diff --git a/.bazelrc b/.bazelrc index 42330a369c9f6..d58c889b9a998 100644 --- a/.bazelrc +++ b/.bazelrc @@ -83,7 +83,7 @@ build --define tsl_protobuf_header_only=true build --define=use_fast_cpp_protos=true build --define=allow_oversize_protos=true - +build --incompatible_strict_action_env build --spawn_strategy=standalone build -c opt @@ -134,7 +134,7 @@ build --experimental_link_static_libraries_once=false # Prevent regressions on those two incompatible changes # TODO: remove those flags when they are flipped in the default Bazel version TF uses. build --incompatible_enforce_config_setting_visibility -# TODO: also enable this flag after fixing the visbility violations +# TODO: also enable this flag after fixing the visibility violations # build --incompatible_config_setting_private_default_visibility # Default options should come above this line. @@ -243,18 +243,26 @@ build:cuda_clang --@local_config_cuda//:cuda_compiler=clang # release while SASS is only forward compatible inside the current # major release. Example: sm_80 kernels can run on sm_89 GPUs but # not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs. -build:cuda_clang --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" +build:cuda_clang --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. build:cuda_clang_official --config=cuda_clang build:cuda_clang_official --action_env=TF_CUDA_VERSION="12" build:cuda_clang_official --action_env=TF_CUDNN_VERSION="8" -build:cuda_clang_official --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.2" +build:cuda_clang_official --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.3" build:cuda_clang_official --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-17/bin/clang" build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:cuda_clang_official --crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain" +# Build with nvcc for CUDA and clang for host +build:nvcc_clang --config=cuda +# Unfortunately, cuda_configure.bzl demands this for using nvcc + clang +build:nvcc_clang --action_env=TF_CUDA_CLANG="1" +build:nvcc_clang --action_env=TF_NVCC_CLANG="1" +build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc + + # Debug config build:dbg -c dbg # Only include debug info for files under tensorflow/, excluding kernels, to @@ -307,7 +315,7 @@ build:macos --copt=-w build:windows --copt=/W0 build:windows --host_copt=/W0 -# Suppress most C++ complier warnings to reduce log size but allow +# Suppress most C++ compiler warnings to reduce log size but allow # for specific warnings to still be present. build:linux --copt="-Wno-all" build:linux --copt="-Wno-extra" @@ -441,6 +449,9 @@ test:win_clang --host_linkopt=/FORCE:MULTIPLE # TODO(kanglan): Change v2's define to default behavior build:v2 --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1 +# Enable all targets in XLA +build:cpu_cross --define=with_cross_compiler_support=true + # Disable XLA on mobile. build:xla --define=with_xla_support=true # TODO: remove, it's on by default. build:android --define=with_xla_support=false @@ -527,8 +538,8 @@ build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.16-clang_confi test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda +build:rbe_linux_cuda_nvcc --config=nvcc_clang build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1 -build:rbe_linux_cuda_nvcc --action_env=TF_NVCC_CLANG="1" # TODO(kanglan): Remove rbe_win and rbe_win_py3* after b/289091160 is fixed build:rbe_win --config=rbe_base @@ -577,6 +588,7 @@ build:elinux_armhf --copt -mfp16-format=ieee # Load rc file written by ./configure. try-import %workspace%/.tf_configure.bazelrc +try-import %workspace%/xla_configure.bazelrc # Load rc file with user-specific options. try-import %workspace%/.bazelrc.user @@ -585,10 +597,16 @@ try-import %workspace%/.bazelrc.user # Build TensorFlow v2. test:release_base --test_size_filters=small,medium +# Ensure release_base is set on linux +build:release_linux_base --config=release_base + # Target the AVX instruction set build:release_linux_base --config=avx_linux -# Disable clang extention that rejects type definitions within offsetof. +# Enable support for all targets +build:release_base --config=cpu_cross + +# Disable clang extension that rejects type definitions within offsetof. # This was added in clang-16 by https://reviews.llvm.org/D133574. # Can be removed once upb is updated, since a type definition is used within # offset of in the current version of ubp. @@ -665,12 +683,23 @@ build:unsupported_gpu_linux --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gc build:unsupported_gpu_linux --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain build:release_cpu_macos --config=avx_linux -test:release_cpu_macos --config=release_base # Base build configs for macOS build:release_macos_base --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer build:release_macos_base --define=no_nccl_support=true --output_filter=^$ +# Ensure release_base is set on mac +build:release_macos_base --config=release_base + +# Build configs for macOS x86 +build:release_macos_x86 --config=release_macos_base +# Build with the AVX instruction set when on macOS x86 +build:release_macos_x86 --config=avx_linux +build:release_macos_x86 --cpu=darwin +# Target Catalina as the minimum compatible OS version +build:release_macos_x86 --macos_minimum_os=10.15 +build:release_macos_x86 --action_env MACOSX_DEPLOYMENT_TARGET=10.15 + # Build configs for macOS Arm64 build:release_macos_arm64 --config=release_macos_base build:release_macos_arm64 --cpu=darwin_arm64 @@ -685,13 +714,18 @@ test:release_macos_base --test_timeout=300,450,1200,3600 --test_output=errors test:release_macos_base --build_tests_only --keep_going test:release_macos_base --flaky_test_attempts=3 +# Test configs for macOS x86 +test:release_macos_x86 --config=release_macos_base + # Test configs for macOS Arm64 test:release_macos_arm64 --config=release_macos_base +# Ensure release_base is set on windows +build:release_cpu_windows --config=release_base + # TODO(kanglan): Update windows configs after b/289091160 is fixed build:release_cpu_windows --config=avx_win build:release_cpu_windows --define=no_tensorflow_py_deps=true -test:release_cpu_windows --config=release_base # Exclude TFRT integration for anything but Linux. build:android --config=no_tfrt @@ -707,7 +741,7 @@ build:no_tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compil # Use --config=tf_public_cache to try and use the TensorFlow public build cache # to build TensorFlow. Look at ci/official/envs to find which types of jobs # push to the cache. For macOS, use --config=tf_public_macos_cache -build:tf_public_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --remote_upload_local_results=false +build:tf_public_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/january2024" --remote_upload_local_results=false # Cache pushes are limited to TF's CI system. build:tf_public_cache_push --config=tf_public_cache --remote_upload_local_results=true --google_default_credentials # Public cache for macOS builds @@ -746,6 +780,11 @@ test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-os test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +# MACOS X86 WHEEL +test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium +test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. @@ -760,46 +799,137 @@ test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-os test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 PYCPP -test:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -test:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -test:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 +# In Linux Arm64 presubmit/continuous build, we cross-compile the binaries on +# Linux x86 so that we can use RBE. Since tests still need to run on the single +# host Arm64 machine, the build becomes too slow (~30 min) to be a presubmit. +# For testing purposes, we want to see the runtime performance of an +# experimental job that is build-only, i.e, we only build the test targets and +# do not run them. By prefixing the configs with "build", we can run both +# `bazel build` and `bazel test` commands with the same config as test configs +# inherit from build. +build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? -test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test +build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test # CROSS-COMPILE ARM64 PYCPP -test:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test +build:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test # Tests that fail only when cross-compiled -test:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test +build:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test # MACOS ARM64 PYCPP test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/core/kernels/image:resize_bicubic_op_test +# MACOS X86 PYCPP +# These are defined as build configs so that we can run a build only job. See +# the note under "ARM64 PYCPP" for more details. +build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --keep_going --test_lang_filters=cc,py --test_size_filters=small,medium +build:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... +# CROSS-COMPILE MACOS X86 PYCPP +build:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test +build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test # END TF TEST SUITE OPTIONS -# START LINUX AARCH64 CROSS-COMPILE CONFIGS +# START CROSS-COMPILE CONFIGS # Set execution platform to Linux x86 # Note: Lot of the "host_" flags such as "host_cpu" and "host_crosstool_top" # flags seem to be actually used to specify the execution platform details. It # seems it is this way because these flags are old and predate the distinction # between host and execution platform. -build:cross_compile_linux_arm64 --host_cpu=k8 -build:cross_compile_linux_arm64 --host_crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite -build:cross_compile_linux_arm64 --extra_execution_platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_x86_64 +build:cross_compile_base --host_cpu=k8 +build:cross_compile_base --host_crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +build:cross_compile_base --extra_execution_platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_x86_64 + +# XLA related settings for cross-compiled build. Certain paths are +# different in the XLA repo. +build:cross_compile_base_xla --host_cpu=k8 +build:cross_compile_base_xla --host_crosstool_top=//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +build:cross_compile_base_xla --extra_execution_platforms=//tools/toolchains/cross_compile/config:linux_x86_64 + +build:rbe_cross_compile_base --config=rbe_base +build:rbe_cross_compile_base --remote_instance_name=projects/tensorflow-testing/instances/default_instance + +# XLA depends on some local Python headers that are configured as Genrule. They +# are present on the local host machine but not on the remote execution machine, +# leading to build failures. To resolve the issue, the following line is added +# to make sure all Genrule targets are excuted locally. +build:rbe_cross_compile_base_xla --config=rbe_cross_compile_base +build:rbe_cross_compile_base_xla --strategy=Genrule=standalone + +# Due to the above strategy, all Genrule commands are executed locally, but the +# following actions invoke tools (E.g `flatc`, `llvm-tblgen`, etc.) that are +# only executabe on the RBE (x86) machine, so the strategy_regexp options are +# added to override and run the actions using remote strategy. +build:rbe_cross_compile_base_xla --strategy_regexp='Generating code from table.*=remote' +build:rbe_cross_compile_base_xla --strategy_regexp='Generating flatbuffer files.*=remote' +build:rbe_cross_compile_base_xla --strategy_regexp='Executing genrule @llvm-project.*=remote' + +# Test-related settings below this point +# We cannot run cross-compiled tests on the remote Linux x86 VMs so we need to +# force all tests to run locally on the Aarch64 host. +test:rbe_cross_compile_base --strategy=TestRunner=local --build_tests_only +test:rbe_cross_compile_base --verbose_failures=true --local_test_jobs=HOST_CPUS --test_output=errors + +test:rbe_cross_compile_base_xla --config=rbe_cross_compile_base + +# START LINUX AARCH64 CROSS-COMPILE CONFIGS +build:cross_compile_linux_arm64 --config=cross_compile_base # Set the target CPU to Aarch64 build:cross_compile_linux_arm64 --platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_aarch64 build:cross_compile_linux_arm64 --cpu=aarch64 build:cross_compile_linux_arm64 --crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite -# RBE configs +# XLA uses different paths for platforms and crosstool_top. +build:cross_compile_linux_arm64_xla --config=cross_compile_base_xla +build:cross_compile_linux_arm64_xla --platforms=//tools/toolchains/cross_compile/config:linux_aarch64 +build:cross_compile_linux_arm64_xla --crosstool_top=//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite + +# RBE cross-compile configs for Linux Aarch64 build:rbe_cross_compile_linux_arm64 --config=cross_compile_linux_arm64 -build:rbe_cross_compile_linux_arm64 --config=rbe_base -build:rbe_cross_compile_linux_arm64 --remote_instance_name=projects/tensorflow-testing/instances/default_instance +build:rbe_cross_compile_linux_arm64 --config=rbe_cross_compile_base +test:rbe_cross_compile_linux_arm64 --config=rbe_cross_compile_base + +# RBE cross-compile configs for XLA Linux Aarch64 +build:rbe_cross_compile_linux_arm64_xla --config=cross_compile_linux_arm64_xla +build:rbe_cross_compile_linux_arm64_xla --config=rbe_cross_compile_base_xla +test:rbe_cross_compile_linux_arm64_xla --config=rbe_cross_compile_base_xla -# Test-related settings below this point -# We cannot run cross-compiled tests on the remote Linux x86 VMs so we need to -# force all tests to run locally on the Aarch64 host. -test:rbe_cross_compile_linux_arm64 --strategy=TestRunner=local -test:rbe_cross_compile_linux_arm64 --verbose_failures=true --local_test_jobs=HOST_CPUS --test_output=errors -test:rbe_cross_compile_linux_arm64 --flaky_test_attempts=3 --build_tests_only # END LINUX AARCH64 CROSS-COMPILE CONFIGS + +# START MACOS CROSS-COMPILE CONFIGS +build:cross_compile_macos_x86 --config=cross_compile_base +build:cross_compile_macos_x86 --config=nonccl +# Target Catalina (10.15) as the minimum supported OS +build:cross_compile_macos_x86 --action_env MACOSX_DEPLOYMENT_TARGET=10.15 + +# Set the target CPU to Darwin x86 +build:cross_compile_macos_x86 --platforms=//tensorflow/tools/toolchains/cross_compile/config:darwin_x86_64 +build:cross_compile_macos_x86 --cpu=darwin +build:cross_compile_macos_x86 --crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +# When RBE cross-compiling for macOS, we need to explicitly register the +# toolchain. Otherwise, oddly, RBE complains that a "docker container must be +# specified". +build:cross_compile_macos_x86 --extra_toolchains=//tensorflow/tools/toolchains/cross_compile/config:macos-x86-cross-compile-cc-toolchain +# Map --platforms=darwin_x86_64 to --cpu=darwin and vice-versa to make selects() +# and transistions that use these flags work. +build:cross_compile_macos_x86 --platform_mappings=tensorflow/tools/toolchains/cross_compile/config/platform_mappings + +# RBE cross-compile configs for Darwin x86 +build:rbe_cross_compile_macos_x86 --config=cross_compile_macos_x86 +build:rbe_cross_compile_macos_x86 --config=rbe_cross_compile_base +build:rbe_cross_compile_macos_x86 --bes_upload_mode=nowait_for_upload_complete +test:rbe_cross_compile_macos_x86 --config=rbe_cross_compile_base +# Increase the test timeout as tests often take longer on mac. +test:rbe_cross_compile_macos_x86 --test_timeout=300,450,1200,3600 +# Limit jobs to 100 to avoid running into "out of memory" issues (b/316266643) +build:rbe_cross_compile_macos_x86 --jobs=100 +test:rbe_cross_compile_macos_x86 --jobs=100 +# END MACOS CROSS-COMPILE CONFIGS +# END CROSS-COMPILE CONFIGS + +# Try to load the XLA warnings config if available +try-import %workspace%/warnings.bazelrc diff --git a/.bazelversion b/.bazelversion index 204ac7c926e43..f3c238740e5bc 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1,2 +1,2 @@ -6.4.0 +6.5.0 # NOTE: Update Bazel version in tensorflow/tools/ci_build/release/common.sh.oss \ No newline at end of file diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000000..720f2f898b3b3 --- /dev/null +++ b/.clang-format @@ -0,0 +1,6 @@ +BasedOnStyle: Google +Language: Cpp +PointerBindsToType: true +SortIncludes: Never +AlignTrailingComments: + Kind: Always diff --git a/.github/workflows/buildifier.yml b/.github/workflows/buildifier.yml index 9a6d821f4c023..00413b1f505d4 100644 --- a/.github/workflows/buildifier.yml +++ b/.github/workflows/buildifier.yml @@ -14,7 +14,7 @@ # ============================================================================ name: Buildifier on: - pull_request_target: + pull_request: env: # Have `go install` place binaries in $PATH @@ -27,13 +27,10 @@ jobs: run: shell: bash timeout-minutes: 1 - if: | - github.event.sender.type == 'User' || - contains(github.event.pull_request.body, 'FORCE_TEST_ACTIONS') steps: - name: "Checking out repository" uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 - name: "Install buildifier" - run: go install github.com/bazelbuild/buildtools/buildifier@7d855c5 + run: go install github.com/bazelbuild/buildtools/buildifier@433ea85 # 6.4.0 - name: "Run buildifier" run: buildifier --lint=warn --warnings=-out-of-order-load -r xla/ diff --git a/.github/workflows/check_contents.yml b/.github/workflows/check_contents.yml index ced9ecba8fb44..abbf80a18db63 100644 --- a/.github/workflows/check_contents.yml +++ b/.github/workflows/check_contents.yml @@ -20,7 +20,7 @@ # TODO(ddunleavy): Update this after METADATA files are consolidated. name: Check Contents on: - pull_request_target: + pull_request: env: # A bit tricky here: this lets us invoke python files outside of `bazel run` diff --git a/.github/workflows/clang_format.yml b/.github/workflows/clang_format.yml new file mode 100644 index 0000000000000..c262edb2385ad --- /dev/null +++ b/.github/workflows/clang_format.yml @@ -0,0 +1,35 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +name: Clang Format +on: + pull_request: + +jobs: + clang-format: + runs-on: ubuntu-22.04 + defaults: + run: + shell: bash + timeout-minutes: 1 + if: | + contains(github.event.pull_request.body, 'FORCE_TEST_ACTIONS') + steps: + - name: "Checking out repository" + uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - name: "Fetch HEAD of main branch" + run: git fetch origin main --depth=1 + - name: "Run clang-format" # Use pipx to get version that apt doesn't have by default + run: pipx run clang-format==17.0.6 --dry-run --Werror --verbose $(git diff --name-only origin/main HEAD -- '*.cc' '*.h') diff --git a/.github/workflows/trusted_partners.js b/.github/workflows/trusted_partners.js index 75a1ff082592b..d886ce5fc6ad4 100644 --- a/.github/workflows/trusted_partners.js +++ b/.github/workflows/trusted_partners.js @@ -76,8 +76,7 @@ const filter_action = async ({github, context, domain}) => { assignees.push('hawkinsp', 'yashk2810', 'skye'); } if (lowercased_title.includes('xla') || lowercased_title.includes('gpu')) { - assignees.push( - 'cheshire', 'gcforster', 'reedwm', 'chsigg', 'xla-rotation'); + assignees.push('cheshire', 'reedwm', 'xla-rotation'); } if (lowercased_title.includes('tf')) { assignees.push('rohan100jain', 'bfontain'); diff --git a/.gitignore b/.gitignore index d510ed43c0a06..ee6ca187cc7fa 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,10 @@ bazel-bin bazel-out bazel-testlogs +# Ignore files produced by `configure` +.tf_configure.bazelrc +xla_configure.bazelrc +tools/python_bin_path.sh # Emacs autosaves *~ diff --git a/.kokoro/jax/build.sh b/.kokoro/jax/build.sh index d533623a868f0..4fbe0bd34d7c0 100644 --- a/.kokoro/jax/build.sh +++ b/.kokoro/jax/build.sh @@ -63,7 +63,7 @@ build_and_test_on_rbe_cpu() { --override_repository=xla="${KOKORO_ARTIFACTS_DIR}"/github/xla \ --config=avx_posix \ --config=mkl_open_source_only \ - --config="rbe_cpu_linux_py312" \ + --config="rbe_cpu_linux_py3.12" \ --config=tensorflow_testing_rbe_linux \ --test_env=JAX_NUM_GENERATED_CASES=25 \ --test_output=errors \ @@ -80,7 +80,7 @@ build_and_test_on_rbe_gpu() { --override_repository=xla="${KOKORO_ARTIFACTS_DIR}"/github/xla \ --config=avx_posix \ --config=mkl_open_source_only \ - --config="rbe_linux_cuda12.2_nvcc_py3.9" \ + --config="rbe_linux_cuda12.3_nvcc_py3.9" \ --config=tensorflow_testing_rbe_linux \ --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ --test_output=errors \ diff --git a/.kokoro/linux/build.sh b/.kokoro/linux/build.sh index 635af61a6d3ed..7164e9520435f 100644 --- a/.kokoro/linux/build.sh +++ b/.kokoro/linux/build.sh @@ -26,12 +26,21 @@ function is_linux_gpu_job() { [[ "$KOKORO_JOB_NAME" =~ tensorflow/xla/linux/.*gpu.* ]] } -# Pull the container (in case it was updated since the instance started) and -# store its SHA in the Sponge log. -docker pull "$DOCKER_IMAGE" -echo "TF_INFO_DOCKER_IMAGE,$DOCKER_IMAGE" >> "$KOKORO_ARTIFACTS_DIR/custom_sponge_config.csv" -echo "TF_INFO_DOCKER_SHA,$(docker pull "$DOCKER_IMAGE" | sed -n '/Digest:/s/Digest: //g p')" >> "$KOKORO_ARTIFACTS_DIR/custom_sponge_config.csv" +function is_linux_cpu_arm64_job() { + [[ "$KOKORO_JOB_NAME" =~ tensorflow/xla/linux/.*arm64.*/.*cpu.* ]] +} + +function pull_docker_image_with_retries() { + # Pull the container (in case it was updated since the instance started) and + # store its SHA in the Sponge log. + docker pull "$DOCKER_IMAGE" || sleep 15 + docker pull "$DOCKER_IMAGE" || sleep 15 + docker pull "$DOCKER_IMAGE" + echo "TF_INFO_DOCKER_IMAGE,$DOCKER_IMAGE" >> "$KOKORO_ARTIFACTS_DIR/custom_sponge_config.csv" + echo "TF_INFO_DOCKER_SHA,$(docker pull "$DOCKER_IMAGE" | sed -n '/Digest:/s/Digest: //g p')" >> "$KOKORO_ARTIFACTS_DIR/custom_sponge_config.csv" +} +pull_docker_image_with_retries # Start a container in the background docker run --name xla -w /tf/xla -itd --rm \ -v "$KOKORO_ARTIFACTS_DIR/github/xla:/tf/xla" \ @@ -39,40 +48,55 @@ docker run --name xla -w /tf/xla -itd --rm \ "$DOCKER_IMAGE" \ bash -# bazelrc Files currently come from https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools -RC_FILE="/usertools/cpu.bazelrc" -TARGET_FILTER="" TAGS_FILTER="-no_oss,-oss_excluded,-oss_serial" ADDITIONAL_FLAGS="" -RBE_CONFIG="" +RBE_FLAGS="" +TARGET_FILTERS="-@tsl//tsl/platform:subprocess_test -@tsl//tsl/platform/cloud:google_auth_provider_test -@tsl//tsl/platform/cloud:oauth_client_test" if is_linux_gpu_job ; then TAGS_FILTER="$TAGS_FILTER,gpu,requires-gpu-nvidia,-no_gpu" - ADDITIONAL_FLAGS="$ADDITIONAL_FLAGS --run_under=//tools/ci_build/gpu_build:parallel_gpu_execute" - RC_FILE="/usertools/gpu.bazelrc" - RBE_CONFIG="rbe_linux_cuda_nvcc" + + # We are currently running XLA presubmits on machines with NVIDIA T4 GPUs, + # which have a compute compatibility of 7.5. Se we filter out all the tests + # that need a newer GPU: + UNSUPPORTED_GPU_TAGS="$(echo -requires-gpu-sm{80,86,89,90}{,-only})" + TAGS_FILTER="${TAGS_FILTER},${UNSUPPORTED_GPU_TAGS// /,}" + + ADDITIONAL_FLAGS="$ADDITIONAL_FLAGS --nobuild_tests_only --run_under=//tools/ci_build/gpu_build:parallel_gpu_execute" + RBE_FLAGS="--config=rbe_linux_cuda_nvcc --jobs=150" echo "***NOTE: nvidia-smi lists the highest CUDA version the driver supports, which may be different than the version of CUDA actually used!!***" nvidia-smi else TAGS_FILTER="$TAGS_FILTER,-gpu,-requires-gpu-nvidia" ADDITIONAL_FLAGS="$ADDITIONAL_FLAGS --config=nonccl" - RBE_CONFIG="rbe_linux_cpu" + + if is_linux_cpu_arm64_job ; then + TAGS_FILTER="$TAGS_FILTER,-no_aarch64" + ADDITIONAL_FLAGS="$ADDITIONAL_FLAGS --action_env PYTHON_BIN_PATH=/usr/bin/python3.11 --python_path=/usr/bin/python3.11" + # Some cross-compile tests are not working for XLA Linux Aarch64. + # TODO(ddunleavy): Revisit these when hermetic python is available. + TARGET_FILTERS="$TARGET_FILTERS -//xla/python_api:xla_shape_test -//xla/python_api:xla_literal_test -//xla/service:xla_aot_compile_stablehlo_cpu_test -//xla/tests:local_client_aot_test" + RBE_FLAGS="--config=rbe_cross_compile_linux_arm64_xla --jobs=150" + else + RBE_FLAGS="--config=rbe_linux_cpu --jobs=150" + ADDITIONAL_FLAGS="$ADDITIONAL_FLAGS --nobuild_tests_only" + fi fi # Build & test XLA -docker exec xla bazel --bazelrc=$RC_FILE \ +docker exec xla bazel \ test \ --build_tag_filters=$TAGS_FILTER \ --test_tag_filters=$TAGS_FILTER \ + --test_output=errors \ --keep_going \ --features=layering_check \ --profile=/tf/pkg/profile.json.gz \ --flaky_test_attempts=3 \ - --config=$RBE_CONFIG \ - --jobs=150 \ - --nobuild_tests_only \ + --config=warnings \ + $RBE_FLAGS \ $ADDITIONAL_FLAGS \ - -- //xla/... //build_tools/... $TARGET_FILTER + -- //xla/... //build_tools/... @tsl//tsl/... $TARGET_FILTERS # Print build time statistics, including critical path. diff --git a/.kokoro/macos/build.sh b/.kokoro/macos/build.sh index a10e65d85e11c..b3c97310752fb 100644 --- a/.kokoro/macos/build.sh +++ b/.kokoro/macos/build.sh @@ -81,6 +81,7 @@ bazel test \ --output_filter="" \ --macos_minimum_os=10.15 \ --keep_going \ + --test_output=errors \ --config=nonccl \ --build_tag_filters=$TAGS_FILTER --test_tag_filters=$TAGS_FILTER \ --test_size_filters=small,medium \ diff --git a/BUILD.bazel b/BUILD.bazel index e5d409b78c165..23644329d2371 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -2,7 +2,7 @@ load("@rules_license//rules:license.bzl", "license") package( default_applicable_licenses = [":license"], - default_visibility = ["//visibility:private"], + default_visibility = ["//visibility:public"], ) licenses(["notice"]) diff --git a/README.md b/README.md index 65e8b8f2211c4..be0325eefc03b 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ and then see the [developer guide](docs/developer_guide.md). ## Contacts -* For questions, contact Thea Lamkin - thealamkin at google.com. +* For questions, contact the maintainers - maintainers at openxla.org ## Resources diff --git a/WORKSPACE b/WORKSPACE index 374838de234e0..7ba74d6276c2e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,3 +1,4 @@ +# buildifier: disable=load-on-top workspace(name = "xla") # Initialize the XLA repository and all dependencies. diff --git a/build_tools/BUILD b/build_tools/BUILD new file mode 100644 index 0000000000000..3111ccb2505db --- /dev/null +++ b/build_tools/BUILD @@ -0,0 +1,28 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +load("//xla:pytype.default.bzl", "pytype_strict_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +pytype_strict_library( + name = "test_utils", + testonly = True, + srcs = ["test_utils.py"], + visibility = ["//visibility:public"], +) diff --git a/build_tools/configure/BUILD b/build_tools/configure/BUILD new file mode 100644 index 0000000000000..3be5b6a044610 --- /dev/null +++ b/build_tools/configure/BUILD @@ -0,0 +1,82 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# Placeholder: load py_test +load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") +load("//xla:pytype.default.bzl", "pytype_strict_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +pytype_strict_library( + name = "configure", + srcs = ["configure.py"], +) + +py_test( + name = "configure_test", + srcs = ["configure_test.py"], + data = [ + "testdata/clang.bazelrc", + "testdata/cuda_clang.bazelrc", + "testdata/gcc.bazelrc", + "testdata/nvcc_clang.bazelrc", + "testdata/nvcc_gcc.bazelrc", + ], + deps = [ + ":configure", + "//build_tools:test_utils", + "@absl_py//absl/testing:absltest", + ], +) + +# Below targets are just for checking if the host/CUDA compiler are configured +# as expected. +cc_library( + name = "assert_clang", + srcs = ["assert_clang.cc"], + tags = ["manual"], +) + +cc_library( + name = "assert_gcc", + srcs = ["assert_gcc.cc"], + tags = ["manual"], +) + +cuda_library( + name = "assert_cuda_clang", + srcs = ["assert_cuda_clang.cu.cc"], + tags = [ + "gpu", + "manual", + ], + deps = ["@local_config_cuda//cuda:cuda_headers"], +) + +cuda_library( + name = "assert_nvcc", + srcs = ["assert_nvcc.cu.cc"], + tags = [ + "gpu", + "manual", + ], + # Notably, this builds fine in OSS without this dependency. Apparently, + # NVCC can give targets access to CUDA headers without letting Bazel know, + # while CUDA clang cannot. + deps = ["@local_config_cuda//cuda:cuda_headers"], +) diff --git a/build_tools/configure/assert_clang.cc b/build_tools/configure/assert_clang.cc new file mode 100644 index 0000000000000..3dd57d1d1ff8b --- /dev/null +++ b/build_tools/configure/assert_clang.cc @@ -0,0 +1,18 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef __clang__ +#error "__clang__ not defined!" +#endif // #ifdef __clang__ diff --git a/build_tools/configure/assert_cuda_clang.cu.cc b/build_tools/configure/assert_cuda_clang.cu.cc new file mode 100644 index 0000000000000..12aeb2743b635 --- /dev/null +++ b/build_tools/configure/assert_cuda_clang.cu.cc @@ -0,0 +1,18 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if !defined(__clang__) || !defined(__CUDA__) +#error "__clang__ or __CUDA__ not defined!" +#endif // #if !defined(__clang__) || !defined(__CUDA__) diff --git a/build_tools/configure/assert_gcc.cc b/build_tools/configure/assert_gcc.cc new file mode 100644 index 0000000000000..617da0d621a01 --- /dev/null +++ b/build_tools/configure/assert_gcc.cc @@ -0,0 +1,21 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Notably, clang will define `__GNUC__`, so need to make sure __clang__ is not +// defined to detect GCC (or, most correctly, some compiler that supports GNU +// extensions that is not clang). +#if !defined(__GNUC__) || defined(__clang__) +#error "__GNUC__ is not defined independently of __clang__!" +#endif // #if !defined(__GNUC__) || defined(__clang__) diff --git a/build_tools/configure/assert_nvcc.cu.cc b/build_tools/configure/assert_nvcc.cu.cc new file mode 100644 index 0000000000000..ea9287565755c --- /dev/null +++ b/build_tools/configure/assert_nvcc.cu.cc @@ -0,0 +1,17 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef __NVCC__ +#error "__NVCC__ not defined!" +#endif // #ifdef __NVCC__ diff --git a/build_tools/configure/configure.py b/build_tools/configure/configure.py new file mode 100755 index 0000000000000..663e4b8724280 --- /dev/null +++ b/build_tools/configure/configure.py @@ -0,0 +1,538 @@ +#!/usr/bin/env python3 +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Configure script to get build parameters from user. + +This script populates a bazelrc file that tells Bazel where to look for +cuda versions and compilers. Note: that a configuration is possible to request, +does not mean that it is supported (e.g. building with gcc). That being said, +if this stops working for you on an unsupported build and you have a fix, please +send a PR! + +Example usage: + `./configure.py --backend=cpu --host_compiler=clang` + Will write a bazelrc to the root of the repo with the lines required to find + the clang in your path. If that isn't the correct clang, you can override like + `./configure.py --backend=cpu --clang_path=`. + +NOTE(ddunleavy): Lots of these things should probably be outside of configure.py +but are here because of complexity in `cuda_configure.bzl` and the TF bazelrc. +Once XLA has it's own bazelrc, and cuda_configure.bzl is replaced or refactored, +we can probably make this file smaller. + +TODO(ddunleavy): add more thorough validation. +""" +import argparse +import dataclasses +import enum +import logging +import os +import pathlib +import shutil +import subprocess +import sys +from typing import Optional + +_REQUIRED_CUDA_LIBRARIES = ["cublas", "cuda", "cudnn"] +_DEFAULT_BUILD_AND_TEST_TAG_FILTERS = ("-no_oss",) +# Assume we are being invoked from the symlink at the root of the repo +_XLA_SRC_ROOT = pathlib.Path(__file__).absolute().parent +_FIND_CUDA_CONFIG = str( + _XLA_SRC_ROOT + / "third_party" + / "tsl" + / "third_party" + / "gpus" + / "find_cuda_config.py" +) +_XLA_BAZELRC_NAME = "xla_configure.bazelrc" +_KW_ONLY_IF_PYTHON310 = {"kw_only": True} if sys.version_info >= (3, 10) else {} + + +def _find_executable(executable: str) -> Optional[str]: + logging.info("Trying to find path to %s...", executable) + # Resolving the symlink is necessary for finding system headers. + if unresolved_path := shutil.which(executable): + return str(pathlib.Path(unresolved_path).resolve()) + return None + + +def _find_executable_or_die(executable: str) -> str: + """Finds executable and resolves symlinks or raises RuntimeError. + + Resolving symlinks is sometimes necessary for finding system headers. + + Args: + executable: The name of the executable that we want to find. + + Returns: + The path to the executable we are looking for. + Raises: + RuntimeError: if path to the executable cannot be found. + """ + resolved_path_to_exe = _find_executable(executable) + if resolved_path_to_exe is None: + raise RuntimeError( + f"Could not find executable `{executable}`! " + "Please change your $PATH or pass the path directly like" + f"`--{executable}_path=path/to/executable." + ) + logging.info("Found path to %s at %s", executable, resolved_path_to_exe) + + return resolved_path_to_exe + + +def _get_cuda_compute_capabilities_or_die() -> list[str]: + """Finds compute capabilities via nvidia-smi or rasies exception. + + Returns: + list of unique, sorted strings representing compute capabilities: + Raises: + RuntimeError: if path to nvidia-smi couldn't be found. + subprocess.CalledProcessError: if nvidia-smi process failed. + """ + try: + nvidia_smi = _find_executable_or_die("nvidia-smi") + nvidia_smi_proc = subprocess.run( + [nvidia_smi, "--query-gpu=compute_cap", "--format=csv,noheader"], + capture_output=True, + check=True, + text=True, + ) + # Command above returns a newline separated list of compute capabilities + # with possible repeats. So we should unique them and sort the final result. + capabilities = sorted(set(nvidia_smi_proc.stdout.strip().split("\n"))) + logging.info("Found CUDA compute capabilities: %s", capabilities) + return capabilities + except (RuntimeError, subprocess.CalledProcessError) as e: + logging.info( + "Could not find nvidia-smi, or nvidia-smi command failed. Please pass" + " capabilities directly using --cuda_compute_capabilities." + ) + raise e + + +def _get_clang_major_version(path_to_clang: str) -> int: + """Gets the major version of the clang at `path_to_clang`. + + Args: + path_to_clang: Path to a clang executable + + Returns: + The major version. + """ + logging.info("Running echo __clang_major__ | %s -E -P -", path_to_clang) + clang_version_proc = subprocess.run( + [path_to_clang, "-E", "-P", "-"], + input="__clang_major__", + check=True, + capture_output=True, + text=True, + ) + major_version = int(clang_version_proc.stdout) + logging.info("%s reports major version %s.", path_to_clang, major_version) + + return major_version + + +class ArgparseableEnum(enum.Enum): + """Enum base class with helper methods for working with argparse. + + Example usage: + ``` + class Fruit(ArgparseableEnum): + APPLE = enum.auto() + + # argparse setup + parser.add_argument("--fruit", type=Fruit.from_str, choices=list(Fruit)) + ``` + Users can pass strings like `--fruit=apple` with nice error messages and the + parser will get the corresponding enum value. + + NOTE: PyType gets confused when this class is used to create Enums in the + functional style like `ArgparseableEnum("Fruit", ["APPLE", "BANANA"])`. + """ + + def __str__(self): + return self.name + + @classmethod + def from_str(cls, s): + s = s.upper() + try: + return cls[s] + except KeyError: + # Sloppy looking exception handling, but argparse will catch ValueError + # and give a pleasant error message. KeyError would not work here. + raise ValueError # pylint: disable=raise-missing-from + + +class Backend(ArgparseableEnum): + CPU = enum.auto() + CUDA = enum.auto() + ROCM = enum.auto() + + +class HostCompiler(ArgparseableEnum): + CLANG = enum.auto() + GCC = enum.auto() + + +class CudaCompiler(ArgparseableEnum): + CLANG = enum.auto() + NVCC = enum.auto() + + +class OS(ArgparseableEnum): + LINUX = enum.auto() + MACOS = enum.auto() + WINDOWS = enum.auto() + + +@dataclasses.dataclass(**_KW_ONLY_IF_PYTHON310) +class DiscoverablePathsAndVersions: + """Paths to various tools and libraries needed to build XLA. + + This class is where all 'stateful' activity should happen, like trying to read + environment variables or looking for things in the $PATH. An instance that has + all fields set should not try to do any of these things though, so that this + file can remain unit testable. + """ + + clang_path: Optional[str] = None + clang_major_version: Optional[int] = None + gcc_path: Optional[str] = None + lld_path: Optional[str] = None + ld_library_path: Optional[str] = None + + # CUDA specific + cublas_version: Optional[str] = None + cuda_toolkit_path: Optional[str] = None + cuda_compute_capabilities: Optional[list[str]] = None + cudnn_version: Optional[str] = None + nccl_version: Optional[str] = None + + def get_relevant_paths_and_versions(self, config: "XLAConfigOptions"): + """Gets paths and versions as needed by the config. + + Args: + config: XLAConfigOptions instance that determines what paths and versions + to try to autoconfigure. + """ + if self.ld_library_path is None: + self.ld_library_path = os.environ.get("LD_LIBRARY_PATH", None) + + if config.host_compiler == HostCompiler.CLANG: + self.clang_path = self.clang_path or _find_executable_or_die("clang") + self.clang_major_version = ( + self.clang_major_version or _get_clang_major_version(self.clang_path) + ) + + # Notably, we don't use `_find_executable_or_die` for lld, as it changes + # which commands it accepts based on it's name! ld.lld is symlinked to a + # different executable just called lld, which should not be invoked + # directly. + self.lld_path = self.lld_path or shutil.which("ld.lld") + elif config.host_compiler == HostCompiler.GCC: + self.gcc_path = self.gcc_path or _find_executable_or_die("gcc") + + if config.backend == Backend.CUDA: + if config.cuda_compiler == CudaCompiler.CLANG: + self.clang_path = self.clang_path or _find_executable_or_die("clang") + + if not self.cuda_compute_capabilities: + self.cuda_compute_capabilities = _get_cuda_compute_capabilities_or_die() + + self._get_cuda_libraries_paths_and_versions_if_needed(config) + + def _get_cuda_libraries_paths_and_versions_if_needed( + self, config: "XLAConfigOptions" + ): + """Gets cuda paths and versions if user left any unspecified. + + This uses `find_cuda_config.py` to find versions for all libraries in + `_REQUIRED_CUDA_LIBRARIES`. + + Args: + config: config that determines which libraries should be found. + """ + should_find_nccl = config.using_nccl and self.nccl_version is None + any_cuda_config_unset = any([ + self.cublas_version is None, + self.cuda_toolkit_path is None, + self.cudnn_version is None, + should_find_nccl, + ]) + + maybe_nccl = ["nccl"] if should_find_nccl else [] + + if any_cuda_config_unset: + logging.info( + "Some CUDA config versions and paths were not provided, " + "so trying to find them using find_cuda_config.py" + ) + try: + find_cuda_config_proc = subprocess.run( + [ + sys.executable, + _FIND_CUDA_CONFIG, + *_REQUIRED_CUDA_LIBRARIES, + *maybe_nccl, + ], + capture_output=True, + check=True, + text=True, + ) + except subprocess.CalledProcessError as e: + logging.info("Command %s failed. Is CUDA installed?", e.cmd) + logging.info("Dumping %s ouptut:\n %s", e.cmd, e.output) + raise e + + cuda_config = dict( + tuple(line.split(": ")) + for line in find_cuda_config_proc.stdout.strip().split("\n") + ) + + self.cublas_version = self.cublas_version or cuda_config["cublas_version"] + self.cuda_toolkit_path = ( + self.cuda_toolkit_path or cuda_config["cuda_toolkit_path"] + ) + self.cudnn_version = self.cudnn_version or cuda_config["cudnn_version"] + if should_find_nccl: + self.nccl_version = self.nccl_version or cuda_config["nccl_version"] + + +@dataclasses.dataclass(frozen=True, **_KW_ONLY_IF_PYTHON310) +class XLAConfigOptions: + """Represents XLA configuration options.""" + + backend: Backend + os: OS + python_bin_path: str + host_compiler: HostCompiler + compiler_options: list[str] + + # CUDA specific + cuda_compiler: CudaCompiler + using_nccl: bool + using_tensorrt: bool + + def to_bazelrc_lines( + self, + dpav: DiscoverablePathsAndVersions, + ) -> list[str]: + """Creates a bazelrc given an XLAConfigOptions. + + Necessary paths are provided by the user, or retrieved via + `self._get_relevant_paths`. + + Args: + dpav: DiscoverablePathsAndVersions that may hold user-specified paths and + versions. The dpav will then read from `self` to determine what to try + to auto-configure. + + Returns: + The lines of a bazelrc. + """ + dpav.get_relevant_paths_and_versions(self) + rc = [] + build_and_test_tag_filters = list(_DEFAULT_BUILD_AND_TEST_TAG_FILTERS) + + # Platform independent options based on host compiler + if self.host_compiler == HostCompiler.GCC: + rc.append(f"build --action_env GCC_HOST_COMPILER_PATH={dpav.gcc_path}") + elif self.host_compiler == HostCompiler.CLANG: + rc.append(f"build --action_env CLANG_COMPILER_PATH={dpav.clang_path}") + rc.append(f"build --repo_env CC={dpav.clang_path}") + rc.append(f"build --repo_env BAZEL_COMPILER={dpav.clang_path}") + self.compiler_options.append("-Wno-error=unused-command-line-argument") + if dpav.lld_path: + rc.append(f"build --linkopt --ld-path={dpav.lld_path}") + + if self.backend == Backend.CPU: + build_and_test_tag_filters.append("-gpu") + + elif self.backend == Backend.CUDA: + compiler_pair = self.cuda_compiler, self.host_compiler + + if compiler_pair == (CudaCompiler.CLANG, HostCompiler.CLANG): + rc.append("build --config cuda_clang") + rc.append( + f"build --action_env CLANG_CUDA_COMPILER_PATH={dpav.clang_path}" + ) + elif compiler_pair == (CudaCompiler.NVCC, HostCompiler.CLANG): + rc.append("build --config nvcc_clang") + # This is demanded by cuda_configure.bzl + rc.append( + f"build --action_env CLANG_CUDA_COMPILER_PATH={dpav.clang_path}" + ) + elif compiler_pair == (CudaCompiler.NVCC, HostCompiler.GCC): + rc.append("build --config cuda") + else: + raise NotImplementedError( + "CUDA clang with host compiler gcc not supported" + ) + + # Lines needed for CUDA backend regardless of CUDA/host compiler + rc.append( + f"build --action_env CUDA_TOOLKIT_PATH={dpav.cuda_toolkit_path}" + ) + rc.append(f"build --action_env TF_CUBLAS_VERSION={dpav.cublas_version}") + rc.append( + "build --action_env" + f" TF_CUDA_COMPUTE_CAPABILITIES={','.join(dpav.cuda_compute_capabilities)}" + ) + rc.append(f"build --action_env TF_CUDNN_VERSION={dpav.cudnn_version}") + rc.append(f"build --repo_env TF_NEED_TENSORRT={int(self.using_tensorrt)}") + if self.using_nccl: + rc.append(f"build --action_env TF_NCCL_VERSION={dpav.nccl_version}") + else: + rc.append("build --config nonccl") + elif self.backend == Backend.ROCM: + pass + + # Lines that are added for every backend + if dpav.ld_library_path: + rc.append(f"build --action_env LD_LIBRARY_PATH={dpav.ld_library_path}") + + if dpav.clang_major_version in (16, 17): + self.compiler_options.append("-Wno-gnu-offsetof-extensions") + + rc.append(f"build --action_env PYTHON_BIN_PATH={self.python_bin_path}") + rc.append(f"build --python_path {self.python_bin_path}") + rc.append("test --test_env LD_LIBRARY_PATH") + rc.append("test --test_size_filters small,medium") + + rc.extend([ + f"build --copt {compiler_option}" + for compiler_option in self.compiler_options + ]) + + # Add build and test tag filters + build_and_test_tag_filters = ",".join(build_and_test_tag_filters) + rc.append(f"build --build_tag_filters {build_and_test_tag_filters}") + rc.append(f"build --test_tag_filters {build_and_test_tag_filters}") + rc.append(f"test --build_tag_filters {build_and_test_tag_filters}") + rc.append(f"test --test_tag_filters {build_and_test_tag_filters}") + + return rc + + +def _parse_args(): + """Creates an argparse.ArgumentParser and parses arguments.""" + comma_separated_list = lambda l: [s.strip() for s in l.split(",")] + + parser = argparse.ArgumentParser(allow_abbrev=False) + parser.add_argument( + "--backend", + type=Backend.from_str, + choices=list(Backend), + required=True, + ) + parser.add_argument( + "--os", type=OS.from_str, choices=list(OS), default="linux" + ) + parser.add_argument( + "--host_compiler", + type=HostCompiler.from_str, + choices=list(HostCompiler), + default="clang", + ) + parser.add_argument( + "--cuda_compiler", + type=CudaCompiler.from_str, + choices=list(CudaCompiler), + default="nvcc", + ) + parser.add_argument( + "--cuda_compute_capabilities", + type=comma_separated_list, + default=None, + ) + parser.add_argument("--python_bin_path", default=sys.executable) + parser.add_argument( + "--compiler_options", + type=comma_separated_list, + default="-Wno-sign-compare", + ) + parser.add_argument("--nccl", action="store_true") + parser.add_argument("--tensorrt", action="store_true") + + # Path and version overrides + path_help = "Optional: will be found on PATH if possible." + parser.add_argument("--clang_path", help=path_help) + parser.add_argument("--gcc_path", help=path_help) + parser.add_argument( + "--ld_library_path", + help=( + "Optional: will be automatically taken from the current environment" + " if flag is not set" + ), + ) + parser.add_argument("--lld_path", help=path_help) + + # CUDA specific + find_cuda_config_help = ( + "Optional: will be found using `find_cuda_config.py` if flag is not set." + ) + parser.add_argument("--cublas_version", help=find_cuda_config_help) + parser.add_argument("--cuda_toolkit_path", help=find_cuda_config_help) + parser.add_argument("--cudnn_version", help=find_cuda_config_help) + parser.add_argument("--nccl_version", help=find_cuda_config_help) + + return parser.parse_args() + + +def main(): + # Setup logging + logging.basicConfig() + logging.getLogger().setLevel(logging.INFO) + + args = _parse_args() + + config = XLAConfigOptions( + backend=args.backend, + os=args.os, + host_compiler=args.host_compiler, + cuda_compiler=args.cuda_compiler, + python_bin_path=args.python_bin_path, + compiler_options=args.compiler_options, + using_nccl=args.nccl, + using_tensorrt=args.tensorrt, + ) + + bazelrc_lines = config.to_bazelrc_lines( + DiscoverablePathsAndVersions( + clang_path=args.clang_path, + gcc_path=args.gcc_path, + lld_path=args.lld_path, + ld_library_path=args.ld_library_path, + cublas_version=args.cublas_version, + cuda_compute_capabilities=args.cuda_compute_capabilities, + cuda_toolkit_path=args.cuda_toolkit_path, + cudnn_version=args.cudnn_version, + nccl_version=args.nccl_version, + ) + ) + + bazelrc_path = _XLA_SRC_ROOT / _XLA_BAZELRC_NAME + bazelrc_contents = "\n".join(bazelrc_lines) + "\n" + + with (bazelrc_path).open("w") as f: + logging.info("Writing bazelrc to %s...", bazelrc_path) + f.write(bazelrc_contents) + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/build_tools/configure/configure_test.py b/build_tools/configure/configure_test.py new file mode 100644 index 0000000000000..c952c8f9241f4 --- /dev/null +++ b/build_tools/configure/configure_test.py @@ -0,0 +1,179 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from absl.testing import absltest + +from xla.build_tools import test_utils +from xla.build_tools.configure import configure + + +XLAConfigOptions = configure.XLAConfigOptions +DiscoverablePathsAndVersions = configure.DiscoverablePathsAndVersions +Backend = configure.Backend +HostCompiler = configure.HostCompiler +CudaCompiler = configure.CudaCompiler +OS = configure.OS + +_PYTHON_BIN_PATH = "/usr/bin/python3" +_CLANG_PATH = "/usr/lib/llvm-17/bin/clang" +_GCC_PATH = "/usr/bin/gcc" +_COMPILER_OPTIONS = ("-Wno-sign-compare",) + +# CUDA specific paths and versions +_CUDA_SPECIFIC_PATHS_AND_VERSIONS = { + "cublas_version": "12.3", + "cuda_toolkit_path": "/usr/local/cuda-12.2", + "cuda_compute_capabilities": ["7.5"], + "cudnn_version": "8", + "ld_library_path": "/usr/local/nvidia/lib:/usr/local/nvidia/lib64", + "nccl_version": "2", +} + + +class ConfigureTest(absltest.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + + testdata = ( + test_utils.xla_src_root() / "build_tools" / "configure" / "testdata" + ) + + with (testdata / "clang.bazelrc").open() as f: + cls.clang_bazelrc_lines = [line.strip() for line in f.readlines()] + + with (testdata / "gcc.bazelrc").open() as f: + cls.gcc_bazelrc_lines = [line.strip() for line in f.readlines()] + + with (testdata / "cuda_clang.bazelrc").open() as f: + cls.cuda_clang_bazelrc_lines = [line.strip() for line in f.readlines()] + + with (testdata / "nvcc_clang.bazelrc").open() as f: + cls.nvcc_clang_bazelrc_lines = [line.strip() for line in f.readlines()] + + with (testdata / "nvcc_gcc.bazelrc").open() as f: + cls.nvcc_gcc_bazelrc_lines = [line.strip() for line in f.readlines()] + + def test_clang_bazelrc(self): + config = XLAConfigOptions( + backend=Backend.CPU, + os=OS.LINUX, + python_bin_path=_PYTHON_BIN_PATH, + host_compiler=HostCompiler.CLANG, + compiler_options=list(_COMPILER_OPTIONS), + cuda_compiler=CudaCompiler.NVCC, + using_nccl=False, + using_tensorrt=False, + ) + + bazelrc_lines = config.to_bazelrc_lines( + DiscoverablePathsAndVersions( + clang_path=_CLANG_PATH, + ld_library_path="", + clang_major_version=17, + ) + ) + + self.assertEqual(bazelrc_lines, self.clang_bazelrc_lines) + + def test_gcc_bazelrc(self): + config = XLAConfigOptions( + backend=Backend.CPU, + os=OS.LINUX, + python_bin_path=_PYTHON_BIN_PATH, + host_compiler=HostCompiler.GCC, + compiler_options=list(_COMPILER_OPTIONS), + cuda_compiler=CudaCompiler.NVCC, + using_nccl=False, + using_tensorrt=False, + ) + + bazelrc_lines = config.to_bazelrc_lines( + DiscoverablePathsAndVersions( + gcc_path=_GCC_PATH, + ld_library_path="", + ) + ) + + self.assertEqual(bazelrc_lines, self.gcc_bazelrc_lines) + + def test_cuda_clang_bazelrc(self): + config = XLAConfigOptions( + backend=Backend.CUDA, + os=OS.LINUX, + python_bin_path=_PYTHON_BIN_PATH, + host_compiler=HostCompiler.CLANG, + compiler_options=list(_COMPILER_OPTIONS), + cuda_compiler=CudaCompiler.CLANG, + using_nccl=False, + using_tensorrt=False, + ) + + bazelrc_lines = config.to_bazelrc_lines( + DiscoverablePathsAndVersions( + clang_path=_CLANG_PATH, + clang_major_version=17, + **_CUDA_SPECIFIC_PATHS_AND_VERSIONS, + ) + ) + + self.assertEqual(bazelrc_lines, self.cuda_clang_bazelrc_lines) + + def test_nvcc_clang_bazelrc(self): + config = XLAConfigOptions( + backend=Backend.CUDA, + os=OS.LINUX, + python_bin_path=_PYTHON_BIN_PATH, + host_compiler=HostCompiler.CLANG, + compiler_options=list(_COMPILER_OPTIONS), + cuda_compiler=CudaCompiler.NVCC, + using_nccl=False, + using_tensorrt=False, + ) + + bazelrc_lines = config.to_bazelrc_lines( + DiscoverablePathsAndVersions( + clang_path=_CLANG_PATH, + clang_major_version=17, + **_CUDA_SPECIFIC_PATHS_AND_VERSIONS, + ) + ) + + self.assertEqual(bazelrc_lines, self.nvcc_clang_bazelrc_lines) + + def test_nvcc_gcc_bazelrc(self): + config = XLAConfigOptions( + backend=Backend.CUDA, + os=OS.LINUX, + python_bin_path=_PYTHON_BIN_PATH, + host_compiler=HostCompiler.GCC, + compiler_options=list(_COMPILER_OPTIONS), + cuda_compiler=CudaCompiler.NVCC, + using_nccl=False, + using_tensorrt=False, + ) + + bazelrc_lines = config.to_bazelrc_lines( + DiscoverablePathsAndVersions( + gcc_path=_GCC_PATH, + **_CUDA_SPECIFIC_PATHS_AND_VERSIONS, + ) + ) + + self.assertEqual(bazelrc_lines, self.nvcc_gcc_bazelrc_lines) + + +if __name__ == "__main__": + absltest.main() diff --git a/build_tools/configure/testdata/clang.bazelrc b/build_tools/configure/testdata/clang.bazelrc new file mode 100644 index 0000000000000..317be65966633 --- /dev/null +++ b/build_tools/configure/testdata/clang.bazelrc @@ -0,0 +1,14 @@ +build --action_env CLANG_COMPILER_PATH=/usr/lib/llvm-17/bin/clang +build --repo_env CC=/usr/lib/llvm-17/bin/clang +build --repo_env BAZEL_COMPILER=/usr/lib/llvm-17/bin/clang +build --action_env PYTHON_BIN_PATH=/usr/bin/python3 +build --python_path /usr/bin/python3 +test --test_env LD_LIBRARY_PATH +test --test_size_filters small,medium +build --copt -Wno-sign-compare +build --copt -Wno-error=unused-command-line-argument +build --copt -Wno-gnu-offsetof-extensions +build --build_tag_filters -no_oss,-gpu +build --test_tag_filters -no_oss,-gpu +test --build_tag_filters -no_oss,-gpu +test --test_tag_filters -no_oss,-gpu diff --git a/build_tools/configure/testdata/cuda_clang.bazelrc b/build_tools/configure/testdata/cuda_clang.bazelrc new file mode 100644 index 0000000000000..b998cf06935f3 --- /dev/null +++ b/build_tools/configure/testdata/cuda_clang.bazelrc @@ -0,0 +1,23 @@ +build --action_env CLANG_COMPILER_PATH=/usr/lib/llvm-17/bin/clang +build --repo_env CC=/usr/lib/llvm-17/bin/clang +build --repo_env BAZEL_COMPILER=/usr/lib/llvm-17/bin/clang +build --config cuda_clang +build --action_env CLANG_CUDA_COMPILER_PATH=/usr/lib/llvm-17/bin/clang +build --action_env CUDA_TOOLKIT_PATH=/usr/local/cuda-12.2 +build --action_env TF_CUBLAS_VERSION=12.3 +build --action_env TF_CUDA_COMPUTE_CAPABILITIES=7.5 +build --action_env TF_CUDNN_VERSION=8 +build --repo_env TF_NEED_TENSORRT=0 +build --config nonccl +build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 +build --action_env PYTHON_BIN_PATH=/usr/bin/python3 +build --python_path /usr/bin/python3 +test --test_env LD_LIBRARY_PATH +test --test_size_filters small,medium +build --copt -Wno-sign-compare +build --copt -Wno-error=unused-command-line-argument +build --copt -Wno-gnu-offsetof-extensions +build --build_tag_filters -no_oss +build --test_tag_filters -no_oss +test --build_tag_filters -no_oss +test --test_tag_filters -no_oss diff --git a/build_tools/configure/testdata/gcc.bazelrc b/build_tools/configure/testdata/gcc.bazelrc new file mode 100644 index 0000000000000..8eefec15ee8ef --- /dev/null +++ b/build_tools/configure/testdata/gcc.bazelrc @@ -0,0 +1,10 @@ +build --action_env GCC_HOST_COMPILER_PATH=/usr/bin/gcc +build --action_env PYTHON_BIN_PATH=/usr/bin/python3 +build --python_path /usr/bin/python3 +test --test_env LD_LIBRARY_PATH +test --test_size_filters small,medium +build --copt -Wno-sign-compare +build --build_tag_filters -no_oss,-gpu +build --test_tag_filters -no_oss,-gpu +test --build_tag_filters -no_oss,-gpu +test --test_tag_filters -no_oss,-gpu diff --git a/build_tools/configure/testdata/nvcc_clang.bazelrc b/build_tools/configure/testdata/nvcc_clang.bazelrc new file mode 100644 index 0000000000000..912dc50faff4c --- /dev/null +++ b/build_tools/configure/testdata/nvcc_clang.bazelrc @@ -0,0 +1,23 @@ +build --action_env CLANG_COMPILER_PATH=/usr/lib/llvm-17/bin/clang +build --repo_env CC=/usr/lib/llvm-17/bin/clang +build --repo_env BAZEL_COMPILER=/usr/lib/llvm-17/bin/clang +build --config nvcc_clang +build --action_env CLANG_CUDA_COMPILER_PATH=/usr/lib/llvm-17/bin/clang +build --action_env CUDA_TOOLKIT_PATH=/usr/local/cuda-12.2 +build --action_env TF_CUBLAS_VERSION=12.3 +build --action_env TF_CUDA_COMPUTE_CAPABILITIES=7.5 +build --action_env TF_CUDNN_VERSION=8 +build --repo_env TF_NEED_TENSORRT=0 +build --config nonccl +build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 +build --action_env PYTHON_BIN_PATH=/usr/bin/python3 +build --python_path /usr/bin/python3 +test --test_env LD_LIBRARY_PATH +test --test_size_filters small,medium +build --copt -Wno-sign-compare +build --copt -Wno-error=unused-command-line-argument +build --copt -Wno-gnu-offsetof-extensions +build --build_tag_filters -no_oss +build --test_tag_filters -no_oss +test --build_tag_filters -no_oss +test --test_tag_filters -no_oss diff --git a/build_tools/configure/testdata/nvcc_gcc.bazelrc b/build_tools/configure/testdata/nvcc_gcc.bazelrc new file mode 100644 index 0000000000000..863209697362d --- /dev/null +++ b/build_tools/configure/testdata/nvcc_gcc.bazelrc @@ -0,0 +1,18 @@ +build --action_env GCC_HOST_COMPILER_PATH=/usr/bin/gcc +build --config cuda +build --action_env CUDA_TOOLKIT_PATH=/usr/local/cuda-12.2 +build --action_env TF_CUBLAS_VERSION=12.3 +build --action_env TF_CUDA_COMPUTE_CAPABILITIES=7.5 +build --action_env TF_CUDNN_VERSION=8 +build --repo_env TF_NEED_TENSORRT=0 +build --config nonccl +build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 +build --action_env PYTHON_BIN_PATH=/usr/bin/python3 +build --python_path /usr/bin/python3 +test --test_env LD_LIBRARY_PATH +test --test_size_filters small,medium +build --copt -Wno-sign-compare +build --build_tag_filters -no_oss +build --test_tag_filters -no_oss +test --build_tag_filters -no_oss +test --test_tag_filters -no_oss diff --git a/build_tools/docker/context/install_bazel.sh b/build_tools/docker/context/install_bazel.sh index c9294f2b2bc93..1cca040d2f589 100755 --- a/build_tools/docker/context/install_bazel.sh +++ b/build_tools/docker/context/install_bazel.sh @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/build_tools/docker/context/install_python_deps.sh b/build_tools/docker/context/install_python_deps.sh index 4e0073e5f5257..e22e0d920131b 100755 --- a/build_tools/docker/context/install_python_deps.sh +++ b/build_tools/docker/context/install_python_deps.sh @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/build_tools/docker/dockerfiles/benchmarking.Dockerfile b/build_tools/docker/dockerfiles/benchmarking.Dockerfile index aaa5bff67c9b9..6159024ab3ecd 100644 --- a/build_tools/docker/dockerfiles/benchmarking.Dockerfile +++ b/build_tools/docker/dockerfiles/benchmarking.Dockerfile @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/build_tools/github_actions/build_xla.sh b/build_tools/github_actions/build_xla.sh index d3fea2f921911..127f376f8b06f 100755 --- a/build_tools/github_actions/build_xla.sh +++ b/build_tools/github_actions/build_xla.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/build_tools/lint/BUILD b/build_tools/lint/BUILD index 5387e3b3097cb..b4b825c925425 100644 --- a/build_tools/lint/BUILD +++ b/build_tools/lint/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ -load("//xla:pytype.default.bzl", "pytype_strict_binary", "pytype_strict_library") +load("//xla:pytype.default.bzl", "pytype_strict_library") # Placeholder: load py_test package( @@ -33,7 +33,7 @@ pytype_strict_library( visibility = ["//visibility:public"], ) -pytype_strict_binary( +pytype_strict_library( name = "generate_compile_commands", srcs = ["generate_compile_commands.py"], ) @@ -45,9 +45,9 @@ py_test( "testdata/bad_cc.diff", "testdata/important_cc.diff", ], - tags = ["no_oss"], deps = [ ":check_contents", + "//build_tools:test_utils", "@absl_py//absl/testing:absltest", ], ) @@ -60,9 +60,18 @@ py_test( "testdata/crosstool.diff", "testdata/important_cc.diff", ], - tags = ["no_oss"], deps = [ ":diff_parser", + "//build_tools:test_utils", + "@absl_py//absl/testing:absltest", + ], +) + +py_test( + name = "generate_compile_commands_test", + srcs = ["generate_compile_commands_test.py"], + deps = [ + ":generate_compile_commands", "@absl_py//absl/testing:absltest", ], ) diff --git a/build_tools/lint/check_contents.py b/build_tools/lint/check_contents.py index 1649152148d1a..f61876017e981 100644 --- a/build_tools/lint/check_contents.py +++ b/build_tools/lint/check_contents.py @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ import logging # Intended to run on vanilla Github Actions runner import re import sys -from typing import Iterable, Optional, Sequence +from typing import Iterable, Sequence from xla.build_tools.lint import diff_parser @@ -92,7 +92,7 @@ def check_diffs( hunks: Iterable[diff_parser.Hunk], *, prohibited_regex: str, - suppression_regex: Optional[str] = None, # TODO(ddunleavy): CI not on 3.10 + suppression_regex: str | None = None, ) -> list[RegexLocation]: """Checks FileDiffs for prohibited regexes. diff --git a/build_tools/lint/check_contents_test.py b/build_tools/lint/check_contents_test.py index e97781f3270ca..21d58785c6f6a 100644 --- a/build_tools/lint/check_contents_test.py +++ b/build_tools/lint/check_contents_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ # ============================================================================ from absl.testing import absltest +from xla.build_tools import test_utils from xla.build_tools.lint import check_contents from xla.build_tools.lint import diff_parser @@ -24,11 +25,11 @@ class CheckDiffsTest(absltest.TestCase): def setUpClass(cls): super().setUpClass() - base_path = "third_party/xla/build_tools/lint" - with open(f"{base_path}/testdata/bad_cc.diff") as f: + testdata = test_utils.xla_src_root() / "build_tools" / "lint" / "testdata" + with (testdata / "bad_cc.diff").open() as f: cls.bad_cc_hunks = diff_parser.parse_hunks(f.read()) - with open(f"{base_path}/testdata/important_cc.diff") as f: + with (testdata / "important_cc.diff").open() as f: cls.important_cc_hunks = diff_parser.parse_hunks(f.read()) def test_check_good_diff(self): diff --git a/build_tools/lint/diff_parser.py b/build_tools/lint/diff_parser.py index 6700264861ca5..ac459dccf6fbd 100644 --- a/build_tools/lint/diff_parser.py +++ b/build_tools/lint/diff_parser.py @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/build_tools/lint/diff_parser_test.py b/build_tools/lint/diff_parser_test.py index 8e4d2acd75a3d..787020cc86503 100644 --- a/build_tools/lint/diff_parser_test.py +++ b/build_tools/lint/diff_parser_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ # ============================================================================ from absl.testing import absltest +from xla.build_tools import test_utils from xla.build_tools.lint import diff_parser @@ -23,15 +24,15 @@ class ParseDiffTest(absltest.TestCase): def setUpClass(cls): super().setUpClass() - base_path = "third_party/xla/build_tools/lint" + testdata = test_utils.xla_src_root() / "build_tools" / "lint" / "testdata" - with open(f"{base_path}/testdata/bad_cc.diff") as f: + with (testdata / "bad_cc.diff").open() as f: cls.bad_cc_diff = f.read() - with open(f"{base_path}/testdata/important_cc.diff") as f: + with (testdata / "important_cc.diff").open() as f: cls.important_cc_diff = f.read() - with open(f"{base_path}/testdata/crosstool.diff") as f: + with (testdata / "crosstool.diff").open() as f: cls.crosstool_diff = f.read() def test_parse_important_cc_diff(self): diff --git a/build_tools/lint/generate_compile_commands.py b/build_tools/lint/generate_compile_commands.py index 735fc53f8aa8a..ec9d6fe0d2037 100644 --- a/build_tools/lint/generate_compile_commands.py +++ b/build_tools/lint/generate_compile_commands.py @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/build_tools/lint/generate_compile_commands_test.py b/build_tools/lint/generate_compile_commands_test.py new file mode 100644 index 0000000000000..95318726f7f17 --- /dev/null +++ b/build_tools/lint/generate_compile_commands_test.py @@ -0,0 +1,54 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from absl.testing import absltest + +from xla.build_tools.lint import generate_compile_commands + +CompileCommand = generate_compile_commands.CompileCommand + + +class CompileCommandsTest(absltest.TestCase): + + def test_command_from_args_list(self): + arguments = [ + "/usr/bin/gcc", + "-DTEST_DEFINE", + "-fstack-protector", + "-c", + "xla/compiler.cc", + "-o", + "bazel-out/k8-opt/bin/xla/_objs/compiler/compiler.pic.o", + ] + + command = CompileCommand.from_args_list(arguments) + + self.assertEqual(command.file, "xla/compiler.cc") + self.assertEqual(command.arguments, arguments) + + def test_command_from_args_list_with_disallowed_option(self): + arguments = [ + "/usr/bin/gcc", + "-DTEST_DEFINE", + "-fno-canonical-system-headers", + "-c", + "xla/compiler.cc", + "-o", + "bazel-out/k8-opt/bin/xla/_objs/compiler/compiler.pic.o", + ] + + command = CompileCommand.from_args_list(arguments) + + self.assertEqual(command.file, "xla/compiler.cc") + self.assertEqual(command.arguments, arguments[0:2] + arguments[3:]) diff --git a/build_tools/test_utils.py b/build_tools/test_utils.py new file mode 100644 index 0000000000000..1d9672379d9ca --- /dev/null +++ b/build_tools/test_utils.py @@ -0,0 +1,28 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test utils for python tests in XLA.""" +import os +import pathlib + + +def xla_src_root() -> pathlib.Path: + """Gets the path to the root of the XLA source tree.""" + is_oss = "BAZEL_TEST" in os.environ + test_srcdir = os.environ["TEST_SRCDIR"] + test_workspace = os.environ["TEST_WORKSPACE"] + if is_oss: + return pathlib.Path(test_srcdir) / test_workspace + else: + return pathlib.Path(test_srcdir) / test_workspace / "third_party" / "xla" diff --git a/configure b/configure deleted file mode 100755 index e43908e39da0c..0000000000000 --- a/configure +++ /dev/null @@ -1,15 +0,0 @@ -#!/usr/bin/env bash - -set -e -set -o pipefail - -if [ -z "$PYTHON_BIN_PATH" ]; then - PYTHON_BIN_PATH=$(which python3 || which python || true) -fi - -# Set all env variables -CONFIGURE_DIR=$(dirname "$0") -"$PYTHON_BIN_PATH" "${CONFIGURE_DIR}/configure.py" "$@" - -echo "Configuration finished" - diff --git a/configure.cmd b/configure.cmd deleted file mode 100644 index 738e106da18fb..0000000000000 --- a/configure.cmd +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -:: WARRANTIES OR CONDITIONS OF ANY KIND< either express or implied. See the -:: License for the specific language governing permissions and limitations under -:: the License. - -@echo off - -set configure_dir=%~dp0 -set configure_dir=%configure_dir:~0,-1% -python "%configure_dir%\configure.py" %* || ( exit /b ) -echo Configuration finished diff --git a/configure.py b/configure.py deleted file mode 100644 index b2b30dfa2956b..0000000000000 --- a/configure.py +++ /dev/null @@ -1,1102 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""configure script to get build parameters from user.""" - -import argparse -import os -import pathlib -import platform -import re -import subprocess -import sys - -# pylint: disable=g-import-not-at-top -try: - from shutil import which -except ImportError: - from distutils.spawn import find_executable as which -# pylint: enable=g-import-not-at-top - -_DEFAULT_CUDA_VERSION = '11' -_DEFAULT_CUDNN_VERSION = '2' -_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '5.2,7.0' - -_DEFAULT_PROMPT_ASK_ATTEMPTS = 10 - -_TF_BAZELRC_FILENAME = '.tf_configure.bazelrc' -_TF_WORKSPACE_ROOT = '' -_TF_BAZELRC = '' -_TF_CURRENT_BAZEL_VERSION = None - - -class UserInputError(Exception): - pass - - -def is_windows(): - return platform.system() == 'Windows' - - -def is_linux(): - return platform.system() == 'Linux' - - -def is_macos(): - return platform.system() == 'Darwin' - - -def is_ppc64le(): - return platform.machine() == 'ppc64le' - - -def is_cygwin(): - return platform.system().startswith('CYGWIN_NT') - - -def get_input(question): - try: - try: - answer = raw_input(question) - except NameError: - answer = input(question) # pylint: disable=bad-builtin - except EOFError: - answer = '' - return answer - - -def write_to_bazelrc(line): - with open(_TF_BAZELRC, 'a') as f: - f.write(line + '\n') - - -def write_action_env_to_bazelrc(var_name, var): - write_to_bazelrc('build --action_env {}="{}"'.format(var_name, str(var))) - - -def run_shell(cmd, allow_non_zero=False, stderr=None): - if stderr is None: - stderr = sys.stdout - if allow_non_zero: - try: - output = subprocess.check_output(cmd, stderr=stderr) - except subprocess.CalledProcessError as e: - output = e.output - else: - output = subprocess.check_output(cmd, stderr=stderr) - return output.decode('UTF-8').strip() - - -def cygpath(path): - """Convert path from posix to windows.""" - return os.path.abspath(path).replace('\\', '/') - - -def get_python_path(environ_cp, python_bin_path): - """Get the python site package paths.""" - python_paths = [] - if environ_cp.get('PYTHONPATH'): - python_paths = environ_cp.get('PYTHONPATH').split(':') - try: - stderr = open(os.devnull, 'wb') - library_paths = run_shell([ - python_bin_path, '-c', - 'import site; print("\\n".join(site.getsitepackages()))' - ], - stderr=stderr).split('\n') - except subprocess.CalledProcessError: - library_paths = [ - run_shell([ - python_bin_path, '-c', - 'from distutils.sysconfig import get_python_lib;' - 'print(get_python_lib())' - ]) - ] - - all_paths = set(python_paths + library_paths) - # Sort set so order is deterministic - all_paths = sorted(all_paths) - - paths = [] - for path in all_paths: - if os.path.isdir(path): - paths.append(path) - return paths - - -def get_python_major_version(python_bin_path): - """Get the python major version.""" - return run_shell([python_bin_path, '-c', 'import sys; print(sys.version[0])']) - - -def setup_python(environ_cp): - """Setup python related env variables.""" - # Get PYTHON_BIN_PATH, default is the current running python. - default_python_bin_path = sys.executable - ask_python_bin_path = ('Please specify the location of python. [Default is ' - '{}]: ').format(default_python_bin_path) - while True: - python_bin_path = get_from_env_or_user_or_default(environ_cp, - 'PYTHON_BIN_PATH', - ask_python_bin_path, - default_python_bin_path) - # Check if the path is valid - if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK): - break - elif not os.path.exists(python_bin_path): - print('Invalid python path: {} cannot be found.'.format(python_bin_path)) - else: - print('{} is not executable. Is it the python binary?'.format( - python_bin_path)) - environ_cp['PYTHON_BIN_PATH'] = '' - - # Convert python path to Windows style before checking lib and version - if is_windows() or is_cygwin(): - python_bin_path = cygpath(python_bin_path) - - # Get PYTHON_LIB_PATH - python_lib_path = environ_cp.get('PYTHON_LIB_PATH') - if not python_lib_path: - python_lib_paths = get_python_path(environ_cp, python_bin_path) - if environ_cp.get('USE_DEFAULT_PYTHON_LIB_PATH') == '1': - python_lib_path = python_lib_paths[0] - else: - print('Found possible Python library paths:\n %s' % - '\n '.join(python_lib_paths)) - default_python_lib_path = python_lib_paths[0] - python_lib_path = get_input( - 'Please input the desired Python library path to use. ' - 'Default is [{}]\n'.format(python_lib_paths[0])) - if not python_lib_path: - python_lib_path = default_python_lib_path - environ_cp['PYTHON_LIB_PATH'] = python_lib_path - - python_major_version = get_python_major_version(python_bin_path) - if python_major_version == '2': - write_to_bazelrc('build --host_force_python=PY2') - - # Convert python path to Windows style before writing into bazel.rc - if is_windows() or is_cygwin(): - python_lib_path = cygpath(python_lib_path) - - # Set-up env variables used by python_configure.bzl - write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path) - write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path) - write_to_bazelrc('build --python_path=\"{}"'.format(python_bin_path)) - environ_cp['PYTHON_BIN_PATH'] = python_bin_path - - # If choosen python_lib_path is from a path specified in the PYTHONPATH - # variable, need to tell bazel to include PYTHONPATH - if environ_cp.get('PYTHONPATH'): - python_paths = environ_cp.get('PYTHONPATH').split(':') - if python_lib_path in python_paths: - write_action_env_to_bazelrc('PYTHONPATH', environ_cp.get('PYTHONPATH')) - - # Write tools/python_bin_path.sh - with open( - os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), - 'w') as f: - f.write('export PYTHON_BIN_PATH="{}"'.format(python_bin_path)) - - -def reset_tf_configure_bazelrc(): - """Reset file that contains customized config settings.""" - open(_TF_BAZELRC, 'w').close() - - -def get_var(environ_cp, - var_name, - query_item, - enabled_by_default, - question=None, - yes_reply=None, - no_reply=None): - """Get boolean input from user. - - If var_name is not set in env, ask user to enable query_item or not. If the - response is empty, use the default. - - Args: - environ_cp: copy of the os.environ. - var_name: string for name of environment variable, e.g. "TF_NEED_CUDA". - query_item: string for feature related to the variable, e.g. "CUDA for - Nvidia GPUs". - enabled_by_default: boolean for default behavior. - question: optional string for how to ask for user input. - yes_reply: optional string for reply when feature is enabled. - no_reply: optional string for reply when feature is disabled. - - Returns: - boolean value of the variable. - - Raises: - UserInputError: if an environment variable is set, but it cannot be - interpreted as a boolean indicator, assume that the user has made a - scripting error, and will continue to provide invalid input. - Raise the error to avoid infinitely looping. - """ - if not question: - question = 'Do you wish to build XLA with {} support?'.format( - query_item) - if not yes_reply: - yes_reply = '{} support will be enabled for XLA.'.format(query_item) - if not no_reply: - no_reply = 'No {}'.format(yes_reply) - - yes_reply += '\n' - no_reply += '\n' - - if enabled_by_default: - question += ' [Y/n]: ' - else: - question += ' [y/N]: ' - - var = environ_cp.get(var_name) - if var is not None: - var_content = var.strip().lower() - true_strings = ('1', 't', 'true', 'y', 'yes') - false_strings = ('0', 'f', 'false', 'n', 'no') - if var_content in true_strings: - var = True - elif var_content in false_strings: - var = False - else: - raise UserInputError( - 'Environment variable %s must be set as a boolean indicator.\n' - 'The following are accepted as TRUE : %s.\n' - 'The following are accepted as FALSE: %s.\n' - 'Current value is %s.' % - (var_name, ', '.join(true_strings), ', '.join(false_strings), var)) - - while var is None: - user_input_origin = get_input(question) - user_input = user_input_origin.strip().lower() - if user_input == 'y': - print(yes_reply) - var = True - elif user_input == 'n': - print(no_reply) - var = False - elif not user_input: - if enabled_by_default: - print(yes_reply) - var = True - else: - print(no_reply) - var = False - else: - print('Invalid selection: {}'.format(user_input_origin)) - return var - - -def set_action_env_var(environ_cp, - var_name, - query_item, - enabled_by_default, - question=None, - yes_reply=None, - no_reply=None, - bazel_config_name=None): - """Set boolean action_env variable. - - Ask user if query_item will be enabled. Default is used if no input is given. - Set environment variable and write to .bazelrc. - - Args: - environ_cp: copy of the os.environ. - var_name: string for name of environment variable, e.g. "TF_NEED_CUDA". - query_item: string for feature related to the variable, e.g. "CUDA for - Nvidia GPUs". - enabled_by_default: boolean for default behavior. - question: optional string for how to ask for user input. - yes_reply: optional string for reply when feature is enabled. - no_reply: optional string for reply when feature is disabled. - bazel_config_name: adding config to .bazelrc instead of action_env. - """ - var = int( - get_var(environ_cp, var_name, query_item, enabled_by_default, question, - yes_reply, no_reply)) - - if not bazel_config_name: - write_action_env_to_bazelrc(var_name, var) - elif var: - write_to_bazelrc('build --config=%s' % bazel_config_name) - environ_cp[var_name] = str(var) - - -def convert_version_to_int(version): - """Convert a version number to a integer that can be used to compare. - - Version strings of the form X.YZ and X.Y.Z-xxxxx are supported. The - 'xxxxx' part, for instance 'homebrew' on OS/X, is ignored. - - Args: - version: a version to be converted - - Returns: - An integer if converted successfully, otherwise return None. - """ - version = version.split('-')[0] - version_segments = version.split('.') - # Treat "0.24" as "0.24.0" - if len(version_segments) == 2: - version_segments.append('0') - for seg in version_segments: - if not seg.isdigit(): - return None - - version_str = ''.join(['%03d' % int(seg) for seg in version_segments]) - return int(version_str) - - -def retrieve_bazel_version(): - """Retrieve installed bazel version (or bazelisk). - - Returns: - The bazel version detected. - """ - bazel_executable = which('bazel') - if bazel_executable is None: - bazel_executable = which('bazelisk') - if bazel_executable is None: - print('Cannot find bazel. Please install bazel/bazelisk.') - sys.exit(1) - - stderr = open(os.devnull, 'wb') - curr_version = run_shell([bazel_executable, '--version'], - allow_non_zero=True, - stderr=stderr) - if curr_version.startswith('bazel '): - curr_version = curr_version.split('bazel ')[1] - - curr_version_int = convert_version_to_int(curr_version) - - # Check if current bazel version can be detected properly. - if not curr_version_int: - print('WARNING: current bazel installation is not a release version.') - return curr_version - - print('You have bazel %s installed.' % curr_version) - return curr_version - - -def set_cc_opt_flags(environ_cp): - """Set up architecture-dependent optimization flags. - - Also append CC optimization flags to bazel.rc.. - - Args: - environ_cp: copy of the os.environ. - """ - if is_ppc64le(): - # gcc on ppc64le does not support -march, use mcpu instead - default_cc_opt_flags = '-mcpu=native' - elif is_windows(): - default_cc_opt_flags = '/arch:AVX' - else: - # On all other platforms, no longer use `-march=native` as this can result - # in instructions that are too modern being generated. Users that want - # maximum performance should compile TF in their environment and can pass - # `-march=native` there. - # See https://github.com/tensorflow/tensorflow/issues/45744 and duplicates - default_cc_opt_flags = '-Wno-sign-compare' - question = ('Please specify optimization flags to use during compilation when' - ' bazel option "--config=opt" is specified [Default is %s]: ' - ) % default_cc_opt_flags - cc_opt_flags = get_from_env_or_user_or_default(environ_cp, 'CC_OPT_FLAGS', - question, default_cc_opt_flags) - for opt in cc_opt_flags.split(): - write_to_bazelrc('build:opt --copt=%s' % opt) - write_to_bazelrc('build:opt --host_copt=%s' % opt) - - -def set_tf_cuda_clang(environ_cp): - """set TF_CUDA_CLANG action_env. - - Args: - environ_cp: copy of the os.environ. - """ - question = 'Do you want to use clang as CUDA compiler?' - yes_reply = 'Clang will be used as CUDA compiler.' - no_reply = 'nvcc will be used as CUDA compiler.' - set_action_env_var( - environ_cp, - 'TF_CUDA_CLANG', - None, - False, - question=question, - yes_reply=yes_reply, - no_reply=no_reply, - bazel_config_name='cuda_clang') - - -def set_tf_download_clang(environ_cp): - """Set TF_DOWNLOAD_CLANG action_env.""" - question = 'Do you wish to download a fresh release of clang? (Experimental)' - yes_reply = 'Clang will be downloaded and used to compile tensorflow.' - no_reply = 'Clang will not be downloaded.' - set_action_env_var( - environ_cp, - 'TF_DOWNLOAD_CLANG', - None, - False, - question=question, - yes_reply=yes_reply, - no_reply=no_reply, - bazel_config_name='download_clang') - - -def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var, - var_default): - """Get var_name either from env, or user or default. - - If var_name has been set as environment variable, use the preset value, else - ask for user input. If no input is provided, the default is used. - - Args: - environ_cp: copy of the os.environ. - var_name: string for name of environment variable, e.g. "TF_NEED_CUDA". - ask_for_var: string for how to ask for user input. - var_default: default value string. - - Returns: - string value for var_name - """ - var = environ_cp.get(var_name) - if not var: - var = get_input(ask_for_var) - print('\n') - if not var: - var = var_default - return var - - -def set_clang_cuda_compiler_path(environ_cp): - """Set CLANG_CUDA_COMPILER_PATH.""" - default_clang_path = which('clang') or '' - ask_clang_path = ('Please specify which clang should be used as device and ' - 'host compiler. [Default is %s]: ') % default_clang_path - - while True: - clang_cuda_compiler_path = get_from_env_or_user_or_default( - environ_cp, 'CLANG_CUDA_COMPILER_PATH', ask_clang_path, - default_clang_path) - if os.path.exists(clang_cuda_compiler_path): - break - - # Reset and retry - print('Invalid clang path: %s cannot be found.' % clang_cuda_compiler_path) - environ_cp['CLANG_CUDA_COMPILER_PATH'] = '' - - # Set CLANG_CUDA_COMPILER_PATH - environ_cp['CLANG_CUDA_COMPILER_PATH'] = clang_cuda_compiler_path - write_action_env_to_bazelrc('CLANG_CUDA_COMPILER_PATH', - clang_cuda_compiler_path) - - -def prompt_loop_or_load_from_env(environ_cp, - var_name, - var_default, - ask_for_var, - check_success, - error_msg, - suppress_default_error=False, - resolve_symlinks=False, - n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS): - """Loop over user prompts for an ENV param until receiving a valid response. - - For the env param var_name, read from the environment or verify user input - until receiving valid input. When done, set var_name in the environ_cp to its - new value. - - Args: - environ_cp: (Dict) copy of the os.environ. - var_name: (String) string for name of environment variable, e.g. "TF_MYVAR". - var_default: (String) default value string. - ask_for_var: (String) string for how to ask for user input. - check_success: (Function) function that takes one argument and returns a - boolean. Should return True if the value provided is considered valid. May - contain a complex error message if error_msg does not provide enough - information. In that case, set suppress_default_error to True. - error_msg: (String) String with one and only one '%s'. Formatted with each - invalid response upon check_success(input) failure. - suppress_default_error: (Bool) Suppress the above error message in favor of - one from the check_success function. - resolve_symlinks: (Bool) Translate symbolic links into the real filepath. - n_ask_attempts: (Integer) Number of times to query for valid input before - raising an error and quitting. - - Returns: - [String] The value of var_name after querying for input. - - Raises: - UserInputError: if a query has been attempted n_ask_attempts times without - success, assume that the user has made a scripting error, and will - continue to provide invalid input. Raise the error to avoid infinitely - looping. - """ - default = environ_cp.get(var_name) or var_default - full_query = '%s [Default is %s]: ' % ( - ask_for_var, - default, - ) - - for _ in range(n_ask_attempts): - val = get_from_env_or_user_or_default(environ_cp, var_name, full_query, - default) - if check_success(val): - break - if not suppress_default_error: - print(error_msg % val) - environ_cp[var_name] = '' - else: - raise UserInputError('Invalid %s setting was provided %d times in a row. ' - 'Assuming to be a scripting mistake.' % - (var_name, n_ask_attempts)) - - if resolve_symlinks: - val = os.path.realpath(val) - environ_cp[var_name] = val - return val - - -def set_gcc_host_compiler_path(environ_cp): - """Set GCC_HOST_COMPILER_PATH.""" - default_gcc_host_compiler_path = which('gcc') or '' - cuda_bin_symlink = '%s/bin/gcc' % environ_cp.get('CUDA_TOOLKIT_PATH') - - if os.path.islink(cuda_bin_symlink): - # os.readlink is only available in linux - default_gcc_host_compiler_path = os.path.realpath(cuda_bin_symlink) - - gcc_host_compiler_path = prompt_loop_or_load_from_env( - environ_cp, - var_name='GCC_HOST_COMPILER_PATH', - var_default=default_gcc_host_compiler_path, - ask_for_var='Please specify which gcc should be used by nvcc as the host ' - 'compiler.', - check_success=os.path.exists, - resolve_symlinks=True, - error_msg='Invalid gcc path. %s cannot be found.', - ) - - write_action_env_to_bazelrc('GCC_HOST_COMPILER_PATH', gcc_host_compiler_path) - - -def set_tf_cuda_paths(environ_cp): - """Set TF_CUDA_PATHS.""" - ask_cuda_paths = ( - 'Please specify the comma-separated list of base paths to look for CUDA ' - 'libraries and headers. [Leave empty to use the default]: ') - tf_cuda_paths = get_from_env_or_user_or_default(environ_cp, 'TF_CUDA_PATHS', - ask_cuda_paths, '') - if tf_cuda_paths: - environ_cp['TF_CUDA_PATHS'] = tf_cuda_paths - - -def set_tf_cuda_version(environ_cp): - """Set TF_CUDA_VERSION.""" - ask_cuda_version = ( - 'Please specify the CUDA SDK version you want to use. ' - '[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION - tf_cuda_version = get_from_env_or_user_or_default(environ_cp, - 'TF_CUDA_VERSION', - ask_cuda_version, - _DEFAULT_CUDA_VERSION) - environ_cp['TF_CUDA_VERSION'] = tf_cuda_version - - -def set_tf_cudnn_version(environ_cp): - """Set TF_CUDNN_VERSION.""" - ask_cudnn_version = ( - 'Please specify the cuDNN version you want to use. ' - '[Leave empty to default to cuDNN %s]: ') % _DEFAULT_CUDNN_VERSION - tf_cudnn_version = get_from_env_or_user_or_default(environ_cp, - 'TF_CUDNN_VERSION', - ask_cudnn_version, - _DEFAULT_CUDNN_VERSION) - environ_cp['TF_CUDNN_VERSION'] = tf_cudnn_version - - -def set_tf_nccl_version(environ_cp): - """Set TF_NCCL_VERSION.""" - if not is_linux(): - raise ValueError('Currently NCCL is only supported on Linux platform.') - - if 'TF_NCCL_VERSION' in environ_cp: - return - - ask_nccl_version = ( - 'Please specify the locally installed NCCL version you want to use. ' - '[Leave empty to use http://github.com/nvidia/nccl]: ') - tf_nccl_version = get_from_env_or_user_or_default(environ_cp, - 'TF_NCCL_VERSION', - ask_nccl_version, '') - environ_cp['TF_NCCL_VERSION'] = tf_nccl_version - - -def get_native_cuda_compute_capabilities(environ_cp): - """Get native cuda compute capabilities. - - Args: - environ_cp: copy of the os.environ. - - Returns: - string of native cuda compute capabilities, separated by comma. - """ - device_query_bin = os.path.join( - environ_cp.get('CUDA_TOOLKIT_PATH'), 'extras/demo_suite/deviceQuery') - if os.path.isfile(device_query_bin) and os.access(device_query_bin, os.X_OK): - try: - output = run_shell(device_query_bin).split('\n') - pattern = re.compile('[0-9]*\\.[0-9]*') - output = [pattern.search(x) for x in output if 'Capability' in x] - output = ','.join(x.group() for x in output if x is not None) - except subprocess.CalledProcessError: - output = '' - else: - output = '' - return output - - -def set_tf_cuda_compute_capabilities(environ_cp): - """Set TF_CUDA_COMPUTE_CAPABILITIES.""" - while True: - native_cuda_compute_capabilities = get_native_cuda_compute_capabilities( - environ_cp) - if not native_cuda_compute_capabilities: - default_cuda_compute_capabilities = _DEFAULT_CUDA_COMPUTE_CAPABILITIES - else: - default_cuda_compute_capabilities = native_cuda_compute_capabilities - - ask_cuda_compute_capabilities = ( - 'Please specify a list of comma-separated CUDA compute capabilities ' - 'you want to build with.\nYou can find the compute capability of your ' - 'device at: https://developer.nvidia.com/cuda-gpus. Each capability ' - 'can be specified as "x.y" or "compute_xy" to include both virtual and' - ' binary GPU code, or as "sm_xy" to only include the binary ' - 'code.\nPlease note that each additional compute capability ' - 'significantly increases your build time and binary size, and that ' - 'XLA only supports compute capabilities >= 5.2 [Default is: ' - '%s]: ' % default_cuda_compute_capabilities - ) - tf_cuda_compute_capabilities = get_from_env_or_user_or_default( - environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES', - ask_cuda_compute_capabilities, default_cuda_compute_capabilities) - # Check whether all capabilities from the input is valid - all_valid = True - # Remove all whitespace characters before splitting the string - # that users may insert by accident, as this will result in error - tf_cuda_compute_capabilities = ''.join(tf_cuda_compute_capabilities.split()) - for compute_capability in tf_cuda_compute_capabilities.split(','): - m = re.match('[0-9]+.[0-9]+', compute_capability) - if not m: - # We now support sm_52,compute_70. - sm_compute_match = re.match('(sm|compute)_?([0-9]+[0-9]+)', - compute_capability) - if not sm_compute_match: - print('Invalid compute capability: %s' % compute_capability) - all_valid = False - else: - ver = int(sm_compute_match.group(2)) - if ver < 52: - print( - 'ERROR: XLA only supports small CUDA compute' - ' capabilities of sm_52 and higher. Please re-specify the list' - ' of compute capabilities excluding version %s.' % ver - ) - all_valid = False - else: - ver = float(m.group(0)) - if ver < 5.2: - print( - 'ERROR: XLA only supports CUDA compute capabilities 5.2 ' - 'and higher. Please re-specify the list of compute ' - 'capabilities excluding version %s.' % ver - ) - all_valid = False - - if all_valid: - break - - # Reset and Retry - environ_cp['TF_CUDA_COMPUTE_CAPABILITIES'] = '' - - # Set TF_CUDA_COMPUTE_CAPABILITIES - environ_cp['TF_CUDA_COMPUTE_CAPABILITIES'] = tf_cuda_compute_capabilities - write_action_env_to_bazelrc('TF_CUDA_COMPUTE_CAPABILITIES', - tf_cuda_compute_capabilities) - - -def set_other_cuda_vars(environ_cp): - """Set other CUDA related variables.""" - # If CUDA is enabled, always use GPU during build and test. - if environ_cp.get('TF_CUDA_CLANG') == '1': - write_to_bazelrc('build --config=cuda_clang') - else: - write_to_bazelrc('build --config=cuda') - - -def system_specific_test_config(environ_cp): - """Add default build and test flags required for TF tests to bazelrc.""" - write_to_bazelrc('test --flaky_test_attempts=3') - write_to_bazelrc('test --test_size_filters=small,medium') - - # Each instance of --test_tag_filters or --build_tag_filters overrides all - # previous instances, so we need to build up a complete list and write a - # single list of filters for the .bazelrc file. - - # Filters to use with both --test_tag_filters and --build_tag_filters - test_and_build_filters = ['-benchmark-test', '-no_oss', '-oss_excluded'] - # Additional filters for --test_tag_filters beyond those in - # test_and_build_filters - test_only_filters = ['-oss_serial'] - if is_windows(): - test_and_build_filters += ['-no_windows', '-windows_excluded'] - if ((environ_cp.get('TF_NEED_CUDA', None) == '1') or - (environ_cp.get('TF_NEED_ROCM', None) == '1')): - test_and_build_filters += ['-no_windows_gpu', '-no_gpu'] - else: - test_and_build_filters.append('-gpu') - elif is_macos(): - test_and_build_filters += ['-gpu', '-nomac', '-no_mac', '-mac_excluded'] - elif is_linux(): - if ((environ_cp.get('TF_NEED_CUDA', None) == '1') or - (environ_cp.get('TF_NEED_ROCM', None) == '1')): - test_and_build_filters.append('-no_gpu') - write_to_bazelrc('test --test_env=LD_LIBRARY_PATH') - else: - test_and_build_filters.append('-gpu') - if environ_cp.get('TF_NEED_ROCM', None) == '1': - test_and_build_filters.append('-no_rocm') - - write_to_bazelrc('test --test_tag_filters=%s' % - ','.join(test_and_build_filters + test_only_filters)) - write_to_bazelrc('test --build_tag_filters=%s' % - ','.join(test_and_build_filters)) - write_to_bazelrc('build --test_tag_filters=%s' % - ','.join(test_and_build_filters + test_only_filters)) - write_to_bazelrc('build --build_tag_filters=%s' % - ','.join(test_and_build_filters)) - - # Disable tests with "v1only" tag in "v2" Bazel config, but not in "v1" config - write_to_bazelrc('test:v1 --test_tag_filters=%s' % - ','.join(test_and_build_filters + test_only_filters)) - write_to_bazelrc('test:v1 --build_tag_filters=%s' % - ','.join(test_and_build_filters)) - write_to_bazelrc( - 'test:v2 --test_tag_filters=%s' % - ','.join(test_and_build_filters + test_only_filters + ['-v1only'])) - write_to_bazelrc('test:v2 --build_tag_filters=%s' % - ','.join(test_and_build_filters + ['-v1only'])) - - -def set_system_libs_flag(environ_cp): - syslibs = environ_cp.get('TF_SYSTEM_LIBS', '') - if syslibs: - if ',' in syslibs: - syslibs = ','.join(sorted(syslibs.split(','))) - else: - syslibs = ','.join(sorted(syslibs.split())) - write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs) - - for varname in ('PREFIX', 'LIBDIR', 'INCLUDEDIR', 'PROTOBUF_INCLUDE_PATH'): - if varname in environ_cp: - write_to_bazelrc('build --define=%s=%s' % (varname, environ_cp[varname])) - - -def set_windows_build_flags(): - """Set Windows specific build options.""" - - # First available in VS 16.4. Speeds up Windows compile times by a lot. See - # https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion - # pylint: disable=line-too-long - write_to_bazelrc( - 'build --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions' - ) - - -def config_info_line(name, help_text): - """Helper function to print formatted help text for Bazel config options.""" - print('\t--config=%-12s\t# %s' % (name, help_text)) - - -def validate_cuda_config(environ_cp): - """Run find_cuda_config.py and return cuda_toolkit_path, or None.""" - - def maybe_encode_env(env): - """Encodes unicode in env to str on Windows python 2.x.""" - if not is_windows() or sys.version_info[0] != 2: - return env - for k, v in env.items(): - if isinstance(k, unicode): - k = k.encode('ascii') - if isinstance(v, unicode): - v = v.encode('ascii') - env[k] = v - return env - - cuda_libraries = ['cuda', 'cudnn'] - if is_linux(): - if environ_cp.get('TF_NCCL_VERSION', None): - cuda_libraries.append('nccl') - - find_cuda_script = os.path.join( - pathlib.Path(__file__).parent.resolve(), - 'third_party/tsl/third_party/gpus/find_cuda_config.py', - ) - if not os.path.isfile(find_cuda_script): - raise FileNotFoundError( - "Can't find 'find_cuda_config.py' script inside working directory," - f' expected in {find_cuda_script}' - ) - proc = subprocess.Popen( - [environ_cp['PYTHON_BIN_PATH'], find_cuda_script] + cuda_libraries, - stdout=subprocess.PIPE, - env=maybe_encode_env(environ_cp), - ) - - if proc.wait(): - # Errors from find_cuda_config.py were sent to stderr. - print('Asking for detailed CUDA configuration...\n') - return False - - config = dict( - tuple(line.decode('ascii').rstrip().split(': ')) for line in proc.stdout) - - print('Found CUDA %s in:' % config['cuda_version']) - print(' %s' % config['cuda_library_dir']) - print(' %s' % config['cuda_include_dir']) - - print('Found cuDNN %s in:' % config['cudnn_version']) - print(' %s' % config['cudnn_library_dir']) - print(' %s' % config['cudnn_include_dir']) - - if config.get('nccl_version', None): - print('Found NCCL %s in:' % config['nccl_version']) - print(' %s' % config['nccl_library_dir']) - print(' %s' % config['nccl_include_dir']) - - print('\n') - - environ_cp['CUDA_TOOLKIT_PATH'] = config['cuda_toolkit_path'] - return True - - -def get_gcc_compiler(environ_cp): - gcc_env = environ_cp.get('CXX') or environ_cp.get('CC') or which('gcc') - if gcc_env is not None: - gcc_version = run_shell([gcc_env, '--version']).split() - if gcc_version[0] in ('gcc', 'g++'): - return gcc_env - return None - - -def main(): - global _TF_WORKSPACE_ROOT - global _TF_BAZELRC - global _TF_CURRENT_BAZEL_VERSION - - parser = argparse.ArgumentParser() - parser.add_argument( - '--workspace', - type=str, - default=os.path.abspath(os.path.dirname(__file__)), - help='The absolute path to your active Bazel workspace.') - args = parser.parse_args() - - _TF_WORKSPACE_ROOT = args.workspace - _TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME) - - # Make a copy of os.environ to be clear when functions and getting and setting - # environment variables. - environ_cp = dict(os.environ) - - try: - current_bazel_version = retrieve_bazel_version() - except subprocess.CalledProcessError as e: - print('Error retrieving bazel version: ', e.output.decode('UTF-8').strip()) - raise e - - _TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version) - - reset_tf_configure_bazelrc() - - setup_python(environ_cp) - - if is_windows(): - environ_cp['TF_NEED_OPENCL'] = '0' - environ_cp['TF_CUDA_CLANG'] = '0' - # TODO(ibiryukov): Investigate using clang as a cpu or cuda compiler on - # Windows. - environ_cp['TF_DOWNLOAD_CLANG'] = '0' - environ_cp['TF_NEED_MPI'] = '0' - - if is_ppc64le(): - # Enable MMA Dynamic Dispatch support if 'gcc' and if linker >= 2.35 - gcc_env = get_gcc_compiler(environ_cp) - if gcc_env is not None: - - # Use gold linker if 'gcc' and if 'ppc64le' - write_to_bazelrc('build --linkopt="-fuse-ld=gold"') - - # Get the linker version - ld_version = run_shell([gcc_env, '-Wl,-version']).split() - - ld_version_int = convert_version_to_int(ld_version[3]) - if ld_version_int is None: - ld_version_int = convert_version_to_int(ld_version[4]) - - # Enable if 'ld' version >= 2.35 - if ld_version_int >= 2035000: - write_to_bazelrc( - 'build --copt="-DEIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH=1"') - - set_action_env_var( - environ_cp, 'TF_NEED_ROCM', 'ROCm', False, bazel_config_name='rocm') - if (environ_cp.get('TF_NEED_ROCM') == '1' and - 'LD_LIBRARY_PATH' in environ_cp and - environ_cp.get('LD_LIBRARY_PATH') != '1'): - write_action_env_to_bazelrc('LD_LIBRARY_PATH', - environ_cp.get('LD_LIBRARY_PATH')) - - if (environ_cp.get('TF_NEED_ROCM') == '1' and environ_cp.get('ROCM_PATH')): - write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH')) - - if (environ_cp.get('TF_NEED_ROCM') == '1' and environ_cp.get('HIP_PLATFORM')): - write_action_env_to_bazelrc('HIP_PLATFORM', environ_cp.get('HIP_PLATFORM')) - - if is_windows(): - print('\nWARNING: Cannot build with CUDA support on Windows.\n' - 'Starting in TF 2.11, CUDA build is not supported for Windows. ' - 'For using XLA GPU on Windows, you will need to build/install ' - 'XLA in WSL2.\n') - environ_cp['TF_NEED_CUDA'] = '0' - else: - environ_cp['TF_NEED_CUDA'] = str( - int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False))) - if (environ_cp.get('TF_NEED_CUDA') == '1' and - 'TF_CUDA_CONFIG_REPO' not in environ_cp): - - environ_save = dict(environ_cp) - for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): - - if validate_cuda_config(environ_cp): - cuda_env_names = [ - 'TF_CUDA_VERSION', - 'TF_CUBLAS_VERSION', - 'TF_CUDNN_VERSION', - 'TF_NCCL_VERSION', - 'TF_CUDA_PATHS', - # Items below are for backwards compatibility when not using - # TF_CUDA_PATHS. - 'CUDA_TOOLKIT_PATH', - 'CUDNN_INSTALL_PATH', - 'NCCL_INSTALL_PATH', - 'NCCL_HDR_PATH', - ] - # Note: set_action_env_var above already writes to bazelrc. - for name in cuda_env_names: - if name in environ_cp: - write_action_env_to_bazelrc(name, environ_cp[name]) - break - - # Restore settings changed below if CUDA config could not be validated. - environ_cp = dict(environ_save) - - set_tf_cuda_version(environ_cp) - set_tf_cudnn_version(environ_cp) - if is_linux(): - set_tf_nccl_version(environ_cp) - - set_tf_cuda_paths(environ_cp) - - else: - raise UserInputError( - 'Invalid CUDA setting were provided %d ' - 'times in a row. Assuming to be a scripting mistake.' % - _DEFAULT_PROMPT_ASK_ATTEMPTS) - - set_tf_cuda_compute_capabilities(environ_cp) - if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get( - 'LD_LIBRARY_PATH') != '1': - write_action_env_to_bazelrc('LD_LIBRARY_PATH', - environ_cp.get('LD_LIBRARY_PATH')) - - set_tf_cuda_clang(environ_cp) - if environ_cp.get('TF_CUDA_CLANG') == '1': - # Ask whether we should download the clang toolchain. - set_tf_download_clang(environ_cp) - if environ_cp.get('TF_DOWNLOAD_CLANG') != '1': - # Set up which clang we should use as the cuda / host compiler. - set_clang_cuda_compiler_path(environ_cp) - else: - # Use downloaded LLD for linking. - write_to_bazelrc('build:cuda_clang --config=download_clang_use_lld') - else: - # Set up which gcc nvcc should use as the host compiler - # No need to set this on Windows - if not is_windows(): - set_gcc_host_compiler_path(environ_cp) - set_other_cuda_vars(environ_cp) - else: - # CUDA not required. Ask whether we should download the clang toolchain and - # use it for the CPU build. - set_tf_download_clang(environ_cp) - - # ROCm / CUDA are mutually exclusive. - # At most 1 GPU platform can be configured. - gpu_platform_count = 0 - if environ_cp.get('TF_NEED_ROCM') == '1': - gpu_platform_count += 1 - if environ_cp.get('TF_NEED_CUDA') == '1': - gpu_platform_count += 1 - if gpu_platform_count >= 2: - raise UserInputError('CUDA / ROCm are mututally exclusive. ' - 'At most 1 GPU platform can be configured.') - - # Disable NCCL if XLA is configured for CPU - if gpu_platform_count == 0: - write_to_bazelrc('build --config=nonccl') - - set_cc_opt_flags(environ_cp) - set_system_libs_flag(environ_cp) - if is_windows(): - set_windows_build_flags() - - system_specific_test_config(environ_cp) - - print('Preconfigured Bazel build configs. You can use any of the below by ' - 'adding "--config=<>" to your build command. See .bazelrc for more ' - 'details.') - config_info_line('mkl', 'Build with MKL support.') - config_info_line( - 'mkl_aarch64', - 'Build with oneDNN and Compute Library for the Arm Architecture (ACL).') - config_info_line('monolithic', 'Config for mostly static monolithic build.') - config_info_line('numa', 'Build with NUMA support.') - config_info_line( - 'dynamic_kernels', - '(Experimental) Build kernels into separate shared objects.') - config_info_line('v1', 'Build with TensorFlow 1 API instead of TF 2 API.') - - print('Preconfigured Bazel build configs to DISABLE default on features:') - config_info_line('nogcp', 'Disable GCP support.') - - if gpu_platform_count == 1: - config_info_line('nonccl', 'Disable NVIDIA NCCL support.') - - -if __name__ == '__main__': - main() diff --git a/configure.py b/configure.py new file mode 120000 index 0000000000000..49938cbbc570f --- /dev/null +++ b/configure.py @@ -0,0 +1 @@ +build_tools/configure/configure.py \ No newline at end of file diff --git a/docs/_book.yaml b/docs/_book.yaml deleted file mode 100644 index a6030d45a9949..0000000000000 --- a/docs/_book.yaml +++ /dev/null @@ -1,47 +0,0 @@ -upper_tabs: -# Tabs left of dropdown menu -- include: /_upper_tabs_left.yaml -- include: /api_docs/_upper_tabs_api.yaml -# Dropdown menu -- name: Resources - path: /resources - is_default: true - menu: - - include: /resources/_menu_toc.yaml - lower_tabs: - # Subsite tabs - other: - - name: Overview - contents: - - heading: OpenXLA - - title: Overview - path: /xla - - title: XLA architecture - path: /xla/architecture - - title: Broadcasting semantics - path: /xla/broadcasting - - title: Develop a new backend for XLA - path: /xla/developing_new_backend - - title: Code Reviews Guide - path: /xla/code_reviews - - title: Operation semantics - path: /xla/operation_semantics - - title: Shapes and layout - path: /xla/shapes - - title: Aliasing - path: /xla/aliasing - - title: Tiled layout - path: /xla/tiled_layout - - title: Writing custom calls - path: /xla/custom_call - - heading: TensorFlow - XLA - - title: Known issues - path: /xla/known_issues - - title: Use AOT compilation - path: /xla/tfcompile - - title: XLA autoclustering - path: /xla/tutorials/autoclustering_xla - - title: Use XLA with tf.function - path: /xla/tutorials/jit_compile - -- include: /_upper_tabs_right.yaml diff --git a/docs/_toc.yaml b/docs/_toc.yaml new file mode 100644 index 0000000000000..50a24a1a6607c --- /dev/null +++ b/docs/_toc.yaml @@ -0,0 +1,51 @@ +toc: +- heading: XLA developer guide +- title: Getting started + section: + - title: Overview + path: /xla + - title: XLA architecture + path: /xla/architecture + - title: Operation semantics + path: /xla/operation_semantics +- title: Developer details + section: + - title: Broadcasting + path: /xla/broadcasting + - title: Shapes and layout + path: /xla/shapes + - title: Aliasing + path: /xla/aliasing + - title: Tiled layout + path: /xla/tiled_layout + - title: Writing custom calls + path: /xla/custom_call + - title: Persisted autotuning + path: /xla/persisted_autotuning + - title: Copybara quirks + path: /xla/copybara + - title: XLA Tooling + path: /xla/tools + - title: Using LSP autocompletion + path: /xla/lsp +- title: Contributing + section: + - title: Develop a new backend for XLA + path: /xla/developing_new_backend + - title: Developer guide + path: /xla/developer_guide + - title: Code reviews + path: /xla/code_reviews + - title: Build from source + path: /xla/build_from_source +- title: Using XLA in TensorFlow + section: + - title: Using XLA in TensorFlow + path: /xla/tf2xla + - title: Use tfcompile + path: /xla/tf2xla/tfcompile + - title: Autoclustering tutorial + path: /xla/tf2xla/tutorials/autoclustering_xla + - title: Use XLA with tf.function + path: /xla/tf2xla/tutorials/jit_compile + diff --git a/docs/async_ops.md b/docs/async_ops.md new file mode 100644 index 0000000000000..a2f7c1dbc7aff --- /dev/null +++ b/docs/async_ops.md @@ -0,0 +1,110 @@ +# Async HLO Instructions + +1. Adding async operations to HLO is cumbersome (i.e. `all-reduce-start` and + `all-reduce-done`). +2. The start and done split may be inadequate for some of the asynchronous use + cases. + +To target the first shortcoming, we propose to introduce one last set of new +asynchronous opcodes: `kAsyncStart`, `kAsyncUpdate`, and `kAsyncDone`. The idea +is to create a generic asynchronous opcode that can wrap any HLO instruction. +The actual operation that will be performed asynchronously will be encoded using +a called computation that only has the instruction as its root and any +parameters for inputs. The in-flight input/output buffer handling and aliasing +can then be shared for any asynchronous operation. The async-start instruction’s +output shape will then be a tuple of the input operands, output values, and any +intermediate state that is needed for the `async-update` or `async-done` +instructions. + +``` +%async_op { + %param0 = f32[64] parameter(0) + ROOT %op = f32[32] op(f32[64] %param0), op_specific_attr=”foo” +} + +%async-start = (f32[64], f32[32], s32[]) async-start(f32[64] %operand), + calls=%async_op +%async-done = f32[32] async-done((f32[64], f32[32], s32[]) %async-start) +``` + +In the representation above, only `async-start` has a called computation since +it is trivial to find what the `async-done` does by following its operand to +find the corresponding `async-start` to find the called computation. + +Also note +that the first element in the output tuple of `async-start` aliases with the +operand, so the buffer stays alive until at least the async-done instruction. +Similarly, the second element aliases with the output of `async-done`, and the +third element is the context state that is used to keep track of the +asynchronous operation. This representation also supports multiple tensors in +the asynchronous operation input and/or output and the aliasing works the same +way: + +``` +%async_op { + %param0 = f32[64] parameter(0) + %param1 = f32[64] parameter(1) + ROOT %op = (f32[32], f32[32]) op(f32[64] %param0, f32[64] %param1), + op_specific_attr=”foo” +} + +%async-start = ((f32[64], f32[64]), (f32[32], f32[32]), s32[]) + async-start(f32[64] %operand0, f32[64] %operand1), + calls=%async_op +%async-done = (f32[32], f32[32]) async-done(%async-start) +``` + +In addition, the op can further be decomposed into zero or more `async-update` +steps that perform intermediate computations. The input/output aliasing works +the same way with the `async-update` instruction and each `async-start` and +`async-update` instructions must have one user that is either another +`async-update` or an `async-done`: + +``` +%async_op { + %param0 = f32[64] parameter(0) + ROOT %op = f32[32] op(f32[64] %param0), op_specific_attr=”foo” +} + +%async-start = (f32[64], f32[32], s32[]) async-start(f32[64] %operand), + calls=%async_op +%async-update0 = (f32[64], f32[32], s32[]) async-update( + (f32[64], f32[32], s32[]) %async-start) +%async-update1 = (f32[64], f32[32], s32[]) async-update( + (f32[64], f32[32], s32[]) %async-update0) +%async-done = f32[32] async-done((f32[64], f32[32], s32[]) %async-update1) + +``` + +## Syntax sugar + +Since having a separate computation to define the operation that will be +performed asynchronously is a bit cumbersome, we also propose a syntax sugar to +automatically print and parse asynchronous operations as if they are first-class +opcodes. The idea is to treat the “-start”, “-update”, and “-done” suffixes +specially by automatically creating the computation and instruction (without the +suffix) when parsing. For example, the code snippet above can be pretty-printed +to the following and the two can be parsed to the same representation: + +``` +%op-start = (f32[64], f32[32], s32[]) op-start(f32[64] %operand), + op_specific_attr=”foo” +%op-update0 = (f32[64], f32[32], s32[]) op-update( + (f32[64], f32[32], s32[]) %op-start), + op_specific_attr=”foo” +%op-update1 = (f32[64], f32[32], s32[]) op-update( + (f32[64], f32[32], s32[]) %op-update0) +%op-done = f32[32] op-done((f32[64], f32[32], s32[]) %op-update1) + +``` + +In order not to create ambiguities, the verifier will not allow an operation to +be wrapped with async-start if we explicitly defined an opcode for that +operation with the “-start” and/or “-done” suffixes. This is also an escape +hatch in case we have any instructions that require HLO-level treatment that +doesn’t fit in the model described above (e.g. the aliasing input/output +buffers). So, initially, `copy-start`/`copy-done`, +`collective-permute-start`/`collective-permute-done` etc. will continue to use +their respective first-class opcodes instead of the new +`async-start`/`async-done` opcodes until we clean up the code to remove these +“-start”/”-done” opcodes. diff --git a/docs/build_from_source.md b/docs/build_from_source.md index f5b2ded3c4cd4..c273f7f3cdf8c 100644 --- a/docs/build_from_source.md +++ b/docs/build_from_source.md @@ -10,22 +10,12 @@ If you did not clone the XLA repository or install Bazel, please check out the ### Configure XLA builds are configured by the `.bazelrc` file in the repository's root -directory. The `./configure` or `./configure.py` scripts can be used to adjust -common settings. +directory. The `./configure.py` script can be used to adjust common settings. -If you need to change the configuration, run the `./configure` script from the -repository's root directory. This script will prompt you for the location of XLA -dependencies and asks for additional build configuration options (compiler -flags, for example). Refer to the *Sample session* section for details. - -``` -./configure -``` - -There is also a python version of this script, `./configure.py`. If using a -virtual environment, `python configure.py` prioritizes paths within the -environment, whereas `./configure` prioritizes paths outside the environment. In -both cases you can change the default. +If you need to change the configuration, run the `./configure.py` script from +the repository's root directory. This script has flags for the location of XLA +dependencies and additional build configuration options (compiler flags, for +example). Refer to the *Sample session* section for details. ### CPU support @@ -33,56 +23,60 @@ We recommend using a suitable docker container to build/test XLA, such as [TensorFlow's docker container](https://www.tensorflow.org/install/docker): ``` -docker run --name xla -w /xla -it -d --rm -v $PWD:/xla tensorflow/tensorflow:latest-gpu bash +docker run --name xla -w /xla -it -d --rm -v $PWD:/xla tensorflow/build:latest-python3.9 bash ``` -Using a docker container you can build XLA with CPU support using the following commands: +Using a docker container you can build XLA with CPU support using the following +commands: ``` -docker exec xla ./configure +docker exec xla ./configure.py --backend=CPU docker exec xla bazel build //xla/... --spawn_strategy=sandboxed --test_output=all ``` -If you want to build XLA targets with CPU support without Docker you need to install gcc-10: +If you want to build XLA targets with CPU support without Docker you need to +install clang. XLA currently builds on CI with clang-17, but earlier versions +should also work: ``` -apt install gcc-10 g++-10 +apt install clang ``` Then configure and build targets using the following commands: -``` -yes '' | GCC_HOST_COMPILER_PATH=/usr/bin/gcc-10 CC=/usr/bin/gcc-10 TF_NEED_ROCM=0 TF_NEED_CUDA=0 TF_CUDA_CLANG=0 ./configure +```sh +./configure.py --backend=CPU bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` - ### GPU support -We recommend using a GPU docker container to build XLA with GPU support, such -as: +We recommend using the same docker container as above to build XLA with GPU +support: ``` -docker run --name xla_gpu -w /xla -it -d --rm -v $PWD:/xla tensorflow/tensorflow:devel-gpu bash +docker run --name xla_gpu -w /xla -it -d --rm -v $PWD:/xla tensorflow/build:latest-python3.9 bash ``` To build XLA with GPU support use the following command: ``` -docker exec -e TF_NEED_CUDA=1 xla_gpu ./configure +docker exec xla_gpu ./configure.py --backend=CUDA docker exec xla_gpu bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` -If you want to build XLA targets with GPU support without Docker you need to install the following dependencies additional to CPU dependencies: [`cuda-11.2`](https://developer.nvidia.com/cuda-11.2.2-download-archive), [`cuDNN-8.1`](https://developer.nvidia.com/cudnn). +If you want to build XLA targets with GPU support without Docker you need to +install the following additional dependencies: +[`cuda-12.3`](https://developer.nvidia.com/cuda-downloads), +[`cuDNN-8.9`](https://developer.nvidia.com/cudnn). Then configure and build targets using the following commands: ``` -yes '' | GCC_HOST_COMPILER_PATH=/usr/bin/gcc-10 CC=/usr/bin/gcc-10 TF_NEED_ROCM=0 TF_NEED_CUDA=1 TF_CUDA_CLANG=0 ./configure +./configure.py --backend=CUDA bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` - For more details regarding [TensorFlow's GPU docker images you can check out this document.](https://www.tensorflow.org/install/source#gpu_support_3) diff --git a/docs/custom_call.md b/docs/custom_call.md index bd2bff418b176..84633d697daa2 100644 --- a/docs/custom_call.md +++ b/docs/custom_call.md @@ -14,6 +14,15 @@ program. > to change it capriciously, but it may change. Some possible future changes are > described below. +> **Caution** The HLO-visible names of functions registered with the custom-call +> macros API do not respect C++ namespaces. As a result, accidental collisions +> from functions registered by different libraries are entirely possible! The +> API will reject such duplicate registrations, but to avoid issues in large +> projects the safest option is to either fully namespace-qualify all references +> to the functions in both the `XLA_REGISTER_CUSTOM_CALL` registration macros +> and custom call target references or to use C-style namespacing directly in +> the function name. + ## Create a custom call on CPU You can create an HLO instruction that represents a custom call via XLA's client diff --git a/docs/developer_guide.md b/docs/developer_guide.md index 63bd56491cdde..53b3efcd8cab5 100644 --- a/docs/developer_guide.md +++ b/docs/developer_guide.md @@ -41,33 +41,27 @@ the repository, and create a pull request. 2. Create and run a [TensorFlow Docker container](https://www.tensorflow.org/install/docker). - To get the TensorFlow Docker image for CPU, run the following command: + To get the TensorFlow Docker image for both CPU and GPU building, run the + following command: ```sh docker run --name xla -w /xla -it -d --rm -v $PWD:/xla tensorflow/build:latest-python3.9 bash ``` - Alternatively, to get the TensorFlow Docker image for GPU, run the following - command: - - ```sh - docker run --name xla_gpu -w /xla -it -d --rm -v $PWD:/xla tensorflow/tensorflow:devel-gpu bash - ``` - ## Build Build for CPU: ```sh -docker exec xla ./configure +docker exec xla ./configure.py --backend=CPU docker exec xla bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` Build for GPU: ```sh -docker exec -e TF_NEED_CUDA=1 xla_gpu ./configure -docker exec xla_gpu bazel build --test_output=all --spawn_strategy=sandboxed //xla/... +docker exec xla ./configure.py --backend=CUDA +docker exec xla bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` Your first build will take quite a while because it has to build the entire diff --git a/docs/images/indexing_analysis_softmax.png b/docs/images/indexing_analysis_softmax.png new file mode 100644 index 0000000000000..6edeb05acb238 Binary files /dev/null and b/docs/images/indexing_analysis_softmax.png differ diff --git a/docs/images/indexing_analysis_transposes.svg b/docs/images/indexing_analysis_transposes.svg new file mode 100644 index 0000000000000..4a7d0f8c62d21 --- /dev/null +++ b/docs/images/indexing_analysis_transposes.svg @@ -0,0 +1,316 @@ + + + + + + + +G + + +Computation f (in fusion instruction fusion) + + + + +5734836738048 + +Parameter 0 +f32[20,10,50]{2,1,0} + + + +5734836737024 + +transpose +lhs_transpose_1 +dimensions={1,0,2} +f32[10,20,50]{2,1,0} + + + +5734836738048->5734836737024 + + + + + + + + +5734836733952 + +transpose +rhs_transpose_1 +dimensions={2,1,0} +f32[50,10,20]{2,1,0} + + + +5734836738048->5734836733952 + + + + + + + + +5734836736000 + +exponential +lhs_e +f32[10,20,50]{2,1,0} + + + +5734836737024->5734836736000 + + + + + + + + +5734836734976 + +transpose +lhs_transpose_2 +dimensions={0,2,1} +f32[10,50,20]{2,1,0} + + + +5734836736000->5734836734976 + + + + + + + + +5734836698112 + +add +f32[10,50,20]{2,1,0} + + + +5734836734976->5734836698112 + + + + + +0 + + + +5734836732928 + +exponential +rhs_log +f32[50,10,20]{2,1,0} + + + +5734836733952->5734836732928 + + + + + + + + +5734836731904 + +transpose +rhs_transpose_2 +dimensions={1,0,2} +f32[10,50,20]{2,1,0} + + + +5734836732928->5734836731904 + + + + + + + + +5734836731904->5734836698112 + + + + + +1 + + + +cluster_5734844288888 + + +ROOT + + + + + +5734836698112->cluster_5734844288888 + + + + + + + + diff --git a/docs/index.md b/docs/index.md index dce5b593748fe..76bbb657f9fe7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -37,16 +37,6 @@ Alibaba, Amazon Web Services, AMD, Apple, Arm, Google, Intel, Meta, and NVIDIA. ## Documentation -To learn more about XLA, check out the guides below. If you're a new XLA +To learn more about XLA, check out the links on the left. If you're a new XLA developer, you might want to start with [XLA architecture](architecture.md) and then read [Code reviews](code_reviews.md). - -- [Aliasing in XLA](aliasing.md) -- [XLA architecture](architecture.md) -- [Broadcasting](broadcasting.md) -- [Code reviews](code_reviews.md) -- [XLA custom calls](custom_call.md) -- [Developing a new backend for XLA](developing_new_backend.md) -- [Operation semantics](operation_semantics.md) -- [Shapes and layout](shapes.md) -- [Tiled layout](tiled_layout.md) diff --git a/docs/indexing.md b/docs/indexing.md new file mode 100644 index 0000000000000..4ee61f16f6203 --- /dev/null +++ b/docs/indexing.md @@ -0,0 +1,567 @@ +# Indexing analysis + +This document describes the HLO indexing analysis, which lets you symbolically +compute indexing maps for HLO ops. The indexing map is a function that maps +indices of one tensor to the indices of another, e.g. indices of an HLO +instruction output to indices of HLO instruction inputs or vice versa. + +#### Example + +For a broadcast from `tensor<20xf32>` to `tensor<10x20x30xf32>` + +```c +p0 = f32[20] parameter(0) +bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1} +``` + +the indexing map from the output to input is $(i, j, k) \mapsto (j)$ for $i \in +[0, 10]$, $j \in [0, 20]$ and $k \in [0, 30]$. + +## Motivation + +XLA GPU uses several bespoke solutions to reason about coalescing, operand +utilization, and tiling schemes (more details below). The goal of indexing +analysis is providing a reusable component for such use cases. Indexing analysis +is built on MLIR's Affine Map infrastructure and adds HLO semantics. + +### Coalescing + +Reasoning about memory coalescing becomes feasible for non-trivial cases, when +we know what elements/slices of the inputs are read to compute an element of the +output. + +### Operand Utilization + +Operand utilization in XLA indicates how much each input of the instruction is +used assuming its output is fully used. Currently, utilization is also not +computed for a generic case. Indexing analysis allows to compute utilization +precisely. + +### Tiling + +A tile/slice is hyper-rectangular subset of a tensor parameterized by offsets, +sizes and strides. Tile propagation is a way to compute tile parameters of the +producer/consumer of the op using the tiling parameters of the op itself. There +is already a +[library](https://github.com/openxla/xla/blob/main/xla/service/gpu/triton_tiling_propagation.h) +that does it for softmax and dot. Tile propagation can be made more generic and +robust if it is expressed via indexing maps. + +## Function and Domain + +The indexing map is a function $\boldsymbol{f}(\boldsymbol{d}, \boldsymbol{s})$ +that maps a multi-index $\boldsymbol{d}$ of a tensor $A$ to elements/ranges of +tensor $B$. The parameter $\boldsymbol{s}$ refers to the ranges of indices of +the dimensions that are present in tensor $B$, but not in tensor $A$​. + +For example, if we have a reduction from `tensor<2x4x8x16xf32>` to +`tensor<4x8xf32>`, then the indexing map from the 2D output to the 4D input is +$(d_0, d_1) \mapsto (s_0, d_0, d_1, s_1)$, where $d_i$ are the dimension +parameters that correspond to the indices of the output tensor. Parameters $s_j$ +encode multiple values, i.e. to compute a $(d_0, d_1)$ element of the output, we +need $(s_0, d_0, d_1, s_1)$ elements of the input, where $s_0 \in [0, 2)$ and +$s_1 \in [0, 16)$. + +This mapping can be constructed from the attributes of HLO instructions or the +mappings of unfused instructions can be composed to get indexing for a fusion. +The mapping also has a domain, which specifies for what elements of the tensor +the mapping exists. + +$$ +\begin{eqnarray} +\boldsymbol{f}(\boldsymbol{d}, \boldsymbol{s})\; &s.t.& \\ +\boldsymbol{lb}_d &\leq& \boldsymbol{d} \leq \boldsymbol{ub}_d \\ +\boldsymbol{lb}_s &\leq& \boldsymbol{s} \leq \boldsymbol{ub}_s \\ +\boldsymbol{lb}_g &\leq& \boldsymbol{g}(\boldsymbol{d}, + \boldsymbol{s}) \leq \boldsymbol{ub}_g +\end{eqnarray} +$$ + +Since we want to minimize recomputation, we need a library for symbolic +computations. XLA already depends on MLIR, so we use +[mlir::AffineMap](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/IR/AffineMap.h) +instead of writing a symbolic arithmetic library. + +A typical `AffineMap` looks like + +``` +(d0)[s0, s1] -> (s0 + 5, d0 * 2, s1 * 3 + 50) +``` + +`AffineMap` conveniently has two types of parameters: *dimensions* and *symbols* +that we can use for $\boldsymbol d$ and $\boldsymbol s$ respectively. +`AffineMap` does not contain any metadata about ranges of the dimensions, so we +have to provide this data ourselves. + +```c++ +struct Range { + int64_t lower_bound; + int64_t upper_bound; +}; + +struct IndexingMap { + mlir::AffineMap affine_map; + std::vector dim_ranges; + std::vector symbol_ranges; + llvm::DenseMap expr_ranges; +}; + +``` + +`dim_ranges` encodes the **inclusive** box constraints for the dimension +parameters $\boldsymbol{d}$ of the indexing map, which usually coincide with the +shape of the output tensor for ops like transpose, reduce, elementwise, dot, but +there are some exceptions like +[HloConcatenateInstruction](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#concatenate). + +`symbol_ranges` encode possible values that $\boldsymbol {s}$ parameters can +take. + +Let's study-by-example to understand what's all of the above actually means. + +## Indexing Maps for Unfused Ops + +### Elementwise + +For elementwise ops the indexing map is an identity. + +```c++ + p0 = f32[10, 20] parameter(0) + p1 = f32[10, 20] parameter(1) + add = f32[10, 20] add(p0, p1) +``` + +The output to input maps: + +- output -> input_0: $(d_0, d_1) \mapsto (d_0, d_1)$ for $\boldsymbol{d} \in + [0,9] \times [0, 19]$, i.e. $\boldsymbol{d} \in {\rm Dom}(output)$ +- output -> input_1: $(d_0, d_1) \mapsto (d_0, d_1)$ for $\boldsymbol{d} \in + {\rm Dom} (output)$ + +The input to output maps + +- input_i -> output: $(d_0, d_1) \mapsto (d_0, d_1)$ for $\boldsymbol{d} \in + {\rm Dom}(input)$ + +### [Broadcast](https://openxla.org/xla/operation_semantics#broadcastindim) + +Broadcasting means that some of the dimensions will be removed when we map +output to input and added when we map input to output. + +```c+ +p0 = f32[20] parameter(0) +bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1} +``` + +The output to input map: + +- output -> input: $(d_0, d_1, d_2) \mapsto (d_1)$ for $\boldsymbol{d} \in + {\rm Dom}(output)$ + +The input to output map + +- input -> output: $(d_0) \mapsto (s_0, d_0, s_1)$ for $\boldsymbol{d} \in + {\rm Dom}(input)$ and $\boldsymbol{s} \in [0, 9] \times [0, 29]$. + +Note that now we have $\boldsymbol s$ on the right side for the input-to-output +mapping. Those are the symbols that represent ranges of values. For example, in +this particular case every element of input with index $d_0$ is mapped to a +10x1x30 slice of the output. + +### Constant and [Iota](https://openxla.org/xla/operation_semantics#iota) + +Conveniently, they do not have any input parameters, so there is nothing to +compute indexing for. + +### [Transpose](https://openxla.org/xla/operation_semantics#transpose) + +Indexing map for transpose is a permutation of input/output dimensions. + +```c+ +p0 = f32[3, 12288, 6, 128] parameter(0) +transpose = f32[3, 6, 128, 12288] transpose(p0), dimensions={0, 2, 3, 1} +``` + +The output to input map: + +- output -> input: $(d_0, d_1, d_2, d_3) \mapsto (d_0, d_3, d_1, d_2)$ for + $\boldsymbol{d} \in {\rm Dom}(output)$ + +The input to output map: + +- input -> output: $(d_0, d_1, d_2, d_3) \mapsto (d_0, d_2, d_3, d_1)$ for + $\boldsymbol{d} \in {\rm Dom}(input)$ + +### [Reverse](https://openxla.org/xla/operation_semantics#rev_reverse) + +Indexing map for reverse changes the reverted dimensions to $upper\_bound(d_i) - +d_i$: + +```c+ +p0 = f32[1, 17, 9, 9] parameter(0) +reverse = f32[1, 17, 9, 9] reverse(p0), dimensions={1, 2} +``` + +The output to input map: + +- output -> input: $(d_0, d_1, d_2, d_3) \mapsto (d_0, -d_1 + 16, -d_2 + 8, + d_3)$ for $\boldsymbol{d} \in {\rm Dom}(output)$ + +The input to output map: + +- input -> output: $(d_0, d_1, d_2, d_3) \mapsto (d_0, -d_1 + 16, -d_2 + 8, + d_3)$ for $\boldsymbol{d} \in {\rm Dom}(input)$ + +### **[(Variadic)Reduce](https://openxla.org/xla/operation_semantics#reduce)** + +Variadic reduction have several inputs and several inits, the map from output to +input adds the reduced dimensions. So, it behaves like an inverse to a broadcast +in some sense. + +```c+ +p0 = f32[256,10] parameter(0) +p0_init = f32[] constant(-inf) +p1 = s32[256,10] parameter(1) +p1_init = s32[] constant(0) +reduce = (f32[10], s32[10]) reduce(p0, p1, p0_init, p1_init), + dimensions={0}, to_apply=min +``` + +The output to input maps: + +- output -> input_j: $(d_0) \mapsto (s_0, d_0)$ for $\boldsymbol{d} \in {\rm + Dom}(output)$ and $\boldsymbol{s} \in [0, 9]$ +- output -> init_j: $(d_0) \mapsto ()$ for $\boldsymbol{d} \in {\rm + Dom}(output)$ + +The input to output maps: + +- input_i -> output_j: $(d_0, d_1) \mapsto (d_1)$ for $\boldsymbol{d} \in {\rm + Dom}(input)$ +- init_i -> output_j: $() \mapsto (s_0)$ for $\boldsymbol{s} \in [0, 9]$ + +for $i, j = 0, \ldots, INPUT\\_COUNT$. + +### [Slice](https://openxla.org/xla/operation_semantics#slice) + +Indexing from output to input for slice results in a strided indexing map which +is valid for every element of the output. Mapping from the input to output is +restricted to a strided range of the elements in the input. + +```c+ +p0 = f32[10, 20, 50] parameter(0) +slice = f32[5, 3, 25] slice(f32[10, 20, 50] p0), + slice={[5:10:1], [3:20:7], [0:50:2]} +``` + +The output to input map: + +- output -> input: $(d_0, d_1, d_2) \mapsto (d_0 + 5, 7d_1 + 3, 2d_2)$ for + $\boldsymbol{d} \in {\rm Dom}(output)$ + +The input to output map: + +- input -> output: $(d_0, d_1, d_2) \mapsto (d_0, d_1 / 7, d_2 / 2)$ for + $\boldsymbol{d} \in [5, 9] \times [3, 19] \times [0, 49]$ with strides $[1, + 7, 2]$​. + +**TBD**: input-to-output indexing + +### [Reshape](https://openxla.org/xla/operation_semantics#reshape) + +Reshapes come in different flavors. + +#### Collapse shape + +This is a "linearizing" reshape from N-D to 1D. + +```c+ +p0 = f32[4,8] parameter(0) +reshape = f32[32] reshape(p0) +``` + +The output to input map: + +- output -> input: $(d_0) \mapsto (d_0 / 8, d_0 \mod 8)$ for $\boldsymbol{d} + \in {\rm Dom}(output)$ + +The input to output map: + +- input -> output: $(d_0, d_1) \mapsto (8 d_0 + d_1)$ for $\boldsymbol{d} \in + {\rm Dom}(input)$. + +#### Expand shape + +This is an inverse "collapse shape" op, it reshapes a 1D input into N-D output. + +```c+ +p0 = f32[32] parameter(0) +reshape = f32[4, 8] reshape(p0) +``` + +The output to input map: + +- output -> input: $(d_0, d_1) \mapsto (8 d_0 + d_1)$ for $\boldsymbol{d} \in + {\rm Dom}(output)$ + +The input to output map: + +- input -> output: $(d_0) \mapsto (d_0 / 8, d_0 \mod 8)$ for $\boldsymbol{d} + \in {\rm Dom}(input)$. + +#### Generic reshape + +These are the reshape ops that cannot be represented as a single expand or +collapse shape. They can be only represented as a composition of 2 or more +expand or collapse shapes. + +##### Example 1: Linearization-delinearization. + +```c+ +p0 = f32[4,8] parameter(0) +reshape = f32[2, 4, 4] reshape(p0) +``` + +This reshape can be represented as a composition of collapse shape of +`tensor<4x8xf32>` to `tensor<32xf32>` and then a shape expansion to +`tensor<2x4x4xf32>`. + +The output to input map: + +- output -> input: $(d_0, d_1, d_2) \mapsto (2d_0 + (4d_1 + d_2) / 8, 4d_1 + + d_2) \mod 8)$ + +for $\boldsymbol{d} \in {\rm Dom}(output)$ + +The input to output map: + +- input -> output: $(d_0, d_1) \mapsto ((8d_0 + d_1) / 16, ((8d_0 + d_1) \mod + 16) / 4, d_1 \mod 4)$ + +for $\boldsymbol{d} \in {\rm Dom}(input)$. + +##### Example 2: Expanded and collapsed subshapes + +```c+ +p0 = f32[4, 8, 12] parameter(0) +reshape = f32[32, 3, 4] reshape(p0) +``` + +This reshape can be represented as a composition of two reshapes. The first one +collapses the outermost dimensions `tensor<4x8x12xf32>` to `tensor<32x12xf32>` +and the second one expand the innermost dimension `tensor<32x12xf32>` into +`tensor<32x3x4xf32>`. + +The output to input map: + +- output -> input: $(d_0, d_1, d_2) \mapsto (d_0 / 8, d_0 \mod 8, 4d_1 + d_2)$ + for $\boldsymbol{d} \in {\rm Dom}(output)$ + +The input to output map: + +- input -> output: $(d_0, d_1, d_2) \mapsto (8d_0 + d_1, d_2 / 4, d_2 \mod 4)$ + for $\boldsymbol{d} \in {\rm Dom}(input)$. + +### Bitcast + +A bitcast op can be represented as a +[sequence of transpose-reshape-transpose](https://github.com/openxla/xla/blob/578b6df240be94c3c84129fd83f34487efc623a5/xla/shape_util.h#L813). +Therefore, its indexing maps are just a composition of indexing maps for this +sequence. + +### [Concatenate](https://openxla.org/xla/operation_semantics#concatenate) + +Output-to-input mapping for concat is defined for all inputs, but with +non-overlapping domains, i.e. only one of the inputs will be used at a time. + +```c+ +p0 = f32[3,50] parameter(0) +p1 = f32[3,30] parameter(1) +concat = f32[3,80] concatenate(f32[3,50] p0, f32[3,30] p1), + dimensions={1} +``` + +The output to input map: + +- output -> input 1: + +$(d_0, d_1) \mapsto (d_0, d_1)$ for $\boldsymbol{d} \in [0, 2] \times [0, 49]$ + +- output -> input 2: + +$(d_0, d_1) \mapsto (d_0, d_1 - 50)$ for $\boldsymbol{d} \in [0, 2] \times [50, +79]$ + +The inputs to output map: + +- input 1 -> output: $(d_0, d_1) \mapsto (d_0, d_1)$ for $\boldsymbol{d} \in + {\rm Dom}(input_1)$. +- input 2 -> output: $(d_0, d_1) \mapsto (d_0, d_1 + 50)$ for $\boldsymbol{d} + \in {\rm Dom}(input_2)$. + +### [Dot](https://openxla.org/xla/operation_semantics#dot) + +Indexing maps for dot are very similar to the ones of reduce. + +```c+ +p0 = f32[4, 128, 256] parameter(0) +p1 = f32[4, 256, 64] parameter(1) +dot = f32[4, 128, 64] dot(p0, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} +``` + +The output to inputs maps: + +- output -> input_1: $(d_0, d_1, d_2) \mapsto (d_0, d_1, s_0)$ for + $\boldsymbol{d} \in {\rm Dom}(output)$ and $\boldsymbol{s} \in [0, 255]$ +- output -> input_2: $(d_0, d_1, d_2) \mapsto (d_0, s_0, d_2)$ for + $\boldsymbol{d} \in {\rm Dom}(output)$ and $\boldsymbol{s} \in [0, 255]$ + +The inputs to output maps: + +- input_1 -> output: $(d_0, d_1, d_2) \mapsto (d_0, d_1, s_0)$ for + $\boldsymbol{d} \in {\rm Dom}(input_1)$ and $\boldsymbol{s} \in [0, 63]$ +- input_2 -> output: $(d_0, d_1, d_2) \mapsto (d_0, s_0, d_1)$ for + $\boldsymbol{d} \in {\rm Dom}(input_2)$ and $\boldsymbol{s} \in [0, 127]$ + +### [Pad](https://openxla.org/xla/operation_semantics#pad) + +Indexing of PadOp is inverse of SliceOp indexing. + +```c+ +p0 = f32[4, 4] parameter(0) +p1 = f32[] parameter(1) +pad = f32[12, 16] pad(p0, p1), padding=1_4_1x4_8_0 +``` + +The padding config `1_4_1x4_8_0` denotes `lowPad_highPad_interiorPad_dim_0 x lowPad_highPad_interiorPad_dim_1`. + +The output to input maps: + +- output -> input: $(d_0, d_1) \mapsto ((d_0 - 1) / 2, d_1 - 4)$ + for $\boldsymbol{d} \in [1, 7] \times [4, 7]$ and $(d_0 - 1) \mod 2 \equiv 0$ +- output -> init: $(d_0, d_1) \mapsto ()$ for $\boldsymbol{d} \in {\rm Dom}(output)$ + + +### [ReduceWindow](https://openxla.org/xla/operation_semantics#reducewindow) + +ReduceWindow in XLA also performs padding. Therefore, the indexing maps can be +computed as a composition of ReduceWindow indexing that does not do any padding +and PadOp's indexing. + + +```c+ +c_inf = f32[] constant(-inf) +p0 = f32[1024, 514] parameter(0) +reduce-window = f32[1024, 3] reduce-window(p0, c_inf), + window={size=1x512 pad=0_0x0_0}, to_apply=max +``` + +The output to input maps: + +- output -> input: $(d_0, d_1) \mapsto (d_0, d_1 + s_0)$ for $\boldsymbol{d} \in [0, 1023] \times [0, 2]$ and $\boldsymbol{s} \in [0, 511]$ +- output -> init: $(d_0, d_1) \mapsto ()$ for $\boldsymbol{d} \in {\rm Dom}(output)$ + +## Indexing Maps for Fusion + +Indexing map for fusion op is a composition of indexing maps for every op in the +cluster. It can happen that some inputs are read several times with different +access patterns. + +### One input, several indexing maps + +Here is an example for $p_0 + p_0^T$ + +```c+ +f { + p0 = f32[1000, 1000] parameter(0) + transpose_p0 = f32[1000, 1000]{0, 1} transpose(p0), dimensions={1, 0} + ROOT a0 = f32[1000, 1000] add(p0, transpose_p0) +} +``` + +The output-to-input indexing maps for `p0` will be $(d_0, d_1) \mapsto (d_0, +d_1)$ and $(d_0, d_1) \mapsto (d_1, d_0)$. It means that to compute one element +of the output we might need to read the input parameter twice. + +### One input, deduplicated indexing map + +![img](./images/indexing_analysis_transposes.svg) + +There are cases when the indexing maps are actually the same, even though it is +not immediately obvious. + +```c+ +f { + p0 = f32[20, 10, 50] parameter(0) + lhs_transpose_1 = f32[10, 20, 50] transpose(p0), dimensions={1, 0, 2} + lhs_e = f32[10, 20, 50] exponential(lhs_transpose_1) + lhs_transpose_2 = f32[10, 50, 20] transpose(lhs_e), dimensions={0, 2, 1} + rhs_transpose_1 = f32[50, 10, 20] transpose(p0), dimensions={2, 1, 0} + rhs_log = f32[50, 10, 20] exponential(rhs_transpose_1) + rhs_transpose_2 = f32[10, 50, 20] transpose(rhs_log), dimensions={1, 0, 2} + ROOT add = f32[10, 50, 20] add(lhs_transpose_2, rhs_transpose_2) +} +``` + +The output-to-input indexing map for `p0` in this case is just $(d_0, d_1, d_2) +\mapsto (d_2, d_0, d_1)$. + + +### Softmax + +![img](./images/indexing_analysis_softmax.png) + +The output-to-input indexing maps for `parameter 0` for softmax: + +- $(d_0, d_1, d_2) \mapsto (d_0, d_1, d_2)$ +- $(d_0, d_1, d_2)[s_0] \mapsto (d_0, d_1, s_0)$ + +for $\boldsymbol{d} \in {\rm Dom}(output)$ and $\boldsymbol{s} \in [0, 124]$ +refers to the inner-most dimension of the input. + +## Indexing Map Simplifier + +The default simplifier for `mlir::AffineMap` upstream cannot make any +assumptions about the ranges of dimensions/symbols. Therefore, it cannot +simplify expressions with `mod` and `div`efficiently. + +We can leverage the knowledge about lower and upper bounds of the +sub-expressions in the affine maps to simplify them even more. + +The simplifier can rewrite the following expressions. + +1. $(d_0, d_1) \mapsto (d_0 + d1 / 16, d1 \mod 16)$ for $\boldsymbol{d} \in [0, + 6] \times [0, 14]$ becomes $(d_0, d_1) \mapsto (d_0, d_1)$ +2. $(d_0, d_1, d_2) \mapsto ((100d_0 + 10d_1 + d_2) /100, ((100d_0 + 10d_1 + + d_2) \mod 100) / 10, d_2 \mod 10)$ for $d_i \in [0, 9]$ becomes $(d_0, d_1, + d_2) \mapsto (d_0, d_1, d_2)$. +3. $(d_0, d_1, d_2) \mapsto ((16d_0 + 4d_1 + d_2) /8, (16d_0 + 4d_1 + d_2) \mod + 8)$ for $d_i \in [0, 9]$ becomes $(d_0, d_1, d_2) \mapsto (2d_0 + (4d_1 + + d_2) /8,(4d_1 + d_2) \mod 8)$. +4. $(d_0, d_1) \mapsto (-(-11d_0 - d_1 + 109) / 11 + 9)$ for $\boldsymbol{d} + \in [0, 9] \times [0, 10]$ becomes $(d_0, d_1) \mapsto (d_0)$. + +Indexing map simplifier allows us to understand that some of the chained +reshapes in HLO cancel each other. + +```c+ +p0 = f32[10, 10, 10] parameter(0) +reshape1 = f32[50, 20] reshape(p0) +reshape2 = f32[10, 10, 10] reshape(reshape1) +``` + +After the composition of indexing maps and their simplification we will get + +$(d_0, d_1, d_2) \mapsto (d_0, d_1, d_2)$. + +Indexing map simplification also simplifies the constraints. + +1. Constraints of type +`lower_bound <= affine_expr (floordiv, +, -, *) constant <= upper_bound` are +rewritten as `updated_lower_bound <= affine_expr <= updated_upped_bound`. +2. Constraints that are always satisfied, e.g. $d_0 + s_0 in [0, 20]$ +for $d_0 \in [0, 5]$ and $s_0 \in [1, 3]$ are eliminated. +3. Affine expressions in the constraints are optimized as the indexing affine +map above. diff --git a/docs/lsp.md b/docs/lsp.md new file mode 100644 index 0000000000000..12dd0f1724782 --- /dev/null +++ b/docs/lsp.md @@ -0,0 +1,20 @@ +# Setting up LSP with clangd + +## Background + +Editors such as Emacs, Vim, or VS Code support features like code navigation, +code completion, inline compiler error messages, and others, through +[LSP](https://en.wikipedia.org/wiki/Language_Server_Protocol), the Language +Server Protocol. A common language server with LSP support is +[clangd](https://clangd.llvm.org), which relies on the presence of +`compile_commands.json`, a JSON file with a record of the compile commands for +each file in a project. + +## How do I generate `compile_commands.json` for XLA source code? + +Use the +[build_tools/lint/generate_compile_commands.py](https://github.com/openxla/xla/blob/main/build_tools/lint/generate_compile_commands.py) +script. The following invocation from XLA repo root generates a +`compile_commands.json` file in place: `bash bazel aquery "mnemonic(CppCompile, +//xla/...)" --output=jsonproto | \ python3 +build_tools/lint/generate_compile_commands.py` diff --git a/docs/persisted_autotuning.md b/docs/persisted_autotuning.md new file mode 100644 index 0000000000000..5d1f01ab50144 --- /dev/null +++ b/docs/persisted_autotuning.md @@ -0,0 +1,76 @@ +# Persisted autotuning (GPU only) + +We use OpenAI Triton for generating some of the GPU kernels. Triton allows +generating fast GPU kernels for certain fusions, but we have to tune some +parameters for each such fusion. + +This can take a long time if there are many fusions, so we provide a way to load +those autotuning results, while still running the other compilation steps +normally. Autotuning caches are still useful if we make a few changes: the +fusions that are present in the cache will use the cache, and the other ones +will be autotuned normally. + +The autotuning results can be dumped/loaded using these parameters: + +``` +--xla_gpu_dump_autotune_results_to= +--xla_gpu_load_autotune_results_from= +``` + +If we specify a .txt or .textproto file, then the cache will be dumped in +textproto format, otherwise in binary protobuf format. + +## In tests + +Persisted autotuning can also be used in tests. It is recommended to use it if +the tests are very big, especially if the performance of the test environment is +limited. + +It only works well if the autotune cache contains results generated on the same +type of GPU where the tests are being run. + +### Making a test use persisted autotuning + +For now let's assume that the test in question always uses the same GPU type. + +1. We have to export the autotune results from the test, for example by + specifying these parameters to the test command: + + ``` + --test_env=XLA_FLAGS=--xla_gpu_dump_autotune_results_to=TEST_UNDECLARED_OUTPUTS_DIR/autotune_cache.textproto + --test_sharding_strategy=disabled + ``` + + Sharding must be disabled to correctly get a single autotune cache for all + tests. + +2. Then we have to upload that cache to our code repository. + +3. Then we have to add the cache to the data dependencies of our test target, + and load it using an environment variable. + + ``` + data = ["test_autotune_cache.textproto"], + env = {"XLA_FLAGS": "--xla_gpu_load_autotune_results_from=" + + "$(execpath test_autotune_cache.textproto)"}, + ``` + + (It is OK to use sharding in tests that load autotune results.) + +Please also see the example tests in +[xla/service/gpu/tests/BUILD](https://github.com/openxla/xla/blob/main/xla/service/gpu/tests/BUILD): + +- load_autotune_results_using_execpath_test +- load_autotune_results_from_test_workspace_test +- dump_autotune_results_to_test_outputs_test + +### Cache obsolescence + +If many changes are made to a model, it is possible that the cache will no +longer contain all fusions, so the test will become slower. In this case we +would have to regenerate the autotuning cache. + +If we start using a new type of GPU for running the tests, the same applies. + +The cache may also become obsolete if the XLA compiler evolves and generates +different fusions. diff --git a/docs/tf2xla/index.md b/docs/tf2xla/index.md new file mode 100644 index 0000000000000..edde1f7de6237 --- /dev/null +++ b/docs/tf2xla/index.md @@ -0,0 +1,239 @@ +# XLA: Optimizing Compiler for Machine Learning + +[OpenXLA](https://openxla.org) is a domain-specific compiler for linear +algebra that can accelerate TensorFlow models with potentially no source code +changes. + +## Introduction + +When a TensorFlow program is run, all of the operations are executed +individually by the TensorFlow executor. Each TensorFlow operation has a +precompiled GPU kernel implementation that the executor dispatches to. + +XLA provides an alternative mode of running models: it compiles the TensorFlow +graph into a sequence of computation kernels generated specifically for the +given model. Because these kernels are unique to the model, they can exploit +model-specific information for optimization. For example, let's look at an +optimization XLA does in the context of a simple TensorFlow computation: + +``` +def model_fn(x, y, z): + return tf.reduce_sum(x + y * z) +``` + +Run without XLA, the graph launches three kernels: one for the multiplication, +one for the addition and one for the reduction. However, XLA can optimize the +graph so that it computes the result in a single kernel launch. It does this by +"fusing" the addition, multiplication and reduction into a single GPU kernel. +Moreover, this fused operation does not write out the intermediate values +produced by `y*z` and `x+y*z` to memory; instead it "streams" the results of +these intermediate computations directly to their users while keeping them +entirely in GPU registers. Fusion is XLA's single most important optimization. +Memory bandwidth is typically the scarcest resource on hardware accelerators, so +removing memory operations is one of the best ways to improve performance. + +## Enable XLA for TensorFlow models + +### Explicit compilation with `tf.function(jit_compile=True)` + +Explicit compilation API offers a fine-grained control for choosing which +functions should be compiled. For example, the following TensorFlow function +which performs the MNIST training is compiled with XLA: + +``` +@tf.function(jit_compile=True) +def train_mnist(images, labels): + images, labels = cast(images, labels) + + with tf.GradientTape() as tape: + predicted_labels = layer(images) + loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=predicted_labels, labels=labels + )) + layer_variables = layer.trainable_variables + grads = tape.gradient(loss, layer_variables) + optimizer.apply_gradients(zip(grads, layer_variables)) +``` + +The `jit_compile` API has _must-compile_ semantics: either the entire +function is compiled with XLA, or an `errors.InvalidArgumentError` exception is +thrown. XLA can not currently compile functions where dimensions are not +_inferrable_: that is, if it's not possible to infer the dimensions of all +tensors without running the entire computation. For example, the following +function will not compile: + +``` +@tf.function +def not_compilable(x): + return tf.unique(x) +``` + +Shapes can vary across the runs though: + +``` +@tf.function(jit_compile=True) +def recompiled_on_launch(a, b): + return a + b + +recompiled_on_launch(tf.ones([1, 10]), tf.ones([1, 10])) +recompiled_on_launch(tf.ones([1, 100]), tf.ones([1, 100])) +``` + +Note: Nesting behavior: the function will be compiled if at least one function +in its call stack has `jit_compile=True`. + +See the [tutorial colab](./tutorials/jit_compile.ipynb) for a more detailed +usage example, and a +[tutorial video](https://www.youtube.com/watch?v=cPAD9vLKE0c) on +`jit_compile=True` usage. + +### Usage with Keras + +For Keras models, `jit_compile=True` can be set as an argument to +[`model.compile`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#compile): + +``` +model.compile(optimizer="adam", jit_compile=True) +``` + +### Usage with distributed strategy + +XLA:GPU can be used with TF distributed strategy +([`MirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy) +or +[`MultiWorkerMirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy)) +by annotating step function with `jit_compile=True`: + +``` +@tf.function(jit_compile=True) +def step_fn(): + t = tf.ones(shape=[100], dtype=tf.float32) + ctx = tf.distribute.get_replica_context() + return ctx.all_reduce(tf.distribute.ReduceOp.SUM, t) + +@tf.function +def run_fn(): + return strategy.run(step_fn) +``` + +### Auto-clustering + +A simple way to start using XLA in TensorFlow models without any changes is to +enable _auto-clustering_, which automatically finds _clusters_ (connected +subgraphs) within the TensorFlow functions which can be compiled and executed +using XLA. Auto-clustering on GPU can be enabled by setting the `TF_XLA_FLAGS` +environment variable: + +Note: In TF2, only the code inside `tf.function` will be clustered. + +``` +$ TF_XLA_FLAGS=--tf_xla_auto_jit=2 path/to/your/tf/program +``` + +Auto-clustering is currently optimized for GPU workloads, but it can also be +enabled on CPU by additionally using the flag `--tf_xla_cpu_global_jit`: + +``` +$ TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" path/to/your/program +``` + +Note: Auto-clustering support on CPU and on multi-GPU environments is +experimental. + +For a detailed usage example see the +[auto-clustering tutorial colab](./tutorials/autoclustering_xla.ipynb). + +### AOT (Ahead-of-time) compilation for CPU with `tfcompile` + +You can also use a standalone [`tfcompile`](./tfcompile.md) tool, which converts +TensorFlow graph into executable code (for x86-64 CPU only). + +## Inspect compiled programs + +XLA provides introspection facilities which let you inspect the generated +programs. To dump the generated programs, use the environment variable +`XLA_FLAGS`: + +``` +$ XLA_FLAGS="--xla_dump_to=/tmp/generated" TF_XLA_FLAGS="--tf_xla_auto_jit=2" my/tensorflow/program +``` + +After the dumping is performed, you can find the following files in +`/tmp/generated`: + +- `module_XXXX.*_optimizations.txt` Generated + [XLA programs](./operation_semantics.md), one per each compiled cluster. + Attaching those when submitting XLA bug reports is extremely helpful! + +- `module_XXXX.ir-*.ll` Generated files in + [LLVM](https://llvm.org/docs/LangRef.html) intermediate representation, with + [NVPTX](https://llvm.org/docs/NVPTXUsage.html) intrinsics. + +- `module_XXXX.ptx` Generated + [PTX](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html) + files. + +You can also dump the graph visualizing the embedding of XLA clusters inside of +the TensorFlow graph with: + +``` +$ TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug" +``` + +## Reproducible bug reports + +A bug report is much easier to reproduce if it includes dumps for the generated +XLA programs and the used auto-clustering embedding. +To generate them for a TensorFlow program running with auto-clustering, launch: + +``` +$ TF_DUMP_GRAPH_PREFIX=/tmp/generated \ + TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" \ + XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=/tmp/generated" \ + my/tensorflow/program" +``` + +When filing bugs, attach the contents of the `/tmp/generated` directory +(referenced above). + +If possible, try to isolate +a bug to a single XLA program by using the +[`run_hlo_module`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/tools/run_hlo_module_main.cc) +and iteratively running it on generated programs. + +## Further reading + +- [OpenXLA Documentation](https://openxla.org) OpenXLA Documentation +- [Known Issues](./known_issues.md) List of known issues with XLA+TF +- [XLA - TensorFlow, Compiled](https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html): + Read on Google Developers Blog +- Check out the + [XLA source](https://github.com/openxla/xla) + on Github! + +## XLA Frontends + +Apart from TensorFlow, XLA programs can be generated by: + +- [JAX](https://github.com/google/jax): Composable transformations of + Python+NumPy programs +- [Julia](https://github.com/JuliaTPU/XLA.jl): The Julia language for + scientific computing +- [PyTorch](https://github.com/pytorch/xla): PyTorch framework +- [Nx](https://github.com/elixir-nx/nx): Numerical computing library for the + Elixir programming language + +## Talks + +### Using XLA from TF using `jit_compile=True` + + + +### XLA Overview + + diff --git a/docs/tf2xla/tfcompile.md b/docs/tf2xla/tfcompile.md new file mode 100644 index 0000000000000..5d60a4e90a9ac --- /dev/null +++ b/docs/tf2xla/tfcompile.md @@ -0,0 +1,279 @@ +# Using AOT compilation + +## What is tfcompile? + +`tfcompile` is a standalone tool that ahead-of-time (AOT) compiles TensorFlow +graphs into executable code. It can reduce total binary size, and also avoid +some runtime overheads. A typical use-case of `tfcompile` is to compile an +inference graph into executable code for mobile devices. + +The TensorFlow graph is normally executed by the TensorFlow runtime. This incurs +some runtime overhead for execution of each node in the graph. This also leads +to a larger total binary size, since the code for the TensorFlow runtime needs +to be available, in addition to the graph itself. The executable code produced +by `tfcompile` does not use the TensorFlow runtime, and only has dependencies on +kernels that are actually used in the computation. + +The compiler is built on top of the XLA framework. The code bridging TensorFlow +to the XLA framework resides under +[tensorflow/compiler](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/). + +## What does tfcompile do? + +`tfcompile` takes a subgraph, identified by the TensorFlow concepts of +feeds and fetches, and generates a function that implements that subgraph. +The `feeds` are the input arguments for the function, and the `fetches` are the +output arguments for the function. All inputs must be fully specified by the +feeds; the resulting pruned subgraph cannot contain Placeholder or Variable +nodes. It is common to specify all Placeholders and Variables as feeds, which +ensures the resulting subgraph no longer contains these nodes. The generated +function is packaged as a `cc_library`, with a header file exporting the +function signature, and an object file containing the implementation. The user +writes code to invoke the generated function as appropriate. + +## Using tfcompile + +This section details high level steps for generating an executable binary with +`tfcompile` from a TensorFlow subgraph. The steps are: + +* Step 1: Configure the subgraph to compile +* Step 2: Use the `tf_library` build macro to compile the subgraph +* Step 3: Write code to invoke the subgraph +* Step 4: Create the final binary + +### Step 1: Configure the subgraph to compile + +Identify the feeds and fetches that correspond to the input and output +arguments for the generated function. Then configure the `feeds` and `fetches` +in a [`tensorflow.tf2xla.Config`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2xla/tf2xla.proto) +proto. + +```textproto +# Each feed is a positional input argument for the generated function. The order +# of each entry matches the order of each input argument. Here “x_hold” and “y_hold” +# refer to the names of placeholder nodes defined in the graph. +feed { + id { node_name: "x_hold" } + shape { + dim { size: 2 } + dim { size: 3 } + } +} +feed { + id { node_name: "y_hold" } + shape { + dim { size: 3 } + dim { size: 2 } + } +} + +# Each fetch is a positional output argument for the generated function. The order +# of each entry matches the order of each output argument. Here “x_y_prod” +# refers to the name of a matmul node defined in the graph. +fetch { + id { node_name: "x_y_prod" } +} +``` + +### Step 2: Use tf_library build macro to compile the subgraph + +This step converts the graph into a `cc_library` using the `tf_library` build +macro. The `cc_library` consists of an object file containing the code generated +from the graph, along with a header file that gives access to the generated +code. `tf_library` utilizes `tfcompile` to compile the TensorFlow graph into +executable code. + +```build +load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") + +# Use the tf_library macro to compile your graph into executable code. +tf_library( + # name is used to generate the following underlying build rules: + # : cc_library packaging the generated header and object files + # _test : cc_test containing a simple test and benchmark + # _benchmark : cc_binary containing a stand-alone benchmark with minimal deps; + # can be run on a mobile device + name = "test_graph_tfmatmul", + # cpp_class specifies the name of the generated C++ class, with namespaces allowed. + # The class will be generated in the given namespace(s), or if no namespaces are + # given, within the global namespace. + cpp_class = "foo::bar::MatMulComp", + # graph is the input GraphDef proto, by default expected in binary format. To + # use the text format instead, just use the ‘.pbtxt’ suffix. A subgraph will be + # created from this input graph, with feeds as inputs and fetches as outputs. + # No Placeholder or Variable ops may exist in this subgraph. + graph = "test_graph_tfmatmul.pb", + # config is the input Config proto, by default expected in binary format. To + # use the text format instead, use the ‘.pbtxt’ suffix. This is where the + # feeds and fetches were specified above, in the previous step. + config = "test_graph_tfmatmul.config.pbtxt", +) +``` + +> To generate the GraphDef proto (test_graph_tfmatmul.pb) for this example, run +> [make_test_graphs.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/aot/tests/make_test_graphs.py) +> and specify the output location with the --out_dir flag. + +Typical graphs contain [`Variables`](https://www.tensorflow.org/guide/variables) +representing the weights that are learned via training, but `tfcompile` cannot +compile a subgraph that contain `Variables`. The +[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py) +tool converts variables into constants, using values stored in a checkpoint +file. As a convenience, the `tf_library` macro supports the `freeze_checkpoint` +argument, which runs the tool. For more examples see +[tensorflow/compiler/aot/tests/BUILD](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/aot/tests/BUILD). + +> Constants that show up in the compiled subgraph are compiled directly into the +> generated code. To pass the constants into the generated function, rather than +> having them compiled-in, simply pass them in as feeds. + +For details on the `tf_library` build macro, see +[tfcompile.bzl](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/aot/tfcompile.bzl). + +For details on the underlying `tfcompile` tool, see +[tfcompile_main.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/aot/tfcompile_main.cc). + +### Step 3: Write code to invoke the subgraph + +This step uses the header file (`test_graph_tfmatmul.h`) generated by the +`tf_library` build macro in the previous step to invoke the generated code. The +header file is located in the `bazel-bin` directory corresponding to the +build package, and is named based on the name attribute set on the `tf_library` +build macro. For example, the header generated for `test_graph_tfmatmul` would +be `test_graph_tfmatmul.h`. Below is an abbreviated version of what is +generated. The generated file, in `bazel-bin`, contains additional useful +comments. + +```c++ +namespace foo { +namespace bar { + +// MatMulComp represents a computation previously specified in a +// TensorFlow graph, now compiled into executable code. +class MatMulComp { + public: + // AllocMode controls the buffer allocation mode. + enum class AllocMode { + ARGS_RESULTS_AND_TEMPS, // Allocate arg, result and temp buffers + RESULTS_AND_TEMPS_ONLY, // Only allocate result and temp buffers + }; + + MatMulComp(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS); + ~MatMulComp(); + + // Runs the computation, with inputs read from arg buffers, and outputs + // written to result buffers. Returns true on success and false on failure. + bool Run(); + + // Arg methods for managing input buffers. Buffers are in row-major order. + // There is a set of methods for each positional argument. + void** args(); + + void set_arg0_data(float* data); + float* arg0_data(); + float& arg0(size_t dim0, size_t dim1); + + void set_arg1_data(float* data); + float* arg1_data(); + float& arg1(size_t dim0, size_t dim1); + + // Result methods for managing output buffers. Buffers are in row-major order. + // Must only be called after a successful Run call. There is a set of methods + // for each positional result. + void** results(); + + + float* result0_data(); + float& result0(size_t dim0, size_t dim1); +}; + +} // end namespace bar +} // end namespace foo +``` + +The generated C++ class is called `MatMulComp` in the `foo::bar` namespace, +because that was the `cpp_class` specified in the `tf_library` macro. All +generated classes have a similar API, with the only difference being the methods +to handle arg and result buffers. Those methods differ based on the number and +types of the buffers, which were specified by the `feed` and `fetch` arguments +to the `tf_library` macro. + +There are three types of buffers managed within the generated class: `args` +representing the inputs, `results` representing the outputs, and `temps` +representing temporary buffers used internally to perform the computation. By +default, each instance of the generated class allocates and manages all of these +buffers for you. The `AllocMode` constructor argument may be used to change this +behavior. All buffers are aligned to 64-byte boundaries. + +The generated C++ class is just a wrapper around the low-level code generated by +XLA. + +Example of invoking the generated function based on +[`tfcompile_test.cc`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/aot/tests/tfcompile_test.cc): + +```c++ +#define EIGEN_USE_THREADS +#define EIGEN_USE_CUSTOM_THREAD_POOL + +#include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "third_party/tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" // generated + +int main(int argc, char** argv) { + Eigen::ThreadPool tp(2); // Size the thread pool as appropriate. + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + + foo::bar::MatMulComp matmul; + matmul.set_thread_pool(&device); + + // Set up args and run the computation. + const float args[12] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + std::copy(args + 0, args + 6, matmul.arg0_data()); + std::copy(args + 6, args + 12, matmul.arg1_data()); + matmul.Run(); + + // Check result + if (matmul.result0(0, 0) == 58) { + std::cout << "Success" << std::endl; + } else { + std::cout << "Failed. Expected value 58 at 0,0. Got:" + << matmul.result0(0, 0) << std::endl; + } + + return 0; +} +``` + +### Step 4: Create the final binary + +This step combines the library generated by `tf_library` in step 2 and the code +written in step 3 to create a final binary. Below is an example `bazel` BUILD +file. + +```build +# Example of linking your binary +# Also see //tensorflow/compiler/aot/tests/BUILD +load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") + +# The same tf_library call from step 2 above. +tf_library( + name = "test_graph_tfmatmul", + ... +) + +# The executable code generated by tf_library can then be linked into your code. +cc_binary( + name = "my_binary", + srcs = [ + "my_code.cc", # include test_graph_tfmatmul.h to access the generated header + ], + deps = [ + ":test_graph_tfmatmul", # link in the generated object file + "//third_party/eigen3", + ], + linkopts = [ + "-lpthread", + ] +) +``` diff --git a/docs/tf2xla/tutorials/autoclustering_xla.ipynb b/docs/tf2xla/tutorials/autoclustering_xla.ipynb new file mode 100644 index 0000000000000..88f94c2bbc3f8 --- /dev/null +++ b/docs/tf2xla/tutorials/autoclustering_xla.ipynb @@ -0,0 +1,272 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "f4TSNCvpENrW" + }, + "source": [ + "##### Copyright 2019 The TensorFlow Authors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "vamNSA0vEP-m" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "asd4sdga7g" + }, + "source": [ + "# Classifying CIFAR-10 with XLA\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b7noD9NjFRL-" + }, + "source": [ + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/xla/tutorials/autoclustering_xla\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", + " \u003c/td\u003e\n", + "\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mz65veHXsmnS" + }, + "source": [ + "This tutorial trains a TensorFlow model to classify the [CIFAR-10](https://en.wikipedia.org/wiki/CIFAR-10) dataset, and we compile it using XLA.\n", + "\n", + "You will load and normalize the dataset using the [TensorFlow Datasets (TFDS)](https://tensorflow.org/datasets) API. First, install/upgrade TensorFlow and TFDS:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "R4xtYyOf78e3" + }, + "outputs": [], + "source": [ + "!pip install -U -q tensorflow tensorflow_datasets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PH2HbLW65tmo" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import tensorflow_datasets as tfds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7vm2QsMisCxI" + }, + "outputs": [], + "source": [ + "# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb\n", + "assert(tf.test.gpu_device_name())\n", + "\n", + "tf.keras.backend.clear_session()\n", + "tf.config.optimizer.set_jit(False) # Start with XLA disabled.\n", + "\n", + "def load_data():\n", + " result = tfds.load('cifar10', batch_size = -1)\n", + " (x_train, y_train) = result['train']['image'],result['train']['label']\n", + " (x_test, y_test) = result['test']['image'],result['test']['label']\n", + " \n", + " x_train = x_train.numpy().astype('float32') / 256\n", + " x_test = x_test.numpy().astype('float32') / 256\n", + "\n", + " # Convert class vectors to binary class matrices.\n", + " y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)\n", + " y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)\n", + " return ((x_train, y_train), (x_test, y_test))\n", + "\n", + "(x_train, y_train), (x_test, y_test) = load_data()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MgNM2tbgtScx" + }, + "source": [ + "We define the model, adapted from the Keras [CIFAR-10 example](https://keras.io/examples/cifar10_cnn/):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3ZRQSwoRsKM_" + }, + "outputs": [], + "source": [ + "def generate_model():\n", + " return tf.keras.models.Sequential([\n", + " tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),\n", + " tf.keras.layers.Activation('relu'),\n", + " tf.keras.layers.Conv2D(32, (3, 3)),\n", + " tf.keras.layers.Activation('relu'),\n", + " tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),\n", + " tf.keras.layers.Dropout(0.25),\n", + "\n", + " tf.keras.layers.Conv2D(64, (3, 3), padding='same'),\n", + " tf.keras.layers.Activation('relu'),\n", + " tf.keras.layers.Conv2D(64, (3, 3)),\n", + " tf.keras.layers.Activation('relu'),\n", + " tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),\n", + " tf.keras.layers.Dropout(0.25),\n", + "\n", + " tf.keras.layers.Flatten(),\n", + " tf.keras.layers.Dense(512),\n", + " tf.keras.layers.Activation('relu'),\n", + " tf.keras.layers.Dropout(0.5),\n", + " tf.keras.layers.Dense(10),\n", + " tf.keras.layers.Activation('softmax')\n", + " ])\n", + "\n", + "model = generate_model()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-M4GtGDZtb8a" + }, + "source": [ + "We train the model using the\n", + "[RMSprop](https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer)\n", + "optimizer:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UKCmrhF0tiMa" + }, + "outputs": [], + "source": [ + "def compile_model(model):\n", + " opt = tf.keras.optimizers.RMSprop(learning_rate=0.0001)\n", + " model.compile(loss='categorical_crossentropy',\n", + " optimizer=opt,\n", + " metrics=['accuracy'])\n", + " return model\n", + "\n", + "model = compile_model(model)\n", + "\n", + "def train_model(model, x_train, y_train, x_test, y_test, epochs=25):\n", + " model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)\n", + "\n", + "def warmup(model, x_train, y_train, x_test, y_test):\n", + " # Warm up the JIT, we do not wish to measure the compilation time.\n", + " initial_weights = model.get_weights()\n", + " train_model(model, x_train, y_train, x_test, y_test, epochs=1)\n", + " model.set_weights(initial_weights)\n", + "\n", + "warmup(model, x_train, y_train, x_test, y_test)\n", + "%time train_model(model, x_train, y_train, x_test, y_test)\n", + "\n", + "scores = model.evaluate(x_test, y_test, verbose=1)\n", + "print('Test loss:', scores[0])\n", + "print('Test accuracy:', scores[1])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SLpfQ0StRgsu" + }, + "source": [ + "Now let's train the model again, using the XLA compiler.\n", + "To enable the compiler in the middle of the application, we need to reset the Keras session." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jxU-Tzy4SX7p" + }, + "outputs": [], + "source": [ + "# We need to clear the session to enable JIT in the middle of the program.\n", + "tf.keras.backend.clear_session()\n", + "tf.config.optimizer.set_jit(True) # Enable XLA.\n", + "model = compile_model(generate_model())\n", + "(x_train, y_train), (x_test, y_test) = load_data()\n", + "\n", + "warmup(model, x_train, y_train, x_test, y_test)\n", + "%time train_model(model, x_train, y_train, x_test, y_test)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iWHz6P1se92F" + }, + "source": [ + "On a machine with a Titan V GPU and an Intel Xeon E5-2690 CPU the speed up is ~1.17x." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "CIFAR-10 with XLA.ipynb", + "private_outputs": true, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/tf2xla/tutorials/jit_compile.ipynb b/docs/tf2xla/tutorials/jit_compile.ipynb new file mode 100644 index 0000000000000..b9967f4e94f4d --- /dev/null +++ b/docs/tf2xla/tutorials/jit_compile.ipynb @@ -0,0 +1,270 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "f4TSNCvpENrW" + }, + "source": [ + "##### Copyright 2019 The TensorFlow Authors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "vamNSA0vEP-m" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "e1oSi4lHFt3z" + }, + "source": [ + "# Use XLA with tf.function" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b7noD9NjFRL-" + }, + "source": [ + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/xla/tutorials/compile\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", + " \u003c/td\u003e\n", + "\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sDy5lSBd4BDE" + }, + "source": [ + "This tutorial trains a TensorFlow model to classify the MNIST dataset, where the training function is compiled using XLA.\n", + "\n", + "First, load TensorFlow and enable eager execution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "45kUPj5ZFrRa" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GZVNiRmTDV-5" + }, + "source": [ + "Then define some necessary constants and prepare the MNIST dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "f37TSEGvGX4_" + }, + "outputs": [], + "source": [ + "# Size of each input image, 28 x 28 pixels\n", + "IMAGE_SIZE = 28 * 28\n", + "# Number of distinct number labels, [0..9]\n", + "NUM_CLASSES = 10\n", + "# Number of examples in each training batch (step)\n", + "TRAIN_BATCH_SIZE = 100\n", + "# Number of training steps to run\n", + "TRAIN_STEPS = 1000\n", + "\n", + "# Loads MNIST dataset.\n", + "train, test = tf.keras.datasets.mnist.load_data()\n", + "train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()\n", + "\n", + "# Casting from raw data to the required datatypes.\n", + "def cast(images, labels):\n", + " images = tf.cast(\n", + " tf.reshape(images, [-1, IMAGE_SIZE]), tf.float32)\n", + " labels = tf.cast(labels, tf.int64)\n", + " return (images, labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lv7I-u_82v1S" + }, + "source": [ + "Finally, define the model and the optimizer. The model uses a single dense layer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7O2NcEfG206Q" + }, + "outputs": [], + "source": [ + "layer = tf.keras.layers.Dense(NUM_CLASSES)\n", + "optimizer = tf.keras.optimizers.Adam()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "x_ZehpZP-SfS" + }, + "source": [ + "# Define the training function\n", + "\n", + "In the training function, you get the predicted labels using the layer defined above, and then minimize the gradient of the loss using the optimizer. In order to compile the computation using XLA, place it inside `tf.function` with `jit_compile=True`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZbhJl_WvGa3g" + }, + "outputs": [], + "source": [ + "@tf.function(jit_compile=True)\n", + "def train_mnist(images, labels):\n", + " images, labels = cast(images, labels)\n", + "\n", + " with tf.GradientTape() as tape:\n", + " predicted_labels = layer(images)\n", + " loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(\n", + " logits=predicted_labels, labels=labels\n", + " ))\n", + " layer_variables = layer.trainable_variables\n", + " grads = tape.gradient(loss, layer_variables)\n", + " optimizer.apply_gradients(zip(grads, layer_variables))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EZD1m_n1DxAF" + }, + "source": [ + "# Train and test the model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gukC2Hol3sFZ" + }, + "source": [ + "Once you have defined the training function, define the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qe28bAHNHUG2" + }, + "outputs": [], + "source": [ + "for images, labels in train_ds:\n", + " if optimizer.iterations \u003e TRAIN_STEPS:\n", + " break\n", + " train_mnist(images, labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qgsKmz3n2UiW" + }, + "source": [ + "And, finally, check the accuracy:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_GxF6jTRHVuA" + }, + "outputs": [], + "source": [ + "images, labels = cast(test[0], test[1])\n", + "predicted_labels = layer(images)\n", + "correct_prediction = tf.equal(tf.argmax(predicted_labels, 1), labels)\n", + "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n", + "print(\"Prediction accuracy after training: %s\" % accuracy)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PXoOjJnuZRaV" + }, + "source": [ + "Behind the scenes, the XLA compiler has compiled the entire TF function to HLO, which has enabled fusion optimizations. Using the introspection facilities, we can see the HLO code (other interesting possible values for \"stage\" are `optimized_hlo` for HLO after optimizations and `optimized_hlo_dot` for a Graphviz graph):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_a8GsNLVaLSQ" + }, + "outputs": [], + "source": [ + "print(train_mnist.experimental_get_compiler_ir(images, labels)(stage='hlo'))" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "jit_compile.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/tools.md b/docs/tools.md new file mode 100644 index 0000000000000..fe7d8ee5a6a86 --- /dev/null +++ b/docs/tools.md @@ -0,0 +1,155 @@ +# Using XLA tooling + +The XLA development workflow is usually centered around +[HLO](./operation_semantics) IR, which represents isolated functional +computation given to the compiler. XLA comes with multiple command line tools +(described below) which consume HLO and either run it, or provide an +intermediate compilation stage. Using such tools is invaluable for a fast +`compile->modify->run` iteration cycle, as HLO is both visualizable and +hackable, and iteratively changing and running it is often the fastest way to +understand and to fix an XLA performance or behavior. + +The easiest way to obtain the HLO for a program being compiled with XLA is +usually to use the `XLA_FLAGS` environment variable: + +``` +XLA_FLAGS=--xla_dump_to=/tmp/myfolder ./myprogram-entry-point +``` + +which stores all before-optimization HLO files in the folder specified, along +with many other useful artifacts. + +## Running HLO snippets: `run_hlo_module` + +The tool `run_hlo_module` operates on pre-optimization HLO, and by default +bundles compilation, running and comparison with the reference interpreter +implementation. For example, the usual invocation to run an input file +`computation.hlo` on an NVIDIA GPU and to check it for correctness is: + +``` +run_hlo_module --platform=CUDA --reference_platform=Interpreter computation.hlo +``` + +As with all the tools, `--help` can be used to obtain the full list of options. + +## Running HLO snippets with SPMD support: `multihost_hlo_runner` + +Multihost HLO runner is a very similar tool, with the caveat that it supports +SPMD, including cross host communication. A typical invocation looks like: + +``` +hlo_runner_main --device_type=gpu --use_spmd_partitioning=true --num_partitions=4 --num_replicas=1 --hlo_file=computation.hlo +``` + +## Running passes/stages of HLO compilation: `hlo-opt` + +When debugging or understanding the workings of the compiler, it is often useful +to get the expansion for a particular hardware at a particular point in the +pipeline (be it HLO, optimized HLO, TritonIR or LLVM), for a given (Stable) HLO +input. + +`hlo-opt` supports multiple output stages: be it PTX, HLO after optimizations, +LLVM IR before optimizations, or TritonIR. The exact set of stages supported +depends on the platform (as e.g. PTX is NVIDIA-specific), and can be seen using +the --list-stages command: + +``` +$ hlo-opt --platform=CUDA --list-stages +hlo +llvm +ptx +``` + +After selecting a stage, the user can write the result of the conversion for a +given platform to a given stream: + +``` +$ hlo-opt myinput.hlo --platform=CUDA --stage=llvm +``` + +which would print the dump to stdout (or to a given file if `-o` was specified). + +### Deviceless Usage + +Access to a GPU is not needed for most of the compilation, and by specifying a +GPU spec on the command line we can get e.g. PTX output without access to an +accelerator: + +``` +$ hlo-opt --platform=CUDA --stage=llvm --xla_gpu_target_config_filename=(pwd)/tools/data/gpu_specs/a100_80.txtpb input.hlo +``` + +Note: For the above invocation to work, the user would usually either need to +disable autotuning with `--xla_gpu_autotune_level=0` or load a pre-existing +autotuning results with `--xla_gpu_load_autotune_results_from=` +(obtained with `--xla_gpu_dump_autotune_results_to=`). + +Specs for popular GPUs are shipped with the compiler, and the provided file is +string serialization of `device_description.proto`: + +``` +gpu_device_info { + cuda_compute_capability { + major: 8 + minor: 0 + } + threads_per_block_limit: 1024 + threads_per_warp: 32 + shared_memory_per_block: 127152 + shared_memory_per_core: 65536 + threads_per_core_limit: 2048 + core_count: 6192 + fpus_per_core: 64 + block_dim_limit_x: 2147483647 + block_dim_limit_y: 65535 + block_dim_limit_z: 65535 + memory_bandwidth: 2039000000000 + l2_cache_size: 4194304 + clock_rate_ghz: 1.1105 + device_memory_size: 79050250240 +} +platform_name: "CUDA" +``` + +Deviceless compilation might run into issues if autotuning is required. Luckily, +we can also provide those on the command line: + +``` +hlo-opt --platform=CUDA --stage=llvm --xla_gpu_target_config_filename=gpu_specs/a100_80.txtpb --xla_gpu_load_autotune_results_from=results.textpb input.hlo +``` + +The autotune file is text serialization of `autotune_results.proto`, with +example looking like: + +``` +version: 2 +results { + device: "sm_8.0 with 42331013120B RAM, 108 cores, 1410000KHz clock, 1215000KHz mem clock, 41943040B L2$" + hlo: "{\n tmp_0 = f16[1,16,17,3]{3,2,1,0} parameter(0)\n tmp_1 = f16[16,51]{1,0} bitcast(f16[1,16,17,3]{3,2,1,0} tmp_0)\n tmp_2 = s8[16,17,3]{2,1,0} parameter(1)\n tmp_3 = s8[51,16]{0,1} bitcast(s8[16,17,3]{2,1,0} tmp_2)\n tmp_4 = f16[51,16]{0,1} convert(s8[51,16]{0,1} tmp_3)\n tmp_5 = f16[16,16]{1,0} dot(f16[16,51]{1,0} tmp_1, f16[51,16]{0,1} tmp_4), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n ROOT tmp_6 = f16[1,16,16]{2,1,0} bitcast(f16[16,16]{1,0} tmp_5)\n}" + result { + run_time { + nanos: 31744 + } + triton { + block_m: 32 + block_n: 32 + block_k: 32 + split_k: 1 + num_stages: 1 + num_warps: 4 + } + } +} +``` + +The autotuning database can be serialized using +`XLA_FLAGS=--xla_gpu_dump_autotune_results_t=` + +### Running a Single Compiler Pass + +The flags from `XLA_FLAGS` are also supported, so the tool can be used to test +running a single pass: + +``` +hlo-opt --platform=CUDA --stage=hlo --xla-hlo-enable-passes-only=algebraic_simplifer input.hlo +``` diff --git a/opensource_only.files b/opensource_only.files index 9de7578a5801a..2b635d53128b3 100644 --- a/opensource_only.files +++ b/opensource_only.files @@ -1,10 +1,17 @@ -compiler/xla/glob_lit_test.bzl: compiler/xla/mlir_hlo/WORKSPACE: compiler/xla/stream_executor/build_defs.bzl: +compiler/xla/tsl/cuda/stub.bzl: +compiler/xla/tsl/mkl/BUILD: +compiler/xla/tsl/mkl/LICENSE: +compiler/xla/tsl/mkl/MKL_LICENSE: +compiler/xla/tsl/mkl/build_defs.bzl: third_party/BUILD: third_party/__init__:.py third_party/compute_library/BUILD: third_party/compute_library/build_defs.bzl: +third_party/implib_so/BUILD: +third_party/implib_so/get_symbols.py: +third_party/implib_so/make_stub.py: third_party/llvm_openmp/BUILD: third_party/llvm_openmp/cmake_vars.bzl: third_party/llvm_openmp/expand_cmake_vars:.py diff --git a/third_party/compute_library/build_defs.bzl b/third_party/compute_library/build_defs.bzl index 4e2effaa62402..d6ec1f1133dec 100644 --- a/third_party/compute_library/build_defs.bzl +++ b/third_party/compute_library/build_defs.bzl @@ -1,6 +1,6 @@ def if_enable_acl(if_true, if_false = []): return select({ - "@xla//third_party/compute_library:build_with_acl": if_true, + "@tsl//third_party/compute_library:build_with_acl": if_true, "//conditions:default": if_false, }) @@ -15,6 +15,6 @@ def acl_deps(): inclusion in the deps attribute of rules. """ return select({ - "@xla//third_party/compute_library:build_with_acl": ["@compute_library//:arm_compute"], + "@tsl//third_party/compute_library:build_with_acl": ["@compute_library//:arm_compute"], "//conditions:default": [], }) diff --git a/third_party/cudnn_frontend_header_fix.patch b/third_party/cudnn_frontend_header_fix.patch index ee37c4b14827f..70476bd3ff5d5 100644 --- a/third_party/cudnn_frontend_header_fix.patch +++ b/third_party/cudnn_frontend_header_fix.patch @@ -1,234 +1,13 @@ -diff --git a/include/cudnn_backend_base.h b/include/cudnn_backend_base.h -index 56d8bec..8ceb19c 100644 ---- a/include/cudnn_backend_base.h -+++ b/include/cudnn_backend_base.h -@@ -24,7 +24,7 @@ - - #include - --#include -+#include "third_party/gpus/cudnn/cudnn.h" - - namespace cudnn_frontend { - -diff --git a/include/cudnn_frontend_ConvDesc.h b/include/cudnn_frontend_ConvDesc.h -index ded7e67..68341e1 100644 ---- a/include/cudnn_frontend_ConvDesc.h -+++ b/include/cudnn_frontend_ConvDesc.h -@@ -29,8 +29,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - -diff --git a/include/cudnn_frontend_Engine.h b/include/cudnn_frontend_Engine.h -index 7e18cd7..d26f4ee 100644 ---- a/include/cudnn_frontend_Engine.h -+++ b/include/cudnn_frontend_Engine.h -@@ -30,8 +30,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_OperationGraph.h" - #include "cudnn_frontend_utils.h" -diff --git a/include/cudnn_frontend_EngineConfig.h b/include/cudnn_frontend_EngineConfig.h -index ea68554..0888858 100644 ---- a/include/cudnn_frontend_EngineConfig.h -+++ b/include/cudnn_frontend_EngineConfig.h -@@ -29,8 +29,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_Engine.h" - #include "cudnn_frontend_utils.h" -diff --git a/include/cudnn_frontend_EngineFallbackList.h b/include/cudnn_frontend_EngineFallbackList.h -index 323106a..d90a1ea 100644 ---- a/include/cudnn_frontend_EngineFallbackList.h -+++ b/include/cudnn_frontend_EngineFallbackList.h -@@ -22,7 +22,7 @@ - - #pragma once - --#include -+#include "third_party/gpus/cudnn/cudnn.h" - #include - #include "cudnn_frontend_Heuristics.h" - -diff --git a/include/cudnn_frontend_ExecutionPlan.h b/include/cudnn_frontend_ExecutionPlan.h -index e361821..88f5790 100644 ---- a/include/cudnn_frontend_ExecutionPlan.h -+++ b/include/cudnn_frontend_ExecutionPlan.h -@@ -30,8 +30,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_Engine.h" - #include "cudnn_frontend_utils.h" -diff --git a/include/cudnn_frontend_Filters.h b/include/cudnn_frontend_Filters.h -index aac4086..ed1f343 100644 ---- a/include/cudnn_frontend_Filters.h -+++ b/include/cudnn_frontend_Filters.h -@@ -22,7 +22,7 @@ - - #pragma once - --#include -+#include "third_party/gpus/cudnn/cudnn.h" - - namespace cudnn_frontend { - -diff --git a/include/cudnn_frontend_Heuristics.h b/include/cudnn_frontend_Heuristics.h -index 680906a..3df8924 100644 ---- a/include/cudnn_frontend_Heuristics.h -+++ b/include/cudnn_frontend_Heuristics.h -@@ -25,8 +25,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_OperationGraph.h" - #include "cudnn_frontend_EngineConfig.h" -diff --git a/include/cudnn_frontend_MatMulDesc.h b/include/cudnn_frontend_MatMulDesc.h -index e7dd8f7..7a5d443 100644 ---- a/include/cudnn_frontend_MatMulDesc.h -+++ b/include/cudnn_frontend_MatMulDesc.h -@@ -29,8 +29,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - -diff --git a/include/cudnn_frontend_Operation.h b/include/cudnn_frontend_Operation.h -index fe75d5b..a43d696 100644 ---- a/include/cudnn_frontend_Operation.h -+++ b/include/cudnn_frontend_Operation.h -@@ -30,8 +30,8 @@ - #include - #include +diff --git a/include/cudnn_frontend.h b/include/cudnn_frontend.h +index 0f0d5a6..802bcbb 100644 +--- a/include/cudnn_frontend.h ++++ b/include/cudnn_frontend.h +@@ -97,7 +97,7 @@ + * - Simpler samples on how to use the new API. + */ -#include --#include +#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" #include "cudnn_frontend_ConvDesc.h" - #include "cudnn_frontend_PointWiseDesc.h" -diff --git a/include/cudnn_frontend_OperationGraph.h b/include/cudnn_frontend_OperationGraph.h -index 919a190..5e31484 100644 ---- a/include/cudnn_frontend_OperationGraph.h -+++ b/include/cudnn_frontend_OperationGraph.h -@@ -30,8 +30,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_Operation.h" - #include "cudnn_frontend_utils.h" -diff --git a/include/cudnn_frontend_PointWiseDesc.h b/include/cudnn_frontend_PointWiseDesc.h -index ad1f943..f320a27 100644 ---- a/include/cudnn_frontend_PointWiseDesc.h -+++ b/include/cudnn_frontend_PointWiseDesc.h -@@ -30,8 +30,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - -diff --git a/include/cudnn_frontend_ReductionDesc.h b/include/cudnn_frontend_ReductionDesc.h -index e22a0fb..d69e8c6 100644 ---- a/include/cudnn_frontend_ReductionDesc.h -+++ b/include/cudnn_frontend_ReductionDesc.h -@@ -29,8 +29,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - -diff --git a/include/cudnn_frontend_Resample.h b/include/cudnn_frontend_Resample.h -index 4d2d197..1b7c24d 100644 ---- a/include/cudnn_frontend_Resample.h -+++ b/include/cudnn_frontend_Resample.h -@@ -29,8 +29,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - -diff --git a/include/cudnn_frontend_Rng.h b/include/cudnn_frontend_Rng.h -index 0001ac6..80a623b 100644 ---- a/include/cudnn_frontend_Rng.h -+++ b/include/cudnn_frontend_Rng.h -@@ -29,8 +29,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - -diff --git a/include/cudnn_frontend_VariantPack.h b/include/cudnn_frontend_VariantPack.h -index dc68207..8b47fce 100644 ---- a/include/cudnn_frontend_VariantPack.h -+++ b/include/cudnn_frontend_VariantPack.h -@@ -30,8 +30,8 @@ - #include - #include - --#include --#include -+#include "third_party/gpus/cudnn/cudnn.h" -+#include "third_party/gpus/cudnn/cudnn_backend.h" - - #include "cudnn_frontend_utils.h" - + #include "cudnn_frontend_Heuristics.h" diff --git a/third_party/cutlass.BUILD b/third_party/cutlass.BUILD new file mode 100644 index 0000000000000..581ffb15685a9 --- /dev/null +++ b/third_party/cutlass.BUILD @@ -0,0 +1,37 @@ +# Description: +# CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance +# matrix-matrix multiplication (GEMM) and related computations at all levels and scales within CUDA. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # MIT + +exports_files(["LICENSE.txt"]) + +filegroup( + name = "cutlass_header_files", + srcs = glob([ + "include/**", + ]), +) + +filegroup( + name = "cutlass_util_header_files", + srcs = glob([ + "tools/util/include/**", + ]), +) + +cc_library( + name = "cutlass", + hdrs = [ + ":cutlass_header_files", + ":cutlass_util_header_files", + ], + includes = [ + "include", + "tools/util/include", + ], +) diff --git a/third_party/dlpack/workspace.bzl b/third_party/dlpack/workspace.bzl index bda030ac4908e..3d7560af37205 100644 --- a/third_party/dlpack/workspace.bzl +++ b/third_party/dlpack/workspace.bzl @@ -5,8 +5,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): tf_http_archive( name = "dlpack", - strip_prefix = "dlpack-9351cf542ab478499294864ff3acfdab5c8c5f3d", - sha256 = "7aca112f2809b7e9523e9b47b04a393affeca38247861951f07c42dee10180e2", - urls = tf_mirror_urls("https://github.com/dmlc/dlpack/archive/9351cf542ab478499294864ff3acfdab5c8c5f3d.tar.gz"), + strip_prefix = "dlpack-2a7e9f1256ddc48186c86dff7a00e189b47e5310", + sha256 = "044d2f5738e677c5f0f1ff9fb616a0245af67d09e42ae3514c73ba50cea0e4a5", + urls = tf_mirror_urls("https://github.com/dmlc/dlpack/archive/2a7e9f1256ddc48186c86dff7a00e189b47e5310.tar.gz"), build_file = "//third_party/dlpack:dlpack.BUILD", ) diff --git a/third_party/flash_attn/BUILD b/third_party/flash_attn/BUILD new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/third_party/flash_attn/flash_attn.BUILD b/third_party/flash_attn/flash_attn.BUILD new file mode 100644 index 0000000000000..5c994b9b42e3a --- /dev/null +++ b/third_party/flash_attn/flash_attn.BUILD @@ -0,0 +1,95 @@ +load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library", "if_cuda_is_configured") + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +cuda_library( + name = "flash_attn", + hdrs = if_cuda_is_configured([ + "csrc/flash_attn/src/alibi.h", + "csrc/flash_attn/src/block_info.h", + "csrc/flash_attn/src/dropout.h", + "csrc/flash_attn/src/flash_bwd_kernel.h", + "csrc/flash_attn/src/flash_bwd_launch_template.h", + "csrc/flash_attn/src/flash_bwd_preprocess_kernel.h", + "csrc/flash_attn/src/flash_fwd_kernel.h", + "csrc/flash_attn/src/flash_fwd_launch_template.h", + "csrc/flash_attn/src/flash_utils.h", + "csrc/flash_attn/src/flash.h", + "csrc/flash_attn/src/kernel_traits.h", + "csrc/flash_attn/src/mask.h", + "csrc/flash_attn/src/philox.cuh", + "csrc/flash_attn/src/rotary.h", + "csrc/flash_attn/src/softmax.h", + "csrc/flash_attn/src/static_switch.h", + "csrc/flash_attn/src/utils.h", + ]), + srcs = if_cuda_is_configured([ + "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu.cc", + "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu.cc", + "csrc/flash_attn/src/utils.cc", + ]), + # https://github.com/Dao-AILab/flash-attention/blob/v2.5.7/setup.py#L193-L199 + copts = [ + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + ], + include_prefix = "flash_attn", + strip_include_prefix = "csrc/flash_attn/src", + deps = if_cuda_is_configured([ + "@cutlass_for_flash_attn//:cutlass", + "@local_config_cuda//cuda:cuda_headers", + ]), +) diff --git a/third_party/flash_attn/flash_attn.patch b/third_party/flash_attn/flash_attn.patch new file mode 100644 index 0000000000000..f00f8f503742d --- /dev/null +++ b/third_party/flash_attn/flash_attn.patch @@ -0,0 +1,1944 @@ +diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp +index ac753af..ab6ebf1 100644 +--- a/csrc/flash_attn/flash_api.cpp ++++ b/csrc/flash_attn/flash_api.cpp +@@ -498,13 +498,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + + std::vector + mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i +- const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. +- const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. ++ const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i ++ const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. +- c10::optional &block_table_, // batch_size x max_num_blocks_per_seq + c10::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, +@@ -540,15 +539,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); + +- at::Tensor block_table; +- const bool paged_KV = block_table_.has_value(); +- if (paged_KV) { +- block_table = block_table_.value(); +- CHECK_DEVICE(block_table); +- TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); +- TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); +- } +- + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +@@ -560,12 +550,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s + const int batch_size = cu_seqlens_q.numel() - 1; + int num_heads = sizes[1]; + const int head_size_og = sizes[2]; +- const int num_heads_k = paged_KV ? k.size(2) : k.size(1); +- +- const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); +- const int num_blocks = !paged_KV ? 0 : k.size(0); +- const int page_block_size = !paged_KV ? 1 : k.size(1); +- TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); ++ const int total_k = k.size(0); ++ const int num_heads_k = k.size(1); + + if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case + if (is_causal) { window_size_right = 0; } +@@ -593,16 +579,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + + CHECK_SHAPE(q, total_q, num_heads, head_size_og); +- if (!paged_KV) { +- const int total_k = k.size(0); +- CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); +- CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); +- } else { +- CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og); +- CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og); +- CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); +- } +- ++ CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); ++ CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + if (seqused_k.has_value()){ +@@ -684,14 +662,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s + window_size_left, + window_size_right, + seqlenq_ngroups_swapped); +- +- if (paged_KV) { +- params.block_table = block_table.data_ptr(); +- params.block_table_batch_stride = block_table.stride(0); +- params.k_batch_stride = k_padded.stride(0); +- params.v_batch_stride = v_padded.stride(0); +- } +- params.page_block_size = page_block_size; + if (seqlenq_ngroups_swapped) { + // Only apply split-k for decoding + set_params_splitkv(params, batch_size, num_heads, +@@ -720,7 +690,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s + + if (max_seqlen_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); +- run_mha_fwd(params, stream, paged_KV); ++ run_mha_fwd(params, stream); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); +diff --git a/csrc/flash_attn/src/alibi.h b/csrc/flash_attn/src/alibi.h +index 80d297f..dc01d6a 100644 +--- a/csrc/flash_attn/src/alibi.h ++++ b/csrc/flash_attn/src/alibi.h +@@ -1,11 +1,11 @@ + #include + +-#include ++#include "cute/tensor.hpp" + +-#include +-#include ++#include "cutlass/cutlass.h" ++#include "cutlass/array.h" + +-#include "utils.h" ++#include "flash_utils.h" + + namespace flash { + +diff --git a/csrc/flash_attn/src/dropout.h b/csrc/flash_attn/src/dropout.h +index 4882f97..0c006c4 100644 +--- a/csrc/flash_attn/src/dropout.h ++++ b/csrc/flash_attn/src/dropout.h +@@ -5,7 +5,7 @@ + #pragma once + + #include "philox.cuh" +-#include "utils.h" ++#include "flash_utils.h" + + namespace flash { + +diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h +index 88a7195..4817fdd 100644 +--- a/csrc/flash_attn/src/flash.h ++++ b/csrc/flash_attn/src/flash.h +@@ -5,15 +5,10 @@ + #pragma once + + #include ++#include + #include + +-#ifdef OLD_GENERATOR_PATH +-#include +-#else +-#include +-#endif +- +-#include // For at::cuda::philox::unpack ++#include "philox.cuh" + + constexpr int TOTAL_DIM = 0; + constexpr int H_DIM = 1; +@@ -120,7 +115,7 @@ struct Flash_fwd_params : public Qkv_params { + int window_size_left, window_size_right; + + // Random state. +- at::PhiloxCudaState philox_args; ++ flash::PhiloxCudaState philox_args; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t * rng_state; +diff --git a/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h +index 6f89c21..c23211e 100644 +--- a/csrc/flash_attn/src/flash_bwd_kernel.h ++++ b/csrc/flash_attn/src/flash_bwd_kernel.h +@@ -4,15 +4,15 @@ + + #pragma once + +-#include ++#include "cute/algorithm/copy.hpp" + +-#include +-#include +-#include ++#include "cutlass/cutlass.h" ++#include "cutlass/array.h" ++#include "cutlass/numeric_types.h" + + #include "block_info.h" + #include "kernel_traits.h" +-#include "utils.h" ++#include "flash_utils.h" + #include "softmax.h" + #include "mask.h" + #include "dropout.h" +diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h +index fd81c88..5896245 100644 +--- a/csrc/flash_attn/src/flash_bwd_launch_template.h ++++ b/csrc/flash_attn/src/flash_bwd_launch_template.h +@@ -4,8 +4,6 @@ + + #pragma once + +-#include +- + #include "static_switch.h" + #include "flash.h" + #include "flash_bwd_preprocess_kernel.h" +@@ -72,7 +70,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) + const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; + int gridDimx = num_n_block; + if (params.deterministic) { +- auto dprops = at::cuda::getCurrentDeviceProperties(); ++ auto dprops = flash::cuda::getCurrentDeviceProperties(); + gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h); + } + dim3 grid_n(gridDimx, params.b, params.h); +@@ -82,7 +80,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) + } else { + flash_bwd_dot_do_o_kernel<<>>(params); + } +- C10_CUDA_KERNEL_LAUNCH_CHECK(); ++ FLASH_CUDA_KERNEL_LAUNCH_CHECK(); + + // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not + // a multiple of kBlockN, we'll need to apply mask in the loop. +@@ -101,11 +99,11 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + if (smem_size_dq_dk_dv >= 48 * 1024) { +- C10_CUDA_CHECK(cudaFuncSetAttribute( ++ FLASH_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + kernel<<>>(params); +- C10_CUDA_KERNEL_LAUNCH_CHECK(); ++ FLASH_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); +@@ -114,11 +112,11 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) + + auto kernel_dq = &flash_bwd_convert_dq_kernel; + if (Kernel_traits::kSmemdQSize >= 48 * 1024) { +- C10_CUDA_CHECK(cudaFuncSetAttribute( ++ FLASH_CUDA_CHECK(cudaFuncSetAttribute( + kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); + } + kernel_dq<<>>(params, !params.deterministic ? 1 : gridDimx); +- C10_CUDA_KERNEL_LAUNCH_CHECK(); ++ FLASH_CUDA_KERNEL_LAUNCH_CHECK(); + } + + template +@@ -137,7 +135,7 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) { + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { +- C10_CUDA_CHECK(status_); ++ FLASH_CUDA_CHECK(status_); + } + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB +@@ -161,7 +159,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { +- C10_CUDA_CHECK(status_); ++ FLASH_CUDA_CHECK(status_); + } + // printf("max_smem_per_block = %d\n", max_smem_per_block); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +@@ -206,7 +204,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { +- C10_CUDA_CHECK(status_); ++ FLASH_CUDA_CHECK(status_); + } + // printf("max_smem_per_block = %d\n", max_smem_per_block); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +@@ -232,7 +230,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { +- C10_CUDA_CHECK(status_); ++ FLASH_CUDA_CHECK(status_); + } + // printf("max_smem_per_block = %d\n", max_smem_per_block); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +@@ -266,7 +264,7 @@ void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream) { + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { +- C10_CUDA_CHECK(status_); ++ FLASH_CUDA_CHECK(status_); + } + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if (max_smem_per_block >= 116 * 1024) { +@@ -286,7 +284,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { +- C10_CUDA_CHECK(status_); ++ FLASH_CUDA_CHECK(status_); + } + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if (max_smem_per_block >= 136 * 1024) { +@@ -314,7 +312,7 @@ void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { +- C10_CUDA_CHECK(status_); ++ FLASH_CUDA_CHECK(status_); + } + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if (max_smem_per_block >= 176 * 1024) { // H100 +diff --git a/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h +index 6582d81..3f6e7fb 100644 +--- a/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h ++++ b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h +@@ -4,15 +4,15 @@ + + #pragma once + +-#include ++#include "cute/algorithm/copy.hpp" + +-#include +-#include +-#include ++#include "cutlass/cutlass.h" ++#include "cutlass/array.h" ++#include "cutlass/numeric_types.h" + + #include "block_info.h" + #include "kernel_traits.h" +-#include "utils.h" ++#include "flash_utils.h" + + namespace flash { + +diff --git a/csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h +index 104e164..d03b2e1 100644 +--- a/csrc/flash_attn/src/flash_fwd_kernel.h ++++ b/csrc/flash_attn/src/flash_fwd_kernel.h +@@ -4,15 +4,15 @@ + + #pragma once + +-#include ++#include "cute/algorithm/copy.hpp" + +-#include +-#include +-#include ++#include "cutlass/cutlass.h" ++#include "cutlass/array.h" ++#include "cutlass/numeric_types.h" + + #include "block_info.h" + #include "kernel_traits.h" +-#include "utils.h" ++#include "flash_utils.h" + #include "softmax.h" + #include "mask.h" + #include "dropout.h" +@@ -42,7 +42,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + +- auto seed_offset = at::cuda::philox::unpack(params.philox_args); ++ auto seed_offset = flash::cuda::philox::unpack(params.philox_args); + flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t, + bidb, bidh, tidx, params.h); + +diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h +index fa6a6f6..c6a6fc4 100644 +--- a/csrc/flash_attn/src/flash_fwd_launch_template.h ++++ b/csrc/flash_attn/src/flash_fwd_launch_template.h +@@ -4,8 +4,6 @@ + + #pragma once + +-#include +- + #include "static_switch.h" + #include "flash.h" + #include "flash_fwd_kernel.h" +@@ -77,7 +75,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { +- C10_CUDA_CHECK(cudaFuncSetAttribute( ++ FLASH_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; +@@ -85,7 +83,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); +- C10_CUDA_KERNEL_LAUNCH_CHECK(); ++ FLASH_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); +@@ -116,11 +114,11 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { +- C10_CUDA_CHECK(cudaFuncSetAttribute( ++ FLASH_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + kernel<<>>(params); +- C10_CUDA_KERNEL_LAUNCH_CHECK(); ++ FLASH_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); +@@ -150,7 +148,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + } else if (params.num_splits <= 128) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } +- C10_CUDA_KERNEL_LAUNCH_CHECK(); ++ FLASH_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } + } +@@ -200,7 +198,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { + template + void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 96; +- auto dprops = at::cuda::getCurrentDeviceProperties(); ++ auto dprops = flash::cuda::getCurrentDeviceProperties(); + bool is_sm8x = dprops->major == 8 && dprops->minor > 0; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { +@@ -226,7 +224,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { + template + void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 128; +- auto dprops = at::cuda::getCurrentDeviceProperties(); ++ auto dprops = flash::cuda::getCurrentDeviceProperties(); + bool is_sm8x = dprops->major == 8 && dprops->minor > 0; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { +@@ -263,7 +261,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { + template + void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 160; +- auto dprops = at::cuda::getCurrentDeviceProperties(); ++ auto dprops = flash::cuda::getCurrentDeviceProperties(); + bool is_sm8x = dprops->major == 8 && dprops->minor > 0; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { +@@ -318,7 +316,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { +- C10_CUDA_CHECK(status_); ++ FLASH_CUDA_CHECK(status_); + } + // printf("max_smem_per_block = %d\n", max_smem_per_block); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +@@ -349,7 +347,7 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { + status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { +- C10_CUDA_CHECK(status_); ++ FLASH_CUDA_CHECK(status_); + } + // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu.cc +similarity index 100% +rename from csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu +rename to csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu.cc +diff --git a/csrc/flash_attn/src/flash_utils.h b/csrc/flash_attn/src/flash_utils.h +new file mode 100644 +index 0000000..d897282 +--- /dev/null ++++ b/csrc/flash_attn/src/flash_utils.h +@@ -0,0 +1,399 @@ ++// Copyright (C) 2024 Ant Group Co., Ltd. All Rights Reserved. ++// SPDX-License-Identifier: Apache-2.0 ++ ++#pragma once ++ ++#include ++#include ++#include ++ ++#include ++ ++#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 ++#include ++#endif ++ ++#include "cute/algorithm/copy.hpp" ++#include "cute/algorithm/gemm.hpp" ++ ++#include "cutlass/array.h" ++#include "cutlass/cutlass.h" ++#include "cutlass/numeric_conversion.h" ++#include "cutlass/numeric_types.h" ++ ++#include "utils.h" ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++namespace flash { ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++using namespace cute; ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++template ++__forceinline__ __device__ uint32_t relu2(const uint32_t x); ++ ++template<> ++__forceinline__ __device__ uint32_t relu2(const uint32_t x) { ++ uint32_t res; ++ const uint32_t zero = 0u; ++#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 ++ asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); ++#else ++ asm volatile( \ ++ "{\n" \ ++ "\t .reg .f16x2 sela;\n" \ ++ "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ ++ "\t and.b32 %0, sela, %1;\n" ++ "}\n" : "=r"(res) : "r"(x), "r"(zero)); ++#endif ++ return res; ++} ++ ++#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 ++template<> ++__forceinline__ __device__ uint32_t relu2(const uint32_t x) { ++ uint32_t res; ++ const uint32_t zero = 0u; ++ asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); ++ return res; ++} ++#endif ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 ++ ++template ++__forceinline__ __device__ uint32_t convert_relu2(const float2 x); ++ ++template<> ++__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { ++ uint32_t res; ++ const uint32_t a = reinterpret_cast(x.x); ++ const uint32_t b = reinterpret_cast(x.y); ++ asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); ++ return res; ++} ++ ++template<> ++__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { ++ uint32_t res; ++ const uint32_t a = reinterpret_cast(x.x); ++ const uint32_t b = reinterpret_cast(x.y); ++ asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); ++ return res; ++} ++ ++#endif ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++template ++struct MaxOp { ++__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } ++}; ++ ++template <> ++struct MaxOp { ++// This is slightly faster ++__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } ++}; ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++template ++struct SumOp { ++__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } ++}; ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++template ++struct Allreduce { ++ static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); ++ template ++ static __device__ __forceinline__ T run(T x, Operator &op) { ++ constexpr int OFFSET = THREADS / 2; ++ x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); ++ return Allreduce::run(x, op); ++ } ++}; ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++template<> ++struct Allreduce<2> { ++template ++static __device__ __forceinline__ T run(T x, Operator &op) { ++ x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); ++ return x; ++} ++}; ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++template ++__forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, ++ Tensor4 const& tCsB, TiledMma tiled_mma, ++ TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, ++ ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { ++ CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M ++ CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N ++ CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K ++ Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); ++ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M ++ Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); ++ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N ++ if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } ++ if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } ++ #pragma unroll ++ for (int i = 0; i < size<2>(tCrA); ++i) { ++ if (i < size<2>(tCrA) - 1) { ++ if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } ++ if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } ++ } ++ cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); ++ } ++} ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++template ++__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, ++ TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, ++ ThrCopy smem_thr_copy_B) { ++ CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M ++ CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N ++ CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K ++ Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); ++ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N ++ cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); ++ #pragma unroll ++ for (int i = 0; i < size<2>(tCrA); ++i) { ++ if (i < size<2>(tCrA) - 1) { ++ cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); ++ } ++ cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); ++ } ++} ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) ++template ++__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { ++ static_assert(decltype(size<0>(acc_layout))::value == 4); ++ static_assert(decltype(rank(acc_layout))::value == 3); ++ auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) ++ return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); ++}; ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) ++// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. ++template ++__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { ++ using X = Underscore; ++ static_assert(decltype(size<0>(acc_layout))::value == 4); ++ static_assert(decltype(rank(acc_layout))::value == 3); ++ constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); ++ static_assert(mma_shape_K == 8 || mma_shape_K == 16); ++ if constexpr (mma_shape_K == 8) { ++ return acc_layout; ++ } else { ++ auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) ++ return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); ++ } ++}; ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) ++template ++__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) { ++ using X = Underscore; ++ static_assert(decltype(size<0>(acc_layout))::value == 4); ++ static_assert(decltype(rank(acc_layout))::value == 3); ++ auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) ++ return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); ++}; ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++template ++__forceinline__ __device__ auto convert_type(Tensor const &tensor) { ++ using From_type = typename Engine::value_type; ++ constexpr int numel = decltype(size(tensor))::value; ++ cutlass::NumericArrayConverter convert_op; ++ // HACK: this requires tensor to be "contiguous" ++ auto frag = convert_op(*reinterpret_cast *>(tensor.data())); ++ return make_tensor(make_rmem_ptr(&frag), tensor.layout()); ++} ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++template ++__forceinline__ __device__ void relu_(Tensor &tensor) { ++ constexpr int numel = decltype(size(tensor))::value; ++ static_assert(numel % 2 == 0); ++ using value_t = typename Engine::value_type; ++ // HACK: this requires tensor to be "contiguous" ++ Tensor tensor_uint32 = recast(tensor); ++ #pragma unroll ++ for (int i = 0; i < size(tensor_uint32); ++i) { ++ tensor_uint32(i) = relu2(tensor_uint32(i)); ++ } ++} ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction ++template ++__forceinline__ __device__ auto convert_type_relu(Tensor const &tensor) { ++ using From_type = typename Engine::value_type; ++ static_assert(std::is_same_v || std::is_same_v); ++ static_assert(std::is_same_v); ++ constexpr int numel = decltype(size(tensor))::value; ++ static_assert(numel % 2 == 0); ++#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 ++ // HACK: this requires tensor to be "contiguous" ++ Tensor tensor_float2 = recast(tensor); ++ Tensor out_uint32 = make_tensor(tensor_float2.layout()); ++ #pragma unroll ++ for (int i = 0; i < size(out_uint32); ++i) { ++ out_uint32(i) = convert_relu2(tensor_float2(i)); ++ } ++ Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); ++#else ++ Tensor out = flash::convert_type(tensor); ++ flash::relu_(out); ++#endif ++ return out; ++} ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++// Blocks until all but N previous cp.async.commit_group operations have committed. ++// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all ++// (which is equivalent to commit_group then wait_group 0). ++// Instead we just call cp.async.wait_group 0, which is slightly faster. ++// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 ++template ++CUTE_HOST_DEVICE ++void cp_async_wait() { ++#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) ++ asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); ++#endif ++} ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++template ++__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, ++ Tensor &D, Tensor const &identity_MN, ++ Tensor const &predicate_K, const int max_MN=0) { ++ CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); ++ CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); ++ CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA ++ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M ++ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K ++ // There's no case where !Clear_OOB_K && Clear_OOB_MN ++ static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); ++ #pragma unroll ++ for (int m = 0; m < size<1>(S); ++m) { ++ if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { ++ #pragma unroll ++ for (int k = 0; k < size<2>(S); ++k) { ++ if (Is_even_K || predicate_K(k)) { ++ cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); ++ } else if (Clear_OOB_K) { ++ cute::clear(D(_, m, k)); ++ } ++ } ++ } else if (Clear_OOB_MN) { ++ cute::clear(D(_, m, _)); ++ } ++ } ++ // TD [2023-04-13]: Strange that the code below can cause race condition. ++ // I think it's because the copies are under an if statement. ++ // if (Is_even_K) { ++ // #pragma unroll ++ // for (int m = 0; m < size<1>(S); ++m) { ++ // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { ++ // copy(tiled_copy, S(_, m, _), D(_, m, _)); ++ // } else if (Clear_OOB_MN) { ++ // clear(D(_, m, _)); ++ // } ++ // } ++ // } else { // It's slightly faster in this case if iterate over K first ++ // #pragma unroll ++ // for (int k = 0; k < size<2>(S); ++k) { ++ // if (predicate_K(k)) { ++ // #pragma unroll ++ // for (int m = 0; m < size<1>(S); ++m) { ++ // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { ++ // copy(tiled_copy, S(_, m, k), D(_, m, k)); ++ // } else if (Clear_OOB_MN) { ++ // clear(D(_, m, k)); ++ // } ++ // } ++ // } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN ++ // if (Clear_OOB_MN || Is_even_MN) { ++ // clear(D(_, _, k)); ++ // } else { ++ // #pragma unroll ++ // for (int m = 0; m < size<1>(S); ++m) { ++ // if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) { ++ // clear(D(_, m, k)); ++ // } ++ // } ++ // } ++ // } ++ // } ++ // } ++} ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++template ++__forceinline__ __device__ void copy_w_min_idx(Tensor const &S, ++ Tensor &D, Tensor const &identity_MN, ++ Tensor const &predicate_K, ++ const int max_MN=0, const int min_MN=0) { ++ CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); ++ CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); ++ CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA ++ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M ++ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K ++ // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } ++ #pragma unroll ++ for (int m = 0; m < size<1>(S); ++m) { ++ // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } ++ if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { ++ // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } ++ #pragma unroll ++ for (int k = 0; k < size<2>(S); ++k) { ++ if (Is_even_K || predicate_K(k)) { ++ cute::copy(S(_, m, k), D(_, m, k)); ++ } ++ } ++ } ++ } ++} ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++} // namespace flash +diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h +index a7a5cf1..4bf4a7b 100644 +--- a/csrc/flash_attn/src/kernel_traits.h ++++ b/csrc/flash_attn/src/kernel_traits.h +@@ -8,7 +8,7 @@ + + #include "cutlass/cutlass.h" + #include "cutlass/layout/layout.h" +-#include ++#include "cutlass/numeric_types.h" + + using namespace cute; + +diff --git a/csrc/flash_attn/src/mask.h b/csrc/flash_attn/src/mask.h +index 3d9b429..73c52d5 100644 +--- a/csrc/flash_attn/src/mask.h ++++ b/csrc/flash_attn/src/mask.h +@@ -4,7 +4,7 @@ + + #pragma once + +-#include ++#include "cute/tensor.hpp" + + namespace flash { + +diff --git a/csrc/flash_attn/src/philox.cuh b/csrc/flash_attn/src/philox.cuh +index cd7e4d2..d7cde3f 100644 +--- a/csrc/flash_attn/src/philox.cuh ++++ b/csrc/flash_attn/src/philox.cuh +@@ -2,8 +2,87 @@ + #pragma once + // Philox CUDA. + ++#include ++#include ++#include ++ + namespace flash { + ++struct PhiloxCudaState { ++ PhiloxCudaState() = default; ++ // Called if graph capture is not underway ++ PhiloxCudaState(uint64_t seed, ++ uint64_t offset) ++ { ++ seed_.val = seed; ++ offset_.val = offset; ++ } ++ // Called if graph capture is underway ++ PhiloxCudaState(int64_t *seed, ++ int64_t *offset_extragraph, ++ uint32_t offset_intragraph) ++ { ++ seed_.ptr = seed; ++ offset_.ptr = offset_extragraph; ++ offset_intragraph_ = offset_intragraph; ++ captured_ = true; ++ } ++ ++ // Public members, directly accessible by at::cuda::philox::unpack. ++ // If we made them private with getters/setters, the getters/setters ++ // would have to be __device__, and we can't declare __device__ in ATen. ++ union Payload ++ { ++ uint64_t val; ++ int64_t *ptr; ++ }; ++ ++ Payload seed_; ++ Payload offset_; ++ uint32_t offset_intragraph_ = 0; ++ bool captured_ = false; ++}; ++ ++class CUDAPhiloxRandomGenerator { ++public: ++ CUDAPhiloxRandomGenerator() ++ { ++ std::random_device rd; ++ seed_ = ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF; ++ philox_offset_per_thread_ = 0; ++ } ++ ++ PhiloxCudaState philox_cuda_state(uint64_t increment) ++ { ++ // rounds increment up to the nearest multiple of 4 ++ increment = ((increment + 3) / 4) * 4; ++ uint64_t offset = this->philox_offset_per_thread_; ++ this->philox_offset_per_thread_ += increment; ++ return PhiloxCudaState(this->seed_, offset); ++ } ++ std::mutex mutex_; ++ ++private: ++ uint64_t seed_; ++ uint64_t philox_offset_per_thread_; ++}; ++ ++namespace cuda::philox { ++ ++__host__ __device__ __forceinline__ std::tuple ++unpack(PhiloxCudaState arg) { ++ if (arg.captured_) { ++ // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long". ++ // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel. ++ // For most threads' reads it will hit in cache, so it shouldn't hurt performance. ++ return std::make_tuple(static_cast(*arg.seed_.ptr), static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_)); ++ } else { ++ return std::make_tuple(arg.seed_.val, arg.offset_.val); ++ } ++} ++ ++} // namespace cuda::philox ++ + struct ull2 { + unsigned long long x; + unsigned long long y; +diff --git a/csrc/flash_attn/src/rotary.h b/csrc/flash_attn/src/rotary.h +index dc2825b..4644027 100644 +--- a/csrc/flash_attn/src/rotary.h ++++ b/csrc/flash_attn/src/rotary.h +@@ -4,9 +4,9 @@ + + #pragma once + +-#include ++#include "cute/algorithm/copy.hpp" + +-#include "utils.h" ++#include "flash_utils.h" + + //////////////////////////////////////////////////////////////////////////////////////////////////// + +diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h +index ebf1b09..025fbc7 100644 +--- a/csrc/flash_attn/src/softmax.h ++++ b/csrc/flash_attn/src/softmax.h +@@ -6,12 +6,12 @@ + + #include + +-#include ++#include "cute/tensor.hpp" + +-#include ++#include "cutlass/numeric_types.h" + + #include "philox.cuh" +-#include "utils.h" ++#include "flash_utils.h" + + namespace flash { + +diff --git a/csrc/flash_attn/src/utils.cc b/csrc/flash_attn/src/utils.cc +new file mode 100644 +index 0000000..9e7808b +--- /dev/null ++++ b/csrc/flash_attn/src/utils.cc +@@ -0,0 +1,67 @@ ++// Copyright (C) 2024 Ant Group Co., Ltd. All Rights Reserved. ++// SPDX-License-Identifier: Apache-2.0 ++ ++#include "utils.h" ++ ++#include ++#include ++#include ++#include ++ ++namespace flash::cuda { ++ ++int getCurrentDevice() { ++ int device; ++ FLASH_CUDA_CHECK(cudaGetDevice(&device)); ++ return device; ++} ++ ++static int num_gpus; ++static std::once_flag device_init_flag; ++static std::deque device_flags; ++static std::vector device_properties; ++static std::deque cuda_gens_init_flag; ++static std::vector> default_gens_cuda; ++ ++static void initCUDAContextVectors() { ++ FLASH_CUDA_CHECK(cudaGetDeviceCount(&num_gpus)); ++ device_flags.resize(num_gpus); ++ device_properties.resize(num_gpus); ++ cuda_gens_init_flag.resize(num_gpus); ++ default_gens_cuda.resize(num_gpus); ++} ++ ++static void initDeviceProperty(int device) { ++ cudaDeviceProp prop; ++ FLASH_CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); ++ device_properties[device] = prop; ++} ++ ++cudaDeviceProp* getDeviceProperties(int device) { ++ std::call_once(device_init_flag, initCUDAContextVectors); ++ if (device == -1) { ++ device = getCurrentDevice(); ++ } ++ FLASH_ASSERT(device >= 0 && device < num_gpus); ++ std::call_once(device_flags[device], initDeviceProperty, device); ++ return &device_properties[device]; ++} ++ ++cudaDeviceProp* getCurrentDeviceProperties() { ++ int cur_device = getCurrentDevice(); ++ return getDeviceProperties(cur_device); ++} ++ ++CUDAPhiloxRandomGenerator& getDefaultCUDAGenerator(int device) { ++ std::call_once(device_init_flag, initCUDAContextVectors); ++ if (device == -1) { ++ device = getCurrentDevice(); ++ } ++ FLASH_ASSERT(device >= 0 && device < num_gpus); ++ std::call_once(cuda_gens_init_flag[device], [&] { ++ default_gens_cuda[device] = std::make_unique(); ++ }); ++ return *default_gens_cuda[device]; ++} ++ ++} // namespace flash::cuda +diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h +index 2b45e87..f816c4b 100644 +--- a/csrc/flash_attn/src/utils.h ++++ b/csrc/flash_attn/src/utils.h +@@ -4,391 +4,46 @@ + + #pragma once + +-#include +-#include +-#include ++#include ++#include ++#include ++#include + +-#include ++#include ++#include + +-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +-#include +-#endif ++#include "philox.cuh" + +-#include +-#include +- +-#include +-#include +-#include +-#include +- +-//////////////////////////////////////////////////////////////////////////////////////////////////// +- +-namespace flash { +- +-//////////////////////////////////////////////////////////////////////////////////////////////////// +- +-template +-__forceinline__ __device__ uint32_t relu2(const uint32_t x); +- +-template<> +-__forceinline__ __device__ uint32_t relu2(const uint32_t x) { +- uint32_t res; +- const uint32_t zero = 0u; +-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +- asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); +-#else +- asm volatile( \ +- "{\n" \ +- "\t .reg .f16x2 sela;\n" \ +- "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ +- "\t and.b32 %0, sela, %1;\n" +- "}\n" : "=r"(res) : "r"(x), "r"(zero)); +-#endif +- return res; +-} +- +-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +-template<> +-__forceinline__ __device__ uint32_t relu2(const uint32_t x) { +- uint32_t res; +- const uint32_t zero = 0u; +- asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); +- return res; +-} +-#endif +- +-//////////////////////////////////////////////////////////////////////////////////////////////////// +- +-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +- +-template +-__forceinline__ __device__ uint32_t convert_relu2(const float2 x); +- +-template<> +-__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { +- uint32_t res; +- const uint32_t a = reinterpret_cast(x.x); +- const uint32_t b = reinterpret_cast(x.y); +- asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); +- return res; +-} +- +-template<> +-__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { +- uint32_t res; +- const uint32_t a = reinterpret_cast(x.x); +- const uint32_t b = reinterpret_cast(x.y); +- asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); +- return res; +-} +- +-#endif +- +-//////////////////////////////////////////////////////////////////////////////////////////////////// +- +-template +-struct MaxOp { +-__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } +-}; +- +-template <> +-struct MaxOp { +-// This is slightly faster +-__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } +-}; +- +-//////////////////////////////////////////////////////////////////////////////////////////////////// +- +-template +-struct SumOp { +-__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } +-}; +- +-//////////////////////////////////////////////////////////////////////////////////////////////////// +- +-template +-struct Allreduce { +- static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); +- template +- static __device__ __forceinline__ T run(T x, Operator &op) { +- constexpr int OFFSET = THREADS / 2; +- x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); +- return Allreduce::run(x, op); ++#define FLASH_ASSERT_WITH_MSG(EXPR, MSG) \ ++ if (!(EXPR)) { \ ++ const char *origin_msg = (MSG); \ ++ const char *msg = origin_msg != nullptr ? origin_msg : ""; \ ++ const char *msg_tail = origin_msg != nullptr ? ", " : ""; \ ++ printf("Assertion failed: %s%sfunction %s, file %s, line %d\n", msg, \ ++ msg_tail, __func__, __FILE__, __LINE__); \ ++ ::abort(); \ + } +-}; +- +-//////////////////////////////////////////////////////////////////////////////////////////////////// + +-template<> +-struct Allreduce<2> { +-template +-static __device__ __forceinline__ T run(T x, Operator &op) { +- x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); +- return x; +-} +-}; ++#define FLASH_ASSERT(EXPR) FLASH_ASSERT_WITH_MSG(EXPR, nullptr) + +-//////////////////////////////////////////////////////////////////////////////////////////////////// ++namespace flash::cuda { + +-template +-__forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, +- Tensor4 const& tCsB, TiledMma tiled_mma, +- TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, +- ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { +- CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M +- CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N +- CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K +- Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); +- CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M +- Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); +- CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N +- if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } +- if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } +- #pragma unroll +- for (int i = 0; i < size<2>(tCrA); ++i) { +- if (i < size<2>(tCrA) - 1) { +- if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } +- if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } +- } +- cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); +- } +-} +- +-//////////////////////////////////////////////////////////////////////////////////////////////////// ++#define FLASH_CUDA_CHECK(EXPR) \ ++ do \ ++ { \ ++ const cudaError_t cuda_err = EXPR; \ ++ if (cuda_err != cudaSuccess) \ ++ { \ ++ const auto &error_msg = std::string("CUDA error: ") + cudaGetErrorString(cuda_err); \ ++ FLASH_ASSERT_WITH_MSG(false, error_msg.c_str()); \ ++ } \ ++ } while (0) + +-template +-__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, +- TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, +- ThrCopy smem_thr_copy_B) { +- CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M +- CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N +- CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K +- Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); +- CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N +- cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); +- #pragma unroll +- for (int i = 0; i < size<2>(tCrA); ++i) { +- if (i < size<2>(tCrA) - 1) { +- cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); +- } +- cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); +- } +-} +- +-//////////////////////////////////////////////////////////////////////////////////////////////////// +- +-// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +-template +-__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { +- static_assert(decltype(size<0>(acc_layout))::value == 4); +- static_assert(decltype(rank(acc_layout))::value == 3); +- auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) +- return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); +-}; +- +-//////////////////////////////////////////////////////////////////////////////////////////////////// +- +-// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +-// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +-template +-__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { +- using X = Underscore; +- static_assert(decltype(size<0>(acc_layout))::value == 4); +- static_assert(decltype(rank(acc_layout))::value == 3); +- constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); +- static_assert(mma_shape_K == 8 || mma_shape_K == 16); +- if constexpr (mma_shape_K == 8) { +- return acc_layout; +- } else { +- auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) +- return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); +- } +-}; +- +-//////////////////////////////////////////////////////////////////////////////////////////////////// +- +-// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +-template +-__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) { +- using X = Underscore; +- static_assert(decltype(size<0>(acc_layout))::value == 4); +- static_assert(decltype(rank(acc_layout))::value == 3); +- auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) +- return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); +-}; +- +-//////////////////////////////////////////////////////////////////////////////////////////////////// +- +-template +-__forceinline__ __device__ auto convert_type(Tensor const &tensor) { +- using From_type = typename Engine::value_type; +- constexpr int numel = decltype(size(tensor))::value; +- cutlass::NumericArrayConverter convert_op; +- // HACK: this requires tensor to be "contiguous" +- auto frag = convert_op(*reinterpret_cast *>(tensor.data())); +- return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +-} +- +-//////////////////////////////////////////////////////////////////////////////////////////////////// +- +-template +-__forceinline__ __device__ void relu_(Tensor &tensor) { +- constexpr int numel = decltype(size(tensor))::value; +- static_assert(numel % 2 == 0); +- using value_t = typename Engine::value_type; +- // HACK: this requires tensor to be "contiguous" +- Tensor tensor_uint32 = recast(tensor); +- #pragma unroll +- for (int i = 0; i < size(tensor_uint32); ++i) { +- tensor_uint32(i) = relu2(tensor_uint32(i)); +- } +-} +- +-//////////////////////////////////////////////////////////////////////////////////////////////////// +- +-// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction +-template +-__forceinline__ __device__ auto convert_type_relu(Tensor const &tensor) { +- using From_type = typename Engine::value_type; +- static_assert(std::is_same_v || std::is_same_v); +- static_assert(std::is_same_v); +- constexpr int numel = decltype(size(tensor))::value; +- static_assert(numel % 2 == 0); +-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +- // HACK: this requires tensor to be "contiguous" +- Tensor tensor_float2 = recast(tensor); +- Tensor out_uint32 = make_tensor(tensor_float2.layout()); +- #pragma unroll +- for (int i = 0; i < size(out_uint32); ++i) { +- out_uint32(i) = convert_relu2(tensor_float2(i)); +- } +- Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); +-#else +- Tensor out = flash::convert_type(tensor); +- flash::relu_(out); +-#endif +- return out; +-} +- +-//////////////////////////////////////////////////////////////////////////////////////////////////// +- +-// Blocks until all but N previous cp.async.commit_group operations have committed. +-// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +-// (which is equivalent to commit_group then wait_group 0). +-// Instead we just call cp.async.wait_group 0, which is slightly faster. +-// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +-template +-CUTE_HOST_DEVICE +-void cp_async_wait() { +-#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) +- asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); +-#endif +-} +- +-//////////////////////////////////////////////////////////////////////////////////////////////////// +- +-template +-__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, +- Tensor &D, Tensor const &identity_MN, +- Tensor const &predicate_K, const int max_MN=0) { +- CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); +- CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); +- CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA +- CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M +- CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K +- // There's no case where !Clear_OOB_K && Clear_OOB_MN +- static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); +- #pragma unroll +- for (int m = 0; m < size<1>(S); ++m) { +- if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { +- #pragma unroll +- for (int k = 0; k < size<2>(S); ++k) { +- if (Is_even_K || predicate_K(k)) { +- cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); +- } else if (Clear_OOB_K) { +- cute::clear(D(_, m, k)); +- } +- } +- } else if (Clear_OOB_MN) { +- cute::clear(D(_, m, _)); +- } +- } +- // TD [2023-04-13]: Strange that the code below can cause race condition. +- // I think it's because the copies are under an if statement. +- // if (Is_even_K) { +- // #pragma unroll +- // for (int m = 0; m < size<1>(S); ++m) { +- // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { +- // copy(tiled_copy, S(_, m, _), D(_, m, _)); +- // } else if (Clear_OOB_MN) { +- // clear(D(_, m, _)); +- // } +- // } +- // } else { // It's slightly faster in this case if iterate over K first +- // #pragma unroll +- // for (int k = 0; k < size<2>(S); ++k) { +- // if (predicate_K(k)) { +- // #pragma unroll +- // for (int m = 0; m < size<1>(S); ++m) { +- // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { +- // copy(tiled_copy, S(_, m, k), D(_, m, k)); +- // } else if (Clear_OOB_MN) { +- // clear(D(_, m, k)); +- // } +- // } +- // } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN +- // if (Clear_OOB_MN || Is_even_MN) { +- // clear(D(_, _, k)); +- // } else { +- // #pragma unroll +- // for (int m = 0; m < size<1>(S); ++m) { +- // if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) { +- // clear(D(_, m, k)); +- // } +- // } +- // } +- // } +- // } +- // } +-} +- +-//////////////////////////////////////////////////////////////////////////////////////////////////// +- +-template +-__forceinline__ __device__ void copy_w_min_idx(Tensor const &S, +- Tensor &D, Tensor const &identity_MN, +- Tensor const &predicate_K, +- const int max_MN=0, const int min_MN=0) { +- CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); +- CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); +- CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA +- CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M +- CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K +- // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } +- #pragma unroll +- for (int m = 0; m < size<1>(S); ++m) { +- // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } +- if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +- // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } +- #pragma unroll +- for (int k = 0; k < size<2>(S); ++k) { +- if (Is_even_K || predicate_K(k)) { +- cute::copy(S(_, m, k), D(_, m, k)); +- } +- } +- } +- } +-} ++#define FLASH_CUDA_KERNEL_LAUNCH_CHECK() FLASH_CUDA_CHECK(cudaGetLastError()) + +-//////////////////////////////////////////////////////////////////////////////////////////////////// ++int getCurrentDevice(); ++cudaDeviceProp* getDeviceProperties(int device); ++cudaDeviceProp* getCurrentDeviceProperties(); ++CUDAPhiloxRandomGenerator& getDefaultCUDAGenerator(int device); + +-} // namespace flash ++} // namespace flash::cuda +diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py +index a7f15be..a1ef865 100644 +--- a/flash_attn/flash_attn_interface.py ++++ b/flash_attn/flash_attn_interface.py +@@ -79,7 +79,6 @@ def _flash_attn_varlen_forward( + window_size, + alibi_slopes, + return_softmax, +- block_table, + ): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] +@@ -91,7 +90,6 @@ def _flash_attn_varlen_forward( + cu_seqlens_q, + cu_seqlens_k, + None, +- block_table, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, +@@ -301,7 +299,6 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, +- block_table=None, + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) + ctx.dropout_p = dropout_p +@@ -443,7 +440,6 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, +- block_table=None, + ) + ctx.save_for_backward( + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state +@@ -574,7 +570,6 @@ class FlashAttnVarlenFunc(torch.autograd.Function): + alibi_slopes, + deterministic, + return_softmax, +- block_table, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) +@@ -592,7 +587,6 @@ class FlashAttnVarlenFunc(torch.autograd.Function): + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, +- block_table=block_table, + ) + ctx.save_for_backward( + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state +@@ -636,7 +630,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., : dout.shape[-1]] + dv = dv[..., : dout.shape[-1]] +- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None ++ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None + + + def flash_attn_qkvpacked_func( +@@ -1007,7 +1001,6 @@ def flash_attn_varlen_func( + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +- block_table=None, + ): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads +@@ -1078,7 +1071,6 @@ def flash_attn_varlen_func( + alibi_slopes, + deterministic, + return_attn_probs, +- block_table, + ) + + +diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py +index 308e30b..892b8be 100644 +--- a/tests/test_flash_attn.py ++++ b/tests/test_flash_attn.py +@@ -1542,12 +1542,8 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): + (1023, 1024), + ], + ) +-# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged +-@pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) + # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) +-def test_flash_attn_varlen_causal( +- seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype +-): ++def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): + if ( + max(seqlen_q, seqlen_k) >= 2048 + and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 +@@ -1563,19 +1559,8 @@ def test_flash_attn_varlen_causal( + nheads = 9 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) +- +- if paged_kv_block_size is None: +- k = torch.randn( +- batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True +- ) +- v = torch.randn( +- batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True +- ) +- block_table = None +- else: +- k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache( +- seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype +- ) ++ k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) ++ v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") + ( +@@ -1595,8 +1580,8 @@ def test_flash_attn_varlen_causal( + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + out_unpad = flash_attn_varlen_func( + q_unpad, +- k_unpad if paged_kv_block_size is None else k_cache_paged, +- v_unpad if paged_kv_block_size is None else v_cache_paged, ++ k_unpad, ++ v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, +@@ -1604,7 +1589,6 @@ def test_flash_attn_varlen_causal( + 0.0, + causal=causal, + window_size=window_size, +- block_table=block_table, + ) + out = output_pad_fn(out_unpad) + out_ref, attn_ref = attention_ref( +@@ -1641,8 +1625,7 @@ def test_flash_attn_varlen_causal( + + g = torch.randn_like(out) + do_o = (g.float() * out.float()).sum(-1) +- test_backward = (d <= MAX_HEADDIM_SM8x or d > 224 or is_sm80 or is_sm90) and block_table is None +- if test_backward: ++ if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90): + ( + dq_unpad, + dk_unpad, +@@ -1678,7 +1661,7 @@ def test_flash_attn_varlen_causal( + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 + +- if test_backward: ++ if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90): + assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 + assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 + assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 +@@ -1905,16 +1888,29 @@ def test_flash_attn_kvcache( + v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) + block_table = None + else: +- ( +- k_cache, +- v_cache, +- block_table, +- k_cache_paged, +- v_cache_paged, +- num_blocks, +- ) = _generate_block_kvcache( +- seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype ++ num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3 ++ k_cache_paged = torch.randn( ++ num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype ++ ) ++ v_cache_paged = torch.randn( ++ num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype + ) ++ block_table = rearrange( ++ torch.randperm(num_blocks, dtype=torch.int32, device=device), ++ "(b nblocks) -> b nblocks", ++ b=batch_size, ++ ) ++ k_cache = rearrange( ++ # pytorch 1.12 doesn't have indexing with int32 ++ k_cache_paged[block_table.to(dtype=torch.long).flatten()], ++ "(b nblocks) block_size ... -> b (nblocks block_size) ...", ++ b=batch_size, ++ )[:, :seqlen_k] ++ v_cache = rearrange( ++ v_cache_paged[block_table.to(dtype=torch.long).flatten()], ++ "(b nblocks) block_size ... -> b (nblocks block_size) ...", ++ b=batch_size, ++ )[:, :seqlen_k] + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough +@@ -2077,33 +2073,6 @@ def test_flash_attn_kvcache( + assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + + +-def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype): +- num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3 +- k_cache_paged = torch.randn( +- num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype +- ) +- v_cache_paged = torch.randn( +- num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype +- ) +- block_table = rearrange( +- torch.randperm(num_blocks, dtype=torch.int32, device=device), +- "(b nblocks) -> b nblocks", +- b=batch_size, +- ) +- k_cache = rearrange( +- # pytorch 1.12 doesn't have indexing with int32 +- k_cache_paged[block_table.to(dtype=torch.long).flatten()], +- "(b nblocks) block_size ... -> b (nblocks block_size) ...", +- b=batch_size, +- )[:, :seqlen_k] +- v_cache = rearrange( +- v_cache_paged[block_table.to(dtype=torch.long).flatten()], +- "(b nblocks) block_size ... -> b (nblocks block_size) ...", +- b=batch_size, +- )[:, :seqlen_k] +- return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks +- +- + # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) + @pytest.mark.parametrize("dtype", [torch.float16]) + @pytest.mark.parametrize("causal", [False, True]) \ No newline at end of file diff --git a/third_party/flash_attn/workspace.bzl b/third_party/flash_attn/workspace.bzl new file mode 100644 index 0000000000000..28766ba1ced1a --- /dev/null +++ b/third_party/flash_attn/workspace.bzl @@ -0,0 +1,19 @@ +"""Provides the repository macro to import flash-attention.""" + +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +def repo(): + # v2.5.7 + FLASH_ATTN_COMMIT = "85881f547fd1053a7b4a2c3faad6690cca969279" + FLASH_ATTN_SHA256 = "66f1c7c09d0783c2b5d89b17b542562166d4276b180ae5cad184ad8f2f32d115" + + tf_http_archive( + name = "flash_attn", + sha256 = FLASH_ATTN_SHA256, + strip_prefix = "flash-attention-{commit}".format(commit = FLASH_ATTN_COMMIT), + urls = tf_mirror_urls("https://github.com/Dao-AILab/flash-attention/archive/{commit}.tar.gz".format(commit = FLASH_ATTN_COMMIT)), + build_file = "//third_party/flash_attn:flash_attn.BUILD", + patch_file = [ + "//third_party/flash_attn:flash_attn.patch" + ], + ) diff --git a/third_party/gloo/BUILD b/third_party/gloo/BUILD new file mode 100644 index 0000000000000..3c413807167ae --- /dev/null +++ b/third_party/gloo/BUILD @@ -0,0 +1 @@ +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/third_party/gloo/gloo.BUILD b/third_party/gloo/gloo.BUILD new file mode 100644 index 0000000000000..2a9ca06136ca4 --- /dev/null +++ b/third_party/gloo/gloo.BUILD @@ -0,0 +1,97 @@ +# Description: +# Gloo is a collective communications library + +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +substitions = { + "@GLOO_VERSION_MAJOR@": "9999", + "@GLOO_VERSION_MINOR@": "0", + "@GLOO_VERSION_PATCH@": "0", + "#cmakedefine01 GLOO_USE_CUDA": "#define GLOO_USE_CUDA 0", + "#cmakedefine01 GLOO_USE_NCCL": "#define GLOO_USE_NCCL 0", + "#cmakedefine01 GLOO_USE_ROCM": "#define GLOO_USE_ROCM 0", + "#cmakedefine01 GLOO_USE_RCCL": "#define GLOO_USE_RCCL 0", + "#cmakedefine01 GLOO_USE_REDIS": "#define GLOO_USE_REDIS 0", + "#cmakedefine01 GLOO_USE_IBVERBS": "#define GLOO_USE_IBVERBS 0", + "#cmakedefine01 GLOO_USE_MPI": "#define GLOO_USE_MPI 0", + "#cmakedefine01 GLOO_USE_LIBUV": "#define GLOO_USE_LIBUV 0", + "#cmakedefine01 GLOO_HAVE_TRANSPORT_TCP": "#define GLOO_HAVE_TRANSPORT_TCP 1", + "#cmakedefine01 GLOO_HAVE_TRANSPORT_TCP_TLS": "#define GLOO_HAVE_TRANSPORT_TCP_TLS 0", + "#cmakedefine01 GLOO_HAVE_TRANSPORT_IBVERBS": "#define GLOO_HAVE_TRANSPORT_IBVERBS 0", + "#cmakedefine01 GLOO_HAVE_TRANSPORT_UV": "#define GLOO_HAVE_TRANSPORT_UV 0", + "#cmakedefine01 GLOO_USE_AVX": "#define GLOO_USE_AVX __AVX__", +} + +expand_template( + name = "config", + out = "gloo/config.h", + substitutions = substitions, + template = "gloo/config.h.in", +) + +cc_library( + name = "gloo", + srcs = glob( + [ + "gloo/*.cc", + "gloo/common/*.cc", + "gloo/transport/*.cc", + ], + exclude = [ + "gloo/common/linux.cc", + "gloo/common/win.cc", + "gloo/cuda*.cc", + ], + ) + [ + "gloo/rendezvous/context.cc", + "gloo/rendezvous/file_store.cc", + "gloo/rendezvous/hash_store.cc", + "gloo/rendezvous/prefix_store.cc", + "gloo/rendezvous/store.cc", + ] + select({ + "@tsl//tsl:macos": [], + "@tsl//tsl:windows": [], + "//conditions:default": [ + "gloo/common/linux.cc", + ], + }), + copts = [ + "-fexceptions", + "-Wno-unused-variable", + ], + includes = ["."], + textual_hdrs = glob( + [ + "gloo/*.h", + "gloo/common/*.h", + "gloo/transport/*.h", + ], + exclude = [ + "gloo/cuda*.h", + "gloo/common/win.h", + ], + ) + [ + "gloo/config.h", + "gloo/rendezvous/context.h", + "gloo/rendezvous/file_store.h", + "gloo/rendezvous/hash_store.h", + "gloo/rendezvous/prefix_store.h", + "gloo/rendezvous/store.h", + ], +) + +cc_library( + name = "transport_tcp", + srcs = glob(["gloo/transport/tcp/*.cc"]), + hdrs = glob(["gloo/transport/tcp/*.h"]), + copts = ["-fexceptions"], + deps = [":gloo"], +) diff --git a/third_party/gloo/workspace.bzl b/third_party/gloo/workspace.bzl new file mode 100644 index 0000000000000..ede168395acdc --- /dev/null +++ b/third_party/gloo/workspace.bzl @@ -0,0 +1,17 @@ +"""Provides the repository macro to import Gloo.""" + +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +def repo(): + """Imports Gloo.""" + + GLOO_COMMIT = "5354032ea08eadd7fc4456477f7f7c6308818509" + GLOO_SHA256 = "5759a06e6c8863c58e8ceadeb56f7c701fec89b2559ba33a103a447207bf69c7" + + tf_http_archive( + name = "gloo", + sha256 = GLOO_SHA256, + strip_prefix = "gloo-{commit}".format(commit = GLOO_COMMIT), + urls = tf_mirror_urls("https://github.com/facebookincubator/gloo/archive/{commit}.tar.gz".format(commit = GLOO_COMMIT)), + build_file = "//third_party/gloo:gloo.BUILD", + ) diff --git a/third_party/implib_so/BUILD b/third_party/implib_so/BUILD new file mode 100644 index 0000000000000..ca6976cd8d342 --- /dev/null +++ b/third_party/implib_so/BUILD @@ -0,0 +1,21 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # MIT + +py_binary( + name = "get_symbols", + srcs = ["get_symbols.py"], + deps = [ + "@bazel_tools//tools/python/runfiles", + "@implib_so//:implib_gen_lib", + ], +) + +py_binary( + name = "make_stub", + srcs = ["make_stub.py"], + deps = [ + "@bazel_tools//tools/python/runfiles", + "@implib_so//:implib_gen_lib", + ], +) diff --git a/third_party/implib_so/get_symbols.py b/third_party/implib_so/get_symbols.py new file mode 100644 index 0000000000000..c21d2bbbd0cde --- /dev/null +++ b/third_party/implib_so/get_symbols.py @@ -0,0 +1,38 @@ +"""Given a .so file, lists symbols that should be included in a stub. + +Example usage: +$ bazel run -c opt @tsl//third_party/implib_so:get_symbols +/usr/local/cuda/lib64/libcudart.so > third_party/tsl/tsl/cuda/cudart.symbols +""" + +import argparse +import importlib + +# We can't import implib-gen directly because it has a dash in its name. +implib = importlib.import_module('implib-gen') + + +def _is_exported_function(s): + return ( + s['Bind'] != 'LOCAL' + and s['Type'] == 'FUNC' + and s['Ndx'] != 'UND' + and s['Name'] not in ['', '_init', '_fini'] + and s['Default'] + ) + + +def main(): + parser = argparse.ArgumentParser( + description='Extracts a list of symbols from a shared library' + ) + parser.add_argument('library', help='Path to the .so file.') + args = parser.parse_args() + syms = implib.collect_syms(args.library) + funs = [s['Name'] for s in syms if _is_exported_function(s)] + for f in sorted(funs): + print(f) + + +if __name__ == '__main__': + main() diff --git a/third_party/implib_so/implib_so.BUILD b/third_party/implib_so/implib_so.BUILD new file mode 100644 index 0000000000000..bbfb2898eb12d --- /dev/null +++ b/third_party/implib_so/implib_so.BUILD @@ -0,0 +1,20 @@ +# Description: +# Implib.so is a simple equivalent of Windows DLL import libraries for POSIX +# shared libraries. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # MIT + +exports_files([ + "LICENSE.txt", +]) + +py_library( + name = "implib_gen_lib", + srcs = ["implib-gen.py"], + data = glob([ + "arch/**/*.S.tpl", + "arch/**/*.ini", + ]), +) diff --git a/third_party/implib_so/make_stub.py b/third_party/implib_so/make_stub.py new file mode 100644 index 0000000000000..f0e1fe564c0c1 --- /dev/null +++ b/third_party/implib_so/make_stub.py @@ -0,0 +1,68 @@ +"""Given a list of symbols, generates a stub.""" + +import argparse +import configparser +import os +import string + +from bazel_tools.tools.python.runfiles import runfiles + +r = runfiles.Create() + + +def main(): + parser = argparse.ArgumentParser( + description='Generates stubs for CUDA libraries.' + ) + parser.add_argument('symbols', help='File containing a list of symbols.') + parser.add_argument( + '--outdir', '-o', help='Path to create wrapper at', default='.' + ) + parser.add_argument( + '--target', + help='Target platform name, e.g. x86_64, aarch64.', + required=True, + ) + args = parser.parse_args() + + config_path = r.Rlocation(f'implib_so/arch/{args.target}/config.ini') + table_path = r.Rlocation(f'implib_so/arch/{args.target}/table.S.tpl') + trampoline_path = r.Rlocation( + f'implib_so/arch/{args.target}/trampoline.S.tpl' + ) + + cfg = configparser.ConfigParser(inline_comment_prefixes=';') + cfg.read(config_path) + ptr_size = int(cfg['Arch']['PointerSize']) + + with open(args.symbols, 'r') as f: + funs = [s.strip() for s in f.readlines()] + + # Generate assembly code, containing a table for the resolved symbols and the + # trampolines. + lib_name, _ = os.path.splitext(os.path.basename(args.symbols)) + + with open(os.path.join(args.outdir, f'{lib_name}.tramp.S'), 'w') as f: + with open(table_path, 'r') as t: + table_text = string.Template(t.read()).substitute( + lib_suffix=lib_name, table_size=ptr_size * (len(funs) + 1) + ) + f.write(table_text) + + with open(trampoline_path, 'r') as t: + tramp_tpl = string.Template(t.read()) + + for i, name in enumerate(funs): + tramp_text = tramp_tpl.substitute( + lib_suffix=lib_name, sym=name, offset=i * ptr_size, number=i + ) + f.write(tramp_text) + + # Generates a list of symbols, formatted as a list of C++ strings. + with open(os.path.join(args.outdir, f'{lib_name}.inc'), 'w') as f: + sym_names = ''.join(f' "{name}",\n' for name in funs) + f.write(sym_names) + + +if __name__ == '__main__': + main() diff --git a/third_party/implib_so/workspace.bzl b/third_party/implib_so/workspace.bzl new file mode 100644 index 0000000000000..37f36cc135fd6 --- /dev/null +++ b/third_party/implib_so/workspace.bzl @@ -0,0 +1,13 @@ +"""Implib.so is a simple equivalent of Windows DLL import libraries for POSIX +shared libraries.""" + +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +def repo(): + tf_http_archive( + name = "implib_so", + strip_prefix = "Implib.so-2cce6cab8ff2c15f9da858ea0b68646a8d62aef2", + sha256 = "4ef3089969d57a5b60bb41b8212c478eaa15c56941f86d4bf5e7f98a3afd24e8", + urls = tf_mirror_urls("https://github.com/yugr/Implib.so/archive/2cce6cab8ff2c15f9da858ea0b68646a8d62aef2.tar.gz"), + build_file = "//third_party/implib_so:implib_so.BUILD", + ) diff --git a/third_party/llvm/build.patch b/third_party/llvm/build.patch index bbf8f587acada..479e08cde869a 100644 --- a/third_party/llvm/build.patch +++ b/third_party/llvm/build.patch @@ -1,8 +1,8 @@ diff --git a/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel b/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel -index 2b88729d748b..e12d979b4908 100644 +index 7770284e5543..0b45127495dc 100644 --- a/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel -@@ -207,13 +207,15 @@ cc_library( +@@ -218,13 +218,15 @@ cc_library( "lib/Support/BLAKE3/llvm_blake3_prefix.h", ] + select({ "@platforms//cpu:aarch64": [ @@ -23,7 +23,7 @@ index 2b88729d748b..e12d979b4908 100644 ], "//conditions:default": [ ], -@@ -238,14 +240,16 @@ cc_library( +@@ -249,14 +251,16 @@ cc_library( ], copts = llvm_copts, defines = select({ diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index ce1937af46e5d..509398da979e8 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,12 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -@@ -594,6 +594,7 @@ - name = "__support_bit", - hdrs = ["src/__support/bit.h"], - deps = [ -+ ":__support_cpp_type_traits", - ":__support_macros_attributes", - ], - ) diff --git a/third_party/llvm/toolchains.patch b/third_party/llvm/toolchains.patch index dc45d4d4987dc..a4de4eaaff343 100644 --- a/third_party/llvm/toolchains.patch +++ b/third_party/llvm/toolchains.patch @@ -34,12 +34,12 @@ index c43ab727e285..7d848d2dffae 100644 # The necessary warnings and other compile flags should be provided by the # toolchain or the `.bazelrc` file. This is just a workaround until we have a diff --git a/utils/bazel/llvm-project-overlay/llvm/config.bzl b/utils/bazel/llvm-project-overlay/llvm/config.bzl -index b15ec9e1bb39..56c2766872fa 100644 +index 2e3bff53ead9..8d01617effdc 100644 --- a/utils/bazel/llvm-project-overlay/llvm/config.bzl +++ b/utils/bazel/llvm-project-overlay/llvm/config.bzl -@@ -89,8 +89,9 @@ os_defines = select({ +@@ -98,8 +98,9 @@ builtin_thread_pointer = select({ # TODO: We should split out host vs. target here. - llvm_config_defines = os_defines + select({ + llvm_config_defines = os_defines + builtin_thread_pointer + select({ "@bazel_tools//src/conditions:windows": native_arch_defines("X86", "x86_64-pc-win32"), - "@bazel_tools//src/conditions:darwin_arm64": native_arch_defines("AArch64", "arm64-apple-darwin"), - "@bazel_tools//src/conditions:darwin_x86_64": native_arch_defines("X86", "x86_64-unknown-darwin"), diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 62f02cded785c..3bcd8e242c18f 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "f688e0901213726feb9b26cedc61919413cbf59c" - LLVM_SHA256 = "b8885c22a9b77f9c91a316b21d71414a7b48dae38513f170da1554002e85b030" + LLVM_COMMIT = "8ee6ab7f69ca9c34eed56faad3971d075dc47121" + LLVM_SHA256 = "c408d2a80a53057fc3596cbfbea3ec64fe8feccbff7adc5e94fde192b96ca568" tf_http_archive( name = name, diff --git a/third_party/llvm_openmp/BUILD b/third_party/llvm_openmp/BUILD index fcba6b9bb2527..52d2e3aa4b611 100644 --- a/third_party/llvm_openmp/BUILD +++ b/third_party/llvm_openmp/BUILD @@ -1,22 +1,22 @@ # Build file for OpenMP library that is part of llvm +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load( - "@tsl//tsl:tsl.bzl", - "if_linux_x86_64", - "if_macos", - "if_windows", -) -load( - "@xla//third_party/llvm_openmp:cmake_vars.bzl", + "@tsl//third_party/llvm_openmp:cmake_vars.bzl", "cmake_var_string", "expand_cmake_vars", ) load( - "@xla//third_party/llvm_openmp:openmp.bzl", + "@tsl//third_party/llvm_openmp:openmp.bzl", "dict_add", "libiomp5_cc_binary", ) -load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load( + "@tsl//tsl:tsl.bzl", + "if_linux_x86_64", + "if_macos", + "if_windows", +) package( default_visibility = [ diff --git a/third_party/llvm_openmp/cmake_vars.bzl b/third_party/llvm_openmp/cmake_vars.bzl index b228f9449aeb8..6772b5d5ce8c5 100644 --- a/third_party/llvm_openmp/cmake_vars.bzl +++ b/third_party/llvm_openmp/cmake_vars.bzl @@ -46,7 +46,7 @@ def expand_cmake_vars(name, src, dst, cmake_vars): cmake_vars: a string containing the CMake variables, as generated by cmake_var_string. """ - expand_cmake_vars_tool = "@xla//third_party/llvm_openmp:expand_cmake_vars" + expand_cmake_vars_tool = "@tsl//third_party/llvm_openmp:expand_cmake_vars" native.genrule( name = name, srcs = [src], diff --git a/third_party/mpitrampoline/BUILD b/third_party/mpitrampoline/BUILD new file mode 100644 index 0000000000000..3c413807167ae --- /dev/null +++ b/third_party/mpitrampoline/BUILD @@ -0,0 +1 @@ +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/third_party/mpitrampoline/gen.patch b/third_party/mpitrampoline/gen.patch new file mode 100644 index 0000000000000..35124db0abb1e --- /dev/null +++ b/third_party/mpitrampoline/gen.patch @@ -0,0 +1,149 @@ +diff --git a/gen/gen_decl.py b/gen/gen_decl.py +index 1005b95..696b4e0 100755 +--- a/gen/gen_decl.py ++++ b/gen/gen_decl.py +@@ -9,8 +9,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi")) + + from mpi_constants import constants + from mpi_functions import functions +-from mpi_constants_fortran import constants_fortran +-from mpi_functions_fortran import functions_fortran ++# from mpi_constants_fortran import constants_fortran ++# from mpi_functions_fortran import functions_fortran + + support_profiling = True + have_weak_symbols = False +@@ -24,7 +24,7 @@ def wrap(line): + lines.append(line) + return "\n".join(lines) + +-with open("include/mpi_decl_constants_c.h", "w") as file: ++with open(sys.argv[1], "w") as file: + file.write("// Declare C MPI constants\n") + file.write("\n") + for (tp, nm) in constants: +@@ -32,7 +32,7 @@ with open("include/mpi_decl_constants_c.h", "w") as file: + 'mpi_nm': nm} + file.write(Template("extern $mpi_tp MPITRAMPOLINE_CONST $mpi_nm;\n").substitute(subs)) + +-with open("include/mpi_decl_functions_c.h", "w") as file: ++with open(sys.argv[2], "w") as file: + file.write("// Declare C MPI functions\n") + file.write("\n") + for (tp, nm, args, flags) in functions: +@@ -90,7 +90,7 @@ with open("include/mpi_decl_functions_c.h", "w") as file: + file.write(Template("\n".join(tmpl)).substitute(subs)) + file.write("\n") + +-with open("include/mpi_decl_constants_fortran.h", "w") as file: ++if False: + file.write("! Declare Fortran MPI constants\n") + file.write("\n") + for (tp, nm) in constants_fortran: +@@ -104,7 +104,7 @@ with open("include/mpi_decl_constants_fortran.h", "w") as file: + file.write("\n".join(map(lambda line: wrap(Template(line).substitute(subs)), tmpl))) + file.write("\n") + +-with open("include/mpi_decl_functions_fortran.h", "w") as file: ++if False: + file.write("! Declare Fortran MPI functions\n") + file.write("\n") + for (tp, nm, args) in functions_fortran: +diff --git a/gen/gen_defn.py b/gen/gen_defn.py +index bf31f35..318222e 100755 +--- a/gen/gen_defn.py ++++ b/gen/gen_defn.py +@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi")) + + from mpi_constants import constants + from mpi_functions import functions +-from mpi_constants_fortran import constants_fortran +-from mpi_functions_fortran import functions_fortran ++# from mpi_constants_fortran import constants_fortran ++# from mpi_functions_fortran import functions_fortran + + support_profiling = True + have_weak_symbols = False + replace_sentinels = False + +-with open("src/mpi_defn_constants_c.h", "w") as file: ++with open(sys.argv[1], "w") as file: + file.write("// Define C MPI constants") + file.write("\n") + for (tp, nm) in constants: +@@ -24,7 +24,7 @@ with open("src/mpi_defn_constants_c.h", "w") as file: + 'mpi_nm': nm} + file.write(Template("$mpi_tp $mpi_nm = ($mpi_tp)0xdeadbeef;\n").substitute(subs)) + +-with open("src/mpi_defn_functions_c.h", "w") as file: ++with open(sys.argv[2], "w") as file: + file.write("// Define C MPI functions\n") + file.write("\n") + for (tp, nm, args, flags) in functions: +@@ -89,7 +89,7 @@ with open("src/mpi_defn_functions_c.h", "w") as file: + file.write(Template("\n".join(tmpl)).substitute(subs)) + file.write("\n") + +-with open("src/mpi_defn_constants_fortran.h", "w") as file: ++if False: + file.write("// Define Fortran MPI constants\n") + file.write("\n") + for (tp, nm) in constants_fortran: +@@ -98,7 +98,7 @@ with open("src/mpi_defn_constants_fortran.h", "w") as file: + # Fortran common blocks with `-march=skylake-avx512` are aligned to 64 bytes + file.write(Template("$mpi_tp $abi_nm __attribute__((__aligned__(64))) = (int)0xdeadbeef;\n").substitute(subs)) + +-with open("src/mpi_defn_functions_fortran.h", "w") as file: ++if False: + file.write("// Define Fortran MPI functions\n") + file.write("\n") + for (tp, nm, args) in functions_fortran: +diff --git a/gen/gen_init.py b/gen/gen_init.py +index 4939261..0e52822 100755 +--- a/gen/gen_init.py ++++ b/gen/gen_init.py +@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi")) + + from mpi_constants import constants + from mpi_functions import functions +-from mpi_constants_fortran import constants_fortran +-from mpi_functions_fortran import functions_fortran ++# from mpi_constants_fortran import constants_fortran ++# from mpi_functions_fortran import functions_fortran + + support_profiling = True + have_weak_symbols = False + replace_sentinels = False + +-with open("src/mpi_init_constants_c.h", "w") as file: ++with open(sys.argv[1], "w") as file: + file.write("// Initialize C MPI constants") + file.write("\n") + for (tp, nm) in constants: +@@ -25,7 +25,7 @@ with open("src/mpi_init_constants_c.h", "w") as file: + 'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm)} + file.write(Template("$mpi_nm = *($mpi_tp const *)get_symbol(handle, \"$abi_nm\");\n").substitute(subs)) + +-with open("src/mpi_init_functions_c.h", "w") as file: ++with open(sys.argv[2], "w") as file: + file.write("// Initialize C MPI functions\n") + file.write("\n") + for (tp, nm, args, flags) in functions: +@@ -39,7 +39,7 @@ with open("src/mpi_init_functions_c.h", "w") as file: + subs['anm{0}'.format(i)] = anm + file.write(Template("$abi_nm = get_symbol(handle, \"$abi_nm\");\n").substitute(subs)) + +-with open("src/mpi_init_constants_fortran.h", "w") as file: ++if False: + file.write("// Initialize Fortran MPI constants\n") + file.write("\n") + for (tp, nm) in constants_fortran: +@@ -47,7 +47,7 @@ with open("src/mpi_init_constants_fortran.h", "w") as file: + 'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm).lower() + "_"} + file.write(Template("$abi_nm = *($abi_tp const*)get_symbol(handle, \"$abi_nm\");\n").substitute(subs)) + +-with open("src/mpi_init_functions_fortran.h", "w") as file: ++if False: + file.write("// Initialize Fortran MPI functions\n") + file.write("\n") + for (tp, nm, args) in functions_fortran: diff --git a/third_party/mpitrampoline/mpitrampoline.BUILD b/third_party/mpitrampoline/mpitrampoline.BUILD new file mode 100644 index 0000000000000..20c5514b164e7 --- /dev/null +++ b/third_party/mpitrampoline/mpitrampoline.BUILD @@ -0,0 +1,135 @@ +# Description: +# A forwarding MPI implementation that can use any other MPI implementation via an MPI ABI + +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") +load("@xla//xla:strict.default.bzl", "py_strict_binary") + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +exports_files(["LICENSE.md"]) + +genrule( + name = "mpi_version", + srcs = [ + "CMakeLists.txt", + "include/mpi_version.h.in", + ], + outs = ["include/mpi_version.h"], + cmd = """ + PROJECT_VERSION=`cat $(location CMakeLists.txt) \ + | grep "MPItrampoline VERSION" | awk '{print $$NF}'` + PROJECT_VERSION_MAJOR=`echo $$PROJECT_VERSION | cut -d. -f1` + PROJECT_VERSION_MINOR=`echo $$PROJECT_VERSION | cut -d. -f2` + PROJECT_VERSION_PATCH=`echo $$PROJECT_VERSION | cut -d. -f3` + sed -e "s/@PROJECT_VERSION@/$${PROJECT_VERSION}/" \ + -e "s/@PROJECT_VERSION_MAJOR@/$${PROJECT_VERSION_MAJOR}/" \ + -e "s/@PROJECT_VERSION_MINOR@/$${PROJECT_VERSION_MINOR}/" \ + -e "s/@PROJECT_VERSION_PATCH@/$${PROJECT_VERSION_PATCH}/" \ + $(location include/mpi_version.h.in) > $(location include/mpi_version.h) + """, +) + +expand_template( + name = "mpi_defaults", + out = "src/mpi_defaults.h", + substitutions = { + "@MPITRAMPOLINE_DEFAULT_DELAY_INIT@": "", + "@MPITRAMPOLINE_DEFAULT_DLOPEN_BINDING@": "", + "@MPITRAMPOLINE_DEFAULT_DLOPEN_MODE@": "", + "@MPITRAMPOLINE_DEFAULT_LIB@": "", + "@MPITRAMPOLINE_DEFAULT_PRELOAD@": "", + "@MPITRAMPOLINE_DEFAULT_VERBOSE@": "", + }, + template = "src/mpi_defaults.h.in", +) + +py_strict_binary( + name = "gen_decl", + srcs = [ + "gen/gen_decl.py", + "mpiabi/mpi_constants.py", + "mpiabi/mpi_functions.py", + ], +) + +genrule( + name = "decl", + outs = [ + "include/mpi_decl_constants_c.h", + "include/mpi_decl_functions_c.h", + ], + cmd = "$(location :gen_decl) $(location include/mpi_decl_constants_c.h) \ + $(location include/mpi_decl_functions_c.h)", + tools = [":gen_decl"], +) + +py_strict_binary( + name = "gen_defn", + srcs = [ + "gen/gen_defn.py", + "mpiabi/mpi_constants.py", + "mpiabi/mpi_functions.py", + ], +) + +genrule( + name = "defn", + outs = [ + "include/mpi_defn_constants_c.h", + "include/mpi_defn_functions_c.h", + ], + cmd = "$(location :gen_defn) $(location include/mpi_defn_constants_c.h) \ + $(location include/mpi_defn_functions_c.h)", + tools = [":gen_defn"], +) + +py_strict_binary( + name = "gen_init", + srcs = [ + "gen/gen_init.py", + "mpiabi/mpi_constants.py", + "mpiabi/mpi_functions.py", + ], +) + +genrule( + name = "init", + outs = [ + "include/mpi_init_constants_c.h", + "include/mpi_init_functions_c.h", + ], + cmd = "$(location :gen_init) $(location include/mpi_init_constants_c.h) \ + $(location include/mpi_init_functions_c.h)", + tools = [":gen_init"], +) + +cc_library( + name = "mpitrampoline", + srcs = [ + "src/mpi.c", + ], + hdrs = [ + "include/mpi.h", + "include/mpi_decl_constants_c.h", + "include/mpi_decl_functions_c.h", + "include/mpi_defn_constants_c.h", + "include/mpi_defn_functions_c.h", + "include/mpi_init_constants_c.h", + "include/mpi_init_functions_c.h", + "include/mpi_version.h", + "mpiabi/mpiabi.h", + "src/mpi_defaults.h", + ], + copts = [ + "-fexceptions", + ], + includes = [ + "include", + "mpiabi", + "src", + ], +) diff --git a/third_party/mpitrampoline/workspace.bzl b/third_party/mpitrampoline/workspace.bzl new file mode 100644 index 0000000000000..4748931ae6e36 --- /dev/null +++ b/third_party/mpitrampoline/workspace.bzl @@ -0,0 +1,18 @@ +"""Provides the repository macro to import mpitrampoline.""" + +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +def repo(): + """Imports mpitrampoline.""" + + MPITRAMPOLINE_COMMIT = "25efb0f7a4cd00ed82bafb8b1a6285fc50d297ed" + MPITRAMPOLINE_SHA256 = "5a36656205c472bdb639bffebb0f014523b32dda0c2cbedd9ce7abfc9e879e84" + + tf_http_archive( + name = "mpitrampoline", + sha256 = MPITRAMPOLINE_SHA256, + strip_prefix = "MPItrampoline-{commit}".format(commit = MPITRAMPOLINE_COMMIT), + urls = tf_mirror_urls("https://github.com/eschnett/mpitrampoline/archive/{commit}.tar.gz".format(commit = MPITRAMPOLINE_COMMIT)), + patch_file = ["//third_party/mpitrampoline:gen.patch"], + build_file = "//third_party/mpitrampoline:mpitrampoline.BUILD", + ) diff --git a/third_party/nanobind/BUILD b/third_party/nanobind/BUILD new file mode 100644 index 0000000000000..3c413807167ae --- /dev/null +++ b/third_party/nanobind/BUILD @@ -0,0 +1 @@ +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/third_party/nanobind/nanobind.BUILD b/third_party/nanobind/nanobind.BUILD new file mode 100644 index 0000000000000..cfbf8fb993c25 --- /dev/null +++ b/third_party/nanobind/nanobind.BUILD @@ -0,0 +1,26 @@ +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "nanobind", + srcs = glob([ + "src/*.cpp", + ]), + copts = ["-fexceptions"], + defines = [ + "NB_BUILD=1", + "NB_SHARED=1", + ], + includes = ["include"], + textual_hdrs = glob( + [ + "include/**/*.h", + "src/*.h", + ], + ), + deps = [ + "@robin_map", + "@tsl//third_party/python_runtime:headers", + ], +) diff --git a/third_party/nanobind/pr438.patch b/third_party/nanobind/pr438.patch new file mode 100644 index 0000000000000..edb7d61700e03 --- /dev/null +++ b/third_party/nanobind/pr438.patch @@ -0,0 +1,51 @@ +diff --git a/src/nb_enum.cpp b/src/nb_enum.cpp +index 86f64d1..91f3932 100644 +--- a/src/nb_enum.cpp ++++ b/src/nb_enum.cpp +@@ -73,6 +73,13 @@ static PyObject *nb_enum_get_doc(PyObject *self, void *) { + return result; + } + ++static PyObject *nb_enum_get_value(PyObject *self, void *) { ++ enum_supplement &supp = nb_enum_supplement(Py_TYPE(self)); ++ return supp.is_signed ? nb_enum_int_signed(self) ++ : nb_enum_int_unsigned(self); ++} ++ ++ + NB_NOINLINE static PyObject *nb_enum_int_signed(PyObject *o) { + type_data *t = nb_type_data(Py_TYPE(o)); + const void *p = inst_ptr((nb_inst *) o); +@@ -141,6 +148,8 @@ error: + static PyGetSetDef nb_enum_getset[] = { + { "__doc__", nb_enum_get_doc, nullptr, nullptr, nullptr }, + { "__name__", nb_enum_get_name, nullptr, nullptr, nullptr }, ++ { "name", nb_enum_get_name, nullptr, nullptr, nullptr }, ++ { "value", nb_enum_get_value, nullptr, nullptr, nullptr }, + { nullptr, nullptr, nullptr, nullptr, nullptr } + }; + +diff --git a/tests/test_enum.py b/tests/test_enum.py +index 2a6e9ff..1063eef 100644 +--- a/tests/test_enum.py ++++ b/tests/test_enum.py +@@ -14,6 +14,9 @@ def test01_unsigned_enum(): + assert int(t.Enum.A) == 0 + assert int(t.Enum.B) == 1 + assert int(t.Enum.C) == 0xffffffff ++ assert t.Enum.A.value == 0 ++ assert t.Enum.B.value == 1 ++ assert t.Enum.C.value == 0xffffffff + assert t.Enum(0) is t.Enum.A + assert t.Enum(1) is t.Enum.B + assert t.Enum(0xffffffff) is t.Enum.C +@@ -48,6 +51,9 @@ def test02_signed_enum(): + assert int(t.SEnum.A) == 0 + assert int(t.SEnum.B) == 1 + assert int(t.SEnum.C) == -1 ++ assert t.SEnum.A.value == 0 ++ assert t.SEnum.B.value == 1 ++ assert t.SEnum.C.value == -1 + assert t.SEnum(0) is t.SEnum.A + assert t.SEnum(1) is t.SEnum.B + assert t.SEnum(-1) is t.SEnum.C \ No newline at end of file diff --git a/third_party/nanobind/pr461.patch b/third_party/nanobind/pr461.patch new file mode 100644 index 0000000000000..aa0a51b68175a --- /dev/null +++ b/third_party/nanobind/pr461.patch @@ -0,0 +1,39 @@ +diff --git a/src/nb_type.cpp b/src/nb_type.cpp +--- a/src/nb_type.cpp ++++ b/src/nb_type.cpp +@@ -36,6 +36,11 @@ static PyObject **nb_weaklist_ptr(PyObje + return weaklistoffset ? (PyObject **) ((uint8_t *) self + weaklistoffset) : nullptr; + } + ++static PyGetSetDef inst_getset[] = { ++ { "__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, nullptr }, ++ { nullptr, nullptr, nullptr, nullptr, nullptr } ++}; ++ + static int inst_clear(PyObject *self) { + PyObject **dict = nb_dict_ptr(self); + if (dict) +@@ -923,8 +928,11 @@ PyObject *nb_type_new(const type_init_da + } + + bool has_traverse = false; +- for (PyType_Slot *ts = slots; ts != s; ++ts) ++ bool has_getset = false; ++ for (PyType_Slot *ts = slots; ts != s; ++ts) { + has_traverse |= ts->slot == Py_tp_traverse; ++ has_getset |= ts->slot == Py_tp_getset; ++ } + + Py_ssize_t dictoffset = 0, weaklistoffset = 0; + int num_members = 0; +@@ -948,6 +956,10 @@ PyObject *nb_type_new(const type_init_da + has_traverse = true; + } + spec.basicsize = (int) basicsize; ++ ++ if (!has_getset) { ++ *s++ = { Py_tp_getset, (void *) inst_getset }; ++ } + } + + if (is_weak_referenceable) { diff --git a/third_party/nanobind/workspace.bzl b/third_party/nanobind/workspace.bzl new file mode 100644 index 0000000000000..9f9022dbaa8d1 --- /dev/null +++ b/third_party/nanobind/workspace.bzl @@ -0,0 +1,16 @@ +"""Loads the nanobind library.""" + +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +def repo(): + tf_http_archive( + name = "nanobind", + strip_prefix = "nanobind-1.9.2", + sha256 = "149a3da40b0a988513d8cf5e71db3037373823505a3c92f87b988c92d7e0ab34", + urls = tf_mirror_urls("https://github.com/wjakob/nanobind/archive/refs/tags/v1.9.2.tar.gz"), + build_file = "//third_party/nanobind:nanobind.BUILD", + patch_file = [ + "//third_party/nanobind:pr438.patch", # Remove when updating to nanobind 2.0.0. + "//third_party/nanobind:pr461.patch", # Remove when updating to nanobind 2.0.0. + ], + ) diff --git a/third_party/py/ml_dtypes/ml_dtypes.BUILD b/third_party/py/ml_dtypes/ml_dtypes.BUILD index c5f1bc01fa8c8..0eb4dfce866ca 100644 --- a/third_party/py/ml_dtypes/ml_dtypes.BUILD +++ b/third_party/py/ml_dtypes/ml_dtypes.BUILD @@ -49,7 +49,7 @@ pybind_extension( ":float8", ":int4", "@eigen_archive//:eigen3", - "@xla//third_party/py/numpy:headers", + "@tsl//third_party/py/numpy:headers", ], ) diff --git a/third_party/py/ml_dtypes/ml_dtypes.tests.BUILD b/third_party/py/ml_dtypes/ml_dtypes.tests.BUILD index a0f7fc39f88f6..fd86cd82a035e 100644 --- a/third_party/py/ml_dtypes/ml_dtypes.tests.BUILD +++ b/third_party/py/ml_dtypes/ml_dtypes.tests.BUILD @@ -8,7 +8,7 @@ py_library( "//:ml_dtypes", "@absl_py//absl/testing:absltest", "@absl_py//absl/testing:parameterized", - "@xla//third_party/py/numpy", + "@tsl//third_party/py/numpy", ], ) diff --git a/third_party/robin_map/BUILD b/third_party/robin_map/BUILD new file mode 100644 index 0000000000000..3c413807167ae --- /dev/null +++ b/third_party/robin_map/BUILD @@ -0,0 +1 @@ +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/third_party/robin_map/robin_map.BUILD b/third_party/robin_map/robin_map.BUILD new file mode 100644 index 0000000000000..b649dda317665 --- /dev/null +++ b/third_party/robin_map/robin_map.BUILD @@ -0,0 +1,17 @@ +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "robin_map", + hdrs = [ + "include/tsl/robin_growth_policy.h", + "include/tsl/robin_hash.h", + "include/tsl/robin_map.h", + "include/tsl/robin_set.h", + ], + copts = ["-fexceptions"], + features = ["-use_header_modules"], # Incompatible with -fexceptions. + includes = ["."], + strip_include_prefix = "include", +) diff --git a/third_party/robin_map/workspace.bzl b/third_party/robin_map/workspace.bzl new file mode 100644 index 0000000000000..397becb29c86b --- /dev/null +++ b/third_party/robin_map/workspace.bzl @@ -0,0 +1,12 @@ +"""Loads the robin_map library.""" + +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +def repo(): + tf_http_archive( + name = "robin_map", + strip_prefix = "robin-map-1.2.1", + sha256 = "2b54d2c1de2f73bea5c51d5dcbd64813a08caf1bfddcfdeee40ab74e9599e8e3", + urls = tf_mirror_urls("https://github.com/Tessil/robin-map/archive/refs/tags/v1.2.1.tar.gz"), + build_file = "//third_party/robin_map:robin_map.BUILD", + ) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch old mode 100644 new mode 100755 index 8abb8b476d8c9..f29517482dd96 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,7 +1,7 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt --- stablehlo/CMakeLists.txt +++ stablehlo/CMakeLists.txt -@@ -13,135 +13,20 @@ +@@ -13,153 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # @@ -25,6 +25,11 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt -if(POLICY CMP0116) - cmake_policy(SET CMP0116 OLD) -endif() +- +-# Support for return(PROPAGATE ...) in functions. +-if (POLICY CMP0140) +- cmake_policy(SET CMP0140 NEW) +-endif() +# This build of StableHLO is meant to be embedded in MLIR-HLO. +# As a result, its root CMakeLists.txt is different from the original +# CMakeLists.txt from https://github.com/openxla/stablehlo. @@ -39,6 +44,9 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt -option(STABLEHLO_BUILD_EMBEDDED "Build StableHLO as part of another project" OFF) -option(STABLEHLO_ENABLE_BINDINGS_PYTHON "Enables StableHLO Python bindings" OFF) -option(STABLEHLO_ENABLE_STRICT_BUILD "Build StableHLO with strict warnings and warnings as errors" OFF) +-option(STABLEHLO_ENABLE_SANITIZER "Enable a sanitizer [OFF, address]" OFF) +-option(STABLEHLO_ENABLE_SPLIT_DWARF "Enable split DWARF if the platform supports it" OFF) +-option(STABLEHLO_ENABLE_LLD "Use LLD as the linker if available" OFF) -#------------------------------------------------------------------------------- -# Project setup and globals @@ -55,29 +63,6 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt - set(CMAKE_CXX_STANDARD 17) -endif() - --# Build with ccache if the package is present --set(LLVM_CCACHE_BUILD OFF CACHE BOOL "Set to ON for a ccache enabled build") --if(LLVM_CCACHE_BUILD) -- find_program(CCACHE_PROGRAM ccache) -- if(CCACHE_PROGRAM) -- set(LLVM_CCACHE_MAXSIZE "" CACHE STRING "Size of ccache") -- set(LLVM_CCACHE_DIR "" CACHE STRING "Directory to keep ccached data") -- set(LLVM_CCACHE_PARAMS "CCACHE_CPP2=yes CCACHE_HASHDIR=yes" -- CACHE STRING "Parameters to pass through to ccache") -- -- set(CCACHE_PROGRAM "${LLVM_CCACHE_PARAMS} ${CCACHE_PROGRAM}") -- if (LLVM_CCACHE_MAXSIZE) -- set(CCACHE_PROGRAM "CCACHE_MAXSIZE=${LLVM_CCACHE_MAXSIZE} ${CCACHE_PROGRAM}") -- endif() -- if (LLVM_CCACHE_DIR) -- set(CCACHE_PROGRAM "CCACHE_DIR=${LLVM_CCACHE_DIR} ${CCACHE_PROGRAM}") -- endif() -- set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ${CCACHE_PROGRAM}) -- else() -- message(FATAL_ERROR "Unable to find the program ccache. Set LLVM_CCACHE_BUILD to OFF") -- endif() --endif() -- -#------------------------------------------------------------------------------- -# MLIR/LLVM Configuration -#------------------------------------------------------------------------------- @@ -114,10 +99,39 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt - message(STATUS "Building StableHLO embedded in another project") -endif() - +-# Add the CMake modules specific to StableHLO +-list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake") +- -if(LLVM_ENABLE_ZLIB) - find_package(ZLIB) -endif() - +-#------------------------------------------------------------------------------- +-# Performance configuration +-#------------------------------------------------------------------------------- +- +-include(CheckCXXCompilerFlag) +-include(CheckLinkerFlag) +-if (STABLEHLO_ENABLE_LLD) +- message(STATUS "Enabling LLD as the linker") +- add_link_options("-fuse-ld=lld") +-endif() +- +-if(STABLEHLO_ENABLE_SPLIT_DWARF) +- check_cxx_compiler_flag(-gsplit-dwarf STABLEHLO_SUPPORTS_SPLIT_DWARF) +- if (STABLEHLO_SUPPORTS_SPLIT_DWARF) +- message(STATUS "Enabling split-dwarf build") +- add_compile_options(-gsplit-dwarf -ggnu-pubnames) +- endif() +- check_linker_flag(CXX "-Wl,--gdb-index" STABLEHLO_SUPPORTS_GDB_INDEX) +- # If we set LLD it doesn't seem to affect the check_linker_flag above. +- # Account for it with the generator expression OR +- if (STABLEHLO_SUPPORTS_GDB_INDEX OR STABLEHLO_ENABLE_LLD) +- message(STATUS "Enabling GDB index in binary") +- add_link_options("-Wl,--gdb-index") +- endif() +-endif() +- -include(TableGen) -include(AddLLVM) -include(AddMLIR) @@ -129,15 +143,19 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt -link_directories(${LLVM_BUILD_LIBRARY_DIR}) -add_definitions(${LLVM_DEFINITIONS}) - +- +-#------------------------------------------------------------------------------- +-# Sanitizer configuration +-#------------------------------------------------------------------------------- +- +-include(SetupSanitizers) +-setup_sanitizers() +- -#------------------------------------------------------------------------------- -# Python configuration -#------------------------------------------------------------------------------- - -if(STABLEHLO_ENABLE_BINDINGS_PYTHON) -- if(NOT STABLEHLO_EXTERNAL_PROJECT_BUILD) -- message(WARNING "StableHLO Python bindings are not supported in standalone mode") -- endif() -- - include(MLIRDetectPythonEnv) - mlir_configure_python_dev_packages() -endif() @@ -159,7 +177,7 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel --- stablehlo/stablehlo/experimental/BUILD.bazel +++ stablehlo/stablehlo/experimental/BUILD.bazel -@@ -0,0 +1,113 @@ +@@ -0,0 +1,115 @@ +# Copyright 2023 The StableHLO Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); @@ -230,6 +248,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/ +cc_library( + name = "experimental_stablehlo_passes", + srcs = [ ++ "transforms/ChloRecomposeOps.cpp", + "transforms/StablehloCanonicalizeDynamism.cpp", + "transforms/StablehloRefineShapes.cpp", + ], @@ -243,6 +262,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/ + "//:chlo_ops", + "//:stablehlo_ops", + "//:stablehlo_ops_inc_gen", ++ "//:stablehlo_passes", + "//:stablehlo_type_inference", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", @@ -426,7 +446,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/CMakeLists.txt b/stablehlo diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp --- stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp +++ stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp -@@ -0,0 +1,506 @@ +@@ -0,0 +1,504 @@ +/* Copyright 2023 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); @@ -444,6 +464,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + +#include "stablehlo/experimental/dialect/StablehloOps.h" + ++#include +#include + +#include "llvm/ADT/ArrayRef.h" @@ -470,8 +491,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + // api_version and backend_config have default values. + // call_target_name should be "stablehlo.dynamic_reduce_window". + // called_computations carries the body. -+ if (attr.getName() != "api_version" && -+ attr.getName() != "backend_config" && ++ if (attr.getName() != "api_version" && attr.getName() != "backend_config" && + attr.getName() != "call_target_name" && + attr.getName() != "called_computations") + return op_.emitError() @@ -852,8 +872,8 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + + // dynamic_top_k_i2 + auto kType = k.getType().dyn_cast(); -+ if (!kType || !kType.hasRank() || -+ kType.getRank() != 0 || !kType.getElementType().isIntOrIndex()) ++ if (!kType || !kType.hasRank() || kType.getRank() != 0 || ++ !kType.getElementType().isIntOrIndex()) + return op_.emitError() + << "expects k (operand #1) " + << "to be a 0-dimensional tensor of integer or index type"; @@ -915,7 +935,6 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + return op_.getInputs()[1].cast>(); +} + -+ +TypedValue DynamicTopKOpAdaptor::getValues() { + return op_.getResults()[0].cast>(); +} @@ -924,8 +943,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + return op_.getResults()[1].cast>(); +} + -+std::optional getDynamicTopKOp( -+ CustomCallOp op) { ++std::optional getDynamicTopKOp(CustomCallOp op) { + if (op.getCallTargetName() != "stablehlo.dynamic_top_k") return {}; + return DynamicTopKOpAdaptor(op); +} @@ -1263,10 +1281,65 @@ diff --ruN a/stablehlo/stablehlo/experimental/tests/CMakeLists.txt b/stablehlo/s + stablehlo-translate +) +add_dependencies(check-stablehlo-quick check-experimental-stablehlo-tests) +diff --ruN a/stablehlo/stablehlo/experimental/tests/chlo_recompose_ops.mlir b/stablehlo/stablehlo/experimental/tests/chlo_recompose_ops.mlir +--- stablehlo/stablehlo/experimental/tests/chlo_recompose_ops.mlir ++++ stablehlo/stablehlo/experimental/tests/chlo_recompose_ops.mlir +@@ -0,0 +1,51 @@ ++// RUN: experimental-stablehlo-opt --experimental-chlo-recompose-ops --split-input-file --verify-diagnostics %s | FileCheck %s ++ ++// ----- ++ ++// CHECK-LABEL: func @recompose_topk ++func.func @recompose_topk(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // CHECK: %values, %indices = chlo.top_k(%arg0, k = 4) {largest = true} : tensor<5x16xf32> -> (tensor, tensor) ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = true} ++ } : (tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @recompose_topk_invalid_attr ++func.func @recompose_topk_invalid_attr(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // CHECK: stablehlo.custom_call @mhlo.topk ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = false} ++ } : (tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: @recompose_tan ++func.func @recompose_tan(%arg0: tensor<16xf32>) -> tensor { ++ // CHECK: %0 = chlo.tan %arg0 : tensor<16xf32> -> tensor ++ %0 = "stablehlo.custom_call"(%arg0) { ++ call_target_name = "mhlo.tan", ++ mhlo.attributes = {}, ++ mhlo.version = 1 : i64 ++ } : (tensor<16xf32>) -> tensor ++ func.return %0 : tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: @recompose_erf ++func.func @recompose_erf(%arg0: tensor<3x20x20xbf16>) -> tensor { ++ // CHECK: %0 = chlo.erf %arg0 : tensor<3x20x20xbf16> -> tensor ++ %0 = "stablehlo.custom_call"(%arg0) { ++ backend_config = "", ++ call_target_name = "mhlo.erf", ++ mhlo.attributes = {}, ++ mhlo.version = 1 : i64 ++ } : (tensor<3x20x20xbf16>) -> tensor ++ func.return %0 : tensor ++} ++ diff --ruN a/stablehlo/stablehlo/experimental/tests/lit.cfg.py b/stablehlo/stablehlo/experimental/tests/lit.cfg.py --- stablehlo/stablehlo/experimental/tests/lit.cfg.py +++ stablehlo/stablehlo/experimental/tests/lit.cfg.py -@@ -0,0 +1,42 @@ +@@ -0,0 +1,46 @@ +"""Lit configuration to drive test in this repo.""" +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The StableHLO Authors. @@ -1297,6 +1370,10 @@ diff --ruN a/stablehlo/stablehlo/experimental/tests/lit.cfg.py b/stablehlo/stabl +config.suffixes = ['.mlir'] +config.test_source_root = os.path.dirname(__file__) + ++# Disallow reusing variables across CHECK-LABEL matches. ++# A variable can eschew this (be made "global") by prefixing its name with $. ++config.environment['FILECHECK_OPTS'] = '-enable-var-scope' ++ +# Make LLVM and StableHLO tools available in RUN directives +tools = [ + 'FileCheck', @@ -1348,11 +1425,11 @@ diff --ruN a/stablehlo/stablehlo/experimental/tests/stablehlo_canonicalize_dynam + // CHECK-NEXT: %[[VAL1:.*]] = stablehlo.add %arg2, %arg3 : tensor + // CHECK-NEXT: stablehlo.return %[[VAL1]] : tensor + // CHECK-NEXT: }) { -+ // CHECK-SAME: base_dilations = dense<[2, 1]> : tensor<2xi64>, ++ // CHECK-SAME: base_dilations = array, + // CHECK-SAME{LITERAL}: padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>, -+ // CHECK-SAME: window_dilations = dense<[3, 1]> : tensor<2xi64>, -+ // CHECK-SAME: window_dimensions = dense<[2, 1]> : tensor<2xi64>, -+ // CHECK-SAME: window_strides = dense<[4, 1]> : tensor<2xi64> ++ // CHECK-SAME: window_dilations = array, ++ // CHECK-SAME: window_dimensions = array, ++ // CHECK-SAME: window_strides = array + // CHECK-SAME: } : (tensor<3x2xf32>, tensor) -> tensor<2x2xf32> + %0 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> + %1 = stablehlo.constant dense<[4, 1]> : tensor<2xi64> @@ -1380,11 +1457,11 @@ diff --ruN a/stablehlo/stablehlo/experimental/tests/stablehlo_canonicalize_dynam + // CHECK-NEXT: %[[VAL1:.*]] = stablehlo.add %arg2, %arg3 : tensor + // CHECK-NEXT: stablehlo.return %[[VAL1]] : tensor + // CHECK-NEXT: }) { -+ // CHECK-SAME: base_dilations = dense<[2, 1]> : tensor<2xi64>, ++ // CHECK-SAME: base_dilations = array, + // CHECK-SAME{LITERAL}: padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>, -+ // CHECK-SAME: window_dilations = dense<[3, 1]> : tensor<2xi64>, -+ // CHECK-SAME: window_dimensions = dense<[2, 1]> : tensor<2xi64>, -+ // CHECK-SAME: window_strides = dense<[4, 1]> : tensor<2xi64> ++ // CHECK-SAME: window_dilations = array, ++ // CHECK-SAME: window_dimensions = array, ++ // CHECK-SAME: window_strides = array + // CHECK-SAME: } : (tensor, tensor) -> tensor + %0 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> + %1 = stablehlo.constant dense<[4, 1]> : tensor<2xi64> @@ -1689,7 +1766,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/tests/stablehlo_refine_shapes.mlir +// RUN: experimental-stablehlo-opt --experimental-stablehlo-refine-shapes --split-input-file --verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: @main -+func.func @main(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<*xf32> { ++func.func @main(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor { + // CHECK: stablehlo.dynamic_reduce_window{{.*}} -> tensor<2x2xf32> + %0 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> + %1 = stablehlo.constant dense<[4, 1]> : tensor<2xi64> @@ -1698,8 +1775,8 @@ diff --ruN a/stablehlo/stablehlo/experimental/tests/stablehlo_refine_shapes.mlir + %4 = stablehlo.constant dense<[[2, 1], [0, 0]]> : tensor<2x2xi64> + %5 = stablehlo.custom_call @stablehlo.dynamic_reduce_window(%arg0, %arg1, %0, %1, %2, %3, %4) { + called_computations = [@dynamic_reduce_window0] -+ } : (tensor<3x2xf32>, tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<*xf32> -+ func.return %5 : tensor<*xf32> ++ } : (tensor<3x2xf32>, tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2x2xi64>) -> tensor ++ func.return %5 : tensor +} + +func.func private @dynamic_reduce_window0(%arg0: tensor, %arg1: tensor) -> tensor { @@ -1710,13 +1787,13 @@ diff --ruN a/stablehlo/stablehlo/experimental/tests/stablehlo_refine_shapes.mlir +// ----- + +// CHECK-LABEL: @refine_dynamic_rng_bit_generator -+func.func @refine_dynamic_rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor, tensor<*xf32>) { ++func.func @refine_dynamic_rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor, tensor) { + // CHECK: stablehlo.dynamic_rng_bit_generator{{.*}} -> (tensor<2xui64>, tensor<1x4xf32>) + %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> + %1:2 = stablehlo.custom_call @stablehlo.dynamic_rng_bit_generator(%arg0, %0) { + rng_algorithm = #stablehlo -+ } : (tensor<2xui64>, tensor<2xi64>) -> (tensor, tensor<*xf32>) -+ func.return %1#0, %1#1 : tensor, tensor<*xf32> ++ } : (tensor<2xui64>, tensor<2xi64>) -> (tensor, tensor) ++ func.return %1#0, %1#1 : tensor, tensor +} + +// ----- @@ -1826,7 +1903,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/tools/StablehloOptMain.cpp b/stabl diff --ruN a/stablehlo/stablehlo/experimental/transforms/CMakeLists.txt b/stablehlo/stablehlo/experimental/transforms/CMakeLists.txt --- stablehlo/stablehlo/experimental/transforms/CMakeLists.txt +++ stablehlo/stablehlo/experimental/transforms/CMakeLists.txt -@@ -0,0 +1,38 @@ +@@ -0,0 +1,40 @@ +# Copyright 2023 The StableHLO Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); @@ -1847,6 +1924,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/CMakeLists.txt b/stable + +add_mlir_dialect_library(ExperimentalStablehloPasses + PARTIAL_SOURCES_INTENDED ++ ChloRecomposeOps.cpp + StablehloCanonicalizeDynamism.cpp + StablehloRefineShapes.cpp + @@ -1862,13 +1940,196 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/CMakeLists.txt b/stable + MLIRTransformUtils + ExperimentalStablehloOps + StablehloBase -+ StablehloTypeInference + StablehloOps ++ StablehloPasses ++ StablehloTypeInference +) +diff --ruN a/stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp b/stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp +--- stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp ++++ stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp +@@ -0,0 +1,178 @@ ++/* Copyright 2024 The StableHLO Authors. ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#include ++#include ++ ++#include "llvm/ADT/SmallVector.h" ++#include "mlir/Dialect/Func/IR/FuncOps.h" ++#include "mlir/IR/Attributes.h" ++#include "mlir/IR/BuiltinAttributes.h" ++#include "mlir/IR/PatternMatch.h" ++#include "mlir/Pass/PassManager.h" ++#include "mlir/Support/LogicalResult.h" ++#include "mlir/Transforms/GreedyPatternRewriteDriver.h" ++#include "stablehlo/dialect/ChloOps.h" ++#include "stablehlo/dialect/StablehloOps.h" ++#include "stablehlo/experimental/dialect/StablehloOps.h" ++#include "stablehlo/experimental/transforms/Passes.h" ++#include "stablehlo/transforms/Passes.h" ++ ++namespace mlir { ++namespace stablehlo { ++namespace experimental { ++ ++#define GEN_PASS_DEF_CHLORECOMPOSEOPSPASS ++#include "stablehlo/experimental/transforms/Passes.h.inc" ++ ++namespace { ++ ++FailureOr getCustomCallOpAttributes(CustomCallOp op, ++ PatternRewriter& rewriter) { ++ auto attrs = op->getAttrOfType("mhlo.attributes"); ++ if (!attrs) ++ return rewriter.notifyMatchFailure( ++ op, "Expected mhlo.attributes dictionary attribute."); ++ return attrs; ++} ++ ++LogicalResult verifyCustomCallOpAttributes( ++ CustomCallOp op, PatternRewriter& rewriter, ++ std::function verifyFn) { ++ auto attrs = getCustomCallOpAttributes(op, rewriter); ++ if (failed(attrs)) return failure(); ++ ++ for (auto attr : attrs->getValue()) { ++ if (failed(verifyFn(attr))) return failure(); ++ } ++ return success(); ++} ++ ++// Experimental and public ops in MHLO that do not exist yet in StableHLO ++// can be encoded as a StableHLO CustomCallOp to allow round-tripping ++// between dialects. Some of these ops are CHLO ops that are accelerated by XLA. ++// For these ops we can recompose to CHLO. ++// ++// Example: ++// %0 = stablehlo.custom_call @mhlo.topk(...) {...} ++// ==> ++// %0 = "chlo.topk"(...) {...} ++template ++LogicalResult recomposeChloOpFromCustomCall(stablehlo::CustomCallOp op, ++ PatternRewriter& rewriter) { ++ // Only call_target_name, backend_config, called_computations, mhlo.version, ++ // and mhlo.attributes are compatible with the extensibility protocol. ++ auto isSupportedAttrName = [](NamedAttribute attr) { ++ auto name = attr.getName(); ++ return name == "call_target_name" || name == "backend_config" || ++ name == "called_computations" || name == "mhlo.attributes" || ++ name == "mhlo.version"; ++ }; ++ if (!llvm::all_of(op->getAttrs(), isSupportedAttrName) || ++ !op.getBackendConfig().empty()) { ++ return rewriter.notifyMatchFailure( ++ op, "CHLO Recompose custom call did not have required attributes."); ++ } ++ if (!op.getCalledComputations().empty()) ++ return rewriter.notifyMatchFailure(op, "Ops with regions not supported."); ++ ++ auto attrs = getCustomCallOpAttributes(op, rewriter); ++ if (failed(attrs)) return failure(); ++ ++ rewriter.replaceOpWithNewOp(op, op->getResultTypes(), ++ op->getOperands(), attrs->getValue()); ++ return success(); ++} ++ ++struct TopKOpRecomposePattern : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(CustomCallOp op, ++ PatternRewriter& rewriter) const override { ++ if (op.getCallTargetName() != "mhlo.topk") return failure(); ++ auto res = verifyCustomCallOpAttributes( ++ op, rewriter, [&](NamedAttribute attr) -> LogicalResult { ++ if (attr.getName() != "largest") return success(); ++ if (attr.getValue().cast().getValue() == false) ++ return rewriter.notifyMatchFailure( ++ op, "largest = false is not supported."); ++ return success(); ++ }); ++ if (failed(res)) return failure(); ++ return recomposeChloOpFromCustomCall(op, rewriter); ++ } ++}; ++ ++struct TanOpRecomposePattern : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(CustomCallOp op, ++ PatternRewriter& rewriter) const override { ++ if (op.getCallTargetName() != "mhlo.tan") return failure(); ++ return recomposeChloOpFromCustomCall(op, rewriter); ++ } ++}; ++ ++struct ErfOpRecomposePattern : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(CustomCallOp op, ++ PatternRewriter& rewriter) const override { ++ if (op.getCallTargetName() != "mhlo.erf") return failure(); ++ return recomposeChloOpFromCustomCall(op, rewriter); ++ } ++}; ++ ++} // namespace ++ ++struct ChloRecomposeOpsPass ++ : public impl::ChloRecomposeOpsPassBase { ++ using ChloRecomposeOpsPassBase::ChloRecomposeOpsPassBase; ++ ++ void runOnOperation() override { ++ // Do a single traversal to recompose CustomCallOp to CHLO ops. ++ GreedyRewriteConfig config; ++ config.useTopDownTraversal = true; ++ config.enableRegionSimplification = true; ++ config.maxIterations = 1; ++ config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; ++ config.strictMode = GreedyRewriteStrictness::ExistingOps; ++ ++ RewritePatternSet patterns(&getContext()); ++ patterns.add(&getContext()); ++ patterns.add(&getContext()); ++ patterns.add(&getContext()); ++ ++ // Only apply to CustomCallOps ++ auto moduleOp = getOperation(); ++ llvm::SmallVector candidateOps; ++ moduleOp.walk([&](CustomCallOp op) { candidateOps.push_back(op); }); ++ ++ if (failed(applyOpPatternsAndFold(candidateOps, std::move(patterns), ++ config))) { ++ moduleOp.emitError("Failed to converge ChloRecomposeOps in ") ++ << config.maxIterations << " iterations"; ++ return signalPassFailure(); ++ } ++ } ++}; ++ ++void createChloLegalizeToStablehloPipeline(OpPassManager& pm) { ++ pm.addPass(mlir::stablehlo::experimental::createChloRecomposeOpsPass()); ++ pm.addNestedPass( ++ mlir::stablehlo::createChloLegalizeToStablehloPass()); ++ pm.addNestedPass( ++ mlir::stablehlo::createShapeLegalizeToStablehloPass()); ++} ++ ++} // namespace experimental ++} // namespace stablehlo ++} // namespace mlir diff --ruN a/stablehlo/stablehlo/experimental/transforms/Passes.h b/stablehlo/stablehlo/experimental/transforms/Passes.h --- stablehlo/stablehlo/experimental/transforms/Passes.h +++ stablehlo/stablehlo/experimental/transforms/Passes.h -@@ -0,0 +1,37 @@ +@@ -0,0 +1,38 @@ +/* Copyright 2023 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); @@ -1895,12 +2156,13 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/Passes.h b/stablehlo/st +namespace mlir { +namespace stablehlo { +namespace experimental { -+ -+#define GEN_PASS_DECL_STABLEHLOCANONICALIZEDYNAMISMPASS -+#define GEN_PASS_DECL_STABLEHLOREFINESHAPESPASS ++ ++#define GEN_PASS_DECL +#define GEN_PASS_REGISTRATION +#include "stablehlo/experimental/transforms/Passes.h.inc" + ++void createChloLegalizeToStablehloPipeline(OpPassManager &pm); ++ +} // namespace experimental +} // namespace stablehlo +} // namespace mlir @@ -1909,7 +2171,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/Passes.h b/stablehlo/st diff --ruN a/stablehlo/stablehlo/experimental/transforms/Passes.td b/stablehlo/stablehlo/experimental/transforms/Passes.td --- stablehlo/stablehlo/experimental/transforms/Passes.td +++ stablehlo/stablehlo/experimental/transforms/Passes.td -@@ -0,0 +1,31 @@ +@@ -0,0 +1,39 @@ +/* Copyright 2023 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); @@ -1941,10 +2203,18 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/Passes.td b/stablehlo/s + Experimental version of the --stablehlo-refine-shapes pass. + }]; +} ++ ++def ChloRecomposeOpsPass : Pass<"experimental-chlo-recompose-ops", "ModuleOp"> { ++ let summary = "(Experimental) Recompose CHLO ops serialized as custom calls."; ++ let description = [{ ++ Experimental version of CHLO serialization support. ++ }]; ++ let dependentDialects = ["chlo::ChloDialect"]; ++} diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp --- stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp +++ stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp -@@ -0,0 +1,441 @@ +@@ -0,0 +1,171 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2023 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); @@ -1960,14 +2230,12 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy +limitations under the License. +==============================================================================*/ + -+#include "llvm/ADT/DenseSet.h" ++#include ++ +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" -+#include "mlir/IR/BuiltinAttributes.h" -+#include "mlir/IR/BuiltinTypes.h" -+#include "mlir/IR/Matchers.h" -+#include "mlir/IR/Value.h" ++#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -1975,6 +2243,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/experimental/dialect/StablehloOps.h" +#include "stablehlo/experimental/transforms/Passes.h" ++#include "stablehlo/transforms/Passes.h" + +namespace mlir { +namespace stablehlo { @@ -1985,169 +2254,6 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy + +namespace { + -+struct CanonicalizeCustomCallOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CustomCallOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector refinements; -+ if (failed(hlo::getShapeRefinements(op.getLoc(), op, refinements))) -+ return rewriter.notifyMatchFailure(op, "expected valid refinements"); -+ auto indicesAttr = -+ op->getAttr("indices_of_shape_operands").cast(); -+ DenseSet indices(indicesAttr.value_begin(), -+ indicesAttr.value_end()); -+ -+ // Discard the indices_of_shape_operands attribute. -+ // We rely on the verification logic implemented in getShapeRefinements to -+ // make sure that its value is consistent with the result types. -+ // In the future, when we upgrade indices_of_shape_operands from an -+ // experiment to a full-fledged StableHLO feature, this logic will be moved -+ // to a proper verifier. -+ SmallVector newAttrs; -+ for (auto attr : op->getAttrs()) { -+ if (attr.getName() == "indices_of_shape_operands") continue; -+ if (attr.getName() == "operand_layouts") { -+ // Drop the operand_layouts that correspond to indices_of_shape_operands -+ ArrayAttr operandLayouts = op.getOperandLayoutsAttr(); -+ SmallVector newOperandLayouts; -+ for (unsigned i = 0; i < operandLayouts.size(); ++i) { -+ if (indices.contains(i)) continue; -+ newOperandLayouts.push_back(operandLayouts[i]); -+ } -+ attr = NamedAttribute(attr.getName(), -+ rewriter.getArrayAttr(newOperandLayouts)); -+ } -+ newAttrs.push_back(attr); -+ } -+ -+ // Discard the operands that correspond to indices_of_shape_operands. -+ // We rely on the verification logic implemented in getShapeRefinements to -+ // make sure that: 1) these operands are static, 2) the values of these -+ // operands are consistent with the result types. -+ SmallVector newOperands; -+ auto resultIndex = 0; -+ for (auto& operand : op->getOpOperands()) { -+ if (indices.contains(operand.getOperandNumber())) { -+ auto resultType = -+ op->getResult(resultIndex).getType().dyn_cast(); -+ if (!resultType || !resultType.hasStaticShape()) -+ return rewriter.notifyMatchFailure(op, -+ "expected static result types"); -+ ++resultIndex; -+ continue; -+ } -+ newOperands.push_back(operand.get()); -+ } -+ rewriter.replaceOpWithNewOp(op, op.getResultTypes(), -+ newOperands, newAttrs); -+ return success(); -+ } -+}; -+ -+struct CanonicalizeDynamicBroadcastInDimOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op, -+ PatternRewriter& rewriter) const override { -+ // This pattern discards the output_dimensions operand as well as the -+ // known_expanding_dimensions and known_nonexpanding_dimensions attributes. -+ // We rely on the verifier to make sure that their values are consistent -+ // with the result type. -+ if (!op.getOperand().getType().hasStaticShape()) -+ return rewriter.notifyMatchFailure(op, "expected static operand type"); -+ if (!succeeded(hlo::matchInts(op.getOutputDimensions()))) -+ return rewriter.notifyMatchFailure(op, -+ "expected static output_dimensions"); -+ if (!op.getType().hasStaticShape()) -+ return rewriter.notifyMatchFailure(op, "expected static result type"); -+ rewriter.replaceOpWithNewOp( -+ op, op.getType(), op.getOperand(), op.getBroadcastDimensions()); -+ return success(); -+ } -+}; -+ -+struct CanonicalizeDynamicConvOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicConvOp op, -+ PatternRewriter& rewriter) const override { -+ // ConvolutionOp supports dynamic shapes for operands and results, so we -+ // don't check for that here unlike in some other patterns in this pass. -+ SmallVector padding; -+ if (!succeeded(hlo::matchInts(op.getDPadding(), padding))) -+ return rewriter.notifyMatchFailure(op, "expected static padding"); -+ auto paddingAttr = DenseIntElementsAttr::get( -+ RankedTensorType::get({static_cast(padding.size()) / 2, 2}, -+ rewriter.getI64Type()), -+ padding); -+ rewriter.replaceOpWithNewOp( -+ op, op.getType(), op.getLhs(), op.getRhs(), op.getWindowStridesAttr(), -+ paddingAttr, op.getLhsDilationAttr(), op.getRhsDilationAttr(), -+ op.getWindowReversalAttr(), op.getDimensionNumbers(), -+ op.getFeatureGroupCount(), op.getBatchGroupCount(), -+ op.getPrecisionConfigAttr()); -+ return success(); -+ } -+}; -+ -+struct CanonicalizeDynamicGatherOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicGatherOp op, -+ PatternRewriter& rewriter) const override { -+ // GatherOp supports dynamic shapes for operands and results, so we -+ // don't check for that here unlike in some other patterns in this pass. -+ SmallVector sliceSizes; -+ if (!succeeded(hlo::matchInts(op.getSliceSizes(), sliceSizes))) -+ return rewriter.notifyMatchFailure(op, "expected static slice_sizes"); -+ rewriter.replaceOpWithNewOp( -+ op, op.getType(), op.getOperand(), op.getStartIndices(), -+ op.getDimensionNumbersAttr(), rewriter.getI64TensorAttr(sliceSizes), -+ op.getIndicesAreSortedAttr()); -+ return success(); -+ } -+}; -+ -+struct CanonicalizeDynamicIotaOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicIotaOp op, -+ PatternRewriter& rewriter) const override { -+ // This pattern discards the output_shape operand. We rely on the verifier -+ // to make sure that its value is consistent with result type. -+ SmallVector outputShape; -+ if (!succeeded(hlo::matchInts(op.getOutputShape(), outputShape))) -+ return rewriter.notifyMatchFailure(op, "expected static output_shape"); -+ if (!op.getType().hasStaticShape()) -+ return rewriter.notifyMatchFailure(op, "expected static result type"); -+ rewriter.replaceOpWithNewOp(op, op.getType(), -+ op.getIotaDimension()); -+ return success(); -+ } -+}; -+ -+struct CanonicalizeDynamicPadOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicPadOp op, -+ PatternRewriter& rewriter) const override { -+ // PadOp supports dynamic shapes for operands and results, so we -+ // don't check for that here unlike in some other patterns in this pass. -+ SmallVector edgePaddingLow, edgePaddingHigh, interiorPadding; -+ if (!succeeded(hlo::matchInts(op.getEdgePaddingLow(), edgePaddingLow))) -+ return rewriter.notifyMatchFailure(op, "expected static low"); -+ if (!succeeded(hlo::matchInts(op.getEdgePaddingHigh(), edgePaddingHigh))) -+ return rewriter.notifyMatchFailure(op, "expected static high"); -+ if (!succeeded(hlo::matchInts(op.getInteriorPadding(), interiorPadding))) -+ return rewriter.notifyMatchFailure(op, "expected static interior"); -+ rewriter.replaceOpWithNewOp( -+ op, op.getType(), op.getOperand(), op.getPaddingValue(), -+ rewriter.getI64TensorAttr(edgePaddingLow), -+ rewriter.getI64TensorAttr(edgePaddingHigh), -+ rewriter.getI64TensorAttr(interiorPadding)); -+ return success(); -+ } -+}; -+ +struct CanonicalizeDynamicReduceWindowOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; @@ -2175,10 +2281,10 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy + return rewriter.notifyMatchFailure(op, "expected static padding"); + auto newOp = rewriter.create( + op->getLoc(), op->getResultTypes(), op.getInputs(), op.getInitValues(), -+ rewriter.getI64TensorAttr(windowDimensions), -+ rewriter.getI64TensorAttr(windowStrides), -+ rewriter.getI64TensorAttr(baseDilations), -+ rewriter.getI64TensorAttr(windowDilations), ++ rewriter.getDenseI64ArrayAttr(windowDimensions), ++ rewriter.getDenseI64ArrayAttr(windowStrides), ++ rewriter.getDenseI64ArrayAttr(baseDilations), ++ rewriter.getDenseI64ArrayAttr(windowDilations), + hlo::getPaddingAttr(&rewriter, padding)); + + // Inline the called computation into newOp. @@ -2196,22 +2302,6 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy + } +}; + -+struct CanonicalizeDynamicReshapeOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicReshapeOp op, -+ PatternRewriter& rewriter) const override { -+ // This pattern ignores and discards the output_shape operand. We rely on -+ // the verifier to make sure that its value is consistent with result type. -+ if (!succeeded(hlo::matchInts(op.getOutputShape()))) -+ return rewriter.notifyMatchFailure(op, "expected static output_shape"); -+ if (!op.getType().hasStaticShape()) -+ return rewriter.notifyMatchFailure(op, "expected static result type"); -+ rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand()); -+ return success(); -+ } -+}; -+ +struct CanonicalizeDynamicRngBitGeneratorOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; @@ -2262,91 +2352,6 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy + } +}; + -+struct CanonicalizeRealDynamicSliceOpToDynamicSliceOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(RealDynamicSliceOp op, -+ PatternRewriter& rewriter) const override { -+ // DynamicSliceOp supports dynamic shapes for operands and results, so we -+ // don't check for that here unlike in some other patterns in this pass. -+ -+ // This rewrite only works for unit strides because DynamicSliceOp -+ // doesn't support strides (i.e. it implicitly has unit strides). -+ SmallVector strides; -+ if (!succeeded(hlo::matchInts(op.getStrides(), strides))) -+ return rewriter.notifyMatchFailure(op, "expected static strides"); -+ if (!llvm::all_of(strides, [&](int64_t stride) { return stride == 1; })) -+ return rewriter.notifyMatchFailure(op, "expected unit strides"); -+ -+ // Check that slice sizes are fully static (DynamicSliceOp style). -+ // To detect that, we check whether `limit_indices` is defined as -+ // `start_indices + constant` or `constant + start_indices`. -+ DenseIntElementsAttr sliceSizesAttr; -+ auto m_startIndices = matchers::m_Val(op.getStartIndices()); -+ if (!matchPattern( -+ op.getLimitIndices(), -+ m_Op(m_startIndices, m_Constant(&sliceSizesAttr))) && -+ !matchPattern(op.getLimitIndices(), -+ m_Op(m_Constant(&sliceSizesAttr), m_startIndices))) -+ return rewriter.notifyMatchFailure( -+ op, "expected limit indices equal to start indices plus constant"); -+ -+ // RealDynamicSliceOp can take tensors of integer or index element types. -+ // DynamicSliceOp::slice_sizes only supports i64 element type. -+ // Adapt accordingly in order to be compatible with DynamicSliceOp. -+ SmallVector sliceSizes; -+ for (auto element : sliceSizesAttr.getValues()) { -+ sliceSizes.push_back(element.getSExtValue()); -+ } -+ -+ // RealDynamicSliceOp::start_indices is a 1-dimensional tensor. -+ // DynamicSliceOp::start_indices is a vararg of 0-dimensional tensors. -+ // Adapt accordingly in order to be compatible with DynamicSliceOp. -+ SmallVector startIndices; -+ for (auto i = 0; i < static_cast(sliceSizes.size()); ++i) { -+ auto startIndexElementType = -+ op.getStartIndices().getType().getElementType(); -+ auto startIndex1DType = RankedTensorType::get({1}, startIndexElementType); -+ auto startIndex1D = rewriter.create( -+ op.getLoc(), startIndex1DType, op.getStartIndices(), -+ rewriter.getI64TensorAttr(i), rewriter.getI64TensorAttr(i + 1), -+ rewriter.getI64TensorAttr(1)); -+ auto startIndex0DType = RankedTensorType::get({}, startIndexElementType); -+ auto startIndex0D = rewriter.create( -+ op.getLoc(), startIndex0DType, startIndex1D); -+ startIndices.push_back(startIndex0D); -+ } -+ -+ rewriter.replaceOpWithNewOp( -+ op, op.getType(), op.getOperand(), startIndices, -+ rewriter.getI64TensorAttr(sliceSizes)); -+ return success(); -+ } -+}; -+ -+struct CanonicalizeRealDynamicSliceOpToSliceOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(RealDynamicSliceOp op, -+ PatternRewriter& rewriter) const override { -+ // SliceOp supports dynamic shapes for operands and results, so we -+ // don't check for that here unlike in some other patterns in this pass. -+ SmallVector startIndices, limitIndices, strides; -+ if (!succeeded(hlo::matchInts(op.getStartIndices(), startIndices))) -+ return rewriter.notifyMatchFailure(op, "expected static start"); -+ if (!succeeded(hlo::matchInts(op.getLimitIndices(), limitIndices))) -+ return rewriter.notifyMatchFailure(op, "expected static limit"); -+ if (!succeeded(hlo::matchInts(op.getStrides(), strides))) -+ return rewriter.notifyMatchFailure(op, "expected static strides"); -+ rewriter.replaceOpWithNewOp( -+ op, op.getType(), op.getOperand(), -+ rewriter.getI64TensorAttr(startIndices), -+ rewriter.getI64TensorAttr(limitIndices), -+ rewriter.getI64TensorAttr(strides)); -+ return success(); -+ } -+}; -+ +struct StablehloCanonicalizeDynamismPass + : public impl::StablehloCanonicalizeDynamismPassBase< + StablehloCanonicalizeDynamismPass> { @@ -2362,21 +2367,16 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy + config.strictMode = GreedyRewriteStrictness::AnyOp; + + RewritePatternSet patterns(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); ++ populateStablehloCanonicalizeDynamismPatterns(&patterns, &getContext()); + patterns.add(&getContext()); -+ patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); -+ patterns.add( -+ &getContext()); -+ patterns.add(&getContext()); -+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), ++ ++ auto funcOp = getOperation(); ++ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns), + config))) { ++ funcOp.emitError("Failed to converge StablehloCanonicalizeDynamism in ") ++ << config.maxIterations << " iterations"; + return signalPassFailure(); + } + } @@ -2389,7 +2389,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp --- stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp +++ stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp -@@ -0,0 +1,1308 @@ +@@ -0,0 +1,170 @@ +/* Copyright 2022 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. @@ -2404,41 +2404,22 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +limitations under the License. +==============================================================================*/ + ++#include "stablehlo/transforms/StablehloRefineShapes.h" ++ +#include -+#include -+#include -+#include + -+#include "llvm/ADT/APInt.h" -+#include "llvm/ADT/APSInt.h" -+#include "llvm/ADT/STLExtras.h" -+#include "llvm/ADT/STLFunctionalExtras.h" -+#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" -+#include "llvm/ADT/StringRef.h" -+#include "llvm/Support/ErrorHandling.h" -+#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" -+#include "mlir/IR/BuiltinAttributes.h" -+#include "mlir/IR/BuiltinOps.h" -+#include "mlir/IR/BuiltinTypes.h" -+#include "mlir/IR/Diagnostics.h" -+#include "mlir/IR/MLIRContext.h" -+#include "mlir/IR/Matchers.h" -+#include "mlir/IR/OpDefinition.h" -+#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" -+#include "mlir/IR/Types.h" -+#include "mlir/IR/Value.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/Base.h" -+#include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/dialect/TypeInference.h" +#include "stablehlo/experimental/dialect/StablehloOps.h" +#include "stablehlo/experimental/transforms/Passes.h" ++#include "stablehlo/transforms/Passes.h" + +namespace mlir { +namespace stablehlo { @@ -2449,813 +2430,10 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + +namespace { + -+// DenseElementsAttr can be constructed from ArrayRef but not from -+// ArrayRef. This helper bridges the gap. -+DenseIntElementsAttr getTensorAttr(ShapedType type, ArrayRef values) { -+ SmallVector supportedValues(values); -+ return DenseIntElementsAttr::get(type, supportedValues); -+} -+ -+APSInt getAPSInt(Type type, uint64_t value) { -+ unsigned numBits; -+ bool isUnsigned; -+ if (auto integerType = type.dyn_cast()) { -+ numBits = integerType.getWidth(); -+ // Signless types are treated as signed, per StableHLO convention. -+ isUnsigned = integerType.isUnsignedInteger(); -+ } else { -+ llvm::report_fatal_error("expected integer type"); -+ } -+ return APSInt({/*numBits=*/numBits, value}, -+ /*isUnsigned=*/isUnsigned); -+} -+ -+// The patterns below implement partial evaluation of shape computations which -+// is a critical part of implementing type refinement for ops like -+// dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape -+// depends on the value of their shape operands. -+ -+template -+LogicalResult evalElementwise(PatternRewriter& rewriter, OpType op, -+ FuncType fn) { -+ auto resultType = op.getType(); -+ if (!resultType.hasRank() || -+ !resultType.getElementType().template isa()) -+ return rewriter.notifyMatchFailure(op, -+ "expected integer result tensor type"); -+ -+ SmallVector result; -+ if constexpr (OpType::template hasTrait()) { -+ SmallVector operand; -+ if (failed(hlo::matchInts(op.getOperand(), operand))) -+ return rewriter.notifyMatchFailure(op, "expected constant operand"); -+ for (const auto& operandEl : operand) { -+ result.push_back(fn(operandEl)); -+ } -+ } else if constexpr (OpType::template hasTrait< -+ OpTrait::NOperands<2>::Impl>()) { -+ SmallVector lhs, rhs; -+ if (failed(hlo::matchInts(op.getLhs(), lhs)) || -+ failed(hlo::matchInts(op.getRhs(), rhs))) -+ return rewriter.notifyMatchFailure(op, "expected constant operands"); -+ for (auto [lhsEl, rhsEl] : llvm::zip(lhs, rhs)) { -+ result.push_back(fn(lhsEl, rhsEl)); -+ } -+ } else if constexpr (OpType::template hasTrait< -+ OpTrait::NOperands<3>::Impl>()) { -+ SmallVector x, y, z; -+ if (failed(hlo::matchInts(op->getOperand(0), x)) || -+ failed(hlo::matchInts(op->getOperand(1), y)) || -+ failed(hlo::matchInts(op->getOperand(2), z))) -+ return rewriter.notifyMatchFailure(op, "expected constant operands"); -+ for (auto [xEl, yEl, zEl] : llvm::zip(x, y, z)) { -+ result.push_back(fn(xEl, yEl, zEl)); -+ } -+ } else { -+ llvm::report_fatal_error("unsupported number of operands"); -+ } -+ -+ rewriter.replaceOpWithNewOp(op, -+ getTensorAttr(resultType, result)); -+ return success(); -+} -+ -+struct EvalAddOpPattern : public OpRewritePattern { ++struct RefineDynamicReduceWindowOpPattern ++ : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(AddOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, -+ [&](APSInt lhs, APSInt rhs) { return lhs + rhs; }); -+ } -+}; -+ -+struct EvalAndOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(AndOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.getElementType().isInteger(1)) -+ return rewriter.notifyMatchFailure(op, "expected boolean element type"); -+ -+ return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { -+ return getAPSInt(resultType.getElementType(), lhsInt != 0 && rhsInt != 0); -+ }); -+ } -+}; -+ -+struct EvalBroadcastInDimOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(BroadcastInDimOp op, -+ PatternRewriter& rewriter) const override { -+ auto operandType = op.getOperand().getType(); -+ if (!operandType.hasRank() || operandType.getRank() != 0) -+ return rewriter.notifyMatchFailure(op, "expected 0-dimensional type"); -+ -+ SmallVector operand; -+ if (failed(hlo::matchInts(op.getOperand(), operand))) -+ return rewriter.notifyMatchFailure(op, "expected constant operands"); -+ auto scalar = operand[0]; -+ -+ rewriter.replaceOpWithNewOp( -+ op, getTensorAttr(op.getType(), scalar)); -+ return success(); -+ } -+}; -+ -+struct EvalClampOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ClampOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, -+ [&](APSInt min, APSInt operand, APSInt max) { -+ if (operand < min) return min; -+ if (max < operand) return max; -+ return operand; -+ }); -+ } -+}; -+ -+struct EvalCompareOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CompareOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { -+ bool result; -+ switch (op.getComparisonDirection()) { -+ case ComparisonDirection::EQ: -+ result = lhs == rhs; -+ break; -+ case ComparisonDirection::NE: -+ result = lhs != rhs; -+ break; -+ case ComparisonDirection::GE: -+ result = lhs >= rhs; -+ break; -+ case ComparisonDirection::GT: -+ result = lhs > rhs; -+ break; -+ case ComparisonDirection::LE: -+ result = lhs <= rhs; -+ break; -+ case ComparisonDirection::LT: -+ result = lhs < rhs; -+ break; -+ } -+ return getAPSInt(resultType.getElementType(), result); -+ }); -+ } -+}; -+ -+struct EvalConcatenateOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ConcatenateOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.hasRank() || op.getDimension() != 0) -+ return rewriter.notifyMatchFailure(op, "expected dimension = 0"); -+ -+ SmallVector result; -+ for (Value operand : op->getOperands()) { -+ if (failed(hlo::matchInts(operand, result))) -+ return rewriter.notifyMatchFailure(op, "expected constant operands"); -+ } -+ -+ rewriter.replaceOpWithNewOp(op, -+ getTensorAttr(resultType, result)); -+ return success(); -+ } -+}; -+ -+struct EvalConvertOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ConvertOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.getElementType().isa()) -+ return rewriter.notifyMatchFailure(op, -+ "expected integer result tensor type"); -+ auto resultBitWidth = resultType.getElementType().getIntOrFloatBitWidth(); -+ return evalElementwise(rewriter, op, [&](APSInt operand) { -+ return operand.extOrTrunc(resultBitWidth); -+ }); -+ } -+}; -+ -+struct EvalDivOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DivOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, -+ [&](APSInt lhs, APSInt rhs) { return lhs / rhs; }); -+ } -+}; -+ -+struct EvalGetDimensionSizeOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(GetDimensionSizeOp op, -+ PatternRewriter& rewriter) const override { -+ auto operandType = op.getOperand().getType(); -+ if (!operandType.hasRank()) -+ return rewriter.notifyMatchFailure(op, "expected ranked operand"); -+ if (operandType.isDynamicDim(op.getDimension())) -+ return rewriter.notifyMatchFailure(op, "expected static dimension"); -+ -+ auto result = operandType.getDimSize(op.getDimension()); -+ rewriter.replaceOpWithNewOp( -+ op, DenseIntElementsAttr::get(op.getType(), result)); -+ return success(); -+ } -+}; -+ -+struct EvalMaxOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(MaxOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { -+ return lhs >= rhs ? lhs : rhs; -+ }); -+ } -+}; -+ -+struct EvalMinOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(MinOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { -+ return lhs <= rhs ? lhs : rhs; -+ }); -+ } -+}; -+ -+struct EvalMulOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(MulOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, -+ [&](APSInt lhs, APSInt rhs) { return lhs * rhs; }); -+ } -+}; -+ -+struct EvalOrOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(OrOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.getElementType().isInteger(1)) -+ return rewriter.notifyMatchFailure(op, "expected boolean element type"); -+ -+ return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { -+ return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0); -+ }); -+ } -+}; -+ -+struct EvalRemOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(RemOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, -+ [&](APSInt lhs, APSInt rhs) { return lhs % rhs; }); -+ } -+}; -+ -+struct EvalReshapeOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ReshapeOp op, -+ PatternRewriter& rewriter) const override { -+ DenseIntElementsAttr attr; -+ if (!matchPattern(op.getOperand(), m_Constant(&attr))) -+ return rewriter.notifyMatchFailure(op, "expected constant operand"); -+ rewriter.replaceOpWithNewOp(op, attr.reshape(op.getType())); -+ return success(); -+ } -+}; -+ -+struct EvalSelectOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(SelectOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector pred, onTrue, onFalse; -+ if (failed(hlo::matchInts(op.getPred(), pred)) || -+ failed(hlo::matchInts(op.getOnTrue(), onTrue)) || -+ failed(hlo::matchInts(op.getOnFalse(), onFalse))) -+ return rewriter.notifyMatchFailure(op, "expected constant operands"); -+ -+ SmallVector result; -+ for (auto [predEl, onTrueEl, onFalseEl] : -+ llvm::zip(pred, onTrue, onFalse)) { -+ result.push_back(predEl != 0 ? onTrueEl : onFalseEl); -+ } -+ -+ rewriter.replaceOpWithNewOp( -+ op, getTensorAttr(op.getType(), result)); -+ return success(); -+ } -+}; -+ -+struct EvalSignOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(SignOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.getElementType().isa()) -+ return rewriter.notifyMatchFailure(op, -+ "expected integer result tensor type"); -+ return evalElementwise(rewriter, op, [&](APSInt operand) { -+ int64_t result; -+ if (operand.isNegative()) -+ result = -1; -+ else if (operand.isZero()) -+ result = 0; -+ else -+ result = 1; -+ return getAPSInt(resultType.getElementType(), result); -+ }); -+ } -+}; -+ -+struct EvalSliceOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(SliceOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.hasRank() || resultType.getRank() != 1) -+ return rewriter.notifyMatchFailure(op, "expected 1-dimensional type"); -+ -+ SmallVector operand; -+ if (failed(hlo::matchInts(op.getOperand(), operand))) -+ return rewriter.notifyMatchFailure(op, "expected constant operand"); -+ -+ int64_t start = op.getStartIndices().getValues()[0]; -+ int64_t limit = op.getLimitIndices().getValues()[0]; -+ int64_t stride = op.getStrides().getValues()[0]; -+ SmallVector result; -+ for (auto i = start; i < limit; i += stride) { -+ result.push_back(operand[i]); -+ } -+ -+ rewriter.replaceOpWithNewOp(op, -+ getTensorAttr(resultType, result)); -+ return success(); -+ } -+}; -+ -+struct EvalSubtractOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(SubtractOp op, -+ PatternRewriter& rewriter) const override { -+ return evalElementwise(rewriter, op, -+ [&](APSInt lhs, APSInt rhs) { return lhs - rhs; }); -+ } -+}; -+ -+// The patterns below implement shape refinement of individual ops. -+// In a nutshell, they use the upstream type inference infrastructure and a -+// StableHLO-specific extension to refine return types based on potentially -+// refined operands. -+ -+// Refines the values using the given types. -+// Tricky implementation details: -+// 1) Need to support partial shape refinements, e.g. if just a single -+// dimension size out of an entire tensor type got refined. This is done -+// via inferMostSpecificType. -+// 2) Need to signal propagation of the refined shapes across the -+// StableHLO program. Different callers of this function have different -+// propagation needs, so this function doesn't signal anything on its own -+// and leaves that to the callers. -+LogicalResult refineValues(PatternRewriter& rewriter, Operation* op, -+ ValueRange values, TypeRange types) { -+ if (values.size() != types.size()) -+ return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { -+ diag << "refineValues failed for " << types << ": expected " -+ << values.size() << " types, got " << types.size(); -+ }); -+ -+ // Check whether `types` contain any new information with respect to existing -+ // return types. Even if just a single dimension size out of an entire tensor -+ // type got updated, using `inferMostSpecificType` ensures that we don't -+ // miss that. -+ bool needsRefinement = false; -+ SmallVector refinedTypes; -+ for (auto it : llvm::zip(values.getTypes(), types)) { -+ // Cannot use structured bindings to simplify this because capturing -+ // structured bindings in a lambda is a C++ 20 extension. -+ auto currentType = std::get<0>(it); -+ auto refinement = std::get<1>(it); -+ auto refinedType = hlo::inferMostSpecificType( -+ /*location=*/{}, {currentType, refinement}); -+ if (failed(refinedType)) -+ return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { -+ diag << "inferMostSpecificType failed for " << currentType << " and " -+ << refinement; -+ }); -+ refinedTypes.push_back(*refinedType); -+ needsRefinement |= (currentType != *refinedType); -+ } -+ if (!needsRefinement) -+ return rewriter.notifyMatchFailure(op, "doesn't need refinement"); -+ -+ for (auto it : llvm::zip(values, refinedTypes)) { -+ // Cannot use structured bindings to simplify this because capturing -+ // structured bindings in a lambda is a C++ 20 extension. -+ auto value = std::get<0>(it); -+ auto refinedType = std::get<1>(it); -+ if (value.getType() == refinedType) continue; -+ -+ // Check whether the users of this value are ready for the type of the -+ // value to be refined. -+ for (Operation* user : value.getUsers()) { -+ // CHLO and StableHLO ops are designed to support type refinements of -+ // their operands and results. Any operand type in these ops can change -+ // within what's supported by `inferMostSpecificType` without breaking -+ // verification of the op. -+ if (isa(user->getDialect())) -+ continue; -+ -+ // Simply changing operand type of `func.return` won't work because -+ // that won't update the FunctionType of the enclosing `func.func`. -+ // Nonetheless, we still want to support these ops because they are widely -+ // used in StableHLO programs (although the plan of record is to replace -+ // `func.return` ops in StableHLO programs with `stablehlo.return`: -+ // https://github.com/openxla/stablehlo/issues/425). -+ if (isa(user)) continue; -+ -+ // Unlike in TensorFlow's type inference pass, here we work only with -+ // allowlisted ops to focus our support on well-defined semantics of -+ // StableHLO programs. -+ return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { -+ diag << "unsupported refinement: tried to refine " << value.getType() -+ << " to " << refinedType << " for user " << user; -+ }); -+ } -+ -+ // Happy path: simply call setType here because most of our users are -+ // fine with that. -+ auto unrefinedType = value.getType(); -+ value.setType(refinedType); -+ -+ // Special case: for `func.return`, guard the refinement with a cast -+ // and leave propagation of the refined return type to a dedicated pattern. -+ auto isFuncReturn = [](OpOperand& use) -> bool { -+ return isa(use.getOwner()); -+ }; -+ if (llvm::none_of(value.getUses(), isFuncReturn)) continue; -+ rewriter.setInsertionPointAfter(op); -+ auto castToUnrefinedType = rewriter.create( -+ op->getLoc(), unrefinedType, value); -+ value.replaceUsesWithIf(castToUnrefinedType.getOutputs()[0], isFuncReturn); -+ } -+ -+ return success(); -+} -+ -+// Refines the return types of the given operation using the given types. -+// This function also signals PatternRewriter that it needs to visit all the -+// users of this op if any updates to its results have happened during execution -+// of the function. -+LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, -+ ArrayRef types) { -+ if (failed(refineValues(rewriter, op, op->getResults(), types))) -+ return failure(); -+ -+ // This `replaceOpWithIf` call doesn't actually change the IR, but -+ // it does ask the rewriter to visit all the users of this op. There is no -+ // upstream API to achieve this directly, but if it's introduced in the -+ // future, we could use it here. -+ rewriter.replaceOpWithIf(op, op->getResults(), -+ [](OpOperand& use) { return false; }); -+ return success(); -+} -+ -+// Refines the return types of the given operation using the given types. -+// Tricky implementation details: -+// 1) `types` can include non-shaped types. If there are tuple types, -+// then they are first flattened into non-tuple types using in-order -+// traversal, and only then we apply the refinements. If there are other -+// types, then the corresponding refinements must be completely empty. -+// 2) Encodings are not supported. In principle, TypeExtensions should be -+// supportable, but this needs careful thinking through. Given that no one -+// asked for support for bounded dynamism in this pass yet, this is left -+// for future work. -+// This function also signals PatternRewriter that it needs to visit all the -+// users of this op if any updates to its results have happened during execution -+// of the function. -+LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, -+ ArrayRef refinements) { -+ SmallVector flattenedTypes; -+ hlo::flattenTupleTypes(op->getResultTypes(), flattenedTypes); -+ auto flattenedSize = flattenedTypes.size(); -+ if (flattenedSize != refinements.size()) -+ return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { -+ diag << "refineReturnTypes failed: expected " << flattenedSize -+ << " refinements, got " << refinements.size(); -+ }); -+ -+ SmallVector flattenedRefinedTypes; -+ for (auto it : llvm::zip(flattenedTypes, refinements)) { -+ // Cannot use structured bindings to simplify this because capturing -+ // structured bindings in a lambda is a C++ 20 extension. -+ ShapedType currentType = std::get<0>(it).dyn_cast(); -+ ShapedTypeComponents refinement = std::get<1>(it); -+ auto failWithReason = [&](StringRef reason) { -+ return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { -+ diag << "refineTypes failed: refining " << currentType -+ << "with refinement: {"; -+ if (refinement.hasRank()) { -+ diag << "shape = [" << refinement.getDims() << "]"; -+ if (refinement.getAttribute()) -+ diag << "attribute = " << refinement.getAttribute(); -+ } else { -+ diag << "hasRank = false"; -+ } -+ diag << ", elementType = " << refinement.getElementType(); -+ diag << "} failed: " << reason; -+ }); -+ }; -+ -+ // If the current type is not a shaped type, then the refinement must -+ // be completely empty. -+ if (!currentType) { -+ if (refinement.hasRank() || refinement.getElementType() || -+ refinement.getAttribute()) -+ return failWithReason("unsupported refinement"); -+ flattenedRefinedTypes.push_back(currentType); -+ continue; -+ } -+ -+ // If the refinement has an element type, then it must be the same as -+ // the current element type. -+ Type currentElementType = currentType.getElementType(); -+ if (refinement.getElementType() && -+ currentElementType != refinement.getElementType()) -+ return failWithReason("expected compatible element types"); -+ -+ // If neither the current type nor the refinement are ranked, then there's -+ // nothing to refine, and we return the current type. -+ bool hasRank = currentType.hasRank() || refinement.hasRank(); -+ if (!hasRank) { -+ flattenedRefinedTypes.push_back(currentType); -+ continue; -+ } -+ -+ // If either the current type or the refinement have encodings, then -+ // we fail. Encodings are left for future work. -+ Attribute currentEncoding = nullptr; -+ if (auto currentRankedType = currentType.dyn_cast()) { -+ currentEncoding = currentRankedType.getEncoding(); -+ } -+ Attribute refinedEncoding = refinement.getAttribute(); -+ if (currentEncoding || refinedEncoding) -+ return failWithReason("expected compatible encodings"); -+ -+ // If both the current type and the refinement have shapes, use the shape -+ // from the refinement. Otherwise, pick whatever is available. -+ // Make sure that the resulting type is compatible with the current type -+ // to avoid creating invalid code. -+ auto refinedShape = -+ refinement.hasRank() ? refinement.getDims() : currentType.getShape(); -+ auto refinedType = RankedTensorType::get(refinedShape, currentElementType); -+ if (!hlo::isCompatibleForHloTypeInference(currentType, refinedType)) -+ return failWithReason("expected compatible shapes"); -+ flattenedRefinedTypes.push_back(refinedType); -+ } -+ -+ SmallVector refinedTypes; -+ if (failed(hlo::unflattenTupleTypes(op->getResultTypes(), -+ flattenedRefinedTypes, refinedTypes))) -+ return failure(); -+ return refineReturnTypes(rewriter, op, refinedTypes); -+} -+ -+// Refines the return type of the given operation using the given shape. -+// This function also signals PatternRewriter that it needs to visit all the -+// users of this op if any updates to its results have happened during execution -+// of the function. -+template -+LogicalResult refineReturnShape(PatternRewriter& rewriter, OpType op, -+ ArrayRef shape) { -+ return refineReturnTypes(rewriter, op, ShapedTypeComponents(shape)); -+} -+ -+// Refines the return type of the given operation using the given shape. -+// This function also signals PatternRewriter that it needs to visit all the -+// users of this op if any updates to its results have happened during execution -+// of the function. -+template -+LogicalResult refineReturnShape(PatternRewriter& rewriter, OpType op, -+ Value shapeValue) { -+ // At the moment, we only support refining return types using fully static -+ // shape values which serves the current use cases well. -+ // Support for partially static shape values is left for future work. -+ SmallVector shape; -+ if (failed(hlo::matchInts(shapeValue, shape))) -+ return rewriter.notifyMatchFailure(op, "expected constant output shape"); -+ return refineReturnShape(rewriter, op, shape); -+} -+ -+struct RefineAllGatherOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(AllGatherOp op, -+ PatternRewriter& rewriter) const override { -+ auto operandType = op.getOperand().getType(); -+ if (!operandType.hasRank()) -+ return rewriter.notifyMatchFailure(op, "expected ranked operand type"); -+ -+ // This represents the cross_replica_and_partition process grouping strategy -+ // that requires num_partitions to compute shardCount. Since we don't know -+ // num_partitions at this point, we error out. -+ if (op.getChannelHandle() && !op.getUseGlobalDeviceIds()) -+ return rewriter.notifyMatchFailure(op, "unsupported strategy"); -+ DenseIntElementsAttr replicaGroups = op.getReplicaGroups(); -+ auto shardCount = replicaGroups.getType().getDimSize(1); -+ -+ SmallVector refinement(operandType.getShape()); -+ if (!operandType.isDynamicDim(op.getAllGatherDim())) -+ refinement[op.getAllGatherDim()] *= shardCount; -+ return refineReturnShape(rewriter, op, refinement); -+ } -+}; -+ -+struct RefineBitcastConvertOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(BitcastConvertOp op, -+ PatternRewriter& rewriter) const override { -+ auto operandType = op.getOperand().getType(); -+ if (!operandType.hasRank()) -+ return rewriter.notifyMatchFailure(op, "expected ranked operand type"); -+ -+ // If bit widths of the operand and the result are different, then -+ // operand and result shapes have different ranks. -+ // This complicates the logic quite a bit and is not needed to pass the -+ // current tests, so we leave this for future work. -+ auto resultType = op.getType(); -+ auto getBitWidthFn = [](ShapedType type) { -+ auto elementType = type.getElementType(); -+ if (auto complexType = elementType.dyn_cast()) -+ return complexType.getElementType().getIntOrFloatBitWidth(); -+ return elementType.getIntOrFloatBitWidth(); -+ }; -+ -+ if (getBitWidthFn(operandType) != getBitWidthFn(resultType)) -+ return rewriter.notifyMatchFailure(op, "unsupported bit width"); -+ -+ return refineReturnShape(rewriter, op, operandType.getShape()); -+ } -+}; -+ -+struct RefineConvertOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ConvertOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector inferredReturnShapes; -+ if (failed(hlo::inferConvertOp( -+ /*location=*/{}, op.getOperand(), inferredReturnShapes))) -+ return rewriter.notifyMatchFailure(op, "inferConvertOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnShapes); -+ } -+}; -+ -+struct RefineConvolutionOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ConvolutionOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector inferredReturnShapes; -+ if (failed(hlo::inferConvolutionOp( -+ /*location=*/{}, op.getLhs().getType(), op.getRhs().getType(), -+ op.getWindowStrides(), op.getPadding(), op.getLhsDilation(), -+ op.getRhsDilation(), op.getWindowReversal(), -+ op.getDimensionNumbers().getInputBatchDimension(), -+ op.getDimensionNumbers().getInputFeatureDimension(), -+ op.getDimensionNumbers().getInputSpatialDimensions(), -+ op.getDimensionNumbers().getKernelInputFeatureDimension(), -+ op.getDimensionNumbers().getKernelOutputFeatureDimension(), -+ op.getDimensionNumbers().getKernelSpatialDimensions(), -+ op.getDimensionNumbers().getOutputBatchDimension(), -+ op.getDimensionNumbers().getOutputFeatureDimension(), -+ op.getDimensionNumbers().getOutputSpatialDimensions(), -+ op.getFeatureGroupCount(), op.getBatchGroupCount(), -+ op.getPrecisionConfig(), inferredReturnShapes))) -+ return rewriter.notifyMatchFailure(op, "inferConvolutionOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnShapes); -+ } -+}; -+ -+struct RefineCustomCallOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CustomCallOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector refinements; -+ if (failed(hlo::getShapeRefinements(op.getLoc(), op, refinements))) -+ return rewriter.notifyMatchFailure(op, "expected valid refinements"); -+ return refineReturnTypes(rewriter, op, refinements); -+ } -+}; -+ -+struct RefineDotGeneralOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DotGeneralOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector inferredReturnShapes; -+ if (failed(hlo::inferDotGeneralOp( -+ /*location=*/{}, op.getLhs().getType(), op.getRhs().getType(), -+ op.getDotDimensionNumbersAttr().getLhsBatchingDimensions(), -+ op.getDotDimensionNumbersAttr().getRhsBatchingDimensions(), -+ op.getDotDimensionNumbersAttr().getLhsContractingDimensions(), -+ op.getDotDimensionNumbersAttr().getRhsContractingDimensions(), -+ op.getPrecisionConfig(), inferredReturnShapes))) -+ return rewriter.notifyMatchFailure(op, "inferDotGeneralOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnShapes); -+ } -+}; -+ -+struct RefineDynamicBroadcastInDimOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op, -+ PatternRewriter& rewriter) const override { -+ return refineReturnShape(rewriter, op, op.getOutputDimensions()); -+ } -+}; -+ -+struct RefineDynamicConvOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicConvOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector padding; -+ if (failed(hlo::matchInts(op.getDPadding(), padding))) -+ return rewriter.notifyMatchFailure(op, "expected constant d_padding"); -+ if (op.getPadding().has_value()) -+ return rewriter.notifyMatchFailure(op, "expected empty padding"); -+ auto paddingType = RankedTensorType::get( -+ op.getDPadding().getType().getShape(), rewriter.getIntegerType(64)); -+ auto paddingAttr = DenseIntElementsAttr::get(paddingType, padding); -+ -+ SmallVector inferredReturnShapes; -+ if (failed(hlo::inferConvolutionOp( -+ /*location=*/{}, op.getLhs().getType(), op.getRhs().getType(), -+ op.getWindowStrides(), paddingAttr, op.getLhsDilation(), -+ op.getRhsDilation(), op.getWindowReversal(), -+ op.getDimensionNumbers().getInputBatchDimension(), -+ op.getDimensionNumbers().getInputFeatureDimension(), -+ op.getDimensionNumbers().getInputSpatialDimensions(), -+ op.getDimensionNumbers().getKernelInputFeatureDimension(), -+ op.getDimensionNumbers().getKernelOutputFeatureDimension(), -+ op.getDimensionNumbers().getKernelSpatialDimensions(), -+ op.getDimensionNumbers().getOutputBatchDimension(), -+ op.getDimensionNumbers().getOutputFeatureDimension(), -+ op.getDimensionNumbers().getOutputSpatialDimensions(), -+ op.getFeatureGroupCount(), op.getBatchGroupCount(), -+ op.getPrecisionConfig(), inferredReturnShapes))) -+ return rewriter.notifyMatchFailure(op, "inferConvolutionOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnShapes); -+ } -+}; -+ -+struct RefineDynamicIotaOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicIotaOp op, -+ PatternRewriter& rewriter) const override { -+ return refineReturnShape(rewriter, op, op.getOutputShape()); -+ } -+}; -+ -+struct RefineDynamicPadOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicPadOp op, -+ PatternRewriter& rewriter) const override { -+ // At the moment, we only support refining return types using fully static -+ // shape values which serves the current use cases well. -+ // Support for partially static shape values is left for future work. -+ SmallVector edgePaddingLow, edgePaddingHigh, interiorPadding; -+ if (failed(hlo::matchInts(op.getEdgePaddingLow(), edgePaddingLow))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant edge_padding_low"); -+ if (failed(hlo::matchInts(op.getEdgePaddingHigh(), edgePaddingHigh))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant edge_padding_high"); -+ if (failed(hlo::matchInts(op.getInteriorPadding(), interiorPadding))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant interior_padding"); -+ -+ SmallVector inferredReturnTypes; -+ if (failed(hlo::inferPadOp( -+ /*location=*/{}, op.getOperand().getType(), -+ op.getPaddingValue().getType(), -+ rewriter.getI64TensorAttr(edgePaddingLow), -+ rewriter.getI64TensorAttr(edgePaddingHigh), -+ rewriter.getI64TensorAttr(interiorPadding), inferredReturnTypes))) -+ return rewriter.notifyMatchFailure(op, "inferPadOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnTypes); -+ } -+}; -+ -+struct RefineDynamicReduceWindowOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CustomCallOp impl, ++ LogicalResult matchAndRewrite(CustomCallOp impl, + PatternRewriter& rewriter) const override { + auto maybeOp = getDynamicReduceWindowOp(impl); + if (!maybeOp || failed(maybeOp->verify())) return failure(); @@ -3284,25 +2462,21 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + SmallVector inferredReturnTypes; + if (failed(hlo::inferReduceWindowOp( + /*location=*/{}, op.getInputs(), op.getInitValues(), -+ rewriter.getI64TensorAttr(windowDimensions), -+ rewriter.getI64TensorAttr(windowStrides), -+ rewriter.getI64TensorAttr(baseDilations), -+ rewriter.getI64TensorAttr(windowDilations), -+ hlo::getPaddingAttr(&rewriter, padding), inferredReturnTypes))) ++ llvm::to_vector(rewriter.getI64TensorAttr(windowDimensions) ++ .getValues()), ++ llvm::to_vector( ++ rewriter.getI64TensorAttr(windowStrides).getValues()), ++ llvm::to_vector( ++ rewriter.getI64TensorAttr(baseDilations).getValues()), ++ llvm::to_vector(rewriter.getI64TensorAttr(windowDilations) ++ .getValues()), ++ hlo::getPaddingAttr(&rewriter, padding), op.getBody(), ++ inferredReturnTypes))) + return rewriter.notifyMatchFailure(op, "inferReduceWindowOp failed"); + return refineReturnTypes(rewriter, op, inferredReturnTypes); + } +}; + -+struct RefineDynamicReshapeOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(DynamicReshapeOp op, -+ PatternRewriter& rewriter) const override { -+ return refineReturnShape(rewriter, op, op.getOutputShape()); -+ } -+}; -+ +struct RefineDynamicRngBitGeneratorOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; @@ -3346,291 +2520,13 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + } +}; + -+struct RefineInferTypeOpInterfacePattern -+ : public OpInterfaceRewritePattern { -+ explicit RefineInferTypeOpInterfacePattern(MLIRContext* context) -+ : OpInterfaceRewritePattern(context, /*benefit=*/0) {} -+ LogicalResult matchAndRewrite(InferTypeOpInterface op, -+ PatternRewriter& rewriter) const override { -+ // Unlike in TensorFlow's type inference pass, here we work only with -+ // allowlisted ops to focus our support on well-defined semantics of -+ // StableHLO programs. -+ if (!isa(op->getDialect())) -+ return rewriter.notifyMatchFailure(op, "unsupported dialect"); -+ -+ // For the ops that implement InferTypeOpInterface, we reinfer their return -+ // types and see what happens. -+ // Operands of these ops might have been refined elsewhere (e.g. someone -+ // might have updated argument types of a function) or earlier during this -+ // pass, and this might enable refinement opportunities downstream. -+ SmallVector inferredReturnTypes; -+ if (failed(op.inferReturnTypes(getContext(), /*location=*/{}, -+ op->getOperands(), op->getAttrDictionary(), -+ op->getPropertiesStorage(), op->getRegions(), -+ inferredReturnTypes))) -+ return rewriter.notifyMatchFailure(op, "inferReturnTypes failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnTypes); -+ } -+}; -+ -+struct RefineRealDynamicSliceOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(RealDynamicSliceOp op, -+ PatternRewriter& rewriter) const override { -+ // Alternative #1: All attributes are fully static (SliceOp style). -+ SmallVector startIndices, limitIndices, strides; -+ if (succeeded(hlo::matchInts(op.getStartIndices(), startIndices)) && -+ succeeded(hlo::matchInts(op.getLimitIndices(), limitIndices)) && -+ succeeded(hlo::matchInts(op.getStrides(), strides))) { -+ SmallVector inferredReturnTypes; -+ if (failed(hlo::inferSliceOp(/*location=*/{}, op.getOperand().getType(), -+ rewriter.getI64TensorAttr(startIndices), -+ rewriter.getI64TensorAttr(limitIndices), -+ rewriter.getI64TensorAttr(strides), -+ inferredReturnTypes))) -+ return rewriter.notifyMatchFailure(op, "inferSliceOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnTypes); -+ } -+ -+ // Alternative #2: Slice sizes are fully static (DynamicSliceOp style). -+ // To detect that, we check whether `limit_indices` is defined as -+ // `start_indices + constant` or `constant + start_indices`. -+ DenseIntElementsAttr sliceSizesAttr; -+ auto m_startIndices = matchers::m_Val(op.getStartIndices()); -+ if (matchPattern( -+ op.getLimitIndices(), -+ m_Op(m_startIndices, m_Constant(&sliceSizesAttr))) || -+ matchPattern( -+ op.getLimitIndices(), -+ m_Op(m_Constant(&sliceSizesAttr), m_startIndices))) { -+ SmallVector strides; -+ if (!succeeded(hlo::matchInts(op.getStrides(), strides)) || -+ !llvm::all_of(strides, [&](int64_t stride) { return stride == 1; })) -+ return rewriter.notifyMatchFailure(op, "expected unit strides"); -+ -+ // RealDynamicSliceOp::start_indices is a 1-dimensional tensor. -+ // DynamicSliceOp::start_indices is a vararg of 0-dimensional tensors. -+ // Adapt accordingly in order to be compatible with inferDynamicSliceOp. -+ auto startIndicesElementType = -+ op.getStartIndices().getType().getElementType(); -+ SmallVector startIndicesTypes( -+ sliceSizesAttr.size(), -+ RankedTensorType::get({}, startIndicesElementType)); -+ -+ // RealDynamicSliceOp can take tensors of integer or index element types. -+ // DynamicSliceOp::slice_sizes only supports i64 element type. -+ // Adapt accordingly in order to be compatible with inferDynamicSliceOp. -+ SmallVector sliceSizes; -+ for (auto element : sliceSizesAttr.getValues()) { -+ sliceSizes.push_back(element.getSExtValue()); -+ } -+ -+ SmallVector inferredReturnTypes; -+ if (failed(hlo::inferDynamicSliceOp( -+ op.getLoc(), op.getOperand().getType(), startIndicesTypes, -+ rewriter.getI64TensorAttr(sliceSizes), inferredReturnTypes))) -+ return rewriter.notifyMatchFailure(op, "inferDynamicSliceOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnTypes); -+ } -+ -+ return rewriter.notifyMatchFailure( -+ op, -+ "expected either fully static attributes (SliceOp style) " -+ "or static sliceSizes (DynamicSliceOp style)"); -+ } -+}; -+ -+struct RefineReduceScatterOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ReduceScatterOp op, -+ PatternRewriter& rewriter) const override { -+ auto operandType = op.getOperand().getType(); -+ if (!operandType.hasRank()) -+ return rewriter.notifyMatchFailure(op, "expected ranked operand type"); -+ -+ // This represents the cross_replica_and_partition process grouping strategy -+ // that requires num_partitions to compute shardCount. Since we don't know -+ // num_partitions at this point, we error out. -+ if (op.getChannelHandle() && !op.getUseGlobalDeviceIds()) -+ return rewriter.notifyMatchFailure(op, "unsupported strategy"); -+ DenseIntElementsAttr replicaGroups = op.getReplicaGroups(); -+ auto shardCount = replicaGroups.getType().getDimSize(1); -+ -+ SmallVector refinement(operandType.getShape()); -+ if (!operandType.isDynamicDim(op.getScatterDimension())) -+ refinement[op.getScatterDimension()] /= shardCount; -+ return refineReturnShape(rewriter, op, refinement); -+ } -+}; -+ -+struct RefineRngOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(RngOp op, -+ PatternRewriter& rewriter) const override { -+ return refineReturnShape(rewriter, op, op.getShape()); -+ } -+}; -+ -+struct RefineUniformQuantizeOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(UniformQuantizeOp op, -+ PatternRewriter& rewriter) const override { -+ SmallVector inferredReturnShapes; -+ if (failed(hlo::inferUniformQuantizeOp( -+ /*location=*/{}, op.getOperand(), inferredReturnShapes))) -+ return rewriter.notifyMatchFailure(op, "inferConvertOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnShapes); -+ } -+}; -+ -+struct RefineWhileOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(WhileOp op, -+ PatternRewriter& rewriter) const override { -+ // Push the potentially refined operand types into the nested regions. -+ // This can lead to refinements of the return types of the body (but not -+ // of the cond since it always returns tensor), but the key insight here -+ // is that the enclosing while op doesn't care about these refinements -+ // (because its return types are equal to its operand types). -+ // If we end up with incompatibilities between while's return types and -+ // body's return types, the verifier will tell us about that. This means -+ // that the original program wasn't well-formed. TODO(burmako): Implement -+ // better error reporting for this case. -+ // This serves the current use cases well, so the implementation of more -+ // sophisticated refinement algorithm is left for future work. -+ rewriter.startRootUpdate(op); -+ auto condStatus = refineValues(rewriter, op, op.getCond().getArguments(), -+ op.getOperandTypes()); -+ auto bodyStatus = refineValues(rewriter, op, op.getBody().getArguments(), -+ op.getOperandTypes()); -+ if (succeeded(condStatus) || succeeded(bodyStatus)) { -+ rewriter.finalizeRootUpdate(op); -+ return success(); -+ } else { -+ rewriter.cancelRootUpdate(op); -+ return failure(); -+ } -+ } -+}; -+ -+struct UpdateFunctionTypePattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(func::ReturnOp op, -+ PatternRewriter& rewriter) const override { -+ // Check whether any of the values returned by `func.return` are casts -+ // which convert more specific type to less specific type. -+ // Such ops are produced by the algorithm behind this pass to avoid -+ // bringing the enclosing `func.func` op into an inconsistent state when -+ // refining individual ops. This pattern cleans this up. -+ bool needsUpdate = false; -+ SmallVector updatedResultTypes(op.getOperandTypes()); -+ llvm::SmallSet castsToReplace; -+ for (auto [i, operand] : llvm::enumerate(op.getOperands())) { -+ auto cast = -+ dyn_cast_or_null(operand.getDefiningOp()); -+ if (!cast || cast.getInputs().size() != 1 || -+ cast.getOutputs().size() != 1) -+ continue; -+ -+ // Only proceed if the type that we're casting from is more specific -+ // than the type that we're casting to. -+ auto sourceType = cast.getInputs()[0].getType(); -+ auto destType = cast.getOutputs()[0].getType(); -+ auto mostSpecificType = hlo::inferMostSpecificType( -+ /*location=*/{}, {sourceType, destType}); -+ if (failed(mostSpecificType) || destType == *mostSpecificType) continue; -+ -+ // If the source type of the cast is more specific than the target type, -+ // then we conclude that the cast is redundant (i.e. needs to be removed) -+ // and that the return type of the function needs an update. -+ needsUpdate = true; -+ updatedResultTypes[i] = sourceType; -+ -+ // Insert into set and continue iterating. -+ // ReturnOp may point to same value more than once. -+ castsToReplace.insert(cast); -+ } -+ if (!needsUpdate) -+ return rewriter.notifyMatchFailure(op, "doesn't need update"); -+ -+ // Replace CastOps with more specific operands than results. -+ for (auto cast : castsToReplace) -+ rewriter.replaceOp(cast, cast->getOperands()); -+ -+ // If the type of the enclosing `func.func` needs an update, we simply -+ // call setType. We can afford this simplicity because our algorithm -+ // currently supports only one function per module. -+ auto func = cast(op->getParentOp()); -+ func.setType( -+ rewriter.getFunctionType(func.getArgumentTypes(), updatedResultTypes)); -+ return success(); -+ } -+}; -+ -+struct UpdateRegionTypePattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(ReturnOp op, -+ PatternRewriter& rewriter) const override { -+ if (!isa(op->getParentOp())) -+ return rewriter.notifyMatchFailure(op, "unsupported region"); -+ -+ bool needsUpdate = false; -+ SmallVector updatedResultTypes(op.getOperandTypes()); -+ for (auto [regionType, refinedType] : llvm::zip( -+ op->getParentOp()->getResultTypes(), op->getOperandTypes())) { -+ auto mostSpecificType = hlo::inferMostSpecificType( -+ /*location=*/{}, {regionType, refinedType}); -+ if (failed(mostSpecificType) || regionType == *mostSpecificType) continue; -+ needsUpdate = true; -+ } -+ if (!needsUpdate) -+ return rewriter.notifyMatchFailure(op, "doesn't need update"); -+ -+ rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; }); -+ return success(); -+ } -+}; -+ +struct StablehloRefineShapesPass + : public impl::StablehloRefineShapesPassBase { + using StablehloRefineShapesPassBase::StablehloRefineShapesPassBase; + + void runOnOperation() override { -+ // Only one function per module is supported at the moment to avoid the need -+ // to think about iterative type inference algorithms. -+ // Current use cases are served well by inlining multiple functions into -+ // a single function, so we leave native support for multiple functions to -+ // future work. -+ // To enable modules that contain CustomCallOp::called_computations, -+ // we allow multiple functions, in which case we only refine the main -+ // function called "main", assuming that the called computations will have -+ // static shapes. Lifting this assumption and expanding refinement to -+ // multiple functions is left for future work. -+ ModuleOp module = getOperation(); -+ auto funcs = llvm::to_vector(module.getOps()); -+ if (funcs.empty()) return; -+ func::FuncOp func; -+ if (funcs.size() == 1) { -+ func = funcs[0]; -+ } else { -+ func = module.lookupSymbol("main"); -+ } -+ if (!func) { -+ module.emitOpError() -+ << "must have no more than one function or a `main`" -+ << " function to clearly identify which function will be refined"; -+ return signalPassFailure(); -+ } -+ -+ // Similarly, only one block per function is supported at the moment. -+ // At the StableHLO level, functions are expected to only have one block, -+ // so supporting more is out of scope for this pass. -+ if (!func.getRegion().hasOneBlock()) { -+ func.emitOpError() << "must have exactly one block"; -+ return signalPassFailure(); -+ } ++ auto func = getStablehloRefineShapesTarget(getOperation()); ++ if (!func) return signalPassFailure(); + + // The algorithm behind this pass consists of a single traversal of the + // function. This is sufficient because we only support one function per @@ -3641,54 +2537,20 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c + GreedyRewriteConfig config; + config.useTopDownTraversal = true; + config.enableRegionSimplification = true; -+ config.maxIterations = 2; ++ config.maxIterations = 3; + config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; + config.strictMode = GreedyRewriteStrictness::AnyOp; + + RewritePatternSet patterns(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); ++ populateStablehloRefineShapesPatterns(&patterns, &getContext()); + patterns.add(&getContext()); -+ patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); -+ patterns.add(&getContext()); + if (failed( + applyPatternsAndFoldGreedily(func, std::move(patterns), config))) { ++ func.emitError() ++ << "Greedy rewriter in StablehloRefineShapes does not converge after " ++ << config.maxIterations << " iterations."; + return signalPassFailure(); + } + } @@ -3698,59 +2560,1185 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +} // namespace experimental +} // namespace stablehlo +} // namespace mlir -diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir ---- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir -+++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir -@@ -340,6 +340,19 @@ - %1 = stablehlo.constant dense<2> : tensor - %2 = stablehlo.multiply %0, %1 : tensor - func.return %2 : tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: func @eval_or -+func.func @eval_or() -> tensor { -+ // CHECK-NOT: stablehlo.or -+ // CHECK: [[RESULT:%.*]] = stablehlo.constant dense : tensor -+ // CHECK: return [[RESULT]] -+ %0 = stablehlo.constant dense : tensor -+ %1 = stablehlo.constant dense : tensor -+ %2 = stablehlo.or %0, %1 : tensor -+ func.return %2 : tensor - } +diff --ruN a/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir b/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir +--- stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir ++++ stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir +@@ -1283,153 +1283,153 @@ + func.func @zeta_f16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: %[[TMP_0:.*]] = stablehlo.convert %[[X]] : (tensor) -> tensor + // CHECK: %[[TMP_1:.*]] = stablehlo.convert %[[Q]] : (tensor) -> tensor +- // CHECK: %[[TMP_2:.*]] = stablehlo.constant dense<0.000000e+00> +- // CHECK: %[[TMP_3:.*]] = stablehlo.negate %[[TMP_0]] +- // CHECK: %[[TMP_4:.*]] = stablehlo.power %[[TMP_1]], %[[TMP_3]] +- // CHECK: %[[TMP_5:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_6:.*]] = stablehlo.add %[[TMP_1]], %[[TMP_5]] +- // CHECK: %[[TMP_7:.*]] = stablehlo.power %[[TMP_6]], %[[TMP_3]] +- // CHECK: %[[TMP_8:.*]] = stablehlo.add %[[TMP_4]], %[[TMP_7]] +- // CHECK: %[[TMP_9:.*]] = stablehlo.add %[[TMP_6]], %[[TMP_5]] +- // CHECK: %[[TMP_10:.*]] = stablehlo.power %[[TMP_9]], %[[TMP_3]] ++ // CHECK-DAG: %[[TMP_2:.*]] = stablehlo.constant dense<0.000000e+00> ++ // CHECK-DAG: %[[TMP_3:.*]] = stablehlo.constant dense<1.000000e+00> ++ // CHECK: %[[TMP_4:.*]] = stablehlo.negate %[[TMP_0]] ++ // CHECK: %[[TMP_5:.*]] = stablehlo.power %[[TMP_1]], %[[TMP_4]] ++ // CHECK: %[[TMP_6:.*]] = stablehlo.add %[[TMP_1]], %[[TMP_3]] ++ // CHECK: %[[TMP_7:.*]] = stablehlo.power %[[TMP_6]], %[[TMP_4]] ++ // CHECK: %[[TMP_8:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_7]] ++ // CHECK: %[[TMP_9:.*]] = stablehlo.add %[[TMP_6]], %[[TMP_3]] ++ // CHECK: %[[TMP_10:.*]] = stablehlo.power %[[TMP_9]], %[[TMP_4]] + // CHECK: %[[TMP_11:.*]] = stablehlo.add %[[TMP_8]], %[[TMP_10]] +- // CHECK: %[[TMP_12:.*]] = stablehlo.add %[[TMP_9]], %[[TMP_5]] +- // CHECK: %[[TMP_13:.*]] = stablehlo.power %[[TMP_12]], %[[TMP_3]] ++ // CHECK: %[[TMP_12:.*]] = stablehlo.add %[[TMP_9]], %[[TMP_3]] ++ // CHECK: %[[TMP_13:.*]] = stablehlo.power %[[TMP_12]], %[[TMP_4]] + // CHECK: %[[TMP_14:.*]] = stablehlo.add %[[TMP_11]], %[[TMP_13]] +- // CHECK: %[[TMP_15:.*]] = stablehlo.add %[[TMP_12]], %[[TMP_5]] +- // CHECK: %[[TMP_16:.*]] = stablehlo.power %[[TMP_15]], %[[TMP_3]] ++ // CHECK: %[[TMP_15:.*]] = stablehlo.add %[[TMP_12]], %[[TMP_3]] ++ // CHECK: %[[TMP_16:.*]] = stablehlo.power %[[TMP_15]], %[[TMP_4]] + // CHECK: %[[TMP_17:.*]] = stablehlo.add %[[TMP_14]], %[[TMP_16]] +- // CHECK: %[[TMP_18:.*]] = stablehlo.add %[[TMP_15]], %[[TMP_5]] +- // CHECK: %[[TMP_19:.*]] = stablehlo.power %[[TMP_18]], %[[TMP_3]] ++ // CHECK: %[[TMP_18:.*]] = stablehlo.add %[[TMP_15]], %[[TMP_3]] ++ // CHECK: %[[TMP_19:.*]] = stablehlo.power %[[TMP_18]], %[[TMP_4]] + // CHECK: %[[TMP_20:.*]] = stablehlo.add %[[TMP_17]], %[[TMP_19]] +- // CHECK: %[[TMP_21:.*]] = stablehlo.add %[[TMP_18]], %[[TMP_5]] +- // CHECK: %[[TMP_22:.*]] = stablehlo.power %[[TMP_21]], %[[TMP_3]] ++ // CHECK: %[[TMP_21:.*]] = stablehlo.add %[[TMP_18]], %[[TMP_3]] ++ // CHECK: %[[TMP_22:.*]] = stablehlo.power %[[TMP_21]], %[[TMP_4]] + // CHECK: %[[TMP_23:.*]] = stablehlo.add %[[TMP_20]], %[[TMP_22]] +- // CHECK: %[[TMP_24:.*]] = stablehlo.add %[[TMP_21]], %[[TMP_5]] +- // CHECK: %[[TMP_25:.*]] = stablehlo.power %[[TMP_24]], %[[TMP_3]] ++ // CHECK: %[[TMP_24:.*]] = stablehlo.add %[[TMP_21]], %[[TMP_3]] ++ // CHECK: %[[TMP_25:.*]] = stablehlo.power %[[TMP_24]], %[[TMP_4]] + // CHECK: %[[TMP_26:.*]] = stablehlo.add %[[TMP_23]], %[[TMP_25]] +- // CHECK: %[[TMP_27:.*]] = stablehlo.add %[[TMP_24]], %[[TMP_5]] +- // CHECK: %[[TMP_28:.*]] = stablehlo.power %[[TMP_27]], %[[TMP_3]] ++ // CHECK: %[[TMP_27:.*]] = stablehlo.add %[[TMP_24]], %[[TMP_3]] ++ // CHECK: %[[TMP_28:.*]] = stablehlo.power %[[TMP_27]], %[[TMP_4]] + // CHECK: %[[TMP_29:.*]] = stablehlo.add %[[TMP_26]], %[[TMP_28]] +- // CHECK: %[[TMP_30:.*]] = stablehlo.add %[[TMP_27]], %[[TMP_5]] +- // CHECK: %[[TMP_31:.*]] = stablehlo.power %[[TMP_30]], %[[TMP_3]] ++ // CHECK: %[[TMP_30:.*]] = stablehlo.add %[[TMP_27]], %[[TMP_3]] ++ // CHECK: %[[TMP_31:.*]] = stablehlo.power %[[TMP_30]], %[[TMP_4]] + // CHECK: %[[TMP_32:.*]] = stablehlo.add %[[TMP_29]], %[[TMP_31]] +- // CHECK: %[[TMP_33:.*]] = stablehlo.add %[[TMP_30]], %[[TMP_5]] +- // CHECK: %[[TMP_34:.*]] = stablehlo.power %[[TMP_33]], %[[TMP_3]] ++ // CHECK: %[[TMP_33:.*]] = stablehlo.add %[[TMP_30]], %[[TMP_3]] ++ // CHECK: %[[TMP_34:.*]] = stablehlo.power %[[TMP_33]], %[[TMP_4]] + // CHECK: %[[TMP_35:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_36:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_35]] +- // CHECK: %[[TMP_37:.*]] = stablehlo.multiply %[[TMP_34]], %[[TMP_33]] +- // CHECK: %[[TMP_38:.*]] = stablehlo.divide %[[TMP_37]], %[[TMP_36]] +- // CHECK: %[[TMP_39:.*]] = stablehlo.add %[[TMP_32]], %[[TMP_38]] +- // CHECK: %[[TMP_40:.*]] = stablehlo.multiply %[[TMP_33]], %[[TMP_33]] +- // CHECK: %[[TMP_41:.*]] = stablehlo.divide %[[TMP_5]], %[[TMP_40]] +- // CHECK: %[[TMP_42:.*]] = stablehlo.constant dense<2.200000e+01> +- // CHECK: %[[TMP_43:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_42]] +- // CHECK: %[[TMP_44:.*]] = stablehlo.constant dense<2.100000e+01> +- // CHECK: %[[TMP_45:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_44]] +- // CHECK: %[[TMP_46:.*]] = stablehlo.multiply %[[TMP_43]], %[[TMP_45]] +- // CHECK: %[[TMP_47:.*]] = stablehlo.constant dense<-1.39544646E-19> +- // CHECK: %[[TMP_48:.*]] = stablehlo.add %[[TMP_2]], %[[TMP_47]] +- // CHECK: %[[TMP_49:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_48]] +- // CHECK: %[[TMP_50:.*]] = stablehlo.multiply %[[TMP_46]], %[[TMP_49]] +- // CHECK: %[[TMP_51:.*]] = stablehlo.constant dense<2.000000e+01> +- // CHECK: %[[TMP_52:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_51]] +- // CHECK: %[[TMP_53:.*]] = stablehlo.constant dense<1.900000e+01> +- // CHECK: %[[TMP_54:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_53]] +- // CHECK: %[[TMP_55:.*]] = stablehlo.multiply %[[TMP_52]], %[[TMP_54]] +- // CHECK: %[[TMP_56:.*]] = stablehlo.constant dense<5.50900303E-18> +- // CHECK: %[[TMP_57:.*]] = stablehlo.add %[[TMP_50]], %[[TMP_56]] +- // CHECK: %[[TMP_58:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_57]] +- // CHECK: %[[TMP_59:.*]] = stablehlo.multiply %[[TMP_55]], %[[TMP_58]] +- // CHECK: %[[TMP_60:.*]] = stablehlo.constant dense<1.800000e+01> +- // CHECK: %[[TMP_61:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_60]] +- // CHECK: %[[TMP_62:.*]] = stablehlo.constant dense<1.700000e+01> +- // CHECK: %[[TMP_63:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_62]] +- // CHECK: %[[TMP_64:.*]] = stablehlo.multiply %[[TMP_61]], %[[TMP_63]] +- // CHECK: %[[TMP_65:.*]] = stablehlo.constant dense<-2.17486866E-16> +- // CHECK: %[[TMP_66:.*]] = stablehlo.add %[[TMP_59]], %[[TMP_65]] +- // CHECK: %[[TMP_67:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_66]] +- // CHECK: %[[TMP_68:.*]] = stablehlo.multiply %[[TMP_64]], %[[TMP_67]] +- // CHECK: %[[TMP_69:.*]] = stablehlo.constant dense<1.600000e+01> +- // CHECK: %[[TMP_70:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_69]] +- // CHECK: %[[TMP_71:.*]] = stablehlo.constant dense<1.500000e+01> +- // CHECK: %[[TMP_72:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_71]] +- // CHECK: %[[TMP_73:.*]] = stablehlo.multiply %[[TMP_70]], %[[TMP_72]] +- // CHECK: %[[TMP_74:.*]] = stablehlo.constant dense<8.58606213E-15> +- // CHECK: %[[TMP_75:.*]] = stablehlo.add %[[TMP_68]], %[[TMP_74]] +- // CHECK: %[[TMP_76:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_75]] +- // CHECK: %[[TMP_77:.*]] = stablehlo.multiply %[[TMP_73]], %[[TMP_76]] +- // CHECK: %[[TMP_78:.*]] = stablehlo.constant dense<1.400000e+01> +- // CHECK: %[[TMP_79:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_78]] +- // CHECK: %[[TMP_80:.*]] = stablehlo.constant dense<1.300000e+01> +- // CHECK: %[[TMP_81:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_80]] +- // CHECK: %[[TMP_82:.*]] = stablehlo.multiply %[[TMP_79]], %[[TMP_81]] +- // CHECK: %[[TMP_83:.*]] = stablehlo.constant dense<-3.3896803E-13> +- // CHECK: %[[TMP_84:.*]] = stablehlo.add %[[TMP_77]], %[[TMP_83]] +- // CHECK: %[[TMP_85:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_84]] +- // CHECK: %[[TMP_86:.*]] = stablehlo.multiply %[[TMP_82]], %[[TMP_85]] +- // CHECK: %[[TMP_87:.*]] = stablehlo.constant dense<1.200000e+01> +- // CHECK: %[[TMP_88:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_87]] +- // CHECK: %[[TMP_89:.*]] = stablehlo.constant dense<1.100000e+01> +- // CHECK: %[[TMP_90:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_89]] +- // CHECK: %[[TMP_91:.*]] = stablehlo.multiply %[[TMP_88]], %[[TMP_90]] +- // CHECK: %[[TMP_92:.*]] = stablehlo.constant dense<1.33825364E-11> +- // CHECK: %[[TMP_93:.*]] = stablehlo.add %[[TMP_86]], %[[TMP_92]] +- // CHECK: %[[TMP_94:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_93]] +- // CHECK: %[[TMP_95:.*]] = stablehlo.multiply %[[TMP_91]], %[[TMP_94]] +- // CHECK: %[[TMP_96:.*]] = stablehlo.constant dense<1.000000e+01> +- // CHECK: %[[TMP_97:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_96]] +- // CHECK: %[[TMP_98:.*]] = stablehlo.constant dense<9.000000e+00> +- // CHECK: %[[TMP_99:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_98]] +- // CHECK: %[[TMP_100:.*]] = stablehlo.multiply %[[TMP_97]], %[[TMP_99]] +- // CHECK: %[[TMP_101:.*]] = stablehlo.constant dense<-5.28419031E-10> +- // CHECK: %[[TMP_102:.*]] = stablehlo.add %[[TMP_95]], %[[TMP_101]] +- // CHECK: %[[TMP_103:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_102]] +- // CHECK: %[[TMP_104:.*]] = stablehlo.multiply %[[TMP_100]], %[[TMP_103]] +- // CHECK: %[[TMP_105:.*]] = stablehlo.constant dense<8.000000e+00> +- // CHECK: %[[TMP_106:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_105]] +- // CHECK: %[[TMP_107:.*]] = stablehlo.constant dense<7.000000e+00> +- // CHECK: %[[TMP_108:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_107]] +- // CHECK: %[[TMP_109:.*]] = stablehlo.multiply %[[TMP_106]], %[[TMP_108]] +- // CHECK: %[[TMP_110:.*]] = stablehlo.constant dense<2.08767563E-8> +- // CHECK: %[[TMP_111:.*]] = stablehlo.add %[[TMP_104]], %[[TMP_110]] +- // CHECK: %[[TMP_112:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_111]] +- // CHECK: %[[TMP_113:.*]] = stablehlo.multiply %[[TMP_109]], %[[TMP_112]] +- // CHECK: %[[TMP_114:.*]] = stablehlo.constant dense<6.000000e+00> +- // CHECK: %[[TMP_115:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_114]] +- // CHECK: %[[TMP_116:.*]] = stablehlo.constant dense<5.000000e+00> +- // CHECK: %[[TMP_117:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_116]] +- // CHECK: %[[TMP_118:.*]] = stablehlo.multiply %[[TMP_115]], %[[TMP_117]] +- // CHECK: %[[TMP_119:.*]] = stablehlo.constant dense<-8.26719599E-7> +- // CHECK: %[[TMP_120:.*]] = stablehlo.add %[[TMP_113]], %[[TMP_119]] +- // CHECK: %[[TMP_121:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_120]] +- // CHECK: %[[TMP_122:.*]] = stablehlo.multiply %[[TMP_118]], %[[TMP_121]] +- // CHECK: %[[TMP_123:.*]] = stablehlo.constant dense<4.000000e+00> +- // CHECK: %[[TMP_124:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_123]] +- // CHECK: %[[TMP_125:.*]] = stablehlo.constant dense<3.000000e+00> +- // CHECK: %[[TMP_126:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_125]] +- // CHECK: %[[TMP_127:.*]] = stablehlo.multiply %[[TMP_124]], %[[TMP_126]] +- // CHECK: %[[TMP_128:.*]] = stablehlo.constant dense<3.30687835E-5> +- // CHECK: %[[TMP_129:.*]] = stablehlo.add %[[TMP_122]], %[[TMP_128]] +- // CHECK: %[[TMP_130:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_129]] +- // CHECK: %[[TMP_131:.*]] = stablehlo.multiply %[[TMP_127]], %[[TMP_130]] +- // CHECK: %[[TMP_132:.*]] = stablehlo.constant dense<2.000000e+00> +- // CHECK: %[[TMP_133:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_132]] +- // CHECK: %[[TMP_134:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_135:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_134]] +- // CHECK: %[[TMP_136:.*]] = stablehlo.multiply %[[TMP_133]], %[[TMP_135]] +- // CHECK: %[[TMP_137:.*]] = stablehlo.constant dense<-0.00138888892> +- // CHECK: %[[TMP_138:.*]] = stablehlo.add %[[TMP_131]], %[[TMP_137]] +- // CHECK: %[[TMP_139:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_138]] +- // CHECK: %[[TMP_140:.*]] = stablehlo.multiply %[[TMP_136]], %[[TMP_139]] +- // CHECK: %[[TMP_141:.*]] = stablehlo.constant dense<5.000000e-01> +- // CHECK: %[[TMP_142:.*]] = stablehlo.divide %[[TMP_0]], %[[TMP_33]] +- // CHECK: %[[TMP_143:.*]] = stablehlo.constant dense<0.0833333358> +- // CHECK: %[[TMP_144:.*]] = stablehlo.add %[[TMP_143]], %[[TMP_140]] +- // CHECK: %[[TMP_145:.*]] = stablehlo.multiply %[[TMP_142]], %[[TMP_144]] +- // CHECK: %[[TMP_146:.*]] = stablehlo.add %[[TMP_141]], %[[TMP_145]] +- // CHECK: %[[TMP_147:.*]] = stablehlo.multiply %[[TMP_34]], %[[TMP_146]] +- // CHECK: %[[TMP_148:.*]] = stablehlo.add %[[TMP_39]], %[[TMP_147]] ++ // CHECK: %[[TMP_36:.*]] = stablehlo.multiply %[[TMP_34]], %[[TMP_33]] ++ // CHECK: %[[TMP_37:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_35]] ++ // CHECK: %[[TMP_38:.*]] = stablehlo.divide %[[TMP_36]], %[[TMP_37]] ++ // CHECK: %[[TMP_39:.*]] = stablehlo.multiply %[[TMP_33]], %[[TMP_33]] ++ // CHECK: %[[TMP_40:.*]] = stablehlo.divide %[[TMP_3]], %[[TMP_39]] ++ // CHECK: %[[TMP_41:.*]] = stablehlo.constant dense<2.200000e+01> ++ // CHECK: %[[TMP_42:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_41]] ++ // CHECK: %[[TMP_43:.*]] = stablehlo.constant dense<2.100000e+01> ++ // CHECK: %[[TMP_44:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_43]] ++ // CHECK: %[[TMP_45:.*]] = stablehlo.multiply %[[TMP_42]], %[[TMP_44]] ++ // CHECK: %[[TMP_46:.*]] = stablehlo.constant dense<-1.39544646E-19> ++ // CHECK: %[[TMP_47:.*]] = stablehlo.add %[[TMP_2]], %[[TMP_46]] ++ // CHECK: %[[TMP_48:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_47]] ++ // CHECK: %[[TMP_49:.*]] = stablehlo.multiply %[[TMP_45]], %[[TMP_48]] ++ // CHECK: %[[TMP_50:.*]] = stablehlo.constant dense<2.000000e+01> ++ // CHECK: %[[TMP_51:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_50]] ++ // CHECK: %[[TMP_52:.*]] = stablehlo.constant dense<1.900000e+01> ++ // CHECK: %[[TMP_53:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_52]] ++ // CHECK: %[[TMP_54:.*]] = stablehlo.multiply %[[TMP_51]], %[[TMP_53]] ++ // CHECK: %[[TMP_55:.*]] = stablehlo.constant dense<5.50900303E-18> ++ // CHECK: %[[TMP_56:.*]] = stablehlo.add %[[TMP_49]], %[[TMP_55]] ++ // CHECK: %[[TMP_57:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_56]] ++ // CHECK: %[[TMP_58:.*]] = stablehlo.multiply %[[TMP_54]], %[[TMP_57]] ++ // CHECK: %[[TMP_59:.*]] = stablehlo.constant dense<1.800000e+01> ++ // CHECK: %[[TMP_60:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_59]] ++ // CHECK: %[[TMP_61:.*]] = stablehlo.constant dense<1.700000e+01> ++ // CHECK: %[[TMP_62:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_61]] ++ // CHECK: %[[TMP_63:.*]] = stablehlo.multiply %[[TMP_60]], %[[TMP_62]] ++ // CHECK: %[[TMP_64:.*]] = stablehlo.constant dense<-2.17486866E-16> ++ // CHECK: %[[TMP_65:.*]] = stablehlo.add %[[TMP_58]], %[[TMP_64]] ++ // CHECK: %[[TMP_66:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_65]] ++ // CHECK: %[[TMP_67:.*]] = stablehlo.multiply %[[TMP_63]], %[[TMP_66]] ++ // CHECK: %[[TMP_68:.*]] = stablehlo.constant dense<1.600000e+01> ++ // CHECK: %[[TMP_69:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_68]] ++ // CHECK: %[[TMP_70:.*]] = stablehlo.constant dense<1.500000e+01> ++ // CHECK: %[[TMP_71:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_70]] ++ // CHECK: %[[TMP_72:.*]] = stablehlo.multiply %[[TMP_69]], %[[TMP_71]] ++ // CHECK: %[[TMP_73:.*]] = stablehlo.constant dense<8.58606213E-15> ++ // CHECK: %[[TMP_74:.*]] = stablehlo.add %[[TMP_67]], %[[TMP_73]] ++ // CHECK: %[[TMP_75:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_74]] ++ // CHECK: %[[TMP_76:.*]] = stablehlo.multiply %[[TMP_72]], %[[TMP_75]] ++ // CHECK: %[[TMP_77:.*]] = stablehlo.constant dense<1.400000e+01> ++ // CHECK: %[[TMP_78:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_77]] ++ // CHECK: %[[TMP_79:.*]] = stablehlo.constant dense<1.300000e+01> ++ // CHECK: %[[TMP_80:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_79]] ++ // CHECK: %[[TMP_81:.*]] = stablehlo.multiply %[[TMP_78]], %[[TMP_80]] ++ // CHECK: %[[TMP_82:.*]] = stablehlo.constant dense<-3.3896803E-13> ++ // CHECK: %[[TMP_83:.*]] = stablehlo.add %[[TMP_76]], %[[TMP_82]] ++ // CHECK: %[[TMP_84:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_83]] ++ // CHECK: %[[TMP_85:.*]] = stablehlo.multiply %[[TMP_81]], %[[TMP_84]] ++ // CHECK: %[[TMP_86:.*]] = stablehlo.constant dense<1.200000e+01> ++ // CHECK: %[[TMP_87:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_86]] ++ // CHECK: %[[TMP_88:.*]] = stablehlo.constant dense<1.100000e+01> ++ // CHECK: %[[TMP_89:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_88]] ++ // CHECK: %[[TMP_90:.*]] = stablehlo.multiply %[[TMP_87]], %[[TMP_89]] ++ // CHECK: %[[TMP_91:.*]] = stablehlo.constant dense<1.33825364E-11> ++ // CHECK: %[[TMP_92:.*]] = stablehlo.add %[[TMP_85]], %[[TMP_91]] ++ // CHECK: %[[TMP_93:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_92]] ++ // CHECK: %[[TMP_94:.*]] = stablehlo.multiply %[[TMP_90]], %[[TMP_93]] ++ // CHECK: %[[TMP_95:.*]] = stablehlo.constant dense<1.000000e+01> ++ // CHECK: %[[TMP_96:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_95]] ++ // CHECK: %[[TMP_97:.*]] = stablehlo.constant dense<9.000000e+00> ++ // CHECK: %[[TMP_98:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_97]] ++ // CHECK: %[[TMP_99:.*]] = stablehlo.multiply %[[TMP_96]], %[[TMP_98]] ++ // CHECK: %[[TMP_100:.*]] = stablehlo.constant dense<-5.28419031E-10> ++ // CHECK: %[[TMP_101:.*]] = stablehlo.add %[[TMP_94]], %[[TMP_100]] ++ // CHECK: %[[TMP_102:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_101]] ++ // CHECK: %[[TMP_103:.*]] = stablehlo.multiply %[[TMP_99]], %[[TMP_102]] ++ // CHECK: %[[TMP_104:.*]] = stablehlo.constant dense<8.000000e+00> ++ // CHECK: %[[TMP_105:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_104]] ++ // CHECK: %[[TMP_106:.*]] = stablehlo.constant dense<7.000000e+00> ++ // CHECK: %[[TMP_107:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_106]] ++ // CHECK: %[[TMP_108:.*]] = stablehlo.multiply %[[TMP_105]], %[[TMP_107]] ++ // CHECK: %[[TMP_109:.*]] = stablehlo.constant dense<2.08767563E-8> ++ // CHECK: %[[TMP_110:.*]] = stablehlo.add %[[TMP_103]], %[[TMP_109]] ++ // CHECK: %[[TMP_111:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_110]] ++ // CHECK: %[[TMP_112:.*]] = stablehlo.multiply %[[TMP_108]], %[[TMP_111]] ++ // CHECK: %[[TMP_113:.*]] = stablehlo.constant dense<6.000000e+00> ++ // CHECK: %[[TMP_114:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_113]] ++ // CHECK: %[[TMP_115:.*]] = stablehlo.constant dense<5.000000e+00> ++ // CHECK: %[[TMP_116:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_115]] ++ // CHECK: %[[TMP_117:.*]] = stablehlo.multiply %[[TMP_114]], %[[TMP_116]] ++ // CHECK: %[[TMP_118:.*]] = stablehlo.constant dense<-8.26719599E-7> ++ // CHECK: %[[TMP_119:.*]] = stablehlo.add %[[TMP_112]], %[[TMP_118]] ++ // CHECK: %[[TMP_120:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_119]] ++ // CHECK: %[[TMP_121:.*]] = stablehlo.multiply %[[TMP_117]], %[[TMP_120]] ++ // CHECK: %[[TMP_122:.*]] = stablehlo.constant dense<4.000000e+00> ++ // CHECK: %[[TMP_123:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_122]] ++ // CHECK: %[[TMP_124:.*]] = stablehlo.constant dense<3.000000e+00> ++ // CHECK: %[[TMP_125:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_124]] ++ // CHECK: %[[TMP_126:.*]] = stablehlo.multiply %[[TMP_123]], %[[TMP_125]] ++ // CHECK: %[[TMP_127:.*]] = stablehlo.constant dense<3.30687835E-5> ++ // CHECK: %[[TMP_128:.*]] = stablehlo.add %[[TMP_121]], %[[TMP_127]] ++ // CHECK: %[[TMP_129:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_128]] ++ // CHECK: %[[TMP_130:.*]] = stablehlo.multiply %[[TMP_126]], %[[TMP_129]] ++ // CHECK: %[[TMP_131:.*]] = stablehlo.constant dense<2.000000e+00> ++ // CHECK: %[[TMP_132:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_131]] ++ // CHECK: %[[TMP_133:.*]] = stablehlo.constant dense<1.000000e+00> ++ // CHECK: %[[TMP_134:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_133]] ++ // CHECK: %[[TMP_135:.*]] = stablehlo.multiply %[[TMP_132]], %[[TMP_134]] ++ // CHECK: %[[TMP_136:.*]] = stablehlo.constant dense<-0.00138888892> ++ // CHECK: %[[TMP_137:.*]] = stablehlo.add %[[TMP_130]], %[[TMP_136]] ++ // CHECK: %[[TMP_138:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_137]] ++ // CHECK: %[[TMP_139:.*]] = stablehlo.multiply %[[TMP_135]], %[[TMP_138]] ++ // CHECK: %[[TMP_140:.*]] = stablehlo.constant dense<5.000000e-01> ++ // CHECK: %[[TMP_141:.*]] = stablehlo.divide %[[TMP_0]], %[[TMP_33]] ++ // CHECK: %[[TMP_142:.*]] = stablehlo.constant dense<0.0833333358> ++ // CHECK: %[[TMP_143:.*]] = stablehlo.add %[[TMP_142]], %[[TMP_139]] ++ // CHECK: %[[TMP_144:.*]] = stablehlo.multiply %[[TMP_141]], %[[TMP_143]] ++ // CHECK: %[[TMP_145:.*]] = stablehlo.add %[[TMP_140]], %[[TMP_144]] ++ // CHECK: %[[TMP_146:.*]] = stablehlo.multiply %[[TMP_34]], %[[TMP_145]] ++ // CHECK: %[[TMP_147:.*]] = stablehlo.add %[[TMP_32]], %[[TMP_38]] ++ // CHECK: %[[TMP_148:.*]] = stablehlo.add %[[TMP_147]], %[[TMP_146]] + // CHECK: %[[TMP_149:.*]] = stablehlo.abs %[[TMP_34]] + // CHECK: %[[TMP_150:.*]] = stablehlo.abs %[[TMP_32]] + // CHECK: %[[TMP_151:.*]] = stablehlo.constant dense<1.401300e-45> +@@ -1456,7 +1456,7 @@ + // CHECK: %[[TMP_172:.*]] = stablehlo.and %[[TMP_169]], %[[TMP_171]] : tensor + // CHECK: %[[TMP_173:.*]] = stablehlo.select %[[TMP_172]], %[[TMP_163]], %[[TMP_155]] + // CHECK: %[[TMP_174:.*]] = stablehlo.select %[[TMP_166]], %[[TMP_173]], %[[TMP_162]] +- // CHECK: %[[TMP_175:.*]] = stablehlo.compare EQ, %[[TMP_0]], %[[TMP_5]], NOTYPE ++ // CHECK: %[[TMP_175:.*]] = stablehlo.compare EQ, %[[TMP_0]], %[[TMP_3]], NOTYPE + // CHECK: %[[TMP_176:.*]] = stablehlo.select %[[TMP_175]], %[[TMP_163]], %[[TMP_174]] + // CHECK: %[[TMP_177:.*]] = stablehlo.convert %[[TMP_176]] : (tensor) -> tensor + %0 = chlo.zeta %arg0, %arg1 : tensor, tensor -> tensor +@@ -1465,8 +1465,7 @@ // ----- -diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ---- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -@@ -304,6 +304,20 @@ - } - }; -+struct EvalOrOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(OrOp op, -+ PatternRewriter& rewriter) const override { -+ auto resultType = op.getType(); -+ if (!resultType.getElementType().isInteger(1)) -+ return rewriter.notifyMatchFailure(op, "expected boolean element type"); -+ -+ return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { -+ return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0); -+ }); -+ } -+}; -+ - struct EvalRemOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(RemOp op, -@@ -1165,6 +1179,7 @@ - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); -+ patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); +- +-// CHECK-LABEL: @polygamma_f32 ++// CHECK: @polygamma_f32 + // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) + func.func @polygamma_f32(%lhs : tensor, %rhs : tensor) -> tensor { + // CHECK-DAG: %[[TMP_0:.*]] = stablehlo.constant dense<1.000000e+00> +@@ -1559,153 +1558,153 @@ + // CHECK: %[[TMP_87:.*]] = stablehlo.constant dense<0x7F800000> + // CHECK: %[[TMP_88:.*]] = stablehlo.select %[[TMP_86]], %[[TMP_87]], %[[TMP_83]] + // CHECK: %[[TMP_89:.*]] = stablehlo.exponential %[[TMP_88]] +- // CHECK: %[[TMP_90:.*]] = stablehlo.constant dense<0.000000e+00> +- // CHECK: %[[TMP_91:.*]] = stablehlo.negate %[[TMP_5]] +- // CHECK: %[[TMP_92:.*]] = stablehlo.power %[[ARG1]], %[[TMP_91]] +- // CHECK: %[[TMP_93:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_94:.*]] = stablehlo.add %[[ARG1]], %[[TMP_93]] +- // CHECK: %[[TMP_95:.*]] = stablehlo.power %[[TMP_94]], %[[TMP_91]] +- // CHECK: %[[TMP_96:.*]] = stablehlo.add %[[TMP_92]], %[[TMP_95]] +- // CHECK: %[[TMP_97:.*]] = stablehlo.add %[[TMP_94]], %[[TMP_93]] +- // CHECK: %[[TMP_98:.*]] = stablehlo.power %[[TMP_97]], %[[TMP_91]] ++ // CHECK-DAG: %[[TMP_90:.*]] = stablehlo.constant dense<0.000000e+00> ++ // CHECK-DAG: %[[TMP_91:.*]] = stablehlo.constant dense<1.000000e+00> ++ // CHECK: %[[TMP_92:.*]] = stablehlo.negate %[[TMP_5]] ++ // CHECK: %[[TMP_93:.*]] = stablehlo.power %[[ARG1]], %[[TMP_92]] ++ // CHECK: %[[TMP_94:.*]] = stablehlo.add %[[ARG1]], %[[TMP_91]] ++ // CHECK: %[[TMP_95:.*]] = stablehlo.power %[[TMP_94]], %[[TMP_92]] ++ // CHECK: %[[TMP_96:.*]] = stablehlo.add %[[TMP_93]], %[[TMP_95]] ++ // CHECK: %[[TMP_97:.*]] = stablehlo.add %[[TMP_94]], %[[TMP_91]] ++ // CHECK: %[[TMP_98:.*]] = stablehlo.power %[[TMP_97]], %[[TMP_92]] + // CHECK: %[[TMP_99:.*]] = stablehlo.add %[[TMP_96]], %[[TMP_98]] +- // CHECK: %[[TMP_100:.*]] = stablehlo.add %[[TMP_97]], %[[TMP_93]] +- // CHECK: %[[TMP_101:.*]] = stablehlo.power %[[TMP_100]], %[[TMP_91]] ++ // CHECK: %[[TMP_100:.*]] = stablehlo.add %[[TMP_97]], %[[TMP_91]] ++ // CHECK: %[[TMP_101:.*]] = stablehlo.power %[[TMP_100]], %[[TMP_92]] + // CHECK: %[[TMP_102:.*]] = stablehlo.add %[[TMP_99]], %[[TMP_101]] +- // CHECK: %[[TMP_103:.*]] = stablehlo.add %[[TMP_100]], %[[TMP_93]] +- // CHECK: %[[TMP_104:.*]] = stablehlo.power %[[TMP_103]], %[[TMP_91]] ++ // CHECK: %[[TMP_103:.*]] = stablehlo.add %[[TMP_100]], %[[TMP_91]] ++ // CHECK: %[[TMP_104:.*]] = stablehlo.power %[[TMP_103]], %[[TMP_92]] + // CHECK: %[[TMP_105:.*]] = stablehlo.add %[[TMP_102]], %[[TMP_104]] +- // CHECK: %[[TMP_106:.*]] = stablehlo.add %[[TMP_103]], %[[TMP_93]] +- // CHECK: %[[TMP_107:.*]] = stablehlo.power %[[TMP_106]], %[[TMP_91]] ++ // CHECK: %[[TMP_106:.*]] = stablehlo.add %[[TMP_103]], %[[TMP_91]] ++ // CHECK: %[[TMP_107:.*]] = stablehlo.power %[[TMP_106]], %[[TMP_92]] + // CHECK: %[[TMP_108:.*]] = stablehlo.add %[[TMP_105]], %[[TMP_107]] +- // CHECK: %[[TMP_109:.*]] = stablehlo.add %[[TMP_106]], %[[TMP_93]] +- // CHECK: %[[TMP_110:.*]] = stablehlo.power %[[TMP_109]], %[[TMP_91]] ++ // CHECK: %[[TMP_109:.*]] = stablehlo.add %[[TMP_106]], %[[TMP_91]] ++ // CHECK: %[[TMP_110:.*]] = stablehlo.power %[[TMP_109]], %[[TMP_92]] + // CHECK: %[[TMP_111:.*]] = stablehlo.add %[[TMP_108]], %[[TMP_110]] +- // CHECK: %[[TMP_112:.*]] = stablehlo.add %[[TMP_109]], %[[TMP_93]] +- // CHECK: %[[TMP_113:.*]] = stablehlo.power %[[TMP_112]], %[[TMP_91]] ++ // CHECK: %[[TMP_112:.*]] = stablehlo.add %[[TMP_109]], %[[TMP_91]] ++ // CHECK: %[[TMP_113:.*]] = stablehlo.power %[[TMP_112]], %[[TMP_92]] + // CHECK: %[[TMP_114:.*]] = stablehlo.add %[[TMP_111]], %[[TMP_113]] +- // CHECK: %[[TMP_115:.*]] = stablehlo.add %[[TMP_112]], %[[TMP_93]] +- // CHECK: %[[TMP_116:.*]] = stablehlo.power %[[TMP_115]], %[[TMP_91]] ++ // CHECK: %[[TMP_115:.*]] = stablehlo.add %[[TMP_112]], %[[TMP_91]] ++ // CHECK: %[[TMP_116:.*]] = stablehlo.power %[[TMP_115]], %[[TMP_92]] + // CHECK: %[[TMP_117:.*]] = stablehlo.add %[[TMP_114]], %[[TMP_116]] +- // CHECK: %[[TMP_118:.*]] = stablehlo.add %[[TMP_115]], %[[TMP_93]] +- // CHECK: %[[TMP_119:.*]] = stablehlo.power %[[TMP_118]], %[[TMP_91]] ++ // CHECK: %[[TMP_118:.*]] = stablehlo.add %[[TMP_115]], %[[TMP_91]] ++ // CHECK: %[[TMP_119:.*]] = stablehlo.power %[[TMP_118]], %[[TMP_92]] + // CHECK: %[[TMP_120:.*]] = stablehlo.add %[[TMP_117]], %[[TMP_119]] +- // CHECK: %[[TMP_121:.*]] = stablehlo.add %[[TMP_118]], %[[TMP_93]] +- // CHECK: %[[TMP_122:.*]] = stablehlo.power %[[TMP_121]], %[[TMP_91]] ++ // CHECK: %[[TMP_121:.*]] = stablehlo.add %[[TMP_118]], %[[TMP_91]] ++ // CHECK: %[[TMP_122:.*]] = stablehlo.power %[[TMP_121]], %[[TMP_92]] + // CHECK: %[[TMP_123:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_124:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] +- // CHECK: %[[TMP_125:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] +- // CHECK: %[[TMP_126:.*]] = stablehlo.divide %[[TMP_125]], %[[TMP_124]] +- // CHECK: %[[TMP_127:.*]] = stablehlo.add %[[TMP_120]], %[[TMP_126]] +- // CHECK: %[[TMP_128:.*]] = stablehlo.multiply %[[TMP_121]], %[[TMP_121]] +- // CHECK: %[[TMP_129:.*]] = stablehlo.divide %[[TMP_93]], %[[TMP_128]] +- // CHECK: %[[TMP_130:.*]] = stablehlo.constant dense<2.200000e+01> +- // CHECK: %[[TMP_131:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_130]] +- // CHECK: %[[TMP_132:.*]] = stablehlo.constant dense<2.100000e+01> +- // CHECK: %[[TMP_133:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_132]] +- // CHECK: %[[TMP_134:.*]] = stablehlo.multiply %[[TMP_131]], %[[TMP_133]] +- // CHECK: %[[TMP_135:.*]] = stablehlo.constant dense<-1.39544646E-19> +- // CHECK: %[[TMP_136:.*]] = stablehlo.add %[[TMP_90]], %[[TMP_135]] +- // CHECK: %[[TMP_137:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_136]] +- // CHECK: %[[TMP_138:.*]] = stablehlo.multiply %[[TMP_134]], %[[TMP_137]] +- // CHECK: %[[TMP_139:.*]] = stablehlo.constant dense<2.000000e+01> +- // CHECK: %[[TMP_140:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_139]] +- // CHECK: %[[TMP_141:.*]] = stablehlo.constant dense<1.900000e+01> +- // CHECK: %[[TMP_142:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_141]] +- // CHECK: %[[TMP_143:.*]] = stablehlo.multiply %[[TMP_140]], %[[TMP_142]] +- // CHECK: %[[TMP_144:.*]] = stablehlo.constant dense<5.50900303E-18> +- // CHECK: %[[TMP_145:.*]] = stablehlo.add %[[TMP_138]], %[[TMP_144]] +- // CHECK: %[[TMP_146:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_145]] +- // CHECK: %[[TMP_147:.*]] = stablehlo.multiply %[[TMP_143]], %[[TMP_146]] +- // CHECK: %[[TMP_148:.*]] = stablehlo.constant dense<1.800000e+01> +- // CHECK: %[[TMP_149:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_148]] +- // CHECK: %[[TMP_150:.*]] = stablehlo.constant dense<1.700000e+01> +- // CHECK: %[[TMP_151:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_150]] +- // CHECK: %[[TMP_152:.*]] = stablehlo.multiply %[[TMP_149]], %[[TMP_151]] +- // CHECK: %[[TMP_153:.*]] = stablehlo.constant dense<-2.17486866E-16> +- // CHECK: %[[TMP_154:.*]] = stablehlo.add %[[TMP_147]], %[[TMP_153]] +- // CHECK: %[[TMP_155:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_154]] +- // CHECK: %[[TMP_156:.*]] = stablehlo.multiply %[[TMP_152]], %[[TMP_155]] +- // CHECK: %[[TMP_157:.*]] = stablehlo.constant dense<1.600000e+01> +- // CHECK: %[[TMP_158:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_157]] +- // CHECK: %[[TMP_159:.*]] = stablehlo.constant dense<1.500000e+01> +- // CHECK: %[[TMP_160:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_159]] +- // CHECK: %[[TMP_161:.*]] = stablehlo.multiply %[[TMP_158]], %[[TMP_160]] +- // CHECK: %[[TMP_162:.*]] = stablehlo.constant dense<8.58606213E-15> +- // CHECK: %[[TMP_163:.*]] = stablehlo.add %[[TMP_156]], %[[TMP_162]] +- // CHECK: %[[TMP_164:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_163]] +- // CHECK: %[[TMP_165:.*]] = stablehlo.multiply %[[TMP_161]], %[[TMP_164]] +- // CHECK: %[[TMP_166:.*]] = stablehlo.constant dense<1.400000e+01> +- // CHECK: %[[TMP_167:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_166]] +- // CHECK: %[[TMP_168:.*]] = stablehlo.constant dense<1.300000e+01> +- // CHECK: %[[TMP_169:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_168]] +- // CHECK: %[[TMP_170:.*]] = stablehlo.multiply %[[TMP_167]], %[[TMP_169]] +- // CHECK: %[[TMP_171:.*]] = stablehlo.constant dense<-3.3896803E-13> +- // CHECK: %[[TMP_172:.*]] = stablehlo.add %[[TMP_165]], %[[TMP_171]] +- // CHECK: %[[TMP_173:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_172]] +- // CHECK: %[[TMP_174:.*]] = stablehlo.multiply %[[TMP_170]], %[[TMP_173]] +- // CHECK: %[[TMP_175:.*]] = stablehlo.constant dense<1.200000e+01> +- // CHECK: %[[TMP_176:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_175]] +- // CHECK: %[[TMP_177:.*]] = stablehlo.constant dense<1.100000e+01> +- // CHECK: %[[TMP_178:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_177]] +- // CHECK: %[[TMP_179:.*]] = stablehlo.multiply %[[TMP_176]], %[[TMP_178]] +- // CHECK: %[[TMP_180:.*]] = stablehlo.constant dense<1.33825364E-11> +- // CHECK: %[[TMP_181:.*]] = stablehlo.add %[[TMP_174]], %[[TMP_180]] +- // CHECK: %[[TMP_182:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_181]] +- // CHECK: %[[TMP_183:.*]] = stablehlo.multiply %[[TMP_179]], %[[TMP_182]] +- // CHECK: %[[TMP_184:.*]] = stablehlo.constant dense<1.000000e+01> +- // CHECK: %[[TMP_185:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_184]] +- // CHECK: %[[TMP_186:.*]] = stablehlo.constant dense<9.000000e+00> +- // CHECK: %[[TMP_187:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_186]] +- // CHECK: %[[TMP_188:.*]] = stablehlo.multiply %[[TMP_185]], %[[TMP_187]] +- // CHECK: %[[TMP_189:.*]] = stablehlo.constant dense<-5.28419031E-10> +- // CHECK: %[[TMP_190:.*]] = stablehlo.add %[[TMP_183]], %[[TMP_189]] +- // CHECK: %[[TMP_191:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_190]] +- // CHECK: %[[TMP_192:.*]] = stablehlo.multiply %[[TMP_188]], %[[TMP_191]] +- // CHECK: %[[TMP_193:.*]] = stablehlo.constant dense<8.000000e+00> +- // CHECK: %[[TMP_194:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_193]] +- // CHECK: %[[TMP_195:.*]] = stablehlo.constant dense<7.000000e+00> +- // CHECK: %[[TMP_196:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_195]] +- // CHECK: %[[TMP_197:.*]] = stablehlo.multiply %[[TMP_194]], %[[TMP_196]] +- // CHECK: %[[TMP_198:.*]] = stablehlo.constant dense<2.08767563E-8> +- // CHECK: %[[TMP_199:.*]] = stablehlo.add %[[TMP_192]], %[[TMP_198]] +- // CHECK: %[[TMP_200:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_199]] +- // CHECK: %[[TMP_201:.*]] = stablehlo.multiply %[[TMP_197]], %[[TMP_200]] +- // CHECK: %[[TMP_202:.*]] = stablehlo.constant dense<6.000000e+00> +- // CHECK: %[[TMP_203:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_202]] +- // CHECK: %[[TMP_204:.*]] = stablehlo.constant dense<5.000000e+00> +- // CHECK: %[[TMP_205:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_204]] +- // CHECK: %[[TMP_206:.*]] = stablehlo.multiply %[[TMP_203]], %[[TMP_205]] +- // CHECK: %[[TMP_207:.*]] = stablehlo.constant dense<-8.26719599E-7> +- // CHECK: %[[TMP_208:.*]] = stablehlo.add %[[TMP_201]], %[[TMP_207]] +- // CHECK: %[[TMP_209:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_208]] +- // CHECK: %[[TMP_210:.*]] = stablehlo.multiply %[[TMP_206]], %[[TMP_209]] +- // CHECK: %[[TMP_211:.*]] = stablehlo.constant dense<4.000000e+00> +- // CHECK: %[[TMP_212:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_211]] +- // CHECK: %[[TMP_213:.*]] = stablehlo.constant dense<3.000000e+00> +- // CHECK: %[[TMP_214:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_213]] +- // CHECK: %[[TMP_215:.*]] = stablehlo.multiply %[[TMP_212]], %[[TMP_214]] +- // CHECK: %[[TMP_216:.*]] = stablehlo.constant dense<3.30687835E-5> +- // CHECK: %[[TMP_217:.*]] = stablehlo.add %[[TMP_210]], %[[TMP_216]] +- // CHECK: %[[TMP_218:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_217]] +- // CHECK: %[[TMP_219:.*]] = stablehlo.multiply %[[TMP_215]], %[[TMP_218]] +- // CHECK: %[[TMP_220:.*]] = stablehlo.constant dense<2.000000e+00> +- // CHECK: %[[TMP_221:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_220]] +- // CHECK: %[[TMP_222:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_223:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_222]] +- // CHECK: %[[TMP_224:.*]] = stablehlo.multiply %[[TMP_221]], %[[TMP_223]] +- // CHECK: %[[TMP_225:.*]] = stablehlo.constant dense<-0.00138888892> +- // CHECK: %[[TMP_226:.*]] = stablehlo.add %[[TMP_219]], %[[TMP_225]] +- // CHECK: %[[TMP_227:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_226]] +- // CHECK: %[[TMP_228:.*]] = stablehlo.multiply %[[TMP_224]], %[[TMP_227]] +- // CHECK: %[[TMP_229:.*]] = stablehlo.constant dense<5.000000e-01> +- // CHECK: %[[TMP_230:.*]] = stablehlo.divide %[[TMP_5]], %[[TMP_121]] +- // CHECK: %[[TMP_231:.*]] = stablehlo.constant dense<0.0833333358> +- // CHECK: %[[TMP_232:.*]] = stablehlo.add %[[TMP_231]], %[[TMP_228]] +- // CHECK: %[[TMP_233:.*]] = stablehlo.multiply %[[TMP_230]], %[[TMP_232]] +- // CHECK: %[[TMP_234:.*]] = stablehlo.add %[[TMP_229]], %[[TMP_233]] +- // CHECK: %[[TMP_235:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_234]] +- // CHECK: %[[TMP_236:.*]] = stablehlo.add %[[TMP_127]], %[[TMP_235]] ++ // CHECK: %[[TMP_124:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] ++ // CHECK: %[[TMP_125:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] ++ // CHECK: %[[TMP_126:.*]] = stablehlo.divide %[[TMP_124]], %[[TMP_125]] ++ // CHECK: %[[TMP_127:.*]] = stablehlo.multiply %[[TMP_121]], %[[TMP_121]] ++ // CHECK: %[[TMP_128:.*]] = stablehlo.divide %[[TMP_91]], %[[TMP_127]] ++ // CHECK: %[[TMP_129:.*]] = stablehlo.constant dense<2.200000e+01> ++ // CHECK: %[[TMP_130:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_129]] ++ // CHECK: %[[TMP_131:.*]] = stablehlo.constant dense<2.100000e+01> ++ // CHECK: %[[TMP_132:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_131]] ++ // CHECK: %[[TMP_133:.*]] = stablehlo.multiply %[[TMP_130]], %[[TMP_132]] ++ // CHECK: %[[TMP_134:.*]] = stablehlo.constant dense<-1.39544646E-19> ++ // CHECK: %[[TMP_135:.*]] = stablehlo.add %[[TMP_90]], %[[TMP_134]] ++ // CHECK: %[[TMP_136:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_135]] ++ // CHECK: %[[TMP_137:.*]] = stablehlo.multiply %[[TMP_133]], %[[TMP_136]] ++ // CHECK: %[[TMP_138:.*]] = stablehlo.constant dense<2.000000e+01> ++ // CHECK: %[[TMP_139:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_138]] ++ // CHECK: %[[TMP_140:.*]] = stablehlo.constant dense<1.900000e+01> ++ // CHECK: %[[TMP_141:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_140]] ++ // CHECK: %[[TMP_142:.*]] = stablehlo.multiply %[[TMP_139]], %[[TMP_141]] ++ // CHECK: %[[TMP_143:.*]] = stablehlo.constant dense<5.50900303E-18> ++ // CHECK: %[[TMP_144:.*]] = stablehlo.add %[[TMP_137]], %[[TMP_143]] ++ // CHECK: %[[TMP_145:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_144]] ++ // CHECK: %[[TMP_146:.*]] = stablehlo.multiply %[[TMP_142]], %[[TMP_145]] ++ // CHECK: %[[TMP_147:.*]] = stablehlo.constant dense<1.800000e+01> ++ // CHECK: %[[TMP_148:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_147]] ++ // CHECK: %[[TMP_149:.*]] = stablehlo.constant dense<1.700000e+01> ++ // CHECK: %[[TMP_150:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_149]] ++ // CHECK: %[[TMP_151:.*]] = stablehlo.multiply %[[TMP_148]], %[[TMP_150]] ++ // CHECK: %[[TMP_152:.*]] = stablehlo.constant dense<-2.17486866E-16> ++ // CHECK: %[[TMP_153:.*]] = stablehlo.add %[[TMP_146]], %[[TMP_152]] ++ // CHECK: %[[TMP_154:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_153]] ++ // CHECK: %[[TMP_155:.*]] = stablehlo.multiply %[[TMP_151]], %[[TMP_154]] ++ // CHECK: %[[TMP_156:.*]] = stablehlo.constant dense<1.600000e+01> ++ // CHECK: %[[TMP_157:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_156]] ++ // CHECK: %[[TMP_158:.*]] = stablehlo.constant dense<1.500000e+01> ++ // CHECK: %[[TMP_159:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_158]] ++ // CHECK: %[[TMP_160:.*]] = stablehlo.multiply %[[TMP_157]], %[[TMP_159]] ++ // CHECK: %[[TMP_161:.*]] = stablehlo.constant dense<8.58606213E-15> ++ // CHECK: %[[TMP_162:.*]] = stablehlo.add %[[TMP_155]], %[[TMP_161]] ++ // CHECK: %[[TMP_163:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_162]] ++ // CHECK: %[[TMP_164:.*]] = stablehlo.multiply %[[TMP_160]], %[[TMP_163]] ++ // CHECK: %[[TMP_165:.*]] = stablehlo.constant dense<1.400000e+01> ++ // CHECK: %[[TMP_166:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_165]] ++ // CHECK: %[[TMP_167:.*]] = stablehlo.constant dense<1.300000e+01> ++ // CHECK: %[[TMP_168:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_167]] ++ // CHECK: %[[TMP_169:.*]] = stablehlo.multiply %[[TMP_166]], %[[TMP_168]] ++ // CHECK: %[[TMP_170:.*]] = stablehlo.constant dense<-3.3896803E-13> ++ // CHECK: %[[TMP_171:.*]] = stablehlo.add %[[TMP_164]], %[[TMP_170]] ++ // CHECK: %[[TMP_172:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_171]] ++ // CHECK: %[[TMP_173:.*]] = stablehlo.multiply %[[TMP_169]], %[[TMP_172]] ++ // CHECK: %[[TMP_174:.*]] = stablehlo.constant dense<1.200000e+01> ++ // CHECK: %[[TMP_175:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_174]] ++ // CHECK: %[[TMP_176:.*]] = stablehlo.constant dense<1.100000e+01> ++ // CHECK: %[[TMP_177:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_176]] ++ // CHECK: %[[TMP_178:.*]] = stablehlo.multiply %[[TMP_175]], %[[TMP_177]] ++ // CHECK: %[[TMP_179:.*]] = stablehlo.constant dense<1.33825364E-11> ++ // CHECK: %[[TMP_180:.*]] = stablehlo.add %[[TMP_173]], %[[TMP_179]] ++ // CHECK: %[[TMP_181:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_180]] ++ // CHECK: %[[TMP_182:.*]] = stablehlo.multiply %[[TMP_178]], %[[TMP_181]] ++ // CHECK: %[[TMP_183:.*]] = stablehlo.constant dense<1.000000e+01> ++ // CHECK: %[[TMP_184:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_183]] ++ // CHECK: %[[TMP_185:.*]] = stablehlo.constant dense<9.000000e+00> ++ // CHECK: %[[TMP_186:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_185]] ++ // CHECK: %[[TMP_187:.*]] = stablehlo.multiply %[[TMP_184]], %[[TMP_186]] ++ // CHECK: %[[TMP_188:.*]] = stablehlo.constant dense<-5.28419031E-10> ++ // CHECK: %[[TMP_189:.*]] = stablehlo.add %[[TMP_182]], %[[TMP_188]] ++ // CHECK: %[[TMP_190:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_189]] ++ // CHECK: %[[TMP_191:.*]] = stablehlo.multiply %[[TMP_187]], %[[TMP_190]] ++ // CHECK: %[[TMP_192:.*]] = stablehlo.constant dense<8.000000e+00> ++ // CHECK: %[[TMP_193:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_192]] ++ // CHECK: %[[TMP_194:.*]] = stablehlo.constant dense<7.000000e+00> ++ // CHECK: %[[TMP_195:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_194]] ++ // CHECK: %[[TMP_196:.*]] = stablehlo.multiply %[[TMP_193]], %[[TMP_195]] ++ // CHECK: %[[TMP_197:.*]] = stablehlo.constant dense<2.08767563E-8> ++ // CHECK: %[[TMP_198:.*]] = stablehlo.add %[[TMP_191]], %[[TMP_197]] ++ // CHECK: %[[TMP_199:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_198]] ++ // CHECK: %[[TMP_200:.*]] = stablehlo.multiply %[[TMP_196]], %[[TMP_199]] ++ // CHECK: %[[TMP_201:.*]] = stablehlo.constant dense<6.000000e+00> ++ // CHECK: %[[TMP_202:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_201]] ++ // CHECK: %[[TMP_203:.*]] = stablehlo.constant dense<5.000000e+00> ++ // CHECK: %[[TMP_204:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_203]] ++ // CHECK: %[[TMP_205:.*]] = stablehlo.multiply %[[TMP_202]], %[[TMP_204]] ++ // CHECK: %[[TMP_206:.*]] = stablehlo.constant dense<-8.26719599E-7> ++ // CHECK: %[[TMP_207:.*]] = stablehlo.add %[[TMP_200]], %[[TMP_206]] ++ // CHECK: %[[TMP_208:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_207]] ++ // CHECK: %[[TMP_209:.*]] = stablehlo.multiply %[[TMP_205]], %[[TMP_208]] ++ // CHECK: %[[TMP_210:.*]] = stablehlo.constant dense<4.000000e+00> ++ // CHECK: %[[TMP_211:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_210]] ++ // CHECK: %[[TMP_212:.*]] = stablehlo.constant dense<3.000000e+00> ++ // CHECK: %[[TMP_213:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_212]] ++ // CHECK: %[[TMP_214:.*]] = stablehlo.multiply %[[TMP_211]], %[[TMP_213]] ++ // CHECK: %[[TMP_215:.*]] = stablehlo.constant dense<3.30687835E-5> ++ // CHECK: %[[TMP_216:.*]] = stablehlo.add %[[TMP_209]], %[[TMP_215]] ++ // CHECK: %[[TMP_217:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_216]] ++ // CHECK: %[[TMP_218:.*]] = stablehlo.multiply %[[TMP_214]], %[[TMP_217]] ++ // CHECK: %[[TMP_219:.*]] = stablehlo.constant dense<2.000000e+00> ++ // CHECK: %[[TMP_220:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_219]] ++ // CHECK: %[[TMP_221:.*]] = stablehlo.constant dense<1.000000e+00> ++ // CHECK: %[[TMP_222:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_221]] ++ // CHECK: %[[TMP_223:.*]] = stablehlo.multiply %[[TMP_220]], %[[TMP_222]] ++ // CHECK: %[[TMP_224:.*]] = stablehlo.constant dense<-0.00138888892> ++ // CHECK: %[[TMP_225:.*]] = stablehlo.add %[[TMP_218]], %[[TMP_224]] ++ // CHECK: %[[TMP_226:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_225]] ++ // CHECK: %[[TMP_227:.*]] = stablehlo.multiply %[[TMP_223]], %[[TMP_226]] ++ // CHECK: %[[TMP_228:.*]] = stablehlo.constant dense<5.000000e-01> ++ // CHECK: %[[TMP_229:.*]] = stablehlo.divide %[[TMP_5]], %[[TMP_121]] ++ // CHECK: %[[TMP_230:.*]] = stablehlo.constant dense<0.0833333358> ++ // CHECK: %[[TMP_231:.*]] = stablehlo.add %[[TMP_230]], %[[TMP_227]] ++ // CHECK: %[[TMP_232:.*]] = stablehlo.multiply %[[TMP_229]], %[[TMP_231]] ++ // CHECK: %[[TMP_233:.*]] = stablehlo.add %[[TMP_228]], %[[TMP_232]] ++ // CHECK: %[[TMP_234:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_233]] ++ // CHECK: %[[TMP_235:.*]] = stablehlo.add %[[TMP_120]], %[[TMP_126]] ++ // CHECK: %[[TMP_236:.*]] = stablehlo.add %[[TMP_235]], %[[TMP_234]] + // CHECK: %[[TMP_237:.*]] = stablehlo.abs %[[TMP_122]] + // CHECK: %[[TMP_238:.*]] = stablehlo.abs %[[TMP_120]] + // CHECK: %[[TMP_239:.*]] = stablehlo.constant dense<1.401300e-45> +@@ -1732,7 +1731,7 @@ + // CHECK: %[[TMP_260:.*]] = stablehlo.and %[[TMP_257]], %[[TMP_259]] + // CHECK: %[[TMP_261:.*]] = stablehlo.select %[[TMP_260]], %[[TMP_251]], %[[TMP_243]] + // CHECK: %[[TMP_262:.*]] = stablehlo.select %[[TMP_254]], %[[TMP_261]], %[[TMP_250]] +- // CHECK: %[[TMP_263:.*]] = stablehlo.compare EQ, %[[TMP_5]], %[[TMP_93]], NOTYPE ++ // CHECK: %[[TMP_263:.*]] = stablehlo.compare EQ, %[[TMP_5]], %[[TMP_91]], NOTYPE + // CHECK: %[[TMP_264:.*]] = stablehlo.select %[[TMP_263]], %[[TMP_251]], %[[TMP_262]] + // CHECK: %[[TMP_265:.*]] = stablehlo.multiply %[[TMP_4]], %[[TMP_89]] + // CHECK: %[[TMP_266:.*]] = stablehlo.multiply %[[TMP_265]], %[[TMP_264]] +@@ -1853,8 +1852,7 @@ + + // ----- + +- +-// CHECK-LABEL: @polygamma_f64 ++// CHECK: @polygamma_f64 + // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) + func.func @polygamma_f64(%lhs : tensor, %rhs : tensor) -> tensor { + // CHECK-DAG: %[[TMP_0:.*]] = stablehlo.constant dense<1.000000e+00> +@@ -1947,153 +1945,153 @@ + // CHECK: %[[TMP_87:.*]] = stablehlo.constant dense<0x7FF0000000000000> + // CHECK: %[[TMP_88:.*]] = stablehlo.select %[[TMP_86]], %[[TMP_87]], %[[TMP_83]] + // CHECK: %[[TMP_89:.*]] = stablehlo.exponential %[[TMP_88]] +- // CHECK: %[[TMP_90:.*]] = stablehlo.constant dense<0.000000e+00> +- // CHECK: %[[TMP_91:.*]] = stablehlo.negate %[[TMP_5]] +- // CHECK: %[[TMP_92:.*]] = stablehlo.power %[[ARG1]], %[[TMP_91]] +- // CHECK: %[[TMP_93:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_94:.*]] = stablehlo.add %[[ARG1]], %[[TMP_93]] +- // CHECK: %[[TMP_95:.*]] = stablehlo.power %[[TMP_94]], %[[TMP_91]] +- // CHECK: %[[TMP_96:.*]] = stablehlo.add %[[TMP_92]], %[[TMP_95]] +- // CHECK: %[[TMP_97:.*]] = stablehlo.add %[[TMP_94]], %[[TMP_93]] +- // CHECK: %[[TMP_98:.*]] = stablehlo.power %[[TMP_97]], %[[TMP_91]] ++ // CHECK-DAG: %[[TMP_90:.*]] = stablehlo.constant dense<0.000000e+00> ++ // CHECK-DAG: %[[TMP_91:.*]] = stablehlo.constant dense<1.000000e+00> ++ // CHECK: %[[TMP_92:.*]] = stablehlo.negate %[[TMP_5]] ++ // CHECK: %[[TMP_93:.*]] = stablehlo.power %[[ARG1]], %[[TMP_92]] ++ // CHECK: %[[TMP_94:.*]] = stablehlo.add %[[ARG1]], %[[TMP_91]] ++ // CHECK: %[[TMP_95:.*]] = stablehlo.power %[[TMP_94]], %[[TMP_92]] ++ // CHECK: %[[TMP_96:.*]] = stablehlo.add %[[TMP_93]], %[[TMP_95]] ++ // CHECK: %[[TMP_97:.*]] = stablehlo.add %[[TMP_94]], %[[TMP_91]] ++ // CHECK: %[[TMP_98:.*]] = stablehlo.power %[[TMP_97]], %[[TMP_92]] + // CHECK: %[[TMP_99:.*]] = stablehlo.add %[[TMP_96]], %[[TMP_98]] +- // CHECK: %[[TMP_100:.*]] = stablehlo.add %[[TMP_97]], %[[TMP_93]] +- // CHECK: %[[TMP_101:.*]] = stablehlo.power %[[TMP_100]], %[[TMP_91]] ++ // CHECK: %[[TMP_100:.*]] = stablehlo.add %[[TMP_97]], %[[TMP_91]] ++ // CHECK: %[[TMP_101:.*]] = stablehlo.power %[[TMP_100]], %[[TMP_92]] + // CHECK: %[[TMP_102:.*]] = stablehlo.add %[[TMP_99]], %[[TMP_101]] +- // CHECK: %[[TMP_103:.*]] = stablehlo.add %[[TMP_100]], %[[TMP_93]] +- // CHECK: %[[TMP_104:.*]] = stablehlo.power %[[TMP_103]], %[[TMP_91]] ++ // CHECK: %[[TMP_103:.*]] = stablehlo.add %[[TMP_100]], %[[TMP_91]] ++ // CHECK: %[[TMP_104:.*]] = stablehlo.power %[[TMP_103]], %[[TMP_92]] + // CHECK: %[[TMP_105:.*]] = stablehlo.add %[[TMP_102]], %[[TMP_104]] +- // CHECK: %[[TMP_106:.*]] = stablehlo.add %[[TMP_103]], %[[TMP_93]] +- // CHECK: %[[TMP_107:.*]] = stablehlo.power %[[TMP_106]], %[[TMP_91]] ++ // CHECK: %[[TMP_106:.*]] = stablehlo.add %[[TMP_103]], %[[TMP_91]] ++ // CHECK: %[[TMP_107:.*]] = stablehlo.power %[[TMP_106]], %[[TMP_92]] + // CHECK: %[[TMP_108:.*]] = stablehlo.add %[[TMP_105]], %[[TMP_107]] +- // CHECK: %[[TMP_109:.*]] = stablehlo.add %[[TMP_106]], %[[TMP_93]] +- // CHECK: %[[TMP_110:.*]] = stablehlo.power %[[TMP_109]], %[[TMP_91]] ++ // CHECK: %[[TMP_109:.*]] = stablehlo.add %[[TMP_106]], %[[TMP_91]] ++ // CHECK: %[[TMP_110:.*]] = stablehlo.power %[[TMP_109]], %[[TMP_92]] + // CHECK: %[[TMP_111:.*]] = stablehlo.add %[[TMP_108]], %[[TMP_110]] +- // CHECK: %[[TMP_112:.*]] = stablehlo.add %[[TMP_109]], %[[TMP_93]] +- // CHECK: %[[TMP_113:.*]] = stablehlo.power %[[TMP_112]], %[[TMP_91]] ++ // CHECK: %[[TMP_112:.*]] = stablehlo.add %[[TMP_109]], %[[TMP_91]] ++ // CHECK: %[[TMP_113:.*]] = stablehlo.power %[[TMP_112]], %[[TMP_92]] + // CHECK: %[[TMP_114:.*]] = stablehlo.add %[[TMP_111]], %[[TMP_113]] +- // CHECK: %[[TMP_115:.*]] = stablehlo.add %[[TMP_112]], %[[TMP_93]] +- // CHECK: %[[TMP_116:.*]] = stablehlo.power %[[TMP_115]], %[[TMP_91]] ++ // CHECK: %[[TMP_115:.*]] = stablehlo.add %[[TMP_112]], %[[TMP_91]] ++ // CHECK: %[[TMP_116:.*]] = stablehlo.power %[[TMP_115]], %[[TMP_92]] + // CHECK: %[[TMP_117:.*]] = stablehlo.add %[[TMP_114]], %[[TMP_116]] +- // CHECK: %[[TMP_118:.*]] = stablehlo.add %[[TMP_115]], %[[TMP_93]] +- // CHECK: %[[TMP_119:.*]] = stablehlo.power %[[TMP_118]], %[[TMP_91]] ++ // CHECK: %[[TMP_118:.*]] = stablehlo.add %[[TMP_115]], %[[TMP_91]] ++ // CHECK: %[[TMP_119:.*]] = stablehlo.power %[[TMP_118]], %[[TMP_92]] + // CHECK: %[[TMP_120:.*]] = stablehlo.add %[[TMP_117]], %[[TMP_119]] +- // CHECK: %[[TMP_121:.*]] = stablehlo.add %[[TMP_118]], %[[TMP_93]] +- // CHECK: %[[TMP_122:.*]] = stablehlo.power %[[TMP_121]], %[[TMP_91]] ++ // CHECK: %[[TMP_121:.*]] = stablehlo.add %[[TMP_118]], %[[TMP_91]] ++ // CHECK: %[[TMP_122:.*]] = stablehlo.power %[[TMP_121]], %[[TMP_92]] + // CHECK: %[[TMP_123:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_124:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] +- // CHECK: %[[TMP_125:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] +- // CHECK: %[[TMP_126:.*]] = stablehlo.divide %[[TMP_125]], %[[TMP_124]] +- // CHECK: %[[TMP_127:.*]] = stablehlo.add %[[TMP_120]], %[[TMP_126]] +- // CHECK: %[[TMP_128:.*]] = stablehlo.multiply %[[TMP_121]], %[[TMP_121]] +- // CHECK: %[[TMP_129:.*]] = stablehlo.divide %[[TMP_93]], %[[TMP_128]] +- // CHECK: %[[TMP_130:.*]] = stablehlo.constant dense<2.200000e+01> +- // CHECK: %[[TMP_131:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_130]] +- // CHECK: %[[TMP_132:.*]] = stablehlo.constant dense<2.100000e+01> +- // CHECK: %[[TMP_133:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_132]] +- // CHECK: %[[TMP_134:.*]] = stablehlo.multiply %[[TMP_131]], %[[TMP_133]] +- // CHECK: %[[TMP_135:.*]] = stablehlo.constant dense<-1.3954464685812522E-19> +- // CHECK: %[[TMP_136:.*]] = stablehlo.add %[[TMP_90]], %[[TMP_135]] +- // CHECK: %[[TMP_137:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_136]] +- // CHECK: %[[TMP_138:.*]] = stablehlo.multiply %[[TMP_134]], %[[TMP_137]] +- // CHECK: %[[TMP_139:.*]] = stablehlo.constant dense<2.000000e+01> +- // CHECK: %[[TMP_140:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_139]] +- // CHECK: %[[TMP_141:.*]] = stablehlo.constant dense<1.900000e+01> +- // CHECK: %[[TMP_142:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_141]] +- // CHECK: %[[TMP_143:.*]] = stablehlo.multiply %[[TMP_140]], %[[TMP_142]] +- // CHECK: %[[TMP_144:.*]] = stablehlo.constant dense<5.5090028283602295E-18> +- // CHECK: %[[TMP_145:.*]] = stablehlo.add %[[TMP_138]], %[[TMP_144]] +- // CHECK: %[[TMP_146:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_145]] +- // CHECK: %[[TMP_147:.*]] = stablehlo.multiply %[[TMP_143]], %[[TMP_146]] +- // CHECK: %[[TMP_148:.*]] = stablehlo.constant dense<1.800000e+01> +- // CHECK: %[[TMP_149:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_148]] +- // CHECK: %[[TMP_150:.*]] = stablehlo.constant dense<1.700000e+01> +- // CHECK: %[[TMP_151:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_150]] +- // CHECK: %[[TMP_152:.*]] = stablehlo.multiply %[[TMP_149]], %[[TMP_151]] +- // CHECK: %[[TMP_153:.*]] = stablehlo.constant dense<-2.1748686985580617E-16> +- // CHECK: %[[TMP_154:.*]] = stablehlo.add %[[TMP_147]], %[[TMP_153]] +- // CHECK: %[[TMP_155:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_154]] +- // CHECK: %[[TMP_156:.*]] = stablehlo.multiply %[[TMP_152]], %[[TMP_155]] +- // CHECK: %[[TMP_157:.*]] = stablehlo.constant dense<1.600000e+01> +- // CHECK: %[[TMP_158:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_157]] +- // CHECK: %[[TMP_159:.*]] = stablehlo.constant dense<1.500000e+01> +- // CHECK: %[[TMP_160:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_159]] +- // CHECK: %[[TMP_161:.*]] = stablehlo.multiply %[[TMP_158]], %[[TMP_160]] +- // CHECK: %[[TMP_162:.*]] = stablehlo.constant dense<8.5860620562778452E-15> +- // CHECK: %[[TMP_163:.*]] = stablehlo.add %[[TMP_156]], %[[TMP_162]] +- // CHECK: %[[TMP_164:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_163]] +- // CHECK: %[[TMP_165:.*]] = stablehlo.multiply %[[TMP_161]], %[[TMP_164]] +- // CHECK: %[[TMP_166:.*]] = stablehlo.constant dense<1.400000e+01> +- // CHECK: %[[TMP_167:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_166]] +- // CHECK: %[[TMP_168:.*]] = stablehlo.constant dense<1.300000e+01> +- // CHECK: %[[TMP_169:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_168]] +- // CHECK: %[[TMP_170:.*]] = stablehlo.multiply %[[TMP_167]], %[[TMP_169]] +- // CHECK: %[[TMP_171:.*]] = stablehlo.constant dense<-3.3896802963225832E-13> +- // CHECK: %[[TMP_172:.*]] = stablehlo.add %[[TMP_165]], %[[TMP_171]] +- // CHECK: %[[TMP_173:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_172]] +- // CHECK: %[[TMP_174:.*]] = stablehlo.multiply %[[TMP_170]], %[[TMP_173]] +- // CHECK: %[[TMP_175:.*]] = stablehlo.constant dense<1.200000e+01> +- // CHECK: %[[TMP_176:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_175]] +- // CHECK: %[[TMP_177:.*]] = stablehlo.constant dense<1.100000e+01> +- // CHECK: %[[TMP_178:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_177]] +- // CHECK: %[[TMP_179:.*]] = stablehlo.multiply %[[TMP_176]], %[[TMP_178]] +- // CHECK: %[[TMP_180:.*]] = stablehlo.constant dense<1.3382536530684679E-11> +- // CHECK: %[[TMP_181:.*]] = stablehlo.add %[[TMP_174]], %[[TMP_180]] +- // CHECK: %[[TMP_182:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_181]] +- // CHECK: %[[TMP_183:.*]] = stablehlo.multiply %[[TMP_179]], %[[TMP_182]] +- // CHECK: %[[TMP_184:.*]] = stablehlo.constant dense<1.000000e+01> +- // CHECK: %[[TMP_185:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_184]] +- // CHECK: %[[TMP_186:.*]] = stablehlo.constant dense<9.000000e+00> +- // CHECK: %[[TMP_187:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_186]] +- // CHECK: %[[TMP_188:.*]] = stablehlo.multiply %[[TMP_185]], %[[TMP_187]] +- // CHECK: %[[TMP_189:.*]] = stablehlo.constant dense<-5.2841901386874932E-10> +- // CHECK: %[[TMP_190:.*]] = stablehlo.add %[[TMP_183]], %[[TMP_189]] +- // CHECK: %[[TMP_191:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_190]] +- // CHECK: %[[TMP_192:.*]] = stablehlo.multiply %[[TMP_188]], %[[TMP_191]] +- // CHECK: %[[TMP_193:.*]] = stablehlo.constant dense<8.000000e+00> +- // CHECK: %[[TMP_194:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_193]] +- // CHECK: %[[TMP_195:.*]] = stablehlo.constant dense<7.000000e+00> +- // CHECK: %[[TMP_196:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_195]] +- // CHECK: %[[TMP_197:.*]] = stablehlo.multiply %[[TMP_194]], %[[TMP_196]] +- // CHECK: %[[TMP_198:.*]] = stablehlo.constant dense<2.08767569878681E-8> +- // CHECK: %[[TMP_199:.*]] = stablehlo.add %[[TMP_192]], %[[TMP_198]] +- // CHECK: %[[TMP_200:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_199]] +- // CHECK: %[[TMP_201:.*]] = stablehlo.multiply %[[TMP_197]], %[[TMP_200]] +- // CHECK: %[[TMP_202:.*]] = stablehlo.constant dense<6.000000e+00> +- // CHECK: %[[TMP_203:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_202]] +- // CHECK: %[[TMP_204:.*]] = stablehlo.constant dense<5.000000e+00> +- // CHECK: %[[TMP_205:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_204]] +- // CHECK: %[[TMP_206:.*]] = stablehlo.multiply %[[TMP_203]], %[[TMP_205]] +- // CHECK: %[[TMP_207:.*]] = stablehlo.constant dense<-8.2671957671957675E-7> +- // CHECK: %[[TMP_208:.*]] = stablehlo.add %[[TMP_201]], %[[TMP_207]] +- // CHECK: %[[TMP_209:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_208]] +- // CHECK: %[[TMP_210:.*]] = stablehlo.multiply %[[TMP_206]], %[[TMP_209]] +- // CHECK: %[[TMP_211:.*]] = stablehlo.constant dense<4.000000e+00> +- // CHECK: %[[TMP_212:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_211]] +- // CHECK: %[[TMP_213:.*]] = stablehlo.constant dense<3.000000e+00> +- // CHECK: %[[TMP_214:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_213]] +- // CHECK: %[[TMP_215:.*]] = stablehlo.multiply %[[TMP_212]], %[[TMP_214]] +- // CHECK: %[[TMP_216:.*]] = stablehlo.constant dense<3.3068783068783071E-5> +- // CHECK: %[[TMP_217:.*]] = stablehlo.add %[[TMP_210]], %[[TMP_216]] +- // CHECK: %[[TMP_218:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_217]] +- // CHECK: %[[TMP_219:.*]] = stablehlo.multiply %[[TMP_215]], %[[TMP_218]] +- // CHECK: %[[TMP_220:.*]] = stablehlo.constant dense<2.000000e+00> +- // CHECK: %[[TMP_221:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_220]] +- // CHECK: %[[TMP_222:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_223:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_222]] +- // CHECK: %[[TMP_224:.*]] = stablehlo.multiply %[[TMP_221]], %[[TMP_223]] +- // CHECK: %[[TMP_225:.*]] = stablehlo.constant dense<-0.0013888888888888889> +- // CHECK: %[[TMP_226:.*]] = stablehlo.add %[[TMP_219]], %[[TMP_225]] +- // CHECK: %[[TMP_227:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_226]] +- // CHECK: %[[TMP_228:.*]] = stablehlo.multiply %[[TMP_224]], %[[TMP_227]] +- // CHECK: %[[TMP_229:.*]] = stablehlo.constant dense<5.000000e-01> +- // CHECK: %[[TMP_230:.*]] = stablehlo.divide %[[TMP_5]], %[[TMP_121]] +- // CHECK: %[[TMP_231:.*]] = stablehlo.constant dense<0.083333333333333329> +- // CHECK: %[[TMP_232:.*]] = stablehlo.add %[[TMP_231]], %[[TMP_228]] +- // CHECK: %[[TMP_233:.*]] = stablehlo.multiply %[[TMP_230]], %[[TMP_232]] +- // CHECK: %[[TMP_234:.*]] = stablehlo.add %[[TMP_229]], %[[TMP_233]] +- // CHECK: %[[TMP_235:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_234]] +- // CHECK: %[[TMP_236:.*]] = stablehlo.add %[[TMP_127]], %[[TMP_235]] ++ // CHECK: %[[TMP_124:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] ++ // CHECK: %[[TMP_125:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] ++ // CHECK: %[[TMP_126:.*]] = stablehlo.divide %[[TMP_124]], %[[TMP_125]] ++ // CHECK: %[[TMP_127:.*]] = stablehlo.multiply %[[TMP_121]], %[[TMP_121]] ++ // CHECK: %[[TMP_128:.*]] = stablehlo.divide %[[TMP_91]], %[[TMP_127]] ++ // CHECK: %[[TMP_129:.*]] = stablehlo.constant dense<2.200000e+01> ++ // CHECK: %[[TMP_130:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_129]] ++ // CHECK: %[[TMP_131:.*]] = stablehlo.constant dense<2.100000e+01> ++ // CHECK: %[[TMP_132:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_131]] ++ // CHECK: %[[TMP_133:.*]] = stablehlo.multiply %[[TMP_130]], %[[TMP_132]] ++ // CHECK: %[[TMP_134:.*]] = stablehlo.constant dense<-1.3954464685812522E-19> ++ // CHECK: %[[TMP_135:.*]] = stablehlo.add %[[TMP_90]], %[[TMP_134]] ++ // CHECK: %[[TMP_136:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_135]] ++ // CHECK: %[[TMP_137:.*]] = stablehlo.multiply %[[TMP_133]], %[[TMP_136]] ++ // CHECK: %[[TMP_138:.*]] = stablehlo.constant dense<2.000000e+01> ++ // CHECK: %[[TMP_139:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_138]] ++ // CHECK: %[[TMP_140:.*]] = stablehlo.constant dense<1.900000e+01> ++ // CHECK: %[[TMP_141:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_140]] ++ // CHECK: %[[TMP_142:.*]] = stablehlo.multiply %[[TMP_139]], %[[TMP_141]] ++ // CHECK: %[[TMP_143:.*]] = stablehlo.constant dense<5.5090028283602295E-18> ++ // CHECK: %[[TMP_144:.*]] = stablehlo.add %[[TMP_137]], %[[TMP_143]] ++ // CHECK: %[[TMP_145:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_144]] ++ // CHECK: %[[TMP_146:.*]] = stablehlo.multiply %[[TMP_142]], %[[TMP_145]] ++ // CHECK: %[[TMP_147:.*]] = stablehlo.constant dense<1.800000e+01> ++ // CHECK: %[[TMP_148:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_147]] ++ // CHECK: %[[TMP_149:.*]] = stablehlo.constant dense<1.700000e+01> ++ // CHECK: %[[TMP_150:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_149]] ++ // CHECK: %[[TMP_151:.*]] = stablehlo.multiply %[[TMP_148]], %[[TMP_150]] ++ // CHECK: %[[TMP_152:.*]] = stablehlo.constant dense<-2.1748686985580617E-16> ++ // CHECK: %[[TMP_153:.*]] = stablehlo.add %[[TMP_146]], %[[TMP_152]] ++ // CHECK: %[[TMP_154:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_153]] ++ // CHECK: %[[TMP_155:.*]] = stablehlo.multiply %[[TMP_151]], %[[TMP_154]] ++ // CHECK: %[[TMP_156:.*]] = stablehlo.constant dense<1.600000e+01> ++ // CHECK: %[[TMP_157:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_156]] ++ // CHECK: %[[TMP_158:.*]] = stablehlo.constant dense<1.500000e+01> ++ // CHECK: %[[TMP_159:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_158]] ++ // CHECK: %[[TMP_160:.*]] = stablehlo.multiply %[[TMP_157]], %[[TMP_159]] ++ // CHECK: %[[TMP_161:.*]] = stablehlo.constant dense<8.5860620562778452E-15> ++ // CHECK: %[[TMP_162:.*]] = stablehlo.add %[[TMP_155]], %[[TMP_161]] ++ // CHECK: %[[TMP_163:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_162]] ++ // CHECK: %[[TMP_164:.*]] = stablehlo.multiply %[[TMP_160]], %[[TMP_163]] ++ // CHECK: %[[TMP_165:.*]] = stablehlo.constant dense<1.400000e+01> ++ // CHECK: %[[TMP_166:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_165]] ++ // CHECK: %[[TMP_167:.*]] = stablehlo.constant dense<1.300000e+01> ++ // CHECK: %[[TMP_168:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_167]] ++ // CHECK: %[[TMP_169:.*]] = stablehlo.multiply %[[TMP_166]], %[[TMP_168]] ++ // CHECK: %[[TMP_170:.*]] = stablehlo.constant dense<-3.3896802963225832E-13> ++ // CHECK: %[[TMP_171:.*]] = stablehlo.add %[[TMP_164]], %[[TMP_170]] ++ // CHECK: %[[TMP_172:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_171]] ++ // CHECK: %[[TMP_173:.*]] = stablehlo.multiply %[[TMP_169]], %[[TMP_172]] ++ // CHECK: %[[TMP_174:.*]] = stablehlo.constant dense<1.200000e+01> ++ // CHECK: %[[TMP_175:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_174]] ++ // CHECK: %[[TMP_176:.*]] = stablehlo.constant dense<1.100000e+01> ++ // CHECK: %[[TMP_177:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_176]] ++ // CHECK: %[[TMP_178:.*]] = stablehlo.multiply %[[TMP_175]], %[[TMP_177]] ++ // CHECK: %[[TMP_179:.*]] = stablehlo.constant dense<1.3382536530684679E-11> ++ // CHECK: %[[TMP_180:.*]] = stablehlo.add %[[TMP_173]], %[[TMP_179]] ++ // CHECK: %[[TMP_181:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_180]] ++ // CHECK: %[[TMP_182:.*]] = stablehlo.multiply %[[TMP_178]], %[[TMP_181]] ++ // CHECK: %[[TMP_183:.*]] = stablehlo.constant dense<1.000000e+01> ++ // CHECK: %[[TMP_184:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_183]] ++ // CHECK: %[[TMP_185:.*]] = stablehlo.constant dense<9.000000e+00> ++ // CHECK: %[[TMP_186:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_185]] ++ // CHECK: %[[TMP_187:.*]] = stablehlo.multiply %[[TMP_184]], %[[TMP_186]] ++ // CHECK: %[[TMP_188:.*]] = stablehlo.constant dense<-5.2841901386874932E-10> ++ // CHECK: %[[TMP_189:.*]] = stablehlo.add %[[TMP_182]], %[[TMP_188]] ++ // CHECK: %[[TMP_190:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_189]] ++ // CHECK: %[[TMP_191:.*]] = stablehlo.multiply %[[TMP_187]], %[[TMP_190]] ++ // CHECK: %[[TMP_192:.*]] = stablehlo.constant dense<8.000000e+00> ++ // CHECK: %[[TMP_193:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_192]] ++ // CHECK: %[[TMP_194:.*]] = stablehlo.constant dense<7.000000e+00> ++ // CHECK: %[[TMP_195:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_194]] ++ // CHECK: %[[TMP_196:.*]] = stablehlo.multiply %[[TMP_193]], %[[TMP_195]] ++ // CHECK: %[[TMP_197:.*]] = stablehlo.constant dense<2.08767569878681E-8> ++ // CHECK: %[[TMP_198:.*]] = stablehlo.add %[[TMP_191]], %[[TMP_197]] ++ // CHECK: %[[TMP_199:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_198]] ++ // CHECK: %[[TMP_200:.*]] = stablehlo.multiply %[[TMP_196]], %[[TMP_199]] ++ // CHECK: %[[TMP_201:.*]] = stablehlo.constant dense<6.000000e+00> ++ // CHECK: %[[TMP_202:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_201]] ++ // CHECK: %[[TMP_203:.*]] = stablehlo.constant dense<5.000000e+00> ++ // CHECK: %[[TMP_204:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_203]] ++ // CHECK: %[[TMP_205:.*]] = stablehlo.multiply %[[TMP_202]], %[[TMP_204]] ++ // CHECK: %[[TMP_206:.*]] = stablehlo.constant dense<-8.2671957671957675E-7> ++ // CHECK: %[[TMP_207:.*]] = stablehlo.add %[[TMP_200]], %[[TMP_206]] ++ // CHECK: %[[TMP_208:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_207]] ++ // CHECK: %[[TMP_209:.*]] = stablehlo.multiply %[[TMP_205]], %[[TMP_208]] ++ // CHECK: %[[TMP_210:.*]] = stablehlo.constant dense<4.000000e+00> ++ // CHECK: %[[TMP_211:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_210]] ++ // CHECK: %[[TMP_212:.*]] = stablehlo.constant dense<3.000000e+00> ++ // CHECK: %[[TMP_213:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_212]] ++ // CHECK: %[[TMP_214:.*]] = stablehlo.multiply %[[TMP_211]], %[[TMP_213]] ++ // CHECK: %[[TMP_215:.*]] = stablehlo.constant dense<3.3068783068783071E-5> ++ // CHECK: %[[TMP_216:.*]] = stablehlo.add %[[TMP_209]], %[[TMP_215]] ++ // CHECK: %[[TMP_217:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_216]] ++ // CHECK: %[[TMP_218:.*]] = stablehlo.multiply %[[TMP_214]], %[[TMP_217]] ++ // CHECK: %[[TMP_219:.*]] = stablehlo.constant dense<2.000000e+00> ++ // CHECK: %[[TMP_220:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_219]] ++ // CHECK: %[[TMP_221:.*]] = stablehlo.constant dense<1.000000e+00> ++ // CHECK: %[[TMP_222:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_221]] ++ // CHECK: %[[TMP_223:.*]] = stablehlo.multiply %[[TMP_220]], %[[TMP_222]] ++ // CHECK: %[[TMP_224:.*]] = stablehlo.constant dense<-0.0013888888888888889> ++ // CHECK: %[[TMP_225:.*]] = stablehlo.add %[[TMP_218]], %[[TMP_224]] ++ // CHECK: %[[TMP_226:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_225]] ++ // CHECK: %[[TMP_227:.*]] = stablehlo.multiply %[[TMP_223]], %[[TMP_226]] ++ // CHECK: %[[TMP_228:.*]] = stablehlo.constant dense<5.000000e-01> ++ // CHECK: %[[TMP_229:.*]] = stablehlo.divide %[[TMP_5]], %[[TMP_121]] ++ // CHECK: %[[TMP_230:.*]] = stablehlo.constant dense<0.083333333333333329> ++ // CHECK: %[[TMP_231:.*]] = stablehlo.add %[[TMP_230]], %[[TMP_227]] ++ // CHECK: %[[TMP_232:.*]] = stablehlo.multiply %[[TMP_229]], %[[TMP_231]] ++ // CHECK: %[[TMP_233:.*]] = stablehlo.add %[[TMP_228]], %[[TMP_232]] ++ // CHECK: %[[TMP_234:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_233]] ++ // CHECK: %[[TMP_235:.*]] = stablehlo.add %[[TMP_120]], %[[TMP_126]] ++ // CHECK: %[[TMP_236:.*]] = stablehlo.add %[[TMP_235]], %[[TMP_234]] + // CHECK: %[[TMP_237:.*]] = stablehlo.abs %[[TMP_122]] + // CHECK: %[[TMP_238:.*]] = stablehlo.abs %[[TMP_120]] + // CHECK: %[[TMP_239:.*]] = stablehlo.constant dense<4.940660e-324> +@@ -2120,7 +2118,7 @@ + // CHECK: %[[TMP_260:.*]] = stablehlo.and %[[TMP_257]], %[[TMP_259]] + // CHECK: %[[TMP_261:.*]] = stablehlo.select %[[TMP_260]], %[[TMP_251]], %[[TMP_243]] + // CHECK: %[[TMP_262:.*]] = stablehlo.select %[[TMP_254]], %[[TMP_261]], %[[TMP_250]] +- // CHECK: %[[TMP_263:.*]] = stablehlo.compare EQ, %[[TMP_5]], %[[TMP_93]], NOTYPE ++ // CHECK: %[[TMP_263:.*]] = stablehlo.compare EQ, %[[TMP_5]], %[[TMP_91]], NOTYPE + // CHECK: %[[TMP_264:.*]] = stablehlo.select %[[TMP_263]], %[[TMP_251]], %[[TMP_262]] + // CHECK: %[[TMP_265:.*]] = stablehlo.multiply %[[TMP_4]], %[[TMP_89]] + // CHECK: %[[TMP_266:.*]] = stablehlo.multiply %[[TMP_265]], %[[TMP_264]] +diff --ruN a/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp +--- stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp ++++ stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp +@@ -1575,11 +1575,21 @@ + + static Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc, + ValueRange args) { +- // Code should match XLA's materializeZeta from chlo_legalize_to_hlo.cc ++ // Implementation ported from: ++ // https://github.com/openxla/xla/blob/7a067a7b88d2ffb15b1dc5e3c06f701a15f0391d/xla/client/lib/math.cc#L1912-L1917 ++ // Reference: Johansson, Fredrik. ++ // "Rigorous high-precision computation of the Hurwitz zeta function and its ++ // derivatives." Numerical Algorithms 69.2 (2015): 253-270. ++ // https://arxiv.org/abs/1309.2877 - formula (5) ++ // Notation is more or less kept as a reference to the whitepaper. + assert(args.size() == 2); + Value x = args[0]; + Value q = args[1]; +- static const std::array kZetaCoeffs{ ++ ++ static constexpr auto kTerms = 12; ++ static constexpr auto kIters = 9; ++ static constexpr auto kTwoTermsMinusOne = 2 * kTerms - 1; ++ static constexpr auto kZetaCoeffs = std::array{ + -7.1661652561756670113e18, + 1.8152105401943546773e17, + -4.5979787224074726105e15, +@@ -1596,131 +1606,134 @@ + + // For speed we'll always use 9 iterations for the initial series estimate, + // and a 12 term expansion for the Euler-Maclaurin formula. +- Value a = q; +- Value zero = getConstantLike(rewriter, loc, 0.0, a); +- Value negPower = zero; +- Value negX = rewriter.create(loc, x); +- Value initialSum = rewriter.create(loc, q, negX); +- Value one = getConstantLike(rewriter, loc, 1.0, a); +- for (int i = 0; i < 9; ++i) { +- a = rewriter.create(loc, a, one); +- negPower = rewriter.create(loc, a, negX); +- initialSum = +- rewriter.create(loc, initialSum, negPower); +- } +- +- a = rewriter.create(loc, a, one); +- negPower = rewriter.create(loc, a, negX); ++ Value zero = getConstantLike(rewriter, loc, 0.0, q); ++ Value one = getConstantLike(rewriter, loc, 1.0, q); ++ Value acc = q; ++ Value qNegPower = zero; ++ Value negX = rewriter.create(loc, x); ++ Value powerSum = rewriter.create(loc, q, negX); ++ for (int i = 0; i < kIters; ++i) { ++ acc = rewriter.create(loc, acc, one); ++ qNegPower = rewriter.create(loc, acc, negX); ++ powerSum = ++ rewriter.create(loc, powerSum, qNegPower); ++ } ++ acc = rewriter.create(loc, acc, one); ++ qNegPower = rewriter.create(loc, acc, negX); + Value oneLikeX = getConstantLike(rewriter, loc, 1.0, x); +- Value xMinusOne = +- rewriter.create(loc, x, oneLikeX); +- Value negPowerMulA = +- rewriter.create(loc, negPower, a); +- Value negPowerMulADivXMinusOne = +- rewriter.create(loc, negPowerMulA, xMinusOne); +- Value s = rewriter.create(loc, initialSum, +- negPowerMulADivXMinusOne); +- Value aInverseSquare = rewriter.create( +- loc, one, rewriter.create(loc, a, a)); +- +- Value hornerSum = zero; +- Value factor = one; ++ Value correctionEulerMaclaurin = rewriter.create( ++ loc, rewriter.create(loc, qNegPower, acc), ++ rewriter.create(loc, x, oneLikeX)); ++ ++ // Manual reciprocal of the square root as RsqrtOp produces different results ++ Value rsqrtAcc = rewriter.create( ++ loc, one, rewriter.create(loc, acc, acc)); ++ + // Use Horner's rule for this. + // Note this differs from Cephes which does a 'naive' polynomial evaluation. + // Using Horner's rule allows to avoid some NaN's and Infs from happening, + // resulting in more numerically stable code. +- for (int i = 0; i < 11; ++i) { +- Value factorLhs = rewriter.create( +- loc, x, getConstantLike(rewriter, loc, 22 - 2 * i, x)); +- Value factorRhs = rewriter.create( +- loc, x, getConstantLike(rewriter, loc, 21 - 2 * i, x)); +- factor = rewriter.create(loc, factorLhs, factorRhs); +- hornerSum = rewriter.create( +- loc, factor, +- rewriter.create( +- loc, aInverseSquare, +- rewriter.create( ++ Value hornerSum = zero; ++ Value hornerProduct = one; ++ ++ for (int i = 0; i < kTerms - 1; ++i) { ++ Value factorLhs = rewriter.create( ++ loc, x, ++ getConstantLike(rewriter, loc, kTwoTermsMinusOne - 1 - 2 * i, x)); ++ Value factorRhs = rewriter.create( ++ loc, x, ++ getConstantLike(rewriter, loc, kTwoTermsMinusOne - 2 - 2 * i, x)); ++ hornerProduct = ++ rewriter.create(loc, factorLhs, factorRhs); ++ hornerSum = rewriter.create( ++ loc, hornerProduct, ++ rewriter.create( ++ loc, rsqrtAcc, ++ rewriter.create( + loc, hornerSum, +- getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a)))); +- } +- Value zeroPointFiveLikeNegPower = +- getConstantLike(rewriter, loc, .5, negPower); +- Value xDivA = rewriter.create(loc, x, a); +- s = rewriter.create( +- loc, s, +- rewriter.create( +- loc, negPower, +- rewriter.create( +- loc, zeroPointFiveLikeNegPower, +- rewriter.create( +- loc, xDivA, +- rewriter.create( +- loc, +- getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11], a), +- hornerSum))))); ++ getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], acc)))); ++ } ++ Value zeroPointFiveLikeQNegPower = ++ getConstantLike(rewriter, loc, .5, qNegPower); ++ Value xDivAcc = rewriter.create(loc, x, acc); ++ Value bernoulliTailTerm = rewriter.create( ++ loc, qNegPower, ++ rewriter.create( ++ loc, zeroPointFiveLikeQNegPower, ++ rewriter.create( ++ loc, xDivAcc, ++ rewriter.create( ++ loc, ++ getConstantLike(rewriter, loc, 1. / kZetaCoeffs[kTerms - 1], ++ acc), ++ hornerSum)))); ++ Value accurateResult = rewriter.create( ++ loc, ++ rewriter.create(loc, powerSum, ++ correctionEulerMaclaurin), ++ bernoulliTailTerm); + + // Use the initial zeta sum without the correction term coming + // from Euler-Maclaurin if it is accurate enough. +- Value absNegPower = rewriter.create(loc, negPower); +- Value absInitialSum = +- rewriter.create(loc, initialSum); +- Value output = rewriter.create( ++ Value absQNegPower = rewriter.create(loc, qNegPower); ++ Value absPowerSum = rewriter.create(loc, powerSum); ++ Value output = rewriter.create( + loc, +- rewriter.create( +- loc, absNegPower, +- rewriter.create( +- loc, absInitialSum, +- getConstantLikeSmallestFiniteValue(rewriter, loc, a)), +- mlir::stablehlo::ComparisonDirection::LT), +- initialSum, s); ++ rewriter.create( ++ loc, absQNegPower, ++ rewriter.create( ++ loc, absPowerSum, ++ getConstantLikeSmallestFiniteValue(rewriter, loc, acc)), ++ ComparisonDirection::LT), ++ powerSum, accurateResult); + + // Function is not defined for x < 1. + Value nan = getConstantLike(rewriter, loc, + std::numeric_limits::quiet_NaN(), x); +- output = rewriter.create( ++ output = rewriter.create( + loc, +- rewriter.create( +- loc, x, oneLikeX, mlir::stablehlo::ComparisonDirection::LT), ++ rewriter.create( ++ loc, x, oneLikeX, ComparisonDirection::LT), + nan, output); + + // For q <= 0, x must be an integer. +- Value qLeZero = rewriter.create( +- loc, q, zero, mlir::stablehlo::ComparisonDirection::LE); +- Value xNotInt = rewriter.create( +- loc, x, rewriter.create(loc, x), +- mlir::stablehlo::ComparisonDirection::NE); ++ Value qLeZero = rewriter.create( ++ loc, q, zero, ComparisonDirection::LE); ++ Value xNotInt = rewriter.create( ++ loc, x, rewriter.create(loc, x), ++ ComparisonDirection::NE); + Value xDomainError = +- rewriter.create(loc, qLeZero, xNotInt); +- output = rewriter.create(loc, xDomainError, nan, ++ rewriter.create(loc, qLeZero, xNotInt); ++ output = rewriter.create(loc, xDomainError, nan, + output); + + // For all integer q <= 0, zeta has a pole. The limit is only defined as + // +inf if x is and even integer. + Value inf = getConstantLike(rewriter, loc, + std::numeric_limits::infinity(), x); +- Value qIsInt = rewriter.create( +- loc, q, rewriter.create(loc, q), +- mlir::stablehlo::ComparisonDirection::EQ); +- Value atPole = rewriter.create(loc, qLeZero, qIsInt); ++ Value qIsInt = rewriter.create( ++ loc, q, rewriter.create(loc, q), ++ ComparisonDirection::EQ); ++ Value atPole = rewriter.create(loc, qLeZero, qIsInt); + Value two = getConstantLike(rewriter, loc, 2.0, x); +- Value xIsInt = rewriter.create( +- loc, x, rewriter.create(loc, x), +- mlir::stablehlo::ComparisonDirection::EQ); +- Value xIsEven = rewriter.create( +- loc, rewriter.create(loc, x, two), zero, +- mlir::stablehlo::ComparisonDirection::EQ); ++ Value xIsInt = rewriter.create( ++ loc, x, rewriter.create(loc, x), ++ ComparisonDirection::EQ); ++ Value xIsEven = rewriter.create( ++ loc, rewriter.create(loc, x, two), zero, ++ ComparisonDirection::EQ); + Value xIsEvenInt = +- rewriter.create(loc, xIsInt, xIsEven); +- output = rewriter.create( ++ rewriter.create(loc, xIsInt, xIsEven); ++ output = rewriter.create( + loc, atPole, +- rewriter.create(loc, xIsEvenInt, inf, nan), ++ rewriter.create(loc, xIsEvenInt, inf, nan), + output); + + // For x = 1, this is the harmonic series and diverges. +- output = rewriter.create( ++ output = rewriter.create( + loc, +- rewriter.create( +- loc, x, one, mlir::stablehlo::ComparisonDirection::EQ), ++ rewriter.create( ++ loc, x, one, ComparisonDirection::EQ), + inf, output); + + return output; diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index a60b5db8e74b5..e6985fc2f07af 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "83f095e7217c897f1eccac5652600ceb944cb0e0" - STABLEHLO_SHA256 = "00e442f7e9c8a52a1ac774ce997f8b5a99d12450c4dfe1594df816dcbad5126f" + STABLEHLO_COMMIT = "1bdf7c2603b7e68d97c1b9be92a51826e06cb6ee" + STABLEHLO_SHA256 = "24b594aa66a5d780d30a98e50d24be6d52dd46643a875abc1004288144c6cbc2" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/triton/b304456327.patch b/third_party/triton/b304456327.patch deleted file mode 100644 index 68209ab4d7550..0000000000000 --- a/third_party/triton/b304456327.patch +++ /dev/null @@ -1,13 +0,0 @@ -diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp -index 169939eb3..9d90acf51 100644 ---- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp -+++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp -@@ -24,7 +24,7 @@ using ttg::SliceEncodingAttr; - // supported - static int getMMAVersionSafe(int computeCapability, tt::DotOp op) { - int baseVersion = 0; -- if (computeCapability < 75) { -+ if (computeCapability < 80) { - baseVersion = 1; - } else if (computeCapability < 90) { - baseVersion = 2; \ No newline at end of file diff --git a/third_party/triton/cl568176943.patch b/third_party/triton/cl568176943.patch deleted file mode 100644 index c187e670e7091..0000000000000 --- a/third_party/triton/cl568176943.patch +++ /dev/null @@ -1,24 +0,0 @@ -diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp -index e78e7298c..a4685653c 100644 ---- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp -+++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp -@@ -40,7 +40,6 @@ - #include "llvm/Support/SourceMgr.h" - #include "llvm/Target/TargetMachine.h" - #include "llvm/Transforms/InstCombine/InstCombine.h" --#include "third_party/py/triton/google/find_cuda.h" - #include - #ifdef _WIN32 - #define WIN32_LEAN_AND_MEAN -@@ -277,8 +276,10 @@ static std::map getExternLibs(mlir::ModuleOp module) { - // Search for libdevice relative to its library path if used from Python - // Then native code is in `triton/_C/libtriton.so` and libdevice in - // `triton/third_party/cuda/lib/libdevice.10.bc` -+ static const auto this_library_path = getThisLibraryPath(); - static const auto runtime_path = -- fs::path(PathToLibdevice()) / "libdevice.10.bc"; -+ this_library_path.parent_path().parent_path() / "third_party" / "cuda" / -+ "lib" / "libdevice.10.bc"; - if (fs::exists(runtime_path)) { - externLibs.try_emplace(libdevice, runtime_path.string()); - } else { diff --git a/third_party/triton/cl584230333.patch b/third_party/triton/cl584230333.patch deleted file mode 100644 index d8399eadf8f4d..0000000000000 --- a/third_party/triton/cl584230333.patch +++ /dev/null @@ -1,14 +0,0 @@ -==== triton/lib/Dialect/Triton/IR/Dialect.cpp#6 - /google/src/cloud/jreiffers/mlir_26a0b277369adc31b162b1cc38b1a712bc10c1a0_1700552908/triton/lib/Dialect/Triton/IR/Dialect.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/Triton/IR/Dialect.cpp 2023-10-12 01:35:16.000000000 -0700 -+++ triton/lib/Dialect/Triton/IR/Dialect.cpp 2023-11-21 01:58:04.000000000 -0800 -@@ -64,8 +64,7 @@ - - /// Handle the given inlined terminator by replacing it with a new operation - /// as necessary. -- void handleTerminator(Operation *op, -- ArrayRef valuesToRepl) const final { -+ void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { - // Only return needs to be handled here. - auto returnOp = cast(op); - diff --git a/third_party/triton/cl607293980.patch b/third_party/triton/cl607293980.patch new file mode 100644 index 0000000000000..955eb6db8da68 --- /dev/null +++ b/third_party/triton/cl607293980.patch @@ -0,0 +1,14 @@ +Long standing patch due to licensing issues. +diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp +index 31bc03fe1..a19a432df 100644 +--- a/include/triton/Tools/Sys/GetEnv.hpp ++++ b/include/triton/Tools/Sys/GetEnv.hpp +@@ -34,7 +34,7 @@ inline const std::set ENV_VARS = { + "AMDGCN_ENABLE_DUMP", + "DISABLE_FAST_REDUCTION", + "DISABLE_LLVM_OPT", +- "DISABLE_MMA_V3", ++ "ENABLE_MMA_V3", + "DISABLE_PTXAS_OPT", + "LLVM_IR_ENABLE_DUMP", + "MLIR_ENABLE_DUMP", diff --git a/third_party/triton/sparse_dot_base.patch b/third_party/triton/sparse_dot_base.patch new file mode 100644 index 0000000000000..3a0066a026e2f --- /dev/null +++ b/third_party/triton/sparse_dot_base.patch @@ -0,0 +1,898 @@ +diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td ++++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +@@ -1158,4 +1158,12 @@ section 9.7.13.4.1 for more details. + let extraClassDeclaration = extraDistributedDeclaration; + } + ++def SparseDotMetaEncodingAttr : DistributedEncoding<"SparseDotMetaEncoding", "sparse_dot_meta_encoding"> { ++ let mnemonic = "sparse_dot_meta"; ++ ++ let parameters = (ins "Attribute":$parent); ++ let assemblyFormat = "`<``{` struct(params) `}``>`"; ++ let extraClassDeclaration = extraDistributedDeclaration; ++} ++ + #endif +diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td ++++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +@@ -7,6 +7,7 @@ include "triton/Dialect/TritonGPU/IR/Tri + include "mlir/Dialect/Arith/IR/ArithBase.td" + include "triton/Dialect/Triton/IR/TritonTypes.td" + include "triton/Dialect/Triton/IR/TritonAttrDefs.td" ++include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td" + include "mlir/IR/OpBase.td" + include "mlir/Interfaces/SideEffectInterfaces.td" // Pure + include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +@@ -214,4 +215,19 @@ def TTG_LocalLoadOp : TTG_Op<"local_load + let results = (outs TT_Tensor:$result); + } + ++def TTNG_SparseDotOp : TTG_Op<"sparse_dot", [ ++ Pure, DeclareOpInterfaceMethods, ++ TypesMatchWith<"result's type matches accumulator's type", "d", "c", "$_self">]> { ++ let summary = "sparse dot"; ++ ++ let arguments = (ins ++ TT_TensorOrMemDesc:$a, ++ TT_TensorOrMemDesc:$b, ++ TT_FpIntTensor:$c, ++ TT_IntTensor: $aMeta); ++ let results = (outs TT_FpIntTensor:$d); ++ let assemblyFormat = "$a`,` $b`,` $c`,` $aMeta attr-dict `:` type($a) `meta` type($aMeta) `*` type($b) `->` type($d)"; ++ let hasVerifier = 1; ++} ++ + #endif +diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp +--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp ++++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp +@@ -479,6 +479,119 @@ getDefaultBlockedEncoding(MLIRContext *c + return encoding; + } + ++///--- SparseDotOp --- ++namespace { ++// Implied properties of 2:4 sparse dots. ++constexpr int kContractingFactor = 2; ++constexpr int kMetadataElementsPerPackedValue = 8; ++constexpr int kMetadataElementsPerWarp = 16; ++} // namespace ++ ++mlir::LogicalResult SparseDotOp::inferReturnTypes( ++ MLIRContext *context, std::optional location, ValueRange operands, ++ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, ++ SmallVectorImpl &inferredReturnTypes) { ++ return DotOp::inferReturnTypes(context, location, operands, attributes, ++ properties, regions, inferredReturnTypes); ++} ++ ++LogicalResult SparseDotOp::verify() { ++ // Verify operand A. ++ auto aTensorTy = getOperand(0).getType().cast(); ++ auto aElemTy = aTensorTy.getElementType(); ++ if (!aElemTy.isF16() && !aElemTy.isBF16()) ++ return emitError("element type of operand A is not supported"); ++ auto aShape = aTensorTy.getShape(); ++ if (aShape.size() != 2) return emitError("shape of operand A is incorrect"); ++ ++ // Verify operand B. ++ auto bTensorTy = getOperand(1).getType().cast(); ++ auto bElemTy = bTensorTy.getElementType(); ++ if (!bElemTy.isF16() && !bElemTy.isBF16()) ++ return emitError("element type of operand B is not supported"); ++ auto bShape = bTensorTy.getShape(); ++ if (bShape.size() != 2) return emitError("shape of operand B is incorrect"); ++ ++ // Verify operand C. ++ auto cTensorTy = getOperand(2).getType().cast(); ++ auto cElemTy = cTensorTy.getElementType(); ++ if (!cElemTy.isF32()) ++ return emitError("element type of operand C is not supported"); ++ auto cShape = cTensorTy.getShape(); ++ if (cShape.size() != 2) return emitError("shape of operand C is incorrect"); ++ ++ // Check operand dependencies. ++ if (aShape[0] != cShape[0] || bShape[1] != cShape[1] || ++ bShape[0] != aShape[1] * kContractingFactor) ++ return emitError("operand shape dimensions are incorrect"); ++ if (aElemTy != bElemTy) ++ return emitError("operand element types do not match"); ++ ++ // Verify sparse metadata. ++ auto metaTy = getOperand(3).getType().cast(); ++ auto metaShape = metaTy.getShape(); ++ if (!metaTy.getElementType().isInteger(16) || metaShape.size() != 2) ++ return emitError("sparse metadata tensor is invalid"); ++ if (metaShape[0] != aShape[0] || ++ metaShape[1] * kMetadataElementsPerPackedValue != aShape[1]) ++ return emitError("sparse metadata shape dimensions are incorrect"); ++ ++ // Verify tensor encoding. ++ auto aEncoding = aTensorTy.getEncoding(); ++ auto bEncoding = bTensorTy.getEncoding(); ++ if (!aEncoding && !bEncoding) return mlir::success(); ++ if (!aEncoding || !bEncoding) ++ return emitError("mismatching encoding between A and B operands"); ++ ++ Dialect &dialect = aEncoding.getDialect(); ++ auto interface = cast(&dialect); ++ return interface->verifyDotOpEncodingCompatibility(getOperation(), aEncoding, ++ bEncoding); ++} ++ ++//--- SparseDotMetaEncodingAttr --- ++unsigned SparseDotMetaEncodingAttr::getTotalElemsPerThread( ++ ArrayRef shape, Type eltTy) const { ++ auto mmaLayout = getParent().cast(); ++ return product(shape) / ++ (mmaLayout.getWarpsPerCTA()[0] * kMetadataElementsPerWarp); ++} ++ ++SmallVector SparseDotMetaEncodingAttr::getElemsPerThread( ++ ArrayRef shape, Type eltTy) const { ++ llvm_unreachable("getElemsPerThread is not supported for sparse dot meta"); ++ return SmallVector(); ++} ++ ++SmallVector SparseDotMetaEncodingAttr::getCTAsPerCGA() const { ++ return ::getCTAsPerCGA(getParent()); ++} ++SmallVector SparseDotMetaEncodingAttr::getCTAOrder() const { ++ return ::getCTAOrder(getParent()); ++} ++SmallVector SparseDotMetaEncodingAttr::getCTASplitNum() const { ++ return ::getCTASplitNum(getParent()); ++} ++SmallVector SparseDotMetaEncodingAttr::getWarpsPerCTA() const { ++ return ::getWarpsPerCTA(getParent()); ++} ++SmallVector SparseDotMetaEncodingAttr::getWarpOrder() const { ++ return {1, 0}; ++} ++SmallVector SparseDotMetaEncodingAttr::getThreadsPerWarp() const { ++ return ::getThreadsPerWarp(getParent()); ++} ++SmallVector SparseDotMetaEncodingAttr::getThreadOrder() const { ++ return {1, 0}; ++} ++SmallVector SparseDotMetaEncodingAttr::getSizePerThread() const { ++ return ::getSizePerThread(getParent()); ++} ++SmallVector SparseDotMetaEncodingAttr::getShapePerCTATile( ++ ArrayRef tensorShape) const { ++ return ::getShapePerCTATile(getParent(), tensorShape); ++} ++ + } // namespace gpu + } // namespace triton + } // namespace mlir +diff --git a/test/SparseDot/convert_to_llvm_ampere.mlir b/test/SparseDot/convert_to_llvm_ampere.mlir +new file mode 100644 +--- /dev/null ++++ b/test/SparseDot/convert_to_llvm_ampere.mlir +@@ -0,0 +1,26 @@ ++// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=80 | FileCheck %s ++ ++#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> ++#shared0 = #triton_gpu.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> ++#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> ++#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> ++#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> ++#dot_meta_enc = #triton_gpu.sparse_dot_meta<{parent=#mma0}> ++ ++module attributes {"triton_gpu.num-warps" = 4 : i32} { ++ tt.func @sparse_dot(%A: tensor<32x32xf16, #blocked0>, %B: tensor<64x32xf16, #blocked0>, %meta: tensor<32x4xi16, #blocked0>) { ++ // CHECK-COUNT-2: ldmatrix.sync.aligned.m8n8.x4.shared.b16 ++ %A_alloc = triton_gpu.local_alloc %A {allocation.offset = 0 : i32} : (tensor<32x32xf16, #blocked0>) -> !tt.memdesc<32x32xf16, #shared0> ++ %A_dot = triton_gpu.local_load %A_alloc : !tt.memdesc<32x32xf16, #shared0> -> tensor<32x32xf16, #dot_operand_a> ++ // CHECK-COUNT-4: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 ++ %B_alloc = triton_gpu.local_alloc %B {allocation.offset = 2048 : i32} : (tensor<64x32xf16, #blocked0>) -> !tt.memdesc<64x32xf16, #shared0> ++ %B_dot = triton_gpu.local_load %B_alloc : !tt.memdesc<64x32xf16, #shared0> -> tensor<64x32xf16, #dot_operand_b> ++ // CHECK-COUNT-4: llvm.load %[[_:.*]] : !llvm.ptr<3> -> i16 ++ %meta_alloc = triton_gpu.local_alloc %meta {allocation.offset = 6144 : i32} : (tensor<32x4xi16, #blocked0>) -> !tt.memdesc<32x4xi16, #shared0> ++ %meta_reg = triton_gpu.local_load %meta_alloc : !tt.memdesc<32x4xi16, #shared0> -> tensor<32x4xi16, #dot_meta_enc> ++ // CHECK-COUNT-4: mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 ++ %acc = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma0> ++ %D = triton_gpu.sparse_dot %A_dot, %B_dot, %acc, %meta_reg : tensor<32x32xf16, #dot_operand_a> meta tensor<32x4xi16, #dot_meta_enc> * tensor<64x32xf16, #dot_operand_b> -> tensor<32x32xf32, #mma0> ++ tt.return ++ } ++} +diff --git a/test/SparseDot/convert_to_llvm_hopper.mlir b/test/SparseDot/convert_to_llvm_hopper.mlir +new file mode 100644 +--- /dev/null ++++ b/test/SparseDot/convert_to_llvm_hopper.mlir +@@ -0,0 +1,28 @@ ++// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=90 | FileCheck %s ++ ++#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> ++#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> ++#shared1 = #triton_gpu.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> ++#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 64, 16]}> ++#dot_meta_enc = #triton_gpu.sparse_dot_meta<{parent=#mma0}> ++ ++module attributes {"triton_gpu.num-warps" = 4 : i32} { ++ tt.func @sparse_dot(%A: tensor<64x32xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>, %meta: tensor<64x4xi16, #blocked0>) { ++ %A_alloc = triton_gpu.local_alloc %A {allocation.offset = 0 : i32} : (tensor<64x32xf16, #blocked0>) -> !tt.memdesc<64x32xf16, #shared0> ++ %B_alloc = triton_gpu.local_alloc %B {allocation.offset = 4096 : i32} : (tensor<64x64xf16, #blocked0>) -> !tt.memdesc<64x64xf16, #shared0> ++ // CHECK-COUNT-2: llvm.load %[[_:.*]] : !llvm.ptr<3> -> i16 ++ %meta_alloc = triton_gpu.local_alloc %meta {allocation.offset = 12288 : i32} : (tensor<64x4xi16, #blocked0>) -> !tt.memdesc<64x4xi16, #shared0> ++ %meta_reg = triton_gpu.local_load %meta_alloc : !tt.memdesc<64x4xi16, #shared0> -> tensor<64x4xi16, #dot_meta_enc> ++ // CHECK: nvgpu.wgmma_fence ++ // CHECK-COUNT-2: nvgpu.wgmma_sp %[[A:.*]] meta %[[M:.*]], %[[B:.*]], %[[C:.*]] { ++ // CHECK-DAG: layoutA = 0 : i32 ++ // CHECK-DAG: layoutB = 0 : i32 ++ // CHECK-DAG: m = 64 : i32 ++ // CHECK-DAG: n = 64 : i32 ++ // CHECK-DAG: k = 32 : i32 ++ // CHECK: nvgpu.wgmma_commit_group ++ %acc = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma0> ++ %D = triton_gpu.sparse_dot %A_alloc, %B_alloc, %acc, %meta_reg : !tt.memdesc<64x32xf16, #shared0> meta tensor<64x4xi16, #dot_meta_enc> * !tt.memdesc<64x64xf16, #shared0> -> tensor<64x64xf32, #mma0> ++ tt.return ++ } ++} +diff --git a/test/SparseDot/validation.mlir b/test/SparseDot/validation.mlir +new file mode 100644 +--- /dev/null ++++ b/test/SparseDot/validation.mlir +@@ -0,0 +1,129 @@ ++// RUN: triton-opt --split-input-file --verify-diagnostics %s ++ ++tt.func @sparse_dot(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi16>) { ++ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32> ++ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x4xi16> * tensor<64x128xbf16> -> tensor<128x128xf32> ++ tt.return ++} ++ ++// ----- ++tt.func @sparse_dot_invalid_lhs_type(%lhs: tensor<128x32xf32>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi16>) { ++ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32> ++ // expected-error @+1 {{element type of operand A is not supported}} ++ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xf32> meta tensor<128x4xi16> * tensor<64x128xbf16> -> tensor<128x128xf32> ++ tt.return ++} ++ ++// ----- ++tt.func @sparse_dot_invalid_lhs_shape(%lhs: tensor<1x128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi16>) { ++ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32> ++ // expected-error @+1 {{shape of operand A is incorrect}} ++ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<1x128x32xbf16> meta tensor<128x4xi16> * tensor<64x128xbf16> -> tensor<128x128xf32> ++ tt.return ++} ++ ++// ----- ++tt.func @sparse_dot_invalid_rhs_type(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xf32>, %meta: tensor<128x4xi16>) { ++ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32> ++ // expected-error @+1 {{element type of operand B is not supported}} ++ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x4xi16> * tensor<64x128xf32> -> tensor<128x128xf32> ++ tt.return ++} ++ ++// ----- ++tt.func @sparse_dot_invalid_rhs_shape(%lhs: tensor<128x32xbf16>, %rhs: tensor<1x64x128xbf16>, %meta: tensor<128x4xi16>) { ++ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32> ++ // expected-error @+1 {{shape of operand B is incorrect}} ++ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x4xi16> * tensor<1x64x128xbf16> -> tensor<128x128xf32> ++ tt.return ++} ++ ++// ----- ++tt.func @sparse_dot_invalid_acc_type(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi16>) { ++ %acc = arith.constant dense<0.00e+00> : tensor<128x128xbf16> ++ // expected-error @+1 {{element type of operand C is not supported}} ++ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x4xi16> * tensor<64x128xbf16> -> tensor<128x128xbf16> ++ tt.return ++} ++ ++// ----- ++tt.func @sparse_dot_invalid_acc_shape(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi16>) { ++ %acc = arith.constant dense<0.00e+00> : tensor<16384xf32> ++ // expected-error @+1 {{shape of operand C is incorrect}} ++ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x4xi16> * tensor<64x128xbf16> -> tensor<16384xf32> ++ tt.return ++} ++ ++// ----- ++tt.func @sparse_dot_mismatch_lhs_acc(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi16>) { ++ %acc = arith.constant dense<0.00e+00> : tensor<64x128xf32> ++ // expected-error @+1 {{operand shape dimensions are incorrect}} ++ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x4xi16> * tensor<64x128xbf16> -> tensor<64x128xf32> ++ tt.return ++} ++ ++// ----- ++tt.func @sparse_dot_mismatch_rhs_acc(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi16>) { ++ %acc = arith.constant dense<0.00e+00> : tensor<128x64xf32> ++ // expected-error @+1 {{operand shape dimensions are incorrect}} ++ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x4xi16> * tensor<64x128xbf16> -> tensor<128x64xf32> ++ tt.return ++} ++ ++// ----- ++tt.func @sparse_dot_mismatch_lhs_rhs(%lhs: tensor<128x32xbf16>, %rhs: tensor<32x128xbf16>, %meta: tensor<128x4xi16>) { ++ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32> ++ // expected-error @+1 {{operand shape dimensions are incorrect}} ++ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x4xi16> * tensor<32x128xbf16> -> tensor<128x128xf32> ++ tt.return ++} ++ ++// ----- ++tt.func @sparse_dot_mismatch_input_types(%lhs: tensor<128x32xf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi16>) { ++ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32> ++ // expected-error @+1 {{operand element types do not match}} ++ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xf16> meta tensor<128x4xi16> * tensor<64x128xbf16> -> tensor<128x128xf32> ++ tt.return ++} ++ ++// ----- ++tt.func @sparse_dot_invalid_meta_type(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi8>) { ++ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32> ++ // expected-error @+1 {{sparse metadata tensor is invalid}} ++ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x4xi8> * tensor<64x128xbf16> -> tensor<128x128xf32> ++ tt.return ++} ++ ++// ----- ++tt.func @sparse_dot_invalid_meta_shape(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<512xi16>) { ++ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32> ++ // expected-error @+1 {{sparse metadata tensor is invalid}} ++ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<512xi16> * tensor<64x128xbf16> -> tensor<128x128xf32> ++ tt.return ++} ++ ++// ----- ++tt.func @sparse_dot_mismatch_meta_noncontracting(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<64x4xi16>) { ++ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32> ++ // expected-error @+1 {{sparse metadata shape dimensions are incorrect}} ++ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<64x4xi16> * tensor<64x128xbf16> -> tensor<128x128xf32> ++ tt.return ++} ++ ++// ----- ++tt.func @sparse_dot_mismatch_meta_contracting(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x8xi16>) { ++ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32> ++ // expected-error @+1 {{sparse metadata shape dimensions are incorrect}} ++ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x8xi16> * tensor<64x128xbf16> -> tensor<128x128xf32> ++ tt.return ++} ++ ++// ----- ++#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], instrShape = [16, 8]}> ++#enc0 = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> ++tt.func @sparse_dot_encoding_operand_mismatch(%lhs: tensor<128x32xbf16, #enc0>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi16>) { ++ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32> ++ // expected-error @+1 {{mismatching encoding between A and B operands}} ++ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16, #enc0> meta tensor<128x4xi16> * tensor<64x128xbf16> -> tensor<128x128xf32> ++ tt.return ++} +diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp ++++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +@@ -38,6 +38,14 @@ Value convertLayout(int opIdx, Conversio + const LLVMTypeConverter *typeConverter, Value thread); + } + ++namespace SharedToSparseDotOperand { ++Value convertLayout( ++ ConversionPatternRewriter &rewriter, Location loc, Value tensor, ++ triton::gpu::SparseDotMetaEncodingAttr sparseEncoding, ++ const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, ++ Value thread); ++} // namespace SharedToSparseDotOperand ++ + namespace { + + struct LocalLoadOpConversion +@@ -59,6 +67,10 @@ public: + .isa()) { + return lowerSharedToDotOperand(op, adaptor, getTypeConverter(), rewriter); + } ++ if (srcLayout.isa() && ++ dstLayout.isa()) { ++ return lowerSharedToSparseMeta(op, adaptor, getTypeConverter(), rewriter); ++ } + return failure(); + } + +@@ -130,6 +142,29 @@ private: + rewriter.replaceOp(op, res); + return success(); + } ++ ++ // shared -> sparse dot meta ++ LogicalResult lowerSharedToSparseMeta( ++ triton::gpu::LocalLoadOp op, triton::gpu::LocalLoadOpAdaptor adaptor, ++ const LLVMTypeConverter *typeConverter, ++ ConversionPatternRewriter &rewriter) const { ++ auto loc = op.getLoc(); ++ auto sparseEncoding = op.getResult() ++ .getType() ++ .cast() ++ .getEncoding() ++ .cast(); ++ auto llvmElemTy = typeConverter->convertType( ++ op.getSrc().getType().cast().getElementType()); ++ auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), ++ llvmElemTy, rewriter); ++ Value res = SharedToSparseDotOperand::convertLayout( ++ rewriter, loc, op.getSrc(), sparseEncoding, smemObj, typeConverter, ++ getThreadId(rewriter, loc)); ++ ++ rewriter.replaceOp(op, res); ++ return success(); ++ } + }; + + struct ConvertLayoutOpOptimizedConversion +diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToSparseDotOperand.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToSparseDotOperand.cpp +new file mode 100644 +--- /dev/null ++++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToSparseDotOperand.cpp +@@ -0,0 +1,69 @@ ++#include "../Utility.h" ++ ++namespace SharedToSparseDotOperand { ++namespace { ++constexpr int kThreadsPerWarp = 32; ++ ++// Each 16x16 original sparse matrix tile requires 16 metadata values of 16-bit ++// size, where the first thread (T0) in each 4-thread group holds two such ++// values in a register (32-bit). ++// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#sparse-matrix-storage ++constexpr int kTileSize = 16; ++constexpr int kThreadsInGroup = 4; ++constexpr int kMetadataElementsPerPackedValue = 8; // 8 x 2-bit = 16-bit ++constexpr int kMetadataLineOffset = kThreadsPerWarp / kThreadsInGroup; ++} // namespace ++ ++Value convertLayout( ++ ConversionPatternRewriter &rewriter, Location loc, Value tensor, ++ triton::gpu::SparseDotMetaEncodingAttr sparseEncoding, ++ const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, ++ Value thread) { ++ // Calculate tile size as number of mask elements (4xi4). ++ NvidiaMmaEncodingAttr mmaLayout = ++ sparseEncoding.getParent().cast(); ++ SmallVector shapePerCTATile = { ++ kTileSize * mmaLayout.getWarpsPerCTA()[0], ++ kTileSize / kMetadataElementsPerPackedValue}; ++ Value strideM = smemObj.strides[0]; ++ Value strideK = smemObj.strides[1]; ++ ++ // Calculate offset in the tile for the current thread. ++ Value threadsPerWarp = i32_val(kThreadsPerWarp); ++ Value warpId = udiv(thread, threadsPerWarp); ++ Value warpGroupId = urem(warpId, i32_val(shapePerCTATile[0] / kTileSize)); ++ Value laneId = urem(thread, threadsPerWarp); ++ Value laneGroupId = udiv(laneId, i32_val(kThreadsInGroup)); ++ Value columnId = urem(laneId, i32_val(shapePerCTATile[1])); ++ Value rowId = add(mul(warpGroupId, i32_val(kTileSize)), laneGroupId); ++ ++ // Calculate number of tile repetitions. ++ auto shape = tensor.getType().cast().getShape(); ++ int repM = shape[0] / shapePerCTATile[0]; ++ int repK = shape[1] / shapePerCTATile[1]; ++ assert(repM > 0 && repK > 0); ++ ++ // Load sparse metadata from shared memory. ++ MLIRContext *ctx = tensor.getContext(); ++ Type ptrTy = ptr_ty(ctx, 3); ++ Value base = gep(ptrTy, i16_ty, smemObj.base, i32_val(0)); ++ SmallVector values; ++ ++ for (int k = 0; k < repK; ++k) { ++ for (int m = 0; m < repM; ++m) { ++ Value row = add(rowId, i32_val(m * shapePerCTATile[0])); ++ Value column = add(columnId, i32_val(k * shapePerCTATile[1])); ++ Value offset1 = add(mul(row, strideM), mul(column, strideK)); ++ Value offset2 = add(offset1, mul(i32_val(kMetadataLineOffset), strideM)); ++ Value lower = load(i16_ty, gep(ptrTy, i16_ty, base, offset1)); ++ Value upper = load(i16_ty, gep(ptrTy, i16_ty, base, offset2)); ++ values.push_back(lower); ++ values.push_back(upper); ++ } ++ } ++ ++ // Pack resulting values as LLVM struct. ++ Type structTy = struct_ty(SmallVector(values.size(), i16_ty)); ++ return packLLElements(loc, typeConverter, values, rewriter, structTy); ++} ++} // namespace SharedToSparseDotOperand +diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp +--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp ++++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp +@@ -32,6 +32,12 @@ LogicalResult convertAsyncWGMMA(triton:: + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + Value thread); ++ ++LogicalResult rewriteSparseDotOp(triton::gpu::SparseDotOp op, ++ triton::gpu::SparseDotOp::Adaptor adaptor, ++ const LLVMTypeConverter *typeConverter, ++ ConversionPatternRewriter &rewriter); ++ + namespace { + struct DotOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +@@ -180,6 +186,18 @@ struct DotWaitOpConversion + return success(); + } + }; ++ ++struct SparseDotOpConversion ++ : public ConvertOpToLLVMPattern { ++ using ConvertOpToLLVMPattern< ++ triton::gpu::SparseDotOp>::ConvertOpToLLVMPattern; ++ ++ LogicalResult matchAndRewrite( ++ triton::gpu::SparseDotOp op, OpAdaptor adaptor, ++ ConversionPatternRewriter &rewriter) const override { ++ return rewriteSparseDotOp(op, adaptor, getTypeConverter(), rewriter); ++ } ++}; + } // namespace + + void mlir::triton::NVIDIA::populateDotOpToLLVMPatterns( +@@ -188,4 +206,5 @@ void mlir::triton::NVIDIA::populateDotOp + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); ++ patterns.add(typeConverter, benefit); + } +diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/Sparse.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/Sparse.cpp +new file mode 100644 +--- /dev/null ++++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/Sparse.cpp +@@ -0,0 +1,337 @@ ++#include "../Utility.h" ++ ++using namespace mlir; ++using namespace mlir::triton; ++using namespace mlir::triton::gpu; ++ ++using ::mlir::LLVM::getSharedMemoryObjectFromStruct; ++using ::mlir::triton::gpu::getShapePerCTA; ++using ::mlir::triton::gpu::getShapePerCTATile; ++using ::mlir::triton::gpu::SharedEncodingAttr; ++ ++using ValueTableV2 = std::map, Value>; ++ ++namespace { ++constexpr int kContractingFactor = 2; // implied by N:M (2:4) ++constexpr int kCore = 2; // number of core matrices per batch ++constexpr int kCoreTile = kCore * kContractingFactor; ++} // namespace ++ ++// ----- Ampere implementation. ++ ++ValueTableV2 getValuesFromDotOperandLayoutStruct(SmallVector elems, ++ int n0, int n1) { ++ int offset = 0; ++ ValueTableV2 vals; ++ for (int i = 0; i < n0; ++i) { ++ for (int j = 0; j < n1; ++j) { ++ vals[{kCore * i, kCore * j}] = elems[offset++]; ++ vals[{kCore * i, kCore * j + 1}] = elems[offset++]; ++ vals[{kCore * i + 1, kCore * j}] = elems[offset++]; ++ vals[{kCore * i + 1, kCore * j + 1}] = elems[offset++]; ++ } ++ } ++ return vals; ++} ++ ++std::string getMmaSpPtxInstruction(Type type) { ++ if (type.isF16()) { ++ return "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32"; ++ } else if (type.isBF16()) { ++ return "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32"; ++ } ++ llvm::report_fatal_error("Unsupported SparseDotOp operand type"); ++} ++ ++LogicalResult convertSparseMMA(SparseDotOp op, ++ SparseDotOp::Adaptor adaptor, ++ const LLVMTypeConverter *typeConverter, ++ ConversionPatternRewriter &rewriter) { ++ // Get number of repetitions across the dimensions. ++ auto aTensorTy = op.getA().getType().cast(); ++ auto bTensorTy = op.getB().getType().cast(); ++ ++ auto layoutA = aTensorTy.getEncoding().dyn_cast(); ++ auto layoutB = bTensorTy.getEncoding().dyn_cast(); ++ assert(layoutA != nullptr && layoutB != nullptr); ++ ++ int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth(); ++ auto mmaEnc = layoutA.getParent().cast(); ++ auto repA = mmaEnc.getMMAv2Rep(triton::gpu::getShapePerCTA(aTensorTy), ++ bitwidth, layoutA.getOpIdx()); ++ auto repB = mmaEnc.getMMAv2Rep(triton::gpu::getShapePerCTA(bTensorTy), ++ bitwidth, layoutB.getOpIdx()); ++ ++ assert(repA[0] == 1 && repB[0] == 1); // batch size ++ assert(repB[1] == repA[2] * kContractingFactor); ++ int repM = repA[1], repN = repB[2], repK = repB[1]; ++ ++ // Arrange loaded values into positions. ++ Location loc = op.getLoc(); ++ auto ha = getValuesFromDotOperandLayoutStruct( ++ unpackLLElements(loc, adaptor.getA(), rewriter), repM, ++ repK / kContractingFactor); ++ auto hb = getValuesFromDotOperandLayoutStruct( ++ unpackLLElements(loc, adaptor.getB(), rewriter), ++ std::max(repN / kCore, 1), repK); ++ ++ // Combine loaded metadata values. ++ auto hMeta = unpackLLElements(loc, adaptor.getAMeta(), rewriter); ++ SmallVector hMetaPacked; ++ for (int i = 0; i < hMeta.size(); i += kCore) { ++ Value lower = zext(i32_ty, hMeta[i]); ++ Value upper = zext(i32_ty, hMeta[i + 1]); ++ Value packed = or_(shl(upper, i32_val(16)), lower); ++ hMetaPacked.push_back(packed); ++ } ++ ++ // Flatten accumulator values. ++ auto dTensorTy = op.getD().getType().cast(); ++ auto fc = unpackLLElements(loc, adaptor.getC(), rewriter); ++ ++ // Create `mma.sp` instruction for 4/8 core matrices. ++ auto callMma = [&](unsigned m, unsigned n, unsigned k) { ++ PTXBuilder builder; ++ auto &mma = ++ *builder.create(getMmaSpPtxInstruction(aTensorTy.getElementType())); ++ ++ auto retArgs = builder.newListOperand(kCoreTile, "=f"); ++ auto cArgs = builder.newListOperand(); ++ int baseIdx = m * repN * kCore + n * kCoreTile; ++ for (int i = 0; i < kCoreTile; ++i) { ++ cArgs->listAppend(builder.newOperand(fc[baseIdx + i], std::to_string(i))); ++ } ++ int i = k / kContractingFactor; ++ auto aArgs = builder.newListOperand({ ++ {ha[{m, i}], "r"}, ++ {ha[{m + 1, i}], "r"}, ++ {ha[{m, i + 1}], "r"}, ++ {ha[{m + 1, i + 1}], "r"}, ++ }); ++ auto bArgs = builder.newListOperand({ ++ {hb[{n, k}], "r"}, ++ {hb[{n, k + 1}], "r"}, ++ {hb[{n, k + 2}], "r"}, ++ {hb[{n, k + 3}], "r"}, ++ }); ++ auto metaArg = ++ builder.newOperand(hMetaPacked[k / kCoreTile * repM + m / kCore], "r"); ++ auto selector = builder.newConstantOperand(0); ++ mma(retArgs, aArgs, bArgs, cArgs, metaArg, selector); ++ ++ Type fp32x4Ty = LLVM::LLVMStructType::getLiteral( ++ op.getContext(), SmallVector(kCoreTile, f32_ty)); ++ Value mmaOut = builder.launch(rewriter, loc, fp32x4Ty); ++ for (int i = 0; i < kCoreTile; ++i) { ++ fc[baseIdx + i] = extract_val(f32_ty, mmaOut, i); ++ } ++ }; ++ ++ for (int k = 0; k < repK; k += kContractingFactor) ++ for (int m = 0; m < repM; ++m) ++ for (int n = 0; n < repN; ++n) callMma(kCore * m, n, kCore * k); ++ ++ // Replace with new packed result. ++ Type structTy = LLVM::LLVMStructType::getLiteral( ++ op.getContext(), SmallVector(fc.size(), f32_ty)); ++ Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy); ++ rewriter.replaceOp(op, res); ++ ++ return success(); ++} ++ ++// ----- Hopper implementation. ++ ++// Forward declarations. ++Value createDescriptor(ConversionPatternRewriter &rewriter, Location loc, ++ int64_t swizzling, uint32_t stride); ++int64_t getSwizzlingFromLayout(const SharedEncodingAttr &layout, ++ uint32_t widthInByte); ++triton::nvgpu::WGMMAEltType getMmaRetType(Value); ++triton::nvgpu::WGMMAEltType getMmaOperandType(Value, bool); ++ ++namespace { ++constexpr int kThreadsPerWarp = 32; ++constexpr int kWarpsInGroup = 4; ++constexpr int kMmaAccumulatorCount = 2; ++constexpr int kMmaLineSize = 128; ++constexpr int kMmaAlignment = 16; ++} // namespace ++ ++// Shared memory descriptor builder for WGMMA. ++Value smemDescriptor(int a, int b, ConversionPatternRewriter &rewriter, ++ Location loc, std::vector instrShape, ++ bool trans, int dimWpt, Value warpId, MemDescType tensorTy, ++ Value baseDesc, int minor) { ++ auto sharedLayout = tensorTy.getEncoding().cast(); ++ int elemBytes = tensorTy.getElementTypeBitWidth() / 8; ++ int elemsPerSwizzlingRow = ++ kMmaLineSize / sharedLayout.getPerPhase() / elemBytes; ++ Value elemsPerSwizzlingRowVal = i32_val(elemsPerSwizzlingRow); ++ ++ Value k = i32_val(b * instrShape[1]); ++ Value m = add(i32_val(a * dimWpt * instrShape[0]), ++ mul(warpId, i32_val(instrShape[0]))); ++ if (trans) { ++ std::swap(k, m); ++ } ++ Value leading_offset = mul(udiv(k, elemsPerSwizzlingRowVal), ++ i32_val(minor * elemsPerSwizzlingRow)); ++ Value stride_offset = mul(m, elemsPerSwizzlingRowVal); ++ Value offset = ++ add(add(leading_offset, stride_offset), urem(k, elemsPerSwizzlingRowVal)); ++ Value off1 = mul(i32_val(elemBytes), offset); ++ Value off_ = zext(i64_ty, udiv(off1, i32_val(kMmaAlignment))); ++ ++ return add(baseDesc, off_); ++} ++ ++LogicalResult convertSparseWGMMA(SparseDotOp op, ++ SparseDotOp::Adaptor adaptor, ++ const LLVMTypeConverter *typeConverter, ++ ConversionPatternRewriter &rewriter, ++ Value thread) { ++ // Get number of repetitions across the dimensions. ++ auto aTensorTy = op.getA().getType().cast(); ++ auto bTensorTy = op.getB().getType().cast(); ++ auto dTensorTy = op.getD().getType().cast(); ++ auto mmaEnc = dTensorTy.getEncoding().cast(); ++ ++ auto shapePerCTA = getShapePerCTA(dTensorTy); ++ auto shapePerCTATile = getShapePerCTATile(mmaEnc); ++ auto instrShape = mmaEnc.getInstrShape(); ++ int repM = ceil(shapePerCTA[0], shapePerCTATile[0]); ++ int repN = ceil(shapePerCTA[1], shapePerCTATile[1]); ++ int repK = ceil(bTensorTy.getShape()[0], ++ instrShape[2] * kContractingFactor); ++ ++ // Flatten accumulator values. ++ auto loc = op.getLoc(); ++ auto fc = unpackLLElements(loc, adaptor.getC(), rewriter); ++ int accSize = kMmaAccumulatorCount * (instrShape[1] / kWarpsInGroup); ++ assert(fc.size() == repM * repN * accSize); ++ ++ // Get warp ID. ++ auto wpt = mmaEnc.getWarpsPerCTA(); ++ Value warp = ++ and_(udiv(thread, i32_val(kThreadsPerWarp)), i32_val(0xFFFFFFFC)); ++ Value warpM = urem(warp, i32_val(wpt[0])); ++ Value warpMN = udiv(warp, i32_val(wpt[0])); ++ Value warpN = urem(warpMN, i32_val(wpt[1])); ++ ++ // Create descriptor. ++ auto getSharedData = [&](Value arg, MemDescType tensorTy) { ++ auto sharedObj = getSharedMemoryObjectFromStruct( ++ loc, arg, typeConverter->convertType(tensorTy.getElementType()), ++ rewriter); ++ auto sharedLayout = tensorTy.getEncoding().cast(); ++ auto shape = getShapePerCTA(tensorTy); ++ auto ord = sharedLayout.getOrder(); ++ int byteSize = aTensorTy.getElementTypeBitWidth() / 8; ++ int64_t swizzling = ++ getSwizzlingFromLayout(sharedLayout, shape[ord[0]] * byteSize); ++ Value baseDesc = createDescriptor(rewriter, loc, swizzling, shape[ord[1]]); ++ return std::make_tuple(shape, ord, baseDesc); ++ }; ++ ++ // Create descriptor for loading A from shared memory. ++ auto tA = getSharedData(adaptor.getA(), aTensorTy); ++ Value warpA = urem(warpM, i32_val(std::get<0>(tA)[0] / instrShape[0])); ++ bool transA = std::get<1>(tA)[0] == 0; ++ auto loadA = [&](int m, int k) { ++ return smemDescriptor(m, k, rewriter, loc, {instrShape[0], instrShape[2]}, ++ transA, wpt[0], warpA, aTensorTy, std::get<2>(tA), ++ std::get<0>(tA)[std::get<1>(tA)[1]]); ++ }; ++ ++ // Create descriptor for loading B from shared memory. ++ auto tB = getSharedData(adaptor.getB(), bTensorTy); ++ Value warpB = urem(warpN, i32_val(std::get<0>(tB)[1] / instrShape[1])); ++ bool transB = std::get<1>(tB)[0] == 1; ++ auto loadB = [&](int n, int k) { ++ return smemDescriptor(n, k, rewriter, loc, ++ {instrShape[1], instrShape[2] * kContractingFactor}, ++ transB, wpt[1], warpB, bTensorTy, std::get<2>(tB), ++ std::get<0>(tB)[std::get<1>(tB)[1]]); ++ }; ++ ++ // Load metadata from shared memory. ++ auto hMeta = unpackLLElements(loc, adaptor.getAMeta(), rewriter); ++ SmallVector hMetaPacked; ++ for (int i = 0; i < hMeta.size(); i += kCore) { ++ Value lower = zext(i32_ty, hMeta[i]); ++ Value upper = zext(i32_ty, hMeta[i + 1]); ++ Value packed = or_(shl(upper, i32_val(16)), lower); ++ hMetaPacked.push_back(packed); ++ } ++ assert(hMetaPacked.size() == repM * repK); ++ ++ // Generate prologue. ++ triton::nvgpu::WGMMAEltType eltTypeA = getMmaOperandType(op.getA(), false); ++ triton::nvgpu::WGMMAEltType eltTypeB = getMmaOperandType(op.getB(), false); ++ triton::nvgpu::WGMMAEltType eltTypeC = getMmaRetType(op.getD()); ++ ++ triton::nvgpu::WGMMALayout layoutA = transA ? triton::nvgpu::WGMMALayout::col ++ : triton::nvgpu::WGMMALayout::row; ++ triton::nvgpu::WGMMALayout layoutB = transB ? triton::nvgpu::WGMMALayout::row ++ : triton::nvgpu::WGMMALayout::col; ++ ++ rewriter.create(loc, 0); ++ rewriter.create(loc); ++ ++ // Generate main loop. ++ for (int m = 0; m < repM; ++m) { ++ for (int n = 0; n < repN; ++n) { ++ llvm::MutableArrayRef acc(&fc[(m * repN + n) * accSize], accSize); ++ auto accTy = LLVM::LLVMStructType::getLiteral( ++ op.getContext(), SmallVector(accSize, f32_ty)); ++ Value d = packLLElements(loc, typeConverter, acc, rewriter, accTy); ++ for (int k = 0; k < repK; ++k) { ++ Value a = loadA(m, k); ++ Value b = loadB(n, k); ++ Value meta = hMetaPacked[k * repM + m]; ++ d = rewriter.create( ++ loc, accTy, a, meta, b, d, kWarpsInGroup * instrShape[0], ++ instrShape[1], kContractingFactor * instrShape[2], eltTypeC, ++ eltTypeA, eltTypeB, layoutA, layoutB); ++ } ++ auto res = unpackLLElements(loc, d, rewriter); ++ for (int i = 0; i < res.size(); ++i) { ++ acc[i] = res[i]; ++ } ++ } ++ } ++ ++ // Replace with new packed result. ++ Type structTy = LLVM::LLVMStructType::getLiteral( ++ op.getContext(), SmallVector(fc.size(), f32_ty)); ++ Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy); ++ ++ rewriter.create(loc); ++ res = rewriter.create(loc, res, 0); ++ rewriter.replaceOp(op, res); ++ ++ return success(); ++} ++ ++// ----- Dispatch based on architecture. ++ ++LogicalResult rewriteSparseDotOp(SparseDotOp op, ++ SparseDotOp::Adaptor adaptor, ++ const LLVMTypeConverter *typeConverter, ++ ConversionPatternRewriter &rewriter) { ++ auto resultTy = op.getResult().getType().cast(); ++ NvidiaMmaEncodingAttr mmaLayout = ++ resultTy.getEncoding().cast(); ++ ++ if (mmaLayout.isAmpere()) { ++ return convertSparseMMA(op, adaptor, typeConverter, rewriter); ++ } ++ if (mmaLayout.isHopper()) { ++ return convertSparseWGMMA(op, adaptor, typeConverter, rewriter, ++ getThreadId(rewriter, op.getLoc())); ++ } ++ ++ llvm::report_fatal_error( ++ "Unsupported SparseDotOp found when converting TritonGPU to LLVM."); ++} +diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp ++++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +@@ -87,8 +87,8 @@ int64_t getSwizzlingFromLayout(const Sha + return swizzlingByteWidth; + } + +-static Value createDescriptor(ConversionPatternRewriter &rewriter, Location loc, +- int64_t swizzling, uint32_t stride) { ++Value createDescriptor(ConversionPatternRewriter &rewriter, Location loc, ++ int64_t swizzling, uint32_t stride) { + // Create descriptor based on the format described in the spec: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor + union WGMMADescriptor { diff --git a/third_party/triton/sparse_dot_nvgpu.patch b/third_party/triton/sparse_dot_nvgpu.patch new file mode 100644 index 0000000000000..b96aeacced574 --- /dev/null +++ b/third_party/triton/sparse_dot_nvgpu.patch @@ -0,0 +1,136 @@ +diff --git a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td b/include/triton/Dialect/NVGPU/IR/NVGPUOps.td +--- a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td ++++ b/include/triton/Dialect/NVGPU/IR/NVGPUOps.td +@@ -87,6 +87,15 @@ def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", [] + let assemblyFormat = "$opA `,` $opB (`,` $opC^)? attr-dict `:` functional-type(operands, $res)"; + } + ++def NVGPU_SparseWGMMAOp : NVGPU_Op<"wgmma_sp", []> { ++ let arguments = (ins WGMMA_OperandType:$opA, I32:$metaA, WGMMA_OperandType:$opB, LLVM_AnyStruct:$opC, ++ I32Attr:$m, I32Attr:$n, I32Attr:$k, ++ WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB, ++ WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB); ++ let results = (outs LLVM_AnyStruct:$res); ++ let assemblyFormat = "$opA `meta` $metaA `,` $opB `,` $opC attr-dict `:` functional-type(operands, $res)"; ++} ++ + def NVGPU_LoadDSmemOp : NVGPU_Op<"load_dsmem", [MemoryEffects<[MemRead]>]> { + let arguments = (ins LLVM_AnyPointer:$addr, I32:$ctaId, I32Attr:$bitwidth, I32Attr:$vec); + let builders = [ +diff --git a/test/SparseDot/test_wgmma_sp.mlir b/test/SparseDot/test_wgmma_sp.mlir +new file mode 100644 +--- /dev/null ++++ b/test/SparseDot/test_wgmma_sp.mlir +@@ -0,0 +1,14 @@ ++// RUN: triton-opt %s -split-input-file --convert-nv-gpu-to-llvm | FileCheck %s ++ ++module attributes {"triton_gpu.num-warps" = 4 : i32} { ++ tt.func @wgmma_sp(%descA: i64, %metaA: i32, %descB: i64, %acc: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>) { ++ // CHECK: @wgmma_sp(%[[LHS:.*]]: i64, %[[META:.*]]: i32, %[[RHS:.*]]: i64, ++ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] ++ // CHECK-SAME: "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7}, $16, $17, $18, 0, 1, 1, 1, 0, 0;" ++ // CHECK-SAME: "=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,l,l,r" %0, %1, %2, %3, %4, %5, %6, %7, %[[LHS]], %[[RHS]], %[[META]] ++ %acc0 = nvgpu.wgmma_sp %descA meta %metaA, %descB, %acc ++ {eltTypeA = 5 : i32, eltTypeB = 5 : i32, eltTypeC = 7 : i32, layoutA = 0 : i32, layoutB = 1 : i32, m = 64 : i32, n = 16 : i32, k = 32 : i32} : ++ (i64, i32, i64, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> ++ tt.return ++ } ++} +diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +--- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp ++++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +@@ -688,6 +688,84 @@ public: + } + }; + ++class SparseWGMMAOpPattern ++ : public NVGPUOpPatternBase { ++public: ++ using Base = NVGPUOpPatternBase; ++ using Base::Base; ++ ++ std::vector getOutputConstraints(ttn::SparseWGMMAOp op) const { ++ auto outputStructType = op.getType().cast(); ++ uint32_t numOutputRegs = outputStructType.getBody().size(); ++ std::string output = ++ outputStructType.getBody().front().isF32() ? "=f" : "=r"; ++ return std::vector(numOutputRegs, output); ++ } ++ ++ OperandsAndConstraints getOperandsAndConstraints( ++ ttn::SparseWGMMAOp op) const { ++ return {{op.getOpC(), "0"}, {op.getOpA(), "l"}, {op.getOpB(), "l"}, ++ {op.getMetaA(), "r"}}; ++ } ++ ++ std::string getPtxAsm(ttn::SparseWGMMAOp op) const { ++ using namespace ttn; ++ auto opA = op.getOpA(); ++ auto opB = op.getOpB(); ++ auto m = op.getM(); ++ auto n = op.getN(); ++ auto k = op.getK(); ++ auto eltTypeC = op.getEltTypeC(); ++ auto eltTypeA = op.getEltTypeA(); ++ auto eltTypeB = op.getEltTypeB(); ++ auto layoutA = op.getLayoutA(); ++ auto layoutB = op.getLayoutB(); ++ ++ // Only f16/bf16 variant is supported. ++ bool supported = ++ eltTypeC == WGMMAEltType::f32 && ++ ((eltTypeA == WGMMAEltType::f16 && eltTypeB == WGMMAEltType::f16) || ++ (eltTypeA == WGMMAEltType::bf16 && eltTypeB == WGMMAEltType::bf16)) && ++ (m == 64 && 8 <= n && n <= 256 && n % 8 == 0 && k == 32); ++ assert(supported && "Sparse WGMMA type or shape is not supported"); ++ ++ // Operands ++ uint32_t asmOpIdx = 0; ++ std::string args = ""; ++ ++ // Output and operand C ++ uint32_t numCRegs = ++ op.getType().cast().getBody().size(); ++ args += "{"; ++ for (uint32_t i = 0; i < numCRegs; ++i) { ++ args += "$" + std::to_string(asmOpIdx++) + (i == numCRegs - 1 ? "" : ","); ++ } ++ args += "}, "; ++ asmOpIdx += numCRegs; ++ ++ // Operands A and B (must be `desc`) ++ args += "$" + std::to_string(asmOpIdx++) + ", "; ++ args += "$" + std::to_string(asmOpIdx++) + ", "; ++ ++ // Metadata for A ++ args += "$" + std::to_string(asmOpIdx++) + ", 0, "; ++ ++ // `scale-d`, `imm-scale-a`, and `imm-scale-b` are 1 by default ++ args += "1, 1, 1"; ++ ++ // `trans-a` and `trans-b` ++ args += ", " + std::to_string(layoutA == WGMMALayout::col); ++ args += ", " + std::to_string(layoutB == WGMMALayout::row); ++ ++ auto ptxAsm = "wgmma.mma_async.sp.sync.aligned" ++ ".m" + std::to_string(m) + "n" + std::to_string(n) + "k" + ++ std::to_string(k) + "." + stringifyEnum(eltTypeC).str() + ++ "." + stringifyEnum(eltTypeA).str() + "." + ++ stringifyEnum(eltTypeB).str() + " " + args + ";"; ++ return ptxAsm; ++ } ++}; ++ + class ConvertNVGPUToLLVM : public ConvertNVGPUToLLVMBase { + + public: +@@ -711,7 +789,8 @@ public: + + patterns.add(context); ++ WGMMAWaitGroupOpPattern, StoreDSmemOpPattern, ++ SparseWGMMAOpPattern>(context); + + if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) + signalPassFailure(); diff --git a/third_party/triton/sparse_dot_passes.patch b/third_party/triton/sparse_dot_passes.patch new file mode 100644 index 0000000000000..e610f67cd7030 --- /dev/null +++ b/third_party/triton/sparse_dot_passes.patch @@ -0,0 +1,591 @@ +diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +--- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp ++++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +@@ -277,6 +277,89 @@ struct TritonDotPattern : public OpConve + } + }; + ++struct TritonSparseDotPattern ++ : public OpConversionPattern { ++ using OpConversionPattern::OpConversionPattern; ++ ++ LogicalResult matchAndRewrite( ++ triton::gpu::SparseDotOp op, OpAdaptor adaptor, ++ ConversionPatternRewriter &rewriter) const override { ++ RankedTensorType origType = op.getType().cast(); ++ auto origShape = origType.getShape(); ++ auto typeConverter = getTypeConverter(); ++ int numWarps = typeConverter->getNumWarps(); ++ int threadsPerWarp = typeConverter->getThreadsPerWarp(); ++ int numCTAs = typeConverter->getNumCTAs(); ++ ++ auto rank = origShape.size(); ++ auto numElements = product(origShape); ++ SmallVector retSizePerThread(rank, 1); ++ if (numElements / (numWarps * threadsPerWarp) >= 4) { ++ retSizePerThread[rank - 1] = 2; ++ retSizePerThread[rank - 2] = 2; ++ } ++ if (numElements / (numWarps * threadsPerWarp) >= 16) { ++ retSizePerThread[rank - 1] = 4; ++ retSizePerThread[rank - 2] = 4; ++ } ++ SmallVector retOrder(rank); ++ for (unsigned i = 0; i < rank; ++i) ++ retOrder[i] = rank - 1 - i; ++ Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get( ++ getContext(), origShape, retSizePerThread, retOrder, numWarps, ++ threadsPerWarp, numCTAs); ++ RankedTensorType retType = ++ RankedTensorType::get(origShape, origType.getElementType(), dEncoding); ++ ++ // a & b must be of smem layout ++ auto aType = adaptor.getA().getType().cast(); ++ auto bType = adaptor.getB().getType().cast(); ++ Type aEltType = aType.getElementType(); ++ Type bEltType = bType.getElementType(); ++ Attribute aEncoding = aType.getEncoding(); ++ Attribute bEncoding = bType.getEncoding(); ++ if (!aEncoding || !bEncoding) ++ return failure(); ++ Value a = adaptor.getA(); ++ Value b = adaptor.getB(); ++ Value c = adaptor.getC(); ++ if (!aEncoding.isa()) { ++ Attribute encoding = triton::gpu::DotOperandEncodingAttr::get( ++ getContext(), 0, dEncoding, aEltType); ++ auto dstType = ++ RankedTensorType::get(aType.getShape(), aEltType, encoding); ++ a = rewriter.create(a.getLoc(), dstType, a); ++ } ++ if (!bEncoding.isa()) { ++ Attribute encoding = triton::gpu::DotOperandEncodingAttr::get( ++ getContext(), 1, dEncoding, bEltType); ++ auto dstType = ++ RankedTensorType::get(bType.getShape(), bEltType, encoding); ++ b = rewriter.create(b.getLoc(), dstType, b); ++ } ++ c = rewriter.create(c.getLoc(), retType, c); ++ ++ // aMeta must be of smem layout ++ auto aMetaType = adaptor.getAMeta().getType().cast(); ++ Attribute aMetaEncoding = aMetaType.getEncoding(); ++ if (!aMetaEncoding) return failure(); ++ Value aMeta = adaptor.getAMeta(); ++ if (!aMetaEncoding.isa()) { ++ Attribute encoding = ++ triton::gpu::SparseDotMetaEncodingAttr::get(getContext(), dEncoding); ++ auto dstType = RankedTensorType::get( ++ aMetaType.getShape(), aMetaType.getElementType(), encoding); ++ aMeta = rewriter.create(aMeta.getLoc(), ++ dstType, aMeta); ++ } ++ ++ addNamedAttrs(rewriter.replaceOpWithNewOp( ++ op, retType, a, b, c, aMeta), ++ adaptor.getAttributes()); ++ return success(); ++ } ++}; ++ + struct TritonCatPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +@@ -550,6 +633,7 @@ void populateTritonPatterns(TritonGPUTyp + GenericOpPattern, GenericOpPattern, + GenericOpPattern, TritonFuncOpPattern>(typeConverter, + context); ++ patterns.insert(typeConverter, context); + } + + // +@@ -788,6 +872,12 @@ public: + IntegerAttr::get( + i32_ty, llvm::APInt(32, computeCapability.getValue()))); + ++ // Only transform sparse dot op with undefined layout. ++ target.addDynamicallyLegalOp( ++ [](triton::gpu::SparseDotOp op) { ++ return op.getAMeta().getType().getEncoding() != nullptr; ++ }); ++ + if (failed(applyPartialConversion(mod, target, std::move(patterns)))) + return signalPassFailure(); + +diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +--- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +@@ -42,8 +42,9 @@ static int getMMAVersionSafe(int compute + return 0; + } + ++template + SmallVector +-warpsPerTileV2(tt::DotOp dotOp, const ArrayRef shape, int numWarps) { ++warpsPerTileV2(DotType dotOp, const ArrayRef shape, int numWarps) { + auto rank = shape.size(); + // Early exit for batched matmul + if (rank == 3) +@@ -56,14 +57,14 @@ warpsPerTileV2(tt::DotOp dotOp, const Ar + auto slices = multiRootGetSlice(dotOp, {filter}, {filter}); + bool hasChainedDot = false; + for (Operation *op : slices) { +- if (isa(op) && (op != dotOp)) { +- auto chainedDot = cast(op); ++ if (isa(op) && (op != dotOp)) { ++ auto chainedDot = cast(op); + auto resTy = chainedDot.getResult().getType(); + if (resTy.getRank() != rank) { + continue; + } + if (auto mmaEncoding = +- resTy.getEncoding().dyn_cast()) { ++ resTy.getEncoding().template dyn_cast()) { + return ttg::getWarpsPerCTA(mmaEncoding); + } + hasChainedDot = true; +@@ -101,12 +102,13 @@ warpsPerTileV2(tt::DotOp dotOp, const Ar + return ret; + } + +-SmallVector +-warpsPerTileV3(tt::DotOp dotOp, const ArrayRef shape, int numWarps, +- const SmallVector &instrShape) { ++template ++SmallVector warpsPerTileV3( ++ DotType dotOp, const ArrayRef shape, int numWarps, ++ const SmallVector &instrShape) { + SetVector slices; + mlir::getForwardSlice(dotOp.getResult(), &slices); +- if (llvm::find_if(slices, [](Operation *op) { return isa(op); }) != ++ if (llvm::find_if(slices, [](Operation *op) { return isa(op); }) != + slices.end()) + return {(unsigned)numWarps, 1}; + +@@ -175,9 +177,10 @@ public: + : mlir::RewritePattern(tt::DotOp::getOperationName(), 2, context), + computeCapability(computeCapability) {} + +- static SmallVector +- getWarpsPerTile(tt::DotOp dotOp, const ArrayRef shape, int version, +- int numWarps, const SmallVector &instrShape) { ++ template ++ static SmallVector getWarpsPerTile( ++ DotType dotOp, const ArrayRef shape, int version, int numWarps, ++ const SmallVector &instrShape) { + switch (version) { + case 2: + return warpsPerTileV2(dotOp, shape, numWarps); +@@ -337,6 +340,97 @@ public: + return success(); + } + }; ++ ++class SparseBlockedToMMA : public mlir::RewritePattern { ++ public: ++ using SparseDotOp = mlir::triton::gpu::SparseDotOp; ++ using SparseDotMetaEncodingAttr = ++ mlir::triton::gpu::SparseDotMetaEncodingAttr; ++ ++ SparseBlockedToMMA(mlir::MLIRContext *context, int computeCapability) ++ : mlir::RewritePattern(SparseDotOp::getOperationName(), 2, context), ++ computeCapability(computeCapability) {} ++ ++ mlir::LogicalResult matchAndRewrite( ++ mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { ++ auto dotOp = cast(op); ++ auto ctx = op->getContext(); ++ Value a = dotOp.getA(); ++ Value b = dotOp.getB(); ++ ++ // Check data-types and SM compatibility ++ RankedTensorType oldRetType = dotOp.getType(); ++ if (!oldRetType.getEncoding() || ++ oldRetType.getEncoding().isa()) ++ return failure(); ++ ++ assert(computeCapability >= 80 && ++ "SparseDot is supported on Ampere and higher"); ++ int versionMajor = computeCapability < 90 ? 2 : 3; ++ ++ // get MMA encoding for the given number of warps ++ auto retShapePerCTA = ttg::getShapePerCTA(oldRetType); ++ auto mod = op->getParentOfType(); ++ int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); ++ auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding()); ++ ++ auto instrShape = mmaVersionToInstrShape( ++ versionMajor, retShapePerCTA, a.getType().cast()); ++ auto warpsPerTile = BlockedToMMA::getWarpsPerTile( ++ dotOp, retShapePerCTA, versionMajor, numWarps, instrShape); ++ ttg::NvidiaMmaEncodingAttr mmaEnc = ++ ttg::NvidiaMmaEncodingAttr::get(ctx, versionMajor, /*versionMinor=*/0, ++ warpsPerTile, CTALayout, instrShape); ++ auto newRetType = RankedTensorType::get( ++ oldRetType.getShape(), oldRetType.getElementType(), mmaEnc); ++ ++ // convert accumulator ++ auto oldAcc = dotOp.getOperand(2); ++ auto newAcc = rewriter.create(oldAcc.getLoc(), ++ newRetType, oldAcc); ++ ++ if (versionMajor == 2) { ++ // convert A operand ++ auto oldAType = a.getType().cast(); ++ auto newAEncoding = ttg::DotOperandEncodingAttr::get( ++ ctx, 0, mmaEnc, oldAType.getElementType()); ++ auto newAType = RankedTensorType::get( ++ oldAType.getShape(), oldAType.getElementType(), newAEncoding); ++ a = rewriter.create(a.getLoc(), newAType, a); ++ ++ // convert B operand ++ auto oldBType = b.getType().cast(); ++ auto newBEncoding = ttg::DotOperandEncodingAttr::get( ++ ctx, 1, mmaEnc, oldBType.getElementType()); ++ auto newBType = RankedTensorType::get( ++ oldBType.getShape(), oldBType.getElementType(), newBEncoding); ++ b = rewriter.create(b.getLoc(), newBType, b); ++ } else { ++ a = BlockedToMMA::getMMAv3Operand(a, rewriter, 0); ++ b = BlockedToMMA::getMMAv3Operand(b, rewriter, 1); ++ } ++ ++ // convert metadata ++ Value meta = dotOp.getAMeta(); ++ auto oldMetaType = meta.getType().cast(); ++ auto newMetaType = RankedTensorType::get( ++ oldMetaType.getShape(), oldMetaType.getElementType(), ++ SparseDotMetaEncodingAttr::get(ctx, mmaEnc)); ++ meta = ++ rewriter.create(meta.getLoc(), newMetaType, meta); ++ ++ // convert dot instruction ++ auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, ++ newAcc, meta); ++ ++ rewriter.replaceOpWithNewOp(op, oldRetType, ++ newDot.getResult()); ++ return success(); ++ } ++ ++ private: ++ int computeCapability; ++}; + } // namespace + + static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, +@@ -397,6 +491,7 @@ public: + + mlir::RewritePatternSet patterns(context); + patterns.add<::BlockedToMMA>(context, computeCapability); ++ patterns.add<::SparseBlockedToMMA>(context, computeCapability); + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { + signalPassFailure(); + } +diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +--- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +@@ -47,6 +47,10 @@ struct PipelinedOpInfo { + bool loadIsMMAV3 = false; + }; + ++bool isDotOp(Operation* op) { ++ return isa(op); ++} ++ + } // namespace + + static bool isMMAv3Dot(Operation *op) { +@@ -163,22 +167,28 @@ getSharedEncIfAllUsersAreDotEnc(Value va + } else { + if (!isa(user)) + return std::nullopt; +- auto dotOpEnc = user->getResult(0) +- .getType() +- .cast() +- .getEncoding() +- .dyn_cast(); +- if (!dotOpEnc) ++ auto enc = ++ user->getResult(0).getType().cast().getEncoding(); ++ if (isa(enc)) { ++ auto srcTy = val.getType().cast(); ++ auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); ++ auto order = ttg::getOrder(srcTy.getEncoding()); ++ unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); ++ tempAttr = ttg::SharedEncodingAttr::get( ++ val.getContext(), cast(enc), ++ srcTy.getShape(), ttg::getOrder(srcTy.getEncoding()), ++ ttg::getCTALayout(srcTy.getEncoding()), ++ srcTy.getElementType().getIntOrFloatBitWidth(), ++ /*needTrans=*/false); ++ } else if (isa(enc)) { ++ auto srcTy = val.getType().cast(); ++ tempAttr = ttg::SharedEncodingAttr::get( ++ val.getContext(), /*vec=*/1, /*perPhase=*/1, /*maxPhase=*/1, ++ ttg::getOrder(srcTy.getEncoding()), ++ ttg::getCTALayout(srcTy.getEncoding())); ++ } else { + return std::nullopt; +- auto srcTy = val.getType().cast(); +- auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); +- auto order = ttg::getOrder(srcTy.getEncoding()); +- unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); +- tempAttr = ttg::SharedEncodingAttr::get( +- val.getContext(), dotOpEnc, srcTy.getShape(), +- ttg::getOrder(srcTy.getEncoding()), +- ttg::getCTALayout(srcTy.getEncoding()), +- srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false); ++ } + } + // Check that the shared encodings needed by the users are compatible. + if (!tempAttr || (attr != nullptr && attr != tempAttr)) +@@ -311,7 +321,7 @@ loadOpsToDistanceAndUse(scf::ForOp forOp + }; + + for (Operation &op : forOp.getBody()->without_terminator()) { +- if (!isa(op)) ++ if (!isDotOp(&op)) + continue; + dfs(&op, 0, &op); + } +@@ -385,7 +395,7 @@ collectOpsToPipeline(scf::ForOp forOp, + // loads. + for (auto &[loadOp, distAndUse] : loadOpToDistAndUse) { + PipelinedOpInfo loadInfo; +- if (isa(distAndUse.second)) { ++ if (isDotOp(distAndUse.second)) { + if (loadIsMMAv3(loadOp)) { + loadInfo.loadIsMMAV3 = true; + loadInfo.sharedEncoding = +@@ -743,7 +753,7 @@ bool mlir::triton::preProcessLoopAndGetS + int useStage = opToInfo[info.use].stage; + int numBuffers = useStage - defStage; + +- if (hasMMAV3 && isa(info.use)) { ++ if (hasMMAV3 && isDotOp(info.use)) { + // For MMAv3, we need an extra buffer as this is assumed in the wgmma + // pipelining post-processing. + numBuffers++; +diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp +--- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp +@@ -36,6 +36,10 @@ public: + auto srcEncoding = srcType.getEncoding(); + if (srcEncoding.isa()) + return; ++ if (dstType.getEncoding().isa()) { ++ replaceSparseMetaEncoding(cvtOp); ++ return; ++ } + auto dstDotOp = + dstType.getEncoding().dyn_cast(); + if (!dstDotOp) +@@ -74,6 +78,27 @@ public: + cvtOp.erase(); + }); + } ++ ++ private: ++ void replaceSparseMetaEncoding(triton::gpu::ConvertLayoutOp cvtOp) { ++ auto srcType = cvtOp.getOperand().getType().cast(); ++ auto srcEncoding = srcType.getEncoding(); ++ auto sharedLayout = triton::gpu::SharedEncodingAttr::get( ++ cvtOp.getContext(), 8, 1, 1, triton::gpu::getOrder(srcEncoding), ++ triton::gpu::getCTALayout(srcEncoding)); ++ ++ auto dstType = cvtOp.getType().cast(); ++ auto tmpType = triton::MemDescType::get( ++ dstType.getShape(), dstType.getElementType(), sharedLayout); ++ ++ OpBuilder builder(cvtOp); ++ auto tmp = builder.create( ++ cvtOp.getLoc(), tmpType, cvtOp.getSrc()); ++ auto newConvert = builder.create( ++ cvtOp.getLoc(), dstType, tmp); ++ cvtOp.replaceAllUsesWith(newConvert.getResult()); ++ cvtOp.erase(); ++ } + }; + + std::unique_ptr mlir::triton::gpu::createReduceDataDuplicationPass() { +diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +--- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp ++++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +@@ -45,7 +45,7 @@ public: + return; + ModuleOp mod = getOperation(); + mod.walk([&](Operation *op) { +- if (!isa(op)) ++ if (!isa(op)) + return WalkResult::advance(); + OpBuilder builder(op); + auto a = op->getOperand(0); +@@ -83,7 +83,7 @@ private: + static DenseSet> trace; + auto op = operand.getDefiningOp(); + // avoid redundant insertion +- if (op && isa(op)) ++ if (op && isa(op)) + return false; + // reach convertlayout + if (op && isa(op) && +diff --git a/test/SparseDot/add_layout.mlir b/test/SparseDot/add_layout.mlir +new file mode 100644 +--- /dev/null ++++ b/test/SparseDot/add_layout.mlir +@@ -0,0 +1,15 @@ ++// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu | FileCheck %s ++ ++// CHECK-COUNT-4: #triton_gpu.blocked ++module attributes {"triton_gpu.num-warps" = 4 : i32} { ++ tt.func @sparse_dot() { ++ %A = arith.constant dense<1.00e+00> : tensor<64x32xf16> ++ %meta = arith.constant dense<0x3333> : tensor<64x4xi16> ++ %B = arith.constant dense<2.00e+00> : tensor<64x64xf16> ++ %C = arith.constant dense<0.00e+00> : tensor<64x64xf32> ++ // CHECK-COUNT-4: triton_gpu.convert_layout ++ // CHECK: triton_gpu.sparse_dot {{.+}} #triton_gpu.sparse_dot_meta ++ %D = triton_gpu.sparse_dot %A, %B, %C, %meta : tensor<64x32xf16> meta tensor<64x4xi16> * tensor<64x64xf16> -> tensor<64x64xf32> ++ tt.return ++ } ++} +diff --git a/test/SparseDot/ttg_accelerate_matmul.mlir b/test/SparseDot/ttg_accelerate_matmul.mlir +new file mode 100644 +--- /dev/null ++++ b/test/SparseDot/ttg_accelerate_matmul.mlir +@@ -0,0 +1,27 @@ ++// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file -tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s ++// RUN: triton-opt %s -split-input-file -tritongpu-accelerate-matmul=compute-capability=80 | FILECHECK_OPTS= FileCheck %s --check-prefix=CHECK-80 ++ ++#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> ++// CHECK: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> ++// CHECK-80: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> ++#lhs = #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}> ++#rhs = #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}> ++module attributes {"triton_gpu.num-warps" = 4 : i32} { ++ tt.func @sparse_dot(%A: tensor<64x32xf16, #lhs>, %B: tensor<64x64xf16, #rhs>, %meta: tensor<64x4xi16, #blocked>) -> tensor<64x64xf32, #blocked> { ++ %C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked> ++ // CHECK-DAG: %[[LHS:.+]] = triton_gpu.local_alloc {{.+}} : (tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> !tt.memdesc<64x32xf16, #{{.+}}> ++ // CHECK-DAG: %[[RHS:.+]] = triton_gpu.local_alloc {{.+}} : (tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> !tt.memdesc<64x64xf16, #{{.+}}> ++ // CHECK-DAG: %[[ACC:.+]] = triton_gpu.convert_layout {{.+}} : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #[[MMA]]> ++ // CHECK-DAG: %[[META:.+]] = triton_gpu.convert_layout {{.+}} : tensor<64x4xi16, #blocked> -> tensor<64x4xi16, #triton_gpu.sparse_dot_meta<{parent = #[[MMA]]}>> ++ // CHECK: %[[OUT:.+]] = triton_gpu.sparse_dot %[[LHS]], %[[RHS]], %[[ACC]], %[[META]] : {{.+}} -> tensor<64x64xf32, #[[MMA]]> ++ // CHECK-80-DAG: %[[LHS:.+]] = triton_gpu.convert_layout {{.+}} : {{.+}} -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> ++ // CHECK-80-DAG: %[[RHS:.+]] = triton_gpu.convert_layout {{.+}} : {{.+}} -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> ++ // CHECK-80-DAG: %[[ACC:.+]] = triton_gpu.convert_layout {{.+}} : {{.+}} -> tensor<64x64xf32, #[[MMA]]> ++ // CHECK-80-DAG: %[[META:.+]] = triton_gpu.convert_layout {{.+}} : {{.+}} -> tensor<64x4xi16, #triton_gpu.sparse_dot_meta<{parent = #[[MMA]]}>> ++ // CHECK-80: %[[OUT:.+]] = triton_gpu.sparse_dot %[[LHS]], %[[RHS]], %[[ACC]], %[[META]] : {{.+}} -> tensor<64x64xf32, #[[MMA]]> ++ %D = triton_gpu.sparse_dot %A, %B, %C, %meta : tensor<64x32xf16, #lhs> meta tensor<64x4xi16, #blocked> * tensor<64x64xf16, #rhs> -> tensor<64x64xf32, #blocked> ++ // CHECK: triton_gpu.convert_layout %[[OUT]] : tensor<64x64xf32, #[[MMA]]> -> tensor<64x64xf32, #blocked> ++ // CHECK-80: triton_gpu.convert_layout %[[OUT]] : tensor<64x64xf32, #[[MMA]]> -> tensor<64x64xf32, #blocked> ++ tt.return %D : tensor<64x64xf32, #blocked> ++ } ++} +diff --git a/test/SparseDot/ttg_fence_insertion.mlir b/test/SparseDot/ttg_fence_insertion.mlir +new file mode 100644 +--- /dev/null ++++ b/test/SparseDot/ttg_fence_insertion.mlir +@@ -0,0 +1,18 @@ ++// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file -triton-nvidia-gpu-fence-insertion | FileCheck %s ++ ++#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> ++#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> ++#lhs = #triton_gpu.dot_op<{opIdx = 0, parent = #mma}> ++#rhs = #triton_gpu.dot_op<{opIdx = 1, parent = #mma}> ++#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> ++module attributes {"triton_gpu.num-warps" = 4 : i32} { ++ tt.func public @sparse_dot_fence(%A: tensor<64x32xf16, #lhs>, %B: tensor<64x64xf16, #rhs>, %meta: tensor<64x4xi16, #blocked>) { ++ %C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> ++ %0 = triton_gpu.local_alloc %A : (tensor<64x32xf16, #lhs>) -> !tt.memdesc<64x32xf16, #shared> ++ %1 = triton_gpu.local_alloc %B : (tensor<64x64xf16, #rhs>) -> !tt.memdesc<64x64xf16, #shared> ++ %2 = triton_gpu.convert_layout %meta : tensor<64x4xi16, #blocked> -> tensor<64x4xi16, #triton_gpu.sparse_dot_meta<{parent = #mma}>> ++ // CHECK: triton_nvidia_gpu.fence_async_shared ++ %3 = triton_gpu.sparse_dot %0, %1, %C, %2 : !tt.memdesc<64x32xf16, #shared> meta tensor<64x4xi16, #triton_gpu.sparse_dot_meta<{parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf32, #mma> ++ tt.return ++ } ++} +diff --git a/test/SparseDot/ttg_loop_pipeline.mlir b/test/SparseDot/ttg_loop_pipeline.mlir +new file mode 100644 +--- /dev/null ++++ b/test/SparseDot/ttg_loop_pipeline.mlir +@@ -0,0 +1,61 @@ ++// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 | FileCheck %s ++ ++#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> ++#sliced = #triton_gpu.slice<{parent=#blocked, dim=0}> ++#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> ++#dot_operand_a = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth=2}> ++#dot_operand_b = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth=2}> ++#dot_meta_enc = #triton_gpu.sparse_dot_meta<{parent=#mma}> ++ ++module attributes {"triton_gpu.num-warps" = 4 : i32} { ++ tt.func @sparse_dot_loop(%lb : index, %ub : index, %step : index, ++ %A : !tt.ptr {tt.divisibility = 16 : i32}, ++ %B : !tt.ptr {tt.divisibility = 16 : i32}, ++ %A_meta : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { ++ // CHECK-COUNT-6: triton_gpu.async_copy_global_to_local ++ // CHECK: triton_gpu.async_wait {{.+}}, {{.+}} {num = 3 : i32} ++ %a_ptr_splat = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked> ++ %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #sliced> ++ %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32, #sliced> -> tensor<1x32xi32, #blocked> ++ %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32, #blocked> -> tensor<128x32xi32, #blocked> ++ %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #blocked>, tensor<128x32xi32, #blocked> ++ ++ %b_ptr_splat = tt.splat %B : !tt.ptr -> tensor<64x128x!tt.ptr, #blocked> ++ %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #sliced> ++ %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32, #sliced> -> tensor<1x128xi32, #blocked> ++ %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32, #blocked> -> tensor<64x128xi32, #blocked> ++ %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> ++ ++ %meta_ptr_splat = tt.splat %A_meta : !tt.ptr -> tensor<128x4x!tt.ptr, #blocked> ++ %meta_tmp0 = tt.make_range {end = 4: i32, start = 0: i32} : tensor<4xi32, #sliced> ++ %meta_tmp1 = tt.expand_dims %meta_tmp0 {axis = 0 : i32} : tensor<4xi32, #sliced> -> tensor<1x4xi32, #blocked> ++ %meta_offs = tt.broadcast %meta_tmp1 : tensor<1x4xi32, #blocked> -> tensor<128x4xi32, #blocked> ++ %meta_ptr_init = tt.addptr %meta_ptr_splat, %meta_offs : tensor<128x4x!tt.ptr, #blocked>, tensor<128x4xi32, #blocked> ++ ++ %a_off = arith.constant dense<4> : tensor<128x32xi32, #blocked> ++ %b_off = arith.constant dense<4> : tensor<64x128xi32, #blocked> ++ %meta_off = arith.constant dense<4> : tensor<128x4xi32, #blocked> ++ %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #mma> ++ ++ // CHECK: scf.for ++ %loop:4 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %c = %c_init, %meta_ptr = %meta_ptr_init) ++ -> (tensor<128x32x!tt.ptr, #blocked>, tensor<64x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, tensor<128x4x!tt.ptr, #blocked>) { ++ // CHECK-COUNT-3: triton_gpu.local_load ++ // CHECK: triton_gpu.sparse_dot ++ // CHECK-COUNT-3: triton_gpu.async_copy_global_to_local ++ %a_ = tt.load %a_ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #blocked> ++ %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #blocked> -> tensor<128x32xf16, #dot_operand_a> ++ %b_ = tt.load %b_ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x128xf16, #blocked> ++ %b = triton_gpu.convert_layout %b_ : tensor<64x128xf16, #blocked> -> tensor<64x128xf16, #dot_operand_b> ++ %meta_ = tt.load %meta_ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x4xi16, #blocked> ++ %meta = triton_gpu.convert_layout %meta_ : tensor<128x4xi16, #blocked> -> tensor<128x4xi16, #dot_meta_enc> ++ %d = triton_gpu.sparse_dot %a, %b, %c, %meta : tensor<128x32xf16, #dot_operand_a> meta tensor<128x4xi16, #dot_meta_enc> * tensor<64x128xf16, #dot_operand_b> -> tensor<128x128xf32, #mma> ++ ++ %a_ptr_next = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #blocked>, tensor<128x32xi32, #blocked> ++ %b_ptr_next = tt.addptr %b_ptr, %b_off : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> ++ %meta_ptr_next = tt.addptr %meta_ptr, %meta_off : tensor<128x4x!tt.ptr, #blocked>, tensor<128x4xi32, #blocked> ++ scf.yield %a_ptr_next, %b_ptr_next, %d, %meta_ptr_next : tensor<128x32x!tt.ptr, #blocked>, tensor<64x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, tensor<128x4x!tt.ptr, #blocked> ++ } ++ tt.return %loop#2: tensor<128x128xf32, #mma> ++ } ++} +diff --git a/test/SparseDot/ttg_reduce_data_duplication.mlir b/test/SparseDot/ttg_reduce_data_duplication.mlir +new file mode 100644 +--- /dev/null ++++ b/test/SparseDot/ttg_reduce_data_duplication.mlir +@@ -0,0 +1,13 @@ ++// RUN: triton-opt %s -split-input-file -tritongpu-reduce-data-duplication | FileCheck %s ++ ++#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> ++#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> ++// CHECK: #[[SHARED:.+]] = #triton_gpu.shared ++module attributes {"triton_gpu.num-warps" = 4 : i32} { ++ tt.func @sparse_dot_metadata(%meta: tensor<64x4xi16, #blocked>) { ++ // CHECK: %[[META:.+]] = triton_gpu.local_alloc {{.+}} : (tensor<64x4xi16, #blocked>) -> !tt.memdesc<64x4xi16, #[[SHARED]]> ++ // CHECK: triton_gpu.local_load %[[META]] : !tt.memdesc<64x4xi16, #[[SHARED]]> -> tensor<64x4xi16, #triton_gpu.sparse_dot_meta<{parent = #mma}>> ++ %0 = triton_gpu.convert_layout %meta : tensor<64x4xi16, #blocked> -> tensor<64x4xi16, #triton_gpu.sparse_dot_meta<{parent = #mma}>> ++ tt.return ++ } ++} diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index 834668f112a38..4c21d35991e20 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -5,9 +5,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl584018112" - TRITON_SHA256 = "a0f2461af9fbcf576cef08e0b83ab7a1caa3cfe2041c60b2809cbd495ff14f08" - + TRITON_COMMIT = "cl619179472" + TRITON_SHA256 = "aa0b0b338bf16aa7eea778312fa549a421278b24d1a4bc04f5d6ced706f693fe" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, @@ -15,8 +14,9 @@ def repo(): urls = tf_mirror_urls("https://github.com/openxla/triton/archive/{commit}.tar.gz".format(commit = TRITON_COMMIT)), # For temporary changes which haven't landed upstream yet. patch_file = [ - "//third_party/triton:b304456327.patch", - "//third_party/triton:cl568176943.patch", - "//third_party/triton:cl584230333.patch", + "//third_party/triton:cl607293980.patch", # long standing :( + "//third_party/triton:sparse_dot_nvgpu.patch", + "//third_party/triton:sparse_dot_base.patch", + "//third_party/triton:sparse_dot_passes.patch", ], ) diff --git a/third_party/tsl/.bazelrc b/third_party/tsl/.bazelrc index 42330a369c9f6..d8990ac5c12cc 100644 --- a/third_party/tsl/.bazelrc +++ b/third_party/tsl/.bazelrc @@ -134,7 +134,7 @@ build --experimental_link_static_libraries_once=false # Prevent regressions on those two incompatible changes # TODO: remove those flags when they are flipped in the default Bazel version TF uses. build --incompatible_enforce_config_setting_visibility -# TODO: also enable this flag after fixing the visbility violations +# TODO: also enable this flag after fixing the visibility violations # build --incompatible_config_setting_private_default_visibility # Default options should come above this line. @@ -243,18 +243,26 @@ build:cuda_clang --@local_config_cuda//:cuda_compiler=clang # release while SASS is only forward compatible inside the current # major release. Example: sm_80 kernels can run on sm_89 GPUs but # not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs. -build:cuda_clang --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" +build:cuda_clang --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. build:cuda_clang_official --config=cuda_clang build:cuda_clang_official --action_env=TF_CUDA_VERSION="12" build:cuda_clang_official --action_env=TF_CUDNN_VERSION="8" -build:cuda_clang_official --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.2" +build:cuda_clang_official --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.3" build:cuda_clang_official --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-17/bin/clang" build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:cuda_clang_official --crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain" +# Build with nvcc for CUDA and clang for host +build:nvcc_clang --config=cuda +# Unfortunately, cuda_configure.bzl demands this for using nvcc + clang +build:nvcc_clang --action_env=TF_CUDA_CLANG="1" +build:nvcc_clang --action_env=TF_NVCC_CLANG="1" +build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc + + # Debug config build:dbg -c dbg # Only include debug info for files under tensorflow/, excluding kernels, to @@ -307,7 +315,7 @@ build:macos --copt=-w build:windows --copt=/W0 build:windows --host_copt=/W0 -# Suppress most C++ complier warnings to reduce log size but allow +# Suppress most C++ compiler warnings to reduce log size but allow # for specific warnings to still be present. build:linux --copt="-Wno-all" build:linux --copt="-Wno-extra" @@ -441,6 +449,9 @@ test:win_clang --host_linkopt=/FORCE:MULTIPLE # TODO(kanglan): Change v2's define to default behavior build:v2 --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1 +# Enable all targets in XLA +build:cpu_cross --define=with_cross_compiler_support=true + # Disable XLA on mobile. build:xla --define=with_xla_support=true # TODO: remove, it's on by default. build:android --define=with_xla_support=false @@ -527,8 +538,8 @@ build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.16-clang_confi test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda +build:rbe_linux_cuda_nvcc --config=nvcc_clang build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1 -build:rbe_linux_cuda_nvcc --action_env=TF_NVCC_CLANG="1" # TODO(kanglan): Remove rbe_win and rbe_win_py3* after b/289091160 is fixed build:rbe_win --config=rbe_base @@ -577,6 +588,7 @@ build:elinux_armhf --copt -mfp16-format=ieee # Load rc file written by ./configure. try-import %workspace%/.tf_configure.bazelrc +try-import %workspace%/xla_configure.bazelrc # Load rc file with user-specific options. try-import %workspace%/.bazelrc.user @@ -585,10 +597,16 @@ try-import %workspace%/.bazelrc.user # Build TensorFlow v2. test:release_base --test_size_filters=small,medium +# Ensure release_base is set on linux +build:release_linux_base --config=release_base + # Target the AVX instruction set build:release_linux_base --config=avx_linux -# Disable clang extention that rejects type definitions within offsetof. +# Enable support for all targets +build:release_base --config=cpu_cross + +# Disable clang extension that rejects type definitions within offsetof. # This was added in clang-16 by https://reviews.llvm.org/D133574. # Can be removed once upb is updated, since a type definition is used within # offset of in the current version of ubp. @@ -665,12 +683,23 @@ build:unsupported_gpu_linux --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gc build:unsupported_gpu_linux --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain build:release_cpu_macos --config=avx_linux -test:release_cpu_macos --config=release_base # Base build configs for macOS build:release_macos_base --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer build:release_macos_base --define=no_nccl_support=true --output_filter=^$ +# Ensure release_base is set on mac +build:release_macos_base --config=release_base + +# Build configs for macOS x86 +build:release_macos_x86 --config=release_macos_base +# Build with the AVX instruction set when on macOS x86 +build:release_macos_x86 --config=avx_linux +build:release_macos_x86 --cpu=darwin +# Target Catalina as the minimum compatible OS version +build:release_macos_x86 --macos_minimum_os=10.15 +build:release_macos_x86 --action_env MACOSX_DEPLOYMENT_TARGET=10.15 + # Build configs for macOS Arm64 build:release_macos_arm64 --config=release_macos_base build:release_macos_arm64 --cpu=darwin_arm64 @@ -685,13 +714,18 @@ test:release_macos_base --test_timeout=300,450,1200,3600 --test_output=errors test:release_macos_base --build_tests_only --keep_going test:release_macos_base --flaky_test_attempts=3 +# Test configs for macOS x86 +test:release_macos_x86 --config=release_macos_base + # Test configs for macOS Arm64 test:release_macos_arm64 --config=release_macos_base +# Ensure release_base is set on windows +build:release_cpu_windows --config=release_base + # TODO(kanglan): Update windows configs after b/289091160 is fixed build:release_cpu_windows --config=avx_win build:release_cpu_windows --define=no_tensorflow_py_deps=true -test:release_cpu_windows --config=release_base # Exclude TFRT integration for anything but Linux. build:android --config=no_tfrt @@ -707,7 +741,7 @@ build:no_tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compil # Use --config=tf_public_cache to try and use the TensorFlow public build cache # to build TensorFlow. Look at ci/official/envs to find which types of jobs # push to the cache. For macOS, use --config=tf_public_macos_cache -build:tf_public_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --remote_upload_local_results=false +build:tf_public_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/january2024" --remote_upload_local_results=false # Cache pushes are limited to TF's CI system. build:tf_public_cache_push --config=tf_public_cache --remote_upload_local_results=true --google_default_credentials # Public cache for macOS builds @@ -746,6 +780,11 @@ test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-os test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +# MACOS X86 WHEEL +test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium +test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. @@ -760,46 +799,137 @@ test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-os test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 PYCPP -test:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -test:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -test:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 +# In Linux Arm64 presubmit/continuous build, we cross-compile the binaries on +# Linux x86 so that we can use RBE. Since tests still need to run on the single +# host Arm64 machine, the build becomes too slow (~30 min) to be a presubmit. +# For testing purposes, we want to see the runtime performance of an +# experimental job that is build-only, i.e, we only build the test targets and +# do not run them. By prefixing the configs with "build", we can run both +# `bazel build` and `bazel test` commands with the same config as test configs +# inherit from build. +build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? -test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test +build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test # CROSS-COMPILE ARM64 PYCPP -test:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test +build:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test # Tests that fail only when cross-compiled -test:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test +build:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test # MACOS ARM64 PYCPP test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/core/kernels/image:resize_bicubic_op_test +# MACOS X86 PYCPP +# These are defined as build configs so that we can run a build only job. See +# the note under "ARM64 PYCPP" for more details. +build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --keep_going --test_lang_filters=cc,py --test_size_filters=small,medium +build:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... +# CROSS-COMPILE MACOS X86 PYCPP +build:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test +build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test # END TF TEST SUITE OPTIONS -# START LINUX AARCH64 CROSS-COMPILE CONFIGS +# START CROSS-COMPILE CONFIGS # Set execution platform to Linux x86 # Note: Lot of the "host_" flags such as "host_cpu" and "host_crosstool_top" # flags seem to be actually used to specify the execution platform details. It # seems it is this way because these flags are old and predate the distinction # between host and execution platform. -build:cross_compile_linux_arm64 --host_cpu=k8 -build:cross_compile_linux_arm64 --host_crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite -build:cross_compile_linux_arm64 --extra_execution_platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_x86_64 +build:cross_compile_base --host_cpu=k8 +build:cross_compile_base --host_crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +build:cross_compile_base --extra_execution_platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_x86_64 + +# XLA related settings for cross-compiled build. Certain paths are +# different in the XLA repo. +build:cross_compile_base_xla --host_cpu=k8 +build:cross_compile_base_xla --host_crosstool_top=//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +build:cross_compile_base_xla --extra_execution_platforms=//tools/toolchains/cross_compile/config:linux_x86_64 + +build:rbe_cross_compile_base --config=rbe_base +build:rbe_cross_compile_base --remote_instance_name=projects/tensorflow-testing/instances/default_instance + +# XLA depends on some local Python headers that are configured as Genrule. They +# are present on the local host machine but not on the remote execution machine, +# leading to build failures. To resolve the issue, the following line is added +# to make sure all Genrule targets are excuted locally. +build:rbe_cross_compile_base_xla --config=rbe_cross_compile_base +build:rbe_cross_compile_base_xla --strategy=Genrule=standalone + +# Due to the above strategy, all Genrule commands are executed locally, but the +# following actions invoke tools (E.g `flatc`, `llvm-tblgen`, etc.) that are +# only executabe on the RBE (x86) machine, so the strategy_regexp options are +# added to override and run the actions using remote strategy. +build:rbe_cross_compile_base_xla --strategy_regexp='Generating code from table.*=remote' +build:rbe_cross_compile_base_xla --strategy_regexp='Generating flatbuffer files.*=remote' +build:rbe_cross_compile_base_xla --strategy_regexp='Executing genrule @llvm-project.*=remote' + +# Test-related settings below this point +# We cannot run cross-compiled tests on the remote Linux x86 VMs so we need to +# force all tests to run locally on the Aarch64 host. +test:rbe_cross_compile_base --strategy=TestRunner=local --build_tests_only +test:rbe_cross_compile_base --verbose_failures=true --local_test_jobs=HOST_CPUS --test_output=errors + +test:rbe_cross_compile_base_xla --config=rbe_cross_compile_base + +# START LINUX AARCH64 CROSS-COMPILE CONFIGS +build:cross_compile_linux_arm64 --config=cross_compile_base # Set the target CPU to Aarch64 build:cross_compile_linux_arm64 --platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_aarch64 build:cross_compile_linux_arm64 --cpu=aarch64 build:cross_compile_linux_arm64 --crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite -# RBE configs +# XLA uses different paths for platforms and crosstool_top. +build:cross_compile_linux_arm64_xla --config=cross_compile_base_xla +build:cross_compile_linux_arm64_xla --platforms=//tools/toolchains/cross_compile/config:linux_aarch64 +build:cross_compile_linux_arm64_xla --crosstool_top=//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite + +# RBE cross-compile configs for Linux Aarch64 build:rbe_cross_compile_linux_arm64 --config=cross_compile_linux_arm64 -build:rbe_cross_compile_linux_arm64 --config=rbe_base -build:rbe_cross_compile_linux_arm64 --remote_instance_name=projects/tensorflow-testing/instances/default_instance +build:rbe_cross_compile_linux_arm64 --config=rbe_cross_compile_base +test:rbe_cross_compile_linux_arm64 --config=rbe_cross_compile_base + +# RBE cross-compile configs for XLA Linux Aarch64 +build:rbe_cross_compile_linux_arm64_xla --config=cross_compile_linux_arm64_xla +build:rbe_cross_compile_linux_arm64_xla --config=rbe_cross_compile_base_xla +test:rbe_cross_compile_linux_arm64_xla --config=rbe_cross_compile_base_xla -# Test-related settings below this point -# We cannot run cross-compiled tests on the remote Linux x86 VMs so we need to -# force all tests to run locally on the Aarch64 host. -test:rbe_cross_compile_linux_arm64 --strategy=TestRunner=local -test:rbe_cross_compile_linux_arm64 --verbose_failures=true --local_test_jobs=HOST_CPUS --test_output=errors -test:rbe_cross_compile_linux_arm64 --flaky_test_attempts=3 --build_tests_only # END LINUX AARCH64 CROSS-COMPILE CONFIGS + +# START MACOS CROSS-COMPILE CONFIGS +build:cross_compile_macos_x86 --config=cross_compile_base +build:cross_compile_macos_x86 --config=nonccl +# Target Catalina (10.15) as the minimum supported OS +build:cross_compile_macos_x86 --action_env MACOSX_DEPLOYMENT_TARGET=10.15 + +# Set the target CPU to Darwin x86 +build:cross_compile_macos_x86 --platforms=//tensorflow/tools/toolchains/cross_compile/config:darwin_x86_64 +build:cross_compile_macos_x86 --cpu=darwin +build:cross_compile_macos_x86 --crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +# When RBE cross-compiling for macOS, we need to explicitly register the +# toolchain. Otherwise, oddly, RBE complains that a "docker container must be +# specified". +build:cross_compile_macos_x86 --extra_toolchains=//tensorflow/tools/toolchains/cross_compile/config:macos-x86-cross-compile-cc-toolchain +# Map --platforms=darwin_x86_64 to --cpu=darwin and vice-versa to make selects() +# and transistions that use these flags work. +build:cross_compile_macos_x86 --platform_mappings=tensorflow/tools/toolchains/cross_compile/config/platform_mappings + +# RBE cross-compile configs for Darwin x86 +build:rbe_cross_compile_macos_x86 --config=cross_compile_macos_x86 +build:rbe_cross_compile_macos_x86 --config=rbe_cross_compile_base +build:rbe_cross_compile_macos_x86 --bes_upload_mode=nowait_for_upload_complete +test:rbe_cross_compile_macos_x86 --config=rbe_cross_compile_base +# Increase the test timeout as tests often take longer on mac. +test:rbe_cross_compile_macos_x86 --test_timeout=300,450,1200,3600 +# Limit jobs to 100 to avoid running into "out of memory" issues (b/316266643) +build:rbe_cross_compile_macos_x86 --jobs=100 +test:rbe_cross_compile_macos_x86 --jobs=100 +# END MACOS CROSS-COMPILE CONFIGS +# END CROSS-COMPILE CONFIGS + +# Try to load the XLA warnings config if available +try-import %workspace%/warnings.bazelrc diff --git a/third_party/tsl/.bazelversion b/third_party/tsl/.bazelversion index 204ac7c926e43..f3c238740e5bc 100644 --- a/third_party/tsl/.bazelversion +++ b/third_party/tsl/.bazelversion @@ -1,2 +1,2 @@ -6.4.0 +6.5.0 # NOTE: Update Bazel version in tensorflow/tools/ci_build/release/common.sh.oss \ No newline at end of file diff --git a/third_party/tsl/.kokoro/generate_index_html.sh b/third_party/tsl/.kokoro/generate_index_html.sh deleted file mode 100755 index 8870bb2818e52..0000000000000 --- a/third_party/tsl/.kokoro/generate_index_html.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/bin/bash -# Copyright 2022 Google LLC All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Generates a handy index.html with a bunch of Kokoro links for GitHub -# presubmits. -# Usage: generate_index_html.sh /path/to/output/index.html - -tee "$1" < - -#$KOKORO_GITHUB_PULL_REQUEST_NUMBER_tsl | $(basename "$KOKORO_JOB_NAME") - - -

TSL Job Logs and Links

-

Job Details

-
    -
  • Job name: $KOKORO_JOB_NAME
  • -
  • Job pool: $KOKORO_JOB_POOL
  • -
  • Job ID: $KOKORO_BUILD_ID
  • -
  • Current HEAD Piper Changelist (may be empty): cl/${KOKORO_PIPER_CHANGELIST:-not available}
  • -
  • Pull Request Number: $KOKORO_GITHUB_PULL_REQUEST_NUMBER_tsl
  • -
  • Pull Request Link: $KOKORO_GITHUB_PULL_REQUEST_URL_tsl
  • -
  • Commit: $KOKORO_GIT_COMMIT_tsl
  • -
-

Googlers-Only Links

- -

Non-Googler Links

- - -EOF diff --git a/third_party/tsl/.kokoro/linux/build.sh b/third_party/tsl/.kokoro/linux/build.sh deleted file mode 100644 index f05e02bfc6cc7..0000000000000 --- a/third_party/tsl/.kokoro/linux/build.sh +++ /dev/null @@ -1,73 +0,0 @@ -#!/bin/bash -# Copyright 2022 Google LLC All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# -e: abort script if one command fails -# -u: error if undefined variable used -# -o pipefail: entire command fails if pipe fails. watch out for yes | ... -# -o history: record shell history -set -euo pipefail -o history - -# Generate a templated results file to make output accessible to everyone -"$KOKORO_ARTIFACTS_DIR"/github/tsl/.kokoro/generate_index_html.sh "$KOKORO_ARTIFACTS_DIR"/index.html - -function is_continuous_job() { - [[ "$KOKORO_JOB_NAME" =~ tensorflow/tsl/.*continuous.* ]] -} - -ADDITIONAL_FLAGS="" -TAGS_FILTER="-no_oss,-oss_excluded,-oss_serial,-gpu,-requires-gpu-nvidia" - -if is_continuous_job ; then - ADDITIONAL_FLAGS="$ADDITIONAL_FLAGS --google_default_credentials" -else - ADDITIONAL_FLAGS="$ADDITIONAL_FLAGS --remote_upload_local_results=false" -fi - -# Pull the container (in case it was updated since the instance started) and -# store its SHA in the Sponge log. -docker pull "$DOCKER_IMAGE" -echo "TF_INFO_DOCKER_IMAGE,$DOCKER_IMAGE" >> "$KOKORO_ARTIFACTS_DIR/custom_sponge_config.csv" -echo "TF_INFO_DOCKER_SHA,$(docker pull "$DOCKER_IMAGE" | sed -n '/Digest:/s/Digest: //g p')" >> "$KOKORO_ARTIFACTS_DIR/custom_sponge_config.csv" - -# Start a container in the background -docker run --name tsl -w /tf/tsl -itd --rm \ - -v "$KOKORO_ARTIFACTS_DIR/github/tsl:/tf/tsl" \ - "$DOCKER_IMAGE" \ - bash - -# Build TSL -docker exec tsl bazel --bazelrc=/usertools/cpu.bazelrc build \ - --output_filter="" \ - --keep_going \ - --build_tag_filters=$TAGS_FILTER \ - --test_tag_filters=$TAGS_FILTER \ - --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/tsl/linux" \ - $ADDITIONAL_FLAGS \ - -- //tsl/... - -# Test TSL -docker exec tsl bazel --bazelrc=/usertools/cpu.bazelrc test \ - --output_filter="" \ - --keep_going \ - --flaky_test_attempts=3 \ - --test_output=errors \ - --build_tests_only \ - --build_tag_filters=$TAGS_FILTER \ - --test_tag_filters=$TAGS_FILTER \ - --verbose_failures=true \ - -- //tsl/... - -# Stop container -docker stop tsl diff --git a/third_party/tsl/.kokoro/linux/cpu/build_cpu.cfg b/third_party/tsl/.kokoro/linux/cpu/build_cpu.cfg deleted file mode 100644 index 8e105be39e67c..0000000000000 --- a/third_party/tsl/.kokoro/linux/cpu/build_cpu.cfg +++ /dev/null @@ -1,5 +0,0 @@ -build_file: "tsl/.kokoro/linux/build.sh" -env_vars: { - key: "DOCKER_IMAGE" - value: "gcr.io/tensorflow-sigs/build:latest-python3.9" -} \ No newline at end of file diff --git a/third_party/tsl/.kokoro/linux/cpu/common.cfg b/third_party/tsl/.kokoro/linux/cpu/common.cfg deleted file mode 100644 index e23a70ebe5127..0000000000000 --- a/third_party/tsl/.kokoro/linux/cpu/common.cfg +++ /dev/null @@ -1,11 +0,0 @@ -action { - define_artifacts { - # Sponge logs - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - # Full test logs to debug - regex: "**/*.tar.gz" - # Html helper for presubmits - regex: "**/*.html" - } -} \ No newline at end of file diff --git a/third_party/tsl/.kokoro/macos/build.sh b/third_party/tsl/.kokoro/macos/build.sh deleted file mode 100644 index 5f4a806d87be2..0000000000000 --- a/third_party/tsl/.kokoro/macos/build.sh +++ /dev/null @@ -1,111 +0,0 @@ -#!/bin/bash -# Copyright 2022 Google LLC All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# -e: abort script if one command fails -# -u: error if undefined variable used -# -o pipefail: entire command fails if pipe fails. watch out for yes | ... -# -o history: record shell history -set -euo pipefail -o history - -cd "${KOKORO_ARTIFACTS_DIR}/github/tsl" - -# Install Bazelisk, Bats, Pyenv, Python, upgrade pip, and activate ".tf-venv" -# virtual environment. We use the "PYENV_VERSION" variable here to decide which -# Python version to install. In addition, we update $PATH with the PYENV_ROOT -# environment variable and we set STATIC_DEPS=true for installing lxml for -# Python. Finally, we set up a symlink to the Python packages directory in -# ".tf-venv" which is referenced in macos.bazelrc. -function install_build_env_tools(){ - # Install Bazelisk; Useful as it would automatically pick the correct - # version of Bazel. - echo "===== Installing Bazelisk =====" - sudo wget --no-verbose -O "/usr/local/bin/bazel" \ - "https://github.com/bazelbuild/bazelisk/releases/download/v1.11.0/bazelisk-darwin-amd64" \ - && chmod +x "/usr/local/bin/bazel" - - echo "===== Installing Pyenv =====" - # Install pyenv; Set up a virtual environment to control dependencies and their - # versions - git clone --branch v2.3.17 https://github.com/pyenv/pyenv.git /Users/kbuilder/.tf_pyenv - export PYENV_ROOT=/Users/kbuilder/.tf_pyenv - export PATH="$PYENV_ROOT/bin:$PATH" # if `pyenv` is not already on PATH - eval "$(pyenv init --path)" - eval "$(pyenv init -)" - - echo "===== Installing Python =====" - # Install Python and set the local python version - pyenv install -s "${TF_PYENV_VERSION}" - pyenv rehash - pyenv local "${TF_PYENV_VERSION}" - # Do a sanity check to make sure that we using the correct Python version - echo "===== Python version =====" - python --version - # Set up virtual environment and activate it - python -m venv /Users/kbuilder/.tf-venv && source /Users/kbuilder/.tf-venv/bin/activate - - # Setup links to Python. Referenced in ./macos.bazelrc - ln -s /Users/kbuilder/.tf-venv/lib/python* /Users/kbuilder/.tf-venv/lib/python - - echo "===== Upgrading to latest pip =====" - python -m pip install --upgrade pip -} - -install_build_env_tools - -python -m pip install numpy==1.21.4 - -# Generate a templated results file to make output accessible to everyone -"$KOKORO_ARTIFACTS_DIR"/github/tsl/.kokoro/generate_index_html.sh "$KOKORO_ARTIFACTS_DIR"/index.html - -function is_continuous_job() { - [[ "$KOKORO_JOB_NAME" =~ tensorflow/tsl/.*continuous.* ]] -} - -# Set authentication for reading and writing cache from Google Cloud Storage -export GOOGLE_APPLICATION_CREDENTIALS="$KOKORO_KEYSTORE_DIR/73361_tensorflow_bazel_cache_writer" - -TAGS_FILTER="-no_oss,-oss_excluded,-gpu,-no_mac,-nomac,-mac_excluded" -ADDITIONAL_FLAGS="" - -if is_continuous_job ; then - ADDITIONAL_FLAGS="$ADDITIONAL_FLAGS --google_default_credentials" -else - ADDITIONAL_FLAGS="$ADDITIONAL_FLAGS --remote_upload_local_results=false" -fi - -# Build TSL -bazel build \ - --output_filter="" \ - --macos_minimum_os=10.15 \ - --build_tag_filters=$TAGS_FILTER \ - --test_tag_filters=$TAGS_FILTER \ - --keep_going \ - --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/tsl/macos" \ - $ADDITIONAL_FLAGS \ - -- //tsl/... - -# Test TSL -bazel test \ - --output_filter="" \ - --macos_minimum_os=10.15 \ - --test_tag_filters=-no_mac,-nomac,-mac_excluded \ - --keep_going \ - --test_output=errors \ - --build_tests_only \ - --build_tag_filters=$TAGS_FILTER \ - --test_tag_filters=$TAGS_FILTER \ - --verbose_failures=true \ - --flaky_test_attempts=3 \ - -- //tsl/... diff --git a/third_party/tsl/.kokoro/macos/cpu/common.cfg b/third_party/tsl/.kokoro/macos/cpu/common.cfg deleted file mode 100644 index b12f5c22a4f83..0000000000000 --- a/third_party/tsl/.kokoro/macos/cpu/common.cfg +++ /dev/null @@ -1,25 +0,0 @@ -# Not sure how long the timeout should be -timeout_mins: 720 - -action { - define_artifacts { - # Sponge logs - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - # Full test logs to debug the log squasher, and libtf.tar.gz - regex: "**/*.tar.gz" - # Html helper for presubmits - regex: "**/*.html" - } -} - -before_action { - fetch_keystore { - # Authentication for reading and writing cache to/from Google Cloud Storage - keystore_resource { - keystore_config_id: 73361 - keyname: "tensorflow_bazel_cache_writer" - backend: "blade:keystore-fastconfigpush" # disable-keystore-reliability-check - } - } -} \ No newline at end of file diff --git a/third_party/tsl/.kokoro/macos/cpu/cpu_py39_full.cfg b/third_party/tsl/.kokoro/macos/cpu/cpu_py39_full.cfg deleted file mode 100644 index 885373fc8c1d0..0000000000000 --- a/third_party/tsl/.kokoro/macos/cpu/cpu_py39_full.cfg +++ /dev/null @@ -1,5 +0,0 @@ -build_file: "tsl/.kokoro/macos/build.sh" -env_vars: { - key: "TF_PYENV_VERSION" - value: "3.9.16" -} diff --git a/third_party/tsl/.kokoro/windows/build.bat b/third_party/tsl/.kokoro/windows/build.bat deleted file mode 100644 index 4fb8ddf28ed4b..0000000000000 --- a/third_party/tsl/.kokoro/windows/build.bat +++ /dev/null @@ -1,20 +0,0 @@ -@REM Copyright 2023 Google LLC - -@REM Licensed under the Apache License, Version 2.0 (the "License"); -@REM you may not use this file except in compliance with the License. -@REM You may obtain a copy of the License at - -@REM https://www.apache.org/licenses/LICENSE-2.0 - -@REM Unless required by applicable law or agreed to in writing, software -@REM distributed under the License is distributed on an "AS IS" BASIS, -@REM WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -@REM See the License for the specific language governing permissions and -@REM limitations under the License. - -SET TMPDIR=T:/tmp -SET TMP=%TMPDIR% -SET TEMP=%TMPDIR% - -bash -l %0/../windows_build.sh %* -exit /b %ERRORLEVEL% diff --git a/third_party/tsl/.kokoro/windows/cpu/build_cpu_py39.cfg b/third_party/tsl/.kokoro/windows/cpu/build_cpu_py39.cfg deleted file mode 100644 index 3a935b23ed828..0000000000000 --- a/third_party/tsl/.kokoro/windows/cpu/build_cpu_py39.cfg +++ /dev/null @@ -1 +0,0 @@ -build_file: "tsl/.kokoro/windows/build.bat" diff --git a/third_party/tsl/.kokoro/windows/cpu/common.cfg b/third_party/tsl/.kokoro/windows/cpu/common.cfg deleted file mode 100644 index 8f936b4071a99..0000000000000 --- a/third_party/tsl/.kokoro/windows/cpu/common.cfg +++ /dev/null @@ -1,12 +0,0 @@ -# timeout_mins: 6000 -action { - define_artifacts { - regex: "**/sponge_log.xml" - # regex: "**/.tf_configure.bazelrc" - regex: "**/lib_package/*" - # regex: "**/java.log" - # regex: "**/win_minidumps/*.dmp" - # Html helper for presubmits - regex: "**/*.html" - } -} diff --git a/third_party/tsl/.kokoro/windows/windows_build.sh b/third_party/tsl/.kokoro/windows/windows_build.sh deleted file mode 100644 index 331efa186fb87..0000000000000 --- a/third_party/tsl/.kokoro/windows/windows_build.sh +++ /dev/null @@ -1,66 +0,0 @@ -#!/bin/bash -# Copyright 2022 Google LLC All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# -e: abort script if one command fails -# -u: error if undefined variable used -# -o pipefail: entire command fails if pipe fails. watch out for yes | ... -# Note: set -x +x around anything you want to have logged. -set -euo pipefail - -cd "${KOKORO_ARTIFACTS_DIR}/github/tsl" - -# Generate a templated results file to make output accessible to everyone -"$KOKORO_ARTIFACTS_DIR"/github/tsl/.kokoro/generate_index_html.sh "$KOKORO_ARTIFACTS_DIR"/index.html - -function is_continuous_job() { - [[ "$KOKORO_JOB_NAME" =~ tensorflow/tsl/.*continuous.* ]] -} - -ADDITIONAL_FLAGS="" -TAGS_FILTER="-no_oss,-oss_excluded,-gpu,-no_windows,-windows_excluded" - -if is_continuous_job ; then - ADDITIONAL_FLAGS="$ADDITIONAL_FLAGS --google_default_credentials" -else - ADDITIONAL_FLAGS="$ADDITIONAL_FLAGS --remote_upload_local_results=false" -fi - -export PATH="$PATH:/c/Python38" - -# Build TSL -/c/tools/bazel.exe build \ - --output_filter="" \ - --keep_going \ - --build_tag_filters=$TAGS_FILTER \ - --test_tag_filters=$TAGS_FILTER \ - --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/tsl/windows" \ - $ADDITIONAL_FLAGS \ - -- //tsl/... \ - || { echo "Bazel Build Failed" && exit 1; } - -# Test TSL TODO(ddunleavy) enable all tests -/c/tools/bazel.exe test \ - --output_filter="" \ - --flaky_test_attempts=3 \ - --test_output=errors \ - --build_tests_only \ - --verbose_failures=true \ - --build_tag_filters=$TAGS_FILTER \ - --test_tag_filters=$TAGS_FILTER \ - --keep_going \ - -- //tsl/... -//tsl/platform:subprocess_test -//tsl/platform/cloud:google_auth_provider_test -//tsl/platform/cloud:oauth_client_test \ - || { echo "Bazel Test Failed" && exit 1; } - -exit 0 diff --git a/third_party/tsl/BUILD.bazel b/third_party/tsl/BUILD.bazel index cb36c9fd6150a..9bcea9e1c01ed 100644 --- a/third_party/tsl/BUILD.bazel +++ b/third_party/tsl/BUILD.bazel @@ -2,7 +2,7 @@ load("@rules_license//rules:license.bzl", "license") package( default_applicable_licenses = [":license"], - default_visibility = ["//visibility:private"], + default_visibility = ["//visibility:public"], ) licenses(["notice"]) diff --git a/third_party/tsl/README.md b/third_party/tsl/README.md index ff61283c5e535..6f4ab2025257a 100644 --- a/third_party/tsl/README.md +++ b/third_party/tsl/README.md @@ -10,16 +10,6 @@ This repo contains base utilities and cross-platform support for projects like > [upstream location](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tsl) > to make any contributions or report any issues. -## Contacts - -Discord TBA - -Community proposals TBA - -Community meetings TBA - -Additional contacts TBA - ## Code of Conduct While under TensorFlow governance, all community spaces are subject to the diff --git a/third_party/tsl/WORKSPACE b/third_party/tsl/WORKSPACE index cab76d7b38432..6ad0d6e0e2b7b 100644 --- a/third_party/tsl/WORKSPACE +++ b/third_party/tsl/WORKSPACE @@ -1,3 +1,4 @@ +# buildifier: disable=load-on-top workspace(name = "tsl") # Initialize the TSL repository and all dependencies. @@ -6,6 +7,9 @@ workspace(name = "tsl") # restriction that load() statements need to be at the top of .bzl files. # E.g. we can not retrieve a new repository with http_archive and then load() # a macro from that repository in the same file. + +# buildifier: disable=load-on-top + load(":workspace3.bzl", "tsl_workspace3") tsl_workspace3() diff --git a/third_party/tsl/opensource_only.files b/third_party/tsl/opensource_only.files index fa84f35768a5d..2920ef7997ad9 100644 --- a/third_party/tsl/opensource_only.files +++ b/third_party/tsl/opensource_only.files @@ -17,8 +17,6 @@ third_party/ducc/threading.h: third_party/eigen3/BUILD: third_party/eigen3/LICENSE: third_party/eigen3/eigen_archive.BUILD: -third_party/gif.BUILD: -third_party/gif_fix_strtok_r.patch: third_party/git/BUILD.tpl: third_party/git/BUILD: third_party/git/git_configure.bzl: @@ -53,7 +51,6 @@ third_party/llvm_openmp/cmake_vars.bzl: third_party/llvm_openmp/expand_cmake_vars:.py third_party/llvm_openmp/openmp.bzl: third_party/mkl/BUILD: -third_party/mkl/build_defs.bzl: third_party/mkl_dnn/LICENSE: third_party/mkl_dnn/build_defs.bzl: third_party/mkl_dnn/mkldnn_acl.BUILD: @@ -63,10 +60,11 @@ third_party/nccl/LICENSE: third_party/nccl/archive.BUILD: third_party/nccl/archive.patch: third_party/nccl/build_defs.bzl.tpl: +third_party/nccl/generated_names.bzl.tpl: third_party/nccl/nccl_configure.bzl: third_party/nccl/system.BUILD.tpl: -third_party/png.BUILD: -third_party/png_fix_rpi.patch: +third_party/nvtx/BUILD: +third_party/nvtx/LICENSE: third_party/protobuf/BUILD: third_party/py/non_hermetic/BUILD.tpl: third_party/py/non_hermetic/BUILD: @@ -143,9 +141,4 @@ tools/toolchains/win/bazel_211/BUILD: tools/toolchains/win/tf_win_05022023/BUILD: tools/toolchains/win_1803/py38/BUILD: tools/toolchains/win_1803/py39/BUILD: -tsl/cuda/stub.bzl: -tsl/mkl/BUILD: -tsl/mkl/LICENSE: -tsl/mkl/MKL_LICENSE: -tsl/mkl/build_defs.bzl: -tsl/platform/default/build_config/BUILD: +tsl/profiler/BUILD: diff --git a/third_party/tsl/third_party/absl/system.absl.functional.BUILD b/third_party/tsl/third_party/absl/system.absl.functional.BUILD index a4f70acf35ca7..9439bd0ba222e 100644 --- a/third_party/tsl/third_party/absl/system.absl.functional.BUILD +++ b/third_party/tsl/third_party/absl/system.absl.functional.BUILD @@ -2,6 +2,10 @@ load("@rules_cc//cc:defs.bzl", "cc_library") package(default_visibility = ["//visibility:public"]) +cc_library( + name = "any_invocable", +) + cc_library( name = "bind_front", ) diff --git a/third_party/tsl/third_party/curl.BUILD b/third_party/tsl/third_party/curl.BUILD index 017f8cc28170b..14067c2dfd0db 100644 --- a/third_party/tsl/third_party/curl.BUILD +++ b/third_party/tsl/third_party/curl.BUILD @@ -14,11 +14,6 @@ CURL_WIN_COPTS = [ "/DCURL_DISABLE_PROXY", "/DHAVE_LIBZ", "/DHAVE_ZLIB_H", - # Defining _USING_V110_SDK71_ is hackery to defeat curl's incorrect - # detection of what OS releases we can build on with VC 2012. This - # may not be needed (or may have to change) if the WINVER setting - # changes in //third_party/msvc/vc_12_0/CROSSTOOL. - "/D_USING_V110_SDK71_", ] CURL_WIN_SRCS = [ diff --git a/third_party/tsl/third_party/gif.BUILD b/third_party/tsl/third_party/gif.BUILD deleted file mode 100644 index 51621ba953e6e..0000000000000 --- a/third_party/tsl/third_party/gif.BUILD +++ /dev/null @@ -1,61 +0,0 @@ -# Description: -# A library for decoding and encoding GIF images - -licenses(["notice"]) # MIT - -exports_files(["COPYING"]) - -cc_library( - name = "gif", - srcs = [ - "dgif_lib.c", - "egif_lib.c", - "gif_err.c", - "gif_font.c", - "gif_hash.c", - "gif_hash.h", - "gif_lib_private.h", - "gifalloc.c", - "openbsd-reallocarray.c", - "quantize.c", - ], - hdrs = ["gif_lib.h"], - defines = select({ - ":android": [ - "S_IREAD=S_IRUSR", - "S_IWRITE=S_IWUSR", - "S_IEXEC=S_IXUSR", - ], - "//conditions:default": [], - }), - includes = ["."], - visibility = ["//visibility:public"], - deps = select({ - ":windows": [":windows_polyfill"], - "//conditions:default": [], - }), -) - -cc_library( - name = "windows_polyfill", - hdrs = ["windows/unistd.h"], - includes = ["windows"], -) - -genrule( - name = "windows_unistd_h", - outs = ["windows/unistd.h"], - cmd = "touch $@", -) - -config_setting( - name = "windows", - values = { - "cpu": "x64_windows", - }, -) - -config_setting( - name = "android", - values = {"crosstool_top": "//external:android/crosstool"}, -) diff --git a/third_party/tsl/third_party/gif_fix_strtok_r.patch b/third_party/tsl/third_party/gif_fix_strtok_r.patch deleted file mode 100644 index c9c9c30c41fab..0000000000000 --- a/third_party/tsl/third_party/gif_fix_strtok_r.patch +++ /dev/null @@ -1,15 +0,0 @@ -diff -r -u ./fixed_gif_font.c ./gif_font.c ---- ./fixed_gif_font.c 2019-09-05 11:05:25.009598262 -0700 -+++ ./gif_font.c 2019-09-05 10:52:45.308389085 -0700 -@@ -11,6 +11,11 @@ - - #include "gif_lib.h" - -+// Windows doesn't have strtok_r. -+#if defined(WIN32) || defined(_WIN32) || defined(__WIN32) && !defined(__CYGWIN__) -+#define strtok_r strtok_s -+#endif -+ - /***************************************************************************** - Ascii 8 by 8 regular font - only first 128 characters are supported. - *****************************************************************************/ diff --git a/third_party/tsl/third_party/git/git_configure.bzl b/third_party/tsl/third_party/git/git_configure.bzl index dd6202639195c..3ce64242af6af 100644 --- a/third_party/tsl/third_party/git/git_configure.bzl +++ b/third_party/tsl/third_party/git/git_configure.bzl @@ -38,10 +38,10 @@ def _git_conf_impl(repository_ctx): ) tensorflow_root_path = str(repository_ctx.path( - Label("@tsl//:BUILD"), + Label("@org_tensorflow//:BUILD"), ))[:-len("BUILD")] python_script_path = repository_ctx.path( - Label("@tsl//tensorflow/tools/git:gen_git_source.py"), + Label("@org_tensorflow//tensorflow/tools/git:gen_git_source.py"), ) generated_files_path = repository_ctx.path("gen") diff --git a/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl b/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl index 0da1d7b58f4bb..74fafb9b32f51 100755 --- a/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl +++ b/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -41,7 +41,7 @@ import os import subprocess import re import sys -import pipes +import shlex # Template values set by cuda_autoconf. CPU_COMPILER = ('%{cpu_compiler}') @@ -299,7 +299,7 @@ def main(): if args.x and args.x[0] == 'cuda': if args.cuda_log: Log('-x cuda') - leftover = [pipes.quote(s) for s in leftover] + leftover = [shlex.quote(s) for s in leftover] if args.cuda_log: Log('using nvcc') return InvokeNvcc(leftover, log=args.cuda_log) diff --git a/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index 77ec948af32c6..d1d44f9fdb2ce 100755 --- a/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -75,6 +75,7 @@ def GetHostCompilerOptions(argv): parser.add_argument('--sysroot', nargs=1) parser.add_argument('-g', nargs='*', action='append') parser.add_argument('-fno-canonical-system-headers', action='store_true') + parser.add_argument('--genco', action='store_true') args, _ = parser.parse_known_args(argv) @@ -90,6 +91,8 @@ def GetHostCompilerOptions(argv): opts += ' -no-canonical-prefixes' if args.sysroot: opts += ' --sysroot ' + args.sysroot[0] + if args.genco: + opts += ' --genco' return opts diff --git a/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl index 189d3e3e78400..bc865cecb3240 100644 --- a/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl +++ b/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl @@ -94,6 +94,25 @@ def if_cuda_is_configured(x, no_cuda = []): return select({"//conditions:default": x}) return select({"//conditions:default": no_cuda}) +def if_cuda_newer_than(wanted_ver, if_true, if_false = []): + """Tests if CUDA was enabled during the configured process and if the + configured version is at least `wanted_ver`. `wanted_ver` needs + to be provided as a string in the format `_`. + Example: `11_0` + """ + + wanted_major = int(wanted_ver.split('_')[0]) + wanted_minor = int(wanted_ver.split('_')[1]) + + configured_version = "%{cuda_version}" + configured_major = int(configured_version.split('.')[0]) + configured_minor = int(configured_version.split('.')[1]) + + if %{cuda_is_configured} and (wanted_major, wanted_minor) <= (configured_major, configured_minor): + return select({"//conditions:default": if_true}) + return select({"//conditions:default": if_false}) + + def cuda_header_library( name, hdrs, diff --git a/third_party/tsl/third_party/gpus/cuda_configure.bzl b/third_party/tsl/third_party/gpus/cuda_configure.bzl index 7ee96c912280c..f4ed97ac4eb07 100644 --- a/third_party/tsl/third_party/gpus/cuda_configure.bzl +++ b/third_party/tsl/third_party/gpus/cuda_configure.bzl @@ -26,7 +26,6 @@ * `PYTHON_BIN_PATH`: The python binary path """ -load("//third_party/clang_toolchain:download_clang.bzl", "download_clang") load( "@bazel_tools//tools/cpp:lib_cc_configure.bzl", "escape_string", @@ -38,6 +37,7 @@ load( "find_vc_path", "setup_vc_env_vars", ) +load("//third_party/clang_toolchain:download_clang.bzl", "download_clang") load( "//third_party/remote_config:common.bzl", "config_repo_label", @@ -317,11 +317,52 @@ def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sysroot): ).stdout.strip() + "/share" inc_dirs += "\n" + resource_dir - return [ + compiler_includes = [ _normalize_include_path(repository_ctx, _cxx_inc_convert(p)) for p in inc_dirs.split("\n") ] + # The compiler might be on a symlink, e.g. /symlink -> /opt/gcc + # The above keeps only the resolved paths to the default includes (e.g. /opt/gcc/include/c++/11) + # but Bazel might encounter either (usually reported by the compiler) + # especially when a compiler wrapper (e.g. ccache) is used. + # So we need to also include paths where symlinks are not resolved. + + # Try to find real path to CC installation to "see through" compiler wrappers + # GCC has the path to g++ + index1 = result.stderr.find("COLLECT_GCC=") + if index1 != -1: + index1 = result.stderr.find("=", index1) + index2 = result.stderr.find("\n", index1) + cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname.dirname + else: + # Clang has the directory + index1 = result.stderr.find("InstalledDir: ") + if index1 != -1: + index1 = result.stderr.find(" ", index1) + index2 = result.stderr.find("\n", index1) + cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname + else: + # Fallback to the CC path + cc_topdir = repository_ctx.path(cc).dirname.dirname + + # We now have the compiler installation prefix, e.g. /symlink/gcc + # And the resolved installation prefix, e.g. /opt/gcc + cc_topdir_resolved = str(realpath(repository_ctx, cc_topdir)).strip() + cc_topdir = str(cc_topdir).strip() + + # If there is (any!) symlink involved we add paths where the unresolved installation prefix is kept. + # e.g. [/opt/gcc/include/c++/11, /opt/gcc/lib/gcc/x86_64-linux-gnu/11/include, /other/path] + # adds [/symlink/include/c++/11, /symlink/lib/gcc/x86_64-linux-gnu/11/include] + if cc_topdir_resolved != cc_topdir: + unresolved_compiler_includes = [ + cc_topdir + inc[len(cc_topdir_resolved):] + for inc in compiler_includes + if inc.startswith(cc_topdir_resolved) + ] + compiler_includes = compiler_includes + unresolved_compiler_includes + return compiler_includes + def get_cxx_inc_directories(repository_ctx, cc, tf_sysroot): """Compute the list of default C and C++ include directories.""" @@ -471,6 +512,8 @@ def compute_capabilities(repository_ctx): continue if len(capability) == len(prefix) + 2 and capability[-2:].isdigit(): continue + if len(capability) == len(prefix) + 3 and capability.endswith("90a"): + continue auto_configure_fail("Invalid compute capability: %s" % capability) return capabilities @@ -784,6 +827,7 @@ def _create_dummy_repository(repository_ctx): "%{cuda_is_configured}": "False", "%{cuda_extra_copts}": "[]", "%{cuda_gpu_architectures}": "[]", + "%{cuda_version}": "0.0", }, ) _tpl( @@ -1126,7 +1170,16 @@ def _create_local_cuda_repository(repository_ctx): # Select the headers based on the cuDNN version (strip '64_' for Windows). cudnn_headers = ["cudnn.h"] - if cuda_config.cudnn_version.rsplit("_", 1)[-1] >= "8": + if cuda_config.cudnn_version.rsplit("_", 1)[-1] >= "9": + cudnn_headers += [ + "cudnn_adv.h", + "cudnn_backend.h", + "cudnn_cnn.h", + "cudnn_graph.h", + "cudnn_ops.h", + "cudnn_version.h", + ] + elif cuda_config.cudnn_version.rsplit("_", 1)[-1] >= "8": cudnn_headers += [ "cudnn_backend.h", "cudnn_adv_infer.h", @@ -1162,6 +1215,7 @@ def _create_local_cuda_repository(repository_ctx): cuda_config.compute_capabilities, ), "%{cuda_gpu_architectures}": str(cuda_config.compute_capabilities), + "%{cuda_version}": cuda_config.cuda_version, }, ) @@ -1375,6 +1429,7 @@ def _create_remote_cuda_repository(repository_ctx, remote_config_repo): repository_ctx, compute_capabilities(repository_ctx), ), + "%{cuda_version}": get_host_environ(repository_ctx, _TF_CUDA_VERSION), }, ) repository_ctx.template( diff --git a/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl index 2b4595bb22288..339733755d6f1 100644 --- a/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl +++ b/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl @@ -38,6 +38,16 @@ def rocm_version_number(): """Returns a list of supported GPU architectures.""" return %{rocm_version_number} +def if_gpu_is_configured(if_true, if_false = []): + """Tests if ROCm or CUDA was enabled during the configure process. + + Unlike if_rocm() or if_cuda(), this does not require that we are building + with --config=rocm or --config=cuda, respectively. Used to allow non-GPU + code to depend on ROCm or CUDA libraries. + + """ + return select({"//conditions:default": %{gpu_is_configured}}) + def if_rocm_is_configured(x): """Tests if the ROCm was enabled during the configure process. diff --git a/third_party/tsl/third_party/gpus/rocm_configure.bzl b/third_party/tsl/third_party/gpus/rocm_configure.bzl index 13d630ea8ea2b..66a53276a194c 100644 --- a/third_party/tsl/third_party/gpus/rocm_configure.bzl +++ b/third_party/tsl/third_party/gpus/rocm_configure.bzl @@ -8,12 +8,6 @@ * `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets. """ -load( - ":cuda_configure.bzl", - "make_copy_dir_rule", - "make_copy_files_rule", - "to_list_of_strings", -) load( "//third_party/remote_config:common.bzl", "config_repo_label", @@ -28,6 +22,13 @@ load( "realpath", "which", ) +load( + ":cuda_configure.bzl", + "enable_cuda", + "make_copy_dir_rule", + "make_copy_files_rule", + "to_list_of_strings", +) _GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH" _GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX" @@ -449,6 +450,7 @@ def _create_dummy_repository(repository_ctx): "rocm:build_defs.bzl", { "%{rocm_is_configured}": "False", + "%{gpu_is_configured}": "if_true" if enable_cuda(repository_ctx) else "if_false", "%{rocm_extra_copts}": "[]", "%{rocm_gpu_architectures}": "[]", "%{rocm_version_number}": "0", @@ -634,6 +636,7 @@ def _create_local_rocm_repository(repository_ctx): tpl_paths["rocm:build_defs.bzl"], { "%{rocm_is_configured}": "True", + "%{gpu_is_configured}": "if_true", "%{rocm_extra_copts}": _compute_rocm_extra_copts( repository_ctx, rocm_config.amdgpu_targets, @@ -762,6 +765,7 @@ def _create_remote_rocm_repository(repository_ctx, remote_config_repo): "rocm:build_defs.bzl", { "%{rocm_is_configured}": "True", + "%{gpu_is_configured}": "if_true", "%{rocm_extra_copts}": _compute_rocm_extra_copts( repository_ctx, [], #_compute_capabilities(repository_ctx) @@ -815,6 +819,7 @@ _ENVIRONS = [ _GCC_HOST_COMPILER_PATH, _GCC_HOST_COMPILER_PREFIX, "TF_NEED_ROCM", + "TF_NEED_CUDA", # Needed by the `if_gpu_is_configured` macro _ROCM_TOOLKIT_PATH, _TF_ROCM_AMDGPU_TARGETS, ] diff --git a/third_party/tsl/third_party/hwloc/hwloc.BUILD b/third_party/tsl/third_party/hwloc/hwloc.BUILD index 52734c6fceb4f..86704f5b7f27a 100644 --- a/third_party/tsl/third_party/hwloc/hwloc.BUILD +++ b/third_party/tsl/third_party/hwloc/hwloc.BUILD @@ -1,5 +1,8 @@ # hwloc: Portable Hardware Locality Library +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") + package( default_visibility = ["//visibility:public"], ) @@ -8,9 +11,6 @@ licenses(["notice"]) exports_files(["COPYING"]) -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("@bazel_skylib//rules:expand_template.bzl", "expand_template") - COMMON_INCLUDE_COPTS = [ "-I.", "-Ihwloc", diff --git a/third_party/tsl/third_party/implib_so/workspace.bzl b/third_party/tsl/third_party/implib_so/workspace.bzl index 01dad3b169f40..37f36cc135fd6 100644 --- a/third_party/tsl/third_party/implib_so/workspace.bzl +++ b/third_party/tsl/third_party/implib_so/workspace.bzl @@ -6,8 +6,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): tf_http_archive( name = "implib_so", - strip_prefix = "Implib.so-5fb84c2a750434b9df1da67d67b749eb929598f1", - sha256 = "10de0a616df24849f2a883747784c115f209708960e44556f5ce384de6f103e8", - urls = tf_mirror_urls("https://github.com/yugr/Implib.so/archive/5fb84c2a750434b9df1da67d67b749eb929598f1.tar.gz"), + strip_prefix = "Implib.so-2cce6cab8ff2c15f9da858ea0b68646a8d62aef2", + sha256 = "4ef3089969d57a5b60bb41b8212c478eaa15c56941f86d4bf5e7f98a3afd24e8", + urls = tf_mirror_urls("https://github.com/yugr/Implib.so/archive/2cce6cab8ff2c15f9da858ea0b68646a8d62aef2.tar.gz"), build_file = "//third_party/implib_so:implib_so.BUILD", ) diff --git a/third_party/tsl/third_party/jpeg/BUILD b/third_party/tsl/third_party/jpeg/BUILD deleted file mode 100644 index ed1568c32f33e..0000000000000 --- a/third_party/tsl/third_party/jpeg/BUILD +++ /dev/null @@ -1,3 +0,0 @@ -# Needed to make this a package. - -# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/third_party/tsl/third_party/jpeg/BUILD.system b/third_party/tsl/third_party/jpeg/BUILD.system deleted file mode 100644 index f4f52da9bdae1..0000000000000 --- a/third_party/tsl/third_party/jpeg/BUILD.system +++ /dev/null @@ -1,12 +0,0 @@ -licenses(["notice"]) # custom notice-style license, see LICENSE.md - -filegroup( - name = "LICENSE.md", - visibility = ["//visibility:public"], -) - -cc_library( - name = "jpeg", - linkopts = ["-ljpeg"], - visibility = ["//visibility:public"], -) diff --git a/third_party/tsl/third_party/jpeg/jpeg.BUILD b/third_party/tsl/third_party/jpeg/jpeg.BUILD deleted file mode 100644 index 9f61f9e31e5e1..0000000000000 --- a/third_party/tsl/third_party/jpeg/jpeg.BUILD +++ /dev/null @@ -1,806 +0,0 @@ -# Description: -# libjpeg-turbo is a drop in replacement for jpeglib optimized with SIMD. - -load("@bazel_skylib//rules:expand_template.bzl", "expand_template") -load("@bazel_skylib//rules:common_settings.bzl", "string_flag") - -licenses(["notice"]) # custom notice-style license, see LICENSE.md - -exports_files(["LICENSE.md"]) - -WIN_COPTS = [ - "/Ox", - "-DWITH_SIMD", - "-wd4996", -] - -libjpegturbo_copts = select({ - ":android": [ - "-O3", - "-fPIC", - "-w", - ], - ":windows": WIN_COPTS, - "//conditions:default": [ - "-O3", - "-w", - ], -}) + select({ - ":armeabi-v7a": [ - "-D__ARM_NEON__", - "-DNEON_INTRINSICS", - "-march=armv7-a", - "-mfpu=neon", - "-mfloat-abi=softfp", - "-fprefetch-loop-arrays", - ], - ":arm64-v8a": [ - "-DNEON_INTRINSICS", - ], - ":linux_ppc64le": [ - "-mcpu=power8", - "-mtune=power8", - ], - "//conditions:default": [], -}) - -cc_library( - name = "jpeg", - srcs = [ - "jaricom.c", - "jcapimin.c", - "jcapistd.c", - "jcarith.c", - "jccoefct.c", - "jccolor.c", - "jcdctmgr.c", - "jchuff.c", - "jchuff.h", - "jcinit.c", - "jcmainct.c", - "jcmarker.c", - "jcmaster.c", - "jcomapi.c", - "jconfig.h", - "jconfigint.h", - "jcparam.c", - "jcphuff.c", - "jcprepct.c", - "jcsample.c", - "jctrans.c", - "jdapimin.c", - "jdapistd.c", - "jdarith.c", - "jdatadst.c", - "jdatasrc.c", - "jdcoefct.c", - "jdcoefct.h", - "jdcolor.c", - "jdct.h", - "jddctmgr.c", - "jdhuff.c", - "jdhuff.h", - "jdinput.c", - "jdmainct.c", - "jdmainct.h", - "jdmarker.c", - "jdmaster.c", - "jdmaster.h", - "jdmerge.c", - "jdmerge.h", - "jdphuff.c", - "jdpostct.c", - "jdsample.c", - "jdsample.h", - "jdtrans.c", - "jerror.c", - "jfdctflt.c", - "jfdctfst.c", - "jfdctint.c", - "jidctflt.c", - "jidctfst.c", - "jidctint.c", - "jidctred.c", - "jinclude.h", - "jmemmgr.c", - "jmemnobs.c", - "jmemsys.h", - "jpeg_nbits_table.h", - "jpegcomp.h", - "jquant1.c", - "jquant2.c", - "jutils.c", - "jversion.h", - ], - hdrs = [ - "jccolext.c", # should have been named .inc - "jdcol565.c", # should have been named .inc - "jdcolext.c", # should have been named .inc - "jdmrg565.c", # should have been named .inc - "jdmrgext.c", # should have been named .inc - "jerror.h", - "jmorecfg.h", - "jpegint.h", - "jpeglib.h", - "jstdhuff.c", # should have been named .inc - ], - copts = libjpegturbo_copts, - visibility = ["//visibility:public"], - deps = select({ - ":nosimd": [":simd_none"], - ":k8": [":simd_x86_64"], - ":armeabi-v7a": [":simd_armv7a"], - ":arm64-v8a": [":simd_armv8a"], - ":linux_ppc64le": [":simd_altivec"], - ":windows": [":simd_win_x86_64"], - "//conditions:default": [":simd_none"], - }), -) - -cc_library( - name = "simd_altivec", - srcs = [ - "jchuff.h", - "jconfig.h", - "jconfigint.h", - "jdct.h", - "jerror.h", - "jinclude.h", - "jmorecfg.h", - "jpegint.h", - "jpeglib.h", - "jsimd.h", - "jsimddct.h", - "simd/jsimd.h", - "simd/powerpc/jccolor-altivec.c", - "simd/powerpc/jcgray-altivec.c", - "simd/powerpc/jcsample-altivec.c", - "simd/powerpc/jdcolor-altivec.c", - "simd/powerpc/jdmerge-altivec.c", - "simd/powerpc/jdsample-altivec.c", - "simd/powerpc/jfdctfst-altivec.c", - "simd/powerpc/jfdctint-altivec.c", - "simd/powerpc/jidctfst-altivec.c", - "simd/powerpc/jidctint-altivec.c", - "simd/powerpc/jquanti-altivec.c", - "simd/powerpc/jsimd.c", - ], - hdrs = [ - "simd/powerpc/jccolext-altivec.c", - "simd/powerpc/jcgryext-altivec.c", - "simd/powerpc/jcsample.h", - "simd/powerpc/jdcolext-altivec.c", - "simd/powerpc/jdmrgext-altivec.c", - "simd/powerpc/jsimd_altivec.h", - ], - copts = libjpegturbo_copts, -) - -SRCS_SIMD_COMMON = [ - "jchuff.h", - "jconfig.h", - "jconfigint.h", - "jdct.h", - "jerror.h", - "jinclude.h", - "jmorecfg.h", - "jpegint.h", - "jpeglib.h", - "jsimddct.h", - "jsimd.h", - "simd/jsimd.h", -] - -cc_library( - name = "simd_x86_64", - srcs = [ - "simd/x86_64/jccolor-avx2.o", - "simd/x86_64/jccolor-sse2.o", - "simd/x86_64/jcgray-avx2.o", - "simd/x86_64/jcgray-sse2.o", - "simd/x86_64/jchuff-sse2.o", - "simd/x86_64/jcphuff-sse2.o", - "simd/x86_64/jcsample-avx2.o", - "simd/x86_64/jcsample-sse2.o", - "simd/x86_64/jdcolor-avx2.o", - "simd/x86_64/jdcolor-sse2.o", - "simd/x86_64/jdmerge-avx2.o", - "simd/x86_64/jdmerge-sse2.o", - "simd/x86_64/jdsample-avx2.o", - "simd/x86_64/jdsample-sse2.o", - "simd/x86_64/jfdctflt-sse.o", - "simd/x86_64/jfdctfst-sse2.o", - "simd/x86_64/jfdctint-avx2.o", - "simd/x86_64/jfdctint-sse2.o", - "simd/x86_64/jidctflt-sse2.o", - "simd/x86_64/jidctfst-sse2.o", - "simd/x86_64/jidctint-avx2.o", - "simd/x86_64/jidctint-sse2.o", - "simd/x86_64/jidctred-sse2.o", - "simd/x86_64/jquantf-sse2.o", - "simd/x86_64/jquanti-avx2.o", - "simd/x86_64/jquanti-sse2.o", - "simd/x86_64/jsimd.c", - "simd/x86_64/jsimdcpu.o", - ] + SRCS_SIMD_COMMON, - copts = libjpegturbo_copts, - linkstatic = 1, -) - -genrule( - name = "simd_x86_64_assemblage23", - srcs = [ - "jconfig.h", - "jconfigint.h", - "simd/x86_64/jccolext-avx2.asm", - "simd/x86_64/jccolext-sse2.asm", - "simd/x86_64/jccolor-avx2.asm", - "simd/x86_64/jccolor-sse2.asm", - "simd/x86_64/jcgray-avx2.asm", - "simd/x86_64/jcgray-sse2.asm", - "simd/x86_64/jcgryext-avx2.asm", - "simd/x86_64/jcgryext-sse2.asm", - "simd/x86_64/jchuff-sse2.asm", - "simd/x86_64/jcphuff-sse2.asm", - "simd/x86_64/jcsample-avx2.asm", - "simd/x86_64/jcsample-sse2.asm", - "simd/x86_64/jdcolext-avx2.asm", - "simd/x86_64/jdcolext-sse2.asm", - "simd/x86_64/jdcolor-avx2.asm", - "simd/x86_64/jdcolor-sse2.asm", - "simd/x86_64/jdmerge-avx2.asm", - "simd/x86_64/jdmerge-sse2.asm", - "simd/x86_64/jdmrgext-avx2.asm", - "simd/x86_64/jdmrgext-sse2.asm", - "simd/x86_64/jdsample-avx2.asm", - "simd/x86_64/jdsample-sse2.asm", - "simd/x86_64/jfdctflt-sse.asm", - "simd/x86_64/jfdctfst-sse2.asm", - "simd/x86_64/jfdctint-avx2.asm", - "simd/x86_64/jfdctint-sse2.asm", - "simd/x86_64/jidctflt-sse2.asm", - "simd/x86_64/jidctfst-sse2.asm", - "simd/x86_64/jidctint-avx2.asm", - "simd/x86_64/jidctint-sse2.asm", - "simd/x86_64/jidctred-sse2.asm", - "simd/x86_64/jquantf-sse2.asm", - "simd/x86_64/jquanti-avx2.asm", - "simd/x86_64/jquanti-sse2.asm", - "simd/x86_64/jsimdcpu.asm", - "simd/nasm/jcolsamp.inc", - "simd/nasm/jdct.inc", - "simd/nasm/jsimdcfg.inc", - "simd/nasm/jsimdcfg.inc.h", - "simd/nasm/jsimdext.inc", - ], - outs = [ - "simd/x86_64/jccolor-avx2.o", - "simd/x86_64/jccolor-sse2.o", - "simd/x86_64/jcgray-avx2.o", - "simd/x86_64/jcgray-sse2.o", - "simd/x86_64/jchuff-sse2.o", - "simd/x86_64/jcphuff-sse2.o", - "simd/x86_64/jcsample-avx2.o", - "simd/x86_64/jcsample-sse2.o", - "simd/x86_64/jdcolor-avx2.o", - "simd/x86_64/jdcolor-sse2.o", - "simd/x86_64/jdmerge-avx2.o", - "simd/x86_64/jdmerge-sse2.o", - "simd/x86_64/jdsample-avx2.o", - "simd/x86_64/jdsample-sse2.o", - "simd/x86_64/jfdctflt-sse.o", - "simd/x86_64/jfdctfst-sse2.o", - "simd/x86_64/jfdctint-avx2.o", - "simd/x86_64/jfdctint-sse2.o", - "simd/x86_64/jidctflt-sse2.o", - "simd/x86_64/jidctfst-sse2.o", - "simd/x86_64/jidctint-avx2.o", - "simd/x86_64/jidctint-sse2.o", - "simd/x86_64/jidctred-sse2.o", - "simd/x86_64/jquantf-sse2.o", - "simd/x86_64/jquanti-avx2.o", - "simd/x86_64/jquanti-sse2.o", - "simd/x86_64/jsimdcpu.o", - ], - cmd = "for out in $(OUTS); do\n" + - " $(location @nasm//:nasm) -f elf64" + - " -DELF -DPIC -D__x86_64__" + - " -I $$(dirname $(location jconfig.h))/" + - " -I $$(dirname $(location jconfigint.h))/" + - " -I $$(dirname $(location simd/nasm/jsimdcfg.inc.h))/" + - " -I $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/" + - " -o $$out" + - " $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/$$(basename $${out%.o}.asm)\n" + - "done", - tools = ["@nasm"], -) - -expand_template( - name = "neon-compat_gen", - out = "simd/arm/neon-compat.h", - substitutions = { - "#cmakedefine HAVE_VLD1_S16_X3": "#define HAVE_VLD1_S16_X3", - "#cmakedefine HAVE_VLD1_U16_X2": "#define HAVE_VLD1_U16_X2", - "#cmakedefine HAVE_VLD1Q_U8_X4": "#define HAVE_VLD1Q_U8_X4", - }, - template = "simd/arm/neon-compat.h.in", -) - -genrule( - name = "neon-compat_hdr_src", - srcs = ["simd/arm/neon-compat.h"], - outs = ["neon-compat.h"], - cmd = "cp $(location simd/arm/neon-compat.h) $@", -) - -cc_library( - name = "neon-compat_hdr", - hdrs = ["neon-compat.h"], - copts = libjpegturbo_copts, -) - -SRCS_SIMD_ARM = [ - "simd/arm/jccolor-neon.c", - "simd/arm/jcgray-neon.c", - "simd/arm/jcphuff-neon.c", - "simd/arm/jcsample-neon.c", - "simd/arm/jdcolor-neon.c", - "simd/arm/jdmerge-neon.c", - "simd/arm/jdsample-neon.c", - "simd/arm/jfdctfst-neon.c", - "simd/arm/jfdctint-neon.c", - "simd/arm/jidctfst-neon.c", - "simd/arm/jidctint-neon.c", - "simd/arm/jidctred-neon.c", - "simd/arm/jquanti-neon.c", -] - -# .c files in the following list are used like .h files in that they are -# "#include"-ed in the actual .c files. So, treat them like normal headers, and -# they *should not* be compiled into individual objects. -HDRS_SIMD_ARM = [ - "simd/arm/align.h", - "simd/arm/jchuff.h", - "simd/arm/jcgryext-neon.c", - "simd/arm/jdcolext-neon.c", - "simd/arm/jdmrgext-neon.c", -] - -cc_library( - name = "simd_armv7a", - srcs = [ - "simd/arm/aarch32/jchuff-neon.c", - "simd/arm/aarch32/jsimd.c", - ] + SRCS_SIMD_COMMON + SRCS_SIMD_ARM, - hdrs = [ - "simd/arm/aarch32/jccolext-neon.c", - ] + HDRS_SIMD_ARM, - copts = libjpegturbo_copts, - visibility = ["//visibility:private"], - deps = [":neon-compat_hdr"], -) - -cc_library( - name = "simd_armv8a", - srcs = [ - "simd/arm/aarch64/jchuff-neon.c", - "simd/arm/aarch64/jsimd.c", - ] + SRCS_SIMD_COMMON + SRCS_SIMD_ARM, - hdrs = [ - "simd/arm/aarch64/jccolext-neon.c", - ] + HDRS_SIMD_ARM, - copts = libjpegturbo_copts, - visibility = ["//visibility:private"], - deps = [":neon-compat_hdr"], -) - -cc_library( - name = "simd_win_x86_64", - srcs = [ - "simd/x86_64/jccolor-avx2.obj", - "simd/x86_64/jccolor-sse2.obj", - "simd/x86_64/jcgray-avx2.obj", - "simd/x86_64/jcgray-sse2.obj", - "simd/x86_64/jchuff-sse2.obj", - "simd/x86_64/jcphuff-sse2.obj", - "simd/x86_64/jcsample-avx2.obj", - "simd/x86_64/jcsample-sse2.obj", - "simd/x86_64/jdcolor-avx2.obj", - "simd/x86_64/jdcolor-sse2.obj", - "simd/x86_64/jdmerge-avx2.obj", - "simd/x86_64/jdmerge-sse2.obj", - "simd/x86_64/jdsample-avx2.obj", - "simd/x86_64/jdsample-sse2.obj", - "simd/x86_64/jfdctflt-sse.obj", - "simd/x86_64/jfdctfst-sse2.obj", - "simd/x86_64/jfdctint-avx2.obj", - "simd/x86_64/jfdctint-sse2.obj", - "simd/x86_64/jidctflt-sse2.obj", - "simd/x86_64/jidctfst-sse2.obj", - "simd/x86_64/jidctint-avx2.obj", - "simd/x86_64/jidctint-sse2.obj", - "simd/x86_64/jidctred-sse2.obj", - "simd/x86_64/jquantf-sse2.obj", - "simd/x86_64/jquanti-avx2.obj", - "simd/x86_64/jquanti-sse2.obj", - "simd/x86_64/jsimd.c", - "simd/x86_64/jsimdcpu.obj", - ] + SRCS_SIMD_COMMON, - copts = libjpegturbo_copts, -) - -genrule( - name = "simd_win_x86_64_assemble", - srcs = [ - "jconfig.h", - "jconfigint.h", - "simd/x86_64/jccolext-avx2.asm", - "simd/x86_64/jccolext-sse2.asm", - "simd/x86_64/jccolor-avx2.asm", - "simd/x86_64/jccolor-sse2.asm", - "simd/x86_64/jcgray-avx2.asm", - "simd/x86_64/jcgray-sse2.asm", - "simd/x86_64/jcgryext-avx2.asm", - "simd/x86_64/jcgryext-sse2.asm", - "simd/x86_64/jchuff-sse2.asm", - "simd/x86_64/jcphuff-sse2.asm", - "simd/x86_64/jcsample-avx2.asm", - "simd/x86_64/jcsample-sse2.asm", - "simd/x86_64/jdcolext-avx2.asm", - "simd/x86_64/jdcolext-sse2.asm", - "simd/x86_64/jdcolor-avx2.asm", - "simd/x86_64/jdcolor-sse2.asm", - "simd/x86_64/jdmerge-avx2.asm", - "simd/x86_64/jdmerge-sse2.asm", - "simd/x86_64/jdmrgext-avx2.asm", - "simd/x86_64/jdmrgext-sse2.asm", - "simd/x86_64/jdsample-avx2.asm", - "simd/x86_64/jdsample-sse2.asm", - "simd/x86_64/jfdctflt-sse.asm", - "simd/x86_64/jfdctfst-sse2.asm", - "simd/x86_64/jfdctint-avx2.asm", - "simd/x86_64/jfdctint-sse2.asm", - "simd/x86_64/jidctflt-sse2.asm", - "simd/x86_64/jidctfst-sse2.asm", - "simd/x86_64/jidctint-avx2.asm", - "simd/x86_64/jidctint-sse2.asm", - "simd/x86_64/jidctred-sse2.asm", - "simd/x86_64/jquantf-sse2.asm", - "simd/x86_64/jquanti-avx2.asm", - "simd/x86_64/jquanti-sse2.asm", - "simd/x86_64/jsimdcpu.asm", - "simd/nasm/jcolsamp.inc", - "simd/nasm/jdct.inc", - "simd/nasm/jsimdcfg.inc", - "simd/nasm/jsimdcfg.inc.h", - "simd/nasm/jsimdext.inc", - ], - outs = [ - "simd/x86_64/jccolor-avx2.obj", - "simd/x86_64/jccolor-sse2.obj", - "simd/x86_64/jcgray-avx2.obj", - "simd/x86_64/jcgray-sse2.obj", - "simd/x86_64/jchuff-sse2.obj", - "simd/x86_64/jcphuff-sse2.obj", - "simd/x86_64/jcsample-avx2.obj", - "simd/x86_64/jcsample-sse2.obj", - "simd/x86_64/jdcolor-avx2.obj", - "simd/x86_64/jdcolor-sse2.obj", - "simd/x86_64/jdmerge-avx2.obj", - "simd/x86_64/jdmerge-sse2.obj", - "simd/x86_64/jdsample-avx2.obj", - "simd/x86_64/jdsample-sse2.obj", - "simd/x86_64/jfdctflt-sse.obj", - "simd/x86_64/jfdctfst-sse2.obj", - "simd/x86_64/jfdctint-avx2.obj", - "simd/x86_64/jfdctint-sse2.obj", - "simd/x86_64/jidctflt-sse2.obj", - "simd/x86_64/jidctfst-sse2.obj", - "simd/x86_64/jidctint-avx2.obj", - "simd/x86_64/jidctint-sse2.obj", - "simd/x86_64/jidctred-sse2.obj", - "simd/x86_64/jquantf-sse2.obj", - "simd/x86_64/jquanti-avx2.obj", - "simd/x86_64/jquanti-sse2.obj", - "simd/x86_64/jsimdcpu.obj", - ], - cmd = "for out in $(OUTS); do\n" + - " $(location @nasm//:nasm) -fwin64 -DWIN64 -D__x86_64__" + - " -I $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/" + - " -I $$(dirname $(location simd/nasm/jdct.inc))/" + - " -I $$(dirname $(location simd/nasm/jdct.inc))/../../win/" + - " -o $$out" + - " $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/$$(basename $${out%.obj}.asm)\n" + - "done", - tools = ["@nasm"], -) - -cc_library( - name = "simd_none", - srcs = [ - "jchuff.h", - "jconfig.h", - "jconfigint.h", - "jdct.h", - "jerror.h", - "jinclude.h", - "jmorecfg.h", - "jpegint.h", - "jpeglib.h", - "jsimd.h", - "jsimd_none.c", - "jsimddct.h", - ], - copts = libjpegturbo_copts, -) - -expand_template( - name = "jversion", - out = "jversion.h", - substitutions = { - "@COPYRIGHT_YEAR@": "1991-2022", - }, - template = "jversion.h.in", -) - -expand_template( - name = "jconfig_win", - out = "jconfig_win.h", - substitutions = { - "@JPEG_LIB_VERSION@": "62", - "@VERSION@": "2.1.4", - "@LIBJPEG_TURBO_VERSION_NUMBER@": "2001004", - "@BITS_IN_JSAMPLE@": "8", - "#cmakedefine C_ARITH_CODING_SUPPORTED": "#define C_ARITH_CODING_SUPPORTED", - "#cmakedefine D_ARITH_CODING_SUPPORTED": "#define D_ARITH_CODING_SUPPORTED", - "#cmakedefine MEM_SRCDST_SUPPORTED": "#define MEM_SRCDST_SUPPORTED", - "#cmakedefine WITH_SIMD": "", - }, - template = "win/jconfig.h.in", -) - -JCONFIG_NOWIN_COMMON_SUBSTITUTIONS = { - "@JPEG_LIB_VERSION@": "62", - "@VERSION@": "2.1.4", - "@LIBJPEG_TURBO_VERSION_NUMBER@": "2001004", - "#cmakedefine C_ARITH_CODING_SUPPORTED 1": "#define C_ARITH_CODING_SUPPORTED 1", - "#cmakedefine D_ARITH_CODING_SUPPORTED 1": "#define D_ARITH_CODING_SUPPORTED 1", - "#cmakedefine MEM_SRCDST_SUPPORTED 1": "#define MEM_SRCDST_SUPPORTED 1", - "@BITS_IN_JSAMPLE@": "8", - "#cmakedefine HAVE_LOCALE_H 1": "#define HAVE_LOCALE_H 1", - "#cmakedefine HAVE_STDDEF_H 1": "#define HAVE_STDDEF_H 1", - "#cmakedefine HAVE_STDLIB_H 1": "#define HAVE_STDLIB_H 1", - "#cmakedefine NEED_SYS_TYPES_H 1": "#define NEED_SYS_TYPES_H 1", - "#cmakedefine NEED_BSD_STRINGS 1": "", - "#cmakedefine HAVE_UNSIGNED_CHAR 1": "#define HAVE_UNSIGNED_CHAR 1", - "#cmakedefine HAVE_UNSIGNED_SHORT 1": "#define HAVE_UNSIGNED_SHORT 1", - "#cmakedefine INCOMPLETE_TYPES_BROKEN 1": "", - "#cmakedefine RIGHT_SHIFT_IS_UNSIGNED 1": "", - "#cmakedefine __CHAR_UNSIGNED__ 1": "", - "#undef const": "", - "#undef size_t": "", -} - -JCONFIG_NOWIN_SIMD_SUBSTITUTIONS = { - "#cmakedefine WITH_SIMD 1": "#define WITH_SIMD 1", -} - -JCONFIG_NOWIN_NOSIMD_SUBSTITUTIONS = { - "#cmakedefine WITH_SIMD 1": "", -} - -JCONFIG_NOWIN_SIMD_SUBSTITUTIONS.update(JCONFIG_NOWIN_COMMON_SUBSTITUTIONS) - -JCONFIG_NOWIN_NOSIMD_SUBSTITUTIONS.update(JCONFIG_NOWIN_COMMON_SUBSTITUTIONS) - -expand_template( - name = "jconfig_nowin_nosimd", - out = "jconfig_nowin_nosimd.h", - substitutions = JCONFIG_NOWIN_NOSIMD_SUBSTITUTIONS, - template = "jconfig.h.in", -) - -expand_template( - name = "jconfig_nowin_simd", - out = "jconfig_nowin_simd.h", - substitutions = JCONFIG_NOWIN_SIMD_SUBSTITUTIONS, - template = "jconfig.h.in", -) - -JCONFIGINT_COMMON_SUBSTITUTIONS = { - "@BUILD@": "20221022", - "@VERSION@": "2.1.4", - "@CMAKE_PROJECT_NAME@": "libjpeg-turbo", - "#undef inline": "", - "#cmakedefine HAVE_INTRIN_H": "", -} - -JCONFIGINT_NOWIN_SUBSTITUTIONS = { - "#cmakedefine HAVE_BUILTIN_CTZL": "#define HAVE_BUILTIN_CTZL", - "@INLINE@": "inline __attribute__((always_inline))", - "#define SIZEOF_SIZE_T @SIZE_T@": "#if (__WORDSIZE==64 && !defined(__native_client__))\n" + - "#define SIZEOF_SIZE_T 8\n" + - "#else\n" + - "#define SIZEOF_SIZE_T 4\n" + - "#endif\n", -} - -JCONFIGINT_WIN_SUBSTITUTIONS = { - "#cmakedefine HAVE_BUILTIN_CTZL": "", - "#define INLINE @INLINE@": "#if defined(__GNUC__)\n" + - "#define INLINE inline __attribute__((always_inline))\n" + - "#elif defined(_MSC_VER)\n" + - "#define INLINE __forceinline\n" + - "#else\n" + - "#define INLINE\n" + - "#endif\n", - "#define SIZEOF_SIZE_T @SIZE_T@": "#if (__WORDSIZE==64)\n" + - "#define SIZEOF_SIZE_T 8\n" + - "#else\n" + - "#define SIZEOF_SIZE_T 4\n" + - "#endif\n", -} - -JCONFIGINT_NOWIN_SUBSTITUTIONS.update(JCONFIGINT_COMMON_SUBSTITUTIONS) - -JCONFIGINT_WIN_SUBSTITUTIONS.update(JCONFIGINT_COMMON_SUBSTITUTIONS) - -expand_template( - name = "jconfigint_nowin", - out = "jconfigint_nowin.h", - substitutions = JCONFIGINT_NOWIN_SUBSTITUTIONS, - template = "jconfigint.h.in", -) - -expand_template( - name = "jconfigint_win", - out = "jconfigint_win.h", - substitutions = JCONFIGINT_WIN_SUBSTITUTIONS, - template = "jconfigint.h.in", -) - -genrule( - name = "configure", - srcs = [ - "jconfig_win.h", - "jconfig_nowin_nosimd.h", - "jconfig_nowin_simd.h", - ], - outs = ["jconfig.h"], - cmd = select({ - ":windows": "cp $(location jconfig_win.h) $@", - ":k8": "cp $(location jconfig_nowin_simd.h) $@", - ":armeabi-v7a": "cp $(location jconfig_nowin_simd.h) $@", - ":arm64-v8a": "cp $(location jconfig_nowin_simd.h) $@", - ":linux_ppc64le": "cp $(location jconfig_nowin_simd.h) $@", - "//conditions:default": "cp $(location jconfig_nowin_nosimd.h) $@", - }), -) - -genrule( - name = "configure_internal", - srcs = [ - "jconfigint_win.h", - "jconfigint_nowin.h", - ], - outs = ["jconfigint.h"], - cmd = select({ - ":windows": "cp $(location jconfigint_win.h) $@", - "//conditions:default": "cp $(location jconfigint_nowin.h) $@", - }), -) - -# jiminy cricket the way this file is generated is completely outrageous -genrule( - name = "configure_simd", - outs = ["simd/jsimdcfg.inc"], - cmd = "cat <<'EOF' >$@\n" + - "%define DCTSIZE 8\n" + - "%define DCTSIZE2 64\n" + - "%define RGB_RED 0\n" + - "%define RGB_GREEN 1\n" + - "%define RGB_BLUE 2\n" + - "%define RGB_PIXELSIZE 3\n" + - "%define EXT_RGB_RED 0\n" + - "%define EXT_RGB_GREEN 1\n" + - "%define EXT_RGB_BLUE 2\n" + - "%define EXT_RGB_PIXELSIZE 3\n" + - "%define EXT_RGBX_RED 0\n" + - "%define EXT_RGBX_GREEN 1\n" + - "%define EXT_RGBX_BLUE 2\n" + - "%define EXT_RGBX_PIXELSIZE 4\n" + - "%define EXT_BGR_RED 2\n" + - "%define EXT_BGR_GREEN 1\n" + - "%define EXT_BGR_BLUE 0\n" + - "%define EXT_BGR_PIXELSIZE 3\n" + - "%define EXT_BGRX_RED 2\n" + - "%define EXT_BGRX_GREEN 1\n" + - "%define EXT_BGRX_BLUE 0\n" + - "%define EXT_BGRX_PIXELSIZE 4\n" + - "%define EXT_XBGR_RED 3\n" + - "%define EXT_XBGR_GREEN 2\n" + - "%define EXT_XBGR_BLUE 1\n" + - "%define EXT_XBGR_PIXELSIZE 4\n" + - "%define EXT_XRGB_RED 1\n" + - "%define EXT_XRGB_GREEN 2\n" + - "%define EXT_XRGB_BLUE 3\n" + - "%define EXT_XRGB_PIXELSIZE 4\n" + - "%define RGBX_FILLER_0XFF 1\n" + - "%define JSAMPLE byte ; unsigned char\n" + - "%define SIZEOF_JSAMPLE SIZEOF_BYTE ; sizeof(JSAMPLE)\n" + - "%define CENTERJSAMPLE 128\n" + - "%define JCOEF word ; short\n" + - "%define SIZEOF_JCOEF SIZEOF_WORD ; sizeof(JCOEF)\n" + - "%define JDIMENSION dword ; unsigned int\n" + - "%define SIZEOF_JDIMENSION SIZEOF_DWORD ; sizeof(JDIMENSION)\n" + - "%define JSAMPROW POINTER ; JSAMPLE * (jpeglib.h)\n" + - "%define JSAMPARRAY POINTER ; JSAMPROW * (jpeglib.h)\n" + - "%define JSAMPIMAGE POINTER ; JSAMPARRAY * (jpeglib.h)\n" + - "%define JCOEFPTR POINTER ; JCOEF * (jpeglib.h)\n" + - "%define SIZEOF_JSAMPROW SIZEOF_POINTER ; sizeof(JSAMPROW)\n" + - "%define SIZEOF_JSAMPARRAY SIZEOF_POINTER ; sizeof(JSAMPARRAY)\n" + - "%define SIZEOF_JSAMPIMAGE SIZEOF_POINTER ; sizeof(JSAMPIMAGE)\n" + - "%define SIZEOF_JCOEFPTR SIZEOF_POINTER ; sizeof(JCOEFPTR)\n" + - "%define DCTELEM word ; short\n" + - "%define SIZEOF_DCTELEM SIZEOF_WORD ; sizeof(DCTELEM)\n" + - "%define float FP32 ; float\n" + - "%define SIZEOF_FAST_FLOAT SIZEOF_FP32 ; sizeof(float)\n" + - "%define ISLOW_MULT_TYPE word ; must be short\n" + - "%define SIZEOF_ISLOW_MULT_TYPE SIZEOF_WORD ; sizeof(ISLOW_MULT_TYPE)\n" + - "%define IFAST_MULT_TYPE word ; must be short\n" + - "%define SIZEOF_IFAST_MULT_TYPE SIZEOF_WORD ; sizeof(IFAST_MULT_TYPE)\n" + - "%define IFAST_SCALE_BITS 2 ; fractional bits in scale factors\n" + - "%define FLOAT_MULT_TYPE FP32 ; must be float\n" + - "%define SIZEOF_FLOAT_MULT_TYPE SIZEOF_FP32 ; sizeof(FLOAT_MULT_TYPE)\n" + - "%define JSIMD_NONE 0x00\n" + - "%define JSIMD_MMX 0x01\n" + - "%define JSIMD_3DNOW 0x02\n" + - "%define JSIMD_SSE 0x04\n" + - "%define JSIMD_SSE2 0x08\n" + - "EOF", -) - -string_flag( - name = "noasm", - build_setting_default = "no", -) - -config_setting( - name = "nosimd", - flag_values = {":noasm": "yes"}, -) - -config_setting( - name = "k8", - flag_values = {":noasm": "no"}, - values = {"cpu": "k8"}, -) - -config_setting( - name = "android", - values = {"crosstool_top": "//external:android/crosstool"}, -) - -config_setting( - name = "armeabi-v7a", - flag_values = {":noasm": "no"}, - values = {"cpu": "armeabi-v7a"}, -) - -config_setting( - name = "arm64-v8a", - flag_values = {":noasm": "no"}, - values = {"cpu": "arm64-v8a"}, -) - -config_setting( - name = "windows", - flag_values = {":noasm": "no"}, - values = {"cpu": "x64_windows"}, -) - -config_setting( - name = "linux_ppc64le", - flag_values = {":noasm": "no"}, - values = {"cpu": "ppc"}, -) diff --git a/third_party/tsl/third_party/jpeg/jpeg_helpers.BUILD.bazel b/third_party/tsl/third_party/jpeg/jpeg_helpers.BUILD.bazel deleted file mode 100644 index 5b01f6e3e4cfd..0000000000000 --- a/third_party/tsl/third_party/jpeg/jpeg_helpers.BUILD.bazel +++ /dev/null @@ -1 +0,0 @@ -licenses(["notice"]) diff --git a/third_party/tsl/third_party/jpeg/workspace.bzl b/third_party/tsl/third_party/jpeg/workspace.bzl deleted file mode 100644 index 631cc933bc60d..0000000000000 --- a/third_party/tsl/third_party/jpeg/workspace.bzl +++ /dev/null @@ -1,13 +0,0 @@ -"""loads the jpeg library, used by TF.""" - -load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") - -def repo(): - tf_http_archive( - name = "libjpeg_turbo", - urls = tf_mirror_urls("https://github.com/libjpeg-turbo/libjpeg-turbo/archive/refs/tags/2.1.4.tar.gz"), - sha256 = "a78b05c0d8427a90eb5b4eb08af25309770c8379592bb0b8a863373128e6143f", - strip_prefix = "libjpeg-turbo-2.1.4", - build_file = "//third_party/jpeg:jpeg.BUILD", - system_build_file = "//third_party/jpeg:BUILD.system", - ) diff --git a/third_party/tsl/third_party/llvm/build.patch b/third_party/tsl/third_party/llvm/build.patch index bbf8f587acada..479e08cde869a 100644 --- a/third_party/tsl/third_party/llvm/build.patch +++ b/third_party/tsl/third_party/llvm/build.patch @@ -1,8 +1,8 @@ diff --git a/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel b/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel -index 2b88729d748b..e12d979b4908 100644 +index 7770284e5543..0b45127495dc 100644 --- a/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel -@@ -207,13 +207,15 @@ cc_library( +@@ -218,13 +218,15 @@ cc_library( "lib/Support/BLAKE3/llvm_blake3_prefix.h", ] + select({ "@platforms//cpu:aarch64": [ @@ -23,7 +23,7 @@ index 2b88729d748b..e12d979b4908 100644 ], "//conditions:default": [ ], -@@ -238,14 +240,16 @@ cc_library( +@@ -249,14 +251,16 @@ cc_library( ], copts = llvm_copts, defines = select({ diff --git a/third_party/tsl/third_party/llvm/generated.patch b/third_party/tsl/third_party/llvm/generated.patch index ce1937af46e5d..509398da979e8 100644 --- a/third_party/tsl/third_party/llvm/generated.patch +++ b/third_party/tsl/third_party/llvm/generated.patch @@ -1,12 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -@@ -594,6 +594,7 @@ - name = "__support_bit", - hdrs = ["src/__support/bit.h"], - deps = [ -+ ":__support_cpp_type_traits", - ":__support_macros_attributes", - ], - ) diff --git a/third_party/tsl/third_party/llvm/toolchains.patch b/third_party/tsl/third_party/llvm/toolchains.patch index dc45d4d4987dc..a4de4eaaff343 100644 --- a/third_party/tsl/third_party/llvm/toolchains.patch +++ b/third_party/tsl/third_party/llvm/toolchains.patch @@ -34,12 +34,12 @@ index c43ab727e285..7d848d2dffae 100644 # The necessary warnings and other compile flags should be provided by the # toolchain or the `.bazelrc` file. This is just a workaround until we have a diff --git a/utils/bazel/llvm-project-overlay/llvm/config.bzl b/utils/bazel/llvm-project-overlay/llvm/config.bzl -index b15ec9e1bb39..56c2766872fa 100644 +index 2e3bff53ead9..8d01617effdc 100644 --- a/utils/bazel/llvm-project-overlay/llvm/config.bzl +++ b/utils/bazel/llvm-project-overlay/llvm/config.bzl -@@ -89,8 +89,9 @@ os_defines = select({ +@@ -98,8 +98,9 @@ builtin_thread_pointer = select({ # TODO: We should split out host vs. target here. - llvm_config_defines = os_defines + select({ + llvm_config_defines = os_defines + builtin_thread_pointer + select({ "@bazel_tools//src/conditions:windows": native_arch_defines("X86", "x86_64-pc-win32"), - "@bazel_tools//src/conditions:darwin_arm64": native_arch_defines("AArch64", "arm64-apple-darwin"), - "@bazel_tools//src/conditions:darwin_x86_64": native_arch_defines("X86", "x86_64-unknown-darwin"), diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index 62f02cded785c..3bcd8e242c18f 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "f688e0901213726feb9b26cedc61919413cbf59c" - LLVM_SHA256 = "b8885c22a9b77f9c91a316b21d71414a7b48dae38513f170da1554002e85b030" + LLVM_COMMIT = "8ee6ab7f69ca9c34eed56faad3971d075dc47121" + LLVM_SHA256 = "c408d2a80a53057fc3596cbfbea3ec64fe8feccbff7adc5e94fde192b96ca568" tf_http_archive( name = name, diff --git a/third_party/tsl/third_party/llvm_openmp/BUILD b/third_party/tsl/third_party/llvm_openmp/BUILD index 4fb09728661a2..52d2e3aa4b611 100644 --- a/third_party/tsl/third_party/llvm_openmp/BUILD +++ b/third_party/tsl/third_party/llvm_openmp/BUILD @@ -1,11 +1,6 @@ # Build file for OpenMP library that is part of llvm -load( - "@tsl//tsl:tsl.bzl", - "if_linux_x86_64", - "if_macos", - "if_windows", -) +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load( "@tsl//third_party/llvm_openmp:cmake_vars.bzl", "cmake_var_string", @@ -16,7 +11,12 @@ load( "dict_add", "libiomp5_cc_binary", ) -load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load( + "@tsl//tsl:tsl.bzl", + "if_linux_x86_64", + "if_macos", + "if_windows", +) package( default_visibility = [ diff --git a/third_party/tsl/third_party/llvm_openmp/openmp.bzl b/third_party/tsl/third_party/llvm_openmp/openmp.bzl index af212f5a8340c..9e53a62436bf5 100644 --- a/third_party/tsl/third_party/llvm_openmp/openmp.bzl +++ b/third_party/tsl/third_party/llvm_openmp/openmp.bzl @@ -4,7 +4,7 @@ after the TF 2.4 branch cut has passed. """ load( - "//tsl/platform:rules_cc.bzl", + "@tsl//tsl/platform:rules_cc.bzl", "cc_binary", ) diff --git a/third_party/tsl/third_party/mkl/BUILD b/third_party/tsl/third_party/mkl/BUILD index 6da193d41ba06..067771b43f7e6 100644 --- a/third_party/tsl/third_party/mkl/BUILD +++ b/third_party/tsl/third_party/mkl/BUILD @@ -1,66 +1 @@ licenses(["notice"]) # 3-Clause BSD - -load("@bazel_skylib//:bzl_library.bzl", "bzl_library") - -package(default_visibility = ["//visibility:public"]) - -alias( - name = "build_with_mkl", - actual = "//tsl/mkl:build_with_mkl", -) - -alias( - name = "build_with_mkl_lnx_x64", - actual = "//tsl/mkl:build_with_mkl_lnx_x64", -) - -alias( - name = "build_with_mkl_lnx_openmp", - actual = "//tsl/mkl:build_with_mkl_lnx_openmp", -) - -alias( - name = "build_with_mkl_windows_openmp", - actual = "//tsl/mkl:build_with_mkl_windows_openmp", -) - -alias( - name = "build_with_mkl_aarch64", - actual = "//tsl/mkl:build_with_mkl_aarch64", -) - -alias( - name = "enable_mkl", - actual = "//tsl/mkl:enable_mkl", -) - -alias( - name = "intel_binary_blob", - actual = "//tsl/mkl:intel_binary_blob", -) - -alias( - name = "LICENSE", - actual = "//tsl/mkl:LICENSE", -) - -alias( - name = "mkl_libs_linux", - actual = "//tsl/mkl:mkl_libs_linux", -) - -alias( - name = "mkl_libs_darwin", - actual = "//tsl/mkl:mkl_libs_darwin", -) - -alias( - name = "mkl_libs_windows", - actual = "//tsl/mkl:mkl_libs_windows", -) - -bzl_library( - name = "build_defs_bzl", - srcs = ["build_defs.bzl"], - visibility = ["//visibility:public"], -) diff --git a/third_party/tsl/third_party/mkl/build_defs.bzl b/third_party/tsl/third_party/mkl/build_defs.bzl deleted file mode 100644 index 16250861c05cc..0000000000000 --- a/third_party/tsl/third_party/mkl/build_defs.bzl +++ /dev/null @@ -1,30 +0,0 @@ -"""Starlark macros for MKL. - -if_mkl is a conditional to check if we are building with MKL. -if_mkl_ml is a conditional to check if we are building with MKL-ML. -if_mkl_ml_only is a conditional to check for MKL-ML-only (no MKL-DNN) mode. -if_mkl_lnx_x64 is a conditional to check for MKL -if_enable_mkl is a conditional to check if building with MKL and MKL is enabled. - -mkl_repository is a repository rule for creating MKL repository rule that can -be pointed to either a local folder, or downloaded from the internet. -mkl_repository depends on the following environment variables: - * `TF_MKL_ROOT`: The root folder where a copy of libmkl is located. -""" - -load( - "@tsl//tsl/mkl:build_defs.bzl", - _if_enable_mkl = "if_enable_mkl", - _if_mkl = "if_mkl", - _if_mkl_lnx_x64 = "if_mkl_lnx_x64", - _if_mkl_ml = "if_mkl_ml", - _mkl_deps = "mkl_deps", - _mkl_repository = "mkl_repository", -) - -if_mkl = _if_mkl -if_mkl_ml = _if_mkl_ml -if_mkl_lnx_x64 = _if_mkl_lnx_x64 -if_enable_mkl = _if_enable_mkl -mkl_deps = _mkl_deps -mkl_repository = _mkl_repository diff --git a/third_party/tsl/third_party/mkl_dnn/build_defs.bzl b/third_party/tsl/third_party/mkl_dnn/build_defs.bzl index 3df468cd5ebb6..00abf41cd924b 100644 --- a/third_party/tsl/third_party/mkl_dnn/build_defs.bzl +++ b/third_party/tsl/third_party/mkl_dnn/build_defs.bzl @@ -23,7 +23,7 @@ def if_mkldnn_openmp(if_true, if_false = []): def if_mkldnn_aarch64_acl(if_true, if_false = []): return select({ - "@tsl//third_party/mkl:build_with_mkl_aarch64": if_true, + "@xla//xla/tsl/mkl:build_with_mkl_aarch64": if_true, "//conditions:default": if_false, }) diff --git a/third_party/tsl/third_party/mkl_dnn/mkldnn_acl.BUILD b/third_party/tsl/third_party/mkl_dnn/mkldnn_acl.BUILD index 53b77d340668f..13ce67110b465 100644 --- a/third_party/tsl/third_party/mkl_dnn/mkldnn_acl.BUILD +++ b/third_party/tsl/third_party/mkl_dnn/mkldnn_acl.BUILD @@ -1,7 +1,7 @@ -exports_files(["LICENSE"]) - load("@bazel_skylib//rules:expand_template.bzl", "expand_template") +exports_files(["LICENSE"]) + _DNNL_COPTS_THREADPOOL = [ "-fopenmp-simd", "-fexceptions", @@ -26,6 +26,7 @@ _DNNL_RUNTIME_THREADPOOL = { "#cmakedefine DNNL_SYCL_CUDA": "#undef DNNL_SYCL_CUDA", "#cmakedefine DNNL_SYCL_HIP": "#undef DNNL_SYCL_HIP", "#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER", + "#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#define DNNL_EXPERIMENTAL_SPARSE", "#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL", "#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH", "#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1", diff --git a/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD index d042803166ffc..090c5a0717fd6 100644 --- a/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD +++ b/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD @@ -1,8 +1,7 @@ -load("@tsl//tsl:tsl.bzl", "tf_openmp_copts") -load("@tsl//third_party/mkl:build_defs.bzl", "if_mkl") -load("@tsl//third_party/mkl_dnn:build_defs.bzl", "if_mkldnn_openmp") -load("@tsl//third_party/mkl:build_defs.bzl", "if_mkl_ml") load("@bazel_skylib//rules:expand_template.bzl", "expand_template") +load("@tsl//third_party/mkl_dnn:build_defs.bzl", "if_mkldnn_openmp") +load("@tsl//tsl:tsl.bzl", "tf_openmp_copts") +load("@xla//xla/tsl/mkl:build_defs.bzl", "if_mkl", "if_mkl_ml") exports_files(["LICENSE"]) @@ -14,8 +13,9 @@ _CMAKE_COMMON_LIST = { "#cmakedefine DNNL_SYCL_CUDA": "#undef DNNL_SYCL_CUDA", "#cmakedefine DNNL_SYCL_HIP": "#undef DNNL_SYCL_HIP", "#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER", - "#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL", "#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH", + "#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#define DNNL_EXPERIMENTAL_SPARSE", + "#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL", "#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1", "#cmakedefine01 BUILD_INFERENCE": "#define BUILD_INFERENCE 0", "#cmakedefine01 BUILD_PRIMITIVE_ALL": "#define BUILD_PRIMITIVE_ALL 1", @@ -95,7 +95,7 @@ expand_template( substitutions = { "@DNNL_VERSION_MAJOR@": "3", "@DNNL_VERSION_MINOR@": "3", - "@DNNL_VERSION_PATCH@": "0", + "@DNNL_VERSION_PATCH@": "4", "@DNNL_VERSION_HASH@": "N/A", }, template = "include/oneapi/dnnl/dnnl_version.h.in", @@ -179,7 +179,7 @@ cc_library( textual_hdrs = _TEXTUAL_HDRS_LIST, visibility = ["//visibility:public"], deps = [":onednn_autogen"] + if_mkl_ml( - ["@tsl//third_party/mkl:intel_binary_blob"], + ["@xla//xla/tsl/mkl:intel_binary_blob"], [], ), ) diff --git a/third_party/tsl/third_party/mkl_dnn/onednn_acl_indirect_conv.patch b/third_party/tsl/third_party/mkl_dnn/onednn_acl_indirect_conv.patch new file mode 100644 index 0000000000000..217e668352de5 --- /dev/null +++ b/third_party/tsl/third_party/mkl_dnn/onednn_acl_indirect_conv.patch @@ -0,0 +1,31 @@ + ******************************************************************************* + Copyright 2024 Arm Limited and affiliates. + SPDX-License-Identifier: Apache-2.0 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ******************************************************************************* +diff --git a/src/cpu/aarch64/acl_convolution_utils.cpp b/src/cpu/aarch64/acl_convolution_utils.cpp +index f043fee4bc..0384cce757 100644 +--- a/src/cpu/aarch64/acl_convolution_utils.cpp ++++ b/src/cpu/aarch64/acl_convolution_utils.cpp +@@ -313,10 +313,6 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, + + CHECK(acl_init_conf(acp, src_md, weights_md, dst_md, bias_md, cd, attr)); + +- // Indirect is slower than gemm for low thread counts, except for fast math +- if (dnnl_get_max_threads() < 28 && !acp.fast_math) +- return status::unimplemented; +- + // If we do not need to pad input channels for fast math mode then it would + // be faster to run convolution with im2row instead of using indirect kernel + int block_by = arm_compute::block_by(acp.weights_info.weight_format()); diff --git a/third_party/tsl/third_party/nccl/archive.BUILD b/third_party/tsl/third_party/nccl/archive.BUILD index 24e1804f74914..5c5f25d376044 100644 --- a/third_party/tsl/third_party/nccl/archive.BUILD +++ b/third_party/tsl/third_party/nccl/archive.BUILD @@ -1,10 +1,6 @@ # NVIDIA NCCL 2 # A package of optimized primitives for collective multi-GPU communication. -licenses(["notice"]) - -exports_files(["LICENSE.txt"]) - load("@bazel_skylib//rules:expand_template.bzl", "expand_template") load("@bazel_skylib//rules:write_file.bzl", "write_file") load( @@ -14,14 +10,21 @@ load( load( "@local_config_nccl//:build_defs.bzl", "cuda_rdc_library", - "gen_device_srcs", ) +load( + "@local_config_nccl//:generated_names.bzl", + "GENERATED_SOURCES", +) + +licenses(["notice"]) + +exports_files(["LICENSE.txt"]) NCCL_MAJOR = 2 -NCCL_MINOR = 18 +NCCL_MINOR = 19 -NCCL_PATCH = 5 +NCCL_PATCH = 3 NCCL_VERSION = NCCL_MAJOR * 10000 + NCCL_MINOR * 100 + NCCL_PATCH # e.g., 21605 @@ -73,37 +76,63 @@ cc_library( cc_library( name = "device_hdrs", - hdrs = glob(["src/collectives/device/*.h"]), - strip_include_prefix = "src/collectives/device", + hdrs = glob(["src/device/**/*.h"]), + strip_include_prefix = "src/device", ) -# NCCL compiles the same source files with different NCCL_OP/NCCL_TYPE defines. -# RDC compilation requires that each compiled module has a unique ID. Clang -# derives the module ID from the path only so we need to copy the files to get -# different IDs for different parts of compilation. NVCC does not have that -# problem because it generates IDs based on preprocessed content. -gen_device_srcs( - name = "device_srcs", - srcs = [ - "src/collectives/device/all_gather.cu.cc", - "src/collectives/device/all_reduce.cu.cc", - "src/collectives/device/broadcast.cu.cc", - "src/collectives/device/reduce.cu.cc", - "src/collectives/device/reduce_scatter.cu.cc", - "src/collectives/device/sendrecv.cu.cc", - ], +py_binary( + name = "generate", + srcs = ["src/device/generate.py"], + python_version = "PY3", +) + +genrule( + name = "generated_srcs", + srcs = [], + outs = ["result.tar"], + cmd = """ + mkdir -p src/device/generated + $(location :generate) src/device/generated + tar -cf $@ src + """, + tools = [":generate"], +) + +genrule( + name = "generated_sources", + srcs = ["generated_srcs"], + outs = ["generated_names.bzl"], + cmd = """ + echo '"List of sources generated by :generate_nccl_kernels"' > $@ + echo "GENERATED_SOURCES = [" >> $@ + tar --list -f $< | grep '.cc' | sort | sed -e 's/\\(.*\\)/ "\\1",/' >> $@ + echo "]" >> $@ + """, +) + +EXTRACT_CMD = """ + set -x + OUTDIR=$$(mktemp -d) + tar -C $$OUTDIR -xf $(location :generated_srcs) + for outf in $(OUTS); do + F=$$(echo $$outf | sed -e 's@.*/src/device/generated/@@') + mv $$OUTDIR/src/device/generated/$$F $$outf + done +""" + +genrule( + name = "generated_files", + srcs = [":generated_srcs"], + outs = GENERATED_SOURCES, + cmd = EXTRACT_CMD, ) cuda_rdc_library( name = "device", srcs = [ - "src/collectives/device/functions.cu.cc", - "src/collectives/device/onerank_reduce.cu.cc", - ":device_srcs", - ] + glob([ - # Required for header inclusion checking, see below for details. - "src/collectives/device/*.h", - "src/nccl.h", + ":generated_files", + ] + glob(include = [ + "src/device/**/*.cu.cc", ]), deps = [ ":device_hdrs", @@ -134,7 +163,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "@local_config_cuda//cuda:cuda_headers", - "@tsl//tsl/cuda:nccl_stub", + "@xla//xla/tsl/cuda:nccl_stub", ], ) @@ -159,7 +188,7 @@ cc_library( ], # Exclude device-library code. exclude = [ - "src/collectives/device/**", + "src/device/**", "src/transport/coll_net.cc", "src/transport/net.cc", "src/enqueue.cc", diff --git a/third_party/tsl/third_party/nccl/archive.patch b/third_party/tsl/third_party/nccl/archive.patch index 8ef0af95a1c6c..372c0f493fc38 100644 --- a/third_party/tsl/third_party/nccl/archive.patch +++ b/third_party/tsl/third_party/nccl/archive.patch @@ -1,46 +1,113 @@ -diff --git a/src/collectives/device/all_gather.cu b/src/collectives/device/all_gather.cu.cc +diff --git a/src/device/common.cu b/src/device/common.cu.cc similarity index 100% -rename from src/collectives/device/all_gather.cu -rename to src/collectives/device/all_gather.cu.cc -diff --git a/src/collectives/device/all_reduce.cu b/src/collectives/device/all_reduce.cu.cc +rename from src/device/common.cu +rename to src/device/common.cu.cc +diff --git a/src/device/common.h b/src/device/common.h +index 97581f7..134fdb8 100644 +--- a/src/device/common.h ++++ b/src/device/common.h +@@ -15,7 +15,7 @@ + #define COLL_UNROLL (ncclCollUnroll()) + + typedef void(*ncclDevFuncPtr_t)(); +-extern __device__ ncclDevFuncPtr_t const ncclDevFuncTable[]; ++extern __device__ ncclDevFuncPtr_t ncclDevFuncTable[]; + + struct ncclShmemGroup { + ncclConnInfo *recvConns[NCCL_MAX_NVLS_ARITY]; +diff --git a/src/device/generate.py b/src/device/generate.py +index 0b053de..87bf6cb 100755 +--- a/src/device/generate.py ++++ b/src/device/generate.py +@@ -195,7 +195,7 @@ kernel_funcs = sorted(set(best_kernel(*fn) for fn in primary_funcs)) + ################################################################################ + + # Generate /device_table.cu +-with open(os.path.join(gensrc, "device_table.cu"), "w") as f: ++with open(os.path.join(gensrc, "device_table.cu.cc"), "w") as f: + out = f.write + out('#include "common.h"\n') + out("\n") +@@ -210,7 +210,7 @@ with open(os.path.join(gensrc, "device_table.cu"), "w") as f: + out("#endif\n") + out("\n") + +- out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable[] = {\n"); ++ out("__device__ ncclDevFuncPtr_t ncclDevFuncTable[] = {\n"); + index = 0 + for fn in primary_funcs: + sym = paste("_", "ncclDevFunc", *fn) +@@ -257,28 +257,45 @@ with open(os.path.join(gensrc, "host_table.cc"), "w") as f: + + # List of all kernel function pointers. + out("extern int const ncclDevKernelCount = %d;\n" % len(kernel_funcs)) +- out("extern void* const ncclDevKernelList[] = {\n") ++ + index = 0 + for kfn in kernel_funcs: + cudart, _ = required_cuda(*kfn) + sym = paste("_", "ncclDevKernel", *kfn) + if cudart != 0: out("#if CUDART_VERSION >= %d\n" % cudart) +- out("/*%4d*/ (void*)%s,\n" % (index, sym)); +- if cudart != 0: out("#else\n" "/*%4d*/ nullptr,\n" "#endif\n" % index) ++ out("/*%4d*/ void* %s_ptr = (void*)%s;\n" % (index, sym, sym)); ++ if cudart != 0: ++ out("#else\n/*%4d*/ void* %s_ptr = nullptr;\n#endif\n" % (index, sym)); ++ index += 1 ++ ++ out("extern void* const ncclDevKernelList[] = {\n") ++ index = 0 ++ for kfn in kernel_funcs: ++ sym = paste("_", "ncclDevKernel", *kfn) ++ out("/*%4d*/ %s_ptr,\n" % (index, sym)); + index += 1 + out("nullptr};\n") + out("\n") + + # Maps primary id to kernel function pointer. +- out("extern void* const ncclDevKernelForFunc[] = {\n") ++ + index = 0 + for fn in primary_funcs: + kfn = best_kernel(*fn) + sym = paste("_", "ncclDevKernel", *kfn) + cudart, _ = required_cuda(*kfn) + if cudart != 0: out("#if CUDART_VERSION >= %d\n" % cudart) +- out("/*%4d*/ (void*)%s,\n" % (index, sym)) +- if cudart != 0: out("#else\n" "/*%4d*/ nullptr,\n" "#endif\n" % index) ++ out("/*%4d*/ void* %s_ptr_%d = (void*)%s;\n" % (index, sym, index, sym)) ++ if cudart != 0: ++ out("#else\n" "/*%4d*/ void* %s_ptr_%d = nullptr;\n" "#endif\n" % (index, sym, index)) ++ index += 1 ++ ++ out("extern void* const ncclDevKernelForFunc[] = {\n") ++ index = 0 ++ for fn in primary_funcs: ++ kfn = best_kernel(*fn) ++ sym = paste("_", "ncclDevKernel", *kfn) ++ out("/*%4d*/ %s_ptr_%d,\n" % (index, sym, index)) + index += 1 + out("nullptr};\n") + out("\n") +@@ -297,7 +314,7 @@ with open(os.path.join(gensrc, "host_table.cc"), "w") as f: + # "coll" is reflected in the name: formally that no two funcs having different + # coll's map to the same filename. + def impl_filename(coll, redop, ty, algo, proto): +- return "%s.cu" % paste("_", coll_camel_to_lower[coll], redop and redop.lower(), ty) ++ return "%s.cu.cc" % paste("_", coll_camel_to_lower[coll], redop and redop.lower(), ty) + + # Partition the functions and kernels to the .cu filenames. The partition is + # a dictionary mapping filename to (coll, func-tuple list) +@@ -318,7 +335,7 @@ name_to_kernels = partition_by_name(kfn for kfn in kernel_funcs if kfn[0]!="Gene + with open(os.path.join(gensrc, "rules.mk"), "w") as f: + out = f.write + impl_names = sorted(name_to_funcs.keys()) +- names = impl_names + ["host_table.cc", "device_table.cu"] ++ names = impl_names + ["host_table.cc", "device_table.cu.cc"] + out("LIB_OBJS_GEN = $(patsubst %, $(OBJDIR)/genobj/%.o, {names})\n" + .format(names=" ".join(names))) + out("\n") +diff --git a/src/device/onerank.cu b/src/device/onerank.cu.cc similarity index 100% -rename from src/collectives/device/all_reduce.cu -rename to src/collectives/device/all_reduce.cu.cc -diff --git a/src/collectives/device/broadcast.cu b/src/collectives/device/broadcast.cu.cc -similarity index 100% -rename from src/collectives/device/broadcast.cu -rename to src/collectives/device/broadcast.cu.cc -diff --git a/src/collectives/device/functions.cu b/src/collectives/device/functions.cu.cc -similarity index 100% -rename from src/collectives/device/functions.cu -rename to src/collectives/device/functions.cu.cc -diff --git a/src/collectives/device/onerank_reduce.cu b/src/collectives/device/onerank_reduce.cu.cc -similarity index 100% -rename from src/collectives/device/onerank_reduce.cu -rename to src/collectives/device/onerank_reduce.cu.cc -diff --git a/src/collectives/device/reduce.cu b/src/collectives/device/reduce.cu.cc -similarity index 100% -rename from src/collectives/device/reduce.cu -rename to src/collectives/device/reduce.cu.cc -diff --git a/src/collectives/device/reduce_scatter.cu b/src/collectives/device/reduce_scatter.cu.cc -similarity index 100% -rename from src/collectives/device/reduce_scatter.cu -rename to src/collectives/device/reduce_scatter.cu.cc -diff --git a/src/collectives/device/sendrecv.cu b/src/collectives/device/sendrecv.cu.cc -similarity index 100% -rename from src/collectives/device/sendrecv.cu -rename to src/collectives/device/sendrecv.cu.cc -diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h -index accf8371a..4ab1bfac6 100644 ---- a/src/collectives/device/common.h -+++ b/src/collectives/device/common.h -@@ -166,7 +166,8 @@ __device__ void ncclKernel( - bytes = 0; - break; - } -- copyToShmem16(tid%WARP_SIZE, dst, src, bytes); -+ if (bytes) -+ copyToShmem16(tid%WARP_SIZE, dst, src, bytes); - } - __syncthreads(); // publish ncclShmem - \ No newline at end of file +rename from src/device/onerank.cu +rename to src/device/onerank.cu.cc diff --git a/third_party/tsl/third_party/nccl/build_defs.bzl.tpl b/third_party/tsl/third_party/nccl/build_defs.bzl.tpl index 1e6dc29462705..2a450f8a1d55a 100644 --- a/third_party/tsl/third_party/nccl/build_defs.bzl.tpl +++ b/third_party/tsl/third_party/nccl/build_defs.bzl.tpl @@ -7,38 +7,6 @@ load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain") _cuda_version = %{cuda_version} _cuda_clang = %{cuda_clang} -def _gen_device_srcs_impl(ctx): - ops = ["sum", "prod", "min", "max", "premulsum", "sumpostdiv"] - # TF uses CUDA version > 11.0, so enable bf16 type unconditionally. - types = ["i8", "u8", "i32", "u32", "i64", "u64", "f16", "bf16", "f32", "f64"] - hdr_tail = "****************************************/" - defines = "\n\n#define NCCL_OP %d\n#define NCCL_TYPE %d" - - files = [] - for NCCL_OP, op in enumerate(ops): - for NCCL_TYPE, dt in enumerate(types): - substitutions = { - hdr_tail: hdr_tail + defines % (NCCL_OP, NCCL_TYPE), - } - for src in ctx.files.srcs: - name = "%s_%s_%s" % (op, dt, src.basename) - file = ctx.actions.declare_file(name, sibling = src) - ctx.actions.expand_template( - output = file, - template = src, - substitutions = substitutions, - ) - files.append(file) - return [DefaultInfo(files = depset(files))] - -gen_device_srcs = rule( - implementation = _gen_device_srcs_impl, - attrs = { - "srcs": attr.label_list(allow_files = True), - }, -) -"""Adds prefix to each file name in srcs and adds #define NCCL_OP.""" - def _rdc_copts(): """Returns copts for compiling relocatable device code.""" diff --git a/third_party/tsl/third_party/nccl/generated_names.bzl.tpl b/third_party/tsl/third_party/nccl/generated_names.bzl.tpl new file mode 100644 index 0000000000000..dcb5ad9232786 --- /dev/null +++ b/third_party/tsl/third_party/nccl/generated_names.bzl.tpl @@ -0,0 +1,117 @@ +"List of sources generated by :generate_nccl_kernels" +GENERATED_SOURCES = [ + "src/device/generated/all_gather.cu.cc", + "src/device/generated/all_reduce.cu.cc", + "src/device/generated/all_reduce_minmax_bf16.cu.cc", + "src/device/generated/all_reduce_minmax_f16.cu.cc", + "src/device/generated/all_reduce_minmax_f32.cu.cc", + "src/device/generated/all_reduce_minmax_f64.cu.cc", + "src/device/generated/all_reduce_minmax_i32.cu.cc", + "src/device/generated/all_reduce_minmax_i64.cu.cc", + "src/device/generated/all_reduce_minmax_u32.cu.cc", + "src/device/generated/all_reduce_minmax_u64.cu.cc", + "src/device/generated/all_reduce_minmax_u8.cu.cc", + "src/device/generated/all_reduce_premulsum_bf16.cu.cc", + "src/device/generated/all_reduce_premulsum_f16.cu.cc", + "src/device/generated/all_reduce_premulsum_f32.cu.cc", + "src/device/generated/all_reduce_premulsum_f64.cu.cc", + "src/device/generated/all_reduce_premulsum_u32.cu.cc", + "src/device/generated/all_reduce_premulsum_u64.cu.cc", + "src/device/generated/all_reduce_premulsum_u8.cu.cc", + "src/device/generated/all_reduce_prod_bf16.cu.cc", + "src/device/generated/all_reduce_prod_f16.cu.cc", + "src/device/generated/all_reduce_prod_f32.cu.cc", + "src/device/generated/all_reduce_prod_f64.cu.cc", + "src/device/generated/all_reduce_prod_u32.cu.cc", + "src/device/generated/all_reduce_prod_u64.cu.cc", + "src/device/generated/all_reduce_prod_u8.cu.cc", + "src/device/generated/all_reduce_sum_bf16.cu.cc", + "src/device/generated/all_reduce_sum_f16.cu.cc", + "src/device/generated/all_reduce_sum_f32.cu.cc", + "src/device/generated/all_reduce_sum_f64.cu.cc", + "src/device/generated/all_reduce_sumpostdiv_i32.cu.cc", + "src/device/generated/all_reduce_sumpostdiv_i64.cu.cc", + "src/device/generated/all_reduce_sumpostdiv_i8.cu.cc", + "src/device/generated/all_reduce_sumpostdiv_u32.cu.cc", + "src/device/generated/all_reduce_sumpostdiv_u64.cu.cc", + "src/device/generated/all_reduce_sumpostdiv_u8.cu.cc", + "src/device/generated/all_reduce_sum_u32.cu.cc", + "src/device/generated/all_reduce_sum_u64.cu.cc", + "src/device/generated/all_reduce_sum_u8.cu.cc", + "src/device/generated/broadcast.cu.cc", + "src/device/generated/device_table.cu.cc", + "src/device/generated/host_table.cc", + "src/device/generated/reduce.cu.cc", + "src/device/generated/reduce_minmax_bf16.cu.cc", + "src/device/generated/reduce_minmax_f16.cu.cc", + "src/device/generated/reduce_minmax_f32.cu.cc", + "src/device/generated/reduce_minmax_f64.cu.cc", + "src/device/generated/reduce_minmax_u32.cu.cc", + "src/device/generated/reduce_minmax_u64.cu.cc", + "src/device/generated/reduce_minmax_u8.cu.cc", + "src/device/generated/reduce_premulsum_bf16.cu.cc", + "src/device/generated/reduce_premulsum_f16.cu.cc", + "src/device/generated/reduce_premulsum_f32.cu.cc", + "src/device/generated/reduce_premulsum_f64.cu.cc", + "src/device/generated/reduce_premulsum_u32.cu.cc", + "src/device/generated/reduce_premulsum_u64.cu.cc", + "src/device/generated/reduce_premulsum_u8.cu.cc", + "src/device/generated/reduce_prod_bf16.cu.cc", + "src/device/generated/reduce_prod_f16.cu.cc", + "src/device/generated/reduce_prod_f32.cu.cc", + "src/device/generated/reduce_prod_f64.cu.cc", + "src/device/generated/reduce_prod_u32.cu.cc", + "src/device/generated/reduce_prod_u64.cu.cc", + "src/device/generated/reduce_prod_u8.cu.cc", + "src/device/generated/reduce_scatter.cu.cc", + "src/device/generated/reduce_scatter_minmax_bf16.cu.cc", + "src/device/generated/reduce_scatter_minmax_f16.cu.cc", + "src/device/generated/reduce_scatter_minmax_f32.cu.cc", + "src/device/generated/reduce_scatter_minmax_f64.cu.cc", + "src/device/generated/reduce_scatter_minmax_i32.cu.cc", + "src/device/generated/reduce_scatter_minmax_i64.cu.cc", + "src/device/generated/reduce_scatter_minmax_u32.cu.cc", + "src/device/generated/reduce_scatter_minmax_u64.cu.cc", + "src/device/generated/reduce_scatter_minmax_u8.cu.cc", + "src/device/generated/reduce_scatter_premulsum_bf16.cu.cc", + "src/device/generated/reduce_scatter_premulsum_f16.cu.cc", + "src/device/generated/reduce_scatter_premulsum_f32.cu.cc", + "src/device/generated/reduce_scatter_premulsum_f64.cu.cc", + "src/device/generated/reduce_scatter_premulsum_u32.cu.cc", + "src/device/generated/reduce_scatter_premulsum_u64.cu.cc", + "src/device/generated/reduce_scatter_premulsum_u8.cu.cc", + "src/device/generated/reduce_scatter_prod_bf16.cu.cc", + "src/device/generated/reduce_scatter_prod_f16.cu.cc", + "src/device/generated/reduce_scatter_prod_f32.cu.cc", + "src/device/generated/reduce_scatter_prod_f64.cu.cc", + "src/device/generated/reduce_scatter_prod_u32.cu.cc", + "src/device/generated/reduce_scatter_prod_u64.cu.cc", + "src/device/generated/reduce_scatter_prod_u8.cu.cc", + "src/device/generated/reduce_scatter_sum_bf16.cu.cc", + "src/device/generated/reduce_scatter_sum_f16.cu.cc", + "src/device/generated/reduce_scatter_sum_f32.cu.cc", + "src/device/generated/reduce_scatter_sum_f64.cu.cc", + "src/device/generated/reduce_scatter_sumpostdiv_i32.cu.cc", + "src/device/generated/reduce_scatter_sumpostdiv_i64.cu.cc", + "src/device/generated/reduce_scatter_sumpostdiv_i8.cu.cc", + "src/device/generated/reduce_scatter_sumpostdiv_u32.cu.cc", + "src/device/generated/reduce_scatter_sumpostdiv_u64.cu.cc", + "src/device/generated/reduce_scatter_sumpostdiv_u8.cu.cc", + "src/device/generated/reduce_scatter_sum_u32.cu.cc", + "src/device/generated/reduce_scatter_sum_u64.cu.cc", + "src/device/generated/reduce_scatter_sum_u8.cu.cc", + "src/device/generated/reduce_sum_bf16.cu.cc", + "src/device/generated/reduce_sum_f16.cu.cc", + "src/device/generated/reduce_sum_f32.cu.cc", + "src/device/generated/reduce_sum_f64.cu.cc", + "src/device/generated/reduce_sumpostdiv_i32.cu.cc", + "src/device/generated/reduce_sumpostdiv_i64.cu.cc", + "src/device/generated/reduce_sumpostdiv_i8.cu.cc", + "src/device/generated/reduce_sumpostdiv_u32.cu.cc", + "src/device/generated/reduce_sumpostdiv_u64.cu.cc", + "src/device/generated/reduce_sumpostdiv_u8.cu.cc", + "src/device/generated/reduce_sum_u32.cu.cc", + "src/device/generated/reduce_sum_u64.cu.cc", + "src/device/generated/reduce_sum_u8.cu.cc", + "src/device/generated/sendrecv.cu.cc", +] diff --git a/third_party/tsl/third_party/nccl/nccl_configure.bzl b/third_party/tsl/third_party/nccl/nccl_configure.bzl index 2dc862f273a1a..a62c29caf27a4 100644 --- a/third_party/tsl/third_party/nccl/nccl_configure.bzl +++ b/third_party/tsl/third_party/nccl/nccl_configure.bzl @@ -123,6 +123,7 @@ def _create_local_nccl_repository(repository_ctx): else: repository_ctx.file("BUILD", _NCCL_ARCHIVE_STUB_BUILD_CONTENT) + repository_ctx.template("generated_names.bzl", _label("generated_names.bzl.tpl"), {}) repository_ctx.template( "build_defs.bzl", _label("build_defs.bzl.tpl"), @@ -140,6 +141,7 @@ def _create_local_nccl_repository(repository_ctx): "%{nccl_library_dir}": config["nccl_library_dir"], } repository_ctx.template("BUILD", _label("system.BUILD.tpl"), config_wrap) + repository_ctx.template("generated_names.bzl", _label("generated_names.bzl.tpl"), {}) def _create_remote_nccl_repository(repository_ctx, remote_config_repo): repository_ctx.template( @@ -147,9 +149,13 @@ def _create_remote_nccl_repository(repository_ctx, remote_config_repo): config_repo_label(remote_config_repo, ":BUILD"), {}, ) - nccl_version = get_host_environ(repository_ctx, _TF_NCCL_VERSION, "") if nccl_version == "": + repository_ctx.template( + "generated_names.bzl", + config_repo_label(remote_config_repo, ":generated_names.bzl"), + {}, + ) repository_ctx.template( "build_defs.bzl", config_repo_label(remote_config_repo, ":build_defs.bzl"), diff --git a/third_party/tsl/third_party/nvtx/BUILD b/third_party/tsl/third_party/nvtx/BUILD new file mode 100644 index 0000000000000..af6de99cb8fcf --- /dev/null +++ b/third_party/tsl/third_party/nvtx/BUILD @@ -0,0 +1,13 @@ +# NVIDIA NVTX 3 + +licenses(["notice"]) + +exports_files(["LICENSE.txt"]) + +cc_library( + name = "headers", + hdrs = glob(["**"]), + include_prefix = "nvtx3", + visibility = ["//visibility:public"], + deps = ["@local_config_cuda//cuda:cuda_headers"], +) diff --git a/third_party/tsl/third_party/nvtx/LICENSE b/third_party/tsl/third_party/nvtx/LICENSE new file mode 100644 index 0000000000000..aa23e563e6f86 --- /dev/null +++ b/third_party/tsl/third_party/nvtx/LICENSE @@ -0,0 +1,5 @@ +Copyright 2009-2022 NVIDIA Corporation. All rights reserved. + +Licensed under the Apache License v2.0 with LLVM Exceptions. +See https://llvm.org/LICENSE.txt for license information. +SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/third_party/tsl/third_party/png.BUILD b/third_party/tsl/third_party/png.BUILD deleted file mode 100644 index 383c6e0ef4df6..0000000000000 --- a/third_party/tsl/third_party/png.BUILD +++ /dev/null @@ -1,70 +0,0 @@ -# Description: -# libpng is the official PNG reference library. - -licenses(["notice"]) # BSD/MIT-like license - -exports_files(["LICENSE"]) - -cc_library( - name = "png", - srcs = [ - "png.c", - "pngdebug.h", - "pngerror.c", - "pngget.c", - "pnginfo.h", - "pnglibconf.h", - "pngmem.c", - "pngpread.c", - "pngpriv.h", - "pngread.c", - "pngrio.c", - "pngrtran.c", - "pngrutil.c", - "pngset.c", - "pngstruct.h", - "pngtrans.c", - "pngwio.c", - "pngwrite.c", - "pngwtran.c", - "pngwutil.c", - ] + select({ - ":windows": [ - "intel/filter_sse2_intrinsics.c", - "intel/intel_init.c", - ], - "@tsl//tsl:linux_ppc64le": [ - #"powerpc/filter_vsx_intrinsics.c", - #"powerpc/powerpc_init.c", - ], - "//conditions:default": [ - ], - }), - hdrs = [ - "png.h", - "pngconf.h", - ], - copts = select({ - ":windows": ["-DPNG_INTEL_SSE_OPT=1"], - "//conditions:default": [], - }), - includes = ["."], - linkopts = select({ - ":windows": [], - "//conditions:default": ["-lm"], - }), - visibility = ["//visibility:public"], - deps = ["@zlib"], -) - -genrule( - name = "snappy_stubs_public_h", - srcs = ["scripts/pnglibconf.h.prebuilt"], - outs = ["pnglibconf.h"], - cmd = "sed -e 's/PNG_ZLIB_VERNUM 0/PNG_ZLIB_VERNUM 0x12d0/' $< >$@", -) - -config_setting( - name = "windows", - values = {"cpu": "x64_windows"}, -) diff --git a/third_party/tsl/third_party/png_fix_rpi.patch b/third_party/tsl/third_party/png_fix_rpi.patch deleted file mode 100644 index df6cfd7ffaee5..0000000000000 --- a/third_party/tsl/third_party/png_fix_rpi.patch +++ /dev/null @@ -1,16 +0,0 @@ -diff -r -u ./scripts/pnglibconf.h.prebuilt ./scripts/pnglibconf.h.prebuilt ---- ./scripts/pnglibconf.h.prebuilt -+++ ./scripts/pnglibconf.h.prebuilt -@@ -19,6 +19,12 @@ - #define PNG_ALIGNED_MEMORY_SUPPORTED - /*#undef PNG_ARM_NEON_API_SUPPORTED*/ - /*#undef PNG_ARM_NEON_CHECK_SUPPORTED*/ -+ -+/* Workaround not having a great build file by forcing -+ * png filter optimization to be disabled on arm */ -+#define PNG_ARM_NEON_OPT 0 -+ -+ - #define PNG_BENIGN_ERRORS_SUPPORTED - #define PNG_BENIGN_READ_ERRORS_SUPPORTED - /*#undef PNG_BENIGN_WRITE_ERRORS_SUPPORTED*/ diff --git a/third_party/tsl/third_party/systemlibs/grpc.bazel.generate_cc.bzl b/third_party/tsl/third_party/systemlibs/grpc.bazel.generate_cc.bzl index 3c3f20c06ec5e..c659ca16366b7 100644 --- a/third_party/tsl/third_party/systemlibs/grpc.bazel.generate_cc.bzl +++ b/third_party/tsl/third_party/systemlibs/grpc.bazel.generate_cc.bzl @@ -172,8 +172,6 @@ _generate_cc = rule( cfg = "exec", ), }, - # We generate .h files, so we need to output to genfiles. - output_to_genfiles = True, implementation = generate_cc_impl, ) diff --git a/third_party/tsl/third_party/systemlibs/protobuf.BUILD b/third_party/tsl/third_party/systemlibs/protobuf.BUILD index 4d05ab28d12e9..c7d940605f9f7 100644 --- a/third_party/tsl/third_party/systemlibs/protobuf.BUILD +++ b/third_party/tsl/third_party/systemlibs/protobuf.BUILD @@ -1,10 +1,10 @@ -load("@rules_proto//proto:defs.bzl", "proto_library") load( "@com_google_protobuf//:protobuf.bzl", "cc_proto_library", "proto_gen", "py_proto_library", ) +load("@rules_proto//proto:defs.bzl", "proto_library") licenses(["notice"]) diff --git a/third_party/tsl/third_party/systemlibs/protobuf.bzl b/third_party/tsl/third_party/systemlibs/protobuf.bzl index 66a06300e4f9e..3813d04954e22 100644 --- a/third_party/tsl/third_party/systemlibs/protobuf.bzl +++ b/third_party/tsl/third_party/systemlibs/protobuf.bzl @@ -150,7 +150,6 @@ proto_gen = rule( "gen_py": attr.bool(), "outs": attr.output_list(), }, - output_to_genfiles = True, implementation = _proto_gen_impl, ) """Generates codes from Protocol Buffers definitions. diff --git a/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/tsl/third_party/tf_runtime/workspace.bzl index 7a59477c03eb4..2b789c8cead31 100644 --- a/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "8f915f25e8b17d2509bb6c7f199a45f2a5e6736c" - TFRT_SHA256 = "6d0cc4221d9bb6739bf16a03da482abc348f6143395726595d89e3f12158a0ea" + TFRT_COMMIT = "968eb3e5b0aa2e20301a41af9bb14a48dd1aee40" + TFRT_SHA256 = "cd3b1d190625d6ca5ddcf1c9cc0b095928707b623f1b986f1ba333a89d5418ae" tf_http_archive( name = "tf_runtime", diff --git a/third_party/tsl/tools/def_file_filter/def_file_filter.py.tpl b/third_party/tsl/tools/def_file_filter/def_file_filter.py.tpl index 4091a5713f8ee..f77217121d668 100644 --- a/third_party/tsl/tools/def_file_filter/def_file_filter.py.tpl +++ b/third_party/tsl/tools/def_file_filter/def_file_filter.py.tpl @@ -310,6 +310,9 @@ def main(): def_fp.write("\t ??1CoordinatedTask@tensorflow@@UEAA@XZ\n") # for _pywrap_tfe def_fp.write("\t ?CopyFrom@CoordinatedTask@tensorflow@@QEAAXAEBV12@@Z\n") # for _pywrap_tfe def_fp.write("\t ??0CoordinatedTask@tensorflow@@IEAA@PEAVArena@protobuf@google@@_N@Z\n") # for _pywrap_tfe + def_fp.write("\t ??0LogMessageFatal@log_internal@lts_20230802@absl@@QEAA@PEBDH@Z\n") # for _pywrap_tfe + def_fp.write("\t ??1LogMessageFatal@log_internal@lts_20230802@absl@@QEAA@XZ\n") # for _pywrap_tfe + def_fp.write("\t ??$CopyToEncodedBuffer@$0A@@LogMessage@log_internal@lts_20230802@absl@@AEAAXV?$basic_string_view@DU?$char_traits@D@std@@@std@@@Z\n") # for _pywrap_tfe def_fp.write("\t ?MaybeTrackCordImpl@CordzInfo@cord_internal@lts_20230802@absl@@CAXAEAVInlineData@234@AEBV5234@W4MethodIdentifier@CordzUpdateTracker@234@@Z\n") # for tensorflow::Status usage of absl::Cord diff --git a/third_party/tsl/tools/def_file_filter/def_file_filter_configure.bzl b/third_party/tsl/tools/def_file_filter/def_file_filter_configure.bzl index 3c7543cba86f9..a648d3d8d646a 100644 --- a/third_party/tsl/tools/def_file_filter/def_file_filter_configure.bzl +++ b/third_party/tsl/tools/def_file_filter/def_file_filter_configure.bzl @@ -19,9 +19,8 @@ symbols through this python script. * `VS140COMNTOOLS` """ -load("@bazel_tools//tools/cpp:windows_cc_configure.bzl", "find_vc_path") -load("@bazel_tools//tools/cpp:windows_cc_configure.bzl", "find_msvc_tool") load("@bazel_tools//tools/cpp:lib_cc_configure.bzl", "auto_configure_fail") +load("@bazel_tools//tools/cpp:windows_cc_configure.bzl", "find_msvc_tool", "find_vc_path") def _def_file_filter_configure_impl(repository_ctx): if repository_ctx.os.name.lower().find("windows") == -1: diff --git a/third_party/tsl/tools/def_file_filter/symbols_pybind.txt b/third_party/tsl/tools/def_file_filter/symbols_pybind.txt index 0f52e9fa804da..78c42c6f454c5 100644 --- a/third_party/tsl/tools/def_file_filter/symbols_pybind.txt +++ b/third_party/tsl/tools/def_file_filter/symbols_pybind.txt @@ -31,6 +31,11 @@ tensorflow::FlattenDictItems tensorflow::IsGoogleCudaEnabled tensorflow::IsBuiltWithROCm tensorflow::IsBuiltWithNvcc +tensorflow::IsAArch32Available +tensorflow::IsAArch64Available +tensorflow::IsPowerPCAvailable +tensorflow::IsSystemZAvailable +tensorflow::IsX86Available tensorflow::GpuSupportsHalfMatMulAndConv tensorflow::IsMklEnabled @@ -54,7 +59,7 @@ tensorflow::tfprof::SerializeToString [//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool] # graph_analyze tensorflow::grappler::graph_analyzer::GraphAnalyzerTool -[//external/local_tsl/tsl/python/lib/core:ml_dtypes_lib] # bfloat16, float8 +[//external/local_xla/xla/tsl/python/lib/core:ml_dtypes_lib] # bfloat16, float8 tsl::ml_dtypes::RegisterTypes tsl::ml_dtypes::GetBfloat16Dtype tsl::ml_dtypes::GetFloat8E4m3b11fnuzDtype @@ -302,7 +307,7 @@ tensorflow::AddWhileInputHack tensorflow::RecordMutation tensorflow::Graph::IsControlEdge -[//external/local_tsl/tsl/python/lib/core:numpy] # tf_session +[//external/local_xla/xla/tsl/python/lib/core:numpy] # tf_session tsl::ImportNumpy _tsl_numpy_api @@ -468,6 +473,9 @@ tensorflow::metrics::CheckpointWriteDuration tensorflow::metrics::AsyncCheckpointWriteDuration tensorflow::metrics::TrainingTimeSaved tensorflow::metrics::CheckpointSize +tensorflow::metrics::ShardingCallbackDuration +tensorflow::metrics::NumCheckpointShardsWritten +tensorflow::metrics::ShardingCallbackDescription [//tensorflow/cc/saved_model:fingerprinting_impl] # SavedModel Fingerprinting tensorflow::saved_model::fingerprinting::CreateFingerprintDef @@ -554,7 +562,8 @@ tensorflow::Safe_PyObjectPtr tensorflow::quantization::QuantizeQatModel tensorflow::quantization::QuantizePtqModelPreCalibration tensorflow::quantization::QuantizePtqModelPostCalibration -tensorflow::quantization::QuantizePtqDynamicRange +tensorflow::quantization::QuantizeStaticRangePtq +tensorflow::quantization::QuantizeDynamicRangePtq tensorflow::quantization::QuantizeWeightOnly [//tensorflow/dtensor/cc:dtensor_device_cc] # DTensor diff --git a/third_party/tsl/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl b/third_party/tsl/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl index 5f2057a3c1436..4455aea60109f 100644 --- a/third_party/tsl/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl +++ b/third_party/tsl/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl @@ -1,8 +1,8 @@ """Configurations of AARCH64 builds used with Docker container.""" -load("//tools/toolchains:cpus/aarch64/aarch64.bzl", "remote_aarch64_configure") -load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure") load("//third_party/py:python_configure.bzl", "remote_python_configure") +load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure") +load("//tools/toolchains:cpus/aarch64/aarch64.bzl", "remote_aarch64_configure") def ml2014_tf_aarch64_configs(name_container_map, env): for name, container in name_container_map.items(): diff --git a/third_party/tsl/tools/toolchains/cross_compile/cc/BUILD b/third_party/tsl/tools/toolchains/cross_compile/cc/BUILD index 7db2527259d02..7cf6d8c3747b2 100644 --- a/third_party/tsl/tools/toolchains/cross_compile/cc/BUILD +++ b/third_party/tsl/tools/toolchains/cross_compile/cc/BUILD @@ -1,6 +1,6 @@ """Toolchain configs for cross-compiling TensorFlow""" -load("@bazel_tools//tools/cpp:unix_cc_toolchain_config.bzl", "cc_toolchain_config") +load(":cc_toolchain_config.bzl", "cc_toolchain_config") package(default_visibility = ["//visibility:public"]) @@ -11,11 +11,21 @@ cc_toolchain_suite( toolchains = { "aarch64": ":linux_aarch64_toolchain", "k8": ":linux_x86_toolchain", + "darwin": ":macos_x86_toolchain", }, ) filegroup(name = "empty") +# We define a wraper ("cc_wrapper.sh") around the compiler to replace all paths +# in the binary (bazel-out/.../path/to/original/library.so) by the paths +# relative to the binary. Without it, we run into "Library not loaded" error +# when trying run cross-compiled tests, see b/300002682. +filegroup( + name = "cc_wrapper_and_macos_sysroot", + srcs = ["cc_wrapper.sh"] + glob(["MacOSX.sdk/**"]), +) + cc_toolchain( name = "linux_x86_toolchain", all_files = ":empty", @@ -186,3 +196,100 @@ cc_toolchain_config( "-Wno-gnu-offsetof-extensions", ], ) + +cc_toolchain( + name = "macos_x86_toolchain", + all_files = ":cc_wrapper_and_macos_sysroot", + compiler_files = ":cc_wrapper_and_macos_sysroot", + dwp_files = ":empty", + linker_files = ":cc_wrapper_and_macos_sysroot", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":macos_x86_toolchain_config", + toolchain_identifier = "macos_x86_toolchain", +) + +cc_toolchain_config( + name = "macos_x86_toolchain_config", + abi_libc_version = "darwin_x86_64", + abi_version = "darwin_x86_64", + builtin_sysroot = "tensorflow/tools/toolchains/cross_compile/cc/MacOSX.sdk", + compile_flags = [ + "--target=x86_64-apple-darwin", + "-fstack-protector", + "-Wall", + "-Wthread-safety", + "-Wself-assign", + "-Wunused-but-set-parameter", + "-Wno-free-nonheap-object", + "-fcolor-diagnostics", + "-fno-omit-frame-pointer", + "-DOS_MACOSX", + "-DGRPC_BAZEL_BUILD", + "-stdlib=libc++", + "-mavx", + # Target Catalina as the minimum supported OS + "-mmacos-version-min=10.15", + ], + compiler = "clang", + coverage_compile_flags = ["--coverage"], + coverage_link_flags = ["--coverage"], + cpu = "darwin", + cxx_builtin_include_directories = [ + "%sysroot%/usr/include", + "/usr/lib/llvm-17/include/", + "/usr/lib/llvm-17/lib/clang/17/include", + "%sysroot%/System/Library/Frameworks/Security.framework/Headers", + "%sysroot%/System/Library/Frameworks/CoreFoundation.framework/Headers", + "%sysroot%/System/Library/Frameworks/SystemConfiguration.framework/Headers", + ], + dbg_compile_flags = ["-g"], + host_system_name = "linux", + link_flags = [ + "--target=x86_64-apple-darwin", + "-lSystem", + "-fuse-ld=lld", + "--ld-path=/usr/lib/llvm-17/bin/ld64.lld", + "-headerpad_max_install_names", + "-Wl,-undefined,dynamic_lookup", + # Target Catalina as the minimum supported OS + "-Wl,-platform_version,macos,10.15.0,10.15", + ], + link_libs = [ + "-lc++", + "-lm", + ], + opt_compile_flags = [ + "-g0", + "-O2", + "-D_FORTIFY_SOURCE=1", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", + ], + opt_link_flags = ["-Wl,-dead_strip"], + supports_start_end_lib = True, + target_libc = "macosx", + target_system_name = "x86_64-apple-macosx10.15", + tool_paths = { + "gcc": "cc_wrapper.sh", + "ld": "/usr/lib/llvm-17/bin/ld64.lld", + "ar": "/usr/lib/llvm-17/bin/llvm-libtool-darwin", + "cpp": "/usr/lib/llvm-17/bin/clang++", + "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", + "nm": "/usr/lib/llvm-17/bin/llvm-nm", + "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", + "strip": "/usr/lib/llvm-17/bin/llvm-strip", + }, + toolchain_identifier = "macos_x86_toolchain", + unfiltered_compile_flags = [ + "-no-canonical-prefixes", + "-Wno-builtin-macro-redefined", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + "-Wno-unused-command-line-argument", + "-Wno-gnu-offsetof-extensions", + ], +) diff --git a/third_party/tsl/tools/toolchains/cross_compile/cc/cc_toolchain_config.bzl b/third_party/tsl/tools/toolchains/cross_compile/cc/cc_toolchain_config.bzl new file mode 100644 index 0000000000000..de638c0159b5f --- /dev/null +++ b/third_party/tsl/tools/toolchains/cross_compile/cc/cc_toolchain_config.bzl @@ -0,0 +1,1444 @@ +# Copyright 2019 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# The toolchain configuration rule file, "cc_toolchain_config.bzl", is a clone of https://github.com/bazelbuild/bazel/blob/6.1.0/tools/cpp/unix_cc_toolchain_config.bzl +# except for a few changes. We remove "-s" as a default archiver option because it is not supported +# by llvm-libtool-darwin, see https://github.com/bazelbuild/bazel/pull/17489. +# We remove "supports_dynamic_linker_feature" from the macOS toolchain config because otherwise +# certain TensorFlow tests get stuck at linking phase. See https://github.com/bazelbuild/bazel/issues/4341 +# which has more details on why this option does not work on macOS. +# +# Note to TF developers: Please make sure this file stays in sync with the Bazel version used +# by TensorFlow. If TensorFlow's Bazel version is updated, replace this file with the contents of +# `tools/cpp/unix_cc_toolchain_config.bzl` from the corresponding Bazel version tag in https://github.com/bazelbuild/bazel. +# Please make sure to add in the additional changes listed above if needed. Contact srnitin@ for +# any questions. +"""A Starlark cc_toolchain configuration rule""" + +load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES") +load( + "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl", + "action_config", + "artifact_name_pattern", + "feature", + "feature_set", + "flag_group", + "flag_set", + "tool", + "tool_path", + "variable_with_value", + "with_feature_set", +) + +def layering_check_features(compiler): + if compiler != "clang": + return [] + return [ + feature( + name = "use_module_maps", + requires = [feature_set(features = ["module_maps"])], + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group( + flags = [ + "-fmodule-name=%{module_name}", + "-fmodule-map-file=%{module_map_file}", + ], + ), + ], + ), + ], + ), + + # Note: not all C++ rules support module maps; thus, do not imply this + # feature from other features - instead, require it. + feature(name = "module_maps", enabled = True), + feature( + name = "layering_check", + implies = ["use_module_maps"], + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group(flags = [ + "-fmodules-strict-decluse", + "-Wprivate-header", + ]), + flag_group( + iterate_over = "dependent_module_map_files", + flags = [ + "-fmodule-map-file=%{dependent_module_map_files}", + ], + ), + ], + ), + ], + ), + ] + +all_compile_actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.clif_match, + ACTION_NAMES.lto_backend, +] + +all_cpp_compile_actions = [ + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.clif_match, +] + +preprocessor_compile_actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.clif_match, +] + +codegen_compile_actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, +] + +all_link_actions = [ + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, +] + +lto_index_actions = [ + ACTION_NAMES.lto_index_for_executable, + ACTION_NAMES.lto_index_for_dynamic_library, + ACTION_NAMES.lto_index_for_nodeps_dynamic_library, +] + +def _sanitizer_feature(name = "", specific_compile_flags = [], specific_link_flags = []): + return feature( + name = name, + flag_sets = [ + flag_set( + actions = all_compile_actions, + flag_groups = [ + flag_group(flags = [ + "-fno-omit-frame-pointer", + "-fno-sanitize-recover=all", + ] + specific_compile_flags), + ], + with_features = [ + with_feature_set(features = [name]), + ], + ), + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group(flags = specific_link_flags), + ], + with_features = [ + with_feature_set(features = [name]), + ], + ), + ], + ) + +def _impl(ctx): + tool_paths = [ + tool_path(name = name, path = path) + for name, path in ctx.attr.tool_paths.items() + ] + action_configs = [] + + llvm_cov_action = action_config( + action_name = ACTION_NAMES.llvm_cov, + tools = [ + tool( + path = ctx.attr.tool_paths["llvm-cov"], + ), + ], + ) + + action_configs.append(llvm_cov_action) + + supports_pic_feature = feature( + name = "supports_pic", + enabled = True, + ) + supports_start_end_lib_feature = feature( + name = "supports_start_end_lib", + enabled = True, + ) + + default_compile_flags_feature = feature( + name = "default_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_compile_actions, + flag_groups = [ + flag_group( + # Security hardening requires optimization. + # We need to undef it as some distributions now have it enabled by default. + flags = ["-U_FORTIFY_SOURCE"], + ), + ], + with_features = [ + with_feature_set( + not_features = ["thin_lto"], + ), + ], + ), + flag_set( + actions = all_compile_actions, + flag_groups = ([ + flag_group( + flags = ctx.attr.compile_flags, + ), + ] if ctx.attr.compile_flags else []), + ), + flag_set( + actions = all_compile_actions, + flag_groups = ([ + flag_group( + flags = ctx.attr.dbg_compile_flags, + ), + ] if ctx.attr.dbg_compile_flags else []), + with_features = [with_feature_set(features = ["dbg"])], + ), + flag_set( + actions = all_compile_actions, + flag_groups = ([ + flag_group( + flags = ctx.attr.opt_compile_flags, + ), + ] if ctx.attr.opt_compile_flags else []), + with_features = [with_feature_set(features = ["opt"])], + ), + flag_set( + actions = all_cpp_compile_actions + [ACTION_NAMES.lto_backend], + flag_groups = ([ + flag_group( + flags = ctx.attr.cxx_flags, + ), + ] if ctx.attr.cxx_flags else []), + ), + ], + ) + + default_link_flags_feature = feature( + name = "default_link_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = ([ + flag_group( + flags = ctx.attr.link_flags, + ), + ] if ctx.attr.link_flags else []), + ), + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = ([ + flag_group( + flags = ctx.attr.opt_link_flags, + ), + ] if ctx.attr.opt_link_flags else []), + with_features = [with_feature_set(features = ["opt"])], + ), + ], + ) + + dbg_feature = feature(name = "dbg") + + opt_feature = feature(name = "opt") + + sysroot_feature = feature( + name = "sysroot", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ] + all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = ["--sysroot=%{sysroot}"], + expand_if_available = "sysroot", + ), + ], + ), + ], + ) + + fdo_optimize_feature = feature( + name = "fdo_optimize", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = [ + "-fprofile-use=%{fdo_profile_path}", + "-fprofile-correction", + ], + expand_if_available = "fdo_profile_path", + ), + ], + ), + ], + provides = ["profile"], + ) + + supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True) + + user_compile_flags_feature = feature( + name = "user_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_compile_actions, + flag_groups = [ + flag_group( + flags = ["%{user_compile_flags}"], + iterate_over = "user_compile_flags", + expand_if_available = "user_compile_flags", + ), + ], + ), + ], + ) + + unfiltered_compile_flags_feature = feature( + name = "unfiltered_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_compile_actions, + flag_groups = ([ + flag_group( + flags = ctx.attr.unfiltered_compile_flags, + ), + ] if ctx.attr.unfiltered_compile_flags else []), + ), + ], + ) + + library_search_directories_feature = feature( + name = "library_search_directories", + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = ["-L%{library_search_directories}"], + iterate_over = "library_search_directories", + expand_if_available = "library_search_directories", + ), + ], + ), + ], + ) + + static_libgcc_feature = feature( + name = "static_libgcc", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.lto_index_for_executable, + ACTION_NAMES.lto_index_for_dynamic_library, + ], + flag_groups = [flag_group(flags = ["-static-libgcc"])], + with_features = [ + with_feature_set(features = ["static_link_cpp_runtimes"]), + ], + ), + ], + ) + + pic_feature = feature( + name = "pic", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group(flags = ["-fPIC"], expand_if_available = "pic"), + ], + ), + ], + ) + + per_object_debug_info_feature = feature( + name = "per_object_debug_info", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ["-gsplit-dwarf", "-g"], + expand_if_available = "per_object_debug_info_file", + ), + ], + ), + ], + ) + + preprocessor_defines_feature = feature( + name = "preprocessor_defines", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.clif_match, + ], + flag_groups = [ + flag_group( + flags = ["-D%{preprocessor_defines}"], + iterate_over = "preprocessor_defines", + ), + ], + ), + ], + ) + + cs_fdo_optimize_feature = feature( + name = "cs_fdo_optimize", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.lto_backend], + flag_groups = [ + flag_group( + flags = [ + "-fprofile-use=%{fdo_profile_path}", + "-Wno-profile-instr-unprofiled", + "-Wno-profile-instr-out-of-date", + "-fprofile-correction", + ], + expand_if_available = "fdo_profile_path", + ), + ], + ), + ], + provides = ["csprofile"], + ) + + autofdo_feature = feature( + name = "autofdo", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = [ + "-fauto-profile=%{fdo_profile_path}", + "-fprofile-correction", + ], + expand_if_available = "fdo_profile_path", + ), + ], + ), + ], + provides = ["profile"], + ) + + runtime_library_search_directories_feature = feature( + name = "runtime_library_search_directories", + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + iterate_over = "runtime_library_search_directories", + flag_groups = [ + flag_group( + flags = [ + "-Xlinker", + "-rpath", + "-Xlinker", + "$EXEC_ORIGIN/%{runtime_library_search_directories}", + ], + expand_if_true = "is_cc_test", + ), + flag_group( + flags = [ + "-Xlinker", + "-rpath", + "-Xlinker", + "$ORIGIN/%{runtime_library_search_directories}", + ], + expand_if_false = "is_cc_test", + ), + ], + expand_if_available = + "runtime_library_search_directories", + ), + ], + with_features = [ + with_feature_set(features = ["static_link_cpp_runtimes"]), + ], + ), + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + iterate_over = "runtime_library_search_directories", + flag_groups = [ + flag_group( + flags = [ + "-Xlinker", + "-rpath", + "-Xlinker", + "$ORIGIN/%{runtime_library_search_directories}", + ], + ), + ], + expand_if_available = + "runtime_library_search_directories", + ), + ], + with_features = [ + with_feature_set( + not_features = ["static_link_cpp_runtimes"], + ), + ], + ), + ], + ) + + fission_support_feature = feature( + name = "fission_support", + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = ["-Wl,--gdb-index"], + expand_if_available = "is_using_fission", + ), + ], + ), + ], + ) + + shared_flag_feature = feature( + name = "shared_flag", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.lto_index_for_dynamic_library, + ACTION_NAMES.lto_index_for_nodeps_dynamic_library, + ], + flag_groups = [flag_group(flags = ["-shared"])], + ), + ], + ) + + random_seed_feature = feature( + name = "random_seed", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group( + flags = ["-frandom-seed=%{output_file}"], + expand_if_available = "output_file", + ), + ], + ), + ], + ) + + includes_feature = feature( + name = "includes", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.clif_match, + ACTION_NAMES.objc_compile, + ACTION_NAMES.objcpp_compile, + ], + flag_groups = [ + flag_group( + flags = ["-include", "%{includes}"], + iterate_over = "includes", + expand_if_available = "includes", + ), + ], + ), + ], + ) + + fdo_instrument_feature = feature( + name = "fdo_instrument", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ] + all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = [ + "-fprofile-generate=%{fdo_instrument_path}", + "-fno-data-sections", + ], + expand_if_available = "fdo_instrument_path", + ), + ], + ), + ], + provides = ["profile"], + ) + + cs_fdo_instrument_feature = feature( + name = "cs_fdo_instrument", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.lto_backend, + ] + all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = [ + "-fcs-profile-generate=%{cs_fdo_instrument_path}", + ], + expand_if_available = "cs_fdo_instrument_path", + ), + ], + ), + ], + provides = ["csprofile"], + ) + + include_paths_feature = feature( + name = "include_paths", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.clif_match, + ACTION_NAMES.objc_compile, + ACTION_NAMES.objcpp_compile, + ], + flag_groups = [ + flag_group( + flags = ["-iquote", "%{quote_include_paths}"], + iterate_over = "quote_include_paths", + ), + flag_group( + flags = ["-I%{include_paths}"], + iterate_over = "include_paths", + ), + flag_group( + flags = ["-isystem", "%{system_include_paths}"], + iterate_over = "system_include_paths", + ), + ], + ), + ], + ) + + external_include_paths_feature = feature( + name = "external_include_paths", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.clif_match, + ACTION_NAMES.objc_compile, + ACTION_NAMES.objcpp_compile, + ], + flag_groups = [ + flag_group( + flags = ["-isystem", "%{external_include_paths}"], + iterate_over = "external_include_paths", + expand_if_available = "external_include_paths", + ), + ], + ), + ], + ) + + symbol_counts_feature = feature( + name = "symbol_counts", + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = [ + "-Wl,--print-symbol-counts=%{symbol_counts_output}", + ], + expand_if_available = "symbol_counts_output", + ), + ], + ), + ], + ) + + strip_debug_symbols_feature = feature( + name = "strip_debug_symbols", + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = ["-Wl,-S"], + expand_if_available = "strip_debug_symbols", + ), + ], + ), + ], + ) + + build_interface_libraries_feature = feature( + name = "build_interface_libraries", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.lto_index_for_dynamic_library, + ACTION_NAMES.lto_index_for_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = [ + "%{generate_interface_library}", + "%{interface_library_builder_path}", + "%{interface_library_input_path}", + "%{interface_library_output_path}", + ], + expand_if_available = "generate_interface_library", + ), + ], + with_features = [ + with_feature_set( + features = ["supports_interface_shared_libraries"], + ), + ], + ), + ], + ) + + libraries_to_link_feature = feature( + name = "libraries_to_link", + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + iterate_over = "libraries_to_link", + flag_groups = [ + flag_group( + flags = ["-Wl,--start-lib"], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file_group", + ), + ), + flag_group( + flags = ["-Wl,-whole-archive"], + expand_if_true = + "libraries_to_link.is_whole_archive", + ), + flag_group( + flags = ["%{libraries_to_link.object_files}"], + iterate_over = "libraries_to_link.object_files", + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file_group", + ), + ), + flag_group( + flags = ["%{libraries_to_link.name}"], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file", + ), + ), + flag_group( + flags = ["%{libraries_to_link.name}"], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "interface_library", + ), + ), + flag_group( + flags = ["%{libraries_to_link.name}"], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "static_library", + ), + ), + flag_group( + flags = ["-l%{libraries_to_link.name}"], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "dynamic_library", + ), + ), + flag_group( + flags = ["-l:%{libraries_to_link.name}"], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "versioned_dynamic_library", + ), + ), + flag_group( + flags = ["-Wl,-no-whole-archive"], + expand_if_true = "libraries_to_link.is_whole_archive", + ), + flag_group( + flags = ["-Wl,--end-lib"], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file_group", + ), + ), + ], + expand_if_available = "libraries_to_link", + ), + flag_group( + flags = ["-Wl,@%{thinlto_param_file}"], + expand_if_true = "thinlto_param_file", + ), + ], + ), + ], + ) + + user_link_flags_feature = feature( + name = "user_link_flags", + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = ["%{user_link_flags}"], + iterate_over = "user_link_flags", + expand_if_available = "user_link_flags", + ), + ], + ), + ], + ) + + default_link_libs_feature = feature( + name = "default_link_libs", + enabled = True, + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [flag_group(flags = ctx.attr.link_libs)] if ctx.attr.link_libs else [], + ), + ], + ) + + fdo_prefetch_hints_feature = feature( + name = "fdo_prefetch_hints", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.lto_backend, + ], + flag_groups = [ + flag_group( + flags = [ + "-mllvm", + "-prefetch-hints-file=%{fdo_prefetch_hints_path}", + ], + expand_if_available = "fdo_prefetch_hints_path", + ), + ], + ), + ], + ) + + linkstamps_feature = feature( + name = "linkstamps", + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = ["%{linkstamp_paths}"], + iterate_over = "linkstamp_paths", + expand_if_available = "linkstamp_paths", + ), + ], + ), + ], + ) + + archiver_flags_feature = feature( + name = "archiver_flags", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group(flags = ["rcsD"]), + flag_group( + flags = ["%{output_execpath}"], + expand_if_available = "output_execpath", + ), + ], + with_features = [ + with_feature_set( + not_features = ["libtool"], + ), + ], + ), + flag_set( + actions = [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group(flags = ["-static"]), + flag_group( + flags = ["-o", "%{output_execpath}"], + expand_if_available = "output_execpath", + ), + ], + with_features = [ + with_feature_set( + features = ["libtool"], + ), + ], + ), + flag_set( + actions = [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + iterate_over = "libraries_to_link", + flag_groups = [ + flag_group( + flags = ["%{libraries_to_link.name}"], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file", + ), + ), + flag_group( + flags = ["%{libraries_to_link.object_files}"], + iterate_over = "libraries_to_link.object_files", + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file_group", + ), + ), + ], + expand_if_available = "libraries_to_link", + ), + ], + ), + flag_set( + actions = [ACTION_NAMES.cpp_link_static_library], + flag_groups = ([ + flag_group( + flags = ctx.attr.archive_flags, + ), + ] if ctx.attr.archive_flags else []), + ), + ], + ) + + force_pic_flags_feature = feature( + name = "force_pic_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.lto_index_for_executable, + ], + flag_groups = [ + flag_group( + flags = ["-pie"], + expand_if_available = "force_pic", + ), + ], + ), + ], + ) + + dependency_file_feature = feature( + name = "dependency_file", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.objc_compile, + ACTION_NAMES.objcpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.clif_match, + ], + flag_groups = [ + flag_group( + flags = ["-MD", "-MF", "%{dependency_file}"], + expand_if_available = "dependency_file", + ), + ], + ), + ], + ) + + serialized_diagnostics_file_feature = feature( + name = "serialized_diagnostics_file", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.objc_compile, + ACTION_NAMES.objcpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.clif_match, + ], + flag_groups = [ + flag_group( + flags = ["--serialize-diagnostics", "%{serialized_diagnostics_file}"], + expand_if_available = "serialized_diagnostics_file", + ), + ], + ), + ], + ) + + dynamic_library_linker_tool_feature = feature( + name = "dynamic_library_linker_tool", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.lto_index_for_dynamic_library, + ACTION_NAMES.lto_index_for_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = [" + cppLinkDynamicLibraryToolPath + "], + expand_if_available = "generate_interface_library", + ), + ], + with_features = [ + with_feature_set( + features = ["supports_interface_shared_libraries"], + ), + ], + ), + ], + ) + + output_execpath_flags_feature = feature( + name = "output_execpath_flags", + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = ["-o", "%{output_execpath}"], + expand_if_available = "output_execpath", + ), + ], + ), + ], + ) + + # Note that we also set --coverage for c++-link-nodeps-dynamic-library. The + # generated code contains references to gcov symbols, and the dynamic linker + # can't resolve them unless the library is linked against gcov. + coverage_feature = feature( + name = "coverage", + provides = ["profile"], + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = ([ + flag_group(flags = ctx.attr.coverage_compile_flags), + ] if ctx.attr.coverage_compile_flags else []), + ), + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = ([ + flag_group(flags = ctx.attr.coverage_link_flags), + ] if ctx.attr.coverage_link_flags else []), + ), + ], + ) + + thinlto_feature = feature( + name = "thin_lto", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ] + all_link_actions + lto_index_actions, + flag_groups = [ + flag_group(flags = ["-flto=thin"]), + flag_group( + expand_if_available = "lto_indexing_bitcode_file", + flags = [ + "-Xclang", + "-fthin-link-bitcode=%{lto_indexing_bitcode_file}", + ], + ), + ], + ), + flag_set( + actions = [ACTION_NAMES.linkstamp_compile], + flag_groups = [flag_group(flags = ["-DBUILD_LTO_TYPE=thin"])], + ), + flag_set( + actions = lto_index_actions, + flag_groups = [ + flag_group(flags = [ + "-flto=thin", + "-Wl,-plugin-opt,thinlto-index-only%{thinlto_optional_params_file}", + "-Wl,-plugin-opt,thinlto-emit-imports-files", + "-Wl,-plugin-opt,thinlto-prefix-replace=%{thinlto_prefix_replace}", + ]), + flag_group( + expand_if_available = "thinlto_object_suffix_replace", + flags = [ + "-Wl,-plugin-opt,thinlto-object-suffix-replace=%{thinlto_object_suffix_replace}", + ], + ), + flag_group( + expand_if_available = "thinlto_merged_object_file", + flags = [ + "-Wl,-plugin-opt,obj-path=%{thinlto_merged_object_file}", + ], + ), + ], + ), + flag_set( + actions = [ACTION_NAMES.lto_backend], + flag_groups = [ + flag_group(flags = [ + "-c", + "-fthinlto-index=%{thinlto_index}", + "-o", + "%{thinlto_output_object_file}", + "-x", + "ir", + "%{thinlto_input_bitcode_file}", + ]), + ], + ), + ], + ) + + treat_warnings_as_errors_feature = feature( + name = "treat_warnings_as_errors", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["-Werror"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["-Wl,-fatal-warnings"])], + ), + ], + ) + + archive_param_file_feature = feature( + name = "archive_param_file", + enabled = True, + ) + + asan_feature = _sanitizer_feature( + name = "asan", + specific_compile_flags = [ + "-fsanitize=address", + "-fno-common", + ], + specific_link_flags = [ + "-fsanitize=address", + ], + ) + + tsan_feature = _sanitizer_feature( + name = "tsan", + specific_compile_flags = [ + "-fsanitize=thread", + ], + specific_link_flags = [ + "-fsanitize=thread", + ], + ) + + ubsan_feature = _sanitizer_feature( + name = "ubsan", + specific_compile_flags = [ + "-fsanitize=undefined", + ], + specific_link_flags = [ + "-fsanitize=undefined", + ], + ) + + is_linux = ctx.attr.target_libc != "macosx" + libtool_feature = feature( + name = "libtool", + enabled = not is_linux, + ) + + # TODO(#8303): Mac crosstool should also declare every feature. + if is_linux: + # Linux artifact name patterns are the default. + artifact_name_patterns = [] + features = [ + dependency_file_feature, + serialized_diagnostics_file_feature, + random_seed_feature, + pic_feature, + per_object_debug_info_feature, + preprocessor_defines_feature, + includes_feature, + include_paths_feature, + external_include_paths_feature, + fdo_instrument_feature, + cs_fdo_instrument_feature, + cs_fdo_optimize_feature, + thinlto_feature, + fdo_prefetch_hints_feature, + autofdo_feature, + build_interface_libraries_feature, + dynamic_library_linker_tool_feature, + symbol_counts_feature, + shared_flag_feature, + linkstamps_feature, + output_execpath_flags_feature, + runtime_library_search_directories_feature, + library_search_directories_feature, + libtool_feature, + archiver_flags_feature, + force_pic_flags_feature, + fission_support_feature, + strip_debug_symbols_feature, + coverage_feature, + supports_pic_feature, + asan_feature, + tsan_feature, + ubsan_feature, + ] + ( + [ + supports_start_end_lib_feature, + ] if ctx.attr.supports_start_end_lib else [] + ) + [ + default_compile_flags_feature, + default_link_flags_feature, + libraries_to_link_feature, + user_link_flags_feature, + default_link_libs_feature, + static_libgcc_feature, + fdo_optimize_feature, + supports_dynamic_linker_feature, + dbg_feature, + opt_feature, + user_compile_flags_feature, + sysroot_feature, + unfiltered_compile_flags_feature, + treat_warnings_as_errors_feature, + archive_param_file_feature, + ] + layering_check_features(ctx.attr.compiler) + else: + # macOS artifact name patterns differ from the defaults only for dynamic + # libraries. + artifact_name_patterns = [ + artifact_name_pattern( + category_name = "dynamic_library", + prefix = "lib", + extension = ".dylib", + ), + ] + features = [ + libtool_feature, + archiver_flags_feature, + supports_pic_feature, + asan_feature, + tsan_feature, + ubsan_feature, + ] + ( + [ + supports_start_end_lib_feature, + ] if ctx.attr.supports_start_end_lib else [] + ) + [ + coverage_feature, + default_compile_flags_feature, + default_link_flags_feature, + user_link_flags_feature, + default_link_libs_feature, + fdo_optimize_feature, + dbg_feature, + opt_feature, + user_compile_flags_feature, + sysroot_feature, + unfiltered_compile_flags_feature, + treat_warnings_as_errors_feature, + archive_param_file_feature, + ] + layering_check_features(ctx.attr.compiler) + + return cc_common.create_cc_toolchain_config_info( + ctx = ctx, + features = features, + action_configs = action_configs, + artifact_name_patterns = artifact_name_patterns, + cxx_builtin_include_directories = ctx.attr.cxx_builtin_include_directories, + toolchain_identifier = ctx.attr.toolchain_identifier, + host_system_name = ctx.attr.host_system_name, + target_system_name = ctx.attr.target_system_name, + target_cpu = ctx.attr.cpu, + target_libc = ctx.attr.target_libc, + compiler = ctx.attr.compiler, + abi_version = ctx.attr.abi_version, + abi_libc_version = ctx.attr.abi_libc_version, + tool_paths = tool_paths, + builtin_sysroot = ctx.attr.builtin_sysroot, + ) + +cc_toolchain_config = rule( + implementation = _impl, + attrs = { + "cpu": attr.string(mandatory = True), + "compiler": attr.string(mandatory = True), + "toolchain_identifier": attr.string(mandatory = True), + "host_system_name": attr.string(mandatory = True), + "target_system_name": attr.string(mandatory = True), + "target_libc": attr.string(mandatory = True), + "abi_version": attr.string(mandatory = True), + "abi_libc_version": attr.string(mandatory = True), + "cxx_builtin_include_directories": attr.string_list(), + "tool_paths": attr.string_dict(), + "compile_flags": attr.string_list(), + "dbg_compile_flags": attr.string_list(), + "opt_compile_flags": attr.string_list(), + "cxx_flags": attr.string_list(), + "link_flags": attr.string_list(), + "archive_flags": attr.string_list(), + "link_libs": attr.string_list(), + "opt_link_flags": attr.string_list(), + "unfiltered_compile_flags": attr.string_list(), + "coverage_compile_flags": attr.string_list(), + "coverage_link_flags": attr.string_list(), + "supports_start_end_lib": attr.bool(), + "builtin_sysroot": attr.string(), + }, + provides = [CcToolchainConfigInfo], +) diff --git a/third_party/tsl/tools/toolchains/cross_compile/cc/cc_wrapper.sh b/third_party/tsl/tools/toolchains/cross_compile/cc/cc_wrapper.sh new file mode 100644 index 0000000000000..a5106941d70ef --- /dev/null +++ b/third_party/tsl/tools/toolchains/cross_compile/cc/cc_wrapper.sh @@ -0,0 +1,120 @@ +#!/bin/bash +# +# Copyright 2015 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# OS X relpath is not really working. This is a wrapper script around gcc +# to simulate relpath behavior. +# +# This wrapper uses install_name_tool to replace all paths in the binary +# (bazel-out/.../path/to/original/library.so) by the paths relative to +# the binary. It parses the command line to behave as rpath is supposed +# to work. +# +# See https://blogs.oracle.com/dipol/entry/dynamic_libraries_rpath_and_mac +# on how to set those paths for Mach-O binaries. +# +set -eu + +LIBS= +LIB_DIRS= +RPATHS= +OUTPUT= + +function parse_option() { + local -r opt="$1" + if [[ "${OUTPUT}" = "1" ]]; then + OUTPUT=$opt + elif [[ "$opt" =~ ^-l(.*)$ ]]; then + LIBS="${BASH_REMATCH[1]} $LIBS" + elif [[ "$opt" =~ ^-L(.*)$ ]]; then + LIB_DIRS="${BASH_REMATCH[1]} $LIB_DIRS" + elif [[ "$opt" =~ ^\@loader_path/(.*)$ ]]; then + RPATHS="${BASH_REMATCH[1]} ${RPATHS}" + elif [[ "$opt" = "-o" ]]; then + # output is coming + OUTPUT=1 + fi +} + +# let parse the option list +for i in "$@"; do + if [[ "$i" = @* && -r "${i:1}" ]]; then + while IFS= read -r opt + do + parse_option "$opt" + done < "${i:1}" || exit 1 + else + parse_option "$i" + fi +done + +# Call the C++ compiler +/usr/lib/llvm-17/bin/clang "$@" + +function get_library_path() { + for libdir in ${LIB_DIRS}; do + if [ -f ${libdir}/lib$1.so ]; then + echo "${libdir}/lib$1.so" + elif [ -f ${libdir}/lib$1.dylib ]; then + echo "${libdir}/lib$1.dylib" + fi + done +} + +# A convenient method to return the actual path even for non symlinks +# and multi-level symlinks, see b/300002682 for more details. +function get_realpath() { + local mangled=$(echo $1 | sed 's/[-_\/a-zA-Z0-9]*_solib_darwin[-_a-zA-Z0-9]*\///g') + if [[ "${mangled:0:3}" = "lib" ]]; then + mangled="${mangled:3}" + fi + if [[ "${mangled:0:2}" = "_U" ]]; then + mangled="${mangled:2}" + fi + local mangled_path=(${mangled//_S/ }) + local demangled_path=() + for mangled in ${mangled_path[@]}; do + demangled_path+=(${mangled//_U/_}) + done + demangled_path=${demangled_path[@]} + echo "bazel-out/darwin-opt/bin/${demangled_path// //}" +} + +# Get the path of a lib inside a tool +function get_otool_path() { + # the lib path is the path of the original lib relative to the workspace + get_realpath $1 | sed 's|^.*/bazel-out/|bazel-out/|' +} + +# Do replacements in the output +for rpath in ${RPATHS}; do + for lib in ${LIBS}; do + unset libname + if [ -f "$(dirname ${OUTPUT})/${rpath}/lib${lib}.so" ]; then + libname="lib${lib}.so" + elif [ -f "$(dirname ${OUTPUT})/${rpath}/lib${lib}.dylib" ]; then + libname="lib${lib}.dylib" + fi + # ${libname-} --> return $libname if defined, or undefined otherwise. This is to make + # this set -e friendly + if [[ -n "${libname-}" ]]; then + libpath=$(get_library_path ${lib}) + if [ -n "${libpath}" ]; then + /usr/lib/llvm-17/bin/llvm-install-name-tool -change $(get_otool_path "${libpath}") \ + "@loader_path/${rpath}/${libname}" "${OUTPUT}" + fi + fi + done +done \ No newline at end of file diff --git a/third_party/tsl/tools/toolchains/cross_compile/config/BUILD b/third_party/tsl/tools/toolchains/cross_compile/config/BUILD index b6a504ba1449d..e60a32aced24e 100644 --- a/third_party/tsl/tools/toolchains/cross_compile/config/BUILD +++ b/third_party/tsl/tools/toolchains/cross_compile/config/BUILD @@ -21,3 +21,25 @@ platform( "@platforms//cpu:aarch64", ], ) + +platform( + name = "darwin_x86_64", + constraint_values = [ + "@platforms//os:macos", + "@platforms//cpu:x86_64", + ], +) + +toolchain( + name = "macos-x86-cross-compile-cc-toolchain", + exec_compatible_with = [ + "@platforms//os:linux", + "@platforms//cpu:x86_64", + ], + target_compatible_with = [ + "@platforms//os:macos", + "@platforms//cpu:x86_64", + ], + toolchain = "//tensorflow/tools/toolchains/cross_compile/cc:macos_x86_toolchain", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) diff --git a/third_party/tsl/tools/toolchains/cross_compile/config/platform_mappings b/third_party/tsl/tools/toolchains/cross_compile/config/platform_mappings new file mode 100644 index 0000000000000..aaf91d096d7b4 --- /dev/null +++ b/third_party/tsl/tools/toolchains/cross_compile/config/platform_mappings @@ -0,0 +1,11 @@ +platforms: +# Maps "--platforms=//tensorflow/tools/toolchains/cross_compile/config:darwin_x86_64" +# to "--cpu=darwin". + //tensorflow/tools/toolchains/cross_compile/config:darwin_x86_64 + --cpu=darwin + +flags: + # Maps "--cpu=darwin" to + # "--platforms=//tensorflow/tools/toolchains/cross_compile/config:darwin_x86_64". + --cpu=darwin + //tensorflow/tools/toolchains/cross_compile/config:darwin_x86_64 diff --git a/third_party/tsl/tools/toolchains/python/python_repo.bzl b/third_party/tsl/tools/toolchains/python/python_repo.bzl index 77011b2c0577b..47fe64d7b7b03 100644 --- a/third_party/tsl/tools/toolchains/python/python_repo.bzl +++ b/third_party/tsl/tools/toolchains/python/python_repo.bzl @@ -1,7 +1,10 @@ """ -Repository rule to set python version. -Can be set via build parameter "--repo_env=TF_PYTHON_VERSION=3.10" +Repository rule to set python version and wheel name. + +Version can be set via build parameter "--repo_env=TF_PYTHON_VERSION=3.10" Defaults to 3.10. + +To set wheel name, add "--repo_env=WHEEL_NAME=tensorflow_cpu" """ VERSIONS = ["3.9", "3.10", "3.11", "3.12"] @@ -16,20 +19,24 @@ export TF_PYTHON_VERSION=3.11 content = """ TF_PYTHON_VERSION = "{}" HERMETIC_PYTHON_VERSION = "{}" +WHEEL_NAME = "{}" +WHEEL_COLLAB = "{}" """ def _python_repository_impl(repository_ctx): repository_ctx.file("BUILD", "") version = repository_ctx.os.environ.get("TF_PYTHON_VERSION", "") + wheel_name = repository_ctx.os.environ.get("WHEEL_NAME", "tensorflow") + wheel_collab = repository_ctx.os.environ.get("WHEEL_COLLAB", False) if version not in VERSIONS: print(WARNING) # buildifier: disable=print version = DEFAULT_VERSION repository_ctx.file( "py_version.bzl", - content.format(version, version), + content.format(version, version, wheel_name, wheel_collab), ) python_repository = repository_rule( implementation = _python_repository_impl, - environ = ["TF_PYTHON_VERSION"], + environ = ["TF_PYTHON_VERSION", "WHEEL_NAME", "WHEEL_COLLAB"], ) diff --git a/third_party/tsl/tools/toolchains/remote_config/configs.bzl b/third_party/tsl/tools/toolchains/remote_config/configs.bzl index 4554463cb9067..f55194087b127 100644 --- a/third_party/tsl/tools/toolchains/remote_config/configs.bzl +++ b/third_party/tsl/tools/toolchains/remote_config/configs.bzl @@ -178,6 +178,28 @@ def initialize_rbe_configs(): python_install_path = "/usr/local", ) + tensorflow_rbe_config( + name = "ubuntu20.04-clang_manylinux2014-cuda12.1-cudnn8.9", + compiler = "/usr/lib/llvm-17/bin/clang", + cuda_version = "12.1", + cudnn_version = "8.9", + os = "ubuntu20.04-manylinux2014-multipython", + python_versions = ["3.9", "3.10", "3.11", "3.12"], + sysroot = "/dt9", + python_install_path = "/usr/local", + ) + + tensorflow_rbe_config( + name = "ubuntu20.04-gcc9_manylinux2014-cuda12.1-cudnn8.9", + compiler = "/dt9/usr/bin/gcc", + compiler_prefix = "/usr/bin", + cuda_version = "12.1", + cudnn_version = "8.9", + os = "ubuntu20.04-manylinux2014-multipython", + python_versions = ["3.9", "3.10", "3.11", "3.12"], + python_install_path = "/usr/local", + ) + tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9", compiler = "/usr/lib/llvm-17/bin/clang", @@ -200,6 +222,28 @@ def initialize_rbe_configs(): python_install_path = "/usr/local", ) + tensorflow_rbe_config( + name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9", + compiler = "/usr/lib/llvm-17/bin/clang", + cuda_version = "12.3", + cudnn_version = "8.9", + os = "ubuntu20.04-manylinux2014-multipython", + python_versions = ["3.9", "3.10", "3.11", "3.12"], + sysroot = "/dt9", + python_install_path = "/usr/local", + ) + + tensorflow_rbe_config( + name = "ubuntu20.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", + compiler = "/dt9/usr/bin/gcc", + compiler_prefix = "/usr/bin", + cuda_version = "12.3", + cudnn_version = "8.9", + os = "ubuntu20.04-manylinux2014-multipython", + python_versions = ["3.9", "3.10", "3.11", "3.12"], + python_install_path = "/usr/local", + ) + tensorflow_rbe_win_config( name = "windows_py37", python_bin_path = "C:/Python37/python.exe", @@ -605,11 +649,11 @@ def initialize_rbe_configs(): sigbuild_tf_configs( name_container_map = { - "sigbuild-r2.16": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505", - "sigbuild-r2.16-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505", - "sigbuild-r2.16-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:93c234df4c781af6974d86e9d1dd2e19ce0845b1b662c38e9a30d1de64eab3b0", - "sigbuild-r2.16-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:d0a91705406aad65a79011683b8f7d4b8131625ea26a6d08aa7c6eb6955873a2", - "sigbuild-r2.16-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:ed7313f95bce391cbf3b498ff6c534d163cc2bb91ca1d6ef6363bde4fd9e0cfc", + "sigbuild-r2.16": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", + "sigbuild-r2.16-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:22d863e6fe3f98946015b9e1264b2eeb8e56e504535a6c1d5e564cae65ae5d37", + "sigbuild-r2.16-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:da15288c8464153eadd35da720540a544b76aa9d78cceb42a6821b2f3e70a0fa", + "sigbuild-r2.16-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", + "sigbuild-r2.16-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:40fcd1d05c672672b599d9cb3784dcf379d6aa876f043b46c6ab18237d5d4e10", }, # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 # and manylinux2014 is 2.17. @@ -633,7 +677,7 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "0", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.2", + "TF_CUDA_VERSION": "12.3", "TF_CUDNN_VERSION": "8.9", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", @@ -645,11 +689,11 @@ def initialize_rbe_configs(): sigbuild_tf_configs( name_container_map = { - "sigbuild-r2.16-clang": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505", - "sigbuild-r2.16-clang-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505", - "sigbuild-r2.16-clang-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:93c234df4c781af6974d86e9d1dd2e19ce0845b1b662c38e9a30d1de64eab3b0", - "sigbuild-r2.16-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:d0a91705406aad65a79011683b8f7d4b8131625ea26a6d08aa7c6eb6955873a2", - "sigbuild-r2.16-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:ed7313f95bce391cbf3b498ff6c534d163cc2bb91ca1d6ef6363bde4fd9e0cfc", + "sigbuild-r2.16-clang": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", + "sigbuild-r2.16-clang-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:22d863e6fe3f98946015b9e1264b2eeb8e56e504535a6c1d5e564cae65ae5d37", + "sigbuild-r2.16-clang-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:da15288c8464153eadd35da720540a544b76aa9d78cceb42a6821b2f3e70a0fa", + "sigbuild-r2.16-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", + "sigbuild-r2.16-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:40fcd1d05c672672b599d9cb3784dcf379d6aa876f043b46c6ab18237d5d4e10", }, # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 # and manylinux2014 is 2.17. @@ -672,7 +716,7 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.2", + "TF_CUDA_VERSION": "12.3", "TF_CUDNN_VERSION": "8.9", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", diff --git a/third_party/tsl/tools/toolchains/remote_config/containers.bzl b/third_party/tsl/tools/toolchains/remote_config/containers.bzl index bfb4634e81032..dd222d06bd13b 100644 --- a/third_party/tsl/tools/toolchains/remote_config/containers.bzl +++ b/third_party/tsl/tools/toolchains/remote_config/containers.bzl @@ -5,8 +5,10 @@ container_digests = { # TF now uses only this container "cuda11.2-cudnn8.1-ubuntu20.04-manylinux2014-multipython": "sha256:48612bd85709cd014711d0b0f87e0806f3567d06d2e81c6e860516b87498b821", # JAX manylinux2014 configs. - "cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:ab39410baf2fc1d31d50540acec7640d7f4814fa694e2421b696b6f0a058d645", - "cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:b699d6ae235ac601dc3e62391ac7c4606cb10331f8141983858c1580f5e74ddb", + "cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:45619e91f14faabddd79fe0cb1526df4c4ad92fc2e6ebdc725ea4419225429c3", + "cuda12.1-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:8c266e5b0acd203aed5e8871b63f68a39d8d23f6d882e619797e58b973f7fe63", + "cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:9fefda035b4a12b24cd5bae56c7dbb9527a5fd06a41ced0a22ac86fe5ed26428", + "cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:6f9524a2ed7f75255dc4be3a0c5e3bda581385a1c13e2fa890bc17fa62da95b2", # ROCM, probably not all of them still in use "rocm-ubuntu18.04-manylinux2010-multipython": "sha256:6e953a09b145df338bcb03e9e36f99b291140c29b72d0a048fb6c5905ccad5eb", "rocm-ubuntu20.04-manylinux2014-multipython": "sha256:906faec7765fe5dd067f2b092b5d5f220c1fedde725fb42c83d031b4d6f32204", @@ -91,6 +93,13 @@ containers = { "digest": container_digests["cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython"], }, + # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.1-cudnn8.9-ubuntu20.04-manylinux2014-multipython. + "cuda12.1-cudnn8.9-ubuntu20.04-manylinux2014-multipython": { + "registry": "gcr.io", + "repository": "tensorflow-testing/nosla-cuda12.1-cudnn8.9-ubuntu20.04-manylinux2014-multipython", + "digest": container_digests["cuda12.1-cudnn8.9-ubuntu20.04-manylinux2014-multipython"], + }, + # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython. "cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": { "registry": "gcr.io", @@ -98,6 +107,13 @@ containers = { "digest": container_digests["cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython"], }, + # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython. + "cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython": { + "registry": "gcr.io", + "repository": "tensorflow-testing/nosla-cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython", + "digest": container_digests["cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython"], + }, + # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.rocm-ubuntu18.04-manylinux2010-multipython. "rocm-ubuntu18.04-manylinux2010-multipython": { "registry": "gcr.io", diff --git a/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl b/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl index b1488584566aa..18a84d96c39f8 100644 --- a/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl +++ b/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl @@ -1,12 +1,12 @@ """Macro that creates external repositories for remote config.""" -load("//third_party/py:python_configure.bzl", "local_python_configure", "remote_python_configure") load("//third_party/gpus:cuda_configure.bzl", "remote_cuda_configure") -load("//third_party/nccl:nccl_configure.bzl", "remote_nccl_configure") load("//third_party/gpus:rocm_configure.bzl", "remote_rocm_configure") +load("//third_party/nccl:nccl_configure.bzl", "remote_nccl_configure") +load("//third_party/py:python_configure.bzl", "local_python_configure", "remote_python_configure") +load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure") load("//third_party/tensorrt:tensorrt_configure.bzl", "remote_tensorrt_configure") load("//tools/toolchains/remote_config:containers.bzl", "containers") -load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure") def _container_image_uri(container_name): container = containers[container_name] diff --git a/third_party/tsl/tools/toolchains/win/bazel_211/BUILD b/third_party/tsl/tools/toolchains/win/bazel_211/BUILD index cc23c8ecb2268..c7484d2ae2efd 100644 --- a/third_party/tsl/tools/toolchains/win/bazel_211/BUILD +++ b/third_party/tsl/tools/toolchains/win/bazel_211/BUILD @@ -15,8 +15,8 @@ # This becomes the BUILD file for @local_config_cc// under Windows. load("@rules_cc//cc:defs.bzl", "cc_library", "cc_toolchain", "cc_toolchain_suite") -load(":windows_cc_toolchain_config.bzl", "cc_toolchain_config") load(":armeabi_cc_toolchain_config.bzl", "armeabi_cc_toolchain_config") +load(":windows_cc_toolchain_config.bzl", "cc_toolchain_config") package(default_visibility = ["//visibility:public"]) diff --git a/third_party/tsl/tools/toolchains/win/bazel_211/windows_cc_toolchain_config.bzl b/third_party/tsl/tools/toolchains/win/bazel_211/windows_cc_toolchain_config.bzl index 30571b6a5ace8..9ccc1706e5eca 100644 --- a/third_party/tsl/tools/toolchains/win/bazel_211/windows_cc_toolchain_config.bzl +++ b/third_party/tsl/tools/toolchains/win/bazel_211/windows_cc_toolchain_config.bzl @@ -14,6 +14,7 @@ """A Starlark cc_toolchain configuration rule for Windows""" +load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES") load( "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl", "action_config", @@ -29,7 +30,6 @@ load( "variable_with_value", "with_feature_set", ) -load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES") all_compile_actions = [ ACTION_NAMES.c_compile, diff --git a/third_party/tsl/tools/toolchains/win/tf_win_05022023/BUILD b/third_party/tsl/tools/toolchains/win/tf_win_05022023/BUILD index f245f6d0789c9..8a2ae6fe4a9dd 100644 --- a/third_party/tsl/tools/toolchains/win/tf_win_05022023/BUILD +++ b/third_party/tsl/tools/toolchains/win/tf_win_05022023/BUILD @@ -15,8 +15,8 @@ # This becomes the BUILD file for @local_config_cc// under Windows. load("@rules_cc//cc:defs.bzl", "cc_library", "cc_toolchain", "cc_toolchain_suite") -load(":windows_cc_toolchain_config.bzl", "cc_toolchain_config") load(":armeabi_cc_toolchain_config.bzl", "armeabi_cc_toolchain_config") +load(":windows_cc_toolchain_config.bzl", "cc_toolchain_config") package(default_visibility = ["//visibility:public"]) diff --git a/third_party/tsl/tools/toolchains/win/tf_win_05022023/windows_cc_toolchain_config.bzl b/third_party/tsl/tools/toolchains/win/tf_win_05022023/windows_cc_toolchain_config.bzl index ba3de607d1045..d6b966b32ceca 100644 --- a/third_party/tsl/tools/toolchains/win/tf_win_05022023/windows_cc_toolchain_config.bzl +++ b/third_party/tsl/tools/toolchains/win/tf_win_05022023/windows_cc_toolchain_config.bzl @@ -14,6 +14,7 @@ """A Starlark cc_toolchain configuration rule for Windows""" +load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES") load( "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl", "action_config", @@ -28,7 +29,6 @@ load( "variable_with_value", "with_feature_set", ) -load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES") all_compile_actions = [ ACTION_NAMES.c_compile, diff --git a/third_party/tsl/tsl/BUILD b/third_party/tsl/tsl/BUILD index 68828f0098d26..8fdcc6fa00f8e 100644 --- a/third_party/tsl/tsl/BUILD +++ b/third_party/tsl/tsl/BUILD @@ -1,7 +1,7 @@ -load("tsl.bzl", "if_google", "if_oss") +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "bool_setting") -load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("tsl.bzl", "if_google", "if_oss") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) @@ -108,7 +108,7 @@ config_setting( config_setting( name = "emscripten", constraint_values = if_google( - ["//third_party/bazel_platforms/cpu:wasm32"], + ["//third_party/bazel_platforms/os:emscripten"], [], ), values = if_oss( @@ -177,6 +177,18 @@ selects.config_setting_group( visibility = ["//visibility:public"], ) +config_setting( + name = "windows_x86_64", + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "windows_aarch64", + values = {"cpu": "arm64_windows"}, + visibility = ["//visibility:public"], +) + config_setting( name = "windows", # Internal builds query the target OS. @@ -221,6 +233,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "with_cross_compiler_support", + define_values = {"with_cross_compiler_support": "true"}, + visibility = ["//visibility:public"], +) + config_setting( name = "android_arm", constraint_values = if_google( @@ -281,6 +299,67 @@ config_setting( visibility = ["//visibility:public"], ) +selects.config_setting_group( + name = "aarch32_or_cross", + match_any = [ + ":linux_armhf", + ":with_cross_compiler_support", + ], + visibility = ["//visibility:public"], +) + +selects.config_setting_group( + name = "aarch64_or_cross", + match_any = [ + ":linux_aarch64", + ":macos_arm64", + ":windows_aarch64", + ":with_cross_compiler_support", + ], + visibility = ["//visibility:public"], +) + +selects.config_setting_group( + name = "arm_or_cross", + match_any = [ + ":linux_aarch64", + ":macos_arm64", + ":windows_aarch64", + ":linux_armhf", + ":with_cross_compiler_support", + ], + visibility = ["//visibility:public"], +) + +selects.config_setting_group( + name = "ppc64le_or_cross", + match_any = [ + ":linux_ppc64le", + ":with_cross_compiler_support", + ], + visibility = ["//visibility:public"], +) + +selects.config_setting_group( + name = "s390x_or_cross", + match_any = [ + ":linux_s390x", + ":with_cross_compiler_support", + ], + visibility = ["//visibility:public"], +) + +selects.config_setting_group( + name = "x86_or_cross", + match_any = [ + ":linux_x86_64", + ":macos_x86_64", + ":windows_x86_64", + ":with_cross_compiler_support", + ], + visibility = ["//visibility:public"], +) + # Config setting that disables the default logger, only logging # to registered TFLogSinks config_setting( @@ -449,10 +528,14 @@ bzl_library( srcs = ["tsl.bzl"], visibility = ["//visibility:public"], deps = [ + "//third_party/compute_library:build_defs_bzl", + "//third_party/mkl_dnn:build_defs_bzl", + "//tsl/platform:rules_cc_bzl", "@bazel_skylib//lib:new_sets", "@local_config_cuda//cuda:build_defs_bzl", "@local_config_rocm//rocm:build_defs_bzl", "@local_config_tensorrt//:build_defs_bzl", + "@xla//xla/tsl/mkl:build_defs_bzl", ], ) diff --git a/third_party/tsl/tsl/c/BUILD b/third_party/tsl/tsl/c/BUILD deleted file mode 100644 index 93f39b2ecc35b..0000000000000 --- a/third_party/tsl/tsl/c/BUILD +++ /dev/null @@ -1,117 +0,0 @@ -# Description: -# C API for TensorFlow, for use by client language bindings. - -load("//tsl/platform:build_config.bzl", "tsl_cc_test") -load("//tsl/platform:rules_cc.bzl", "cc_library") -load("//tsl:tsl.bzl", "tsl_copts", "tsl_gpu_library") - -# buildifier: disable=same-origin-load -load("//tsl:tsl.default.bzl", "filegroup") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - licenses = ["notice"], -) - -# ----------------------------------------------------------------------------- -# Public targets - -filegroup( - name = "headers", - srcs = [ - "tsl_status.h", - ], - visibility = ["//tensorflow:__subpackages__"], -) - -filegroup( - name = "srcs", - srcs = glob( - [ - "*.cc", - "*.h", - ], - exclude = [ - "*test*", - ], - ), - visibility = [ - "//tensorflow/c:__subpackages__", - ], -) - -tsl_gpu_library( - name = "c_api", - hdrs = [ - "tsl_status.h", - ], - copts = tsl_copts(), - visibility = ["//visibility:public"], - deps = [ - ":tsl_status_internal", - ], -) - -tsl_gpu_library( - name = "tsl_status_internal", - hdrs = [ - "tsl_status.h", - "tsl_status_internal.h", - ], - visibility = ["//visibility:public"], - deps = [ - "//tsl/platform:status", - ], -) - -cc_library( - name = "tsl_status", - srcs = ["tsl_status.cc"], - hdrs = ["tsl_status.h"], - visibility = ["//visibility:public"], - deps = [ - ":tsl_status_internal", - "//tsl/platform:errors", - "//tsl/platform:status", - ], -) - -tsl_cc_test( - name = "tsl_status_test", - srcs = ["tsl_status_test.cc"], - deps = [ - ":tsl_status", - ":tsl_status_internal", - "//tsl/platform:errors", - "//tsl/platform:status", - "//tsl/platform:test", - "//tsl/platform:test_main", - ], -) - -cc_library( - name = "tsl_status_headers", - hdrs = ["tsl_status.h"], - visibility = ["//visibility:public"], -) - -tsl_gpu_library( - name = "tsl_status_helper", - srcs = ["tsl_status_helper.cc"], - hdrs = ["tsl_status_helper.h"], - visibility = ["//visibility:public"], - deps = [ - ":tsl_status", - ":tsl_status_internal", - "//tsl/platform:errors", - "//tsl/platform:status", - ], -) - -filegroup( - name = "tsl_status_internal_headers", - srcs = ["tsl_status_internal.h"], - visibility = [ - "//tensorflow/c:__subpackages__", - ], -) diff --git a/third_party/tsl/tsl/concurrency/BUILD b/third_party/tsl/tsl/concurrency/BUILD index 1281da9ac0ea6..e37bc692df05d 100644 --- a/third_party/tsl/tsl/concurrency/BUILD +++ b/third_party/tsl/tsl/concurrency/BUILD @@ -1,6 +1,6 @@ load("//tsl:tsl.default.bzl", "get_compatible_with_portable") -load("//tsl/platform:rules_cc.bzl", "cc_library") load("//tsl/platform:build_config.bzl", "tsl_cc_test") +load("//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -23,6 +23,9 @@ cc_library( deps = [ ":concurrent_vector", ":ref_count", + "//tsl/platform:logging", + "//tsl/platform:platform_port", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", @@ -39,6 +42,20 @@ tsl_cc_test( ":async_value", "//tsl/platform:test", "//tsl/platform:test_main", + "@com_google_absl//absl/status", + ], +) + +tsl_cc_test( + name = "async_value_ptr_test", + srcs = ["async_value_ptr_test.cc"], + deps = [ + ":async_value", + "//tsl/platform:test", + "//tsl/platform:test_main", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", ], ) @@ -47,8 +64,12 @@ tsl_cc_test( srcs = ["async_value_ref_test.cc"], deps = [ ":async_value", + ":ref_count", "//tsl/platform:test", "//tsl/platform:test_main", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", ], ) diff --git a/third_party/tsl/tsl/concurrency/async_value.cc b/third_party/tsl/tsl/concurrency/async_value.cc index 28af75607cda5..431e6272279aa 100644 --- a/third_party/tsl/tsl/concurrency/async_value.cc +++ b/third_party/tsl/tsl/concurrency/async_value.cc @@ -16,55 +16,21 @@ limitations under the License. #include "tsl/concurrency/async_value.h" #include +#include #include -#include +#include #include -#include #include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" #include "absl/synchronization/blocking_counter.h" +#include "absl/types/span.h" #include "tsl/concurrency/async_value_ref.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/logging.h" namespace tsl { -namespace internal { - -void* AlignedAlloc(size_t alignment, size_t size) { - size = (size + alignment - 1) / alignment * alignment; -#ifdef _WIN32 - // MSVC runtime doesn't support aligned_alloc(). See - // https://developercommunity.visualstudio.com/t/c17-stdaligned-alloc%E7%BC%BA%E5%A4%B1/468021#T-N473365 - return _aligned_malloc(size, alignment); -#elif defined(__ANDROID__) || defined(OS_ANDROID) - return memalign(alignment, size); -#else - // posix_memalign requires that the requested alignment be at least - // alignof(void*). In this case, fall back on malloc which should return - // memory aligned to at least the size of a pointer. - if (alignment <= alignof(void*)) return std::malloc(size); - void* ptr = nullptr; - if (posix_memalign(&ptr, alignment, size) != 0) - return nullptr; - else - return ptr; -#endif -} - -void AlignedFree(void* ptr) { -#ifdef _WIN32 - // _aligned_alloc() must be paired with _aligned_free(). - // - // Attempting to use free() with a pointer returned by _aligned_malloc() - // results in runtime issues that are hard to debug. - _aligned_free(ptr); -#else - free(ptr); -#endif -} - -} // namespace internal - // This is a singly linked list of nodes waiting for notification, hanging off // of AsyncValue. When the value becomes available or if an error occurs, the // callbacks are informed. @@ -83,9 +49,8 @@ class NotifierListNode { uint16_t AsyncValue::CreateTypeInfoAndReturnTypeIdImpl( const TypeInfo& type_info) { size_t type_id = GetTypeInfoTableSingleton()->emplace_back(type_info) + 1; - // Detect overflow. - assert(type_id < std::numeric_limits::max() && - "Too many different AsyncValue types."); + DCHECK(type_id < std::numeric_limits::max()) + << "Too many different AsyncValue types."; return type_id; } @@ -99,7 +64,7 @@ std::atomic AsyncValue::total_allocated_async_values_; const AsyncValue::TypeInfo& AsyncValue::GetTypeInfo() const { TypeInfoTable* type_info_table = AsyncValue::GetTypeInfoTableSingleton(); - assert(type_id_ != 0); + DCHECK_NE(type_id_, 0); return (*type_info_table)[type_id_ - 1]; } @@ -108,17 +73,17 @@ const AsyncValue::TypeInfo& AsyncValue::GetTypeInfo() const { // need to change our state and clear out the notifications. The current state // must be unavailable (i.e. kUnconstructed or kConstructed). void AsyncValue::NotifyAvailable(State available_state) { - assert((kind() == Kind::kConcrete || kind() == Kind::kIndirect) && - "Should only be used by ConcreteAsyncValue or IndirectAsyncValue"); + DCHECK((kind() == Kind::kConcrete || kind() == Kind::kIndirect)) + << "Should only be used by ConcreteAsyncValue or IndirectAsyncValue"; - assert(available_state == State::kConcrete || + DCHECK(available_state == State::kConcrete || available_state == State::kError); // Mark the value as available, ensuring that new queries for the state see // the value that got filled in. auto old_value = waiters_and_state_.exchange( WaitersAndState(nullptr, available_state), std::memory_order_acq_rel); - assert(old_value.state() == State::kUnconstructed || + DCHECK(old_value.state() == State::kUnconstructed || old_value.state() == State::kConstructed); RunWaiters(old_value.waiter()); @@ -158,7 +123,7 @@ void AsyncValue::EnqueueWaiter(absl::AnyInvocable waiter, // so, just run the waiter. if (old_value.state() == State::kConcrete || old_value.state() == State::kError) { - assert(old_value.waiter() == nullptr); + DCHECK(old_value.waiter() == nullptr); node->notification_(); delete node; return; @@ -169,16 +134,16 @@ void AsyncValue::EnqueueWaiter(absl::AnyInvocable waiter, // compare_exchange_weak succeeds. The old_value must be in either // kUnconstructed or kConstructed state. - assert(old_value.state() == State::kUnconstructed || + DCHECK(old_value.state() == State::kUnconstructed || old_value.state() == State::kConstructed); } void AsyncValue::SetError(absl::Status status) { - assert(!status.ok()); + DCHECK(!status.ok()); if (kind() == Kind::kConcrete) { GetTypeInfo().set_error(this, std::move(status)); } else { - assert(kind() == Kind::kIndirect); + DCHECK(kind() == Kind::kIndirect); auto error_av = MakeErrorAsyncValueRef(std::move(status)); static_cast(this)->ForwardTo(std::move(error_av)); } @@ -187,17 +152,17 @@ void AsyncValue::SetError(absl::Status status) { // Mark this IndirectAsyncValue as forwarding to the specified value. This // gives the IndirectAsyncValue a +1 reference. void IndirectAsyncValue::ForwardTo(RCReference value) { - assert(IsUnavailable()); + DCHECK(IsUnavailable()); auto s = value->state(); if (s == State::kConcrete || s == State::kError) { - assert(!value_ && "IndirectAsyncValue::ForwardTo is called more than once"); + DCHECK(!value_) << "IndirectAsyncValue::ForwardTo is called more than once"; auto* concrete_value = value.release(); if (concrete_value->kind() == Kind::kIndirect) { auto* indirect_value = static_cast(concrete_value); concrete_value = indirect_value->value_; - assert(concrete_value != nullptr); - assert(concrete_value->kind() == Kind::kConcrete); + DCHECK(concrete_value != nullptr); + DCHECK(concrete_value->kind() == Kind::kConcrete); concrete_value->AddRef(); indirect_value->DropRef(); } diff --git a/third_party/tsl/tsl/concurrency/async_value.h b/third_party/tsl/tsl/concurrency/async_value.h index ddeba8e161950..59d5e274ab0e2 100644 --- a/third_party/tsl/tsl/concurrency/async_value.h +++ b/third_party/tsl/tsl/concurrency/async_value.h @@ -22,16 +22,15 @@ limitations under the License. #include #include #include -#include #include #include #include "absl/functional/any_invocable.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "absl/types/span.h" #include "tsl/concurrency/concurrent_vector.h" #include "tsl/concurrency/ref_count.h" +#include "tsl/platform/mem.h" namespace tsl { @@ -45,11 +44,6 @@ class ConcreteAsyncValue; template constexpr bool kMaybeBase = std::is_class::value && !std::is_final::value; -// TODO(ezhulenev): Switch to `tsl::port::Aligned(Malloc|Free)` once TFRT will -// be able to properly depend on TSL in the open source build. -void* AlignedAlloc(size_t alignment, size_t size); -void AlignedFree(void* ptr); - } // namespace internal // This is a future of the specified value type. Arbitrary C++ types may be used @@ -258,11 +252,12 @@ class AsyncValue { // ----------------------------------------------------------- // Implementation details follow. Clients should ignore them. + friend class IndirectAsyncValue; + // Utility template for tag dispatching. template struct TypeTag {}; - friend class IndirectAsyncValue; template AsyncValue(Kind kind, State state, bool is_refcounted, TypeTag) : refcount_(1), @@ -975,12 +970,12 @@ inline void AsyncValue::Destroy() { // explicit check and instead make ~IndirectAsyncValue go through the // GetTypeInfo().destructor case below. static_cast(this)->~IndirectAsyncValue(); - if (was_ref_counted) internal::AlignedFree(this); + if (was_ref_counted) port::AlignedFree(this); return; } GetTypeInfo().destructor(this); - if (was_ref_counted) internal::AlignedFree(this); + if (was_ref_counted) port::AlignedFree(this); } inline bool AsyncValue::IsUnique() const { diff --git a/third_party/tsl/tsl/concurrency/async_value_ptr_test.cc b/third_party/tsl/tsl/concurrency/async_value_ptr_test.cc index 137bf0fcaf8dd..54815c76190c2 100644 --- a/third_party/tsl/tsl/concurrency/async_value_ptr_test.cc +++ b/third_party/tsl/tsl/concurrency/async_value_ptr_test.cc @@ -13,6 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" #include "tsl/concurrency/async_value_ref.h" #include "tsl/platform/test.h" @@ -75,4 +80,206 @@ TEST(AsyncValuePtrTest, AndThen) { EXPECT_TRUE(executed); } +TEST(AsyncValuePtrTest, AndThenError) { + AsyncValueRef ref = MakeConstructedAsyncValueRef(42); + AsyncValuePtr ptr = ref.AsPtr(); + + auto error = absl::InternalError("test error"); + ptr.SetError(error); + ptr.AndThen([&](absl::Status status) { EXPECT_EQ(status, error); }); +} + +TEST(AsyncValuePtrTest, AndThenNoError) { + AsyncValueRef ref = MakeAvailableAsyncValueRef(42); + AsyncValuePtr ptr = ref.AsPtr(); + + ptr.AndThen([](absl::Status status) { EXPECT_TRUE(status.ok()); }); +} + +TEST(AsyncValuePtrTest, AndThenStatusOrError) { + AsyncValueRef ref = MakeConstructedAsyncValueRef(42); + AsyncValuePtr ptr = ref.AsPtr(); + + auto error = absl::InternalError("test error"); + ptr.SetError(error); + + ptr.AndThen([&](absl::StatusOr v) { + EXPECT_FALSE(v.ok()); + EXPECT_EQ(v.status(), error); + }); +} + +TEST(AsyncValuePtrTest, AndThenStatusOrNoError) { + AsyncValueRef ref = MakeAvailableAsyncValueRef(42); + AsyncValuePtr ptr = ref.AsPtr(); + + ptr.AndThen([&](absl::StatusOr v) { EXPECT_EQ(**v, 42); }); +} + +TEST(AsyncValuePtrTest, BlockUntilReady) { + AsyncValueRef ref = MakeAvailableAsyncValueRef(42); + AsyncValuePtr ptr = ref.AsPtr(); + BlockUntilReady(ptr); +} + +TEST(AsyncValuePtrTest, RunWhenReady) { + AsyncValueRef ref = MakeAvailableAsyncValueRef(42); + AsyncValuePtr ptr = ref.AsPtr(); + bool executed = false; + RunWhenReady(absl::MakeConstSpan({ptr}), [&] { executed = true; }); + EXPECT_TRUE(executed); +} + +namespace { +struct A { + virtual ~A() = default; +}; +struct B : public A {}; +struct C : public B {}; +struct D : public A {}; +} // namespace + +TEST(AsyncValuePtrTest, Isa) { + // Empty async pointer always returns false for any Isa. + AsyncValuePtr null_ptr; + EXPECT_FALSE(Isa(null_ptr)); + + AsyncValueRef a_ref = MakeAvailableAsyncValueRef(); + AsyncValueRef b_ref = MakeAvailableAsyncValueRef(); + AsyncValueRef c_ref = MakeAvailableAsyncValueRef(); + AsyncValueRef d_ref = MakeAvailableAsyncValueRef(); + + EXPECT_TRUE(Isa(a_ref.AsPtr())); + EXPECT_TRUE(Isa(b_ref.AsPtr())); + EXPECT_TRUE(Isa(c_ref.AsPtr())); + EXPECT_TRUE(Isa(d_ref.AsPtr())); + + // Error async value is Isa of any type in the hierarchy. + AsyncValueRef err = MakeErrorAsyncValueRef(absl::InternalError("error")); + EXPECT_TRUE(Isa(err.AsPtr())); + EXPECT_TRUE(Isa(err.AsPtr())); + EXPECT_TRUE(Isa(err.AsPtr())); + EXPECT_TRUE(Isa(err.AsPtr())); + + // If the value was constructed with a concrete type it should return true + // for Isa even if it was set to error later but only if types match. + AsyncValueRef a_err = MakeConstructedAsyncValueRef(); + AsyncValueRef b_err = MakeConstructedAsyncValueRef(); + a_err.SetError(absl::InternalError("error")); + b_err.SetError(absl::InternalError("error")); + + EXPECT_TRUE(Isa(a_err.AsPtr())); + EXPECT_TRUE(Isa(b_err.AsPtr())); + + // Indirect async value is Isa only if it would be a no-op cast. + auto indirect = MakeIndirectAsyncValue(); + AsyncValueRef c_indirect(indirect); + EXPECT_TRUE(Isa(c_indirect.AsPtr())); + EXPECT_FALSE(Isa(c_indirect.AsPtr())); + + // After forwarding indirect async value to a concrete one it correctly + // returns true from Isa check. + indirect->ForwardTo(c_ref.CopyRCRef()); + EXPECT_TRUE(Isa(c_indirect.AsPtr())); + EXPECT_TRUE(Isa(c_indirect.AsPtr())); +} + +TEST(AsyncValuePtrTest, DynCast) { + AsyncValueRef a_ref = MakeAvailableAsyncValueRef(); + AsyncValueRef b_ref = MakeAvailableAsyncValueRef(); + AsyncValueRef c_ref = MakeAvailableAsyncValueRef(); + AsyncValueRef d_ref = MakeAvailableAsyncValueRef(); + + EXPECT_TRUE(DynCast(a_ref.AsPtr())); + EXPECT_TRUE(DynCast(b_ref.AsPtr())); + EXPECT_TRUE(DynCast(c_ref.AsPtr())); + EXPECT_TRUE(DynCast(d_ref.AsPtr())); + + // No-op casts are always successful. + EXPECT_TRUE(DynCast(c_ref.AsPtr())); + + // We don't support casting to base (C inherits from B) because we can't do + // that safely relying just on AsyncValue type id. For safe conversion to base + // we need to introduce some kind of traits to the type hierarchy or rely on + // builtin `dynamic_cast` (will work only for constructed values). + EXPECT_FALSE(DynCast(c_ref.AsPtr())); + + // Types are unrelated, although they have same base. + EXPECT_FALSE(DynCast(d_ref.AsPtr())); + + // Error async value can be DynCast to any type in the hierarchy. + AsyncValueRef err = MakeErrorAsyncValueRef(absl::InternalError("error")); + EXPECT_TRUE(DynCast(err.AsPtr())); + EXPECT_TRUE(DynCast(err.AsPtr())); + EXPECT_TRUE(DynCast(err.AsPtr())); + EXPECT_TRUE(DynCast(err.AsPtr())); + + // If the value was constructed with a concrete type it should DynCast + // successfully even it it was set to error later but only if types match. + AsyncValueRef a_err = MakeConstructedAsyncValueRef(); + AsyncValueRef b_err = MakeConstructedAsyncValueRef(); + a_err.SetError(absl::InternalError("error")); + b_err.SetError(absl::InternalError("error")); + + EXPECT_TRUE(DynCast(a_err.AsPtr())); + EXPECT_TRUE(DynCast(b_err.AsPtr())); + EXPECT_FALSE(DynCast(a_err.AsPtr())); + + // Indirect async value can't be DynCast until it's forwarded unless it's a + // no-op DynCast to the same type. + auto indirect = MakeIndirectAsyncValue(); + AsyncValueRef c_indirect(indirect); + EXPECT_TRUE(DynCast(c_indirect.AsPtr())); + EXPECT_FALSE(DynCast(c_indirect.AsPtr())); + + // After forwarding indirect async value to a concrete one it can be DynCast + // to a concrete type. + indirect->ForwardTo(c_ref.CopyRCRef()); + EXPECT_TRUE(DynCast(c_indirect.AsPtr())); + EXPECT_TRUE(DynCast(c_indirect.AsPtr())); +} + +TEST(AsyncValuePtrTest, Cast) { + AsyncValueRef a_ref = MakeAvailableAsyncValueRef(); + AsyncValueRef b_ref = MakeAvailableAsyncValueRef(); + AsyncValueRef c_ref = MakeAvailableAsyncValueRef(); + AsyncValueRef d_ref = MakeAvailableAsyncValueRef(); + + EXPECT_TRUE(Cast(a_ref.AsPtr())); + EXPECT_TRUE(Cast(b_ref.AsPtr())); + EXPECT_TRUE(Cast(c_ref.AsPtr())); + EXPECT_TRUE(Cast(d_ref.AsPtr())); + + EXPECT_TRUE(Cast(c_ref.AsPtr())); + + // Error async value can be Cast to any type in the hierarchy. + AsyncValueRef err = MakeErrorAsyncValueRef(absl::InternalError("error")); + EXPECT_TRUE(Cast(err.AsPtr())); + EXPECT_TRUE(Cast(err.AsPtr())); + EXPECT_TRUE(Cast(err.AsPtr())); + EXPECT_TRUE(Cast(err.AsPtr())); + + // If the value was constructed with a concrete type it should Cast + // successfully even it it was set to error later but only if types match. + AsyncValueRef a_err = MakeConstructedAsyncValueRef(); + AsyncValueRef b_err = MakeConstructedAsyncValueRef(); + a_err.SetError(absl::InternalError("error")); + b_err.SetError(absl::InternalError("error")); + + EXPECT_TRUE(Cast(a_err.AsPtr())); + EXPECT_TRUE(Cast(b_err.AsPtr())); + + // Indirect async value can't be Cast until it's forwarded unless it's a + // no-op Cast to the same type. + auto indirect = MakeIndirectAsyncValue(); + AsyncValueRef c_indirect(indirect); + EXPECT_TRUE(Cast(c_indirect.AsPtr())); + + // After forwarding indirect async value to a concrete one it can be Cast + // to a concrete type. + indirect->ForwardTo(c_ref.CopyRCRef()); + EXPECT_TRUE(Cast(c_indirect.AsPtr())); + EXPECT_TRUE(Cast(c_indirect.AsPtr())); +} + } // namespace tsl diff --git a/third_party/tsl/tsl/concurrency/async_value_ref.cc b/third_party/tsl/tsl/concurrency/async_value_ref.cc index b850182dffd6d..8a1e23c05d2cc 100644 --- a/third_party/tsl/tsl/concurrency/async_value_ref.cc +++ b/third_party/tsl/tsl/concurrency/async_value_ref.cc @@ -18,6 +18,10 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tsl/concurrency/async_value.h" +#include "tsl/concurrency/ref_count.h" + namespace tsl { RCReference MakeIndirectAsyncValue() { diff --git a/third_party/tsl/tsl/concurrency/async_value_ref.h b/third_party/tsl/tsl/concurrency/async_value_ref.h index 71d5ae0217b78..2d1b5db45cd31 100644 --- a/third_party/tsl/tsl/concurrency/async_value_ref.h +++ b/third_party/tsl/tsl/concurrency/async_value_ref.h @@ -17,15 +17,20 @@ limitations under the License. #define TENSORFLOW_TSL_CONCURRENCY_ASYNC_VALUE_REF_H_ #include -#include #include #include #include +#include "absl/base/attributes.h" +#include "absl/container/inlined_vector.h" +#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "tsl/concurrency/async_value.h" #include "tsl/concurrency/ref_count.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/mem.h" namespace tsl { @@ -53,12 +58,11 @@ class AsyncValueRef { // Support implicit conversion from AsyncValueRef to // AsyncValueRef. - template ::value>* = nullptr> - AsyncValueRef(AsyncValueRef&& u) // NOLINT + template * = nullptr> + AsyncValueRef(AsyncValueRef&& u) // NOLINT : value_(u.ReleaseRCRef()) {} - // Support implicit conversion from RCReference. + // Support implicit conversion from RCReference. AsyncValueRef(RCReference value) // NOLINT : value_(std::move(value)) {} @@ -83,12 +87,40 @@ class AsyncValueRef { // Return the stored value. The AsyncValueRef must be available. T& get() const { return value_->get(); } - // Return the stored value as a subclass type. The AsyncValueRef must be + // Return the stored value as a derived type. The AsyncValueRef must be // available. - template ::value>* = nullptr> - SubclassT& get() const { - return value_->get(); + template * = nullptr> + Derived& get() const { + return value_->get(); + } + + template * = nullptr> + bool Isa() const { + // Isa is successful if: + // (1) This is no-op cast even if concrete payload has different type. + // (2) Type id of a concrete payload matches Derived type id. + // (3) Payload is for a special case of ErrorAsyncValue. + return value_ && (std::is_same_v || // (1) + value_->IsType() || // (2) + value_->IsType()); // (3) + } + + template * = nullptr> + AsyncValueRef Cast() const { + DCHECK(DynCast()) << "Illegal async value cast"; + return AsyncValueRef(value_); + } + + template * = nullptr> + AsyncValueRef DynCast() const { + DCHECK(value_) << "Async value must be not null"; + return Isa() ? AsyncValueRef(value_) + : AsyncValueRef(nullptr); + } + + template * = nullptr> + AsyncValueRef DynCastOrNull() const { + return value_ ? DynCast(value_) : AsyncValueRef(nullptr); } T* operator->() const { return &get(); } @@ -130,7 +162,7 @@ class AsyncValueRef { } void SetError(absl::Status status) const { - assert(!status.ok() && "expected non-ok status"); + DCHECK(!status.ok()) << "expected non-ok status"; return value_->SetError(std::move(status)); } @@ -202,6 +234,35 @@ class AsyncValuePtr { return *this; } + template * = nullptr> + bool Isa() const { + // Isa is successful if: + // (1) This is no-op cast even if concrete payload has different type. + // (2) Type id of a concrete payload matches Derived type id. + // (3) Payload is for a special case of ErrorAsyncValue. + return value_ && (std::is_same_v || // (1) + value_->IsType() || // (2) + value_->IsType()); // (3) + } + + template * = nullptr> + AsyncValuePtr Cast() const { + DCHECK(DynCast()) << "Illegal async value cast"; + return AsyncValuePtr(value_); + } + + template * = nullptr> + AsyncValuePtr DynCast() const { + DCHECK(value_) << "Async value must be not null"; + return Isa() ? AsyncValuePtr(value_) + : AsyncValuePtr(nullptr); + } + + template * = nullptr> + AsyncValuePtr DynCastOrNull() const { + return value_ ? DynCast(value_) : AsyncValuePtr(nullptr); + } + bool IsAvailable() const { return value_->IsAvailable(); } bool IsUnavailable() const { return value_->IsUnavailable(); } @@ -218,7 +279,7 @@ class AsyncValuePtr { const absl::Status& GetError() const { return value_->GetError(); } void SetError(absl::Status status) const { - assert(!status.ok() && "expected non-ok status"); + DCHECK(!status.ok()) << "expected non-ok status"; return value_->SetError(std::move(status)); } @@ -307,6 +368,94 @@ RCReference MakeErrorAsyncValueRef(std::string_view message); // Construct an empty IndirectAsyncValue, not forwarding to anything. RCReference MakeIndirectAsyncValue(); +//===----------------------------------------------------------------------===// +// Functions for awaiting on the async values. +//===----------------------------------------------------------------------===// + +template +void BlockUntilReady(const AsyncValueRef& ref) { + BlockUntilReady(ref.GetAsyncValue()); +} + +template +void BlockUntilReady(const AsyncValuePtr& ptr) { + BlockUntilReady(ptr.value()); +} + +template +void RunWhenReady(absl::Span> refs, + absl::AnyInvocable callee) { + absl::InlinedVector values(refs.size()); + for (size_t i = 0; i < refs.size(); ++i) { + values[i] = refs[i].GetAsyncValue(); + } + RunWhenReady(values, std::move(callee)); +} + +template +void RunWhenReady(absl::Span> ptrs, + absl::AnyInvocable callee) { + absl::InlinedVector values(ptrs.size()); + for (size_t i = 0; i < ptrs.size(); ++i) { + values[i] = ptrs[i].value(); + } + RunWhenReady(values, std::move(callee)); +} + +//===----------------------------------------------------------------------===// +// LLVM-style type casting library for async value refs and ptrs. +//===----------------------------------------------------------------------===// + +template * = nullptr> +bool Isa(const AsyncValueRef& ref) { + return ref.template Isa(); +} + +template * = nullptr> +AsyncValueRef Cast(const AsyncValueRef& ref) { + return ref.template Cast(); +} + +template * = nullptr> +AsyncValueRef DynCast(const AsyncValueRef& ref) { + return ref.template DynCast(); +} + +template * = nullptr> +AsyncValueRef DynCastOrNull(const AsyncValueRef& ref) { + return ref.template DynCastOrNull(); +} + +template * = nullptr> +bool Isa(AsyncValuePtr ptr) { + return ptr.template Isa(); +} + +template * = nullptr> +AsyncValuePtr Cast(AsyncValuePtr ptr) { + return ptr.template Cast(); +} + +template * = nullptr> +AsyncValuePtr DynCast(AsyncValuePtr ptr) { + return ptr.template DynCast(); +} + +template * = nullptr> +AsyncValuePtr DynCastOrNull(AsyncValuePtr ptr) { + return ptr.template DynCastOrNull(); +} + +//===----------------------------------------------------------------------===// +// Constructing reference-counted async values on the heap. //===----------------------------------------------------------------------===// namespace internal { @@ -318,17 +467,12 @@ T* PlacementConstruct(void* buf, Args&&... args) { template T* AllocateAndConstruct(Args&&... args) { - // TODO(ezhulenev): `port::AlignedMalloc` has a different order of arguments! - void* buf = internal::AlignedAlloc(alignof(T), sizeof(T)); + void* buf = port::AlignedMalloc(sizeof(T), alignof(T)); return PlacementConstruct(buf, std::forward(args)...); } } // namespace internal -//===----------------------------------------------------------------------===// -// Constructing reference-counted async values on the heap. -//===----------------------------------------------------------------------===// - // Allocate an unconstructed AsyncValueRef. The AsyncValueRef should be made // available later by invoking AsyncValueRef::emplace or // AsyncValueRef::SetError. diff --git a/third_party/tsl/tsl/concurrency/async_value_ref_test.cc b/third_party/tsl/tsl/concurrency/async_value_ref_test.cc index 63b7b6237e897..4efd691ec4841 100644 --- a/third_party/tsl/tsl/concurrency/async_value_ref_test.cc +++ b/third_party/tsl/tsl/concurrency/async_value_ref_test.cc @@ -15,9 +15,14 @@ limitations under the License. #include "tsl/concurrency/async_value_ref.h" -#include +#include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "tsl/concurrency/async_value.h" +#include "tsl/concurrency/ref_count.h" #include "tsl/platform/test.h" namespace tsl { @@ -112,67 +117,48 @@ TEST(AsyncValueRefTest, CopyRef) { EXPECT_EQ(value.GetAsyncValue(), copied_value.GetAsyncValue()); } -TEST(AsyncValueRefTest, AndThenError) { - auto value = MakeConstructedAsyncValueRef(kTestValue); +TEST(AsyncValueRefTest, AndThen) { + AsyncValueRef ref = MakeUnconstructedAsyncValueRef(); + + EXPECT_FALSE(ref.IsConcrete()); + EXPECT_FALSE(ref.IsAvailable()); - auto diag = absl::InternalError("test error"); - value.AndThen([&](absl::Status status) { EXPECT_EQ(status, diag); }); + bool executed = false; + ref.AndThen([&]() { executed = true; }); - value.SetError(diag); + ref.emplace(42); + EXPECT_TRUE(executed); } -TEST(AsyncValueRefTest, AndThenNoError) { - auto value = MakeConstructedAsyncValueRef(kTestValue); +TEST(AsyncValueRefTest, AndThenError) { + AsyncValueRef ref = MakeConstructedAsyncValueRef(42); - value.AndThen([](absl::Status status) { EXPECT_TRUE(status.ok()); }); + auto error = absl::InternalError("test error"); + ref.SetError(error); - value.SetStateConcrete(); + ref.AndThen([&](absl::Status status) { EXPECT_EQ(status, error); }); } -TEST(AsyncValueRefTest, AndThenStatusOrError) { - auto value = MakeConstructedAsyncValueRef(kTestValue); - - auto diag = absl::InternalError("test error"); - value.AndThen([&](absl::StatusOr v) { - EXPECT_FALSE(v.ok()); - EXPECT_EQ(v.status(), diag); - }); - - value.SetError(diag); +TEST(AsyncValueRefTest, AndThenNoError) { + AsyncValueRef ref = MakeAvailableAsyncValueRef(42); + ref.AndThen([](absl::Status status) { EXPECT_TRUE(status.ok()); }); } -TEST(AsyncValueRefTest, PtrAndThenStatusOrError) { - auto value = MakeConstructedAsyncValueRef(kTestValue); +TEST(AsyncValueRefTest, AndThenStatusOrError) { + AsyncValueRef ref = MakeConstructedAsyncValueRef(42); - auto diag = absl::InternalError("test error"); - value.AsPtr().AndThen([&](absl::StatusOr v) { + auto error = absl::InternalError("test error"); + ref.SetError(error); + + ref.AndThen([&](absl::StatusOr v) { EXPECT_FALSE(v.ok()); - EXPECT_EQ(v.status(), diag); + EXPECT_EQ(v.status(), error); }); - - value.SetError(diag); } TEST(AsyncValueRefTest, AndThenStatusOrNoError) { - auto value = MakeConstructedAsyncValueRef(kTestValue); - - value.AndThen([](absl::StatusOr v) { - EXPECT_TRUE(v.ok()); - EXPECT_EQ(**v, kTestValue); - }); - - value.SetStateConcrete(); -} - -TEST(AsyncValueRefTest, PtrAndThenStatusOrNoError) { - auto value = MakeConstructedAsyncValueRef(kTestValue); - - value.AsPtr().AndThen([](absl::StatusOr v) { - EXPECT_TRUE(v.ok()); - EXPECT_EQ(**v, kTestValue); - }); - - value.SetStateConcrete(); + AsyncValueRef ref = MakeAvailableAsyncValueRef(42); + ref.AndThen([&](absl::StatusOr v) { EXPECT_EQ(**v, 42); }); } TEST(AsyncValueRefTest, Nullptr) { @@ -187,4 +173,168 @@ TEST(AsyncValueRefTest, Nullptr) { EXPECT_FALSE(av_int2); } +TEST(AsyncValueRefTest, BlockUntilReady) { + AsyncValueRef ref = MakeAvailableAsyncValueRef(42); + BlockUntilReady(ref); +} + +TEST(AsyncValueRefTest, RunWhenReady) { + AsyncValueRef ref = MakeAvailableAsyncValueRef(42); + bool executed = false; + RunWhenReady(absl::MakeConstSpan({ref}), [&] { executed = true; }); + EXPECT_TRUE(executed); +} + +namespace { +struct A { + virtual ~A() = default; +}; +struct B : public A {}; +struct C : public B {}; +struct D : public A {}; +} // namespace + +TEST(AsyncValueRefTest, Isa) { + // Empty async reference always returns false for any Isa. + AsyncValueRef null_ref; + EXPECT_FALSE(Isa(null_ref)); + + AsyncValueRef a_ref = MakeAvailableAsyncValueRef(); + AsyncValueRef b_ref = MakeAvailableAsyncValueRef(); + AsyncValueRef c_ref = MakeAvailableAsyncValueRef(); + AsyncValueRef d_ref = MakeAvailableAsyncValueRef(); + + EXPECT_TRUE(Isa(a_ref)); + EXPECT_TRUE(Isa(b_ref)); + EXPECT_TRUE(Isa(c_ref)); + EXPECT_TRUE(Isa(d_ref)); + + // Error async value is Isa of any type in the hierarchy. + AsyncValueRef err = MakeErrorAsyncValueRef(absl::InternalError("error")); + EXPECT_TRUE(Isa(err)); + EXPECT_TRUE(Isa(err)); + EXPECT_TRUE(Isa(err)); + EXPECT_TRUE(Isa(err)); + + // If the value was constructed with a concrete type it should return true + // for Isa even if it was set to error later but only if types match.S + AsyncValueRef a_err = MakeConstructedAsyncValueRef(); + AsyncValueRef b_err = MakeConstructedAsyncValueRef(); + a_err.SetError(absl::InternalError("error")); + b_err.SetError(absl::InternalError("error")); + + EXPECT_TRUE(Isa(a_err)); + EXPECT_TRUE(Isa(b_err)); + + // Indirect async value is Isa only if it would be a no-op cast. + auto indirect = MakeIndirectAsyncValue(); + AsyncValueRef c_indirect(indirect); + EXPECT_TRUE(Isa(c_indirect)); + EXPECT_FALSE(Isa(c_indirect)); + + // After forwarding indirect async value to a concrete one it correctly + // returns true from Isa check. + indirect->ForwardTo(c_ref.CopyRCRef()); + EXPECT_TRUE(Isa(c_indirect)); + EXPECT_TRUE(Isa(c_indirect)); +} + +TEST(AsyncValueRefTest, DynCast) { + AsyncValueRef a_ref = MakeAvailableAsyncValueRef(); + AsyncValueRef b_ref = MakeAvailableAsyncValueRef(); + AsyncValueRef c_ref = MakeAvailableAsyncValueRef(); + AsyncValueRef d_ref = MakeAvailableAsyncValueRef(); + + EXPECT_TRUE(DynCast(a_ref)); + EXPECT_TRUE(DynCast(b_ref)); + EXPECT_TRUE(DynCast(c_ref)); + EXPECT_TRUE(DynCast(d_ref)); + + // No-op casts are always successful. + EXPECT_TRUE(DynCast(c_ref)); + + // We don't support casting to base (C inherits from B) because we can't do + // that safely relying just on AsyncValue type id. For safe conversion to base + // we need to introduce some kind of traits to the type hierarchy or rely on + // builtin `dynamic_cast` (will work only for constructed values). + EXPECT_FALSE(DynCast(c_ref)); + + // Types are unrelated, although they have same base. + EXPECT_FALSE(DynCast(d_ref)); + + // Error async value can be DynCast to any type in the hierarchy. + AsyncValueRef err = MakeErrorAsyncValueRef(absl::InternalError("error")); + EXPECT_TRUE(DynCast(err)); + EXPECT_TRUE(DynCast(err)); + EXPECT_TRUE(DynCast(err)); + EXPECT_TRUE(DynCast(err)); + + // If the value was constructed with a concrete type it should DynCast + // successfully even it it was set to error later but only if types match. + AsyncValueRef a_err = MakeConstructedAsyncValueRef(); + AsyncValueRef b_err = MakeConstructedAsyncValueRef(); + a_err.SetError(absl::InternalError("error")); + b_err.SetError(absl::InternalError("error")); + + EXPECT_TRUE(DynCast(a_err)); + EXPECT_TRUE(DynCast(b_err)); + EXPECT_FALSE(DynCast(a_err)); + + // Indirect async value can't be DynCast until it's forwarded unless it's a + // no-op DynCast to the same type. + auto indirect = MakeIndirectAsyncValue(); + AsyncValueRef c_indirect(indirect); + EXPECT_TRUE(DynCast(c_indirect)); + EXPECT_FALSE(DynCast(c_indirect)); + + // After forwarding indirect async value to a concrete one it can be DynCast + // to a concrete type. + indirect->ForwardTo(c_ref.CopyRCRef()); + EXPECT_TRUE(DynCast(c_indirect)); + EXPECT_TRUE(DynCast(c_indirect)); +} + +TEST(AsyncValueRefTest, Cast) { + AsyncValueRef a_ref = MakeAvailableAsyncValueRef(); + AsyncValueRef b_ref = MakeAvailableAsyncValueRef(); + AsyncValueRef c_ref = MakeAvailableAsyncValueRef(); + AsyncValueRef d_ref = MakeAvailableAsyncValueRef(); + + EXPECT_TRUE(Cast(a_ref)); + EXPECT_TRUE(Cast(b_ref)); + EXPECT_TRUE(Cast(c_ref)); + EXPECT_TRUE(Cast(d_ref)); + + EXPECT_TRUE(Cast(c_ref)); + + // Error async value can be Cast to any type in the hierarchy. + AsyncValueRef err = MakeErrorAsyncValueRef(absl::InternalError("error")); + EXPECT_TRUE(Cast(err)); + EXPECT_TRUE(Cast(err)); + EXPECT_TRUE(Cast(err)); + EXPECT_TRUE(Cast(err)); + + // If the value was constructed with a concrete type it should Cast + // successfully even it it was set to error later but only if types match. + AsyncValueRef a_err = MakeConstructedAsyncValueRef(); + AsyncValueRef b_err = MakeConstructedAsyncValueRef(); + a_err.SetError(absl::InternalError("error")); + b_err.SetError(absl::InternalError("error")); + + EXPECT_TRUE(Cast(a_err)); + EXPECT_TRUE(Cast(b_err)); + + // Indirect async value can't be Cast until it's forwarded unless it's a + // no-op Cast to the same type. + auto indirect = MakeIndirectAsyncValue(); + AsyncValueRef c_indirect(indirect); + EXPECT_TRUE(Cast(c_indirect)); + + // After forwarding indirect async value to a concrete one it can be Cast + // to a concrete type. + indirect->ForwardTo(c_ref.CopyRCRef()); + EXPECT_TRUE(Cast(c_indirect)); + EXPECT_TRUE(Cast(c_indirect)); +} + } // namespace tsl diff --git a/third_party/tsl/tsl/concurrency/async_value_test.cc b/third_party/tsl/tsl/concurrency/async_value_test.cc index 8770e787b3bc6..67ec25d26137c 100644 --- a/third_party/tsl/tsl/concurrency/async_value_test.cc +++ b/third_party/tsl/tsl/concurrency/async_value_test.cc @@ -15,9 +15,11 @@ limitations under the License. #include "tsl/concurrency/async_value.h" +#include #include #include +#include "absl/status/status.h" #include "tsl/concurrency/async_value_ref.h" #include "tsl/platform/test.h" diff --git a/third_party/tsl/tsl/concurrency/ref_count.h b/third_party/tsl/tsl/concurrency/ref_count.h index c10921ef2b6fb..1b3154021c6f3 100644 --- a/third_party/tsl/tsl/concurrency/ref_count.h +++ b/third_party/tsl/tsl/concurrency/ref_count.h @@ -19,11 +19,19 @@ limitations under the License. #include #include #include +#include #include #include namespace tsl { +namespace internal { +// TODO(ezhulenev): Replace with C++20 concept when available. +// https://en.cppreference.com/w/cpp/concepts/derived_from +template +using DerivedFrom = typename std::enable_if_t>; +} // namespace internal + #ifndef NDEBUG inline std::atomic total_reference_counted_objects; @@ -110,8 +118,7 @@ class ReferenceCounted { }; // This is a smart pointer that keeps the specified reference counted value -// around. It is move-only to avoid accidental copies, but it can be copied -// explicitly. +// around. template class RCReference { public: @@ -138,14 +145,12 @@ class RCReference { } // Support implicit conversion from RCReference to RCReference. - template ::value>> - RCReference(RCReference&& u) : pointer_(u.pointer_) { // NOLINT + template * = nullptr> + RCReference(RCReference&& u) : pointer_(u.pointer_) { // NOLINT u.pointer_ = nullptr; } - template ::value>> - RCReference(const RCReference& u) : pointer_(u.pointer_) { // NOLINT + template * = nullptr> + RCReference(const RCReference& u) : pointer_(u.pointer_) { // NOLINT if (pointer_) pointer_->AddRef(); } diff --git a/third_party/tsl/tsl/cuda/BUILD.bazel b/third_party/tsl/tsl/cuda/BUILD.bazel deleted file mode 100644 index 50c2846bd20e4..0000000000000 --- a/third_party/tsl/tsl/cuda/BUILD.bazel +++ /dev/null @@ -1,291 +0,0 @@ -# Description: -# Stubs for dynamically loading CUDA. - -load("//tsl/cuda:stub.bzl", "cuda_stub") -load( - "//tsl/platform:build_config.bzl", - "tsl_cc_test", -) -load( - "//tsl/platform:rules_cc.bzl", - "cc_library", -) -load( - "//tsl/platform/default:cuda_build_defs.bzl", - "cuda_rpath_flags", - "if_cuda_is_configured", -) - -package( - licenses = ["notice"], -) - -cuda_stub( - name = "cublas", - srcs = ["cublas.symbols"], -) - -cc_library( - name = "cublas", - srcs = if_cuda_is_configured([ - "cublas_stub.cc", - "cublas.tramp.S", - ]), - linkopts = if_cuda_is_configured(cuda_rpath_flags( - "nvidia/cublas/lib", - )), - local_defines = [ - "IMPLIB_EXPORT_SHIMS=1", - ], - textual_hdrs = ["cublas.inc"], - visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ - "@com_google_absl//absl/container:flat_hash_set", - "@local_config_cuda//cuda:cuda_headers", - "//tsl/platform:dso_loader", - "//tsl/platform:env", - ]), -) - -cuda_stub( - name = "cublasLt", - srcs = ["cublasLt.symbols"], -) - -cc_library( - name = "cublas_lt", - srcs = if_cuda_is_configured([ - "cublasLt_stub.cc", - "cublasLt.tramp.S", - ]), - local_defines = [ - "IMPLIB_EXPORT_SHIMS=1", - ], - textual_hdrs = ["cublasLt.inc"], - visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - "//tsl/platform:dso_loader", - "//tsl/platform:env", - ]), -) - -cuda_stub( - name = "cuda", - srcs = ["cuda.symbols"], -) - -cc_library( - name = "cuda", - srcs = if_cuda_is_configured([ - "cuda_stub.cc", - "cuda.tramp.S", - ]), - local_defines = [ - "IMPLIB_EXPORT_SHIMS=1", - ], - textual_hdrs = ["cuda.inc"], - visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - "//tsl/platform:dso_loader", - "//tsl/platform:env", - ]), -) - -cuda_stub( - name = "cudart", - srcs = ["cudart.symbols"], -) - -cc_library( - name = "cudart", - srcs = select({ - # include dynamic loading implementation only when if_cuda_is_configured and build dynamically - "//tsl:is_cuda_enabled_and_oss": [ - "cudart.tramp.S", - "cudart_stub.cc", - ], - "//conditions:default": [], - }), - linkopts = select({ - "//tsl:is_cuda_enabled_and_oss": cuda_rpath_flags("nvidia/cuda_runtime/lib"), - "//conditions:default": [], - }), - local_defines = [ - "IMPLIB_EXPORT_SHIMS=1", - ], - textual_hdrs = ["cudart.inc"], - visibility = ["//visibility:public"], - deps = select({ - "//tsl:is_cuda_enabled_and_oss": [ - ":cuda", - "//tsl/platform:dso_loader", - "//tsl/platform:env", - "@com_google_absl//absl/container:flat_hash_set", - "@local_config_cuda//cuda:cuda_headers", - ], - "//conditions:default": [], - }), -) - -cuda_stub( - name = "cudnn", - srcs = ["cudnn.symbols"], -) - -cc_library( - name = "cudnn", - srcs = if_cuda_is_configured([ - "cudnn_stub.cc", - "cudnn.tramp.S", - ]), - linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/cudnn/lib")), - local_defines = [ - "IMPLIB_EXPORT_SHIMS=1", - ], - textual_hdrs = ["cudnn.inc"], - visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ - "@com_google_absl//absl/container:flat_hash_map", - "@local_config_cuda//cuda:cudnn_header", - "//tsl/platform:dso_loader", - "//tsl/platform:env", - ]), -) - -cc_library( - name = "nccl_rpath", - linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/nccl/lib")), - visibility = ["//visibility:public"], -) - -cc_library( - name = "tensorrt_rpath", - linkopts = if_cuda_is_configured(cuda_rpath_flags("tensorrt")), - visibility = ["//visibility:public"], -) - -cuda_stub( - name = "cufft", - srcs = ["cufft.symbols"], -) - -cc_library( - name = "cufft", - srcs = if_cuda_is_configured([ - "cufft_stub.cc", - "cufft.tramp.S", - ]), - linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/cufft/lib")), - local_defines = [ - "IMPLIB_EXPORT_SHIMS=1", - ], - textual_hdrs = ["cufft.inc"], - visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - "//tsl/platform:dso_loader", - "//tsl/platform:env", - ]), -) - -cuda_stub( - name = "cupti", - srcs = ["cupti.symbols"], -) - -cc_library( - name = "cupti", - srcs = if_cuda_is_configured([ - "cupti_stub.cc", - "cupti.tramp.S", - ]), - data = if_cuda_is_configured(["@local_config_cuda//cuda:cupti_dsos"]), - linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/cuda_cupti/lib")), - local_defines = [ - "IMPLIB_EXPORT_SHIMS=1", - ], - textual_hdrs = ["cupti.inc"], - visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_cuda//cuda:cupti_headers", - "//tsl/platform:dso_loader", - "//tsl/platform:env", - ]), -) - -cuda_stub( - name = "cusolver", - srcs = ["cusolver.symbols"], -) - -cc_library( - name = "cusolver", - srcs = if_cuda_is_configured([ - "cusolver_stub.cc", - "cusolver.tramp.S", - ]), - linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/cusolver/lib")), - local_defines = [ - "IMPLIB_EXPORT_SHIMS=1", - ], - textual_hdrs = ["cusolver.inc"], - visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - "//tsl/platform:dso_loader", - "//tsl/platform:env", - ]), -) - -cuda_stub( - name = "cusparse", - srcs = ["cusparse.symbols"], -) - -cc_library( - name = "cusparse", - srcs = if_cuda_is_configured([ - "cusparse_stub.cc", - "cusparse.tramp.S", - ]), - linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/cusparse/lib")), - local_defines = [ - "IMPLIB_EXPORT_SHIMS=1", - ], - textual_hdrs = ["cusparse.inc"], - visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - "//tsl/platform:dso_loader", - "//tsl/platform:env", - ]), -) - -cuda_stub( - name = "nccl", - srcs = ["nccl.symbols"], -) - -cc_library( - name = "nccl_stub", - srcs = if_cuda_is_configured([ - "nccl_stub.cc", - "nccl.tramp.S", - ]), - linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/nccl/lib")), - local_defines = [ - "IMPLIB_EXPORT_SHIMS=1", - ], - textual_hdrs = ["nccl.inc"], - visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ - "@com_google_absl//absl/container:flat_hash_set", - "@local_config_cuda//cuda:cuda_headers", - "@local_config_nccl//:nccl_headers", - "//tsl/platform:dso_loader", - "//tsl/platform:env", - ]), -) diff --git a/third_party/tsl/tsl/cuda/stub.bzl b/third_party/tsl/tsl/cuda/stub.bzl deleted file mode 100644 index 0dbfad0965808..0000000000000 --- a/third_party/tsl/tsl/cuda/stub.bzl +++ /dev/null @@ -1,26 +0,0 @@ -"""Macros to generate CUDA library stubs from a list of symbols.""" - -def cuda_stub(name, srcs): - """Generates a CUDA stub from a list of symbols. - - Generates two files: - * library.inc, which contains a list of symbols suitable for inclusion by - C++, and - * library.tramp.S, which contains assembly-language trampolines for each - symbol. - """ - native.genrule( - name = "{}_stub_gen".format(name), - srcs = srcs, - tools = ["//third_party/implib_so:make_stub"], - outs = [ - "{}.inc".format(name), - "{}.tramp.S".format(name), - ], - tags = ["gpu"], - cmd = select({ - "//tsl:linux_aarch64": "$(location //third_party/implib_so:make_stub) $< --outdir $(RULEDIR) --target aarch64", - "//tsl:linux_x86_64": "$(location //third_party/implib_so:make_stub) $< --outdir $(RULEDIR) --target x86_64", - "//conditions:default": "NOT_IMPLEMENTED_FOR_THIS_PLATFORM_OR_ARCHITECTURE", - }), - ) diff --git a/third_party/tsl/tsl/distributed_runtime/BUILD b/third_party/tsl/tsl/distributed_runtime/BUILD index cc384976a5887..8d8bc75119024 100644 --- a/third_party/tsl/tsl/distributed_runtime/BUILD +++ b/third_party/tsl/tsl/distributed_runtime/BUILD @@ -2,6 +2,7 @@ # Distributed runtime modules for machine learning, which allows coordination between multiple # processes for distributed operations. +load("//tsl:tsl.bzl", "internal_visibility") load( "//tsl/platform:rules_cc.bzl", "cc_library", @@ -9,9 +10,9 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ + default_visibility = internal_visibility([ "//tsl:internal", - ], + ]), licenses = ["notice"], ) diff --git a/third_party/tsl/tsl/distributed_runtime/coordination/BUILD b/third_party/tsl/tsl/distributed_runtime/coordination/BUILD index fc4f0f419e532..fe434d10bd4d9 100644 --- a/third_party/tsl/tsl/distributed_runtime/coordination/BUILD +++ b/third_party/tsl/tsl/distributed_runtime/coordination/BUILD @@ -1,10 +1,10 @@ -load("//tsl/platform:rules_cc.bzl", "cc_library") +load("//tsl:tsl.bzl", "if_oss", "internal_visibility", "tsl_gpu_library") load("//tsl/platform:build_config.bzl", "tf_proto_library", "tsl_cc_test") -load("//tsl:tsl.bzl", "if_oss", "set_external_visibility", "tsl_gpu_library") +load("//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([ + default_visibility = internal_visibility([ "//tsl:internal", ]), licenses = ["notice"], @@ -78,13 +78,13 @@ tsl_gpu_library( "//tsl/platform:thread_annotations", "//tsl/protobuf:coordination_config_proto_cc", "//tsl/protobuf:coordination_service_proto_cc", - "//tsl/util:device_name_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/memory", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", + "@xla//xla/tsl/util:device_name_utils", ], alwayslink = 1, ) @@ -140,7 +140,6 @@ tsl_gpu_library( "//tsl/framework:cancellation", "//tsl/lib/monitoring:gauge", "//tsl/platform:env", - "//tsl/platform:errors", "//tsl/platform:mutex", "//tsl/platform:random", "//tsl/platform:status", @@ -149,6 +148,7 @@ tsl_gpu_library( "//tsl/protobuf:coordination_config_proto_cc", "//tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -169,11 +169,15 @@ tsl_cc_test( "//tsl/platform:errors", "//tsl/platform:protobuf", "//tsl/platform:status", + "//tsl/platform:statusor", "//tsl/platform:test", "//tsl/platform:test_main", "//tsl/protobuf:coordination_config_proto_cc_impl", "//tsl/protobuf:coordination_service_proto_cc_impl", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/time", ], ) diff --git a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service.cc b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service.cc index 0b916e65aaa20..45dfb972a131f 100644 --- a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service.cc +++ b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service.cc @@ -26,11 +26,13 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/synchronization/notification.h" #include "absl/time/time.h" +#include "xla/tsl/util/device_name_utils.h" #include "tsl/distributed_runtime/call_options.h" #include "tsl/distributed_runtime/coordination/coordination_client.h" #include "tsl/distributed_runtime/coordination/coordination_service_error_util.h" @@ -44,7 +46,6 @@ limitations under the License. #include "tsl/platform/thread_annotations.h" #include "tsl/protobuf/coordination_config.pb.h" #include "tsl/protobuf/coordination_service.pb.h" -#include "tsl/util/device_name_utils.h" namespace tsl { namespace { @@ -62,6 +63,7 @@ constexpr int kServiceToClientTimeoutMs = 10 * 1000; // 10 seconds constexpr size_t kOngoingBarriersSoftLimit = 20; constexpr char kHealthCheckThread[] = "CoordinationServiceHealthCheck"; constexpr int kPendingTaskLogLimit = 20; +constexpr int kPendingStragglerLogLimit = 3; std::string GetTaskName(absl::string_view job_name, int task_id) { return strings::StrCat("/job:", job_name, "/replica:", 0, "/task:", task_id); @@ -104,32 +106,36 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { void SetDeviceAggregationFunction( std::function post_aggregate_device_fn) override; - Status RegisterTask(const CoordinatedTask& task, - uint64_t incarnation) override; + + void LogConnectStatusLocked() const TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + + absl::Status RegisterTask(const CoordinatedTask& task, + uint64_t incarnation) override; void WaitForAllTasks(const CoordinatedTask& task, const DeviceInfo& devices, StatusCallback done) override; void ShutdownTaskAsync(const CoordinatedTask& task, StatusCallback done) override; - Status ResetTask(const CoordinatedTask& task) override; - Status RecordHeartbeat(const CoordinatedTask& task, - uint64_t incarnation) override; - Status ReportTaskError(const CoordinatedTask& task, Status error) override; + absl::Status ResetTask(const CoordinatedTask& task) override; + absl::Status RecordHeartbeat(const CoordinatedTask& task, + uint64_t incarnation) override; + absl::Status ReportTaskError(const CoordinatedTask& task, + absl::Status error) override; std::vector GetTaskState( const std::vector& task) override; - Status InsertKeyValue(const std::string& key, - const std::string& value) override; + absl::Status InsertKeyValue(const std::string& key, + const std::string& value) override; void GetKeyValueAsync(const std::string& key, StatusOrValueCallback done) override; - StatusOr TryGetKeyValue(const std::string& key) override; + absl::StatusOr TryGetKeyValue(const std::string& key) override; std::vector GetKeyValueDir( absl::string_view directory_key) override; - Status DeleteKeyValue(const std::string& key) override; + absl::Status DeleteKeyValue(const std::string& key) override; void BarrierAsync(const std::string& barrier_id, absl::Duration timeout, const CoordinatedTask& task, const std::vector& participating_tasks, StatusCallback done) override; - Status CancelBarrier(const std::string& barrier_id, - const CoordinatedTask& task) override; + absl::Status CancelBarrier(const std::string& barrier_id, + const CoordinatedTask& task) override; private: const DeviceInfo& ListClusterDevices() override @@ -139,22 +145,22 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { void Stop(bool shut_staleness_thread = true); // Report service error to a specified task. void ReportServiceErrorToTaskAsync(const CoordinatedTask& destination_task, - Status error); + absl::Status error); // Report error from a task to all other connected tasks if the task is not // recoverable. // Note: SetTaskError() must be called before propagating its error. void PropagateError(const CoordinatedTask& source_task, bool is_reported_by_task = false) TF_LOCKS_EXCLUDED(state_mu_); - void SetTaskError(absl::string_view task_name, Status error) + void SetTaskError(absl::string_view task_name, absl::Status error) TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); void AggregateClusterDevices() TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - Status DisconnectTask(const CoordinatedTask& task) + absl::Status DisconnectTask(const CoordinatedTask& task) TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); struct BarrierState { bool passed = false; - Status result = errors::Unknown( + absl::Status result = errors::Unknown( "Invalid barrier result."); // Only valid if `passed` is true. uint64_t deadline_in_micros = 0; int num_pending_tasks = 0; @@ -164,7 +170,7 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { tasks_at_barrier; std::vector done_callbacks; }; - void PassBarrier(absl::string_view barrier_id, Status result, + void PassBarrier(absl::string_view barrier_id, absl::Status result, BarrierState* barrier) TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); // Check if participating tasks are specified correctly across barrier calls. @@ -188,17 +194,17 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { // tasks in the cluster. CoordinatedTaskState GetState() { return state_; } - Status GetStatus() { return status_; } + absl::Status GetStatus() { return status_; } uint64_t GetTaskIncarnation() { return task_incarnation_; } void SetConnected(uint64_t task_incarnation); void Disconnect(uint64_t grace_period_duration_us); - Status RecordHeartbeat(uint64_t task_incarnation); + absl::Status RecordHeartbeat(uint64_t task_incarnation); int64_t TimeSinceLastHeartbeatMs(); // This denotes the deadline after which we stop accepting heartbeats from a // disconnected task. This grace period accounts for the lag time between // the service recording the state change and the agent stopping heartbeats. uint64_t GetDisconnectedGracePeriodMicros(); - void SetError(Status status); + void SetError(absl::Status status); DeviceInfo GetDeviceInfo() { return devices_; } void CollectDeviceInfo(const DeviceInfo& devices) { devices_ = devices; } // Checks if task has called WaitForAllTasks() previously, which gathers the @@ -214,7 +220,7 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { uint64_t task_incarnation_ = 0; CoordinatedTaskState state_ = CoordinatedTaskState::TASKSTATE_DISCONNECTED; - Status status_; + absl::Status status_; mutex last_heartbeat_mu_; uint64_t last_heartbeat_us_ TF_GUARDED_BY(last_heartbeat_mu_); // This denotes the deadline after which we stop accepting heartbeats from a @@ -277,7 +283,7 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { void CoordinationServiceStandaloneImpl::TaskState::SetConnected( uint64_t task_incarnation) { state_ = CoordinatedTaskState::TASKSTATE_CONNECTED; - status_ = OkStatus(); + status_ = absl::OkStatus(); task_incarnation_ = task_incarnation; mutex_lock l(last_heartbeat_mu_); last_heartbeat_us_ = Env::Default()->NowMicros(); @@ -288,17 +294,17 @@ void CoordinationServiceStandaloneImpl::TaskState::Disconnect( disconnect_grace_period_us_ = Env::Default()->NowMicros() + grace_period_duration_us; state_ = CoordinatedTaskState::TASKSTATE_DISCONNECTED; - status_ = OkStatus(); + status_ = absl::OkStatus(); } void CoordinationServiceStandaloneImpl::TaskState::SetError( - const Status status) { + const absl::Status status) { if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) return; state_ = CoordinatedTaskState::TASKSTATE_ERROR; status_ = status; } -Status CoordinationServiceStandaloneImpl::TaskState::RecordHeartbeat( +absl::Status CoordinationServiceStandaloneImpl::TaskState::RecordHeartbeat( uint64_t task_incarnation) { if (!status_.ok()) return status_; if (task_incarnation != task_incarnation_) { @@ -308,7 +314,7 @@ Status CoordinationServiceStandaloneImpl::TaskState::RecordHeartbeat( } mutex_lock l(last_heartbeat_mu_); last_heartbeat_us_ = Env::Default()->NowMicros(); - return OkStatus(); + return absl::OkStatus(); } int64_t @@ -357,6 +363,7 @@ CoordinationServiceStandaloneImpl::CoordinationServiceStandaloneImpl( absl::Milliseconds(config.shutdown_barrier_timeout_in_ms())), allow_new_incarnation_to_reconnect_( config.allow_new_incarnation_to_reconnect()) { + LOG(INFO) << "Initializing CoordinationService"; recoverable_jobs_ = absl::flat_hash_set( config.recoverable_jobs().cbegin(), config.recoverable_jobs().cend()); for (const auto& job : config.coordinated_job_list()) { @@ -386,7 +393,7 @@ void CoordinationServiceStandaloneImpl::StartCheckStaleness() { } } // Heartbeat check. - Status status = OkStatus(); + absl::Status status = absl::OkStatus(); { mutex_lock l(state_mu_); for (const auto& [task_name, task_state] : cluster_state_) { @@ -457,7 +464,7 @@ void CoordinationServiceStandaloneImpl::StartCheckStaleness() { } } } - const Status error = + const absl::Status error = MakeCoordinationError(errors::DeadlineExceeded(absl::StrCat( "Barrier timed out. Barrier_id: ", barrier_id, ". Timed out task names:\n", pending_tasks))); @@ -498,7 +505,7 @@ void CoordinationServiceStandaloneImpl::Stop(bool shut_staleness_thread) { mutex_lock l(state_mu_); for (auto& [barrier_id, barrier] : barriers_) { if (!barrier.passed) { - Status error = MakeCoordinationError(errors::Aborted(absl::StrCat( + absl::Status error = MakeCoordinationError(errors::Aborted(absl::StrCat( "Barrier failed because service is shutting down. Barrier_id: ", barrier_id))); PassBarrier(barrier_id, error, &barrier); @@ -519,11 +526,31 @@ void CoordinationServiceStandaloneImpl::Stop(bool shut_staleness_thread) { } } -Status CoordinationServiceStandaloneImpl::RegisterTask( +// Helper to log progress to having waited for all tasks. +void CoordinationServiceStandaloneImpl::LogConnectStatusLocked() const { + const int num_tasks = cluster_state_.size(); + int pending_tasks = 0; + std::vector task_names; + for (const auto& [task_name, task_state] : cluster_state_) { + if (task_state->GetState() != CoordinatedTaskState::TASKSTATE_CONNECTED) { + pending_tasks++; + if (task_names.size() < kPendingStragglerLogLimit) { + task_names.push_back(task_name); + } + } + } + LOG(INFO) << "Waiting for " << pending_tasks << "/" << num_tasks + << " tasks to connect."; + if (!task_names.empty()) { + LOG(INFO) << "Example stragglers:\n" << absl::StrJoin(task_names, "\n"); + } +} + +absl::Status CoordinationServiceStandaloneImpl::RegisterTask( const CoordinatedTask& task, uint64_t incarnation) { const std::string& task_name = GetTaskName(task); - Status error; + absl::Status error; std::string error_message; { mutex_lock l(state_mu_); @@ -553,7 +580,8 @@ Status CoordinationServiceStandaloneImpl::RegisterTask( LOG(INFO) << task_name << " has connected to coordination service. Incarnation: " << incarnation; - return OkStatus(); + LogConnectStatusLocked(); + return absl::OkStatus(); } else if (task_state == CoordinatedTaskState::TASKSTATE_CONNECTED) { // This may happen if the service processes the initial RegisterTask(), // but the agent did not receive the response so the agent retries again. @@ -565,7 +593,8 @@ Status CoordinationServiceStandaloneImpl::RegisterTask( LOG(INFO) << task_name << " has connected to coordination service with the same " << "incarnation again: " << incarnation; - return OkStatus(); + LogConnectStatusLocked(); + return absl::OkStatus(); } else { error_message = absl::StrCat(task_name, @@ -615,7 +644,7 @@ void CoordinationServiceStandaloneImpl::ShutdownTaskAsync( BarrierAsync(shutdown_barrier_id_, shutdown_barrier_timeout_, task, {}, done); } else { - Status status; + absl::Status status; { mutex_lock l(state_mu_); // Disconnect task from service individually. @@ -625,13 +654,13 @@ void CoordinationServiceStandaloneImpl::ShutdownTaskAsync( } } -Status CoordinationServiceStandaloneImpl::ResetTask( +absl::Status CoordinationServiceStandaloneImpl::ResetTask( const CoordinatedTask& task) { mutex_lock l(state_mu_); return DisconnectTask(task); } -Status CoordinationServiceStandaloneImpl::DisconnectTask( +absl::Status CoordinationServiceStandaloneImpl::DisconnectTask( const CoordinatedTask& task) { const std::string task_name = GetTaskName(task); // Check if task is valid and not already disconnected. @@ -649,14 +678,14 @@ Status CoordinationServiceStandaloneImpl::DisconnectTask( /*grace_period_duration_us=*/heartbeat_timeout_ms_ * 1000); for (const auto& barrier_id : cluster_state_[task_name]->GetOngoingBarriers()) { - Status error = MakeCoordinationError(errors::Internal(absl::StrCat( + absl::Status error = MakeCoordinationError(errors::Internal(absl::StrCat( "Barrier failed from a disconnected task. Barrier Id: ", barrier_id, ", Task: ", task_name))); PassBarrier(barrier_id, error, &barriers_[barrier_id]); } LOG(INFO) << task_name << " has disconnected from coordination service."; - return OkStatus(); + return absl::OkStatus(); } const DeviceInfo& CoordinationServiceStandaloneImpl::ListClusterDevices() { @@ -667,8 +696,8 @@ uint64_t CoordinationServiceStandaloneImpl::GetServiceIncarnation() { return service_incarnation_; } -Status CoordinationServiceStandaloneImpl::ReportTaskError( - const CoordinatedTask& task, Status error) { +absl::Status CoordinationServiceStandaloneImpl::ReportTaskError( + const CoordinatedTask& task, absl::Status error) { const std::string& task_name = GetTaskName(task); { mutex_lock l(state_mu_); @@ -684,7 +713,7 @@ Status CoordinationServiceStandaloneImpl::ReportTaskError( } } PropagateError(task, /*is_reported_by_task=*/true); - return OkStatus(); + return absl::OkStatus(); } std::vector @@ -694,7 +723,7 @@ CoordinationServiceStandaloneImpl::GetTaskState( for (const auto& task : tasks) { const std::string task_name = GetTaskName(task); auto& state_info = states_info.emplace_back(); - Status error; + absl::Status error; { mutex_lock l(state_mu_); state_info.set_state(cluster_state_[task_name]->GetState()); @@ -711,10 +740,10 @@ CoordinationServiceStandaloneImpl::GetTaskState( return states_info; } -Status CoordinationServiceStandaloneImpl::RecordHeartbeat( +absl::Status CoordinationServiceStandaloneImpl::RecordHeartbeat( const CoordinatedTask& task, uint64_t incarnation) { const std::string& task_name = GetTaskName(task); - Status s = OkStatus(); + absl::Status s = absl::OkStatus(); { mutex_lock l(state_mu_); if (!cluster_state_.contains(task_name)) { @@ -754,7 +783,7 @@ Status CoordinationServiceStandaloneImpl::RecordHeartbeat( } void CoordinationServiceStandaloneImpl::ReportServiceErrorToTaskAsync( - const CoordinatedTask& destination_task, Status error) { + const CoordinatedTask& destination_task, absl::Status error) { assert(!error.ok()); // Don't report error if there is no service-to-client connection. @@ -777,7 +806,7 @@ void CoordinationServiceStandaloneImpl::ReportServiceErrorToTaskAsync( CoordinationClient* client = client_cache_->GetClient(task_name); client->ReportErrorToTaskAsync( call_opts.get(), request.get(), response.get(), - [request, response, task_name, call_opts](Status s) { + [request, response, task_name, call_opts](absl::Status s) { if (!s.ok()) { LOG(ERROR) << "Encountered another error while reporting to " << task_name << ": " << s; @@ -790,7 +819,7 @@ void CoordinationServiceStandaloneImpl::PropagateError( // If the error task is recoverable, do not propagate the error to other // connected tasks. if (isRecoverableJob(source_task.job_name())) return; - Status error; + absl::Status error; { mutex_lock l(state_mu_); error = cluster_state_[GetTaskName(source_task)]->GetStatus(); @@ -836,7 +865,8 @@ void CoordinationServiceStandaloneImpl::PropagateError( auto response = std::make_shared(); auto n = std::make_shared(); client->ReportErrorToTaskAsync( - &call_opts, &request, response.get(), [response, n, task](Status s) { + &call_opts, &request, response.get(), + [response, n, task](absl::Status s) { if (!s.ok()) { LOG(ERROR) << "Encountered another error while reporting to " << task << ": " << s; @@ -878,7 +908,7 @@ std::string NormalizeKey(const StringPiece orig_key) { return norm_key; } -Status CoordinationServiceStandaloneImpl::InsertKeyValue( +absl::Status CoordinationServiceStandaloneImpl::InsertKeyValue( const std::string& key, const std::string& value) { VLOG(3) << "InsertKeyValue(): " << key << ": " << value; const std::string& norm_key = NormalizeKey(key); @@ -895,7 +925,7 @@ Status CoordinationServiceStandaloneImpl::InsertKeyValue( } get_cb_.erase(iter); } - return OkStatus(); + return absl::OkStatus(); } void CoordinationServiceStandaloneImpl::GetKeyValueAsync( @@ -916,7 +946,7 @@ void CoordinationServiceStandaloneImpl::GetKeyValueAsync( cb_iter->second.emplace_back(std::move(done)); } -StatusOr CoordinationServiceStandaloneImpl::TryGetKeyValue( +absl::StatusOr CoordinationServiceStandaloneImpl::TryGetKeyValue( const std::string& key) { VLOG(3) << "TryGetKeyValue(): " << key; const std::string& norm_key = NormalizeKey(key); @@ -956,7 +986,7 @@ std::vector CoordinationServiceStandaloneImpl::GetKeyValueDir( return kvs_in_directory; } -Status CoordinationServiceStandaloneImpl::DeleteKeyValue( +absl::Status CoordinationServiceStandaloneImpl::DeleteKeyValue( const std::string& key) { VLOG(3) << "DeleteKeyValue(): " << key; const std::string& norm_key = NormalizeKey(key); @@ -975,15 +1005,15 @@ Status CoordinationServiceStandaloneImpl::DeleteKeyValue( if (iter != kv_store_.end()) { kv_store_.erase(iter); } - return OkStatus(); + return absl::OkStatus(); } void CoordinationServiceStandaloneImpl::SetTaskError( - absl::string_view task_name, Status error) { + absl::string_view task_name, absl::Status error) { cluster_state_[task_name]->SetError(error); for (const auto& barrier_id : cluster_state_[task_name]->GetOngoingBarriers()) { - Status error = MakeCoordinationError(errors::Internal(absl::StrCat( + absl::Status error = MakeCoordinationError(errors::Internal(absl::StrCat( "Barrier failed from a task error. Barrier Id: ", barrier_id, ", Task: ", task_name))); PassBarrier(barrier_id, error, &barriers_[barrier_id]); @@ -1021,7 +1051,7 @@ void CoordinationServiceStandaloneImpl::BarrierAsync( // barrier. const std::string task_name = GetTaskName(task); if (!cluster_state_.contains(task_name)) { - Status error = MakeCoordinationError(errors::InvalidArgument( + absl::Status error = MakeCoordinationError(errors::InvalidArgument( absl::StrCat("Unexpected task (", task_name, ") that is not in the cluster called the barrier. " "Barrier Id: ", @@ -1040,7 +1070,7 @@ void CoordinationServiceStandaloneImpl::BarrierAsync( const std::string task_name = GetTaskName(pending_task.first); if (cluster_state_[task_name]->GetState() == CoordinatedTaskState::TASKSTATE_ERROR) { - Status error = MakeCoordinationError(errors::Internal( + absl::Status error = MakeCoordinationError(errors::Internal( absl::StrCat("Task (", task_name, ") is already in error before the barrier " "was called. Barrier Id: ", @@ -1071,7 +1101,7 @@ void CoordinationServiceStandaloneImpl::BarrierAsync( if (barrier->passed) { // Special hook for shutdown barrier to disconnect task. if (barrier_id == shutdown_barrier_id_) { - Status s = DisconnectTask(task); + absl::Status s = DisconnectTask(task); // Return any errors from the disconnect attempt, otherwise return the // barrier status outside of this hook. if (!s.ok()) { @@ -1090,7 +1120,7 @@ void CoordinationServiceStandaloneImpl::BarrierAsync( // Check if caller task is participating in the barrier. if (!barrier->tasks_at_barrier.contains(task)) { // Unexpected barrier call from a task not participating in the barrier. - Status error = MakeCoordinationError(errors::InvalidArgument( + absl::Status error = MakeCoordinationError(errors::InvalidArgument( absl::StrCat("A non-participating task (", GetTaskName(task), ") called the barrier: ", barrier_id))); PassBarrier(barrier_id, error, barrier); @@ -1100,8 +1130,9 @@ void CoordinationServiceStandaloneImpl::BarrierAsync( // Check if task args are specified consistently across barrier calls. if (!ValidateTaskArgs(participating_tasks, barrier->tasks_at_barrier, cluster_state_.size())) { - Status error = MakeCoordinationError(errors::InvalidArgument(absl::StrCat( - "Conflicting tasks specified for the same barrier: ", barrier_id))); + absl::Status error = + MakeCoordinationError(errors::InvalidArgument(absl::StrCat( + "Conflicting tasks specified for the same barrier: ", barrier_id))); PassBarrier(barrier_id, error, barrier); return; } @@ -1113,13 +1144,13 @@ void CoordinationServiceStandaloneImpl::BarrierAsync( --barrier->num_pending_tasks; if (barrier->num_pending_tasks == 0) { - PassBarrier(barrier_id, OkStatus(), barrier); + PassBarrier(barrier_id, absl::OkStatus(), barrier); return; } } } -Status CoordinationServiceStandaloneImpl::CancelBarrier( +absl::Status CoordinationServiceStandaloneImpl::CancelBarrier( const std::string& barrier_id, const CoordinatedTask& task) { mutex_lock l(state_mu_); auto [it, inserted] = barriers_.try_emplace(barrier_id); @@ -1137,17 +1168,17 @@ Status CoordinationServiceStandaloneImpl::CancelBarrier( } // Cancel barrier. - Status cancelled = MakeCoordinationError(errors::Cancelled(absl::StrCat( + absl::Status cancelled = MakeCoordinationError(errors::Cancelled(absl::StrCat( "Barrier (", barrier_id, ") is cancelled by task: ", GetTaskName(task)))); PassBarrier(barrier_id, cancelled, barrier); VLOG(3) << "Barrier (" << barrier_id << ") is cancelled."; - return OkStatus(); + return absl::OkStatus(); } // Mark barrier as passed. void CoordinationServiceStandaloneImpl::PassBarrier( - absl::string_view barrier_id, Status result, BarrierState* barrier) { + absl::string_view barrier_id, absl::Status result, BarrierState* barrier) { barrier->passed = true; barrier->result = result; VLOG(3) << "Barrier(" << barrier_id << ") has passed with status: " << result; @@ -1173,14 +1204,14 @@ void CoordinationServiceStandaloneImpl::PassBarrier( "crashed early or too slow / hanging. Check the logs for " "an earlier error to identify the root cause."; } - Status shutdown_error = MakeCoordinationError(errors::Internal( + absl::Status shutdown_error = MakeCoordinationError(errors::Internal( absl::StrCat("Shutdown barrier has been passed with status: '", barrier->result.ToString(), "', but this task is not at the barrier yet."))); for (const auto& [task, at_barrier] : barrier->tasks_at_barrier) { if (at_barrier) { // Disconnect tasks that reached the barrier. - Status disconnect_status = DisconnectTask(task); + absl::Status disconnect_status = DisconnectTask(task); if (!disconnect_status.ok()) { LOG(ERROR) << disconnect_status; } diff --git a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service.h b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service.h index d95d4f7f54d0d..b82261e6d30bc 100644 --- a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service.h +++ b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service.h @@ -70,7 +70,7 @@ class CoordinationServiceInterface { std::unique_ptr cache)>; using StatusOrValueCallback = - std::function&)>; + std::function&)>; virtual ~CoordinationServiceInterface() = default; @@ -116,8 +116,8 @@ class CoordinationServiceInterface { // - InvalidArgument: Unexpected task request. // - Aborted: (1) task is in error state, or (2) task is in connected state // with a different incarnation, indicating that it restarted. - virtual Status RegisterTask(const tensorflow::CoordinatedTask& task, - uint64_t incarnation) = 0; + virtual absl::Status RegisterTask(const tensorflow::CoordinatedTask& task, + uint64_t incarnation) = 0; // Wait for all tasks to be up and running, and register local device // info. The callback is invoked when all tasks are up and registered, or some @@ -141,16 +141,16 @@ class CoordinationServiceInterface { // Possible service errors: // - InvalidArgument: Unexpected task request. // - FailedPrecondition: task has already disconnected. - virtual Status ResetTask(const tensorflow::CoordinatedTask& task) = 0; + virtual absl::Status ResetTask(const tensorflow::CoordinatedTask& task) = 0; // Update the heartbeat timestamp of a task. This should only be invoked on // the leader of the cluster. - virtual Status RecordHeartbeat(const tensorflow::CoordinatedTask& task, - uint64_t incarnation) = 0; + virtual absl::Status RecordHeartbeat(const tensorflow::CoordinatedTask& task, + uint64_t incarnation) = 0; // Set a task in error state permanently. - virtual Status ReportTaskError(const tensorflow::CoordinatedTask& task, - Status error) = 0; + virtual absl::Status ReportTaskError(const tensorflow::CoordinatedTask& task, + absl::Status error) = 0; // Get the state and the error status of the tasks. virtual std::vector GetTaskState( @@ -159,8 +159,8 @@ class CoordinationServiceInterface { // Insert a configuration key-value in the coordination service. // For now, a key-value can only be inserted once and cannot be updated. // The key-values are not persisted and will be lost if the leader fails. - virtual Status InsertKeyValue(const std::string& key, - const std::string& value) = 0; + virtual absl::Status InsertKeyValue(const std::string& key, + const std::string& value) = 0; // Get a configuration key-value from the coordination service. The `done` // callback is invoked when the key-value becomes available. @@ -169,7 +169,8 @@ class CoordinationServiceInterface { // Get a configuration key-value from the coordination service. If the key // does not exist, return NotFound error. - virtual StatusOr TryGetKeyValue(const std::string& key) = 0; + virtual absl::StatusOr TryGetKeyValue( + const std::string& key) = 0; // Gets all values under a directory (key). // A value is considered to be in the directory if its key is prefixed with @@ -180,7 +181,7 @@ class CoordinationServiceInterface { // Delete configuration key-value. If key is a directory, recursively clean // up all key-values under the directory. - virtual Status DeleteKeyValue(const std::string& key) = 0; + virtual absl::Status DeleteKeyValue(const std::string& key) = 0; // Blocks until all (or a subset of) tasks are at the barrier or the barrier // fails. @@ -222,8 +223,9 @@ class CoordinationServiceInterface { // CANCELLED error status. // Possible service errors: // - FailedPrecondition: Barrier has already been passed. - virtual Status CancelBarrier(const std::string& barrier_id, - const tensorflow::CoordinatedTask& task) = 0; + virtual absl::Status CancelBarrier( + const std::string& barrier_id, + const tensorflow::CoordinatedTask& task) = 0; private: friend class CoordinationServiceRpcHandler; diff --git a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.cc b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.cc index 79065f7a9118a..5f65f8e861bd9 100644 --- a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.cc +++ b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tsl/distributed_runtime/coordination/coordination_service_agent.h" #include +#include +#include #include #include #include @@ -27,6 +29,7 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/substitute.h" @@ -39,9 +42,10 @@ limitations under the License. #include "tsl/framework/cancellation.h" #include "tsl/lib/monitoring/gauge.h" #include "tsl/platform/env.h" -#include "tsl/platform/errors.h" #include "tsl/platform/mutex.h" #include "tsl/platform/random.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/thread_annotations.h" #include "tsl/protobuf/coordination_config.pb.h" #include "tsl/protobuf/coordination_service.pb.h" @@ -68,70 +72,69 @@ class CoordinationServiceAgentImpl : public CoordinationServiceAgent { public: CoordinationServiceAgentImpl() = default; ~CoordinationServiceAgentImpl() override { - Status s = Shutdown(); + absl::Status s = Shutdown(); VLOG(3) << "Coordination agent dtor failed with status: " << s; } - Status Initialize(Env* env, const std::string& job_name, int task_id, - const CoordinationServiceConfig& configs, - std::unique_ptr leader_client, - StatusCallback error_fn) override; - Status Initialize(Env* env, const CoordinatedTask& task, - const CoordinationServiceConfig& configs, - std::unique_ptr leader_client, - StatusCallback error_fn) override; + absl::Status Initialize(Env* env, std::string_view job_name, int task_id, + const CoordinationServiceConfig& configs, + std::unique_ptr leader_client, + StatusCallback error_fn) override; + absl::Status Initialize(Env* env, const CoordinatedTask& task, + const CoordinationServiceConfig& configs, + std::unique_ptr leader_client, + StatusCallback error_fn) override; bool IsInitialized() override; bool IsConnected() override; bool IsError() override; - Status Connect() override; - Status WaitForAllTasks(const DeviceInfo& local_devices) override; + absl::Status Connect() override; + absl::Status WaitForAllTasks(const DeviceInfo& local_devices) override; const DeviceInfo& GetClusterDeviceInfo() override; - StatusOr GetOwnTask() override; - StatusOr> GetTaskState( + absl::StatusOr GetOwnTask() override; + absl::StatusOr> GetTaskState( const std::vector& task) override; - Status ReportError(const Status& error) override; - Status Shutdown() override; - Status Reset() override; - - StatusOr GetKeyValue(std::string_view key) override; - StatusOr GetKeyValue(const char* key, int64_t key_size) override; - StatusOr GetKeyValue(std::string_view key, - absl::Duration timeout) override; + absl::Status ReportError(const absl::Status& error) override; + absl::Status Shutdown() override; + absl::Status Reset() override; + + absl::StatusOr GetKeyValue(std::string_view key) override; + absl::StatusOr GetKeyValue(std::string_view key, + absl::Duration timeout) override; std::shared_ptr GetKeyValueAsync( std::string_view key, StatusOrValueCallback done) override; - StatusOr TryGetKeyValue(std::string_view key) override; - StatusOr> GetKeyValueDir( + absl::StatusOr TryGetKeyValue(std::string_view key) override; + absl::StatusOr> GetKeyValueDir( std::string_view key) override; void GetKeyValueDirAsync(std::string_view key, StatusOrValueDirCallback done) override; - Status InsertKeyValue(std::string_view key, std::string_view value) override; - Status InsertKeyValue(const char* key, int64_t key_size, const char* value, - int64_t value_size) override; - Status DeleteKeyValue(std::string_view key) override; - Status DeleteKeyValue(const char* key, int64_t key_size) override; - Status UpdateKeyValue(std::string_view key, std::string_view value) override; - - Status StartWatchKey(std::string_view key, - ChangedKeyValuesCallback on_change) override; - Status StopWatchKey(std::string_view key) override; - Status WaitAtBarrier(const std::string& barrier_id, absl::Duration timeout, - const std::vector& tasks) override; - void WaitAtBarrierAsync(const std::string& barrier_id, absl::Duration timeout, + absl::Status InsertKeyValue(std::string_view key, + std::string_view value) override; + absl::Status DeleteKeyValue(std::string_view key) override; + absl::Status UpdateKeyValue(std::string_view key, + std::string_view value) override; + + absl::Status StartWatchKey(std::string_view key, + ChangedKeyValuesCallback on_change) override; + absl::Status StopWatchKey(std::string_view key) override; + absl::Status WaitAtBarrier( + std::string_view barrier_id, absl::Duration timeout, + const std::vector& tasks) override; + void WaitAtBarrierAsync(std::string_view barrier_id, absl::Duration timeout, const std::vector& tasks, StatusCallback done) override; - Status CancelBarrier(const std::string& barrier_id) override; - void CancelBarrierAsync(const std::string& barrier_id, + absl::Status CancelBarrier(std::string_view barrier_id) override; + void CancelBarrierAsync(std::string_view barrier_id, StatusCallback done) override; - StatusOr GetEnv() override; + absl::StatusOr GetEnv() override; protected: - void SetError(const Status& error) override; - Status ActivateWatch(std::string_view key, - const std::map&) override; + void SetError(const absl::Status& error) override; + absl::Status ActivateWatch( + std::string_view key, const std::map&) override; // Returns an error if agent is not running. If `allow_disconnected` is true, // returns OK even if the agent is in DISCONNECTED state. - Status ValidateRunningAgent(bool allow_disconnected = false); + absl::Status ValidateRunningAgent(bool allow_disconnected = false); void StopHeartbeat(); private: @@ -144,7 +147,7 @@ class CoordinationServiceAgentImpl : public CoordinationServiceAgent { mutable mutex state_mu_; CoordinatedTaskState state_ TF_GUARDED_BY(state_mu_) = CoordinatedTaskState::TASKSTATE_UNINITIALIZED; - Status status_ TF_GUARDED_BY(state_mu_) = OkStatus(); + absl::Status status_ TF_GUARDED_BY(state_mu_) = absl::OkStatus(); // Note: this set grows without bounds. For now, this is okay as most users // require < 100 barriers. If there is a use case that requires many barriers, // consider using a monotonic sequence number to track instead. @@ -166,18 +169,18 @@ class CoordinationServiceAgentImpl : public CoordinationServiceAgent { void operator=(const CoordinationServiceAgentImpl&) = delete; }; -Status CoordinationServiceAgentImpl::Initialize( - Env* env, const std::string& job_name, int task_id, +absl::Status CoordinationServiceAgentImpl::Initialize( + Env* env, std::string_view job_name, int task_id, const CoordinationServiceConfig& configs, std::unique_ptr leader_client, StatusCallback error_fn) { CoordinatedTask task; - task.set_job_name(job_name); + task.set_job_name(std::string(job_name)); task.set_task_id(task_id); return Initialize(env, task, configs, std::move(leader_client), error_fn); } -Status CoordinationServiceAgentImpl::Initialize( +absl::Status CoordinationServiceAgentImpl::Initialize( Env* env, const CoordinatedTask& task, const CoordinationServiceConfig& configs, std::unique_ptr leader_client, @@ -185,7 +188,7 @@ Status CoordinationServiceAgentImpl::Initialize( enabled_usage_metric->GetCell()->Set(true); mutex_lock l(state_mu_); if (state_ != CoordinatedTaskState::TASKSTATE_UNINITIALIZED) { - return MakeCoordinationError(errors::FailedPrecondition( + return MakeCoordinationError(absl::FailedPreconditionError( "Coordination service agent has already been initialized.")); } @@ -193,17 +196,17 @@ Status CoordinationServiceAgentImpl::Initialize( task_ = task; configs_ = configs; if (configs_.service_leader().empty()) { - return MakeCoordinationError(errors::InvalidArgument( + return MakeCoordinationError(absl::InvalidArgumentError( "CoordinationServiceAgent must be initialized with a valid leader.")); } leader_client_ = std::move(leader_client); if (leader_client_ == nullptr) { - return MakeCoordinationError(errors::InvalidArgument( + return MakeCoordinationError(absl::InvalidArgumentError( "CoordinationServiceAgent must have a valid leader client.")); } error_fn_ = error_fn; state_ = CoordinatedTaskState::TASKSTATE_DISCONNECTED; - return OkStatus(); + return absl::OkStatus(); } bool CoordinationServiceAgentImpl::IsInitialized() { @@ -230,16 +233,17 @@ void CoordinationServiceAgentImpl::StopHeartbeat() { heartbeat_thread_.reset(); } -Status CoordinationServiceAgentImpl::Connect() { +absl::Status CoordinationServiceAgentImpl::Connect() { VLOG(3) << "Agent has started trying to Connect()."; { mutex_lock l(state_mu_); if (state_ != CoordinatedTaskState::TASKSTATE_DISCONNECTED) { - return MakeCoordinationError(errors::FailedPrecondition( + return MakeCoordinationError(absl::FailedPreconditionError( "Coordination service agent is not in DISCONNECTED state.")); } } - Status connect_status = errors::Unknown("Connection not attempted yet."); + absl::Status connect_status = + absl::UnknownError("Connection not attempted yet."); RegisterTaskRequest request; *request.mutable_source_task() = task_; request.set_incarnation(incarnation_id_); @@ -261,7 +265,7 @@ Status CoordinationServiceAgentImpl::Connect() { call_opts.SetTimeout(absl::ToInt64Milliseconds(deadline - absl::Now())); absl::Notification n; leader_client_->RegisterTaskAsync( - &call_opts, &request, &response, [&](Status s) { + &call_opts, &request, &response, [&](absl::Status s) { if (s.ok()) { leader_incarnation_ = response.leader_incarnation(); { @@ -315,13 +319,13 @@ Status CoordinationServiceAgentImpl::Connect() { call_opts.SetTimeout(heartbeat_interval_ms); while (true) { - Status status; + absl::Status status; absl::Notification n; // Heartbeat RPC implementation automatically retries to tolerate // transient network failures. VLOG(10) << "HeartbeatRequest: " << request.DebugString(); leader_client_->HeartbeatAsync(&call_opts, &request, &response, - [&](Status s) { + [&](absl::Status s) { status = s; n.Notify(); }); @@ -341,8 +345,8 @@ Status CoordinationServiceAgentImpl::Connect() { SetError(status); } else if (response.leader_incarnation() != leader_incarnation_) { SetError(MakeCoordinationError( - errors::Aborted("Leader incarnation ID mismatch: the " - "coordination leader has restarted."))); + absl::AbortedError("Leader incarnation ID mismatch: the " + "coordination leader has restarted."))); } // Send next heartbeat after an interval. { @@ -355,12 +359,12 @@ Status CoordinationServiceAgentImpl::Connect() { } } })); - return OkStatus(); + return absl::OkStatus(); } -Status CoordinationServiceAgentImpl::WaitForAllTasks( +absl::Status CoordinationServiceAgentImpl::WaitForAllTasks( const DeviceInfo& local_devices) { - Status agent_running_status = ValidateRunningAgent(); + absl::Status agent_running_status = ValidateRunningAgent(); if (!agent_running_status.ok()) { return agent_running_status; } @@ -369,12 +373,13 @@ Status CoordinationServiceAgentImpl::WaitForAllTasks( *request.mutable_device_info() = local_devices; VLOG(3) << "WaitForAllTasksRequest: " << request.DebugString(); WaitForAllTasksResponse response; - Status status; + absl::Status status; absl::Notification n; - leader_client_->WaitForAllTasksAsync(&request, &response, [&](Status s) { - status = s; - n.Notify(); - }); + leader_client_->WaitForAllTasksAsync(&request, &response, + [&](absl::Status s) { + status = s; + n.Notify(); + }); n.WaitForNotification(); if (!status.ok()) { VLOG(3) << "WaitForAllTasksResponse: " << status; @@ -383,53 +388,55 @@ Status CoordinationServiceAgentImpl::WaitForAllTasks( } VLOG(3) << "WaitForAllTasksResponse: " << response.DebugString(); cluster_devices_ = response.device_info(); - return OkStatus(); + return absl::OkStatus(); } const DeviceInfo& CoordinationServiceAgentImpl::GetClusterDeviceInfo() { return cluster_devices_; } -StatusOr CoordinationServiceAgentImpl::GetOwnTask() { +absl::StatusOr CoordinationServiceAgentImpl::GetOwnTask() { if (!IsInitialized()) { - return MakeCoordinationError( - errors::FailedPrecondition("Agent has not been initialized; we do not " - "know the associated task yet.")); + return MakeCoordinationError(absl::FailedPreconditionError( + "Agent has not been initialized; we do not " + "know the associated task yet.")); } return task_; } -StatusOr> +absl::StatusOr> CoordinationServiceAgentImpl::GetTaskState( const std::vector& tasks) { GetTaskStateRequest request; *request.mutable_source_task() = {tasks.begin(), tasks.end()}; GetTaskStateResponse response; absl::Notification n; - StatusOr> result; - leader_client_->GetTaskStateAsync(&request, &response, [&](const Status& s) { - if (s.ok()) { - result = std::vector( - std::make_move_iterator(response.task_state().begin()), - std::make_move_iterator(response.task_state().end())); - } else { - result = s; - } - n.Notify(); - }); + absl::StatusOr> result; + leader_client_->GetTaskStateAsync( + &request, &response, [&](const absl::Status& s) { + if (s.ok()) { + result = std::vector( + std::make_move_iterator(response.task_state().begin()), + std::make_move_iterator(response.task_state().end())); + } else { + result = s; + } + n.Notify(); + }); n.WaitForNotification(); return result; } -Status CoordinationServiceAgentImpl::ReportError(const Status& error) { +absl::Status CoordinationServiceAgentImpl::ReportError( + const absl::Status& error) { { mutex_lock l(state_mu_); if (state_ == CoordinatedTaskState::TASKSTATE_UNINITIALIZED) { - return MakeCoordinationError(errors::FailedPrecondition( + return MakeCoordinationError(absl::FailedPreconditionError( "Coordination service agent must be initialized first before " "reporting error.")); } else if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) { - return MakeCoordinationError(errors::FailedPrecondition( + return MakeCoordinationError(absl::FailedPreconditionError( "Coordination service agent is already in error state.")); } } @@ -444,24 +451,26 @@ Status CoordinationServiceAgentImpl::ReportError(const Status& error) { ReportErrorToServiceResponse response; absl::Notification n; - leader_client_->ReportErrorToServiceAsync(&request, &response, [&](Status s) { - VLOG(5) << "ReportErrorToServiceResponse: " << s; - if (!s.ok()) { - LOG(ERROR) << "Encountered another error when reporting error to " - "coordination service: " - << s - << "\nThis is usually caused by an earlier error during " - "execution. Check the logs (this task or the leader) for " - "an earlier error to debug further."; - } - n.Notify(); - }); + leader_client_->ReportErrorToServiceAsync( + &request, &response, [&](absl::Status s) { + VLOG(5) << "ReportErrorToServiceResponse: " << s; + if (!s.ok()) { + LOG(ERROR) + << "Encountered another error when reporting error to " + "coordination service: " + << s + << "\nThis is usually caused by an earlier error during " + "execution. Check the logs (this task or the leader) for " + "an earlier error to debug further."; + } + n.Notify(); + }); n.WaitForNotification(); - return OkStatus(); + return absl::OkStatus(); } -Status CoordinationServiceAgentImpl::Shutdown() { - Status status = OkStatus(); +absl::Status CoordinationServiceAgentImpl::Shutdown() { + absl::Status status = absl::OkStatus(); bool is_connected = false; { mutex_lock l(state_mu_); @@ -482,7 +491,7 @@ Status CoordinationServiceAgentImpl::Shutdown() { absl::Notification n; leader_client_->ShutdownTaskAsync(&call_opts, &request, &response, - [&status, &n](Status s) { + [&status, &n](absl::Status s) { status = s; n.Notify(); }); @@ -513,7 +522,7 @@ Status CoordinationServiceAgentImpl::Shutdown() { "Check the logs (this task or the leader) for an earlier error to " "debug further."); status = - MakeCoordinationError(errors::FailedPrecondition(status_message)); + MakeCoordinationError(absl::FailedPreconditionError(status_message)); LOG(ERROR) << status_message; } state_ = CoordinatedTaskState::TASKSTATE_DISCONNECTED; @@ -524,11 +533,11 @@ Status CoordinationServiceAgentImpl::Shutdown() { return status; } -Status CoordinationServiceAgentImpl::Reset() { +absl::Status CoordinationServiceAgentImpl::Reset() { { mutex_lock l(state_mu_); if (state_ != CoordinatedTaskState::TASKSTATE_ERROR) { - return MakeCoordinationError(errors::FailedPrecondition( + return MakeCoordinationError(absl::FailedPreconditionError( "Reset() failed: coordination service agent is not in ERROR state.")); } } @@ -538,12 +547,13 @@ Status CoordinationServiceAgentImpl::Reset() { VLOG(3) << "ResetTaskRequest: " << request.DebugString(); ResetTaskResponse response; - Status status; + absl::Status status; absl::Notification n; - leader_client_->ResetTaskAsync(&request, &response, [&status, &n](Status s) { - status = s; - n.Notify(); - }); + leader_client_->ResetTaskAsync(&request, &response, + [&status, &n](absl::Status s) { + status = s; + n.Notify(); + }); n.WaitForNotification(); VLOG(3) << "ResetTaskResponse: " << status; if (!status.ok()) { @@ -565,30 +575,25 @@ Status CoordinationServiceAgentImpl::Reset() { return status; } -StatusOr CoordinationServiceAgentImpl::GetKeyValue( +absl::StatusOr CoordinationServiceAgentImpl::GetKeyValue( std::string_view key) { return GetKeyValue(key, /*timeout=*/absl::InfiniteDuration()); } -StatusOr CoordinationServiceAgentImpl::GetKeyValue( - const char* key, int64_t key_size) { - return GetKeyValue(std::string_view(key, key_size)); -} - -StatusOr CoordinationServiceAgentImpl::GetKeyValue( +absl::StatusOr CoordinationServiceAgentImpl::GetKeyValue( std::string_view key, absl::Duration timeout) { auto n = std::make_shared(); auto result = std::make_shared>(); - GetKeyValueAsync(key, - [n, result](const StatusOr& status_or_value) { - *result = status_or_value; - n->Notify(); - }); + GetKeyValueAsync( + key, [n, result](const absl::StatusOr& status_or_value) { + *result = status_or_value; + n->Notify(); + }); bool call_completed_before_timeout = n->WaitForNotificationWithTimeout(timeout); if (!call_completed_before_timeout) { VLOG(3) << "GetKeyValue(" << key << ") timed out after " << timeout; - return MakeCoordinationError(errors::DeadlineExceeded(absl::Substitute( + return MakeCoordinationError(absl::DeadlineExceededError(absl::Substitute( "GetKeyValue() timed out with key: $0 and duration: $1", key, absl::FormatDuration(timeout)))); } @@ -608,13 +613,13 @@ std::shared_ptr CoordinationServiceAgentImpl::GetKeyValueAsync( const bool already_cancelled = !cancellation_manager_.RegisterCallback( token, [call_opts]() { call_opts->StartCancel(); }); if (already_cancelled) { - done(errors::Cancelled("GetKeyValueAsync() was cancelled.")); + done(absl::CancelledError("GetKeyValueAsync() was cancelled.")); return call_opts; } leader_client_->GetKeyValueAsync( call_opts.get(), request.get(), response.get(), [call_opts, request, response, done = std::move(done), - &cm = cancellation_manager_, token](const Status& s) { + &cm = cancellation_manager_, token](const absl::Status& s) { // RPC call has completed (no longer needs to be cancelled if agent is // destroyed). cm.TryDeregisterCallback(token); @@ -631,16 +636,16 @@ std::shared_ptr CoordinationServiceAgentImpl::GetKeyValueAsync( return call_opts; } -StatusOr CoordinationServiceAgentImpl::TryGetKeyValue( +absl::StatusOr CoordinationServiceAgentImpl::TryGetKeyValue( std::string_view key) { absl::Notification n; - StatusOr result; + absl::StatusOr result; TryGetKeyValueRequest request; request.set_key(key.data(), key.size()); VLOG(3) << "TryGetKeyValueRequest: " << request.DebugString(); TryGetKeyValueResponse response; leader_client_->TryGetKeyValueAsync( - &request, &response, [&](const Status& s) { + &request, &response, [&](const absl::Status& s) { if (s.ok()) { result = response.kv().value(); VLOG(3) << "TryGetKeyValueResponse: " << result.value(); @@ -655,12 +660,13 @@ StatusOr CoordinationServiceAgentImpl::TryGetKeyValue( return result; } -StatusOr> +absl::StatusOr> CoordinationServiceAgentImpl::GetKeyValueDir(std::string_view key) { absl::Notification n; - StatusOr> result; + absl::StatusOr> result; GetKeyValueDirAsync( - key, [&n, &result](StatusOr> status_or_value) { + key, [&n, &result]( + absl::StatusOr> status_or_value) { result = std::move(status_or_value); n.Notify(); }); @@ -677,7 +683,7 @@ void CoordinationServiceAgentImpl::GetKeyValueDirAsync( auto response = std::make_shared(); leader_client_->GetKeyValueDirAsync( request.get(), response.get(), - [request, response, done = std::move(done)](const Status& s) { + [request, response, done = std::move(done)](const absl::Status& s) { if (!s.ok()) { done(s); VLOG(3) << "GetKeyValueDirResponse: " << s; @@ -691,17 +697,17 @@ void CoordinationServiceAgentImpl::GetKeyValueDirAsync( }); } -Status CoordinationServiceAgentImpl::InsertKeyValue(std::string_view key, - std::string_view value) { +absl::Status CoordinationServiceAgentImpl::InsertKeyValue( + std::string_view key, std::string_view value) { InsertKeyValueRequest request; request.mutable_kv()->set_key(key.data(), key.size()); request.mutable_kv()->set_value(value.data(), value.size()); VLOG(3) << "InsertKeyValueRequest: " << request.DebugString(); InsertKeyValueResponse response; - Status status; + absl::Status status; absl::Notification n; - leader_client_->InsertKeyValueAsync(&request, &response, [&](Status s) { + leader_client_->InsertKeyValueAsync(&request, &response, [&](absl::Status s) { status = s; n.Notify(); }); @@ -710,56 +716,44 @@ Status CoordinationServiceAgentImpl::InsertKeyValue(std::string_view key, return status; } -Status CoordinationServiceAgentImpl::InsertKeyValue(const char* key, - int64_t key_size, - const char* value, - int64_t value_size) { - return InsertKeyValue(std::string_view(key, key_size), - std::string_view(value, value_size)); -} - -Status CoordinationServiceAgentImpl::DeleteKeyValue(std::string_view key) { +absl::Status CoordinationServiceAgentImpl::DeleteKeyValue( + std::string_view key) { DeleteKeyValueRequest request; request.set_key(key.data(), key.size()); request.set_is_directory(true); VLOG(3) << "DeleteKeyValueRequest: " << request.DebugString(); DeleteKeyValueResponse response; - Status status; + absl::Status status; absl::Notification n; - leader_client_->DeleteKeyValueAsync(&request, &response, [&](Status s) { + leader_client_->DeleteKeyValueAsync(&request, &response, [&](absl::Status s) { status = s; n.Notify(); }); n.WaitForNotification(); VLOG(3) << "DeleteKeyValueResponse " << status; - return OkStatus(); -} - -Status CoordinationServiceAgentImpl::DeleteKeyValue(const char* key, - int64_t key_size) { - return DeleteKeyValue(std::string_view(key, key_size)); + return absl::OkStatus(); } -Status CoordinationServiceAgentImpl::UpdateKeyValue(std::string_view key, - std::string_view value) { - return MakeCoordinationError(errors::Unimplemented( +absl::Status CoordinationServiceAgentImpl::UpdateKeyValue( + std::string_view key, std::string_view value) { + return MakeCoordinationError(absl::UnimplementedError( "CoordinationServiceAgent::UpdateKeyValue is not implemented.")); } -Status CoordinationServiceAgentImpl::StartWatchKey( +absl::Status CoordinationServiceAgentImpl::StartWatchKey( std::string_view key, CoordinationServiceAgentImpl::ChangedKeyValuesCallback on_change) { - return MakeCoordinationError(errors::Unimplemented( + return MakeCoordinationError(absl::UnimplementedError( "CoordinationServiceAgent::StartWatchKey is not implemented.")); } -Status CoordinationServiceAgentImpl::StopWatchKey(std::string_view key) { - return MakeCoordinationError(errors::Unimplemented( +absl::Status CoordinationServiceAgentImpl::StopWatchKey(std::string_view key) { + return MakeCoordinationError(absl::UnimplementedError( "CoordinationServiceAgent::StopWatchKey is not implemented.")); } -void CoordinationServiceAgentImpl::SetError(const Status& error) { +void CoordinationServiceAgentImpl::SetError(const absl::Status& error) { assert(!error.ok()); mutex_lock l(state_mu_); if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) return; @@ -770,18 +764,18 @@ void CoordinationServiceAgentImpl::SetError(const Status& error) { error_fn_(error); } -Status CoordinationServiceAgentImpl::ActivateWatch( +absl::Status CoordinationServiceAgentImpl::ActivateWatch( std::string_view key, const std::map& kvs) { - return MakeCoordinationError(errors::Unimplemented( + return MakeCoordinationError(absl::UnimplementedError( "CoordinationServiceAgent::ActivateWatch is not implemented.")); } -Status CoordinationServiceAgentImpl::WaitAtBarrier( - const std::string& barrier_id, absl::Duration timeout, +absl::Status CoordinationServiceAgentImpl::WaitAtBarrier( + std::string_view barrier_id, absl::Duration timeout, const std::vector& tasks) { - Status status; + absl::Status status; absl::Notification n; - WaitAtBarrierAsync(barrier_id, timeout, tasks, [&](Status s) { + WaitAtBarrierAsync(barrier_id, timeout, tasks, [&](absl::Status s) { status = s; n.Notify(); }); @@ -790,9 +784,9 @@ Status CoordinationServiceAgentImpl::WaitAtBarrier( } void CoordinationServiceAgentImpl::WaitAtBarrierAsync( - const std::string& barrier_id, absl::Duration timeout, + std::string_view barrier_id, absl::Duration timeout, const std::vector& tasks, StatusCallback done) { - Status agent_running_status = + absl::Status agent_running_status = ValidateRunningAgent(/*allow_disconnected=*/true); if (!agent_running_status.ok()) { done(agent_running_status); @@ -800,35 +794,35 @@ void CoordinationServiceAgentImpl::WaitAtBarrierAsync( } { mutex_lock l(state_mu_); - auto [it, inserted] = used_barrier_ids_.insert(barrier_id); + auto [it, inserted] = used_barrier_ids_.insert(std::string(barrier_id)); if (!inserted) { - done(errors::FailedPrecondition( + done(absl::FailedPreconditionError(absl::StrCat( "WaitAtBarrier() should not be called with the same id more than " "once. Barrier id: ", - barrier_id)); + barrier_id))); return; } } auto request = std::make_shared(); auto response = std::make_shared(); - request->set_barrier_id(barrier_id); + request->set_barrier_id(std::string(barrier_id)); request->set_barrier_timeout_in_ms(timeout / absl::Milliseconds(1)); *request->mutable_source_task() = task_; *request->mutable_tasks() = {tasks.begin(), tasks.end()}; VLOG(3) << "WaitAtBarrierRequest: " << request->DebugString(); leader_client_->BarrierAsync( request.get(), response.get(), - [request, response, done = std::move(done)](const Status& s) { + [request, response, done = std::move(done)](const absl::Status& s) { done(s); VLOG(3) << "WaitAtBarrierResponse: " << s; }); } -Status CoordinationServiceAgentImpl::CancelBarrier( - const std::string& barrier_id) { - Status status; +absl::Status CoordinationServiceAgentImpl::CancelBarrier( + std::string_view barrier_id) { + absl::Status status; absl::Notification n; - CancelBarrierAsync(barrier_id, [&](const Status& s) { + CancelBarrierAsync(barrier_id, [&](const absl::Status& s) { status = s; n.Notify(); }); @@ -837,8 +831,8 @@ Status CoordinationServiceAgentImpl::CancelBarrier( } void CoordinationServiceAgentImpl::CancelBarrierAsync( - const std::string& barrier_id, StatusCallback done) { - Status agent_running_status = + std::string_view barrier_id, StatusCallback done) { + absl::Status agent_running_status = ValidateRunningAgent(/*allow_disconnected=*/true); if (!agent_running_status.ok()) { done(agent_running_status); @@ -846,53 +840,53 @@ void CoordinationServiceAgentImpl::CancelBarrierAsync( } auto request = std::make_shared(); auto response = std::make_shared(); - request->set_barrier_id(barrier_id); + request->set_barrier_id(std::string(barrier_id)); *request->mutable_source_task() = task_; VLOG(3) << "CancelBarrierRequest: " << request->DebugString(); leader_client_->CancelBarrierAsync( request.get(), response.get(), - [request, response, done = std::move(done)](const Status& s) { + [request, response, done = std::move(done)](const absl::Status& s) { done(s); VLOG(3) << "CancelBarrierResponse: " << s; }); } // Returns an error if agent is not running. -Status CoordinationServiceAgentImpl::ValidateRunningAgent( +absl::Status CoordinationServiceAgentImpl::ValidateRunningAgent( bool allow_disconnected) { mutex_lock l(state_mu_); switch (state_) { case CoordinatedTaskState::TASKSTATE_CONNECTED: - return OkStatus(); + return absl::OkStatus(); case CoordinatedTaskState::TASKSTATE_UNINITIALIZED: - return MakeCoordinationError(errors::FailedPrecondition( + return MakeCoordinationError(absl::FailedPreconditionError( "Agent must be in CONNECTED state. It is currently UNINITIALIZED.")); case CoordinatedTaskState::TASKSTATE_DISCONNECTED: - if (allow_disconnected) return OkStatus(); - return MakeCoordinationError(errors::FailedPrecondition( + if (allow_disconnected) return absl::OkStatus(); + return MakeCoordinationError(absl::FailedPreconditionError( "Agent must be in CONNECTED state. It is currently DISCONNECTED.")); case CoordinatedTaskState::TASKSTATE_ERROR: - return MakeCoordinationError(errors::FailedPrecondition( + return MakeCoordinationError(absl::FailedPreconditionError( "Agent must be in CONNECTED state. It is currently in ERROR.")); default: - return MakeCoordinationError(errors::FailedPrecondition(absl::StrCat( + return MakeCoordinationError(absl::FailedPreconditionError(absl::StrCat( "Agent is not in CONNECTED state. Current state: ", state_))); } } -StatusOr CoordinationServiceAgentImpl::GetEnv() { +absl::StatusOr CoordinationServiceAgentImpl::GetEnv() { if (!IsInitialized()) { - return MakeCoordinationError(errors::FailedPrecondition( + return MakeCoordinationError(absl::FailedPreconditionError( "Coordination service agent has not been initialized.")); } if (env_ == nullptr) { return MakeCoordinationError( - errors::FailedPrecondition("Coordination service agent was not " - "initialized with a valid Env* object.")); + absl::FailedPreconditionError("Coordination service agent was not " + "initialized with a valid Env* object.")); } return env_; } diff --git a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.h b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.h index f94e6ac9dcb20..6c31eccffd10f 100644 --- a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.h +++ b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_TSL_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_AGENT_H_ #define TENSORFLOW_TSL_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_AGENT_H_ +#include #include #include #include @@ -52,35 +53,34 @@ class Env; // reported by the user via `agent->ReportError()`. // // Possible service errors: -// - errors::Internal: Coordination service is not enabled. +// - Internal: Coordination service is not enabled. // If it was previously accessible, coordination service // has been shut down. -// - errors::Aborted: Incarnation mismatch during heartbeat (either remote +// - Aborted: Incarnation mismatch during heartbeat (either remote // task or coordination service has restarted). -// - errors::Unavailable: Heartbeat timeout from remote task (failed, +// - Unavailable: Heartbeat timeout from remote task (failed, // crashed or got preempted). -// - errors::InvalidArgument: Unexpected heartbeat from remote task (not +// - InvalidArgument: Unexpected heartbeat from remote task (not // registered or wrong config). -// TODO(hanyangtay): Migrate to string_view for string parameters. class CoordinationServiceAgent { public: using StatusOrValueCallback = - std::function&)>; + std::function&)>; // Collection of key-value pairs in the same directory. using StatusOrValueDirCallback = std::function>&)>; + const absl::StatusOr>&)>; using ChangedKeyValuesCallback = std::function&)>; virtual ~CoordinationServiceAgent() = default; // Initialize coordination service agent. - virtual Status Initialize( - tsl::Env* env, const std::string& job_name, int task_id, + virtual absl::Status Initialize( + tsl::Env* env, std::string_view job_name, int task_id, const tensorflow::CoordinationServiceConfig& configs, std::unique_ptr leader_client, StatusCallback error_fn) = 0; - virtual Status Initialize( + virtual absl::Status Initialize( tsl::Env* env, const tensorflow::CoordinatedTask& task, const tensorflow::CoordinationServiceConfig& configs, std::unique_ptr leader_client, @@ -105,14 +105,14 @@ class CoordinationServiceAgent { // - InvalidArgument: Unexpected task registration // - Aborted: Duplicate task registration (agent will retry connecting until // the configured timeout) - virtual Status Connect() = 0; + virtual absl::Status Connect() = 0; // Wait for all tasks to be up and registered. The call blocks until all tasks // in the cluster are up, or some error occurs. // Possible service errors: // - FailedPrecondition: Agent is not in CONNECTED state. // - InvalidArgument: Unexpected task request - virtual Status WaitForAllTasks( + virtual absl::Status WaitForAllTasks( const tensorflow::DeviceInfo& local_devices) = 0; // Get the device attributes of tasks from remote tasks in the cluster. @@ -127,10 +127,10 @@ class CoordinationServiceAgent { // Reset // Get task associated with this agent. - virtual StatusOr GetOwnTask() = 0; + virtual absl::StatusOr GetOwnTask() = 0; // Get status of a remote task. - virtual StatusOr> + virtual absl::StatusOr> GetTaskState(const std::vector& task) = 0; // Report error to coordination service. This will invoke the error callback. @@ -139,7 +139,7 @@ class CoordinationServiceAgent { // Possible service errors: // - FailedPrecondition: Uninitialized/disconnected/already in error state. // - InvalidArgument: Unexpected task request - virtual Status ReportError(const Status& error) = 0; + virtual absl::Status ReportError(const absl::Status& error) = 0; // Shuts down by disconnecting from the service. Should only be called if // agent is connected and no further agent calls (except the destructor) are @@ -151,14 +151,14 @@ class CoordinationServiceAgent { // - InvalidArgument: Unexpected task request. // - FailedPrecondition: Task was in error state (note: agent is still // shut down forcefully). - virtual Status Shutdown() = 0; + virtual absl::Status Shutdown() = 0; // Disconnect from the service, and clean up the internal error status. // Possible service errors: // - InvalidArgument: Unexpected task request. // - FailedPrecondition: task is not in error state/has already // disconnected. - virtual Status Reset() = 0; + virtual absl::Status Reset() = 0; // Key-value store API. // The agent does not need to be connected to utilize the key-value store. @@ -168,51 +168,46 @@ class CoordinationServiceAgent { // Get config key-value from the service. // If the key-value is not inserted yet, this is a blocking call that waits // until the corresponding key is inserted. - // - errors::DeadlineExceeded: timed out waiting for key. - virtual StatusOr GetKeyValue(std::string_view key) = 0; - virtual StatusOr GetKeyValue(const char* key, - int64_t key_size) = 0; - virtual StatusOr GetKeyValue(std::string_view key, - absl::Duration timeout) = 0; + // - DeadlineExceeded: timed out waiting for key. + virtual absl::StatusOr GetKeyValue(std::string_view key) = 0; + virtual absl::StatusOr GetKeyValue(std::string_view key, + absl::Duration timeout) = 0; // Note: Cancel the underlying RPC call with `call_opts->StartCancel()` and // `call_opts->ClearCancelCallback()`. virtual std::shared_ptr GetKeyValueAsync( std::string_view, StatusOrValueCallback done) = 0; // Get config key-value from the service. - // - errors::NotFound: the requested key does not exist. - virtual StatusOr TryGetKeyValue(std::string_view key) = 0; + // - NotFound: the requested key does not exist. + virtual absl::StatusOr TryGetKeyValue(std::string_view key) = 0; // Get all values under a directory (key). // A value is considered to be in the directory if its key is prefixed with // the directory. // This is not a blocking call. If no keys are found, an empty vector is // returned immediately. - virtual StatusOr> GetKeyValueDir( + virtual absl::StatusOr> GetKeyValueDir( std::string_view key) = 0; virtual void GetKeyValueDirAsync(std::string_view key, StatusOrValueDirCallback done) = 0; // Insert config key-value to the service. - // - errors::AlreadyExists: key is already set. - virtual Status InsertKeyValue(std::string_view key, - std::string_view value) = 0; - virtual Status InsertKeyValue(const char* key, int64_t key_size, - const char* value, int64_t value_size) = 0; + // - AlreadyExists: key is already set. + virtual absl::Status InsertKeyValue(std::string_view key, + std::string_view value) = 0; // Delete config keys in the coordination service. - virtual Status DeleteKeyValue(std::string_view key) = 0; - virtual Status DeleteKeyValue(const char* key, int64_t key_size) = 0; + virtual absl::Status DeleteKeyValue(std::string_view key) = 0; // Update the value of a config key. - virtual Status UpdateKeyValue(std::string_view key, - std::string_view value) = 0; + virtual absl::Status UpdateKeyValue(std::string_view key, + std::string_view value) = 0; // Register a callback that will be invoked when the key or keys under the key // directory are changed (inserted, deleted, or updated). - virtual Status StartWatchKey(std::string_view key, - ChangedKeyValuesCallback on_change) = 0; - virtual Status StopWatchKey(std::string_view key) = 0; + virtual absl::Status StartWatchKey(std::string_view key, + ChangedKeyValuesCallback on_change) = 0; + virtual absl::Status StopWatchKey(std::string_view key) = 0; // Blocks until all (or a subset of) tasks are at the barrier or the barrier // fails. @@ -246,12 +241,12 @@ class CoordinationServiceAgent { // list of participating tasks. // - FailedPrecondition: Agent is in UNINITIALIZED or ERROR state. Or the // same barrier_id was already used previously. - virtual Status WaitAtBarrier( - const std::string& barrier_id, absl::Duration timeout, + virtual absl::Status WaitAtBarrier( + std::string_view barrier_id, absl::Duration timeout, const std::vector& tasks) = 0; virtual void WaitAtBarrierAsync( - const std::string& barrier_id, absl::Duration timeout, + std::string_view barrier_id, absl::Duration timeout, const std::vector& tasks, StatusCallback done) = 0; @@ -260,22 +255,22 @@ class CoordinationServiceAgent { // CANCELLED error status. // Possible service errors: // - FailedPrecondition: Barrier has already been passed. - virtual Status CancelBarrier(const std::string& barrier_id) = 0; - virtual void CancelBarrierAsync(const std::string& barrier_id, + virtual absl::Status CancelBarrier(std::string_view barrier_id) = 0; + virtual void CancelBarrierAsync(std::string_view barrier_id, StatusCallback done) = 0; // Get unowned Env* that the agent was initialized with. - virtual StatusOr GetEnv() = 0; + virtual absl::StatusOr GetEnv() = 0; protected: // Set the service agent to error status and invoke the error callback. // Note: different from ReportError, this does not report the error status to // remote coordination service. - virtual void SetError(const Status& error) = 0; + virtual void SetError(const absl::Status& error) = 0; // Activate the key-value callback watch. - virtual Status ActivateWatch(std::string_view, - const std::map&) = 0; + virtual absl::Status ActivateWatch( + std::string_view, const std::map&) = 0; private: friend class CoordinationServiceRpcHandler; diff --git a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc index bbbc3f01d0b9a..60b18a033dd34 100644 --- a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc +++ b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc @@ -16,11 +16,15 @@ limitations under the License. #include "tsl/distributed_runtime/coordination/coordination_service_agent.h" #include +#include #include #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "tsl/distributed_runtime/call_options.h" @@ -29,6 +33,7 @@ limitations under the License. #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" #include "tsl/protobuf/coordination_config.pb.h" #include "tsl/protobuf/coordination_service.pb.h" @@ -42,7 +47,6 @@ using tensorflow::KeyValueEntry; using ::testing::_; using ::testing::DoAll; using ::testing::InvokeArgument; -using ::testing::Pointee; using ::testing::SetArgPointee; using ::testing::UnorderedPointwise; using ::testing::WithArgs; @@ -60,8 +64,8 @@ class ProtoStringMatcher { return p.DebugString() == expected_; } - void DescribeTo(::std::ostream* os) const { *os << expected_; } - void DescribeNegationTo(::std::ostream* os) const { + void DescribeTo(std::ostream* os) const { *os << expected_; } + void DescribeNegationTo(std::ostream* os) const { *os << "not equal to expected message: " << expected_; } @@ -69,11 +73,6 @@ class ProtoStringMatcher { const std::string expected_; }; -inline ::testing::PolymorphicMatcher EqualsProto( - const tsl::protobuf::Message& x) { - return ::testing::MakePolymorphicMatcher(ProtoStringMatcher(x)); -} - MATCHER(KvEq, "simple KeyValueEntry matcher") { const KeyValueEntry& kv0 = std::get<0>(arg); const KeyValueEntry& kv1 = std::get<1>(arg); @@ -141,7 +140,7 @@ class TestCoordinationClient : public CoordinationClient { void method##Async(const method##Request* request, \ method##Response* response, StatusCallback done) \ override { \ - done(errors::Unimplemented(#method "Async")); \ + done(absl::UnimplementedError(#method "Async")); \ } UNIMPLEMENTED(WaitForAllTasks); @@ -151,7 +150,7 @@ class TestCoordinationClient : public CoordinationClient { const ReportErrorToTaskRequest* request, ReportErrorToTaskResponse* response, StatusCallback done) override { - done(errors::Unimplemented("ReportErrorToTaskAsync")); + done(absl::UnimplementedError("ReportErrorToTaskAsync")); } }; @@ -159,19 +158,19 @@ class CoordinationServiceAgentTest : public ::testing::Test { public: void SetUp() override { ON_CALL(*client_, RegisterTaskAsync(_, _, _, _)) - .WillByDefault(InvokeArgument<3>(OkStatus())); + .WillByDefault(InvokeArgument<3>(absl::OkStatus())); ON_CALL(*client_, HeartbeatAsync(_, _, _, _)) - .WillByDefault(InvokeArgument<3>(OkStatus())); + .WillByDefault(InvokeArgument<3>(absl::OkStatus())); ON_CALL(*client_, ShutdownTaskAsync(_, _, _, _)) - .WillByDefault(InvokeArgument<3>(OkStatus())); + .WillByDefault(InvokeArgument<3>(absl::OkStatus())); ON_CALL(*client_, ReportErrorToServiceAsync(_, _, _)) - .WillByDefault(InvokeArgument<2>(OkStatus())); + .WillByDefault(InvokeArgument<2>(absl::OkStatus())); ON_CALL(*client_, ResetTaskAsync(_, _, _)) - .WillByDefault(InvokeArgument<2>(OkStatus())); + .WillByDefault(InvokeArgument<2>(absl::OkStatus())); ON_CALL(*client_, BarrierAsync(_, _, _)) - .WillByDefault(InvokeArgument<2>(OkStatus())); + .WillByDefault(InvokeArgument<2>(absl::OkStatus())); ON_CALL(*client_, GetTaskStateAsync(_, _, _)) - .WillByDefault(InvokeArgument<2>(OkStatus())); + .WillByDefault(InvokeArgument<2>(absl::OkStatus())); } // Should be called after mocking service responses, before testing the agent. @@ -180,7 +179,7 @@ class CoordinationServiceAgentTest : public ::testing::Test { TF_ASSERT_OK(agent_->Initialize( Env::Default(), /*job_name=*/"test_job", /*task_id=*/0, config, std::move(client_), - /*error_fn=*/[](Status s) { + /*error_fn=*/[](absl::Status s) { LOG(ERROR) << "Coordination agent is set to error: " << s; })); } @@ -209,7 +208,7 @@ TEST_F(CoordinationServiceAgentTest, GetKeyValue_Simple_Success) { kv->set_value(test_value); ON_CALL(*GetClient(), GetKeyValueAsync(_, _, _, _)) .WillByDefault(DoAll(SetArgPointee<2>(mocked_response), - InvokeArgument<3>(OkStatus()))); + InvokeArgument<3>(absl::OkStatus()))); // Initialize coordination agent. InitializeAgent(); @@ -229,7 +228,7 @@ TEST_F(CoordinationServiceAgentTest, GetKeyValue_WithTimeout_Success) { kv->set_value(test_value); ON_CALL(*GetClient(), GetKeyValueAsync(_, _, _, _)) .WillByDefault(DoAll(SetArgPointee<2>(mocked_response), - InvokeArgument<3>(OkStatus()))); + InvokeArgument<3>(absl::OkStatus()))); // Initialize coordination agent. InitializeAgent(); @@ -254,7 +253,7 @@ TEST_F(CoordinationServiceAgentTest, GetKeyValue_Timeout_ReturnError) { EXPECT_EQ(result.status().code(), error::DEADLINE_EXCEEDED); // Needed to tear down test safely since agent dtor would cancel pending // calls, which would reference deallocated call_opts. - owned_done(errors::Cancelled("error")); + owned_done(absl::CancelledError("error")); } TEST_F(CoordinationServiceAgentTest, @@ -282,7 +281,7 @@ TEST_F(CoordinationServiceAgentTest, auto kv = owned_response->mutable_kv(); kv->set_key(test_key); kv->set_value(test_value); - owned_done(OkStatus()); + owned_done(absl::OkStatus()); // No explicit test, but used to verify there is no stack-use-after-return // or other memory-related errors. } @@ -314,7 +313,7 @@ TEST_F(CoordinationServiceAgentTest, auto kv = owned_response->mutable_kv(); kv->set_key(test_key); kv->set_value(test_value); - owned_done(OkStatus()); + owned_done(absl::OkStatus()); })); })); InitializeAgent(); @@ -332,19 +331,19 @@ TEST_F(CoordinationServiceAgentTest, CancelGetKeyValue_Success) { WithArgs<0, 3>([](CallOptions* call_opts, StatusCallback done) { // Mock RPC call cancellation. call_opts->SetCancelCallback([callback = std::move(done)]() { - callback(errors::Cancelled("RPC call cancelled.")); + callback(absl::CancelledError("RPC call cancelled.")); }); })); InitializeAgent(); - Status status; + absl::Status status; std::shared_ptr get_kv_call_opts = agent_->GetKeyValueAsync( - test_key, [&status](const StatusOr& result) { + test_key, [&status](const absl::StatusOr& result) { status = result.status(); }); get_kv_call_opts->StartCancel(); - EXPECT_TRUE(errors::IsCancelled(status)) << status; + EXPECT_TRUE(absl::IsCancelled(status)) << status; // This is to prevent memory leaks due to how we set this particular cancel // callback. In practice, this should not be necessary. get_kv_call_opts->ClearCancelCallback(); @@ -360,7 +359,7 @@ TEST_F(CoordinationServiceAgentTest, TryGetKeyValue_Simple_Success) { kv->set_value(test_value); ON_CALL(*GetClient(), TryGetKeyValueAsync(_, _, _)) .WillByDefault(DoAll(SetArgPointee<1>(mocked_response), - InvokeArgument<2>(OkStatus()))); + InvokeArgument<2>(absl::OkStatus()))); // Initialize coordination agent. InitializeAgent(); @@ -380,7 +379,7 @@ TEST_F(CoordinationServiceAgentTest, GetKeyValueDir_Simple_Success) { *mocked_response.mutable_kv() = {test_values.begin(), test_values.end()}; ON_CALL(*GetClient(), GetKeyValueDirAsync(_, _, _)) .WillByDefault(DoAll(SetArgPointee<1>(mocked_response), - InvokeArgument<2>(OkStatus()))); + InvokeArgument<2>(absl::OkStatus()))); // Initialize coordination agent. InitializeAgent(); @@ -390,88 +389,16 @@ TEST_F(CoordinationServiceAgentTest, GetKeyValueDir_Simple_Success) { EXPECT_THAT(*result, UnorderedPointwise(KvEq(), test_values)); } -TEST_F(CoordinationServiceAgentTest, - InsertKeyValue_EarlyNullCharacter_Success) { - // Note: this API is passing C-strings, but users might pass a string with - // a null character in the middle due to their specific serialization / - // encoding mechanism (e.g. protobuf). This test makes sure the full C-string - // (as specified by the user) gets passed in regardless of the early null - // character. - std::string test_key = "test_x_key"; - test_key[5] = '\0'; // Replace x with null character '\0'. - std::string test_value = "test_x_value"; - test_value[5] = '\0'; - InsertKeyValueRequest expected_input; - expected_input.mutable_kv()->set_key(test_key); - expected_input.mutable_kv()->set_value(test_value); - - EXPECT_CALL(*GetClient(), - InsertKeyValueAsync(Pointee(EqualsProto(expected_input)), _, _)) - .WillOnce(InvokeArgument<2>(OkStatus())); - InitializeAgent(); - - TF_ASSERT_OK(agent_->InsertKeyValue(test_key.c_str(), test_key.size(), - test_value.c_str(), test_value.size())); -} - -TEST_F(CoordinationServiceAgentTest, GetKeyValue_EarlyNullCharacter_Success) { - // Note: this API is passing C-strings, but users might pass a string with - // a null character in the middle due to their specific serialization / - // encoding mechanism (e.g. protobuf). This test makes sure the full cstring - // (as specified by the user) gets passed in regardless of the early null - // character. - std::string test_key = "test_x_key"; - test_key[5] = '\0'; // Replace x with null character '\0'. - const std::string test_value = "test_value"; - GetKeyValueRequest expected_input; - expected_input.set_key(test_key); - - // Mock server response: set key-value pair and invoke done callback. - GetKeyValueResponse mocked_response; - mocked_response.mutable_kv()->set_key(test_key); - mocked_response.mutable_kv()->set_value(test_value); - EXPECT_CALL(*GetClient(), - GetKeyValueAsync(_, Pointee(EqualsProto(expected_input)), _, _)) - .WillOnce(DoAll(SetArgPointee<2>(mocked_response), - InvokeArgument<3>(OkStatus()))); - InitializeAgent(); - - auto result = agent_->GetKeyValue(test_key.data(), test_key.size()); - TF_ASSERT_OK(result.status()); - EXPECT_EQ(*result, test_value); -} - -TEST_F(CoordinationServiceAgentTest, - DeleteKeyValue_EarlyNullCharacter_Success) { - // Note: this API is passing C-strings, but users might pass a string with - // a null character in the middle due to their specific serialization / - // encoding mechanism (e.g. protobuf). This test makes sure the full cstring - // (as specified by the user) gets passed in regardless of the early null - // character. - std::string test_key = "test_x_key"; - test_key[5] = '\0'; // Replace x with null character '\0'. - DeleteKeyValueRequest expected_input; - expected_input.set_key(test_key); - expected_input.set_is_directory(true); // This is default. - - EXPECT_CALL(*GetClient(), - DeleteKeyValueAsync(Pointee(EqualsProto(expected_input)), _, _)) - .WillOnce(InvokeArgument<2>(OkStatus())); - InitializeAgent(); - - TF_ASSERT_OK(agent_->DeleteKeyValue(test_key.c_str(), test_key.size())); -} - TEST_F(CoordinationServiceAgentTest, ShutdownInErrorShouldReturnError) { // Connect coordination agent and set it to error. InitializeAgent(); TF_ASSERT_OK(agent_->Connect()); - TF_ASSERT_OK(agent_->ReportError(errors::Internal("Test Error."))); + TF_ASSERT_OK(agent_->ReportError(absl::InternalError("Test Error."))); // Shutdown should return error. - Status s = agent_->Shutdown(); + absl::Status s = agent_->Shutdown(); - EXPECT_TRUE(errors::IsFailedPrecondition(s)); + EXPECT_TRUE(absl::IsFailedPrecondition(s)); } TEST_F(CoordinationServiceAgentTest, Reset_ConnectedButNotInError_Fail) { @@ -482,14 +409,14 @@ TEST_F(CoordinationServiceAgentTest, Reset_ConnectedButNotInError_Fail) { auto status = agent_->Reset(); // Fails because agent is not in ERROR state. - EXPECT_TRUE(errors::IsFailedPrecondition(status)); + EXPECT_TRUE(absl::IsFailedPrecondition(status)); } TEST_F(CoordinationServiceAgentTest, ConnectAfterResetError) { // Connect coordination agent and set it to error. InitializeAgent(); TF_ASSERT_OK(agent_->Connect()); - TF_ASSERT_OK(agent_->ReportError(errors::Internal("Test Error."))); + TF_ASSERT_OK(agent_->ReportError(absl::InternalError("Test Error."))); // Reset error. TF_ASSERT_OK(agent_->Reset()); @@ -500,16 +427,16 @@ TEST_F(CoordinationServiceAgentTest, ConnectAfterResetError) { TEST_F(CoordinationServiceAgentTest, ResetCanBeRetried) { // Mock reset error failing for the first time. EXPECT_CALL(*GetClient(), ResetTaskAsync(_, _, _)) - .WillOnce(InvokeArgument<2>(errors::Internal("Reset error"))) - .WillOnce(InvokeArgument<2>(OkStatus())); + .WillOnce(InvokeArgument<2>(absl::InternalError("Reset error"))) + .WillOnce(InvokeArgument<2>(absl::OkStatus())); // Connect coordination agent and set it to error. InitializeAgent(); TF_ASSERT_OK(agent_->Connect()); - TF_ASSERT_OK(agent_->ReportError(errors::Internal("Test Error."))); + TF_ASSERT_OK(agent_->ReportError(absl::InternalError("Test Error."))); // Reset error fails for the first time. - Status reset_status = agent_->Reset(); - EXPECT_TRUE(errors::IsInternal(reset_status)); + absl::Status reset_status = agent_->Reset(); + EXPECT_TRUE(absl::IsInternal(reset_status)); // Agent should be able to attempt resetting again. TF_ASSERT_OK(agent_->Reset()); @@ -535,7 +462,7 @@ TEST_F(CoordinationServiceAgentTest, GetOwnTask) { TEST_F(CoordinationServiceAgentTest, GetOwnTask_Uninitialized) { auto result = agent_->GetOwnTask(); - EXPECT_TRUE(errors::IsFailedPrecondition(result.status())); + EXPECT_TRUE(absl::IsFailedPrecondition(result.status())); } TEST_F(CoordinationServiceAgentTest, WaitAtBarrier_SameIdUsedTwice_Fails) { @@ -550,14 +477,14 @@ TEST_F(CoordinationServiceAgentTest, WaitAtBarrier_SameIdUsedTwice_Fails) { auto result = agent_->WaitAtBarrier(barrier_id, absl::Seconds(1), /*tasks=*/{}); - EXPECT_TRUE(errors::IsFailedPrecondition(result)); + EXPECT_TRUE(absl::IsFailedPrecondition(result)); } TEST_F(CoordinationServiceAgentTest, GetEnv_SucceedsAfterInit) { - EXPECT_TRUE(errors::IsFailedPrecondition(agent_->GetEnv().status())); + EXPECT_TRUE(absl::IsFailedPrecondition(agent_->GetEnv().status())); InitializeAgent(); - StatusOr result = agent_->GetEnv(); + absl::StatusOr result = agent_->GetEnv(); TF_ASSERT_OK(result.status()); EXPECT_EQ(*result, Env::Default()); @@ -566,9 +493,11 @@ TEST_F(CoordinationServiceAgentTest, GetEnv_SucceedsAfterInit) { TEST_F(CoordinationServiceAgentTest, Connect_AbortedErrorShouldBeRetried) { // Mock connection failing for the first two times. EXPECT_CALL(*GetClient(), RegisterTaskAsync(_, _, _, _)) - .WillOnce(InvokeArgument<3>(errors::Aborted("DuplicateTaskRegistration"))) - .WillOnce(InvokeArgument<3>(errors::Aborted("DuplicateTaskRegistration"))) - .WillOnce(InvokeArgument<3>(OkStatus())); + .WillOnce( + InvokeArgument<3>(absl::AbortedError("DuplicateTaskRegistration"))) + .WillOnce( + InvokeArgument<3>(absl::AbortedError("DuplicateTaskRegistration"))) + .WillOnce(InvokeArgument<3>(absl::OkStatus())); InitializeAgent(); TF_EXPECT_OK(agent_->Connect()); @@ -579,25 +508,25 @@ TEST_F(CoordinationServiceAgentTest, Connect_AbortedErrorShouldFailEventually) { // restarts. EXPECT_CALL(*GetClient(), RegisterTaskAsync(_, _, _, _)) .WillRepeatedly( - InvokeArgument<3>(errors::Aborted("DuplicateTaskRegistration"))); + InvokeArgument<3>(absl::AbortedError("DuplicateTaskRegistration"))); CoordinationServiceConfig config; // Connect should only be retried for 3 seconds. config.set_cluster_register_timeout_in_ms( absl::ToInt64Milliseconds(absl::Seconds(3))); InitializeAgent(config); - Status s = agent_->Connect(); + absl::Status s = agent_->Connect(); - EXPECT_TRUE(errors::IsAborted(s)); + EXPECT_TRUE(absl::IsAborted(s)); } TEST_F(CoordinationServiceAgentTest, Connect_InternalErrorShouldBeRetried) { EXPECT_CALL(*GetClient(), RegisterTaskAsync(_, _, _, _)) .WillOnce(InvokeArgument<3>( - errors::Internal("Coordination service is not enabled."))) + absl::InternalError("Coordination service is not enabled."))) .WillOnce(InvokeArgument<3>( - errors::Internal("Coordination service is not enabled."))) - .WillOnce(InvokeArgument<3>(OkStatus())); + absl::InternalError("Coordination service is not enabled."))) + .WillOnce(InvokeArgument<3>(absl::OkStatus())); InitializeAgent(); TF_EXPECT_OK(agent_->Connect()); diff --git a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_error_util.h b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_error_util.h index 79156778f40d9..851e233c1a2e8 100644 --- a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_error_util.h +++ b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_error_util.h @@ -28,7 +28,7 @@ constexpr absl::string_view CoordinationErrorPayloadKey() { // Mark error as a coordination service error (as opposed to RPC // errors). -inline Status MakeCoordinationError(Status s) { +inline absl::Status MakeCoordinationError(absl::Status s) { s.SetPayload(CoordinationErrorPayloadKey(), absl::Cord("")); return s; } @@ -37,9 +37,9 @@ inline Status MakeCoordinationError(Status s) { // errors), and indicate error origin. // Errors reported via the agent API by the user should set `is_reported_error` // to true. -inline Status MakeCoordinationError(Status s, - const tensorflow::CoordinatedTask& origin, - bool is_reported_error = false) { +inline absl::Status MakeCoordinationError( + absl::Status s, const tensorflow::CoordinatedTask& origin, + bool is_reported_error = false) { tensorflow::CoordinationServiceError error; *error.mutable_source_task() = origin; error.set_is_reported_error(is_reported_error); @@ -49,8 +49,8 @@ inline Status MakeCoordinationError(Status s, } // Mark error as a coordination service error with payload. -inline Status MakeCoordinationError( - Status s, const tensorflow::CoordinationServiceError& payload) { +inline absl::Status MakeCoordinationError( + absl::Status s, const tensorflow::CoordinationServiceError& payload) { s.SetPayload(CoordinationErrorPayloadKey(), absl::Cord(payload.SerializeAsString())); return s; diff --git a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc index 72bb8035b5766..5c10b5aec1354 100644 --- a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc +++ b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc @@ -26,9 +26,9 @@ using ::tensorflow::CoordinatedTask; using ::tensorflow::CoordinationServiceError; TEST(CoordinationServiceErrorUtil, MakeCoordinationErrorWithEmptyPayload) { - Status error = errors::Internal("Test Error"); + absl::Status error = errors::Internal("Test Error"); - Status coordination_error = MakeCoordinationError(error); + absl::Status coordination_error = MakeCoordinationError(error); EXPECT_EQ(coordination_error.code(), error.code()); EXPECT_EQ(coordination_error.message(), error.message()); @@ -38,12 +38,12 @@ TEST(CoordinationServiceErrorUtil, MakeCoordinationErrorWithEmptyPayload) { } TEST(CoordinationServiceErrorUtil, MakeCoordinationErrorWithErrorOrigin) { - Status error = errors::Internal("Test Error"); + absl::Status error = errors::Internal("Test Error"); CoordinatedTask source_task; source_task.set_job_name("test_worker"); source_task.set_task_id(7); - Status coordination_error = MakeCoordinationError(error, source_task); + absl::Status coordination_error = MakeCoordinationError(error, source_task); EXPECT_EQ(coordination_error.code(), error.code()); EXPECT_EQ(coordination_error.message(), error.message()); @@ -57,13 +57,14 @@ TEST(CoordinationServiceErrorUtil, MakeCoordinationErrorWithErrorOrigin) { } TEST(CoordinationServiceErrorUtil, MakeCoordinationErrorWithUserReportedError) { - Status error = errors::Internal("Test Error"); + absl::Status error = errors::Internal("Test Error"); CoordinatedTask source_task; source_task.set_job_name("test_worker"); source_task.set_task_id(7); - Status coordination_error = MakeCoordinationError(error, source_task, - /*is_reported_error=*/true); + absl::Status coordination_error = + MakeCoordinationError(error, source_task, + /*is_reported_error=*/true); EXPECT_EQ(coordination_error.code(), error.code()); EXPECT_EQ(coordination_error.message(), error.message()); @@ -77,14 +78,14 @@ TEST(CoordinationServiceErrorUtil, MakeCoordinationErrorWithUserReportedError) { } TEST(CoordinationServiceErrorUtil, MakeCoordinationErrorWithPayload) { - Status error = errors::Internal("Test Error"); + absl::Status error = errors::Internal("Test Error"); CoordinationServiceError payload; CoordinatedTask* source_task = payload.mutable_source_task(); source_task->set_job_name("test_worker"); source_task->set_task_id(7); payload.set_is_reported_error(true); - Status coordination_error = MakeCoordinationError(error, payload); + absl::Status coordination_error = MakeCoordinationError(error, payload); EXPECT_EQ(coordination_error.code(), error.code()); EXPECT_EQ(coordination_error.message(), error.message()); diff --git a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc index 4a9c1e9d6cdfd..fabcefd019789 100644 --- a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc +++ b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc @@ -113,7 +113,7 @@ class TestCoordinationServiceTaskState { void InitializeAndConnectCoordinationAgents( const std::string& job_name, int task_id, const CoordinationServiceConfig& coordination_config) { - auto error_fn = [this, job_name](const Status& status) { + auto error_fn = [this, job_name](const absl::Status& status) { this->status_ = status; LOG(ERROR) << "Coordination service agent of " << job_name << " is in error status: " << status; @@ -128,11 +128,11 @@ class TestCoordinationServiceTaskState { CoordinationClient* GetCoordinationClient() { return coord_client_.get(); } - Status ReportError(const Status& status) { + absl::Status ReportError(const absl::Status& status) { return coord_agent_->ReportError(status); } - Status GetStatus() const { return status_; } + absl::Status GetStatus() const { return status_; } private: std::unique_ptr<::grpc::Server> grpc_server_; @@ -142,7 +142,7 @@ class TestCoordinationServiceTaskState { std::unique_ptr coord_agent_ = CreateCoordinationServiceAgent(); std::unique_ptr coord_client_; - Status status_; + absl::Status status_; }; class CoordinationServiceRecoverableJobTest : public ::testing::Test { diff --git a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc index b40864d786f13..2f2f87687d6fe 100644 --- a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc +++ b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc @@ -78,13 +78,13 @@ void CoordinationServiceRpcHandler::HeartbeatAsync( const CoordinatedTask& task = request->source_task(); const uint64_t incarnation = request->incarnation(); const uint64_t leader_incarnation = service_->GetServiceIncarnation(); - Status s = service_->RecordHeartbeat(task, incarnation); + absl::Status s = service_->RecordHeartbeat(task, incarnation); if (!s.ok()) { done(s); return; } response->set_leader_incarnation(leader_incarnation); - done(OkStatus()); + done(absl::OkStatus()); } void CoordinationServiceRpcHandler::WaitForAllTasksAsync( @@ -98,7 +98,7 @@ void CoordinationServiceRpcHandler::WaitForAllTasksAsync( } service_->WaitForAllTasks( request->source_task(), request->device_info(), - [response, service = service_, done = std::move(done)](Status s) { + [response, service = service_, done = std::move(done)](absl::Status s) { if (s.ok()) { *response->mutable_device_info() = service->ListClusterDevices(); } @@ -116,7 +116,7 @@ void CoordinationServiceRpcHandler::ShutdownTaskAsync( return; } service_->ShutdownTaskAsync(request->source_task(), - [done](Status s) { done(s); }); + [done](absl::Status s) { done(s); }); } void CoordinationServiceRpcHandler::ResetTaskAsync( @@ -141,14 +141,15 @@ void CoordinationServiceRpcHandler::ReportErrorToTaskAsync( return; } const CoordinationServiceError& error_payload = request->error_payload(); - Status error(static_cast(request->error_code()), - strings::StrCat("Error reported from /job:", - error_payload.source_task().job_name(), - "/task:", error_payload.source_task().task_id(), - ": ", request->error_message())); + absl::Status error( + static_cast(request->error_code()), + strings::StrCat( + "Error reported from /job:", error_payload.source_task().job_name(), + "/task:", error_payload.source_task().task_id(), ": ", + request->error_message())); error = MakeCoordinationError(error, error_payload); agent_->SetError(error); - done(OkStatus()); + done(absl::OkStatus()); } void CoordinationServiceRpcHandler::ReportErrorToServiceAsync( @@ -163,8 +164,8 @@ void CoordinationServiceRpcHandler::ReportErrorToServiceAsync( done(service_->ReportTaskError( request->error_origin(), MakeCoordinationError( - Status{static_cast(request->error_code()), - request->error_message()}, + absl::Status{static_cast(request->error_code()), + request->error_message()}, request->error_origin(), /*is_reported_error=*/true))); } @@ -182,7 +183,7 @@ void CoordinationServiceRpcHandler::GetTaskStateAsync( {request->source_task().begin(), request->source_task().end()}); absl::c_move(result, RepeatedFieldBackInserter(response->mutable_task_state())); - done(OkStatus()); + done(absl::OkStatus()); } void CoordinationServiceRpcHandler::InsertKeyValueAsync( @@ -209,7 +210,7 @@ void CoordinationServiceRpcHandler::GetKeyValueAsync( response->mutable_kv()->set_key(request->key()); service_->GetKeyValueAsync( request->key(), [response, done = std::move(done)]( - const StatusOr& status_or_value) { + const absl::StatusOr& status_or_value) { if (status_or_value.ok()) { response->mutable_kv()->set_value(status_or_value.value()); } @@ -233,7 +234,7 @@ void CoordinationServiceRpcHandler::TryGetKeyValueAsync( } response->mutable_kv()->set_key(request->key()); response->mutable_kv()->set_value(result.value()); - done(OkStatus()); + done(absl::OkStatus()); } void CoordinationServiceRpcHandler::GetKeyValueDirAsync( @@ -249,7 +250,7 @@ void CoordinationServiceRpcHandler::GetKeyValueDirAsync( service_->GetKeyValueDir(request->directory_key()); *response->mutable_kv() = {std::make_move_iterator(results.begin()), std::make_move_iterator(results.end())}; - done(OkStatus()); + done(absl::OkStatus()); } void CoordinationServiceRpcHandler::DeleteKeyValueAsync( @@ -279,7 +280,7 @@ void CoordinationServiceRpcHandler::BarrierAsync(const BarrierRequest* request, request->barrier_id(), absl::Milliseconds(request->barrier_timeout_in_ms()), request->source_task(), tasks, - [done = std::move(done)](const Status& status) { done(status); }); + [done = std::move(done)](const absl::Status& status) { done(status); }); } void CoordinationServiceRpcHandler::CancelBarrierAsync( diff --git a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_test.cc b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_test.cc index 7419b72841902..b111a6235c8cb 100644 --- a/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_test.cc +++ b/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_test.cc @@ -80,7 +80,7 @@ class TestCoordinationClient : public CoordinationClient { public: TestCoordinationClient() = default; - Status GetStatus() { + absl::Status GetStatus() { mutex_lock l(mu_); return status_; } @@ -88,7 +88,7 @@ class TestCoordinationClient : public CoordinationClient { void RegisterTaskAsync(CallOptions* opts, const RegisterTaskRequest* request, RegisterTaskResponse* response, StatusCallback done) override { - done(OkStatus()); + done(absl::OkStatus()); } void ReportErrorToTaskAsync(CallOptions* call_opts, @@ -96,9 +96,9 @@ class TestCoordinationClient : public CoordinationClient { ReportErrorToTaskResponse* response, StatusCallback done) override { mutex_lock l(mu_); - status_ = Status(static_cast(request->error_code()), - request->error_message()); - done(OkStatus()); + status_ = absl::Status(static_cast(request->error_code()), + request->error_message()); + done(absl::OkStatus()); } #define UNIMPLEMENTED(method) \ @@ -134,7 +134,7 @@ class TestCoordinationClient : public CoordinationClient { private: mutex mu_; - Status status_ TF_GUARDED_BY(mu_); + absl::Status status_ TF_GUARDED_BY(mu_); }; class TestCoordinationClientCache : public CoordinationClientCache { @@ -183,7 +183,8 @@ class CoordinationBarrierTest : public ::testing::Test { Env::Default(), config, std::move(client_cache)); // Register the tasks. for (int i = 0; i < num_tasks; ++i) { - Status s = coord_service_->RegisterTask(tasks_[i], /*incarnation=*/0); + absl::Status s = + coord_service_->RegisterTask(tasks_[i], /*incarnation=*/0); if (!s.ok()) { LOG(FATAL) << "RegisterTask() failed in CoordinationBarrierTest(): " << s; @@ -278,7 +279,7 @@ TEST_F(CoordinateTwoTasksTest, TestStandaloneService) { TF_ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); absl::Notification wait_for_all; - coord_service_->WaitForAllTasks(task_0_, {}, [&](Status s) { + coord_service_->WaitForAllTasks(task_0_, {}, [&](absl::Status s) { TF_ASSERT_OK(s); wait_for_all.Notify(); }); @@ -286,7 +287,7 @@ TEST_F(CoordinateTwoTasksTest, TestStandaloneService) { ASSERT_FALSE(wait_for_all.HasBeenNotified()); TF_ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); coord_service_->WaitForAllTasks(task_1_, {}, - [&](Status s) { TF_ASSERT_OK(s); }); + [&](absl::Status s) { TF_ASSERT_OK(s); }); // All tasks have registered. wait_for_all.WaitForNotification(); @@ -341,19 +342,19 @@ TEST(CoordinationServiceTest, TestCoordinatedJobs) { // Each coordinated task registers and waits for other tasks. absl::Notification register_chief; TF_ASSERT_OK(coord_service->RegisterTask(chief, /*incarnation=*/0)); - coord_service->WaitForAllTasks(chief, {}, [&](Status s) { + coord_service->WaitForAllTasks(chief, {}, [&](absl::Status s) { TF_ASSERT_OK(s); register_chief.Notify(); }); absl::Notification register_task0; TF_ASSERT_OK(coord_service->RegisterTask(task_0, /*incarnation=*/0)); - coord_service->WaitForAllTasks(task_0, {}, [&](Status s) { + coord_service->WaitForAllTasks(task_0, {}, [&](absl::Status s) { TF_ASSERT_OK(s); register_task0.Notify(); }); absl::Notification register_task1; TF_ASSERT_OK(coord_service->RegisterTask(task_1, /*incarnation=*/0)); - coord_service->WaitForAllTasks(task_1, {}, [&](Status s) { + coord_service->WaitForAllTasks(task_1, {}, [&](absl::Status s) { TF_ASSERT_OK(s); register_task1.Notify(); }); @@ -363,7 +364,8 @@ TEST(CoordinationServiceTest, TestCoordinatedJobs) { register_task1.WaitForNotification(); // Registering the evaluator task is unexpected - Status status = coord_service->RegisterTask(evaluator, /*incarnation=*/0); + absl::Status status = + coord_service->RegisterTask(evaluator, /*incarnation=*/0); EXPECT_TRUE(absl::IsInvalidArgument(status)) << status; EXPECT_TRUE(!status.message().empty()); } @@ -385,7 +387,8 @@ TEST(CoordinationServiceTest, RegisterTask_AlreadyConnected_Succeeds) { TF_ASSERT_OK(coord_service->RegisterTask(task_0, /*incarnation=*/0)); // Registration should succeed since it is the same task. - const Status status = coord_service->RegisterTask(task_0, /*incarnation=*/0); + const absl::Status status = + coord_service->RegisterTask(task_0, /*incarnation=*/0); TF_EXPECT_OK(status) << status; } @@ -407,7 +410,8 @@ TEST(CoordinationServiceTest, // Registration should fail since task already registered previously with a // different incarnation. Note that incarnation usually changes if an agent // restarts. - const Status status = coord_service->RegisterTask(task_0, /*incarnation=*/1); + const absl::Status status = + coord_service->RegisterTask(task_0, /*incarnation=*/1); EXPECT_TRUE(absl::IsAborted(status)) << status; EXPECT_TRUE(!status.message().empty()); @@ -430,7 +434,8 @@ TEST(CoordinationServiceTest, RegisterTask_AlreadyInError_Fails) { coord_service->ReportTaskError(task_0, errors::Internal("test_error"))); // Registration should fail since task already registered previously. - const Status status = coord_service->RegisterTask(task_0, /*incarnation=*/0); + const absl::Status status = + coord_service->RegisterTask(task_0, /*incarnation=*/0); EXPECT_TRUE(absl::IsAborted(status)) << status; EXPECT_TRUE(!status.message().empty()); @@ -474,7 +479,7 @@ TEST_F(CoordinateTwoTasksTest, TestTaskRestart) { TF_ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); // Simulate task restart scenario: trying to register to cluster again. - Status s = + absl::Status s = coord_service_->RegisterTask(task_1_, /*incarnation=*/random::New64()); EXPECT_TRUE(absl::IsAborted(s)) << s; // Aborted error is also propagated to other tasks in cluster. @@ -497,9 +502,9 @@ TEST_F(CoordinateTwoTasksTest, TestSetGetValues) { // Get simple key absl::Notification n1; - StatusOr ret; + absl::StatusOr ret; coord_service_->GetKeyValueAsync( - "key0", [&](const StatusOr& status_or_value) { + "key0", [&](const absl::StatusOr& status_or_value) { ret = status_or_value; n1.Notify(); }); @@ -509,7 +514,8 @@ TEST_F(CoordinateTwoTasksTest, TestSetGetValues) { // Get key with redundant slashes absl::Notification n2; coord_service_->GetKeyValueAsync( - "path//to///key1////", [&](const StatusOr& status_or_value) { + "path//to///key1////", + [&](const absl::StatusOr& status_or_value) { ret = status_or_value; n2.Notify(); }); @@ -521,7 +527,7 @@ TEST_F(CoordinateTwoTasksTest, TestSetGetValues) { // Get key that is not available absl::Notification n3; coord_service_->GetKeyValueAsync( - "key0", [&](const StatusOr& status_or_value) { + "key0", [&](const absl::StatusOr& status_or_value) { ret = status_or_value; n3.Notify(); }); @@ -541,7 +547,9 @@ TEST_F(CoordinateTwoTasksTest, TestSetGetValues) { // service shutdown. Hence, we use a shared pointer for notification so // that the it will not be deallocated before the pending callback is // cleaned up. - [n4](const StatusOr& status_or_value) { n4->Notify(); }); + [n4](const absl::StatusOr& status_or_value) { + n4->Notify(); + }); EXPECT_FALSE(n4->HasBeenNotified()); } @@ -554,7 +562,8 @@ TEST(CoordinationServiceTest, TryGetKeyValue) { Env::Default(), config, std::move(client_cache)); // Try to get nonexistent key. - StatusOr result = coord_service->TryGetKeyValue("test_key"); + absl::StatusOr result = + coord_service->TryGetKeyValue("test_key"); EXPECT_TRUE(absl::IsNotFound(result.status())); // Insert key value. @@ -660,7 +669,7 @@ TEST(CoordinationServiceTest, ListClusterDevices_TfDevice) { CoordinatedTask task_2; task_2.set_job_name("worker"); task_2.set_task_id(2); - Status status = OkStatus(); + absl::Status status = absl::OkStatus(); auto client_cache = std::make_unique(); std::unique_ptr coord_service = CoordinationServiceInterface::EnableCoordinationService( @@ -682,10 +691,10 @@ TEST(CoordinationServiceTest, ListClusterDevices_TfDevice) { // Each task sends its device info. DeviceInfo cluster_devices; coord_service->WaitForAllTasks(task_0, local_devices_0, - [&](Status s) { TF_ASSERT_OK(s); }); + [&](absl::Status s) { TF_ASSERT_OK(s); }); coord_service->WaitForAllTasks(task_1, local_devices_1, - [&](Status s) { TF_ASSERT_OK(s); }); - coord_service->WaitForAllTasks(task_2, local_devices_2, [&](Status s) { + [&](absl::Status s) { TF_ASSERT_OK(s); }); + coord_service->WaitForAllTasks(task_2, local_devices_2, [&](absl::Status s) { TF_ASSERT_OK(s); // Gather the cluster device info. cluster_devices = coord_service->ListClusterDevices(); @@ -716,7 +725,7 @@ TEST(CoordinationServiceTest, ListClusterDevices_XlaDevice) { CoordinatedTask task_2; task_2.set_job_name("worker"); task_2.set_task_id(2); - Status status = OkStatus(); + absl::Status status = absl::OkStatus(); auto client_cache = std::make_unique(); std::unique_ptr coord_service = CoordinationServiceInterface::EnableCoordinationService( @@ -757,10 +766,10 @@ TEST(CoordinationServiceTest, ListClusterDevices_XlaDevice) { // Make sure that cluster device order is deterministic even if devices are // sent out of order. coord_service->WaitForAllTasks(task_1, local_devices_1, - [&](Status s) { TF_ASSERT_OK(s); }); + [&](absl::Status s) { TF_ASSERT_OK(s); }); coord_service->WaitForAllTasks(task_0, local_devices_0, - [&](Status s) { TF_ASSERT_OK(s); }); - coord_service->WaitForAllTasks(task_2, local_devices_2, [&](Status s) { + [&](absl::Status s) { TF_ASSERT_OK(s); }); + coord_service->WaitForAllTasks(task_2, local_devices_2, [&](absl::Status s) { TF_ASSERT_OK(s); // Gather the cluster device info. cluster_devices = coord_service->ListClusterDevices(); @@ -794,7 +803,7 @@ TEST(CoordinationServiceTest, ListClusterDevices_DevicesAreNotAddedTwice) { CoordinatedTask task_1; task_1.set_job_name("worker"); task_1.set_task_id(1); - Status status = OkStatus(); + absl::Status status = absl::OkStatus(); auto client_cache = std::make_unique(); std::unique_ptr coord_service = CoordinationServiceInterface::EnableCoordinationService( @@ -812,19 +821,20 @@ TEST(CoordinationServiceTest, ListClusterDevices_DevicesAreNotAddedTwice) { // Task0 sends device info. DeviceInfo cluster_devices; coord_service->WaitForAllTasks(task_0, local_devices_0, - [](Status s) { TF_ASSERT_OK(s); }); + [](absl::Status s) { TF_ASSERT_OK(s); }); // Task0 sends device info sgain. coord_service->WaitForAllTasks(task_0, local_devices_0, - [](Status s) { TF_ASSERT_OK(s); }); - coord_service->WaitForAllTasks( - task_1, local_devices_1, - [coord_service = coord_service.get(), &cluster_devices, &n](Status s) { - TF_ASSERT_OK(s); - // Gather the cluster device info. - cluster_devices = coord_service->ListClusterDevices(); - n.Notify(); - }); + [](absl::Status s) { TF_ASSERT_OK(s); }); + coord_service->WaitForAllTasks(task_1, local_devices_1, + [coord_service = coord_service.get(), + &cluster_devices, &n](absl::Status s) { + TF_ASSERT_OK(s); + // Gather the cluster device info. + cluster_devices = + coord_service->ListClusterDevices(); + n.Notify(); + }); n.WaitForNotification(); // No duplicates found. @@ -840,37 +850,37 @@ TEST(CoordinationServiceTest, ListClusterDevices_DevicesAreNotAddedTwice) { TEST_F(CoordinationBarrierTest, Barrier) { const std::string barrier_id = "barrier_id"; absl::Duration timeout = absl::Seconds(5); - Status barrier_status_0; - Status barrier_status_1; - Status barrier_status_2; + absl::Status barrier_status_0; + absl::Status barrier_status_1; + absl::Status barrier_status_2; absl::Notification n_0; absl::Notification n_1; absl::Notification n_2; - GetCoordinationService()->BarrierAsync(barrier_id, timeout, GetTask(0), - /*participating_tasks=*/{}, - [&barrier_status_0, &n_0](Status s) { - barrier_status_0 = s; - n_0.Notify(); - }); - GetCoordinationService()->BarrierAsync(barrier_id, timeout, GetTask(1), - /*participating_tasks=*/{}, - [&barrier_status_1, &n_1](Status s) { - barrier_status_1 = s; - n_1.Notify(); - }); + GetCoordinationService()->BarrierAsync( + barrier_id, timeout, GetTask(0), + /*participating_tasks=*/{}, [&barrier_status_0, &n_0](absl::Status s) { + barrier_status_0 = s; + n_0.Notify(); + }); + GetCoordinationService()->BarrierAsync( + barrier_id, timeout, GetTask(1), + /*participating_tasks=*/{}, [&barrier_status_1, &n_1](absl::Status s) { + barrier_status_1 = s; + n_1.Notify(); + }); // Make sure barrier has not been exited prematurely. EXPECT_FALSE(n_0.HasBeenNotified()); EXPECT_FALSE(n_1.HasBeenNotified()); EXPECT_FALSE(n_2.HasBeenNotified()); // Last task calls the barrier. - GetCoordinationService()->BarrierAsync(barrier_id, timeout, GetTask(2), - /*participating_tasks=*/{}, - [&barrier_status_2, &n_2](Status s) { - barrier_status_2 = s; - n_2.Notify(); - }); + GetCoordinationService()->BarrierAsync( + barrier_id, timeout, GetTask(2), + /*participating_tasks=*/{}, [&barrier_status_2, &n_2](absl::Status s) { + barrier_status_2 = s; + n_2.Notify(); + }); EXPECT_TRUE(n_0.HasBeenNotified()); EXPECT_TRUE(n_1.HasBeenNotified()); @@ -883,22 +893,22 @@ TEST_F(CoordinationBarrierTest, Barrier) { TEST_F(CoordinationBarrierTest, BarrierWithSubsetOfTasks) { const std::string barrier_id = "barrier_id"; absl::Duration timeout = absl::Seconds(5); - Status barrier_status_0; - Status barrier_status_1; + absl::Status barrier_status_0; + absl::Status barrier_status_1; absl::Notification n_0; absl::Notification n_1; GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(0), /*participating_tasks=*/{GetTask(0), GetTask(1)}, - [&barrier_status_0, &n_0](Status s) { + [&barrier_status_0, &n_0](absl::Status s) { barrier_status_0 = s; n_0.Notify(); }); GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(1), /*participating_tasks=*/{GetTask(0), GetTask(1)}, - [&barrier_status_1, &n_1](Status s) { + [&barrier_status_1, &n_1](absl::Status s) { barrier_status_1 = s; n_1.Notify(); }); @@ -913,19 +923,19 @@ TEST_F(CoordinationBarrierTest, BarrierWithSubsetOfTasks) { TEST_F(CoordinationBarrierTest, BarrierWithMismatchedTasks) { const std::string barrier_id = "barrier_id"; absl::Duration timeout = absl::Seconds(5); - Status barrier_status_0; - Status barrier_status_1; + absl::Status barrier_status_0; + absl::Status barrier_status_1; GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(0), /*participating_tasks=*/{GetTask(0), GetTask(1)}, - [&barrier_status_0](Status s) { barrier_status_0 = s; }); + [&barrier_status_0](absl::Status s) { barrier_status_0 = s; }); // task_1's barrier call specified a conflicting set of tasks (task_2 instead // of task_0). GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(1), /*participating_tasks=*/{GetTask(1), GetTask(2)}, - [&barrier_status_1](Status s) { barrier_status_1 = s; }); + [&barrier_status_1](absl::Status s) { barrier_status_1 = s; }); EXPECT_TRUE(absl::IsInvalidArgument(barrier_status_0)); EXPECT_TRUE(absl::IsInvalidArgument(barrier_status_1)); @@ -934,20 +944,20 @@ TEST_F(CoordinationBarrierTest, BarrierWithMismatchedTasks) { TEST_F(CoordinationBarrierTest, BarrierByNonParticipatingTask) { const std::string barrier_id = "barrier_id"; absl::Duration timeout = absl::Seconds(5); - Status barrier_status_0; - Status barrier_status_1; + absl::Status barrier_status_0; + absl::Status barrier_status_1; absl::Notification n_0; absl::Notification n_1; GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(0), /*participating_tasks=*/{GetTask(0), GetTask(1)}, - [&barrier_status_0](Status s) { barrier_status_0 = s; }); + [&barrier_status_0](absl::Status s) { barrier_status_0 = s; }); // Task 2 unexpectedly calls a barrier that it is not participating in. GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(2), /*participating_tasks=*/{GetTask(0), GetTask(1)}, - [&barrier_status_1](Status s) { barrier_status_1 = s; }); + [&barrier_status_1](absl::Status s) { barrier_status_1 = s; }); // Barrier should fail for all tasks with the unexpected call. EXPECT_TRUE(absl::IsInvalidArgument(barrier_status_0)); @@ -957,7 +967,7 @@ TEST_F(CoordinationBarrierTest, BarrierByNonParticipatingTask) { TEST_F(CoordinationBarrierTest, BarrierByNonClusterTask) { const std::string barrier_id = "barrier_id"; absl::Duration timeout = absl::Seconds(5); - Status barrier_status_0; + absl::Status barrier_status_0; absl::Notification n_0; CoordinatedTask unspecified_task; unspecified_task.set_job_name("task_from_another_cluster"); @@ -966,7 +976,7 @@ TEST_F(CoordinationBarrierTest, BarrierByNonClusterTask) { GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(0), /*participating_tasks=*/{GetTask(0), unspecified_task}, - [&barrier_status_0, &n_0](Status s) { + [&barrier_status_0, &n_0](absl::Status s) { barrier_status_0 = s; n_0.Notify(); }); @@ -979,15 +989,15 @@ TEST_F(CoordinationBarrierTest, BarrierByNonClusterTask) { TEST_F(CoordinationBarrierTest, BarrierTimeout) { const std::string barrier_id = "barrier_id"; absl::Duration timeout = absl::Seconds(1); - Status barrier_status_0; + absl::Status barrier_status_0; absl::Notification n_0; - GetCoordinationService()->BarrierAsync(barrier_id, timeout, GetTask(0), - /*participating_tasks=*/{}, - [&barrier_status_0, &n_0](Status s) { - barrier_status_0 = s; - n_0.Notify(); - }); + GetCoordinationService()->BarrierAsync( + barrier_id, timeout, GetTask(0), + /*participating_tasks=*/{}, [&barrier_status_0, &n_0](absl::Status s) { + barrier_status_0 = s; + n_0.Notify(); + }); // Block until user-specified timeout. n_0.WaitForNotification(); @@ -1003,16 +1013,16 @@ TEST_F(CoordinationBarrierTest, BarrierTimeout) { TEST_F(CoordinationBarrierTest, BarrierReturnsPreviousError) { const std::string barrier_id = "barrier_id"; absl::Duration timeout = absl::Seconds(1); - Status barrier_status_0; - Status barrier_status_1; + absl::Status barrier_status_0; + absl::Status barrier_status_1; absl::Notification n_0; - GetCoordinationService()->BarrierAsync(barrier_id, timeout, GetTask(0), - /*participating_tasks=*/{}, - [&barrier_status_0, &n_0](Status s) { - barrier_status_0 = s; - n_0.Notify(); - }); + GetCoordinationService()->BarrierAsync( + barrier_id, timeout, GetTask(0), + /*participating_tasks=*/{}, [&barrier_status_0, &n_0](absl::Status s) { + barrier_status_0 = s; + n_0.Notify(); + }); TF_ASSERT_OK(GetCoordinationService()->ReportTaskError( GetTask(0), errors::Internal("test_error"))); // Block until barrier has failed due to task error. @@ -1021,7 +1031,7 @@ TEST_F(CoordinationBarrierTest, BarrierReturnsPreviousError) { GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(1), /*participating_tasks=*/{}, - [&barrier_status_1](Status s) { barrier_status_1 = s; }); + [&barrier_status_1](absl::Status s) { barrier_status_1 = s; }); EXPECT_TRUE(absl::IsInternal(barrier_status_0)); EXPECT_TRUE(absl::IsInternal(barrier_status_1)); @@ -1030,13 +1040,13 @@ TEST_F(CoordinationBarrierTest, BarrierReturnsPreviousError) { TEST_F(CoordinationBarrierTest, BarrierCancelled) { const std::string barrier_id = "barrier_id"; absl::Duration timeout = absl::Seconds(5); - Status barrier_status; + absl::Status barrier_status; GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(0), /*participating_tasks=*/{}, - [&barrier_status](Status s) { barrier_status = s; }); - Status cancelled_status = + [&barrier_status](absl::Status s) { barrier_status = s; }); + absl::Status cancelled_status = GetCoordinationService()->CancelBarrier(barrier_id, GetTask(0)); EXPECT_TRUE(absl::IsCancelled(barrier_status)); @@ -1046,7 +1056,7 @@ TEST_F(CoordinationBarrierTest, BarrierCancelled) { TEST_F(CoordinationBarrierTest, CancelNonExistentBarrier_FutureBarrierFails) { const std::string barrier_id = "cancelled_barrier_id"; absl::Duration timeout = absl::Seconds(1); - Status barrier_status; + absl::Status barrier_status; // Cancel barrier should still succeed. TF_ASSERT_OK(GetCoordinationService()->CancelBarrier(barrier_id, GetTask(0))); @@ -1054,7 +1064,7 @@ TEST_F(CoordinationBarrierTest, CancelNonExistentBarrier_FutureBarrierFails) { GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(0), /*participating_tasks=*/{}, - [&barrier_status](Status s) { barrier_status = s; }); + [&barrier_status](absl::Status s) { barrier_status = s; }); EXPECT_TRUE(absl::IsCancelled(barrier_status)) << barrier_status; } @@ -1062,24 +1072,24 @@ TEST_F(CoordinationBarrierTest, CancelNonExistentBarrier_FutureBarrierFails) { TEST_F(CoordinationBarrierTest, CancelAfterBarrierHasPassed) { const std::string barrier_id = "barrier_id"; absl::Duration timeout = absl::Seconds(5); - Status barrier_status_0; - Status barrier_status_1; - Status barrier_status_2; + absl::Status barrier_status_0; + absl::Status barrier_status_1; + absl::Status barrier_status_2; GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(0), /*participating_tasks=*/{}, - [&barrier_status_0](Status s) { barrier_status_0 = s; }); + [&barrier_status_0](absl::Status s) { barrier_status_0 = s; }); GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(1), /*participating_tasks=*/{}, - [&barrier_status_1](Status s) { barrier_status_1 = s; }); + [&barrier_status_1](absl::Status s) { barrier_status_1 = s; }); GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(2), /*participating_tasks=*/{}, - [&barrier_status_2](Status s) { barrier_status_2 = s; }); + [&barrier_status_2](absl::Status s) { barrier_status_2 = s; }); // Cancel barrier should fail if barrier has already been passed. - Status cancelled_status = + absl::Status cancelled_status = GetCoordinationService()->CancelBarrier(barrier_id, GetTask(0)); EXPECT_TRUE(absl::IsFailedPrecondition(cancelled_status)); @@ -1091,38 +1101,38 @@ TEST_F(CoordinationBarrierTest, CancelAfterBarrierHasPassed) { TEST_F(CoordinationBarrierTest, PassedBarrierReturnsImmediately) { const std::string barrier_id = "barrier_id"; absl::Duration timeout = absl::Seconds(5); - Status barrier_status_0; - Status barrier_status_1; - Status barrier_status_2; - Status barrier_status_repeat; + absl::Status barrier_status_0; + absl::Status barrier_status_1; + absl::Status barrier_status_2; + absl::Status barrier_status_repeat; absl::Notification n0; absl::Notification n1; absl::Notification n2; absl::Notification n_repeat; - GetCoordinationService()->BarrierAsync(barrier_id, timeout, GetTask(0), - /*participating_tasks=*/{}, - [&barrier_status_0, &n0](Status s) { - barrier_status_0 = s; - n0.Notify(); - }); - GetCoordinationService()->BarrierAsync(barrier_id, timeout, GetTask(1), - /*participating_tasks=*/{}, - [&barrier_status_1, &n1](Status s) { - barrier_status_1 = s; - n1.Notify(); - }); - GetCoordinationService()->BarrierAsync(barrier_id, timeout, GetTask(2), - /*participating_tasks=*/{}, - [&barrier_status_2, &n2](Status s) { - barrier_status_2 = s; - n2.Notify(); - }); + GetCoordinationService()->BarrierAsync( + barrier_id, timeout, GetTask(0), + /*participating_tasks=*/{}, [&barrier_status_0, &n0](absl::Status s) { + barrier_status_0 = s; + n0.Notify(); + }); + GetCoordinationService()->BarrierAsync( + barrier_id, timeout, GetTask(1), + /*participating_tasks=*/{}, [&barrier_status_1, &n1](absl::Status s) { + barrier_status_1 = s; + n1.Notify(); + }); + GetCoordinationService()->BarrierAsync( + barrier_id, timeout, GetTask(2), + /*participating_tasks=*/{}, [&barrier_status_2, &n2](absl::Status s) { + barrier_status_2 = s; + n2.Notify(); + }); // Repeated call should return the same result. GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(1), /*participating_tasks=*/{}, - [&barrier_status_repeat, &n_repeat](Status s) { + [&barrier_status_repeat, &n_repeat](absl::Status s) { barrier_status_repeat = s; n_repeat.Notify(); }); @@ -1143,12 +1153,12 @@ TEST_F(CoordinationBarrierTest, BarrierFailsIfTaskIsAlreadyInError) { // Set task 0 to error state. TF_ASSERT_OK(GetCoordinationService()->ReportTaskError( GetTask(0), errors::Internal("test_error"))); - Status barrier_status; + absl::Status barrier_status; GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(1), /*participating_tasks=*/{}, - [&barrier_status](Status s) { barrier_status = s; }); + [&barrier_status](absl::Status s) { barrier_status = s; }); EXPECT_TRUE(absl::IsInternal(barrier_status)); } @@ -1157,14 +1167,14 @@ TEST_F(CoordinationBarrierTest, BarrierFailsUponTaskError) { const std::string barrier_id = "barrier_id"; absl::Duration timeout = absl::Seconds(5); absl::Notification n0; - Status barrier_status; - - GetCoordinationService()->BarrierAsync(barrier_id, timeout, GetTask(0), - /*participating_tasks=*/{}, - [&barrier_status, &n0](Status s) { - barrier_status = s; - n0.Notify(); - }); + absl::Status barrier_status; + + GetCoordinationService()->BarrierAsync( + barrier_id, timeout, GetTask(0), + /*participating_tasks=*/{}, [&barrier_status, &n0](absl::Status s) { + barrier_status = s; + n0.Notify(); + }); TF_ASSERT_OK(GetCoordinationService()->ReportTaskError( GetTask(0), errors::Internal("test_error"))); n0.WaitForNotification(); @@ -1176,9 +1186,9 @@ TEST_F(CoordinationBarrierTest, BarrierStillBlocksIfSameTaskCallsOngoingBarrierRepeatedly) { const std::string barrier_id = "barrier_id"; absl::Duration timeout = absl::Seconds(5); - Status barrier_status_0; - Status barrier_status_1; - Status barrier_status_2; + absl::Status barrier_status_0; + absl::Status barrier_status_1; + absl::Status barrier_status_2; absl::Notification n_0; absl::Notification n_1; absl::Notification n_2; @@ -1186,7 +1196,7 @@ TEST_F(CoordinationBarrierTest, GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(0), /*participating_tasks=*/{GetTask(0), GetTask(1)}, - [&barrier_status_0, &n_0](Status s) { + [&barrier_status_0, &n_0](absl::Status s) { barrier_status_0 = s; n_0.Notify(); }); @@ -1194,7 +1204,7 @@ TEST_F(CoordinationBarrierTest, GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(0), /*participating_tasks=*/{GetTask(0), GetTask(1)}, - [&barrier_status_1, &n_1](Status s) { + [&barrier_status_1, &n_1](absl::Status s) { barrier_status_1 = s; n_1.Notify(); }); @@ -1205,7 +1215,7 @@ TEST_F(CoordinationBarrierTest, GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(1), /*participating_tasks=*/{GetTask(0), GetTask(1)}, - [&barrier_status_2, &n_2](Status s) { + [&barrier_status_2, &n_2](absl::Status s) { barrier_status_2 = s; n_2.Notify(); }); @@ -1244,14 +1254,15 @@ TEST_F(CoordinateTwoTasksTest, Reset_FailsOngoingBarrier) { EnableCoordinationService(/*has_service_to_client_connection=*/true, /*enable_shutdown_barrier=*/false); TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); - Status barrier_status; + absl::Status barrier_status; absl::Notification barrier_n; - coord_service_->BarrierAsync( - "ongoing_barrier", absl::InfiniteDuration(), task_0_, - /*participating_tasks=*/{}, [&barrier_status, &barrier_n](Status s) { - barrier_status = s; - barrier_n.Notify(); - }); + coord_service_->BarrierAsync("ongoing_barrier", absl::InfiniteDuration(), + task_0_, + /*participating_tasks=*/{}, + [&barrier_status, &barrier_n](absl::Status s) { + barrier_status = s; + barrier_n.Notify(); + }); TF_EXPECT_OK(coord_service_->ResetTask(task_0_)); @@ -1266,7 +1277,7 @@ TEST_F(CoordinateTwoTasksTest, Shutdown_HeartbeatsAreAcceptedForAGracePeriod) { TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); absl::Notification n; - coord_service_->ShutdownTaskAsync(task_0_, [&n](Status s) { + coord_service_->ShutdownTaskAsync(task_0_, [&n](absl::Status s) { TF_EXPECT_OK(s); n.Notify(); }); @@ -1287,17 +1298,18 @@ TEST_F(CoordinateTwoTasksTest, Shutdown_FailsOngoingBarrier) { EnableCoordinationService(/*has_service_to_client_connection=*/true, /*enable_shutdown_barrier=*/false); TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); - Status barrier_status; + absl::Status barrier_status; absl::Notification barrier_n; - coord_service_->BarrierAsync( - "ongoing_barrier", absl::InfiniteDuration(), task_0_, - /*participating_tasks=*/{}, [&barrier_status, &barrier_n](Status s) { - barrier_status = s; - barrier_n.Notify(); - }); + coord_service_->BarrierAsync("ongoing_barrier", absl::InfiniteDuration(), + task_0_, + /*participating_tasks=*/{}, + [&barrier_status, &barrier_n](absl::Status s) { + barrier_status = s; + barrier_n.Notify(); + }); absl::Notification shutdown_n; - coord_service_->ShutdownTaskAsync(task_0_, [&shutdown_n](Status s) { + coord_service_->ShutdownTaskAsync(task_0_, [&shutdown_n](absl::Status s) { TF_EXPECT_OK(s); shutdown_n.Notify(); }); @@ -1313,13 +1325,13 @@ TEST_F(CoordinateTwoTasksTest, ShutdownWithBarrier_BarrierSucceeds) { /*enable_shutdown_barrier=*/true); TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); TF_EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); - Status barrier_status; - Status barrier_status_2; + absl::Status barrier_status; + absl::Status barrier_status_2; coord_service_->ShutdownTaskAsync( - task_0_, [&barrier_status](Status s) { barrier_status = s; }); + task_0_, [&barrier_status](absl::Status s) { barrier_status = s; }); coord_service_->ShutdownTaskAsync( - task_1_, [&barrier_status_2](Status s) { barrier_status_2 = s; }); + task_1_, [&barrier_status_2](absl::Status s) { barrier_status_2 = s; }); TF_EXPECT_OK(barrier_status); TF_EXPECT_OK(barrier_status_2); @@ -1337,13 +1349,14 @@ TEST_F(CoordinateTwoTasksTest, /*enable_shutdown_barrier=*/true); TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); TF_EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); - Status barrier_status; + absl::Status barrier_status; absl::Notification n; - coord_service_->ShutdownTaskAsync(task_0_, [&n, &barrier_status](Status s) { - barrier_status = s; - n.Notify(); - }); + coord_service_->ShutdownTaskAsync(task_0_, + [&n, &barrier_status](absl::Status s) { + barrier_status = s; + n.Notify(); + }); // Block until barrier times out. n.WaitForNotification(); @@ -1354,7 +1367,7 @@ TEST_F(CoordinateTwoTasksTest, TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); // Other task is alerted that shutdown has been initiated without it. - Status other_task_status = client_1_.GetStatus(); + absl::Status other_task_status = client_1_.GetStatus(); EXPECT_TRUE(absl::IsInternal(other_task_status)) << other_task_status; } @@ -1364,13 +1377,14 @@ TEST_F(CoordinateTwoTasksTest, /*enable_shutdown_barrier=*/true); TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); TF_EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); - Status barrier_status; + absl::Status barrier_status; absl::Notification n; - coord_service_->ShutdownTaskAsync(task_0_, [&n, &barrier_status](Status s) { - barrier_status = s; - n.Notify(); - }); + coord_service_->ShutdownTaskAsync(task_0_, + [&n, &barrier_status](absl::Status s) { + barrier_status = s; + n.Notify(); + }); // Block until barrier times out. n.WaitForNotification(); // Provide time for coordination service to shut down after barrier timeout. @@ -1383,7 +1397,7 @@ TEST_F(CoordinateTwoTasksTest, // error propagation. // Task 1 still sends unexpected heartbeat because it doesn't know that // service has stopped yet, which should fail. - Status s = coord_service_->RecordHeartbeat(task_1_, incarnation_1_); + absl::Status s = coord_service_->RecordHeartbeat(task_1_, incarnation_1_); EXPECT_TRUE(absl::IsInvalidArgument(s)) << s; } diff --git a/third_party/tsl/tsl/distributed_runtime/preemption/BUILD b/third_party/tsl/tsl/distributed_runtime/preemption/BUILD index 858caa90a6dd2..5a661c4cae506 100644 --- a/third_party/tsl/tsl/distributed_runtime/preemption/BUILD +++ b/third_party/tsl/tsl/distributed_runtime/preemption/BUILD @@ -1,11 +1,11 @@ -load("//tsl/platform:rules_cc.bzl", "cc_library") -load("//tsl/platform:build_config.bzl", "tsl_cc_test") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "get_compatible_with_portable", "tsl_grpc_cc_dependencies") -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl/platform:build_config.bzl", "tsl_cc_test") +load("//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([ + default_visibility = internal_visibility([ "//tsl:internal", ]), licenses = ["notice"], diff --git a/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier.cc b/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier.cc index f9838f2823b25..56226a85896e5 100644 --- a/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier.cc +++ b/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier.cc @@ -94,13 +94,14 @@ void SigtermNotifier::StartListenerThread() { } // namespace -StatusOr PreemptionNotifier::WillBePreemptedAt() { +absl::StatusOr PreemptionNotifier::WillBePreemptedAt() { absl::Notification n; - StatusOr result; - WillBePreemptedAtAsync([&n, &result](StatusOr async_result) { - result = async_result; - n.Notify(); - }); + absl::StatusOr result; + WillBePreemptedAtAsync( + [&n, &result](absl::StatusOr async_result) { + result = async_result; + n.Notify(); + }); n.WaitForNotification(); return result; } @@ -117,7 +118,7 @@ void PreemptionNotifier::WillBePreemptedAtAsync(PreemptTimeCallback callback) { } void PreemptionNotifier::NotifyRegisteredListeners( - StatusOr death_time) { + absl::StatusOr death_time) { mutex_lock l(mu_); if (death_time.ok()) { death_time_ = death_time.value(); diff --git a/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier.h b/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier.h index 53941ceea6493..075af20fcd334 100644 --- a/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier.h +++ b/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier.h @@ -75,7 +75,7 @@ namespace tsl { class PreemptionNotifier { public: - typedef std::function)> PreemptTimeCallback; + typedef std::function)> PreemptTimeCallback; using PreemptionNotifierFactory = std::function(Env* env)>; @@ -112,7 +112,7 @@ class PreemptionNotifier { // termination will occur once the listener receives the preemption // notification. If no death time is specified, absl::Now() is returned. // Returns error::Cancelled if UnregisterListeners() is called. - StatusOr WillBePreemptedAt(); + absl::StatusOr WillBePreemptedAt(); // Registers a callback that takes the death time as input once the listener // receives the preemption notification. @@ -126,7 +126,7 @@ class PreemptionNotifier { Env* GetEnv() { return env_; } // Invokes all pending callbacks upon receipt of preemption notice with death // time or errors (e.g. cancellation during shutdown). - void NotifyRegisteredListeners(StatusOr death_time); + void NotifyRegisteredListeners(absl::StatusOr death_time); private: static std::unordered_map* diff --git a/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier_test.cc b/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier_test.cc index abd1d24c9f51e..d083e2ef1ba2a 100644 --- a/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier_test.cc +++ b/third_party/tsl/tsl/distributed_runtime/preemption/preemption_notifier_test.cc @@ -59,7 +59,7 @@ TEST_F(PreemptNotifierTest, WillBePreemptedAt) { []() { std::raise(SIGTERM); }); // Preempt time should be current timestamp. - StatusOr result = preempt_notifier->WillBePreemptedAt(); + absl::StatusOr result = preempt_notifier->WillBePreemptedAt(); TF_CHECK_OK(result.status()); absl::Time preempt_time = result.value(); @@ -84,7 +84,7 @@ TEST_F(PreemptNotifierTest, env->SleepForMicroseconds(absl::ToInt64Microseconds(absl::Seconds(2))); // Preempt time should be current timestamp. - StatusOr result = preempt_notifier->WillBePreemptedAt(); + absl::StatusOr result = preempt_notifier->WillBePreemptedAt(); TF_CHECK_OK(result.status()); absl::Time preempt_time = result.value(); @@ -105,17 +105,17 @@ TEST_F(PreemptNotifierTest, WillBePreemptedAtAsync_SameResultForAllCallbacks) { []() { std::raise(SIGTERM); }); // Preempt time should be current timestamp. - StatusOr preempt_time; - StatusOr preempt_time_2; + absl::StatusOr preempt_time; + absl::StatusOr preempt_time_2; absl::Notification n; absl::Notification n_2; preempt_notifier->WillBePreemptedAtAsync( - [&preempt_time, &n](StatusOr result) { + [&preempt_time, &n](absl::StatusOr result) { preempt_time = result; n.Notify(); }); preempt_notifier->WillBePreemptedAtAsync( - [&preempt_time_2, &n_2](StatusOr result) { + [&preempt_time_2, &n_2](absl::StatusOr result) { preempt_time_2 = result; n_2.Notify(); }); @@ -135,7 +135,7 @@ TEST_F(PreemptNotifierTest, Reset_TwoDifferentPreemptTimesRecorded) { // Raise first signal. std::raise(SIGTERM); - StatusOr result = preempt_notifier->WillBePreemptedAt(); + absl::StatusOr result = preempt_notifier->WillBePreemptedAt(); TF_CHECK_OK(result.status()); absl::Time preempt_time = result.value(); @@ -154,10 +154,10 @@ TEST_F(PreemptNotifierTest, DestructorCancelsPendingCalls) { auto env = Env::Default(); std::unique_ptr preempt_notifier = PreemptionNotifier::CreatePreemptionNotifier("sigterm", env); - StatusOr result; + absl::StatusOr result; absl::Notification n; preempt_notifier->WillBePreemptedAtAsync( - [&result, &n](StatusOr status_or_time) { + [&result, &n](absl::StatusOr status_or_time) { result = status_or_time; n.Notify(); }); diff --git a/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.cc b/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.cc index a4ca1ac9159de..d7e9c1280e63f 100644 --- a/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.cc +++ b/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.cc @@ -73,11 +73,12 @@ class PreemptionSyncManagerImpl : public PreemptionSyncManager { ~PreemptionSyncManagerImpl() override { shutdown_.Notify(); } - Status Initialize(CoordinationServiceAgent* agent) override; - Status Initialize(CoordinationServiceAgent* agent, - const std::string& preemption_notifier_type) override; - Status Initialize(CoordinationServiceAgent* agent, - std::unique_ptr notifier) override; + absl::Status Initialize(CoordinationServiceAgent* agent) override; + absl::Status Initialize(CoordinationServiceAgent* agent, + const std::string& preemption_notifier_type) override; + absl::Status Initialize( + CoordinationServiceAgent* agent, + std::unique_ptr notifier) override; bool ReachedSyncPoint(int step_counter) override; private: @@ -103,11 +104,12 @@ class PreemptionSyncManagerImpl : public PreemptionSyncManager { std::shared_ptr call_opts_; }; -Status PreemptionSyncManagerImpl::Initialize(CoordinationServiceAgent* agent) { +absl::Status PreemptionSyncManagerImpl::Initialize( + CoordinationServiceAgent* agent) { return Initialize(agent, "sigterm"); } -Status PreemptionSyncManagerImpl::Initialize( +absl::Status PreemptionSyncManagerImpl::Initialize( CoordinationServiceAgent* agent, const std::string& preemption_notifier_type) { TF_ASSIGN_OR_RETURN(Env * env, agent->GetEnv()); @@ -115,7 +117,7 @@ Status PreemptionSyncManagerImpl::Initialize( preemption_notifier_type, env)); } -Status PreemptionSyncManagerImpl::Initialize( +absl::Status PreemptionSyncManagerImpl::Initialize( CoordinationServiceAgent* agent, std::unique_ptr notifier) { TF_ASSIGN_OR_RETURN(Env * env, agent->GetEnv()); @@ -131,7 +133,7 @@ Status PreemptionSyncManagerImpl::Initialize( * service when death time is within kProtocolDuration. */ preemption_notifier_->WillBePreemptedAtAsync( - [agent = agent_, task_name](StatusOr death_time) { + [agent = agent_, task_name](absl::StatusOr death_time) { if (!death_time.ok()) { // The preemption notifier invokes callback with Cancelled error when // its being destructed. @@ -147,8 +149,8 @@ Status PreemptionSyncManagerImpl::Initialize( } notified_metric->GetCell()->Set(true); // Notify coordination service about preemption notice. - const Status s = agent->InsertKeyValue(kPreemptionNoticeKey, - absl::FormatTime(*death_time)); + const absl::Status s = agent->InsertKeyValue( + kPreemptionNoticeKey, absl::FormatTime(*death_time)); LOG(INFO) << "Notified coordination service that this task will " "be preempted at " << *death_time << ". Status: " << s; @@ -159,7 +161,7 @@ Status PreemptionSyncManagerImpl::Initialize( */ call_opts_ = agent_->GetKeyValueAsync( kPreemptionNoticeKey, - [this, agent = agent_](StatusOr status_or_death_time) { + [this, agent = agent_](absl::StatusOr status_or_death_time) { if (errors::IsCancelled(status_or_death_time.status())) { // The agent cancels pending GetKeyValue RPCs because of shutdown, // so simply log and return. @@ -177,7 +179,7 @@ Status PreemptionSyncManagerImpl::Initialize( // CancelPreemptionBarrier() cannot be used because this may be // triggered after preemption sync manager has been destroyed. agent->CancelBarrierAsync( - kPreemptionBarrier, [](const Status& status) { + kPreemptionBarrier, [](const absl::Status& status) { if (!status.ok()) { LOG(ERROR) << "Failed to cancel preemption barrier: " << status; @@ -205,7 +207,7 @@ Status PreemptionSyncManagerImpl::Initialize( death_time))); }); - return OkStatus(); + return absl::OkStatus(); } void PreemptionSyncManagerImpl::ComputeSyncCallCounter(absl::Time death_time) { @@ -231,7 +233,7 @@ void PreemptionSyncManagerImpl::ComputeSyncCallCounter(absl::Time death_time) { // `preemption_sync_counter_` or the protocol failed. This ensures correctness // of the preemption sync protocol. mutex_lock l(mu_); - const Status notified_status = agent_->InsertKeyValue( + const absl::Status notified_status = agent_->InsertKeyValue( current_call_counter_key_, std::to_string(call_counter_)); if (!notified_status.ok()) { LOG(ERROR) << "Preemption sync failed - could not inform service of " @@ -243,7 +245,7 @@ void PreemptionSyncManagerImpl::ComputeSyncCallCounter(absl::Time death_time) { // 3. Impose a barrier to wait until everybody sends their current call // counter. - const Status barrier_status = + const absl::Status barrier_status = agent_->WaitAtBarrier(kPreemptionBarrier, kPreemptionBarrierTimeout, {}); if (!barrier_status.ok()) { LOG(ERROR) << "Preemption sync barrier failed: " << barrier_status; @@ -251,7 +253,7 @@ void PreemptionSyncManagerImpl::ComputeSyncCallCounter(absl::Time death_time) { } // 4. Retrieve every task's current call counter. - StatusOr> all_counters = + absl::StatusOr> all_counters = agent_->GetKeyValueDir(kPreemptionCounterDirKey); if (!all_counters.ok()) { LOG(ERROR) << "Preemption sync failed - unable to retrieve call counters: " @@ -287,11 +289,12 @@ void PreemptionSyncManagerImpl::ComputeSyncCallCounter(absl::Time death_time) { } void PreemptionSyncManagerImpl::CancelPreemptionBarrier() { - agent_->CancelBarrierAsync(kPreemptionBarrier, [](const Status& status) { - if (!status.ok()) { - LOG(ERROR) << "Failed to cancel preemption barrier: " << status; - } - }); + agent_->CancelBarrierAsync( + kPreemptionBarrier, [](const absl::Status& status) { + if (!status.ok()) { + LOG(ERROR) << "Failed to cancel preemption barrier: " << status; + } + }); } bool PreemptionSyncManagerImpl::ReachedSyncPoint(int step_counter) { diff --git a/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.h b/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.h index 2c359b686ffc5..baf1911cac2d6 100644 --- a/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.h +++ b/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager.h @@ -35,11 +35,13 @@ class PreemptionSyncManager { public: virtual ~PreemptionSyncManager() = default; - virtual Status Initialize(CoordinationServiceAgent* agent) = 0; - virtual Status Initialize(CoordinationServiceAgent* agent, - const std::string& preemption_notifier_type) = 0; - virtual Status Initialize(CoordinationServiceAgent* agent, - std::unique_ptr notifier) = 0; + virtual absl::Status Initialize(CoordinationServiceAgent* agent) = 0; + virtual absl::Status Initialize( + CoordinationServiceAgent* agent, + const std::string& preemption_notifier_type) = 0; + virtual absl::Status Initialize( + CoordinationServiceAgent* agent, + std::unique_ptr notifier) = 0; // Check if the synchronized point has been reached. When a task has been // preempted, a safe sync point will be determined by using the fastest task's diff --git a/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc b/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc index 4caed02d705ad..82d578c2b9658 100644 --- a/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc +++ b/third_party/tsl/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc @@ -158,7 +158,7 @@ class PreemptionSyncManagerTest : public ::testing::Test { std::unique_ptr coord_client2 = absl::WrapUnique(NewGrpcCoordinationClient( grpc_server_->InProcessChannel(::grpc::ChannelArguments()))); - auto error_fn = [](const Status& status) { + auto error_fn = [](const absl::Status& status) { LOG(ERROR) << "Coordination service agent in error status: " << status; }; CoordinationServiceConfig coord_config; diff --git a/third_party/tsl/tsl/distributed_runtime/rpc/BUILD b/third_party/tsl/tsl/distributed_runtime/rpc/BUILD index 9b52ee1cb0167..9be54e18af36d 100644 --- a/third_party/tsl/tsl/distributed_runtime/rpc/BUILD +++ b/third_party/tsl/tsl/distributed_runtime/rpc/BUILD @@ -1,14 +1,14 @@ # Description: # RPC communication interfaces and implementations for TensorFlow. -load("//tsl/platform:rules_cc.bzl", "cc_library") -load("//tsl/platform:build_config.bzl", "tf_proto_library", "tsl_cc_test") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "tsl_grpc_cc_dependencies") -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl/platform:build_config.bzl", "tf_proto_library", "tsl_cc_test") +load("//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([ + default_visibility = internal_visibility([ "//tsl:internal", ]), licenses = ["notice"], @@ -37,7 +37,6 @@ cc_library( hdrs = ["grpc_util.h"], deps = [ "//tsl/platform:protobuf", - "//tsl/platform:random", "//tsl/platform:status", "//tsl/platform:stringpiece", "//tsl/platform:stringprintf", @@ -57,6 +56,7 @@ tsl_cc_test( deps = [ ":grpc_util", ":test_request_proto_cc_impl", + "//tsl/platform:env_impl", "//tsl/platform:errors", "//tsl/platform:test", "//tsl/platform:test_benchmark", @@ -95,8 +95,8 @@ cc_library( "//tsl/platform:thread_annotations", "//tsl/platform:types", "//tsl/protobuf:rpc_options_proto_cc", - "//tsl/util:device_name_utils", "@com_google_absl//absl/strings", + "@xla//xla/tsl/util:device_name_utils", ] + tsl_grpc_cc_dependencies(), ) @@ -109,11 +109,12 @@ tsl_cc_test( deps = [ ":grpc_channel", "//tsl/lib/core:status_test_util", + "//tsl/platform:env_impl", "//tsl/platform:strcat", "//tsl/platform:test", "//tsl/platform:test_main", "//tsl/protobuf:rpc_options_proto_cc_impl", - "//tsl/util:device_name_utils", + "@xla//xla/tsl/util:device_name_utils", ], ) @@ -128,8 +129,8 @@ cc_library( "//tsl/platform:errors", "//tsl/platform:status", "//tsl/platform:strcat", - "//tsl/util:env_var", "@com_google_absl//absl/status", + "@xla//xla/tsl/util:env_var", ] + tsl_grpc_cc_dependencies(), ) diff --git a/third_party/tsl/tsl/distributed_runtime/rpc/coordination/BUILD b/third_party/tsl/tsl/distributed_runtime/rpc/coordination/BUILD index 5ba50b9d77ac5..6e9f5ca0488bb 100644 --- a/third_party/tsl/tsl/distributed_runtime/rpc/coordination/BUILD +++ b/third_party/tsl/tsl/distributed_runtime/rpc/coordination/BUILD @@ -1,10 +1,10 @@ -load("//tsl/platform:rules_cc.bzl", "cc_library") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "tsl_grpc_cc_dependencies") -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([ + default_visibility = internal_visibility([ "//tsl:internal", ]), licenses = ["notice"], diff --git a/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.cc b/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.cc index ba886c1a7bf1d..ba12449f03bf2 100644 --- a/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.cc +++ b/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/escaping.h" #include "absl/strings/match.h" #include "absl/strings/str_split.h" +#include "xla/tsl/util/device_name_utils.h" #include "tsl/distributed_runtime/rpc/grpc_channel_common.h" #include "tsl/lib/gtl/map_util.h" #include "tsl/platform/errors.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tsl/platform/thread_annotations.h" #include "tsl/platform/types.h" #include "tsl/protobuf/rpc_options.pb.h" -#include "tsl/util/device_name_utils.h" namespace tsl { @@ -49,10 +49,10 @@ string MakeAddress(const string& job, int replica, int task) { } // Allows the host to be a raw IP (either v4 or v6). -Status ValidateHostPortPair(const string& host_port) { +absl::Status ValidateHostPortPair(const string& host_port) { string bns_prefix = "/bns/"; if (host_port.substr(0, bns_prefix.length()) == bns_prefix) { - return OkStatus(); + return absl::OkStatus(); } uint32 port; auto colon_index = host_port.find_last_of(':'); @@ -61,7 +61,7 @@ Status ValidateHostPortPair(const string& host_port) { return errors::InvalidArgument("Could not interpret \"", host_port, "\" as a host-port pair."); } - return OkStatus(); + return absl::OkStatus(); } ::grpc::ChannelArguments* CreateDefaultChannelArguments() { @@ -140,21 +140,22 @@ ::grpc::ChannelArguments GetChannelArguments(const RPCOptions* rpc_options) { return args; } -Status NewHostPortGrpcChannel(const string& target, - const RPCOptions* rpc_options, - SharedGrpcChannelPtr* channel_pointer) { +absl::Status NewHostPortGrpcChannel(const string& target, + const RPCOptions* rpc_options, + SharedGrpcChannelPtr* channel_pointer) { // Minimally ensure that the target is valid TF_RETURN_IF_ERROR(ValidateHostPortPair(target)); ::grpc::ChannelArguments args = GetChannelArguments(rpc_options); *channel_pointer = ::grpc::CreateCustomChannel( "dns:///" + target, ::grpc::InsecureChannelCredentials(), args); - return OkStatus(); + return absl::OkStatus(); } ChannelCreationFunction ConvertToChannelCreationFunction( - const std::function& new_channel_func_ptr) { + const std::function& + new_channel_func_ptr) { return [new_channel_func_ptr](const string& target) -> SharedGrpcChannelPtr { SharedGrpcChannelPtr channel_ptr; if (new_channel_func_ptr(target, /*rpc_options=*/nullptr, &channel_ptr) @@ -166,7 +167,7 @@ ChannelCreationFunction ConvertToChannelCreationFunction( }; } -Status GrpcChannelSpec::AddHostPortsJob( +absl::Status GrpcChannelSpec::AddHostPortsJob( const string& job_id, const std::map& host_ports) { if (!job_ids_.insert(job_id).second) { return errors::InvalidArgument( @@ -176,7 +177,7 @@ Status GrpcChannelSpec::AddHostPortsJob( TF_RETURN_IF_ERROR(ValidateHostPortPair(id_host_port.second)); } host_ports_jobs_.emplace_back(job_id, host_ports); - return OkStatus(); + return absl::OkStatus(); } namespace { diff --git a/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.h b/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.h index b019377f9986d..654e7aa91c321 100644 --- a/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.h +++ b/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.h @@ -43,8 +43,8 @@ class GrpcChannelSpec { const std::map host_ports; }; - Status AddHostPortsJob(const string& job_id, - const std::map& host_ports); + absl::Status AddHostPortsJob(const string& job_id, + const std::map& host_ports); const std::vector& host_ports_jobs() const { return host_ports_jobs_; @@ -88,12 +88,13 @@ GrpcChannelCache* NewGrpcChannelCache( ::grpc::ChannelArguments GetChannelArguments(const RPCOptions* rpc_options); ChannelCreationFunction ConvertToChannelCreationFunction( - const std::function& new_channel_func_ptr); + const std::function& + new_channel_func_ptr); -Status NewHostPortGrpcChannel(const string& target, - const RPCOptions* rpc_options, - SharedGrpcChannelPtr* channel_pointer); +absl::Status NewHostPortGrpcChannel(const string& target, + const RPCOptions* rpc_options, + SharedGrpcChannelPtr* channel_pointer); } // namespace tsl diff --git a/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel_test.cc b/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel_test.cc index 6b2d330cb1d57..adc0df2b89dde 100644 --- a/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel_test.cc +++ b/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel_test.cc @@ -18,11 +18,11 @@ limitations under the License. #include #include +#include "xla/tsl/util/device_name_utils.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/strcat.h" #include "tsl/platform/test.h" #include "tsl/protobuf/rpc_options.pb.h" -#include "tsl/util/device_name_utils.h" namespace tsl { #define IsSameAddrSp DeviceNameUtils::IsSameAddressSpace diff --git a/third_party/tsl/tsl/distributed_runtime/rpc/grpc_state.h b/third_party/tsl/tsl/distributed_runtime/rpc/grpc_state.h index 37b41edc0a010..21d8f2df5099e 100644 --- a/third_party/tsl/tsl/distributed_runtime/rpc/grpc_state.h +++ b/third_party/tsl/tsl/distributed_runtime/rpc/grpc_state.h @@ -23,6 +23,7 @@ limitations under the License. #include "grpcpp/generic/generic_stub.h" #include "grpcpp/grpcpp.h" #include "absl/status/status.h" +#include "xla/tsl/util/env_var.h" #include "tsl/distributed_runtime/call_options.h" #include "tsl/distributed_runtime/rpc/grpc_client_cq_tag.h" #include "tsl/distributed_runtime/rpc/grpc_util.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tsl/platform/status.h" #include "tsl/platform/strcat.h" #include "tsl/platform/threadpool.h" -#include "tsl/util/env_var.h" namespace tsl { @@ -149,7 +149,7 @@ class RPCState : public GrpcClientCQTag { VLOG(2) << "Completed call: " << method_; - Status s = FromGrpcStatus(status_); + absl::Status s = FromGrpcStatus(status_); if (s.ok() && !ok) { // Since this function is only being used for processing the response // to Finish for client-side unary calls, ok should never be false @@ -206,7 +206,7 @@ class RPCState : public GrpcClientCQTag { } void ParseAndCallDone() { - Status s; + absl::Status s; if (!parse_proto_fn_(&response_buf_, response_)) { s.Update(errors::Internal("could not parse rpc response")); } diff --git a/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util.cc b/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util.cc index 39042ce77cb3e..3df92d1418f98 100644 --- a/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util.cc +++ b/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util.cc @@ -16,76 +16,25 @@ limitations under the License. #include "tsl/distributed_runtime/rpc/grpc_util.h" #include -#include #include #include "grpcpp/impl/codegen/proto_utils.h" #include "tsl/platform/protobuf.h" -#include "tsl/platform/random.h" namespace tsl { -namespace { - -double GenerateUniformRandomNumber() { - return random::New64() * (1.0 / std::numeric_limits::max()); -} - -double GenerateUniformRandomNumberBetween(double a, double b) { - if (a == b) return a; - DCHECK_LT(a, b); - return a + GenerateUniformRandomNumber() * (b - a); -} - -} // namespace - -int64_t ComputeBackoffMicroseconds(int current_retry_attempt, int64_t min_delay, - int64_t max_delay) { - DCHECK_GE(current_retry_attempt, 0); - - // This function with the constants below is calculating: - // - // (0.4 * min_delay) + (random[0.6,1.0] * min_delay * 1.3^retries) - // - // Note that there is an extra truncation that occurs and is documented in - // comments below. - constexpr double kBackoffBase = 1.3; - constexpr double kBackoffRandMult = 0.4; - - // This first term does not vary with current_retry_attempt or a random - // number. It exists to ensure the final term is >= min_delay - const double first_term = kBackoffRandMult * min_delay; - - // This is calculating min_delay * 1.3^retries - double uncapped_second_term = min_delay; - while (current_retry_attempt > 0 && - uncapped_second_term < max_delay - first_term) { - current_retry_attempt--; - uncapped_second_term *= kBackoffBase; - } - // Note that first_term + uncapped_second_term can exceed max_delay here - // because of the final multiply by kBackoffBase. We fix that problem with - // the min() below. - double second_term = std::min(uncapped_second_term, max_delay - first_term); - - // This supplies the random jitter to ensure that retried don't cause a - // thundering herd problem. - second_term *= - GenerateUniformRandomNumberBetween(1.0 - kBackoffRandMult, 1.0); - - return std::max(static_cast(first_term + second_term), min_delay); -} - ::grpc::Status GrpcMaybeUnparseProto(const protobuf::Message& src, grpc::ByteBuffer* dst) { bool own_buffer; - return ::grpc::GenericSerialize<::grpc::ProtoBufferWriter, - protobuf::Message>(src, dst, &own_buffer); + // grpc::ProtoBufferWriter + return ::grpc::SerializationTraits::Serialize(src, dst, + &own_buffer); } bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, protobuf::Message* dst) { - ::grpc::ProtoBufferReader reader(src); - return dst->ParseFromZeroCopyStream(&reader); + // grpc::ProtoBufferReader + return ::grpc::SerializationTraits::Deserialize(src, dst) + .ok(); } // GrpcMaybeUnparseProto from a string simply copies the string to the diff --git a/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util.h b/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util.h index 2268124318559..b10fff85a003e 100644 --- a/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util.h +++ b/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util.h @@ -37,15 +37,6 @@ namespace tsl { constexpr char kGrpcPayloadsLost[] = "type.googleapis.com/tensorflow.distributed_runtime.GrpcPayloadsLost"; -// Given the total number of RPC retries attempted, return a randomized -// amount of time to delay before retrying the request. -// -// The average computed backoff increases with the number of RPCs attempted. -// See implementation for details on the calculations. -int64_t ComputeBackoffMicroseconds(int current_retry_attempt, - int64_t min_delay = 1000, - int64_t max_delay = 10000000); - constexpr char kStreamRemovedMessage[] = "Stream removed"; // Identify if the given grpc::Status corresponds to an HTTP stream removed @@ -61,7 +52,7 @@ inline bool IsStreamRemovedError(const ::grpc::Status& s) { s.error_message() == kStreamRemovedMessage; } -inline std::string SerializePayloads(const Status& s) { +inline std::string SerializePayloads(const absl::Status& s) { tensorflow::distributed_runtime::GrpcPayloadContainer container; s.ForEachPayload([&container](StringPiece key, const absl::Cord& value) { (*container.mutable_payloads())[std::string(key)] = std::string(value); @@ -69,7 +60,7 @@ inline std::string SerializePayloads(const Status& s) { return container.SerializeAsString(); } -inline void InsertSerializedPayloads(Status& s, std::string payloads) { +inline void InsertSerializedPayloads(absl::Status& s, std::string payloads) { tensorflow::distributed_runtime::GrpcPayloadContainer container; if (container.ParseFromString(payloads)) { for (const auto& key_val : container.payloads()) { @@ -82,24 +73,25 @@ inline void InsertSerializedPayloads(Status& s, std::string payloads) { } } -inline Status FromGrpcStatus(const ::grpc::Status& s) { +inline absl::Status FromGrpcStatus(const ::grpc::Status& s) { if (s.ok()) { - return OkStatus(); + return absl::OkStatus(); } else { - Status converted; + absl::Status converted; // Convert "UNKNOWN" stream removed errors into unavailable, to allow // for retry upstream. if (IsStreamRemovedError(s)) { - converted = Status(absl::StatusCode::kUnavailable, s.error_message()); + converted = + absl::Status(absl::StatusCode::kUnavailable, s.error_message()); } - converted = Status(static_cast(s.error_code()), - s.error_message()); + converted = absl::Status(static_cast(s.error_code()), + s.error_message()); InsertSerializedPayloads(converted, s.error_details()); return converted; } } -inline ::grpc::Status ToGrpcStatus(const Status& s) { +inline ::grpc::Status ToGrpcStatus(const absl::Status& s) { if (s.ok()) { return ::grpc::Status::OK; } else { diff --git a/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util_test.cc b/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util_test.cc index 06e5228731e41..9872c1cf705a0 100644 --- a/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util_test.cc +++ b/third_party/tsl/tsl/distributed_runtime/rpc/grpc_util_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tsl/distributed_runtime/rpc/grpc_util.h" #include +#include #include #include "grpcpp/grpcpp.h" @@ -70,9 +71,9 @@ TestRequest MakeProto(int size) { } TEST(PayloadSerialization, PayloadsAreTransmitted) { - Status status = errors::InvalidArgument("invalid arg message"); + absl::Status status = errors::InvalidArgument("invalid arg message"); status.SetPayload("a", absl::Cord("\\xFF\\x02\\x03")); - Status status_recovered = FromGrpcStatus(ToGrpcStatus(status)); + absl::Status status_recovered = FromGrpcStatus(ToGrpcStatus(status)); ASSERT_TRUE(status_recovered.GetPayload("a").has_value()); EXPECT_EQ(status_recovered.GetPayload("a").value(), "\\xFF\\x02\\x03"); @@ -83,7 +84,7 @@ TEST(PayloadSerialization, PayloadsCorrupted) { ::grpc::StatusCode::INVALID_ARGUMENT, "invalid arg message", "string that can not be serialized to the GrpcPayloadContainer proto"); - Status converted = FromGrpcStatus(status); + absl::Status converted = FromGrpcStatus(status); EXPECT_TRUE(converted.GetPayload(kGrpcPayloadsLost).has_value()); } diff --git a/third_party/tsl/tsl/framework/BUILD b/third_party/tsl/tsl/framework/BUILD index e055a619f5ac6..1947588e6e656 100644 --- a/third_party/tsl/tsl/framework/BUILD +++ b/third_party/tsl/tsl/framework/BUILD @@ -4,8 +4,12 @@ # The libraries in this package are not allowed to have ANY dependencies # to other TF components outside of TSL. -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") +load( + "//tsl/platform:build_config.bzl", + "tsl_cc_test", +) load( "//tsl/platform:build_config_root.bzl", "if_static", @@ -14,10 +18,6 @@ load( "//tsl/platform:rules_cc.bzl", "cc_library", ) -load( - "//tsl/platform:build_config.bzl", - "tsl_cc_test", -) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -82,7 +82,7 @@ filegroup( "tracking_allocator.h", "type_traits.h", ], - visibility = set_external_visibility(["//tensorflow/core:__subpackages__"]), + visibility = internal_visibility(["//tensorflow/core:__subpackages__"]), ) # Files needed for tf2xla build. @@ -154,8 +154,8 @@ cc_library( "cpu_allocator_impl.cc", "tracking_allocator.h", ], - visibility = set_external_visibility([ - "//tensorflow/compiler/xla:__subpackages__", + visibility = internal_visibility([ + "@xla//xla:__subpackages__", "//tensorflow/core:__subpackages__", "//tsl:__subpackages__", ]), @@ -273,10 +273,10 @@ cc_library( "//tsl/platform:status", "//tsl/platform:statusor", "//tsl/platform:str_util", - "//tsl/util:device_name_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@xla//xla/tsl/util:device_name_utils", ], ) @@ -291,7 +291,7 @@ filegroup( cc_library( name = "numeric_types", hdrs = ["numeric_types.h"], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/compiler:__subpackages__", "//tensorflow/core:__subpackages__", ]), @@ -333,7 +333,7 @@ cc_library( cc_library( name = "type_traits", hdrs = ["type_traits.h"], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core/framework:__pkg__", ]), deps = [ @@ -347,7 +347,7 @@ filegroup( srcs = [ "cancellation.h", ], - visibility = set_external_visibility(["//tensorflow/core:__subpackages__"]), + visibility = internal_visibility(["//tensorflow/core:__subpackages__"]), ) cc_library( @@ -374,6 +374,35 @@ cc_library( ], ) +cc_library( + name = "serving_device_selector", + srcs = ["serving_device_selector.cc"], + hdrs = ["serving_device_selector.h"], + visibility = ["//visibility:public"], + deps = [ + "//tsl/platform:logging", + "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "serving_device_selector_policies", + srcs = ["serving_device_selector_policies.cc"], + hdrs = ["serving_device_selector_policies.h"], + features = ["-layering_check"], + deps = [ + ":serving_device_selector", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "real_time_in_memory_metric", + hdrs = ["real_time_in_memory_metric.h"], +) + tsl_cc_test( name = "cancellation_test", size = "small", @@ -403,7 +432,7 @@ exports_files( "shared_counter.h", "tracking_allocator.h", ], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/common_runtime:__pkg__", "//tensorflow/core/common_runtime/gpu:__pkg__", @@ -436,6 +465,16 @@ tsl_cc_test( "//tsl/platform:status_matchers", "//tsl/platform:test_main", "//tsl/protobuf:error_codes_proto_impl_cc", - "//tsl/util:device_name_utils", + "@xla//xla/tsl/util:device_name_utils", + ], +) + +tsl_cc_test( + name = "real_time_in_memory_metric_test", + srcs = ["real_time_in_memory_metric_test.cc"], + deps = [ + ":real_time_in_memory_metric", + "//tsl/platform:test", + "//tsl/platform:test_main", ], ) diff --git a/third_party/tsl/tsl/framework/bfc_allocator.cc b/third_party/tsl/tsl/framework/bfc_allocator.cc index 9e4447108a921..79a8ea7c892d0 100644 --- a/third_party/tsl/tsl/framework/bfc_allocator.cc +++ b/third_party/tsl/tsl/framework/bfc_allocator.cc @@ -135,24 +135,19 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) { size_t bytes = std::min(curr_region_allocation_bytes_, available_bytes); size_t bytes_received; void* mem_addr = sub_allocator_->Alloc(alignment, bytes, &bytes_received); - if (mem_addr == nullptr && !started_backpedal_) { - // Only backpedal once. - started_backpedal_ = true; - + if (mem_addr == nullptr) { static constexpr float kBackpedalFactor = 0.9; // Try allocating less memory. while (mem_addr == nullptr) { bytes = RoundedBytes(bytes * kBackpedalFactor); - if (bytes < rounded_bytes) break; + if (bytes < rounded_bytes) { + return false; + } mem_addr = sub_allocator_->Alloc(alignment, bytes, &bytes_received); } } - if (mem_addr == nullptr) { - return false; - } - if (!increased_allocation) { // Increase the region size of the next required allocation. curr_region_allocation_bytes_ *= 2; @@ -311,6 +306,8 @@ void* BFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes, } }(); VLOG(3) << "AllocateRaw " << Name() << " " << num_bytes << " " << result; + VLOG(4) << "[mem-debug] AllocateRaw," << Name() << "," << num_bytes << "," + << result << "," << tsl::CurrentStackTrace(); return result; } @@ -684,6 +681,9 @@ void BFCAllocator::SplitChunk(BFCAllocator::ChunkHandle h, size_t num_bytes) { void BFCAllocator::DeallocateRaw(void* ptr) { VLOG(3) << "DeallocateRaw " << Name() << " " << (ptr ? RequestedSize(ptr) : 0); + VLOG(4) << "[mem-debug] DeallocateRaw," << Name() << "," + << (ptr ? RequestedSize(ptr) : 0) << "," << ptr << "," + << tsl::CurrentStackTrace(); DeallocateRawInternal(ptr); retry_helper_.NotifyDealloc(); } diff --git a/third_party/tsl/tsl/framework/bfc_allocator.h b/third_party/tsl/tsl/framework/bfc_allocator.h index 47619856abe8d..76921c5f04a79 100644 --- a/third_party/tsl/tsl/framework/bfc_allocator.h +++ b/third_party/tsl/tsl/framework/bfc_allocator.h @@ -579,10 +579,6 @@ class BFCAllocator : public Allocator { // The size of the current region allocation. size_t curr_region_allocation_bytes_; - // An indicator that expansion of a region has hit the limits - // of the available memory. - bool started_backpedal_ = false; - // Whether the allocator will coalesce adjacent sub allocator provided // AllocationRegions. This may be disabled if discrete sub allocator // regions can't be treated as contiguous (e.g. if the allocation refers to diff --git a/third_party/tsl/tsl/framework/contraction/BUILD b/third_party/tsl/tsl/framework/contraction/BUILD index b52a7ae315d38..71c0d19971f8e 100644 --- a/third_party/tsl/tsl/framework/contraction/BUILD +++ b/third_party/tsl/tsl/framework/contraction/BUILD @@ -1,7 +1,7 @@ -load("//tsl/platform:rules_cc.bzl", "cc_library") -load("//tsl:tsl.default.bzl", "get_compatible_with_portable") -load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") +load("//tsl:tsl.default.bzl", "get_compatible_with_portable") +load("//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/third_party/tsl/tsl/framework/convolution/BUILD b/third_party/tsl/tsl/framework/convolution/BUILD index 1becabc7ac388..c8a8ab4439bba 100644 --- a/third_party/tsl/tsl/framework/convolution/BUILD +++ b/third_party/tsl/tsl/framework/convolution/BUILD @@ -1,9 +1,9 @@ -load("//tsl/platform:rules_cc.bzl", "cc_library") load("//tsl:tsl.default.bzl", "get_compatible_with_portable") load( "//tsl/platform:build_config.bzl", "tsl_cc_test", ) +load("//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/third_party/tsl/tsl/framework/device_id_utils.h b/third_party/tsl/tsl/framework/device_id_utils.h index c2479aded5fe0..e814e68c8530a 100644 --- a/third_party/tsl/tsl/framework/device_id_utils.h +++ b/third_party/tsl/tsl/framework/device_id_utils.h @@ -20,11 +20,11 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "xla/tsl/util/device_name_utils.h" #include "tsl/framework/device_id.h" #include "tsl/framework/device_type.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/util/device_name_utils.h" namespace tsl { diff --git a/third_party/tsl/tsl/framework/device_id_utils_test.cc b/third_party/tsl/tsl/framework/device_id_utils_test.cc index ddf7cdd479935..21e574f95c1b2 100644 --- a/third_party/tsl/tsl/framework/device_id_utils_test.cc +++ b/third_party/tsl/tsl/framework/device_id_utils_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include +#include "xla/tsl/util/device_name_utils.h" #include "tsl/framework/device_id_manager.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" -#include "tsl/util/device_name_utils.h" namespace tsl { namespace { diff --git a/third_party/tsl/tsl/framework/fixedpoint/BUILD b/third_party/tsl/tsl/framework/fixedpoint/BUILD index ac56f82270aa7..4aef243762770 100644 --- a/third_party/tsl/tsl/framework/fixedpoint/BUILD +++ b/third_party/tsl/tsl/framework/fixedpoint/BUILD @@ -1,5 +1,6 @@ -load("//tsl/platform:rules_cc.bzl", "cc_library") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "get_compatible_with_portable") +load("//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -47,10 +48,10 @@ filegroup( "TypeCastingAVX512.h", ], compatible_with = get_compatible_with_portable(), - visibility = [ + visibility = internal_visibility([ "//tensorflow:__subpackages__", "//tsl:internal", - ], + ]), ) # Files needed for core:mobile_srcs_no_runtime. diff --git a/third_party/tsl/tsl/framework/real_time_in_memory_metric.h b/third_party/tsl/tsl/framework/real_time_in_memory_metric.h new file mode 100644 index 0000000000000..0671bf35e69ed --- /dev/null +++ b/third_party/tsl/tsl/framework/real_time_in_memory_metric.h @@ -0,0 +1,58 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_TSL_FRAMEWORK_REAL_TIME_IN_MEMORY_METRIC_H_ +#define TENSORFLOW_TSL_FRAMEWORK_REAL_TIME_IN_MEMORY_METRIC_H_ + +#include + +namespace tsl { + +// Represents a metric with backing storage in local RAM, for exporting real +// time metrics for consumers that live in the same process. It currently only +// supports a simple scalar value. The implementation of this class is lossy but +// minimizes overhead, because there is usually no requirement for metrics +// consumer to get the exact value for any specific time point, but the metrics +// update is usually placed at the critical path of some request. This class is +// a replacement for streamz metric for the above described use case, not +// complimentary. +// +// This class is thread-safe. +// +// NOTE: Only integer and floating point values are supported. +template +class RealTimeInMemoryMetric { + public: + RealTimeInMemoryMetric() : value_(T{0}) {} + + // Gets the current value of this metric. + T Get() const { return value_.load(std::memory_order_relaxed); } + + // Updates the current value of this metric. + void Set(T new_value) { value_.store(new_value, std::memory_order_relaxed); } + + RealTimeInMemoryMetric(const RealTimeInMemoryMetric&) = delete; + RealTimeInMemoryMetric& operator=(const RealTimeInMemoryMetric&) = delete; + RealTimeInMemoryMetric(RealTimeInMemoryMetric&&) = delete; + RealTimeInMemoryMetric& operator=(RealTimeInMemoryMetric&&) = delete; + + static_assert(std::is_arithmetic_v); + + private: + std::atomic value_; +}; + +} // namespace tsl + +#endif // TENSORFLOW_TSL_FRAMEWORK_REAL_TIME_IN_MEMORY_METRIC_H_ diff --git a/third_party/tsl/tsl/framework/real_time_in_memory_metric_test.cc b/third_party/tsl/tsl/framework/real_time_in_memory_metric_test.cc new file mode 100644 index 0000000000000..87d7d9d51e3bd --- /dev/null +++ b/third_party/tsl/tsl/framework/real_time_in_memory_metric_test.cc @@ -0,0 +1,32 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tsl/framework/real_time_in_memory_metric.h" + +#include + +#include "tsl/platform/test.h" + +namespace tsl { +namespace { + +TEST(RealTimeInMemoryMetric, SetAndGet) { + RealTimeInMemoryMetric m; + EXPECT_EQ(m.Get(), 0); + m.Set(100); + EXPECT_EQ(m.Get(), 100); +} + +} // namespace +} // namespace tsl diff --git a/third_party/tsl/tsl/framework/serving_device_selector.cc b/third_party/tsl/tsl/framework/serving_device_selector.cc new file mode 100644 index 0000000000000..e03e77eeaeb6b --- /dev/null +++ b/third_party/tsl/tsl/framework/serving_device_selector.cc @@ -0,0 +1,161 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tsl/framework/serving_device_selector.h" + +#include +#include +#include +#include +#include + +#include "absl/container/fixed_array.h" +#include "absl/strings/string_view.h" +#include "tsl/platform/logging.h" + +namespace tsl { + +inline constexpr int kHighPriority = 0; + +DeviceReservation::DeviceReservation(int device_index, + ServingDeviceSelector* device_selector) + : device_index_(device_index), device_selector_(device_selector) {} + +DeviceReservation::~DeviceReservation() { reset(); } + +void DeviceReservation::reset() { + if (device_selector_) device_selector_->FreeDeviceReservation(*this); + device_selector_ = nullptr; +} + +DeviceReservation::DeviceReservation(DeviceReservation&& r) + : device_index_{r.device_index_}, device_selector_{r.device_selector_} { + r.device_selector_ = nullptr; +} + +DeviceReservation& DeviceReservation::operator=(DeviceReservation&& r) { + if (this == &r) return *this; + + if (device_selector_) device_selector_->FreeDeviceReservation(*this); + + device_index_ = r.device_index_; + device_selector_ = r.device_selector_; + r.device_selector_ = nullptr; + return *this; +} + +/*static*/ void ServingDeviceSelector::CompletedHelper( + DeviceState& device_state, int32_t device_index, int32_t priority, + std::optional& min_exec_time, bool had_error, int64_t now_ns) { + // Check that priority 'priority' queue is non-empty. + DCHECK(!device_state.enqueued_programs[priority].empty()); + auto& program_info = device_state.enqueued_programs[priority].front(); + auto prefetch_results = program_info.prefetch_results; + auto execution_info = program_info.execution_info; + device_state.enqueued_programs[priority].pop_front(); + // To make tracked execution time as accurate as possible, we only record this + // execution time if two programs ran back-to-back without host round trip. + if (!device_state.timer_reset && !had_error) { + VLOG(4) << "Complete. update device[" << device_index + << "], priority: " << priority + << ", prefetch: " << static_cast(prefetch_results) + << ", time: " << now_ns - device_state.last_started_ns; + const_cast(execution_info) + ->AddTime(now_ns - device_state.last_started_ns, prefetch_results); + // Only update min_exec_time_ when running_average is updated. This avoids + // the case where running_average is zero. + if (!min_exec_time.has_value() || + execution_info->GetTime(prefetch_results) < min_exec_time.value()) { + min_exec_time = execution_info->GetTime(prefetch_results); + } + } + // If there are remaining programs, update the start time. + if (!device_state.enqueued_programs.empty()) { + device_state.last_started_ns = now_ns; + device_state.timer_reset = false; + } +} + +/*static*/ int64_t ServingDeviceSelector::EstimateTimeTillIdleNs( + const DeviceState& device_state, int32_t priority, int64_t min_exec_time, + int64_t now_ns) { + int64_t ns_till_idle = 0; + // Add time from each program in queues with priority 'priority' or higher. + for (int32_t i = 0; i <= priority; i++) { + for (auto& info : device_state.enqueued_programs[i]) { + ns_till_idle += + info.execution_info->MaybeGetValidTime(info.prefetch_results); + } + } + // Accounts for the elapsed time of the currently running but unfinished + // program (i.e., enqueued programs). + if (ns_till_idle > 0) { + DCHECK_GT(device_state.last_started_ns, 0); + ns_till_idle = std::max( + 0, ns_till_idle - (now_ns - device_state.last_started_ns)); + } + + // Add time from scheduled programs with priority 'priority' or higher + int64_t ns_of_schedule_programs = 0; + for (int32_t i = 0; i <= priority; i++) { + for (auto& info : device_state.scheduled_programs[i]) { + ns_of_schedule_programs += std::max( + info.execution_info->MaybeGetValidTime(info.prefetch_results), + min_exec_time); + } + } + return ns_till_idle + ns_of_schedule_programs; +} +/*static*/ void ServingDeviceSelector::EnqueueHelper( + DeviceState& device_state, int32_t device_index, + ExecutionInfo& execution_info, absl::string_view fingerprint, + int32_t priority, int64_t req_id, size_t priority_queue_count, + int prefetch_results, int64_t now_ns) { + if (!device_state.scheduled_programs[priority].empty()) { + auto& program = device_state.scheduled_programs[priority].front(); + if (program.fingerprint.empty()) { + program.execution_info = &execution_info; + program.fingerprint = fingerprint; + if (priority == kHighPriority) { + device_state.last_fingerprint = fingerprint; + } + device_state.unknown_fingerprint_requests--; + } + device_state.enqueued_programs[static_cast(priority)].push_back( + std::move(program)); + device_state.scheduled_programs[static_cast(priority)].pop_front(); + } else { + DeviceState::ProgramInfo program; + program.execution_info = &execution_info; + program.fingerprint = fingerprint; + program.req_id = req_id; + program.priority = priority; + program.prefetch_results = prefetch_results; + device_state.enqueued_programs[priority].push_back(program); + device_state.last_fingerprint = fingerprint; + } + + // Count number of programs in enqueued_programs queues. + int64_t num_programs_enqueued = 0; + for (int64_t i = 0; i < priority_queue_count; i++) { + num_programs_enqueued += device_state.enqueued_programs[i].size(); + } + + if (num_programs_enqueued == 1) { + device_state.last_started_ns = now_ns; + device_state.timer_reset = true; + } +} +} // namespace tsl diff --git a/third_party/tsl/tsl/framework/serving_device_selector.h b/third_party/tsl/tsl/framework/serving_device_selector.h new file mode 100644 index 0000000000000..426a43e4f1326 --- /dev/null +++ b/third_party/tsl/tsl/framework/serving_device_selector.h @@ -0,0 +1,201 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_TSL_FRAMEWORK_SERVING_DEVICE_SELECTOR_H_ +#define TENSORFLOW_TSL_FRAMEWORK_SERVING_DEVICE_SELECTOR_H_ + +#include +#include +#include + +#include "absl/container/fixed_array.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tsl/platform/logging.h" + +namespace tsl { + +class ServingDeviceSelector; + +// A RAII type for device reservation. +class DeviceReservation { + public: + DeviceReservation(int device_index, ServingDeviceSelector* selector); + ~DeviceReservation(); + + DeviceReservation(const DeviceReservation&) = delete; + DeviceReservation& operator=(const DeviceReservation&) = delete; + + DeviceReservation(DeviceReservation&& r); + DeviceReservation& operator=(DeviceReservation&& r); + + int device_index() const { return device_index_; } + + void reset(); + + private: + int device_index_; + ServingDeviceSelector* device_selector_; +}; + +// Interface for runtime device selection for serving. +// NOTE: This interface is experimental and subject to change. +class ServingDeviceSelector { + public: + // Tracks the running average of certain program execution time. + class RunningAverage { + public: + void Add(int64_t value) { + DCHECK_GE(value, 0); + sum_ += value; + ++count_; + latency_ = sum_ / count_; + } + + int64_t Get() const { return latency_; } + + private: + int64_t sum_ = 0; + int64_t count_ = 0; + int64_t latency_ = 0; + }; + + // Tracks the program execution information, including execution time. + class ExecutionInfo { + public: + explicit ExecutionInfo(int64_t num_prefetch_result = 1) + : running_average_(num_prefetch_result) {} + + virtual ~ExecutionInfo() = default; + + void AddTime(int64_t value, int result) { + DCHECK_GE(value, 0); + DCHECK_LT(result, running_average_.size()); + running_average_.at(result).Add(value); + } + + int64_t GetTime(int result) const { + DCHECK_LT(result, running_average_.size()); + return running_average_.at(result).Get(); + } + + // To be conservative when one of the path is missing. + virtual int64_t MaybeGetValidTime(int result) const { + return GetTime(result); + } + + private: + // Records program average execution time, one for each prefetch result. + absl::FixedArray running_average_; + }; + + struct DeviceState { + explicit DeviceState(int64_t priority_count = 1) + : enqueued_programs(priority_count), + scheduled_programs(priority_count) {} + // TODO(b/295352859): Add more stats to track that are useful for the Policy + // to use when selecting a device. + struct ProgramInfo { + std::string fingerprint; + int32_t priority; + int64_t req_id = -1; + const ExecutionInfo* execution_info; + int prefetch_results; + }; + // A queue of enqueued programs, one for each priority level + absl::FixedArray> enqueued_programs; + // A queue of scheduled yet enqueued programs, one for each priority level. + // May or may not have fingerprint. + absl::FixedArray> scheduled_programs; + // Timestamp in nanoseconds of last started program. + int64_t last_started_ns = 0; + // Fingerprint of last enqueued high priority program. + std::string last_fingerprint; + // The number of scheduled not yet enqueued programs with unknown + // fingerprints. + int32_t unknown_fingerprint_requests; + // Whether execution timer was reset, true iff a program is enqueued while + // all queues (for all priorities) were empty. + bool timer_reset = true; + }; + + // Struct of all tracked device states, which will be passed to Policy. + struct DeviceStates { + absl::Span states; + }; + + // Policy used to select a device. + class Policy { + public: + virtual ~Policy() = default; + // Selects a device based on the tracked states of all devices. + virtual int SelectDevice(absl::string_view program_fingerprint, + const DeviceStates& device_states) = 0; + }; + + virtual ~ServingDeviceSelector() = default; + + // Reserves a device according to a given selection policy. The reserved + // device will be freed when the lifetime of the returned `DeviceReservation` + // object ends. + virtual DeviceReservation ReserveDevice( + absl::string_view program_fingerprint) = 0; + + // Enqueues a program on the given device. Used only for load tracking + // purposes when the device selection feature is unused. + virtual void Enqueue(int32_t device_index, absl::string_view fingerprint) = 0; + + // Marks the completion of a program on the given device. Used only for load + // tracking purposes when the device selection feature is unused. + virtual void Completed(int32_t device_index, bool had_error) = 0; + + protected: + // A helper function for Enqueue. The EnqueueHelper does the following things. + // 1. If there are programs in the scheduled_programs queue of the given + // priority, move the program to the corresponding enqueued_programs + // queue. Update the fingerprint if it is unknown. This is a typical TF1 + // use case. + // 2. If there are no programs in the scheduled_programs queue of the given + // priority, create the program of the fingerprint and place it in the + // corresponding enqueued_programs queue. + // This can happen in two cases: (1) TFRT that doesn't need + // scheduled_programs queue. (2) In TF1, Schedule() was not called prior + // to Enqueue(). + // This helper also updates last_started_ns and timer_reset. + static void EnqueueHelper(DeviceState& device_state, int32_t device_index, + ExecutionInfo& execution_info, + absl::string_view fingerprint, int32_t priority, + int64_t req_id, size_t priority_queue_count, + int prefetch_results, int64_t now_ns); + // A helper function tells a program has completed on the given device. + static void CompletedHelper(DeviceState& device_state, int32_t device_index, + int32_t priority, + std::optional& min_exec_time, + bool had_error, int64_t now_ns); + // Helper to estimate the time until the core becomes idle in nanoseconds. + // Only considers queues with priority at least as high as 'priority'. + static int64_t EstimateTimeTillIdleNs(const DeviceState& device_state, + int32_t priority, int64_t min_exec_time, + int64_t now_ns); + + private: + friend DeviceReservation; + + // Frees the given device reservation. + virtual void FreeDeviceReservation(const DeviceReservation& reservation) = 0; +}; + +} // namespace tsl + +#endif // TENSORFLOW_TSL_FRAMEWORK_SERVING_DEVICE_SELECTOR_H_ diff --git a/third_party/tsl/tsl/framework/serving_device_selector_policies.cc b/third_party/tsl/tsl/framework/serving_device_selector_policies.cc new file mode 100644 index 0000000000000..7c074ff078018 --- /dev/null +++ b/third_party/tsl/tsl/framework/serving_device_selector_policies.cc @@ -0,0 +1,31 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tsl/framework/serving_device_selector_policies.h" + +#include + +#include "absl/strings/string_view.h" +#include "tsl/framework/serving_device_selector.h" + +namespace tsl { + +int RoundRobinPolicy::SelectDevice( + absl::string_view program_fingerprint, + const ServingDeviceSelector::DeviceStates& device_states) { + const int num_devices = device_states.states.size(); + return ordinal_.fetch_add(1, std::memory_order_relaxed) % num_devices; +} + +} // namespace tsl diff --git a/third_party/tsl/tsl/framework/serving_device_selector_policies.h b/third_party/tsl/tsl/framework/serving_device_selector_policies.h new file mode 100644 index 0000000000000..638206bc1229c --- /dev/null +++ b/third_party/tsl/tsl/framework/serving_device_selector_policies.h @@ -0,0 +1,42 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_TSL_FRAMEWORK_SERVING_DEVICE_SELECTOR_POLICIES_H_ +#define TENSORFLOW_TSL_FRAMEWORK_SERVING_DEVICE_SELECTOR_POLICIES_H_ + +#include + +#include "tsl/framework/serving_device_selector.h" + +namespace tsl { + +enum class ServingDeviceSelectorPolicy { + kRoundRobin, +}; + +class RoundRobinPolicy : public ServingDeviceSelector::Policy { + public: + RoundRobinPolicy() : ordinal_(0) {} + + int SelectDevice( + absl::string_view program_fingerprint, + const ServingDeviceSelector::DeviceStates& device_states) override; + + private: + std::atomic ordinal_; +}; + +} // namespace tsl + +#endif // TENSORFLOW_TSL_FRAMEWORK_SERVING_DEVICE_SELECTOR_POLICIES_H_ diff --git a/third_party/tsl/tsl/lib/core/BUILD b/third_party/tsl/tsl/lib/core/BUILD index 2d7a3ca220f8c..9715dfefbb5c7 100644 --- a/third_party/tsl/tsl/lib/core/BUILD +++ b/third_party/tsl/tsl/lib/core/BUILD @@ -4,9 +4,9 @@ # The libraries in this package are not allowed to have ANY dependencies # to other TF components outside of TSL. -load("//tsl/platform:build_config.bzl", "tsl_cc_test") -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "get_compatible_with_portable") +load("//tsl/platform:build_config.bzl", "tsl_cc_test") load( "//tsl/platform:rules_cc.bzl", "cc_library", @@ -28,7 +28,7 @@ filegroup( "bits.h", ], compatible_with = get_compatible_with_portable(), - visibility = set_external_visibility(["//tensorflow/core:__pkg__"]), + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) filegroup( @@ -39,7 +39,7 @@ filegroup( "status_test_util.h", ], compatible_with = get_compatible_with_portable(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/lib/core:__pkg__", ]), @@ -51,7 +51,7 @@ filegroup( "bitmap_test.cc", ], compatible_with = get_compatible_with_portable(), - visibility = set_external_visibility(["//tensorflow/core:__pkg__"]), + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) filegroup( @@ -61,7 +61,7 @@ filegroup( "bits.h", ], compatible_with = get_compatible_with_portable(), - visibility = set_external_visibility(["//tensorflow/core:__pkg__"]), + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) filegroup( @@ -70,7 +70,7 @@ filegroup( "status_test_util.h", ], compatible_with = get_compatible_with_portable(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/lib/core:__pkg__", ]), diff --git a/third_party/tsl/tsl/lib/gtl/BUILD b/third_party/tsl/tsl/lib/gtl/BUILD index 093a7c1880e02..c9f2ac9f6cd13 100644 --- a/third_party/tsl/tsl/lib/gtl/BUILD +++ b/third_party/tsl/tsl/lib/gtl/BUILD @@ -1,4 +1,4 @@ -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "filegroup") load( "//tsl/platform:build_config.bzl", @@ -11,7 +11,7 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([ + default_visibility = internal_visibility([ # tensorflow/core:lib effectively exposes all targets under tensorflow/core/lib/** "//tensorflow/core:__pkg__", # tensorflow/core/lib/strings:proto_serialization uses on gtl:inlined_vector @@ -27,7 +27,7 @@ package( "//tensorflow/core/tfrt/utils:__pkg__", # tensorflow/examples/custom_ops_doc/simple_hash_table uses map_util "//tensorflow/examples/custom_ops_doc/simple_hash_table:__pkg__", - "//tensorflow/compiler/xla:__subpackages__", + "@xla//xla:__subpackages__", "//tensorflow/core/lib/gtl:__subpackages__", "//tsl/distributed_runtime/rpc:__pkg__", "//tsl/profiler/utils:__pkg__", @@ -58,8 +58,8 @@ cc_library( name = "flatrep", hdrs = ["flatrep.h"], deps = [ - "//tsl/platform:prefetch", "//tsl/platform:types", + "@com_google_absl//absl/base:prefetch", ], ) @@ -116,7 +116,7 @@ filegroup( "inlined_vector.h", "iterator_range.h", ], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/lib/gtl:__pkg__", ]), @@ -128,7 +128,7 @@ filegroup( "int_type.h", "map_util.h", ], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/lib/gtl:__pkg__", ]), @@ -138,7 +138,7 @@ filegroup( name = "legacy_lib_test_internal_headers", srcs = [ ], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/lib/gtl:__pkg__", ]), @@ -148,7 +148,7 @@ filegroup( name = "legacy_android_gif_internal_headers", srcs = [ ], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/lib/gtl:__pkg__", ]), @@ -162,7 +162,7 @@ filegroup( "flatrep.h", "inlined_vector.h", ], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/lib/gtl:__pkg__", "//tsl:__subpackages__", @@ -178,7 +178,7 @@ filegroup( "map_util.h", "//tsl/lib/gtl/subtle:map_traits", ], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/lib/gtl:__pkg__", ]), @@ -197,7 +197,7 @@ filegroup( "map_util.h", "//tsl/lib/gtl/subtle:map_traits", ], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/lib/gtl:__pkg__", ]), diff --git a/third_party/tsl/tsl/lib/gtl/flatrep.h b/third_party/tsl/tsl/lib/gtl/flatrep.h index 9e29a772c3eca..dfc65844e68ed 100644 --- a/third_party/tsl/tsl/lib/gtl/flatrep.h +++ b/third_party/tsl/tsl/lib/gtl/flatrep.h @@ -20,7 +20,7 @@ limitations under the License. #include -#include "tsl/platform/prefetch.h" +#include "absl/base/prefetch.h" #include "tsl/platform/types.h" namespace tsl { @@ -214,8 +214,8 @@ class FlatRep { size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket uint32 bi = index & (kWidth - 1); Bucket* b = &array_[index >> kBase]; - port::prefetch(&b->marker[bi]); - port::prefetch(&b->storage.key[bi]); + absl::PrefetchToLocalCache(&b->marker[bi]); + absl::PrefetchToLocalCache(&b->storage.key[bi]); } inline void MaybeResize() { diff --git a/third_party/tsl/tsl/lib/gtl/subtle/BUILD b/third_party/tsl/tsl/lib/gtl/subtle/BUILD index c69e16f2cab6b..3e9bfe7a5d03e 100644 --- a/third_party/tsl/tsl/lib/gtl/subtle/BUILD +++ b/third_party/tsl/tsl/lib/gtl/subtle/BUILD @@ -1,6 +1,7 @@ # Description: # gtl subtle packages. +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "filegroup") package( @@ -13,8 +14,8 @@ filegroup( srcs = [ "map_traits.h", ], - visibility = [ + visibility = internal_visibility([ "//tensorflow/core/lib/gtl/subtle:__pkg__", "//tsl/lib/gtl:__pkg__", - ], + ]), ) diff --git a/third_party/tsl/tsl/lib/hash/BUILD b/third_party/tsl/tsl/lib/hash/BUILD index 4a5bcc363aad8..7433988f1d9b7 100644 --- a/third_party/tsl/tsl/lib/hash/BUILD +++ b/third_party/tsl/tsl/lib/hash/BUILD @@ -1,10 +1,10 @@ -load("//tsl:tsl.bzl", "set_external_visibility") -load("//tsl:tsl.default.bzl", "filegroup") load( "//tsl:tsl.bzl", "if_linux_x86_64", + "internal_visibility", "tsl_copts", ) +load("//tsl:tsl.default.bzl", "filegroup") load( "//tsl/platform:build_config.bzl", "tsl_cc_test", @@ -16,7 +16,7 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([ + default_visibility = internal_visibility([ # tensorflow/tsl/lib/io/table_builder.cc uses crc functionality "//tsl/lib/io:__pkg__", # tensorflow/core/lib/hash aliases hash for now @@ -29,7 +29,6 @@ cc_library( name = "crc32c", srcs = [ "crc32c.cc", - "crc32c_accelerate.cc", ], hdrs = ["crc32c.h"], # -msse4.2 enables the use of crc32c compiler builtins. @@ -37,8 +36,10 @@ cc_library( deps = [ "//tsl/platform", "//tsl/platform:cord", - "//tsl/platform:raw_coding", "//tsl/platform:types", + "@com_google_absl//absl/crc:crc32c", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", ], ) @@ -48,9 +49,8 @@ filegroup( srcs = [ "crc32c.cc", "crc32c.h", - "crc32c_accelerate.cc", ], - visibility = set_external_visibility(["//tensorflow/core/lib/hash:__pkg__"]), + visibility = internal_visibility(["//tensorflow/core/lib/hash:__pkg__"]), ) filegroup( @@ -58,7 +58,7 @@ filegroup( srcs = [ "crc32c.h", ], - visibility = set_external_visibility(["//tensorflow/core/lib/hash:__pkg__"]), + visibility = internal_visibility(["//tensorflow/core/lib/hash:__pkg__"]), ) tsl_cc_test( @@ -71,5 +71,7 @@ tsl_cc_test( "//tsl/platform:test", "//tsl/platform:test_benchmark", "//tsl/platform:test_main", + "//tsl/platform:types", + "@com_google_absl//absl/strings:cord", ], ) diff --git a/third_party/tsl/tsl/lib/hash/crc32c.cc b/third_party/tsl/tsl/lib/hash/crc32c.cc index 91edb249f9634..1bd005b6b0529 100644 --- a/third_party/tsl/tsl/lib/hash/crc32c.cc +++ b/third_party/tsl/tsl/lib/hash/crc32c.cc @@ -19,250 +19,14 @@ limitations under the License. #include "tsl/lib/hash/crc32c.h" #include -#include "tsl/platform/raw_coding.h" + +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "tsl/platform/types.h" namespace tsl { namespace crc32c { -extern bool CanAccelerate(); -extern uint32_t AcceleratedExtend(uint32_t crc, const char *buf, size_t size); - -static const uint32 table0_[256] = { - 0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4, 0xc79a971f, 0x35f1141c, - 0x26a1e7e8, 0xd4ca64eb, 0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b, - 0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24, 0x105ec76f, 0xe235446c, - 0xf165b798, 0x030e349b, 0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384, - 0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54, 0x5d1d08bf, 0xaf768bbc, - 0xbc267848, 0x4e4dfb4b, 0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a, - 0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35, 0xaa64d611, 0x580f5512, - 0x4b5fa6e6, 0xb93425e5, 0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa, - 0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45, 0xf779deae, 0x05125dad, - 0x1642ae59, 0xe4292d5a, 0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a, - 0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595, 0x417b1dbc, 0xb3109ebf, - 0xa0406d4b, 0x522bee48, 0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957, - 0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687, 0x0c38d26c, 0xfe53516f, - 0xed03a29b, 0x1f682198, 0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927, - 0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38, 0xdbfc821c, 0x2997011f, - 0x3ac7f2eb, 0xc8ac71e8, 0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7, - 0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096, 0xa65c047d, 0x5437877e, - 0x4767748a, 0xb50cf789, 0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859, - 0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46, 0x7198540d, 0x83f3d70e, - 0x90a324fa, 0x62c8a7f9, 0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6, - 0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36, 0x3cdb9bdd, 0xceb018de, - 0xdde0eb2a, 0x2f8b6829, 0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c, - 0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93, 0x082f63b7, 0xfa44e0b4, - 0xe9141340, 0x1b7f9043, 0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c, - 0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3, 0x55326b08, 0xa759e80b, - 0xb4091bff, 0x466298fc, 0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c, - 0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033, 0xa24bb5a6, 0x502036a5, - 0x4370c551, 0xb11b4652, 0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d, - 0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d, 0xef087a76, 0x1d63f975, - 0x0e330a81, 0xfc588982, 0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d, - 0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622, 0x38cc2a06, 0xcaa7a905, - 0xd9f75af1, 0x2b9cd9f2, 0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed, - 0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530, 0x0417b1db, 0xf67c32d8, - 0xe52cc12c, 0x1747422f, 0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff, - 0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0, 0xd3d3e1ab, 0x21b862a8, - 0x32e8915c, 0xc083125f, 0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540, - 0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90, 0x9e902e7b, 0x6cfbad78, - 0x7fab5e8c, 0x8dc0dd8f, 0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee, - 0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1, 0x69e9f0d5, 0x9b8273d6, - 0x88d28022, 0x7ab90321, 0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e, - 0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81, 0x34f4f86a, 0xc69f7b69, - 0xd5cf889d, 0x27a40b9e, 0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e, - 0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351}; -static const uint32 table1_[256] = { - 0x00000000, 0x13a29877, 0x274530ee, 0x34e7a899, 0x4e8a61dc, 0x5d28f9ab, - 0x69cf5132, 0x7a6dc945, 0x9d14c3b8, 0x8eb65bcf, 0xba51f356, 0xa9f36b21, - 0xd39ea264, 0xc03c3a13, 0xf4db928a, 0xe7790afd, 0x3fc5f181, 0x2c6769f6, - 0x1880c16f, 0x0b225918, 0x714f905d, 0x62ed082a, 0x560aa0b3, 0x45a838c4, - 0xa2d13239, 0xb173aa4e, 0x859402d7, 0x96369aa0, 0xec5b53e5, 0xfff9cb92, - 0xcb1e630b, 0xd8bcfb7c, 0x7f8be302, 0x6c297b75, 0x58ced3ec, 0x4b6c4b9b, - 0x310182de, 0x22a31aa9, 0x1644b230, 0x05e62a47, 0xe29f20ba, 0xf13db8cd, - 0xc5da1054, 0xd6788823, 0xac154166, 0xbfb7d911, 0x8b507188, 0x98f2e9ff, - 0x404e1283, 0x53ec8af4, 0x670b226d, 0x74a9ba1a, 0x0ec4735f, 0x1d66eb28, - 0x298143b1, 0x3a23dbc6, 0xdd5ad13b, 0xcef8494c, 0xfa1fe1d5, 0xe9bd79a2, - 0x93d0b0e7, 0x80722890, 0xb4958009, 0xa737187e, 0xff17c604, 0xecb55e73, - 0xd852f6ea, 0xcbf06e9d, 0xb19da7d8, 0xa23f3faf, 0x96d89736, 0x857a0f41, - 0x620305bc, 0x71a19dcb, 0x45463552, 0x56e4ad25, 0x2c896460, 0x3f2bfc17, - 0x0bcc548e, 0x186eccf9, 0xc0d23785, 0xd370aff2, 0xe797076b, 0xf4359f1c, - 0x8e585659, 0x9dface2e, 0xa91d66b7, 0xbabffec0, 0x5dc6f43d, 0x4e646c4a, - 0x7a83c4d3, 0x69215ca4, 0x134c95e1, 0x00ee0d96, 0x3409a50f, 0x27ab3d78, - 0x809c2506, 0x933ebd71, 0xa7d915e8, 0xb47b8d9f, 0xce1644da, 0xddb4dcad, - 0xe9537434, 0xfaf1ec43, 0x1d88e6be, 0x0e2a7ec9, 0x3acdd650, 0x296f4e27, - 0x53028762, 0x40a01f15, 0x7447b78c, 0x67e52ffb, 0xbf59d487, 0xacfb4cf0, - 0x981ce469, 0x8bbe7c1e, 0xf1d3b55b, 0xe2712d2c, 0xd69685b5, 0xc5341dc2, - 0x224d173f, 0x31ef8f48, 0x050827d1, 0x16aabfa6, 0x6cc776e3, 0x7f65ee94, - 0x4b82460d, 0x5820de7a, 0xfbc3faf9, 0xe861628e, 0xdc86ca17, 0xcf245260, - 0xb5499b25, 0xa6eb0352, 0x920cabcb, 0x81ae33bc, 0x66d73941, 0x7575a136, - 0x419209af, 0x523091d8, 0x285d589d, 0x3bffc0ea, 0x0f186873, 0x1cbaf004, - 0xc4060b78, 0xd7a4930f, 0xe3433b96, 0xf0e1a3e1, 0x8a8c6aa4, 0x992ef2d3, - 0xadc95a4a, 0xbe6bc23d, 0x5912c8c0, 0x4ab050b7, 0x7e57f82e, 0x6df56059, - 0x1798a91c, 0x043a316b, 0x30dd99f2, 0x237f0185, 0x844819fb, 0x97ea818c, - 0xa30d2915, 0xb0afb162, 0xcac27827, 0xd960e050, 0xed8748c9, 0xfe25d0be, - 0x195cda43, 0x0afe4234, 0x3e19eaad, 0x2dbb72da, 0x57d6bb9f, 0x447423e8, - 0x70938b71, 0x63311306, 0xbb8de87a, 0xa82f700d, 0x9cc8d894, 0x8f6a40e3, - 0xf50789a6, 0xe6a511d1, 0xd242b948, 0xc1e0213f, 0x26992bc2, 0x353bb3b5, - 0x01dc1b2c, 0x127e835b, 0x68134a1e, 0x7bb1d269, 0x4f567af0, 0x5cf4e287, - 0x04d43cfd, 0x1776a48a, 0x23910c13, 0x30339464, 0x4a5e5d21, 0x59fcc556, - 0x6d1b6dcf, 0x7eb9f5b8, 0x99c0ff45, 0x8a626732, 0xbe85cfab, 0xad2757dc, - 0xd74a9e99, 0xc4e806ee, 0xf00fae77, 0xe3ad3600, 0x3b11cd7c, 0x28b3550b, - 0x1c54fd92, 0x0ff665e5, 0x759baca0, 0x663934d7, 0x52de9c4e, 0x417c0439, - 0xa6050ec4, 0xb5a796b3, 0x81403e2a, 0x92e2a65d, 0xe88f6f18, 0xfb2df76f, - 0xcfca5ff6, 0xdc68c781, 0x7b5fdfff, 0x68fd4788, 0x5c1aef11, 0x4fb87766, - 0x35d5be23, 0x26772654, 0x12908ecd, 0x013216ba, 0xe64b1c47, 0xf5e98430, - 0xc10e2ca9, 0xd2acb4de, 0xa8c17d9b, 0xbb63e5ec, 0x8f844d75, 0x9c26d502, - 0x449a2e7e, 0x5738b609, 0x63df1e90, 0x707d86e7, 0x0a104fa2, 0x19b2d7d5, - 0x2d557f4c, 0x3ef7e73b, 0xd98eedc6, 0xca2c75b1, 0xfecbdd28, 0xed69455f, - 0x97048c1a, 0x84a6146d, 0xb041bcf4, 0xa3e32483}; -static const uint32 table2_[256] = { - 0x00000000, 0xa541927e, 0x4f6f520d, 0xea2ec073, 0x9edea41a, 0x3b9f3664, - 0xd1b1f617, 0x74f06469, 0x38513ec5, 0x9d10acbb, 0x773e6cc8, 0xd27ffeb6, - 0xa68f9adf, 0x03ce08a1, 0xe9e0c8d2, 0x4ca15aac, 0x70a27d8a, 0xd5e3eff4, - 0x3fcd2f87, 0x9a8cbdf9, 0xee7cd990, 0x4b3d4bee, 0xa1138b9d, 0x045219e3, - 0x48f3434f, 0xedb2d131, 0x079c1142, 0xa2dd833c, 0xd62de755, 0x736c752b, - 0x9942b558, 0x3c032726, 0xe144fb14, 0x4405696a, 0xae2ba919, 0x0b6a3b67, - 0x7f9a5f0e, 0xdadbcd70, 0x30f50d03, 0x95b49f7d, 0xd915c5d1, 0x7c5457af, - 0x967a97dc, 0x333b05a2, 0x47cb61cb, 0xe28af3b5, 0x08a433c6, 0xade5a1b8, - 0x91e6869e, 0x34a714e0, 0xde89d493, 0x7bc846ed, 0x0f382284, 0xaa79b0fa, - 0x40577089, 0xe516e2f7, 0xa9b7b85b, 0x0cf62a25, 0xe6d8ea56, 0x43997828, - 0x37691c41, 0x92288e3f, 0x78064e4c, 0xdd47dc32, 0xc76580d9, 0x622412a7, - 0x880ad2d4, 0x2d4b40aa, 0x59bb24c3, 0xfcfab6bd, 0x16d476ce, 0xb395e4b0, - 0xff34be1c, 0x5a752c62, 0xb05bec11, 0x151a7e6f, 0x61ea1a06, 0xc4ab8878, - 0x2e85480b, 0x8bc4da75, 0xb7c7fd53, 0x12866f2d, 0xf8a8af5e, 0x5de93d20, - 0x29195949, 0x8c58cb37, 0x66760b44, 0xc337993a, 0x8f96c396, 0x2ad751e8, - 0xc0f9919b, 0x65b803e5, 0x1148678c, 0xb409f5f2, 0x5e273581, 0xfb66a7ff, - 0x26217bcd, 0x8360e9b3, 0x694e29c0, 0xcc0fbbbe, 0xb8ffdfd7, 0x1dbe4da9, - 0xf7908dda, 0x52d11fa4, 0x1e704508, 0xbb31d776, 0x511f1705, 0xf45e857b, - 0x80aee112, 0x25ef736c, 0xcfc1b31f, 0x6a802161, 0x56830647, 0xf3c29439, - 0x19ec544a, 0xbcadc634, 0xc85da25d, 0x6d1c3023, 0x8732f050, 0x2273622e, - 0x6ed23882, 0xcb93aafc, 0x21bd6a8f, 0x84fcf8f1, 0xf00c9c98, 0x554d0ee6, - 0xbf63ce95, 0x1a225ceb, 0x8b277743, 0x2e66e53d, 0xc448254e, 0x6109b730, - 0x15f9d359, 0xb0b84127, 0x5a968154, 0xffd7132a, 0xb3764986, 0x1637dbf8, - 0xfc191b8b, 0x595889f5, 0x2da8ed9c, 0x88e97fe2, 0x62c7bf91, 0xc7862def, - 0xfb850ac9, 0x5ec498b7, 0xb4ea58c4, 0x11abcaba, 0x655baed3, 0xc01a3cad, - 0x2a34fcde, 0x8f756ea0, 0xc3d4340c, 0x6695a672, 0x8cbb6601, 0x29faf47f, - 0x5d0a9016, 0xf84b0268, 0x1265c21b, 0xb7245065, 0x6a638c57, 0xcf221e29, - 0x250cde5a, 0x804d4c24, 0xf4bd284d, 0x51fcba33, 0xbbd27a40, 0x1e93e83e, - 0x5232b292, 0xf77320ec, 0x1d5de09f, 0xb81c72e1, 0xccec1688, 0x69ad84f6, - 0x83834485, 0x26c2d6fb, 0x1ac1f1dd, 0xbf8063a3, 0x55aea3d0, 0xf0ef31ae, - 0x841f55c7, 0x215ec7b9, 0xcb7007ca, 0x6e3195b4, 0x2290cf18, 0x87d15d66, - 0x6dff9d15, 0xc8be0f6b, 0xbc4e6b02, 0x190ff97c, 0xf321390f, 0x5660ab71, - 0x4c42f79a, 0xe90365e4, 0x032da597, 0xa66c37e9, 0xd29c5380, 0x77ddc1fe, - 0x9df3018d, 0x38b293f3, 0x7413c95f, 0xd1525b21, 0x3b7c9b52, 0x9e3d092c, - 0xeacd6d45, 0x4f8cff3b, 0xa5a23f48, 0x00e3ad36, 0x3ce08a10, 0x99a1186e, - 0x738fd81d, 0xd6ce4a63, 0xa23e2e0a, 0x077fbc74, 0xed517c07, 0x4810ee79, - 0x04b1b4d5, 0xa1f026ab, 0x4bdee6d8, 0xee9f74a6, 0x9a6f10cf, 0x3f2e82b1, - 0xd50042c2, 0x7041d0bc, 0xad060c8e, 0x08479ef0, 0xe2695e83, 0x4728ccfd, - 0x33d8a894, 0x96993aea, 0x7cb7fa99, 0xd9f668e7, 0x9557324b, 0x3016a035, - 0xda386046, 0x7f79f238, 0x0b899651, 0xaec8042f, 0x44e6c45c, 0xe1a75622, - 0xdda47104, 0x78e5e37a, 0x92cb2309, 0x378ab177, 0x437ad51e, 0xe63b4760, - 0x0c158713, 0xa954156d, 0xe5f54fc1, 0x40b4ddbf, 0xaa9a1dcc, 0x0fdb8fb2, - 0x7b2bebdb, 0xde6a79a5, 0x3444b9d6, 0x91052ba8}; -static const uint32 table3_[256] = { - 0x00000000, 0xdd45aab8, 0xbf672381, 0x62228939, 0x7b2231f3, 0xa6679b4b, - 0xc4451272, 0x1900b8ca, 0xf64463e6, 0x2b01c95e, 0x49234067, 0x9466eadf, - 0x8d665215, 0x5023f8ad, 0x32017194, 0xef44db2c, 0xe964b13d, 0x34211b85, - 0x560392bc, 0x8b463804, 0x924680ce, 0x4f032a76, 0x2d21a34f, 0xf06409f7, - 0x1f20d2db, 0xc2657863, 0xa047f15a, 0x7d025be2, 0x6402e328, 0xb9474990, - 0xdb65c0a9, 0x06206a11, 0xd725148b, 0x0a60be33, 0x6842370a, 0xb5079db2, - 0xac072578, 0x71428fc0, 0x136006f9, 0xce25ac41, 0x2161776d, 0xfc24ddd5, - 0x9e0654ec, 0x4343fe54, 0x5a43469e, 0x8706ec26, 0xe524651f, 0x3861cfa7, - 0x3e41a5b6, 0xe3040f0e, 0x81268637, 0x5c632c8f, 0x45639445, 0x98263efd, - 0xfa04b7c4, 0x27411d7c, 0xc805c650, 0x15406ce8, 0x7762e5d1, 0xaa274f69, - 0xb327f7a3, 0x6e625d1b, 0x0c40d422, 0xd1057e9a, 0xaba65fe7, 0x76e3f55f, - 0x14c17c66, 0xc984d6de, 0xd0846e14, 0x0dc1c4ac, 0x6fe34d95, 0xb2a6e72d, - 0x5de23c01, 0x80a796b9, 0xe2851f80, 0x3fc0b538, 0x26c00df2, 0xfb85a74a, - 0x99a72e73, 0x44e284cb, 0x42c2eeda, 0x9f874462, 0xfda5cd5b, 0x20e067e3, - 0x39e0df29, 0xe4a57591, 0x8687fca8, 0x5bc25610, 0xb4868d3c, 0x69c32784, - 0x0be1aebd, 0xd6a40405, 0xcfa4bccf, 0x12e11677, 0x70c39f4e, 0xad8635f6, - 0x7c834b6c, 0xa1c6e1d4, 0xc3e468ed, 0x1ea1c255, 0x07a17a9f, 0xdae4d027, - 0xb8c6591e, 0x6583f3a6, 0x8ac7288a, 0x57828232, 0x35a00b0b, 0xe8e5a1b3, - 0xf1e51979, 0x2ca0b3c1, 0x4e823af8, 0x93c79040, 0x95e7fa51, 0x48a250e9, - 0x2a80d9d0, 0xf7c57368, 0xeec5cba2, 0x3380611a, 0x51a2e823, 0x8ce7429b, - 0x63a399b7, 0xbee6330f, 0xdcc4ba36, 0x0181108e, 0x1881a844, 0xc5c402fc, - 0xa7e68bc5, 0x7aa3217d, 0x52a0c93f, 0x8fe56387, 0xedc7eabe, 0x30824006, - 0x2982f8cc, 0xf4c75274, 0x96e5db4d, 0x4ba071f5, 0xa4e4aad9, 0x79a10061, - 0x1b838958, 0xc6c623e0, 0xdfc69b2a, 0x02833192, 0x60a1b8ab, 0xbde41213, - 0xbbc47802, 0x6681d2ba, 0x04a35b83, 0xd9e6f13b, 0xc0e649f1, 0x1da3e349, - 0x7f816a70, 0xa2c4c0c8, 0x4d801be4, 0x90c5b15c, 0xf2e73865, 0x2fa292dd, - 0x36a22a17, 0xebe780af, 0x89c50996, 0x5480a32e, 0x8585ddb4, 0x58c0770c, - 0x3ae2fe35, 0xe7a7548d, 0xfea7ec47, 0x23e246ff, 0x41c0cfc6, 0x9c85657e, - 0x73c1be52, 0xae8414ea, 0xcca69dd3, 0x11e3376b, 0x08e38fa1, 0xd5a62519, - 0xb784ac20, 0x6ac10698, 0x6ce16c89, 0xb1a4c631, 0xd3864f08, 0x0ec3e5b0, - 0x17c35d7a, 0xca86f7c2, 0xa8a47efb, 0x75e1d443, 0x9aa50f6f, 0x47e0a5d7, - 0x25c22cee, 0xf8878656, 0xe1873e9c, 0x3cc29424, 0x5ee01d1d, 0x83a5b7a5, - 0xf90696d8, 0x24433c60, 0x4661b559, 0x9b241fe1, 0x8224a72b, 0x5f610d93, - 0x3d4384aa, 0xe0062e12, 0x0f42f53e, 0xd2075f86, 0xb025d6bf, 0x6d607c07, - 0x7460c4cd, 0xa9256e75, 0xcb07e74c, 0x16424df4, 0x106227e5, 0xcd278d5d, - 0xaf050464, 0x7240aedc, 0x6b401616, 0xb605bcae, 0xd4273597, 0x09629f2f, - 0xe6264403, 0x3b63eebb, 0x59416782, 0x8404cd3a, 0x9d0475f0, 0x4041df48, - 0x22635671, 0xff26fcc9, 0x2e238253, 0xf36628eb, 0x9144a1d2, 0x4c010b6a, - 0x5501b3a0, 0x88441918, 0xea669021, 0x37233a99, 0xd867e1b5, 0x05224b0d, - 0x6700c234, 0xba45688c, 0xa345d046, 0x7e007afe, 0x1c22f3c7, 0xc167597f, - 0xc747336e, 0x1a0299d6, 0x782010ef, 0xa565ba57, 0xbc65029d, 0x6120a825, - 0x0302211c, 0xde478ba4, 0x31035088, 0xec46fa30, 0x8e647309, 0x5321d9b1, - 0x4a21617b, 0x9764cbc3, 0xf54642fa, 0x2803e842}; - -// Used to fetch a naturally-aligned 32-bit word in little endian byte-order -static inline uint32_t LE_LOAD32(const uint8_t *p) { - return core::DecodeFixed32(reinterpret_cast(p)); -} - -uint32 Extend(uint32 crc, const char *buf, size_t size) { - static bool can_accelerate = CanAccelerate(); - if (can_accelerate) { - return AcceleratedExtend(crc, buf, size); - } - - const uint8 *p = reinterpret_cast(buf); - const uint8 *e = p + size; - uint32 l = crc ^ 0xffffffffu; - -#define STEP1 \ - do { \ - int c = (l & 0xff) ^ *p++; \ - l = table0_[c] ^ (l >> 8); \ - } while (0) - -#define STEP4 \ - do { \ - uint32 c = l ^ LE_LOAD32(p); \ - p += 4; \ - l = table3_[c & 0xff] ^ table2_[(c >> 8) & 0xff] ^ \ - table1_[(c >> 16) & 0xff] ^ table0_[c >> 24]; \ - } while (0) - - // Point x at first 4-byte aligned byte in string. This might be - // just past the end of the string. - const uintptr_t pval = reinterpret_cast(p); - const uint8 *x = reinterpret_cast(((pval + 3) >> 2) << 2); - if (x <= e) { - // Process bytes until finished or p is 4-byte aligned - while (p != x) { - STEP1; - } - } - // Process bytes 16 at a time - while ((e - p) >= 16) { - STEP4; - STEP4; - STEP4; - STEP4; - } - // Process bytes 4 at a time - while ((e - p) >= 4) { - STEP4; - } - // Process the last few bytes - while (p != e) { - STEP1; - } -#undef STEP4 -#undef STEP1 - return l ^ 0xffffffffu; -} - #if defined(TF_CORD_SUPPORT) uint32 Extend(uint32 crc, const absl::Cord &cord) { for (absl::string_view fragment : cord.Chunks()) { diff --git a/third_party/tsl/tsl/lib/hash/crc32c.h b/third_party/tsl/tsl/lib/hash/crc32c.h index 89da8d98f3321..10c4ea13e864d 100644 --- a/third_party/tsl/tsl/lib/hash/crc32c.h +++ b/third_party/tsl/tsl/lib/hash/crc32c.h @@ -18,6 +18,8 @@ limitations under the License. #include +#include "absl/crc/crc32c.h" +#include "absl/strings/string_view.h" #include "tsl/platform/cord.h" #include "tsl/platform/platform.h" #include "tsl/platform/types.h" @@ -28,7 +30,10 @@ namespace crc32c { // Return the crc32c of concat(A, buf[0,size-1]) where init_crc is the // crc32c of some string A. Extend() is often used to maintain the // crc32c of a stream of data. -extern uint32 Extend(uint32 init_crc, const char* buf, size_t size); +inline uint32 Extend(uint32 init_crc, const char* buf, size_t size) { + return static_cast(absl::ExtendCrc32c( + static_cast(init_crc), absl::string_view(buf, size))); +} #if defined(TF_CORD_SUPPORT) extern uint32 Extend(uint32 init_crc, const absl::Cord& cord); diff --git a/third_party/tsl/tsl/lib/hash/crc32c_accelerate.cc b/third_party/tsl/tsl/lib/hash/crc32c_accelerate.cc deleted file mode 100644 index bd48ef35b00b0..0000000000000 --- a/third_party/tsl/tsl/lib/hash/crc32c_accelerate.cc +++ /dev/null @@ -1,99 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -// SSE4.2 accelerated CRC32c. - -// See if the SSE4.2 crc32c instruction is available. -#undef USE_SSE_CRC32C -#ifdef __SSE4_2__ -#if defined(__x86_64__) && defined(__GNUC__) && \ - (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8)) -#define USE_SSE_CRC32C 1 -#elif defined(__x86_64__) && defined(__clang__) -#if __has_builtin(__builtin_cpu_supports) -#define USE_SSE_CRC32C 1 -#endif -#endif -#endif /* __SSE4_2__ */ - -// This version of Apple clang has a bug: -// https://llvm.org/bugs/show_bug.cgi?id=25510 -#if defined(__APPLE__) && (__clang_major__ <= 8) -#undef USE_SSE_CRC32C -#endif - -#ifdef USE_SSE_CRC32C -#include -#endif - -namespace tsl { -namespace crc32c { - -#ifndef USE_SSE_CRC32C - -bool CanAccelerate() { return false; } -uint32_t AcceleratedExtend(uint32_t crc, const char *buf, size_t size) { - // Should not be called. - return 0; -} - -#else - -// SSE4.2 optimized crc32c computation. -bool CanAccelerate() { return __builtin_cpu_supports("sse4.2"); } - -uint32_t AcceleratedExtend(uint32_t crc, const char *buf, size_t size) { - const uint8_t *p = reinterpret_cast(buf); - const uint8_t *e = p + size; - uint32_t l = crc ^ 0xffffffffu; - - // Advance p until aligned to 8-bytes.. - // Point x at first 7-byte aligned byte in string. This might be - // just past the end of the string. - const uintptr_t pval = reinterpret_cast(p); - const uint8_t *x = reinterpret_cast(((pval + 7) >> 3) << 3); - if (x <= e) { - // Process bytes until finished or p is 8-byte aligned - while (p != x) { - l = _mm_crc32_u8(l, *p); - p++; - } - } - - // Process bytes 16 at a time - uint64_t l64 = l; - while ((e - p) >= 16) { - l64 = _mm_crc32_u64(l64, *reinterpret_cast(p)); - l64 = _mm_crc32_u64(l64, *reinterpret_cast(p + 8)); - p += 16; - } - - // Process remaining bytes one at a time. - l = l64; - while (p < e) { - l = _mm_crc32_u8(l, *p); - p++; - } - - return l ^ 0xffffffffu; -} - -#endif - -} // namespace crc32c -} // namespace tsl diff --git a/third_party/tsl/tsl/lib/hash/crc32c_test.cc b/third_party/tsl/tsl/lib/hash/crc32c_test.cc index 6fcd9c902d155..9ba2e6e8108cf 100644 --- a/third_party/tsl/tsl/lib/hash/crc32c_test.cc +++ b/third_party/tsl/tsl/lib/hash/crc32c_test.cc @@ -17,9 +17,11 @@ limitations under the License. #include +#include "absl/strings/cord.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" +#include "tsl/platform/types.h" namespace tsl { namespace crc32c { diff --git a/third_party/tsl/tsl/lib/histogram/BUILD b/third_party/tsl/tsl/lib/histogram/BUILD index 1bc7af9ac546e..ffe3dc8ae6eed 100644 --- a/third_party/tsl/tsl/lib/histogram/BUILD +++ b/third_party/tsl/tsl/lib/histogram/BUILD @@ -1,3 +1,4 @@ +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "filegroup") load( "//tsl/platform:build_config.bzl", @@ -17,10 +18,11 @@ cc_library( name = "histogram", srcs = ["histogram.cc"], hdrs = ["histogram.h"], - visibility = [ + visibility = internal_visibility([ + "//learning/brain/google/monitoring:__pkg__", "//tensorflow/core/lib/histogram:__pkg__", "//tsl/lib/monitoring:__pkg__", - ], + ]), deps = [ "//tsl/platform:logging", "//tsl/platform:macros", @@ -39,7 +41,7 @@ filegroup( "histogram.cc", "histogram.h", ], - visibility = ["//tensorflow/core/lib/histogram:__pkg__"], + visibility = internal_visibility(["//tensorflow/core/lib/histogram:__pkg__"]), ) filegroup( @@ -47,7 +49,7 @@ filegroup( srcs = [ "histogram.h", ], - visibility = ["//tensorflow/core/lib/histogram:__pkg__"], + visibility = internal_visibility(["//tensorflow/core/lib/histogram:__pkg__"]), ) tsl_cc_test( diff --git a/third_party/tsl/tsl/lib/io/BUILD b/third_party/tsl/tsl/lib/io/BUILD index 7b133ff0e3183..79b39bbaa166a 100644 --- a/third_party/tsl/tsl/lib/io/BUILD +++ b/third_party/tsl/tsl/lib/io/BUILD @@ -1,6 +1,6 @@ -load("//tsl/platform:build_config.bzl", "tsl_cc_test") -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "filegroup") +load("//tsl/platform:build_config.bzl", "tsl_cc_test") load( "//tsl/platform:rules_cc.bzl", "cc_library", @@ -8,11 +8,11 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([ + default_visibility = internal_visibility([ "//tensorflow/c/experimental/filesystem:__pkg__", "//tensorflow/c/experimental/filesystem/plugins/posix:__pkg__", "//tsl/lib/io/snappy:__pkg__", - "//tensorflow/compiler/xla:__subpackages__", + "@xla//xla:__subpackages__", # tensorflow/core:lib effectively exposes all targets under tensorflow/core/lib/** "//tensorflow/core/util:__subpackages__", "//tensorflow/core:__pkg__", @@ -388,7 +388,7 @@ filegroup( "//tsl/lib/io/snappy:snappy_inputstream.h", "//tsl/lib/io/snappy:snappy_outputbuffer.h", ], - visibility = set_external_visibility(["//tensorflow/core:__pkg__"]), + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) filegroup( @@ -406,7 +406,7 @@ filegroup( "table_builder.h", "table_options.h", ], - visibility = set_external_visibility(["//tensorflow/core:__pkg__"]), + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) filegroup( @@ -422,7 +422,7 @@ filegroup( "//tsl/lib/io/snappy:snappy_inputstream.h", "//tsl/lib/io/snappy:snappy_outputbuffer.h", ], - visibility = set_external_visibility(["//tensorflow/core:__pkg__"]), + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) filegroup( @@ -432,7 +432,7 @@ filegroup( "block_builder.h", "format.h", ], - visibility = set_external_visibility(["//tensorflow/core:__pkg__"]), + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) tsl_cc_test( diff --git a/third_party/tsl/tsl/lib/io/random_inputstream.h b/third_party/tsl/tsl/lib/io/random_inputstream.h index 9313ca591757e..e1608ce3ec2b9 100644 --- a/third_party/tsl/tsl/lib/io/random_inputstream.h +++ b/third_party/tsl/tsl/lib/io/random_inputstream.h @@ -31,7 +31,7 @@ class RandomAccessInputStream : public InputStreamInterface { // must outlive *this. RandomAccessInputStream(RandomAccessFile* file, bool owns_file = false); - ~RandomAccessInputStream(); + ~RandomAccessInputStream() override; Status ReadNBytes(int64_t bytes_to_read, tstring* result) override; diff --git a/third_party/tsl/tsl/lib/io/snappy/BUILD b/third_party/tsl/tsl/lib/io/snappy/BUILD index a5e5376c03a66..dbe78e0362dab 100644 --- a/third_party/tsl/tsl/lib/io/snappy/BUILD +++ b/third_party/tsl/tsl/lib/io/snappy/BUILD @@ -1,3 +1,4 @@ +load("//tsl:tsl.bzl", "internal_visibility") load( "//tsl/platform:build_config.bzl", "tsl_cc_test", @@ -12,10 +13,10 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ + default_visibility = internal_visibility([ "//tensorflow/core/lib/io:__pkg__", "//tsl/lib/io:__pkg__", - ], + ]), licenses = ["notice"], ) diff --git a/third_party/tsl/tsl/lib/io/zlib_inputstream.h b/third_party/tsl/tsl/lib/io/zlib_inputstream.h index e19843c812d35..d009cbfd0baf5 100644 --- a/third_party/tsl/tsl/lib/io/zlib_inputstream.h +++ b/third_party/tsl/tsl/lib/io/zlib_inputstream.h @@ -55,7 +55,7 @@ class ZlibInputStream : public InputStreamInterface { size_t output_buffer_bytes, const ZlibCompressionOptions& zlib_options); - ~ZlibInputStream(); + ~ZlibInputStream() override; // Reads bytes_to_read bytes into *result, overwriting *result. // diff --git a/third_party/tsl/tsl/lib/io/zlib_outputbuffer.h b/third_party/tsl/tsl/lib/io/zlib_outputbuffer.h index 3e4236ac1e44f..8f0793c985bae 100644 --- a/third_party/tsl/tsl/lib/io/zlib_outputbuffer.h +++ b/third_party/tsl/tsl/lib/io/zlib_outputbuffer.h @@ -49,7 +49,7 @@ class ZlibOutputBuffer : public WritableFile { int32_t output_buffer_bytes, // size of z_stream.next_out buffer const ZlibCompressionOptions& zlib_options); - ~ZlibOutputBuffer(); + ~ZlibOutputBuffer() override; // Initializes some state necessary for the output buffer. This call is // required before any other operation on the buffer. diff --git a/third_party/tsl/tsl/lib/math/BUILD b/third_party/tsl/tsl/lib/math/BUILD index 52773191775cc..e5f1178382650 100644 --- a/third_party/tsl/tsl/lib/math/BUILD +++ b/third_party/tsl/tsl/lib/math/BUILD @@ -1,5 +1,5 @@ +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "get_compatible_with_portable") -load("//tsl:tsl.bzl", "set_external_visibility") load( "//tsl/platform:build_config.bzl", "tsl_cc_test", @@ -7,9 +7,9 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ + default_visibility = internal_visibility([ "//tensorflow:__subpackages__", - ], + ]), licenses = ["notice"], ) @@ -17,7 +17,7 @@ cc_library( name = "math_util", hdrs = ["math_util.h"], compatible_with = get_compatible_with_portable(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//platforms/performance/tf_sim/utils:__subpackages__", "//platforms/xla/service:__subpackages__", "//tensorflow:__subpackages__", @@ -48,7 +48,7 @@ filegroup( "math_util.h", ], compatible_with = get_compatible_with_portable(), - visibility = ["//tensorflow/core:__pkg__"], + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) exports_files([ diff --git a/third_party/tsl/tsl/lib/monitoring/BUILD b/third_party/tsl/tsl/lib/monitoring/BUILD index 11974718d6dc3..f605971aa9715 100644 --- a/third_party/tsl/tsl/lib/monitoring/BUILD +++ b/third_party/tsl/tsl/lib/monitoring/BUILD @@ -1,4 +1,4 @@ -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "filegroup") load( "//tsl/platform:rules_cc.bzl", @@ -7,7 +7,7 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([ + default_visibility = internal_visibility([ "//learning/brain/google/data:__subpackages__", "//learning/brain/google/monitoring:__subpackages__", # tensorflow/core:lib effectively exposes all targets under tensorflow/core/lib/** @@ -15,17 +15,18 @@ package( # tensorflow/core/platform:monitoring depends on this package "//tensorflow/core/platform:__subpackages__", # tensorflow/compiler/xla/pjrt:metrics depends on this package - "//tensorflow/compiler/xla/pjrt:__subpackages__", - "//tensorflow/compiler/xla/service/gpu:__subpackages__", + "@xla//xla/pjrt:__subpackages__", + "@xla//xla/service/gpu:__subpackages__", # tensorflow/compiler/mlir/tfrt:tf_jitrt depends on this package "//tensorflow/compiler/mlir/tfrt:__subpackages__", - "//tensorflow/compiler/xla/stream_executor:__subpackages__", - "//tensorflow/compiler/xla/hlo/experimental:__subpackages__", + "@xla//xla/stream_executor:__subpackages__", + "@xla//xla/hlo/experimental:__subpackages__", "//tensorflow/core/lib/monitoring:__subpackages__", - "//tensorflow/compiler/xla/service:__subpackages__", + "@xla//xla/service:__subpackages__", "//tsl/framework:__subpackages__", "//tsl/distributed_runtime:__subpackages__", "//tensorflow/compiler/mlir/tf2xla:__subpackages__", + "//tensorflow_serving/model_servers:__subpackages__", ]), licenses = ["notice"], ) @@ -94,7 +95,7 @@ cc_library( cc_library( name = "metric_def", hdrs = ["metric_def.h"], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__subpackages__", ]), deps = [ @@ -109,8 +110,9 @@ cc_library( name = "collection_registry", srcs = ["collection_registry.cc"], hdrs = ["collection_registry.h"], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__subpackages__", + "//tensorflow_serving/model_servers:__pkg__", ]), deps = [ ":collected_metrics", @@ -230,7 +232,7 @@ filegroup( "timed.h", "types.h", ], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/lib/monitoring:__pkg__", ]), @@ -251,7 +253,7 @@ filegroup( "timed.h", "types.h", ], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/lib/monitoring:__pkg__", ]), @@ -271,7 +273,7 @@ filegroup( "test_utils.h", "types.h", ], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/lib/monitoring:__pkg__", ]), diff --git a/third_party/tsl/tsl/lib/monitoring/collected_metrics.h b/third_party/tsl/tsl/lib/monitoring/collected_metrics.h index 8582594922adf..ba67299b57a95 100644 --- a/third_party/tsl/tsl/lib/monitoring/collected_metrics.h +++ b/third_party/tsl/tsl/lib/monitoring/collected_metrics.h @@ -90,6 +90,7 @@ struct Point { int64_t int64_value; string string_value; bool bool_value; + double double_value; HistogramProto histogram_value; Percentiles percentiles_value; diff --git a/third_party/tsl/tsl/lib/monitoring/collection_registry.h b/third_party/tsl/tsl/lib/monitoring/collection_registry.h index d988d2f19f15a..7af6c87e51f0b 100644 --- a/third_party/tsl/tsl/lib/monitoring/collection_registry.h +++ b/third_party/tsl/tsl/lib/monitoring/collection_registry.h @@ -352,6 +352,18 @@ inline void CollectValue(Percentiles value, Point* const point) { point->percentiles_value = std::move(value); } +template <> +inline void CollectValue(double value, Point* const point) { + point->value_type = ValueType::kDouble; + point->double_value = value; +} + +template <> +inline void CollectValue(std::function value_fn, Point* const point) { + point->value_type = ValueType::kDouble; + point->double_value = value_fn(); +} + // Used by the CollectionRegistry class to collect all the values of all the // metrics in the registry. This is an implementation detail of the // CollectionRegistry class, please do not depend on this. diff --git a/third_party/tsl/tsl/lib/monitoring/gauge.h b/third_party/tsl/tsl/lib/monitoring/gauge.h index 93cbe9aa928df..0b69383b5f2d1 100644 --- a/third_party/tsl/tsl/lib/monitoring/gauge.h +++ b/third_party/tsl/tsl/lib/monitoring/gauge.h @@ -65,8 +65,10 @@ class Gauge { std::is_same::value || std::is_same >::value || std::is_same >::value || - std::is_same >::value, - "Gauge only allows bool, int64, and string types."); + std::is_same >::value || + std::is_same >::value || + std::is_same::value, + "Gauge only allows bool, int64, double and string types."); return new Gauge(); } @@ -296,8 +298,10 @@ Gauge* Gauge::New( std::is_same::value || std::is_same >::value || std::is_same >::value || - std::is_same >::value, - "Gauge only allows bool, int64, and string types."); + std::is_same >::value || + std::is_same >::value || + std::is_same::value, + "Gauge only allows bool, int64, double, and string types."); return new Gauge( MetricDef( std::forward(metric_def_args)...)); diff --git a/third_party/tsl/tsl/lib/monitoring/metric_def.h b/third_party/tsl/tsl/lib/monitoring/metric_def.h index f8c21c360a2b0..ab454664691b1 100644 --- a/third_party/tsl/tsl/lib/monitoring/metric_def.h +++ b/third_party/tsl/tsl/lib/monitoring/metric_def.h @@ -47,7 +47,8 @@ enum class ValueType : int { kHistogram, kString, kBool, - kPercentiles + kPercentiles, + kDouble }; // Everything in the internal namespace is implementation details. Do not depend @@ -97,6 +98,16 @@ inline ValueType GetValueType>() { return ValueType::kBool; } +template <> +inline ValueType GetValueType() { + return ValueType::kDouble; +} + +template <> +inline ValueType GetValueType>() { + return ValueType::kDouble; +} + } // namespace internal // Abstract base class for a metric definition. diff --git a/third_party/tsl/tsl/lib/random/BUILD b/third_party/tsl/tsl/lib/random/BUILD index 26d3ea28c3b11..cdafe9bde6c4c 100644 --- a/third_party/tsl/tsl/lib/random/BUILD +++ b/third_party/tsl/tsl/lib/random/BUILD @@ -1,17 +1,20 @@ +load("//tsl:tsl.bzl", "internal_visibility") +load("//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") load( "//tsl/platform:build_config.bzl", "tsl_cc_test", ) -load("//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") load("//tsl/platform:rules_cc.bzl", "cc_library") +default_visibility = [ + "//tsl/lib/io:__pkg__", + # tensorflow/core/platform/random aliases this package + "//tensorflow/core/lib/random:__pkg__", +] + package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tsl/lib/io:__pkg__", - # tensorflow/core/platform/random aliases this package - "//tensorflow/core/lib/random:__pkg__", - ], + default_visibility = internal_visibility(default_visibility), licenses = ["notice"], ) @@ -32,6 +35,7 @@ cc_library( "random_distributions.h", "simple_philox.h", ], + visibility = internal_visibility(default_visibility), deps = [ ":exact_uniform_int", ":philox_random", @@ -49,10 +53,10 @@ cc_library( name = "random_distributions_utils", hdrs = ["random_distributions_utils.h"], compatible_with = get_compatible_with_portable(), - visibility = [ + visibility = internal_visibility([ "//tensorflow/core/lib/random:__pkg__", "//tensorflow/lite:__subpackages__", - ], + ]), deps = [":philox_random"], ) @@ -60,10 +64,10 @@ cc_library( name = "philox_random", hdrs = ["philox_random.h"], compatible_with = get_compatible_with_portable(), - visibility = [ + visibility = internal_visibility([ "//tensorflow/core/lib/random:__pkg__", "//tensorflow/lite:__subpackages__", - ], + ]), ) cc_library( @@ -116,7 +120,7 @@ filegroup( "random_distributions_utils.h", "simple_philox.h", ], - visibility = ["//tensorflow/core/lib/random:__pkg__"], + visibility = internal_visibility(["//tensorflow/core/lib/random:__pkg__"]), ) filegroup( @@ -126,7 +130,7 @@ filegroup( "random_distributions_utils.h", "weighted_picker.h", ], - visibility = ["//tensorflow/core/lib/random:__pkg__"], + visibility = internal_visibility(["//tensorflow/core/lib/random:__pkg__"]), ) filegroup( @@ -134,7 +138,7 @@ filegroup( srcs = [ "philox_random_test_utils.h", ], - visibility = ["//tensorflow/core/lib/random:__pkg__"], + visibility = internal_visibility(["//tensorflow/core/lib/random:__pkg__"]), ) filegroup( @@ -149,7 +153,7 @@ filegroup( "simple_philox.h", "weighted_picker.h", ], - visibility = ["//tensorflow/core/lib/random:__pkg__"], + visibility = internal_visibility(["//tensorflow/core/lib/random:__pkg__"]), ) tsl_cc_test( diff --git a/third_party/tsl/tsl/lib/strings/BUILD b/third_party/tsl/tsl/lib/strings/BUILD index 9637cf492706c..e89016582814b 100644 --- a/third_party/tsl/tsl/lib/strings/BUILD +++ b/third_party/tsl/tsl/lib/strings/BUILD @@ -1,4 +1,4 @@ -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "filegroup") load( "//tsl/platform:rules_cc.bzl", @@ -11,11 +11,11 @@ cc_library( name = "proto_serialization", srcs = ["proto_serialization.cc"], hdrs = ["proto_serialization.h"], - visibility = set_external_visibility([ - "//tensorflow/compiler/xla/pjrt:__subpackages__", - "//tensorflow/compiler/xla/python:__pkg__", - "//tensorflow/compiler/xla/service:__pkg__", - "//tensorflow/compiler/xla/stream_executor:__pkg__", + visibility = internal_visibility([ + "@xla//xla/pjrt:__subpackages__", + "@xla//xla/python:__pkg__", + "@xla//xla/service:__pkg__", + "@xla//xla/stream_executor:__pkg__", "//tensorflow/core/lib/strings:__pkg__", "//tensorflow/compiler/tf2xla/kernels:__pkg__", "//tensorflow/core/util/autotune_maps:__pkg__", @@ -37,7 +37,7 @@ filegroup( "proto_serialization.cc", "proto_serialization.h", ], - visibility = ["//tensorflow/core/lib/strings:__pkg__"], + visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]), ) filegroup( @@ -45,7 +45,7 @@ filegroup( srcs = [ "proto_serialization.h", ], - visibility = ["//tensorflow/core/lib/strings:__pkg__"], + visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]), ) filegroup( @@ -53,7 +53,7 @@ filegroup( srcs = [ "proto_serialization.h", ], - visibility = ["//tensorflow/core/lib/strings:__pkg__"], + visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]), ) filegroup( @@ -61,5 +61,5 @@ filegroup( srcs = [ "proto_serialization.h", ], - visibility = ["//tensorflow/core/lib/strings:__pkg__"], + visibility = internal_visibility(["//tensorflow/core/lib/strings:__pkg__"]), ) diff --git a/third_party/tsl/tsl/mkl/BUILD b/third_party/tsl/tsl/mkl/BUILD deleted file mode 100644 index db21c0fb1b402..0000000000000 --- a/third_party/tsl/tsl/mkl/BUILD +++ /dev/null @@ -1,146 +0,0 @@ -load("@bazel_skylib//:bzl_library.bzl", "bzl_library") -load( - "//tsl:tsl.bzl", - "clean_dep", -) - -licenses(["notice"]) # 3-Clause BSD - -config_setting( - name = "build_with_mkl", - define_values = { - "build_with_mkl": "true", - }, - visibility = ["//visibility:public"], -) - -config_setting( - name = "build_with_mkl_lnx_x64", - define_values = { - "build_with_mkl": "true", - }, - values = { - "cpu": "k8", - }, - visibility = ["//visibility:public"], -) - -config_setting( - name = "build_with_mkl_lnx_openmp", - constraint_values = [ - "@platforms//os:linux", - ], - define_values = { - "build_with_mkl": "true", - "build_with_openmp": "true", - }, - visibility = ["//visibility:public"], -) - -config_setting( - name = "build_with_mkl_windows_openmp", - constraint_values = [ - "@platforms//os:windows", - ], - define_values = { - "build_with_mkl": "true", - "build_with_openmp": "true", - }, - visibility = ["//visibility:public"], -) - -config_setting( - name = "build_with_mkl_aarch64", - define_values = { - "build_with_mkl_aarch64": "true", - }, - visibility = ["//visibility:public"], -) - -config_setting( - name = "enable_mkl", - define_values = { - "enable_mkl": "true", - "build_with_mkl": "true", - }, - visibility = ["//visibility:public"], -) - -filegroup( - name = "LICENSE", - srcs = [ - "MKL_LICENSE", - "@llvm_openmp//:LICENSE.txt", - ], - visibility = ["//visibility:public"], -) - -# TODO(Intel-tf) Remove the following 3 calls to cc_library and replace all uses -# of mkl_libs_* with @llvm_openmp//:libiomp5.* directly. - -cc_library( - name = "mkl_libs_linux", - srcs = [ - "@llvm_openmp//:libiomp5.so", - ], - hdrs = ["@llvm_openmp//:config_omp"], - target_compatible_with = select({ - "//tsl/mkl:build_with_mkl": [], - "//conditions:default": ["@platforms//:incompatible"], - }), - visibility = ["//visibility:public"], -) - -# MacOS build configuration is provided for completness, it has not been tested -cc_library( - name = "mkl_libs_darwin", - srcs = [ - "@llvm_openmp//:libiomp5.dylib", - ], - hdrs = ["@llvm_openmp//:config_omp"], - target_compatible_with = select({ - "//tsl/mkl:build_with_mkl": [], - "//conditions:default": ["@platforms//:incompatible"], - }), - visibility = ["//visibility:public"], -) - -cc_library( - name = "mkl_libs_windows", - srcs = [ - "@llvm_openmp//:libiomp5md.dll", - ], - hdrs = ["@llvm_openmp//:config_omp"], - target_compatible_with = select({ - "//tsl/mkl:build_with_mkl": [], - "//conditions:default": ["@platforms//:incompatible"], - }), - visibility = ["//visibility:public"], -) - -cc_library( - name = "intel_binary_blob", - target_compatible_with = select({ - "//tsl/mkl:build_with_mkl": [], - "//conditions:default": ["@platforms//:incompatible"], - }), - visibility = ["//visibility:public"], - deps = select({ - clean_dep("//tsl:linux_x86_64"): [ - ":mkl_libs_linux", - ], - clean_dep("//tsl:macos"): [ - ":mkl_libs_darwin", - ], - clean_dep("//tsl:windows"): [ - ":mkl_libs_windows", - ], - "//conditions:default": [], - }), -) - -bzl_library( - name = "build_defs_bzl", - srcs = ["build_defs.bzl"], - visibility = ["//visibility:public"], -) diff --git a/third_party/tsl/tsl/mkl/build_defs.bzl b/third_party/tsl/tsl/mkl/build_defs.bzl deleted file mode 100644 index ba97a5ae87857..0000000000000 --- a/third_party/tsl/tsl/mkl/build_defs.bzl +++ /dev/null @@ -1,161 +0,0 @@ -"""Starlark macros for MKL. - -if_mkl is a conditional to check if we are building with MKL. -if_mkl_ml is a conditional to check if we are building with MKL-ML. -if_mkl_ml_only is a conditional to check for MKL-ML-only (no MKL-DNN) mode. -if_mkl_lnx_x64 is a conditional to check for MKL -if_enable_mkl is a conditional to check if building with MKL and MKL is enabled. - -mkl_repository is a repository rule for creating MKL repository rule that can -be pointed to either a local folder, or download it from the internet. -mkl_repository depends on the following environment variables: - * `TF_MKL_ROOT`: The root folder where a copy of libmkl is located. -""" - -_TF_MKL_ROOT = "TF_MKL_ROOT" - -def if_mkl(if_true, if_false = []): - """Shorthand for select()'ing on whether we're building with oneDNN. - - OneDNN gets built if we are building on platforms that support oneDNN - (x86 linux/windows) or if specifcially configured to use oneDNN. - - Args: - if_true: expression to evaluate if building with oneDNN. - if_false: expression to evaluate if building without oneDNN. - - Returns: - a select evaluating to either if_true or if_false as appropriate. - - TODO(intel-tf): - the first "if_true" line is kept because non-x86 platforms (e.g., ARM) - may need it. It may be deleted in future with refactoring. - """ - return select({ - "@tsl//tsl/mkl:build_with_mkl_aarch64": if_true, - "@tsl//tsl:linux_x86_64": if_true, - "@tsl//tsl:windows": if_true, - "//conditions:default": if_false, - }) - -def if_mkl_ml(if_true, if_false = []): - """Shorthand for select()'ing on whether we're building with MKL-ML. - - Args: - if_true: expression to evaluate if building with MKL-ML. - if_false: expression to evaluate if building without MKL-ML - (i.e. without MKL at all, or with MKL-DNN only). - - Returns: - a select evaluating to either if_true or if_false as appropriate. - """ - return select({ - "@tsl//third_party/mkl_dnn:build_with_mkl_opensource": if_false, - "@tsl//tsl/mkl:build_with_mkl": if_true, - "//conditions:default": if_false, - }) - -def if_mkl_lnx_x64(if_true, if_false = []): - """Shorthand to select() if building with MKL and the target is Linux x86-64. - - Args: - if_true: expression to evaluate if building with MKL is enabled and the - target platform is Linux x86-64. - if_false: expression to evaluate if building without MKL or for a - different platform. - - Returns: - a select evaluating to either if_true or if_false as appropriate. - """ - return select({ - "@tsl//tsl/mkl:build_with_mkl_lnx_x64": if_true, - "//conditions:default": if_false, - }) - -def if_enable_mkl(if_true, if_false = []): - """Shorthand to select() if we are building with MKL and MKL is enabled. - - This is only effective when built with MKL. - - Args: - if_true: expression to evaluate if building with MKL and MKL is enabled - if_false: expression to evaluate if building without MKL or MKL is not enabled. - - Returns: - A select evaluating to either if_true or if_false as appropriate. - """ - return select({ - "@tsl//tsl/mkl:enable_mkl": if_true, - "//conditions:default": if_false, - }) - -def mkl_deps(): - """Returns the correct set of oneDNN library dependencies. - - Shorthand for select() to pull in the correct set of oneDNN library deps - depending on the platform. x86 Linux/Windows with or without --config=mkl - will always build with oneDNN library. - - Returns: - a select evaluating to a list of library dependencies, suitable for - inclusion in the deps attribute of rules. - """ - return select({ - "@tsl//tsl/mkl:build_with_mkl_aarch64": ["@mkl_dnn_acl_compatible//:mkl_dnn_acl"], - "@tsl//tsl:linux_x86_64": ["@onednn//:mkl_dnn"], - "@tsl//tsl:windows": ["@onednn//:mkl_dnn"], - "//conditions:default": [], - }) - -def onednn_v3_define(): - """Returns a define to build with oneDNN v3.x if it is enabled. - - Returns: - A define to build with oneDNN v3.x for Linux x86 and Windows x86 builds. - An empty list of all other cases (include ARM builds). - """ - return select({ - "@tsl//tsl/mkl:build_with_mkl_aarch64": ["-DENABLE_ONEDNN_V3"], - "@tsl//tsl:linux_x86_64": ["-DENABLE_ONEDNN_V3"], - "@tsl//tsl:windows": ["-DENABLE_ONEDNN_V3"], - "//conditions:default": [], - }) - -def _enable_local_mkl(repository_ctx): - return _TF_MKL_ROOT in repository_ctx.os.environ - -def _mkl_autoconf_impl(repository_ctx): - """Implementation of the local_mkl_autoconf repository rule.""" - - if _enable_local_mkl(repository_ctx): - # Symlink lib and include local folders. - mkl_root = repository_ctx.os.environ[_TF_MKL_ROOT] - mkl_lib_path = "%s/lib" % mkl_root - repository_ctx.symlink(mkl_lib_path, "lib") - mkl_include_path = "%s/include" % mkl_root - repository_ctx.symlink(mkl_include_path, "include") - mkl_license_path = "%s/license.txt" % mkl_root - repository_ctx.symlink(mkl_license_path, "license.txt") - else: - # setup remote mkl repository. - repository_ctx.download_and_extract( - repository_ctx.attr.urls, - sha256 = repository_ctx.attr.sha256, - stripPrefix = repository_ctx.attr.strip_prefix, - ) - - # Also setup BUILD file. - repository_ctx.symlink(repository_ctx.attr.build_file, "BUILD") - -mkl_repository = repository_rule( - implementation = _mkl_autoconf_impl, - environ = [ - _TF_MKL_ROOT, - ], - attrs = { - "build_file": attr.label(), - "urls": attr.string_list(default = []), - "sha256": attr.string(default = ""), - "strip_prefix": attr.string(default = ""), - }, -) diff --git a/third_party/tsl/tsl/platform/BUILD b/third_party/tsl/tsl/platform/BUILD index 7277c8e50f627..e9020cbe5ae86 100644 --- a/third_party/tsl/tsl/platform/BUILD +++ b/third_party/tsl/tsl/platform/BUILD @@ -4,10 +4,14 @@ # The libraries in this package are not allowed to have ANY dependencies # to other TF components outside of TSL. +load( + "@bazel_skylib//:bzl_library.bzl", + "bzl_library", +) load( "//tsl:tsl.bzl", "if_not_fuchsia", - "set_external_visibility", + "internal_visibility", "tsl_copts", ) load("//tsl:tsl.default.bzl", "get_compatible_with_portable") @@ -23,9 +27,9 @@ load( "tf_protobuf_compiler_deps", "tf_resource_deps", "tf_stream_executor_deps", - "tf_testing_deps", "tf_windows_aware_platform_deps", "tsl_cc_test", + "tsl_grpc_credentials_deps", "tsl_protobuf_deps", ) load("//tsl/platform:build_config_root.bzl", "if_static") @@ -33,10 +37,6 @@ load( "//tsl/platform:rules_cc.bzl", "cc_library", ) -load( - "@bazel_skylib//:bzl_library.bzl", - "bzl_library", -) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -52,7 +52,7 @@ exports_files( "load_library.h", "stringpiece_test.cc", ], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core/platform:__subpackages__", "//tsl:__subpackages__", ]), @@ -64,8 +64,10 @@ cc_library( hdrs = ["base64.h"], deps = [ ":errors", + ":macros", ":status", ":stringpiece", + ":types", ], ) @@ -183,15 +185,7 @@ cc_library( hdrs = ["dynamic_annotations.h"], compatible_with = get_compatible_with_portable(), deps = [ - ":platform", - ] + tf_platform_deps("dynamic_annotations"), -) - -cc_library( - name = "gif", - hdrs = ["gif.h"], - deps = [ - "@gif", + "@com_google_absl//absl/base:dynamic_annotations", ], ) @@ -381,7 +375,7 @@ filegroup( "test_benchmark.h", ], compatible_with = get_compatible_with_portable(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/platform:__pkg__", ]), @@ -394,7 +388,7 @@ filegroup( "test.h", ], compatible_with = get_compatible_with_portable(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/platform:__pkg__", ]), @@ -449,7 +443,6 @@ filegroup( "file_system_helper.h", "fingerprint.h", "init_main.h", - "logger.h", "mem.h", "net.h", "notification.h", @@ -601,7 +594,6 @@ filegroup( "error_logging.h", "fingerprint.h", "notification.h", - "png.h", "random.cc", "random.h", "test_benchmark.h", @@ -616,24 +608,12 @@ filegroup( "subprocess.h", ]), compatible_with = get_compatible_with_portable(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/platform:__pkg__", ]), ) -filegroup( - name = "gif_hdrs", - srcs = [ - "gif.h", - ], - compatible_with = get_compatible_with_portable(), - visibility = set_external_visibility([ - "//tensorflow/core/lib/gif:__pkg__", - "//tensorflow/core/platform:__pkg__", - ]), -) - filegroup( name = "legacy_lib_internal_headers", srcs = glob( @@ -642,9 +622,6 @@ filegroup( ], exclude = [ "dynamic_annotations.h", - "gif.h", - "png.h", - "jpeg.h", ], ) + [ "//tsl/platform/profile_utils:android_armv7a_cpu_utils_helper.h", @@ -673,6 +650,7 @@ exports_files( "file_system.h", "file_system_helper.cc", "file_system_helper.h", + "grpc_credentials.h", "host_info.h", "human_readable_json.h", "init_main.h", @@ -693,7 +671,7 @@ exports_files( "tracing.cc", "tracing.h", ], - visibility = set_external_visibility([ + visibility = internal_visibility([ ":__subpackages__", "//tensorflow:__subpackages__", ]), @@ -709,7 +687,7 @@ filegroup( "stringpiece.h", ], compatible_with = get_compatible_with_portable(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", ]), ) @@ -738,32 +716,12 @@ filegroup( "unbounded_work_queue.h", ], compatible_with = get_compatible_with_portable(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/platform:__pkg__", ]), ) -filegroup( - name = "jpeg_hdrs", - srcs = [ - "jpeg.h", - ], - compatible_with = get_compatible_with_portable(), - visibility = set_external_visibility([ - "//tensorflow/core:__pkg__", - "//tensorflow/core/lib/jpeg:__pkg__", - ]), -) - -cc_library( - name = "jpeg", - hdrs = ["jpeg.h"], - deps = [ - "@libjpeg_turbo//:jpeg", - ], -) - filegroup( name = "tflite_portable_logging_hdrs", srcs = [ @@ -772,7 +730,7 @@ filegroup( "platform.h", ], compatible_with = get_compatible_with_portable(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/lib/jpeg:__pkg__", ]), @@ -789,7 +747,7 @@ filegroup( "stringpiece.h", ], compatible_with = get_compatible_with_portable(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/lib/jpeg:__pkg__", "//tensorflow/core/platform:__pkg__", @@ -806,7 +764,7 @@ filegroup( "platform.h", ], compatible_with = get_compatible_with_portable(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/lib/gif:__pkg__", "//tensorflow/core/platform:__pkg__", @@ -883,8 +841,13 @@ cc_library( cc_library( name = "resource_loader", testonly = 1, + srcs = ["resource_loader.cc"], textual_hdrs = ["resource_loader.h"], - deps = tf_testing_deps("resource_loader"), + deps = [ + ":logging", + ":path", + ":test", + ], ) cc_library( @@ -1028,14 +991,12 @@ cc_library( cc_library( name = "build_test", testonly = 1, - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core/platform:__pkg__", ]), deps = [ ":byte_order", ":fingerprint", - ":gif", - ":jpeg", ":macros", ":net", ":platform", @@ -1067,17 +1028,6 @@ cc_library( ], ) -cc_library( - name = "float8", - hdrs = ["float8.h"], - compatible_with = get_compatible_with_portable(), - # TODO - b/299180335: Add deprecation, update usages. - # deprecation = "Please use ml_dtypes.", - deps = [ - ":ml_dtypes", - ], -) - cc_library( name = "dso_loader", hdrs = ["dso_loader.h"], @@ -1086,20 +1036,6 @@ cc_library( ] + tf_stream_executor_deps("dso_loader"), ) -cc_library( - name = "logger", - srcs = ["logger.cc"], - hdrs = ["logger.h"], - deps = [ - "env", - ":logging", - ":protobuf", - "@com_google_absl//absl/base", - "@com_google_absl//absl/synchronization", - ], - alwayslink = 1, -) - cc_library( name = "logging", compatible_with = get_compatible_with_portable(), @@ -1123,6 +1059,13 @@ cc_library( ] + tf_error_logging_deps(), ) +cc_library( + name = "grpc_credentials", + compatible_with = get_compatible_with_portable(), + textual_hdrs = ["grpc_credentials.h"], + deps = tsl_grpc_credentials_deps(), +) + cc_library( name = "prefetch", hdrs = ["prefetch.h"], @@ -1170,7 +1113,7 @@ filegroup( "str_util.h", ], compatible_with = get_compatible_with_portable(), - visibility = set_external_visibility(["//tensorflow/core:__pkg__"]), + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) cc_library( @@ -1235,6 +1178,7 @@ cc_library( hdrs = ["fingerprint.h"], compatible_with = get_compatible_with_portable(), deps = [ + ":platform", ":stringpiece", ":types", ] + tf_fingerprint_deps(), @@ -1420,17 +1364,8 @@ cc_library( hdrs = ["notification.h"], compatible_with = get_compatible_with_portable(), deps = [ - ":platform", - ] + tf_platform_deps("notification"), -) - -cc_library( - name = "png", - hdrs = ["png.h"], - compatible_with = get_compatible_with_portable(), - deps = [ - ":platform", - "@png", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", ], ) @@ -1783,8 +1718,10 @@ cc_library( deps = [ ":env", ":errors", + ":logging", ":random", ":status", + "@com_google_absl//absl/time", ], ) @@ -1830,5 +1767,6 @@ tsl_cc_test( ":test", ":test_main", "//tsl/lib/core:status_test_util", + "@com_google_absl//absl/time", ], ) diff --git a/third_party/tsl/tsl/platform/abi.cc b/third_party/tsl/tsl/platform/abi.cc index 3558b42c6e65e..8e886535d4503 100644 --- a/third_party/tsl/tsl/platform/abi.cc +++ b/third_party/tsl/tsl/platform/abi.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tsl/platform/abi.h" +#include "tsl/platform/types.h" + #if defined(_MSC_VER) #include #include diff --git a/third_party/tsl/tsl/platform/base64.cc b/third_party/tsl/tsl/platform/base64.cc index 8918d6ab1e312..6421b5ec92001 100644 --- a/third_party/tsl/tsl/platform/base64.cc +++ b/third_party/tsl/tsl/platform/base64.cc @@ -19,7 +19,10 @@ limitations under the License. #include #include "tsl/platform/errors.h" +#include "tsl/platform/macros.h" +#include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" +#include "tsl/platform/types.h" namespace tsl { namespace { diff --git a/third_party/tsl/tsl/platform/build_config.bzl b/third_party/tsl/tsl/platform/build_config.bzl index faa3390cf7c58..bec0e8403b248 100644 --- a/third_party/tsl/tsl/platform/build_config.bzl +++ b/third_party/tsl/tsl/platform/build_config.bzl @@ -2,8 +2,6 @@ load( "//tsl/platform/default:build_config.bzl", - _if_llvm_aarch64_available = "if_llvm_aarch64_available", - _if_llvm_system_z_available = "if_llvm_system_z_available", _pyx_library = "pyx_library", _tf_additional_all_protos = "tf_additional_all_protos", _tf_additional_core_deps = "tf_additional_core_deps", @@ -36,14 +34,12 @@ load( _tf_pyclif_proto_library = "tf_pyclif_proto_library", _tf_resource_deps = "tf_resource_deps", _tf_stream_executor_deps = "tf_stream_executor_deps", - _tf_testing_deps = "tf_testing_deps", _tf_windows_aware_platform_deps = "tf_windows_aware_platform_deps", _tsl_cc_test = "tsl_cc_test", + _tsl_grpc_credentials_deps = "tsl_grpc_credentials_deps", _tsl_protobuf_deps = "tsl_protobuf_deps", ) -if_llvm_aarch64_available = _if_llvm_aarch64_available -if_llvm_system_z_available = _if_llvm_system_z_available pyx_library = _pyx_library tf_additional_all_protos = _tf_additional_all_protos tf_additional_core_deps = _tf_additional_core_deps @@ -76,7 +72,7 @@ tf_py_clif_cc = _tf_py_clif_cc tf_pyclif_proto_library = _tf_pyclif_proto_library tf_resource_deps = _tf_resource_deps tf_stream_executor_deps = _tf_stream_executor_deps -tf_testing_deps = _tf_testing_deps tf_windows_aware_platform_deps = _tf_windows_aware_platform_deps tsl_protobuf_deps = _tsl_protobuf_deps tsl_cc_test = _tsl_cc_test +tsl_grpc_credentials_deps = _tsl_grpc_credentials_deps diff --git a/third_party/tsl/tsl/platform/build_config_root.bzl b/third_party/tsl/tsl/platform/build_config_root.bzl index e236414e26722..151e40d4c02e3 100644 --- a/third_party/tsl/tsl/platform/build_config_root.bzl +++ b/third_party/tsl/tsl/platform/build_config_root.bzl @@ -2,6 +2,12 @@ load( "//tsl/platform/default:build_config_root.bzl", + _if_llvm_aarch32_available = "if_llvm_aarch32_available", + _if_llvm_aarch64_available = "if_llvm_aarch64_available", + _if_llvm_arm_available = "if_llvm_arm_available", + _if_llvm_powerpc_available = "if_llvm_powerpc_available", + _if_llvm_system_z_available = "if_llvm_system_z_available", + _if_llvm_x86_available = "if_llvm_x86_available", _if_static = "if_static", _if_static_and_not_mobile = "if_static_and_not_mobile", _tf_additional_grpc_deps_py = "tf_additional_grpc_deps_py", @@ -14,6 +20,12 @@ load( _tf_gpu_tests_tags = "tf_gpu_tests_tags", ) +if_llvm_aarch32_available = _if_llvm_aarch32_available +if_llvm_aarch64_available = _if_llvm_aarch64_available +if_llvm_arm_available = _if_llvm_arm_available +if_llvm_powerpc_available = _if_llvm_powerpc_available +if_llvm_system_z_available = _if_llvm_system_z_available +if_llvm_x86_available = _if_llvm_x86_available if_static = _if_static if_static_and_not_mobile = _if_static_and_not_mobile tf_additional_grpc_deps_py = _tf_additional_grpc_deps_py diff --git a/third_party/tsl/tsl/platform/cloud/BUILD b/third_party/tsl/tsl/platform/cloud/BUILD index 9a7bf418091b2..220a05f96dbdc 100644 --- a/third_party/tsl/tsl/platform/cloud/BUILD +++ b/third_party/tsl/tsl/platform/cloud/BUILD @@ -1,18 +1,18 @@ # Description: # Cloud file system implementation. -load("//tsl/platform:rules_cc.bzl", "cc_library") load( "//tsl:tsl.bzl", "if_windows", - "set_external_visibility", + "internal_visibility", "tsl_copts", ) load("//tsl/platform:build_config.bzl", "tsl_cc_test") +load("//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([ + default_visibility = internal_visibility([ ":dependency_allowlist", ]), licenses = ["notice"], @@ -217,8 +217,8 @@ cc_library( "//tsl/platform:str_util", "//tsl/platform:stringpiece", "//tsl/platform:types", - "//tsl/util:env_var", "@curl", + "@xla//xla/tsl/util:env_var", ], ) diff --git a/third_party/tsl/tsl/platform/cloud/curl_http_request.cc b/third_party/tsl/tsl/platform/cloud/curl_http_request.cc index a7e6a65e37335..c41f967c04b05 100644 --- a/third_party/tsl/tsl/platform/cloud/curl_http_request.cc +++ b/third_party/tsl/tsl/platform/cloud/curl_http_request.cc @@ -17,13 +17,13 @@ limitations under the License. #include +#include "xla/tsl/util/env_var.h" #include "tsl/lib/gtl/map_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/macros.h" #include "tsl/platform/scanner.h" #include "tsl/platform/str_util.h" #include "tsl/platform/types.h" -#include "tsl/util/env_var.h" #define CHECK_CURL_OK(expr) CHECK_EQ(expr, CURLE_OK) diff --git a/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc b/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc index ea65028a96cd2..869dc993ee0a9 100644 --- a/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc +++ b/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc @@ -66,10 +66,10 @@ limitations under the License. namespace tsl { namespace { -constexpr char kGcsUriBase[] = "https://www.googleapis.com/storage/v1/"; +constexpr char kGcsUriBase[] = "https://www.googleapis.com./storage/v1/"; constexpr char kGcsUploadUriBase[] = - "https://www.googleapis.com/upload/storage/v1/"; -constexpr char kStorageHost[] = "storage.googleapis.com"; + "https://www.googleapis.com./upload/storage/v1/"; +constexpr char kStorageHost[] = "storage.googleapis.com."; constexpr char kBucketMetadataLocationKey[] = "location"; constexpr size_t kReadAppendableFileBufferSize = 1024 * 1024; // In bytes. constexpr int kGetChildrenDefaultPageSize = 1000; diff --git a/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc b/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc index 9221128276af9..e403599096e5f 100644 --- a/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc +++ b/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc @@ -62,13 +62,13 @@ class FakeZoneProvider : public ZoneProvider { TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache) { std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-5\n" "Timeouts: 5 1 20\n", "012345"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 6-11\n" "Timeouts: 5 1 20\n", @@ -108,13 +108,13 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache) { TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered) { std::vector requests({ new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", "0123456789"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 10-19\n" "Timeouts: 5 1 20\n", @@ -155,14 +155,14 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered) { TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_Errors) { std::vector requests({ new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", "Server Not", errors::Unavailable("important HTTP error 308"), nullptr, {}, 308), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 6-15\n" "Timeouts: 5 1 20\n", @@ -204,13 +204,13 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_Errors) { TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_ReadAtEOF) { std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", "0123456789"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 10-19\n" "Timeouts: 5 1 20\n", @@ -251,7 +251,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_CachedOutOfRange) { // In this test, there is only one backend request since we cache the file // size. std::vector requests({new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", @@ -297,13 +297,13 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_CachedNotSequential) { // a backend request. std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 1-10\n" "Timeouts: 5 1 20\n", "12345678"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", @@ -339,13 +339,13 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_CachedNotSequential) { TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_Growing) { std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", "012345678"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 9-18\n" "Timeouts: 5 1 20\n", @@ -387,13 +387,13 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_ReadBackwards) { // Go backwards in the file. It should trigger a new read. std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 5-14\n" "Timeouts: 5 1 20\n", "56789"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-9\n" "Timeouts: 5 1 20\n", @@ -433,7 +433,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_ReadBackwards) { TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintInSameLocation) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", R"( @@ -460,7 +460,7 @@ TEST(GcsFileSystemTest, TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintCaching) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", R"( @@ -468,7 +468,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintCaching) { "location":"US-EAST1" })"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/anotherbucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/anotherbucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", R"( @@ -476,7 +476,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintCaching) { "location":"US-EAST1" })"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", R"( @@ -517,7 +517,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintCaching) { TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintInDifferentLocation) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", R"( @@ -547,13 +547,13 @@ TEST(GcsFileSystemTest, TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache_DifferentN) { std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-2\n" "Timeouts: 5 1 20\n", "012"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 3-12\n" "Timeouts: 5 1 20\n", @@ -593,26 +593,26 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache) { // "0123456789abcde". std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "random_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"15\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-8\n" "Timeouts: 5 1 20\n", "012345678"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 9-17\n" "Timeouts: 5 1 20\n", "9abcde"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 18-26\n" "Timeouts: 5 1 20\n", @@ -679,27 +679,27 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_Flush) { // "0123456789abcde". std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "random_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"15\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-8\n" "Timeouts: 5 1 20\n", "012345678"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "random_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"15\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-8\n" "Timeouts: 5 1 20\n", @@ -738,22 +738,24 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_MaxStaleness) { // "0123456789abcdef". std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "object?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"16\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), - new FakeHttpRequest("Uri: https://storage.googleapis.com/bucket/object\n" - "Auth Token: fake_token\n" - "Range: 0-7\n" - "Timeouts: 5 1 20\n", - "01234567"), - new FakeHttpRequest("Uri: https://storage.googleapis.com/bucket/object\n" - "Auth Token: fake_token\n" - "Range: 8-15\n" - "Timeouts: 5 1 20\n", - "89abcdef")}); + new FakeHttpRequest( + "Uri: https://storage.googleapis.com./bucket/object\n" + "Auth Token: fake_token\n" + "Range: 0-7\n" + "Timeouts: 5 1 20\n", + "01234567"), + new FakeHttpRequest( + "Uri: https://storage.googleapis.com./bucket/object\n" + "Auth Token: fake_token\n" + "Range: 8-15\n" + "Timeouts: 5 1 20\n", + "89abcdef")}); GcsFileSystem fs( std::unique_ptr(new FakeAuthProvider), std::unique_ptr( @@ -800,27 +802,27 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_FileSignatureChanges) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "random_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"5\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-8\n" "Timeouts: 5 1 20\n", "01234"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "random_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"5\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-8\n" "Timeouts: 5 1 20\n", @@ -874,14 +876,14 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoObjectName) { TEST(GcsFileSystemTest, NewRandomAccessFile_InconsistentRead) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "random_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"6\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-5\n" "Timeouts: 5 1 20\n", @@ -917,20 +919,20 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_InconsistentRead) { TEST(GcsFileSystemTest, NewWritableFile) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fwriteable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"16\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fwriteable\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fwriteable\n" "Auth Token: fake_token\n" "Range: 0-7\n" "Timeouts: 5 1 20\n", "01234567"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -944,14 +946,14 @@ TEST(GcsFileSystemTest, NewWritableFile) { "Put body: content1,content2\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fwriteable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"33\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:15:34.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fwriteable\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fwriteable\n" "Auth Token: fake_token\n" "Range: 0-7\n" "Timeouts: 5 1 20\n", @@ -998,7 +1000,7 @@ TEST(GcsFileSystemTest, NewWritableFile) { TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceeds) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable.txt\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1076,20 +1078,20 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceedsOnGetStatus) { // path. std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fwriteable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"16\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fwriteable\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fwriteable\n" "Auth Token: fake_token\n" "Range: 0-7\n" "Timeouts: 5 1 20\n", "01234567"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1109,14 +1111,14 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceedsOnGetStatus) { "Put: yes\n", "", OkStatus(), nullptr, {}, 201), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fwriteable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"33\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:19:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fwriteable\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fwriteable\n" "Auth Token: fake_token\n" "Range: 0-7\n" "Timeouts: 5 1 20\n", @@ -1163,7 +1165,7 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceedsOnGetStatus) { TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadAllAttemptsFail) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable.txt\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1196,7 +1198,7 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadAllAttemptsFail) { // These calls will be made in the Close() attempt from the destructor. // Letting the destructor succeed. requests.emplace_back(new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable.txt\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1245,7 +1247,7 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable.txt\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1262,7 +1264,7 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) { // These calls will be made in the Close() attempt from the destructor. // Letting the destructor succeed. new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fwriteable.txt\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1334,26 +1336,26 @@ TEST(GcsFileSystemTest, NewWritableFile_NoObjectName) { TEST(GcsFileSystemTest, NewAppendableFile) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fappendable\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fappendable\n" "Auth Token: fake_token\n" "Range: 0-1048575\n" "Timeouts: 5 1 20\n", "content1,"), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fappendable\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fappendable\n" "Auth Token: fake_token\n" "Range: 0-31\n" "Timeouts: 5 1 20\n", "content1,"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fappendable\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 17\n" @@ -1367,14 +1369,14 @@ TEST(GcsFileSystemTest, NewAppendableFile) { "Put body: content1,content2\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:25:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fappendable\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fappendable\n" "Auth Token: fake_token\n" "Range: 0-31\n" "Timeouts: 5 1 20\n", @@ -1435,13 +1437,13 @@ TEST(GcsFileSystemTest, NewAppendableFile_NoObjectName) { TEST(GcsFileSystemTest, NewAppendableFile_ObjectDoesNotExist) { std::vector requests( {new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/filename\n" + "Uri: https://storage.googleapis.com./bucket/filename\n" "Auth Token: fake_token\n" "Range: 0-1048575\n" "Timeouts: 5 1 20\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o" "?uploadType=resumable&name=filename\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 0\n" @@ -1467,7 +1469,7 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) { const string content = "file content"; std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Frandom_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -1475,7 +1477,7 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) { ", \"generation\": \"1\"", ", \"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - strings::StrCat("Uri: https://storage.googleapis.com/bucket/" + strings::StrCat("Uri: https://storage.googleapis.com./bucket/" "path%2Frandom_access.txt\n" "Auth Token: fake_token\n" "Range: 0-", @@ -1520,7 +1522,7 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile_NoObjectName) { TEST(GcsFileSystemTest, FileExists_YesAsObject) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Ffile1.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -1543,13 +1545,13 @@ TEST(GcsFileSystemTest, FileExists_YesAsObject) { TEST(GcsFileSystemTest, FileExists_YesAsFolder) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsubfolder?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubfolder%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -1573,12 +1575,12 @@ TEST(GcsFileSystemTest, FileExists_YesAsFolder) { TEST(GcsFileSystemTest, FileExists_YesAsBucket) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket1\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket1\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"size\": \"100\"}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket1\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket1\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"size\": \"100\"}")}); @@ -1600,13 +1602,13 @@ TEST(GcsFileSystemTest, FileExists_YesAsBucket) { TEST(GcsFileSystemTest, FileExists_NotAsObjectOrFolder) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Ffile1.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Ffile1.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -1630,12 +1632,12 @@ TEST(GcsFileSystemTest, FileExists_NotAsObjectOrFolder) { TEST(GcsFileSystemTest, FileExists_NotAsBucket) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket2\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket2\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket2\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket2\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404)}); @@ -1656,20 +1658,20 @@ TEST(GcsFileSystemTest, FileExists_NotAsBucket) { TEST(GcsFileSystemTest, FileExists_StatCache) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Ffile1.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsubfolder%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubfolder%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -1697,7 +1699,7 @@ TEST(GcsFileSystemTest, FileExists_StatCache) { TEST(GcsFileSystemTest, FileExists_DirectoryMark) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "dir%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -1720,7 +1722,7 @@ TEST(GcsFileSystemTest, FileExists_DirectoryMark) { TEST(GcsFileSystemTest, GetChildren_NoItems) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&prefix=" "path%2F\n" "Auth Token: fake_token\n" @@ -1745,7 +1747,7 @@ TEST(GcsFileSystemTest, GetChildren_NoItems) { TEST(GcsFileSystemTest, GetChildren_ThreeFiles) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&prefix=" "path%2F\n" "Auth Token: fake_token\n" @@ -1774,7 +1776,7 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles) { TEST(GcsFileSystemTest, GetChildren_SelfDirectoryMarker) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&prefix=" "path%2F\n" "Auth Token: fake_token\n" @@ -1802,7 +1804,7 @@ TEST(GcsFileSystemTest, GetChildren_SelfDirectoryMarker) { TEST(GcsFileSystemTest, GetChildren_ThreeFiles_NoSlash) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&prefix=" "path%2F\n" "Auth Token: fake_token\n" @@ -1831,7 +1833,7 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles_NoSlash) { TEST(GcsFileSystemTest, GetChildren_Root) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket-a-b-c/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket-a-b-c/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -1855,7 +1857,7 @@ TEST(GcsFileSystemTest, GetChildren_Root) { TEST(GcsFileSystemTest, GetChildren_Empty) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&prefix=" "path%2F\n" "Auth Token: fake_token\n" @@ -1881,7 +1883,7 @@ TEST(GcsFileSystemTest, GetChildren_Empty) { TEST(GcsFileSystemTest, GetChildren_Pagination) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&" "prefix=path%2F\n" "Auth Token: fake_token\n" @@ -1892,7 +1894,7 @@ TEST(GcsFileSystemTest, GetChildren_Pagination) { " { \"name\": \"path/file3.txt\" }]," "\"prefixes\": [\"path/subpath/\"]}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2Cprefixes%2CnextPageToken&delimiter=%2F&" "prefix=path%2F" "&pageToken=ABCD==\n" @@ -1923,7 +1925,7 @@ TEST(GcsFileSystemTest, GetChildren_Pagination) { TEST(GcsFileSystemTest, GetMatchingPaths_NoWildcard) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubpath%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -1949,7 +1951,7 @@ TEST(GcsFileSystemTest, GetMatchingPaths_NoWildcard) { TEST(GcsFileSystemTest, GetMatchingPaths_BucketAndWildcard) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -1978,7 +1980,7 @@ TEST(GcsFileSystemTest, GetMatchingPaths_BucketAndWildcard) { TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_Matches) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2006,7 +2008,7 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_Matches) { TEST(GcsFileSystemTest, GetMatchingPaths_SelfDirectoryMarker) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2031,7 +2033,7 @@ TEST(GcsFileSystemTest, GetMatchingPaths_SelfDirectoryMarker) { TEST(GcsFileSystemTest, GetMatchingPaths_SlashInObjectName) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2056,7 +2058,7 @@ TEST(GcsFileSystemTest, GetMatchingPaths_SlashInObjectName) { TEST(GcsFileSystemTest, GetMatchingPaths_SlashInObjectNameEscaped) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2081,7 +2083,7 @@ TEST(GcsFileSystemTest, GetMatchingPaths_SlashInObjectNameEscaped) { TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_NoMatches) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2127,14 +2129,14 @@ TEST(GcsFileSystemTest, GetMatchingPaths_OnlyWildcard) { TEST(GcsFileSystemTest, GetMatchingPaths_Cache) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubpath%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"items\": [ " " { \"name\": \"path/subpath/file2.txt\" }]}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2172,14 +2174,14 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache) { TEST(GcsFileSystemTest, GetMatchingPaths_Cache_Flush) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubpath%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"items\": [ " " { \"name\": \"path/subpath/file2.txt\" }]}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubpath%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2218,33 +2220,33 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache_Flush) { TEST(GcsFileSystemTest, DeleteFile) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Ffile1.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Ffile1.txt\n" + "Uri: https://storage.googleapis.com./bucket/path%2Ffile1.txt\n" "Auth Token: fake_token\n" "Range: 0-15\n" "Timeouts: 5 1 20\n", "01234567"), - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2Ffile1.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Ffile1.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:19:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Ffile1.txt\n" + "Uri: https://storage.googleapis.com./bucket/path%2Ffile1.txt\n" "Auth Token: fake_token\n" "Range: 0-15\n" "Timeouts: 5 1 20\n", @@ -2296,26 +2298,26 @@ TEST(GcsFileSystemTest, DeleteFile_NoObjectName) { TEST(GcsFileSystemTest, DeleteFile_StatCacheRemoved) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/file.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=file.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2347,7 +2349,7 @@ TEST(GcsFileSystemTest, DeleteFile_StatCacheRemoved) { TEST(GcsFileSystemTest, DeleteDir_Empty) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F&maxResults=2\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2369,13 +2371,13 @@ TEST(GcsFileSystemTest, DeleteDir_Empty) { TEST(GcsFileSystemTest, DeleteDir_OnlyDirMarkerLeft) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F&maxResults=2\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"items\": [ " " { \"name\": \"path/\" }]}"), - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -2397,7 +2399,7 @@ TEST(GcsFileSystemTest, DeleteDir_OnlyDirMarkerLeft) { TEST(GcsFileSystemTest, DeleteDir_BucketOnly) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?fields=items%2F" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?fields=items%2F" "name%2CnextPageToken&maxResults=2\nAuth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}")}); @@ -2417,7 +2419,7 @@ TEST(GcsFileSystemTest, DeleteDir_BucketOnly) { TEST(GcsFileSystemTest, DeleteDir_NonEmpty) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F&maxResults=2\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2440,7 +2442,7 @@ TEST(GcsFileSystemTest, DeleteDir_NonEmpty) { TEST(GcsFileSystemTest, GetFileSize) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2484,7 +2486,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { std::vector requests( {// Check if this is a folder or an object. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path1%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2493,7 +2495,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { " { \"name\": \"path1/subfolder/file1.txt\" }]}"), // Requesting the full list of files in the folder. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path1%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2503,7 +2505,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { " { \"name\": \"path1/file2.txt\" }]}"), // Copying the directory marker. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path1%2F/rewriteTo/b/bucket/o/path2%2F\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -2511,7 +2513,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { "{\"done\": true}"), // Deleting the original directory marker. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path1%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -2519,7 +2521,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { ""), // Copying the first file. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path1%2Fsubfolder%2Ffile1.txt/rewriteTo/b/bucket/o/" "path2%2Fsubfolder%2Ffile1.txt\n" "Auth Token: fake_token\n" @@ -2528,7 +2530,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { "{\"done\": true}"), // Deleting the first original file. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path1%2Fsubfolder%2Ffile1.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -2536,7 +2538,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { ""), // Copying the second file. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path1%2Ffile2.txt/rewriteTo/b/bucket/o/path2%2Ffile2.txt\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -2544,7 +2546,7 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { "{\"done\": true}"), // Deleting the second original file. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path1%2Ffile2.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -2568,34 +2570,34 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { TEST(GcsFileSystemTest, RenameFile_Object) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fsrc.txt\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fsrc.txt\n" "Auth Token: fake_token\n" "Range: 0-15\n" "Timeouts: 5 1 20\n", "01234567"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fdst.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fdst.txt\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fdst.txt\n" "Auth Token: fake_token\n" "Range: 0-15\n" "Timeouts: 5 1 20\n", "76543210"), // IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsrc.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2603,7 +2605,7 @@ TEST(GcsFileSystemTest, RenameFile_Object) { "{}"), // Copying to the new location. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt/rewriteTo/b/bucket/o/path%2Fdst.txt\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -2611,34 +2613,34 @@ TEST(GcsFileSystemTest, RenameFile_Object) { "{\"done\": true}"), // Deleting the original file. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fsrc.txt\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fsrc.txt\n" "Auth Token: fake_token\n" "Range: 0-15\n" "Timeouts: 5 1 20\n", "89abcdef"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fdst.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"2\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fdst.txt\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fdst.txt\n" "Auth Token: fake_token\n" "Range: 0-15\n" "Timeouts: 5 1 20\n", @@ -2681,7 +2683,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) { std::vector requests( {// Stat the target file. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fdst.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2689,7 +2691,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), // IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsrc.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2697,7 +2699,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) { "{}"), // IsDirectory is checking if the path exists as an object. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2705,7 +2707,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), // Copying to the new location. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt/rewriteTo/b/bucket/o/path%2Fdst.txt\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -2713,14 +2715,14 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) { "{\"done\": true}"), // Deleting the original file. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fdst.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2757,7 +2759,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) { std::vector requests( {// IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsrc.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2765,7 +2767,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) { "{}"), // IsDirectory is checking if the path exists as an object. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2773,7 +2775,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), // Copying to the new location. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt/rewriteTo/b/bucket/o/path%2Fdst.txt\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -2781,7 +2783,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) { "{\"done\": true}"), // Deleting the original file - the deletion returns a failure. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -2789,7 +2791,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) { "", errors::Unavailable("503"), 503), // Deleting the original file again - the deletion returns NOT_FOUND. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -2815,7 +2817,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_Incomplete) { std::vector requests( {// IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsrc.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2823,7 +2825,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_Incomplete) { "{}"), // IsDirectory is checking if the path exists as an object. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2831,7 +2833,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_Incomplete) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), // Copying to the new location. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fsrc.txt/rewriteTo/b/bucket/o/path%2Fdst.txt\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -2854,7 +2856,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_Incomplete) { TEST(GcsFileSystemTest, Stat_Object) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -2881,13 +2883,13 @@ TEST(GcsFileSystemTest, Stat_Object) { TEST(GcsFileSystemTest, Stat_Folder) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "subfolder?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=subfolder%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2915,13 +2917,13 @@ TEST(GcsFileSystemTest, Stat_Folder) { TEST(GcsFileSystemTest, Stat_ObjectOrFolderNotFound) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -2945,7 +2947,7 @@ TEST(GcsFileSystemTest, Stat_ObjectOrFolderNotFound) { TEST(GcsFileSystemTest, Stat_Bucket) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}")}); @@ -2969,7 +2971,7 @@ TEST(GcsFileSystemTest, Stat_Bucket) { TEST(GcsFileSystemTest, Stat_BucketNotFound) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404)}); @@ -2992,20 +2994,20 @@ TEST(GcsFileSystemTest, Stat_BucketNotFound) { TEST(GcsFileSystemTest, Stat_Cache) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "subfolder%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=subfolder%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3041,14 +3043,14 @@ TEST(GcsFileSystemTest, Stat_Cache) { TEST(GcsFileSystemTest, Stat_Cache_Flush) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3085,7 +3087,7 @@ TEST(GcsFileSystemTest, Stat_Cache_Flush) { TEST(GcsFileSystemTest, Stat_FilenameEndingWithSlash) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "dir%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3111,14 +3113,14 @@ TEST(GcsFileSystemTest, Stat_FilenameEndingWithSlash) { TEST(GcsFileSystemTest, IsDirectory_NotFound) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=file.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3141,14 +3143,14 @@ TEST(GcsFileSystemTest, IsDirectory_NotFound) { TEST(GcsFileSystemTest, IsDirectory_NotDirectoryButObject) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=file.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3172,14 +3174,14 @@ TEST(GcsFileSystemTest, IsDirectory_NotDirectoryButObject) { TEST(GcsFileSystemTest, IsDirectory_Yes) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=subfolder%2F" "&maxResults=1\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"items\": [{\"name\": \"subfolder/\"}]}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=subfolder%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3203,12 +3205,12 @@ TEST(GcsFileSystemTest, IsDirectory_Yes) { TEST(GcsFileSystemTest, IsDirectory_Bucket) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}"), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}")}); @@ -3229,7 +3231,7 @@ TEST(GcsFileSystemTest, IsDirectory_Bucket) { TEST(GcsFileSystemTest, IsDirectory_BucketNotFound) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404)}); @@ -3254,14 +3256,14 @@ TEST(GcsFileSystemTest, CreateDir_Folder) { { // File doesn't exist. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "subpath%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}"), // Simple upload. new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=media&name=subpath%2F&ifGenerationMatch=0\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -3269,7 +3271,7 @@ TEST(GcsFileSystemTest, CreateDir_Folder) { ""), // File exists. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "subpath%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3277,14 +3279,14 @@ TEST(GcsFileSystemTest, CreateDir_Folder) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), // File doesn't exist again. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "subpath%2F?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}"), // Simulate object uploaded in between. new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=media&name=subpath%2F&ifGenerationMatch=0\n" "Auth Token: fake_token\n" "Post: yes\n" @@ -3316,12 +3318,12 @@ TEST(GcsFileSystemTest, CreateDir_Folder) { TEST(GcsFileSystemTest, CreateDir_Bucket) { std::vector requests( {new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket\n" + "Uri: https://www.googleapis.com./storage/v1/b/bucket\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "")}); @@ -3344,7 +3346,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_Ok) { std::vector requests( {// IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3353,7 +3355,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_Ok) { " { \"name\": \"path/file1.txt\" }]}"), // GetChildren recursively. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3363,35 +3365,35 @@ TEST(GcsFileSystemTest, DeleteRecursively_Ok) { " { \"name\": \"path/subpath/file2.txt\" }," " { \"name\": \"path/file3.txt\" }]}"), // Delete the current directory's marker. - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), // Delete the object - fails and will be retried. - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2Ffile1.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", "", errors::Unavailable("500"), 500), // Delete the object again. - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2Ffile1.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), // Delete the object. - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2Fsubpath%2Ffile2.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), // Delete the object. - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2Ffile3.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -3419,7 +3421,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { std::vector requests( {// IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3428,7 +3430,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { " { \"name\": \"path/file1.txt\" }]}"), // Calling GetChildren recursively. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3438,14 +3440,14 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { " { \"name\": \"path/subpath/file2.txt\" }," " { \"name\": \"path/file3.txt\" }]}"), // Deleting the object. - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2Ffile1.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), // Deleting the directory marker gs://bucket/path/ - fails with 404. - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2Fsubpath%2F\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -3453,7 +3455,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { "", errors::NotFound("404"), 404), // Checking if gs://bucket/path/subpath/ is a folder - it is. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubpath%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3461,14 +3463,14 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { strings::StrCat("{\"items\": [ " " { \"name\": \"path/subpath/\" }]}")), // Deleting the object gs://bucket/path/subpath/file2.txt - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2Fsubpath%2Ffile2.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), // Deleting the object s://bucket/path/file3.txt - fails with 404. - new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + new FakeHttpRequest("Uri: https://www.googleapis.com./storage/v1/b" "/bucket/o/path%2Ffile3.txt\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -3476,7 +3478,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { "", errors::NotFound("404"), 404), // Checking if gs://bucket/path/file3.txt/ is a folder - it's not. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2Ffile3.txt%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3484,7 +3486,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { "{}"), // Checking if gs://bucket/path/file3.txt is an object - fails with 404. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Ffile3.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3512,7 +3514,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_NotAFolder) { std::vector requests( {// IsDirectory is checking whether there are children objects. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o?" "fields=items%2Fname%2CnextPageToken&prefix=path%2F" "&maxResults=1\n" "Auth Token: fake_token\n" @@ -3520,7 +3522,7 @@ TEST(GcsFileSystemTest, DeleteRecursively_NotAFolder) { "{}"), // IsDirectory is checking if the path exists as an object. new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3604,7 +3606,7 @@ TEST(GcsFileSystemTest, AdditionalRequestHeaderTest) { std::vector requests( {// IsDirectory is checking whether there are children objects. - new FakeHttpRequest("Uri: https://www.googleapis.com/fake\n" + new FakeHttpRequest("Uri: https://www.googleapis.com./fake\n" "Auth Token: fake_token\n" "Header mynewheader: newheadercontents\n" "Header Hello: world\n", @@ -3622,7 +3624,7 @@ TEST(GcsFileSystemTest, AdditionalRequestHeaderTest) { std::unique_ptr request; TF_EXPECT_OK(fs7.CreateHttpRequest(&request)); - request->SetUri("https://www.googleapis.com/fake"); + request->SetUri("https://www.googleapis.com./fake"); request->AddHeader("Hello", "world"); TF_EXPECT_OK(request->Send()); } @@ -3684,7 +3686,7 @@ TEST(GcsFileSystemTest, OverrideCacheParameters) { TEST(GcsFileSystemTest, CreateHttpRequest) { std::vector requests( {// IsDirectory is checking whether there are children objects. - new FakeHttpRequest("Uri: https://www.googleapis.com/fake\n" + new FakeHttpRequest("Uri: https://www.googleapis.com./fake\n" "Auth Token: fake_token\n" "Header Hello: world\n", "{}")}); @@ -3701,7 +3703,7 @@ TEST(GcsFileSystemTest, CreateHttpRequest) { std::unique_ptr request; TF_EXPECT_OK(fs.CreateHttpRequest(&request)); - request->SetUri("https://www.googleapis.com/fake"); + request->SetUri("https://www.googleapis.com./fake"); request->AddHeader("Hello", "world"); TF_EXPECT_OK(request->Send()); } @@ -3745,7 +3747,7 @@ class TestGcsStats : public GcsStatsInterface { TEST(GcsFileSystemTest, Stat_StatsRecording) { std::vector requests({new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "file.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3773,7 +3775,7 @@ TEST(GcsFileSystemTest, Stat_StatsRecording) { TEST(GcsFileSystemTest, NewRandomAccessFile_StatsRecording) { std::vector requests({new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/random_access.txt\n" + "Uri: https://storage.googleapis.com./bucket/random_access.txt\n" "Auth Token: fake_token\n" "Range: 0-5\n" "Timeouts: 5 1 20\n", @@ -3815,7 +3817,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { // Fetch the file (stats and then content) new FakeHttpRequest( "Uri: " - "https://www.googleapis.com/storage/v1/b/bucket/o/" + "https://www.googleapis.com./storage/v1/b/bucket/o/" "some%2Fpath%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3823,14 +3825,14 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( "Uri: " - "https://storage.googleapis.com/bucket/some%2Fpath%2Fappendable\n" + "https://storage.googleapis.com./bucket/some%2Fpath%2Fappendable\n" "Auth Token: fake_token\n" "Range: 0-1048575\n" "Timeouts: 5 1 20\n", contents[0]), // Upload entire file new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=some%2Fpath%2Fappendable\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 18\n" @@ -3848,7 +3850,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { // Upload new part to a temporary object new FakeHttpRequest( "Uri: " - "https://www.googleapis.com/upload/storage/v1/b/bucket/" + "https://www.googleapis.com./upload/storage/v1/b/bucket/" "o?uploadType=resumable&name=some%2Fpath%2F.tmpcompose%2Fappendable." "18\n" "Auth Token: fake_token\n" @@ -3870,7 +3872,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { // Fetch generation new FakeHttpRequest( "Uri: " - "https://www.googleapis.com/storage/v1/b/bucket/o/" + "https://www.googleapis.com./storage/v1/b/bucket/o/" "some%2Fpath%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", @@ -3878,7 +3880,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), // Compose the new part at the end of the original object. new FakeHttpRequest("Uri: " - "https://www.googleapis.com/storage/v1/b/bucket/o/" + "https://www.googleapis.com./storage/v1/b/bucket/o/" "some%2Fpath%2Fappendable/compose\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -3891,14 +3893,14 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { ""), // Delete the temporary object. new FakeHttpRequest("Uri: " - "https://www.googleapis.com/storage/v1/b/bucket/o/" + "https://www.googleapis.com./storage/v1/b/bucket/o/" "some%2Fpath%2F.tmpcompose%2Fappendable.18\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" "Delete: yes\n", ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=some%2Fpath%2F.tmpcompose%2Fappendable." "27\n" "Auth Token: fake_token\n" @@ -3917,14 +3919,14 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { // Fetch generation new FakeHttpRequest( "Uri: " - "https://www.googleapis.com/storage/v1/b/bucket/o/" + "https://www.googleapis.com./storage/v1/b/bucket/o/" "some%2Fpath%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"4567\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest("Uri: " - "https://www.googleapis.com/storage/v1/b/bucket/o/" + "https://www.googleapis.com./storage/v1/b/bucket/o/" "some%2Fpath%2Fappendable/compose\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n" @@ -3936,7 +3938,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { "'some/path/.tmpcompose/appendable.27'}]}\n", ""), new FakeHttpRequest("Uri: " - "https://www.googleapis.com/storage/v1/b/bucket/o/" + "https://www.googleapis.com./storage/v1/b/bucket/o/" "some%2Fpath%2F.tmpcompose%2Fappendable." "27\n" "Auth Token: fake_token\n" @@ -3973,20 +3975,20 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithoutCompose) { {"content0,", "content1,", "content2,", "content3,"}); std::vector requests({ new FakeHttpRequest( - "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "Uri: https://www.googleapis.com./storage/v1/b/bucket/o/" "path%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"8\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( - "Uri: https://storage.googleapis.com/bucket/path%2Fappendable\n" + "Uri: https://storage.googleapis.com./bucket/path%2Fappendable\n" "Auth Token: fake_token\n" "Range: 0-1048575\n" "Timeouts: 5 1 20\n", contents[0]), new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fappendable\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 18\n" @@ -4003,7 +4005,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithoutCompose) { contents[0], contents[1], "\n"), ""), new FakeHttpRequest("Uri: " - "https://www.googleapis.com/upload/storage/v1/b/" + "https://www.googleapis.com./upload/storage/v1/b/" "bucket/o?" "uploadType=resumable&name=path%2Fappendable\n" "Auth Token: fake_token\n" @@ -4024,7 +4026,7 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithoutCompose) { contents[0], contents[1], contents[2], "\n"), ""), new FakeHttpRequest( - "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "Uri: https://www.googleapis.com./upload/storage/v1/b/bucket/o?" "uploadType=resumable&name=path%2Fappendable\n" "Auth Token: fake_token\n" "Header X-Upload-Content-Length: 36\n" diff --git a/third_party/tsl/tsl/platform/cpu_info.cc b/third_party/tsl/tsl/platform/cpu_info.cc index c25c354fd37ca..1de5eb8031623 100644 --- a/third_party/tsl/tsl/platform/cpu_info.cc +++ b/third_party/tsl/tsl/platform/cpu_info.cc @@ -82,6 +82,7 @@ class CPUIDInfo { : have_adx_(0), have_aes_(0), have_amx_bf16_(0), + have_amx_fp16_(0), have_amx_int8_(0), have_amx_tile_(0), have_avx_(0), @@ -98,8 +99,11 @@ class CPUIDInfo { have_avx512_4vnniw_(0), have_avx512_4fmaps_(0), have_avx512_bf16_(0), + have_avx512_fp16_(0), have_avx512_vnni_(0), have_avx_vnni_(0), + have_avx_vnni_int8_(0), + have_avx_ne_convert_(0), have_bmi1_(0), have_bmi2_(0), have_cmov_(0), @@ -226,12 +230,19 @@ class CPUIDInfo { cpuid->have_amx_int8_ = (edx >> 25) & 0x1; cpuid->have_amx_bf16_ = (edx >> 22) & 0x1; + // Check for avx512_fp16 using information from Xbyak in oneDNN: + // https://github.com/oneapi-src/oneDNN/blob/acf8d214cedfe7e24c9446bacc1f9f648c9273f8/src/cpu/x64/xbyak/xbyak_util.h#L516 + cpuid->have_avx512_fp16_ = have_avx512 && ((edx >> 23) & 0x1); + // Get more Structured Extended Feature info by issuing CPUID with // sub-leaf = 1 (eax = 7, ecx = 1) if (kMaxNumSubLeaves >= 1) { GETCPUID(eax, ebx, ecx, edx, 7, 1); cpuid->have_avx_vnni_ = (eax >> 4) & 0x1; cpuid->have_avx512_bf16_ = have_avx512 && ((eax >> 5) & 0x1); + cpuid->have_amx_fp16_ = (eax >> 21) & 0x1; + cpuid->have_avx_vnni_int8_ = (edx >> 4) & 0x1; + cpuid->have_avx_ne_convert_ = (edx >> 5) & 0x1; } } @@ -242,6 +253,7 @@ class CPUIDInfo { case ADX: return cpuid->have_adx_; case AES: return cpuid->have_aes_; case AMX_BF16: return cpuid->have_amx_bf16_; + case AMX_FP16: return cpuid->have_amx_fp16_; case AMX_INT8: return cpuid->have_amx_int8_; case AMX_TILE: return cpuid->have_amx_tile_; case AVX2: return cpuid->have_avx2_; @@ -258,8 +270,11 @@ class CPUIDInfo { case AVX512_4VNNIW: return cpuid->have_avx512_4vnniw_; case AVX512_4FMAPS: return cpuid->have_avx512_4fmaps_; case AVX512_BF16: return cpuid->have_avx512_bf16_; + case AVX512_FP16: return cpuid->have_avx512_fp16_; case AVX512_VNNI: return cpuid->have_avx512_vnni_; case AVX_VNNI: return cpuid->have_avx_vnni_; + case AVX_VNNI_INT8: return cpuid->have_avx_vnni_int8_; + case AVX_NE_CONVERT: return cpuid->have_avx_ne_convert_; case BMI1: return cpuid->have_bmi1_; case BMI2: return cpuid->have_bmi2_; case CMOV: return cpuid->have_cmov_; @@ -297,6 +312,7 @@ class CPUIDInfo { int have_adx_ : 1; int have_aes_ : 1; int have_amx_bf16_ : 1; + int have_amx_fp16_ : 1; int have_amx_int8_ : 1; int have_amx_tile_ : 1; int have_avx_ : 1; @@ -313,8 +329,11 @@ class CPUIDInfo { int have_avx512_4vnniw_ : 1; int have_avx512_4fmaps_ : 1; int have_avx512_bf16_ : 1; + int have_avx512_fp16_ : 1; int have_avx512_vnni_ : 1; int have_avx_vnni_ : 1; + int have_avx_vnni_int8_ : 1; + int have_avx_ne_convert_ : 1; int have_bmi1_ : 1; int have_bmi2_ : 1; int have_cmov_ : 1; diff --git a/third_party/tsl/tsl/platform/cpu_info.h b/third_party/tsl/tsl/platform/cpu_info.h index e0b0d66bb1111..68506b1d34ae8 100644 --- a/third_party/tsl/tsl/platform/cpu_info.h +++ b/third_party/tsl/tsl/platform/cpu_info.h @@ -132,6 +132,11 @@ enum CPUFeature { AMX_TILE = 41, // Tile configuration and load/store AMX_INT8 = 42, // Int8 tile matrix multiplication AMX_BF16 = 43, // Bfloat16 tile matrix multiplication + + AVX512_FP16 = 44, // Float16 neural network + AMX_FP16 = 45, // Float16 tile matrix multiplication + AVX_NE_CONVERT = 46, // Instructions for faster bfloat16, float16 convert. + AVX_VNNI_INT8 = 47, // VNNI instructions for combinations of u8, s8 dtypes. }; enum Aarch64CPU { diff --git a/third_party/tsl/tsl/platform/default/BUILD b/third_party/tsl/tsl/platform/default/BUILD index 82ef13527d53a..8e46929122a02 100644 --- a/third_party/tsl/tsl/platform/default/BUILD +++ b/third_party/tsl/tsl/platform/default/BUILD @@ -1,13 +1,18 @@ -load("//tsl:tsl.bzl", "if_not_fuchsia", "if_not_windows", "set_external_visibility", "tsl_copts") -load("//tsl:tsl.default.bzl", "filegroup") -load("//tsl/platform:rules_cc.bzl", "cc_library") - # Tensorflow default + linux implementations of tensorflow/core/platform libraries. load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load( + "//tsl:tsl.bzl", + "if_not_fuchsia", + "if_not_windows", + "internal_visibility", + "tsl_copts", +) +load("//tsl:tsl.default.bzl", "filegroup", "tsl_grpc_cc_dependencies") +load("//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([ + default_visibility = internal_visibility([ "//tensorflow/core/lib/jpeg:__pkg__", "//tensorflow/core/platform:__pkg__", "//tsl/platform:__pkg__", @@ -82,12 +87,11 @@ cc_library( "nobuilder", ], deps = [ - "//tsl/platform:env", - "//tsl/platform:errors", + "//tsl/platform:load_library", "//tsl/platform:logging", "//tsl/platform:path", - "//tsl/platform:status", - "//tsl/platform:statusor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", @@ -97,16 +101,6 @@ cc_library( ], ) -cc_library( - name = "dynamic_annotations", - hdrs = ["dynamic_annotations.h"], - tags = [ - "manual", - "no_oss", - "nobuilder", - ], -) - cc_library( name = "env", srcs = [ @@ -124,6 +118,7 @@ cc_library( "//tsl/platform:ram_file_system.h", "//tsl/platform:threadpool.h", ], + copts = tsl_copts(), tags = [ "manual", "no_oss", @@ -213,6 +208,21 @@ cc_library( ], ) +cc_library( + name = "grpc_credentials", + srcs = ["grpc_credentials.cc"], + hdrs = ["//tsl/platform:grpc_credentials.h"], + tags = [ + "manual", + "no_oss", + "nobuilder", + ], + deps = [ + "//tsl/platform:logging", + "@com_google_absl//absl/log:check", + ] + tsl_grpc_cc_dependencies(), +) + cc_library( name = "human_readable_json", srcs = ["human_readable_json.cc"], @@ -240,8 +250,7 @@ cc_library( "nobuilder", ], deps = [ - "//tsl/platform:errors", - "//tsl/platform:status", + "@com_google_absl//absl/status", ], ) @@ -269,7 +278,6 @@ cc_library( filegroup( name = "xla_cpu_runtime_srcs", srcs = [ - "dynamic_annotations.h", "integral_types.h", ] + if_not_windows(["env_time.cc"]), ) @@ -314,20 +322,6 @@ cc_library( alwayslink = True, ) -cc_library( - name = "notification", - hdrs = ["notification.h"], - tags = [ - "manual", - "no_oss", - "nobuilder", - ], - deps = [ - "//tsl/platform:mutex", - "//tsl/platform:types", - ], -) - cc_library( name = "platform_port", srcs = [ @@ -389,24 +383,6 @@ cc_library( ], ) -cc_library( - name = "resource_loader", - testonly = 1, - srcs = ["resource_loader.cc"], - hdrs = ["//tsl/platform:resource_loader.h"], - tags = [ - "manual", - "no_oss", - "nobuilder", - ], - deps = [ - "//tsl/platform:logging", - "//tsl/platform:path", - "//tsl/platform:test", - "@bazel_tools//tools/cpp/runfiles", - ], -) - cc_library( name = "rocm_rocdl_path", srcs = ["rocm_rocdl_path.cc"], @@ -533,7 +509,7 @@ cc_library( "nobuilder", ], textual_hdrs = ["crash_analysis.h"], - visibility = set_external_visibility(["//tensorflow:__subpackages__"]), + visibility = internal_visibility(["//tensorflow:__subpackages__"]), deps = [ "//tsl/platform", "//tsl/platform:protobuf", @@ -548,7 +524,7 @@ cc_library( "nobuilder", ], textual_hdrs = ["status.h"], - visibility = set_external_visibility(["//tensorflow:__subpackages__"]), + visibility = internal_visibility(["//tensorflow:__subpackages__"]), ) cc_library( @@ -559,7 +535,7 @@ cc_library( "nobuilder", ], textual_hdrs = ["statusor.h"], - visibility = set_external_visibility(["//tensorflow:__subpackages__"]), + visibility = internal_visibility(["//tensorflow:__subpackages__"]), deps = [ "//tsl/platform:macros", "//tsl/platform:status", @@ -570,7 +546,7 @@ cc_library( bzl_library( name = "cuda_build_defs_bzl", srcs = ["cuda_build_defs.bzl"], - visibility = set_external_visibility(["//tensorflow:__subpackages__"]), + visibility = internal_visibility(["//tensorflow:__subpackages__"]), ) bzl_library( @@ -581,7 +557,7 @@ bzl_library( # Export source files needed for mobile builds, which do not use granular targets. filegroup( name = "additional_mobile_srcs_no_runtime", - visibility = set_external_visibility(["//tensorflow/core/platform:__pkg__"]), + visibility = internal_visibility(["//tensorflow/core/platform:__pkg__"]), ) filegroup( @@ -589,7 +565,6 @@ filegroup( srcs = [ "casts.h", "context.h", - "dynamic_annotations.h", "env.cc", "integral_types.h", "load_library.cc", @@ -603,7 +578,7 @@ filegroup( "//tsl/platform/profile_utils:cpu_utils.h", "//tsl/platform/profile_utils:i_cpu_utils_helper.h", ], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core/platform:__pkg__", "//tsl/platform:__pkg__", ]), @@ -616,14 +591,13 @@ filegroup( "error_logging.cc", "mutex.h", "mutex_data.h", - "notification.h", "unbounded_work_queue.cc", "unbounded_work_queue.h", ] + if_not_fuchsia([ "subprocess.cc", "subprocess.h", ]), - visibility = set_external_visibility(["//tensorflow/core/platform:__pkg__"]), + visibility = internal_visibility(["//tensorflow/core/platform:__pkg__"]), ) exports_files( @@ -635,7 +609,7 @@ exports_files( "test.cc", ], ), - visibility = set_external_visibility(["//tensorflow/core/platform:__pkg__"]), + visibility = internal_visibility(["//tensorflow/core/platform:__pkg__"]), ) exports_files( @@ -644,7 +618,7 @@ exports_files( "logging.h", "test.cc", ], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/lib/gif:__pkg__", "//tensorflow/core/lib/jpeg:__pkg__", diff --git a/third_party/tsl/tsl/platform/default/build_config.bzl b/third_party/tsl/tsl/platform/default/build_config.bzl index e8dd1d12abd73..3d08489c73dc6 100644 --- a/third_party/tsl/tsl/platform/default/build_config.bzl +++ b/third_party/tsl/tsl/platform/default/build_config.bzl @@ -1,14 +1,14 @@ # Platform-specific build configurations. +load("@com_github_grpc_grpc//bazel:generate_cc.bzl", "generate_cc") load("@com_google_protobuf//:protobuf.bzl", "proto_gen") -load("//tsl/platform:build_config_root.bzl", "if_static") load( "//tsl:tsl.bzl", "clean_dep", "if_not_windows", "if_tsl_link_protobuf", ) -load("@com_github_grpc_grpc//bazel:generate_cc.bzl", "generate_cc") +load("//tsl/platform:build_config_root.bzl", "if_static") def well_known_proto_libs(): """Set of standard protobuf protos, like Any and Timestamp. @@ -655,12 +655,10 @@ def tf_additional_lib_hdrs(): clean_dep("//tsl/platform/default:casts.h"), clean_dep("//tsl/platform/default:context.h"), clean_dep("//tsl/platform/default:criticality.h"), - clean_dep("//tsl/platform/default:dynamic_annotations.h"), clean_dep("//tsl/platform/default:integral_types.h"), clean_dep("//tsl/platform/default:logging.h"), clean_dep("//tsl/platform/default:mutex.h"), clean_dep("//tsl/platform/default:mutex_data.h"), - clean_dep("//tsl/platform/default:notification.h"), clean_dep("//tsl/platform/default:stacktrace.h"), clean_dep("//tsl/platform/default:status.h"), clean_dep("//tsl/platform/default:statusor.h"), @@ -734,7 +732,7 @@ def tf_lib_proto_parsing_deps(): return [ ":protos_all_cc", clean_dep("@eigen_archive//:eigen3"), - clean_dep("//tsl/platform/default/build_config:proto_parsing"), + clean_dep("//tsl/protobuf:protos_all_cc"), ] def tf_py_clif_cc(name, visibility = None, **kwargs): @@ -820,9 +818,6 @@ def tf_windows_aware_platform_deps(name): def tf_platform_deps(name, platform_dir = "@tsl//tsl/platform/"): return [platform_dir + "default:" + name] -def tf_testing_deps(name, platform_dir = "@tsl//tsl/platform/"): - return tf_platform_deps(name, platform_dir) - def tf_stream_executor_deps(name, platform_dir = "@tsl//tsl/platform/"): return tf_platform_deps(name, platform_dir) @@ -835,6 +830,9 @@ def tf_logging_deps(): def tf_error_logging_deps(): return [clean_dep("//tsl/platform/default:error_logging")] +def tsl_grpc_credentials_deps(): + return [clean_dep("//tsl/platform/default:grpc_credentials")] + def tf_resource_deps(): return [clean_dep("//tsl/platform/default:resource")] @@ -853,14 +851,5 @@ def tf_google_mobile_srcs_no_runtime(): def tf_google_mobile_srcs_only_runtime(): return [] -def if_llvm_aarch64_available(then, otherwise = []): - return then - -def if_llvm_system_z_available(then, otherwise = []): - return select({ - clean_dep("//tsl:linux_s390x"): then, - "//conditions:default": otherwise, - }) - def tf_cuda_libdevice_path_deps(): return tf_platform_deps("cuda_libdevice_path") diff --git a/third_party/tsl/tsl/platform/default/build_config/BUILD b/third_party/tsl/tsl/platform/default/build_config/BUILD deleted file mode 100644 index ac9a71c993cef..0000000000000 --- a/third_party/tsl/tsl/platform/default/build_config/BUILD +++ /dev/null @@ -1,119 +0,0 @@ -# Description: -# Platform-specific build configurations. - -load("//tsl/platform:rules_cc.bzl", "cc_library") -load("//tsl:tsl.bzl", "set_external_visibility", "tsl_copts") - -package(default_visibility = set_external_visibility(["//tsl:internal"])) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -cc_library( - name = "gtest", - testonly = 1, - copts = tsl_copts(), - deps = [ - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "tensorflow_platform_specific", - copts = tsl_copts(), - linkstatic = 1, - deps = [], -) - -cc_library( - name = "_empty_lib", - visibility = ["//visibility:private"], -) - -# Dummy stream executor cuda plugins. -cc_library( - name = "cublas_plugin", - srcs = [], -) - -cc_library( - name = "cufft_plugin", - srcs = [], -) - -cc_library( - name = "cudnn_plugin", - srcs = [], -) - -# Minimal lib so that tools used for mobile compilation -# don't have to depend on platformlib. -cc_library( - name = "proto_parsing", - copts = tsl_copts(), - deps = [ - "//tsl/protobuf:protos_all_cc", - ], -) - -# Minimal lib to be used by tensorflow/core:framework_lite. -# This provides minimal support for writing operator implementations (kernels), -# and excludes anything that can bloat binary size if used. -cc_library( - name = "minimal", - srcs = [], - copts = tsl_copts(), -) - -cc_library( - name = "gif", - copts = tsl_copts(), - deps = [ - "@gif", - ], -) - -cc_library( - name = "jpeg", - copts = tsl_copts(), - deps = [ - "@libjpeg_turbo//:jpeg", - ], -) - -cc_library( - name = "png", - copts = tsl_copts(), - deps = [ - "@png", - "@zlib", - ], -) - -cc_library( - name = "test_main", - testonly = 1, - linkstatic = 1, - deps = [], -) - -cc_library( - name = "cuda", - data = [ - "@local_config_cuda//cuda:cudart", - ], - linkopts = select({ - "//tsl:macos": [ - "-Wl,-rpath,../local_config_cuda/cuda/lib", - "-Wl,-rpath,../local_config_cuda/cuda/extras/CUPTI/lib", - ], - "//conditions:default": [ - "-Wl,-rpath,../local_config_cuda/cuda/lib64", - "-Wl,-rpath,../local_config_cuda/cuda/extras/CUPTI/lib64", - ], - }), - deps = [ - "@local_config_cuda//cuda:cudart", - ], -) diff --git a/third_party/tsl/tsl/platform/default/build_config_root.bzl b/third_party/tsl/tsl/platform/default/build_config_root.bzl index 83912ed0f3956..9280779cbf236 100644 --- a/third_party/tsl/tsl/platform/default/build_config_root.bzl +++ b/third_party/tsl/tsl/platform/default/build_config_root.bzl @@ -60,3 +60,39 @@ def if_static_and_not_mobile(extra_deps, otherwise = []): str(Label("//tsl:ios")): otherwise, "//conditions:default": extra_deps, }) + +def if_llvm_aarch32_available(then, otherwise = []): + return select({ + str(Label("//tsl:aarch32_or_cross")): then, + "//conditions:default": otherwise, + }) + +def if_llvm_aarch64_available(then, otherwise = []): + return select({ + str(Label("//tsl:aarch64_or_cross")): then, + "//conditions:default": otherwise, + }) + +def if_llvm_arm_available(then, otherwise = []): + return select({ + str(Label("//tsl:arm_or_cross")): then, + "//conditions:default": otherwise, + }) + +def if_llvm_powerpc_available(then, otherwise = []): + return select({ + str(Label("//tsl:ppc64le_or_cross")): then, + "//conditions:default": otherwise, + }) + +def if_llvm_system_z_available(then, otherwise = []): + return select({ + str(Label("//tsl:s390x_or_cross")): then, + "//conditions:default": otherwise, + }) + +def if_llvm_x86_available(then, otherwise = []): + return select({ + str(Label("//tsl:x86_or_cross")): then, + "//conditions:default": otherwise, + }) diff --git a/third_party/tsl/tsl/platform/default/cuda_build_defs.bzl b/third_party/tsl/tsl/platform/default/cuda_build_defs.bzl index ad89515cd34c0..1f7f52a627b76 100644 --- a/third_party/tsl/tsl/platform/default/cuda_build_defs.bzl +++ b/third_party/tsl/tsl/platform/default/cuda_build_defs.bzl @@ -1,6 +1,10 @@ """Open source build configurations for CUDA.""" -load("@local_config_cuda//cuda:build_defs.bzl", _if_cuda_is_configured = "if_cuda_is_configured") +load( + "@local_config_cuda//cuda:build_defs.bzl", + _if_cuda_is_configured = "if_cuda_is_configured", + _if_cuda_newer_than = "if_cuda_newer_than", +) # We perform this indirection so that the copybara tool can distinguish this # macro from others provided by the same file. @@ -16,3 +20,6 @@ def cuda_rpath_flags(relpath): "-Wl,-rpath='$$ORIGIN/../../" + relpath + "'", "-Wl,-rpath='$$ORIGIN/../" + relpath + "'", ] + +def if_cuda_newer_than(wanted_ver, if_true, if_false = []): + return _if_cuda_newer_than(wanted_ver, if_true, if_false) diff --git a/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc b/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc index ed2ffece58819..46321e74b5dc3 100644 --- a/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc +++ b/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tsl/platform/path.h" #include "tsl/platform/platform.h" #if defined(PLATFORM_POSIX) && !defined(__APPLE__) @@ -44,15 +45,14 @@ std::vector CandidateCudaRoots() { Dl_info info; if (dladdr(&__FUNCTION__, &info)) { - auto lib = std::vector{info.dli_fname, - info.dli_fname + strlen(info.dli_fname)}; - auto dir = dirname(lib.data()); + auto lib = std::string(info.dli_fname); + auto dir = io::Dirname(lib); // TF lib binaries are located in both the package's root dir and within a // 'python' subdirectory (for pywrap libs). So we check two possible paths // relative to the current binary for the wheel-based nvcc package. - for (auto path : {"/../nvidia/cuda_nvcc", "/../../nvidia/cuda_nvcc"}) - roots.emplace_back(std::string(dir) + path); + for (auto path : {"../nvidia/cuda_nvcc", "../../nvidia/cuda_nvcc"}) + roots.emplace_back(io::JoinPath(dir, path)); } #endif // defined(PLATFORM_POSIX) && !defined(__APPLE__) diff --git a/third_party/tsl/tsl/platform/default/dlopen_checker.cc b/third_party/tsl/tsl/platform/default/dlopen_checker.cc index 2d67789d8a001..da840d6153604 100644 --- a/third_party/tsl/tsl/platform/default/dlopen_checker.cc +++ b/third_party/tsl/tsl/platform/default/dlopen_checker.cc @@ -12,17 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "tsl/platform/default/dso_loader.h" -#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace tsl { namespace internal { namespace DsoLoader { -Status TryDlopenCUDALibraries() { +absl::Status TryDlopenCUDALibraries() { namespace CachedLoader = ::tsl::internal::CachedDsoLoader; auto cudart_status = CachedLoader::GetCudaRuntimeDsoHandle(); auto cublas_status = CachedLoader::GetCublasDsoHandle(); @@ -36,14 +36,14 @@ Status TryDlopenCUDALibraries() { !cufft_status.status().ok() || !cusolver_status.status().ok() || !cusparse_status.status().ok() || !cudnn_status.status().ok() || !cublaslt_status.status().ok()) { - return Status(absl::StatusCode::kInternal, - absl::StrCat("Cannot dlopen all CUDA libraries.")); + return absl::Status(absl::StatusCode::kInternal, + absl::StrCat("Cannot dlopen all CUDA libraries.")); } else { - return tsl::OkStatus(); + return absl::OkStatus(); } } -Status TryDlopenROCmLibraries() { +absl::Status TryDlopenROCmLibraries() { auto rocblas_status = GetRocblasDsoHandle(); auto miopen_status = GetMiopenDsoHandle(); auto rocfft_status = GetHipfftDsoHandle(); @@ -57,32 +57,30 @@ Status TryDlopenROCmLibraries() { || !hipblaslt_status.status().ok() #endif ) { - return Status(absl::StatusCode::kInternal, - absl::StrCat("Cannot dlopen all ROCm libraries.")); + return absl::InternalError("Cannot dlopen all ROCm libraries."); } else { - return tsl::OkStatus(); + return absl::OkStatus(); } } -Status MaybeTryDlopenGPULibraries() { +absl::Status MaybeTryDlopenGPULibraries() { #if GOOGLE_CUDA return TryDlopenCUDALibraries(); #elif TENSORFLOW_USE_ROCM return TryDlopenROCmLibraries(); #else LOG(INFO) << "Not built with GPU enabled. Skip GPU library dlopen check."; - return tsl::OkStatus(); + return absl::OkStatus(); #endif } -Status TryDlopenTensorRTLibraries() { +absl::Status TryDlopenTensorRTLibraries() { auto nvinfer_status = GetNvInferDsoHandle(); auto nvinferplugin_status = GetNvInferPluginDsoHandle(); if (!nvinfer_status.status().ok() || !nvinferplugin_status.status().ok()) { - return Status(absl::StatusCode::kInternal, - absl::StrCat("Cannot dlopen all TensorRT libraries.")); + return absl::InternalError("Cannot dlopen all TensorRT libraries."); } else { - return tsl::OkStatus(); + return absl::OkStatus(); } } diff --git a/third_party/tsl/tsl/platform/default/dlopen_checker_stub.cc b/third_party/tsl/tsl/platform/default/dlopen_checker_stub.cc index 1d4b213427b5a..67f734302835d 100644 --- a/third_party/tsl/tsl/platform/default/dlopen_checker_stub.cc +++ b/third_party/tsl/tsl/platform/default/dlopen_checker_stub.cc @@ -12,18 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/status/status.h" #include "tsl/platform/default/dso_loader.h" #include "tsl/platform/logging.h" -#include "tsl/platform/status.h" namespace tsl { namespace internal { namespace DsoLoader { // Skip check when GPU libraries are statically linked. -Status MaybeTryDlopenGPULibraries() { +absl::Status MaybeTryDlopenGPULibraries() { LOG(INFO) << "GPU libraries are statically linked, skip dlopen check."; - return ::tsl::OkStatus(); + return absl::OkStatus(); } } // namespace DsoLoader } // namespace internal diff --git a/third_party/tsl/tsl/platform/default/dso_loader.cc b/third_party/tsl/tsl/platform/default/dso_loader.cc index eeff5d9e7ed94..a835a81489367 100644 --- a/third_party/tsl/tsl/platform/default/dso_loader.cc +++ b/third_party/tsl/tsl/platform/default/dso_loader.cc @@ -16,17 +16,18 @@ limitations under the License. #include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "third_party/gpus/cuda/cuda_config.h" #include "third_party/nccl/nccl_config.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" +#include "tsl/platform/load_library.h" #include "tsl/platform/logging.h" #include "tsl/platform/path.h" #include "tsl/platform/platform.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" #include "third_party/tensorrt/tensorrt_config.h" #if TENSORFLOW_USE_ROCM @@ -37,22 +38,23 @@ namespace tsl { namespace internal { namespace { -string GetCudaVersion() { return TF_CUDA_VERSION; } -string GetCudaRtVersion() { return TF_CUDART_VERSION; } -string GetCuptiVersion() { return TF_CUPTI_VERSION; } -string GetCudnnVersion() { return TF_CUDNN_VERSION; } -string GetCublasVersion() { return TF_CUBLAS_VERSION; } -string GetCusolverVersion() { return TF_CUSOLVER_VERSION; } -string GetCufftVersion() { return TF_CUFFT_VERSION; } -string GetCusparseVersion() { return TF_CUSPARSE_VERSION; } -string GetNcclVersion() { return TF_NCCL_VERSION; } -string GetTensorRTVersion() { return TF_TENSORRT_VERSION; } - -StatusOr GetDsoHandle(const string& name, const string& version) { - auto filename = Env::Default()->FormatLibraryFileName(name, version); +std::string GetCudaVersion() { return TF_CUDA_VERSION; } +std::string GetCudaRtVersion() { return TF_CUDART_VERSION; } +std::string GetCuptiVersion() { return TF_CUPTI_VERSION; } +std::string GetCudnnVersion() { return TF_CUDNN_VERSION; } +std::string GetCublasVersion() { return TF_CUBLAS_VERSION; } +std::string GetCusolverVersion() { return TF_CUSOLVER_VERSION; } +std::string GetCufftVersion() { return TF_CUFFT_VERSION; } +std::string GetCusparseVersion() { return TF_CUSPARSE_VERSION; } +std::string GetNcclVersion() { return TF_NCCL_VERSION; } +std::string GetTensorRTVersion() { return TF_TENSORRT_VERSION; } + +absl::StatusOr GetDsoHandle(const std::string& name, + const std::string& version) { + auto filename = tsl::internal::FormatLibraryFileName(name, version); void* dso_handle; - Status status = - Env::Default()->LoadDynamicLibrary(filename.c_str(), &dso_handle); + absl::Status status = + tsl::internal::LoadDynamicLibrary(filename.c_str(), &dso_handle); if (status.ok()) { VLOG(1) << "Successfully opened dynamic library " << filename; return dso_handle; @@ -66,12 +68,12 @@ StatusOr GetDsoHandle(const string& name, const string& version) { } #endif VLOG(1) << message; - return Status(absl::StatusCode::kFailedPrecondition, message); + return absl::Status(absl::StatusCode::kFailedPrecondition, message); } } // namespace namespace DsoLoader { -StatusOr GetCudaDriverDsoHandle() { +absl::StatusOr GetCudaDriverDsoHandle() { #if defined(PLATFORM_WINDOWS) return GetDsoHandle("nvcuda", ""); #elif defined(__APPLE__) @@ -85,31 +87,31 @@ StatusOr GetCudaDriverDsoHandle() { return GetDsoHandle("cuda", "1"); } -StatusOr GetCudaRuntimeDsoHandle() { +absl::StatusOr GetCudaRuntimeDsoHandle() { return GetDsoHandle("cudart", GetCudaRtVersion()); } -StatusOr GetCublasDsoHandle() { +absl::StatusOr GetCublasDsoHandle() { return GetDsoHandle("cublas", GetCublasVersion()); } -StatusOr GetCublasLtDsoHandle() { +absl::StatusOr GetCublasLtDsoHandle() { return GetDsoHandle("cublasLt", GetCublasVersion()); } -StatusOr GetCufftDsoHandle() { +absl::StatusOr GetCufftDsoHandle() { return GetDsoHandle("cufft", GetCufftVersion()); } -StatusOr GetCusolverDsoHandle() { +absl::StatusOr GetCusolverDsoHandle() { return GetDsoHandle("cusolver", GetCusolverVersion()); } -StatusOr GetCusparseDsoHandle() { +absl::StatusOr GetCusparseDsoHandle() { return GetDsoHandle("cusparse", GetCusparseVersion()); } -StatusOr GetCuptiDsoHandle() { +absl::StatusOr GetCuptiDsoHandle() { // Load specific version of CUPTI this is built. auto status_or_handle = GetDsoHandle("cupti", GetCuptiVersion()); if (status_or_handle.ok()) return status_or_handle; @@ -117,15 +119,15 @@ StatusOr GetCuptiDsoHandle() { return GetDsoHandle("cupti", ""); } -StatusOr GetCudnnDsoHandle() { +absl::StatusOr GetCudnnDsoHandle() { return GetDsoHandle("cudnn", GetCudnnVersion()); } -StatusOr GetNcclDsoHandle() { +absl::StatusOr GetNcclDsoHandle() { return GetDsoHandle("nccl", GetNcclVersion()); } -StatusOr GetNvInferDsoHandle() { +absl::StatusOr GetNvInferDsoHandle() { #if defined(PLATFORM_WINDOWS) return GetDsoHandle("nvinfer", ""); #else @@ -133,7 +135,7 @@ StatusOr GetNvInferDsoHandle() { #endif } -StatusOr GetNvInferPluginDsoHandle() { +absl::StatusOr GetNvInferPluginDsoHandle() { #if defined(PLATFORM_WINDOWS) return GetDsoHandle("nvinfer_plugin", ""); #else @@ -141,134 +143,142 @@ StatusOr GetNvInferPluginDsoHandle() { #endif } -StatusOr GetRocblasDsoHandle() { return GetDsoHandle("rocblas", ""); } +absl::StatusOr GetRocblasDsoHandle() { + return GetDsoHandle("rocblas", ""); +} -StatusOr GetMiopenDsoHandle() { return GetDsoHandle("MIOpen", ""); } +absl::StatusOr GetMiopenDsoHandle() { + return GetDsoHandle("MIOpen", ""); +} -StatusOr GetHipfftDsoHandle() { return GetDsoHandle("hipfft", ""); } +absl::StatusOr GetHipfftDsoHandle() { + return GetDsoHandle("hipfft", ""); +} -StatusOr GetRocrandDsoHandle() { return GetDsoHandle("rocrand", ""); } +absl::StatusOr GetRocrandDsoHandle() { + return GetDsoHandle("rocrand", ""); +} -StatusOr GetRocsolverDsoHandle() { +absl::StatusOr GetRocsolverDsoHandle() { return GetDsoHandle("rocsolver", ""); } #if TF_ROCM_VERSION >= 40500 -StatusOr GetHipsolverDsoHandle() { +absl::StatusOr GetHipsolverDsoHandle() { return GetDsoHandle("hipsolver", ""); } #endif -StatusOr GetRoctracerDsoHandle() { +absl::StatusOr GetRoctracerDsoHandle() { return GetDsoHandle("roctracer64", ""); } -StatusOr GetHipsparseDsoHandle() { +absl::StatusOr GetHipsparseDsoHandle() { return GetDsoHandle("hipsparse", ""); } -StatusOr GetHipblasltDsoHandle() { +absl::StatusOr GetHipblasltDsoHandle() { return GetDsoHandle("hipblaslt", ""); } -StatusOr GetHipDsoHandle() { return GetDsoHandle("amdhip64", ""); } +absl::StatusOr GetHipDsoHandle() { return GetDsoHandle("amdhip64", ""); } } // namespace DsoLoader namespace CachedDsoLoader { -StatusOr GetCudaDriverDsoHandle() { +absl::StatusOr GetCudaDriverDsoHandle() { static auto result = new auto(DsoLoader::GetCudaDriverDsoHandle()); return *result; } -StatusOr GetCudaRuntimeDsoHandle() { +absl::StatusOr GetCudaRuntimeDsoHandle() { static auto result = new auto(DsoLoader::GetCudaRuntimeDsoHandle()); return *result; } -StatusOr GetCublasDsoHandle() { +absl::StatusOr GetCublasDsoHandle() { static auto result = new auto(DsoLoader::GetCublasDsoHandle()); return *result; } -StatusOr GetCublasLtDsoHandle() { +absl::StatusOr GetCublasLtDsoHandle() { static auto result = new auto(DsoLoader::GetCublasLtDsoHandle()); return *result; } -StatusOr GetCufftDsoHandle() { +absl::StatusOr GetCufftDsoHandle() { static auto result = new auto(DsoLoader::GetCufftDsoHandle()); return *result; } -StatusOr GetCusolverDsoHandle() { +absl::StatusOr GetCusolverDsoHandle() { static auto result = new auto(DsoLoader::GetCusolverDsoHandle()); return *result; } -StatusOr GetCusparseDsoHandle() { +absl::StatusOr GetCusparseDsoHandle() { static auto result = new auto(DsoLoader::GetCusparseDsoHandle()); return *result; } -StatusOr GetCuptiDsoHandle() { +absl::StatusOr GetCuptiDsoHandle() { static auto result = new auto(DsoLoader::GetCuptiDsoHandle()); return *result; } -StatusOr GetCudnnDsoHandle() { +absl::StatusOr GetCudnnDsoHandle() { static auto result = new auto(DsoLoader::GetCudnnDsoHandle()); return *result; } -StatusOr GetRocblasDsoHandle() { +absl::StatusOr GetRocblasDsoHandle() { static auto result = new auto(DsoLoader::GetRocblasDsoHandle()); return *result; } -StatusOr GetMiopenDsoHandle() { +absl::StatusOr GetMiopenDsoHandle() { static auto result = new auto(DsoLoader::GetMiopenDsoHandle()); return *result; } -StatusOr GetHipfftDsoHandle() { +absl::StatusOr GetHipfftDsoHandle() { static auto result = new auto(DsoLoader::GetHipfftDsoHandle()); return *result; } -StatusOr GetRocrandDsoHandle() { +absl::StatusOr GetRocrandDsoHandle() { static auto result = new auto(DsoLoader::GetRocrandDsoHandle()); return *result; } -StatusOr GetRoctracerDsoHandle() { +absl::StatusOr GetRoctracerDsoHandle() { static auto result = new auto(DsoLoader::GetRoctracerDsoHandle()); return *result; } -StatusOr GetRocsolverDsoHandle() { +absl::StatusOr GetRocsolverDsoHandle() { static auto result = new auto(DsoLoader::GetRocsolverDsoHandle()); return *result; } #if TF_ROCM_VERSION >= 40500 -StatusOr GetHipsolverDsoHandle() { +absl::StatusOr GetHipsolverDsoHandle() { static auto result = new auto(DsoLoader::GetHipsolverDsoHandle()); return *result; } #endif -StatusOr GetHipsparseDsoHandle() { +absl::StatusOr GetHipsparseDsoHandle() { static auto result = new auto(DsoLoader::GetHipsparseDsoHandle()); return *result; } -StatusOr GetHipblasltDsoHandle() { +absl::StatusOr GetHipblasltDsoHandle() { static auto result = new auto(DsoLoader::GetHipblasltDsoHandle()); return *result; } -StatusOr GetHipDsoHandle() { +absl::StatusOr GetHipDsoHandle() { static auto result = new auto(DsoLoader::GetHipDsoHandle()); return *result; } diff --git a/third_party/tsl/tsl/platform/default/dso_loader.h b/third_party/tsl/tsl/platform/default/dso_loader.h index ee5b2b28af348..6f72484d504f5 100644 --- a/third_party/tsl/tsl/platform/default/dso_loader.h +++ b/third_party/tsl/tsl/platform/default/dso_loader.h @@ -19,8 +19,8 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_DEFAULT_DSO_LOADER_H_ #define TENSORFLOW_TSL_PLATFORM_DEFAULT_DSO_LOADER_H_ -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" namespace tsl { namespace internal { @@ -28,65 +28,65 @@ namespace internal { namespace DsoLoader { // The following methods either load the DSO of interest and return a dlopen // handle or error status. -StatusOr GetCudaDriverDsoHandle(); -StatusOr GetCudaRuntimeDsoHandle(); -StatusOr GetCublasDsoHandle(); -StatusOr GetCublasLtDsoHandle(); -StatusOr GetCufftDsoHandle(); -StatusOr GetCusolverDsoHandle(); -StatusOr GetCusparseDsoHandle(); -StatusOr GetCuptiDsoHandle(); -StatusOr GetCudnnDsoHandle(); -StatusOr GetNcclDsoHandle(); -StatusOr GetNvInferDsoHandle(); -StatusOr GetNvInferPluginDsoHandle(); +absl::StatusOr GetCudaDriverDsoHandle(); +absl::StatusOr GetCudaRuntimeDsoHandle(); +absl::StatusOr GetCublasDsoHandle(); +absl::StatusOr GetCublasLtDsoHandle(); +absl::StatusOr GetCufftDsoHandle(); +absl::StatusOr GetCusolverDsoHandle(); +absl::StatusOr GetCusparseDsoHandle(); +absl::StatusOr GetCuptiDsoHandle(); +absl::StatusOr GetCudnnDsoHandle(); +absl::StatusOr GetNcclDsoHandle(); +absl::StatusOr GetNvInferDsoHandle(); +absl::StatusOr GetNvInferPluginDsoHandle(); -StatusOr GetRocblasDsoHandle(); -StatusOr GetMiopenDsoHandle(); -StatusOr GetHipfftDsoHandle(); -StatusOr GetRocrandDsoHandle(); -StatusOr GetRoctracerDsoHandle(); -StatusOr GetRocsolverDsoHandle(); -StatusOr GetHipsolverDsoHandle(); -StatusOr GetHipsparseDsoHandle(); -StatusOr GetHipDsoHandle(); +absl::StatusOr GetRocblasDsoHandle(); +absl::StatusOr GetMiopenDsoHandle(); +absl::StatusOr GetHipfftDsoHandle(); +absl::StatusOr GetRocrandDsoHandle(); +absl::StatusOr GetRoctracerDsoHandle(); +absl::StatusOr GetRocsolverDsoHandle(); +absl::StatusOr GetHipsolverDsoHandle(); +absl::StatusOr GetHipsparseDsoHandle(); +absl::StatusOr GetHipDsoHandle(); // The following method tries to dlopen all necessary GPU libraries for the GPU // platform TF is built with (CUDA or ROCm) only when these libraries should be // dynamically loaded. Error status is returned when any of the libraries cannot // be dlopened. -Status MaybeTryDlopenGPULibraries(); +absl::Status MaybeTryDlopenGPULibraries(); // The following method tries to dlopen all necessary TensorRT libraries when // these libraries should be dynamically loaded. Error status is returned when // any of the libraries cannot be dlopened. -Status TryDlopenTensorRTLibraries(); +absl::Status TryDlopenTensorRTLibraries(); } // namespace DsoLoader // Wrapper around the DsoLoader that prevents us from dlopen'ing any of the DSOs // more than once. namespace CachedDsoLoader { // Cached versions of the corresponding DsoLoader methods above. -StatusOr GetCudaDriverDsoHandle(); -StatusOr GetCudaRuntimeDsoHandle(); -StatusOr GetCublasDsoHandle(); -StatusOr GetCublasLtDsoHandle(); -StatusOr GetCufftDsoHandle(); -StatusOr GetCusolverDsoHandle(); -StatusOr GetCusparseDsoHandle(); -StatusOr GetCuptiDsoHandle(); -StatusOr GetCudnnDsoHandle(); +absl::StatusOr GetCudaDriverDsoHandle(); +absl::StatusOr GetCudaRuntimeDsoHandle(); +absl::StatusOr GetCublasDsoHandle(); +absl::StatusOr GetCublasLtDsoHandle(); +absl::StatusOr GetCufftDsoHandle(); +absl::StatusOr GetCusolverDsoHandle(); +absl::StatusOr GetCusparseDsoHandle(); +absl::StatusOr GetCuptiDsoHandle(); +absl::StatusOr GetCudnnDsoHandle(); -StatusOr GetRocblasDsoHandle(); -StatusOr GetMiopenDsoHandle(); -StatusOr GetHipfftDsoHandle(); -StatusOr GetRocrandDsoHandle(); -StatusOr GetRocsolverDsoHandle(); -StatusOr GetHipsolverDsoHandle(); -StatusOr GetRoctracerDsoHandle(); -StatusOr GetHipsparseDsoHandle(); -StatusOr GetHipblasltDsoHandle(); -StatusOr GetHipDsoHandle(); +absl::StatusOr GetRocblasDsoHandle(); +absl::StatusOr GetMiopenDsoHandle(); +absl::StatusOr GetHipfftDsoHandle(); +absl::StatusOr GetRocrandDsoHandle(); +absl::StatusOr GetRocsolverDsoHandle(); +absl::StatusOr GetHipsolverDsoHandle(); +absl::StatusOr GetRoctracerDsoHandle(); +absl::StatusOr GetHipsparseDsoHandle(); +absl::StatusOr GetHipblasltDsoHandle(); +absl::StatusOr GetHipDsoHandle(); } // namespace CachedDsoLoader } // namespace internal diff --git a/third_party/tsl/tsl/platform/default/dynamic_annotations.h b/third_party/tsl/tsl/platform/default/dynamic_annotations.h deleted file mode 100644 index 4d275cc1169a6..0000000000000 --- a/third_party/tsl/tsl/platform/default/dynamic_annotations.h +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_TSL_PLATFORM_DEFAULT_DYNAMIC_ANNOTATIONS_H_ -#define TENSORFLOW_TSL_PLATFORM_DEFAULT_DYNAMIC_ANNOTATIONS_H_ - -// IWYU pragma: private, include "tsl/platform/dynamic_annotations.h" -// IWYU pragma: friend third_party/tensorflow/tsl/platform/dynamic_annotations.h - -// Do nothing for this platform. - -#define TF_ANNOTATE_MEMORY_IS_INITIALIZED(ptr, bytes) \ - do { \ - } while (0) - -#define TF_ANNOTATE_BENIGN_RACE(ptr, description) \ - do { \ - } while (0) - -#define TF_ATTRIBUTE_NO_SANITIZE_MEMORY - -#endif // TENSORFLOW_TSL_PLATFORM_DEFAULT_DYNAMIC_ANNOTATIONS_H_ diff --git a/third_party/tsl/tsl/platform/default/env.cc b/third_party/tsl/tsl/platform/default/env.cc index 62245dee98e63..6786be68aa4ef 100644 --- a/third_party/tsl/tsl/platform/default/env.cc +++ b/third_party/tsl/tsl/platform/default/env.cc @@ -189,13 +189,13 @@ class PosixEnv : public Env { }); } - Status LoadDynamicLibrary(const char* library_filename, - void** handle) override { + absl::Status LoadDynamicLibrary(const char* library_filename, + void** handle) override { return internal::LoadDynamicLibrary(library_filename, handle); } - Status GetSymbolFromLibrary(void* handle, const char* symbol_name, - void** symbol) override { + absl::Status GetSymbolFromLibrary(void* handle, const char* symbol_name, + void** symbol) override { return internal::GetSymbolFromLibrary(handle, symbol_name, symbol); } @@ -218,7 +218,7 @@ class PosixEnv : public Env { // See if we have the executable path. if executable.runfiles exists, return // that folder. string runfiles_path = bin_path + runfiles_suffix; - Status s = this->IsDirectory(runfiles_path); + absl::Status s = this->IsDirectory(runfiles_path); if (s.ok()) { return runfiles_path; } diff --git a/third_party/tsl/tsl/platform/default/grpc_credentials.cc b/third_party/tsl/tsl/platform/default/grpc_credentials.cc new file mode 100644 index 0000000000000..44850f56e0519 --- /dev/null +++ b/third_party/tsl/tsl/platform/default/grpc_credentials.cc @@ -0,0 +1,42 @@ +// Copyright 2024 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tsl/platform/grpc_credentials.h" + +#include + +#include "absl/log/check.h" +#include "grpcpp/security/credentials.h" +#include "grpcpp/security/server_credentials.h" +#include "tsl/platform/logging.h" + +namespace tsl { + +std::shared_ptr GetClientCredentials( + bool verify_secure_credentials) { + CHECK(!verify_secure_credentials) + << "Insecure gRPC credentials are unexpectedly used!"; + LOG(INFO) << "gRPC insecure client credentials are used."; + return grpc::InsecureChannelCredentials(); +} + +std::shared_ptr GetServerCredentials( + bool verify_secure_credentials) { + CHECK(!verify_secure_credentials) + << "Insecure gRPC credentials are unexpectedly used!"; + LOG(INFO) << "gRPC insecure server credentials are used."; + return grpc::InsecureServerCredentials(); +} + +} // namespace tsl diff --git a/third_party/tsl/tsl/platform/default/human_readable_json.cc b/third_party/tsl/tsl/platform/default/human_readable_json.cc index d2b757abfa8f5..5f2685cd93f76 100644 --- a/third_party/tsl/tsl/platform/default/human_readable_json.cc +++ b/third_party/tsl/tsl/platform/default/human_readable_json.cc @@ -21,8 +21,9 @@ limitations under the License. namespace tsl { -Status ProtoToHumanReadableJson(const protobuf::Message& proto, string* result, - bool ignore_accuracy_loss) { +absl::Status ProtoToHumanReadableJson(const protobuf::Message& proto, + string* result, + bool ignore_accuracy_loss) { result->clear(); protobuf::util::JsonPrintOptions json_options; @@ -38,16 +39,18 @@ Status ProtoToHumanReadableJson(const protobuf::Message& proto, string* result, strings::StrCat("Could not convert proto to JSON string: ", StringPiece(error_msg.data(), error_msg.length()))); } - return OkStatus(); + return absl::OkStatus(); } -Status ProtoToHumanReadableJson(const protobuf::MessageLite& proto, - string* result, bool ignore_accuracy_loss) { +absl::Status ProtoToHumanReadableJson(const protobuf::MessageLite& proto, + string* result, + bool ignore_accuracy_loss) { *result = "[human readable output not available for lite protos]"; - return OkStatus(); + return absl::OkStatus(); } -Status HumanReadableJsonToProto(const string& str, protobuf::Message* proto) { +absl::Status HumanReadableJsonToProto(const string& str, + protobuf::Message* proto) { proto->Clear(); auto status = protobuf::util::JsonStringToMessage(str, proto); if (!status.ok()) { @@ -58,11 +61,11 @@ Status HumanReadableJsonToProto(const string& str, protobuf::Message* proto) { strings::StrCat("Could not convert JSON string to proto: ", StringPiece(error_msg.data(), error_msg.length()))); } - return OkStatus(); + return absl::OkStatus(); } -Status HumanReadableJsonToProto(const string& str, - protobuf::MessageLite* proto) { +absl::Status HumanReadableJsonToProto(const string& str, + protobuf::MessageLite* proto) { return errors::Internal("Cannot parse JSON protos on Android"); } diff --git a/third_party/tsl/tsl/platform/default/load_library.cc b/third_party/tsl/tsl/platform/default/load_library.cc index f49adf2f7f257..70961c8dc990e 100644 --- a/third_party/tsl/tsl/platform/default/load_library.cc +++ b/third_party/tsl/tsl/platform/default/load_library.cc @@ -17,26 +17,26 @@ limitations under the License. #include -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" +#include + +#include "absl/status/status.h" namespace tsl { namespace internal { -Status LoadDynamicLibrary(const char* library_filename, void** handle) { +absl::Status LoadDynamicLibrary(const char* library_filename, void** handle) { *handle = dlopen(library_filename, RTLD_NOW | RTLD_LOCAL); if (!*handle) { // Note that in C++17 std::string_view(nullptr) gives segfault! const char* error_msg = dlerror(); - return tsl::errors::NotFound(error_msg ? error_msg - : "(null error message)"); + return absl::NotFoundError(error_msg ? error_msg : "(null error message)"); } - return OkStatus(); + return absl::OkStatus(); } -Status GetSymbolFromLibrary(void* handle, const char* symbol_name, - void** symbol) { +absl::Status GetSymbolFromLibrary(void* handle, const char* symbol_name, + void** symbol) { // Check that the handle is not NULL to avoid dlsym's RTLD_DEFAULT behavior. if (!handle) { *symbol = nullptr; @@ -46,14 +46,14 @@ Status GetSymbolFromLibrary(void* handle, const char* symbol_name, if (!*symbol) { // Note that in C++17 std::string_view(nullptr) gives segfault! const char* error_msg = dlerror(); - return tsl::errors::NotFound(error_msg ? error_msg - : "(null error message)"); + return absl::NotFoundError(error_msg ? error_msg : "(null error message)"); } - return OkStatus(); + return absl::OkStatus(); } -string FormatLibraryFileName(const string& name, const string& version) { - string filename; +std::string FormatLibraryFileName(const std::string& name, + const std::string& version) { + std::string filename; #if defined(__APPLE__) if (version.size() == 0) { filename = "lib" + name + ".dylib"; diff --git a/third_party/tsl/tsl/platform/default/notification.h b/third_party/tsl/tsl/platform/default/notification.h deleted file mode 100644 index 99a1fa89b942c..0000000000000 --- a/third_party/tsl/tsl/platform/default/notification.h +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_TSL_PLATFORM_DEFAULT_NOTIFICATION_H_ -#define TENSORFLOW_TSL_PLATFORM_DEFAULT_NOTIFICATION_H_ - -#include - -#include // NOLINT -#include // NOLINT -#include // NOLINT - -#include "tsl/platform/mutex.h" -#include "tsl/platform/types.h" - -namespace tsl { - -class Notification { - public: - Notification() : notified_(0) {} - ~Notification() { - // In case the notification is being used to synchronize its own deletion, - // force any prior notifier to leave its critical section before the object - // is destroyed. - mutex_lock l(mu_); - } - - void Notify() { - mutex_lock l(mu_); - assert(!HasBeenNotified()); - notified_.store(true, std::memory_order_release); - cv_.notify_all(); - } - - bool HasBeenNotified() const { - return notified_.load(std::memory_order_acquire); - } - - void WaitForNotification() { - if (!HasBeenNotified()) { - mutex_lock l(mu_); - while (!HasBeenNotified()) { - cv_.wait(l); - } - } - } - - private: - friend bool WaitForNotificationWithTimeout(Notification* n, - int64_t timeout_in_us); - bool WaitForNotificationWithTimeout(int64_t timeout_in_us) { - bool notified = HasBeenNotified(); - if (!notified) { - mutex_lock l(mu_); - do { - notified = HasBeenNotified(); - } while (!notified && - cv_.wait_for(l, std::chrono::microseconds(timeout_in_us)) != - std::cv_status::timeout); - } - return notified; - } - - mutex mu_; // protects mutations of notified_ - condition_variable cv_; // signaled when notified_ becomes non-zero - std::atomic notified_; // mutations under mu_ -}; - -inline bool WaitForNotificationWithTimeout(Notification* n, - int64_t timeout_in_us) { - return n->WaitForNotificationWithTimeout(timeout_in_us); -} - -} // namespace tsl - -#endif // TENSORFLOW_TSL_PLATFORM_DEFAULT_NOTIFICATION_H_ diff --git a/third_party/tsl/tsl/platform/default/port.cc b/third_party/tsl/tsl/platform/default/port.cc index c2151c78ec533..868fb35f887da 100644 --- a/third_party/tsl/tsl/platform/default/port.cc +++ b/third_party/tsl/tsl/platform/default/port.cc @@ -15,6 +15,7 @@ limitations under the License. #include "absl/base/internal/sysinfo.h" #include "tsl/platform/cpu_info.h" +#include "tsl/platform/host_info.h" #include "tsl/platform/logging.h" #include "tsl/platform/mem.h" #include "tsl/platform/numa.h" @@ -256,7 +257,6 @@ int NUMAGetThreadNodeAffinity() { return node_index; } - void* NUMAMalloc(int node, size_t size, int minimum_alignment) { #ifdef TENSORFLOW_USE_NUMA if (HaveHWLocTopology()) { @@ -307,7 +307,6 @@ int NUMAGetMemAffinity(const void* addr) { return node; } - bool Snappy_Compress(const char* input, size_t length, string* output) { #ifdef TF_USE_SNAPPY output->resize(snappy::MaxCompressedLength(length)); @@ -447,5 +446,8 @@ MemoryBandwidthInfo GetMemoryBandwidthInfo() { MemoryBandwidthInfo membw_info = {INT64_MAX}; return membw_info; } + +IOStatistics GetIOStatistics() { return IOStatistics(); } + } // namespace port } // namespace tsl diff --git a/third_party/tsl/tsl/platform/default/posix_file_system.cc b/third_party/tsl/tsl/platform/default/posix_file_system.cc index 16b1be88329dc..c87ba18019744 100644 --- a/third_party/tsl/tsl/platform/default/posix_file_system.cc +++ b/third_party/tsl/tsl/platform/default/posix_file_system.cc @@ -60,14 +60,14 @@ class PosixRandomAccessFile : public RandomAccessFile { } } - Status Name(StringPiece* result) const override { + absl::Status Name(StringPiece* result) const override { *result = filename_; - return OkStatus(); + return absl::OkStatus(); } - Status Read(uint64 offset, size_t n, StringPiece* result, - char* scratch) const override { - Status s; + absl::Status Read(uint64 offset, size_t n, StringPiece* result, + char* scratch) const override { + absl::Status s; char* dst = scratch; while (n > 0 && s.ok()) { // Some platforms, notably macs, throw EINVAL if pread is asked to read @@ -85,8 +85,8 @@ class PosixRandomAccessFile : public RandomAccessFile { n -= r; offset += r; } else if (r == 0) { - s = Status(absl::StatusCode::kOutOfRange, - "Read less bytes than requested"); + s = absl::Status(absl::StatusCode::kOutOfRange, + "Read less bytes than requested"); } else if (errno == EINTR || errno == EAGAIN) { // Retry } else { @@ -98,9 +98,9 @@ class PosixRandomAccessFile : public RandomAccessFile { } #if defined(TF_CORD_SUPPORT) - Status Read(uint64 offset, size_t n, absl::Cord* cord) const override { + absl::Status Read(uint64 offset, size_t n, absl::Cord* cord) const override { if (n == 0) { - return OkStatus(); + return absl::OkStatus(); } if (n < 0) { return errors::InvalidArgument( @@ -115,7 +115,7 @@ class PosixRandomAccessFile : public RandomAccessFile { } StringPiece tmp; - Status s = Read(offset, n, &tmp, scratch); + absl::Status s = Read(offset, n, &tmp, scratch); absl::Cord tmp_cord = absl::MakeCordFromExternal( absl::string_view(static_cast(scratch), tmp.size()), @@ -142,32 +142,32 @@ class PosixWritableFile : public WritableFile { } } - Status Append(StringPiece data) override { + absl::Status Append(StringPiece data) override { size_t r = fwrite(data.data(), 1, data.size(), file_); if (r != data.size()) { return IOError(filename_, errno); } - return OkStatus(); + return absl::OkStatus(); } #if defined(TF_CORD_SUPPORT) // \brief Append 'cord' to the file. - Status Append(const absl::Cord& cord) override { + absl::Status Append(const absl::Cord& cord) override { for (const auto& chunk : cord.Chunks()) { size_t r = fwrite(chunk.data(), 1, chunk.size(), file_); if (r != chunk.size()) { return IOError(filename_, errno); } } - return OkStatus(); + return absl::OkStatus(); } #endif - Status Close() override { + absl::Status Close() override { if (file_ == nullptr) { return IOError(filename_, EBADF); } - Status result; + absl::Status result; if (fclose(file_) != 0) { result = IOError(filename_, errno); } @@ -175,28 +175,28 @@ class PosixWritableFile : public WritableFile { return result; } - Status Flush() override { + absl::Status Flush() override { if (fflush(file_) != 0) { return IOError(filename_, errno); } - return OkStatus(); + return absl::OkStatus(); } - Status Name(StringPiece* result) const override { + absl::Status Name(StringPiece* result) const override { *result = filename_; - return OkStatus(); + return absl::OkStatus(); } - Status Sync() override { - Status s; + absl::Status Sync() override { + absl::Status s; if (fflush(file_) != 0) { s = IOError(filename_, errno); } return s; } - Status Tell(int64_t* position) override { - Status s; + absl::Status Tell(int64_t* position) override { + absl::Status s; *position = ftell(file_); if (*position == -1) { @@ -222,11 +222,11 @@ class PosixReadOnlyMemoryRegion : public ReadOnlyMemoryRegion { const uint64 length_; }; -Status PosixFileSystem::NewRandomAccessFile( +absl::Status PosixFileSystem::NewRandomAccessFile( const string& fname, TransactionToken* token, std::unique_ptr* result) { string translated_fname = TranslateName(fname); - Status s; + absl::Status s; int fd = open(translated_fname.c_str(), O_RDONLY); if (fd < 0) { s = IOError(fname, errno); @@ -236,11 +236,11 @@ Status PosixFileSystem::NewRandomAccessFile( return s; } -Status PosixFileSystem::NewWritableFile(const string& fname, - TransactionToken* token, - std::unique_ptr* result) { +absl::Status PosixFileSystem::NewWritableFile( + const string& fname, TransactionToken* token, + std::unique_ptr* result) { string translated_fname = TranslateName(fname); - Status s; + absl::Status s; FILE* f = fopen(translated_fname.c_str(), "w"); if (f == nullptr) { s = IOError(fname, errno); @@ -250,11 +250,11 @@ Status PosixFileSystem::NewWritableFile(const string& fname, return s; } -Status PosixFileSystem::NewAppendableFile( +absl::Status PosixFileSystem::NewAppendableFile( const string& fname, TransactionToken* token, std::unique_ptr* result) { string translated_fname = TranslateName(fname); - Status s; + absl::Status s; FILE* f = fopen(translated_fname.c_str(), "a"); if (f == nullptr) { s = IOError(fname, errno); @@ -264,11 +264,11 @@ Status PosixFileSystem::NewAppendableFile( return s; } -Status PosixFileSystem::NewReadOnlyMemoryRegionFromFile( +absl::Status PosixFileSystem::NewReadOnlyMemoryRegionFromFile( const string& fname, TransactionToken* token, std::unique_ptr* result) { string translated_fname = TranslateName(fname); - Status s = OkStatus(); + absl::Status s = absl::OkStatus(); int fd = open(translated_fname.c_str(), O_RDONLY); if (fd < 0) { s = IOError(fname, errno); @@ -289,16 +289,17 @@ Status PosixFileSystem::NewReadOnlyMemoryRegionFromFile( return s; } -Status PosixFileSystem::FileExists(const string& fname, - TransactionToken* token) { +absl::Status PosixFileSystem::FileExists(const string& fname, + TransactionToken* token) { if (access(TranslateName(fname).c_str(), F_OK) == 0) { - return OkStatus(); + return absl::OkStatus(); } return errors::NotFound(fname, " not found"); } -Status PosixFileSystem::GetChildren(const string& dir, TransactionToken* token, - std::vector* result) { +absl::Status PosixFileSystem::GetChildren(const string& dir, + TransactionToken* token, + std::vector* result) { string translated_dir = TranslateName(dir); result->clear(); DIR* d = opendir(translated_dir.c_str()); @@ -315,25 +316,26 @@ Status PosixFileSystem::GetChildren(const string& dir, TransactionToken* token, if (closedir(d) < 0) { return IOError(dir, errno); } - return OkStatus(); + return absl::OkStatus(); } -Status PosixFileSystem::GetMatchingPaths(const string& pattern, - TransactionToken* token, - std::vector* results) { +absl::Status PosixFileSystem::GetMatchingPaths(const string& pattern, + TransactionToken* token, + std::vector* results) { return internal::GetMatchingPaths(this, Env::Default(), pattern, results); } -Status PosixFileSystem::DeleteFile(const string& fname, - TransactionToken* token) { - Status result; +absl::Status PosixFileSystem::DeleteFile(const string& fname, + TransactionToken* token) { + absl::Status result; if (unlink(TranslateName(fname).c_str()) != 0) { result = IOError(fname, errno); } return result; } -Status PosixFileSystem::CreateDir(const string& name, TransactionToken* token) { +absl::Status PosixFileSystem::CreateDir(const string& name, + TransactionToken* token) { string translated = TranslateName(name); if (translated.empty()) { return errors::AlreadyExists(name); @@ -341,20 +343,22 @@ Status PosixFileSystem::CreateDir(const string& name, TransactionToken* token) { if (mkdir(translated.c_str(), 0755) != 0) { return IOError(name, errno); } - return OkStatus(); + return absl::OkStatus(); } -Status PosixFileSystem::DeleteDir(const string& name, TransactionToken* token) { - Status result; +absl::Status PosixFileSystem::DeleteDir(const string& name, + TransactionToken* token) { + absl::Status result; if (rmdir(TranslateName(name).c_str()) != 0) { result = IOError(name, errno); } return result; } -Status PosixFileSystem::GetFileSize(const string& fname, - TransactionToken* token, uint64* size) { - Status s; +absl::Status PosixFileSystem::GetFileSize(const string& fname, + TransactionToken* token, + uint64* size) { + absl::Status s; struct stat sbuf; if (stat(TranslateName(fname).c_str(), &sbuf) != 0) { *size = 0; @@ -365,9 +369,9 @@ Status PosixFileSystem::GetFileSize(const string& fname, return s; } -Status PosixFileSystem::Stat(const string& fname, TransactionToken* token, - FileStatistics* stats) { - Status s; +absl::Status PosixFileSystem::Stat(const string& fname, TransactionToken* token, + FileStatistics* stats) { + absl::Status s; struct stat sbuf; if (stat(TranslateName(fname).c_str(), &sbuf) != 0) { s = IOError(fname, errno); @@ -379,17 +383,18 @@ Status PosixFileSystem::Stat(const string& fname, TransactionToken* token, return s; } -Status PosixFileSystem::RenameFile(const string& src, const string& target, - TransactionToken* token) { - Status result; +absl::Status PosixFileSystem::RenameFile(const string& src, + const string& target, + TransactionToken* token) { + absl::Status result; if (rename(TranslateName(src).c_str(), TranslateName(target).c_str()) != 0) { result = IOError(src, errno); } return result; } -Status PosixFileSystem::CopyFile(const string& src, const string& target, - TransactionToken* token) { +absl::Status PosixFileSystem::CopyFile(const string& src, const string& target, + TransactionToken* token) { string translated_src = TranslateName(src); struct stat sbuf; if (stat(translated_src.c_str(), &sbuf) != 0) { @@ -438,18 +443,18 @@ Status PosixFileSystem::CopyFile(const string& src, const string& target, } } - Status result = OkStatus(); + absl::Status result = absl::OkStatus(); if (rc < 0) { result = IOError(target, errno); } // Keep the error code rc = close(target_fd); - if (rc < 0 && result == OkStatus()) { + if (rc < 0 && result == absl::OkStatus()) { result = IOError(target, errno); } rc = close(src_fd); - if (rc < 0 && result == OkStatus()) { + if (rc < 0 && result == absl::OkStatus()) { result = IOError(target, errno); } diff --git a/third_party/tsl/tsl/platform/default/posix_file_system.h b/third_party/tsl/tsl/platform/default/posix_file_system.h index 2da9f03889c46..877a65319c6df 100644 --- a/third_party/tsl/tsl/platform/default/posix_file_system.h +++ b/third_party/tsl/tsl/platform/default/posix_file_system.h @@ -29,45 +29,48 @@ class PosixFileSystem : public FileSystem { TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT; - Status NewRandomAccessFile( + absl::Status NewRandomAccessFile( const string& filename, TransactionToken* token, std::unique_ptr* result) override; - Status NewWritableFile(const string& fname, TransactionToken* token, - std::unique_ptr* result) override; + absl::Status NewWritableFile(const string& fname, TransactionToken* token, + std::unique_ptr* result) override; - Status NewAppendableFile(const string& fname, TransactionToken* token, - std::unique_ptr* result) override; + absl::Status NewAppendableFile( + const string& fname, TransactionToken* token, + std::unique_ptr* result) override; - Status NewReadOnlyMemoryRegionFromFile( + absl::Status NewReadOnlyMemoryRegionFromFile( const string& filename, TransactionToken* token, std::unique_ptr* result) override; - Status FileExists(const string& fname, TransactionToken* token) override; + absl::Status FileExists(const string& fname, + TransactionToken* token) override; - Status GetChildren(const string& dir, TransactionToken* token, - std::vector* result) override; + absl::Status GetChildren(const string& dir, TransactionToken* token, + std::vector* result) override; - Status Stat(const string& fname, TransactionToken* token, - FileStatistics* stats) override; + absl::Status Stat(const string& fname, TransactionToken* token, + FileStatistics* stats) override; - Status GetMatchingPaths(const string& pattern, TransactionToken* token, - std::vector* results) override; + absl::Status GetMatchingPaths(const string& pattern, TransactionToken* token, + std::vector* results) override; - Status DeleteFile(const string& fname, TransactionToken* token) override; + absl::Status DeleteFile(const string& fname, + TransactionToken* token) override; - Status CreateDir(const string& name, TransactionToken* token) override; + absl::Status CreateDir(const string& name, TransactionToken* token) override; - Status DeleteDir(const string& name, TransactionToken* token) override; + absl::Status DeleteDir(const string& name, TransactionToken* token) override; - Status GetFileSize(const string& fname, TransactionToken* token, - uint64* size) override; + absl::Status GetFileSize(const string& fname, TransactionToken* token, + uint64* size) override; - Status RenameFile(const string& src, const string& target, - TransactionToken* token) override; + absl::Status RenameFile(const string& src, const string& target, + TransactionToken* token) override; - Status CopyFile(const string& src, const string& target, - TransactionToken* token) override; + absl::Status CopyFile(const string& src, const string& target, + TransactionToken* token) override; }; class LocalPosixFileSystem : public PosixFileSystem { diff --git a/third_party/tsl/tsl/platform/default/resource_loader.cc b/third_party/tsl/tsl/platform/default/resource_loader.cc deleted file mode 100644 index 76087737c6da5..0000000000000 --- a/third_party/tsl/tsl/platform/default/resource_loader.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tsl/platform/resource_loader.h" - -#include - -#include "tsl/platform/logging.h" -#include "tsl/platform/path.h" -#include "tsl/platform/test.h" -#include "tools/cpp/runfiles/runfiles.h" - -using bazel::tools::cpp::runfiles::Runfiles; - -namespace tsl { - -std::string GetDataDependencyFilepath(const std::string& relative_path) { - std::string error; - std::unique_ptr runfiles(Runfiles::CreateForTest(&error)); - - if (runfiles == nullptr) { - LOG(FATAL) << "Unable to access the data dependencies of this test.\n" - "Make sure you are running this test using bazel."; - } - - const char* workspace_cstr = std::getenv("TEST_WORKSPACE"); - EXPECT_THAT(workspace_cstr, ::testing::NotNull()); - return runfiles->Rlocation(io::JoinPath(workspace_cstr, relative_path)); -} - -} // namespace tsl diff --git a/third_party/tsl/tsl/platform/default/rules_cc.bzl b/third_party/tsl/tsl/platform/default/rules_cc.bzl index 054460a83e19a..c649b10e29513 100644 --- a/third_party/tsl/tsl/platform/default/rules_cc.bzl +++ b/third_party/tsl/tsl/platform/default/rules_cc.bzl @@ -2,12 +2,29 @@ _cc_binary = native.cc_binary _cc_import = native.cc_import -_cc_library = native.cc_library _cc_shared_library = native.cc_shared_library _cc_test = native.cc_test cc_binary = _cc_binary cc_import = _cc_import -cc_library = _cc_library cc_shared_library = _cc_shared_library cc_test = _cc_test + +def cc_library(name, deps = None, **kwargs): + """cc_library that hides side effects of https://github.com/bazelbuild/bazel/issues/21519. + + Args: + name: name of target. + deps: deps with `xla:bazel_issue_21519` added. + **kwargs: passed to native.cc_library. + """ + + if deps == None: + deps = [] + + # Horrifying, but needed to prevent a cycle, as `bazel_issue_21519` is an + # alias of `empty`. + if name != "empty": + deps = deps + ["@xla//xla:bazel_issue_21519"] # buildifier: disable=list-append + + native.cc_library(name = name, deps = deps, **kwargs) diff --git a/third_party/tsl/tsl/platform/default/subprocess.cc b/third_party/tsl/tsl/platform/default/subprocess.cc index d750328ebf38f..c786295c08e0e 100644 --- a/third_party/tsl/tsl/platform/default/subprocess.cc +++ b/third_party/tsl/tsl/platform/default/subprocess.cc @@ -30,7 +30,11 @@ limitations under the License. #include "tsl/platform/logging.h" // Android versions older than 28 do not have posix_spawn(). -#define USE_POSIX_SPAWN !defined(__ANDROID_API__) || __ANDROID_API__ >= 28 +#if !defined(__ANDROID_API__) || __ANDROID_API__ >= 28 +#define USE_POSIX_SPAWN 1 +#else // defined(__ANDROID_API__) && __ANDROID_API__ < 28 +#define USE_POSIX_SPAWN 0 +#endif // !defined(__ANDROID_API__) || __ANDROID_API__ >= 28 // 1) FYI from m3b@ about fork(): // A danger of calling fork() (as opposed to clone() or vfork()) is that if diff --git a/third_party/tsl/tsl/platform/dynamic_annotations.h b/third_party/tsl/tsl/platform/dynamic_annotations.h index 88912f7b7519b..e0c5867c9c503 100644 --- a/third_party/tsl/tsl/platform/dynamic_annotations.h +++ b/third_party/tsl/tsl/platform/dynamic_annotations.h @@ -16,17 +16,21 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_DYNAMIC_ANNOTATIONS_H_ #define TENSORFLOW_TSL_PLATFORM_DYNAMIC_ANNOTATIONS_H_ -#include "tsl/platform/platform.h" - -// Include appropriate platform-dependent implementation. -#if defined(PLATFORM_GOOGLE) -#include "tsl/platform/google/dynamic_annotations.h" // IWYU pragma: export -#elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \ - defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_POSIX_IOS) || \ - defined(PLATFORM_GOOGLE_IOS) || defined(PLATFORM_WINDOWS) -#include "tsl/platform/default/dynamic_annotations.h" // IWYU pragma: export +#include "absl/base/dynamic_annotations.h" + +#define TF_ANNOTATE_MEMORY_IS_INITIALIZED(ptr, bytes) \ + ANNOTATE_MEMORY_IS_INITIALIZED(ptr, bytes) + +#define TF_ANNOTATE_BENIGN_RACE(ptr, description) \ + ANNOTATE_BENIGN_RACE(ptr, description) + +// Tell MemorySanitizer to relax the handling of a given function. All "Use of +// uninitialized value" warnings from such functions will be suppressed, and +// all values loaded from memory will be considered fully initialized. +#ifdef MEMORY_SANITIZER +#define TF_ATTRIBUTE_NO_SANITIZE_MEMORY __attribute__((no_sanitize_memory)) #else -#error Define the appropriate PLATFORM_ macro for this platform +#define TF_ATTRIBUTE_NO_SANITIZE_MEMORY #endif #endif // TENSORFLOW_TSL_PLATFORM_DYNAMIC_ANNOTATIONS_H_ diff --git a/third_party/tsl/tsl/platform/env.h b/third_party/tsl/tsl/platform/env.h index fe3354c765a06..35b446a99445a 100644 --- a/third_party/tsl/tsl/platform/env.h +++ b/third_party/tsl/tsl/platform/env.h @@ -54,8 +54,12 @@ struct ThreadOptions; /// Callers may wish to provide a custom Env object to get fine grain /// control. /// -/// All Env implementations are safe for concurrent access from -/// multiple threads without any external synchronization. +/// All Env implementations of file-system modifying functionality are safe +/// for concurrent access from multiple threads without any external +/// synchronization, *however*, Envs and their underlying file systems are +/// global objects, and therefore, if any thread modifies options, the modified +/// options take effect process-wide. The SetOption functions themselves are +/// also *not* thread safe. class Env { public: Env(); diff --git a/third_party/tsl/tsl/platform/file_system.h b/third_party/tsl/tsl/platform/file_system.h index 76fab57f4b64b..8f7bd875e35bc 100644 --- a/third_party/tsl/tsl/platform/file_system.h +++ b/third_party/tsl/tsl/platform/file_system.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -40,6 +41,7 @@ limitations under the License. namespace tsl { +class FileAcl; class RandomAccessFile; class ReadOnlyMemoryRegion; class WritableFile; @@ -531,6 +533,13 @@ class FileSystem { return errors::Unimplemented("SetOption"); } + /// \brief Set File System ACL checker. + /// + /// No checks are enforced if a FileAcl is never set. + virtual tsl::Status SetFileAcl(std::shared_ptr file_acl) { + return errors::Unimplemented("SetFileAcl"); + } + FileSystem() {} virtual ~FileSystem() = default; @@ -902,6 +911,13 @@ class FileSystemRegistry { std::vector* schemes) = 0; }; +/// \brief An abstraction for enforcing ACL checks in FileSystem. +class FileAcl { + public: + virtual absl::Status CheckAccess(std::string_view path) = 0; + virtual ~FileAcl() = default; +}; + } // namespace tsl #endif // TENSORFLOW_TSL_PLATFORM_FILE_SYSTEM_H_ diff --git a/third_party/tsl/tsl/platform/fingerprint.h b/third_party/tsl/tsl/platform/fingerprint.h index ee5e5fc1314ef..bb961fd89c174 100644 --- a/third_party/tsl/tsl/platform/fingerprint.h +++ b/third_party/tsl/tsl/platform/fingerprint.h @@ -16,12 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_FINGERPRINT_H_ #define TENSORFLOW_TSL_PLATFORM_FINGERPRINT_H_ +#include "tsl/platform/platform.h" #include "tsl/platform/stringpiece.h" #include "tsl/platform/types.h" -// The following line is used by copybara to set or unset the USE_OSS_FARMHASH -// preprocessor symbol as needed. Please do not remove. +#if TSL_IS_IN_OSS #define USE_OSS_FARMHASH +#endif // TSL_IS_IN_OSS #ifdef USE_OSS_FARMHASH #include diff --git a/third_party/tsl/tsl/platform/float8.h b/third_party/tsl/tsl/platform/float8.h deleted file mode 100644 index 247c1c12fd54b..0000000000000 --- a/third_party/tsl/tsl/platform/float8.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_TSL_PLATFORM_FLOAT8_H_ -#define TENSORFLOW_TSL_PLATFORM_FLOAT8_H_ - -// Deprecated, here only for backward-compatibility. Please use ml_dtypes.h. -#include "tsl/platform/ml_dtypes.h" // IWYU pragma: export - -#endif // TENSORFLOW_TSL_PLATFORM_FLOAT8_H_ diff --git a/third_party/tsl/tsl/platform/gif.h b/third_party/tsl/tsl/platform/gif.h deleted file mode 100644 index 865b6f201e66f..0000000000000 --- a/third_party/tsl/tsl/platform/gif.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_TSL_PLATFORM_GIF_H_ -#define TENSORFLOW_TSL_PLATFORM_GIF_H_ - -#include "gif_lib.h" // from @gif - -#endif // TENSORFLOW_TSL_PLATFORM_GIF_H_ diff --git a/third_party/tsl/tsl/platform/grpc_credentials.h b/third_party/tsl/tsl/platform/grpc_credentials.h new file mode 100644 index 0000000000000..5625811c0fdde --- /dev/null +++ b/third_party/tsl/tsl/platform/grpc_credentials.h @@ -0,0 +1,38 @@ +/* + * Copyright 2024 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TENSORFLOW_TSL_PLATFORM_GRPC_CREDENTIALS_H_ +#define TENSORFLOW_TSL_PLATFORM_GRPC_CREDENTIALS_H_ + +#include + +#include "grpcpp/security/credentials.h" +#include "grpcpp/security/server_credentials.h" + +namespace tsl { + +// Get credentials to use in the client gRPC. +// If `verify_secure_credentials`, crash if insecure credentials are used. +std::shared_ptr<::grpc::ChannelCredentials> GetClientCredentials( + bool verify_secure_credentials = true); + +// Get credentials to use in the server gRPC. +// If `verify_secure_credentials`, crash if insecure credentials are used. +std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials( + bool verify_secure_credentials = true); +} // namespace tsl + +#endif // TENSORFLOW_TSL_PLATFORM_GRPC_CREDENTIALS_H_ diff --git a/third_party/tsl/tsl/platform/host_info.h b/third_party/tsl/tsl/platform/host_info.h index 189f3be2934ce..630f9424525e0 100644 --- a/third_party/tsl/tsl/platform/host_info.h +++ b/third_party/tsl/tsl/platform/host_info.h @@ -16,11 +16,26 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_HOST_INFO_H_ #define TENSORFLOW_TSL_PLATFORM_HOST_INFO_H_ +#include + #include "tsl/platform/types.h" namespace tsl { namespace port { +// Statistical data of IO operations performed by the job. +struct IOStatistics { + struct Distribution { + uint64_t count = 0; + double mean = 0.0; + double std_dev = 0.0; + }; + // Distribution of round trip IO latency in microseconds. + Distribution roundtrip_latency_usec; + // Distribution of data received by IO reads in bytes. + Distribution response_bytes; +}; + // Return the hostname of the machine on which this process is running. string Hostname(); @@ -34,6 +49,9 @@ int64_t JobUid(); // Returns the Borg task ID as an int64_t if it exists. Otherwise return -1. int64_t TaskId(); +// Retrieves the host file read statistics. +IOStatistics GetIOStatistics(); + } // namespace port } // namespace tsl diff --git a/third_party/tsl/tsl/platform/intrusive_ptr.h b/third_party/tsl/tsl/platform/intrusive_ptr.h index 3793ba3e8610e..5407f15af1aaa 100644 --- a/third_party/tsl/tsl/platform/intrusive_ptr.h +++ b/third_party/tsl/tsl/platform/intrusive_ptr.h @@ -32,7 +32,7 @@ class IntrusivePtr { // object needs to be externally managed. IntrusivePtr(T* h, bool add_ref) { reset(h, add_ref); } IntrusivePtr(const IntrusivePtr& o) { reset(o.handle_, /*add_ref=*/true); } - IntrusivePtr(IntrusivePtr&& o) { *this = std::move(o); } + IntrusivePtr(IntrusivePtr&& o) noexcept { *this = std::move(o); } IntrusivePtr() {} void reset(T* h, bool add_ref) { if (h != handle_) { @@ -45,7 +45,7 @@ class IntrusivePtr { reset(o.handle_, /*add_ref=*/true); return *this; } - IntrusivePtr& operator=(IntrusivePtr&& o) { + IntrusivePtr& operator=(IntrusivePtr&& o) noexcept { if (handle_ != o.handle_) { // Must clear o.handle_ before calling reset to capture the case where // handle_->member == o. In this case, calling handle_->Unref first would diff --git a/third_party/tsl/tsl/platform/jpeg.h b/third_party/tsl/tsl/platform/jpeg.h deleted file mode 100644 index a7b640db03943..0000000000000 --- a/third_party/tsl/tsl/platform/jpeg.h +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_TSL_PLATFORM_JPEG_H_ -#define TENSORFLOW_TSL_PLATFORM_JPEG_H_ - -#include -#include -#include -#include - -extern "C" { -#include "jerror.h" // from @libjpeg_turbo // IWYU pragma: export -#include "jpeglib.h" // from @libjpeg_turbo // IWYU pragma: export -} - -#endif // TENSORFLOW_TSL_PLATFORM_JPEG_H_ diff --git a/third_party/tsl/tsl/platform/load_library.h b/third_party/tsl/tsl/platform/load_library.h index e46f85da0a7f9..5a42f2a3439fd 100644 --- a/third_party/tsl/tsl/platform/load_library.h +++ b/third_party/tsl/tsl/platform/load_library.h @@ -16,16 +16,19 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_LOAD_LIBRARY_H_ #define TENSORFLOW_TSL_PLATFORM_LOAD_LIBRARY_H_ -#include "tsl/platform/status.h" +#include + +#include "absl/status/status.h" namespace tsl { namespace internal { -Status LoadDynamicLibrary(const char* library_filename, void** handle); -Status GetSymbolFromLibrary(void* handle, const char* symbol_name, - void** symbol); -string FormatLibraryFileName(const string& name, const string& version); +absl::Status LoadDynamicLibrary(const char* library_filename, void** handle); +absl::Status GetSymbolFromLibrary(void* handle, const char* symbol_name, + void** symbol); +std::string FormatLibraryFileName(const std::string& name, + const std::string& version); } // namespace internal diff --git a/third_party/tsl/tsl/platform/logger.cc b/third_party/tsl/tsl/platform/logger.cc deleted file mode 100644 index af4c721f135f7..0000000000000 --- a/third_party/tsl/tsl/platform/logger.cc +++ /dev/null @@ -1,104 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tsl/platform/logger.h" - -#include "absl/base/call_once.h" -#include "absl/synchronization/notification.h" -#include "tsl/platform/env.h" -#include "tsl/platform/logging.h" - -namespace tsl { -namespace { - -class DefaultLogger : public Logger { - private: - void DoLogProto(google::protobuf::Any* proto) override {} - void DoFlush() override {} -}; - -} // namespace - -Logger::FactoryFunc Logger::singleton_factory_ = []() -> Logger* { - return new DefaultLogger(); -}; - -struct LoggerSingletonContainer { - // Used to kick off the construction of a new thread that will asynchronously - // construct a Logger. - absl::once_flag start_initialization_thread_flag; - - // The constructed logger, if there is one. - Logger* logger; - - // The initializing thread notifies `logger_initialized` after storing the - // constructed logger to `logger`. - absl::Notification logger_initialized; - - // The thread used to construct the Logger instance asynchronously. - std::unique_ptr initialization_thread; - - // Used to kick off the joining and destruction of `initialization_thread`. - absl::once_flag delete_initialization_thread_flag; -}; - -LoggerSingletonContainer* GetLoggerSingletonContainer() { - static LoggerSingletonContainer* container = new LoggerSingletonContainer; - return container; -} - -struct AsyncSingletonImpl { - static void InitializationThreadFn() { - LoggerSingletonContainer* container = GetLoggerSingletonContainer(); - container->logger = Logger::singleton_factory_(); - container->logger_initialized.Notify(); - } - - static void StartInitializationThread(LoggerSingletonContainer* container) { - Thread* thread = - Env::Default()->StartThread(ThreadOptions{}, "logger-init-thread", - AsyncSingletonImpl::InitializationThreadFn); - container->initialization_thread.reset(thread); - } -}; - -/*static*/ Logger* Logger::GetSingleton() { - // Call the async version to kick off the initialization thread if necessary. - (void)Logger::GetSingletonAsync(); - - // And wait for the thread to finish. - LoggerSingletonContainer* container = GetLoggerSingletonContainer(); - absl::call_once(container->delete_initialization_thread_flag, - [container]() { container->initialization_thread.reset(); }); - - return container->logger; -} - -/*static*/ Logger* Logger::GetSingletonAsync() { - LoggerSingletonContainer* container = GetLoggerSingletonContainer(); - absl::call_once(container->start_initialization_thread_flag, - AsyncSingletonImpl::StartInitializationThread, container); - - if (container->logger_initialized.HasBeenNotified()) { - // Wait for the initializing thread to finish to reclaim resources. - absl::call_once( - container->delete_initialization_thread_flag, - [container]() { container->initialization_thread.reset(); }); - return container->logger; - } else { - return nullptr; - } -} -} // namespace tsl diff --git a/third_party/tsl/tsl/platform/logger.h b/third_party/tsl/tsl/platform/logger.h deleted file mode 100644 index 369da338272b3..0000000000000 --- a/third_party/tsl/tsl/platform/logger.h +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_TSL_PLATFORM_LOGGER_H_ -#define TENSORFLOW_TSL_PLATFORM_LOGGER_H_ - -#include "google/protobuf/any.pb.h" -#include "tsl/platform/protobuf.h" - -namespace tsl { - -// Abstract logging interface. Contrary to logging.h, this class describes an -// interface, not a concrete logging mechanism. This is useful when we want to -// log anything to a non-local place, e.g. a database. -class Logger { - public: - // The singleton is supposed to be used in the following steps: - // * At program start time, REGISTER_MODULE_INITIALIZER calls - // SetSingletonFactory. - // * At some point in the program execution, Singleton() is called for the - // first time, initializing the logger. - // * Succeeding calls to Singleton() return the initialized logger. - using FactoryFunc = Logger* (*)(); - - static void SetSingletonFactory(FactoryFunc factory) { - singleton_factory_ = factory; - } - - // Returns the per-process Logger instance, constructing synchronously it if - // necessary. - static Logger* GetSingleton(); - - // Like GetSingleton, except that this does not wait for the construction of - // Logger to finish before returning. - // - // Returns the constructed instance of Logger if it has been constructed, - // otherwise returns nullptr (if the logger is not ready yet). - static Logger* GetSingletonAsync(); - - virtual ~Logger() = default; - - // Logs a typed proto. - template - void LogProto(const ProtoType& proto) { - google::protobuf::Any any; - any.PackFrom(proto); - DoLogProto(&any); - } - - // Flushes any pending log. Blocks until everything is flushed. - void Flush() { DoFlush(); } - - private: - virtual void DoLogProto(google::protobuf::Any* proto) = 0; - virtual void DoFlush() = 0; - - static FactoryFunc singleton_factory_; - - friend struct AsyncSingletonImpl; -}; - -} // namespace tsl - -#endif // TENSORFLOW_TSL_PLATFORM_LOGGER_H_ diff --git a/third_party/tsl/tsl/platform/ml_dtypes.h b/third_party/tsl/tsl/platform/ml_dtypes.h index c25efc2f865b7..504085af8518e 100644 --- a/third_party/tsl/tsl/platform/ml_dtypes.h +++ b/third_party/tsl/tsl/platform/ml_dtypes.h @@ -20,16 +20,16 @@ limitations under the License. #include "ml_dtypes/include/int4.h" // from @ml_dtypes namespace tsl { -using float8_e4m3fn = ml_dtypes::float8_e4m3fn; -using float8_e4m3fnuz = ml_dtypes::float8_e4m3fnuz; -using float8_e4m3b11fnuz = ml_dtypes::float8_e4m3b11fnuz; +using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn; +using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz; +using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz; using float8_e4m3b11 = float8_e4m3b11fnuz; // Deprecated: old name for // backward-compatibility only. -using float8_e5m2 = ml_dtypes::float8_e5m2; -using float8_e5m2fnuz = ml_dtypes::float8_e5m2fnuz; +using float8_e5m2 = ::ml_dtypes::float8_e5m2; +using float8_e5m2fnuz = ::ml_dtypes::float8_e5m2fnuz; -using int4 = ml_dtypes::int4; -using uint4 = ml_dtypes::uint4; +using int4 = ::ml_dtypes::int4; +using uint4 = ::ml_dtypes::uint4; } // namespace tsl #endif // TENSORFLOW_TSL_PLATFORM_ML_DTYPES_H_ diff --git a/third_party/tsl/tsl/platform/notification.h b/third_party/tsl/tsl/platform/notification.h index b5adb1e51e850..80e5b388d2a93 100644 --- a/third_party/tsl/tsl/platform/notification.h +++ b/third_party/tsl/tsl/platform/notification.h @@ -16,17 +16,25 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_NOTIFICATION_H_ #define TENSORFLOW_TSL_PLATFORM_NOTIFICATION_H_ -#include "tsl/platform/platform.h" - -// Include appropriate platform-dependent implementations of Notification. -#if defined(PLATFORM_GOOGLE) -#include "tsl/platform/google/notification.h" // IWYU pragma: export -#elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \ - defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_POSIX_IOS) || \ - defined(PLATFORM_GOOGLE_IOS) || defined(PLATFORM_WINDOWS) -#include "tsl/platform/default/notification.h" // IWYU pragma: export -#else -#error Define the appropriate PLATFORM_ macro for this platform -#endif +#include // NOLINT +#include // NOLINT +#include +#include // NOLINT + +#include "absl/synchronization/notification.h" +#include "absl/time/time.h" + +namespace tsl { + +using absl::Notification; + +// TODO(ddunleavy): remove this method and replace uses of `tsl::Notification` +// with `absl::Notification`. +inline bool WaitForNotificationWithTimeout(Notification* n, + int64_t timeout_in_us) { + return n->WaitForNotificationWithTimeout(absl::Microseconds(timeout_in_us)); +} + +} // namespace tsl #endif // TENSORFLOW_TSL_PLATFORM_NOTIFICATION_H_ diff --git a/third_party/tsl/tsl/platform/path.cc b/third_party/tsl/tsl/platform/path.cc index b33af3eb7c311..580aacde900c1 100644 --- a/third_party/tsl/tsl/platform/path.cc +++ b/third_party/tsl/tsl/platform/path.cc @@ -407,5 +407,12 @@ bool ResolveTestPrefixes(tsl::StringPiece path, string& resolved_path) { } } +[[maybe_unused]] std::string& AppendDotExeIfWindows(std::string& path) { +#ifdef PLATFORM_WINDOWS + path.append(".exe"); +#endif // PLATFORM_WINDOWS + return path; +} + } // namespace io } // namespace tsl diff --git a/third_party/tsl/tsl/platform/path.h b/third_party/tsl/tsl/platform/path.h index 451addc60b465..f0a5b87d135c2 100644 --- a/third_party/tsl/tsl/platform/path.h +++ b/third_party/tsl/tsl/platform/path.h @@ -126,6 +126,9 @@ bool GetTestUndeclaredOutputsDir(std::string* dir); // be resolved. bool ResolveTestPrefixes(tsl::StringPiece path, std::string& resolved_path); +// Appends `.exe` if `PLATFORM_WINDOWS` is defined. +[[maybe_unused]] std::string& AppendDotExeIfWindows(std::string& path); + } // namespace io } // namespace tsl diff --git a/third_party/tsl/tsl/platform/platform.h b/third_party/tsl/tsl/platform/platform.h index 9456687727a1a..aca7a8141bc11 100644 --- a/third_party/tsl/tsl/platform/platform.h +++ b/third_party/tsl/tsl/platform/platform.h @@ -82,4 +82,6 @@ limitations under the License. #endif #endif +#define TSL_IS_IN_OSS 1 + #endif // TENSORFLOW_TSL_PLATFORM_PLATFORM_H_ diff --git a/third_party/tsl/tsl/platform/png.h b/third_party/tsl/tsl/platform/png.h deleted file mode 100644 index c66e88a5dcd6a..0000000000000 --- a/third_party/tsl/tsl/platform/png.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_TSL_PLATFORM_PNG_H_ -#define TENSORFLOW_TSL_PLATFORM_PNG_H_ - -#include "tsl/platform/platform.h" - -#if defined(PLATFORM_GOOGLE) && !defined(IS_MOBILE_PLATFORM) -#include "png.h" // from @png // IWYU pragma: export -#elif defined(PLATFORM_POSIX) || defined(PLATFORM_WINDOWS) || \ - defined(PLATFORM_POSIX_ANDROID) || defined(IS_MOBILE_PLATFORM) -#include // IWYU pragma: export -#else -#error Define the appropriate PLATFORM_ macro for this platform -#endif - -#endif // TENSORFLOW_TSL_PLATFORM_PNG_H_ diff --git a/third_party/tsl/tsl/platform/profile_utils/BUILD b/third_party/tsl/tsl/platform/profile_utils/BUILD index 8dcb6525498fc..046ae3295f319 100644 --- a/third_party/tsl/tsl/platform/profile_utils/BUILD +++ b/third_party/tsl/tsl/platform/profile_utils/BUILD @@ -1,21 +1,17 @@ # Description: # profile_utils targets. -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility", "tsl_copts") load("//tsl:tsl.default.bzl", "filegroup") load( "//tsl/platform:rules_cc.bzl", "cc_library", ) -load( - "//tsl:tsl.bzl", - "tsl_copts", -) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([ - "//tensorflow/compiler/xla/stream_executor:__subpackages__", + default_visibility = internal_visibility([ + "@xla//xla/stream_executor:__subpackages__", "//tensorflow/core/platform:__subpackages__", "//tsl:__pkg__", "//tsl/platform/default:__pkg__", @@ -39,10 +35,10 @@ filegroup( "android_armv7a_cpu_utils_helper.cc", "clock_cycle_profiler.cc", ], - visibility = [ + visibility = internal_visibility([ "//tensorflow/core/platform:__subpackages__", "//tsl/platform:__pkg__", - ], + ]), ) cc_library( diff --git a/third_party/tsl/tsl/platform/protobuf.h b/third_party/tsl/tsl/platform/protobuf.h index fd9f9559f4876..e618016866225 100644 --- a/third_party/tsl/tsl/platform/protobuf.h +++ b/third_party/tsl/tsl/platform/protobuf.h @@ -27,22 +27,27 @@ limitations under the License. // TensorFlow code should use the ::tensorflow::protobuf namespace to // refer to all protobuf APIs. -#include "google/protobuf/descriptor.pb.h" // IWYU pragma:export -#include "google/protobuf/arena.h" // IWYU pragma:export -#include "google/protobuf/descriptor.h" // IWYU pragma:export -#include "google/protobuf/dynamic_message.h" // IWYU pragma:export -#include "google/protobuf/io/coded_stream.h" // IWYU pragma:export -#include "google/protobuf/io/tokenizer.h" // IWYU pragma:export -#include "google/protobuf/io/zero_copy_stream.h" // IWYU pragma:export -#include "google/protobuf/io/zero_copy_stream_impl_lite.h" // IWYU pragma:export -#include "google/protobuf/map.h" // IWYU pragma:export -#include "google/protobuf/message.h" // IWYU pragma:export -#include "google/protobuf/repeated_field.h" // IWYU pragma:export -#include "google/protobuf/text_format.h" // IWYU pragma:export -#include "google/protobuf/util/field_comparator.h" // IWYU pragma:export -#include "google/protobuf/util/json_util.h" // IWYU pragma:export -#include "google/protobuf/util/message_differencer.h" // IWYU pragma:export -#include "google/protobuf/util/type_resolver_util.h" // IWYU pragma:export +#include "google/protobuf/descriptor.pb.h" // IWYU pragma: export +#include "google/protobuf/arena.h" // IWYU pragma: export +#include "google/protobuf/descriptor.h" // IWYU pragma: export +#include "google/protobuf/dynamic_message.h" // IWYU pragma: export +#include "google/protobuf/io/coded_stream.h" // IWYU pragma: export +#include "google/protobuf/io/tokenizer.h" // IWYU pragma: export +#include "google/protobuf/io/zero_copy_stream.h" // IWYU pragma: export +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" // IWYU pragma: export +#include "google/protobuf/map.h" // IWYU pragma: export +#include "google/protobuf/message.h" // IWYU pragma: export +#include "google/protobuf/repeated_field.h" // IWYU pragma: export +#include "google/protobuf/repeated_ptr_field.h" // IWYU pragma: export +#include "google/protobuf/text_format.h" // IWYU pragma: export +#include "google/protobuf/util/field_comparator.h" // IWYU pragma: export +#include "google/protobuf/util/json_util.h" // IWYU pragma: export +#include "google/protobuf/util/message_differencer.h" // IWYU pragma: export +#include "google/protobuf/util/type_resolver_util.h" // IWYU pragma: export + +#if !TSL_IS_IN_OSS +#define TENSORFLOW_PROTOBUF_USES_CORD 1 +#endif // TSL_IS_IN_OSS namespace tsl { @@ -120,6 +125,12 @@ class TStringOutputStream : public protobuf::io::ZeroCopyOutputStream { tstring* target_; }; + +std::string LegacyUnredactedDebugString(const tsl::protobuf::Message& message); +std::string LegacyUnredactedDebugString( + const tsl::protobuf::MessageLite& message); +std::string LegacyUnredactedShortDebugString( + const tsl::protobuf::Message& message); } // namespace tsl #endif // TENSORFLOW_TSL_PLATFORM_PROTOBUF_H_ diff --git a/third_party/tsl/tsl/platform/protobuf_util.cc b/third_party/tsl/tsl/platform/protobuf_util.cc index 837bf112ed0ec..0bab613132797 100644 --- a/third_party/tsl/tsl/platform/protobuf_util.cc +++ b/third_party/tsl/tsl/platform/protobuf_util.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tsl/platform/protobuf.h" namespace tsl { @@ -27,4 +29,32 @@ bool ParseProtoUnlimited(protobuf::MessageLite* proto, const void* serialized, return proto->ParseFromArray(serialized, size); } +std::string LegacyUnredactedDebugString(const tsl::protobuf::Message& message) { + std::string debug_string; + tsl::protobuf::TextFormat::Printer printer; + printer.SetExpandAny(true); + + printer.PrintToString(message, &debug_string); + return debug_string; +} + +std::string LegacyUnredactedDebugString( + const tsl::protobuf::MessageLite& message) { + return message.DebugString(); +} + +std::string LegacyUnredactedShortDebugString( + const tsl::protobuf::Message& message) { + std::string debug_string; + tsl::protobuf::TextFormat::Printer printer; + printer.SetSingleLineMode(true); + printer.SetExpandAny(true); + + printer.PrintToString(message, &debug_string); + if (!debug_string.empty() && debug_string.back() == ' ') { + debug_string.pop_back(); + } + return debug_string; +} + } // namespace tsl diff --git a/third_party/tsl/tsl/platform/refcount.h b/third_party/tsl/tsl/platform/refcount.h index 1b3944bf3e13c..c3461c615a306 100644 --- a/third_party/tsl/tsl/platform/refcount.h +++ b/third_party/tsl/tsl/platform/refcount.h @@ -259,13 +259,13 @@ class WeakPtr { return *this; } - WeakPtr(WeakPtr&& other) { + WeakPtr(WeakPtr&& other) noexcept { data_ = std::move(other.data_); notifier_id_ = other.notifier_id_; other.notifier_id_ = 0; } - WeakPtr& operator=(WeakPtr&& other) { + WeakPtr& operator=(WeakPtr&& other) noexcept { if (this != &other) { if (data_ != nullptr && notifier_id_ != 0) { data_->RemoveNotifier(notifier_id_); diff --git a/third_party/tsl/tsl/platform/regexp.h b/third_party/tsl/tsl/platform/regexp.h index dadb53bbc0eb6..fac545c266aae 100644 --- a/third_party/tsl/tsl/platform/regexp.h +++ b/third_party/tsl/tsl/platform/regexp.h @@ -18,7 +18,10 @@ limitations under the License. #include "tsl/platform/platform.h" -// TODO(b/305283688): make a platform macro for internal windows builds -#include "re2/re2.h" +#if TSL_IS_IN_OSS +#include "re2/re2.h" // IWYU pragma: export +#else +#include "third_party/re2/re2.h" // IWYU pragma: export +#endif // TSL_IS_IN_OSS #endif // TENSORFLOW_TSL_PLATFORM_REGEXP_H_ diff --git a/third_party/tsl/tsl/platform/resource_loader.cc b/third_party/tsl/tsl/platform/resource_loader.cc new file mode 100644 index 0000000000000..616a6553e33c9 --- /dev/null +++ b/third_party/tsl/tsl/platform/resource_loader.cc @@ -0,0 +1,44 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tsl/platform/resource_loader.h" + +#include +#include + +#include "tsl/platform/logging.h" +#include "tsl/platform/path.h" +#include "tsl/platform/test.h" + +namespace tsl { + +std::string GetDataDependencyFilepath(const std::string& relative_path) { + // TODO(ddunleavy): replace this with `TensorFlowSrcRoot()` from `test.h`. + const char* srcdir = std::getenv("TEST_SRCDIR"); + if (!srcdir) { + LOG(FATAL) << "Environment variable TEST_SRCDIR unset!"; // Crash OK + } + + const char* workspace = std::getenv("TEST_WORKSPACE"); + if (!workspace) { + LOG(FATAL) << "Environment variable TEST_WORKSPACE unset!"; // Crash OK + } + + return testing::kIsOpenSource + ? io::JoinPath(srcdir, workspace, relative_path) + : io::JoinPath(srcdir, workspace, "third_party", relative_path); +} + +} // namespace tsl diff --git a/third_party/tsl/tsl/platform/retrying_utils.cc b/third_party/tsl/tsl/platform/retrying_utils.cc index c148a5b45b233..3beb3f46110a2 100644 --- a/third_party/tsl/tsl/platform/retrying_utils.cc +++ b/third_party/tsl/tsl/platform/retrying_utils.cc @@ -12,12 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "tsl/platform/retrying_utils.h" +#include +#include +#include +#include + +#include "absl/time/time.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/file_system.h" +#include "tsl/platform/logging.h" #include "tsl/platform/random.h" namespace tsl { @@ -36,6 +42,16 @@ bool IsRetriable(absl::StatusCode code) { } } +double GenerateUniformRandomNumber() { + return random::New64() * (1.0 / std::numeric_limits::max()); +} + +double GenerateUniformRandomNumberBetween(double a, double b) { + if (a == b) return a; + DCHECK_LT(a, b); + return a + GenerateUniformRandomNumber() * (b - a); +} + } // namespace Status RetryingUtils::CallWithRetries(const std::function& f, @@ -97,4 +113,40 @@ Status RetryingUtils::DeleteWithRetries( config); } +absl::Duration ComputeRetryBackoff(int current_retry_attempt, + absl::Duration min_delay, + absl::Duration max_delay) { + DCHECK_GE(current_retry_attempt, 0); + + // This function with the constants below is calculating: + // + // (0.4 * min_delay) + (random[0.6,1.0] * min_delay * 1.3^retries) + // + // Note that there is an extra truncation that occurs and is documented in + // comments below. + constexpr double kBackoffBase = 1.3; + constexpr double kBackoffRandMult = 0.4; + + // This first term does not vary with current_retry_attempt or a random + // number. It exists to ensure the final term is >= min_delay. + const absl::Duration first_term = min_delay * kBackoffRandMult; + + // This is calculating min_delay * 1.3^retries. + absl::Duration uncapped_second_term = + min_delay * std::pow(kBackoffBase, current_retry_attempt); + + // Note that first_term + uncapped_second_term can exceed max_delay here + // because of the final multiply by kBackoffBase. We fix that problem with + // the min() below. + absl::Duration second_term = + std::min(uncapped_second_term, max_delay - first_term); + + // This supplies the random jitter to ensure that retried don't cause a + // thundering herd problem. + second_term *= + GenerateUniformRandomNumberBetween(1.0 - kBackoffRandMult, 1.0); + + return std::max(first_term + second_term, min_delay); +} + } // namespace tsl diff --git a/third_party/tsl/tsl/platform/retrying_utils.h b/third_party/tsl/tsl/platform/retrying_utils.h index 950e5bb17c829..3252da2637c4d 100644 --- a/third_party/tsl/tsl/platform/retrying_utils.h +++ b/third_party/tsl/tsl/platform/retrying_utils.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #ifndef TENSORFLOW_TSL_PLATFORM_RETRYING_UTILS_H_ #define TENSORFLOW_TSL_PLATFORM_RETRYING_UTILS_H_ #include +#include "absl/time/time.h" #include "tsl/platform/status.h" namespace tsl { @@ -67,6 +67,15 @@ class RetryingUtils { const RetryConfig& config); }; +// Given the total number of retries attempted, returns a randomized duration of +// time to delay before the next retry. +// +// The average computed backoff increases with the number of retries attempted. +// See implementation for details on the calculations. +absl::Duration ComputeRetryBackoff( + int current_retry_attempt, absl::Duration min_delay = absl::Milliseconds(1), + absl::Duration max_delay = absl::Seconds(10)); + } // namespace tsl #endif // TENSORFLOW_TSL_PLATFORM_RETRYING_UTILS_H_ diff --git a/third_party/tsl/tsl/platform/retrying_utils_test.cc b/third_party/tsl/tsl/platform/retrying_utils_test.cc index 1533376903639..c0b3ad7b651e1 100644 --- a/third_party/tsl/tsl/platform/retrying_utils_test.cc +++ b/third_party/tsl/tsl/platform/retrying_utils_test.cc @@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "tsl/platform/retrying_utils.h" +#include #include +#include "absl/time/time.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" @@ -154,5 +155,28 @@ TEST(RetryingUtilsTest, DeleteWithRetries_FirstNotFoundReturnedAsIs) { .code()); } +TEST(RetryingUtilsTest, ComputeRetryBackoff) { + for (int i = 0; i < 30; ++i) { + EXPECT_LE(0.4 * absl::Milliseconds(1) + + 0.6 * absl::Milliseconds(1) * std::pow(1.3, i), + ComputeRetryBackoff(/*current_retry_attempt=*/i)); + EXPECT_LE( + ComputeRetryBackoff(/*current_retry_attempt=*/i), + 0.4 * absl::Milliseconds(1) + absl::Milliseconds(1) * std::pow(1.3, i)); + } +} + +TEST(RetryingUtilsTest, ComputeRetryBackoff_MinMaxDelays) { + for (int i = 0; i < 30; ++i) { + EXPECT_EQ(ComputeRetryBackoff(/*current_retry_attempt=*/i, + /*min_delay=*/absl::Seconds(10)), + absl::Seconds(10)); + EXPECT_EQ(ComputeRetryBackoff(/*current_retry_attempt=*/i, + /*min_delay=*/absl::Microseconds(1), + /*max_delay=*/absl::Microseconds(1)), + absl::Microseconds(1)); + } +} + } // namespace } // namespace tsl diff --git a/third_party/tsl/tsl/platform/status.cc b/third_party/tsl/tsl/platform/status.cc index 063ad31456881..2fb124322c109 100644 --- a/third_party/tsl/tsl/platform/status.cc +++ b/third_party/tsl/tsl/platform/status.cc @@ -166,12 +166,6 @@ const char* NullTerminatedMessage(const Status& status) { } -Status OkStatus() { return Status(); } - -Status FromAbslStatus(const absl::Status& s) { return s; } - -absl::Status ToAbslStatus(const ::absl::Status& s) { return s; } - std::string* TfCheckOpHelperOutOfLine(const ::tsl::Status& v, const char* msg) { std::string r("Non-OK-status: "); r += msg; diff --git a/third_party/tsl/tsl/platform/status.h b/third_party/tsl/tsl/platform/status.h index e4d342af9f2d1..812ac1a0d6adf 100644 --- a/third_party/tsl/tsl/platform/status.h +++ b/third_party/tsl/tsl/platform/status.h @@ -26,6 +26,7 @@ limitations under the License. #include #include "absl/base/attributes.h" +#include "absl/base/macros.h" #include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/strings/cord.h" @@ -45,6 +46,11 @@ limitations under the License. #include "tsl/platform/default/status.h" // IWYU pragma: export #endif +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + namespace tsl { // Since April 2023, tensorflow::Status is an alias to absl::Status. The first @@ -54,10 +60,10 @@ namespace tsl { // // Here is a set of correspondences: // - Use `absl::OkStatus()` instead of `tsl::OkStatus()`. -typedef absl::Status Status; +typedef absl::Status Status ABSL_DEPRECATE_AND_INLINE(); namespace errors { -typedef absl::StatusCode Code; +typedef absl::StatusCode Code ABSL_DEPRECATE_AND_INLINE(); } // namespace errors namespace error { typedef ::tensorflow::error::Code Code; @@ -99,10 +105,14 @@ namespace tsl { // // Returns an OK status, equivalent to a default constructed instance. Prefer // usage of `OkStatus()` when constructing such an OK status. -Status OkStatus(); +ABSL_DEPRECATE_AND_INLINE() inline absl::Status OkStatus() { + return absl::OkStatus(); +}; -absl::Status FromAbslStatus(const absl::Status& s); -absl::Status ToAbslStatus(const ::absl::Status& s); +ABSL_DEPRECATE_AND_INLINE() +inline absl::Status FromAbslStatus(const absl::Status& s) { return s; } +ABSL_DEPRECATE_AND_INLINE() +inline absl::Status ToAbslStatus(const ::absl::Status& s) { return s; } // Given `Status.message()` does not guarantee to be always backed by a // null-terminated string, we have this utility function when it's needed for diff --git a/third_party/tsl/tsl/platform/statusor.h b/third_party/tsl/tsl/platform/statusor.h index 0db4e733112c8..6c49be5132fc9 100644 --- a/third_party/tsl/tsl/platform/statusor.h +++ b/third_party/tsl/tsl/platform/statusor.h @@ -69,6 +69,7 @@ limitations under the License. #define TENSORFLOW_TSL_PLATFORM_STATUSOR_H_ #include "absl/base/attributes.h" +#include "absl/base/macros.h" #include "absl/status/statusor.h" #include "tsl/platform/errors.h" #include "tsl/platform/macros.h" @@ -82,9 +83,15 @@ limitations under the License. #include "tsl/platform/default/statusor.h" // IWYU pragma: export #endif +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + namespace tsl { -using absl::StatusOr; +template +using StatusOr ABSL_DEPRECATE_AND_INLINE() = absl::StatusOr; } // namespace tsl diff --git a/third_party/tsl/tsl/platform/subprocess_test.cc b/third_party/tsl/tsl/platform/subprocess_test.cc index 4e36b0d790da2..1b1bbcb3113e1 100644 --- a/third_party/tsl/tsl/platform/subprocess_test.cc +++ b/third_party/tsl/tsl/platform/subprocess_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tsl/platform/subprocess.h" #include -#include #include +#include #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/path.h" @@ -36,48 +36,43 @@ limitations under the License. namespace tsl { namespace { -static string GetDataFilePath(const string& relative_path) { -#ifdef PLATFORM_WINDOWS - // While CreateProcess on windows is resilient to not having ".exe" suffix, - // Bazel_tools has to have the exact file path to return the resource. - return strings::StrCat(relative_path, ".exe"); -#else - return relative_path; -#endif -} string EchoProgram() { - return io::JoinPath(testing::TslSrcRoot(), "platform", "testdata", - "test_echo"); + std::string path = + io::JoinPath(testing::TslSrcRoot(), "platform", "testdata", "test_echo"); + return tsl::io::AppendDotExeIfWindows(path); } string EchoArgv1Program() { - return io::JoinPath(testing::TslSrcRoot(), "platform", "testdata", - "test_echo_argv_1"); + std::string path = io::JoinPath(testing::TslSrcRoot(), "platform", "testdata", + "test_echo_argv_1"); + return tsl::io::AppendDotExeIfWindows(path); } string NoopProgram() { - return io::JoinPath(testing::TslSrcRoot(), "platform", "testdata", - "test_noop"); + std::string path = + io::JoinPath(testing::TslSrcRoot(), "platform", "testdata", "test_noop"); + return tsl::io::AppendDotExeIfWindows(path); } string StdErrProgram() { - return io::JoinPath(testing::TslSrcRoot(), "platform", "testdata", - "test_stderr"); + std::string path = io::JoinPath(testing::TslSrcRoot(), "platform", "testdata", + "test_stderr"); + return tsl::io::AppendDotExeIfWindows(path); } class SubProcessTest : public ::testing::Test {}; TEST_F(SubProcessTest, NoOutputNoComm) { tsl::SubProcess proc; - proc.SetProgram(GetDataFilePath(NoopProgram()).c_str(), {NoopProgram()}); + proc.SetProgram(NoopProgram().c_str(), {NoopProgram()}); EXPECT_TRUE(proc.Start()); EXPECT_TRUE(proc.Wait()); } TEST_F(SubProcessTest, NoOutput) { tsl::SubProcess proc; - proc.SetProgram(GetDataFilePath(NoopProgram()).c_str(), {NoopProgram()}); + proc.SetProgram(NoopProgram().c_str(), {NoopProgram()}); proc.SetChannelAction(CHAN_STDOUT, ACTION_PIPE); proc.SetChannelAction(CHAN_STDERR, ACTION_PIPE); EXPECT_TRUE(proc.Start()); @@ -93,7 +88,7 @@ TEST_F(SubProcessTest, NoOutput) { TEST_F(SubProcessTest, Stdout) { tsl::SubProcess proc; const char test_string[] = "hello_world"; - proc.SetProgram(GetDataFilePath(EchoArgv1Program()).c_str(), + proc.SetProgram(EchoArgv1Program().c_str(), {EchoArgv1Program(), test_string}); proc.SetChannelAction(CHAN_STDOUT, ACTION_PIPE); proc.SetChannelAction(CHAN_STDERR, ACTION_PIPE); @@ -110,7 +105,7 @@ TEST_F(SubProcessTest, Stdout) { TEST_F(SubProcessTest, StdoutIgnored) { tsl::SubProcess proc; const char test_string[] = "hello_world"; - proc.SetProgram(GetDataFilePath(EchoArgv1Program()).c_str(), + proc.SetProgram(EchoArgv1Program().c_str(), {EchoArgv1Program(), test_string}); proc.SetChannelAction(CHAN_STDOUT, ACTION_PIPE); proc.SetChannelAction(CHAN_STDERR, ACTION_PIPE); @@ -124,8 +119,7 @@ TEST_F(SubProcessTest, StdoutIgnored) { TEST_F(SubProcessTest, Stderr) { tsl::SubProcess proc; const char test_string[] = "muh_failure!"; - proc.SetProgram(GetDataFilePath(StdErrProgram()).c_str(), - {StdErrProgram(), test_string}); + proc.SetProgram(StdErrProgram().c_str(), {StdErrProgram(), test_string}); proc.SetChannelAction(CHAN_STDOUT, ACTION_PIPE); proc.SetChannelAction(CHAN_STDERR, ACTION_PIPE); EXPECT_TRUE(proc.Start()); @@ -141,8 +135,7 @@ TEST_F(SubProcessTest, Stderr) { TEST_F(SubProcessTest, StderrIgnored) { tsl::SubProcess proc; const char test_string[] = "muh_failure!"; - proc.SetProgram(GetDataFilePath(StdErrProgram()).c_str(), - {StdErrProgram(), test_string}); + proc.SetProgram(StdErrProgram().c_str(), {StdErrProgram(), test_string}); proc.SetChannelAction(CHAN_STDOUT, ACTION_PIPE); proc.SetChannelAction(CHAN_STDERR, ACTION_PIPE); EXPECT_TRUE(proc.Start()); @@ -154,7 +147,7 @@ TEST_F(SubProcessTest, StderrIgnored) { TEST_F(SubProcessTest, Stdin) { tsl::SubProcess proc; - proc.SetProgram(GetDataFilePath(EchoProgram()).c_str(), {EchoProgram()}); + proc.SetProgram(EchoProgram().c_str(), {EchoProgram()}); proc.SetChannelAction(CHAN_STDIN, ACTION_PIPE); EXPECT_TRUE(proc.Start()); @@ -166,7 +159,7 @@ TEST_F(SubProcessTest, Stdin) { TEST_F(SubProcessTest, StdinStdout) { tsl::SubProcess proc; - proc.SetProgram(GetDataFilePath(EchoProgram()).c_str(), {EchoProgram()}); + proc.SetProgram(EchoProgram().c_str(), {EchoProgram()}); proc.SetChannelAction(CHAN_STDIN, ACTION_PIPE); proc.SetChannelAction(CHAN_STDOUT, ACTION_PIPE); EXPECT_TRUE(proc.Start()); @@ -183,7 +176,7 @@ TEST_F(SubProcessTest, StdinStdout) { TEST_F(SubProcessTest, StdinChildExit) { tsl::SubProcess proc; - proc.SetProgram(GetDataFilePath(NoopProgram()).c_str(), {NoopProgram()}); + proc.SetProgram(NoopProgram().c_str(), {NoopProgram()}); proc.SetChannelAction(CHAN_STDIN, ACTION_PIPE); EXPECT_TRUE(proc.Start()); @@ -202,7 +195,7 @@ TEST_F(SubProcessTest, StdinChildExit) { TEST_F(SubProcessTest, StdinStdoutOverlap) { tsl::SubProcess proc; - proc.SetProgram(GetDataFilePath(EchoProgram()).c_str(), {EchoProgram()}); + proc.SetProgram(EchoProgram().c_str(), {EchoProgram()}); proc.SetChannelAction(CHAN_STDIN, ACTION_PIPE); proc.SetChannelAction(CHAN_STDOUT, ACTION_PIPE); EXPECT_TRUE(proc.Start()); @@ -226,7 +219,7 @@ TEST_F(SubProcessTest, StdinStdoutOverlap) { TEST_F(SubProcessTest, KillProc) { tsl::SubProcess proc; - proc.SetProgram(GetDataFilePath(EchoProgram()).c_str(), {EchoProgram()}); + proc.SetProgram(EchoProgram().c_str(), {EchoProgram()}); proc.SetChannelAction(CHAN_STDIN, ACTION_PIPE); proc.SetChannelAction(CHAN_STDOUT, ACTION_PIPE); EXPECT_TRUE(proc.Start()); diff --git a/third_party/tsl/tsl/platform/test.cc b/third_party/tsl/tsl/platform/test.cc index e8f4102e4f468..b2b2a8936c81e 100644 --- a/third_party/tsl/tsl/platform/test.cc +++ b/third_party/tsl/tsl/platform/test.cc @@ -37,7 +37,19 @@ std::string GetEnvVarOrDie(const char* env_var) { } // namespace -std::string TmpDir() { return GetEnvVarOrDie("TEST_TMPDIR"); } +std::string TmpDir() { + const char* tmp_dir = std::getenv("TEST_TMPDIR"); + if (!tmp_dir) { + tmp_dir = std::getenv("TMPDIR"); + } + if (tmp_dir) { + return tmp_dir; + } + LOG(FATAL) // Crash OK + << "Failed to find environment variables: TEST_TMPDIR, TMPDIR"; + + return tmp_dir; +} int PickUnusedPortOrDie() { return internal::PickUnusedPortOrDie(); } diff --git a/third_party/tsl/tsl/platform/test.h b/third_party/tsl/tsl/platform/test.h index 352a64785e530..313bfe5f0ea3d 100644 --- a/third_party/tsl/tsl/platform/test.h +++ b/third_party/tsl/tsl/platform/test.h @@ -87,11 +87,7 @@ int RandomSeed(); int PickUnusedPortOrDie(); // Constant which is false internally and true in open source. -#ifdef PLATFORM_GOOGLE -inline constexpr bool kIsOpenSource = false; -#else -inline constexpr bool kIsOpenSource = true; -#endif // PLATFORM_GOOGLE +inline constexpr bool kIsOpenSource = TSL_IS_IN_OSS; } // namespace testing } // namespace tsl diff --git a/third_party/tsl/tsl/platform/threadpool.cc b/third_party/tsl/tsl/platform/threadpool.cc index 218226611b13f..8b2c850331e94 100644 --- a/third_party/tsl/tsl/platform/threadpool.cc +++ b/third_party/tsl/tsl/platform/threadpool.cc @@ -28,6 +28,10 @@ limitations under the License. #include "tsl/platform/setround.h" #include "tsl/platform/tracing.h" +#ifdef DNNL_AARCH64_USE_ACL +#include "tsl/platform/cpu_info.h" +#endif // DNNL_AARCH64_USE_ACL + #ifdef TENSORFLOW_THREADSCALING_EXPERIMENTAL ABSL_FLAG(float, tensorflow_num_threads_scale_factor, 1.0, "Allows to scale all Tensorflow ThreadPools. Total number of threads " @@ -107,6 +111,14 @@ ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options, bool low_latency_hint, Eigen::Allocator* allocator) { CHECK_GE(num_threads, 1); +#ifdef DNNL_AARCH64_USE_ACL + // To avoid cost of swapping in and out threads from running processes + // we do not use all available cores to parallelise TF operations. + if (num_threads == tsl::port::NumTotalCPUs() && num_threads >= 16) { + num_threads = num_threads - 1; + } +#endif // DNNL_AARCH64_USE_ACL + #ifdef TENSORFLOW_THREADSCALING_EXPERIMENTAL CHECK_GT(absl::GetFlag(FLAGS_tensorflow_num_threads_scale_factor), 0); num_threads *= absl::GetFlag(FLAGS_tensorflow_num_threads_scale_factor); diff --git a/third_party/tsl/tsl/platform/tstring.h b/third_party/tsl/tsl/platform/tstring.h index 8b8772e7d4207..0bdd9f52a76cc 100644 --- a/third_party/tsl/tsl/platform/tstring.h +++ b/third_party/tsl/tsl/platform/tstring.h @@ -128,7 +128,7 @@ class tstring { tstring& operator=(const view& tsv); // Move Assignment - tstring& operator=(tstring&& str); + tstring& operator=(tstring&& str) noexcept; // Comparison int compare(const char* str, size_t len) const; @@ -206,7 +206,7 @@ class tstring { tstring& insert(size_t pos, const tstring& str, size_t subpos, size_t sublen); tstring& insert(size_t pos, size_t n, char c); - void swap(tstring& str); + void swap(tstring& str) noexcept; void push_back(char ch); // Friends @@ -327,7 +327,7 @@ inline tstring& tstring::operator=(const tstring::view& tsv) { // Move Assignment -inline tstring& tstring::operator=(tstring&& str) { +inline tstring& tstring::operator=(tstring&& str) noexcept { TF_TString_Move(&tstr_, &str.tstr_); return *this; @@ -553,7 +553,7 @@ inline tstring& tstring::insert(size_t pos, size_t n, char c) { return *this; } -inline void tstring::swap(tstring& str) { +inline void tstring::swap(tstring& str) noexcept { // TODO(dero): Invalid for OFFSET (unimplemented). std::swap(tstr_, str.tstr_); } diff --git a/third_party/tsl/tsl/platform/windows/BUILD b/third_party/tsl/tsl/platform/windows/BUILD index 1eec5bdde62a6..3e9532dc3f538 100644 --- a/third_party/tsl/tsl/platform/windows/BUILD +++ b/third_party/tsl/tsl/platform/windows/BUILD @@ -1,10 +1,10 @@ -load("//tsl:tsl.default.bzl", "filegroup") - # Tensorflow windows-specific implementations of tensorflow/core/platform libraries. load( "//tsl:tsl.bzl", + "internal_visibility", "tsl_copts", ) +load("//tsl:tsl.default.bzl", "filegroup") load( "//tsl/platform:rules_cc.bzl", "cc_library", @@ -12,10 +12,10 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ + default_visibility = internal_visibility([ "//tensorflow/core/platform:__pkg__", "//tsl/platform:__pkg__", - ], + ]), licenses = ["notice"], ) @@ -145,7 +145,7 @@ cc_library( deps = [ ":wide_char", "//tsl/platform:errors", - "//tsl/platform:status", + "@com_google_absl//absl/status", ], ) @@ -266,5 +266,5 @@ filegroup( exports_files( srcs = ["intrinsics_port.h"], - visibility = ["//tensorflow/core/platform:__pkg__"], + visibility = internal_visibility(["//tensorflow/core/platform:__pkg__"]), ) diff --git a/third_party/tsl/tsl/platform/windows/load_library.cc b/third_party/tsl/tsl/platform/windows/load_library.cc index 0c47532dc687a..66d2d62cf6e13 100644 --- a/third_party/tsl/tsl/platform/windows/load_library.cc +++ b/third_party/tsl/tsl/platform/windows/load_library.cc @@ -28,7 +28,7 @@ limitations under the License. #include #include -#include "tsl/platform/errors.h" +#include "absl/status/status.h" #include "tsl/platform/windows/wide_char.h" #pragma comment(lib, "Shlwapi.lib") @@ -37,8 +37,8 @@ namespace tsl { namespace internal { -Status LoadDynamicLibrary(const char* library_filename, void** handle) { - string file_name = library_filename; +absl::Status LoadDynamicLibrary(const char* library_filename, void** handle) { + std::string file_name = library_filename; std::replace(file_name.begin(), file_name.end(), '/', '\\'); std::wstring ws_file_name(tsl::Utf8ToWideChar(file_name)); @@ -46,26 +46,27 @@ Status LoadDynamicLibrary(const char* library_filename, void** handle) { HMODULE hModule = LoadLibraryExW(ws_file_name.c_str(), NULL, LOAD_WITH_ALTERED_SEARCH_PATH); if (!hModule) { - return tsl::errors::NotFound(file_name + " not found"); + return absl::NotFoundError(file_name + " not found"); } *handle = hModule; - return OkStatus(); + return absl::OkStatus(); } -Status GetSymbolFromLibrary(void* handle, const char* symbol_name, - void** symbol) { +absl::Status GetSymbolFromLibrary(void* handle, const char* symbol_name, + void** symbol) { FARPROC found_symbol; found_symbol = GetProcAddress((HMODULE)handle, symbol_name); if (found_symbol == NULL) { - return tsl::errors::NotFound(std::string(symbol_name) + " not found"); + return absl::NotFoundError(std::string(symbol_name) + " not found"); } *symbol = (void**)found_symbol; - return OkStatus(); + return absl::OkStatus(); } -string FormatLibraryFileName(const string& name, const string& version) { - string filename; +std::string FormatLibraryFileName(const std::string& name, + const std::string& version) { + std::string filename; if (version.size() == 0) { filename = name + ".dll"; } else { diff --git a/third_party/tsl/tsl/platform/windows/port.cc b/third_party/tsl/tsl/platform/windows/port.cc index 9b5692650dbb5..f8e19503edb30 100644 --- a/third_party/tsl/tsl/platform/windows/port.cc +++ b/third_party/tsl/tsl/platform/windows/port.cc @@ -61,6 +61,8 @@ int64_t JobUid() { return -1; } int64_t TaskId() { return -1; } +IOStatistics GetIOStatistics() { return IOStatistics(); } + int NumSchedulableCPUs() { SYSTEM_INFO system_info; GetSystemInfo(&system_info); @@ -122,7 +124,6 @@ void NUMAFree(void* ptr, size_t size) { tsl::port::Free(ptr); } int NUMAGetMemAffinity(const void* addr) { return kNUMANoAffinity; } - bool Snappy_Compress(const char* input, size_t length, string* output) { #ifdef TF_USE_SNAPPY output->resize(snappy::MaxCompressedLength(length)); @@ -183,7 +184,7 @@ string Demangle(const char* mangled) { return mangled; } double NominalCPUFrequency() { DWORD data; DWORD data_size = sizeof(data); - #pragma comment(lib, "shlwapi.lib") // For SHGetValue(). +#pragma comment(lib, "shlwapi.lib") // For SHGetValue(). if (SUCCEEDED( SHGetValueA(HKEY_LOCAL_MACHINE, "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0", diff --git a/third_party/tsl/tsl/profiler/BUILD b/third_party/tsl/tsl/profiler/BUILD index af9a004e299ec..130527c3bae00 100644 --- a/third_party/tsl/tsl/profiler/BUILD +++ b/third_party/tsl/tsl/profiler/BUILD @@ -2,29 +2,16 @@ package_group( name = "friends", - includes = ["//tsl:internal"], ) package_group( name = "internal", - packages = [ - "//tensorflow/core/profiler/...", - "//tensorflow/python/eager/...", - "//tensorflow/python/profiler/...", - "//tensorflow/python/tpu/profiler/...", - "//tsl/profiler/...", - "//xla/backends/profiler/...", - ], ) package_group( name = "xla_profiler_backends", - packages = ["//xla/backends/profiler/..."], ) package_group( name = "xla_internal", - packages = [ - "//xla/...", - ], ) diff --git a/third_party/tsl/tsl/profiler/backends/cpu/BUILD b/third_party/tsl/tsl/profiler/backends/cpu/BUILD index 38b461fbb69c3..4deea64854922 100644 --- a/third_party/tsl/tsl/profiler/backends/cpu/BUILD +++ b/third_party/tsl/tsl/profiler/backends/cpu/BUILD @@ -1,8 +1,8 @@ -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") +load("//tsl/platform:build_config.bzl", "tsl_cc_test") load("//tsl/platform:build_config_root.bzl", "if_static") load("//tsl/platform:rules_cc.bzl", "cc_library") load("//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") -load("//tsl/platform:build_config.bzl", "tsl_cc_test") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) @@ -10,9 +10,10 @@ cc_library( name = "traceme_recorder", hdrs = ["traceme_recorder.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tsl/profiler:internal", "//tsl/profiler:xla_profiler_backends", + "//tensorflow/lite:__pkg__", ]), deps = [ "//tsl/platform:macros", @@ -32,7 +33,7 @@ cc_library( ], hdrs = ["traceme_recorder.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/python:__pkg__", "//tsl/platform/cloud:__pkg__", "//tsl/profiler:__pkg__", @@ -76,7 +77,7 @@ cc_library( name = "annotation_stack", hdrs = ["annotation_stack.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tsl/profiler:internal", ]), deps = [ @@ -95,8 +96,8 @@ cc_library( "annotation_stack.h", ], copts = tf_profiler_copts(), - visibility = set_external_visibility([ - "//tensorflow/compiler/xla:__subpackages__", + visibility = internal_visibility([ + "@xla//xla:__subpackages__", "//tsl/profiler:internal", ]), deps = [ @@ -112,7 +113,7 @@ cc_library( srcs = ["host_tracer_utils.cc"], hdrs = ["host_tracer_utils.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tsl/profiler:internal", "//tsl/profiler:xla_internal", ]), diff --git a/third_party/tsl/tsl/profiler/backends/cpu/annotation_stack.cc b/third_party/tsl/tsl/profiler/backends/cpu/annotation_stack.cc index 97b4c5daeb373..1e4b99bf79a84 100644 --- a/third_party/tsl/tsl/profiler/backends/cpu/annotation_stack.cc +++ b/third_party/tsl/tsl/profiler/backends/cpu/annotation_stack.cc @@ -16,33 +16,69 @@ limitations under the License. #include "tsl/profiler/backends/cpu/annotation_stack.h" #include +#include +#include +#include +#include +#include +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tsl/platform/types.h" namespace tsl { namespace profiler { -namespace internal { - -#ifdef _WIN32 -#define DECL_DLL_EXPORT __declspec(dllexport) -#else -#define DECL_DLL_EXPORT -#endif -// DLL imported variables cannot be initialized on Windows. This file is -// included only on DLL exports. -DECL_DLL_EXPORT std::atomic g_annotation_enabled(0); - -// g_annotation_enabled implementation must be lock-free for faster execution of -// the ScopedAnnotation API. This can be commented (if compilation is failing) -// but execution might be slow (even when tracing is disabled). -static_assert(ATOMIC_INT_LOCK_FREE == 2, "Assumed atomic was lock free"); -} // namespace internal +// Returns the annotation data for the given generation. +static auto GetAnnotationData(const std::atomic& atomic) { + static thread_local struct { + int generation = 0; + std::vector stack; + std::string string; + } data; + int generation = atomic.load(std::memory_order_acquire); + if (generation != data.generation) { + data = {generation}; + } + return std::make_pair(&data.stack, &data.string); +}; + +void AnnotationStack::PushAnnotation(std::string_view name) { + auto [stack, string] = GetAnnotationData(generation_); + stack->push_back(string->size()); + if (!string->empty()) { + return absl::StrAppend( + string, "::", absl::string_view(name.data(), name.size()) // NOLINT + ); + } + string->assign(name); +} + +void AnnotationStack::PopAnnotation() { + auto [stack, string] = GetAnnotationData(generation_); + if (stack->empty()) { + return string->clear(); + } + string->resize(stack->back()); + stack->pop_back(); +} -/*static*/ string* AnnotationStack::ThreadAnnotationStack() { - static thread_local string annotation_stack; - return &annotation_stack; +const string& AnnotationStack::Get() { + return *std::get(GetAnnotationData(generation_)); } +void AnnotationStack::Enable(bool enable) { + int generation = generation_.load(std::memory_order_relaxed); + while (!generation_.compare_exchange_weak( + generation, enable ? generation | 1 : generation + 1 & ~1, + std::memory_order_release)) { + } +} + +// AnnotationStack::generation_ implementation must be lock-free for faster +// execution of the ScopedAnnotation API. +std::atomic AnnotationStack::generation_{0}; +static_assert(ATOMIC_INT_LOCK_FREE == 2, "Assumed atomic was lock free"); + } // namespace profiler } // namespace tsl diff --git a/third_party/tsl/tsl/profiler/backends/cpu/annotation_stack.h b/third_party/tsl/tsl/profiler/backends/cpu/annotation_stack.h index 23bd5236f185b..44d1626e6a5cb 100644 --- a/third_party/tsl/tsl/profiler/backends/cpu/annotation_stack.h +++ b/third_party/tsl/tsl/profiler/backends/cpu/annotation_stack.h @@ -15,80 +15,42 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PROFILER_BACKENDS_CPU_ANNOTATION_STACK_H_ #define TENSORFLOW_TSL_PROFILER_BACKENDS_CPU_ANNOTATION_STACK_H_ -#include - #include -#include +#include -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "tsl/platform/macros.h" #include "tsl/platform/types.h" namespace tsl { namespace profiler { -namespace internal { - -// Whether annotations are enabled. -// Static atomic so Annotation::IsEnabled can be fast and non-blocking. -TF_EXPORT extern std::atomic g_annotation_enabled; - -} // namespace internal // Backend for ScopedAnnotation. class AnnotationStack { public: - // Appends name to the annotation for the current thread and returns the - // original length of the annotation. - // Append name to the current annotation, separated by "::". - // The choice of separator "::" is based on characters not used by - // TensorFlow for its TensorOps. - static size_t PushAnnotation(absl::string_view name) { - string* annotation_stack = ThreadAnnotationStack(); - size_t old_length = annotation_stack->size(); - if (old_length != 0) { - absl::StrAppend(annotation_stack, "::", name); - } else { - *annotation_stack = string(name); - } - return old_length; - } + // Appends name to the annotations for the current thread, separated by "::". + // The choice of separator "::" is based on characters not used by TensorFlow + // for its TensorOps. + static void PushAnnotation(std::string_view name); - static size_t PushAnnotation(string&& name) { - string* annotation_stack = ThreadAnnotationStack(); - size_t old_length = annotation_stack->size(); - if (old_length != 0) { - absl::StrAppend(annotation_stack, "::", name); - } else { - *annotation_stack = std::move(name); - } - return old_length; - } + // Resizes the annotation stack for the current thread. + static void PopAnnotation(); // Returns the annotation stack for the current thread. - static const string& Get() { return *ThreadAnnotationStack(); } + static const string& Get(); - // Resizes the annotation stack for the current thread to its old length. - static void PopAnnotation(size_t old_length) { - ThreadAnnotationStack()->resize(old_length); - } - - static void Enable(bool enable) { - internal::g_annotation_enabled.store(enable, std::memory_order_release); - } + // Enables or disables the annotation stack. + static void Enable(bool enable); + // Returns whether the annotation stack is enabled. static bool IsEnabled() { - return internal::g_annotation_enabled.load(std::memory_order_acquire); + return generation_.load(std::memory_order_acquire) & 1; } private: AnnotationStack() = default; - AnnotationStack(const AnnotationStack&) = delete; - void operator=(const AnnotationStack&) = delete; - - // Returns a reference to the annotation for the current thread. - static string* ThreadAnnotationStack(); + // Enabled if odd, disabled if even. The value is incremented for every call + // to Enable() which changes the enabled state. + static std::atomic generation_; }; } // namespace profiler diff --git a/third_party/tsl/tsl/profiler/builds/BUILD b/third_party/tsl/tsl/profiler/builds/BUILD index cf7bbdb12b69c..050103eec5570 100644 --- a/third_party/tsl/tsl/profiler/builds/BUILD +++ b/third_party/tsl/tsl/profiler/builds/BUILD @@ -1,6 +1,8 @@ +load("//tsl:tsl.bzl", "internal_visibility") + package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tsl/profiler:internal"], + default_visibility = internal_visibility(["//tsl/profiler:internal"]), licenses = ["notice"], ) diff --git a/third_party/tsl/tsl/profiler/builds/build_config.bzl b/third_party/tsl/tsl/profiler/builds/build_config.bzl index 22df6844da1a0..72e2e53537794 100644 --- a/third_party/tsl/tsl/profiler/builds/build_config.bzl +++ b/third_party/tsl/tsl/profiler/builds/build_config.bzl @@ -1,11 +1,11 @@ """Provides a redirection point for platform specific implementations of Starlark utilities.""" +load("//tsl:tsl.bzl", "clean_dep") load( "//tsl/profiler/builds/oss:build_config.bzl", _tf_profiler_alias = "tf_profiler_alias", _tf_profiler_pybind_cc_library_wrapper = "tf_profiler_pybind_cc_library_wrapper", ) -load("//tsl:tsl.bzl", "clean_dep") tf_profiler_pybind_cc_library_wrapper = _tf_profiler_pybind_cc_library_wrapper tf_profiler_alias = _tf_profiler_alias diff --git a/third_party/tsl/tsl/profiler/convert/BUILD b/third_party/tsl/tsl/profiler/convert/BUILD index 0e9a6e85822a9..89068b6328047 100644 --- a/third_party/tsl/tsl/profiler/convert/BUILD +++ b/third_party/tsl/tsl/profiler/convert/BUILD @@ -1,14 +1,14 @@ -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") +load("//tsl/platform:build_config.bzl", "tsl_cc_test") load( "//tsl/platform:rules_cc.bzl", "cc_library", ) -load("//tsl/platform:build_config.bzl", "tsl_cc_test") load("//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility(["//tsl/profiler:internal"]), + default_visibility = internal_visibility(["//tsl/profiler:internal"]), licenses = ["notice"], ) @@ -29,10 +29,10 @@ cc_library( cc_library( name = "xla_op_utils", hdrs = ["xla_op_utils.h"], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tsl/profiler:internal", "//tsl/profiler:xla_profiler_backends", - "//tensorflow/compiler/xla/python:__pkg__", + "@xla//xla/python:__pkg__", ]), deps = ["@com_google_absl//absl/strings"], ) @@ -53,10 +53,11 @@ cc_library( srcs = ["post_process_single_host_xplane.cc"], hdrs = ["post_process_single_host_xplane.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility(["//tsl/profiler:internal"]), + visibility = internal_visibility(["//tsl/profiler:internal"]), deps = [ "//tsl/platform:types", "//tsl/profiler/protobuf:xplane_proto_cc", + "//tsl/profiler/utils:timestamp_utils", "//tsl/profiler/utils:xplane_schema", "//tsl/profiler/utils:xplane_utils", ], @@ -67,7 +68,7 @@ cc_library( srcs = ["trace_events_to_json.cc"], hdrs = ["trace_events_to_json.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tsl/profiler:internal", ]), deps = [ @@ -114,7 +115,7 @@ cc_library( srcs = ["xplane_to_trace_events.cc"], hdrs = ["xplane_to_trace_events.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tsl/profiler:internal", ]), deps = [ diff --git a/third_party/tsl/tsl/profiler/convert/post_process_single_host_xplane.cc b/third_party/tsl/tsl/profiler/convert/post_process_single_host_xplane.cc index fbba8c2eb840a..49e2f7dbda2ae 100644 --- a/third_party/tsl/tsl/profiler/convert/post_process_single_host_xplane.cc +++ b/third_party/tsl/tsl/profiler/convert/post_process_single_host_xplane.cc @@ -17,7 +17,9 @@ limitations under the License. #include #include +#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "tsl/profiler/utils/timestamp_utils.h" #include "tsl/profiler/utils/xplane_schema.h" #include "tsl/profiler/utils/xplane_utils.h" @@ -43,7 +45,7 @@ void MergeHostPlanesAndSortLines(tensorflow::profiler::XSpace* space) { } // namespace void PostProcessSingleHostXSpace(tensorflow::profiler::XSpace* space, - uint64 start_time_ns) { + uint64 start_time_ns, uint64 stop_time_ns) { VLOG(3) << "Post processing local profiler XSpace."; // Post processing the collected XSpace without hold profiler lock. // 1. Merge all host planes and sorts lines by name. @@ -51,7 +53,10 @@ void PostProcessSingleHostXSpace(tensorflow::profiler::XSpace* space, // 2. Normalize all timestamps by shifting timeline to profiling start time. // NOTE: this have to be done before sorting XSpace due to timestamp overflow. NormalizeTimestamps(space, start_time_ns); - // 3. Sort each plane of the XSpace + // 3. Add information regarding profiling start_time_ns_ and stop_time_ns_ to + // taskEnv. + SetSessionTimestamps(start_time_ns, stop_time_ns, *space); + // 4. Sort each plane of the XSpace SortXSpace(space); } diff --git a/third_party/tsl/tsl/profiler/convert/post_process_single_host_xplane.h b/third_party/tsl/tsl/profiler/convert/post_process_single_host_xplane.h index d0183f2dba188..0b413931e989f 100644 --- a/third_party/tsl/tsl/profiler/convert/post_process_single_host_xplane.h +++ b/third_party/tsl/tsl/profiler/convert/post_process_single_host_xplane.h @@ -23,7 +23,7 @@ namespace profiler { // Post process XSpaces collected locally from multiple profilers. void PostProcessSingleHostXSpace(tensorflow::profiler::XSpace* space, - uint64 start_time_ns); + uint64 start_time_ns, uint64 stop_time_ns); } // namespace profiler } // namespace tsl diff --git a/third_party/tsl/tsl/profiler/lib/BUILD b/third_party/tsl/tsl/profiler/lib/BUILD index 249db4b173582..88462a4963215 100644 --- a/third_party/tsl/tsl/profiler/lib/BUILD +++ b/third_party/tsl/tsl/profiler/lib/BUILD @@ -1,8 +1,12 @@ -load("//tsl/platform:rules_cc.bzl", "cc_library") -load("//tsl/platform:build_config_root.bzl", "if_static") +load("//tsl:tsl.bzl", "if_not_android", "internal_visibility", "nvtx_headers") load("//tsl:tsl.default.bzl", "filegroup") -load("//tsl:tsl.bzl", "if_not_android", "set_external_visibility") load("//tsl/platform:build_config.bzl", "tsl_cc_test") +load("//tsl/platform:build_config_root.bzl", "if_static") +load("//tsl/platform:rules_cc.bzl", "cc_library") +load( + "//tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) load( "//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts", @@ -27,6 +31,8 @@ cc_library( filegroup( name = "mobile_srcs_no_runtime", srcs = [ + "nvtx_utils.h", + "nvtx_utils_stub.cc", #Include the stub implementation here since CUDA isn't relevant to Android. "scoped_annotation.h", "scoped_memory_debug_annotation.cc", "scoped_memory_debug_annotation.h", @@ -49,7 +55,7 @@ cc_library( name = "profiler_controller", srcs = ["profiler_controller.cc"], hdrs = ["profiler_controller.h"], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tsl/profiler:internal", ]), deps = [ @@ -64,10 +70,10 @@ cc_library( cc_library( name = "profiler_factory", hdrs = ["profiler_factory.h"], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tsl/profiler:internal", "//tsl/profiler:xla_profiler_backends", - "//tensorflow/compiler/xla/python:__pkg__", + "@xla//xla/python:__pkg__", "//learning/brain/tfrc/executor/stream_executor:__pkg__", ]), deps = [ @@ -85,7 +91,7 @@ cc_library( "profiler_factory.h", ], copts = tf_profiler_copts(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tsl/profiler:internal", "//learning/brain/tfrc/executor/stream_executor:__pkg__", ]), @@ -119,7 +125,7 @@ cc_library( name = "profiler_interface", hdrs = ["profiler_interface.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tsl:internal", "//tsl/profiler:internal", "//tsl/profiler:xla_profiler_backends", @@ -135,7 +141,7 @@ cc_library( srcs = ["profiler_lock.cc"], hdrs = ["profiler_lock.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tsl/profiler:internal", "//tsl/profiler:xla_internal", ]), @@ -143,7 +149,8 @@ cc_library( "//tsl/platform:errors", "//tsl/platform:macros", "//tsl/platform:statusor", - "//tsl/util:env_var", + "@com_google_absl//absl/strings:string_view", + "@xla//xla/tsl/util:env_var", ], ) @@ -152,6 +159,7 @@ tsl_cc_test( srcs = ["profiler_lock_test.cc"], deps = [ ":profiler_lock", + "//tsl/platform:statusor", "//tsl/platform:test", "//tsl/platform:test_main", ], @@ -160,7 +168,7 @@ tsl_cc_test( cc_library( name = "profiler_session", hdrs = ["profiler_session.h"], - visibility = set_external_visibility(["//tsl:internal"]), + visibility = internal_visibility(["//tsl:internal"]), deps = [ "//tsl/platform", "//tsl/platform:errors", @@ -185,7 +193,7 @@ cc_library( "profiler_session.h", ], copts = tf_profiler_copts(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/python:__pkg__", "//tsl/profiler:internal", ]), @@ -231,6 +239,7 @@ tsl_cc_test( ":traceme_encode", "//tsl/platform", "//tsl/platform:test", + "//tsl/platform:test_benchmark", "//tsl/platform:test_main", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -240,7 +249,7 @@ tsl_cc_test( tf_profiler_pybind_cc_library_wrapper( name = "traceme_for_pybind", actual = ":traceme", - visibility = set_external_visibility(["//tsl/profiler:xla_internal"]), + visibility = internal_visibility(["//tsl/profiler:xla_internal"]), ) cc_library( @@ -262,43 +271,28 @@ cc_library( cc_library( name = "nvtx_utils", + srcs = if_cuda_is_configured( + ["nvtx_utils.cc"], + ["nvtx_utils_stub.cc"], + ), hdrs = ["nvtx_utils.h"], visibility = ["//visibility:public"], - deps = [ - "//tsl/platform:logging", - "//tsl/platform:macros", - "//tsl/platform:types", - "@com_google_absl//absl/strings", - ] + if_not_android([ - "//tsl/profiler/backends/cpu:annotation_stack", - ]), + deps = if_cuda_is_configured(nvtx_headers()), ) cc_library( name = "scoped_annotation", - hdrs = ["scoped_annotation.h"], + hdrs = [ + "nvtx_utils.h", + "scoped_annotation.h", + ], visibility = ["//visibility:public"], deps = [ + ":nvtx_utils", "//tsl/platform:macros", "//tsl/platform:types", "@com_google_absl//absl/strings", ] + if_not_android([ - ":nvtx_utils", - "//tsl/profiler/backends/cpu:annotation_stack", - ]), -) - -cc_library( - name = "scoped_annotation_stack", - hdrs = ["scoped_annotation_stack.h"], - visibility = set_external_visibility([ - "//tsl/profiler:internal", - "//tsl/profiler:xla_internal", - ]), - deps = [ - "@com_google_absl//absl/strings", - ] + if_not_android([ - ":nvtx_utils", "//tsl/profiler/backends/cpu:annotation_stack", ]), ) @@ -309,7 +303,6 @@ tsl_cc_test( srcs = ["scoped_annotation_test.cc"], deps = [ ":scoped_annotation", - ":scoped_annotation_stack", "//tsl/platform:test", "//tsl/platform:test_benchmark", "//tsl/platform:test_main", @@ -336,8 +329,8 @@ cc_library( name = "profiler_collection", srcs = ["profiler_collection.cc"], hdrs = ["profiler_collection.h"], - visibility = set_external_visibility([ - "//tensorflow/compiler/xla/backends/profiler/plugin:__pkg__", + visibility = internal_visibility([ + "@xla//xla/backends/profiler/plugin:__pkg__", "//learning/brain/tfrc/executor/stream_executor:__pkg__", ]), deps = [ diff --git a/third_party/tsl/tsl/profiler/lib/connected_traceme.h b/third_party/tsl/tsl/profiler/lib/connected_traceme.h index e6e5bfed1493c..a4b01ae517f65 100644 --- a/third_party/tsl/tsl/profiler/lib/connected_traceme.h +++ b/third_party/tsl/tsl/profiler/lib/connected_traceme.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PROFILER_LIB_CONNECTED_TRACEME_H_ #define TENSORFLOW_TSL_PROFILER_LIB_CONNECTED_TRACEME_H_ +#include #include #include @@ -79,7 +80,7 @@ class TraceMeProducer : public TraceMe { template explicit TraceMeProducer(NameT&& name, ContextType context_type = ContextType::kGeneric, - absl::optional context_id = absl::nullopt, + std::optional context_id = std::nullopt, int level = 2) : TraceMe(std::forward(name), level), context_id_(context_id.has_value() ? context_id.value() diff --git a/third_party/tsl/tsl/profiler/lib/context_types.cc b/third_party/tsl/tsl/profiler/lib/context_types.cc index 9379885c4de76..371631c10ba88 100644 --- a/third_party/tsl/tsl/profiler/lib/context_types.cc +++ b/third_party/tsl/tsl/profiler/lib/context_types.cc @@ -46,6 +46,8 @@ const char* GetContextTypeString(ContextType context_type) { return "tpu_launch"; case ContextType::kPathwaysExecutor: return "pathways_exec"; + case ContextType::kPjrtLibraryCall: + return "pjrt_library_call"; } } diff --git a/third_party/tsl/tsl/profiler/lib/context_types.h b/third_party/tsl/tsl/profiler/lib/context_types.h index 6f65454354a1d..621f35462fdae 100644 --- a/third_party/tsl/tsl/profiler/lib/context_types.h +++ b/third_party/tsl/tsl/profiler/lib/context_types.h @@ -36,6 +36,7 @@ enum class ContextType : int { kTpuStream, kTpuLaunch, kPathwaysExecutor, + kPjrtLibraryCall, kLastContextType = ContextType::kTpuLaunch, }; diff --git a/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc b/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc new file mode 100644 index 0000000000000..b122c6e12dfc1 --- /dev/null +++ b/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc @@ -0,0 +1,84 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tsl/profiler/lib/nvtx_utils.h" + +#include +#include +#include +#include +#include + +#include "nvtx3/nvToolsExt.h" +#include "nvtx3/nvToolsExtPayload.h" + +namespace tsl::profiler { +static_assert(std::is_pointer_v); +static_assert(std::is_pointer_v); + +ProfilerDomainHandle DefaultProfilerDomain() { + static ProfilerDomainHandle domain = + reinterpret_cast(nvtxDomainCreateA("TSL")); + return domain; +} + +void RangePop(ProfilerDomainHandle domain) { + nvtxDomainRangePop(reinterpret_cast(domain)); +} + +void RangePush(ProfilerDomainHandle domain, const char* ascii) { + nvtxEventAttributes_t attrs{}; + attrs.version = NVTX_VERSION; + attrs.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; + attrs.messageType = NVTX_MESSAGE_TYPE_ASCII; + attrs.message.ascii = ascii; + nvtxDomainRangePushEx(reinterpret_cast(domain), &attrs); +} + +namespace detail { +void RangePush(ProfilerDomainHandle domain, StringHandle title, + uint64_t schema_id, const void* payload, size_t payload_size) { + nvtxEventAttributes_t attrs{}; + attrs.version = NVTX_VERSION; + attrs.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; + attrs.messageType = NVTX_MESSAGE_TYPE_REGISTERED; + attrs.message.registered = reinterpret_cast(title); + NVTX_PAYLOAD_EVTATTR_SET(attrs, schema_id, payload, payload_size); + nvtxDomainRangePushEx(reinterpret_cast(domain), &attrs); +} +} // namespace detail + +uint64_t RegisterSchema(ProfilerDomainHandle domain, const void* schemaAttr) { + return nvtxPayloadSchemaRegister( + reinterpret_cast(domain), + static_cast(schemaAttr)); +} + +StringHandle RegisterString(ProfilerDomainHandle domain, + const std::string& str) { + const auto impl = [domain](const char* c_str) { + return reinterpret_cast(nvtxDomainRegisterStringA( + reinterpret_cast(domain), c_str)); + }; + constexpr auto max_length = 65330; + if (str.size() <= max_length) { + return impl(str.c_str()); + } + // nvbugs 4340868 + std::string_view suffix{"\n[truncated]\n"}; + std::string buffer(str.data(), max_length - suffix.size()); + buffer.append(suffix); + return impl(buffer.c_str()); +} +} // namespace tsl::profiler diff --git a/third_party/tsl/tsl/profiler/lib/nvtx_utils.h b/third_party/tsl/tsl/profiler/lib/nvtx_utils.h index 416d829378455..0727072d06390 100644 --- a/third_party/tsl/tsl/profiler/lib/nvtx_utils.h +++ b/third_party/tsl/tsl/profiler/lib/nvtx_utils.h @@ -16,68 +16,54 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PROFILER_LIB_NVTX_UTILS_H_ #define TENSORFLOW_TSL_PROFILER_LIB_NVTX_UTILS_H_ -#include +#include -#include "absl/strings/string_view.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" +#include +#include -#if GOOGLE_CUDA -#include "nvtx3/nvToolsExt.h" -#endif +namespace tsl::profiler { +struct String; +// Opaque handle to a string that has been pre-registered with the profiler/NVTX +// implementation +using StringHandle = String*; -namespace tsl { -namespace profiler { -namespace nvtx { +struct ProfilerDomain; +// Opaque handle to a domain in the profiler/NVTX implementation +using ProfilerDomainHandle = ProfilerDomain*; -// Some typedef to help build without NVTX. -#if !GOOGLE_CUDA -typedef void* nvtxEventAttributes_t; -typedef void* nvtxDomainHandle_t; -#endif +// Get the "TSL" domain if NVTX profiling is enabled, otherwise null +ProfilerDomainHandle DefaultProfilerDomain(); -// A helper function that return the domains to use if NVTX profiling -// is enabled. -inline std::optional GetNVTXDomain() { -#if GOOGLE_CUDA - static nvtxDomainHandle_t domain; - static bool is_enabled = [] { - bool _is_enabled = false; - // Force NVTX marker if a tool triggered the profiler. - domain = nvtxDomainCreateA("TSL"); - if (domain) { - _is_enabled = true; - } - VLOG(1) << "Is NVTX marker enabled? " << _is_enabled; - return _is_enabled; - }(); - if (is_enabled) return domain; -#endif - return {}; +// Register a string with the profiler/NVTX implementation for faster use +StringHandle RegisterString(ProfilerDomainHandle, const std::string&); + +// End a range that was created on this thread by RangePush +void RangePop(ProfilerDomainHandle); + +// Older/simpler version; NVTX implementation copies a C-style string each time +void RangePush(ProfilerDomainHandle domain, const char*); +inline void RangePush(ProfilerDomainHandle domain, const std::string& str) { + RangePush(domain, str.c_str()); } -// A helper function to decide whether to enable CUDA NVTX profiling ranges. -inline bool RangesEnabled() { -#if GOOGLE_CUDA - return GetNVTXDomain().has_value(); -#else - return false; -#endif +namespace detail { +void RangePush(ProfilerDomainHandle domain, StringHandle title, + uint64_t schema_id, const void* payload, size_t payload_size); } -// Note: The memory backing msg must persist until the result of this function -// has been consumed by an NVTX API. -inline void MakeAttributes(const char* msg, nvtxEventAttributes_t* result) { - *result = {0}; -#if GOOGLE_CUDA - result->version = NVTX_VERSION; - result->size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; - result->messageType = NVTX_MESSAGE_TYPE_ASCII; - result->message.ascii = msg; -#endif +// More powerful version: pass a registered string instead of a C-style +// string, and attach a generic payload. The Annotation type must implement a +// method called NvtxSchemaId() that allows the NVTX backend to interpret the +// payload. +template +void RangePush(ProfilerDomainHandle domain, StringHandle title, + const Annotation& annotation) { + return detail::RangePush(domain, title, annotation.NvtxSchemaId(), + &annotation, sizeof(Annotation)); } -} // namespace nvtx -} // namespace profiler -} // namespace tsl +// Register the schema of a custom payload type, for use with the more powerful +// version of RangePush +uint64_t RegisterSchema(ProfilerDomainHandle domain, const void* schemaAttr); +} // namespace tsl::profiler #endif // TENSORFLOW_TSL_PROFILER_LIB_NVTX_UTILS_H_ diff --git a/third_party/tsl/tsl/profiler/lib/nvtx_utils_stub.cc b/third_party/tsl/tsl/profiler/lib/nvtx_utils_stub.cc new file mode 100644 index 0000000000000..c887af77ec8b1 --- /dev/null +++ b/third_party/tsl/tsl/profiler/lib/nvtx_utils_stub.cc @@ -0,0 +1,29 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tsl/profiler/lib/nvtx_utils.h" + +namespace tsl::profiler { +ProfilerDomainHandle DefaultProfilerDomain() { return {}; } +void RangePop(ProfilerDomainHandle) {} +void RangePush(ProfilerDomainHandle, const char*) {} +namespace detail { +void RangePush(ProfilerDomainHandle, StringHandle, uint64_t, const void*, + size_t) {} +} // namespace detail +uint64_t RegisterSchema(ProfilerDomainHandle, const void*) { return 0; } +StringHandle RegisterString(ProfilerDomainHandle, const std::string&) { + return {}; +} +} // namespace tsl::profiler diff --git a/third_party/tsl/tsl/profiler/lib/profiler_factory_test.cc b/third_party/tsl/tsl/profiler/lib/profiler_factory_test.cc index ef8c52e261f07..55bb71b107c36 100644 --- a/third_party/tsl/tsl/profiler/lib/profiler_factory_test.cc +++ b/third_party/tsl/tsl/profiler/lib/profiler_factory_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tsl/platform/macros.h" #include "tsl/platform/status.h" #include "tsl/platform/test.h" diff --git a/third_party/tsl/tsl/profiler/lib/profiler_lock.cc b/third_party/tsl/tsl/profiler/lib/profiler_lock.cc index e99db5ae36696..325713117a333 100644 --- a/third_party/tsl/tsl/profiler/lib/profiler_lock.cc +++ b/third_party/tsl/tsl/profiler/lib/profiler_lock.cc @@ -16,10 +16,10 @@ limitations under the License. #include +#include "xla/tsl/util/env_var.h" #include "tsl/platform/errors.h" #include "tsl/platform/macros.h" #include "tsl/platform/statusor.h" -#include "tsl/util/env_var.h" namespace tsl { namespace profiler { diff --git a/third_party/tsl/tsl/profiler/lib/profiler_lock.h b/third_party/tsl/tsl/profiler/lib/profiler_lock.h index aead8353b2304..26d4b2a4471c4 100644 --- a/third_party/tsl/tsl/profiler/lib/profiler_lock.h +++ b/third_party/tsl/tsl/profiler/lib/profiler_lock.h @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tsl/platform/statusor.h" namespace tsl { diff --git a/third_party/tsl/tsl/profiler/lib/profiler_lock_test.cc b/third_party/tsl/tsl/profiler/lib/profiler_lock_test.cc index a483582034039..c22e65d635c73 100644 --- a/third_party/tsl/tsl/profiler/lib/profiler_lock_test.cc +++ b/third_party/tsl/tsl/profiler/lib/profiler_lock_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace tsl { diff --git a/third_party/tsl/tsl/profiler/lib/profiler_session.cc b/third_party/tsl/tsl/profiler/lib/profiler_session.cc index f5ba7e1ee9281..39775b957a109 100644 --- a/third_party/tsl/tsl/profiler/lib/profiler_session.cc +++ b/third_party/tsl/tsl/profiler/lib/profiler_session.cc @@ -22,9 +22,7 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/mutex.h" -#include "tsl/platform/platform.h" #include "tsl/platform/status.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" @@ -70,6 +68,7 @@ Status ProfilerSession::CollectDataInternal(XSpace* space) { LOG(INFO) << "Profiler session collecting data."; if (profilers_ != nullptr) { profilers_->Stop().IgnoreError(); + stop_time_ns_ = profiler::GetCurrentTimeNanos(); profilers_->CollectData(space).IgnoreError(); profilers_.reset(); // data has been collected. } @@ -83,7 +82,7 @@ Status ProfilerSession::CollectData(XSpace* space) { #if !defined(IS_MOBILE_PLATFORM) space->add_hostnames(port::Hostname()); TF_RETURN_IF_ERROR(CollectDataInternal(space)); - profiler::PostProcessSingleHostXSpace(space, start_time_ns_); + profiler::PostProcessSingleHostXSpace(space, start_time_ns_, stop_time_ns_); #endif return OkStatus(); } diff --git a/third_party/tsl/tsl/profiler/lib/profiler_session.h b/third_party/tsl/tsl/profiler/lib/profiler_session.h index 424e5c87d0b4e..e6fb67218ac5f 100644 --- a/third_party/tsl/tsl/profiler/lib/profiler_session.h +++ b/third_party/tsl/tsl/profiler/lib/profiler_session.h @@ -83,6 +83,7 @@ class ProfilerSession { std::unique_ptr profilers_ TF_GUARDED_BY(mutex_); uint64 start_time_ns_; + uint64 stop_time_ns_; tensorflow::ProfileOptions options_; #endif tsl::Status status_ TF_GUARDED_BY(mutex_); diff --git a/third_party/tsl/tsl/profiler/lib/scoped_annotation.h b/third_party/tsl/tsl/profiler/lib/scoped_annotation.h index 643d704542860..c41a2a39a8dc3 100644 --- a/third_party/tsl/tsl/profiler/lib/scoped_annotation.h +++ b/third_party/tsl/tsl/profiler/lib/scoped_annotation.h @@ -18,128 +18,83 @@ limitations under the License. #include #include -#include #include #include #include -#include "absl/strings/string_view.h" #include "tsl/platform/macros.h" -#include "tsl/platform/types.h" +#include "tsl/profiler/lib/nvtx_utils.h" #if !defined(IS_MOBILE_PLATFORM) #include "tsl/profiler/backends/cpu/annotation_stack.h" -#include "tsl/profiler/lib/nvtx_utils.h" #endif -namespace tsl { -namespace profiler { +namespace tsl::profiler { -// Adds an annotation to all activities for the duration of the instance -// lifetime through the currently registered TraceCollector. -// -// Usage: { -// ScopedAnnotation annotation("my kernels"); -// Kernel1<<>>; -// LaunchKernel2(); // Launches a CUDA kernel. -// } -// This will add 'my kernels' to both kernels in the profiler UI -template -class ScopedAnnotationT { - public: - explicit ScopedAnnotationT(absl::string_view name) { -#if !defined(IS_MOBILE_PLATFORM) +// Adds an annotation to all activities through the currently registered +// TraceCollector until PopAnnotation() is called. +template +void PushAnnotation(const T& generator) { #if GOOGLE_CUDA - std::optional domain = - tsl::profiler::nvtx::GetNVTXDomain(); - if (TF_PREDICT_FALSE(domain.has_value())) { - nvtxEventAttributes_t attrs; - std::string name_str(name); - tsl::profiler::nvtx::MakeAttributes(name_str.c_str(), &attrs); - ::nvtxDomainRangePushEx(domain.value(), &attrs); - } else // NOLINT -#endif - if (always_annotate || TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { - old_length_ = AnnotationStack::PushAnnotation(name); - } -#endif + if (auto domain = DefaultProfilerDomain(); + TF_PREDICT_FALSE(domain != nullptr)) { + RangePush(domain, generator()); + return; } +#endif - explicit ScopedAnnotationT(const char* name) - : ScopedAnnotationT(absl::string_view(name)) {} - - explicit ScopedAnnotationT(const string& name) { #if !defined(IS_MOBILE_PLATFORM) -#if GOOGLE_CUDA - std::optional domain = - tsl::profiler::nvtx::GetNVTXDomain(); - if (TF_PREDICT_FALSE(domain.has_value())) { - nvtxEventAttributes_t attrs; - tsl::profiler::nvtx::MakeAttributes(name.c_str(), &attrs); - ::nvtxDomainRangePushEx(domain.value(), &attrs); - } else // NOLINT -#endif - if (always_annotate || TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { - old_length_ = AnnotationStack::PushAnnotation(name); - } -#endif + if (TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { + AnnotationStack::PushAnnotation(static_cast(generator())); } +#endif +} + +inline void PushAnnotation(const char* name) { + PushAnnotation([&] { return name; }); +} +inline void PushAnnotation(const std::string& name) { + PushAnnotation([&] { return name; }); +} + +inline void PopAnnotation() { + // TODO(b/137971921): without this memory fence, two presubmit tests will + // fail probably due to compiler in that presubmit config. + std::atomic_thread_fence(std::memory_order_acquire); - explicit ScopedAnnotationT(string&& name) { -#if !defined(IS_MOBILE_PLATFORM) #if GOOGLE_CUDA - std::optional domain = - tsl::profiler::nvtx::GetNVTXDomain(); - if (TF_PREDICT_FALSE(domain.has_value())) { - nvtxEventAttributes_t attrs; - tsl::profiler::nvtx::MakeAttributes(name.c_str(), &attrs); - ::nvtxDomainRangePushEx(domain.value(), &attrs); - } else // NOLINT -#endif - if (always_annotate || TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { - old_length_ = AnnotationStack::PushAnnotation(std::move(name)); - } -#endif + if (auto domain = DefaultProfilerDomain(); + TF_PREDICT_FALSE(domain != nullptr)) { + RangePop(domain); + return; } +#endif - template - explicit ScopedAnnotationT(NameGeneratorT name_generator) { #if !defined(IS_MOBILE_PLATFORM) -#if GOOGLE_CUDA - std::optional domain = - tsl::profiler::nvtx::GetNVTXDomain(); - if (TF_PREDICT_FALSE(domain.has_value())) { - auto name = name_generator(); - nvtxEventAttributes_t attrs; - tsl::profiler::nvtx::MakeAttributes(name.c_str(), &attrs); - ::nvtxDomainRangePushEx(domain.value(), &attrs); - } else // NOLINT -#endif - if (always_annotate || TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { - auto name = name_generator(); - old_length_ = AnnotationStack::PushAnnotation(name); - } + if (TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { + AnnotationStack::PopAnnotation(); + } #endif +} + +// Adds an annotation to all activities for the duration of the instance +// lifetime through the currently registered TraceCollector. +// +// Usage: { +// ScopedAnnotation annotation("my kernels"); +// Kernel1<<>>; +// LaunchKernel2(); // Launches a CUDA kernel. +// } +// This will add 'my kernels' to both kernels in the profiler UI +class ScopedAnnotation { + public: + template + explicit ScopedAnnotation(T&& annotation) { + PushAnnotation(std::forward(annotation)); } // Pops the name passed in the constructor from the current annotation. - ~ScopedAnnotationT() { - // TODO(b/137971921): without this memory fence, two presubmit tests will - // fail probably due to compiler in that presubmit config. - std::atomic_thread_fence(std::memory_order_acquire); -#if !defined(IS_MOBILE_PLATFORM) -#if GOOGLE_CUDA - std::optional domain = - tsl::profiler::nvtx::GetNVTXDomain(); - if (TF_PREDICT_FALSE(domain.has_value())) { - ::nvtxDomainRangePop(domain.value()); - } else // NOLINT -#endif - if (TF_PREDICT_FALSE(old_length_ != kInvalidLength)) { - AnnotationStack::PopAnnotation(old_length_); - } -#endif - } + ~ScopedAnnotation() { PopAnnotation(); } static bool IsEnabled() { #if !defined(IS_MOBILE_PLATFORM) @@ -150,19 +105,10 @@ class ScopedAnnotationT { } private: - // signals that annotation is disabled at the constructor. - static constexpr size_t kInvalidLength = static_cast(-1); - - ScopedAnnotationT(const ScopedAnnotationT&) = delete; - void operator=(const ScopedAnnotationT&) = delete; - - size_t old_length_ = kInvalidLength; + ScopedAnnotation(const ScopedAnnotation&) = delete; + ScopedAnnotation& operator=(const ScopedAnnotation&) = delete; }; -using ScopedAnnotation = ScopedAnnotationT; -using ScopedAnnotationAlways = ScopedAnnotationT; - -} // namespace profiler -} // namespace tsl +} // namespace tsl::profiler #endif // TENSORFLOW_TSL_PROFILER_LIB_SCOPED_ANNOTATION_H_ diff --git a/third_party/tsl/tsl/profiler/lib/scoped_annotation_stack.h b/third_party/tsl/tsl/profiler/lib/scoped_annotation_stack.h deleted file mode 100644 index f4e538f127c9b..0000000000000 --- a/third_party/tsl/tsl/profiler/lib/scoped_annotation_stack.h +++ /dev/null @@ -1,119 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_LIB_SCOPED_ANNOTATION_STACK_H_ -#define TENSORFLOW_TSL_PROFILER_LIB_SCOPED_ANNOTATION_STACK_H_ - -#include - -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#if !defined(IS_MOBILE_PLATFORM) -#include "tsl/profiler/backends/cpu/annotation_stack.h" -#include "tsl/profiler/lib/nvtx_utils.h" -#endif - -namespace tsl { -namespace profiler { - -// ScopedAnnotation for clients that can't use RAII for managing the lifetime -// of annotations. It provides an API similar to the `TraceMe::ActivityStart` -// and `TraceMe::ActivityEnd`. -// -// Usage: -// int64_t id = ScopedAnnotationStack::ActivityStart("foo"); -// foo(); -// ScopedAnnotationStack::ActivityEnd(id); -// -// Prefer a regular `ScopedAnnotation`. The name of this class is a misnomer, -// because it doesn't do any automatic destruction at the scope end, it's just -// for the sake of consistency. -class ScopedAnnotationStack { - static constexpr size_t kInvalidActivity = static_cast(-1); - - public: - static bool IsEnabled() { return AnnotationStack::IsEnabled(); } - - static int64_t ActivityStart(std::string name) { -#if !defined(IS_MOBILE_PLATFORM) -#if GOOGLE_CUDA - std::optional domain = - tsl::profiler::nvtx::GetNVTXDomain(); - if (TF_PREDICT_FALSE(domain.has_value())) { - nvtxEventAttributes_t attrs; - std::string name_str(name); - tsl::profiler::nvtx::MakeAttributes(name_str.c_str(), &attrs); - ::nvtxDomainRangePushEx(domain.value(), &attrs); - } else // NOLINT -#endif - if (TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { - return AnnotationStack::PushAnnotation(std::move(name)); - } -#endif - return kInvalidActivity; - } - - static int64_t ActivityStart(std::string_view name) { - return ActivityStart(std::string(name)); - } - - static int64_t ActivityStart(const char* name) { - return ActivityStart(std::string_view(name)); - } - - template - static int64_t ActivityStart(NameGeneratorT name_generator) { -#if !defined(IS_MOBILE_PLATFORM) -#if GOOGLE_CUDA - std::optional domain = - tsl::profiler::nvtx::GetNVTXDomain(); - if (TF_PREDICT_FALSE(domain.has_value())) { - auto name = name_generator(); - nvtxEventAttributes_t attrs; - std::string name_str(name); - tsl::profiler::nvtx::MakeAttributes(name_str.c_str(), &attrs); - ::nvtxDomainRangePushEx(domain.value(), &attrs); - } else // NOLINT -#endif - if (TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { - return AnnotationStack::PushAnnotation(name_generator()); - } -#endif - return kInvalidActivity; - } - - static void ActivityEnd(int64_t activity_id) { -#if !defined(IS_MOBILE_PLATFORM) -#if GOOGLE_CUDA - std::optional domain = - tsl::profiler::nvtx::GetNVTXDomain(); - if (TF_PREDICT_FALSE(domain.has_value())) { - ::nvtxDomainRangePop(domain.value()); - } else // NOLINT -#endif - if (TF_PREDICT_FALSE(activity_id != kInvalidActivity)) { - AnnotationStack::PopAnnotation(activity_id); - } -#endif - } -}; - -} // namespace profiler -} // namespace tsl - -#endif // TENSORFLOW_TSL_PROFILER_LIB_SCOPED_ANNOTATION_STACK_H_ diff --git a/third_party/tsl/tsl/profiler/lib/scoped_annotation_test.cc b/third_party/tsl/tsl/profiler/lib/scoped_annotation_test.cc index dab3e91f2ed4c..0ae9d3276375f 100644 --- a/third_party/tsl/tsl/profiler/lib/scoped_annotation_test.cc +++ b/third_party/tsl/tsl/profiler/lib/scoped_annotation_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" #include "tsl/profiler/backends/cpu/annotation_stack.h" -#include "tsl/profiler/lib/scoped_annotation_stack.h" namespace tsl { namespace profiler { @@ -50,11 +49,11 @@ TEST(ScopedAnnotation, Simple) { { AnnotationStack::Enable(true); - int64_t id0 = ScopedAnnotationStack::ActivityStart("foo"); - int64_t id1 = ScopedAnnotationStack::ActivityStart("bar"); + PushAnnotation("foo"); + PushAnnotation("bar"); EXPECT_EQ(AnnotationStack::Get(), "foo::bar"); // enabled - ScopedAnnotationStack::ActivityEnd(id1); - ScopedAnnotationStack::ActivityEnd(id0); + PopAnnotation(); + PopAnnotation(); AnnotationStack::Enable(false); } diff --git a/third_party/tsl/tsl/profiler/lib/traceme_encode_test.cc b/third_party/tsl/tsl/profiler/lib/traceme_encode_test.cc index ea64d28f9e2b4..4827bee4d820b 100644 --- a/third_party/tsl/tsl/profiler/lib/traceme_encode_test.cc +++ b/third_party/tsl/tsl/profiler/lib/traceme_encode_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "tsl/platform/platform.h" #include "tsl/platform/test.h" +#include "tsl/platform/test_benchmark.h" namespace tsl { namespace profiler { @@ -81,5 +82,25 @@ TEST(TraceMeEncodeTest, NoNameTest) { } } // namespace + +void BM_TraceMeEncode(::testing::benchmark::State& state) { + for (auto s : state) { + TraceMeEncode( + "MyTestEvent", + {{"Lorem ipsum dolor sit amet", 1}, + {"consectetur adipiscing elit", 2}, + {"sed do eiusmod tempor incididunt", 3.52}, + {"ut labore et dolore magna aliqua", "Ut enim ad minim veniam"}, + {"quis nostrud exercitation ullamco", "laboris nisi ut aliquip ex"}, + {"ea commodo consequat.", 11111.1111}, + {"Duis aute", 1234567890}, + {"irure dolor in", " reprehenderit in voluptate"}, + {"velit esse cillum dolore", "eu fugiat nulla pariatur."}, + {"Excepteur sint", "occaecat cupidatat non proident, sunt in"}, + {"culpa qui officia", "deserunt mollit anim id est laborum."}}); + } +} +BENCHMARK(BM_TraceMeEncode); + } // namespace profiler } // namespace tsl diff --git a/third_party/tsl/tsl/profiler/protobuf/BUILD b/third_party/tsl/tsl/profiler/protobuf/BUILD index 9994a26e7c82d..a011f2ff17fd7 100644 --- a/third_party/tsl/tsl/profiler/protobuf/BUILD +++ b/third_party/tsl/tsl/profiler/protobuf/BUILD @@ -1,6 +1,6 @@ -# Placeholder: load py_proto_library # copybara:uncomment(oss-unused) load("//net/grpc/go/build_defs:go_grpc_library.bzl", "go_grpc_library") -load("//tsl:tsl.bzl", "set_external_visibility") +# Placeholder: load py_proto_library +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl/platform:build_config.bzl", "tf_proto_library") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) @@ -17,7 +17,7 @@ tf_proto_library( srcs = ["xplane.proto"], cc_api_version = 2, make_default_target_header_only = True, - visibility = set_external_visibility([":friends"]), + visibility = internal_visibility([":friends"]), ) tf_proto_library( @@ -71,7 +71,7 @@ tf_proto_library( name = "trace_events_proto", srcs = ["trace_events.proto"], cc_api_version = 2, - visibility = set_external_visibility([":friends"]), + visibility = internal_visibility([":friends"]), ) # copybara:uncomment_begin(google-only) @@ -86,7 +86,7 @@ tf_proto_library( # This is needed because of how tf_android_core_proto_sources parses proto paths. exports_files( srcs = ["xplane.proto"], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tsl:__pkg__", ]), @@ -96,8 +96,8 @@ tf_proto_library( name = "profile_proto", srcs = ["profile.proto"], cc_api_version = 2, - visibility = set_external_visibility([ - "//tensorflow/compiler/xla/python:__pkg__", + visibility = internal_visibility([ + "@xla//xla/python:__pkg__", "//tsl/profiler:internal", ]), ) @@ -119,7 +119,7 @@ tf_proto_library( # py_proto_library( # name = "xplane_py_pb2", # api_version = 2, -# visibility = set_external_visibility([":friends"]), +# visibility = internal_visibility([":friends"]), # deps = [":xplane_proto"], # ) # copybara:uncomment_end @@ -129,5 +129,5 @@ tf_proto_library( srcs = ["profiled_instructions.proto"], cc_api_version = 2, make_default_target_header_only = True, - visibility = set_external_visibility([":friends"]), + visibility = internal_visibility([":friends"]), ) diff --git a/third_party/tsl/tsl/profiler/rpc/BUILD b/third_party/tsl/tsl/profiler/rpc/BUILD index c2c9c8cfeaa7f..d03355da35325 100644 --- a/third_party/tsl/tsl/profiler/rpc/BUILD +++ b/third_party/tsl/tsl/profiler/rpc/BUILD @@ -1,15 +1,15 @@ +load("//tsl:tsl.bzl", "internal_visibility") +load("//tsl:tsl.default.bzl", "tsl_grpc_cc_dependencies") load("//tsl/platform:rules_cc.bzl", "cc_library") load( "//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts", "tf_profiler_pybind_cc_library_wrapper", ) -load("//tsl:tsl.default.bzl", "tsl_grpc_cc_dependencies") -load("//tsl:tsl.bzl", "set_external_visibility") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility(["//tsl/profiler:internal"]), + default_visibility = internal_visibility(["//tsl/profiler:internal"]), licenses = ["notice"], ) @@ -19,7 +19,7 @@ cc_library( srcs = ["profiler_service_impl.cc"], hdrs = ["profiler_service_impl.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core/data/service:__pkg__", "//tensorflow/core/distributed_runtime/rpc:__pkg__", "//tensorflow/core/profiler/rpc:__pkg__", @@ -53,7 +53,7 @@ cc_library( tf_profiler_pybind_cc_library_wrapper( name = "profiler_server_for_pybind", actual = ":profiler_server_impl", - visibility = ["//tensorflow/python/profiler/internal:__pkg__"], + visibility = internal_visibility(["//tensorflow/python/profiler/internal:__pkg__"]), ) cc_library( @@ -61,8 +61,8 @@ cc_library( srcs = ["profiler_server.cc"], hdrs = ["profiler_server.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([ - "//tensorflow/compiler/xla:__subpackages__", + visibility = internal_visibility([ + "@xla//xla:__subpackages__", "//tensorflow/core/profiler/rpc:__pkg__", "//tensorflow/python:__pkg__", "//tensorflow/python/profiler/internal:__pkg__", diff --git a/third_party/tsl/tsl/profiler/rpc/client/BUILD b/third_party/tsl/tsl/profiler/rpc/client/BUILD index fb3a39cd1ffaf..d1fcacb0e6798 100644 --- a/third_party/tsl/tsl/profiler/rpc/client/BUILD +++ b/third_party/tsl/tsl/profiler/rpc/client/BUILD @@ -1,11 +1,11 @@ -load("//tsl/platform:rules_cc.bzl", "cc_library") -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl:tsl.default.bzl", "tsl_grpc_cc_dependencies") load( "//tsl/platform:build_config.bzl", "tf_protos_profiler_service", "tsl_cc_test", ) +load("//tsl/platform:rules_cc.bzl", "cc_library") load( "//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts", @@ -14,7 +14,7 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([ + default_visibility = internal_visibility([ "//tsl/profiler:internal", ]), licenses = ["notice"], @@ -25,8 +25,8 @@ cc_library( srcs = ["capture_profile.cc"], hdrs = ["capture_profile.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([ - "//tensorflow/compiler/xla/python:__pkg__", + visibility = internal_visibility([ + "@xla//xla/python:__pkg__", "//tensorflow/core/profiler/rpc/client:__pkg__", "//tensorflow/python/profiler/internal:__pkg__", ]), @@ -56,9 +56,9 @@ cc_library( srcs = ["save_profile.cc"], hdrs = ["save_profile.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core/profiler/rpc/client:__pkg__", - "//tensorflow/compiler/xla/python:__pkg__", + "@xla//xla/python:__pkg__", "//tsl/profiler:internal", "//tsl/profiler/rpc:__pkg__", ]), @@ -81,7 +81,7 @@ cc_library( tf_profiler_pybind_cc_library_wrapper( name = "profiler_client_for_pybind", actual = ":profiler_client", - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core/profiler/rpc/client:__pkg__", "//tensorflow/python/profiler/internal:__pkg__", ]), @@ -90,8 +90,8 @@ tf_profiler_pybind_cc_library_wrapper( cc_library( name = "profiler_client", hdrs = ["profiler_client.h"], - visibility = set_external_visibility([ - "//tensorflow/compiler/xla:__subpackages__", + visibility = internal_visibility([ + "@xla//xla:__subpackages__", "//tensorflow/core/profiler/rpc/client:__pkg__", "//tensorflow/python/profiler/internal:__pkg__", ]), @@ -113,8 +113,8 @@ cc_library( "profiler_client.h", ], copts = tf_profiler_copts(), - visibility = set_external_visibility([ - "//tensorflow/compiler/xla/python:__pkg__", + visibility = internal_visibility([ + "@xla//xla/python:__pkg__", "//tensorflow/core/profiler/rpc/client:__pkg__", "//tensorflow/python:__pkg__", "//tensorflow/python/profiler/internal:__pkg__", diff --git a/third_party/tsl/tsl/profiler/utils/BUILD b/third_party/tsl/tsl/profiler/utils/BUILD index 95fa784b13fde..3e9041b0ae988 100644 --- a/third_party/tsl/tsl/profiler/utils/BUILD +++ b/third_party/tsl/tsl/profiler/utils/BUILD @@ -1,4 +1,4 @@ -load("//tsl:tsl.bzl", "set_external_visibility") +load("//tsl:tsl.bzl", "internal_visibility") load("//tsl/platform:build_config.bzl", "tsl_cc_test") load("//tsl/platform:build_config_root.bzl", "if_static") load("//tsl/platform:rules_cc.bzl", "cc_library") @@ -6,7 +6,7 @@ load("//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([ + default_visibility = internal_visibility([ "//tsl/profiler:internal", ]), licenses = ["notice"], @@ -36,7 +36,7 @@ cc_library( name = "time_utils", hdrs = ["time_utils.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([":friends"]), + visibility = internal_visibility([":friends"]), deps = [ ":math_utils", ] + if_static([ @@ -51,8 +51,8 @@ cc_library( "time_utils.h", ], copts = tf_profiler_copts(), - visibility = set_external_visibility([ - "//tensorflow/compiler/xla:__subpackages__", + visibility = internal_visibility([ + "@xla//xla:__subpackages__", "//tsl/platform/cloud:__pkg__", "//tsl/profiler:internal", ]), @@ -114,7 +114,7 @@ cc_library( srcs = ["xplane_schema.cc"], hdrs = ["xplane_schema.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([":friends"]), + visibility = internal_visibility([":friends"]), deps = [ ":tf_op_utils", "//tsl/lib/gtl:map_util", @@ -134,7 +134,7 @@ cc_library( srcs = ["xplane_visitor.cc"], hdrs = ["xplane_visitor.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([":friends"]), + visibility = internal_visibility([":friends"]), deps = [ ":timespan", "//tsl/platform:logging", @@ -151,7 +151,7 @@ cc_library( srcs = ["xplane_builder.cc"], hdrs = ["xplane_builder.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([":friends"]), + visibility = internal_visibility([":friends"]), deps = [ ":math_utils", ":timespan", @@ -185,14 +185,13 @@ cc_library( name = "trace_utils", hdrs = ["trace_utils.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([ - "//tensorflow/compiler/xla/backends/profiler/gpu:__pkg__", + visibility = internal_visibility([ + "@xla//xla/backends/profiler/gpu:__pkg__", "//tsl/profiler:internal", ]), deps = [ "//tsl/platform:types", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", ], ) @@ -201,7 +200,7 @@ cc_library( srcs = ["xplane_utils.cc"], hdrs = ["xplane_utils.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([":friends"]), + visibility = internal_visibility([":friends"]), deps = [ ":math_utils", ":tf_xplane_visitor", @@ -214,12 +213,12 @@ cc_library( "//tsl/platform:types", "//tsl/profiler/lib:context_types", "//tsl/profiler/protobuf:xplane_proto_cc", - "//tsl/util:stats_calculator_portable", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", + "@xla//xla/tsl/util:stats_calculator_portable", ], ) @@ -247,7 +246,7 @@ cc_library( name = "tf_xplane_visitor", hdrs = ["tf_xplane_visitor.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([":friends"]), + visibility = internal_visibility([":friends"]), deps = [ ":xplane_schema", ":xplane_visitor", @@ -260,7 +259,7 @@ cc_library( srcs = ["parse_annotation.cc"], hdrs = ["parse_annotation.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([":friends"]), + visibility = internal_visibility([":friends"]), deps = [ "@com_google_absl//absl/strings", ], @@ -282,7 +281,7 @@ cc_library( srcs = ["group_events.cc"], hdrs = ["group_events.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([":friends"]), + visibility = internal_visibility([":friends"]), deps = [ ":tf_xplane_visitor", ":xplane_builder", @@ -310,7 +309,7 @@ cc_library( srcs = ["xplane_test_utils.cc"], hdrs = ["xplane_test_utils.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([":friends"]), + visibility = internal_visibility([":friends"]), deps = [ ":xplane_builder", ":xplane_schema", @@ -375,8 +374,8 @@ cc_library( name = "file_system_utils", hdrs = ["file_system_utils.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([ - "//tensorflow/compiler/xla/python:__pkg__", + visibility = internal_visibility([ + "@xla//xla/python:__pkg__", "//tsl/profiler:internal", ]), deps = [ @@ -390,8 +389,8 @@ cc_library( srcs = ["buffer_pool.cc"], hdrs = ["buffer_pool.h"], copts = tf_profiler_copts(), - visibility = set_external_visibility([ - "//tensorflow/compiler/xla/backends/profiler/gpu:__pkg__", + visibility = internal_visibility([ + "@xla//xla/backends/profiler/gpu:__pkg__", "//tsl/profiler:internal", ]), deps = [ @@ -417,7 +416,7 @@ cc_library( srcs = ["preprocess_xplane.cc"], hdrs = ["preprocess_xplane.h"], copts = tf_profiler_copts(), - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ ":tpu_xplane_utils", ":trace_utils", @@ -463,3 +462,29 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", ], ) + +cc_library( + name = "timestamp_utils", + srcs = ["timestamp_utils.cc"], + hdrs = ["timestamp_utils.h"], + deps = [ + ":xplane_builder", + ":xplane_schema", + ":xplane_utils", + "//tsl/profiler/protobuf:xplane_proto_cc", + "@com_google_absl//absl/log", + ], +) + +tsl_cc_test( + name = "timestamp_utils_test", + srcs = ["timestamp_utils_test.cc"], + deps = [ + ":timestamp_utils", + ":xplane_schema", + ":xplane_utils", + ":xplane_visitor", + "//tsl/platform:test", + "//tsl/platform:test_main", + ], +) diff --git a/third_party/tsl/tsl/profiler/utils/timestamp_utils.cc b/third_party/tsl/tsl/profiler/utils/timestamp_utils.cc new file mode 100644 index 0000000000000..ea208ed309c46 --- /dev/null +++ b/third_party/tsl/tsl/profiler/utils/timestamp_utils.cc @@ -0,0 +1,49 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tsl/profiler/utils/timestamp_utils.h" + +#include + +#include "absl/log/log.h" +#include "tsl/profiler/protobuf/xplane.pb.h" +#include "tsl/profiler/utils/xplane_builder.h" +#include "tsl/profiler/utils/xplane_schema.h" +#include "tsl/profiler/utils/xplane_utils.h" + +namespace tsl { +namespace profiler { + +void SetSessionTimestamps(uint64_t start_walltime_ns, uint64_t stop_walltime_ns, + tensorflow::profiler::XSpace& space) { + if (start_walltime_ns != 0 && stop_walltime_ns != 0) { + tsl::profiler::XPlaneBuilder plane( + tsl::profiler::FindOrAddMutablePlaneWithName( + &space, tsl::profiler::kTaskEnvPlaneName)); + plane.AddStatValue(*plane.GetOrCreateStatMetadata( + GetTaskEnvStatTypeStr(kEnvProfileStartTime)), + start_walltime_ns); + plane.AddStatValue(*plane.GetOrCreateStatMetadata( + GetTaskEnvStatTypeStr(kEnvProfileStopTime)), + stop_walltime_ns); + } else { + LOG(WARNING) << "Not Setting Session Timestamps, (start_walltime_ns, " + "stop_walltime_ns) : " + << start_walltime_ns << ", " << stop_walltime_ns; + } +} + +} // namespace profiler +} // namespace tsl diff --git a/third_party/tsl/tsl/profiler/utils/timestamp_utils.h b/third_party/tsl/tsl/profiler/utils/timestamp_utils.h new file mode 100644 index 0000000000000..87013c97a6f5b --- /dev/null +++ b/third_party/tsl/tsl/profiler/utils/timestamp_utils.h @@ -0,0 +1,33 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_TSL_PROFILER_UTILS_TIMESTAMP_UTILS_H_ +#define TENSORFLOW_TSL_PROFILER_UTILS_TIMESTAMP_UTILS_H_ + +#include + +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tsl { +namespace profiler { + +// Add metadata regarding profile start_time and stop_time to xspace. +// This function won't have an effect if either of the timestamps is zero. +void SetSessionTimestamps(uint64_t start_walltime_ns, uint64_t stop_walltime_ns, + tensorflow::profiler::XSpace& space); +} // namespace profiler +} // namespace tsl + +#endif // TENSORFLOW_TSL_PROFILER_UTILS_TIMESTAMP_UTILS_H_ diff --git a/third_party/tsl/tsl/profiler/utils/timestamp_utils_test.cc b/third_party/tsl/tsl/profiler/utils/timestamp_utils_test.cc new file mode 100644 index 0000000000000..893e31ebb5ec5 --- /dev/null +++ b/third_party/tsl/tsl/profiler/utils/timestamp_utils_test.cc @@ -0,0 +1,45 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tsl/profiler/utils/timestamp_utils.h" + +#include "tsl/platform/test.h" +#include "tsl/profiler/utils/xplane_schema.h" +#include "tsl/profiler/utils/xplane_utils.h" +#include "tsl/profiler/utils/xplane_visitor.h" + +namespace tsl { +namespace profiler { +using ::testing::Eq; + +TEST(TimestampUtilsTest, StartAndStopTimestampAreAdded) { + XSpace xspace; + + SetSessionTimestamps(1000, 2000, xspace); + + const XPlane* xplane = FindPlaneWithName(xspace, kTaskEnvPlaneName); + + XPlaneVisitor visitor(xplane, {}, {FindTaskEnvStatType}); + + auto start_time = visitor.GetStat(TaskEnvStatType::kEnvProfileStartTime); + auto stop_time = visitor.GetStat(TaskEnvStatType::kEnvProfileStopTime); + + EXPECT_THAT(start_time->IntOrUintValue(), Eq(1000)); + EXPECT_THAT(stop_time->IntOrUintValue(), Eq(2000)); +} + +} // namespace profiler + +} // namespace tsl diff --git a/third_party/tsl/tsl/profiler/utils/trace_utils.h b/third_party/tsl/tsl/profiler/utils/trace_utils.h index 6a7093b422c7d..27d27c5b9966a 100644 --- a/third_party/tsl/tsl/profiler/utils/trace_utils.h +++ b/third_party/tsl/tsl/profiler/utils/trace_utils.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/strings/numbers.h" #include "absl/strings/string_view.h" #include "tsl/platform/types.h" diff --git a/third_party/tsl/tsl/profiler/utils/xplane_schema.cc b/third_party/tsl/tsl/profiler/utils/xplane_schema.cc index 62b69f2910b33..990ffa1ed4e9d 100644 --- a/third_party/tsl/tsl/profiler/utils/xplane_schema.cc +++ b/third_party/tsl/tsl/profiler/utils/xplane_schema.cc @@ -47,8 +47,8 @@ const absl::string_view kHostCpusPlaneName = "Host CPUs"; const absl::string_view kSyscallsPlaneName = "Syscalls"; const absl::string_view kStepLineName = "Steps"; -const absl::string_view kTensorFlowNameScopeLineName = "TensorFlow Name Scope"; -const absl::string_view kTensorFlowOpLineName = "TensorFlow Ops"; +const absl::string_view kTensorFlowNameScopeLineName = "Framework Name Scope"; +const absl::string_view kTensorFlowOpLineName = "Framework Ops"; const absl::string_view kXlaModuleLineName = "XLA Modules"; const absl::string_view kXlaOpLineName = "XLA Ops"; const absl::string_view kXlaAsyncOpLineName = "Async XLA Ops"; @@ -59,6 +59,8 @@ const absl::string_view kCounterEventsLineName = "_counters_"; const absl::string_view kDeviceVendorNvidia = "Nvidia"; const absl::string_view kDeviceVendorAMD = "AMD"; +const absl::string_view kTaskEnvPlaneName = "Task Environment"; + namespace { constexpr int kNumHostEventTypes = @@ -138,6 +140,7 @@ const HostEventTypeMap& GetHostEventTypeMap() { // Batching related. {"BatchingSessionRun", kBatchingSessionRun}, {"ProcessBatch", kProcessBatch}, + {"BrainSessionRun", kBrainSessionRun}, {"ConcatInputTensors", kConcatInputTensors}, {"MergeInputTensors", kMergeInputTensors}, {"ScheduleWithoutSplit", kScheduleWithoutSplit}, @@ -265,6 +268,7 @@ const StatTypeMap& GetStatTypeMap() { {"tf_function_call", kTfFunctionCall}, {"tracing_count", kTfFunctionTracingCount}, {"flops", kFlops}, + {"model_flops", kModelFlops}, {"bytes_accessed", kBytesAccessed}, {"memory_access_breakdown", kMemoryAccessBreakdown}, {"source", kSourceInfo}, @@ -329,6 +333,7 @@ const StatTypeMap& GetStatTypeMap() { {"dcn_destination_per_slice_device_id", kDcnDestinationPerSliceDeviceId}, {"dcn_chunk", kDcnChunk}, {"dcn_loop_index", kDcnLoopIndex}, + {"dropped_traces", kDroppedTraces}, }); DCHECK_EQ(stat_type_map->size(), kNumStatTypes); return *stat_type_map; @@ -394,6 +399,29 @@ const LineIdTypeStrMap& GetLineIdTypeStrMap() { return *line_id_type_str_map; } +using TaskEnvStatTypeMap = + absl::flat_hash_map; +using TaskEnvStatTypeStrMap = + absl::flat_hash_map; + +constexpr int kNumTaskEnvStatTypes = TaskEnvStatType::kLastTaskEnvStatType - + TaskEnvStatType::kFirstTaskEnvStatType + 1; + +const TaskEnvStatTypeMap& GetTaskEnvStatTypeMap() { + static auto* task_env_stat_type_map = new TaskEnvStatTypeMap({ + {"profile_start_time", kEnvProfileStartTime}, + {"profile_stop_time", kEnvProfileStopTime}, + }); + DCHECK_EQ(task_env_stat_type_map->size(), kNumTaskEnvStatTypes); + return *task_env_stat_type_map; +} + +const TaskEnvStatTypeStrMap& GetTaskEnvStatTypeStrMap() { + static auto* task_env_stat_type_str_map = new TaskEnvStatTypeStrMap( + gtl::ReverseMap(GetTaskEnvStatTypeMap())); + return *task_env_stat_type_str_map; +} + } // namespace absl::string_view GetHostEventTypeStr(HostEventType event_type) { @@ -442,6 +470,17 @@ std::optional FindMegaScaleStatType(absl::string_view stat_name) { return std::nullopt; } +absl::string_view GetTaskEnvStatTypeStr(TaskEnvStatType stat_type) { + return GetTaskEnvStatTypeStrMap().at(stat_type); +} + +std::optional FindTaskEnvStatType(absl::string_view stat_name) { + if (auto stat_type = gtl::FindOrNull(GetTaskEnvStatTypeMap(), stat_name)) { + return *stat_type; + } + return std::nullopt; +} + absl::string_view GetLineIdTypeStr(LineIdType line_id_type) { return GetLineIdTypeStrMap().at(line_id_type); } @@ -497,6 +536,8 @@ const absl::string_view kMegaScaleDcnReceive = const absl::string_view kMegaScaleDcnSend = "MegaScale: Communication Transport Send"; const absl::string_view kMegaScaleDcnSendFinished = "MegaScale: Send Finished"; +const absl::string_view kMegaScaleDcnMemAllocate = "MegaScale: Memory Allocate"; +const absl::string_view kMegaScaleDcnMemCopy = "MegaScale: Memory Copy"; const absl::string_view kMegaScaleTopologyDiscovery = "MegaScale: Communication Topology Discovery."; const absl::string_view kMegaScaleBarrier = "MegaScale: Barrier."; @@ -509,6 +550,9 @@ const absl::string_view kMegaScaleH2DTransferStart = "MegaScale: Host to Device Action"; const absl::string_view kMegaScaleH2DTransferFinished = "MegaScale: Host to Device Transfer Finished"; +const absl::string_view kMegaScaleReductionStart = "MegaScale: Reduction"; +const absl::string_view kMegaScaleReductionFinished = + "MegaScale: Reduction Finished"; const char kXProfMetadataKey[] = "key"; const char kXProfMetadataFlow[] = "flow"; const char kXProfMetadataTransfers[] = "transfers"; diff --git a/third_party/tsl/tsl/profiler/utils/xplane_schema.h b/third_party/tsl/tsl/profiler/utils/xplane_schema.h index 7bbd052f815eb..c3d0dbddd70d3 100644 --- a/third_party/tsl/tsl/profiler/utils/xplane_schema.h +++ b/third_party/tsl/tsl/profiler/utils/xplane_schema.h @@ -77,6 +77,9 @@ TF_CONST_INIT extern const absl::string_view kCounterEventsLineName; TF_CONST_INIT extern const absl::string_view kDeviceVendorNvidia; TF_CONST_INIT extern const absl::string_view kDeviceVendorAMD; +// Name of Xplane that contains environment information +TF_CONST_INIT extern const absl::string_view kTaskEnvPlaneName; + // Max collectives to display per TPU. // Since in most cases there will be more than 9 collectives, the last line // contains all collectives that did not qualify to get their own line. @@ -131,6 +134,7 @@ enum HostEventType { // Batching related. kBatchingSessionRun, kProcessBatch, + kBrainSessionRun, kConcatInputTensors, kMergeInputTensors, kScheduleWithoutSplit, @@ -252,6 +256,7 @@ enum StatType { kTfFunctionCall, kTfFunctionTracingCount, kFlops, + kModelFlops, kBytesAccessed, kMemoryAccessBreakdown, kSourceInfo, @@ -313,7 +318,8 @@ enum StatType { kEdgeTpuModelInfo, kEdgeTpuModelProfileInfo, kEdgeTpuMlir, - kLastStatType = kEdgeTpuMlir, + kDroppedTraces, + kLastStatType = kDroppedTraces, }; enum MegaScaleStatType : uint8_t { @@ -340,6 +346,13 @@ enum MegaScaleStatType : uint8_t { kLastMegaScaleStatType = kMegaScaleGraphProtos, }; +enum TaskEnvStatType { + kFirstTaskEnvStatType = 1, + kEnvProfileStartTime = kFirstTaskEnvStatType, + kEnvProfileStopTime, + kLastTaskEnvStatType = kEnvProfileStopTime, +}; + static constexpr uint32_t kLineIdOffset = 10000; enum LineIdType { @@ -400,6 +413,10 @@ bool IsInternalEvent(std::optional event_type); // Returns true if the given stat shouldn't be shown in the trace viewer. bool IsInternalStat(std::optional stat_type); +absl::string_view GetTaskEnvStatTypeStr(TaskEnvStatType stat_type); + +std::optional FindTaskEnvStatType(absl::string_view stat_name); + // Support for flow events: // This class enables encoding/decoding the flow id and direction, stored as // XStat value. The flow id are limited to 56 bits. @@ -474,6 +491,8 @@ class XFlow { TF_CONST_INIT extern const absl::string_view kMegaScaleDcnReceive; TF_CONST_INIT extern const absl::string_view kMegaScaleDcnSend; TF_CONST_INIT extern const absl::string_view kMegaScaleDcnSendFinished; +TF_CONST_INIT extern const absl::string_view kMegaScaleDcnMemAllocate; +TF_CONST_INIT extern const absl::string_view kMegaScaleDcnMemCopy; TF_CONST_INIT extern const absl::string_view kMegaScaleTopologyDiscovery; TF_CONST_INIT extern const absl::string_view kMegaScaleBarrier; TF_CONST_INIT extern const absl::string_view kMegaScaleHostCommand; @@ -481,6 +500,8 @@ TF_CONST_INIT extern const absl::string_view kMegaScaleD2HTransferStart; TF_CONST_INIT extern const absl::string_view kMegaScaleD2HTransferFinished; TF_CONST_INIT extern const absl::string_view kMegaScaleH2DTransferStart; TF_CONST_INIT extern const absl::string_view kMegaScaleH2DTransferFinished; +TF_CONST_INIT extern const absl::string_view kMegaScaleReductionStart; +TF_CONST_INIT extern const absl::string_view kMegaScaleReductionFinished; TF_CONST_INIT extern const char kXProfMetadataKey[]; TF_CONST_INIT extern const char kXProfMetadataFlow[]; TF_CONST_INIT extern const char kXProfMetadataTransfers[]; diff --git a/third_party/tsl/tsl/profiler/utils/xplane_utils.cc b/third_party/tsl/tsl/profiler/utils/xplane_utils.cc index 7ddffb8dd49af..88c7e30b76eee 100644 --- a/third_party/tsl/tsl/profiler/utils/xplane_utils.cc +++ b/third_party/tsl/tsl/profiler/utils/xplane_utils.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" +#include "xla/tsl/util/stats_calculator.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/types.h" #include "tsl/profiler/lib/context_types.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tsl/profiler/utils/xplane_builder.h" #include "tsl/profiler/utils/xplane_schema.h" #include "tsl/profiler/utils/xplane_visitor.h" -#include "tsl/util/stats_calculator.h" namespace tsl { namespace profiler { @@ -560,7 +560,7 @@ void AggregateXPlane(const XPlane& full_trace, XPlane& aggregated_trace) { : event.EndTimestampPs(); const auto& group_stat = event.GetStat(StatType::kGroupId); int64_t group_id = - group_stat.has_value() ? group_stat->IntOrUintValue() : 0; + group_stat.has_value() ? group_stat->IntOrUintValue() : kint64max; StatByEvent& line_stats = stats[line.Id()][group_id]; line_stats[event.Id()].stat.UpdateStat(event.DurationPs()); @@ -606,7 +606,9 @@ void AggregateXPlane(const XPlane& full_trace, XPlane& aggregated_trace) { aggregated_line.AddEvent(event_metadata); aggregated_event.SetNumOccurrences(event_stat.stat.count()); aggregated_event.SetDurationPs(event_stat.stat.sum()); - aggregated_event.AddStatValue(*kGroupId, group_id); + if (group_id != kint64max) { + aggregated_event.AddStatValue(*kGroupId, group_id); + } if (event_stat.stat.count() > 1) { aggregated_event.AddStatValue(*kMinDurationPs, event_stat.stat.min()); } diff --git a/third_party/tsl/tsl/profiler/utils/xplane_utils_test.cc b/third_party/tsl/tsl/profiler/utils/xplane_utils_test.cc index ab3d6d62fe015..bfeeed52e1ef1 100644 --- a/third_party/tsl/tsl/profiler/utils/xplane_utils_test.cc +++ b/third_party/tsl/tsl/profiler/utils/xplane_utils_test.cc @@ -437,11 +437,10 @@ TEST(XplaneUtilsTest, TestAggregateXPlanes) { IgnoringRepeatedFieldOrdering(EqualsProto( R"pb(lines { id: 1 - name: "TensorFlow Ops" + name: "Framework Ops" events { metadata_id: 1 duration_ps: 9000 - stats { metadata_id: 4 int64_value: 0 } stats { metadata_id: 2 int64_value: 4000 } stats { metadata_id: 3 int64_value: 4000 } num_occurrences: 2 @@ -449,21 +448,18 @@ TEST(XplaneUtilsTest, TestAggregateXPlanes) { events { metadata_id: 3 duration_ps: 5000 - stats { metadata_id: 4 int64_value: 0 } stats { metadata_id: 2 int64_value: 2000 } num_occurrences: 2 } events { metadata_id: 4 duration_ps: 6000 - stats { metadata_id: 4 int64_value: 0 } stats { metadata_id: 3 int64_value: 2000 } num_occurrences: 1 } events { metadata_id: 2 duration_ps: 10000 - stats { metadata_id: 4 int64_value: 0 } stats { metadata_id: 2 int64_value: 5000 } num_occurrences: 2 } diff --git a/third_party/tsl/tsl/protobuf/BUILD b/third_party/tsl/tsl/protobuf/BUILD index 2f416cfd311cf..c6c461b8f8361 100644 --- a/third_party/tsl/tsl/protobuf/BUILD +++ b/third_party/tsl/tsl/protobuf/BUILD @@ -2,7 +2,7 @@ load( "//tsl:tsl.bzl", "if_google", - "set_external_visibility", + "internal_visibility", ) load( "//tsl/platform:build_config.bzl", @@ -11,7 +11,7 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([ + default_visibility = internal_visibility([ "//tensorflow/core:__subpackages__", "//tsl:internal", "//tensorflow_models:__subpackages__", @@ -103,9 +103,9 @@ tf_proto_library( name = "test_log_proto", srcs = ["test_log.proto"], make_default_target_header_only = True, - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/core:__subpackages__", - "//tsl/util:__pkg__", + "@xla//xla/tsl/util:__pkg__", ]), ) diff --git a/third_party/tsl/tsl/protobuf/dnn.proto b/third_party/tsl/tsl/protobuf/dnn.proto index b349115292e43..cc16b2141e0e7 100644 --- a/third_party/tsl/tsl/protobuf/dnn.proto +++ b/third_party/tsl/tsl/protobuf/dnn.proto @@ -179,6 +179,13 @@ message ConvolutionDescriptorProto { string name = 7; } +// NormKind kind +enum NormKind { + LAYER_FWD_INFER = 0; + LAYER_FWD_TRAIN = 1; + LAYER_BWD = 2; +} + // FusedMHAKind kind enum FusedMHAKind { BMM1_OUTPUT_UNKNOWN = 0; diff --git a/third_party/tsl/tsl/tsl.bzl b/third_party/tsl/tsl/tsl.bzl index 3372248a132fb..30dbf6312fff0 100644 --- a/third_party/tsl/tsl/tsl.bzl +++ b/third_party/tsl/tsl/tsl.bzl @@ -5,6 +5,12 @@ load( "@local_config_cuda//cuda:build_defs.bzl", "if_cuda", ) +load( + "@xla//xla/tsl/mkl:build_defs.bzl", + "if_enable_mkl", + "if_mkl", + "onednn_v3_define", +) load( "//third_party/compute_library:build_defs.bzl", "if_enable_acl", @@ -20,12 +26,6 @@ load( "if_rocm", "if_rocm_is_configured", ) -load( - "//tsl/mkl:build_defs.bzl", - "if_enable_mkl", - "if_mkl", - "onednn_v3_define", -) load( "//tsl/platform:rules_cc.bzl", "cc_binary", @@ -37,6 +37,11 @@ load( "if_tensorrt", ) +# Internally this loads a macro, but in OSS this is a function +# buildifier: disable=out-of-order-load +def register_extension_info(**kwargs): + pass + two_gpu_tags = ["requires-gpu-nvidia:2", "notap", "manual", "no_pip"] def clean_dep(target): @@ -96,6 +101,14 @@ def if_google(google_value, oss_value = []): """ return oss_value # copybara:comment_replace return google_value +def internal_visibility(internal_targets): + """Returns internal_targets in g3, but returns public in OSS. + + Useful for targets that are part of the XLA/TSL API surface but want finer-grained visibilites + internally. + """ + return if_google(internal_targets, ["//visibility:public"]) + # TODO(jakeharmon): Use this to replace if_static def if_tsl_link_protobuf(if_true, if_false = []): return select({ @@ -186,6 +199,7 @@ def if_nccl(if_true, if_false = []): return select({ clean_dep("//tsl:no_nccl_support"): if_false, clean_dep("//tsl:windows"): if_false, + clean_dep("//tsl:arm"): if_false, "//conditions:default": if_true, }) @@ -298,14 +312,12 @@ def tsl_copts( def tf_openmp_copts(): # We assume when compiling on Linux gcc/clang will be used and MSVC on Windows - # TODO(zacmustin): Update OSS to use TSL's MKL. return select({ + clean_dep("@xla//xla/tsl/mkl:build_with_mkl_lnx_openmp"): ["-fopenmp"], # copybara:uncomment_begin - # "//tsl/mkl:build_with_mkl_lnx_openmp": ["-fopenmp"], - # "//tsl/mkl:build_with_mkl_windows_openmp": ["/openmp"], + # "@xla//xla/tsl/mkl:build_with_mkl_windows_openmp": ["/openmp"], # copybara:uncomment_end_and_comment_begin - "@tsl//third_party/mkl:build_with_mkl_lnx_openmp": ["-fopenmp"], - "@tsl//third_party/mkl:build_with_mkl_windows_openmp": ["/openmp:llvm"], + clean_dep("@xla//xla/tsl/mkl:build_with_mkl_windows_openmp"): ["/openmp:llvm"], # copybara:comment_end "//conditions:default": [], }) @@ -340,7 +352,7 @@ def tsl_gpu_library(deps = None, cuda_deps = None, copts = tsl_copts(), **kwargs kwargs.pop("default_copts", None) cc_library( deps = deps + if_cuda([ - clean_dep("//tsl/cuda:cudart"), + clean_dep("@xla//xla/tsl/cuda:cudart"), "@local_config_cuda//cuda:cuda_headers", ]) + if_rocm_is_configured([ "@local_config_rocm//rocm:rocm_headers", @@ -349,6 +361,8 @@ def tsl_gpu_library(deps = None, cuda_deps = None, copts = tsl_copts(), **kwargs **kwargs ) +register_extension_info(extension = tsl_gpu_library, label_regex_for_dep = "{extension_name}") + # Traverse the dependency graph along the "deps" attribute of the # target and return a struct with one field called 'tf_collected_deps'. # tf_collected_deps will be the union of the deps of the current target @@ -424,6 +438,9 @@ check_deps = rule( def get_compatible_with_portable(): return [] +def get_compatible_with_libtpu_portable(): + return [] + def filegroup(**kwargs): native.filegroup(**kwargs) @@ -458,7 +475,7 @@ _transitive_hdrs = rule( def transitive_hdrs(name, deps = [], **kwargs): _transitive_hdrs(name = name + "_gather", deps = deps) - native.filegroup(name = name, srcs = [":" + name + "_gather"]) + native.filegroup(name = name, srcs = [":" + name + "_gather"], **kwargs) # Create a header only library that includes all the headers exported by # the libraries in deps. @@ -585,8 +602,6 @@ def tsl_pybind_extension_opensource( filegroup_name = "%s_filegroup" % name pyd_file = "%s%s.pyd" % (prefix, sname) exported_symbols = [ - "init%s" % sname, - "init_%s" % sname, "PyInit_%s" % sname, ] + additional_exported_symbols @@ -755,7 +770,5 @@ def tsl_pybind_extension_opensource( compatible_with = compatible_with, ) -# Used for specifying external visibility constraints. In non-monorepo situations, this needs to be -# public, but monorepos can have more precise constraints. -def set_external_visibility(monorepo_paths): - return if_oss(["//visibility:public"], monorepo_paths) +def nvtx_headers(): + return if_oss(["@nvtx_archive//:headers"], ["@local_config_cuda//cuda:cuda_headers"]) diff --git a/third_party/tsl/tsl/tsl.default.bzl b/third_party/tsl/tsl/tsl.default.bzl index 1759e5106320d..912939245725a 100644 --- a/third_party/tsl/tsl/tsl.default.bzl +++ b/third_party/tsl/tsl/tsl.default.bzl @@ -3,6 +3,7 @@ load( "//tsl:tsl.bzl", _filegroup = "filegroup", + _get_compatible_with_libtpu_portable = "get_compatible_with_libtpu_portable", _get_compatible_with_portable = "get_compatible_with_portable", _if_not_mobile_or_arm_or_lgpl_restricted = "if_not_mobile_or_arm_or_lgpl_restricted", _internal_hlo_deps = "internal_hlo_deps", @@ -11,6 +12,7 @@ load( ) get_compatible_with_portable = _get_compatible_with_portable +get_compatible_with_libtpu_portable = _get_compatible_with_libtpu_portable filegroup = _filegroup if_not_mobile_or_arm_or_lgpl_restricted = _if_not_mobile_or_arm_or_lgpl_restricted internal_hlo_deps = _internal_hlo_deps diff --git a/third_party/tsl/tsl/util/BUILD b/third_party/tsl/tsl/util/BUILD deleted file mode 100644 index 4b43f8096fc68..0000000000000 --- a/third_party/tsl/tsl/util/BUILD +++ /dev/null @@ -1,344 +0,0 @@ -# Description: -# Tensor Standard Libraries. -# -# The libraries in this package are not allowed to have ANY dependencies -# to other TF components outside of TSL. - -load( - "//tsl/platform:rules_cc.bzl", - "cc_library", -) -load( - "//tsl:tsl.bzl", - "check_deps", - "set_external_visibility", - "tsl_copts", -) -load("//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") -load( - "//tsl/platform:build_config_root.bzl", - "if_static", -) -load( - "//tsl/platform:build_config.bzl", - "tsl_cc_test", -) - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//visibility:public", - ], - licenses = ["notice"], -) - -filegroup( - name = "mobile_srcs_no_runtime", - srcs = [ - "byte_swap_array.cc", - "byte_swap_array.h", - ], -) - -filegroup( - name = "mobile_srcs_only_runtime", - srcs = [ - "command_line_flags.cc", - "command_line_flags.h", - "determinism.cc", - "determinism.h", - "device_name_utils.cc", - "device_name_utils.h", - "env_var.cc", - "env_var.h", - "use_cudnn.cc", - "use_cudnn.h", - ], -) - -filegroup( - name = "determnism_hdr", - srcs = [ - "determinism.h", - ], - compatible_with = get_compatible_with_portable(), - visibility = set_external_visibility([ - "//tensorflow:__subpackages__", - "//tensorflow/core/util:__pkg__", - ]), -) - -filegroup( - name = "framework_internal_private_hdrs", - srcs = [ - "byte_swap_array.h", - "command_line_flags.h", - "device_name_utils.h", - "env_var.h", - "stat_summarizer_options.h", - "stats_calculator.h", - "use_cudnn.h", - ], -) - -filegroup( - name = "framework_internal_impl_srcs", - srcs = [ - "use_cudnn.cc", - ], -) - -filegroup( - name = "lib_internal_public_hdrs", - srcs = [ - "command_line_flags.h", - "env_var.h", - "use_cudnn.h", - ], - visibility = set_external_visibility([ - "//tensorflow/core:__pkg__", - "//tensorflow/core/util:__pkg__", - ]), -) - -filegroup( - name = "determinism_hdr", - srcs = [ - "determinism.h", - ], - compatible_with = get_compatible_with_portable(), - visibility = set_external_visibility([ - "//tensorflow:__subpackages__", - "//tensorflow/core/util:__pkg__", - ]), -) - -filegroup( - name = "framework_srcs", - srcs = [ - "device_name_utils.h", - "stat_summarizer_options.h", - "use_cudnn.h", - ], -) - -cc_library( - name = "byte_swap_array", - srcs = ["byte_swap_array.cc"], - hdrs = ["byte_swap_array.h"], - deps = [ - "//tsl/platform:byte_order", - "//tsl/platform:errors", - "//tsl/platform:status", - ], -) - -cc_library( - name = "determinism_hdr_lib", - hdrs = [":determinism_hdr"], - compatible_with = get_compatible_with_portable(), - # TODO(b/298501506): narrow this in a way that won't break TAP - visibility = ["//visibility:public"], -) - -# Note: This rule should not be used as a dependency for kernels. Use the -# "determinism_for_kernels" rule below instead. -cc_library( - name = "determinism", - srcs = ["determinism.cc"], - hdrs = ["determinism.h"], - copts = tsl_copts(), - visibility = set_external_visibility(["//tensorflow:__subpackages__"]), - deps = [ - ":env_var", - "//tsl/platform:mutex", - "@com_google_absl//absl/strings", - ], - alwayslink = 1, -) - -# This alias should be used as a dependency for kernels which use determinism, -# as well any other rules which are in the same shared library as the kernels. -# This rule does not include the determinism.cc file for nonstatic builds. The -# reason is that for nonstatic builds, the shared object which contains the -# kernels (e.g. _pywrap_tensorflow_internal.so) must not contain the global -# variable in determinism.cc, since the global variable is already in -# libtensorflow_framework.so. -# -# To test that determinism.cc is not improperly included in the shared object -# which contains the kernels, you can run the "determinism_check_deps" rule -# below. -alias( - name = "determinism_for_kernels", - actual = if_static(":determinism", ":determinism_hdr_lib"), - visibility = set_external_visibility(["//tensorflow:__subpackages__"]), -) - -check_deps( - name = "determinism_check_deps", - disallowed_deps = if_static( - [], - otherwise = [":determinism"], - ), - deps = [ - ], -) - -cc_library( - name = "determinism_test_util", - hdrs = [":determinism_test_util.h"], - data = [ - # Adding this data dependency ensures determinism_check_deps is run - # whenever determinism tests are run. - ":determinism_check_deps", - ], - deps = [":determinism"], -) - -cc_library( - name = "env_var", - srcs = ["env_var.cc"], - hdrs = ["env_var.h"], - deps = [ - "//tsl/platform:errors", - "//tsl/platform:logging", - "//tsl/platform:numbers", - "//tsl/platform:status", - "//tsl/platform:str_util", - "//tsl/platform:strcat", - "//tsl/platform:stringpiece", - "//tsl/platform:types", - ], -) - -cc_library( - name = "reporter", - srcs = ["reporter.cc"], - hdrs = ["reporter.h"], - visibility = set_external_visibility([ - "//tensorflow/core:__subpackages__", - "//tsl:__subpackages__", - ]), - deps = [ - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:errors", - "//tsl/platform:macros", - "//tsl/platform:mutex", - "//tsl/platform:str_util", - "//tsl/platform:types", - "//tsl/protobuf:test_log_proto_cc", - ], -) - -cc_library( - name = "stats_calculator_portable", - srcs = [ - "stats_calculator.cc", - ], - hdrs = [ - "stat_summarizer_options.h", - "stats_calculator.h", - ], - copts = tsl_copts(), - visibility = set_external_visibility([ - "//tsl:internal", - ]), -) - -tsl_cc_test( - name = "stats_calculator_test", - srcs = ["stats_calculator_test.cc"], - deps = [ - ":stats_calculator_portable", - "//tsl/platform:test", - "//tsl/platform:test_main", - ], -) - -cc_library( - name = "device_name_utils", - srcs = ["device_name_utils.cc"], - hdrs = ["device_name_utils.h"], - deps = [ - "//tsl/platform:errors", - "//tsl/platform:status", - "//tsl/platform:stringpiece", - ], -) - -tsl_cc_test( - name = "device_name_utils_test", - size = "small", - srcs = ["device_name_utils_test.cc"], - deps = [ - ":device_name_utils", - "//tsl/lib/core:status_test_util", - "//tsl/platform:errors", - "//tsl/platform:strcat", - "//tsl/platform:test", - "//tsl/platform:test_benchmark", - "//tsl/platform:test_main", - ], -) - -cc_library( - name = "command_line_flags", - srcs = ["command_line_flags.cc"], - hdrs = ["command_line_flags.h"], - deps = [ - "//tsl/platform:logging", - "//tsl/platform:str_util", - "//tsl/platform:stringpiece", - "//tsl/platform:stringprintf", - "//tsl/platform:types", - "@com_google_absl//absl/strings", - ], -) - -filegroup( - name = "test_hdrs", - testonly = 1, - srcs = [ - "reporter.h", - ], - visibility = set_external_visibility(["//tensorflow/core/util:__pkg__"]), -) - -filegroup( - name = "onednn_util_hdrs", - srcs = [ - "onednn_threadpool.h", - ], - visibility = set_external_visibility([ - "//tensorflow/compiler/xla:__subpackages__", - "//tensorflow/core:__pkg__", - "//tensorflow/core/framework:__pkg__", - "//tensorflow/core/util:__pkg__", - ]), -) - -filegroup( - name = "android_test_hdrs", - testonly = 1, - srcs = [ - "reporter.h", - ], - visibility = set_external_visibility([ - "//tensorflow/core:__pkg__", - "//tensorflow/core/util:__pkg__", - ]), -) - -filegroup( - name = "android_test_srcs", - testonly = 1, - srcs = [ - "reporter.cc", - ":android_test_hdrs", - ], - visibility = set_external_visibility([ - "//tensorflow/core:__pkg__", - "//tensorflow/core/util:__pkg__", - ]), -) diff --git a/third_party/tsl/tsl/util/env_var.cc b/third_party/tsl/tsl/util/env_var.cc deleted file mode 100644 index e7d818445c7de..0000000000000 --- a/third_party/tsl/tsl/util/env_var.cc +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tsl/util/env_var.h" - -#include - -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/numbers.h" -#include "tsl/platform/str_util.h" -#include "tsl/platform/strcat.h" - -namespace tsl { - -Status ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val, - bool* value) { - *value = default_val; - const char* tf_env_var_val = getenv(string(env_var_name).c_str()); - if (tf_env_var_val == nullptr) { - return OkStatus(); - } - string str_value = absl::AsciiStrToLower(tf_env_var_val); - if (str_value == "0" || str_value == "false") { - *value = false; - return OkStatus(); - } else if (str_value == "1" || str_value == "true") { - *value = true; - return OkStatus(); - } - return errors::InvalidArgument(strings::StrCat( - "Failed to parse the env-var ${", env_var_name, "} into bool: ", - tf_env_var_val, ". Use the default value: ", default_val)); -} - -Status ReadInt64FromEnvVar(StringPiece env_var_name, int64_t default_val, - int64_t* value) { - *value = default_val; - const char* tf_env_var_val = getenv(string(env_var_name).c_str()); - if (tf_env_var_val == nullptr) { - return OkStatus(); - } - if (strings::safe_strto64(tf_env_var_val, value)) { - return OkStatus(); - } - return errors::InvalidArgument(strings::StrCat( - "Failed to parse the env-var ${", env_var_name, "} into int64: ", - tf_env_var_val, ". Use the default value: ", default_val)); -} - -Status ReadFloatFromEnvVar(StringPiece env_var_name, float default_val, - float* value) { - *value = default_val; - const char* tf_env_var_val = getenv(string(env_var_name).c_str()); - if (tf_env_var_val == nullptr) { - return OkStatus(); - } - if (strings::safe_strtof(tf_env_var_val, value)) { - return OkStatus(); - } - return errors::InvalidArgument(strings::StrCat( - "Failed to parse the env-var ${", env_var_name, "} into float: ", - tf_env_var_val, ". Use the default value: ", default_val)); -} - -Status ReadStringFromEnvVar(StringPiece env_var_name, StringPiece default_val, - string* value) { - const char* tf_env_var_val = getenv(string(env_var_name).c_str()); - if (tf_env_var_val != nullptr) { - *value = tf_env_var_val; - } else { - *value = string(default_val); - } - return OkStatus(); -} - -Status ReadStringsFromEnvVar(StringPiece env_var_name, StringPiece default_val, - std::vector* value) { - string str_val; - TF_RETURN_IF_ERROR(ReadStringFromEnvVar(env_var_name, default_val, &str_val)); - *value = str_util::Split(str_val, ','); - return OkStatus(); -} - -} // namespace tsl diff --git a/third_party/tsl/tsl/util/env_var.h b/third_party/tsl/tsl/util/env_var.h deleted file mode 100644 index 9c6925c57f643..0000000000000 --- a/third_party/tsl/tsl/util/env_var.h +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_TSL_UTIL_ENV_VAR_H_ -#define TENSORFLOW_TSL_UTIL_ENV_VAR_H_ - -#include "tsl/platform/status.h" -#include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" - -namespace tsl { - -// Returns a boolean into "value" from the environmental variable -// "env_var_name". If it is unset, the default value is used. A string "0" or a -// case insensitive "false" is interpreted as false. A string "1" or a case -// insensitive "true" is interpreted as true. Otherwise, an error status is -// returned. -Status ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val, - bool* value); - -// Returns an int64 into "value" from the environmental variable "env_var_name". -// If it is unset, the default value is used. -// If the string cannot be parsed into int64, an error status is returned. -Status ReadInt64FromEnvVar(StringPiece env_var_name, int64_t default_val, - int64_t* value); -// Returns a float into "value" from the environmental variable "env_var_name". -// If it is unset, the default value is used. -// If the string cannot be parsed into float, an error status is returned. -Status ReadFloatFromEnvVar(StringPiece env_var_name, float default_val, - float* value); - -// Returns a string into "value" from the environmental variable "env_var_name". -// If it is unset, the default value is used. -Status ReadStringFromEnvVar(StringPiece env_var_name, StringPiece default_val, - std::string* value); - -// Returns a comma separated string into "value" from the environmental variable -// "env_var_name". If it is unset, the default value is comma split and used. -Status ReadStringsFromEnvVar(StringPiece env_var_name, StringPiece default_val, - std::vector* value); - -} // namespace tsl - -#endif // TENSORFLOW_TSL_UTIL_ENV_VAR_H_ diff --git a/third_party/tsl/tsl/util/proto/BUILD b/third_party/tsl/tsl/util/proto/BUILD deleted file mode 100644 index 5b4acb2f5f10e..0000000000000 --- a/third_party/tsl/tsl/util/proto/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load( - "//tsl/platform:rules_cc.bzl", - "cc_library", -) - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//visibility:public", - ], - licenses = ["notice"], -) - -cc_library( - name = "proto_utils", - hdrs = ["proto_utils.h"], - deps = [ - "@com_google_absl//absl/time", - "@com_google_protobuf//:protobuf_headers", - ], -) diff --git a/third_party/tsl/tsl/util/reporter.cc b/third_party/tsl/tsl/util/reporter.cc deleted file mode 100644 index 41501bc68e8ce..0000000000000 --- a/third_party/tsl/tsl/util/reporter.cc +++ /dev/null @@ -1,104 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tsl/util/reporter.h" - -#include "tsl/platform/errors.h" -#include "tsl/platform/mutex.h" -#include "tsl/platform/str_util.h" - -namespace tsl { - -TestReportFile::TestReportFile(const string& fname, const string& test_name) - : closed_(true), fname_(fname), test_name_(test_name) {} - -Status TestReportFile::Append(const string& content) { - if (closed_) return OkStatus(); - return log_file_->Append(content); -} - -Status TestReportFile::Close() { - if (closed_) return OkStatus(); - closed_ = true; - return log_file_->Close(); -} - -Status TestReportFile::Initialize() { - if (fname_.empty()) { - return OkStatus(); - } - string mangled_fname = strings::StrCat( - fname_, absl::StrJoin(str_util::Split(test_name_, '/'), "__")); - Env* env = Env::Default(); - if (env->FileExists(mangled_fname).ok()) { - return errors::InvalidArgument( - "Cannot create TestReportFile, file exists: ", mangled_fname); - } - TF_RETURN_IF_ERROR(env->NewWritableFile(mangled_fname, &log_file_)); - TF_RETURN_IF_ERROR(log_file_->Flush()); - - closed_ = false; - return OkStatus(); -} - -TestReporter::TestReporter(const string& fname, const string& test_name) - : report_file_(fname, test_name) { - benchmark_entry_.set_name(test_name); -} - -Status TestReporter::Close() { - if (report_file_.IsClosed()) return OkStatus(); - - tensorflow::BenchmarkEntries entries; - *entries.add_entry() = benchmark_entry_; - TF_RETURN_IF_ERROR(report_file_.Append(entries.SerializeAsString())); - benchmark_entry_.Clear(); - - return report_file_.Close(); -} - -Status TestReporter::Benchmark(int64_t iters, double cpu_time, double wall_time, - double throughput) { - if (report_file_.IsClosed()) return OkStatus(); - benchmark_entry_.set_iters(iters); - benchmark_entry_.set_cpu_time(cpu_time / iters); - benchmark_entry_.set_wall_time(wall_time / iters); - benchmark_entry_.set_throughput(throughput); - return OkStatus(); -} - -Status TestReporter::SetProperty(const string& name, const string& value) { - if (report_file_.IsClosed()) return OkStatus(); - (*benchmark_entry_.mutable_extras())[name].set_string_value(value); - return OkStatus(); -} - -Status TestReporter::SetProperty(const string& name, double value) { - if (report_file_.IsClosed()) return OkStatus(); - (*benchmark_entry_.mutable_extras())[name].set_double_value(value); - return OkStatus(); -} - -Status TestReporter::AddMetric(const string& name, double value) { - if (report_file_.IsClosed()) return OkStatus(); - auto* metric = benchmark_entry_.add_metrics(); - metric->set_name(name); - metric->set_value(value); - return OkStatus(); -} - -Status TestReporter::Initialize() { return report_file_.Initialize(); } - -} // namespace tsl diff --git a/third_party/tsl/workspace0.bzl b/third_party/tsl/workspace0.bzl index 82419fece2d46..30a9426e5c604 100644 --- a/third_party/tsl/workspace0.bzl +++ b/third_party/tsl/workspace0.bzl @@ -1,10 +1,10 @@ """TensorFlow workspace initialization. Consult the WORKSPACE on how to use it.""" -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_toolchains//repositories:repositories.bzl", bazel_toolchains_repositories = "repositories") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@build_bazel_apple_support//lib:repositories.bzl", "apple_support_dependencies") load("@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies") load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies") -load("@build_bazel_apple_support//lib:repositories.bzl", "apple_support_dependencies") load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps") def _tf_bind(): diff --git a/third_party/tsl/workspace2.bzl b/third_party/tsl/workspace2.bzl index 1343f9e0f03ed..e23dcc3a4c7ad 100644 --- a/third_party/tsl/workspace2.bzl +++ b/third_party/tsl/workspace2.bzl @@ -22,7 +22,6 @@ load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") load("//third_party/gpus:rocm_configure.bzl", "rocm_configure") load("//third_party/hwloc:workspace.bzl", hwloc = "repo") load("//third_party/implib_so:workspace.bzl", implib_so = "repo") -load("//third_party/jpeg:workspace.bzl", jpeg = "repo") load("//third_party/llvm:setup.bzl", "llvm_setup") load("//third_party/nasm:workspace.bzl", nasm = "repo") load("//third_party/nccl:nccl_configure.bzl", "nccl_configure") @@ -51,7 +50,6 @@ def _initialize_third_party(): gemmlowp() hwloc() implib_so() - jpeg() ml_dtypes() nasm() pybind11_abseil() @@ -250,26 +248,6 @@ def _tf_repositories(): urls = tf_mirror_urls("https://github.com/GoogleCloudPlatform/tensorflow-gcp-tools/archive/2643d8caeba6ca2a6a0b46bb123953cb95b7e7d5.tar.gz"), ) - tf_http_archive( - name = "png", - build_file = "//third_party:png.BUILD", - patch_file = ["//third_party:png_fix_rpi.patch"], - sha256 = "a00e9d2f2f664186e4202db9299397f851aea71b36a35e74910b8820e380d441", - strip_prefix = "libpng-1.6.39", - system_build_file = "//third_party/systemlibs:png.BUILD", - urls = tf_mirror_urls("https://github.com/glennrp/libpng/archive/v1.6.39.tar.gz"), - ) - - tf_http_archive( - name = "gif", - build_file = "//third_party:gif.BUILD", - patch_file = ["//third_party:gif_fix_strtok_r.patch"], - sha256 = "31da5562f44c5f15d63340a09a4fd62b48c45620cd302f77a6d9acf0077879bd", - strip_prefix = "giflib-5.2.1", - system_build_file = "//third_party/systemlibs:gif.BUILD", - urls = tf_mirror_urls("https://pilotfiber.dl.sourceforge.net/project/giflib/giflib-5.2.1.tar.gz"), - ) - tf_http_archive( name = "six_archive", build_file = "//third_party:six.BUILD", @@ -403,7 +381,6 @@ def _tf_repositories(): urls = tf_mirror_urls("https://github.com/open-source-parsers/jsoncpp/archive/1.9.5.tar.gz"), ) - # Note: if you update this, you have to update libpng too. See cl/437813808 tf_http_archive( name = "zlib", build_file = "//third_party:zlib.BUILD", @@ -426,9 +403,18 @@ def _tf_repositories(): name = "nccl_archive", build_file = "//third_party:nccl/archive.BUILD", patch_file = ["//third_party/nccl:archive.patch"], - sha256 = "16ac98f3e926c024ce48e10ab220e19ce734adc48c423cfd55ad6f509bd1179f", - strip_prefix = "nccl-2.18.5-1", - urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.18.5-1.tar.gz"), + sha256 = "1c5474553afedb88e878c772f13d6f90b9226b3f2971dfa6f873adb9443100c2", + strip_prefix = "nccl-2.19.3-1", + urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.19.3-1.tar.gz"), + ) + + # Note that we are currently taking NVTX headers from a NCCL release to get nvToolsExtPayload.h + tf_http_archive( + name = "nvtx_archive", + build_file = "//third_party:nvtx/BUILD", + sha256 = "1c5474553afedb88e878c772f13d6f90b9226b3f2971dfa6f873adb9443100c2", + strip_prefix = "nccl-2.19.3-1/src/include/nvtx3", + urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.19.3-1.tar.gz"), ) java_import_external( diff --git a/third_party/tsl/workspace3.bzl b/third_party/tsl/workspace3.bzl index 2c13446fb4c1a..9510b09374206 100644 --- a/third_party/tsl/workspace3.bzl +++ b/third_party/tsl/workspace3.bzl @@ -1,8 +1,8 @@ """TensorFlow workspace initialization. Consult the WORKSPACE on how to use it.""" load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -load("//third_party/tf_runtime:workspace.bzl", tf_runtime = "repo") load("//third_party/llvm:workspace.bzl", llvm = "repo") +load("//third_party/tf_runtime:workspace.bzl", tf_runtime = "repo") def workspace(): http_archive( diff --git a/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl b/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl index 5f2057a3c1436..4455aea60109f 100644 --- a/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl +++ b/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl @@ -1,8 +1,8 @@ """Configurations of AARCH64 builds used with Docker container.""" -load("//tools/toolchains:cpus/aarch64/aarch64.bzl", "remote_aarch64_configure") -load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure") load("//third_party/py:python_configure.bzl", "remote_python_configure") +load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure") +load("//tools/toolchains:cpus/aarch64/aarch64.bzl", "remote_aarch64_configure") def ml2014_tf_aarch64_configs(name_container_map, env): for name, container in name_container_map.items(): diff --git a/tools/toolchains/cross_compile/cc/BUILD b/tools/toolchains/cross_compile/cc/BUILD index 7db2527259d02..a064bbf140251 100644 --- a/tools/toolchains/cross_compile/cc/BUILD +++ b/tools/toolchains/cross_compile/cc/BUILD @@ -1,6 +1,6 @@ """Toolchain configs for cross-compiling TensorFlow""" -load("@bazel_tools//tools/cpp:unix_cc_toolchain_config.bzl", "cc_toolchain_config") +load(":cc_toolchain_config.bzl", "cc_toolchain_config") package(default_visibility = ["//visibility:public"]) @@ -11,11 +11,21 @@ cc_toolchain_suite( toolchains = { "aarch64": ":linux_aarch64_toolchain", "k8": ":linux_x86_toolchain", + "darwin": ":macos_x86_toolchain", }, ) filegroup(name = "empty") +# We define a wraper ("cc_wrapper.sh") around the compiler to replace all paths +# in the binary (bazel-out/.../path/to/original/library.so) by the paths +# relative to the binary. Without it, we run into "Library not loaded" error +# when trying run cross-compiled tests, see b/300002682. +filegroup( + name = "cc_wrapper_and_macos_sysroot", + srcs = ["cc_wrapper.sh"] + glob(["MacOSX.sdk/**"]), +) + cc_toolchain( name = "linux_x86_toolchain", all_files = ":empty", @@ -186,3 +196,100 @@ cc_toolchain_config( "-Wno-gnu-offsetof-extensions", ], ) + +cc_toolchain( + name = "macos_x86_toolchain", + all_files = ":cc_wrapper_and_macos_sysroot", + compiler_files = ":cc_wrapper_and_macos_sysroot", + dwp_files = ":empty", + linker_files = ":cc_wrapper_and_macos_sysroot", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":macos_x86_toolchain_config", + toolchain_identifier = "macos_x86_toolchain", +) + +cc_toolchain_config( + name = "macos_x86_toolchain_config", + abi_libc_version = "darwin_x86_64", + abi_version = "darwin_x86_64", + builtin_sysroot = "tools/toolchains/cross_compile/cc/MacOSX.sdk", + compile_flags = [ + "--target=x86_64-apple-darwin", + "-fstack-protector", + "-Wall", + "-Wthread-safety", + "-Wself-assign", + "-Wunused-but-set-parameter", + "-Wno-free-nonheap-object", + "-fcolor-diagnostics", + "-fno-omit-frame-pointer", + "-DOS_MACOSX", + "-DGRPC_BAZEL_BUILD", + "-stdlib=libc++", + "-mavx", + # Target Catalina as the minimum supported OS + "-mmacos-version-min=10.15", + ], + compiler = "clang", + coverage_compile_flags = ["--coverage"], + coverage_link_flags = ["--coverage"], + cpu = "darwin", + cxx_builtin_include_directories = [ + "%sysroot%/usr/include", + "/usr/lib/llvm-17/include/", + "/usr/lib/llvm-17/lib/clang/17/include", + "%sysroot%/System/Library/Frameworks/Security.framework/Headers", + "%sysroot%/System/Library/Frameworks/CoreFoundation.framework/Headers", + "%sysroot%/System/Library/Frameworks/SystemConfiguration.framework/Headers", + ], + dbg_compile_flags = ["-g"], + host_system_name = "linux", + link_flags = [ + "--target=x86_64-apple-darwin", + "-lSystem", + "-fuse-ld=lld", + "--ld-path=/usr/lib/llvm-17/bin/ld64.lld", + "-headerpad_max_install_names", + "-Wl,-undefined,dynamic_lookup", + # Target Catalina as the minimum supported OS + "-Wl,-platform_version,macos,10.15.0,10.15", + ], + link_libs = [ + "-lc++", + "-lm", + ], + opt_compile_flags = [ + "-g0", + "-O2", + "-D_FORTIFY_SOURCE=1", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", + ], + opt_link_flags = ["-Wl,-dead_strip"], + supports_start_end_lib = True, + target_libc = "macosx", + target_system_name = "x86_64-apple-macosx10.15", + tool_paths = { + "gcc": "cc_wrapper.sh", + "ld": "/usr/lib/llvm-17/bin/ld64.lld", + "ar": "/usr/lib/llvm-17/bin/llvm-libtool-darwin", + "cpp": "/usr/lib/llvm-17/bin/clang++", + "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", + "nm": "/usr/lib/llvm-17/bin/llvm-nm", + "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", + "strip": "/usr/lib/llvm-17/bin/llvm-strip", + }, + toolchain_identifier = "macos_x86_toolchain", + unfiltered_compile_flags = [ + "-no-canonical-prefixes", + "-Wno-builtin-macro-redefined", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + "-Wno-unused-command-line-argument", + "-Wno-gnu-offsetof-extensions", + ], +) diff --git a/tools/toolchains/cross_compile/cc/cc_toolchain_config.bzl b/tools/toolchains/cross_compile/cc/cc_toolchain_config.bzl new file mode 100644 index 0000000000000..de638c0159b5f --- /dev/null +++ b/tools/toolchains/cross_compile/cc/cc_toolchain_config.bzl @@ -0,0 +1,1444 @@ +# Copyright 2019 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# The toolchain configuration rule file, "cc_toolchain_config.bzl", is a clone of https://github.com/bazelbuild/bazel/blob/6.1.0/tools/cpp/unix_cc_toolchain_config.bzl +# except for a few changes. We remove "-s" as a default archiver option because it is not supported +# by llvm-libtool-darwin, see https://github.com/bazelbuild/bazel/pull/17489. +# We remove "supports_dynamic_linker_feature" from the macOS toolchain config because otherwise +# certain TensorFlow tests get stuck at linking phase. See https://github.com/bazelbuild/bazel/issues/4341 +# which has more details on why this option does not work on macOS. +# +# Note to TF developers: Please make sure this file stays in sync with the Bazel version used +# by TensorFlow. If TensorFlow's Bazel version is updated, replace this file with the contents of +# `tools/cpp/unix_cc_toolchain_config.bzl` from the corresponding Bazel version tag in https://github.com/bazelbuild/bazel. +# Please make sure to add in the additional changes listed above if needed. Contact srnitin@ for +# any questions. +"""A Starlark cc_toolchain configuration rule""" + +load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES") +load( + "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl", + "action_config", + "artifact_name_pattern", + "feature", + "feature_set", + "flag_group", + "flag_set", + "tool", + "tool_path", + "variable_with_value", + "with_feature_set", +) + +def layering_check_features(compiler): + if compiler != "clang": + return [] + return [ + feature( + name = "use_module_maps", + requires = [feature_set(features = ["module_maps"])], + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group( + flags = [ + "-fmodule-name=%{module_name}", + "-fmodule-map-file=%{module_map_file}", + ], + ), + ], + ), + ], + ), + + # Note: not all C++ rules support module maps; thus, do not imply this + # feature from other features - instead, require it. + feature(name = "module_maps", enabled = True), + feature( + name = "layering_check", + implies = ["use_module_maps"], + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group(flags = [ + "-fmodules-strict-decluse", + "-Wprivate-header", + ]), + flag_group( + iterate_over = "dependent_module_map_files", + flags = [ + "-fmodule-map-file=%{dependent_module_map_files}", + ], + ), + ], + ), + ], + ), + ] + +all_compile_actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.clif_match, + ACTION_NAMES.lto_backend, +] + +all_cpp_compile_actions = [ + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.clif_match, +] + +preprocessor_compile_actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.clif_match, +] + +codegen_compile_actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, +] + +all_link_actions = [ + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, +] + +lto_index_actions = [ + ACTION_NAMES.lto_index_for_executable, + ACTION_NAMES.lto_index_for_dynamic_library, + ACTION_NAMES.lto_index_for_nodeps_dynamic_library, +] + +def _sanitizer_feature(name = "", specific_compile_flags = [], specific_link_flags = []): + return feature( + name = name, + flag_sets = [ + flag_set( + actions = all_compile_actions, + flag_groups = [ + flag_group(flags = [ + "-fno-omit-frame-pointer", + "-fno-sanitize-recover=all", + ] + specific_compile_flags), + ], + with_features = [ + with_feature_set(features = [name]), + ], + ), + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group(flags = specific_link_flags), + ], + with_features = [ + with_feature_set(features = [name]), + ], + ), + ], + ) + +def _impl(ctx): + tool_paths = [ + tool_path(name = name, path = path) + for name, path in ctx.attr.tool_paths.items() + ] + action_configs = [] + + llvm_cov_action = action_config( + action_name = ACTION_NAMES.llvm_cov, + tools = [ + tool( + path = ctx.attr.tool_paths["llvm-cov"], + ), + ], + ) + + action_configs.append(llvm_cov_action) + + supports_pic_feature = feature( + name = "supports_pic", + enabled = True, + ) + supports_start_end_lib_feature = feature( + name = "supports_start_end_lib", + enabled = True, + ) + + default_compile_flags_feature = feature( + name = "default_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_compile_actions, + flag_groups = [ + flag_group( + # Security hardening requires optimization. + # We need to undef it as some distributions now have it enabled by default. + flags = ["-U_FORTIFY_SOURCE"], + ), + ], + with_features = [ + with_feature_set( + not_features = ["thin_lto"], + ), + ], + ), + flag_set( + actions = all_compile_actions, + flag_groups = ([ + flag_group( + flags = ctx.attr.compile_flags, + ), + ] if ctx.attr.compile_flags else []), + ), + flag_set( + actions = all_compile_actions, + flag_groups = ([ + flag_group( + flags = ctx.attr.dbg_compile_flags, + ), + ] if ctx.attr.dbg_compile_flags else []), + with_features = [with_feature_set(features = ["dbg"])], + ), + flag_set( + actions = all_compile_actions, + flag_groups = ([ + flag_group( + flags = ctx.attr.opt_compile_flags, + ), + ] if ctx.attr.opt_compile_flags else []), + with_features = [with_feature_set(features = ["opt"])], + ), + flag_set( + actions = all_cpp_compile_actions + [ACTION_NAMES.lto_backend], + flag_groups = ([ + flag_group( + flags = ctx.attr.cxx_flags, + ), + ] if ctx.attr.cxx_flags else []), + ), + ], + ) + + default_link_flags_feature = feature( + name = "default_link_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = ([ + flag_group( + flags = ctx.attr.link_flags, + ), + ] if ctx.attr.link_flags else []), + ), + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = ([ + flag_group( + flags = ctx.attr.opt_link_flags, + ), + ] if ctx.attr.opt_link_flags else []), + with_features = [with_feature_set(features = ["opt"])], + ), + ], + ) + + dbg_feature = feature(name = "dbg") + + opt_feature = feature(name = "opt") + + sysroot_feature = feature( + name = "sysroot", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ] + all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = ["--sysroot=%{sysroot}"], + expand_if_available = "sysroot", + ), + ], + ), + ], + ) + + fdo_optimize_feature = feature( + name = "fdo_optimize", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = [ + "-fprofile-use=%{fdo_profile_path}", + "-fprofile-correction", + ], + expand_if_available = "fdo_profile_path", + ), + ], + ), + ], + provides = ["profile"], + ) + + supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True) + + user_compile_flags_feature = feature( + name = "user_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_compile_actions, + flag_groups = [ + flag_group( + flags = ["%{user_compile_flags}"], + iterate_over = "user_compile_flags", + expand_if_available = "user_compile_flags", + ), + ], + ), + ], + ) + + unfiltered_compile_flags_feature = feature( + name = "unfiltered_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_compile_actions, + flag_groups = ([ + flag_group( + flags = ctx.attr.unfiltered_compile_flags, + ), + ] if ctx.attr.unfiltered_compile_flags else []), + ), + ], + ) + + library_search_directories_feature = feature( + name = "library_search_directories", + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = ["-L%{library_search_directories}"], + iterate_over = "library_search_directories", + expand_if_available = "library_search_directories", + ), + ], + ), + ], + ) + + static_libgcc_feature = feature( + name = "static_libgcc", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.lto_index_for_executable, + ACTION_NAMES.lto_index_for_dynamic_library, + ], + flag_groups = [flag_group(flags = ["-static-libgcc"])], + with_features = [ + with_feature_set(features = ["static_link_cpp_runtimes"]), + ], + ), + ], + ) + + pic_feature = feature( + name = "pic", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group(flags = ["-fPIC"], expand_if_available = "pic"), + ], + ), + ], + ) + + per_object_debug_info_feature = feature( + name = "per_object_debug_info", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ["-gsplit-dwarf", "-g"], + expand_if_available = "per_object_debug_info_file", + ), + ], + ), + ], + ) + + preprocessor_defines_feature = feature( + name = "preprocessor_defines", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.clif_match, + ], + flag_groups = [ + flag_group( + flags = ["-D%{preprocessor_defines}"], + iterate_over = "preprocessor_defines", + ), + ], + ), + ], + ) + + cs_fdo_optimize_feature = feature( + name = "cs_fdo_optimize", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.lto_backend], + flag_groups = [ + flag_group( + flags = [ + "-fprofile-use=%{fdo_profile_path}", + "-Wno-profile-instr-unprofiled", + "-Wno-profile-instr-out-of-date", + "-fprofile-correction", + ], + expand_if_available = "fdo_profile_path", + ), + ], + ), + ], + provides = ["csprofile"], + ) + + autofdo_feature = feature( + name = "autofdo", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = [ + "-fauto-profile=%{fdo_profile_path}", + "-fprofile-correction", + ], + expand_if_available = "fdo_profile_path", + ), + ], + ), + ], + provides = ["profile"], + ) + + runtime_library_search_directories_feature = feature( + name = "runtime_library_search_directories", + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + iterate_over = "runtime_library_search_directories", + flag_groups = [ + flag_group( + flags = [ + "-Xlinker", + "-rpath", + "-Xlinker", + "$EXEC_ORIGIN/%{runtime_library_search_directories}", + ], + expand_if_true = "is_cc_test", + ), + flag_group( + flags = [ + "-Xlinker", + "-rpath", + "-Xlinker", + "$ORIGIN/%{runtime_library_search_directories}", + ], + expand_if_false = "is_cc_test", + ), + ], + expand_if_available = + "runtime_library_search_directories", + ), + ], + with_features = [ + with_feature_set(features = ["static_link_cpp_runtimes"]), + ], + ), + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + iterate_over = "runtime_library_search_directories", + flag_groups = [ + flag_group( + flags = [ + "-Xlinker", + "-rpath", + "-Xlinker", + "$ORIGIN/%{runtime_library_search_directories}", + ], + ), + ], + expand_if_available = + "runtime_library_search_directories", + ), + ], + with_features = [ + with_feature_set( + not_features = ["static_link_cpp_runtimes"], + ), + ], + ), + ], + ) + + fission_support_feature = feature( + name = "fission_support", + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = ["-Wl,--gdb-index"], + expand_if_available = "is_using_fission", + ), + ], + ), + ], + ) + + shared_flag_feature = feature( + name = "shared_flag", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.lto_index_for_dynamic_library, + ACTION_NAMES.lto_index_for_nodeps_dynamic_library, + ], + flag_groups = [flag_group(flags = ["-shared"])], + ), + ], + ) + + random_seed_feature = feature( + name = "random_seed", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group( + flags = ["-frandom-seed=%{output_file}"], + expand_if_available = "output_file", + ), + ], + ), + ], + ) + + includes_feature = feature( + name = "includes", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.clif_match, + ACTION_NAMES.objc_compile, + ACTION_NAMES.objcpp_compile, + ], + flag_groups = [ + flag_group( + flags = ["-include", "%{includes}"], + iterate_over = "includes", + expand_if_available = "includes", + ), + ], + ), + ], + ) + + fdo_instrument_feature = feature( + name = "fdo_instrument", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ] + all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = [ + "-fprofile-generate=%{fdo_instrument_path}", + "-fno-data-sections", + ], + expand_if_available = "fdo_instrument_path", + ), + ], + ), + ], + provides = ["profile"], + ) + + cs_fdo_instrument_feature = feature( + name = "cs_fdo_instrument", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.lto_backend, + ] + all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = [ + "-fcs-profile-generate=%{cs_fdo_instrument_path}", + ], + expand_if_available = "cs_fdo_instrument_path", + ), + ], + ), + ], + provides = ["csprofile"], + ) + + include_paths_feature = feature( + name = "include_paths", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.clif_match, + ACTION_NAMES.objc_compile, + ACTION_NAMES.objcpp_compile, + ], + flag_groups = [ + flag_group( + flags = ["-iquote", "%{quote_include_paths}"], + iterate_over = "quote_include_paths", + ), + flag_group( + flags = ["-I%{include_paths}"], + iterate_over = "include_paths", + ), + flag_group( + flags = ["-isystem", "%{system_include_paths}"], + iterate_over = "system_include_paths", + ), + ], + ), + ], + ) + + external_include_paths_feature = feature( + name = "external_include_paths", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.clif_match, + ACTION_NAMES.objc_compile, + ACTION_NAMES.objcpp_compile, + ], + flag_groups = [ + flag_group( + flags = ["-isystem", "%{external_include_paths}"], + iterate_over = "external_include_paths", + expand_if_available = "external_include_paths", + ), + ], + ), + ], + ) + + symbol_counts_feature = feature( + name = "symbol_counts", + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = [ + "-Wl,--print-symbol-counts=%{symbol_counts_output}", + ], + expand_if_available = "symbol_counts_output", + ), + ], + ), + ], + ) + + strip_debug_symbols_feature = feature( + name = "strip_debug_symbols", + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = ["-Wl,-S"], + expand_if_available = "strip_debug_symbols", + ), + ], + ), + ], + ) + + build_interface_libraries_feature = feature( + name = "build_interface_libraries", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.lto_index_for_dynamic_library, + ACTION_NAMES.lto_index_for_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = [ + "%{generate_interface_library}", + "%{interface_library_builder_path}", + "%{interface_library_input_path}", + "%{interface_library_output_path}", + ], + expand_if_available = "generate_interface_library", + ), + ], + with_features = [ + with_feature_set( + features = ["supports_interface_shared_libraries"], + ), + ], + ), + ], + ) + + libraries_to_link_feature = feature( + name = "libraries_to_link", + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + iterate_over = "libraries_to_link", + flag_groups = [ + flag_group( + flags = ["-Wl,--start-lib"], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file_group", + ), + ), + flag_group( + flags = ["-Wl,-whole-archive"], + expand_if_true = + "libraries_to_link.is_whole_archive", + ), + flag_group( + flags = ["%{libraries_to_link.object_files}"], + iterate_over = "libraries_to_link.object_files", + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file_group", + ), + ), + flag_group( + flags = ["%{libraries_to_link.name}"], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file", + ), + ), + flag_group( + flags = ["%{libraries_to_link.name}"], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "interface_library", + ), + ), + flag_group( + flags = ["%{libraries_to_link.name}"], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "static_library", + ), + ), + flag_group( + flags = ["-l%{libraries_to_link.name}"], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "dynamic_library", + ), + ), + flag_group( + flags = ["-l:%{libraries_to_link.name}"], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "versioned_dynamic_library", + ), + ), + flag_group( + flags = ["-Wl,-no-whole-archive"], + expand_if_true = "libraries_to_link.is_whole_archive", + ), + flag_group( + flags = ["-Wl,--end-lib"], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file_group", + ), + ), + ], + expand_if_available = "libraries_to_link", + ), + flag_group( + flags = ["-Wl,@%{thinlto_param_file}"], + expand_if_true = "thinlto_param_file", + ), + ], + ), + ], + ) + + user_link_flags_feature = feature( + name = "user_link_flags", + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = ["%{user_link_flags}"], + iterate_over = "user_link_flags", + expand_if_available = "user_link_flags", + ), + ], + ), + ], + ) + + default_link_libs_feature = feature( + name = "default_link_libs", + enabled = True, + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [flag_group(flags = ctx.attr.link_libs)] if ctx.attr.link_libs else [], + ), + ], + ) + + fdo_prefetch_hints_feature = feature( + name = "fdo_prefetch_hints", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.lto_backend, + ], + flag_groups = [ + flag_group( + flags = [ + "-mllvm", + "-prefetch-hints-file=%{fdo_prefetch_hints_path}", + ], + expand_if_available = "fdo_prefetch_hints_path", + ), + ], + ), + ], + ) + + linkstamps_feature = feature( + name = "linkstamps", + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = ["%{linkstamp_paths}"], + iterate_over = "linkstamp_paths", + expand_if_available = "linkstamp_paths", + ), + ], + ), + ], + ) + + archiver_flags_feature = feature( + name = "archiver_flags", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group(flags = ["rcsD"]), + flag_group( + flags = ["%{output_execpath}"], + expand_if_available = "output_execpath", + ), + ], + with_features = [ + with_feature_set( + not_features = ["libtool"], + ), + ], + ), + flag_set( + actions = [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group(flags = ["-static"]), + flag_group( + flags = ["-o", "%{output_execpath}"], + expand_if_available = "output_execpath", + ), + ], + with_features = [ + with_feature_set( + features = ["libtool"], + ), + ], + ), + flag_set( + actions = [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + iterate_over = "libraries_to_link", + flag_groups = [ + flag_group( + flags = ["%{libraries_to_link.name}"], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file", + ), + ), + flag_group( + flags = ["%{libraries_to_link.object_files}"], + iterate_over = "libraries_to_link.object_files", + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file_group", + ), + ), + ], + expand_if_available = "libraries_to_link", + ), + ], + ), + flag_set( + actions = [ACTION_NAMES.cpp_link_static_library], + flag_groups = ([ + flag_group( + flags = ctx.attr.archive_flags, + ), + ] if ctx.attr.archive_flags else []), + ), + ], + ) + + force_pic_flags_feature = feature( + name = "force_pic_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.lto_index_for_executable, + ], + flag_groups = [ + flag_group( + flags = ["-pie"], + expand_if_available = "force_pic", + ), + ], + ), + ], + ) + + dependency_file_feature = feature( + name = "dependency_file", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.objc_compile, + ACTION_NAMES.objcpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.clif_match, + ], + flag_groups = [ + flag_group( + flags = ["-MD", "-MF", "%{dependency_file}"], + expand_if_available = "dependency_file", + ), + ], + ), + ], + ) + + serialized_diagnostics_file_feature = feature( + name = "serialized_diagnostics_file", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.objc_compile, + ACTION_NAMES.objcpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.clif_match, + ], + flag_groups = [ + flag_group( + flags = ["--serialize-diagnostics", "%{serialized_diagnostics_file}"], + expand_if_available = "serialized_diagnostics_file", + ), + ], + ), + ], + ) + + dynamic_library_linker_tool_feature = feature( + name = "dynamic_library_linker_tool", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.lto_index_for_dynamic_library, + ACTION_NAMES.lto_index_for_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = [" + cppLinkDynamicLibraryToolPath + "], + expand_if_available = "generate_interface_library", + ), + ], + with_features = [ + with_feature_set( + features = ["supports_interface_shared_libraries"], + ), + ], + ), + ], + ) + + output_execpath_flags_feature = feature( + name = "output_execpath_flags", + flag_sets = [ + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = [ + flag_group( + flags = ["-o", "%{output_execpath}"], + expand_if_available = "output_execpath", + ), + ], + ), + ], + ) + + # Note that we also set --coverage for c++-link-nodeps-dynamic-library. The + # generated code contains references to gcov symbols, and the dynamic linker + # can't resolve them unless the library is linked against gcov. + coverage_feature = feature( + name = "coverage", + provides = ["profile"], + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = ([ + flag_group(flags = ctx.attr.coverage_compile_flags), + ] if ctx.attr.coverage_compile_flags else []), + ), + flag_set( + actions = all_link_actions + lto_index_actions, + flag_groups = ([ + flag_group(flags = ctx.attr.coverage_link_flags), + ] if ctx.attr.coverage_link_flags else []), + ), + ], + ) + + thinlto_feature = feature( + name = "thin_lto", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ] + all_link_actions + lto_index_actions, + flag_groups = [ + flag_group(flags = ["-flto=thin"]), + flag_group( + expand_if_available = "lto_indexing_bitcode_file", + flags = [ + "-Xclang", + "-fthin-link-bitcode=%{lto_indexing_bitcode_file}", + ], + ), + ], + ), + flag_set( + actions = [ACTION_NAMES.linkstamp_compile], + flag_groups = [flag_group(flags = ["-DBUILD_LTO_TYPE=thin"])], + ), + flag_set( + actions = lto_index_actions, + flag_groups = [ + flag_group(flags = [ + "-flto=thin", + "-Wl,-plugin-opt,thinlto-index-only%{thinlto_optional_params_file}", + "-Wl,-plugin-opt,thinlto-emit-imports-files", + "-Wl,-plugin-opt,thinlto-prefix-replace=%{thinlto_prefix_replace}", + ]), + flag_group( + expand_if_available = "thinlto_object_suffix_replace", + flags = [ + "-Wl,-plugin-opt,thinlto-object-suffix-replace=%{thinlto_object_suffix_replace}", + ], + ), + flag_group( + expand_if_available = "thinlto_merged_object_file", + flags = [ + "-Wl,-plugin-opt,obj-path=%{thinlto_merged_object_file}", + ], + ), + ], + ), + flag_set( + actions = [ACTION_NAMES.lto_backend], + flag_groups = [ + flag_group(flags = [ + "-c", + "-fthinlto-index=%{thinlto_index}", + "-o", + "%{thinlto_output_object_file}", + "-x", + "ir", + "%{thinlto_input_bitcode_file}", + ]), + ], + ), + ], + ) + + treat_warnings_as_errors_feature = feature( + name = "treat_warnings_as_errors", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["-Werror"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["-Wl,-fatal-warnings"])], + ), + ], + ) + + archive_param_file_feature = feature( + name = "archive_param_file", + enabled = True, + ) + + asan_feature = _sanitizer_feature( + name = "asan", + specific_compile_flags = [ + "-fsanitize=address", + "-fno-common", + ], + specific_link_flags = [ + "-fsanitize=address", + ], + ) + + tsan_feature = _sanitizer_feature( + name = "tsan", + specific_compile_flags = [ + "-fsanitize=thread", + ], + specific_link_flags = [ + "-fsanitize=thread", + ], + ) + + ubsan_feature = _sanitizer_feature( + name = "ubsan", + specific_compile_flags = [ + "-fsanitize=undefined", + ], + specific_link_flags = [ + "-fsanitize=undefined", + ], + ) + + is_linux = ctx.attr.target_libc != "macosx" + libtool_feature = feature( + name = "libtool", + enabled = not is_linux, + ) + + # TODO(#8303): Mac crosstool should also declare every feature. + if is_linux: + # Linux artifact name patterns are the default. + artifact_name_patterns = [] + features = [ + dependency_file_feature, + serialized_diagnostics_file_feature, + random_seed_feature, + pic_feature, + per_object_debug_info_feature, + preprocessor_defines_feature, + includes_feature, + include_paths_feature, + external_include_paths_feature, + fdo_instrument_feature, + cs_fdo_instrument_feature, + cs_fdo_optimize_feature, + thinlto_feature, + fdo_prefetch_hints_feature, + autofdo_feature, + build_interface_libraries_feature, + dynamic_library_linker_tool_feature, + symbol_counts_feature, + shared_flag_feature, + linkstamps_feature, + output_execpath_flags_feature, + runtime_library_search_directories_feature, + library_search_directories_feature, + libtool_feature, + archiver_flags_feature, + force_pic_flags_feature, + fission_support_feature, + strip_debug_symbols_feature, + coverage_feature, + supports_pic_feature, + asan_feature, + tsan_feature, + ubsan_feature, + ] + ( + [ + supports_start_end_lib_feature, + ] if ctx.attr.supports_start_end_lib else [] + ) + [ + default_compile_flags_feature, + default_link_flags_feature, + libraries_to_link_feature, + user_link_flags_feature, + default_link_libs_feature, + static_libgcc_feature, + fdo_optimize_feature, + supports_dynamic_linker_feature, + dbg_feature, + opt_feature, + user_compile_flags_feature, + sysroot_feature, + unfiltered_compile_flags_feature, + treat_warnings_as_errors_feature, + archive_param_file_feature, + ] + layering_check_features(ctx.attr.compiler) + else: + # macOS artifact name patterns differ from the defaults only for dynamic + # libraries. + artifact_name_patterns = [ + artifact_name_pattern( + category_name = "dynamic_library", + prefix = "lib", + extension = ".dylib", + ), + ] + features = [ + libtool_feature, + archiver_flags_feature, + supports_pic_feature, + asan_feature, + tsan_feature, + ubsan_feature, + ] + ( + [ + supports_start_end_lib_feature, + ] if ctx.attr.supports_start_end_lib else [] + ) + [ + coverage_feature, + default_compile_flags_feature, + default_link_flags_feature, + user_link_flags_feature, + default_link_libs_feature, + fdo_optimize_feature, + dbg_feature, + opt_feature, + user_compile_flags_feature, + sysroot_feature, + unfiltered_compile_flags_feature, + treat_warnings_as_errors_feature, + archive_param_file_feature, + ] + layering_check_features(ctx.attr.compiler) + + return cc_common.create_cc_toolchain_config_info( + ctx = ctx, + features = features, + action_configs = action_configs, + artifact_name_patterns = artifact_name_patterns, + cxx_builtin_include_directories = ctx.attr.cxx_builtin_include_directories, + toolchain_identifier = ctx.attr.toolchain_identifier, + host_system_name = ctx.attr.host_system_name, + target_system_name = ctx.attr.target_system_name, + target_cpu = ctx.attr.cpu, + target_libc = ctx.attr.target_libc, + compiler = ctx.attr.compiler, + abi_version = ctx.attr.abi_version, + abi_libc_version = ctx.attr.abi_libc_version, + tool_paths = tool_paths, + builtin_sysroot = ctx.attr.builtin_sysroot, + ) + +cc_toolchain_config = rule( + implementation = _impl, + attrs = { + "cpu": attr.string(mandatory = True), + "compiler": attr.string(mandatory = True), + "toolchain_identifier": attr.string(mandatory = True), + "host_system_name": attr.string(mandatory = True), + "target_system_name": attr.string(mandatory = True), + "target_libc": attr.string(mandatory = True), + "abi_version": attr.string(mandatory = True), + "abi_libc_version": attr.string(mandatory = True), + "cxx_builtin_include_directories": attr.string_list(), + "tool_paths": attr.string_dict(), + "compile_flags": attr.string_list(), + "dbg_compile_flags": attr.string_list(), + "opt_compile_flags": attr.string_list(), + "cxx_flags": attr.string_list(), + "link_flags": attr.string_list(), + "archive_flags": attr.string_list(), + "link_libs": attr.string_list(), + "opt_link_flags": attr.string_list(), + "unfiltered_compile_flags": attr.string_list(), + "coverage_compile_flags": attr.string_list(), + "coverage_link_flags": attr.string_list(), + "supports_start_end_lib": attr.bool(), + "builtin_sysroot": attr.string(), + }, + provides = [CcToolchainConfigInfo], +) diff --git a/tools/toolchains/cross_compile/cc/cc_wrapper.sh b/tools/toolchains/cross_compile/cc/cc_wrapper.sh new file mode 100644 index 0000000000000..a5106941d70ef --- /dev/null +++ b/tools/toolchains/cross_compile/cc/cc_wrapper.sh @@ -0,0 +1,120 @@ +#!/bin/bash +# +# Copyright 2015 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# OS X relpath is not really working. This is a wrapper script around gcc +# to simulate relpath behavior. +# +# This wrapper uses install_name_tool to replace all paths in the binary +# (bazel-out/.../path/to/original/library.so) by the paths relative to +# the binary. It parses the command line to behave as rpath is supposed +# to work. +# +# See https://blogs.oracle.com/dipol/entry/dynamic_libraries_rpath_and_mac +# on how to set those paths for Mach-O binaries. +# +set -eu + +LIBS= +LIB_DIRS= +RPATHS= +OUTPUT= + +function parse_option() { + local -r opt="$1" + if [[ "${OUTPUT}" = "1" ]]; then + OUTPUT=$opt + elif [[ "$opt" =~ ^-l(.*)$ ]]; then + LIBS="${BASH_REMATCH[1]} $LIBS" + elif [[ "$opt" =~ ^-L(.*)$ ]]; then + LIB_DIRS="${BASH_REMATCH[1]} $LIB_DIRS" + elif [[ "$opt" =~ ^\@loader_path/(.*)$ ]]; then + RPATHS="${BASH_REMATCH[1]} ${RPATHS}" + elif [[ "$opt" = "-o" ]]; then + # output is coming + OUTPUT=1 + fi +} + +# let parse the option list +for i in "$@"; do + if [[ "$i" = @* && -r "${i:1}" ]]; then + while IFS= read -r opt + do + parse_option "$opt" + done < "${i:1}" || exit 1 + else + parse_option "$i" + fi +done + +# Call the C++ compiler +/usr/lib/llvm-17/bin/clang "$@" + +function get_library_path() { + for libdir in ${LIB_DIRS}; do + if [ -f ${libdir}/lib$1.so ]; then + echo "${libdir}/lib$1.so" + elif [ -f ${libdir}/lib$1.dylib ]; then + echo "${libdir}/lib$1.dylib" + fi + done +} + +# A convenient method to return the actual path even for non symlinks +# and multi-level symlinks, see b/300002682 for more details. +function get_realpath() { + local mangled=$(echo $1 | sed 's/[-_\/a-zA-Z0-9]*_solib_darwin[-_a-zA-Z0-9]*\///g') + if [[ "${mangled:0:3}" = "lib" ]]; then + mangled="${mangled:3}" + fi + if [[ "${mangled:0:2}" = "_U" ]]; then + mangled="${mangled:2}" + fi + local mangled_path=(${mangled//_S/ }) + local demangled_path=() + for mangled in ${mangled_path[@]}; do + demangled_path+=(${mangled//_U/_}) + done + demangled_path=${demangled_path[@]} + echo "bazel-out/darwin-opt/bin/${demangled_path// //}" +} + +# Get the path of a lib inside a tool +function get_otool_path() { + # the lib path is the path of the original lib relative to the workspace + get_realpath $1 | sed 's|^.*/bazel-out/|bazel-out/|' +} + +# Do replacements in the output +for rpath in ${RPATHS}; do + for lib in ${LIBS}; do + unset libname + if [ -f "$(dirname ${OUTPUT})/${rpath}/lib${lib}.so" ]; then + libname="lib${lib}.so" + elif [ -f "$(dirname ${OUTPUT})/${rpath}/lib${lib}.dylib" ]; then + libname="lib${lib}.dylib" + fi + # ${libname-} --> return $libname if defined, or undefined otherwise. This is to make + # this set -e friendly + if [[ -n "${libname-}" ]]; then + libpath=$(get_library_path ${lib}) + if [ -n "${libpath}" ]; then + /usr/lib/llvm-17/bin/llvm-install-name-tool -change $(get_otool_path "${libpath}") \ + "@loader_path/${rpath}/${libname}" "${OUTPUT}" + fi + fi + done +done \ No newline at end of file diff --git a/tools/toolchains/cross_compile/config/BUILD b/tools/toolchains/cross_compile/config/BUILD index b6a504ba1449d..386b8858fa8b3 100644 --- a/tools/toolchains/cross_compile/config/BUILD +++ b/tools/toolchains/cross_compile/config/BUILD @@ -21,3 +21,25 @@ platform( "@platforms//cpu:aarch64", ], ) + +platform( + name = "darwin_x86_64", + constraint_values = [ + "@platforms//os:macos", + "@platforms//cpu:x86_64", + ], +) + +toolchain( + name = "macos-x86-cross-compile-cc-toolchain", + exec_compatible_with = [ + "@platforms//os:linux", + "@platforms//cpu:x86_64", + ], + target_compatible_with = [ + "@platforms//os:macos", + "@platforms//cpu:x86_64", + ], + toolchain = "//tools/toolchains/cross_compile/cc:macos_x86_toolchain", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) diff --git a/tools/toolchains/cross_compile/config/platform_mappings b/tools/toolchains/cross_compile/config/platform_mappings new file mode 100644 index 0000000000000..3ff2236a54a84 --- /dev/null +++ b/tools/toolchains/cross_compile/config/platform_mappings @@ -0,0 +1,11 @@ +platforms: +# Maps "--platforms=//tools/toolchains/cross_compile/config:darwin_x86_64" +# to "--cpu=darwin". + //tools/toolchains/cross_compile/config:darwin_x86_64 + --cpu=darwin + +flags: + # Maps "--cpu=darwin" to + # "--platforms=//tools/toolchains/cross_compile/config:darwin_x86_64". + --cpu=darwin + //tools/toolchains/cross_compile/config:darwin_x86_64 diff --git a/tools/toolchains/python/python_repo.bzl b/tools/toolchains/python/python_repo.bzl index 77011b2c0577b..47fe64d7b7b03 100644 --- a/tools/toolchains/python/python_repo.bzl +++ b/tools/toolchains/python/python_repo.bzl @@ -1,7 +1,10 @@ """ -Repository rule to set python version. -Can be set via build parameter "--repo_env=TF_PYTHON_VERSION=3.10" +Repository rule to set python version and wheel name. + +Version can be set via build parameter "--repo_env=TF_PYTHON_VERSION=3.10" Defaults to 3.10. + +To set wheel name, add "--repo_env=WHEEL_NAME=tensorflow_cpu" """ VERSIONS = ["3.9", "3.10", "3.11", "3.12"] @@ -16,20 +19,24 @@ export TF_PYTHON_VERSION=3.11 content = """ TF_PYTHON_VERSION = "{}" HERMETIC_PYTHON_VERSION = "{}" +WHEEL_NAME = "{}" +WHEEL_COLLAB = "{}" """ def _python_repository_impl(repository_ctx): repository_ctx.file("BUILD", "") version = repository_ctx.os.environ.get("TF_PYTHON_VERSION", "") + wheel_name = repository_ctx.os.environ.get("WHEEL_NAME", "tensorflow") + wheel_collab = repository_ctx.os.environ.get("WHEEL_COLLAB", False) if version not in VERSIONS: print(WARNING) # buildifier: disable=print version = DEFAULT_VERSION repository_ctx.file( "py_version.bzl", - content.format(version, version), + content.format(version, version, wheel_name, wheel_collab), ) python_repository = repository_rule( implementation = _python_repository_impl, - environ = ["TF_PYTHON_VERSION"], + environ = ["TF_PYTHON_VERSION", "WHEEL_NAME", "WHEEL_COLLAB"], ) diff --git a/tools/toolchains/remote_config/configs.bzl b/tools/toolchains/remote_config/configs.bzl index 4554463cb9067..f55194087b127 100644 --- a/tools/toolchains/remote_config/configs.bzl +++ b/tools/toolchains/remote_config/configs.bzl @@ -178,6 +178,28 @@ def initialize_rbe_configs(): python_install_path = "/usr/local", ) + tensorflow_rbe_config( + name = "ubuntu20.04-clang_manylinux2014-cuda12.1-cudnn8.9", + compiler = "/usr/lib/llvm-17/bin/clang", + cuda_version = "12.1", + cudnn_version = "8.9", + os = "ubuntu20.04-manylinux2014-multipython", + python_versions = ["3.9", "3.10", "3.11", "3.12"], + sysroot = "/dt9", + python_install_path = "/usr/local", + ) + + tensorflow_rbe_config( + name = "ubuntu20.04-gcc9_manylinux2014-cuda12.1-cudnn8.9", + compiler = "/dt9/usr/bin/gcc", + compiler_prefix = "/usr/bin", + cuda_version = "12.1", + cudnn_version = "8.9", + os = "ubuntu20.04-manylinux2014-multipython", + python_versions = ["3.9", "3.10", "3.11", "3.12"], + python_install_path = "/usr/local", + ) + tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9", compiler = "/usr/lib/llvm-17/bin/clang", @@ -200,6 +222,28 @@ def initialize_rbe_configs(): python_install_path = "/usr/local", ) + tensorflow_rbe_config( + name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9", + compiler = "/usr/lib/llvm-17/bin/clang", + cuda_version = "12.3", + cudnn_version = "8.9", + os = "ubuntu20.04-manylinux2014-multipython", + python_versions = ["3.9", "3.10", "3.11", "3.12"], + sysroot = "/dt9", + python_install_path = "/usr/local", + ) + + tensorflow_rbe_config( + name = "ubuntu20.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", + compiler = "/dt9/usr/bin/gcc", + compiler_prefix = "/usr/bin", + cuda_version = "12.3", + cudnn_version = "8.9", + os = "ubuntu20.04-manylinux2014-multipython", + python_versions = ["3.9", "3.10", "3.11", "3.12"], + python_install_path = "/usr/local", + ) + tensorflow_rbe_win_config( name = "windows_py37", python_bin_path = "C:/Python37/python.exe", @@ -605,11 +649,11 @@ def initialize_rbe_configs(): sigbuild_tf_configs( name_container_map = { - "sigbuild-r2.16": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505", - "sigbuild-r2.16-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505", - "sigbuild-r2.16-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:93c234df4c781af6974d86e9d1dd2e19ce0845b1b662c38e9a30d1de64eab3b0", - "sigbuild-r2.16-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:d0a91705406aad65a79011683b8f7d4b8131625ea26a6d08aa7c6eb6955873a2", - "sigbuild-r2.16-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:ed7313f95bce391cbf3b498ff6c534d163cc2bb91ca1d6ef6363bde4fd9e0cfc", + "sigbuild-r2.16": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", + "sigbuild-r2.16-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:22d863e6fe3f98946015b9e1264b2eeb8e56e504535a6c1d5e564cae65ae5d37", + "sigbuild-r2.16-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:da15288c8464153eadd35da720540a544b76aa9d78cceb42a6821b2f3e70a0fa", + "sigbuild-r2.16-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", + "sigbuild-r2.16-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:40fcd1d05c672672b599d9cb3784dcf379d6aa876f043b46c6ab18237d5d4e10", }, # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 # and manylinux2014 is 2.17. @@ -633,7 +677,7 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "0", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.2", + "TF_CUDA_VERSION": "12.3", "TF_CUDNN_VERSION": "8.9", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", @@ -645,11 +689,11 @@ def initialize_rbe_configs(): sigbuild_tf_configs( name_container_map = { - "sigbuild-r2.16-clang": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505", - "sigbuild-r2.16-clang-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505", - "sigbuild-r2.16-clang-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:93c234df4c781af6974d86e9d1dd2e19ce0845b1b662c38e9a30d1de64eab3b0", - "sigbuild-r2.16-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:d0a91705406aad65a79011683b8f7d4b8131625ea26a6d08aa7c6eb6955873a2", - "sigbuild-r2.16-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:ed7313f95bce391cbf3b498ff6c534d163cc2bb91ca1d6ef6363bde4fd9e0cfc", + "sigbuild-r2.16-clang": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", + "sigbuild-r2.16-clang-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:22d863e6fe3f98946015b9e1264b2eeb8e56e504535a6c1d5e564cae65ae5d37", + "sigbuild-r2.16-clang-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:da15288c8464153eadd35da720540a544b76aa9d78cceb42a6821b2f3e70a0fa", + "sigbuild-r2.16-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", + "sigbuild-r2.16-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:40fcd1d05c672672b599d9cb3784dcf379d6aa876f043b46c6ab18237d5d4e10", }, # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 # and manylinux2014 is 2.17. @@ -672,7 +716,7 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.2", + "TF_CUDA_VERSION": "12.3", "TF_CUDNN_VERSION": "8.9", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", diff --git a/tools/toolchains/remote_config/containers.bzl b/tools/toolchains/remote_config/containers.bzl index bfb4634e81032..dd222d06bd13b 100644 --- a/tools/toolchains/remote_config/containers.bzl +++ b/tools/toolchains/remote_config/containers.bzl @@ -5,8 +5,10 @@ container_digests = { # TF now uses only this container "cuda11.2-cudnn8.1-ubuntu20.04-manylinux2014-multipython": "sha256:48612bd85709cd014711d0b0f87e0806f3567d06d2e81c6e860516b87498b821", # JAX manylinux2014 configs. - "cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:ab39410baf2fc1d31d50540acec7640d7f4814fa694e2421b696b6f0a058d645", - "cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:b699d6ae235ac601dc3e62391ac7c4606cb10331f8141983858c1580f5e74ddb", + "cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:45619e91f14faabddd79fe0cb1526df4c4ad92fc2e6ebdc725ea4419225429c3", + "cuda12.1-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:8c266e5b0acd203aed5e8871b63f68a39d8d23f6d882e619797e58b973f7fe63", + "cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:9fefda035b4a12b24cd5bae56c7dbb9527a5fd06a41ced0a22ac86fe5ed26428", + "cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:6f9524a2ed7f75255dc4be3a0c5e3bda581385a1c13e2fa890bc17fa62da95b2", # ROCM, probably not all of them still in use "rocm-ubuntu18.04-manylinux2010-multipython": "sha256:6e953a09b145df338bcb03e9e36f99b291140c29b72d0a048fb6c5905ccad5eb", "rocm-ubuntu20.04-manylinux2014-multipython": "sha256:906faec7765fe5dd067f2b092b5d5f220c1fedde725fb42c83d031b4d6f32204", @@ -91,6 +93,13 @@ containers = { "digest": container_digests["cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython"], }, + # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.1-cudnn8.9-ubuntu20.04-manylinux2014-multipython. + "cuda12.1-cudnn8.9-ubuntu20.04-manylinux2014-multipython": { + "registry": "gcr.io", + "repository": "tensorflow-testing/nosla-cuda12.1-cudnn8.9-ubuntu20.04-manylinux2014-multipython", + "digest": container_digests["cuda12.1-cudnn8.9-ubuntu20.04-manylinux2014-multipython"], + }, + # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython. "cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": { "registry": "gcr.io", @@ -98,6 +107,13 @@ containers = { "digest": container_digests["cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython"], }, + # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython. + "cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython": { + "registry": "gcr.io", + "repository": "tensorflow-testing/nosla-cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython", + "digest": container_digests["cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython"], + }, + # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.rocm-ubuntu18.04-manylinux2010-multipython. "rocm-ubuntu18.04-manylinux2010-multipython": { "registry": "gcr.io", diff --git a/tools/toolchains/remote_config/rbe_config.bzl b/tools/toolchains/remote_config/rbe_config.bzl index b1488584566aa..18a84d96c39f8 100644 --- a/tools/toolchains/remote_config/rbe_config.bzl +++ b/tools/toolchains/remote_config/rbe_config.bzl @@ -1,12 +1,12 @@ """Macro that creates external repositories for remote config.""" -load("//third_party/py:python_configure.bzl", "local_python_configure", "remote_python_configure") load("//third_party/gpus:cuda_configure.bzl", "remote_cuda_configure") -load("//third_party/nccl:nccl_configure.bzl", "remote_nccl_configure") load("//third_party/gpus:rocm_configure.bzl", "remote_rocm_configure") +load("//third_party/nccl:nccl_configure.bzl", "remote_nccl_configure") +load("//third_party/py:python_configure.bzl", "local_python_configure", "remote_python_configure") +load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure") load("//third_party/tensorrt:tensorrt_configure.bzl", "remote_tensorrt_configure") load("//tools/toolchains/remote_config:containers.bzl", "containers") -load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure") def _container_image_uri(container_name): container = containers[container_name] diff --git a/tools/toolchains/win/bazel_211/BUILD b/tools/toolchains/win/bazel_211/BUILD index cc23c8ecb2268..c7484d2ae2efd 100644 --- a/tools/toolchains/win/bazel_211/BUILD +++ b/tools/toolchains/win/bazel_211/BUILD @@ -15,8 +15,8 @@ # This becomes the BUILD file for @local_config_cc// under Windows. load("@rules_cc//cc:defs.bzl", "cc_library", "cc_toolchain", "cc_toolchain_suite") -load(":windows_cc_toolchain_config.bzl", "cc_toolchain_config") load(":armeabi_cc_toolchain_config.bzl", "armeabi_cc_toolchain_config") +load(":windows_cc_toolchain_config.bzl", "cc_toolchain_config") package(default_visibility = ["//visibility:public"]) diff --git a/tools/toolchains/win/bazel_211/windows_cc_toolchain_config.bzl b/tools/toolchains/win/bazel_211/windows_cc_toolchain_config.bzl index 30571b6a5ace8..9ccc1706e5eca 100644 --- a/tools/toolchains/win/bazel_211/windows_cc_toolchain_config.bzl +++ b/tools/toolchains/win/bazel_211/windows_cc_toolchain_config.bzl @@ -14,6 +14,7 @@ """A Starlark cc_toolchain configuration rule for Windows""" +load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES") load( "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl", "action_config", @@ -29,7 +30,6 @@ load( "variable_with_value", "with_feature_set", ) -load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES") all_compile_actions = [ ACTION_NAMES.c_compile, diff --git a/tools/toolchains/win/tf_win_05022023/BUILD b/tools/toolchains/win/tf_win_05022023/BUILD index f245f6d0789c9..8a2ae6fe4a9dd 100644 --- a/tools/toolchains/win/tf_win_05022023/BUILD +++ b/tools/toolchains/win/tf_win_05022023/BUILD @@ -15,8 +15,8 @@ # This becomes the BUILD file for @local_config_cc// under Windows. load("@rules_cc//cc:defs.bzl", "cc_library", "cc_toolchain", "cc_toolchain_suite") -load(":windows_cc_toolchain_config.bzl", "cc_toolchain_config") load(":armeabi_cc_toolchain_config.bzl", "armeabi_cc_toolchain_config") +load(":windows_cc_toolchain_config.bzl", "cc_toolchain_config") package(default_visibility = ["//visibility:public"]) diff --git a/tools/toolchains/win/tf_win_05022023/windows_cc_toolchain_config.bzl b/tools/toolchains/win/tf_win_05022023/windows_cc_toolchain_config.bzl index ba3de607d1045..d6b966b32ceca 100644 --- a/tools/toolchains/win/tf_win_05022023/windows_cc_toolchain_config.bzl +++ b/tools/toolchains/win/tf_win_05022023/windows_cc_toolchain_config.bzl @@ -14,6 +14,7 @@ """A Starlark cc_toolchain configuration rule for Windows""" +load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES") load( "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl", "action_config", @@ -28,7 +29,6 @@ load( "variable_with_value", "with_feature_set", ) -load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES") all_compile_actions = [ ACTION_NAMES.c_compile, diff --git a/warnings.bazelrc b/warnings.bazelrc new file mode 100644 index 0000000000000..00e9d3f58028d --- /dev/null +++ b/warnings.bazelrc @@ -0,0 +1,98 @@ +# This file is autogenerated! Do not edit! + +# Treat warnings as errors... +build:warnings --copt=-Werror --host_copt=-Werror +# ...and silence them outside of the workspace. +build:warnings --per_file_copt=external/.*@-w +# ...and silence them on host builds. There is no host_per_file_copt and +# everything we build in the host configuration we either also build in the +# target configuration or is external, so we can't control it. +# If/when Bazel supports --host_per_file_copt, we could use that instead: +# https://github.com/bazelbuild/bazel/issues/12406. +# Would need to then make all the --copt below duplicated with --host_copt. +build:warnings --host_copt=-w + +build:warnings --copt=-Wall +build:warnings --copt=-Werror +build:warnings --copt=-Wno-address-of-packed-member +build:warnings --copt=-Wno-defaulted-function-deleted +build:warnings --copt=-Wno-enum-compare-switch +build:warnings --copt=-Wno-expansion-to-defined +build:warnings --copt=-Wno-ignored-attributes +build:warnings --copt=-Wno-ignored-qualifiers +build:warnings --copt=-Wno-inconsistent-missing-override +build:warnings --copt=-Wno-potentially-evaluated-expression +build:warnings --copt=-Wno-range-loop-analysis +build:warnings --copt=-Wno-strict-prototypes +build:warnings --copt=-Wno-tautological-type-limit-compare +build:warnings --copt=-Wno-tautological-undefined-compare +build:warnings --copt=-Wno-tautological-unsigned-zero-compare +build:warnings --copt=-Wno-tautological-unsigned-enum-zero-compare +build:warnings --copt=-Wno-undefined-func-template +build:warnings --copt=-Wno-unused-but-set-variable +build:warnings --copt=-Wno-unused-lambda-capture +build:warnings --copt=-Wno-unused-local-typedef +build:warnings --copt=-Wno-deprecated-builtins +build:warnings --copt=-Wno-deprecated-volatile +build:warnings --copt=-Wno-deprecated-anon-enum-enum-conversion +build:warnings --copt=-Wno-deprecated-enum-compare +build:warnings --copt=-Wno-deprecated-enum-enum-conversion +build:warnings --copt=-Wno-deprecated-enum-compare-conditional +build:warnings --copt=-Wno-deprecated-enum-float-conversion +build:warnings --copt=-Wno-deprecated-this-capture +build:warnings --copt=-Wno-deprecated-array-compare +build:warnings --copt=-Wno-deprecated-comma-subscript +build:warnings --copt=-Wno-bitfield-constant-conversion +build:warnings --copt=-Wno-bitwise-instead-of-logical +build:warnings --copt=-Wno-comment +build:warnings --copt=-Wno-compound-token-split +build:warnings --copt=-Wno-deprecated-non-prototype +build:warnings --copt=-Wno-enum-constexpr-conversion +build:warnings --copt=-Wno-misleading-indentation +build:warnings --copt=-Wno-psabi +build:warnings --copt=-Wno-unqualified-std-cast-call +build:warnings --copt=-Wno-ambiguous-member-template +build:warnings --copt=-Wno-char-subscripts +build:warnings --copt=-Wno-deprecated-declarations +build:warnings --copt=-Wno-deprecated-pragma +build:warnings --copt=-Wno-extern-c-compat +build:warnings --copt=-Wno-gnu-alignof-expression +build:warnings --copt=-Wno-gnu-variable-sized-type-not-at-end +build:warnings --copt=-Wno-implicit-int-float-conversion +build:warnings --copt=-Wno-invalid-source-encoding +build:warnings --copt=-Wno-mismatched-tags +build:warnings --copt=-Wno-pointer-sign +build:warnings --copt=-Wno-private-header +build:warnings --copt=-Wno-sign-compare +build:warnings --copt=-Wno-strict-overflow +build:warnings --copt=-Wno-unknown-pragmas +build:warnings --copt=-Wno-unused-command-line-argument +build:warnings --copt=-Wno-unused-const-variable +build:warnings --copt=-Wno-unused-function +build:warnings --copt=-Wno-unused-private-field +build:warnings --copt=-Wno-user-defined-warnings +build:warnings --copt=-Wfloat-overflow-conversion +build:warnings --copt=-Wfloat-zero-conversion +build:warnings --copt=-Wfor-loop-analysis +build:warnings --copt=-Wgnu-redeclared-enum +build:warnings --copt=-Winfinite-recursion +build:warnings --copt=-Wself-assign +build:warnings --copt=-Wstring-conversion +build:warnings --copt=-Wtautological-overlap-compare +build:warnings --copt=-Wunused-but-set-parameter +build:warnings --copt=-Wunused-comparison +build:warnings --copt=-Wvla +build:warnings --copt=-Wno-return-type-c-linkage +build:warnings --copt=-Wno-self-assign-overloaded +build:warnings --copt=-Wctad-maybe-unsupported +build:warnings --copt=-Wthread-safety-beta +build:warnings --copt=-Wno-trigraphs +build:warnings --copt=-Woverloaded-virtual +build:warnings --copt=-Wno-invalid-offsetof +build:warnings --copt=-Wno-final-dtor-non-final-class +build:warnings --copt=-Wnon-virtual-dtor +build:warnings --copt=-Wimplicit-fallthrough +build:warnings --copt=-Wthread-safety-analysis +build:warnings --copt=-Wno-tautological-type-limit-compare +build:warnings --copt=-Wno-builtin-macro-redefined +build:warnings --copt=-Wno-macro-redefined diff --git a/workspace0.bzl b/workspace0.bzl index 73245fd6dd153..e59cdaf8a682d 100644 --- a/workspace0.bzl +++ b/workspace0.bzl @@ -1,11 +1,11 @@ """TensorFlow workspace initialization. Consult the WORKSPACE on how to use it.""" -load("@tsl//:workspace0.bzl", "tsl_workspace0") -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_toolchains//repositories:repositories.bzl", bazel_toolchains_repositories = "repositories") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@build_bazel_apple_support//lib:repositories.bzl", "apple_support_dependencies") load("@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies") load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies") -load("@build_bazel_apple_support//lib:repositories.bzl", "apple_support_dependencies") +load("@tsl//:workspace0.bzl", "tsl_workspace0") def _tf_bind(): """Bind targets for some external repositories""" diff --git a/workspace1.bzl b/workspace1.bzl index 61b00e1675fdb..75180880898e3 100644 --- a/workspace1.bzl +++ b/workspace1.bzl @@ -1,10 +1,10 @@ """TensorFlow workspace initialization. Consult the WORKSPACE on how to use it.""" -load("@tsl//:workspace1.bzl", "tsl_workspace1") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies") +load("@tsl//:workspace1.bzl", "tsl_workspace1") # buildifier: disable=unnamed-macro def workspace(): diff --git a/workspace2.bzl b/workspace2.bzl index c2f8547b043f9..08e44e9a40775 100644 --- a/workspace2.bzl +++ b/workspace2.bzl @@ -1,20 +1,30 @@ """TensorFlow workspace initialization. Consult the WORKSPACE on how to use it.""" -# Import TSL Workspaces -load("@tsl//:workspace2.bzl", "tsl_workspace2") - # Import third party config rules. load("@bazel_skylib//lib:versions.bzl", "versions") + +# Import TSL Workspaces +load("@tsl//:workspace2.bzl", "tsl_workspace2") load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # Import third party repository rules. See go/tfbr-thirdparty. load("//third_party/dlpack:workspace.bzl", dlpack = "repo") +load("//third_party/flash_attn:workspace.bzl", flash_attn = "repo") +load("//third_party/gloo:workspace.bzl", gloo = "repo") +load("//third_party/mpitrampoline:workspace.bzl", mpitrampoline = "repo") +load("//third_party/nanobind:workspace.bzl", nanobind = "repo") +load("//third_party/robin_map:workspace.bzl", robin_map = "repo") load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo") load("//third_party/triton:workspace.bzl", triton = "repo") def _initialize_third_party(): """ Load third party repositories. See above load() statements. """ dlpack() + flash_attn() + gloo() + mpitrampoline() + nanobind() + robin_map() stablehlo() triton() @@ -32,9 +42,17 @@ def _tf_repositories(): name = "cudnn_frontend_archive", build_file = "//third_party:cudnn_frontend.BUILD", patch_file = ["//third_party:cudnn_frontend_header_fix.patch"], - sha256 = "d8dba9e2607a0c256aa8eacb45b39986ab6f3f24a4d431d4397047a3cb0cd4fb", - strip_prefix = "cudnn-frontend-0.9", - urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v0.9.zip"), + sha256 = "1bb309af98fe9aad81b6a14fd52acbd6566aacfd322fc5803f9a1b77fc681a27", + strip_prefix = "cudnn-frontend-1.2.1", + urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.2.1.zip"), + ) + + tf_http_archive( + name = "cutlass_archive", + build_file = "//third_party:cutlass.BUILD", + sha256 = "84cf3fcc47c440a8dde016eb458f8d6b93b3335d9c3a7a16f388333823f1eae0", + strip_prefix = "cutlass-afa7b7241aabe598b725c65480bd9fa71121732c", + urls = tf_mirror_urls("https://github.com/chsigg/cutlass/archive/afa7b7241aabe598b725c65480bd9fa71121732c.tar.gz"), ) tf_http_archive( @@ -84,6 +102,22 @@ def _tf_repositories(): #url = "http://www.tcs.hut.fi/Software/bliss/bliss-0.73.zip", ) + tf_http_archive( + name = "pybind11_protobuf", + urls = tf_mirror_urls("https://github.com/pybind/pybind11_protobuf/archive/80f3440cd8fee124e077e2e47a8a17b78b451363.zip"), + sha256 = "c7ab64b1ccf9a678694a89035a8c865a693e4e872803778f91f0965c2f281d78", + strip_prefix = "pybind11_protobuf-80f3440cd8fee124e077e2e47a8a17b78b451363", + ) + + # v3.4.1 + tf_http_archive( + name = "cutlass_for_flash_attn", + build_file = "//third_party:cutlass.BUILD", + sha256 = "9fa1da6be3d2d9207b801d5768cbced59c202444a8c84b82325b0670f47f9d48", + strip_prefix = "cutlass-bbe579a9e3beb6ea6626d9227ec32d0dae119a49", + urls = tf_mirror_urls("https://github.com/NVIDIA/cutlass/archive/bbe579a9e3beb6ea6626d9227ec32d0dae119a49.tar.gz"), + ) + # buildifier: disable=function-docstring # buildifier: disable=unnamed-macro def workspace(): diff --git a/xla/.clang-format b/xla/.clang-format deleted file mode 100644 index c2aa867556199..0000000000000 --- a/xla/.clang-format +++ /dev/null @@ -1,3 +0,0 @@ -BasedOnStyle: Google -Language: Cpp -PointerBindsToType: true diff --git a/xla/BUILD b/xla/BUILD index 04befaee8f0f5..d84147f3fbe5e 100644 --- a/xla/BUILD +++ b/xla/BUILD @@ -1,16 +1,18 @@ -# Placeholder: load py_proto_library -load("//xla:xla.bzl", "xla_cc_test", "xla_py_proto_library") -load("//third_party/compute_library:build_defs.bzl", "if_enable_acl") +load("@tsl//tsl:tsl.bzl", "internal_visibility") load("@tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") load( "@tsl//tsl/platform:build_config.bzl", "tf_proto_library", ) load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//third_party/compute_library:build_defs.bzl", "if_enable_acl") + +# Placeholder: load py_proto_library +load("//xla:xla.bzl", "xla_cc_test", "xla_py_proto_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//xla:internal"], + default_visibility = internal_visibility(["//xla:internal"]), licenses = ["notice"], ) @@ -24,6 +26,8 @@ package_group( "//third_party/mira/...", "//third_party/mlcompass/...", "//third_party/mlir_edge/model_curriculum/...", + "//third_party/openxla/shardonnay/...", + "//third_party/py/enzyme_ad/...", "//third_party/py/jax/...", "//third_party/py/t5x/...", "//third_party/py/tpu_graphs/...", @@ -47,7 +51,9 @@ package_group( ], ) -exports_files(["run_lit.sh"]) +exports_files([ + "lit.cfg.py", +]) # Filegroup used to collect source files for dependency checking. filegroup( @@ -64,7 +70,7 @@ filegroup( "cpu_function_runtime.cc", "executable_run_options.cc", ], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), ) filegroup( @@ -74,7 +80,7 @@ filegroup( "executable_run_options.h", "types.h", ], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), ) tf_proto_library( @@ -102,7 +108,7 @@ tf_proto_library( cc_library( name = "bit_cast", hdrs = ["bit_cast.h"], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ ":types", "@com_google_absl//absl/base", @@ -132,7 +138,7 @@ cc_library( "comparison_util.h", "primitive_util.h", ], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ ":shape_util", ":statusor", @@ -142,8 +148,8 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", - "@tsl//tsl/platform:float8", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:ml_dtypes", ], ) @@ -159,6 +165,38 @@ xla_cc_test( ], ) +cc_library( + name = "compiler_macros", + hdrs = ["compiler_macros.h"], + visibility = internal_visibility([":friends"]), +) + +cc_library( + name = "ef57", + srcs = ["ef57.cc"], + hdrs = ["ef57.h"], + visibility = internal_visibility([":friends"]), + deps = [ + ":compiler_macros", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "ef57_test", + srcs = ["ef57_test.cc"], + deps = [ + ":ef57", + ":test", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:log_streamer", + "@com_google_absl//absl/random", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:test_main", + ], +) + cc_library( name = "execution_options_util", srcs = [ @@ -167,7 +205,7 @@ cc_library( hdrs = [ "execution_options_util.h", ], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ ":debug_options_flags", ":xla_proto_cc", @@ -182,7 +220,7 @@ cc_library( hdrs = [ "frontend_attributes.h", ], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = ["//xla/hlo/ir:hlo"], ) @@ -190,7 +228,7 @@ cc_library( name = "test", testonly = 1, hdrs = ["test.h"], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ "@tsl//tsl/platform", "@tsl//tsl/platform:test", @@ -201,11 +239,10 @@ cc_library( name = "types", hdrs = ["types.h"], compatible_with = get_compatible_with_portable(), - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ - "@com_google_absl//absl/strings:str_format", "@eigen_archive//:eigen3", - "@ml_dtypes//:int4", + "@tsl//tsl/platform:ml_dtypes", ], ) @@ -213,9 +250,7 @@ xla_cc_test( name = "types_test", size = "small", srcs = ["types_test.cc"], - visibility = [ - "//visibility:private", # Only private by automation, not intent. Owner may accept CLs adding visibility. See go/scheuklappen#explicit-private. - ], + visibility = ["//visibility:private"], deps = [ ":test", ":types", @@ -227,7 +262,7 @@ cc_library( name = "service_interface", srcs = [], hdrs = ["service_interface.h"], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ ":status", ":xla_data_proto_cc", @@ -270,7 +305,9 @@ cc_library( hdrs = ["status.h"], visibility = ["//visibility:public"], deps = [ - "@tsl//tsl/platform:status", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", ], ) @@ -325,8 +362,8 @@ cc_library( "@tsl//tsl/platform:casts", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:float8", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:numbers", "@tsl//tsl/platform:stacktrace", ], @@ -339,8 +376,8 @@ xla_cc_test( ":test", ":types", ":util", - "@tsl//tsl/platform:float8", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test_main", ], ) @@ -436,9 +473,9 @@ cc_library( "@com_google_absl//absl/types:span", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:float8", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:macros", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:platform_port", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", @@ -466,7 +503,6 @@ xla_cc_test( ":shape_util", ":test", ":xla_data_proto_cc", - "//xla:status", "@com_google_absl//absl/hash:hash_testing", "@tsl//tsl/platform:test_benchmark", "@tsl//tsl/platform:test_main", @@ -566,7 +602,10 @@ cc_library( ":types", ":util", ":xla_data_proto_cc", + "//xla/tsl/util:byte_swap_array", "@com_google_absl//absl/base", + "@com_google_absl//absl/base:config", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/strings", @@ -575,13 +614,12 @@ cc_library( "@eigen_archive//:eigen3", "@tsl//tsl/lib/core:bitmap", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:float8", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:macros", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:platform_port", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", - "@tsl//tsl/util:byte_swap_array", ], ) @@ -599,15 +637,17 @@ xla_cc_test( ":status", ":test", ":types", + ":util", ":xla_data_proto_cc", "@com_google_absl//absl/base", + "@com_google_absl//absl/random", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:float8", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:macros", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_benchmark", "@tsl//tsl/platform:test_main", @@ -636,8 +676,8 @@ cc_library( "@com_google_absl//absl/types:span", "@tsl//tsl/lib/core:bitmap", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:float8", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:status", ], ) @@ -645,14 +685,14 @@ cc_library( cc_library( name = "error_spec", hdrs = ["error_spec.h"], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), ) cc_library( name = "literal_comparison", srcs = ["literal_comparison.cc"], hdrs = ["literal_comparison.h"], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ ":error_spec", ":literal", @@ -668,8 +708,8 @@ cc_library( "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:float8", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:ml_dtypes", ], ) @@ -701,7 +741,7 @@ cc_library( name = "array", srcs = ["array.cc"], hdrs = ["array.h"], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ ":status", ":types", @@ -747,7 +787,7 @@ xla_cc_test( cc_library( name = "array3d", hdrs = ["array3d.h"], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ ":array", ":types", @@ -769,7 +809,7 @@ xla_cc_test( cc_library( name = "array4d", hdrs = ["array4d.h"], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ ":array", ":array2d", @@ -807,7 +847,7 @@ cc_library( name = "packed_literal_reader", srcs = ["packed_literal_reader.cc"], hdrs = ["packed_literal_reader.h"], - visibility = [":internal"], + visibility = internal_visibility([":friends"]), deps = [ ":literal", ":shape_util", @@ -827,13 +867,12 @@ cc_library( name = "test_helpers", testonly = 1, hdrs = ["test_helpers.h"], - visibility = [":internal"], + visibility = internal_visibility([":friends"]), deps = [ + ":status", ":statusor", ":types", "@com_google_absl//absl/strings", - "@tsl//tsl/platform:protobuf", - "@tsl//tsl/platform:regexp", "@tsl//tsl/platform:test", ], ) @@ -842,7 +881,7 @@ cc_library( name = "text_literal_reader", srcs = ["text_literal_reader.cc"], hdrs = ["text_literal_reader.h"], - visibility = [":internal"], + visibility = internal_visibility([":friends"]), deps = [ ":literal", ":shape_util", @@ -879,7 +918,7 @@ cc_library( name = "text_literal_writer", srcs = ["text_literal_writer.cc"], hdrs = ["text_literal_writer.h"], - visibility = [":internal"], + visibility = internal_visibility([":friends"]), deps = [ ":literal", ":shape_util", @@ -1031,12 +1070,12 @@ cc_library( deps = [ ":types", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:logging", - "@tsl//tsl/util:command_line_flags", ], ) @@ -1046,12 +1085,12 @@ xla_cc_test( deps = [ ":parse_flags_from_env", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/strings:str_format", "@tsl//tsl/platform:env", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:subprocess", "@tsl//tsl/platform:test", - "@tsl//tsl/util:command_line_flags", ], ) @@ -1063,11 +1102,14 @@ cc_library( ], hdrs = ["debug_options_flags.h"], copts = if_enable_acl(["-DXLA_CPU_USE_ACL=1"]), - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ ":parse_flags_from_env", ":xla_proto_cc", + "//xla/stream_executor/cuda:ptx_compiler_support", + "//xla/tsl/util:command_line_flags", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_map", @@ -1075,7 +1117,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:protobuf", - "@tsl//tsl/util:command_line_flags", ], ) @@ -1084,7 +1125,7 @@ cc_library( srcs = ["cpu_function_runtime.cc"], hdrs = ["cpu_function_runtime.h"], compatible_with = get_compatible_with_portable(), - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ "@com_google_absl//absl/base:dynamic_annotations", ], @@ -1182,24 +1223,13 @@ cc_library( ], ) -filegroup( - name = "litfiles", - srcs = [ - "runlit.cfg.py", - "runlit.site.cfg.py", - ], - visibility = [ - "//xla:__subpackages__", # Scheuklappen: keep - ], -) - # ----------------------------------------------------------------------------- # copybara:uncomment_begin(google-only) # py_proto_library( # name = "xla_data_proto_py_pb2", # api_version = 2, -# visibility = [":friends"], +# visibility = internal_visibility([":friends"]), # deps = [":xla_data_proto"], # ) # @@ -1207,8 +1237,8 @@ filegroup( # name = "xla_py_pb2", # testonly = 0, # api_version = 2, -# compatible_with = ["//buildenv/target:gce"], -# visibility = [":friends"], +# compatible_with = ["//buildenv/target:non_prod"], +# visibility = internal_visibility([":friends"]), # deps = [":xla_proto"], # ) # copybara:uncomment_end @@ -1217,3 +1247,10 @@ cc_library( name = "empty", visibility = ["//visibility:public"], ) + +# Needed to workaround https://github.com/bazelbuild/bazel/issues/21519 +alias( + name = "bazel_issue_21519", + actual = ":empty", + visibility = ["//visibility:public"], +) diff --git a/xla/README.md b/xla/README.md index 6138b17feaac1..b250b82b1d3fc 100644 --- a/xla/README.md +++ b/xla/README.md @@ -4,7 +4,7 @@ XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that optimizes TensorFlow computations. See the -[documentation](./g3doc/index.md). +[documentation](./../docs/index.md). This directory is currently migrating to [OpenXLA](https://github.com/openxla/) and will be the root of the [openxla/xla](https://github.com/openxla/xla) diff --git a/xla/array.cc b/xla/array.cc index 04140243d10c6..da35c8eb83f48 100644 --- a/xla/array.cc +++ b/xla/array.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/array.h b/xla/array.h index eebc13b4bef26..145cb2856b106 100644 --- a/xla/array.h +++ b/xla/array.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/array2d.h b/xla/array2d.h index 8f86c09d4fc37..2e8c1547a967a 100644 --- a/xla/array2d.h +++ b/xla/array2d.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/array2d_test.cc b/xla/array2d_test.cc index 8514bd863584f..b7052d1c33f72 100644 --- a/xla/array2d_test.cc +++ b/xla/array2d_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/array3d.h b/xla/array3d.h index 6c8b81d165ec1..ba3147476f0a7 100644 --- a/xla/array3d.h +++ b/xla/array3d.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/array3d_test.cc b/xla/array3d_test.cc index e79d59c996c4e..07599797fbbb0 100644 --- a/xla/array3d_test.cc +++ b/xla/array3d_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/array4d.h b/xla/array4d.h index 0b4549dd0a859..816a1fd976da9 100644 --- a/xla/array4d.h +++ b/xla/array4d.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/array4d_test.cc b/xla/array4d_test.cc index adbf3135cb476..2285b38230d51 100644 --- a/xla/array4d_test.cc +++ b/xla/array4d_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/array_test.cc b/xla/array_test.cc index 8c764998fd65b..c033271d38894 100644 --- a/xla/array_test.cc +++ b/xla/array_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/autotune_results.proto b/xla/autotune_results.proto index 70e981364f4a1..0244dc6fa492d 100644 --- a/xla/autotune_results.proto +++ b/xla/autotune_results.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/autotuning.proto b/xla/autotuning.proto index b177fe3e2895e..9a867ab7e2177 100644 --- a/xla/autotuning.proto +++ b/xla/autotuning.proto @@ -82,6 +82,7 @@ message AutotuneResult { int64 split_k = 4; int64 num_stages = 5; int64 num_warps = 6; + int64 num_ctas = 7; } int64 scratch_bytes = 8; diff --git a/xla/backends/interpreter/BUILD b/xla/backends/interpreter/BUILD index 08123dbbe365c..4f39603307847 100644 --- a/xla/backends/interpreter/BUILD +++ b/xla/backends/interpreter/BUILD @@ -15,10 +15,9 @@ cc_library( srcs = ["interpreter_transfer_manager.cc"], hdrs = ["interpreter_transfer_manager.h"], deps = [ - "//xla/backends/interpreter:platform_id", + ":platform_id", "//xla/service:generic_transfer_manager", "//xla/service:transfer_manager", - "@com_google_absl//absl/memory", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -120,12 +119,9 @@ cc_library( "//xla/service:shaped_buffer", "//xla/service:transfer_manager", "//xla/stream_executor", - "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:mutex", ], ) @@ -140,8 +136,9 @@ cc_library( "//xla/stream_executor/platform", "@com_google_absl//absl/strings:str_format", "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", ], - alwayslink = True, # Registers itself with the MultiPlatformManager. + alwayslink = True, # Registers itself with the PlatformManager. ) cc_library( @@ -156,7 +153,7 @@ cc_library( "//xla/stream_executor:stream_executor_internal", "//xla/stream_executor/host:host_stream", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:logging", ], ) diff --git a/xla/backends/interpreter/compiler.cc b/xla/backends/interpreter/compiler.cc index 3b89c3b6054de..e48f63f77929a 100644 --- a/xla/backends/interpreter/compiler.cc +++ b/xla/backends/interpreter/compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -61,7 +61,7 @@ namespace { // Handles custom_call ops during evaluation by routing them through the global // CPU registry used by other CPU-based backends. -StatusOr HandleEvaluatorCustomCall( +absl::StatusOr HandleEvaluatorCustomCall( const HloInstruction* custom_call, absl::Span operands) { // Find the target C function in the global registry. auto* registry = CustomCallTargetRegistry::Global(); @@ -92,12 +92,13 @@ StatusOr HandleEvaluatorCustomCall( Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); + // The TopkDecomposer generates a compare op with type=TOTALORDER and must + // run before the ComparisonExpander which rewrites such comparisons. pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass( /*rewrite_training_op=*/true, @@ -109,7 +110,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { return pipeline.Run(hlo_module).status(); } -StatusOr> InterpreterCompiler::RunHloPasses( +absl::StatusOr> InterpreterCompiler::RunHloPasses( std::unique_ptr hlo_module, se::StreamExecutor* /*stream_exec*/, const CompileOptions& /*options*/) { VLOG(1) << "Run hlo passes on graph " << hlo_module->name(); @@ -117,7 +118,7 @@ StatusOr> InterpreterCompiler::RunHloPasses( return std::move(hlo_module); } -StatusOr> InterpreterCompiler::RunBackend( +absl::StatusOr> InterpreterCompiler::RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, const CompileOptions& /*options*/) { TF_RET_CHECK(stream_exec != nullptr); @@ -146,7 +147,8 @@ StatusOr> InterpreterCompiler::RunBackend( return std::move(executable); } -StatusOr>> InterpreterCompiler::Compile( +absl::StatusOr>> +InterpreterCompiler::Compile( std::unique_ptr module_group, std::vector> stream_exec, const CompileOptions& options) { @@ -170,7 +172,7 @@ StatusOr>> InterpreterCompiler::Compile( return std::move(ret); } -StatusOr>> +absl::StatusOr>> InterpreterCompiler::CompileAheadOfTime( std::unique_ptr module_group, const AotCompilationOptions& aot_options) { diff --git a/xla/backends/interpreter/compiler.h b/xla/backends/interpreter/compiler.h index 218160ff132d8..cfdbaa4fd2392 100644 --- a/xla/backends/interpreter/compiler.h +++ b/xla/backends/interpreter/compiler.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -42,18 +42,18 @@ class InterpreterCompiler : public Compiler { InterpreterCompiler() {} ~InterpreterCompiler() override {} - StatusOr> RunHloPasses( + absl::StatusOr> RunHloPasses( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, const CompileOptions& options) override; - StatusOr> RunBackend( + absl::StatusOr> RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, const CompileOptions& options) override; - StatusOr>> Compile( + absl::StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_exec, const CompileOptions& options) override; - StatusOr>> + absl::StatusOr>> CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& aot_options) override; diff --git a/xla/backends/interpreter/executable.cc b/xla/backends/interpreter/executable.cc index 9bcb7c0908be0..bd7b6261a8515 100644 --- a/xla/backends/interpreter/executable.cc +++ b/xla/backends/interpreter/executable.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -51,7 +51,7 @@ InterpreterExecutable::InterpreterExecutable( } } -StatusOr InterpreterExecutable::Evaluate( +absl::StatusOr InterpreterExecutable::Evaluate( const ServiceExecutableRunOptions* run_options, const HloComputation& computation, absl::Span arg_literals) { // Execute the graph using the HloEvaluator. diff --git a/xla/backends/interpreter/executable.h b/xla/backends/interpreter/executable.h index fcfb89c278615..4a66f3bb375e2 100644 --- a/xla/backends/interpreter/executable.h +++ b/xla/backends/interpreter/executable.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -48,9 +48,10 @@ class InterpreterExecutable : public InterpreterExecutableBase { static int64_t ShapeSizeBytes(const Shape& shape); protected: - StatusOr Evaluate(const ServiceExecutableRunOptions* run_options, - const HloComputation& computation, - absl::Span arg_literals) override + absl::StatusOr Evaluate( + const ServiceExecutableRunOptions* run_options, + const HloComputation& computation, + absl::Span arg_literals) override ABSL_LOCKS_EXCLUDED(evaluator_lock_); // The interpreter interprets executables with an HloEvaluator. diff --git a/xla/backends/interpreter/executable_base.cc b/xla/backends/interpreter/executable_base.cc index bcac33eff56b0..78443e635ad73 100644 --- a/xla/backends/interpreter/executable_base.cc +++ b/xla/backends/interpreter/executable_base.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ InterpreterExecutableBase::InterpreterExecutableBase( : Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr, /*hlo_profile_index_map=*/nullptr) {} -StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( +absl::StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, std::vector arguments, HloExecutionProfile* hlo_execution_profile) { @@ -150,14 +150,15 @@ StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( return std::move(result); } -StatusOr +absl::StatusOr InterpreterExecutableBase::AllocateOutputMemoryWithInputReuse( const Shape& shape, const HloInputOutputAliasConfig& alias_config, se::DeviceMemoryAllocator* allocator, std::vector* arguments, se::Stream* stream) { TF_RETURN_IF_ERROR(alias_config.ForEachAliasWithStatus( [&](const ShapeIndex& output_index, - std::optional alias) { + std::optional alias) + -> absl::Status { if (alias && alias->must_alias()) { VLOG(1) << alias->ToString(); const MaybeOwningDeviceMemory& original_input = @@ -187,8 +188,7 @@ InterpreterExecutableBase::AllocateOutputMemoryWithInputReuse( result.Result().on_device_shape(), result_index)); if (!ShapeUtil::IndexIsValid(alias_config.shape(), result_index)) { - return InternalError("result_index is invalid: %s", - result_index.ToString()); + return Internal("result_index is invalid: %s", result_index.ToString()); } std::optional alias = diff --git a/xla/backends/interpreter/executable_base.h b/xla/backends/interpreter/executable_base.h index db3ee23c0d355..fa55e56746443 100644 --- a/xla/backends/interpreter/executable_base.h +++ b/xla/backends/interpreter/executable_base.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -37,19 +37,19 @@ class InterpreterExecutableBase : public Executable { public: explicit InterpreterExecutableBase(std::unique_ptr hlo_module); - StatusOr ExecuteAsyncOnStream( + absl::StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, std::vector arguments, HloExecutionProfile* hlo_execution_profile) override; protected: - virtual StatusOr Evaluate( + virtual absl::StatusOr Evaluate( const ServiceExecutableRunOptions* run_options, const HloComputation& computation, absl::Span arg_literals) = 0; private: - StatusOr AllocateOutputMemoryWithInputReuse( + absl::StatusOr AllocateOutputMemoryWithInputReuse( const Shape& shape, const HloInputOutputAliasConfig& alias_config, se::DeviceMemoryAllocator* allocator, std::vector* arguments, stream_executor::Stream* stream); diff --git a/xla/backends/interpreter/executor.cc b/xla/backends/interpreter/executor.cc index 3766f7cb7af82..26c548e009029 100644 --- a/xla/backends/interpreter/executor.cc +++ b/xla/backends/interpreter/executor.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/functional/any_invocable.h" +#include "absl/log/log.h" #include "xla/status_macros.h" namespace stream_executor { @@ -34,66 +35,45 @@ DeviceMemoryBase XlaInterpreterExecutor::Allocate(uint64_t size, return DeviceMemoryBase(new char[size], size); } -void *XlaInterpreterExecutor::GetSubBuffer(DeviceMemoryBase *parent, - uint64_t offset_bytes, - uint64_t /*size_bytes*/) { - return parent + offset_bytes; -} - void XlaInterpreterExecutor::Deallocate(DeviceMemoryBase *mem) { delete[] static_cast(mem->opaque()); } -bool XlaInterpreterExecutor::Memcpy(Stream *stream, void *host_dst, - const DeviceMemoryBase &dev_src, - uint64_t size) { +absl::Status XlaInterpreterExecutor::Memcpy(Stream *stream, void *host_dst, + const DeviceMemoryBase &dev_src, + uint64_t size) { AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() { // Ignore errors. - tsl::Status ok = SynchronousMemcpy(host_dst, dev_src, size); + absl::Status ok = SynchronousMemcpy(host_dst, dev_src, size); }); - tsl::Status status = AsExecutorStream(stream)->BlockUntilDone(); - if (status.ok()) { - return true; - } - - // TODO(b/199316985): Return 'tsl::Status' instead of 'bool', so we don't need - // to throw away error information here. - LOG(WARNING) << "Memcpy: error on stream: " << status; - return false; + return AsExecutorStream(stream)->BlockUntilDone(); } -bool XlaInterpreterExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, - const void *host_src, uint64_t size) { +absl::Status XlaInterpreterExecutor::Memcpy(Stream *stream, + DeviceMemoryBase *dev_dst, + const void *host_src, + uint64_t size) { AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() { // Ignore errors. - tsl::Status ok = SynchronousMemcpy(dev_dst, host_src, size); + absl::Status ok = SynchronousMemcpy(dev_dst, host_src, size); }); - tsl::Status status = AsExecutorStream(stream)->BlockUntilDone(); - if (status.ok()) { - return true; - } - - // TODO(b/199316985): Return 'tsl::Status' instead of 'bool', so we don't need - // to throw away error information here. - LOG(WARNING) << "Memcpy: error on stream: " << status; - return false; + return AsExecutorStream(stream)->BlockUntilDone(); } -tsl::Status XlaInterpreterExecutor::SynchronousMemcpy(DeviceMemoryBase *dev_dst, - const void *host_src, - uint64_t size) { +absl::Status XlaInterpreterExecutor::SynchronousMemcpy( + DeviceMemoryBase *dev_dst, const void *host_src, uint64_t size) { memcpy(dev_dst->opaque(), host_src, size); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status XlaInterpreterExecutor::SynchronousMemcpy( +absl::Status XlaInterpreterExecutor::SynchronousMemcpy( void *host_dst, const DeviceMemoryBase &dev_src, uint64_t size) { memcpy(host_dst, dev_src.opaque(), size); - return ::tsl::OkStatus(); + return absl::OkStatus(); } bool XlaInterpreterExecutor::HostCallback( - Stream *stream, absl::AnyInvocable callback) { + Stream *stream, absl::AnyInvocable callback) { AsExecutorStream(stream)->EnqueueTaskWithStatus(std::move(callback)); return true; } @@ -102,7 +82,7 @@ bool XlaInterpreterExecutor::CreateStreamDependency(Stream *dependent, Stream *other) { AsExecutorStream(dependent)->EnqueueTaskWithStatus( [other]() { return other->BlockHostUntilDone(); }); - tsl::Status status = AsExecutorStream(dependent)->BlockUntilDone(); + absl::Status status = AsExecutorStream(dependent)->BlockUntilDone(); if (status.ok()) { return true; } @@ -113,11 +93,11 @@ bool XlaInterpreterExecutor::CreateStreamDependency(Stream *dependent, return false; } -tsl::Status XlaInterpreterExecutor::BlockHostUntilDone(Stream *stream) { +absl::Status XlaInterpreterExecutor::BlockHostUntilDone(Stream *stream) { return AsExecutorStream(stream)->BlockUntilDone(); } -tsl::StatusOr> +absl::StatusOr> XlaInterpreterExecutor::CreateDeviceDescription(int device_ordinal) { internal::DeviceDescriptionBuilder builder; diff --git a/xla/backends/interpreter/executor.h b/xla/backends/interpreter/executor.h index 683d83a15d58e..31c1c03ea72d6 100644 --- a/xla/backends/interpreter/executor.h +++ b/xla/backends/interpreter/executor.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,7 +27,6 @@ limitations under the License. #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_options.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/host/host_stream.h" #include "xla/stream_executor/kernel.h" @@ -47,25 +46,23 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { public: XlaInterpreterExecutor() = default; - tsl::Status Init(int device_ordinal, DeviceOptions device_options) override { + absl::Status Init(int device_ordinal) override { device_ordinal_ = device_ordinal; - return ::tsl::OkStatus(); + return absl::OkStatus(); } int device_ordinal() const override { return device_ordinal_; }; - tsl::Status GetKernel(const MultiKernelLoaderSpec &spec, - Kernel *kernel) override { + absl::Status GetKernel(const MultiKernelLoaderSpec &spec, + Kernel *kernel) override { return tsl::errors::Unimplemented("Not Implemented"); } - tsl::Status Launch(Stream *stream, const ThreadDim &thread_dims, - const BlockDim &block_dims, const Kernel &kernel, - const KernelArgs &args) override { + absl::Status Launch(Stream *stream, const ThreadDim &thread_dims, + const BlockDim &block_dims, const Kernel &kernel, + const KernelArgs &args) override { return tsl::errors::Unimplemented("Not Implemented"); } DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; - void *GetSubBuffer(DeviceMemoryBase *parent, uint64_t offset_bytes, - uint64_t size_bytes) override; void Deallocate(DeviceMemoryBase *mem) override; void *HostMemoryAllocate(uint64_t size) override { return new char[size]; } @@ -75,66 +72,67 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { bool HostMemoryRegister(void *mem, uint64_t size) override { return true; } bool HostMemoryUnregister(void *mem) override { return true; } - bool Memcpy(Stream *stream, void *host_dst, const DeviceMemoryBase &dev_src, - uint64_t size) override; - bool Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, const void *host_src, - uint64_t size) override; + absl::Status Memcpy(Stream *stream, void *host_dst, + const DeviceMemoryBase &dev_src, uint64_t size) override; + absl::Status Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, + const void *host_src, uint64_t size) override; bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *pop_dst, const DeviceMemoryBase &host_src, uint64_t size) override { return false; } - tsl::Status MemZero(Stream *stream, DeviceMemoryBase *location, - uint64_t size) override { + absl::Status MemZero(Stream *stream, DeviceMemoryBase *location, + uint64_t size) override { return tsl::errors::Internal("Interpreter can not memzero"); } - tsl::Status Memset(Stream *stream, DeviceMemoryBase *location, - uint8_t pattern, uint64_t size) override { + absl::Status Memset(Stream *stream, DeviceMemoryBase *location, + uint8_t pattern, uint64_t size) override { return tsl::errors::Internal("Interpreter can not memset"); } - tsl::Status Memset32(Stream *stream, DeviceMemoryBase *location, - uint32_t pattern, uint64_t size) override { + absl::Status Memset32(Stream *stream, DeviceMemoryBase *location, + uint32_t pattern, uint64_t size) override { return tsl::errors::Internal("Interpreter can not memset"); } // No "synchronize all activity" implemented for this platform at the moment. bool SynchronizeAllActivity() override { return true; } - tsl::Status SynchronousMemZero(DeviceMemoryBase *location, - uint64_t size) override { + absl::Status SynchronousMemZero(DeviceMemoryBase *location, + uint64_t size) override { return tsl::errors::Internal("Interpreter can not memzero"); } - tsl::Status SynchronousMemSet(DeviceMemoryBase *location, int value, - uint64_t size) override { + absl::Status SynchronousMemSet(DeviceMemoryBase *location, int value, + uint64_t size) override { return tsl::errors::Internal("Interpreter can not memset"); } - tsl::Status SynchronousMemcpy(DeviceMemoryBase *dev_dst, const void *host_src, - uint64_t size) override; - tsl::Status SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &dev_src, - uint64_t size) override; - tsl::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *pop_dst, - const DeviceMemoryBase &pop_src, - uint64_t size) override { - return tsl::Status{absl::StatusCode::kUnimplemented, ""}; + absl::Status SynchronousMemcpy(DeviceMemoryBase *dev_dst, + const void *host_src, uint64_t size) override; + absl::Status SynchronousMemcpy(void *host_dst, + const DeviceMemoryBase &dev_src, + uint64_t size) override; + absl::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *pop_dst, + const DeviceMemoryBase &pop_src, + uint64_t size) override { + return absl::Status{absl::StatusCode::kUnimplemented, ""}; } bool HostCallback(Stream *stream, - absl::AnyInvocable callback) override; + absl::AnyInvocable callback) override; - tsl::Status AllocateEvent(Event *event) override { return ::tsl::OkStatus(); } + absl::Status AllocateEvent(Event *event) override { return absl::OkStatus(); } - tsl::Status DeallocateEvent(Event *event) override { - return ::tsl::OkStatus(); + absl::Status DeallocateEvent(Event *event) override { + return absl::OkStatus(); } - tsl::Status RecordEvent(Stream *stream, Event *event) override { - return tsl::Status{absl::StatusCode::kUnimplemented, "RecordEvent"}; + absl::Status RecordEvent(Stream *stream, Event *event) override { + return absl::Status{absl::StatusCode::kUnimplemented, "RecordEvent"}; } - tsl::Status WaitForEvent(Stream *stream, Event *event) override { - return tsl::Status{absl::StatusCode::kUnimplemented, "WaitForEvent"}; + absl::Status WaitForEvent(Stream *stream, Event *event) override { + return absl::Status{absl::StatusCode::kUnimplemented, "WaitForEvent"}; } Event::Status PollForEventStatus(Event *event) override { @@ -145,22 +143,22 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { void DeallocateStream(Stream *stream) override {} bool CreateStreamDependency(Stream *dependent, Stream *other) override; - tsl::Status BlockHostUntilDone(Stream *stream) override; + absl::Status BlockHostUntilDone(Stream *stream) override; bool DeviceMemoryUsage(int64_t *free, int64_t *total) const override { return false; } - tsl::StatusOr> CreateDeviceDescription() + absl::StatusOr> CreateDeviceDescription() const override { return CreateDeviceDescription(0); } - static tsl::StatusOr> + static absl::StatusOr> CreateDeviceDescription(int device_ordinal); - tsl::Status EnablePeerAccessTo(StreamExecutorInterface *other) override { - return ::tsl::OkStatus(); + absl::Status EnablePeerAccessTo(StreamExecutorInterface *other) override { + return absl::OkStatus(); } bool CanEnablePeerAccessTo(StreamExecutorInterface *other) override { @@ -172,15 +170,9 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { return nullptr; } - std::unique_ptr CreateKernelImplementation() - override { - return nullptr; - } - std::unique_ptr GetStreamImplementation() override { - return std::unique_ptr( - new host::HostStream(/*thread_stack_size=*/0)); + return std::make_unique(); } private: @@ -190,7 +182,8 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape); - tsl::StatusOr AllocateOutputBuffer(const xla::Shape &shape); + absl::StatusOr AllocateOutputBuffer( + const xla::Shape &shape); }; } // namespace interpreter diff --git a/xla/backends/interpreter/interpreter_transfer_manager.cc b/xla/backends/interpreter/interpreter_transfer_manager.cc index b0876c02edac3..6f0d193b3562b 100644 --- a/xla/backends/interpreter/interpreter_transfer_manager.cc +++ b/xla/backends/interpreter/interpreter_transfer_manager.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/interpreter/interpreter_transfer_manager.h b/xla/backends/interpreter/interpreter_transfer_manager.h index 832121dd16fd0..86bdd55d2583c 100644 --- a/xla/backends/interpreter/interpreter_transfer_manager.h +++ b/xla/backends/interpreter/interpreter_transfer_manager.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/interpreter/platform.cc b/xla/backends/interpreter/platform.cc index b79a42e332763..0a9681ee0840f 100644 --- a/xla/backends/interpreter/platform.cc +++ b/xla/backends/interpreter/platform.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,10 +20,9 @@ limitations under the License. #include "absl/strings/str_format.h" #include "xla/backends/interpreter/executor.h" -#include "xla/stream_executor/device_options.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/initialize.h" +#include "xla/stream_executor/platform_manager.h" #include "tsl/platform/status.h" namespace stream_executor { @@ -41,33 +40,32 @@ int XlaInterpreterPlatform::VisibleDeviceCount() const { return 1; } const std::string& XlaInterpreterPlatform::Name() const { return name_; } -tsl::StatusOr> +absl::StatusOr> XlaInterpreterPlatform::DescriptionForDevice(int ordinal) const { return XlaInterpreterExecutor::CreateDeviceDescription(ordinal); } -tsl::StatusOr XlaInterpreterPlatform::ExecutorForDevice( +absl::StatusOr XlaInterpreterPlatform::ExecutorForDevice( int ordinal) { StreamExecutorConfig config; config.ordinal = ordinal; - config.device_options = DeviceOptions::Default(); return GetExecutor(config); } -tsl::StatusOr XlaInterpreterPlatform::GetExecutor( +absl::StatusOr XlaInterpreterPlatform::GetExecutor( const StreamExecutorConfig& config) { return executor_cache_.GetOrCreate( config, [&]() { return GetUncachedExecutor(config); }); } -tsl::StatusOr> +absl::StatusOr> XlaInterpreterPlatform::GetUncachedExecutor( const StreamExecutorConfig& config) { auto executor = std::make_unique( this, std::make_unique(), config.ordinal); - auto init_status = executor->Init(config.device_options); + auto init_status = executor->Init(); if (!init_status.ok()) { - return tsl::Status{ + return absl::Status{ absl::StatusCode::kInternal, absl::StrFormat( "failed initializing StreamExecutor for device ordinal %d: %s", @@ -79,19 +77,12 @@ XlaInterpreterPlatform::GetUncachedExecutor( static void InitializeXlaInterpreterPlatform() { std::unique_ptr platform(new XlaInterpreterPlatform); - TF_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform))); + TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform))); } } // namespace interpreter } // namespace stream_executor -REGISTER_MODULE_INITIALIZER( +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER( interpreter_platform, stream_executor::interpreter::InitializeXlaInterpreterPlatform()); - -// Note that module initialization sequencing is not supported in the -// open-source project, so this will be a no-op there. -REGISTER_MODULE_INITIALIZER_SEQUENCE(interpreter_platform, - multi_platform_manager); -REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener, - interpreter_platform); diff --git a/xla/backends/interpreter/platform.h b/xla/backends/interpreter/platform.h index 2833d09969339..c81f7f7f2fd60 100644 --- a/xla/backends/interpreter/platform.h +++ b/xla/backends/interpreter/platform.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ limitations under the License. #include "xla/backends/interpreter/platform_id.h" #include "xla/stream_executor/executor_cache.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/trace_listener.h" +#include "tsl/platform/statusor.h" namespace stream_executor { namespace interpreter { @@ -39,15 +39,15 @@ class XlaInterpreterPlatform : public Platform { const std::string& Name() const override; - tsl::StatusOr> DescriptionForDevice( + absl::StatusOr> DescriptionForDevice( int ordinal) const override; - tsl::StatusOr ExecutorForDevice(int ordinal) override; + absl::StatusOr ExecutorForDevice(int ordinal) override; - tsl::StatusOr GetExecutor( + absl::StatusOr GetExecutor( const StreamExecutorConfig& config) override; - tsl::StatusOr> GetUncachedExecutor( + absl::StatusOr> GetUncachedExecutor( const StreamExecutorConfig& config) override; private: diff --git a/xla/backends/interpreter/platform_id.cc b/xla/backends/interpreter/platform_id.cc index 459f78ba56310..c788860207736 100644 --- a/xla/backends/interpreter/platform_id.cc +++ b/xla/backends/interpreter/platform_id.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/interpreter/platform_id.h b/xla/backends/interpreter/platform_id.h index 2e8caf79b9dd5..7552d9c63be31 100644 --- a/xla/backends/interpreter/platform_id.h +++ b/xla/backends/interpreter/platform_id.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/profiler/BUILD b/xla/backends/profiler/BUILD index 7b96d8086a94b..c4c63968769b1 100644 --- a/xla/backends/profiler/BUILD +++ b/xla/backends/profiler/BUILD @@ -1,4 +1,9 @@ -load("@tsl//tsl:tsl.bzl", "if_with_tpu_support", "tsl_gpu_library") +load( + "@tsl//tsl:tsl.bzl", + "if_with_tpu_support", + "internal_visibility", + "tsl_gpu_library", +) # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) @@ -13,7 +18,7 @@ package_group( tsl_gpu_library( name = "profiler_backends", - visibility = ["//xla:internal"], + visibility = internal_visibility(["//xla:internal"]), deps = [ "//xla/backends/profiler/cpu:host_tracer", "//xla/backends/profiler/cpu:metadata_collector", diff --git a/xla/backends/profiler/cpu/BUILD b/xla/backends/profiler/cpu/BUILD index 919d27b7db93d..f0154fed4eb80 100644 --- a/xla/backends/profiler/cpu/BUILD +++ b/xla/backends/profiler/cpu/BUILD @@ -1,18 +1,24 @@ +load("@tsl//tsl:tsl.bzl", "internal_visibility") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load("@tsl//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") +load( + "//xla:xla.bzl", + "xla_cc_test", +) # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) cc_library( name = "host_tracer", srcs = ["host_tracer_factory.cc"], - visibility = [ + visibility = internal_visibility([ "//xla/backends/profiler:__pkg__", # copybara:uncomment "//tensorflow/core/profiler:internal", - ], + ]), deps = [ ":host_tracer_impl", "@tsl//tsl/profiler/lib:profiler_factory", + "@tsl//tsl/profiler/lib:profiler_interface", "@tsl//tsl/profiler/protobuf:profiler_options_proto_cc", ], alwayslink = True, @@ -23,13 +29,13 @@ cc_library( srcs = ["host_tracer.cc"], hdrs = ["host_tracer.h"], copts = tf_profiler_copts(), - visibility = [ + visibility = internal_visibility([ # copybara:uncomment "//tensorflow/core/profiler:internal", - ], + ]), deps = [ + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:types", "@tsl//tsl/profiler/backends/cpu:host_tracer_utils", "@tsl//tsl/profiler/backends/cpu:traceme_recorder", "@tsl//tsl/profiler/lib:profiler_interface", @@ -43,13 +49,15 @@ cc_library( cc_library( name = "python_tracer", srcs = ["python_tracer_factory.cc"], - visibility = [ + visibility = internal_visibility([ "//xla/python:__pkg__", # copybara:uncomment "//tensorflow/core/profiler:internal", - ], + # copybara:uncomment "//tensorflow:internal", + ]), deps = [ ":python_tracer_impl", "@tsl//tsl/profiler/lib:profiler_factory", + "@tsl//tsl/profiler/lib:profiler_interface", "@tsl//tsl/profiler/protobuf:profiler_options_proto_cc", ], alwayslink = True, @@ -61,15 +69,14 @@ cc_library( hdrs = ["python_tracer.h"], copts = tf_profiler_copts() + ["-fexceptions"], features = ["-use_header_modules"], - visibility = [ + visibility = internal_visibility([ # copybara:uncomment "//tensorflow/core/profiler:internal", - ], + ]), deps = [ "//xla/python/profiler/internal:python_hooks", + "@com_google_absl//absl/status", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:macros", - "@tsl//tsl/platform:status", "@tsl//tsl/profiler/lib:profiler_interface", "@tsl//tsl/profiler/protobuf:xplane_proto_cc", ], @@ -79,20 +86,20 @@ cc_library( name = "metadata_collector", srcs = ["metadata_collector.cc"], copts = tf_profiler_copts(), - visibility = [ + visibility = internal_visibility([ "//xla/backends/profiler:__pkg__", # copybara:uncomment "//tensorflow/core/profiler:internal", - ], + ]), deps = [ ":metadata_utils", + "//xla:status", "//xla/service:hlo_proto_cc", "//xla/service:xla_debug_info_manager", - "@tsl//tsl/platform:macros", - "@tsl//tsl/platform:status", "@tsl//tsl/profiler/lib:profiler_factory", "@tsl//tsl/profiler/lib:profiler_interface", "@tsl//tsl/profiler/protobuf:profiler_options_proto_cc", "@tsl//tsl/profiler/protobuf:xplane_proto_cc", + "@tsl//tsl/profiler/utils:xplane_builder", "@tsl//tsl/profiler/utils:xplane_schema", "@tsl//tsl/profiler/utils:xplane_utils", ], @@ -102,9 +109,9 @@ cc_library( cc_library( name = "metadata_utils", hdrs = ["metadata_utils.h"], - visibility = [ + visibility = internal_visibility([ # copybara:uncomment "//tensorflow/core/profiler:internal", - ], + ]), deps = [ "//xla/service:hlo_proto_cc", "@tsl//tsl/profiler/convert:xla_op_utils", @@ -113,3 +120,22 @@ cc_library( "@tsl//tsl/profiler/utils:xplane_schema", ], ) + +xla_cc_test( + name = "host_tracer_test", + srcs = ["host_tracer_test.cc"], + deps = [ + ":host_tracer_impl", + "@com_google_absl//absl/types:optional", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:types", + "@tsl//tsl/profiler/lib:profiler_interface", + "@tsl//tsl/profiler/lib:traceme", + "@tsl//tsl/profiler/protobuf:xplane_proto_cc", + "@tsl//tsl/profiler/utils:xplane_schema", + "@tsl//tsl/profiler/utils:xplane_visitor", + ], +) diff --git a/xla/backends/profiler/cpu/host_tracer.cc b/xla/backends/profiler/cpu/host_tracer.cc index f4d4dfd9a2284..ad79fb1c398cc 100644 --- a/xla/backends/profiler/cpu/host_tracer.cc +++ b/xla/backends/profiler/cpu/host_tracer.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "absl/log/log.h" +#include "absl/status/status.h" #include "tsl/platform/errors.h" -#include "tsl/platform/status.h" -#include "tsl/platform/types.h" #include "tsl/profiler/backends/cpu/host_tracer_utils.h" #include "tsl/profiler/backends/cpu/traceme_recorder.h" #include "tsl/profiler/lib/profiler_interface.h" @@ -42,11 +42,11 @@ class HostTracer : public tsl::profiler::ProfilerInterface { explicit HostTracer(int host_trace_level); ~HostTracer() override; - tsl::Status Start() override; // TENSORFLOW_STATUS_OK + absl::Status Start() override; // TENSORFLOW_STATUS_OK - tsl::Status Stop() override; // TENSORFLOW_STATUS_OK + absl::Status Stop() override; // TENSORFLOW_STATUS_OK - tsl::Status CollectData( // TENSORFLOW_STATUS_OK + absl::Status CollectData( // TENSORFLOW_STATUS_OK tensorflow::profiler::XSpace* space) override; private: @@ -68,7 +68,7 @@ HostTracer::HostTracer(int host_trace_level) HostTracer::~HostTracer() { Stop().IgnoreError(); } // NOLINT -tsl::Status HostTracer::Start() { // TENSORFLOW_STATUS_OK +absl::Status HostTracer::Start() { // TENSORFLOW_STATUS_OK if (recording_) { return tsl::errors::Internal("TraceMeRecorder already started"); } @@ -81,33 +81,33 @@ tsl::Status HostTracer::Start() { // TENSORFLOW_STATUS_OK if (!recording_) { return tsl::errors::Internal("Failed to start TraceMeRecorder"); } - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status HostTracer::Stop() { // TENSORFLOW_STATUS_OK +absl::Status HostTracer::Stop() { // TENSORFLOW_STATUS_OK if (!recording_) { return tsl::errors::Internal("TraceMeRecorder not started"); } events_ = tsl::profiler::TraceMeRecorder::Stop(); recording_ = false; - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status HostTracer::CollectData( // TENSORFLOW_STATUS_OK +absl::Status HostTracer::CollectData( // TENSORFLOW_STATUS_OK tensorflow::profiler::XSpace* space) { VLOG(2) << "Collecting data to XSpace from HostTracer."; if (recording_) { return tsl::errors::Internal("TraceMeRecorder not stopped"); } if (events_.empty()) { - return tsl::OkStatus(); + return absl::OkStatus(); } tensorflow::profiler::XPlane* plane = tsl::profiler::FindOrAddMutablePlaneWithName( space, tsl::profiler::kHostThreadsPlaneName); ConvertCompleteEventsToXPlane(start_timestamp_ns_, std::exchange(events_, {}), plane); - return tsl::OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/xla/backends/profiler/cpu/host_tracer.h b/xla/backends/profiler/cpu/host_tracer.h index fb2bdd20fedb3..975941086d1c9 100644 --- a/xla/backends/profiler/cpu/host_tracer.h +++ b/xla/backends/profiler/cpu/host_tracer.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/profiler/cpu/host_tracer_factory.cc b/xla/backends/profiler/cpu/host_tracer_factory.cc index 414ce198e822f..9732ace9dc8f7 100644 --- a/xla/backends/profiler/cpu/host_tracer_factory.cc +++ b/xla/backends/profiler/cpu/host_tracer_factory.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include "xla/backends/profiler/cpu/host_tracer.h" #include "tsl/profiler/lib/profiler_factory.h" +#include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" namespace xla { diff --git a/xla/backends/profiler/cpu/host_tracer_test.cc b/xla/backends/profiler/cpu/host_tracer_test.cc new file mode 100644 index 0000000000000..881f46e50837f --- /dev/null +++ b/xla/backends/profiler/cpu/host_tracer_test.cc @@ -0,0 +1,158 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/backends/profiler/cpu/host_tracer.h" + +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/test.h" +#include "tsl/platform/types.h" +#include "tsl/profiler/lib/profiler_interface.h" +#include "tsl/profiler/lib/traceme.h" +#include "tsl/profiler/protobuf/xplane.pb.h" +#include "tsl/profiler/utils/xplane_schema.h" +#include "tsl/profiler/utils/xplane_visitor.h" + +namespace xla { +namespace profiler { +namespace { + +using ::tsl::Env; +using ::tsl::Thread; +using ::tsl::ThreadOptions; +using ::tsl::profiler::TraceMe; +using ::tsl::profiler::XEventVisitor; +using ::tsl::profiler::XPlaneVisitor; +using ::tsl::profiler::XStatVisitor; + +TEST(HostTracerTest, CollectsTraceMeEventsAsXSpace) { + tsl::uint32 thread_id; + std::string thread_name = "MyThreadName"; + tensorflow::profiler::XSpace space; + + // We start a thread with a known and controlled name. As of the time of + // writing, not all platforms (example: Windows) allow reading through the + // system to the current thread name/description. By starting a thread with a + // name, we control this behavior entirely within the TensorFlow subsystems. + std::unique_ptr traced_thread( + Env::Default()->StartThread(ThreadOptions(), thread_name, [&] { + // Some implementations add additional information to the thread name. + // Recapture this information. + ASSERT_TRUE(Env::Default()->GetCurrentThreadName(&thread_name)); + thread_id = Env::Default()->GetCurrentThreadId(); + + auto tracer = CreateHostTracer({}); + + TF_ASSERT_OK(tracer->Start()); + { TraceMe traceme("hello"); } + { TraceMe traceme("world"); } + { TraceMe traceme("contains#inside"); } + { TraceMe traceme("good#key1=value1#"); } + { TraceMe traceme("morning#key1=value1,key2=value2#"); } + { TraceMe traceme("incomplete#key1=value1,key2#"); } + // Special cases for tf.data + { TraceMe traceme("Iterator::XXX::YYY::ParallelMap"); } + TF_ASSERT_OK(tracer->Stop()); + + TF_ASSERT_OK(tracer->CollectData(&space)); + })); + traced_thread.reset(); // Join thread, waiting for completion. + ASSERT_NO_FATAL_FAILURE(); // Test for failure in child thread. + + ASSERT_EQ(space.planes_size(), 1); + const auto& plane = space.planes(0); + XPlaneVisitor xplane(&plane); + ASSERT_EQ(plane.name(), ::tsl::profiler::kHostThreadsPlaneName); + ASSERT_EQ(plane.lines_size(), 1); + ASSERT_EQ(plane.event_metadata_size(), 7); + ASSERT_EQ(plane.stat_metadata_size(), 4); + const auto& line = plane.lines(0); + EXPECT_EQ(line.id(), thread_id); + EXPECT_EQ(line.name(), thread_name); + ASSERT_EQ(line.events_size(), 7); + const auto& events = line.events(); + + XEventVisitor e0(&xplane, &line, &events[0]); + EXPECT_EQ(e0.Name(), "hello"); + ASSERT_EQ(events[0].stats_size(), 0); + + XEventVisitor e1(&xplane, &line, &events[1]); + EXPECT_EQ(e1.Name(), "world"); + ASSERT_EQ(events[1].stats_size(), 0); + + XEventVisitor e2(&xplane, &line, &events[2]); + EXPECT_EQ(e2.Name(), "contains#inside"); + ASSERT_EQ(events[2].stats_size(), 0); + + XEventVisitor e3(&xplane, &line, &events[3]); + EXPECT_EQ(e3.Name(), "good"); + ASSERT_EQ(events[3].stats_size(), 1); + { + std::optional value; + e3.ForEachStat([&](const XStatVisitor& stat) { + if (stat.Name() == "key1") value = stat.ToString(); + }); + ASSERT_TRUE(value); // The stat key is present. + EXPECT_EQ(*value, "value1"); // The stat value is expected. + } + + XEventVisitor e4(&xplane, &line, &events[4]); + EXPECT_EQ(e4.Name(), "morning"); + ASSERT_EQ(events[4].stats_size(), 2); + { + std::optional value1, value2; + e4.ForEachStat([&](const XStatVisitor& stat) { + if (stat.Name() == "key1") { + value1 = stat.ToString(); + } else if (stat.Name() == "key2") { + value2 = stat.ToString(); + } + }); + ASSERT_TRUE(value1 && value2); // The stat keys are presents. + EXPECT_EQ(*value1, "value1"); // The stat value1 is expected. + EXPECT_EQ(*value2, "value2"); // The stat value2 is expected. + } + + XEventVisitor e5(&xplane, &line, &events[5]); + EXPECT_EQ(e5.Name(), "incomplete"); + ASSERT_EQ(events[5].stats_size(), 1); + { + std::optional value1, value2; + e5.ForEachStat([&](const XStatVisitor& stat) { + if (stat.Name() == "key1") { + value1 = stat.ToString(); + } else if (stat.Name() == "key2") { + value2 = stat.ToString(); + } + }); + ASSERT_TRUE(value1 && !value2); // One of the stat key is present. + EXPECT_EQ(*value1, "value1"); // The stat value is expected. + } + + // Dataset Ops will trim intermediate namespace. + XEventVisitor e6(&xplane, &line, &events[6]); + EXPECT_EQ(e6.Name(), "Iterator::XXX::YYY::ParallelMap"); + + EXPECT_EQ(e6.DisplayName(), "Iterator::ParallelMap"); +} + +} // namespace +} // namespace profiler +} // namespace xla diff --git a/xla/backends/profiler/cpu/metadata_collector.cc b/xla/backends/profiler/cpu/metadata_collector.cc index 19bc06d288fc1..a19d4527ec930 100644 --- a/xla/backends/profiler/cpu/metadata_collector.cc +++ b/xla/backends/profiler/cpu/metadata_collector.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,12 +21,12 @@ limitations under the License. #include "xla/backends/profiler/cpu/metadata_utils.h" #include "xla/service/hlo.pb.h" #include "xla/service/xla_debug_info_manager.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/status.h" +#include "xla/status.h" #include "tsl/profiler/lib/profiler_factory.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "tsl/profiler/utils/xplane_builder.h" #include "tsl/profiler/utils/xplane_schema.h" #include "tsl/profiler/utils/xplane_utils.h" diff --git a/xla/backends/profiler/cpu/metadata_utils.h b/xla/backends/profiler/cpu/metadata_utils.h index 1e3c8a215faa5..dfef42c7baeb4 100644 --- a/xla/backends/profiler/cpu/metadata_utils.h +++ b/xla/backends/profiler/cpu/metadata_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/profiler/cpu/python_tracer.cc b/xla/backends/profiler/cpu/python_tracer.cc index d45a221990f4d..30c9982d9b132 100644 --- a/xla/backends/profiler/cpu/python_tracer.cc +++ b/xla/backends/profiler/cpu/python_tracer.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,11 +16,10 @@ limitations under the License. #include +#include "absl/status/status.h" #include "xla/python/profiler/internal/python_hooks.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/status.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/xplane.pb.h" @@ -35,11 +34,11 @@ class PythonTracer : public tsl::profiler::ProfilerInterface { : options_(options) {} ~PythonTracer() override; - tsl::Status Start() override; // TENSORFLOW_STATUS_OK + absl::Status Start() override; // TENSORFLOW_STATUS_OK - tsl::Status Stop() override; // TENSORFLOW_STATUS_OK + absl::Status Stop() override; // TENSORFLOW_STATUS_OK - tsl::Status CollectData( // TENSORFLOW_STATUS_OK + absl::Status CollectData( // TENSORFLOW_STATUS_OK tensorflow::profiler::XSpace* space) override; private: @@ -53,34 +52,34 @@ class PythonTracer : public tsl::profiler::ProfilerInterface { PythonTracer::~PythonTracer() { Stop().IgnoreError(); } // NOLINT -tsl::Status PythonTracer::Start() { // TENSORFLOW_STATUS_OK +absl::Status PythonTracer::Start() { // TENSORFLOW_STATUS_OK if (recording_) { return tsl::errors::Internal("PythonTracer already started"); } VLOG(1) << __FUNCTION__; recording_ = true; PythonHooks::GetSingleton()->Start(options_); - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status PythonTracer::Stop() { // TENSORFLOW_STATUS_OK +absl::Status PythonTracer::Stop() { // TENSORFLOW_STATUS_OK if (!recording_) { return tsl::errors::Internal("PythonTracer not started"); } VLOG(1) << __FUNCTION__; context_ = PythonHooks::GetSingleton()->Stop(); recording_ = false; - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status PythonTracer::CollectData( // TENSORFLOW_STATUS_OK +absl::Status PythonTracer::CollectData( // TENSORFLOW_STATUS_OK tensorflow::profiler::XSpace* space) { VLOG(2) << "Collecting data to XSpace from PythonTracer."; if (context_) { context_->Finalize(space); context_.reset(); } - return tsl::OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/xla/backends/profiler/cpu/python_tracer.h b/xla/backends/profiler/cpu/python_tracer.h index 86e11ae9e8949..0ea3c7e4ade7c 100644 --- a/xla/backends/profiler/cpu/python_tracer.h +++ b/xla/backends/profiler/cpu/python_tracer.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/profiler/cpu/python_tracer_factory.cc b/xla/backends/profiler/cpu/python_tracer_factory.cc index 07926ea1b5368..865d399e62a2b 100644 --- a/xla/backends/profiler/cpu/python_tracer_factory.cc +++ b/xla/backends/profiler/cpu/python_tracer_factory.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include "xla/backends/profiler/cpu/python_tracer.h" #include "tsl/profiler/lib/profiler_factory.h" +#include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" namespace xla { diff --git a/xla/backends/profiler/gpu/BUILD b/xla/backends/profiler/gpu/BUILD index f225c603e1db2..37533e052fa68 100644 --- a/xla/backends/profiler/gpu/BUILD +++ b/xla/backends/profiler/gpu/BUILD @@ -1,11 +1,8 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library", "if_cuda") -load( - "//xla:xla.bzl", - "xla_cc_test", -) load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load( "@tsl//tsl:tsl.bzl", + "internal_visibility", "tsl_copts", "tsl_gpu_library", ) @@ -17,16 +14,19 @@ load( "@tsl//tsl/platform:build_config_root.bzl", "tf_cuda_tests_tags", ) -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load( "@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) load("@tsl//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") +load( + "//xla:xla.bzl", + "xla_cc_test", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//xla:internal"], + default_visibility = internal_visibility(["//xla:internal"]), ) tsl_gpu_library( @@ -37,26 +37,22 @@ tsl_gpu_library( ":cupti_collector", ":cupti_tracer", ":cupti_wrapper", + ":rocm_collector", ":rocm_tracer", ], deps = [ ":cupti_utils", + "//xla/tsl/util:env_var", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:abi", - "@tsl//tsl/platform:env_time", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:macros", - "@tsl//tsl/platform:mutex", "@tsl//tsl/platform:thread_annotations", "@tsl//tsl/profiler/lib:profiler_factory", "@tsl//tsl/profiler/lib:profiler_interface", "@tsl//tsl/profiler/protobuf:xplane_proto_cc", "@tsl//tsl/profiler/utils:time_utils", - "@tsl//tsl/util:env_var", ], alwayslink = 1, ) @@ -69,7 +65,7 @@ tsl_gpu_library( deps = [ "@tsl//tsl/platform:macros", "@tsl//tsl/platform:types", - ] + if_cuda(["@tsl//tsl/cuda:cupti"]), + ] + if_cuda(["//xla/tsl/cuda:cupti"]), ) tsl_gpu_library( @@ -96,9 +92,7 @@ tsl_gpu_library( ], visibility = ["//visibility:public"], deps = [ - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/debugging:leak_check", - "@com_google_absl//absl/synchronization", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:mutex", "@tsl//tsl/platform:thread_annotations", @@ -145,7 +139,7 @@ cuda_library( visibility = ["//visibility:public"], deps = [ "@local_config_cuda//cuda:cuda_headers", - "@local_config_cuda//cuda:cudart", + "@local_config_cuda//cuda:cudart_static", "@tsl//tsl/platform:test", ], ) @@ -158,14 +152,17 @@ cuda_library( # that the wrapper is about the only direct user. tsl_gpu_library( name = "cupti_wrapper", - srcs = if_cuda(["cupti_wrapper.cc"]), + srcs = if_cuda([ + "cupti_wrapper.cc", + "cupti_wrapper_stub.cc", + ]), hdrs = if_cuda(["cupti_wrapper.h"]), copts = tf_profiler_copts() + tsl_copts(), linkstatic = 1, visibility = ["//visibility:public"], deps = [ ":cupti_interface", - ] + if_cuda(["@tsl//tsl/cuda:cupti"]), + ] + if_cuda(["//xla/tsl/cuda:cupti"]), ) tsl_gpu_library( @@ -184,13 +181,13 @@ tsl_gpu_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:macros", "@tsl//tsl/platform:platform_port", - "@tsl//tsl/platform:status", "@tsl//tsl/platform:types", "@tsl//tsl/profiler/backends/cpu:annotation_stack", "@tsl//tsl/profiler/lib:scoped_annotation", @@ -211,13 +208,12 @@ tsl_gpu_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/container:node_hash_set", - "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/status", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:macros", "@tsl//tsl/platform:platform_port", - "@tsl//tsl/platform:status", "@tsl//tsl/platform:types", "@tsl//tsl/profiler/backends/cpu:annotation_stack", "@tsl//tsl/profiler/lib:scoped_annotation", @@ -225,6 +221,37 @@ tsl_gpu_library( ], ) +tsl_gpu_library( + name = "rocm_collector", + srcs = if_rocm(["rocm_collector.cc"]), + hdrs = if_rocm(["rocm_collector.h"]), + copts = tf_profiler_copts() + tsl_copts(), + visibility = ["//visibility:public"], + deps = [ + "//xla/stream_executor/rocm:roctracer_wrapper", + "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/types:optional", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:macros", + "@tsl//tsl/platform:platform_port", + "@tsl//tsl/profiler/backends/cpu:annotation_stack", + "@tsl//tsl/profiler/lib:profiler_factory", + "@tsl//tsl/profiler/lib:profiler_interface", + "@tsl//tsl/profiler/protobuf:xplane_proto_cc", + "@tsl//tsl/profiler/utils:parse_annotation", + "@tsl//tsl/profiler/utils:time_utils", + "@tsl//tsl/profiler/utils:trace_utils", + "@tsl//tsl/profiler/utils:xplane_builder", + "@tsl//tsl/profiler/utils:xplane_schema", + "@tsl//tsl/profiler/utils:xplane_utils", + ], +) + tsl_gpu_library( name = "rocm_tracer", srcs = if_rocm(["rocm_tracer.cc"]), @@ -232,18 +259,20 @@ tsl_gpu_library( copts = tf_profiler_copts() + tsl_copts(), visibility = ["//visibility:public"], deps = [ + ":rocm_collector", "//xla/stream_executor/rocm:roctracer_wrapper", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/container:node_hash_set", - "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/status", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:macros", "@tsl//tsl/platform:platform_port", "@tsl//tsl/profiler/backends/cpu:annotation_stack", + "@tsl//tsl/profiler/protobuf:xplane_proto_cc", "@tsl//tsl/profiler/utils:time_utils", ], ) @@ -257,14 +286,13 @@ tsl_gpu_library( "@com_google_absl//absl/strings", "@tsl//tsl/platform", "@tsl//tsl/platform:macros", - "@tsl//tsl/platform:mutex", ], ) tsl_gpu_library( name = "cupti_collector", srcs = if_cuda(["cupti_collector.cc"]), - hdrs = if_cuda(["cupti_collector.h"]), + hdrs = ["cupti_collector.h"], copts = tf_profiler_copts() + tsl_copts(), visibility = ["//visibility:public"], deps = [ @@ -273,12 +301,12 @@ tsl_gpu_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@tsl//tsl/platform:abi", "@tsl//tsl/platform:macros", "@tsl//tsl/platform:mutex", "@tsl//tsl/platform:platform_port", - "@tsl//tsl/platform:status", "@tsl//tsl/platform:types", "@tsl//tsl/profiler/protobuf:xplane_proto_cc", "@tsl//tsl/profiler/utils:parse_annotation", @@ -286,23 +314,7 @@ tsl_gpu_library( "@tsl//tsl/profiler/utils:xplane_builder", "@tsl//tsl/profiler/utils:xplane_schema", "@tsl//tsl/profiler/utils:xplane_utils", - ] + if_cuda(["@tsl//tsl/cuda:cupti"]), -) - -cc_library( - name = "cupti_collector_header", - hdrs = ["cupti_collector.h"], - visibility = ["//visibility:public"], - deps = [ - "@com_google_absl//absl/container:fixed_array", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:node_hash_set", - "@com_google_absl//absl/strings", - "@tsl//tsl/platform:macros", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:types", - "@tsl//tsl/profiler/protobuf:xplane_proto_cc", - ], + ] + if_cuda(["//xla/tsl/cuda:cupti"]), ) tsl_gpu_library( @@ -313,7 +325,11 @@ tsl_gpu_library( ":cupti_error_manager", ":cupti_interface", ":cupti_wrapper", + "@com_google_absl//absl/base", "@com_google_absl//absl/memory", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:stringpiece", + "//xla/tsl/util:env_var", ], visibility = ["//visibility:public"], alwayslink = 1, diff --git a/xla/backends/profiler/gpu/cuda_test.cu.cc b/xla/backends/profiler/gpu/cuda_test.cu.cc index b41c94aadd8a8..a86c04c6cf8e7 100644 --- a/xla/backends/profiler/gpu/cuda_test.cu.cc +++ b/xla/backends/profiler/gpu/cuda_test.cu.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/profiler/gpu/cuda_test.h b/xla/backends/profiler/gpu/cuda_test.h index c1c9fc68d2fab..b59c5384e021d 100644 --- a/xla/backends/profiler/gpu/cuda_test.h +++ b/xla/backends/profiler/gpu/cuda_test.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/profiler/gpu/cupti_collector.cc b/xla/backends/profiler/gpu/cupti_collector.cc index 2e97339714a07..c29cfc558c063 100644 --- a/xla/backends/profiler/gpu/cupti_collector.cc +++ b/xla/backends/profiler/gpu/cupti_collector.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -140,7 +140,7 @@ class PerDeviceCollector { } void CreateXEvent(const CuptiTracerEvent& event, XPlaneBuilder* plane, - tsl::uint64 start_gpu_ns, tsl::uint64 end_gpu_ns, + uint64_t start_gpu_ns, uint64_t end_gpu_ns, XLineBuilder* line) { if (event.start_time_ns < start_gpu_ns || event.end_time_ns > end_gpu_ns || event.start_time_ns > event.end_time_ns) { @@ -177,7 +177,7 @@ class PerDeviceCollector { if (event.context_id != CuptiTracerEvent::kInvalidContextId) { xevent.AddStatValue( *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kContextId)), - absl::StrCat("$$", static_cast(event.context_id))); + absl::StrCat("$$", static_cast(event.context_id))); } if (event.type == CuptiTracerEventType::Kernel && @@ -320,7 +320,7 @@ class PerDeviceCollector { events_.emplace_back(std::move(event)); } - size_t Flush(tsl::uint64 start_gpu_ns, tsl::uint64 end_gpu_ns, + size_t Flush(uint64_t start_gpu_ns, uint64_t end_gpu_ns, XPlaneBuilder* device_plane, XPlaneBuilder* host_plane) { mutex_lock l(m_); // Tracking event types per line. @@ -392,7 +392,7 @@ class PerDeviceCollector { // Times 2 because HBM is DDR memory; it gets two data bits per each // data lane. auto memory_bandwidth = - tsl::uint64{2} * (*mem_clock_khz) * 1000 * (*mem_bus_width_bits) / 8; + uint64_t{2} * (*mem_clock_khz) * 1000 * (*mem_bus_width_bits) / 8; device_plane->AddStatValue( *device_plane->GetOrCreateStatMetadata( GetStatTypeStr(StatType::kDevCapMemoryBandwidth)), @@ -404,7 +404,7 @@ class PerDeviceCollector { device_plane->AddStatValue( *device_plane->GetOrCreateStatMetadata( GetStatTypeStr(StatType::kDevCapMemorySize)), - static_cast(total_memory)); + static_cast(total_memory)); } auto compute_capability_major = GetDeviceAttribute( @@ -469,7 +469,7 @@ class PerDeviceCollector { } // namespace -void AnnotationMap::Add(tsl::uint32 device_id, tsl::uint32 correlation_id, +void AnnotationMap::Add(uint32_t device_id, uint32_t correlation_id, const absl::string_view annotation, const absl::string_view nvtx_range) { if (annotation.empty() && nvtx_range.empty()) return; @@ -488,8 +488,8 @@ void AnnotationMap::Add(tsl::uint32 device_id, tsl::uint32 correlation_id, } } -AnnotationMap::AnnotationInfo AnnotationMap::LookUp( - tsl::uint32 device_id, tsl::uint32 correlation_id) { +AnnotationMap::AnnotationInfo AnnotationMap::LookUp(uint32_t device_id, + uint32_t correlation_id) { if (device_id >= per_device_map_.size()) return AnnotationInfo(); auto& per_device_map = per_device_map_[device_id]; absl::MutexLock lock(&per_device_map.mutex); @@ -503,8 +503,7 @@ AnnotationMap::AnnotationInfo AnnotationMap::LookUp( class CuptiTraceCollectorImpl : public CuptiTraceCollector { public: CuptiTraceCollectorImpl(const CuptiTracerCollectorOptions& option, - tsl::uint64 start_walltime_ns, - tsl::uint64 start_gpu_ns) + uint64_t start_walltime_ns, uint64_t start_gpu_ns) : CuptiTraceCollector(option), num_callback_events_(0), num_activity_events_(0), @@ -531,13 +530,13 @@ class CuptiTraceCollectorImpl : public CuptiTraceCollector { per_device_collector_[event.device_id].AddEvent(std::move(event)); } void OnEventsDropped(const std::string& reason, - tsl::uint32 num_events) override { + uint32_t num_events) override { absl::MutexLock lock(&mutex_); dropped_events_[reason] += num_events; } void Flush() override {} // Returns true if some GPU events are captured. - bool Export(XSpace* space, tsl::uint64 end_gpu_ns) override { + bool Export(XSpace* space, uint64_t end_gpu_ns) override { LOG(INFO) << " GpuTracer has collected " << num_callback_events_ << " callback api events and " << num_activity_events_ << " activity events. " << ReportDroppedEvents(); @@ -587,10 +586,10 @@ class CuptiTraceCollectorImpl : public CuptiTraceCollector { std::atomic num_callback_events_; std::atomic num_activity_events_; absl::Mutex mutex_; - absl::flat_hash_map dropped_events_ + absl::flat_hash_map dropped_events_ ABSL_GUARDED_BY(mutex_); - tsl::uint64 start_walltime_ns_; - tsl::uint64 start_gpu_ns_; + uint64_t start_walltime_ns_; + uint64_t start_gpu_ns_; int num_gpus_; // Set the all XLines of specified XPlane to starting walltime. @@ -599,7 +598,7 @@ class CuptiTraceCollectorImpl : public CuptiTraceCollector { // this fact. Eventually we change line start time to corresponding // start_walltime_ns to normalize with CPU wall time. static void NormalizeTimeStamps(XPlaneBuilder* plane, - tsl::uint64 start_walltime_ns) { + uint64_t start_walltime_ns) { plane->ForEachLine( [&](XLineBuilder line) { line.SetTimestampNs(start_walltime_ns); }); } @@ -611,8 +610,8 @@ class CuptiTraceCollectorImpl : public CuptiTraceCollector { }; std::unique_ptr CreateCuptiCollector( - const CuptiTracerCollectorOptions& options, - const tsl::uint64 start_walltime_ns, const tsl::uint64 start_gputime_ns) { + const CuptiTracerCollectorOptions& options, uint64_t start_walltime_ns, + uint64_t start_gputime_ns) { return std::make_unique(options, start_walltime_ns, start_gputime_ns); } diff --git a/xla/backends/profiler/gpu/cupti_collector.h b/xla/backends/profiler/gpu/cupti_collector.h index 8af5927269de5..8a523548256b5 100644 --- a/xla/backends/profiler/gpu/cupti_collector.h +++ b/xla/backends/profiler/gpu/cupti_collector.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,14 +16,14 @@ limitations under the License. #ifndef XLA_BACKENDS_PROFILER_GPU_CUPTI_COLLECTOR_H_ #define XLA_BACKENDS_PROFILER_GPU_CUPTI_COLLECTOR_H_ +#include #include #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_set.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/status.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" @@ -35,16 +35,16 @@ struct MemcpyDetails { size_t num_bytes; // The destination device for peer-2-peer communication (memcpy). The source // device is implicit: it's the current device. - tsl::uint32 destination; + uint32_t destination; // Whether or not the memcpy is asynchronous. bool async; // This contains CUpti_ActivityMemcpyKind for activity event (on device). // For events from other CuptiTracerEventSource, it is always 0. - tsl::int8 copy_kind; + int8_t copy_kind; // CUpti_ActivityMemoryKind of source. - tsl::int8 src_mem_kind; + int8_t src_mem_kind; // CUpti_ActivityMemoryKind of destination. - tsl::int8 dst_mem_kind; + int8_t dst_mem_kind; // ID of the hardware channel on which this operation ran. uint32_t channel_id = -1; @@ -56,9 +56,9 @@ struct MemAllocDetails { // Size of memory to be written over in bytes. size_t num_bytes; // The CUpti_ActivityMemoryKind value for this activity event. - tsl::int8 mem_kind; + int8_t mem_kind; // The virtual address of allocation. 0 if it is a free operation. - tsl::uint64 address; + uint64_t address; }; using MemFreeDetails = MemAllocDetails; @@ -72,20 +72,20 @@ using MemoryResidencyDetails = MemAllocDetails; // cudaHostRegister struct HostRegisterDetails { size_t num_bytes; - tsl::uint64 address; + uint64_t address; unsigned int flags; }; // cudaHostUnregister struct HostUnregisterDetails { - tsl::uint64 address; + uint64_t address; }; struct MemsetDetails { // Size of memory to be written over in bytes. size_t num_bytes; // The CUpti_ActivityMemoryKind value for this activity event. - tsl::int8 mem_kind; + int8_t mem_kind; // Whether or not the memset is asynchronous. bool async; @@ -97,23 +97,23 @@ struct MemsetDetails { struct KernelDetails { // The number of registers used in this kernel. - tsl::uint32 registers_per_thread; + uint32_t registers_per_thread; // The amount of shared memory space used by a thread block. - tsl::uint32 static_shared_memory_usage; + uint32_t static_shared_memory_usage; // The amount of dynamic memory space used by a thread block. - tsl::uint32 dynamic_shared_memory_usage; + uint32_t dynamic_shared_memory_usage; // X-dimension of a thread block. - tsl::uint32 block_x; + uint32_t block_x; // Y-dimension of a thread block. - tsl::uint32 block_y; + uint32_t block_y; // Z-dimension of a thread block. - tsl::uint32 block_z; + uint32_t block_z; // X-dimension of a grid. - tsl::uint32 grid_x; + uint32_t grid_x; // Y-dimension of a grid. - tsl::uint32 grid_y; + uint32_t grid_y; // Z-dimension of a grid. - tsl::uint32 grid_z; + uint32_t grid_z; // ID of the hardware channel on which this operation ran. uint32_t channel_id = -1; @@ -165,13 +165,13 @@ enum class CuptiTracerEventSource { }; struct CuptiTracerEvent { - static constexpr tsl::uint32 kInvalidThreadId = + static constexpr uint32_t kInvalidThreadId = std::numeric_limits::max(); - static constexpr tsl::uint32 kInvalidCorrelationId = + static constexpr uint32_t kInvalidCorrelationId = std::numeric_limits::max(); - static constexpr tsl::uint64 kInvalidContextId = + static constexpr uint64_t kInvalidContextId = std::numeric_limits::max(); - static constexpr tsl::uint64 kInvalidStreamId = + static constexpr uint64_t kInvalidStreamId = std::numeric_limits::max(); CuptiTracerEventType type = CuptiTracerEventType::Unsupported; CuptiTracerEventSource source = CuptiTracerEventSource::Invalid; @@ -183,11 +183,11 @@ struct CuptiTracerEvent { // where serialization happens. absl::string_view annotation; absl::string_view nvtx_range; - tsl::uint64 start_time_ns = 0; - tsl::uint64 end_time_ns = 0; - tsl::uint32 device_id = 0; - tsl::uint32 correlation_id = kInvalidCorrelationId; - tsl::uint32 thread_id = kInvalidThreadId; + uint64_t start_time_ns = 0; + uint64_t end_time_ns = 0; + uint32_t device_id = 0; + uint32_t correlation_id = kInvalidCorrelationId; + uint32_t thread_id = kInvalidThreadId; int64_t context_id = kInvalidContextId; int64_t stream_id = kInvalidStreamId; union { @@ -214,13 +214,13 @@ struct CuptiTracerCollectorOptions { // Maximum number of events to collect from callback API; if -1, no limit. // if 0, the callback API is enabled to build a correlation map, but no // events are collected. - tsl::uint64 max_callback_api_events = 2 * 1024 * 1024; + uint64_t max_callback_api_events = 2 * 1024 * 1024; // Maximum number of events to collect from activity API; if -1, no limit. - tsl::uint64 max_activity_api_events = 2 * 1024 * 1024; + uint64_t max_activity_api_events = 2 * 1024 * 1024; // Maximum number of annotation strings that we can accommodate. - tsl::uint64 max_annotation_strings = 1024 * 1024; + uint64_t max_annotation_strings = 1024 * 1024; // Number of GPUs involved. - tsl::uint32 num_gpus; + uint32_t num_gpus; }; class AnnotationMap { @@ -230,12 +230,12 @@ class AnnotationMap { absl::string_view nvtx_range; }; - explicit AnnotationMap(tsl::uint64 max_size, tsl::uint32 num_gpus) + explicit AnnotationMap(uint64_t max_size, uint32_t num_gpus) : max_size_(max_size), per_device_map_(num_gpus) {} - void Add(tsl::uint32 device_id, tsl::uint32 correlation_id, + void Add(uint32_t device_id, uint32_t correlation_id, const absl::string_view annotation, const absl::string_view nvtx_range); - AnnotationInfo LookUp(tsl::uint32 device_id, tsl::uint32 correlation_id); + AnnotationInfo LookUp(uint32_t device_id, uint32_t correlation_id); private: struct PerDeviceAnnotationMap { @@ -246,9 +246,9 @@ class AnnotationMap { // an use the reference to the string in the map. absl::node_hash_set annotations; absl::node_hash_set nvtx_ranges; - absl::flat_hash_map correlation_map; + absl::flat_hash_map correlation_map; }; - const tsl::uint64 max_size_; + const uint64_t max_size_; absl::FixedArray per_device_map_; AnnotationMap(const AnnotationMap&) = delete; @@ -265,12 +265,12 @@ class CuptiTraceCollector { // Producer side functions (i.e. called by CuptiTracer). virtual void AddEvent(CuptiTracerEvent&& event) = 0; virtual void OnEventsDropped(const std::string& reason, - tsl::uint32 num_events) = 0; + uint32_t num_events) = 0; virtual void Flush() = 0; // Consumer side functions (i.e. called by GPU tracer); virtual bool Export(tensorflow::profiler::XSpace* space, - tsl::uint64 end_gpu_ns) { + uint64_t end_gpu_ns) { return true; } virtual std::string ReportNumEventsIfDropped() { return ""; } @@ -288,8 +288,8 @@ class CuptiTraceCollector { }; std::unique_ptr CreateCuptiCollector( - const CuptiTracerCollectorOptions& options, - const tsl::uint64 start_walltime_ns, const tsl::uint64 start_gputime_ns); + const CuptiTracerCollectorOptions& options, uint64_t start_walltime_ns, + uint64_t start_gputime_ns); } // namespace profiler } // namespace xla diff --git a/xla/backends/profiler/gpu/cupti_error_manager.cc b/xla/backends/profiler/gpu/cupti_error_manager.cc index 4d84d9b8cfaf2..d50be6c955e7e 100644 --- a/xla/backends/profiler/gpu/cupti_error_manager.cc +++ b/xla/backends/profiler/gpu/cupti_error_manager.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -202,242 +202,6 @@ CUptiResult CuptiErrorManager::Unsubscribe(CUpti_SubscriberHandle subscriber) { return error; } -CUptiResult CuptiErrorManager::DeviceEnumEventDomains( - CUdevice device, size_t* array_size_bytes, - CUpti_EventDomainID* domain_array) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = interface_->DeviceEnumEventDomains( - device, array_size_bytes, domain_array); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::DeviceGetEventDomainAttribute( - CUdevice device, CUpti_EventDomainID event_domain, - CUpti_EventDomainAttribute attrib, size_t* value_size, void* value) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = interface_->DeviceGetEventDomainAttribute( - device, event_domain, attrib, value_size, value); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::DisableKernelReplayMode(CUcontext context) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = interface_->DisableKernelReplayMode(context); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::EnableKernelReplayMode(CUcontext context) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = interface_->EnableKernelReplayMode(context); - if (error == CUPTI_SUCCESS) { - auto f = - std::bind(&CuptiErrorManager::DisableKernelReplayMode, this, context); - RegisterUndoFunction(f); - } - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::DeviceEnumMetrics(CUdevice device, - size_t* arraySizeBytes, - CUpti_MetricID* metricArray) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = - interface_->DeviceEnumMetrics(device, arraySizeBytes, metricArray); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::DeviceGetNumEventDomains(CUdevice device, - uint32_t* num_domains) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = interface_->DeviceGetNumEventDomains(device, num_domains); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::EventDomainEnumEvents( - CUpti_EventDomainID event_domain, size_t* array_size_bytes, - CUpti_EventID* event_array) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = interface_->EventDomainEnumEvents( - event_domain, array_size_bytes, event_array); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::EventDomainGetNumEvents( - CUpti_EventDomainID event_domain, uint32_t* num_events) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = - interface_->EventDomainGetNumEvents(event_domain, num_events); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::EventGetAttribute(CUpti_EventID event, - CUpti_EventAttribute attrib, - size_t* value_size, - void* value) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = - interface_->EventGetAttribute(event, attrib, value_size, value); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::EventGetIdFromName(CUdevice device, - const char* event_name, - CUpti_EventID* event) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = interface_->EventGetIdFromName(device, event_name, event); - ALLOW_ERROR(error, CUPTI_ERROR_INVALID_EVENT_NAME); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::EventGroupDisable(CUpti_EventGroup event_group) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = interface_->EventGroupDisable(event_group); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::EventGroupEnable(CUpti_EventGroup event_group) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = interface_->EventGroupEnable(event_group); - if (error == CUPTI_SUCCESS) { - auto f = - std::bind(&CuptiErrorManager::EventGroupDisable, this, event_group); - RegisterUndoFunction(f); - } - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::EventGroupGetAttribute( - CUpti_EventGroup event_group, CUpti_EventGroupAttribute attrib, - size_t* value_size, void* value) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = interface_->EventGroupGetAttribute(event_group, attrib, - value_size, value); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::EventGroupReadEvent( - CUpti_EventGroup event_group, CUpti_ReadEventFlags flags, - CUpti_EventID event, size_t* event_value_buffer_size_bytes, - uint64_t* event_value_buffer) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = interface_->EventGroupReadEvent( - event_group, flags, event, event_value_buffer_size_bytes, - event_value_buffer); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::EventGroupSetAttribute( - CUpti_EventGroup event_group, CUpti_EventGroupAttribute attrib, - size_t value_size, void* value) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = interface_->EventGroupSetAttribute(event_group, attrib, - value_size, value); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::EventGroupSetsCreate( - CUcontext context, size_t event_id_array_size_bytes, - CUpti_EventID* event_id_array, CUpti_EventGroupSets** event_group_passes) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = interface_->EventGroupSetsCreate( - context, event_id_array_size_bytes, event_id_array, event_group_passes); - if (error == CUPTI_SUCCESS) { - auto f = std::bind(&CuptiErrorManager::EventGroupSetsDestroy, this, - *event_group_passes); - RegisterUndoFunction(f); - } - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::EventGroupSetsDestroy( - CUpti_EventGroupSets* event_group_sets) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = interface_->EventGroupSetsDestroy(event_group_sets); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -// CUPTI metric API -CUptiResult CuptiErrorManager::DeviceGetNumMetrics(CUdevice device, - uint32_t* num_metrics) { - IGNORE_CALL_IF_DISABLED; - // Disable heap checking for the first CUPTI metric API. See b/22091576. - absl::LeakCheckDisabler disabler; - CUptiResult error = interface_->DeviceGetNumMetrics(device, num_metrics); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::MetricGetIdFromName(CUdevice device, - const char* metric_name, - CUpti_MetricID* metric) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = - interface_->MetricGetIdFromName(device, metric_name, metric); - ALLOW_ERROR(error, CUPTI_ERROR_INVALID_METRIC_NAME); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::MetricGetNumEvents(CUpti_MetricID metric, - uint32_t* num_events) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = interface_->MetricGetNumEvents(metric, num_events); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::MetricEnumEvents( - CUpti_MetricID metric, size_t* event_id_array_size_bytes, - CUpti_EventID* event_id_array) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = interface_->MetricEnumEvents( - metric, event_id_array_size_bytes, event_id_array); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::MetricGetAttribute(CUpti_MetricID metric, - CUpti_MetricAttribute attrib, - size_t* value_size, - void* value) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = - interface_->MetricGetAttribute(metric, attrib, value_size, value); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - -CUptiResult CuptiErrorManager::MetricGetValue( - CUdevice device, CUpti_MetricID metric, size_t event_id_array_size_bytes, - CUpti_EventID* event_id_array, size_t event_value_array_size_bytes, - uint64_t* event_value_array, uint64_t time_duration, - CUpti_MetricValue* metric_value) { - IGNORE_CALL_IF_DISABLED; - CUptiResult error = interface_->MetricGetValue( - device, metric, event_id_array_size_bytes, event_id_array, - event_value_array_size_bytes, event_value_array, time_duration, - metric_value); - LOG_AND_DISABLE_IF_ERROR(error); - return error; -} - void CuptiErrorManager::UndoAndDisable() { if (undo_disabled_) { // prevent deadlock return; diff --git a/xla/backends/profiler/gpu/cupti_error_manager.h b/xla/backends/profiler/gpu/cupti_error_manager.h index 9ee29433b6a7d..863ae2730a7ed 100644 --- a/xla/backends/profiler/gpu/cupti_error_manager.h +++ b/xla/backends/profiler/gpu/cupti_error_manager.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -100,119 +100,6 @@ class CuptiErrorManager : public xla::profiler::CuptiInterface { // Unsubscribes callbacks. CUptiResult Unsubscribe(CUpti_SubscriberHandle subscriber) override; - // CUPTI event API - // Returns a list of event domains. - CUptiResult DeviceEnumEventDomains( - CUdevice device, size_t* array_size_bytes, - CUpti_EventDomainID* domain_array) override; - - // Returns domain attributes. - CUptiResult DeviceGetEventDomainAttribute(CUdevice device, - CUpti_EventDomainID event_domain, - CUpti_EventDomainAttribute attrib, - size_t* value_size, - void* value) override; - - // Disables kernel replay mode. - CUptiResult DisableKernelReplayMode(CUcontext context) override; - - // Enables kernel replay mode. If we successfully enable kernel replay mode, - // we add DisableKernelReplayMode to the undo log. - CUptiResult EnableKernelReplayMode(CUcontext context) override; - - // Returns the number of event domains. - CUptiResult DeviceGetNumEventDomains(CUdevice device, - uint32_t* num_domains) override; - - // Returns a list of events. - CUptiResult EventDomainEnumEvents(CUpti_EventDomainID event_domain, - size_t* array_size_bytes, - CUpti_EventID* event_array) override; - - // Returns the number of events. - CUptiResult EventDomainGetNumEvents(CUpti_EventDomainID event_domain, - uint32_t* num_events) override; - - // Returns an event attribute. - CUptiResult EventGetAttribute(CUpti_EventID event, - CUpti_EventAttribute attrib, size_t* value_size, - void* value) override; - - // Convverts event ID from event name. - CUptiResult EventGetIdFromName(CUdevice device, const char* event_name, - CUpti_EventID* event) override; - - // Disables event group. - CUptiResult EventGroupDisable(CUpti_EventGroup event_group) override; - - // Enables event group. If we successfully enable an event group, we add - // EventGroupDisable to the undo log. - CUptiResult EventGroupEnable(CUpti_EventGroup event_group) override; - - // Returns an event group attribute. - CUptiResult EventGroupGetAttribute(CUpti_EventGroup event_group, - CUpti_EventGroupAttribute attrib, - size_t* value_size, void* value) override; - - // Returns a performance counter value. - CUptiResult EventGroupReadEvent(CUpti_EventGroup event_group, - CUpti_ReadEventFlags flags, - CUpti_EventID event, - size_t* event_value_buffer_size_bytes, - uint64_t* event_value_buffer) override; - - // Returns an event group set attribute. - CUptiResult EventGroupSetAttribute(CUpti_EventGroup event_group, - CUpti_EventGroupAttribute attrib, - size_t value_size, void* value) override; - - // Creates an event group set. If we successfully creates an event group set, - // we add EventGroupSetsDestroy to the undo log. - CUptiResult EventGroupSetsCreate( - CUcontext context, size_t event_id_array_size_bytes, - CUpti_EventID* event_id_array, - CUpti_EventGroupSets** event_group_passes) override; - - // Destroys an event group set. - CUptiResult EventGroupSetsDestroy( - CUpti_EventGroupSets* event_group_sets) override; - - // CUPTI metric API: all thread-safe - // Enumerates metrics. - CUptiResult DeviceEnumMetrics(CUdevice device, size_t* arraySizeBytes, - CUpti_MetricID* metricArray) override; - - // Returns the number of metrics. - CUptiResult DeviceGetNumMetrics(CUdevice device, - uint32_t* num_metrics) override; - - // Converts a metric ID to a metric name. - CUptiResult MetricGetIdFromName(CUdevice device, const char* metric_name, - CUpti_MetricID* metric) override; - - // Returns the number of events required to calculate a particular metric. - CUptiResult MetricGetNumEvents(CUpti_MetricID metric, - uint32_t* num_events) override; - - // Returns a list of events required to calculate a particular metric. - CUptiResult MetricEnumEvents(CUpti_MetricID metric, - size_t* event_id_array_size_bytes, - CUpti_EventID* event_id_array) override; - - // Returns a metric attribute. - CUptiResult MetricGetAttribute(CUpti_MetricID metric, - CUpti_MetricAttribute attrib, - size_t* value_size, void* value) override; - - // Returns a metric value. - CUptiResult MetricGetValue(CUdevice device, CUpti_MetricID metric, - size_t event_id_array_size_bytes, - CUpti_EventID* event_id_array, - size_t event_value_array_size_bytes, - uint64_t* event_value_array, - uint64_t time_duration, - CUpti_MetricValue* metric_value) override; - CUptiResult GetResultString(CUptiResult result, const char** str) override; CUptiResult GetContextId(CUcontext context, uint32_t* context_id) override; diff --git a/xla/backends/profiler/gpu/cupti_error_manager_test.cc b/xla/backends/profiler/gpu/cupti_error_manager_test.cc index 76bb76e93c585..e513e6322db92 100644 --- a/xla/backends/profiler/gpu/cupti_error_manager_test.cc +++ b/xla/backends/profiler/gpu/cupti_error_manager_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/profiler/gpu/cupti_interface.h b/xla/backends/profiler/gpu/cupti_interface.h index 631780e1792de..b487673d07e6a 100644 --- a/xla/backends/profiler/gpu/cupti_interface.h +++ b/xla/backends/profiler/gpu/cupti_interface.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,8 +16,8 @@ limitations under the License. #ifndef XLA_BACKENDS_PROFILER_GPU_CUPTI_INTERFACE_H_ #define XLA_BACKENDS_PROFILER_GPU_CUPTI_INTERFACE_H_ -#include -#include +#include +#include #include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" #include "third_party/gpus/cuda/include/cuda.h" @@ -60,7 +60,7 @@ class CuptiInterface { CUpti_BuffersCallbackRequestFunc func_buffer_requested, CUpti_BuffersCallbackCompleteFunc func_buffer_completed) = 0; - virtual CUptiResult GetDeviceId(CUcontext context, tsl::uint32* deviceId) = 0; + virtual CUptiResult GetDeviceId(CUcontext context, uint32_t* deviceId) = 0; virtual CUptiResult GetTimestamp(uint64_t* timestamp) = 0; @@ -82,95 +82,6 @@ class CuptiInterface { virtual CUptiResult Unsubscribe(CUpti_SubscriberHandle subscriber) = 0; - // CUPTI event API - virtual CUptiResult DeviceEnumEventDomains( - CUdevice device, size_t* array_size_bytes, - CUpti_EventDomainID* domain_array) = 0; - - virtual CUptiResult DeviceGetEventDomainAttribute( - CUdevice device, CUpti_EventDomainID event_domain, - CUpti_EventDomainAttribute attrib, size_t* value_size, void* value) = 0; - - virtual CUptiResult DisableKernelReplayMode(CUcontext context) = 0; - - virtual CUptiResult EnableKernelReplayMode(CUcontext context) = 0; - - virtual CUptiResult DeviceGetNumEventDomains(CUdevice device, - uint32_t* num_domains) = 0; - - virtual CUptiResult EventDomainEnumEvents(CUpti_EventDomainID event_domain, - size_t* array_size_bytes, - CUpti_EventID* event_array) = 0; - - virtual CUptiResult EventDomainGetNumEvents(CUpti_EventDomainID event_domain, - uint32_t* num_events) = 0; - - virtual CUptiResult EventGetAttribute(CUpti_EventID event, - CUpti_EventAttribute attrib, - size_t* value_size, void* value) = 0; - - virtual CUptiResult EventGetIdFromName(CUdevice device, - const char* event_name, - CUpti_EventID* event) = 0; - - virtual CUptiResult EventGroupDisable(CUpti_EventGroup event_group) = 0; - - virtual CUptiResult EventGroupEnable(CUpti_EventGroup event_group) = 0; - - virtual CUptiResult EventGroupGetAttribute(CUpti_EventGroup event_group, - CUpti_EventGroupAttribute attrib, - size_t* value_size, - void* value) = 0; - - virtual CUptiResult EventGroupReadEvent(CUpti_EventGroup event_group, - CUpti_ReadEventFlags flags, - CUpti_EventID event, - size_t* event_value_buffer_size_bytes, - uint64_t* eventValueBuffer) = 0; - - virtual CUptiResult EventGroupSetAttribute(CUpti_EventGroup event_group, - CUpti_EventGroupAttribute attrib, - size_t value_size, - void* value) = 0; - - virtual CUptiResult EventGroupSetsCreate( - CUcontext context, size_t event_id_array_size_bytes, - CUpti_EventID* event_id_array, - CUpti_EventGroupSets** event_group_passes) = 0; - - virtual CUptiResult EventGroupSetsDestroy( - CUpti_EventGroupSets* event_group_sets) = 0; - - // CUPTI metric API - virtual CUptiResult DeviceEnumMetrics(CUdevice device, size_t* arraySizeBytes, - CUpti_MetricID* metricArray) = 0; - - virtual CUptiResult DeviceGetNumMetrics(CUdevice device, - uint32_t* num_metrics) = 0; - - virtual CUptiResult MetricGetIdFromName(CUdevice device, - const char* metric_name, - CUpti_MetricID* metric) = 0; - - virtual CUptiResult MetricGetNumEvents(CUpti_MetricID metric, - uint32_t* num_events) = 0; - - virtual CUptiResult MetricEnumEvents(CUpti_MetricID metric, - size_t* event_id_array_size_bytes, - CUpti_EventID* event_id_array) = 0; - - virtual CUptiResult MetricGetAttribute(CUpti_MetricID metric, - CUpti_MetricAttribute attrib, - size_t* value_size, void* value) = 0; - - virtual CUptiResult MetricGetValue(CUdevice device, CUpti_MetricID metric, - size_t event_id_array_size_bytes, - CUpti_EventID* event_id_array, - size_t event_value_array_size_bytes, - uint64_t* event_value_array, - uint64_t time_duration, - CUpti_MetricValue* metric_value) = 0; - virtual CUptiResult GetResultString(CUptiResult result, const char** str) = 0; virtual CUptiResult GetContextId(CUcontext context, uint32_t* context_id) = 0; diff --git a/xla/backends/profiler/gpu/cupti_profiler.cc b/xla/backends/profiler/gpu/cupti_profiler.cc index f49e4f54b9f42..b355308b21336 100644 --- a/xla/backends/profiler/gpu/cupti_profiler.cc +++ b/xla/backends/profiler/gpu/cupti_profiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -64,7 +64,7 @@ void CuptiProfiler::Enable(const CuptiProfilerOptions &option) {} void CuptiProfiler::Disable() {} -/*static*/ tsl::uint64 CuptiProfiler::GetTimestamp() { +/*static*/ uint64_t CuptiProfiler::GetTimestamp() { uint64_t tsc; CuptiInterface *cupti_interface = GetCuptiInterface(); if (cupti_interface && cupti_interface->GetTimestamp(&tsc) == CUPTI_SUCCESS) { diff --git a/xla/backends/profiler/gpu/cupti_profiler.h b/xla/backends/profiler/gpu/cupti_profiler.h index e16daafc9dfcc..3065da37b1269 100644 --- a/xla/backends/profiler/gpu/cupti_profiler.h +++ b/xla/backends/profiler/gpu/cupti_profiler.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,11 +15,11 @@ limitations under the License. #ifndef XLA_BACKENDS_PROFILER_GPU_CUPTI_PROFILER_H_ #define XLA_BACKENDS_PROFILER_GPU_CUPTI_PROFILER_H_ -#include "absl/types/optional.h" +#include +#include + #include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" #include "xla/backends/profiler/gpu/cupti_interface.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" #include "tsl/platform/types.h" namespace xla { diff --git a/xla/backends/profiler/gpu/cupti_tracer.cc b/xla/backends/profiler/gpu/cupti_tracer.cc index 5c93d7872f633..a7ebab58e6e3f 100644 --- a/xla/backends/profiler/gpu/cupti_tracer.cc +++ b/xla/backends/profiler/gpu/cupti_tracer.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ limitations under the License. #include "absl/container/node_hash_set.h" #include "third_party/gpus/cuda/extras/CUPTI/include/cupti_activity.h" #include "third_party/gpus/cuda/extras/CUPTI/include/generated_nvtx_meta.h" +#include "third_party/gpus/cuda/include/cuda.h" #include "xla/backends/profiler/gpu/cupti_collector.h" #include "xla/backends/profiler/gpu/nvtx_utils.h" #include "tsl/platform/env.h" @@ -36,9 +37,9 @@ namespace profiler { namespace { +using absl::OkStatus; +using absl::Status; using tsl::Env; -using tsl::OkStatus; -using tsl::Status; using tsl::profiler::AnnotationStack; // CUPTI from CUDA 11.6 adds information about the hardware channel that ops @@ -379,9 +380,9 @@ void CUPTIAPI ProcessCuptiActivityBuffer(CUcontext context, uint32_t stream_id, } void AddKernelEventUponApiExit(CuptiTraceCollector *collector, - tsl::uint32 device_id, + uint32_t device_id, const CUpti_CallbackData *cbdata, - tsl::uint64 start_time, tsl::uint64 end_time) { + uint64_t start_time, uint64_t end_time) { CuptiTracerEvent event{}; event.type = CuptiTracerEventType::Kernel; event.source = CuptiTracerEventSource::DriverCallback; @@ -399,8 +400,8 @@ void AddKernelEventUponApiExit(CuptiTraceCollector *collector, // Performs the actual callback for both normal and P2P memcpy operations. CuptiTracerEvent PopulateMemcpyCallbackEvent( CuptiTracerEventType type, const CUpti_CallbackData *cbdata, - size_t num_bytes, tsl::uint32 src_device, tsl::uint32 dst_device, - bool async, tsl::uint64 start_time, tsl::uint64 end_time) { + size_t num_bytes, uint32_t src_device, uint32_t dst_device, bool async, + uint64_t start_time, uint64_t end_time) { CuptiTracerEvent event{}; event.type = type; event.source = CuptiTracerEventSource::DriverCallback; @@ -421,11 +422,9 @@ CuptiTracerEvent PopulateMemcpyCallbackEvent( } void AddNormalMemcpyEventUponApiExit(CuptiTraceCollector *collector, - tsl::uint32 device_id, - CUpti_CallbackId cbid, + uint32_t device_id, CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata, - tsl::uint64 start_time, - tsl::uint64 end_time) { + uint64_t start_time, uint64_t end_time) { size_t num_bytes; CuptiTracerEventType type; bool async; @@ -440,9 +439,9 @@ void AddNormalMemcpyEventUponApiExit(CuptiTraceCollector *collector, } void AddCuMemsetEventUponApiExit(CuptiTraceCollector *collector, - tsl::uint32 device_id, CUpti_CallbackId cbid, + uint32_t device_id, CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata, - tsl::uint64 start_time, tsl::uint64 end_time) { + uint64_t start_time, uint64_t end_time) { // We are casting all variants of cuMemset to cuMemsetD8 for accessing the // first member attribute, a CUdeviceptr. const auto *params = @@ -473,17 +472,16 @@ void AddCuMemsetEventUponApiExit(CuptiTraceCollector *collector, void AddP2PMemcpyEventUponApiExit(CuptiTraceCollector *collector, CuptiInterface *cupti_interface, - tsl::uint32 device_id, CUpti_CallbackId cbid, + uint32_t device_id, CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata, - tsl::uint64 start_time, - tsl::uint64 end_time) { + uint64_t start_time, uint64_t end_time) { size_t num_bytes; CuptiTracerEventType type; bool async; std::tie(num_bytes, type, async) = DecodeDriverMemcpy(cbid, cbdata->functionParams); - tsl::uint32 dst_device = -1, src_device = -1; + uint32_t dst_device = -1, src_device = -1; const auto *p2p_params = static_cast(cbdata->functionParams); cupti_interface->GetDeviceId(p2p_params->srcContext, &src_device); @@ -497,10 +495,9 @@ void AddP2PMemcpyEventUponApiExit(CuptiTraceCollector *collector, } void AddCuMemAllocEventUponApiExit(CuptiTraceCollector *collector, - tsl::uint32 device_id, CUpti_CallbackId cbid, + uint32_t device_id, CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata, - tsl::uint64 start_time, - tsl::uint64 end_time) { + uint64_t start_time, uint64_t end_time) { const auto *params = static_cast(cbdata->functionParams); const void *dptr = reinterpret_cast(*params->dptr); @@ -521,12 +518,9 @@ void AddCuMemAllocEventUponApiExit(CuptiTraceCollector *collector, collector->AddEvent(std::move(event)); } -void AddCuMemAllocPitchEventUponApiExit(CuptiTraceCollector *collector, - tsl::uint32 device_id, - CUpti_CallbackId cbid, - const CUpti_CallbackData *cbdata, - tsl::uint64 start_time, - tsl::uint64 end_time) { +void AddCuMemAllocPitchEventUponApiExit( + CuptiTraceCollector *collector, uint32_t device_id, CUpti_CallbackId cbid, + const CUpti_CallbackData *cbdata, uint64_t start_time, uint64_t end_time) { const auto *params = static_cast(cbdata->functionParams); const void *dptr = reinterpret_cast(*params->dptr); @@ -548,12 +542,9 @@ void AddCuMemAllocPitchEventUponApiExit(CuptiTraceCollector *collector, collector->AddEvent(std::move(event)); } -void AddCuMemAllocManagedEventUponApiExit(CuptiTraceCollector *collector, - tsl::uint32 device_id, - CUpti_CallbackId cbid, - const CUpti_CallbackData *cbdata, - tsl::uint64 start_time, - tsl::uint64 end_time) { +void AddCuMemAllocManagedEventUponApiExit( + CuptiTraceCollector *collector, uint32_t device_id, CUpti_CallbackId cbid, + const CUpti_CallbackData *cbdata, uint64_t start_time, uint64_t end_time) { const auto *params = static_cast(cbdata->functionParams); const void *dptr = reinterpret_cast(*params->dptr); @@ -575,11 +566,10 @@ void AddCuMemAllocManagedEventUponApiExit(CuptiTraceCollector *collector, } void AddCuMemAllocHostEventUponApiExit(CuptiTraceCollector *collector, - tsl::uint32 device_id, + uint32_t device_id, CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata, - tsl::uint64 start_time, - tsl::uint64 end_time) { + uint64_t start_time, uint64_t end_time) { const auto *params = static_cast(cbdata->functionParams); CuptiTracerEvent event{}; @@ -600,11 +590,10 @@ void AddCuMemAllocHostEventUponApiExit(CuptiTraceCollector *collector, } void AddCuMemHostAllocEventUponApiExit(CuptiTraceCollector *collector, - tsl::uint32 device_id, + uint32_t device_id, CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata, - tsl::uint64 start_time, - tsl::uint64 end_time) { + uint64_t start_time, uint64_t end_time) { const auto *params = static_cast(cbdata->functionParams); CuptiTracerEvent event{}; @@ -626,10 +615,9 @@ void AddCuMemHostAllocEventUponApiExit(CuptiTraceCollector *collector, } void AddCuMemFreeEventUponApiExit(CuptiTraceCollector *collector, - tsl::uint32 device_id, CUpti_CallbackId cbid, + uint32_t device_id, CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata, - tsl::uint64 start_time, - tsl::uint64 end_time) { + uint64_t start_time, uint64_t end_time) { const auto *params = static_cast(cbdata->functionParams); const void *dptr = reinterpret_cast(params->dptr); @@ -650,11 +638,9 @@ void AddCuMemFreeEventUponApiExit(CuptiTraceCollector *collector, } void AddCuMemFreeHostEventUponApiExit(CuptiTraceCollector *collector, - tsl::uint32 device_id, - CUpti_CallbackId cbid, + uint32_t device_id, CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata, - tsl::uint64 start_time, - tsl::uint64 end_time) { + uint64_t start_time, uint64_t end_time) { const auto *params = static_cast(cbdata->functionParams); CuptiTracerEvent event{}; @@ -673,12 +659,9 @@ void AddCuMemFreeHostEventUponApiExit(CuptiTraceCollector *collector, collector->AddEvent(std::move(event)); } -void AddCuMemHostRegisterEventUponApiExit(CuptiTraceCollector *collector, - tsl::uint32 device_id, - CUpti_CallbackId cbid, - const CUpti_CallbackData *cbdata, - tsl::uint64 start_time, - tsl::uint64 end_time) { +void AddCuMemHostRegisterEventUponApiExit( + CuptiTraceCollector *collector, uint32_t device_id, CUpti_CallbackId cbid, + const CUpti_CallbackData *cbdata, uint64_t start_time, uint64_t end_time) { const auto *params = static_cast(cbdata->functionParams); CuptiTracerEvent event{}; @@ -700,12 +683,9 @@ void AddCuMemHostRegisterEventUponApiExit(CuptiTraceCollector *collector, collector->AddEvent(std::move(event)); } -void AddCuMemHostUnregisterEventUponApiExit(CuptiTraceCollector *collector, - tsl::uint32 device_id, - CUpti_CallbackId cbid, - const CUpti_CallbackData *cbdata, - tsl::uint64 start_time, - tsl::uint64 end_time) { +void AddCuMemHostUnregisterEventUponApiExit( + CuptiTraceCollector *collector, uint32_t device_id, CUpti_CallbackId cbid, + const CUpti_CallbackData *cbdata, uint64_t start_time, uint64_t end_time) { const auto *params = static_cast(cbdata->functionParams); CuptiTracerEvent event{}; @@ -725,9 +705,9 @@ void AddCuMemHostUnregisterEventUponApiExit(CuptiTraceCollector *collector, } void AddGenericEventUponApiExit(CuptiTraceCollector *collector, - tsl::uint32 device_id, CUpti_CallbackId cbid, + uint32_t device_id, CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata, - tsl::uint64 start_time, tsl::uint64 end_time) { + uint64_t start_time, uint64_t end_time) { CuptiTracerEvent event{}; event.type = CuptiTracerEventType::Generic; event.source = CuptiTracerEventSource::DriverCallback; @@ -1082,630 +1062,6 @@ class CuptiDriverApiHookWithActivityApi : public CuptiDriverApiHook { void operator=(const CuptiDriverApiHookWithActivityApi &) = delete; }; -struct KernelRecord { - const char *kernel_name; - // TODO(csigg): cuStreamGetCtx introduced in CUDA 9.2 would allow us to only - // record the stream and infer the context during collection. - CUcontext context; - CUstream stream; - tsl::uint32 correlation_id; - CUevent start_event; - CUevent stop_event; - KernelDetails details; - tsl::uint64 start_timestamp; -}; - -struct MemcpyRecord { - CuptiTracerEventType type; - size_t size_bytes; - CUcontext context; - CUstream stream; - tsl::uint32 correlation_id; - bool async; - CUevent start_event; - CUevent stop_event; - tsl::uint64 start_timestamp; -}; - -Status CreateAndRecordEvent(CUevent *event, CUstream stream) { - CuptiApiTracingDisabler disabler; - TF_RETURN_IF_ERROR(ToStatus(cuEventCreate(event, CU_EVENT_DEFAULT))); - return ToStatus(cuEventRecord(*event, stream)); -} - -// Maintain and restore current thread's CUDA context. -// Note: cuStreamGetCtx only available after CUDA 9.2. -class ScopedCudaContext { - public: - explicit ScopedCudaContext(CUstream stream) : stream_(stream) { - CuptiApiTracingDisabler disabler; // don't trace cuda call in this func. - CUcontext context; - if (cuStreamGetCtx(stream, &context) != CUDA_SUCCESS) return; - context_ = context; - tsl::uint32 device_ordinal; - if (cuptiGetDeviceId(context, &device_ordinal) != CUPTI_SUCCESS) return; - device_ordinal_ = device_ordinal; - context_pushed_ = cuCtxPushCurrent(context) == CUDA_SUCCESS; - } - ~ScopedCudaContext() { - if (!context_pushed_) return; - CuptiApiTracingDisabler disabler; // don't trace cuda call in this func. - cuCtxPopCurrent(&*context_); - } - - // If successful, return the device ordinal of the relevant cuda stream. - // Otherwise std::nullopt; non-std ok - std::optional GetDeviceOrdinal() { return device_ordinal_; } - - // If successful, return the cuda context of the relevant cuda stream. - // Otherwise std::nullopt; - std::optional GetContext() { return context_; } - - private: - CUstream stream_; - std::optional context_; - std::optional device_ordinal_; - bool context_pushed_ = false; -}; - -// Stores a series of kernel and memcpy records. -class CudaEventRecorder { - public: - CudaEventRecorder(CuptiInterface *cupti_interface, - CuptiTraceCollector *collector, int ordinal) - : cupti_interface_(cupti_interface), - collector_(collector), - ordinal_(ordinal) { - device_name_ = absl::StrCat("gpu ", ordinal); // default. - CUdevice device; - if (cuDeviceGet(&device, ordinal) == CUDA_SUCCESS) { - char name[100]; - if (cuDeviceGetName(name, sizeof(name), device) == CUDA_SUCCESS) { - device_name_ = name; - } - } - } - - // Registers the start of a kernel launch. The returned index should be passed - // to StopKernel() after the kernel launch has completed. - template - size_t StartKernel(const char *kernel_name, CUcontext context, - tsl::uint32 correlation_id, const T *params) { - CUstream stream = params->hStream; - KernelRecord record = {kernel_name, context, stream, correlation_id}; - record.details.registers_per_thread = 0; // unknown. - record.details.static_shared_memory_usage = params->sharedMemBytes; - record.details.dynamic_shared_memory_usage = 0; // unknown - record.details.block_x = params->blockDimX; - record.details.block_y = params->blockDimY; - record.details.block_z = params->blockDimZ; - record.details.grid_x = params->gridDimX; - record.details.grid_y = params->gridDimY; - record.details.grid_z = params->gridDimZ; - record.start_timestamp = CuptiTracer::GetTimestamp(); - LogIfError(CreateAndRecordEvent(&record.start_event, stream)); - absl::MutexLock lock(&mutex_); - if (stopped_) return -1; - kernel_records_.push_back(record); - return kernel_records_.size() - 1; - } - tsl::uint64 StopKernel(size_t index) { - absl::MutexLock lock(&mutex_); - if (index >= kernel_records_.size()) return 0; - auto &record = kernel_records_[index]; - LogIfError(CreateAndRecordEvent(&record.stop_event, record.stream)); - return record.start_timestamp; - } - - // Registers the start of a copy operation. The returned index should be - // passed to StopMemcpy() after the memcpy has completed. - size_t StartMemcpy(CuptiTracerEventType type, size_t size_bytes, - CUcontext context, CUstream stream, - tsl::uint32 correlation_id, bool async) { - MemcpyRecord record = {type, size_bytes, context, - stream, correlation_id, async}; - record.start_timestamp = CuptiTracer::GetTimestamp(); - LogIfError(CreateAndRecordEvent(&record.start_event, stream)); - absl::MutexLock lock(&mutex_); - if (stopped_) return -1; - memcpy_records_.push_back(record); - return memcpy_records_.size() - 1; - } - tsl::uint64 StopMemcpy(size_t index) { - absl::MutexLock lock(&mutex_); - if (index >= memcpy_records_.size()) return 0; - auto &record = memcpy_records_[index]; - LogIfError(CreateAndRecordEvent(&record.stop_event, record.stream)); - return record.start_timestamp; - } - - Status Stop() { - { - absl::MutexLock lock(&mutex_); - stopped_ = true; - LOG(INFO) << "Collecting " << kernel_records_.size() - << " kernel records, " << memcpy_records_.size() - << " memcpy records."; - - // Gather all profiled streams and contexts. - for (const auto &record : kernel_records_) { - TF_RETURN_IF_ERROR( - AddStreamInfo(record.context, record.stream, "Kernel")); - } - for (const auto &record : memcpy_records_) { - TF_RETURN_IF_ERROR(AddStreamInfo(record.context, record.stream, - GetTraceEventTypeName(record.type))); - } - } - - // Synchronize all contexts, record end events, synchronize again. - // This scheme is an unreliable measure to associate a event with the wall - // time. There are chances that other threads might enque kernels which - // delay the second synchronization. - TF_RETURN_IF_ERROR(Synchronize()); - for (auto &pair : context_infos_) { - TF_RETURN_IF_ERROR(ToStatus(cuCtxSetCurrent(pair.first))); - TF_RETURN_IF_ERROR(CreateAndRecordEvent(&pair.second.end_event, nullptr)); - } - - TF_RETURN_IF_ERROR(Synchronize()); - end_walltime_us_ = Env::Default()->NowMicros(); - return OkStatus(); - } - - Status Flush(AnnotationMap *annotation_map) { - auto kernel_records = ConsumeKernelRecords(); - auto memcpy_records = ConsumeMemcpyRecords(); - for (const auto &record : kernel_records) { - TF_RETURN_IF_ERROR(SaveRecord(record, annotation_map)); - } - for (const auto &record : memcpy_records) { - TF_RETURN_IF_ERROR(SaveRecord(record, annotation_map)); - } - return OkStatus(); - } - - std::vector ConsumeKernelRecords() { - absl::MutexLock lock(&mutex_); - return std::move(kernel_records_); - } - std::vector ConsumeMemcpyRecords() { - absl::MutexLock lock(&mutex_); - return std::move(memcpy_records_); - } - - private: - struct ContextInfo { - tsl::uint32 context_id = 0; - int num_streams = 0; - CUevent end_event; - }; - - struct StreamInfo { - tsl::uint32 stream_id = 0; - std::string name; - int index; // 0 is reserved for null stream. - const ContextInfo *ctx_info; - }; - - // Synchronizes all contexts. - Status Synchronize() const { - CuptiApiTracingDisabler disabler; - for (const auto &pair : context_infos_) { - TF_RETURN_IF_ERROR(ToStatus(cuCtxSetCurrent(pair.first))); - TF_RETURN_IF_ERROR(ToStatus(cuCtxSynchronize())); - } - return OkStatus(); - } - - // Returns element from context_infos_, adding it if not yet present. - Status GetContextInfo(CUcontext context, ContextInfo **ctx_info_ptr) { - auto it = context_infos_.find(context); - - if (it == context_infos_.end()) { - tsl::uint32 context_id = 0; - RETURN_IF_CUPTI_ERROR( - cupti_interface_->GetContextId(context, &context_id)); - ContextInfo ctx_info = {context_id}; - it = context_infos_.emplace(context, ctx_info).first; - } - - *ctx_info_ptr = &it->second; - return OkStatus(); - } - - // Adds element to stream_infos_ if not yet present. If present, clear name - // if it doesn't match parameter. - Status AddStreamInfo(CUcontext context, CUstream stream, - absl::string_view name) { - StreamKey key(context, stream); - auto it = stream_infos_.find(key); - if (it != stream_infos_.end()) { - if (it->second.name != name) { - it->second.name.clear(); // Stream with inconsistent names, clear it. - } - return OkStatus(); - } - - ContextInfo *ctx_info; - TF_RETURN_IF_ERROR(GetContextInfo(context, &ctx_info)); - int index = stream ? ++ctx_info->num_streams : 0; - tsl::uint32 stream_id = 0; -#if defined(CUDA_API_PER_THREAD_DEFAULT_STREAM) - RETURN_IF_CUPTI_ERROR( - cupti_interface_->GetStreamIdEx(context, stream, 1, &stream_id)); -#else - RETURN_IF_CUPTI_ERROR( - cupti_interface_->GetStreamIdEx(context, stream, 0, &stream_id)); -#endif - - StreamInfo stream_info = {stream_id, static_cast(name), index, - ctx_info}; - stream_infos_.emplace(key, stream_info); - return OkStatus(); - } - - // Returns time in microseconds between events recorded on the GPU. - static uint64_t GetElapsedTimeUs(CUevent start, CUevent stop) { - CuptiApiTracingDisabler disabler; - float elapsed_ms = 0.0f; - LogIfError(ToStatus(cuEventElapsedTime(&elapsed_ms, start, stop))); - return static_cast( - std::llroundf(1000 * std::max(elapsed_ms, 0.0f))); - } - - Status SaveRecord(const KernelRecord &record, - AnnotationMap *annotation_map) const { - if (!record.start_event || !record.stop_event) { - return OkStatus(); - } - const auto &stream_info = - stream_infos_.at(StreamKey(record.context, record.stream)); - auto start_us = - GetElapsedTimeUs(record.start_event, stream_info.ctx_info->end_event); - auto elapsed_us = GetElapsedTimeUs(record.start_event, record.stop_event); - - std::string annotation; - - CuptiTracerEvent event{}; - event.type = CuptiTracerEventType::Kernel; - event.source = CuptiTracerEventSource::Activity; // on gpu device. - event.name = record.kernel_name; - event.start_time_ns = (end_walltime_us_ - start_us) * 1000; - event.end_time_ns = event.start_time_ns + elapsed_us * 1000; - event.device_id = ordinal_; - event.context_id = stream_info.ctx_info->context_id; - event.stream_id = stream_info.stream_id; - event.correlation_id = record.correlation_id; - AnnotationMap::AnnotationInfo info = collector_->annotation_map()->LookUp( - event.device_id, event.correlation_id); - event.annotation = info.annotation; - event.kernel_info = record.details; - collector_->AddEvent(std::move(event)); - return OkStatus(); - } - - Status SaveRecord(const MemcpyRecord &record, - AnnotationMap *annotation_map) const { - if (!record.start_event || !record.stop_event) { - return OkStatus(); - } - const auto &stream_info = - stream_infos_.at(StreamKey(record.context, record.stream)); - auto start_us = - GetElapsedTimeUs(record.start_event, stream_info.ctx_info->end_event); - auto elapsed_us = GetElapsedTimeUs(record.start_event, record.stop_event); - - CuptiTracerEvent event{}; - event.type = record.type; - event.name = GetTraceEventTypeName(event.type); - event.source = CuptiTracerEventSource::Activity; - event.start_time_ns = (end_walltime_us_ - start_us) * 1000; - event.end_time_ns = event.start_time_ns + elapsed_us * 1000; - event.device_id = ordinal_; - event.context_id = stream_info.ctx_info->context_id; - event.stream_id = stream_info.stream_id; - event.correlation_id = record.correlation_id; - AnnotationMap::AnnotationInfo info = collector_->annotation_map()->LookUp( - event.device_id, event.correlation_id); - event.annotation = info.annotation; - event.memcpy_info.num_bytes = record.size_bytes; - // TODO: support MemcpyD2D where destination != source; - event.memcpy_info.destination = ordinal_; - event.memcpy_info.async = record.async; - // TODO: set src_mem_kind and dst_mem_kind. - collector_->AddEvent(std::move(event)); - return OkStatus(); - } - - absl::Mutex mutex_; - bool stopped_ TF_GUARDED_BY(mutex_) = false; - std::vector kernel_records_ TF_GUARDED_BY(mutex_); - std::vector memcpy_records_ - TF_GUARDED_BY(mutex_); // non std ok - - CuptiInterface *cupti_interface_; - CuptiTraceCollector *collector_; - const int ordinal_; - std::string device_name_; - tsl::uint64 end_walltime_us_; - // Include context in key to distinguish null streams. - using StreamKey = std::pair; - - absl::node_hash_map context_infos_; - absl::flat_hash_map stream_infos_; -}; - -// This hook uses cuda events to measure device side activities. -class CuptiDriverApiHookWithCudaEvent : public CuptiDriverApiHook { - public: - CuptiDriverApiHookWithCudaEvent(const CuptiTracerOptions &option, - CuptiInterface *cupti_interface, - CuptiTraceCollector *collector) - : option_(option), - cupti_interface_(cupti_interface), - collector_(collector) { - int num_gpus = CuptiTracer::NumGpus(); - cuda_event_recorders_.reserve(num_gpus); - for (int i = 0; i < num_gpus; ++i) { - cuda_event_recorders_.emplace_back( - std::make_unique(cupti_interface, collector, i)); - } - } - ~CuptiDriverApiHookWithCudaEvent() { - for (auto *callback_context : callback_contexts_) delete callback_context; - } - - Status OnDriverApiEnter(int device_id, CUpti_CallbackDomain domain, - CUpti_CallbackId cbid, - const CUpti_CallbackData *cbdata) override { - auto *recorder = cuda_event_recorders_[device_id].get(); - switch (cbid) { - case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel: { - DCHECK_NE(cbdata->symbolName, nullptr); - const auto *params = - static_cast(cbdata->functionParams); - *cbdata->correlationData = recorder->StartKernel( - cbdata->symbolName, cbdata->context, cbdata->correlationId, params); - break; - } -#if CUDA_VERSION >= 11080 // CUDA 11.8 - case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx: { - DCHECK_NE(cbdata->symbolName, nullptr); - const auto *params = static_cast( - cbdata->functionParams); - *cbdata->correlationData = recorder->StartKernel( - cbdata->symbolName, cbdata->context, cbdata->correlationId, - params->config); - break; - } -#endif // CUDA_VERSION >= 11080 - case CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel: { - DCHECK_NE(cbdata->symbolName, nullptr); - const auto *params = - static_cast( - cbdata->functionParams); - *cbdata->correlationData = - recorder->StartKernel( - cbdata->symbolName, cbdata->context, cbdata->correlationId, - params); - break; - } - case CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernelMultiDevice: { - const auto *params = - static_cast( - cbdata->functionParams); - std::vector record_indices; - record_indices.reserve(params->numDevices); - *cbdata->correlationData = -1; // Invalid value. - const auto &annotation = AnnotationStack::Get(); - for (int i = 0; i < params->numDevices; ++i) { - CUstream stream = params->launchParamsList[i].hStream; - ScopedCudaContext scoped_cuda_context(stream); - auto dev_id = scoped_cuda_context.GetDeviceOrdinal(); - auto context = scoped_cuda_context.GetContext(); - if (!dev_id) return tsl::errors::Internal("Invalid CUDA stream"); - // Because annotation are per device, therefore we need to populate - // annotation for each device involved. - collector_->annotation_map()->Add(*dev_id, cbdata->correlationId, - annotation, ""); - record_indices.push_back( - cuda_event_recorders_[*dev_id]->StartKernel( - "CooperativeKernelMultiDevice", *context, - cbdata->correlationId, &(params->launchParamsList[i]))); - } - auto *callback_context = - new CuptiApiCallbackContext(std::move(record_indices)); - callback_contexts_.insert(callback_context); - *cbdata->correlationData = - reinterpret_cast(callback_context); - } break; - case CUPTI_DRIVER_TRACE_CBID_cuMemcpy: { - const auto *params = - static_cast(cbdata->functionParams); - StartMemcpy(GetMemcpyType(params->src, params->dst), - cbdata, recorder); - break; - } - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyAsync: { - const auto *params = - static_cast(cbdata->functionParams); - StartMemcpyAsync( - GetMemcpyType(params->src, params->dst), cbdata, recorder); - break; - } - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2: - StartMemcpy(CuptiTracerEventType::MemcpyH2D, - cbdata, recorder); - break; - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2: - StartMemcpyAsync( - CuptiTracerEventType::MemcpyH2D, cbdata, recorder); - break; - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2: - StartMemcpy(CuptiTracerEventType::MemcpyD2H, - cbdata, recorder); - break; - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2: - StartMemcpyAsync( - CuptiTracerEventType::MemcpyD2H, cbdata, recorder); - break; - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2: - StartMemcpy(CuptiTracerEventType::MemcpyD2D, - cbdata, recorder); - break; - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2: - StartMemcpyAsync( - CuptiTracerEventType::MemcpyD2D, cbdata, recorder); - break; - default: - VLOG(1) << "Unexpected callback id: " << cbid; - break; - } - return OkStatus(); - } - - Status OnDriverApiExit(int device_id, CUpti_CallbackDomain domain, - CUpti_CallbackId cbid, - const CUpti_CallbackData *cbdata) override { - auto *recorder = cuda_event_recorders_[device_id].get(); - if (*cbdata->correlationData == static_cast(-1)) return OkStatus(); - tsl::uint64 start_tsc = 0; - switch (cbid) { - case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel: -#if CUDA_VERSION >= 11080 // CUDA 11.8 - case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx: -#endif // CUDA_VERSION >= 11080 - case CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel: - start_tsc = recorder->StopKernel(*cbdata->correlationData); - break; - case CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernelMultiDevice: { - auto *callback_context = reinterpret_cast( - *cbdata->correlationData); - callback_contexts_.erase(callback_context); - auto record_indices = std::move(callback_context->record_indices); - delete callback_context; - const auto *params = - static_cast( - cbdata->functionParams); - if (record_indices.size() != params->numDevices) - return tsl::errors::Internal("Invalid correlation data"); - for (int i = 0; i < params->numDevices; ++i) { - CUstream stream = params->launchParamsList[i].hStream; - ScopedCudaContext scoped_cuda_context(stream); - auto dev_id = scoped_cuda_context.GetDeviceOrdinal(); - if (!dev_id) return tsl::errors::Internal("Invalid CUDA stream"); - start_tsc = - cuda_event_recorders_[*dev_id]->StopKernel(record_indices[i]); - } - } break; - case CUPTI_DRIVER_TRACE_CBID_cuMemcpy: - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyAsync: - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2: - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2: - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2: - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2: - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2: - case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2: - start_tsc = recorder->StopMemcpy(*cbdata->correlationData); - break; - default: - VLOG(1) << "Unexpected callback id: " << cbid; - // TODO: figure out how to get start timestamp in this case. - return OkStatus(); - } - // If we are not collecting CPU events from Callback API, we can return now. - if (!option_.required_callback_api_events) { - return OkStatus(); - } - - // Grab timestamp for API exit. API entry timestamp saved in cbdata. - tsl::uint64 end_tsc = CuptiTracer::GetTimestamp(); - return AddDriverApiCallbackEvent(collector_, cupti_interface_, device_id, - start_tsc, end_tsc, domain, cbid, cbdata); - } - Status SyncAndFlush() override { - for (auto &recorder : cuda_event_recorders_) { - TF_RETURN_IF_ERROR(recorder->Stop()); - } - for (auto &recorder : cuda_event_recorders_) { - TF_RETURN_IF_ERROR(recorder->Flush(collector_->annotation_map())); - } - return OkStatus(); - } - - private: - template - static void StartMemcpy(CuptiTracerEventType type, - const CUpti_CallbackData *cbdata, - CudaEventRecorder *recorder) { - const auto *params = static_cast(cbdata->functionParams); - *cbdata->correlationData = - recorder->StartMemcpy(type, params->ByteCount, cbdata->context, nullptr, - cbdata->correlationId, /*async*/ false); - } - - template - static void StartMemcpyAsync(CuptiTracerEventType type, - const CUpti_CallbackData *cbdata, - CudaEventRecorder *recorder) { - const auto *params = static_cast(cbdata->functionParams); - *cbdata->correlationData = recorder->StartMemcpy( - type, params->ByteCount, cbdata->context, params->hStream, - cbdata->correlationId, /*async*/ true); - } - - static CUmemorytype GetMemoryType(CUdeviceptr ptr) { - CuptiApiTracingDisabler disabler; - CUmemorytype mem_type = CU_MEMORYTYPE_HOST; - auto status = - cuPointerGetAttribute(&mem_type, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, ptr); - if (status == CUDA_ERROR_INVALID_VALUE) { - // Pointer not registered with CUDA, must be host memory. - return CU_MEMORYTYPE_HOST; - } - LogIfError(ToStatus(status)); - return mem_type; - } - - static CuptiTracerEventType GetMemcpyType(CUdeviceptr src, CUdeviceptr dst) { - CUmemorytype src_type = GetMemoryType(src); - CUmemorytype dst_type = GetMemoryType(dst); - // TODO: handle CU_MEMORYTYPE_ARRAY case - if (src_type == CU_MEMORYTYPE_HOST && dst_type == CU_MEMORYTYPE_DEVICE) { - return CuptiTracerEventType::MemcpyH2D; - } else if (src_type == CU_MEMORYTYPE_DEVICE && - dst_type == CU_MEMORYTYPE_HOST) { - return CuptiTracerEventType::MemcpyD2H; - } else if (src_type == CU_MEMORYTYPE_DEVICE && - dst_type == CU_MEMORYTYPE_DEVICE) { - return CuptiTracerEventType::MemcpyD2D; - } - return CuptiTracerEventType::MemcpyOther; - } - - // Each cuLaunchCooperativeKernelMultiDevice will need to add an entry in - // each corresponding device, therefore we need to keep records of all - // the record indices in each device's record array. - // We allocate such data structure during API entry and free during API exit. - // However there is no guarantee that we receive such callbacks in pairs, we - // maintain a on-going API calls to make sure no memory leaks. - struct CuptiApiCallbackContext { - explicit CuptiApiCallbackContext(std::vector &&r) - : record_indices(std::move(r)) {} - std::vector record_indices; - }; - - const CuptiTracerOptions option_; - CuptiInterface *cupti_interface_; - CuptiTraceCollector *collector_; - absl::node_hash_set callback_contexts_; - std::vector> cuda_event_recorders_; - CuptiDriverApiHookWithCudaEvent(const CuptiDriverApiHookWithCudaEvent &) = - delete; - void operator=(const CuptiDriverApiHookWithCudaEvent &) = delete; -}; - /*static*/ std::string ErrorWithHostname(absl::string_view error_message) { return absl::StrCat(tsl::port::Hostname(), ": ", error_message); } @@ -1714,7 +1070,7 @@ class CuptiDriverApiHookWithCudaEvent : public CuptiDriverApiHook { /*static*/ Status CuptiDriverApiHook::AddDriverApiCallbackEvent( CuptiTraceCollector *collector, CuptiInterface *cupti_interface, - int device_id, tsl::uint64 start_tsc, tsl::uint64 end_tsc, + int device_id, uint64_t start_tsc, uint64_t end_tsc, CUpti_CallbackDomain domain, CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata) { switch (cbid) { @@ -1890,30 +1246,21 @@ void CuptiTracer::Enable(const CuptiTracerOptions &option, CuptiTraceCollector *collector) { option_ = option; collector_ = collector; - if (option_->enable_event_based_activity) { - option_->enable_activity_api = false; - cupti_driver_api_hook_.reset(new CuptiDriverApiHookWithCudaEvent( - option, cupti_interface_, collector)); - } else { - cupti_driver_api_hook_.reset(new CuptiDriverApiHookWithActivityApi( - option, cupti_interface_, collector)); - } + + cupti_driver_api_hook_.reset(new CuptiDriverApiHookWithActivityApi( + option, cupti_interface_, collector)); Status status = EnableApiTracing(); need_root_access_ |= status.code() == tsl::error::PERMISSION_DENIED; if (!status.ok()) return; - if (option_->enable_activity_api) { - EnableActivityTracing().IgnoreError(); - } + EnableActivityTracing().IgnoreError(); tsl::profiler::AnnotationStack::Enable(true); } void CuptiTracer::Disable() { DisableApiTracing().IgnoreError(); - if (option_->enable_activity_api) { - DisableActivityTracing().IgnoreError(); - } + DisableActivityTracing().IgnoreError(); cupti_interface_->CleanUp(); Finalize().IgnoreError(); cupti_driver_api_hook_->SyncAndFlush().IgnoreError(); @@ -2028,7 +1375,7 @@ Status CuptiTracer::Finalize() { return OkStatus(); } -/*static*/ tsl::uint64 CuptiTracer::GetTimestamp() { +/*static*/ uint64_t CuptiTracer::GetTimestamp() { uint64_t tsc; CuptiInterface *cupti_interface = GetCuptiInterface(); if (cupti_interface && cupti_interface->GetTimestamp(&tsc) == CUPTI_SUCCESS) { @@ -2075,7 +1422,7 @@ Status CuptiTracer::HandleCallback(CUpti_CallbackDomain domain, } // Grab a correct device ID. - tsl::uint32 device_id = -1; + uint32_t device_id = -1; RETURN_IF_CUPTI_ERROR( cupti_interface_->GetDeviceId(cbdata->context, &device_id)); if (device_id >= num_gpus_) { @@ -2227,7 +1574,7 @@ Status CuptiTracer::ProcessActivityBuffer(CUcontext context, uint32_t stream_id, RETURN_IF_CUPTI_ERROR(cupti_interface_->ActivityGetNumDroppedRecords( context, stream_id, &dropped)); if (dropped != 0) { - tsl::uint32 device_id = -1; + uint32_t device_id = -1; RETURN_IF_CUPTI_ERROR(cupti_interface_->GetDeviceId(context, &device_id)); collector_->OnEventsDropped("cupti activity buffer full", dropped); } diff --git a/xla/backends/profiler/gpu/cupti_tracer.h b/xla/backends/profiler/gpu/cupti_tracer.h index a7ed77160d0e8..59583453779b7 100644 --- a/xla/backends/profiler/gpu/cupti_tracer.h +++ b/xla/backends/profiler/gpu/cupti_tracer.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,13 +16,12 @@ limitations under the License. #ifndef XLA_BACKENDS_PROFILER_GPU_CUPTI_TRACER_H_ #define XLA_BACKENDS_PROFILER_GPU_CUPTI_TRACER_H_ +#include "absl/status/status.h" #include "absl/types/optional.h" #include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" #include "third_party/gpus/cuda/include/nvtx3/nvToolsExt.h" #include "xla/backends/profiler/gpu/cupti_collector.h" #include "xla/backends/profiler/gpu/cupti_interface.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" #include "tsl/platform/types.h" #include "tsl/profiler/utils/buffer_pool.h" @@ -30,13 +29,6 @@ namespace xla { namespace profiler { struct CuptiTracerOptions { - bool enable_activity_api = true; - - // Use cuda events to enclose the kernel/memcpy to measure device activity. - // enable_event_based_activity, if true, will override the enable_activity_api - // setting. - bool enable_event_based_activity = false; - bool required_callback_api_events = true; // The callback ids that will be enabled and monitored, if empty, all // Callback ids to be enabled using Callback API. @@ -59,18 +51,18 @@ class CuptiDriverApiHook { public: virtual ~CuptiDriverApiHook() {} - virtual tsl::Status OnDriverApiEnter( + virtual absl::Status OnDriverApiEnter( int device_id, CUpti_CallbackDomain domain, CUpti_CallbackId cbid, const CUpti_CallbackData* callback_info) = 0; - virtual tsl::Status OnDriverApiExit( + virtual absl::Status OnDriverApiExit( int device_id, CUpti_CallbackDomain domain, CUpti_CallbackId cbid, const CUpti_CallbackData* callback_info) = 0; - virtual tsl::Status SyncAndFlush() = 0; + virtual absl::Status SyncAndFlush() = 0; protected: - static tsl::Status AddDriverApiCallbackEvent( + static absl::Status AddDriverApiCallbackEvent( CuptiTraceCollector* collector, CuptiInterface* cupti_interface, - int device_id, tsl::uint64 start_tsc, tsl::uint64 end_tsc, + int device_id, uint64_t start_tsc, uint64_t end_tsc, CUpti_CallbackDomain domain, CUpti_CallbackId cbid, const CUpti_CallbackData* callback_info); }; @@ -94,8 +86,9 @@ class CuptiTracer { void Enable(const CuptiTracerOptions& option, CuptiTraceCollector* collector); void Disable(); - tsl::Status HandleCallback(CUpti_CallbackDomain domain, CUpti_CallbackId cbid, - const CUpti_CallbackData* callback_info); + absl::Status HandleCallback(CUpti_CallbackDomain domain, + CUpti_CallbackId cbid, + const CUpti_CallbackData* callback_info); // Returns a buffer and its size for CUPTI to store activities. This buffer // will be reclaimed when CUPTI makes a callback to ProcessActivityBuffer. @@ -104,8 +97,8 @@ class CuptiTracer { // Parses CUPTI activity events from activity buffer, and emits events for // CuptiTraceCollector. This function is public because called from registered // callback. - tsl::Status ProcessActivityBuffer(CUcontext context, uint32_t stream_id, - uint8_t* buffer, size_t size); + absl::Status ProcessActivityBuffer(CUcontext context, uint32_t stream_id, + uint8_t* buffer, size_t size); static uint64_t GetTimestamp(); static int NumGpus(); @@ -120,14 +113,14 @@ class CuptiTracer { // Buffer size and alignment, 32K and 8 as in CUPTI samples. static constexpr size_t kBufferSizeInBytes = 32 * 1024; - tsl::Status EnableApiTracing(); - tsl::Status EnableActivityTracing(); - tsl::Status DisableApiTracing(); - tsl::Status DisableActivityTracing(); - tsl::Status Finalize(); + absl::Status EnableApiTracing(); + absl::Status EnableActivityTracing(); + absl::Status DisableApiTracing(); + absl::Status DisableActivityTracing(); + absl::Status Finalize(); void ConfigureActivityUnifiedMemoryCounter(bool enable); - tsl::Status HandleNVTXCallback(CUpti_CallbackId cbid, - const CUpti_CallbackData* cbdata); + absl::Status HandleNVTXCallback(CUpti_CallbackId cbid, + const CUpti_CallbackData* cbdata); int num_gpus_; std::optional option_; diff --git a/xla/backends/profiler/gpu/cupti_utils.cc b/xla/backends/profiler/gpu/cupti_utils.cc index f3a12d8256f7c..a4198811286be 100644 --- a/xla/backends/profiler/gpu/cupti_utils.cc +++ b/xla/backends/profiler/gpu/cupti_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,17 +12,43 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/base/call_once.h" #include "absl/memory/memory.h" #include "xla/backends/profiler/gpu/cupti_error_manager.h" #include "xla/backends/profiler/gpu/cupti_interface.h" #include "xla/backends/profiler/gpu/cupti_wrapper.h" +#include "xla/tsl/util/env_var.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/stringpiece.h" namespace xla { namespace profiler { +bool IsCuptiUseStubInterface() { + // TODO: b/149634979: Remove this after NVIDIA issue 4459155 resolved. + static constexpr tsl::StringPiece cupti_use_stub_interface_env = + "TF_GPU_CUPTI_USE_STUB_INTERFACE"; + static absl::once_flag once; // NOLINT(clang-diagnostic-unreachable-code) + static bool cupti_use_stub_interface = false; + absl::call_once(once, [&] { + tsl::ReadBoolFromEnvVar(cupti_use_stub_interface_env, false, + &cupti_use_stub_interface) + .IgnoreError(); + if (cupti_use_stub_interface) { + LOG(INFO) << cupti_use_stub_interface_env << " is set to true, " + << "XLA Profiler is using stub CUPTI interface to work around " + << "potential serious bug in CUPTI lib. Such control may be " + << "removed/disabled in future if the known issue is resolved!"; + } + }); + return cupti_use_stub_interface; +} + CuptiInterface* GetCuptiInterface() { static CuptiInterface* cupti_interface = - new CuptiErrorManager(std::make_unique()); + IsCuptiUseStubInterface() + ? new CuptiErrorManager(std::make_unique()) + : new CuptiErrorManager(std::make_unique()); return cupti_interface; } diff --git a/xla/backends/profiler/gpu/cupti_wrapper.cc b/xla/backends/profiler/gpu/cupti_wrapper.cc index 2c7b1b2ada00d..16565d9dd576f 100644 --- a/xla/backends/profiler/gpu/cupti_wrapper.cc +++ b/xla/backends/profiler/gpu/cupti_wrapper.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -56,8 +56,7 @@ CUptiResult CuptiWrapper::ActivityRegisterCallbacks( func_buffer_completed); } -CUptiResult CuptiWrapper::GetDeviceId(CUcontext context, - tsl::uint32* deviceId) { +CUptiResult CuptiWrapper::GetDeviceId(CUcontext context, uint32_t* deviceId) { return cuptiGetDeviceId(context, deviceId); } @@ -90,145 +89,6 @@ CUptiResult CuptiWrapper::Unsubscribe(CUpti_SubscriberHandle subscriber) { return cuptiUnsubscribe(subscriber); } -CUptiResult CuptiWrapper::DeviceEnumEventDomains( - CUdevice device, size_t* array_size_bytes, - CUpti_EventDomainID* domain_array) { - return cuptiDeviceEnumEventDomains(device, array_size_bytes, domain_array); -} - -CUptiResult CuptiWrapper::DeviceGetEventDomainAttribute( - CUdevice device, CUpti_EventDomainID event_domain, - CUpti_EventDomainAttribute attrib, size_t* value_size, void* value) { - return cuptiDeviceGetEventDomainAttribute(device, event_domain, attrib, - value_size, value); -} - -CUptiResult CuptiWrapper::DisableKernelReplayMode(CUcontext context) { - return cuptiDisableKernelReplayMode(context); -} - -CUptiResult CuptiWrapper::EnableKernelReplayMode(CUcontext context) { - return cuptiEnableKernelReplayMode(context); -} - -CUptiResult CuptiWrapper::DeviceGetNumEventDomains(CUdevice device, - uint32_t* num_domains) { - return cuptiDeviceGetNumEventDomains(device, num_domains); -} - -CUptiResult CuptiWrapper::EventDomainEnumEvents( - CUpti_EventDomainID event_domain, size_t* array_size_bytes, - CUpti_EventID* event_array) { - return cuptiEventDomainEnumEvents(event_domain, array_size_bytes, - event_array); -} - -CUptiResult CuptiWrapper::EventDomainGetNumEvents( - CUpti_EventDomainID event_domain, uint32_t* num_events) { - return cuptiEventDomainGetNumEvents(event_domain, num_events); -} - -CUptiResult CuptiWrapper::EventGetAttribute(CUpti_EventID event, - CUpti_EventAttribute attrib, - size_t* value_size, void* value) { - return cuptiEventGetAttribute(event, attrib, value_size, value); -} - -CUptiResult CuptiWrapper::EventGetIdFromName(CUdevice device, - const char* event_name, - CUpti_EventID* event) { - return cuptiEventGetIdFromName(device, event_name, event); -} - -CUptiResult CuptiWrapper::EventGroupDisable(CUpti_EventGroup event_group) { - return cuptiEventGroupDisable(event_group); -} - -CUptiResult CuptiWrapper::EventGroupEnable(CUpti_EventGroup event_group) { - return cuptiEventGroupEnable(event_group); -} - -CUptiResult CuptiWrapper::EventGroupGetAttribute( - CUpti_EventGroup event_group, CUpti_EventGroupAttribute attrib, - size_t* value_size, void* value) { - return cuptiEventGroupGetAttribute(event_group, attrib, value_size, value); -} - -CUptiResult CuptiWrapper::EventGroupReadEvent( - CUpti_EventGroup event_group, CUpti_ReadEventFlags flags, - CUpti_EventID event, size_t* event_value_buffer_size_bytes, - uint64_t* event_value_buffer) { - return cuptiEventGroupReadEvent(event_group, flags, event, - event_value_buffer_size_bytes, - event_value_buffer); -} - -CUptiResult CuptiWrapper::EventGroupSetAttribute( - CUpti_EventGroup event_group, CUpti_EventGroupAttribute attrib, - size_t value_size, void* value) { - return cuptiEventGroupSetAttribute(event_group, attrib, value_size, value); -} - -CUptiResult CuptiWrapper::EventGroupSetsCreate( - CUcontext context, size_t event_id_array_size_bytes, - CUpti_EventID* event_id_array, CUpti_EventGroupSets** event_group_passes) { - return cuptiEventGroupSetsCreate(context, event_id_array_size_bytes, - event_id_array, event_group_passes); -} - -CUptiResult CuptiWrapper::EventGroupSetsDestroy( - CUpti_EventGroupSets* event_group_sets) { - return cuptiEventGroupSetsDestroy(event_group_sets); -} - -// CUPTI metric API -CUptiResult CuptiWrapper::DeviceEnumMetrics(CUdevice device, - size_t* arraySizeBytes, - CUpti_MetricID* metricArray) { - return cuptiDeviceEnumMetrics(device, arraySizeBytes, metricArray); -} - -CUptiResult CuptiWrapper::DeviceGetNumMetrics(CUdevice device, - uint32_t* num_metrics) { - return cuptiDeviceGetNumMetrics(device, num_metrics); -} - -CUptiResult CuptiWrapper::MetricGetIdFromName(CUdevice device, - const char* metric_name, - CUpti_MetricID* metric) { - return cuptiMetricGetIdFromName(device, metric_name, metric); -} - -CUptiResult CuptiWrapper::MetricGetNumEvents(CUpti_MetricID metric, - uint32_t* num_events) { - return cuptiMetricGetNumEvents(metric, num_events); -} - -CUptiResult CuptiWrapper::MetricEnumEvents(CUpti_MetricID metric, - size_t* event_id_array_size_bytes, - CUpti_EventID* event_id_array) { - return cuptiMetricEnumEvents(metric, event_id_array_size_bytes, - event_id_array); -} - -CUptiResult CuptiWrapper::MetricGetAttribute(CUpti_MetricID metric, - CUpti_MetricAttribute attrib, - size_t* value_size, void* value) { - return cuptiMetricGetAttribute(metric, attrib, value_size, value); -} - -CUptiResult CuptiWrapper::MetricGetValue(CUdevice device, CUpti_MetricID metric, - size_t event_id_array_size_bytes, - CUpti_EventID* event_id_array, - size_t event_value_array_size_bytes, - uint64_t* event_value_array, - uint64_t time_duration, - CUpti_MetricValue* metric_value) { - return cuptiMetricGetValue(device, metric, event_id_array_size_bytes, - event_id_array, event_value_array_size_bytes, - event_value_array, time_duration, metric_value); -} - CUptiResult CuptiWrapper::GetResultString(CUptiResult result, const char** str) { return cuptiGetResultString(result, str); diff --git a/xla/backends/profiler/gpu/cupti_wrapper.h b/xla/backends/profiler/gpu/cupti_wrapper.h index 658148106d272..ada4af91e6d48 100644 --- a/xla/backends/profiler/gpu/cupti_wrapper.h +++ b/xla/backends/profiler/gpu/cupti_wrapper.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -55,7 +55,7 @@ class CuptiWrapper : public xla::profiler::CuptiInterface { CUpti_BuffersCallbackRequestFunc func_buffer_requested, CUpti_BuffersCallbackCompleteFunc func_buffer_completed) override; - CUptiResult GetDeviceId(CUcontext context, tsl::uint32* deviceId) override; + CUptiResult GetDeviceId(CUcontext context, uint32_t* deviceId) override; CUptiResult GetTimestamp(uint64_t* timestamp) override; @@ -77,92 +77,75 @@ class CuptiWrapper : public xla::profiler::CuptiInterface { CUptiResult Unsubscribe(CUpti_SubscriberHandle subscriber) override; - // CUPTI event API - CUptiResult DeviceEnumEventDomains( - CUdevice device, size_t* array_size_bytes, - CUpti_EventDomainID* domain_array) override; - - CUptiResult DeviceGetEventDomainAttribute(CUdevice device, - CUpti_EventDomainID event_domain, - CUpti_EventDomainAttribute attrib, - size_t* value_size, - void* value) override; - - CUptiResult DisableKernelReplayMode(CUcontext context) override; + CUptiResult GetResultString(CUptiResult result, const char** str) override; - CUptiResult EnableKernelReplayMode(CUcontext context) override; + CUptiResult GetContextId(CUcontext context, uint32_t* context_id) override; - CUptiResult DeviceGetNumEventDomains(CUdevice device, - uint32_t* num_domains) override; + CUptiResult GetStreamIdEx(CUcontext context, CUstream stream, + uint8_t per_thread_stream, + uint32_t* stream_id) override; - CUptiResult EventDomainEnumEvents(CUpti_EventDomainID event_domain, - size_t* array_size_bytes, - CUpti_EventID* event_array) override; + void CleanUp() override {} + bool Disabled() const override { return false; } - CUptiResult EventDomainGetNumEvents(CUpti_EventDomainID event_domain, - uint32_t* num_events) override; + private: + CuptiWrapper(const CuptiWrapper&) = delete; + void operator=(const CuptiWrapper&) = delete; +}; - CUptiResult EventGetAttribute(CUpti_EventID event, - CUpti_EventAttribute attrib, size_t* value_size, - void* value) override; +// This is an implementation of CuptiWrapper that implements all load bearing +// APIs as no-op. This is a stub that keeps XLA profiler functional, but all +// collected profiles will be empty. +class CuptiWrapperStub : public xla::profiler::CuptiInterface { + public: + CuptiWrapperStub() {} - CUptiResult EventGetIdFromName(CUdevice device, const char* event_name, - CUpti_EventID* event) override; + ~CuptiWrapperStub() override {} - CUptiResult EventGroupDisable(CUpti_EventGroup event_group) override; + // CUPTI activity API + CUptiResult ActivityDisable(CUpti_ActivityKind kind) override; - CUptiResult EventGroupEnable(CUpti_EventGroup event_group) override; + CUptiResult ActivityEnable(CUpti_ActivityKind kind) override; - CUptiResult EventGroupGetAttribute(CUpti_EventGroup event_group, - CUpti_EventGroupAttribute attrib, - size_t* value_size, void* value) override; + CUptiResult ActivityFlushAll(uint32_t flag) override; - CUptiResult EventGroupReadEvent(CUpti_EventGroup event_group, - CUpti_ReadEventFlags flags, - CUpti_EventID event, - size_t* event_value_buffer_size_bytes, - uint64_t* event_value_buffer) override; + CUptiResult ActivityGetNextRecord(uint8_t* buffer, + size_t valid_buffer_size_bytes, + CUpti_Activity** record) override; - CUptiResult EventGroupSetAttribute(CUpti_EventGroup event_group, - CUpti_EventGroupAttribute attrib, - size_t value_size, void* value) override; + CUptiResult ActivityGetNumDroppedRecords(CUcontext context, + uint32_t stream_id, + size_t* dropped) override; - CUptiResult EventGroupSetsCreate( - CUcontext context, size_t event_id_array_size_bytes, - CUpti_EventID* event_id_array, - CUpti_EventGroupSets** event_group_passes) override; + CUptiResult ActivityConfigureUnifiedMemoryCounter( + CUpti_ActivityUnifiedMemoryCounterConfig* config, + uint32_t count) override; - CUptiResult EventGroupSetsDestroy( - CUpti_EventGroupSets* event_group_sets) override; + CUptiResult ActivityRegisterCallbacks( + CUpti_BuffersCallbackRequestFunc func_buffer_requested, + CUpti_BuffersCallbackCompleteFunc func_buffer_completed) override; - // CUPTI metric API - CUptiResult DeviceEnumMetrics(CUdevice device, size_t* arraySizeBytes, - CUpti_MetricID* metricArray) override; + CUptiResult GetDeviceId(CUcontext context, uint32_t* deviceId) override; - CUptiResult DeviceGetNumMetrics(CUdevice device, - uint32_t* num_metrics) override; + CUptiResult GetTimestamp(uint64_t* timestamp) override; - CUptiResult MetricGetIdFromName(CUdevice device, const char* metric_name, - CUpti_MetricID* metric) override; + // cuptiFinalize is only defined in CUDA8 and above. + // To enable it in CUDA8, the environment variable CUPTI_ENABLE_FINALIZE must + // be set to 1. + CUptiResult Finalize() override; - CUptiResult MetricGetNumEvents(CUpti_MetricID metric, - uint32_t* num_events) override; + // CUPTI callback API + CUptiResult EnableCallback(uint32_t enable, CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain, + CUpti_CallbackId cbid) override; - CUptiResult MetricEnumEvents(CUpti_MetricID metric, - size_t* event_id_array_size_bytes, - CUpti_EventID* event_id_array) override; + CUptiResult EnableDomain(uint32_t enable, CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain) override; - CUptiResult MetricGetAttribute(CUpti_MetricID metric, - CUpti_MetricAttribute attrib, - size_t* value_size, void* value) override; + CUptiResult Subscribe(CUpti_SubscriberHandle* subscriber, + CUpti_CallbackFunc callback, void* userdata) override; - CUptiResult MetricGetValue(CUdevice device, CUpti_MetricID metric, - size_t event_id_array_size_bytes, - CUpti_EventID* event_id_array, - size_t event_value_array_size_bytes, - uint64_t* event_value_array, - uint64_t time_duration, - CUpti_MetricValue* metric_value) override; + CUptiResult Unsubscribe(CUpti_SubscriberHandle subscriber) override; CUptiResult GetResultString(CUptiResult result, const char** str) override; @@ -176,8 +159,8 @@ class CuptiWrapper : public xla::profiler::CuptiInterface { bool Disabled() const override { return false; } private: - CuptiWrapper(const CuptiWrapper&) = delete; - void operator=(const CuptiWrapper&) = delete; + CuptiWrapperStub(const CuptiWrapperStub&) = delete; + void operator=(const CuptiWrapperStub&) = delete; }; } // namespace profiler diff --git a/xla/backends/profiler/gpu/cupti_wrapper_stub.cc b/xla/backends/profiler/gpu/cupti_wrapper_stub.cc new file mode 100644 index 0000000000000..945fe49e853de --- /dev/null +++ b/xla/backends/profiler/gpu/cupti_wrapper_stub.cc @@ -0,0 +1,109 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "xla/backends/profiler/gpu/cupti_wrapper.h" + +namespace xla { +namespace profiler { + +CUptiResult CuptiWrapperStub::ActivityDisable(CUpti_ActivityKind kind) { + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::ActivityEnable(CUpti_ActivityKind kind) { + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::ActivityFlushAll(uint32_t flag) { + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::ActivityGetNextRecord( + uint8_t* buffer, size_t valid_buffer_size_bytes, CUpti_Activity** record) { + return CUPTI_ERROR_MAX_LIMIT_REACHED; +} + +CUptiResult CuptiWrapperStub::ActivityGetNumDroppedRecords(CUcontext context, + uint32_t stream_id, + size_t* dropped) { + *dropped = 0; + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::ActivityConfigureUnifiedMemoryCounter( + CUpti_ActivityUnifiedMemoryCounterConfig* config, uint32_t count) { + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::ActivityRegisterCallbacks( + CUpti_BuffersCallbackRequestFunc func_buffer_requested, + CUpti_BuffersCallbackCompleteFunc func_buffer_completed) { + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::GetDeviceId(CUcontext context, + uint32_t* deviceId) { + return cuptiGetDeviceId(context, deviceId); +} + +CUptiResult CuptiWrapperStub::GetTimestamp(uint64_t* timestamp) { + return cuptiGetTimestamp(timestamp); +} + +CUptiResult CuptiWrapperStub::Finalize() { return CUPTI_SUCCESS; } + +CUptiResult CuptiWrapperStub::EnableCallback(uint32_t enable, + CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain, + CUpti_CallbackId cbid) { + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::EnableDomain(uint32_t enable, + CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain) { + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::Subscribe(CUpti_SubscriberHandle* subscriber, + CUpti_CallbackFunc callback, + void* userdata) { + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::Unsubscribe(CUpti_SubscriberHandle subscriber) { + return CUPTI_SUCCESS; +} + +CUptiResult CuptiWrapperStub::GetResultString(CUptiResult result, + const char** str) { + return cuptiGetResultString(result, str); +} + +CUptiResult CuptiWrapperStub::GetContextId(CUcontext context, + uint32_t* context_id) { + return cuptiGetContextId(context, context_id); +} + +CUptiResult CuptiWrapperStub::GetStreamIdEx(CUcontext context, CUstream stream, + uint8_t per_thread_stream, + uint32_t* stream_id) { + return cuptiGetStreamIdEx(context, stream, per_thread_stream, stream_id); +} + +} // namespace profiler +} // namespace xla diff --git a/xla/backends/profiler/gpu/device_tracer_cuda.cc b/xla/backends/profiler/gpu/device_tracer_cuda.cc index 7aacb96615cdb..d7bb2524b6676 100644 --- a/xla/backends/profiler/gpu/device_tracer_cuda.cc +++ b/xla/backends/profiler/gpu/device_tracer_cuda.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,9 +23,11 @@ limitations under the License. #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "third_party/gpus/cuda/include/cuda.h" #include "xla/backends/profiler/gpu/cupti_collector.h" #include "xla/backends/profiler/gpu/cupti_tracer.h" #include "xla/backends/profiler/gpu/cupti_wrapper.h" +#include "xla/tsl/util/env_var.h" #include "tsl/platform/errors.h" #include "tsl/platform/macros.h" #include "tsl/platform/thread_annotations.h" @@ -33,16 +35,15 @@ limitations under the License. #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/xplane.pb.h" #include "tsl/profiler/utils/time_utils.h" -#include "tsl/util/env_var.h" namespace xla { namespace profiler { +using absl::OkStatus; +using absl::Status; using tensorflow::ProfileOptions; using tensorflow::profiler::XSpace; -using tsl::OkStatus; using tsl::ReadBoolFromEnvVar; -using tsl::Status; // GpuTracer for GPU. class GpuTracer : public tsl::profiler::ProfilerInterface { @@ -130,12 +131,6 @@ Status GpuTracer::DoStart() { CUPTI_DRIVER_TRACE_CBID_cuStreamSynchronize, }; - bool use_cupti_activity_api = true; - ReadBoolFromEnvVar("TF_GPU_CUPTI_USE_ACTIVITY_API", true, - &use_cupti_activity_api) - .IgnoreError(); - options_.enable_event_based_activity = !use_cupti_activity_api; - bool trace_concurrent_kernels = false; ReadBoolFromEnvVar("TF_GPU_CUPTI_FORCE_CONCURRENT_KERNEL", true, &trace_concurrent_kernels) @@ -155,8 +150,8 @@ Status GpuTracer::DoStart() { CuptiTracerCollectorOptions collector_options; collector_options.num_gpus = cupti_tracer_->NumGpus(); - tsl::uint64 start_gputime_ns = CuptiTracer::GetTimestamp(); - tsl::uint64 start_walltime_ns = tsl::profiler::GetCurrentTimeNanos(); + uint64_t start_gputime_ns = CuptiTracer::GetTimestamp(); + uint64_t start_walltime_ns = tsl::profiler::GetCurrentTimeNanos(); cupti_collector_ = CreateCuptiCollector(collector_options, start_walltime_ns, start_gputime_ns); @@ -213,7 +208,7 @@ Status GpuTracer::CollectData(XSpace* space) { space->add_warnings(std::move(events_dropped)); } if (cupti_collector_) { - tsl::uint64 end_gpu_ns = CuptiTracer::GetTimestamp(); + uint64_t end_gpu_ns = CuptiTracer::GetTimestamp(); cupti_collector_->Export(space, end_gpu_ns); } return OkStatus(); diff --git a/xla/backends/profiler/gpu/device_tracer_rocm.cc b/xla/backends/profiler/gpu/device_tracer_rocm.cc index 7fdf43ef50587..81eb2d192ea09 100644 --- a/xla/backends/profiler/gpu/device_tracer_rocm.cc +++ b/xla/backends/profiler/gpu/device_tracer_rocm.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The OpenXLA Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,7 +24,9 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "xla/backends/profiler/gpu/rocm_collector.h" #include "xla/backends/profiler/gpu/rocm_tracer.h" +#include "xla/tsl/util/env_var.h" #include "tsl/platform/abi.h" #include "tsl/platform/env_time.h" #include "tsl/platform/errors.h" @@ -38,7 +40,6 @@ limitations under the License. #include "tsl/profiler/utils/xplane_builder.h" #include "tsl/profiler/utils/xplane_schema.h" #include "tsl/profiler/utils/xplane_utils.h" -#include "tsl/util/env_var.h" namespace xla { namespace profiler { @@ -65,872 +66,6 @@ using tsl::profiler::XLineBuilder; using tsl::profiler::XPlaneBuilder; using tsl::profiler::XSpace; -namespace { -// Set the all XLines of specified XPlane to starting walltime. -// Events time in both host and device planes are CUTPI timestamps. -// We set initial RocmTracer timestamp as start time for all lines to reflect -// this fact. Eventually we change line start time to corresponding -// start_walltime_ns to normalize with CPU wall time. -static void NormalizeTimeStamps(XPlaneBuilder* plane, - uint64_t start_walltime_ns) { - plane->ForEachLine([&](tsl::profiler::XLineBuilder line) { - line.SetTimestampNs(start_walltime_ns); - }); -} - -std::string GetDeviceXLineName( - int64_t stream_id, absl::flat_hash_set& event_types) { - std::string line_name = absl::StrCat("Stream #", stream_id); - event_types.erase(RocmTracerEventType::Unsupported); - if (event_types.empty()) return line_name; - std::vector type_names; - for (const auto event_type : event_types) { - type_names.emplace_back(GetRocmTracerEventTypeName(event_type)); - } - return absl::StrCat(line_name, "(", absl::StrJoin(type_names, ","), ")"); -} - -} // namespace - -class RocmTraceCollectorImpl : public profiler::RocmTraceCollector { - public: - RocmTraceCollectorImpl(const RocmTraceCollectorOptions& options, - uint64_t start_walltime_ns, uint64_t start_gputime_ns) - : RocmTraceCollector(options), - num_callback_events_(0), - num_activity_events_(0), - start_walltime_ns_(start_walltime_ns), - start_gputime_ns_(start_gputime_ns), - per_device_collector_(options.num_gpus) {} - - void AddEvent(RocmTracerEvent&& event, bool is_auxiliary) override { - mutex_lock lock(event_maps_mutex_); - - if (event.source == RocmTracerEventSource::ApiCallback && !is_auxiliary) { - if (num_callback_events_ > options_.max_callback_api_events) { - OnEventsDropped("max callback event capacity reached", - event.correlation_id); - DumpRocmTracerEvent(event, 0, 0, ". Dropped!"); - return; - } - num_callback_events_++; - } else if (event.source == RocmTracerEventSource::Activity && - event.domain == RocmTracerEventDomain::HIP_API) { - // we do not count HIP_OPS activities. - if (num_activity_events_ > options_.max_activity_api_events) { - OnEventsDropped("max activity event capacity reached", - event.correlation_id); - DumpRocmTracerEvent(event, 0, 0, ". Dropped!"); - return; - } - num_activity_events_++; - } - - bool emplace_result = false; - if (event.source == RocmTracerEventSource::ApiCallback) { - auto& target_api_event_map = - (is_auxiliary) ? auxiliary_api_events_map_ : api_events_map_; - std::tie(std::ignore, emplace_result) = - target_api_event_map.emplace(event.correlation_id, std::move(event)); - } else if (event.source == RocmTracerEventSource::Activity) { - if (event.domain == RocmTracerEventDomain::HIP_API) { - std::tie(std::ignore, emplace_result) = - activity_api_events_map_.emplace(event.correlation_id, - std::move(event)); - } else if (event.domain == RocmTracerEventDomain::HCC_OPS) { - auto result = activity_ops_events_map_.emplace( - event.correlation_id, std::vector{}); - result.first->second.push_back(std::move(event)); - emplace_result = true; // we always accept Hip-Ops events - } - } - if (!emplace_result) { - OnEventsDropped("event with duplicate correlation_id was received.", - event.correlation_id); - DumpRocmTracerEvent(event, 0, 0, ". Dropped!"); - } - } - - void OnEventsDropped(const std::string& reason, - uint32_t correlation_id) override { - LOG(INFO) << "RocmTracerEvent dropped (correlation_id=" << correlation_id - << ",) : " << reason << "."; - } - - void Flush() override { - mutex_lock lock(event_maps_mutex_); - auto& aggregated_events_ = ApiActivityInfoExchange(); - - VLOG(3) << "RocmTraceCollector collected " << num_callback_events_ - << " callback events, " << num_activity_events_ - << " activity events, and aggregated them into " - << aggregated_events_.size() << " events."; - - for (auto& event : aggregated_events_) { - if (event.device_id >= options_.num_gpus) { - OnEventsDropped("device id >= num gpus", event.correlation_id); - DumpRocmTracerEvent(event, 0, 0, ". Dropped!"); - LOG(WARNING) << "A ROCm profiler event record with wrong device ID " - "dropped! Type=" - << GetRocmTracerEventTypeName(event.type); - continue; - } - - activity_api_events_map_.clear(); - activity_ops_events_map_.clear(); - api_events_map_.clear(); - auxiliary_api_events_map_.clear(); - - per_device_collector_[event.device_id].AddEvent(event); - } - - for (int i = 0; i < options_.num_gpus; ++i) { - per_device_collector_[i].SortByStartTime(); - } - } - - void Export(XSpace* space) { - uint64_t end_gputime_ns = RocmTracer::GetTimestamp(); - XPlaneBuilder host_plane(FindOrAddMutablePlaneWithName( - space, tsl::profiler::kRoctracerApiPlaneName)); - for (int i = 0; i < options_.num_gpus; ++i) { - std::string name = GpuPlaneName(i); - XPlaneBuilder device_plane(FindOrAddMutablePlaneWithName(space, name)); - device_plane.SetId(i); - // Calculate device capabilities before flushing, so that device - // properties are available to the occupancy calculator in export(). - per_device_collector_[i].GetDeviceCapabilities(i, &device_plane); - per_device_collector_[i].Export(start_walltime_ns_, start_gputime_ns_, - end_gputime_ns, &device_plane, - &host_plane); - - NormalizeTimeStamps(&device_plane, start_walltime_ns_); - } - NormalizeTimeStamps(&host_plane, start_walltime_ns_); - } - - private: - std::atomic num_callback_events_; - std::atomic num_activity_events_; - uint64_t start_walltime_ns_; - uint64_t start_gputime_ns_; - - mutex event_maps_mutex_; - absl::flat_hash_map api_events_map_ - TF_GUARDED_BY(event_maps_mutex_); - absl::flat_hash_map activity_api_events_map_ - TF_GUARDED_BY(event_maps_mutex_); - - /* Some apis such as MEMSETD32 (based on an observation with ResNet50), - trigger multiple HIP ops domain activities. We keep them in a vector and - merge them with api activities at flush time. - */ - absl::flat_hash_map> - activity_ops_events_map_ TF_GUARDED_BY(event_maps_mutex_); - // This is for the APIs that we track because we need some information from - // them to populate the corresponding activity that we actually track. - absl::flat_hash_map auxiliary_api_events_map_ - TF_GUARDED_BY(event_maps_mutex_); - - const std::vector ApiActivityInfoExchange() { - /* Different from CUDA, roctracer activity records are not enough to fill a - TF event. For most of the activities, we need to enable the corresponding - API callsbacks (we call them auxiliary API callbacks) to capture the - necessary fields from them using the correlation id. The purpose of this - function is to let APIs and activities exchange information to reach a - state very similar to TF CUDA and getting ready to dump the event. - */ - - // Copying info from HIP-OPS activities to HIP-API activities - /*HIP-API activities <<==== HIP-OPS activities*/ - auto activity_api_events_map_iter = activity_api_events_map_.begin(); - while (activity_api_events_map_iter != activity_api_events_map_.end()) { - RocmTracerEvent& activity_api_event = - activity_api_events_map_iter->second; - - bool result = false; - switch (activity_api_event.type) { - case RocmTracerEventType::Kernel: - case RocmTracerEventType::Memset: { - // KERNEL & MEMSET - auto iter = - activity_ops_events_map_.find(activity_api_event.correlation_id); - result = (iter != activity_ops_events_map_.end()); - if (result) { - // since the key exist in the map, there should be at least one item - // in the vector - activity_api_event.device_id = iter->second.front().device_id; - activity_api_event.stream_id = iter->second.front().stream_id; - // we initialize the start time and end time based on the first - // element - activity_api_event.start_time_ns = - iter->second.front().start_time_ns; - activity_api_event.end_time_ns = iter->second.front().end_time_ns; - for (auto& kernel_activity_op : iter->second) { - activity_api_event.start_time_ns = - std::min(activity_api_event.start_time_ns, - kernel_activity_op.start_time_ns); - activity_api_event.end_time_ns = - std::max(activity_api_event.end_time_ns, - kernel_activity_op.end_time_ns); - } - } - break; - } - case RocmTracerEventType::MemcpyD2D: - case RocmTracerEventType::MemcpyH2D: - case RocmTracerEventType::MemcpyD2H: - case RocmTracerEventType::MemcpyOther: { - // MEMCPY - auto iter = - activity_ops_events_map_.find(activity_api_event.correlation_id); - result = (iter != activity_ops_events_map_.end()); - if (result) { - // since the key exist in the map, there should be at least one item - // in the vector - activity_api_event.device_id = iter->second.front().device_id; - activity_api_event.memcpy_info.destination = - iter->second.front() - .memcpy_info.destination; // similar to CUDA, it is the - // same as device_id - activity_api_event.stream_id = iter->second.front().stream_id; - /* IMPORTANT: it seems that the HCC timing is only valid for - * Synchronous memcpy activities*/ - if (!activity_api_event.memcpy_info.async) { - activity_api_event.start_time_ns = - iter->second.front().start_time_ns; - activity_api_event.end_time_ns = iter->second.front().end_time_ns; - for (auto& kernel_activity_op : iter->second) { - activity_api_event.start_time_ns = - std::min(activity_api_event.start_time_ns, - kernel_activity_op.start_time_ns); - activity_api_event.end_time_ns = - std::max(activity_api_event.end_time_ns, - kernel_activity_op.end_time_ns); - } - } - } - break; - } - default: - // nothing to do for the rest - result = true; - break; - } - if (!result) { - OnEventsDropped( - "A HIP-API activity with missing HIP-OPS activity was found", - activity_api_event.correlation_id); - DumpRocmTracerEvent(activity_api_event, 0, 0, ". Dropped!"); - activity_api_events_map_.erase(activity_api_events_map_iter++); - } else { - ++activity_api_events_map_iter; - } - } - - // the event vector to be returned - std::vector aggregated_events; - - // Copying info from HIP activities to HIP API callbacks - /*HIP-API call backs <<==== HIP-API activities*/ - for (auto& api_iter : api_events_map_) { - RocmTracerEvent& api_event = api_iter.second; - auto iter = activity_api_events_map_.find(api_event.correlation_id); - switch (api_event.type) { - /*KERNEL API*/ - case RocmTracerEventType::Kernel: { - aggregated_events.push_back(api_event); - break; - } - /*MEMCPY API*/ - case RocmTracerEventType::MemcpyD2H: - case RocmTracerEventType::MemcpyH2D: - case RocmTracerEventType::MemcpyD2D: - case RocmTracerEventType::MemcpyOther: { - if (iter != activity_api_events_map_.end()) { - api_event.device_id = iter->second.device_id; - api_event.memcpy_info.destination = - api_event.device_id; // Similar to CUDA - aggregated_events.push_back(api_event); - } else { - OnEventsDropped( - "A Memcpy event from HIP API discarded." - " Could not find the counterpart activity.", - api_event.correlation_id); - DumpRocmTracerEvent(api_event, 0, 0, ". Dropped!"); - } - break; - } - /*MEMSET API*/ - case RocmTracerEventType::Memset: { - if (iter != activity_api_events_map_.end()) { - api_event.device_id = iter->second.device_id; - - aggregated_events.push_back(api_event); - } else { - OnEventsDropped( - "A Memset event from HIP API discarded." - " Could not find the counterpart activity.", - api_event.correlation_id); - DumpRocmTracerEvent(api_event, 0, 0, ". Dropped!"); - } - break; - } - /*MALLOC API, FREE API*/ - case RocmTracerEventType::MemoryAlloc: - case RocmTracerEventType::MemoryFree: { - // no missing info - aggregated_events.push_back(api_event); - break; - } - /*SYNCHRONIZATION API*/ - case RocmTracerEventType::Synchronization: { - // no missing info - aggregated_events.push_back(api_event); - break; - } - default: - OnEventsDropped("Missing API-Activity information exchange. Dropped!", - api_event.correlation_id); - DumpRocmTracerEvent(api_event, 0, 0, ". Dropped!"); - LOG(WARNING) << "A ROCm API event type with unimplemented activity " - "merge dropped! " - "Type=" - << GetRocmTracerEventTypeName(api_event.type); - break; - } // end switch(api_event.type) - } - - // Copying info from HIP API callbacks to HIP API activities - // API ACTIVITIES<<====API-CB - for (auto& activity_iter : activity_api_events_map_) { - RocmTracerEvent& activity_event = activity_iter.second; - // finding the corresponding activity either in the api_call backs or the - // axuilarities - auto iter = api_events_map_.find(activity_event.correlation_id); - - iter = (iter == api_events_map_.end()) - ? auxiliary_api_events_map_.find(activity_event.correlation_id) - : iter; - switch (activity_event.type) { - /*KERNEL ACTIVITY*/ - case RocmTracerEventType::Kernel: { - if (iter != api_events_map_.end() || - iter != auxiliary_api_events_map_.end()) { - activity_event.name = iter->second.name; - activity_event.kernel_info = iter->second.kernel_info; - aggregated_events.push_back(activity_event); - } else { - OnEventsDropped( - "A Kernel event activity was discarded." - " Could not find the counterpart API callback.", - activity_event.correlation_id); - DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!"); - } - break; - } - /*MEMCPY ACTIVITY*/ - case RocmTracerEventType::MemcpyD2H: - case RocmTracerEventType::MemcpyH2D: - case RocmTracerEventType::MemcpyD2D: - case RocmTracerEventType::MemcpyOther: { - if (iter != api_events_map_.end() || - iter != auxiliary_api_events_map_.end()) { - activity_event.memcpy_info = iter->second.memcpy_info; - aggregated_events.push_back(activity_event); - } else { - OnEventsDropped( - "A Memcpy event activity was discarded." - " Could not find the counterpart API callback.", - activity_event.correlation_id); - DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!"); - } - break; - } - /*MEMSET ACTIVITY*/ - case RocmTracerEventType::Memset: { - if (iter != api_events_map_.end() || - iter != auxiliary_api_events_map_.end()) { - activity_event.memset_info = iter->second.memset_info; - aggregated_events.push_back(activity_event); - - } else { - OnEventsDropped( - "A Memset event activity was discarded." - " Could not find the counterpart API callback.", - activity_event.correlation_id); - DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!"); - } - break; - } - /*MALLOC ACTIVITY, FREE ACTIVITY*/ - case RocmTracerEventType::MemoryAlloc: - case RocmTracerEventType::MemoryFree: { - if (iter != api_events_map_.end() || - iter != auxiliary_api_events_map_.end()) { - activity_event.device_id = iter->second.device_id; - aggregated_events.push_back(activity_event); - } else { - OnEventsDropped( - "A Malloc/Free activity was discarded." - " Could not find the counterpart API callback.", - activity_event.correlation_id); - DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!"); - } - break; - } - /*SYNCHRONIZATION ACTIVITY*/ - case RocmTracerEventType::Synchronization: { - if (iter != api_events_map_.end() || - iter != auxiliary_api_events_map_.end()) { - // CUDA does not provide device ID for these activities. - // Interestingly, TF-profiler by default set the device id to 0 for - // CuptiTracerEvent. - // RocmTracerEvent type, set device by default to an unvalid - // device-id value. To be consistent with CUDA (in terms of having a - // logically valid value for device id) we update the device-id to - // its correct value - activity_event.device_id = iter->second.device_id; - aggregated_events.push_back(activity_event); - } else { - OnEventsDropped( - "A sync event activity was discarded." - " Could not find the counterpart API callback.", - activity_event.correlation_id); - DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!"); - } - break; - } - default: - OnEventsDropped("Missing API-Activity information exchange. Dropped!", - activity_event.correlation_id); - DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!"); - LOG(WARNING) << "A ROCm activity event with unimplemented API " - "callback merge dropped! " - "Type=" - << GetRocmTracerEventTypeName(activity_event.type); - break; - } // end switch(activity_event.type) - } - - return aggregated_events; - } - struct RocmDeviceOccupancyParams { - hipFuncAttributes attributes = {}; - int block_size = 0; - size_t dynamic_smem_size = 0; - void* func_ptr; - - friend bool operator==(const RocmDeviceOccupancyParams& lhs, - const RocmDeviceOccupancyParams& rhs) { - return 0 == memcmp(&lhs, &rhs, sizeof(lhs)); - } - - template - friend H AbslHashValue(H hash_state, - const RocmDeviceOccupancyParams& params) { - return H::combine( - std::move(hash_state), params.attributes.maxThreadsPerBlock, - params.attributes.numRegs, params.attributes.sharedSizeBytes, - params.attributes.maxDynamicSharedSizeBytes, params.block_size, - params.dynamic_smem_size, params.func_ptr); - } - }; - - struct OccupancyStats { - double occupancy_pct = 0.0; - int min_grid_size = 0; - int suggested_block_size = 0; - }; - struct CorrelationInfo { - CorrelationInfo(uint32_t t, uint32_t e) - : thread_id(t), enqueue_time_ns(e) {} - uint32_t thread_id; - uint64_t enqueue_time_ns; - }; - - struct PerDeviceCollector { - void GetDeviceCapabilities(int32_t device_ordinal, - XPlaneBuilder* device_plane) { - device_plane->AddStatValue(*device_plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevVendor)), - kDeviceVendorAMD); - - if (hipGetDeviceProperties(&device_properties_, device_ordinal) != - hipSuccess) - return; - - auto clock_rate_in_khz = - device_properties_.clockRate; // this is also in Khz - if (clock_rate_in_khz) { - device_plane->AddStatValue( - *device_plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapClockRateKHz)), - clock_rate_in_khz); - } - - auto core_count = device_properties_.multiProcessorCount; - if (core_count) { - device_plane->AddStatValue( - *device_plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapCoreCount)), - core_count); - } - - auto mem_clock_khz = device_properties_.memoryClockRate; - auto mem_bus_width_bits = device_properties_.memoryBusWidth; - - if (mem_clock_khz && mem_bus_width_bits) { - // Times 2 because HBM is DDR memory; it gets two data bits per each - // data lane. - auto memory_bandwidth = - tsl::uint64{2} * (mem_clock_khz)*1000 * (mem_bus_width_bits) / 8; - device_plane->AddStatValue( - *device_plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapMemoryBandwidth)), - memory_bandwidth); - } - - size_t total_memory = device_properties_.totalGlobalMem; - if (total_memory) { - device_plane->AddStatValue( - *device_plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapMemorySize)), - static_cast(total_memory)); - } - - auto compute_capability_major = device_properties_.major; - if (compute_capability_major) { - device_plane->AddStatValue( - *device_plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapComputeCapMajor)), - compute_capability_major); - } - auto compute_capability_minor = device_properties_.minor; - if (compute_capability_minor) { - device_plane->AddStatValue( - *device_plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapComputeCapMinor)), - compute_capability_minor); - } - } - - inline std::string ToXStat(const KernelDetails& kernel_info, - double occupancy_pct) { - return absl::StrCat( - "regs:", kernel_info.registers_per_thread, - " static_shared:", kernel_info.static_shared_memory_usage, - " dynamic_shared:", kernel_info.dynamic_shared_memory_usage, - " grid:", kernel_info.grid_x, ",", kernel_info.grid_y, ",", - kernel_info.grid_z, " block:", kernel_info.block_x, ",", - kernel_info.block_y, ",", kernel_info.block_z, - " occ_pct:", occupancy_pct); - } - OccupancyStats GetOccupancy(const RocmDeviceOccupancyParams& params) const { - // TODO(rocm-profiler): hipOccupancyMaxActiveBlocksPerMultiprocessor only - // return hipSuccess for HIP_API_ID_hipLaunchKernel - - OccupancyStats stats; - int number_of_active_blocks; - hipError_t err = hipOccupancyMaxActiveBlocksPerMultiprocessor( - &number_of_active_blocks, params.func_ptr, params.block_size, - params.dynamic_smem_size); - - if (err != hipError_t::hipSuccess) { - return {}; - } - - stats.occupancy_pct = number_of_active_blocks * params.block_size * 100; - stats.occupancy_pct /= device_properties_.maxThreadsPerMultiProcessor; - - err = hipOccupancyMaxPotentialBlockSize( - &stats.min_grid_size, &stats.suggested_block_size, params.func_ptr, - params.dynamic_smem_size, 0); - - if (err != hipError_t::hipSuccess) { - return {}; - } - - return stats; - } - void AddEvent(const RocmTracerEvent& event) { - mutex_lock l(events_mutex); - if (event.source == RocmTracerEventSource::ApiCallback) { - // Cupti api callback events were used to populate launch times etc. - if (event.correlation_id != RocmTracerEvent::kInvalidCorrelationId) { - correlation_info_.insert( - {event.correlation_id, - CorrelationInfo(event.thread_id, event.start_time_ns)}); - } - events.emplace_back(std::move(event)); - } else { - // Cupti activity events measure device times etc. - events.emplace_back(std::move(event)); - } - } - - void SortByStartTime() { - mutex_lock lock(events_mutex); - std::sort( - events.begin(), events.end(), - [](const RocmTracerEvent& event1, const RocmTracerEvent& event2) { - return event1.start_time_ns < event2.start_time_ns; - }); - } - - void CreateXEvent(const RocmTracerEvent& event, XPlaneBuilder* plane, - uint64_t start_gpu_ns, uint64_t end_gpu_ns, - XLineBuilder* line) { - if (event.start_time_ns < start_gpu_ns || - event.end_time_ns > end_gpu_ns || - event.start_time_ns > event.end_time_ns) { - VLOG(2) << "events have abnormal timestamps:" << event.name - << " start time(ns): " << event.start_time_ns - << " end time(ns): " << event.end_time_ns - << " start gpu(ns):" << start_gpu_ns - << " end gpu(ns):" << end_gpu_ns - << " corr. id:" << event.correlation_id; - return; - } - std::string kernel_name = tsl::port::MaybeAbiDemangle(event.name.c_str()); - if (kernel_name.empty()) { - kernel_name = GetRocmTracerEventTypeName(event.type); - } - XEventMetadata* event_metadata = - plane->GetOrCreateEventMetadata(std::move(kernel_name)); - XEventBuilder xevent = line->AddEvent(*event_metadata); - VLOG(7) << "Adding event to line=" << line->Id(); - xevent.SetTimestampNs(event.start_time_ns); - xevent.SetEndTimestampNs(event.end_time_ns); - if (event.source == RocmTracerEventSource::ApiCallback) { - xevent.AddStatValue(*plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDeviceId)), - event.device_id); - } - if (event.correlation_id != RocmTracerEvent::kInvalidCorrelationId) { - xevent.AddStatValue(*plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kCorrelationId)), - event.correlation_id); - } - if (!event.roctx_range.empty()) { - xevent.AddStatValue(*plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kNVTXRange)), - *plane->GetOrCreateStatMetadata(event.roctx_range)); - } - // if (event.context_id != CuptiTracerEvent::kInvalidContextId) { - // xevent.AddStatValue( - // *plane->GetOrCreateStatMetadata( - // GetStatTypeStr(StatType::kContextId)), - // absl::StrCat("$$", static_cast(event.context_id))); - // } - - if (event.type == RocmTracerEventType::Kernel && - event.source == RocmTracerEventSource::Activity) { - RocmDeviceOccupancyParams params{}; - params.attributes.maxThreadsPerBlock = INT_MAX; - params.attributes.numRegs = - static_cast(event.kernel_info.registers_per_thread); - params.attributes.sharedSizeBytes = - event.kernel_info.static_shared_memory_usage; - // params.attributes.partitionedGCConfig = PARTITIONED_GC_OFF; - // params.attributes.shmemLimitConfig = FUNC_SHMEM_LIMIT_DEFAULT; - params.attributes.maxDynamicSharedSizeBytes = 0; - params.block_size = static_cast(event.kernel_info.block_x * - event.kernel_info.block_y * - event.kernel_info.block_z); - - params.dynamic_smem_size = - event.kernel_info.dynamic_shared_memory_usage; - params.func_ptr = event.kernel_info.func_ptr; - - OccupancyStats& occ_stats = occupancy_cache_[params]; - if (occ_stats.occupancy_pct == 0.0) { - occ_stats = GetOccupancy(params); - } - xevent.AddStatValue(*plane->GetOrCreateStatMetadata(GetStatTypeStr( - StatType::kTheoreticalOccupancyPct)), - occ_stats.occupancy_pct); - xevent.AddStatValue(*plane->GetOrCreateStatMetadata(GetStatTypeStr( - StatType::kOccupancyMinGridSize)), - static_cast(occ_stats.min_grid_size)); - xevent.AddStatValue( - *plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kOccupancySuggestedBlockSize)), - static_cast(occ_stats.suggested_block_size)); - xevent.AddStatValue(*plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kKernelDetails)), - *plane->GetOrCreateStatMetadata(ToXStat( - event.kernel_info, occ_stats.occupancy_pct))); - } else if (event.type == RocmTracerEventType::MemcpyH2D || - event.type == RocmTracerEventType::MemcpyD2H || - event.type == RocmTracerEventType::MemcpyD2D || - event.type == RocmTracerEventType::MemcpyP2P || - event.type == RocmTracerEventType::MemcpyOther) { - VLOG(7) << "Add Memcpy stat"; - const auto& memcpy_info = event.memcpy_info; - std::string memcpy_details = absl::StrCat( - // TODO(rocm-profiler): we need to discover the memory kind similar - // to CUDA - "kind:", "Unknown", " size:", memcpy_info.num_bytes, - " dest:", memcpy_info.destination, " async:", memcpy_info.async); - xevent.AddStatValue( - *plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kMemcpyDetails)), - *plane->GetOrCreateStatMetadata(std::move(memcpy_details))); - } else if (event.type == RocmTracerEventType::MemoryAlloc) { - VLOG(7) << "Add MemAlloc stat"; - std::string value = - // TODO(rocm-profiler): we need to discover the memory kind similar - // to CUDA - absl::StrCat("kind:", "Unknown", - " num_bytes:", event.memalloc_info.num_bytes); - xevent.AddStatValue(*plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kMemallocDetails)), - *plane->GetOrCreateStatMetadata(std::move(value))); - } else if (event.type == RocmTracerEventType::MemoryFree) { - VLOG(7) << "Add MemFree stat"; - std::string value = - // TODO(rocm-profiler): we need to discover the memory kind similar - // to CUDA - absl::StrCat("kind:", "Unknown", - " num_bytes:", event.memalloc_info.num_bytes); - xevent.AddStatValue(*plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kMemFreeDetails)), - *plane->GetOrCreateStatMetadata(std::move(value))); - } else if (event.type == RocmTracerEventType::Memset) { - VLOG(7) << "Add Memset stat"; - auto value = - // TODO(rocm-profiler): we need to discover the memory kind similar - // to CUDA - absl::StrCat("kind:", "Unknown", - " num_bytes:", event.memset_info.num_bytes, - " async:", event.memset_info.async); - xevent.AddStatValue(*plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kMemsetDetails)), - *plane->GetOrCreateStatMetadata(std::move(value))); - } - // TODO(rocm-profiler): we need to support the following event type - /* else if (event.type == CuptiTracerEventType::MemoryResidency) { - VLOG(7) << "Add MemoryResidency stat"; - std::string value = absl::StrCat( - "kind:", GetMemoryKindName(event.memory_residency_info.kind), - " num_bytes:", event.memory_residency_info.num_bytes, - " addr:", event.memory_residency_info.address); - xevent.AddStatValue(*plane->GetOrCreateStatMetadata(GetStatTypeStr( - StatType::kMemoryResidencyDetails)), - *plane->GetOrCreateStatMetadata(std::move(value))); - } */ - - std::vector annotation_stack = - ParseAnnotationStack(event.annotation); - if (!annotation_stack.empty()) { - xevent.AddStatValue( - *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kTfOp)), - *plane->GetOrCreateStatMetadata(annotation_stack.begin()->name)); - } - // If multiple metadata have the same key name, show the values from the - // top of the stack (innermost annotation). Concatenate the values from - // "hlo_op". - absl::flat_hash_set key_set; - - for (auto annotation = annotation_stack.rbegin(); - annotation != annotation_stack.rend(); ++annotation) { - for (const Annotation::Metadata& metadata : annotation->metadata) { - if (key_set.insert(metadata.key).second) { - xevent.ParseAndAddStatValue( - *plane->GetOrCreateStatMetadata(metadata.key), metadata.value); - } - } - } - } - bool IsHostEvent(const RocmTracerEvent& event, tsl::int64* line_id) { - // DriverCallback(i.e. kernel launching) events are host events. - if (event.source == RocmTracerEventSource::ApiCallback) { - *line_id = event.thread_id; - return true; - } else { // activities - *line_id = event.stream_id; - return false; - } - - // TODO(rocm-profiler): do we have such a report in rocm? - // Non-overhead activity events are device events. - /* if (event.type != CuptiTracerEventType::Overhead) { - *line_id = event.stream_id; - return false; - } */ - // Overhead events can be associated with a thread or a stream, etc. - // If a valid thread id is specified, we consider it as a host event. - // - - if (event.stream_id != RocmTracerEvent::kInvalidStreamId) { - *line_id = event.stream_id; - return false; - } else if (event.thread_id != RocmTracerEvent::kInvalidThreadId && - event.thread_id != 0) { - *line_id = event.thread_id; - return true; - } else { - *line_id = tsl::profiler::kThreadIdOverhead; - return false; - } - } - void Export(uint64_t start_walltime_ns, uint64_t start_gputime_ns, - uint64_t end_gputime_ns, XPlaneBuilder* device_plane, - XPlaneBuilder* host_plane) { - int host_ev_cnt = 0, dev_ev_cnt = 0; - mutex_lock l(events_mutex); - // Tracking event types per line. - absl::flat_hash_map> - events_types_per_line; - for (const RocmTracerEvent& event : events) { - int64_t line_id = RocmTracerEvent::kInvalidThreadId; - bool is_host_event = IsHostEvent(event, &line_id); - - if (is_host_event) { - host_ev_cnt++; - } else { - dev_ev_cnt++; - } - - if (line_id == RocmTracerEvent::kInvalidThreadId || - line_id == RocmTracerEvent::kInvalidStreamId) { - VLOG(3) << "Ignoring event, type=" << static_cast(event.type); - continue; - } - auto* plane = is_host_event ? host_plane : device_plane; - VLOG(9) << "Event" - << " type=" << static_cast(event.type) - << " line_id=" << line_id - << (is_host_event ? " host plane=" : " device plane=") - << plane->Name(); - XLineBuilder line = plane->GetOrCreateLine(line_id); - line.SetTimestampNs(start_gputime_ns); - CreateXEvent(event, plane, start_gputime_ns, end_gputime_ns, &line); - events_types_per_line[line_id].emplace(event.type); - } - device_plane->ForEachLine([&](XLineBuilder line) { - line.SetName( - GetDeviceXLineName(line.Id(), events_types_per_line[line.Id()])); - }); - host_plane->ForEachLine([&](XLineBuilder line) { - line.SetName(absl::StrCat("Host Threads/", line.Id())); - }); - events.clear(); - } - - mutex events_mutex; - std::vector events TF_GUARDED_BY(events_mutex); - absl::flat_hash_map correlation_info_ - TF_GUARDED_BY(events_mutex); - absl::flat_hash_map - occupancy_cache_; - hipDeviceProp_t device_properties_; - }; - - absl::FixedArray per_device_collector_; -}; - // GpuTracer for ROCm GPU. class GpuTracer : public profiler::ProfilerInterface { public: @@ -962,7 +97,7 @@ class GpuTracer : public profiler::ProfilerInterface { State profiling_state_ = State::kNotStarted; RocmTracer* rocm_tracer_; - std::unique_ptr rocm_trace_collector_; + std::unique_ptr rocm_trace_collector_; }; RocmTracerOptions GpuTracer::GetRocmTracerOptions() { @@ -1034,12 +169,6 @@ RocmTracerOptions GpuTracer::GetRocmTracerOptions() { hip_api_aux_ops.end()); options.api_callbacks.emplace(ACTIVITY_DOMAIN_HIP_API, hip_api_domain_ops); - // options.api_callbacks.emplace(ACTIVITY_DOMAIN_ROCTX, empty_vec); - // options.api_callbacks.emplace(ACTIVITY_DOMAIN_HIP_API, empty_vec); - - // options.activity_tracing.emplace(ACTIVITY_DOMAIN_HIP_API, - // hip_api_domain_ops); - options.activity_tracing.emplace(ACTIVITY_DOMAIN_HIP_API, empty_vec); options.activity_tracing.emplace(ACTIVITY_DOMAIN_HCC_OPS, empty_vec); return options; @@ -1066,8 +195,12 @@ Status GpuTracer::DoStart() { GetRocmTraceCollectorOptions(rocm_tracer_->NumGpus()); uint64_t start_gputime_ns = RocmTracer::GetTimestamp(); uint64_t start_walltime_ns = tsl::EnvTime::NowNanos(); - rocm_trace_collector_ = std::make_unique( + rocm_trace_collector_ = CreateRocmCollector( trace_collector_options, start_walltime_ns, start_gputime_ns); + // rocm_trace_collector_ = + // std::make_unique(trace_collector_options, + // start_walltime_ns, + // start_gputime_ns); RocmTracerOptions tracer_options = GetRocmTracerOptions(); rocm_tracer_->Enable(tracer_options, rocm_trace_collector_.get()); diff --git a/xla/backends/profiler/gpu/mock_cupti.h b/xla/backends/profiler/gpu/mock_cupti.h index 617406a0e17a0..2097e4049b32d 100644 --- a/xla/backends/profiler/gpu/mock_cupti.h +++ b/xla/backends/profiler/gpu/mock_cupti.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -68,85 +68,6 @@ class MockCupti : public xla::profiler::CuptiInterface { (override)); MOCK_METHOD(CUptiResult, Unsubscribe, (CUpti_SubscriberHandle subscriber), (override)); - MOCK_METHOD(CUptiResult, DeviceEnumEventDomains, - (CUdevice device, size_t* array_size_bytes, - CUpti_EventDomainID* domain_array), - (override)); - MOCK_METHOD(CUptiResult, DeviceGetEventDomainAttribute, - (CUdevice device, CUpti_EventDomainID event_domain, - CUpti_EventDomainAttribute attrib, size_t* value_size, - void* value), - (override)); - MOCK_METHOD(CUptiResult, DisableKernelReplayMode, (CUcontext context), - (override)); - MOCK_METHOD(CUptiResult, EnableKernelReplayMode, (CUcontext context), - (override)); - MOCK_METHOD(CUptiResult, DeviceGetNumEventDomains, - (CUdevice device, uint32_t* num_domains), (override)); - MOCK_METHOD(CUptiResult, EventDomainEnumEvents, - (CUpti_EventDomainID event_domain, size_t* array_size_bytes, - CUpti_EventID* event_array), - (override)); - MOCK_METHOD(CUptiResult, EventDomainGetNumEvents, - (CUpti_EventDomainID event_domain, uint32_t* num_events), - (override)); - MOCK_METHOD(CUptiResult, EventGetAttribute, - (CUpti_EventID event, CUpti_EventAttribute attrib, - size_t* value_size, void* value), - (override)); - MOCK_METHOD(CUptiResult, EventGetIdFromName, - (CUdevice device, const char* event_name, CUpti_EventID* event), - (override)); - MOCK_METHOD(CUptiResult, EventGroupDisable, (CUpti_EventGroup event_group), - (override)); - MOCK_METHOD(CUptiResult, EventGroupEnable, (CUpti_EventGroup event_group), - (override)); - MOCK_METHOD(CUptiResult, EventGroupGetAttribute, - (CUpti_EventGroup event_group, CUpti_EventGroupAttribute attrib, - size_t* value_size, void* value), - (override)); - MOCK_METHOD(CUptiResult, EventGroupReadEvent, - (CUpti_EventGroup event_group, CUpti_ReadEventFlags flags, - CUpti_EventID event, size_t* event_value_buffer_size_bytes, - uint64_t* eventValueBuffer), - (override)); - MOCK_METHOD(CUptiResult, EventGroupSetAttribute, - (CUpti_EventGroup event_group, CUpti_EventGroupAttribute attrib, - size_t value_size, void* value), - (override)); - MOCK_METHOD(CUptiResult, EventGroupSetsCreate, - (CUcontext context, size_t event_id_array_size_bytes, - CUpti_EventID* event_id_array, - CUpti_EventGroupSets** event_group_passes), - (override)); - MOCK_METHOD(CUptiResult, EventGroupSetsDestroy, - (CUpti_EventGroupSets * event_group_sets), (override)); - MOCK_METHOD(CUptiResult, DeviceEnumMetrics, - (CUdevice device, size_t* arraySizeBytes, - CUpti_MetricID* metricArray), - (override)); - MOCK_METHOD(CUptiResult, DeviceGetNumMetrics, - (CUdevice device, uint32_t* num_metrics), (override)); - MOCK_METHOD(CUptiResult, MetricGetIdFromName, - (CUdevice device, const char* metric_name, - CUpti_MetricID* metric), - (override)); - MOCK_METHOD(CUptiResult, MetricGetNumEvents, - (CUpti_MetricID metric, uint32_t* num_events), (override)); - MOCK_METHOD(CUptiResult, MetricEnumEvents, - (CUpti_MetricID metric, size_t* event_id_array_size_bytes, - CUpti_EventID* event_id_array), - (override)); - MOCK_METHOD(CUptiResult, MetricGetAttribute, - (CUpti_MetricID metric, CUpti_MetricAttribute attrib, - size_t* value_size, void* value), - (override)); - MOCK_METHOD(CUptiResult, MetricGetValue, - (CUdevice device, CUpti_MetricID metric, - size_t event_id_array_size_bytes, CUpti_EventID* event_id_array, - size_t event_value_array_size_bytes, uint64_t* event_value_array, - uint64_t time_duration, CUpti_MetricValue* metric_value), - (override)); MOCK_METHOD(CUptiResult, GetResultString, (CUptiResult result, const char** str), (override)); diff --git a/xla/backends/profiler/gpu/nvtx_utils.cc b/xla/backends/profiler/gpu/nvtx_utils.cc index 512c5fc2d1024..c4b38d899b081 100644 --- a/xla/backends/profiler/gpu/nvtx_utils.cc +++ b/xla/backends/profiler/gpu/nvtx_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/profiler/gpu/nvtx_utils.h b/xla/backends/profiler/gpu/nvtx_utils.h index 58be1980e2a29..43f0c91bf917f 100644 --- a/xla/backends/profiler/gpu/nvtx_utils.h +++ b/xla/backends/profiler/gpu/nvtx_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/profiler/gpu/rocm_collector.cc b/xla/backends/profiler/gpu/rocm_collector.cc new file mode 100644 index 0000000000000..41b21c486eb34 --- /dev/null +++ b/xla/backends/profiler/gpu/rocm_collector.cc @@ -0,0 +1,853 @@ + +/* Copyright 2024 The OpenXLA Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/profiler/gpu/rocm_collector.h" + +#include "absl/container/fixed_array.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/optional.h" +#include "xla/stream_executor/rocm/roctracer_wrapper.h" +#include "xla/tsl/util/env_var.h" +#include "tsl/platform/abi.h" +#include "tsl/platform/env_time.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/macros.h" +#include "tsl/platform/mutex.h" +#include "tsl/platform/status.h" +#include "tsl/platform/thread_annotations.h" +#include "tsl/platform/types.h" +#include "tsl/profiler/backends/cpu/annotation_stack.h" +#include "tsl/profiler/lib/profiler_factory.h" +#include "tsl/profiler/lib/profiler_interface.h" +#include "tsl/profiler/utils/parse_annotation.h" +#include "tsl/profiler/utils/xplane_builder.h" +#include "tsl/profiler/utils/xplane_schema.h" +#include "tsl/profiler/utils/xplane_utils.h" + +namespace xla { +namespace profiler { + +namespace se = ::stream_executor; +using tensorflow::ProfileOptions; +using tsl::mutex; +using tsl::mutex_lock; +// using tsl::OkStatus; +using tsl::Status; +using tsl::profiler::Annotation; +using tsl::profiler::AnnotationStack; +using tsl::profiler::FindOrAddMutablePlaneWithName; +using tsl::profiler::GetStatTypeStr; +using tsl::profiler::GpuPlaneName; +using tsl::profiler::kDeviceVendorAMD; +using tsl::profiler::kThreadIdOverhead; +using tsl::profiler::ParseAnnotationStack; +using tsl::profiler::ProfilerInterface; +// using tsl::profiler::RegisterProfilerFactory; +using tsl::profiler::StatType; +using tsl::profiler::XEventBuilder; +using tsl::profiler::XEventMetadata; +using tsl::profiler::XLineBuilder; +using tsl::profiler::XPlaneBuilder; +using tsl::profiler::XSpace; + +void AnnotationMap::Add(uint32_t correlation_id, + const std::string& annotation) { + if (annotation.empty()) return; + VLOG(3) << "Add annotation: " << " correlation_id=" << correlation_id + << ", annotation: " << annotation; + absl::MutexLock lock(&map_.mutex); + if (map_.annotations.size() < max_size_) { + absl::string_view annotation_str = + *map_.annotations.insert(annotation).first; + map_.correlation_map.emplace(correlation_id, annotation_str); + } +} + +absl::string_view AnnotationMap::LookUp(uint32_t correlation_id) { + absl::MutexLock lock(&map_.mutex); + auto it = map_.correlation_map.find(correlation_id); + return it != map_.correlation_map.end() ? it->second : absl::string_view(); +} + +//========== +namespace { +// Set the all XLines of specified XPlane to starting walltime. +// Events time in both host and device planes are CUTPI timestamps. +// We set initial RocmTracer timestamp as start time for all lines to reflect +// this fact. Eventually we change line start time to corresponding +// start_walltime_ns to normalize with CPU wall time. +static void NormalizeTimeStamps(XPlaneBuilder* plane, + uint64_t start_walltime_ns) { + plane->ForEachLine([&](tsl::profiler::XLineBuilder line) { + line.SetTimestampNs(start_walltime_ns); + }); +} + +std::string GetDeviceXLineName( + int64_t stream_id, absl::flat_hash_set& event_types) { + std::string line_name = absl::StrCat("Stream #", stream_id); + event_types.erase(RocmTracerEventType::Unsupported); + if (event_types.empty()) return line_name; + std::vector type_names; + for (const auto event_type : event_types) { + type_names.emplace_back(GetRocmTracerEventTypeName(event_type)); + } + return absl::StrCat(line_name, "(", absl::StrJoin(type_names, ","), ")"); +} + +} // namespace + +static void DumpRocmTracerEvent(const RocmTracerEvent& event, + uint64_t start_walltime_ns, + uint64_t start_gputime_ns, + const std::string& message) { + std::ostringstream oss; + oss << "correlation_id=" << event.correlation_id; + oss << ",type=" << GetRocmTracerEventTypeName(event.type); + oss << ",source=" << GetRocmTracerEventSourceName(event.source); + oss << ",domain=" << GetRocmTracerEventDomainName(event.domain); + oss << ",name=" << event.name; + oss << ",annotation=" << event.annotation; + oss << ",start_time_us=" + << (start_walltime_ns + (start_gputime_ns - event.start_time_ns)) / 1000; + oss << ",duration=" << (event.end_time_ns - event.start_time_ns) / 1000; + oss << ",device_id=" << event.device_id; + oss << ",thread_id=" << event.thread_id; + oss << ",stream_id=" << event.stream_id; + + switch (event.type) { + case RocmTracerEventType::Kernel: + break; + case RocmTracerEventType::MemcpyD2H: + case RocmTracerEventType::MemcpyH2D: + case RocmTracerEventType::MemcpyD2D: + case RocmTracerEventType::MemcpyP2P: + oss << ",num_bytes=" << event.memcpy_info.num_bytes; + oss << ",destination=" << event.memcpy_info.destination; + oss << ",async=" << event.memcpy_info.async; + break; + case RocmTracerEventType::MemoryAlloc: + oss << ",num_bytes=" << event.memalloc_info.num_bytes; + break; + case RocmTracerEventType::Synchronization: + break; + case RocmTracerEventType::Generic: + break; + default: + DCHECK(false); + break; + } + oss << message; + VLOG(3) << oss.str(); +} + +static uint64_t get_timestamp() { + uint64_t ts; + if (se::wrap::roctracer_get_timestamp(&ts) != ROCTRACER_STATUS_SUCCESS) { + const char* errstr = se::wrap::roctracer_error_string(); + LOG(ERROR) << "function roctracer_get_timestamp failed with error " + << errstr; + // Return 0 on error. + return 0; + } + return ts; +} + +struct RocmDeviceOccupancyParams { + hipFuncAttributes attributes = {}; + int block_size = 0; + size_t dynamic_smem_size = 0; + void* func_ptr; + + friend bool operator==(const RocmDeviceOccupancyParams& lhs, + const RocmDeviceOccupancyParams& rhs) { + return 0 == memcmp(&lhs, &rhs, sizeof(lhs)); + } + + template + friend H AbslHashValue(H hash_state, + const RocmDeviceOccupancyParams& params) { + return H::combine( + std::move(hash_state), params.attributes.maxThreadsPerBlock, + params.attributes.numRegs, params.attributes.sharedSizeBytes, + params.attributes.maxDynamicSharedSizeBytes, params.block_size, + params.dynamic_smem_size, params.func_ptr); + } +}; + +struct OccupancyStats { + double occupancy_pct = 0.0; + int min_grid_size = 0; + int suggested_block_size = 0; +}; + +struct CorrelationInfo { + CorrelationInfo(uint32_t t, uint32_t e) : thread_id(t), enqueue_time_ns(e) {} + uint32_t thread_id; + uint64_t enqueue_time_ns; +}; + +class PerDeviceCollector { + private: + OccupancyStats GetOccupancy(const RocmDeviceOccupancyParams& params) const { + // TODO(rocm-profiler): hipOccupancyMaxActiveBlocksPerMultiprocessor only + // return hipSuccess for HIP_API_ID_hipLaunchKernel + + OccupancyStats stats; + int number_of_active_blocks; + hipError_t err = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &number_of_active_blocks, params.func_ptr, params.block_size, + params.dynamic_smem_size); + + if (err != hipError_t::hipSuccess) { + return {}; + } + + stats.occupancy_pct = number_of_active_blocks * params.block_size * 100; + stats.occupancy_pct /= device_properties_.maxThreadsPerMultiProcessor; + + err = hipOccupancyMaxPotentialBlockSize( + &stats.min_grid_size, &stats.suggested_block_size, params.func_ptr, + params.dynamic_smem_size, 0); + + if (err != hipError_t::hipSuccess) { + return {}; + } + + return stats; + } + + void CreateXEvent(const RocmTracerEvent& event, XPlaneBuilder* plane, + uint64_t start_gpu_ns, uint64_t end_gpu_ns, + XLineBuilder* line) { + if (event.start_time_ns < start_gpu_ns || event.end_time_ns > end_gpu_ns || + event.start_time_ns > event.end_time_ns) { + VLOG(2) << "events have abnormal timestamps:" << event.name + << " start time(ns): " << event.start_time_ns + << " end time(ns): " << event.end_time_ns + << " start gpu(ns):" << start_gpu_ns + << " end gpu(ns):" << end_gpu_ns + << " corr. id:" << event.correlation_id; + return; + } + std::string kernel_name = tsl::port::MaybeAbiDemangle(event.name.c_str()); + if (kernel_name.empty()) { + kernel_name = GetRocmTracerEventTypeName(event.type); + } + XEventMetadata* event_metadata = + plane->GetOrCreateEventMetadata(std::move(kernel_name)); + XEventBuilder xevent = line->AddEvent(*event_metadata); + VLOG(7) << "Adding event to line=" << line->Id(); + xevent.SetTimestampNs(event.start_time_ns); + xevent.SetEndTimestampNs(event.end_time_ns); + if (event.source == RocmTracerEventSource::ApiCallback) { + xevent.AddStatValue( + *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kDeviceId)), + event.device_id); + } + if (event.correlation_id != RocmTracerEvent::kInvalidCorrelationId) { + xevent.AddStatValue(*plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kCorrelationId)), + event.correlation_id); + } + if (!event.roctx_range.empty()) { + xevent.AddStatValue( + *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kNVTXRange)), + *plane->GetOrCreateStatMetadata(event.roctx_range)); + } + + if (event.type == RocmTracerEventType::Kernel && + event.source == RocmTracerEventSource::Activity) { + RocmDeviceOccupancyParams params{}; + params.attributes.maxThreadsPerBlock = INT_MAX; + params.attributes.numRegs = + static_cast(event.kernel_info.registers_per_thread); + params.attributes.sharedSizeBytes = + event.kernel_info.static_shared_memory_usage; + // params.attributes.partitionedGCConfig = PARTITIONED_GC_OFF; + // params.attributes.shmemLimitConfig = FUNC_SHMEM_LIMIT_DEFAULT; + params.attributes.maxDynamicSharedSizeBytes = 0; + params.block_size = static_cast(event.kernel_info.block_x * + event.kernel_info.block_y * + event.kernel_info.block_z); + + params.dynamic_smem_size = event.kernel_info.dynamic_shared_memory_usage; + params.func_ptr = event.kernel_info.func_ptr; + + OccupancyStats& occ_stats = occupancy_cache_[params]; + if (occ_stats.occupancy_pct == 0.0) { + occ_stats = GetOccupancy(params); + } + xevent.AddStatValue(*plane->GetOrCreateStatMetadata(GetStatTypeStr( + StatType::kTheoreticalOccupancyPct)), + occ_stats.occupancy_pct); + xevent.AddStatValue(*plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kOccupancyMinGridSize)), + static_cast(occ_stats.min_grid_size)); + xevent.AddStatValue( + *plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kOccupancySuggestedBlockSize)), + static_cast(occ_stats.suggested_block_size)); + xevent.AddStatValue(*plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kKernelDetails)), + *plane->GetOrCreateStatMetadata(ToXStat( + event.kernel_info, occ_stats.occupancy_pct))); + } else if (event.type == RocmTracerEventType::MemcpyH2D || + event.type == RocmTracerEventType::MemcpyD2H || + event.type == RocmTracerEventType::MemcpyD2D || + event.type == RocmTracerEventType::MemcpyP2P || + event.type == RocmTracerEventType::MemcpyOther) { + VLOG(7) << "Add Memcpy stat"; + const auto& memcpy_info = event.memcpy_info; + std::string memcpy_details = absl::StrCat( + // TODO(rocm-profiler): we need to discover the memory kind similar + // to CUDA + "kind:", "Unknown", " size:", memcpy_info.num_bytes, + " dest:", memcpy_info.destination, " async:", memcpy_info.async); + xevent.AddStatValue( + *plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kMemcpyDetails)), + *plane->GetOrCreateStatMetadata(std::move(memcpy_details))); + } else if (event.type == RocmTracerEventType::MemoryAlloc) { + VLOG(7) << "Add MemAlloc stat"; + std::string value = + // TODO(rocm-profiler): we need to discover the memory kind similar + // to CUDA + absl::StrCat("kind:", "Unknown", + " num_bytes:", event.memalloc_info.num_bytes); + xevent.AddStatValue(*plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kMemallocDetails)), + *plane->GetOrCreateStatMetadata(std::move(value))); + } else if (event.type == RocmTracerEventType::MemoryFree) { + VLOG(7) << "Add MemFree stat"; + std::string value = + // TODO(rocm-profiler): we need to discover the memory kind similar + // to CUDA + absl::StrCat("kind:", "Unknown", + " num_bytes:", event.memalloc_info.num_bytes); + xevent.AddStatValue(*plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kMemFreeDetails)), + *plane->GetOrCreateStatMetadata(std::move(value))); + } else if (event.type == RocmTracerEventType::Memset) { + VLOG(7) << "Add Memset stat"; + auto value = + // TODO(rocm-profiler): we need to discover the memory kind similar + // to CUDA + absl::StrCat("kind:", "Unknown", + " num_bytes:", event.memset_info.num_bytes, + " async:", event.memset_info.async); + xevent.AddStatValue(*plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kMemsetDetails)), + *plane->GetOrCreateStatMetadata(std::move(value))); + } + // TODO(rocm-profiler): we need to support the following event type + /* else if (event.type == CuptiTracerEventType::MemoryResidency) { + VLOG(7) << "Add MemoryResidency stat"; + std::string value = absl::StrCat( + "kind:", GetMemoryKindName(event.memory_residency_info.kind), + " num_bytes:", event.memory_residency_info.num_bytes, + " addr:", event.memory_residency_info.address); + xevent.AddStatValue(*plane->GetOrCreateStatMetadata(GetStatTypeStr( + StatType::kMemoryResidencyDetails)), + *plane->GetOrCreateStatMetadata(std::move(value))); + } */ + + std::vector annotation_stack = + ParseAnnotationStack(event.annotation); + if (!annotation_stack.empty()) { + xevent.AddStatValue( + *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kTfOp)), + *plane->GetOrCreateStatMetadata(annotation_stack.begin()->name)); + } + // If multiple metadata have the same key name, show the values from the + // top of the stack (innermost annotation). Concatenate the values from + // "hlo_op". + absl::flat_hash_set key_set; + + for (auto annotation = annotation_stack.rbegin(); + annotation != annotation_stack.rend(); ++annotation) { + for (const Annotation::Metadata& metadata : annotation->metadata) { + if (key_set.insert(metadata.key).second) { + xevent.ParseAndAddStatValue( + *plane->GetOrCreateStatMetadata(metadata.key), metadata.value); + } + } + } + } + + void SortByStartTime() { + mutex_lock lock(events_mutex); + std::sort(events.begin(), events.end(), + [](const RocmTracerEvent& event1, const RocmTracerEvent& event2) { + return event1.start_time_ns < event2.start_time_ns; + }); + } + + bool IsHostEvent(const RocmTracerEvent& event, tsl::int64* line_id) { + // DriverCallback(i.e. kernel launching) events are host events. + if (event.source == RocmTracerEventSource::ApiCallback) { + *line_id = event.thread_id; + return true; + } else { // activities + *line_id = event.stream_id; + return false; + } + + // TODO(rocm-profiler): do we have such a report in rocm? + // Non-overhead activity events are device events. + /* if (event.type != CuptiTracerEventType::Overhead) { + *line_id = event.stream_id; + return false; + } */ + // Overhead events can be associated with a thread or a stream, etc. + // If a valid thread id is specified, we consider it as a host event. + // + + if (event.stream_id != RocmTracerEvent::kInvalidStreamId) { + *line_id = event.stream_id; + return false; + } else if (event.thread_id != RocmTracerEvent::kInvalidThreadId && + event.thread_id != 0) { + *line_id = event.thread_id; + return true; + } else { + *line_id = tsl::profiler::kThreadIdOverhead; + return false; + } + } + + public: + void Export(uint64_t start_walltime_ns, uint64_t start_gputime_ns, + uint64_t end_gputime_ns, XPlaneBuilder* device_plane, + XPlaneBuilder* host_plane) { + int host_ev_cnt = 0, dev_ev_cnt = 0; + mutex_lock l(events_mutex); + // Tracking event types per line. + absl::flat_hash_map> + events_types_per_line; + for (const RocmTracerEvent& event : events) { + int64_t line_id = RocmTracerEvent::kInvalidThreadId; + bool is_host_event = IsHostEvent(event, &line_id); + + if (is_host_event) { + host_ev_cnt++; + } else { + dev_ev_cnt++; + } + + if (line_id == RocmTracerEvent::kInvalidThreadId || + line_id == RocmTracerEvent::kInvalidStreamId) { + VLOG(3) << "Ignoring event, type=" << static_cast(event.type); + continue; + } + auto* plane = is_host_event ? host_plane : device_plane; + VLOG(9) << "Event" << " type=" << static_cast(event.type) + << " line_id=" << line_id + << (is_host_event ? " host plane=" : " device plane=") + << plane->Name(); + XLineBuilder line = plane->GetOrCreateLine(line_id); + line.SetTimestampNs(start_gputime_ns); + CreateXEvent(event, plane, start_gputime_ns, end_gputime_ns, &line); + events_types_per_line[line_id].emplace(event.type); + } + device_plane->ForEachLine([&](XLineBuilder line) { + line.SetName( + GetDeviceXLineName(line.Id(), events_types_per_line[line.Id()])); + }); + host_plane->ForEachLine([&](XLineBuilder line) { + line.SetName(absl::StrCat("Host Threads/", line.Id())); + }); + events.clear(); + } + + PerDeviceCollector() = default; + + void AddEvent(const RocmTracerEvent& event) { + mutex_lock l(events_mutex); + if (event.source == RocmTracerEventSource::ApiCallback) { + // Cupti api callback events were used to populate launch times etc. + if (event.correlation_id != RocmTracerEvent::kInvalidCorrelationId) { + correlation_info_.insert( + {event.correlation_id, + CorrelationInfo(event.thread_id, event.start_time_ns)}); + } + events.emplace_back(std::move(event)); + } else { + // Cupti activity events measure device times etc. + events.emplace_back(std::move(event)); + } + } + + void GetDeviceCapabilities(int32_t device_ordinal, + XPlaneBuilder* device_plane) { + device_plane->AddStatValue(*device_plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDevVendor)), + kDeviceVendorAMD); + + if (hipGetDeviceProperties(&device_properties_, device_ordinal) != + hipSuccess) + return; + + auto clock_rate_in_khz = + device_properties_.clockRate; // this is also in Khz + if (clock_rate_in_khz) { + device_plane->AddStatValue( + *device_plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDevCapClockRateKHz)), + clock_rate_in_khz); + } + + auto core_count = device_properties_.multiProcessorCount; + if (core_count) { + device_plane->AddStatValue( + *device_plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDevCapCoreCount)), + core_count); + } + + auto mem_clock_khz = device_properties_.memoryClockRate; + auto mem_bus_width_bits = device_properties_.memoryBusWidth; + + if (mem_clock_khz && mem_bus_width_bits) { + // Times 2 because HBM is DDR memory; it gets two data bits per each + // data lane. + auto memory_bandwidth = + uint64_t{2} * (mem_clock_khz) * 1000 * (mem_bus_width_bits) / 8; + device_plane->AddStatValue( + *device_plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDevCapMemoryBandwidth)), + memory_bandwidth); + } + + size_t total_memory = device_properties_.totalGlobalMem; + if (total_memory) { + device_plane->AddStatValue( + *device_plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDevCapMemorySize)), + static_cast(total_memory)); + } + + auto compute_capability_major = device_properties_.major; + if (compute_capability_major) { + device_plane->AddStatValue( + *device_plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDevCapComputeCapMajor)), + compute_capability_major); + } + auto compute_capability_minor = device_properties_.minor; + if (compute_capability_minor) { + device_plane->AddStatValue( + *device_plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDevCapComputeCapMinor)), + compute_capability_minor); + } + } + + private: + mutex events_mutex; + std::vector events TF_GUARDED_BY(events_mutex); + absl::flat_hash_map correlation_info_ + TF_GUARDED_BY(events_mutex); + absl::flat_hash_map + occupancy_cache_; + hipDeviceProp_t device_properties_; +}; + +class RocmTraceCollectorImpl : public profiler::RocmTraceCollector { + public: + RocmTraceCollectorImpl(const RocmTraceCollectorOptions& options, + uint64_t start_walltime_ns, uint64_t start_gputime_ns) + : RocmTraceCollector(options), + num_callback_events_(0), + num_activity_events_(0), + start_walltime_ns_(start_walltime_ns), + start_gputime_ns_(start_gputime_ns), + num_gpus_(options.num_gpus), + per_device_collector_(options.num_gpus) {} + + void AddEvent(RocmTracerEvent&& event, bool is_auxiliary) override; + void Flush() override; + void Export(XSpace* space) override; + + void OnEventsDropped(const std::string& reason, + uint32_t correlation_id) override { + LOG(INFO) << "RocmTracerEvent dropped (correlation_id=" << correlation_id + << ",) : " << reason << "."; + } + + private: + std::atomic num_callback_events_; + std::atomic num_activity_events_; + uint64_t start_walltime_ns_; + uint64_t start_gputime_ns_; + int num_gpus_; + + mutex event_maps_mutex_; + absl::flat_hash_map api_events_map_ + TF_GUARDED_BY(event_maps_mutex_); + + /* Some apis such as MEMSETD32 (based on an observation with ResNet50), + trigger multiple HIP ops domain activities. We keep them in a vector and + merge them with api activities at flush time. + */ + absl::flat_hash_map> + activity_ops_events_map_ TF_GUARDED_BY(event_maps_mutex_); + // This is for the APIs that we track because we need some information from + // them to populate the corresponding activity that we actually track. + absl::flat_hash_map auxiliary_api_events_map_ + TF_GUARDED_BY(event_maps_mutex_); + + const std::vector ApiActivityInfoExchange(); + + absl::flat_hash_map per_device_collector_; +}; +//========== + +void RocmTraceCollectorImpl::AddEvent(RocmTracerEvent&& event, + bool is_auxiliary) { + mutex_lock lock(event_maps_mutex_); + + if (event.source == RocmTracerEventSource::ApiCallback && !is_auxiliary) { + if (num_callback_events_ > options_.max_callback_api_events) { + OnEventsDropped("max callback event capacity reached", + event.correlation_id); + DumpRocmTracerEvent(event, 0, 0, ". Dropped!"); + return; + } + num_callback_events_++; + } else if (event.source == RocmTracerEventSource::Activity && + event.domain == RocmTracerEventDomain::HIP_API) { + // we do not count HIP_OPS activities. + if (num_activity_events_ > options_.max_activity_api_events) { + OnEventsDropped("max activity event capacity reached", + event.correlation_id); + DumpRocmTracerEvent(event, 0, 0, ". Dropped!"); + return; + } + num_activity_events_++; + } + + bool emplace_result = false; + if (event.source == RocmTracerEventSource::ApiCallback) { + auto& target_api_event_map = + (is_auxiliary) ? auxiliary_api_events_map_ : api_events_map_; + std::tie(std::ignore, emplace_result) = + target_api_event_map.emplace(event.correlation_id, std::move(event)); + } else if (event.source == RocmTracerEventSource::Activity) { + auto result = activity_ops_events_map_.emplace( + event.correlation_id, std::vector{}); + result.first->second.push_back(std::move(event)); + emplace_result = true; // we always accept Hip-Ops events + } + if (!emplace_result) { + OnEventsDropped("event with duplicate correlation_id was received.", + event.correlation_id); + DumpRocmTracerEvent(event, 0, 0, ". Dropped!"); + } +} + +void RocmTraceCollectorImpl::Flush() { + mutex_lock lock(event_maps_mutex_); + auto& aggregated_events_ = ApiActivityInfoExchange(); + + VLOG(3) << "RocmTraceCollector collected " << num_callback_events_ + << " callback events, " << num_activity_events_ + << " activity events, and aggregated them into " + << aggregated_events_.size() << " events."; + + // device ids for GPUs filled in by roctracer are not zero indexed. + // They are offset by number of CPUs on the machine + tsl::uint32 min_device_id = INT32_MAX; + ; + for (auto& event : aggregated_events_) { + if (event.device_id < min_device_id) { + min_device_id = event.device_id; + } + } + + for (auto event : aggregated_events_) { + event.device_id = event.device_id - min_device_id; + if (event.device_id < num_gpus_) { + per_device_collector_[event.device_id].AddEvent(event); + } else { + OnEventsDropped("Invalid device id for an event.", event.correlation_id); + DumpRocmTracerEvent(event, 0, 0, ". Dropped!"); + } + } + + activity_ops_events_map_.clear(); + api_events_map_.clear(); + auxiliary_api_events_map_.clear(); +} + +void RocmTraceCollectorImpl::Export(XSpace* space) { + uint64_t end_gputime_ns = get_timestamp(); + XPlaneBuilder host_plane(FindOrAddMutablePlaneWithName( + space, tsl::profiler::kRoctracerApiPlaneName)); + + for (int device_ordinal = 0; device_ordinal < num_gpus_; ++device_ordinal) { + std::string name = GpuPlaneName(device_ordinal); + XPlaneBuilder device_plane(FindOrAddMutablePlaneWithName(space, name)); + device_plane.SetId(device_ordinal); + // Calculate device capabilities before flushing, so that device + // properties are available to the occupancy calculator in export(). + per_device_collector_[device_ordinal].GetDeviceCapabilities(device_ordinal, + &device_plane); + per_device_collector_[device_ordinal].Export( + start_walltime_ns_, start_gputime_ns_, end_gputime_ns, &device_plane, + &host_plane); + NormalizeTimeStamps(&device_plane, start_walltime_ns_); + } + NormalizeTimeStamps(&host_plane, start_walltime_ns_); +} + +const std::vector +RocmTraceCollectorImpl::ApiActivityInfoExchange() { + /* Different from CUDA, roctracer activity records are not enough to fill a + TF event. For most of the activities, we need to enable the corresponding + API callsbacks (we call them auxiliary API callbacks) to capture the + necessary fields from them using the correlation id. The purpose of this + function is to let APIs and activities exchange information to reach a + state very similar to TF CUDA and getting ready to dump the event. + */ + + std::vector aggregated_events; + + // Copy info from activity events to API callback events + for (auto& api_iter : api_events_map_) { + RocmTracerEvent& api_event = api_iter.second; + auto activity_event = + activity_ops_events_map_.find(api_event.correlation_id); + + if (activity_event == activity_ops_events_map_.end()) { + OnEventsDropped( + "An event from HIP API discarded." + "Could not find the counterpart activity.", + api_event.correlation_id); + DumpRocmTracerEvent(api_event, 0, 0, ". Dropped!"); + } else { + api_event.device_id = activity_event->second.front().device_id; + api_event.stream_id = activity_event->second.front().stream_id; + switch (api_event.type) { + case RocmTracerEventType::Kernel: + case RocmTracerEventType::Memset: + case RocmTracerEventType::MemoryAlloc: + case RocmTracerEventType::MemoryFree: + case RocmTracerEventType::Synchronization: { + aggregated_events.push_back(api_event); + break; + } + case RocmTracerEventType::MemcpyD2H: + case RocmTracerEventType::MemcpyH2D: + case RocmTracerEventType::MemcpyD2D: + case RocmTracerEventType::MemcpyOther: { + api_event.memcpy_info.destination = + activity_event->second.front().device_id; + aggregated_events.push_back(api_event); + break; + } + default: + OnEventsDropped("Missing API-Activity information exchange. Dropped!", + api_event.correlation_id); + DumpRocmTracerEvent(api_event, 0, 0, ". Dropped!"); + LOG(WARNING) << "A ROCm API event type with unimplemented activity " + "merge dropped! " + "Type=" + << GetRocmTracerEventTypeName(api_event.type); + } + } + } + + // Make sure for all activity events we have API callback events + for (auto& activity_iter : activity_ops_events_map_) { + RocmTracerEvent& activity_event = activity_iter.second.front(); + auto api_event = api_events_map_.find(activity_event.correlation_id); + + if (api_event == api_events_map_.end()) { + api_event = auxiliary_api_events_map_.find(activity_event.correlation_id); + } + + if (api_event == auxiliary_api_events_map_.end()) { + OnEventsDropped( + "An event from activity was discarded." + "Could not find the counterpart HIP API.", + activity_event.correlation_id); + DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!"); + } else { + switch (activity_event.type) { + // KERNEL ACTIVITY + case RocmTracerEventType::Kernel: { + activity_event.name = api_event->second.name; + activity_event.kernel_info = api_event->second.kernel_info; + aggregated_events.push_back(activity_event); + break; + } + // MEMCPY ACTIVITY + case RocmTracerEventType::MemcpyD2H: + case RocmTracerEventType::MemcpyH2D: + case RocmTracerEventType::MemcpyD2D: + case RocmTracerEventType::MemcpyOther: { + activity_event.memcpy_info = api_event->second.memcpy_info; + aggregated_events.push_back(activity_event); + break; + } + // MEMSET ACTIVITY + case RocmTracerEventType::Memset: { + activity_event.memset_info = api_event->second.memset_info; + aggregated_events.push_back(activity_event); + break; + } + // MALLOC ACTIVITY, FREE ACTIVITY + case RocmTracerEventType::MemoryAlloc: + case RocmTracerEventType::MemoryFree: { + activity_event.device_id = api_event->second.device_id; + aggregated_events.push_back(activity_event); + break; + } + // SYNCHRONIZATION ACTIVITY + case RocmTracerEventType::Synchronization: { + activity_event.device_id = api_event->second.device_id; + aggregated_events.push_back(activity_event); + break; + } + default: + OnEventsDropped("Missing API-Activity information exchange. Dropped!", + activity_event.correlation_id); + DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!"); + LOG(WARNING) << "A ROCm activity event with unimplemented API " + "callback merge dropped! " + "Type=" + << GetRocmTracerEventTypeName(activity_event.type); + break; + } + } + } + + return aggregated_events; +} + +std::unique_ptr CreateRocmCollector( + const RocmTraceCollectorOptions& options, const uint64_t start_walltime_ns, + const uint64_t start_gputime_ns) { + return std::make_unique(options, start_walltime_ns, + start_gputime_ns); +} + +} // namespace profiler +} // namespace xla diff --git a/xla/backends/profiler/gpu/rocm_collector.h b/xla/backends/profiler/gpu/rocm_collector.h new file mode 100644 index 0000000000000..af8bc26f97aa2 --- /dev/null +++ b/xla/backends/profiler/gpu/rocm_collector.h @@ -0,0 +1,227 @@ +/* Copyright 2024 The OpenXLA Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_PROFILER_GPU_ROCM_COLLECTOR_H_ +#define XLA_BACKENDS_PROFILER_GPU_ROCM_COLLECTOR_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_set.h" +#include "tsl/profiler/utils/xplane_builder.h" + +namespace xla { +namespace profiler { + +using tsl::profiler::XSpace; + +struct MemcpyDetails { + // The amount of data copied for memcpy events. + size_t num_bytes; + // The destination device for peer-2-peer communication (memcpy). The source + // device is implicit: it's the current device. + uint32_t destination; + // Whether or not the memcpy is asynchronous. + bool async; +}; + +struct MemAllocDetails { + // The amount of data requested for cudaMalloc events. + uint64_t num_bytes; +}; + +struct MemsetDetails { + // The number of memory elements getting set + size_t num_bytes; + // Whether or not the memset is asynchronous. + bool async; +}; + +struct KernelDetails { + // The number of registers used in this kernel. + uint32_t registers_per_thread; + // The amount of shared memory space used by a thread block. + uint32_t static_shared_memory_usage; + // The amount of dynamic memory space used by a thread block. + uint32_t dynamic_shared_memory_usage; + // X-dimension of a thread block. + uint32_t block_x; + // Y-dimension of a thread block. + uint32_t block_y; + // Z-dimension of a thread block. + uint32_t block_z; + // X-dimension of a grid. + uint32_t grid_x; + // Y-dimension of a grid. + uint32_t grid_y; + // Z-dimension of a grid. + uint32_t grid_z; + + // kernel address. Used for calculating core occupancy + void* func_ptr; +}; + +inline std::string ToXStat(const KernelDetails& kernel_info, + double occupancy_pct) { + return absl::StrCat( + "regs:", kernel_info.registers_per_thread, + " static_shared:", kernel_info.static_shared_memory_usage, + " dynamic_shared:", kernel_info.dynamic_shared_memory_usage, + " grid:", kernel_info.grid_x, ",", kernel_info.grid_y, ",", + kernel_info.grid_z, " block:", kernel_info.block_x, ",", + kernel_info.block_y, ",", kernel_info.block_z, + " occ_pct:", occupancy_pct); +} + +enum class RocmTracerEventType { + Unsupported = 0, + Kernel, + MemcpyH2D, + MemcpyD2H, + MemcpyD2D, + MemcpyP2P, + MemcpyOther, + MemoryAlloc, + MemoryFree, + Memset, + Synchronization, + Generic, +}; + +const char* GetRocmTracerEventTypeName(const RocmTracerEventType& type); + +enum class RocmTracerEventSource { + Invalid = 0, + ApiCallback, + Activity, +}; + +const char* GetRocmTracerEventSourceName(const RocmTracerEventSource& source); + +enum class RocmTracerEventDomain { + InvalidDomain = 0, + HIP_API, + HCC_OPS, // TODO(rocm-profiler): renme this to HIP_OPS +}; +const char* GetRocmTracerEventDomainName(const RocmTracerEventDomain& domain); +// RocmTracerSyncTypes forward decleration +enum class RocmTracerSyncTypes; + +struct SynchronizationDetails { + RocmTracerSyncTypes sync_type; +}; + +struct RocmTracerEvent { + static constexpr uint32_t kInvalidDeviceId = + std::numeric_limits::max(); + static constexpr uint32_t kInvalidThreadId = + std::numeric_limits::max(); + static constexpr uint32_t kInvalidCorrelationId = + std::numeric_limits::max(); + static constexpr uint64_t kInvalidStreamId = + std::numeric_limits::max(); + RocmTracerEventType type; + RocmTracerEventSource source = RocmTracerEventSource::Invalid; + RocmTracerEventDomain domain; + std::string name; + // This points to strings in AnnotationMap, which should outlive the point + // where serialization happens. + absl::string_view annotation; + absl::string_view roctx_range; + uint64_t start_time_ns = 0; + uint64_t end_time_ns = 0; + uint32_t device_id = kInvalidDeviceId; + uint32_t correlation_id = kInvalidCorrelationId; + uint32_t thread_id = kInvalidThreadId; + int64_t stream_id = kInvalidStreamId; + union { + MemcpyDetails memcpy_info; // If type == Memcpy* + MemsetDetails memset_info; // If type == Memset* + MemAllocDetails memalloc_info; // If type == MemoryAlloc + KernelDetails kernel_info; // If type == Kernel + SynchronizationDetails synchronization_info; // If type == Synchronization + }; +}; + +struct RocmTraceCollectorOptions { + // Maximum number of events to collect from callback API; if -1, no limit. + // if 0, the callback API is enabled to build a correlation map, but no + // events are collected. + uint64_t max_callback_api_events; + // Maximum number of events to collect from activity API; if -1, no limit. + uint64_t max_activity_api_events; + // Maximum number of annotation strings that we can accommodate. + uint64_t max_annotation_strings; + // Number of GPUs involved. + uint32_t num_gpus; +}; + +class AnnotationMap { + public: + explicit AnnotationMap(uint64_t max_size) : max_size_(max_size) {} + void Add(uint32_t correlation_id, const std::string& annotation); + absl::string_view LookUp(uint32_t correlation_id); + + private: + struct AnnotationMapImpl { + // The population/consumption of annotations might happen from multiple + // callback/activity api related threads. + absl::Mutex mutex; + // Annotation tends to be repetitive, use a hash_set to store the strings, + // an use the reference to the string in the map. + absl::node_hash_set annotations; + absl::flat_hash_map correlation_map; + }; + const uint64_t max_size_; + AnnotationMapImpl map_; + + public: + // Disable copy and move. + AnnotationMap(const AnnotationMap&) = delete; + AnnotationMap& operator=(const AnnotationMap&) = delete; +}; + +class RocmTraceCollector { + public: + explicit RocmTraceCollector(const RocmTraceCollectorOptions& options) + : options_(options), annotation_map_(options.max_annotation_strings) {} + virtual ~RocmTraceCollector() {} + + virtual void AddEvent(RocmTracerEvent&& event, bool is_auxiliary) = 0; + virtual void OnEventsDropped(const std::string& reason, + uint32_t num_events) = 0; + virtual void Flush() = 0; + virtual void Export(XSpace* space) = 0; + + AnnotationMap* annotation_map() { return &annotation_map_; } + + protected: + RocmTraceCollectorOptions options_; + + private: + AnnotationMap annotation_map_; + + public: + // Disable copy and move. + RocmTraceCollector(const RocmTraceCollector&) = delete; + RocmTraceCollector& operator=(const RocmTraceCollector&) = delete; +}; + +std::unique_ptr CreateRocmCollector( + const RocmTraceCollectorOptions& options, const uint64_t start_walltime_ns, + const uint64_t start_gputime_ns); + +} // namespace profiler +} // namespace xla + +#endif // XLA_BACKENDS_PROFILER_GPU_ROCM_COLLECTOR_H_ diff --git a/xla/backends/profiler/gpu/rocm_tracer.cc b/xla/backends/profiler/gpu/rocm_tracer.cc index fd3711b3c0fc3..8bdef9300baa2 100644 --- a/xla/backends/profiler/gpu/rocm_tracer.cc +++ b/xla/backends/profiler/gpu/rocm_tracer.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The OpenXLA Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -169,7 +169,15 @@ inline void DumpApiCallbackData(uint32_t domain, uint32_t cbid, break; case HIP_API_ID_hipStreamSynchronize: break; + case HIP_API_ID_hipStreamWaitEvent: // ignore all aux HIP API Events + case HIP_API_ID_hipHostFree: + case HIP_API_ID_hipHostMalloc: + case HIP_API_ID_hipSetDevice: + break; default: + VLOG(3) << "Warning: HIP_API_ID_x is not handled in " + "DumpApiCallbackData, HIP_API_ID=" + << cbid; DCHECK(false); break; } @@ -269,51 +277,8 @@ const char* GetRocmTracerEventDomainName(const RocmTracerEventDomain& domain) { return ""; } -void DumpRocmTracerEvent(const RocmTracerEvent& event, - uint64_t start_walltime_ns, uint64_t start_gputime_ns, - const std::string& message) { - std::ostringstream oss; - oss << "correlation_id=" << event.correlation_id; - oss << ",type=" << GetRocmTracerEventTypeName(event.type); - oss << ",source=" << GetRocmTracerEventSourceName(event.source); - oss << ",domain=" << GetRocmTracerEventDomainName(event.domain); - oss << ",name=" << event.name; - oss << ",annotation=" << event.annotation; - oss << ",start_time_us=" - << (start_walltime_ns + (start_gputime_ns - event.start_time_ns)) / 1000; - oss << ",duration=" << (event.end_time_ns - event.start_time_ns) / 1000; - oss << ",device_id=" << event.device_id; - oss << ",thread_id=" << event.thread_id; - oss << ",stream_id=" << event.stream_id; - - switch (event.type) { - case RocmTracerEventType::Kernel: - break; - case RocmTracerEventType::MemcpyD2H: - case RocmTracerEventType::MemcpyH2D: - case RocmTracerEventType::MemcpyD2D: - case RocmTracerEventType::MemcpyP2P: - oss << ",num_bytes=" << event.memcpy_info.num_bytes; - oss << ",destination=" << event.memcpy_info.destination; - oss << ",async=" << event.memcpy_info.async; - break; - case RocmTracerEventType::MemoryAlloc: - oss << ",num_bytes=" << event.memalloc_info.num_bytes; - break; - case RocmTracerEventType::Synchronization: - break; - case RocmTracerEventType::Generic: - break; - default: - DCHECK(false); - break; - } - oss << message; - VLOG(3) << oss.str(); -} - -tsl::Status RocmApiCallbackImpl::operator()(uint32_t domain, uint32_t cbid, - const void* cbdata) { +absl::Status RocmApiCallbackImpl::operator()(uint32_t domain, uint32_t cbid, + const void* cbdata) { /* Some APIs such as hipMalloc, implicitly work on th devices set by the user using APIs such as hipSetDevice. API callbacks and activity records for functions like hipMalloc does not return the device id (CUDA does). To @@ -321,7 +286,7 @@ tsl::Status RocmApiCallbackImpl::operator()(uint32_t domain, uint32_t cbid, hipSetDevice) for each thread. */ - thread_local uint32_t default_device = 0; + thread_local uint32_t default_device = hipGetStreamDeviceId(nullptr); // DumpApiCallbackData(domain, cbid, cbdata); @@ -338,7 +303,7 @@ tsl::Status RocmApiCallbackImpl::operator()(uint32_t domain, uint32_t cbid, } if (cbid == HIP_API_ID_hipSetDevice) { - default_device = data->args.hipSetDevice.deviceId; + default_device = hipGetStreamDeviceId(nullptr); } } else if (data->phase == ACTIVITY_API_PHASE_EXIT) { uint64_t enter_time = 0, exit_time = 0; @@ -556,7 +521,7 @@ void RocmApiCallbackImpl::AddNormalMemcpyEventUponApiExit( missing: device_id(partially, have only for async), context_id, memcpy_info.kind(CUPTI puts CUPTI_ACTIVITY_MEMCPY_KIND_UNKNOWN), - memcpy_info.destenation(partially, only for async)( CUPTI puts device_id), + memcpy_info.destination(partially, only for async)( CUPTI puts device_id), extra: domain, name, @@ -873,8 +838,8 @@ void RocmApiCallbackImpl::AddSynchronizeEventUponApiExit( collector_->AddEvent(std::move(event), is_auxiliary); } -tsl::Status RocmActivityCallbackImpl::operator()(const char* begin, - const char* end) { +absl::Status RocmActivityCallbackImpl::operator()(const char* begin, + const char* end) { // we do not dump activities in this set in logger static std::set dump_excluded_activities = { @@ -1332,26 +1297,6 @@ void RocmActivityCallbackImpl::AddHipOpsMemsetActivityEvent( collector_->AddEvent(std::move(event), false); } -void AnnotationMap::Add(uint32_t correlation_id, - const std::string& annotation) { - if (annotation.empty()) return; - VLOG(3) << "Add annotation: " - << " correlation_id=" << correlation_id - << ", annotation: " << annotation; - absl::MutexLock lock(&map_.mutex); - if (map_.annotations.size() < max_size_) { - absl::string_view annotation_str = - *map_.annotations.insert(annotation).first; - map_.correlation_map.emplace(correlation_id, annotation_str); - } -} - -absl::string_view AnnotationMap::LookUp(uint32_t correlation_id) { - absl::MutexLock lock(&map_.mutex); - auto it = map_.correlation_map.find(correlation_id); - return it != map_.correlation_map.end() ? it->second : absl::string_view(); -} - /* static */ RocmTracer* RocmTracer::GetRocmTracerSingleton() { static auto* singleton = new RocmTracer(); return singleton; @@ -1414,14 +1359,14 @@ void ApiCallback(uint32_t domain, uint32_t cbid, const void* cbdata, tracer->ApiCallbackHandler(domain, cbid, cbdata).IgnoreError(); } -tsl::Status RocmTracer::ApiCallbackHandler(uint32_t domain, uint32_t cbid, - const void* cbdata) { +absl::Status RocmTracer::ApiCallbackHandler(uint32_t domain, uint32_t cbid, + const void* cbdata) { if (api_tracing_enabled_) TF_RETURN_IF_ERROR((*api_cb_impl_)(domain, cbid, cbdata)); return tsl::OkStatus(); } -tsl::Status RocmTracer::EnableApiTracing() { +absl::Status RocmTracer::EnableApiTracing() { if (api_tracing_enabled_) return tsl::OkStatus(); api_tracing_enabled_ = true; @@ -1447,7 +1392,7 @@ tsl::Status RocmTracer::EnableApiTracing() { return tsl::OkStatus(); } -tsl::Status RocmTracer::DisableApiTracing() { +absl::Status RocmTracer::DisableApiTracing() { if (!api_tracing_enabled_) return tsl::OkStatus(); api_tracing_enabled_ = false; @@ -1478,8 +1423,8 @@ void ActivityCallback(const char* begin, const char* end, void* user_data) { tracer->ActivityCallbackHandler(begin, end).IgnoreError(); } -tsl::Status RocmTracer::ActivityCallbackHandler(const char* begin, - const char* end) { +absl::Status RocmTracer::ActivityCallbackHandler(const char* begin, + const char* end) { if (activity_tracing_enabled_) { TF_RETURN_IF_ERROR((*activity_cb_impl_)(begin, end)); } else { @@ -1507,7 +1452,7 @@ tsl::Status RocmTracer::ActivityCallbackHandler(const char* begin, return tsl::OkStatus(); } -tsl::Status RocmTracer::EnableActivityTracing() { +absl::Status RocmTracer::EnableActivityTracing() { if (activity_tracing_enabled_) return tsl::OkStatus(); activity_tracing_enabled_ = true; @@ -1548,7 +1493,7 @@ tsl::Status RocmTracer::EnableActivityTracing() { return tsl::OkStatus(); } -tsl::Status RocmTracer::DisableActivityTracing() { +absl::Status RocmTracer::DisableActivityTracing() { if (!activity_tracing_enabled_) return tsl::OkStatus(); for (auto& iter : options_->activity_tracing) { @@ -1589,8 +1534,8 @@ tsl::Status RocmTracer::DisableActivityTracing() { size_t threshold = 1; for (int i = 0; i < 6; i++, duration_ms *= 2, threshold *= 2) { if (GetPendingActivityRecordsCount() < threshold) break; - VLOG(3) << "Wait for pending activity records :" - << " Pending count = " << GetPendingActivityRecordsCount() + VLOG(3) << "Wait for pending activity records :" << " Pending count = " + << GetPendingActivityRecordsCount() << ", Threshold = " << threshold; VLOG(3) << "Wait for pending activity records : sleep for " << duration_ms << " ms"; diff --git a/xla/backends/profiler/gpu/rocm_tracer.h b/xla/backends/profiler/gpu/rocm_tracer.h index cfc29d3890988..b82a1e66d0092 100644 --- a/xla/backends/profiler/gpu/rocm_tracer.h +++ b/xla/backends/profiler/gpu/rocm_tracer.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The OpenXLA Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_set.h" #include "absl/types/optional.h" +#include "xla/backends/profiler/gpu/rocm_collector.h" #include "xla/stream_executor/rocm/roctracer_wrapper.h" #include "tsl/platform/errors.h" #include "tsl/platform/macros.h" @@ -30,88 +31,6 @@ limitations under the License. namespace xla { namespace profiler { -struct MemcpyDetails { - // The amount of data copied for memcpy events. - size_t num_bytes; - // The destination device for peer-2-peer communication (memcpy). The source - // device is implicit: it's the current device. - uint32_t destination; - // Whether or not the memcpy is asynchronous. - bool async; -}; - -struct MemsetDetails { - // The number of memory elements getting set - size_t num_bytes; - // Whether or not the memset is asynchronous. - bool async; -}; - -struct MemAllocDetails { - // The amount of data requested for cudaMalloc events. - uint64_t num_bytes; -}; - -struct KernelDetails { - // The number of registers used in this kernel. - uint32_t registers_per_thread; - // The amount of shared memory space used by a thread block. - uint32_t static_shared_memory_usage; - // The amount of dynamic memory space used by a thread block. - uint32_t dynamic_shared_memory_usage; - // X-dimension of a thread block. - uint32_t block_x; - // Y-dimension of a thread block. - uint32_t block_y; - // Z-dimension of a thread block. - uint32_t block_z; - // X-dimension of a grid. - uint32_t grid_x; - // Y-dimension of a grid. - uint32_t grid_y; - // Z-dimension of a grid. - uint32_t grid_z; - - // kernel address. Used for calculating core occupancy - void* func_ptr; -}; - -// RocmTracerSyncTypes forward decleration -enum class RocmTracerSyncTypes; -struct SynchronizationDetails { - RocmTracerSyncTypes sync_type; -}; - -enum class RocmTracerEventType { - Unsupported = 0, - Kernel, - MemcpyH2D, - MemcpyD2H, - MemcpyD2D, - MemcpyP2P, - MemcpyOther, - MemoryAlloc, - MemoryFree, - Memset, - Synchronization, - Generic, -}; - -const char* GetRocmTracerEventTypeName(const RocmTracerEventType& type); - -enum class RocmTracerEventSource { - Invalid = 0, - ApiCallback, - Activity, -}; - -const char* GetRocmTracerEventSourceName(const RocmTracerEventSource& source); - -enum class RocmTracerEventDomain { - InvalidDomain = 0, - HIP_API, - HCC_OPS, // TODO(rocm-profiler): renme this to HIP_OPS -}; enum class RocmTracerSyncTypes { InvalidSync = 0, StreamSynchronize, // caller thread wait stream to become empty @@ -119,44 +38,6 @@ enum class RocmTracerSyncTypes { StreamWait // compute stream will wait for event to happen }; -const char* GetRocmTracerEventDomainName(const RocmTracerEventDomain& domain); - -struct RocmTracerEvent { - static constexpr uint32_t kInvalidDeviceId = - std::numeric_limits::max(); - static constexpr uint32_t kInvalidThreadId = - std::numeric_limits::max(); - static constexpr uint32_t kInvalidCorrelationId = - std::numeric_limits::max(); - static constexpr uint64_t kInvalidStreamId = - std::numeric_limits::max(); - RocmTracerEventType type; - RocmTracerEventSource source = RocmTracerEventSource::Invalid; - RocmTracerEventDomain domain; - std::string name; - // This points to strings in AnnotationMap, which should outlive the point - // where serialization happens. - absl::string_view annotation; - absl::string_view roctx_range; - uint64_t start_time_ns = 0; - uint64_t end_time_ns = 0; - uint32_t device_id = kInvalidDeviceId; - uint32_t correlation_id = kInvalidCorrelationId; - uint32_t thread_id = kInvalidThreadId; - int64_t stream_id = kInvalidStreamId; - union { - MemcpyDetails memcpy_info; // If type == Memcpy* - MemsetDetails memset_info; // If type == Memset* - MemAllocDetails memalloc_info; // If type == MemoryAlloc - KernelDetails kernel_info; // If type == Kernel - SynchronizationDetails synchronization_info; // If type == Synchronization - }; -}; - -void DumpRocmTracerEvent(const RocmTracerEvent& event, - uint64_t start_walltime_ns, uint64_t start_gputime_ns, - const std::string& message); - struct RocmTracerOptions { std::set api_tracking_set; // actual api set we want to profile @@ -170,69 +51,6 @@ struct RocmTracerOptions { activity_tracing; }; -struct RocmTraceCollectorOptions { - // Maximum number of events to collect from callback API; if -1, no limit. - // if 0, the callback API is enabled to build a correlation map, but no - // events are collected. - uint64_t max_callback_api_events; - // Maximum number of events to collect from activity API; if -1, no limit. - uint64_t max_activity_api_events; - // Maximum number of annotation strings that we can accommodate. - uint64_t max_annotation_strings; - // Number of GPUs involved. - uint32_t num_gpus; -}; - -class AnnotationMap { - public: - explicit AnnotationMap(uint64_t max_size) : max_size_(max_size) {} - void Add(uint32_t correlation_id, const std::string& annotation); - absl::string_view LookUp(uint32_t correlation_id); - - private: - struct AnnotationMapImpl { - // The population/consumption of annotations might happen from multiple - // callback/activity api related threads. - absl::Mutex mutex; - // Annotation tends to be repetitive, use a hash_set to store the strings, - // an use the reference to the string in the map. - absl::node_hash_set annotations; - absl::flat_hash_map correlation_map; - }; - const uint64_t max_size_; - AnnotationMapImpl map_; - - public: - // Disable copy and move. - AnnotationMap(const AnnotationMap&) = delete; - AnnotationMap& operator=(const AnnotationMap&) = delete; -}; - -class RocmTraceCollector { - public: - explicit RocmTraceCollector(const RocmTraceCollectorOptions& options) - : options_(options), annotation_map_(options.max_annotation_strings) {} - virtual ~RocmTraceCollector() {} - - virtual void AddEvent(RocmTracerEvent&& event, bool is_auxiliary) = 0; - virtual void OnEventsDropped(const std::string& reason, - uint32_t num_events) = 0; - virtual void Flush() = 0; - - AnnotationMap* annotation_map() { return &annotation_map_; } - - protected: - RocmTraceCollectorOptions options_; - - private: - AnnotationMap annotation_map_; - - public: - // Disable copy and move. - RocmTraceCollector(const RocmTraceCollector&) = delete; - RocmTraceCollector& operator=(const RocmTraceCollector&) = delete; -}; - class RocmTracer; class RocmApiCallbackImpl { @@ -241,7 +59,7 @@ class RocmApiCallbackImpl { RocmTraceCollector* collector) : options_(options), tracer_(tracer), collector_(collector) {} - tsl::Status operator()(uint32_t domain, uint32_t cbid, const void* cbdata); + absl::Status operator()(uint32_t domain, uint32_t cbid, const void* cbdata); private: void AddKernelEventUponApiExit(uint32_t cbid, const hip_api_data_t* data, @@ -279,7 +97,7 @@ class RocmActivityCallbackImpl { RocmTraceCollector* collector) : options_(options), tracer_(tracer), collector_(collector) {} - tsl::Status operator()(const char* begin, const char* end); + absl::Status operator()(const char* begin, const char* end); private: void AddHipKernelActivityEvent(const roctracer_record_t* record); @@ -295,7 +113,7 @@ class RocmActivityCallbackImpl { RocmTraceCollector* collector_ = nullptr; }; -// The class use to enable cupti callback/activity API and forward the collected +// The class uses roctracer callback/activity API and forward the collected // trace events to RocmTraceCollector. There should be only one RocmTracer // per process. class RocmTracer { @@ -309,9 +127,9 @@ class RocmTracer { void Enable(const RocmTracerOptions& options, RocmTraceCollector* collector); void Disable(); - tsl::Status ApiCallbackHandler(uint32_t domain, uint32_t cbid, - const void* cbdata); - tsl::Status ActivityCallbackHandler(const char* begin, const char* end); + absl::Status ApiCallbackHandler(uint32_t domain, uint32_t cbid, + const void* cbdata); + absl::Status ActivityCallbackHandler(const char* begin, const char* end); static uint64_t GetTimestamp(); static int NumGpus(); @@ -335,11 +153,11 @@ class RocmTracer { explicit RocmTracer() : num_gpus_(NumGpus()) {} private: - tsl::Status EnableApiTracing(); - tsl::Status DisableApiTracing(); + absl::Status EnableApiTracing(); + absl::Status DisableApiTracing(); - tsl::Status EnableActivityTracing(); - tsl::Status DisableActivityTracing(); + absl::Status EnableActivityTracing(); + absl::Status DisableActivityTracing(); int num_gpus_; std::optional options_; diff --git a/xla/backends/profiler/plugin/BUILD b/xla/backends/profiler/plugin/BUILD index 53fd864e1ad9c..169a4eaa4edb9 100644 --- a/xla/backends/profiler/plugin/BUILD +++ b/xla/backends/profiler/plugin/BUILD @@ -1,9 +1,9 @@ +load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("@tsl//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") load( "//xla:xla.bzl", "xla_cc_test", ) -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("@tsl//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -23,10 +23,7 @@ cc_library( srcs = ["plugin_tracer.cc"], hdrs = ["plugin_tracer.h"], copts = tf_profiler_copts(), - visibility = [ - "//third_party/xprof/plugin/tensorboard_plugin_profile/integration_tests/tpu/tensorflow:__pkg__", - "//xla:internal", - ], + visibility = ["//xla:internal"], deps = [ ":profiler_c_api_hdrs", "//xla:status", diff --git a/xla/backends/profiler/plugin/plugin_tracer.cc b/xla/backends/profiler/plugin/plugin_tracer.cc index 1df97001bc7ed..8176ac55f553e 100644 --- a/xla/backends/profiler/plugin/plugin_tracer.cc +++ b/xla/backends/profiler/plugin/plugin_tracer.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/profiler/plugin/plugin_tracer.h b/xla/backends/profiler/plugin/plugin_tracer.h index ee3ce3950d7b6..acd88644ac090 100644 --- a/xla/backends/profiler/plugin/plugin_tracer.h +++ b/xla/backends/profiler/plugin/plugin_tracer.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/profiler/plugin/plugin_tracer_impl.cc b/xla/backends/profiler/plugin/plugin_tracer_impl.cc index dbeb577c1f7f3..63b062103ce66 100644 --- a/xla/backends/profiler/plugin/plugin_tracer_impl.cc +++ b/xla/backends/profiler/plugin/plugin_tracer_impl.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/profiler/plugin/plugin_tracer_impl.h b/xla/backends/profiler/plugin/plugin_tracer_impl.h index 99f9e5153b1a2..5e1eda3864eb0 100644 --- a/xla/backends/profiler/plugin/plugin_tracer_impl.h +++ b/xla/backends/profiler/plugin/plugin_tracer_impl.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/profiler/plugin/plugin_tracer_impl_test.cc b/xla/backends/profiler/plugin/plugin_tracer_impl_test.cc index c7f888f7f2b4d..b9dc203ddad27 100644 --- a/xla/backends/profiler/plugin/plugin_tracer_impl_test.cc +++ b/xla/backends/profiler/plugin/plugin_tracer_impl_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/profiler/plugin/profiler_c_api.h b/xla/backends/profiler/plugin/profiler_c_api.h index 6f548257c0911..23ead3089d788 100644 --- a/xla/backends/profiler/plugin/profiler_c_api.h +++ b/xla/backends/profiler/plugin/profiler_c_api.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/profiler/plugin/profiler_error.cc b/xla/backends/profiler/plugin/profiler_error.cc index 7abeae126e8a6..7be485b944bf9 100644 --- a/xla/backends/profiler/plugin/profiler_error.cc +++ b/xla/backends/profiler/plugin/profiler_error.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/profiler/plugin/profiler_error.h b/xla/backends/profiler/plugin/profiler_error.h index 7439101685c77..bbd6cc601d646 100644 --- a/xla/backends/profiler/plugin/profiler_error.h +++ b/xla/backends/profiler/plugin/profiler_error.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/backends/profiler/tpu/BUILD b/xla/backends/profiler/tpu/BUILD index 11c7452b9757c..76bfe264d670e 100644 --- a/xla/backends/profiler/tpu/BUILD +++ b/xla/backends/profiler/tpu/BUILD @@ -1,6 +1,6 @@ +load("@tsl//tsl:tsl.bzl", "if_with_tpu_support") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load("@tsl//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") -load("@tsl//tsl:tsl.bzl", "if_with_tpu_support") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -11,18 +11,15 @@ cc_library( name = "tpu_tracer", srcs = if_with_tpu_support(["tpu_tracer.cc"]), copts = tf_profiler_copts(), - visibility = [ - "//third_party/xprof/plugin/tensorboard_plugin_profile/integration_tests/tpu/tensorflow:__pkg__", - "//xla:internal", - ], + visibility = ["//xla:internal"], deps = [ "//xla/stream_executor/tpu:tpu_api", "//xla/stream_executor/tpu:tpu_api_dlsym_set_fn", "//xla/stream_executor/tpu:tpu_ops_c_api_hdrs", "//xla/stream_executor/tpu:tpu_profiler_init_fns", "//xla/stream_executor/tpu:tsl_status_helper", + "//xla/tsl/c:tsl_status", "@com_google_absl//absl/strings", - "@tsl//tsl/c:tsl_status", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status", "@tsl//tsl/platform:types", diff --git a/xla/backends/profiler/tpu/tpu_tracer.cc b/xla/backends/profiler/tpu/tpu_tracer.cc index aa6bab072d87b..19c7e57e49a13 100644 --- a/xla/backends/profiler/tpu/tpu_tracer.cc +++ b/xla/backends/profiler/tpu/tpu_tracer.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,7 +27,7 @@ limitations under the License. #include "xla/stream_executor/tpu/tpu_api_dlsym_set_fn.h" #include "xla/stream_executor/tpu/tpu_ops_c_api.h" #include "xla/stream_executor/tpu/tsl_status_helper.h" -#include "tsl/c/tsl_status.h" +#include "xla/tsl/c/tsl_status.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/types.h" @@ -61,7 +61,7 @@ class ProfilerStatusHelper { stream_executor::tpu::ProfilerApiFn()->TpuStatus_FreeFn(c_status); } - static tsl::Status FromC( // TENSORFLOW_STATUS_OK + static absl::Status FromC( // TENSORFLOW_STATUS_OK TF_Status* const c_status) { if (stream_executor::tpu::ProfilerApiFn()->TpuStatus_CodeFn(c_status) == TSL_OK) { @@ -80,7 +80,7 @@ class ProfilerStatusHelper { TSL_OK; } - tsl::Status status() const { // TENSORFLOW_STATUS_OK + absl::Status status() const { // TENSORFLOW_STATUS_OK return FromC(c_status); } diff --git a/xla/bit_cast.h b/xla/bit_cast.h index 026ba88f1f095..95c13c4070e38 100644 --- a/xla/bit_cast.h +++ b/xla/bit_cast.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/bit_cast_test.cc b/xla/bit_cast_test.cc index f94007b536192..15d5195024c63 100644 --- a/xla/bit_cast_test.cc +++ b/xla/bit_cast_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/c/BUILD b/xla/c/BUILD index 3807acbba10e3..7d4f5bcc99ce7 100644 --- a/xla/c/BUILD +++ b/xla/c/BUILD @@ -1,11 +1,12 @@ +load("@tsl//tsl:tsl.bzl", "internal_visibility") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ + default_visibility = internal_visibility([ "//learning/brain/tfrt/tpu_plugin:__subpackages__", "//tensorflow/core/common_runtime/next_pluggable_device:__subpackages__", - ], + ]), licenses = ["notice"], ) diff --git a/xla/c/c_api_decl.h b/xla/c/c_api_decl.h index a7c0c3b61fd50..4c415151c336a 100644 --- a/xla/c/c_api_decl.h +++ b/xla/c/c_api_decl.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/BUILD b/xla/client/BUILD index d90cc0bc6612c..6a37f59be4043 100644 --- a/xla/client/BUILD +++ b/xla/client/BUILD @@ -1,9 +1,9 @@ # Description: # XLA client libraries. -load("//xla:xla.bzl", "xla_cc_test") load("@tsl//tsl:tsl.default.bzl", "filegroup") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -39,7 +39,6 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:status", ], ) @@ -84,7 +83,6 @@ cc_library( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/service:hlo_proto_cc", - "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", @@ -101,14 +99,19 @@ cc_library( "//xla:debug_options_flags", "//xla:execution_options_util", "//xla:shape_util", + "//xla:statusor", + "//xla:util", "//xla:xla_proto_cc", "//xla/pjrt:compile_options_proto_cc", "//xla/service:compilation_environments", "//xla/service:computation_placer", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings:str_format", "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", ], ) @@ -179,7 +182,6 @@ cc_library( "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/memory", "@tsl//tsl/platform:logging", ], ) @@ -209,7 +211,6 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/service:hlo_proto_cc", - "@com_google_absl//absl/memory", ], ) @@ -287,12 +288,15 @@ xla_cc_test( name = "xla_builder_test", srcs = ["xla_builder_test.cc"], deps = [ + ":padding", ":sharding_builder", ":value_inference", ":xla_builder", ":xla_computation", + "//xla:comparison_util", "//xla:debug_options_flags", "//xla:shape_util", + "//xla:status", "//xla:statusor", "//xla:test", "//xla:test_helpers", @@ -303,7 +307,13 @@ xla_cc_test( "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", + "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/client/client.cc b/xla/client/client.cc index 2f7ad6efcfabf..1d2fdab9029af 100644 --- a/xla/client/client.cc +++ b/xla/client/client.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,8 +38,8 @@ Client::Client(ServiceInterface* stub) : stub_(stub) {} Client::~Client() = default; -StatusOr Client::Transfer(const GlobalData& data, - const Shape* shape_with_layout) { +absl::StatusOr Client::Transfer(const GlobalData& data, + const Shape* shape_with_layout) { TransferToClientRequest request; *request.mutable_data() = data.handle(); if (shape_with_layout != nullptr) { @@ -65,7 +65,7 @@ StatusOr Client::Transfer(const GlobalData& data, return Literal::CreateFromProto(*response.mutable_literal()); } -StatusOr> Client::TransferToServer( +absl::StatusOr> Client::TransferToServer( const LiteralSlice& literal, const DeviceHandle* device_handle) { TransferToServerRequest request; *request.mutable_literal() = literal.ToProto(); @@ -115,7 +115,7 @@ Status Client::TransferToInfeed(const LiteralSlice& literal, int64_t replica_id, return OkStatus(); } -StatusOr Client::TransferFromOutfeed( +absl::StatusOr Client::TransferFromOutfeed( const Shape* shape_with_layout, int64_t replica_id, const DeviceHandle* device_handle) { TransferFromOutfeedRequest request; @@ -163,7 +163,7 @@ Status Client::ResetDevice() { return OkStatus(); } -StatusOr Client::ExecuteAndTransfer( +absl::StatusOr Client::ExecuteAndTransfer( const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options, ExecutionProfile* execution_profile) { @@ -181,8 +181,8 @@ StatusOr Client::ExecuteAndTransfer( : nullptr); } -StatusOr Client::ComputeConstant(const XlaComputation& computation, - const Layout* output_layout) const { +absl::StatusOr Client::ComputeConstant( + const XlaComputation& computation, const Layout* output_layout) const { ComputeConstantGraphRequest request; *request.mutable_computation() = computation.proto(); if (output_layout != nullptr) { @@ -202,19 +202,19 @@ StatusOr Client::ComputeConstant(const XlaComputation& computation, VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}"; if (!response.has_literal()) { - return InternalError( + return Internal( "no computed literal in the provided response in ComputeConstantGraph " "request"); } return Literal::CreateFromProto(response.literal()); } -StatusOr Client::LoadSnapshot(const HloSnapshot& module) { +absl::StatusOr Client::LoadSnapshot(const HloSnapshot& module) { TF_RET_CHECK(module.has_hlo() && module.hlo().has_hlo_module()); return XlaComputation(module.hlo().hlo_module()); } -StatusOr Client::Compile( +absl::StatusOr Client::Compile( const XlaComputation& computation, absl::Span argument_shapes, const ExecutionOptions* execution_options) { CompileRequest request; @@ -248,7 +248,7 @@ StatusOr Client::Compile( return response.handle(); } -StatusOr> Client::Execute( +absl::StatusOr> Client::Execute( const ExecutionHandle& handle, absl::Span arguments, ExecutionProfile* execution_profile) { ExecuteRequest request; @@ -274,7 +274,7 @@ StatusOr> Client::Execute( return std::make_unique(stub_, response.output()); } -StatusOr> Client::Execute( +absl::StatusOr> Client::Execute( const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options, ExecutionProfile* execution_profile) { @@ -325,8 +325,8 @@ StatusOr> Client::Execute( return std::move(results[0]); } -StatusOr>> Client::ExecuteParallel( - absl::Span computations) { +absl::StatusOr>> +Client::ExecuteParallel(absl::Span computations) { ExecuteGraphParallelRequest request; for (const XlaComputationInstance& computation : computations) { @@ -362,7 +362,7 @@ StatusOr>> Client::ExecuteParallel( return std::move(outputs); } -StatusOr> Client::GetDeviceHandles( +absl::StatusOr> Client::GetDeviceHandles( int64_t device_count) { if (device_count < 1) { return InvalidArgument("device_count must be greater than 0"); @@ -401,8 +401,8 @@ Status Client::Unregister(const GlobalData& data) { return s; } -StatusOr>> Client::DeconstructTuple( - const GlobalData& data) { +absl::StatusOr>> +Client::DeconstructTuple(const GlobalData& data) { DeconstructTupleRequest request; *request.mutable_tuple_handle() = data.handle(); DeconstructTupleResponse response; @@ -422,7 +422,7 @@ StatusOr>> Client::DeconstructTuple( return std::move(handles); } -StatusOr Client::GetComputationStats( +absl::StatusOr Client::GetComputationStats( const XlaComputation& computation, const DebugOptions& debug_options) const { ComputationGraphStatsRequest request; @@ -443,13 +443,13 @@ StatusOr Client::GetComputationStats( return response.stats(); } -StatusOr> Client::GetComputationShape( +absl::StatusOr> Client::GetComputationShape( const XlaComputation& computation) { TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape()); return std::make_unique(result); } -StatusOr Client::GetShape(const GlobalData& data) { +absl::StatusOr Client::GetShape(const GlobalData& data) { GetShapeRequest request; *request.mutable_data() = data.handle(); GetShapeResponse response; @@ -465,7 +465,7 @@ StatusOr Client::GetShape(const GlobalData& data) { return Shape(response.shape()); } -StatusOr Client::ExecutionStatsAsString( +absl::StatusOr Client::ExecutionStatsAsString( const XlaComputation& computation, const ExecutionProfile& profile) { TF_ASSIGN_OR_RETURN( auto computation_stats, @@ -486,7 +486,7 @@ StatusOr Client::ExecutionStatsAsString( return std::string("[Execution Statistics] not available."); } -StatusOr Client::CreateChannelHandleByType( +absl::StatusOr Client::CreateChannelHandleByType( ChannelHandle::ChannelType type) { CreateChannelHandleRequest request; request.set_channel_type(type); @@ -503,15 +503,15 @@ StatusOr Client::CreateChannelHandleByType( return response.channel(); } -StatusOr Client::CreateChannelHandle() { +absl::StatusOr Client::CreateChannelHandle() { return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_DEVICE); } -StatusOr Client::CreateHostToDeviceChannelHandle() { +absl::StatusOr Client::CreateHostToDeviceChannelHandle() { return CreateChannelHandleByType(ChannelHandle::HOST_TO_DEVICE); } -StatusOr Client::CreateDeviceToHostChannelHandle() { +absl::StatusOr Client::CreateDeviceToHostChannelHandle() { return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_HOST); } diff --git a/xla/client/client.h b/xla/client/client.h index bdd7cfaaa93d3..e4c48d5835995 100644 --- a/xla/client/client.h +++ b/xla/client/client.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -59,7 +59,7 @@ class Client { // to call the Execute(const XlaComputation&) overload. If you're going to // run the computation more than once but you want control over when the // Executable is unloaded, use the LocalClient API. - StatusOr Compile( + absl::StatusOr Compile( const XlaComputation& computation, absl::Span argument_shapes, const ExecutionOptions* execution_options = nullptr); @@ -68,7 +68,7 @@ class Client { // arguments and returns the global data that was produced from the execution. // * If execution_profile is not nullptr then the pointed-to ExecutionProfile // will be filled with profile data from the execution. - StatusOr> Execute( + absl::StatusOr> Execute( const ExecutionHandle& handle, absl::Span arguments, ExecutionProfile* execution_profile = nullptr); @@ -87,7 +87,7 @@ class Client { // TODO(b/122731460): The given computation is compiled and then thrown away // immediately after it's run. If you want control over how long the // resulting Executable lives, use the LocalClient API. - StatusOr> Execute( + absl::StatusOr> Execute( const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options = nullptr, @@ -117,14 +117,15 @@ class Client { // Executes a list XlaComputationInstances and returns global data produced // from each computation. // - StatusOr>> ExecuteParallel( + absl::StatusOr>> ExecuteParallel( absl::Span computations); // Requests device_count device handles available on the target. The returned // device handles are used to specify the devices to execute the computations // (see ExecuteParallel) or to transfer data (see TransferToServer or // TransferToInfeed). - StatusOr> GetDeviceHandles(int64_t device_count); + absl::StatusOr> GetDeviceHandles( + int64_t device_count); // Transfer the global data provided to this client process, which is // returned in the provided literal. Use sparingly to avoid transfer @@ -132,8 +133,8 @@ class Client { // // If shape_with_layout is not nullptr, it points to a shape whose layout will // be the layout of the returned literal. - StatusOr Transfer(const GlobalData& data, - const Shape* shape_with_layout = nullptr); + absl::StatusOr Transfer(const GlobalData& data, + const Shape* shape_with_layout = nullptr); // Transfer the given literal to the server. This allocates memory on the // device and copies the literal's contents over. Returns a global data handle @@ -142,7 +143,7 @@ class Client { // If device_handle is not nullptr, data is transferred to the associated // device (and its replicas if replication is enabled). Otherwise, data is // transferred to the default device (and its replicas). - StatusOr> TransferToServer( + absl::StatusOr> TransferToServer( const LiteralSlice& literal, const DeviceHandle* device_handle = nullptr); // Transfer the given literal to the Infeed interface of the device. @@ -158,7 +159,7 @@ class Client { // device_handle and replica_id together specify a particular device; a device // assigned for the given replica_id among the replicas that the given device // handle belongs to. - StatusOr TransferFromOutfeed( + absl::StatusOr TransferFromOutfeed( const Shape* shape_with_layout, int64_t replica_id = 0, const DeviceHandle* device_handle = nullptr); @@ -168,7 +169,7 @@ class Client { // Executes the computation with the given arguments and transfers the result // to the client as a literal. Parameters are defined the same as for // Execute() and Transfer(). - StatusOr ExecuteAndTransfer( + absl::StatusOr ExecuteAndTransfer( const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options = nullptr, @@ -189,7 +190,7 @@ class Client { // // If output_layout is non-null, then the output of the computation will be // stored using that layout. - StatusOr ComputeConstant( + absl::StatusOr ComputeConstant( const XlaComputation& computation, const Layout* output_layout = nullptr) const; @@ -197,43 +198,43 @@ class Client { Status Unregister(const GlobalData& data); // Returns a vector of global data handles that point to the tuple elements. - StatusOr>> DeconstructTuple( + absl::StatusOr>> DeconstructTuple( const GlobalData& data); // Retrieves the statistics of the given computation. - StatusOr GetComputationStats( + absl::StatusOr GetComputationStats( const XlaComputation& computation, const DebugOptions& debug_options) const; // Returns the Shape of the given array specified by 'data'. The shape // includes the Layout of the array as it is stored on the service. - StatusOr GetShape(const GlobalData& data); + absl::StatusOr GetShape(const GlobalData& data); // As above, but returns the shape of the provided computation (parameter // types/names and return type). - StatusOr> GetComputationShape( + absl::StatusOr> GetComputationShape( const XlaComputation& computation); // Creates a channel handle that can be used to transfer data between two // computations on different devices via a pair of Send and Recv instructions. - StatusOr CreateChannelHandle(); + absl::StatusOr CreateChannelHandle(); // Create a channel for communicating with the host via a SendtoHost or // RecvFromHost operation. - StatusOr CreateHostToDeviceChannelHandle(); - StatusOr CreateDeviceToHostChannelHandle(); + absl::StatusOr CreateHostToDeviceChannelHandle(); + absl::StatusOr CreateDeviceToHostChannelHandle(); - StatusOr LoadSnapshot(const HloSnapshot& module); + absl::StatusOr LoadSnapshot(const HloSnapshot& module); ServiceInterface* stub() { return stub_; } private: // Returns the execution statistics (e.g., gflop/s) as a string from the // ExecutionProfile returned from an execution of the computation. - StatusOr ExecutionStatsAsString( + absl::StatusOr ExecutionStatsAsString( const XlaComputation& computation, const ExecutionProfile& profile); - StatusOr CreateChannelHandleByType( + absl::StatusOr CreateChannelHandleByType( ChannelHandle::ChannelType type); ServiceInterface* stub_; // Stub that this client is connected on. diff --git a/xla/client/client_library.cc b/xla/client/client_library.cc index bbab29178cd40..b55691be31f7f 100644 --- a/xla/client/client_library.cc +++ b/xla/client/client_library.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -83,7 +83,7 @@ const std::optional>& LocalClientOptions::allowed_devices() ClientLibrary::ClientLibrary() = default; ClientLibrary::~ClientLibrary() = default; -/* static */ StatusOr ClientLibrary::GetOrCreateLocalClient( +/* static */ absl::StatusOr ClientLibrary::GetOrCreateLocalClient( se::Platform* platform, const std::optional>& device_set) { LocalClientOptions default_options; default_options.set_platform(platform); @@ -91,7 +91,7 @@ ClientLibrary::~ClientLibrary() = default; return GetOrCreateLocalClient(default_options); } -/* static */ StatusOr ClientLibrary::GetOrCreateLocalClient( +/* static */ absl::StatusOr ClientLibrary::GetOrCreateLocalClient( const LocalClientOptions& options) { se::Platform* platform = options.platform(); int replica_count = options.number_of_replicas(); @@ -139,7 +139,7 @@ ClientLibrary::~ClientLibrary() = default; return it->second->service.get(); } -/* static */ StatusOr +/* static */ absl::StatusOr ClientLibrary::GetOrCreateCompileOnlyClient(se::Platform* platform) { ClientLibrary& client_library = Singleton(); absl::MutexLock lock(&client_library.service_mutex_); diff --git a/xla/client/client_library.h b/xla/client/client_library.h index f98a01aa9697a..bb84e521b3b5c 100644 --- a/xla/client/client_library.h +++ b/xla/client/client_library.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -83,10 +83,10 @@ class ClientLibrary { // null then default platform is used. // device_set: Set of device IDs for which the stream executor will be // created, for the given platform. - static StatusOr GetOrCreateLocalClient( + static absl::StatusOr GetOrCreateLocalClient( se::Platform* platform = nullptr, const std::optional>& device_set = std::nullopt); - static StatusOr GetOrCreateLocalClient( + static absl::StatusOr GetOrCreateLocalClient( const LocalClientOptions& options); // Convenience "or-die" wrapper around the above which returns the existing @@ -101,7 +101,7 @@ class ClientLibrary { // // platform : The platform the underlying XLA service should target. If // null then default platform is used. - static StatusOr GetOrCreateCompileOnlyClient( + static absl::StatusOr GetOrCreateCompileOnlyClient( se::Platform* platform = nullptr); // Clears the local instance and compile only instance caches. The client diff --git a/xla/client/compile_only_client.cc b/xla/client/compile_only_client.cc index 084829a30131c..23c07b3742ba0 100644 --- a/xla/client/compile_only_client.cc +++ b/xla/client/compile_only_client.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ limitations under the License. namespace xla { -StatusOr> +absl::StatusOr> CompileOnlyClient::CreateModuleConfig( const ProgramShape& program_shape, absl::Span argument_shapes, @@ -33,7 +33,7 @@ CompileOnlyClient::CreateModuleConfig( execution_options); } -StatusOr>> +absl::StatusOr>> CompileOnlyClient::CompileAheadOfTime( const absl::Span computations, const AotCompilationOptions& options, diff --git a/xla/client/compile_only_client.h b/xla/client/compile_only_client.h index 3b737fc26cf89..fac2fbd98932a 100644 --- a/xla/client/compile_only_client.h +++ b/xla/client/compile_only_client.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -53,7 +53,7 @@ class CompileOnlyClient : public Client { // This is intended for use in static compilation. The |options| // parameter describes the target for which the compiler should emit // code. |metadata|, if provided, is populated during compilation. - StatusOr>> + absl::StatusOr>> CompileAheadOfTime( absl::Span computations, const AotCompilationOptions& options, @@ -61,7 +61,7 @@ class CompileOnlyClient : public Client { // Create a Hlo module config for the given program shape and arguments. // execution_options is optional; if not given a default is used. - StatusOr> CreateModuleConfig( + absl::StatusOr> CreateModuleConfig( const ProgramShape& program_shape, absl::Span argument_shapes, const ExecutionOptions* execution_options); diff --git a/xla/client/executable_build_options.cc b/xla/client/executable_build_options.cc index 3089a9820a181..64be6188908e5 100644 --- a/xla/client/executable_build_options.cc +++ b/xla/client/executable_build_options.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,16 +15,26 @@ limitations under the License. #include "xla/client/executable_build_options.h" +#include #include #include #include #include +#include "absl/log/check.h" #include "absl/strings/str_format.h" #include "xla/debug_options_flags.h" #include "xla/execution_options_util.h" +#include "xla/layout_util.h" +#include "xla/service/compilation_environments.h" +#include "xla/service/computation_placer.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/statusor.h" +#include "xla/util.h" #include "xla/xla.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" namespace xla { @@ -100,14 +110,14 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_use_auto_spmd_partitioning( ExecutableBuildOptions& ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_shape( std::vector mesh_shape) { - auto_spmd_partitioning_mesh_shape_ = mesh_shape; + auto_spmd_partitioning_mesh_shape_ = std::move(mesh_shape); return *this; } ExecutableBuildOptions& ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_ids( std::vector mesh_ids) { - auto_spmd_partitioning_mesh_ids_ = mesh_ids; + auto_spmd_partitioning_mesh_ids_ = std::move(mesh_ids); return *this; } @@ -134,7 +144,8 @@ std::string ExecutableBuildOptions::ToString() const { device_ordinal_, result_layout, num_replicas_); } -StatusOr ExecutableBuildOptions::ToProto() const { +absl::StatusOr ExecutableBuildOptions::ToProto() + const { ExecutableBuildOptionsProto output; output.set_device_ordinal(device_ordinal()); if (result_layout()) { @@ -151,6 +162,10 @@ StatusOr ExecutableBuildOptions::ToProto() const { "Cannot serialize " "ExecutableBuildOptions::layout_canonicalization_callback"); } + if (compile_thread_pool() != nullptr) { + return InvalidArgument( + "Cannot serialize ExecutableBuildOptions::compile_thread_pool"); + } output.set_num_replicas(num_replicas()); output.set_num_partitions(num_partitions()); output.set_use_spmd_partitioning(use_spmd_partitioning()); @@ -162,6 +177,12 @@ StatusOr ExecutableBuildOptions::ToProto() const { } output.set_alias_passthrough_params(alias_passthrough_params()); output.set_run_backend_only(run_backend_only()); + if (!allow_spmd_sharding_propagation_to_parameters().empty()) { + output.mutable_allow_spmd_sharding_propagation_to_parameters()->Clear(); + for (bool v : allow_spmd_sharding_propagation_to_parameters()) { + output.mutable_allow_spmd_sharding_propagation_to_parameters()->Add(v); + } + } if (!allow_spmd_sharding_propagation_to_output().empty()) { output.mutable_allow_spmd_sharding_propagation_to_output()->Clear(); for (bool v : allow_spmd_sharding_propagation_to_output()) { @@ -170,10 +191,16 @@ StatusOr ExecutableBuildOptions::ToProto() const { } *output.mutable_fdo_profile() = fdo_profile(); output.set_device_memory_size(device_memory_size()); + for (int64_t s : auto_spmd_partitioning_mesh_shape()) { + output.mutable_auto_spmd_partitioning_mesh_shape()->Add(s); + } + for (int64_t s : auto_spmd_partitioning_mesh_ids()) { + output.mutable_auto_spmd_partitioning_mesh_ids()->Add(s); + } return output; } -StatusOr ExecutableBuildOptionsFromProto( +absl::StatusOr ExecutableBuildOptionsFromProto( const ExecutableBuildOptionsProto& input) { xla::ExecutableBuildOptions output; if (input.device_ordinal() != -1) { @@ -204,10 +231,18 @@ StatusOr ExecutableBuildOptionsFromProto( } output.set_alias_passthrough_params(input.alias_passthrough_params()); output.set_run_backend_only(input.run_backend_only()); + output.set_allow_spmd_sharding_propagation_to_parameters( + input.allow_spmd_sharding_propagation_to_parameters()); output.set_allow_spmd_sharding_propagation_to_output( input.allow_spmd_sharding_propagation_to_output()); *output.mutable_fdo_profile() = input.fdo_profile(); output.set_device_memory_size(input.device_memory_size()); + output.set_auto_spmd_partitioning_mesh_shape( + std::vector(input.auto_spmd_partitioning_mesh_shape().begin(), + input.auto_spmd_partitioning_mesh_shape().end())); + output.set_auto_spmd_partitioning_mesh_ids( + std::vector(input.auto_spmd_partitioning_mesh_ids().begin(), + input.auto_spmd_partitioning_mesh_ids().end())); return output; } @@ -240,6 +275,15 @@ ExecutionOptions CreateExecutionOptions( execution_options.mutable_auto_spmd_partitioning_mesh_ids()->Add(t); } execution_options.set_deduplicate_hlo(build_options.deduplicate_hlo()); + if (!build_options.allow_spmd_sharding_propagation_to_parameters().empty()) { + execution_options.mutable_allow_spmd_sharding_propagation_to_parameters() + ->Clear(); + for (bool v : + build_options.allow_spmd_sharding_propagation_to_parameters()) { + execution_options.mutable_allow_spmd_sharding_propagation_to_parameters() + ->Add(v); + } + } if (!build_options.allow_spmd_sharding_propagation_to_output().empty()) { execution_options.mutable_allow_spmd_sharding_propagation_to_output() ->Clear(); diff --git a/xla/client/executable_build_options.h b/xla/client/executable_build_options.h index bc0a039f56510..fe51f7e0fd02f 100644 --- a/xla/client/executable_build_options.h +++ b/xla/client/executable_build_options.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -153,13 +153,34 @@ class ExecutableBuildOptions { return *this; } + absl::Span allow_spmd_sharding_propagation_to_parameters() const { + return allow_spmd_sharding_propagation_to_parameters_; + } absl::Span allow_spmd_sharding_propagation_to_output() const { return allow_spmd_sharding_propagation_to_output_; } + bool any_allow_spmd_sharding_propagation_to_parameters() const { + return absl::c_linear_search(allow_spmd_sharding_propagation_to_parameters_, + true); + } bool any_allow_spmd_sharding_propagation_to_output() const { return absl::c_linear_search(allow_spmd_sharding_propagation_to_output_, true); } + // Allows sharding propagation to propagate to the inputs. This changes the + // input shape of the computation (which is undesirable), but it can be used + // to allow to run partial compilation to determine what would be the input + // sharding of a computation if XLA would be allowed to propagate the sharding + // which can be used by higher level framework as a way to query intermediate + // sharding of operations when multiple computation would be chained and + // merged together. + ExecutableBuildOptions& set_allow_spmd_sharding_propagation_to_parameters( + absl::Span allow_spmd_sharding_propagation_to_parameters) { + allow_spmd_sharding_propagation_to_parameters_.assign( + allow_spmd_sharding_propagation_to_parameters.begin(), + allow_spmd_sharding_propagation_to_parameters.end()); + return *this; + } // Allows sharding propagation to propagate to the outputs. This changes the // output shape of the computation (which is undesirable), but it can be used // to allow to run partial compilation to determine what would be the output @@ -186,7 +207,7 @@ class ExecutableBuildOptions { } using LayoutCanonicalizationCallback = - std::function, Shape>>( + std::function, Shape>>( const HloModule& module)>; void set_layout_canonicalization_callback( LayoutCanonicalizationCallback callback) { @@ -197,8 +218,8 @@ class ExecutableBuildOptions { } absl::string_view fdo_profile() const { return fdo_profile_; } - void set_fdo_profile(const std::string& fdo_profile) { - fdo_profile_ = fdo_profile; + void set_fdo_profile(std::string fdo_profile) { + fdo_profile_ = std::move(fdo_profile); } std::string* mutable_fdo_profile() { return &fdo_profile_; } @@ -213,7 +234,7 @@ class ExecutableBuildOptions { // debugging. std::string ToString() const; - StatusOr ToProto() const; + absl::StatusOr ToProto() const; private: int device_ordinal_ = -1; @@ -233,6 +254,8 @@ class ExecutableBuildOptions { std::optional device_assignment_; bool alias_passthrough_params_ = false; bool run_backend_only_ = false; + absl::InlinedVector allow_spmd_sharding_propagation_to_parameters_ = + {false}; absl::InlinedVector allow_spmd_sharding_propagation_to_output_ = { false}; tsl::thread::ThreadPool* compile_thread_pool_ = nullptr; @@ -241,7 +264,7 @@ class ExecutableBuildOptions { int64_t device_memory_size_ = 0; }; -StatusOr ExecutableBuildOptionsFromProto( +absl::StatusOr ExecutableBuildOptionsFromProto( const ExecutableBuildOptionsProto& input); // Creates an ExecutionOptions based on a given ExecutableBuildOptions and diff --git a/xla/client/global_data.cc b/xla/client/global_data.cc index c655909879b4f..1235d1e3c50d6 100644 --- a/xla/client/global_data.cc +++ b/xla/client/global_data.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/global_data.h b/xla/client/global_data.h index f54735d4429ca..9a485258a9aaf 100644 --- a/xla/client/global_data.h +++ b/xla/client/global_data.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/BUILD b/xla/client/lib/BUILD index 0f72a9bdf54b6..f7552e4081599 100644 --- a/xla/client/lib/BUILD +++ b/xla/client/lib/BUILD @@ -1,12 +1,13 @@ # Common computation builders for XLA. -load("//xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test") +load("@tsl//tsl:tsl.bzl", "internal_visibility") load("@tsl//tsl:tsl.default.bzl", "filegroup") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//xla/client:friends"], + default_visibility = internal_visibility(["//xla/client:friends"]), licenses = ["notice"], ) @@ -62,7 +63,6 @@ cc_library( ], deps = [ ":constants", - "//xla:literal_util", "//xla:shape_util", "//xla:types", "//xla:xla_data_proto_cc", @@ -102,7 +102,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", - "@tsl//tsl/platform:float8", + "@tsl//tsl/platform:ml_dtypes", ], ) @@ -145,7 +145,6 @@ cc_library( "//xla:status_macros", "//xla/client:padding", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", ], ) @@ -155,7 +154,6 @@ cc_library( hdrs = ["dynamic_shaped_ops.h"], deps = [ ":constants", - "//xla:literal_util", "//xla:shape_util", "//xla:types", "//xla:util", @@ -254,11 +252,9 @@ xla_test( ":matrix", ":slicing", "//xla:status", - "//xla:status_macros", "//xla:statusor", "//xla:test", "//xla:types", - "//xla:xla_data_proto_cc", "//xla/client:xla_builder", "//xla/tests:client_library_test_base", "//xla/tests:test_macros_header", @@ -343,7 +339,6 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/client:xla_builder", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", ], ) @@ -368,7 +363,6 @@ xla_test( "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:test", ], ) @@ -486,7 +480,6 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/client:xla_builder", "@tsl//tsl/platform:bfloat16", - "@tsl//tsl/platform:logging", ], ) @@ -528,7 +521,6 @@ cc_library( "//xla/client:xla_computation", "//xla/tests:test_utils", "@com_google_absl//absl/strings", - "@tsl//tsl/platform:logging", "@tsl//tsl/platform:protobuf", ], ) @@ -553,7 +545,6 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/client:xla_builder", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", ], ) @@ -582,7 +573,6 @@ xla_test( "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:test", ], ) @@ -605,7 +595,6 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/client:xla_builder", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", ], ) @@ -634,7 +623,6 @@ xla_test( "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:test", ], ) @@ -663,18 +651,13 @@ xla_test( shard_count = 10, tags = ["optonly"], deps = [ - ":constants", ":slicing", ":tridiagonal", - "//xla:array2d", - "//xla:error_spec", "//xla:literal", "//xla:shape_util", "//xla:status", - "//xla:test", "//xla/client:xla_builder", "//xla/tests:client_library_test_base", - "//xla/tests:literal_test_util", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", ], @@ -698,7 +681,6 @@ cc_library( "//xla:statusor", "//xla/client:xla_builder", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", ], ) @@ -722,7 +704,6 @@ xla_test( "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:test", ], ) @@ -756,6 +737,5 @@ xla_test( "//xla/tests:client_library_test_base", "//xla/tests:xla_internal_test_main", "@tsl//tsl/platform:statusor", - "@tsl//tsl/platform:test", ], ) diff --git a/xla/client/lib/approx_topk.cc b/xla/client/lib/approx_topk.cc index 8472e82c514cd..40b9360976215 100644 --- a/xla/client/lib/approx_topk.cc +++ b/xla/client/lib/approx_topk.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ const uint64_t kTpuChunkTiling = 1024; namespace xla { namespace { -StatusOr> GetOperandTypes( +absl::StatusOr> GetOperandTypes( XlaBuilder* builder, absl::Span operands, absl::Span init_values) { std::vector op_types; diff --git a/xla/client/lib/approx_topk.h b/xla/client/lib/approx_topk.h index d92c784b9de65..6b097e0a8483d 100644 --- a/xla/client/lib/approx_topk.h +++ b/xla/client/lib/approx_topk.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/approx_topk_shape.cc b/xla/client/lib/approx_topk_shape.cc index 9309c0ea63269..aec7b7105e9d3 100644 --- a/xla/client/lib/approx_topk_shape.cc +++ b/xla/client/lib/approx_topk_shape.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include "xla/util.h" @@ -37,7 +38,7 @@ inline uint32_t log2_ceil(uint64_t value) { } // LINT.IfChange -StatusOr> ApproxTopKReductionOutputSize( +absl::StatusOr> ApproxTopKReductionOutputSize( int64_t input_size, int64_t rank, int64_t top_k, float recall_target, bool aggregate_to_topk, int64_t input_size_override) { if (aggregate_to_topk) { @@ -97,7 +98,7 @@ StatusOr> ApproxTopKReductionOutputSize( static_cast((1.0 - top_k) / std::log(static_cast(recall_target))), tpu_tiling), - input_size); + logical_input_size); uint32_t log2_reduction = log2_floor(logical_input_size / m); if (log2_reduction == 0) { return std::pair(input_size, 0); diff --git a/xla/client/lib/approx_topk_shape.h b/xla/client/lib/approx_topk_shape.h index c569aa3aa40c1..de12c7fc0a3e1 100644 --- a/xla/client/lib/approx_topk_shape.h +++ b/xla/client/lib/approx_topk_shape.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ namespace xla { // // 2. is invalid and set to -1 when the approximate output is disabled, i.e. // top_k = 1 or aggregate_to_topk = true. -StatusOr> ApproxTopKReductionOutputSize( +absl::StatusOr> ApproxTopKReductionOutputSize( int64_t input_size, int64_t rank, int64_t top_k, float recall_target, bool aggregate_to_topk, int64_t input_size_override = -1); diff --git a/xla/client/lib/arithmetic.cc b/xla/client/lib/arithmetic.cc index eda614597e547..0b089552e11d7 100644 --- a/xla/client/lib/arithmetic.cc +++ b/xla/client/lib/arithmetic.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -104,7 +104,7 @@ XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type, XlaOp Any(XlaOp predicates) { XlaBuilder* builder = predicates.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { auto f = ConstantR0(builder, false); XlaComputation logical_or = CreateScalarOrComputation(PRED, builder); TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, @@ -142,7 +142,7 @@ static XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder, XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) { XlaBuilder* builder = input.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); XlaOp value_init_value; if (is_min) { diff --git a/xla/client/lib/arithmetic.h b/xla/client/lib/arithmetic.h index d7b0d2cb5d8ad..8ebf47ca53aff 100644 --- a/xla/client/lib/arithmetic.h +++ b/xla/client/lib/arithmetic.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/arithmetic_test.cc b/xla/client/lib/arithmetic_test.cc index dee791bb237a9..f0b285cc6cbf4 100644 --- a/xla/client/lib/arithmetic_test.cc +++ b/xla/client/lib/arithmetic_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/broadcast.cc b/xla/client/lib/broadcast.cc index 4874feecba28d..43ab0cf422887 100644 --- a/xla/client/lib/broadcast.cc +++ b/xla/client/lib/broadcast.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,8 +25,8 @@ limitations under the License. namespace xla { -StatusOr BroadcastTo(XlaOp input, - absl::Span output_dims) { +absl::StatusOr BroadcastTo(XlaOp input, + absl::Span output_dims) { XlaBuilder* builder = input.builder(); TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); absl::Span input_dims = input_shape.dimensions(); diff --git a/xla/client/lib/broadcast.h b/xla/client/lib/broadcast.h index 50a8ab2124f37..a4b3c93144374 100644 --- a/xla/client/lib/broadcast.h +++ b/xla/client/lib/broadcast.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,7 +25,8 @@ namespace xla { // Broadcasts 'input' up to shape 'output_dims', using TensorFlow broadcasting // rules. Supports broadcasting a dimension of size x to size x*y, i.e., tiling. -StatusOr BroadcastTo(XlaOp input, absl::Span output_dims); +absl::StatusOr BroadcastTo(XlaOp input, + absl::Span output_dims); } // namespace xla diff --git a/xla/client/lib/comparators.cc b/xla/client/lib/comparators.cc index 4d61aedb7060d..02ba3064382dd 100644 --- a/xla/client/lib/comparators.cc +++ b/xla/client/lib/comparators.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/comparators.h b/xla/client/lib/comparators.h index 952a804eddfa3..c2dfc1c7def3f 100644 --- a/xla/client/lib/comparators.h +++ b/xla/client/lib/comparators.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/comparators_test.cc b/xla/client/lib/comparators_test.cc index 16b08911c0ee8..7c3305cac94c6 100644 --- a/xla/client/lib/comparators_test.cc +++ b/xla/client/lib/comparators_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/constants.cc b/xla/client/lib/constants.cc index dfe9332fff83e..2ab6b2abfff45 100644 --- a/xla/client/lib/constants.cc +++ b/xla/client/lib/constants.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/float8.h" +#include "tsl/platform/ml_dtypes.h" namespace xla { @@ -35,7 +35,7 @@ XlaOp Zeros(XlaBuilder* builder, const Shape& shape) { XlaOp ZerosLike(XlaOp prototype) { XlaBuilder* builder = prototype.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); return Zeros(builder, shape); }); diff --git a/xla/client/lib/constants.h b/xla/client/lib/constants.h index d6c968540eae0..3126b362b3f2b 100644 --- a/xla/client/lib/constants.h +++ b/xla/client/lib/constants.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/float8.h" +#include "tsl/platform/ml_dtypes.h" namespace xla { @@ -66,7 +66,7 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { template XlaOp ScalarLike(XlaOp prototype, T value) { XlaBuilder* builder = prototype.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); return ConstantR0WithType(builder, shape.element_type(), value); }); @@ -80,7 +80,7 @@ XlaOp ScalarLike(XlaOp prototype, T value) { template XlaOp FullLike(XlaOp prototype, T value) { XlaBuilder* builder = prototype.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); if (ShapeUtil::IsScalar(shape) || shape.IsArray()) { return Broadcast(ScalarLike(prototype, value), shape.dimensions()); diff --git a/xla/client/lib/constants_test.cc b/xla/client/lib/constants_test.cc index 2af7870e9e2f3..16e0e12c7dbf5 100644 --- a/xla/client/lib/constants_test.cc +++ b/xla/client/lib/constants_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/conv_grad_size_util.cc b/xla/client/lib/conv_grad_size_util.cc index ee88f2e84be7a..6d23c71a15f39 100644 --- a/xla/client/lib/conv_grad_size_util.cc +++ b/xla/client/lib/conv_grad_size_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ namespace xla { namespace { -StatusOr GetWindowedOutputSize( +absl::StatusOr GetWindowedOutputSize( int64_t input_size, int64_t filter_size, int64_t dilation_rate, int64_t stride, Padding padding_type) { if (stride <= 0) { @@ -65,7 +65,7 @@ StatusOr GetWindowedOutputSize( } // namespace -StatusOr +absl::StatusOr ConvGradExtractAndVerifyDimension(int64_t input_size, int64_t filter_size, int64_t output_size, int64_t dilation, int64_t stride, Padding padding) { diff --git a/xla/client/lib/conv_grad_size_util.h b/xla/client/lib/conv_grad_size_util.h index 1de8045049221..5112add3043d3 100644 --- a/xla/client/lib/conv_grad_size_util.h +++ b/xla/client/lib/conv_grad_size_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ struct SpatialDimensionOutputSizeAndPadding { // Verifies that the dimensions all match, and computes the size and padding of // a spatial dimension for convolution gradient operations. -StatusOr +absl::StatusOr ConvGradExtractAndVerifyDimension(int64_t input_size, int64_t filter_size, int64_t output_size, int64_t dilation, int64_t stride, Padding padding); diff --git a/xla/client/lib/dynamic_shaped_ops.cc b/xla/client/lib/dynamic_shaped_ops.cc index 07387b862f9c8..89c53b21b38d7 100644 --- a/xla/client/lib/dynamic_shaped_ops.cc +++ b/xla/client/lib/dynamic_shaped_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -65,9 +65,9 @@ Shape FindMaxShape(absl::Span shapes) { return result; } -StatusOr ReconsileBranchDifference(const Shape& left_branch_shape, - const Shape& right_branch_shape, - XlaOp left_root) { +absl::StatusOr ReconsileBranchDifference(const Shape& left_branch_shape, + const Shape& right_branch_shape, + XlaOp left_root) { if (left_branch_shape.IsTuple()) { // Invariant sanity check -- Left branch and right branch need to have // compatible shapes. @@ -124,7 +124,7 @@ XlaOp DynamicConditional(XlaBuilder* builder, XlaOp predicate, const XlaComputation& true_computation, XlaOp false_operand, const XlaComputation& false_computation) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { auto true_shape = true_computation.GetProgramShape().value().result(); auto false_shape = false_computation.GetProgramShape().value().result(); @@ -136,8 +136,8 @@ XlaOp DynamicConditional(XlaBuilder* builder, XlaOp predicate, auto reconsile_branch = [](const Shape& root_shape, const Shape& operand_shape, - const Shape& reference_root_shape, - const XlaComputation& computation) -> StatusOr { + const Shape& reference_root_shape, const XlaComputation& computation) + -> absl::StatusOr { xla::XlaBuilder builder("dynamic_builder"); auto param = xla::Parameter(&builder, 0, operand_shape, "param"); auto call = Call(&builder, computation, {param}); @@ -165,7 +165,7 @@ XlaOp DynamicConditional( XlaBuilder* builder, XlaOp branch_index, absl::Span branch_computations, absl::Span branch_operands) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { std::vector root_shapes; root_shapes.reserve(branch_computations.size()); for (int64_t i = 0; i < branch_computations.size(); ++i) { @@ -194,8 +194,8 @@ XlaOp DynamicConditional( auto reconsile_branch = [](const Shape& root_shape, const Shape& operand_shape, - const Shape& reference_root_shape, - const XlaComputation& computation) -> StatusOr { + const Shape& reference_root_shape, const XlaComputation& computation) + -> absl::StatusOr { xla::XlaBuilder builder("dynamic_builder"); auto param = xla::Parameter(&builder, 0, operand_shape, "param"); auto call = Call(&builder, computation, {param}); @@ -227,9 +227,9 @@ XlaOp DynamicConditional( }); } -StatusOr SetDimensionSizeWithRebound(ValueInference* value_inference, - XlaOp operand, XlaOp dimension_size, - int64_t dimension) { +absl::StatusOr SetDimensionSizeWithRebound( + ValueInference* value_inference, XlaOp operand, XlaOp dimension_size, + int64_t dimension) { auto inferred_bound_status_or = value_inference->AnalyzeConstant( dimension_size, xla::ValueInferenceMode::kUpperBound); @@ -253,8 +253,8 @@ StatusOr SetDimensionSizeWithRebound(ValueInference* value_inference, return operand; } -StatusOr SetAllDimensionSizes(ValueInference* value_inference, - XlaOp operand, XlaOp size_vector) { +absl::StatusOr SetAllDimensionSizes(ValueInference* value_inference, + XlaOp operand, XlaOp size_vector) { auto builder = value_inference->builder(); TF_RETURN_IF_ERROR(builder->GetCurrentStatus()); TF_ASSIGN_OR_RETURN(auto shape_ptr, builder->GetShapePtr(operand)); diff --git a/xla/client/lib/dynamic_shaped_ops.h b/xla/client/lib/dynamic_shaped_ops.h index 969d0ddadf971..4dd4591563e64 100644 --- a/xla/client/lib/dynamic_shaped_ops.h +++ b/xla/client/lib/dynamic_shaped_ops.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -42,15 +42,15 @@ XlaOp DynamicConditional( // Similar to SetDimensionSize, but automatically adjust the bound of output if // a tighter one can be inferred by `value_inference`. -StatusOr SetDimensionSizeWithRebound(ValueInference* value_inference, - XlaOp operand, XlaOp dimension_size, - int64_t dimension); +absl::StatusOr SetDimensionSizeWithRebound( + ValueInference* value_inference, XlaOp operand, XlaOp dimension_size, + int64_t dimension); // Take a `operand` tensor and a R1 tensor `size_vector` representing the sizes // of `operand`, Call SetDimensionSize if for each dimension whose size is // dynamic. -StatusOr SetAllDimensionSizes(ValueInference* value_inference, - XlaOp operand, XlaOp size_vector); +absl::StatusOr SetAllDimensionSizes(ValueInference* value_inference, + XlaOp operand, XlaOp size_vector); } // namespace xla #endif // XLA_CLIENT_LIB_DYNAMIC_SHAPED_OPS_H_ diff --git a/xla/client/lib/logdet.cc b/xla/client/lib/logdet.cc index eb16067409a81..7eaee01274e21 100644 --- a/xla/client/lib/logdet.cc +++ b/xla/client/lib/logdet.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,7 +36,8 @@ limitations under the License. namespace xla { SignAndLogDet SLogDet(XlaOp a) { - StatusOr result = [&]() -> StatusOr { + absl::StatusOr result = + [&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, a.builder()->GetShape(a)); auto qr = Qr(a); diff --git a/xla/client/lib/logdet.h b/xla/client/lib/logdet.h index 6ecd7a107dc29..ee3d984fa6931 100644 --- a/xla/client/lib/logdet.h +++ b/xla/client/lib/logdet.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/logdet_test.cc b/xla/client/lib/logdet_test.cc index e9384d936f319..c9be45c5ed13f 100644 --- a/xla/client/lib/logdet_test.cc +++ b/xla/client/lib/logdet_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/loops.cc b/xla/client/lib/loops.cc index e897d054fc490..9cee8e354e875 100644 --- a/xla/client/lib/loops.cc +++ b/xla/client/lib/loops.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ limitations under the License. namespace xla { -StatusOr> WhileLoopHelper( +absl::StatusOr> WhileLoopHelper( const WhileLoopHelperConditionFunction& condition_function, const WhileLoopHelperBodyFunction& body_function, absl::Span initial_values, absl::string_view name, @@ -83,19 +83,19 @@ StatusOr> WhileLoopHelper( return unpack_tuple(outputs, arity, builder); } -StatusOr> ForEachIndex( +absl::StatusOr> ForEachIndex( int64_t num_iterations, PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, absl::Span initial_values, absl::string_view name, XlaBuilder* builder) { auto while_cond_fn = [&](absl::Span values, - XlaBuilder* cond_builder) -> StatusOr { + XlaBuilder* cond_builder) -> absl::StatusOr { return Lt(values[0], ConstantR0WithType(cond_builder, num_iterations_type, num_iterations)); }; auto while_body_fn = [&](absl::Span values, - XlaBuilder* body_builder) -> StatusOr> { + XlaBuilder* body_builder) -> absl::StatusOr> { XlaOp iteration = values[0]; std::vector updated_values; diff --git a/xla/client/lib/loops.h b/xla/client/lib/loops.h index 6ee295be155ee..5cd7426992cea 100644 --- a/xla/client/lib/loops.h +++ b/xla/client/lib/loops.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,13 +29,14 @@ namespace xla { // Function that builds a loop condition. Takes as input a sequence of input // values, and returns a boolean value representing if the condition succeeds. -typedef std::function(absl::Span, XlaBuilder*)> +typedef std::function(absl::Span, + XlaBuilder*)> WhileLoopHelperConditionFunction; // Function that builds a loop body. Takes as input a sequence of input values // and returns a sequence of output values. -typedef std::function>(absl::Span, - XlaBuilder*)> +typedef std::function>( + absl::Span, XlaBuilder*)> WhileLoopHelperBodyFunction; // Helper function for building an XLA while loop, where the values carried by @@ -46,7 +47,7 @@ typedef std::function>(absl::Span, // init: (a, b, c) // ) // 'name' is a descriptive name for the loop. -StatusOr> WhileLoopHelper( +absl::StatusOr> WhileLoopHelper( const WhileLoopHelperConditionFunction& condition_function, const WhileLoopHelperBodyFunction& body_function, absl::Span initial_values, absl::string_view name, @@ -57,11 +58,11 @@ StatusOr> WhileLoopHelper( // The body function (ForEachIndexBodyFunction) takes as input a pair of // (current iteration number, loop-carried values), and returns an updated // vector of the loop-carried values. -typedef std::function>( +typedef std::function>( XlaOp, absl::Span, XlaBuilder*)> ForEachIndexBodyFunction; -StatusOr> ForEachIndex( +absl::StatusOr> ForEachIndex( int64_t num_iterations, PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, absl::Span initial_values, absl::string_view name, diff --git a/xla/client/lib/lu_decomposition.cc b/xla/client/lib/lu_decomposition.cc index a8d267ed12e8b..cbce5be6b06c8 100644 --- a/xla/client/lib/lu_decomposition.cc +++ b/xla/client/lib/lu_decomposition.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ namespace xla { LuDecompositionResult LuDecomposition(XlaOp a) { XlaBuilder* builder = a.builder(); - XlaOp result = builder->ReportErrorOrReturn([&]() -> StatusOr { + XlaOp result = builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); const int ndims = a_shape.rank(); TF_RET_CHECK(ndims >= 2); diff --git a/xla/client/lib/lu_decomposition.h b/xla/client/lib/lu_decomposition.h index 4ec7459922aab..a2d26e02f4e63 100644 --- a/xla/client/lib/lu_decomposition.h +++ b/xla/client/lib/lu_decomposition.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/math.cc b/xla/client/lib/math.cc index 5a3aacd8941b8..60592de4ddc80 100644 --- a/xla/client/lib/math.cc +++ b/xla/client/lib/math.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -75,7 +75,7 @@ static XlaOp DoWithUpcastToF32(XlaOp operand, absl::Span upcast_types, const std::function& operation) { auto& b = *operand.builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand)); PrimitiveType elem_ty = shape.element_type(); bool needs_upcast = absl::c_linear_search(upcast_types, elem_ty); @@ -107,7 +107,7 @@ static Status EnsureOperandIsRealFp(absl::string_view op_name, XlaOp operand) { XlaOp IsPosInf(XlaOp operand) { auto& b = *operand.builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsPosInf", operand)); TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand)); // Note that this is only correct for floating-point types. If we wanted it @@ -118,7 +118,7 @@ XlaOp IsPosInf(XlaOp operand) { XlaOp IsNegInf(XlaOp operand) { auto& b = *operand.builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegInf", operand)); TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand)); // Note that this is only correct for floating-point types. If we wanted it @@ -129,7 +129,7 @@ XlaOp IsNegInf(XlaOp operand) { XlaOp IsInf(XlaOp operand) { auto& b = *operand.builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsInf", operand)); return IsPosInf(Abs(operand)); }); @@ -137,7 +137,7 @@ XlaOp IsInf(XlaOp operand) { XlaOp IsNan(XlaOp operand) { auto& b = *operand.builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNan", operand)); return Ne(operand, operand); }); @@ -145,7 +145,7 @@ XlaOp IsNan(XlaOp operand) { XlaOp IsNegZero(XlaOp operand) { auto& b = *operand.builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegZero", operand)); TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand)); @@ -291,7 +291,7 @@ static XlaOp ErfImpl64(XlaOp x) { XlaOp Erfc(XlaOp x) { auto& b = *x.builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erfc", x)); TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x)); // erfc(x) = @@ -336,26 +336,6 @@ static XlaOp ErfImpl32(XlaOp x) { EvaluatePolynomial(x2, kBeta); } -XlaOp Erf(XlaOp x) { - auto& b = *x.builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { - TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erf", x)); - TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x)); - // erf(x) = - // erf_impl(x) if x < 1 - // 1 - erfc_impl(x) otherwise - if (shape.element_type() == F64) { - return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl64(x), - ScalarLike(x, 1) - ErfcImpl64(x)); - } - // Erf(c)Impl don't have enough precision when run with bf16 intermediates - // (not surprising!), so upcast to f32 in this case. - return DoWithUpcastToF32( - x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}, - [](XlaOp x) { return ErfImpl32(x); }); - }); -} - namespace { // Approximation for the inverse error function from @@ -403,7 +383,7 @@ XlaOp ErfInv32(XlaOp x) { // Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is // indeterminate, and can give nan or -/+inf.) auto& b = *x.builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, b.GetShape(x)); return Select(Eq(Abs(x), ScalarLike(x, 1)), x * MaxValue(&b, shape.element_type()), result); @@ -483,7 +463,7 @@ XlaOp ErfInv64(XlaOp x) { // Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is // indeterminate, and can give nan or -/+inf.) auto& b = *x.builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, b.GetShape(x)); return Select(Eq(Abs(x), ScalarLike(x, 1)), x * MaxValue(&b, shape.element_type()), result); @@ -494,7 +474,7 @@ XlaOp ErfInv64(XlaOp x) { XlaOp ErfInv(XlaOp x) { auto& b = *x.builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("ErfInv", x)); TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x)); if (shape.element_type() == F64) { @@ -625,7 +605,7 @@ XlaOp Lgamma(XlaOp input) { }; auto& b = *input.builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Lgamma", input)); // F16 and BF16 don't provide sufficient precision for intermediate results // here (although it's better than you might expect!), so do the @@ -726,7 +706,7 @@ XlaOp Digamma(XlaOp input) { }; auto& b = *input.builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Digamma", input)); return DoWithUpcastToF32( input, @@ -753,12 +733,12 @@ XlaOp IgammaSeries(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled, // size, and then run independent loops for each warp's worth of // data. auto cond = [&](absl::Span vals, - XlaBuilder* builder) -> StatusOr { + XlaBuilder* builder) -> absl::StatusOr { XlaOp enabled = vals[0]; return Any(enabled); }; auto body = [&](absl::Span vals, - XlaBuilder* builder) -> StatusOr> { + XlaBuilder* builder) -> absl::StatusOr> { XlaOp enabled = vals[0]; XlaOp r = vals[1]; XlaOp c = vals[2]; @@ -791,7 +771,7 @@ XlaOp IgammaSeries(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled, }; }; auto& b = *ax.builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { std::vector vals = { enabled, a, FullLike(a, 1), FullLike(a, 1), x, FullLike(a, 0), FullLike(a, 0), @@ -822,13 +802,13 @@ XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled, xla::PrimitiveType type) { // vals: enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2 auto cond = [&](absl::Span vals, - XlaBuilder* builder) -> StatusOr { + XlaBuilder* builder) -> absl::StatusOr { XlaOp enabled = vals[0]; XlaOp c = vals[5]; return And(Lt(c, ScalarLike(c, 2000)), Any(enabled)); }; auto body = [&](absl::Span vals, - XlaBuilder* builder) -> StatusOr> { + XlaBuilder* builder) -> absl::StatusOr> { XlaOp enabled = vals[0]; XlaOp ans = vals[1]; XlaOp t = vals[2]; @@ -911,7 +891,7 @@ XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled, }; auto& b = *ax.builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { XlaOp y = ScalarLike(a, 1) - a; XlaOp z = x + y + ScalarLike(x, 1); XlaOp c = ScalarLike(x, 0); @@ -975,7 +955,7 @@ XlaOp Igamma(XlaOp a, XlaOp x) { output = Select(Or(domain_error, is_nan), FullLike(a, nan), output); return output; }; - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a)); TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x)); if (a_shape != x_shape) { @@ -1028,7 +1008,7 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) { output = Select(Or(domain_error, is_nan), FullLike(a, nan), output); return output; }; - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a)); TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x)); if (a_shape != x_shape) { @@ -1081,7 +1061,7 @@ XlaOp RandomGammaGrad(XlaOp a, XlaOp x) { output = Select(Or(domain_error, is_nan), FullLike(a, nan), output); return output; }; - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a)); TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x)); if (a_shape != x_shape) { @@ -1126,7 +1106,7 @@ XlaOp Igammac(XlaOp a, XlaOp x) { result = Select(x_is_infinity, ZerosLike(result), result); return Select(out_of_range, FullLike(a, 1), result); }; - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a)); TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x)); if (a_shape != x_shape) { @@ -1157,7 +1137,7 @@ XlaOp Igammac(XlaOp a, XlaOp x) { // integers are rounded towards even. XlaOp RoundToEven(XlaOp x) { auto& b = *x.builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { // Reject non-real non-fp inputs (What does it even mean to round a complex // number? Do you round each component equally? In that case, you should // just ask for that explicitly.) @@ -1175,7 +1155,7 @@ XlaOp RoundToEven(XlaOp x) { // acos(x) = -(i * log(x + i * sqrt((1 + x) * (1 - x)))) XlaOp Acos(XlaOp x) { XlaBuilder* b = x.builder(); - return b->ReportErrorOrReturn([&]() -> StatusOr { + return b->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x)); if (primitive_util::IsComplexType(shape.element_type())) { @@ -1214,7 +1194,7 @@ XlaOp Atan(XlaOp x) { return Atan2(x, ScalarLike(x, 1.0)); } // overflows; x < -1 simply yields nan. This is quite different than asinh!) XlaOp Acosh(XlaOp x) { XlaBuilder* b = x.builder(); - return b->ReportErrorOrReturn([&]() -> StatusOr { + return b->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x)); auto one = ScalarLike(x, 1); @@ -1252,7 +1232,7 @@ XlaOp Acosh(XlaOp x) { // -asinh(x). XlaOp Asinh(XlaOp x) { XlaBuilder* b = x.builder(); - auto do_it = [&](XlaOp x) -> StatusOr { + auto do_it = [&](XlaOp x) -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x)); auto one = ScalarLike(x, 1); @@ -1300,7 +1280,7 @@ XlaOp Asinh(XlaOp x) { // atanh(x) = nan otherwise XlaOp Atanh(XlaOp x) { XlaBuilder* b = x.builder(); - auto do_it = [&](XlaOp x) -> StatusOr { + auto do_it = [&](XlaOp x) -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x)); auto naive_result = (Log1p(x) - Log1p(-x)) * ScalarLike(x, 0.5); @@ -1347,7 +1327,7 @@ XlaOp Cosh(XlaOp x) { // we deem this acceptable. XlaOp Sinh(XlaOp x) { XlaBuilder* b = x.builder(); - auto do_it = [&](XlaOp x) -> StatusOr { + auto do_it = [&](XlaOp x) -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x)); auto one_half = ScalarLike(x, 0.5); auto log_one_half = Log(ScalarLike(x, 0.5)); @@ -1380,7 +1360,7 @@ XlaOp Sinh(XlaOp x) { XlaOp MaybeConjugate(XlaOp x, bool conjugate) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); auto perform_conj = primitive_util::IsComplexType(shape.element_type()) && conjugate; @@ -1390,7 +1370,7 @@ XlaOp MaybeConjugate(XlaOp x, bool conjugate) { XlaOp NextAfter(XlaOp from, XlaOp to) { auto builder = from.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(from)); int bitwidth = primitive_util::BitWidth(shape.element_type()); auto int_type = primitive_util::UnsignedIntegralTypeForBitWidth(bitwidth); @@ -1547,7 +1527,7 @@ static XlaOp I0eImpl64(XlaOp x) { XlaOp BesselI0e(XlaOp x) { auto& b = *x.builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("BesselI0e", x)); TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x)); if (shape.element_type() == F64) { @@ -1643,7 +1623,7 @@ static XlaOp I1eImpl64(XlaOp x) { XlaOp BesselI1e(XlaOp x) { auto& b = *x.builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("BesselI1e", x)); TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x)); if (shape.element_type() == F64) { @@ -1665,7 +1645,7 @@ static XlaOp LentzThompsonBarnettAlgorithm( const ForEachIndexBodyFunction& nth_partial_denominator, absl::Span inputs, absl::string_view name) { auto& b = *inputs.front().builder(); - return b.ReportErrorOrReturn([&]() -> StatusOr { + return b.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_RET_CHECK(num_iterations < INT32_MAX); enum { @@ -1684,9 +1664,9 @@ static XlaOp LentzThompsonBarnettAlgorithm( // Inputs follow all of the other state. kFirstInputIdx, }; - auto while_cond_fn = [num_iterations]( - absl::Span values, - XlaBuilder* cond_builder) -> StatusOr { + auto while_cond_fn = + [num_iterations](absl::Span values, + XlaBuilder* cond_builder) -> absl::StatusOr { auto iteration = values[kIterationIdx]; auto iterations_remain_cond = Lt(iteration, ScalarLike(iteration, num_iterations)); @@ -1697,7 +1677,7 @@ static XlaOp LentzThompsonBarnettAlgorithm( auto while_body_fn = [small, threshold, &nth_partial_numerator, &nth_partial_denominator]( absl::Span values, - XlaBuilder* body_builder) -> StatusOr> { + XlaBuilder* body_builder) -> absl::StatusOr> { XlaOp iteration = values[kIterationIdx]; TF_ASSIGN_OR_RETURN( @@ -1763,7 +1743,7 @@ static XlaOp LentzThompsonBarnettAlgorithm( XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x) { auto& builder = *x.builder(); - return builder.ReportErrorOrReturn([&]() -> StatusOr { + return builder.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder.GetShape(a)); TF_ASSIGN_OR_RETURN(Shape b_shape, builder.GetShape(b)); TF_ASSIGN_OR_RETURN(Shape x_shape, builder.GetShape(x)); @@ -1793,7 +1773,7 @@ XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x) { // case: the partial numerator for the first iteration is one. auto NthPartialBetaincNumerator = [&](XlaOp iteration, absl::Span inputs, - XlaBuilder* builder) -> StatusOr> { + XlaBuilder* builder) -> absl::StatusOr> { auto a = inputs[0]; auto b = inputs[1]; auto x = inputs[2]; @@ -1820,7 +1800,7 @@ XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x) { auto NthPartialBetaincDenominator = [&shape](XlaOp iteration, absl::Span inputs, - XlaBuilder* builder) -> StatusOr> { + XlaBuilder* builder) -> absl::StatusOr> { auto x = inputs[2]; auto iteration_bcast = Broadcast(iteration, shape.dimensions()); return std::vector{ @@ -1904,7 +1884,7 @@ XlaOp Polygamma(XlaOp n, XlaOp x) { ScalarLike(n, nan), output); return output; }; - return builder.ReportErrorOrReturn([&]() -> StatusOr { + return builder.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto n_shape, builder.GetShape(n)); TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x)); if (n_shape != x_shape) { @@ -2023,7 +2003,7 @@ XlaOp Zeta(XlaOp x, XlaOp q) { return output; }; - return builder.ReportErrorOrReturn([&]() -> StatusOr { + return builder.ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x)); TF_ASSIGN_OR_RETURN(auto q_shape, builder.GetShape(q)); if (x_shape != q_shape) { diff --git a/xla/client/lib/math.h b/xla/client/lib/math.h index 9ef7a5967d7a0..74b8a387a416d 100644 --- a/xla/client/lib/math.h +++ b/xla/client/lib/math.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -46,9 +46,6 @@ XlaOp Reciprocal(XlaOp operand); // Computes an approximation of the error function complement (1 - erf(x)). XlaOp Erfc(XlaOp x); -// Computes an approximation of the error function. -XlaOp Erf(XlaOp x); - // Computes an approximation of the inverse of the error function. XlaOp ErfInv(XlaOp x); diff --git a/xla/client/lib/math_test.cc b/xla/client/lib/math_test.cc index 71b94898c87bf..90c3304536050 100644 --- a/xla/client/lib/math_test.cc +++ b/xla/client/lib/math_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/matrix.cc b/xla/client/lib/matrix.cc index e5b060a49e290..156689d4cda39 100644 --- a/xla/client/lib/matrix.cc +++ b/xla/client/lib/matrix.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -61,7 +61,7 @@ XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64_t m, XlaOp GetDiagonalMask(XlaOp x, int diagonal) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); auto n_dims = static_cast(shape.rank()); TF_RET_CHECK(n_dims >= 2); @@ -79,7 +79,7 @@ XlaOp GetDiagonalMask(XlaOp x, int diagonal) { XlaOp GetMatrixDiagonal(XlaOp x, int k) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); auto n_dims = static_cast(shape.rank()); TF_RET_CHECK(n_dims >= 2); @@ -113,7 +113,7 @@ XlaOp GetMatrixDiagonal(XlaOp x, int k) { XlaOp GetMatrixDiagonalViaGather(XlaOp x, int k) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); auto n_dims = static_cast(shape.rank()); TF_RET_CHECK(n_dims >= 2); @@ -176,7 +176,7 @@ XlaOp GetMatrixDiagonalViaGather(XlaOp x, int k) { XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k) { XlaBuilder* builder = matrix.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(matrix)); TF_ASSIGN_OR_RETURN(Shape diag_shape, builder->GetShape(diag)); auto n_dims = static_cast(shape.rank()); @@ -215,7 +215,7 @@ XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k) { XlaOp TriangleMask(XlaOp x, int diagonal) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); const int64_t n_dims = shape.rank(); TF_RET_CHECK(n_dims >= 2); @@ -242,7 +242,7 @@ XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } XlaOp Symmetrize(XlaOp x, bool lower) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); if (shape.rank() < 2) { return InvalidArgument( @@ -297,7 +297,7 @@ std::optional, 3>> EinsumDiagonalLabels( // reduction. xla::XlaOp EinsumDiagonalMask(XlaOp x, absl::Span config) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); Shape iota_shape = ShapeUtil::MakeShape(S32, x_shape.dimensions()); XlaOp mask = ConstantR0(builder, true); @@ -317,7 +317,7 @@ xla::XlaOp EinsumDiagonalMask(XlaOp x, absl::Span config) { xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span config) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { auto labels = EinsumDiagonalLabels(config); if (!labels) { return x; @@ -333,7 +333,7 @@ xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span config) { xla::XlaOp EinsumInverseDiagonal(XlaOp x, absl::Span config) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { auto labels = EinsumDiagonalLabels(config); if (!labels) { return x; @@ -390,7 +390,7 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, std::optional preferred_element_type, bool grad_x, bool grad_y) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { auto x_diagonal_labels = EinsumDiagonalLabels(x_config); if (x_diagonal_labels) { return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels->at(0), y, @@ -590,7 +590,7 @@ XlaOp BatchDot(XlaOp x, bool transpose_x, XlaOp y, bool transpose_y, std::optional preferred_element_type, bool grad_x, bool grad_y) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { std::string string("...mk,...kn->...mn"); if (transpose_x) { std::swap(string[3], string[4]); @@ -603,7 +603,7 @@ XlaOp BatchDot(XlaOp x, bool transpose_x, XlaOp y, bool transpose_y, }); } -StatusOr, 3>> ParseEinsumString( +absl::StatusOr, 3>> ParseEinsumString( absl::string_view einsum_config, int64_t x_rank, int64_t y_rank) { std::array, 3> einsum_config_numeric; std::vector main_split = @@ -612,7 +612,7 @@ StatusOr, 3>> ParseEinsumString( return InvalidArgument("Expected one \",\" in einsum_config."); } - auto maybe_invalid_character = [](char d) { + auto maybe_invalid_character = [](char d) -> absl::Status { if (absl::ascii_isalpha(d)) { return OkStatus(); } @@ -625,7 +625,7 @@ StatusOr, 3>> ParseEinsumString( auto string_config_to_numeric = [&](absl::string_view config, bool is_input_config, int64_t input_rank, int64_t ellipsis_rank, - std::vector* numeric_config) -> StatusOr { + std::vector* numeric_config) -> absl::StatusOr { std::vector splits = absl::StrSplit(config, "..."); if (splits.empty()) { return ellipsis_rank; @@ -723,7 +723,7 @@ XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config, std::optional preferred_element_type, bool grad_x, bool grad_y) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { auto new_config = NormalizeEinsumString(einsum_config); if (!new_config.empty()) { return Einsum(x, y, new_config, precision, preferred_element_type, grad_x, @@ -748,7 +748,7 @@ XlaOp Einsum(XlaOp x, absl::string_view einsum_config, XlaOp TransposeInMinorDims(XlaOp x) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); const int64_t n_dims = shape.rank(); TF_RET_CHECK(n_dims >= 2); diff --git a/xla/client/lib/matrix.h b/xla/client/lib/matrix.h index 48f75b2e650b6..b1b18b1ae9fd8 100644 --- a/xla/client/lib/matrix.h +++ b/xla/client/lib/matrix.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -117,7 +117,7 @@ xla::XlaOp BatchDot( // NOTE: This function is meant for testing, there is no need to call it // directly. -StatusOr, 3>> ParseEinsumString( +absl::StatusOr, 3>> ParseEinsumString( absl::string_view einsum_config, int64_t x_rank, int64_t y_rank); // If an einsum config does not contain an -> one will be added and the output diff --git a/xla/client/lib/matrix_test.cc b/xla/client/lib/matrix_test.cc index 5201f81ebfbc6..1f3aba6d92e61 100644 --- a/xla/client/lib/matrix_test.cc +++ b/xla/client/lib/matrix_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/pooling.cc b/xla/client/lib/pooling.cc index eb7919bc7fff4..83534ba696190 100644 --- a/xla/client/lib/pooling.cc +++ b/xla/client/lib/pooling.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -80,7 +80,7 @@ XlaOp ComputeSums(XlaOp operand, XlaOp init_value, absl::Span stride, const TensorFormat& data_format) { XlaBuilder* b = operand.builder(); - return b->ReportErrorOrReturn([&]() -> StatusOr { + return b->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); TF_ASSIGN_OR_RETURN(Shape init_shape, b->GetShape(init_value)); PrimitiveType accumulation_type = init_shape.element_type(); @@ -140,7 +140,7 @@ XlaOp MaxPool(XlaOp operand, absl::Span kernel_size, absl::Span stride, Padding padding, const TensorFormat& data_format) { XlaBuilder* b = operand.builder(); - return b->ReportErrorOrReturn([&]() -> StatusOr { + return b->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); PrimitiveType dtype = operand_shape.element_type(); auto max_computation = CreateScalarMaxComputation(dtype, b); @@ -156,7 +156,7 @@ XlaOp AvgPool(XlaOp operand, absl::Span kernel_size, const TensorFormat& data_format, const bool counts_include_padding) { XlaBuilder* b = operand.builder(); - return b->ReportErrorOrReturn([&]() -> StatusOr { + return b->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); PrimitiveType dtype = operand_shape.element_type(); auto init_value = Zero(b, dtype); @@ -202,7 +202,7 @@ XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span gradients_size, const TensorFormat& data_format, const bool counts_include_padding) { XlaBuilder* b = out_backprop.builder(); - return b->ReportErrorOrReturn([&]() -> StatusOr { + return b->ReportErrorOrReturn([&]() -> absl::StatusOr { const int num_dims = kernel_size.size(); const int num_gradients = gradients_size.size(); if (num_gradients != num_dims) { diff --git a/xla/client/lib/pooling.h b/xla/client/lib/pooling.h index d8f0ff0a91654..d48c4de7cc825 100644 --- a/xla/client/lib/pooling.h +++ b/xla/client/lib/pooling.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/pooling_test.cc b/xla/client/lib/pooling_test.cc index 46a12c30e870e..67de2130e5848 100644 --- a/xla/client/lib/pooling_test.cc +++ b/xla/client/lib/pooling_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/prng.cc b/xla/client/lib/prng.cc index 00d9dcbdf5a60..def89a0d40b95 100644 --- a/xla/client/lib/prng.cc +++ b/xla/client/lib/prng.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -505,7 +505,7 @@ RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state, XlaOp ConvertRandomBitsToUniformFloatingPoint(XlaOp bits, XlaOp minval, XlaOp maxval) { XlaBuilder* builder = bits.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* minval_shape, builder->GetShapePtr(minval)); TF_ASSIGN_OR_RETURN(const Shape* bits_shape, builder->GetShapePtr(bits)); diff --git a/xla/client/lib/prng.h b/xla/client/lib/prng.h index 77d931ef074ad..a54b629564b0a 100644 --- a/xla/client/lib/prng.h +++ b/xla/client/lib/prng.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/prng_test.cc b/xla/client/lib/prng_test.cc index a5f9b24cd1027..424846cc6fc2f 100644 --- a/xla/client/lib/prng_test.cc +++ b/xla/client/lib/prng_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/qr.cc b/xla/client/lib/qr.cc index e76e6f5631354..39631584fea62 100644 --- a/xla/client/lib/qr.cc +++ b/xla/client/lib/qr.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ limitations under the License. namespace xla { QrDecomposition Qr(XlaOp a) { - auto result = [&]() -> StatusOr { + auto result = [&]() -> absl::StatusOr { XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); const int num_dims = a_shape.rank(); @@ -69,7 +69,7 @@ QrDecomposition Qr(XlaOp a) { XlaOp ProductOfElementaryHouseholderReflectors(XlaOp a, XlaOp taus) { XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); TF_ASSIGN_OR_RETURN(Shape taus_shape, builder->GetShape(taus)); if (a_shape.rank() < 2) { @@ -109,7 +109,7 @@ XlaOp ProductOfElementaryHouseholderReflectors(XlaOp a, XlaOp taus) { } void QrExplicit(XlaOp a, bool full_matrices, XlaOp& q, XlaOp& r) { - StatusOr a_shape_or = a.builder()->GetShape(a); + absl::StatusOr a_shape_or = a.builder()->GetShape(a); if (!a_shape_or.ok()) { q = a.builder()->ReportError(a_shape_or.status()); r = q; diff --git a/xla/client/lib/qr.h b/xla/client/lib/qr.h index 6d8a102b28fc1..ce51ab342bb39 100644 --- a/xla/client/lib/qr.h +++ b/xla/client/lib/qr.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/qr_test.cc b/xla/client/lib/qr_test.cc index eb3f2e2ae2dfe..a21932b3e797e 100644 --- a/xla/client/lib/qr_test.cc +++ b/xla/client/lib/qr_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -145,4 +145,39 @@ XLA_TEST_F(QrTest, SubnormalComplex) { xla::ErrorSpec(1e-4, 1e-4)); } +XLA_TEST_F(QrTest, DuplicateHouseholderExpansion) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a0_vals({ + {0, 1, 1}, + {1, 0, 1}, + {1, 1, 0}, + }); + xla::Array2D a1_vals({ + {1, 0}, + {0, 1}, + {1, 0}, + }); + + // Verifies that different computations are created to generate HouseHolder + // transformations with identical QR shapes, but different tau shapes. + // The first QR decomposition should generate a ([3,3], [3]) computation, + // the second should generate a ([3,3], [2]) computation. Mismatch will result + // in compilation failure. + + xla::XlaOp a0, q0, r0; + auto a0_data = CreateR2Parameter(a0_vals, 0, "a0", &builder, &a0); + xla::QrExplicit(a0, /*full_matrices=*/true, q0, r0); + + xla::XlaOp a1, q1, r1; + auto a1_data = CreateR2Parameter(a1_vals, 1, "a1", &builder, &a1); + xla::QrExplicit(a1, /*full_matrices=*/true, q1, r1); + + // Verifies that the decomposition composes back to the original matrix. + xla::BatchDot(q1, r1, xla::PrecisionConfig::HIGHEST); + + ComputeAndCompareR2(&builder, a1_vals, {a0_data.get(), a1_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + } // namespace diff --git a/xla/client/lib/quantize.h b/xla/client/lib/quantize.h index a510651e1e640..f9835c42642d3 100644 --- a/xla/client/lib/quantize.h +++ b/xla/client/lib/quantize.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -82,7 +82,7 @@ inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range, absl::string_view mode_string = "MIN_COMBINED", bool transpose_output = false) { XlaBuilder* const builder = input.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { float half_range = !std::is_signed::value ? 0.0f diff --git a/xla/client/lib/quantize_test.cc b/xla/client/lib/quantize_test.cc index df7526a1c7680..d93defa46e05b 100644 --- a/xla/client/lib/quantize_test.cc +++ b/xla/client/lib/quantize_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/self_adjoint_eig.cc b/xla/client/lib/self_adjoint_eig.cc index d25b2e12bbafe..606f968c6648c 100644 --- a/xla/client/lib/self_adjoint_eig.cc +++ b/xla/client/lib/self_adjoint_eig.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -40,7 +40,7 @@ namespace xla { SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64_t max_iter, float tol, bool sort_eigenvalues) { XlaBuilder* builder = a.builder(); - XlaOp result = builder->ReportErrorOrReturn([&]() -> StatusOr { + XlaOp result = builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); const int64_t num_dims = a_shape.rank(); if (num_dims < 2) { diff --git a/xla/client/lib/self_adjoint_eig.h b/xla/client/lib/self_adjoint_eig.h index c21163dff056b..f375f192e71f0 100644 --- a/xla/client/lib/self_adjoint_eig.h +++ b/xla/client/lib/self_adjoint_eig.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/self_adjoint_eig_test.cc b/xla/client/lib/self_adjoint_eig_test.cc index b5cac96a4e514..7be635f2a9a79 100644 --- a/xla/client/lib/self_adjoint_eig_test.cc +++ b/xla/client/lib/self_adjoint_eig_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -286,8 +286,9 @@ XLA_TEST_P(RandomEighTest, Random) { GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); // TODO(phawkins): this would be better expressed as <= 6e-3. - ComputeAndCompareR0(&builder, 3e-3, {a_data.get()}, - ErrorSpec(3e-3, 0)); + double kExpected = 0.00300000003; + ComputeAndCompareR0(&builder, kExpected, {a_data.get()}, + ErrorSpec(kExpected, 0)); } #ifndef XLA_TEST_BACKEND_CPU diff --git a/xla/client/lib/slicing.cc b/xla/client/lib/slicing.cc index 914e41ddabbf2..881181b86c317 100644 --- a/xla/client/lib/slicing.cc +++ b/xla/client/lib/slicing.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -49,7 +49,7 @@ XlaOp DynamicStridedSlice(XlaOp input, absl::Span base_indices, XlaOp SliceInMinorDims(XlaOp x, absl::Span start, absl::Span end) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_RET_CHECK(start.size() == end.size()); int64_t n_minor_dims = start.size(); @@ -78,7 +78,7 @@ XlaOp SliceInMinorDims(XlaOp x, absl::Span start, XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); const int64_t n_dims = shape.rank(); const int64_t start_size = start.size(); @@ -98,7 +98,7 @@ XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start) { XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, absl::Span start) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); const int64_t n_dims = shape.rank(); const int64_t n_minor_dims = start.size(); @@ -120,7 +120,7 @@ std::vector ConcatVectors(absl::Span xs, return output; } -StatusOr> PrependZerosInMajorDims( +absl::StatusOr> PrependZerosInMajorDims( XlaOp x, absl::Span starts) { XlaBuilder* builder = x.builder(); TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); @@ -138,7 +138,7 @@ StatusOr> PrependZerosInMajorDims( XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, absl::Span sizes) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); const int64_t n_dims = shape.rank(); int64_t n_minor_dims = starts.size(); @@ -156,7 +156,7 @@ XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, absl::Span starts) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts)); return DynamicUpdateSlice(x, update, padded_starts); }); @@ -164,7 +164,7 @@ XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse) { XlaBuilder* builder = input.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index)); TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); if (ShapeUtil::ElementHasBitWidth(index_shape, 64) && @@ -234,7 +234,7 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse) { XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64_t dim, const std::function& combiner) { XlaBuilder* builder = input.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index)); TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); std::vector index_broadcast_dims; @@ -273,7 +273,7 @@ XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64_t dim, XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64_t dim, int64_t batch_dims) { XlaBuilder* builder = input.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index)); if (dim < batch_dims) { diff --git a/xla/client/lib/slicing.h b/xla/client/lib/slicing.h index b0b955c023420..af734497522d3 100644 --- a/xla/client/lib/slicing.h +++ b/xla/client/lib/slicing.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/slicing_test.cc b/xla/client/lib/slicing_test.cc index d0578caffa064..ff1dfe3fd7c31 100644 --- a/xla/client/lib/slicing_test.cc +++ b/xla/client/lib/slicing_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/sorting.cc b/xla/client/lib/sorting.cc index 8fad0484c2daa..2f3de2c0427b9 100644 --- a/xla/client/lib/sorting.cc +++ b/xla/client/lib/sorting.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ namespace xla { XlaOp TopK(XlaOp input, int64_t k, PrimitiveType index_type) { XlaBuilder* const builder = input.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); int last_dim = input_shape.dimensions_size() - 1; int64_t last_dim_size = input_shape.dimensions(last_dim); @@ -158,7 +158,7 @@ XlaOp TopK(XlaOp input, int64_t k, PrimitiveType index_type) { XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions, PrimitiveType index_type) { XlaBuilder* const builder = input.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); int last_dim = input_shape.dimensions_size() - 1; // Calculate per partition size. @@ -183,7 +183,7 @@ XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions, auto topk_body_fn = [&](XlaOp partition, absl::Span values_and_indices, - XlaBuilder* builder) -> StatusOr> { + XlaBuilder* builder) -> absl::StatusOr> { auto values = values_and_indices[0]; auto indices = values_and_indices[1]; auto input = values_and_indices[2]; diff --git a/xla/client/lib/sorting.h b/xla/client/lib/sorting.h index 03b317037df48..4af4f8caaf977 100644 --- a/xla/client/lib/sorting.h +++ b/xla/client/lib/sorting.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/sorting_test.cc b/xla/client/lib/sorting_test.cc index 4f627993a9b63..a2ce452b04fb9 100644 --- a/xla/client/lib/sorting_test.cc +++ b/xla/client/lib/sorting_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/svd.cc b/xla/client/lib/svd.cc index af6bfe3b0af79..9e7a33209d0b7 100644 --- a/xla/client/lib/svd.cc +++ b/xla/client/lib/svd.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -107,8 +107,9 @@ struct OneSidedJacobiRotation { // // A[i, j:] * H = [sigma, 0, 0, ..., 0] // -StatusOr HouseRow(XlaOp a, XlaOp i, XlaOp j, XlaOp eps, - PrecisionConfig::Precision precision) { +absl::StatusOr HouseRow( + XlaOp a, XlaOp i, XlaOp j, XlaOp eps, + PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); const int64_t num_dims = a_shape.rank(); @@ -172,8 +173,9 @@ StatusOr HouseRow(XlaOp a, XlaOp i, XlaOp j, XlaOp eps, // // H * A[i:, j] = [xnorm, 0, 0, ..., 0] // -StatusOr HouseCol(XlaOp a, XlaOp i, XlaOp j, XlaOp eps, - PrecisionConfig::Precision precision) { +absl::StatusOr HouseCol( + XlaOp a, XlaOp i, XlaOp j, XlaOp eps, + PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); const int64_t num_dims = a_shape.rank(); @@ -250,7 +252,7 @@ StatusOr HouseCol(XlaOp a, XlaOp i, XlaOp j, XlaOp eps, // A = np.matmul(A, R) // return LL, A, RR // -StatusOr HouseHolderBidiagonalization( +absl::StatusOr HouseHolderBidiagonalization( XlaOp a, XlaOp eps, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); @@ -268,13 +270,13 @@ StatusOr HouseHolderBidiagonalization( IdentityMatrix(builder, a_shape.element_type(), n, n), batch_dims); auto while_cond_fn = [&](absl::Span values, - XlaBuilder* cond_builder) -> StatusOr { + XlaBuilder* cond_builder) -> absl::StatusOr { auto i = values[0]; return Lt(i, ScalarLike(i, n - 2)); }; auto while_body_fn = [&](absl::Span values, - XlaBuilder* body_builder) -> StatusOr> { + XlaBuilder* body_builder) -> absl::StatusOr> { auto i = values[0]; auto one = ScalarLike(i, 1); @@ -357,7 +359,8 @@ StatusOr HouseHolderBidiagonalization( // s = 0.0 // return c, s // -StatusOr MakeJacobi(XlaOp ps, XlaOp qs, XlaOp pqs, XlaOp eps) { +absl::StatusOr MakeJacobi(XlaOp ps, XlaOp qs, XlaOp pqs, + XlaOp eps) { auto zero = ScalarLike(ps, 0.0); auto one = ScalarLike(ps, 1.0); auto two = ScalarLike(ps, 2.0); @@ -411,8 +414,10 @@ StatusOr MakeJacobi(XlaOp ps, XlaOp qs, XlaOp pqs, XlaOp eps) { // rot_l = rot @ rot_r // return rot_l, rot_r // -StatusOr GetOneSidedJacobiRotation(XlaOp a, XlaOp p, - XlaOp q, XlaOp eps) { +absl::StatusOr GetOneSidedJacobiRotation(XlaOp a, + XlaOp p, + XlaOp q, + XlaOp eps) { XlaOp a_pp = DynamicSliceInMinorDims(a, {p, p}, {1, 1}); XlaOp a_pq = DynamicSliceInMinorDims(a, {p, q}, {1, 1}); XlaOp a_qp = DynamicSliceInMinorDims(a, {q, p}, {1, 1}); @@ -449,8 +454,8 @@ StatusOr GetOneSidedJacobiRotation(XlaOp a, XlaOp p, } // Apply one-sided Jacobi on elements at indices pp, pq, qp, qq. -StatusOr OneSidedJacobiUpdate(SVDResult svd_result, XlaOp p, XlaOp q, - XlaOp eps) { +absl::StatusOr OneSidedJacobiUpdate(SVDResult svd_result, XlaOp p, + XlaOp q, XlaOp eps) { XlaOp u = svd_result.u; XlaOp v = svd_result.v; XlaOp d = svd_result.d; @@ -563,7 +568,7 @@ StatusOr OneSidedJacobiUpdate(SVDResult svd_result, XlaOp p, XlaOp q, return svd_result; } -StatusOr ComputeToleranceComparison(XlaOp w, XlaOp epsilon) { +absl::StatusOr ComputeToleranceComparison(XlaOp w, XlaOp epsilon) { XlaBuilder* builder = w.builder(); TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w)); auto num_dims = static_cast(shape.rank()); @@ -600,14 +605,14 @@ StatusOr ComputeToleranceComparison(XlaOp w, XlaOp epsilon) { } // Main boby of One-sided Jacobi Method. -StatusOr> WhileLoopFn( +absl::StatusOr> WhileLoopFn( absl::Span initial_values, // int matrix_dimension, // int max_sweep_updates, // absl::string_view name, // XlaBuilder* builder) { auto while_cond_fn = [&](absl::Span values, - XlaBuilder* cond_builder) -> StatusOr { + XlaBuilder* cond_builder) -> absl::StatusOr { auto k = values[0]; auto max_sweeps = ScalarLike(k, max_sweep_updates); auto sweep_update_cond = Gt(max_sweeps, k); @@ -623,27 +628,27 @@ StatusOr> WhileLoopFn( auto while_body_fn = [&](absl::Span values, - XlaBuilder* body_builder) -> StatusOr> { + XlaBuilder* body_builder) -> absl::StatusOr> { auto while_cond_fn_inner = [&](absl::Span values_inner, - XlaBuilder* inner_cond_builder) -> StatusOr { + XlaBuilder* inner_cond_builder) -> absl::StatusOr { auto p = values_inner[0]; return Lt(p, ScalarLike(p, matrix_dimension - 1)); }; - auto while_body_fn_inner = - [&](absl::Span values_inner, - XlaBuilder* inner_body_builder) -> StatusOr> { + auto while_body_fn_inner = [&](absl::Span values_inner, + XlaBuilder* inner_body_builder) + -> absl::StatusOr> { auto while_cond_fn_innermost = [&](absl::Span values_innermost, - XlaBuilder* innermost_cond_builder) -> StatusOr { + XlaBuilder* innermost_cond_builder) -> absl::StatusOr { auto q = values_innermost[1]; return Lt(q, ScalarLike(q, matrix_dimension)); }; auto while_body_fn_innermost = [&](absl::Span values_innermost, XlaBuilder* innermost_body_builder) - -> StatusOr> { + -> absl::StatusOr> { auto p = values_innermost[0]; auto q = values_innermost[1]; @@ -731,7 +736,8 @@ StatusOr> WhileLoopFn( // Sort singular values in decending order, and make sure they are non-negative // by flipping the signs of negative diagonal values and transferring the signs // to V. And for numeric stability, renormalize U and V. -StatusOr SortBySingularValuesAndPostProcessing(SVDResult result) { +absl::StatusOr SortBySingularValuesAndPostProcessing( + SVDResult result) { XlaBuilder* builder = result.d.builder(); TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(result.d)); const int64_t num_dims = shape.rank(); diff --git a/xla/client/lib/svd.h b/xla/client/lib/svd.h index 177d84f3153d8..07f361f73b3a3 100644 --- a/xla/client/lib/svd.h +++ b/xla/client/lib/svd.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/svd_test.cc b/xla/client/lib/svd_test.cc index 56d034e9f71ab..f27d78974e1dd 100644 --- a/xla/client/lib/svd_test.cc +++ b/xla/client/lib/svd_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/testing.cc b/xla/client/lib/testing.cc index dcea26ba5943e..0461108bb2496 100644 --- a/xla/client/lib/testing.cc +++ b/xla/client/lib/testing.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -83,7 +83,7 @@ std::unique_ptr MakeFakeDataViaDeviceOrDie( std::unique_ptr MakeFakeDataOrDie( const Shape& shape, Client* client, DebugOptions* debug_opts /*=nullptr*/) { if (DataSizeOfShape(shape) < (1LL << 20)) { - StatusOr literal_status = MakeFakeLiteral(shape); + absl::StatusOr literal_status = MakeFakeLiteral(shape); if (!literal_status.ok()) { // If we got an Unimplemented error, fall back to making the fake data via // an on-device computation. diff --git a/xla/client/lib/testing.h b/xla/client/lib/testing.h index b20ae049df8a8..cbc5f79340d25 100644 --- a/xla/client/lib/testing.h +++ b/xla/client/lib/testing.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/tridiagonal.cc b/xla/client/lib/tridiagonal.cc index 3f629aa8e8d43..4daab7b4e8408 100644 --- a/xla/client/lib/tridiagonal.cc +++ b/xla/client/lib/tridiagonal.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -50,10 +50,10 @@ Status CheckSecondToLastDimension(const Shape& op_shape, int64_t rank, return OkStatus(); } -StatusOr CheckSystemAndReturnNumEquations(XlaOp lower_diagonal, - XlaOp main_diagonal, - XlaOp upper_diagonal, - XlaOp rhs) { +absl::StatusOr CheckSystemAndReturnNumEquations(XlaOp lower_diagonal, + XlaOp main_diagonal, + XlaOp upper_diagonal, + XlaOp rhs) { XlaBuilder* builder = lower_diagonal.builder(); TF_ASSIGN_OR_RETURN(Shape lower_diagonal_shape, @@ -160,9 +160,9 @@ Status ValidateTridiagonalMatMulDiagonal(const Shape& diagonal_shape, return OkStatus(); } -StatusOr CheckMatMulSystemAndReturnShapeParams( - XlaOp upper_diagonal, XlaOp main_diagonal, XlaOp lower_diagonal, - XlaOp rhs) { +absl::StatusOr +CheckMatMulSystemAndReturnShapeParams(XlaOp upper_diagonal, XlaOp main_diagonal, + XlaOp lower_diagonal, XlaOp rhs) { XlaBuilder* builder = upper_diagonal.builder(); TF_ASSIGN_OR_RETURN(const Shape upper_diagonal_shape, @@ -217,8 +217,9 @@ XlaOp UpdateEq(XlaOp updated, XlaOp i, XlaOp update) { } template -StatusOr TridiagonalSolverImpl(XlaOp lower_diagonal, XlaOp main_diagonal, - XlaOp upper_diagonal, XlaOp rhs); +absl::StatusOr TridiagonalSolverImpl(XlaOp lower_diagonal, + XlaOp main_diagonal, + XlaOp upper_diagonal, XlaOp rhs); // Applies Thomas algorithm to solve a linear system where the linear operand // is a tri-diagonal matrix. @@ -233,10 +234,10 @@ StatusOr TridiagonalSolverImpl(XlaOp lower_diagonal, XlaOp main_diagonal, // the right-hand-side `rhs` should be [..., num_rhs, num_equations]. The // solution will have the shape [..., num_rhs, num_equations]. template <> -StatusOr TridiagonalSolverImpl(XlaOp lower_diagonal, - XlaOp main_diagonal, - XlaOp upper_diagonal, - XlaOp rhs) { +absl::StatusOr TridiagonalSolverImpl(XlaOp lower_diagonal, + XlaOp main_diagonal, + XlaOp upper_diagonal, + XlaOp rhs) { XlaBuilder* builder = lower_diagonal.builder(); TF_ASSIGN_OR_RETURN(int64_t num_eqs, @@ -258,7 +259,7 @@ StatusOr TridiagonalSolverImpl(XlaOp lower_diagonal, auto preparation_body_fn = [](XlaOp i, absl::Span values, - XlaBuilder* builder) -> StatusOr> { + XlaBuilder* builder) -> absl::StatusOr> { auto upper_diagonal_coeffs = values[0]; auto upper_diagonal = values[1]; // upper_diagonal_coeffs[:, i] = upper_diagonal[:, i]; @@ -275,7 +276,7 @@ StatusOr TridiagonalSolverImpl(XlaOp lower_diagonal, // Forward transformation. auto forward_transformation_fn = [](XlaOp i_minus_one, absl::Span values, - XlaBuilder* builder) -> StatusOr> { + XlaBuilder* builder) -> absl::StatusOr> { auto lower_diagonal = values[0]; auto main_diagonal = values[1]; auto rhs = values[2]; @@ -333,7 +334,7 @@ StatusOr TridiagonalSolverImpl(XlaOp lower_diagonal, Coefficient(main_diag_after_elimination, num_eqs - 1)); auto bwd_reduction_fn = [num_eqs](XlaOp j, absl::Span values, - XlaBuilder* builder) -> StatusOr> { + XlaBuilder* builder) -> absl::StatusOr> { auto x_coeffs = values[0]; auto rhs_after_elimination = values[1]; auto upper_diagonal_coeffs = values[2]; @@ -368,9 +369,10 @@ StatusOr TridiagonalSolverImpl(XlaOp lower_diagonal, } // namespace -StatusOr TridiagonalSolver(SolverAlgorithm algo, XlaOp lower_diagonal, - XlaOp main_diagonal, XlaOp upper_diagonal, - XlaOp rhs) { +absl::StatusOr TridiagonalSolver(SolverAlgorithm algo, + XlaOp lower_diagonal, + XlaOp main_diagonal, + XlaOp upper_diagonal, XlaOp rhs) { switch (algo) { case kThomas: return TridiagonalSolverImpl(lower_diagonal, main_diagonal, @@ -394,8 +396,8 @@ StatusOr TridiagonalSolver(SolverAlgorithm algo, XlaOp lower_diagonal, // The right-hand-side d is expected to have dimension // [..., num_rhs, num_equations]. // The solution will have size [..., num_rhs, num_equations]. -StatusOr TridiagonalSolver(SolverAlgorithm algo, XlaOp diagonals, - XlaOp rhs) { +absl::StatusOr TridiagonalSolver(SolverAlgorithm algo, XlaOp diagonals, + XlaOp rhs) { XlaBuilder* builder = diagonals.builder(); TF_ASSIGN_OR_RETURN(Shape diagonals_shape, builder->GetShape(diagonals)); const int64_t rank = diagonals_shape.rank(); @@ -440,8 +442,9 @@ StatusOr TridiagonalSolver(SolverAlgorithm algo, XlaOp diagonals, // [..., 0] is ignored. // The `right-hand-side` is expected to have dimension [..., M, N]. // The solution will have size [..., M, N]. -StatusOr TridiagonalMatMul(XlaOp upper_diagonal, XlaOp main_diagonal, - XlaOp lower_diagonal, XlaOp rhs) { +absl::StatusOr TridiagonalMatMul(XlaOp upper_diagonal, + XlaOp main_diagonal, + XlaOp lower_diagonal, XlaOp rhs) { TF_ASSIGN_OR_RETURN(const TridiagonalMatMulShapeParams shape_params, CheckMatMulSystemAndReturnShapeParams( upper_diagonal, main_diagonal, lower_diagonal, rhs)); diff --git a/xla/client/lib/tridiagonal.h b/xla/client/lib/tridiagonal.h index f458a1934711d..10ce52ed21ca1 100644 --- a/xla/client/lib/tridiagonal.h +++ b/xla/client/lib/tridiagonal.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,15 +24,17 @@ namespace tridiagonal { enum SolverAlgorithm { kThomas }; -StatusOr TridiagonalSolver(SolverAlgorithm algo, XlaOp lower_diagonal, - XlaOp main_diagonal, XlaOp upper_diagonal, - XlaOp rhs); +absl::StatusOr TridiagonalSolver(SolverAlgorithm algo, + XlaOp lower_diagonal, + XlaOp main_diagonal, + XlaOp upper_diagonal, XlaOp rhs); -StatusOr TridiagonalSolver(SolverAlgorithm algo, XlaOp diagonals, - XlaOp rhs); +absl::StatusOr TridiagonalSolver(SolverAlgorithm algo, XlaOp diagonals, + XlaOp rhs); -StatusOr TridiagonalMatMul(XlaOp upper_diagonal, XlaOp main_diagonal, - XlaOp lower_diagonal, XlaOp rhs); +absl::StatusOr TridiagonalMatMul(XlaOp upper_diagonal, + XlaOp main_diagonal, + XlaOp lower_diagonal, XlaOp rhs); } // namespace tridiagonal } // namespace xla diff --git a/xla/client/lib/tridiagonal_test.cc b/xla/client/lib/tridiagonal_test.cc index c76aa42aa9e6c..8be29c3304d79 100644 --- a/xla/client/lib/tridiagonal_test.cc +++ b/xla/client/lib/tridiagonal_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/lib/tuple.cc b/xla/client/lib/tuple.cc index 943be6ade2dcc..ae116fba3ae72 100644 --- a/xla/client/lib/tuple.cc +++ b/xla/client/lib/tuple.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,7 +27,7 @@ limitations under the License. namespace xla { -StatusOr> DisassembleTuple(XlaOp tuple) { +absl::StatusOr> DisassembleTuple(XlaOp tuple) { TF_ASSIGN_OR_RETURN(Shape shape, tuple.builder()->GetShape(tuple)); ShapeTree result(shape); result.ForEachMutableElement([&](ShapeIndexView index, XlaOp* element) { diff --git a/xla/client/lib/tuple.h b/xla/client/lib/tuple.h index 59cbb24f5353b..56c8a3f6b99e0 100644 --- a/xla/client/lib/tuple.h +++ b/xla/client/lib/tuple.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ namespace xla { // Returns a ShapeTree where each index is a GetTupleElement instruction for // that subshape of the tuple. The root index is the original argument. -StatusOr> DisassembleTuple(XlaOp tuple); +absl::StatusOr> DisassembleTuple(XlaOp tuple); // Assembles a tuple from a ShapeTree that contains the leaves of the tuple. // Non-leaf elements of the ShapeTree are ignored. DisassembleTuple and diff --git a/xla/client/lib/tuple_test.cc b/xla/client/lib/tuple_test.cc index 6a3368014d412..1e1f8eb9ed098 100644 --- a/xla/client/lib/tuple_test.cc +++ b/xla/client/lib/tuple_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/local_client.cc b/xla/client/local_client.cc index 8b61d6bd3ed6d..033361e897b67 100644 --- a/xla/client/local_client.cc +++ b/xla/client/local_client.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,8 +32,8 @@ using xla::source_map_util::InvalidParameterArgument; namespace xla { namespace { -StatusOr BorrowStreamForDevice(int device_ordinal, - Backend* backend) { +absl::StatusOr BorrowStreamForDevice(int device_ordinal, + Backend* backend) { if (device_ordinal < 0) { device_ordinal = backend->default_device_ordinal(); } @@ -115,7 +115,7 @@ Status LocalExecutable::ValidateExecutionOptions( return OkStatus(); } -StatusOr> +absl::StatusOr> LocalExecutable::RunHelper(const absl::Span argument_shapes, ExecutableRunOptions run_options) { const ComputationLayout& computation_layout = @@ -171,7 +171,7 @@ LocalExecutable::RunHelper(const absl::Span argument_shapes, return std::make_pair(service_options, std::move(stream)); } -StatusOr LocalExecutable::Run( +absl::StatusOr LocalExecutable::Run( const absl::Span arguments, ExecutableRunOptions run_options) { std::vector argument_shapes; @@ -185,7 +185,7 @@ StatusOr LocalExecutable::Run( }); } -StatusOr LocalExecutable::Run( +absl::StatusOr LocalExecutable::Run( std::vector arguments, ExecutableRunOptions run_options) { std::vector argument_shapes; argument_shapes.reserve(arguments.size()); @@ -239,7 +239,7 @@ static void DumpOutputsAndSaveSnapshot(const Backend* backend, }); } -StatusOr LocalExecutable::RunAsync( +absl::StatusOr LocalExecutable::RunAsync( const absl::Span arguments, ExecutableRunOptions run_options) { std::vector argument_shapes; @@ -279,7 +279,7 @@ static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer( return result; } -StatusOr LocalExecutable::RunAsync( +absl::StatusOr LocalExecutable::RunAsync( absl::Span argument_host_shapes, std::vector arguments, ExecutableRunOptions run_options) { if (argument_host_shapes.size() != arguments.size()) { @@ -321,7 +321,7 @@ StatusOr LocalExecutable::RunAsync( return std::move(outputs); } -StatusOr LocalExecutable::RunAsync( +absl::StatusOr LocalExecutable::RunAsync( std::vector arguments, ExecutableRunOptions run_options) { std::vector argument_shapes; argument_shapes.reserve(arguments.size()); @@ -355,7 +355,7 @@ Backend* LocalClient::mutable_backend() { return local_service_->mutable_backend(); } -static StatusOr UpdateBuildOptions( +static absl::StatusOr UpdateBuildOptions( const ExecutableBuildOptions& options, int default_device_ordinal) { ExecutableBuildOptions updated_options = options; if (options.device_ordinal() == -1) { @@ -383,10 +383,10 @@ static StatusOr UpdateBuildOptions( return updated_options; } -StatusOr>> LocalClient::Compile( - const XlaComputation& computation, - const absl::Span argument_layouts, - const ExecutableBuildOptions& options) { +absl::StatusOr>> +LocalClient::Compile(const XlaComputation& computation, + const absl::Span argument_layouts, + const ExecutableBuildOptions& options) { TF_ASSIGN_OR_RETURN(ExecutableBuildOptions updated_options, UpdateBuildOptions(options, default_device_ordinal())); TF_ASSIGN_OR_RETURN(std::vector> executables, @@ -405,7 +405,7 @@ StatusOr>> LocalClient::Compile( return std::move(local_executables); } -StatusOr>> +absl::StatusOr>> LocalClient::CompileAheadOfTime( const XlaComputation& computation, const absl::Span argument_layouts, @@ -420,7 +420,7 @@ LocalClient::CompileAheadOfTime( return std::move(aot_results); } -StatusOr> LocalClient::Load( +absl::StatusOr> LocalClient::Load( const std::string& serialized_aot_result, const ExecutableBuildOptions& options) { TF_ASSIGN_OR_RETURN(ExecutableBuildOptions updated_options, @@ -442,7 +442,7 @@ StatusOr> LocalClient::Load( updated_options); } -StatusOr LocalClient::LiteralToShapedBuffer( +absl::StatusOr LocalClient::LiteralToShapedBuffer( const LiteralSlice& literal, int device_ordinal, se::DeviceMemoryAllocator* allocator) { if (allocator == nullptr) { @@ -458,7 +458,7 @@ StatusOr LocalClient::LiteralToShapedBuffer( return std::move(scoped_buffer); } -StatusOr LocalClient::ShapedBufferToLiteral( +absl::StatusOr LocalClient::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream( shaped_buffer.device_ordinal())); @@ -466,7 +466,7 @@ StatusOr LocalClient::ShapedBufferToLiteral( shaped_buffer); } -StatusOr LocalClient::GlobalDataToShapedBuffer( +absl::StatusOr LocalClient::GlobalDataToShapedBuffer( const GlobalDataHandle& data, int replica_number) { return local_service_->GlobalDataToShapedBuffer(data, replica_number); } @@ -487,11 +487,12 @@ Status LocalClient::TransferFromOutfeedLocal(int device_ordinal, literal); } -StatusOr LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) { +absl::StatusOr LocalClient::ReplicaNumberToDeviceOrdinal( + int replica_number) { return local_service_->ReplicaNumberToDeviceOrdinal(replica_number); } -StatusOr LocalClient::TransferToLocalServer( +absl::StatusOr LocalClient::TransferToLocalServer( const ::xla::BorrowingLiteral& literal, int device_ordinal) { const ::xla::Shape& shape = literal.shape(); diff --git a/xla/client/local_client.h b/xla/client/local_client.h index f14b98581a9fb..d6e382192cab8 100644 --- a/xla/client/local_client.h +++ b/xla/client/local_client.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -49,25 +49,25 @@ class LocalExecutable { // Run the compiled computation with the given arguments and options and // return the result. - StatusOr Run( + absl::StatusOr Run( absl::Span arguments, ExecutableRunOptions run_options); // Similar to Run(), but allows for donating argument buffers to the // executable. - StatusOr Run(std::vector arguments, - ExecutableRunOptions run_options); + absl::StatusOr Run(std::vector arguments, + ExecutableRunOptions run_options); // Similar to Run(), but need not block the host waiting for the computation // to complete before returning. - StatusOr RunAsync( + absl::StatusOr RunAsync( absl::Span arguments, ExecutableRunOptions run_options); // Similar to RunAsync(), but allows for donating argument buffers to the // executable. - StatusOr RunAsync(std::vector arguments, - ExecutableRunOptions run_options); + absl::StatusOr RunAsync( + std::vector arguments, ExecutableRunOptions run_options); // Return the options used to build the executable. const ExecutableBuildOptions& build_options() const { return build_options_; } @@ -76,7 +76,7 @@ class LocalExecutable { Executable* executable() const { return executable_.get(); } private: - StatusOr RunAsync( + absl::StatusOr RunAsync( absl::Span argument_host_shapes, std::vector arguments, ExecutableRunOptions run_options); @@ -89,11 +89,12 @@ class LocalExecutable { const Backend& backend); // Returns a literal containing the contents of the given ShapedBuffer. - StatusOr LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer); + absl::StatusOr LiteralFromShapedBuffer( + const ShapedBuffer& shaped_buffer); - StatusOr> RunHelper( - absl::Span argument_shapes, - ExecutableRunOptions run_options); + absl::StatusOr> + RunHelper(absl::Span argument_shapes, + ExecutableRunOptions run_options); // The ordinal of the device which this executable was compiled for. The // executable can run on all equivalent devices (as determined by @@ -142,7 +143,7 @@ class LocalClient : public Client { // // The given ExecutableBuildOptions overrides any values from XLA_FLAGS // environment variable. - StatusOr>> Compile( + absl::StatusOr>> Compile( const XlaComputation& computation, absl::Span argument_layouts, const ExecutableBuildOptions& options); @@ -150,14 +151,14 @@ class LocalClient : public Client { // Same as Compile() above, but return AotCompilationResult objects (instead // of LocalExecutable objects), which can be persisted to later load // LocalExecutable(s) using the Load() method below. - StatusOr>> + absl::StatusOr>> CompileAheadOfTime(const XlaComputation& computation, absl::Span argument_layouts, const ExecutableBuildOptions& options); // Return a LocalExecutable object loaded from a serialized // AotCompilationResult. - StatusOr> Load( + absl::StatusOr> Load( const std::string& serialized_aot_result, const ExecutableBuildOptions& options); @@ -165,21 +166,22 @@ class LocalClient : public Client { // ScopedShapedBuffer. If non-null the given memory allocator is used for // device memory allocation. If null, the default memory allocator for the // device is used. - StatusOr LiteralToShapedBuffer( + absl::StatusOr LiteralToShapedBuffer( const LiteralSlice& literal, int device_ordinal, se::DeviceMemoryAllocator* allocator = nullptr); // Transfer the BorrowingLiteral to the device with the given ordinal. - StatusOr TransferToLocalServer( + absl::StatusOr TransferToLocalServer( const ::xla::BorrowingLiteral& literal, int device_ordinal); // Copy the data from the device contained in the given ShapedBuffer and // return as a Literal. - StatusOr ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer); + absl::StatusOr ShapedBufferToLiteral( + const ShapedBuffer& shaped_buffer); // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid // as long as the handle is valid. - StatusOr GlobalDataToShapedBuffer( + absl::StatusOr GlobalDataToShapedBuffer( const GlobalDataHandle& data, int replica_number); // Transfer the given literal to the infeed queue of the given device. @@ -201,7 +203,7 @@ class LocalClient : public Client { // This returns an error if there is not a one-to-one correspondence of // replicas to device ordinals, but is useful as a short term mechanism for // the "easy" case where a single replica is a single device. - StatusOr ReplicaNumberToDeviceOrdinal(int replica_number); + absl::StatusOr ReplicaNumberToDeviceOrdinal(int replica_number); // Returns the platform that the underlying service targets. se::Platform* platform() const; diff --git a/xla/client/padding.cc b/xla/client/padding.cc index ffb8f0b6552d0..f36c1cd18a505 100644 --- a/xla/client/padding.cc +++ b/xla/client/padding.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/padding.h b/xla/client/padding.h index 38a55c3a5816d..5c4a34b0d0922 100644 --- a/xla/client/padding.h +++ b/xla/client/padding.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/padding_test.cc b/xla/client/padding_test.cc index 7d39be74d1dda..0d183d0e16ede 100644 --- a/xla/client/padding_test.cc +++ b/xla/client/padding_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/sharding_builder.cc b/xla/client/sharding_builder.cc index 32fa0c68e1124..e2324d68f92db 100644 --- a/xla/client/sharding_builder.cc +++ b/xla/client/sharding_builder.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/sharding_builder.h b/xla/client/sharding_builder.h index b4f2424458df5..98d6512d59c28 100644 --- a/xla/client/sharding_builder.h +++ b/xla/client/sharding_builder.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/client/value_inference.cc b/xla/client/value_inference.cc index 2c48fbc67e48c..b4bb8af37c4ea 100644 --- a/xla/client/value_inference.cc +++ b/xla/client/value_inference.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -148,7 +148,7 @@ struct HloProtoEvaluator { return *this; } - StatusOr Evaluate() { + absl::StatusOr Evaluate() { // Evaluate the instruction by swapping it's operands with constant // instructions with given literals. HloComputation::Builder builder("EmptyComputation"); @@ -286,11 +286,11 @@ struct PostorderDFSDep { // This function represents the logic to visit a node once its dependencies // (operands) are all resolved. -using Visit = std::function(absl::Span)>; +using Visit = std::function(absl::Span)>; // Convenient specializations of Visit function for different operands. -using Visit0D = std::function()>; -using Visit1D = std::function(Literal)>; -using Visit2D = std::function(Literal, Literal)>; +using Visit0D = std::function()>; +using Visit1D = std::function(Literal)>; +using Visit2D = std::function(Literal, Literal)>; // A postorder dfs node can be visited once its dependency requests are all // fulfilled. @@ -332,7 +332,7 @@ struct [[nodiscard]] PostorderDFSNode { // Convert an interger handle to HloInstructionProto. using HandleToInstruction = - std::function(int64_t)>; + std::function(int64_t)>; using HandleToComputation = std::function; struct PostorderDFSVisitor { @@ -343,20 +343,20 @@ struct PostorderDFSVisitor { handle_to_instruction(handle_to_instruction), handle_to_computation(handle_to_computation) {} - StatusOr AnalyzeUpperBound(int64_t handle, - InferenceContext context); - StatusOr AnalyzeLowerBound(int64_t handle, - InferenceContext context); - StatusOr AnalyzeIsDynamic(int64_t handle, - PostorderDFSNodeType type, - InferenceContext context); - StatusOr AnalyzeConstant(int64_t handle, - InferenceContext context); - StatusOr AnalyzeConstantValueFallback( + absl::StatusOr AnalyzeUpperBound(int64_t handle, + InferenceContext context); + absl::StatusOr AnalyzeLowerBound(int64_t handle, + InferenceContext context); + absl::StatusOr AnalyzeIsDynamic(int64_t handle, + PostorderDFSNodeType type, + InferenceContext context); + absl::StatusOr AnalyzeConstant(int64_t handle, + InferenceContext context); + absl::StatusOr AnalyzeConstantValueFallback( int64_t handle, PostorderDFSNodeType type, InferenceContext context); - StatusOr PostOrderDFSVisit(int64_t handle, - PostorderDFSNodeType type); + absl::StatusOr PostOrderDFSVisit(int64_t handle, + PostorderDFSNodeType type); // Returns true if a value represented by `handle` is an integeral type or // a floating pointer type that just got converted from an integral type. @@ -469,8 +469,10 @@ PostorderDFSNode CreateAllDynamicResult(const Shape& shape, } // namespace // Analyze a tensor's constant value, upper-bound value or lower-bound value. -StatusOr PostorderDFSVisitor::AnalyzeConstantValueFallback( - int64_t handle, PostorderDFSNodeType type, InferenceContext context) { +absl::StatusOr +PostorderDFSVisitor::AnalyzeConstantValueFallback(int64_t handle, + PostorderDFSNodeType type, + InferenceContext context) { TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, handle_to_instruction(handle)); TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode())); @@ -534,7 +536,7 @@ StatusOr PostorderDFSVisitor::AnalyzeConstantValueFallback( call_context.caller_operand_handles.push_back(call_proto->operand_ids(0)); node.AddDependency(called_root, PostorderDFSNodeType::kConstantValue, call_context, "callee's root instruction"); - return node.AddVisit([](Literal operand) -> StatusOr { + return node.AddVisit([](Literal operand) -> absl::StatusOr { // Forward result of callee's root to caller. return std::move(operand); }); @@ -565,7 +567,7 @@ StatusOr PostorderDFSVisitor::AnalyzeConstantValueFallback( branch_context); } return node.AddVisit( - [](absl::Span operands) -> StatusOr { + [](absl::Span operands) -> absl::StatusOr { int64_t pred_is_dynamic = operands[1].Get({}); if (pred_is_dynamic) { // If predicate is dynamic, return the value of the first branch @@ -606,7 +608,7 @@ StatusOr PostorderDFSVisitor::AnalyzeConstantValueFallback( handle_to_computation(root->called_computation_ids(0)); return result.AddVisit( [root, computation_proto, context, - this](absl::Span operands) -> StatusOr { + this](absl::Span operands) -> absl::StatusOr { TF_ASSIGN_OR_RETURN( auto computation, HloComputation::CreateFromProto(*computation_proto, {})); @@ -638,7 +640,7 @@ StatusOr PostorderDFSVisitor::AnalyzeConstantValueFallback( } } -StatusOr PostorderDFSVisitor::AnalyzeUpperBound( +absl::StatusOr PostorderDFSVisitor::AnalyzeUpperBound( int64_t handle, InferenceContext context) { TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, handle_to_instruction(handle)); @@ -657,7 +659,7 @@ StatusOr PostorderDFSVisitor::AnalyzeUpperBound( const HloInstructionProto* operand_proto = handle_to_instruction(operand_handle).value(); return PostorderDFSNode().AddVisit( - [operand_proto, dimension]() -> StatusOr { + [operand_proto, dimension]() -> absl::StatusOr { return LiteralUtil::CreateR0( operand_proto->shape().dimensions(dimension)); }); @@ -671,7 +673,7 @@ StatusOr PostorderDFSVisitor::AnalyzeUpperBound( .AddDependency(root->operand_ids(0), PostorderDFSNodeType::kConstantUpperBound, context) .AddVisit([this](Literal lower_bound, - Literal upper_bound) -> StatusOr { + Literal upper_bound) -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto lower_bound_abs, evaluator.EvaluateElementwiseUnaryOp( HloOpcode::kAbs, lower_bound)); @@ -701,7 +703,8 @@ StatusOr PostorderDFSVisitor::AnalyzeUpperBound( } return dfs.AddVisit( - [root, context](absl::Span operands) -> StatusOr { + [root, + context](absl::Span operands) -> absl::StatusOr { std::vector results; results.reserve(operands.size()); // Conservatively set each element of the tensor to the max value. @@ -724,7 +727,7 @@ StatusOr PostorderDFSVisitor::AnalyzeUpperBound( return PostorderDFSNode() .AddDependency(root->operand_ids(0), PostorderDFSNodeType::kConstantLowerBound, context) - .AddVisit([this](Literal lower_bound) -> StatusOr { + .AddVisit([this](Literal lower_bound) -> absl::StatusOr { return evaluator.EvaluateElementwiseUnaryOp(HloOpcode::kNegate, lower_bound); }); @@ -739,7 +742,7 @@ StatusOr PostorderDFSVisitor::AnalyzeUpperBound( PostorderDFSNodeType::kConstantLowerBound, context) .AddVisit([root, opcode, this]( Literal upper_bound, - Literal lower_bound) -> StatusOr { + Literal lower_bound) -> absl::StatusOr { if (opcode == HloOpcode::kDivide && this->IsValueEffectiveInteger(root->operand_ids(1))) { // Because in many cases the lower bound of a value is @@ -771,7 +774,7 @@ StatusOr PostorderDFSVisitor::AnalyzeUpperBound( } case HloOpcode::kCustomCall: { if (root->custom_call_target() == "SetBound") { - return PostorderDFSNode().AddVisit([root]() -> StatusOr { + return PostorderDFSNode().AddVisit([root]() -> absl::StatusOr { if (root->literal().shape().element_type() == TUPLE) { // First literal of SetBound contains bounds, second literal // contains dynamism indicators. @@ -808,7 +811,7 @@ StatusOr PostorderDFSVisitor::AnalyzeUpperBound( } } -StatusOr PostorderDFSVisitor::AnalyzeLowerBound( +absl::StatusOr PostorderDFSVisitor::AnalyzeLowerBound( int64_t handle, InferenceContext context) { TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, handle_to_instruction(handle)); @@ -826,7 +829,7 @@ StatusOr PostorderDFSVisitor::AnalyzeLowerBound( TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, handle_to_instruction(operand_handle)); return PostorderDFSNode().AddVisit( - [dimension, operand_proto]() -> StatusOr { + [dimension, operand_proto]() -> absl::StatusOr { if (operand_proto->shape().is_dynamic_dimension(dimension)) { return LiteralUtil::CreateR0(0); } else { @@ -844,7 +847,7 @@ StatusOr PostorderDFSVisitor::AnalyzeLowerBound( .AddDependency(root->operand_ids(0), PostorderDFSNodeType::kConstantUpperBound, context) .AddVisit([this](Literal lower_bound, - Literal upper_bound) -> StatusOr { + Literal upper_bound) -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto lower_bound_abs, evaluator.EvaluateElementwiseUnaryOp( HloOpcode::kAbs, lower_bound)); @@ -860,7 +863,7 @@ StatusOr PostorderDFSVisitor::AnalyzeLowerBound( return PostorderDFSNode() .AddDependency(root->operand_ids(0), PostorderDFSNodeType::kConstantUpperBound, context) - .AddVisit([this](Literal upper_bound) -> StatusOr { + .AddVisit([this](Literal upper_bound) -> absl::StatusOr { return evaluator.EvaluateElementwiseUnaryOp(HloOpcode::kNegate, upper_bound); }); @@ -874,7 +877,8 @@ StatusOr PostorderDFSVisitor::AnalyzeLowerBound( .AddDependency(root->operand_ids(1), PostorderDFSNodeType::kConstantUpperBound, context) .AddVisit( - [root, this](absl::Span operands) -> StatusOr { + [root, + this](absl::Span operands) -> absl::StatusOr { return std::make_unique(evaluator, *root) ->WithOperands(operands) .Evaluate(); @@ -898,7 +902,7 @@ StatusOr PostorderDFSVisitor::AnalyzeLowerBound( } } -StatusOr PostorderDFSVisitor::AnalyzeConstant( +absl::StatusOr PostorderDFSVisitor::AnalyzeConstant( int64_t handle, InferenceContext context) { TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, handle_to_instruction(handle)); @@ -916,7 +920,7 @@ StatusOr PostorderDFSVisitor::AnalyzeConstant( TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, handle_to_instruction(operand_handle)); return PostorderDFSNode().AddVisit( - [operand_proto, dimension, root]() -> StatusOr { + [operand_proto, dimension, root]() -> absl::StatusOr { if (operand_proto->shape().is_dynamic_dimension(dimension)) { // The value is dynamic, we return garbage data here and mask them // out later. @@ -939,7 +943,8 @@ StatusOr PostorderDFSVisitor::AnalyzeConstant( context); } return result.AddVisit( - [root, this](absl::Span operands) -> StatusOr { + [root, + this](absl::Span operands) -> absl::StatusOr { return std::make_unique(evaluator, *root) ->WithOperands(operands) .Evaluate(); @@ -952,8 +957,9 @@ StatusOr PostorderDFSVisitor::AnalyzeConstant( return PostorderDFSNode() .AddDependency(root->operand_ids(0), PostorderDFSNodeType::kConstantValue, context) - .AddVisit( - [](Literal operand) -> StatusOr { return operand; }); + .AddVisit([](Literal operand) -> absl::StatusOr { + return operand; + }); } else if (root->custom_call_target() == "Sharding") { return PostorderDFSNode() .AddDependency(root->operand_ids(0), @@ -981,7 +987,7 @@ StatusOr PostorderDFSVisitor::AnalyzeConstant( handle_to_computation(root->called_computation_ids(0)); return result.AddVisit( [root, context, computation_proto, - this](absl::Span operands) -> StatusOr { + this](absl::Span operands) -> absl::StatusOr { TF_ASSIGN_OR_RETURN( auto computation, HloComputation::CreateFromProto(*computation_proto, {})); @@ -998,7 +1004,7 @@ StatusOr PostorderDFSVisitor::AnalyzeConstant( } } -StatusOr PostorderDFSVisitor::AnalyzeIsDynamic( +absl::StatusOr PostorderDFSVisitor::AnalyzeIsDynamic( int64_t handle, PostorderDFSNodeType type, InferenceContext context) { TF_RETURN_IF_ERROR(handle_to_instruction(handle).status()); // Invariant check. @@ -1028,7 +1034,7 @@ StatusOr PostorderDFSVisitor::AnalyzeIsDynamic( TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, handle_to_instruction(operand_handle)); return PostorderDFSNode().AddVisit( - [operand_proto, dimension, type]() -> StatusOr { + [operand_proto, dimension, type]() -> absl::StatusOr { if (type == PostorderDFSNodeType::kBoundIsDynamic) { // The bound of dynamic dimension is not dynamic. return LiteralUtil::CreateR0(false); @@ -1048,7 +1054,7 @@ StatusOr PostorderDFSVisitor::AnalyzeIsDynamic( } return dfs.AddVisit([root, context, type](absl::Span operands) - -> StatusOr { + -> absl::StatusOr { bool all_operands_values_static = true; for (int64_t i = 0; i < operands.size(); ++i) { all_operands_values_static &= operands[i].IsAll(0); @@ -1113,6 +1119,7 @@ StatusOr PostorderDFSVisitor::AnalyzeIsDynamic( case HloOpcode::kCollectivePermuteDone: case HloOpcode::kCos: case HloOpcode::kClz: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -1204,10 +1211,11 @@ StatusOr PostorderDFSVisitor::AnalyzeIsDynamic( call_proto->operand_ids(0)); node.AddDependency(call_root, PostorderDFSNodeType::kValueIsDynamic, branch_context, "callee's root instruction"); - return node.AddVisit([context](Literal operand) -> StatusOr { - // Forward result of callee's root to caller. - return operand; - }); + return node.AddVisit( + [context](Literal operand) -> absl::StatusOr { + // Forward result of callee's root to caller. + return operand; + }); } case HloOpcode::kConditional: { auto node = PostorderDFSNode(); @@ -1245,7 +1253,7 @@ StatusOr PostorderDFSVisitor::AnalyzeIsDynamic( // 2*i + 1: Branch value is dynamic. return node.AddVisit([root, branch_size, context](absl::Span operands) - -> StatusOr { + -> absl::StatusOr { int64_t pred_is_dynamic = operands[1].Get({}); auto result = CreatePredLiteral( true, @@ -1385,7 +1393,8 @@ StatusOr PostorderDFSVisitor::AnalyzeIsDynamic( .AddDependency(root->operand_ids(1), type, context) // rhs dependency. .AddDependency(root->operand_ids(2), type, context) - .AddVisit([root](absl::Span operands) -> StatusOr { + .AddVisit([root](absl::Span operands) + -> absl::StatusOr { OptionalLiteral optional_selector_literal(std::move(operands[0]), std::move(operands[1])); Literal lhs = std::move(operands[2]); @@ -1422,7 +1431,8 @@ StatusOr PostorderDFSVisitor::AnalyzeIsDynamic( .AddDependency(root->operand_ids(1), PostorderDFSNodeType::kValueIsDynamic, context) .AddVisit( - [root, this](absl::Span operands) -> StatusOr { + [root, + this](absl::Span operands) -> absl::StatusOr { OptionalLiteral optional_selector_literal( std::move(operands[1]), std::move(operands[2])); @@ -1443,7 +1453,8 @@ StatusOr PostorderDFSVisitor::AnalyzeIsDynamic( } case HloOpcode::kCustomCall: { if (root->custom_call_target() == "SetBound") { - return PostorderDFSNode().AddVisit([type, root]() -> StatusOr { + return PostorderDFSNode().AddVisit([type, + root]() -> absl::StatusOr { if (type == PostorderDFSNodeType::kBoundIsDynamic) { return CreatePredLiteral(false, Shape(root->shape())); } else { @@ -1475,8 +1486,8 @@ StatusOr PostorderDFSVisitor::AnalyzeIsDynamic( case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kWhile: { - return PostorderDFSNode().AddVisit([root, - context]() -> StatusOr { + return PostorderDFSNode().AddVisit([root, context]() + -> absl::StatusOr { return CreatePredLiteral( true, ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index)); @@ -1484,8 +1495,8 @@ StatusOr PostorderDFSVisitor::AnalyzeIsDynamic( break; } default: - return PostorderDFSNode().AddVisit([root, - context]() -> StatusOr { + return PostorderDFSNode().AddVisit([root, context]() + -> absl::StatusOr { return CreatePredLiteral( true, ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index)); @@ -1493,7 +1504,7 @@ StatusOr PostorderDFSVisitor::AnalyzeIsDynamic( } } -StatusOr PostorderDFSVisitor::PostOrderDFSVisit( +absl::StatusOr PostorderDFSVisitor::PostOrderDFSVisit( int64_t handle, PostorderDFSNodeType type) { enum VisitState { kUnvisited = 0, @@ -1608,7 +1619,7 @@ StatusOr PostorderDFSVisitor::PostOrderDFSVisit( return evaluated[root.GetCacheKey()].Clone(); } -StatusOr ValueInference::AnalyzeIsDynamic(XlaOp op) { +absl::StatusOr ValueInference::AnalyzeIsDynamic(XlaOp op) { PostorderDFSVisitor visitor( evaluator_, [&](int64_t handle) { @@ -1621,7 +1632,8 @@ StatusOr ValueInference::AnalyzeIsDynamic(XlaOp op) { return result; } -StatusOr> ValueInference::CseOpHandle(int64_t handle) { +absl::StatusOr> ValueInference::CseOpHandle( + int64_t handle) { TF_ASSIGN_OR_RETURN(auto inst, builder_->LookUpInstructionByHandle(handle)); TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(inst->opcode())); // For now, only handle kGetDimensionSize as that's the most duplicated one. @@ -1652,7 +1664,7 @@ StatusOr> ValueInference::CseOpHandle(int64_t handle) { return {std::nullopt}; } -StatusOr ValueInference::SimplifyOp(int64_t handle) { +absl::StatusOr ValueInference::SimplifyOp(int64_t handle) { TF_ASSIGN_OR_RETURN(auto cse_handle, CseOpHandle(handle)); if (cse_handle) { // Use the CSE'd handle instead. @@ -1768,7 +1780,7 @@ StatusOr ValueInference::SimplifyOp(int64_t handle) { } } -StatusOr ValueInference::AnalyzeConstant( +absl::StatusOr ValueInference::AnalyzeConstant( XlaOp op, ValueInferenceMode mode) { TF_RETURN_IF_ERROR(builder_->LookUpInstructionByHandle(op.handle()).status()); PostorderDFSVisitor visitor( diff --git a/xla/client/value_inference.h b/xla/client/value_inference.h index 47f9f7129ed9f..6f1685f1a42e0 100644 --- a/xla/client/value_inference.h +++ b/xla/client/value_inference.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -85,10 +85,11 @@ class ValueInference { explicit ValueInference(XlaBuilder* builder) : builder_(builder) { CHECK(builder_); } - StatusOr AnalyzeIsDynamic(XlaOp op); + absl::StatusOr AnalyzeIsDynamic(XlaOp op); // Returns an OptionalLiteral. Each individual value of the literal is // the concrete constant value if it can be inferred, otherwise a nullopt. - StatusOr AnalyzeConstant(XlaOp op, ValueInferenceMode mode); + absl::StatusOr AnalyzeConstant(XlaOp op, + ValueInferenceMode mode); // Returns underlying xla builder. XlaBuilder* builder() { return builder_; } @@ -97,11 +98,11 @@ class ValueInference { // Given an op handle, returns a simplified version of the handle inside a // int64_t Literal. If the a -1 value for the handle means invalid // simplification and the result shouldn't be used. - StatusOr SimplifyOp(int64_t handle); + absl::StatusOr SimplifyOp(int64_t handle); // Perform CSE on a given handle, and return an equivalent handle if seen // before. Otherwise, returns nullopt. - StatusOr> CseOpHandle(int64_t handle); + absl::StatusOr> CseOpHandle(int64_t handle); XlaBuilder* builder_; HloEvaluator evaluator_; // A map from instruction_hash to handle that helps perform CSE. diff --git a/xla/client/xla_builder.cc b/xla/client/xla_builder.cc index ccb353f052989..246ab2c384ba8 100644 --- a/xla/client/xla_builder.cc +++ b/xla/client/xla_builder.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -118,7 +118,7 @@ namespace internal { XlaOp XlaBuilderFriend::BuildAddDependency(XlaBuilder* builder, XlaOp operand, XlaOp token, const Shape& shape) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); return builder->AddInstruction(std::move(instr), HloOpcode::kAddDependency, @@ -131,7 +131,7 @@ XlaOp XlaBuilderFriend::BuildFusion( absl::string_view fusion_kind, const XlaComputation& fused_computation, absl::Span>> output_operand_aliasing) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; instr.set_fusion_kind(std::string(fusion_kind)); if (!output_operand_aliasing.empty()) { @@ -160,20 +160,11 @@ std::pair XlaBuilderFriend::BuildAsyncStart( XlaBuilder* builder, absl::Span operands, std::string execution_thread, const XlaComputation& called_computation, const Shape& shape) { - return BuildAsyncStart(builder, operands, execution_thread, /*group_id=*/-1, - called_computation, shape); -} - -std::pair XlaBuilderFriend::BuildAsyncStart( - XlaBuilder* builder, absl::Span operands, - std::string execution_thread, int64_t group_id, - const XlaComputation& called_computation, const Shape& shape) { int64_t called_computation_id; - auto start_op = builder->ReportErrorOrReturn([&]() -> StatusOr { + auto start_op = builder->ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); instr.set_async_execution_thread(execution_thread); - instr.set_async_group_id(group_id); builder->AddCalledComputation(called_computation, &instr); called_computation_id = instr.called_computation_ids()[0]; return builder->AddInstruction(std::move(instr), HloOpcode::kAsyncStart, @@ -187,18 +178,10 @@ XlaOp XlaBuilderFriend::BuildAsyncUpdate(XlaBuilder* builder, std::string execution_thread, int64_t called_computation, const Shape& shape) { - return BuildAsyncUpdate(builder, operand, execution_thread, /*group_id=*/-1, - called_computation, shape); -} - -XlaOp XlaBuilderFriend::BuildAsyncUpdate( - XlaBuilder* builder, const XlaOp operand, std::string execution_thread, - int64_t group_id, int64_t called_computation, const Shape& shape) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); instr.set_async_execution_thread(execution_thread); - instr.set_async_group_id(group_id); instr.add_called_computation_ids(called_computation); return builder->AddInstruction(std::move(instr), HloOpcode::kAsyncUpdate, {operand}); @@ -209,20 +192,10 @@ XlaOp XlaBuilderFriend::BuildAsyncDone(XlaBuilder* builder, const XlaOp operand, std::string execution_thread, int64_t called_computation, const Shape& shape) { - return BuildAsyncDone(builder, operand, execution_thread, /*group_id=*/-1, - called_computation, shape); -} - -XlaOp XlaBuilderFriend::BuildAsyncDone(XlaBuilder* builder, const XlaOp operand, - std::string execution_thread, - int64_t group_id, - int64_t called_computation, - const Shape& shape) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); instr.set_async_execution_thread(execution_thread); - instr.set_async_group_id(group_id); instr.add_called_computation_ids(called_computation); return builder->AddInstruction(std::move(instr), HloOpcode::kAsyncDone, {operand}); @@ -243,7 +216,7 @@ XlaOp XlaBuilderFriend::BuildAllGatherStart( XlaOp XlaBuilderFriend::BuildAllGatherDone(XlaBuilder* builder, const XlaOp operand, const Shape& shape) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); return builder->AddInstruction(std::move(instr), HloOpcode::kAllGatherDone, @@ -265,7 +238,7 @@ XlaOp XlaBuilderFriend::BuildAllReduceStart( XlaOp XlaBuilderFriend::BuildAllReduceDone(XlaBuilder* builder, const XlaOp operand, const Shape& shape) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); return builder->AddInstruction(std::move(instr), HloOpcode::kAllReduceDone, @@ -276,7 +249,7 @@ XlaOp XlaBuilderFriend::BuildAllReduceDone(XlaBuilder* builder, XlaOp XlaBuilderFriend::BuildCopyStart( XlaBuilder* builder, const XlaOp operand, std::optional cross_program_prefetch_index) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; if (cross_program_prefetch_index) { instr.set_cross_program_prefetch_index(*cross_program_prefetch_index); @@ -296,7 +269,7 @@ XlaOp XlaBuilderFriend::BuildCopyStart( XlaOp XlaBuilderFriend::BuildCopyDone(XlaBuilder* builder, const XlaOp operand, const Shape& shape) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); return builder->AddInstruction(std::move(instr), HloOpcode::kCopyDone, @@ -315,7 +288,7 @@ XlaOp XlaBuilderFriend::BuildCollectivePermuteStart( XlaOp XlaBuilderFriend::BuildCollectivePermuteDone(XlaBuilder* builder, const XlaOp operand, const Shape& shape) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); return builder->AddInstruction( @@ -325,7 +298,7 @@ XlaOp XlaBuilderFriend::BuildCollectivePermuteDone(XlaBuilder* builder, XlaOp XlaBuilderFriend::BuildBitcast(XlaBuilder* builder, XlaOp operand, const Shape& shape) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); return builder->AddInstruction(std::move(instr), HloOpcode::kBitcast, @@ -336,7 +309,7 @@ XlaOp XlaBuilderFriend::BuildBitcast(XlaBuilder* builder, XlaOp operand, XlaOp XlaBuilderFriend::BuildDomain(XlaBuilder* builder, XlaOp operand, const OpSharding entry, const OpSharding exit, const Shape& shape) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; *instr.mutable_domain_entry_sharding() = entry; *instr.mutable_domain_exit_sharding() = exit; @@ -348,7 +321,7 @@ XlaOp XlaBuilderFriend::BuildDomain(XlaBuilder* builder, XlaOp operand, XlaOp XlaBuilderFriend::BuildPartitionId(XlaBuilder* builder, const Shape& shape) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); return builder->AddInstruction(std::move(instr), HloOpcode::kPartitionId); @@ -358,7 +331,7 @@ XlaOp XlaBuilderFriend::BuildPartitionId(XlaBuilder* builder, XlaOp XlaBuilderFriend::BuildSend(XlaBuilder* builder, XlaOp operand, XlaOp token, const ChannelHandle& handle, bool is_host_transfer) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto send_instr; TF_ASSIGN_OR_RETURN(const Shape* shape, builder->GetShapePtr(operand)); // Send instruction produces a tuple of {aliased operand, U32 context, @@ -377,7 +350,7 @@ XlaOp XlaBuilderFriend::BuildSend(XlaBuilder* builder, XlaOp operand, XlaOp XlaBuilderFriend::BuildSendDone(XlaBuilder* builder, XlaOp operand, const ChannelHandle& handle, bool is_host_transfer) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto send_done_instr; *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); send_done_instr.set_channel_id(handle.handle()); @@ -391,7 +364,7 @@ XlaOp XlaBuilderFriend::BuildRecv(XlaBuilder* builder, XlaOp token, const Shape& shape, const ChannelHandle& handle, bool is_host_transfer) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { // Recv instruction produces a tuple of {receive buffer, U32 context, // token}. HloInstructionProto recv_instr; @@ -410,7 +383,7 @@ XlaOp XlaBuilderFriend::BuildRecvDone(XlaBuilder* builder, XlaOp token, const Shape& shape, const ChannelHandle& handle, bool is_host_transfer) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto recv_done_instr; *recv_done_instr.mutable_shape() = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}) @@ -426,7 +399,7 @@ XlaOp XlaBuilderFriend::BuildRngGetAndUpdateState(XlaBuilder* builder, int64_t delta, const Shape& shape) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; instr.set_delta(delta); *instr.mutable_shape() = shape.ToProto(); @@ -462,7 +435,7 @@ XlaOp operator<<(XlaOp x, XlaOp y) { return ShiftLeft(x, y); } XlaOp operator>>(XlaOp x, XlaOp y) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* shape, builder->GetShapePtr(x)); if (!ShapeUtil::ElementIsIntegral(*shape)) { return InvalidArgument( @@ -477,7 +450,7 @@ XlaOp operator>>(XlaOp x, XlaOp y) { }); } -StatusOr XlaBuilder::GetShapePtr(XlaOp op) const { +absl::StatusOr XlaBuilder::GetShapePtr(XlaOp op) const { TF_RETURN_IF_ERROR(first_error_); TF_RETURN_IF_ERROR(CheckOpBuilder(op)); auto it = handle_to_index_.find(op.handle()); @@ -487,12 +460,12 @@ StatusOr XlaBuilder::GetShapePtr(XlaOp op) const { return instruction_shapes_.at(it->second).get(); } -StatusOr XlaBuilder::GetShape(XlaOp op) const { +absl::StatusOr XlaBuilder::GetShape(XlaOp op) const { TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(op)); return *shape; } -StatusOr> XlaBuilder::GetOperandShapes( +absl::StatusOr> XlaBuilder::GetOperandShapes( absl::Span operands) const { std::vector operand_shapes; operand_shapes.reserve(operands.size()); @@ -559,7 +532,7 @@ XlaOp XlaBuilder::ReportError(const Status& error) { return XlaOp(this); } -XlaOp XlaBuilder::ReportErrorOrReturn(const StatusOr& op) { +XlaOp XlaBuilder::ReportErrorOrReturn(const absl::StatusOr& op) { if (!first_error_.ok()) { return XlaOp(this); } @@ -570,11 +543,12 @@ XlaOp XlaBuilder::ReportErrorOrReturn(const StatusOr& op) { } XlaOp XlaBuilder::ReportErrorOrReturn( - absl::FunctionRef()> op_creator) { + absl::FunctionRef()> op_creator) { return ReportErrorOrReturn(op_creator()); } -StatusOr XlaBuilder::GetProgramShape(int64_t root_id) const { +absl::StatusOr XlaBuilder::GetProgramShape( + int64_t root_id) const { TF_RETURN_IF_ERROR(first_error_); TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto, LookUpInstructionByHandle(root_id)); @@ -605,12 +579,12 @@ StatusOr XlaBuilder::GetProgramShape(int64_t root_id) const { return program_shape; } -StatusOr XlaBuilder::GetProgramShape() const { +absl::StatusOr XlaBuilder::GetProgramShape() const { TF_RET_CHECK(!instructions_.empty()); return GetProgramShape(instructions_.back().id()); } -StatusOr XlaBuilder::GetProgramShape(XlaOp root) const { +absl::StatusOr XlaBuilder::GetProgramShape(XlaOp root) const { if (root.builder_ != this) { return InvalidArgument("Given root operation is not in this computation."); } @@ -729,21 +703,22 @@ Status XlaBuilder::GetCurrentStatus() const { return OkStatus(); } -StatusOr XlaBuilder::Build(bool remove_dynamic_dimensions) { +absl::StatusOr XlaBuilder::Build( + bool remove_dynamic_dimensions) { TF_RETURN_IF_ERROR(GetCurrentStatus()); return Build(instructions_.back().id(), remove_dynamic_dimensions); } -StatusOr XlaBuilder::Build(XlaOp root, - bool remove_dynamic_dimensions) { +absl::StatusOr XlaBuilder::Build( + XlaOp root, bool remove_dynamic_dimensions) { if (root.builder_ != this) { return InvalidArgument("Given root operation is not in this computation."); } return Build(root.handle(), remove_dynamic_dimensions); } -StatusOr XlaBuilder::Build(int64_t root_id, - bool remove_dynamic_dimensions) { +absl::StatusOr XlaBuilder::Build( + int64_t root_id, bool remove_dynamic_dimensions) { TF_RETURN_IF_ERROR(GetCurrentStatus()); // TODO(b/121223198): XLA backend cannot handle dynamic dimensions yet, remove @@ -864,7 +839,70 @@ StatusOr XlaBuilder::Build(int64_t root_id, return OkStatus(); } -StatusOr XlaBuilder::InDimBroadcast( +XlaOp XlaBuilder::DynamicBroadcastInDim( + const XlaOp operand, const XlaOp output_dimensions, + absl::Span broadcast_dimensions, const Shape& output_shape) { + return ReportErrorOrReturn([&]() -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + TF_ASSIGN_OR_RETURN(const Shape* output_dimensions_shape, + GetShapePtr(output_dimensions)); + + if (!output_dimensions_shape->IsInteger()) { + return InvalidArgument("output_dimensions must be an integer type %s", + output_dimensions_shape->ToString()); + } + + if (output_dimensions_shape->rank() != 1) { + return InvalidArgument("output_dimensions must be rank 1 but got rank %d", + output_dimensions_shape->rank()); + } + + int64_t operand_rank = operand_shape->rank(); + int64_t result_rank = output_shape.rank(); + int64_t broadcast_dimensions_size = broadcast_dimensions.size(); + if (broadcast_dimensions_size != operand_rank) { + return InvalidArgument( + "broadcast_dimensions size (%d) does not match operand rank (%d)", + broadcast_dimensions_size, operand_rank); + } + + if (result_rank < operand_rank) { + return InvalidArgument("result rank (%d) is less than operand rank (%d)", + result_rank, operand_rank); + } + + for (int64_t i = 0; i != broadcast_dimensions_size; ++i) { + int64_t dim_index = broadcast_dimensions[i]; + if (dim_index < 0 || dim_index >= result_rank) { + return InvalidArgument( + "broadcast_dimensions contains invalid value %d for result with " + "rank %d", + dim_index, result_rank); + } + + int64_t dim_size = operand_shape->dimensions(i); + int64_t result_dim_size = output_shape.dimensions(dim_index); + + if (dim_size != 1 && dim_size != result_dim_size && + dim_size != Shape::kUnboundedSize) { + return InvalidArgument( + "size of operand dimension %d (%d) is not compatible with size of " + "result dimension %d (%d)", + i, dim_size, dim_index, result_dim_size); + } + } + + return xla::CustomCall( + operand.builder(), "mhlo.dynamic_broadcast_in_dim", + /*operands=*/{operand, output_dimensions}, + /*shape=*/output_shape, + /*opaque=*/ + absl::StrCat("{broadcast_dimensions=[", + absl::StrJoin(broadcast_dimensions, ","), "]}")); + }); +} + +absl::StatusOr XlaBuilder::InDimBroadcast( const Shape& shape, XlaOp operand, absl::Span broadcast_dimensions) { TF_RETURN_IF_ERROR(first_error_); @@ -876,26 +914,28 @@ StatusOr XlaBuilder::InDimBroadcast( } TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + TF_RET_CHECK(!shape.is_unbounded_dynamic()) + << "broadcast op result shapes must be static"; for (int64_t i = 0; i < shape.rank(); i++) { if (auto it = absl::c_find(broadcast_dimensions, i); it != broadcast_dimensions.end()) { // Broadcast dimensions are permitted to be dynamic iff the operand // dimension is dynamic. - TF_RET_CHECK(operand_shape->is_dynamic_dimension( + TF_RET_CHECK(operand_shape->is_bounded_dynamic_dimension( it - broadcast_dimensions.begin()) == - shape.is_dynamic_dimension(i)) + shape.is_bounded_dynamic_dimension(i)) << " i: " << i << ", shape: " << shape.ToString() << ", operand_shape: " << operand_shape->ToString(); } else { - // Non-broadcast dimensions must not be dynamic. - TF_RET_CHECK(!shape.is_dynamic_dimension(i)); + // Non-broadcast dimensions must be static. + TF_RET_CHECK(shape.is_static_dimension(i)); } } return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand}); } -StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, - XlaOp operand) { +absl::StatusOr XlaBuilder::AddBroadcastSequence( + const Shape& output_shape, XlaOp operand) { TF_RETURN_IF_ERROR(first_error_); TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); @@ -923,7 +963,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, operand_shape->is_dynamic_dimension(i)); } else { TF_RET_CHECK(operand_shape->dimensions(i) == 1 && - !operand_shape->is_dynamic_dimension(i)) + operand_shape->is_static_dimension(i)) << "An explicit broadcast sequence requires the broadcasted " "dimensions to be trivial; operand shape: " << *operand_shape << "; output_shape: " << output_shape; @@ -937,16 +977,22 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, reshaped_dynamic_dimensions); // Eliminate the size one dimensions. - TF_ASSIGN_OR_RETURN( - XlaOp reshaped_operand, - ReshapeInternal(reshaped_shape, operand, /*inferred_dimension=*/-1)); + // The added reshape reduces the rank of the tensor. Hence we cannot directly + // apply the broadcast's sharding on reshape. + XlaOp reshaped_operand; + { + XlaScopedShardingAssignment scoped_sharding(this, std::nullopt); + TF_ASSIGN_OR_RETURN( + reshaped_operand, + ReshapeInternal(reshaped_shape, operand, /*inferred_dimension=*/-1)); + } // Broadcast 'reshape' up to the larger size. return InDimBroadcast(broadcast_shape, reshaped_operand, broadcast_dimensions); } XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN( Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape)); @@ -954,58 +1000,169 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) { }); } +namespace { + +// Broadcasts an origin XLA op to the rank of target_shape. +// Does not broadcast rank dimensions to match, only expands rank. +// Is identity function if origin rank matches target rank. +absl::StatusOr BroadcastToTargetRank( + XlaOp origin, const Shape& origin_shape, const Shape& target_shape, + absl::Span broadcast_dimensions) { + if (ShapeUtil::IsScalar(origin_shape)) { + return origin; + } + + const int64_t origin_rank = origin_shape.rank(); + const int64_t target_rank = target_shape.rank(); + + // Identity op if ranks match, should never be larger than target. + if (origin_rank >= target_rank) { + return origin; + } + + // Update target_size with origin sizes using broadcast_dimensions + absl::Span target_dimensions = target_shape.dimensions(); + std::vector target_size{target_dimensions.begin(), + target_dimensions.end()}; + for (int64_t origin_dim = 0; origin_dim < origin_rank; origin_dim++) { + int64_t target_dim = broadcast_dimensions[origin_dim]; + target_size[target_dim] = origin_shape.dimensions(origin_dim); + } + return xla::BroadcastInDim(origin, target_size, broadcast_dimensions); +} + +// Extract the `num_dims` counts of dimension sizes from the `op`. First, +// prepend `pad_count` of 1's reshaped to `tensor<1xi32>` to `op_dims`. If size +// is static, append them at `op_dims`. If size is dynamic, get the dimension +// size, reshape them to `tensor<1xi32>`, and append them at `op_dims`. +absl::StatusOr> ExtractDimensionSizesAndPadOnesToLeft( + XlaBuilder* builder, XlaOp op, size_t num_dims, int pad_count) { + TF_ASSIGN_OR_RETURN(const Shape* op_shape, builder->GetShapePtr(op)); + std::vector op_dims( + pad_count, xla::ConstantR1(builder, absl::Span({1}))); + for (size_t i = 0; i < num_dims; i++) { + op_dims.push_back( + op_shape->is_static_dimension(i) + ? ConstantR1(builder, + absl::Span( + {static_cast(op_shape->dimensions(i))})) + : xla::Reshape(xla::GetDimensionSize(op, i), {1})); + } + return op_dims; +} + +// Broadcast `scalar` to `output_shape` with all shapes static at runtime. If a +// dimension of `output_shape` is dynamic, get the dimension size of the dynamic +// dimension from `output` and reshape them to `tensor<1xi32>`. This is used as +// one of the inputs to DynamicBroadcastInDim. +absl::StatusOr BroadcastScalarToOutputShapeWithUnbounded( + XlaBuilder* builder, XlaOp scalar, XlaOp output, + const Shape& output_shape) { + TF_ASSIGN_OR_RETURN(const Shape* scalar_shape, builder->GetShapePtr(scalar)); + CHECK(ShapeUtil::IsScalar(*scalar_shape)); + + std::vector output_sizes(output_shape.rank()); + for (size_t i = 0; i < output_shape.rank(); i++) { + output_sizes[i] = + output_shape.is_static_dimension(i) + ? ConstantR1(builder, + absl::Span({static_cast( + output_shape.dimensions(i))})) + : xla::Reshape(xla::GetDimensionSize(output, i), {1}); + } + return xla::DynamicBroadcastInDim( + scalar, /*output_dimensions=*/ConcatInDim(builder, output_sizes, 0), {}, + output_shape); +} + +// The shape of `operand` is broadcasted to the values in `output_dimensions` if +// the dimension size is degenerate (dimension size is 1). +absl::StatusOr DegenerateBroadcastWithUnbounded( + XlaBuilder* builder, XlaOp operand, XlaOp output_dimensions, + const Shape& output_shape) { + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, + builder->GetShapePtr(operand)); + + std::vector broadcast_dimensions(operand_shape->rank()); + std::iota(broadcast_dimensions.begin(), broadcast_dimensions.end(), + output_shape.rank() - operand_shape->rank()); + + return xla::DynamicBroadcastInDim(operand, output_dimensions, + broadcast_dimensions, output_shape); +} + +// Helper struct to store the result of `BroadcastToOutputShapeWithUnbounded`. +struct UnboundedBroadcastResult { + XlaOp lhs; + XlaOp rhs; +}; + +// Broadcast `lhs` and `rhs` to `output_shape` with unbounded dimensions where +// `lhs` or `rhs` are possibly different ranks than `output_shape`. +absl::StatusOr BroadcastToOutputShapeWithUnbounded( + XlaBuilder* builder, XlaOp lhs, const Shape& lhs_shape, XlaOp rhs, + const Shape rhs_shape, const Shape& output_shape, + absl::Span broadcast_dimensions) { + const int64_t lhs_rank = lhs_shape.rank(); + const int64_t rhs_rank = rhs_shape.rank(); + const int64_t output_rank = output_shape.rank(); + + // If the rank of the op is less than the output rank, pad the dimension + // sizes of the op with 1's to match the output rank. + TF_ASSIGN_OR_RETURN(std::vector lhs_dims, + ExtractDimensionSizesAndPadOnesToLeft( + builder, lhs, lhs_rank, output_rank - lhs_rank)); + TF_ASSIGN_OR_RETURN(std::vector rhs_dims, + ExtractDimensionSizesAndPadOnesToLeft( + builder, rhs, rhs_rank, output_rank - rhs_rank)); + + // The output dimensions of the dynamic broadcast is the maximum of the input + // shapes. The `output_dimensions` refer to the runtime shape and should not + // contain any dynamic sizes at run time. + XlaOp output_dimensions = + Max(ConcatInDim(builder, lhs_dims, 0), ConcatInDim(builder, rhs_dims, 0)); + + // Broadcast `lhs` and `rhs` to `output_shape`. + TF_ASSIGN_OR_RETURN(XlaOp lhs_result, + DegenerateBroadcastWithUnbounded( + builder, lhs, output_dimensions, output_shape)); + TF_ASSIGN_OR_RETURN(XlaOp rhs_result, + DegenerateBroadcastWithUnbounded( + builder, rhs, output_dimensions, output_shape)); + return UnboundedBroadcastResult{lhs_result, rhs_result}; +} + +} // namespace + XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions, std::optional direction, std::optional type) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); TF_ASSIGN_OR_RETURN( Shape shape, ShapeInference::InferBinaryOpShape( binop, *lhs_shape, *rhs_shape, broadcast_dimensions)); - const int64_t lhs_rank = lhs_shape->rank(); - const int64_t rhs_rank = rhs_shape->rank(); - XlaOp updated_lhs = lhs; XlaOp updated_rhs = rhs; - if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) { - const bool should_broadcast_lhs = lhs_rank < rhs_rank; - XlaOp from = should_broadcast_lhs ? lhs : rhs; - const Shape& from_shape = should_broadcast_lhs ? *lhs_shape : *rhs_shape; - - std::vector to_size; - std::vector to_size_is_dynamic; - const auto rank = shape.rank(); - to_size.reserve(rank); - to_size_is_dynamic.reserve(rank); - for (int i = 0; i < rank; i++) { - to_size.push_back(shape.dimensions(i)); - to_size_is_dynamic.push_back(false); + if (!lhs_shape->is_unbounded_dynamic() && + !rhs_shape->is_unbounded_dynamic()) { + if (lhs_shape->rank() < shape.rank()) { + TF_ASSIGN_OR_RETURN(updated_lhs, + BroadcastToTargetRank(lhs, *lhs_shape, shape, + broadcast_dimensions)); } - for (int64_t from_dim = 0; from_dim < from_shape.rank(); from_dim++) { - int64_t to_dim = broadcast_dimensions[from_dim]; - to_size[to_dim] = from_shape.dimensions(from_dim); - to_size_is_dynamic[to_dim] = from_shape.is_dynamic_dimension(from_dim); + if (rhs_shape->rank() < shape.rank()) { + TF_ASSIGN_OR_RETURN(updated_rhs, + BroadcastToTargetRank(rhs, *rhs_shape, shape, + broadcast_dimensions)); } - - const Shape& broadcasted_shape = ShapeUtil::MakeShape( - from_shape.element_type(), to_size, to_size_is_dynamic); - TF_ASSIGN_OR_RETURN( - XlaOp broadcasted_operand, - InDimBroadcast(broadcasted_shape, from, broadcast_dimensions)); - - updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs; - updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs; - } - - TF_ASSIGN_OR_RETURN(const Shape* updated_lhs_shape, - GetShapePtr(updated_lhs)); - TF_ASSIGN_OR_RETURN(const Shape* updated_rhs_shape, - GetShapePtr(updated_rhs)); - if (!updated_lhs_shape->is_unbounded_dynamic() && - !updated_rhs_shape->is_unbounded_dynamic()) { + TF_ASSIGN_OR_RETURN(const Shape* updated_lhs_shape, + GetShapePtr(updated_lhs)); + TF_ASSIGN_OR_RETURN(const Shape* updated_rhs_shape, + GetShapePtr(updated_rhs)); if (!ShapeUtil::SameDimensions(shape, *updated_lhs_shape)) { TF_ASSIGN_OR_RETURN(updated_lhs, AddBroadcastSequence(shape, updated_lhs)); @@ -1014,6 +1171,30 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, TF_ASSIGN_OR_RETURN(updated_rhs, AddBroadcastSequence(shape, updated_rhs)); } + } else { + if (ShapeUtil::IsScalar(*lhs_shape) || ShapeUtil::IsScalar(*rhs_shape)) { + if (ShapeUtil::IsScalar(*lhs_shape)) { + TF_ASSIGN_OR_RETURN(updated_lhs, + BroadcastScalarToOutputShapeWithUnbounded( + this, lhs, rhs, *rhs_shape)); + } + if (ShapeUtil::IsScalar(*rhs_shape)) { + TF_ASSIGN_OR_RETURN(updated_rhs, + BroadcastScalarToOutputShapeWithUnbounded( + this, rhs, lhs, *lhs_shape)); + } + } else { + if (!ShapeUtil::SameDimensions(*lhs_shape, *rhs_shape)) { + Shape output_shape = shape; + output_shape.set_element_type(lhs_shape->element_type()); + TF_ASSIGN_OR_RETURN(UnboundedBroadcastResult broadcast_result, + BroadcastToOutputShapeWithUnbounded( + this, lhs, *lhs_shape, rhs, *rhs_shape, + output_shape, broadcast_dimensions)); + updated_lhs = broadcast_result.lhs; + updated_rhs = broadcast_result.rhs; + } + } } if (binop == HloOpcode::kCompare) { @@ -1039,24 +1220,26 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, XlaOp XlaBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), binop, {lhs, rhs}); }); } -StatusOr XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, - ComparisonDirection direction) { +absl::StatusOr XlaBuilder::Compare(const Shape& shape, XlaOp lhs, + XlaOp rhs, + ComparisonDirection direction) { TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(lhs)); return Compare( shape, lhs, rhs, direction, Comparison::DefaultComparisonType(operand_shape.element_type())); } -StatusOr XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, - ComparisonDirection direction, - Comparison::Type type) { +absl::StatusOr XlaBuilder::Compare(const Shape& shape, XlaOp lhs, + XlaOp rhs, + ComparisonDirection direction, + Comparison::Type type) { HloInstructionProto instr; instr.set_comparison_direction(ComparisonDirectionToString(direction)); instr.set_comparison_type(ComparisonTypeToString(type)); @@ -1064,11 +1247,32 @@ StatusOr XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, return AddInstruction(std::move(instr), HloOpcode::kCompare, {lhs, rhs}); } +absl::StatusOr XlaBuilder::BroadcastScalarToOutputShape(XlaOp scalar, + XlaOp output) { + TF_ASSIGN_OR_RETURN(const Shape* scalar_shape, GetShapePtr(scalar)); + TF_ASSIGN_OR_RETURN(const Shape* output_shape, GetShapePtr(output)); + + XlaOp updated_output = scalar; + if (output_shape->is_unbounded_dynamic()) { + Shape output_shape_copy = *output_shape; + output_shape_copy.set_element_type(scalar_shape->element_type()); + TF_ASSIGN_OR_RETURN(updated_output, + BroadcastScalarToOutputShapeWithUnbounded( + this, scalar, output, output_shape_copy)); + return updated_output; + } + + TF_ASSIGN_OR_RETURN(updated_output, + AddBroadcastSequence(*output_shape, updated_output)); + return updated_output; +} + XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { XlaOp updated_lhs = lhs; XlaOp updated_rhs = rhs; XlaOp updated_ehs = ehs; + // The client API supports implicit broadcast for kSelect and kClamp, but // XLA does not support implicit broadcast. Make implicit broadcast explicit // and update the operands. @@ -1076,34 +1280,36 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) { TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(ehs)); + TF_ASSIGN_OR_RETURN( + std::optional output_shape, + ShapeInference::InferScalarBroadcastShape( + absl::Span({*lhs_shape, *rhs_shape, *ehs_shape}))); - std::optional non_scalar_shape; - for (const Shape* shape : {lhs_shape, rhs_shape, ehs_shape}) { - if (shape->IsArray() && shape->rank() != 0) { - if (non_scalar_shape.has_value()) { - // TODO(jpienaar): The case where we need to compute the broadcasted - // shape by considering multiple of the shapes is not implemented. - // Consider reusing getBroadcastedType from mlir/Dialect/Traits.h. - TF_RET_CHECK(non_scalar_shape.value().dimensions() == - shape->dimensions()) - << "Unimplemented implicit broadcast."; - } else { - non_scalar_shape = ShapeUtil::MakeStaticShape(*shape); - } - } - } - if (non_scalar_shape.has_value()) { + // Scalar broadcast if mix of scalars and non-scalars + if (output_shape.has_value()) { if (ShapeUtil::IsScalar(*lhs_shape)) { - TF_ASSIGN_OR_RETURN(updated_lhs, - AddBroadcastSequence(*non_scalar_shape, lhs)); + TF_ASSIGN_OR_RETURN( + updated_lhs, + BroadcastScalarToOutputShape( + /*scalar=*/lhs, + /*output=*/ + ShapeUtil::Equal(*output_shape, *rhs_shape) ? rhs : ehs)); } if (ShapeUtil::IsScalar(*rhs_shape)) { - TF_ASSIGN_OR_RETURN(updated_rhs, - AddBroadcastSequence(*non_scalar_shape, rhs)); + TF_ASSIGN_OR_RETURN( + updated_rhs, + BroadcastScalarToOutputShape( + /*scalar=*/rhs, + /*output=*/ + ShapeUtil::Equal(*output_shape, *lhs_shape) ? lhs : ehs)); } if (ShapeUtil::IsScalar(*ehs_shape)) { - TF_ASSIGN_OR_RETURN(updated_ehs, - AddBroadcastSequence(*non_scalar_shape, ehs)); + TF_ASSIGN_OR_RETURN( + updated_ehs, + BroadcastScalarToOutputShape( + /*scalar=*/ehs, + /*output=*/ + ShapeUtil::Equal(*output_shape, *lhs_shape) ? lhs : rhs)); } } } @@ -1111,21 +1317,17 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) { TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(updated_lhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(updated_rhs)); TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(updated_ehs)); - StatusOr status_or_shape = ShapeInference::InferTernaryOpShape( - triop, *lhs_shape, *rhs_shape, *ehs_shape); - if (!status_or_shape.status().ok()) { - return InvalidArgument( - "%s Input scalar shapes may have been changed to non-scalar shapes.", - status_or_shape.status().message()); - } + TF_ASSIGN_OR_RETURN(const Shape inferred_shape, + ShapeInference::InferTernaryOpShape( + triop, *lhs_shape, *rhs_shape, *ehs_shape)); - return AddOpWithShape(triop, status_or_shape.value(), + return AddOpWithShape(triop, inferred_shape, {updated_lhs, updated_rhs, updated_ehs}); }); } XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { if (literal.shape().IsArray() && literal.element_count() > 1 && literal.IsAllFirst()) { Literal scalar = LiteralUtil::GetFirstScalarLiteral(literal); @@ -1146,7 +1348,7 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { } XlaOp XlaBuilder::Iota(const Shape& shape, int64_t iota_dimension) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { if (!shape.is_static()) { return InvalidArgument( "The output of iota must not have dynamic dimensions: %s", @@ -1165,7 +1367,7 @@ XlaOp XlaBuilder::Iota(PrimitiveType type, int64_t size) { XlaOp XlaBuilder::Call(const XlaComputation& computation, absl::Span operands) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); @@ -1187,7 +1389,7 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation, XlaOp XlaBuilder::Parameter( int64_t parameter_number, const Shape& shape, const std::string& name, const std::vector& replicated_at_leaf_buffers) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; if (!parameter_numbers_.insert(parameter_number).second) { return InvalidArgument("parameter %d already registered", @@ -1208,7 +1410,7 @@ XlaOp XlaBuilder::Parameter( XlaOp XlaBuilder::Broadcast(XlaOp operand, absl::Span broadcast_sizes) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN( const Shape& shape, @@ -1231,15 +1433,18 @@ XlaOp XlaBuilder::Broadcast(XlaOp operand, } XlaOp XlaBuilder::BroadcastInDim( - XlaOp operand, const absl::Span out_dim_size, - const absl::Span broadcast_dimensions) { - return ReportErrorOrReturn([&]() -> StatusOr { + XlaOp operand, absl::Span out_dim_size, + absl::Span broadcast_dimensions) { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); // Output shape, in the case of degenerate broadcast, the out_dim_size is // not necessarily the same as the dimension sizes of the output shape. TF_ASSIGN_OR_RETURN(auto output_shape, ShapeUtil::MakeValidatedShape( operand_shape->element_type(), out_dim_size)); + TF_RET_CHECK(!output_shape.is_unbounded_dynamic()) + << "BroadcastInDim output must shape be static or bounded dynamic " + << output_shape.ToString(); int64_t broadcast_rank = broadcast_dimensions.size(); if (operand_shape->rank() != broadcast_rank) { return InvalidArgument( @@ -1254,7 +1459,8 @@ XlaOp XlaBuilder::BroadcastInDim( broadcast_dimensions[i]); } output_shape.set_dynamic_dimension( - broadcast_dimensions[i], operand_shape->is_dynamic_dimension(i)); + broadcast_dimensions[i], + operand_shape->is_bounded_dynamic_dimension(i)); } TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape( @@ -1263,9 +1469,12 @@ XlaOp XlaBuilder::BroadcastInDim( std::vector in_dim_size(out_dim_size.begin(), out_dim_size.end()); std::vector in_dim_dynamic(out_dim_size.size(), false); for (int i = 0; i < broadcast_rank; i++) { - in_dim_size[broadcast_dimensions[i]] = operand_shape->dimensions(i); + in_dim_size[broadcast_dimensions[i]] = + (operand_shape->is_unbounded_dynamic_dimension(i)) + ? out_dim_size[broadcast_dimensions[i]] + : operand_shape->dimensions(i); in_dim_dynamic[broadcast_dimensions[i]] = - operand_shape->is_dynamic_dimension(i); + operand_shape->is_bounded_dynamic_dimension(i); } const auto& in_dim_shape = ShapeUtil::MakeShape( operand_shape->element_type(), in_dim_size, in_dim_dynamic); @@ -1283,9 +1492,14 @@ XlaOp XlaBuilder::BroadcastInDim( }); } -StatusOr XlaBuilder::ReshapeInternal(const Shape& shape, XlaOp operand, - int64_t inferred_dimension) { +absl::StatusOr XlaBuilder::ReshapeInternal(const Shape& shape, + XlaOp operand, + int64_t inferred_dimension) { TF_RETURN_IF_ERROR(first_error_); + if (shape.is_unbounded_dynamic()) { + return InvalidArgument( + "Reshaping with unbounded result shape is not supported."); + } HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); @@ -1298,7 +1512,7 @@ StatusOr XlaBuilder::ReshapeInternal(const Shape& shape, XlaOp operand, XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSliceShape( *operand_shape, start_indices, @@ -1307,7 +1521,7 @@ XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span start_indices, }); } -StatusOr XlaBuilder::SliceInternal( +absl::StatusOr XlaBuilder::SliceInternal( const Shape& shape, XlaOp operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides) { @@ -1325,7 +1539,7 @@ StatusOr XlaBuilder::SliceInternal( XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, int64_t stride, int64_t dimno) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand)); std::vector starts(shape->rank(), 0); std::vector limits(shape->dimensions().begin(), @@ -1341,7 +1555,7 @@ XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64_t start_index, XlaOp XlaBuilder::DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); std::vector start_indices_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes, @@ -1356,7 +1570,7 @@ XlaOp XlaBuilder::DynamicSlice(XlaOp operand, }); } -StatusOr XlaBuilder::DynamicSliceInternal( +absl::StatusOr XlaBuilder::DynamicSliceInternal( const Shape& shape, XlaOp operand, absl::Span start_indices, absl::Span slice_sizes) { HloInstructionProto instr; @@ -1373,7 +1587,7 @@ StatusOr XlaBuilder::DynamicSliceInternal( XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update)); std::vector start_indices_shape_ptrs; @@ -1389,7 +1603,7 @@ XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update, }); } -StatusOr XlaBuilder::DynamicUpdateSliceInternal( +absl::StatusOr XlaBuilder::DynamicUpdateSliceInternal( const Shape& shape, XlaOp operand, XlaOp update, absl::Span start_indices) { HloInstructionProto instr; @@ -1403,7 +1617,7 @@ StatusOr XlaBuilder::DynamicUpdateSliceInternal( XlaOp XlaBuilder::ConcatInDim(absl::Span operands, int64_t dimension) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), @@ -1414,7 +1628,7 @@ XlaOp XlaBuilder::ConcatInDim(absl::Span operands, }); } -StatusOr XlaBuilder::ConcatInDimInternal( +absl::StatusOr XlaBuilder::ConcatInDimInternal( const Shape& shape, absl::Span operands, int64_t dimension) { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); @@ -1426,7 +1640,7 @@ StatusOr XlaBuilder::ConcatInDimInternal( XlaOp XlaBuilder::Pad(XlaOp operand, XlaOp padding_value, const PaddingConfig& padding_config) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(const Shape* padding_value_shape, GetShapePtr(padding_value)); @@ -1439,7 +1653,7 @@ XlaOp XlaBuilder::Pad(XlaOp operand, XlaOp padding_value, XlaOp XlaBuilder::PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno, int64_t pad_lo, int64_t pad_hi) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand)); PaddingConfig padding_config = MakeNoPaddingConfig(shape->rank()); auto* dims = padding_config.mutable_dimensions(dimno); @@ -1449,9 +1663,9 @@ XlaOp XlaBuilder::PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno, }); } -StatusOr XlaBuilder::PadInternal(const Shape& shape, XlaOp operand, - XlaOp padding_value, - const PaddingConfig& padding_config) { +absl::StatusOr XlaBuilder::PadInternal( + const Shape& shape, XlaOp operand, XlaOp padding_value, + const PaddingConfig& padding_config) { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); *instr.mutable_padding_config() = padding_config; @@ -1462,7 +1676,7 @@ StatusOr XlaBuilder::PadInternal(const Shape& shape, XlaOp operand, XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span dimensions, absl::Span new_sizes, int64_t inferred_dimension) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferReshapeShape( *operand_shape, dimensions, @@ -1476,7 +1690,7 @@ XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span dimensions, XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span new_sizes, int64_t inferred_dimension) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand)); std::vector dimensions(shape->dimensions_size()); std::iota(dimensions.begin(), dimensions.end(), 0); @@ -1486,7 +1700,7 @@ XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span new_sizes, XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand, int64_t inferred_dimension) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { return ReshapeInternal(shape, operand, inferred_dimension); }); } @@ -1495,7 +1709,7 @@ XlaOp XlaBuilder::DynamicReshape(XlaOp operand, absl::Span dim_sizes, absl::Span new_size_bounds, const std::vector& dims_are_dynamic) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); std::vector dim_size_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& dim_size_shapes, @@ -1523,7 +1737,7 @@ XlaOp XlaBuilder::DynamicReshape(XlaOp operand, XlaOp XlaBuilder::Collapse(XlaOp operand, absl::Span dimensions) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { if (dimensions.size() <= 1) { // Not collapsing anything, trivially we can return the operand versus // enqueueing a trivial reshape. @@ -1563,14 +1777,15 @@ XlaOp XlaBuilder::Collapse(XlaOp operand, } // Dummy pass-through computation returning it's parameter of shape `shape`. -static StatusOr PassthroughComputation(const Shape& shape) { +static absl::StatusOr PassthroughComputation( + const Shape& shape) { XlaBuilder builder("dummy"); XlaOp out = Parameter(&builder, 0, shape, "p"); return builder.Build(out); } XlaOp XlaBuilder::Select(XlaOp pred, XlaOp on_true, XlaOp on_false) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* true_shape, GetShapePtr(on_true)); TF_ASSIGN_OR_RETURN(const Shape* false_shape, GetShapePtr(on_false)); TF_RET_CHECK(true_shape->IsTuple() == false_shape->IsTuple()); @@ -1587,7 +1802,7 @@ XlaOp XlaBuilder::Select(XlaOp pred, XlaOp on_true, XlaOp on_false) { } XlaOp XlaBuilder::Tuple(absl::Span elements) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements)); absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), @@ -1599,15 +1814,15 @@ XlaOp XlaBuilder::Tuple(absl::Span elements) { }); } -StatusOr XlaBuilder::TupleInternal(const Shape& shape, - absl::Span elements) { +absl::StatusOr XlaBuilder::TupleInternal( + const Shape& shape, absl::Span elements) { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kTuple, elements); } XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64_t index) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* tuple_shape, GetShapePtr(tuple_data)); if (!tuple_shape->IsTuple()) { return InvalidArgument( @@ -1625,9 +1840,9 @@ XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64_t index) { }); } -StatusOr XlaBuilder::GetTupleElementInternal(const Shape& shape, - XlaOp tuple_data, - int64_t index) { +absl::StatusOr XlaBuilder::GetTupleElementInternal(const Shape& shape, + XlaOp tuple_data, + int64_t index) { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); instr.set_tuple_index(index); @@ -1638,7 +1853,7 @@ StatusOr XlaBuilder::GetTupleElementInternal(const Shape& shape, XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config, std::optional preferred_element_type) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); DotDimensionNumbers dimension_numbers; @@ -1654,7 +1869,7 @@ XlaOp XlaBuilder::DotGeneral( XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers, const PrecisionConfig* precision_config, std::optional preferred_element_type) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); TF_ASSIGN_OR_RETURN( @@ -1666,7 +1881,7 @@ XlaOp XlaBuilder::DotGeneral( }); } -StatusOr XlaBuilder::DotGeneralInternal( +absl::StatusOr XlaBuilder::DotGeneralInternal( const Shape& shape, XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers, const PrecisionConfig* precision_config) { @@ -1679,6 +1894,35 @@ StatusOr XlaBuilder::DotGeneralInternal( return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs}); } +XlaOp XlaBuilder::SparseDot( + XlaOp lhs, XlaOp rhs, absl::Span sparse_meta, + absl::Span sparsity, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config, + std::optional preferred_element_type) { + return ReportErrorOrReturn([&]() -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); + TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferDotOpShape( + *lhs_shape, *rhs_shape, dimension_numbers, + preferred_element_type, sparsity)); + std::vector operands{lhs, rhs}; + operands.insert(operands.end(), sparse_meta.begin(), sparse_meta.end()); + + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + *instr.mutable_dot_dimension_numbers() = dimension_numbers; + if (precision_config != nullptr) { + *instr.mutable_precision_config() = *precision_config; + } + for (const SparsityDescriptor& descriptor : sparsity) { + *instr.add_dot_sparsity() = descriptor; + } + return AddInstruction(std::move(instr), HloOpcode::kDot, operands); + }); +} + Status XlaBuilder::VerifyConvolution( const Shape& lhs_shape, const Shape& rhs_shape, const ConvolutionDimensionNumbers& dimension_numbers) const { @@ -1697,8 +1941,9 @@ Status XlaBuilder::VerifyConvolution( } int num_spatial_dims = num_dims - 2; - const auto check_spatial_dimensions = [&](absl::string_view field_name, - absl::Span numbers) { + const auto check_spatial_dimensions = + [&](absl::string_view field_name, + absl::Span numbers) -> absl::Status { if (numbers.size() != num_spatial_dims) { return InvalidArgument("Expected %d elements for %s, but got %d.", num_spatial_dims, field_name, numbers.size()); @@ -1753,7 +1998,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( int64_t feature_group_count, int64_t batch_group_count, const PrecisionConfig* precision_config, std::optional preferred_element_type) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); @@ -1811,7 +2056,7 @@ XlaOp XlaBuilder::ConvGeneralDilated( const PrecisionConfig* precision_config, std::optional preferred_element_type, std::optional> window_reversal) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); TF_RETURN_IF_ERROR( @@ -1841,7 +2086,7 @@ XlaOp XlaBuilder::ConvGeneralDilated( }); } -StatusOr XlaBuilder::DynamicConvInstruction( +absl::StatusOr XlaBuilder::DynamicConvInstruction( XlaOp lhs, XlaOp rhs, absl::Span window_strides, absl::Span> padding, absl::Span lhs_dilation, @@ -1894,7 +2139,7 @@ XlaOp XlaBuilder::DynamicConvInputGrad( int64_t feature_group_count, int64_t batch_group_count, const PrecisionConfig* precision_config, PaddingType padding_type, std::optional preferred_element_type) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN( HloInstructionProto instr, DynamicConvInstruction( @@ -1919,7 +2164,7 @@ XlaOp XlaBuilder::DynamicConvKernelGrad( int64_t feature_group_count, int64_t batch_group_count, const PrecisionConfig* precision_config, PaddingType padding_type, std::optional preferred_element_type) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN( HloInstructionProto instr, DynamicConvInstruction(activations, gradients, window_strides, padding, @@ -1946,7 +2191,7 @@ XlaOp XlaBuilder::DynamicConvForward( int64_t feature_group_count, int64_t batch_group_count, const PrecisionConfig* precision_config, PaddingType padding_type, std::optional preferred_element_type) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN( HloInstructionProto instr, DynamicConvInstruction( @@ -1959,7 +2204,7 @@ XlaOp XlaBuilder::DynamicConvForward( }); } -StatusOr XlaBuilder::ConvGeneralDilatedInternal( +absl::StatusOr XlaBuilder::ConvGeneralDilatedInternal( const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, absl::Span window_strides, absl::Span> padding, @@ -1985,7 +2230,7 @@ StatusOr XlaBuilder::ConvGeneralDilatedInternal( XlaOp XlaBuilder::Fft(XlaOp operand, const FftType fft_type, const absl::Span fft_length) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape( *operand_shape, fft_type, fft_length)); @@ -1993,7 +2238,7 @@ XlaOp XlaBuilder::Fft(XlaOp operand, const FftType fft_type, }); } -StatusOr XlaBuilder::FftInternal( +absl::StatusOr XlaBuilder::FftInternal( const Shape& shape, XlaOp operand, const FftType fft_type, const absl::Span fft_length) { HloInstructionProto instr; @@ -2006,7 +2251,7 @@ StatusOr XlaBuilder::FftInternal( return AddInstruction(std::move(instr), HloOpcode::kFft, {operand}); } -StatusOr XlaBuilder::TriangularSolveInternal( +absl::StatusOr XlaBuilder::TriangularSolveInternal( const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options) { HloInstructionProto instr; *instr.mutable_triangular_solve_options() = std::move(options); @@ -2015,8 +2260,8 @@ StatusOr XlaBuilder::TriangularSolveInternal( return AddInstruction(std::move(instr), HloOpcode::kTriangularSolve, {a, b}); } -StatusOr XlaBuilder::CholeskyInternal(const Shape& shape, XlaOp a, - bool lower) { +absl::StatusOr XlaBuilder::CholeskyInternal(const Shape& shape, XlaOp a, + bool lower) { HloInstructionProto instr; CholeskyOptions& options = *instr.mutable_cholesky_options(); options.set_lower(lower); @@ -2026,7 +2271,7 @@ StatusOr XlaBuilder::CholeskyInternal(const Shape& shape, XlaOp a, } XlaOp XlaBuilder::Infeed(const Shape& shape, const std::string& config) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument("Given shape to Infeed must have a layout"); @@ -2100,7 +2345,7 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const std::string& config) { XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape, const std::string& config) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument("Given shape to Infeed must have a layout"); } @@ -2122,7 +2367,7 @@ XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape, }); } -StatusOr XlaBuilder::InfeedWithTokenInternal( +absl::StatusOr XlaBuilder::InfeedWithTokenInternal( const Shape& infeed_instruction_shape, XlaOp token, const std::string& config) { HloInstructionProto instr; @@ -2133,7 +2378,7 @@ StatusOr XlaBuilder::InfeedWithTokenInternal( void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout, const std::string& outfeed_config) { - ReportErrorOrReturn([&]() -> StatusOr { + ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); @@ -2206,7 +2451,7 @@ void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout, XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token, const Shape& shape_with_layout, const std::string& outfeed_config) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { // Check and set outfeed shape. if (!LayoutUtil::HasLayout(shape_with_layout)) { return InvalidArgument("Given shape to Outfeed must have a layout"); @@ -2223,7 +2468,7 @@ XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token, }); } -StatusOr XlaBuilder::OutfeedWithTokenInternal( +absl::StatusOr XlaBuilder::OutfeedWithTokenInternal( XlaOp operand, XlaOp token, const Shape& shape_with_layout, const std::string& outfeed_config) { HloInstructionProto instr; @@ -2235,7 +2480,7 @@ StatusOr XlaBuilder::OutfeedWithTokenInternal( } XlaOp XlaBuilder::CreateToken() { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); return AddInstruction(std::move(instr), HloOpcode::kAfterAll); @@ -2243,7 +2488,7 @@ XlaOp XlaBuilder::CreateToken() { } XlaOp XlaBuilder::AfterAll(absl::Span tokens) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { if (tokens.empty()) { return InvalidArgument("AfterAll requires at least one operand"); } @@ -2272,7 +2517,7 @@ XlaOp XlaBuilder::CustomCall( const Literal* literal, std::optional window, std::optional dnums, CustomCallSchedule schedule, CustomCallApiVersion api_version) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { if (absl::StartsWith(call_target_name, "$")) { return InvalidArgument( "Invalid custom_call_target \"%s\": Call targets that start with '$' " @@ -2309,7 +2554,7 @@ XlaOp XlaBuilder::CustomCall( }); } -StatusOr XlaBuilder::CustomCallInternal( +absl::StatusOr XlaBuilder::CustomCallInternal( const std::string& call_target_name, absl::Span operands, const XlaComputation* computation, const Shape& shape, const std::string& opaque, @@ -2380,7 +2625,7 @@ XlaOp XlaBuilder::CustomCall( output_operand_aliasing, const Literal* literal, CustomCallSchedule schedule, CustomCallApiVersion api_version) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { if (absl::StartsWith(call_target_name, "$")) { return InvalidArgument( "Invalid custom_call_target \"%s\": Call targets that start with '$' " @@ -2418,7 +2663,7 @@ XlaOp XlaBuilder::CustomCall( } XlaOp XlaBuilder::OptimizationBarrier(XlaOp operand) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); Shape shape = *operand_shape; HloInstructionProto instr; @@ -2430,7 +2675,7 @@ XlaOp XlaBuilder::OptimizationBarrier(XlaOp operand) { XlaOp XlaBuilder::Transpose(XlaOp operand, absl::Span permutation) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape( *operand_shape, permutation)); @@ -2438,7 +2683,7 @@ XlaOp XlaBuilder::Transpose(XlaOp operand, }); } -StatusOr XlaBuilder::TransposeInternal( +absl::StatusOr XlaBuilder::TransposeInternal( const Shape& shape, XlaOp operand, absl::Span permutation) { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); @@ -2449,7 +2694,7 @@ StatusOr XlaBuilder::TransposeInternal( } XlaOp XlaBuilder::Rev(XlaOp operand, absl::Span dimensions) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape( *operand_shape, dimensions)); @@ -2457,8 +2702,8 @@ XlaOp XlaBuilder::Rev(XlaOp operand, absl::Span dimensions) { }); } -StatusOr XlaBuilder::RevInternal(const Shape& shape, XlaOp operand, - absl::Span dimensions) { +absl::StatusOr XlaBuilder::RevInternal( + const Shape& shape, XlaOp operand, absl::Span dimensions) { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); for (int64_t dim : dimensions) { @@ -2470,7 +2715,7 @@ StatusOr XlaBuilder::RevInternal(const Shape& shape, XlaOp operand, XlaOp XlaBuilder::Sort(absl::Span operands, const XlaComputation& comparator, int64_t dimension, bool is_stable) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(std::vector operand_shapes, GetOperandShapes(operands)); @@ -2482,10 +2727,11 @@ XlaOp XlaBuilder::Sort(absl::Span operands, }); } -StatusOr XlaBuilder::SortInternal(const Shape& shape, - absl::Span operands, - const XlaComputation& comparator, - int64_t dimension, bool is_stable) { +absl::StatusOr XlaBuilder::SortInternal(const Shape& shape, + absl::Span operands, + const XlaComputation& comparator, + int64_t dimension, + bool is_stable) { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); instr.set_is_stable(is_stable); @@ -2499,7 +2745,7 @@ StatusOr XlaBuilder::SortInternal(const Shape& shape, } XlaOp XlaBuilder::TopK(XlaOp operand, int64_t k, bool largest) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, @@ -2508,8 +2754,9 @@ XlaOp XlaBuilder::TopK(XlaOp operand, int64_t k, bool largest) { }); } -StatusOr XlaBuilder::TopKInternal(const Shape& shape, XlaOp operand, - int64_t k, bool largest) { +absl::StatusOr XlaBuilder::TopKInternal(const Shape& shape, + XlaOp operand, int64_t k, + bool largest) { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); instr.set_k(k); @@ -2519,7 +2766,7 @@ StatusOr XlaBuilder::TopKInternal(const Shape& shape, XlaOp operand, XlaOp XlaBuilder::ConvertElementType(XlaOp operand, PrimitiveType new_element_type) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape( *operand_shape, new_element_type)); @@ -2533,7 +2780,7 @@ XlaOp XlaBuilder::ConvertElementType(XlaOp operand, XlaOp XlaBuilder::BitcastConvertType(XlaOp operand, PrimitiveType new_element_type) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferBitcastConvertShape( *operand_shape, new_element_type)); @@ -2541,8 +2788,8 @@ XlaOp XlaBuilder::BitcastConvertType(XlaOp operand, }); } -StatusOr XlaBuilder::BitcastConvertTypeInternal(const Shape& shape, - XlaOp operand) { +absl::StatusOr XlaBuilder::BitcastConvertTypeInternal(const Shape& shape, + XlaOp operand) { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert, @@ -2551,7 +2798,7 @@ StatusOr XlaBuilder::BitcastConvertTypeInternal(const Shape& shape, XlaOp XlaBuilder::StochasticConvertType(XlaOp operand, XlaOp random, PrimitiveType new_element_type) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(const Shape* random_shape, GetShapePtr(random)); TF_ASSIGN_OR_RETURN(Shape shape, @@ -2570,7 +2817,7 @@ XlaOp XlaBuilder::Map(absl::Span operands, const XlaComputation& computation, absl::Span dimensions, absl::Span static_operands) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { if (!static_operands.empty()) { return Unimplemented("static_operands is not supported in Map"); } @@ -2612,7 +2859,7 @@ XlaOp XlaBuilder::Map(absl::Span operands, XlaOp XlaBuilder::RngOp(RandomDistribution distribution, absl::Span parameters, const Shape& shape) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { // Check the number of parameters per RNG distribution. switch (distribution) { case RandomDistribution::RNG_NORMAL: @@ -2632,9 +2879,9 @@ XlaOp XlaBuilder::RngOp(RandomDistribution distribution, }); } -StatusOr XlaBuilder::RngOpInternal(RandomDistribution distribution, - absl::Span parameters, - const Shape& shape) { +absl::StatusOr XlaBuilder::RngOpInternal( + RandomDistribution distribution, absl::Span parameters, + const Shape& shape) { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); instr.set_distribution(distribution); @@ -2652,7 +2899,7 @@ XlaOp XlaBuilder::RngUniform(XlaOp a, XlaOp b, const Shape& shape) { XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, const Shape& shape) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); TF_ASSIGN_OR_RETURN(Shape state_shape, GetShape(initial_state)); Shape output_shape = shape; @@ -2672,7 +2919,7 @@ XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm, }); } -StatusOr XlaBuilder::RngBitGeneratorInternal( +absl::StatusOr XlaBuilder::RngBitGeneratorInternal( const Shape& full_result_shape, RandomAlgorithm algorithm, XlaOp initial_state) { HloInstructionProto instr; @@ -2684,7 +2931,7 @@ StatusOr XlaBuilder::RngBitGeneratorInternal( XlaOp XlaBuilder::While(const XlaComputation& condition, const XlaComputation& body, XlaOp init) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { // Infer shape. TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape()); TF_ASSIGN_OR_RETURN(const auto& condition_program_shape, @@ -2697,10 +2944,10 @@ XlaOp XlaBuilder::While(const XlaComputation& condition, }); } -StatusOr XlaBuilder::WhileInternal(const Shape& shape, - const XlaComputation& condition, - const XlaComputation& body, - XlaOp init) { +absl::StatusOr XlaBuilder::WhileInternal(const Shape& shape, + const XlaComputation& condition, + const XlaComputation& body, + XlaOp init) { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); // Body comes before condition computation in the vector. @@ -2713,7 +2960,7 @@ XlaOp XlaBuilder::Gather(XlaOp input, XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes, bool indices_are_sorted) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input)); TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape, GetShapePtr(start_indices)); @@ -2725,7 +2972,7 @@ XlaOp XlaBuilder::Gather(XlaOp input, XlaOp start_indices, }); } -StatusOr XlaBuilder::GatherInternal( +absl::StatusOr XlaBuilder::GatherInternal( const Shape& shape, XlaOp input, XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes, bool indices_are_sorted) { @@ -2755,7 +3002,7 @@ XlaOp XlaBuilder::Scatter(absl::Span inputs, XlaOp scatter_indices, const XlaComputation& update_computation, const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted, bool unique_indices) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { if (inputs.empty()) { return InvalidArgument("Scatter inputs cannot be empty."); } @@ -2788,12 +3035,12 @@ XlaOp XlaBuilder::Scatter(absl::Span inputs, XlaOp scatter_indices, }); } -StatusOr XlaBuilder::ScatterInternal( +absl::StatusOr XlaBuilder::ScatterInternal( const Shape& shape, absl::Span inputs, XlaOp scatter_indices, absl::Span updates, const XlaComputation& update_computation, const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted, bool unique_indices) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; instr.set_indices_are_sorted(indices_are_sorted); instr.set_unique_indices(unique_indices); @@ -2814,7 +3061,7 @@ XlaOp XlaBuilder::Conditional(XlaOp predicate, XlaOp true_operand, const XlaComputation& true_computation, XlaOp false_operand, const XlaComputation& false_computation) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(predicate)); if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != PRED) { @@ -2834,7 +3081,7 @@ XlaOp XlaBuilder::Conditional( XlaOp branch_index, absl::Span branch_computations, absl::Span branch_operands) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(branch_index)); if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != S32) { @@ -2853,7 +3100,7 @@ XlaOp XlaBuilder::AllReduceImpl(XlaOp operand, const std::optional& layout, const std::optional use_global_device_ids, bool async) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); std::vector operand_shapes; @@ -2939,7 +3186,7 @@ XlaOp XlaBuilder::AllGatherImpl(const XlaOp operand, const std::optional& layout, const std::optional use_global_device_ids, bool async) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); @@ -2991,7 +3238,7 @@ XlaOp XlaBuilder::ConditionalImpl( XlaOp branch_index, absl::Span branch_computations, absl::Span branch_operands) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* branch_index_shape, @@ -3046,7 +3293,7 @@ XlaOp XlaBuilder::Reduce(absl::Span operands, absl::Span init_values, const XlaComputation& computation, absl::Span dimensions_to_reduce) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); @@ -3070,11 +3317,11 @@ XlaOp XlaBuilder::Reduce(absl::Span operands, }); } -StatusOr XlaBuilder::ReduceInternal( +absl::StatusOr XlaBuilder::ReduceInternal( const Shape& shape, absl::Span all_operands, const XlaComputation& computation, absl::Span dimensions_to_reduce) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); @@ -3089,7 +3336,7 @@ StatusOr XlaBuilder::ReduceInternal( XlaOp XlaBuilder::ReduceAll(XlaOp operand, XlaOp init_value, const XlaComputation& computation) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); std::vector all_dimnos(operand_shape->rank()); std::iota(all_dimnos.begin(), all_dimnos.end(), 0); @@ -3113,7 +3360,7 @@ XlaOp XlaBuilder::ReduceWindow(absl::Span operands, absl::Span window_dimensions, absl::Span window_strides, Padding padding) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { const Shape* operand_shape = nullptr; for (const auto& operand : operands) { TF_ASSIGN_OR_RETURN(operand_shape, GetShapePtr(operand)); @@ -3167,7 +3414,7 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( absl::Span window_dilations, absl::Span> padding) { std::vector operand_shapes, init_shapes; - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { if (operands.size() == 1) { const auto& operand = operands[0]; const auto& init_value = init_values[0]; @@ -3203,7 +3450,7 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( }); } -StatusOr XlaBuilder::ReduceWindowInternal( +absl::StatusOr XlaBuilder::ReduceWindowInternal( absl::Span operands, absl::Span init_values, const XlaComputation& computation, absl::Span window_dimensions, @@ -3238,7 +3485,7 @@ StatusOr XlaBuilder::ReduceWindowInternal( return instr; } -StatusOr XlaBuilder::ReduceWindowInternal( +absl::StatusOr XlaBuilder::ReduceWindowInternal( const Shape& shape, XlaOp operand, XlaOp init_value, const XlaComputation& computation, Window window) { HloInstructionProto instr; @@ -3252,7 +3499,7 @@ StatusOr XlaBuilder::ReduceWindowInternal( XlaOp XlaBuilder::BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, float epsilon, int64_t feature_index) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); @@ -3275,7 +3522,7 @@ XlaOp XlaBuilder::BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp XlaBuilder::BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean, XlaOp variance, float epsilon, int64_t feature_index) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); @@ -3300,7 +3547,7 @@ XlaOp XlaBuilder::BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp XlaBuilder::BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, XlaOp batch_var, XlaOp grad_output, float epsilon, int64_t feature_index) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); @@ -3336,7 +3583,7 @@ XlaOp XlaBuilder::AllGather(XlaOp operand, int64_t all_gather_dimension, XlaOp XlaBuilder::CrossReplicaSum( XlaOp operand, absl::Span replica_groups) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand)); const Shape* element_shape; if (shape->IsTuple()) { @@ -3380,7 +3627,7 @@ XlaOp XlaBuilder::ReduceScatter( const std::optional& channel_id, const std::optional& layout, const std::optional use_global_device_ids) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); std::vector operand_shapes; @@ -3452,7 +3699,7 @@ XlaOp XlaBuilder::AllToAllArray( XlaOp operand, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, absl::Span replica_groups, const std::optional& channel_id) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN( const Shape all_to_all_shape, @@ -3511,7 +3758,7 @@ XlaOp XlaBuilder::AllToAllTuple( absl::Span replica_groups, const std::optional& layout, const std::optional& channel_id) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(auto operand_shapes, this->GetOperandShapes(operands)); std::vector operand_shape_ptrs; @@ -3554,10 +3801,10 @@ XlaOp XlaBuilder::AllToAllTuple( int64_t split_count, absl::Span replica_groups, const std::optional& layout, const std::optional& channel_id) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - // The HloInstruction for Alltoall currently only handles the data + // The HloInstruction for AllToAll currently only handles the data // communication: it accepts N already split parts and scatters them to N // cores, and each core gathers the N received parts into a tuple as the // output. So here we explicitly split the operand before the hlo alltoall, @@ -3594,6 +3841,34 @@ XlaOp XlaBuilder::AllToAllTuple( }); } +XlaOp XlaBuilder::CollectiveBroadcast( + XlaOp operand, absl::Span replica_groups, + const std::optional& channel_id) { + return CollectiveBroadcastImpl(operand, replica_groups, channel_id); +} + +XlaOp XlaBuilder::CollectiveBroadcastImpl( + XlaOp operand, absl::Span replica_groups, + const std::optional& channel_id) { + return ReportErrorOrReturn([&]() -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN( + Shape shape, + ShapeInference::InferCollectiveBroadcastShape({operand_shape})); + *instr.mutable_shape() = shape.ToProto(); + for (const ReplicaGroup& group : replica_groups) { + *instr.add_replica_groups() = group; + } + if (channel_id.has_value()) { + instr.set_channel_id(channel_id->handle()); + } + + return AddInstruction(std::move(instr), HloOpcode::kCollectiveBroadcast, + {operand}); + }); +} + XlaOp XlaBuilder::CollectivePermute( XlaOp operand, const std::vector>& source_target_pairs, @@ -3606,7 +3881,7 @@ XlaOp XlaBuilder::CollectivePermuteImpl( XlaOp operand, const std::vector>& source_target_pairs, const std::optional& channel_id, bool async) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); HloInstructionProto instr; TF_ASSIGN_OR_RETURN( @@ -3631,7 +3906,7 @@ XlaOp XlaBuilder::CollectivePermuteImpl( } XlaOp XlaBuilder::ReplicaId() { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; *instr.mutable_shape() = ShapeUtil::MakeShape(U32, {}).ToProto(); return AddInstruction(std::move(instr), HloOpcode::kReplicaId, {}); @@ -3644,7 +3919,7 @@ XlaOp XlaBuilder::SelectAndScatter(XlaOp operand, const XlaComputation& select, Padding padding, XlaOp source, XlaOp init_value, const XlaComputation& scatter) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); std::vector> padding_values = @@ -3683,7 +3958,7 @@ XlaOp XlaBuilder::SelectAndScatter(XlaOp operand, const XlaComputation& select, }); } -StatusOr XlaBuilder::SelectAndScatterInternal( +absl::StatusOr XlaBuilder::SelectAndScatterInternal( XlaOp operand, const XlaComputation& select, absl::Span window_dimensions, absl::Span window_strides, @@ -3719,7 +3994,7 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( absl::Span window_strides, absl::Span> padding, XlaOp source, XlaOp init_value, const XlaComputation& scatter) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(HloInstructionProto instr, SelectAndScatterInternal( operand, select, window_dimensions, window_strides, @@ -3732,7 +4007,7 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( XlaOp XlaBuilder::ReducePrecision(XlaOp operand, const int exponent_bits, const int mantissa_bits) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReducePrecisionShape( @@ -3742,10 +4017,9 @@ XlaOp XlaBuilder::ReducePrecision(XlaOp operand, const int exponent_bits, }); } -StatusOr XlaBuilder::ReducePrecisionInternal(const Shape& shape, - XlaOp operand, - const int exponent_bits, - const int mantissa_bits) { +absl::StatusOr XlaBuilder::ReducePrecisionInternal( + const Shape& shape, XlaOp operand, const int exponent_bits, + const int mantissa_bits) { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); instr.set_exponent_bits(exponent_bits); @@ -3755,7 +4029,7 @@ StatusOr XlaBuilder::ReducePrecisionInternal(const Shape& shape, } void XlaBuilder::Send(XlaOp operand, const ChannelHandle& handle) { - ReportErrorOrReturn([&]() -> StatusOr { + ReportErrorOrReturn([&]() -> absl::StatusOr { // Send HLO takes two operands: a data operand and a token. Generate the // token to pass into the send. // TODO(b/80000000): Remove this when clients have been updated to handle @@ -3771,7 +4045,7 @@ void XlaBuilder::Send(XlaOp operand, const ChannelHandle& handle) { XlaOp XlaBuilder::SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) { return InvalidArgument("Send must use a device-to-device channel"); } @@ -3784,7 +4058,7 @@ XlaOp XlaBuilder::SendWithToken(XlaOp operand, XlaOp token, } XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { // Recv HLO takes a single token operand. Generate the token to pass into // the Recv and RecvDone instructions. // TODO(b/80000000): Remove this when clients have been updated to handle @@ -3810,7 +4084,7 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { XlaOp XlaBuilder::RecvWithToken(XlaOp token, const Shape& shape, const ChannelHandle& handle) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) { return InvalidArgument("Recv must use a device-to-device channel"); } @@ -3825,7 +4099,7 @@ XlaOp XlaBuilder::RecvWithToken(XlaOp token, const Shape& shape, XlaOp XlaBuilder::SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout, const ChannelHandle& handle) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { if (!LayoutUtil::HasLayout(shape_with_layout)) { return InvalidArgument("Shape passed to SendToHost must have a layout"); } @@ -3873,7 +4147,7 @@ XlaOp XlaBuilder::SendToHost(XlaOp operand, XlaOp token, XlaOp XlaBuilder::RecvFromHost(XlaOp token, const Shape& shape, const ChannelHandle& handle) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument("Shape passed to RecvFromHost must have a layout"); } @@ -3913,14 +4187,14 @@ XlaOp XlaBuilder::RecvFromHost(XlaOp token, const Shape& shape, } XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64_t dimension) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape( *operand_shape, dimension)); // Calling GetDimensionSize on a static dimension returns a constant // instruction. - if (!operand_shape->is_dynamic_dimension(dimension)) { + if (operand_shape->is_static_dimension(dimension)) { return ConstantR0(this, operand_shape->dimensions(dimension)); } *instr.mutable_shape() = shape.ToProto(); @@ -3931,7 +4205,7 @@ XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64_t dimension) { } XlaOp XlaBuilder::RemoveDynamicDimension(XlaOp operand, int64_t dimension) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); Shape shape = *operand_shape; @@ -3946,7 +4220,7 @@ XlaOp XlaBuilder::RemoveDynamicDimension(XlaOp operand, int64_t dimension) { XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension) { - return ReportErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(const Shape* val_shape, GetShapePtr(val)); @@ -3957,9 +4231,10 @@ XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, }); } -StatusOr XlaBuilder::SetDimensionSizeInternal(const Shape& shape, - XlaOp operand, XlaOp val, - int64_t dimension) { +absl::StatusOr XlaBuilder::SetDimensionSizeInternal(const Shape& shape, + XlaOp operand, + XlaOp val, + int64_t dimension) { // Note that both SetDimensionSize and RemoveDynamicDimension use // HloOpcode::kSetDimensionSize internally. However, The SetDimensionSize // builder always produces an output with a dynamic bound on the given @@ -3976,7 +4251,7 @@ StatusOr XlaBuilder::SetDimensionSizeInternal(const Shape& shape, {operand, val}); } -StatusOr XlaBuilder::IsConstant(XlaOp operand) const { +absl::StatusOr XlaBuilder::IsConstant(XlaOp operand) const { TF_RETURN_IF_ERROR(first_error_); // Verify that the handle is valid. @@ -3988,7 +4263,7 @@ StatusOr XlaBuilder::IsConstant(XlaOp operand) const { return is_constant; } -StatusOr XlaBuilder::BuildConstantSubGraph( +absl::StatusOr XlaBuilder::BuildConstantSubGraph( XlaOp root_op, bool dynamic_dimension_is_minus_one) { TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op)); if (!is_constant) { @@ -4256,9 +4531,9 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { return OkStatus(); } -StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, - HloOpcode opcode, - absl::Span operands) { +absl::StatusOr XlaBuilder::AddInstruction( + HloInstructionProto&& instr, HloOpcode opcode, + absl::Span operands) { TF_RETURN_IF_ERROR(first_error_); const int64_t handle = GetNextId(); @@ -4304,8 +4579,8 @@ StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, return op; } -StatusOr XlaBuilder::AddOpWithShape(HloOpcode opcode, const Shape& shape, - absl::Span operands) { +absl::StatusOr XlaBuilder::AddOpWithShape( + HloOpcode opcode, const Shape& shape, absl::Span operands) { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), opcode, operands); @@ -4367,25 +4642,25 @@ void XlaBuilder::AddCalledComputation(const XlaComputation& computation, } } -StatusOr XlaBuilder::LookUpInstruction( +absl::StatusOr XlaBuilder::LookUpInstruction( const XlaOp op) const { TF_RETURN_IF_ERROR(first_error_); return LookUpInstructionInternal(op); } -StatusOr XlaBuilder::LookUpInstructionByHandle( - int64_t handle) const { +absl::StatusOr +XlaBuilder::LookUpInstructionByHandle(int64_t handle) const { return LookUpInstructionByHandleInternal(handle); } -StatusOr XlaBuilder::LookUpMutableInstruction( +absl::StatusOr XlaBuilder::LookUpMutableInstruction( const XlaOp op) { TF_RETURN_IF_ERROR(first_error_); return LookUpInstructionInternal(op); } -StatusOr XlaBuilder::LookUpMutableInstructionByHandle( - int64_t handle) { +absl::StatusOr +XlaBuilder::LookUpMutableInstructionByHandle(int64_t handle) { return LookUpInstructionByHandleInternal(handle); } @@ -4416,12 +4691,19 @@ XlaOp Broadcast(const XlaOp operand, } XlaOp BroadcastInDim(const XlaOp operand, - const absl::Span out_dim_size, - const absl::Span broadcast_dimensions) { + absl::Span out_dim_size, + absl::Span broadcast_dimensions) { return operand.builder()->BroadcastInDim(operand, out_dim_size, broadcast_dimensions); } +XlaOp DynamicBroadcastInDim(const XlaOp operand, const XlaOp output_dimensions, + absl::Span broadcast_dimensions, + const Shape& output_shape) { + return operand.builder()->DynamicBroadcastInDim( + operand, output_dimensions, broadcast_dimensions, output_shape); +} + XlaOp Copy(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kCopy, operand); } @@ -4516,7 +4798,7 @@ static XlaOp CompareTotalOrder(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions, ComparisonDirection comparison_direction) { auto b = lhs.builder(); - return b->ReportErrorOrReturn([&]() -> StatusOr { + return b->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto operand_shape, b->GetShape(lhs)); auto operand_element_type = operand_shape.element_type(); auto compare_type = @@ -4621,6 +4903,17 @@ XlaOp DotGeneral(const XlaOp lhs, const XlaOp rhs, precision_config, preferred_element_type); } +XlaOp SparseDot(const XlaOp lhs, const XlaOp rhs, + absl::Span sparse_meta, + absl::Span sparsity, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config, + std::optional preferred_element_type) { + return lhs.builder()->SparseDot(lhs, rhs, sparse_meta, sparsity, + dimension_numbers, precision_config, + preferred_element_type); +} + XlaOp Conv(const XlaOp lhs, const XlaOp rhs, absl::Span window_strides, Padding padding, int64_t feature_group_count, int64_t batch_group_count, @@ -4738,7 +5031,7 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, bool unit_diagonal, TriangularSolveOptions::Transpose transpose_a) { XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a)); TF_ASSIGN_OR_RETURN(const Shape* b_shape, builder->GetShapePtr(b)); TriangularSolveOptions options; @@ -4754,7 +5047,7 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, XlaOp Cholesky(XlaOp a, bool lower) { XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCholeskyShape(*a_shape)); @@ -5021,6 +5314,18 @@ XlaOp AllGather(const XlaOp operand, int64_t all_gather_dimension, layout, use_global_device_ids); } +XlaOp AllGatherTuple(absl::Span operands, + int64_t all_gather_dimension, int64_t shard_count, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& layout, + const std::optional use_global_device_ids) { + CHECK(!operands.empty()); + return operands[0].builder()->AllGather( + operands[0].builder()->Tuple(operands), all_gather_dimension, shard_count, + replica_groups, channel_id, layout, use_global_device_ids); +} + XlaOp CrossReplicaSum(const XlaOp operand, absl::Span replica_groups) { return operand.builder()->CrossReplicaSum(operand, replica_groups); @@ -5036,7 +5341,7 @@ XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation, use_global_device_ids); } -XlaOp AllReduceTuple(const absl::Span operands, +XlaOp AllReduceTuple(absl::Span operands, const XlaComputation& computation, absl::Span replica_groups, const std::optional& channel_id, @@ -5088,6 +5393,13 @@ XlaOp AllToAllTuple(const XlaOp operand, int64_t split_dimension, replica_groups, layout, channel_id); } +XlaOp CollectiveBroadcast(const XlaOp operand, + absl::Span replica_groups, + const std::optional& channel_id) { + return operand.builder()->CollectiveBroadcast(operand, replica_groups, + channel_id); +} + XlaOp CollectivePermute( const XlaOp operand, const std::vector>& source_target_pairs, @@ -5152,6 +5464,9 @@ XlaOp Log(const XlaOp operand) { XlaOp Log1p(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kLog1p, operand); } +XlaOp Erf(const XlaOp operand) { + return operand.builder()->UnaryOp(HloOpcode::kErf, operand); +} XlaOp Logistic(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kLogistic, operand); } @@ -5448,7 +5763,7 @@ OpSharding GetManualSharding(const OpSharding& original, int64_t single_dim) { return manual; } -StatusOr ConvertSpmdFullToShardShape( +absl::StatusOr ConvertSpmdFullToShardShape( XlaBuilder* builder, XlaOp input, int single_dim, const OpSharding& manual_sharding, absl::Span unspecified_dims) { @@ -5493,7 +5808,7 @@ StatusOr ConvertSpmdFullToShardShape( } } -StatusOr ConvertSpmdShardToFullShape( +absl::StatusOr ConvertSpmdShardToFullShape( XlaBuilder* builder, XlaOp input, const Shape& output_shape, int single_dim, const OpSharding& manual_sharding, absl::Span unspecified_dims) { diff --git a/xla/client/xla_builder.h b/xla/client/xla_builder.h index aca638833e709..b744c679bd7a6 100644 --- a/xla/client/xla_builder.h +++ b/xla/client/xla_builder.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -71,23 +71,13 @@ struct XlaBuilderFriend { static XlaOp BuildAddDependency(XlaBuilder* builder, XlaOp operand, XlaOp token, const Shape& shape); - static std::pair BuildAsyncStart( - XlaBuilder* builder, absl::Span operands, - std::string execution_thread, int64_t group_id, - const XlaComputation& called_computation, const Shape& shape); static std::pair BuildAsyncStart( XlaBuilder* builder, absl::Span operands, std::string execution_thread, const XlaComputation& called_computation, const Shape& shape); - static XlaOp BuildAsyncUpdate(XlaBuilder* builder, XlaOp operands, - std::string execution_thread, int64_t group_id, - int64_t called_computation, const Shape& shape); static XlaOp BuildAsyncUpdate(XlaBuilder* builder, XlaOp operands, std::string execution_thread, int64_t called_computation, const Shape& shape); - static XlaOp BuildAsyncDone(XlaBuilder* builder, XlaOp operands, - std::string execution_thread, int64_t group_id, - int64_t called_computation, const Shape& shape); static XlaOp BuildAsyncDone(XlaBuilder* builder, XlaOp operands, std::string execution_thread, int64_t called_computation, const Shape& shape); @@ -368,12 +358,12 @@ class XlaBuilder { // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the // dynamic dimensions information when XLA backend can handle dynamic // dimensions. - StatusOr Build(bool remove_dynamic_dimensions = false); + absl::StatusOr Build(bool remove_dynamic_dimensions = false); // Overload of Build which specifies a particular root instruction for the // computation. - StatusOr Build(XlaOp root, - bool remove_dynamic_dimensions = false); + absl::StatusOr Build(XlaOp root, + bool remove_dynamic_dimensions = false); // Builds the computation with the requested operations, or notes an error in // the parent XlaBuilder and returns an empty computation if building failed. @@ -389,7 +379,7 @@ class XlaBuilder { // compile-time constant (see `IsConstant`), returns an error. // // This will copy the needed ops/computations to the subgraph. - StatusOr BuildConstantSubGraph( + absl::StatusOr BuildConstantSubGraph( XlaOp root_op, bool dynamic_dimension_is_minus_one = false); // Returns the first error that was encountered while building the @@ -405,18 +395,18 @@ class XlaBuilder { Status GetCurrentStatus() const; // Returns the shape of the given op. - StatusOr GetShape(XlaOp op) const; + absl::StatusOr GetShape(XlaOp op) const; // Returns the shape of the given op. - virtual StatusOr GetShapePtr(XlaOp op) const; + virtual absl::StatusOr GetShapePtr(XlaOp op) const; // Returns the (inferred) result for the current computation's shape. This // assumes the root instruction is the last added instruction. - StatusOr GetProgramShape() const; + absl::StatusOr GetProgramShape() const; // Returns the (inferred) result for the current computation's shape using the // given operation as the root. - StatusOr GetProgramShape(XlaOp root) const; + absl::StatusOr GetProgramShape(XlaOp root) const; // Reports an error to the builder, by // * storing it internally and capturing a backtrace if it's the first error @@ -430,11 +420,12 @@ class XlaBuilder { // A helper function that converts a StatusOr into an XlaOp. // If the Status was an error, reports the error to builder and returns an // invalid XlaOp handle. - XlaOp ReportErrorOrReturn(const StatusOr& op); + XlaOp ReportErrorOrReturn(const absl::StatusOr& op); // A helper function that runs a function that returns a StatusOr and // returns an XlaOp. - XlaOp ReportErrorOrReturn(absl::FunctionRef()> op_creator); + XlaOp ReportErrorOrReturn( + absl::FunctionRef()> op_creator); // Returns true if 'operand' is a compile-time constant. A compile-time // constant does not depend on any parameters, or on stateful operators such @@ -442,7 +433,7 @@ class XlaBuilder { // // This tests whether a computation is a compile-time constant without // evaluating the computation. - StatusOr IsConstant(XlaOp operand) const; + absl::StatusOr IsConstant(XlaOp operand) const; // Adds a new input/output alias. Since the input/output shape information are // not available until the computation is built, any eventual error in the @@ -490,7 +481,7 @@ class XlaBuilder { std::string value); // Returns shapes for the operands. - StatusOr> GetOperandShapes( + absl::StatusOr> GetOperandShapes( absl::Span operands) const; // Converts the op to string for the ease of debugging. @@ -500,8 +491,8 @@ class XlaBuilder { void ToStringHelper(std::string* out, int ident, int64_t op_handle) const; // Build helper which takes the id of the root operation.. - StatusOr Build(int64_t root_id, - bool remove_dynamic_dimensions); + absl::StatusOr Build(int64_t root_id, + bool remove_dynamic_dimensions); // Description for the methods below can be found in the corresponding public // functions section in this file. @@ -522,14 +513,22 @@ class XlaBuilder { XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, absl::Span broadcast_dimensions); + // This is an experimental API for creating the mhlo.dynamic_broadcast_in_dim + // op from the XlaBuilder. This is only intended for export to MHLO or + // StableHLO, and cannot be compiled. Only static output_dimensions are + // allowed, and broadcast_dimensions is verified. + XlaOp DynamicBroadcastInDim(XlaOp operand, XlaOp output_dimensions, + absl::Span broadcast_dimensions, + const Shape& output_shape); + XlaOp Pad(XlaOp operand, XlaOp padding_value, const PaddingConfig& padding_config); XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno, int64_t pad_lo, int64_t pad_hi); - virtual StatusOr PadInternal(const Shape& shape, XlaOp operand, - XlaOp padding_value, - const PaddingConfig& padding_config); + virtual absl::StatusOr PadInternal( + const Shape& shape, XlaOp operand, XlaOp padding_value, + const PaddingConfig& padding_config); XlaOp Reshape(XlaOp operand, absl::Span dimensions, absl::Span new_sizes, @@ -550,40 +549,40 @@ class XlaBuilder { XlaOp Slice(XlaOp operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides); - virtual StatusOr SliceInternal(const Shape& shape, XlaOp operand, - absl::Span start_indices, - absl::Span limit_indices, - absl::Span strides); + virtual absl::StatusOr SliceInternal( + const Shape& shape, XlaOp operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); virtual XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, int64_t stride, int64_t dimno); XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes); - virtual StatusOr DynamicSliceInternal( + virtual absl::StatusOr DynamicSliceInternal( const Shape& shape, XlaOp operand, absl::Span start_indices, absl::Span slice_sizes); XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices); - virtual StatusOr DynamicUpdateSliceInternal( + virtual absl::StatusOr DynamicUpdateSliceInternal( const Shape& shape, XlaOp operand, XlaOp update, absl::Span start_indices); XlaOp ConcatInDim(absl::Span operands, int64_t dimension); - virtual StatusOr ConcatInDimInternal(const Shape& shape, - absl::Span operands, - int64_t dimension); + virtual absl::StatusOr ConcatInDimInternal( + const Shape& shape, absl::Span operands, int64_t dimension); XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); XlaOp Tuple(absl::Span elements); - virtual StatusOr TupleInternal(const Shape& shape, - absl::Span elements); + virtual absl::StatusOr TupleInternal(const Shape& shape, + absl::Span elements); XlaOp GetTupleElement(XlaOp tuple_data, int64_t index); - virtual StatusOr GetTupleElementInternal(const Shape& shape, - XlaOp tuple_data, - int64_t index); + virtual absl::StatusOr GetTupleElementInternal(const Shape& shape, + XlaOp tuple_data, + int64_t index); XlaOp Dot(XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config = nullptr, @@ -594,6 +593,13 @@ class XlaBuilder { const PrecisionConfig* precision_config = nullptr, std::optional preferred_element_type = std::nullopt); + XlaOp SparseDot( + XlaOp lhs, XlaOp rhs, absl::Span sparse_meta, + absl::Span sparsity, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + XlaOp Conv( XlaOp lhs, XlaOp rhs, absl::Span window_strides, Padding padding, int64_t feature_group_count = 1, @@ -666,7 +672,7 @@ class XlaBuilder { const PrecisionConfig* precision_config, PaddingType padding_type, std::optional preferred_element_type = std::nullopt); - StatusOr DynamicConvInstruction( + absl::StatusOr DynamicConvInstruction( XlaOp lhs, XlaOp rhs, absl::Span window_strides, absl::Span> padding, absl::Span lhs_dilation, @@ -676,7 +682,7 @@ class XlaBuilder { const PrecisionConfig* precision_config, PaddingType padding_type, std::optional preferred_element_type = std::nullopt); - virtual StatusOr ConvGeneralDilatedInternal( + virtual absl::StatusOr ConvGeneralDilatedInternal( const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, absl::Span window_strides, absl::Span> padding, @@ -688,20 +694,20 @@ class XlaBuilder { XlaOp Fft(XlaOp operand, FftType fft_type, absl::Span fft_length); - virtual StatusOr FftInternal(const Shape& shape, XlaOp operand, - FftType fft_type, - absl::Span fft_length); + virtual absl::StatusOr FftInternal( + const Shape& shape, XlaOp operand, FftType fft_type, + absl::Span fft_length); - virtual StatusOr TriangularSolveInternal( + virtual absl::StatusOr TriangularSolveInternal( const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options); - virtual StatusOr CholeskyInternal(const Shape& shape, XlaOp a, - bool lower); + virtual absl::StatusOr CholeskyInternal(const Shape& shape, XlaOp a, + bool lower); XlaOp Infeed(const Shape& shape, const std::string& config = ""); XlaOp InfeedWithToken(XlaOp token, const Shape& shape, const std::string& config); - virtual StatusOr InfeedWithTokenInternal( + virtual absl::StatusOr InfeedWithTokenInternal( const Shape& infeed_instruction_shape, XlaOp token, const std::string& config); @@ -710,7 +716,7 @@ class XlaBuilder { XlaOp OutfeedWithToken(XlaOp operand, XlaOp token, const Shape& shape_with_layout, const std::string& outfeed_config); - virtual StatusOr OutfeedWithTokenInternal( + virtual absl::StatusOr OutfeedWithTokenInternal( XlaOp operand, XlaOp token, const Shape& shape_with_layout, const std::string& outfeed_config); XlaOp Call(const XlaComputation& computation, @@ -730,7 +736,7 @@ class XlaBuilder { // Internal version of CustomCall without computation that doesn't do op // specific error handling and expects arguments to be legal. CustomCall // method above calls this method after error handling. - virtual StatusOr CustomCallInternal( + virtual absl::StatusOr CustomCallInternal( const std::string& call_target_name, absl::Span operands, const XlaComputation* computation, const Shape& shape_with_layout, const std::string& opaque, @@ -766,7 +772,7 @@ class XlaBuilder { const XlaComputation& computation, absl::Span dimensions_to_reduce); - virtual StatusOr ReduceInternal( + virtual absl::StatusOr ReduceInternal( const Shape& shape, absl::Span all_operands, const XlaComputation& computation, absl::Span dimensions_to_reduce); @@ -793,7 +799,7 @@ class XlaBuilder { absl::Span base_dilations, absl::Span window_dilations, absl::Span> padding); - StatusOr ReduceWindowInternal( + absl::StatusOr ReduceWindowInternal( absl::Span operands, absl::Span init_values, const XlaComputation& computation, absl::Span window_dimensions, @@ -801,7 +807,7 @@ class XlaBuilder { absl::Span base_dilations, absl::Span window_dilations, absl::Span> padding); - virtual StatusOr ReduceWindowInternal( + virtual absl::StatusOr ReduceWindowInternal( const Shape& shape, XlaOp operand, XlaOp init_value, const XlaComputation& computation, Window window); XlaOp CrossReplicaSum(XlaOp operand, @@ -846,6 +852,10 @@ class XlaBuilder { const std::optional& layout, const std::optional& channel_id = std::nullopt); + XlaOp CollectiveBroadcast( + XlaOp operand, absl::Span replica_groups, + const std::optional& channel_id = std::nullopt); + XlaOp CollectivePermute( XlaOp operand, const std::vector>& source_target_pairs, @@ -866,7 +876,7 @@ class XlaBuilder { absl::Span> padding, XlaOp source, XlaOp init_value, const XlaComputation& scatter); - StatusOr SelectAndScatterInternal( + absl::StatusOr SelectAndScatterInternal( XlaOp operand, const XlaComputation& select, absl::Span window_dimensions, absl::Span window_strides, @@ -880,30 +890,30 @@ class XlaBuilder { XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type); XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type); - virtual StatusOr BitcastConvertTypeInternal(const Shape& shape, - XlaOp operand); + virtual absl::StatusOr BitcastConvertTypeInternal(const Shape& shape, + XlaOp operand); XlaOp StochasticConvertType(XlaOp operand, XlaOp random, PrimitiveType new_element_type); XlaOp Transpose(XlaOp operand, absl::Span permutation); - virtual StatusOr TransposeInternal( + virtual absl::StatusOr TransposeInternal( const Shape& shape, XlaOp operand, absl::Span permutation); XlaOp Rev(XlaOp operand, absl::Span dimensions); - virtual StatusOr RevInternal(const Shape& shape, XlaOp operand, - absl::Span dimensions); + virtual absl::StatusOr RevInternal( + const Shape& shape, XlaOp operand, absl::Span dimensions); XlaOp Sort(absl::Span operands, const XlaComputation& comparator, int64_t dimension = -1, bool is_stable = false); - virtual StatusOr SortInternal(const Shape& shape, - absl::Span operands, - const XlaComputation& comparator, - int64_t dimension, bool is_stable); + virtual absl::StatusOr SortInternal(const Shape& shape, + absl::Span operands, + const XlaComputation& comparator, + int64_t dimension, bool is_stable); XlaOp TopK(XlaOp operand, int64_t k, bool largest); - virtual StatusOr TopKInternal(const Shape& shape, XlaOp operand, - int64_t k, bool largest); + virtual absl::StatusOr TopKInternal(const Shape& shape, XlaOp operand, + int64_t k, bool largest); XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); @@ -919,15 +929,16 @@ class XlaBuilder { const Shape& shape); // Internal variant for the op with the full result shape containing both data // and state shape as a tuple. - virtual StatusOr RngBitGeneratorInternal( + virtual absl::StatusOr RngBitGeneratorInternal( const Shape& full_result_shape, RandomAlgorithm algorithm, XlaOp initial_state); XlaOp While(const XlaComputation& condition, const XlaComputation& body, XlaOp init); - virtual StatusOr WhileInternal(const Shape& shape, - const XlaComputation& condition, - const XlaComputation& body, XlaOp init); + virtual absl::StatusOr WhileInternal(const Shape& shape, + const XlaComputation& condition, + const XlaComputation& body, + XlaOp init); XlaOp Conditional(XlaOp predicate, XlaOp true_operand, const XlaComputation& true_computation, XlaOp false_operand, @@ -938,17 +949,17 @@ class XlaBuilder { absl::Span branch_operands); XlaOp ReducePrecision(XlaOp operand, int exponent_bits, int mantissa_bits); - virtual StatusOr ReducePrecisionInternal(const Shape& shape, - XlaOp operand, - int exponent_bits, - int mantissa_bits); + virtual absl::StatusOr ReducePrecisionInternal(const Shape& shape, + XlaOp operand, + int exponent_bits, + int mantissa_bits); XlaOp Gather(XlaOp input, XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes, bool indices_are_sorted = false); - virtual StatusOr GatherInternal( + virtual absl::StatusOr GatherInternal( const Shape& shape, XlaOp input, XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes, bool indices_are_sorted); @@ -963,7 +974,7 @@ class XlaBuilder { const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted = false, bool unique_indices = false); - virtual StatusOr ScatterInternal( + virtual absl::StatusOr ScatterInternal( const Shape& shape, absl::Span inputs, XlaOp scatter_indices, absl::Span updates, const XlaComputation& update_computation, const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted, @@ -1001,28 +1012,29 @@ class XlaBuilder { XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension); - virtual StatusOr SetDimensionSizeInternal(const Shape& shape, - XlaOp operand, XlaOp val, - int64_t dimension); + virtual absl::StatusOr SetDimensionSizeInternal(const Shape& shape, + XlaOp operand, + XlaOp val, + int64_t dimension); XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension); - virtual StatusOr AddInstruction(HloInstructionProto&& instr, - HloOpcode opcode, - absl::Span operands); - StatusOr AddInstruction(HloInstructionProto&& instr, - HloOpcode opcode) { + virtual absl::StatusOr AddInstruction( + HloInstructionProto&& instr, HloOpcode opcode, + absl::Span operands); + absl::StatusOr AddInstruction(HloInstructionProto&& instr, + HloOpcode opcode) { return AddInstruction(std::move(instr), opcode, /*operands=*/{}); } void AddCalledComputation(const XlaComputation& computation, HloInstructionProto* instr); - StatusOr LookUpInstruction(XlaOp op) const; - StatusOr LookUpInstructionByHandle( + absl::StatusOr LookUpInstruction(XlaOp op) const; + absl::StatusOr LookUpInstructionByHandle( int64_t handle) const; - StatusOr LookUpMutableInstruction(XlaOp op); - StatusOr LookUpMutableInstructionByHandle( + absl::StatusOr LookUpMutableInstruction(XlaOp op); + absl::StatusOr LookUpMutableInstructionByHandle( int64_t handle); // Internal helper method that does the building for an arbitrary unary op. @@ -1037,13 +1049,14 @@ class XlaBuilder { std::optional direction = std::nullopt, std::optional type = std::nullopt); - StatusOr Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, - ComparisonDirection direction); + absl::StatusOr Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, + ComparisonDirection direction); // Internal helper method for binary op compare without broadcast dimensions. - virtual StatusOr Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, - ComparisonDirection direction, - Comparison::Type type); + virtual absl::StatusOr Compare(const Shape& shape, XlaOp lhs, + XlaOp rhs, + ComparisonDirection direction, + Comparison::Type type); // Internal helper method that does the building for an arbitrary binary op // with same ranked operands that doesn't broadcast. @@ -1056,11 +1069,11 @@ class XlaBuilder { XlaOp RngOp(RandomDistribution distribution, absl::Span parameters, const Shape& shape); - virtual StatusOr RngOpInternal(RandomDistribution distribution, - absl::Span parameters, - const Shape& shape); + virtual absl::StatusOr RngOpInternal( + RandomDistribution distribution, absl::Span parameters, + const Shape& shape); - virtual StatusOr InDimBroadcast( + virtual absl::StatusOr InDimBroadcast( const Shape& shape, XlaOp operand, absl::Span broadcast_dimensions); @@ -1069,16 +1082,21 @@ class XlaBuilder { // All dimensions of the operand must either be equal to the corresponding // output shape dimension, or be exactly 1. (Such dimensions are the // degenerate dimensions.) - StatusOr AddBroadcastSequence(const Shape& output_shape, - XlaOp operand); + absl::StatusOr AddBroadcastSequence(const Shape& output_shape, + XlaOp operand); + + // Internal helper method that broadcasts a scalar to the shape of the output. + absl::StatusOr BroadcastScalarToOutputShape(XlaOp scalar, + XlaOp output); // Internal helper method for creating a Reshape op with the already inferred // shape. - virtual StatusOr ReshapeInternal(const Shape& shape, XlaOp operand, - int64_t inferred_dimension); + virtual absl::StatusOr ReshapeInternal(const Shape& shape, + XlaOp operand, + int64_t inferred_dimension); // Returns the (inferred) result for the program shape using the given root. - StatusOr GetProgramShape(int64_t root_id) const; + absl::StatusOr GetProgramShape(int64_t root_id) const; // A visitor which checks whether an operation is a compile-time constant, // meaning that it doesn't depend on any parameters, or on any stateful @@ -1187,6 +1205,11 @@ class XlaBuilder { absl::Span out_dim_size, absl::Span broadcast_dimensions); + friend XlaOp DynamicBroadcastInDim( + XlaOp operand, XlaOp output_dimensions, + absl::Span broadcast_dimensions, + const Shape& output_shape); + friend XlaOp Copy(XlaOp operand); friend XlaOp Pad(XlaOp operand, XlaOp padding_value, @@ -1246,10 +1269,16 @@ class XlaBuilder { const DotDimensionNumbers& dimension_number, const PrecisionConfig* precision_config, std::optional preferred_element_type); - virtual StatusOr DotGeneralInternal( + virtual absl::StatusOr DotGeneralInternal( const Shape& shape, XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_number, const PrecisionConfig* precision_config); + friend XlaOp SparseDot(XlaOp lhs, XlaOp rhs, + absl::Span sparse_meta, + absl::Span sparsity, + const DotDimensionNumbers& dimension_number, + const PrecisionConfig* precision_config, + std::optional preferred_element_type); friend XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span window_strides, Padding padding, int64_t feature_group_count, int64_t batch_group_count, @@ -1447,6 +1476,12 @@ class XlaBuilder { const std::optional& channel_id, const std::optional& layout, std::optional use_global_device_ids); + friend XlaOp AllGatherTuple(absl::Span operands, + int64_t all_gather_dimension, int64_t shard_count, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& layout, + std::optional use_global_device_ids); friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, absl::Span replica_groups, const std::optional& channel_id, @@ -1479,6 +1514,9 @@ class XlaBuilder { absl::Span replica_groups, const std::optional& layout, const std::optional& channel_id); + friend XlaOp CollectiveBroadcast( + XlaOp operand, absl::Span replica_groups, + const std::optional& channel_id); friend XlaOp CollectivePermute( XlaOp operand, const std::vector>& source_target_pairs, @@ -1498,6 +1536,7 @@ class XlaBuilder { friend XlaOp Abs(XlaOp operand); friend XlaOp Atan2(XlaOp y, XlaOp x, absl::Span broadcast_dimensions); + friend XlaOp Erf(XlaOp operand); friend XlaOp Exp(XlaOp operand); friend XlaOp Expm1(XlaOp operand); friend XlaOp Floor(XlaOp operand); @@ -1626,6 +1665,10 @@ class XlaBuilder { const std::optional& layout, std::optional use_global_device_ids, bool async); + XlaOp CollectiveBroadcastImpl(XlaOp operand, + absl::Span replica_groups, + const std::optional& channel_id); + XlaOp CollectivePermuteImpl( XlaOp operand, const std::vector>& source_target_pairs, @@ -1642,8 +1685,8 @@ class XlaBuilder { const std::optional& channel_id = std::nullopt); // Creates an op with the given opcode and the output shape. - virtual StatusOr AddOpWithShape(HloOpcode opcode, const Shape& shape, - absl::Span operands); + virtual absl::StatusOr AddOpWithShape( + HloOpcode opcode, const Shape& shape, absl::Span operands); // Here, InstructionType is either const HloInstructionProto* or non-const // HloInstructionProto*. @@ -1863,6 +1906,15 @@ XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, absl::Span broadcast_dimensions); +// This is an experimental API for creating the mhlo.dynamic_broadcast_in_dim +// op from the XlaBuilder. This is only intended for export to MHLO or +// StableHLO, and cannot be compiled. See +// https://www.tensorflow.org/mlir/hlo_ops#mhlodynamic_broadcast_in_dim_mhlodynamicbroadcastindimop. +// for the op semantics. +XlaOp DynamicBroadcastInDim(XlaOp operand, XlaOp output_dimensions, + absl::Span broadcast_dimensions, + const Shape& output_shape); + // Copies the input operand to the output. This operation is for internal // purpose and is only used by the compiler for optimization purposes or to // ensure correctness. The XLA client should never have to generate this @@ -2072,6 +2124,14 @@ XlaOp DotGeneral( const PrecisionConfig* precision_config = nullptr, std::optional preferred_element_type = std::nullopt); +// Enqueues a sparse dot instruction onto the computation. +XlaOp SparseDot( + XlaOp lhs, XlaOp rhs, absl::Span sparse_meta, + absl::Span sparsity, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span window_strides, @@ -2441,6 +2501,13 @@ XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension, const std::optional& layout = std::nullopt, std::optional use_global_device_ids = std::nullopt); +XlaOp AllGatherTuple( + absl::Span operands, int64_t all_gather_dimension, + int64_t shard_count, absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + // Enqueues an operation that do an AllReduce of the operand cross cores. Here // AllReduce means doing a reduction on the input operand cross cores and then // broadcasting the reduction result to those cores. The reduction function is @@ -2502,6 +2569,10 @@ XlaOp AllToAllTuple( const std::optional& layout = std::nullopt, const std::optional& channel_id = std::nullopt); +XlaOp CollectiveBroadcast( + XlaOp operand, absl::Span replica_groups, + const std::optional& channel_id = std::nullopt); + // Enqueues an collective operation that sends and receives data cross replicas. // // - `source_target_pair`: a list of (source_replica_id, target_replica_id) @@ -2542,6 +2613,9 @@ XlaOp Abs(XlaOp operand); XlaOp Atan2(XlaOp y, XlaOp x, absl::Span broadcast_dimensions = {}); +// Enqueues an erf instruction onto the computation. +XlaOp Erf(XlaOp operand); + // Enqueues an exp instruction onto the computation. XlaOp Exp(XlaOp operand); @@ -2954,7 +3028,7 @@ XlaOp ConstantR4FromArray4D(XlaBuilder* builder, // Switches from automatic SPMD partitioning to manual partitioning. Converts a // full-shaped tensor (to be automatically partitioned by SPMD partitioner) to a // shard-shaped tensor to be consumed by manually partitioned ops. -StatusOr ConvertSpmdFullToShardShape( +absl::StatusOr ConvertSpmdFullToShardShape( xla::XlaBuilder* builder, xla::XlaOp input, int single_dim, const xla::OpSharding& manual_sharding, absl::Span unspecified_dims); @@ -2962,7 +3036,7 @@ StatusOr ConvertSpmdFullToShardShape( // Switches from manual partitioning to automatic SPMD partitioning. Converts a // shard-shaped tensor (manually partitioned in SPMD-style) to a full-shaped // tensor to be partitioned automatically by the SPMD partitioner. -StatusOr ConvertSpmdShardToFullShape( +absl::StatusOr ConvertSpmdShardToFullShape( xla::XlaBuilder* builder, xla::XlaOp input, const xla::Shape& output_shape, int single_dim, const xla::OpSharding& manual_sharding, absl::Span unspecified_dims); diff --git a/xla/client/xla_builder_test.cc b/xla/client/xla_builder_test.cc index 8979a6f2b8e46..d4c24a5ef1ef7 100644 --- a/xla/client/xla_builder_test.cc +++ b/xla/client/xla_builder_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,31 +16,49 @@ limitations under the License. #include "xla/client/xla_builder.h" #include +#include #include +#include #include #include +#include #include +#include #include +#include #include +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/client/padding.h" #include "xla/client/sharding_builder.h" #include "xla/client/value_inference.h" #include "xla/client/xla_computation.h" +#include "xla/comparison_util.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/layout_util.h" #include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/status.h" #include "xla/statusor.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" namespace xla { @@ -49,53 +67,57 @@ namespace { namespace m = ::xla::match; +using ::testing::_; using ::testing::HasSubstr; +using ::testing::Test; +using ::tsl::testing::StatusIs; + +HloInstruction* GetRoot(HloModule& module) { + return module.entry_computation()->root_instruction(); +} // TODO(b/74197823): Move the tests to service/. -class XlaBuilderTest : public ::testing::Test { - protected: - StatusOr> BuildHloModule(XlaBuilder* b) { - TF_ASSIGN_OR_RETURN(XlaComputation computation, - b->Build(/*remove_dynamic_dimensions=*/false)); - const HloModuleProto& proto = computation.proto(); - TF_ASSIGN_OR_RETURN(const auto& config, - HloModule::CreateModuleConfigFromProto( - proto, GetDebugOptionsFromFlags())); - return HloModule::CreateFromProto(proto, config); - } +absl::StatusOr> BuildHloModule(XlaBuilder& b) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, + b.Build(/*remove_dynamic_dimensions=*/false)); + const HloModuleProto& proto = computation.proto(); + TF_ASSIGN_OR_RETURN(const auto& config, + HloModule::CreateModuleConfigFromProto( + proto, GetDebugOptionsFromFlags())); + return HloModule::CreateFromProto(proto, config); +} - // Overload which explicitly specifies the root instruction. - StatusOr> BuildHloModule(XlaBuilder* b, - XlaOp root) { - TF_ASSIGN_OR_RETURN(XlaComputation computation, - b->Build(root, /*remove_dynamic_dimensions=*/false)); - const HloModuleProto& proto = computation.proto(); - TF_ASSIGN_OR_RETURN(const auto& config, - HloModule::CreateModuleConfigFromProto( - proto, GetDebugOptionsFromFlags())); - return HloModule::CreateFromProto(proto, config); - } +// Overload which explicitly specifies the root instruction. +absl::StatusOr> BuildHloModule(XlaBuilder& b, + XlaOp root) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, + b.Build(root, /*remove_dynamic_dimensions=*/false)); + const HloModuleProto& proto = computation.proto(); + TF_ASSIGN_OR_RETURN(const auto& config, + HloModule::CreateModuleConfigFromProto( + proto, GetDebugOptionsFromFlags())); + return HloModule::CreateFromProto(proto, config); +} - // Returns the name of the test currently being run. - std::string TestName() const { - return ::testing::UnitTest::GetInstance()->current_test_info()->name(); - } -}; +// Returns the name of the test currently being run. +std::string TestName() { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); +} -TEST_F(XlaBuilderTest, OnePlusTwo) { +TEST(XlaBuilderTest, OnePlusTwo) { XlaBuilder b(TestName()); Add(ConstantR0(&b, 1.0), ConstantR0(&b, 2.0)); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(m::Add(m::Constant(), m::Constant()))); } -TEST_F(XlaBuilderTest, UnaryOperatorsBuildExpectedHLO) { +TEST(XlaBuilderTest, UnaryOperatorsBuildExpectedHLO) { auto test_unary_operator = [&](std::function op, auto matches_pattern) { XlaBuilder b(TestName()); op(ConstantR0(&b, 1)); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, matches_pattern); }; @@ -105,12 +127,12 @@ TEST_F(XlaBuilderTest, UnaryOperatorsBuildExpectedHLO) { GmockMatch(m::Not(m::Constant()))); } -TEST_F(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) { +TEST(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) { auto test_binary_operator = [&](std::function op, auto matches_pattern) { XlaBuilder b(TestName()); op(ConstantR0(&b, 1), ConstantR0(&b, 2)); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, matches_pattern); }; @@ -140,7 +162,7 @@ TEST_F(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) { [&](std::function op, auto matches_pattern) { XlaBuilder b(TestName()); op(ConstantR0(&b, 1), ConstantR0(&b, 2)); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, matches_pattern); }; @@ -149,12 +171,12 @@ TEST_F(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) { GmockMatch(m::ShiftRightLogical(m::Constant(), m::Constant()))); } -TEST_F(XlaBuilderTest, VariadicAnd) { +TEST(XlaBuilderTest, VariadicAnd) { XlaBuilder b(TestName()); - Shape s = ShapeUtil::MakeShape(PRED, {}); + const Shape s = ShapeUtil::MakeShape(PRED, {}); And(Parameter(&b, 0, s, "p0"), Parameter(&b, 1, s, "p1"), Parameter(&b, 2, s, "p2")); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); // Don't specify in the test whether And(x, y, z) is right- or // left-associative; accept either one. EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -165,12 +187,12 @@ TEST_F(XlaBuilderTest, VariadicAnd) { m::Parameter(2))))); } -TEST_F(XlaBuilderTest, VariadicOr) { +TEST(XlaBuilderTest, VariadicOr) { XlaBuilder b(TestName()); - Shape s = ShapeUtil::MakeShape(PRED, {}); + const Shape s = ShapeUtil::MakeShape(PRED, {}); Or(Parameter(&b, 0, s, "p0"), Parameter(&b, 1, s, "p1"), Parameter(&b, 2, s, "p2")); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); // Don't specify in the test whether Or(x, y, z) is right- or // left-associative; accept either one. EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -181,7 +203,7 @@ TEST_F(XlaBuilderTest, VariadicOr) { m::Parameter(2))))); } -TEST_F(XlaBuilderTest, ShiftRightOperatorOnNonIntegerProducesError) { +TEST(XlaBuilderTest, ShiftRightOperatorOnNonIntegerProducesError) { XlaBuilder b(TestName()); ConstantR0(&b, 1) >> ConstantR0(&b, 2); auto statusor = b.Build(); @@ -191,27 +213,27 @@ TEST_F(XlaBuilderTest, ShiftRightOperatorOnNonIntegerProducesError) { HasSubstr("Argument to >> operator does not have an integral type")); } -TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) { +TEST(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {3, 5}), "x"); Add(x, ConstantR0(&b, 1.0)); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(), m::Broadcast(m::Constant())))); } -TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcastReversed) { +TEST(XlaBuilderTest, ParamPlusConstantHasScalarBroadcastReversed) { XlaBuilder b(TestName()); - XlaOp x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {3, 5}), "x"); + const XlaOp x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {3, 5}), "x"); Add(ConstantR0(&b, 1.0), x); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(m::Add(m::Broadcast(m::Constant()), m::Parameter()))); } -TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { +TEST(XlaBuilderTest, ParamPlusParamHasBroadcast) { XlaBuilder b(TestName()); const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6}); const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4}); @@ -219,46 +241,70 @@ TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { auto y = Parameter(&b, 1, y_shape, "y"); auto add = Add(x, y, /*broadcast_dimensions=*/{0, 1}); - TF_ASSERT_OK_AND_ASSIGN(auto add_shape, b.GetShape(add)); + TF_ASSERT_OK_AND_ASSIGN(const auto add_shape, b.GetShape(add)); EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape)); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT( root, GmockMatch(m::Add(m::Parameter(0), m::Broadcast(m::Parameter(1))))); } -TEST_F(XlaBuilderTest, XPlusX) { +TEST(XlaBuilderTest, XPlusX) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x"); Add(x, x); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), m::Parameter(0)))); } -TEST_F(XlaBuilderTest, ShapeInferenceError) { +TEST(XlaBuilderTest, TestBinaryOpImplicitBroadcast) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("f32[1]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[2, 2]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[2,2]")); + Add(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + /*broadcast_dimensions=*/{1}); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, TestBinaryOpImplicitBroadcastBounded) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("f32[1]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[<=2, <=2]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[<=2, <=2]")); + Add(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + /*broadcast_dimensions=*/{1}); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, ShapeInferenceError) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x"); auto y = Parameter(&b, 1, ShapeUtil::MakeShape(U32, {2, 4}), "y"); Add(x, y); - auto statusor = BuildHloModule(&b); + auto statusor = BuildHloModule(b); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().message(), HasSubstr("Shapes must be equal rank")); } -TEST_F(XlaBuilderTest, DynamicDimensionReshapeToR0) { +TEST(XlaBuilderTest, DynamicDimensionReshapeToR0) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {1}), "x"); auto y = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "dyn_dim"); auto dx = SetDimensionSize(x, y, 0); Reshape(dx, {}); - auto statusor = BuildHloModule(&b); + auto statusor = BuildHloModule(b); ASSERT_TRUE(statusor.ok()); } -TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) { +TEST(XlaBuilderTest, ParameterAlreadyRegistered) { XlaBuilder b_call("add"); Parameter(&b_call, 0, ShapeUtil::MakeShape(PRED, {}), "x"); @@ -266,36 +312,36 @@ TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) { auto x = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "x"); auto y = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "y"); Add(x, y); - auto statusor = BuildHloModule(&b); + auto statusor = BuildHloModule(b); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().message(), HasSubstr("parameter 0 already registered")); } -TEST_F(XlaBuilderTest, Call) { +TEST(XlaBuilderTest, Call) { XlaBuilder b_call("the_only_to_apply"); auto p0 = Parameter(&b_call, 0, ShapeUtil::MakeShape(F32, {}), "p0"); auto p1 = Parameter(&b_call, 1, ShapeUtil::MakeShape(F32, {}), "p1"); Add(p0, p1); - TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build()); + TF_ASSERT_OK_AND_ASSIGN(const auto call, b_call.Build()); XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y"); auto one = ConstantR0(&b, 1); auto two = ConstantR0(&b, 2); Add(Call(&b, call, {x, y}), Call(&b, call, {one, two})); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(m::Add(m::Call(m::Parameter(), m::Parameter()), m::Call(m::Constant(), m::Constant())))); } -TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) { +TEST(XlaBuilderTest, BinopHasDegenerateBroadcast) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x"); auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y"); Add(x, y); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); // Expected: // @@ -312,12 +358,12 @@ TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) { m::Broadcast(m::Reshape(m::Parameter(1)))))); } -TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { +TEST(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3}), "x"); auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y"); Add(x, y, /*broadcast_dimensions=*/{0, 1}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); // The binary operation has in-dim broadcast and degenerate broadcast, should // first do the in-dim broadcast then convert the degenerate broadcast into a @@ -338,37 +384,37 @@ TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { m::Broadcast(m::Reshape(m::Parameter(1)))))); } -TEST_F(XlaBuilderTest, BroadcastInDim) { +TEST(XlaBuilderTest, BroadcastInDim) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3}), "x"); BroadcastInDim(x, {2, 4, 3}, /*broadcast_dimensions=*/{0, 2}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(m::Broadcast())); } -TEST_F(XlaBuilderTest, BroadcastInDimWithDegeneratedDim) { +TEST(XlaBuilderTest, BroadcastInDimWithDegeneratedDim) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 1, 4}), "x"); BroadcastInDim(x, {2, 3, 4}, /*broadcast_dimensions=*/{0, 1, 2}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); EXPECT_THAT(module->entry_computation()->root_instruction(), GmockMatch(m::Broadcast(m::Reshape(m::Broadcast())))); } -TEST_F(XlaBuilderTest, BroadcastInDimWithNegativeSize) { +TEST(XlaBuilderTest, BroadcastInDimWithNegativeSize) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 1, 4}), "x"); BroadcastInDim(x, {-3, 3, 4}, /*broadcast_dimensions=*/{0, 1, 2}); - auto statusor = BuildHloModule(&b); + auto statusor = BuildHloModule(b); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().message(), HasSubstr("invalid shape")); } -TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { +TEST(XlaBuilderTest, OperandFromWrongBuilder) { XlaBuilder b1("b1"); auto p0 = Parameter(&b1, 0, ShapeUtil::MakeShape(F32, {}), "p0"); XlaBuilder builder("main"); @@ -382,49 +428,49 @@ TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { "built by builder 'b1', but is trying to use it in builder 'main'")); } -TEST_F(XlaBuilderTest, ReshapeDefaultOrder) { +TEST(XlaBuilderTest, ReshapeDefaultOrder) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); Reshape(x, /*new_sizes=*/{6, 35}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(m::Reshape(m::Parameter()))); } -TEST_F(XlaBuilderTest, ReshapeHasTranspose) { +TEST(XlaBuilderTest, ReshapeHasTranspose) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(m::Reshape(m::Transpose(m::Parameter())))); } -TEST_F(XlaBuilderTest, Transpose) { +TEST(XlaBuilderTest, Transpose) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); Transpose(x, /*permutation=*/{1, 0}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(m::Transpose(m::Parameter()))); } -TEST_F(XlaBuilderTest, AllGatherR1) { +TEST(XlaBuilderTest, AllGatherR1) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4}), "x"); AllGather(x, /*all_gather_dimension=*/0, /*shard_count=*/4); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAllGather); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {16}))); } -TEST_F(XlaBuilderTest, AllGatherR2) { +TEST(XlaBuilderTest, AllGatherR2) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); AllGather(x, /*all_gather_dimension=*/1, /*shard_count=*/4); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAllGather); @@ -432,13 +478,13 @@ TEST_F(XlaBuilderTest, AllGatherR2) { ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 64}))); } -TEST_F(XlaBuilderTest, AllGatherWithTuple) { +TEST(XlaBuilderTest, AllGatherWithTuple) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4}), "x"); auto x2 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {16, 4}), "x2"); AllGather(Tuple(&b, {x, x2}), /*all_gather_dimension=*/0, /*shard_count=*/4); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAllGather); @@ -448,7 +494,22 @@ TEST_F(XlaBuilderTest, AllGatherWithTuple) { ShapeUtil::MakeShape(F32, {64, 4})}))); } -TEST_F(XlaBuilderTest, ReduceScatter) { +TEST(XlaBuilderTest, AllGatherTuple) { + XlaBuilder b(TestName()); + auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {128, 4}), "p0"); + auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {128, 8}), "p1"); + AllGatherTuple({p0, p1}, /*all_gather_dimension=*/1, /*shard_count=*/4); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + auto root = module->entry_computation()->root_instruction(); + auto tuple_shape = + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {128, 16}), + ShapeUtil::MakeShape(F32, {128, 32})}); + EXPECT_THAT(root, GmockMatch(m::Op() + .WithOpcode(HloOpcode::kAllGather) + .WithShapeEqualTo(&tuple_shape))); +} + +TEST(XlaBuilderTest, ReduceScatter) { XlaBuilder b(TestName()); XlaComputation to_apply; { @@ -466,7 +527,7 @@ TEST_F(XlaBuilderTest, ReduceScatter) { group.add_replica_ids(1); ReduceScatter(x, to_apply, /*scatter_dimension=*/1, /*shard_count=*/2, /*replica_groups=*/{group}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kReduceScatter); @@ -474,7 +535,7 @@ TEST_F(XlaBuilderTest, ReduceScatter) { ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 8}))); } -TEST_F(XlaBuilderTest, ReduceScatterWithTuple) { +TEST(XlaBuilderTest, ReduceScatterWithTuple) { XlaBuilder b(TestName()); XlaComputation to_apply; { @@ -494,7 +555,7 @@ TEST_F(XlaBuilderTest, ReduceScatterWithTuple) { ReduceScatter(Tuple(&b, {x, x2}), to_apply, /*scatter_dimension=*/1, /*shard_count=*/2, /*replica_groups=*/{group}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kReduceScatter); @@ -504,12 +565,12 @@ TEST_F(XlaBuilderTest, ReduceScatterWithTuple) { ShapeUtil::MakeShape(F32, {16, 2})}))); } -TEST_F(XlaBuilderTest, AllToAll) { +TEST(XlaBuilderTest, AllToAll) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/2); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); // AllToAll is decomposed into slices -> all-to-all -> gte -> concat. @@ -521,12 +582,12 @@ TEST_F(XlaBuilderTest, AllToAll) { } // Test the special case where split_dimension is the same as concat_dimension. -TEST_F(XlaBuilderTest, AllToAllSpecial) { +TEST(XlaBuilderTest, AllToAllSpecial) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16, 8}), "x"); AllToAll(x, /*split_dimension=*/0, /*concat_dimension=*/0, /*split_count=*/2); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); // AllToAll is converted into a single all-to-all HloInstruction. @@ -535,7 +596,7 @@ TEST_F(XlaBuilderTest, AllToAllSpecial) { ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 16, 8}))); } -TEST_F(XlaBuilderTest, AllToAllTuple) { +TEST(XlaBuilderTest, AllToAllTuple) { XlaBuilder b(TestName()); auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 4}), "p0"); auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {2, 4}), "p1"); @@ -544,7 +605,7 @@ TEST_F(XlaBuilderTest, AllToAllTuple) { replica_group.add_replica_ids(1); AllToAllTuple({p0, p1}, {replica_group}, LayoutUtil::MakeAscendingLayout(2)); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); // Check shape and replica groups. @@ -566,7 +627,7 @@ TEST_F(XlaBuilderTest, AllToAllTuple) { .WithPredicate(is_replica_group_pred))); } -TEST_F(XlaBuilderTest, AllReduceTuple) { +TEST(XlaBuilderTest, AllReduceTuple) { XlaBuilder b(TestName()); auto shape0 = ShapeUtil::MakeShape(F32, {}); auto shape1 = ShapeUtil::MakeShape(F32, {1, 2}); @@ -576,10 +637,10 @@ TEST_F(XlaBuilderTest, AllReduceTuple) { XlaBuilder bsum(TestName()); auto f32Scalar = ShapeUtil::MakeShape(F32, {}); Add(Parameter(&bsum, 0, f32Scalar, "x"), Parameter(&bsum, 1, f32Scalar, "y")); - TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + TF_ASSERT_OK_AND_ASSIGN(const auto sum, bsum.Build()); AllReduceTuple({p0, p1}, sum); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); // Check shape and replica groups. @@ -591,37 +652,49 @@ TEST_F(XlaBuilderTest, AllReduceTuple) { .WithShapeEqualTo(&tuple_shape))); } -TEST_F(XlaBuilderTest, CollectivePermute) { +TEST(XlaBuilderTest, CollectiveBroadcast) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + ReplicaGroup replica_group; + replica_group.add_replica_ids(0); + replica_group.add_replica_ids(1); + CollectiveBroadcast(x, {replica_group}); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kCollectiveBroadcast); +} + +TEST(XlaBuilderTest, CollectivePermute) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); CollectivePermute(x, {{0, 1}, {1, 2}, {2, 3}}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kCollectivePermute); } -TEST_F(XlaBuilderTest, GetDimensionSize) { +TEST(XlaBuilderTest, GetDimensionSize) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}, {false, true}), "x"); GetDimensionSize(x, 1); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kGetDimensionSize); } -TEST_F(XlaBuilderTest, GetDimensionSizeConstant) { +TEST(XlaBuilderTest, GetDimensionSizeConstant) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}, {false, true}), "x"); // Get dimension size from a constant dimension gives us a constant. GetDimensionSize(x, 0); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kConstant); } -TEST_F(XlaBuilderTest, ReportError) { +TEST(XlaBuilderTest, ReportError) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); Add(b.ReportError(InvalidArgument("a test error")), x); @@ -630,56 +703,57 @@ TEST_F(XlaBuilderTest, ReportError) { EXPECT_THAT(statusor.status().message(), HasSubstr("a test error")); } -TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesNonErrors) { +TEST(XlaBuilderTest, ReportErrorOrReturnHandlesNonErrors) { XlaBuilder b(TestName()); - StatusOr op(ConstantR0(&b, 1.0)); + absl::StatusOr op(ConstantR0(&b, 1.0)); Add(b.ReportErrorOrReturn(op), ConstantR0(&b, 2.0)); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(m::Add(m::Constant(), m::Constant()))); } -TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) { +TEST(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) { XlaBuilder b(TestName()); - StatusOr op(InvalidArgument("a test error")); + absl::StatusOr op(InvalidArgument("a test error")); Add(b.ReportErrorOrReturn(op), ConstantR0(&b, 2.0)); auto statusor = b.Build(); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().message(), HasSubstr("a test error")); } -TEST_F(XlaBuilderTest, BuildWithSpecificRoot) { +TEST(XlaBuilderTest, BuildWithSpecificRoot) { XlaBuilder b(TestName()); - XlaOp constant = ConstantR0(&b, 1.0); + const XlaOp constant = ConstantR0(&b, 1.0); Add(constant, ConstantR0(&b, 2.0)); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/constant)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, + BuildHloModule(b, /*root=*/constant)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(m::Constant())); } -TEST_F(XlaBuilderTest, BuildWithSpecificRootAndMultipleParameters) { +TEST(XlaBuilderTest, BuildWithSpecificRootAndMultipleParameters) { // Specifying a particular root in Build should still include all entry // parameters. XlaBuilder b(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {42, 123}); - XlaOp x = Parameter(&b, 0, shape, "x"); - XlaOp y = Parameter(&b, 1, shape, "y"); - XlaOp z = Parameter(&b, 2, shape, "z"); + const XlaOp x = Parameter(&b, 0, shape, "x"); + const XlaOp y = Parameter(&b, 1, shape, "y"); + const XlaOp z = Parameter(&b, 2, shape, "z"); Add(x, Sub(y, z)); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/x)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b, /*root=*/x)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(m::Parameter())); EXPECT_EQ(module->entry_computation()->num_parameters(), 3); EXPECT_EQ(module->entry_computation()->instruction_count(), 5); } -TEST_F(XlaBuilderTest, BuildWithSpecificRootWithWrongBuilder) { +TEST(XlaBuilderTest, BuildWithSpecificRootWithWrongBuilder) { XlaBuilder b(TestName()); XlaBuilder other_b(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {42, 123}); Parameter(&b, 0, shape, "param"); - XlaOp other_param = Parameter(&other_b, 0, shape, "other_param"); + const XlaOp other_param = Parameter(&other_b, 0, shape, "other_param"); Status status = b.Build(other_param).status(); ASSERT_IS_NOT_OK(status); @@ -688,7 +762,7 @@ TEST_F(XlaBuilderTest, BuildWithSpecificRootWithWrongBuilder) { ::testing::HasSubstr("root operation is not in this computation")); } -TEST_F(XlaBuilderTest, ProtoMatches) { +TEST(XlaBuilderTest, ProtoMatches) { std::vector computations; const int n = 2; computations.reserve(n); @@ -697,7 +771,7 @@ TEST_F(XlaBuilderTest, ProtoMatches) { auto p0 = Parameter(&b_call, 0, ShapeUtil::MakeShape(F32, {}), "p0"); auto p1 = Parameter(&b_call, 1, ShapeUtil::MakeShape(F32, {}), "p1"); Add(p0, Add(p1, p0)); - TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build()); + TF_ASSERT_OK_AND_ASSIGN(const auto call, b_call.Build()); XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y"); @@ -711,13 +785,13 @@ TEST_F(XlaBuilderTest, ProtoMatches) { EXPECT_EQ(c0_string, c1_string); } -TEST_F(XlaBuilderTest, DynamicParameter) { +TEST(XlaBuilderTest, DynamicParameter) { XlaBuilder b(TestName()); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(F32, {6}, {true})}); auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); Parameter(&b, 1, ShapeUtil::MakeShape(U32, {}), "p1"); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/p0)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b, /*root=*/p0)); const Shape& param_shape = module->entry_computation() ->parameter_instruction(0) ->shape() @@ -725,33 +799,33 @@ TEST_F(XlaBuilderTest, DynamicParameter) { EXPECT_TRUE(param_shape.is_dynamic_dimension(0)); } -TEST_F(XlaBuilderTest, SetDimensionSize) { +TEST(XlaBuilderTest, SetDimensionSize) { XlaBuilder b(TestName()); auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10}), "p0"); auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "p1"); auto set_dim_size = SetDimensionSize(p0, p1, 0); - TF_ASSERT_OK_AND_ASSIGN(auto module, - BuildHloModule(&b, /*root=*/set_dim_size)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, + BuildHloModule(b, /*root=*/set_dim_size)); const Shape& root_shape = module->entry_computation()->root_instruction()->shape(); EXPECT_TRUE(root_shape.is_dynamic_dimension(0)); } -TEST_F(XlaBuilderTest, RemoveDynamicDimension) { +TEST(XlaBuilderTest, RemoveDynamicDimension) { XlaBuilder b(TestName()); auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10}), "p0"); auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "p1"); auto set_dim_size = SetDimensionSize(p0, p1, 0); auto remove_dim_size = RemoveDynamicDimension(set_dim_size, 0); - TF_ASSERT_OK_AND_ASSIGN(auto module, - BuildHloModule(&b, /*root=*/remove_dim_size)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, + BuildHloModule(b, /*root=*/remove_dim_size)); const Shape& root_shape = module->entry_computation()->root_instruction()->shape(); // Dynamic dimension has been removed. EXPECT_FALSE(root_shape.is_dynamic_dimension(0)); } -TEST_F(XlaBuilderTest, RemoveDynamicDimensionMultiDims) { +TEST(XlaBuilderTest, RemoveDynamicDimensionMultiDims) { XlaBuilder b(TestName()); auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10, 10}), "p0"); auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "p1"); @@ -759,8 +833,8 @@ TEST_F(XlaBuilderTest, RemoveDynamicDimensionMultiDims) { set_dim_size = SetDimensionSize(set_dim_size, p1, 1); auto remove_dim_size = RemoveDynamicDimension(set_dim_size, 0); remove_dim_size = RemoveDynamicDimension(remove_dim_size, 1); - TF_ASSERT_OK_AND_ASSIGN(auto module, - BuildHloModule(&b, /*root=*/remove_dim_size)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, + BuildHloModule(b, /*root=*/remove_dim_size)); const Shape& root_shape = module->entry_computation()->root_instruction()->shape(); // Dynamic dimensions are removed. @@ -768,60 +842,60 @@ TEST_F(XlaBuilderTest, RemoveDynamicDimensionMultiDims) { EXPECT_FALSE(root_shape.is_dynamic_dimension(1)); } -TEST_F(XlaBuilderTest, DynamicUnary) { +TEST(XlaBuilderTest, DynamicUnary) { XlaBuilder b(TestName()); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})}); auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); auto gte = GetTupleElement(p0, 0); Neg(gte); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); const Shape& result_shape = module->entry_computation()->root_instruction()->shape(); EXPECT_TRUE(result_shape.is_dynamic_dimension(0)); } -TEST_F(XlaBuilderTest, DynamicBinary) { +TEST(XlaBuilderTest, DynamicBinary) { XlaBuilder b(TestName()); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})}); auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); auto gte0 = GetTupleElement(p0, 0); auto gte1 = GetTupleElement(p0, 1); Add(gte0, gte1); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); const Shape& result_shape = module->entry_computation()->root_instruction()->shape(); EXPECT_TRUE(result_shape.is_dynamic_dimension(0)); } -TEST_F(XlaBuilderTest, DynamicBinaryHasBroadcast) { +TEST(XlaBuilderTest, DynamicBinaryHasBroadcast) { XlaBuilder b(TestName()); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}), ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})}); auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); auto gte0 = GetTupleElement(p0, 0); auto gte1 = GetTupleElement(p0, 1); Add(gte0, gte1, {0}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); const Shape& result_shape = module->entry_computation()->root_instruction()->shape(); EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) << result_shape; } -TEST_F(XlaBuilderTest, DynamicBroadcast) { +TEST(XlaBuilderTest, DynamicBroadcast) { XlaBuilder b(TestName()); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}), ShapeUtil::MakeShape(U32, {})}); auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); auto gte = GetTupleElement(p0, 0); BroadcastInDim(gte, /*out_dim_size=*/{3, 5, 4}, /*broadcast_dimensions=*/{1, 2}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); const Shape& result_shape = module->entry_computation()->root_instruction()->shape(); EXPECT_TRUE( @@ -829,25 +903,25 @@ TEST_F(XlaBuilderTest, DynamicBroadcast) { << result_shape; } -TEST_F(XlaBuilderTest, DynamicBinaryHasDegenerateBroadcast) { +TEST(XlaBuilderTest, DynamicBinaryHasDegenerateBroadcast) { XlaBuilder b(TestName()); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {10}, {true}), ShapeUtil::MakeShape(F32, {1, 15}), ShapeUtil::MakeShape(U32, {})}); auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); auto gte0 = GetTupleElement(p0, 0); auto gte1 = GetTupleElement(p0, 1); Add(gte0, gte1, /*broadcast_dimensions=*/{0}); // f32[<=10, 15] - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); const Shape& result_shape = module->entry_computation()->root_instruction()->shape(); EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) << result_shape; } -TEST_F(XlaBuilderTest, DynamicSelectOnlyPredDynamic) { +TEST(XlaBuilderTest, DynamicSelectOnlyPredDynamic) { XlaBuilder b(TestName()); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(PRED, {10}, {true}), ShapeUtil::MakeShape(F32, {10}), ShapeUtil::MakeShape(U32, {})}); auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); @@ -856,26 +930,26 @@ TEST_F(XlaBuilderTest, DynamicSelectOnlyPredDynamic) { Select(gte0, gte1, gte1); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); const Shape& result_shape = module->entry_computation()->root_instruction()->shape(); EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true})) << result_shape; } -TEST_F(XlaBuilderTest, SelectIntoConditional) { +TEST(XlaBuilderTest, SelectIntoConditional) { XlaBuilder b(TestName()); - Shape selector_shape = ShapeUtil::MakeShape(PRED, {}); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape selector_shape = ShapeUtil::MakeShape(PRED, {}); + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {})}); - XlaOp p0 = Parameter(&b, 0, selector_shape, "p0"); - XlaOp p1 = Parameter(&b, 1, tuple_param_shape, "p1"); - XlaOp p2 = Parameter(&b, 2, tuple_param_shape, "p2"); + const XlaOp p0 = Parameter(&b, 0, selector_shape, "p0"); + const XlaOp p1 = Parameter(&b, 1, tuple_param_shape, "p1"); + const XlaOp p2 = Parameter(&b, 2, tuple_param_shape, "p2"); Select(p0, p1, p2); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); EXPECT_THAT(module->entry_computation()->root_instruction(), GmockMatch(m::Conditional(m::Parameter(0), m::Parameter(1), m::Parameter(2)))); @@ -891,9 +965,9 @@ TEST_F(XlaBuilderTest, SelectIntoConditional) { GmockMatch(m::Parameter(0))); } -TEST_F(XlaBuilderTest, DynamicPad) { +TEST(XlaBuilderTest, DynamicPad) { XlaBuilder b(TestName()); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}), ShapeUtil::MakeShape(U32, {})}); auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); @@ -907,16 +981,16 @@ TEST_F(XlaBuilderTest, DynamicPad) { dimension->set_interior_padding(0); } Pad(gte, pad_val, padding_config); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); const Shape& result_shape = module->entry_computation()->root_instruction()->shape(); EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) << result_shape; } -TEST_F(XlaBuilderTest, DynamicConvolution) { +TEST(XlaBuilderTest, DynamicConvolution) { XlaBuilder b(TestName()); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {1, 2, 2, 128}, {true, false, false, false}), ShapeUtil::MakeShape(F32, {2, 2, 128, 8}, {false, false, true, false}), ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}); @@ -938,7 +1012,7 @@ TEST_F(XlaBuilderTest, DynamicConvolution) { dnums.set_kernel_output_feature_dimension(3); ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, /*feature_group_count=*/1); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); const Shape& result_shape = module->entry_computation()->root_instruction()->shape(); EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), @@ -946,9 +1020,9 @@ TEST_F(XlaBuilderTest, DynamicConvolution) { << result_shape; } -TEST_F(XlaBuilderTest, DynamicDot) { +TEST(XlaBuilderTest, DynamicDot) { XlaBuilder b(TestName()); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {2, 3, 4}, {true, true, false}), ShapeUtil::MakeShape(F32, {2, 4, 5}, {true, false, false}), ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}); @@ -962,7 +1036,7 @@ TEST_F(XlaBuilderTest, DynamicDot) { dnums.add_lhs_batch_dimensions(0); dnums.add_rhs_batch_dimensions(0); DotGeneral(lhs, rhs, dnums); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); const Shape& result_shape = module->entry_computation()->root_instruction()->shape(); EXPECT_TRUE( @@ -970,9 +1044,9 @@ TEST_F(XlaBuilderTest, DynamicDot) { << result_shape; } -TEST_F(XlaBuilderTest, DynamicReduce) { +TEST(XlaBuilderTest, DynamicReduce) { XlaBuilder b(TestName()); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {5, 4, 3}, {false, true, false}), ShapeUtil::MakeShape(U32, {})}); auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); @@ -981,18 +1055,18 @@ TEST_F(XlaBuilderTest, DynamicReduce) { XlaBuilder bsum(TestName()); Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"), Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y")); - TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + TF_ASSERT_OK_AND_ASSIGN(const auto sum, bsum.Build()); Reduce(gte, init, sum, {0}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); const Shape& result_shape = module->entry_computation()->root_instruction()->shape(); EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) << result_shape; } -TEST_F(XlaBuilderTest, DynamicReduceWindow) { +TEST(XlaBuilderTest, DynamicReduceWindow) { XlaBuilder b(TestName()); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}), ShapeUtil::MakeShape(U32, {})}); auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); @@ -1001,10 +1075,10 @@ TEST_F(XlaBuilderTest, DynamicReduceWindow) { XlaBuilder bsum(TestName()); Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"), Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y")); - TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + TF_ASSERT_OK_AND_ASSIGN(const auto sum, bsum.Build()); ReduceWindow(gte, init, sum, /*window_dimensions=*/{1, 2, 4}, /*window_strides=*/{1, 1, 1}, Padding::kValid); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); VLOG(2) << module->entry_computation()->root_instruction()->ToString() << "\n"; const Shape& result_shape = @@ -1014,9 +1088,9 @@ TEST_F(XlaBuilderTest, DynamicReduceWindow) { << result_shape; } -TEST_F(XlaBuilderTest, VariadicDynamicReduceWindow) { +TEST(XlaBuilderTest, VariadicDynamicReduceWindow) { XlaBuilder b(TestName()); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}), ShapeUtil::MakeShape(U32, {})}); auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); @@ -1031,12 +1105,12 @@ TEST_F(XlaBuilderTest, VariadicDynamicReduceWindow) { auto p5 = Parameter(&bsum, 3, ShapeUtil::MakeShape(F32, {}), "y1"); std::vector output_operands = {Add(p2, p4), Add(p3, p5)}; Tuple(&bsum, absl::MakeSpan(output_operands)); - TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + TF_ASSERT_OK_AND_ASSIGN(const auto sum, bsum.Build()); auto init = ConstantR0(&b, 0.f); ReduceWindow(input_operands, {init, init}, sum, /*window_dimensions=*/{1, 2, 4}, /*window_strides=*/{1, 1, 1}, Padding::kValid); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); VLOG(2) << module->entry_computation()->root_instruction()->ToString() << "\n"; const Shape& result_shape = @@ -1049,9 +1123,9 @@ TEST_F(XlaBuilderTest, VariadicDynamicReduceWindow) { << result_shape.tuple_shapes(1); } -TEST_F(XlaBuilderTest, DynamicSelectAndScatter) { +TEST(XlaBuilderTest, DynamicSelectAndScatter) { XlaBuilder b(TestName()); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}), ShapeUtil::MakeShape(F32, {2, 2, 2}, {true, false, false}), ShapeUtil::MakeShape(U32, {})}); @@ -1060,17 +1134,17 @@ TEST_F(XlaBuilderTest, DynamicSelectAndScatter) { XlaBuilder bsum(TestName()); Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"), Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y")); - TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + TF_ASSERT_OK_AND_ASSIGN(const auto sum, bsum.Build()); XlaBuilder bge(TestName()); Ge(Parameter(&bge, 0, ShapeUtil::MakeShape(F32, {}), "x"), Parameter(&bge, 1, ShapeUtil::MakeShape(F32, {}), "y")); - TF_ASSERT_OK_AND_ASSIGN(auto ge, bge.Build()); + TF_ASSERT_OK_AND_ASSIGN(const auto ge, bge.Build()); auto gte0 = GetTupleElement(p0, 0); auto source = GetTupleElement(p0, 1); SelectAndScatter(gte0, ge, {1, 2, 4}, {1, 2, 4}, Padding::kValid, source, init, sum); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); const Shape& result_shape = module->entry_computation()->root_instruction()->shape(); EXPECT_TRUE( @@ -1078,16 +1152,16 @@ TEST_F(XlaBuilderTest, DynamicSelectAndScatter) { << result_shape; } -TEST_F(XlaBuilderTest, DynamicReshape) { +TEST(XlaBuilderTest, DynamicReshape) { XlaBuilder b(TestName()); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6}, {false, false, true, true, false}), ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}); auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); auto gte = GetTupleElement(p0, 0); // f32[2, 3, <=4, <=5, 6] Reshape(gte, /*new_sizes=*/{6, 4, 5, 2, 3}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); const Shape& result_shape = module->entry_computation()->root_instruction()->shape(); EXPECT_TRUE(result_shape.is_dynamic_dimension(1)); @@ -1097,9 +1171,9 @@ TEST_F(XlaBuilderTest, DynamicReshape) { << result_shape; } -TEST_F(XlaBuilderTest, DynamicSelect) { +TEST(XlaBuilderTest, DynamicSelect) { XlaBuilder b(TestName()); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}), ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}), ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}); @@ -1108,7 +1182,7 @@ TEST_F(XlaBuilderTest, DynamicSelect) { auto gte0 = GetTupleElement(p0, 0); auto gte1 = GetTupleElement(p0, 1); Select(pred, gte0, gte1); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); const Shape& result_shape = module->entry_computation()->root_instruction()->shape(); EXPECT_TRUE(result_shape.is_dynamic_dimension(1)); @@ -1118,9 +1192,9 @@ TEST_F(XlaBuilderTest, DynamicSelect) { << result_shape; } -TEST_F(XlaBuilderTest, DynamicSelectNotCompatible) { +TEST(XlaBuilderTest, DynamicSelectNotCompatible) { XlaBuilder b(TestName()); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}), ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, false, true}), ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}); @@ -1129,29 +1203,29 @@ TEST_F(XlaBuilderTest, DynamicSelectNotCompatible) { auto gte0 = GetTupleElement(p0, 0); // f32[4,<=5,6] auto gte1 = GetTupleElement(p0, 1); // f32[4,5,<=6] Select(pred, gte0, gte1); - Status status = BuildHloModule(&b).status(); + Status status = BuildHloModule(b).status(); ASSERT_IS_OK(status); } -TEST_F(XlaBuilderTest, DynamicTranspose) { +TEST(XlaBuilderTest, DynamicTranspose) { XlaBuilder b(TestName()); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {3, 5}, {true, false}), ShapeUtil::MakeShape(U32, {})}); auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); auto gte = GetTupleElement(p0, 0); Transpose(gte, /*permutation=*/{1, 0}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); const Shape& result_shape = module->entry_computation()->root_instruction()->shape(); EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {false, true})) << result_shape; } -TEST_F(XlaBuilderTest, DotWithPreferredElementType) { +TEST(XlaBuilderTest, DotWithPreferredElementType) { XlaBuilder b(TestName()); - Shape p0_shape = ShapeUtil::MakeShape(U8, {2, 3}); - Shape p1_shape = ShapeUtil::MakeShape(U16, {3, 2}); + const Shape p0_shape = ShapeUtil::MakeShape(U8, {2, 3}); + const Shape p1_shape = ShapeUtil::MakeShape(U16, {3, 2}); auto p0 = Parameter(&b, 0, p0_shape, "p0"); auto p1 = Parameter(&b, 1, p1_shape, "p1"); @@ -1160,17 +1234,42 @@ TEST_F(XlaBuilderTest, DotWithPreferredElementType) { dnums.add_rhs_contracting_dimensions(0); DotGeneral(p0, p1, dnums, /*precision_config=*/nullptr, /*preferred_element_type=*/U32); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); const Shape& result_shape = module->entry_computation()->root_instruction()->shape(); ASSERT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(U32, {2, 2}), result_shape)); } -TEST_F(XlaBuilderTest, ConvolutionWithPreferredElementType) { +TEST(XlaBuilderTest, SparseDot) { XlaBuilder b(TestName()); - Shape p0_shape = ShapeUtil::MakeShape(S16, {1, 2, 2, 128}); - Shape p1_shape = ShapeUtil::MakeShape(S8, {2, 2, 128, 8}); + auto lhs = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10, 16}), "lhs"); + auto rhs = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {32, 20}), "rhs"); + auto meta = Parameter(&b, 2, ShapeUtil::MakeShape(U16, {10, 2}), "meta"); + + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(1); + dnums.add_rhs_contracting_dimensions(0); + SparsityDescriptor sparsity_descriptor; + sparsity_descriptor.set_type(SparsityType::SPARSITY_STRUCTURED_N_M); + sparsity_descriptor.set_n(2); + sparsity_descriptor.set_m(4); + sparsity_descriptor.set_index(0); + sparsity_descriptor.set_dimension(1); + std::vector sparsity = {sparsity_descriptor}; + std::vector sparse_meta = {meta}; + + SparseDot(lhs, rhs, sparse_meta, sparsity, dnums); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[10, 20]")); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, ConvolutionWithPreferredElementType) { + XlaBuilder b(TestName()); + const Shape p0_shape = ShapeUtil::MakeShape(S16, {1, 2, 2, 128}); + const Shape p1_shape = ShapeUtil::MakeShape(S8, {2, 2, 128, 8}); auto p0 = Parameter(&b, 0, p0_shape, "p0"); auto p1 = Parameter(&b, 1, p1_shape, "p1"); @@ -1191,14 +1290,14 @@ TEST_F(XlaBuilderTest, ConvolutionWithPreferredElementType) { /*feature_group_count=*/1, /*batch_group_count=*/1, /*precision_config=*/nullptr, /*preferred_element_type=*/S32); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); const Shape& result_shape = module->entry_computation()->root_instruction()->shape(); ASSERT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {1, 1, 1, 8}), result_shape)); } -TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) { +TEST(XlaBuilderTest, AfterAllWithNonTokenOperands) { XlaBuilder b(TestName()); AfterAll(&b, {CreateToken(&b), ConstantR0(&b, 1.0)}); Status status = b.Build().status(); @@ -1207,7 +1306,7 @@ TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) { ::testing::HasSubstr("All operands to AfterAll must be tokens")); } -TEST_F(XlaBuilderTest, CheckInputOutputAlias) { +TEST(XlaBuilderTest, CheckInputOutputAlias) { XlaBuilder b(TestName()); auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {8, 4}), "p0"); auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {8, 4}), "p1"); @@ -1218,7 +1317,7 @@ TEST_F(XlaBuilderTest, CheckInputOutputAlias) { b.SetUpAlias({1}, 0, {}); b.SetUpAlias({0}, 1, {}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, root)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b, root)); const HloInputOutputAliasConfig& config = module->input_output_alias_config(); EXPECT_TRUE(config.ParameterHasAlias(0, {})); @@ -1233,7 +1332,7 @@ TEST_F(XlaBuilderTest, CheckInputOutputAlias) { EXPECT_EQ(*alias_p1, ShapeIndex({0})); } -TEST_F(XlaBuilderTest, CheckBufferDonor) { +TEST(XlaBuilderTest, CheckBufferDonor) { XlaBuilder b(TestName()); auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {8, 4}), "p0"); auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {8, 4}), "p1"); @@ -1243,14 +1342,14 @@ TEST_F(XlaBuilderTest, CheckBufferDonor) { b.AddBufferDonor(0, {}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, root)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b, root)); const HloBufferDonorConfig& config = module->buffer_donor_config(); EXPECT_TRUE(config.ParameterIsBufferDonor(0, {})); EXPECT_FALSE(config.ParameterIsBufferDonor(1, {})); } -TEST_F(XlaBuilderTest, InvalidInputOutputAliasBufferDonor) { +TEST(XlaBuilderTest, InvalidInputOutputAliasBufferDonor) { XlaBuilder b(TestName()); auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {8, 4}), "p0"); @@ -1262,14 +1361,14 @@ TEST_F(XlaBuilderTest, InvalidInputOutputAliasBufferDonor) { b.SetUpAlias({1}, 0, {}); b.AddBufferDonor(0, {}); - auto statusor = BuildHloModule(&b, root); + auto statusor = BuildHloModule(b, root); EXPECT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().message(), HasSubstr("is already aliased with one output, thus it cannot be " "added as a buffer donor for any output.")); } -TEST_F(XlaBuilderTest, ValidInputOutputAliasBufferDonor) { +TEST(XlaBuilderTest, ValidInputOutputAliasBufferDonor) { XlaBuilder b(TestName()); auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {8, 4}), "p0"); @@ -1280,7 +1379,7 @@ TEST_F(XlaBuilderTest, ValidInputOutputAliasBufferDonor) { b.SetUpAlias({1}, 0, {}); b.AddBufferDonor(1, {}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, root)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b, root)); const HloInputOutputAliasConfig& io_alias_config = module->input_output_alias_config(); @@ -1319,7 +1418,7 @@ void ExpectInstructionsAttributesMatch( EXPECT_EQ(expected_it, expected.end()); } -TEST_F(XlaBuilderTest, SimpleSetFrontendAttributes) { +TEST(XlaBuilderTest, SimpleSetFrontendAttributes) { XlaBuilder b(TestName()); FrontendAttributes attributes; @@ -1332,14 +1431,14 @@ TEST_F(XlaBuilderTest, SimpleSetFrontendAttributes) { b.ClearFrontendAttributes(); ConstantR0(&b, 0); // No attribute set - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); std::vector expected{FrontendAttributes(), attributes, FrontendAttributes()}; ExpectInstructionsAttributesMatch(*module, expected); } -TEST_F(XlaBuilderTest, ComplexSetFrontendAttributes) { +TEST(XlaBuilderTest, ComplexSetFrontendAttributes) { XlaBuilder b(TestName()); ConstantR0(&b, 0); // No attribute set. @@ -1374,11 +1473,11 @@ TEST_F(XlaBuilderTest, ComplexSetFrontendAttributes) { ConstantR0(&b, 0); // No attribute set expected.push_back(FrontendAttributes()); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); ExpectInstructionsAttributesMatch(*module, expected); } -TEST_F(XlaBuilderTest, AddFrontendAttribute) { +TEST(XlaBuilderTest, AddFrontendAttribute) { XlaBuilder b(TestName()); ConstantR0(&b, 0); @@ -1440,24 +1539,24 @@ TEST_F(XlaBuilderTest, AddFrontendAttribute) { ConstantR0(&b, 0); // No attribute set expected.push_back(FrontendAttributes()); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); ExpectInstructionsAttributesMatch(*module, expected); } -TEST_F(XlaBuilderTest, ComparisonType) { +TEST(XlaBuilderTest, ComparisonType) { XlaBuilder b(TestName()); (void)Le(ConstantR0(&b, 1), ConstantR0(&b, 2)); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto root = module->entry_computation()->root_instruction(); ASSERT_THAT(root, GmockMatch(m::Compare(m::Constant(), m::Constant()))); EXPECT_EQ(Comparison::Type::kSigned, DynCast(root)->type()); } -TEST_F(XlaBuilderTest, StableLookUpInstructionByHandle) { +TEST(XlaBuilderTest, StableLookUpInstructionByHandle) { XlaBuilder b(TestName()); internal::XlaBuilderFriend builder_friend; - XlaOp le = Le(ConstantR0(&b, 1), ConstantR0(&b, 2)); + const XlaOp le = Le(ConstantR0(&b, 1), ConstantR0(&b, 2)); HloInstructionProto* first_op = builder_friend.GetInstruction(le); // Create some more instructions. for (int i = 0; i < 100; ++i) { @@ -1468,36 +1567,38 @@ TEST_F(XlaBuilderTest, StableLookUpInstructionByHandle) { EXPECT_EQ(first_op, first_op_now); } -TEST_F(XlaBuilderTest, ComplexAbsConstant) { +TEST(XlaBuilderTest, ComplexAbsConstant) { XlaBuilder b(TestName()); - XlaOp out = + const XlaOp out = Abs(ConstantR0>(&b, std::complex{-1, -1})); ValueInference value_inference(&b); - StatusOr analyzed = + absl::StatusOr analyzed = value_inference.AnalyzeConstant(out, kUpperBound); EXPECT_IS_OK(analyzed.status()); EXPECT_EQ(analyzed->GetValue().value().shape().element_type(), PrimitiveType::F32); } -TEST_F(XlaBuilderTest, OutfeedDummyTupleSharding) { +TEST(XlaBuilderTest, OutfeedDummyTupleSharding) { XlaBuilder b(TestName()); - XlaOp value = ConstantR1(&b, {0}); - Shape shape = ShapeUtil::MakeShapeWithDenseLayout(S32, /* dimensions= */ {1}, - /* minor_to_major= */ {0}); + const XlaOp value = ConstantR1(&b, {0}); + const Shape shape = + ShapeUtil::MakeShapeWithDenseLayout(S32, /* dimensions= */ {1}, + /* minor_to_major= */ {0}); Outfeed(value, shape, ""); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); EXPECT_FALSE(module->entry_computation()->root_instruction()->has_sharding()); } -TEST_F(XlaBuilderTest, OutfeedTokenSharding) { +TEST(XlaBuilderTest, OutfeedTokenSharding) { XlaBuilder b(TestName()); - XlaOp value = ConstantR1(&b, {0}); - Shape shape = ShapeUtil::MakeShapeWithDenseLayout(S32, /* dimensions= */ {1}, - /* minor_to_major= */ {0}); + const XlaOp value = ConstantR1(&b, {0}); + const Shape shape = + ShapeUtil::MakeShapeWithDenseLayout(S32, /* dimensions= */ {1}, + /* minor_to_major= */ {0}); b.SetSharding(sharding_builder::Replicate()); Outfeed(value, shape, ""); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); auto it = std::find_if(module->entry_computation()->instructions().begin(), module->entry_computation()->instructions().end(), HloPredicateIsOp); @@ -1513,23 +1614,23 @@ TEST_F(XlaBuilderTest, OutfeedTokenSharding) { HloSharding::FromProto(sharding_builder::AssignDevice(0)).value()); } -TEST_F(XlaBuilderTest, NormalizeTupleSharding) { +TEST(XlaBuilderTest, NormalizeTupleSharding) { XlaBuilder b(TestName()); - Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(F32, {6})}); b.SetSharding(sharding_builder::Replicate()); Parameter(&b, 0, tuple_param_shape, "p0"); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); const HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_TRUE(root->has_sharding()); EXPECT_TRUE(root->sharding().IsTuple()); EXPECT_EQ(root->sharding().tuple_elements().size(), 2); } -TEST_F(XlaBuilderTest, InvalidSharding) { +TEST(XlaBuilderTest, InvalidSharding) { XlaBuilder b(TestName()); - Shape shape2d = ShapeUtil::MakeShape(F32, {6, 8}); - Shape shape1d = ShapeUtil::MakeShape(F32, {5}); + const Shape shape2d = ShapeUtil::MakeShape(F32, {6, 8}); + const Shape shape1d = ShapeUtil::MakeShape(F32, {5}); b.SetSharding(sharding_builder::Tile1D(shape1d, 4)); Parameter(&b, 0, shape2d, "p0"); auto statusor = b.Build(); @@ -1539,13 +1640,13 @@ TEST_F(XlaBuilderTest, InvalidSharding) { "subgroups) is different than the input rank")); } -TEST_F(XlaBuilderTest, TopKDimensions) { +TEST(XlaBuilderTest, TopKDimensions) { XlaBuilder b(TestName()); int64_t k = 1; int64_t largest = true; TopK(Parameter(&b, 0, ShapeUtil::MakeShape(F32, {6, 8}), "p0"), k, largest); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); const HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_TRUE(root->opcode() == HloOpcode::kTopK); EXPECT_TRUE(root->shape().IsTuple()); @@ -1558,227 +1659,1180 @@ TEST_F(XlaBuilderTest, TopKDimensions) { EXPECT_EQ(root->shape().tuple_shapes(1).dimensions(1), k); } -TEST_F(XlaBuilderTest, UnboundedAbs) { +//============================================================================// +// Experimental Test +//============================================================================// + +TEST(XlaBuilderTest, DynamicBroadcastInDimExportSuccess) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[1, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape output_dimensions, ParseShape("s32[3]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, ParseShape("f32[1, 2, 3]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[1, 2, 3]")); + DynamicBroadcastInDim( + Parameter(&b, 0, operand, "operand"), + Parameter(&b, 1, output_dimensions, "output_dimensions"), + /*broadcast_dimensions=*/{1, 2}, output_shape); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(module->ToString(), HasSubstr("mhlo.dynamic_broadcast_in_dim")); + EXPECT_THAT(module->ToString(), HasSubstr("broadcast_dimensions=[1,2]")); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, DynamicBroadcastInDimNonBroadcastDimSizeGreaterThanOne) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape output_dimensions, ParseShape("s32[3]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, ParseShape("f32[2, 2, 3]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[2, 2, 3]")); + DynamicBroadcastInDim( + Parameter(&b, 0, operand, "operand"), + Parameter(&b, 1, output_dimensions, "output_dimensions"), + /*broadcast_dimensions=*/{1, 2}, output_shape); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(module->ToString(), HasSubstr("mhlo.dynamic_broadcast_in_dim")); + EXPECT_THAT(module->ToString(), HasSubstr("broadcast_dimensions=[1,2]")); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, DynamicBroadcastInDimDynamicResultSize) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[1, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape output_dimensions, ParseShape("s32[3]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, ParseShape("f32[1, 2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[1, 2, ?]")); + DynamicBroadcastInDim( + Parameter(&b, 0, operand, "operand"), + Parameter(&b, 1, output_dimensions, "output_dimensions"), + /*broadcast_dimensions=*/{1, 2}, output_shape); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(module->ToString(), HasSubstr("mhlo.dynamic_broadcast_in_dim")); + EXPECT_THAT(module->ToString(), HasSubstr("broadcast_dimensions=[1,2]")); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, DynamicBroadcastInDimInvalidOutputDimensionsElementType) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape output_dimensions, ParseShape("f32[3]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, ParseShape("f32[2, 3, 3]")); + DynamicBroadcastInDim( + Parameter(&b, 0, operand, "operand"), + Parameter(&b, 1, output_dimensions, "output_dimensions"), + /*broadcast_dimensions=*/{1, 2}, output_shape); + EXPECT_THAT( + BuildHloModule(b), + StatusIs(_, + HasSubstr("output_dimensions must be an integer type f32[3]"))); +} + +TEST(XlaBuilderTest, DynamicBroadcastInDimInvalidOutputDimensionsRank) { XlaBuilder b(TestName()); - StatusOr operand = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); - StatusOr expected = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); - ASSERT_IS_OK(operand.status()); - ASSERT_IS_OK(expected.status()); - Abs(Parameter(&b, 0, operand.value(), "operand")); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - const Shape& result = - module->entry_computation()->root_instruction()->shape(); - EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) - << "result: " << ShapeUtil::HumanString(result) - << " expected: " << ShapeUtil::HumanString(expected.value()); -} - -TEST_F(XlaBuilderTest, UnboundedAdd) { - XlaBuilder b(TestName()); - StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); - StatusOr rhs = ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]"); - StatusOr expected = ParseShape("f32[?, ?, 2, 2, <=2, <=2, ?]"); - ASSERT_IS_OK(lhs.status()); - ASSERT_IS_OK(rhs.status()); - ASSERT_IS_OK(expected.status()); - Add(Parameter(&b, 0, lhs.value(), "lhs"), - Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - const Shape& result = - module->entry_computation()->root_instruction()->shape(); - EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) - << "result: " << ShapeUtil::HumanString(result) - << " expected: " << ShapeUtil::HumanString(expected.value()); -} - -TEST_F(XlaBuilderTest, UnboundedAddUnsupportedImplicitBroadcast) { - XlaBuilder b(TestName()); - StatusOr lhs = ParseShape("f32[?, 10]"); - StatusOr rhs = ParseShape("f32[1]"); - ASSERT_IS_OK(lhs.status()); - ASSERT_IS_OK(rhs.status()); - Add(Parameter(&b, 0, lhs.value(), "lhs"), - Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{1}); - StatusOr> build_status = BuildHloModule(&b); - EXPECT_FALSE(build_status.ok()); - EXPECT_THAT(build_status.status().message(), - HasSubstr("Unbounded dynamic shapes not supported")); -} - -TEST_F(XlaBuilderTest, UnboundedDiv) { - XlaBuilder b(TestName()); - StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); - StatusOr rhs = ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]"); - StatusOr expected = ParseShape("f32[?, ?, 2, 2, <=2, <=2, ?]"); - ASSERT_IS_OK(lhs.status()); - ASSERT_IS_OK(rhs.status()); - ASSERT_IS_OK(expected.status()); - Div(Parameter(&b, 0, lhs.value(), "lhs"), - Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - const Shape& result = - module->entry_computation()->root_instruction()->shape(); - EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) - << "result: " << ShapeUtil::HumanString(result) - << " expected: " << ShapeUtil::HumanString(expected.value()); -} - -TEST_F(XlaBuilderTest, UnboundedDivUnsupportedImplicitBroadcast) { - XlaBuilder b(TestName()); - StatusOr lhs = ParseShape("f32[?, 10]"); - StatusOr rhs = ParseShape("f32[1]"); - ASSERT_IS_OK(lhs.status()); - ASSERT_IS_OK(rhs.status()); - Div(Parameter(&b, 0, lhs.value(), "lhs"), - Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{1}); - StatusOr> build_status = BuildHloModule(&b); - EXPECT_FALSE(build_status.ok()); - EXPECT_THAT(build_status.status().message(), - HasSubstr("Unbounded dynamic shapes not supported")); -} - -TEST_F(XlaBuilderTest, UnboundedExp) { - XlaBuilder b(TestName()); - StatusOr operand = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); - StatusOr expected = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); - ASSERT_IS_OK(operand.status()); - ASSERT_IS_OK(expected.status()); - Exp(Parameter(&b, 0, operand.value(), "operand")); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - const Shape& result = - module->entry_computation()->root_instruction()->shape(); - EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) - << "result: " << ShapeUtil::HumanString(result) - << " expected: " << ShapeUtil::HumanString(expected.value()); -} - -TEST_F(XlaBuilderTest, UnboundedMax) { - XlaBuilder b(TestName()); - StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); - StatusOr rhs = ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]"); - StatusOr expected = ParseShape("f32[?, ?, 2, 2, <=2, <=2, ?]"); - ASSERT_IS_OK(lhs.status()); - ASSERT_IS_OK(rhs.status()); - ASSERT_IS_OK(expected.status()); - Max(Parameter(&b, 0, lhs.value(), "lhs"), - Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - const Shape& result = - module->entry_computation()->root_instruction()->shape(); - EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) - << "result: " << ShapeUtil::HumanString(result) - << " expected: " << ShapeUtil::HumanString(expected.value()); -} - -TEST_F(XlaBuilderTest, UnboundedMaxUnsupportedImplicitBroadcast) { - XlaBuilder b(TestName()); - StatusOr lhs = ParseShape("f32[?, 10]"); - StatusOr rhs = ParseShape("f32[1]"); - ASSERT_IS_OK(lhs.status()); - ASSERT_IS_OK(rhs.status()); - Max(Parameter(&b, 0, lhs.value(), "lhs"), - Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{1}); - StatusOr> build_status = BuildHloModule(&b); - EXPECT_FALSE(build_status.ok()); - EXPECT_THAT(build_status.status().message(), - HasSubstr("Unbounded dynamic shapes not supported")); -} - -TEST_F(XlaBuilderTest, UnboundedMul) { - XlaBuilder b(TestName()); - StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); - StatusOr rhs = ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]"); - StatusOr expected = ParseShape("f32[?, ?, 2, 2, <=2, <=2, ?]"); - ASSERT_IS_OK(lhs.status()); - ASSERT_IS_OK(rhs.status()); - ASSERT_IS_OK(expected.status()); - Mul(Parameter(&b, 0, lhs.value(), "lhs"), - Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - const Shape& result = - module->entry_computation()->root_instruction()->shape(); - EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) - << "result: " << ShapeUtil::HumanString(result) - << " expected: " << ShapeUtil::HumanString(expected.value()); -} - -TEST_F(XlaBuilderTest, UnboundedMulUnsupportedImplicitBroadcast) { - XlaBuilder b(TestName()); - StatusOr lhs = ParseShape("f32[?, 10]"); - StatusOr rhs = ParseShape("f32[1]"); - ASSERT_IS_OK(lhs.status()); - ASSERT_IS_OK(rhs.status()); - Mul(Parameter(&b, 0, lhs.value(), "lhs"), - Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{1}); - StatusOr> build_status = BuildHloModule(&b); - EXPECT_FALSE(build_status.ok()); - EXPECT_THAT(build_status.status().message(), - HasSubstr("Unbounded dynamic shapes not supported")); -} - -TEST_F(XlaBuilderTest, UnboundedPow) { - XlaBuilder b(TestName()); - StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); - StatusOr rhs = ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]"); - StatusOr expected = ParseShape("f32[?, ?, 2, 2, <=2, <=2, ?]"); - ASSERT_IS_OK(lhs.status()); - ASSERT_IS_OK(rhs.status()); - ASSERT_IS_OK(expected.status()); - Pow(Parameter(&b, 0, lhs.value(), "lhs"), - Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - const Shape& result = - module->entry_computation()->root_instruction()->shape(); - EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) - << "result: " << ShapeUtil::HumanString(result) - << " expected: " << ShapeUtil::HumanString(expected.value()); -} - -TEST_F(XlaBuilderTest, UnboundedPowUnsupportedImplicitBroadcast) { - XlaBuilder b(TestName()); - StatusOr lhs = ParseShape("f32[?, 10]"); - StatusOr rhs = ParseShape("f32[1]"); - ASSERT_IS_OK(lhs.status()); - ASSERT_IS_OK(rhs.status()); - Pow(Parameter(&b, 0, lhs.value(), "lhs"), - Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{1}); - StatusOr> build_status = BuildHloModule(&b); - EXPECT_FALSE(build_status.ok()); - EXPECT_THAT(build_status.status().message(), - HasSubstr("Unbounded dynamic shapes not supported")); -} - -TEST_F(XlaBuilderTest, UnboundedSub) { - XlaBuilder b(TestName()); - StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); - StatusOr rhs = ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]"); - StatusOr expected = ParseShape("f32[?, ?, 2, 2, <=2, <=2, ?]"); - ASSERT_IS_OK(lhs.status()); - ASSERT_IS_OK(rhs.status()); - ASSERT_IS_OK(expected.status()); - Sub(Parameter(&b, 0, lhs.value(), "lhs"), - Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - const Shape& result = - module->entry_computation()->root_instruction()->shape(); - EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) - << "result: " << ShapeUtil::HumanString(result) - << " expected: " << ShapeUtil::HumanString(expected.value()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape output_dimensions, + ParseShape("s32[2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, ParseShape("f32[2, 3, 3]")); + DynamicBroadcastInDim( + Parameter(&b, 0, operand, "operand"), + Parameter(&b, 1, output_dimensions, "output_dimensions"), + /*broadcast_dimensions=*/{1, 2}, output_shape); + EXPECT_THAT( + BuildHloModule(b), + StatusIs(_, + HasSubstr("output_dimensions must be rank 1 but got rank 2"))); } -TEST_F(XlaBuilderTest, UnboundedSubUnsupportedImplicitBroadcast) { +TEST(XlaBuilderTest, DynamicBroadcastInDimIncompatibleBroadcastSize) { XlaBuilder b(TestName()); - StatusOr lhs = ParseShape("f32[?, 10]"); - StatusOr rhs = ParseShape("f32[1]"); - ASSERT_IS_OK(lhs.status()); - ASSERT_IS_OK(rhs.status()); - Sub(Parameter(&b, 0, lhs.value(), "lhs"), - Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{1}); - StatusOr> build_status = BuildHloModule(&b); - EXPECT_FALSE(build_status.ok()); - EXPECT_THAT(build_status.status().message(), - HasSubstr("Unbounded dynamic shapes not supported")); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape output_dimensions, ParseShape("s32[3]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, ParseShape("f32[2, 3, 3]")); + DynamicBroadcastInDim( + Parameter(&b, 0, operand, "operand"), + Parameter(&b, 1, output_dimensions, "output_dimensions"), + /*broadcast_dimensions=*/{1, 2}, output_shape); + EXPECT_THAT( + BuildHloModule(b), + StatusIs(_, HasSubstr("size of operand dimension 0 (2) is not compatible " + "with size of result dimension 1 (3)"))); } +//============================================================================// +// Unbounded Dynamism Test +//============================================================================// + +struct UnaryOpTestCase { + std::string operand; + std::string expected; + std::function unary_op; +}; + +struct BinaryOpTestCase { + std::string lhs; + std::string rhs; + absl::Span broadcast_dimensions; + std::string expected; + std::function)> binary_op; + std::optional error_message; +}; + +constexpr absl::string_view kBroadcastDimensionMismatch = + "Broadcast dimension 0 mismatch: 2 != -9223372036854775808; f32[2] and " + "f32[?,10]."; +std::array empty_array = {}; +std::array zero_array = {0}; + +class XlaBuilderUnboundedUnaryOpTest + : public ::testing::TestWithParam {}; + +class XlaBuilderUnboundedBinaryOpTest + : public ::testing::TestWithParam {}; + +TEST_P(XlaBuilderUnboundedUnaryOpTest, UnboundedUnaryOpTest) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape(GetParam().operand)); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + GetParam().unary_op(Parameter(&b, 0, operand, "operand")); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST_P(XlaBuilderUnboundedBinaryOpTest, UnboundedBinaryOpTest) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + GetParam().binary_op(Parameter(&b, 0, lhs, "lhs"), + Parameter(&b, 1, rhs, "rhs"), + GetParam().broadcast_dimensions); + if (const auto result = BuildHloModule(b); result.ok()) { + ASSERT_NE(*result, nullptr); + EXPECT_THAT(GetRoot(**result), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); + } else { + ASSERT_TRUE(GetParam().error_message.has_value()); + EXPECT_THAT(result, StatusIs(_, HasSubstr(*GetParam().error_message))); + } +} + +TEST(XlaBuilderTest, UnboundedAddScalarBroadcast) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + Add(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + /*broadcast_dimensions=*/empty_array); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedAddDegenerateBroadcast) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[1, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + Add(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + /*broadcast_dimensions=*/{0, 1}); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedAddUnsupportedImplicitBroadcast) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[2]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + Add(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + /*broadcast_dimensions=*/zero_array); + EXPECT_THAT(BuildHloModule(b), + StatusIs(_, HasSubstr(kBroadcastDimensionMismatch))); +} + +TEST(XlaBuilderTest, UnboundedAnd) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, + ParseShape("s32[1, ?, 2, ?, <=2, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, + ParseShape("s32[?, 1, ?, 2, ?, <=2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape("s32[?, ?, 2, 2, <=2, <=2, ?]")); + And(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + /*broadcast_dimensions=*/empty_array); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedBatchNormGrad) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, ?, 7]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape grad_operand, ParseShape("f32[?, ?, 7]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape scale, ParseShape("f32[5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape mean, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape variance, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape grad_scale, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape grad_offset, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape grad_output, ParseShape("f32[5, ?, 7]")); + const Shape expected = + ShapeUtil::MakeTupleShape({grad_operand, grad_scale, grad_offset}); + BatchNormGrad( + Parameter(&b, 0, operand, "operand"), Parameter(&b, 1, scale, "scale"), + Parameter(&b, 2, mean, "mean"), Parameter(&b, 3, variance, "variance"), + Parameter(&b, 4, grad_output, "grad_output"), 1.0, 1); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedBatchNormInference) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, ?, 7]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, ?, 7]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape scale, ParseShape("f32[5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape offset, ParseShape("f32[5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape mean, ParseShape("f32[5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape variance, ParseShape("f32[5]")); + BatchNormInference( + Parameter(&b, 0, operand, "operand"), Parameter(&b, 1, scale, "scale"), + Parameter(&b, 2, offset, "offset"), Parameter(&b, 3, mean, "mean"), + Parameter(&b, 4, variance, "variance"), 1.0, 1); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedBatchNormTraining) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, ?, 7]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape output, ParseShape("f32[?, ?, 7]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape scale, ParseShape("f32[5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape offset, ParseShape("f32[5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape batch_mean, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape batch_var, ParseShape("f32[?]")); + const Shape expected = + ShapeUtil::MakeTupleShape({output, batch_mean, batch_var}); + BatchNormTraining(Parameter(&b, 0, operand, "operand"), + Parameter(&b, 1, scale, "scale"), + Parameter(&b, 2, offset, "offset"), 1.0, 1); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedBitcastConvert) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f16[?, 10, 2]")); + BitcastConvertType(Parameter(&b, 0, operand, "operand"), PrimitiveType::F16); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedBroadcastUnsupportedOperand) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[<=3, ?]")); + Broadcast(Parameter(&b, 0, operand, "operand"), /*broadcast_sizes=*/{1}); + EXPECT_THAT(BuildHloModule(b), + StatusIs(_, HasSubstr("is_unbounded_dynamic"))); +} + +TEST(XlaBuilderTest, UnboundedBroadcastUnsupportedBroadcastSize) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[1]")); + Broadcast(Parameter(&b, 0, operand, "operand"), + /*broadcast_sizes=*/{Shape::kUnboundedSize}); + EXPECT_THAT( + BuildHloModule(b), + StatusIs(_, HasSubstr("Non-broadcast dimensions must not be dynamic."))); +} + +TEST(XlaBuilderTest, UnboundedBroadcastInDim) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[<=2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[<=2, 3, 4]")); + BroadcastInDim(Parameter(&b, 0, operand, "operand"), + /*out_dim_size=*/{2, 3, 4}, + /*broadcast_dimensions=*/{0, 2}); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedBroadcastInDimUnsupported) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[<=3, ?]")); + BroadcastInDim(Parameter(&b, 0, operand, "operand"), + /*out_dim_size=*/{2, 3, Shape::kUnboundedSize}, + /*broadcast_dimensions=*/{0, 2}); + EXPECT_THAT(BuildHloModule(b), + StatusIs(_, HasSubstr("BroadcastInDim output must shape be " + "static or bounded dynamic"))); +} + +TEST(XlaBuilderTest, UnboundedCholesky) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape a, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + Cholesky(Parameter(&b, 0, a, "a"), /*lower=*/true); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedClamp) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, + ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, + ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, + ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]")); + Clamp(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + Parameter(&b, 2, ehs, "ehs")); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedClampScalarMinImplicitBroadcast) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("f32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + Clamp(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + Parameter(&b, 2, ehs, "ehs")); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedClampScalarMinMaxImplicitBroadcast) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("f32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, ParseShape("f32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + Clamp(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + Parameter(&b, 2, ehs, "ehs")); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedClampScalarOperandMaxImplicitBroadcast) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, ParseShape("f32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + Clamp(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + Parameter(&b, 2, ehs, "ehs")); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedClampScalarMinOperandImplicitBroadcast) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("f32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + Clamp(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + Parameter(&b, 2, ehs, "ehs")); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, + UnboundedClampUnsupportedDegenerateOperandImplicitBroadcast) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[1]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, ParseShape("f32[?, 10]")); + Clamp(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + Parameter(&b, 2, ehs, "ehs")); + EXPECT_THAT(BuildHloModule(b), + StatusIs(_, HasSubstr("Unimplemented implicit broadcast."))); +} + +TEST(XlaBuilderTest, UnboundedCompare) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, + ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, + ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape("pred[?, ?, 2, 2, <=2, <=2, ?]")); + Compare(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + /*direction=*/{}); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedConcatenate) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand1, + ParseShape("f32[3, ?, 2, ?, <=2, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand2, + ParseShape("f32[?, 4, ?, 2, ?, <=2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand3, + ParseShape("f32[?, ?, 2, 2, <=2, <=2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape("f32[3, 4, ?, 2, <=2, <=2, ?]")); + ConcatInDim(&b, + {Parameter(&b, 0, operand1, "operand1"), + Parameter(&b, 1, operand2, "operand2"), + Parameter(&b, 2, operand3, "operand3")}, + 2); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedConvert) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("s32[?]")); + ConvertElementType(Parameter(&b, 0, operand, "operand"), PrimitiveType::S32); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedConvolution) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("f32[?, 2, ?, 128]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[2, 2, <=128, 8]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 1, ?, 8]")); + + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + ConvWithGeneralDimensions(Parameter(&b, 0, lhs, "lhs"), + Parameter(&b, 1, rhs, "rhs"), + /*window_strides=*/{1, 1}, Padding::kValid, dnums); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedDot) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + Dot(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs")); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedDotGeneral) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("f32[?, <=3, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[2, 4, 5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, <=3, 5]")); + + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(1); + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + + DotGeneral(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), dnums); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedDynamicSlice) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape start_indices, ParseShape("s32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[2, 2]")); + DynamicSlice(Parameter(&b, 0, operand, "operand"), + /*start_indices=*/ + { + Parameter(&b, 1, start_indices, "start_indices0"), + Parameter(&b, 2, start_indices, "start_indices1"), + }, + /*slice_sizes=*/{2, 2}); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedGather) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[3, 4, 2]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape start_indices, + ParseShape("s32[?, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, ?, 2, 2]")); + + GatherDimensionNumbers dimension_numbers; + dimension_numbers.add_offset_dims(2); + dimension_numbers.add_offset_dims(3); + dimension_numbers.add_collapsed_slice_dims(0); + dimension_numbers.add_start_index_map(1); + dimension_numbers.add_start_index_map(0); + dimension_numbers.set_index_vector_dim(2); + + Gather(Parameter(&b, 0, operand, "operand"), + Parameter(&b, 1, start_indices, "start_indices"), dimension_numbers, + /*slice_sizes=*/{1, 2, 2}); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedGetTupleElement) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + GetTupleElement(Tuple(&b, {Parameter(&b, 0, operand, "operand")}), 0); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedMap) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand0, ParseShape("f32[2, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand1, ParseShape("f32[?, 3, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[2, ?, ?]")); + + XlaComputation computation; + { + const std::unique_ptr sub_builder = b.CreateSubBuilder("add"); + Add(Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), + "arg0"), + Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), + "arg1")); + TF_ASSERT_OK_AND_ASSIGN(computation, sub_builder->Build()); + } + + Map(&b, /*operands=*/ + {Parameter(&b, 0, operand0, "operand0"), + Parameter(&b, 1, operand1, "operand1")}, + computation, /*dimensions=*/{0, 1, 2}, + /*static_operands=*/{}); + + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedOptimizationBarrier) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + OptimizationBarrier(Parameter(&b, 0, operand, "operand")); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedOr) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, + ParseShape("s32[1, ?, 2, ?, <=2, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, + ParseShape("s32[?, 1, ?, 2, ?, <=2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape("s32[?, ?, 2, 2, <=2, <=2, ?]")); + Or(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + /*broadcast_dimensions=*/empty_array); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedPad) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 21]")); + PaddingConfig padding_config; + for (int i = 0; i < 2; i++) { + auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_low(1); + dimension->set_edge_padding_high(1); + dimension->set_interior_padding(1); + } + Pad(Parameter(&b, 0, operand, "operand"), + /*padding_value=*/ConstantR0(&b, 0), padding_config); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedReduce) { + XlaBuilder b(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {7}, {false}); + const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape}); + + TF_ASSERT_OK_AND_ASSIGN(const Shape input0, ParseShape("f32[7, 5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape input1, ParseShape("f32[?, 5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape input2, ParseShape("f32[7, ?]")); + const Shape scalar_f32 = ShapeUtil::MakeShape(F32, {}); + const XlaOp init = Parameter(&b, 3, scalar_f32, "init"); + + XlaBuilder bsum(TestName()); + std::vector output_operands = { + Add(Parameter(&bsum, 0, scalar_f32, "arg0"), + Parameter(&bsum, 1, scalar_f32, "arg1")), + Add(Parameter(&bsum, 2, scalar_f32, "arg2"), + Parameter(&bsum, 3, scalar_f32, "arg3")), + Add(Parameter(&bsum, 4, scalar_f32, "arg4"), + Parameter(&bsum, 5, scalar_f32, "arg5"))}; + Tuple(&bsum, absl::MakeSpan(output_operands)); + TF_ASSERT_OK_AND_ASSIGN(const XlaComputation sum, bsum.Build()); + Reduce( + &b, + {Parameter(&b, 0, input0, "input0"), Parameter(&b, 1, input1, "input1"), + Parameter(&b, 2, input2, "input2")}, + {init, init, init}, sum, {1}); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedReducePrecision) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + ReducePrecision(Parameter(&b, 0, operand, "operand"), /*exponent_bits=*/2, + /*mantissa_bits=*/2); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedReduceWindow) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape input, ParseShape("f32[?, 4, 8]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 3, 5]")); + + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"), + Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y")); + TF_ASSERT_OK_AND_ASSIGN(const XlaComputation sum, bsum.Build()); + + ReduceWindow(Parameter(&b, 0, input, "input"), ConstantR0(&b, 0.f), + sum, + /*window_dimensions=*/{1, 2, 4}, + /*window_strides=*/{1, 1, 1}, Padding::kValid); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedReshape) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[2,3]")); + Reshape(Parameter(&b, 0, operand, "operand"), /*dimensions=*/{0}, + /*new_sizes=*/{2, 3}); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedReshapeUnsupportedOutputShape) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[6]")); + Reshape(Parameter(&b, 0, operand, "operand"), /*dimensions=*/{0}, + /*new_sizes=*/{Shape::kUnboundedSize, Shape::kUnboundedSize}); + EXPECT_THAT( + BuildHloModule(b), + StatusIs(_, + HasSubstr( + "Reshaping with unbounded result shape is not supported."))); +} + +TEST(XlaBuilderTest, UnboundedReshapeUnsupportedInferredShape) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?]")); + Reshape(operand, Parameter(&b, 0, operand, "operand")); + EXPECT_THAT( + BuildHloModule(b), + StatusIs(_, + HasSubstr( + "Reshaping with unbounded result shape is not supported."))); +} + +TEST(XlaBuilderTest, UnboundedReverse) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + Rev(Parameter(&b, 0, operand, "operand"), /*dimensions=*/{0, 1}); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedRngBitGenerator) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape initial_state, ParseShape("u32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("u32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape("(u32[?, 10], u32[?, 10])")); + RngBitGenerator(RandomAlgorithm::RNG_DEFAULT, + Parameter(&b, 0, initial_state, "initial_state"), shape); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedRngNormal) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + RngNormal(Parameter(&b, 0, ShapeUtil::MakeScalarShape(F32), "mu"), + Parameter(&b, 1, ShapeUtil::MakeScalarShape(F32), "sigma"), shape); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedRngUniform) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + RngUniform(Parameter(&b, 0, ShapeUtil::MakeScalarShape(F32), "a"), + Parameter(&b, 1, ShapeUtil::MakeScalarShape(F32), "b"), shape); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedScatter) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape input, ParseShape("f32[?, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape scatter_indices, + ParseShape("s32[?, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape updates, ParseShape("f32[?, ?, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, ?, ?]")); + + XlaComputation update_computation; + { + const std::unique_ptr sub_builder = b.CreateSubBuilder("add"); + Add(Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), + "arg0"), + Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), + "arg1")); + TF_ASSERT_OK_AND_ASSIGN(update_computation, sub_builder->Build()); + } + + ScatterDimensionNumbers dimension_numbers; + dimension_numbers.add_update_window_dims(2); + dimension_numbers.add_update_window_dims(3); + dimension_numbers.add_inserted_window_dims(0); + dimension_numbers.add_scatter_dims_to_operand_dims(1); + dimension_numbers.add_scatter_dims_to_operand_dims(0); + dimension_numbers.set_index_vector_dim(2); + + Scatter(Parameter(&b, 0, input, "input"), + Parameter(&b, 1, scatter_indices, "scatter_indices"), + Parameter(&b, 2, updates, "updates"), update_computation, + dimension_numbers, /*indices_are_sorted=*/false, + /*unique_indices=*/false); + + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedSelect) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, + ParseShape("pred[1, ?, 2, ?, <=2, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, + ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, + ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape("f32[1, 1, 2, 2, <=2, <=2, ?]")); + Select(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + Parameter(&b, 2, ehs, "ehs")); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedSelectScalarPred) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("pred[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + Select(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + Parameter(&b, 2, ehs, "ehs")); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedSelectScalarOnTrueOnFalseImplicitBroadcast) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("pred[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, ParseShape("f32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + Select(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + Parameter(&b, 2, ehs, "ehs")); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedSelectScalarPredOnFalseImplicitBroadcast) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("pred[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, ParseShape("f32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + Select(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + Parameter(&b, 2, ehs, "ehs")); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedSelectScalarPredOnTrueImplicitBroadcast) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("pred[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + Select(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + Parameter(&b, 2, ehs, "ehs")); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, + UnboundedSelectUnsupportedDegenerateOperandImplicitBroadcast) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("pred[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[1]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, ParseShape("f32[?, 10]")); + Select(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + Parameter(&b, 2, ehs, "ehs")); + EXPECT_THAT(BuildHloModule(b), + StatusIs(_, HasSubstr("Unimplemented implicit broadcast."))); +} + +TEST(XlaBuilderTest, UnboundedSelectAndScatter) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape source, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape init_value, ParseShape("f32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + + XlaComputation select; + { + const std::unique_ptr sub_builder = + b.CreateSubBuilder("compare"); + Compare(Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), + "arg0"), + Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), + "arg1"), + ComparisonDirection::kGe); + TF_ASSERT_OK_AND_ASSIGN(select, sub_builder->Build()); + } + + XlaComputation scatter; + { + const std::unique_ptr sub_builder = b.CreateSubBuilder("add"); + Add(Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), + "arg0"), + Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), + "arg1")); + TF_ASSERT_OK_AND_ASSIGN(scatter, sub_builder->Build()); + } + + SelectAndScatter(Parameter(&b, 0, operand, "operand"), select, + /*window_dimensions=*/ + std::array({3, 1}), + /*window_strides=*/std::array({2, 1}), + Padding::kValid, Parameter(&b, 1, source, "source"), + Parameter(&b, 2, init_value, "init_value"), scatter); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedSlice) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[1, <=3, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[1, <=2, 3]")); + Slice(Parameter(&b, 0, operand, "operand"), + /*start_indices=*/{0, 1, 2}, + /*limit_indices=*/{1, 3, 5}, + /*strides=*/{1, 1, 1}); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedSort) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + + XlaComputation comparator; + { + const std::unique_ptr sub_builder = + b.CreateSubBuilder("compare"); + Compare(Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), + "arg0"), + Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), + "arg1"), + ComparisonDirection::kLt); + TF_ASSERT_OK_AND_ASSIGN(comparator, sub_builder->Build()); + } + + Sort({Parameter(&b, 0, operand, "operand")}, comparator, + /*dimension=*/0, /*is_stable=*/true); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedTranspose) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, + ParseShape("f32[1, ?, 2, ?, <=2]{4,3,2,1,0}")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape("f32[<=2, 1, ?, 2, ?]{0,2,3,4,1}")); + Transpose(Parameter(&b, 0, operand, "operand"), + /*permutation=*/{4, 0, 3, 2, 1}); + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedTriangularSolve) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape a_shape, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape b_shape, ParseShape("f32[10, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[10, ?]")); + TriangularSolveOptions options; + TriangularSolve(Parameter(&b, 0, a_shape, "a"), + Parameter(&b, 1, b_shape, "b"), + /*left_side=*/true, /*lower*/ true, /*unit_diagonal=*/false, + TriangularSolveOptions::TRANSPOSE); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedTuple) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + const Shape expected = ShapeUtil::MakeTupleShape({operand}); + Tuple(&b, {Parameter(&b, 0, operand, "operand")}); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedWhile) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape init, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?]")); + + XlaComputation add; + { + const std::unique_ptr sub_builder = b.CreateSubBuilder("add"); + Add(Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), + "arg0"), + Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), + "arg1")); + TF_ASSERT_OK_AND_ASSIGN(add, sub_builder->Build()); + } + + XlaComputation condition; + { + const std::unique_ptr sub_builder = + b.CreateSubBuilder("compare"); + Ge(/*lhs=*/ConstantR0(sub_builder.get(), 10.0f), + /*rhs=*/Reduce(/*operand=*/Parameter(sub_builder.get(), 0, init, "prev"), + ConstantR0(sub_builder.get(), 0.0f), add, + /*dimensions_to_reduce=*/{0})); + TF_ASSERT_OK_AND_ASSIGN(condition, sub_builder->Build()); + } + + XlaComputation body; + { + const std::unique_ptr sub_builder = b.CreateSubBuilder("add"); + Add(ConstantR1(sub_builder.get(), {1.0f}), + Parameter(sub_builder.get(), 0, init, "prev"), + /*broadcast_dimensions=*/{0}); + TF_ASSERT_OK_AND_ASSIGN(body, sub_builder->Build()); + } + + While(condition, body, Parameter(&b, 0, init, "init")); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +TEST(XlaBuilderTest, UnboundedXor) { + XlaBuilder b(TestName()); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, + ParseShape("s32[1, ?, 2, ?, <=2, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, + ParseShape("s32[?, 1, ?, 2, ?, <=2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape("s32[?, ?, 2, 2, <=2, <=2, ?]")); + Xor(Parameter(&b, 0, lhs, "lhs"), Parameter(&b, 1, rhs, "rhs"), + /*broadcast_dimensions=*/empty_array); + TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr module, + BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Op().WithShapeEqualTo(&expected))); +} + +INSTANTIATE_TEST_SUITE_P(UnboundedDynamism, XlaBuilderUnboundedUnaryOpTest, + ::testing::ValuesIn( + {{"f32[?]", "f32[?]", &Abs}, + {"f32[?]", "f32[?]", &Cbrt}, + {"f32[?]", "f32[?]", &Ceil}, + {"u32[?]", "u32[?]", &Clz}, + {"f32[?]", "f32[?]", &Cos}, + {"f32[?]", "f32[?]", &Erf}, + {"f32[?]", "f32[?]", &Exp}, + {"f32[?]", "f32[?]", &Expm1}, + {"f32[?]", "f32[?]", &Floor}, + {"f32[?]", "f32[?]", &Imag}, + {"f32[?]", "pred[?]", &IsFinite}, + {"f32[?]", "f32[?]", &Log}, + {"f32[?]", "f32[?]", &Log1p}, + {"f32[?]", "f32[?]", &Logistic}, + {"f32[?]", "f32[?]", &Neg}, + {"s32[?]", "s32[?]", &Not}, + {"u32[?]", "u32[?]", &PopulationCount}, + {"f32[?]", "f32[?]", &Real}, + {"f32[?]", "f32[?]", &Round}, + {"f32[?]", "f32[?]", &RoundNearestEven}, + {"f32[?]", "f32[?]", &Rsqrt}, + {"f32[?]", "f32[?]", &Sign}, + {"f32[?]", "f32[?]", &Sin}, + {"f32[?]", "f32[?]", &Sqrt}, + {"f32[?]", "f32[?]", &Tanh}})); + +INSTANTIATE_TEST_SUITE_P( + UnboundedDynamism, XlaBuilderUnboundedBinaryOpTest, + ::testing::ValuesIn({ + {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", + /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", + &Add}, + {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, + "f32[?, 10]", &Add}, + {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", + /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", + &Atan2}, + {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", + /*broadcast_dimensions=*/empty_array, "c64[?, ?, 2, 2, <=2, <=2, ?]", + &Complex}, + {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, + "c64[?, 10]", &Complex}, + {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", + /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", + &Div}, + {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, + "f32[?, 10]", &Div}, + {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", + /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", + &Max}, + {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, + "f32[?, 10]", &Max}, + {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", + /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", + &Min}, + {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, + "f32[?, 10]", &Min}, + {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", + /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", + &Mul}, + {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, + "f32[?, 10]", &Mul}, + {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, + "pred[?, 10]", &Ne}, + {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", + /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", + &Pow}, + {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, + "f32[?, 10]", &Pow}, + {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", + /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", + &Rem}, + {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, + "f32[?, 10]", &Rem}, + {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", + /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", + &ShiftLeft}, + {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, + "f32[?, 10]", &ShiftLeft}, + {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", + /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", + &ShiftRightArithmetic}, + {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, + "f32[?, 10]", &ShiftRightArithmetic}, + {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", + /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", + &ShiftRightLogical}, + {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, + "f32[?, 10]", &ShiftRightLogical}, + {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", + /*broadcast_dimensions=*/empty_array, "f32[?, ?, 2, 2, <=2, <=2, ?]", + &Sub}, + {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, + "f32[?, 10]", &Sub}, + })); + } // namespace } // namespace xla diff --git a/xla/client/xla_computation.cc b/xla/client/xla_computation.cc index b493a3b573e23..c92de63495d19 100644 --- a/xla/client/xla_computation.cc +++ b/xla/client/xla_computation.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,12 +23,12 @@ limitations under the License. namespace xla { -StatusOr XlaComputation::GetProgramShape() const { +absl::StatusOr XlaComputation::GetProgramShape() const { TF_RET_CHECK(proto_.has_host_program_shape()); return ProgramShape(proto_.host_program_shape()); } -StatusOr> XlaComputation::Snapshot() const { +absl::StatusOr> XlaComputation::Snapshot() const { if (IsNull()) { return InvalidArgument("Computation is invalid."); } diff --git a/xla/client/xla_computation.h b/xla/client/xla_computation.h index 4f70662d79c6d..e21a92d630065 100644 --- a/xla/client/xla_computation.h +++ b/xla/client/xla_computation.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -45,7 +45,7 @@ class XlaComputation { // Returns the "program shape" (parameter and return shapes) for this // computation. - StatusOr GetProgramShape() const; + absl::StatusOr GetProgramShape() const; const std::string& name() const { return proto().name(); } @@ -54,7 +54,7 @@ class XlaComputation { // Requests that we snapshot the computation into a serializable protocol // buffer form. - StatusOr> Snapshot() const; + absl::StatusOr> Snapshot() const; // Returns true if this object is a null Computation. bool IsNull() const { return unique_id_ == -1; } diff --git a/xla/comparison_util.cc b/xla/comparison_util.cc index 19da596af9825..a0df27e07928d 100644 --- a/xla/comparison_util.cc +++ b/xla/comparison_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -163,7 +163,7 @@ absl::string_view ComparisonOrderToString(Comparison::Order order) { } } -StatusOr StringToComparisonDirection( +absl::StatusOr StringToComparisonDirection( absl::string_view direction) { static auto* map = new absl::flat_hash_map({ @@ -181,7 +181,8 @@ StatusOr StringToComparisonDirection( return it->second; } -StatusOr StringToComparisonOrder(absl::string_view order) { +absl::StatusOr StringToComparisonOrder( + absl::string_view order) { static auto* map = new absl::flat_hash_map({ {"TOTALORDER", Comparison::Order::kTotal}, {"PARTIALORDER", Comparison::Order::kPartial}, @@ -193,7 +194,7 @@ StatusOr StringToComparisonOrder(absl::string_view order) { return it->second; } -StatusOr StringToComparisonType( +absl::StatusOr StringToComparisonType( absl::string_view comparison) { static auto* map = new absl::flat_hash_map({ {"FLOAT", Comparison::Type::kFloat}, diff --git a/xla/comparison_util.h b/xla/comparison_util.h index 0c11edf302d46..3c72cac3655b3 100644 --- a/xla/comparison_util.h +++ b/xla/comparison_util.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -61,6 +61,13 @@ class Comparison { kPartial, }; + friend absl::string_view ComparisonOrderToString(Comparison::Order order); + + template + friend void AbslStringify(Sink& sink, const Order& p) { + absl::Format(&sink, "%s", ComparisonOrderToString(p)); + } + // Represents different comparison operations. enum class Direction : uint8_t { kEq, @@ -228,12 +235,13 @@ inline std::ostream& operator<<(std::ostream& os, const Comparison& cmp) { std::string ComparisonDirectionToString(Comparison::Direction direction); std::string ComparisonTypeToString(Comparison::Type type); absl::string_view ComparisonPrimitiveTypeToString(PrimitiveType type); -absl::string_view ComparisonOrderToString(Comparison::Order order); -StatusOr StringToComparisonDirection( +absl::StatusOr StringToComparisonDirection( absl::string_view direction); -StatusOr StringToComparisonType(absl::string_view comparison); -StatusOr StringToComparisonOrder(absl::string_view order); +absl::StatusOr StringToComparisonType( + absl::string_view comparison); +absl::StatusOr StringToComparisonOrder( + absl::string_view order); // Returns a comparison function using the provided key function on each value, // i.e. `key_fn(a) < key_fn(b)`. diff --git a/xla/comparison_util_test.cc b/xla/comparison_util_test.cc index cd3f568282f4e..1581569a5d284 100644 --- a/xla/comparison_util_test.cc +++ b/xla/comparison_util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/compiler_macros.h b/xla/compiler_macros.h new file mode 100644 index 0000000000000..026ebbdcd525d --- /dev/null +++ b/xla/compiler_macros.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_COMPILER_MACROS_H_ +#define XLA_COMPILER_MACROS_H_ + +#if (defined(__GNUC__) || defined(__clang__)) && defined(__SSE2__) +#define XLA_HAS_SSE2 +#elif defined(_MSC_VER) && !defined(_M_ARM64EC) && defined(_M_X64) +#define XLA_HAS_SSE2 +#elif defined(_MSC_VER) && !defined(_M_ARM64EC) && \ + (defined(_M_IX86_FP) && _M_IX86_FP >= 2) +#define XLA_HAS_SSE2 +#elif defined(__AVX__) +#define XLA_HAS_SSE2 +#endif + +#if defined(_M_ARM64) || defined(_M_ARM64EC) +#define XLA_HAS_ARM64 +#define XLA_HAS_ARM_NEON +#elif defined(__ARM_NEON) && !defined(__ARM_BIG_ENDIAN) +#define XLA_HAS_ARM_NEON + +#if defined(__aarch64__) +#define XLA_HAS_ARM64 +#endif // defined(__aarch64__) + +#endif // defined(_M_ARM64) || defined(_M_ARM64EC) + +#if defined(__clang__) +#define XLA_UNROLL _Pragma("unroll") +#elif defined(__GNUC__) +#define XLA_UNROLL _Pragma("GCC unroll 128") +#else +#define XLA_UNROLL +#endif + +#if defined(__GNUC__) || defined(__clang__) +#define XLA_FLATTEN __attribute__((flatten)) +#elif defined(_MSC_VER) +#define XLA_FLATTEN [[msvc::flatten]] +#else +#define XLA_FLATTEN +#endif + +#endif // XLA_COMPILER_MACROS_H_ diff --git a/xla/cpu_function_runtime.cc b/xla/cpu_function_runtime.cc index bc361a662caae..034f21acc359a 100644 --- a/xla/cpu_function_runtime.cc +++ b/xla/cpu_function_runtime.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/cpu_function_runtime.h b/xla/cpu_function_runtime.h index 99e0eaa1c578c..6dccf0e0facc1 100644 --- a/xla/cpu_function_runtime.h +++ b/xla/cpu_function_runtime.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index fb3a714cbdebd..1908e72d135cc 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,10 +22,12 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/base/call_once.h" #include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_map.h" #include "absl/strings/ascii.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -33,9 +35,10 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/debug_options_parsers.h" #include "xla/parse_flags_from_env.h" +#include "xla/stream_executor/cuda/ptx_compiler_support.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/xla.pb.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep -#include "tsl/util/command_line_flags.h" namespace xla { @@ -85,7 +88,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_cpu_fast_math_honor_division(true); // TODO(AyanmoI): Remove this flag when cuDNN FMHA is fully supported. - opts.set_xla_gpu_enable_cudnn_fmha(false); + opts.set_xla_gpu_enable_cudnn_fmha(true); opts.set_xla_gpu_fused_attention_use_cudnn_rng(false); @@ -100,8 +103,9 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION); opts.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLAS); + opts.add_xla_gpu_enable_command_buffer(DebugOptions::CUSTOM_CALL); + opts.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN); opts.set_xla_gpu_graph_num_runs_to_instantiate(-1); - opts.set_xla_gpu_enable_persistent_temp_buffers(false); opts.set_xla_gpu_graph_min_graph_size(5); opts.set_xla_gpu_graph_enable_concurrent_region(false); opts.set_xla_gpu_graph_eviction_timeout_seconds(60); @@ -120,9 +124,10 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_all_gather_combine_by_dim(true); opts.set_xla_gpu_enable_reduce_scatter_combine_by_dim(true); - opts.set_xla_gpu_enable_async_collectives(false); + opts.set_xla_gpu_enable_async_collectives(true); opts.set_xla_gpu_enable_async_all_reduce(true); opts.set_xla_gpu_enable_async_all_gather(false); + opts.set_xla_gpu_enable_async_collective_broadcast(true); opts.set_xla_gpu_enable_async_collective_permute(false); opts.set_xla_gpu_enable_async_all_to_all(false); opts.set_xla_gpu_enable_async_reduce_scatter(false); @@ -135,14 +140,14 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_detailed_logging(true); opts.set_xla_enable_dumping(true); - opts.set_xla_gpu_enable_xla_runtime_executable(true); + opts.set_xla_gpu_enable_xla_runtime_executable(false); opts.set_xla_gpu_enable_custom_fusions(false); + opts.set_xla_gpu_enable_address_computation_fusion(true); opts.set_xla_gpu_nccl_termination_timeout_seconds(-1); opts.set_xla_gpu_enable_shared_constants(true); - - // XLA:GPU + IREE runtime flags. - opts.set_xla_gpu_enable_gpu2_runtime(false); - opts.set_xla_gpu_enable_gpu2_hal(true); + opts.set_xla_gpu_enable_nccl_user_buffers(false); + opts.set_xla_gpu_enable_nccl_comm_splitting(false); + opts.set_xla_gpu_enable_nccl_per_stream_comms(false); // Set 4GB space limit for redzone scratch allocator. opts.set_xla_gpu_redzone_scratch_max_megabytes(1LL << 12); @@ -154,6 +159,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_latency_hiding_scheduler(false); opts.set_xla_gpu_lhs_enable_gpu_async_tracker(true); opts.set_xla_gpu_enable_analytical_latency_estimator(false); + opts.set_xla_gpu_enable_linear_program_scheduler(false); + opts.set_xla_gpu_pgle_profile_file_or_directory_path(""); opts.set_xla_gpu_memory_limit_slop_factor(95); opts.set_xla_gpu_enable_highest_priority_async_stream(true); @@ -192,7 +199,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_exhaustive_tiling_search(false); - opts.set_xla_gpu_enable_priority_fusion(false); + opts.set_xla_gpu_enable_priority_fusion(true); opts.set_xla_gpu_auto_spmd_partitioning_memory_budget_gb(0); opts.set_xla_gpu_auto_spmd_partitioning_memory_budget_ratio(1.1); @@ -204,6 +211,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_reduction_epilogue_fusion(true); opts.set_xla_gpu_enable_nccl_clique_optimization(false); opts.set_xla_gpu_cublas_fallback(true); + opts.set_xla_gpu_cudnn_gemm_fusion_level(0); opts.set_xla_gpu_enable_while_loop_double_buffering(false); opts.set_xla_gpu_ensure_minor_dot_contraction_dims(false); opts.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning(true); @@ -212,6 +220,37 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_cub_radix_sort(true); opts.set_xla_gpu_enable_cudnn_layer_norm(false); opts.set_xla_gpu_threshold_for_windowed_einsum_mib(100000); + + opts.set_xla_gpu_enable_triton_hopper(false); + + // We disable this until b/319271534 is fixed due to errors during linking. + // + // TODO(b/319271534): Re-enable once we use libnvjitlink. + opts.set_xla_gpu_enable_llvm_module_compilation_parallelism(false); + + opts.set_xla_gpu_enable_libnvptxcompiler(false); + + opts.set_xla_gpu_enable_dot_strength_reduction(true); + + opts.set_xla_gpu_enable_bf16_6way_gemm(false); + opts.set_xla_gpu_enable_bf16_3way_gemm(false); + opts.set_xla_gpu_nccl_collective_max_nchannels(0); + opts.set_xla_gpu_nccl_p2p_max_nchannels(0); + + opts.set_xla_gpu_enable_mlir_emitters(false); + opts.set_xla_gpu_max_mlir_kernels(0); + opts.set_xla_gpu_skip_mlir_kernels(0); + + opts.set_xla_gpu_multi_streamed_windowed_einsum(false); + + // Minimum combined size of matrices in matrix multiplication to + // be rewritten to cuBLAS or Triton kernel call. + // This threshold is a conservative estimate and has been measured + // to be always beneficial (up to generally several times faster) + // on V100 and H100 GPUs. See openxla/xla #9319 for details. + const int64_t kDefaultMinGemmRewriteSize = 100; + opts.set_xla_gpu_gemm_rewrite_size_threshold(kDefaultMinGemmRewriteSize); + return opts; } @@ -290,12 +329,13 @@ void MakeDebugOptionsFlags(std::vector* flag_list, }; }; - auto float_setter_for = [](void (DebugOptions::*member_setter)(float)) { - return [member_setter](float value) { - (flag_values->*member_setter)(value); - return true; - }; - }; + auto float_setter_for = + [debug_options](void (DebugOptions::*member_setter)(float)) { + return [debug_options, member_setter](float value) { + (debug_options->*member_setter)(value); + return true; + }; + }; // Custom "sub-parser" lambda for xla_gpu_shape_checks. auto setter_for_xla_gpu_shape_checks = @@ -390,17 +430,76 @@ void MakeDebugOptionsFlags(std::vector* flag_list, // Custom "sub-parser" lambda for xla_gpu_enable_command_buffer. auto setter_for_xla_gpu_enable_command_buffer = - [debug_options](const std::string& values) { - debug_options->clear_xla_gpu_enable_command_buffer(); - for (const absl::string_view value : absl::StrSplit(values, ',')) { + [debug_options](const std::string& input) { + auto is_command_type = [](absl::string_view value) { DebugOptions::CommandBufferCmdType cmd_type; - if (!DebugOptions::CommandBufferCmdType_Parse( - absl::AsciiStrToUpper(value), &cmd_type)) { - return false; + return DebugOptions::CommandBufferCmdType_Parse( + absl::AsciiStrToUpper(value), &cmd_type); + }; + + auto is_add_or_remove_command_type = [&](absl::string_view value) { + if (absl::StartsWith(value, "+") || absl::StartsWith(value, "-")) { + return (is_command_type(value.substr(1))); } - debug_options->add_xla_gpu_enable_command_buffer(cmd_type); + return false; + }; + + auto parse_command_type = [](absl::string_view value) { + DebugOptions::CommandBufferCmdType cmd_type; + DebugOptions::CommandBufferCmdType_Parse(absl::AsciiStrToUpper(value), + &cmd_type); + return cmd_type; + }; + + auto erase_command_type = [](tsl::protobuf::RepeatedField* enabled, + DebugOptions::CommandBufferCmdType type) { + auto it = enabled->begin(); + while (it != enabled->end()) { + if (*it == type) { + it = enabled->erase(it); + } else { + it++; + } + } + }; + + // Disable command buffers by clearing a set of supported commands. + if (input.empty()) { + debug_options->clear_xla_gpu_enable_command_buffer(); + return true; } - return true; + + std::vector values = absl::StrSplit(input, ','); + + // Overwrite a set of supported commands with a flag. + if (absl::c_all_of(values, is_command_type)) { + debug_options->clear_xla_gpu_enable_command_buffer(); + for (const absl::string_view value : values) { + debug_options->add_xla_gpu_enable_command_buffer( + parse_command_type(value)); + } + return true; + } + + // Add or remove a commands from a default set. + if (absl::c_all_of(values, is_add_or_remove_command_type)) { + for (const absl::string_view value : values) { + DebugOptions::CommandBufferCmdType cmd_type = + parse_command_type(value.substr(1)); + if (absl::StartsWith(value, "+")) { + debug_options->add_xla_gpu_enable_command_buffer(cmd_type); + } else if (absl::StartsWith(value, "-")) { + tsl::protobuf::RepeatedField* enabled = + debug_options->mutable_xla_gpu_enable_command_buffer(); + erase_command_type(enabled, cmd_type); + } + return true; + } + } + + // Return an error if flag value was not recognized as one of the + // supported modes. + return false; }; // Custom "sub-parser" for xla_fuel. Note that ConsumeFuel does not do any @@ -857,6 +956,18 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_force_compilation_parallelism(), "Overrides normal multi-threaded compilation setting to use this many " "threads. Setting to 0 (the default value) means no enforcement.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_llvm_module_compilation_parallelism", + bool_setter_for( + &DebugOptions:: + set_xla_gpu_enable_llvm_module_compilation_parallelism), + debug_options->xla_gpu_enable_llvm_module_compilation_parallelism(), + "Decides whether we can do LLVM module compilation in a parallelised " + "way. If set to false, then it will be single threaded, otherwise the " + "number of threads depends on the " + "--xla_gpu_force_compilation_parallelism flag and the thread pool " + "supplied to GpuCompiler.")); + flag_list->push_back( tsl::Flag("xla_gpu_deterministic_ops", bool_setter_for(&DebugOptions::set_xla_gpu_deterministic_ops), @@ -872,6 +983,12 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_enable_async_all_reduce), debug_options->xla_gpu_enable_async_all_reduce(), "Converts synchronous all-reduce ops into asynchronous.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_async_collective_broadcast", + bool_setter_for( + &DebugOptions::set_xla_gpu_enable_async_collective_broadcast), + debug_options->xla_gpu_enable_async_collective_broadcast(), + "Converts synchronous collective-broadcast ops into asynchronous.")); flag_list->push_back(tsl::Flag( "xla_gpu_enable_async_collective_permute", bool_setter_for( @@ -1003,7 +1120,10 @@ void MakeDebugOptionsFlags(std::vector* flag_list, flag_list->push_back(tsl::Flag( "xla_gpu_enable_command_buffer", setter_for_xla_gpu_enable_command_buffer, command_types_to_string(debug_options->xla_gpu_enable_command_buffer()), - "The types of the commands that are recorded into command buffers")); + "The types of the commands that are recorded into command buffers. It" + " can either be a list of command types or a list of command types with" + " + and - as prefix, which indicate adding or removing a command type" + " to/from the default list.")); flag_list->push_back(tsl::Flag( "xla_gpu_graph_num_runs_to_instantiate", int32_setter_for( @@ -1033,14 +1153,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "XLA instantiates new Gpu graphs, it evicts graphs that were not " "recently executed to free space on device.")); - flag_list->push_back(tsl::Flag( - "xla_gpu_enable_persistent_temp_buffers", - bool_setter_for( - &DebugOptions::set_xla_gpu_enable_persistent_temp_buffers), - debug_options->xla_gpu_enable_persistent_temp_buffers(), - "Allocate temp buffers once during the first execution of an executable. " - "Reuse the allocated buffers in subsequent executions. Executables cannot" - " run concurrently if this is enabled.")); flag_list->push_back( tsl::Flag("xla_dump_disable_metadata", bool_setter_for(&DebugOptions::set_xla_dump_disable_metadata), @@ -1078,16 +1190,12 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Limits custom fusion only to fusions which match this regular " "expression. Default is all custom fusions registerered in a current " "process.")); - flag_list->push_back( - tsl::Flag("xla_gpu_enable_gpu2_runtime", - bool_setter_for(&DebugOptions::set_xla_gpu_enable_gpu2_runtime), - debug_options->xla_gpu_enable_gpu2_runtime(), - "Whether to enable experimental XLA:GPU runtime")); - flag_list->push_back( - tsl::Flag("xla_gpu_enable_gpu2_hal", - bool_setter_for(&DebugOptions::set_xla_gpu_enable_gpu2_hal), - debug_options->xla_gpu_enable_gpu2_hal(), - "Whether to enable CUDA HAL in experimental XLA:GPU runtime")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_address_computation_fusion", + bool_setter_for( + &DebugOptions::set_xla_gpu_enable_address_computation_fusion), + debug_options->xla_gpu_enable_address_computation_fusion(), + "Whether to enable XLA address computation fusion")); flag_list->push_back(tsl::Flag( "xla_gpu_nccl_termination_timeout_seconds", int64_setter_for( @@ -1099,6 +1207,27 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_enable_shared_constants), debug_options->xla_gpu_enable_shared_constants(), "Enable constant sharing between GPU executables")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_nccl_user_buffers", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_nccl_user_buffers), + debug_options->xla_gpu_enable_nccl_user_buffers(), + "Enables NCCL User Buffer Registration. collective_memory_size in the " + "allocator config must also be set to a non-zero value that is large " + "enough to meet peak collective memory usage.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_nccl_comm_splitting", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_nccl_comm_splitting), + debug_options->xla_gpu_enable_nccl_comm_splitting(), + "Enables NCCL communicator splitting which allows sharing NCCL resources " + "between different NCCL cliques.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_nccl_per_stream_comms", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_nccl_per_stream_comms), + debug_options->xla_gpu_enable_nccl_per_stream_comms(), + "A separate NCCL communicator will be created for each stream that a " + "NCCL collective is executed on. This can lead to higher performance if " + "NCCL collectives are issued concurrently at the cost of more GPU memory" + " usage.")); flag_list->push_back(tsl::Flag( "xla_gpu_redzone_scratch_max_megabytes", int64_setter_for( @@ -1188,6 +1317,13 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_enable_analytical_latency_estimator(), "Enable analytical latency estimator for latency-hiding scheduler for " "XLA:GPU")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_linear_program_scheduler", + bool_setter_for( + &DebugOptions::set_xla_gpu_enable_linear_program_scheduler), + debug_options->xla_gpu_enable_linear_program_scheduler(), + "Enable linear program sheduler for better performance" + "XLA:GPU")); flag_list->push_back(tsl::Flag( "xla_gpu_pgle_profile_file_or_directory_path", string_setter_for( @@ -1270,7 +1406,9 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_triton_gemm_any), debug_options->xla_gpu_triton_gemm_any(), "Use Triton-based matrix multiplication for any GEMM it " - "supports without filtering only faster ones.")); + "supports without filtering only faster ones. To make sure " + "only triton gemm is chosen by the autotuner run with " + "`xla_gpu_cublas_fallback` set to false.")); flag_list->push_back(tsl::Flag( "xla_gpu_exhaustive_tiling_search", bool_setter_for(&DebugOptions::set_xla_gpu_exhaustive_tiling_search), @@ -1298,6 +1436,12 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "unless the name ends with .txt or .textproto. It will be loaded at most " "once per process. This only works on CUDA. In tests, the TEST_WORKSPACE " "prefix can be used to load files from their data dependencies.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_require_complete_aot_autotune_results", + bool_setter_for( + &DebugOptions::set_xla_gpu_require_complete_aot_autotune_results), + debug_options->xla_gpu_multi_streamed_windowed_einsum(), + "Whether to require complete AOT autotuning results.")); flag_list->push_back(tsl::Flag( "xla_gpu_auto_spmd_partitioning_memory_budget_gb", int32_setter_for( @@ -1372,6 +1516,12 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_cublas_fallback(), "Allow Triton GEMM autotuning to fall back to cuBLAS when that " "is faster.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_cudnn_gemm_fusion_level", + int32_setter_for(&DebugOptions::set_xla_gpu_cudnn_gemm_fusion_level), + debug_options->xla_gpu_cudnn_gemm_fusion_level(), + "cuDNN GEMM fusion level; higher level corresponds to more kinds of " + "fused operations.")); flag_list->push_back( tsl::Flag("xla_gpu_mock_custom_calls", bool_setter_for(&DebugOptions::set_xla_gpu_mock_custom_calls), @@ -1426,7 +1576,86 @@ void MakeDebugOptionsFlags(std::vector* flag_list, &DebugOptions::set_xla_gpu_threshold_for_windowed_einsum_mib), debug_options->xla_gpu_threshold_for_windowed_einsum_mib(), "Threshold to enable windowed einsum (collective matmul) in MB." + "Einsums that have partitioned operand(can be either LHS or RHS) that's " + "larger than this threshold will be transformed to use windowed einsums." "Default is 100000")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_triton_hopper", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_triton_hopper), + debug_options->xla_gpu_enable_triton_hopper(), + "Enable Hopper-specific optimizations such as MMA_V3 and pipelining.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_libnvptxcompiler", + [debug_options](bool enabled) { + if (enabled && !stream_executor::IsLibNvPtxCompilerSupported()) { + // This feature can't be enabled when XLA was built without + // libnvptxcompiler support. + return false; + } + debug_options->set_xla_gpu_enable_libnvptxcompiler(enabled); + return true; + }, + debug_options->xla_gpu_enable_libnvptxcompiler(), + "Use libnvptxcompiler for PTX-to-GPU-assembly compilation instead of " + "calling ptxas.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_dot_strength_reduction", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_dot_strength_reduction), + debug_options->xla_gpu_enable_dot_strength_reduction(), + "Enable rewriting matmuls with a vector into reductions.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_bf16_6way_gemm", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_bf16_6way_gemm), + debug_options->xla_gpu_enable_bf16_6way_gemm(), + "Use BF16 6way gemm to compute F32 gemm.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_bf16_3way_gemm", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_bf16_3way_gemm), + debug_options->xla_gpu_enable_bf16_3way_gemm(), + "Use BF16 3way gemm to compute F32 gemm.")); + flag_list->push_back( + tsl::Flag("xla_gpu_nccl_collective_max_nchannels", + int64_setter_for( + &DebugOptions::set_xla_gpu_nccl_collective_max_nchannels), + debug_options->xla_gpu_nccl_collective_max_nchannels(), + "Specify the maximum number of channels(SMs) NCCL will use " + "for collective operations. Default is 0 which is to let " + "NCCL decide.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_nccl_p2p_max_nchannels", + int64_setter_for(&DebugOptions::set_xla_gpu_nccl_p2p_max_nchannels), + debug_options->xla_gpu_nccl_p2p_max_nchannels(), + "Specify the maximum number of channels(SMs) NCCL will use " + "for p2p operations. Default is 0 which is to let " + "NCCL decide.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_mlir_emitters", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_mlir_emitters), + debug_options->xla_gpu_enable_mlir_emitters(), + "Enable new MLIR-based emitters.")); + flag_list->push_back( + tsl::Flag("xla_gpu_max_mlir_kernels", + int64_setter_for(&DebugOptions::set_xla_gpu_max_mlir_kernels), + debug_options->xla_gpu_max_mlir_kernels(), + "Maximum number of kernels to emit with MLIR.")); + flag_list->push_back( + tsl::Flag("xla_gpu_skip_mlir_kernels", + int64_setter_for(&DebugOptions::set_xla_gpu_skip_mlir_kernels), + debug_options->xla_gpu_skip_mlir_kernels(), + "Number of initial kernels to skip MLIR emission for.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_multi_streamed_windowed_einsum", + bool_setter_for( + &DebugOptions::set_xla_gpu_multi_streamed_windowed_einsum), + debug_options->xla_gpu_multi_streamed_windowed_einsum(), + "Whether to run windowed einsum using multiple compute streams.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_gemm_rewrite_size_threshold", + int64_setter_for(&DebugOptions::set_xla_gpu_gemm_rewrite_size_threshold), + debug_options->xla_gpu_gemm_rewrite_size_threshold(), + "Threshold to rewrite matmul to cuBLAS or Triton " + "(minumum combined number of elements of both matrices " + "in non-batch dimensions to be considered for a rewrite).")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/xla/debug_options_flags.h b/xla/debug_options_flags.h index 6ec89d7746e05..4bc8420441af1 100644 --- a/xla/debug_options_flags.h +++ b/xla/debug_options_flags.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,9 +19,9 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/xla.pb.h" #include "tsl/platform/logging.h" -#include "tsl/util/command_line_flags.h" namespace xla { diff --git a/xla/debug_options_parsers.h b/xla/debug_options_parsers.h index 0350acccff3f3..865f69f2fd244 100644 --- a/xla/debug_options_parsers.h +++ b/xla/debug_options_parsers.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/debug_options_parsers_test.cc b/xla/debug_options_parsers_test.cc index 1260096898b2a..318b98f163faa 100644 --- a/xla/debug_options_parsers_test.cc +++ b/xla/debug_options_parsers_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/device_util.h b/xla/device_util.h index ed20599360c3e..17f8aca2dbfa4 100644 --- a/xla/device_util.h +++ b/xla/device_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/ef57.cc b/xla/ef57.cc new file mode 100644 index 0000000000000..b0c3f974319f2 --- /dev/null +++ b/xla/ef57.cc @@ -0,0 +1,116 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/ef57.h" + +#include +#include + +#include "absl/types/span.h" +#include "xla/compiler_macros.h" +#include "tsl/platform/logging.h" + +#ifdef XLA_HAS_SSE2 +#include // IWYU pragma: keep +#endif + +#if defined(XLA_HAS_ARM_NEON) && defined(XLA_HAS_ARM64) +#include // IWYU pragma: keep +#endif + +namespace xla { + +void ConvertF64ToEf57(absl::Span input, + absl::Span output) { + DCHECK_EQ(input.size() * 2, output.size()); +#ifdef __AVX__ + constexpr int kDoublesPerAvxIteration = sizeof(__m256d) / sizeof(double); + constexpr int kFloatsPerSseRegister = sizeof(__m128) / sizeof(float); + while (input.size() >= kDoublesPerAvxIteration) { + __m256d x = _mm256_loadu_pd(input.data()); + + __m128 x_hi_f32 = _mm256_cvtpd_ps(x); + __m256d x_hi_f64 = _mm256_cvtps_pd(x_hi_f32); + __m256d x_lo_f64 = _mm256_sub_pd(x, x_hi_f64); + __m128 x_lo_f32 = _mm256_cvtpd_ps(x_lo_f64); + + const __m128 inf = _mm_set1_ps(std::numeric_limits::infinity()); + __m128 x_hi_exponent = _mm_and_ps(x_hi_f32, inf); + __m128 x_is_finite = _mm_cmplt_ps(x_hi_exponent, inf); + x_lo_f32 = _mm_and_ps(x_lo_f32, x_is_finite); + + _mm_storeu_ps(output.data(), _mm_unpacklo_ps(x_hi_f32, x_lo_f32)); + output.remove_prefix(kFloatsPerSseRegister); + _mm_storeu_ps(output.data(), _mm_unpackhi_ps(x_hi_f32, x_lo_f32)); + output.remove_prefix(kFloatsPerSseRegister); + + input.remove_prefix(kDoublesPerAvxIteration); + } +#endif +#ifdef XLA_HAS_SSE2 + constexpr int kDoublesPerSseIteration = sizeof(__m128d) / sizeof(double); + constexpr int kFloatsPerSseIteration = sizeof(__m128) / sizeof(float); + while (input.size() >= kDoublesPerSseIteration) { + __m128d x = _mm_loadu_pd(input.data()); + __m128 x_hi_f32 = _mm_cvtpd_ps(x); + __m128d x_hi_f64 = _mm_cvtps_pd(x_hi_f32); + __m128d x_lo_f64 = _mm_sub_pd(x, x_hi_f64); + __m128 x_lo_f32 = _mm_cvtpd_ps(x_lo_f64); + + const __m128 inf = _mm_set1_ps(std::numeric_limits::infinity()); + __m128 x_hi_exponent = _mm_and_ps(x_hi_f32, inf); + __m128 x_is_finite = _mm_cmplt_ps(x_hi_exponent, inf); + x_lo_f32 = _mm_and_ps(x_lo_f32, x_is_finite); + + __m128 to_store = _mm_unpacklo_ps(x_hi_f32, x_lo_f32); + _mm_storeu_ps(output.data(), to_store); + + input.remove_prefix(kDoublesPerSseIteration); + output.remove_prefix(kFloatsPerSseIteration); + } +#endif +#if defined(XLA_HAS_ARM_NEON) && defined(XLA_HAS_ARM64) + constexpr int kDoublesPerNeonIteration = sizeof(float64x2_t) / sizeof(double); + constexpr int kFloatsPerNeonIteration = sizeof(float32x2x2_t) / sizeof(float); + while (input.size() >= kDoublesPerNeonIteration) { + float64x2_t x = vld1q_f64(input.data()); + float32x2_t x_hi_f32 = vcvt_f32_f64(x); + float64x2_t x_hi_f64 = vcvt_f64_f32(x_hi_f32); + float64x2_t x_lo_f64 = vsubq_f64(x, x_hi_f64); + float32x2_t x_lo_f32 = vcvt_f32_f64(x_lo_f64); + + uint32x2_t x_is_finite = + vcalt_f32(x_hi_f32, vdup_n_f32(std::numeric_limits::infinity())); + x_lo_f32 = vreinterpret_f32_u32( + vand_u32(vreinterpret_u32_f32(x_lo_f32), x_is_finite)); + + float32x2x2_t to_store; + to_store.val[0] = x_hi_f32; + to_store.val[1] = x_lo_f32; + vst2_f32(output.data(), to_store); + + input.remove_prefix(kDoublesPerNeonIteration); + output.remove_prefix(kFloatsPerNeonIteration); + } +#endif + + while (input.size() >= 1) { + std::tie(output[0], output[1]) = SplitF64ToF32(input.front()); + input.remove_prefix(1); + output.remove_prefix(2); + } +} + +} // namespace xla diff --git a/xla/ef57.h b/xla/ef57.h new file mode 100644 index 0000000000000..06cf1715c2123 --- /dev/null +++ b/xla/ef57.h @@ -0,0 +1,66 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_EF57_H_ +#define XLA_EF57_H_ + +#include +#include + +#include "absl/types/span.h" + +namespace xla { + +// Utility function to split a double-precision float (F64) into a pair of F32s. +// For a p-bit number, and a splitting point (p/2) <= s <= (p - 1), the +// algorithm produces a (p - s)-bit value 'hi' and a non-overlapping (s - 1)-bit +// value 'lo'. See Theorem 4 in [1] (attributed to Dekker) or [2] for the +// original theorem by Dekker. +// +// For double-precision F64s, which contain a 53 bit mantissa (52 of them +// explicit), we can represent the most significant 49 digits as the unevaluated +// sum of two single-precision floats 'hi' and 'lo'. The 'hi' float stores the +// most significant 24 bits and the sign bit of 'lo' together with its mantissa +// store the remaining 25 bits. The exponent of the resulting representation is +// still restricted to 8 bits of F32. +// +// References: +// [1] A. Thall, Extended-Precision Floating-Point Numbers for GPU Computation, +// SIGGRAPH Research Posters, 2006. +// (http://andrewthall.org/papers/df64_qf128.pdf) +// [2] T. J. Dekker, A floating point technique for extending the available +// precision, Numerische Mathematik, vol. 18, pp. 224–242, 1971. +inline std::pair SplitF64ToF32(double x) { + const float x_f32 = static_cast(x); + + const bool result_is_finite = std::isfinite(x_f32); + + // The high float is simply the double rounded to the nearest float. Because + // we are rounding to nearest with ties to even, the error introduced in + // rounding is less than half an ULP in the high ULP. + const float hi = x_f32; + // We can compute the low term using Sterbenz' lemma: If a and b are two + // positive floating point numbers and a/2 ≤ b ≤ 2a, then their difference can + // be computed exactly. + // Note: the difference is computed exactly but is rounded to the nearest + // float which will introduce additional error. + const float lo = static_cast(x - static_cast(hi)); + return std::make_pair(hi, result_is_finite ? lo : 0.0f); +} +void ConvertF64ToEf57(absl::Span input, absl::Span output); + +} // namespace xla + +#endif // XLA_EF57_H_ diff --git a/xla/ef57_test.cc b/xla/ef57_test.cc new file mode 100644 index 0000000000000..1f5d48cfda016 --- /dev/null +++ b/xla/ef57_test.cc @@ -0,0 +1,93 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/ef57.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/log_streamer.h" +#include "absl/random/random.h" +#include "absl/types/span.h" +#include "xla/test.h" + +namespace xla { +namespace { + +TEST(Ef57Test, DoubleMax) { + // Overflowing the F32 exponent in SplitF64ToF32 should result in a pair of + // [∞,0]. + auto [high, low] = SplitF64ToF32(std::numeric_limits::max()); + EXPECT_EQ(high, std::numeric_limits::infinity()); + EXPECT_EQ(low, 0.0f); +} + +TEST(Ef57Test, Overflow) { + auto [high, low] = SplitF64ToF32(0x1.ffffffp+127); + EXPECT_EQ(high, std::numeric_limits::infinity()); + EXPECT_EQ(low, 0.0f); +} + +TEST(Ef57Test, CheckPrecision) { + auto [high, low] = SplitF64ToF32(2.0 - 0x1p-52); + EXPECT_EQ(high, 2.0f); + EXPECT_EQ(low, -0x1p-52f); +} + +TEST(Ef57Test, SimpleArray) { + std::vector inputs(127); + + absl::BitGen gen; + for (double& input : inputs) { + input = absl::Uniform(gen, 0.0f, 1.0f); + } + + std::vector outputs(inputs.size() * 2); + ConvertF64ToEf57(inputs, absl::MakeSpan(outputs)); + for (int i = 0; i < inputs.size(); ++i) { + EXPECT_EQ(outputs[i * 2], inputs[i]); + EXPECT_EQ(outputs[i * 2 + 1], 0.0f); + } +} + +TEST(Ef57Test, RelativeSplit) { + const float distance = std::scalbnf(1.0f, std::numeric_limits::digits); + std::vector inputs(127); + + absl::BitGen gen; + for (double& input : inputs) { + input = absl::Uniform(gen, 0.0, 1.0); + } + + std::vector outputs(inputs.size() * 2); + ConvertF64ToEf57(inputs, absl::MakeSpan(outputs)); + for (int i = 0; i < outputs.size(); i += 2) { + auto most_significant = outputs[i]; + auto least_significant = outputs[i + 1]; + auto most_significant_mag = std::fabs(most_significant); + auto least_significant_mag = std::fabs(least_significant); + EXPECT_FALSE(std::isnan(most_significant_mag)); + if (most_significant_mag == 0.0f) { + EXPECT_EQ(least_significant_mag, 0.0f); + } else { + EXPECT_GT(most_significant_mag, least_significant_mag * distance); + } + } +} + +} // namespace +} // namespace xla diff --git a/xla/error_spec.h b/xla/error_spec.h index cde904940ecec..27d6029f9ec6c 100644 --- a/xla/error_spec.h +++ b/xla/error_spec.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,11 +20,11 @@ namespace xla { // Structure describing permissible absolute and relative error bounds. struct ErrorSpec { - explicit ErrorSpec(float aabs, float arel = 0, bool relaxed_nans = false) + explicit ErrorSpec(double aabs, double arel = 0, bool relaxed_nans = false) : abs(aabs), rel(arel), relaxed_nans(relaxed_nans) {} - float abs; // Absolute error bound. - float rel; // Relative error bound. + double abs; // Absolute error bound. + double rel; // Relative error bound. // If relaxed_nans is true then any result is valid if we are expecting NaNs. // In effect, this allows the tested operation to produce incorrect results diff --git a/xla/examples/axpy/BUILD b/xla/examples/axpy/BUILD index 0c31fb33ad022..7db0f85830a9d 100644 --- a/xla/examples/axpy/BUILD +++ b/xla/examples/axpy/BUILD @@ -11,7 +11,6 @@ xla_cc_test( "//xla/client:client_library", "//xla/client:local_client", "//xla/pjrt:local_device_state", - "//xla/pjrt:mlir_to_hlo", "//xla/pjrt:pjrt_stream_executor_client", "//xla/service:cpu_plugin", "//xla/service:platform_util", diff --git a/xla/examples/axpy/stablehlo_compile_test.cc b/xla/examples/axpy/stablehlo_compile_test.cc index a3c44736f0ad9..678e32237f06d 100644 --- a/xla/examples/axpy/stablehlo_compile_test.cc +++ b/xla/examples/axpy/stablehlo_compile_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/executable_run_options.cc b/xla/executable_run_options.cc index 795c5fc417643..9986ef4175de5 100644 --- a/xla/executable_run_options.cc +++ b/xla/executable_run_options.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -124,6 +124,17 @@ ExecutableRunOptions::gpu_executable_run_options() const { return gpu_executable_run_options_; } +ExecutableRunOptions& ExecutableRunOptions::set_cpu_executable_run_options( + const cpu::CpuExecutableRunOptions* cpu_executable_run_options) { + cpu_executable_run_options_ = cpu_executable_run_options; + return *this; +} + +const cpu::CpuExecutableRunOptions* +ExecutableRunOptions::cpu_executable_run_options() const { + return cpu_executable_run_options_; +} + ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) { rng_seed_ = rng_seed; return *this; diff --git a/xla/executable_run_options.h b/xla/executable_run_options.h index 31ba23bf3b7a1..b1a9d07b28ddf 100644 --- a/xla/executable_run_options.h +++ b/xla/executable_run_options.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -51,6 +51,10 @@ class DeviceAssignment; class ExecutionProfile; class Shape; +namespace cpu { +class CpuExecutableRunOptions; +} // namespace cpu + namespace gpu { class GpuExecutableRunOptions; } // namespace gpu @@ -210,6 +214,12 @@ class ExecutableRunOptions { return recv_device_memory_function_; } + // CPU-backend specific options. These are kept out-of-line to avoid bloating + // the size of this dependency for CPU-only AOT builds. + ExecutableRunOptions& set_cpu_executable_run_options( + const cpu::CpuExecutableRunOptions* cpu_executable_run_options); + const cpu::CpuExecutableRunOptions* cpu_executable_run_options() const; + // GPU-backend specific options. These are kept out-of-line to avoid bloating // the size of this dependency for CPU-only AOT builds. ExecutableRunOptions& set_gpu_executable_run_options( @@ -231,6 +241,7 @@ class ExecutableRunOptions { SendDeviceMemoryFunction* send_device_memory_function_ = nullptr; RecvDeviceMemoryFunction* recv_device_memory_function_ = nullptr; RunId run_id_; + const cpu::CpuExecutableRunOptions* cpu_executable_run_options_ = nullptr; const gpu::GpuExecutableRunOptions* gpu_executable_run_options_ = nullptr; }; diff --git a/xla/execution_options_util.cc b/xla/execution_options_util.cc index d1d2fdeb9fd3b..e86196a3eb622 100644 --- a/xla/execution_options_util.cc +++ b/xla/execution_options_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "xla/execution_options_util.h" + #include "xla/debug_options_flags.h" +#include "xla/xla.pb.h" namespace xla { diff --git a/xla/execution_options_util.h b/xla/execution_options_util.h index 2a69fa2aab2ec..60db0257f318c 100644 --- a/xla/execution_options_util.h +++ b/xla/execution_options_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/experiments/sm_bandwidth_benchmark/BUILD b/xla/experiments/sm_bandwidth_benchmark/BUILD index 6399f47b8237e..0fb3865181ea9 100644 --- a/xla/experiments/sm_bandwidth_benchmark/BUILD +++ b/xla/experiments/sm_bandwidth_benchmark/BUILD @@ -31,6 +31,6 @@ xla_cc_test( ":sm_bw_utils", "@com_google_googletest//:gtest_main", ] + if_cuda([ - "@tsl//tsl/cuda:cudart", + "//xla/tsl/cuda:cudart", ]), ) diff --git a/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.cu.cc b/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.cu.cc index 824ddcd24ba90..798d4ceeae086 100644 --- a/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.cu.cc +++ b/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.cu.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.h b/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.h index 172fb23d5610b..611a9a61a6af8 100644 --- a/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.h +++ b/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/experiments/sm_bandwidth_benchmark/sm_bw_test.cc b/xla/experiments/sm_bandwidth_benchmark/sm_bw_test.cc index 485eb2f528588..0246e86c7a22c 100644 --- a/xla/experiments/sm_bandwidth_benchmark/sm_bw_test.cc +++ b/xla/experiments/sm_bandwidth_benchmark/sm_bw_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/experiments/sm_bandwidth_benchmark/sm_bw_utils.h b/xla/experiments/sm_bandwidth_benchmark/sm_bw_utils.h index 3de989c5497c7..47bf5a08a4240 100644 --- a/xla/experiments/sm_bandwidth_benchmark/sm_bw_utils.h +++ b/xla/experiments/sm_bandwidth_benchmark/sm_bw_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/experiments/triton_autotuning/BUILD b/xla/experiments/triton_autotuning/BUILD new file mode 100644 index 0000000000000..cdadad94af0bb --- /dev/null +++ b/xla/experiments/triton_autotuning/BUILD @@ -0,0 +1,10 @@ +# Single matmul autotuning for Triton + +package( + default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:private"], + licenses = ["notice"], +) + +# The scripts in this directory are meant to be used with OSS Triton and are not integrated +# into the build system. diff --git a/xla/experiments/triton_autotuning/check_csv.py b/xla/experiments/triton_autotuning/check_csv.py index ade709694ebf9..6974c6ea66bbd 100755 --- a/xla/experiments/triton_autotuning/check_csv.py +++ b/xla/experiments/triton_autotuning/check_csv.py @@ -1,5 +1,5 @@ #!/usr/bin/python3 -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/experiments/triton_autotuning/check_data.py b/xla/experiments/triton_autotuning/check_data.py index 8ad71b671bd81..3a982c550ecb4 100755 --- a/xla/experiments/triton_autotuning/check_data.py +++ b/xla/experiments/triton_autotuning/check_data.py @@ -1,5 +1,5 @@ #!/usr/bin/python3 -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/experiments/triton_autotuning/matmul_lib.py b/xla/experiments/triton_autotuning/matmul_lib.py index f69c4105d8349..baa3ee715e8d3 100755 --- a/xla/experiments/triton_autotuning/matmul_lib.py +++ b/xla/experiments/triton_autotuning/matmul_lib.py @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/experiments/triton_autotuning/run_single_matmul.py b/xla/experiments/triton_autotuning/run_single_matmul.py index 733b332c22e59..e8156e99e43a8 100755 --- a/xla/experiments/triton_autotuning/run_single_matmul.py +++ b/xla/experiments/triton_autotuning/run_single_matmul.py @@ -1,5 +1,5 @@ #!/usr/bin/python3 -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/experiments/triton_autotuning/search.py b/xla/experiments/triton_autotuning/search.py index 2adbe975322af..94122ad586c87 100755 --- a/xla/experiments/triton_autotuning/search.py +++ b/xla/experiments/triton_autotuning/search.py @@ -1,5 +1,5 @@ #!/usr/bin/python3 -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/experiments/triton_autotuning/tune_single_matmul.py b/xla/experiments/triton_autotuning/tune_single_matmul.py index 7f4b70a9ab54a..360bf84ecfc07 100755 --- a/xla/experiments/triton_autotuning/tune_single_matmul.py +++ b/xla/experiments/triton_autotuning/tune_single_matmul.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/ffi/BUILD b/xla/ffi/BUILD index 07b8625504e3f..b76582b240793 100644 --- a/xla/ffi/BUILD +++ b/xla/ffi/BUILD @@ -1,5 +1,5 @@ -load("//xla:xla.bzl", "xla_cc_test") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -18,18 +18,14 @@ cc_library( srcs = ["call_frame.cc"], hdrs = ["call_frame.h"], deps = [ - ":api", - "//xla:status", "//xla:types", "//xla:xla_data_proto_cc", "//xla/ffi/api:c_api", "//xla/ffi/api:c_api_internal", - "//xla/service:executable", "//xla/stream_executor:device_memory", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", ], ) @@ -39,22 +35,17 @@ cc_library( hdrs = ["ffi.h"], deps = [ ":api", - ":call_frame", + "//xla:shape_util", "//xla:status", - "//xla:statusor", "//xla:types", "//xla:xla_data_proto_cc", "//xla/ffi/api:c_api", "//xla/ffi/api:c_api_internal", + "//xla/hlo/ir:hlo", "//xla/runtime:memref_view", - "//xla/service:executable", + "//xla/stream_executor", "//xla/stream_executor:device_memory", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", ], ) @@ -67,18 +58,13 @@ cc_library( ":call_frame", "//xla:status", "//xla:statusor", - "//xla:types", - "//xla:xla_data_proto_cc", "//xla/ffi/api:c_api", "//xla/ffi/api:c_api_internal", - "//xla/runtime:memref_view", + "//xla/hlo/ir:hlo", "//xla/service:executable", - "//xla/stream_executor:device_memory", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", ], ) @@ -92,9 +78,13 @@ xla_cc_test( ":ffi_api", "//xla:xla_data_proto_cc", "//xla/service:executable", + "//xla/stream_executor", "//xla/stream_executor:device_memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], diff --git a/xla/ffi/api/BUILD b/xla/ffi/api/BUILD index ca0e6f8ef146e..aeec931230050 100644 --- a/xla/ffi/api/BUILD +++ b/xla/ffi/api/BUILD @@ -1,6 +1,6 @@ -load("//xla:xla.bzl", "xla_cc_test") load("@tsl//tsl:tsl.default.bzl", "filegroup") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -74,6 +74,7 @@ xla_cc_test( "//xla/ffi:ffi_api", "//xla/stream_executor:device_memory", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status_matchers", diff --git a/xla/ffi/api/api.h b/xla/ffi/api/api.h index b37a170f57638..6169f06b2ce82 100644 --- a/xla/ffi/api/api.h +++ b/xla/ffi/api/api.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -57,20 +57,34 @@ limitations under the License. #include "xla/ffi/api/c_api.h" +#ifdef __has_builtin +#define XLA_FFI_HAS_BUILTIN(x) __has_builtin(x) +#else +#define XLA_FFI_HAS_BUILTIN(x) 0 +#endif + #if __has_attribute(always_inline) -#define XLA_ATTRIBUTE_ALWAYS_INLINE inline __attribute__((always_inline)) +#define XLA_FFI_ATTRIBUTE_ALWAYS_INLINE inline __attribute__((always_inline)) #elif defined(_MSC_VER) -#define XLA_ATTRIBUTE_ALWAYS_INLINE __forceinline +#define XLA_FFI_ATTRIBUTE_ALWAYS_INLINE __forceinline #else -#define XLA_ATTRIBUTE_ALWAYS_INLINE inline +#define XLA_FFI_ATTRIBUTE_ALWAYS_INLINE inline #endif #if __has_attribute(noinline) -#define XLA_ATTRIBUTE_NEVER_INLINE __attribute__((noinline)) +#define XLA_FFI_ATTRIBUTE_NEVER_INLINE __attribute__((noinline)) #elif defined(_MSC_VER) -#define XLA_ATTRIBUTE_NEVER_INLINE __declspec(noinline) +#define XLA_FFI_ATTRIBUTE_NEVER_INLINE __declspec(noinline) +#else +#define XLA_FFI_ATTRIBUTE_NEVER_INLINE +#endif + +#if XLA_FFI_HAS_BUILTIN(__builtin_expect) +#define XLA_FFI_PREDICT_FALSE(x) (__builtin_expect(false || (x), false)) +#define XLA_FFI_PREDICT_TRUE(x) (__builtin_expect(false || (x), true)) #else -#define XLA_ATTRIBUTE_NEVER_INLINE +#define XLA_FFI_PREDICT_FALSE(x) (x) +#define XLA_FFI_PREDICT_TRUE(x) (x) #endif namespace xla::ffi { @@ -89,15 +103,27 @@ class Handler; class Ffi { public: + // Creates and empty binding specification wich allows to define FFI handler + // signature separately from implementation and rely on compile time type + // checking to verify that signature matches the provided implementation. static Binding<> Bind(); + // Automatic FFI binding that does binding specification inference from the + // `fn` type signature and binds `fn` to it. This enables a more concise FFI + // handler registration with fully automatic type inference at the cost of + // less readable error messages, template metaprogramming "magic" and a risk + // to accidentally change handler type without noticing it. + template + static auto BindTo(Fn fn); + virtual ~Ffi() = default; virtual XLA_FFI_Error* Call(const XLA_FFI_CallFrame* call_frame) const = 0; - // Registers handler with an XLA runtime under the given name. - static inline XLA_FFI_Error* RegisterStaticHandler(const XLA_FFI_Api* api, - std::string_view name, - XLA_FFI_Handler* handler); + // Registers handler with an XLA runtime under the given name on a given + // platform. + static inline XLA_FFI_Error* RegisterStaticHandler( + const XLA_FFI_Api* api, std::string_view name, std::string_view platform, + XLA_FFI_Handler* handler, XLA_FFI_Handler_Traits traits = 0); protected: template @@ -117,14 +143,20 @@ class Ffi { XLA_FFI_Error* Ffi::RegisterStaticHandler(const XLA_FFI_Api* api, std::string_view name, - XLA_FFI_Handler* handler) { - std::string name_str(name); // make a copy to guarantee it's null terminated + std::string_view platform, + XLA_FFI_Handler* handler, + XLA_FFI_Handler_Traits traits) { + // Make copies of string views to guarantee they are null terminated. + std::string name_str(name); + std::string platform_str(platform); XLA_FFI_Handler_Register_Args args; args.struct_size = XLA_FFI_Handler_Register_Args_STRUCT_SIZE; args.priv = nullptr; args.name = name_str.c_str(); + args.platform = platform_str.c_str(); args.handler = handler; + args.traits = traits; return api->XLA_FFI_Handler_Register(&args); } @@ -178,8 +210,15 @@ namespace internal { // A type tag to forward all remaining args as `RemainingArgs`. struct RemainingArgsTag {}; -// A type tag to distinguish arguments tied to the attributes in the -// `Binding` variadic template argument. +// A type tag to distinguish parameters tied to results in the `Binding` +// variadic template. In XLA FFI we use destination passing style APIs and don't +// return anything from the handler, but instead pass a destination where the +// handler should write the result. +template +struct RetTag {}; + +// A type tag to distinguish parameters tied to the attributes in the +// `Binding` variadic template. template struct AttrTag {}; @@ -188,7 +227,7 @@ struct AttrTag {}; template struct AttrsTag {}; -// A type tag to distinguish arguments extracted from an execution context. +// A type tag to distinguish parameter extracted from an execution context. template struct CtxTag {}; @@ -235,6 +274,11 @@ class Binding { return {std::move(*this)}; } + template + Binding> Ret() && { + return {std::move(*this)}; + } + Binding RemainingArgs() && { static_assert(!internal::HasRemainingArgsTag::value, "remaining arguments can be passed just once"); @@ -264,7 +308,7 @@ class Binding { template std::unique_ptr> To(Fn fn) { return std::unique_ptr>( - new Handler(std::forward(fn), std::move(attrs_))); + new Handler(std::move(fn), std::move(attrs_))); } private: @@ -287,6 +331,234 @@ class Binding { inline Binding<> Ffi::Bind() { return xla::ffi::Binding<>(); } +//===----------------------------------------------------------------------===// +// Template metaprogramming to automatially infer Binding from invocable object. +//===----------------------------------------------------------------------===// + +// A little bit of metaprogramming that automatically infers the binding schema +// from an invocable type signature. + +// XLA FFI binding for an argument. +// +// Example: binding for the `MyType` argument +// +// template <> +// struct ArgBinding { +// using Arg = MyType; +// }; +// +template +struct ArgBinding { + using Arg = void; +}; + +// XLA FFI binding for a returned result. +// +// Example: binding for the `MyType` result +// +// template <> +// struct RetBinding { +// using Ret = MyType; +// }; +// +template +struct RetBinding { + using Ret = void; +}; + +// XLA FFI binding for a named attribute. +// +// Example: binding for the `MyType` attribute +// +// template <> +// struct AttrBinding { +// using Attr = MyAttr; +// static constexpr std::string_view name() { return "my_attr"; } +// }; +// +template +struct AttrBinding { + using Attr = void; +}; + +// XLA FFI binding for dictionary attributes: automatic parsing of all +// attributes into user defined struct. +template +struct AttrsBinding { + using Attrs = void; +}; + +// XLA FFI binding for values passed via context. +// +// Example: binding for the `gpuStream_t` platform stream +// +// template <> +// struct CtxBinding { +// using Ctx = PlatformStream; +// }; +// +template +struct CtxBinding { + using Ctx = void; +}; + +namespace internal { + +template +inline constexpr bool is_arg_binding_v = + !std::is_void_v::Arg>; + +template +inline constexpr bool is_ret_binding_v = + !std::is_void_v::Ret>; + +template +inline constexpr bool is_attr_binding_v = + !std::is_void_v::Attr>; + +template +inline constexpr bool is_attrs_binding_v = + !std::is_void_v::Attrs>; + +template +inline constexpr bool is_ctx_binding_v = + !std::is_void_v::Ctx>; + +// A helper template to bind `Params` to `Fn` one by one. +template +struct BindOne; + +// A specialization that binds one parameter. +template +struct BindOne { + // Binds single parameter and then continues with remaining parameters using + // recursive template instantiation. + template + static auto To(Fn fn, InFlightBinding binding) { + if constexpr (is_arg_binding_v) { + // Bind parameter as an FFI handler argument. + return BindOne::To( + std::move(fn), + std::move(binding).template Arg::Arg>()); + } else if constexpr (is_ret_binding_v) { + // Bind parameter as an FFI handler result. + return BindOne::To( + std::move(fn), + std::move(binding).template Ret::Ret>()); + + } else if constexpr (is_attr_binding_v) { + // Bind parameter as a named FFI handler attribute. + return BindOne::To( + std::move(fn), + std::move(binding).template Attr::Attr>( + std::string(AttrBinding::name()))); + + } else if constexpr (is_attrs_binding_v) { + // Bind parameter as attributes dictionary. + return BindOne::To( + std::move(fn), + std::move(binding) + .template Attrs::Attrs>()); + + } else if constexpr (is_ctx_binding_v) { + // Bind parameter as an FFI handler context. + return BindOne::To( + std::move(fn), + std::move(binding).template Ctx::Ctx>()); + + } else { + // Parameter is not recognized as one of the types that can be bound to + // FFI handler. + static_assert(sizeof(Param) == 0, + "parameter is not supported for binding"); + } + } +}; + +// A specialization that binds `Fn` after consuming all parameters. +template +struct BindOne { + template + static auto To(Fn fn, InFlightBinding binding) { + return binding.To(std::move(fn)); + } +}; + +template +struct Bind; + +// Binding specialization for function pointers (and captureless lambdas that +// can be casted to function pointers). +template +struct Bind { + using Fn = ResultType (*)(Params...); + + static auto To(Fn fn) { + return BindOne::To(std::move(fn), Ffi::Bind()); + } +}; + +// Binding specialization for callables (lambdas with captures). +template +struct Bind { + static auto To(Fn fn) { + return BindOne::To(std::move(fn), Ffi::Bind()); + } +}; + +} // namespace internal + +template +auto Ffi::BindTo(Fn fn) { + if constexpr (std::is_pointer_v) { + return internal::Bind::To(fn); + } else { + return internal::Bind::To(std::move(fn)); + } +} + +// A container for defining parameters corresponding to results. +template +class Result { + public: + Result(T value) : value_(value) {} // NOLINT + T& operator*() { return value_; } + T* operator->() { return &value_; } + + private: + T value_; +}; + +// A container for defining parameters corresponding to attributes with an +// attribute name available as compile time value. +template +class Attr { + public: + Attr(T value) : value_(value) {} // NOLINT + T& operator*() { return value_; } + T* operator->() { return &value_; } + + private: + T value_; +}; + +//===----------------------------------------------------------------------===// +// Attributes bindings +//===----------------------------------------------------------------------===// + +// Default attribute binding for `Attr` parameters. +template +struct AttrBinding> { + using Attr = T; + static constexpr std::string_view name() { return attr_name; } +}; + +// Default attributes binding for `Dictonary` parameters. +template <> +struct AttrsBinding { + using Attrs = Dictionary; +}; + //===----------------------------------------------------------------------===// // Arguments decoding implementation //===----------------------------------------------------------------------===// @@ -304,6 +576,23 @@ inline Binding<> Ffi::Bind() { return xla::ffi::Binding<>(); } template struct ArgDecoding; +//===----------------------------------------------------------------------===// +// Results decoding implementation +//===----------------------------------------------------------------------===// + +// XLA FFI results decoding must be defined by specializing this template. +// +// Example: decoding for the `MyType` results +// +// template <> +// struct RetDecoding { +// static std::optional Decode(XLA_FFI_RetType type, void* ret); +// }; +// +// If argument can't be decoded it should return the empty optional. +template +struct RetDecoding; + //===----------------------------------------------------------------------===// // Attributes decoding implementation //===----------------------------------------------------------------------===// @@ -314,6 +603,7 @@ struct ArgDecoding; // // template <> // struct AttrDecoding { +// using Type = // static std::optional Decode(XLA_FFI_AttrType type, void* attr, // DiagnosticEngine&); // } @@ -430,6 +720,7 @@ namespace internal { // attributes we decoded so far to compute call frame offsets. struct DecodingOffsets { int64_t args = 0; + int64_t rets = 0; int64_t attrs = 0; }; @@ -442,7 +733,7 @@ struct DecodingContext { template struct Decode { - XLA_ATTRIBUTE_ALWAYS_INLINE + XLA_FFI_ATTRIBUTE_ALWAYS_INLINE static std::optional call(DecodingOffsets& offsets, DecodingContext& ctx, DiagnosticEngine& diagnostic) { int64_t idx = offsets.args++; @@ -453,9 +744,22 @@ struct Decode { } // namespace internal +template +struct internal::Decode> { + static std::optional> call(DecodingOffsets& offsets, + DecodingContext& ctx, + DiagnosticEngine& diagnostic) { + int64_t idx = offsets.rets++; + return RetDecoding::Decode(ctx.call_frame->rets.types[idx], + ctx.call_frame->rets.rets[idx], diagnostic); + } +}; + template struct internal::Decode> { - static std::optional call(DecodingOffsets& offsets, DecodingContext& ctx, + using R = typename AttrDecoding::Type; + + static std::optional call(DecodingOffsets& offsets, DecodingContext& ctx, DiagnosticEngine& diagnostic) { // Find decoded attribute corresponding to the given attribute index. int64_t i = offsets.attrs++; @@ -474,7 +778,10 @@ struct internal::Decode> { // Attribute name does not match. std::string_view attr_name_view = {attr_name->ptr, attr_name->len}; - if (attr_name_view != ctx.attrs_names[i]) return std::nullopt; + if (attr_name_view != ctx.attrs_names[i]) { + return diagnostic.Emit("Attribute name mismatch: ") + << attr_name_view << " vs " << ctx.attrs_names[i]; + } return AttrDecoding::Decode(attr_type, attr, diagnostic); } @@ -545,16 +852,16 @@ class RemainingArgs { public: RemainingArgs(const XLA_FFI_Args* args, size_t offset) : args_(args), offset_(offset) { - assert(offset <= args_->num_args && "illegal remaining args offset"); + assert(offset <= args_->size && "illegal remaining args offset"); } - size_t size() const { return args_->num_args - offset_; } + size_t size() const { return args_->size - offset_; } bool empty() const { return size() == 0; } template Expected get(size_t index) const { size_t idx = offset_ + index; - if (idx >= args_->num_args) { + if (idx >= args_->size) { return Unexpected("Index out of range."); } @@ -589,10 +896,10 @@ class Dictionary { public: explicit Dictionary(const XLA_FFI_Attrs* attrs) : attrs_(attrs) {} - size_t size() const { return attrs_->num_attrs; } + size_t size() const { return attrs_->size; } bool contains(std::string_view name) const { - return Find(name) < attrs_->num_attrs; + return Find(name) < attrs_->size; } template @@ -609,7 +916,7 @@ class Dictionary { std::optional get(std::string_view name, DiagnosticEngine& diagnostic) const { size_t idx = Find(name); - if (idx >= attrs_->num_attrs) { + if (idx >= attrs_->size) { return diagnostic.Emit("Unexpected attribute: ") << name; } @@ -621,7 +928,7 @@ class Dictionary { private: size_t Find(std::string_view name) const { XLA_FFI_ByteSpan** begin = attrs_->names; - XLA_FFI_ByteSpan** end = begin + attrs_->num_attrs; + XLA_FFI_ByteSpan** end = begin + attrs_->size; auto name_eq = [&](XLA_FFI_ByteSpan* attr) { std::string_view name_view = {attr->ptr, attr->len}; @@ -668,21 +975,21 @@ struct FnArgType { using Type = T; }; -// Extracts the underlying type from the attribute type tag. -template -struct FnArgType> { - using Type = T; +template <> +struct FnArgType { + using Type = RemainingArgs; }; -// Extracts the underlying type from the context type tag. +// Extracts the underlying type from the returned result type tag. template -struct FnArgType> { - using Type = typename CtxDecoding::Type; +struct FnArgType> { + using Type = Result; }; -template <> -struct FnArgType { - using Type = RemainingArgs; +// Extracts the underlying type from the attribute type tag. +template +struct FnArgType> { + using Type = typename AttrDecoding::Type; }; template @@ -690,11 +997,19 @@ struct FnArgType> { using Type = T; }; +// Extracts the underlying type from the context type tag. +template +struct FnArgType> { + using Type = typename CtxDecoding::Type; +}; + // A template for checking if type in a parameter pack is a tagged one and has // a special decoding rule defined by template specialization. template struct IsTagged : std::false_type {}; template +struct IsTagged> : std::true_type {}; +template struct IsTagged> : std::true_type {}; template struct IsTagged> : std::true_type {}; @@ -729,6 +1044,9 @@ class Handler : public Ffi { static constexpr int64_t kNumArgs = internal::NumArgs::value; + static constexpr int64_t kNumRets = + internal::NumTagged::value; + static constexpr int64_t kNumAttrs = internal::NumTagged::value; @@ -757,31 +1075,41 @@ class Handler : public Ffi { // Check that the number of passed arguments matches the signature. Each // individual argument decoding will check the actual type. if (internal::HasRemainingArgsTag::value) { - if (call_frame->args.num_args < kNumArgs) { + if (XLA_FFI_PREDICT_FALSE(call_frame->args.size < kNumArgs)) { return InvalidArgument( call_frame->api, StrCat("Wrong number of arguments: expected at least ", - kNumArgs - 1, " but got ", call_frame->args.num_args)); + kNumArgs - 1, " but got ", call_frame->args.size)); } } else { - if (call_frame->args.num_args != kNumArgs) { + if (XLA_FFI_PREDICT_FALSE(call_frame->args.size != kNumArgs)) { return InvalidArgument( call_frame->api, StrCat("Wrong number of arguments: expected ", kNumArgs, - " but got ", call_frame->args.num_args)); + " but got ", call_frame->args.size)); } } + // Check that the number of results matches the signature. Each individual + // result decoding will check the actual type. + if (XLA_FFI_PREDICT_FALSE(call_frame->rets.size != kNumRets)) { + return InvalidArgument( + call_frame->api, + StrCat("Wrong number of results: expected ", kNumRets, " but got ", + call_frame->rets.size)); + } + // Check that the number of passed attributes matches the signature. Each // individual attribute decoding will check the actual type. If we decode // attributes into a dictionary (or a custom struct decoded from a // dictionary), then there is no need to check attributes, as the FFI // handler (or a struct decoding) should be responsible for it. - if (kNumDictAttrs == 0 && call_frame->attrs.num_attrs != kNumAttrs) { + if (XLA_FFI_PREDICT_FALSE(kNumDictAttrs == 0 && + call_frame->attrs.size != kNumAttrs)) { return InvalidArgument( call_frame->api, StrCat("Wrong number of attributes: expected ", kNumAttrs, - " but got ", call_frame->attrs.num_attrs)); + " but got ", call_frame->attrs.size)); } // Define index sequences to access custom call operands. @@ -792,7 +1120,7 @@ class Handler : public Ffi { private: template - XLA_ATTRIBUTE_ALWAYS_INLINE XLA_FFI_Error* Call( + XLA_FFI_ATTRIBUTE_ALWAYS_INLINE XLA_FFI_Error* Call( const XLA_FFI_CallFrame* call_frame, std::index_sequence) const { // A helper structure to allow each decoder find the correct offset. internal::DecodingOffsets offsets; @@ -807,7 +1135,7 @@ class Handler : public Ffi { internal::Decode::call(offsets, ctx, diagnostic)...}; bool all_decoded = (std::get(args).has_value() && ...); - if (!all_decoded) { + if (XLA_FFI_PREDICT_FALSE(!all_decoded)) { return FailedDecodeError(call_frame, {std::get(args).has_value()...}, diagnostic); } @@ -875,45 +1203,51 @@ class Handler : public Ffi { inline std::ostream& operator<<(std::ostream& os, const XLA_FFI_AttrType type) { switch (type) { - case XLA_FFI_AttrType_I32: - return os << "int32"; - case XLA_FFI_AttrType_I64: - return os << "int64"; - case XLA_FFI_AttrType_F32: - return os << "float"; - case XLA_FFI_AttrType_STRING: - return os << "string"; + case XLA_FFI_AttrType_ARRAY: + return os << "array"; case XLA_FFI_AttrType_DICTIONARY: return os << "dictionary"; + case XLA_FFI_AttrType_SCALAR: + return os << "scalar"; + case XLA_FFI_AttrType_STRING: + return os << "string"; } } #define XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(T, TYPE) \ template <> \ struct AttrDecoding { \ + using Type = T; \ static std::optional Decode(XLA_FFI_AttrType type, void* attr, \ DiagnosticEngine& diagnostic) { \ - if (type != TYPE) { \ + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_SCALAR)) { \ return diagnostic.Emit("Wrong attribute type: expected ") \ - << TYPE << " but got " << type; \ + << XLA_FFI_AttrType_SCALAR << " but got " << type; \ + } \ + \ + auto* scalar = reinterpret_cast(attr); \ + if (XLA_FFI_PREDICT_FALSE(scalar->dtype != TYPE)) { \ + return diagnostic.Emit("Wrong scalar data type: expected ") \ + << TYPE << " but got " << scalar->dtype; \ } \ \ - return *reinterpret_cast(attr); \ + return *reinterpret_cast(scalar->value); \ } \ } -XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(int32_t, XLA_FFI_AttrType_I32); -XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(int64_t, XLA_FFI_AttrType_I64); -XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(float, XLA_FFI_AttrType_F32); +XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(int32_t, XLA_FFI_DataType_S32); +XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(int64_t, XLA_FFI_DataType_S64); +XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(float, XLA_FFI_DataType_F32); #undef XLA_FFI_REGISTER_SCALAR_ATTR_DECODING template <> struct AttrDecoding { + using Type = std::string_view; static std::optional Decode(XLA_FFI_AttrType type, void* attr, DiagnosticEngine& diagnostic) { - if (type != XLA_FFI_AttrType_STRING) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_STRING)) { return diagnostic.Emit("Wrong attribute type: expected ") << XLA_FFI_AttrType_STRING << " but got " << type; } @@ -925,9 +1259,10 @@ struct AttrDecoding { template <> struct AttrDecoding { + using Type = Dictionary; static std::optional Decode(XLA_FFI_AttrType type, void* attr, DiagnosticEngine& diagnostic) { - if (type != XLA_FFI_AttrType_DICTIONARY) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) { return diagnostic.Emit("Wrong attribute type: expected ") << XLA_FFI_AttrType_DICTIONARY << " but got " << type; } @@ -957,7 +1292,7 @@ template struct DecodeDictionaryAttr { static constexpr size_t kSize = sizeof...(Ts); - XLA_ATTRIBUTE_ALWAYS_INLINE + XLA_FFI_ATTRIBUTE_ALWAYS_INLINE static std::optional Decode(const XLA_FFI_Attrs* attrs, std::array names, DiagnosticEngine& diagnostic) { @@ -965,12 +1300,12 @@ struct DecodeDictionaryAttr { } template - XLA_ATTRIBUTE_ALWAYS_INLINE static std::optional Decode( + XLA_FFI_ATTRIBUTE_ALWAYS_INLINE static std::optional Decode( const XLA_FFI_Attrs* attrs, std::array names, std::index_sequence, DiagnosticEngine& diagnostic) { - if (kSize != attrs->num_attrs) { + if (XLA_FFI_PREDICT_FALSE(kSize != attrs->size)) { return diagnostic.Emit("Wrong number of attributes: expected ") - << kSize << " attributes but got " << attrs->num_attrs; + << kSize << " attributes but got " << attrs->size; } // TODO(ezhulenev): We rely on dictionary to lookup struct members by name @@ -986,7 +1321,7 @@ struct DecodeDictionaryAttr { std::tuple...> members = { dict.get(names[Is], diagnostic)...}; bool all_decoded = (std::get(members).has_value() && ...); - if (!all_decoded) return std::nullopt; + if (XLA_FFI_PREDICT_FALSE(!all_decoded)) return std::nullopt; return T{std::move(*std::get(members))...}; } @@ -1013,22 +1348,31 @@ auto DictionaryDecoder(Members... m) { // StructMember("a"), // StructMember("b")); // -#define XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(T, ...) \ - template <> \ - struct AttrDecoding { \ - static std::optional Decode(XLA_FFI_AttrType type, void* attr, \ - DiagnosticEngine& diagnostic) { \ - if (type != XLA_FFI_AttrType_DICTIONARY) { \ - diagnostic.Emit("Wrong attribute type: expected ") \ - << XLA_FFI_AttrType_DICTIONARY << " but got " << type; \ - return std::nullopt; \ - } \ - \ - auto decoder = internal::DictionaryDecoder(__VA_ARGS__); \ - return decltype(decoder)::Decode( \ - reinterpret_cast(attr), \ - internal::StructMemberNames(__VA_ARGS__), diagnostic); \ - } \ +// Automatically registers attributes binding for a struct that allows automatic +// binding specification inference from a callable signature. +// +#define XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(T, ...) \ + template <> \ + struct AttrsBinding { \ + using Attrs = T; \ + }; \ + \ + template <> \ + struct AttrDecoding { \ + using Type = T; \ + static std::optional Decode(XLA_FFI_AttrType type, void* attr, \ + DiagnosticEngine& diagnostic) { \ + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) { \ + diagnostic.Emit("Wrong attribute type: expected ") \ + << XLA_FFI_AttrType_DICTIONARY << " but got " << type; \ + return std::nullopt; \ + } \ + \ + auto decoder = internal::DictionaryDecoder(__VA_ARGS__); \ + return decltype(decoder)::Decode( \ + reinterpret_cast(attr), \ + internal::StructMemberNames(__VA_ARGS__), diagnostic); \ + } \ } //===----------------------------------------------------------------------===// @@ -1043,23 +1387,44 @@ auto DictionaryDecoder(Members... m) { // Use captureless lambda to function pointer conversion to create a static // XLA_FFI_Handler function pointer variable. -#define XLA_FFI_DEFINE_HANDLER(fn, impl, binding) \ + +// Use explicit binding specification to create a handler. +#define XLA_FFI_DEFINE_HANDLER_EXPLICIT(fn, impl, binding) \ static constexpr XLA_FFI_Handler* fn = +[](XLA_FFI_CallFrame* call_frame) { \ static auto* handler = binding.To(impl).release(); \ return handler->Call(call_frame); \ } +// Automatically infer binding specification from the implementation. +#define XLA_FFI_DEFINE_HANDLER_AUTO(fn, impl) \ + static constexpr XLA_FFI_Handler* fn = +[](XLA_FFI_CallFrame* call_frame) { \ + static auto* handler = ::xla::ffi::Ffi::BindTo(impl).release(); \ + return handler->Call(call_frame); \ + } + +#define XLA_FFI_DEFINE_HANDLER_X(x, fn, impl, binding, FUNC, ...) FUNC + +// This is a trick to define macro with optional parameters. +// Source: https://stackoverflow.com/a/8814003 +#define XLA_FFI_DEFINE_HANDLER(fn, impl, ...) \ + XLA_FFI_DEFINE_HANDLER_X( \ + , fn, impl, ##__VA_ARGS__, \ + XLA_FFI_DEFINE_HANDLER_EXPLICIT(fn, impl, __VA_ARGS__), \ + XLA_FFI_DEFINE_HANDLER_AUTO(fn, impl)) + // TODO(ezhulenev): Add a callback so that end users can log registration error // to appropriate logging destination, e.g. LOG(FATAL) for duplicate internal // FFI handlers. -#define XLA_FFI_REGISTER_HANDLER(API, NAME, FUNC) \ - XLA_FFI_REGISTER_HANDLER_(API, NAME, FUNC, __COUNTER__) -#define XLA_FFI_REGISTER_HANDLER_(API, NAME, FUNC, N) \ - XLA_FFI_REGISTER_HANDLER__(API, NAME, FUNC, N) -#define XLA_FFI_REGISTER_HANDLER__(API, NAME, FUNC, N) \ - XLA_FFI_ATTRIBUTE_UNUSED static const XLA_FFI_Error* \ - xla_ffi_static_handler_##N##_registered_ = [] { \ - return ::xla::ffi::Ffi::RegisterStaticHandler(API, NAME, FUNC); \ +#define XLA_FFI_REGISTER_HANDLER(API, NAME, PLATFORM, FUNC, ...) \ + XLA_FFI_REGISTER_HANDLER_(API, NAME, PLATFORM, FUNC, __COUNTER__, \ + ##__VA_ARGS__) +#define XLA_FFI_REGISTER_HANDLER_(API, NAME, PLATFORM, FUNC, N, ...) \ + XLA_FFI_REGISTER_HANDLER__(API, NAME, PLATFORM, FUNC, N, ##__VA_ARGS__) +#define XLA_FFI_REGISTER_HANDLER__(API, NAME, PLATFORM, FUNC, N, ...) \ + XLA_FFI_ATTRIBUTE_UNUSED static const XLA_FFI_Error* \ + xla_ffi_static_handler_##N##_registered_ = [] { \ + return ::xla::ffi::Ffi::RegisterStaticHandler(API, NAME, PLATFORM, \ + FUNC, ##__VA_ARGS__); \ }() } // namespace xla::ffi diff --git a/xla/ffi/api/c_api.h b/xla/ffi/api/c_api.h index cd6eb0fe1e755..d4dc246fc6f8a 100644 --- a/xla/ffi/api/c_api.h +++ b/xla/ffi/api/c_api.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,12 +20,16 @@ limitations under the License. #include // XLA FFI C API follows PJRT API style for consistency. See `pjrt_c_api.h`. +// More details on versioning strategy and example version checks: +// https://github.com/tensorflow/community/blob/master/rfcs/20200612-stream-executor-c-api/C_API_versioning_strategy.md // Every struct passed across the C API boundary has its size as a member, and // we use it as a sanity check for API compatibility. #define XLA_FFI_STRUCT_SIZE(struct_type, last_field) \ (offsetof(struct_type, last_field) + sizeof(((struct_type*)0)->last_field)) +// Must update XLA_FFI_DEFINE_STRUCT_TRAITS with the new `last_field` after +// adding a new member to a struct. #define XLA_FFI_DEFINE_STRUCT_TRAITS(sname, last_field) \ typedef struct sname sname; \ enum { sname##_STRUCT_SIZE = XLA_FFI_STRUCT_SIZE(sname, last_field) } @@ -118,6 +122,17 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Error_Create_Args, errc); typedef XLA_FFI_Error* XLA_FFI_Error_Create(XLA_FFI_Error_Create_Args* args); +struct XLA_FFI_Error_GetMessage_Args { + size_t struct_size; + void* priv; + XLA_FFI_Error* error; + const char* message; // out +}; + +XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Error_GetMessage_Args, message); + +typedef void XLA_FFI_Error_GetMessage(XLA_FFI_Error_GetMessage_Args* args); + struct XLA_FFI_Error_Destroy_Args { size_t struct_size; void* priv; @@ -172,16 +187,23 @@ typedef enum { XLA_FFI_ArgType_BUFFER = 1, } XLA_FFI_ArgType; +//===----------------------------------------------------------------------===// +// Builtin result types +//===----------------------------------------------------------------------===// + +typedef enum { + XLA_FFI_RetType_BUFFER = 1, +} XLA_FFI_RetType; + //===----------------------------------------------------------------------===// // Builtin attribute types //===----------------------------------------------------------------------===// typedef enum { - XLA_FFI_AttrType_I32 = 1, - XLA_FFI_AttrType_I64 = 2, - XLA_FFI_AttrType_F32 = 3, + XLA_FFI_AttrType_ARRAY = 1, + XLA_FFI_AttrType_DICTIONARY = 2, + XLA_FFI_AttrType_SCALAR = 3, XLA_FFI_AttrType_STRING = 4, - XLA_FFI_AttrType_DICTIONARY = 5, } XLA_FFI_AttrType; //===----------------------------------------------------------------------===// @@ -208,27 +230,61 @@ struct XLA_FFI_ByteSpan { XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_ByteSpan, len); +// A struct to pass a scalar value to FFI handler. +struct XLA_FFI_Scalar { + size_t struct_size; + void* priv; + + XLA_FFI_DataType dtype; + void* value; +}; + +XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Scalar, value); + +// A struct to pass a dense array to FFI handler. +struct XLA_FFI_Array { + size_t struct_size; + void* priv; + + XLA_FFI_DataType dtype; + size_t size; + void* data; +}; + +XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Array, data); + struct XLA_FFI_Args { size_t struct_size; void* priv; - int64_t num_args; - XLA_FFI_ArgType* types; // length == num_args - void** args; // length == num_args + int64_t size; + XLA_FFI_ArgType* types; // length == size + void** args; // length == size }; XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Args, args); +struct XLA_FFI_Rets { + size_t struct_size; + void* priv; + + int64_t size; + XLA_FFI_RetType* types; // length == size + void** rets; // length == size +}; + +XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Rets, rets); + // FFI handler attributes are always sorted by name, so that the handler can // rely on binary search to look up attributes by name. struct XLA_FFI_Attrs { size_t struct_size; void* priv; - int64_t num_attrs; - XLA_FFI_AttrType* types; // length == num_attrs - XLA_FFI_ByteSpan** names; // length == num_attrs - void** attrs; // length == num_attrs + int64_t size; + XLA_FFI_AttrType* types; // length == size + XLA_FFI_ByteSpan** names; // length == size + void** attrs; // length == size }; XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Attrs, attrs); @@ -240,6 +296,7 @@ struct XLA_FFI_CallFrame { XLA_FFI_Api* api; XLA_FFI_ExecutionContext* ctx; XLA_FFI_Args args; + XLA_FFI_Rets rets; XLA_FFI_Attrs attrs; }; @@ -252,15 +309,26 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_CallFrame, attrs); // External functions registered with XLA as FFI handlers. typedef XLA_FFI_Error* XLA_FFI_Handler(XLA_FFI_CallFrame* call_frame); +enum XLA_FFI_Handler_TraitsBits { + // Calls to FFI handler are safe to trace into the command buffer. It means + // that calls to FFI handler always launch exactly the same device operations + // (can depend on attribute values) that can be captured and then replayed. + XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE = 1u << 0, +}; + +typedef uint32_t XLA_FFI_Handler_Traits; + struct XLA_FFI_Handler_Register_Args { size_t struct_size; void* priv; - const char* name; // null terminated + const char* name; // null terminated + const char* platform; // null terminated XLA_FFI_Handler* handler; + XLA_FFI_Handler_Traits traits; }; -XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Handler_Register_Args, handler); +XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Handler_Register_Args, traits); typedef XLA_FFI_Error* XLA_FFI_Handler_Register( XLA_FFI_Handler_Register_Args* args); @@ -296,6 +364,7 @@ struct XLA_FFI_Api { XLA_FFI_InternalApi* internal_api; _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_Create); + _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_GetMessage); _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_Destroy); _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Handler_Register); _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Stream_Get); @@ -303,7 +372,7 @@ struct XLA_FFI_Api { #undef _XLA_FFI_API_STRUCT_FIELD -XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Api, XLA_FFI_Handler_Register); +XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Api, XLA_FFI_Stream_Get); #ifdef __cplusplus } diff --git a/xla/ffi/api/c_api_internal.h b/xla/ffi/api/c_api_internal.h index c8c9fc78a5c35..7bbb0441b8858 100644 --- a/xla/ffi/api/c_api_internal.h +++ b/xla/ffi/api/c_api_internal.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_FFI_API_C_API_INTERNAL_H_ #define XLA_FFI_API_C_API_INTERNAL_H_ +#include + #include "xla/ffi/api/c_api.h" // Internal XLA FFI API that gives access to XLA implementation details that @@ -38,10 +40,28 @@ extern "C" { // Forwards `absl::Status` object pointed to by `status` to XLA FFI error // (status left in moved-from state). Pointer ownership stays with the // caller. -typedef XLA_FFI_Error* XLA_FFI_Error_Forward(void* status); +typedef XLA_FFI_Error* XLA_FFI_INTERNAL_Error_Forward(void* status); + +// Returns a pointer to main compute stream (`se::Stream` pointer). In +// contrast to public C API which returns a pointer to underlying platform +// stream (i.e. cudaStream_t for CUDA backend), this API returns a pointer to +// StreamExecutor stream which is unsafe to use across dynamic library boundary. +typedef void* XLA_FFI_INTERNAL_Stream_Get(XLA_FFI_ExecutionContext* ctx); + +// Returns the device ordinal of the device associated with the execution +// context. +typedef int32_t XLA_FFI_INTERNAL_DeviceOrdinal_Get( + XLA_FFI_ExecutionContext* ctx); + +// Returns a pointer to device memory allocator (`se::DeviceMemoryAllocator` +// pointer) which allows to allocate memory inside a custom call from the same +// allocator as XLA (i.e. it allows to construct scratch memory allocator). +typedef void* XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get( + XLA_FFI_ExecutionContext* ctx); -// Returns a pointer to `xla::ServiceExecutableRunOptions`. -typedef void* XLA_FFI_ServiceExecutableRunOptions_Get( +// Returns a pointer to `xla::HloComputation` if FFI handler has a called +// computation attached to it. +typedef void* XLA_FFI_INTERNAL_CalledComputation_Get( XLA_FFI_ExecutionContext* ctx); //===----------------------------------------------------------------------===// @@ -51,8 +71,12 @@ typedef void* XLA_FFI_ServiceExecutableRunOptions_Get( #define _XLA_FFI_INTERNAL_API_STRUCT_FIELD(fn_type) fn_type* fn_type struct XLA_FFI_InternalApi { - _XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_Error_Forward); - _XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_ServiceExecutableRunOptions_Get); + _XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_Error_Forward); + _XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_Stream_Get); + _XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_DeviceOrdinal_Get); + _XLA_FFI_INTERNAL_API_STRUCT_FIELD( + XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get); + _XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_CalledComputation_Get); }; #undef _XLA_FFI_INTERNAL_API_STRUCT_FIELD diff --git a/xla/ffi/api/ffi.h b/xla/ffi/api/ffi.h index a3e4b1fbbdbde..b652d5accda0d 100644 --- a/xla/ffi/api/ffi.h +++ b/xla/ffi/api/ffi.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,9 +16,9 @@ limitations under the License. #ifndef XLA_FFI_API_FFI_H_ #define XLA_FFI_API_FFI_H_ -#ifdef TENSORFLOW_COMPILER_XLA_FFI_FFI_H_ +#ifdef XLA_FFI_FFI_H_ #error Two different XLA FFI implementations cannot be included together -#endif // XLA_FFI_API_H_ +#endif // XLA_FFI_FFI_H_ #include #include @@ -57,8 +57,8 @@ enum class DataType : uint8_t { inline std::ostream& operator<<(std::ostream& os, const DataType dtype) { static constexpr const char* kDataTypeNames[] = { - "PRED", "S8", "S16", "S32", "S64", "U8", "U16", - "U32", "U64", "F16", "F32", "F64", "BF16", + "INVALID", "PRED", "S8", "S16", "S32", "S64", "U8", + "U16", "U32", "U64", "F16", "F32", "F64", "BF16", }; return os << kDataTypeNames[static_cast(dtype)]; } @@ -116,6 +116,12 @@ class Error { // Arguments //===----------------------------------------------------------------------===// +struct BufferBase { + DataType dtype; + void* data; + Span dimensions; +}; + namespace internal { // A workaround for the fact that a static_assertion can be evaluated @@ -124,36 +130,106 @@ template struct always_false : std::false_type {}; template -struct PtrType { +struct DataTypeToNative { static_assert(always_false::value, "unsupported data type"); }; // clang-format off -template <> struct PtrType { using Type = bool; }; -template <> struct PtrType { using Type = uint8_t; }; -template <> struct PtrType { using Type = uint16_t; }; -template <> struct PtrType { using Type = uint32_t; }; -template <> struct PtrType { using Type = uint64_t; }; -template <> struct PtrType { using Type = int8_t; }; -template <> struct PtrType { using Type = int16_t; }; -template <> struct PtrType { using Type = int32_t; }; -template <> struct PtrType { using Type = int64_t; }; -template <> struct PtrType { using Type = uint16_t; }; -template <> struct PtrType { using Type = float; }; -template <> struct PtrType { using Type = double; }; -template <> struct PtrType { using Type = uint16_t; }; +template <> struct DataTypeToNative { using type = bool; }; +template <> struct DataTypeToNative { using type = uint8_t; }; +template <> struct DataTypeToNative { using type = uint16_t; }; +template <> struct DataTypeToNative { using type = uint32_t; }; +template <> struct DataTypeToNative { using type = uint64_t; }; +template <> struct DataTypeToNative { using type = int8_t; }; +template <> struct DataTypeToNative { using type = int16_t; }; +template <> struct DataTypeToNative { using type = int32_t; }; +template <> struct DataTypeToNative { using type = int64_t; }; +template <> struct DataTypeToNative { using type = uint16_t; }; +template <> struct DataTypeToNative { using type = float; }; +template <> struct DataTypeToNative { using type = double; }; +template <> struct DataTypeToNative { using type = uint16_t; }; // clang-format on inline constexpr size_t kDynamicRank = std::numeric_limits::max(); +template +using NativeType = typename DataTypeToNative::type; + } // namespace internal template -struct BufferBase { - typename internal::PtrType::Type* data; +struct Buffer { + internal::NativeType* data; Span dimensions; }; +// clang-format off +template using BufferR0 = Buffer; +template using BufferR1 = Buffer; +template using BufferR2 = Buffer; +template using BufferR3 = Buffer; +template using BufferR4 = Buffer; +// clang-format on + +namespace internal { + +inline BufferBase DecodeBuffer(XLA_FFI_Buffer* buf) { + return BufferBase{static_cast(buf->dtype), buf->data, + Span(buf->dims, buf->rank)}; +} + +template +std::optional> DecodeBuffer(XLA_FFI_Buffer* buf, + DiagnosticEngine& diagnostic) { + if (auto buf_dtype = static_cast(buf->dtype); + XLA_FFI_PREDICT_FALSE(buf_dtype != dtype)) { + return diagnostic.Emit("Wrong buffer dtype: expected ") + << dtype << " but got " << buf_dtype; + } + + if constexpr (rank != internal::kDynamicRank) { + if (XLA_FFI_PREDICT_FALSE(buf->rank != rank)) { + return diagnostic.Emit("Wrong buffer rank: expected ") + << rank << " but got " << buf->rank; + } + } + + Buffer buffer; + buffer.data = static_cast*>(buf->data); + buffer.dimensions = Span(buf->dims, buf->rank); + return buffer; +} + +} // namespace internal + +//===----------------------------------------------------------------------===// +// Arguments binding +//===----------------------------------------------------------------------===// + +template <> +struct ArgBinding { + using Arg = BufferBase; +}; + +template +struct ArgBinding> { + using Arg = Buffer; +}; + +//===----------------------------------------------------------------------===// +// Results binding +//===----------------------------------------------------------------------===// + +template <> +struct RetBinding> { + using Ret = BufferBase; +}; + +template +struct RetBinding>> { + using Ret = Buffer; +}; + //===----------------------------------------------------------------------===// // Arguments decoding //===----------------------------------------------------------------------===// @@ -165,31 +241,124 @@ inline std::ostream& operator<<(std::ostream& os, const XLA_FFI_ArgType type) { } } +template <> +struct ArgDecoding { + XLA_FFI_ATTRIBUTE_ALWAYS_INLINE + static std::optional Decode(XLA_FFI_ArgType type, void* arg, + DiagnosticEngine& diagnostic) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_ArgType_BUFFER)) { + return diagnostic.Emit("Wrong argument type: expected ") + << XLA_FFI_ArgType_BUFFER << " but got " << type; + } + return internal::DecodeBuffer(reinterpret_cast(arg)); + } +}; + template -struct ArgDecoding> { - XLA_ATTRIBUTE_ALWAYS_INLINE - static std::optional> Decode( +struct ArgDecoding> { + XLA_FFI_ATTRIBUTE_ALWAYS_INLINE + static std::optional> Decode( XLA_FFI_ArgType type, void* arg, DiagnosticEngine& diagnostic) { - if (type != XLA_FFI_ArgType_BUFFER) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_ArgType_BUFFER)) { return diagnostic.Emit("Wrong argument type: expected ") << XLA_FFI_ArgType_BUFFER << " but got " << type; } - auto* buf = reinterpret_cast(arg); - if (auto actual_dtype = static_cast(buf->dtype); - actual_dtype != dtype) { - return diagnostic.Emit("Wrong buffer dtype: expected ") - << dtype << " but got " << actual_dtype; + + return internal::DecodeBuffer( + reinterpret_cast(arg), diagnostic); + } +}; + +//===----------------------------------------------------------------------===// +// Results decoding +//===----------------------------------------------------------------------===// + +inline std::ostream& operator<<(std::ostream& os, const XLA_FFI_RetType type) { + switch (type) { + case XLA_FFI_RetType_BUFFER: + return os << "buffer"; + } +} + +template <> +struct RetDecoding { + XLA_FFI_ATTRIBUTE_ALWAYS_INLINE + static std::optional> Decode( + XLA_FFI_RetType type, void* ret, DiagnosticEngine& diagnostic) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_RetType_BUFFER)) { + return diagnostic.Emit("Wrong result type: expected ") + << XLA_FFI_RetType_BUFFER << " but got " << type; } - auto* data = - static_cast::Type*>(buf->data); - if constexpr (rank != internal::kDynamicRank) { - if (buf->rank != rank) { - diagnostic.Emit("Wrong buffer rank: expected ") - << rank << " but got " << buf->rank; - return std::nullopt; - } + return internal::DecodeBuffer(reinterpret_cast(ret)); + } +}; + +template +struct RetDecoding> { + XLA_FFI_ATTRIBUTE_ALWAYS_INLINE + static std::optional>> Decode( + XLA_FFI_RetType type, void* ret, DiagnosticEngine& diagnostic) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_RetType_BUFFER)) { + return diagnostic.Emit("Wrong result type: expected ") + << XLA_FFI_RetType_BUFFER << " but got " << type; } - return BufferBase{data, Span(buf->dims, rank)}; + + return internal::DecodeBuffer( + reinterpret_cast(ret), diagnostic); + } +}; + +//===----------------------------------------------------------------------===// +// Attributes decoding +//===----------------------------------------------------------------------===// + +#define XLA_FFI_REGISTER_ARRRAY_ATTR_DECODING(T, TYPE) \ + template <> \ + struct AttrDecoding> { \ + using Type = Span; \ + static std::optional Decode(XLA_FFI_AttrType type, void* attr, \ + DiagnosticEngine& diagnostic) { \ + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_ARRAY)) { \ + return diagnostic.Emit("Wrong attribute type: expected ") \ + << XLA_FFI_AttrType_ARRAY << " but got " << type; \ + } \ + \ + auto* array = reinterpret_cast(attr); \ + if (XLA_FFI_PREDICT_FALSE(array->dtype != TYPE)) { \ + return diagnostic.Emit("Wrong array data type: expected ") \ + << TYPE << " but got " << array->dtype; \ + } \ + \ + return Span(reinterpret_cast(array->data), array->size); \ + } \ + } + +XLA_FFI_REGISTER_ARRRAY_ATTR_DECODING(int32_t, XLA_FFI_DataType_S32); +XLA_FFI_REGISTER_ARRRAY_ATTR_DECODING(int64_t, XLA_FFI_DataType_S64); +XLA_FFI_REGISTER_ARRRAY_ATTR_DECODING(float, XLA_FFI_DataType_F32); + +#undef XLA_FFI_REGISTER_SCALAR_ATTR_DECODING + +// A type tag to mark i64 attributes as pointers to `T`. +template +struct Pointer {}; + +template +struct AttrDecoding> { + using Type = T*; + + static std::optional Decode(XLA_FFI_AttrType type, void* attr, + DiagnosticEngine& diagnostic) { + auto* scalar = reinterpret_cast(attr); + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_SCALAR || + scalar->dtype != XLA_FFI_DataType_S64)) { + return diagnostic.Emit("Wrong attribute type: ") + << "expected i64 scalar for passing pointer but got " << type; + } + + static_assert(sizeof(uintptr_t) == sizeof(int64_t)); + uintptr_t ptr = *reinterpret_cast(scalar->value); + return reinterpret_cast(ptr); } }; @@ -226,7 +395,7 @@ struct CtxDecoding> { static std::optional Decode(const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx, - DiagnosticEngine&) { + DiagnosticEngine& diagnostic) { XLA_FFI_Stream_Get_Args args; args.struct_size = XLA_FFI_Stream_Get_Args_STRUCT_SIZE; args.priv = nullptr; @@ -234,6 +403,8 @@ struct CtxDecoding> { args.stream = nullptr; if (XLA_FFI_Error* error = api->XLA_FFI_Stream_Get(&args); error) { + diagnostic.Emit("Failed to get platform stream: ") + << GetErrorMessage(api, error); DestroyError(api, error); return std::nullopt; } @@ -241,14 +412,22 @@ struct CtxDecoding> { return reinterpret_cast(args.stream); } - // TODO(ezhulenev): We need to log error message somewhere, currently we - // silently destroy it. + static const char* GetErrorMessage(const XLA_FFI_Api* api, + XLA_FFI_Error* error) { + XLA_FFI_Error_GetMessage_Args args; + args.struct_size = XLA_FFI_Error_GetMessage_Args_STRUCT_SIZE; + args.priv = nullptr; + args.error = error; + api->XLA_FFI_Error_GetMessage(&args); + return args.message; + } + static void DestroyError(const XLA_FFI_Api* api, XLA_FFI_Error* error) { - XLA_FFI_Error_Destroy_Args destroy_args; - destroy_args.struct_size = XLA_FFI_Error_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; - destroy_args.error = error; - api->XLA_FFI_Error_Destroy(&destroy_args); + XLA_FFI_Error_Destroy_Args args; + args.struct_size = XLA_FFI_Error_Destroy_Args_STRUCT_SIZE; + args.priv = nullptr; + args.error = error; + api->XLA_FFI_Error_Destroy(&args); } }; diff --git a/xla/ffi/api/ffi_test.cc b/xla/ffi/api/ffi_test.cc index 5e081313d59be..b1dc769ca8de9 100644 --- a/xla/ffi/api/ffi_test.cc +++ b/xla/ffi/api/ffi_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,12 +20,12 @@ limitations under the License. #include #include "absl/log/check.h" +#include "absl/status/status.h" #include "xla/ffi/call_frame.h" #include "xla/ffi/ffi_api.h" #include "xla/stream_executor/device_memory.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" -#include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" @@ -63,6 +63,24 @@ TEST(FfiTest, DataTypeEnumValue) { EXPECT_EQ(encoded(PrimitiveType::BF16), encoded(DataType::BF16)); } +TEST(FfiTest, BufferBaseArgument) { + std::vector storage(4, 0.0f); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + + CallFrameBuilder builder; + builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + auto handler = Ffi::Bind().Arg().To([&](auto buffer) { + EXPECT_EQ(buffer.data, storage.data()); + EXPECT_EQ(buffer.dimensions.size(), 2); + return Error::Success(); + }); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); +} + TEST(FfiTest, BufferArgument) { std::vector storage(4, 0.0f); se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); @@ -72,7 +90,7 @@ TEST(FfiTest, BufferArgument) { auto call_frame = builder.Build(); auto handler = - Ffi::Bind().Arg>().To([&](auto buffer) { + Ffi::Bind().Arg>().To([&](auto buffer) { EXPECT_EQ(buffer.data, storage.data()); EXPECT_EQ(buffer.dimensions.size(), 2); return Error::Success(); @@ -82,51 +100,215 @@ TEST(FfiTest, BufferArgument) { TF_ASSERT_OK(status); } +TEST(FfiTest, BufferBaseResult) { + std::vector storage(4, 0.0f); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + + CallFrameBuilder builder; + builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + auto handler = + Ffi::Bind().Ret().To([&](Result buffer) { + EXPECT_EQ(buffer->data, storage.data()); + EXPECT_EQ(buffer->dimensions.size(), 2); + return Error::Success(); + }); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); +} + TEST(FfiTest, MissingBufferArgument) { CallFrameBuilder builder; auto call_frame = builder.Build(); - auto handler = Ffi::Bind().Arg>().To( + auto handler = Ffi::Bind().Arg>().To( [](auto) { return Error::Success(); }); auto status = Call(*handler, call_frame); - EXPECT_THAT(status, StatusIs(tsl::error::INVALID_ARGUMENT, + EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Wrong number of arguments"))); } TEST(FfiTest, WrongRankBufferArgument) { - std::vector storage(4, 0.0); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(std::int32_t)); + std::vector storage(4, 0.0); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(int32_t)); CallFrameBuilder builder; builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); auto call_frame = builder.Build(); - auto handler = Ffi::Bind().Arg>().To( + auto handler = Ffi::Bind().Arg>().To( [](auto) { return Error::Success(); }); auto status = Call(*handler, call_frame); EXPECT_THAT(status, - StatusIs(tsl::error::INVALID_ARGUMENT, + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Wrong buffer rank: expected 1 but got 2"))); } TEST(FfiTest, WrongTypeBufferArgument) { - std::vector storage(4, 0.0); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(std::int32_t)); + std::vector storage(4, 0.0); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(int32_t)); CallFrameBuilder builder; builder.AddBufferArg(memory, PrimitiveType::S32, /*dims=*/{2, 2}); auto call_frame = builder.Build(); - auto handler = Ffi::Bind().Arg>().To( + auto handler = Ffi::Bind().Arg>().To( [](auto) { return Error::Success(); }); auto status = Call(*handler, call_frame); EXPECT_THAT( status, - StatusIs(tsl::error::INVALID_ARGUMENT, - HasSubstr("Wrong buffer dtype: expected F64 but got S64"))); + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Wrong buffer dtype: expected F32 but got S32"))); +} + +TEST(FfiTest, AutoBinding) { + static constexpr char kI32[] = "i32"; + + auto handler = Ffi::BindTo(+[](BufferBase buffer, Attr foo) { + EXPECT_EQ(*foo, 42); + return Error::Success(); + }); + + std::vector storage(4, 0.0f); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + + CallFrameBuilder::AttributesBuilder attrs; + attrs.Insert(kI32, 42); + + CallFrameBuilder builder; + builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + builder.AddAttributes(attrs.Build()); + auto call_frame = builder.Build(); + + auto status = Call(*handler, call_frame); + TF_ASSERT_OK(status); +} + +TEST(FfiTest, AutoBindingResult) { + auto handler = + Ffi::BindTo(+[](Result buffer) { return Error::Success(); }); + + CallFrameBuilder builder; + builder.AddBufferRet(se::DeviceMemoryBase(), PrimitiveType::F32, /*dims=*/{}); + auto call_frame = builder.Build(); + + auto status = Call(*handler, call_frame); + TF_ASSERT_OK(status); +} + +struct I32AndF32 { + int32_t i32; + float f32; +}; + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(I32AndF32, StructMember("i32"), + StructMember("f32")); + +TEST(FfiTest, AutoBindingStructs) { + auto handler = Ffi::BindTo(+[](I32AndF32 attrs) { + EXPECT_EQ(attrs.i32, 42); + EXPECT_EQ(attrs.f32, 42.0f); + return Error::Success(); + }); + + CallFrameBuilder::AttributesBuilder attrs; + attrs.Insert("i32", 42); + attrs.Insert("f32", 42.0f); + + CallFrameBuilder builder; + builder.AddAttributes(attrs.Build()); + auto call_frame = builder.Build(); + + auto status = Call(*handler, call_frame); + TF_ASSERT_OK(status); +} + +TEST(FfiTest, AutoBindingDictionary) { + auto handler = Ffi::BindTo(+[](Dictionary attrs) { + EXPECT_EQ(*attrs.get("i32"), 42); + EXPECT_EQ(*attrs.get("f32"), 42.0f); + return Error::Success(); + }); + + CallFrameBuilder::AttributesBuilder attrs; + attrs.Insert("i32", 42); + attrs.Insert("f32", 42.0f); + + CallFrameBuilder builder; + builder.AddAttributes(attrs.Build()); + auto call_frame = builder.Build(); + + auto status = Call(*handler, call_frame); + TF_ASSERT_OK(status); +} + +// Use opaque struct to define a platform stream type just like platform +// stream for GPU backend (e.g. `CUstream_st` and `cudaStream_t`). +struct TestStreamSt; +using TestStream = TestStreamSt*; + +template <> +struct CtxBinding { + using Ctx = PlatformStream; +}; + +TEST(FfiTest, BindingPlatformStreamInference) { + // We only check that it compiles. + (void)Ffi::BindTo(+[](TestStream stream) { return Error::Success(); }); +} + +TEST(FfiTest, ArrayAttr) { + CallFrameBuilder::AttributesBuilder attrs; + attrs.Insert("arr", std::vector({1, 2, 3, 4})); + + CallFrameBuilder builder; + builder.AddAttributes(attrs.Build()); + auto call_frame = builder.Build(); + + auto fn = [&](Span arr) { + EXPECT_EQ(arr.size(), 4); + EXPECT_EQ(arr[0], 1); + EXPECT_EQ(arr[1], 2); + EXPECT_EQ(arr[2], 3); + EXPECT_EQ(arr[3], 4); + return Error::Success(); + }; + + auto handler = Ffi::Bind().Attr>("arr").To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); +} + +TEST(FfiTest, PointerAttr) { + std::string foo = "foo"; + + // Test for convenience attr binding that casts i64 attribute to user-type + // pointers. It's up to the user to guarantee that pointer is valid. + auto ptr = reinterpret_cast(&foo); + static_assert(sizeof(ptr) == sizeof(int64_t)); + + CallFrameBuilder::AttributesBuilder attrs; + attrs.Insert("ptr", static_cast(ptr)); + + CallFrameBuilder builder; + builder.AddAttributes(attrs.Build()); + auto call_frame = builder.Build(); + + auto fn = [&](const std::string* str) { + EXPECT_EQ(*str, "foo"); + return Error::Success(); + }; + + auto handler = Ffi::Bind().Attr>("ptr").To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); } //===----------------------------------------------------------------------===// @@ -144,6 +326,51 @@ static CallFrameBuilder WithBufferArgs(size_t num_args, size_t rank = 4) { return builder; } +//===----------------------------------------------------------------------===// +// BM_BufferBaseArgX1 +//===----------------------------------------------------------------------===// + +void BM_BufferBaseArgX1(benchmark::State& state) { + auto call_frame = WithBufferArgs(1).Build(); + + auto handler = Ffi::Bind().Arg().To([](auto buffer) { + benchmark::DoNotOptimize(buffer); + return Error::Success(); + }); + for (auto _ : state) { + CHECK_OK(Call(*handler, call_frame)); + } +} + +BENCHMARK(BM_BufferBaseArgX1); + +//===----------------------------------------------------------------------===// +// BM_BufferBaseArgX4 +//===----------------------------------------------------------------------===// + +void BM_BufferBaseArgX4(benchmark::State& state) { + auto call_frame = WithBufferArgs(4).Build(); + + auto handler = Ffi::Bind() + .Arg() + .Arg() + .Arg() + .Arg() + .To([](auto b0, auto b1, auto b2, auto b3) { + benchmark::DoNotOptimize(b0); + benchmark::DoNotOptimize(b1); + benchmark::DoNotOptimize(b2); + benchmark::DoNotOptimize(b3); + return Error::Success(); + }); + + for (auto _ : state) { + CHECK_OK(Call(*handler, call_frame)); + } +} + +BENCHMARK(BM_BufferBaseArgX4); + //===----------------------------------------------------------------------===// // BM_BufferArgX1 //===----------------------------------------------------------------------===// @@ -151,11 +378,10 @@ static CallFrameBuilder WithBufferArgs(size_t num_args, size_t rank = 4) { void BM_BufferArgX1(benchmark::State& state) { auto call_frame = WithBufferArgs(1).Build(); - auto handler = - Ffi::Bind().Arg>().To([](auto buffer) { - benchmark::DoNotOptimize(buffer); - return Error::Success(); - }); + auto handler = Ffi::Bind().Arg>().To([](auto buffer) { + benchmark::DoNotOptimize(buffer); + return Error::Success(); + }); for (auto _ : state) { CHECK_OK(Call(*handler, call_frame)); } @@ -171,10 +397,10 @@ void BM_BufferArgX4(benchmark::State& state) { auto call_frame = WithBufferArgs(4).Build(); auto handler = Ffi::Bind() - .Arg>() - .Arg>() - .Arg>() - .Arg>() + .Arg>() + .Arg>() + .Arg>() + .Arg>() .To([](auto b0, auto b1, auto b2, auto b3) { benchmark::DoNotOptimize(b0); benchmark::DoNotOptimize(b1); diff --git a/xla/ffi/call_frame.cc b/xla/ffi/call_frame.cc index 3ddc3bc4971fb..4d064bfafcd06 100644 --- a/xla/ffi/call_frame.cc +++ b/xla/ffi/call_frame.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -84,13 +85,19 @@ void CallFrameBuilder::AddBufferArg(se::DeviceMemoryBase memory, args_.push_back(Buffer{memory, type, {dims.begin(), dims.end()}}); } +void CallFrameBuilder::AddBufferRet(se::DeviceMemoryBase memory, + PrimitiveType type, + absl::Span dims) { + rets_.push_back(Buffer{memory, type, {dims.begin(), dims.end()}}); +} + void CallFrameBuilder::AddAttributes(AttributesMap attrs) { for (auto& [name, attr] : attrs) { attrs_.try_emplace(std::move(name), std::move(attr)); } } -CallFrame CallFrameBuilder::Build() { return CallFrame(args_, attrs_); } +CallFrame CallFrameBuilder::Build() { return CallFrame(args_, rets_, attrs_); } CallFrameBuilder::CallFrameBuilder(CallFrameBuilder&&) = default; CallFrameBuilder& CallFrameBuilder::operator=(CallFrameBuilder&&) = default; @@ -125,6 +132,19 @@ struct CallFrame::Dictionary { std::unique_ptr attrs; }; +struct CallFrame::Array { + std::variant, std::vector, std::vector> + value; // XLA_FFI_Array::data + + XLA_FFI_Array array = {XLA_FFI_Array_STRUCT_SIZE, nullptr}; +}; + +struct CallFrame::Scalar { + std::variant value; // XLA_FFI_Scalar::value + + XLA_FFI_Scalar scalar = {XLA_FFI_Scalar_STRUCT_SIZE, nullptr}; +}; + struct CallFrame::String { std::string value; // XLA_FFI_ByteSpan::ptr @@ -151,6 +171,21 @@ struct CallFrame::Arguments { XLA_FFI_Args ffi_args = {XLA_FFI_Args_STRUCT_SIZE, nullptr}; }; +struct CallFrame::Results { + explicit Results(size_t size) { + results.reserve(size); + types.reserve(size); + rets.reserve(size); + } + + std::vector results; + + std::vector types; // XLA_FFI_Rets::types + std::vector rets; // XLA_FFI_Rets::rets + + XLA_FFI_Rets ffi_rets = {XLA_FFI_Rets_STRUCT_SIZE, nullptr}; +}; + struct CallFrame::Attributes { explicit Attributes(size_t size) { attributes.reserve(size); @@ -173,8 +208,11 @@ struct CallFrame::Attributes { //===----------------------------------------------------------------------===// CallFrame::CallFrame(absl::Span args, + absl::Span rets, const CallFrameBuilder::AttributesMap& attrs) - : arguments_(InitArgs(args)), attributes_(InitAttrs(attrs)) {} + : arguments_(InitArgs(args)), + results_(InitRets(rets)), + attributes_(InitAttrs(attrs)) {} XLA_FFI_CallFrame CallFrame::Build(XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx) { @@ -182,53 +220,60 @@ XLA_FFI_CallFrame CallFrame::Build(XLA_FFI_Api* api, call_frame.api = api; call_frame.ctx = ctx; call_frame.args = arguments_->ffi_args; + call_frame.rets = results_->ffi_rets; call_frame.attrs = attributes_->ffi_attrs; return call_frame; } CallFrame::~CallFrame() = default; +// We rely on casting to and from underlying integral type to convert from +// PrimitiveType to XLA FFI DataType, and for safety convert all unknown types +// to invalid type, otherwise we can accidentally cause UB. +static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) { + switch (primitive_type) { + case PrimitiveType::PRIMITIVE_TYPE_INVALID: + case PrimitiveType::PRED: + case PrimitiveType::S8: + case PrimitiveType::S16: + case PrimitiveType::S32: + case PrimitiveType::S64: + case PrimitiveType::U8: + case PrimitiveType::U16: + case PrimitiveType::U32: + case PrimitiveType::U64: + case PrimitiveType::F16: + case PrimitiveType::F32: + case PrimitiveType::F64: + case PrimitiveType::BF16: + return static_cast(primitive_type); + default: + DCHECK(false) << "Unsupported primitive type" << primitive_type; + return XLA_FFI_DataType_INVALID; + } +} + +CallFrame::Buffer CallFrame::ConvertBuffer( + const CallFrameBuilder::Buffer& buffer) { + Buffer result; + result.dims = buffer.dims; + result.buffer.data = const_cast(buffer.memory.opaque()); + result.buffer.dtype = ToDataType(buffer.type); + result.buffer.rank = result.dims.size(); + return result; +} + //===----------------------------------------------------------------------===// // Call frame arguments //===----------------------------------------------------------------------===// -/*static*/ std::unique_ptr CallFrame::InitArgs( +std::unique_ptr CallFrame::InitArgs( absl::Span bargs) { auto res = std::make_unique(bargs.size()); - // We rely on casting to and from underlying integral type to convert from - // PrimitiveType to XLA FFI DataType, and for safety convert all unknown types - // to invalid type, otherwise we can accidentally cause UB. - auto to_data_type = [](PrimitiveType primitive_type) { - switch (primitive_type) { - case PrimitiveType::PRIMITIVE_TYPE_INVALID: - case PrimitiveType::S8: - case PrimitiveType::S16: - case PrimitiveType::S32: - case PrimitiveType::S64: - case PrimitiveType::U8: - case PrimitiveType::U16: - case PrimitiveType::U32: - case PrimitiveType::U64: - case PrimitiveType::F16: - case PrimitiveType::F32: - case PrimitiveType::F64: - case PrimitiveType::BF16: - return static_cast(primitive_type); - default: - DCHECK(false) << "Unsupported primitive type" << primitive_type; - return XLA_FFI_DataType_INVALID; - } - }; - // Convert call frame builder arguments to call frame arguments. for (const CallFrameBuilder::Buffer& barg : bargs) { - Buffer buffer; - buffer.dims = barg.dims; - buffer.buffer.data = const_cast(barg.memory.opaque()); - buffer.buffer.dtype = to_data_type(barg.type); - buffer.buffer.rank = buffer.dims.size(); - res->arguments.push_back(std::move(buffer)); + res->arguments.push_back(ConvertBuffer(barg)); } // Fix up pointers in XLA FFI structs. @@ -244,13 +289,46 @@ CallFrame::~CallFrame() = default; // Finally initialize the XLA FFI struct. At this point all storage is // allocated and it's safe to grab a pointer to it. - res->ffi_args.num_args = res->arguments.size(); + res->ffi_args.size = res->arguments.size(); res->ffi_args.types = res->types.data(); res->ffi_args.args = res->args.data(); return res; } +//===----------------------------------------------------------------------===// +// Call frame results +//===----------------------------------------------------------------------===// + +std::unique_ptr CallFrame::InitRets( + absl::Span brets) { + auto res = std::make_unique(brets.size()); + + // Convert call frame builder arguments to call frame arguments. + for (const CallFrameBuilder::Buffer& bret : brets) { + res->results.push_back(ConvertBuffer(bret)); + } + + // Fix up pointers in XLA FFI structs. + for (CallFrame::Buffer& arg : res->results) { + arg.buffer.dims = arg.dims.data(); + } + + // Initialize vectors required for building XLA_FFI_Rets. + for (CallFrame::Buffer& ret : res->results) { + res->types.push_back(XLA_FFI_RetType_BUFFER); + res->rets.push_back(&ret.buffer); + } + + // Finally initialize the XLA FFI struct. At this point all storage is + // allocated and it's safe to grab a pointer to it. + res->ffi_rets.size = res->results.size(); + res->ffi_rets.types = res->types.data(); + res->ffi_rets.rets = res->rets.data(); + + return res; +} + //===----------------------------------------------------------------------===// // Call frame attributes //===----------------------------------------------------------------------===// @@ -258,9 +336,12 @@ CallFrame::~CallFrame() = default; // An std::visit overload set for converting CallFrameBuilder::Attribute to // CallFrame::Attribute. struct CallFrame::ConvertAttribute { - template - CallFrame::Attribute operator()(const T& value) { - return value; + CallFrame::Attribute operator()(const CallFrameBuilder::Array& array) { + return CallFrame::Array{array}; + } + + CallFrame::Attribute operator()(const CallFrameBuilder::Scalar& scalar) { + return CallFrame::Scalar{scalar}; } CallFrame::Attribute operator()(const std::string& str) { @@ -272,25 +353,58 @@ struct CallFrame::ConvertAttribute { } }; +template +static XLA_FFI_DataType GetDataType() { + if constexpr (std::is_same_v) { + return XLA_FFI_DataType_S32; + } else if constexpr (std::is_same_v) { + return XLA_FFI_DataType_S64; + } else if constexpr (std::is_same_v) { + return XLA_FFI_DataType_F32; + } else { + static_assert(sizeof(T) == 0, "unsupported FFI data type"); + } +} + // An std::visit overload set to fix up CallFrame::Attribute storage and // initialize XLA FFI structs with valid pointers into storage objects. struct CallFrame::FixupAttribute { - template - void operator()(T& value) {} + void operator()(CallFrame::Array& array) { + auto visitor = [&](auto& value) { + using T = typename std::remove_reference_t::value_type; + array.array.dtype = GetDataType(); + array.array.size = value.size(); + array.array.data = value.data(); + }; + std::visit(visitor, array.value); + } + + void operator()(CallFrame::Scalar& scalar) { + auto visitor = [&](auto& value) { + using T = std::remove_reference_t; + scalar.scalar.dtype = GetDataType(); + scalar.scalar.value = &value; + }; + std::visit(visitor, scalar.value); + } void operator()(CallFrame::String& str) { str.span.ptr = str.value.data(); str.span.len = str.value.size(); } + + void operator()(CallFrame::Dictionary&) {} }; // An std::visit overload set to get CallFrame::Attribute XLA FFI type. struct CallFrame::AttributeType { - XLA_FFI_AttrType operator()(int32_t&) { return XLA_FFI_AttrType_I32; } - - XLA_FFI_AttrType operator()(int64_t&) { return XLA_FFI_AttrType_I64; } + XLA_FFI_AttrType operator()(CallFrame::Array&) { + return XLA_FFI_AttrType_ARRAY; + } - XLA_FFI_AttrType operator()(float&) { return XLA_FFI_AttrType_F32; } + XLA_FFI_AttrType operator()(CallFrame::Scalar&) { + return XLA_FFI_AttrType_SCALAR; + } XLA_FFI_AttrType operator()(CallFrame::String&) { return XLA_FFI_AttrType_STRING; @@ -308,6 +422,10 @@ struct CallFrame::AttributeStorage { return &value; } + void* operator()(CallFrame::Array& array) { return &array.array; } + + void* operator()(CallFrame::Scalar& scalar) { return &scalar.scalar; } + void* operator()(CallFrame::String& str) { return &str.span; } void* operator()(CallFrame::Dictionary& dict) { @@ -346,7 +464,7 @@ struct CallFrame::AttributeStorage { // Finally initialize XLA FFI struct. At this point all storage is allocated // and it's safe to grab a pointer to it. - res->ffi_attrs.num_attrs = res->attributes.size(); + res->ffi_attrs.size = res->attributes.size(); res->ffi_attrs.names = res->names.data(); res->ffi_attrs.types = res->types.data(); res->ffi_attrs.attrs = res->attrs.data(); diff --git a/xla/ffi/call_frame.h b/xla/ffi/call_frame.h index f57a70caeecf2..d283f607fb84a 100644 --- a/xla/ffi/call_frame.h +++ b/xla/ffi/call_frame.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -57,11 +57,15 @@ class CallFrameBuilder { CallFrameBuilder(CallFrameBuilder&&); CallFrameBuilder& operator=(CallFrameBuilder&&); + using Scalar = std::variant; + using Array = std::variant, std::vector, + std::vector>; + // Declare implementation detail structs for call frame builder storage. struct Dictionary; // Attributes that do not support nested dictionaries. - using FlatAttribute = std::variant; + using FlatAttribute = std::variant; using FlatAttributesMap = absl::flat_hash_map; // Attributes that support arbitrary nesting. @@ -96,6 +100,9 @@ class CallFrameBuilder { void AddBufferArg(se::DeviceMemoryBase memory, PrimitiveType type, absl::Span dims); + void AddBufferRet(se::DeviceMemoryBase memory, PrimitiveType type, + absl::Span dims); + void AddAttributes(AttributesMap attrs); private: @@ -104,6 +111,7 @@ class CallFrameBuilder { struct Buffer; std::vector args_; + std::vector rets_; AttributesMap attrs_; }; @@ -123,24 +131,34 @@ class CallFrame { // Declare implementation detail structs for call frame storage. struct Arguments; + struct Array; struct Attributes; struct Buffer; struct Dictionary; struct NamedAttribute; + struct Results; + struct Scalar; struct String; - using Attribute = std::variant; + using Attribute = std::variant; CallFrame(absl::Span args, + absl::Span rets, const CallFrameBuilder::AttributesMap& attrs); static std::unique_ptr InitArgs( absl::Span args); + static std::unique_ptr InitRets( + absl::Span rets); + static std::unique_ptr InitAttrs( const CallFrameBuilder::AttributesMap& attrs); + static Buffer ConvertBuffer(const CallFrameBuilder::Buffer& buffer); + std::unique_ptr arguments_; + std::unique_ptr results_; std::unique_ptr attributes_; // Declare implementation detail structs to grant access to private members. diff --git a/xla/ffi/ffi.h b/xla/ffi/ffi.h index 83878d2b08595..e0d555e3d33f2 100644 --- a/xla/ffi/ffi.h +++ b/xla/ffi/ffi.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,11 +16,13 @@ limitations under the License. #ifndef XLA_FFI_FFI_H_ #define XLA_FFI_FFI_H_ -#ifdef TENSORFLOW_COMPILER_XLA_FFI_API_FFI_H_ +#ifdef XLA_FFI_API_FFI_H_ #error Two different XLA FFI implementations cannot be included together #endif // XLA_FFI_API_FFI_H_ +#include #include +#include #include // IWYU pragma: begin_exports @@ -30,20 +32,29 @@ limitations under the License. #include "absl/types/span.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/primitive_util.h" #include "xla/runtime/memref_view.h" -#include "xla/service/service_executable_run_options.h" #include "xla/status.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/scratch_allocator.h" +#include "xla/stream_executor/stream.h" #include "xla/types.h" // IWYU pragma: keep #include "xla/xla_data.pb.h" namespace xla::ffi { +// Type tags to bind parameters passed via execution context to FFI handler. +struct Stream {}; // binds `se::Stream*` +struct ScratchAllocator {}; // binds `se::OwningScratchAllocator` +struct CalledComputation {}; // binds `HloComputation*` + //===----------------------------------------------------------------------===// // Arguments //===----------------------------------------------------------------------===// -struct Buffer { +struct BufferBase { PrimitiveType dtype; se::DeviceMemoryBase data; absl::Span dimensions; @@ -55,22 +66,199 @@ struct Buffer { } }; +namespace internal { + +inline constexpr size_t kDynamicRank = std::numeric_limits::max(); + +template +using NativeType = typename primitive_util::PrimitiveTypeToNative::type; + +} // namespace internal + +template +struct Buffer { + se::DeviceMemory> data; + absl::Span dimensions; +}; + +// clang-format off +template using BufferR0 = Buffer; +template using BufferR1 = Buffer; +template using BufferR2 = Buffer; +template using BufferR3 = Buffer; +template using BufferR4 = Buffer; +// clang-format on + +namespace internal { + +inline BufferBase DecodeBuffer(XLA_FFI_Buffer* buf) { + size_t size_bytes = primitive_util::ByteWidth(PrimitiveType(buf->dtype)); + for (int64_t i = 0; i < buf->rank; ++i) size_bytes *= buf->dims[i]; + + BufferBase buffer; + buffer.dtype = PrimitiveType(buf->dtype); + buffer.data = se::DeviceMemoryBase(buf->data, size_bytes); + buffer.dimensions = absl::MakeConstSpan(buf->dims, buf->rank); + return buffer; +} + +template +std::optional> DecodeBuffer(XLA_FFI_Buffer* buf, + DiagnosticEngine& diagnostic) { + if (auto buf_dtype = PrimitiveType(buf->dtype); + XLA_FFI_PREDICT_FALSE(buf_dtype != dtype)) { + return diagnostic.Emit("Wrong buffer dtype: expected ") + << primitive_util::LowercasePrimitiveTypeName(dtype) << " but got " + << primitive_util::LowercasePrimitiveTypeName(buf_dtype); + } + + if constexpr (rank != internal::kDynamicRank) { + if (XLA_FFI_PREDICT_FALSE(buf->rank != rank)) { + return diagnostic.Emit("Wrong buffer rank: expected ") + << rank << " but got " << buf->rank; + } + } + + Buffer buffer; + buffer.data = + se::DeviceMemory>(se::DeviceMemoryBase(buf->data)); + buffer.dimensions = absl::MakeConstSpan(buf->dims, buf->rank); + return buffer; +} + +} // namespace internal + +//===----------------------------------------------------------------------===// +// Arguments binding +//===----------------------------------------------------------------------===// + +template <> +struct ArgBinding { + using Arg = BufferBase; +}; + +template +struct ArgBinding> { + using Arg = Buffer; +}; + //===----------------------------------------------------------------------===// // Arguments decoding //===----------------------------------------------------------------------===// template <> -struct ArgDecoding { - static std::optional Decode(XLA_FFI_ArgType type, void* arg, - DiagnosticEngine&) { - if (type != XLA_FFI_ArgType_BUFFER) return std::nullopt; - auto* buf = reinterpret_cast(arg); - - Buffer buffer; - buffer.dtype = PrimitiveType(buf->dtype); - buffer.data = se::DeviceMemoryBase(buf->data); - buffer.dimensions = absl::MakeConstSpan(buf->dims, buf->rank); - return buffer; +struct ArgDecoding { + XLA_FFI_ATTRIBUTE_ALWAYS_INLINE + static std::optional Decode(XLA_FFI_ArgType type, void* arg, + DiagnosticEngine& diagnostic) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_ArgType_BUFFER)) { + return diagnostic.Emit("Wrong argument type: expected ") + << XLA_FFI_ArgType_BUFFER << " but got " << type; + } + + return internal::DecodeBuffer(reinterpret_cast(arg)); + } +}; + +template +struct ArgDecoding> { + XLA_FFI_ATTRIBUTE_ALWAYS_INLINE + static std::optional> Decode( + XLA_FFI_ArgType type, void* arg, DiagnosticEngine& diagnostic) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_ArgType_BUFFER)) { + return diagnostic.Emit("Wrong argument type: expected ") + << XLA_FFI_ArgType_BUFFER << " but got " << type; + } + + return internal::DecodeBuffer( + reinterpret_cast(arg), diagnostic); + } +}; + +//===----------------------------------------------------------------------===// +// Results decoding +//===----------------------------------------------------------------------===// + +template <> +struct RetDecoding { + XLA_FFI_ATTRIBUTE_ALWAYS_INLINE + static std::optional> Decode( + XLA_FFI_RetType type, void* arg, DiagnosticEngine& diagnostic) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_RetType_BUFFER)) { + return diagnostic.Emit("Wrong result type: expected ") + << XLA_FFI_RetType_BUFFER << " but got " << type; + } + return internal::DecodeBuffer(reinterpret_cast(arg)); + } +}; + +template +struct RetDecoding> { + XLA_FFI_ATTRIBUTE_ALWAYS_INLINE + static std::optional>> Decode( + XLA_FFI_RetType type, void* arg, DiagnosticEngine& diagnostic) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_RetType_BUFFER)) { + return diagnostic.Emit("Wrong result type: expected ") + << XLA_FFI_RetType_BUFFER << " but got " << type; + } + + return internal::DecodeBuffer( + reinterpret_cast(arg), diagnostic); + } +}; + +//===----------------------------------------------------------------------===// +// Attributes decoding +//===----------------------------------------------------------------------===// + +#define XLA_FFI_REGISTER_ARRRAY_ATTR_DECODING(T, TYPE) \ + template <> \ + struct AttrDecoding> { \ + using Type = absl::Span; \ + static std::optional Decode(XLA_FFI_AttrType type, void* attr, \ + DiagnosticEngine& diagnostic) { \ + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_ARRAY)) { \ + return diagnostic.Emit("Wrong attribute type: expected ") \ + << XLA_FFI_AttrType_ARRAY << " but got " << type; \ + } \ + \ + auto* array = reinterpret_cast(attr); \ + if (XLA_FFI_PREDICT_FALSE(array->dtype != TYPE)) { \ + return diagnostic.Emit("Wrong array data type: expected ") \ + << TYPE << " but got " << array->dtype; \ + } \ + \ + return absl::Span(reinterpret_cast(array->data), \ + array->size); \ + } \ + } + +XLA_FFI_REGISTER_ARRRAY_ATTR_DECODING(int32_t, XLA_FFI_DataType_S32); +XLA_FFI_REGISTER_ARRRAY_ATTR_DECODING(int64_t, XLA_FFI_DataType_S64); +XLA_FFI_REGISTER_ARRRAY_ATTR_DECODING(float, XLA_FFI_DataType_F32); + +#undef XLA_FFI_REGISTER_SCALAR_ATTR_DECODING + +// A type tag to mark i64 attributes as pointers to `T`. +template +struct Pointer {}; + +template +struct AttrDecoding> { + using Type = T*; + + static std::optional Decode(XLA_FFI_AttrType type, void* attr, + DiagnosticEngine& diagnostic) { + auto* scalar = reinterpret_cast(attr); + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_SCALAR || + scalar->dtype != XLA_FFI_DataType_S64)) { + return diagnostic.Emit("Wrong attribute type: ") + << "expected i64 scalar for passing pointer but got " << type; + } + + static_assert(sizeof(uintptr_t) == sizeof(int64_t)); + uintptr_t ptr = *reinterpret_cast(scalar->value); + return reinterpret_cast(ptr); } }; @@ -79,13 +267,43 @@ struct ArgDecoding { //===----------------------------------------------------------------------===// template <> -struct CtxDecoding { - using Type = const ServiceExecutableRunOptions*; +struct CtxDecoding { + using Type = se::Stream*; + + static std::optional Decode(const XLA_FFI_Api* api, + XLA_FFI_ExecutionContext* ctx, + DiagnosticEngine&) { + void* ptr = api->internal_api->XLA_FFI_INTERNAL_Stream_Get(ctx); + return reinterpret_cast(ptr); + } +}; + +template <> +struct CtxDecoding { + using Type = se::OwningScratchAllocator<>; + + static std::optional Decode(const XLA_FFI_Api* api, + XLA_FFI_ExecutionContext* ctx, + DiagnosticEngine&) { + int32_t device_ordinal = + api->internal_api->XLA_FFI_INTERNAL_DeviceOrdinal_Get(ctx); + void* device_allocator = + api->internal_api->XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get(ctx); + + return se::OwningScratchAllocator<>( + device_ordinal, + reinterpret_cast(device_allocator)); + } +}; + +template <> +struct CtxDecoding { + using Type = const HloComputation*; static std::optional Decode(const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx, DiagnosticEngine&) { - void* ptr = api->internal_api->XLA_FFI_ServiceExecutableRunOptions_Get(ctx); + void* ptr = api->internal_api->XLA_FFI_INTERNAL_CalledComputation_Get(ctx); return reinterpret_cast(ptr); } }; @@ -97,7 +315,7 @@ struct CtxDecoding { template <> struct ResultEncoding { static XLA_FFI_Error* Encode(XLA_FFI_Api* api, Status status) { - return api->internal_api->XLA_FFI_Error_Forward(&status); + return api->internal_api->XLA_FFI_INTERNAL_Error_Forward(&status); } }; diff --git a/xla/ffi/ffi_api.cc b/xla/ffi/ffi_api.cc index 303ba5ea6cdf0..af8dbaa5d50b6 100644 --- a/xla/ffi/ffi_api.cc +++ b/xla/ffi/ffi_api.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,21 +16,22 @@ limitations under the License. #include "xla/ffi/ffi_api.h" #include +#include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "xla/ffi/api/api.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep #include "xla/ffi/call_frame.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/service/service_executable_run_options.h" #include "xla/status.h" -#include "xla/statusor.h" -#include "tsl/platform/logging.h" //===----------------------------------------------------------------------===// // XLA FFI C structs definition @@ -42,12 +43,17 @@ struct XLA_FFI_Error { struct XLA_FFI_ExecutionContext { const xla::ServiceExecutableRunOptions* run_options; + const xla::HloComputation* called_computation; }; //===----------------------------------------------------------------------===// namespace xla::ffi { +bool IsCommandBufferCompatible(XLA_FFI_Handler_Traits traits) { + return traits & XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE; +} + //===----------------------------------------------------------------------===// // Calling XLA FFI handlers //===----------------------------------------------------------------------===// @@ -63,14 +69,16 @@ Status TakeStatus(XLA_FFI_Error* error) { } Status Call(Ffi& handler, CallFrame& call_frame, const CallOptions& options) { - XLA_FFI_ExecutionContext ctx = {options.run_options}; + XLA_FFI_ExecutionContext ctx = {options.run_options, + options.called_computation}; XLA_FFI_CallFrame ffi_call_frame = call_frame.Build(GetXlaFfiApi(), &ctx); return TakeStatus(handler.Call(&ffi_call_frame)); } Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame, const CallOptions& options) { - XLA_FFI_ExecutionContext ctx = {options.run_options}; + XLA_FFI_ExecutionContext ctx = {options.run_options, + options.called_computation}; XLA_FFI_CallFrame ffi_call_frame = call_frame.Build(GetXlaFfiApi(), &ctx); return TakeStatus((*handler)(&ffi_call_frame)); } @@ -79,30 +87,52 @@ Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame, // XLA FFI registry //===----------------------------------------------------------------------===// -// TODO(ezhulenev): We have to support platform-specific handler registration. -using HandlerRegistry = absl::flat_hash_map; +using HandlerKey = std::pair; +using HandlerRegistry = absl::flat_hash_map; + +static HandlerKey MakeHandlerKey(std::string_view name, + std::string_view platform) { + return std::make_pair(std::string(name), absl::AsciiStrToLower(platform)); +} static HandlerRegistry& GetHandlerRegistry() { static auto* registry = new HandlerRegistry(); return *registry; } -static Status RegisterHandler(std::string_view name, XLA_FFI_Handler* handler) { - auto emplaced = GetHandlerRegistry().try_emplace(std::string(name), handler); +static Status RegisterHandler(std::string_view name, std::string_view platform, + XLA_FFI_Handler* handler, + XLA_FFI_Handler_Traits traits) { + auto emplaced = GetHandlerRegistry().try_emplace( + MakeHandlerKey(name, platform), HandlerRegistration{handler, traits}); if (!emplaced.second) return absl::InvalidArgumentError( - absl::StrCat("Duplicate FFI handler registration for ", name)); + absl::StrCat("Duplicate FFI handler registration for ", name, + " on a platform ", platform)); return OkStatus(); } -StatusOr FindHandler(std::string_view name) { - auto it = GetHandlerRegistry().find(name); +absl::StatusOr FindHandler(std::string_view name, + std::string_view platform) { + auto it = GetHandlerRegistry().find(MakeHandlerKey(name, platform)); if (it == GetHandlerRegistry().end()) - return absl::NotFoundError( - absl::StrCat("No FFI handler registered for ", name)); + return absl::NotFoundError(absl::StrCat("No FFI handler registered for ", + name, " on a platform ", platform)); return it->second; } +absl::flat_hash_map StaticRegisteredHandlers( + std::string_view platform) { + absl::flat_hash_map calls; + for (const auto& [metadata, handler] : GetHandlerRegistry()) { + if (absl::AsciiStrToLower(platform) == metadata.second) { + calls[metadata.first] = handler; + } + } + + return calls; +} + //===----------------------------------------------------------------------===// // XLA FFI Api Implementation //===----------------------------------------------------------------------===// @@ -183,6 +213,18 @@ static XLA_FFI_Error* XLA_FFI_Error_Create(XLA_FFI_Error_Create_Args* args) { return new XLA_FFI_Error{Status(ToStatusCode(args->errc), args->message)}; } +static void XLA_FFI_Error_GetMessage(XLA_FFI_Error_GetMessage_Args* args) { + Status struct_size_check = ActualStructSizeIsGreaterOrEqual( + "XLA_FFI_Error_GetMessage", XLA_FFI_Error_GetMessage_Args_STRUCT_SIZE, + args->struct_size); + if (!struct_size_check.ok()) { + LOG(ERROR) << struct_size_check.message(); + } + // absl::Status owns error message in a std::string which guarantees that + // we'll get a null terminated string. + args->message = args->error->status.message().data(); +} + static void XLA_FFI_Error_Destroy(XLA_FFI_Error_Destroy_Args* args) { Status struct_size_check = ActualStructSizeIsGreaterOrEqual( "XLA_FFI_Error_Destroy", XLA_FFI_Error_Destroy_Args_STRUCT_SIZE, @@ -199,7 +241,9 @@ static XLA_FFI_Error* XLA_FFI_Handler_Register( "XLA_FFI_Handler_Register", XLA_FFI_Handler_Register_Args_STRUCT_SIZE, args->struct_size)); - if (auto status = RegisterHandler(args->name, args->handler); !status.ok()) { + if (auto status = RegisterHandler(args->name, args->platform, args->handler, + args->traits); + !status.ok()) { return new XLA_FFI_Error{std::move(status)}; } return nullptr; @@ -220,13 +264,27 @@ static XLA_FFI_Error* XLA_FFI_Stream_Get(XLA_FFI_Stream_Get_Args* args) { // XLA FFI Internal Api Implementation //===----------------------------------------------------------------------===// -static XLA_FFI_Error* XLA_FFI_Error_Forward(void* status) { +static XLA_FFI_Error* XLA_FFI_INTERNAL_Error_Forward(void* status) { return new XLA_FFI_Error{std::move(*reinterpret_cast(status))}; } -static void* XLA_FFI_ServiceExecutableRunOptions_Get( +static void* XLA_FFI_INTERNAL_Stream_Get(XLA_FFI_ExecutionContext* ctx) { + return ctx->run_options->stream(); +} + +static int32_t XLA_FFI_INTERNAL_DeviceOrdinal_Get( + XLA_FFI_ExecutionContext* ctx) { + return ctx->run_options->device_ordinal(); +} + +static void* XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get( + XLA_FFI_ExecutionContext* ctx) { + return ctx->run_options->allocator(); +} + +static void* XLA_FFI_INTERNAL_CalledComputation_Get( XLA_FFI_ExecutionContext* ctx) { - return const_cast(ctx->run_options); + return const_cast(ctx->called_computation); } //===----------------------------------------------------------------------===// @@ -234,8 +292,11 @@ static void* XLA_FFI_ServiceExecutableRunOptions_Get( //===----------------------------------------------------------------------===// static XLA_FFI_InternalApi internal_api = { - XLA_FFI_Error_Forward, - XLA_FFI_ServiceExecutableRunOptions_Get, + XLA_FFI_INTERNAL_Error_Forward, + XLA_FFI_INTERNAL_Stream_Get, + XLA_FFI_INTERNAL_DeviceOrdinal_Get, + XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get, + XLA_FFI_INTERNAL_CalledComputation_Get, }; static XLA_FFI_Api api = { @@ -245,6 +306,7 @@ static XLA_FFI_Api api = { &internal_api, XLA_FFI_Error_Create, // creates error + XLA_FFI_Error_GetMessage, // get error message XLA_FFI_Error_Destroy, // frees error XLA_FFI_Handler_Register, // registers handler XLA_FFI_Stream_Get, // returns platform specific stream diff --git a/xla/ffi/ffi_api.h b/xla/ffi/ffi_api.h index b127ed0b2b5b6..a90f19182ffc6 100644 --- a/xla/ffi/ffi_api.h +++ b/xla/ffi/ffi_api.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,15 +16,17 @@ limitations under the License. #ifndef XLA_FFI_FFI_API_H_ #define XLA_FFI_FFI_API_H_ +#include #include +#include "absl/container/flat_hash_map.h" #include "xla/ffi/api/api.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep #include "xla/ffi/call_frame.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/service/service_executable_run_options.h" #include "xla/status.h" -#include "xla/statusor.h" namespace xla::ffi { @@ -42,6 +44,7 @@ namespace xla::ffi { struct CallOptions { const ServiceExecutableRunOptions* run_options = nullptr; + const HloComputation* called_computation = nullptr; }; // Takes ownership of the XLA FFI error and returns underlying status. Frees @@ -58,9 +61,21 @@ Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame, // XLA FFI registry //===----------------------------------------------------------------------===// -// Returns registered FFI handler for a given name, or an error if it's not -// found in the static registry. -StatusOr FindHandler(std::string_view name); +struct HandlerRegistration { + XLA_FFI_Handler* handler = nullptr; + XLA_FFI_Handler_Traits traits = 0; +}; + +bool IsCommandBufferCompatible(XLA_FFI_Handler_Traits traits); + +// Returns registered FFI handler for a given name and platform, or an error if +// it's not found in the static registry. +absl::StatusOr FindHandler(std::string_view name, + std::string_view platform); + +// Returns all registered calls in the static registry for a given platform. +absl::flat_hash_map StaticRegisteredHandlers( + std::string_view platform); //===----------------------------------------------------------------------===// // XLA FFI Api Implementation diff --git a/xla/ffi/ffi_test.cc b/xla/ffi/ffi_test.cc index f40b09fc1eb18..56c49fc8063b7 100644 --- a/xla/ffi/ffi_test.cc +++ b/xla/ffi/ffi_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,28 +16,55 @@ limitations under the License. #include "xla/ffi/ffi.h" #include +#include #include #include #include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/types/span.h" #include "xla/ffi/call_frame.h" #include "xla/ffi/ffi_api.h" #include "xla/service/service_executable_run_options.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/status_matchers.h" #include "tsl/platform/test.h" namespace xla::ffi { +using ::testing::_; +using ::testing::HasSubstr; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; +using ::tsl::testing::StatusIs; + TEST(FfiTest, StaticRegistration) { static constexpr auto* noop = +[] { return absl::OkStatus(); }; - XLA_FFI_DEFINE_HANDLER(NoOp, noop, Ffi::Bind()); - XLA_FFI_REGISTER_HANDLER(GetXlaFfiApi(), "no-op", NoOp); + // Use explicit binding specification. + XLA_FFI_DEFINE_HANDLER(NoOp0, noop, Ffi::Bind()); + + // Automatically infer binding specification from function signature. + XLA_FFI_DEFINE_HANDLER(NoOp1, noop); + + XLA_FFI_REGISTER_HANDLER(GetXlaFfiApi(), "no-op-0", "Host", NoOp0); + XLA_FFI_REGISTER_HANDLER(GetXlaFfiApi(), "no-op-1", "Host", NoOp1, + XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE); + + auto handler0 = FindHandler("no-op-0", "Host"); + auto handler1 = FindHandler("no-op-1", "Host"); + + TF_ASSERT_OK(handler0.status()); + TF_ASSERT_OK(handler1.status()); + + ASSERT_EQ(handler0->traits, 0); + ASSERT_EQ(handler1->traits, XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE); - auto handler = FindHandler("no-op"); - TF_ASSERT_OK(handler.status()); + EXPECT_THAT(StaticRegisteredHandlers("Host"), + UnorderedElementsAre(Pair("no-op-0", _), Pair("no-op-1", _))); } TEST(FfiTest, ForwardError) { @@ -52,8 +79,8 @@ TEST(FfiTest, WrongNumArgs) { builder.AddBufferArg(se::DeviceMemoryBase(nullptr), PrimitiveType::F32, {}); auto call_frame = builder.Build(); - auto handler = Ffi::Bind().Arg().Arg().To( - [](Buffer, Buffer) { return absl::OkStatus(); }); + auto handler = Ffi::Bind().Arg().Arg().To( + [](BufferBase, BufferBase) { return absl::OkStatus(); }); auto status = Call(*handler, call_frame); @@ -107,6 +134,82 @@ TEST(FfiTest, BuiltinAttributes) { TF_ASSERT_OK(status); } +TEST(FfiTest, BuiltinAttributesAutoBinding) { + CallFrameBuilder::AttributesBuilder attrs; + attrs.Insert("i32", 42); + attrs.Insert("f32", 42.0f); + attrs.Insert("str", "foo"); + + CallFrameBuilder builder; + builder.AddAttributes(attrs.Build()); + auto call_frame = builder.Build(); + + static constexpr char kI32[] = "i32"; + static constexpr char kF32[] = "f32"; + static constexpr char kStr[] = "str"; + + auto fn = [&](Attr i32, Attr f32, + Attr str) { + EXPECT_EQ(*i32, 42); + EXPECT_EQ(*f32, 42.0f); + EXPECT_EQ(*str, "foo"); + return absl::OkStatus(); + }; + + auto handler = Ffi::BindTo(fn); + auto status = Call(*handler, call_frame); + TF_ASSERT_OK(status); +} + +TEST(FfiTest, ArrayAttr) { + CallFrameBuilder::AttributesBuilder attrs; + attrs.Insert("arr", std::vector({1, 2, 3, 4})); + + CallFrameBuilder builder; + builder.AddAttributes(attrs.Build()); + auto call_frame = builder.Build(); + + auto fn = [&](absl::Span arr) { + EXPECT_EQ(arr.size(), 4); + EXPECT_EQ(arr[0], 1); + EXPECT_EQ(arr[1], 2); + EXPECT_EQ(arr[2], 3); + EXPECT_EQ(arr[3], 4); + return absl::OkStatus(); + }; + + auto handler = Ffi::Bind().Attr>("arr").To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); +} + +TEST(FfiTest, PointerAttr) { + std::string foo = "foo"; + + // Test for convenience attr binding that casts i64 attribute to user-type + // pointers. It's up to the user to guarantee that pointer is valid. + auto ptr = reinterpret_cast(&foo); + static_assert(sizeof(ptr) == sizeof(int64_t)); + + CallFrameBuilder::AttributesBuilder attrs; + attrs.Insert("ptr", static_cast(ptr)); + + CallFrameBuilder builder; + builder.AddAttributes(attrs.Build()); + auto call_frame = builder.Build(); + + auto fn = [&](const std::string* str) { + EXPECT_EQ(*str, "foo"); + return absl::OkStatus(); + }; + + auto handler = Ffi::Bind().Attr>("ptr").To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); +} + TEST(FfiTest, AttrsAsDictionary) { CallFrameBuilder::AttributesBuilder attrs; attrs.Insert("i32", 42); @@ -275,12 +378,21 @@ TEST(FfiTest, DecodingErrors) { auto status = Call(*handler, call_frame); - ASSERT_EQ( + EXPECT_TRUE(absl::StrContains( status.message(), - "Failed to decode all FFI handler operands (bad operands at: 0, 1, 3)"); + "Failed to decode all FFI handler operands (bad operands at: 0, 1, 3)")); + + EXPECT_TRUE(absl::StrContains( + status.message(), "Attribute name mismatch: i32 vs not_i32_should_fail")); + + EXPECT_TRUE(absl::StrContains( + status.message(), "Attribute name mismatch: i64 vs not_i64_should_fail")); + + EXPECT_TRUE(absl::StrContains( + status.message(), "Attribute name mismatch: str vs not_str_should_fail")); } -TEST(FfiTest, BufferArgument) { +TEST(FfiTest, BufferBaseArgument) { std::vector storage(4, 0.0f); se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); @@ -288,17 +400,86 @@ TEST(FfiTest, BufferArgument) { builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); auto call_frame = builder.Build(); - auto fn = [&](Buffer buffer) { + auto fn = [&](BufferBase buffer) { EXPECT_EQ(buffer.dtype, PrimitiveType::F32); EXPECT_EQ(buffer.data.opaque(), storage.data()); EXPECT_EQ(buffer.dimensions.size(), 2); return absl::OkStatus(); }; - auto handler = Ffi::Bind().Arg().To(fn); + { // Test explicit binding signature declaration. + auto handler = Ffi::Bind().Arg().To(fn); + auto status = Call(*handler, call_frame); + TF_ASSERT_OK(status); + } + + { // Test inferring binding signature from a handler type. + auto handler = Ffi::BindTo(fn); + auto status = Call(*handler, call_frame); + TF_ASSERT_OK(status); + } +} + +TEST(FfiTest, TypedAndRankedBufferArgument) { + std::vector storage(4, 0.0f); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + + CallFrameBuilder builder; + builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + auto fn = [&](BufferR2 buffer) { + EXPECT_EQ(buffer.data.opaque(), storage.data()); + EXPECT_EQ(buffer.dimensions.size(), 2); + return absl::OkStatus(); + }; + + { // Test explicit binding signature declaration. + auto handler = Ffi::Bind().Arg>().To(fn); + auto status = Call(*handler, call_frame); + TF_ASSERT_OK(status); + } + + { // Test inferring binding signature from a handler type. + auto handler = Ffi::BindTo(fn); + auto status = Call(*handler, call_frame); + TF_ASSERT_OK(status); + } +} + +TEST(FfiTest, WrongRankBufferArgument) { + std::vector storage(4, 0.0); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(int32_t)); + + CallFrameBuilder builder; + builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + auto handler = Ffi::Bind().Arg>().To( + [](auto) { return absl::OkStatus(); }); auto status = Call(*handler, call_frame); - TF_ASSERT_OK(status); + EXPECT_THAT(status, + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Wrong buffer rank: expected 1 but got 2"))); +} + +TEST(FfiTest, WrongTypeBufferArgument) { + std::vector storage(4, 0.0); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(int32_t)); + + CallFrameBuilder builder; + builder.AddBufferArg(memory, PrimitiveType::S32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + auto handler = Ffi::Bind().Arg>().To( + [](auto) { return absl::OkStatus(); }); + auto status = Call(*handler, call_frame); + + EXPECT_THAT( + status, + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Wrong buffer dtype: expected f32 but got s32"))); } TEST(FfiTest, RemainingArgs) { @@ -311,8 +492,8 @@ TEST(FfiTest, RemainingArgs) { auto fn = [&](RemainingArgs args) { EXPECT_EQ(args.size(), 1); - EXPECT_TRUE(args.get(0).has_value()); - EXPECT_FALSE(args.get(1).has_value()); + EXPECT_TRUE(args.get(0).has_value()); + EXPECT_FALSE(args.get(1).has_value()); return absl::OkStatus(); }; @@ -324,15 +505,18 @@ TEST(FfiTest, RemainingArgs) { TEST(FfiTest, RunOptionsCtx) { auto call_frame = CallFrameBuilder().Build(); - auto* expected = reinterpret_cast(0x01234567); + auto* expected = reinterpret_cast(0x01234567); + + ServiceExecutableRunOptions opts; + opts.mutable_run_options()->set_stream(expected); - auto fn = [&](const ServiceExecutableRunOptions* run_options) { + auto fn = [&](const se::Stream* run_options) { EXPECT_EQ(run_options, expected); return absl::OkStatus(); }; - auto handler = Ffi::Bind().Ctx().To(fn); - auto status = Call(*handler, call_frame, {expected}); + auto handler = Ffi::Bind().Ctx().To(fn); + auto status = Call(*handler, call_frame, {&opts}); TF_ASSERT_OK(status); } diff --git a/xla/frontend_attributes.cc b/xla/frontend_attributes.cc index 405aacbfd31bd..8831040f89c15 100644 --- a/xla/frontend_attributes.cc +++ b/xla/frontend_attributes.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/frontend_attributes.h b/xla/frontend_attributes.h index de2643beeebb5..f8f9a68ce30fe 100644 --- a/xla/frontend_attributes.h +++ b/xla/frontend_attributes.h @@ -1,6 +1,6 @@ #ifndef XLA_FRONTEND_ATTRIBUTES_H_ #define XLA_FRONTEND_ATTRIBUTES_H_ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/glob_lit_test.bzl b/xla/glob_lit_test.bzl deleted file mode 100644 index 9780db87b2984..0000000000000 --- a/xla/glob_lit_test.bzl +++ /dev/null @@ -1,138 +0,0 @@ -# Test definitions for Lit, the LLVM test runner. -# -# This is reusing the LLVM Lit test runner in the interim until the new build -# rules are upstreamed. -# TODO(b/136126535): remove this custom rule. -"""Lit runner globbing test -""" - -load("@bazel_skylib//lib:paths.bzl", "paths") - -# Default values used by the test runner. -_default_test_file_exts = ["mlir", ".pbtxt", ".td"] -_default_driver = "@llvm-project//mlir:run_lit.sh" -_default_size = "small" -_default_tags = [] - -# These are patterns which we should never match, for tests, subdirectories, or -# test input data files. -_ALWAYS_EXCLUDE = [ - "**/LICENSE.txt", - "**/README.txt", - "**/lit.local.cfg", - # Exclude input files that have spaces in their names, since bazel - # cannot cope with such "targets" in the srcs list. - "**/* *", - "**/* */**", -] - -def _run_lit_test( - name, - data, - size, - tags, - driver, # @unused - features, - exec_properties): - """Runs lit on all tests it can find in `data` under xla/. - - Note that, due to Bazel's hermetic builds, lit only sees the tests that - are included in the `data` parameter, regardless of what other tests might - exist in the directory searched. - - Args: - name: str, the name of the test, including extension. - data: [str], the data input to the test. - size: str, the size of the test. - tags: [str], tags to attach to the test. - driver: str, label of the driver shell script. - Note: use of a custom driver is not currently supported - and specifying a default driver will abort the tests. - features: [str], list of extra features to enable. - exec_properties: may enable things like remote execution. - """ - xla_root_dir = "xla/" - - # Disable tests on windows for now, to enable testing rest of all xla and mlir. - native.py_test( - name = name, - srcs = ["@llvm-project//llvm:lit"], - tags = tags + ["no_windows"], - args = [ - xla_root_dir + paths.basename(data[-1]) + " --config-prefix=runlit -v", - ] + features, - data = data + [ - "//xla:litfiles", - "@llvm-project//llvm:FileCheck", - "@llvm-project//llvm:count", - "@llvm-project//llvm:not", - ], - size = size, - main = "lit.py", - exec_properties = exec_properties, - ) - -def glob_lit_tests( - name = None, - exclude = [], - test_file_exts = _default_test_file_exts, - default_size = _default_size, - size_override = {}, - data = [], - per_test_extra_data = {}, - default_tags = _default_tags, - tags_override = {}, - driver = _default_driver, - features = [], - exec_properties = {}): - """Creates all plausible Lit tests (and their inputs) under this directory. - - Args: - name: str, name of the test_suite rule to generate for running all tests. - exclude: [str], paths to exclude (for tests and inputs). - test_file_exts: [str], extensions for files that are tests. - default_size: str, the test size for targets not in "size_override". - size_override: {str: str}, sizes to use for specific tests. - data: [str], additional input data to the test. - per_test_extra_data: {str: [str]}, extra data to attach to a given file. - default_tags: [str], additional tags to attach to the test. - tags_override: {str: str}, tags to add to specific tests. - driver: str, label of the driver shell script. - Note: use of a custom driver is not currently supported - and specifying a default driver will abort the tests. - features: [str], list of extra features to enable. - exec_properties: a dictionary of properties to pass on. - """ - - # Ignore some patterns by default for tests and input data. - exclude = _ALWAYS_EXCLUDE + exclude - - tests = native.glob( - ["*." + ext for ext in test_file_exts], - exclude = exclude, - ) - - # Run tests individually such that errors can be attributed to a specific - # failure. - all_tests = [] - for curr_test in tests: - all_tests.append(curr_test + ".test") - - # Instantiate this test with updated parameters. - _run_lit_test( - name = curr_test + ".test", - data = data + [curr_test] + per_test_extra_data.get(curr_test, []), - size = size_override.get(curr_test, default_size), - tags = default_tags + tags_override.get(curr_test, []), - driver = driver, - features = features, - exec_properties = exec_properties, - ) - - # TODO: remove this check after making it a required param. - if name: - native.test_suite( - name = name, - tests = all_tests, - tags = ["manual"], - ) diff --git a/xla/hlo/evaluator/BUILD b/xla/hlo/evaluator/BUILD index ca64de895fad7..026a984878305 100644 --- a/xla/hlo/evaluator/BUILD +++ b/xla/hlo/evaluator/BUILD @@ -1,8 +1,8 @@ # Description: # XLA evaluator implementation. -load("//xla:xla.bzl", "xla_cc_test") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -82,8 +82,8 @@ cc_library( "@tsl//tsl/lib/core:bitmap", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:float8", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:platform_port", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", diff --git a/xla/hlo/evaluator/hlo_evaluator.cc b/xla/hlo/evaluator/hlo_evaluator.cc index ef50a619263fe..25cb1126e24b2 100644 --- a/xla/hlo/evaluator/hlo_evaluator.cc +++ b/xla/hlo/evaluator/hlo_evaluator.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -83,8 +83,8 @@ limitations under the License. #include "tsl/platform/cpu_info.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" -#include "tsl/platform/float8.h" #include "tsl/platform/logging.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/types.h" @@ -96,9 +96,10 @@ namespace { using primitive_util::NativeTypeOf; template -StatusOr Compare(const Shape& shape, Comparison comparison, - LiteralSlice lhs_literal, LiteralSlice rhs_literal) { - auto populate = [&](auto compare_op) -> StatusOr { +absl::StatusOr Compare(const Shape& shape, Comparison comparison, + LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { + auto populate = [&](auto compare_op) -> absl::StatusOr { Literal result(shape); TF_RETURN_IF_ERROR(result.PopulateParallel( [&](absl::Span multi_index, int /*thread_id*/) { @@ -147,7 +148,7 @@ StatusOr Compare(const Shape& shape, Comparison comparison, std::optional GetInstructionStaticValueAsBool( const HloInstruction* instruction) { HloEvaluator evaluator; - StatusOr static_value = evaluator.Evaluate( + absl::StatusOr static_value = evaluator.Evaluate( instruction, /*recursively_evaluate_nonconstant_operands=*/true); if (static_value.ok()) { return static_value->GetFirstElement(); @@ -251,7 +252,7 @@ struct DynamicOrStaticInteger { std::optional GetInstructionValueAsInteger( const HloInstruction* instruction) { HloEvaluator evaluator; - StatusOr static_value = evaluator.Evaluate( + absl::StatusOr static_value = evaluator.Evaluate( instruction, /*recursively_evaluate_nonconstant_operands=*/true); if (static_value.ok()) { if (instruction->shape().element_type() == PrimitiveType::PRED) { @@ -859,7 +860,7 @@ HloEvaluator::HloEvaluator(int64_t max_loop_iterations) }); } -StatusOr HloEvaluator::Evaluate( +absl::StatusOr HloEvaluator::Evaluate( const HloComputation& computation, absl::Span arg_literals) { CHECK(computation.parent() != nullptr); @@ -920,7 +921,7 @@ StatusOr HloEvaluator::Evaluate( return result.Clone(); } -StatusOr HloEvaluator::Evaluate( +absl::StatusOr HloEvaluator::Evaluate( const HloInstruction* instruction, bool recursively_evaluate_nonconstant_operands) { arg_literals_.clear(); @@ -955,7 +956,7 @@ bool HloEvaluator::TryEvaluate(const HloInstruction* instruction, return true; } -StatusOr HloEvaluator::EvaluateWithSubstitutions( +absl::StatusOr HloEvaluator::EvaluateWithSubstitutions( const HloInstruction* instruction, const absl::flat_hash_map& substitutions) { @@ -983,7 +984,7 @@ StatusOr HloEvaluator::EvaluateWithSubstitutions( return result; } -StatusOr HloEvaluator::EvaluateElementwiseBinaryOp( +absl::StatusOr HloEvaluator::EvaluateElementwiseBinaryOp( HloOpcode opcode, const Literal& lhs, const Literal& rhs) { std::unique_ptr lhs_instr = HloInstruction::CreateConstant(lhs.Clone()); @@ -998,7 +999,7 @@ StatusOr HloEvaluator::EvaluateElementwiseBinaryOp( return result; } -StatusOr HloEvaluator::EvaluateElementwiseTernaryOp( +absl::StatusOr HloEvaluator::EvaluateElementwiseTernaryOp( HloOpcode opcode, const Literal& lhs, const Literal& rhs, const Literal& ehs) { std::unique_ptr lhs_instr = @@ -1016,7 +1017,7 @@ StatusOr HloEvaluator::EvaluateElementwiseTernaryOp( return Evaluate(cloned_instruction.get()); } -StatusOr HloEvaluator::EvaluateElementwiseCompareOp( +absl::StatusOr HloEvaluator::EvaluateElementwiseCompareOp( ComparisonDirection direction, const Literal& lhs, const Literal& rhs) { std::unique_ptr lhs_instr = HloInstruction::CreateConstant(lhs.Clone()); @@ -1032,7 +1033,7 @@ StatusOr HloEvaluator::EvaluateElementwiseCompareOp( return result; } -StatusOr HloEvaluator::EvaluateElementwiseUnaryOp( +absl::StatusOr HloEvaluator::EvaluateElementwiseUnaryOp( HloOpcode opcode, const Literal& operand) { std::unique_ptr operand_instr = HloInstruction::CreateConstant(operand.Clone()); @@ -1046,7 +1047,7 @@ StatusOr HloEvaluator::EvaluateElementwiseUnaryOp( return result; } -StatusOr HloEvaluator::EvaluateDotOp( +absl::StatusOr HloEvaluator::EvaluateDotOp( const DotDimensionNumbers& dim_numbers, const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs) { @@ -1189,7 +1190,7 @@ Status HloEvaluator::EvaluateInternal( } if (!tuple_points_to_analysis_cache_) { HloModule* module = instruction->GetModule(); - StatusOr> + absl::StatusOr> tuple_points_to_analysis = TuplePointsToAnalysis::Run(module); if (tuple_points_to_analysis.ok()) { tuple_points_to_analysis_cache_ = @@ -2347,7 +2348,7 @@ class OutputBatchIndexToInputIndex { // same storage for all invocations. // // This returns a Span into memory owned by the class. - StatusOr> operator()( + absl::StatusOr> operator()( absl::Span output_index) { PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index); TF_RETURN_IF_ERROR(FetchIndexVector()); @@ -2467,7 +2468,7 @@ class OutputOffsetIndexToInputIndex { // result (input_index_), mutating it in place. // // This returns a Span into memory owned by the class. - StatusOr> operator()( + absl::StatusOr> operator()( absl::Span output_index) { PropagateOutputIndexWindowDimsToInputIndex(output_index); return absl::Span(input_index_); @@ -2507,9 +2508,9 @@ class OutputOffsetIndexToInputIndex { // Reshapes the gather indices input to have a trailing degenerate `1` dimension // if necessary. Hands over the ownership of the newly created literal (if // there is one) to `reshaped_start_indices`. -static StatusOr> ReshapedGatherIndices( - int64_t index_vector_dim, const Literal& start_indices, - Literal* reshaped_start_indices) { +static absl::StatusOr> +ReshapedGatherIndices(int64_t index_vector_dim, const Literal& start_indices, + Literal* reshaped_start_indices) { if (start_indices.shape().dimensions_size() != index_vector_dim) { return std::cref(start_indices); } @@ -2574,7 +2575,8 @@ Status HloEvaluator::HandleGather(const HloInstruction* gather) { auto gather_inner_loop_body = [&](absl::Span output_window_index, absl::Span input_gather_index, - absl::Span output_gather_index) -> StatusOr { + absl::Span output_gather_index) + -> absl::StatusOr { TF_ASSIGN_OR_RETURN( absl::Span input_window_index, output_offset_index_to_input_index(output_window_index)); @@ -2608,7 +2610,8 @@ Status HloEvaluator::HandleGather(const HloInstruction* gather) { }; auto gather_outer_loop_body = - [&](absl::Span output_gather_index) -> StatusOr { + [&](absl::Span output_gather_index) + -> absl::StatusOr { TF_ASSIGN_OR_RETURN(absl::Span input_gather_index, output_batch_index_to_input_index(output_gather_index)); TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( @@ -2628,7 +2631,7 @@ namespace { // Reshapes the scatter indices input to have a trailing degenerate `1` // dimension if necessary. Hands over the ownership of the newly created // literal (if there is one) to `reshaped_indices`. -StatusOr> ReshapedScatterIndices( +absl::StatusOr> ReshapedScatterIndices( int64_t index_vector_dim, const Literal& indices, Literal* reshaped_indices) { if (indices.shape().dimensions_size() != index_vector_dim) { @@ -2750,7 +2753,7 @@ class UpdateScatterIndexToInputIndex { // same storage for all invocations. // // This returns a Span into memory owned by the class. - StatusOr> operator()( + absl::StatusOr> operator()( absl::Span update_index) { PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index); TF_RETURN_IF_ERROR(FetchIndexVector()); @@ -2873,7 +2876,7 @@ class UpdateWindowIndexToInputIndex { // result (input_index_), mutating it in place. // // This returns a Span into memory owned by the class. - StatusOr> operator()( + absl::StatusOr> operator()( absl::Span update_index) { PropagateUpdateIndexWindowDimsToInputIndex(update_index); return absl::Span(input_index_); @@ -2966,7 +2969,8 @@ Status HloEvaluator::HandleScatter(const HloInstruction* hlo) { auto scatter_inner_loop_body = [&](absl::Span update_window_index, absl::Span input_scatter_index, - absl::Span update_scatter_index) -> StatusOr { + absl::Span update_scatter_index) + -> absl::StatusOr { TF_ASSIGN_OR_RETURN( absl::Span input_window_index, update_window_index_to_input_index(update_window_index)); @@ -3018,7 +3022,8 @@ Status HloEvaluator::HandleScatter(const HloInstruction* hlo) { }; auto scatter_outer_loop_body = - [&](absl::Span update_scatter_index) -> StatusOr { + [&](absl::Span update_scatter_index) + -> absl::StatusOr { TF_ASSIGN_OR_RETURN( absl::Span input_scatter_index, update_scatter_index_to_input_index(update_scatter_index)); @@ -3115,13 +3120,14 @@ Status HloEvaluator::HandleAsyncStart(const HloInstruction* async_start) { arg_literals.push_back(&arg_literal); } - HloEvaluator embedded_evaluator; - embedded_evaluator.set_dynamic_dimension_inference( + std::unique_ptr embedded_evaluator = + CreateEmbedded(max_loop_iterations_); + embedded_evaluator->set_dynamic_dimension_inference( dynamic_dimension_inference_); TF_ASSIGN_OR_RETURN( Literal result, - embedded_evaluator.Evaluate(*async_start->async_wrapped_computation(), - arg_literals)); + embedded_evaluator->Evaluate(*async_start->async_wrapped_computation(), + arg_literals)); evaluated_[async_start] = Literal(async_start->shape()); // Copy the operand values to the index {0, i} of the output. @@ -3415,10 +3421,10 @@ Status HloEvaluator::HandleSelect(const HloInstruction* select) { namespace { -StatusOr CreateScalarLiteral(int64_t value, - PrimitiveType element_type) { +absl::StatusOr CreateScalarLiteral(int64_t value, + PrimitiveType element_type) { return primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> StatusOr { + [&](auto primitive_type_constant) -> absl::StatusOr { if constexpr (primitive_util::IsIntegralType(primitive_type_constant)) { return LiteralUtil::CreateR0( static_cast>(value)); @@ -3431,7 +3437,7 @@ StatusOr CreateScalarLiteral(int64_t value, // Parses the while loop if it matches one of the known patterns. Returns the // value of the loop induction variable after the loop execution if the loop is // static. -StatusOr TryParseAndEvaluateWhileInductionVar( +absl::StatusOr TryParseAndEvaluateWhileInductionVar( const HloInstruction* while_hlo) { std::optional parsed_while_loop = PatternMatchParseWhileLoop(while_hlo); @@ -3506,7 +3512,7 @@ Status HloEvaluator::HandleWhile(const HloInstruction* while_hlo) { dynamic_dimension_inference_); while (keep_going) { if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) { - StatusOr result = + absl::StatusOr result = TryParseAndEvaluateWhileInductionVar(while_hlo); if (result.ok()) { lcv = std::move(result).value(); @@ -3545,11 +3551,11 @@ Literal ExtractLiteralFromIndexPositions(const Literal& from, return LiteralUtil::CreateR1(values); } -StatusOr ExtractFromIndexPositions(const Literal& from, - absl::Span indices) { +absl::StatusOr ExtractFromIndexPositions( + const Literal& from, absl::Span indices) { PrimitiveType type = from.shape().element_type(); return primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> StatusOr { + [&](auto primitive_type_constant) -> absl::StatusOr { if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { return ExtractLiteralFromIndexPositions< NativeTypeOf>(from, indices); @@ -3608,9 +3614,9 @@ void IterateThroughWindow( } template -StatusOr StochasticConvertOp(const Literal& operand_literal, - const Literal& random_literal, - const Shape& result_shape) { +absl::StatusOr StochasticConvertOp(const Literal& operand_literal, + const Literal& random_literal, + const Shape& result_shape) { std::function stochastic_convert_op = [](Fp operand, Uint random) -> ResultT { bool is_negative = static_cast(Eigen::numext::signbit(operand)); @@ -3672,9 +3678,9 @@ StatusOr StochasticConvertOp(const Literal& operand_literal, // Converts from primitive types to native types. template -StatusOr StochasticConvertOp(const Literal& operand_literal, - const Literal& random_literal, - const Shape& result_shape) { +absl::StatusOr StochasticConvertOp(const Literal& operand_literal, + const Literal& random_literal, + const Shape& result_shape) { return StochasticConvertOp< typename primitive_util::PrimitiveTypeToNative::type, typename primitive_util::PrimitiveTypeToNative::type, @@ -3684,11 +3690,11 @@ StatusOr StochasticConvertOp(const Literal& operand_literal, // Evaluates all possible paths of converting to different integers. template -StatusOr StochasticConvertOp(const Literal& operand_literal, - const Literal& random_literal, - const Shape& result_shape) { +absl::StatusOr StochasticConvertOp(const Literal& operand_literal, + const Literal& random_literal, + const Shape& result_shape) { return primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> StatusOr { + [&](auto primitive_type_constant) -> absl::StatusOr { if constexpr (primitive_util::IsSignedIntegralType( primitive_type_constant)) { return StochasticConvertOp StochasticConvertOp(const Literal& operand_literal, result_shape.element_type()); } -StatusOr StochasticConvertOp(const Literal& operand_literal, - const Literal& random_literal, - const Shape& result_shape) { +absl::StatusOr StochasticConvertOp(const Literal& operand_literal, + const Literal& random_literal, + const Shape& result_shape) { return primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> StatusOr { + [&](auto primitive_type_constant) -> absl::StatusOr { if constexpr (primitive_util::IsFloatingPointType( primitive_type_constant)) { return StochasticConvertOp< @@ -3924,9 +3930,9 @@ Status HloEvaluator::HandleSort(const HloInstruction* sort) { << " accessing increment of size " << increment.size(); increment[sort_dim] = sort_dim_elements; - auto comparator = [sort](absl::Span literals_to_sort, - int64_t a, int64_t b, - HloEvaluator* embedded_evaluator) -> StatusOr { + auto comparator = + [sort](absl::Span literals_to_sort, int64_t a, int64_t b, + HloEvaluator* embedded_evaluator) -> absl::StatusOr { absl::InlinedVector literals; literals.reserve(2 * sort->operand_count()); for (int64_t i = 0; i < sort->operand_count(); ++i) { @@ -3947,10 +3953,10 @@ Status HloEvaluator::HandleSort(const HloInstruction* sort) { embedded_evaluator->ResetVisitStates(); return computed_result.Get({}); }; - auto less_than = [&comparator]( - absl::Span literals_to_sort, int64_t a, - int64_t b, - HloEvaluator* embedded_evaluator) -> StatusOr { + auto less_than = + [&comparator](absl::Span literals_to_sort, int64_t a, + int64_t b, + HloEvaluator* embedded_evaluator) -> absl::StatusOr { TF_ASSIGN_OR_RETURN(bool a_is_smaller, comparator(literals_to_sort, a, b, embedded_evaluator)); #ifndef NDEBUG @@ -4100,7 +4106,7 @@ Status HloEvaluator::HandleSort(const HloInstruction* sort) { // Iterate through each dimension except 'sort_dim'. TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( key_shape, zero_base, key_shape.dimensions(), increment, - [&](absl::Span indices) -> StatusOr { + [&](absl::Span indices) -> absl::StatusOr { // Extract a slice from each operand literal that corresponds to // exactly the row in dimension 'sort_dim'. std::vector limit_indices(indices.begin(), indices.end()); @@ -4185,7 +4191,7 @@ static bool IsScalarAdd(HloComputation* computation) { // the user-provided computation on the accumulator and the output element // (until the reduction is completed, the output element is also used as // an accumulator). -static StatusOr PerformReductionStep( +static absl::StatusOr PerformReductionStep( bool is_tuple, absl::Span input_index, absl::Span output_index, absl::Span input_args, absl::Span results, @@ -4235,7 +4241,7 @@ static StatusOr PerformReductionStep( return true; } -static StatusOr GenerateReduceOutputElement( +static absl::StatusOr GenerateReduceOutputElement( bool is_tuple, absl::Span output_index, absl::Span init_values, @@ -4278,7 +4284,7 @@ static StatusOr GenerateReduceOutputElement( // Periodically compute partial sum to avoid linear_indices getting // large computed_result += *input_arg0->GetSumAsDouble( - absl::MakeConstSpan(&linear_indices[0], n_linear_indices)); + absl::MakeConstSpan(linear_indices, n_linear_indices)); n_linear_indices = 0; } return true; @@ -4288,7 +4294,7 @@ static StatusOr GenerateReduceOutputElement( if (n_linear_indices > 0) { // Add in sum over any final indices collected computed_result += *input_arg0->GetSumAsDouble( - absl::MakeConstSpan(&linear_indices[0], n_linear_indices)); + absl::MakeConstSpan(linear_indices, n_linear_indices)); } TF_RETURN_IF_ERROR(results[0].SetFromDouble(output_index, computed_result)); return true; diff --git a/xla/hlo/evaluator/hlo_evaluator.h b/xla/hlo/evaluator/hlo_evaluator.h index f2095fa1d756c..ed6accfacf96e 100644 --- a/xla/hlo/evaluator/hlo_evaluator.h +++ b/xla/hlo/evaluator/hlo_evaluator.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -89,7 +89,9 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // instance of the subclass instead. virtual std::unique_ptr CreateEmbedded( int64_t max_loop_iterations) { - return std::make_unique(max_loop_iterations); + auto result = std::make_unique(max_loop_iterations); + result->set_custom_call_handler(custom_call_handler_); + return result; } // Enables subclasses to be notified when a new computation is being @@ -105,13 +107,13 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // // (Dummy template arg is to reduce the overloading priority of one overload // so that Evaluate(module, {}) resolves unambiguously.) - StatusOr Evaluate(const HloModule& module, - absl::Span arg_literals) { + absl::StatusOr Evaluate( + const HloModule& module, absl::Span arg_literals) { return Evaluate(*module.entry_computation(), arg_literals); } template - StatusOr Evaluate(const HloModule& module, - absl::Span arg_literals) { + absl::StatusOr Evaluate(const HloModule& module, + absl::Span arg_literals) { return Evaluate(*module.entry_computation(), arg_literals); } @@ -134,11 +136,12 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // // (Dummy template arg is to reduce the overloading priority of one overload // so that Evaluate(module, {}) resolves unambiguously.) - StatusOr Evaluate(const HloComputation& computation, - absl::Span arg_literals); + absl::StatusOr Evaluate( + const HloComputation& computation, + absl::Span arg_literals); template - StatusOr Evaluate(const HloComputation& computation, - absl::Span arg_literals) { + absl::StatusOr Evaluate(const HloComputation& computation, + absl::Span arg_literals) { std::vector arg_literal_ptrs; for (const auto& l : arg_literals) { arg_literal_ptrs.push_back(&l); @@ -152,7 +155,7 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // within its parent computation until it encounters something that cannot be // evaluated, such as an Infeed or a Parameter instruction. // It makes best effort to partially evaluate a dependency if possible. - StatusOr Evaluate( + absl::StatusOr Evaluate( const HloInstruction* instruction, bool recursively_evaluate_nonconstant_operands = false); @@ -166,30 +169,29 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // // For example, given instruction = op(A, B, C) and the map // {A = x, C = y}, this evaluates op(x, B, y). - StatusOr EvaluateWithSubstitutions( + absl::StatusOr EvaluateWithSubstitutions( const HloInstruction* instruction, const absl::flat_hash_map& substitutions); - StatusOr EvaluateElementwiseBinaryOp(HloOpcode opcode, - const Literal& lhs, - const Literal& rhs); + absl::StatusOr EvaluateElementwiseBinaryOp(HloOpcode opcode, + const Literal& lhs, + const Literal& rhs); - StatusOr EvaluateElementwiseUnaryOp(HloOpcode opcode, - const Literal& operand); + absl::StatusOr EvaluateElementwiseUnaryOp(HloOpcode opcode, + const Literal& operand); - StatusOr EvaluateElementwiseTernaryOp(HloOpcode opcode, - const Literal& lhs, - const Literal& rhs, - const Literal& ehs); + absl::StatusOr EvaluateElementwiseTernaryOp(HloOpcode opcode, + const Literal& lhs, + const Literal& rhs, + const Literal& ehs); - StatusOr EvaluateElementwiseCompareOp(ComparisonDirection direction, - const Literal& lhs, - const Literal& rhs); + absl::StatusOr EvaluateElementwiseCompareOp( + ComparisonDirection direction, const Literal& lhs, const Literal& rhs); - StatusOr EvaluateDotOp(const DotDimensionNumbers& dim_numbers, - const PrecisionConfig& precision_config, - const Literal& lhs, const Literal& rhs); + absl::StatusOr EvaluateDotOp(const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, + const Literal& lhs, const Literal& rhs); void set_dynamic_dimension_inference( DynamicDimensionInference* dynamic_dimension_inference) { @@ -206,7 +208,7 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // Handles evaluation of a custom-call op. // Operand literals are provided in |operands| and implementations must // populate |output| before returning. - using CustomCallHandler = std::function( + using CustomCallHandler = std::function( const HloInstruction* custom_call, absl::Span operands)>; // Sets a handler that is called during evaluation for custom-call ops. @@ -434,7 +436,7 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { private: template - static StatusOr ElementWiseUnaryOpImpl( + static absl::StatusOr ElementWiseUnaryOpImpl( const HloInstruction* instruction, const std::function& unary_op, const Literal& operand_literal) { diff --git a/xla/hlo/evaluator/hlo_evaluator_test.cc b/xla/hlo/evaluator/hlo_evaluator_test.cc index a831f51f4a91a..ef3cb8ad8cee9 100644 --- a/xla/hlo/evaluator/hlo_evaluator_test.cc +++ b/xla/hlo/evaluator/hlo_evaluator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -77,7 +77,7 @@ class HloEvaluatorTest : public HloTestBase { public: HloEvaluatorTest() : use_bfloat16_(false) { InitializeFftData(); } - StatusOr Evaluate( + absl::StatusOr Evaluate( absl::Span arg_literals = {}) { if (use_bfloat16_) { HloElementTypeConverter(F32, BF16).Run(m_.get()).value(); @@ -155,7 +155,7 @@ class HloEvaluatorTest : public HloTestBase { } void TestEvaluationFailure(HloInstruction* instruction) { - StatusOr result = evaluator_.Evaluate(instruction); + absl::StatusOr result = evaluator_.Evaluate(instruction); EXPECT_TRUE(!result.ok()); } @@ -170,7 +170,7 @@ class HloEvaluatorTest : public HloTestBase { } void TestRecursiveEvaluationFailure(HloInstruction* instruction) { - StatusOr result = evaluator_.Evaluate( + absl::StatusOr result = evaluator_.Evaluate( instruction, /*recursively_evaluate_nonconstant_operands=*/true); EXPECT_TRUE(!result.ok()); } @@ -4560,7 +4560,7 @@ TEST_F(HloEvaluatorTest, EvaluateCustomCall_HandlerError) { HloEvaluator evaluator; evaluator.set_custom_call_handler([](const HloInstruction* custom_call, absl::Span operands) { - return InternalError("Test error"); + return Internal("Test error"); }); EXPECT_EQ(evaluator.Evaluate(*m_, {&args[0]}).status().code(), ::tsl::error::INTERNAL); @@ -4605,6 +4605,30 @@ TEST_F(HloEvaluatorTest, EvaluateCustomCall_ManyInputs) { EXPECT_TRUE(absl::c_equal(expected_data, actual_literal.data())); } +TEST_F(HloEvaluatorTest, EvaluateCustomCallInFusion) { + const absl::string_view hlo_text = R"( +fusion1 { + p = f32[] parameter(0) + ROOT c = f32[] custom-call(p), custom_call_target="__cchandler1" +} + +ENTRY e { + p = f32[] parameter(0) + ROOT f = f32[] fusion(p), kind=kCustom, calls=fusion1 +})"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto input = LiteralUtil::CreateR0(0); + HloEvaluator evaluator; + evaluator.set_custom_call_handler([](const HloInstruction* custom_call, + absl::Span operands) { + return LiteralUtil::CreateR0(1 - + operands[0]->GetFirstElement()); + }); + TF_ASSERT_OK_AND_ASSIGN(auto output, evaluator.Evaluate(*m_, {&input})); + EXPECT_EQ(output, LiteralUtil::CreateR0(1)); +} + TEST_F(HloEvaluatorTest, IsFiniteF16) { const absl::string_view hlo_text = R"( HloModule test @@ -4843,6 +4867,23 @@ TEST_F(HloEvaluatorTest, SortC64) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +TEST_F(HloEvaluatorTest, ConvertC128ToC64) { + const absl::string_view hlo_text = R"( + HloModule m + + ENTRY main { + c = c128[3] constant({(2, 0), (4, 0), (6, 0)}) + ROOT sort = c64[3]{0} convert(c) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + Literal expected = + LiteralUtil::CreateR1>({2.f, 4.f, 6.f}); + TF_ASSERT_OK_AND_ASSIGN( + Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + // Tests that HloEvaluator can evaluate an instruction even when its operands // are not constant. TEST_F(HloEvaluatorTest, RecursivelyEvaluateNonConstantOperands) { diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index 63b480253caf1..68b79d25bf5d2 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -245,6 +245,18 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { return UnsupportedTypeError(ceil); } + Status HandleErf(const HloInstruction* erf) override { + if constexpr (!is_complex_v) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[erf], + ElementWiseUnaryOp(erf, [](ElementwiseT elem_operand) { + return std::erf(elem_operand); + })); + return OkStatus(); + } + return UnsupportedTypeError(erf); + } + Status HandleExp(const HloInstruction* exp) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) { @@ -511,15 +523,13 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { parent_->evaluated_[power], ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - // Case 0: 1^x = 1 - if (lhs_el == ElementwiseT(1)) { - return static_cast(1); - } - // Case 1: 0^0 = 1 - if (lhs_el == ElementwiseT(0) && rhs_el == ElementwiseT(0)) { + // Case 0: 1^x = 1 and x^0 = 1, regardless of X, see + // Branch Cuts for Complex Elementary Functions or Much Ado About + // Nothing's Sign Bit, W. Kahan, Section 10. + if (lhs_el == ElementwiseT(1) || rhs_el == ElementwiseT(0)) { return static_cast(1); } - // Case 2: + // Case 1: // 1. inf^(a + 0i) = inf, if a > 0. // 2. inf^(a + 0i) = 0, if a < 0. if constexpr (is_complex_v) { @@ -539,7 +549,7 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { return static_cast(0); } } - // Case 3: + // Case 2: // Fallback to pow. if constexpr (std::is_same_v) { return lhs_el || !rhs_el; @@ -1595,7 +1605,7 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { } private: - StatusOr ElementWiseUnaryOp( + absl::StatusOr ElementWiseUnaryOp( const HloInstruction* instruction, const std::function& unary_op) { const Literal& operand_literal = @@ -1608,7 +1618,7 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { return std::move(result_literal); } - StatusOr ElementWiseBinaryOp( + absl::StatusOr ElementWiseBinaryOp( const HloInstruction* instruction, const std::function& binary_op) { @@ -1633,7 +1643,7 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { } template - StatusOr ElementwiseTernaryOp( + absl::StatusOr ElementwiseTernaryOp( const HloInstruction* instruction, const std::function& ternary_op) { const auto& shape = instruction->shape(); diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_bfloat16.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_bfloat16.cc index c7e7fa4a7273d..09859c69bb874 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_bfloat16.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_bfloat16.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_bool.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_bool.cc index 6e5dccfb85eb6..f5ea9853d3ee5 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_bool.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_bool.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_complex128.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_complex128.cc index 0faf2f0ae40b5..d113a6e30c626 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_complex128.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_complex128.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_complex64.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_complex64.cc index 0826f1903c7bc..9a55017daff5c 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_complex64.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_complex64.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_double.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_double.cc index 468c2ebbe8dd3..25c26f1a6ea33 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_double.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_double.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float.cc index 44d90adfdd7fb..c60fa62ebfb95 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc index 18a7b2c576000..b2cd8eb87292e 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_half.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_half.cc index cebf51586379c..bb34e07c4d5e6 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_half.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_half.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int16.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int16.cc index 72ee798da9e21..fb102fdebdca0 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int16.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int16.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int32.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int32.cc index 29d342c862d52..dd0889da2efaa 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int32.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int32.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int4.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int4.cc index 45d2787af86de..c7e1391d24501 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int4.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int4.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int64.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int64.cc index a08d2380fc45a..fc75f05fa71f2 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int64.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int64.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int8.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int8.cc index f2252507a7f1f..5b23fb6c48911 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int8.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int8.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint16.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint16.cc index a5b0c367fd46e..d01feccd69dba 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint16.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint16.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint32.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint32.cc index dad7f93558782..e919ef81fb1db 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint32.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint32.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint64.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint64.cc index 0afce9402c4ef..44a3a307d1bb7 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint64.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint64.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint8.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint8.cc index 666db2aa3a1d8..79cebf522cfab 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint8.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint8.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/experimental/auto_reorder/BUILD b/xla/hlo/experimental/auto_reorder/BUILD new file mode 100644 index 0000000000000..dd3094b724153 --- /dev/null +++ b/xla/hlo/experimental/auto_reorder/BUILD @@ -0,0 +1,189 @@ +load("@bazel_skylib//rules:build_test.bzl", "build_test") +load("@rules_proto//proto:defs.bzl", "proto_library") +load("//xla:xla.bzl", "xla_cc_test","xla_cc_binary") +load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_libtpu_portable") +load("@tsl//tsl:tsl.bzl", "internal_visibility") +load( + "@local_config_cuda//cuda:build_defs.bzl", + "if_cuda", +) +load( + "@tsl//tsl/platform:build_config.bzl", + "tf_proto_library", +) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], +) + +package_group( + name = "friends", + packages= [ + # "//xla/hlo/experimental/auto_reorder", + "//xla/service/gpu/...", + "//xla/hlo/utils/...", + ], + +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) +cc_library( + name = "auto_reorder_solver", + srcs = ["auto_reorder_solver.cc", + ], + hdrs = [ + "auto_reorder_solver.h", + ], + deps = [ + "//xla:statusor", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", + "//xla/hlo/utils:common_ortools_deps", + "@tsl//tsl/platform:statusor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@tsl//tsl/platform:hash", + "@tsl//tsl/platform:types", + ] +) +# All header files that are used in the build must be declared in +# the hdrs or srcs of cc_* rules. +# This is enforced. + +cc_library( + name = "auto_reorder", + srcs = [ + "auto_reorder.cc", + ], + hdrs = [ + "auto_reorder.h", + "auto_reorder_solver.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", + "//xla/hlo/ir:hlo_module_group", + "//xla/service:hlo_pass", + "//xla/service:hlo_cost_analysis", + "//xla/service:latency_hiding_scheduler", + "//xla/service/gpu/model:gpu_hlo_cost_analysis", + "//xla/service/gpu/model:analytical_latency_estimator", + "//xla/service:backend", + "@com_google_absl//absl/strings", + "@tsl//tsl/platform:statusor", + ":auto_reorder_solver" + ], +) +tf_proto_library( + name = "instr_profile_info_proto", + srcs = ["instr_profile_info.proto"], +) + +cc_library( + name="convert_xplane", + srcs=["convert_xplane.cc"], + hdrs=["convert_xplane.h"], + deps=[ + "//xla:status", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_proto_cc", + "//xla:shape_util", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:types", + "@tsl//tsl/profiler/convert:xla_op_utils", + "@tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", + "@tsl//tsl/profiler/protobuf:xplane_proto_cc", + "@tsl//tsl/profiler/utils:file_system_utils", + "@tsl//tsl/profiler/utils:tf_xplane_visitor", + "@tsl//tsl/profiler/utils:xplane_schema", + "@tsl//tsl/profiler/utils:xplane_utils", + "@tsl//tsl/profiler/utils:xplane_visitor", + "@com_google_protobuf//:protobuf", + ":instr_profile_info_proto_cc" + ] +) +xla_cc_test( + name = "auto_reorder_test", + srcs = ["auto_reorder_test.cc"], + deps = [ + ":auto_reorder", + ":convert_xplane", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:hlo_parser", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "//xla/service:latency_hiding_scheduler", + "//xla/service/gpu:gpu_hlo_schedule", + "//xla/service:gpu_plugin", + "//xla/service/gpu:gpu_device_info_for_tests", + "@com_google_absl//absl/log", + "@com_google_googletest//:gtest", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_binary( + name="convert_xplane_tools", + linkopts = [ + "-Wl,--allow-multiple-definition", + "-lstdc++fs", # For std::filesystem + ], + srcs=["convert_xplane_bin.cc"], + deps=[ + ":convert_xplane", + ":auto_reorder", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", + "//xla/service:latency_hiding_scheduler", + "//xla/service/gpu:gpu_hlo_schedule", + "//xla/service:gpu_plugin", + "//xla:device_util", + "@com_google_absl//absl/log", + "@tsl//tsl/platform:statusor", + + ] +) +# # compatible_with = get_compatible_with_libtpu_portable(), +# deps=[ +# "//xla:status", +# "//xla:xla_proto_cc", +# "//xla/hlo/ir:hlo", +# "//xla/service:hlo_proto_cc", +# "@com_google_absl//absl/container:flat_hash_map", +# "@com_google_absl//absl/status", +# "@com_google_absl//absl/strings", +# "@com_google_absl//absl/types:optional", +# "@tsl//tsl/platform:env", +# "@tsl//tsl/platform:types", +# "@tsl//tsl/profiler/convert:xla_op_utils", +# "@tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", +# "@tsl//tsl/profiler/protobuf:xplane_proto_cc", +# "@tsl//tsl/profiler/utils:file_system_utils", +# "@tsl//tsl/profiler/utils:tf_xplane_visitor", +# "@tsl//tsl/profiler/utils:xplane_schema", +# "@tsl//tsl/profiler/utils:xplane_utils", +# "@tsl//tsl/profiler/utils:xplane_visitor", +# "@tsl//tsl/platform:platform_port", +# ] +# ) \ No newline at end of file diff --git a/xla/hlo/experimental/auto_reorder/README b/xla/hlo/experimental/auto_reorder/README new file mode 100644 index 0000000000000..902abdb3cf208 --- /dev/null +++ b/xla/hlo/experimental/auto_reorder/README @@ -0,0 +1,8 @@ +# Auto Reorder + +run tests +``` +export TF_CPP_MIN_LOG_LEVEL=0 +export TF_CPP_VMODULE="auto_reorder=5,auto_reorder_solver=5" +bazel run --compilation_mode=dbg xla/hlo/experimental/auto_reorder:auto_reorder_test --incompatible_strict_action_env --action_env=USE_CUDA --action_env=XLA_CUDA +``` \ No newline at end of file diff --git a/xla/hlo/experimental/auto_reorder/auto_reorder.cc b/xla/hlo/experimental/auto_reorder/auto_reorder.cc new file mode 100644 index 0000000000000..3a66801c2ebc1 --- /dev/null +++ b/xla/hlo/experimental/auto_reorder/auto_reorder.cc @@ -0,0 +1,402 @@ +#include "xla/hlo/experimental/auto_reorder/auto_reorder.h" +namespace xla { +constexpr int64_t kPointerSize = 8; +// get shape byte size, f32 have 4 bytes; +int64_t ShapeSize(const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, kPointerSize); +} + +tsl::Status AutoReorderPass::RebuildHloOrdering( + HloSchedule& module_schedule, HloComputation* entry_computation) { + bool is_debug = false; + // module_schedule.remove_computation(entry_computation); + // module_schedule.GetOrCreateSequence(entry_computation); + auto status = module_schedule.UpdateComputationSchedule(entry_computation); + + if (!status.ok()) { + return status; + } else { + } + status = module_schedule.Update({}); + if (!status.ok()) { + VLOG(2) << "Update error:" << status.message() << std::endl; + return status; + } + // SequentialHloOrdering seq_ordering(module_schedule); + // auto seqs = seq_ordering.SequentialOrder(*entry_computation); + // module_schedule.set_sequence(entry_computation, *seqs); + + auto new_instruction_sequence = + module_schedule.sequence(entry_computation).instructions(); + for (auto i = 0; i < new_instruction_sequence.size(); i++) { + auto inst = new_instruction_sequence.at(i); + } + status = module_schedule.Verify(); + if (!status.ok()) { + return status; + } + return OkStatus(); +} +tsl::StatusOr> +AutoReorderPass::ScheduleComputation(HloComputation* computation) { + int64_t current_pos = 0; + auto post_order_instructions = computation->MakeInstructionPostOrder(); + HloScheduleGraph schedule_graph(&post_order_instructions, + /*alias_analysis=*/nullptr, + latency_estimator_.get(), + async_tracker_.get()); + async_tracker_->PostProcessScheduleGraph(&schedule_graph, + latency_estimator_.get()); + // we don't need InitializeGraphAnalysis for init node status; + + auto solver_ = absl::make_unique, const HloInstruction*>>(); + std::vector*> comm_lp_nodes; + + // scan instructions, get every instruction cost and deps + // post order,every inst will iter before it's operators + for (HloInstruction* instr : post_order_instructions) { + // AddHint + + const HloGraphNode& instr_node = schedule_graph.GetNode(instr); + VLOG(2) << instr->ToShortString() << "flops cost :" << instr_node.GetCost(); + auto addEdge = [&](const xla::HloInstruction* from_inst, + LPContainer* dst_node, + NodeType edge_type) { + auto operand_lp_node = solver_->FindInstructionLPNode(from_inst); + if (!operand_lp_node.ok()) { + VLOG(2) << "operand_lp_node not found:" << from_inst->ToShortString(); + return false; + } + auto operand_node = schedule_graph.GetNode(from_inst); + CostType edge_cost = + latency_estimator_->GetLatencyBetween(operand_node, instr_node); + VLOG(2) << from_inst->ToShortString() + " should execute before " + + instr->ToShortString(); + // if(edge_type==NodeType::kCommunication){ + // //let edge become a node,so edge will no overlap + // auto edge_node = solver_->FindLPNodeOrCreate(nullptr, edge_cost, + // edge_type); + // }else{ + dst_node->AddDep(operand_lp_node.value(), edge_cost, edge_type); + // } + + return true; + }; + + CostType cost = std::ceil(instr_node.GetCost()); + // there are 2 type now: 1. compute 2. communication + if (async_tracker_->IsSupportedAsyncStart(*instr) || + async_tracker_->IsSupportedAsyncDone(*instr)) { + // communication + // GetCost return float, floor to int + auto current_inst_lp_node = + solver_->FindLPNodeOrCreate(instr, cost, NodeType::kCommunication); + // add current node as constraint + + if (async_tracker_->IsSupportedAsyncDone(*instr)) { + // create a edge, which is communication + auto operand_inst = instr->operand(0); + auto is_success = addEdge(operand_inst, current_inst_lp_node, + NodeType::kCommunication); + TF_RET_CHECK(is_success) + << "operand_lp_node not found:" << operand_inst->ToShortString(); + } else { + // add it's operands to his deps + for (auto i = 0; i < instr->operand_count(); i++) { + auto operand_inst = instr->operand(i); + auto is_success = + addEdge(operand_inst, current_inst_lp_node, NodeType::kCompute); + TF_RET_CHECK(is_success) + << "operand_lp_node not found:" << operand_inst->ToShortString(); + } + instr->control_predecessors(); + for (auto control_inst : instr->control_predecessors()) { + // if it's communication, if control_inst is communicate op,this type + // should be kCommunication? + auto is_success = addEdge(control_inst, current_inst_lp_node, + NodeType::kCompute); // which type? + TF_RET_CHECK(is_success) + << "operand_lp_node not found:" << control_inst->ToShortString(); + } + } + + TF_CHECK_OK(solver_->AddConstraint(current_inst_lp_node)); + if (reorder::is_keep_communicate_order()) { + comm_lp_nodes.push_back(current_inst_lp_node); + } + } else { // compute + auto current_inst_lp_node = + solver_->FindLPNodeOrCreate(instr, cost, NodeType::kCompute); + // when adding edge node, current node have no add to Constraint? + for (auto i = 0; i < instr->operand_count(); i++) { + auto operand_inst = instr->operand(i); + auto is_success = + addEdge(operand_inst, current_inst_lp_node, NodeType::kCompute); + TF_RET_CHECK(is_success) + << "operand_lp_node not found:" << operand_inst->ToShortString(); + } + for (auto control_inst : instr->control_predecessors()) { + // if it's + auto is_success = addEdge(control_inst, current_inst_lp_node, + NodeType::kCompute); // which type? + TF_RET_CHECK(is_success) + << "operand_lp_node not found:" << control_inst->ToShortString(); + } + + TF_CHECK_OK(solver_->AddConstraint(current_inst_lp_node)); + } + } + + // set hint, using post order + std::reverse(post_order_instructions.begin(), post_order_instructions.end()); + for (HloInstruction* instr : post_order_instructions) { + auto lp_node = solver_->FindInstructionLPNode(instr); + if (!lp_node.ok()) { + VLOG(2) << "operand_lp_node not found:" << instr->ToShortString(); + continue; + } + auto operand_lp_node = lp_node.value(); + CostType start_at = -1; + for (auto dep_pair : operand_lp_node->GetDeps()) { + CostType cost = std::get<1>(dep_pair); + auto dep_node = std::get<0>(dep_pair); + if (dep_node->GetHintStart() > -1) { + start_at = std::max(start_at, dep_node->GetHintStart() + cost); + } + } + if (start_at > -1) { + operand_lp_node->SetHintStart(start_at); + } + } + if (reorder::solve_debug) { + // save to pid related file + solver_->SaveGraphviz(absl::StrCat("gantt_before_", computation->name())); + solver_->SaveJSON(absl::StrCat("gantt_before_", computation->name())); + } + auto status = + solver_->Solve(absl::StrCat("mps_file_of_", computation->name())); + if (reorder::solve_debug) { + // save to pid related file + solver_->SaveGantt(absl::StrCat("gantt_", computation->name())); + solver_->SaveGraphviz(absl::StrCat("gantt_", computation->name())); + } + + if (status.ok()) { + // return instruction order by solver + std::vector new_schedule_params; + std::vector new_schedule; + auto sorted_nodes = solver_->GetSortedNodes(); + for (auto node : sorted_nodes) { + auto insts = node->GetValues(); + for (auto inst : insts) { + // extra check: param inst must move to head; + if (inst->opcode() == HloOpcode::kParameter) { + new_schedule_params.insert(new_schedule_params.begin(), + const_cast(inst)); + } else { + new_schedule.push_back(const_cast(inst)); + } + } + } + std::sort(new_schedule_params.begin(), new_schedule_params.end(), + [](const HloInstruction* a, const HloInstruction* b) { + return a->unique_id() < b->unique_id(); + }); + new_schedule_params.insert(new_schedule_params.end(), new_schedule.begin(), + new_schedule.end()); + return new_schedule_params; + } + TF_RET_CHECK(status.ok()) << "Solver error:" << status.message(); + return status; +} +tsl::Status AutoReorderPass::MoveInstruction(HloComputation* src_computation, + absl::string_view src_name, + HloComputation* dst_computation) { + bool is_debug = true; + + // Move instruction from src_computation to dst_computation. + auto src_instruction = src_computation->GetInstructionWithName(src_name); + // step 1: found src_instruction input args and output args + std::vector + src_inputs; // instruction which outputs is needed by src_instruction + std::vector + src_outputs; // instruction which input is src_instruction's output + for (auto i = 0; i < src_instruction->operand_count(); i++) { + auto src_input = src_instruction->mutable_operand(i); + src_inputs.push_back(src_input); + } + std::vector user_insts = src_instruction->users(); + for (auto i = 0; i < src_instruction->user_count(); i++) { + src_outputs.push_back(user_insts.at(i)); + } + // step 2: create Send Instruction for input args, create Recv Instruction for + // output args + int64_t channel_id = 0; + std::vector dst_inputs; + std::vector send_params; + dst_inputs.reserve(src_inputs.size()); + send_params.reserve(src_inputs.size()); + for (size_t i = 0; i < src_inputs.size(); i++) { + channel_id++; + auto src_input = src_inputs.at(i); + auto src_input_shape = src_input->shape(); + // src_instruction + auto token = src_computation->AddInstruction(HloInstruction::CreateToken()); + + auto send_inst = src_computation->AddInstruction(HloInstruction::CreateSend( + src_input, token, channel_id, false /*is_host_transfer*/)); + auto send_done = src_computation->AddInstruction( + HloInstruction::CreateSendDone(send_inst)); + token = dst_computation->AddInstruction(HloInstruction::CreateToken()); + auto recv_inst = dst_computation->AddInstruction( + HloInstruction::CreateRecv(src_input_shape, token, channel_id, + false /*is_host_transfer*/), + "dst_recv" + std::to_string(i)); + auto recv_done = dst_computation->AddInstruction( + HloInstruction::CreateRecvDone(recv_inst)); + HloInstruction* recv_parameter = dst_computation->AddInstruction( + HloInstruction::CreateGetTupleElement(recv_done, 0)); + + dst_inputs.push_back(recv_parameter); + } + channel_id++; + // step3: clone same instruction to dst_computation + auto dst_inst = + dst_computation->AddInstruction(src_instruction->CloneWithNewOperands( + src_instruction->shape(), dst_inputs)); + + // step4 :create Send Instruction from dst_compuation, create Recv Instruction + // in src_computation + auto token = dst_computation->AddInstruction(HloInstruction::CreateToken()); + + auto ret_send_inst = + dst_computation->AddInstruction(HloInstruction::CreateSend( + dst_inst, token, channel_id, false /*is_host_transfer*/)); + auto send_done = dst_computation->AddInstruction( + HloInstruction::CreateSendDone(ret_send_inst)); + + // create recv in src_computation, create token node,so recv_inst will be + // executed by scheduler + token = src_computation->AddInstruction(HloInstruction::CreateToken()); + + auto recv_inst = src_computation->AddInstruction( + HloInstruction::CreateRecv(dst_inst->shape(), token, channel_id, + false /*is_host_transfer*/), + "src_recv_ret"); + auto recv_done = src_computation->AddInstruction( + HloInstruction::CreateRecvDone(recv_inst)); + HloInstruction* recv_parameter = src_computation->AddInstruction( + HloInstruction::CreateGetTupleElement(recv_done, 0)); + + // step5: replace instruction which use src_instruction's output with Recv + // Instruction + for (size_t i = 0; i < src_outputs.size(); i++) { + /* code */ + auto src_output = src_outputs.at(i); + // add dependency + auto status = src_instruction->ReplaceUseWith(src_output, recv_parameter); + if (!status.ok()) { + VLOG(2) << "ReplaceUseWith error:" << status.message() << std::endl; + } + absl::flat_hash_map new_instruction_uses; + int operand_num = 0; + for (const HloInstruction* operand : src_output->operands()) { + if (operand->unique_id() == src_instruction->unique_id()) { + new_instruction_uses[operand_num] = recv_parameter; + } + operand_num++; + } + for (auto it = new_instruction_uses.begin(); + it != new_instruction_uses.end(); ++it) { + status = src_output->ReplaceOperandWith(it->first, it->second); + if (!status.ok()) { + VLOG(2) << "ReplaceOperandWith error:" << status.message() << std::endl; + } + } + } + // step6: remove src_instruction + src_instruction->DetachFromOperandsAndUsers(); + auto status = src_computation->RemoveInstruction(src_instruction); + if (!status.ok()) { + VLOG(2) << "RemoveInstruction error:" << status.message() << std::endl; + return status; + } else { + VLOG(3) << "RemoveInstruction success" + << src_computation->instruction_count() << std::endl; + return OkStatus(); + } +} +StatusOr AutoReorderPass::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + // about reorder: be careful about RNG, such as dropout, random_shuffle, + // random_uniform; + // HloCostAnalysis, get instruction cost + HloComputation* entry_computation = module->entry_computation(); + + // Currently we expect that a schedule that minimizes memory pressure is + // provided as a base. It's not necessary for the algorithm itself but it + // allows us to not having to think for now about memory pressure. + std::vector computations_to_schedule; + computations_to_schedule.reserve(module->computation_count()); + // Collect which computations have latency hiding opportunities. + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + for (auto* instr : computation->instructions()) { + if (async_tracker_->IsSupportedAsyncStart(*instr) || + async_tracker_->IsSupportedAsyncDone(*instr)) { + computations_to_schedule.push_back(computation); + break; + } + } + } + if (computations_to_schedule.empty()) { + return false; + } + + absl::flat_hash_map> + saved_schedules; + // TF_RETURN_IF_ERROR(scheduler_core_->InitializeScheduler(module)); //TODO: + // we don't limit memory usage + for (HloComputation* computation : computations_to_schedule) { + TF_ASSIGN_OR_RETURN(std::vector new_schedule, + ScheduleComputation(computation)); + VLOG(2) << "new_schedule length:" << new_schedule.size() + << " computation instruction_count:" + << computation->instruction_count(); + + saved_schedules[computation] = std::move(new_schedule); + } + + // TODO: now memory is not in constraction + // LOG(INFO) << "AutoReorderPass current memory usage: " + // << scheduler_core_->GetMemoryPeak() << " bytes."; + for (HloComputation* computation : computations_to_schedule) { + // VLOG(1) << "Statistics before scheduling:"; + VLOG(4) << "sequences length:" << module->schedule().sequences().size() + << std::endl; + module->schedule().set_sequence( + computation, absl::MakeConstSpan(saved_schedules[computation])); + VLOG(1) << "Statistics after scheduling:"; + // LogScheduleStatistics(computation); + } + return true; + +} // AutoReorderPass::Run +CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo) { + switch (hlo.opcode()) { + case HloOpcode::kSend: + return {HloOpcode::kAsyncStart, HloOpcode::kSend}; + case HloOpcode::kSendDone: + return {HloOpcode::kAsyncDone, HloOpcode::kSend}; + case HloOpcode::kRecv: + return {HloOpcode::kAsyncStart, HloOpcode::kRecv}; + case HloOpcode::kRecvDone: + return {HloOpcode::kAsyncDone, HloOpcode::kRecv}; + default: + return DefaultGetCanonicalAsyncOp(hlo); + } +} + +} // namespace xla diff --git a/xla/hlo/experimental/auto_reorder/auto_reorder.h b/xla/hlo/experimental/auto_reorder/auto_reorder.h new file mode 100644 index 0000000000000..a2c9592249425 --- /dev/null +++ b/xla/hlo/experimental/auto_reorder/auto_reorder.h @@ -0,0 +1,72 @@ +#ifndef XLA_AUTO_REORDER_H_ +#define XLA_AUTO_REORDER_H_ +#include "absl/strings/string_view.h" +#include "xla/hlo/experimental/auto_reorder/auto_reorder_solver.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/backend.h" +#include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/model/analytical_latency_estimator.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/hlo_pass_interface.h" +#include "xla/service/latency_hiding_scheduler.h" + +// #include "xla/statusor.h" +namespace xla { +class AutoReorderPass : public HloModulePass { + public: + AutoReorderPass(){}; + AutoReorderPass(std::unique_ptr latency_estimator, + std::unique_ptr async_tracker, + std::unique_ptr scheduler_core, + HloCostAnalysis::ShapeSizeFunction shape_size_bytes) + : async_tracker_(std::move(async_tracker)), + scheduler_core_(std::move(scheduler_core)), + latency_estimator_(std::move(latency_estimator)), + shape_size_bytes_(shape_size_bytes){}; + absl::string_view name() const override { return "auto-reorder"; } + using HloPassInterface::Run; + StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + // when computation is changed, we need to rebuild the hlo ordering + tsl::Status RebuildHloOrdering(HloSchedule& module_schedule, + HloComputation* entry_computation); + tsl::Status MoveInstruction(HloComputation* src_computation, + absl::string_view src_name, + HloComputation* dst_computation); + int64_t OriginalInstructionPosition(const HloInstruction* instr) const { + auto it = instr_order_map_.find(instr); + CHECK(it != instr_order_map_.end()); + return it->second; + } + tsl::StatusOr> ScheduleComputation( + HloComputation* computation); + CostType GetInstructionStart(const HloInstruction* instr) const { + auto it = instr_order_map_.find(instr); + CHECK(it != instr_order_map_.end()); + return it->second; + } + void LogScheduleStatistics(const HloComputation* computation) { + XLA_VLOG_LINES(1, LatencyHidingScheduler::SchedulerStatisticsString( + LatencyHidingScheduler::LatencyHidingStatistics( + computation, latency_estimator_.get(), + async_tracker_.get(), shape_size_bytes_))); + } + + private: + std::unique_ptr async_tracker_; + std::unique_ptr scheduler_core_; + std::unique_ptr latency_estimator_; + absl::flat_hash_map> + nodes_; + absl::flat_hash_map instr_order_map_; + // std::unique_ptr solver_; + int64_t move_cost_threshold_in_bytes_; + HloCostAnalysis::ShapeSizeFunction shape_size_bytes_; +}; + +CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo); + +} // namespace xla + +#endif \ No newline at end of file diff --git a/xla/hlo/experimental/auto_reorder/auto_reorder_solver.cc b/xla/hlo/experimental/auto_reorder/auto_reorder_solver.cc new file mode 100644 index 0000000000000..28f39778731c6 --- /dev/null +++ b/xla/hlo/experimental/auto_reorder/auto_reorder_solver.cc @@ -0,0 +1,519 @@ +#include "xla/hlo/experimental/auto_reorder/auto_reorder_solver.h" +#include +#include + +#ifndef LPSchedulerFunc(return_type) +#define LPSchedulerFunc(return_type) \ + template \ + return_type LinearProgramScheduler +#endif + +#ifndef LPContainerDAGFunc(return_type) +#define LPContainerDAGFunc(return_type) \ + template \ + return_type LPContainerDAG +#endif + +namespace xla { +using IntVar = operations_research::sat::IntVar; +using CpModelBuilder = operations_research::sat::CpModelBuilder; +using IntervalVar = operations_research::sat::IntervalVar; +// namespace ORTools = operations_research::sat; +using Task = + std::tuple; // (channel, processing_time), we have two + // channel now:communication and computation +using Job = std::vector; +namespace reorder { +uint32_t get_autoreorder_timeout() { + const char* env = std::getenv("XLA_AUTOREORDER_TIMEOUT"); + if (env == nullptr) { + return ksolveTimeout; + } + return std::atoi(env); +}; +int get_horizon(int max_time) { + // scale should be fit with module? + return max_time * 2; +} +const bool is_keep_communicate_order() { + const char* env = std::getenv("XLA_KEEP_COMMUNICATE_ORDER"); + if (env == nullptr) { + return false; + } + return std::strcmp(env, "true") == 0; +}; +int get_cpu_number() { + // return 8; + return std::thread::hardware_concurrency(); +} + +} // namespace reorder +template +LinearProgramScheduler::~LinearProgramScheduler() { + uuid2container.clear(); + node_to_task_.clear(); + channel_to_intervals_.clear(); + // destroy nodes + for (auto node : nodes_) { + delete node; + } + nodes_.clear(); +}; +template +void LPContainer::AddDep(LPContainer* dep, CostType cost, + NodeType edgetype) { + if (frozen_) { + LOG(FATAL) << "Can not add dep to a frozen node"; + // raise exception + return; + } + // every node should start after dep+cost + deps_.push_back(std::make_tuple(dep, cost, edgetype)); +}; + +LPSchedulerFunc(StatusOr)::FindInstructionLPNode( + ElementType instruction) { + auto it = uuid2container.find(instruction->unique_id()); + + if (it != uuid2container.end()) { + return it->second; + } + TF_RET_CHECK(false) << "Can not find the node:" << instruction->ToString(); +} +LPSchedulerFunc(ContainerType*)::FindLPNodeOrCreate(ElementType element, + CostType cost, + NodeType type) { + auto it = uuid2container.find(element->unique_id()); + if (it != uuid2container.end()) { + return it->second; + } + auto node = new ContainerType(element, cost, type); + nodes_.push_back(node); + uuid2container.emplace(element->unique_id(), node); + return node; +}; +LPSchedulerFunc(bool)::NodeHasAddTasks(ContainerType* node) { + auto it = node_to_task_.find(node->UUID()); + return it != node_to_task_.end(); +}; +LPSchedulerFunc(void)::AddNodeToTask(ContainerType* node, TaskType task) {} + +LPSchedulerFunc(StatusOr)::FindTask(ContainerType* node) { + auto it = node_to_task_.find(node->UUID()); + if (it != node_to_task_.end()) { + VLOG(3) << "Find task for node:" << node->GetName() << " success"; + return std::get<1>(it->second); + } else { + TF_RET_CHECK(false) << "Can not find the task for node:" << node->GetName(); + } +}; +LPSchedulerFunc(Status)::AddConstraint(ContainerType* node) { + if (NodeHasAddTasks(node)) { + return OkStatus(); + } + // XD can't frozen node here, we will add other constraint after that + return OkStatus(); +}; +LPSchedulerFunc(StatusOr)::AddNodeToTask(ContainerType* node) { + IntVar start = cp_model_.NewIntVar({0, horizon_}); + IntVar end = cp_model_.NewIntVar({0, horizon_}); + IntervalVar interval = cp_model_.NewIntervalVar(start, node->GetCost(), end); + TaskType task{start, end, interval}; + if (node->GetHintStart() != -1) { + cp_model_.AddHint(start, node->GetHintStart()); + } + // AddNodeToTask(node, task); + node_to_task_.emplace(node->UUID(), std::make_tuple(node, task)); + return task; +}; + +LPSchedulerFunc(tsl::Status)::Solve(std::string mps_filename) { + uint32_t max_execution_time = 0; + for (auto node : nodes_) { + node->Freeze(); + max_execution_time += node->GetCost(); + for (auto dep_pair : node->GetDeps()) { + auto cost = std::get<1>(dep_pair); + max_execution_time += cost; + }; + } + SetHorizon(reorder::get_horizon(max_execution_time)); + // nodes_ is added by post order,so we should add it before its deps; + for (auto node : nodes_) { + VLOG(3) << "Add to scheduler" << node->GetName(); + TF_ASSIGN_OR_RETURN(TaskType node_task, AddNodeToTask(node)); + } + for (auto node : nodes_) { + auto node_task = std::get<1>(node_to_task_.at(node->UUID())); + + channel_to_intervals_[node->GetType()].push_back(node_task.interval); + for (auto dep_pair : node->GetDeps()) { + auto dep_node = std::get<0>(dep_pair); + auto cost = std::get<1>(dep_pair); + TaskType dep_task; + VLOG(3) << node->GetName() << "should start after" << dep_node->GetName() + << "+" << cost; + TF_ASSIGN_OR_RETURN(dep_task, FindTask(dep_node)); + + cp_model_.AddGreaterOrEqual(node_task.start, dep_task.end + cost); + } + } + // add constraint, channels can be overlap each other + for (auto it = channel_to_intervals_.begin(); + it != channel_to_intervals_.end(); it++) { + cp_model_.AddNoOverlap(it->second); + } + // for communicate stream, edge also should no overlap + std::vector no_overlap_edges; + for (auto node : nodes_) { + if (!node->IsCommunication()) { + continue; + } + // simple method to create 01 program + auto node_task = std::get<1>(node_to_task_.at(node->UUID())); + for (auto dep_tuple : node->GetDeps()) { + auto dep_node = std::get<0>(dep_tuple); + auto cost = std::get<1>(dep_tuple); + auto dep_type = std::get<2>(dep_tuple); + + if (IsSingleChannel(dep_type)) { + auto dep_task = std::get<1>(node_to_task_.at(dep_node->UUID())); + // interval + IntervalVar interval = + cp_model_.NewIntervalVar(dep_task.end, cost, node_task.start); + no_overlap_edges.push_back(interval); + } + } + } + cp_model_.AddNoOverlap(no_overlap_edges); + + // objective. + IntVar obj_var = cp_model_.NewIntVar({0, horizon_}).WithName("makespan"); + std::vector ends; + for (auto it = node_to_task_.begin(); it != node_to_task_.end(); it++) { + ends.push_back(std::get<1>(it->second).end); + } + cp_model_.AddMaxEquality(obj_var, ends); + cp_model_.Minimize(obj_var); + // cp_model_. + // VLOG(2)<<"Number of variables:"< 0) { + operations_research::MPModelProto output; + operations_research::sat::ConvertCpModelProtoToMPModelProto(model, &output); + auto status_of_string = operations_research::ExportModelAsMpsFormat(output); + if (status_of_string.ok()) { + VLOG(2) << "ExportModelAsMpsFormat success"; + std::ofstream out(absl::StrCat("/tmp/", mps_filename, ".mps")); + out << status_of_string.value(); + out.close(); + } + } + + const operations_research::sat::CpSolverResponse response = + operations_research::sat::SolveWithParameters(model, parameters); + uint64_t solve_time = response.wall_time(); + VLOG(1) << "Solve finish:" << response.status() + << " solve time:" << solve_time; + + if (response.status() == operations_research::sat::CpSolverStatus::OPTIMAL || + response.status() == operations_research::sat::CpSolverStatus::FEASIBLE) { + VLOG(2) << "Optimal objective value:" << response.objective_value() + << " status:" << response.status(); + for (auto kv : node_to_task_) { + auto node_task_tuple = std::get<1>(kv); + auto node = std::get<0>(node_task_tuple); + auto task = std::get<1>(node_task_tuple); + CostType start = + operations_research::sat::SolutionIntegerValue(response, task.start); + node->SetStart(start); + VLOG(2) << node->GetName() << "should start at" << start << std::endl; + node_starttime_.emplace(node->UUID(), start); + } + + return OkStatus(); + } else { + VLOG(2) << "Solve failed:" << response.status(); + return tsl::errors::NotFound("Linear Programming solve failed"); + } +}; +std::string ReplaceUnusedChar(const std::string str, + const std::string need_move_str) { + std::string result = str; + for (auto c : need_move_str) { + result.erase(std::remove(result.begin(), result.end(), c), result.end()); + } + return result; +} +LPSchedulerFunc(std::vector)::GetSortedNodes() const { + std::vector sorted_nodes; + sorted_nodes.reserve(nodes_.size()); + for (auto node : nodes_) { + sorted_nodes.push_back(node); + } + // we need stable_sort,let same graph on diffence device have same computation + std::stable_sort( + // std::sort( + sorted_nodes.begin(), sorted_nodes.end(), + [this](ContainerType* a, ContainerType* b) { + return a->GetStart() < b->GetStart(); + }); + return sorted_nodes; +} +LPSchedulerFunc(void)::SaveJSON(std::string filename) const { + std::string json_file = absl::StrCat("/tmp/", filename, ".json"); + std::ofstream json_out(json_file); + json_out << "{" << std::endl; + json_out << "\"nodes\": [" << std::endl; + int32_t node_count = 0; + int32_t edge_count = 0; + + for (auto node : this->GetSortedNodes()) { + std::string name; + if (node->IsCommunication()) { + name = "communication"; + } else { + name = "compute"; + } + if (node_count > 0) { + json_out << ",\n{ \"uuid\": \"" << node->UUID() << "\",\"typename\": \"" + << name << "\", \"name\": \"" + << ReplaceUnusedChar(node->GetName(), "'") + << "\", \"cost\": " << node->GetCost() << " }"; + } else { + json_out << "{ \"uuid\": \"" << node->UUID() << "\",\"typename\": \"" + << name << "\", \"name\": \"" + << ReplaceUnusedChar(node->GetName(), "'") + << "\", \"cost\": " << node->GetCost() << " }"; + } + node_count++; + } + json_out << "]," << std::endl; + json_out << "\"edges\": [" << std::endl; + for (auto node : this->GetSortedNodes()) { + for (auto dep_pair : node->GetDeps()) { + auto dep_node = std::get<0>(dep_pair); + auto dep_cost = std::get<1>(dep_pair); + NodeType dep_type = std::get<2>(dep_pair); + std::string name; + if (IsSingleChannel(dep_type)) { + name = "communication"; + } else { + name = "compute"; + } + // draw edge + if (edge_count > 0) { + json_out << ",\n{ \"from\": \"" << dep_node->UUID() << "\", \"to\": \"" + << node->UUID() << "\", \"typename\": \"" << name + << "\", \"cost\": " << dep_cost << " }"; + } else { + json_out << "{ \"from\": \"" << dep_node->UUID() << "\", \"to\": \"" + << node->UUID() << "\", \"typename\": \"" << name + << "\", \"cost\": " << dep_cost << " }"; + } + edge_count++; + } + } + json_out << "]" << std::endl; + json_out << "}" << std::endl; +} +LPSchedulerFunc(void)::SaveGraphviz(std::string filename) const { + // write a dot file + std::string dot_file = absl::StrCat("/tmp/", filename, ".dot"); + std::ofstream out(dot_file); + out << "digraph G {\n"; + VLOG(4) << "write node number:" << nodes_.size() << " to /tmp/" << filename + << ".dot" << std::endl; + auto get_node_name = [](const ContainerType* node) { + return "\"" + ReplaceUnusedChar(node->GetName(), "%") + "\""; + }; + bool draw_start_time = (node_starttime_.size() > 0); + for (auto node : nodes_) { + std::string color; + if (node->IsCommunication()) { + color = "orange"; + } else { + color = "green"; + } + if (draw_start_time) { + out << get_node_name(node) << "[label=\"" + << ReplaceUnusedChar(node->GetName(), "") << "\\n" + << "cost=" << node->GetCost() + << "\nstart=" << node_starttime_.at(node->UUID()) + << "\",shape=box,color=" << color << "];\n"; + } else { + out << get_node_name(node) << "[label=\"" + << ReplaceUnusedChar(node->GetName(), "") << "\\n" + << "cost=" << node->GetCost() << "\",shape=box,color=" << color + << "];\n"; + } + + for (auto dep_pair : node->GetDeps()) { + auto dep_node = std::get<0>(dep_pair); + auto dep_cost = std::get<1>(dep_pair); + // draw edge + out << get_node_name(dep_node) << "->" << get_node_name(node) + << "[label=\"" << dep_cost << "\"];\n"; + } + } + out << "}\n"; + + out.close(); + // convert dot file to png + std::string png_file = absl::StrCat("/tmp/", filename, ".png"); + std::string cmd = absl::StrCat("dot -Tpng ", dot_file, " -o ", png_file); + auto status = system(cmd.c_str()); + VLOG(4) << cmd << " execute status:" << status << std::endl; +} +LPSchedulerFunc(void)::SaveGantt(std::string filename) const { + // https://g2.antv.antgroup.com/en/examples/storytelling/storytelling/#gantt + // { name: 'compute',label:'kernel name1', startTime: 1, endTime: 4 }, + VLOG(4) << "write node number:" << nodes_.size() << " to /tmp/" << filename + << ".js" << std::endl; + auto get_node_name = [](const ContainerType* node) { + return ReplaceUnusedChar(node->GetName(), "'"); + }; + bool draw_start_time = (node_starttime_.size() > 0); + std::string csv_file = absl::StrCat("/tmp/", filename, ".js"); + std::ofstream csv_out(csv_file); + csv_out << R"(import { Chart } from '@antv/g2'; + const events = [ )"; + for (auto node : this->GetSortedNodes()) { + std::string name; + if (node->IsCommunication()) { + name = "communication"; + } else { + name = "compute"; + } + if (draw_start_time) { + csv_out << "{ name: \"" << name << "\",label:'" + << ReplaceUnusedChar(node->GetName(), "'") + << "', startTime: " << node_starttime_.at(node->UUID()) + << ", endTime: " + << node_starttime_.at(node->UUID()) + node->GetCost() << " },\n"; + } + } + csv_out << "];"; + + csv_out << R"( + const chart = new Chart({ + container: 'container', + autoFit: true, + }); + + chart.coordinate({ transform: [{ type: 'transpose' }] }); + + chart + .interval() + .data(events) + .encode('x', 'name') + .encode('y', ['endTime', 'startTime']) + .encode('color', 'name') + .label({ + text: 'label', + position: 'inside', + transform: [{ type: 'overflowHide' }], + }) + .encode('enterDuration', (d) => d.endTime - d.startTime) + .encode('enterDelay', 'startTime') + .scale('enterDuration', { + zero: true, + range: [0, 3000], + }); + + chart.render();)"; +} + +LPContainerDAGFunc(bool)::IsIn(LPContainer* a) { + return operands_.find(a) != operands_.end(); +}; +LPContainerDAGFunc(void)::AddToDAG(LPContainer* child) { + inner_elements.push_back(child); + if (IsIn(child)) { + operands_.erase(child); + } + for (auto dep_pair : child->GetDeps()) { + auto dep = std::get<0>(dep_pair); + auto cost = std::get<1>(dep_pair); // if cost need store ? + operands_.insert(dep); + } +} +LPContainerDAGFunc(Status)::MergeFrom(LPContainerDAG* other) { + /* + step 1: this inner_elements must have dep to other's inner_elements. so that + link to other's inner_elements change to inner edges + */ + + // maintain this LPContainerDAG inner_elements's deps,so that can create inner + // edge after merge {dep: [,]} + std::unordered_map< + int, std::vector*, CostType>>> + dep_operands2element; + + for (LPContainer* element : GetInnerElements()) { + // from operate to element, there are outer edge,maybe convert to inner edge + for (auto dep_pair : element->GetDeps()) { + auto dep = std::get<0>(dep_pair); + auto cost = std::get<1>(dep_pair); + if (dep_operands2element.find(dep->UUID()) == + dep_operands2element.end()) { + dep_operands2element[dep->UUID()] = + std::vector*, CostType>>(); + } + dep_operands2element[dep->UUID()].push_back( + std::make_tuple(element, cost)); + } + } + // other + for (auto child : other->GetInnerElements()) { + // there child must in inner_elements_deps + TF_RET_CHECK(dep_operands2element.find(child->UUID()) == + dep_operands2element.end()) + << "child is not in dep_operands2element"; + for (auto dep_pair : dep_operands2element[child->UUID()]) { + auto dep = std::get<0>(dep_pair); + auto cost = std::get<1>(dep_pair); + if (dep_operands2element.find(dep->UUID()) != + dep_operands2element.end()) { + for (auto element_pair : dep_operands2element[dep->UUID()]) { + auto element = std::get<0>(element_pair); + auto cost = std::get<1>(element_pair); + // create edge between element and child + DAGEdge edge; + edge.from = element; + edge.to = child; + edge.cost = cost; + edges_.push_back(edge); + } + } + } + + AddToDAG(child); + }; +} +template class LPContainer; +template class LinearProgramScheduler, + const HloInstruction*>; + +template class LPContainerDAG; +// template class LinearProgramScheduler, +// const HloInstruction*>; + +} // namespace xla diff --git a/xla/hlo/experimental/auto_reorder/auto_reorder_solver.h b/xla/hlo/experimental/auto_reorder/auto_reorder_solver.h new file mode 100644 index 0000000000000..d6a9625eac3f5 --- /dev/null +++ b/xla/hlo/experimental/auto_reorder/auto_reorder_solver.h @@ -0,0 +1,270 @@ +#ifndef XLA_AUTO_REORDER_SOLVER_H_ +#define XLA_AUTO_REORDER_SOLVER_H_ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/utils/common_ortools_deps.h" + +namespace xla { +using IntVar = operations_research::sat::IntVar; +using CpModelBuilder = operations_research::sat::CpModelBuilder; +using IntervalVar = operations_research::sat::IntervalVar; +namespace reorder { +const uint32_t ksolveTimeout = 180; // 3min +uint32_t get_autoreorder_timeout(); +constexpr const int kChannelNumber = 2; +int get_horizon(int max_time); +constexpr bool solve_debug = true; +// TODO: no keep order will cause hung on multi processing, we should consider +// how to resolve it +// get cpu number of current machine +const bool is_keep_communicate_order(); +int get_cpu_number(); +} // namespace reorder +enum class NodeType { + kCompute = 0, + kCommunication = 1 + +}; +static bool IsSingleChannel(NodeType nodetype) { + return nodetype == NodeType::kCommunication; +} + +struct TaskType { + IntVar start; + IntVar end; + IntervalVar interval; +}; +using CostType = int64_t; // we can change it to double? + +// TODO: using LPNode to abstract LPContainer and LPContainerDAG +template +class LPNode { + public: + virtual const std::string GetName() const = 0; + virtual const int UUID() = 0; + virtual CostType GetCost() const = 0; + virtual void SetStart(CostType start) = 0; + virtual CostType GetStart() = 0; + virtual bool IsComputation() const = 0; + virtual bool IsCommunication() const = 0; + virtual NodeType GetType() const = 0; + virtual bool HasValue() const = 0; + virtual ElementType GetValue() const = 0; + virtual void AddDep(LPNode* dep, CostType cost) = 0; + virtual const std::vector> GetDeps() const = 0; + virtual void Freeze() = 0; + + private: + std::vector> deps_; +}; + +// LPContainer is a template class, it can be used to store any type of data +// 1. LPContainer; using to store one instruction +// 2. LPContainer; using to store a graph of +// instructions,decrese lp hard +// 3. LPContainer; maybe we can use it to store a pipeline stage +template +class LPContainer { + public: + // create a LPContainer with inner_element, cost and type + LPContainer(ElementType inner_element, CostType cost, NodeType type) + : inner_element_(inner_element), cost_(cost), type_(type) { + uuid_ = reinterpret_cast(this); + }; + ~LPContainer() { deps_.clear(); }; + const std::string GetName() const { return inner_element_->ToShortString(); } + const int UUID() { return inner_element_->unique_id(); } + + CostType GetCost() const { return cost_; } + void SetStart(CostType start) { startat_ = start; } + CostType GetStart() { return startat_; } + // speed up reorder, we can set a hint start time + CostType GetHintStart() { return hint_start_; } + void SetHintStart(CostType start) { hint_start_ = start; } + + // Get the type of the container: compute or communication + bool IsComputation() const { return type_ == NodeType::kCompute; } + bool IsCommunication() const { return type_ == NodeType::kCommunication; } + + NodeType GetType() const { return type_; } + + const bool HasValue() { return inner_element_ != nullptr; } + const std::vector GetValues() { + return std::vector{inner_element_}; + } + // Add a dep of this container, cost is the cost of the edge; this Container + // will be executed after dep + void AddDep(LPContainer* dep, CostType cost, NodeType nodetype); + // Get all deps of the container + const std::vector> GetDeps() + const { + return deps_; + } + /** + * Checks if the given dependency in this container. + * + * @param dep The dependency to check. + * @return True if the dependency in this container, false otherwise. + */ + bool HasDep(LPContainer* dep) { + for (auto d : deps_) { + if (std::get<0>(d) == dep) { + return true; + } + } + return false; + } + // when a container is frozen, it can not be add deps + void Freeze() { frozen_ = true; } + + private: + CostType cost_; + CostType startat_; + CostType hint_start_ = -1; + NodeType type_; + ElementType inner_element_; + // deps store the edge + std::vector> deps_; + bool frozen_ = + false; // if it is frozen, it can not be changed,such as add deps + uintptr_t uuid_; + std::string name_; // edge need a name +}; +// LPContainerDAG is a graph of container, it can be used to store the DAG of +// container be used as a atomic unit of LPContainer +template +class LPContainerDAG : public LPContainer { + // we can use InstructionDAG to get memory effect order + public: + // maintain a DAG of inner elements + struct DAGEdge { + LPContainer* from; + LPContainer* to; + CostType cost; + }; + // create a LPContainerDAG with one element + LPContainerDAG(ElementType inner_element, CostType cost, NodeType type) + : LPContainer(inner_element, cost, type) { + // TODO: there should not create element? + auto ele = new LPContainer(inner_element, cost, type); + inner_elements.push_back(ele); + }; + bool IsIn(LPContainer* a); + // which container can be put together:1. they have the same type 2. they have + // dep between them + // static bool CanFused(LPContainerDAG* a, + // LPContainerDAG* b); + + // override LPContainer + const std::string GetName() { + std::string name = "LPContainerDAG{"; + for (auto ele : inner_elements) { + name += ele->GetName(); + name += "\n"; + } + name += "}"; + return name; + } + const int UUID() { return inner_elements[0]->UUID(); } + const bool HasValue() { return inner_elements.size() > 0; } + const std::vector GetValues() { + std::vector values; + for (auto ele : inner_elements) { + for (auto inst : ele->GetValues()) { + values.push_back(inst); + } + } + return values; + } + // AddChild, child should maintain the deps before + void AddToDAG(LPContainer* child); + const std::vector*> GetInnerElements() const { + return inner_elements; + } + // merge other LPContainerDAG to this LPContainerDAG,then destroy other + // LPContainerDAG + Status MergeFrom(LPContainerDAG* other); + + private: + std::set*> operands_; + std::vector*> inner_elements; + // maintain edges between inner_elements + std::vector edges_; + CostType cost_; + CostType startat_; + NodeType type_; +}; + +// we only define node, edge is express by deps; +// edge is use to express the dependency between two nodes ,it have no effect +// constraint + +// ContainerType is a template class, it can be used to store ElementType of +// data example: LPContainer; using to store one +// instruction, ElementType is const HloInstruction*, ContainerType is +// LPContainer +template +class LinearProgramScheduler { + // https://developers.google.com/optimization/scheduling/job_shop?hl=zh-cn + // be a linear programming problem or a integer programming problem,that's a + // problem + public: + explicit LinearProgramScheduler(bool verbose = false) { + cp_model_ = CpModelBuilder(); + verbose_ = verbose; + }; + ~LinearProgramScheduler(); + // add Node to scheduler, its deps will execute before it + Status AddConstraint(ContainerType* node); + // solve the LP problem + // Status Solve(); + Status Solve(std::string mps_filename); + // find instruction,if not exist, return error + StatusOr FindInstructionLPNode(ElementType instruction); + // find LPNode by instruction,if not exist,create it + ContainerType* FindLPNodeOrCreate(ElementType instruction, CostType cost, + NodeType type); + // ContainerType* + std::vector GetSortedNodes() const; + // for debug: save graph viz file + void SaveGraphviz(std::string filename) const; + // for debug: render gantt chart + void SaveGantt(std::string filename) const; + void SaveJSON(std::string filename) const; + // set max start time as horizon + void SetHorizon(uint32_t horizon) { horizon_ = horizon; } + StatusOr FindTask(ContainerType* node); + bool NodeHasAddTasks(ContainerType* node); + CostType GetNodeStartTime(ContainerType* node); + void AddNodeToTask(ContainerType* node, TaskType task); + StatusOr AddNodeToTask(ContainerType* node); + + private: + StatusOr AddEdgesNoOverlap(ContainerType* node); + CpModelBuilder cp_model_; + bool verbose_ = false; + std::unordered_map uuid2container; + std::vector nodes_; + uint32_t horizon_ = std::numeric_limits::max(); + absl::flat_hash_map> + node_to_task_; // every node hold interval_var,show what time it start + // and end + // channels can be overlap each other + std::map> channel_to_intervals_; + std::map node_starttime_; +}; +} // namespace xla +#endif // XLA_AUTO_REORDER_H_ \ No newline at end of file diff --git a/xla/hlo/experimental/auto_reorder/auto_reorder_test.cc b/xla/hlo/experimental/auto_reorder/auto_reorder_test.cc new file mode 100644 index 0000000000000..801b010e3e87e --- /dev/null +++ b/xla/hlo/experimental/auto_reorder/auto_reorder_test.cc @@ -0,0 +1,1596 @@ + + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "xla/hlo/experimental/auto_reorder/auto_reorder.h" +#include "xla/hlo/experimental/auto_reorder/convert_xplane.h" +#include "xla/service/async_collective_creator.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/gpu_hlo_schedule.h" +#include "xla/tests/hlo_test_base.h" + +#define debug_log(x) \ + { \ + if (is_debug) std::cout << x << std::endl; \ + } +namespace xla { +uint32_t kRandomSeed = 1243; +struct ScheduleStatus { + bool success; + uint32_t exec_time; + uint32_t memory_usage; +}; +class CostGenerator { + public: + CostGenerator(int mean, int std, int seed) { + gen_ = std::mt19937(seed); + dist_ = std::normal_distribution(static_cast(mean), + static_cast(std)); + }; + int operator()() { return std::max(1, static_cast(dist_(gen_))); } + + private: + std::mt19937 gen_; + std::normal_distribution dist_; +}; +class GpuLatencyEstimator : public ApproximateLatencyEstimator { + public: + explicit GpuLatencyEstimator( + GetCanonicalAsyncOpFunc func = GpuGetCanonicalAsyncOp) + : ApproximateLatencyEstimator(func) {} + TimeCost NodeCost(const HloInstruction* instr) const override { + HloOpcode op = instr->opcode(); + if (op == HloOpcode::kGetTupleElement || op == HloOpcode::kBitcast || + op == HloOpcode::kConstant || op == HloOpcode::kParameter || + instr->IsEffectiveBitcast()) { + return 0.0; + } + // Consider cublas/cuddn/softmax custom calls as medium cost. Since the + // latency between async-start and async-done is 5000 and cost of each + // custom call is 1000, the LHS will try to schedule approximately 5 of + // these in between each start/end pair. + if (instr->opcode() == HloOpcode::kCustomCall) { + if (gpu::IsCublasGemm(*instr) || + gpu::IsCustomCallToDnnConvolution(*instr)) { + return ApproximateLatencyEstimator::kMediumCost; + } + // consider other custom calls as medium cost for now. Keeping the case + // explicitly separate for further tuning. + return ApproximateLatencyEstimator::kMediumCost; + } + return ApproximateLatencyEstimator::NodeCost(instr); + } + + LatencyEstimator::TimeCost GetLatencyBetween( + const HloGraphNode& from, const HloGraphNode& target) const override { + if (IsAsyncPair(from, target)) { + if (from.GetInstr().opcode() == HloOpcode::kRecv) { + // Recv -> RecvDone has a low latency. + return ApproximateLatencyEstimator::kLowLatency; + } else if (from.GetInstr().opcode() == HloOpcode::kSend) { + // Send -> SendDone has a very high latency. + return ApproximateLatencyEstimator::kHighLatency * 10; + } + + return ApproximateLatencyEstimator::kHighLatency; + } + // Every other instruction we consider synchronous, which means the + // latency between each of them is always one unit. + return ApproximateLatencyEstimator::kLowLatency; + } +}; + +class SavedInstLatencyEstimator : public GpuLatencyEstimator { + // make random inst cost + // usage: + // 1. create instruction + // 2. using SetInstructionCost; + // 3. using this estimator in scheduler + + public: + explicit SavedInstLatencyEstimator( + GetCanonicalAsyncOpFunc func = GpuGetCanonicalAsyncOp) + : GpuLatencyEstimator(func) {} + TimeCost NodeCost(const HloInstruction* instr) const override { + auto cost = GpuLatencyEstimator::NodeCost(instr); + if (inst_cost_.find(instr->unique_id()) != inst_cost_.end()) { + cost = inst_cost_.at(instr->unique_id()); + } + return cost; + } + LatencyEstimator::TimeCost GetLatencyBetween( + const HloGraphNode& from, const HloGraphNode& target) const override { + if (IsAsyncPair(from, target)) { + if (edge_cost_.find(target.GetInstr().unique_id()) != edge_cost_.end()) { + auto cost = edge_cost_.at(target.GetInstr().unique_id()); + return cost; + } + if (edge_cost_.find(from.GetInstr().unique_id()) != edge_cost_.end()) { + auto cost = edge_cost_.at(from.GetInstr().unique_id()); + return cost; + } + return ApproximateLatencyEstimator::kLowLatency; + } + // Every other instruction we consider synchronous, which means the + // latency between each of them is always one unit. + return ApproximateLatencyEstimator::kLowLatency; + } + + void SetInstructionCost(const HloInstruction* instr, TimeCost cost) { + inst_cost_.emplace(instr->unique_id(), cost); + } + void SetInstructionBetween(const HloInstruction* target, TimeCost cost) { + // let all node link to target have cost + if (target->unique_id() == -1) { + // raise exception? + std::cout << "SetInstructionBetween fail" << target->ToShortString() + << std::endl; + ASSERT_ANY_THROW(target->unique_id() != -1); + } + edge_cost_.emplace(target->unique_id(), cost); + } + void CloneCost(std::unordered_map input_costs, + std::unordered_map edge_cost) { + inst_cost_ = input_costs; + edge_cost_ = edge_cost; + } + std::unordered_map GetCosts() { return inst_cost_; } + std::unique_ptr clone() { + auto estimator = std::make_unique(); + estimator->CloneCost(inst_cost_, edge_cost_); + return estimator; + } + + private: + std::unordered_map inst_cost_; + std::unordered_map edge_cost_; +}; +using namespace xla::gpu; +constexpr int kMaxConcurrentAsyncCollectivePermutes = 5; + +SchedulerConfig GetSchedulerConfig(int64_t memory_limit) { + SchedulerConfig config; + config.all_reduce_overlap_limit = 1; + config.collective_permute_overlap_limit = 1; + config.use_real_cost_model = false; + config.aggressive_scheduling_policies = true; + config.schedule_send_recvs = true; + config.memory_limit = memory_limit; + return config; +} +SchedulerConfig GetDefaultSchedConfig() { + SchedulerConfig sched_cfg; + sched_cfg.collective_permute_overlap_limit = + kMaxConcurrentAsyncCollectivePermutes; + sched_cfg.send_recv_overlap_limit = INT32_MAX; + return sched_cfg; +} + +class AutoReorderingTest : public HloTestBase { + protected: + void SetUp() override { setenv("XLA_AUTOREORDER_TIMEOUT", "60", 1); } + void TearDown() override { unsetenv("XLA_AUTOREORDER_TIMEOUT"); } + const char* const add_hlo_string_ = R"( +HloModule module +ENTRY %elementwise { + %param0 = f32[16,32,64]{2,1,0} parameter(0) + %param1 = f32[16,32,64]{2,1,0} parameter(1) + ROOT root = f32[16,32,64]{2,1,0} add(%param0, %param1) +})"; + void RunDemoAutoReorderWithOptions(size_t expected_num_tiles, + size_t expected_sharded_dimensions = 1) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(add_hlo_string_)); + auto* instruction = FindInstruction(module.get(), "param0"); + } + + public: + HloComputation* MakeReduction(const HloOpcode type, HloModule* module) { + HloComputation::Builder sum_builder(HloOpcodeString(type)); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), type, x, y)); + HloComputation* reduction = + module->AddEmbeddedComputation(sum_builder.Build()); + return reduction; + } + HloComputation* MakeReduceScatter(const HloOpcode type, Shape input_shape, + Shape output_shape, HloModule* module) { + HloComputation::Builder async_builder("AsyncOp"); + HloInstruction* param = async_builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "pasync")); + async_builder.AddInstruction(HloInstruction::CreateReduceScatter( + output_shape, {param}, MakeReduction(type, module), + CreateReplicaGroups({{0, 1}}), false, /*channel_id*/ 3, false, 0)); + HloComputation* reduction = + module->AddEmbeddedComputation(async_builder.Build()); + + return reduction; + } + HloComputation* MakeAll2All(Shape input_shape, HloModule* module) { + HloComputation::Builder async_builder("AsyncOp"); + HloInstruction* param = async_builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "pasync")); + HloInstruction* param1 = async_builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "pasync")); + async_builder.AddInstruction(HloInstruction::CreateAllToAll( + input_shape, {param, param1}, + /*replica_groups=*/CreateReplicaGroups({{0, 1}}), + /*constrain_layout=*/false, /*channel_id=*/std::nullopt, + /*split_dimension*/ 0)); + HloComputation* reduction = + module->AddEmbeddedComputation(async_builder.Build()); + + return reduction; + } + std::string GetInstructionsOrderString(HloModule* hlo_module) { + auto insts = hlo_module->schedule() + .sequence(hlo_module->entry_computation()) + .instructions(); + auto ret = std::string(""); + // hard to keep all instruction order,only keep communicate order + for (auto inst : insts) { + switch (inst->opcode()) { + case HloOpcode::kAllToAll: + case HloOpcode::kAllGather: + case HloOpcode::kAllGatherDone: + case HloOpcode::kAllReduce: + case HloOpcode::kAllReduceDone: + case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteDone: + case HloOpcode::kReduceScatter: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: { + ret = ret + "\n" + inst->ToString(); + } + default: { + continue; + } + } + } + bool isStart = false; + // check start-done pair, there is no start-start + for (auto inst : insts) { + switch (inst->opcode()) { + case HloOpcode::kAllGatherStart: + case HloOpcode::kAllReduceStart: + case HloOpcode::kCollectivePermuteStart: + case HloOpcode::kSend: + case HloOpcode::kRecv: { + CHECK_NE(isStart, true); + isStart = true; + continue; + } + case HloOpcode::kAllGatherDone: + case HloOpcode::kAllReduceDone: + case HloOpcode::kCollectivePermuteDone: + case HloOpcode::kSendDone: + case HloOpcode::kRecvDone: { + isStart = false; + continue; + } + case HloOpcode::kReduceScatter: { + ret = ret + "\n" + inst->ToString(); + } + default: { + continue; + } + } + } + return ret; + } + bool CheckParameterInst(HloModule* hlo_module) { + auto insts = hlo_module->schedule() + .sequence(hlo_module->entry_computation()) + .instructions(); + bool isParameter = true; + // check parameter must at head + for (auto inst : insts) { + if (inst->opcode() == HloOpcode::kParameter) { + if (isParameter) { + continue; + } else { + return false; + } + } else { + isParameter = false; + } + } + return true; + } + StatusOr GetModuleCost( + HloModule* module, SchedulerConfig sched_config, + const xla::LatencyEstimator* latency_estimator + + ) { + // we should implement method independent of scheduler to get module time + // cost + + // ASSERT_ANY_THROW(latency_estimator!=nullptr); + auto computation = module->entry_computation(); + auto schedule = module->schedule(); + auto seq = schedule.sequence(computation); + + HloCostAnalysis::ShapeSizeFunction shape_size_bytes = + [&shape_size_bytes](const Shape& shape) -> int64_t { + int64_t shape_size = 0; + if (shape.IsTuple()) { + for (auto& sub_shape : shape.tuple_shapes()) { + shape_size += shape_size_bytes(sub_shape); + } + return shape_size; + } + return ShapeUtil::ByteSizeOfElements(shape); + }; + auto async_tracker = std::make_unique(sched_config); + // copy some code from LatencyHidingStatistics + if (latency_estimator == nullptr) { + return absl::InvalidArgumentError("latency_estimator is nullptr!"); + } + // here latency_estimator make segmant fault? + auto ret = LatencyHidingScheduler::LatencyHidingStatistics( + computation, latency_estimator, async_tracker.get(), shape_size_bytes); + return ret; + } + std::vector CreateReplicaGroups( + absl::Span> groups) { + std::vector replica_groups(groups.size()); + for (int64_t i = 0; i < groups.size(); ++i) { + *replica_groups[i].mutable_replica_ids() = {groups[i].begin(), + groups[i].end()}; + } + return replica_groups; + } + StatusOr MakeCommunicateComputation( + HloModule* module, SavedInstLatencyEstimator* estimator) { + /* + p0->allreduce-100->allreduce_done->p0.1 + p1->allreduce.1-10->allreduce_done.1->p0.2 + dot(p0,p1):cost=10 + dot(p0.1,p0.2):cost=10 + + p0 + + */ + HloComputation::Builder builder("test"); + auto add_reducer = MakeReduction(HloOpcode::kAdd, module); + Shape shape = ShapeUtil::MakeShape(F32, {4, 256, 256}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + int64_t channel_id = 0; + auto precision_config = DefaultPrecisionConfig(2); + + auto p0 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, shape, "p0")); + auto p1 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, shape, "p1")); + HloInstruction* ar_start0 = + builder.AddInstruction(HloInstruction::CreateAllReduceStart( + shape, {p0}, add_reducer, + /*replica_groups=*/CreateReplicaGroups({{0, 1}}), + /*constrain_layout=*/false, /*channel_id=*/1, + /*use_global_device_ids=*/true)); + HloInstruction* ar_done0 = + builder.AddInstruction(HloInstruction::CreateUnary( + shape, HloOpcode::kAllReduceDone, ar_start0)); + + auto dot0 = builder.AddInstruction( + HloInstruction::CreateDot(shape, p0, p1, dot_dnums, precision_config)); + + HloInstruction* ar_start1 = + builder.AddInstruction(HloInstruction::CreateAllReduceStart( + shape, {p1}, add_reducer, + /*replica_groups=*/CreateReplicaGroups({{0, 1}}), + /*constrain_layout=*/false, /*channel_id=*/1, + /*use_global_device_ids=*/true)); + HloInstruction* ar_done1 = + builder.AddInstruction(HloInstruction::CreateUnary( + shape, HloOpcode::kAllReduceDone, ar_start1)); + + auto dot1 = builder.AddInstruction(HloInstruction::CreateDot( + shape, ar_done0, ar_done1, dot_dnums, precision_config)); + + auto ret = + builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1})); + auto computation = builder.Build(); + computation->set_root_instruction(ret); + auto entry_computation = + module->AddEntryComputation(std::move(computation)); + estimator->SetInstructionCost(ar_start0, 1); + estimator->SetInstructionCost(ar_done0, 1); + estimator->SetInstructionCost(dot0, 10); + estimator->SetInstructionCost(ar_start1, 1); + estimator->SetInstructionCost(ar_done1, 1); + estimator->SetInstructionCost(dot1, 10); + estimator->SetInstructionBetween(ar_done0, 100); + estimator->SetInstructionBetween(ar_done1, 10); + + VLOG(2) << "finish creating instruction now scheduling" + << module->has_schedule(); + + // let module have one schedule + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + + return entry_computation; + } + StatusOr MakeTestComputation(HloModule* module) { + // param: p0,p1,p2,p3 + // d01 = dot(p0,p1) + // d23 = dot(p2,p3) + // + auto add_reducer = MakeReduction(HloOpcode::kAdd, module); + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 256, 256}); + Shape reduce_shape = ShapeUtil::MakeShape(F32, {2, 256, 256}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + int64_t channel_id = 0; + auto precision_config = DefaultPrecisionConfig(2); + auto p0 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, shape, "p0")); + auto p1 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, shape, "p1")); + auto p2 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, shape, "p2")); + auto p3 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/3, shape, "p3")); + auto d01 = builder.AddInstruction( + HloInstruction::CreateDot(shape, p0, p1, dot_dnums, precision_config)); + auto d23 = builder.AddInstruction( + HloInstruction::CreateDot(shape, p2, p3, dot_dnums, precision_config)); + HloInstruction* all_reduce_start = builder.AddInstruction( + HloInstruction::CreateAllReduceStart( + shape, {d01}, add_reducer, + /*replica_groups=*/CreateReplicaGroups({{0, 1}}), + /*constrain_layout=*/false, /*channel_id=*/1, + /*use_global_device_ids=*/true), + "all_reduce_start"); + HloInstruction* ar_done0 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kAllReduceDone, + all_reduce_start), + "ar_done0"); + + HloInstruction* all_reduce_start1 = + builder.AddInstruction(HloInstruction::CreateAllReduceStart( + shape, {d23}, add_reducer, + /*replica_groups=*/CreateReplicaGroups({{0, 1}}), + /*constrain_layout=*/false, /*channel_id=*/1, + /*use_global_device_ids=*/true)); + HloInstruction* ar_done1 = + builder.AddInstruction(HloInstruction::CreateUnary( + shape, HloOpcode::kAllReduceDone, all_reduce_start1)); + + // d01 dot with p0 + auto d01_dot_p0 = builder.AddInstruction(HloInstruction::CreateDot( + shape, ar_done0, p0, dot_dnums, precision_config)); + // d23 dot with p1 + auto d23_dot_p1 = builder.AddInstruction(HloInstruction::CreateDot( + shape, ar_done1, p1, dot_dnums, precision_config)); + // d01_dot_p0 add p2 + auto d01_dot_p0_add_p2 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, d01_dot_p0, p2)); + // d23_dot_p1 add p3 + auto d23_dot_p1_add_p3 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, d23_dot_p1, p3)); + auto rs_computation = + MakeReduceScatter(HloOpcode::kAdd, shape, reduce_shape, module); + + const Shape async_start_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTupleShape({shape}), reduce_shape}); + HloInstruction* async_start0 = + builder.AddInstruction(HloInstruction::CreateAsyncStart( + async_start_shape, {d01_dot_p0_add_p2}, rs_computation, + /*async_execution_thread=*/"parallel_thread")); + HloInstruction* async_update = builder.AddInstruction( + HloInstruction::CreateAsyncUpdate(async_start_shape, async_start0)); + auto reducescater_ret0 = builder.AddInstruction( + HloInstruction::CreateAsyncDone(reduce_shape, async_update)); + // new version:need create other computation + auto rs_computation2 = + MakeReduceScatter(HloOpcode::kAdd, shape, reduce_shape, module); + HloInstruction* async_start1 = + builder.AddInstruction(HloInstruction::CreateAsyncStart( + async_start_shape, {d23_dot_p1_add_p3}, rs_computation2, + /*async_execution_thread=*/"parallel_thread")); + HloInstruction* async_update1 = builder.AddInstruction( + HloInstruction::CreateAsyncUpdate(async_start_shape, async_start1)); + auto reducescater_ret1 = builder.AddInstruction( + HloInstruction::CreateAsyncDone(reduce_shape, async_update1)); + + auto add_reduce_scatter = builder.AddInstruction( + HloInstruction::CreateBinary(reduce_shape, HloOpcode::kAdd, + reducescater_ret1, reducescater_ret0)); + auto all2all_ret = builder.AddInstruction(HloInstruction::CreateAllToAll( + reduce_shape, {add_reduce_scatter}, + /*replica_groups=*/CreateReplicaGroups({{0, 1}}), + /*constrain_layout=*/true, /*channel_id=*/2, + /*split_dimension*/ 0)); + std::vector compute_vec = {d01_dot_p0_add_p2, + d23_dot_p1_add_p3, all2all_ret}; + auto ret = builder.AddInstruction(HloInstruction::CreateTuple(compute_vec)); + auto computation = builder.Build(); + computation->set_root_instruction(ret); + auto entry_computation = + module->AddEntryComputation(std::move(computation)); + VLOG(2) << "finish creating instruction now scheduling" + << module->has_schedule(); + + // let module have one schedule + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + + return entry_computation; + } + StatusOr MakeRandomComputation( + HloModule* module, SavedInstLatencyEstimator* estimator, + uint32_t inst_nums = 100, uint8_t max_deps = 5, + double communication_rate = 0.1f, + std::mt19937 gen = std::mt19937{kRandomSeed}, + CostGenerator cost_gen = CostGenerator(50, 5, kRandomSeed)) { + /* create instruction list with inst_nums instructions + every inst be used by output + */ + VLOG(2) << "create computation begin,test name: " << TestName() + << ",inst_nums=" << inst_nums << ",max_deps=" << max_deps + << ",communication_rate=" << communication_rate; + HloComputation::Builder builder( + absl::StrCat(TestName(), "N", inst_nums, "R", communication_rate)); + Shape shape = ShapeUtil::MakeShape(F32, {4, 256, 256}); + + // insts_list: store instruction list,which have one result + std::vector insts_list; + + uint32_t communication_count = std::floor(communication_rate * inst_nums); + uint32_t insert_comm_every = inst_nums / communication_count; + + // Node cost must add after AddEntryComputation,so that instruction have + // unique_id + std::vector> insts2cost; + std::vector> edge2cost; + + std::set used_insts; + std::set not_used_insts; + + CostGenerator random_gen = CostGenerator(50, 5, kRandomSeed); + + for (size_t i = 0; i < inst_nums; i++) { + // random deps from 1~5 + if (i < 2) { + auto inst = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/i, shape, "p" + std::to_string(i))); + insts_list.push_back(inst); + insts2cost.push_back(std::make_tuple(inst, 0)); + not_used_insts.insert(inst); + continue; + } + uint32_t deps_count = 2; + std::vector deps; + // from 0~i, pick deps_count insts as deps + if (not_used_insts.size() >= deps_count) { + // first pick not used insts + std::sample(not_used_insts.begin(), not_used_insts.end(), + std::back_inserter(deps), deps_count, gen); + // remove deps from not_used_insts + for (auto& dep : deps) { + not_used_insts.erase(dep); + } + } else { + std::sample(insts_list.begin(), insts_list.end(), + std::back_inserter(deps), deps_count, gen); + } + + if (deps.size() != 2) { + return absl::InvalidArgumentError("deps size not equal 2"); + } + + auto inst = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, deps.at(0), deps.at(1))); + // we can add control dependency to inst + // random add control dep, test control_predecessors issue + if (random_gen() > 60) { + std::vector control_deps; + std::sample(insts_list.begin(), insts_list.end(), + std::back_inserter(control_deps), 2, gen); + for (auto control_dep : control_deps) { + auto status = control_dep->AddControlDependencyTo(inst); + if (!status.ok()) { + return absl::InvalidArgumentError("AddControlDependencyTo error"); + } + } + } + + insts_list.push_back(inst); + insts2cost.push_back(std::make_tuple(inst, cost_gen())); + not_used_insts.insert(inst); + + if (i % insert_comm_every == 0) { + uint8_t comm_deps_count = 2; + // from 0~i, pick deps_count insts + std::vector comm_deps; + if (not_used_insts.size() >= comm_deps_count) { + // first pick not used insts + std::sample(not_used_insts.begin(), not_used_insts.end(), + std::back_inserter(comm_deps), comm_deps_count, gen); + // remove deps from not_used_insts + for (auto& dep : comm_deps) { + not_used_insts.erase(dep); + } + } else { + // pick from all insts + std::sample(insts_list.begin(), insts_list.end(), + std::back_inserter(comm_deps), comm_deps_count, gen); + } + if (comm_deps.size() != comm_deps_count) { + return absl::InvalidArgumentError("comm_deps size not equal 2"); + } + for (auto& dep : comm_deps) { + auto all_reduce_start = + builder.AddInstruction(HloInstruction::CreateAllReduceStart( + shape, {dep}, MakeReduction(HloOpcode::kAdd, module), + /*replica_groups=*/CreateReplicaGroups({{0, 1}}), + /*constrain_layout=*/false, /*channel_id=*/1, + /*use_global_device_ids=*/true)); + // estimator->SetInstructionCost(all_reduce_start, 1); + insts2cost.push_back(std::make_tuple(all_reduce_start, 1)); + auto ar_done = builder.AddInstruction(HloInstruction::CreateUnary( + shape, HloOpcode::kAllReduceDone, all_reduce_start)); + // estimator->SetInstructionCost(ar_done, 1); + insts2cost.push_back(std::make_tuple(ar_done, 1)); + + insts_list.push_back(ar_done); + edge2cost.push_back(std::make_tuple(ar_done, cost_gen() + 50)); + not_used_insts.insert(ar_done); + } + } + } + // get no use insts,let them sum and return,avoid graph optimizer delete + // them + + auto reduce_sum_func = [&](HloInstruction* left, HloInstruction* right) { + return builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, left, right)); + }; + auto sum_op = std::reduce( + not_used_insts.begin(), not_used_insts.end(), + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), + reduce_sum_func); + + std::vector last_n_insts; + uint32_t last_n = 3; + for (size_t i = 0; i < last_n; i++) { + last_n_insts.push_back(insts_list.at(inst_nums - 1 - i)); + } + last_n_insts.push_back(sum_op); + auto last_ret = builder.AddInstruction( + HloInstruction::CreateTuple(absl::MakeSpan(last_n_insts))); + estimator->SetInstructionCost(last_ret, 0); + auto computation = builder.Build(); + VLOG(2) << "create computation success,test name" << TestName(); + computation->set_root_instruction(last_ret); + // Node cost must after AddEntryComputation,so that instruction have + // unique_id + auto computation_ptr = module->AddEntryComputation(std::move(computation)); + for (auto& inst_cost : insts2cost) { + estimator->SetInstructionCost(std::get<0>(inst_cost), + std::get<1>(inst_cost)); + } + for (auto& edge_cost : edge2cost) { + estimator->SetInstructionBetween(std::get<0>(edge_cost), + std::get<1>(edge_cost)); + } + // let module have one schedule + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + VLOG(2) << "setting default schedule finish." << TestName(); + + return computation_ptr; + } + StatusOr RunLatencyHidingScheduler( + HloModule* module, SchedulerConfig sched_config = GetDefaultSchedConfig(), + std::unique_ptr latency_estimator = + std::make_unique()) { + AsyncCollectiveCreator::CollectiveCreatorConfig config{ + /*convert_all_reduce=*/HloPredicateTrue, + /*convert_all_gather=*/HloPredicateTrue, + /*convert_collective_permute=*/HloPredicateTrue}; + TF_ASSIGN_OR_RETURN(bool value, + AsyncCollectiveCreator(std::move(config)).Run(module)); + HloCostAnalysis::ShapeSizeFunction shape_size_bytes = + [&shape_size_bytes](const Shape& shape) -> int64_t { + int64_t shape_size = 0; + if (shape.IsTuple()) { + for (auto& sub_shape : shape.tuple_shapes()) { + shape_size += shape_size_bytes(sub_shape); + } + return shape_size; + } + return ShapeUtil::ByteSizeOfElements(shape); + }; + auto async_tracker = std::make_unique(sched_config); + auto scheduler_core = std::make_unique( + shape_size_bytes, async_tracker.get(), latency_estimator.get(), + sched_config); + TF_ASSIGN_OR_RETURN( + bool is_success, + LatencyHidingScheduler(std::move(latency_estimator), + std::move(async_tracker), + std::move(scheduler_core), shape_size_bytes) + .Run(module)); + + return ScheduleStatus{is_success, 0, 0}; + } + StatusOr RunScheduler( + HloModule* module, SchedulerConfig sched_config = GetDefaultSchedConfig(), + std::unique_ptr latency_estimator = + std::make_unique()) { + HloCostAnalysis::ShapeSizeFunction shape_size_bytes = + [&shape_size_bytes](const Shape& shape) -> int64_t { + int64_t shape_size = 0; + if (shape.IsTuple()) { + for (auto& sub_shape : shape.tuple_shapes()) { + shape_size += shape_size_bytes(sub_shape); + } + return shape_size; + } + return ShapeUtil::ByteSizeOfElements(shape); + }; + auto async_tracker = std::make_unique(sched_config); + auto scheduler_core = std::make_unique( + shape_size_bytes, async_tracker.get(), latency_estimator.get(), + sched_config); + auto test_pass = + AutoReorderPass(std::move(latency_estimator), std::move(async_tracker), + std::move(scheduler_core), shape_size_bytes); + TF_ASSIGN_OR_RETURN(bool is_success, test_pass.Run(module)); + + return ScheduleStatus{is_success, 0, 0}; + } + StatusOr> ParseHloText( + absl::string_view hlo_string) { + TF_ASSIGN_OR_RETURN( + auto hlo_module, + ParseAndReturnVerifiedModule(hlo_string, GetModuleConfigForTest())); + return StatusOr>(std::move(hlo_module)); + } + tsl::Status RebuildHloOrdering(HloSchedule& module_schedule, + HloComputation* entry_computation) { + bool is_debug = false; + // module_schedule.remove_computation(entry_computation); + // module_schedule.GetOrCreateSequence(entry_computation); + auto status = module_schedule.UpdateComputationSchedule(entry_computation); + debug_log("UpdateComputationSchedule"); + + if (!status.ok()) { + debug_log("UpdateComputationSchedule error:" << status.message()); + return status; + } else { + debug_log( + "UpdateComputationSchedule success:" + << module_schedule.sequence(entry_computation).instructions().size()); + } + status = module_schedule.Update({}); + if (!status.ok()) { + std::cout << "Update error:" << status.message() << std::endl; + return status; + } + // SequentialHloOrdering seq_ordering(module_schedule); + // auto seqs = seq_ordering.SequentialOrder(*entry_computation); + // module_schedule.set_sequence(entry_computation, *seqs); + // debug_log("seqs length" << seqs.size()); + + auto new_instruction_sequence = + module_schedule.sequence(entry_computation).instructions(); + debug_log("new_instruction_sequence length" + << new_instruction_sequence.size()); + for (auto i = 0; i < new_instruction_sequence.size(); i++) { + auto inst = new_instruction_sequence.at(i); + debug_log("rebuild idx=" << i << "=" << inst->ToString()); + } + status = module_schedule.Verify(); + if (!status.ok()) { + debug_log("Verify error:" << status.message()); + return status; + } else { + debug_log( + "Verify success:" + << module_schedule.sequence(entry_computation).instructions().size()); + } + } + + void MoveInstruction(HloComputation* src_computation, + absl::string_view src_name, + HloComputation* dst_computation) { + bool is_debug = true; + + // Move instruction from src_computation to dst_computation. + auto src_instruction = src_computation->GetInstructionWithName(src_name); + // step 1: found src_instruction input args and output args + std::vector + src_inputs; // instruction which outputs is needed by src_instruction + std::vector + src_outputs; // instruction which input is src_instruction's output + for (auto i = 0; i < src_instruction->operand_count(); i++) { + auto src_input = src_instruction->mutable_operand(i); + src_inputs.push_back(src_input); + } + std::vector user_insts = src_instruction->users(); + for (auto i = 0; i < src_instruction->user_count(); i++) { + src_outputs.push_back(user_insts.at(i)); + } + // step 2: create Send Instruction for input args, create Recv Instruction + // for output args + int64_t channel_id = 0; + std::vector dst_inputs; + std::vector send_params; + dst_inputs.reserve(src_inputs.size()); + send_params.reserve(src_inputs.size()); + for (size_t i = 0; i < src_inputs.size(); i++) { + channel_id++; + auto src_input = src_inputs.at(i); + auto src_input_shape = src_input->shape(); + // src_instruction + auto token = + src_computation->AddInstruction(HloInstruction::CreateToken()); + + auto send_inst = + src_computation->AddInstruction(HloInstruction::CreateSend( + src_input, token, channel_id, false /*is_host_transfer*/)); + auto send_done = src_computation->AddInstruction( + HloInstruction::CreateSendDone(send_inst)); + token = dst_computation->AddInstruction(HloInstruction::CreateToken()); + auto recv_inst = dst_computation->AddInstruction( + HloInstruction::CreateRecv(src_input_shape, token, channel_id, + false /*is_host_transfer*/), + "dst_recv" + std::to_string(i)); + auto recv_done = dst_computation->AddInstruction( + HloInstruction::CreateRecvDone(recv_inst)); + HloInstruction* recv_parameter = dst_computation->AddInstruction( + HloInstruction::CreateGetTupleElement(recv_done, 0)); + + dst_inputs.push_back(recv_parameter); + } + channel_id++; + // step3: clone same instruction to dst_computation + auto dst_inst = + dst_computation->AddInstruction(src_instruction->CloneWithNewOperands( + src_instruction->shape(), dst_inputs)); + + // step4 :create Send Instruction from dst_compuation, create Recv + // Instruction in src_computation + auto token = dst_computation->AddInstruction(HloInstruction::CreateToken()); + + auto ret_send_inst = + dst_computation->AddInstruction(HloInstruction::CreateSend( + dst_inst, token, channel_id, false /*is_host_transfer*/)); + auto send_done = dst_computation->AddInstruction( + HloInstruction::CreateSendDone(ret_send_inst)); + + // create recv in src_computation, create token node,so recv_inst will be + // executed by scheduler + token = src_computation->AddInstruction(HloInstruction::CreateToken()); + + auto recv_inst = src_computation->AddInstruction( + HloInstruction::CreateRecv(dst_inst->shape(), token, channel_id, + false /*is_host_transfer*/), + "src_recv_ret"); + auto recv_done = src_computation->AddInstruction( + HloInstruction::CreateRecvDone(recv_inst)); + HloInstruction* recv_parameter = src_computation->AddInstruction( + HloInstruction::CreateGetTupleElement(recv_done, 0)); + + // step5: replace instruction which use src_instruction's output with Recv + // Instruction + for (size_t i = 0; i < src_outputs.size(); i++) { + /* code */ + auto src_output = src_outputs.at(i); + // add dependency + auto status = src_instruction->ReplaceUseWith(src_output, recv_parameter); + if (!status.ok()) { + std::cout << "ReplaceUseWith error:" << status.message() << std::endl; + } + absl::flat_hash_map new_instruction_uses; + int operand_num = 0; + for (const HloInstruction* operand : src_output->operands()) { + if (operand->unique_id() == src_instruction->unique_id()) { + new_instruction_uses[operand_num] = recv_parameter; + } + operand_num++; + } + for (auto it = new_instruction_uses.begin(); + it != new_instruction_uses.end(); ++it) { + status = src_output->ReplaceOperandWith(it->first, it->second); + if (!status.ok()) { + std::cout << "ReplaceOperandWith error:" << status.message() + << std::endl; + } + } + } + // step6: remove src_instruction + src_instruction->DetachFromOperandsAndUsers(); + auto status = src_computation->RemoveInstruction(src_instruction); + if (!status.ok()) { + std::cout << "RemoveInstruction error:" << status.message() << std::endl; + } else { + std::cout << "RemoveInstruction success" + << src_computation->instruction_count() << std::endl; + } + } +}; +TEST_F(AutoReorderingTest, ConvertPDO) { + // GTEST_SKIP() << "using convert here;"; + // get filepath from env + const char* env = std::getenv("XLA_AUTOREORDER_XPLANE_DIR"); + if (env == nullptr) { + GTEST_SKIP() << "have no set XLA_AUTOREORDER_XPLANE_DIR env skip"; + } + auto status = ConvertXplaneToFile( + env, "/root/tb/llama_xla_trace_2n16g/llama_fdo.jsonl"); + std::cout << status.message() << std::endl; + EXPECT_TRUE(status.ok()); +} + +TEST_F(AutoReorderingTest, CommOpCostTest) { + HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction = + [&](const Shape& shape) { + constexpr int64_t kPointerSize = 8; + return ShapeUtil::ByteSizeOf(shape, kPointerSize); + }; + + HloCostAnalysis::Options options_{ShapeSizeBytesFunction, + /*per_second_rates=*/{}, + /*count_multiple_input_accesses=*/true}; + GpuHloCostAnalysis analysis_(options_); + + auto hlo_module = CreateNewVerifiedModule(TestName(), /*replica_count*/ 2); + auto st = MakeTestComputation(hlo_module.get()); + HloComputation* comp = st.value(); + ASSERT_IS_OK(hlo_module->entry_computation()->Accept(&analysis_)); + + const HloModuleConfig& config = hlo_module->config(); + int64_t num_devices = config.num_partitions(); + int64_t num_replicas = config.replica_count(); + const HloInstruction* all_reduce_start = + comp->GetInstructionWithName("all_reduce_start"); + + EXPECT_EQ(analysis_.NumOfDevices(*all_reduce_start), 2); + // auto numel_bytes = analysis_.bytes_accessed(*all_reduce_start); +}; +TEST_F(AutoReorderingTest, AllReduceDeviceNumber) { + absl::string_view hlo_string = R"( +HloModule m +%add { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + ROOT %a = f32[] add(p0, p1) +} +ENTRY e { + %p0 = bf16[1024,4096] parameter(0) + %ag-start = (bf16[1024,4096], bf16[8192,4096]) all-gather-start(bf16[1024,4096] %p0), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={0}, use_global_device_ids=true + %ag-done = bf16[8192,4096] all-gather-done( + (bf16[1024,4096], bf16[8192,4096]) %ag-start ) + %ar-start = bf16[8192,4096] all-reduce-start(bf16[8192,4096] %ag-done), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={0}, use_global_device_ids=true to_apply=%add + %ar-done = bf16[8192,4096] all-reduce-done( + (bf16[1024,4096], bf16[8192,4096]) %ar-start ) + %add-ret = bf16[8192,4096] call(%ag-done,%ar-done), to_apply=%add + ROOT tuple = (bf16[8192,4096],bf16[8192,4096]) tuple(%ag-done, %ar-done) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + hlo_string, /*replica_count*/ 8)); + HloInstruction* root = module->entry_computation()->root_instruction(); + HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction = + [&](const Shape& shape) { + constexpr int64_t kPointerSize = 8; + return ShapeUtil::ByteSizeOf(shape, kPointerSize); + }; + + HloCostAnalysis::Options options_{ShapeSizeBytesFunction, + /*per_second_rates=*/{}, + /*count_multiple_input_accesses=*/true}; + GpuHloCostAnalysis analysis_(options_); + + ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); + auto computation = module->entry_computation(); + const HloInstruction* ag_start = + computation->GetInstructionWithName("ag-start"); + + EXPECT_EQ(analysis_.NumOfDevices(*ag_start), 8); + auto numel_bytes = analysis_.bytes_accessed(*ag_start); + // numel_bytes should be 1024*4096*4=16777216 + EXPECT_EQ(numel_bytes, 16777216); + + const HloInstruction* ar_start = + computation->GetInstructionWithName("ar-start"); + + EXPECT_EQ(analysis_.NumOfDevices(*ar_start), 8); + numel_bytes = analysis_.bytes_accessed(*ar_start); + // numel_bytes should be 1024*4096*4=16777216 + EXPECT_EQ(numel_bytes, 16777216 * 2); + const HloInstruction* add_ret = + computation->GetInstructionWithName("add-ret "); + auto flops = analysis_.flop_count(*add_ret); + std::cout << "called flops:" << flops << std::endl; +} + +TEST_F(AutoReorderingTest, SPMDAutoReorder) { + // GTEST_SKIP()<<"new version xla can'parse this old hlo"; + absl::string_view hlo_string = R"( +HloModule SyncTensorsGraph.23, is_scheduled=true, entry_computation_layout={(f32[9600]{0}, f32[2400,12800]{1,0}, f32[400,12800]{1,0}, f32[320,9600]{1,0}, f32[1280]{0})->(f32[400,9600]{1,0}, f32[9600,320]{1,0}, f32[400,1280]{1,0})}, allow_spmd_sharding_propagation_to_output={true}, num_partitions=4, frontend_attributes={fingerprint_before_lhs="45ccf07b9a0d113ede565121a6790508"} + +fused_add { + param_0 = f32[400,9600]{1,0} parameter(0) + param_1.1 = f32[9600]{0} parameter(1) + broadcast.6.1 = f32[400,9600]{1,0} broadcast(param_1.1), dimensions={1}, metadata={op_type="aten__addmm" op_name="aten__addmm"} + ROOT add.4.1 = f32[400,9600]{1,0} add(param_0, broadcast.6.1), metadata={op_type="aten__addmm" op_name="aten__addmm"} +} + +fused_add.1 { + param_0.1 = f32[400,1280]{1,0} parameter(0) + param_1.3 = f32[1280]{0} parameter(1) + broadcast.8.1 = f32[400,1280]{1,0} broadcast(param_1.3), dimensions={1}, metadata={op_type="aten__addmm" op_name="aten__addmm"} + ROOT add.5.1 = f32[400,1280]{1,0} add(param_0.1, broadcast.8.1), metadata={op_type="aten__addmm" op_name="aten__addmm"} +} + +wrapped_transpose_computation { + param_0.2 = f32[320,9600]{1,0} parameter(0) + ROOT transpose.4.1 = f32[9600,320]{1,0} transpose(param_0.2), dimensions={1,0} +} + +ENTRY SyncTensorsGraph.23_spmd { + param.5 = f32[400,12800]{1,0} parameter(2), sharding={devices=[4,1]0,1,2,3}, metadata={op_type="xla__device_data" op_name="xla__device_data"} + param.1.0 = f32[2400,12800]{1,0} parameter(1), sharding={devices=[4,1]0,1,2,3}, metadata={op_type="xla__device_data" op_name="xla__device_data"} + param.2.0 = f32[9600]{0} parameter(0), sharding={replicated}, metadata={op_type="xla__device_data" op_name="xla__device_data"} + param.3.0 = f32[320,9600]{1,0} parameter(3), sharding={devices=[4,1]0,1,2,3}, metadata={op_type="xla__device_data" op_name="xla__device_data"} + param.4.0 = f32[1280]{0} parameter(4), sharding={replicated}, metadata={op_type="xla__device_data" op_name="xla__device_data"} + bitcast.44.0 = f32[12800,2400]{0,1} bitcast(param.1.0) + wrapped_transpose = f32[9600,320]{1,0} fusion(param.3.0), kind=kInput, calls=wrapped_transpose_computation + bitcast.61.0 = f32[9600,320]{0,1} bitcast(param.3.0) + all-gather-start = (f32[12800,2400]{0,1}, f32[12800,9600]{0,1}) all-gather-start(bitcast.44.0), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={1}, use_global_device_ids=true, metadata={op_type="aten__addmm" op_name="aten__addmm"}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"collective_backend_config":{"is_sync":false,"no_parallel_custom_call":false},"force_earliest_schedule":false} + all-gather-done = f32[12800,9600]{0,1} all-gather-done(all-gather-start), metadata={op_type="aten__addmm" op_name="aten__addmm"} + all-gather-start.1 = (f32[9600,320]{0,1}, f32[9600,1280]{0,1}) all-gather-start(bitcast.61.0), channel_id=2, replica_groups={{0,1,2,3}}, dimensions={1}, use_global_device_ids=true, metadata={op_type="aten__addmm" op_name="aten__addmm"}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"collective_backend_config":{"is_sync":false,"no_parallel_custom_call":false},"force_earliest_schedule":false} + all-gather-done.1 = f32[9600,1280]{0,1} all-gather-done(all-gather-start.1), metadata={op_type="aten__addmm" op_name="aten__addmm"} + custom-call.2.0 = (f32[400,9600]{1,0}, s8[4194304]{0}) custom-call(param.5, all-gather-done), custom_call_target="__cublas$gemm", metadata={op_type="aten__addmm" op_name="aten__addmm"}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"],"algorithm":"ALG_UNSET"},"epilogue":"DEFAULT","lhs_stride":"5120000","rhs_stride":"122880000","grad_x":false,"grad_y":false},"force_earliest_schedule":false} + get-tuple-element.3.0 = f32[400,9600]{1,0} get-tuple-element(custom-call.2.0), index=0, metadata={op_type="aten__addmm" op_name="aten__addmm"} + loop_add_fusion = f32[400,9600]{1,0} fusion(get-tuple-element.3.0, param.2.0), kind=kLoop, calls=fused_add, metadata={op_type="aten__addmm" op_name="aten__addmm"} + custom-call.3.0 = (f32[400,1280]{1,0}, s8[4194304]{0}) custom-call(loop_add_fusion, all-gather-done.1), custom_call_target="__cublas$gemm", metadata={op_type="aten__addmm" op_name="aten__addmm"}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"],"algorithm":"ALG_UNSET"},"epilogue":"DEFAULT","lhs_stride":"3840000","rhs_stride":"12288000","grad_x":false,"grad_y":false},"force_earliest_schedule":false} + get-tuple-element.4.0 = f32[400,1280]{1,0} get-tuple-element(custom-call.3.0), index=0, metadata={op_type="aten__addmm" op_name="aten__addmm"} + loop_add_fusion.1 = f32[400,1280]{1,0} fusion(get-tuple-element.4.0, param.4.0), kind=kLoop, calls=fused_add.1, metadata={op_type="aten__addmm" op_name="aten__addmm"} + ROOT tuple.1.0 = (f32[400,9600]{1,0}, f32[9600,320]{1,0}, f32[400,1280]{1,0}) tuple(loop_add_fusion, wrapped_transpose, loop_add_fusion.1) +} // SyncTensorsGraph.23_spmd +)"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string)); + auto gpu_latency_estimator = std::make_unique(); + std::unique_ptr latency_estimator; + int pointer_size_ = 4; + Backend& test_backend = backend(); + auto gpu_device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + + VLOG(2) << "threads_per_block_limit:" + << gpu_device_info.threads_per_block_limit() << " threads_per_warp" + << gpu_device_info.threads_per_warp(); + const int64_t scheduler_mem_limit = xla::gpu::GetSchedulerMemoryLimit( + hlo_module.get(), gpu_device_info, pointer_size_); + SchedulerConfig config = GetSchedulerConfig(scheduler_mem_limit); + SchedulerConfig sched_config = GetDefaultSchedConfig(); + HloCostAnalysis::ShapeSizeFunction shape_size_bytes = + [&shape_size_bytes](const Shape& shape) -> int64_t { + int64_t shape_size = 0; + if (shape.IsTuple()) { + for (auto& sub_shape : shape.tuple_shapes()) { + shape_size += shape_size_bytes(sub_shape); + } + return shape_size; + } + return ShapeUtil::ByteSizeOfElements(shape); + }; + + auto async_tracker = std::make_unique(sched_config); + latency_estimator = std::make_unique( + config, std::move(gpu_latency_estimator), gpu_device_info, + [input_pointer_size = pointer_size_](const Shape& shape) { + return GetSizeOfShape(shape, input_pointer_size); + }, + hlo_module->entry_computation()); + auto entry_computation = hlo_module->entry_computation(); + auto scheduler_core = std::make_unique( + shape_size_bytes, async_tracker.get(), latency_estimator.get(), + sched_config); + auto test_pass = + AutoReorderPass(std::move(latency_estimator), std::move(async_tracker), + std::move(scheduler_core), shape_size_bytes); + + for (HloComputation* computation : + hlo_module->MakeNonfusionComputations({})) { + auto status = test_pass.ScheduleComputation(computation); + if (!status.ok()) { + std::cout << "NonfusionComputations src_module fail" << std::endl; + } + EXPECT_TRUE(status.ok()); + } +} +TEST_F(AutoReorderingTest, DemoAutoReorder) { + GTEST_SKIP() << "Skipping DemoAutoReorder"; + + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +%add { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + ROOT %a = f32[] add(p0, p1) +} + +ENTRY %module { + %constant.19 = u32[] constant(0) + %replica_id = u32[]{:T(128)} replica-id() + %convert = f32[]{:T(128)} convert(u32[]{:T(128)} %replica_id) + %color_operand.1 = f32[2,8,256,256]{3,2,1,0:T(8,128)} broadcast( + f32[]{:T(128)} %convert), dimensions={} + %color_operand.2 = f32[2,8,256,256]{3,2,1,0:T(8,128)} broadcast( + f32[]{:T(128)} %convert), dimensions={} + %ar-start = f32[2,8,256,256] all-reduce-start( + f32[2,8,256,256] %color_operand.1), replica_groups={{0,1}}, to_apply=%add, + metadata={op_type="AllReduce" op_name="ar0"} + %ar-start.2 = f32[2,8,256,256] all-reduce-start( + f32[2,8,256,256] %color_operand.2), replica_groups={{0,1}}, to_apply=%add, + metadata={op_type="AllReduce" op_name="ar1"} + %ar-done = f32[2,8,256,256] all-reduce-done( + f32[2,8,256,256] %ar-start), + metadata={op_type="AllReduce" op_name="ar0"} + %ar-done-bc = f32[16,256,256] bitcast(f32[2,8,256,256] %ar-done), + metadata={op_type="Bitcast" op_name="ar0"} + %ar-done.2 = f32[2,8,256,256] all-reduce-done( + f32[2,8,256,256] %ar-start.2), + metadata={op_type="AllReduce" op_name="ar1"} + %ar-done-bc.2 = f32[16,256,256] bitcast(f32[2,8,256,256] %ar-done.2), + metadata={op_type="Bitcast" op_name="ar1"} + p0 = f32[16,64,256]{2,1,0} parameter(0) + p1 = f32[16,64,256]{2,1,0} parameter(1) + p2 = f32[16,256,256]{2,1,0} parameter(2) + p3 = f32[16,256,256]{2,1,0} parameter(3) + c0 = f32[16,256,256]{2,1,0} convolution(p0, p1), + window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb, + metadata={op_type="AllReduce" op_name="c0"} + c1 = f32[16,256,256]{2,1,0} convolution(p0, p1), + window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb, + metadata={op_type="AllReduce" op_name="c1"} + a2 = f32[16,256,256]{2,1,0} add(c1, c0) + a3 = f32[16,256,256]{2,1,0} add(a2, c0) + a4 = f32[16,256,256]{2,1,0} add(a3, c1) + ROOT t = (f32[16,256,256], f32[16,256,256], f32[16,256,256]) tuple(a4, %ar-done-bc.2, %ar-done-bc) +} +)"; + absl::string_view hlo_string_cpu = R"( +HloModule module, is_scheduled=true + +%add { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + ROOT %a = f32[] add(p0, p1) +} + +ENTRY %module { + p2 = f32[16,256,256]{2,1,0} parameter(0) + p3 = f32[16,256,256]{2,1,0} parameter(1) + c0 = f32[16,256,256]{2,1,0} multiply(p2, p3) + a2 = f32[16,256,256]{2,1,0} add(p2, p3) + a3 = f32[16,256,256]{2,1,0} add(a2, c0) + ROOT t = (f32[16,256,256]) tuple(a3) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module_cpu, ParseHloText(hlo_string_cpu)); + // VLOG(10) << module->ToString(); + auto* instruction = FindInstruction(hlo_module.get(), "param0"); + HloSchedule& module_schedule = hlo_module->schedule(); + EXPECT_TRUE(hlo_module->has_entry_computation()); + HloComputation* entry_computation = hlo_module->entry_computation(); + HloComputation* entry_computation_cpu = hlo_module_cpu->entry_computation(); + + std::vector new_instruction_sequence = + module_schedule.sequence(entry_computation).instructions(); + for (auto i = 0; i < new_instruction_sequence.size(); i++) { + auto inst = new_instruction_sequence.at(i); + std::cout << "idx=" << i << "=" << inst->ToString() << std::endl; + } + // test create H2D + // entry_computation + // idx=17=%a2 = f32[16,256,256]{2,1,0} add(f32[16,256,256]{2,1,0} %c1, + // f32[16,256,256]{2,1,0} %c0) + auto inst_a2 = new_instruction_sequence.at(17); + std::cout << "Before Move,there are:" << new_instruction_sequence.size() + << " instructions" << std::endl; + std::cout << "Before Move,there are:" + << hlo_module_cpu->schedule() + .sequence(hlo_module_cpu->entry_computation()) + .instructions() + .size() + << " instructions" << std::endl; + auto test_pass = AutoReorderPass(); + auto status = test_pass.MoveInstruction(hlo_module->entry_computation(), "a3", + hlo_module_cpu->entry_computation()); + if (!status.ok()) { + std::cout << "MoveInstruction src_module fail" << status.message() + << std::endl; + EXPECT_TRUE(status.ok()); + } + std::cout << "after Move" << std::endl; + status = + test_pass.RebuildHloOrdering(hlo_module->schedule(), entry_computation); + if (!status.ok()) { + std::cout << "RebuildHloOrdering src_module fail" << status.message() + << std::endl; + EXPECT_TRUE(status.ok()); + } + // std::cout << "after rebuild ordering src module="<< std::endl; + // new_instruction_sequence = + // hlo_module->schedule().sequence(entry_computation).instructions(); + // for (auto i = 0; i < new_instruction_sequence.size(); i++) { + // auto inst = new_instruction_sequence.at(i); + // std::cout << "idx=" << i << "=" << inst->ToString() << std::endl; + // } + + status = test_pass.RebuildHloOrdering(hlo_module_cpu->schedule(), + entry_computation_cpu); + + if (!status.ok()) { + std::cout << "RebuildHloOrdering hlo_module_cpu fail" << status.message() + << std::endl; + EXPECT_TRUE(status.ok()); + } +} +TEST_F(AutoReorderingTest, ReorderScheduleComputation) { + auto hlo_module = CreateNewUnverifiedModule(); + auto st = MakeTestComputation(hlo_module.get()); + EXPECT_TRUE(st.ok()); + auto gpu_latency_estimator = std::make_unique(); + std::unique_ptr latency_estimator; + int pointer_size_ = 4; + Backend& test_backend = backend(); + auto gpu_device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + + VLOG(2) << "threads_per_block_limit:" + << gpu_device_info.threads_per_block_limit() << " threads_per_warp" + << gpu_device_info.threads_per_warp(); + const int64_t scheduler_mem_limit = xla::gpu::GetSchedulerMemoryLimit( + hlo_module.get(), gpu_device_info, pointer_size_); + SchedulerConfig config = GetSchedulerConfig(scheduler_mem_limit); + SchedulerConfig sched_config = GetDefaultSchedConfig(); + HloCostAnalysis::ShapeSizeFunction shape_size_bytes = + [&shape_size_bytes](const Shape& shape) -> int64_t { + int64_t shape_size = 0; + if (shape.IsTuple()) { + for (auto& sub_shape : shape.tuple_shapes()) { + shape_size += shape_size_bytes(sub_shape); + } + return shape_size; + } + return ShapeUtil::ByteSizeOfElements(shape); + }; + + auto async_tracker = std::make_unique(sched_config); + latency_estimator = std::make_unique( + config, std::move(gpu_latency_estimator), gpu_device_info, + [input_pointer_size = pointer_size_](const Shape& shape) { + return GetSizeOfShape(shape, input_pointer_size); + }, + hlo_module->entry_computation()); + auto entry_computation = hlo_module->entry_computation(); + auto scheduler_core = std::make_unique( + shape_size_bytes, async_tracker.get(), latency_estimator.get(), + sched_config); + auto test_pass = + AutoReorderPass(std::move(latency_estimator), std::move(async_tracker), + std::move(scheduler_core), shape_size_bytes); + + for (HloComputation* computation : + hlo_module->MakeNonfusionComputations({})) { + auto status = test_pass.ScheduleComputation(computation); + if (!status.ok()) { + std::cout << "NonfusionComputations src_module fail" << std::endl; + } + EXPECT_TRUE(status.ok()); + } +} +TEST_F(AutoReorderingTest, ReorderPass) { + auto hlo_module = CreateNewUnverifiedModule(); + auto st = MakeTestComputation(hlo_module.get()); + EXPECT_TRUE(st.ok()); + int pointer_size_ = 4; + Backend& test_backend = backend(); + auto gpu_device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + const int64_t scheduler_mem_limit = xla::gpu::GetSchedulerMemoryLimit( + hlo_module.get(), gpu_device_info, pointer_size_); + SchedulerConfig config = GetSchedulerConfig(scheduler_mem_limit); + auto gpu_latency_estimator = std::make_unique(); + std::unique_ptr latency_estimator = + std::make_unique( + config, std::move(gpu_latency_estimator), gpu_device_info, + [input_pointer_size = pointer_size_](const Shape& shape) { + return GetSizeOfShape(shape, input_pointer_size); + }, + hlo_module->entry_computation()); + // we should create other estimator,otherwise it's nullptr(move to other + // place) + auto gpu_latency_estimator2 = std::make_unique(); + + auto latency_estimator2 = std::make_unique( + config, std::move(gpu_latency_estimator2), gpu_device_info, + [input_pointer_size = pointer_size_](const Shape& shape) { + return GetSizeOfShape(shape, input_pointer_size); + }, + hlo_module->entry_computation()); + SchedulerConfig sched_config = GetDefaultSchedConfig(); + auto status = RunScheduler(hlo_module.get(), sched_config, + std::move(latency_estimator)); + EXPECT_TRUE(status.ok()); + auto statics_or_status = + GetModuleCost(hlo_module.get(), sched_config, latency_estimator2.get()); + + EXPECT_TRUE(statics_or_status.ok()); + auto statics = statics_or_status.value(); + // statics. +}; +TEST_F(AutoReorderingTest, ReorderPassWithDefaultEstimator) { + auto hlo_module = CreateNewUnverifiedModule(); + auto st = MakeTestComputation(hlo_module.get()); + EXPECT_TRUE(st.ok()); + int pointer_size_ = 4; + Backend& test_backend = backend(); + auto gpu_device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + const int64_t scheduler_mem_limit = xla::gpu::GetSchedulerMemoryLimit( + hlo_module.get(), gpu_device_info, pointer_size_); + SchedulerConfig config = GetSchedulerConfig(scheduler_mem_limit); + auto gpu_latency_estimator = std::make_unique(); + std::unique_ptr latency_estimator; + latency_estimator = std::make_unique( + config, std::move(gpu_latency_estimator), gpu_device_info, + [input_pointer_size = pointer_size_](const Shape& shape) { + return GetSizeOfShape(shape, input_pointer_size); + }, + hlo_module->entry_computation()); + SchedulerConfig sched_config = GetDefaultSchedConfig(); + auto status = RunScheduler(hlo_module.get(), sched_config); + EXPECT_TRUE(status.ok()); + EXPECT_TRUE(CheckParameterInst(hlo_module.get())); +} +TEST_F(AutoReorderingTest, ReorderPassCommunicateComputation) { + // GTEST_SKIP() << "Skipping single test"; + std::srand(kRandomSeed); + auto hlo_module = CreateNewUnverifiedModule(); + auto gpu_latency_estimator = std::make_unique(); + SchedulerConfig sched_config = GetDefaultSchedConfig(); + auto st = + MakeCommunicateComputation(hlo_module.get(), gpu_latency_estimator.get()); + auto gpu_latency_estimator2 = gpu_latency_estimator->clone(); + auto gpu_latency_estimator3 = gpu_latency_estimator->clone(); + + auto status = RunScheduler(hlo_module.get(), sched_config, + std::move(gpu_latency_estimator)); + EXPECT_TRUE(status.ok()); + auto insts = hlo_module->schedule() + .sequence(hlo_module->entry_computation()) + .instructions(); + for (auto inst : insts) { + std::cout << inst->ToString() << std::endl; + } +} + +TEST_F(AutoReorderingTest, ReorderPassStableOrder) { + GTEST_SKIP() << "Skipping test when dev"; + std::srand(kRandomSeed); + auto hlo_module = CreateNewUnverifiedModule(); + auto gpu_latency_estimator = std::make_unique(); + SchedulerConfig sched_config = GetDefaultSchedConfig(); + auto st = MakeRandomComputation(hlo_module.get(), gpu_latency_estimator.get(), + /*inst num*/ 200, + /*max deps*/ 5, + /*communication rate*/ 0.2); + auto gpu_latency_estimator2 = gpu_latency_estimator->clone(); + auto gpu_latency_estimator3 = gpu_latency_estimator->clone(); + + auto status = RunScheduler(hlo_module.get(), sched_config, + std::move(gpu_latency_estimator)); + EXPECT_TRUE(status.ok()); + auto comm_op_1 = GetInstructionsOrderString(hlo_module.get()); + + // get hlo_module instruction order,and compute a hash + status = RunScheduler(hlo_module.get(), sched_config, + std::move(gpu_latency_estimator2)); + auto comm_op_2 = GetInstructionsOrderString(hlo_module.get()); + + EXPECT_TRUE(status.ok()); + EXPECT_EQ(comm_op_1, comm_op_2); + EXPECT_TRUE(CheckParameterInst(hlo_module.get())); +} +TEST_F(AutoReorderingTest, ReorderPassWithRandom) { + // GTEST_SKIP() << "Skipping single test"; + std::srand(kRandomSeed); + auto hlo_module = CreateNewUnverifiedModule(); + auto gpu_latency_estimator = std::make_unique(); + SchedulerConfig sched_config = GetDefaultSchedConfig(); + auto st = MakeRandomComputation(hlo_module.get(), gpu_latency_estimator.get(), + /*inst num*/ 200, + /*max deps*/ 5, + /*communication rate*/ 0.2); + // std::cout<ToString()<clone(); + auto gpu_latency_estimator3 = gpu_latency_estimator->clone(); + // run AutoReorder for compare + + auto status = RunScheduler(hlo_module.get(), sched_config, + std::move(gpu_latency_estimator)); + EXPECT_TRUE(status.ok()); + + auto statics = GetModuleCost(hlo_module.get(), sched_config, + gpu_latency_estimator2.get()); + EXPECT_TRUE(statics.ok()); + EXPECT_TRUE(CheckParameterInst(hlo_module.get())); + auto auto_reorder_cost = statics.value().total_cycles; + std::cout << "ReorderPassWithRandom:" << auto_reorder_cost << std::endl; + + // compare post order vs reorder + auto post_insts_order = + hlo_module->entry_computation()->MakeInstructionPostOrder(); + hlo_module->schedule().set_sequence(hlo_module->entry_computation(), + post_insts_order); + + statics = GetModuleCost(hlo_module.get(), sched_config, + gpu_latency_estimator2.get()); + EXPECT_TRUE(statics.ok()); + auto post_order_cost = statics.value().total_cycles; + EXPECT_TRUE(CheckParameterInst(hlo_module.get())); + std::cout << "MakeInstructionPostOrder:" << post_order_cost << std::endl; + + // run LatencyHidingScheduler for compare + // NOTICE: DO NOT using gpu_latency_estimator after + // std::move(gpu_latency_estimator) + auto lhs_status = RunLatencyHidingScheduler( + hlo_module.get(), sched_config, std::move(gpu_latency_estimator3)); + EXPECT_TRUE(lhs_status.ok()); + EXPECT_TRUE(CheckParameterInst(hlo_module.get())); + statics = GetModuleCost(hlo_module.get(), sched_config, + gpu_latency_estimator2.get()); + EXPECT_TRUE(statics.ok()); + auto xla_lhs_cost = statics.value().total_cycles; + EXPECT_LE(auto_reorder_cost, post_order_cost); + EXPECT_LE(auto_reorder_cost, xla_lhs_cost); +} +// skip this test +TEST_F(AutoReorderingTest, ReorderPassDataAnalyse) { + GTEST_SKIP() << "Skipping single test because longtime"; + std::srand(kRandomSeed); + auto gen = std::mt19937{kRandomSeed}; + int repeat_time = 1; + uint32_t nnodes = 100; + std::vector communication_rates; + // = { + // 0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9 + // }; + for (float current = 0.1; current < 0.9; current += 0.05) { + communication_rates.push_back(current); + } + + // communication rate from 0.05 to 0.95,step is 0.05 + std::ofstream csv_out("/tmp/test_ret.csv"); + csv_out + << "exp_id,nnodes,communication_rate,auto_reorder_cost,post_order_cost," + "xla_hiding_order_cost,xla_hiding_solve_time,auto_reorder_solve_time" + << std::endl; + for (auto communication_rate : communication_rates) { + for (size_t i = 0; i < repeat_time; i++) { + std::cout << TestName() << " repeat time:" << i << std::endl; + auto hlo_module = CreateNewUnverifiedModule(); + auto gpu_latency_estimator = + std::make_unique(); + // float communication_rate = 0.2; + SchedulerConfig sched_config = GetDefaultSchedConfig(); + auto st = + MakeRandomComputation(hlo_module.get(), gpu_latency_estimator.get(), + /*inst num*/ nnodes, + /*max deps*/ 5, + /*communication rate*/ communication_rate, + /* gen */ gen); + EXPECT_TRUE(st.ok()); + // auto latency_estimator = create_latency_estimator(); + + auto gpu_latency_estimator2 = gpu_latency_estimator->clone(); + auto gpu_latency_estimator3 = gpu_latency_estimator->clone(); + // run AutoReorder for compare + // get running time cost + auto start = std::chrono::steady_clock::now(); + auto status = RunScheduler(hlo_module.get(), sched_config, + std::move(gpu_latency_estimator)); + EXPECT_TRUE(status.ok()); + auto end = std::chrono::steady_clock::now(); + auto auto_reorder_solve_time = + std::chrono::duration_cast(end - start) + .count(); + + auto statics = GetModuleCost(hlo_module.get(), sched_config, + gpu_latency_estimator2.get()); + EXPECT_TRUE(statics.ok()); + auto auto_reorder_cost = statics.value().total_cycles; + // compare post order vs reorder + auto post_insts_order = + hlo_module->entry_computation()->MakeInstructionPostOrder(); + hlo_module->schedule().set_sequence(hlo_module->entry_computation(), + post_insts_order); + + statics = GetModuleCost(hlo_module.get(), sched_config, + gpu_latency_estimator2.get()); + EXPECT_TRUE(statics.ok()); + auto post_order_cost = statics.value().total_cycles; + + // run LatencyHidingScheduler for compare + // NOTICE: DO NOT using gpu_latency_estimator after + // std::move(gpu_latency_estimator) + start = std::chrono::steady_clock::now(); + auto lhs_status = RunLatencyHidingScheduler( + hlo_module.get(), sched_config, std::move(gpu_latency_estimator3)); + EXPECT_TRUE(lhs_status.ok()); + end = std::chrono::steady_clock::now(); + auto xla_hiding_solve_time = + std::chrono::duration_cast(end - start) + .count(); + statics = GetModuleCost(hlo_module.get(), sched_config, + gpu_latency_estimator2.get()); + EXPECT_TRUE(statics.ok()); + auto xla_hiding_order_cost = statics.value().total_cycles; + csv_out << i << "," << nnodes << "," << communication_rate << "," + << auto_reorder_cost << "," << post_order_cost << "," + << xla_hiding_order_cost << "," << xla_hiding_solve_time << "," + << auto_reorder_solve_time << std::endl; + } + } +} + +} // namespace xla diff --git a/xla/hlo/experimental/auto_reorder/auto_reorder_xla.py b/xla/hlo/experimental/auto_reorder/auto_reorder_xla.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/xla/hlo/experimental/auto_reorder/convert_xplane.cc b/xla/hlo/experimental/auto_reorder/convert_xplane.cc new file mode 100644 index 0000000000000..528fef6f4fdf4 --- /dev/null +++ b/xla/hlo/experimental/auto_reorder/convert_xplane.cc @@ -0,0 +1,325 @@ +#include "xla/hlo/experimental/auto_reorder/convert_xplane.h" + +namespace xla { + +using tensorflow::profiler::XPlane; +using tensorflow::profiler::XSpace; +using tsl::profiler::CreateTfXPlaneVisitor; +using tsl::profiler::FindPlanesWithPrefix; +using tsl::profiler::FindPlaneWithName; +using tsl::profiler::GetStatTypeStr; +using tsl::profiler::HostEventType; +using tsl::profiler::IsInternalEvent; +using tsl::profiler::ProfilerJoinPath; +using tsl::profiler::StatType; +using tsl::profiler::XEventMetadataVisitor; +using tsl::profiler::XEventVisitor; +using tsl::profiler::XLineVisitor; +using tsl::profiler::XPlaneVisitor; +using tsl::profiler::XStatVisitor; + +// maintain info for next PGLE +/* +steps + +1. parse module, maintain map of {instr_name: instr_info} func: +GetHloInstrProfileInfo +2. parse xplane, maintain map of {instr_name: instr_latency}, update instr_info +3. write to sqlite file, as origin DB +4. use origin DB,group instr_type/shape to generate summary DB +5. use summary DB for next PGLE +*/ + +void GetXPlaneLatencyInfo( + const XPlaneVisitor& xplane, + absl::flat_hash_map* hlo_latency_info) { + // Iterate events. + xplane.ForEachLine([hlo_latency_info](const XLineVisitor& xline) { + if (xline.DisplayName() == tsl::profiler::kXlaAsyncOpLineName) { + return; + } + VLOG(5) << "Processing line: " << xline.DisplayName(); + xline.ForEachEvent([hlo_latency_info](const XEventVisitor& xevent) { + int64_t event_type = + xevent.Type().value_or(HostEventType::kUnknownHostEventType); + if (IsInternalEvent(event_type)) return; + std::optional hlo_name = std::nullopt; + + auto for_each_stat = [&](const XStatVisitor& stat) { + if (stat.ValueCase() == tsl::profiler::XStat::VALUE_NOT_SET) return; + // Store latency information for HLOs. + if (stat.Name() == GetStatTypeStr(StatType::kHloOp)) { + hlo_name = stat.ToString(); + } + }; + xevent.Metadata().ForEachStat(for_each_stat); + xevent.ForEachStat(for_each_stat); + double latency = static_cast(xevent.DurationNs()) / 1e3; + VLOG(5) << "hlo_name: " << hlo_name.value_or("N/A") + << "latency:" << latency; + + std::string key = hlo_name.value(); + (*hlo_latency_info)[key].durations.emplace_back(latency); + }); + }); +} + +std::unique_ptr CreateModuleFromProto( + const xla::HloModuleProto& proto) { + auto config = xla::HloModule::CreateModuleConfigFromProto(proto, {}); + if (config.ok()) { + auto module = xla::HloModule::CreateFromProto(proto, config.value()); + if (module.ok()) { + return std::move(*module); + } + } + return nullptr; +} + +Status GetHloInstrProfileInfo( + const xla::HloModuleProto& hlo_module_proto, + absl::flat_hash_map* + hlo_module_info) { + std::unique_ptr hlo_module = + CreateModuleFromProto(hlo_module_proto); + if (hlo_module == nullptr) { + return absl::InternalError("Failed to create HloModule from proto"); + } + VLOG(5) << "success get hlo module from proto"; + for (HloComputation* computation : + hlo_module->MakeNonfusionComputations({})) { + for (auto* instr : computation->instructions()) { + // instr to json + //{name:"name",opcode:"opcode",operand_count:1,operand_names:["a"],operand_types:["f32"],shape:"[1,2,3]",result_type:"f32",result_shape:"[1,2,3]",result_element_type:"f32",result_element_shape:"[1,2,3]",result_element_count:6} + // TODO: should we need shard info? + // TODO: custom call + // there are 3 category instrs: + // 1. custom call, include GEMM now; record its input shape/dtype + // 2. communicate call, include async reducescatter ; record its input + // shape/dtype + // 3. other, + HloInstructionProto instr_origin_proto = instr->ToProto(); + auto_reorder::InstrProfileInfo instr_info; + auto_reorder::Size ret_size; + instr_info.set_name(instr_origin_proto.name()); + HloOpcode code = instr->opcode(); + + instr_info.set_opcode(static_cast(code)); + + // set operand count/type/size + instr_info.set_operand_count(instr->operand_count()); + for (auto operand : instr->operands()) { + Shape op_shape = operand->shape(); + // operand dtype + + instr_info.add_operand_types( + PrimitiveType_Name(op_shape.element_type())); + auto_reorder::Size* op_size = instr_info.add_operand_sizes(); + op_size->set_rank(op_shape.dimensions_size()); + for (size_t i = 0; i < op_shape.dimensions_size(); i++) { + op_size->add_sizes(op_shape.dimensions(i)); + } + } + + Shape shape = instr->shape(); + instr_info.mutable_result_size()->set_rank(shape.dimensions_size()); + for (size_t i = 0; i < shape.dimensions_size(); i++) { + /* code */ + instr_info.mutable_result_size()->add_sizes(shape.dimensions(i)); + } + // custom call + switch (code) { + case HloOpcode::kCustomCall: { + instr_info.set_custom_call_target(instr->custom_call_target()); + break; + } + case HloOpcode::kReduceScatter: + case HloOpcode::kAllGather: + case HloOpcode::kAllGatherStart: + case HloOpcode::kAllReduce: + case HloOpcode::kAllReduceStart: + case HloOpcode::kCollectivePermuteStart: { // comm op need record + // process group + // example :{{1,2,3,4}}, {{1,2},{3,4}} + std::vector replica_groups = instr->replica_groups(); + uint16_t group_id = 0; + for (auto replica_group : replica_groups) { + xla::auto_reorder::ReplicaGroup* group = + instr_info.add_process_groups(); + group->set_replica_group_id(group_id); + group_id++; + for (auto replica : replica_group.replica_ids()) { + group->add_replica_ids(replica); + } + } + + // instr_info.set_process_group(); + break; + } + case HloOpcode::kAsyncStart: { + // get async inner instr + } + default: + break; + + } // end switch + hlo_module_info->emplace(instr_origin_proto.name(), instr_info); + } // end for instrs + } // end for computations + return absl::OkStatus(); +} + +void GetXPlaneHloModuleProfileInfo( + const XPlaneVisitor& xplane, + absl::flat_hash_map* + hlo_module_info) { + // Iterate events. + xplane.ForEachEventMetadata([&](const XEventMetadataVisitor& event_metadata) { + event_metadata.ForEachStat([&](const XStatVisitor& stat) { + xla::HloProto hlo_proto; + if (tsl::ParseProtoUnlimited(&hlo_proto, stat.BytesValue().data(), + stat.BytesValue().size())) { + const xla::HloModuleProto& hlo_module_proto = hlo_proto.hlo_module(); + + Status st = GetHloInstrProfileInfo(hlo_module_proto, hlo_module_info); + if (!st.ok()) { + VLOG(5) << "Failed to get HloInstrProfileInfo from HloModuleProto"; + } + } + }); + }); +} + +Status ConvertXplaneToProfiledJSONLine( + std::vector xspaces, + std::vector* jsonline_vector) { + // name to HloLatencyInfo + absl::flat_hash_map hlo_latency_info; + // name to HloInstructionProto + absl::flat_hash_map + hlo_instr_profile_info; + google::protobuf::util::JsonPrintOptions options; + options.add_whitespace = true; + options.always_print_primitive_fields = true; + google::protobuf::util::Status st; + // st = google::protobuf::util::MessageToJsonString(profile_proto, + // &json_string, options); if(!st.ok()) { + // return absl::InternalError("Failed to convert ProfiledInstructionsProto + // to json"); + // } + // Iterate through each host. + for (const XSpace& xspace : xspaces) { + const XPlane* metadata_plane = + FindPlaneWithName(xspace, tsl::profiler::kMetadataPlaneName); + if (metadata_plane != nullptr) { + XPlaneVisitor xplane = CreateTfXPlaneVisitor(metadata_plane); + GetXPlaneHloModuleProfileInfo(xplane, &hlo_instr_profile_info); + } + std::vector device_planes = + FindPlanesWithPrefix(xspace, tsl::profiler::kGpuPlanePrefix); + // We don't expect GPU and TPU planes and custom devices to be present in + // the same XSpace. + if (device_planes.empty()) { + VLOG(5) << "No GPU plane found, try to find TPU plane."; + device_planes = + FindPlanesWithPrefix(xspace, tsl::profiler::kTpuPlanePrefix); + } + if (device_planes.empty()) { + VLOG(5) << "No TPU plane found, try to find custom device plane."; + device_planes = + FindPlanesWithPrefix(xspace, tsl::profiler::kCustomPlanePrefix); + } + // Go over each device plane. + for (const XPlane* device_plane : device_planes) { + XPlaneVisitor xplane = CreateTfXPlaneVisitor(device_plane); + GetXPlaneLatencyInfo(xplane, &hlo_latency_info); + } + } + if (hlo_instr_profile_info.empty()) { + VLOG(5) << "No HLO instruction info found in xplane protobuf."; + return absl::InternalError("No HLO latency info found in xplane"); + } + if (hlo_latency_info.empty()) { + VLOG(5) << "No HLO latency info found in xplane."; + return absl::InternalError("No HLO latency info found in xplane"); + } + HloLatencyStats stats; + + // Get the mean duration for each hlo and store into the proto. + for (const auto& iter : hlo_latency_info) { + // auto* cost = profiled_instructions_proto->add_costs(); + auto profile_it = hlo_instr_profile_info.find(iter.first); + if (profile_it == hlo_instr_profile_info.end()) { + VLOG(5) << "No instr info found for instr: " << iter.first; + stats.misses++; + continue; + } else { + stats.hits++; + } + + auto_reorder::InstrProfileInfo cost = profile_it->second; + for (auto duration : iter.second.durations) { + // cost->add_durations(d); + cost.set_cost(duration); + std::string json_string; + auto st = google::protobuf::util::MessageToJsonString(cost, &json_string, + options); + if (!st.ok()) { + return absl::InternalError( + "Failed to convert ProfiledInstructionsProto to json"); + } + jsonline_vector->push_back(json_string); + } + } + VLOG(5) << "Lookup inst profiler, Hits: " << stats.hits + << " Misses: " << stats.misses; + return OkStatus(); +} +Status ConvertXplaneUnderLogdirToProfiledInstructionsProto( + const std::string& logdir, std::vector* jsonline_vector) { + // Find the xplane files for each host under logdir. + std::vector children_path; + TF_RETURN_IF_ERROR(tsl::Env::Default()->GetChildren(logdir, &children_path)); + if (children_path.empty()) { + return absl::NotFoundError( + absl::StrCat("Could not find file under: ", logdir)); + } + std::vector xspaces; + for (const std::string& child_path : children_path) { + if (absl::StrContains(child_path, kXPlanePb)) { + std::string xspace_path = ProfilerJoinPath(logdir, child_path); + tensorflow::profiler::XSpace xspace; + TF_RETURN_IF_ERROR( + ReadBinaryProto(tsl::Env::Default(), xspace_path, &xspace)); + xspaces.emplace_back(xspace); + } + } + if (xspaces.size() == 0) { + return absl::NotFoundError( + absl::StrCat("Could not find xplane file under: ", logdir)); + } + VLOG(3) << "Have load " << xspaces.size() << " xspaces"; + return ConvertXplaneToProfiledJSONLine(xspaces, jsonline_vector); +} + +Status ConvertXplaneToFile(const std::string& xplane_dir, + const std::string& output_filename) { + tensorflow::profiler::ProfiledInstructionsProto profile_proto; + std::vector jsonline_vector; + auto status = ConvertXplaneUnderLogdirToProfiledInstructionsProto( + xplane_dir, &jsonline_vector); + if (!status.ok()) { + return status; + } + // open file,write jsonline + std::ofstream fout = std::ofstream(output_filename); + if (!fout.is_open()) { + return absl::InternalError("Failed to open file for writing"); + } + for (const std::string& jsonline : jsonline_vector) { + fout << jsonline << std::endl; + } + return OkStatus(); +} + +} // namespace xla \ No newline at end of file diff --git a/xla/hlo/experimental/auto_reorder/convert_xplane.h b/xla/hlo/experimental/auto_reorder/convert_xplane.h new file mode 100644 index 0000000000000..46e0f382b1dbe --- /dev/null +++ b/xla/hlo/experimental/auto_reorder/convert_xplane.h @@ -0,0 +1,82 @@ +// decouple xla/python deps, xla/python need +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" + +// #include "xla/status.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo.pb.h" +#include "xla/primitive_util.h" +#include "xla/xla.pb.h" + +#include "tsl/platform/env.h" +#include "tsl/platform/types.h" +#include "tsl/profiler/convert/xla_op_utils.h" +#include "tsl/profiler/protobuf/xplane.pb.h" +#include "tsl/profiler/utils/file_system_utils.h" +#include "tsl/profiler/utils/tf_xplane_visitor.h" +#include "tsl/profiler/utils/xplane_schema.h" +#include "tsl/profiler/utils/xplane_utils.h" +#include "tsl/profiler/utils/xplane_visitor.h" +#include "tsl/profiler/protobuf/profiled_instructions.pb.h" +#include "tsl/profiler/protobuf/xplane.pb.h" +#include "google/protobuf/util/json_util.h" +#include "xla/hlo/experimental/auto_reorder/instr_profile_info.pb.h" + +namespace xla { + +constexpr char kXPlanePb[] = "xplane.pb"; +constexpr char kCostNameSep[] = "::"; + +using tensorflow::profiler::XPlane; +using tensorflow::profiler::XSpace; +using tsl::profiler::CreateTfXPlaneVisitor; +using tsl::profiler::FindPlanesWithPrefix; +using tsl::profiler::FindPlaneWithName; +using tsl::profiler::GetStatTypeStr; +using tsl::profiler::HostEventType; +using tsl::profiler::IsInternalEvent; +using tsl::profiler::ProfilerJoinPath; +using tsl::profiler::StatType; +using tsl::profiler::XEventMetadataVisitor; +using tsl::profiler::XEventVisitor; +using tsl::profiler::XLineVisitor; +using tsl::profiler::XPlaneVisitor; +using tsl::profiler::XStatVisitor; + +// Latency info for a single HLO instruction. it's a vector of durations. Each +// duration is the latency of the instruction +struct HloLatencyInfo { + std::vector durations; +}; +struct HloLatencyStats { + uint32_t hits; + uint32_t misses; +}; +Status ConvertXplaneToProfiledInstructionsProto( + std::vector xspaces, + tensorflow::profiler::ProfiledInstructionsProto* + profiled_instructions_proto); +Status ConvertXplaneUnderLogdirToProfiledInstructionsProto( + const std::string& logdir, tensorflow::profiler::ProfiledInstructionsProto* + profiled_instructions_proto); + +Status ConvertXplaneToFile(const std::string& xplane_dir, + const std::string& output_filename); + +} // namespace xla \ No newline at end of file diff --git a/xla/hlo/experimental/auto_reorder/convert_xplane_bin.cc b/xla/hlo/experimental/auto_reorder/convert_xplane_bin.cc new file mode 100644 index 0000000000000..442a3bb9070a1 --- /dev/null +++ b/xla/hlo/experimental/auto_reorder/convert_xplane_bin.cc @@ -0,0 +1,3 @@ +#include "xla/hlo/experimental/auto_reorder/convert_xplane.h" + +int main() { return 0; } \ No newline at end of file diff --git a/xla/hlo/experimental/auto_reorder/instr_profile_info.proto b/xla/hlo/experimental/auto_reorder/instr_profile_info.proto new file mode 100644 index 0000000000000..afe8162aa6837 --- /dev/null +++ b/xla/hlo/experimental/auto_reorder/instr_profile_info.proto @@ -0,0 +1,42 @@ +// Copyright 2023 The Lynx Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +syntax = "proto3"; + +package xla.auto_reorder; + +import "google/protobuf/any.proto"; +// Size(rank=3,sizes=[2,3,4]) +message Size { + int64 rank = 1; + repeated int64 sizes = 2; +} +//ReplicaGroup(replica_group_id=1, replica_ids={1,2}) +message ReplicaGroup{ + int64 replica_group_id=1; + repeated int64 replica_ids=2; +} +// as xla/service/hlo.proto HloInstructionProto subset,we focus on compute/communicate complexity +message InstrProfileInfo { + string name = 1; + uint32 operand_count=2; + uint32 result_count=3; + uint32 opcode=4; + uint32 version=5; + repeated string operand_types = 6; + repeated string result_types = 7; + repeated Size operand_sizes = 8; + Size result_size = 9; + repeated ReplicaGroup process_groups=10; + optional string custom_call_target = 11; + double cost=12; +} \ No newline at end of file diff --git a/xla/hlo/experimental/auto_sharding/BUILD b/xla/hlo/experimental/auto_sharding/BUILD index 9eefa8a0c02f0..dc2a1fe5a06f0 100644 --- a/xla/hlo/experimental/auto_sharding/BUILD +++ b/xla/hlo/experimental/auto_sharding/BUILD @@ -1,8 +1,9 @@ # Automatic sharding annotation load("@bazel_skylib//rules:build_test.bzl", "build_test") -load("//xla:xla.bzl", "auto_sharding_deps", "auto_sharding_solver_deps", "xla_cc_binary", "xla_cc_test") +load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_libtpu_portable") load("@tsl//tsl/platform:build_config.bzl", "tf_proto_library") +load("//xla:xla.bzl", "auto_sharding_deps", "auto_sharding_solver_deps", "xla_cc_binary", "xla_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -23,10 +24,12 @@ cc_library( srcs = [ "auto_sharding.cc", "auto_sharding_dot_handler.cc", + "auto_sharding_strategy.cc", ], hdrs = [ "auto_sharding.h", ], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_cost_graph", ":auto_sharding_option", @@ -59,10 +62,12 @@ cc_library( "//xla/service:hlo_memory_scheduler", "//xla/service:hlo_pass", "//xla/service:hlo_value", + "//xla/service:optimize_input_output_buffer_alias", "//xla/service:sharding_propagation", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -79,9 +84,11 @@ cc_library( cc_library( name = "auto_sharding_solver_impl", srcs = ["auto_sharding_solver_impl.cc"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_proto_cc", ":auto_sharding_strategy", + "@com_google_absl//absl/log:check", "@com_google_ortools//ortools/linear_solver", ], ) @@ -89,11 +96,15 @@ cc_library( cc_library( name = "auto_sharding_solver", srcs = ["auto_sharding_solver.cc"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_proto_cc", ":auto_sharding_strategy", + "//xla:status", + "//xla:status_macros", "//xla:statusor", "//xla:util", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -115,14 +126,16 @@ cc_library( "auto_sharding_solver.h", "auto_sharding_strategy.h", ], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_proto_cc", "//xla:shape_util", - "//xla:statusor", + "//xla:status", "//xla/hlo/ir:hlo", "//xla/service:hlo_value", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_ortools//ortools/linear_solver", ], @@ -130,12 +143,16 @@ cc_library( cc_library( name = "auto_sharding_cost_graph", + srcs = ["auto_sharding_cost_graph.cc"], hdrs = ["auto_sharding_cost_graph.h"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_strategy", ":matrix", "//xla:shape_util", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -146,19 +163,20 @@ cc_library( name = "auto_sharding_option", srcs = ["auto_sharding_option.cc"], hdrs = ["auto_sharding_option.h"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", ], ) cc_library( name = "auto_sharding_wrapper", hdrs = ["auto_sharding_wrapper.h"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_cost_graph", ":auto_sharding_option", @@ -167,12 +185,15 @@ cc_library( "//xla/hlo/utils:hlo_live_range", "//xla/service:hlo_cost_analysis", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings:string_view", ], ) cc_library( name = "auto_sharding_impl", srcs = ["auto_sharding_impl.cc"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_cost_graph", ":auto_sharding_option", @@ -182,12 +203,14 @@ cc_library( "//xla/hlo/utils:hlo_live_range", "//xla/service:hlo_cost_analysis", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:string_view", ], ) cc_library( name = "matrix", hdrs = ["matrix.h"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ "@com_google_absl//absl/strings", "@tsl//tsl/platform:logging", @@ -198,19 +221,23 @@ cc_library( name = "cluster_environment", srcs = ["cluster_environment.cc"], hdrs = ["cluster_environment.h"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_option", ":auto_sharding_strategy", ":auto_sharding_util", ":profiling_result", + "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service/spmd:spmd_partitioner", + "@com_google_absl//absl/types:span", ], ) cc_library( name = "profiling_result", hdrs = ["profiling_result.h"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [":auto_sharding_strategy"], ) @@ -218,6 +245,7 @@ cc_library( name = "auto_sharding_util", srcs = ["auto_sharding_util.cc"], hdrs = ["auto_sharding_util.h"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_strategy", "//xla:array", @@ -227,9 +255,11 @@ cc_library( "//xla:statusor", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:ptrvec", "//xla/hlo/utils:hlo_sharding_util", "//xla/service:call_graph", "//xla/service:sharding_propagation", + "//xla/service:while_loop_analysis", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", @@ -248,18 +278,19 @@ cc_library( name = "metrics", srcs = ["metrics.cc"], hdrs = ["metrics.h"], + compatible_with = get_compatible_with_libtpu_portable(), deps = ["@tsl//tsl/lib/monitoring:counter"], ) xla_cc_binary( name = "auto_sharding_runner", srcs = ["auto_sharding_runner.cc"], + compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding", "//xla:status", "//xla/hlo/ir:hlo", "//xla/service:hlo_parser", - "//xla/stream_executor:dnn", "//xla/tools:hlo_module_loader", "@tsl//tsl/platform:platform_port", ], @@ -288,14 +319,23 @@ xla_cc_test( deps = [ ":auto_sharding", ":auto_sharding_option", - ":auto_sharding_proto_cc", + ":auto_sharding_strategy", ":auto_sharding_util", + "//xla:statusor", "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_live_range", "//xla/hlo/utils:hlo_matchers", + "//xla/service:buffer_value", + "//xla/service:hlo_alias_analysis", + "//xla/service:hlo_memory_scheduler", "//xla/service:hlo_parser", + "//xla/service:hlo_value", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", @@ -312,9 +352,12 @@ xla_cc_test( ], deps = [ ":auto_sharding_proto_cc", - ":auto_sharding_solver", + ":auto_sharding_solver", # build_cleaner: keep ":auto_sharding_strategy", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_googletest//:gtest", ], ) diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 3de8fbeb86ce7..8231912f55b54 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -57,6 +56,7 @@ limitations under the License. #include "xla/hlo/experimental/auto_sharding/metrics.h" #include "xla/hlo/experimental/auto_sharding/profiling_result.h" #include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" @@ -74,10 +74,12 @@ limitations under the License. #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_memory_scheduler.h" #include "xla/service/hlo_value.h" +#include "xla/service/optimize_input_output_buffer_alias.h" #include "xla/service/sharding_propagation.h" #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" +#include "xla/status.h" #include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -90,12 +92,12 @@ namespace spmd { namespace { constexpr double kOverbudgetCoeff = 1e6; -constexpr double kSaltiplier = 0.001; // Modifies each obj. term by at most .1% +constexpr double kSaltiplier = 0.0; // This value (0.0) disables salting. } // namespace -// Compute the resharding cost vector from multiple possible strategies -// to a desired sharding spec. -std::vector ReshardingCostVector( +// Compute the resharding cost vector from multiple possible strategies to a +// desired sharding spec. +std::vector CommunicationReshardingCostVector( const StrategyGroup* strategy_group, const Shape& operand_shape, const HloSharding& required_sharding, const ClusterEnvironment& cluster_env) { @@ -112,9 +114,66 @@ std::vector ReshardingCostVector( return ret; } +double ComputeMemoryReshardingCost(const Shape& shape, + const HloSharding& src_sharding, + const HloSharding& dst_sharding, + const Array& device_mesh) { + int64_t src_n_dim = NumTileDimensions(src_sharding); + int64_t dst_n_dim = NumTileDimensions(dst_sharding); + + int64_t src_sharded_bytes = GetShardedInstructionSize( + shape, device_mesh.num_elements(), src_sharding); + double result = std::max( + src_sharded_bytes, GetShardedInstructionSize( + shape, device_mesh.num_elements(), dst_sharding)); + + if (src_n_dim != dst_n_dim && src_n_dim != -1 && dst_n_dim != -1) { + Shape inter_shape = ComputeIntermediateShape(src_sharding, dst_sharding, + shape, device_mesh); + + std::optional src_inter_sharding = + hlo_sharding_util::ReshapeSharding(shape, inter_shape, src_sharding); + std::optional dst_inter_sharding = + hlo_sharding_util::ReshapeSharding(shape, inter_shape, dst_sharding); + if (!src_inter_sharding.has_value() || !dst_inter_sharding.has_value()) { + src_inter_sharding = HloSharding::Replicate(); + dst_inter_sharding = HloSharding::Replicate(); + } + + result = std::max( + result, + static_cast(std::max( + GetShardedInstructionSize(inter_shape, device_mesh.num_elements(), + src_inter_sharding), + GetShardedInstructionSize(inter_shape, device_mesh.num_elements(), + dst_inter_sharding)))); + } + return result - src_sharded_bytes; +} + +std::vector MemoryReshardingCostVector( + const StrategyGroup* strategy_group, const Shape& operand_shape, + const HloSharding& required_sharding, + const ClusterEnvironment& cluster_env) { + CHECK(!strategy_group->is_tuple) << "Only works with strategy vector."; + std::vector ret; + ret.reserve(strategy_group->strategies.size()); + auto required_sharding_for_resharding = required_sharding.IsTileMaximal() + ? HloSharding::Replicate() + : required_sharding; + CHECK_OK(required_sharding.Validate(operand_shape)) + << strategy_group->ToString(); + for (const auto& x : strategy_group->strategies) { + ret.push_back(ComputeMemoryReshardingCost(operand_shape, x.output_sharding, + required_sharding_for_resharding, + cluster_env.device_mesh_)); + } + return ret; +} + // Factory functions for StrategyGroup. std::unique_ptr CreateLeafStrategyGroupWithoutInNodes( - size_t instruction_id, StrategyGroups& strategy_groups) { + const size_t instruction_id, StrategyGroups& strategy_groups) { auto strategy_group = std::make_unique(); strategy_group->is_tuple = false; strategy_group->node_idx = strategy_groups.size(); @@ -125,7 +184,7 @@ std::unique_ptr CreateLeafStrategyGroupWithoutInNodes( // Factory functions for StrategyGroup. std::unique_ptr CreateLeafStrategyGroup( - size_t instruction_id, const HloInstruction* ins, + const size_t instruction_id, const HloInstruction* ins, const StrategyMap& strategy_map, StrategyGroups& strategy_groups) { auto strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, strategy_groups); @@ -135,7 +194,8 @@ std::unique_ptr CreateLeafStrategyGroup( return strategy_group; } -std::unique_ptr CreateTupleStrategyGroup(size_t instruction_id) { +std::unique_ptr CreateTupleStrategyGroup( + const size_t instruction_id) { auto strategy_group = std::make_unique(); strategy_group->is_tuple = true; strategy_group->node_idx = -1; @@ -143,43 +203,38 @@ std::unique_ptr CreateTupleStrategyGroup(size_t instruction_id) { return strategy_group; } -// ShardingPropagation::GetShardingFromUser does not handle TopK custom -// calls. Mirroring that function's handling of kSort, we handle TopK below. -HloSharding InferInputShardingForTopK(const HloInstruction* ins, - const HloSharding& output_sharding) { - return output_sharding; -} - // Compute the resharding costs as well as input shardings (when missing) for // all operands of a given instruction, and an output sharding for that // instruction. -std::vector> +std::pair GenerateReshardingCostsAndMissingShardingsForAllOperands( const HloInstruction* ins, const HloSharding& output_sharding, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, const CallGraph& call_graph, std::vector>& input_shardings) { - std::vector> resharding_costs; + ReshardingCosts communication_resharding_costs; + ReshardingCosts memory_resharding_costs; if (input_shardings.empty() && ins->operand_count() > 0) { input_shardings.resize(ins->operand_count()); } for (int64_t k = 0; k < ins->operand_count(); ++k) { auto operand = ins->operand(k); if (operand->shape().IsToken() || operand->shape().rank() == 0) { - resharding_costs.push_back(std::vector( + communication_resharding_costs.push_back(std::vector( + strategy_map.at(operand)->strategies.size(), 0.0)); + memory_resharding_costs.push_back(std::vector( strategy_map.at(operand)->strategies.size(), 0.0)); if (!input_shardings[k].has_value()) { input_shardings[k] = HloSharding::Replicate(); } } else { std::optional cur_input_sharding; + CHECK_EQ(input_shardings.size(), ins->operand_count()); if (input_shardings[k].has_value()) { - CHECK_EQ(input_shardings.size(), ins->operand_count()); cur_input_sharding = input_shardings[k]; } else { - cur_input_sharding = - GetInputSharding(ins, operand, k, output_sharding, call_graph, - cluster_env.NumDevices()); + cur_input_sharding = GetInputSharding( + ins, k, output_sharding, call_graph, cluster_env.NumDevices()); } bool is_sharding_default_replicated = false; if (!cur_input_sharding.has_value()) { @@ -187,8 +242,6 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( (ins->opcode() == HloOpcode::kScatter && k != 0)) { is_sharding_default_replicated = true; cur_input_sharding = HloSharding::Replicate(); - } else if (IsTopKCustomCall(ins)) { - cur_input_sharding = InferInputShardingForTopK(ins, output_sharding); } else if (ins->opcode() == HloOpcode::kCustomCall) { is_sharding_default_replicated = true; cur_input_sharding = HloSharding::Replicate(); @@ -206,27 +259,34 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( is_sharding_default_replicated) { VLOG(2) << "Zeroing out operand 0 resharding costs for gather sharding " << output_sharding.ToString(); - resharding_costs.push_back( + communication_resharding_costs.push_back( + std::vector(operand_strategies->strategies.size(), 0)); + memory_resharding_costs.push_back( std::vector(operand_strategies->strategies.size(), 0)); input_shardings[k] = std::nullopt; } else { - resharding_costs.push_back( - ReshardingCostVector(operand_strategies, ins->operand(k)->shape(), - *cur_input_sharding, cluster_env)); + communication_resharding_costs.push_back( + CommunicationReshardingCostVector( + operand_strategies, ins->operand(k)->shape(), + *cur_input_sharding, cluster_env)); + memory_resharding_costs.push_back(MemoryReshardingCostVector( + operand_strategies, ins->operand(k)->shape(), *cur_input_sharding, + cluster_env)); } } } - return resharding_costs; + return std::make_pair(communication_resharding_costs, + memory_resharding_costs); } -std::pair>, - std::vector>> +std::tuple>> GenerateReshardingCostsAndShardingsForAllOperands( const HloInstruction* ins, const HloSharding& output_sharding, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, const CallGraph& call_graph) { std::vector> input_shardings_optional; - auto resharding_costs = + std::pair resharding_costs = GenerateReshardingCostsAndMissingShardingsForAllOperands( ins, output_sharding, strategy_map, cluster_env, call_graph, input_shardings_optional); @@ -234,30 +294,35 @@ GenerateReshardingCostsAndShardingsForAllOperands( CHECK(sharding_optional.has_value()); } - return {resharding_costs, input_shardings_optional}; + return {resharding_costs.first, resharding_costs.second, + input_shardings_optional}; } // When computing resharding costs for inputs, this function assumes that the -// shape of the input is the same as the shape of the output (ie. the `shape` -// operand to the function) +// shape of the input is the same as the shape of the output (i.e., the `shape` +// operand to the function). void FollowArrayOrTokenStrategyGroup( const StrategyGroup& src_strategy_group, const Shape& shape, - size_t instruction_id, bool have_memory_cost, + const size_t instruction_id, const bool have_memory_cost, const ClusterEnvironment& cluster_env, - StableHashMap>& + const StableHashMap>& pretrimmed_strategy_map, StrategyGroup& strategy_group) { CHECK(shape.IsArray() || shape.IsToken()); + std::vector pretrimmed_strategies; // Only follows the given strategy when there is no other strategy to be // restored. - if (!pretrimmed_strategy_map.contains(src_strategy_group.node_idx)) { + auto pretrimmed_strategy_map_it = + pretrimmed_strategy_map.find(src_strategy_group.node_idx); + if (pretrimmed_strategy_map_it != pretrimmed_strategy_map.end()) { + pretrimmed_strategies = pretrimmed_strategy_map_it->second; + } else { strategy_group.following = &src_strategy_group; } + strategy_group.strategies.reserve(src_strategy_group.strategies.size()); // Creates the sharding strategies and restores trimmed strategies, if any. - std::vector& pretrimmed_strategies = - pretrimmed_strategy_map[src_strategy_group.node_idx]; for (int64_t sid = 0; sid < src_strategy_group.strategies.size() + pretrimmed_strategies.size(); ++sid) { @@ -271,30 +336,35 @@ void FollowArrayOrTokenStrategyGroup( VLOG(1) << "Adding outspec from the trimmed strategy map: " << output_spec->ToString(); } - std::string name = ToStringSimple(*output_spec); + const std::string name = ToStringSimple(*output_spec); double compute_cost = 0, communication_cost = 0; double memory_cost = have_memory_cost ? GetBytes(shape) / output_spec->NumTiles() : 0; size_t num_in_nodes = strategy_group.in_nodes.size(); std::vector> input_shardings(num_in_nodes, *output_spec); - std::vector> resharding_costs; + ReshardingCosts communication_resharding_costs; + ReshardingCosts memory_resharding_costs; for (size_t i = 0; i < strategy_group.in_nodes.size(); ++i) { - resharding_costs.push_back(ReshardingCostVector( + communication_resharding_costs.push_back( + CommunicationReshardingCostVector(strategy_group.in_nodes[i], shape, + *output_spec, cluster_env)); + memory_resharding_costs.push_back(MemoryReshardingCostVector( strategy_group.in_nodes[i], shape, *output_spec, cluster_env)); } strategy_group.strategies.push_back( ShardingStrategy({name, *output_spec, compute_cost, communication_cost, - memory_cost, resharding_costs, input_shardings})); + memory_cost, communication_resharding_costs, + memory_resharding_costs, input_shardings})); } } std::unique_ptr MaybeFollowInsStrategyGroup( const StrategyGroup* src_strategy_group, const Shape& shape, - size_t instruction_id, bool have_memory_cost, + const size_t instruction_id, const bool have_memory_cost, StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, - StableHashMap>& + const StableHashMap>& pretrimmed_strategy_map) { std::unique_ptr strategy_group; if (src_strategy_group->is_tuple) { @@ -321,12 +391,12 @@ std::unique_ptr MaybeFollowInsStrategyGroup( return strategy_group; } -StatusOr> FollowReduceStrategy( +absl::StatusOr> FollowReduceStrategy( const HloInstruction* ins, const Shape& output_shape, const HloInstruction* operand, const HloInstruction* unit, - size_t instruction_id, StrategyMap& strategy_map, + const size_t instruction_id, StrategyMap& strategy_map, StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, - bool allow_mixed_mesh_shape, bool crash_at_error) { + const bool allow_mixed_mesh_shape, const bool crash_at_error) { std::unique_ptr strategy_group; if (output_shape.IsTuple()) { strategy_group = CreateTupleStrategyGroup(instruction_id); @@ -352,8 +422,8 @@ StatusOr> FollowReduceStrategy( strategy_group->following = src_strategy_group; strategy_group->strategies.reserve(src_strategy_group->strategies.size()); // Map operand dims to inst dim - // Example: f32[1,16]{1,0} reduce(f32[1,16,4096]{2,1,0} %param0, f32[] - // %param1), dimensions={2} + // Example: f32[1,16]{1,0} reduce(f32[1,16,4096]{2,1,0} %param0, + // f32[] %param1), dimensions={2} // op_dim_to_output_dim = [0, 1, -1] std::vector op_dim_to_output_dim = GetDimensionMapping(/*reduced_dimensions=*/ins->dimensions(), @@ -389,51 +459,59 @@ StatusOr> FollowReduceStrategy( std::unique_ptr unit_clone = unit->Clone(); // Creates a new reduce op with one output, which is easier to use // GetShardingFromUser() to get the input sharding. - auto new_reduce = HloInstruction::CreateReduce( + std::unique_ptr new_reduce = HloInstruction::CreateReduce( output_shape, operand_clone.get(), unit_clone.get(), ins->dimensions(), ins->to_apply()); operand_clone->set_sharding( src_strategy_group->strategies[sid].output_sharding); - auto s = new_reduce->ReplaceOperandWith(0, operand_clone.get()); + absl::Status s = new_reduce->ReplaceOperandWith(0, operand_clone.get()); if (!s.ok()) { continue; } - ShardingPropagation::ComputationMap computation_map; - bool changed = - InferReduceShardingFromOperand(new_reduce.get(), false, true); - CHECK(changed); + CHECK(InferReduceShardingFromOperand(new_reduce.get(), false, true)); HloSharding output_spec = new_reduce->sharding(); new_reduce.reset(); operand_clone.reset(); unit_clone.reset(); - std::string name = ToStringSimple(output_spec); + const std::string name = ToStringSimple(output_spec); double compute_cost = 0, communication_cost = 0; double memory_cost = GetBytes(output_shape) / output_spec.NumTiles(); - for (auto mesh_dim : all_reduce_dims) { + for (int64_t mesh_dim : all_reduce_dims) { communication_cost += cluster_env.AllReduceCost(memory_cost, mesh_dim); } - std::vector> resharding_costs; + ReshardingCosts communication_resharding_costs; + ReshardingCosts memory_resharding_costs; for (int64_t k = 0; k < ins->operand_count(); ++k) { - auto cur_operand = ins->operand(k); + const HloInstruction* cur_operand = ins->operand(k); if (ToString(cur_operand->shape().dimensions()) == ToString(operand->shape().dimensions())) { - auto operand_strategies = strategy_map.at(cur_operand).get(); - resharding_costs.push_back(ReshardingCostVector( - operand_strategies, output_shape, input_sharding, cluster_env)); + const StrategyGroup* operand_strategies = + strategy_map.at(cur_operand).get(); + communication_resharding_costs.push_back( + CommunicationReshardingCostVector(operand_strategies, + cur_operand->shape(), + input_sharding, cluster_env)); + memory_resharding_costs.push_back(MemoryReshardingCostVector( + operand_strategies, cur_operand->shape(), input_sharding, + cluster_env)); } else { - resharding_costs.push_back(std::vector( + communication_resharding_costs.push_back(std::vector( + strategy_map.at(cur_operand)->strategies.size(), 0.0)); + memory_resharding_costs.push_back(std::vector( strategy_map.at(cur_operand)->strategies.size(), 0.0)); } } - const ShardingStrategy strategy = ShardingStrategy({name, - output_spec, - compute_cost, - communication_cost, - memory_cost, - resharding_costs, - {input_sharding}}); + const ShardingStrategy strategy = + ShardingStrategy({name, + output_spec, + compute_cost, + communication_cost, + memory_cost, + communication_resharding_costs, + memory_resharding_costs, + {input_sharding}}); strategy_group->strategies.push_back(strategy); } } else { @@ -453,14 +531,15 @@ std::vector FindReplicateStrategyIndices( return indices; } -std::pair>, - std::vector>> +std::tuple>> ReshardingCostsForTupleOperand(const HloInstruction* operand, StrategyGroup* operand_strategy_vector) { // TODO(yuemmawang) Support instructions with more than one tuple operand. // Creates resharding costs such that favors when operand strategies are // replicated. - std::vector> resharding_costs; + ReshardingCosts communication_resharding_costs; + ReshardingCosts memory_resharding_costs; std::vector tuple_element_shardings; for (size_t tuple_element_idx = 0; tuple_element_idx < operand->shape().tuple_shapes_size(); @@ -473,21 +552,23 @@ ReshardingCostsForTupleOperand(const HloInstruction* operand, << "There is no replicated strategy in instruction " << operand->ToString() << ".\nStrategies:\n" << tuple_element_strategies->ToString(); - resharding_costs.push_back(std::vector( + memory_resharding_costs.push_back( + std::vector(tuple_element_strategies->strategies.size(), 0)); + communication_resharding_costs.push_back(std::vector( tuple_element_strategies->strategies.size(), kInfinityCost)); tuple_element_shardings.push_back(HloSharding::Replicate()); for (const size_t i : indices) { - resharding_costs.back().at(i) = 0.0; + communication_resharding_costs.back().at(i) = 0.0; } } - return {resharding_costs, + return {communication_resharding_costs, memory_resharding_costs, std::vector>( {HloSharding::Tuple(operand->shape(), tuple_element_shardings)})}; } -std::vector> CreateZeroReshardingCostsForAllOperands( +ReshardingCosts CreateZeroReshardingCostsForAllOperands( const HloInstruction* ins, const StrategyMap& strategy_map) { - std::vector> resharding_costs; + ReshardingCosts resharding_costs; for (size_t i = 0; i < ins->operand_count(); ++i) { auto operand = ins->operand(i); const auto& operand_strategies = strategy_map.at(operand); @@ -520,12 +601,13 @@ void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, - double replicated_penalty) { + const double replicated_penalty) { HloSharding output_spec = HloSharding::Replicate(); - std::vector> resharding_costs; + ReshardingCosts communication_resharding_costs; + ReshardingCosts memory_resharding_costs; std::vector> input_shardings; - int tuple_size = ins->operand(0)->shape().tuple_shapes_size(); + const int tuple_size = ins->operand(0)->shape().tuple_shapes_size(); if (ins->has_sharding()) { std::vector operand_shapes(ins->operand_count()); for (int i = 0; i < ins->operand_count(); ++i) { @@ -547,7 +629,12 @@ void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, for (size_t i = 0; i < tuple_size; ++i) { auto input_sharding = get_input_sharding(i); input_shardings.push_back(input_sharding); - resharding_costs.push_back(ReshardingCostVector( + communication_resharding_costs.push_back( + CommunicationReshardingCostVector( + strategy_map.at(ins->operand(0))->childs[i].get(), + ins->operand(0)->shape().tuple_shapes(i), input_sharding, + cluster_env)); + memory_resharding_costs.push_back(MemoryReshardingCostVector( strategy_map.at(ins->operand(0))->childs[i].get(), ins->operand(0)->shape().tuple_shapes(i), input_sharding, cluster_env)); @@ -556,16 +643,21 @@ void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, input_shardings.push_back(input_sharding); } else { for (size_t i = 0; i < tuple_size; ++i) { - resharding_costs.push_back(std::vector( + communication_resharding_costs.push_back(std::vector( + strategy_map.at(ins->operand(0))->childs[i].get()->strategies.size(), + 0)); + memory_resharding_costs.push_back(std::vector( strategy_map.at(ins->operand(0))->childs[i].get()->strategies.size(), 0)); } } - resharding_costs.push_back({}); + communication_resharding_costs.push_back({}); + memory_resharding_costs.push_back({}); double memory_cost = GetBytes(shape) / output_spec.NumTiles(); - strategy_group->strategies.push_back(ShardingStrategy( - {"R", HloSharding::Replicate(), replicated_penalty, 0, memory_cost, - std::move(resharding_costs), input_shardings})); + strategy_group->strategies.push_back( + ShardingStrategy({"R", HloSharding::Replicate(), replicated_penalty, 0, + memory_cost, std::move(communication_resharding_costs), + std::move(memory_resharding_costs), input_shardings})); } double ComputeCommunicationCost( @@ -583,7 +675,7 @@ double ComputeCommunicationCost( // As seen in the test // SpmdPartitioningTest.GatherPartitionedOnTrivialSliceDims (in file // third_party/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc), - // when the gather op is replicated, and the first operand sharded, we + // when the gather op is replicated and the first operand sharded, we // need an AllReduce to implement the gather op. We capture that cost // here. // TODO(pratikf) Model gather communication costs in a more principled @@ -609,8 +701,9 @@ double ComputeCommunicationCost( void AddReplicatedStrategy( const HloInstruction* ins, const Shape& shape, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - std::unique_ptr& strategy_group, double replicated_penalty, - absl::flat_hash_set operands_to_consider_all_strategies_for = {}) { + std::unique_ptr& strategy_group, + const double replicated_penalty, + absl::flat_hash_set operands_to_consider_all_strategies_for) { HloSharding replicated_strategy = HloSharding::Replicate(); HloSharding output_spec = replicated_strategy; double memory_cost = GetBytes(shape) / output_spec.NumTiles(); @@ -626,9 +719,12 @@ void AddReplicatedStrategy( possible_input_shardings( operand_strategies_to_consider->strategies.size(), std::vector>(ins->operand_count())); - std::vector>> possible_resharding_costs( + std::vector possible_communication_resharding_costs( + operand_strategies_to_consider->strategies.size(), + ReshardingCosts(ins->operand_count())); + std::vector possible_memory_resharding_costs( operand_strategies_to_consider->strategies.size(), - std::vector>(ins->operand_count())); + ReshardingCosts(ins->operand_count())); for (int64_t k = 0; k < ins->operand_count(); ++k) { CHECK(!ins->operand(k)->shape().IsTuple()); @@ -638,7 +734,13 @@ void AddReplicatedStrategy( for (size_t j = 0; j < possible_input_shardings.size(); ++j) { possible_input_shardings[j][k] = operand_strategies_to_consider->strategies[j].output_sharding; - possible_resharding_costs[j][k] = ReshardingCostVector( + possible_communication_resharding_costs[j][k] = + CommunicationReshardingCostVector( + strategy_map.at(ins->operand(k)).get(), + ins->operand(k)->shape(), + operand_strategies_to_consider->strategies[j].output_sharding, + cluster_env); + possible_memory_resharding_costs[j][k] = MemoryReshardingCostVector( strategy_map.at(ins->operand(k)).get(), ins->operand(k)->shape(), operand_strategies_to_consider->strategies[j].output_sharding, cluster_env); @@ -646,7 +748,11 @@ void AddReplicatedStrategy( } else { for (size_t j = 0; j < possible_input_shardings.size(); ++j) { possible_input_shardings[j][k] = replicated_strategy; - possible_resharding_costs[j][k] = ReshardingCostVector( + possible_communication_resharding_costs[j][k] = + CommunicationReshardingCostVector( + strategy_map.at(ins->operand(k)).get(), + ins->operand(k)->shape(), replicated_strategy, cluster_env); + possible_memory_resharding_costs[j][k] = MemoryReshardingCostVector( strategy_map.at(ins->operand(k)).get(), ins->operand(k)->shape(), replicated_strategy, cluster_env); } @@ -658,11 +764,13 @@ void AddReplicatedStrategy( ins, possible_input_shardings[j], cluster_env); strategy_group->strategies.push_back(ShardingStrategy( {"R", replicated_strategy, replicated_penalty, communication_cost, - memory_cost, std::move(possible_resharding_costs[j]), + memory_cost, std::move(possible_communication_resharding_costs[j]), + std::move(possible_memory_resharding_costs[j]), std::move(possible_input_shardings[j])})); } } else { - std::vector> resharding_costs; + ReshardingCosts communication_resharding_costs; + ReshardingCosts memory_resharding_costs; std::vector> input_shardings; if (ins->operand_count() > 0 && ins->operand(0)->shape().IsTuple()) { @@ -670,17 +778,24 @@ void AddReplicatedStrategy( << "Do not support instructions with more than one tuple " "operand. If this CHECK fails, we will need to fix " "b/233412625."; - std::tie(resharding_costs, input_shardings) = + std::tie(communication_resharding_costs, memory_resharding_costs, + input_shardings) = ReshardingCostsForTupleOperand( ins->operand(0), strategy_map.at(ins->operand(0)).get()); } else { for (int64_t k = 0; k < ins->operand_count(); ++k) { auto operand = ins->operand(k); if (ins->opcode() == HloOpcode::kConditional) { - resharding_costs.push_back(std::vector( + communication_resharding_costs.push_back(std::vector( + strategy_map.at(operand)->strategies.size(), 0)); + memory_resharding_costs.push_back(std::vector( strategy_map.at(operand)->strategies.size(), 0)); } else { - resharding_costs.push_back(ReshardingCostVector( + communication_resharding_costs.push_back( + CommunicationReshardingCostVector(strategy_map.at(operand).get(), + ins->operand(k)->shape(), + output_spec, cluster_env)); + memory_resharding_costs.push_back(MemoryReshardingCostVector( strategy_map.at(operand).get(), ins->operand(k)->shape(), output_spec, cluster_env)); input_shardings.push_back(output_spec); @@ -689,15 +804,16 @@ void AddReplicatedStrategy( } strategy_group->strategies.push_back(ShardingStrategy( {"R", HloSharding::Replicate(), replicated_penalty, 0, memory_cost, - std::move(resharding_costs), input_shardings})); + std::move(communication_resharding_costs), + std::move(memory_resharding_costs), input_shardings})); } } // TODO(pratikf) Communication costs for sort HLO ops. This is currently a // placeholder approximation and should be improved. -double ComputeSortCommunicationCost(int64_t sort_dim, - int64_t operand_sharded_dim, - int64_t mesh_sharding_dim, +double ComputeSortCommunicationCost(const int64_t sort_dim, + const int64_t operand_sharded_dim, + const int64_t mesh_sharding_dim, const Shape& shape, const ClusterEnvironment& cluster_env) { if (sort_dim == operand_sharded_dim) { @@ -712,7 +828,7 @@ void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, - bool only_allow_divisible, + const bool only_allow_divisible, const std::string& suffix, const CallGraph& call_graph) { for (int64_t i = 0; i < shape.rank(); ++i) { @@ -723,34 +839,39 @@ void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, continue; } - std::string name = absl::StrFormat("S%d @ %d", i, j) + suffix; + const std::string name = absl::StrFormat("S%d @ %d", i, j) + suffix; HloSharding output_spec = Tile(shape, {i}, {j}, device_mesh); double compute_cost = 0, communication_cost = 0; double memory_cost = GetBytes(shape) / output_spec.NumTiles(); - std::vector> resharding_costs; + ReshardingCosts communication_resharding_costs; + ReshardingCosts memory_resharding_costs; std::vector> input_shardings; if (ins->opcode() == HloOpcode::kConditional) { // TODO(pratikf): Compute input_shardings for kConditional ops - resharding_costs = + communication_resharding_costs = + CreateZeroReshardingCostsForAllOperands(ins, strategy_map); + memory_resharding_costs = CreateZeroReshardingCostsForAllOperands(ins, strategy_map); } else if (ins->operand_count() > 0 && ins->operand(0)->shape().IsTuple()) { CHECK_EQ(ins->operand_count(), 1) << "Do not support instructions with more than one tuple " "operand."; - std::tie(resharding_costs, input_shardings) = + std::tie(communication_resharding_costs, memory_resharding_costs, + input_shardings) = ReshardingCostsForTupleOperand( ins->operand(0), strategy_map.at(ins->operand(0)).get()); } else if (ins->opcode() == HloOpcode::kRngBitGenerator && ins->operand(0)->shape().IsArray()) { input_shardings.push_back(HloSharding::Replicate()); - resharding_costs = + std::tie(communication_resharding_costs, memory_resharding_costs) = GenerateReshardingCostsAndMissingShardingsForAllOperands( ins, output_spec, strategy_map, cluster_env, call_graph, input_shardings); } else { - std::tie(resharding_costs, input_shardings) = + std::tie(communication_resharding_costs, memory_resharding_costs, + input_shardings) = GenerateReshardingCostsAndShardingsForAllOperands( ins, output_spec, strategy_map, cluster_env, call_graph); } @@ -768,7 +889,8 @@ void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, } strategy_group->strategies.push_back(ShardingStrategy( {name, output_spec, compute_cost, communication_cost, memory_cost, - std::move(resharding_costs), input_shardings})); + std::move(communication_resharding_costs), + std::move(memory_resharding_costs), input_shardings})); } } } @@ -781,17 +903,16 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, const CallGraph& call_graph, absl::Span tensor_dims); -// Enumerate all partitions recursively void EnumerateAllPartition(const HloInstruction* ins, const Shape& shape, const Array& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, const InstructionBatchDimMap& batch_dim_map, - bool only_allow_divisible, + const bool only_allow_divisible, const CallGraph& call_graph, - int64_t partition_dimensions, - const std::vector& tensor_dims = {}) { + const int64_t partition_dimensions, + const std::vector& tensor_dims) { const auto tensor_dims_size = tensor_dims.size(); if (tensor_dims_size == partition_dimensions) { BuildStrategyAndCostForOp(ins, shape, device_mesh, cluster_env, @@ -825,7 +946,6 @@ void EnumerateAllPartition(const HloInstruction* ins, const Shape& shape, } } -// Builds the strategy + cost for the given tensor_dims & mesh_dims. void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, const Array& device_mesh, const ClusterEnvironment& cluster_env, @@ -835,60 +955,60 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, absl::Span tensor_dims) { std::vector mesh_dims(tensor_dims.size()); std::iota(mesh_dims.begin(), mesh_dims.end(), 0); - std::string name = + const std::string name = absl::StrFormat("S{%s} @ {%s}", absl::StrJoin(tensor_dims, ","), absl::StrJoin(mesh_dims, ",")); HloSharding output_spec = Tile(shape, tensor_dims, mesh_dims, device_mesh); double compute_cost = 0, communication_cost = 0; double memory_cost = GetBytes(shape) / output_spec.NumTiles(); std::vector> input_shardings; - std::vector> resharding_costs; + ReshardingCosts communication_resharding_costs; + ReshardingCosts memory_resharding_costs; if (ins->opcode() == HloOpcode::kConditional) { // TODO(pratikf): Compute input_shardings for kConditional ops - resharding_costs = + communication_resharding_costs = + CreateZeroReshardingCostsForAllOperands(ins, strategy_map); + memory_resharding_costs = CreateZeroReshardingCostsForAllOperands(ins, strategy_map); } else if (ins->operand_count() > 0 && ins->operand(0)->shape().IsTuple()) { CHECK_EQ(ins->operand_count(), 1) << "Do not support instructions with more than one tuple " "operand. If this CHECK fails, we will need to fix " "b/233412625."; - std::tie(resharding_costs, input_shardings) = + std::tie(communication_resharding_costs, memory_resharding_costs, + input_shardings) = ReshardingCostsForTupleOperand(ins->operand(0), strategy_map.at(ins->operand(0)).get()); } else { - std::tie(resharding_costs, input_shardings) = + std::tie(communication_resharding_costs, memory_resharding_costs, + input_shardings) = GenerateReshardingCostsAndShardingsForAllOperands( ins, output_spec, strategy_map, cluster_env, call_graph); } // TODO(pratikf) Communication costs for sort HLO ops. This is currently a // placeholder approximation and should be improved. + int64_t sort_or_topk_dim = -1; if (ins->opcode() == HloOpcode::kSort) { auto sort_ins = xla::DynCast(ins); CHECK(sort_ins); - for (int64_t dim = 0; dim < tensor_dims.size(); ++dim) { - if (sort_ins->sort_dimension() == tensor_dims[dim]) { - communication_cost = ComputeSortCommunicationCost( - sort_ins->sort_dimension(), tensor_dims[dim], dim, shape, - cluster_env); - break; - } - } + sort_or_topk_dim = sort_ins->sort_dimension(); } else if (IsTopKCustomCall(ins)) { - auto topk_dim = ins->operand(0)->shape().rank() - 1; - for (int64_t dim = 0; dim < tensor_dims.size(); ++dim) { - if (topk_dim == tensor_dims[dim]) { - communication_cost = ComputeSortCommunicationCost( - topk_dim, tensor_dims[dim], dim, shape, cluster_env); - break; - } + sort_or_topk_dim = ins->operand(0)->shape().rank() - 1; + } + + if (sort_or_topk_dim != -1) { + if (auto index = GetIndex(tensor_dims, sort_or_topk_dim); index != -1) { + communication_cost = ComputeSortCommunicationCost( + sort_or_topk_dim, sort_or_topk_dim, index, shape, cluster_env); } } - strategy_group->strategies.push_back(ShardingStrategy( - {name, output_spec, compute_cost, communication_cost, memory_cost, - std::move(resharding_costs), input_shardings})); + + strategy_group->strategies.push_back( + ShardingStrategy({name, output_spec, compute_cost, communication_cost, + memory_cost, std::move(communication_resharding_costs), + std::move(memory_resharding_costs), input_shardings})); } -// Enumerate all 1d partition strategies for reshape. void EnumerateAll1DPartitionReshape( const HloInstruction* ins, const Array& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, @@ -918,20 +1038,25 @@ void EnumerateAll1DPartitionReshape( continue; } - std::string name = absl::StrFormat("S%d @ %d", i, j) + suffix; + const std::string name = absl::StrFormat("S%d @ %d", i, j) + suffix; double compute_cost = 0, communication_cost = 0; double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles(); - std::vector> resharding_costs{ - ReshardingCostVector(strategy_map.at(operand).get(), operand->shape(), - *input_spec, cluster_env)}; + ReshardingCosts communication_resharding_costs{ + CommunicationReshardingCostVector(strategy_map.at(operand).get(), + operand->shape(), *input_spec, + cluster_env)}; + ReshardingCosts memory_resharding_costs{MemoryReshardingCostVector( + strategy_map.at(operand).get(), operand->shape(), *input_spec, + cluster_env)}; strategy_group->strategies.push_back( ShardingStrategy({name, output_spec, compute_cost, communication_cost, memory_cost, - std::move(resharding_costs), + std::move(communication_resharding_costs), + std::move(memory_resharding_costs), {*input_spec}})); } } @@ -950,8 +1075,8 @@ void EnumeratePartitionReshape(const HloInstruction* ins, const StrategyMap& strategy_map, const InstructionBatchDimMap& batch_dim_map, std::unique_ptr& strategy_group, - bool only_allow_divisible, - int64_t partition_dimensions, + const bool only_allow_divisible, + const int64_t partition_dimensions, const std::vector& tensor_dims = {}) { const auto tensor_dims_size = tensor_dims.size(); if (tensor_dims_size == partition_dimensions) { @@ -1010,16 +1135,21 @@ void BuildStrategyAndCostForReshape( double compute_cost = 0, communication_cost = 0; double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles(); - std::vector> resharding_costs{ - ReshardingCostVector(strategy_map.at(operand).get(), operand->shape(), - *input_spec, cluster_env)}; + ReshardingCosts communication_resharding_costs{ + CommunicationReshardingCostVector(strategy_map.at(operand).get(), + operand->shape(), *input_spec, + cluster_env)}; + ReshardingCosts memory_resharding_costs{ + MemoryReshardingCostVector(strategy_map.at(operand).get(), + operand->shape(), *input_spec, cluster_env)}; strategy_group->strategies.push_back( ShardingStrategy({name, output_spec, compute_cost, communication_cost, memory_cost, - std::move(resharding_costs), + std::move(communication_resharding_costs), + std::move(memory_resharding_costs), {*input_spec}})); } @@ -1037,13 +1167,11 @@ int64_t MaxNumTiles(const StrategyMap& strategy_map, std::max(max_num_tiles, strategy_group->strategies[i].output_sharding.NumTiles()); } - return max_num_tiles; } -// Choose an operand to follow. We choose to follow the operand with the -// highest priority. The priority is defined as a function of two entities as -// below: +// Choose an operand to follow. We choose to follow the operand with the highest +// priority. The priority is defined as a function of two entities as below: // // priority(operand) = // max(x.output_spec.num_tiles for x in operand.strategies) + @@ -1061,7 +1189,8 @@ int64_t MaxNumTiles(const StrategyMap& strategy_map, // one to follow. std::pair ChooseOperandToFollow( const StrategyMap& strategy_map, const InstructionDepthMap& depth_map, - const AliasMap& alias_map, int64_t max_depth, const HloInstruction* ins) { + const AliasMap& alias_map, const int64_t max_depth, + const HloInstruction* ins) { // If an alias constraint is set, always follow its alias source. auto it = alias_map.find(ins); if (it != alias_map.end()) { @@ -1092,12 +1221,11 @@ std::pair ChooseOperandToFollow( } } CHECK(follow_idx.has_value()); - return {*follow_idx, tie}; } -// Return whether an instruciton can follow one of its operand when -// more than one operand have the same priority. +// Return whether an instruction can follow one of its operand when more than +// one operand have the same priority. // Consider adding special cases here if the auto sharding following strategy // behaves weird for your model. bool AllowTieFollowing(const HloInstruction* ins) { @@ -1142,13 +1270,65 @@ void DisableIncompatibleMixedMeshShapeAndForceBatchDim( } } -StatusOr> CreateAllStrategiesGroup( - const HloInstruction* ins, const Shape& shape, size_t instruction_id, +void FillAllStrategiesForArray( + std::unique_ptr& strategy_group, const HloInstruction* ins, + const Shape& shape, const ClusterEnvironment& cluster_env, + const StrategyMap& strategy_map, const AutoShardingOption& option, + const double replicated_penalty, + const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph, + const bool only_allow_divisible, const bool create_replicated_strategies, + const bool create_partially_replicated_strategies) { + if (create_partially_replicated_strategies || cluster_env.IsDeviceMesh1D()) { + EnumerateAll1DPartition(ins, shape, cluster_env.device_mesh_, cluster_env, + strategy_map, strategy_group, only_allow_divisible, + "", call_graph); + } + // Split 2 dims + if (cluster_env.IsDeviceMesh2D()) { + EnumerateAllPartition(ins, shape, cluster_env.device_mesh_, cluster_env, + strategy_map, strategy_group, batch_dim_map, + only_allow_divisible, call_graph, /*partitions*/ 2); + } + // Split 3 dims + if (cluster_env.IsDeviceMesh3D()) { + EnumerateAllPartition(ins, shape, cluster_env.device_mesh_, cluster_env, + strategy_map, strategy_group, batch_dim_map, + only_allow_divisible, call_graph, /*partitions*/ 3); + } + + if (option.allow_mixed_mesh_shape && cluster_env.IsDeviceMesh2D()) { + // Set penalty for 1d partial tiled layout + for (size_t i = 0; i < strategy_group->strategies.size(); ++i) { + strategy_group->strategies[i].compute_cost += replicated_penalty * 0.8; + } + + // Split 1 dim, but for 1d mesh + EnumerateAll1DPartition(ins, shape, cluster_env.device_mesh_1d_, + cluster_env, strategy_map, strategy_group, + only_allow_divisible, " 1d", call_graph); + } + if (create_replicated_strategies || strategy_group->strategies.empty()) { + AddReplicatedStrategy(ins, shape, cluster_env, strategy_map, strategy_group, + replicated_penalty); + } + + // If force_batch_dim_to_mesh_dim is set, filter out invalid strategies + // and only keep the data parallel strategies. + if (option.force_batch_dim_to_mesh_dim >= 0 && + batch_dim_map.contains(GetBatchDimMapKey(ins))) { + CHECK_OK(FilterStrategy(ins, shape, strategy_group, cluster_env, + batch_dim_map, option)); + } +} + +absl::StatusOr> CreateAllStrategiesGroup( + const HloInstruction* ins, const Shape& shape, const size_t instruction_id, StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const AutoShardingOption& option, - double replicated_penalty, const InstructionBatchDimMap& batch_dim_map, - const CallGraph& call_graph, bool only_allow_divisible, - bool create_replicated_strategies) { + const double replicated_penalty, + const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph, + const bool only_allow_divisible, const bool create_replicated_strategies, + const bool create_partially_replicated_strategies) { std::unique_ptr strategy_group; if (shape.IsTuple()) { strategy_group = CreateTupleStrategyGroup(instruction_id); @@ -1159,7 +1339,8 @@ StatusOr> CreateAllStrategiesGroup( strategy_groups, cluster_env, strategy_map, option, replicated_penalty, batch_dim_map, call_graph, only_allow_divisible, - create_replicated_strategies) + create_replicated_strategies, + create_partially_replicated_strategies) .value(); child_strategies->tuple_element_idx = i; strategy_group->childs.push_back(std::move(child_strategies)); @@ -1167,45 +1348,11 @@ StatusOr> CreateAllStrategiesGroup( } else if (shape.IsArray()) { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, strategy_groups); - EnumerateAll1DPartition(ins, shape, cluster_env.device_mesh_, cluster_env, - strategy_map, strategy_group, only_allow_divisible, - "", call_graph); - // Split 2 dims - if (cluster_env.IsDeviceMesh2D()) { - EnumerateAllPartition(ins, shape, cluster_env.device_mesh_, cluster_env, - strategy_map, strategy_group, batch_dim_map, - only_allow_divisible, call_graph, /*partitions*/ 2); - } - // Split 3 dims - if (cluster_env.IsDeviceMesh3D()) { - EnumerateAllPartition(ins, shape, cluster_env.device_mesh_, cluster_env, - strategy_map, strategy_group, batch_dim_map, - only_allow_divisible, call_graph, /*partitions*/ 3); - } - - if (option.allow_mixed_mesh_shape && cluster_env.IsDeviceMesh2D()) { - // Set penalty for 1d partial tiled layout - for (size_t i = 0; i < strategy_group->strategies.size(); ++i) { - strategy_group->strategies[i].compute_cost += replicated_penalty * 0.8; - } - - // Split 1 dim, but for 1d mesh - EnumerateAll1DPartition(ins, shape, cluster_env.device_mesh_1d_, - cluster_env, strategy_map, strategy_group, - only_allow_divisible, " 1d", call_graph); - } - if (create_replicated_strategies || strategy_group->strategies.empty()) { - AddReplicatedStrategy(ins, shape, cluster_env, strategy_map, - strategy_group, replicated_penalty); - } - // If force_batch_dim_to_mesh_dim is set, filter out invalid strategies - // and only keep the data parallel strategies. - if (option.force_batch_dim_to_mesh_dim >= 0 && - batch_dim_map.contains(GetBatchDimMapKey(ins))) { - TF_RETURN_IF_ERROR(FilterStrategy(ins, shape, strategy_group, cluster_env, - batch_dim_map, option)); - } + FillAllStrategiesForArray( + strategy_group, ins, shape, cluster_env, strategy_map, option, + replicated_penalty, batch_dim_map, call_graph, only_allow_divisible, + create_replicated_strategies, create_partially_replicated_strategies); } else if (shape.IsToken()) { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, strategy_groups); @@ -1261,11 +1408,11 @@ bool ShardingIsConsistent(const HloSharding& partial_sharding, void TrimOrGenerateStrategiesBasedOnExistingSharding( const Shape& output_shape, StrategyGroup* strategy_group, const StrategyMap& strategy_map, - const std::vector instructions, + const std::vector& instructions, const HloSharding& existing_sharding, const ClusterEnvironment& cluster_env, StableHashMap>& pretrimmed_strategy_map, - const CallGraph& call_graph, bool strict) { + const CallGraph& call_graph, const bool strict) { if (strategy_group->is_tuple) { for (size_t i = 0; i < strategy_group->childs.size(); ++i) { TrimOrGenerateStrategiesBasedOnExistingSharding( @@ -1274,6 +1421,9 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( cluster_env, pretrimmed_strategy_map, call_graph, strict); } } else { + if (existing_sharding.IsUnknown()) { + return; + } if (ShardingIsComplete(existing_sharding, cluster_env.device_mesh_.num_elements())) { // Sharding provided by XLA users, we need to keep them. @@ -1297,31 +1447,48 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( } else { VLOG(1) << "Generate a new strategy based on user sharding."; std::string name = ToStringSimple(existing_sharding); - std::vector> resharding_costs; + ReshardingCosts communication_resharding_costs; + ReshardingCosts memory_resharding_costs; std::vector> input_shardings; if (!strategy_group->in_nodes.empty()) { HloInstruction* ins = instructions.at(strategy_group->instruction_id); for (size_t i = 0; i < strategy_group->in_nodes.size(); i++) { HloInstruction* operand = instructions.at(strategy_group->in_nodes.at(i)->instruction_id); - std::optional input_sharding_or = - ShardingPropagation::GetShardingFromUser(*operand, *ins, 10, - true, call_graph); - if (input_sharding_or.has_value()) { - input_shardings.push_back(input_sharding_or.value()); - } - + std::optional input_sharding = + ShardingPropagation::GetShardingFromUser( + *operand, *ins, 10, true, call_graph, + /*sharding_helper=*/nullptr); StrategyGroup* operand_strategy_group = strategy_map.at(operand).get(); Shape operand_shape = operand->shape(); if (ins->opcode() == HloOpcode::kGetTupleElement) { + if (input_sharding && input_sharding->IsTuple()) { + input_sharding = input_sharding->GetSubSharding( + operand->shape(), {ins->tuple_index()}); + } operand_strategy_group = operand_strategy_group->childs[ins->tuple_index()].get(); - operand_shape = operand_shape.tuple_shapes(ins->tuple_index()); + operand_shape = operand->shape().tuple_shapes(ins->tuple_index()); } - resharding_costs.push_back( - ReshardingCostVector(operand_strategy_group, operand_shape, - existing_sharding, cluster_env)); + + if (input_sharding.has_value()) { + input_sharding = *input_sharding; + } else if (existing_sharding.Validate(operand_shape).ok()) { + input_sharding = existing_sharding; + } else { + input_sharding = HloSharding::Replicate(); + } + CHECK(input_sharding.has_value()); + + input_shardings.push_back(*input_sharding); + communication_resharding_costs.push_back( + CommunicationReshardingCostVector( + operand_strategy_group, operand_shape, *input_sharding, + cluster_env)); + memory_resharding_costs.push_back(MemoryReshardingCostVector( + operand_strategy_group, operand_shape, *input_sharding, + cluster_env)); } } double memory_cost = @@ -1333,18 +1500,19 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( strategy_group->strategies.clear(); strategy_group->strategies.push_back( ShardingStrategy({name, existing_sharding, 0, 0, memory_cost, - resharding_costs, input_shardings})); + communication_resharding_costs, + memory_resharding_costs, input_shardings})); } // If there is only one option for resharding, and the cost computed for // that option is kInfinityCost, set the cost to zero. This is okay // because there is only one option anyway, and having the costs set to // kInfinityCost is problematic for the solver. if (strategy_group->strategies.size() == 1) { - for (auto& operand_resharding_costs : - strategy_group->strategies[0].resharding_costs) { - if (operand_resharding_costs.size() == 1 && - operand_resharding_costs[0] >= kInfinityCost) { - operand_resharding_costs[0] = 0; + for (auto& operand_communication_resharding_costs : + strategy_group->strategies[0].communication_resharding_costs) { + if (operand_communication_resharding_costs.size() == 1 && + operand_communication_resharding_costs[0] >= kInfinityCost) { + operand_communication_resharding_costs[0] = 0; } } } @@ -1404,9 +1572,9 @@ void CheckMemoryCosts(StrategyGroup* strategy_group, const Shape& shape) { } } -void RemoveInvalidShardingsWithShapes(const Shape& shape, - StrategyGroup* strategy_group, - bool instruction_has_user_sharding) { +void RemoveInvalidShardingsWithShapes( + const Shape& shape, StrategyGroup* strategy_group, + const bool instruction_has_user_sharding) { if (strategy_group->is_tuple) { for (size_t i = 0; i < strategy_group->childs.size(); i++) { RemoveInvalidShardingsWithShapes(shape.tuple_shapes().at(i), @@ -1461,21 +1629,23 @@ void CheckReshardingCostsShape(StrategyGroup* strategy_group) { if (strategy_group->in_nodes.size() == 1 && strategy_group->in_nodes.at(0)->is_tuple) { // This is when current instruction's only operand is tuple, and the - // first dimension of resharding_costs should equal its number of - // tuple elements. - CHECK_EQ(strategy.resharding_costs.size(), + // first dimension of communication_resharding_costs should equal its + // number of tuple elements. + CHECK_EQ(strategy.communication_resharding_costs.size(), strategy_group->in_nodes.at(0)->childs.size()) << "Instruction ID: " << strategy_group->instruction_id << "\n" << strategy_group->ToString(); } else { - // The rest of the time, the first dimension of resharding_costs - // should equal its number of operands (in_nodes). - CHECK_EQ(strategy.resharding_costs.size(), + // The rest of the time, the first dimension of + // communication_resharding_costs should equal its number of operands + // (in_nodes). + CHECK_EQ(strategy.communication_resharding_costs.size(), strategy_group->in_nodes.size()) << "Instruction ID: " << strategy_group->instruction_id << "\n" << strategy_group->ToString(); } - for (size_t i = 0; i < strategy.resharding_costs.size(); i++) { + for (size_t i = 0; i < strategy.communication_resharding_costs.size(); + i++) { size_t to_compare; if (strategy_group->in_nodes.size() == 1 && strategy_group->in_nodes.at(0)->is_tuple) { @@ -1486,8 +1656,8 @@ void CheckReshardingCostsShape(StrategyGroup* strategy_group) { } else { to_compare = strategy_group->in_nodes.at(i)->strategies.size(); } - CHECK_EQ(strategy.resharding_costs[i].size(), to_compare) - << "\nIndex of resharding_costs: " << i + CHECK_EQ(strategy.communication_resharding_costs[i].size(), to_compare) + << "\nIndex of communication_resharding_costs: " << i << "\nInstruction ID: " << strategy_group->instruction_id << "\nCurrent strategies:\n" << strategy_group->ToString(); @@ -1496,17 +1666,8 @@ void CheckReshardingCostsShape(StrategyGroup* strategy_group) { } } -bool LeafVectorsAreConsistent(const std::vector& one, - const std::vector& two, - bool is_reshape) { - if (one.size() != two.size()) { - return false; - } - return true; -} - void ScaleCostsWithExecutionCounts(StrategyGroup* strategy_group, - int64_t execution_count) { + const int64_t execution_count) { if (strategy_group->is_tuple) { for (size_t i = 0; i < strategy_group->childs.size(); ++i) { ScaleCostsWithExecutionCounts(strategy_group->childs[i].get(), @@ -1516,29 +1677,29 @@ void ScaleCostsWithExecutionCounts(StrategyGroup* strategy_group, for (auto& strategy : strategy_group->strategies) { strategy.compute_cost *= execution_count; strategy.communication_cost *= execution_count; - for (auto i = 0; i < strategy.resharding_costs.size(); ++i) { - for (auto j = 0; j < strategy.resharding_costs[i].size(); ++j) { - strategy.resharding_costs[i][j] *= execution_count; + for (auto i = 0; i < strategy.communication_resharding_costs.size(); + ++i) { + for (auto j = 0; j < strategy.communication_resharding_costs[i].size(); + ++j) { + strategy.communication_resharding_costs[i][j] *= execution_count; } } } } } -// Enumerates sharding strategies for elementwise operators by following -// strategies of an operand of the elementwise op. std::unique_ptr CreateElementwiseOperatorStrategies( - size_t instruction_id, const HloInstruction* ins, + const size_t instruction_id, const HloInstruction* ins, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, const InstructionDepthMap& depth_map, const AliasMap& alias_map, - StableHashMap>& + const StableHashMap>& pretrimmed_strategy_map, - int64_t max_depth, StrategyGroups& strategy_groups, + const int64_t max_depth, StrategyGroups& strategy_groups, AssociativeDotPairs& associative_dot_pairs) { std::unique_ptr strategy_group = CreateLeafStrategyGroup( instruction_id, ins, strategy_map, strategy_groups); - // Choose an operand to follow + // Choose an operand to follow. int64_t follow_idx; bool tie; std::tie(follow_idx, tie) = @@ -1550,7 +1711,7 @@ std::unique_ptr CreateElementwiseOperatorStrategies( strategy_group->following = nullptr; } - // Get all possible sharding specs from operands + // Get all possible sharding specs from operands. for (int64_t i = 0; i < ins->operand_count(); ++i) { if (strategy_group->following != nullptr && i != follow_idx) { // If ins follows one operand, do not consider sharding specs from @@ -1585,23 +1746,22 @@ std::unique_ptr CreateElementwiseOperatorStrategies( return strategy_group; } -// Enumerates sharding strategies for reshape operators. The function does so by -// essentially reshaping the sharding of the operand in a manner similar to the -// tensor reshape itself. std::unique_ptr CreateReshapeStrategies( - size_t instruction_id, const HloInstruction* ins, + const size_t instruction_id, const HloInstruction* ins, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, - bool only_allow_divisible, double replicated_penalty, + const bool only_allow_divisible, const double replicated_penalty, const InstructionBatchDimMap& batch_dim_map, - const AutoShardingOption& option, StrategyGroups& strategy_groups) { - std::unique_ptr strategy_group = CreateLeafStrategyGroup( - instruction_id, ins, strategy_map, strategy_groups); - const HloInstruction* operand = ins->operand(0); + const AutoShardingOption& option, StrategyGroups& strategy_groups, + const CallGraph& call_graph) { const Array& device_mesh = cluster_env.device_mesh_; - const Array& device_mesh_1d = cluster_env.device_mesh_1d_; int mesh_nn_dims = VectorGreaterThanOneElementCount(device_mesh.dimensions()); + std::unique_ptr strategy_group = CreateLeafStrategyGroup( + instruction_id, ins, strategy_map, strategy_groups); + if (mesh_nn_dims < 2 || !option.allow_mixed_mesh_shape) { + const HloInstruction* operand = ins->operand(0); + // Create follow strategies const StrategyGroup* src_strategy_group = strategy_map.at(operand).get(); CHECK(!src_strategy_group->is_tuple); @@ -1624,10 +1784,14 @@ std::unique_ptr CreateReshapeStrategies( if (!TileAssignmentMatchesMesh(*output_spec, device_mesh)) { continue; } - std::string name = ToStringSimple(*output_spec); + const std::string name = ToStringSimple(*output_spec); double compute_cost = 0, communication_cost = 0; double memory_cost = GetBytes(ins->shape()) / output_spec->NumTiles(); - std::vector resharding_costs = ReshardingCostVector( + std::vector communication_resharding_costs = + CommunicationReshardingCostVector( + src_strategy_group, operand->shape(), + src_strategy_group->strategies[sid].output_sharding, cluster_env); + std::vector memory_resharding_costs = MemoryReshardingCostVector( src_strategy_group, operand->shape(), src_strategy_group->strategies[sid].output_sharding, cluster_env); strategy_group->strategies.push_back(ShardingStrategy( @@ -1636,858 +1800,38 @@ std::unique_ptr CreateReshapeStrategies( compute_cost, communication_cost, memory_cost, - {resharding_costs}, + {communication_resharding_costs}, + {memory_resharding_costs}, {src_strategy_group->strategies[sid].output_sharding}})); } } - // Fail to create follow strategies, enumerate all possible cases if (strategy_group->strategies.empty()) { - strategy_group->strategies.clear(); - strategy_group->following = nullptr; - - // Split 1 dim - if (cluster_env.IsDeviceMesh1D()) { - EnumerateAll1DPartitionReshape(ins, device_mesh, cluster_env, - strategy_map, strategy_group, - only_allow_divisible, ""); - } - if (option.allow_mixed_mesh_shape && cluster_env.IsDeviceMesh2D()) { - // Split 1 dim, but for 1d mesh - EnumerateAll1DPartitionReshape(ins, device_mesh_1d, cluster_env, - strategy_map, strategy_group, - only_allow_divisible, " 1d"); - } - if (cluster_env.IsDeviceMesh2D()) { - // Split 2 dim, one is always the batch dim - EnumeratePartitionReshape(ins, device_mesh, cluster_env, strategy_map, - batch_dim_map, strategy_group, - only_allow_divisible, - /*partitions*/ 2); - } - if (cluster_env.IsDeviceMesh3D()) { - // Split 3 dim, one is always the batch dim - EnumeratePartitionReshape(ins, device_mesh, cluster_env, strategy_map, - batch_dim_map, strategy_group, - only_allow_divisible, - /*partitions*/ 3); - } - - // Replicate - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, replicated_penalty); + // Fail to create follow strategies, enumerate all possible cases + VLOG(2) << "Enumerating all strategies for reshape"; + FillAllStrategiesForArray( + strategy_group, ins, ins->shape(), cluster_env, strategy_map, option, + replicated_penalty, batch_dim_map, call_graph, only_allow_divisible, + /* create_replicated_strategies */ true, + /* create_partially_replicated_strategies */ true); } - return strategy_group; -} - -// NOLINTBEGIN(readability/fn_size) -// TODO(zhuohan): Decompose this function into smaller pieces -// Build possible sharding strategies and their costs for all instructions. -StatusOr> -BuildStrategyAndCost(const HloInstructionSequence& sequence, - const HloModule* module, - const absl::flat_hash_map& - instruction_execution_counts, - const InstructionDepthMap& depth_map, - const InstructionBatchDimMap& batch_dim_map, - const AliasMap& alias_map, - const ClusterEnvironment& cluster_env, - AutoShardingOption& option, const CallGraph& call_graph, - const HloCostAnalysis& hlo_cost_analysis, - bool trying_multiple_mesh_shapes) { - const Array& device_mesh = cluster_env.device_mesh_; - const Array& device_mesh_1d = cluster_env.device_mesh_1d_; - StrategyMap strategy_map; - // This map stores all of the trimmed strategies due to user specified - // sharding. The key is the instruction id, the value is the strategies. This - // is useful when the operand is forced to use a user sharding, and the op - // doesn't need to strictly follow it. We restore the trimmed strategies in - // this situation. - StableHashMap> pretrimmed_strategy_map; - StrategyGroups strategy_groups; - AssociativeDotPairs associative_dot_pairs; - - const std::vector& instructions = sequence.instructions(); - - // Add penalty for replicated tensors - double replicated_penalty = std::round(cluster_env.AllReduceCost(1, 0) + - cluster_env.AllReduceCost(1, 1)); - - int64_t max_depth = -1; - for (auto iter : depth_map) { - max_depth = std::max(max_depth, iter.second); - } - - // Register strategies and their costs for each instruction. - for (size_t instruction_id = 0; instruction_id < instructions.size(); - ++instruction_id) { - const HloInstruction* ins = instructions[instruction_id]; - VLOG(2) << "instruction_id = " << instruction_id << ": " - << ToAdaptiveString(ins); - std::unique_ptr strategy_group; - - HloOpcode opcode = ins->opcode(); - - bool only_allow_divisible; - if (IsEntryComputationInputOrOutput(module, ins)) { - // With IsEntryComputationInputOrOutput(module, ins) == true, entry - // computation's root instruction may still be unevenly sharded because it - // usually "follows" other instruction's sharding. If the instruction it - // follows is an intermediate instruction, it may be able to choose - // unevenly sharded strategiyes. Usually if we constraint input's sharding - // strategies, outputs would be constrained as welll, but if outputs are - // still unevely sharded in some cases, we need to fix the implementation - // in auto sharding. - only_allow_divisible = option.only_allow_divisible_input_output; - } else { - only_allow_divisible = option.only_allow_divisible_intermediate; - } - - switch (opcode) { - case HloOpcode::kParameter: - case HloOpcode::kRngBitGenerator: - case HloOpcode::kRng: { - strategy_group = - CreateAllStrategiesGroup(ins, ins->shape(), instruction_id, - strategy_groups, cluster_env, strategy_map, - option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, - option.allow_replicated_parameters) - .value(); - break; - } - case HloOpcode::kConstant: { - strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, - strategy_groups); - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, 0); - break; - } - case HloOpcode::kScatter: { - strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, strategy_groups); - // We follow the first operand (the array we're scattering into) - auto src_strategy_group = strategy_map.at(ins->operand(0)).get(); - CHECK(!src_strategy_group->is_tuple); - for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); - ++sid) { - HloSharding output_spec = - src_strategy_group->strategies[sid].output_sharding; - std::string name = ToStringSimple(output_spec); - double compute_cost = 0, communication_cost = 0; - double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles(); - - std::vector> input_shardings_optional( - {output_spec, std::nullopt, std::nullopt}); - std::vector> resharding_cost = - GenerateReshardingCostsAndMissingShardingsForAllOperands( - ins, output_spec, strategy_map, cluster_env, call_graph, - input_shardings_optional); - - for (const auto& sharding_optional : input_shardings_optional) { - CHECK(sharding_optional.has_value()); - } - - strategy_group->strategies.push_back(ShardingStrategy( - {name, output_spec, compute_cost, communication_cost, memory_cost, - std::move(resharding_cost), input_shardings_optional})); - } - break; - } - case HloOpcode::kGather: { - strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, strategy_groups); - const HloInstruction* indices = ins->operand(1); - const Shape& shape = ins->shape(); - for (int32_t index_dim = 0; index_dim < indices->shape().rank(); - index_dim++) { - // Shard on indices dimensions that correspond to output dimensions - // TODO(b/220935014) Shard the last dim of output (model dim) with - // AllGather cost and no follow. - if (index_dim == ins->gather_dimension_numbers().index_vector_dim()) { - continue; - } - for (int64_t j = 0; j < device_mesh.num_dimensions(); ++j) { - // Split only when the tensor shape is divisible by device - // mesh. - if (device_mesh.dim(j) == 1 || - (only_allow_divisible && - !IsDivisible(shape.dimensions(index_dim), - device_mesh.dim(j)))) { - continue; - } - std::string name = absl::StrCat("S", index_dim, " @ ", j); - - HloSharding output_spec = - Tile(shape, {index_dim}, {j}, device_mesh); - double compute_cost = 0, communication_cost = 0; - double memory_cost = GetBytes(shape) / output_spec.NumTiles(); - std::optional input_spec = - hlo_sharding_util::ReshapeSharding(shape, indices->shape(), - output_spec); - if (!input_spec.has_value()) { // invalid reshape - continue; - } - std::vector> input_shardings_optional( - {std::nullopt, input_spec}); - std::vector> resharding_cost = - GenerateReshardingCostsAndMissingShardingsForAllOperands( - ins, output_spec, strategy_map, cluster_env, call_graph, - input_shardings_optional); - - strategy_group->strategies.push_back(ShardingStrategy( - {name, output_spec, compute_cost, communication_cost, - memory_cost, std::move(resharding_cost), - input_shardings_optional})); - } - } - auto src_strategy_group = strategy_map.at(ins->operand(0)).get(); - for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); - ++sid) { - HloSharding output_spec = - src_strategy_group->strategies[sid].output_sharding; - auto gather_parallel_dims = - hlo_sharding_util::GetGatherParallelBatchDims(*ins, call_graph); - absl::Span operand_parallel_dims; - if (gather_parallel_dims) { - operand_parallel_dims = absl::MakeConstSpan( - gather_parallel_dims->operand_parallel_dims); - } - HloSharding filtered_operand_sharding = - hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( - output_spec, operand_parallel_dims); - auto maybe_from_data = hlo_sharding_util:: - GatherOutputShardingFromOperandOperandPassthroughDimensions( - filtered_operand_sharding, *ins); - if (!maybe_from_data) continue; - std::string name = ToStringSimple(*maybe_from_data); - double compute_cost = 0, communication_cost = 0; - double memory_cost = - GetBytes(ins->shape()) / maybe_from_data->NumTiles(); - std::vector> input_shardings_optional( - {*maybe_from_data, std::nullopt}); - std::vector> resharding_cost = - GenerateReshardingCostsAndMissingShardingsForAllOperands( - ins, *maybe_from_data, strategy_map, cluster_env, call_graph, - input_shardings_optional); - strategy_group->strategies.push_back(ShardingStrategy( - {name, *maybe_from_data, compute_cost, communication_cost, - memory_cost, std::move(resharding_cost), - input_shardings_optional})); - } - AddReplicatedStrategy( - ins, ins->shape(), cluster_env, strategy_map, strategy_group, 0, - /* operands_to_consider_all_strategies_for */ {0}); - break; - } - case HloOpcode::kBroadcast: { - strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, strategy_groups); - - if (ins->shape().rank() == 1 || cluster_env.IsDeviceMesh1D()) { - EnumerateAll1DPartition(ins, ins->shape(), cluster_env.device_mesh_, - cluster_env, strategy_map, strategy_group, - only_allow_divisible, "", call_graph); - } else { - EnumerateAllPartition(ins, ins->shape(), cluster_env.device_mesh_, - cluster_env, strategy_map, strategy_group, - batch_dim_map, only_allow_divisible, call_graph, - /*partitions*/ 2); - if (option.allow_mixed_mesh_shape) { - EnumerateAll1DPartition(ins, ins->shape(), - cluster_env.device_mesh_1d_, cluster_env, - strategy_map, strategy_group, - only_allow_divisible, "1d", call_graph); - } - } - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, replicated_penalty); - - break; - } - case HloOpcode::kReshape: { - strategy_group = CreateReshapeStrategies( - instruction_id, ins, strategy_map, cluster_env, - only_allow_divisible, replicated_penalty, batch_dim_map, option, - strategy_groups); - break; - } - case HloOpcode::kTranspose: - case HloOpcode::kReverse: { - strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, strategy_groups); - - const HloInstruction* operand = ins->operand(0); - - // Create follow strategies - const StrategyGroup* src_strategy_group = - strategy_map.at(operand).get(); - CHECK(!src_strategy_group->is_tuple); - strategy_group->following = src_strategy_group; - - for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); - ++sid) { - HloSharding output_spec = Undefined(); - auto input_spec = src_strategy_group->strategies[sid].output_sharding; - if (opcode == HloOpcode::kTranspose) { - output_spec = hlo_sharding_util::TransposeSharding( - input_spec, ins->dimensions()); - } else { - output_spec = hlo_sharding_util::ReverseSharding(input_spec, - ins->dimensions()); - } - - std::string name = ToStringSimple(output_spec); - double compute_cost = 0, communication_cost = 0; - double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles(); - auto resharding_costs = ReshardingCostVector( - src_strategy_group, operand->shape(), input_spec, cluster_env); - strategy_group->strategies.push_back( - ShardingStrategy({name, - output_spec, - compute_cost, - communication_cost, - memory_cost, - {resharding_costs}, - {input_spec}})); - } - break; - } - case HloOpcode::kPad: - case HloOpcode::kSlice: - case HloOpcode::kConcatenate: // TODO(zhuohan): revisit concatenate - case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kReduceWindow: - case HloOpcode::kSelectAndScatter: { - strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, strategy_groups); - int64_t follow_idx; - switch (opcode) { - // TODO(yuemmawang) Re-evaluate the follow_idx choices for the - // following 3. - case HloOpcode::kPad: - case HloOpcode::kReduceWindow: - case HloOpcode::kSelectAndScatter: - case HloOpcode::kConcatenate: - // Follow the operand according to the follow heuristics - follow_idx = ChooseOperandToFollow(strategy_map, depth_map, - alias_map, max_depth, ins) - .first; - break; - // The following types are better to follow the first operand. - case HloOpcode::kSlice: - case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - follow_idx = 0; - break; - default: - LOG(FATAL) << "Selecting follow index encounters an unhandled " - "instruction type: " + - ins->ToShortString(); - } - // Create follow strategies - const HloInstruction* operand = ins->operand(follow_idx); - StrategyGroup* src_strategy_group = strategy_map.at(operand).get(); - CHECK(!src_strategy_group->is_tuple); - strategy_group->following = src_strategy_group; - - for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); - ++sid) { - std::optional output_spec; - HloSharding input_spec = - src_strategy_group->strategies[sid].output_sharding; - - // Find output shardings. - switch (opcode) { - case HloOpcode::kPad: - case HloOpcode::kSlice: - case HloOpcode::kConcatenate: - case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - output_spec = PropagateDimwiseSharding( - input_spec, operand->shape(), ins->shape()); - break; - case HloOpcode::kReduceWindow: - case HloOpcode::kSelectAndScatter: - output_spec = PropagateReduceWindowSharding( - input_spec, operand->shape(), ins->window()); - break; - default: - LOG(FATAL) << "Unhandled instruction: " + ins->ToString(); - } - - // Get a list of input shardings, each corresponds to an operand. - std::vector> input_shardings; - for (int64_t k = 0; k < ins->operand_count(); ++k) { - if (k == follow_idx || - ToString(ins->operand(k)->shape().dimensions()) == - ToString(operand->shape().dimensions())) { - input_shardings.push_back(input_spec); - } else { - input_shardings.push_back(std::nullopt); - } - } - if (!output_spec.has_value()) { - continue; - } - std::string name = ToStringSimple(*output_spec); - double compute_cost = 0, communication_cost = 0; - double memory_cost = GetBytes(ins->shape()) / output_spec->NumTiles(); - std::vector> resharding_costs = - GenerateReshardingCostsAndMissingShardingsForAllOperands( - ins, *output_spec, strategy_map, cluster_env, call_graph, - input_shardings); - - strategy_group->strategies.push_back( - ShardingStrategy({name, - *output_spec, - compute_cost, - communication_cost, - memory_cost, - std::move(resharding_costs), - {input_spec}})); - } - - if (strategy_group->strategies.empty()) { - strategy_group->following = nullptr; - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, 0); - } - - break; - } - case HloOpcode::kOptimizationBarrier: { - auto operand_strategies = strategy_map.at(ins->operand(0)).get(); - strategy_group = MaybeFollowInsStrategyGroup( - operand_strategies, ins->shape(), instruction_id, - /* have_memory_cost */ true, strategy_groups, cluster_env, - pretrimmed_strategy_map); - break; - } - case HloOpcode::kBitcast: { - if (ins->shape() == ins->operand(0)->shape()) { - strategy_group = CreateElementwiseOperatorStrategies( - instruction_id, ins, strategy_map, cluster_env, depth_map, - alias_map, pretrimmed_strategy_map, max_depth, strategy_groups, - associative_dot_pairs); - } else { - strategy_group = CreateReshapeStrategies( - instruction_id, ins, strategy_map, cluster_env, - only_allow_divisible, replicated_penalty, batch_dim_map, option, - strategy_groups); - } - break; - } - // Unary elementwise operations. - case HloOpcode::kAbs: - case HloOpcode::kRoundNearestAfz: - case HloOpcode::kRoundNearestEven: - case HloOpcode::kCeil: - case HloOpcode::kClz: - case HloOpcode::kConvert: - case HloOpcode::kBitcastConvert: - case HloOpcode::kCopy: - case HloOpcode::kCos: - case HloOpcode::kExp: - case HloOpcode::kExpm1: - case HloOpcode::kFloor: - case HloOpcode::kImag: - case HloOpcode::kIsFinite: - case HloOpcode::kLog: - case HloOpcode::kLog1p: - case HloOpcode::kNot: - case HloOpcode::kNegate: - case HloOpcode::kPopulationCount: - case HloOpcode::kReal: - case HloOpcode::kReducePrecision: - case HloOpcode::kRsqrt: - case HloOpcode::kLogistic: - case HloOpcode::kSign: - case HloOpcode::kSin: - case HloOpcode::kSqrt: - case HloOpcode::kCbrt: - case HloOpcode::kTan: - case HloOpcode::kTanh: - // Binary elementwise operations - case HloOpcode::kAdd: - case HloOpcode::kAtan2: - case HloOpcode::kCompare: - case HloOpcode::kComplex: - case HloOpcode::kDivide: - case HloOpcode::kMaximum: - case HloOpcode::kMinimum: - case HloOpcode::kMultiply: - case HloOpcode::kPower: - case HloOpcode::kRemainder: - case HloOpcode::kSubtract: - case HloOpcode::kAnd: - case HloOpcode::kOr: - case HloOpcode::kXor: - case HloOpcode::kShiftLeft: - case HloOpcode::kShiftRightArithmetic: - case HloOpcode::kShiftRightLogical: - case HloOpcode::kStochasticConvert: - // Ternary elementwise operations. - case HloOpcode::kSelect: - case HloOpcode::kClamp: { - strategy_group = CreateElementwiseOperatorStrategies( - instruction_id, ins, strategy_map, cluster_env, depth_map, - alias_map, pretrimmed_strategy_map, max_depth, strategy_groups, - associative_dot_pairs); - break; - } - case HloOpcode::kReduce: { - auto strategies_status = FollowReduceStrategy( - ins, ins->shape(), ins->operand(0), ins->operand(1), instruction_id, - strategy_map, strategy_groups, cluster_env, - option.allow_mixed_mesh_shape, !trying_multiple_mesh_shapes); - if (strategies_status.ok()) { - strategy_group = std::move(strategies_status.value()); - } else { - return strategies_status.status(); - } - break; - } - case HloOpcode::kDot: { - TF_RETURN_IF_ERROR(HandleDot( - strategy_group, strategy_groups, strategy_map, ins, instruction_id, - cluster_env, batch_dim_map, option, call_graph)); - if (option.allow_replicated_strategy_for_dot_and_conv) { - AddReplicatedStrategy( - ins, ins->shape(), cluster_env, strategy_map, strategy_group, - GetDotConvReplicationPenalty(ins, instruction_id, /* window */ 10, - sequence, hlo_cost_analysis)); - } - break; - } - case HloOpcode::kConvolution: { - TF_RETURN_IF_ERROR(HandleConv( - strategy_group, strategy_groups, strategy_map, ins, instruction_id, - cluster_env, batch_dim_map, option, call_graph)); - if (option.allow_replicated_strategy_for_dot_and_conv) { - AddReplicatedStrategy( - ins, ins->shape(), cluster_env, strategy_map, strategy_group, - GetDotConvReplicationPenalty(ins, instruction_id, /* window */ 10, - sequence, hlo_cost_analysis)); - } - break; - } - case HloOpcode::kRngGetAndUpdateState: { - strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, - strategy_groups); - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, 0); - break; - } - case HloOpcode::kIota: { - strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, - strategy_groups); - if (cluster_env.IsDeviceMesh1D()) { - EnumerateAll1DPartition(ins, ins->shape(), device_mesh, cluster_env, - strategy_map, strategy_group, - only_allow_divisible, "", call_graph); - } - if (cluster_env.IsDeviceMesh2D()) { - // Split 2 dims - EnumerateAllPartition(ins, ins->shape(), device_mesh, cluster_env, - strategy_map, strategy_group, batch_dim_map, - only_allow_divisible, call_graph, /*parts*/ 2); - } - if (cluster_env.IsDeviceMesh3D()) { - // Split 3 dims - EnumerateAllPartition(ins, ins->shape(), device_mesh, cluster_env, - strategy_map, strategy_group, batch_dim_map, - only_allow_divisible, call_graph, /*parts*/ 3); - } - if (cluster_env.IsDeviceMesh2D() && option.allow_mixed_mesh_shape) { - // Split 1 dim, but for 1d flattened version of the 2d mesh - // For example, when the mesh shape is (2, 4), we add strategies for - // mesh shape (1, 8) here in addition. - EnumerateAll1DPartition(ins, ins->shape(), device_mesh_1d, - cluster_env, strategy_map, strategy_group, - only_allow_divisible, " 1d", call_graph); - } - - // Replicate - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, replicated_penalty * 5); - - break; - } - case HloOpcode::kTuple: { - strategy_group = CreateTupleStrategyGroup(instruction_id); - strategy_group->childs.reserve(ins->operand_count()); - for (size_t i = 0; i < ins->operand_count(); ++i) { - const HloInstruction* operand = ins->operand(i); - const StrategyGroup* src_strategy_group = - strategy_map.at(operand).get(); - auto child_strategies = MaybeFollowInsStrategyGroup( - src_strategy_group, operand->shape(), instruction_id, - /* have_memory_cost= */ true, strategy_groups, cluster_env, - pretrimmed_strategy_map); - child_strategies->tuple_element_idx = i; - strategy_group->childs.push_back(std::move(child_strategies)); - } - break; - } - case HloOpcode::kGetTupleElement: { - const HloInstruction* operand = ins->operand(0); - const StrategyGroup* src_strategy_group = - strategy_map.at(operand).get(); - CHECK(src_strategy_group->is_tuple); - strategy_group = MaybeFollowInsStrategyGroup( - src_strategy_group->childs[ins->tuple_index()].get(), ins->shape(), - instruction_id, - /* have_memory_cost= */ true, strategy_groups, cluster_env, - pretrimmed_strategy_map); - break; - } - case HloOpcode::kCustomCall: { - auto generate_non_following_strategies = - [&](bool only_replicated, - absl::flat_hash_set - operands_to_consider_all_strategies_for = {}) { - if (ins->shape().IsTuple()) { - if (only_replicated) { - strategy_group = CreateTupleStrategyGroup(instruction_id); - strategy_group->childs.reserve( - ins->shape().tuple_shapes_size()); - for (size_t i = 0; i < ins->shape().tuple_shapes_size(); - ++i) { - std::unique_ptr child_strategies = - CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, strategy_groups); - AddReplicatedStrategy(ins, ins->shape().tuple_shapes(i), - cluster_env, strategy_map, - child_strategies, replicated_penalty); - strategy_group->childs.push_back( - std::move(child_strategies)); - } - } else { - strategy_group = - CreateAllStrategiesGroup( - ins, ins->shape(), instruction_id, strategy_groups, - cluster_env, strategy_map, option, replicated_penalty, - batch_dim_map, call_graph, only_allow_divisible, true) - .value(); - } - } else { - if (only_replicated) { - strategy_group = CreateLeafStrategyGroup( - instruction_id, ins, strategy_map, strategy_groups); - AddReplicatedStrategy(ins, ins->shape(), cluster_env, - strategy_map, strategy_group, - replicated_penalty); - } else { - strategy_group = - CreateAllStrategiesGroup( - ins, ins->shape(), instruction_id, strategy_groups, - cluster_env, strategy_map, option, replicated_penalty, - batch_dim_map, call_graph, only_allow_divisible, true) - .value(); - } - } - }; - - if (IsCustomCallMarker(ins)) { - const HloInstruction* operand = ins->operand(0); - const StrategyGroup* src_strategy_group = - strategy_map.at(operand).get(); - CHECK(src_strategy_group->is_tuple); - strategy_group = MaybeFollowInsStrategyGroup( - src_strategy_group, ins->shape(), instruction_id, - /* have_memory_cost= */ true, strategy_groups, cluster_env, - pretrimmed_strategy_map); - } else if (ins->has_sharding()) { - generate_non_following_strategies(false); - } else if (OutputInputSameShapes(ins)) { - auto* partitioner = - GetCustomCallPartitioner(ins->custom_call_target()); - if (partitioner && partitioner->IsCustomCallShardable(ins)) { - // Follows operand 0's strategies if this custom-call op is - // shardable and has the same input and output sizes. - const HloInstruction* operand = ins->operand(0); - const StrategyGroup* src_strategy_group = - strategy_map.at(operand).get(); - strategy_group = MaybeFollowInsStrategyGroup( - src_strategy_group, ins->shape(), instruction_id, - /* have_memory_cost= */ true, strategy_groups, cluster_env, - pretrimmed_strategy_map); - } - } else if (IsTopKCustomCall(ins)) { - generate_non_following_strategies(false, {0}); - } else { - // TODO (b/258723035) Handle CustomCall ops for GPUs in a better way. - generate_non_following_strategies(true); - } - break; - } - case HloOpcode::kWhile: { - strategy_group = CreateTupleStrategyGroup(instruction_id); - strategy_group->childs.reserve(ins->shape().tuple_shapes_size()); - const StrategyGroup* src_strategy_group = - strategy_map.at(ins->operand(0)).get(); - for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) { - auto child_strategies = MaybeFollowInsStrategyGroup( - src_strategy_group->childs[i].get(), - ins->shape().tuple_shapes().at(i), instruction_id, - /* have_memory_cost= */ true, strategy_groups, cluster_env, - pretrimmed_strategy_map); - child_strategies->tuple_element_idx = i; - strategy_group->childs.push_back(std::move(child_strategies)); - } - break; - } - case HloOpcode::kConditional: - case HloOpcode::kInfeed: - case HloOpcode::kSort: { - strategy_group = - CreateAllStrategiesGroup(ins, ins->shape(), instruction_id, - strategy_groups, cluster_env, strategy_map, - option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, - /*create_replicated_strategies*/ true) - .value(); - break; - } - case HloOpcode::kOutfeed: { - strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, strategy_groups); - GenerateOutfeedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, replicated_penalty); - break; - } - case HloOpcode::kAfterAll: { - strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, strategy_groups); - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, replicated_penalty); - break; - } - default: - LOG(FATAL) << "Unhandled instruction: " + ins->ToString(); - } - RemoveDuplicatedStrategy(strategy_group); - if (ins->has_sharding() && ins->opcode() != HloOpcode::kOutfeed) { - // Finds the sharding strategy that aligns with the given sharding spec - // Do not merge nodes if this one instruction has annotations. - TrimOrGenerateStrategiesBasedOnExistingSharding( - ins->shape(), strategy_group.get(), strategy_map, instructions, - ins->sharding(), cluster_env, pretrimmed_strategy_map, call_graph, - option.nd_sharding_iteratively_strict_search_space); - } - if (!strategy_group->is_tuple && strategy_group->following) { - if (!LeafVectorsAreConsistent( - strategy_group->strategies, strategy_group->following->strategies, - /*is_reshape*/ ins->opcode() == HloOpcode::kReshape)) { - // It confuses the solver if two instructions have different number of - // sharding strategies but share the same ILP variable. The solver - // would run much longer and/or return infeasible solutions. - // So if two strategies' strategiess are inconsistent, we unfollow - // them. - strategy_group->following = nullptr; - } - } else if (strategy_group->is_tuple) { - for (size_t i = 0; i < strategy_group->childs.size(); i++) { - if (strategy_group->childs.at(i)->following && - !LeafVectorsAreConsistent( - strategy_group->childs.at(i)->strategies, - strategy_group->childs.at(i)->following->strategies, - /*is_reshape*/ ins->opcode() == HloOpcode::kReshape)) { - strategy_group->childs.at(i)->following = nullptr; - } - } - } - RemoveInvalidShardingsWithShapes( - ins->shape(), strategy_group.get(), - /* instruction_has_user_sharding */ ins->has_sharding()); - - if (instruction_execution_counts.contains(ins)) { - ScaleCostsWithExecutionCounts(strategy_group.get(), - instruction_execution_counts.at(ins)); - } else { - VLOG(5) << "No execution count available for " << ins->name(); - } - XLA_VLOG_LINES(2, - absl::StrCat("strategies:\n", strategy_group->ToString())); - - // Debug options: forcibly set the strategy of some instructions. - if (option.force_strategy) { - std::vector inst_indices = option.force_strategy_inst_indices; - std::vector stra_names = option.force_strategy_stra_names; - CHECK_EQ(inst_indices.size(), stra_names.size()); - auto it = absl::c_find(inst_indices, strategy_group->node_idx); - if (it != inst_indices.end()) { - CHECK(!strategy_group->is_tuple); - std::vector new_strategies; - int64_t idx = it - inst_indices.begin(); - for (const auto& stra : strategy_group->strategies) { - if (stra.name == stra_names[idx]) { - new_strategies.push_back(stra); - } - } - strategy_group->strategies = std::move(new_strategies); - } - } - - // When trying out multiple mesh shapes in the presence of user specified - // sharding (as in - // AutoShardingTest.AutoShardingKeepUserShardingInputOutput), there may be a - // situation when we cannot generate any shardings for an instruction when - // the mesh shape we're trying does not match with the mesh shape used in - // user specified shardings. So we disable the check in that situation. - if (!trying_multiple_mesh_shapes) { - CHECK(strategy_group->is_tuple || !strategy_group->strategies.empty()) - << ins->ToString() << " does not have any valid strategies."; - } else if (!(strategy_group->is_tuple || - !strategy_group->strategies.empty())) { - return Status(absl::StatusCode::kFailedPrecondition, - "Could not generate any shardings for an instruction due " - "to mismatched mesh shapes."); - } - // Checks the shape of resharding_costs is valid. It will check fail if the - // shape is not as expected. - // CheckReshardingCostsShape(strategies.get()); - CheckMemoryCosts(strategy_group.get(), ins->shape()); - strategy_map[ins] = std::move(strategy_group); - } // end of for loop - - // If gradient accumulation is used, adjust the cost of all-reduce for - // gradient synchronization. - if (option.grad_acc_num_micro_batches > 1) { - // find gradient-computation instructions - std::vector grad_insts = - GetGradientComputationInstructions(instructions); - for (const HloInstruction* inst : grad_insts) { - StrategyGroup* stra_vector = strategy_map[inst].get(); - CHECK(!stra_vector->is_tuple); - - for (auto& stra : stra_vector->strategies) { - if (absl::StrContains(stra.name, "allreduce")) { - stra.communication_cost /= option.grad_acc_num_micro_batches; - } - } - } - } - - return std::make_tuple(std::move(strategy_map), std::move(strategy_groups), - std::move(associative_dot_pairs)); + return strategy_group; } -// NOLINTEND - AutoShardingSolverResult CallSolver( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, - const LivenessNodeSet& liveness_node_set, const StrategyMap& strategy_map, + const LivenessNodeSet& liveness_node_set, + const LivenessEdgeSet& liveness_edge_set, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, const std::vector& s_hint, - bool compute_iis, int64_t solver_timeout_in_seconds, - const AutoShardingOption& option, + const absl::flat_hash_set& peak_times, const bool compute_iis, + const int64_t solver_timeout_in_seconds, const AutoShardingOption& option, + std::optional max_cost, absl::string_view request_name, const absl::flat_hash_map& - sharding_propagation_solution) { - // Serialize edges and edge costs to 1d numpy arrays + sharding_propagation_solution, + bool deterministic_mode) { + // Serialize edges and edge costs to 1d numpy arrays. AutoShardingSolverRequest request; request.set_module_name(hlo_module.name()); request.set_num_nodes(strategy_groups.size()); @@ -2497,39 +1841,52 @@ AutoShardingSolverResult CallSolver( request.mutable_s_follow()->Add(cost_graph.follow_idx_.begin(), cost_graph.follow_idx_.end()); request.mutable_s_hint()->Add(s_hint.begin(), s_hint.end()); + request.mutable_peak_times()->Add(peak_times.begin(), peak_times.end()); request.mutable_solver_timeout()->set_solver_timeout_in_seconds( solver_timeout_in_seconds); request.mutable_overbudget_coeff()->set_coeff(kOverbudgetCoeff); request.set_crash_at_infinity_costs_check(!option.try_multiple_mesh_shapes); request.set_compute_iis(compute_iis); request.set_saltiplier(kSaltiplier); + request.set_deterministic_mode(deterministic_mode); + request.set_request_name(std::string(request_name)); + request.set_enable_memory_edge_costs(option.model_resharding_memory_costs); + if (max_cost) { + request.mutable_max_cost()->set_coeff(*max_cost); + } for (const auto& [edge, edge_cost] : cost_graph.edge_costs_) { AutoShardingSolverRequest_Pair raw_edge; raw_edge.set_first(edge.first); raw_edge.set_second(edge.second); *request.add_edges() = raw_edge; AutoShardingSolverRequest_Costs rij; + AutoShardingSolverRequest_Costs mij; for (NodeStrategyIdx i = 0; i < edge_cost.n_; i++) { for (NodeStrategyIdx j = 0; j < edge_cost.m_; j++) { - rij.add_costs(edge_cost(i, j)); + rij.add_costs(edge_cost(i, j).communication_cost); + mij.add_costs(edge_cost(i, j).memory_cost); } } request.mutable_resharding_costs()->Add(std::move(rij)); + request.mutable_memory_edge_costs()->Add(std::move(mij)); } const HloInstructionSequence& sequence = hlo_live_range.flattened_instruction_sequence(); const std::vector& instructions = sequence.instructions(); - // Serialize node costs + // Serialize node costs. int num_nodes_without_default = 0; for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { const StrategyGroup* strategy_group = strategy_groups[node_idx]; - auto instruction_name = - instructions.at(strategy_group->instruction_id)->name(); + const auto instruction = instructions.at(strategy_group->instruction_id); + const auto instruction_name = instruction->name(); + const auto opcode = HloOpcodeString(instruction->opcode()); request.add_instruction_names( absl::StrCat(instruction_name, " (id: ", node_idx, ")")); + request.add_opcodes(std::string(opcode)); AutoShardingSolverRequest_Costs ci, di, mi, pi; + AutoShardingSolverRequest_Names strategy_names; std::optional default_strategy; auto iter = sharding_propagation_solution.find(instruction_name); if (iter != sharding_propagation_solution.end()) { @@ -2550,6 +1907,7 @@ AutoShardingSolverResult CallSolver( cost_graph.extra_node_costs_[node_idx][j]); mi.add_costs(strategy.memory_cost); pi.add_costs(default_strategy && sharding == *default_strategy ? 0 : 1); + strategy_names.add_names(sharding.ToString()); } if (option.use_sharding_propagation_for_default_shardings && *std::min_element(pi.costs().begin(), pi.costs().end()) > 0) { @@ -2562,17 +1920,18 @@ AutoShardingSolverResult CallSolver( request.mutable_communication_costs()->Add(std::move(di)); request.mutable_memory_costs()->Add(std::move(mi)); request.mutable_departure_costs()->Add(std::move(pi)); + request.mutable_strategy_names()->Add(std::move(strategy_names)); } LOG(INFO) << "Total nodes without default: " << num_nodes_without_default; // Serialize special edges that forces a alias pair have the same sharding - // spec + // spec. std::vector> new_followers; for (const auto& pair : alias_set) { const StrategyGroup* src_strategy_group = strategy_groups[pair.first]; const StrategyGroup* dst_strategy_group = strategy_groups[pair.second]; - Matrix raw_cost(src_strategy_group->strategies.size(), - dst_strategy_group->strategies.size()); + Matrix raw_cost(src_strategy_group->strategies.size(), + dst_strategy_group->strategies.size()); for (NodeStrategyIdx i = 0; i < src_strategy_group->strategies.size(); ++i) { for (NodeStrategyIdx j = 0; j < dst_strategy_group->strategies.size(); @@ -2654,6 +2013,12 @@ AutoShardingSolverResult CallSolver( liveness_node_subset.end()); request.mutable_live()->Add(std::move(nodes)); } + for (const auto& liveness_edge_subset : liveness_edge_set) { + AutoShardingSolverRequest_Edges edges; + edges.mutable_edges()->Add(liveness_edge_subset.begin(), + liveness_edge_subset.end()); + request.mutable_live_edges()->Add(std::move(edges)); + } PopulateTemporalValues(cost_graph, request); @@ -2661,7 +2026,7 @@ AutoShardingSolverResult CallSolver( } void CheckHloSharding(const HloInstructionSequence& sequence, - size_t total_num_devices) { + const size_t total_num_devices) { const std::vector& instructions = sequence.instructions(); std::vector> size_string; for (const HloInstruction* ins : instructions) { @@ -2732,7 +2097,7 @@ void CheckHloSharding(const HloInstructionSequence& sequence, std::sort(size_string.begin(), size_string.end(), MemLarger); size_t k = 10; k = std::min(k, size_string.size()); - for (size_t t = 0; t < k; t++) { + for (size_t t = 0; t < k; ++t) { LOG(INFO) << size_string.at(t).second; } } @@ -2742,12 +2107,16 @@ void SetHloSharding(const HloInstructionSequence& sequence, const StrategyMap& strategy_map, const CostGraph& cost_graph, absl::Span s_val, - bool last_iteration) { + const bool last_iteration) { // Set the HloSharding for every instruction const std::vector& instructions = sequence.instructions(); for (HloInstruction* inst : instructions) { - if (inst->opcode() == HloOpcode::kOutfeed) { + if (inst->opcode() == HloOpcode::kOutfeed || + inst->opcode() == HloOpcode::kRecv || + inst->opcode() == HloOpcode::kRecvDone || + inst->opcode() == HloOpcode::kSend || + inst->opcode() == HloOpcode::kSendDone) { continue; } auto iter = strategy_map.find(inst); @@ -2812,8 +2181,8 @@ void SetHloSharding(const HloInstructionSequence& sequence, Status SetHloShardingPostProcessing( const HloInstructionSequence& sequence, const StrategyMap& strategy_map, const CostGraph& cost_graph, absl::Span s_val, - const ClusterEnvironment& cluster_env, bool crash_at_error, - absl::flat_hash_map>* + const ClusterEnvironment& cluster_env, const bool crash_at_error, + absl::flat_hash_map>& preserve_shardings) { const std::vector& instructions = sequence.instructions(); const Array& device_mesh = cluster_env.device_mesh_; @@ -2829,23 +2198,33 @@ Status SetHloShardingPostProcessing( // Here we insert some extra annotated identity instructions to help the // spmd partitioner generate correct code. - if (inst->opcode() == HloOpcode::kDot) { + if (inst->opcode() == HloOpcode::kDot || + inst->opcode() == HloOpcode::kConvolution) { const ShardingStrategy& stra = GetShardingStrategy(inst, strategy_map, cost_graph, s_val); const HloInstruction* lhs = inst->operand(0); const HloInstruction* rhs = inst->operand(1); const HloSharding& lhs_sharding = lhs->sharding(); const HloSharding& rhs_sharding = rhs->sharding(); - const DotDimensionNumbers& dot_dnums = inst->dot_dimension_numbers(); - const auto& lhs_con_dims = dot_dnums.lhs_contracting_dimensions(); - const auto& rhs_con_dims = dot_dnums.rhs_contracting_dimensions(); + std::vector lhs_con_dims; + std::vector rhs_con_dims; + if (inst->opcode() == HloOpcode::kDot) { + const DotDimensionNumbers& dot_dnums = inst->dot_dimension_numbers(); + lhs_con_dims.push_back(dot_dnums.lhs_contracting_dimensions()[0]); + rhs_con_dims.push_back(dot_dnums.rhs_contracting_dimensions()[0]); + } else { + const ConvolutionDimensionNumbers& conv_dnums = + inst->convolution_dimension_numbers(); + lhs_con_dims.push_back(conv_dnums.input_feature_dimension()); + rhs_con_dims.push_back(conv_dnums.kernel_input_feature_dimension()); + } - const auto& lhs_tensor_dim_to_mesh_dim = + const std::vector& lhs_tensor_dim_to_mesh_dim = cluster_env.GetTensorDimToMeshDimWrapper( lhs->shape(), lhs_sharding, /* consider_reverse_device_meshes */ true, /* crash_at_error */ crash_at_error); - const auto& rhs_tensor_dim_to_mesh_dim = + const std::vector& rhs_tensor_dim_to_mesh_dim = cluster_env.GetTensorDimToMeshDimWrapper( rhs->shape(), rhs_sharding, /* consider_reverse_device_meshes */ true, @@ -2856,10 +2235,17 @@ Status SetHloShardingPostProcessing( return absl::InvalidArgumentError( "Cannot generate tensor dim to mesh dim mapping"); } + if (absl::StrContains(stra.name, "allreduce") && - lhs_tensor_dim_to_mesh_dim[lhs_con_dims[0]] == -1 && - rhs_tensor_dim_to_mesh_dim[rhs_con_dims[0]] == -1) { - // Allow duplicatd dot computation in this case to reduce + std::any_of(lhs_con_dims.begin(), lhs_con_dims.end(), + [&lhs_tensor_dim_to_mesh_dim](int64_t dim) { + return lhs_tensor_dim_to_mesh_dim[dim] == -1; + }) && + std::any_of(rhs_con_dims.begin(), rhs_con_dims.end(), + [&rhs_tensor_dim_to_mesh_dim](int64_t dim) { + return rhs_tensor_dim_to_mesh_dim[dim] == -1; + })) { + // Allow duplicated dot computation in this case to reduce // communication } else { CHECK(stra.input_shardings.size() == 2) @@ -2875,59 +2261,21 @@ Status SetHloShardingPostProcessing( device_mesh, resharding_cache); } } - } else if (inst->opcode() == HloOpcode::kConvolution) { - const ShardingStrategy& stra = - GetShardingStrategy(inst, strategy_map, cost_graph, s_val); - const HloInstruction* lhs = inst->operand(0); - const HloInstruction* rhs = inst->operand(1); - const HloSharding& lhs_sharding = lhs->sharding(); - const HloSharding& rhs_sharding = rhs->sharding(); - const ConvolutionDimensionNumbers& conv_dnums = - inst->convolution_dimension_numbers(); - const int lhs_in_channel_dim = conv_dnums.input_feature_dimension(); - const int rhs_in_channel_dim = - conv_dnums.kernel_input_feature_dimension(); - - const auto& lhs_tensor_dim_to_mesh_dim = - cluster_env.GetTensorDimToMeshDimWrapper( - lhs->shape(), lhs_sharding, - /* consider_reverse_device_meshes */ true, - /* crash_at_error */ crash_at_error); - const auto& rhs_tensor_dim_to_mesh_dim = - cluster_env.GetTensorDimToMeshDimWrapper( - rhs->shape(), rhs_sharding, - /* consider_reverse_device_meshes */ true, - /* crash_at_error */ crash_at_error); - - if (lhs_tensor_dim_to_mesh_dim.size() != lhs->shape().rank() || - rhs_tensor_dim_to_mesh_dim.size() != rhs->shape().rank()) { - return absl::InvalidArgumentError( - "Cannot generate tensor dim to mesh dim mapping"); - } - - if (absl::StrContains(stra.name, "allreduce") && - lhs_tensor_dim_to_mesh_dim[lhs_in_channel_dim] == -1 && - rhs_tensor_dim_to_mesh_dim[rhs_in_channel_dim] == -1) { - // Allow duplicatd conv computation in this case to reduce - // communication - } else { - if (stra.input_shardings[0].has_value()) { - FixMixedMeshShapeResharding(inst, 0, stra.input_shardings[0].value(), - device_mesh, resharding_cache); - } - if (stra.input_shardings[1].has_value()) { - FixMixedMeshShapeResharding(inst, 1, stra.input_shardings[1].value(), - device_mesh, resharding_cache); - } - } - } else if (inst->opcode() == HloOpcode::kOutfeed) { - // Outfeed operand shardings are handled in downstream passes and so we - // ignore outfeed ops here. However, we need to ensure that outfeed ops - // which have user shardings have their shardings restored at the end. If - // not, this can lead to errors downstream in the spmd_partitioner pass. - auto preserved_sharding_iter = preserve_shardings->find(inst->name()); - if (preserved_sharding_iter != preserve_shardings->end()) { - const auto& preserved_sharding = preserved_sharding_iter->second; + } else if (inst->opcode() == HloOpcode::kOutfeed || + inst->opcode() == HloOpcode::kSendDone) { + // Outfeed: Outfeed operand shardings are handled in downstream passes and + // so we ignore outfeed ops here. However, we need to ensure that outfeed + // ops which have user shardings have their shardings restored at the + // end. If not, this can lead to errors downstream in the spmd_partitioner + // pass. + + // In the analysis itself, we use replicated strategies as a stand-in for + // the (expected) maximal sharding annotations that send-done ops usually + // have. Here we restore these maximal shardings if present. + auto preserved_sharding_iter = preserve_shardings.find(inst->name()); + if (preserved_sharding_iter != preserve_shardings.end()) { + const std::vector& preserved_sharding = + preserved_sharding_iter->second; if (preserved_sharding.size() > 1) { std::vector tuple_elements_shape( inst->operand(0)->shape().tuple_shapes().begin(), @@ -2938,18 +2286,51 @@ Status SetHloShardingPostProcessing( ShapeTree output_tuple_sharding( output_tuple_sharding_shape, Undefined()); size_t i = 0; - for (auto& leaf : output_tuple_sharding.leaves()) { + for (std::pair& leaf : + output_tuple_sharding.leaves()) { leaf.second = preserved_sharding.at(i++); } inst->set_sharding(HloSharding::Tuple(output_tuple_sharding)); } else { - inst->set_sharding(preserved_sharding.at(0)); + CHECK_EQ(preserved_sharding.size(), 1); // Crash OK + inst->set_sharding(preserved_sharding[0]); + } + } + continue; + } else if (inst->opcode() == HloOpcode::kSend || + inst->opcode() == HloOpcode::kRecv || + inst->opcode() == HloOpcode::kRecvDone) { + // In the analysis itself, we use replicated strategies as a stand-in for + // the (expected) maximal sharding annotations that send ops usually + // have. Here we restore these maximal shardings if present. + auto preserved_sharding_iter = preserve_shardings.find(inst->name()); + if (preserved_sharding_iter != preserve_shardings.end()) { + const std::vector& preserved_sharding = + preserved_sharding_iter->second; + if (preserved_sharding.size() > 1) { + inst->set_sharding( + HloSharding::Tuple(inst->shape(), preserved_sharding)); + } else { + if (preserved_sharding.size() != 1) { + return absl::InternalError(absl::StrCat( + "An empty sharding was preserved for ", inst->name(), + ". This should be reported as a bug.")); + } + inst->set_sharding(preserved_sharding[0]); } } - continue; } else { if (inst->shape().IsTuple()) { + // While we do not support nested tuples fully (b/332951306), this is a + // hack to get things to work in some cases (specifically observed for + // the llama and gemma models) where nested tuples as used as + // inputs/outputs of the kOptimizationBarrier instruction. + if (absl::c_any_of( + inst->shape().tuple_shapes(), + [](const Shape& shape) { return shape.IsTuple(); })) { + continue; + } switch (inst->opcode()) { case HloOpcode::kReduce: case HloOpcode::kCustomCall: @@ -2987,7 +2368,7 @@ Status SetHloShardingPostProcessing( for (size_t i = 0; i < inst->shape().tuple_shapes_size(); ++i) { CHECK(!inst->shape().tuple_shapes(i).IsTuple()) << "We currently do not support ops with nested tuples as " - "output."; + "output. See b/332951306."; const ShardingStrategy& stra = GetShardingStrategyForTuple(inst, {static_cast(i)}, strategy_map, cost_graph, s_val); @@ -2997,7 +2378,7 @@ Status SetHloShardingPostProcessing( } } FixMixedMeshShapeReshardingGetTupleElementWithTupleOutput( - inst, dst_shardings, device_mesh, preserve_shardings); + inst, dst_shardings, device_mesh); break; } @@ -3078,14 +2459,14 @@ std::string PrintStrategyMap(const StrategyMap& strategy_map, } // Print the chosen auto sharding strategy for debugging. -// TODO (zhuohan): update the following function +// TODO (zhuohan): Update the following function. std::string PrintAutoShardingSolution(const HloInstructionSequence& sequence, const LivenessSet& liveness_set, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, absl::Span s_val, - double objective) { + const double objective) { std::string str("=== Auto sharding strategy ===\n"); const std::vector& instructions = sequence.instructions(); size_t N = strategy_groups.size(); @@ -3207,82 +2588,34 @@ std::string PrintSolutionMemoryUsage(const LivenessSet& liveness_set, } void SaveShardingForInstruction( + const HloInstruction* inst, bool save_for_copy_users, absl::flat_hash_map>& - preserve_shardings, - HloInstruction* inst) { - if (!inst->has_sharding()) { - return; - } - if (!inst->sharding().IsTuple()) { - preserve_shardings[inst->name()] = {inst->sharding()}; - } else { - preserve_shardings[inst->name()] = inst->sharding().tuple_elements(); - } -} - -// Saves the user shardings that need to be preserved, and check whether they -// are preserved after this pass. -absl::flat_hash_map> SaveUserShardings( - HloModule* module, - const absl::flat_hash_set& replicated_small_tensors, - AutoShardingOption::PreserveShardingsType type) { - absl::flat_hash_map> preserve_shardings; - if (type == AutoShardingOption::PreserveShardingsType::kKeepAllShardings) { - // Saves shardings for all instructions. - for (const auto computation : module->computations()) { - for (const auto inst : computation->instructions()) { - SaveShardingForInstruction(preserve_shardings, inst); - for (const auto user : inst->users()) { - // Also preserve the shardings of copy ops that are the users of those - // instructions. - if (user->opcode() == HloOpcode::kCopy) { - SaveShardingForInstruction(preserve_shardings, user); - } - } - } - } - } else if (type == AutoShardingOption::PreserveShardingsType:: - kKeepInputOutputShardings) { - // Saves parameter shardings - for (const auto inst : - module->entry_computation()->parameter_instructions()) { - SaveShardingForInstruction(preserve_shardings, inst); - for (const auto user : inst->users()) { - // Also preserve the shardings of copy ops that are the users of those - // instructions. - if (user->opcode() == HloOpcode::kCopy) { - SaveShardingForInstruction(preserve_shardings, user); - } - } + preserve_shardings) { + auto save_sharding = [&preserve_shardings](const HloInstruction* inst) { + if (!inst->has_sharding()) { + return; } - for (const auto computation : module->computations()) { - for (const auto inst : computation->instructions()) { - if (inst->opcode() == HloOpcode::kOutfeed || - replicated_small_tensors.count(inst->name())) { - SaveShardingForInstruction(preserve_shardings, inst); - } - } + if (!inst->sharding().IsTuple()) { + preserve_shardings[inst->name()] = {inst->sharding()}; + } else { + preserve_shardings[inst->name()] = inst->sharding().tuple_elements(); } - // Saves output shardings - auto inst = module->entry_computation()->root_instruction(); - SaveShardingForInstruction(preserve_shardings, inst); - } + }; - if (VLOG_IS_ON(1)) { - LOG(INFO) << "User shardings that need to be kept (printing only the 1st " - "elemenet of tuples): "; - for (const auto& tmp : preserve_shardings) { - std::string sharding; - for (const auto& s : tmp.second) { - sharding += s.ToString() + ","; + save_sharding(inst); + + if (save_for_copy_users) { + for (const auto user : inst->users()) { + // Also preserve the shardings of copy ops that are the users of those + // instructions. + if (user->opcode() == HloOpcode::kCopy) { + save_sharding(user); } - LOG(INFO) << tmp.first << ": " << sharding; } } - return preserve_shardings; } -// Check whether the shardings that need to be perserved are preserved. +// Check whether the shardings that need to be preserved are preserved. void CheckUserShardingPreservation( HloModule* module, const absl::flat_hash_map>& @@ -3298,6 +2631,7 @@ void CheckUserShardingPreservation( << preserve_shardings.at(inst->name())[0].ToString() << "\nbut it's empty."; } else if (!inst->sharding().IsTuple() && + !preserve_shardings.at(inst->name())[0].IsUnknown() && preserve_shardings.at(inst->name())[0] != inst->sharding()) { LOG(FATAL) << "User sharding is not preserved! Instruction with name " << inst->name() << " should be: " @@ -3307,8 +2641,9 @@ void CheckUserShardingPreservation( const std::vector* preserve_shardings_tuple = &preserve_shardings.at(inst->name()); for (size_t i = 0; i < inst->shape().tuple_shapes_size(); i++) { - if (preserve_shardings_tuple->at(i) != - inst->sharding().tuple_elements().at(i)) { + if (!preserve_shardings_tuple->at(i).IsUnknown() && + preserve_shardings_tuple->at(i) != + inst->sharding().tuple_elements().at(i)) { LOG(FATAL) << "Tuple sharding is not preserved! Instruction " "with name " << inst->name() << " " << i << "th tuple element " @@ -3323,11 +2658,12 @@ void CheckUserShardingPreservation( } } -int64_t MemoryBudgetLowerBound(const HloModule& module, - const LivenessSet& liveness_set, - const HloAliasAnalysis* alias_analysis, - int64_t num_devices) { - auto get_value_sharding = [](const HloValue* value) { +int64_t MemoryBudgetLowerBound( + const HloModule& module, const LivenessSet& liveness_set, + const HloAliasAnalysis& alias_analysis, const int64_t num_devices, + const absl::flat_hash_map>& + preserved_shardings) { + auto get_value_sharding = [](const HloValue* value) -> HloSharding { return !value->index().empty() ? value->instruction()->sharding().GetSubSharding( value->instruction()->shape(), value->index()) @@ -3340,24 +2676,28 @@ int64_t MemoryBudgetLowerBound(const HloModule& module, // as aliasing HloValues are mapped to the same buffer. absl::flat_hash_map buffer_to_sharded_value_mapping; - for (LivenessIdx time_idx = 0; time_idx < liveness_set.size(); ++time_idx) { - for (const HloValue* value : liveness_set[time_idx]) { - auto buffer = alias_analysis->GetBufferContainingValue(*value); + bool vlog_is_on_5 = VLOG_IS_ON(5); + for (const HloBuffer& buffer : alias_analysis.buffers()) { + for (const HloValue* value : buffer.values()) { if (value->instruction()->has_sharding()) { - auto this_value_sharding = get_value_sharding(value); - auto iter = buffer_to_sharded_value_mapping.find(buffer.id()); - if (iter != buffer_to_sharded_value_mapping.end()) { - auto buffer_value_sharding = get_value_sharding(iter->second); - if (this_value_sharding != buffer_value_sharding) { - // TODO(pratikf): This is an unavoidable situation, but possibly - // there is a better design decision that can be made here. - VLOG(1) << "We have a situation where two HloValues alias, but " - "they have different shardings. This can happen in the " - "presence of user-specified shardings, and is expected. " - "This, however, means that the memory budget estimate " - "is not very accurate. The aliasing HLOs are " - << value->ToShortString() << " and " - << iter->second->ToShortString(); + if (vlog_is_on_5) { + const HloSharding& this_value_sharding = get_value_sharding(value); + auto iter = buffer_to_sharded_value_mapping.find(buffer.id()); + if (iter != buffer_to_sharded_value_mapping.end()) { + const HloSharding& buffer_value_sharding = + get_value_sharding(iter->second); + if (this_value_sharding != buffer_value_sharding) { + // TODO(pratikf): This is an unavoidable situation, but possibly + // there is a better design decision that can be made here. + VLOG(1) + << "We have a situation where two HloValues alias, but " + "they have different shardings. This can happen in the " + "presence of user-specified shardings, and is expected. " + "This, however, means that the memory budget estimate " + "is not very accurate. The aliasing HLOs are " + << value->ToShortString() << " and " + << iter->second->ToShortString(); + } } } buffer_to_sharded_value_mapping[buffer.id()] = value; @@ -3366,25 +2706,51 @@ int64_t MemoryBudgetLowerBound(const HloModule& module, } int64_t max_memory_usage = 0; + absl::flat_hash_map value_to_memory_size_mapping; for (LivenessIdx time_idx = 0; time_idx < liveness_set.size(); ++time_idx) { int64_t memory_usage = 0; for (const HloValue* value : liveness_set[time_idx]) { if (value->instruction()->shape().IsTuple() && value->index().empty()) { continue; } - Shape shape = - ShapeUtil::GetSubshape(value->instruction()->shape(), value->index()); - auto buffer = alias_analysis->GetBufferContainingValue(*value); - auto iter = buffer_to_sharded_value_mapping.find(buffer.id()); + + auto iter1 = value_to_memory_size_mapping.find(value); + if (iter1 != value_to_memory_size_mapping.end()) { + memory_usage += iter1->second; + continue; + } + std::optional optional_sharding = std::nullopt; - if (iter != buffer_to_sharded_value_mapping.end()) { - optional_sharding = get_value_sharding(iter->second); + const HloBuffer& buffer = alias_analysis.GetBufferContainingValue(*value); + auto iter2 = buffer_to_sharded_value_mapping.find(buffer.id()); + if (iter2 != buffer_to_sharded_value_mapping.end()) { + // The instructions here can have partial sharding annotations from + // previous iterations with partial mesh shapes when + // solve_nd_sharding_iteratively is true. To exclude these, we only + // utilize those shardings which corresponding to the current device + // mesh. + if (preserved_shardings.find(value->instruction()->name()) != + preserved_shardings.end()) { + optional_sharding = get_value_sharding(iter2->second); + } else { + const HloSharding& value_sharding = get_value_sharding(iter2->second); + if (!value_sharding.IsTiled() || + value_sharding.TotalNumTiles() == num_devices) { + optional_sharding = value_sharding; + } + } } - memory_usage += + + const Shape& shape = + ShapeUtil::GetSubshape(value->instruction()->shape(), value->index()); + int64_t value_memory_usage = GetShardedInstructionSize(shape, num_devices, optional_sharding); + value_to_memory_size_mapping[value] = value_memory_usage; + memory_usage += value_memory_usage; } max_memory_usage = std::max(max_memory_usage, memory_usage); } + return max_memory_usage; } @@ -3423,12 +2789,13 @@ void RecoverShardingsFromPartialMesh( } } } + // DFS to find the replicated set starting from cur instruction. void FindReplicateSet( HloInstruction* cur, const AliasMap& alias_map, const CostGraph& cost_graph, absl::Span s_val, const StrategyMap& strategy_map, const ShardingStrategy& strategy, const HloInstruction* output, - bool do_all_gather_after_backward, HloInstruction*& transpose_inst, + const bool do_all_gather_after_backward, HloInstruction*& transpose_inst, StableHashSet& replicated_set, StableHashSet& boundary_set, StableHashSet& consumer_set, @@ -3440,13 +2807,13 @@ void FindReplicateSet( for (HloInstruction* consumer : users) { const HloInstruction* shape_inst = cur; - // Allow at most one transpose + // Allow at most one transpose. if (consumer->opcode() == HloOpcode::kTranspose && (transpose_inst == nullptr || DimensionsEqual(transpose_inst->shape(), consumer->shape()))) { shape_inst = consumer; transpose_inst = consumer; - // TODO(zhuohan): fix output_sharding comparison. + // TODO(zhuohan): Fix output_sharding comparison. } if (consumer->opcode() == HloOpcode::kTuple || @@ -3488,14 +2855,14 @@ void FindReplicateSet( } // Substitute all-reduce strategies with their reduce-scatter variants. -void GenerateReduceScatter( +absl::Status GenerateReduceScatter( const HloInstructionSequence& sequence, const AliasMap& alias_map, const InstructionDepthMap& depth_map, const StrategyMap& strategy_map, const CostGraph& cost_graph, absl::Span s_val, const ClusterEnvironment& cluster_env, const AutoShardingOption& option) { const std::vector& instructions = sequence.instructions(); - // Propagation ends at output + // Propagation ends at output. const HloInstruction* output = instructions.back(); if (IsCustomCallMarker(output)) { output = output->operand(0); @@ -3709,7 +3076,7 @@ void GenerateReduceScatter( CHECK(!cur->users().empty()); - // Find the first user + // Find the first user. HloInstruction* first_user = nullptr; int64_t min_depth = ((int64_t)1) << 50; for (const auto& x : cur->users()) { @@ -3719,7 +3086,7 @@ void GenerateReduceScatter( } if (x->opcode() != HloOpcode::kConvolution && x->opcode() != HloOpcode::kDot) { - // Only apply this aggressive optimization for dot and conv + // Only apply this aggressive optimization for dot and conv. continue; } if (iter->second < min_depth) { @@ -3729,7 +3096,7 @@ void GenerateReduceScatter( } if (first_user != nullptr) { - // Insert an identity to prevent CSE of all-gather + // Insert an identity to prevent CSE of all-gather. HloInstruction* identity = inst->parent()->AddInstruction( HloInstruction::CreateCustomCall(cur->shape(), {cur}, kIdentityMarker)); @@ -3753,8 +3120,9 @@ void GenerateReduceScatter( replace_with->set_sharding( GetShardingStrategy(inst, strategy_map, cost_graph, s_val) .output_sharding); - TF_CHECK_OK(inst->ReplaceAllUsesWith(replace_with)); + TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(replace_with)); } + return OkStatus(); } void AnnotateShardingWithSimpleHeuristic( @@ -4083,49 +3451,99 @@ bool HasReduceScatterOpportunity( } // namespace spmd -StatusOr AutoShardingImplementation::RemoveShardingAnnotation( +std::pair>, bool> +AutoShardingImplementation::SaveAndRemoveShardingAnnotation( HloModule* module, const absl::flat_hash_set& replicated_small_tensors, const absl::flat_hash_set& execution_threads) { + absl::flat_hash_map> preserve_shardings; + absl::flat_hash_set keep_inst; + + for (const HloComputation* computation : + module->computations(execution_threads)) { + for (const auto inst : computation->instructions()) { + if (inst->opcode() == HloOpcode::kOutfeed || + inst->opcode() == HloOpcode::kRecv || + inst->opcode() == HloOpcode::kRecvDone || + inst->opcode() == HloOpcode::kSend || + inst->opcode() == HloOpcode::kSendDone) { + spmd::SaveShardingForInstruction(inst, + /* save_for_copy_users */ false, + preserve_shardings); + continue; + } + if (inst->has_sharding() && + spmd::IsShardingMisaligned(inst->sharding(), inst->shape())) { + LOG(WARNING) + << "Instruction " << inst->name() + << " has a user sharding annotation that is misaligned. Shape: " + << inst->shape().ToString() + << ". Sharding:" << inst->sharding().ToString(); + } + } + } + if (option_.preserve_shardings == AutoShardingOption::PreserveShardingsType::kKeepAllShardings) { - return false; + // Saves shardings for all instructions. + for (const HloComputation* computation : + module->computations(execution_threads)) { + for (const auto inst : computation->instructions()) { + spmd::SaveShardingForInstruction(inst, + /* save_for_copy_users */ true, + preserve_shardings); + } + } + return std::make_pair(preserve_shardings, /* module_is_changed */ false); } - VLOG(0) << "Removing user sharding annotations."; - bool changed = false; - absl::flat_hash_set keep_inst; + + bool module_is_changed = false; for (HloComputation* computation : module->computations(execution_threads)) { bool is_entry_computation = computation->IsEntryComputation(); for (HloInstruction* ins : computation->instructions()) { - // Do not remove sharding annotations from instructions replicated as they - // are small tensors + // Do not remove sharding annotations from instructions replicated as + // they are small tensors if (replicated_small_tensors.count(ins->name())) { keep_inst.insert(ins); + spmd::SaveShardingForInstruction(ins, + /* save_for_copy_users */ false, + preserve_shardings); continue; } - // Do not remove entry computation's parameter and root instruction's // sharding if preserve_shardings is kKeepInputOutputShardings. if (option_.preserve_shardings == AutoShardingOption::PreserveShardingsType:: kKeepInputOutputShardings && - (is_entry_computation && - (ins->opcode() == HloOpcode::kParameter || ins->IsRoot()))) { + is_entry_computation && + (ins->opcode() == HloOpcode::kParameter || ins->IsRoot())) { keep_inst.insert(ins); + spmd::SaveShardingForInstruction( + ins, + /* save_for_copy_users */ ins->opcode() == HloOpcode::kParameter, + preserve_shardings); continue; } + if (ins->opcode() == HloOpcode::kCopy && keep_inst.find(ins->operand(0)) != keep_inst.end()) { continue; } + + if (ins->opcode() == HloOpcode::kOutfeed || + ins->opcode() == HloOpcode::kSend || + ins->opcode() == HloOpcode::kSendDone) { + continue; + } + if (ins->has_sharding()) { - changed |= true; + module_is_changed |= true; ins->clear_sharding(); } } } - return changed; + return std::make_pair(preserve_shardings, module_is_changed); } Status AutoShardingImplementation::CanonicalizeLayouts(HloModule* module) { @@ -4157,7 +3575,7 @@ AutoShardingImplementation::AutoShardingImplementation( const AutoShardingOption& option) : option_(option) {} -StatusOr AutoShardingImplementation::RunAutoSharding( +absl::StatusOr AutoShardingImplementation::RunAutoSharding( HloModule* module, const absl::flat_hash_set& replicated_small_tensors, const absl::flat_hash_set& execution_threads, @@ -4174,14 +3592,14 @@ StatusOr AutoShardingImplementation::RunAutoSharding( // shardings to their input ops. absl::flat_hash_map> unspecified_dims; - auto status_or_changed = ProcessShardingInstruction( + absl::StatusOr changed = ProcessShardingInstruction( module, execution_threads, /*replace_sharding_with_copy=*/true, &unspecified_dims, /*saved_root_shardings=*/nullptr, /*saved_parameter_shardings=*/nullptr); - if (!status_or_changed.ok()) { - return status_or_changed.status(); + if (!changed.ok()) { + return changed.status(); } - if (status_or_changed.value()) { + if (changed.value()) { module_is_changed = true; VLOG(3) << "CustomCalls with custom_call_target=Sharding are removed and " "their shardings are moved to their input ops."; @@ -4190,27 +3608,12 @@ StatusOr AutoShardingImplementation::RunAutoSharding( "custom_call_target=Sharding."; } + std::pair>, bool> + preserve_shardings_result = SaveAndRemoveShardingAnnotation( + module, replicated_small_tensors, execution_threads); absl::flat_hash_map> - preserve_shardings = spmd::SaveUserShardings( - module, replicated_small_tensors, option_.preserve_shardings); - - // Remove xla sharding annotations, if there is any. - if (option_.preserve_shardings != - AutoShardingOption::PreserveShardingsType::kKeepAllShardings) { - StatusOr status_or_changed = RemoveShardingAnnotation( - module, replicated_small_tensors, execution_threads); - if (!status_or_changed.ok()) { - return status_or_changed.status(); - } - if (status_or_changed.value()) { - module_is_changed = true; - LOG(INFO) << "XLA sharding annotations are removed."; - } else { - LOG(INFO) << "This workload does not have XLA sharding annotations."; - } - } else { - LOG(INFO) << "Preserving XLA sharding annotations."; - } + preserve_shardings = preserve_shardings_result.first; + module_is_changed |= preserve_shardings_result.second; // ----- Get a sequential schedule and do liveness analysis ----- auto size_fn = [](const BufferValue& buffer) { @@ -4224,7 +3627,19 @@ StatusOr AutoShardingImplementation::RunAutoSharding( const HloComputation* entry_computation = module->entry_computation(); std::unique_ptr alias_analysis = HloAliasAnalysis::Run(module).value(); - spmd::AliasMap alias_map = spmd::BuildAliasMap(module); + + // Handle donated args by resolving them into input-output aliases. While we + // want to perform this resolution, we do not want to modify the module, which + // is why we run the OptimizeInputOutputBufferAlias pass on a clone. + auto module_clone = module->Clone(""); + OptimizeInputOutputBufferAlias input_output_buffer_alias_optimizer( + /* registered_buffer_donor_only */ true); + CHECK_OK(input_output_buffer_alias_optimizer.Run(module_clone.get())); + const HloInputOutputAliasConfig& input_output_alias_config = + module_clone->input_output_alias_config(); + + spmd::AliasMap alias_map = + spmd::BuildAliasMap(module, input_output_alias_config); TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_live_range, @@ -4232,9 +3647,9 @@ StatusOr AutoShardingImplementation::RunAutoSharding( absl::flat_hash_map& buffer_live_ranges = hlo_live_range->buffer_live_ranges(); spmd::LivenessSet liveness_set(hlo_live_range->schedule_end_time() + 1); - for (const auto& iter : buffer_live_ranges) { - for (spmd::LivenessIdx i = iter.second.start; i <= iter.second.end; ++i) { - liveness_set[i].push_back(iter.first); + for (const auto& [hlo_value, live_range] : buffer_live_ranges) { + for (spmd::LivenessIdx i = live_range.start; i <= live_range.end; ++i) { + liveness_set[i].push_back(hlo_value); } } VLOG(10) << hlo_live_range->ToString(); @@ -4253,10 +3668,11 @@ StatusOr AutoShardingImplementation::RunAutoSharding( // supposed to make the solver faster, but it makes it much much slower for // both 1D and 2D mesh shapes. // batch_dim_map = spmd::BuildInstructionBatchDimMap(sequence); - // ----- Read parameters of device mesh ---- + + // ----- Read parameters of device mesh ----- Array original_device_mesh(option_.device_mesh_shape); original_device_mesh.SetValues(option_.device_mesh_ids); - int64_t original_memory_budget = option_.memory_budget_per_device; + const int64_t original_memory_budget = option_.memory_budget_per_device; std::vector> partial_mesh_shapes; if (option_.solve_nd_sharding_iteratively) { @@ -4272,9 +3688,8 @@ StatusOr AutoShardingImplementation::RunAutoSharding( .shape_size = [](const Shape& shape) { return spmd::GetBytes(shape); }}; HloCostAnalysis hlo_cost_analysis(hlo_cost_analysis_options); CHECK_OK(module->entry_computation()->Accept(&hlo_cost_analysis)); - for (size_t mesh_idx = 0; mesh_idx < partial_mesh_shapes.size(); ++mesh_idx) { - // Adjust existing shardings with current partial mesh shapes; + // Adjust existing shardings with current partial mesh shapes. std::vector mesh_shape = partial_mesh_shapes[mesh_idx]; LOG(INFO) << "Processing partial mesh shape: " << spmd::ToString(mesh_shape); @@ -4285,34 +3700,40 @@ StatusOr AutoShardingImplementation::RunAutoSharding( total_devices *= i; } if (mesh_idx != partial_mesh_shapes.size() - 1) { - auto changed_or = spmd::AdjustShardingsWithPartialMeshShape( + absl::StatusOr changed = spmd::AdjustShardingsWithPartialMeshShape( sequence.instructions(), mesh_shape, total_devices, /* crash_on_error */ !option_.try_multiple_mesh_shapes); - if (changed_or.ok()) { + if (changed.ok()) { LOG(INFO) << "Shardings are adjusted based on current partial mesh shape: " - << *changed_or; + << *changed; } else { - return changed_or.status(); + return changed.status(); } } - std::vector device_mesh_ids = std::vector(total_devices); - std::iota(device_mesh_ids.begin(), device_mesh_ids.end(), 0); - device_mesh.SetValues(device_mesh_ids); + if (option_.device_mesh_ids.size() == total_devices) { + // It is unclear what device order to use for partial meshes. So we only + // use the actual device order only for the final full mesh. + device_mesh.SetValues(option_.device_mesh_ids); + } else { + std::vector device_mesh_ids = + std::vector(total_devices); + std::iota(device_mesh_ids.begin(), device_mesh_ids.end(), 0); + device_mesh.SetValues(device_mesh_ids); + } - // TODO (zhuohan): include the prof result as an option. + // TODO (zhuohan): Include the prof result as an option. spmd::ProfilingResult prof_result; spmd::ClusterEnvironment cluster_env( original_device_mesh, device_mesh, option_.device_mesh_alpha, option_.device_mesh_beta, prof_result, option_); XLA_VLOG_LINES(6, module->ToString()); - int64_t memory_lower_bound = spmd::MemoryBudgetLowerBound( - *module, liveness_set, alias_analysis.get(), - device_mesh.num_elements()); - // Rounds up to the next GB. - int64_t memory_lower_bound_gb = - 1 + memory_lower_bound / (1024 * 1024 * 1024); + const int64_t memory_lower_bound = spmd::MemoryBudgetLowerBound( + *module, liveness_set, *alias_analysis, device_mesh.num_elements(), + preserve_shardings); + const float memory_lower_bound_gb = + static_cast(memory_lower_bound) / (1024 * 1024 * 1024); LOG(INFO) << "Memory consumption lower bound is " << memory_lower_bound_gb << " GB."; if (set_to_memory_lower_bound) { @@ -4320,14 +3741,12 @@ StatusOr AutoShardingImplementation::RunAutoSharding( << "--xla_tpu_auto_spmd_partitioning_memory_budget_gb is 0, and " "--xla_tpu_auto_spmd_partitioning_memory_budget_ratio is " << option_.memory_budget_ratio - << ", so setting " - "option.memory_budget_per_device to " + << ", so setting option.memory_budget_per_device to " << memory_lower_bound_gb << " x " << option_.memory_budget_ratio << " = " << memory_lower_bound_gb * option_.memory_budget_ratio << " GB"; - option_.memory_budget_per_device = memory_lower_bound_gb * - (1024 * 1024 * 1024) * - option_.memory_budget_ratio; + option_.memory_budget_per_device = + memory_lower_bound * option_.memory_budget_ratio; } else if (option_.memory_budget_per_device > 0) { option_.memory_budget_per_device = original_memory_budget * original_device_mesh.num_elements() / @@ -4351,27 +3770,41 @@ StatusOr AutoShardingImplementation::RunAutoSharding( // ----- Analyze depth ----- spmd::InstructionDepthMap ins_depth_map; ins_depth_map = spmd::BuildInstructionDepthMap(sequence, batch_dim_map); + // ----- Build strategies and costs ----- spmd::StrategyMap strategy_map; spmd::StrategyGroups strategy_groups; spmd::AssociativeDotPairs associative_dot_pairs; - TF_ASSIGN_OR_RETURN( std::tie(strategy_map, strategy_groups, associative_dot_pairs), BuildStrategyAndCost( sequence, module, instruction_execution_counts, ins_depth_map, batch_dim_map, alias_map, cluster_env, option_, *call_graph, hlo_cost_analysis, option_.try_multiple_mesh_shapes)); - spmd::AliasSet alias_set = spmd::BuildAliasSet(module, strategy_map); - CheckAliasSetCompatibility(alias_set, strategy_groups, sequence); + spmd::AliasSet alias_set = + spmd::BuildAliasSet(module, input_output_alias_config, strategy_map); + if (Status alias_set_status = CheckAliasSetCompatibility( + alias_set, strategy_groups, sequence, + /* crash_at_error */ !option_.try_multiple_mesh_shapes); + !alias_set_status.ok()) { + return alias_set_status; + } XLA_VLOG_LINES(8, PrintStrategyMap(strategy_map, sequence)); // ----- Build cost graph and merge unimportant nodes ----- spmd::CostGraph cost_graph(strategy_groups, associative_dot_pairs); cost_graph.Simplify(option_.simplify_graph); - // ----- Build the liveness node set ----- + // ----- Build the liveness node & edge sets ----- + std::vector> node_to_edges( + strategy_groups.size()); + spmd::EdgeIdx edge_idx = 0; + for (const auto& [edge, _] : cost_graph.edge_costs_) { + node_to_edges[edge.second].insert(edge_idx); + ++edge_idx; + } spmd::LivenessNodeSet liveness_node_set(liveness_set.size()); + spmd::LivenessEdgeSet liveness_edge_set(liveness_set.size()); for (spmd::LivenessIdx t = 0; t < liveness_set.size(); ++t) { for (const HloValue* value : liveness_set[t]) { const HloInstruction* instruction = value->instruction(); @@ -4381,54 +3814,55 @@ StatusOr AutoShardingImplementation::RunAutoSharding( strategy_map.at(instruction).get(); const spmd::NodeIdx node_idx = strategy_group->GetSubStrategyGroup(index)->node_idx; - if (node_idx >= 0) liveness_node_set[t].push_back(node_idx); + if (node_idx < 0) continue; + liveness_node_set[t].push_back(node_idx); + for (const spmd::EdgeIdx edge_idx : node_to_edges[node_idx]) { + liveness_edge_set[t].push_back(edge_idx); + } } std::sort(liveness_node_set[t].begin(), liveness_node_set[t].end()); + std::sort(liveness_edge_set[t].begin(), liveness_edge_set[t].end()); } // ----- Call the ILP Solver ----- - std::vector s_val; - std::vector e_val; - double objective = -1.0; - if (!option_.load_solution_vector) { - auto solver_result = - Solve(*module, *hlo_live_range, liveness_node_set, strategy_map, - strategy_groups, cost_graph, alias_set, option_, - sharding_propagation_solution); - if (solver_result.skip_auto_sharding) { - return AutoShardingResult::kModuleUnchangedNoShardingPerfomed; - } else if (!solver_result.status.ok()) { - return AutoShardingResult::kModuleUnchanged; - } else { - TF_ASSIGN_OR_RETURN(auto solution, solver_result.status); - std::tie(s_val, e_val, objective) = solution; - if (mesh_idx == partial_mesh_shapes.size() - 1) { - this->solver_optimal_objective_value_ = objective; - } - } + spmd::AutoShardingSolverOutput output; + std::string request_name = absl::StrCat("mesh_idx_", mesh_idx); + auto solver_result = + Solve(*module, *hlo_live_range, liveness_node_set, liveness_edge_set, + strategy_map, strategy_groups, cost_graph, alias_set, option_, + request_name, sharding_propagation_solution); + if (solver_result.skip_auto_sharding) { + return AutoShardingResult::kModuleUnchangedNoShardingPerformed; + } else if (!solver_result.status.ok()) { + return AutoShardingResult::kModuleUnchanged; } else { - s_val = option_.strategy_vector; + TF_ASSIGN_OR_RETURN(auto solution, solver_result.status); + output = solution; + if (mesh_idx == partial_mesh_shapes.size() - 1) { + this->solver_optimal_objective_value_ = output.cost; + } } - XLA_VLOG_LINES(5, PrintAutoShardingSolution(sequence, liveness_set, - strategy_map, strategy_groups, - cost_graph, s_val, objective)); - XLA_VLOG_LINES(1, PrintSolutionMemoryUsage(liveness_set, strategy_map, - cost_graph, s_val)); + XLA_VLOG_LINES(5, PrintAutoShardingSolution( + sequence, liveness_set, strategy_map, strategy_groups, + cost_graph, output.s_val, output.cost)); + XLA_VLOG_LINES(6, PrintSolutionMemoryUsage(liveness_set, strategy_map, + cost_graph, output.s_val)); // ----- Substitute all-reduce with reduce-scatter ----- if (option_.prefer_reduce_scatter) { - GenerateReduceScatter(sequence, alias_map, ins_depth_map, strategy_map, - cost_graph, s_val, cluster_env, option_); + TF_RETURN_IF_ERROR(GenerateReduceScatter( + sequence, alias_map, ins_depth_map, strategy_map, cost_graph, + output.s_val, cluster_env, option_)); } // ----- Set Sharding ----- - SetHloSharding(sequence, strategy_map, cost_graph, s_val, + SetHloSharding(sequence, strategy_map, cost_graph, output.s_val, (mesh_idx == partial_mesh_shapes.size() - 1)); if (mesh_idx == partial_mesh_shapes.size() - 1) { if (!SetHloShardingPostProcessing( - sequence, strategy_map, cost_graph, s_val, cluster_env, + sequence, strategy_map, cost_graph, output.s_val, cluster_env, /* crash_at_error */ !option_.try_multiple_mesh_shapes, - &preserve_shardings) + preserve_shardings) .ok()) { return AutoShardingResult::kModuleUnchanged; } @@ -4469,9 +3903,6 @@ bool ModuleHasUserShardings(const HloModule* module) { return has_shardings; } -AutoSharding::AutoSharding(const AutoShardingOption& option) - : option_(option) {} - bool IsSmallTensor(const HloInstruction* ins, const AutoShardingOption& option) { return spmd::GetInstructionSize(ins->shape()) <= @@ -4490,6 +3921,33 @@ bool IsModuleManuallySharded(const HloModule* module) { return false; } +bool ShardedOnTooManyMeshAxes(const HloModule& module) { + for (const auto* computation : module.computations()) { + for (const auto* instruction : computation->instructions()) { + if (instruction->has_sharding() && instruction->sharding().IsTiled() && + spmd::NumTileDimensions(instruction->sharding()) >= 3) { + return true; + } + } + } + return false; +} + +bool HasUnsupportedNestedTuples(const HloModule& module) { + for (const auto* computation : module.computations()) { + for (const auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kConditional) { + for (const HloInstruction* operand : instruction->operands()) { + if (ShapeUtil::IsNestedTuple(operand->shape())) { + return true; + } + } + } + } + } + return false; +} + std::unique_ptr CloneModule(const HloModule* module) { auto module_clone = module->Clone(""); module_clone->set_layout_canonicalization_callback( @@ -4497,7 +3955,10 @@ std::unique_ptr CloneModule(const HloModule* module) { return module_clone; } -StatusOr AutoSharding::Run( +AutoSharding::AutoSharding(const AutoShardingOption& option) + : option_(option) {} + +absl::StatusOr AutoSharding::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { if (!option_.enable) { @@ -4506,9 +3967,26 @@ StatusOr AutoSharding::Run( LOG(INFO) << "Starting the auto sharding pass"; if (IsModuleManuallySharded(module)) { - LOG(ERROR) - << "Auto-sharding on partially manually sharded modules is not yet " - "supported. Please fall back on the sharding propagation pass."; + LOG(FATAL) + << "Auto-sharding on partially manually sharded modules " // Crash OK + "is not yet supported. Please fall back on the sharding " + "propagation pass."; + return false; + } + + if (ShardedOnTooManyMeshAxes(*module)) { + LOG(FATAL) << "The input module contains sharding annotations " // Crash OK + "over a mesh with too many axes (>2). This case is currently " + "not well supported."; + return false; + } + + // TODO(b/332951306): Remove this check once nested tuples are supported + // everywhere + if (HasUnsupportedNestedTuples(*module)) { + LOG(FATAL) << "The input module contains nested tuples " // Crash OK + "which we do not currently support well. See b/332951306 to " + "track progress on this."; return false; } @@ -4522,6 +4000,8 @@ StatusOr AutoSharding::Run( metrics::RecordAutoShardingInvocations(); #endif + TF_RETURN_IF_ERROR(module->RemoveUnusedComputations()); + TF_RETURN_IF_ERROR(option_.CheckAndSetup()); LOG(INFO) << "AutoShardingOptions:\n" << option_.ToString(); @@ -4564,19 +4044,41 @@ StatusOr AutoSharding::Run( mesh_shapes.push_back(option_.device_mesh_shape); } + if (module->entry_computation()->num_parameters() > 0) { + HloInstruction* parameter_instruction = + module->entry_computation()->parameter_instruction(0); + if (parameter_instruction->shape().IsTuple() && + parameter_instruction->has_sharding()) { + CHECK_EQ(module->entry_computation()->num_parameters(), 1); + parameter_instruction->set_sharding( + spmd::ReplaceGivenShardingsWithUnknownForTuple( + parameter_instruction->sharding(), parameter_instruction->shape(), + module->config() + .allow_spmd_sharding_propagation_to_parameters())); + } + } + + HloInstruction* root_instruction = + module->entry_computation()->root_instruction(); + if (root_instruction->shape().IsTuple() && root_instruction->has_sharding()) { + root_instruction->set_sharding( + spmd::ReplaceGivenShardingsWithUnknownForTuple( + root_instruction->sharding(), root_instruction->shape(), + module->config().allow_spmd_sharding_propagation_to_output())); + } + absl::flat_hash_map sharding_propagation_solution; std::unique_ptr module_with_default_solution = nullptr; if (option_.use_sharding_propagation_for_default_shardings) { module_with_default_solution = CloneModule(module); - // TODO(pratikf): Ensure that we're passing the correct custom call sharding - // helper to the sharding propagation pass. + // TODO(pratikf): Ensure that we're passing the correct custom call + // sharding helper to the sharding propagation pass. auto sharding_prop = ShardingPropagation( /*is_spmd */ true, /*propagate_metadata */ false, /*allow_spmd_sharding_propagation_to_output*/ module->config().allow_spmd_sharding_propagation_to_output(), - /*allow_spmd_sharding_propagation_to_parameters */ - absl::InlinedVector{false}, + module->config().allow_spmd_sharding_propagation_to_parameters(), /*cse_prevention_only */ false, /*sharding_helper*/ nullptr); @@ -4608,6 +4110,12 @@ StatusOr AutoSharding::Run( VLOG(1) << "Trying mesh shape " << spmd::ToString(mesh_shapes[i]); AutoShardingOption this_option = option_; this_option.device_mesh_shape = mesh_shapes[i]; + if (this_option.device_mesh_shape.size() != + this_option.device_mesh_alpha.size()) { + this_option.device_mesh_alpha.clear(); + this_option.device_mesh_beta.clear(); + TF_RETURN_IF_ERROR(this_option.CheckAndSetup()); + } auto pass = new AutoShardingImplementation(this_option); auto module_clone = CloneModule(module); auto pass_result = @@ -4620,8 +4128,7 @@ StatusOr AutoSharding::Run( delete pass; if (!pass_result.ok()) { VLOG(1) << "Mesh shape " << spmd::ToString(mesh_shapes[i]) - << " did work lead to an auto-sharding solution due to the " - "following error: " + << " led to the following error: " << pass_result.status().message(); continue; } @@ -4633,30 +4140,29 @@ StatusOr AutoSharding::Run( } if (pass_result.ok() && pass_result.value() != - AutoShardingResult::kModuleUnchangedNoShardingPerfomed) { + AutoShardingResult::kModuleUnchangedNoShardingPerformed) { skip_auto_sharding = false; } } - StatusOr module_is_changed; + absl::StatusOr module_is_changed; if (skip_auto_sharding) { - VLOG(1) << "Solver timed out. Will now rely on sharding propagation to " - "perform sharding."; - if (!ModuleHasUserShardings(module)) { - LOG(WARNING) - << "The auto-sharding solver has timed out without a solution. " - "Further, as the input module does not contain any sharding " - "annotations, we cannot rely on sharding propagation to perform " - "heuristic-guided sharding. The module therefore may not be " - "sharded leading to low performance."; - } + LOG(FATAL) << "The auto-sharding solver has timed out without a solution."; module_is_changed = false; } else { + std::string trying_to_find; + if (option_.try_multiple_mesh_shapes) { + trying_to_find = "a device mesh (and the corresponding shardings)"; + } else { + trying_to_find = "shardings"; + } CHECK_GE(min_mesh_shape_index, 0) - << "The auto-sharding pass could not find a device mesh that works for " - "this input. This could be the result of a low memory budget. If " - "you think you have set a reasonably large memory budget, please " - "report this as a bug."; + << "The auto-sharding pass could not find " << trying_to_find + << " that works for this input. This could be the result of a low " + "memory budget (please refer to the " + "`--xla_tpu_auto_spmd_partitioning_memory_budget_ratio` flag to set " + "a higher budget). If you think you have set a reasonably large " + "memory budget, please report this as a bug."; if (!changed[min_mesh_shape_index].ok()) { module_is_changed = changed[min_mesh_shape_index].status(); @@ -4669,12 +4175,25 @@ StatusOr AutoSharding::Run( << " which had the minimal solver objective value of " << min_objective_value; chosen_mesh_shape_ = mesh_shapes[min_mesh_shape_index]; + TF_RETURN_IF_ERROR( + modules[min_mesh_shape_index]->RemoveUnusedComputations()); + const std::vector& original_module_computations = + module->MakeComputationSorted(); + const std::vector& clone_module_computations = + modules[min_mesh_shape_index]->MakeComputationSorted(); + if (original_module_computations.size() != + clone_module_computations.size()) { + return absl::InternalError( + "The cloned and the original modules do not have the same number " + "of computations. This is a bug and should be reported."); + } + absl::flat_hash_map computation_replacements; - for (size_t i = 0; i < module->computation_count(); ++i) { - auto original_computation = module->mutable_computation(i); - auto new_computation = - modules[min_mesh_shape_index]->mutable_computation(i); + for (size_t i = 0; i < original_module_computations.size(); ++i) { + HloComputation* original_computation = + original_module_computations[i]; + HloComputation* new_computation = clone_module_computations[i]; computation_replacements[original_computation] = new_computation; } @@ -4683,6 +4202,10 @@ StatusOr AutoSharding::Run( *module->mutable_config().mutable_entry_computation_layout() = modules[min_mesh_shape_index]->entry_computation_layout(); + module->input_output_alias_config() = + modules[min_mesh_shape_index]->input_output_alias_config(); + module->buffer_donor_config() = + modules[min_mesh_shape_index]->buffer_donor_config(); module_is_changed = true; } else if (changed[min_mesh_shape_index].value() == @@ -4709,7 +4232,7 @@ StatusOr AutoSharding::Run( return module_is_changed; } -StatusOr DummyAutoSharding::Run( +absl::StatusOr DummyAutoSharding::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { // ----- Set Dummy Replicated Sharding ----- diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.h b/xla/hlo/experimental/auto_sharding/auto_sharding.h index 5af6ad3d35cc8..9a994bfce9a7e 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,10 @@ limitations under the License. #include #include #include +#include #include +#include +#include #include #include "absl/container/flat_hash_map.h" @@ -37,9 +40,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/call_graph.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_pass_interface.h" #include "xla/shape.h" -#include "xla/statusor.h" +#include "xla/status.h" namespace xla { @@ -50,7 +55,7 @@ class DummyAutoSharding : public HloModulePass { absl::string_view name() const override { return "dummy_auto_sharding"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; @@ -58,7 +63,7 @@ class DummyAutoSharding : public HloModulePass { enum class AutoShardingResult { kModuleUnchanged, kModuleChangedShardingPerformed, - kModuleUnchangedNoShardingPerfomed + kModuleUnchangedNoShardingPerformed }; class AutoShardingImplementation { @@ -66,21 +71,21 @@ class AutoShardingImplementation { explicit AutoShardingImplementation(const AutoShardingOption& option); ~AutoShardingImplementation() = default; - // using HloPassInterface::Run; - StatusOr RunAutoSharding( + absl::StatusOr RunAutoSharding( HloModule* module, const absl::flat_hash_set& replicated_small_tensors, const absl::flat_hash_set& execution_threads, const absl::flat_hash_map& sharding_propagation_solution = {}); - // Removes SPMD annotations (if there are) to test AutoSharding on manually - // annotated graphs. - StatusOr RemoveShardingAnnotation( + // Returns sharding annotations that need to be preserved in a map (for + // verification after auto-sharding is done), and removes any sharding + // anotations that need to be removed. + std::pair>, bool> + SaveAndRemoveShardingAnnotation( HloModule* module, - const absl::flat_hash_set& replicated_small_tensors = {}, - - const absl::flat_hash_set& execution_threads = {}); + const absl::flat_hash_set& replicated_small_tensors, + const absl::flat_hash_set& execution_threads); // Canonicalizes entry_computation_layouts by calling // module.layout_canonicalization_callback(), which gives canonicalized @@ -90,7 +95,7 @@ class AutoShardingImplementation { // tensorflow/compiler/xla/pjrt/utils.cc Status CanonicalizeLayouts(HloModule* module); - // Returns the optimal objective value that the ILP solver computes + // Returns the optimal objective value that the ILP solver computes. double GetSolverOptimalObjectiveValue() { return solver_optimal_objective_value_; } @@ -99,7 +104,7 @@ class AutoShardingImplementation { AutoShardingOption option_; // Stores the optimal value of the objective the solver found. This is used to - // chose the best mesh shape when the try_multiple_mesh_shapes option is on. + // choose the best mesh shape when the try_multiple_mesh_shapes option is on. double solver_optimal_objective_value_ = -1.0; }; @@ -110,7 +115,7 @@ class AutoSharding : public HloModulePass { absl::string_view name() const override { return "auto_sharding"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -129,16 +134,21 @@ class AutoSharding : public HloModulePass { }; namespace spmd { -// Function declarations +// Function declarations. // Their comments can be found in their definitions in *.cc files. HloSharding Tile(const Shape& shape, absl::Span tensor_dims, absl::Span mesh_dims, const Array& device_mesh); -std::vector ReshardingCostVector(const StrategyGroup* strategy_group, - const Shape& shape, - const HloSharding& required_sharding, - const ClusterEnvironment& cluster_env); +std::vector CommunicationReshardingCostVector( + const StrategyGroup* strategy_group, const Shape& shape, + const HloSharding& required_sharding, + const ClusterEnvironment& cluster_env); + +std::vector MemoryReshardingCostVector( + const StrategyGroup* strategy_group, const Shape& operand_shape, + const HloSharding& required_sharding, + const ClusterEnvironment& cluster_env); std::vector FollowInsCostVector(int64_t source_len, int64_t index); @@ -161,6 +171,8 @@ Status FilterStrategy(const HloInstruction* ins, const Shape& shape, Status HandleDot(std::unique_ptr& strategy_group, StrategyGroups& strategy_groups, StrategyMap& strategy_map, const HloInstruction* ins, size_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph); @@ -168,6 +180,8 @@ Status HandleDot(std::unique_ptr& strategy_group, Status HandleConv(std::unique_ptr& strategy_group, StrategyGroups& strategy_groups, StrategyMap& strategy_map, const HloInstruction* ins, size_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, @@ -191,11 +205,12 @@ AliasMap BuildAliasMap(const HloModule* module); AliasSet BuildAliasSet(const HloModule* module, const StrategyMap& strategy_map); -void CheckAliasSetCompatibility(const AliasSet& alias_set, - const StrategyGroups& strategy_groups, - const HloInstructionSequence& sequence); +Status CheckAliasSetCompatibility(const AliasSet& alias_set, + const StrategyGroups& strategy_groups, + const HloInstructionSequence& sequence, + bool crash_on_error); -void GenerateReduceScatter( +absl::Status GenerateReduceScatter( const HloInstructionSequence& sequence, const AliasMap& alias_map, const InstructionDepthMap& depth_map, const StrategyMap& strategy_map, const CostGraph& cost_graph, absl::Span s_val, @@ -213,9 +228,11 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, // The high-level "recipe" for solving an Auto Sharding problem. AutoShardingSolverResult Solve( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, - const LivenessNodeSet& liveness_node_set, const StrategyMap& strategy_map, + const LivenessNodeSet& liveness_node_set, + const LivenessEdgeSet& liveness_edge_set, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, const AutoShardingOption& option, + absl::string_view request_prefix, const absl::flat_hash_map& sharding_propagation_solution = {}); @@ -223,6 +240,160 @@ AutoShardingSolverResult Solve( void PopulateTemporalValues(const CostGraph& cost_graph, AutoShardingSolverRequest& request); +void AddReplicatedStrategy( + const HloInstruction* ins, const Shape& shape, + const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, + std::unique_ptr& strategy_group, double replicated_penalty, + absl::flat_hash_set operands_to_consider_all_strategies_for = {}); + +void CheckMemoryCosts(StrategyGroup* strategy_group, const Shape& shape); + +// Choose an operand to follow. We choose to follow the operand with the highest +// priority. +std::pair ChooseOperandToFollow( + const StrategyMap& strategy_map, const InstructionDepthMap& depth_map, + const AliasMap& alias_map, int64_t max_depth, const HloInstruction* ins); + +void FillAllStrategiesForArray( + std::unique_ptr& strategy_group, const HloInstruction* ins, + const Shape& shape, const ClusterEnvironment& cluster_env, + const StrategyMap& strategy_map, const AutoShardingOption& option, + double replicated_penalty, const InstructionBatchDimMap& batch_dim_map, + const CallGraph& call_graph, bool only_allow_divisible, + bool create_replicated_strategies, + bool create_partially_replicated_strategies); + +absl::StatusOr> CreateAllStrategiesGroup( + const HloInstruction* ins, const Shape& shape, size_t instruction_id, + StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, + const StrategyMap& strategy_map, const AutoShardingOption& option, + double replicated_penalty, const InstructionBatchDimMap& batch_dim_map, + const CallGraph& call_graph, bool only_allow_divisible, + bool create_replicated_strategies, + bool create_partially_replicated_strategies); + +// Enumerates sharding strategies for elementwise operators by following +// strategies of an operand of the elementwise op. +std::unique_ptr CreateElementwiseOperatorStrategies( + size_t instruction_id, const HloInstruction* ins, + const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, + const InstructionDepthMap& depth_map, const AliasMap& alias_map, + const StableHashMap>& + pretrimmed_strategy_map, + int64_t max_depth, StrategyGroups& strategy_groups, + AssociativeDotPairs& associative_dot_pairs); + +// Factory functions for StrategyGroup. +std::unique_ptr CreateLeafStrategyGroupWithoutInNodes( + size_t instruction_id, StrategyGroups& strategy_groups); + +// Enumerates sharding strategies for reshape operators. The function does so by +// essentially reshaping the sharding of the operand in a manner similar to the +// tensor reshape itself. +std::unique_ptr CreateReshapeStrategies( + size_t instruction_id, const HloInstruction* ins, + const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, + bool only_allow_divisible, double replicated_penalty, + const InstructionBatchDimMap& batch_dim_map, + const AutoShardingOption& option, StrategyGroups& strategy_groups, + const CallGraph& call_graph); + +std::unique_ptr CreateTupleStrategyGroup(size_t instruction_id); + +// Enumerate all 1d partition strategies. +void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, + const Array& device_mesh, + const ClusterEnvironment& cluster_env, + const StrategyMap& strategy_map, + std::unique_ptr& strategy_group, + bool only_allow_divisible, + const std::string& suffix, + const CallGraph& call_graph); + +// Enumerate all partitions recursively. +void EnumerateAllPartition(const HloInstruction* ins, const Shape& shape, + const Array& device_mesh, + const ClusterEnvironment& cluster_env, + const StrategyMap& strategy_map, + std::unique_ptr& strategy_group, + const InstructionBatchDimMap& batch_dim_map, + bool only_allow_divisible, + const CallGraph& call_graph, + int64_t partition_dimensions, + const std::vector& tensor_dims = {}); + +absl::StatusOr> FollowReduceStrategy( + const HloInstruction* ins, const Shape& output_shape, + const HloInstruction* operand, const HloInstruction* unit, + size_t instruction_id, StrategyMap& strategy_map, + StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, + bool allow_mixed_mesh_shape, bool crash_at_error); + +void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, + const ClusterEnvironment& cluster_env, + const StrategyMap& strategy_map, + std::unique_ptr& strategy_group, + double replicated_penalty); + +std::pair +GenerateReshardingCostsAndMissingShardingsForAllOperands( + const HloInstruction* ins, const HloSharding& output_sharding, + const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, + const CallGraph& call_graph, + std::vector>& input_shardings); + +bool LeafVectorsAreConsistent(const std::vector& one, + const std::vector& two, + bool is_reshape); + +std::unique_ptr MaybeFollowInsStrategyGroup( + const StrategyGroup* src_strategy_group, const Shape& shape, + size_t instruction_id, bool have_memory_cost, + StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, + const StableHashMap>& + pretrimmed_strategy_map); + +void RemoveInvalidShardingsWithShapes(const Shape& shape, + StrategyGroup* strategy_group, + bool instruction_has_user_sharding); + +void ScaleCostsWithExecutionCounts(StrategyGroup* strategy_group, + int64_t execution_count); + +// Existing shardings refer to the HloSharding field in the given +// HloInstruction. +void TrimOrGenerateStrategiesBasedOnExistingSharding( + const Shape& output_shape, StrategyGroup* strategy_group, + const StrategyMap& strategy_map, + const std::vector& instructions, + const HloSharding& existing_sharding, const ClusterEnvironment& cluster_env, + StableHashMap>& + pretrimmed_strategy_map, + const CallGraph& call_graph, bool strict); + +// Build possible sharding strategies and their costs for all instructions. +absl::StatusOr> +BuildStrategyAndCost(const HloInstructionSequence& sequence, + const HloModule* module, + const absl::flat_hash_map& + instruction_execution_counts, + const InstructionDepthMap& depth_map, + const InstructionBatchDimMap& batch_dim_map, + const AliasMap& alias_map, + const ClusterEnvironment& cluster_env, + AutoShardingOption& option, const CallGraph& call_graph, + const HloCostAnalysis& hlo_cost_analysis, + bool trying_multiple_mesh_shapes); + +// Computes an approximate lower bound on the per-device memory usage of a +// module once it has been sharded. This quantity is multiplied with +// memory_budget_ratio to obtain the memory budget using in our ILP formulation. +int64_t MemoryBudgetLowerBound( + const HloModule& module, const LivenessSet& liveness_set, + const HloAliasAnalysis& alias_analysis, int64_t num_devices, + const absl::flat_hash_map>& + preserved_shardings); + } // namespace spmd } // namespace xla diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.proto b/xla/hlo/experimental/auto_sharding/auto_sharding.proto index 095a3fce35e3e..102b3e86e35a0 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.proto +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,6 +28,12 @@ message AutoShardingSolverRequest { message Nodes { repeated int64 nodes = 1; } + message Edges { + repeated int64 edges = 1; + } + message Names { + repeated string names = 1; + } message SolverTimeout { int64 solver_timeout_in_seconds = 1; } @@ -40,23 +46,34 @@ message AutoShardingSolverRequest { repeated int64 s_len = 3; repeated int64 s_follow = 4; repeated int64 s_hint = 5; + repeated int64 peak_times = 35; repeated Pair edges = 6; repeated Nodes live = 7; + repeated Edges live_edges = 28; repeated Costs computation_costs = 8; repeated Costs communication_costs = 9; repeated Costs memory_costs = 10; + repeated Costs memory_edge_costs = 29; repeated Costs departure_costs = 11; repeated Costs resharding_costs = 12; repeated Costs duration_costs = 13; repeated Pair aliases = 14; repeated Costs value_costs = 15; repeated string instruction_names = 16; + repeated string opcodes = 33; + repeated Names strategy_names = 32; optional SolverTimeout solver_timeout = 17; optional Coeff overbudget_coeff = 18; optional Coeff makespan_coeff = 19; optional Coeff max_departures = 20; + optional Coeff max_cost = 25; + optional Coeff coeff_limit = 26; bool crash_at_infinity_costs_check = 21; bool compute_iis = 22; double saltiplier = 23; + bool deterministic_mode = 27; string module_name = 24; + string request_name = 30; + bool enable_output = 31; + bool enable_memory_edge_costs = 34; } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc new file mode 100644 index 0000000000000..1156e0b80c302 --- /dev/null +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc @@ -0,0 +1,345 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" +#include "xla/hlo/experimental/auto_sharding/matrix.h" + +namespace xla { +namespace spmd { + +CostGraph::CostGraph(const StrategyGroups& strategy_groups, + const AssociativeDotPairs& associative_dot_pairs) { + node_lens_.reserve(strategy_groups.size()); + extra_node_costs_.reserve(strategy_groups.size()); + adjacency_.assign(strategy_groups.size(), StableHashSet()); + + // Build the cost graph. + for (StrategyGroup* strategy_group : strategy_groups) { + node_lens_.push_back(strategy_group->strategies.size()); + extra_node_costs_.push_back( + std::vector(strategy_group->strategies.size(), 0.0)); + + const auto& in_nodes = strategy_group->in_nodes; + for (size_t i = 0; i < in_nodes.size(); ++i) { + if (!in_nodes[i]->is_tuple) { + NodeIdx src_idx = in_nodes[i]->node_idx; + NodeIdx dst_idx = strategy_group->node_idx; + EdgeReshardingCostMatrix edge_cost = + CreateEdgeCost(src_idx, dst_idx, i, strategy_group); + AddEdgeCost(src_idx, dst_idx, edge_cost); + } else if (in_nodes[i]->is_tuple && in_nodes.size() > 1) { + for (size_t l = 0; l < in_nodes[i]->childs.size(); ++l) { + NodeIdx src_idx = in_nodes[i]->childs[l]->node_idx; + NodeIdx dst_idx = strategy_group->node_idx; + EdgeReshardingCostMatrix edge_cost = + CreateEdgeCost(src_idx, dst_idx, i, strategy_group, true); + AddEdgeCost(src_idx, dst_idx, edge_cost); + } + } else { + CHECK_EQ(in_nodes.size(), 1) + << "Do not support instructions with more than one tuple " + "operand. If this CHECK fails, we will need to fix " + "b/233412625."; + for (size_t l = 0; l < in_nodes[i]->childs.size(); ++l) { + NodeIdx src_idx = in_nodes[i]->childs[l]->node_idx; + NodeIdx dst_idx = strategy_group->node_idx; + // TODO(b/233412625) Support more general case, e.g., multiple tuple + // operands. If there is only one operand and it's a tuple, the + // first index of communication_resharding_costs is for the tuple + // element. + EdgeReshardingCostMatrix edge_cost = CreateEdgeCost( + src_idx, dst_idx, /*in_node_idx=*/l, strategy_group); + AddEdgeCost(src_idx, dst_idx, edge_cost); + } + } + } + + if (strategy_group->following) { + if (strategy_group->strategies.size() == + strategy_group->following->strategies.size()) { + to_merge_pairs_.push_back( + {strategy_group->node_idx, strategy_group->following->node_idx}); + } else { + LOG(WARNING) << "Different strategy counts for instruction ID " + << strategy_group->instruction_id + << " and following instruction ID " + << strategy_group->following->instruction_id; + } + } + } + + // Adjust the edge costs for dot pairs that can be optimized by + // AllReduceReassociate. + for (const auto& pair : associative_dot_pairs) { + NodeIdx src_idx = pair.first->node_idx; + NodeIdx dst_idx = pair.second->node_idx; + + EdgeReshardingCostMatrix edge_cost(node_lens_[src_idx], + node_lens_[dst_idx]); + absl::flat_hash_map + src_strategy_name_to_idx_map; + for (NodeStrategyIdx i = 0; i < node_lens_[src_idx]; ++i) { + const ShardingStrategy& strategy = + strategy_groups[src_idx]->strategies[i]; + if (strategy.communication_cost > 0) { + src_strategy_name_to_idx_map[strategy.name] = i; + } + } + for (NodeStrategyIdx i = 0; i < node_lens_[dst_idx]; ++i) { + const ShardingStrategy& dst_strategy = + strategy_groups[dst_idx]->strategies[i]; + if (dst_strategy.communication_cost > 0) { + auto it = src_strategy_name_to_idx_map.find(dst_strategy.name); + if (it != src_strategy_name_to_idx_map.end()) { + const ShardingStrategy& src_strategy = + strategy_groups[src_idx]->strategies[it->second]; + CHECK_LE(std::abs(src_strategy.communication_cost - + dst_strategy.communication_cost), + 1e-6); + edge_cost(it->second, i).communication_cost = + -src_strategy.communication_cost; + } + } + } + AddEdgeCost(src_idx, dst_idx, edge_cost); + } +} + +EdgeReshardingCostMatrix CostGraph::CreateEdgeCost( + const NodeIdx src_idx, const NodeIdx dst_idx, const size_t in_node_idx, + StrategyGroup* strategy_group, const bool zero_cost) { + CHECK_LT(src_idx, node_lens_.size()); + CHECK_LT(dst_idx, node_lens_.size()); + EdgeReshardingCostMatrix edge_cost(node_lens_[src_idx], node_lens_[dst_idx]); + for (NodeStrategyIdx k = 0; k < strategy_group->strategies.size(); ++k) { + const ShardingStrategy& strategy = strategy_group->strategies[k]; + size_t start_idx = 0; + CHECK_LT(in_node_idx, strategy.memory_resharding_costs.size()) + << strategy_group->node_idx; + if (strategy.memory_resharding_costs[in_node_idx].size() > + node_lens_[src_idx]) { + start_idx = strategy.memory_resharding_costs[in_node_idx].size() - + node_lens_[src_idx]; + } + for (size_t j = start_idx; + j < strategy.memory_resharding_costs[in_node_idx].size(); ++j) { + double communication_cost = 0; + double memory_cost = 0; + if (!zero_cost) { + communication_cost = + strategy.communication_resharding_costs[in_node_idx][j]; + memory_cost = strategy.memory_resharding_costs[in_node_idx][j]; + } + edge_cost(j - start_idx, k) = + EdgeReshardingCost(communication_cost, memory_cost); + } + } + return edge_cost; +} + +EdgeReshardingCostMatrix CostGraph::GetEdgeCost(const NodeIdx i, + const NodeIdx j) { + if (i <= j) { + return edge_costs_[{i, j}]; + } + return edge_costs_[{j, i}].Transpose(); +} + +void CostGraph::AddEdgeCost(NodeIdx i, NodeIdx j, + EdgeReshardingCostMatrix& cost) { + if (i > j) { + std::swap(i, j); + cost = cost.Transpose(); + } + + if (edge_costs_.contains({i, j})) { + CHECK(adjacency_[i].contains(j)); + CHECK(adjacency_[j].contains(i)); + edge_costs_[{i, j}] = edge_costs_[{i, j}] + cost; + } else { + adjacency_[i].insert(j); + adjacency_[j].insert(i); + edge_costs_[{i, j}] = cost; + } +} + +void CostGraph::RemoveEdge(NodeIdx i, NodeIdx j) { + if (i > j) { + std::swap(i, j); + } + + CHECK(adjacency_[i].contains(j)); + CHECK(adjacency_[j].contains(i)); + CHECK(edge_costs_.contains({i, j})); + + adjacency_[i].erase(j); + adjacency_[j].erase(i); + edge_costs_.erase({i, j}); +} + +void CostGraph::MergeNode(const NodeIdx src, const NodeIdx dst) { + CHECK(adjacency_[src].contains(dst)); + CHECK(adjacency_[dst].contains(src)); + CHECK(!merged_to_.contains(src)); + CHECK(!merged_to_.contains(dst)); + CHECK_NE(src, dst); + + EdgeReshardingCostMatrix edge_cost = GetEdgeCost(dst, src); + + std::vector reindexing(node_lens_[dst]); + if (node_lens_[dst] == node_lens_[src]) { + // Assume the orders of strategies in src and dst match + // (i.e., i-th strategy in src follows i-th strategy in dst). + // This is true in most cases because of how we create the + // following strategies. + std::iota(reindexing.begin(), reindexing.end(), 0); + } else { + // Otherwise, find the strategy to follow greedily. + // For every strategy in dst, find the strategy in src with + // the lowest resharding cost. + std::vector arange(node_lens_[src]); + std::iota(arange.begin(), arange.end(), 0); + for (NodeStrategyIdx i = 0; i < node_lens_[dst]; ++i) { + std::vector> keys; + + // If there are multiple strategies with the same lowest costs, + // prefer to follow "replicated", which has the largest index. + // Node: We assume the strategy "Repilcated" is always appended + // as the last strategy in BuildStrategyAndCost. + keys.reserve(node_lens_[src]); + for (NodeStrategyIdx j = 0; j < node_lens_[src]; ++j) { + keys.push_back({edge_cost(i, j).communication_cost, -j}); + } + + std::sort(arange.begin(), arange.end(), [&keys](int l, int r) { + return (keys[l].first < keys[r].first) || + (keys[l].first == keys[r].first && + keys[l].second < keys[r].second); + }); + reindexing[i] = arange.front(); + } + } + merged_to_[src] = dst; + reindexing_vector_[src] = reindexing; + + // Merge edge-cost matrix. + std::vector adj_list(adjacency_[src].begin(), adjacency_[src].end()); + for (const NodeIdx adj : adj_list) { + if (adj == dst) { + for (NodeStrategyIdx i = 0; i < node_lens_[dst]; ++i) { + extra_node_costs_[dst][i] += + edge_cost(i, reindexing[i]).communication_cost; + } + } else { + EdgeReshardingCostMatrix added_edge_cost(node_lens_[dst], + node_lens_[adj]); + EdgeReshardingCostMatrix edge_cost_src_adj = GetEdgeCost(src, adj); + for (NodeStrategyIdx i = 0; i < node_lens_[dst]; ++i) { + for (NodeStrategyIdx k = 0; k < node_lens_[adj]; ++k) { + added_edge_cost(i, k) = edge_cost_src_adj(reindexing[i], k); + } + } + AddEdgeCost(dst, adj, added_edge_cost); + } + } + // Remove edges + for (const NodeIdx adj : adj_list) { + RemoveEdge(src, adj); + } +} + +NodeIdx CostGraph::QueryDestination(const NodeIdx node_idx) { + if (merged_to_.contains(node_idx)) { + NodeIdx old_dst = merged_to_[node_idx]; + NodeIdx new_dst = QueryDestination(old_dst); + if (old_dst != new_dst) { + // Compress path. + absl::Span old_reindexing_vector = + reindexing_vector_[node_idx]; + std::vector new_reindexing_vector; + new_reindexing_vector.reserve(node_lens_.size()); + for (NodeStrategyIdx i = 0; i < node_lens_[new_dst]; ++i) { + new_reindexing_vector.push_back( + old_reindexing_vector[reindexing_vector_[old_dst][i]]); + } + reindexing_vector_[node_idx] = new_reindexing_vector; + merged_to_[node_idx] = new_dst; + } + return new_dst; + } + return node_idx; +} + +void CostGraph::Simplify(const bool enable) { + // Merge nodes. + if (enable) { + for (const auto& [src, dst] : to_merge_pairs_) { + MergeNode(src, QueryDestination(dst)); + } + } + // Build follow map. + follow_idx_.reserve(node_lens_.size()); + for (NodeIdx i = 0; i < node_lens_.size(); ++i) { + if (merged_to_.contains(i)) { + follow_idx_.push_back(QueryDestination(i)); + } else { + follow_idx_.push_back(-1); + } + } +} + +NodeStrategyIdx CostGraph::RemapIndex(const NodeIdx node_id, + const NodeStrategyIdx value) const { + if (follow_idx_[node_id] < 0) { + return value; + } + return reindexing_vector_.at(node_id)[value]; +} + +std::string CostGraph::ToString() const { + std::string str; + absl::StrAppend(&str, "Cost Graph:\n"); + + for (NodeIdx i = 0; i < node_lens_.size(); ++i) { + absl::StrAppend(&str, "Node", i, ": ", node_lens_[i], "\n"); + } + absl::StrAppend(&str, "\n"); + + for (const auto& iter : edge_costs_) { + absl::StrAppend(&str, "Edge (", iter.first.first, ", ", iter.first.second, + "):\n"); + absl::StrAppend(&str, iter.second.ToString(), "\n"); + } + + return str; +} + +} // namespace spmd +} // namespace xla diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h index fd62735508968..08b0bd968b6d4 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,9 +16,8 @@ limitations under the License. #ifndef XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_COST_GRAPH_H_ #define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_COST_GRAPH_H_ -#include #include -#include +#include #include #include #include @@ -30,310 +29,77 @@ limitations under the License. #include "xla/hlo/experimental/auto_sharding/matrix.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/shape_util.h" + namespace xla { namespace spmd { -// A graph data structrue to simplify the edge cost graph. -// It merges nodes and does path compression. -class CostGraph { - public: - CostGraph(const StrategyGroups& strategy_groups, - const AssociativeDotPairs& associative_dot_pairs) { - node_lens_.reserve(strategy_groups.size()); - extra_node_costs_.reserve(strategy_groups.size()); - adjacency_.assign(strategy_groups.size(), StableHashSet()); - - // Build the cost graph - for (const auto& strategies : strategy_groups) { - node_lens_.push_back(strategies->strategies.size()); - extra_node_costs_.push_back( - std::vector(strategies->strategies.size(), 0.0)); - - for (size_t i = 0; i < strategies->in_nodes.size(); ++i) { - if (!strategies->in_nodes[i]->is_tuple) { - NodeIdx src_idx = strategies->in_nodes[i]->node_idx; - NodeIdx dst_idx = strategies->node_idx; - Matrix edge_cost = CreateEdgeCost(src_idx, dst_idx, i, strategies); - AddEdgeCost(src_idx, dst_idx, edge_cost); - } else if (strategies->in_nodes[i]->is_tuple && - strategies->in_nodes.size() > 1) { - for (size_t l = 0; l < strategies->in_nodes[i]->childs.size(); l++) { - NodeIdx src_idx = strategies->in_nodes[i]->childs.at(l)->node_idx; - NodeIdx dst_idx = strategies->node_idx; - Matrix edge_cost = - CreateEdgeCost(src_idx, dst_idx, i, strategies, true); - AddEdgeCost(src_idx, dst_idx, edge_cost); - } - - } else { - CHECK_EQ(strategies->in_nodes.size(), 1) - << "Do not support instructions with more than one tuple " - "operand. If this CHECK fails, we will need to fix " - "b/233412625."; - for (size_t l = 0; l < strategies->in_nodes[i]->childs.size(); l++) { - NodeIdx src_idx = strategies->in_nodes[i]->childs.at(l)->node_idx; - NodeIdx dst_idx = strategies->node_idx; - // TODO(b/233412625) Support more general case, e.g., multiple tuple - // operands. If there is only one operand and it's a tuple, the - // first index of resharding_costs is for the tuple element. - Matrix edge_cost = - CreateEdgeCost(src_idx, dst_idx, /*in_node_idx=*/l, strategies); - AddEdgeCost(src_idx, dst_idx, edge_cost); - } - } - } - - if (strategies->following) { - to_merge_pairs_.push_back( - {strategies->node_idx, strategies->following->node_idx}); - } - } - - // Adjust the edge costs for dot pairs that can be optimized by - // AllReduceReassociate - for (const auto& pair : associative_dot_pairs) { - NodeIdx src_idx = pair.first->node_idx; - NodeIdx dst_idx = pair.second->node_idx; +struct EdgeReshardingCost { + double communication_cost = 0; + double memory_cost = 0; - if (node_lens_[src_idx] != node_lens_[dst_idx]) { - continue; - } + EdgeReshardingCost() : communication_cost(0), memory_cost(0) {} - Matrix edge_cost(node_lens_[src_idx], node_lens_[dst_idx]); - for (NodeStrategyIdx i = 0; i < node_lens_[src_idx]; ++i) { - if (strategy_groups[src_idx]->strategies[i].communication_cost > 0) { - CHECK_LE( - std::abs( - strategy_groups[src_idx]->strategies[i].communication_cost - - strategy_groups[dst_idx]->strategies[i].communication_cost), - 1e-6); - edge_cost(i, i) = - -strategy_groups[src_idx]->strategies[i].communication_cost; - } - } - AddEdgeCost(src_idx, dst_idx, edge_cost); - } - } - - Matrix CreateEdgeCost(NodeIdx src_idx, NodeIdx dst_idx, size_t in_node_idx, - StrategyGroup* strategy_group, bool zero_cost = false) { - CHECK_LT(src_idx, node_lens_.size()); - CHECK_LT(dst_idx, node_lens_.size()); - Matrix edge_cost(node_lens_[src_idx], node_lens_[dst_idx]); - for (NodeStrategyIdx k = 0; k < strategy_group->strategies.size(); ++k) { - const ShardingStrategy& strategy = strategy_group->strategies[k]; - size_t start_idx = 0; - if (strategy.resharding_costs[in_node_idx].size() > node_lens_[src_idx]) { - start_idx = - strategy.resharding_costs[in_node_idx].size() - node_lens_[src_idx]; - } - for (size_t j = start_idx; - j < strategy.resharding_costs[in_node_idx].size(); ++j) { - edge_cost(j - start_idx, k) = - zero_cost ? 0 : strategy.resharding_costs[in_node_idx][j]; - } - } + EdgeReshardingCost(double communication_cost_, double memory_cost_) + : communication_cost(communication_cost_), memory_cost(memory_cost_) {} - return edge_cost; + EdgeReshardingCost operator+(const EdgeReshardingCost& other) const { + return EdgeReshardingCost(other.communication_cost + communication_cost, + other.memory_cost + memory_cost); } - Matrix GetEdgeCost(NodeIdx i, NodeIdx j) { - if (i <= j) { - return edge_costs_[{i, j}]; - } - return edge_costs_[{j, i}].Transpose(); - } - - void AddEdgeCost(NodeIdx i, NodeIdx j, Matrix& cost) { - if (i > j) { - std::swap(i, j); - cost = cost.Transpose(); - } - - if (edge_costs_.contains({i, j})) { - CHECK(adjacency_[i].contains(j)); - CHECK(adjacency_[j].contains(i)); - edge_costs_[{i, j}] = edge_costs_[{i, j}] + cost; - } else { - adjacency_[i].insert(j); - adjacency_[j].insert(i); - edge_costs_[{i, j}] = cost; - } - } - - void RemoveEdge(NodeIdx i, NodeIdx j) { - if (i > j) { - std::swap(i, j); - } - - CHECK(adjacency_[i].contains(j)); - CHECK(adjacency_[j].contains(i)); - CHECK(edge_costs_.contains({i, j})); - - adjacency_[i].erase(j); - adjacency_[j].erase(i); - edge_costs_.erase({i, j}); + std::string ToString() const { + return absl::StrCat("{communication_cost=", communication_cost, + ", memory_cost=", memory_cost, "}"); } +}; - void MergeNode(NodeIdx src, NodeIdx dst) { - // Merge node src into node dst. This is used when we set one operator to - // follow another operator's sharding spec. For the following computation - // graph: - // dst -- src -- adj1 - // | - // adj2 - // - // It will be transformed into the following graph: - // (src) - // dst -- adj1 - // | - // adj2 - // Where all the edges costs between src and adjs will be added into - // the edge costs between dst and adjs. The edge cost between src and - // dst will be added to the extra node cost of dst. Other node costs of - // src will be added into dst's node cost in the ILP. - - CHECK(adjacency_[src].contains(dst)); - CHECK(adjacency_[dst].contains(src)); - CHECK(!merged_to_.contains(src)); - CHECK(!merged_to_.contains(dst)); - CHECK_NE(src, dst); - - Matrix edge_cost = GetEdgeCost(dst, src); - - std::vector reindexing(node_lens_[dst]); - if (node_lens_[dst] == node_lens_[src]) { - // Assume the orders of strategies in src and dst match - // (i.e. i-th strategy in src follows i-th strategy in dst). - // This is true in most cases because of how we create the - // following strategies. - std::iota(reindexing.begin(), reindexing.end(), 0); - } else { - // Otherwise, find the strategy to follow greedily. - // For every strategy in dst, find the strategy in src with - // the lowest resharding cost. - std::vector arange(node_lens_[src]); - std::iota(arange.begin(), arange.end(), 0); - for (NodeStrategyIdx i = 0; i < node_lens_[dst]; ++i) { - std::vector> keys; - - // If there are multiple strategies with the same lowest costs, - // prefer to follow "replicated", which has the largest index. - // Node: We assume the strategy "Repilcated" is always appended - // as the last strategy in BuildStrategyAndCost. - keys.reserve(node_lens_[src]); - for (NodeStrategyIdx j = 0; j < node_lens_[src]; ++j) { - keys.push_back({edge_cost(i, j), -j}); - } - - std::sort(arange.begin(), arange.end(), [&keys](int l, int r) { - return (keys[l].first < keys[r].first) || - (keys[l].first == keys[r].first && - keys[l].second < keys[r].second); - }); - - reindexing[i] = arange.front(); - } - } - merged_to_[src] = dst; - reindexing_vector_[src] = reindexing; - - // Merge edge cost matrix - std::vector adj_list(adjacency_[src].begin(), - adjacency_[src].end()); - for (NodeIdx adj : adj_list) { - if (adj == dst) { - for (NodeStrategyIdx i = 0; i < node_lens_[dst]; ++i) { - extra_node_costs_[dst][i] += edge_cost(i, reindexing[i]); - } - } else { - Matrix added_edge_cost(node_lens_[dst], node_lens_[adj]); - Matrix edge_cost_src_adj = GetEdgeCost(src, adj); +using EdgeReshardingCostMatrix = Matrix; - for (NodeStrategyIdx i = 0; i < node_lens_[dst]; ++i) { - for (NodeStrategyIdx k = 0; k < node_lens_[adj]; ++k) { - added_edge_cost(i, k) = edge_cost_src_adj(reindexing[i], k); - } - } +// A graph data structure to simplify the edge cost graph. It merges nodes and +// performs path compression. +class CostGraph { + public: + CostGraph(const StrategyGroups& strategy_groups, + const AssociativeDotPairs& associative_dot_pairs); - AddEdgeCost(dst, adj, added_edge_cost); - } - } + EdgeReshardingCostMatrix CreateEdgeCost(NodeIdx src_idx, NodeIdx dst_idx, + size_t in_node_idx, + StrategyGroup* strategy_group, + bool zero_cost = false); - // Remove edges - for (NodeIdx adj : adj_list) { - RemoveEdge(src, adj); - } - } + EdgeReshardingCostMatrix GetEdgeCost(NodeIdx i, NodeIdx j); - NodeIdx QueryDestination(NodeIdx node_idx) { - if (merged_to_.contains(node_idx)) { - NodeIdx old_dst = merged_to_[node_idx]; - NodeIdx new_dst = QueryDestination(old_dst); - if (old_dst != new_dst) { - // Compresss path - absl::Span old_reindexing_vector = - reindexing_vector_[node_idx]; - std::vector new_reindexing_vector; - new_reindexing_vector.reserve(node_lens_.size()); - for (NodeStrategyIdx i = 0; i < node_lens_[new_dst]; ++i) { - new_reindexing_vector.push_back( - old_reindexing_vector[reindexing_vector_[old_dst][i]]); - } - reindexing_vector_[node_idx] = new_reindexing_vector; - merged_to_[node_idx] = new_dst; - } - return new_dst; - } - return node_idx; - } + void AddEdgeCost(NodeIdx i, NodeIdx j, EdgeReshardingCostMatrix& cost); - void Simplify(bool enable) { - // Merge nodes - for (const auto& pair : to_merge_pairs_) { - NodeIdx src = pair.first; - NodeIdx dst = pair.second; - dst = QueryDestination(dst); - if (enable) { - MergeNode(src, dst); - } - } + void RemoveEdge(NodeIdx i, NodeIdx j); - // Build follow map - follow_idx_.reserve(node_lens_.size()); - for (NodeIdx i = 0; i < node_lens_.size(); ++i) { - if (merged_to_.contains(i)) { - follow_idx_.push_back(QueryDestination(i)); - } else { - follow_idx_.push_back(-1); - } - } - } + // Merge node src into node dst. This is used when we set one operator to + // follow another operator's sharding spec. For the following computation + // graph: + // dst -- src -- adj1 + // | + // adj2 + // + // It will be transformed into the following graph: + // (src) + // dst -- adj1 + // | + // adj2 + // Where all the edges costs between src and adjs will be added into the edge + // costs between dst and adjs. The edge cost between src and dst will be added + // to the extra node cost of dst. Other node costs of src will be added into + // dst's node cost in the ILP. + void MergeNode(NodeIdx src, NodeIdx dst); - NodeStrategyIdx RemapIndex(NodeIdx node_id, NodeStrategyIdx value) const { - if (follow_idx_[node_id] < 0) { - return value; - } - return reindexing_vector_.at(node_id)[value]; - } + NodeIdx QueryDestination(NodeIdx node_idx); - std::string ToString() { - std::string str; - absl::StrAppend(&str, "Cost Graph:\n"); + void Simplify(bool enable); - for (NodeIdx i = 0; i < node_lens_.size(); ++i) { - absl::StrAppend(&str, "Node", i, ": ", node_lens_[i], "\n"); - } - absl::StrAppend(&str, "\n"); + NodeStrategyIdx RemapIndex(NodeIdx node_id, NodeStrategyIdx value) const; - for (const auto& iter : edge_costs_) { - absl::StrAppend(&str, "Edge (", iter.first.first, ", ", iter.first.second, - "):\n"); - absl::StrAppend(&str, iter.second.ToString(), "\n"); - } + std::string ToString() const; - return str; - } + // TODO: Make class member variables private. // The number of strategies of each node. std::vector node_lens_; @@ -341,7 +107,8 @@ class CostGraph { std::vector> adjacency_; // The cost matrix between two nodes. - StableHashMap, Matrix> edge_costs_; + StableHashMap, EdgeReshardingCostMatrix> + edge_costs_; // The extra node costs introduced by merging nodes. std::vector> extra_node_costs_; // The reindexing vector of the node. @@ -358,7 +125,7 @@ class CostGraph { std::vector> to_merge_pairs_; }; -// Get the final sharding strategy according to the ilp solution. +// Get the final sharding strategy according to the ILP solution. inline const ShardingStrategy& GetShardingStrategy( const HloInstruction* inst, const StrategyMap& strategy_map, const CostGraph& cost_graph, absl::Span s_val) { @@ -369,7 +136,7 @@ inline const ShardingStrategy& GetShardingStrategy( return strategy_group->strategies[stra_idx]; } -// Get the final sharding strategy according to the ilp solution. +// Get the final sharding strategy according to the ILP solution. inline const ShardingStrategy& GetShardingStrategyForTuple( const HloInstruction* inst, ShapeIndex index, const StrategyMap& strategy_map, const CostGraph& cost_graph, @@ -388,4 +155,5 @@ inline const ShardingStrategy& GetShardingStrategyForTuple( } // namespace spmd } // namespace xla + #endif // XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_COST_GRAPH_H_ diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index bde495a71447d..84c22df6fa074 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -1,4 +1,4 @@ -/*Copyright 2022 The TensorFlow Authors.All Rights Reserved. +/*Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0(the "License"); you may not use this file except in compliance with the License. @@ -33,14 +33,17 @@ limitations under the License. #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h" #include "xla/hlo/experimental/auto_sharding/cluster_environment.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/service/call_graph.h" #include "xla/service/dot_as_convolution_util.h" +#include "xla/service/hlo_cost_analysis.h" #include "xla/service/sharding_propagation.h" #include "xla/status.h" #include "tsl/platform/errors.h" @@ -63,12 +66,18 @@ class HandlerBase { protected: HandlerBase(std::unique_ptr& strategy_group, StrategyMap& strategy_map, const HloInstruction* ins, + const int64_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) : strategy_group_(strategy_group), strategy_map_(strategy_map), ins_(ins), + instruction_id_(instruction_id), + instruction_sequence_(instruction_sequence), + hlo_cost_analysis_(hlo_cost_analysis), cluster_env_(cluster_env), batch_map_(batch_map), option_(option), @@ -83,18 +92,6 @@ class HandlerBase { absl::Span input_specs, double compute_cost, double communication_cost); - bool CheckDims(const HloInstruction* ins, const DimMap& dim_map) const { - for (const auto& [tensor_dim, mesh_dim] : dim_map) { - auto shape_dim = ins->shape().dimensions().at(tensor_dim); - auto device_mesh_dim = device_mesh_.dim(mesh_dim); - if (shape_dim < device_mesh_dim) return false; - if (option_.only_allow_divisible_intermediate && - !IsDivisible(shape_dim, device_mesh_dim)) - return false; - } - return true; - } - HloSharding CreateInputSpec(const HloInstruction* ins, const DimMap& dim_map, const Array& device_mesh) const { if (dim_map.empty()) return HloSharding::Replicate(); @@ -145,6 +142,9 @@ class HandlerBase { std::unique_ptr& strategy_group_; StrategyMap& strategy_map_; const HloInstruction* ins_; + const int64_t instruction_id_; + const HloInstructionSequence& instruction_sequence_; + const HloCostAnalysis& hlo_cost_analysis_; const ClusterEnvironment& cluster_env_; const InstructionBatchDimMap& batch_map_; const AutoShardingOption& option_; @@ -160,13 +160,18 @@ class DotHandler : public HandlerBase { public: DotHandler(std::unique_ptr& strategy_group, StrategyMap& strategy_map, const HloDotInstruction* ins, + int64_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph); DotHandler( std::unique_ptr& strategy_group, StrategyMap& strategy_map, - const HloConvolutionInstruction* ins, + const HloConvolutionInstruction* ins, int64_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const dot_as_convolution_util::DotConvolutionDimsInfo& conv_as_dot_dims, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, @@ -216,6 +221,9 @@ class ConvHandler : public HandlerBase { public: ConvHandler(std::unique_ptr& strategy_group, StrategyMap& strategy_map, const HloInstruction* ins, + int64_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph); @@ -246,13 +254,17 @@ void HandlerBase::AppendNewStrategy(const std::string& name, absl::Span input_specs, double compute_cost, double communication_cost) { - std::vector> resharding_costs; + ReshardingCosts communication_resharding_costs; + ReshardingCosts memory_resharding_costs; for (int i = 0; i < ins_->operand_count(); ++i) { const HloInstruction* operand = ins_->operand(i); - resharding_costs.push_back( - ReshardingCostVector(strategy_map_.at(operand).get(), operand->shape(), - input_specs[i], cluster_env_)); + communication_resharding_costs.push_back(CommunicationReshardingCostVector( + strategy_map_.at(operand).get(), operand->shape(), input_specs[i], + cluster_env_)); + memory_resharding_costs.push_back(MemoryReshardingCostVector( + strategy_map_.at(operand).get(), operand->shape(), input_specs[i], + cluster_env_)); } strategy_group_->strategies.push_back(ShardingStrategy({ @@ -261,7 +273,8 @@ void HandlerBase::AppendNewStrategy(const std::string& name, compute_cost, communication_cost, GetBytes(ins_->shape()) / output_spec.NumTiles(), - resharding_costs, + communication_resharding_costs, + memory_resharding_costs, {input_specs.begin(), input_specs.end()}, })); } @@ -283,9 +296,9 @@ void HandlerBase::MaybeAppend( communication_cost_fn) { HloSharding lhs_spec = CreateInputSpec(lhs_, lhs_dim_map, device_mesh); HloSharding rhs_spec = CreateInputSpec(rhs_, rhs_dim_map, device_mesh); - if (std::optional output_spec = - GetShardingFromUser(lhs_spec, rhs_spec); - output_spec.has_value()) { + std::optional output_spec = + GetShardingFromUser(lhs_spec, rhs_spec); + if (output_spec.has_value()) { if (expected_output_dim_map.has_value()) { HloSharding expected_output_spec = CreateInputSpec(ins_, *expected_output_dim_map, device_mesh); @@ -306,15 +319,21 @@ void HandlerBase::MaybeAppend( "mismatch, we continue with the expected sharding"; } } - double communication_cost = 0; - if (communication_cost_fn.has_value()) { - communication_cost = communication_cost_fn.value()(*output_spec); - } - AppendNewStrategy(name, *output_spec, {lhs_spec, rhs_spec}, compute_cost, - communication_cost); } else { - LOG(FATAL) << "Sharding propagation could not infer output sharding"; + CHECK(expected_output_dim_map.has_value()); + output_spec = CreateInputSpec(ins_, *expected_output_dim_map, device_mesh); + LOG(WARNING) + << "Sharding propagation could not infer output sharding for:\n " + << ins_->ToString() << "\n LHS Spec: " << lhs_spec + << "\n RHS Spec: " << rhs_spec << "\n Output sharding name: " << name; } + + double communication_cost = 0; + if (communication_cost_fn.has_value()) { + communication_cost = communication_cost_fn.value()(*output_spec); + } + AppendNewStrategy(name, *output_spec, {lhs_spec, rhs_spec}, compute_cost, + communication_cost); } std::optional HandlerBase::GetShardingFromUser( @@ -347,12 +366,16 @@ std::optional HandlerBase::GetShardingFromUser( DotHandler::DotHandler(std::unique_ptr& strategy_group, StrategyMap& strategy_map, const HloDotInstruction* ins, + const int64_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) - : HandlerBase(strategy_group, strategy_map, ins, cluster_env, batch_map, - option, call_graph), + : HandlerBase(strategy_group, strategy_map, ins, instruction_id, + instruction_sequence, hlo_cost_analysis, cluster_env, + batch_map, option, call_graph), is_dot_(true), space_base_dim_(ins->dot_dimension_numbers().lhs_batch_dimensions_size()), lhs_con_dims_(ins->dot_dimension_numbers().lhs_contracting_dimensions()), @@ -367,13 +390,16 @@ DotHandler::DotHandler(std::unique_ptr& strategy_group, DotHandler::DotHandler( std::unique_ptr& strategy_group, StrategyMap& strategy_map, - const HloConvolutionInstruction* ins, + const HloConvolutionInstruction* ins, const int64_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const dot_as_convolution_util::DotConvolutionDimsInfo& conv_as_dot_dims, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) - : HandlerBase(strategy_group, strategy_map, ins, cluster_env, batch_map, - option, call_graph), + : HandlerBase(strategy_group, strategy_map, ins, instruction_id, + instruction_sequence, hlo_cost_analysis, cluster_env, + batch_map, option, call_graph), is_dot_(false), space_base_dim_(-1) { CHECK(conv_as_dot_dims.conv_spatial_dims.empty()); @@ -596,7 +622,8 @@ void DotHandler::SplitBatchDimBothContract() { absl::StrJoin(e.mesh_dims, ","), e.mesh_dims[1]); const DimMap lhs_dim_map = {{lhs_con_dims_[e.i], e.mesh_dims[1]}, {lhs_batch_dims_[e.j], e.mesh_dims[0]}}; - const DimMap rhs_dim_map = {{rhs_batch_dims_[e.j], e.mesh_dims[0]}}; + const DimMap rhs_dim_map = {{rhs_con_dims_[e.i], e.mesh_dims[1]}, + {rhs_batch_dims_[e.j], e.mesh_dims[0]}}; std::optional out_dim_map = std::nullopt; if (is_dot_) { out_dim_map = DimMap{{e.j, e.mesh_dims[0]}}; @@ -645,6 +672,9 @@ void DotHandler::RecomputeSplitBothContract() { if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || device_mesh_.dim(e.mesh_dims[1]) <= 1) return; + if (!option_.allow_recompute_heavy_op) { + return; + } std::string name = absl::StrFormat("RR = RS x SR @ {%d} (allreduce @ %d)", e.mesh_dims[0], e.mesh_dims[0]); const DimMap lhs_dim_map = {{lhs_con_dims_[e.i], e.mesh_dims[0]}}; @@ -653,7 +683,10 @@ void DotHandler::RecomputeSplitBothContract() { if (is_dot_) { out_dim_map = DimMap{}; } - double compute_cost = cluster_env_.DotCost(lhs_->shape(), rhs_->shape()); + double compute_cost = GetDotConvReplicationPenalty( + ins_, instruction_id_, /* window */ 10, + instruction_sequence_, hlo_cost_analysis_) / + device_mesh_.dim(e.mesh_dims[0]); auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0]); @@ -831,12 +864,16 @@ Status DotHandler::RegisterStrategies() { ConvHandler::ConvHandler(std::unique_ptr& strategy_group, StrategyMap& strategy_map, const HloInstruction* ins, + const int64_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) - : HandlerBase(strategy_group, strategy_map, ins, cluster_env, batch_map, - option, call_graph), + : HandlerBase(strategy_group, strategy_map, ins, instruction_id, + instruction_sequence, hlo_cost_analysis, cluster_env, + batch_map, option, call_graph), conv_dnums_(ins->convolution_dimension_numbers()) { lhs_batch_dim_ = conv_dnums_.input_batch_dimension(); lhs_in_channel_dim_ = conv_dnums_.input_feature_dimension(); @@ -1004,6 +1041,8 @@ void ConvHandler::SplitDepthwise(bool forward) { Status HandleDot(std::unique_ptr& strategy_group, StrategyGroups& strategy_groups, StrategyMap& strategy_map, const HloInstruction* ins, size_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, @@ -1012,6 +1051,7 @@ Status HandleDot(std::unique_ptr& strategy_group, strategy_groups); DotHandler handler(strategy_group, strategy_map, Cast(ins), + instruction_id, instruction_sequence, hlo_cost_analysis, cluster_env, batch_map, option, call_graph); TF_RETURN_IF_ERROR(handler.RegisterStrategies()); return OkStatus(); @@ -1021,6 +1061,8 @@ Status HandleDot(std::unique_ptr& strategy_group, Status HandleConv(std::unique_ptr& strategy_group, StrategyGroups& strategy_groups, StrategyMap& strategy_map, const HloInstruction* ins, size_t instruction_id, + const HloInstructionSequence& instruction_sequence, + const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, @@ -1031,13 +1073,15 @@ Status HandleConv(std::unique_ptr& strategy_group, auto conv_as_dot_dims = dot_as_convolution_util::ParseConvolutionDimsInfo(ins); if (conv_as_dot_dims.conv_spatial_dims.empty()) { - DotHandler handler(strategy_group, strategy_map, - Cast(ins), conv_as_dot_dims, - cluster_env, batch_map, option, call_graph); + DotHandler handler( + strategy_group, strategy_map, Cast(ins), + instruction_id, instruction_sequence, hlo_cost_analysis, + conv_as_dot_dims, cluster_env, batch_map, option, call_graph); TF_RETURN_IF_ERROR(handler.RegisterStrategies()); } else { - ConvHandler handler(strategy_group, strategy_map, ins, cluster_env, + ConvHandler handler(strategy_group, strategy_map, ins, instruction_id, + instruction_sequence, hlo_cost_analysis, cluster_env, batch_map, option, call_graph); TF_RETURN_IF_ERROR(handler.RegisterStrategies()); } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc index 21281269eb535..1adcc3f6a05ca 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,9 +14,11 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_solver.h" @@ -32,15 +34,19 @@ namespace spmd { AutoShardingSolverResult Solve( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, - const LivenessNodeSet& liveness_node_set, const StrategyMap& strategy_map, + const LivenessNodeSet& liveness_node_set, + const LivenessEdgeSet& liveness_edge_set, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, const AutoShardingOption& option, + absl::string_view request_prefix, const absl::flat_hash_map& sharding_propagation_solution) { - return CallSolver(hlo_module, hlo_live_range, liveness_node_set, strategy_map, - strategy_groups, cost_graph, alias_set, /*s_hint*/ {}, - /*compute_iis*/ true, option.solver_timeout_in_seconds, - option, sharding_propagation_solution); + return CallSolver( + hlo_module, hlo_live_range, liveness_node_set, liveness_edge_set, + strategy_map, strategy_groups, cost_graph, alias_set, /*s_hint*/ {}, + /*peak_times*/ {}, /*compute_iis*/ true, option.solver_timeout_in_seconds, + option, /*max_cost*/ std::nullopt, request_prefix, + sharding_propagation_solution, /*deterministic mode*/ true); } void PopulateTemporalValues(const CostGraph& cost_graph, diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc index 661535c9db027..e9648436047dd 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -40,8 +40,6 @@ std::string AutoShardingOption::ToString() const { absl::StrCat("memory_budget_per_device: ", memory_budget_per_device / (1024 * 1024 * 1024), " GB")); } - lines.push_back( - absl::StrCat("try_multiple_mesh_shapes: ", try_multiple_mesh_shapes)); lines.push_back(absl::StrCat("force_override_all_gather_cost: ", force_override_all_gather_cost)); @@ -85,7 +83,8 @@ std::string AutoShardingOption::ToString() const { absl::StrCat("allow_mixed_mesh_shape: ", allow_mixed_mesh_shape)); lines.push_back( absl::StrCat("grad_acc_num_micro_batches: ", grad_acc_num_micro_batches)); - lines.push_back(absl::StrCat("load_solution_vector: ", load_solution_vector)); + lines.push_back(absl::StrCat("solve_nd_sharding_iteratively: ", + solve_nd_sharding_iteratively)); lines.push_back( absl::StrCat("force_simple_heuristic: ", force_simple_heuristic)); lines.push_back(absl::StrCat("force_strategy: ", force_strategy)); @@ -108,9 +107,6 @@ std::string AutoShardingOption::ToString() const { lines.push_back(absl::StrCat("nd_sharding_iteratively_strict_search_space: ", nd_sharding_iteratively_strict_search_space)); - lines.push_back(absl::StrCat("allow_replicated_strategy_for_dot_and_conv: ", - allow_replicated_strategy_for_dot_and_conv)); - lines.push_back(absl::StrCat("device_mesh_shape: [", absl::StrJoin(device_mesh_shape, ","), "]")); lines.push_back(absl::StrCat("device_mesh_alpha: [", @@ -118,11 +114,27 @@ std::string AutoShardingOption::ToString() const { lines.push_back(absl::StrCat("device_mesh_beta: [", absl::StrJoin(device_mesh_beta, ","), "]")); - lines.push_back(absl::StrCat("load_strategy: ", load_strategy)); - if (load_strategy) { - lines.push_back(absl::StrCat("strategy_vector: [", - absl::StrJoin(strategy_vector, ","), "]")); - } + lines.push_back( + absl::StrCat("try_multiple_mesh_shapes: ", try_multiple_mesh_shapes)); + + lines.push_back( + absl::StrCat("solver_timeout_in_seconds: ", solver_timeout_in_seconds)); + + lines.push_back(absl::StrCat("loop_iteration_count_estimate: ", + loop_iteration_count_estimate)); + + lines.push_back(absl::StrCat("allow_alias_to_follower_conversion: ", + allow_alias_to_follower_conversion)); + + lines.push_back( + absl::StrCat("small_tensor_byte_size: ", small_tensor_byte_size)); + + lines.push_back( + absl::StrCat("use_sharding_propagation_for_default_shardings: ", + use_sharding_propagation_for_default_shardings)); + + lines.push_back(absl::StrCat("model_resharding_memory_costs: ", + model_resharding_memory_costs)); return absl::StrJoin(lines, "\n"); } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_option.h b/xla/hlo/experimental/auto_sharding/auto_sharding_option.h index 775c4a59bc301..a858ecfce1d4d 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_option.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_option.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -104,9 +104,12 @@ struct AutoShardingOption { // 2d mesh case. bool batch_matmul_always_split_batch = false; - // If true, allow strategies that recompute heavy operators (e.g., dot) - // to reduce communication. - bool allow_recompute_heavy_op = false; + // If true, allow strategies that recompute heavy operators (e.g., dot) to + // reduce communication. This will generate generate replicated or partially + // replicated strategies for dot/conv ops. Generating these seems to be + // beneficial for LLM serving models, but can increase the search space, so + // this feature is exposed as an option. + bool allow_recompute_heavy_op = true; // If true, allow adding 1d strategies in 2d logical mesh. bool allow_mixed_mesh_shape = false; @@ -116,9 +119,6 @@ struct AutoShardingOption { // is divided by this number. int grad_acc_num_micro_batches = 1; - // If true, load solution vector from PassContext - bool load_solution_vector = false; - // If true, N-D sharding (e.g., N maybe be 2 or 3) will be solved in N // iterations, where one iteration chooses one tensor dimension to shard. If // false, solve N-D sharding directly, i.e., generating all possible sharding @@ -146,11 +146,6 @@ struct AutoShardingOption { // space more scalable. Therefore leaving it as an option. bool nd_sharding_iteratively_strict_search_space = false; - // Whether or not to generate replicated strategies for dot/conv - // ops. Generating these seems to be beneficial for LLM serving models, but - // can increase the search space, so this feature is exposed as an option. - bool allow_replicated_strategy_for_dot_and_conv = true; - // Device mesh shape. std::vector device_mesh_shape; // Device IDs in the mesh. @@ -161,8 +156,7 @@ struct AutoShardingOption { // element models the communication performance along each mesh dimension. std::vector device_mesh_alpha; std::vector device_mesh_beta; - // Load the strategy vector instead of solving one. - bool load_strategy = false; + // Explore other mesh shapes with the same number of devices as the provided // one for a potentially better auto-sharding solution. bool try_multiple_mesh_shapes = false; @@ -172,7 +166,10 @@ struct AutoShardingOption { // sharding_propagation.cc. int64_t solver_timeout_in_seconds = 3600; - // Static estimate for iteration count of a while loop, used in the cost model + // Static estimate for iteration count of a while loop, used in the cost + // model. This estimate is used when we cannot infer an upper bound on the + // number of iterations in the loop (as implemented in + // third_party/tensorflow/compiler/xla/service/while_loop_analysis.h) int64_t loop_iteration_count_estimate = 100; // Allows the conversion of aliases to followers if their pairwise strategy @@ -180,7 +177,6 @@ struct AutoShardingOption { // smaller Mixed ILP). bool allow_alias_to_follower_conversion = true; - std::vector strategy_vector; // If greater than zero, tensors with size smaller than or equal to this limit // will always be replicated if they don't have a different user-specified // sharding. @@ -191,6 +187,10 @@ struct AutoShardingOption { // a simple replicated default. bool use_sharding_propagation_for_default_shardings = false; + // Whether or not to model the memory usage of intermediate tensors, if any, + // for resharding edges. + bool model_resharding_memory_costs = true; + // Prints a debug string. std::string ToString() const; diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_runner.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_runner.cc index 59e729f06c8b8..b7a721780b998 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_runner.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_runner.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 8b73e6a43bc00..b880e5137f64d 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include "xla/hlo/experimental/auto_sharding/auto_sharding_solver.h" #include +#include #include #include #include @@ -25,6 +26,7 @@ limitations under the License. #include #include +#include "absl/container/btree_set.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding.pb.h" #ifdef PLATFORM_GOOGLE @@ -34,11 +36,12 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" -#include "absl/status/status.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" +#include "absl/time/clock.h" #include "absl/time/time.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" +#include "xla/status.h" +#include "xla/status_macros.h" #include "xla/util.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/hash.h" @@ -50,14 +53,29 @@ limitations under the License. #include "util/task/status.pb.h" #endif -using MPConstraint = operations_research::MPConstraint; -using MPSolver = operations_research::MPSolver; -using MPSolverParameters = operations_research::MPSolverParameters; -using MPVariable = operations_research::MPVariable; - namespace xla { namespace spmd { +using ::operations_research::MPConstraint; +using ::operations_research::MPSolver; +using ::operations_research::MPVariable; + +// We need to nudge the maximum cost (if present) slightly, since the constraint +// solver cannot guarantee exact numerical precision. +constexpr double kMaxCostEpsilon = 1.0001; + +// In the Mixed ILP, we model all memory-related terms (i.e., coefficients, +// bounds, etc.) using smaller absolute values, due to limitations on precision. +// To compensate, the overbudget objective coefficient must be amplified by the +// same amount. +constexpr double kMemoryMultiplier = 1e-6; + +bool AutoShardingSolverOutput::operator==( + const AutoShardingSolverOutput& other) const { + return s_val == other.s_val && e_val == other.e_val && cost == other.cost && + peak_times == other.peak_times; +} + bool AutoShardingSolverResult::operator==( const AutoShardingSolverResult& other) const { return status == other.status && @@ -67,17 +85,16 @@ bool AutoShardingSolverResult::operator==( void PrintLargestInstructions( const std::vector& chosen_strategy, const AutoShardingSolverRequest& request) { - // This memory consumption computation is different from - // that in PrintAutoShardingSolution() because how L and m are created to be - // different from liveness_set and strategy.memory_cost. - + // This memory consumption computation is different from that in + // PrintAutoShardingSolution() because L and m are created to be different + // from liveness_set and strategy.memory_cost. std::vector> time_memory_usage; for (LivenessIdx time_idx = 0; time_idx < request.live_size(); ++time_idx) { double mem = 0.0; for (NodeIdx node_idx : request.live(time_idx).nodes()) { mem += request.memory_costs(node_idx).costs(chosen_strategy[node_idx]); } - time_memory_usage.push_back(std::make_pair(time_idx, mem)); + time_memory_usage.push_back({time_idx, mem}); } struct { bool operator()(std::pair a, @@ -95,14 +112,14 @@ void PrintLargestInstructions( k = std::min(k, time_memory_usage.size()); std::vector> instruction_mem; absl::flat_hash_set instruction_set; - for (auto usage_idx = 0; usage_idx < k; usage_idx++) { + for (auto usage_idx = 0; usage_idx < k; ++usage_idx) { LivenessIdx time_idx = time_memory_usage.at(usage_idx).first; for (NodeIdx node_idx : request.live(time_idx).nodes()) { double mem = request.memory_costs(node_idx).costs(chosen_strategy[node_idx]); if (mem > 100 * 1024 * 1024 && instruction_set.find(node_idx) == instruction_set.end()) { - instruction_mem.push_back(std::make_pair(node_idx, mem)); + instruction_mem.push_back({node_idx, mem}); instruction_set.insert(node_idx); } } @@ -112,7 +129,7 @@ void PrintLargestInstructions( size_t top_tensors = 10; top_tensors = std::min(top_tensors, instruction_mem.size()); VLOG(1) << "Top " << top_tensors << " largest tensors:"; - for (size_t i = 0; i < top_tensors; i++) { + for (size_t i = 0; i < top_tensors; ++i) { VLOG(1) << "instruction name: " << request.instruction_names(instruction_mem.at(i).first) << " memory usage: " @@ -120,11 +137,11 @@ void PrintLargestInstructions( } } -// Applies deterministic noise to the coefficient using the name & saltiplier, +// Applies deterministic noise to the coefficient using `name` and `saltiplier` // so that ties between equal terms can be broken in the solver's objective -// function. We include both a multiplicative term (in case the coefficient is +// function. We include both a multiplicative term (in case the coefficient is // large) and an additive term (in case the coefficient is zero). -void AddSalt(const std::string& name, double saltiplier, double* coeff) { +void AddSalt(const std::string& name, const double saltiplier, double* coeff) { if (saltiplier <= 0.0) return; const tsl::uint64 hash = tsl::Hash64(name); // stable across runs & platforms double salt = saltiplier * hash / std::numeric_limits::max(); @@ -139,22 +156,72 @@ AutoShardingSolverResult SolveAndExtractSolution( MPSolver& solver); double MinimumMemoryBudgetRequired(const AutoShardingSolverRequest& request) { - double minimum_memory_budget_required_estimate = 0.0; + double min_memory_budget_required_estimate = 0.0; for (LivenessIdx time_idx = 0; time_idx < request.live_size(); ++time_idx) { - double minimum_memory_budget_required_estimate_local = 0.0; + double min_memory_budget_required_estimate_local = 0.0; for (NodeIdx node_idx : request.live(time_idx).nodes()) { const auto& m = request.memory_costs(node_idx).costs(); const double fixed_memory_cost = *std::min_element(m.begin(), m.end()); - minimum_memory_budget_required_estimate_local += fixed_memory_cost; + min_memory_budget_required_estimate_local += fixed_memory_cost; + } + min_memory_budget_required_estimate = + std::max(min_memory_budget_required_estimate, + min_memory_budget_required_estimate_local); + } + return min_memory_budget_required_estimate; +} + +double MaxCoeff( + const tsl::protobuf::RepeatedPtrField& + cost_mat) { + double max_coeff = 0.0; + for (auto& costs : cost_mat) { + for (auto& cost : costs.costs()) { + if (cost < kInfinityCost) { + max_coeff = std::max(max_coeff, cost); + } + } + } + return max_coeff; +} + +void ScaleCoeffs( + double scaling_factor, + tsl::protobuf::RepeatedPtrField* + cost_mat) { + for (auto& costs : *cost_mat) { + for (auto& cost : *costs.mutable_costs()) { + if (cost < kInfinityCost) { + cost = floor(cost * scaling_factor); + } } - minimum_memory_budget_required_estimate = - std::max(minimum_memory_budget_required_estimate, - minimum_memory_budget_required_estimate_local); } - return minimum_memory_budget_required_estimate; } -// We formulate the auto sharding process as the following ILP problem: +AutoShardingSolverRequest ScaleRequest( + const AutoShardingSolverRequest& request) { + if (!request.has_coeff_limit()) return request; + VLOG(0) << "Scaling request by coefficient limit: " + << request.coeff_limit().coeff(); + double max_coeff = 0.0; + max_coeff = std::max(max_coeff, MaxCoeff(request.communication_costs())); + max_coeff = std::max(max_coeff, MaxCoeff(request.computation_costs())); + max_coeff = std::max(max_coeff, MaxCoeff(request.resharding_costs())); + if (max_coeff <= request.coeff_limit().coeff()) return request; + const double scaling_factor = request.coeff_limit().coeff() / max_coeff; + AutoShardingSolverRequest scaled_request = request; + ScaleCoeffs(scaling_factor, scaled_request.mutable_communication_costs()); + ScaleCoeffs(scaling_factor, scaled_request.mutable_computation_costs()); + ScaleCoeffs(scaling_factor, scaled_request.mutable_resharding_costs()); + return scaled_request; +} + +// Taking an auto-sharding problem (`request`) as an input, calls the OR tools +// CP-SAT solver and outputs a solution to the input problem. +// +// We formulate the auto-sharding process as the following ILP problem +// (correspondences to the fields of the request parameter are specified in +// parenthesis): // Variables: // s[i]: Sharding strategy one-hot vector. // dim(s[i]) == # sharding strategies of the i-th XLA op @@ -162,19 +229,22 @@ double MinimumMemoryBudgetRequired(const AutoShardingSolverRequest& request) { // e[i, j]: Strategy one-hot vector of edge i -> j. // dim(e[i, j]) == dim(s[i]) * dim(s[j]) // Constants: -// N: Number of total XLA ops -// M: Memory budget -// E: Edge set {(i, j)} -// L[t]: Index of live instructions at time t -// c[i]: Computation cost vector of instruction i +// N: Number of total XLA ops (request.num_nodes) +// M: Memory budget (request.memory_budget) +// E: Edge set {(i, j)} (request.edges) +// L[t]: Index of live instructions at time t (request.live) +// c[i]: Computation cost vector of instruction i (request.computation_costs) // d[i]: Communication cost vector of instruction i -// m[i]: Memory cost vector of instruction i +// (request.communication_costs) +// m[i]: Memory cost vector of instruction i (request.memory_costs) // dim(c[i]) == dim(d[i]) == dim(m[i]) == dim(s[i]) // r[i, j]: The resharding cost vector of edge i -> j +// (request.resharding_costs) // dim(e[i, j]) == dim(r[i, j]) -// A: Alias set {(i, j)} +// A: Alias set {(i, j)} (request.aliases) // v[i, j]: v[i, j](p, q) == 1 if strategy p is different than q, otherwise // v[i, j](p, q) == 0 +// (request.value_costs) // dim(e[i, j]) == dim(v[i, j]) // Problem: // Minimize sum_{0 <= i < N} s[i]^T * (c[i] + d[i]) @@ -199,11 +269,27 @@ double MinimumMemoryBudgetRequired(const AutoShardingSolverRequest& request) { // Serialize parameters of the ILP problem as numpy arrays and call the python // solver. +// Beyond what is described, note the following: +// 1. We also enforce that certain HLO ops have the same sharding as some other +// HLO ops (think elementwise ops, for example). This information stored in +// request.s_follow, where if s_follow[i] >= 0, then instruction i is forced +// the share same sharding as s_follow[i]. +// 2. If request.overbudget_coeff is present, we turn the hard memory budget +// constraint into a soft constraint instead. +// 3. If request.makespan_coeff is present, the objective additionally includes +// a makespan term. This is experimental and turned off by default. +// 4. request.max_departures is used only for debugging and can be ignored. +// 5. Note that due to our modeling of XLA's AllReduceReassociate optimization +// (more details in CostGraph::CostGraph() in auto_sharding_cost_graph.cc, +// and in CreateElementwiseOperatorStrategies() in auto_sharding.cc), there +// can be a few (usually < 10) edges in the problem with negative costs. This +// is guaranteed to never produce a negative overall cost for the graph, +// however. AutoShardingSolverResult CallORToolsSolver( - const AutoShardingSolverRequest& request) { - size_t num_edges = request.edges_size(); - - int32_t num_workers = 32; + const AutoShardingSolverRequest& unscaled_request) { + const AutoShardingSolverRequest& request = ScaleRequest(unscaled_request); + const size_t num_edges = request.edges_size(); + const int num_workers = 32; // SAT or SCIP std::unique_ptr solver(std::make_unique("", MPSolver::SAT_INTEGER_PROGRAMMING)); CHECK(solver); @@ -213,11 +299,15 @@ AutoShardingSolverResult CallORToolsSolver( if (solver->ProblemType() == operations_research::MPSolver::SAT_INTEGER_PROGRAMMING) { // Set random_seed, interleave_search and share_binary_clauses for - // determinism, and num_workers for parallelism. - solver_parameter_str = absl::StrCat( - "share_binary_clauses:false,random_seed:1,interleave_" - "search:true,num_workers:", - num_workers); + // determinism, mip_max_bound (to handle large costs), and num_workers for + // parallelism. + solver_parameter_str = + request.deterministic_mode() + ? absl::StrCat( + "share_binary_clauses:false,random_seed:1,interleave_" + "search:true,mip_max_bound:1e9,num_workers:", + num_workers) + : absl::StrCat("mip_max_bound:1e9,num_workers:", num_workers); solver->SetSolverSpecificParametersAsString(solver_parameter_str); } #endif @@ -227,10 +317,10 @@ AutoShardingSolverResult CallORToolsSolver( MPVariable* overbudget_var = nullptr; MPVariable* makespan_var = nullptr; - size_t var_vector_cnt = 0; + size_t unique_nodes = 0; for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { if (request.s_follow(node_idx) < 0) { - var_vector_cnt += 1; + unique_nodes += 1; // Creates variables for instructions that do not follow others. solver->MakeBoolVarArray(request.s_len(node_idx), absl::StrCat("s[", node_idx, "]"), &s[node_idx]); @@ -239,12 +329,15 @@ AutoShardingSolverResult CallORToolsSolver( for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { if (request.s_follow(node_idx) >= 0) { + CHECK_EQ(request.s_len(node_idx), + request.s_len(request.s_follow(node_idx))); // Copies the variable of followed instruction to the following // instruction. s[node_idx] = s[request.s_follow(node_idx)]; } } + size_t unique_edges = 0; std::vector e_follow(num_edges, -1); absl::flat_hash_map, EdgeIdx> edge_map; for (EdgeIdx edge_idx = 0; edge_idx < num_edges; ++edge_idx) { @@ -258,6 +351,7 @@ AutoShardingSolverResult CallORToolsSolver( e_follow[edge_idx] = it->second; continue; } + unique_edges += 1; solver->MakeBoolVarArray( request.s_len(edge.first) * request.s_len(edge.second), absl::StrCat("e[", edge.first, ",", edge.second, "]"), &e[edge_idx]); @@ -273,14 +367,19 @@ AutoShardingSolverResult CallORToolsSolver( makespan_var = CreateMakespanVar(request, e, *solver); } - // Objective + // Construct objective function. // Node costs + absl::flat_hash_set infinity_vars; for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { double accumulated_coefficient = - solver->Objective().GetCoefficient(s[node_idx][j]); + solver->MutableObjective()->GetCoefficient(s[node_idx][j]); double coefficient = request.computation_costs(node_idx).costs(j) + request.communication_costs(node_idx).costs(j); + if (coefficient >= kInfinityCost) { + infinity_vars.insert(s[node_idx][j]); + continue; + } AddSalt(absl::StrCat(node_idx, "S", j), request.saltiplier(), &coefficient); solver->MutableObjective()->SetCoefficient( @@ -291,24 +390,32 @@ AutoShardingSolverResult CallORToolsSolver( for (EdgeIdx edge_idx = 0; edge_idx < num_edges; ++edge_idx) { for (EdgeStrategyIdx j = 0; j < e[edge_idx].size(); ++j) { double accumulated_coefficient = - solver->Objective().GetCoefficient(e[edge_idx][j]); + solver->MutableObjective()->GetCoefficient(e[edge_idx][j]); double coefficient = request.resharding_costs(edge_idx).costs(j); + if (coefficient >= kInfinityCost) { + infinity_vars.insert(e[edge_idx][j]); + continue; + } AddSalt(absl::StrCat(edge_idx, "E", j), request.saltiplier(), &coefficient); solver->MutableObjective()->SetCoefficient( e[edge_idx][j], accumulated_coefficient + coefficient); } } + LOG(INFO) << "Number of infinity terms: " << infinity_vars.size(); - // Constraints + // Add constraints. // 0. Do not choose solutions with infinity costs, as it will make the // objective value so large that other solution choices do not matter anymore. - // Remove these constraints once b/238210866 is done. + // Also eliminate strategies that are known to be dominated by others. + const NodeStrategies shaved_strategies = + StrategyShaver(request).FindShavedStrategies(); for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { if (s[node_idx].empty() || request.s_follow(node_idx) >= 0) continue; bool all_infinity = true; for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { - if (solver->Objective().GetCoefficient(s[node_idx][j]) >= kInfinityCost) { + if (infinity_vars.contains(s[node_idx][j]) || + shaved_strategies.contains({node_idx, j})) { MPConstraint* constraint = solver->MakeRowConstraint( 0.0, 0.0, absl::StrCat("infinitycost: s[", node_idx, "][", j, "] = 0")); @@ -321,12 +428,11 @@ AutoShardingSolverResult CallORToolsSolver( LOG(FATAL) << "All of s[" << node_idx << "][*] have infinity costs"; } } - for (EdgeIdx edge_idx = 0; edge_idx < num_edges; ++edge_idx) { if (e[edge_idx].empty() || e_follow[edge_idx] >= 0) continue; bool all_infinity = true; for (EdgeStrategyIdx j = 0; j < e[edge_idx].size(); ++j) { - if (solver->Objective().GetCoefficient(e[edge_idx][j]) >= kInfinityCost) { + if (infinity_vars.contains(e[edge_idx][j])) { MPConstraint* constraint = solver->MakeRowConstraint( 0.0, 0.0, absl::StrCat("infinitycost: e[", edge_idx, "][", j, "] = 0")); @@ -362,36 +468,15 @@ AutoShardingSolverResult CallORToolsSolver( } // c. if (request.memory_budget() > 0) { - const double minimum_memory_budget_required_estimate = - MinimumMemoryBudgetRequired(request); - const double minimum_memory_overbudget = std::max( - 0.0, minimum_memory_budget_required_estimate - request.memory_budget()); - for (LivenessIdx time_idx = 0; time_idx < request.live_size(); ++time_idx) { - double upper_bound = request.memory_budget(); - if (overbudget_var) upper_bound += minimum_memory_overbudget; - MPConstraint* constraint = - solver->MakeRowConstraint(-MPSolver::infinity(), upper_bound, - absl::StrCat("mem[", time_idx, "]")); - if (overbudget_var) constraint->SetCoefficient(overbudget_var, -1.0); - for (NodeIdx node_idx : request.live(time_idx).nodes()) { - for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { - const double accumulated_coefficient = - constraint->GetCoefficient(s[node_idx][j]); - const double memory_cost = request.memory_costs(node_idx).costs(j); - constraint->SetCoefficient(s[node_idx][j], - accumulated_coefficient + memory_cost); - } - } - } if (overbudget_var) { solver->MutableObjective()->SetCoefficient( - overbudget_var, request.overbudget_coeff().coeff()); - solver->MutableObjective()->SetOffset(request.overbudget_coeff().coeff() * - minimum_memory_overbudget); + overbudget_var, + request.overbudget_coeff().coeff() / kMemoryMultiplier); } LOG(INFO) << "Minimum memory budget estimate: " - << minimum_memory_budget_required_estimate; - LOG(INFO) << "Using memory budget: " << request.memory_budget(); + << MinimumMemoryBudgetRequired(request); + LOG(INFO) << "Using memory budget: " + << static_cast(request.memory_budget()); } // d. specified via "BoolVarArray" @@ -416,8 +501,8 @@ AutoShardingSolverResult CallORToolsSolver( absl::StrCat("f for i = ", edge_idx, ", p = ", p)); constraint->SetCoefficient(s[edge.first()][p], -1.0); for (NodeStrategyIdx q = 0; q < s[edge.second()].size(); ++q) { - constraint->SetCoefficient(e[edge_idx][p * s[edge.second()].size() + q], - 1.0); + const EdgeStrategyIdx j = p * s[edge.second()].size() + q; + constraint->SetCoefficient(e[edge_idx][j], 1.0); } } } @@ -431,25 +516,30 @@ AutoShardingSolverResult CallORToolsSolver( absl::StrCat("g for i = ", edge_idx, ", q = ", q)); constraint->SetCoefficient(s[edge.second()][q], -1.0); for (NodeStrategyIdx p = 0; p < s[edge.first()].size(); ++p) { - constraint->SetCoefficient(e[edge_idx][p * s[edge.second()].size() + q], - 1.0); + const EdgeStrategyIdx j = p * s[edge.second()].size() + q; + constraint->SetCoefficient(e[edge_idx][j], 1.0); } } } // h. + absl::flat_hash_set> alias_set; for (auto alias_idx = 0; alias_idx < request.aliases_size(); ++alias_idx) { - const auto& alias = request.aliases(alias_idx); + const auto& raw_alias = request.aliases(alias_idx); + const std::pair alias(raw_alias.first(), + raw_alias.second()); + if (alias_set.contains(alias)) continue; + alias_set.insert(alias); const auto& value_costs = request.value_costs(alias_idx).costs(); - for (NodeStrategyIdx p = 0; p < s[alias.first()].size(); ++p) { - for (NodeStrategyIdx q = 0; q < s[alias.second()].size(); ++q) { + for (NodeStrategyIdx p = 0; p < s[alias.first].size(); ++p) { + for (NodeStrategyIdx q = 0; q < s[alias.second].size(); ++q) { // if lhs == 1 - if (value_costs[p * s[alias.second()].size() + q] > 0.5) { + if (value_costs[p * s[alias.second].size() + q] > 0.5) { MPConstraint* constraint = solver->MakeRowConstraint( -MPSolver::infinity(), 1, - absl::StrCat("s[", alias.first(), "][", p, "] + s[", - alias.second(), "][", q, "] <= 1")); - constraint->SetCoefficient(s[alias.first()][p], 1.0); - constraint->SetCoefficient(s[alias.second()][q], 1.0); + absl::StrCat("s[", alias.first, "][", p, "] + s[", alias.second, + "][", q, "] <= 1")); + constraint->SetCoefficient(s[alias.first][p], 1.0); + constraint->SetCoefficient(s[alias.second][q], 1.0); } } } @@ -468,8 +558,17 @@ AutoShardingSolverResult CallORToolsSolver( } } } + if (request.has_max_cost()) { + double max_cost = kMaxCostEpsilon * request.max_cost().coeff(); + max_cost -= solver->Objective().offset(); + MPConstraint* cost_constraint = solver->MakeRowConstraint( + -MPSolver::infinity(), max_cost, "cost_constraint"); + for (const auto [var, coeff] : solver->Objective().terms()) { + cost_constraint->SetCoefficient(var, coeff); + } + } - if (!request.s_hint().empty()) { + if (!request.s_hint().empty() && !request.deterministic_mode()) { std::vector> hint; for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { if (request.s_follow(node_idx) >= 0) continue; @@ -495,15 +594,18 @@ AutoShardingSolverResult CallORToolsSolver( LOG(ERROR) << write_status.message(); } } - // Exports the solver request proto for debugging. + // Exports the *unscaled* solver request proto for debugging. bool dump_solver_request = false; if (dump_solver_request) { uint64_t solver_request_fprint = - tsl::Fingerprint64(request.SerializeAsString()); + tsl::Fingerprint64(unscaled_request.SerializeAsString()); + std::string request_dump_path = + absl::StrCat("/tmp/solver_request_", unscaled_request.request_name(), + "_", solver_request_fprint, ".proto"); auto write_status = file::SetBinaryProto( // Modify this file path if needed. - absl::StrCat("/tmp/solver_request_", solver_request_fprint, ".proto"), - request, file::Defaults()); + request_dump_path, unscaled_request, file::Defaults()); + VLOG(5) << "Dumped solver request to " << request_dump_path; if (!write_status.ok()) { LOG(ERROR) << write_status.message(); } @@ -513,22 +615,35 @@ AutoShardingSolverResult CallORToolsSolver( solver->SetTimeLimit( absl::Seconds(request.solver_timeout().solver_timeout_in_seconds())); } + if (request.enable_output()) { + solver->EnableOutput(); + } VLOG(0) << "Starting solver " << solver->ProblemType() << "\n" << "Solver parameter string: " << solver_parameter_str << "\n" << "Number of workers: " << num_workers << "\n" << "Number of threads: " << solver->GetNumThreads() << "\n" << "Time limit: " << solver->time_limit() << "\n" - << "Number variables for ILP: " << solver->NumVariables() << "\n" - << "Total vector of variables: " << var_vector_cnt << "\n" + << "Request valid: " << ValidateRequest(request).ok() << "\n" + << "Aliases: " << request.aliases_size() << "\n" + << "Unique nodes: " << unique_nodes << "\n" + << "Unique edges: " << unique_edges << "\n" << "Total instructions: " << request.num_nodes() << "\n" + << "Total edges: " << request.edges_size() << "\n" << "Memory budget: " << request.memory_budget() / (1024 * 1024 * 1024) << "GB\n" + << "Number variables for ILP: " << solver->NumVariables() << "\n" << "Number of ILP constraints: " << solver->NumConstraints() << "\n" + << "Deterministic mode: " << request.deterministic_mode() << "\n" << "Module name: " << request.module_name(); + if (request.has_max_cost()) { + VLOG(0) << "Max cost: " << request.max_cost().coeff(); + } auto result = SolveAndExtractSolution(request, s, e, overbudget_var, makespan_var, *solver); if (result.status.ok()) { - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = + Evaluate(unscaled_request, result); + LOG(INFO) << "*** Total costs for the (unscaled) solver request ***"; LOG(INFO) << "Total Communication Cost: " << evaluation.total.communication_cost << " (lower bound: " << evaluation.lower_bound.communication_cost @@ -554,13 +669,139 @@ AutoShardingSolverResult CallORToolsSolver( return result; } +std::vector GetChosenNodeStrategy( + const AutoShardingSolverRequest& request, + const std::vector>& s) { + std::vector chosen_node_strategy(request.num_nodes(), -1); + for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { + for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { + // if lhs == 1 + if (s[node_idx][j]->solution_value() > 0.5) { + chosen_node_strategy[node_idx] = j; + break; + } + } + } + return chosen_node_strategy; +} + +std::vector GetChosenEdgeStrategy( + const AutoShardingSolverRequest& request, + const std::vector>& e) { + size_t num_edges = request.edges_size(); + std::vector chosen_edge_strategy(num_edges, -1); + for (EdgeIdx edge_idx = 0; edge_idx < num_edges; ++edge_idx) { + for (EdgeStrategyIdx j = 0; j < e[edge_idx].size(); ++j) { + // if lhs == 1 + if (e[edge_idx][j]->solution_value() > 0.5) { + chosen_edge_strategy[edge_idx] = j; + break; + } + } + } + return chosen_edge_strategy; +} + +// Finds the timestep with the largest memory overbudget (-1 if no such value). +LivenessIdx FindPeakLiveness(const AutoShardingSolverRequest& request, + const std::vector>& s, + const std::vector>& e) { + const std::vector chosen_node_strategy = + GetChosenNodeStrategy(request, s); + const std::vector chosen_edge_strategy = + GetChosenEdgeStrategy(request, e); + LivenessIdx peak_time_idx = -1; + double peak_overbudget = 0.0; + for (LivenessIdx time_idx = 0; time_idx < request.live_size(); ++time_idx) { + double memory_usage = 0.0; + for (NodeIdx node_idx : request.live(time_idx).nodes()) { + const NodeStrategyIdx j = chosen_node_strategy[node_idx]; + memory_usage += request.memory_costs(node_idx).costs(j); + } + if (!request.live_edges().empty() && request.enable_memory_edge_costs()) { + for (EdgeIdx edge_idx : request.live_edges(time_idx).edges()) { + const EdgeStrategyIdx j = chosen_edge_strategy[edge_idx]; + memory_usage += request.memory_edge_costs(edge_idx).costs(j); + } + } + const double overbudget = memory_usage - request.memory_budget(); + if (peak_overbudget < overbudget) { + peak_overbudget = overbudget; + peak_time_idx = time_idx; + } + } + return peak_time_idx; +} + +// Imposes a new memory constraint at the given location. +void ImposeMemoryConstraint(const AutoShardingSolverRequest& request, + const std::vector>& s, + const std::vector>& e, + const MPVariable* overbudget_var, MPSolver& solver, + LivenessIdx time_idx) { + VLOG(1) << "Imposing a memory constraint at time " << time_idx; + MPConstraint* constraint = + solver.MakeRowConstraint(-MPSolver::infinity(), MPSolver::infinity(), + absl::StrCat("mem[", time_idx, "]")); + if (overbudget_var) constraint->SetCoefficient(overbudget_var, -1.0); + for (NodeIdx node_idx : request.live(time_idx).nodes()) { + for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { + double memory_cost = request.memory_costs(node_idx).costs(j); + memory_cost *= kMemoryMultiplier; + const double accumulated_coefficient = + constraint->GetCoefficient(s[node_idx][j]); + constraint->SetCoefficient(s[node_idx][j], + accumulated_coefficient + memory_cost); + } + } + if (!request.live_edges().empty() && request.enable_memory_edge_costs()) { + for (EdgeIdx edge_idx : request.live_edges(time_idx).edges()) { + for (EdgeStrategyIdx j = 0; j < e[edge_idx].size(); ++j) { + double memory_cost = request.memory_edge_costs(edge_idx).costs(j); + memory_cost *= kMemoryMultiplier; + const double accumulated_coefficient = + constraint->GetCoefficient(e[edge_idx][j]); + constraint->SetCoefficient(e[edge_idx][j], + accumulated_coefficient + memory_cost); + } + } + } + constraint->SetUB(kMemoryMultiplier * request.memory_budget()); +} + AutoShardingSolverResult SolveAndExtractSolution( const AutoShardingSolverRequest& request, const std::vector>& s, const std::vector>& e, const MPVariable* overbudget_var, const MPVariable* makespan_var, MPSolver& solver) { + absl::Time start_time = absl::Now(); + absl::flat_hash_set peak_times; + if (request.memory_budget() > 0 && !request.deterministic_mode()) { + for (const LivenessIdx peak_time_idx : request.peak_times()) { + peak_times.insert(peak_time_idx); + ImposeMemoryConstraint(request, s, e, overbudget_var, solver, + peak_time_idx); + } + } auto status = solver.Solve(); + if (request.memory_budget() > 0) { + while (status == operations_research::MPSolver::OPTIMAL) { + const LivenessIdx peak_time_idx = FindPeakLiveness(request, s, e); + if (peak_time_idx == -1 || peak_times.contains(peak_time_idx)) break; + peak_times.insert(peak_time_idx); + ImposeMemoryConstraint(request, s, e, overbudget_var, solver, + peak_time_idx); + status = solver.Solve(); + } + LOG(INFO) << "Imposed " << peak_times.size() + << " memory constraints out of " << request.live_size(); + } + absl::Time end_time = absl::Now(); + auto duration = end_time - start_time; + LOG(INFO) << "Solver took " << absl::ToInt64Milliseconds(duration) << " ms"; + LOG(INFO) << "Solver Status: " << status; + if (status == operations_research::MPSolver::INFEASIBLE) { LOG(ERROR) << "MPSolver could not find any feasible solution."; #ifdef PLATFORM_GOOGLE @@ -599,7 +840,6 @@ AutoShardingSolverResult SolveAndExtractSolution( "likely a bug and should be reported."; } else if (status != operations_research::MPSolver::OPTIMAL) { auto err_msg = "Solver timed out."; - LOG(WARNING) << err_msg << " Solver status " << status; return AutoShardingSolverResult(absl::InternalError(err_msg), true); } @@ -610,10 +850,10 @@ AutoShardingSolverResult SolveAndExtractSolution( uint64_t model_fprint = tsl::Fingerprint64(model_proto.SerializeAsString()); operations_research::MPSolutionResponse response; solver.FillSolutionResponseProto(&response); + response.clear_solve_info(); // Remove for fingerprint; can vary between runs uint64_t solution_fprint = tsl::Fingerprint64(response.SerializeAsString()); - LOG(INFO) << "Solver Status: " << status - << " Objective value: " << solver.Objective().Value() + LOG(INFO) << "Objective value: " << solver.Objective().Value() << " Model fingerprint: " << model_fprint << " Solution fingerprint: " << solution_fprint; if (solver.Objective().Value() >= kInfinityCost) { @@ -634,33 +874,22 @@ AutoShardingSolverResult SolveAndExtractSolution( // Return value size_t num_edges = request.edges_size(); double unsalted_objective = 0.0; - std::vector chosen_strategy(request.num_nodes(), -1); - std::vector e_val(num_edges, -1); + const std::vector chosen_node_strategy = + GetChosenNodeStrategy(request, s); + const std::vector chosen_edge_strategy = + GetChosenEdgeStrategy(request, e); for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { - for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { - // if lhs == 1 - if (s[node_idx][j]->solution_value() > 0.5) { - chosen_strategy[node_idx] = j; - unsalted_objective += request.computation_costs(node_idx).costs(j) + - request.communication_costs(node_idx).costs(j); - break; - } - } + const NodeStrategyIdx j = chosen_node_strategy[node_idx]; + unsalted_objective += request.computation_costs(node_idx).costs(j) + + request.communication_costs(node_idx).costs(j); } for (EdgeIdx edge_idx = 0; edge_idx < num_edges; ++edge_idx) { - for (EdgeStrategyIdx j = 0; j < e[edge_idx].size(); ++j) { - // if lhs == 1 - if (e[edge_idx][j]->solution_value() > 0.5) { - e_val[edge_idx] = j; - unsalted_objective += request.resharding_costs(edge_idx).costs(j); - break; - } - } + const EdgeStrategyIdx j = chosen_edge_strategy[edge_idx]; + unsalted_objective += request.resharding_costs(edge_idx).costs(j); } if (overbudget_var) { - unsalted_objective += - request.overbudget_coeff().coeff() * overbudget_var->solution_value(); - unsalted_objective += solver.Objective().offset(); + unsalted_objective += request.overbudget_coeff().coeff() * + overbudget_var->solution_value() / kMemoryMultiplier; } if (makespan_var) { unsalted_objective += @@ -675,11 +904,11 @@ AutoShardingSolverResult SolveAndExtractSolution( LOG(INFO) << "memory budget: " << request.memory_budget() / (1024 * 1024 * 1024) << " GB"; } - PrintLargestInstructions(chosen_strategy, request); - return AutoShardingSolverResult( - std::make_tuple(std::move(chosen_strategy), std::move(e_val), - unsalted_objective), - false); + PrintLargestInstructions(chosen_node_strategy, request); + const AutoShardingSolverOutput output = {std::move(chosen_node_strategy), + std::move(chosen_edge_strategy), + unsalted_objective, peak_times}; + return AutoShardingSolverResult(output, false); } bool CostComponents::operator==(const CostComponents& other) const { @@ -709,8 +938,8 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, const auto& r = request.resharding_costs(); const auto& v = request.value_costs(); const auto& p = request.departure_costs(); - const std::vector& s_val = std::get<0>(*result.status); - const std::vector& e_val = std::get<1>(*result.status); + const std::vector& s_val = result.status->s_val; + const std::vector& e_val = result.status->e_val; AutoShardingEvaluation evaluation; // Compute violations. for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { @@ -757,6 +986,13 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, total_memory_cost += m[s_val[node_idx]]; lower_bound_memory_cost += *std::min_element(m.begin(), m.end()); } + if (!request.live_edges().empty() && request.enable_memory_edge_costs()) { + for (EdgeIdx edge_idx : request.live_edges(time_idx).edges()) { + const auto& m = request.memory_edge_costs(edge_idx).costs(); + total_memory_cost += m[e_val[edge_idx]]; + lower_bound_memory_cost += *std::min_element(m.begin(), m.end()); + } + } if (request.has_overbudget_coeff()) { total_overbudget = std::max( total_overbudget, total_memory_cost - request.memory_budget()); @@ -774,7 +1010,7 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, request.overbudget_coeff().coeff() * lower_bound_overbudget; } } - // Compute metrics & lower bounds. + // Compute metrics and lower bounds. for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { evaluation.total.communication_cost += d.at(node_idx).costs(s_val[node_idx]); @@ -799,8 +1035,8 @@ std::vector Rationalize(const AutoShardingSolverRequest& request, std::vector rationales; const auto& names = request.instruction_names(); - const std::vector& s_result = std::get<0>(*result.status); - const std::vector& s_subopt = std::get<0>(*subopt.status); + const std::vector& s_result = result.status->s_val; + const std::vector& s_subopt = subopt.status->s_val; for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { const NodeStrategyIdx j = s_result[node_idx], k = s_subopt[node_idx]; if (j != k) { @@ -823,8 +1059,8 @@ std::vector Rationalize(const AutoShardingSolverRequest& request, } } - const std::vector& e_result = std::get<1>(*result.status); - const std::vector& e_subopt = std::get<1>(*subopt.status); + const std::vector& e_result = result.status->e_val; + const std::vector& e_subopt = subopt.status->e_val; for (EdgeIdx edge_idx = 0; edge_idx < request.edges_size(); ++edge_idx) { const auto& edge = request.edges(edge_idx); const EdgeStrategyIdx j = e_result[edge_idx], k = e_subopt[edge_idx]; @@ -841,5 +1077,45 @@ std::vector Rationalize(const AutoShardingSolverRequest& request, return rationales; } +Status ValidateRequest(const AutoShardingSolverRequest& request) { + const int num_nodes = request.num_nodes(); + const int num_edges = request.edges_size(); + TF_RET_CHECK(num_nodes == request.computation_costs_size()); + TF_RET_CHECK(num_nodes == request.communication_costs_size()); + TF_RET_CHECK(num_nodes == request.memory_costs_size()); + TF_RET_CHECK(num_edges == request.resharding_costs_size()); + + for (NodeIdx u = 0; u < num_nodes; ++u) { + const int num_strategies = request.computation_costs(u).costs_size(); + TF_RET_CHECK(num_strategies >= 1); + TF_RET_CHECK(num_strategies == request.communication_costs(u).costs_size()); + TF_RET_CHECK(num_strategies == request.memory_costs(u).costs_size()); + for (NodeStrategyIdx strategy = 0; strategy < num_strategies; ++strategy) { + TF_RET_CHECK(request.computation_costs(u).costs(strategy) >= 0.0); + TF_RET_CHECK(request.communication_costs(u).costs(strategy) >= 0.0); + TF_RET_CHECK(request.memory_costs(u).costs(strategy) >= 0.0); + } + } + + absl::btree_set> edges_seen; + for (EdgeIdx e = 0; e < num_edges; ++e) { + const int u = request.edges(e).first(); + const int v = request.edges(e).second(); + TF_RET_CHECK(u >= 0); + TF_RET_CHECK(u < num_nodes); + TF_RET_CHECK(v >= 0); + TF_RET_CHECK(v < num_nodes); + TF_RET_CHECK(u < v); + TF_RET_CHECK(edges_seen.count({u, v}) == 0); + edges_seen.insert({u, v}); + + const int num_strategies = request.resharding_costs(e).costs_size(); + const int num_u_strategies = request.computation_costs(u).costs_size(); + const int num_v_strategies = request.computation_costs(v).costs_size(); + CHECK_EQ(num_strategies, num_u_strategies * num_v_strategies); + } + return OkStatus(); +} + } // namespace spmd } // namespace xla diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h index 9bfa64a114990..32809503603c7 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,36 +16,35 @@ limitations under the License. #ifndef XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_SOLVER_H_ #define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_SOLVER_H_ -#include -#include #include -#include -#include #include #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding.pb.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" -#include "xla/statusor.h" +#include "xla/status.h" #include "ortools/linear_solver/linear_solver.h" -using MPSolver = operations_research::MPSolver; -using MPVariable = operations_research::MPVariable; - namespace xla { namespace spmd { +struct AutoShardingSolverOutput { + std::vector s_val; + std::vector e_val; + double cost = -1.0; + absl::flat_hash_set peak_times; + + bool operator==(const AutoShardingSolverOutput& other) const; +}; + struct AutoShardingSolverResult { public: - AutoShardingSolverResult( - StatusOr, - std::vector, double>> - status, - bool skip_auto_sharding) + AutoShardingSolverResult(absl::StatusOr status, + bool skip_auto_sharding) : status(status), skip_auto_sharding(skip_auto_sharding) {} bool operator==(const AutoShardingSolverResult& other) const; - StatusOr, std::vector, double>> - status; + absl::StatusOr status; bool skip_auto_sharding; }; @@ -78,14 +77,14 @@ struct AutoShardingEvaluation { // A set of constraint violations; should be empty for any viable solution. absl::flat_hash_set violation_codes; - // A breakdown & lower bound for each individual cost component. + // A breakdown and lower bound for each individual cost component. CostComponents total; CostComponents lower_bound; // How many instructions departed from the "default" sharding strategy. double total_departures = 0.0; - // The (raw) total makespan, i.e. not scaled by the makespan coefficient. + // The (raw) total makespan, i.e., not scaled by the makespan coefficient. double total_makespan = 0.0; bool operator==(const AutoShardingEvaluation& other) const; @@ -102,14 +101,49 @@ std::vector Rationalize(const AutoShardingSolverRequest& request, const AutoShardingSolverResult& subopt); // Creates and returns a variable for makespan. -MPVariable* CreateMakespanVar(const AutoShardingSolverRequest& request, - const std::vector>& e, - MPSolver& solver); +operations_research::MPVariable* CreateMakespanVar( + const AutoShardingSolverRequest& request, + const std::vector>& e, + operations_research::MPSolver& solver); double EvaluateMakespan(const AutoShardingSolverRequest& request, const AutoShardingSolverResult& result, AutoShardingEvaluation& evaluation); +// Scale down values to reduce the range of costs & coefficients in the solver. +AutoShardingSolverRequest ScaleRequest( + const AutoShardingSolverRequest& request); + +// Determines if strategy 'first' is dominated by strategy 'second' (i.e., its +// costs are all equal or worse, and it has identical alias mappings). +bool CheckDominance(const AutoShardingSolverRequest& request, + const std::vector& src_edges, + const std::vector& dst_edges, + const std::vector& src_aliases, + const std::vector& dst_aliases, NodeIdx node_idx, + NodeStrategyIdx first, NodeStrategyIdx second); + +class StrategyShaver { + public: + explicit StrategyShaver(const AutoShardingSolverRequest& request); + + // For every node, examine each sharding strategy to see if it is dominated by + // another. + NodeStrategies FindShavedStrategies() const; + + private: + const AutoShardingSolverRequest& request_; // NOLINT + std::vector> src_edge_map_; + std::vector> dst_edge_map_; + std::vector> src_alias_map_; + std::vector> dst_alias_map_; + std::vector> followers_; +}; + +// Check fail if `request` is invalid (e.g., because of negative node costs). +// Note: This does not include checks for valid variable aliasing yet. +Status ValidateRequest(const AutoShardingSolverRequest& request); + } // namespace spmd } // namespace xla diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc index f155be5c12d8b..4be54f98a0a49 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,14 +17,15 @@ limitations under the License. #include "xla/hlo/experimental/auto_sharding/auto_sharding.pb.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_solver.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "ortools/linear_solver/linear_solver.h" -using MPSolver = operations_research::MPSolver; -using MPVariable = operations_research::MPVariable; - namespace xla { namespace spmd { +using ::operations_research::MPSolver; +using ::operations_research::MPVariable; + MPVariable* CreateMakespanVar(const AutoShardingSolverRequest& request, const std::vector>& e, MPSolver& solver) { @@ -37,5 +38,12 @@ double EvaluateMakespan(const AutoShardingSolverRequest& request, return 0.0; // TODO(moffitt): Implement this. } +StrategyShaver::StrategyShaver(const AutoShardingSolverRequest& request) + : request_(request) {} + +NodeStrategies StrategyShaver::FindShavedStrategies() const { + return {}; // TODO(moffitt): Implement this. +} + } // namespace spmd } // namespace xla diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc index d2b63f690f332..ba42a25107aa0 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -13,13 +13,15 @@ limitations under the License. #include "xla/hlo/experimental/auto_sharding/auto_sharding_solver.h" #include -#include #include -#include #include #include +#include #include +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/status/status.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding.pb.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" @@ -29,6 +31,7 @@ namespace { using CostMatrix = std::vector>; using NodeMatrix = std::vector>; +using EdgeMatrix = std::vector>; void AddCosts(proto2::RepeatedPtrField* costs, const CostMatrix& cost_matrix) { @@ -48,6 +51,15 @@ void AddNodes(proto2::RepeatedPtrField* nodes, } } +void AddEdges(proto2::RepeatedPtrField* edges, + const EdgeMatrix& edge_matrix) { + for (const auto& edge_row : edge_matrix) { + AutoShardingSolverRequest_Edges edge; + edge.mutable_edges()->Add(edge_row.begin(), edge_row.end()); + edges->Add(std::move(edge)); + } +} + // clang-format off AutoShardingSolverRequest DefaultAutoShardingSolverRequest() { @@ -69,7 +81,7 @@ AutoShardingSolverRequest DefaultAutoShardingSolverRequest() { {20, 21, 22}, {30, 31, 32, 33}, {40, 41, 42, 43}, - {50, 51, 52, 53}}; + {50, 51, 52}}; const CostMatrix d = {{100, 110, 120, 130}, {200, 210, 220}, {300, 310, 320, 330}, @@ -125,7 +137,87 @@ AutoShardingSolverRequest DefaultAutoShardingSolverRequest() { AddCosts(request.mutable_value_costs(), v); request.mutable_instruction_names()->Add(instruction_names.begin(), instruction_names.end()); + return request; +} + +AutoShardingSolverRequest AutoShardingSolverRequestWithEquivalences() { + const auto s_len = {4, 3, 7, 7, 3}; + const auto s_follow = {-1, -1, -1, 2, -1}; + AutoShardingSolverRequest_Pair edge1, edge2; + edge1.set_first(0); + edge1.set_second(2); + edge2.set_first(1); + edge2.set_second(2); + const auto edges = {edge1, edge2}; + const NodeMatrix live = {{1, 0}, + {1, 0}, + {1, 2, 0}, + {1, 2, 3, 0}, + {1, 3, 0}}; + const CostMatrix c = {{10, 10, 10, 10}, + {20, 20, 20}, + {30, 30, 31, 30, 30, 30, 30}, + {40, 40, 40, 40, 40, 40, 40}, + {50, 50, 50}}; + const CostMatrix d = {{100, 100, 100, 100}, + {200, 200, 200}, + {300, 300, 300, 300, 300, 300, 300}, + {400, 400, 400, 400, 400, 400, 410}, + {500, 500, 500}}; + const CostMatrix m = {{10000, 10000, 10000, 10000}, + {20000, 20000, 20000}, + {30000, 30000, 30000, 31000, 30000, 30000, 30000}, + {40000, 40000, 40000, 40000, 40000, 40000, 40000}, + {50000, 50000, 50000}}; + const CostMatrix p = {{1.0, 0.0, 1.0, 1.0}, + {1.0, 0.0, 1.0}, + {1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {1.0, 0.0, 1.0}}; + const CostMatrix r = {{1000, 1000, 1000, 1000, 1000, 1000, 1000, + 2000, 2000, 2000, 2000, 2000, 2000, 2000, + 3000, 3000, 3000, 3000, 3100, 3000, 3000, + 4000, 4000, 4000, 4000, 4000, 4000, 4000}, + {5000, 5000, 5000, 5000, 5000, 5000, 5000, + 6000, 6000, 6000, 6000, 6000, 6000, 6000, + 7000, 7000, 7000, 7000, 7000, 7000, 7000}}; + const CostMatrix t = {{70000, 70000, 70000, 70000, 70000, 70000, 70000, + 60000, 60000, 60000, 60000, 60000, 60000, 60000, + 50000, 50000, 50000, 50000, 50000, 50000, 50000, + 40000, 40000, 40000, 40000, 40000, 40000, 40000}, + {30000, 30000, 30000, 30000, 30000, 30000, 30000, + 20000, 20000, 20000, 20000, 20000, 20000, 20000, + 10000, 10000, 10000, 10000, 10000, 10000, 10000}}; + AutoShardingSolverRequest_Pair alias; + alias.set_first(2); + alias.set_second(4); + const auto aliases = {alias}; + const CostMatrix v = {{0, 1, 0, + 0, 1, 0, + 0, 1, 0, + 0, 1, 0, + 0, 1, 0, + 1, 0, 1, + 0, 1, 0}}; + const std::vector instruction_names = {"A", "B", "C", "D", "E"}; + + AutoShardingSolverRequest request; + request.set_num_nodes(5); + request.set_memory_budget(1500000); + request.mutable_s_len()->Add(s_len.begin(), s_len.end()); + request.mutable_s_follow()->Add(s_follow.begin(), s_follow.end()); + request.mutable_edges()->Add(edges.begin(), edges.end()); + AddNodes(request.mutable_live(), live); AddCosts(request.mutable_computation_costs(), c); + AddCosts(request.mutable_communication_costs(), d); + AddCosts(request.mutable_memory_costs(), m); + AddCosts(request.mutable_departure_costs(), p); + AddCosts(request.mutable_resharding_costs(), r); + AddCosts(request.mutable_duration_costs(), t); + request.mutable_aliases()->Add(aliases.begin(), aliases.end()); + AddCosts(request.mutable_value_costs(), v); + request.mutable_instruction_names()->Add(instruction_names.begin(), + instruction_names.end()); return request; } @@ -137,9 +229,9 @@ TEST(CallORToolsSolverTest, SolvesOptimally) { const std::vector s_val = {0, 0, 0, 0, 0}; const std::vector e_val = {0, 0}; const double objective_value = 7650.0; - const AutoShardingSolverResult expected_result = { - std::make_tuple( - std::move(s_val), std::move(e_val), objective_value), false}; + const AutoShardingSolverOutput expected_output = + {s_val, e_val, objective_value}; + const AutoShardingSolverResult expected_result = {expected_output, false}; EXPECT_EQ(result, expected_result); } @@ -153,9 +245,10 @@ TEST(CallORToolsSolverTest, SolvesOverbudget) { const std::vector s_val = {0, 0, 0, 0, 0}; const std::vector e_val = {0, 0}; const double objective_value = 9007650.0; - const AutoShardingSolverResult expected_result = { - std::make_tuple( - std::move(s_val), std::move(e_val), objective_value), false}; + const absl::flat_hash_set peak_times = {3}; + const AutoShardingSolverOutput expected_output = + {s_val, e_val, objective_value, peak_times}; + const AutoShardingSolverResult expected_result = {expected_output, false}; EXPECT_EQ(result, expected_result); } @@ -168,9 +261,9 @@ TEST(CallORToolsSolverTest, SolvesMaxDepartures) { const std::vector s_val = {0, 0, 1, 1, 0}; const std::vector e_val = {1, 1}; const double objective_value = 7872.0; - const AutoShardingSolverResult expected_result = { - std::make_tuple( - std::move(s_val), std::move(e_val), objective_value), false}; + const AutoShardingSolverOutput expected_output = + {s_val, e_val, objective_value}; + const AutoShardingSolverResult expected_result = {expected_output, false}; EXPECT_EQ(result, expected_result); } @@ -185,9 +278,9 @@ TEST(CallORToolsSolverTest, AvoidsInfiniteNodeCosts) { const std::vector s_val = {3, 0, 0, 0, 0}; const std::vector e_val = {12, 0}; const double objective_value = 10683.0; - const AutoShardingSolverResult expected_result = { - std::make_tuple( - std::move(s_val), std::move(e_val), objective_value), false}; + const AutoShardingSolverOutput expected_output = + {s_val, e_val, objective_value}; + const AutoShardingSolverResult expected_result = {expected_output, false}; EXPECT_EQ(result, expected_result); } @@ -200,9 +293,9 @@ TEST(CallORToolsSolverTest, AvoidsInfiniteEdgeCosts) { const std::vector s_val = {0, 0, 1, 1, 0}; const std::vector e_val = {1, 1}; const double objective_value = 7872.0; - const AutoShardingSolverResult expected_result = { - std::make_tuple( - std::move(s_val), std::move(e_val), objective_value), false}; + const AutoShardingSolverOutput expected_output = + {s_val, e_val, objective_value}; + const AutoShardingSolverResult expected_result = {expected_output, false}; EXPECT_EQ(result, expected_result); } @@ -227,9 +320,9 @@ TEST(CallORToolsSolverTest, HandlesFollowedEdges) { const std::vector s_val = {0, 0, 0, 0, 0}; const std::vector e_val = {0, 0, 0}; const double objective_value = 12650.0; - const AutoShardingSolverResult expected_result = { - std::make_tuple( - std::move(s_val), std::move(e_val), objective_value), false}; + const AutoShardingSolverOutput expected_output = + {s_val, e_val, objective_value}; + const AutoShardingSolverResult expected_result = {expected_output, false}; EXPECT_EQ(result, expected_result); } @@ -243,9 +336,59 @@ TEST(CallORToolsSolverTest, UsesHint) { const std::vector s_val = {0, 0, 0, 0, 0}; const std::vector e_val = {0, 0}; const double objective_value = 7650.0; - const AutoShardingSolverResult expected_result = { - std::make_tuple( - std::move(s_val), std::move(e_val), objective_value), false}; + const AutoShardingSolverOutput expected_output = + {s_val, e_val, objective_value}; + const AutoShardingSolverResult expected_result = {expected_output, false}; + EXPECT_EQ(result, expected_result); +} + +TEST(CallORToolsSolverTest, HonorsMaxCost) { + AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); + request.mutable_max_cost()->set_coeff(7600.0); // Best possible is 7650.0 + + const AutoShardingSolverResult result = CallORToolsSolver(request); + + EXPECT_TRUE(absl::IsInternal(result.status.status())); +} + +TEST(CallORToolsSolverTest, HandlesMemoryEdgeCosts) { + AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); + const EdgeMatrix live_edges = {{}, {0}, {0, 1}, {1}, {}}; + const CostMatrix memory_edge_costs = {{1000000, 1100, 1200, 1300, + 2000, 2100, 2200, 2300, + 3000, 3100, 3200, 3300, + 4000, 4100, 4200, 4300}, + {5000000, 5100, 5200, 5300, + 6000, 6100, 6200, 6300, + 7000, 7100, 7200, 7300}}; + AddEdges(request.mutable_live_edges(), live_edges); + AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs); + request.set_enable_memory_edge_costs(true); + + const AutoShardingSolverResult result = CallORToolsSolver(request); + + const std::vector s_val = {0, 0, 1, 1, 0}; + const std::vector e_val = {1, 1}; + const double objective_value = 7872.0; + const absl::flat_hash_set peak_times = {2}; + const AutoShardingSolverOutput expected_output = + {s_val, e_val, objective_value, peak_times}; + const AutoShardingSolverResult expected_result = {expected_output, false}; + EXPECT_EQ(result, expected_result); +} + +TEST(CallORToolsSolverTest, SolvesWithEquivalences) { + const AutoShardingSolverRequest request = + AutoShardingSolverRequestWithEquivalences(); + + const AutoShardingSolverResult result = CallORToolsSolver(request); + + const std::vector s_val = {0, 0, 5, 5, 1}; + const std::vector e_val = {5, 5}; + const double objective_value = 7650.0; + const AutoShardingSolverOutput expected_output = + {s_val, e_val, objective_value}; + const AutoShardingSolverResult expected_result = {expected_output, false}; EXPECT_EQ(result, expected_result); } @@ -254,9 +397,8 @@ TEST(AutoShardingEvaluatorTest, NoViolations) { const std::vector s_val = {3, 1, 2, 2, 1}; const std::vector e_val = {14, 6}; const double objective_value = 12149.0; - const AutoShardingSolverResult result = { - std::make_tuple( - std::move(s_val), std::move(e_val), objective_value), false}; + const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; + const AutoShardingSolverResult result = {output, false}; const AutoShardingEvaluation evaluation = Evaluate(request, result); @@ -278,9 +420,8 @@ TEST(AutoShardingEvaluatorTest, EvaluatesOverbudget) { const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; const std::vector e_val = {10, 6}; const double objective_value = 11138.0; - const AutoShardingSolverResult result = { - std::make_tuple( - std::move(s_val), std::move(e_val), objective_value), false}; + const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; + const AutoShardingSolverResult result = {output, false}; const AutoShardingEvaluation evaluation = Evaluate(request, result); @@ -302,9 +443,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesFollower) { const std::vector s_val = {3, 1, 2, 1 /* violates */, 1}; const std::vector e_val = {14, 6}; const double objective_value = 12138.0; - const AutoShardingSolverResult result = { - std::make_tuple( - std::move(s_val), std::move(e_val), objective_value), false}; + const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; + const AutoShardingSolverResult result = {output, false}; const AutoShardingEvaluation evaluation = Evaluate(request, result); @@ -325,9 +465,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesAlias) { const std::vector s_val = {3, 1, 2, 2, 0 /* violates */}; const std::vector e_val = {14, 6}; const double objective_value = 12138.0; - const AutoShardingSolverResult result = { - std::make_tuple( - std::move(s_val), std::move(e_val), objective_value), false}; + const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; + const AutoShardingSolverResult result = {output, false}; const AutoShardingEvaluation evaluation = Evaluate(request, result); @@ -348,9 +487,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesMemory) { const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; const std::vector e_val = {10, 6}; const double objective_value = 11138.0; - const AutoShardingSolverResult result = { - std::make_tuple( - std::move(s_val), std::move(e_val), objective_value), false}; + const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; + const AutoShardingSolverResult result = {output, false}; const AutoShardingEvaluation evaluation = Evaluate(request, result); @@ -374,9 +512,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForNode) { const std::vector s_val = {0 /* violates */, 1, 2, 2, 1}; const std::vector e_val = {2, 6}; const double objective_value = 1e+20; - const AutoShardingSolverResult result = { - std::make_tuple( - std::move(s_val), std::move(e_val), objective_value), false}; + const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; + const AutoShardingSolverResult result = {output, false}; const AutoShardingEvaluation evaluation = Evaluate(request, result); @@ -398,9 +535,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForEdge) { const std::vector s_val = {0, 1, 2, 2, 1}; const std::vector e_val = {2 /* violates */, 6}; const double objective_value = 1e+20; - const AutoShardingSolverResult result = { - std::make_tuple( - std::move(s_val), std::move(e_val), objective_value), false}; + const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; + const AutoShardingSolverResult result = {output, false}; const AutoShardingEvaluation evaluation = Evaluate(request, result); @@ -422,9 +558,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesMaxDepartures) { const std::vector s_val = {3, 1, 2, 2, 1}; const std::vector e_val = {14, 6}; const double objective_value = 12149.0; - const AutoShardingSolverResult result = { - std::make_tuple( - std::move(s_val), std::move(e_val), objective_value), false}; + const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; + const AutoShardingSolverResult result = {output, false}; const AutoShardingEvaluation evaluation = Evaluate(request, result); @@ -445,18 +580,17 @@ TEST(AutoShardingRationalizerTest, RationalizesProperly) { const std::vector s_val = {0, 1, 2, 2, 1}; const std::vector e_val = {2, 6}; const double objective_value = 9116.0; - const AutoShardingSolverResult result = { - std::make_tuple( - std::move(s_val), std::move(e_val), objective_value), false}; + const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; + const AutoShardingSolverResult result = {output, false}; const std::vector s_subopt = {3, 1, 2, 2, 1}; const std::vector e_subopt = {14, 6}; const double subopt_value = 12149.0; - const AutoShardingSolverResult subopt = { - std::make_tuple( - std::move(s_subopt), std::move(e_subopt), subopt_value), false}; + const AutoShardingSolverOutput subopt_output = + {s_subopt, e_subopt, subopt_value}; + const AutoShardingSolverResult subopt_result = {subopt_output, false}; const std::vector rationales = - Rationalize(request, result, subopt); + Rationalize(request, result, subopt_result); const std::vector expected_rationales = { "strategy changes for A (0 -> 3)", @@ -466,6 +600,112 @@ TEST(AutoShardingRationalizerTest, RationalizesProperly) { EXPECT_EQ(rationales, expected_rationales); } +TEST(ScaleRequest, ScalesProperly) { + AutoShardingSolverRequest unscaled_request; + const CostMatrix c = {{10000000, 11000000, 12000000, 13000000}, + {20000000, 21000000, 22000000}, + {30000000, 31000000, 32000000, 33000000}, + {40000000, 41000000, 42000000, 43000000}, + {50000000, 51000000, 52000000, 53000000}}; + const CostMatrix d = {{100000000, 110000000, 120000000, 130000000}, + {200000000, 210000000, 220000000}, + {300000000, 310000000, 320000000, 330000000}, + {400000000, 410000000, 420000000, 430000000}, + {500000000, 510000000, 520000000}}; + const CostMatrix r = {{1000000000, 1100000000, 1200000000, 1300000000, + 2000000000, 2100000000, 2200000000, 2300000000, + 3000000000, 3100000000, 3200000000, 3300000000, + 4000000000, 4100000000, 4200000000, 4300000000}, + {5000000000, 5100000000, 5200000000, 5300000000, + 6000000000, 6100000000, 6200000000, 6300000000, + 7000000000, 7100000000, 7200000000, 10000000000000}}; + AddCosts(unscaled_request.mutable_computation_costs(), c); + AddCosts(unscaled_request.mutable_communication_costs(), d); + AddCosts(unscaled_request.mutable_resharding_costs(), r); + unscaled_request.mutable_coeff_limit()->set_coeff(1e7); + + AutoShardingSolverRequest request = ScaleRequest(unscaled_request); + + AutoShardingSolverRequest expected_request; + const CostMatrix expected_c = {{10, 11, 12, 13}, + {20, 21, 22}, + {30, 31, 32, 33}, + {40, 41, 42, 43}, + {50, 51, 52, 53}}; + const CostMatrix expected_d = {{100, 110, 120, 130}, + {200, 210, 220}, + {300, 310, 320, 330}, + {400, 410, 420, 430}, + {500, 510, 520}}; + const CostMatrix expected_r = {{1000, 1100, 1200, 1300, + 2000, 2100, 2200, 2300, + 3000, 3100, 3200, 3300, + 4000, 4100, 4200, 4300}, + {5000, 5100, 5200, 5300, + 6000, 6100, 6200, 6300, + 7000, 7100, 7200, 10000000}}; + AddCosts(expected_request.mutable_computation_costs(), expected_c); + AddCosts(expected_request.mutable_communication_costs(), expected_d); + AddCosts(expected_request.mutable_resharding_costs(), expected_r); + expected_request.mutable_coeff_limit()->set_coeff(1e7); + EXPECT_THAT(request, ::testing::EqualsProto(expected_request)); +} + +TEST(ScaleRequest, SkipsScaling) { + AutoShardingSolverRequest unscaled_request; + const CostMatrix c = {{10, 11, 12, 13}, + {20, 21, 22}, + {30, 31, 32, 33}, + {40, 41, 42, 43}, + {50, 51, 52, 53}}; + const CostMatrix d = {{100, 110, 120, 130}, + {200, 210, 220}, + {300, 310, 320, 330}, + {400, 410, 420, 430}, + {500, 510, 520}}; + const CostMatrix r = {{1000, 1100, 1200, 1300, + 2000, 2100, 2200, 2300, + 3000, 3100, 3200, 3300, + 4000, 4100, 4200, 4300}, + {5000, 5100, 5200, 5300, + 6000, 6100, 6200, 6300, + 7000, 7100, 7200, 10000000}}; + AddCosts(unscaled_request.mutable_computation_costs(), c); + AddCosts(unscaled_request.mutable_communication_costs(), d); + AddCosts(unscaled_request.mutable_resharding_costs(), r); + unscaled_request.mutable_coeff_limit()->set_coeff(1e7); + + AutoShardingSolverRequest request = ScaleRequest(unscaled_request); + + AutoShardingSolverRequest expected_request; + const CostMatrix expected_c = {{10, 11, 12, 13}, + {20, 21, 22}, + {30, 31, 32, 33}, + {40, 41, 42, 43}, + {50, 51, 52, 53}}; + const CostMatrix expected_d = {{100, 110, 120, 130}, + {200, 210, 220}, + {300, 310, 320, 330}, + {400, 410, 420, 430}, + {500, 510, 520}}; + const CostMatrix expected_r = {{1000, 1100, 1200, 1300, + 2000, 2100, 2200, 2300, + 3000, 3100, 3200, 3300, + 4000, 4100, 4200, 4300}, + {5000, 5100, 5200, 5300, + 6000, 6100, 6200, 6300, + 7000, 7100, 7200, 10000000}}; + AddCosts(expected_request.mutable_computation_costs(), expected_c); + AddCosts(expected_request.mutable_communication_costs(), expected_d); + AddCosts(expected_request.mutable_resharding_costs(), expected_r); + expected_request.mutable_coeff_limit()->set_coeff(1e7); + EXPECT_THAT(request, ::testing::EqualsProto(expected_request)); +} + +TEST(ValidateRequest, AcceptsAutoShardingSolverRequest) { + CHECK_OK(ValidateRequest(DefaultAutoShardingSolverRequest())); +} + // clang-format on } // namespace diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc new file mode 100644 index 0000000000000..2c67fa56a1dc2 --- /dev/null +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -0,0 +1,912 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/array.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h" +#include "xla/hlo/experimental/auto_sharding/cluster_environment.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/utils/hlo_sharding_util.h" +#include "xla/service/call_graph.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/sharding_propagation.h" +#include "xla/shape.h" +#include "xla/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" + +namespace xla { +namespace spmd { + +bool LeafVectorsAreConsistent(const std::vector& one, + const std::vector& two) { + return one.size() == two.size(); +} + +// NOLINTBEGIN(readability/fn_size) +// TODO(zhuohan): Decompose this function into smaller pieces +absl::StatusOr> +BuildStrategyAndCost(const HloInstructionSequence& sequence, + const HloModule* module, + const absl::flat_hash_map& + instruction_execution_counts, + const InstructionDepthMap& depth_map, + const InstructionBatchDimMap& batch_dim_map, + const AliasMap& alias_map, + const ClusterEnvironment& cluster_env, + AutoShardingOption& option, const CallGraph& call_graph, + const HloCostAnalysis& hlo_cost_analysis, + bool trying_multiple_mesh_shapes) { + const Array& device_mesh = cluster_env.device_mesh_; + StrategyMap strategy_map; + // This map stores all of the trimmed strategies due to user specified + // sharding. The key is the instruction id, the value is the strategies. This + // is useful when the operand is forced to use a user sharding, and the op + // doesn't need to strictly follow it. We restore the trimmed strategies in + // this situation. + StableHashMap> pretrimmed_strategy_map; + StrategyGroups strategy_groups; + AssociativeDotPairs associative_dot_pairs; + + const std::vector& instructions = sequence.instructions(); + + // Add penalty for replicated tensors + double replicated_penalty = std::round(cluster_env.AllReduceCost(1, 0) + + cluster_env.AllReduceCost(1, 1)); + + int64_t max_depth = -1; + for (auto iter : depth_map) { + max_depth = std::max(max_depth, iter.second); + } + + absl::flat_hash_map + while_body_args_to_input_tuple; + // Register strategies and their costs for each instruction. + for (size_t instruction_id = 0; instruction_id < instructions.size(); + ++instruction_id) { + const HloInstruction* ins = instructions[instruction_id]; + VLOG(2) << "instruction_id = " << instruction_id << ": " + << ToAdaptiveString(ins); + std::unique_ptr strategy_group; + + HloOpcode opcode = ins->opcode(); + + bool only_allow_divisible; + if (IsEntryComputationInputOrOutput(module, ins)) { + // With IsEntryComputationInputOrOutput(module, ins) == true, entry + // computation's root instruction may still be unevenly sharded because it + // usually "follows" other instruction's sharding. If the instruction it + // follows is an intermediate instruction, it may be able to choose + // unevenly sharded strategiyes. Usually if we constraint input's sharding + // strategies, outputs would be constrained as welll, but if outputs are + // still unevely sharded in some cases, we need to fix the implementation + // in auto sharding. + only_allow_divisible = option.only_allow_divisible_input_output; + } else { + only_allow_divisible = option.only_allow_divisible_intermediate; + } + + bool is_follow_necessary_for_correctness = false; + switch (opcode) { + case HloOpcode::kParameter: { + auto it = while_body_args_to_input_tuple.find(ins); + if (it != while_body_args_to_input_tuple.end()) { + const HloInstruction* while_input_tuple = it->second; + const StrategyGroup* while_input_tuple_strategy_group = + strategy_map.at(while_input_tuple).get(); + + VLOG(5) << "Following while input " << while_input_tuple->name(); + strategy_group = CreateTupleStrategyGroup(instruction_id); + strategy_group->childs.reserve(ins->shape().tuple_shapes_size()); + // We use this following relationship to ensure that the input tuple + // of the while loop, and the parameter of the body of that while + // loop. Therefore, this followinf relationship is necessary for + // correctness, and is not merely an optmization. + is_follow_necessary_for_correctness = true; + for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) { + std::unique_ptr child_strategies = + MaybeFollowInsStrategyGroup( + while_input_tuple_strategy_group->childs[i].get(), + ins->shape().tuple_shapes().at(i), instruction_id, + /* have_memory_cost= */ true, strategy_groups, cluster_env, + pretrimmed_strategy_map); + child_strategies->tuple_element_idx = i; + strategy_group->childs.push_back(std::move(child_strategies)); + } + } else { + strategy_group = + CreateAllStrategiesGroup( + ins, ins->shape(), instruction_id, strategy_groups, + cluster_env, strategy_map, option, replicated_penalty, + batch_dim_map, call_graph, only_allow_divisible, + option.allow_replicated_parameters, + /* create_partially_replicated_strategies */ true) + .value(); + } + break; + } + case HloOpcode::kRngBitGenerator: + case HloOpcode::kRng: { + strategy_group = + CreateAllStrategiesGroup( + ins, ins->shape(), instruction_id, strategy_groups, cluster_env, + strategy_map, option, replicated_penalty, batch_dim_map, + call_graph, only_allow_divisible, + option.allow_replicated_parameters, + /* create_partially_replicated_strategies */ true) + .value(); + break; + } + case HloOpcode::kConstant: { + strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, + strategy_groups); + AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, + strategy_group, 0); + break; + } + case HloOpcode::kScatter: { + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, strategy_groups); + // We follow the first operand (the array we're scattering into) + auto src_strategy_group = strategy_map.at(ins->operand(0)).get(); + CHECK(!src_strategy_group->is_tuple); + for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); + ++sid) { + HloSharding output_spec = + src_strategy_group->strategies[sid].output_sharding; + std::string name = ToStringSimple(output_spec); + double compute_cost = 0, communication_cost = 0; + double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles(); + + std::vector> input_shardings_optional( + {output_spec, std::nullopt, std::nullopt}); + std::pair resharding_costs = + GenerateReshardingCostsAndMissingShardingsForAllOperands( + ins, output_spec, strategy_map, cluster_env, call_graph, + input_shardings_optional); + + for (const auto& sharding_optional : input_shardings_optional) { + CHECK(sharding_optional.has_value()); + } + + strategy_group->strategies.push_back(ShardingStrategy( + {name, output_spec, compute_cost, communication_cost, memory_cost, + std::move(resharding_costs.first), + std::move(resharding_costs.second), input_shardings_optional})); + } + break; + } + case HloOpcode::kGather: { + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, strategy_groups); + // Follows the strategy of start_indices (operand 1) + const HloInstruction* indices = ins->operand(1); + const Shape& shape = ins->shape(); + const StrategyGroup* src_strategy_group = + strategy_map.at(indices).get(); + CHECK(!src_strategy_group->is_tuple); + strategy_group->following = src_strategy_group; + for (int32_t index_dim = 0; index_dim < indices->shape().rank(); + index_dim++) { + // Shard on indices dimensions that correspond to output dimensions + // TODO(b/220935014) Shard the last dim of output (model dim) with + // AllGather cost and no follow. + if (index_dim == ins->gather_dimension_numbers().index_vector_dim()) { + continue; + } + for (int64_t j = 0; j < device_mesh.num_dimensions(); ++j) { + // Split only when the tensor shape is divisible by device + // mesh. + if (device_mesh.dim(j) == 1 || + (only_allow_divisible && + !IsDivisible(shape.dimensions(index_dim), + device_mesh.dim(j)))) { + continue; + } + std::string name = absl::StrCat("S", index_dim, " @ ", j); + + HloSharding output_spec = + Tile(shape, {index_dim}, {j}, device_mesh); + double compute_cost = 0, communication_cost = 0; + double memory_cost = GetBytes(shape) / output_spec.NumTiles(); + std::optional input_spec = + hlo_sharding_util::ReshapeSharding(shape, indices->shape(), + output_spec); + if (!input_spec.has_value()) { // invalid reshape + continue; + } + std::vector> input_shardings_optional( + {std::nullopt, input_spec}); + std::pair resharding_costs = + GenerateReshardingCostsAndMissingShardingsForAllOperands( + ins, output_spec, strategy_map, cluster_env, call_graph, + input_shardings_optional); + + strategy_group->strategies.push_back(ShardingStrategy( + {name, output_spec, compute_cost, communication_cost, + memory_cost, std::move(resharding_costs.first), + std::move(resharding_costs.second), + input_shardings_optional})); + } + } + AddReplicatedStrategy( + ins, ins->shape(), cluster_env, strategy_map, strategy_group, 0, + /* operands_to_consider_all_strategies_for */ {0}); + break; + } + case HloOpcode::kBroadcast: { + strategy_group = + CreateAllStrategiesGroup( + ins, ins->shape(), instruction_id, strategy_groups, cluster_env, + strategy_map, option, replicated_penalty, batch_dim_map, + call_graph, only_allow_divisible, + /* create_replicated_strategies */ true, + /* create_partially_replicated_strategies */ true) + .value(); + break; + } + case HloOpcode::kReshape: { + strategy_group = CreateReshapeStrategies( + instruction_id, ins, strategy_map, cluster_env, + only_allow_divisible, replicated_penalty, batch_dim_map, option, + strategy_groups, call_graph); + break; + } + case HloOpcode::kTranspose: + case HloOpcode::kReverse: { + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, strategy_groups); + + const HloInstruction* operand = ins->operand(0); + + // Create follow strategies + const StrategyGroup* src_strategy_group = + strategy_map.at(operand).get(); + CHECK(!src_strategy_group->is_tuple); + strategy_group->following = src_strategy_group; + + for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); + ++sid) { + HloSharding output_spec = Undefined(); + auto input_spec = src_strategy_group->strategies[sid].output_sharding; + if (opcode == HloOpcode::kTranspose) { + output_spec = hlo_sharding_util::TransposeSharding( + input_spec, ins->dimensions()); + } else { + output_spec = hlo_sharding_util::ReverseSharding(input_spec, + ins->dimensions()); + } + + std::string name = ToStringSimple(output_spec); + double compute_cost = 0, communication_cost = 0; + double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles(); + std::vector communication_resharding_costs = + CommunicationReshardingCostVector(src_strategy_group, + operand->shape(), input_spec, + cluster_env); + std::vector memory_resharding_costs = + MemoryReshardingCostVector(src_strategy_group, operand->shape(), + input_spec, cluster_env); + strategy_group->strategies.push_back( + ShardingStrategy({name, + output_spec, + compute_cost, + communication_cost, + memory_cost, + {communication_resharding_costs}, + {memory_resharding_costs}, + {input_spec}})); + } + break; + } + case HloOpcode::kPad: + case HloOpcode::kSlice: + case HloOpcode::kConcatenate: // TODO(zhuohan): revisit concatenate + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: { + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, strategy_groups); + int64_t follow_idx; + switch (opcode) { + // TODO(yuemmawang) Re-evaluate the follow_idx choices for the + // following 3. + case HloOpcode::kPad: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kConcatenate: + // Follow the operand according to the follow heuristics + follow_idx = ChooseOperandToFollow(strategy_map, depth_map, + alias_map, max_depth, ins) + .first; + break; + // The following types are better to follow the first operand. + case HloOpcode::kSlice: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + follow_idx = 0; + break; + default: + LOG(FATAL) << "Selecting follow index encounters an unhandled " + "instruction type: " + + ins->ToShortString(); + } + // Create follow strategies + const HloInstruction* operand = ins->operand(follow_idx); + StrategyGroup* src_strategy_group = strategy_map.at(operand).get(); + CHECK(!src_strategy_group->is_tuple); + strategy_group->following = src_strategy_group; + + for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); + ++sid) { + std::optional output_spec; + HloSharding input_spec = + src_strategy_group->strategies[sid].output_sharding; + + // Find output shardings. + switch (opcode) { + case HloOpcode::kSlice: { + // When solve_nd_sharding_iteratively is true, in some cases, we + // can have 1D shardings where the total number of tiles is larger + // than the number of elements in the partial mesh (and is + // actually equal to the number of devices in the original + // mesh). Below, we use the correct mesh depending on the number + // of elements in the 1D sharding. + bool is_1d_sharding = + VectorGreaterThanOneElementCount( + input_spec.tile_assignment().dimensions()) == 1; + if (is_1d_sharding && + input_spec.TotalNumTiles() == + cluster_env.device_mesh_1d_.num_elements()) { + output_spec = PropagateDimwiseShardingSlice( + input_spec, operand->shape(), ins->shape(), + cluster_env.device_mesh_1d_); + } else if (is_1d_sharding) { + CHECK_EQ(input_spec.TotalNumTiles(), + cluster_env.original_device_mesh_1d_.num_elements()); + output_spec = PropagateDimwiseShardingSlice( + input_spec, operand->shape(), ins->shape(), + cluster_env.original_device_mesh_1d_); + } else { + output_spec = PropagateDimwiseShardingSlice( + input_spec, operand->shape(), ins->shape(), + cluster_env.device_mesh_); + } + break; + } + case HloOpcode::kPad: + case HloOpcode::kConcatenate: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + output_spec = PropagateDimwiseSharding( + input_spec, operand->shape(), ins->shape()); + break; + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + output_spec = PropagateReduceWindowSharding( + input_spec, operand->shape(), ins->window()); + break; + default: + LOG(FATAL) << "Unhandled instruction: " + ins->ToString(); + } + + // Get a list of input shardings, each corresponds to an operand. + std::vector> input_shardings; + for (int64_t k = 0; k < ins->operand_count(); ++k) { + if (k == follow_idx || + ToString(ins->operand(k)->shape().dimensions()) == + ToString(operand->shape().dimensions())) { + input_shardings.push_back(input_spec); + } else { + input_shardings.push_back(std::nullopt); + } + } + if (!output_spec.has_value()) { + continue; + } + + std::string name = ToStringSimple(*output_spec); + double compute_cost = 0, communication_cost = 0; + double memory_cost = GetBytes(ins->shape()) / output_spec->NumTiles(); + std::pair resharding_costs = + GenerateReshardingCostsAndMissingShardingsForAllOperands( + ins, *output_spec, strategy_map, cluster_env, call_graph, + input_shardings); + + strategy_group->strategies.push_back( + ShardingStrategy({name, + *output_spec, + compute_cost, + communication_cost, + memory_cost, + std::move(resharding_costs.first), + std::move(resharding_costs.second), + {input_spec}})); + } + + if (strategy_group->strategies.empty()) { + strategy_group->following = nullptr; + AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, + strategy_group, 0); + } + break; + } + case HloOpcode::kOptimizationBarrier: { + auto operand_strategies = strategy_map.at(ins->operand(0)).get(); + strategy_group = MaybeFollowInsStrategyGroup( + operand_strategies, ins->shape(), instruction_id, + /* have_memory_cost */ true, strategy_groups, cluster_env, + pretrimmed_strategy_map); + break; + } + case HloOpcode::kBitcast: { + if (ins->shape() == ins->operand(0)->shape()) { + strategy_group = CreateElementwiseOperatorStrategies( + instruction_id, ins, strategy_map, cluster_env, depth_map, + alias_map, pretrimmed_strategy_map, max_depth, strategy_groups, + associative_dot_pairs); + } else { + strategy_group = CreateReshapeStrategies( + instruction_id, ins, strategy_map, cluster_env, + only_allow_divisible, replicated_penalty, batch_dim_map, option, + strategy_groups, call_graph); + } + break; + } + // Unary elementwise operations. + case HloOpcode::kAbs: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kRoundNearestEven: + case HloOpcode::kCeil: + case HloOpcode::kClz: + case HloOpcode::kConvert: + case HloOpcode::kBitcastConvert: + case HloOpcode::kCopy: + case HloOpcode::kCos: + case HloOpcode::kErf: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kFloor: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kNot: + case HloOpcode::kNegate: + case HloOpcode::kPopulationCount: + case HloOpcode::kReal: + case HloOpcode::kReducePrecision: + case HloOpcode::kRsqrt: + case HloOpcode::kLogistic: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kSqrt: + case HloOpcode::kCbrt: + case HloOpcode::kTan: + case HloOpcode::kTanh: + // Binary elementwise operations + case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kCompare: + case HloOpcode::kComplex: + case HloOpcode::kDivide: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kSubtract: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kXor: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + case HloOpcode::kStochasticConvert: + // Ternary elementwise operations. + case HloOpcode::kSelect: + case HloOpcode::kClamp: { + strategy_group = CreateElementwiseOperatorStrategies( + instruction_id, ins, strategy_map, cluster_env, depth_map, + alias_map, pretrimmed_strategy_map, max_depth, strategy_groups, + associative_dot_pairs); + break; + } + case HloOpcode::kReduce: { + auto strategies_status = FollowReduceStrategy( + ins, ins->shape(), ins->operand(0), ins->operand(1), instruction_id, + strategy_map, strategy_groups, cluster_env, + option.allow_mixed_mesh_shape, !trying_multiple_mesh_shapes); + if (strategies_status.ok()) { + strategy_group = std::move(strategies_status.value()); + } else { + return strategies_status.status(); + } + break; + } + case HloOpcode::kDot: { + TF_RETURN_IF_ERROR(HandleDot(strategy_group, strategy_groups, + strategy_map, ins, instruction_id, + sequence, hlo_cost_analysis, cluster_env, + batch_dim_map, option, call_graph)); + + if (option.allow_recompute_heavy_op) { + AddReplicatedStrategy( + ins, ins->shape(), cluster_env, strategy_map, strategy_group, + GetDotConvReplicationPenalty(ins, instruction_id, /* window */ 10, + sequence, hlo_cost_analysis)); + } + break; + } + case HloOpcode::kConvolution: { + TF_RETURN_IF_ERROR(HandleConv(strategy_group, strategy_groups, + strategy_map, ins, instruction_id, + sequence, hlo_cost_analysis, cluster_env, + batch_dim_map, option, call_graph)); + if (option.allow_recompute_heavy_op) { + AddReplicatedStrategy( + ins, ins->shape(), cluster_env, strategy_map, strategy_group, + GetDotConvReplicationPenalty(ins, instruction_id, /* window */ 10, + sequence, hlo_cost_analysis)); + } + break; + } + case HloOpcode::kRngGetAndUpdateState: { + strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, + strategy_groups); + AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, + strategy_group, 0); + break; + } + case HloOpcode::kIota: { + // For an unknown reason, we do not generate partially replicated + // strategies for iota ops. This can be changed if we find that our + // search isn't exhaustive enough for certain ops. + strategy_group = + CreateAllStrategiesGroup( + ins, ins->shape(), instruction_id, strategy_groups, cluster_env, + strategy_map, option, replicated_penalty, batch_dim_map, + call_graph, only_allow_divisible, + /* create_replicated_strategies */ true, + /* create_partially_replicated_strategies */ false) + .value(); + break; + } + case HloOpcode::kTuple: { + strategy_group = CreateTupleStrategyGroup(instruction_id); + strategy_group->childs.reserve(ins->operand_count()); + for (size_t i = 0; i < ins->operand_count(); ++i) { + const HloInstruction* operand = ins->operand(i); + const StrategyGroup* src_strategy_group = + strategy_map.at(operand).get(); + auto child_strategies = MaybeFollowInsStrategyGroup( + src_strategy_group, operand->shape(), instruction_id, + /* have_memory_cost= */ true, strategy_groups, cluster_env, + pretrimmed_strategy_map); + child_strategies->tuple_element_idx = i; + strategy_group->childs.push_back(std::move(child_strategies)); + } + + if (ins->users().size() == 1 && + ins->users()[0]->opcode() == HloOpcode::kWhile) { + const HloInstruction* while_op = ins->users()[0]; + while_body_args_to_input_tuple[while_op->while_body() + ->parameter_instruction(0)] = ins; + while_body_args_to_input_tuple[while_op->while_condition() + ->parameter_instruction(0)] = ins; + } + + break; + } + case HloOpcode::kGetTupleElement: { + const HloInstruction* operand = ins->operand(0); + const StrategyGroup* src_strategy_group = + strategy_map.at(operand).get(); + CHECK(src_strategy_group->is_tuple); + strategy_group = MaybeFollowInsStrategyGroup( + src_strategy_group->childs[ins->tuple_index()].get(), ins->shape(), + instruction_id, + /* have_memory_cost= */ true, strategy_groups, cluster_env, + pretrimmed_strategy_map); + break; + } + case HloOpcode::kCustomCall: { + auto generate_non_following_strategies = + [&](bool only_replicated, + absl::flat_hash_set + operands_to_consider_all_strategies_for = {}) { + if (only_replicated) { + if (ins->shape().IsTuple()) { + strategy_group = CreateTupleStrategyGroup(instruction_id); + strategy_group->childs.reserve( + ins->shape().tuple_shapes_size()); + for (size_t i = 0; i < ins->shape().tuple_shapes_size(); + ++i) { + std::unique_ptr child_strategies = + CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, strategy_groups); + AddReplicatedStrategy(ins, ins->shape().tuple_shapes(i), + cluster_env, strategy_map, + child_strategies, replicated_penalty); + strategy_group->childs.push_back( + std::move(child_strategies)); + } + } else { + strategy_group = CreateLeafStrategyGroup( + instruction_id, ins, strategy_map, strategy_groups); + AddReplicatedStrategy(ins, ins->shape(), cluster_env, + strategy_map, strategy_group, + replicated_penalty); + } + } else { + strategy_group = + CreateAllStrategiesGroup( + ins, ins->shape(), instruction_id, strategy_groups, + cluster_env, strategy_map, option, replicated_penalty, + batch_dim_map, call_graph, only_allow_divisible, + /* create_replicated_strategies */ true, + /* create_partially_replicated_strategies */ true) + .value(); + } + }; + + if (IsCustomCallMarker(ins)) { + const HloInstruction* operand = ins->operand(0); + const StrategyGroup* src_strategy_group = + strategy_map.at(operand).get(); + CHECK(src_strategy_group->is_tuple); + strategy_group = MaybeFollowInsStrategyGroup( + src_strategy_group, ins->shape(), instruction_id, + /* have_memory_cost= */ true, strategy_groups, cluster_env, + pretrimmed_strategy_map); + } else if (ins->has_sharding()) { + generate_non_following_strategies(false); + } else if (OutputInputSameShapes(ins)) { + auto* partitioner = + GetCustomCallPartitioner(ins->custom_call_target()); + if (partitioner && partitioner->IsCustomCallShardable(ins)) { + // Follows operand 0's strategies if this custom-call op is + // shardable and has the same input and output sizes. + const HloInstruction* operand = ins->operand(0); + const StrategyGroup* src_strategy_group = + strategy_map.at(operand).get(); + strategy_group = MaybeFollowInsStrategyGroup( + src_strategy_group, ins->shape(), instruction_id, + /* have_memory_cost= */ true, strategy_groups, cluster_env, + pretrimmed_strategy_map); + } + } else if (IsTopKCustomCall(ins)) { + generate_non_following_strategies(false, {0}); + } else { + // TODO (b/258723035) Handle CustomCall ops for GPUs in a better way. + generate_non_following_strategies(true); + } + break; + } + case HloOpcode::kWhile: { + strategy_group = CreateTupleStrategyGroup(instruction_id); + strategy_group->childs.reserve(ins->shape().tuple_shapes_size()); + const StrategyGroup* src_strategy_group = + strategy_map.at(ins->operand(0)).get(); + for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) { + auto child_strategies = MaybeFollowInsStrategyGroup( + src_strategy_group->childs[i].get(), + ins->shape().tuple_shapes().at(i), instruction_id, + /* have_memory_cost= */ true, strategy_groups, cluster_env, + pretrimmed_strategy_map); + child_strategies->tuple_element_idx = i; + strategy_group->childs.push_back(std::move(child_strategies)); + } + + break; + } + case HloOpcode::kConditional: + case HloOpcode::kInfeed: + case HloOpcode::kSort: { + strategy_group = + CreateAllStrategiesGroup( + ins, ins->shape(), instruction_id, strategy_groups, cluster_env, + strategy_map, option, replicated_penalty, batch_dim_map, + call_graph, only_allow_divisible, + /* create_replicated_strategies */ true, + /* create_partially_replicated_strategies */ true) + .value(); + break; + } + case HloOpcode::kOutfeed: { + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, strategy_groups); + GenerateOutfeedStrategy(ins, ins->shape(), cluster_env, strategy_map, + strategy_group, replicated_penalty); + break; + } + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kSend: { + strategy_group = CreateTupleStrategyGroup(instruction_id); + strategy_group->childs.reserve(ins->shape().tuple_shapes_size()); + for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) { + std::unique_ptr child_strategies = + CreateLeafStrategyGroup(instruction_id, ins, strategy_map, + strategy_groups); + AddReplicatedStrategy(ins, ins->shape().tuple_shapes(i), cluster_env, + strategy_map, child_strategies, 0); + child_strategies->tuple_element_idx = i; + strategy_group->childs.push_back(std::move(child_strategies)); + } + break; + } + case HloOpcode::kSendDone: { + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, strategy_groups); + AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, + strategy_group, 0); + break; + } + case HloOpcode::kAfterAll: { + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, strategy_groups); + AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, + strategy_group, replicated_penalty); + break; + } + default: + LOG(FATAL) << "Unhandled instruction: " + ins->ToString(); + } + RemoveDuplicatedStrategy(strategy_group); + if (ins->has_sharding() && ins->opcode() != HloOpcode::kOutfeed) { + // Finds the sharding strategy that aligns with the given sharding spec + // Do not merge nodes if this one instruction has annotations. + TrimOrGenerateStrategiesBasedOnExistingSharding( + ins->shape(), strategy_group.get(), strategy_map, instructions, + ins->sharding(), cluster_env, pretrimmed_strategy_map, call_graph, + option.nd_sharding_iteratively_strict_search_space); + } + if (!strategy_group->is_tuple && strategy_group->following) { + if (!LeafVectorsAreConsistent(strategy_group->strategies, + strategy_group->following->strategies)) { + // It confuses the solver if two instructions have different number of + // sharding strategies but share the same ILP variable. The solver would + // run much longer and/or return infeasible solutions. So if two + // strategies are inconsistent, we unfollow them. + CHECK(!is_follow_necessary_for_correctness) + << "Reverting a following decision that is necessary for " + "correctness. Please report this as a bug."; + strategy_group->following = nullptr; + } + } else if (strategy_group->is_tuple) { + for (size_t i = 0; i < strategy_group->childs.size(); i++) { + if (strategy_group->childs.at(i)->following && + !LeafVectorsAreConsistent( + strategy_group->childs.at(i)->strategies, + strategy_group->childs.at(i)->following->strategies)) { + CHECK(!is_follow_necessary_for_correctness) + << "Reverting a following decision that is necessary for " + "correctness. Please report this as a bug."; + strategy_group->childs.at(i)->following = nullptr; + } + } + } + RemoveInvalidShardingsWithShapes( + ins->shape(), strategy_group.get(), + /* instruction_has_user_sharding */ ins->has_sharding()); + + if (instruction_execution_counts.contains(ins)) { + ScaleCostsWithExecutionCounts(strategy_group.get(), + instruction_execution_counts.at(ins)); + } else { + VLOG(5) << "No execution count available for " << ins->name(); + } + XLA_VLOG_LINES(2, + absl::StrCat("strategies:\n", strategy_group->ToString())); + + // Debug options: forcibly set the strategy of some instructions. + if (option.force_strategy) { + std::vector inst_indices = option.force_strategy_inst_indices; + std::vector stra_names = option.force_strategy_stra_names; + CHECK_EQ(inst_indices.size(), stra_names.size()); + auto it = absl::c_find(inst_indices, strategy_group->node_idx); + if (it != inst_indices.end()) { + CHECK(!strategy_group->is_tuple); + std::vector new_strategies; + int64_t idx = it - inst_indices.begin(); + for (const auto& stra : strategy_group->strategies) { + if (stra.name == stra_names[idx]) { + new_strategies.push_back(stra); + } + } + strategy_group->strategies = std::move(new_strategies); + } + } + + // When trying out multiple mesh shapes in the presence of user specified + // sharding (as in + // AutoShardingTest.AutoShardingKeepUserShardingInputOutput), there may be a + // situation when we cannot generate any shardings for an instruction when + // the mesh shape we're trying does not match with the mesh shape used in + // user specified shardings. So we disable the check in that situation. + if (!trying_multiple_mesh_shapes) { + CHECK(strategy_group->is_tuple || !strategy_group->strategies.empty()) + << ins->ToString() << " does not have any valid strategies."; + } else if (!(strategy_group->is_tuple || + !strategy_group->strategies.empty())) { + return Status(absl::StatusCode::kFailedPrecondition, + "Could not generate any shardings for an instruction due " + "to mismatched mesh shapes."); + } + // Checks the shape of resharding_costs is valid. It will check fail if the + // shape is not as expected. + // CheckReshardingCostsShape(strategies.get()); + CheckMemoryCosts(strategy_group.get(), ins->shape()); + strategy_map[ins] = std::move(strategy_group); + } // end of for loop + + // If gradient accumulation is used, adjust the cost of all-reduce for + // gradient synchronization. + if (option.grad_acc_num_micro_batches > 1) { + // find gradient-computation instructions + std::vector grad_insts = + GetGradientComputationInstructions(instructions); + for (const HloInstruction* inst : grad_insts) { + StrategyGroup* stra_vector = strategy_map[inst].get(); + CHECK(!stra_vector->is_tuple); + + for (auto& stra : stra_vector->strategies) { + if (absl::StrContains(stra.name, "allreduce")) { + stra.communication_cost /= option.grad_acc_num_micro_batches; + } + } + } + } + + return std::make_tuple(std::move(strategy_map), std::move(strategy_groups), + std::move(associative_dot_pairs)); +} + +// NOLINTEND + +} // namespace spmd +} // namespace xla diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h index 24fa988df5729..5233a71d1e58e 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -56,6 +56,8 @@ using AliasMap = StableHashMap; using ReshardingCache = StableHashMap>>; +// Resharding costs for each operand +using ReshardingCosts = std::vector>; // One sharding strategy struct ShardingStrategy { @@ -68,7 +70,8 @@ struct ShardingStrategy { // i-th operand's j-th strategy to this strategy. // If there is only one tuple operand,resharding_costs[i][j] is the resharding // cost from i-th tuple element's j-th strategy. - std::vector> resharding_costs; + ReshardingCosts communication_resharding_costs; + ReshardingCosts memory_resharding_costs; // Optional: the required shardings of operands. // This is used to guide the SPMD partitioner. std::vector> input_shardings; @@ -78,14 +81,25 @@ struct ShardingStrategy { } std::string ToStringLong() const { - std::vector resharding_vector_strings; - resharding_vector_strings.reserve(resharding_costs.size()); - for (const auto& v : resharding_costs) { - resharding_vector_strings.push_back( + std::vector communication_resharding_vector_strings; + communication_resharding_vector_strings.reserve( + communication_resharding_costs.size()); + for (const auto& v : communication_resharding_costs) { + communication_resharding_vector_strings.push_back( absl::StrCat("[", absl::StrJoin(v, ", "), "]")); } - std::string resharding_cost_str = - absl::StrCat("{", absl::StrJoin(resharding_vector_strings, ", "), "}"); + std::string communication_resharding_cost_str = absl::StrCat( + "{", absl::StrJoin(communication_resharding_vector_strings, ", "), "}"); + + std::vector memory_resharding_vector_strings; + memory_resharding_vector_strings.reserve(memory_resharding_costs.size()); + for (const auto& v : memory_resharding_costs) { + memory_resharding_vector_strings.push_back( + absl::StrCat("[", absl::StrJoin(v, ", "), "]")); + } + std::string memory_resharding_cost_str = absl::StrCat( + "{", absl::StrJoin(memory_resharding_vector_strings, ", "), "}"); + std::string input_sharding_str = "{"; for (const auto& s : input_shardings) { if (!s.has_value()) { @@ -105,12 +119,13 @@ struct ShardingStrategy { } } input_sharding_str += "}\n"; - return absl::StrCat(name, ", ", output_sharding.ToString(), - ", compute_cost=", compute_cost, - ", communication_cost=", communication_cost, - ", memory_cost=", memory_cost, - ", resharding_costs=", resharding_cost_str, - ", input_shardings=", input_sharding_str); + return absl::StrCat( + name, ", ", output_sharding.ToString(), ", compute_cost=", compute_cost, + ", communication_cost=", communication_cost, + ", memory_cost=", memory_cost, + ", communication_resharding_costs=", communication_resharding_cost_str, + ", memory_resharding_costs=", memory_resharding_cost_str, + ", input_shardings=", input_sharding_str); } bool operator==(const ShardingStrategy& other) const { @@ -118,7 +133,9 @@ struct ShardingStrategy { compute_cost == other.compute_cost && communication_cost == other.communication_cost && memory_cost == other.memory_cost && - resharding_costs == other.resharding_costs && + communication_resharding_costs == + other.communication_resharding_costs && + memory_resharding_costs == other.memory_resharding_costs && input_shardings == other.input_shardings; } }; @@ -130,6 +147,10 @@ using EdgeStrategyIdx = int64_t; // An index into an edge's strategy vector. using LivenessIdx = int64_t; // An index into the liveness vector. using AliasIdx = int64_t; // An index into the alias vector. +// Various classes needed to support strategy shaving. +using NodeStrategy = std::pair; +using NodeStrategies = StableHashSet; + // A group of strategy choices (along with details like index values) // for each instruction. struct StrategyGroup { @@ -160,6 +181,10 @@ struct StrategyGroup { absl::StrAppend(&str, indent, "node_idx: ", node_idx, "\n"); absl::StrAppend(&str, indent, "instruction id: ", instruction_id, "\n"); absl::StrAppend(&str, indent, "is_tuple: ", is_tuple, "\n"); + if (tuple_element_idx.has_value()) { + absl::StrAppend(&str, indent, + "index in producer inst.: ", *tuple_element_idx, "\n"); + } if (following != nullptr) { absl::StrAppend(&str, indent, "following instruction: ", following->instruction_id, @@ -198,6 +223,8 @@ struct StrategyGroup { using LivenessSet = std::vector>; // A liveness set using node indices instead of HLO values. using LivenessNodeSet = std::vector>; +// A liveness set using edge indices instead of HLO values. +using LivenessEdgeSet = std::vector>; // Map an instruction to its strategy group. using StrategyMap = StableHashMap>; diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index 88700280b6b3c..9ce627b58867b 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -16,20 +16,33 @@ limitations under the License. #include #include #include +#include #include +#include #include #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/strings/string_view.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/utils/hlo_live_range.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/service/buffer_value.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_memory_scheduler.h" #include "xla/service/hlo_parser.h" +#include "xla/service/hlo_value.h" +#include "xla/statusor.h" #include "xla/tests/hlo_test_base.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" @@ -39,11 +52,21 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace spmd { namespace { +using ::testing::Each; +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; +using ::testing::FieldsAre; +using ::testing::IsEmpty; +using ::testing::IsFalse; +using ::testing::IsTrue; +using ::testing::Pair; +using ::testing::ResultOf; +using ::testing::UnorderedElementsAre; using DummyAutoShardingTest = HloTestBase; TEST_F(DummyAutoShardingTest, ReplicatedShardingDummy) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %elementwise { %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0) @@ -52,8 +75,8 @@ ENTRY %elementwise { ROOT %copy = f32[5,7,11,13]{3,2,1,0} copy(%add) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); TF_ASSERT_OK_AND_ASSIGN(bool changed, DummyAutoSharding().Run(module.get())); EXPECT_TRUE(changed); auto* instruction = FindInstruction(module.get(), "param0"); @@ -63,14 +86,14 @@ ENTRY %elementwise { class AutoShardingTest : public HloTestBase { protected: - const char* const dot_hlo_string_ = R"( + const absl::string_view kDotHloString = R"( HloModule module ENTRY matmul { parameter.1 = f32[32,64]{1,0} parameter(0) parameter.2 = f32[64,128]{1,0} parameter(1) ROOT root = f32[32,128]{1,0} dot(parameter.1, parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} })"; - const char* const add_hlo_string_ = R"( + const absl::string_view kAddHloString = R"( HloModule module ENTRY %elementwise { %param0 = f32[16,32,64]{2,1,0} parameter(0) @@ -80,8 +103,8 @@ ENTRY %elementwise { void RunMatMulAutoShardingWithOptions( AutoShardingOption option, size_t expected_num_tiles, size_t expected_sharded_dimensions = 1) { - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(dot_hlo_string_)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kDotHloString)); RunAutoShardingWithOptions(module.get(), option, expected_num_tiles, expected_sharded_dimensions); } @@ -89,8 +112,8 @@ ENTRY %elementwise { void RunAddAutoShardingWithOptions(AutoShardingOption option, size_t expected_num_tiles, size_t expected_sharded_dimensions = 1) { - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(add_hlo_string_)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kAddHloString)); RunAutoShardingWithOptions(module.get(), option, expected_num_tiles, expected_sharded_dimensions); } @@ -111,8 +134,8 @@ ENTRY %elementwise { } void RunMatMulAutoShardingWithOptionsExpectFail(AutoShardingOption option) { - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(dot_hlo_string_)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kDotHloString)); RunAutoShardingWithOptionsExpectFail(module.get(), option); } @@ -124,8 +147,8 @@ ENTRY %elementwise { void RunMatMulAutoShardingWithOptionsNoDeviceIds( AutoShardingOption option, std::vector expected_tile, bool expeted_last_dim_replicate = false) { - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(dot_hlo_string_)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kDotHloString)); RunAutoShardingWithOptionsNoDeviceIds(module.get(), option, expected_tile, expeted_last_dim_replicate); } @@ -142,12 +165,74 @@ ENTRY %elementwise { EXPECT_EQ(root->sharding().ReplicateOnLastTileDim(), expeted_last_dim_replicate); EXPECT_THAT(root->sharding().tile_assignment().dimensions(), - ::testing::ElementsAreArray(expected_tile)); + ElementsAreArray(expected_tile)); } }; +TEST_F(AutoShardingTest, MemoryBudgetTest) { + auto compute_memory_budget_lower_bound = + [](const HloModule& module, int64_t num_devices, + const absl::flat_hash_map>& + preserved_shardings = {}) -> absl::StatusOr { + auto size_fn = [](const BufferValue& buffer) { + return spmd::GetBytes(buffer.shape()); + }; + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(&module, size_fn, + ComputationSchedulerToModuleScheduler( + DFSMemoryScheduler), + /* execution_threads */ {})); + const HloComputation* entry_computation = module.entry_computation(); + std::unique_ptr alias_analysis = + HloAliasAnalysis::Run(&module).value(); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_live_range, + HloLiveRange::Run(schedule, *alias_analysis, entry_computation)); + absl::flat_hash_map& + buffer_live_ranges = hlo_live_range->buffer_live_ranges(); + spmd::LivenessSet liveness_set(hlo_live_range->schedule_end_time() + 1); + for (const auto& [hlo_value, live_range] : buffer_live_ranges) { + for (spmd::LivenessIdx i = live_range.start; i <= live_range.end; ++i) { + liveness_set[i].push_back(hlo_value); + } + } + return spmd::MemoryBudgetLowerBound(module, liveness_set, *alias_analysis, + num_devices, preserved_shardings); + }; + + constexpr absl::string_view kHloString = R"( +HloModule module +ENTRY %elementwise { + %param0 = f32[16384,16384]{0,1} parameter(0) + %param1 = f32[16384,16384]{0,1} parameter(1) + %add = f32[16384,16384]{0,1} add(%param0, %param1) + ROOT %copy = f32[16384,16384]{0,1} copy(%add) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString)); + TF_ASSERT_OK_AND_ASSIGN(HloSharding partial_sharding, + ParseSharding("{devices=[64,1]<=[64]}")); + TF_ASSERT_OK_AND_ASSIGN( + int64_t partial_mesh_64x1_budget_lower_bound, + compute_memory_budget_lower_bound(*module, /* num_devices */ 64)); + for (HloInstruction* ins : module->entry_computation()->instructions()) { + ins->set_sharding(partial_sharding); + } + TF_ASSERT_OK_AND_ASSIGN( + int64_t full_mesh_64x8_budget_lower_bound, + compute_memory_budget_lower_bound(*module, /* num_devices */ 512)); + CHECK_LT(full_mesh_64x8_budget_lower_bound, + partial_mesh_64x1_budget_lower_bound) + << "The memory budget lower bound per device should be lower with a " + "larger number of devices. Instead, the bound was " + << partial_mesh_64x1_budget_lower_bound << " bytes for 64 devices and " + << full_mesh_64x8_budget_lower_bound << " bytes for 512 devices."; +} + TEST_F(AutoShardingTest, DISABLED_ElementWiseOperator) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %elementwise { %param0 = f32[128,128]{0,1} parameter(0) @@ -156,8 +241,8 @@ ENTRY %elementwise { ROOT %copy = f32[128,128]{0,1} copy(%add) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -172,8 +257,154 @@ ENTRY %elementwise { EXPECT_THAT(instruction, op::Sharding("{devices=[2,2]0,2,1,3}")); } +TEST_F(AutoShardingTest, Unsupported3DShardingTest) { + constexpr absl::string_view kHloString = R"( +HloModule module +ENTRY %elementwise { + %param0 = f32[32,32,32,32] parameter(0) + %param1 = f32[32,32,32,32] parameter(1) + %add = f32[32,32,32,32] add(%param0, %param1), sharding={devices=[2,2,1,2]<=[8]} + ROOT %copy = f32[32,32,32,32] copy(%add) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + AutoShardingOption option; + option.enable = true; + // The case of a fleet HLO when run with try_multiple_mesh_shapes = true + option.device_mesh_shape = {2, 4}; + option.device_mesh_alpha = {1.0, 1.0}; + option.device_mesh_beta = {0.01, 1.0}; + EXPECT_DEATH(auto status = AutoSharding(option).Run(module.get()), + ".*too many axes.*"); +} + +TEST_F(AutoShardingTest, NDIterativeSolveTest) { + constexpr absl::string_view kHloString = R"( +HloModule module + +ENTRY %elementwise { + param = s32[512,3084]{1,0} parameter(0), sharding={devices=[256,1]0,1,2,3,4,5,6,7,16,17,18,19,20,21,22,23,8,9,10,11,12,13,14,15,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255} + sharding_call = s32[512,3084]{1,0} custom-call(param), custom_call_target="Sharding", sharding={devices=[256,1]<=[256]} + ROOT slice = s32[512,2048]{1,0} slice(sharding_call), slice={[0:512], [0:2048]} +})"; + + AutoShardingOption option; + option.enable = true; + option.solve_nd_sharding_iteratively = true; + option.preserve_shardings = + AutoShardingOption::PreserveShardingsType::kKeepAllShardings; + option.device_mesh_shape = {16, 16}; + option.device_mesh_alpha = {1.0, 1.0}; + option.device_mesh_beta = {0.01, 1.0}; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); + VLOG(10) << module->ToString(); + EXPECT_TRUE(changed); + HloInstruction* slice = FindInstruction(module.get(), "slice"); + EXPECT_NE(slice, nullptr); + EXPECT_THAT(slice, op::Sharding("{devices=[256,1]<=[256]}")); +} + +TEST_F(AutoShardingTest, SliceDeviceMeshTest) { + constexpr absl::string_view kHloString = R"( +HloModule module + +ENTRY %elementwise { + param = s32[512,3084]{1,0} parameter(0) + slice = s32[512,2048]{1,0} slice(param), slice={[0:512], [0:2048]} + ROOT copy = s32[512,2048]{1,0} copy(slice) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, AutoSharding(/* option */ AutoShardingOption{ + .enable = true, + .solve_nd_sharding_iteratively = true, + .device_mesh_shape = {2, 2}, + .device_mesh_alpha = {1.0, 1.0}, + .device_mesh_beta = {0.01, 1.0}}) + .Run(module.get())); + VLOG(10) << module->ToString(); + EXPECT_TRUE(changed); + const HloInstruction* slice = FindInstruction(module.get(), "slice"); + ASSERT_NE(slice, nullptr); + EXPECT_THAT( + slice, + AnyOf(op::Sharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}"), + op::Sharding("{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}"))); +} + +TEST_F(AutoShardingTest, SliceMixedUserShardingTest) { + constexpr absl::string_view kHloString = R"( +HloModule module + +ENTRY %elementwise { + param = s32[512,3084]{1,0} parameter(0), sharding={devices=[4,1]0,2,1,3} + slice = s32[512,2048]{1,0} slice(param), slice={[0:512], [0:2048]} + ROOT copy = s32[512,2048]{1,0} copy(slice) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + AutoSharding( + /* option */ { + .enable = true, + .preserve_shardings = + AutoShardingOption::PreserveShardingsType::kKeepAllShardings, + .solve_nd_sharding_iteratively = true, + .device_mesh_shape = {2, 2}, + .device_mesh_ids = {0, 2, 1, 3}, + .device_mesh_alpha = {1.0, 1.0}, + .device_mesh_beta = {0.01, 1.0}}) + .Run(module.get())); + VLOG(10) << module->ToString(); + EXPECT_TRUE(changed); + + std::vector instructions = + module->entry_computation()->MakeInstructionPostOrder(); + EXPECT_THAT(instructions, + Each(ResultOf( + [](const HloInstruction* ins) { return ins->has_sharding(); }, + IsTrue()))); + EXPECT_THAT(instructions, Each(op::Sharding("{devices=[4,1]0,2,1,3}"))); +} + +TEST_F(AutoShardingTest, UserShardingTest) { + constexpr absl::string_view kHloString = R"( +HloModule module + +ENTRY %elementwise { + concatenate.76306 = bf16[1,4096,8,256]{3,2,1,0} parameter(0) + constant.15158 = bf16[] constant(0) + pad.70 = bf16[1,4352,8,256]{3,2,1,0} pad(concatenate.76306, constant.15158), padding=0_0x0_256x0_0x0_0, sharding={devices=[1,1,128,1]<=[128]} + ROOT copy.45 = bf16[1,4352,8,256]{3,2,1,0} copy(pad.70) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + AutoSharding( + /* option */ AutoShardingOption{ + .enable = true, + .preserve_shardings = + AutoShardingOption::PreserveShardingsType::kKeepAllShardings, + .device_mesh_shape = {128, 1}, + .device_mesh_alpha = {1.0, 1.0}, + .device_mesh_beta = {0.01, 1.0}}) + .Run(module.get())); + VLOG(10) << module->ToString(); + EXPECT_TRUE(changed); +} + TEST_F(AutoShardingTest, RngBitGeneratorArrayInput) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule rng_bit_generator ENTRY %RngBitGenerator (p0: u64[2]) -> (u64[2], u32[16,16]) { @@ -181,8 +412,8 @@ ENTRY %RngBitGenerator (p0: u64[2]) -> (u64[2], u32[16,16]) { ROOT %rand = (u64[2]{0}, u32[16,16]{1,0}) rng-bit-generator(u64[2]{0} %p0), algorithm=rng_three_fry })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -198,7 +429,7 @@ ENTRY %RngBitGenerator (p0: u64[2]) -> (u64[2], u32[16,16]) { } TEST_F(AutoShardingTest, RngBitGeneratorTupleInput) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule rng_bit_generator ENTRY %RngBitGenerator { @@ -208,8 +439,8 @@ ENTRY %RngBitGenerator { ROOT rng-bit-generator = u32[100,100]{1,0:T(8,128)} rng-bit-generator(tuple.3), algorithm=rng_default })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -228,7 +459,7 @@ ENTRY %RngBitGenerator { } TEST_F(AutoShardingTest, DotLHSTwoNonContractingDims) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %entry { %param0 = f32[4,256,64]{2,1,0} parameter(0) @@ -237,8 +468,8 @@ ENTRY %entry { ROOT %copy = f32[4,256,32]{2,1,0} copy(%dot) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -257,22 +488,22 @@ ENTRY %entry { EXPECT_THAT( std::make_tuple(param0, param1, dot), AnyOf( - ::testing::FieldsAre( + FieldsAre( op::Sharding( "{devices=[1,2,1,2]0,1,2,3 last_tile_dim_replicate}"), op::Sharding("{devices=[1,2,2]0,2,1,3 last_tile_dim_replicate}"), op::Sharding("{devices=[1,2,2]0,1,2,3}")), - ::testing::FieldsAre( + FieldsAre( op::Sharding( "{devices=[1,2,1,2]0,2,1,3 last_tile_dim_replicate}"), op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}"), op::Sharding("{devices=[1,2,2]0,2,1,3}")), - ::testing::FieldsAre( + FieldsAre( op::Sharding( "{devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}"), op::Sharding("{devices=[1,2,2]0,2,1,3 last_tile_dim_replicate}"), op::Sharding("{devices=[2,1,2]0,1,2,3}")), - ::testing::FieldsAre( + FieldsAre( op::Sharding( "{devices=[2,1,1,2]0,2,1,3 last_tile_dim_replicate}"), op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}"), @@ -280,7 +511,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, DotRHSTwoNonContractingDims) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %entry { %param0 = f32[4,256,32]{2,1,0} parameter(0) @@ -289,8 +520,8 @@ ENTRY %entry { ROOT %copy = f32[32,4,8]{2,1,0} copy(%dot) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -301,22 +532,38 @@ ENTRY %entry { VLOG(2) << module->ToString(); EXPECT_TRUE(changed); auto* param0 = FindInstruction(module.get(), "param0"); - ASSERT_NE(param0, nullptr); - EXPECT_THAT( - param0, - op::Sharding("{devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}")); auto* param1 = FindInstruction(module.get(), "param1"); - ASSERT_NE(param1, nullptr); - EXPECT_THAT( - param1, - op::Sharding("{devices=[1,1,2,1,2]0,2,1,3 last_tile_dim_replicate}")); auto* dot = FindInstruction(module.get(), "dot"); + ASSERT_NE(param0, nullptr); + ASSERT_NE(param1, nullptr); ASSERT_NE(dot, nullptr); - EXPECT_THAT(dot, op::Sharding("{devices=[2,2,1]0,1,2,3}")); + EXPECT_THAT( + std::make_tuple(param0, param1, dot), + AnyOf( + FieldsAre(op::Sharding( + "{devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}"), + op::Sharding( + "{devices=[1,1,2,1,2]0,2,1,3 last_tile_dim_replicate}"), + op::Sharding("{devices=[2,2,1]0,1,2,3}")), + FieldsAre(op::Sharding( + "{devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}"), + op::Sharding( + "{devices=[1,1,1,2,2]0,2,1,3 last_tile_dim_replicate}"), + op::Sharding("{devices=[2,1,2]0,1,2,3}")), + FieldsAre(op::Sharding( + "{devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate}"), + op::Sharding( + "{devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}"), + op::Sharding("{devices=[2,1,2]0,2,1,3}")), + FieldsAre(op::Sharding( + "{devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate}"), + op::Sharding( + "{devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}"), + op::Sharding("{devices=[2,2,1]0,2,1,3}")))); } TEST_F(AutoShardingTest, DotTwoContractingDims) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %entry { %param0 = f32[4,256,64]{2,1,0} parameter(0) @@ -325,8 +572,8 @@ ENTRY %entry { ROOT %copy = f32[64,32]{1,0} copy(%dot) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -344,22 +591,20 @@ ENTRY %entry { ASSERT_NE(dot, nullptr); EXPECT_THAT( std::make_tuple(param0, param1, dot), - AnyOf(::testing::FieldsAre( - op::Sharding( - "{devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate}"), - op::Sharding( - "{devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}"), - op::Sharding("{devices=[2,2]0,2,1,3}")), - ::testing::FieldsAre( - op::Sharding( - "{devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}"), - op::Sharding( - "{devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate}"), - op::Sharding("{devices=[2,2]0,1,2,3}")))); + AnyOf(FieldsAre(op::Sharding( + "{devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate}"), + op::Sharding( + "{devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}"), + op::Sharding("{devices=[2,2]0,2,1,3}")), + FieldsAre(op::Sharding( + "{devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}"), + op::Sharding( + "{devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate}"), + op::Sharding("{devices=[2,2]0,1,2,3}")))); } TEST_F(AutoShardingTest, TwoMatmul) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module ENTRY twomatmul { parameter.1 = f32[64,64]{1,0} parameter(0) @@ -369,11 +614,11 @@ ENTRY twomatmul { ROOT dot.5 = f32[64,64]{1,0} dot(dot.4, parameter.3), lhs_contracting_dims={1}, rhs_contracting_dims={0} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; - option.allow_replicated_strategy_for_dot_and_conv = false; + option.allow_recompute_heavy_op = false; option.device_mesh_shape = {2, 2}; option.device_mesh_ids = {0, 1, 2, 3}; option.device_mesh_alpha = {1.0, 1.0}; @@ -402,9 +647,9 @@ ENTRY twomatmul { op::Sharding("{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}")); // Test with replicated strategies on for dot - TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnVerifiedModule(kHloString)); option.enable = true; - option.allow_replicated_strategy_for_dot_and_conv = true; + option.allow_recompute_heavy_op = true; option.device_mesh_shape = {2, 2}; option.device_mesh_ids = {0, 1, 2, 3}; option.device_mesh_alpha = {1.0, 1.0}; @@ -413,25 +658,32 @@ ENTRY twomatmul { VLOG(10) << module->ToString(); EXPECT_TRUE(changed); param1 = FindInstruction(module.get(), "parameter.1"); - ASSERT_NE(param1, nullptr); - EXPECT_THAT(param1, op::Sharding("{replicated}")); param2 = FindInstruction(module.get(), "parameter.2"); - ASSERT_NE(param2, nullptr); - EXPECT_THAT(param2, op::Sharding("{replicated}")); param3 = FindInstruction(module.get(), "parameter.3"); - ASSERT_NE(param3, nullptr); - EXPECT_THAT(param3, - op::Sharding("{devices=[1,2,2]0,2,1,3 last_tile_dim_replicate}")); dot4 = FindInstruction(module.get(), "dot.4"); - ASSERT_NE(dot4, nullptr); - EXPECT_THAT(dot4, op::Sharding("{replicated}")); dot5 = FindInstruction(module.get(), "dot.5"); + ASSERT_NE(param1, nullptr); + ASSERT_NE(param2, nullptr); + ASSERT_NE(param3, nullptr); + ASSERT_NE(dot4, nullptr); ASSERT_NE(dot5, nullptr); - EXPECT_THAT(dot5, op::Sharding("{devices=[2,2]0,1,2,3}")); + EXPECT_THAT( + std::make_tuple(param1, param2, param3, dot4, dot5), + AnyOf( + FieldsAre( + op::Sharding("{replicated}"), op::Sharding("{replicated}"), + op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}"), + op::Sharding("{replicated}"), + op::Sharding("{devices=[2,2]0,2,1,3}")), + FieldsAre( + op::Sharding("{replicated}"), op::Sharding("{replicated}"), + op::Sharding("{devices=[1,2,2]0,2,1,3 last_tile_dim_replicate}"), + op::Sharding("{replicated}"), + op::Sharding("{devices=[2,2]0,1,2,3}")))); } TEST_F(AutoShardingTest, ProcessCustomCallShardings) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %entry { @@ -443,8 +695,8 @@ ENTRY %entry { %copy.2 = f32[6,3] copy(%annotate) ROOT %copy.3 = f32[6,3] copy(%copy.2) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -460,8 +712,8 @@ ENTRY %entry { op::Sharding("{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}")); } -TEST_F(AutoShardingTest, RemoveShardingAnnotationKeepAll) { - const char* const hlo_string = R"( +TEST_F(AutoShardingTest, SaveAndRemoveShardingAnnotationKeepAll) { + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { @@ -470,94 +722,270 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { %dot = f32[64,32]{1,0} dot(f32[4,256,64]{2,1,0} %param0, f32[4,256,32]{2,1,0} %param1), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1}, sharding={devices=[2,2]0,1,2,3} ROOT %copy = f32[64,32]{1,0} copy(f32[64,32]{1,0} %dot), sharding={devices=[2,2]0,1,2,3} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; // Keep all user shardings option.preserve_shardings = AutoShardingOption::PreserveShardingsType::kKeepAllShardings; - TF_ASSERT_OK_AND_ASSIGN( - bool changed, AutoShardingImplementation(option).RemoveShardingAnnotation( - module.get())); + std::pair>, bool> + saved_shardings_result = + AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( + module.get(), /* replicated_small_tensors */ {}, + /* execution_threads */ {}); + absl::flat_hash_map> saved_shardings = + saved_shardings_result.first; + bool changed = saved_shardings_result.second; EXPECT_FALSE(changed); - for (HloComputation* computation : module->computations()) { - for (HloInstruction* ins : computation->instructions()) { - EXPECT_TRUE(ins->has_sharding()); - } - } + std::vector instructions = + module->entry_computation()->MakeInstructionPostOrder(); + EXPECT_THAT(instructions, + Each(ResultOf( + [](const HloInstruction* ins) { return ins->has_sharding(); }, + IsTrue()))); + + auto verified_parse_sharding = [](const absl::string_view sharding_str) { + absl::StatusOr sharding = ParseSharding(sharding_str); + CHECK_OK(sharding); + return *sharding; + }; + + EXPECT_THAT( + saved_shardings, + UnorderedElementsAre( + Pair("param0", + ElementsAre(verified_parse_sharding( + "{devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}"))), + Pair("param1", + ElementsAre(verified_parse_sharding( + "{devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate}"))), + Pair("dot", + ElementsAre(verified_parse_sharding("{devices=[2,2]0,1,2,3}"))), + Pair("copy", ElementsAre(verified_parse_sharding( + "{devices=[2,2]0,1,2,3}"))))); } -TEST_F(AutoShardingTest, RemoveShardingAnnotationRemoveIntermediate) { - const char* const hlo_string = R"( +TEST_F(AutoShardingTest, + SaveAndRemoveShardingAnnotationKeepInputOutputSmallTensor) { + constexpr absl::string_view kHloString = R"( +HloModule module + +ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { + %param0 = f32[4,256,64]{2,1,0} parameter(0), sharding={devices=[2,2,1]0,1,2,3} + %param1 = f32[4,256,32]{2,1,0} parameter(1), sharding={devices=[2,2,1]0,1,2,3} + %dot = f32[64,32]{1,0} dot(f32[4,256,64]{2,1,0} %param0, f32[4,256,32]{2,1,0} %param1), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1}, sharding={replicated} + ROOT %copy = f32[64,32]{1,0} copy(f32[64,32]{1,0} %dot), sharding={devices=[2,2]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + AutoShardingOption option; + // Keep all user shardings + option.preserve_shardings = + AutoShardingOption::PreserveShardingsType::kKeepInputOutputShardings; + std::pair>, bool> + saved_shardings_result = + AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( + module.get(), /* replicated_small_tensors */ {"dot"}, + /* execution_threads */ {}); + absl::flat_hash_map> saved_shardings = + saved_shardings_result.first; + bool changed = saved_shardings_result.second; + EXPECT_FALSE(changed); + std::vector instructions = + module->entry_computation()->MakeInstructionPostOrder(); + EXPECT_THAT(instructions, + Each(ResultOf( + [](const HloInstruction* ins) { return ins->has_sharding(); }, + IsTrue()))); + + auto verified_parse_sharding = [](const absl::string_view sharding_str) { + absl::StatusOr sharding = ParseSharding(sharding_str); + CHECK_OK(sharding); + return *sharding; + }; + + EXPECT_THAT( + saved_shardings, + UnorderedElementsAre( + Pair("param0", ElementsAre(verified_parse_sharding( + "{devices=[2,2,1]0,1,2,3}"))), + Pair("param1", ElementsAre(verified_parse_sharding( + "{devices=[2,2,1]0,1,2,3}"))), + Pair("dot", ElementsAre(verified_parse_sharding("{replicated}"))), + Pair("copy", ElementsAre(verified_parse_sharding( + "{devices=[2,2]0,1,2,3}"))))); +} + +TEST_F(AutoShardingTest, SaveAndRemoveShardingAnnotationKeepInputOutput) { + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { %param0 = f32[4,256,64]{2,1,0} parameter(0), sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate} %param1 = f32[4,256,32]{2,1,0} parameter(1), sharding={devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate} - %dot = f32[64,32]{1,0} dot(f32[4,256,64]{2,1,0} %param0, f32[4,256,32]{2,1,0} %param1), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1}, sharding={devices=[2,2]0,1,2,3} + %param0_copy = f32[4,256,64]{2,1,0} copy(param0), sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate} + %param1_copy = f32[4,256,32]{2,1,0} copy(param1), sharding={devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate} + %dot = f32[64,32]{1,0} dot(f32[4,256,64]{2,1,0} %param0_copy, f32[4,256,32]{2,1,0} %param1_copy), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1}, sharding={devices=[2,2]0,1,2,3} ROOT %copy = f32[64,32]{1,0} copy(f32[64,32]{1,0} %dot), sharding={devices=[2,2]0,1,2,3} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.preserve_shardings = AutoShardingOption::PreserveShardingsType::kKeepInputOutputShardings; - TF_ASSERT_OK_AND_ASSIGN( - bool changed, AutoShardingImplementation(option).RemoveShardingAnnotation( - module.get())); + std::pair>, bool> + saved_shardings_result = + AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( + module.get(), /* replicated_small_tensors */ {}, + /* execution_threads */ {}); + absl::flat_hash_map> saved_shardings = + saved_shardings_result.first; + bool changed = saved_shardings_result.second; EXPECT_TRUE(changed); + // Dot does not have shardings anymore. - auto* dot = FindInstruction(module.get(), "dot"); + const HloInstruction* dot = FindInstruction(module.get(), "dot"); ASSERT_NE(dot, nullptr); EXPECT_FALSE(dot->has_sharding()); - // params and copy still have shardings. - auto* param0 = FindInstruction(module.get(), "param0"); + + // params and copies still have shardings. + const HloInstruction* param0 = FindInstruction(module.get(), "param0"); ASSERT_NE(param0, nullptr); EXPECT_TRUE(param0->has_sharding()); EXPECT_THAT( param0, op::Sharding("{devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}")); - auto* param1 = FindInstruction(module.get(), "param1"); + + const HloInstruction* param0_copy = + FindInstruction(module.get(), "param0_copy"); + ASSERT_NE(param0_copy, nullptr); + EXPECT_TRUE(param0_copy->has_sharding()); + EXPECT_THAT( + param0_copy, + op::Sharding("{devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}")); + + const HloInstruction* param1 = FindInstruction(module.get(), "param1"); ASSERT_NE(param1, nullptr); EXPECT_TRUE(param1->has_sharding()); EXPECT_THAT( param1, op::Sharding("{devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate}")); - auto* copy = FindInstruction(module.get(), "copy"); + + const HloInstruction* param1_copy = + FindInstruction(module.get(), "param1_copy"); + ASSERT_NE(param1_copy, nullptr); + EXPECT_TRUE(param1_copy->has_sharding()); + EXPECT_THAT( + param1_copy, + op::Sharding("{devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate}")); + + // Root still has sharding + const HloInstruction* copy = FindInstruction(module.get(), "copy"); ASSERT_NE(copy, nullptr); EXPECT_TRUE(copy->has_sharding()); EXPECT_THAT(copy, op::Sharding("{devices=[2,2]0,1,2,3}")); + + EXPECT_THAT( + saved_shardings, + UnorderedElementsAre(Pair("param0", ElementsAre(param0->sharding())), + Pair("param0_copy", ElementsAre(param0->sharding())), + Pair("param1", ElementsAre(param1->sharding())), + Pair("param1_copy", ElementsAre(param1->sharding())), + Pair("copy", ElementsAre(copy->sharding())))); } -TEST_F(AutoShardingTest, RemoveShardingAnnotationRemoveAll) { - const char* const hlo_string = R"( +TEST_F(AutoShardingTest, SaveAndRemoveShardingAnnotationRemoveAll) { + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { - %param0 = f32[4,256,64]{2,1,0} parameter(0), sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate} - %param1 = f32[4,256,32]{2,1,0} parameter(1), sharding={devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate} - %dot = f32[64,32]{1,0} dot(f32[4,256,64]{2,1,0} %param0, f32[4,256,32]{2,1,0} %param1), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1}, sharding={devices=[2,2]0,1,2,3} - ROOT %copy = f32[64,32]{1,0} copy(f32[64,32]{1,0} %dot), sharding={devices=[2,2]0,1,2,3} + %param0 = f32[4,256,64]{2,1,0} parameter(0), + sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate} %param1 = + f32[4,256,32]{2,1,0} parameter(1), sharding={devices=[1,1,2,2]0,2,1,3 + last_tile_dim_replicate} %dot = f32[64,32]{1,0} dot(f32[4,256,64]{2,1,0} + %param0, f32[4,256,32]{2,1,0} %param1), lhs_contracting_dims={0,1}, + rhs_contracting_dims={0,1}, sharding={devices=[2,2]0,1,2,3} ROOT %copy = + f32[64,32]{1,0} copy(f32[64,32]{1,0} %dot), sharding={devices=[2,2]0,1,2,3} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; // Remove all user shardings option.preserve_shardings = AutoShardingOption::PreserveShardingsType::kRemoveAllShardings; - TF_ASSERT_OK_AND_ASSIGN( - bool changed, AutoShardingImplementation(option).RemoveShardingAnnotation( - module.get())); + std::pair>, bool> + saved_shardings_result = + AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( + module.get(), /* replicated_small_tensors */ {}, + /* execution_threads */ {}); + absl::flat_hash_map> saved_shardings = + saved_shardings_result.first; + bool changed = saved_shardings_result.second; EXPECT_TRUE(changed); - for (HloComputation* computation : module->computations()) { - for (HloInstruction* ins : computation->instructions()) { - EXPECT_FALSE(ins->has_sharding()); - } - } + EXPECT_THAT(saved_shardings, IsEmpty()); + std::vector instructions = + module->entry_computation()->MakeInstructionPostOrder(); + EXPECT_THAT(instructions, + Each(ResultOf( + [](const HloInstruction* ins) { return ins->has_sharding(); }, + IsFalse()))); +} + +TEST_F(AutoShardingTest, SaveAndRemoveShardingAnnotationRemoveAllSmallTensor) { + constexpr absl::string_view kHloString = R"( +HloModule module + +ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { + %param0 = f32[4,256,64]{2,1,0} parameter(0), sharding={devices=[2,2,1]0,1,2,3} + %param1 = f32[4,256,32]{2,1,0} parameter(1), sharding={devices=[2,2,1]0,1,2,3} + %dot = f32[64,32]{1,0} dot(f32[4,256,64]{2,1,0} %param0, f32[4,256,32]{2,1,0} %param1), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1}, sharding={replicated} + ROOT %copy = f32[64,32]{1,0} copy(f32[64,32]{1,0} %dot), sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + AutoShardingOption option; + // Remove all user shardings + option.preserve_shardings = + AutoShardingOption::PreserveShardingsType::kRemoveAllShardings; + std::pair>, bool> + saved_shardings_result = + AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( + module.get(), /* replicated_small_tensors */ {"dot", "copy"}, + /* execution_threads */ {}); + absl::flat_hash_map> saved_shardings = + saved_shardings_result.first; + bool changed = saved_shardings_result.second; + EXPECT_TRUE(changed); + + // params have no shardings. + const HloInstruction* param0 = FindInstruction(module.get(), "param0"); + ASSERT_NE(param0, nullptr); + EXPECT_FALSE(param0->has_sharding()); + + const HloInstruction* param1 = FindInstruction(module.get(), "param1"); + ASSERT_NE(param1, nullptr); + EXPECT_FALSE(param1->has_sharding()); + + // Dot and copy have shardings as they are specified as replicated small + // tensors. + const HloInstruction* dot = FindInstruction(module.get(), "dot"); + ASSERT_NE(dot, nullptr); + EXPECT_TRUE(dot->has_sharding()); + EXPECT_TRUE(dot->sharding().IsReplicated()); + + const HloInstruction* copy = FindInstruction(module.get(), "copy"); + ASSERT_NE(copy, nullptr); + EXPECT_TRUE(copy->has_sharding()); + EXPECT_TRUE(copy->sharding().IsReplicated()); + + EXPECT_THAT( + saved_shardings, + UnorderedElementsAre(Pair("dot", ElementsAre(dot->sharding())), + Pair("copy", ElementsAre(copy->sharding())))); } TEST_F(AutoShardingTest, TupleReduceTest) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module %func (lhs_value: f32[], lhs_index: s32[], rhs_value: f32[], rhs_index: s32[]) -> (f32[], s32[]) { %lhs_value = f32[] parameter(0) @@ -580,8 +1008,8 @@ ENTRY %entry { %constant.b = s32[] constant(0) %reduce = (f32[1,16]{1,0}, s32[1,16]{1,0}) reduce(f32[1,16,40]{2,1,0} %param0, s32[1,16,40]{2,1,0} %iota, f32[] %constant.a, s32[] %constant.b), dimensions={2}, to_apply=%func })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -603,7 +1031,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, ReduceTest) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module %func (x: f32[], y: f32[]) -> f32[] { @@ -617,8 +1045,8 @@ ENTRY %entry { %param1 = f32[] parameter(1) %reduce = f32[1,16]{1,0} reduce(f32[1,16,128]{2,1,0} %param0, f32[] %param1), dimensions={2}, to_apply=%func })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -646,7 +1074,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, ScatterTest2D) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module region { @@ -661,8 +1089,8 @@ ENTRY %Scatter { ROOT scatter = s32[4,128]{1,0} scatter(call, clamp, broadcast), update_window_dims={1}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, indices_are_sorted=true, unique_indices=true, to_apply=region } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -685,7 +1113,7 @@ ENTRY %Scatter { } TEST_F(AutoShardingTest, ScatterTest3D) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module region { @@ -700,8 +1128,8 @@ ENTRY %Scatter { ROOT scatter = f32[4,128,128]{2,1,0} scatter(call, clamp, multiply), update_window_dims={1,2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0,1,2}, index_vector_dim=1, indices_are_sorted=true, unique_indices=true, to_apply=region } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -728,15 +1156,15 @@ ENTRY %Scatter { } TEST_F(AutoShardingTest, GatherTest) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %entry { %param0 = f32[256,1024]{0,1} parameter(0) %param1 = s32[128,512,1]{2,1,0} parameter(1) ROOT %gather = f32[128,512,1024]{2,1,0} gather(f32[256,1024]{0,1} %param0, s32[128,512,1]{2,1,0} %param1), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,1024} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -749,21 +1177,25 @@ ENTRY %entry { ASSERT_NE(gather, nullptr); EXPECT_THAT( gather, - op::Sharding("{devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}")); + AnyOf( + op::Sharding("{devices=[1,2,1,2]0,1,2,3 last_tile_dim_replicate}"), + op::Sharding("{devices=[1,2,1,2]0,2,1,3 last_tile_dim_replicate}"), + op::Sharding("{devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}"), + op::Sharding("{devices=[2,1,1,2]0,2,1,3 last_tile_dim_replicate}"))); auto gather_sharding = gather->sharding(); TF_EXPECT_OK(gather_sharding.Validate(gather->shape(), 4)); } TEST_F(AutoShardingTest, GatherTestNoReshard) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %entry { get-tuple-element = s8[1000,128]{1,0} parameter(0) reshape = s32[8,1,1]{2,1,0} parameter(1) gather = s8[8,1,128]{2,1,0} gather(get-tuple-element, reshape), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,128} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {1, 1, 8}; @@ -778,14 +1210,15 @@ ENTRY %entry { ASSERT_NE(gather, nullptr); ASSERT_NE(param0, nullptr); EXPECT_THAT(gather, op::Sharding("{devices=[8,1,1]0,1,2,3,4,5,6,7}")); - EXPECT_THAT(param0, op::Sharding("{devices=[8,1]0,1,2,3,4,5,6,7}")); + EXPECT_THAT(param0, AnyOf(op::Sharding("{devices=[1,8]0,1,2,3,4,5,6,7}"), + op::Sharding("{devices=[8,1]0,1,2,3,4,5,6,7}"))); TF_EXPECT_OK(gather->sharding().Validate(gather->shape(), 8)); // Ensure no resharding op is created for operand 0 of gather in this case. EXPECT_EQ(param0, gather->operand(0)); } TEST_F(AutoShardingTest, GatherConvTest) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %entry { %param0 = f32[1024,1024]{0,1} parameter(0) @@ -798,8 +1231,8 @@ ENTRY %entry { ROOT convolution = f32[128,1024,1024]{2,1,0} convolution(gather, reshape), window={size=1}, dim_labels=b0f_io0->b0f })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {4, 1, 1}; @@ -809,13 +1242,13 @@ ENTRY %entry { TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); EXPECT_TRUE(changed); auto* gather = FindInstruction(module.get(), "gather"); + auto* conv = FindInstruction(module.get(), "convolution"); ASSERT_NE(gather, nullptr); - EXPECT_THAT(gather, op::Sharding("{devices=[4,1,1]0,1,2,3}")); + ASSERT_NE(conv, nullptr); + EXPECT_THAT(gather, op::Sharding("{devices=[1,4,1]0,1,2,3}")); + EXPECT_THAT(conv, op::Sharding("{devices=[1,4,1]0,1,2,3}")); auto gather_sharding = gather->sharding(); TF_EXPECT_OK(gather_sharding.Validate(gather->shape(), 4)); - auto* conv = FindInstruction(module.get(), "convolution"); - ASSERT_NE(conv, nullptr); - EXPECT_THAT(conv, op::Sharding("{devices=[4,1,1]0,1,2,3}")); auto conv_sharding = conv->sharding(); TF_EXPECT_OK(conv_sharding.Validate(conv->shape(), 4)); } @@ -1084,7 +1517,7 @@ TEST_F(AutoShardingTest, InvalidOptions) { TEST_F(AutoShardingTest, AutoShardingKeepUserShardingInputOutput) { // An HLO Module with sharding for all instructions. - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { @@ -1094,8 +1527,8 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { ROOT %copy = f32[64,32]{1,0} copy(f32[64,32]{1,0} %dot), sharding={devices=[2,2]0,1,2,3} } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); // Remove the sharding in dot auto* dot = FindInstruction(module.get(), "dot"); dot->clear_sharding(); @@ -1116,7 +1549,7 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { TEST_F(AutoShardingTest, AutoShardingKeepUserShardingAdd) { // An HLO Module with sharding for all instructions. - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %elementwise { %param0 = f32[128,128]{0,1} parameter(0) @@ -1124,8 +1557,8 @@ ENTRY %elementwise { %add = f32[128,128]{0,1} add(%param0, %param1), sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} ROOT %copy = f32[128,128]{0,1} copy(%add) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); // Run AutoSharding AutoShardingOption option; option.enable = true; @@ -1150,7 +1583,7 @@ ENTRY %elementwise { TEST_F(AutoShardingTest, AutoShardingKeepUserShardingDot) { // An HLO Module with sharding for all instructions. - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { @@ -1160,8 +1593,8 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { ROOT %copy = f32[64,32]{1,0} copy(f32[64,32]{1,0} %dot), sharding={devices=[2,2]0,1,2,3} } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); // Remove the sharding in param0, param1 and copy auto* param0 = FindInstruction(module.get(), "param0"); param0->clear_sharding(); @@ -1196,7 +1629,7 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { } TEST_F(AutoShardingTest, DISABLED_AutoShardingKeepUserShardingTupleReduce) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module %func (lhs_value: f32[], lhs_index: s32[], rhs_value: f32[], rhs_index: s32[]) -> (f32[], s32[]) { %lhs_value = f32[] parameter(0) @@ -1220,8 +1653,8 @@ ENTRY %entry { %reduce = (f32[1,16]{1,0}, s32[1,16]{1,0}) reduce(f32[1,16,40]{2,1,0} %param0, s32[1,16,40]{2,1,0} %iota, f32[] %constant.a, s32[] %constant.b), dimensions={2}, to_apply=%func, sharding={{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}, {devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -1243,8 +1676,42 @@ ENTRY %entry { EXPECT_FALSE(param0->sharding().IsReplicated()); } +TEST_F(AutoShardingTest, GetTupleElementUserShardingsParameter) { + constexpr absl::string_view kHloString = R"( +HloModule module +ENTRY %tupleparameter { + %param0 = f32[32,64]{1,0} parameter(0) + %param1 = f32[32,64]{1,0} parameter(1), sharding={devices=[2,2]<=[4]} + %tuple1 = (f32[32,64]{1,0}, f32[32,64]{1,0}) tuple(f32[32,64]{1,0} %param0, f32[32,64]{1,0} %param1) + %first = f32[32,64]{1,0} get-tuple-element((f32[32,64]{1,0}, f32[32,64]{1,0}) %tuple1), index=0 + %second = f32[32,64]{1,0} get-tuple-element((f32[32,64]{1,0}, f32[32,64]{1,0}) %tuple1), index=1, sharding={devices=[4,1]<=[4]} + ROOT root = f32[32,64]{1,0} add(%first, %second) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + AutoShardingOption option; + option.enable = true; + option.preserve_shardings = + AutoShardingOption::PreserveShardingsType::kKeepAllShardings; + option.device_mesh_shape = {2, 2}; + option.device_mesh_ids = {0, 1, 2, 3}; + option.device_mesh_alpha = {1.0, 1.0}; + option.device_mesh_beta = {0.01, 1.0}; + TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); + VLOG(10) << module->ToString(); + EXPECT_TRUE(changed); + const HloInstruction* param1 = FindInstruction(module.get(), "param1"); + ASSERT_NE(param1, nullptr); + EXPECT_THAT(param1, op::Sharding("{devices=[2,2]<=[4]}")); + + const HloInstruction* second = FindInstruction(module.get(), "root"); + ASSERT_NE(second, nullptr); + EXPECT_THAT(second, op::Sharding("{devices=[4,1]<=[4]}")); +} + TEST_F(AutoShardingTest, DISABLED_TupleParameter) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %tupleparameter { %tuple_param = (f32[16,32,64]{2,1,0}, f32[16,32,64]{2,1,0}) parameter(0) @@ -1253,8 +1720,8 @@ ENTRY %tupleparameter { ROOT root = f32[16,32,64]{2,1,0} add(%first, %second) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -1274,7 +1741,7 @@ ENTRY %tupleparameter { // CRASHES TEST_F(AutoShardingTest, DISABLED_GetTupleElementWithUserShardingTest) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module %while_cond { @@ -1306,8 +1773,8 @@ ENTRY %entry (param0: f32[16,256,256], param1: f32[16,256,256]) -> f32[16,256,25 %tuple1 = f32[16,256,256]{2,1,0} get-tuple-element((u32[], f32[16,256,256]{2,1,0}, f32[16,256,256]{2,1,0}) %while.1), index=1, sharding={devices=[2,2,1]0,2,1,3} ROOT %tanh = f32[16,256,256]{2,1,0} tanh(f32[16,256,256]{2,1,0} %tuple1) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.preserve_shardings = AutoShardingOption::PreserveShardingsType::kKeepAllShardings; @@ -1321,7 +1788,7 @@ ENTRY %entry (param0: f32[16,256,256], param1: f32[16,256,256]) -> f32[16,256,25 } TEST_F(AutoShardingTest, While) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module %cond { @@ -1357,8 +1824,8 @@ ENTRY %entry { ROOT %result = bf16[128,512,768] get-tuple-element(%while), index=3 })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -1401,7 +1868,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, DynamicSlice) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %entry { %param0 = s32[] parameter(0) @@ -1416,8 +1883,8 @@ ENTRY %entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -1430,7 +1897,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, Alias) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module, input_output_alias={ {0}: (0, {}, may-alias), {1}: (1, {}, may-alias), {2}: (2, {}, may-alias), {3}: (3, {}, may-alias)} ENTRY %entry { @@ -1442,8 +1909,8 @@ ENTRY %entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -1456,7 +1923,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, AliasTupleParameter) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module, input_output_alias={ {0}: (0, {0}, may-alias), {1}: (0, {1}, may-alias), {2}: (0, {2}, may-alias), {3}: (0, {3}, may-alias)} ENTRY %entry { @@ -1469,8 +1936,8 @@ ENTRY %entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -1483,7 +1950,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, JaxRandomUniform) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module clone { lhs.1 = u32[] parameter(0) @@ -1516,8 +1983,8 @@ ENTRY %entry { ROOT maximum = f32[8,512]{1,0} maximum(subtract, broadcast.d) } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -1534,7 +2001,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, Reshape) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %entry { @@ -1547,8 +2014,8 @@ ENTRY %entry { %dot = bf16[512,1024,16,128]{3,2,1,0} dot(bf16[512,1024,2048]{2,1,0} %param.2, bf16[2048,16,128]{2,1,0} %reshape), lhs_contracting_dims={2}, rhs_contracting_dims={0} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {64, 1}; @@ -1562,7 +2029,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, ReshapeWithInvalidUserSharding) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %entry { @@ -1571,8 +2038,8 @@ ENTRY %entry { %copy = bf16[1,24,16,16]{3,2,1,0} copy(%reshape) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {32, 1}; @@ -1588,15 +2055,15 @@ ENTRY %entry { } TEST_F(AutoShardingTest, Broadcast) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %entry { %param.0 = s32[32]{0} parameter(0) ROOT broadcast = s32[512,1024,1024,32]{3,2,1,0} broadcast(s32[32]{0} %param.0), dimensions={3} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {1, 1, 64}; @@ -1607,7 +2074,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, TestReshardingCostsForUserAnnotatedSharding) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module ENTRY %entry { @@ -1616,8 +2083,8 @@ ENTRY %entry { %dot = f32[256,256] dot(%param0, %param1), lhs_contracting_dims={1}, rhs_contracting_dims={1} ROOT %result = f32[256,256] tanh(%dot), sharding={devices=[1,4]0,1,2,3} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -1633,7 +2100,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, AllowAliasToFollowerConversion) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module, input_output_alias={ {0}: (0, {}, may-alias), {1}: (1, {}, may-alias), {2}: (2, {}, may-alias), {3}: (3, {}, may-alias)} ENTRY %entry { @@ -1645,8 +2112,8 @@ ENTRY %entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -1660,7 +2127,7 @@ ENTRY %entry { } TEST_F(AutoShardingTest, DisallowAliasToFollowerConversion) { - const char* const hlo_string = R"( + constexpr absl::string_view kHloString = R"( HloModule module, input_output_alias={ {0}: (0, {}, may-alias), {1}: (1, {}, may-alias), {2}: (2, {}, may-alias), {3}: (3, {}, may-alias)} ENTRY %entry { @@ -1672,8 +2139,8 @@ ENTRY %entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; @@ -1686,43 +2153,60 @@ ENTRY %entry { EXPECT_TRUE(changed); } -TEST_F(AutoShardingTest, - GatherMergedIndexParallelAndOperandPassthroughBackwardPass) { - const char* const hlo_string = R"( -HloModule module +TEST_F(AutoShardingTest, BufferDonorConfigPreservation) { + constexpr absl::string_view kHloString = R"( +HloModule Module, buffer_donor={ (0, {0}), (0, {1}) } -ENTRY %module { - %arg.0 = s32[8,4,2,2]{3,2,1,0} parameter(0) - %arg.1 = s32[1,8,4]{2,1,0} parameter(1) - %operand = s32[8,4,2,2]{3,2,1,0} copy(s32[8,4,2,2]{3,2,1,0} %arg.0) - %indices = s32[1,8,4]{2,1,0} copy(s32[1,8,4]{2,1,0} %arg.1) - %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1 - %concatenate = s32[2,8,4]{2,1,0} concatenate( - s32[1,8,4]{2,1,0} %iota, s32[1,8,4]{2,1,0} %indices), dimensions={0} - %gather = s32[8,4,2,2]{3,2,1,0} gather( - s32[8,4,2,2]{3,2,1,0} %operand, - s32[2,8,4]{2,1,0} %concatenate), offset_dims={2,3}, - collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0, - slice_sizes={1,1,2,2}, - sharding={devices=[2,1,2,1]0,1,4,5 metadata={op_name="a"}} - ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather) +ENTRY entry { + %p = (f32[], f32[]) parameter(0) + %p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0 + %p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1 + ROOT %out = (f32[], f32[]) tuple(%p0, %p1) } )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + AutoShardingOption option; + option.enable = true; + option.device_mesh_shape = {2, 2}; + // Creating an explicit copy here to ensure that it is not modified during + // auto-sharding + const HloBufferDonorConfig buffer_donor_config_before = + module->buffer_donor_config(); + TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); + EXPECT_TRUE(changed); + const HloBufferDonorConfig& buffer_donor_config_after = + module->buffer_donor_config(); + EXPECT_EQ(buffer_donor_config_before.ToString(), + buffer_donor_config_after.ToString()); +} - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); +TEST_F(AutoShardingTest, InputOutputAliasConfigPreservation) { + constexpr absl::string_view kHloString = R"( +HloModule Module, input_output_alias={ {0}: (0, {0}, must-alias), {1}: (0, {1}) } + +ENTRY entry { + %p = (f32[], f32[]) parameter(0) + %p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0 + %p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1 + ROOT %out = (f32[], f32[]) tuple(%p0, %p1) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; - option.device_mesh_ids = {0, 1, 2, 3}; - option.device_mesh_alpha = {1.0, 1.0}; - option.device_mesh_beta = {0.01, 1.0}; + // Creating an explicit copy here to ensure that it is not modified during + // auto-sharding + const HloInputOutputAliasConfig input_output_alias_config_before = + module->input_output_alias_config(); TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); - VLOG(0) << module->ToString(); EXPECT_TRUE(changed); - auto* gather = FindInstruction(module.get(), "gather"); - ASSERT_NE(gather, nullptr); - EXPECT_THAT(gather, op::Sharding("{devices=[1,1,2,2]0,1,2,3}")); + const HloInputOutputAliasConfig& input_output_alias_config_after = + module->input_output_alias_config(); + EXPECT_EQ(input_output_alias_config_before.ToString(), + input_output_alias_config_after.ToString()); } } // namespace diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index d5965d057576a..bd4b00de3b594 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -49,9 +49,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/ir/ptrvec.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/service/call_graph.h" #include "xla/service/sharding_propagation.h" +#include "xla/service/while_loop_analysis.h" #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" @@ -68,12 +70,11 @@ inline HloInstruction* PassThroughCustomCallMarkerUser( HloInstruction* raw_user, const HloInstruction* inst); std::optional GetInputSharding(const HloInstruction* ins, - const HloInstruction* operand, int64_t op_index, const HloSharding& output_sharding, const CallGraph& call_graph, int64_t num_devices) { - auto ins_clone = ins->Clone(); + std::unique_ptr ins_clone = ins->Clone(); ins_clone->set_sharding(output_sharding); std::vector> operands; @@ -95,9 +96,17 @@ std::optional GetInputSharding(const HloInstruction* ins, operands.push_back(std::move(operand_clone)); } - auto result = ShardingPropagation::GetShardingFromUser( - *ins_clone->operand(op_index), *ins_clone, 10, true, call_graph); - return result; + std::optional inferred_sharding = + ShardingPropagation::GetShardingFromUser(*ins_clone->operand(op_index), + *ins_clone, 10, true, call_graph, + /*sharding_helper=*/nullptr); + + if (!inferred_sharding.has_value() && IsTopKCustomCall(ins)) { + // ShardingPropagation::GetShardingFromUser does not handle TopK custom + // calls. Mirroring that function's handling of kSort, we handle TopK below. + inferred_sharding = output_sharding; + } + return inferred_sharding; } // Return whether the instruction is an activation from another pipeline stage. @@ -144,6 +153,32 @@ std::optional PropagateDimwiseSharding( return input_spec; } +HloSharding PropagateDimwiseShardingSlice(const HloSharding& input_spec, + const Shape& old_shape, + const Shape& new_shape, + const Array& device_mesh) { + if (input_spec.IsReplicated()) { + return input_spec; + } + + CHECK(old_shape.IsArray()); + + std::vector tensor_to_mesh_dim = + GetTensorDimToMeshDim(new_shape.rank(), input_spec, device_mesh, + /* consider_reverse_device_meshes */ false); + + std::vector tensor_dims; + std::vector mesh_dims; + for (size_t i = 0; i < new_shape.rank(); ++i) { + if (new_shape.dimensions(i) == old_shape.dimensions(i) && + tensor_to_mesh_dim[i] > -1) { + tensor_dims.push_back(i); + mesh_dims.push_back(tensor_to_mesh_dim[i]); + } + } + return Tile(new_shape, tensor_dims, mesh_dims, device_mesh); +} + // Propagate sharding for ReduceWindow-like operations. // The sharding can successfully propagate if the window operation only happens // on tensor dimensions that are not tiled. @@ -348,6 +383,7 @@ void BatchDimMapForward(const std::vector& instructions, case HloOpcode::kBitcastConvert: case HloOpcode::kCopy: case HloOpcode::kCos: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -607,6 +643,7 @@ void BatchDimMapBackward(const std::vector& instructions, case HloOpcode::kBitcastConvert: case HloOpcode::kCopy: case HloOpcode::kCos: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -865,7 +902,8 @@ void RemoveDuplicatedStrategy(std::unique_ptr& strategy_group) { absl::flat_hash_set added; size_t num_skipped_due_to_infinity_costs = 0; for (size_t i = 0; i < strategy_group->strategies.size(); ++i) { - if (AllInfinityCosts(strategy_group->strategies[i].resharding_costs)) { + if (AllInfinityCosts( + strategy_group->strategies[i].communication_resharding_costs)) { num_skipped_due_to_infinity_costs++; continue; } @@ -1038,8 +1076,8 @@ void UseAllReduceForGradAcc(StableHashSet& replicated_set, // dimensions at 0, e.g., array is 2D and dim = 1, this returns array[0, 1], // array[1, 1], array [2, 1], .... // Returns error status if dim >= array.num_dimensions(). -StatusOr> GetValuesAlongOneDim(const Array& array, - int dim) { +absl::StatusOr> GetValuesAlongOneDim( + const Array& array, int dim) { if (dim >= array.num_dimensions()) { return absl::OutOfRangeError(absl::StrCat( "Input dim (", dim, @@ -1056,7 +1094,8 @@ StatusOr> GetValuesAlongOneDim(const Array& array, } // Check whether a sequence is an arithmetic sequence. -StatusOr CheckArithmeticSequence(absl::Span sequence) { +absl::StatusOr CheckArithmeticSequence( + absl::Span sequence) { if (sequence.size() < 2) { return absl::OutOfRangeError( "Invalid device id assignment: sequence.size() < 2"); @@ -1313,9 +1352,7 @@ HloInstruction* ReshardTensor(HloInstruction* tensor, void FixMixedMeshShapeReshardingGetTupleElementWithTupleOutput( HloInstruction* inst, const std::vector>& dst_shardings, - const Array& device_mesh, - absl::flat_hash_map>* - preserve_shardings) { + const Array& device_mesh) { size_t tuple_size = inst->shape().tuple_shapes_size(); auto current_sharding = inst->sharding(); @@ -1376,7 +1413,7 @@ void FixMixedMeshShapeReshardingGetTupleElementWithTupleOutput( void FixMixedMeshShapeReshardingGetTupleElement( HloInstruction* inst, const HloSharding& dst_sharding, const Array& device_mesh, - absl::flat_hash_map>* + absl::flat_hash_map>& preserve_shardings) { HloInstruction* operand = inst->mutable_operand(0); auto input_tuple_sharding = operand->sharding(); @@ -1406,11 +1443,11 @@ void FixMixedMeshShapeReshardingGetTupleElement( TF_CHECK_OK(inst->ReplaceUseWith(user, replace_with)); } - CHECK_NE(preserve_shardings, nullptr); - if (preserve_shardings->contains(inst->name())) { - (*preserve_shardings)[replace_with->name()] = - preserve_shardings->at(inst->name()); - preserve_shardings->erase(inst->name()); + auto iter = preserve_shardings.find(inst->name()); + if (iter != preserve_shardings.end()) { + preserve_shardings[replace_with->name()] = + std::vector(iter->second); + preserve_shardings.erase(inst->name()); } } @@ -1419,7 +1456,8 @@ void FixMixedMeshShapeResharding(HloInstruction* inst, int operand_num, const Array& device_mesh, ReshardingCache* resharding_cache) { HloInstruction* operand = inst->mutable_operand(operand_num); - if (operand->opcode() == HloOpcode::kOutfeed) { + if (operand->opcode() == HloOpcode::kOutfeed || + operand->opcode() == HloOpcode::kSendDone) { return; } @@ -1548,7 +1586,6 @@ HloSharding Tile(const Shape& tensor_shape, tile_assignment_dimensions[tensor_dims[i]] = device_mesh.dim(mesh_dims[i]); split_prod *= device_mesh.dim(mesh_dims[i]); } - // Replicate on remaining mesh dimensions bool replicate_on_last_tile_dim = false; if (split_prod < device_mesh.num_elements()) { @@ -1604,12 +1641,10 @@ HloSharding Tile(const Shape& tensor_shape, : HloSharding::Tile(std::move(tile_assignment)); } -AliasMap BuildAliasMap(const HloModule* module) { +AliasMap BuildAliasMap(const HloModule* module, + const HloInputOutputAliasConfig& alias_config) { AliasMap alias_map; - const HloInputOutputAliasConfig& alias_config = - module->input_output_alias_config(); - HloComputation* entry = module->entry_computation(); const auto& parameter_instructions = entry->parameter_instructions(); const HloInstruction* output_tuple = entry->root_instruction(); @@ -1670,13 +1705,11 @@ AliasMap BuildAliasMap(const HloModule* module) { } AliasSet BuildAliasSet(const HloModule* module, + const HloInputOutputAliasConfig& alias_config, const StrategyMap& strategy_map) { - // Adjust the edge cost for aliases (donated buffer). - // Typically, old weights and new weights are aliases, so we should + // We also look at alias_config to adjust the edge cost for aliases (donated + // buffer). Typically, old weights and new weights are aliases, so we should // let them have the same sharding spec. - const HloInputOutputAliasConfig& alias_config = - module->input_output_alias_config(); - HloComputation* entry = module->entry_computation(); const auto& parameter_instructions = entry->parameter_instructions(); const HloInstruction* output_tuple = entry->root_instruction(); @@ -1729,19 +1762,14 @@ AliasSet BuildAliasSet(const HloModule* module, for (const HloComputation* computation : module->computations()) { for (const HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kWhile) { + // Aliasing between the while op, and the parameters of its body and + // conditional computations is handled by making the latter follow the + // input tuple to thew while loop in the function + // BuildStrategyAndCost(). traverse_tuple_alias( strategy_map.at(instruction).get(), strategy_map.at(instruction->while_body()->root_instruction()) .get()); - traverse_tuple_alias( - strategy_map.at(instruction).get(), - strategy_map.at(instruction->while_body()->parameter_instruction(0)) - .get()); - traverse_tuple_alias( - strategy_map.at(instruction).get(), - strategy_map - .at(instruction->while_condition()->parameter_instruction(0)) - .get()); } else if (instruction->opcode() == HloOpcode::kConditional) { auto branch_computations = instruction->branch_computations(); for (size_t i = 0; i < branch_computations.size(); ++i) { @@ -1760,9 +1788,10 @@ AliasSet BuildAliasSet(const HloModule* module, return alias_set; } -void CheckAliasSetCompatibility(const AliasSet& alias_set, - const StrategyGroups& strategy_groups, - const HloInstructionSequence& sequence) { +Status CheckAliasSetCompatibility(const AliasSet& alias_set, + const StrategyGroups& strategy_groups, + const HloInstructionSequence& sequence, + bool crash_on_error) { const std::vector& instructions = sequence.instructions(); // Checks the compatibility for (const auto& pair : alias_set) { @@ -1793,24 +1822,29 @@ void CheckAliasSetCompatibility(const AliasSet& alias_set, "tensors and may result in large memory consumption: " << "(" << instructions.at(src_strategy_group->instruction_id)->name() << ", " << instructions.at(dst_strategy_group->instruction_id)->name() - << ")" - << "\n" + << ")" << "\n" << "(" << src_strategy_group->node_idx << ", " << dst_strategy_group->node_idx << ")\n" << src_strategy_group->ToString() << "\n" << dst_strategy_group->ToString(); } - CHECK(compatible_cnt > 0) - << "Alias pair does not have any sharding strategy in common: " - << "(" << instructions.at(src_strategy_group->instruction_id)->name() - << ", " << instructions.at(dst_strategy_group->instruction_id)->name() - << ")" - << "\n" - << "(" << src_strategy_group->node_idx << ", " - << dst_strategy_group->node_idx << ")\n" - << src_strategy_group->ToString() << "\n" - << dst_strategy_group->ToString(); + if (compatible_cnt <= 0) { + std::string err_msg = absl::StrCat( + "Alias pair does not have any sharding strategy in common: (", + instructions.at(src_strategy_group->instruction_id)->name(), ", ", + instructions.at(dst_strategy_group->instruction_id)->name(), ")\n(", + src_strategy_group->node_idx, ", ", dst_strategy_group->node_idx, + ")\n", src_strategy_group->ToString(), "\n", + dst_strategy_group->ToString()); + if (crash_on_error) { + LOG(FATAL) << err_msg; + } else { + LOG(WARNING) << err_msg; + return absl::InternalError(err_msg); + } + } } + return OkStatus(); } size_t VectorGreaterThanOneElementCount(absl::Span span, @@ -1845,6 +1879,9 @@ int64_t GetInstructionSize(const Shape& shape) { int64_t GetShardedInstructionSize(const Shape& shape, int64_t num_devices, std::optional sharding) { + if (sharding && sharding->IsUnknown()) { + sharding = HloSharding::Replicate(); + } if (shape.IsTuple()) { int64_t size = 0; for (size_t i = 0; i < shape.tuple_shapes_size(); i++) { @@ -1885,52 +1922,7 @@ HloInstruction* FindInstruction( return nullptr; } -double AllToAllCostUtil(double num_bytes, int mesh_dim, int64_t num_devices, - const std::vector& mesh_alpha, - const std::vector& mesh_beta) { - // A penalty factor to make the theoretical cost match the - // empirical cost on v100 + nvlink. - double penalty_factor = static_cast(num_devices) / 2.0; - return (round(mesh_alpha[mesh_dim] + mesh_beta[mesh_dim] * (num_devices - 1) / - num_devices / num_devices * - num_bytes * penalty_factor) + - 0.001); -} - -// Do not consider device id changes yet. -double ReshardingCostMixedMeshShape( - const Shape& shape, std::vector src_tensor_dim_to_mesh_dim, - std::vector dst_tensor_dim_to_mesh_dim, int64_t num_devices, - const std::vector& mesh_alpha, - const std::vector& mesh_beta) { - double resharding_costs = 0.0; - for (size_t i = 0; i < shape.rank(); ++i) { - // Only consider sharded dimensions, do not consider replicate_on_last_dim. - if (src_tensor_dim_to_mesh_dim[i] == dst_tensor_dim_to_mesh_dim[i]) { - continue; - } - if (dst_tensor_dim_to_mesh_dim[i] == -1 || - src_tensor_dim_to_mesh_dim[i] == -1) { - // AllToAll cost - int64_t communication_dim; - if (dst_tensor_dim_to_mesh_dim[i] != -1) { - communication_dim = dst_tensor_dim_to_mesh_dim[i]; - } else { - communication_dim = src_tensor_dim_to_mesh_dim[i]; - } - int64_t communication_bytes = GetBytes(shape); - resharding_costs += - AllToAllCostUtil(communication_bytes, communication_dim, num_devices, - mesh_alpha, mesh_beta); - } else { - // Do not support this sharding, assuming it is gonna be very expensive. - return kInfinityCost; - } - } - return resharding_costs; -} - -std::pair> +absl::StatusOr> AdjustShardingWithPartialMeshShapePerElement( const HloSharding& sharding, const absl::flat_hash_set& valid_shards, int64_t total_num_devices, @@ -1956,7 +1948,7 @@ AdjustShardingWithPartialMeshShapePerElement( LOG(FATAL) << err_msg; } else { LOG(WARNING) << err_msg; - return {absl::InternalError(err_msg), std::nullopt}; + return absl::InternalError(err_msg); } } } @@ -1969,7 +1961,7 @@ AdjustShardingWithPartialMeshShapePerElement( if (valid_shards.find(sharding.tile_assignment().dim( sharding.tile_assignment().num_dimensions() - 1)) != valid_shards.end()) { - return {OkStatus(), HloSharding::Replicate()}; + return HloSharding::Replicate(); } // If replicate on other dimensions, remove the // replicate_on_last_tile @@ -2015,13 +2007,12 @@ AdjustShardingWithPartialMeshShapePerElement( // Set arbitrary values because it will not be used. std::iota(device_ids.begin(), device_ids.end(), 0); tile_assignment.SetValues(device_ids); - HloSharding new_sharding = HloSharding::Tile(std::move(tile_assignment)); - return {OkStatus(), new_sharding}; + return HloSharding::Tile(std::move(tile_assignment)); } - return {OkStatus(), std::nullopt}; + return std::nullopt; } -StatusOr AdjustShardingsWithPartialMeshShape( +absl::StatusOr AdjustShardingsWithPartialMeshShape( const std::vector& instructions, const std::vector& mesh_shape, int64_t total_num_devices, bool crash_on_error) { @@ -2042,17 +2033,21 @@ StatusOr AdjustShardingsWithPartialMeshShape( for (size_t i = 0; i < inst->shape().tuple_shapes_size(); i++) { auto shape = inst->shape().tuple_shapes(i); auto sharding = inst->sharding().tuple_elements()[i]; - std::pair> new_sharding_result = + if (sharding.IsUnknown()) { + output_flattened_shardings.push_back(sharding); + continue; + } + absl::StatusOr> new_sharding_result = AdjustShardingWithPartialMeshShapePerElement( sharding, valid_shards, total_num_devices, crash_on_error); - if (new_sharding_result.first.ok()) { - if (new_sharding_result.second.has_value()) { - output_flattened_shardings.push_back(*new_sharding_result.second); + if (new_sharding_result.ok()) { + if (new_sharding_result.value().has_value()) { + output_flattened_shardings.push_back(*new_sharding_result.value()); } else { output_flattened_shardings.push_back(sharding); } } else { - return new_sharding_result.first; + return new_sharding_result.status(); } } size_t i = 0; @@ -2061,17 +2056,17 @@ StatusOr AdjustShardingsWithPartialMeshShape( } inst->set_sharding(HloSharding::Tuple(output_tuple_sharding)); } else { - std::pair> sharding_result = + absl::StatusOr> sharding_result = AdjustShardingWithPartialMeshShapePerElement( inst->sharding(), valid_shards, total_num_devices, crash_on_error); - if (sharding_result.first.ok()) { - if (sharding_result.second.has_value()) { - inst->set_sharding(*sharding_result.second); + if (sharding_result.ok()) { + if (sharding_result.value().has_value()) { + inst->set_sharding(*sharding_result.value()); changed = true; } } else { - return sharding_result.first; + return sharding_result.status(); } } } @@ -2130,50 +2125,56 @@ bool IsEntryComputationInputOrOutput(const HloModule* module, void ComputeInstructionExecutionCountsHelper( const HloComputation* computation, int64_t computation_execution_count, - int64_t loop_iteration_count_estimate, - absl::flat_hash_map* + int64_t static_loop_iteration_count_estimate, + absl::flat_hash_map& instruction_execution_counts) { - for (auto instruction : computation->instructions()) { - (*instruction_execution_counts)[instruction] = computation_execution_count; + for (const HloInstruction* instruction : computation->instructions()) { + (instruction_execution_counts)[instruction] = computation_execution_count; if (instruction->opcode() == HloOpcode::kWhile) { + int64_t loop_iteration_count = static_loop_iteration_count_estimate; + if (std::optional upper_bound = + ComputeWhileLoopTripCountUpperBound(instruction)) { + loop_iteration_count = *upper_bound; + } int64_t while_body_condition_execution_count = - computation_execution_count * loop_iteration_count_estimate; + computation_execution_count * loop_iteration_count; ComputeInstructionExecutionCountsHelper( instruction->while_body(), - /*computation_execution_count */ + /* computation_execution_count */ while_body_condition_execution_count, - /*loop_iteration_count_estimate*/ loop_iteration_count_estimate, - instruction_execution_counts); + /* loop_iteration_count_estimate */ + static_loop_iteration_count_estimate, instruction_execution_counts); ComputeInstructionExecutionCountsHelper( instruction->while_condition(), - /*computation_execution_count */ + /* computation_execution_count */ while_body_condition_execution_count, - /*loop_iteration_count_estimate*/ loop_iteration_count_estimate, - instruction_execution_counts); + /* loop_iteration_count_estimate */ + static_loop_iteration_count_estimate, instruction_execution_counts); } else if (instruction->opcode() == HloOpcode::kConditional) { // TODO(pratikf): For now, we do not scale down the execution counts of // branch statements, though we should at some point. - auto branch_computations = instruction->branch_computations(); + PtrVec branch_computations = + instruction->branch_computations(); for (size_t i = 0; i < branch_computations.size(); ++i) { ComputeInstructionExecutionCountsHelper( branch_computations[i], - /*computation_execution_count */ + /* computation_execution_count */ computation_execution_count, - /*loop_iteration_count_estimate*/ loop_iteration_count_estimate, - instruction_execution_counts); + /* loop_iteration_count_estimate */ + static_loop_iteration_count_estimate, instruction_execution_counts); } } } } absl::flat_hash_map -ComputeInstructionExecutionCounts(const HloModule* module, - int64_t loop_iteration_count_estimate) { +ComputeInstructionExecutionCounts( + const HloModule* module, int64_t static_loop_iteration_count_estimate) { absl::flat_hash_map instruction_execution_counts; ComputeInstructionExecutionCountsHelper(module->entry_computation(), 1, - loop_iteration_count_estimate, - &instruction_execution_counts); + static_loop_iteration_count_estimate, + instruction_execution_counts); return instruction_execution_counts; } @@ -2199,29 +2200,6 @@ void EnumerateAllPossibleMeshShapesHelper( } } -std::vector> EnumerateAllPossibleMeshShapes( - const int64_t num_devices, int num_mesh_dims, bool symmetrical_mesh_dims) { - std::vector> result; - EnumerateAllPossibleMeshShapesHelper(num_devices, num_mesh_dims, {}, result); - - if (symmetrical_mesh_dims) { - absl::flat_hash_set> dedup_result; - for (const std::vector& mesh_shape : result) { - dedup_result.insert( - absl::btree_multiset(mesh_shape.begin(), mesh_shape.end())); - } - - result.clear(); - - for (const absl::btree_multiset& mesh_shape_set : dedup_result) { - result.push_back( - std::vector(mesh_shape_set.begin(), mesh_shape_set.end())); - } - } - - return result; -} - std::vector> InferMeshShapesToTry( const HloModule& module) { int64_t sharding_1d = -1; @@ -2233,7 +2211,7 @@ std::vector> InferMeshShapesToTry( for (const HloSharding& child : sharding.tuple_elements()) { process_sharding(child); } - } else if (!sharding.IsReplicated()) { + } else if (!sharding.IsReplicated() && !sharding.IsTileMaximal()) { absl::Span dims = sharding.tile_assignment().dimensions(); std::vector dims_greater_than_one; for (const int64_t dim : dims) { @@ -2280,12 +2258,74 @@ std::vector> InferOrEnumerateMeshShapesToTry( bool symmetrical_mesh_dims) { std::vector> mesh_shapes = InferMeshShapesToTry(module); if (mesh_shapes.empty()) { - mesh_shapes = spmd::EnumerateAllPossibleMeshShapes( - num_devices, num_mesh_dims, - /* symmetrical_mesh_dims */ symmetrical_mesh_dims); + EnumerateAllPossibleMeshShapesHelper(num_devices, num_mesh_dims, {}, + mesh_shapes); + } + if (symmetrical_mesh_dims) { + absl::flat_hash_set> dedup_result; + for (const std::vector& mesh_shape : mesh_shapes) { + dedup_result.insert( + absl::btree_multiset(mesh_shape.begin(), mesh_shape.end())); + } + + mesh_shapes.clear(); + + for (const absl::btree_multiset& mesh_shape_set : dedup_result) { + mesh_shapes.push_back( + std::vector(mesh_shape_set.begin(), mesh_shape_set.end())); + } } + return mesh_shapes; } +bool IsShardingMisaligned(const HloSharding& sharding, const Shape& shape) { + if (shape.IsTuple()) { + for (size_t i = 0; i < shape.tuple_shapes_size(); ++i) { + if (IsShardingMisaligned( + sharding.IsTuple() + ? sharding.GetSubSharding(shape, {static_cast(i)}) + : sharding, + shape.tuple_shapes(i))) { + return true; + } + } + return false; + } + + if (sharding.IsReplicated() || sharding.IsManual() || sharding.IsUnknown() || + sharding.IsTileMaximal()) { + return false; + } + + for (size_t i = 0; i < shape.rank(); ++i) { + int64_t shape_dim = shape.dimensions()[i]; + int64_t sharding_dim = sharding.tile_assignment().dim(i); + if (shape_dim % sharding_dim != 0) { + return true; + } + } + return false; +} + +HloSharding ReplaceGivenShardingsWithUnknownForTuple( + const HloSharding& sharding, const Shape& shape, + absl::Span to_replace_sharding_ids) { + std::vector new_tuple_shardings; + int64_t num_elements = sharding.tuple_elements().size(); + for (int32_t i = 0; i < num_elements; ++i) { + bool can_change_sharding = to_replace_sharding_ids.size() == 1 + ? to_replace_sharding_ids[0] + : to_replace_sharding_ids[i]; + if (can_change_sharding) { + new_tuple_shardings.push_back(HloSharding::Unknown()); + } else { + new_tuple_shardings.push_back(sharding.tuple_elements()[i]); + } + } + + return HloSharding::Tuple(shape, new_tuple_shardings); +} + } // namespace spmd } // namespace xla diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index 290d92f45a86a..4a129bb159076 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,7 +34,9 @@ limitations under the License. #include "absl/types/span.h" #include "xla/array.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/ir/hlo_sharding.h" @@ -357,7 +359,6 @@ inline std::vector Argsort(const std::vector& scores) { // Given the sharding for an instruction, invoke the sharding propagation pass // to infer appropriate shardings for its operands. std::optional GetInputSharding(const HloInstruction* ins, - const HloInstruction* operand, int64_t op_index, const HloSharding& output_sharding, const xla::CallGraph& call_graph, @@ -427,6 +428,11 @@ std::optional PropagateDimwiseSharding( const HloSharding& input_spec, const Shape& old_shape, const Shape& new_shape); +HloSharding PropagateDimwiseShardingSlice(const HloSharding& input_spec, + const Shape& old_shape, + const Shape& new_shape, + const Array& device_mesh); + // Propagate sharding for ReduceWindow-like operations. // The sharding can successfully propagate if the window operation only happens // on tensor dimensions that are not tiled. @@ -459,15 +465,13 @@ Shape ComputeIntermediateShape(const HloSharding& src_sharding, void FixMixedMeshShapeReshardingGetTupleElement( HloInstruction* inst, const HloSharding& dst_sharding, const Array& device_mesh, - absl::flat_hash_map>* + absl::flat_hash_map>& preserve_shardings); void FixMixedMeshShapeReshardingGetTupleElementWithTupleOutput( HloInstruction* inst, const std::vector>& dst_sharding, - const Array& device_mesh, - absl::flat_hash_map>* - preserve_shardings); + const Array& device_mesh); void FixMixedMeshShapeResharding(HloInstruction* inst, int operand_num, const HloSharding& dst_sharding, @@ -527,10 +531,11 @@ std::vector> GetReplicaGroupsAlongOneDimension( // dimensions at 0, e.g., array is 2D and dim = 1, this returns array[0, 1], // array[1, 1], array [2, 1], .... // Returns error status if dim >= array.num_dimensions(). -StatusOr> GetValuesAlongOneDim(const Array& array, - int dim); +absl::StatusOr> GetValuesAlongOneDim( + const Array& array, int dim); -StatusOr CheckArithmeticSequence(absl::Span sequence); +absl::StatusOr CheckArithmeticSequence( + absl::Span sequence); // Checks if the number of sharded dimensions in the tile assignment matches the // device mesh. @@ -557,9 +562,11 @@ HloSharding Tile(const Shape& tensor_shape, absl::Span mesh_dims, const Array& device_mesh); -AliasMap BuildAliasMap(const HloModule* module); +AliasMap BuildAliasMap(const HloModule* module, + const HloInputOutputAliasConfig& alias_config); AliasSet BuildAliasSet(const HloModule* module, + const HloInputOutputAliasConfig& alias_config, const StrategyMap& strategy_map); // Transpose an array of any number of dimensions given any axes order. @@ -600,21 +607,12 @@ int64_t GetShardedInstructionSize( HloInstruction* FindInstruction( const std::vector& instructions, absl::string_view name); -double AllToAllCostUtil(double num_bytes, int mesh_dim, int64_t num_devices, - const std::vector& mesh_alpha, - const std::vector& mesh_beta); - -double ReshardingCostMixedMeshShape( - const Shape& shape, std::vector src_tensor_dim_to_mesh_dim, - std::vector dst_tensor_dim_to_mesh_dim, int64_t num_devices, - const std::vector& mesh_alpha, - const std::vector& mesh_beta); // When a complete mesh shape is [1, 8, 4], [1, 8, 1] is its partial mesh shape. // If a sharding is [8, 4] for the complete mesh shape, we convert it to [8, 1] // given [1, 8, 1] as the partial mesh shape. // total_num_devices should equal to the product of mesh_shape elements. -StatusOr AdjustShardingsWithPartialMeshShape( +absl::StatusOr AdjustShardingsWithPartialMeshShape( const std::vector& instructions, const std::vector& mesh_shape, int64_t total_num_devices, bool crash_on_error); @@ -655,6 +653,17 @@ std::vector> InferOrEnumerateMeshShapesToTry( const HloModule& module, int64_t num_devices, int num_mesh_dims, bool symmetrical_mesh_dims); +// Check if the sharding is "misaligned" wrt the shape. This is true if there is +// at least one dimension of the tensor that is sharded over a number of devices +// that do not complete divide the size of the tensor dimension. +bool IsShardingMisaligned(const HloSharding& sharding, const Shape& shape); + +// In a given tuple sharding, replace certain leaves with +// HloSharding::Unknown() +HloSharding ReplaceGivenShardingsWithUnknownForTuple( + const HloSharding& sharding, const Shape& shape, + absl::Span to_replace_sharding_ids); + } // namespace spmd } // namespace xla diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h b/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h index 008a4184e00b9..1811a2cc619e5 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,9 +18,13 @@ limitations under the License. #include #include +#include #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_solver.h" @@ -37,13 +41,16 @@ namespace spmd { // combinatorial optimization problem & solves it. AutoShardingSolverResult CallSolver( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, - const LivenessNodeSet& liveness_node_set, const StrategyMap& strategy_map, + const LivenessNodeSet& liveness_node_set, + const LivenessEdgeSet& liveness_edge_set, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, const std::vector& s_hint, - bool compute_iis, int64_t solver_timeout_in_seconds, - const AutoShardingOption& option, + const absl::flat_hash_set& peak_times, bool compute_iis, + int64_t solver_timeout_in_seconds, const AutoShardingOption& option, + std::optional max_cost, absl::string_view request_name, const absl::flat_hash_map& - sharding_propagation_solution = {}); + sharding_propagation_solution = {}, + bool deterministic_mode = false); // Computes the penalty to be used for fully replicated sharding strategies for // dots and convs. diff --git a/xla/hlo/experimental/auto_sharding/cluster_environment.cc b/xla/hlo/experimental/auto_sharding/cluster_environment.cc index fc4fe744542ff..5e66abd4f7b64 100644 --- a/xla/hlo/experimental/auto_sharding/cluster_environment.cc +++ b/xla/hlo/experimental/auto_sharding/cluster_environment.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,15 +16,18 @@ limitations under the License. #include "xla/hlo/experimental/auto_sharding/cluster_environment.h" #include +#include #include #include #include #include #include +#include "absl/types/span.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" #include "xla/service/spmd/spmd_partitioner_util.h" +#include "xla/shape.h" namespace xla { namespace spmd { @@ -99,6 +102,17 @@ double ClusterEnvironment::ReduceScatterCost(double num_bytes, 0.001); } +double ClusterEnvironment::AllToAllCostUtil(double num_bytes, int mesh_dim, + int64_t num_devices) const { + // A penalty factor to make the theoretical cost match the + // empirical cost on v100 + nvlink. + double penalty_factor = static_cast(num_devices) / 2.0; + return (round(mesh_alpha_[mesh_dim] + + mesh_beta_[mesh_dim] * (num_devices - 1) / num_devices / + num_devices * num_bytes * penalty_factor) + + 0.001); +} + double ClusterEnvironment::AllToAllCost(double num_bytes, int mesh_dim) const { if (auto_sharding_option_.force_override_all_to_all_cost) { return auto_sharding_option_.all_to_all_cost; @@ -116,26 +130,43 @@ double ClusterEnvironment::AllToAllCost(double num_bytes, int mesh_dim) const { } int64_t num_devices = device_mesh_.dim(mesh_dim); - return AllToAllCostUtil(num_bytes, mesh_dim, num_devices, mesh_alpha_, - mesh_beta_); + return AllToAllCostUtil(num_bytes, mesh_dim, num_devices); } -double ClusterEnvironment::DotCost(const Shape& lhs_shape, - const Shape& rhs_shape) const { - if (!auto_sharding_option_.allow_recompute_heavy_op) { - return kInfinityCost; +// Do not consider device id changes yet. +double ClusterEnvironment::ReshardingCostMixedMeshShape( + const Shape& shape, absl::Span src_tensor_dim_to_mesh_dim, + absl::Span dst_tensor_dim_to_mesh_dim) const { + int64_t num_devices = device_mesh_.num_elements(); + double resharding_costs = 0.0; + for (size_t i = 0; i < shape.rank(); ++i) { + // Only consider sharded dimensions, do not consider replicate_on_last_dim. + if (src_tensor_dim_to_mesh_dim[i] == dst_tensor_dim_to_mesh_dim[i]) { + continue; + } + if (dst_tensor_dim_to_mesh_dim[i] == -1 || + src_tensor_dim_to_mesh_dim[i] == -1) { + // AllToAll cost + int64_t communication_dim; + if (dst_tensor_dim_to_mesh_dim[i] != -1) { + communication_dim = dst_tensor_dim_to_mesh_dim[i]; + } else { + communication_dim = src_tensor_dim_to_mesh_dim[i]; + } + int64_t communication_bytes = GetBytes(shape); + resharding_costs += + AllToAllCostUtil(communication_bytes, communication_dim, num_devices); + } else { + // Do not support this sharding, assuming it is gonna be very expensive. + return kInfinityCost; + } } - - // TODO(zhuohan): When profiling data is not available, it is not easy to - // align the scale of compute cost and communication cost. Here we just use - // a simple heuristic to compute the compute cost with communication cost. - double num_bytes = GetBytes(lhs_shape) + GetBytes(rhs_shape); - return AllReduceCost(num_bytes, 0) + AllReduceCost(num_bytes, 1); + return resharding_costs; } double ClusterEnvironment::CollectivePermuteCost( double num_bytes, - const std::vector>& src_dst_pairs) const { + absl::Span> src_dst_pairs) const { absl::flat_hash_map> device_to_index_map; device_mesh_.Each([&](absl::Span indices, int64_t device) { std::vector indices_vector; @@ -163,17 +194,18 @@ double ClusterEnvironment::CollectivePermuteCost( // Overestimate the cost of replicating a tensor by decomposing the resharding // operation as an all-gather on all mesh dimensions. double ClusterEnvironment::OverestimateReplicationCost( - const Shape& shape, const HloSharding& src_spec) const { + const Shape& shape, const HloSharding& src_spec, + const Array& device_mesh) const { if (src_spec.IsTileMaximal() || src_spec.IsManual()) { // TODO(b/238210866) Do not use kInfinityCost. return kInfinityCost; } int64_t bytes_moved = GetBytes(shape) / src_spec.NumTiles(); double cost = 0.0; - for (size_t i = 0; i < device_mesh_.num_dimensions(); ++i) { + for (size_t i = 0; i < device_mesh.num_dimensions(); ++i) { auto this_cost = this->AllGatherCost(bytes_moved, i); cost += this_cost; - bytes_moved *= device_mesh_.dimensions()[i]; + bytes_moved *= device_mesh.dimensions()[i]; } return cost; } @@ -213,10 +245,9 @@ double ClusterEnvironment::TryCollectivePermuteForResharding( // Since we only estimate communication costs here, we only need to consider // the cost of step 1, ie. replicating the tensor starting from sharding // s2. We estimate this cost by invoking OverestimateReplicationCost. - return OverestimateReplicationCost(shape, src_spec); + return OverestimateReplicationCost(shape, src_spec, device_mesh_); } -// The communication cost of resharding a tensor from src to dst double ClusterEnvironment::ReshardingCost(const Shape& shape, const HloSharding& src_spec, const HloSharding& dst_spec) const { @@ -226,6 +257,17 @@ double ClusterEnvironment::ReshardingCost(const Shape& shape, return 0.0; } + if (src_spec.tile_assignment().num_elements() > device_mesh_.num_elements() || + dst_spec.tile_assignment().num_elements() > device_mesh_.num_elements()) { + LOG(WARNING) + << "Full device sharding found when solving for the partial mesh " + << spmd::ToString(device_mesh_.dimensions()) + << ". Overestimating the resharding cost by assuming full replication " + "on the full device mesh " + << spmd::ToString(device_mesh_.dimensions()) << "."; + return OverestimateReplicationCost(shape, src_spec, original_device_mesh_); + } + CHECK(!IsUndefined(dst_spec)); int64_t src_n_dim = NumTileDimensions(src_spec); int64_t dst_n_dim = NumTileDimensions(dst_spec); @@ -283,9 +325,8 @@ double ClusterEnvironment::ReshardingCost(const Shape& shape, dst_tensor_dim_to_mesh_dim_or.value(); if (src_n_dim != dst_n_dim && src_n_dim != -1 && dst_n_dim != -1) { - return ReshardingCostMixedMeshShape( - shape, src_tensor_dim_to_mesh_dim, dst_tensor_dim_to_mesh_dim, - device_mesh_.num_elements(), mesh_alpha_, mesh_beta_); + return ReshardingCostMixedMeshShape(shape, src_tensor_dim_to_mesh_dim, + dst_tensor_dim_to_mesh_dim); } AdjustTensorMeshDimMapping(src_tensor_dim_to_mesh_dim, src_n_dim); diff --git a/xla/hlo/experimental/auto_sharding/cluster_environment.h b/xla/hlo/experimental/auto_sharding/cluster_environment.h index 25c3e6ba26cad..19736d19e25f0 100644 --- a/xla/hlo/experimental/auto_sharding/cluster_environment.h +++ b/xla/hlo/experimental/auto_sharding/cluster_environment.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -50,7 +50,8 @@ class ClusterEnvironment { mesh_beta_(mesh_beta.begin(), mesh_beta.end()), prof_result_(prof_result), total_devices_(device_mesh.num_elements()), - device_mesh_1d_(original_device_mesh), + device_mesh_1d_(device_mesh), + original_device_mesh_1d_(original_device_mesh), auto_sharding_option_(auto_sharding_option) { // Build replica group for each dimension. non_zero_mesh_dims_ = @@ -68,11 +69,15 @@ class ClusterEnvironment { original_device_mesh_shape.end()); size_t largest_dim_idx = std::distance(original_device_mesh_shape.begin(), max_dim_iterator); + std::vector device_mesh_1d_shape(device_mesh.num_dimensions(), 1); + device_mesh_1d_shape[largest_dim_idx] = device_mesh.num_elements(); + device_mesh_1d_.Reshape(device_mesh_1d_shape); - std::vector device_mesh_1d_shape( + std::vector original_device_mesh_1d_shape( original_device_mesh.num_dimensions(), 1); - device_mesh_1d_shape[largest_dim_idx] = original_device_mesh.num_elements(); - device_mesh_1d_.Reshape(device_mesh_1d_shape); + original_device_mesh_1d_shape[largest_dim_idx] = + original_device_mesh.num_elements(); + original_device_mesh_1d_.Reshape(original_device_mesh_1d_shape); } size_t NumDevices() const { return total_devices_; } @@ -125,20 +130,23 @@ class ClusterEnvironment { double AllToAllCost(double num_bytes, int mesh_dim) const; + double ReshardingCostMixedMeshShape( + const Shape& shape, absl::Span src_tensor_dim_to_mesh_dim, + absl::Span dst_tensor_dim_to_mesh_dim) const; + double CollectivePermuteCost( double num_bytes, - const std::vector>& src_dst_pairs) const; + absl::Span> src_dst_pairs) const; double TryCollectivePermuteForResharding(const Shape& shape, const HloSharding& src_spec, const HloSharding& dst_spec) const; - double DotCost(const Shape& lhs_shape, const Shape& rhs_shape) const; - // This function attempts to overestimate the cost of replicating a tensor of // shape `shape` sharded according to `src_spec`. double OverestimateReplicationCost(const Shape& shape, - const HloSharding& src_spec) const; + const HloSharding& src_spec, + const Array& device_mesh) const; double ReshardingCost(const Shape& shape, const HloSharding& src_spec, const HloSharding& dst_spec) const; @@ -170,6 +178,10 @@ class ClusterEnvironment { // Used for mixed mesh shape strategies. Array device_mesh_1d_; + // Cache a flatten 1d version of the original device mesh. + // Used for mixed mesh shape strategies. + Array original_device_mesh_1d_; + // The option may override the cost of communication primitives const AutoShardingOption& auto_sharding_option_; @@ -177,6 +189,9 @@ class ClusterEnvironment { std::vector>> cached_replica_groups_; private: + double AllToAllCostUtil(double num_bytes, int mesh_dim, + int64_t num_devices) const; + void GenerateCachedReplicaGroups() { // One vector per device_mesh_ dimension. cached_replica_groups_.reserve(device_mesh_.num_dimensions()); diff --git a/xla/hlo/experimental/auto_sharding/matrix.h b/xla/hlo/experimental/auto_sharding/matrix.h index 90db18f591b07..903973eea5a3a 100644 --- a/xla/hlo/experimental/auto_sharding/matrix.h +++ b/xla/hlo/experimental/auto_sharding/matrix.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,6 +36,7 @@ namespace spmd { // It can create a view for matrix transpose without copying the memory. // TODO (zhuohan): Inherit from Array2D and add Transpose and operator+ (See // tensorflow/compiler/xla/array2d.h;l=39) +template class Matrix { public: Matrix() : n_(0), m_(0), transpose_(false), data_(nullptr) {} @@ -44,11 +45,11 @@ class Matrix { this->n_ = n; this->m_ = m; transpose_ = false; - data_ = std::make_shared>(n * m, 0.0); + data_ = std::make_shared>(n * m, T()); } Matrix(size_t n, size_t m, bool transpose, - std::shared_ptr> data) { + std::shared_ptr> data) { this->n_ = n; this->m_ = m; this->transpose_ = transpose; @@ -57,7 +58,7 @@ class Matrix { Matrix Transpose() { return Matrix(m_, n_, !transpose_, data_); } - double operator()(size_t i, size_t j) const { + T operator()(size_t i, size_t j) const { size_t idx; if (transpose_) { idx = j * n_ + i; @@ -69,7 +70,7 @@ class Matrix { return (*data_)[idx]; } - double& operator()(size_t i, size_t j) { + T& operator()(size_t i, size_t j) { size_t idx; if (transpose_) { idx = j * n_ + i; @@ -81,7 +82,7 @@ class Matrix { return (*data_)[idx]; } - Matrix operator+(const Matrix& other) { + Matrix operator+(const Matrix& other) { CHECK_EQ(n_, other.n_); CHECK_EQ(m_, other.m_); Matrix ret = Matrix(n_, m_); @@ -98,7 +99,7 @@ class Matrix { for (size_t i = 0; i < n_; ++i) { for (size_t j = 0; j < m_; ++j) { - absl::StrAppend(&str, operator()(i, j), " "); + absl::StrAppend(&str, operator()(i, j).ToString(), " "); } absl::StrAppend(&str, "\n"); } @@ -109,7 +110,7 @@ class Matrix { size_t n_; size_t m_; bool transpose_; - std::shared_ptr> data_; + std::shared_ptr> data_; }; } // namespace spmd } // namespace xla diff --git a/xla/hlo/experimental/auto_sharding/metrics.cc b/xla/hlo/experimental/auto_sharding/metrics.cc index efe6b1086eb1c..0c52981cf2df9 100644 --- a/xla/hlo/experimental/auto_sharding/metrics.cc +++ b/xla/hlo/experimental/auto_sharding/metrics.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/experimental/auto_sharding/metrics.h b/xla/hlo/experimental/auto_sharding/metrics.h index 11a684bbb1f89..31f56fe1b5e26 100644 --- a/xla/hlo/experimental/auto_sharding/metrics.h +++ b/xla/hlo/experimental/auto_sharding/metrics.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/experimental/auto_sharding/profiling_result.h b/xla/hlo/experimental/auto_sharding/profiling_result.h index f95d80c39b073..873aabf786388 100644 --- a/xla/hlo/experimental/auto_sharding/profiling_result.h +++ b/xla/hlo/experimental/auto_sharding/profiling_result.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/ir/BUILD b/xla/hlo/ir/BUILD index 2ffa2da72e198..39e501b034dbf 100644 --- a/xla/hlo/ir/BUILD +++ b/xla/hlo/ir/BUILD @@ -1,11 +1,12 @@ # Description: # XLA’s HLO Intermediate Representation implementation. +load("@tsl//tsl:tsl.bzl", "internal_visibility") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [":friends"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -55,6 +56,7 @@ cc_library( "hlo_sharding_metadata.h", ], deps = [ + ":ptrvec", ":tile_assignment", "//xla:array", "//xla:comparison_util", @@ -81,15 +83,18 @@ cc_library( "//xla/service:name_uniquer", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@tsl//tsl/lib/gtl:iterator_range", "@tsl//tsl/lib/gtl:map_util", @@ -129,6 +134,40 @@ cc_library( ], ) +cc_library( + name = "hlo_dfs_reachability", + srcs = ["hlo_dfs_reachability.cc"], + hdrs = ["hlo_dfs_reachability.h"], + deps = [ + ":hlo", + "@com_google_absl//absl/algorithm:container", + "@llvm-project//llvm:Support", + ], +) + +cc_library( + name = "ptrvec", + hdrs = ["ptrvec.h"], + deps = [ + "@com_google_absl//absl/log:check", + "@tsl//tsl/platform:logging", + ], +) + +cc_test( + name = "ptrvec_test", + srcs = ["ptrvec_test.cc"], + deps = [ + ":ptrvec", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_benchmark", + "@tsl//tsl/platform:test_main", + ], +) + cc_library( name = "tile_assignment", srcs = ["tile_assignment.cc"], diff --git a/xla/hlo/ir/dfs_hlo_visitor.cc b/xla/hlo/ir/dfs_hlo_visitor.cc index 3afdc7a01005d..22f031eb7f77f 100644 --- a/xla/hlo/ir/dfs_hlo_visitor.cc +++ b/xla/hlo/ir/dfs_hlo_visitor.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/ir/dfs_hlo_visitor.h b/xla/hlo/ir/dfs_hlo_visitor.h index 1d981823c7479..29d5715716bd5 100644 --- a/xla/hlo/ir/dfs_hlo_visitor.h +++ b/xla/hlo/ir/dfs_hlo_visitor.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -128,6 +128,7 @@ class DfsHloVisitorBase { virtual Status HandleAllReduceDone(HloInstructionPtr hlo) = 0; virtual Status HandleAllReduceStart(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; + virtual Status HandleCollectiveBroadcast(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermuteDone(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermuteStart(HloInstructionPtr hlo) = 0; @@ -175,6 +176,9 @@ class DfsHloVisitorBase { virtual Status HandleRoundNearestEven(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } + virtual Status HandleErf(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleLogistic(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } diff --git a/xla/hlo/ir/dfs_hlo_visitor_with_default.h b/xla/hlo/ir/dfs_hlo_visitor_with_default.h index 09b3689274438..93b085772cd62 100644 --- a/xla/hlo/ir/dfs_hlo_visitor_with_default.h +++ b/xla/hlo/ir/dfs_hlo_visitor_with_default.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/base/optimization.h" +#include "absl/log/log.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" #include "xla/hlo/ir/hlo_computation.h" @@ -119,6 +120,9 @@ class DfsHloVisitorWithDefaultBase Status HandleAllToAll(HloInstructionPtr hlo) override { return DefaultAction(hlo); } + Status HandleCollectiveBroadcast(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } Status HandleCollectivePermute(HloInstructionPtr hlo) override { return DefaultAction(hlo); } @@ -297,7 +301,7 @@ using ConstDfsHloVisitorWithDefault = class DfsHloRewriteVisitor : public DfsHloVisitorWithDefault { public: // Runs a visitor on the module and returns whether the module has changed. - StatusOr RunOnModule( + absl::StatusOr RunOnModule( HloModule* module, const absl::flat_hash_set& execution_threads = {}) { Status status; @@ -323,8 +327,8 @@ class DfsHloRewriteVisitor : public DfsHloVisitorWithDefault { Status ReplaceWithNewInstruction( HloInstruction* old_instruction, std::unique_ptr new_instruction) { - VLOG(3) << "Replacing instruction:" - << "\n old: " << old_instruction->ToString() + VLOG(3) << "Replacing instruction:" << "\n old: " + << old_instruction->ToString() << "\n new: " << new_instruction->ToString(); Status status = old_instruction->parent()->ReplaceWithNewInstruction( old_instruction, std::move(new_instruction)); @@ -337,14 +341,15 @@ class DfsHloRewriteVisitor : public DfsHloVisitorWithDefault { // Replaces the existing HLO instruction old_instruction, with // new_instruction, and marks the optimizer status as changed. // Returns the Status representing the result of the replace operation. - StatusOr ReplaceInstruction(HloInstruction* old_instruction, - HloInstruction* new_instruction, - bool preserve_sharding) { - VLOG(3) << "Replacing instruction:" - << "\n old: " << old_instruction->ToString() + absl::StatusOr ReplaceInstruction(HloInstruction* old_instruction, + HloInstruction* new_instruction, + bool preserve_sharding) { + VLOG(3) << "Replacing instruction:" << "\n old: " + << old_instruction->ToString() << "\n new: " << new_instruction->ToString(); - StatusOr changed_or = old_instruction->parent()->ReplaceInstruction( - old_instruction, new_instruction, preserve_sharding); + absl::StatusOr changed_or = + old_instruction->parent()->ReplaceInstruction( + old_instruction, new_instruction, preserve_sharding); if (ABSL_PREDICT_TRUE(changed_or.ok())) { changed_ |= changed_or.value(); } @@ -353,7 +358,7 @@ class DfsHloRewriteVisitor : public DfsHloVisitorWithDefault { Status ReplaceInstruction(HloInstruction* old_instruction, HloInstruction* new_instruction) { - StatusOr changed_or = + absl::StatusOr changed_or = ReplaceInstruction(old_instruction, new_instruction, /*preserve_sharding=*/false); if (ABSL_PREDICT_TRUE(changed_or.ok())) { diff --git a/xla/hlo/ir/dynamic_parameter_binding.cc b/xla/hlo/ir/dynamic_parameter_binding.cc index 049676a795659..ca96fbe135113 100644 --- a/xla/hlo/ir/dynamic_parameter_binding.cc +++ b/xla/hlo/ir/dynamic_parameter_binding.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/ir/dynamic_parameter_binding.h b/xla/hlo/ir/dynamic_parameter_binding.h index 77edb06ee44f1..70a199db63532 100644 --- a/xla/hlo/ir/dynamic_parameter_binding.h +++ b/xla/hlo/ir/dynamic_parameter_binding.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/ir/hlo_casting_utils.h b/xla/hlo/ir/hlo_casting_utils.h index 38549d68fad5f..ff635675bca9c 100644 --- a/xla/hlo/ir/hlo_casting_utils.h +++ b/xla/hlo/ir/hlo_casting_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/ir/hlo_clone_context.h b/xla/hlo/ir/hlo_clone_context.h index b9115bb46cae4..b1eaa31306c7c 100644 --- a/xla/hlo/ir/hlo_clone_context.h +++ b/xla/hlo/ir/hlo_clone_context.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/ir/hlo_computation.cc b/xla/hlo/ir/hlo_computation.cc index 30af92643a7b1..67489fd30b3d6 100644 --- a/xla/hlo/ir/hlo_computation.cc +++ b/xla/hlo/ir/hlo_computation.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,9 +18,10 @@ limitations under the License. #include #include #include -#include +#include #include #include +#include #include #include #include @@ -30,6 +31,9 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/functional/function_ref.h" +#include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "xla/hlo/ir/hlo_clone_context.h" @@ -37,12 +41,16 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/ptrvec.h" #include "xla/map_util.h" #include "xla/printer.h" #include "xla/service/mapped_ptr_container_sorter.h" +#include "xla/service/name_uniquer.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" +#include "tsl/lib/gtl/iterator_range.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" @@ -51,6 +59,61 @@ namespace xla { using absl::StrCat; +enum class VisitState { kNew = 0, kVisiting = 1, kVisited = 2 }; + +static std::ostream& operator<<(std::ostream& os, const VisitState& state) { + switch (state) { + case VisitState::kNew: + os << "new"; + break; + case VisitState::kVisiting: + os << "visiting"; + break; + case VisitState::kVisited: + os << "visited"; + break; + } + return os; +} + +class HloComputation::VisitMap { + public: + VisitMap() = default; + explicit VisitMap(int capacity) : size_(capacity) { + int num_words = (capacity + 31) / 32; + bits_.resize(num_words); + bit_ptr_ = bits_.empty() ? nullptr : bits_.data(); + } + + // A handle is a dense index used to identify a particular node. + using Handle = uint32_t; + + // Returns the current VisitState for the instruction with handle "h" + VisitState GetState(Handle h) const { + DCHECK_LT(h, size_); + uint32_t word = (h / 32); + uint32_t shift = (h % 32) << 1; + return static_cast((bit_ptr_[word] >> shift) & 0x3); + } + + // Sets the VisitState for the instruction with Handle "h" to "new_state" + void SetState(Handle h, VisitState new_state) { + DCHECK_LT(h, size_); + uint32_t word = (h / 32); + uint32_t shift = (h % 32) << 1; + uint64_t mask = ~(3ull << shift); + uint64_t val = static_cast(new_state); + bit_ptr_[word] = (bit_ptr_[word] & mask) | (val << shift); + } + + private: + // bits_ stores VisitState entries (2 bits per entry, packed 32 entries per + // 64-bit word) + absl::InlinedVector bits_; + uint64_t* bit_ptr_ = nullptr; // + int size_ = 0; // Number of entries. bits_ holds at least 2 * this many bits +}; + std::unique_ptr HloComputation::Builder::Build( HloInstruction* root_instruction) { int parameter_count = 0; @@ -63,25 +126,18 @@ std::unique_ptr HloComputation::Builder::Build( HloInstruction* root = root_instruction ? root_instruction : last_added_instruction(); CHECK_NE(nullptr, root); - return absl::WrapUnique(new HloComputation( - name_, parameter_count, &instructions_, root, fusion_instruction_)); + return absl::WrapUnique( + new HloComputation(name_, parameter_count, &instructions_, root)); } HloComputation::HloComputation( const std::string& name, int parameter_count, std::vector>* instructions, - HloInstruction* root_instruction, HloInstruction* fusion_instruction) - : name_(NameUniquer::GetSanitizedName(name)), - unique_id_(-1), + HloInstruction* root_instruction) + : unique_id_(-1), root_instruction_(root_instruction), - fusion_instruction_(fusion_instruction), - is_fusion_computation_(fusion_instruction != nullptr), - custom_call_instruction_(nullptr), - is_custom_call_computation_(false), - collective_call_instruction_(nullptr), - is_collective_called_computation_(false), - while_call_instruction_(nullptr), - is_while_call_body_computation_(false) { + instruction_count_(0), + name_(NameUniquer::GetSanitizedName(name)) { param_instructions_.resize(parameter_count, nullptr); bool root_found = false; for (auto& instruction : *instructions) { @@ -100,21 +156,44 @@ HloComputation::HloComputation( } CHECK(root_found) << "\nERROR: root instruction is not present in computation."; + root_instruction_->MarkAsRoot(); } HloComputation::~HloComputation() { - if (fusion_instruction_ != nullptr) { - CHECK(fusion_instruction_->fused_instructions_computation() == this); - fusion_instruction_->ClearCalledComputations(); - fusion_instruction_ = nullptr; + if (FusionInstruction() != nullptr) { + CHECK(FusionInstruction()->fused_instructions_computation() == this); + FusionInstruction()->ClearCalledComputations(); } if (IsAsyncComputation()) { - for (auto* async_instr : async_instructions_) { - CHECK(async_instr->async_wrapped_computation() == this); - async_instr->ClearCalledComputations(); - } - async_instructions_.clear(); + CHECK(async_start_->async_wrapped_computation() == this); + async_start_->ClearCalledComputations(); + } + for (const auto& i : instructions_) { + delete i.inst(); + } + Cleanup(); +} + +void HloComputation::SetInstruction(HloInstruction* instruction, + InstructionType type) { + static_assert(alignof(HloInstruction) == kInstructionTypeMask + 1, + "HloInstruction should be aligned as a QWORD"); + + DCHECK(type != InstructionType::kUnset) + << "Set instruction must be called with a valid type, not kUnset."; + DCHECK(instruction_type() == InstructionType::kUnset || + instruction_type() == type) + << "Unexpected instruction type. Current type is " + << static_cast(instruction_type()) << " and it cannot be reset to " + << static_cast(type); + + // If `instruction` is nullptr, we need to preserve the existing type. + if (instruction == nullptr) { + type = instruction_type(); } + + instruction_and_type_ = + reinterpret_cast(instruction) | static_cast(type); } HloInstruction* HloComputation::AddInstruction( @@ -143,9 +222,16 @@ HloInstruction* HloComputation::AddInstructionInternal( instruction->SetUniqueId(parent()->NewUniqueInstructionId()); } instruction->set_parent(this); - HloInstruction* pinst = instruction.get(); - instruction_iterators_[pinst] = - instructions_.insert(instructions_.end(), std::move(instruction)); + HloInstruction* pinst = instruction.release(); // Take ownership + HloInstructionInfo info; + info.opcode_ = pinst->opcode(); + info.inst_ = pinst; + VLOG(2) << "Adding instruction " << pinst << " " << pinst->name() + << " from computation " << name() << " opcode " << info.opcode(); + uint32_t index = instructions_.size(); + instruction_count_++; + pinst->index_in_parent_ = index; + instructions_.push_back(info); return pinst; } @@ -153,7 +239,7 @@ HloInstruction* HloComputation::AddParameter( std::unique_ptr instruction) { CHECK(instruction->opcode() == HloOpcode::kParameter); CHECK(!IsFusionComputation() || - fusion_instruction_->operand_count() == param_instructions_.size()); + FusionInstruction()->operand_count() == param_instructions_.size()); instruction->set_parent(this); param_instructions_.push_back(instruction.get()); AddInstructionInternal(std::move(instruction)); @@ -227,7 +313,7 @@ HloInstruction* HloComputation::ReplaceParameter( CHECK_LT(param_no, param_instructions_.size()); CHECK(instruction->opcode() == HloOpcode::kParameter); CHECK(!IsFusionComputation() || - fusion_instruction_->operand_count() == param_instructions_.size()); + FusionInstruction()->operand_count() == param_instructions_.size()); instruction->set_parent(this); HloInstruction* new_instruction = @@ -236,7 +322,7 @@ HloInstruction* HloComputation::ReplaceParameter( TF_CHECK_OK( old_instruction->ReplaceAllUsesWithDifferentShape(new_instruction)); param_instructions_[param_no] = new_instruction; - TF_CHECK_OK(RemoveInstruction(old_instruction)); + TF_CHECK_OK(ForceRemoveInstruction(old_instruction)); return new_instruction; } @@ -355,8 +441,8 @@ Status HloComputation::ForceRemoveInstruction(HloInstruction* instruction) { Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction, bool ignore_safety_check) { - VLOG(2) << "Removing instruction " << instruction->name() - << " from computation " << name(); + VLOG(2) << "Removing instruction " << instruction << " " + << instruction->name() << " from computation " << name(); TF_RET_CHECK(ignore_safety_check || IsSafelyRemovable(instruction)) << "cannot remove instruction: " << instruction->ToString(); TF_RET_CHECK(instruction->IsDead()) << "instruction " << instruction->name() @@ -368,17 +454,25 @@ Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction, << "instruction " << instruction->name() << " has control successors and cannot be removed"; - auto inst_it = instruction_iterators_.find(instruction); - TF_RET_CHECK(inst_it != instruction_iterators_.end()); - (*inst_it->second)->set_parent(nullptr); - to_be_deleted_.emplace_back(inst_it->second->release()); + HloInstructionInfo* info = &instructions_[instruction->index_in_parent_]; + DCHECK_EQ(info->inst(), instruction); + info->inst()->set_parent(nullptr); + to_be_deleted_.push_back(info->inst()); // Takes ownership to_be_deleted_.back()->DetachFromOperandsAndUsers(); // Clear all operands to avoid Null operands. to_be_deleted_.back()->RemoveAllOperands(); - to_be_deleted_.back()->ClearCalledComputations(); + // These require non-trivial cleanup for their called computations, + // which is invoked in the ops destructor. + if (!to_be_deleted_.back()->IsAsynchronous() && + !to_be_deleted_.back()->IsFused()) { + to_be_deleted_.back()->ClearCalledComputations(); + } to_be_deleted_.back()->MarkAsDead(); - instructions_.erase(inst_it->second); - instruction_iterators_.erase(inst_it); + // TODO(jeff): should we set info->opcode to something? + info->inst_ = + nullptr; // Leave a hole: this is no longer part of "instructions()" + instruction->index_in_parent_ = ~0u; + instruction_count_--; return OkStatus(); } @@ -411,35 +505,52 @@ void HloComputation::set_root_instruction(HloInstruction* new_root_instruction, } } + // `root_instruction_` can be equal to `new_root_instruction` and so it is + // important that we call MarkAsNonRoot before calling MarkAsRoot. + root_instruction_->MarkAsNonRoot(); + new_root_instruction->MarkAsRoot(); root_instruction_ = new_root_instruction; } void HloComputation::ComputeInstructionPostOrder( HloInstruction* root, const ChannelDependencies& channel_dependencies, - absl::flat_hash_map& visited, - std::vector& post_order) const { + VisitMap& visited, std::vector& post_order, + std::vector* dfs_stack_scratch) const { ForEachInstructionPostOrderImpl( [&post_order](HloInstruction* hlo) { post_order.push_back(hlo); }, root, - channel_dependencies, visited); + channel_dependencies, visited, dfs_stack_scratch); } void HloComputation::ForEachInstructionPostOrderImpl( absl::FunctionRef func, HloInstruction* root, - const ChannelDependencies& channel_dependencies, - absl::flat_hash_map& visited) const { - std::vector dfs_stack = {root}; - while (!dfs_stack.empty()) { - HloInstruction& current = *dfs_stack.back(); - - auto [it, was_inserted] = visited.insert({¤t, kVisiting}); - if (!was_inserted) { // We've already seen this instruction. - dfs_stack.pop_back(); - if (it->second != kVisited) { - DCHECK_EQ(current.parent(), this) - << "Instruction " << current.name() - << " is not in the current computation (" << name() << ")."; - func(¤t); - it->second = kVisited; + const ChannelDependencies& channel_dependencies, VisitMap& visited, + std::vector* dfs_stack_scratch) const { + bool has_channel_dependencies = !channel_dependencies.empty(); + auto* dfs_stack = dfs_stack_scratch; + dfs_stack->clear(); + + // Pushes instruction to dfs stack only if it was not already processed. + auto dfs_stack_push = [&](HloInstruction* instr) { + VisitState state = visited.GetState(instr->index_in_parent_); + if (state != VisitState::kVisited) dfs_stack->push_back(instr); + }; + + dfs_stack_push(root); + while (!dfs_stack->empty()) { + HloInstruction* current = dfs_stack->back(); + DCHECK_EQ(current->parent(), this) + << "Instruction " << current->name() + << " is not in the current computation (" << name() << ")."; + + VisitMap::Handle h = current->index_in_parent_; + VisitState state = visited.GetState(h); + if (state == VisitState::kNew) { + visited.SetState(h, VisitState::kVisiting); + } else { + dfs_stack->pop_back(); + if (state != VisitState::kVisited) { + visited.SetState(h, VisitState::kVisited); + func(current); } continue; } @@ -448,22 +559,22 @@ void HloComputation::ForEachInstructionPostOrderImpl( // Collectives with the same channel ID must be performed together, as these // represent MPMD-partitioned that will later be split into separate modules // and the order must be preserved. - if (¤t != root) { - auto it = channel_dependencies.find(¤t); + if (has_channel_dependencies && current != root) { + auto it = channel_dependencies.find(current); if (it != channel_dependencies.end()) { - dfs_stack.insert(dfs_stack.end(), it->second.begin(), it->second.end()); + absl::c_for_each(it->second, dfs_stack_push); } } // Add the operands to the stack in reverse order so the first operand is // processed first. This will produce a more natural ordering and a nicer // result for things like HLO stringification. - const HloInstruction::InstructionVector& operands = current.operands(); - dfs_stack.insert(dfs_stack.end(), operands.rbegin(), operands.rend()); + const HloInstruction::InstructionVector& operands = current->operands(); + absl::c_for_each(tsl::gtl::make_range(operands.rbegin(), operands.rend()), + dfs_stack_push); - const std::vector& predecessors = - current.control_predecessors(); - dfs_stack.insert(dfs_stack.end(), predecessors.begin(), predecessors.end()); + // Add control predecessors to the stack. + absl::c_for_each(current->control_predecessors(), dfs_stack_push); } } @@ -480,21 +591,23 @@ HloComputation::ChannelDependencies HloComputation::ComputeChannelDependencies() // Create dependencies between partitioned collectives. ChannelDependencies dependencies; - for (const auto& instruction : instructions_) { - switch (instruction->opcode()) { + for (const auto& inst : instructions_with_info()) { + switch (inst.opcode()) { case HloOpcode::kAllReduce: case HloOpcode::kAllGather: case HloOpcode::kAllToAll: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: case HloOpcode::kReduceScatter: { + HloInstruction* instruction = inst.inst(); std::optional channel_id = instruction->channel_id(); if (channel_id) { Instructions& group = channel_groups[*channel_id]; for (const HloInstruction* group_inst : group) { - dependencies[group_inst].push_back(instruction.get()); + dependencies[group_inst].push_back(instruction); } - dependencies[instruction.get()] = group; - group.push_back(instruction.get()); + dependencies[instruction] = group; + group.push_back(instruction); } break; } @@ -508,9 +621,11 @@ HloComputation::ChannelDependencies HloComputation::ComputeChannelDependencies() std::vector HloComputation::MakeInstructionPostOrderFrom( HloInstruction& postorder_root) const { std::vector post_order; - absl::flat_hash_map visited; + VisitMap visited(instructions_.size()); + + std::vector dfs_stack_scratch; ComputeInstructionPostOrder(&postorder_root, ComputeChannelDependencies(), - visited, post_order); + visited, post_order, &dfs_stack_scratch); return post_order; } @@ -522,15 +637,17 @@ std::vector HloComputation::MakeInstructionPostOrder( const ChannelDependencies& channel_dependencies) const { std::vector post_order; post_order.reserve(instruction_count()); - absl::flat_hash_map visited; - visited.reserve(instruction_count()); - for (auto& instruction : instructions_) { + VisitMap visited(instructions_.size()); + std::vector dfs_stack_scratch; + dfs_stack_scratch.reserve(instruction_count()); + + for (const auto& instruction : instructions()) { if (instruction->users().empty()) { - ComputeInstructionPostOrder(instruction.get(), channel_dependencies, - visited, post_order); + ComputeInstructionPostOrder(instruction, channel_dependencies, visited, + post_order, &dfs_stack_scratch); } } - CHECK_EQ(instructions_.size(), post_order.size()) + CHECK_EQ(instruction_count(), post_order.size()) << "number of instructions does not match post order size"; return post_order; } @@ -600,13 +717,14 @@ HloComputation::MakeInstructionPostOrderWithReshapeFirst() const { void HloComputation::ForEachInstructionPostOrder( absl::FunctionRef func) const { - absl::flat_hash_map visited; - visited.reserve(instruction_count()); + VisitMap visited(instructions_.size()); + std::vector dfs_stack_scratch; + dfs_stack_scratch.reserve(instruction_count()); auto channel_dependencies = ComputeChannelDependencies(); - for (auto& instruction : instructions_) { + for (const auto& instruction : instructions()) { if (instruction->users().empty()) { - ForEachInstructionPostOrderImpl(func, instruction.get(), - channel_dependencies, visited); + ForEachInstructionPostOrderImpl(func, instruction, channel_dependencies, + visited, &dfs_stack_scratch); } } } @@ -618,25 +736,32 @@ std::vector HloComputation::MakeEmbeddedComputationsList() // The first element of the pair is the currently processed computation, the // second is iterator inside the instructions list of the computation that is // currently being processed. - std::stack> st; + using ComputationIter = + std::pair; + std::stack> st; // We cannot directly push (this, instructions_.cbegin()) to the stack, as the // stack should contain only mutable computations. Also, we don't want to // include the computation itself in the list of embedded computations. - for (auto* instruction : instructions()) { - auto process_called_computations = - [&](std::vector called_computations) { - // Put the called computations in reverse order onto the stack. - // Otherwise we don't match the recursive enumeration of - // computations, which processes the first called computation first. - absl::c_reverse(called_computations); - for (HloComputation* called_computation : called_computations) { - if (visited.insert(called_computation).second) { - st.emplace(called_computation, - called_computation->instructions_.cbegin()); - } - } - }; + for (const HloInstructionInfo& instruction : instructions_with_info()) { + using PtrVec = PtrVec; + auto process_called_computations = [&](const PtrVec& called_computations) { + if (called_computations.empty()) return; + // Put the called computations in reverse order onto the stack. + // Otherwise we don't match the recursive enumeration of + // computations, which processes the first called computation first. + std::reverse_iterator i( + called_computations.end()); + std::reverse_iterator rend( + called_computations.begin()); + for (; i != rend; ++i) { + HloComputation* called_computation = *i; + if (visited.insert(called_computation).second) { + st.emplace(called_computation, + called_computation->instructions_.cbegin()); + } + } + }; process_called_computations(instruction->called_computations()); while (!st.empty()) { auto& cur = st.top(); @@ -645,9 +770,19 @@ std::vector HloComputation::MakeEmbeddedComputationsList() st.pop(); post_order.push_back(computation); } else { - HloInstruction* next_instruction = cur.second->get(); - ++cur.second; - process_called_computations(next_instruction->called_computations()); + if (cur.second->inst() == nullptr) { + ++cur.second; + } else { + HloOpcode opcode = cur.second->opcode(); + HloInstruction* next_instruction = cur.second->get(); + ++cur.second; + if (HloInstruction::MightHaveCalledComputations(opcode)) { + process_called_computations( + next_instruction->called_computations()); + } else { + DCHECK(next_instruction->called_computations().empty()); + } + } } } } @@ -776,13 +911,13 @@ HloComputationProto HloComputation::ToProto() const { } proto.set_root_id(root_instruction()->unique_id()); *proto.mutable_program_shape() = ComputeProgramShape().ToProto(); - proto.set_is_fusion_computation(is_fusion_computation_); + proto.set_is_fusion_computation(IsFusionComputation()); proto.set_execution_thread(IsMainThread() ? "" : std::string(execution_thread())); return proto; } -/* static */ StatusOr> +/* static */ absl::StatusOr> HloComputation::CreateFromProto( const HloComputationProto& proto, const absl::flat_hash_map& computation_map, @@ -838,10 +973,12 @@ HloComputation::CreateFromProto( }()); auto computation = absl::WrapUnique( - new HloComputation(proto.name(), parameter_count, &instructions, root, - /*fusion_instruction=*/nullptr)); + new HloComputation(proto.name(), parameter_count, &instructions, root)); computation->unique_id_ = proto.id(); - computation->is_fusion_computation_ = proto.is_fusion_computation(); + if (proto.is_fusion_computation()) { + computation->instruction_and_type_ = + static_cast(InstructionType::kFusion); + } if (!proto.execution_thread().empty()) { computation->SetExecutionThread(proto.execution_thread()); } @@ -889,47 +1026,76 @@ HloInstruction* HloComputation::CreateCallInstruction( return call_instruction; } -StatusOr HloComputation::CreateAsyncInstructions( +absl::StatusOr HloComputation::CreateAsyncInstructions( HloInstruction* instruction, absl::Span context_shapes, absl::string_view async_execution_thread, bool replace, bool override_names) { - Builder builder("async_computation"); - std::vector parameters(instruction->operand_count()); - std::vector parameter_shapes(instruction->operand_count()); - for (int i = 0; i < instruction->operand_count(); ++i) { - const Shape& parameter_shape = instruction->operand(i)->shape(); - parameters[i] = builder.AddInstruction(HloInstruction::CreateParameter( - i, parameter_shape, absl::StrCat("param_", i))); - parameter_shapes[i] = parameter_shape; - } - HloInstruction* root = builder.AddInstruction( - instruction->CloneWithNewOperands(instruction->shape(), parameters)); - if (override_names) { - root->SetAndSanitizeName(absl::StrCat(instruction->name(), ".cloned")); - } - HloComputation* async_computation = - parent_->AddEmbeddedComputation(builder.Build(root)); - std::vector start_shapes = { - ShapeUtil::MakeTupleShape(parameter_shapes), root->shape()}; - for (const Shape& context_shape : context_shapes) { - start_shapes.push_back(context_shape); - } - HloInstruction* async_start = AddInstruction(HloInstruction::CreateAsyncStart( - ShapeUtil::MakeTupleShape(start_shapes), instruction->operands(), - async_computation, /*async_group_id=*/std::nullopt, - async_execution_thread)); - HloInstruction* async_done = AddInstruction(HloInstruction::CreateAsyncDone( - root->shape(), async_start, async_computation, - /*async_group_id=*/std::nullopt, async_execution_thread)); - if (override_names) { - async_start->SetAndSanitizeName(absl::StrCat(root->name(), ".call-start")); - async_done->SetAndSanitizeName(absl::StrCat(root->name(), ".call-done")); + HloInstruction* async_start; + HloInstruction* async_done; + if (instruction->opcode() == HloOpcode::kCopy) { + // Until the async ops are unified, add specialized support for copy here. + // TODO(b/319466176): Remove this special case once this bug is complete. + // Note that CopyStart/CopyDone uses (dest_shape, src_shape, context) + // convention while async-start/async-done uses ((src_shapes), dest_shape, + // context). + std::vector context_shapes_tuple; + context_shapes_tuple.reserve(context_shapes.size() + 2); + Shape instruction_shape_destination = instruction->shape(); + context_shapes_tuple.push_back(instruction_shape_destination); + Shape instruction_shape_source = instruction->operand(0)->shape(); + context_shapes_tuple.push_back(instruction_shape_source); + context_shapes_tuple.insert(context_shapes_tuple.end(), + context_shapes.begin(), context_shapes.end()); + + async_start = AddInstruction(HloInstruction::CreateCopyStart( + ShapeUtil::MakeTupleShape(context_shapes_tuple), + instruction->mutable_operand(0))); + async_done = AddInstruction(HloInstruction::CreateUnary( + instruction_shape_destination, HloOpcode::kCopyDone, async_start)); + } else { + Builder builder("async_computation"); + std::vector parameters(instruction->operand_count()); + std::vector parameter_shapes(instruction->operand_count()); + for (int i = 0; i < instruction->operand_count(); ++i) { + const Shape& parameter_shape = instruction->operand(i)->shape(); + parameters[i] = builder.AddInstruction(HloInstruction::CreateParameter( + i, parameter_shape, absl::StrCat("param_", i))); + parameter_shapes[i] = parameter_shape; + } + HloInstruction* root = builder.AddInstruction( + instruction->CloneWithNewOperands(instruction->shape(), parameters)); + if (override_names) { + root->SetAndSanitizeName(absl::StrCat(instruction->name(), ".cloned")); + } + HloComputation* async_computation = + parent_->AddEmbeddedComputation(builder.Build(root)); + std::vector start_shapes = { + ShapeUtil::MakeTupleShape(parameter_shapes), root->shape()}; + for (const Shape& context_shape : context_shapes) { + start_shapes.push_back(context_shape); + } + async_start = AddInstruction(HloInstruction::CreateAsyncStart( + ShapeUtil::MakeTupleShape(start_shapes), instruction->operands(), + async_computation, async_execution_thread)); + async_done = AddInstruction( + HloInstruction::CreateAsyncDone(root->shape(), async_start)); + if (override_names) { + async_start->SetAndSanitizeName( + absl::StrCat(root->name(), ".call-start")); + async_done->SetAndSanitizeName(absl::StrCat(root->name(), ".call-done")); + } } async_start->set_metadata(instruction->metadata()); async_start->CopyBackendConfigFrom(instruction); async_done->set_metadata(instruction->metadata()); async_done->CopyBackendConfigFrom(instruction); - TF_RETURN_IF_ERROR(async_done->CopyAllControlDepsFrom(instruction)); + for (HloInstruction* control_pred : instruction->control_predecessors()) { + TF_RETURN_IF_ERROR(control_pred->AddControlDependencyTo(async_start)); + } + for (HloInstruction* control_successor : instruction->control_successors()) { + TF_RETURN_IF_ERROR(async_done->AddControlDependencyTo(control_successor)); + } + if (replace) { TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); TF_RETURN_IF_ERROR(ReplaceInstruction(instruction, async_done)); @@ -937,7 +1103,7 @@ StatusOr HloComputation::CreateAsyncInstructions( return async_done; } -StatusOr HloComputation::DeepCopyHelper( +absl::StatusOr HloComputation::DeepCopyHelper( HloInstruction* instruction, ShapeIndex* index, absl::FunctionRef HloComputation::DeepCopyHelper( return copy_leaf(instruction, *index, this); } -StatusOr HloComputation::DeepCopyInstruction( +absl::StatusOr HloComputation::DeepCopyInstruction( HloInstruction* instruction, const ShapeTree* indices_to_copy, ShapeTree* copies_added) { if (instruction->parent() != this) { @@ -1007,7 +1173,8 @@ StatusOr HloComputation::DeepCopyInstruction( return DeepCopyHelper(instruction, &index, copy_leaf); } -StatusOr HloComputation::DeepCopyInstructionWithCustomCopier( +absl::StatusOr +HloComputation::DeepCopyInstructionWithCustomCopier( HloInstruction* instruction, absl::FunctionRef HloComputation::ReplaceInstruction( +absl::StatusOr HloComputation::ReplaceInstruction( HloInstruction* old_instruction, HloInstruction* new_instruction, bool preserve_sharding, bool relay_control_dependency) { TF_RET_CHECK( @@ -1131,7 +1298,7 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, return OkStatus(); } -StatusOr HloComputation::ReplaceInstructionWithDifferentShape( +absl::StatusOr HloComputation::ReplaceInstructionWithDifferentShape( HloInstruction* old_instruction, HloInstruction* new_instruction, bool preserve_sharding, bool relay_control_dependency, bool remove_unused_operands) { @@ -1145,6 +1312,10 @@ StatusOr HloComputation::ReplaceInstructionWithDifferentShape( TF_RETURN_IF_ERROR( new_instruction->CopyAllControlDepsFrom(old_instruction)); TF_RETURN_IF_ERROR(old_instruction->DropAllControlDeps()); + } else if (old_instruction->HasControlDependencies()) { + VLOG(10) << "Skipping replacement because old instruction has " + "control dependencies"; + return false; } VLOG(10) << "transformed " << old_instruction->ToString() << " to " << new_instruction->ToString(); @@ -1156,11 +1327,7 @@ StatusOr HloComputation::ReplaceInstructionWithDifferentShape( // But still this seems to be better than nothing. bool overwrite_op_name = new_instruction->metadata().op_name().empty() && !old_instruction->metadata().op_name().empty(); - bool overwrite_pass_id = - new_instruction->metadata().op_name().empty() && - new_instruction->metadata().logical_creation_pass_id() == 0 && - old_instruction->metadata().logical_creation_pass_id() != 0; - if (overwrite_op_name || overwrite_pass_id) { + if (overwrite_op_name) { new_instruction->set_metadata(old_instruction->metadata()); } if (new_instruction->frontend_attributes().map().empty()) { @@ -1343,8 +1510,7 @@ void SortClonedInstructionUsersAndControlLists( auto instruction_mapper = [&context, replace](const HloInstruction* i) { return context.FindInstruction(replace(i)); }; - for (const std::unique_ptr& instruction : - sorted_instructions) { + for (const HloInstructionInfo& instruction : sorted_instructions) { HloInstruction* cloned_instruction = context.FindInstruction(replace(instruction.get())); if (!cloned_instruction) { @@ -1405,12 +1571,13 @@ std::unique_ptr HloComputation::CloneInContext( // ourselves. std::vector postorder; absl::flat_hash_map visited; - for (const auto& instr : instructions_) { - std::vector dfs_stack; - const HloInstruction* new_instr = replace(instr.get()); + std::vector dfs_stack; + for (const auto& instr : instructions()) { + const HloInstruction* new_instr = replace(instr); if (!new_instr) { continue; } + dfs_stack.clear(); dfs_stack.push_back(new_instr); while (!dfs_stack.empty()) { @@ -1418,16 +1585,16 @@ std::unique_ptr HloComputation::CloneInContext( auto it = visited.find(cur); if (it != visited.end()) { dfs_stack.pop_back(); - if (it->second == kVisited) { + if (it->second == VisitState::kVisited) { continue; } - CHECK_EQ(it->second, kVisiting); + CHECK_EQ(it->second, VisitState::kVisiting); postorder.push_back(cur); - it->second = kVisited; + it->second = VisitState::kVisited; continue; } - visited.insert({cur, kVisiting}); + visited.insert({cur, VisitState::kVisiting}); for (HloInstruction* operand : cur->operands()) { const HloInstruction* new_operand = replace(operand); if (new_operand) { @@ -1495,7 +1662,6 @@ std::unique_ptr HloComputation::CloneInContext( context.MapComputation(this, result.get()); result->SetExecutionThread(execution_thread()); - return result; } diff --git a/xla/hlo/ir/hlo_computation.h b/xla/hlo/ir/hlo_computation.h index 92d2e00013bb1..37c7b63b98df5 100644 --- a/xla/hlo/ir/hlo_computation.h +++ b/xla/hlo/ir/hlo_computation.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_HLO_IR_HLO_COMPUTATION_H_ #define XLA_HLO_IR_HLO_COMPUTATION_H_ +#include #include #include #include @@ -26,12 +27,16 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/functional/function_ref.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/ptrvec.h" #include "xla/iterator_util.h" #include "xla/printer.h" #include "xla/service/hlo.pb.h" @@ -64,14 +69,12 @@ class HloModule; class HloComputation { public: // Used by instructions_. - using InstructionList = std::list>; + using InstructionList = std::vector; // Builder class for HloComputation. class Builder { public: - explicit Builder(absl::string_view name, - HloInstruction* fusion_instruction = nullptr) - : name_(name), fusion_instruction_(fusion_instruction) {} + explicit Builder(absl::string_view name) : name_(name) {} Builder(Builder&& b) = default; virtual ~Builder() = default; @@ -99,11 +102,11 @@ class HloComputation { return AddInstruction(std::move(instruction)); } - StatusOr AddParameter( + absl::StatusOr AddParameter( std::unique_ptr parameter) { if (!parameter_numbers_.insert(parameter->parameter_number()).second) { - return InternalError("Duplicate parameter number %d", - parameter->parameter_number()); + return Internal("Duplicate parameter number %d", + parameter->parameter_number()); } return AddInstruction(std::move(parameter)); } @@ -122,7 +125,6 @@ class HloComputation { private: const std::string name_; - HloInstruction* fusion_instruction_; std::vector> instructions_; absl::flat_hash_set parameter_numbers_; @@ -148,6 +150,49 @@ class HloComputation { OpMetadata metadata_; }; + // Helper class for returning the instruction post order for a computation, + // but maintaining a cache to avoid repeated calls to + // computation->MakeInstructionPostorder(). The cache is invalidated if + // RecordChange() is called. + // + // This class can be handy to avoid recomputing the instruction post order + // when an optimization pass wants to make multiple passes over the + // instructions. + // + // Example usage: + // for (auto* computation : module->computations(execution_threads)) { + // HloComputation::CachingPostOrder cpo(computation); + // for (auto instruction : cpo.PostOrder()) { // Pass 1 + // bool did_change = ... maybe do something to instruction ...; + // cpo.RecordChange(did_change); + // } + // for (auto instruction : cpo.PostOrder()) { // Pass 2 + // bool did_change = ... maybe do something else to instruction ...; + // cpo.RecordChange(did_change); + // } + // } + class CachingPostOrder { + public: + explicit CachingPostOrder(const HloComputation* computation) + : computation_(computation), recompute_(true) {} + + // Returns the instruction post-order for "computation" + const std::vector& PostOrder() { + if (recompute_) { + cached_post_order_ = computation_->MakeInstructionPostOrder(); + recompute_ = false; + } + return cached_post_order_; + } + + void RecordChange(bool changed) { recompute_ |= changed; } + + private: + const HloComputation* computation_; + bool recompute_; + std::vector cached_post_order_; + }; + ~HloComputation(); // Add an instruction to the computation. The computation takes ownership of @@ -252,8 +297,20 @@ class HloComputation { absl::string_view name() const { return name_; } + // Sets the string identifier for this computation. Name will be sanitized to + // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". + // + // See also HloModule::SetAndUniquifyComputationName(), which does this plus + // UniqufyName(). + void SetAndSanitizeName(absl::string_view name) { + name_ = NameUniquer::GetSanitizedName(name); + } + // Use the given NameUniquer to select a unique name for the computation based // on the computation's existing name. + // + // See also HloModule::SetAndUniquifyComputationName(), which does this plus + // SetAndSanitizeName(). void UniquifyName(NameUniquer* name_uniquer); // Prints a string representation of the computation. @@ -300,7 +357,7 @@ class HloComputation { // computation_map: a map from computation id to HloComputation*. This map // must contain all computations which the newly constructed computation // calls. - static StatusOr> CreateFromProto( + static absl::StatusOr> CreateFromProto( const HloComputationProto& proto, const absl::flat_hash_map& computation_map, bool prohibit_empty_literal = true); @@ -319,10 +376,10 @@ class HloComputation { } using InstructionSequence = tsl::gtl::iterator_range< - UnwrappingIterator>::iterator>>; + UnwrappingIterator>; - using ConstInstructionSequence = tsl::gtl::iterator_range>::const_iterator>>; + using ConstInstructionSequence = tsl::gtl::iterator_range< + UnwrappingIterator>; // Gets the instructions in this computation. // @@ -331,13 +388,33 @@ class HloComputation { // // for (HloInstruction* instr : computation->instructions()) { ... } // - ConstInstructionSequence instructions() const { - return {MakeUnwrappingIterator(instructions_.begin()), - MakeUnwrappingIterator(instructions_.end())}; + + tsl::gtl::iterator_range + instructions() const { + const int end = instructions_.size(); + return {HloInstructionUnwrappingConstIterator( + HloInstructionConstIterator(&instructions_, 0, end)), + HloInstructionUnwrappingConstIterator( + HloInstructionConstIterator(&instructions_, end, end))}; + } + tsl::gtl::iterator_range + instructions() { + const int end = instructions_.size(); + return {HloInstructionUnwrappingIterator( + HloInstructionIterator(&instructions_, 0, end)), + HloInstructionUnwrappingIterator( + HloInstructionIterator(&instructions_, end, end))}; } - InstructionSequence instructions() { - return {MakeUnwrappingIterator(instructions_.begin()), - MakeUnwrappingIterator(instructions_.end())}; + tsl::gtl::iterator_range instructions_with_info() { + const int end = instructions_.size(); + return {HloInstructionIterator(&instructions_, 0, end), + HloInstructionIterator(&instructions_, end, end)}; + } + tsl::gtl::iterator_range instructions_with_info() + const { + const int end = instructions_.size(); + return {HloInstructionConstIterator(&instructions_, 0, end), + HloInstructionConstIterator(&instructions_, end, end)}; } using ChannelDependencies = @@ -362,7 +439,7 @@ class HloComputation { void ForEachInstructionPostOrder( absl::FunctionRef func) const; - int64_t instruction_count() const { return instruction_iterators_.size(); } + int64_t instruction_count() const { return instruction_count_; } // Creates and returns a list of the embedded computations called by this // computation. This includes all embedded computations called directly or @@ -400,7 +477,7 @@ class HloComputation { // If `replace` is true, replace instruction with the async done instruction. // If `override_names` is true, the clone on `instruction` and the async op // created will get non-default names. - StatusOr CreateAsyncInstructions( + absl::StatusOr CreateAsyncInstructions( HloInstruction* instruction, absl::Span context_shapes, absl::string_view async_execution_thread = HloInstruction::kMainExecutionThread, @@ -416,14 +493,14 @@ class HloComputation { // transparently. If copies_added is non-null, then the added kCopy // instructions will be inserted in the respective index in the given // ShapeTree. - StatusOr DeepCopyInstruction( + absl::StatusOr DeepCopyInstruction( HloInstruction* instruction, const ShapeTree* indices_to_copy = nullptr, ShapeTree* copies_added = nullptr); // As above, but uses a custom function to copy the leaf nodes, which could // create alternative HLOs other than kCopy, or even pass-throughs. - StatusOr DeepCopyInstructionWithCustomCopier( + absl::StatusOr DeepCopyInstructionWithCustomCopier( HloInstruction* instruction, absl::FunctionRef ReplaceInstruction(HloInstruction* old_instruction, - HloInstruction* new_instruction, - bool preserve_sharding, - bool relay_control_dependency = false); + absl::StatusOr ReplaceInstruction( + HloInstruction* old_instruction, HloInstruction* new_instruction, + bool preserve_sharding, bool relay_control_dependency = false); // Same as above, with preserve_sharding=false. Since this replacement always // happens, it returns just a Status as opposed to StatusOr @@ -512,7 +588,7 @@ class HloComputation { // Same as ReplaceInstruction, but the new instruction can have a different // shape. - StatusOr ReplaceInstructionWithDifferentShape( + absl::StatusOr ReplaceInstructionWithDifferentShape( HloInstruction* old_instruction, HloInstruction* new_instruction, bool preserve_sharding, bool relay_control_dependency = false, bool remove_unused_operands = true); @@ -636,112 +712,108 @@ class HloComputation { // Returns if this computation is a fusion computation. // Do not use this method to determine if fusion_instruction_ != nullptr. // Instead, directly do: FusionInstruction() != nullptr - bool IsFusionComputation() const { return is_fusion_computation_; } + bool IsFusionComputation() const { + return instruction_type() == InstructionType::kFusion; + } // Returns if this computation is the entry computation of the module. bool IsEntryComputation() const; // Returns the owning fusion instruction, or nullptr if this is not a fusion // computation. - HloInstruction* FusionInstruction() const { return fusion_instruction_; } + HloInstruction* FusionInstruction() const { + return instruction_type() == InstructionType::kFusion ? instruction() + : nullptr; + } void SetFusionInstruction(HloInstruction* fusion_instruction) { - CHECK(!IsCustomCallComputation() && !IsAsyncComputation() && - !IsCollectiveCalledComputation() && !IsWhileBodyComputation()); - fusion_instruction_ = fusion_instruction; - is_fusion_computation_ |= (fusion_instruction != nullptr); + SetInstruction(fusion_instruction, InstructionType::kFusion); } // Returns if this computation is a custom-call computation. - bool IsCustomCallComputation() const { return is_custom_call_computation_; } + bool IsCustomCallComputation() const { + return instruction_type() == InstructionType::kCustomCall; + } // Returns the owning custom call instruction, or nullptr if this is not a // custom call computation. HloInstruction* CustomCallInstruction() const { - return custom_call_instruction_; + return instruction_type() == InstructionType::kCustomCall ? instruction() + : nullptr; } void SetCustomCallInstruction(HloInstruction* custom_call_instruction) { - CHECK(!IsFusionComputation() && !IsAsyncComputation() && - !IsCollectiveCalledComputation() && !IsWhileBodyComputation()); - custom_call_instruction_ = custom_call_instruction; - is_custom_call_computation_ |= (custom_call_instruction != nullptr); + SetInstruction(custom_call_instruction, InstructionType::kCustomCall); } // Returns if this computation is a to_apply region of a collective. bool IsCollectiveCalledComputation() const { - return is_collective_called_computation_; + return instruction_type() == InstructionType::kCollective; } // Returns the owning collective call instruction, or nullptr if this is not a // collective call computation. HloInstruction* CollectiveCallInstruction() const { - return collective_call_instruction_; + return instruction_type() == InstructionType::kCollective ? instruction() + : nullptr; } void SetCollectiveCallInstruction( HloInstruction* collective_call_instruction) { - CHECK(!IsFusionComputation() && !IsAsyncComputation() && - !IsCustomCallComputation() && !IsWhileBodyComputation()); - collective_call_instruction_ = collective_call_instruction; - is_collective_called_computation_ |= - (collective_call_instruction != nullptr); + SetInstruction(collective_call_instruction, InstructionType::kCollective); } // Returns if this computation is a body computation of a while. bool IsWhileBodyComputation() const { - return is_while_call_body_computation_; + return instruction_type() == InstructionType::kWhile; } // Returns the owning while call instruction, or nullptr if this is not a // while call body computation. HloInstruction* WhileCallInstruction() const { - return while_call_instruction_; + return instruction_type() == InstructionType::kWhile ? instruction() + : nullptr; } void SetWhileCallInstruction(HloInstruction* while_call_instruction) { - CHECK(!IsFusionComputation() && !IsAsyncComputation() && - !IsCustomCallComputation() && !IsCollectiveCalledComputation()); CHECK(while_call_instruction != nullptr); CHECK(while_call_instruction->opcode() == HloOpcode::kWhile); - while_call_instruction_ = while_call_instruction; - is_while_call_body_computation_ |= (while_call_instruction != nullptr); + SetInstruction(while_call_instruction, InstructionType::kWhile); } - // Returns if this computation is an async computation. - bool IsAsyncComputation() const { return !async_instructions_.empty(); } - - // Returns the owning async instruction. It's empty if this is not an async - // computation. - const std::vector& AsyncInstructions() const { - return async_instructions_; + // Returns if this computation is a branch computation of a conditional. + bool IsConditionalBranchComputation() const { + return instruction_type() == InstructionType::kConditional; } - std::vector& AsyncInstructions() { - return async_instructions_; + // Returns the owning conditional call instruction, or nullptr if this is not + // a conditional branch computation. + HloInstruction* ConditionalCallInstruction() const { + return instruction_type() == InstructionType::kConditional ? instruction() + : nullptr; } - void AddAsyncInstruction(HloInstruction& async_instruction) { - CHECK(!IsFusionComputation() && !IsCustomCallComputation()); - CHECK(async_instruction.opcode() == HloOpcode::kAsyncStart || - async_instruction.opcode() == HloOpcode::kAsyncUpdate || - async_instruction.opcode() == HloOpcode::kAsyncDone); - async_instructions_.push_back(&async_instruction); + void SetConditionalCallInstruction( + HloInstruction* conditional_call_instruction) { + CHECK(conditional_call_instruction != nullptr); + CHECK(conditional_call_instruction->opcode() == HloOpcode::kConditional); + SetInstruction(conditional_call_instruction, InstructionType::kConditional); } - void RemoveAsyncInstruction(HloInstruction* async_instruction) { - if (async_instruction == nullptr) { - return; - } - async_instructions_.erase( - std::remove(async_instructions_.begin(), async_instructions_.end(), - async_instruction), - async_instructions_.end()); - } + // Returns if this computation is an async computation. + bool IsAsyncComputation() const { return async_start_ != nullptr; } - // Returns if this computation is invoked by an Hlo instruction. - bool IsCalledComputation() const { - return IsFusionComputation() || IsCustomCallComputation(); + // Returns the owning async instruction. It's nullptr if this is not an async + // computation. + HloInstruction* AsyncStart() const { return async_start_; } + + void AddAsyncStart(HloInstruction* async_instruction) { + // TODO: Add instruction type for async instructions. + CHECK(instruction_type() == InstructionType::kUnset); + CHECK(async_instruction->opcode() == HloOpcode::kAsyncStart); + async_start_ = async_instruction; } + void RemoveAsyncStart() { async_start_ = nullptr; } + // Clear the unique ID of the computation so that it can be re-assigned, such // as for the purpose of compacting the unique IDs. void ClearUniqueIdInternal() { unique_id_ = -1; } @@ -773,7 +845,12 @@ class HloComputation { // stage clean up process is designed such that HloPass can have stable // internal pointers to HloInstructions while we create and remove // HloInstructions in a pass. - void Cleanup() { to_be_deleted_.clear(); } + void Cleanup() { + for (HloInstruction* it : to_be_deleted_) { + delete it; + } + to_be_deleted_.clear(); + } // Returns true if a given instruction is marked dead in this computation. bool IsMarkedAsDead(const HloInstruction* inst); @@ -785,7 +862,7 @@ class HloComputation { explicit HloComputation( const std::string& name, int parameter_count, std::vector>* instructions, - HloInstruction* root_instruction, HloInstruction* fusion_instruction); + HloInstruction* root_instruction); // Internal helper for adding instructions. HloInstruction* AddInstructionInternal( @@ -806,7 +883,7 @@ class HloComputation { // Internal helper for recursive copying of an instruction. Creates and // returns a deep copy of the given instruction. - StatusOr DeepCopyHelper( + absl::StatusOr DeepCopyHelper( HloInstruction* instruction, ShapeIndex* index, absl::FunctionRef CollectUnreachableRoots() const; - enum VisitState { kVisiting, kVisited }; + class VisitMap; void ComputeInstructionPostOrder( HloInstruction* root, const ChannelDependencies& channel_dependencies, - absl::flat_hash_map& visited, - std::vector& post_order) const; + VisitMap& visited, std::vector& post_order, + std::vector* dfs_stack_scratch) const; void ForEachInstructionPostOrderImpl( absl::FunctionRef func, HloInstruction* root, - const ChannelDependencies& channel_dependencies, - absl::flat_hash_map& visited) const; + const ChannelDependencies& channel_dependencies, VisitMap& visited, + std::vector* dfs_stack_scratch) const; Status RemoveUnusedParametersImpl(bool allow_non_fusion); Status RemoveInstructionImpl(HloInstruction* instruction, bool ignore_safety_check); - std::string name_; - int64_t unique_id_; - HloInstruction* root_instruction_; - - // If this computation is a fusion computation, this field points to the - // corresponding fusion instruction (if it is live). Otherwise, this is null. - HloInstruction* fusion_instruction_; - - // Determines whether this computation is a fusion computation. A fusion - // computation ordinarily also has a non-null fusion_instruction_. However, if - // a fusion instruction is removed during compilation, the fusion computation - // becomes unreachable, and its fusion_instruction_ is set to null. We still - // need to regard such computations as fusion computations for HLO scheduling - // purposes. - bool is_fusion_computation_; + enum class InstructionType : uint8_t { + kUnset, + // This computation is a fusion computation. A fusion computation ordinarily + // also has a non-null instruction. However, if a fusion instruction + // is removed during compilation, the fusion computation becomes + // unreachable, and its instruction is set to null. We still need to regard + // such computations as fusion computations for HLO scheduling purposes. + kFusion, + // This computation is a custom-call computation. + kCustomCall, + // This computation is a while body computation. + kCollective, + // This computation is a while body computation. + kWhile, + // This computation is a conditional branch computation. + kConditional, + }; + static constexpr uintptr_t kInstructionTypeMask = 0b111; + static_assert(static_cast(InstructionType::kUnset) == 0, + "kUnset must be 0."); - // If this computation is a custom-call computation, this field points to the - // corresponding custom-call instruction (if it is live). Otherwise, this is - // null. - HloInstruction* custom_call_instruction_; + InstructionType instruction_type() const { + return static_cast(instruction_and_type_ & + kInstructionTypeMask); + } - // Determines whether this computation is a custom-call computation. - bool is_custom_call_computation_; + HloInstruction* instruction() const { + return reinterpret_cast(instruction_and_type_ & + ~kInstructionTypeMask); + } - // If this computation is a collective sub-computation, this field points to - // the corresponding collective instruction. Otherwise, this is null. - HloInstruction* collective_call_instruction_; + void SetInstruction(HloInstruction* instruction, InstructionType type); - // Determines whether this computation is a collective sub-computation. - bool is_collective_called_computation_; + int64_t unique_id_; + HloInstruction* root_instruction_; - // If this computation is a while body computation, this field points to - // the corresponding while instruction. Otherwise, this is null. - HloInstruction* while_call_instruction_; + // Module containing this computation. + HloModule* parent_ = nullptr; - // Determines whether this computation is a while body computation. - bool is_while_call_body_computation_; + // Contains HloInstruction* and its type. + // The respective type in the least significant three bits. + uintptr_t instruction_and_type_ = 0; // If this computation is an async computation, this field points to the - // corresponding async instructions (if live) that call this computation. + // first async instruction (async-start) in the asynchronous op chain that + // calls this computation. // Otherwise, this is empty. - std::vector async_instructions_; + HloInstruction* async_start_ = nullptr; - // Execution thread of this computation. By default, it's main thread. - std::string execution_thread_ = HloInstruction::kMainExecutionThread; + HloInstruction::InstructionVector param_instructions_; - // Module containing this computation. - HloModule* parent_ = nullptr; + // Store instructions in std::vector as they can be added and removed + // arbitrarily and we want a stable iteration order. + // For the reverse mapping we use HloInstruction::index_in_parent_. + HloInstructionList instructions_; - // Store instructions in std::list as they can be added and removed - // arbitrarily and we want a stable iteration order. Keep a map from - // instruction pointer to location in the list for fast lookup. - InstructionList instructions_; - absl::flat_hash_map - instruction_iterators_; + // Number of not-marked-for-deletion entries in instructions_. + int64_t instruction_count_; // Removed instructions are moved into to_be_deleted_ first and then // deallocated when Cleanup is called. - std::vector> to_be_deleted_; + PtrVec to_be_deleted_; - HloInstruction::InstructionVector param_instructions_; + // Execution thread of this computation. By default, it's main thread. + std::string execution_thread_ = HloInstruction::kMainExecutionThread; + + std::string name_; HloComputation(const HloComputation&) = delete; HloComputation& operator=(const HloComputation&) = delete; @@ -929,9 +1012,6 @@ Status HloComputation::AcceptOrdered( absl::flat_hash_set visited; for (const HloInstruction* instruction : order) { VLOG(3) << "Visiting ordered: " << instruction->ToString(); - TF_RET_CHECK(instruction_iterators_.contains(instruction)) - << "Instruction " << instruction->name() << " is not in computation " - << name(); TF_RET_CHECK(!visited.contains(instruction)) << "Instruction " << instruction->name() << " appears more than once in order"; diff --git a/xla/hlo/ir/hlo_dfs_reachability.cc b/xla/hlo/ir/hlo_dfs_reachability.cc new file mode 100644 index 0000000000000..c831f31cec03f --- /dev/null +++ b/xla/hlo/ir/hlo_dfs_reachability.cc @@ -0,0 +1,113 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/ir/hlo_dfs_reachability.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/SmallVector.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" + +namespace xla { + +bool HloDfsReachability::IsPresent(const HloInstruction* instruction) const { + return instruction_to_idx_.contains(instruction); +} + +bool HloDfsReachability::IsReachable(const HloInstruction* from, + const HloInstruction* to) const { + if (from == to) { + return true; + } + if (to->operand_count() == 0 && from->control_predecessors().empty()) { + return false; + } + + size_t target_node_idx = instruction_to_idx_.at(from); + size_t dfs_root_idx = instruction_to_idx_.at(to); + + // Note that the DFS goes from the "uses" root towards the "defs", i.e. from + // `to` node to `from` node, so the node indices are decreasing. + if (dfs_root_idx < target_node_idx) { + return false; + } + + // We use LLVM support library here because it has stack-allocated bit vector + // which significantly improves performance by avoiding heap allocations when + // instructions are reachable via a short chain. + llvm::SmallVector stack{to}; + + // We will visit instructions in the [target_node_idx, dfs_root_idx] range, so + // we can construct a smaller bit vector. + llvm::BitVector visited_idxs(1 + (dfs_root_idx - target_node_idx)); + visited_idxs.set(dfs_root_idx - target_node_idx); + + auto check_and_enqueue = [&](const HloInstruction* instr) { + if (instr == from) { + return true; + } + size_t instr_idx = instruction_to_idx_.at(instr); + if (instr_idx < target_node_idx) { + return false; + } + size_t visited_idx = instr_idx - target_node_idx; + if (visited_idxs.test(visited_idx)) { + return false; + } + visited_idxs.set(visited_idx); + stack.push_back(instr); + return false; + }; + + while (!stack.empty()) { + const HloInstruction* instr = stack.pop_back_val(); + + if (absl::c_any_of(instr->operands(), check_and_enqueue) || + absl::c_any_of(instr->control_predecessors(), check_and_enqueue)) { + return true; + } + } + return false; +} + +bool HloDfsReachability::IsConnected(const HloInstruction* a, + const HloInstruction* b) const { + return IsReachable(a, b) || IsReachable(b, a); +} + +std::unique_ptr HloDfsReachability::Build( + const HloComputation* computation) { + auto res = std::make_unique(); + + // For instruction reachability we do not care about correct order of + // collective operations as we only care about use-def chains. + HloComputation::ChannelDependencies empty_channel_dependencies; + std::vector instructions = + computation->MakeInstructionPostOrder(empty_channel_dependencies); + + res->instruction_to_idx_.reserve(instructions.size()); + for (size_t i = 0; i < instructions.size(); ++i) { + res->instruction_to_idx_[instructions[i]] = i; + } + + return res; +} + +} // namespace xla diff --git a/xla/hlo/ir/hlo_dfs_reachability.h b/xla/hlo/ir/hlo_dfs_reachability.h new file mode 100644 index 0000000000000..3db9a5309b4ef --- /dev/null +++ b/xla/hlo/ir/hlo_dfs_reachability.h @@ -0,0 +1,61 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_IR_HLO_DFS_REACHABILITY_H_ +#define XLA_HLO_IR_HLO_DFS_REACHABILITY_H_ + +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" + +namespace xla { + +// A simple DFS-based reachability analysis for HLO instructions. +// +// When the class is created, the instructions are ordered in a defs-before-uses +// topological order. +// The reachability query runs a DFS from the destination node (going up through +// operands / control predecessors), and stops when the instruction's index in +// the defs-before-uses list is before the source node. As the reachability is +// tested for nodes that are close to each other, this optimization works well, +// and the time is dominated by the post-order sort. +class HloDfsReachability { + public: + // Returns true iff the instruction was present in the computation passed to + // Build(). The calling code may want to still use the class after the + // computation is modified, if it's known that the def-before-use order is + // still preserved. + bool IsPresent(const HloInstruction* instruction) const; + // Returns true iff there is a path (with edges being users and control + // successors) from 'from' to 'to'. (i.e. path from definitions to uses; from + // producers to consumers) + bool IsReachable(const HloInstruction* from, const HloInstruction* to) const; + // Returns true iff either `a` is reachable from `b` or `b` is reachable from + // `a`. + bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; + static std::unique_ptr Build( + const HloComputation* computation); + + private: + // LLVM dense map shows ~10-20% speedup compared to absl::flat_hash_map. + llvm::DenseMap instruction_to_idx_; +}; + +} // namespace xla + +#endif // XLA_HLO_IR_HLO_DFS_REACHABILITY_H_ diff --git a/xla/hlo/ir/hlo_domain_metadata.h b/xla/hlo/ir/hlo_domain_metadata.h index 2849dd2e3b44c..a90f6d8a5ede2 100644 --- a/xla/hlo/ir/hlo_domain_metadata.h +++ b/xla/hlo/ir/hlo_domain_metadata.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/ir/hlo_frontend_attributes.cc b/xla/hlo/ir/hlo_frontend_attributes.cc index f5ae28dcf15c5..347edcec61f39 100644 --- a/xla/hlo/ir/hlo_frontend_attributes.cc +++ b/xla/hlo/ir/hlo_frontend_attributes.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/ir/hlo_frontend_attributes.h b/xla/hlo/ir/hlo_frontend_attributes.h index 7e38d15806dce..73486915d3ae7 100644 --- a/xla/hlo/ir/hlo_frontend_attributes.h +++ b/xla/hlo/ir/hlo_frontend_attributes.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/ir/hlo_input_output_alias_config.cc b/xla/hlo/ir/hlo_input_output_alias_config.cc index 87117bbc2a58a..3dca857c43a84 100644 --- a/xla/hlo/ir/hlo_input_output_alias_config.cc +++ b/xla/hlo/ir/hlo_input_output_alias_config.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -91,7 +91,8 @@ HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const { return result; } -StatusOr HloInputOutputAliasConfig::CreateFromProto( +absl::StatusOr +HloInputOutputAliasConfig::CreateFromProto( Shape output_shape, const HloInputOutputAliasProto& proto) { HloInputOutputAliasConfig result(std::move(output_shape)); for (const HloInputOutputAliasProto::AliasEntryProto& entry : @@ -221,7 +222,7 @@ Status HloInputOutputAliasConfig::Verify( TF_RET_CHECK(LayoutUtil::IsDenseArray(output_subshape)); if (size_func(param_subshape) != size_func(output_subshape)) { - return InternalError( + return Internal( "Expected aliased input %lld at index %s and output at index %s to " "have the same size. Input sub-shape is %s with size %lld, output " "sub-shape is %s with size %lld", @@ -278,7 +279,7 @@ HloBufferDonorProto HloBufferDonorConfig::ToProto() const { return result; } -StatusOr HloBufferDonorConfig::CreateFromProto( +absl::StatusOr HloBufferDonorConfig::CreateFromProto( const HloBufferDonorProto& proto) { HloBufferDonorConfig result; for (const HloBufferDonorProto::BufferDonorEntryProto& entry : @@ -334,7 +335,7 @@ Status HloBufferDonorConfig::Verify(const HloModule& module) const { TF_RET_CHECK(LayoutUtil::IsDenseArray(param_subshape)); if (alias_config.ParameterHasAlias(donor.param_number, donor.param_index)) { - return InternalError( + return Internal( "Input %lld at index %s is registered as a buffer donor. However, it " "is also in the input output alias config.", donor.param_number, donor.param_index.ToString()); diff --git a/xla/hlo/ir/hlo_input_output_alias_config.h b/xla/hlo/ir/hlo_input_output_alias_config.h index ab06940319aa3..236621dbe0538 100644 --- a/xla/hlo/ir/hlo_input_output_alias_config.h +++ b/xla/hlo/ir/hlo_input_output_alias_config.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,8 +20,10 @@ limitations under the License. #include #include #include +#include #include +#include "absl/container/btree_set.h" #include "absl/container/flat_hash_set.h" #include "absl/functional/function_ref.h" #include "xla/service/hlo.pb.h" @@ -92,7 +94,7 @@ class HloInputOutputAliasConfig { // HloInputOutputAliasProto. HloInputOutputAliasProto ToProto() const; - static StatusOr CreateFromProto( + static absl::StatusOr CreateFromProto( Shape output_shape, const HloInputOutputAliasProto& proto); // Returns the output index that the given parameter and parameter index is @@ -168,6 +170,14 @@ class HloBufferDonorConfig { param_index == other.param_index; } + bool operator<(const BufferDonor& other) const { + return std::forward_as_tuple(param_number, param_index) < + std::forward_as_tuple(other.param_number, other.param_index); + } + bool operator>(const BufferDonor& other) const { return other < *this; } + bool operator<=(const BufferDonor& other) const { return !(*this > other); } + bool operator>=(const BufferDonor& other) const { return !(*this < other); } + // A hash function borrowed from go/absl-hash. template friend H AbslHashValue(H h, const BufferDonor& donor) { @@ -189,7 +199,7 @@ class HloBufferDonorConfig { // (De)Serializes an HloBufferDonorConfig to/from an HloBufferDonorProto. HloBufferDonorProto ToProto() const; - static StatusOr CreateFromProto( + static absl::StatusOr CreateFromProto( const HloBufferDonorProto& proto); // Verifies that the given config is valid for the given module. @@ -198,7 +208,7 @@ class HloBufferDonorConfig { Status Verify(const HloModule& module) const; // Returns the registered buffer donors - const absl::flat_hash_set& buffer_donor() const { + const absl::btree_set& buffer_donor() const { return buffer_donor_; } @@ -208,7 +218,7 @@ class HloBufferDonorConfig { private: // A set recording the registered buffer donors. - absl::flat_hash_set buffer_donor_; + absl::btree_set buffer_donor_; }; std::ostream& operator<<(std::ostream& out, diff --git a/xla/hlo/ir/hlo_instruction.cc b/xla/hlo/ir/hlo_instruction.cc index 2ffc311685329..1b54ddac95e12 100644 --- a/xla/hlo/ir/hlo_instruction.cc +++ b/xla/hlo/ir/hlo_instruction.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,14 +36,17 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" +#include "absl/log/check.h" #include "absl/memory/memory.h" #include "absl/strings/ascii.h" #include "absl/strings/escaping.h" +#include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" @@ -58,6 +61,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/hlo_sharding_metadata.h" +#include "xla/hlo/ir/ptrvec.h" #include "xla/iterator_util.h" #include "xla/layout.h" #include "xla/literal.h" @@ -80,8 +84,6 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/human_readable_json.h" #include "tsl/platform/logging.h" // IWYU pragma: keep -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace xla { @@ -90,6 +92,134 @@ using absl::StrAppend; using absl::StrCat; using absl::StrJoin; +// Empty static object +const HloInstruction::Rare* const HloInstruction::kEmptyRare = + new HloInstruction::Rare; + +namespace { +// Specialization for erasing from PtrVec. +template +Status EraseElementFromVector(PtrVec* container, T value) { + // absl::c_find returns a const_iterator which does not seem to work on + // gcc 4.8.4, and this breaks the ubuntu/xla_gpu build bot. + auto it = std::find(container->begin(), container->end(), value); + TF_RET_CHECK(it != container->end()); + container->erase(it); + return OkStatus(); +} +} // namespace + +HloInstruction::Users::~Users() = default; + +void HloInstruction::Users::Clear() { + users_.clear(); + user_map_.reset(nullptr); + DCHECK(CheckInvariants()); +} + +bool HloInstruction::Users::Contains(const HloInstruction* instruction) const { + if (user_map_ == nullptr) { + return std::find(users_.begin(), users_.end(), instruction) != users_.end(); + } else { + return user_map_->contains(instruction); + } +} + +void HloInstruction::Users::AddUser(HloInstruction* user) { + if (!Contains(user)) { + // Create hash table if user list is large. + if (user_map_ == nullptr && users_.size() >= kMapThreshold) { + user_map_ = + std::make_unique>( + users_.size()); + RebuildMap(); + DCHECK(CheckInvariants()); + } + + if (user_map_ != nullptr) { + user_map_->emplace(user, users_.size()); + } + users_.push_back(user); + DCHECK(CheckInvariants()); + } +} + +int64_t HloInstruction::Users::UserId(HloInstruction* user) { + if (user_map_ == nullptr) { + auto it = std::find(users_.begin(), users_.end(), user); + CHECK(it != users_.end()); + return it - users_.begin(); + } else { + auto result = user_map_->find(user); + CHECK(result != user_map_->end()); + return result->second; + } +} + +void HloInstruction::Users::MaybeRemoveUser(HloInstruction* user) { + if (Contains(user)) { + RemoveUser(user); + DCHECK(CheckInvariants()); + } +} + +void HloInstruction::Users::RemoveUser(HloInstruction* user) { + const int64_t index = UserId(user); + CHECK_EQ(users_[index], user); + + // Move the last user into the position of the removed user. + HloInstruction* last = users_.back(); + + // Update map if allocated. + if (user_map_ != nullptr) { + (*user_map_)[last] = index; + user_map_->erase(user); + } + + // Replace found user with last slot from the vector. + users_[index] = last; + users_.pop_back(); + + DCHECK(CheckInvariants()); +} + +void HloInstruction::Users::SortInstructionUsers( + const MappedPtrContainerSorter::MapPtrFn& map_fn, + const Users& sorted_instruction_users) { + using Sorter = MappedPtrContainerSorter; + auto status = Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(), + sorted_instruction_users.users_, users_); + if (!status.ok()) { + LOG(ERROR) << "Failed to sort instruction users: " << status; + } + if (user_map_ != nullptr) { + user_map_->clear(); + RebuildMap(); + } + DCHECK(CheckInvariants()); +} + +void HloInstruction::Users::RebuildMap() { + for (uint64_t i = 0; i < users_.size(); ++i) { + (*user_map_)[users_[i]] = i; + } +} + +bool HloInstruction::Users::CheckInvariants() { + if (user_map_ != nullptr) { + // Avoid quadratic behavior by doing a quick and dirty check on + // size instead of actually comparing mapped indices. + CHECK_EQ(users_.size(), user_map_->size()); + } + return true; +} + +void HloInstruction::AppendComputation(HloComputation* computation) { + // In .cc file since PtrVec::push_back() wants to check the alignment + // of T and hlo_instruction.h does not include hlo_computation.h. + mutable_rare()->called_computations.push_back(computation); +} + HloInstruction* HloInstruction::AddInstruction( std::unique_ptr derived_instruction) { HloInstruction* derived = @@ -104,7 +234,7 @@ HloInstruction* HloInstruction::AddInstruction( } /* static */ -StatusOr> HloInstruction::CreateFromProto( +absl::StatusOr> HloInstruction::CreateFromProto( const HloInstructionProto& proto, const absl::flat_hash_map& instruction_map, const absl::flat_hash_map& computation_map, @@ -153,6 +283,19 @@ StatusOr> HloInstruction::CreateFromProto( }); return result; }; + const auto output_to_operand_aliasing = [&proto]() { + std::vector>> + output_to_operand_aliasing; + for (const auto& aliasing : proto.output_operand_aliasing()) { + output_to_operand_aliasing.emplace_back( + ShapeIndex(aliasing.output_shape_index().begin(), + aliasing.output_shape_index().end()), + std::make_pair(aliasing.operand_index(), + ShapeIndex(aliasing.operand_shape_index().begin(), + aliasing.operand_shape_index().end()))); + } + return output_to_operand_aliasing; + }; const auto computations = [&computation_map, &proto](int index) { return computation_map.at(proto.called_computation_ids(index)); }; @@ -215,46 +358,56 @@ StatusOr> HloInstruction::CreateFromProto( << "Async start instruction should have 1 called computation but " "sees " << proto.called_computation_ids_size(); - std::optional async_group_id; - if (proto.async_group_id() >= 0) { - async_group_id = proto.async_group_id(); - } instruction = CreateAsyncStart(shape, all_operands(), computations(0), - async_group_id, proto.async_execution_thread().empty() ? kMainExecutionThread : proto.async_execution_thread()); break; } case HloOpcode::kAsyncUpdate: { - TF_RET_CHECK(proto.called_computation_ids_size() == 1) - << "Async update instruction should have 1 called computation but " - "sees " - << proto.called_computation_ids_size(); - std::optional async_group_id; - if (proto.async_group_id() >= 0) { - async_group_id = proto.async_group_id(); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Async update requires one singular operand"; + HloInstruction* prev_op = operands(0); + TF_RET_CHECK(prev_op->IsAsynchronous()) + << "Async update requires its operand to be an asynchronous op"; + if (!proto.async_execution_thread().empty()) { + TF_RET_CHECK(proto.async_execution_thread() == + prev_op->async_execution_thread()) + << "Async update should have " << prev_op->async_execution_thread() + << " async_execution_thread, but sees " + << proto.async_execution_thread(); } - instruction = - CreateAsyncUpdate(shape, operands(0), computations(0), async_group_id, - proto.async_execution_thread().empty() - ? kMainExecutionThread - : proto.async_execution_thread()); + if (!proto.called_computation_ids().empty()) { + TF_RET_CHECK(computations(0) == prev_op->async_wrapped_computation()) + << "Async update should have " + << prev_op->async_wrapped_computation()->name() + << " async_wrapped_computation, but sees " + << computations(0)->name(); + } + instruction = CreateAsyncUpdate(shape, prev_op); break; } case HloOpcode::kAsyncDone: { - TF_RET_CHECK(proto.called_computation_ids_size() == 1) - << "Async done instruction should have 1 called computation but sees " - << proto.called_computation_ids_size(); - std::optional async_group_id; - if (proto.async_group_id() >= 0) { - async_group_id = proto.async_group_id(); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Async done requires one singular operand"; + HloInstruction* prev_op = operands(0); + TF_RET_CHECK(prev_op->IsAsynchronous()) + << "Async done requires its operand to be an asynchronous op"; + if (!proto.async_execution_thread().empty()) { + TF_RET_CHECK(proto.async_execution_thread() == + prev_op->async_execution_thread()) + << "Async done should have " << prev_op->async_execution_thread() + << " async_execution_thread, but sees " + << proto.async_execution_thread(); } - instruction = - CreateAsyncDone(shape, operands(0), computations(0), async_group_id, - proto.async_execution_thread().empty() - ? kMainExecutionThread - : proto.async_execution_thread()); + if (!proto.called_computation_ids().empty()) { + TF_RET_CHECK(computations(0) == prev_op->async_wrapped_computation()) + << "Async done should have " + << prev_op->async_wrapped_computation()->name() + << " async_wrapped_computation, but sees " + << computations(0)->name(); + } + instruction = CreateAsyncDone(shape, prev_op); break; } case HloOpcode::kCopyStart: { @@ -467,19 +620,9 @@ StatusOr> HloInstruction::CreateFromProto( << "No fusion computation with id " << fusion_id; instruction = CreateFusion(shape, fusion_kind, all_operands(), fused_computation); - std::vector>> - output_to_operand_aliasing; - for (const auto& aliasing : proto.output_operand_aliasing()) { - output_to_operand_aliasing.emplace_back( - ShapeIndex(aliasing.output_shape_index().begin(), - aliasing.output_shape_index().end()), - std::make_pair(aliasing.operand_index(), - ShapeIndex(aliasing.operand_shape_index().begin(), - aliasing.operand_shape_index().end()))); - } auto fusion_instr = DynCast(instruction.get()); fusion_instr->set_output_to_operand_aliasing( - std::move(output_to_operand_aliasing)); + output_to_operand_aliasing()); break; } case HloOpcode::kRng: @@ -613,6 +756,17 @@ StatusOr> HloInstruction::CreateFromProto( /*channel_id=*/channel_id, split_dimension); break; } + case HloOpcode::kCollectiveBroadcast: { + std::optional channel_id; + if (proto.channel_id() > 0) { + channel_id = proto.channel_id(); + } + auto replica_groups = std::vector( + proto.replica_groups().begin(), proto.replica_groups().end()); + instruction = CreateCollectiveBroadcast( + shape, all_operands(), replica_groups, false, channel_id); + break; + } case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermuteStart: { TF_RET_CHECK(proto.operand_ids().size() == 1 || @@ -836,18 +990,8 @@ StatusOr> HloInstruction::CreateFromProto( precision_config.mutable_operand_precision()->Resize( proto.operand_ids_size(), PrecisionConfig::DEFAULT); *custom_call_instr->mutable_precision_config() = precision_config; - std::vector>> - output_to_operand_aliasing; - for (const auto& aliasing : proto.output_operand_aliasing()) { - output_to_operand_aliasing.emplace_back( - ShapeIndex(aliasing.output_shape_index().begin(), - aliasing.output_shape_index().end()), - std::make_pair(aliasing.operand_index(), - ShapeIndex(aliasing.operand_shape_index().begin(), - aliasing.operand_shape_index().end()))); - } custom_call_instr->set_output_to_operand_aliasing( - std::move(output_to_operand_aliasing)); + output_to_operand_aliasing()); custom_call_instr->set_custom_call_schedule(proto.custom_call_schedule()); custom_call_instr->set_api_version(proto.custom_call_api_version()); break; @@ -939,16 +1083,27 @@ StatusOr> HloInstruction::CreateFromProto( instruction = CreateIota(shape, proto.dimensions(0)); break; case HloOpcode::kDot: { + int expected_operands = + HloDotInstruction::kOperands + proto.dot_sparsity_size(); + TF_RET_CHECK(proto.dot_sparsity_size() <= HloDotInstruction::kOperands) + << "Too many sparse dot descriptors: " << proto.dot_sparsity_size(); + TF_RET_CHECK(proto.operand_ids_size() == expected_operands) + << proto.opcode() << " instruction should have " << expected_operands + << " operands but sees " << proto.operand_ids_size(); TF_RET_CHECK(proto.has_dot_dimension_numbers()) << "Dot instruction should have dot_dimension_numbers."; TF_RET_CHECK(absl::c_all_of(proto.precision_config().operand_precision(), PrecisionConfig::Precision_IsValid)); PrecisionConfig precision_config = proto.precision_config(); precision_config.mutable_operand_precision()->Resize( - proto.operand_ids_size(), PrecisionConfig::DEFAULT); + HloDotInstruction::kOperands, PrecisionConfig::DEFAULT); + std::vector sparsity(proto.dot_sparsity().begin(), + proto.dot_sparsity().end()); + auto operand_vector = all_operands(); instruction = std::make_unique( shape, operands(0), operands(1), proto.dot_dimension_numbers(), - precision_config); + precision_config, std::move(sparsity), + absl::MakeSpan(operand_vector).subspan(HloDotInstruction::kOperands)); break; } case HloOpcode::kDomain: { @@ -988,8 +1143,9 @@ StatusOr> HloInstruction::CreateFromProto( inferred_dimension = proto.dimensions()[0]; } TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() && - ShapeUtil::ElementsIn(shape) == - ShapeUtil::ElementsIn(operands(0)->shape())) + (operands(0)->shape().is_unbounded_dynamic() || + ShapeUtil::StaticExtentProduct(shape) == + ShapeUtil::StaticExtentProduct(operands(0)->shape()))) << "shape: " << ShapeUtil::HumanString(shape) << " operand: " << ShapeUtil::HumanString(operands(0)->shape()); instruction = CreateReshape(shape, operands(0), inferred_dimension); @@ -997,8 +1153,8 @@ StatusOr> HloInstruction::CreateFromProto( } case HloOpcode::kDynamicReshape: { TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() && - ShapeUtil::ElementsIn(shape) == - ShapeUtil::ElementsIn(operands(0)->shape())) + ShapeUtil::StaticExtentProduct(shape) == + ShapeUtil::StaticExtentProduct(operands(0)->shape())) << "shape: " << ShapeUtil::HumanString(shape) << " operand: " << ShapeUtil::HumanString(operands(0)->shape()); const auto& operand_vector = all_operands(); @@ -1006,27 +1162,37 @@ StatusOr> HloInstruction::CreateFromProto( shape, operands(0), absl::MakeSpan(operand_vector).subspan(1)); break; } + case HloOpcode::kCall: { + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Call should have 1 called computation but has " + << proto.called_computation_ids_size(); + TF_RET_CHECK(!proto.has_precision_config()) + << instruction->opcode() << proto.name(); + TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode(); + + auto call_instruction = new HloCallInstruction( + shape, all_operands(), + computation_map.at(proto.called_computation_ids()[0])); + call_instruction->set_output_to_operand_aliasing( + output_to_operand_aliasing()); + instruction = absl::WrapUnique(call_instruction); + break; + } default: { instruction = absl::WrapUnique(new HloInstruction(opcode, shape)); + if (instruction->opcode() == HloOpcode::kWhile) { + TF_RET_CHECK(proto.called_computation_ids_size() == 2) + << "While should have 2 called computation but has " + << proto.called_computation_ids_size(); + } + for (const int64_t operand_id : proto.operand_ids()) { instruction->AppendOperand(instruction_map.at(operand_id)); } - if (instruction->opcode() != HloOpcode::kFusion) { - if (instruction->opcode() == HloOpcode::kCall) { - TF_RET_CHECK(proto.called_computation_ids_size() == 1) - << "Call should have 1 called computation but has " - << proto.called_computation_ids_size(); - } - if (instruction->opcode() == HloOpcode::kWhile) { - TF_RET_CHECK(proto.called_computation_ids_size() == 2) - << "While should have 2 called computation but has " - << proto.called_computation_ids_size(); - } - for (const int64_t computation_id : proto.called_computation_ids()) { - instruction->called_computations_.push_back( - computation_map.at(computation_id)); - } + for (const int64_t computation_id : proto.called_computation_ids()) { + instruction->AppendComputation(computation_map.at(computation_id)); } + TF_RET_CHECK(!proto.has_precision_config()) << instruction->opcode() << proto.DebugString(); TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode(); @@ -1043,7 +1209,7 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(!proto.name().empty()); instruction->SetAndSanitizeName(proto.name()); - instruction->metadata_ = proto.metadata(); + *instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); TF_RET_CHECK(proto.id() >= 0) @@ -1155,6 +1321,7 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state, case HloOpcode::kCos: case HloOpcode::kOptimizationBarrier: case HloOpcode::kClz: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -1258,29 +1425,23 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state, /* static */ std::unique_ptr HloInstruction::CreateAsyncStart( const Shape& shape, absl::Span operands, - HloComputation* async_computation, std::optional async_group_id, + HloComputation* async_computation, absl::string_view async_execution_thread) { - return std::make_unique( + return std::make_unique( HloOpcode::kAsyncStart, shape, operands, async_computation, - async_group_id, async_execution_thread); + async_execution_thread); } /* static */ std::unique_ptr HloInstruction::CreateAsyncUpdate( - const Shape& shape, HloInstruction* operand, - HloComputation* async_computation, std::optional async_group_id, - absl::string_view async_execution_thread) { - return std::make_unique( - HloOpcode::kAsyncUpdate, shape, operand, async_computation, - async_group_id, async_execution_thread); + const Shape& shape, HloInstruction* operand) { + return std::make_unique(HloOpcode::kAsyncUpdate, shape, + operand); } /* static */ std::unique_ptr HloInstruction::CreateAsyncDone( - const Shape& shape, HloInstruction* operand, - HloComputation* async_computation, std::optional async_group_id, - absl::string_view async_execution_thread) { - return std::make_unique( - HloOpcode::kAsyncDone, shape, operand, async_computation, async_group_id, - async_execution_thread); + const Shape& shape, HloInstruction* operand) { + return std::make_unique(HloOpcode::kAsyncDone, shape, + operand); } /* static */ std::unique_ptr HloInstruction::CreateCopyStart( @@ -1312,9 +1473,12 @@ HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a, /* static */ std::unique_ptr HloInstruction::CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfig& precision_config) { + const PrecisionConfig& precision_config, + std::vector sparsity, + absl::Span sparse_meta) { return std::make_unique(shape, lhs, rhs, dimension_numbers, - precision_config); + precision_config, + std::move(sparsity), sparse_meta); } /* static */ std::unique_ptr @@ -1390,6 +1554,16 @@ HloInstruction::CreateAllReduceStart( split_dimension); } +/* static */ std::unique_ptr +HloInstruction::CreateCollectiveBroadcast( + const Shape& shape, absl::Span operands, + absl::Span replica_groups, bool constrain_layout, + const std::optional& channel_id) { + return std::make_unique( + HloOpcode::kCollectiveBroadcast, shape, operands, replica_groups, + constrain_layout, channel_id); +} + /* static */ std::unique_ptr HloInstruction::CreateCollectivePermute( const Shape& shape, HloInstruction* operand, @@ -1481,6 +1655,12 @@ HloInstruction::CreateCollectivePermuteStart( is_host_transfer); } +/* static */ std::unique_ptr HloInstruction::CreateSendDone( + HloInstruction* operand, int64_t channel_id, bool is_host_transfer) { + return std::make_unique(operand, channel_id, + is_host_transfer); +} + /* static */ std::unique_ptr HloInstruction::CreateRecv( const Shape& shape, HloInstruction* token, int64_t channel_id, bool is_host_transfer) { @@ -1497,6 +1677,12 @@ HloInstruction::CreateCollectivePermuteStart( is_host_transfer); } +/* static */ std::unique_ptr HloInstruction::CreateRecvDone( + HloInstruction* operand, int64_t channel_id, bool is_host_transfer) { + return std::make_unique(operand, channel_id, + is_host_transfer); +} + /* static */ std::unique_ptr HloInstruction::CreateReverse( const Shape& shape, HloInstruction* operand, absl::Span dimensions) { @@ -1536,9 +1722,9 @@ HloInstruction::CreateAddDependency(HloInstruction* data_operand, absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape)); instruction->AppendOperand(init); // Body comes before condition computation in the vector. - instruction->called_computations_.push_back(body); - instruction->called_computations_.push_back(condition); - // Set back pointer from body computation to the while call instruction + instruction->AppendComputation(body); + instruction->AppendComputation(condition); + // Set back pointer from body computation to the while call instruction. body->SetWhileCallInstruction(instruction.get()); return instruction; } @@ -1552,11 +1738,14 @@ HloInstruction::CreateAddDependency(HloInstruction* data_operand, instruction->AppendOperand(pred); instruction->AppendOperand(true_computation_arg); instruction->AppendOperand(false_computation_arg); - // In called_computations_, the index of true_computation must be 0 and that + // In called_computations, the index of true_computation must be 0 and that // of false computation must be 1, as defined by kTrueComputationIndex and // kFalseComputationIndex. - instruction->called_computations_.push_back(true_computation); - instruction->called_computations_.push_back(false_computation); + instruction->AppendComputation(true_computation); + instruction->AppendComputation(false_computation); + // Set back pointer from computations to the conditional instruction. + true_computation->SetConditionalCallInstruction(instruction.get()); + false_computation->SetConditionalCallInstruction(instruction.get()); return instruction; } @@ -1569,8 +1758,10 @@ HloInstruction::CreateAddDependency(HloInstruction* data_operand, instruction->AppendOperand(branch_index); CHECK_EQ(branch_computations.size(), branch_computation_args.size()); for (int i = 0; i < branch_computations.size(); ++i) { - instruction->called_computations_.push_back(branch_computations[i]); + instruction->AppendComputation(branch_computations[i]); instruction->AppendOperand(branch_computation_args[i]); + // Set back pointer from the computation to the conditional instruction. + branch_computations[i]->SetConditionalCallInstruction(instruction.get()); } return instruction; } @@ -1830,8 +2021,9 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreateReshape( const Shape& shape, HloInstruction* operand, int64_t inferred_dimension) { - CHECK_EQ(ShapeUtil::ElementsIn(shape), - ShapeUtil::ElementsIn(operand->shape())) + CHECK(operand->shape().is_unbounded_dynamic() || + ShapeUtil::StaticExtentProduct(shape) == + ShapeUtil::StaticExtentProduct(operand->shape())) << "shape: " << ShapeUtil::HumanString(shape) << " operand: " << ShapeUtil::HumanString(operand->shape()); @@ -1843,8 +2035,8 @@ HloInstruction::CreateBroadcastSequence( HloInstruction::CreateDynamicReshape( const Shape& shape, HloInstruction* data_operand, absl::Span dim_sizes) { - CHECK_EQ(ShapeUtil::ElementsIn(shape), - ShapeUtil::ElementsIn(data_operand[0].shape())) + CHECK_EQ(ShapeUtil::StaticExtentProduct(shape), + ShapeUtil::StaticExtentProduct(data_operand[0].shape())) << "shape: " << ShapeUtil::HumanString(shape) << " operand: " << ShapeUtil::HumanString(data_operand[0].shape()); CHECK_EQ(shape.rank(), dim_sizes.size()); @@ -1899,13 +2091,19 @@ void HloInstruction::SetupDerivedInstruction( } else { derived_instruction->clear_sharding(); } - derived_instruction->set_metadata(metadata_); - derived_instruction->set_frontend_attributes(frontend_attributes_); - derived_instruction->set_statistics_viz(statistics_viz_); -} - -bool HloInstruction::IsRoot() const { - return parent_ != nullptr && this == parent_->root_instruction(); + derived_instruction->set_metadata(*metadata_); + if (has_rare()) { + derived_instruction->set_frontend_attributes(frontend_attributes()); + derived_instruction->set_statistics_viz(statistics_viz()); + } else if (derived_instruction->has_rare()) { + derived_instruction->mutable_rare()->frontend_attributes.Clear(); + derived_instruction->mutable_rare()->statistics_viz.Clear(); + } + // If the derived instruction has the same opcode as current, + // then the backend config is also applicable. + if (opcode() == derived_instruction->opcode() && has_backend_config()) { + derived_instruction->CopyBackendConfigFrom(this); + } } bool HloInstruction::HasSideEffectNoRecurse() const { @@ -1922,6 +2120,7 @@ bool HloInstruction::HasSideEffectNoRecurse() const { case HloOpcode::kAllReduceDone: case HloOpcode::kAllGatherStart: case HloOpcode::kAllGatherDone: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermuteStart: case HloOpcode::kCollectivePermuteDone: return true; @@ -2163,6 +2362,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kReduceScatter: case HloOpcode::kAllReduceStart: case HloOpcode::kAllToAll: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermuteStart: case HloOpcode::kInfeed: @@ -2200,6 +2400,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kOptimizationBarrier: case HloOpcode::kCopyDone: case HloOpcode::kCos: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kImag: @@ -2303,6 +2504,8 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 0); clone = CreatePartitionId(shape); break; + default: + CHECK(0) << "Unsupported opcode: " << opcode_; } // SetupDerivedInstruction will setup the precision_config_ field. SetupDerivedInstruction(clone.get()); @@ -2338,9 +2541,7 @@ void HloInstruction::DetachFromOperandsAndUsers() { if (operand == nullptr) { continue; } - if (operand->user_map_.find(this) != operand->user_map_.end()) { - operand->RemoveUser(this); - } + operand->users_.MaybeRemoveUser(this); operands_[operand_num] = nullptr; } @@ -2398,12 +2599,12 @@ const HloInstruction* HloInstruction::LatestNonGteAncestor() const { } const HloInstruction* HloInstruction::operand(int64_t i) const { - return operands_.at(i); + return operands_[i]; } HloInstruction* HloInstruction::mutable_operand(int64_t i) { CHECK(operands_[i] != nullptr); - return operands_.at(i); + return operands_[i]; } int64_t HloInstruction::operand_index(const HloInstruction* target) const { @@ -2428,42 +2629,52 @@ HloInstruction::InstructionVector HloInstruction::unique_operands() const { Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { TF_RET_CHECK(instruction->parent() == parent()); - if (!absl::c_linear_search(control_successors_, instruction)) { - control_successors_.push_back(instruction); - TF_RET_CHECK( - !absl::c_linear_search(instruction->control_predecessors_, this)); - instruction->control_predecessors_.push_back(this); + if (!absl::c_linear_search(control_successors(), instruction)) { + mutable_rare()->control_successors.push_back(instruction); + TF_RET_CHECK(!absl::c_linear_search( + instruction->rare()->control_predecessors, this)); + instruction->mutable_rare()->control_predecessors.push_back(this); } return OkStatus(); } Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) { TF_RET_CHECK(instruction->parent() == parent()); - TF_RETURN_IF_ERROR(EraseElementFromVector(&control_successors_, instruction)); - TF_RETURN_IF_ERROR( - EraseElementFromVector(&instruction->control_predecessors_, this)); + if (has_rare()) { + TF_RETURN_IF_ERROR(EraseElementFromVector( + &mutable_rare()->control_successors, instruction)); + } + if (instruction->has_rare()) { + TF_RETURN_IF_ERROR(EraseElementFromVector( + &instruction->mutable_rare()->control_predecessors, this)); + } return OkStatus(); } Status HloInstruction::DropAllControlDeps() { - for (auto* ctrl_succ : control_successors_) { - TF_RETURN_IF_ERROR( - EraseElementFromVector(&ctrl_succ->control_predecessors_, this)); - } - for (auto* ctrl_pred : control_predecessors_) { - TF_RETURN_IF_ERROR( - EraseElementFromVector(&ctrl_pred->control_successors_, this)); + if (has_rare()) { + for (auto* ctrl_succ : rare()->control_successors) { + TF_RETURN_IF_ERROR(EraseElementFromVector( + &ctrl_succ->mutable_rare()->control_predecessors, this)); + } + for (auto* ctrl_pred : rare()->control_predecessors) { + TF_RETURN_IF_ERROR(EraseElementFromVector( + &ctrl_pred->mutable_rare()->control_successors, this)); + } + Rare* r = mutable_rare(); + r->control_successors.clear(); + r->control_predecessors.clear(); } - control_successors_.clear(); - control_predecessors_.clear(); return OkStatus(); } Status HloInstruction::SafelyDropAllControlDependencies() { // Add all pairs of transitive dependencies from predecessors to successors. - for (HloInstruction* predecessor : control_predecessors_) { - for (HloInstruction* successor : control_successors_) { - TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(successor)); + if (has_rare()) { + for (HloInstruction* predecessor : rare()->control_predecessors) { + for (HloInstruction* successor : rare()->control_successors) { + TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(successor)); + } } } TF_RETURN_IF_ERROR(DropAllControlDeps()); @@ -2471,18 +2682,18 @@ Status HloInstruction::SafelyDropAllControlDependencies() { } bool HloInstruction::HasControlDependencies() const { - return !control_predecessors_.empty() || !control_successors_.empty(); + const Rare* r = rare(); + return (!r->control_predecessors.empty() || !r->control_successors.empty()); } -Status HloInstruction::CopyAllControlDepsFrom(const HloInstruction* inst) { - for (auto* ctrl_pred : inst->control_predecessors()) { - TF_RETURN_IF_ERROR(ctrl_pred->AddControlDependencyTo(this)); +Status HloInstruction::CopyAllControlDepsTo(HloInstruction* start, + HloInstruction* end) const { + for (auto* ctrl_pred : control_predecessors()) { + TF_RETURN_IF_ERROR(ctrl_pred->AddControlDependencyTo(start)); } - - for (auto* ctrl_succ : inst->control_successors()) { - TF_RETURN_IF_ERROR(this->AddControlDependencyTo(ctrl_succ)); + for (auto* ctrl_succ : control_successors()) { + TF_RETURN_IF_ERROR(end->AddControlDependencyTo(ctrl_succ)); } - return OkStatus(); } @@ -2584,19 +2795,6 @@ void HloInstruction::RemoveOperandsAtAscendingIndices( operands_.resize(operands_.size() - removed_count); } -void HloInstruction::AddUser(HloInstruction* user) { - if (!ContainsKey(user_map_, user)) { - user_map_.emplace(user, users_.size()); - users_.push_back(user); - } -} - -int64_t HloInstruction::UserId(HloInstruction* user) { - auto result = user_map_.find(user); - CHECK(result != user_map_.end()); - return result->second; -} - bool HloInstruction::HasConstantOperand() const { for (const HloInstruction* operand : operands_) { if (operand->IsConstant()) { @@ -2633,6 +2831,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kCos: case HloOpcode::kDivide: case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -2735,6 +2934,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kReduceScatter: case HloOpcode::kAllReduceStart: case HloOpcode::kAllToAll: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermuteStart: case HloOpcode::kConvolution: @@ -2758,23 +2958,6 @@ bool HloInstruction::IdenticalSlowPath( return false; } -void HloInstruction::RemoveUser(HloInstruction* user) { - auto map_it = user_map_.find(user); - CHECK(map_it != user_map_.end()); - - const int64_t index = map_it->second; - CHECK_EQ(users_[index], user); - - // Move the last user into the position of the removed user. - users_[index] = users_.back(); - user_map_[users_.back()] = index; - - // Remove the user from the map and drop the last slot from the vector what - // have been moved to the position of the original user. - user_map_.erase(map_it); - users_.pop_back(); -} - Status HloInstruction::ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer) { TF_RET_CHECK( @@ -2862,6 +3045,54 @@ Status HloInstruction::ReplaceOperandWithDifferentShape( return OkStatus(); } +// Copy all the instructions in the given fusion instruction into the fusion +// instruction's parent computation and replace the use of the fusion +// instruction with the copy of the fusion expression root. +Status HloInstruction::Defuse() { + if (opcode() != HloOpcode::kFusion) { + return OkStatus(); + } + VLOG(2) << "Defusing instruction: " << ToString(); + + HloComputation* fused_computation = fused_instructions_computation(); + + // A map from fused instruction to its defused clone. + absl::flat_hash_map + defused_instructions; + // Initialize map to contain the fusion instruction parameters mapping + // to the operands of the fusion instruction. + for (int64_t i = 0; i < operand_count(); ++i) { + defused_instructions[fused_computation->parameter_instruction(i)] = + mutable_operand(i); + } + + // Create a clone of each instruction of the fused computation in the same + // computation as the fusion instruction itself. + // TODO(b/68227302): Moving instruction to new computation rather than + // cloning and deleting. + for (HloInstruction* fused_instruction : + fused_computation->MakeInstructionPostOrder()) { + if (fused_instruction->opcode() == HloOpcode::kParameter) { + continue; + } + std::vector new_operands; + for (HloInstruction* operand : fused_instruction->operands()) { + new_operands.push_back(defused_instructions.at(operand)); + } + HloInstruction* defused_instruction = + parent()->AddInstruction(fused_instruction->CloneWithNewOperands( + fused_instruction->shape(), new_operands)); + defused_instructions[fused_instruction] = defused_instruction; + } + + TF_RETURN_IF_ERROR( + ReplaceAllUsesWith(defused_instructions.at(fused_expression_root()))); + + HloModule* module = GetModule(); + TF_RETURN_IF_ERROR(parent()->RemoveInstruction(this)); + return module->RemoveEmbeddedComputation(fused_computation); +} + Status HloInstruction::ReplaceUsesWith(absl::Span users, HloInstruction* new_producer) { TF_RET_CHECK( @@ -2872,7 +3103,9 @@ Status HloInstruction::ReplaceUsesWith(absl::Span users, Status HloInstruction::ReplaceAllUsesWithDifferentShape( absl::Span users, HloInstruction* new_producer) { - for (HloInstruction* user : users) { + // Make a copy since users span might get mutated during the loop + std::vector users_vector(users.begin(), users.end()); + for (HloInstruction* user : users_vector) { TF_RETURN_IF_ERROR(ReplaceUseWithDifferentShape(user, new_producer)); } @@ -2900,7 +3133,9 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer, Status HloInstruction::ReplaceAllUsesWithDifferentShape( HloInstruction* new_producer) { bool new_producer_is_user = false; - for (HloInstruction* user : users()) { + // Make a copy since users span might get mutated during the loop + std::vector users_vector(users().begin(), users().end()); + for (HloInstruction* user : users_vector) { if (user == new_producer) { // It's possible that new_producer is a user of this instruction as might // be the case when replacing an instruction with a kCopy of itself. In @@ -2917,8 +3152,7 @@ Status HloInstruction::ReplaceAllUsesWithDifferentShape( } } } - users_.clear(); - user_map_.clear(); + users_.Clear(); if (new_producer_is_user) { AddUser(new_producer); } @@ -2939,18 +3173,18 @@ bool HloInstruction::IsEffectiveBitcast() const { HloComputation* HloInstruction::to_apply() const { if (has_to_apply()) { - CHECK_EQ(called_computations_.size(), 1) + CHECK_EQ(called_computations().size(), 1) << "Expected a to_apply computation for " << opcode(); - return called_computations_[0]; + return called_computations()[0]; } LOG(FATAL) << "Invalid opcode for to_apply(): " << opcode(); } void HloInstruction::set_to_apply(HloComputation* computation) { if (has_to_apply()) { - CHECK_EQ(called_computations_.size(), 1) + CHECK_EQ(called_computations().size(), 1) << "Expected a to_apply computation for " << opcode(); - called_computations_[0] = computation; + rare_->called_computations[0] = computation; return; } LOG(FATAL) << "Invalid opcode for to_apply(): " << opcode(); @@ -2967,12 +3201,11 @@ bool HloInstruction::has_to_apply() const { case HloOpcode::kReduceWindow: case HloOpcode::kScatter: case HloOpcode::kSort: - case HloOpcode::kTopK: return true; case HloOpcode::kCustomCall: // CustomCall can have a to_apply computation, but it is not required to // have one. - return called_computations_.size() == 1; + return called_computations().size() == 1; default: return false; } @@ -2980,22 +3213,22 @@ bool HloInstruction::has_to_apply() const { HloComputation* HloInstruction::while_condition() const { CHECK_EQ(HloOpcode::kWhile, opcode_); - return called_computations_[kConditionComputationIndex]; + return called_computations()[kConditionComputationIndex]; } HloComputation* HloInstruction::while_body() const { CHECK_EQ(HloOpcode::kWhile, opcode_); - return called_computations_[kBodyComputationIndex]; + return called_computations()[kBodyComputationIndex]; } void HloInstruction::set_while_condition(HloComputation* computation) { CHECK_EQ(HloOpcode::kWhile, opcode_); - called_computations_[kConditionComputationIndex] = computation; + rare_->called_computations[kConditionComputationIndex] = computation; } void HloInstruction::set_while_body(HloComputation* computation) { CHECK_EQ(HloOpcode::kWhile, opcode_); - called_computations_[kBodyComputationIndex] = computation; + rare_->called_computations[kBodyComputationIndex] = computation; } HloInstruction* HloInstruction::while_init() const { @@ -3006,37 +3239,36 @@ HloInstruction* HloInstruction::while_init() const { HloComputation* HloInstruction::true_computation() const { CHECK_EQ(HloOpcode::kConditional, opcode_); CHECK_EQ(PRED, operand(0)->shape().element_type()); - return called_computations_[kTrueComputationIndex]; + return called_computations()[kTrueComputationIndex]; } HloComputation* HloInstruction::false_computation() const { CHECK_EQ(HloOpcode::kConditional, opcode_); CHECK_EQ(PRED, operand(0)->shape().element_type()); - return called_computations_[kFalseComputationIndex]; + return called_computations()[kFalseComputationIndex]; } -const std::vector& HloInstruction::branch_computations() - const { +const PtrVec& HloInstruction::branch_computations() const { CHECK(HloOpcode::kConditional == opcode_); - return called_computations_; + return called_computations(); } int HloInstruction::branch_count() const { CHECK(HloOpcode::kConditional == opcode_); - return called_computations_.size(); + return called_computations().size(); } HloComputation* HloInstruction::branch_computation(int b) const { CHECK(HloOpcode::kConditional == opcode_); CHECK_GE(b, 0); - CHECK_LT(b, called_computations_.size()); - return called_computations_[b]; + CHECK_LT(b, called_computations().size()); + return called_computations()[b]; } void HloInstruction::set_branch_computation(int b, HloComputation* computation) { CHECK_EQ(HloOpcode::kConditional, opcode_); - called_computations_[b] = computation; + rare_->called_computations[b] = computation; } std::string HloInstruction::SignatureString() const { @@ -3149,6 +3381,7 @@ bool HloInstruction::IsOpElementwise(HloOpcode opcode) { case HloOpcode::kBitcastConvert: case HloOpcode::kCopy: case HloOpcode::kCos: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -3270,8 +3503,8 @@ void HloInstruction::PrintWithCanonicalNameMap( // Print opcode, operand(s). if (options.syntax_sugar_async_ops() && HloOpcodeIsAsync(opcode()) && - (!called_computations_.empty() && - called_computations_[0]->CanExpandIntoSingleInstruction())) { + (async_wrapped_computation() && + async_wrapped_computation()->CanExpandIntoSingleInstruction())) { absl::string_view suffix = [&]() { switch (opcode()) { case HloOpcode::kAsyncStart: @@ -3302,10 +3535,11 @@ void HloInstruction::PrintWithCanonicalNameMap( PrintExtraAttributes(attr_printer, options); if (options.print_metadata() && - (!metadata_.op_type().empty() || !metadata_.op_name().empty() || - !metadata_.source_file().empty())) { + (!metadata_->op_type().empty() || !metadata_->op_name().empty() || + !metadata_->source_file().empty())) { printer->Append(", metadata={"); - printer->Append(xla::OpMetadataToString(metadata_)); + printer->Append(xla::OpMetadataToString( + *metadata_, options.print_metadata_only_op_name())); printer->Append("}"); } if (options.print_backend_config() && !backend_config_.empty()) { @@ -3469,9 +3703,10 @@ void HloInstruction::PrintExtraAttributes( }); } } else if (HloOpcodeIsAsync(opcode())) { - if (!options.syntax_sugar_async_ops() || - (!called_computations().empty() && - !called_computations_[0]->CanExpandIntoSingleInstruction())) { + if (opcode() == HloOpcode::kAsyncStart && + (!options.syntax_sugar_async_ops() || + (async_wrapped_computation() && + !async_wrapped_computation()->CanExpandIntoSingleInstruction()))) { printer.Next([this, &options](Printer* printer) { printer->Append("calls="); PrintNameInternal(printer, async_wrapped_computation()->name(), @@ -3574,17 +3809,17 @@ void HloInstruction::PrintExtraAttributes( sharding().Print(printer, options.print_metadata()); }); } - if (!frontend_attributes_.map().empty()) { + if (!rare()->frontend_attributes.map().empty()) { printer.Next([this](Printer* printer) { AppendCat(printer, "frontend_attributes=", - FrontendAttributesToString(frontend_attributes_)); + FrontendAttributesToString(rare()->frontend_attributes)); }); } - if (options.print_control_dependencies() && !control_predecessors_.empty()) { + if (options.print_control_dependencies() && !control_predecessors().empty()) { printer.Next([this, &options](Printer* printer) { printer->Append("control-predecessors={"); - AppendJoin(printer, control_predecessors_, ", ", + AppendJoin(printer, control_predecessors(), ", ", [&](Printer* printer, HloInstruction* pre) { PrintNameInternal(printer, pre->name(), options); }); @@ -3592,26 +3827,10 @@ void HloInstruction::PrintExtraAttributes( }); } - if (!statistics_viz_.statistics().empty()) { + if (!statistics_viz().statistics().empty()) { printer.Next([this](Printer* printer) { - AppendCat(printer, "statistics=", StatisticsVizToString(statistics_viz_)); - }); - } - - if (operation_queue_id_) { - printer.Next([this](Printer* printer) { - AppendCat(printer, "operation_queue_id=", *operation_queue_id_); - }); - } - - if (wait_on_operation_queues_.size() > 0) { - printer.Next([this, &options](Printer* printer) { - printer->Append("wait_on_operation_queues={"); - AppendJoin(printer, wait_on_operation_queues_, ", ", - [&](Printer* printer, int64_t queue_id) { - printer->Append(queue_id); - }); - printer->Append("}"); + AppendCat(printer, + "statistics=", StatisticsVizToString(statistics_viz())); }); } } @@ -3663,14 +3882,14 @@ HloInstructionProto HloInstruction::ToProto() const { for (const HloInstruction* operand : operands_) { proto.add_operand_ids(operand->unique_id()); } - for (const HloInstruction* control : control_predecessors_) { + for (const HloInstruction* control : control_predecessors()) { proto.add_control_predecessor_ids(control->unique_id()); } - *proto.mutable_metadata() = metadata_; + *proto.mutable_metadata() = *metadata_; proto.set_backend_config(backend_config_.GetRawString()); if (opcode() != HloOpcode::kFusion) { - for (const HloComputation* computation : called_computations_) { + for (const HloComputation* computation : called_computations()) { proto.add_called_computation_ids(computation->unique_id()); } } @@ -3679,9 +3898,9 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_sharding() = sharding().ToProto(); } - *proto.mutable_frontend_attributes() = frontend_attributes_; + *proto.mutable_frontend_attributes() = frontend_attributes(); - *proto.mutable_statistics_viz() = statistics_viz_; + *proto.mutable_statistics_viz() = statistics_viz(); return proto; } @@ -3756,10 +3975,14 @@ bool HloInstruction::IsFusible() const { HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape) : unique_id_(-1), + index_in_parent_(~0u), opcode_(opcode), + is_default_config_(false), + cleaned_up_(false), + marked_as_dead_(false), + is_root_(false), shape_(shape), - name_(HloOpcodeString(opcode)), - marked_as_dead_(false) { + name_(HloOpcodeString(opcode)) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); } @@ -3780,6 +4003,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleBatchNormInference(this); case HloOpcode::kBatchNormGrad: return visitor->HandleBatchNormGrad(this); + case HloOpcode::kErf: + return visitor->HandleErf(this); case HloOpcode::kLogistic: return visitor->HandleLogistic(this); case HloOpcode::kSign: @@ -3856,6 +4081,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleAllReduceDone(this); case HloOpcode::kAllToAll: return visitor->HandleAllToAll(this); + case HloOpcode::kCollectiveBroadcast: + return visitor->HandleCollectiveBroadcast(this); case HloOpcode::kCollectivePermute: return visitor->HandleCollectivePermute(this); case HloOpcode::kCollectivePermuteStart: @@ -4004,11 +4231,12 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleCholesky(this); case HloOpcode::kOptimizationBarrier: return visitor->HandleOptimizationBarrier(this); + default: + return Internal( + "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - " + "please file a bug for XLA.", + HloOpcodeString(opcode_)); } - return InternalError( - "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - " - "please file a bug for XLA.", - HloOpcodeString(opcode_)); } // Explicit instantiations. @@ -4357,7 +4585,7 @@ absl::string_view ToString(HloInstruction::FusionKind kind) { } } -StatusOr StringToFusionKind( +absl::StatusOr StringToFusionKind( absl::string_view kind_name) { if (kind_name == "kLoop") { return HloInstruction::FusionKind::kLoop; @@ -4422,6 +4650,13 @@ std::string PrecisionToString(const PrecisionConfig::Precision& precision) { return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision)); } +std::string AlgorithmToString(const PrecisionConfig::Algorithm& algorithm) { + constexpr absl::string_view kPrefix = "ALG_"; + const std::string& name = PrecisionConfig::Algorithm_Name(algorithm); + DCHECK(absl::StartsWith(name, kPrefix)); + return absl::AsciiStrToLower(name.substr(kPrefix.size())); +} + static std::string CustomCallScheduleToString( const CustomCallSchedule& schedule) { return absl::AsciiStrToLower(CustomCallSchedule_Name(schedule)); @@ -4509,7 +4744,8 @@ std::string ReplicaGroupsToString( return StrCat("{", StrJoin(replica_group_str, ","), "}"); } -StatusOr StringToRandomAlgorithm(const std::string& name) { +absl::StatusOr StringToRandomAlgorithm( + const std::string& name) { static absl::flat_hash_map* map = [] { static auto* map = new absl::flat_hash_map; for (int i = 0; i < RandomAlgorithm_ARRAYSIZE; i++) { @@ -4527,7 +4763,7 @@ StatusOr StringToRandomAlgorithm(const std::string& name) { return found->second; } -StatusOr StringToRandomDistribution( +absl::StatusOr StringToRandomDistribution( const std::string& name) { static absl::flat_hash_map* map = [] { static auto* map = new absl::flat_hash_map; @@ -4546,7 +4782,7 @@ StatusOr StringToRandomDistribution( return found->second; } -StatusOr StringToPrecision( +absl::StatusOr StringToPrecision( const std::string& name) { static absl::flat_hash_map* map = [] { @@ -4562,12 +4798,33 @@ StatusOr StringToPrecision( }(); auto found = map->find(absl::AsciiStrToLower(name)); if (found == map->end()) { - return InvalidArgument("Unknown distribution"); + return InvalidArgument("Unknown precision"); } return found->second; } -StatusOr StringToCustomCallSchedule( +absl::StatusOr StringToAlgorithm( + const std::string& name) { + static absl::flat_hash_map* map = + [] { + static auto* map = + new absl::flat_hash_map; + for (int i = 0; i < PrecisionConfig::Algorithm_ARRAYSIZE; i++) { + if (PrecisionConfig::Algorithm_IsValid(i)) { + auto value = static_cast(i); + (*map)[AlgorithmToString(value)] = value; + } + } + return map; + }(); + auto found = map->find(absl::AsciiStrToLower(name)); + if (found == map->end()) { + return InvalidArgument("Unknown algorithm"); + } + return found->second; +} + +absl::StatusOr StringToCustomCallSchedule( absl::string_view name) { static const absl::flat_hash_map* map = [] { static auto* map = new absl::flat_hash_map; @@ -4586,7 +4843,7 @@ StatusOr StringToCustomCallSchedule( return found->second; } -StatusOr StringToCustomCallApiVersion( +absl::StatusOr StringToCustomCallApiVersion( absl::string_view name) { static const absl::flat_hash_map* map = [] { @@ -4654,6 +4911,7 @@ Status HloInstruction::GetBackendConfigInternal( } const std::string& HloInstruction::BackendConfigRep::GetRawString() const { + absl::WriterMutexLock lock{&mutex_}; if (proto_ && raw_string_.empty()) { raw_string_ = BackendConfigToRawString(*proto_).value(); } @@ -4667,6 +4925,8 @@ HloInstruction::BackendConfigRep HloInstruction::BackendConfigRep::Clone() if (auto* proto = GetProtoPtr()) { cloned.SetProto(*proto); } else { + absl::MutexLock source_lock{&mutex_}; + absl::MutexLock target_lock{&cloned.mutex_}; cloned.raw_string_ = raw_string_; } return cloned; @@ -4674,6 +4934,7 @@ HloInstruction::BackendConfigRep HloInstruction::BackendConfigRep::Clone() HloInstruction::BackendConfigRep& HloInstruction::BackendConfigRep::operator=( std::string raw_string) { + absl::MutexLock lock{&mutex_}; raw_string_ = std::move(raw_string); proto_.reset(); return *this; @@ -4682,6 +4943,7 @@ HloInstruction::BackendConfigRep& HloInstruction::BackendConfigRep::operator=( HloInstruction::BackendConfigRep& HloInstruction::BackendConfigRep::operator=( const tsl::protobuf::Message& proto) { SetProto(proto); + absl::MutexLock lock{&mutex_}; raw_string_.clear(); return *this; } @@ -4704,8 +4966,8 @@ bool HloInstruction::BackendConfigRep::operator==( return GetRawString() == other.GetRawString(); } -/* static */ StatusOr HloInstruction::BackendConfigToRawString( - const tsl::protobuf::Message& proto) { +/* static */ absl::StatusOr +HloInstruction::BackendConfigToRawString(const tsl::protobuf::Message& proto) { std::string ret; // Pass ignore_accuracy_loss = true because estimated_cycles field can be // INT64_MAX. If ignore_accuracy_loss = false and estimated_cycles = @@ -4758,26 +5020,23 @@ void HloInstruction::SortInstructionUsersAndControlLists( const MappedPtrContainerSorter::MapPtrFn& map_fn, const HloInstruction& sorted_instruction) { using Sorter = MappedPtrContainerSorter; - auto status = Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(), - sorted_instruction.users_, users_); - if (!status.ok()) { - LOG(ERROR) << "Failed to sort instruction users for " << name() << "; " - << status; - } - user_map_.clear(); - for (uint64_t i = 0; i < users_.size(); ++i) { - user_map_[users_[i]] = i; + users_.SortInstructionUsers(map_fn, sorted_instruction.users_); + + absl::Status status; + if (has_rare()) { + status = Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(), + sorted_instruction.control_predecessors(), + mutable_rare()->control_predecessors); } - status = Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(), - sorted_instruction.control_predecessors_, - control_predecessors_); if (!status.ok()) { LOG(ERROR) << "Failed to sort instruction control predecessors for " << name() << "; " << status; } - status = - Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(), - sorted_instruction.control_successors_, control_successors_); + if (has_rare()) { + status = Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(), + sorted_instruction.control_successors(), + mutable_rare()->control_successors); + } if (!status.ok()) { LOG(ERROR) << "Failed to sort instruction control successors for " << name() << "; " << status; @@ -4916,14 +5175,12 @@ HloInstruction* HloInstruction::fused_expression_root() const { return Cast(this)->fused_expression_root(); } -tsl::gtl::iterator_range>::const_iterator>> +tsl::gtl::iterator_range HloInstruction::fused_instructions() const { return Cast(this)->fused_instructions(); } -tsl::gtl::iterator_range< - UnwrappingIterator>::iterator>> +tsl::gtl::iterator_range HloInstruction::fused_instructions() { return Cast(this)->fused_instructions(); } @@ -5171,9 +5428,16 @@ bool HloInstruction::IsAsynchronous() const { return HloOpcodeIsAsync(opcode()); } +HloInstruction* HloInstruction::async_chain_start() const { + return Cast(this)->async_chain_start(); +} + +HloInstruction* HloInstruction::async_chain_done() const { + return Cast(this)->async_chain_done(); +} + HloComputation* HloInstruction::async_wrapped_computation() const { - CHECK(IsAsynchronous()); - return called_computations()[0]; + return Cast(this)->async_wrapped_computation(); } HloInstruction* HloInstruction::async_wrapped_instruction() const { @@ -5184,14 +5448,6 @@ HloOpcode HloInstruction::async_wrapped_opcode() const { return Cast(this)->async_wrapped_opcode(); } -std::optional HloInstruction::async_group_id() const { - return Cast(this)->async_group_id(); -} - -void HloInstruction::set_async_group_id(std::optional async_group_id) { - Cast(this)->set_async_group_id(async_group_id); -} - absl::string_view HloInstruction::async_execution_thread() const { return Cast(this)->async_execution_thread(); } diff --git a/xla/hlo/ir/hlo_instruction.h b/xla/hlo/ir/hlo_instruction.h index 3de91ccd73a3a..a28fb7cf52914 100644 --- a/xla/hlo/ir/hlo_instruction.h +++ b/xla/hlo/ir/hlo_instruction.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -37,12 +37,14 @@ limitations under the License. #include #include "absl/base/attributes.h" +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" @@ -50,6 +52,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_domain_metadata.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/ir/ptrvec.h" #include "xla/iterator_util.h" #include "xla/layout.h" #include "xla/literal.h" @@ -94,6 +97,7 @@ class HloPrintOptions { print_large_constants_(false), print_only_essential_constants_(false), print_metadata_(true), + print_metadata_only_op_name_(false), print_backend_config_(true), print_infeed_outfeed_config_(true), compact_operands_(false), @@ -202,6 +206,13 @@ class HloPrintOptions { return *this; } + // If true and print_metadata is true, metadata op name will be printed. Other + // metadata values will be omitted. + HloPrintOptions& set_print_metadata_only_op_name(bool value) { + print_metadata_only_op_name_ = value; + return *this; + } + // If true, backend_config will be printed. HloPrintOptions& set_print_backend_config(bool value) { print_backend_config_ = value; @@ -368,6 +379,9 @@ class HloPrintOptions { return print_subcomputation_mode_; } bool print_metadata() const { return print_metadata_; } + bool print_metadata_only_op_name() const { + return print_metadata_only_op_name_; + } bool print_backend_config() const { return print_backend_config_; } bool print_infeed_outfeed_config() const { return print_infeed_outfeed_config_; @@ -407,6 +421,7 @@ class HloPrintOptions { bool print_large_constants_; bool print_only_essential_constants_; bool print_metadata_; + bool print_metadata_only_op_name_; bool print_backend_config_; bool print_infeed_outfeed_config_; bool compact_operands_; @@ -445,6 +460,132 @@ class CanonicalNameMap { absl::flat_hash_map canonical_name_map_; }; +class HloInstruction; + +// A small holder that is used to keep some immutable info alongside an +// instruction pointer in an HloComputation's list of instructions +class HloInstructionInfo { + public: + HloInstruction* get() const { return inst_; } + HloInstruction& operator*() { return *inst_; } + HloInstruction* operator->() { return inst_; } + const HloInstruction& operator*() const { return *inst_; } + const HloInstruction* operator->() const { return inst_; } + + HloOpcode opcode() const { return opcode_; } + HloInstruction* inst() const { return inst_; } + + private: // TODO: Make private and provide accessors? + friend class HloComputation; + HloOpcode opcode_; + HloInstruction* inst_; +}; + +namespace mapped_ptr_container_sorter_internal { + +template +struct PtrGetter { + static const T* Get(const HloInstructionInfo& p) { return p.get(); } +}; + +} // namespace mapped_ptr_container_sorter_internal + +using HloInstructionList = std::vector; + +template +class HloInstructionIteratorBase { + public: + using iterator_category = std::input_iterator_tag; + using value_type = HloInstructionInfo; + using difference_type = ptrdiff_t; + using pointer = value_type*; + using reference = value_type&; + + HloInstructionIteratorBase(UnderlyingList* list, int begin_index, + int end_index) + : list_(list), current_(begin_index), end_index_(end_index) { + if (current_ < end_index_ && (*list_)[current_].inst() == nullptr) { + ++*this; + } + } + + HloInstruction* get() const { return (*list_)[current_].inst(); } + + auto operator*() -> HloInstructionInfo { return (*list_)[current_]; } + HloInstructionIteratorBase& operator++() { + int next = current_; + do { + ++next; + } while (next < end_index_ && (*list_)[next].inst() == nullptr); + current_ = next; + return *this; + } + HloInstructionIteratorBase operator++(int) { + HloInstructionIteratorBase temp(list_, current_, end_index_); + operator++(); + return temp; + } + + friend bool operator==(const HloInstructionIteratorBase& a, + const HloInstructionIteratorBase& b) { + return a.current_ == b.current_; + } + + friend bool operator!=(const HloInstructionIteratorBase& a, + const HloInstructionIteratorBase& b) { + return !(a == b); + } + + private: + UnderlyingList* list_; + int current_; + int end_index_; +}; +using HloInstructionIterator = HloInstructionIteratorBase; +using HloInstructionConstIterator = + HloInstructionIteratorBase; + +template +class HloInstructionUnwrappingIteratorBase { + public: + using iterator_category = std::input_iterator_tag; + using value_type = HloInstruction*; + using difference_type = ptrdiff_t; + using pointer = value_type*; + using reference = value_type&; + + explicit HloInstructionUnwrappingIteratorBase(WrappedIter iter) + : iter_(std::move(iter)) {} + + auto operator*() -> value_type { return iter_.get(); } + HloInstructionUnwrappingIteratorBase& operator++() { + ++iter_; + return *this; + } + HloInstructionUnwrappingIteratorBase operator++(int) { + HloInstructionUnwrappingIteratorBase temp(iter_); + operator++(); + return temp; + } + + friend bool operator==(const HloInstructionUnwrappingIteratorBase& a, + const HloInstructionUnwrappingIteratorBase& b) { + return a.iter_ == b.iter_; + } + + friend bool operator!=(const HloInstructionUnwrappingIteratorBase& a, + const HloInstructionUnwrappingIteratorBase& b) { + return !(a == b); + } + + private: + WrappedIter iter_; +}; +using HloInstructionUnwrappingIterator = + HloInstructionUnwrappingIteratorBase; +using HloInstructionUnwrappingConstIterator = + HloInstructionUnwrappingIteratorBase; + // HLO instructions are the atomic unit of the high-level compiler's IR. // // HloInstructions live inside of an HloComputation, which is analogous to a @@ -554,7 +695,7 @@ class HloInstruction { // computation_map: a map from computation id to HloComputation*. This map // must contain all computations which the newly constructed instruction // calls. - static StatusOr> CreateFromProto( + static absl::StatusOr> CreateFromProto( const HloInstructionProto& proto, const absl::flat_hash_map& instruction_map, const absl::flat_hash_map& computation_map = {}, @@ -665,18 +806,11 @@ class HloInstruction { static std::unique_ptr CreateAsyncStart( const Shape& shape, absl::Span operands, HloComputation* async_computation, - std::optional async_group_id = std::nullopt, absl::string_view async_execution_thread = kMainExecutionThread); static std::unique_ptr CreateAsyncUpdate( - const Shape& shape, HloInstruction* operand, - HloComputation* async_computation, - std::optional async_group_id = std::nullopt, - absl::string_view async_execution_thread = kMainExecutionThread); + const Shape& shape, HloInstruction* operand); static std::unique_ptr CreateAsyncDone( - const Shape& shape, HloInstruction* operand, - HloComputation* async_computation, - std::optional async_group_id = std::nullopt, - absl::string_view async_execution_thread = kMainExecutionThread); + const Shape& shape, HloInstruction* operand); // Creates a copy-start op, indicating whether this is a cross-program // prefetch or not. @@ -698,11 +832,16 @@ class HloInstruction { const Shape& shape, HloInstruction* a, const CholeskyOptions& options); // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch - // dimensions specified in 'dimension_numbers'. + // dimensions specified in 'dimension_numbers'. If 'sparsity' is set, then + // 'sparse_meta' must also be present (and have the same size). + // Note: 'sparsity' argument is eventually moved in the HloDotInstruction + // constructor, so no extra copies are created. static std::unique_ptr CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfig& precision_config); + const PrecisionConfig& precision_config, + std::vector sparsity = {}, + absl::Span sparse_meta = {}); // Creates a reduce-precision op, where operand is the data to reduce in // precision, and exponent_bits and mantissa_bits describe the precision to @@ -815,6 +954,21 @@ class HloInstruction { const std::optional& channel_id, const std::optional& split_dimension = std::nullopt); + // Creates a communication instruction that broadcasts data cross replicas. + // Data is sent from to the first replica id in each group to the other ids in + // the same group. If a replica id is not a in any replica group, the output + // on that replica is a tensor consists of 0(s) in `shape`. + static std::unique_ptr CreateCollectiveBroadcast( + const Shape& shape, absl::Span operand, + absl::Span replica_groups, bool constrain_layout, + const std::optional& channel_id); + + static std::unique_ptr CreateCollectiveBroadcast( + const Shape& shape, HloInstruction* input, HloInstruction* output, + HloInstruction* input_start_indices, HloInstruction* output_start_indices, + absl::Span replica_groups, bool constrain_layout, + const std::optional& channel_id); + // Creates a communication instruction that permutes data cross replicas. // Data is sent/received according to the (source_replica_id, // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a @@ -902,6 +1056,10 @@ class HloInstruction { // The operand must be kSend. static std::unique_ptr CreateSendDone( HloInstruction* operand, bool is_host_transfer = false); + // Similar to the above, but the operand doesn't have to be a kSend. + static std::unique_ptr CreateSendDone( + HloInstruction* operand, int64_t channel_id, + bool is_host_transfer = false); // Creates an asynchronous receive instruction with the given channel id, // which allocates resources to receive data of the given shape from a unique @@ -916,6 +1074,10 @@ class HloInstruction { // and returns the receive buffer. The operand must be kRecv. static std::unique_ptr CreateRecvDone( HloInstruction* operand, bool is_host_transfer = false); + // Similar to the above, but the operand doesn't have to be a kRecv. + static std::unique_ptr CreateRecvDone( + HloInstruction* operand, int64_t channel_id, + bool is_host_transfer = false); // Creates a slice instruction, where the operand is sliced by the given // start/limit indices. @@ -1225,7 +1387,9 @@ class HloInstruction { HloOpcode* mutable_opcode() { return &opcode_; } // Returns whether this instruction is the root of its parent computation. - bool IsRoot() const; + bool IsRoot() const { return is_root_; } + void MarkAsRoot() { is_root_ = true; } + void MarkAsNonRoot() { is_root_ = false; } // Does this instruction have no users. bool IsDead() const { return users_.empty() && !IsRoot(); } @@ -1237,7 +1401,7 @@ class HloInstruction { // Returns true if this instruction has a side effect. An instruction has a // side effect if it uses certain opcodes or calls a computation with a side // effect. - bool HasSideEffect() const; + virtual bool HasSideEffect() const; // Returns the result shape of this instruction. const Shape& shape() const; @@ -1271,16 +1435,16 @@ class HloInstruction { int64_t user_count() const { return users_.size(); } // Returns the users of this instruction. - const std::vector& users() const { return users_; } + const PtrVec& users() const { return users_.vec(); } // Returns the index of the user in the users() vector. // // Precondition: `user` is a user of the instruction. - int64_t UserId(HloInstruction* user); + int64_t UserId(HloInstruction* user) { return users_.UserId(user); } // Returns true if this instruction is a user of 'instruction'. bool IsUserOf(const HloInstruction* instruction) const { - return instruction->user_map_.contains(this); + return instruction->users_.Contains(this); } // Adds a control dependency from this instruction to the given @@ -1316,16 +1480,24 @@ class HloInstruction { // Depending on the use cases we see in practice, in the future we may // consider folding the logic here into Clone, CloneWithNewOperands and // ReplaceAllUsesWith by treating control dependencies like data dependencies. - Status CopyAllControlDepsFrom(const HloInstruction* inst); + Status CopyAllControlDepsFrom(const HloInstruction* inst) { + return inst->CopyAllControlDepsTo(this, this); + } + + // Copies all control dependencies of this instruction to start/end. Copies + // all control predecessors of this instruction to control predecessors of + // `start` and copies all control successors of this instruction to control + // successors of `end`. + Status CopyAllControlDepsTo(HloInstruction* start, HloInstruction* end) const; // Returns the set of control predecessors (successors) of this // instruction. Control predecessors (successors) must execute before (after) // the current instruction. - const std::vector& control_predecessors() const { - return control_predecessors_; + const PtrVec& control_predecessors() const { + return rare()->control_predecessors; } - const std::vector& control_successors() const { - return control_successors_; + const PtrVec& control_successors() const { + return rare()->control_successors; } // Returns true if 'other' performs the same computation as this instruction. @@ -1436,6 +1608,9 @@ class HloInstruction { Status ReplaceOperandWithDifferentShape(int64_t operand_num, HloInstruction* new_operand); + // Decomposes fusion back to individual parts. + Status Defuse(); + // Replaces all uses of this instruction with the new producer. If // new_producer is a user of this instruction then new_producer remains a use // of this instruction to avoid introducing cycles into the graph. @@ -1525,6 +1700,13 @@ class HloInstruction { // a bitcast. bool IsEffectiveBitcast() const; + // Returns true if this instruction is asynchronous with the + // async_execution_thread set to `execution_thread`. + bool IsAsyncInstructionWithExecutionThread( + absl::string_view execution_thread) const { + return IsAsynchronous() && async_execution_thread() == execution_thread; + }; + // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc. // The setter should only be called by HloModule or HloComputation methods. // @@ -1554,7 +1736,7 @@ class HloInstruction { // Gets the branch HloComputations for Conditional. // // Precondition: The instruction is a Conditional instruction. - const std::vector& branch_computations() const; + const PtrVec& branch_computations() const; int branch_count() const; HloComputation* branch_computation(int b) const; // Sets a branch HloComputation for Conditional. @@ -1726,17 +1908,25 @@ class HloInstruction { const std::string& suffix, HloCloneContext* context = nullptr) const; // Returns the computations this instruction directly calls (if any). - const std::vector& called_computations() const { - return called_computations_; + const PtrVec& called_computations() const { + return rare()->called_computations; } + bool has_called_computations() const { + return has_rare() && !called_computations().empty(); + } + + // Returns true iff an instruction of type "opcode" might have non-empty + // called_computations. + static bool MightHaveCalledComputations(HloOpcode opcode); // Replaces all called computations based on a map function. This is needed // when we clone hlo_computations and want to let the instructions to point // to the newly cloned nodes. void ReplaceCalledComputations( absl::FunctionRef map_function) { - for (int64_t i = 0; i < called_computations_.size(); ++i) { - called_computations_[i] = map_function(called_computations_[i]); + for (int64_t i = 0; i < called_computations().size(); ++i) { + mutable_rare()->called_computations[i] = + map_function(rare()->called_computations[i]); } } @@ -1748,7 +1938,11 @@ class HloInstruction { // clearing out the computations, we reflect the fact that all side-effecting // properties have been reflected in the caller, and make the call HLO // removable. - virtual void ClearCalledComputations() { called_computations_.clear(); } + virtual void ClearCalledComputations() { + if (has_rare()) { + mutable_rare()->called_computations.clear(); + } + } // Returns true if this instruction performs an elementwise operation on // `operand_idx`-th operand. An instruction is elementwise on an operand iff, @@ -1850,7 +2044,7 @@ class HloInstruction { return OkStatus(); } - bool preserve_layout() const { return metadata_.preserve_layout(); } + bool preserve_layout() const { return metadata_->preserve_layout(); } bool has_backend_config() const { return !backend_config_.empty(); } @@ -1861,44 +2055,49 @@ class HloInstruction { } void set_frontend_attributes(FrontendAttributes frontend_attributes) { - frontend_attributes_ = std::move(frontend_attributes); + if (!has_rare() && frontend_attributes.map().empty()) { + return; + } + mutable_rare()->frontend_attributes = std::move(frontend_attributes); } void add_frontend_attributes(FrontendAttributes frontend_attributes) { - frontend_attributes_.mutable_map()->insert( - frontend_attributes.map().begin(), frontend_attributes.map().end()); + if (!frontend_attributes.map().empty()) { + mutable_rare()->frontend_attributes.mutable_map()->insert( + frontend_attributes.map().begin(), frontend_attributes.map().end()); + } } const FrontendAttributes& frontend_attributes() const { - return frontend_attributes_; + return rare()->frontend_attributes; } void add_single_statistic(Statistic statistic) { - *statistics_viz_.add_statistics() = std::move(statistic); + *mutable_rare()->statistics_viz.add_statistics() = std::move(statistic); } void set_stat_index_to_visualize(int64_t index) { - statistics_viz_.set_stat_index_to_visualize(index); + mutable_rare()->statistics_viz.set_stat_index_to_visualize(index); } // Whether this specific instruction has statistics - bool has_statistics() const { return !statistics_viz_.statistics().empty(); } + bool has_statistics() const { return !statistics_viz().statistics().empty(); } // Whether any instruction within the same HLO module as this has statistics bool module_has_statistics() const { - return statistics_viz_.stat_index_to_visualize() == -1; + return statistics_viz().stat_index_to_visualize() == -1; } const Statistic& statistic_to_visualize() const { - return statistics_viz_.statistics().at( - statistics_viz_.stat_index_to_visualize()); + return statistics_viz().statistics().at( + statistics_viz().stat_index_to_visualize()); } void set_statistics_viz(StatisticsViz statistics_viz) { - statistics_viz_ = std::move(statistics_viz); + mutable_rare()->statistics_viz = std::move(statistics_viz); } - const StatisticsViz& statistics_viz() const { return statistics_viz_; } + const StatisticsViz& statistics_viz() const { return rare()->statistics_viz; } // Getter/setter for raw JSON-encoded backend config. Prefer the // functions above that deal in proto Messages where possible. @@ -1912,26 +2111,6 @@ class HloInstruction { bool is_default_config() const { return is_default_config_; } void set_default_config() { is_default_config_ = true; } - void set_operation_queue_id(int64_t operation_queue_id) { - operation_queue_id_ = operation_queue_id; - } - - const std::optional operation_queue_id() const { - return operation_queue_id_; - } - - void set_wait_on_operation_queues(std::vector& operation_queue_ids) { - wait_on_operation_queues_ = operation_queue_ids; - } - - const std::vector wait_on_operation_queues() const { - return wait_on_operation_queues_; - } - - void add_wait_on_operation_queues(int64_t operation_queue_id) { - wait_on_operation_queues_.push_back(operation_queue_id); - } - // Returns a string representation of a proto in the format used by // raw_backend_config_string. // @@ -1941,7 +2120,7 @@ class HloInstruction { // TF_RETURN_IF_ERROR(instr.set_backend_config(proto)); // return instr.raw_backend_config_string(); // - static StatusOr BackendConfigToRawString( + static absl::StatusOr BackendConfigToRawString( const tsl::protobuf::Message& proto); // Returns the information used to tell the implementation information about @@ -1957,36 +2136,26 @@ class HloInstruction { // Sets the debug metadata for this instruction, excluding creation_pass_id, // which should never be copied anywhere. - void set_metadata(const OpMetadata& metadata) { - int64_t creation_pass_id = metadata_.creation_pass_id(); - metadata_ = metadata; - metadata_.set_creation_pass_id(creation_pass_id); - } + void set_metadata(const OpMetadata& metadata) { *metadata_ = metadata; } void set_size_of_generated_code_in_bytes(int64_t code_size_in_bytes) { - metadata_.set_size_of_generated_code_in_bytes(code_size_in_bytes); + metadata_->set_size_of_generated_code_in_bytes(code_size_in_bytes); } void set_size_of_memory_working_set_in_bytes( int64_t working_set_size_in_bytes) { - metadata_.set_size_of_memory_working_set_in_bytes( + metadata_->set_size_of_memory_working_set_in_bytes( working_set_size_in_bytes); } - void set_creation_pass_id(int64_t pass_id) { - metadata_.set_creation_pass_id(pass_id); - } void set_metadata_op_name(const std::string& name) { - metadata_.set_op_name(name); - } - void set_logical_creation_pass_id(int64_t pass_id) { - metadata_.set_logical_creation_pass_id(pass_id); + metadata_->set_op_name(name); } void set_metadata_deduplicated_name(std::string deduplicated_name) { - metadata_.set_deduplicated_name(std::move(deduplicated_name)); + metadata_->set_deduplicated_name(std::move(deduplicated_name)); } void set_metadata_preserve_layout(bool preserve_layout) { - metadata_.set_preserve_layout(preserve_layout); + metadata_->set_preserve_layout(preserve_layout); } - const OpMetadata& metadata() const { return metadata_; } + const OpMetadata& metadata() const { return *metadata_; } // Set/get the computation containing this instruction. set_parent should only // be called by HloComputation methods which add/remove instructions to @@ -2103,12 +2272,10 @@ class HloInstruction { HloInstruction* fused_expression_root() const; // Delegates to HloFusionInstruction::fused_instructions. - tsl::gtl::iterator_range>::const_iterator>> + tsl::gtl::iterator_range fused_instructions() const; - tsl::gtl::iterator_range< - UnwrappingIterator>::iterator>> + tsl::gtl::iterator_range fused_instructions(); // Delegates to HloFusionInstruction::fused_instruction_count. @@ -2272,6 +2439,12 @@ class HloInstruction { // async-done. bool IsAsynchronous() const; + // Delagates to HloAsyncInstruction::async_chain_start(). + HloInstruction* async_chain_start() const; + + // Delagates to HloAsyncInstruction::async_done(). + HloInstruction* async_chain_done() const; + // Returns the computation that will executed asynchronously. HloComputation* async_wrapped_computation() const; @@ -2281,12 +2454,6 @@ class HloInstruction { // Delagates to HloAsyncInstruction::async_wrapped_opcode(). HloOpcode async_wrapped_opcode() const; - // Delegates to HloAsyncInstruction::async_group_id(). - std::optional async_group_id() const; - - // Delegates to HloAsyncInstruction::set_async_group_id(). - void set_async_group_id(std::optional async_group_id); - // Delegates to HloAsyncInstruction::async_execution_thread(). absl::string_view async_execution_thread() const; @@ -2341,16 +2508,14 @@ class HloInstruction { void RemoveOperandsAtAscendingIndices( absl::Span ascending_indices); - void AppendComputation(HloComputation* computation) { - called_computations_.push_back(computation); - } + void AppendComputation(HloComputation* computation); void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); } void set_called_computation(int index, HloComputation* computation) { - called_computations_[index] = computation; + mutable_rare()->called_computations[index] = computation; } - // Indices of computations in called_computations_ for instructions which call + // Indices of computations in called_computations for instructions which call // multiple computations. enum { // kWhile computations. @@ -2385,21 +2550,41 @@ class HloInstruction { return !(*this == other); } - bool empty() const { return proto_ == nullptr && raw_string_.empty(); } + bool empty() const { + absl::MutexLock lock{&mutex_}; + return proto_ == nullptr && raw_string_.empty(); + } void clear() { proto_.reset(); + absl::MutexLock lock{&mutex_}; raw_string_.clear(); } + BackendConfigRep() = default; + BackendConfigRep(BackendConfigRep&& other) + : proto_(std::move(other.proto_)), raw_string_([&] { + absl::MutexLock lock{&other.mutex_}; + return std::move(other.raw_string_); + }()) {} + BackendConfigRep& operator=(std::string raw_string); BackendConfigRep& operator=(const tsl::protobuf::Message& proto); + BackendConfigRep& operator=(BackendConfigRep&& other) { + proto_ = std::move(other.proto_); + absl::MutexLock destination_lock{&mutex_}; + absl::MutexLock source_lock{&other.mutex_}; + raw_string_ = std::move(other.raw_string_); + return *this; + } + void SetProto(const tsl::protobuf::Message& proto); private: std::unique_ptr proto_; // If proto_ is not null, raw_string_ is a lazy cache of its string format. - mutable std::string raw_string_; + mutable absl::Mutex mutex_; + mutable std::string raw_string_ ABSL_GUARDED_BY(mutex_); }; bool IdenticalInternal( @@ -2449,10 +2634,10 @@ class HloInstruction { absl::Span operands); // Adds a user for this instruction. - void AddUser(HloInstruction* user); + void AddUser(HloInstruction* user) { users_.AddUser(user); } // Removes a user for this instruction. - void RemoveUser(HloInstruction* user); + void RemoveUser(HloInstruction* user) { users_.RemoveUser(user); } // Helper for implementing backend_config(). Parses backend_config_ into the // given proto. @@ -2465,90 +2650,152 @@ class HloInstruction { // HloInstruction. bool IsMarkedAsDead() const { return marked_as_dead_; } + // Rare is allocated lazily, only when any of its constituent fields are + // non-empty. This reduces the memory footprint of HloInstruction objects. + struct Rare { + // Computations called by this instruction. + PtrVec called_computations; + + // The set of control predecessors of this instruction. + // Note that the order of the instructions in the vector influences the + // order computed in HloComputation::ComputeInstructionPostOrder, which may + // influence the result of the compilation by changing the scheduling. We + // are not sure if it matters. + PtrVec control_predecessors; + + // The set of control successors of this instruction. + PtrVec control_successors; + + // Attributes passed from the frontend to give hints to the backend about + // how to compile this HLO. + // HLO -> HLO transforms are expected to preserve these attributes on a + // "best effort" basis only. + // For example: + // x = const(10, frontend_attributes={x} + // y = const(10, frontend_attributes={y} + // z = add(x,y), frontend_attributes={y} + // Could be simplified to: + // z' = const(20), frontend_attributes={?} + FrontendAttributes frontend_attributes; + + // Used to render an HLO graph when tracking the propagation desired values + // through it. + StatisticsViz statistics_viz; + }; + + static const Rare* const kEmptyRare; + + bool has_rare() const { return rare_ != nullptr; } + + // Return the allocated rare state, or the pointer to the static empty rare + // state + const Rare* rare() const { + Rare* r = rare_.get(); + return (r == nullptr) ? kEmptyRare : r; + } + + // Lazily allocate the Rare struct + Rare* mutable_rare() { + if (rare_ == nullptr) { + rare_ = std::make_unique(); + } + return rare_.get(); + } + + // Users holds the list of users of an HloInstruction, plus it provides a fast + // way for checking for presence of a potential user. + class Users { + public: + Users() = default; + ~Users(); + + // No copying allowed + Users(const Users&) = delete; + Users& operator=(const Users&) = delete; + + bool empty() const { return users_.empty(); } + int64_t size() const { return users_.size(); } + const PtrVec& vec() const { return users_; } + + void Clear(); + bool Contains(const HloInstruction* instruction) const; + void AddUser(HloInstruction* user); + void MaybeRemoveUser(HloInstruction* user); // Remove user if present + void RemoveUser(HloInstruction* user); // REQUIRES: Contains(user) + int64_t UserId(HloInstruction* user); + void SortInstructionUsers( + const MappedPtrContainerSorter::MapPtrFn& map_fn, + const Users& sorted_instruction_users); + bool CheckInvariants(); + + private: + void RebuildMap(); + + PtrVec users_; + + // If users_ is big, we also maintain a copy of the elements of users_ + // in a hash map to enable fast membership tests. The value in the map + // contains the index of the instruction in the vector what enables fast + // removal. + static constexpr size_t kMapThreshold = 16; + std::unique_ptr> + user_map_; + }; + int unique_id_; // Unique to this HloInstruction within a HloModule + uint32_t index_in_parent_; // Index that identifies inst in HloComputation // Opcode for this instruction. HloOpcode opcode_; + // This field is assigned to true when backend_config_ is assigned to + // a default configuration. + bool is_default_config_ : 1; + + // True if this instruction has already been detached from its user and + // operands. + bool cleaned_up_ : 1; + + // Intrusive flag used by HloComputation, whether this instruction has + // been marked as dead. + bool marked_as_dead_ : 1; + + // True if this instruction is the root of a computation. + bool is_root_ : 1; + // Instruction operands. InstructionVector operands_; - // The set of control predecessors of this instruction. - // Note that the order of the instructions in the vector influences the order - // computed in HloComputation::ComputeInstructionPostOrder, which may - // influence the result of the compilation by changing the scheduling. We are - // not sure if it matters. - std::vector control_predecessors_; + // If needed, points off to allocated struct holding out-of-line info + // for things that are rarely filled + std::unique_ptr rare_; // The users of this instruction. Users are HLOs where this instruction is an - // operand. The vector users_ and the map user_map_ contain identical members. - // The map enables fast membership testing and the vector enables fast, stable - // iteration. The value in the map contains the index of the instruction in - // the vector what enables fast removal. - std::vector users_; - absl::flat_hash_map user_map_; - - // The set of control successors of this instruction. - std::vector control_successors_; + // operand. + Users users_; // The computation in which this instruction is contained. HloComputation* parent_ = nullptr; - // Result shape of this instruction. - Shape shape_; - // The sharding, if one exists. // Uses std::shared_ptr to allow reuse of the same sharding object between // HloInstructions and other components as HloSharding can be very large for // many element tuples. std::shared_ptr sharding_; - // Computations called by this instruction. - std::vector called_computations_; + // Result shape of this instruction. + Shape shape_; // The backend-specific configuration for how a backend should compile this // HLO. See the documentation on backend_config(). mutable BackendConfigRep backend_config_; - // Attributes passed from the frontend to give hints to the backend about - // how to compile this HLO. - // HLO -> HLO transforms are expected to preserve these attributes on a - // "best effort" basis only. - // For example: - // x = const(10, frontend_attributes={x} - // y = const(10, frontend_attributes={y} - // z = add(x,y), frontend_attributes={y} - // Could be simplified to: - // z' = const(20), frontend_attributes={?} - FrontendAttributes frontend_attributes_; - - // Used to render an HLO graph when tracking the propagation desired values - // through it. - StatisticsViz statistics_viz_; - // String identifier for instruction. std::string name_; - // Metadata for debugging. - OpMetadata metadata_; - - // This field is assigned to true when backend_config_ is assigned to - // a default configuration. - bool is_default_config_ = false; - - // True if this instruction has already been detached from its user and - // operands. - bool cleaned_up_ = false; - - // Intrusive flag used by HloComputation, whether this instruction has - // been marked as dead. - bool marked_as_dead_; - - // ID of the operation queue to run this instruction. - std::optional operation_queue_id_; - - // IDs of operation queues to await before running this instruction. - std::vector wait_on_operation_queues_; + // Metadata for debugging. Allocate it on heap, so that it does not increase + // the memory footprint of HloInstruction. + std::unique_ptr metadata_ = std::make_unique(); }; // Explicit instantiations in hlo_instruction.cc. @@ -2559,7 +2806,7 @@ extern template Status HloInstruction::Visit(DfsHloVisitor* visitor); extern template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor); absl::string_view ToString(HloInstruction::FusionKind kind); -StatusOr StringToFusionKind( +absl::StatusOr StringToFusionKind( absl::string_view kind_name); // Custom (de)stringification functions for protos that live inside @@ -2569,18 +2816,24 @@ std::string StatisticsVizToString(const StatisticsViz& statistics_viz); std::string RandomAlgorithmToString(const RandomAlgorithm& algorithm); std::string RandomDistributionToString(const RandomDistribution& distribution); std::string PrecisionToString(const PrecisionConfig::Precision& precision); +std::string AlgorithmToString(const PrecisionConfig::Algorithm& algorithm); std::string DotDimensionNumbersToString(const DotDimensionNumbers& dnums); std::string ConvolutionDimensionNumbersToString( const ConvolutionDimensionNumbers& dnums); std::string ReplicaGroupsToString( absl::Span replica_groups); -StatusOr StringToRandomAlgorithm(const std::string& name); -StatusOr StringToRandomDistribution( +absl::StatusOr StringToRandomAlgorithm( + const std::string& name); +absl::StatusOr StringToRandomDistribution( + const std::string& name); +absl::StatusOr StringToPrecision( + const std::string& name); +absl::StatusOr StringToAlgorithm( const std::string& name); -StatusOr StringToPrecision(const std::string& name); -StatusOr StringToCustomCallSchedule(absl::string_view name); -StatusOr StringToCustomCallApiVersion( +absl::StatusOr StringToCustomCallSchedule( + absl::string_view name); +absl::StatusOr StringToCustomCallApiVersion( absl::string_view name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); @@ -2608,13 +2861,41 @@ using ConstHloInstructionSet = template bool HloPredicateIsOp(const HloInstruction* instruction) { - if (instruction->opcode() == op) { - return true; - } - if constexpr (sizeof...(rest) == 0) { - return false; - } else { - return HloPredicateIsOp(instruction); + return (instruction->opcode() == op) || + ((instruction->opcode() == rest) || ...); +} + +/* static */ inline bool HloInstruction::MightHaveCalledComputations( + HloOpcode opcode) { + switch (opcode) { + // Control flow opcodes + case HloOpcode::kWhile: + case HloOpcode::kConditional: + + // Fusion contains a sub-computation + case HloOpcode::kFusion: + + // Async + case HloOpcode::kAsyncStart: + case HloOpcode::kAsyncUpdate: + case HloOpcode::kAsyncDone: + + // Opcodes for which has_to_apply can return true + case HloOpcode::kAllReduce: + case HloOpcode::kAllReduceStart: + case HloOpcode::kCall: + case HloOpcode::kMap: + case HloOpcode::kReduce: + case HloOpcode::kReduceScatter: + case HloOpcode::kReduceWindow: + case HloOpcode::kScatter: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kSort: + case HloOpcode::kTopK: + case HloOpcode::kCustomCall: + return true; + default: + return false; } } diff --git a/xla/hlo/ir/hlo_instructions.cc b/xla/hlo/ir/hlo_instructions.cc index 566b33dccf109..d2f05e6eb1102 100644 --- a/xla/hlo/ir/hlo_instructions.cc +++ b/xla/hlo/ir/hlo_instructions.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -85,24 +85,54 @@ bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, void PrintPrecisionConfig(HloInstruction::AttributePrinter& printer, const PrecisionConfig& precision_config) { - if (absl::c_all_of( + if (absl::c_any_of( precision_config.operand_precision(), [](int32_t precision) { - return static_cast(precision) == + return static_cast(precision) != PrecisionConfig::DEFAULT; })) { - return; + printer.Next([&precision_config](Printer* printer) { + printer->Append("operand_precision={"); + AppendJoin(printer, precision_config.operand_precision(), ",", + [](Printer* printer, int32_t precision) { + CHECK(PrecisionConfig::Precision_IsValid(precision)) + << precision; + printer->Append(PrecisionToString( + static_cast(precision))); + }); + printer->Append("}"); + }); } - printer.Next([&precision_config](Printer* printer) { - printer->Append("operand_precision={"); - AppendJoin(printer, precision_config.operand_precision(), ",", - [](Printer* printer, int32_t precision) { - CHECK(PrecisionConfig::Precision_IsValid(precision)) - << precision; - printer->Append(PrecisionToString( - static_cast(precision))); - }); - printer->Append("}"); + if (precision_config.algorithm() != PrecisionConfig::ALG_UNSET) { + printer.Next([&precision_config](Printer* printer) { + printer->Append("algorithm="); + printer->Append(AlgorithmToString(precision_config.algorithm())); + }); + } +} + +void PrintSparsityDescriptor(HloInstruction::AttributePrinter& printer, + absl::Span sparsity) { + printer.Next([&sparsity](Printer* printer) { + printer->Append("sparsity="); + for (int i = 0; i < sparsity.size(); ++i) { + if (i != 0) { + printer->Append("_"); + } + const SparsityDescriptor& cur = sparsity[i]; + printer->Append(cur.index() == 0 ? "L." : "R."); + printer->Append(cur.dimension()); + printer->Append("@"); + switch (cur.type()) { + case SPARSITY_STRUCTURED_N_M: + printer->Append(cur.n()); + printer->Append(":"); + printer->Append(cur.m()); + break; + default: + LOG(FATAL) << "Unknown sparsity type: " << cur.type(); + } + } }); } @@ -272,81 +302,81 @@ std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( HloAsyncInstruction::HloAsyncInstruction( HloOpcode opcode, const Shape& shape, - absl::Span operands, - HloComputation* async_computation, std::optional async_group_id, - absl::string_view async_execution_thread) - : HloInstruction(opcode, shape), - async_group_id_(async_group_id), - async_execution_thread_(async_execution_thread) { + absl::Span operands, HloOpcode async_wrapped_opcode) + : HloInstruction(opcode, shape) { CHECK(opcode == HloOpcode::kAsyncStart || operands.size() == 1); for (auto operand : operands) { AppendOperand(operand); } - AppendComputation(async_computation); - CHECK(!async_computation->IsCustomCallComputation()); - CHECK(!async_computation->IsFusionComputation()); - async_computation->AddAsyncInstruction(*this); - set_async_execution_thread(async_execution_thread); // Drop 'async' from async-{start/update/done} to get the suffix. absl::string_view suffix = HloOpcodeString(opcode).substr(5); - absl::string_view wrapped_name = HloOpcodeString(async_wrapped_opcode()); + absl::string_view wrapped_name = HloOpcodeString(async_wrapped_opcode); SetAndSanitizeName(absl::StrCat(wrapped_name, suffix)); } -HloAsyncInstruction::HloAsyncInstruction( - HloOpcode opcode, const Shape& shape, HloInstruction* operand, - HloComputation* async_computation, std::optional async_group_id, - absl::string_view async_execution_thread) +HloAsyncInstruction::HloAsyncInstruction(HloOpcode opcode, const Shape& shape, + HloInstruction* operand) : HloAsyncInstruction(opcode, shape, absl::MakeConstSpan(&operand, 1), - async_computation, async_group_id, - async_execution_thread) {} - -HloAsyncInstruction::~HloAsyncInstruction() { - ClearAsyncComputationInstruction(); - ClearCalledComputations(); + operand->async_wrapped_opcode()) { + CHECK(operand->opcode() == HloOpcode::kAsyncStart || + operand->opcode() == HloOpcode::kAsyncUpdate); + HloAsyncInstruction* prev = Cast(operand); + prev->async_chain_next_ = this; } -void HloAsyncInstruction::ClearAsyncComputationInstruction() { - // Each async instruction calls a single computation, but we use - // called_computations() instead of async_wrapped_instruction(), because the - // order in which things get destructed can vary; the async computation's - // back-pointer may already be null, which violates a check in - // async_wrapped_instruction. - for (HloComputation* computation : called_computations()) { - CHECK(computation != nullptr); - if (computation->IsAsyncComputation()) { - computation->RemoveAsyncInstruction(this); - } - } +HloComputation* HloAsyncInstruction::async_wrapped_computation() const { + return async_chain_start()->called_computations().front(); } HloInstruction* HloAsyncInstruction::async_wrapped_instruction() const { - CHECK(!called_computations().empty()); - return called_computations()[0]->root_instruction(); + return async_chain_start()->async_wrapped_computation()->root_instruction(); } HloOpcode HloAsyncInstruction::async_wrapped_opcode() const { - return async_wrapped_instruction()->opcode(); + return async_chain_start()->async_wrapped_instruction()->opcode(); } -void HloAsyncInstruction::PrintExtraAttributesImpl( - AttributePrinter& printer, const HloPrintOptions& options) const { - if (async_group_id_.has_value()) { - printer.Next([this](Printer* printer) { - AppendCat(printer, "async_group_id=", *async_group_id_); - }); +absl::string_view HloAsyncInstruction::async_execution_thread() const { + return async_chain_start()->async_execution_thread(); +} + +HloAsyncInstruction* HloAsyncInstruction::async_chain_start() const { + if (opcode() == HloOpcode::kAsyncStart) { + return const_cast(this); } - if (async_execution_thread_ != kMainExecutionThread) { - printer.Next([this](Printer* printer) { - AppendCat(printer, "async_execution_thread=\"", async_execution_thread_, - "\""); - }); + + HloInstruction* prev = operands()[0]; + while (prev->opcode() != HloOpcode::kAsyncStart) { + // If the prev op in the chain isn't async-start, it must be async-update. + CHECK(prev->opcode() == HloOpcode::kAsyncUpdate); + prev = prev->operands()[0]; } - if (options.syntax_sugar_async_ops() && - async_wrapped_computation()->CanExpandIntoSingleInstruction()) { - async_wrapped_instruction()->PrintExtraAttributes(printer, options); + return Cast(prev); +} + +HloAsyncInstruction* HloAsyncInstruction::async_chain_done() const { + if (opcode() == HloOpcode::kAsyncDone) { + return const_cast(this); } + + HloAsyncInstruction* next = async_chain_next_; + while (next->opcode() != HloOpcode::kAsyncDone) { + // If the next op in the chain isn't async-done, it must be async-update. + CHECK(next->opcode() == HloOpcode::kAsyncUpdate); + next = next->async_chain_next_; + } + return next; +} + +std::vector HloAsyncInstruction::GetAsyncChain() const { + std::vector chain; + HloAsyncInstruction* current = async_chain_start(); + do { + chain.push_back(current); + current = current->async_chain_next_; + } while (current != nullptr); + return chain; } bool HloAsyncInstruction::IdenticalSlowPath( @@ -361,46 +391,50 @@ bool HloAsyncInstruction::IdenticalSlowPath( std::unique_ptr HloAsyncInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - HloModule* module = context != nullptr ? context->module() : GetModule(); - HloComputation* new_wrapped_computation = nullptr; - if (context != nullptr) { - new_wrapped_computation = - context->FindComputation(async_wrapped_computation()); - } - if (new_wrapped_computation == nullptr) { - // kAsyncDone and kAsyncUpdate uses the sync wrapped computation of its - // corresponding kAsyncUpdate or kAsyncDone, avoid cloning the computation - // again. - if ((opcode() == HloOpcode::kAsyncDone || - opcode() == HloOpcode::kAsyncUpdate) && - operand(0)->async_wrapped_computation() == - async_wrapped_computation()) { - new_wrapped_computation = new_operands[0]->async_wrapped_computation(); - } else { - new_wrapped_computation = module->AddEmbeddedComputation( - async_wrapped_computation()->Clone("clone", context)); - } - } - return std::make_unique( - opcode(), shape, new_operands, new_wrapped_computation, async_group_id_, - async_execution_thread_); + return std::make_unique(opcode(), shape, + new_operands[0]); +} + +HloAsyncStartInstruction::HloAsyncStartInstruction( + HloOpcode opcode, const Shape& shape, + absl::Span operands, + HloComputation* async_computation, absl::string_view async_execution_thread) + : HloAsyncInstruction(opcode, shape, operands, + async_computation->root_instruction()->opcode()) { + CHECK(!async_computation->IsCustomCallComputation()); + CHECK(!async_computation->IsFusionComputation()); + CHECK(!async_computation->IsAsyncComputation()); + AppendComputation(async_computation); + async_computation->AddAsyncStart(this); + HloAsyncStartInstruction::set_async_execution_thread(async_execution_thread); +} + +HloAsyncStartInstruction::~HloAsyncStartInstruction() { + ClearAsyncComputationInstruction(); + ClearCalledComputations(); } -void HloAsyncInstruction::set_async_group_id( - std::optional async_group_id) { - async_group_id_ = async_group_id; +void HloAsyncStartInstruction::ClearAsyncComputationInstruction() { + // Each async instruction calls a single computation, but we use + // called_computations() instead of async_wrapped_instruction(), because the + // order in which things get destructed can vary; the async computation's + // back-pointer may already be null, which violates a check in + // async_wrapped_instruction. + if (!called_computations().empty() && + async_wrapped_computation()->AsyncStart() == this) { + async_wrapped_computation()->RemoveAsyncStart(); + } } -void HloAsyncInstruction::set_async_execution_thread( +void HloAsyncStartInstruction::set_async_execution_thread( absl::string_view async_execution_thread) { async_execution_thread_ = std::string(async_execution_thread); SetThreadName(async_wrapped_computation(), async_execution_thread, /*skip_async_execution_thread_overwrite=*/false); } -HloInstructionProto HloAsyncInstruction::ToProto() const { +HloInstructionProto HloAsyncStartInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); - proto.set_async_group_id(async_group_id_.has_value() ? *async_group_id_ : -1); proto.set_async_execution_thread(async_execution_thread_ == HloInstruction::kMainExecutionThread ? "" @@ -408,6 +442,40 @@ HloInstructionProto HloAsyncInstruction::ToProto() const { return proto; } +void HloAsyncStartInstruction::PrintExtraAttributesImpl( + AttributePrinter& printer, const HloPrintOptions& options) const { + if (async_execution_thread_ != kMainExecutionThread) { + printer.Next([this](Printer* printer) { + AppendCat(printer, "async_execution_thread=\"", async_execution_thread_, + "\""); + }); + } + if (options.syntax_sugar_async_ops() && + async_wrapped_computation()->CanExpandIntoSingleInstruction()) { + async_wrapped_instruction()->PrintExtraAttributes(printer, options); + } +} + +std::unique_ptr +HloAsyncStartInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + HloModule* module = context != nullptr ? context->module() : GetModule(); + HloComputation* new_wrapped_computation = nullptr; + if (context != nullptr) { + new_wrapped_computation = + context->FindComputation(async_wrapped_computation()); + } + if (new_wrapped_computation == nullptr) { + new_wrapped_computation = module->AddEmbeddedComputation( + async_wrapped_computation()->Clone("clone", context)); + } + + return std::make_unique( + opcode(), shape, new_operands, new_wrapped_computation, + async_execution_thread_); +} + HloCopyStartInstruction::HloCopyStartInstruction( const Shape& shape, HloInstruction* operand, std::optional cross_program_prefetch_index) @@ -759,13 +827,26 @@ HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand, AppendOperand(operand); } +HloSendDoneInstruction::HloSendDoneInstruction(HloInstruction* operand, + int64_t channel_id, + bool is_host_transfer) + : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(), + channel_id, is_host_transfer) { + AppendOperand(operand); +} + std::unique_ptr HloSendDoneInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); + HloSendInstruction* send = dynamic_cast(new_operands[0]); + if (send != nullptr) { + return std::make_unique(send, is_host_transfer()); + } + return std::make_unique( - Cast(new_operands[0]), is_host_transfer()); + new_operands[0], channel_id().value(), is_host_transfer()); } // Recv instruction produces a tuple of {receive buffer, U32 context}. @@ -801,13 +882,30 @@ HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand, AppendOperand(operand); } +HloRecvDoneInstruction::HloRecvDoneInstruction(HloInstruction* operand, + int64_t channel_id, + bool is_host_transfer) + : HloSendRecvInstruction( + HloOpcode::kRecvDone, + ShapeUtil::MakeTupleShape( + {ShapeUtil::GetTupleElementShape(operand->shape(), 0), + ShapeUtil::MakeTokenShape()}), + channel_id, is_host_transfer) { + AppendOperand(operand); +} + std::unique_ptr HloRecvDoneInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); + HloRecvInstruction* recv = dynamic_cast(new_operands[0]); + if (recv != nullptr) { + return std::make_unique(recv, is_host_transfer()); + } + return std::make_unique( - Cast(new_operands[0]), is_host_transfer()); + new_operands[0], channel_id().value(), is_host_transfer()); } HloCollectiveInstruction::HloCollectiveInstruction( @@ -1063,6 +1161,27 @@ bool HloAllToAllInstruction::IdenticalSlowPathIgnoringChannelIdValues( split_dimension_ == casted_other.split_dimension(); } +HloCollectiveBroadcastInstruction::HloCollectiveBroadcastInstruction( + HloOpcode opcode, const Shape& shape, + absl::Span operands, + absl::Span replica_groups, bool constrain_layout, + const std::optional& channel_id) + : HloCollectiveInstruction(opcode, shape, operands, replica_groups, + constrain_layout, channel_id) {} + +HloInstructionProto HloCollectiveBroadcastInstruction::ToProto() const { + return HloCollectiveInstruction::ToProto(); +} + +std::unique_ptr +HloCollectiveBroadcastInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* /*context*/) const { + return std::make_unique( + opcode(), shape, new_operands, replica_groups(), constrain_layout(), + channel_id()); +} + HloCollectivePermuteInstruction::HloCollectivePermuteInstruction( HloOpcode opcode, const Shape& shape, HloInstruction* operand, const std::vector>& source_target_pairs, @@ -1737,8 +1856,6 @@ HloInstruction* HloCallableInstruction::AppendInstructionIntoCalledComputation( HloInstruction* HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation( HloInstruction* instruction_to_append, bool add_output) { - CHECK(instruction_to_append->IsFusible()) - << instruction_to_append->ToString(); VLOG(3) << "CloneAndAppendInstructionIntoCalledComputation:\n" << instruction_to_append->ToString(); HloInstruction* clone = nullptr; @@ -1751,12 +1868,15 @@ HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation( if (called_computations().empty()) { // New fusion instruction. It should not be a multioutput instruction. CHECK(!add_output); - auto builder = HloComputation::Builder( - default_called_computation_name(), - opcode() == HloOpcode::kFusion ? this : nullptr); + auto builder = HloComputation::Builder(default_called_computation_name()); builder.AddInstruction(instruction_to_append->Clone(/*suffix=*/"")); - AppendComputation( - CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build())); + auto* new_computation = + CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()); + AppendComputation(new_computation); + if (opcode() == HloOpcode::kFusion) { + new_computation->SetFusionInstruction(this); + } + clone = called_computation_root(); } else { // When add_output is false, instruction_to_append is necessarily an @@ -1986,6 +2106,15 @@ void HloFusionInstruction::ClearCalledComputations() { HloInstruction::ClearCalledComputations(); } +HloInstruction* +HloFusionInstruction::CloneAndAppendInstructionIntoCalledComputation( + HloInstruction* instruction_to_append, bool add_output) { + CHECK(instruction_to_append->IsFusible()) + << instruction_to_append->ToString(); + return HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation( + instruction_to_append, add_output); +} + std::string HloFusionInstruction::ToCategory() const { switch (fusion_kind()) { case FusionKind::kLoop: @@ -2217,15 +2346,13 @@ HloFusionInstruction::fused_parameters() const { return fused_instructions_computation()->parameter_instructions(); } -tsl::gtl::iterator_range>::const_iterator>> +tsl::gtl::iterator_range HloFusionInstruction::fused_instructions() const { const HloComputation* subcomp = fused_instructions_computation(); return subcomp->instructions(); } -tsl::gtl::iterator_range< - UnwrappingIterator>::iterator>> +tsl::gtl::iterator_range HloFusionInstruction::fused_instructions() { return fused_instructions_computation()->instructions(); } @@ -3471,18 +3598,34 @@ std::unique_ptr HloIotaInstruction::CloneWithNewOperandsImpl( HloDotInstruction::HloDotInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfig& precision_config) + const PrecisionConfig& precision_config, + std::vector sparsity, + absl::Span sparse_meta) : HloInstruction(HloOpcode::kDot, shape), dot_dimension_numbers_(dimension_numbers), - precision_config_(precision_config) { + precision_config_(precision_config), + sparsity_(std::move(sparsity)) { AppendOperand(lhs); AppendOperand(rhs); + CHECK_LE(sparsity_.size(), kOperands); + CHECK_EQ(sparsity_.size(), sparse_meta.size()); + for (HloInstruction* meta : sparse_meta) { + AppendOperand(meta); + } + if (sparsity_.size() == kOperands && + sparsity_[0].index() > sparsity_[1].index()) { + std::swap(sparsity_[0], sparsity_[1]); // Keep descriptors ordered. + std::swap(mutable_operands()[2], mutable_operands()[3]); + } } HloInstructionProto HloDotInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_; *proto.mutable_precision_config() = precision_config_; + for (const SparsityDescriptor& descriptor : sparsity_) { + *proto.add_dot_sparsity() = descriptor; + } return proto; } @@ -3492,6 +3635,9 @@ void HloDotInstruction::PrintExtraAttributesImpl( printer->Append(DotDimensionNumbersToString(dot_dimension_numbers_)); }); PrintPrecisionConfig(printer, precision_config_); + if (!sparsity_.empty()) { + PrintSparsityDescriptor(printer, absl::MakeSpan(sparsity_)); + } } bool HloDotInstruction::IdenticalSlowPath( @@ -3502,16 +3648,18 @@ bool HloDotInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals(dot_dimension_numbers(), casted_other.dot_dimension_numbers()) && protobuf_util::ProtobufEquals(precision_config(), - casted_other.precision_config()); + casted_other.precision_config()) && + absl::c_equal(sparsity_, casted_other.sparsity_, + protobuf_util::ProtobufEquals); } std::unique_ptr HloDotInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 2); + CHECK_EQ(new_operands.size(), kOperands + sparse_operands()); return std::make_unique( shape, new_operands[0], new_operands[1], dot_dimension_numbers_, - precision_config_); + precision_config_, sparsity_, new_operands.subspan(kOperands)); } HloDomainInstruction::HloDomainInstruction( diff --git a/xla/hlo/ir/hlo_instructions.h b/xla/hlo/ir/hlo_instructions.h index e2509d7284c11..79c23ce67c6c6 100644 --- a/xla/hlo/ir/hlo_instructions.h +++ b/xla/hlo/ir/hlo_instructions.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -67,7 +67,6 @@ class HloDimensionsInstruction : public HloInstruction { case HloOpcode::kReduce: case HloOpcode::kReverse: case HloOpcode::kSort: - case HloOpcode::kTopK: case HloOpcode::kTranspose: return true; default: @@ -226,41 +225,23 @@ class HloFftInstruction : public HloInstruction { class HloAsyncInstruction : public HloInstruction { public: - HloAsyncInstruction( - HloOpcode opcode, const Shape& shape, - absl::Span operands, - HloComputation* async_computation, - std::optional async_group_id = std::nullopt, - absl::string_view async_execution_thread = kMainExecutionThread); - HloAsyncInstruction( - HloOpcode opcode, const Shape& shape, HloInstruction* operand, - HloComputation* async_computation, - std::optional async_group_id = std::nullopt, - absl::string_view async_execution_thread = kMainExecutionThread); - - ~HloAsyncInstruction() override; - // When an async instruction is being destructed, remove it from the vector of - // pointers of its called computation, to avoid referencing freed memory. - void ClearAsyncComputationInstruction(); + // Constructs async-{update,done}. + HloAsyncInstruction(HloOpcode opcode, const Shape& shape, + HloInstruction* operand); + HloComputation* async_wrapped_computation() const; HloInstruction* async_wrapped_instruction() const; HloOpcode async_wrapped_opcode() const; - // Async group id is a unique id given to a group of async operations that - // consist of one async start, one async done, and zero or more async update - // operations. The async group participates in a single async operation. The - // async operation canonicalizer pass assigns async group ids. - std::optional async_group_id() const { return async_group_id_; } - // Async thread name is a unique thread name for one or more async groups. // Typically one HLO module contains a main thread as well as one or more // parallel threads. - absl::string_view async_execution_thread() const { - return async_execution_thread_; + virtual absl::string_view async_execution_thread() const; + virtual void set_async_execution_thread( + absl::string_view async_execution_thread) {} + HloInstructionProto ToProto() const override { + return HloInstruction::ToProto(); } - void set_async_group_id(std::optional async_group_id); - void set_async_execution_thread(absl::string_view async_execution_thread); - HloInstructionProto ToProto() const override; static bool ClassOf(const HloInstruction* hlo) { switch (hlo->opcode()) { @@ -273,9 +254,31 @@ class HloAsyncInstruction : public HloInstruction { } } + // Returns async-start instruction of the async chain. + HloAsyncInstruction* async_chain_start() const; + // Returns async-done instruction of the async chain. + HloAsyncInstruction* async_chain_done() const; + // Returns the chain of async op referencing this computation, + // where *begin(GetAsyncChain()) is the async-start op and + // *end(GetAsyncChain()) is the async-done op. + std::vector GetAsyncChain() const; + + bool HasSideEffect() const override { + return async_wrapped_instruction()->HasSideEffect(); + } + + protected: + // Helper to constructs async-{start,update,done}. + HloAsyncInstruction(HloOpcode opcode, const Shape& shape, + absl::Span operands, + HloOpcode async_wrapped_opcode); + private: + // async-{update,done} inherit all their attributes from async-start, + // so they shouldn't print any. void PrintExtraAttributesImpl(AttributePrinter& printer, - const HloPrintOptions& options) const override; + const HloPrintOptions& options) const override { + } bool IdenticalSlowPath( const HloInstruction& other, absl::FunctionRef @@ -283,7 +286,37 @@ class HloAsyncInstruction : public HloInstruction { std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - std::optional async_group_id_; + HloAsyncInstruction* async_chain_next_ = nullptr; +}; + +// Creates async-start. +class HloAsyncStartInstruction : public HloAsyncInstruction { + public: + HloAsyncStartInstruction( + HloOpcode opcode, const Shape& shape, + absl::Span operands, + HloComputation* async_computation, + absl::string_view async_execution_thread = kMainExecutionThread); + + ~HloAsyncStartInstruction() override; + // When an async instruction is being destructed, remove it from the vector of + // pointers of its called computation, to avoid referencing freed memory. + void ClearAsyncComputationInstruction(); + + absl::string_view async_execution_thread() const override { + return async_execution_thread_; + }; + void set_async_execution_thread( + absl::string_view async_execution_thread) override; + HloInstructionProto ToProto() const override; + + private: + void PrintExtraAttributesImpl(AttributePrinter& printer, + const HloPrintOptions& options) const override; + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + std::string async_execution_thread_ = kMainExecutionThread; }; @@ -415,9 +448,9 @@ class HloCholeskyInstruction : public HloInstruction { // Class that represents instructions that synchronize and transfer data between // partitioned devices. Send/Recv and collective instructions (AllReduce, -// AllToAll, CollectivePermute) belong to this instruction type. A group of -// instructions (of the same opcode) with the same channel_id communicate during -// execution. +// AllToAll, CollectivePermute, CollectiveBroadcast) belong to this instruction +// type. A group of instructions (of the same opcode) with the same channel_id +// communicate during execution. class HloChannelInstruction : public HloInstruction { public: // Returns the channel id associated with the instruction. The id is @@ -545,7 +578,8 @@ class HloSendDoneInstruction : public HloSendRecvInstruction { public: explicit HloSendDoneInstruction(HloSendInstruction* operand, bool is_host_transfer); - + explicit HloSendDoneInstruction(HloInstruction* operand, int64_t channel_id, + bool is_host_transfer); static bool ClassOf(const HloInstruction* hlo) { return hlo->opcode() == HloOpcode::kSendDone; } @@ -577,6 +611,8 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction { public: explicit HloRecvDoneInstruction(HloRecvInstruction* operand, bool is_host_transfer); + explicit HloRecvDoneInstruction(HloInstruction* operand, int64_t channel_id, + bool is_host_transfer); static bool ClassOf(const HloInstruction* hlo) { return hlo->opcode() == HloOpcode::kRecvDone; @@ -813,6 +849,28 @@ class HloAllToAllInstruction : public HloCollectiveInstruction { std::optional split_dimension_; }; +class HloCollectiveBroadcastInstruction : public HloCollectiveInstruction { + public: + explicit HloCollectiveBroadcastInstruction( + HloOpcode opcode, const Shape& shape, + absl::Span operands, + absl::Span replica_groups, bool constrain_layout, + const std::optional& channel_id); + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + static bool ClassOf(const HloInstruction* hlo) { + return hlo->opcode() == HloOpcode::kCollectiveBroadcast; + } + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; +}; + class HloCollectivePermuteInstruction : public HloChannelInstruction { public: explicit HloCollectivePermuteInstruction( @@ -868,6 +926,7 @@ inline bool HloAllReduceInstructionBase::ClassOf(const HloInstruction* hlo) { inline bool HloCollectiveInstruction::ClassOf(const HloInstruction* hlo) { return HloAllReduceInstructionBase::ClassOf(hlo) || + HloCollectiveBroadcastInstruction::ClassOf(hlo) || HloAllGatherInstruction::ClassOf(hlo) || HloAllToAllInstruction::ClassOf(hlo); } @@ -1310,6 +1369,11 @@ class HloFusionInstruction : public HloCallableInstruction { // its fusion computation, to avoid referencing freed memory. void ClearFusionComputationInstruction(); + // Clones the given instruction_to_append and inserts the clone into this + // callable instruction. + HloInstruction* CloneAndAppendInstructionIntoCalledComputation( + HloInstruction* instruction_to_append, bool add_output = false); + std::string ToCategory() const override; // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1363,12 +1427,10 @@ class HloFusionInstruction : public HloCallableInstruction { // Returns the list of fused instructions inside this fusion instruction. The // returned type is a range of HloInstruction*s. - tsl::gtl::iterator_range>::const_iterator>> + tsl::gtl::iterator_range fused_instructions() const; - tsl::gtl::iterator_range< - UnwrappingIterator>::iterator>> + tsl::gtl::iterator_range fused_instructions(); // Gets the number of instructions inside this fusion instruction. @@ -2152,6 +2214,8 @@ class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction { int64_t first_index_operand_number() const override { return 2; } + const HloInstruction* update() const { return operand(1); } + static bool ClassOf(const HloInstruction* hlo) { return hlo->opcode() == HloOpcode::kDynamicUpdateSlice; } @@ -2309,12 +2373,17 @@ class HloIotaInstruction : public HloInstruction { class HloDotInstruction : public HloInstruction { public: + static const int kOperands = 2; + // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch - // dimensions specified in 'dimension_numbers'. - explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs, - HloInstruction* rhs, - const DotDimensionNumbers& dimension_numbers, - const PrecisionConfig& precision_config); + // dimensions specified in 'dimension_numbers'. If 'sparsity' is set, then + // 'sparse_meta' must also be present (and have the same size). + explicit HloDotInstruction( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config, + std::vector sparsity = {}, + absl::Span sparse_meta = {}); // Returns data on the dimension numbers used for a dot operation. const DotDimensionNumbers& dot_dimension_numbers() const { @@ -2336,6 +2405,13 @@ class HloDotInstruction : public HloInstruction { const PrecisionConfig& precision_config() const { return precision_config_; } PrecisionConfig* mutable_precision_config() { return &precision_config_; } + // Sparsity descriptors are optional. If present, additional operands define + // how the data is read for the dot inputs. + int sparse_operands() const { return sparsity_.size(); } + absl::Span sparsity() const { + return absl::MakeSpan(sparsity_); + } + // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -2361,6 +2437,11 @@ class HloDotInstruction : public HloInstruction { // Information used to communicate to the implementation about the algorithm // used to produce results. See the documentation on precision_config(). PrecisionConfig precision_config_; + + // Sparsity descriptors are set if some operands are sparse. In this case, the + // additional metadata operands contain the information that defines how + // the data is read. + std::vector sparsity_; }; class HloDomainInstruction : public HloInstruction { diff --git a/xla/hlo/ir/hlo_module.cc b/xla/hlo/ir/hlo_module.cc index 149fec7c3f1d4..f5a7806c0cbd9 100644 --- a/xla/hlo/ir/hlo_module.cc +++ b/xla/hlo/ir/hlo_module.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -188,7 +188,7 @@ HloComputation* HloModule::AddEntryComputationWithLayouts( } Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { - if (has_schedule() && !to_remove->IsCalledComputation()) { + if (has_schedule()) { schedule_->remove_computation(to_remove); } @@ -370,6 +370,15 @@ void HloModule::Print(Printer* printer, const HloPrintOptions& options) const { entry_computation_layout().Print(printer); printer->Append("}"); } + if (config.allow_spmd_sharding_propagation_to_parameters().size() != 1 || + config.allow_spmd_sharding_propagation_to_parameters().back()) { + printer->Append(", allow_spmd_sharding_propagation_to_parameters={"); + AppendJoin(printer, config.allow_spmd_sharding_propagation_to_parameters(), + ",", [](Printer* printer, bool i) { + printer->Append(i ? "true" : "false"); + }); + printer->Append("}"); + } if (config.allow_spmd_sharding_propagation_to_output().size() != 1 || config.allow_spmd_sharding_propagation_to_output().back()) { printer->Append(", allow_spmd_sharding_propagation_to_output={"); @@ -495,7 +504,7 @@ HloModuleProto HloModule::ToProto() const { return proto; } -StatusOr HloModule::ToProtoWithConfig() const { +absl::StatusOr HloModule::ToProtoWithConfig() const { HloModuleProtoWithConfig result; TF_ASSIGN_OR_RETURN(*result.mutable_config(), config_.get().ToProto()); *result.mutable_hlo_module() = ToProto(); @@ -531,7 +540,7 @@ Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const { } /* static */ -StatusOr> HloModule::CreateFromProto( +absl::StatusOr> HloModule::CreateFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config, bool prohibit_empty_literal) { VLOG(2) << "CreateFromProto()"; @@ -675,7 +684,7 @@ StatusOr> HloModule::CreateFromProto( } /* static */ -StatusOr HloModule::CreateModuleConfigFromShape( +absl::StatusOr HloModule::CreateModuleConfigFromShape( const ProgramShape& program_shape, const DebugOptions& debug_options, const ExecutionOptions* execution_options) { HloModuleConfig module_config(ProgramShape{program_shape}); @@ -691,17 +700,18 @@ StatusOr HloModule::CreateModuleConfigFromShape( execution_options->use_spmd_partitioning()); module_config.set_use_auto_spmd_partitioning( execution_options->use_auto_spmd_partitioning()); - std::vector mesh_shape; - for (auto t : execution_options->auto_spmd_partitioning_mesh_shape()) { - mesh_shape.push_back(t); - } - module_config.set_auto_spmd_partitioning_mesh_shape(mesh_shape); - std::vector mesh_ids; - for (auto t : execution_options->auto_spmd_partitioning_mesh_ids()) { - mesh_ids.push_back(t); - } - module_config.set_auto_spmd_partitioning_mesh_ids(mesh_ids); + module_config.set_auto_spmd_partitioning_mesh_shape(std::vector( + execution_options->auto_spmd_partitioning_mesh_shape().begin(), + execution_options->auto_spmd_partitioning_mesh_shape().end())); + module_config.set_auto_spmd_partitioning_mesh_ids(std::vector( + execution_options->auto_spmd_partitioning_mesh_ids().begin(), + execution_options->auto_spmd_partitioning_mesh_ids().end())); module_config.set_deduplicate_hlo(execution_options->deduplicate_hlo()); + if (!execution_options->allow_spmd_sharding_propagation_to_parameters() + .empty()) { + module_config.set_allow_spmd_sharding_propagation_to_parameters( + execution_options->allow_spmd_sharding_propagation_to_parameters()); + } if (!execution_options->allow_spmd_sharding_propagation_to_output() .empty()) { module_config.set_allow_spmd_sharding_propagation_to_output( @@ -721,11 +731,10 @@ StatusOr HloModule::CreateModuleConfigFromShape( module_config.num_partitions()); } } - std::vector param_requires_broadcast_via_collectives( + module_config.set_param_requires_broadcast_via_collectives(std::vector< + bool>( execution_options->param_requires_broadcast_via_collectives().begin(), - execution_options->param_requires_broadcast_via_collectives().end()); - module_config.set_param_requires_broadcast_via_collectives( - param_requires_broadcast_via_collectives); + execution_options->param_requires_broadcast_via_collectives().end())); module_config.set_allow_separate_sharding_programs( execution_options->allow_separate_sharding_programs()); HloModuleConfig::AssignStructShardableValueUpdatePairs( @@ -747,7 +756,7 @@ StatusOr HloModule::CreateModuleConfigFromShape( } /* static */ -StatusOr HloModule::CreateModuleConfigFromProto( +absl::StatusOr HloModule::CreateModuleConfigFromProto( const HloModuleProto& module, const DebugOptions& debug_options, const ExecutionOptions* execution_options) { if (!module.has_host_program_shape()) { @@ -770,7 +779,7 @@ StatusOr HloModule::CreateModuleConfigFromProto( return config; } -StatusOr> HloModule::CreateFromProtoWithConfig( +absl::StatusOr> HloModule::CreateFromProtoWithConfig( const HloModuleProtoWithConfig& proto, bool prohibit_empty_literal) { auto hlo_module_proto = proto.hlo_module(); TF_ASSIGN_OR_RETURN(std::unique_ptr config_ptr, @@ -914,10 +923,12 @@ std::vector HloModule::MakeComputationPostOrder( absl::flat_hash_set nonroot_computations; nonroot_computations.reserve(computations_.size() - 1); for (auto& computation : computations_) { - for (auto* instruction : computation->instructions()) { - for (HloComputation* called_computation : - instruction->called_computations()) { - nonroot_computations.insert(called_computation); + for (const HloInstructionInfo& inst : + computation->instructions_with_info()) { + if (HloInstruction::MightHaveCalledComputations(inst.opcode())) { + for (HloComputation* called_computation : inst->called_computations()) { + nonroot_computations.insert(called_computation); + } } } } diff --git a/xla/hlo/ir/hlo_module.h b/xla/hlo/ir/hlo_module.h index 43ddb9ac5539d..d232fb1462ce0 100644 --- a/xla/hlo/ir/hlo_module.h +++ b/xla/hlo/ir/hlo_module.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -51,7 +51,7 @@ limitations under the License. namespace xla { using LayoutCanonicalizationCallback = - std::function, Shape>>( + std::function, Shape>>( const HloModule& module)>; // Helper class to maintain a copy-on-write storage of an object of the @@ -451,25 +451,25 @@ class HloModule { // Convert an HloModule to or from a proto. HloModuleProto ToProto() const; - static StatusOr> CreateFromProto( + static absl::StatusOr> CreateFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config, bool prohibit_empty_literal = true); // Convert an HloModule to or from a proto that includes module configuration - StatusOr ToProtoWithConfig() const; - static StatusOr> CreateFromProtoWithConfig( + absl::StatusOr ToProtoWithConfig() const; + static absl::StatusOr> CreateFromProtoWithConfig( const HloModuleProtoWithConfig& proto, bool prohibit_empty_literal = true); // Creates and returns an HloModuleConfig with an appropriate program shape // for the HLO module in the given proto. - static StatusOr CreateModuleConfigFromProto( + static absl::StatusOr CreateModuleConfigFromProto( const HloModuleProto& module, const DebugOptions& debug_options, const ExecutionOptions* execution_options = nullptr); // Creates and returns an HloModuleConfig with an appropriate program shape // for the HLO module in the given proto. - static StatusOr CreateModuleConfigFromShape( + static absl::StatusOr CreateModuleConfigFromShape( const ProgramShape& program_shape, const DebugOptions& debug_options, const ExecutionOptions* execution_options = nullptr); @@ -505,6 +505,9 @@ class HloModule { const HloInputOutputAliasConfig& input_output_alias_config() const { return input_output_alias_config_; } + void set_input_output_alias_config(HloInputOutputAliasConfig config) { + input_output_alias_config_ = std::move(config); + } // buffer_donor_config_ indicates the set of input buffer donors that are // expected from the module. @@ -512,6 +515,9 @@ class HloModule { const HloBufferDonorConfig& buffer_donor_config() const { return buffer_donor_config_; } + void set_buffer_donor_config(HloBufferDonorConfig config) { + buffer_donor_config_ = std::move(config); + } // Returns an id that is unique to this module across all modules created over // the lifetime of this process. @@ -546,6 +552,12 @@ class HloModule { instr->UniquifyName(&instruction_name_uniquer_); } + void SetAndUniquifyComputationName(HloComputation* computation, + absl::string_view name) { + computation->SetAndSanitizeName(name); + computation->UniquifyName(&computation_name_uniquer_); + } + Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const; // Checks if this config has a list of entry parameters' HLO shardings for diff --git a/xla/hlo/ir/hlo_module_group.cc b/xla/hlo/ir/hlo_module_group.cc index 33eb30b11c426..c5623e45aea58 100644 --- a/xla/hlo/ir/hlo_module_group.cc +++ b/xla/hlo/ir/hlo_module_group.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -72,7 +72,7 @@ HloModuleGroupProto HloModuleGroup::ToProto() const { return proto; } -/* static */ StatusOr HloModuleGroup::CreateFromProto( +/* static */ absl::StatusOr HloModuleGroup::CreateFromProto( const HloModuleGroupProto& proto, absl::Span module_configs) { TF_RET_CHECK(!proto.name().empty()) << "Module group name cannot be empty"; diff --git a/xla/hlo/ir/hlo_module_group.h b/xla/hlo/ir/hlo_module_group.h index c9afd22ec44e6..753e8bc61c62d 100644 --- a/xla/hlo/ir/hlo_module_group.h +++ b/xla/hlo/ir/hlo_module_group.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -44,6 +44,11 @@ class HloModuleGroup { HloModuleGroup(absl::string_view name, std::vector>&& modules); + HloModuleGroup(const HloModuleGroup& other) = delete; + HloModuleGroup(HloModuleGroup&& other) = default; + HloModuleGroup& operator=(const HloModuleGroup& other) = delete; + HloModuleGroup& operator=(HloModuleGroup&& other) = default; + // Returns the modules contained in the group. const std::vector& modules() const { return module_ptrs_; } @@ -82,7 +87,7 @@ class HloModuleGroup { // Serialize the module group to/from a proto. HloModuleGroupProto ToProto() const; - static StatusOr CreateFromProto( + static absl::StatusOr CreateFromProto( const HloModuleGroupProto& proto, absl::Span module_configs); diff --git a/xla/hlo/ir/hlo_module_metadata.cc b/xla/hlo/ir/hlo_module_metadata.cc index 8e9de396ad7ec..aa910c1091924 100644 --- a/xla/hlo/ir/hlo_module_metadata.cc +++ b/xla/hlo/ir/hlo_module_metadata.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,11 +18,15 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "xla/util.h" #include "tsl/platform/env.h" +#include "tsl/platform/protobuf.h" namespace xla { -StatusOr HloModuleMetadata::GetCurrentHloPassMetadata() { +absl::StatusOr +HloModuleMetadata::GetCurrentHloPassMetadata() { if (running_passes_.empty()) { return NotFound( "HloPassMetadata for currently running pass not found, either because " @@ -84,4 +88,16 @@ void HloModuleMetadata::set_prepartitioning_metadata( } } +Status HloModuleMetadata::set_custom_metadata( + const ::tsl::protobuf::Message& message) { + TF_ASSIGN_OR_RETURN(HloPassMetadata * pass_metadata, + GetCurrentHloPassMetadata()); + if (!pass_metadata->mutable_custom_metadata()->PackFrom(message)) { + LOG(WARNING) << "failed to pack custom metadata for " + << pass_metadata->pass_id(); + return Internal("failed to pack custom metadata"); + }; + return OkStatus(); +} + } // namespace xla diff --git a/xla/hlo/ir/hlo_module_metadata.h b/xla/hlo/ir/hlo_module_metadata.h index 27fcc5361348e..a18e9090db4d0 100644 --- a/xla/hlo/ir/hlo_module_metadata.h +++ b/xla/hlo/ir/hlo_module_metadata.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/util.h" #include "tsl/platform/env.h" +#include "tsl/platform/protobuf.h" namespace xla { @@ -62,8 +63,9 @@ class HloModuleMetadata { void add_partitioned_module_id(int64_t id) { module_metadata_.add_partitioned_module_ids(id); } + Status set_custom_metadata(const ::tsl::protobuf::Message& message); - StatusOr current_pass_id() { + absl::StatusOr current_pass_id() { TF_ASSIGN_OR_RETURN(HloPassMetadata * pass_metadata, GetCurrentHloPassMetadata()); return pass_metadata->pass_id(); @@ -111,7 +113,7 @@ class HloModuleMetadata { // Gets mutable metadata for the currently running pass. If passes are nested, // finds the deepest one still running. Returns NotFound if metadata for the // currently running pass cannot be found. - StatusOr GetCurrentHloPassMetadata(); + absl::StatusOr GetCurrentHloPassMetadata(); Status MutateCurrentHloPassMetadata( absl::FunctionRef mutator); diff --git a/xla/hlo/ir/hlo_op_metadata.cc b/xla/hlo/ir/hlo_op_metadata.cc index 69af1ec010d92..b45f23ba96d43 100644 --- a/xla/hlo/ir/hlo_op_metadata.cc +++ b/xla/hlo/ir/hlo_op_metadata.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,8 +24,16 @@ limitations under the License. namespace xla { -std::string OpMetadataToString(const OpMetadata& metadata) { +std::string OpMetadataToString(const OpMetadata& metadata, bool only_op_name) { std::vector result; + if (only_op_name) { + if (!metadata.op_name().empty()) { + return absl::StrCat("op_name=\"", absl::CEscape(metadata.op_name()), + "\""); + } else { + return ""; + } + } if (!metadata.op_type().empty()) { result.push_back( absl::StrCat("op_type=\"", absl::CEscape(metadata.op_type()), "\"")); diff --git a/xla/hlo/ir/hlo_op_metadata.h b/xla/hlo/ir/hlo_op_metadata.h index ebf7f51e597d2..acbd34c84af2f 100644 --- a/xla/hlo/ir/hlo_op_metadata.h +++ b/xla/hlo/ir/hlo_op_metadata.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,7 +21,8 @@ limitations under the License. #include "xla/xla_data.pb.h" namespace xla { -std::string OpMetadataToString(const OpMetadata& metadata); +std::string OpMetadataToString(const OpMetadata& metadata, + bool only_op_name = false); } // namespace xla #endif // XLA_HLO_IR_HLO_OP_METADATA_H_ diff --git a/xla/hlo/ir/hlo_opcode.cc b/xla/hlo/ir/hlo_opcode.cc index 3a3b4b520d407..dcdf9c1933c82 100644 --- a/xla/hlo/ir/hlo_opcode.cc +++ b/xla/hlo/ir/hlo_opcode.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ absl::string_view HloOpcodeString(HloOpcode opcode) { } } -StatusOr StringToHloOpcode(absl::string_view opcode_name) { +absl::StatusOr StringToHloOpcode(absl::string_view opcode_name) { static auto* opcode_map = new absl::flat_hash_map({ #define STRING_TO_OPCODE_ENTRY(enum_name, opcode_name, ...) \ {opcode_name, HloOpcode::enum_name}, diff --git a/xla/hlo/ir/hlo_opcode.h b/xla/hlo/ir/hlo_opcode.h index b3127b614a952..79d90e3120d60 100644 --- a/xla/hlo/ir/hlo_opcode.h +++ b/xla/hlo/ir/hlo_opcode.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -75,6 +75,7 @@ namespace xla { V(kCholesky, "cholesky", 1) \ V(kClamp, "clamp", 3) \ V(kClz, "count-leading-zeros", 1) \ + V(kCollectiveBroadcast, "collective-broadcast", kHloOpcodeIsVariadic) \ V(kCollectivePermute, "collective-permute", kHloOpcodeIsVariadic) \ V(kCollectivePermuteDone, "collective-permute-done", 1) \ V(kCollectivePermuteStart, "collective-permute-start", kHloOpcodeIsVariadic) \ @@ -92,10 +93,11 @@ namespace xla { V(kCustomCall, "custom-call", kHloOpcodeIsVariadic) \ V(kDivide, "divide", 2) \ V(kDomain, "domain", 1) \ - V(kDot, "dot", 2) \ + V(kDot, "dot", kHloOpcodeIsVariadic) \ V(kDynamicReshape, "dynamic-reshape", kHloOpcodeIsVariadic) \ V(kDynamicSlice, "dynamic-slice", kHloOpcodeIsVariadic) \ V(kDynamicUpdateSlice, "dynamic-update-slice", kHloOpcodeIsVariadic) \ + V(kErf, "erf", 1) \ V(kExp, "exponential", 1) \ V(kExpm1, "exponential-minus-one", 1) \ V(kFft, "fft", 1) \ @@ -169,7 +171,8 @@ namespace xla { /* go/keep-sorted end */ // LINT.ThenChange(../../mlir_hlo/mhlo/IR/hlo_ops.td) -enum class HloOpcode { +// Upto 256 opcodes. Increase the base type if/when needed. +enum class HloOpcode : uint8_t { #define DECLARE_ENUM(enum_name, opcode_name, ...) enum_name, HLO_OPCODE_LIST(DECLARE_ENUM) #undef DECLARE_ENUM @@ -184,7 +187,7 @@ enum { absl::string_view HloOpcodeString(HloOpcode opcode); // Retrieves the opcode enum by name if the opcode exists. -StatusOr StringToHloOpcode(absl::string_view opcode_name); +absl::StatusOr StringToHloOpcode(absl::string_view opcode_name); inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) { return os << HloOpcodeString(opcode); @@ -226,6 +229,9 @@ inline constexpr uint32_t HloOpcodeCount() { #define HLO_XLIST_LENGTH(list) list(HLO_COUNT_ONE) return HLO_XLIST_LENGTH(HLO_OPCODE_LIST); } +static_assert(HloOpcodeCount() < 256, + "HloOpcode is a uint8_t. You need to increase its size before " + "adding new op codes."); } // namespace xla diff --git a/xla/hlo/ir/hlo_reachability.cc b/xla/hlo/ir/hlo_reachability.cc index c98c44cc2c040..cdc1b664ef438 100644 --- a/xla/hlo/ir/hlo_reachability.cc +++ b/xla/hlo/ir/hlo_reachability.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ namespace xla { HloReachabilityMap::HloReachabilityMap( absl::Span instructions) : bit_sets_(instructions.size(), BitSet(instructions.size())) { + indices_.reserve(instructions.size()); for (size_t i = 0; i < instructions.size(); ++i) { bit_sets_[i].Set(i); // Instructions are reachable from themselves. indices_[GetKey(instructions[i])] = i; diff --git a/xla/hlo/ir/hlo_reachability.h b/xla/hlo/ir/hlo_reachability.h index bdceed946180d..157991067ae9a 100644 --- a/xla/hlo/ir/hlo_reachability.h +++ b/xla/hlo/ir/hlo_reachability.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -161,8 +161,16 @@ class HloReachabilityMap { // Sets this bit-set to union of this bit-set and `other`. void operator|=(const BitSet& other) { - for (size_t i = 0; i < vector_.size(); ++i) { - vector_[i] |= other.vector_[i]; + if (this == &other) return; + DCHECK(size_ == other.size_); + + // Ease the work of the auto-vectorizer. + const Word* a = vector_.data(); + const Word* b = other.vector_.data(); + Word* __restrict out = vector_.data(); + size_t num_words = vector_.size(); + for (size_t i = 0; i < num_words; ++i) { + out[i] = a[i] | b[i]; } } @@ -182,6 +190,8 @@ class HloReachabilityMap { std::vector vector_; }; + friend class HloReachabilityMapBitSetBenchmark; + using Key = std::pair; // module ID, instruction ID. static Key GetKey(const HloInstruction* instruction) { return {instruction->GetModule()->unique_id(), instruction->unique_id()}; diff --git a/xla/hlo/ir/hlo_schedule.cc b/xla/hlo/ir/hlo_schedule.cc index 6e945b5d7ded8..2922a1b1286f3 100644 --- a/xla/hlo/ir/hlo_schedule.cc +++ b/xla/hlo/ir/hlo_schedule.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ limitations under the License. namespace xla { -/* static */ StatusOr HloSchedule::CreateFromProto( +/* static */ absl::StatusOr HloSchedule::CreateFromProto( const HloModule* module, const HloScheduleProto& proto) { absl::flat_hash_map id_to_computation; for (const HloComputation* computation : module->computations()) { @@ -76,7 +76,7 @@ namespace xla { return std::move(schedule); } -StatusOr HloSchedule::ToProto() const { +absl::StatusOr HloSchedule::ToProto() const { TF_RETURN_IF_ERROR(Verify()); HloScheduleProto proto; for (const auto& id_sequence : sequences_) { diff --git a/xla/hlo/ir/hlo_schedule.h b/xla/hlo/ir/hlo_schedule.h index 597ed2881a9eb..0f7e2a5733993 100644 --- a/xla/hlo/ir/hlo_schedule.h +++ b/xla/hlo/ir/hlo_schedule.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -78,6 +78,14 @@ class HloInstructionSequence { *id_it = new_instruction->unique_id(); } + // Adds the instruction to the sequence at a specified index, + void insert_instruction(HloInstruction* instruction, int64_t index) { + CHECK(0 <= index && index < size()) << "Index out of bounds"; + instruction_sequence_.insert(instruction_sequence_.begin() + index, + instruction); + id_sequence_.insert(id_sequence_.begin() + index, instruction->unique_id()); + } + // Clears the sequence of all instructions. void clear() { instruction_sequence_.clear(); @@ -114,9 +122,9 @@ class HloSchedule { explicit HloSchedule(const HloModule* module) : module_(module) {} // (De)Serialize an HloSchedule to/from a HloScheduleProto. - static StatusOr CreateFromProto(const HloModule* module, - const HloScheduleProto& proto); - StatusOr ToProto() const; + static absl::StatusOr CreateFromProto( + const HloModule* module, const HloScheduleProto& proto); + absl::StatusOr ToProto() const; // Returns a reference to the sequence for the given computation. const HloInstructionSequence& sequence( @@ -151,7 +159,8 @@ class HloSchedule { // Removes the computation from the sequences. void remove_computation(const HloComputation* computation) { auto it = sequences_.find(computation->unique_id()); - CHECK(it != sequences_.end()); + // The computation is not scheduled. Nothing to remove. + if (it == sequences_.end()) return; sequences_.erase(it); execution_threads_.erase(computation->unique_id()); } @@ -195,10 +204,10 @@ class HloSchedule { bool empty() const { return sequences_.empty(); } const HloModule* module() const { return module_; } - - private: // Updates the instruction sequence for the given computation. Status UpdateComputationSchedule(const HloComputation* computation); + private: + const HloModule* module_; diff --git a/xla/hlo/ir/hlo_sharding.cc b/xla/hlo/ir/hlo_sharding.cc index 9497747e6ece2..0577e5de93d0d 100644 --- a/xla/hlo/ir/hlo_sharding.cc +++ b/xla/hlo/ir/hlo_sharding.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -371,7 +372,7 @@ HloSharding HloSharding::SingleTuple(const Shape& tuple_shape, int64_t leaf_count = RequiredLeaves(tuple_shape); std::vector flattened_list; flattened_list.resize(leaf_count, sharding); - return HloSharding(flattened_list); + return HloSharding(std::move(flattened_list)); } HloSharding HloSharding::Single(const Shape& shape, @@ -624,7 +625,7 @@ Status HloSharding::CheckLeafCount(const Shape& shape) const { return OkStatus(); } -StatusOr> HloSharding::AsShapeTree( +absl::StatusOr> HloSharding::AsShapeTree( const Shape& shape) const { if (IsTuple()) { ShapeTree result(shape, HloSharding::Replicate()); @@ -639,7 +640,8 @@ StatusOr> HloSharding::AsShapeTree( } } -StatusOr HloSharding::GetTupleSharding(const Shape& shape) const { +absl::StatusOr HloSharding::GetTupleSharding( + const Shape& shape) const { if (IsTuple()) { TF_RETURN_IF_ERROR(CheckLeafCount(shape)); return *this; @@ -791,7 +793,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, return OkStatus(); } -/*static*/ StatusOr HloSharding::FromProto( +/*static*/ absl::StatusOr HloSharding::FromProto( const OpSharding& proto) { std::vector metadata(proto.metadata().begin(), proto.metadata().end()); @@ -845,14 +847,15 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, TF_RET_CHECK(!proto.tile_assignment_dimensions().empty()); auto product_no_overflow = - [](absl::Span dims) -> StatusOr { + [](absl::Span dims) -> absl::StatusOr { int64_t product_of_dimensions = 1; + bool any_overflow = false; for (auto dimension : dims) { - TF_RET_CHECK(dimension > 0); - product_of_dimensions = - MultiplyWithoutOverflow(product_of_dimensions, dimension); - TF_RET_CHECK(product_of_dimensions > 0); + bool overflow = false; + std::tie(product_of_dimensions, overflow) = + OverflowSafeMultiply(product_of_dimensions, dimension); } + TF_RET_CHECK(!any_overflow); return product_of_dimensions; }; @@ -1017,6 +1020,16 @@ int64_t HloSharding::NumTiles() const { .subspan(0, TiledDataRank())); } +int64_t HloSharding::NumTilesLeaf() const { + DCHECK(!IsTuple()); + if (IsTileMaximalLeaf()) { + return 1; + } + CHECK(!IsManualLeaf() && !IsUnknownLeaf()); + return Product(absl::Span(tile_assignment_.dimensions()) + .subspan(0, TiledDataRankLeaf())); +} + int64_t HloSharding::NumTiles(absl::Span dims) const { if (IsTileMaximal()) { return 1; @@ -1039,17 +1052,18 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape, const Shape* sub_shape = &shape; for (int64_t idx : index) { for (int64_t i = 0; i < idx; ++i) { - sharding_index += - ShapeUtil::GetLeafCount(ShapeUtil::GetSubshape(*sub_shape, {i})); + sharding_index += ShapeUtil::GetLeafCount( + ShapeUtil::GetSubshapeOneIndex(*sub_shape, i)); } - sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx}); + sub_shape = &ShapeUtil::GetSubshapeOneIndex(*sub_shape, idx); } if (sub_shape->IsTuple()) { auto begin_it = tuple_elements_.begin() + sharding_index; return HloSharding::Tuple( *sub_shape, absl::MakeConstSpan( - &*begin_it, &*(begin_it + ShapeUtil::GetLeafCount(*sub_shape)))); + &*begin_it, + &*(begin_it + ShapeUtil::GetLeafCountTuple(*sub_shape)))); } else { return tuple_elements_[sharding_index]; } diff --git a/xla/hlo/ir/hlo_sharding.h b/xla/hlo/ir/hlo_sharding.h index 8b943b75b00c0..4237b1e182ba9 100644 --- a/xla/hlo/ir/hlo_sharding.h +++ b/xla/hlo/ir/hlo_sharding.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/log/check.h" #include "absl/types/span.h" #include "xla/array.h" #include "xla/hlo/ir/tile_assignment.h" // IWYU pragma: export @@ -147,7 +148,7 @@ class HloSharding { static HloSharding Single(const Shape& shape, const HloSharding& sharding); // Create a new sharding from a protobuf OpSharding. - static StatusOr FromProto(const OpSharding& proto); + static absl::StatusOr FromProto(const OpSharding& proto); // Checks whether device is a reserved device number. A reserved device number // has usually a special meaning, with dedicated handling logic. @@ -177,6 +178,10 @@ class HloSharding { return absl::c_all_of( tuple_elements_, [](const HloSharding& s) { return s.IsReplicated(); }); } + bool IsReplicatedLeaf() const { + DCHECK(!IsTuple()); + return replicated_; + } // Returns true if the tile size is the same as the input size. bool IsTileMaximal() const { @@ -187,6 +192,10 @@ class HloSharding { return s.IsTileMaximal(); }); } + bool IsTileMaximalLeaf() const { + DCHECK(!IsTuple()); + return maximal_; + } // Returns whether the sharding represents manual partitioning. bool IsManual() const { @@ -196,6 +205,10 @@ class HloSharding { return absl::c_all_of(tuple_elements_, [](const HloSharding& s) { return s.IsManual(); }); } + bool IsManualLeaf() const { + DCHECK(!IsTuple()); + return manual_; + } // Returns whether the sharding represents a placeholder sharding. bool IsUnknown() const { @@ -205,6 +218,10 @@ class HloSharding { return absl::c_all_of(tuple_elements_, [](const HloSharding& s) { return s.IsUnknown(); }); } + bool IsUnknownLeaf() const { + DCHECK(!IsTuple()); + return unknown_; + } bool IsShardGroup() const { if (!IsTuple()) { @@ -250,6 +267,9 @@ class HloSharding { bool IsTiled() const { return !IsTileMaximal() && !IsManual() && !IsUnknown(); } + bool IsTiledLeaf() const { + return !IsTileMaximalLeaf() && !IsManualLeaf() && !IsUnknownLeaf(); + } // Returns if the sharding has partial replication and partial sharding. If // true, data is sharded according to other dimensions of tile_assignment(), @@ -317,7 +337,7 @@ class HloSharding { // tuple, if IsTuple, or a ShapeTree with a single element containing this // sharding. Only the leaf elements are populated. This creates a new // ShapeTree object so is not cheap. - StatusOr> AsShapeTree(const Shape& shape) const; + absl::StatusOr> AsShapeTree(const Shape& shape) const; ShapeTree GetAsShapeTree(const Shape& shape) const { return AsShapeTree(shape).value(); } @@ -329,7 +349,7 @@ class HloSharding { // If the current sharding is a tuple sharding, return itself as result. // Otherwise returns a tuple sharding for the input shape, with all the leaves // having this object sharding. - StatusOr GetTupleSharding(const Shape& shape) const; + absl::StatusOr GetTupleSharding(const Shape& shape) const; // If the shape is tuple and the current sharding is not a tuple, attempt to // construct a sharding that is compatible with the shape by replicating the @@ -408,6 +428,7 @@ class HloSharding { // Gets the number of tiles. If it has partial replication, this will not // equal the device count. int64_t NumTiles() const; + int64_t NumTilesLeaf() const; // Like NumTiles() but considers only some specific dimensions passed as // argument int64_t NumTiles(absl::Span dims) const; @@ -447,6 +468,16 @@ class HloSharding { rank -= subgroup_types_.size(); return rank; } + int64_t TiledDataRankLeaf() const { + DCHECK(!IsTuple()); + CHECK(IsTiledLeaf()); + int64_t rank = tile_assignment_.num_dimensions(); + if (ReplicateOnLastTileDim()) { + rank--; + } + rank -= subgroup_types_.size(); + return rank; + } // Returns the number of tuple_elements_ entries to fit the shape. static int64_t RequiredLeaves(const Shape& shape); @@ -639,16 +670,19 @@ class HloSharding { // When creating HloSharding, subgroup dims of the same type will be merged, // so that there is at most one dim with a given type. std::vector subgroup_types_; - bool replicated_; - bool maximal_; - bool tuple_; - bool manual_; - bool unknown_; + bool replicated_ : 1; // When non-tuple, true if the sharding is trivial. + bool maximal_ : 1; // When non-tuple, true if the tile size is the same as + // the input size. + bool tuple_ : 1; // True if this is a tuple. + bool manual_ : 1; // When non-tuple, true if the sharding represents manual + // partitioning. + bool unknown_ : 1; // When non-tuple, true if the sharding represents a + // placeholder sharding. // This flag is to support partial replication and partial sharding. If it is // true, tile_assignment_ will have an extra dimension in addition to the data // shape rank, and the added last dimension represents the subgroups of // replications, i.e., elements in slice [..., :] will be replicated. - bool replicate_on_last_tile_dim_; + bool replicate_on_last_tile_dim_ : 1; // This field is used to store the shard group information. Instructions // within the same shard group(i.e. under the same shard_group_id) will be // sharded alike or exactly the same as each other. diff --git a/xla/hlo/ir/hlo_sharding_metadata.cc b/xla/hlo/ir/hlo_sharding_metadata.cc index c0835571b3a80..85f5fc04e579d 100644 --- a/xla/hlo/ir/hlo_sharding_metadata.cc +++ b/xla/hlo/ir/hlo_sharding_metadata.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -174,7 +174,7 @@ Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, // If user is a tuple instruction, return the tuple subsharding corresponding to // the operand matching the instruction argument, because that is the // subsharding corresponding to instruction. -StatusOr> GetShardingTreeFromUser( +absl::StatusOr> GetShardingTreeFromUser( const HloInstruction& instruction, const HloInstruction& user) { if (user.opcode() == HloOpcode::kTuple) { return user.sharding() @@ -188,8 +188,8 @@ StatusOr> GetShardingTreeFromUser( // then no assignment is made. Therefore kUnassignedDevice is never propagated. // kConflict is returned if lhs is already assigned and rhs is assigned to a // different device. -StatusOr AssignLeafSharding(HloSharding* lhs, - const HloSharding& rhs) { +absl::StatusOr AssignLeafSharding(HloSharding* lhs, + const HloSharding& rhs) { TF_RET_CHECK(!lhs->IsTuple() && !rhs.IsTuple()); if (rhs.UsesDevice(kUnassignedDevice)) { return AssignmentKind::kUnassigned; @@ -207,7 +207,7 @@ StatusOr AssignLeafSharding(HloSharding* lhs, // In case of conflicting assignment AssignmentKind::kConflict is returned. In // this case lhs_tree is partially assigned, up to the conflicting leaf. It is // up to the caller to discard the partial assignment in case of conflict. -StatusOr AssignTreeSharding( +absl::StatusOr AssignTreeSharding( ShapeTree* lhs_tree, ShapeTree::iterator lhs_it, const ShapeTree& rhs_tree) { AssignmentKind assigned = AssignmentKind::kUnassigned; @@ -233,9 +233,9 @@ StatusOr AssignTreeSharding( return assigned; } -StatusOr ApplyShardingFromUsers(HloInstruction* instruction, - const DomainMetadata::Domain& domain, - const HloSharding& domain_sharding) { +absl::StatusOr ApplyShardingFromUsers( + HloInstruction* instruction, const DomainMetadata::Domain& domain, + const HloSharding& domain_sharding) { if (instruction->users().empty()) { // No sharding from users, use domain_sharding, after checking // compatibility. @@ -317,8 +317,8 @@ StatusOr ApplyShardingFromUsers(HloInstruction* instruction, // Tries to propagate the sharding information into the instructions that are // part of the domain, in a reverse post order manner (users propagate to // instruction). -StatusOr ApplyDomainShardingPass(const DomainMetadata::Domain& domain, - const HloSharding& domain_sharding) { +absl::StatusOr ApplyDomainShardingPass( + const DomainMetadata::Domain& domain, const HloSharding& domain_sharding) { int64_t assigned = 0; // domain.instructions are ordered in a post-order manner. As we do // user->operand propagation we process instructions in reverse order. In so @@ -380,8 +380,8 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain, return OkStatus(); } -StatusOr> ExtractOriginalCommonSharding( - absl::Span instructions) { +absl::StatusOr> +ExtractOriginalCommonSharding(absl::Span instructions) { // If we are here, all the instructions being passed had the same sharding // (or no sharding), by the means of the ShardingMatches() API. // As such, no kDomain was inserted, and here we are asked to extract the @@ -435,7 +435,7 @@ std::string ShardingMetadata::ToString() const { return sharding_ != nullptr ? sharding_->ToString() : "{}"; } -/*static*/ StatusOr +/*static*/ absl::StatusOr ShardingMetadata::ToShardingMetadata(const DomainMetadata* metadata) { if (metadata->Kind() != ShardingMetadata::KindName()) { return Status( diff --git a/xla/hlo/ir/hlo_sharding_metadata.h b/xla/hlo/ir/hlo_sharding_metadata.h index 7fa15698e586e..f8d1b9368eb5d 100644 --- a/xla/hlo/ir/hlo_sharding_metadata.h +++ b/xla/hlo/ir/hlo_sharding_metadata.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -58,7 +58,7 @@ class ShardingMetadata : public DomainMetadata { static absl::string_view KindName() { return "sharding"; } - static StatusOr ToShardingMetadata( + static absl::StatusOr ToShardingMetadata( const DomainMetadata* metadata); // Apply the specified domain metadata onto the specified domain. If no diff --git a/xla/hlo/ir/ptrvec.h b/xla/hlo/ir/ptrvec.h new file mode 100644 index 0000000000000..6dd30fc1ec317 --- /dev/null +++ b/xla/hlo/ir/ptrvec.h @@ -0,0 +1,380 @@ +/* Copyright 2023 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#ifndef XLA_HLO_IR_PTRVEC_H_ +#define XLA_HLO_IR_PTRVEC_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "tsl/platform/logging.h" // IWYU pragma: keep + +namespace xla { + +// PtrVec is like a std::vector or absl::InlinedVector, but +// optimized to use less memory for empty and single element vectors. +// +// T must be a pointer type (e.g., char*, const int*, double*, etc.). +template +class PtrVec { + public: + static_assert(std::is_pointer::value); + + // Default constructible. + PtrVec(); + ~PtrVec(); + + // Copyable. + PtrVec(const PtrVec& x); + PtrVec& operator=(const PtrVec& x); + + // Movable. + PtrVec(PtrVec&& x); + PtrVec& operator=(PtrVec&& x); + + // Const iteration. Non-const iteration can be easily added if necessary. + using difference_type = std::ptrdiff_t; + using value_type = T; + using pointer = T*; + using reference = T&; + using const_reference = T const&; + using const_iterator = T const*; + const_iterator begin() const; + const_iterator end() const; + + // Subset of vector-like operations. + size_t size() const; + bool empty() const; + T* data(); + T const* data() const; + T& operator[](size_t i); + T operator[](size_t i) const; + T at(size_t i) const; + T front() const; + T back() const; + void clear(); + void pop_back(); + void push_back(T x); + void erase(const_iterator iter); + + // For compatibility with existing code, allow conversion to vector. + // NOLINTNEXTLINE(google-explicit-constructor) + operator std::vector() const; + + private: + // rep_ is either a T, or its bottom two bits are interpreted as a tag: + // kEmptyTag empty + // kBigTag remaining bits are a Big* + // + // kEmptyTag and kBigTag have bottom bit 1. If we attempt to store a single + // pointer whose bottom bit is 1, we immediately switch to the big + // representation to avoid ambiguity. + // Empty vectors are represented uniquely in the small representation. + static constexpr uintptr_t kEmptyTag = 0x1; + static constexpr uintptr_t kBigTag = 0x3; + static constexpr uintptr_t kTagMask = 0x3; + + struct Big { + size_t size; + size_t capacity; + T data[]; // Beginning of variable sized portion + }; + + inline static bool can_inline(T ptr) { + // If T has enough alignment, ptr bottom bit must be zero, so we can store + // it inline without ambiguity. Otherwise we do a dynamic check. + if constexpr (alignof(decltype(*ptr)) >= 2) { + DCHECK_EQ(reinterpret_cast(ptr) & 0x1, 0); + return true; + } + return ((reinterpret_cast(ptr) & 0x1) == 0); + } + + inline bool is_big() const { return (rep_ & kTagMask) == kBigTag; } + + inline Big* big() const { + DCHECK(is_big()); + return reinterpret_cast(rep_ & ~kTagMask); + } + + // big_size returns the number of bytes to allocate for a Big representation + // that can store up to the specified number of elements. + inline static size_t big_size(size_t n) { + // Verify that we won't overflow. + static constexpr size_t kMaxFit = + (std::numeric_limits::max() - sizeof(Big)) / sizeof(T); + DCHECK_LE(n, kMaxFit); + const size_t result = sizeof(Big) + n * sizeof(T); + DCHECK_GE(result, sizeof(Big)); + return result; + } + + // MakeBig switches to an empty Big representation with at least the + // specified capacity. Caller is responsible for freeing any old Big + // representation. + inline Big* MakeBig(size_t capacity) { + Big* big = static_cast(malloc(big_size(capacity))); + big->size = 0; + big->capacity = capacity; + rep_ = reinterpret_cast(big) | kBigTag; + return big; + } + + inline static void FreeBig(Big* big) { free(big); } + + uintptr_t rep_; +}; + +// Implementation details: + +template +inline PtrVec::PtrVec() : rep_(kEmptyTag) {} + +template +inline PtrVec::~PtrVec() { + if (is_big()) FreeBig(big()); +} + +template +inline PtrVec::PtrVec(const PtrVec& x) : rep_(kEmptyTag) { + *this = x; +} + +template +inline PtrVec& PtrVec::operator=(const PtrVec& x) { + if (this == &x) { + return *this; + } + + const size_t n = x.size(); + Big* b; + if (!is_big()) { + // Stick with small representation if we can. + if (n < 2) { + if (n == 0) { + rep_ = kEmptyTag; + return *this; + } + T single = x.front(); + if (can_inline(single)) { + rep_ = reinterpret_cast(single); + DCHECK(!empty()); + DCHECK(!is_big()); + return *this; + } + } + + // Switch to big representation. + b = MakeBig(x.size()); + } else { + if (n == 0) { + // Make empty() faster by always using a unique representation for empty + // vectors (tag is empty). + clear(); + return *this; + } + b = big(); + if (b->capacity < n) { + FreeBig(b); + b = MakeBig(n); + } + } + + memcpy(b->data, x.data(), n * sizeof(T)); + b->size = n; + return *this; +} + +template +inline PtrVec::PtrVec(PtrVec&& x) : rep_(x.rep_) { + x.rep_ = kEmptyTag; +} + +template +inline PtrVec& PtrVec::operator=(PtrVec&& x) { + if (this != &x) { + if (is_big()) { + FreeBig(big()); + } + rep_ = x.rep_; + x.rep_ = kEmptyTag; + } + return *this; +} + +template +inline size_t PtrVec::size() const { + return is_big() ? big()->size : (rep_ != kEmptyTag ? 1 : 0); +} + +template +inline bool PtrVec::empty() const { + return rep_ == kEmptyTag; +} + +template +inline T* PtrVec::data() { + return is_big() ? big()->data : reinterpret_cast(&rep_); +} + +template +inline T const* PtrVec::data() const { + return is_big() ? big()->data : reinterpret_cast(&rep_); +} + +template +inline T& PtrVec::operator[](size_t i) { + DCHECK_LT(i, size()); + return *(data() + i); +} + +template +inline T PtrVec::operator[](size_t i) const { + DCHECK_LT(i, size()); + return *(data() + i); +} + +template +inline T PtrVec::at(size_t i) const { + DCHECK_LT(i, size()); + return *(data() + i); +} + +template +inline T PtrVec::front() const { + return (*this)[0]; +} + +template +inline T PtrVec::back() const { + return (*this)[size() - 1]; +} + +template +inline typename PtrVec::const_iterator PtrVec::begin() const { + return data(); +} + +template +inline typename PtrVec::const_iterator PtrVec::end() const { + return data() + size(); +} + +template +inline void PtrVec::clear() { + if (is_big()) { + FreeBig(big()); + } + rep_ = kEmptyTag; +} + +template +inline void PtrVec::pop_back() { + DCHECK(!empty()); + if (is_big()) { + big()->size--; + if (big()->size == 0) { + // Revert to unique representation of empty vectors. + clear(); + } + } else { + rep_ = kEmptyTag; // From length 1 to length 0 + } +} + +template +inline void PtrVec::push_back(T x) { + if (!is_big()) { + if (rep_ == kEmptyTag) { + if (can_inline(x)) { + // Switch from empty to singleton representation. + rep_ = reinterpret_cast(x); + DCHECK(!empty()); + DCHECK(!is_big()); + } else { + // Avoid ambiguity by jumping from empty to big representation. + Big* b = MakeBig(1); + b->size = 1; + b->data[0] = x; + } + } else { + // Switch from singleton to Big representation. + T singleton = front(); + Big* b = MakeBig(2); + b->size = 2; + b->data[0] = singleton; + b->data[1] = x; + } + } else { + // See if x fits in current Big. + Big* b = big(); + const size_t n = b->size; + DCHECK_LE(n, b->capacity); + if (n == b->capacity) { + Big* old = b; + b = MakeBig(std::max(2, 2 * old->capacity)); + memcpy(b->data, old->data, n * sizeof(T)); + FreeBig(old); + } + b->data[n] = x; + b->size = n + 1; + } +} + +template +inline void PtrVec::erase(const_iterator iter) { + DCHECK_GE(iter, begin()); + DCHECK_LT(iter, end()); + if (!is_big()) { + // Must be going from single element to zero. + rep_ = kEmptyTag; + } else { + Big* b = big(); + const size_t index = iter - b->data; + memmove(b->data + index, b->data + index + 1, + (b->size - index - 1) * sizeof(T)); + b->size--; + if (b->size == 0) { + // Revert to unique representation for empty vectors. + clear(); + } + } +} + +template +inline PtrVec::operator std::vector() const { + if (empty()) return {}; + return std::vector(begin(), end()); +} + +template +bool operator==(const PtrVec& a, const PtrVec& b) { + auto a_data = a.data(); + auto b_data = b.data(); + return std::equal(a_data, a_data + a.size(), b_data, b_data + b.size()); +} + +template +bool operator!=(const PtrVec& a, const PtrVec& b) { + return !(a == b); +} + +} // namespace xla + +#endif // XLA_HLO_IR_PTRVEC_H_ diff --git a/xla/hlo/ir/ptrvec_test.cc b/xla/hlo/ir/ptrvec_test.cc new file mode 100644 index 0000000000000..0834d38bcfcff --- /dev/null +++ b/xla/hlo/ir/ptrvec_test.cc @@ -0,0 +1,281 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/hlo/ir/ptrvec.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/test.h" +#include "tsl/platform/test_benchmark.h" + +namespace xla { +namespace { + +class PtrVecTest : public testing::Test { + public: + int* NewInt(int v) { + ints_.push_back(std::make_unique(v)); + return ints_.back().get(); + } + + void Fill(PtrVec& dst, absl::Span src) { + for (int v : src) { + dst.push_back(NewInt(v)); + } + } + + std::vector Pointees(const PtrVec& src) { + std::vector result; + result.reserve(src.size()); + for (int* ptr : src) { + result.push_back(*ptr); + } + return result; + } + + private: + // Underlying storage for pointers stored in PtrVec<>. + std::vector> ints_; +}; + +// Some useful vectors to test with. +std::vector> TestCases() { + return std::vector>{ + {}, + {100}, + {200, 300}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + }; +} + +TEST_F(PtrVecTest, Accessors) { + for (const auto& c : TestCases()) { + SCOPED_TRACE(c.size()); + PtrVec v; + Fill(v, c); + ASSERT_EQ(v.empty(), c.empty()); + ASSERT_EQ(v.size(), c.size()); + if (!c.empty()) { + ASSERT_EQ(*v.front(), c.front()); + ASSERT_EQ(*v.back(), c.back()); + } + } +} + +TEST_F(PtrVecTest, Iteration) { + for (const auto& c : TestCases()) { + SCOPED_TRACE(c.size()); + PtrVec v; + Fill(v, c); + int i = 0; + for (auto ptr : v) { + ASSERT_EQ(*ptr, c[i]); + i++; + } + } +} + +TEST_F(PtrVecTest, Indexing) { + for (const auto& c : TestCases()) { + SCOPED_TRACE(c.size()); + PtrVec v; + Fill(v, c); + for (int i = 0; i < c.size(); i++) { + ASSERT_EQ(*v[i], c[i]); + ASSERT_EQ(*v.at(i), c[i]); + } + } +} + +TEST_F(PtrVecTest, Data) { + for (const auto& c : TestCases()) { + SCOPED_TRACE(c.size()); + PtrVec v; + Fill(v, c); + int** data = v.data(); + for (int i = 0; i < c.size(); i++) { + ASSERT_EQ(*data[i], c[i]); + } + } +} + +TEST_F(PtrVecTest, ConversionToVector) { + for (const auto& c : TestCases()) { + SCOPED_TRACE(c.size()); + PtrVec v; + Fill(v, c); + std::vector vec = v; + ASSERT_EQ(vec.size(), c.size()); + for (int i = 0; i < c.size(); i++) { + ASSERT_EQ(*vec[i], c[i]); + } + } +} + +TEST_F(PtrVecTest, Clear) { + for (const auto& c : TestCases()) { + SCOPED_TRACE(c.size()); + PtrVec v; + Fill(v, c); + v.clear(); + EXPECT_EQ(Pointees(v), std::vector{}); + } +} + +TEST_F(PtrVecTest, PopBack) { + for (const auto& c : TestCases()) { + SCOPED_TRACE(c.size()); + PtrVec v; + Fill(v, c); + auto model = c; + while (!model.empty()) { + model.pop_back(); + v.pop_back(); + EXPECT_EQ(Pointees(v), model); + } + } +} + +TEST_F(PtrVecTest, Erase) { + for (const auto& c : TestCases()) { + if (c.empty()) { + continue; + } + SCOPED_TRACE(c.size()); + PtrVec v; + Fill(v, c); + auto model = c; + int offset = c.size() / 2; + model.erase(model.begin() + offset); + v.erase(v.begin() + offset); + EXPECT_EQ(Pointees(v), model); + } +} + +TEST_F(PtrVecTest, Assign) { + const auto cases = TestCases(); + for (const auto& x : cases) { + for (const auto& y : cases) { + SCOPED_TRACE(absl::StrFormat("from %d to %d", x.size(), y.size())); + + // Copy construct + { + PtrVec b; + Fill(b, y); + PtrVec a = b; + ASSERT_EQ(Pointees(a), y); + } + + // Move construct + { + PtrVec b; + Fill(b, y); + PtrVec a = std::move(b); + ASSERT_EQ(Pointees(a), y); + // NOLINTNEXTLINE(bugprone-use-after-move) + ASSERT_EQ(Pointees(b), std::vector{}); + } + + // Copy + { + PtrVec a; + Fill(a, x); + ASSERT_EQ(Pointees(a), x); + PtrVec b; + Fill(b, y); + a = b; + ASSERT_EQ(Pointees(a), y); + } + + // Move + { + PtrVec a; + Fill(a, x); + PtrVec b; + Fill(b, y); + a = std::move(b); + ASSERT_EQ(Pointees(a), y); + // NOLINTNEXTLINE(bugprone-use-after-move) + ASSERT_EQ(Pointees(b), std::vector{}); + } + } + } +} + +TEST_F(PtrVecTest, ReducedAlignment) { + const char* str = "hello world"; + for (int i = 0; i < 11; i++) { + PtrVec vec; + vec.push_back(&str[i]); + EXPECT_EQ(vec.size(), 1); + EXPECT_EQ(vec[0], &str[i]); + + PtrVec copy; + copy = vec; + EXPECT_EQ(copy.size(), 1); + EXPECT_EQ(copy[0], &str[i]); + } +} + +struct Elem { + int64_t number; +}; + +void BM_PtrVecIter(::testing::benchmark::State& state) { + const int n = state.range(0); + std::vector storage(n); + PtrVec vec; + for (int i = 0; i < n; i++) { + storage[i].number = i; + vec.push_back(&storage[i]); + } + + uintptr_t sum = 0; + for (auto s : state) { + for (int i = 0; i < vec.size(); i++) { + sum += reinterpret_cast(vec[i]); + } + } + VLOG(1) << sum; +} +BENCHMARK(BM_PtrVecIter)->Arg(0)->Arg(1)->Arg(2)->Arg(4)->Arg(8)->Arg(1024); + +void BM_StdVecIter(::testing::benchmark::State& state) { + const int n = state.range(0); + std::vector storage(n); + std::vector vec; + for (int i = 0; i < n; i++) { + storage[i].number = i; + vec.push_back(&storage[i]); + } + + uintptr_t sum = 0; + for (auto s : state) { + for (int i = 0; i < vec.size(); i++) { + sum += reinterpret_cast(vec[i]); + } + } + VLOG(1) << sum; +} +BENCHMARK(BM_StdVecIter)->Arg(0)->Arg(1)->Arg(2)->Arg(4)->Arg(8)->Arg(1024); + +} // namespace +} // namespace xla diff --git a/xla/hlo/ir/tile_assignment.cc b/xla/hlo/ir/tile_assignment.cc index a445e37d04e37..8bba81c8528ea 100644 --- a/xla/hlo/ir/tile_assignment.cc +++ b/xla/hlo/ir/tile_assignment.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/ir/tile_assignment.h b/xla/hlo/ir/tile_assignment.h index 871b6e6aeabb3..a68d932f7627b 100644 --- a/xla/hlo/ir/tile_assignment.h +++ b/xla/hlo/ir/tile_assignment.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/hlo/transforms/hlo_constant_splitter.cc b/xla/hlo/transforms/hlo_constant_splitter.cc index 9df91d0e7eb18..f5849e5173d46 100644 --- a/xla/hlo/transforms/hlo_constant_splitter.cc +++ b/xla/hlo/transforms/hlo_constant_splitter.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -41,9 +41,9 @@ bool IsSupportedConstantExpression(const HloInstruction* instruction) { // Perform duplication of a certain constant expression and replace the // original expression for a specific user. -StatusOr DuplicateConstantExpressionPerUser(HloComputation* computation, - HloInstruction* to_clone, - HloInstruction* user) { +absl::StatusOr DuplicateConstantExpressionPerUser( + HloComputation* computation, HloInstruction* to_clone, + HloInstruction* user) { absl::InlinedVector, 8> worklist( 1, std::make_pair(to_clone, 0)); absl::InlinedVector to_clone_vec; @@ -94,7 +94,7 @@ StatusOr DuplicateConstantExpressionPerUser(HloComputation* computation, } // namespace -StatusOr HloConstantSplitter::Run( +absl::StatusOr HloConstantSplitter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/hlo/transforms/hlo_constant_splitter.h b/xla/hlo/transforms/hlo_constant_splitter.h index 50ac60671235a..3ad29dea4703a 100644 --- a/xla/hlo/transforms/hlo_constant_splitter.h +++ b/xla/hlo/transforms/hlo_constant_splitter.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -36,7 +36,7 @@ class HloConstantSplitter : public HloModulePass { : split_expressions_(split_expressions) {} absl::string_view name() const override { return "hlo-constant-splitter"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/hlo/transforms/hlo_constant_splitter_test.cc b/xla/hlo/transforms/hlo_constant_splitter_test.cc index 7ef0b68e58b32..c12ae702906b8 100644 --- a/xla/hlo/transforms/hlo_constant_splitter_test.cc +++ b/xla/hlo/transforms/hlo_constant_splitter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/xla/hlo/utils/BUILD b/xla/hlo/utils/BUILD index d2ea0d16293f8..5bc30bf1d01e0 100644 --- a/xla/hlo/utils/BUILD +++ b/xla/hlo/utils/BUILD @@ -1,12 +1,16 @@ # Description: # Implementation of XLA’s HLO utilities used for higher-level transformations. +load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load( "//xla:xla.bzl", "xla_cc_test", ) -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") - +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load( + "@tsl//tsl:tsl.bzl", + "clean_dep", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [":friends"], @@ -22,26 +26,22 @@ package_group( cc_library( name = "hlo_live_range", - srcs = [ - "hlo_live_range.cc", - ], - hdrs = [ - "hlo_live_range.h", - ], + srcs = ["hlo_live_range.cc"], + hdrs = ["hlo_live_range.h"], deps = [ + "//xla:shape_util", "//xla:statusor", - "//xla:types", "//xla/hlo/ir:hlo", "//xla/service:hlo_alias_analysis", "//xla/service:hlo_buffer", "//xla/service:hlo_dataflow_analysis", "//xla/service:hlo_ordering", "//xla/service:hlo_value", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:status", + "@tsl//tsl/platform:logging", ], ) @@ -58,7 +58,9 @@ xla_cc_test( "//xla/service:hlo_value", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_map", "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", ], ) @@ -69,6 +71,7 @@ cc_library( hdrs = ["hlo_matchers.h"], deps = [ "//xla:test", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:hlo_parser", "@com_google_absl//absl/strings", @@ -82,8 +85,10 @@ xla_cc_test( ":hlo_matchers", "//xla:literal_util", "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", ], ) @@ -100,15 +105,19 @@ cc_library( "//xla:literal_util", "//xla:protobuf_util", "//xla:shape_util", + "//xla:status", + "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:tile_assignment", "//xla/service:call_graph", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -123,10 +132,16 @@ xla_cc_test( ], deps = [ ":hlo_sharding_util", + "//xla:array", + "//xla:shape_util", "//xla:test", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:tile_assignment", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/log", + "@com_google_absl//absl/types:span", ], ) @@ -142,3 +157,14 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", ], ) +cc_library( + name = "common_ortools_deps", + srcs = ["common_ortools_deps.h"], + visibility = ["//visibility:public"], + deps = [ + clean_dep("@com_google_ortools//ortools/linear_solver"), + clean_dep("@com_google_ortools//ortools/linear_solver:linear_solver_cc_proto"), + clean_dep("@com_google_ortools//ortools/sat:cp_model"), + clean_dep("@com_google_absl//absl/strings"), + ] +) \ No newline at end of file diff --git a/xla/hlo/utils/common_ortools_deps.h b/xla/hlo/utils/common_ortools_deps.h new file mode 100644 index 0000000000000..46a0de861d000 --- /dev/null +++ b/xla/hlo/utils/common_ortools_deps.h @@ -0,0 +1,11 @@ +#ifndef ORTOOLS_LINEAR_SOLVER_H +#define ORTOOLS_LINEAR_SOLVER_H +#include "ortools/linear_solver/linear_solver.h" +#include "ortools/linear_solver/linear_solver.pb.h" +#include "ortools/linear_solver/model_exporter.h" +#include "ortools/sat/cp_model.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/cp_model_solver.h" +#include "ortools/sat/lp_utils.h" +#include "absl/strings/string_view.h" +#endif \ No newline at end of file diff --git a/xla/hlo/utils/hlo_live_range.cc b/xla/hlo/utils/hlo_live_range.cc index e62f55f9e3ee2..670bf8f570889 100644 --- a/xla/hlo/utils/hlo_live_range.cc +++ b/xla/hlo/utils/hlo_live_range.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,21 +16,32 @@ limitations under the License. #include "xla/hlo/utils/hlo_live_range.h" #include +#include +#include +#include +#include #include #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_value.h" +#include "xla/shape_util.h" +#include "xla/statusor.h" +#include "tsl/platform/logging.h" namespace xla { /*static*/ -StatusOr> HloLiveRange::Run( +absl::StatusOr> HloLiveRange::Run( const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis, const HloComputation* computation, bool module_scoped_analysis) { std::unique_ptr hlo_live_range( @@ -154,6 +165,12 @@ HloLiveRange::LogicalTime HloLiveRange::GetLastUsageTime( LogicalTime end_time = -1; for (const HloUse& use : value.GetUses()) { const HloInstruction* used = use.instruction; + + // In module scoped mode when all call operations are flattened ignore uses + // by call operation itself, and rely on the last usage time inferred from + // the operations in the called computation. + if (module_scoped_analysis_ && used->opcode() == HloOpcode::kCall) continue; + // As an optimization, we deem a while's init value's live range ends as // soon as the loop body starts. This optimization is only applicable in // module scoped mode. @@ -201,12 +218,11 @@ void HloLiveRange::CalculateBufferStartEndMap() { if (async_context_it != computations_in_async_context_.end()) { const HloComputation* async_context = async_context_it->second; CHECK(async_context->IsAsyncComputation()); - auto async_done_it = - absl::c_find_if(async_context->AsyncInstructions(), - HloPredicateIsOp); - CHECK(async_done_it != async_context->AsyncInstructions().end()); + auto async_done = async_context->AsyncStart()->async_chain_done(); + auto async_done_it = instruction_schedule_.find(async_done); + CHECK(async_done_it != instruction_schedule_.end()); definition_end_time = - std::max(definition_end_time, instruction_schedule_[*async_done_it]); + std::max(definition_end_time, async_done_it->second); VLOG(2) << "Setting the definition end time for op in async context: " << definition_end_time; } diff --git a/xla/hlo/utils/hlo_live_range.h b/xla/hlo/utils/hlo_live_range.h index 62a2f0d19abd5..eb1530503ab12 100644 --- a/xla/hlo/utils/hlo_live_range.h +++ b/xla/hlo/utils/hlo_live_range.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,23 +15,17 @@ the License. #ifndef XLA_HLO_UTILS_HLO_LIVE_RANGE_H_ #define XLA_HLO_UTILS_HLO_LIVE_RANGE_H_ +#include #include #include -#include #include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/hlo_alias_analysis.h" -#include "xla/service/hlo_buffer.h" -#include "xla/service/hlo_dataflow_analysis.h" -#include "xla/service/hlo_ordering.h" #include "xla/service/hlo_value.h" #include "xla/statusor.h" -#include "xla/types.h" -#include "tsl/platform/status.h" namespace xla { @@ -43,7 +37,7 @@ class HloLiveRange { public: // Constructs a hlo live range object for the given module and computation // assuming the given HLO instruction ordering. - static StatusOr> Run( + static absl::StatusOr> Run( const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis, const HloComputation* computation, bool module_scoped_analysis = true); diff --git a/xla/hlo/utils/hlo_live_range_test.cc b/xla/hlo/utils/hlo_live_range_test.cc index ec49337ac08ee..b4155b103cfc1 100644 --- a/xla/hlo/utils/hlo_live_range_test.cc +++ b/xla/hlo/utils/hlo_live_range_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,21 +14,22 @@ limitations under the License. ==============================================================================*/ #include "xla/hlo/utils/hlo_live_range.h" +#include #include #include #include #include +#include "absl/container/flat_hash_map.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/literal.h" +#include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/hlo_alias_analysis.h" -#include "xla/service/hlo_ordering.h" #include "xla/service/hlo_value.h" -#include "xla/status_macros.h" #include "xla/tests/hlo_test_base.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -417,9 +418,9 @@ HloModule AsyncCall, is_scheduled=true, entry_computation_layout={(f32[4096]{0}, %called_computation (param_0: f32[4096], param_1: f32[4096]) -> f32[4096] { %param_0 = f32[4096]{0} parameter(0) %param_1 = f32[4096]{0} parameter(1) - %negate_0 = f32[4096]{0} negate(f32[4096]{0} %param_0) - %negate_1 = f32[4096]{0} negate(f32[4096]{0} %param_1) - ROOT %result.1 = f32[4096]{0} add(f32[4096]{0} %negate_0, f32[4096]{0} %negate_1) + %negate_2 = f32[4096]{0} negate(f32[4096]{0} %param_0) + %negate_3 = f32[4096]{0} negate(f32[4096]{0} %param_1) + ROOT %result.1 = f32[4096]{0} add(f32[4096]{0} %negate_2, f32[4096]{0} %negate_3) } %async_wrapped (async_param: f32[4096], async_param.1: f32[4096]) -> f32[4096] { @@ -431,11 +432,11 @@ HloModule AsyncCall, is_scheduled=true, entry_computation_layout={(f32[4096]{0}, ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] { %a = f32[4096]{0} parameter(0) %b = f32[4096]{0} parameter(1) - %async-start = ((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) async-start(f32[4096]{0} %a, f32[4096]{0} %b), async_group_id=0, calls=%async_wrapped - %negate_2 = f32[4096]{0} negate(f32[4096]{0} %a) - %negate_3 = f32[4096]{0} negate(f32[4096]{0} %b) - %add_0 = f32[4096]{0} add(f32[4096]{0} %negate_2, f32[4096]{0} %negate_3) - %async-done = f32[4096]{0} async-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start), async_group_id=0, calls=%async_wrapped + %negate_0 = f32[4096]{0} negate(f32[4096]{0} %a) + %negate_1 = f32[4096]{0} negate(f32[4096]{0} %b) + %async-start = ((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) async-start(f32[4096]{0} %negate_0, f32[4096]{0} %negate_1), calls=%async_wrapped + %add_0 = f32[4096]{0} add(f32[4096]{0} %negate_0, f32[4096]{0} %negate_1) + %async-done = f32[4096]{0} async-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start) ROOT %add_1 = f32[4096]{0} add(f32[4096]{0} %add_0, f32[4096]{0} %async-done) } )"; @@ -445,6 +446,72 @@ ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] { Analyze(schedule); CheckSchedule(); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr aa, + HloAliasAnalysis::Run(module_.get())); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_live_range, + HloLiveRange::Run(module_->schedule(), *aa, + module_->entry_computation())); + + absl::flat_hash_map> inst_ranges; + for (auto& [value, time_bound] : hlo_live_range->buffer_live_ranges()) { + inst_ranges[value->instruction()->name()] = {time_bound.start, + time_bound.end}; + } + + // Parameters live range spans whole computation. + EXPECT_EQ(inst_ranges["a"], std::make_pair(0, 16)); + EXPECT_EQ(inst_ranges["b"], std::make_pair(0, 16)); + + // Check `add` operations live range to make sure that `negate` values live + // range spans past the last non-async use. + EXPECT_EQ(inst_ranges["add_0"], std::make_pair(13, 15)); + EXPECT_EQ(inst_ranges["add_1"], std::make_pair(15, 16)); + + // `negate_0` and `negate_1` live range ends after `async-done`. + EXPECT_EQ(inst_ranges["negate_0"], std::make_pair(2, 14)); + EXPECT_EQ(inst_ranges["negate_1"], std::make_pair(3, 14)); } + +TEST_F(HloLiveRangeTest, Call) { + std::string hlo_string = R"( + HloModule Call, is_scheduled=true + + %called_computation (param_0: f32[4096]) -> f32[4096] { + %param_0 = f32[4096]{0} parameter(0) + ROOT %negate_0 = f32[4096]{0} negate(f32[4096]{0} %param_0) + } + + ENTRY %main (a: f32[4096]) -> f32[4096] { + %a = f32[4096]{0} parameter(0) + %b = f32[4096]{0} negate(%a) + %c = f32[4096]{0} call(%b), to_apply=%called_computation + %d = f32[4096]{0} negate(%c) + ROOT %e = f32[4096]{0} add(%c, %d) + })"; + + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr aa, + HloAliasAnalysis::Run(module_.get())); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_live_range, + HloLiveRange::Run(module_->schedule(), *aa, + module_->entry_computation())); + + absl::flat_hash_map> inst_ranges; + for (auto& [value, time_bound] : hlo_live_range->buffer_live_ranges()) { + inst_ranges[value->instruction()->name()] = {time_bound.start, + time_bound.end}; + } + + EXPECT_EQ(inst_ranges["a"], std::make_pair(0, 7)); + EXPECT_EQ(inst_ranges["b"], std::make_pair(1, 3)); + EXPECT_EQ(inst_ranges["negate_0"], std::make_pair(3, 6)); + EXPECT_EQ(inst_ranges["d"], std::make_pair(5, 6)); + EXPECT_EQ(inst_ranges["e"], std::make_pair(6, 7)); +} + } // namespace } // namespace xla diff --git a/xla/hlo/utils/hlo_matchers.cc b/xla/hlo/utils/hlo_matchers.cc index 1704fcb7823a9..584381674d07a 100644 --- a/xla/hlo/utils/hlo_matchers.cc +++ b/xla/hlo/utils/hlo_matchers.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -390,6 +390,47 @@ void HloSourceTargetPairsMatcher::DescribeTo(std::ostream* os) const { }; *os << '{' << absl::StrJoin(source_target_pairs_, ",", pair_formatter) << "}"; } +bool HloMetadataMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + *listener << " (metadata: "; + if (instruction->metadata().op_type() != metadata_.op_type()) { + *listener << " has wrong metadata (got " + << instruction->metadata().op_type() << ", want " + << metadata_.op_type() << ")"; + return false; + } + *listener << metadata_.op_type() << " "; + if (instruction->metadata().op_name() != metadata_.op_name()) { + *listener << " has wrong metadata (got " + << instruction->metadata().op_name() << ", want " + << metadata_.op_name() << ")"; + return false; + } + *listener << metadata_.op_name() << " "; + if (instruction->metadata().source_file() != metadata_.source_file()) { + *listener << " has wrong metadata (got " + << instruction->metadata().source_file() << ", want " + << metadata_.source_file() << ")"; + return false; + } + *listener << metadata_.source_file() << " "; + if (instruction->metadata().source_line() != metadata_.source_line()) { + *listener << " has wrong metadata (got " + << instruction->metadata().source_line() << ", want " + << metadata_.source_line() << ")"; + return false; + } + *listener << metadata_.source_line(); + *listener << ")"; + return true; +} + +void HloMetadataMatcher::DescribeTo(std::ostream* os) const { + *os << " (metadata: " << metadata_.op_type() << " " << metadata_.op_name() + << " " << metadata_.source_file() << " " << metadata_.source_line() + << ")"; +} } // namespace testing void PrintTo(const HloInstruction* inst, ::std::ostream* os) { diff --git a/xla/hlo/utils/hlo_matchers.h b/xla/hlo/utils/hlo_matchers.h index a3bf9f214dce0..17f3294156bca 100644 --- a/xla/hlo/utils/hlo_matchers.h +++ b/xla/hlo/utils/hlo_matchers.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/hlo_parser.h" #include "xla/test.h" +#include "xla/xla_data.pb.h" namespace xla { namespace testing { @@ -237,6 +238,20 @@ class HloSourceTargetPairsMatcher std::vector> source_target_pairs_; }; +class HloMetadataMatcher + : public ::testing::MatcherInterface { + public: + explicit HloMetadataMatcher(OpMetadata metadata) + : metadata_(std::move(metadata)) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + + private: + OpMetadata metadata_; +}; + // HloInstruction* matchers for opcode and operands. Example: // namespace op = xla::opcode_matchers; // EXPECT_THAT(instruction, @@ -270,6 +285,7 @@ HLO_MATCHER(Broadcast); HLO_MATCHER(Call); HLO_MATCHER(Ceil); HLO_MATCHER(Clamp); +HLO_MATCHER(CollectiveBroadcast); HLO_MATCHER(CollectivePermute); HLO_MATCHER(CollectivePermuteStart); HLO_MATCHER(CollectivePermuteDone); @@ -285,6 +301,7 @@ HLO_MATCHER(Divide); HLO_MATCHER(Domain); HLO_MATCHER(DynamicSlice); HLO_MATCHER(DynamicUpdateSlice); +HLO_MATCHER(Erf); HLO_MATCHER(Exp); HLO_MATCHER(Fft); HLO_MATCHER(Floor); @@ -552,6 +569,12 @@ inline ::testing::Matcher SourceTargetPairs( std::move(source_target_pairs))); } +inline ::testing::Matcher Metadata( + OpMetadata metadata) { + return ::testing::MakeMatcher( + new ::xla::testing::HloMetadataMatcher(std::move(metadata))); +} + #undef HLO_MATCHER } // namespace opcode_matchers diff --git a/xla/hlo/utils/hlo_matchers_test.cc b/xla/hlo/utils/hlo_matchers_test.cc index 799e9b26cb19f..9811e8830f944 100644 --- a/xla/hlo/utils/hlo_matchers_test.cc +++ b/xla/hlo/utils/hlo_matchers_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,9 +20,11 @@ limitations under the License. #include #include +#include #include "xla/literal_util.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" namespace op = xla::testing::opcode_matchers; using ::testing::_; @@ -277,10 +279,12 @@ TEST_F(HloMatchersTest, ComparisonMatcher) { TEST_F(HloMatchersTest, AsyncCopyMatcher) { Shape shape_memspace1 = ShapeUtil::MakeShapeWithDenseLayout( F32, {16}, /*minor_to_major=*/{0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, /*memory_space=*/1); Shape shape_memspace2 = ShapeUtil::MakeShapeWithDenseLayout( F32, {16}, /*minor_to_major=*/{0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, /*memory_space=*/2); @@ -370,5 +374,47 @@ TEST_F(HloMatchersTest, SourceTargetPairsMatcher) { HasSubstr("source_target_pairs (expected: {{0,1},{2,3}}")); EXPECT_THAT(cp.get(), op::SourceTargetPairs({{0, 1}, {2, 3}, {1, 2}})); } + +TEST_F(HloMatchersTest, MetadataMatcher) { + Shape shape = ShapeUtil::MakeShape(F32, {5, 7}); + std::unique_ptr p0 = + HloInstruction::CreateParameter(0, shape, "param"); + OpMetadata metadata; + metadata.set_op_type("op_type1"); + metadata.set_op_name("op_name1"); + p0->set_metadata(metadata); + + OpMetadata actual_opname; + actual_opname.set_op_type("op_type1"); + actual_opname.set_op_name("op_name2"); + + OpMetadata actual_source_file; + actual_source_file.set_op_type("op_type1"); + actual_source_file.set_op_name("op_name1"); + actual_source_file.set_source_file("source_file"); + + OpMetadata actual_optype; + actual_optype.set_op_type("op_type2"); + actual_optype.set_op_name("op_name1"); + + OpMetadata actual_source_line; + actual_source_line.set_op_type("op_type1"); + actual_source_line.set_op_name("op_name1"); + actual_source_line.set_source_line(1); + + EXPECT_THAT(Explain(p0.get(), op::Metadata(actual_opname)), + HasSubstr("has wrong metadata (got op_name1, want op_name2)")); + EXPECT_THAT(Explain(p0.get(), op::Metadata(actual_source_file)), + HasSubstr("has wrong metadata (got " + ", want source_file)")); + EXPECT_THAT(Explain(p0.get(), op::Metadata(actual_optype)), + HasSubstr("has wrong metadata (got" + " op_type1, want op_type2)")); + EXPECT_THAT(Explain(p0.get(), op::Metadata(actual_source_line)), + HasSubstr("has wrong metadata (got 0" + ", want 1)")); + EXPECT_THAT(DescribeHloMatcher(op::Metadata(p0->metadata())), + R"( (metadata: op_type1 op_name1 0))"); +} } // namespace } // namespace xla diff --git a/xla/hlo/utils/hlo_query.cc b/xla/hlo/utils/hlo_query.cc index c40445cec08ff..487fc72e99b5d 100644 --- a/xla/hlo/utils/hlo_query.cc +++ b/xla/hlo/utils/hlo_query.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,23 +31,32 @@ namespace hlo_query { bool IsCollectiveCommunicationOp(HloOpcode op) { return op == HloOpcode::kAllReduce || op == HloOpcode::kAllGather || op == HloOpcode::kAllToAll || op == HloOpcode::kCollectivePermute || + op == HloOpcode::kCollectiveBroadcast || op == HloOpcode::kReduceScatter || op == HloOpcode::kAllReduceStart || op == HloOpcode::kAllGatherStart || op == HloOpcode::kCollectivePermuteStart; } -bool IsAsyncCollectiveStartOp(HloOpcode op, bool include_send_recv) { +bool IsAsyncCollectiveStartOp(const HloInstruction* instruction, + bool include_send_recv) { + HloOpcode op = instruction->opcode(); + if (op == HloOpcode::kAsyncStart) { + return IsCollectiveCommunicationOp(instruction->async_wrapped_opcode()); + } return op == HloOpcode::kAllReduceStart || op == HloOpcode::kAllGatherStart || op == HloOpcode::kCollectivePermuteStart || - op == HloOpcode::kAsyncStart || (include_send_recv && (op == HloOpcode::kSend || op == HloOpcode::kRecv)); } -bool IsAsyncCollectiveDoneOp(HloOpcode op, bool include_send_recv) { +bool IsAsyncCollectiveDoneOp(const HloInstruction* instruction, + bool include_send_recv) { + HloOpcode op = instruction->opcode(); + if (op == HloOpcode::kAsyncDone) { + return IsCollectiveCommunicationOp(instruction->async_wrapped_opcode()); + } return op == HloOpcode::kAllReduceDone || op == HloOpcode::kAllGatherDone || op == HloOpcode::kCollectivePermuteDone || - op == HloOpcode::kAsyncDone || (include_send_recv && (op == HloOpcode::kSendDone || op == HloOpcode::kRecvDone)); } @@ -159,6 +168,11 @@ bool IsBroadcastOfScalarConstant(const HloInstruction& instr) { IsScalarConstant(instr.operand(0)); } +bool IsBroadcastOfParameter(const HloInstruction& instr) { + return instr.opcode() == HloOpcode::kBroadcast && + instr.operand(0)->opcode() == HloOpcode::kParameter; +} + HloInstruction* GetFirstInstructionWithOpcode(const HloComputation& computation, const HloOpcode opcode) { auto instructions = computation.instructions(); diff --git a/xla/hlo/utils/hlo_query.h b/xla/hlo/utils/hlo_query.h index e166cbe9c9c09..e12e330665137 100644 --- a/xla/hlo/utils/hlo_query.h +++ b/xla/hlo/utils/hlo_query.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,12 +31,14 @@ namespace hlo_query { // that is represented as HloCollectiveInstruction. bool IsCollectiveCommunicationOp(HloOpcode op); -// Returns whether the given opcode represents the start operation for a +// Returns whether the given instruction represents the start operation for a // collective communication, may include send & recv operations. -bool IsAsyncCollectiveStartOp(HloOpcode op, bool include_send_recv = false); -// Returns whether the given opcode represents the done operation for a +bool IsAsyncCollectiveStartOp(const HloInstruction* instruction, + bool include_send_recv = false); +// Returns whether the given instruction represents the done operation for a // collective communication, may include send & recv operations. -bool IsAsyncCollectiveDoneOp(HloOpcode op, bool include_send_recv = false); +bool IsAsyncCollectiveDoneOp(const HloInstruction* instruction, + bool include_send_recv = false); // Returns whether the instruction provided is a constant rank-0 float32, and // if so, places the constant value into out. @@ -69,6 +71,9 @@ bool IsBroadcastedConstantOrScalar(const HloInstruction& instr); // scalar constant. bool IsBroadcastOfScalarConstant(const HloInstruction& instr); +// Returns whether the `instr` is a broadcast and its input is a parameter. +bool IsBroadcastOfParameter(const HloInstruction& instr); + // Returns first HLO of the computation with the opcode, otherwise nullptr. HloInstruction* GetFirstInstructionWithOpcode(const HloComputation& computation, HloOpcode opcode); diff --git a/xla/hlo/utils/hlo_sharding_util.cc b/xla/hlo/utils/hlo_sharding_util.cc index bee53152f046d..1d8734ce85e8e 100644 --- a/xla/hlo/utils/hlo_sharding_util.cc +++ b/xla/hlo/utils/hlo_sharding_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,31 +22,37 @@ limitations under the License. #include #include #include -#include #include #include #include #include #include "absl/algorithm/container.h" +#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xla/array.h" #include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/tile_assignment.h" #include "xla/literal_util.h" +#include "xla/map_util.h" #include "xla/protobuf_util.h" #include "xla/service/call_graph.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -74,31 +80,81 @@ bool IsSubTilingOrEqualSharding(const Shape& potential_sharded_shape, const int32_t tiled_data_rank = potential_subsharding.TiledDataRank(); // Different tiled ranks can't be compared (something is wrong, are the // shardings for different shapes?) - if (tiled_data_rank != sharding.TiledDataRank()) { + if (tiled_data_rank != sharding.TiledDataRank() || + tiled_data_rank != potential_sharded_shape.dimensions_size()) { return false; } - // Helper to construct the base tile bounds based on a shape and a sharding. - auto get_base_tile_for_sharding = [](const Shape& shape, - const HloSharding& sharding) { - absl::InlinedVector base_tile; - base_tile.resize(shape.dimensions_size()); - for (int64_t i = 0; i < shape.dimensions_size(); ++i) { - base_tile[i] = - CeilOfRatio(shape.dimensions(i), sharding.tile_assignment().dim(i)); - } - return base_tile; - }; - auto potential_base_tile = get_base_tile_for_sharding(potential_sharded_shape, - potential_subsharding); - auto base_tile = - get_base_tile_for_sharding(potential_sharded_shape, sharding); - // If the potential_base_tile is bigger than the base_tile on any dimension - // then it can't be contained regardless. - for (int64_t i = 0; i < potential_base_tile.size(); ++i) { - if (potential_base_tile[i] > base_tile[i]) { + + DimensionVector potential_base_tile(tiled_data_rank); + DimensionVector base_tile(tiled_data_rank); + bool shortcut = true; + int64_t diff_dim_counter = 0; + DimensionVector reshape_dims( + potential_subsharding.tile_assignment().dimensions().begin(), + potential_subsharding.tile_assignment().dimensions().end()); + for (int64_t i = 0; i < tiled_data_rank; ++i) { + const auto shape_i = potential_sharded_shape.dimensions(i); + const auto p_tile_dim_i = potential_subsharding.tile_assignment().dim(i); + const auto s_tile_dim_i = sharding.tile_assignment().dim(i); + if (p_tile_dim_i < s_tile_dim_i) { return false; } + potential_base_tile[i] = CeilOfRatio(shape_i, p_tile_dim_i); + base_tile[i] = CeilOfRatio(shape_i, s_tile_dim_i); + + if (s_tile_dim_i != 1 && + (p_tile_dim_i % s_tile_dim_i != 0 || + base_tile[i] % potential_base_tile[i] != 0 || + shape_i <= (p_tile_dim_i - 1) * potential_base_tile[i] || + shape_i <= (s_tile_dim_i - 1) * base_tile[i])) { + // The comment below explains this condition. + shortcut = false; + } + if (shortcut && p_tile_dim_i != s_tile_dim_i) { + reshape_dims[i + diff_dim_counter] = s_tile_dim_i; + reshape_dims.insert(reshape_dims.begin() + i + diff_dim_counter + 1, + p_tile_dim_i / s_tile_dim_i); + diff_dim_counter++; + } + } + + if (shortcut) { + // In the shortcut, we ensure that (1) p_tile_dim_i is divisible by + // s_tile_dim_i, (2) base_tile[i] is divisible by potential_base_tile[i], + // and (3) all devices have raw data of the tensor (a counterexample is that + // a device may only have paddings). We can use this shortcut to quickly + // make the decision. + // + // s_tile_dim_i == 1 means that it is replicated along dimension i with + // `sharding`, which is compatible with the shortcut. + // + // We cannot extend the shortcut if the condition fails. An example is + // listed below. Given potential_sharded_shape = [1, 1, 1, ..., 1], the raw + // data of the tensor is only on the first tile. Thus, we only need to focus + // on the first tile in the two input shardings. + if (!sharding.HasPartialReplication()) { + return potential_subsharding == sharding; + } + + std::vector perm(reshape_dims.size()); + absl::c_iota(perm, 0); + for (int64_t i = 0; i < tiled_data_rank; ++i) { + if (potential_subsharding.tile_assignment().dim(i) != + sharding.tile_assignment().dim(i)) { + auto element = perm[i + 1]; + perm.erase(perm.begin() + i + 1); + perm.push_back(element); + } + } + + auto reshaped_ta = potential_subsharding.tile_assignment() + .Reshape(reshape_dims) + .Transpose(perm) + .Reshape(sharding.tile_assignment().dimensions()); + return HloSharding::PartialTile(reshaped_ta).tile_assignment() == + sharding.tile_assignment(); } + // Use one contiguous storage to reduce allocation overhead. auto storage = std::make_unique( sharding.tile_assignment().num_elements() * tiled_data_rank); @@ -128,8 +184,8 @@ bool IsSubTilingOrEqualSharding(const Shape& potential_sharded_shape, }); // Compare the start offsets and the end offset of the tiles for each device. auto& potential_ta = potential_subsharding.tile_assignment().array(); - absl::Status ok_if_no_vialation = potential_ta.EachStatus( - [&](absl::Span indices, int64_t device) { + absl::Status ok_if_no_violation = potential_ta.EachStatus( + [&](absl::Span indices, int64_t device) -> absl::Status { auto sharding_offset = get_sharding_offsets(device); for (int j = 0; j < tiled_data_rank; ++j) { const int32_t subsharding_offset_j = @@ -137,23 +193,43 @@ bool IsSubTilingOrEqualSharding(const Shape& potential_sharded_shape, // The subsharding contains data outside of the tile we are comparing // against. if (subsharding_offset_j < sharding_offset[j]) { - return InternalError(""); + return Internal(""); } // Skip last tile. It can never go beyond the limit as the shape is // the same for both shardings and sometimes there's padding making // one of the two limits bigger than the other, but it shouldn't be // counted. - const bool is_last_tile = - subsharding_offset_j + potential_base_tile[j] >= - potential_sharded_shape.dimensions(j); - if (!is_last_tile && subsharding_offset_j + potential_base_tile[j] > - sharding_offset[j] + base_tile[j]) { - return InternalError(""); + if (subsharding_offset_j + potential_base_tile[j] <= + potential_sharded_shape.dimensions(j) && + subsharding_offset_j + potential_base_tile[j] > + sharding_offset[j] + base_tile[j]) { + return Internal(""); } } return absl::OkStatus(); }); - return ok_if_no_vialation.ok(); + return ok_if_no_violation.ok(); +} + +static bool IsLeafShardingMoreSpecific(const HloSharding& lhs, + const HloSharding& rhs) { + DCHECK(!lhs.IsTuple()); + DCHECK(!rhs.IsTuple()); + // Manual sharding is more specific than tile maximal sharding. + if (lhs.IsManualLeaf() && rhs.IsTileMaximalLeaf()) { + return true; + } + if (lhs.IsManualLeaf() || rhs.IsManualLeaf()) { + return false; + } + if (!rhs.IsTileMaximalLeaf()) { + return lhs.NumTilesLeaf() > rhs.NumTilesLeaf(); + } + // If we are not replicated then only tiled (not tile maximal) shardings + // can improve us. + // If we are replicated then any non-replicated sharding can improve us. + return !(rhs.IsReplicatedLeaf() ? lhs.IsReplicatedLeaf() + : lhs.IsTileMaximalLeaf()); } bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) { @@ -176,23 +252,7 @@ bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) { } return is_better; } - // Manual sharding is more specific than tile maximal sharding. - if (lhs.IsManual() && rhs.IsTileMaximal()) { - return true; - } - if (lhs.IsManual() || rhs.IsManual()) { - return false; - } - if (!rhs.IsTileMaximal()) { - return lhs.NumTiles() > rhs.NumTiles(); - } else if (!rhs.IsReplicated()) { - // If we are not replicated then only tiled (not tile maximal) shardings - // can improve us. - return !lhs.IsTileMaximal(); - } else { - // If we are replicated then any non-replicated sharding can improve us. - return !lhs.IsReplicated(); - } + return IsLeafShardingMoreSpecific(lhs, rhs); } bool MergeSharding(const HloSharding& to_merge, HloSharding* dst, @@ -211,7 +271,7 @@ bool MergeSharding(const HloSharding& to_merge, HloSharding* dst, !dst->HasPartialReplication() || to_merge.tile_assignment().num_elements() != dst->tile_assignment().num_elements()) { - return IsShardingMoreSpecific(*dst, to_merge); + goto check_if_more_specific; } if (MergeShardingIfCompatible( @@ -220,7 +280,8 @@ bool MergeSharding(const HloSharding& to_merge, HloSharding* dst, dst)) { return true; } - return IsShardingMoreSpecific(*dst, to_merge); +check_if_more_specific: + return IsLeafShardingMoreSpecific(*dst, to_merge); } bool MergeShardingIfCompatible(const HloSharding& to_merge, @@ -237,124 +298,251 @@ bool MergeShardingIfCompatible(const HloSharding& to_merge, if (!dst->HasPartialReplication()) { return false; } + if (dst->TiledDataRank() != to_merge.TiledDataRank()) { + return false; + } + + const int64_t to_merge_man_dim = to_merge.SubgroupManualDim(); + const int64_t dst_man_dim = dst->SubgroupManualDim(); + if ((to_merge_man_dim >= 0) != (dst_man_dim >= 0)) { + return false; + } + // Combine the tile dimension sizes from dst and to_merge. - int64_t num_devices = to_merge.tile_assignment().num_elements(); - std::vector merged_tile_dims; + DimensionVector perm_merge(dst->tile_assignment().num_dimensions(), -1); + DimensionVector perm_dst(dst->tile_assignment().num_dimensions(), -1); + int64_t perm_merge_counter = 0; + int64_t perm_dst_counter = 0; + DimensionVector merge_old_tile_dim, dst_old_tile_dim; + DimensionVector merge_new_tile_dim, dst_new_tile_dim; + DimensionVector merge_new_tile_index, dst_new_tile_index; + DimensionVector merged_tile_dims; merged_tile_dims.reserve(dst->tile_assignment().num_dimensions()); + int64_t num_merge_groups = 1; + int64_t num_dst_groups = 1; for (int64_t i = 0; i < to_merge.TiledDataRank(); ++i) { - int64_t dst_dim = dst->tile_assignment().dim(i); int64_t merge_dim = to_merge.tile_assignment().dim(i); - if (dst_dim == 1) { + int64_t dst_dim = dst->tile_assignment().dim(i); + num_merge_groups *= merge_dim; + num_dst_groups *= dst_dim; + if (dst_dim == merge_dim) { + merge_old_tile_dim.push_back(merge_dim); + perm_merge[i] = perm_merge_counter++; + dst_old_tile_dim.push_back(dst_dim); + perm_dst[i] = perm_dst_counter++; + merged_tile_dims.push_back(dst_dim); + } else if (dst_dim == 1) { + merge_old_tile_dim.push_back(merge_dim); + perm_merge[i] = perm_merge_counter++; + dst_new_tile_dim.push_back(merge_dim); + dst_new_tile_index.push_back(i); merged_tile_dims.push_back(merge_dim); } else if (merge_dim == 1) { - merged_tile_dims.push_back(dst_dim); - } else if (dst_dim == merge_dim) { + merge_new_tile_dim.push_back(dst_dim); + merge_new_tile_index.push_back(i); + dst_old_tile_dim.push_back(dst_dim); + perm_dst[i] = perm_dst_counter++; merged_tile_dims.push_back(dst_dim); } else { return false; } } - const int64_t num_tiles = Product(merged_tile_dims); - if (num_devices % num_tiles != 0 || num_tiles < minimum_tiles) { - return false; - } - int64_t to_merge_man_dim = to_merge.SubgroupManualDim(); - int64_t dst_man_dim = dst->SubgroupManualDim(); + if (to_merge_man_dim >= 0) { - if (dst_man_dim < 0) { - return false; - } int64_t man_group_size = to_merge.tile_assignment().dim(to_merge_man_dim); if (man_group_size != dst->tile_assignment().dim(dst_man_dim)) { return false; } + merge_old_tile_dim.push_back(man_group_size); + dst_old_tile_dim.push_back(man_group_size); + perm_merge[to_merge.TiledDataRank()] = perm_merge_counter++; + perm_dst[to_merge.TiledDataRank()] = perm_dst_counter++; + merged_tile_dims.push_back(man_group_size); + num_merge_groups *= man_group_size; + num_dst_groups *= man_group_size; } - int64_t replication = num_devices / Product(merged_tile_dims); - merged_tile_dims.push_back(replication); - Array merged_tile(merged_tile_dims); - // Maps from replication group ID to sorted members. - absl::flat_hash_map> merge_group_members; - absl::flat_hash_map> dst_group_members; - auto get_group_index = [&](absl::Span tile_indices, - const HloSharding& sharding, int64_t manual_dim) { - int64_t group_id = 0; - for (int64_t i = 0; i < to_merge.TiledDataRank(); ++i) { - group_id *= sharding.tile_assignment().dim(i); - group_id += tile_indices[i]; - } - if (manual_dim >= 0) { - group_id *= sharding.tile_assignment().dim(manual_dim); - group_id += tile_indices[manual_dim]; - } - return group_id; - }; - to_merge.tile_assignment().Each([&](absl::Span indices, - int64_t device) { - merge_group_members[get_group_index(indices, to_merge, to_merge_man_dim)] - .insert(device); - }); - dst->tile_assignment().Each( - [&](absl::Span indices, int64_t device) { - dst_group_members[get_group_index(indices, *dst, dst_man_dim)].insert( - device); - }); - // Try to find the intersection of to_merge and dst replication groups, in - // order to determine the merged tile assignment. - Status compatible = merged_tile.EachStatus( - [&](absl::Span indices, int64_t* device) { - std::vector to_merge_index( - to_merge.tile_assignment().num_dimensions()); - std::vector dst_index(dst->tile_assignment().num_dimensions()); - for (int64_t i = 0; i < to_merge.TiledDataRank(); ++i) { - if (to_merge.tile_assignment().dim(i) == 1) { - to_merge_index[i] = 0; - } else { - to_merge_index[i] = indices[i]; - } - if (dst->tile_assignment().dim(i) == 1) { - dst_index[i] = 0; - } else { - dst_index[i] = indices[i]; - } - } - if (to_merge_man_dim >= 0) { - to_merge_index[to_merge_man_dim] = indices[to_merge.TiledDataRank()]; - dst_index[dst_man_dim] = indices[to_merge.TiledDataRank()]; - } - if (to_merge.HasPartialReplication()) { - to_merge_index[to_merge.SubgroupReplicationDim()] = indices.back(); + + const int64_t num_devices = to_merge.tile_assignment().num_elements(); + const int64_t new_num_tiles = Product(merged_tile_dims); + if (num_devices % new_num_tiles != 0 || new_num_tiles < minimum_tiles) { + return false; + } + const int64_t replication = num_devices / new_num_tiles; + if (replication > 1) { + merged_tile_dims.push_back(replication); + } + + std::optional compatible_tile_assignment; + // We use two methods to find compatible_tile_assignment. The comparisons are + // liste below. + // 1. In terms of compilation speed, the first method is usually faster than + // the second one, especially when the number of devices is large. + // 2. The first method is friendly to the iota tile assignment. If to_merge or + // dst has iota tile assignment, the resultant sharding also has iota tile + // assignment. The second method always generates v1 sharding. + // 3. The first method can handle the common cases. However, it fails on + // corner cases, such as the arbitrary device order. Conversely, the second + // method can handle all cases. Above all, we initially try the first method, + // and proceed with the second one if the first one fails. + + { + // In the first method, we use reshape and transpose to generate the + // compatible tile assignments for the input sharding. Reshape: decompose + // the input sharding along the replicated dimension. Transpose: assign the + // decomposed dimensions to the new tiled dimensions. + auto get_compatible_tile_assignment = + [&](const HloSharding& sharding, const DimensionVector& old_tile_dims, + DimensionVector& new_tile_dims, DimensionVector& new_tile_indices, + DimensionVector& perm, + const int64_t perm_counter) -> std::vector { + if (!sharding.HasPartialReplication() || + sharding.tile_assignment().dim(sharding.SubgroupReplicationDim()) == + replication) { + return {sharding.tile_assignment()}; + } + if (replication == 1) { + perm.pop_back(); + } else { + new_tile_dims.push_back(replication); + new_tile_indices.push_back(dst->tile_assignment().num_dimensions() - 1); + } + + std::vector result; + DimensionVector iota(new_tile_dims.size()); + absl::c_iota(iota, 0); + do { + std::vector local_perm(perm.begin(), perm.end()); + int64_t local_perm_counter = perm_counter; + DimensionVector reshape_dims(old_tile_dims.begin(), + old_tile_dims.end()); + reshape_dims.reserve(old_tile_dims.size() + new_tile_dims.size()); + for (auto i : iota) { + reshape_dims.push_back(new_tile_dims[i]); + local_perm[new_tile_indices[i]] = local_perm_counter++; } - dst_index[dst->SubgroupReplicationDim()] = indices.back(); - int64_t to_merge_group_id = - get_group_index(to_merge_index, to_merge, to_merge_man_dim); - int64_t dst_group_id = get_group_index(dst_index, *dst, dst_man_dim); - if (merge_group_members[to_merge_group_id].empty() || - dst_group_members[dst_group_id].empty()) { - return InvalidArgument("Not compatible"); + result.push_back(sharding.tile_assignment() + .Reshape(reshape_dims) + .Transpose(local_perm)); + } while (std::next_permutation(iota.begin(), iota.end())); + return result; + }; + + auto merge_compatible_tile_assignment = get_compatible_tile_assignment( + to_merge, merge_old_tile_dim, merge_new_tile_dim, merge_new_tile_index, + perm_merge, perm_merge_counter); + auto dst_compatible_tile_assignment = get_compatible_tile_assignment( + *dst, dst_old_tile_dim, dst_new_tile_dim, dst_new_tile_index, perm_dst, + perm_dst_counter); + + // Find the intersection of merge_compatible_tile_assignment and + // dst_compatible_tile_assignment, such that the resultant tile assignment + // is compatible to both to_merge and dst. + for (const auto& ta1 : dst_compatible_tile_assignment) { + for (const auto& ta2 : merge_compatible_tile_assignment) { + if (ta1 == ta2) { + // Try to get the tile assignment in the iota format + compatible_tile_assignment = ta1.iota() ? ta1 : ta2; } + } + } + } - int64_t smallest_to_merge = - *merge_group_members[to_merge_group_id].begin(); - int64_t smallest_dst = *dst_group_members[dst_group_id].begin(); - if (smallest_to_merge < smallest_dst) { - if (merge_group_members[to_merge_group_id].count(smallest_dst) == 0) { - return InvalidArgument("Not compatible"); + // If the first method fails, try the second method, which handles the element + // in the new tile assignment one by one. + if (!compatible_tile_assignment.has_value()) { + Array new_tile_array(merged_tile_dims); + // Maps from replication group ID to sorted members. + std::vector> merge_group_members(num_merge_groups); + std::vector> dst_group_members(num_dst_groups); + const int64_t merge_group_size = num_devices / num_merge_groups; + const int64_t dst_group_size = num_devices / num_dst_groups; + const auto* merge_begin = to_merge.tile_assignment().array().begin(); + const auto* dst_begin = dst->tile_assignment().array().begin(); + for (int64_t i = 0; i < num_merge_groups; ++i) { + merge_group_members[i] = + absl::btree_set{merge_begin + i * merge_group_size, + merge_begin + (i + 1) * merge_group_size}; + } + for (int64_t i = 0; i < num_dst_groups; ++i) { + dst_group_members[i] = absl::btree_set{ + dst_begin + i * dst_group_size, dst_begin + (i + 1) * dst_group_size}; + } + + auto get_group_index = [&](absl::Span tile_indices, + const HloSharding& sharding, + int64_t manual_dim) { + int64_t group_id = 0; + for (int64_t i = 0; i < to_merge.TiledDataRank(); ++i) { + group_id *= sharding.tile_assignment().dim(i); + group_id += tile_indices[i]; + } + if (manual_dim >= 0) { + group_id *= sharding.tile_assignment().dim(manual_dim); + group_id += tile_indices[manual_dim]; + } + return group_id; + }; + // Try to find the intersection of to_merge and dst replication groups, in + // order to determine the merged tile assignment. + Status compatible = + new_tile_array.EachStatus([&](absl::Span indices, + int64_t* device) -> absl::Status { + DimensionVector to_merge_index( + to_merge.tile_assignment().num_dimensions()); + DimensionVector dst_index(dst->tile_assignment().num_dimensions()); + for (int64_t i = 0; i < to_merge.TiledDataRank(); ++i) { + if (to_merge.tile_assignment().dim(i) == 1) { + to_merge_index[i] = 0; + } else { + to_merge_index[i] = indices[i]; + } + if (dst->tile_assignment().dim(i) == 1) { + dst_index[i] = 0; + } else { + dst_index[i] = indices[i]; + } } - *device = smallest_dst; - } else { - if (dst_group_members[dst_group_id].count(smallest_to_merge) == 0) { - return InvalidArgument("Not compatible"); + if (to_merge_man_dim >= 0) { + to_merge_index[to_merge_man_dim] = + indices[to_merge.TiledDataRank()]; + dst_index[dst_man_dim] = indices[to_merge.TiledDataRank()]; } - *device = smallest_to_merge; - } - merge_group_members[to_merge_group_id].erase(*device); - dst_group_members[dst_group_id].erase(*device); - return OkStatus(); - }); - if (!compatible.ok()) { - return false; + if (to_merge.HasPartialReplication()) { + to_merge_index[to_merge.SubgroupReplicationDim()] = indices.back(); + } + dst_index[dst->SubgroupReplicationDim()] = indices.back(); + + int64_t to_merge_group_id = + get_group_index(to_merge_index, to_merge, to_merge_man_dim); + int64_t dst_group_id = get_group_index(dst_index, *dst, dst_man_dim); + auto& gm1 = merge_group_members[to_merge_group_id]; + auto& gm2 = dst_group_members[dst_group_id]; + + // Find the smallest element in the intersection of gm1 and gm2. + auto it1 = gm1.begin(); + auto it2 = gm2.begin(); + while (it1 != gm1.end() && it2 != gm2.end()) { + if (*it1 == *it2) { + *device = *it1; + gm1.erase(it1); + gm2.erase(it2); + return OkStatus(); + } else if (*it1 < *it2) { + it1++; + } else { + it2++; + } + } + return InvalidArgument("Not compatible"); + }); + if (!compatible.ok()) { + return false; + } + compatible_tile_assignment = + TileAssignment(std::make_shared>(new_tile_array)); } + std::vector merged_metadata(std::move(dst->metadata())); merged_metadata.reserve(merged_metadata.size() + to_merge.metadata().size()); const absl::flat_hash_set= 0) { subgroup_types.push_back(OpSharding::MANUAL); } - subgroup_types.push_back(OpSharding::REPLICATED); - *dst = HloSharding::Subgroup(merged_tile, subgroup_types, merged_metadata); + if (replication > 1) { + subgroup_types.push_back(OpSharding::REPLICATED); + } + *dst = HloSharding::Subgroup(compatible_tile_assignment.value(), + subgroup_types, merged_metadata); return true; } @@ -511,10 +702,10 @@ std::optional ReshapeSharding(const Shape& source_shape, // gets split into 128 and 8, but 8 then gets merged with 256. We use stacks // to make supporting such cases easy. const Shape tile_shape = sharding.TileShape(source_shape); - std::vector target_tile_assignment_dimensions; - std::vector source_dims_stack(source_shape.rank()); - std::vector target_dims_stack(target_shape.rank()); - std::vector sharding_tile_dims_stack(source_shape.rank()); + DimensionVector target_tile_assignment_dimensions; + DimensionVector source_dims_stack(source_shape.rank()); + DimensionVector target_dims_stack(target_shape.rank()); + DimensionVector sharding_tile_dims_stack(source_shape.rank()); int64_t added_to_partially_replicated = 1; for (int64_t i = 0; i < source_shape.rank(); ++i) { source_dims_stack[i] = source_shape.dimensions(source_shape.rank() - 1 - i); @@ -654,7 +845,6 @@ std::optional ReshapeSharding(const Shape& source_shape, HloSharding PropagateShardingThroughReshape(const Shape& source_shape, const Shape& target_shape, const HloSharding& sharding) { - HloSharding result = HloSharding::Replicate(); if (sharding.IsTileMaximal() || sharding.IsManual()) { return sharding; } @@ -671,33 +861,58 @@ HloSharding PropagateShardingThroughReshape(const Shape& source_shape, // Find intervals of consecutive dimensions that could use ReshapeSharding(). // then merge the results. We start with the longest interval (whole shape), // and if it fails, we find a sub-interval of it or a disjoint interval. + HloSharding result = HloSharding::Replicate(); int64_t start_dim = 0; while (start_dim < source_shape.rank()) { - int64_t found_compatible = false; + bool found_compatible = false; // For each start_dim, try to use all dims after it. If that fails, reduce // the range. for (int64_t end_dim = source_shape.rank(); end_dim > start_dim; --end_dim) { - std::vector preserved_dims(end_dim - start_dim); - absl::c_iota(preserved_dims, start_dim); - auto group = GroupShardingOnAllDimsExcept(sharding, preserved_dims); + DimensionVector grouped_tiling_dims(source_shape.rank(), 1); + for (int64_t i = start_dim; i < end_dim; ++i) { + grouped_tiling_dims[i] = sharding.tile_assignment().dim(i); + } + HloSharding grouped_sharding = + HloSharding::Tile(TileAssignment(grouped_tiling_dims)); if (auto reshaped = - ReshapeSharding(source_shape, target_shape, group.sharding)) { - group.sharding = std::move(*reshaped); - group.group_dims.clear(); - // Replication dim. - group.group_dims.push_back(target_shape.rank()); - group.data_rank = target_shape.rank(); - int64_t group_size = Product(group.group_dim_sizes); - group.group_dim_sizes.clear(); - group.group_dim_sizes.push_back(group_size); - if (MergeShardingIfCompatible(UngroupSharding(group), - result.NumTiles() + 1, &result)) { + ReshapeSharding(source_shape, target_shape, grouped_sharding)) { + std::vector perm; + perm.reserve(sharding.tile_assignment().num_dimensions()); + for (int64_t i = start_dim; i < end_dim; i++) { + perm.push_back(i); + } + for (int64_t i = 0; i < start_dim; i++) { + perm.push_back(i); + } + for (int64_t i = end_dim; + i < sharding.tile_assignment().num_dimensions(); i++) { + perm.push_back(i); + } + + DimensionVector reshape_dims( + reshaped->tile_assignment().dimensions().begin(), + reshaped->tile_assignment().dimensions().end()); + CHECK_EQ( + sharding.tile_assignment().num_elements() % Product(reshape_dims), + 0); + int64_t num_replicated_dims = + sharding.tile_assignment().num_elements() / Product(reshape_dims); + const int64_t diff = reshape_dims.size() - target_shape.rank(); + CHECK(diff == 0 || diff == 1); + if (diff == 0) { + reshape_dims.push_back(num_replicated_dims); + } else { + reshape_dims.back() *= num_replicated_dims; + } + HloSharding ungrouped_sharding = HloSharding::PartialTile( + sharding.tile_assignment().Transpose(perm).Reshape(reshape_dims)); + if (MergeShardingIfCompatible(ungrouped_sharding, result.NumTiles() + 1, + &result)) { // If the current interval works, we can skip all dimensions within // or before it in future intervals, since they have been considered // already. Set start_dim to end_dim to start with the next disjoint // interval. - result.metadata() = sharding.metadata(); start_dim = end_dim; found_compatible = true; break; @@ -710,6 +925,7 @@ HloSharding PropagateShardingThroughReshape(const Shape& source_shape, start_dim += 1; } } + result.metadata() = sharding.metadata(); return result; } @@ -753,43 +969,31 @@ HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64_t dim, // | | | | 3 | | | | | | // +---+---+ +---+---+ +-+-+-+-+ - std::vector tile_dims(sharding.tile_assignment().num_dimensions(), - 1); - // Handle ignore dimensions. - std::vector ignore_sizes; - int64_t ignore_size = 1; + auto old_dims = sharding.tile_assignment().dimensions(); + DimensionVector new_dims(old_dims.begin(), old_dims.end()); + std::vector not_in_dims, dims_except_the_dim; for (int64_t i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { - if (absl::c_find(dims, i) == dims.end()) { - int64_t size = sharding.tile_assignment().dim(i); - ignore_sizes.push_back(size); - tile_dims[i] = size; - ignore_size *= size; + if (i == dim) { + continue; + } else if (absl::c_find(dims, i) != dims.end()) { + dims_except_the_dim.push_back(i); + new_dims[dim] *= old_dims[i]; + new_dims[i] = 1; + } else { + not_in_dims.push_back(i); } } + // perm = not_in_dims + {dim} + dims_except_the_dim + std::vector perm; + perm.reserve(sharding.tile_assignment().num_dimensions()); + perm.insert(perm.end(), not_in_dims.begin(), not_in_dims.end()); + perm.push_back(dim); + perm.insert(perm.end(), dims_except_the_dim.begin(), + dims_except_the_dim.end()); - using Buckets = std::vector>; - Array buckets(ignore_sizes, - Buckets(sharding.tile_assignment().dim(dim))); - sharding.tile_assignment().Each( - [&](absl::Span index, int64_t device) { - std::vector ignore_index; - for (int64_t i = 0; i < index.size(); ++i) { - if (absl::c_find(dims, i) == dims.end()) { - ignore_index.push_back(index[i]); - } - } - buckets(ignore_index)[index[dim]].push_back(device); - }); - std::vector devices; - buckets.Each([&](absl::Span index, const Buckets& buckets) { - for (auto& bucket : buckets) { - devices.insert(devices.end(), bucket.begin(), bucket.end()); - } - }); - tile_dims[dim] = devices.size() / ignore_size; - Array tile_assignment(tile_dims); - tile_assignment.SetValues(devices); - return HloSharding::Tile(tile_assignment, sharding.metadata()); + auto new_tile_assignment = + sharding.tile_assignment().Transpose(perm).Reshape(new_dims); + return HloSharding::Tile(new_tile_assignment, sharding.metadata()); } bool ContainsTileSharding(const HloModule& module) { @@ -819,7 +1023,7 @@ HloSharding GatherOutputShardingFromIndexIndexPassthroughDimensions( GetGatherScatterIndexPassthroughOutputOrUpdateDims(hlo->shape().rank(), dnums.offset_dims()); CHECK_EQ(index_passthrough_dims.size(), output_passthrough_dims.size()); - std::vector output_tile(hlo->shape().rank(), 1); + DimensionVector output_tile(hlo->shape().rank(), 1); for (auto i = 0; i != index_passthrough_dims.size(); ++i) { output_tile[output_passthrough_dims[i]] = index_sharding.tile_assignment().dim(index_passthrough_dims[i]); @@ -860,7 +1064,7 @@ HloSharding GatherIndexShardingFromOutputIndexPassthroughDimensions( GetGatherScatterIndexPassthroughOutputOrUpdateDims(hlo->shape().rank(), dnums.offset_dims()); CHECK_EQ(index_passthrough_dims.size(), output_passthrough_dims.size()); - std::vector index_tile(hlo->operand(1)->shape().rank(), 1); + DimensionVector index_tile(hlo->operand(1)->shape().rank(), 1); for (auto i = 0; i != index_passthrough_dims.size(); ++i) { index_tile[index_passthrough_dims[i]] = output_sharding.tile_assignment().dim(output_passthrough_dims[i]); @@ -892,7 +1096,7 @@ HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) { } const GatherDimensionNumbers& dnums = hlo.gather_dimension_numbers(); - std::vector tile_assignment_dims(hlo.shape().rank()); + DimensionVector tile_assignment_dims(hlo.shape().rank()); int64_t num_elements = 1; for (int64_t i = 0; i < hlo.shape().rank(); ++i) { if (!absl::c_binary_search(dnums.offset_dims(), i)) { @@ -923,7 +1127,7 @@ HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) { // - first dimension is non offset dimension, // - second dimension is offset dimension, // Then the result sharding will be [2,1]{0,2}. - std::vector slice_starts(hlo.shape().rank(), 0LL), + DimensionVector slice_starts(hlo.shape().rank(), 0LL), slice_limits(hlo.shape().rank()); for (int64_t i = 0; i < hlo.shape().rank(); ++i) { if (!absl::c_binary_search(dnums.offset_dims(), i)) { @@ -953,8 +1157,7 @@ HloSharding ScatterIndexShardingFromUpdateIndexPassthroughDimensions( scatter->scatter_updates()[0]->shape().rank(), dnums.update_window_dims()); CHECK_EQ(index_passthrough_dims.size(), update_passthrough_dims.size()); - std::vector index_tile(scatter->scatter_indices()->shape().rank(), - 1); + DimensionVector index_tile(scatter->scatter_indices()->shape().rank(), 1); for (auto i = 0; i != index_passthrough_dims.size(); ++i) { index_tile[index_passthrough_dims[i]] = update_sharding.tile_assignment().dim(update_passthrough_dims[i]); @@ -995,8 +1198,7 @@ HloSharding ScatterUpdateShardingFromIndexIndexPassthroughDimensions( scatter->scatter_updates()[0]->shape().rank(), dnums.update_window_dims()); CHECK_EQ(index_passthrough_dims.size(), update_passthrough_dims.size()); - std::vector update_tile( - scatter->scatter_updates()[0]->shape().rank(), 1); + DimensionVector update_tile(scatter->scatter_updates()[0]->shape().rank(), 1); for (auto i = 0; i != index_passthrough_dims.size(); ++i) { update_tile[update_passthrough_dims[i]] = index_sharding.tile_assignment().dim(index_passthrough_dims[i]); @@ -1053,7 +1255,7 @@ HloSharding ScatterEffectiveIndexSharding( } const int64_t index_rank = scatter.scatter_indices()->shape().rank(); - std::vector slice_starts(index_rank, 0LL), slice_limits(index_rank); + DimensionVector slice_starts(index_rank, 0LL), slice_limits(index_rank); for (int64_t i = 0; i < index_rank; ++i) { if (i < index_dim) { slice_limits[i] = index_sharding.tile_assignment().dim(i); @@ -1075,7 +1277,7 @@ HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, const ScatterDimensionNumbers& dnums = scatter.scatter_dimension_numbers(); const int64_t data_rank = scatter.scatter_updates()[0]->shape().rank(); - std::vector tile_assignment_dims(data_rank, 1LL); + DimensionVector tile_assignment_dims(data_rank, 1LL); int64_t num_elements = 1; for (int64_t i = 0; i < scatter.shape().rank(); ++i) { if (absl::c_binary_search(dnums.inserted_window_dims(), i)) { @@ -1104,7 +1306,7 @@ HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, // - first dimension is scatter_window_dims, // - second dimension is update_window_dims, // Then the result sharding will be [2,1]{0,2}. - std::vector slice_starts(data_rank, 0LL); + DimensionVector slice_starts(data_rank, 0LL); Array tile_assignment = data_sharding.tile_assignment().array().Slice(slice_starts, tile_assignment_dims); @@ -1178,12 +1380,12 @@ std::optional PassthroughOperandToGatherOutputOrScatterUpdate( absl::Span offset_or_window_dims, absl::Span slice_size, const int64_t index_vector_dim) { if (operand_sharding.IsTileMaximal() || operand_sharding.IsManual()) { - return operand_sharding; + return std::nullopt; } auto operand_passthrough_dims = GetGatherScatterOperandPassthroughOperandDims( operand_shape, collapsed_or_inserted_dims, index_map, offset_or_window_dims, slice_size); - std::vector passthrough_tile(output_or_update_rank, 1); + DimensionVector passthrough_tile(output_or_update_rank, 1); int64_t collapsed = 0; for (int64_t i = 0; i < operand_shape.rank(); ++i) { if (absl::c_linear_search(collapsed_or_inserted_dims, i)) { @@ -1233,10 +1435,10 @@ std::optional PassthroughGatherOutputOrScatterUpdateToOperand( auto operand_passthrough_dims = GetGatherScatterOperandPassthroughOperandDims( operand_shape, collapsed_or_inserted_dims, index_map, offset_or_window_dims, slice_size); - std::vector passthrough_tile(operand_shape.rank(), 1); + DimensionVector passthrough_tile(operand_shape.rank(), 1); int64_t collapsed = 0; // Relevant dims have shardings passed to the operand. - std::vector relevant_output_or_update_dims; + DimensionVector relevant_output_or_update_dims; for (int64_t i = 0; i < operand_shape.rank(); ++i) { if (absl::c_linear_search(collapsed_or_inserted_dims, i)) { collapsed++; @@ -1285,9 +1487,9 @@ std::optional GatherOperandShardingFromOutputParallelDimensions( const Shape gather_shape = gather.shape(); CHECK_EQ(output_parallel_dims.size(), output_aligned_operand_parallel_dims.size()); - std::vector operand_tile_assignment( - gather.operand(0)->shape().rank(), 1); - std::vector relevant_output_dims; + DimensionVector operand_tile_assignment(gather.operand(0)->shape().rank(), + 1); + DimensionVector relevant_output_dims; for (int i = 0, parallel_idx = 0; i < gather_shape.rank(); ++i) { if (parallel_idx >= output_parallel_dims.size() || output_parallel_dims[parallel_idx] != i) { @@ -1528,9 +1730,9 @@ std::optional ScatterUpdateShardingFromOutputParallelDimensions( : scatter.shape(); CHECK_EQ(update_parallel_dims.size(), index_aligned_operand_parallel_dims.size()); - std::vector update_tile_assignment( + DimensionVector update_tile_assignment( scatter.scatter_updates()[0]->shape().rank(), 1); - std::vector relevant_output_dims; + DimensionVector relevant_output_dims; for (int i = 0, parallel_idx = 0; i < scatter_shape.rank(); ++i) { if (parallel_idx >= operand_parallel_dims_sorted.size() || operand_parallel_dims_sorted[parallel_idx] != i) { @@ -1612,7 +1814,7 @@ HloSharding GatherOutputOrScatterUpdateShardingFromIndicesParallelDimensions( indices_sharding.metadata()); } -StatusOr, HloOpcode>> +absl::StatusOr, HloOpcode>> IdentityValueAndHloOpcodeForScatterReduceComputation( const HloScatterInstruction& scatter) { auto computation = scatter.to_apply(); @@ -1707,7 +1909,7 @@ HloSharding PartiallyReplicateTiledShardingOnDims( return sharding; } int64_t group_count = 1; - std::vector valid_dims_to_replicate; + DimensionVector valid_dims_to_replicate; for (int64_t dim : dims_to_replicate) { if (dim >= sharding.TiledDataRank()) { continue; @@ -1721,7 +1923,7 @@ HloSharding PartiallyReplicateTiledShardingOnDims( if (group_count == sharding.NumTiles() && sharding.subgroup_types().empty()) { return HloSharding::Replicate(sharding.metadata()); } - std::vector dim_permutation(sharding.TiledDataRank()); + DimensionVector dim_permutation(sharding.TiledDataRank()); absl::c_iota(dim_permutation, 0); absl::c_stable_sort(dim_permutation, [&](const int64_t a, const int64_t b) { return absl::c_linear_search(valid_dims_to_replicate, a) < @@ -1729,7 +1931,7 @@ HloSharding PartiallyReplicateTiledShardingOnDims( }); auto new_tile = TransposeSharding(sharding, dim_permutation).tile_assignment(); - std::vector new_tile_shape( + DimensionVector new_tile_shape( sharding.tile_assignment().dimensions().begin(), sharding.tile_assignment().dimensions().end()); for (int64_t dim : valid_dims_to_replicate) { @@ -1757,7 +1959,7 @@ HloSharding PartiallyReplicateTiledShardingOnAllDimsExcept( if (sharding.IsTileMaximal() || sharding.IsManual()) { return sharding; } - std::vector dims_to_replicate(sharding.TiledDataRank()); + DimensionVector dims_to_replicate(sharding.TiledDataRank()); absl::c_iota(dims_to_replicate, 0); dims_to_replicate.erase( @@ -1780,7 +1982,7 @@ HloSharding ReplicateAllDataDims(const HloSharding& sharding, PartiallyReplicateTiledShardingOnAllDimsExcept(sharding, {}); if (data_rank >= 0 && data_rank != result.TiledDataRank() && !result.IsTileMaximal()) { - std::vector new_tile_shape(data_rank, 1); + DimensionVector new_tile_shape(data_rank, 1); for (int64_t i = result.TiledDataRank(); i < result.tile_assignment().num_dimensions(); ++i) { new_tile_shape.push_back(result.tile_assignment().dim(i)); @@ -1796,7 +1998,7 @@ HloSharding RemoveShapeDimensions(const HloSharding& sharding, if (sharding.IsTileMaximal() || dims_to_remove.empty()) { return sharding; } - std::vector new_tile_shape; + DimensionVector new_tile_shape; new_tile_shape.reserve(sharding.tile_assignment().num_dimensions() - dims_to_remove.size()); for (int64_t i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { @@ -1821,8 +2023,8 @@ std::optional TransposeShardingWithCollapsedDims( } if (src_to_tgt.size() < source.tile_assignment().num_dimensions()) { // Add missing subgroup dims. - std::vector new_src_to_tgt(src_to_tgt.begin(), src_to_tgt.end()); - std::vector new_tgt_to_src(tgt_to_src.begin(), tgt_to_src.end()); + DimensionVector new_src_to_tgt(src_to_tgt.begin(), src_to_tgt.end()); + DimensionVector new_tgt_to_src(tgt_to_src.begin(), tgt_to_src.end()); for (int64_t i = 0; i < source.tile_assignment().num_dimensions() - src_to_tgt.size(); ++i) { @@ -1832,7 +2034,7 @@ std::optional TransposeShardingWithCollapsedDims( return TransposeShardingWithCollapsedDims(source, new_src_to_tgt, new_tgt_to_src); } - std::vector tgt_dims_skipping_new(tgt_to_src.size(), -1); + DimensionVector tgt_dims_skipping_new(tgt_to_src.size(), -1); int64_t skipped_tgt_dims = 0; int64_t src_non_subgroup_dims = src_to_tgt.size() - source.subgroup_types().size(); @@ -1848,7 +2050,7 @@ std::optional TransposeShardingWithCollapsedDims( } } int64_t skipped_src_dims = absl::c_count(src_to_tgt, -1); - std::vector perm(src_to_tgt.size()); + DimensionVector perm(src_to_tgt.size()); for (int64_t i = 0; i < src_non_subgroup_dims; ++i) { if (src_to_tgt[i] < 0) { if (source.tile_assignment().dim(i) > 1) { @@ -1867,7 +2069,7 @@ std::optional TransposeShardingWithCollapsedDims( perm[src_to_tgt[i] - skipped_tgt_dims + skipped_src_dims] = i; } auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm); - std::vector tgt_tiles(tgt_to_src.size(), 1); + DimensionVector tgt_tiles(tgt_to_src.size(), 1); for (int64_t i = 0; i < tgt_tiles.size(); ++i) { if (tgt_to_src[i] >= 0) { int64_t dim = tgt_dims_skipping_new[i]; @@ -2235,16 +2437,17 @@ absl::InlinedVector IndexAlignedOperandParallelDims( } std::string GroupedSharding::ToString() const { - auto result = absl::StrCat("dims: ", absl::StrJoin(group_dims, ","), - "\ndevice_groups:\n"); - absl::StrAppend(&result, - "group dim sizes: ", absl::StrJoin(group_dim_sizes, ",")); - absl::StrAppend(&result, "data rank: ", data_rank); - absl::StrAppend(&result, "subgroup manual: ", subgroup_manual); + auto result = + absl::StrCat("group dims: ", absl::StrJoin(group_dims, ","), "\n"); + absl::StrAppend( + &result, "group dim sizes: ", absl::StrJoin(group_dim_sizes, ","), "\n"); + absl::StrAppend(&result, "data rank: ", data_rank, "\n"); + absl::StrAppend(&result, "subgroup manual: ", subgroup_manual, "\n"); + absl::StrAppend(&result, "inner sharding: ", sharding.ToString(), "\n"); + absl::StrAppend(&result, "device groups:", "\n"); for (auto& device_group : device_groups) { absl::StrAppend(&result, "\t", absl::StrJoin(device_group, ","), "\n"); } - absl::StrAppend(&result, "inner sharding: ", sharding.ToString()); return result; } @@ -2275,31 +2478,77 @@ GroupedSharding GroupShardingOnDims(const HloSharding& sharding, absl::Span group_dim_shards, bool subgroup_manual) { CHECK(!sharding.IsTileMaximal()); - std::vector grouped_tiling_dims( - sharding.tile_assignment().dimensions().begin(), - sharding.tile_assignment().dimensions().end()); - std::vector group_dim_sizes(group_dims.size()); + + // The first item of the pair is the group_dim_size. The second item is the + // group_dim_shard. + std::vector> decomposed_tiling_dims( + sharding.tile_assignment().num_dimensions()); + for (int64_t i = 0; i < decomposed_tiling_dims.size(); ++i) { + // Set default values for group_dim_size and group_dim_shard. + decomposed_tiling_dims[i] = + std::make_pair(1, sharding.tile_assignment().dim(i)); + } + + DimensionVector group_dim_sizes(group_dims.size()); for (int64_t i = 0; i < group_dims.size(); ++i) { - CHECK_EQ(grouped_tiling_dims[group_dims[i]] % group_dim_shards[i], 0); + CHECK_EQ( + sharding.tile_assignment().dim(group_dims[i]) % group_dim_shards[i], 0); group_dim_sizes[i] = - grouped_tiling_dims[group_dims[i]] / group_dim_shards[i]; - grouped_tiling_dims[group_dims[i]] = group_dim_shards[i]; + sharding.tile_assignment().dim(group_dims[i]) / group_dim_shards[i]; + + decomposed_tiling_dims[group_dims[i]].first = group_dim_sizes[i]; + decomposed_tiling_dims[group_dims[i]].second = group_dim_shards[i]; + } + + DimensionVector grouped_tiling_dims(decomposed_tiling_dims.size()); + for (int64_t i = 0; i < decomposed_tiling_dims.size(); ++i) { + grouped_tiling_dims[i] = decomposed_tiling_dims[i].second; + } + + DimensionVector sorted_group_dims(group_dims.size()); + std::partial_sort_copy(group_dims.begin(), group_dims.end(), + sorted_group_dims.begin(), sorted_group_dims.end()); + + absl::flat_hash_map group_dim_to_index(group_dims.size()); + DimensionVector reshape_dimensions(grouped_tiling_dims.begin(), + grouped_tiling_dims.end()); + reshape_dimensions.reserve(decomposed_tiling_dims.size() + group_dims.size()); + for (int64_t i = 0; i < sorted_group_dims.size(); ++i) { + int64_t index = sorted_group_dims[i] + i; + group_dim_to_index[sorted_group_dims[i]] = index; + reshape_dimensions.insert( + reshape_dimensions.begin() + index, + decomposed_tiling_dims[sorted_group_dims[i]].first); + } + + std::vector perm(reshape_dimensions.size()); + absl::c_iota(perm, 0); + for (int64_t i = 0; i < group_dims.size(); ++i) { + const int64_t index = group_dim_to_index[group_dims[i]]; + perm.erase(std::remove(perm.begin(), perm.end(), index), perm.end()); + perm.insert(perm.begin() + i, index); + } + + auto grouped_array = sharding.tile_assignment() + .Reshape(reshape_dimensions) + .Transpose(perm) + .array(); + + const int64_t num_device_groups = Product(group_dim_sizes); + const int64_t num_devices = sharding.tile_assignment().num_elements(); + CHECK_EQ(num_devices % num_device_groups, 0); + const int64_t device_group_size = num_devices / num_device_groups; + std::vector> device_groups( + num_device_groups, std::vector(device_group_size)); + for (int64_t i = 0; i < num_device_groups; ++i) { + device_groups[i].assign( + grouped_array.begin() + i * device_group_size, + grouped_array.begin() + (i + 1) * device_group_size); } - std::vector> device_groups(Product(group_dim_sizes)); - sharding.tile_assignment().Each([&](absl::Span indices, - int64_t device) { - int64_t group_id = 0; - for (int64_t i = 0; i < group_dims.size(); ++i) { - group_id *= - sharding.tile_assignment().dim(group_dims[i]) / group_dim_shards[i]; - group_id += indices[group_dims[i]] / group_dim_shards[i]; - } - device_groups[group_id].push_back(device); - }); auto grouped = GroupedSharding( std::move(device_groups), - std::vector(group_dims.begin(), group_dims.end()), + DimensionVector(group_dims.begin(), group_dims.end()), std::move(group_dim_sizes), sharding.tile_assignment().num_dimensions(), HloSharding::Replicate(), subgroup_manual); if (sharding.ReplicateOnLastTileDim()) { @@ -2388,7 +2637,7 @@ GroupedSharding GroupShardingOnReplicatedDim( num_groups / (sharding.ReplicateOnLastTileDim() ? sharding.tile_assignment().dimensions().back() : 1); - std::vector tile_dims( + DimensionVector tile_dims( sharding.tile_assignment().dimensions().begin(), sharding.tile_assignment().dimensions().end()); if (!sharding.ReplicateOnLastTileDim()) { @@ -2437,7 +2686,7 @@ GroupedSharding GetGroupedReplicatedSharding(const int64_t num_groups, absl::c_iota(device_group, device_id); device_id = device_group.back() + 1; } - return GroupedSharding(std::move(device_groups), {data_rank}, {group_size}, + return GroupedSharding(std::move(device_groups), {data_rank}, {num_groups}, data_rank, HloSharding::Replicate(), /*subgroup_manual=*/false); } @@ -2447,7 +2696,7 @@ GroupedSharding GetManualSubgroupSharding(const HloSharding& sharding) { int64_t tile_dimensions = sharding.tile_assignment().num_dimensions(); int64_t subgroup_size = sharding.subgroup_types().size(); int64_t rank = tile_dimensions - subgroup_size; - std::vector group_dims; + DimensionVector group_dims; bool last_tile_dim_replicate = false; for (int64_t i = 0; i < subgroup_size; i++) { @@ -2479,14 +2728,14 @@ PartialReplicatedGroupShardingWithAssignedDeviceGroups( VLOG(5) << "Failed because not partial replicated or not divisible"; return std::nullopt; } - std::vector> device_to_index( + std::vector device_to_index( Product(sharding.tile_assignment().dimensions()), - std::vector(sharding.tile_assignment().num_dimensions())); + DimensionVector(sharding.tile_assignment().num_dimensions())); sharding.tile_assignment().Each( [&device_to_index](absl::Span indices, int64_t device) { device_to_index[device].assign(indices.begin(), indices.end()); }); - std::vector grouped_tiling_dims( + DimensionVector grouped_tiling_dims( sharding.tile_assignment().dimensions().begin(), sharding.tile_assignment().dimensions().end()); grouped_tiling_dims.back() /= device_groups.size(); @@ -2543,12 +2792,12 @@ PartialReplicatedGroupShardingWithAssignedDeviceGroups( } HloSharding UngroupSharding(const GroupedSharding& grouped_sharding) { - std::vector tiling_dims; + DimensionVector tiling_dims; bool partial_sharding = false; std::vector subgroup_types; auto grouped_tiling = grouped_sharding.sharding.tile_assignment(); if (grouped_sharding.sharding.IsTileMaximal()) { - tiling_dims = std::vector(grouped_sharding.data_rank, 1); + tiling_dims = DimensionVector(grouped_sharding.data_rank, 1); if (grouped_sharding.device_groups[0].size() != 1 || absl::c_linear_search(grouped_sharding.group_dims, tiling_dims.size())) { @@ -2580,7 +2829,6 @@ HloSharding UngroupSharding(const GroupedSharding& grouped_sharding) { tiling_dims.insert(tiling_dims.begin() + grouped_sharding.group_dims[i], 1); } - grouped_tiling = grouped_tiling.Reshape(tiling_dims); } else if (!grouped_sharding.sharding.IsTileMaximal()) { // Handles tile replicated. partial_sharding = grouped_sharding.sharding.ReplicateOnLastTileDim(); @@ -2590,39 +2838,73 @@ HloSharding UngroupSharding(const GroupedSharding& grouped_sharding) { if (absl::c_linear_search(grouped_sharding.group_dims, tiling_dims.size())) { tiling_dims.push_back(1); - grouped_tiling = grouped_tiling.Reshape(tiling_dims); partial_sharding = true; } } - // Update group dim sizes. + DimensionVector group_dim_sizes_and_tiling_dims( + grouped_sharding.group_dim_sizes.begin(), + grouped_sharding.group_dim_sizes.end()); + group_dim_sizes_and_tiling_dims.insert(group_dim_sizes_and_tiling_dims.end(), + tiling_dims.begin(), + tiling_dims.end()); + Array tiling(group_dim_sizes_and_tiling_dims); + + DimensionVector sorted_group_dims(grouped_sharding.group_dims.size()); + std::partial_sort_copy(grouped_sharding.group_dims.begin(), + grouped_sharding.group_dims.end(), + sorted_group_dims.begin(), sorted_group_dims.end()); + absl::flat_hash_map group_dim_to_index( + grouped_sharding.group_dims.size()); + for (int64_t i = 0; i < sorted_group_dims.size(); ++i) { + group_dim_to_index[sorted_group_dims[i]] = sorted_group_dims[i] + i; + } + + std::vector perm(tiling_dims.size() + grouped_sharding.group_dims.size(), + -1); + for (int64_t i = 0; i < grouped_sharding.group_dims.size(); i++) { + perm[group_dim_to_index[grouped_sharding.group_dims[i]]] = i; + } + int64_t j = grouped_sharding.group_dims.size(); + for (int64_t i = 0; i < perm.size(); i++) { + if (perm[i] == -1) { + perm[i] = j++; + } + } + + std::vector flattened_device_groups; + flattened_device_groups.reserve(grouped_sharding.device_groups.size() * + grouped_sharding.device_groups[0].size()); + bool same_length = + grouped_tiling.num_elements() == grouped_sharding.device_groups[0].size(); + for (auto const& v : grouped_sharding.device_groups) { + if (same_length) { + // Reorder the device_groups based on the grouped_tiling.array() + for (int64_t i = 0; i < v.size(); ++i) { + flattened_device_groups.push_back( + v[*(grouped_tiling.array().begin() + i)]); + } + } else { + flattened_device_groups.insert(flattened_device_groups.end(), v.begin(), + v.end()); + } + } + tiling.SetValues(flattened_device_groups); + TileAssignment tile_assignment( + std::make_shared>(std::move(tiling))); + for (int64_t i = 0; i < grouped_sharding.group_dims.size(); ++i) { int64_t dim = grouped_sharding.group_dims[i]; tiling_dims[dim] *= grouped_sharding.group_dim_sizes[i]; } - Array tiling(tiling_dims); - grouped_tiling.Each([&](absl::Span indices, int64_t device) { - std::vector ungrouped_inds(indices.begin(), indices.end()); - for (int64_t g = 0; g < grouped_sharding.device_groups.size(); ++g) { - int64_t remaining_group_index = g; - for (int64_t i = grouped_sharding.group_dims.size() - 1; i >= 0; --i) { - int64_t dim = grouped_sharding.group_dims[i]; - int64_t groups_in_this_dim = grouped_sharding.group_dim_sizes[i]; - ungrouped_inds[dim] = (remaining_group_index % groups_in_this_dim) * - grouped_tiling.dim(dim) + - indices[dim]; - remaining_group_index /= groups_in_this_dim; - } - tiling(ungrouped_inds) = grouped_sharding.device_groups[g][device]; - } - }); + tile_assignment = tile_assignment.Transpose(perm).Reshape(tiling_dims); if (grouped_sharding.subgroup_manual) { - return HloSharding::Subgroup(tiling, subgroup_types, + return HloSharding::Subgroup(tile_assignment, subgroup_types, grouped_sharding.sharding.metadata()); } - return partial_sharding ? HloSharding::PartialTile(tiling) - : HloSharding::Tile(tiling); + return partial_sharding ? HloSharding::PartialTile(tile_assignment) + : HloSharding::Tile(tile_assignment); } bool DeviceGroupsAreMatch(GroupedSharding& lhs, GroupedSharding& rhs, @@ -2632,7 +2914,8 @@ bool DeviceGroupsAreMatch(GroupedSharding& lhs, GroupedSharding& rhs, } bool matching_groups = true; - absl::flat_hash_map device_to_ref_group; + std::vector device_to_ref_group(lhs.device_groups.size() * + lhs.device_groups[0].size()); for (int64_t g = 0; g < lhs.device_groups.size(); ++g) { for (int64_t device : lhs.device_groups[g]) { device_to_ref_group[device] = g; @@ -2666,9 +2949,8 @@ HloSharding SplitShardingDimension(const HloSharding& sharding, CHECK_GT(sharding.TiledDataRank(), dimension); CHECK_EQ(sharding.tile_assignment().dim(dimension) % new_dim_size, 0) << "dim size " << new_dim_size; - std::vector dimensions( - sharding.tile_assignment().dimensions().begin(), - sharding.tile_assignment().dimensions().end()); + DimensionVector dimensions(sharding.tile_assignment().dimensions().begin(), + sharding.tile_assignment().dimensions().end()); int64_t current_dimension = dimensions[dimension]; dimensions.insert(dimensions.begin() + dimension + 1, current_dimension / new_dim_size); @@ -2683,9 +2965,8 @@ HloSharding SplitShardingDimension(const HloSharding& sharding, HloSharding MergeShardingDimension(const HloSharding& sharding, int64_t dimension) { CHECK_GT(sharding.TiledDataRank(), dimension); - std::vector dimensions( - sharding.tile_assignment().dimensions().begin(), - sharding.tile_assignment().dimensions().end()); + DimensionVector dimensions(sharding.tile_assignment().dimensions().begin(), + sharding.tile_assignment().dimensions().end()); dimensions[dimension] *= dimensions[dimension + 1]; dimensions.erase(dimensions.begin() + dimension + 1); auto new_tile_assignment = sharding.tile_assignment().Reshape(dimensions); @@ -2734,7 +3015,8 @@ bool IsSortOperandShardingMovable(const HloInstruction* sort_operand, auto tile_assignment_dims = sharding.tile_assignment().dimensions(); const int rank = sort_operand->shape().rank(); for (int64_t dim = 0; dim < rank; ++dim) { - if (dim == sort_dim || tile_assignment_dims[dim] != 1) { + if (dim == sort_dim || tile_assignment_dims[dim] != 1 || + sort_operand->shape().dimensions(dim) == 1) { continue; } return true; @@ -2782,13 +3064,52 @@ Shape UntileLeafShape(const HloSharding& sharding, const Shape& shape) { if (sharding.IsTileMaximal() || sharding.IsManual() || sharding.IsUnknown()) { return shape; } + if (!shape.IsArray()) { + return shape; + } Shape result_shape = shape; - for (int64_t i = 0; i < sharding.TiledDataRank(); ++i) { + // sharding.TiledDataRank() == i < shape.rank() is not always true? + for (int64_t i = 0; i < sharding.TiledDataRank() && i < shape.rank(); ++i) { result_shape.set_dimensions( i, shape.dimensions(i) * sharding.tile_assignment().dim(i)); } return result_shape; } +Shape TileShape(const HloSharding& sharding, const Shape& shape) { + if (!sharding.IsTuple()) { + return TileLeafShape(sharding, shape); + } + Shape result_shape = shape; + ShapeUtil::ForEachMutableSubshape( + &result_shape, + [&shape, &sharding](Shape* subshape, const ShapeIndex& index) { + if (!ShapeUtil::IsLeafIndex(shape, index)) { + return; + } + const HloSharding& subshape_sharding = + sharding.GetSubSharding(shape, index); + *subshape = TileLeafShape(subshape_sharding, *subshape); + }); + + return result_shape; +} + +Shape TileLeafShape(const HloSharding& sharding, const Shape& shape) { + if (sharding.IsTileMaximal() || sharding.IsManual() || sharding.IsUnknown()) { + return shape; + } + if (!shape.IsArray()) { + return shape; + } + Shape result_shape = shape; + for (int64_t i = 0; i < sharding.TiledDataRank() && i < shape.rank(); ++i) { + CHECK_EQ(shape.dimensions(i) % sharding.tile_assignment().dim(i), 0); + result_shape.set_dimensions( + i, shape.dimensions(i) / sharding.tile_assignment().dim(i)); + } + return result_shape; +} + } // namespace hlo_sharding_util } // namespace xla diff --git a/xla/hlo/utils/hlo_sharding_util.h b/xla/hlo/utils/hlo_sharding_util.h index c158f15891d8d..8671ebfe2554f 100644 --- a/xla/hlo/utils/hlo_sharding_util.h +++ b/xla/hlo/utils/hlo_sharding_util.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,13 +25,17 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/service/call_graph.h" #include "xla/shape.h" +#include "xla/statusor.h" +#include "xla/util.h" namespace xla { namespace hlo_sharding_util { @@ -256,7 +260,7 @@ HloSharding GatherOutputOrScatterUpdateShardingFromIndicesParallelDimensions( // - If computation is min/max, return max value/min value with corresponding op // code. // - Otherwise, return error status. -StatusOr, HloOpcode>> +absl::StatusOr, HloOpcode>> IdentityValueAndHloOpcodeForScatterReduceComputation( const HloScatterInstruction& scatter); @@ -358,9 +362,9 @@ absl::InlinedVector IndexAlignedOperandParallelDims( // represents the in-group sharding. struct GroupedSharding { GroupedSharding(std::vector> device_groups, - std::vector group_dims, - std::vector group_dim_sizes, int64_t data_rank, - HloSharding grouped_sharding, bool subgroup_manual = false) + DimensionVector group_dims, DimensionVector group_dim_sizes, + int64_t data_rank, HloSharding grouped_sharding, + bool subgroup_manual = false) : device_groups(std::move(device_groups)), group_dims(std::move(group_dims)), group_dim_sizes(std::move(group_dim_sizes)), @@ -369,8 +373,8 @@ struct GroupedSharding { subgroup_manual(subgroup_manual) {} std::string ToString() const; std::vector> device_groups; - std::vector group_dims; - std::vector group_dim_sizes; + DimensionVector group_dims; + DimensionVector group_dim_sizes; int64_t data_rank; HloSharding sharding; bool subgroup_manual; @@ -475,6 +479,13 @@ Shape UntileShape(const HloSharding& sharding, const Shape& shape); // REQUIRES: !sharding.IsTuple() Shape UntileLeafShape(const HloSharding& sharding, const Shape& shape); +// Returns the tiled shape. +Shape TileShape(const HloSharding& sharding, const Shape& shape); + +// Returns the tiled shape. +// REQUIRES: !sharding.IsTuple() +Shape TileLeafShape(const HloSharding& sharding, const Shape& shape); + } // namespace hlo_sharding_util } // namespace xla diff --git a/xla/hlo/utils/hlo_sharding_util_test.cc b/xla/hlo/utils/hlo_sharding_util_test.cc index 15d74f2a012de..ad042361a7bf2 100644 --- a/xla/hlo/utils/hlo_sharding_util_test.cc +++ b/xla/hlo/utils/hlo_sharding_util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,18 +15,98 @@ limitations under the License. #include "xla/hlo/utils/hlo_sharding_util.h" +#include +#include #include +#include #include -#include "xla/hlo/ir/hlo_instruction.h" +#include "absl/log/log.h" +#include "absl/types/span.h" +#include "xla/array.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/ir/tile_assignment.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/test.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla { namespace hlo_sharding_util { namespace { +TEST(HloShardingUtilTest, MergeShardingIfCompatible1) { + HloSharding to_merge = + HloSharding::PartialTile(TileAssignment({1, 4, 2, 16}, {16, 8}, {1, 0})); + HloSharding dst = HloSharding::PartialTile(TileAssignment({4, 1, 1, 32})); + EXPECT_TRUE(MergeShardingIfCompatible(to_merge, dst.NumTiles() + 1, &dst)); + EXPECT_EQ(dst, HloSharding::PartialTile( + TileAssignment({4, 4, 2, 4}, {4, 4, 8}, {0, 2, 1}))); +} + +TEST(HloShardingUtilTest, MergeShardingIfCompatible2) { + HloSharding to_merge = + HloSharding::PartialTile(TileAssignment({1, 2, 4, 16}, {16, 8}, {1, 0})); + HloSharding dst = HloSharding::PartialTile(TileAssignment({4, 1, 1, 32})); + EXPECT_TRUE(MergeShardingIfCompatible(to_merge, dst.NumTiles() + 1, &dst)); + EXPECT_EQ(dst, HloSharding::PartialTile( + TileAssignment({4, 2, 4, 4}, {4, 4, 8}, {0, 2, 1}))); +} + +TEST(HloShardingUtilTest, MergeShardingIfCompatible3) { + HloSharding to_merge = + HloSharding::PartialTile(TileAssignment({4, 2, 1, 16}, {16, 8}, {1, 0})); + HloSharding dst = HloSharding::PartialTile(TileAssignment({1, 1, 4, 32})); + EXPECT_TRUE(MergeShardingIfCompatible(to_merge, dst.NumTiles() + 1, &dst)); + EXPECT_EQ(dst, HloSharding::PartialTile( + TileAssignment({4, 2, 4, 4}, {16, 8}, {1, 0}))); +} + +TEST(HloShardingUtilTest, MergeShardingIfCompatible4) { + HloSharding to_merge = + HloSharding::PartialTile(TileAssignment({1, 4, 2, 16}, {16, 8}, {1, 0})); + HloSharding dst = + HloSharding::PartialTile(TileAssignment({4, 1, 1, 32}, {4, 32}, {1, 0})); + EXPECT_TRUE(MergeShardingIfCompatible(to_merge, dst.NumTiles() + 1, &dst)); + EXPECT_EQ(dst, HloSharding::PartialTile( + TileAssignment({4, 4, 2, 4}, {4, 32}, {1, 0}))); +} + +TEST(HloShardingUtilTest, MergeShardingIfCompatible5) { + HloSharding to_merge = + HloSharding::PartialTile(TileAssignment({1, 4, 2, 16}, {16, 8}, {1, 0})); + HloSharding dst = + HloSharding::PartialTile(TileAssignment({4, 1, 1, 32}, {32, 4}, {1, 0})); + EXPECT_FALSE(MergeShardingIfCompatible(to_merge, dst.NumTiles() + 1, &dst)); +} + +TEST(HloShardingUtilTest, MergeShardingIfCompatible6) { + HloSharding to_merge = + HloSharding::PartialTile(TileAssignment({1, 4, 2, 16})); + HloSharding dst = HloSharding::PartialTile(TileAssignment({4, 1, 1, 32})); + EXPECT_FALSE(MergeShardingIfCompatible(to_merge, dst.NumTiles() + 1, &dst)); +} + +TEST(HloShardingUtilTest, MergeShardingIfCompatible7) { + HloSharding to_merge = HloSharding::PartialTile( + TileAssignment({2, 1, 2, 2}, {2, 2, 2}, {2, 1, 0})); + HloSharding dst = HloSharding::PartialTile(TileAssignment({1, 2, 1, 4})); + EXPECT_TRUE(MergeShardingIfCompatible(to_merge, dst.NumTiles() + 1, &dst)); + EXPECT_EQ(dst, + HloSharding::Tile(TileAssignment({2, 2, 2}, {2, 2, 2}, {2, 0, 1}))); +} + +TEST(HloShardingUtilTest, MergeShardingIfCompatible8) { + HloSharding to_merge = HloSharding::PartialTile(TileAssignment({2, 1, 4})); + HloSharding dst = + HloSharding::PartialTile(TileAssignment({1, 4, 2}, {2, 2, 2}, {2, 1, 0})); + EXPECT_TRUE(MergeShardingIfCompatible(to_merge, dst.NumTiles() + 1, &dst)); + EXPECT_EQ(dst, + HloSharding::Tile(TileAssignment({2, 4}, {2, 2, 2}, {0, 2, 1}))); +} + TEST(HloShardingUtilTest, TransposeShardingReplicated) { EXPECT_EQ(TransposeSharding(HloSharding::Replicate(), {0, 1, 2}), HloSharding::Replicate()); @@ -34,8 +114,7 @@ TEST(HloShardingUtilTest, TransposeShardingReplicated) { TEST(HloShardingUtilTest, TransposeShardingTiled) { HloSharding input = HloSharding::IotaTile({1, 2, 1, 2}); - HloSharding output = - HloSharding::Tile(TileAssignment({2, 1, 2, 1}, {2, 2}, {1, 0})); + HloSharding output = HloSharding::IotaTile({2, 1, 2, 1}, {2, 2}, {1, 0}); EXPECT_EQ(TransposeSharding(input, {3, 0, 1, 2}), output); } @@ -132,7 +211,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledTrivialDimensions) { EXPECT_EQ(result.value(), output_sharding); } -TEST(HloShardingUtilTest, ReshapeShardingTrivialDImensionInsertedToEnd) { +TEST(HloShardingUtilTest, ReshapeShardingTrivialDimensionInsertedToEnd) { Shape input_shape = ShapeUtil::MakeShape(F32, {8, 16}); Shape output_shape = ShapeUtil::MakeShape(F32, {8, 16, 1}); HloSharding input_sharding = HloSharding::IotaTile({2, 1}); @@ -160,56 +239,102 @@ TEST(HloShardingUtilTest, ReshapeShardingScalar) { EXPECT_FALSE(result.has_value()); } -TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim0) { - HloSharding sharding = HloSharding::IotaTile({2, 2}); - HloSharding result = - ReshapeToTileDimension(sharding, /*dim=*/0, /*dims=*/{0, 1}); - EXPECT_EQ(result.tile_assignment(), - TileAssignment((absl::Span){4, 1})); -} - -TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim1) { - HloSharding sharding = HloSharding::IotaTile({2, 2}); - HloSharding result = ReshapeToTileDimension( - sharding, /*dim=*/1, /*dims=*/(absl::Span){0, 1}); - EXPECT_EQ(result.tile_assignment(), TileAssignment({1, 4}, {2, 2}, {1, 0})); -} - -TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim0) { - HloSharding sharding = HloSharding::IotaTile({2, 2, 2}); - HloSharding result = - ReshapeToTileDimension(sharding, /*dim=*/0, /*dims=*/{0, 1, 2}); - EXPECT_EQ(result.tile_assignment(), TileAssignment({8, 1, 1})); -} - -TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim1) { - HloSharding sharding = HloSharding::IotaTile({2, 2, 2}); - HloSharding result = - ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1, 2}); - EXPECT_EQ(result.tile_assignment(), - TileAssignment({1, 8, 1}, {2, 2, 2}, {1, 0, 2})); -} - -TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim2) { - HloSharding sharding = HloSharding::IotaTile({2, 2, 2}); - HloSharding result = - ReshapeToTileDimension(sharding, /*dim=*/2, /*dims=*/{0, 1, 2}); - EXPECT_EQ(result.tile_assignment(), - TileAssignment({1, 1, 8}, {4, 2}, {1, 0})); -} - -TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim2_Batch1) { - // Tile sharding in batch dimension, i.e. - // sharding={devices[2,2,2]<=[8]. - HloSharding sharding = HloSharding::IotaTile({2, 2, 2}); - // Reshape on dimensions {1, 2} only, therefore ignoring batch dimension 0. - HloSharding result = ReshapeToTileDimension(sharding, /*dim=*/2, - /*dims=*/{1, 2}); - // Expected result is {devices=[2,1,4]0,2,1,3,4,6,5,7}, i.e. the two - // non-batch dimensions {{0, 1}, {2, 3}} and {{4, 5}, {6, 7}} are individually - // reshaped to tile dimension 2, i.e. {{0, 2, 1, 3}}, {{4, 6, 5, 7}}. - EXPECT_EQ(result.tile_assignment().array(), - Array3D({{{0, 2, 1, 3}}, {{4, 6, 5, 7}}})); +TEST(HloShardingUtilTest, ReshapeToTileDimension2D) { + // The two sharding in the vector are the same. They will be processed in + // different branches in ReshapeToTileDimension. + std::vector shardings = {HloSharding::IotaTile({2, 2}), + HloSharding::Tile({{0, 1}, {2, 3}})}; + + for (const HloSharding& sharding : shardings) { + EXPECT_EQ(ReshapeToTileDimension(sharding, /*dim=*/0, /*dims=*/{0, 1}) + .tile_assignment(), + TileAssignment((absl::Span){4, 1})); + EXPECT_EQ(ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1}) + .tile_assignment(), + TileAssignment({1, 4}, {2, 2}, {1, 0})); + } +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Case1) { + std::vector shardings = { + HloSharding::IotaTile({2, 2, 2}), + HloSharding::Tile({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})}; + + for (const HloSharding& sharding : shardings) { + EXPECT_EQ(ReshapeToTileDimension(sharding, /*dim=*/0, /*dims=*/{0, 1, 2}) + .tile_assignment(), + TileAssignment({8, 1, 1})); + EXPECT_EQ(ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1, 2}) + .tile_assignment(), + TileAssignment({1, 8, 1}, {2, 2, 2}, {1, 0, 2})); + EXPECT_EQ(ReshapeToTileDimension(sharding, /*dim=*/2, /*dims=*/{0, 1, 2}) + .tile_assignment(), + TileAssignment({1, 1, 8}, {4, 2}, {1, 0})); + + EXPECT_EQ(ReshapeToTileDimension(sharding, /*dim=*/2, + /*dims=*/{1, 2}) + .tile_assignment(), + TileAssignment({2, 1, 4}, {2, 2, 2}, {0, 2, 1})); + EXPECT_EQ(ReshapeToTileDimension(sharding, /*dim=*/0, + /*dims=*/{0, 2}) + .tile_assignment(), + TileAssignment({4, 2, 1}, {2, 2, 2}, {1, 0, 2})); + EXPECT_EQ(ReshapeToTileDimension(sharding, /*dim=*/2, + /*dims=*/{0, 2}) + .tile_assignment(), + TileAssignment({1, 2, 4}, {2, 2, 2}, {1, 2, 0})); + } +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Case2) { + // The input sharding has a complicated device list. + std::vector shardings = { + HloSharding::IotaTile({2, 2, 2}, {4, 2}, {1, 0}), + HloSharding::Tile({{{0, 2}, {4, 6}}, {{1, 3}, {5, 7}}})}; + for (const HloSharding& sharding : shardings) { + EXPECT_EQ(ReshapeToTileDimension(sharding, /*dim=*/0, /*dims=*/{0, 1, 2}) + .tile_assignment(), + TileAssignment({8, 1, 1}, {4, 2}, {1, 0})); + EXPECT_EQ(ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1, 2}) + .tile_assignment(), + TileAssignment({1, 8, 1}, {2, 2, 2}, {0, 2, 1})); + EXPECT_EQ(ReshapeToTileDimension(sharding, /*dim=*/2, /*dims=*/{0, 1, 2}) + .tile_assignment(), + TileAssignment({1, 1, 8}, {2, 4}, {1, 0})); + } +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension4D) { + HloSharding sharding1 = HloSharding::IotaTile({2, 3, 5, 7}); + HloSharding sharding2 = + HloSharding::Tile(sharding1.tile_assignment().array()); + std::vector shardings = {sharding1, sharding2}; + + for (const HloSharding& sharding : shardings) { + EXPECT_EQ(ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1}) + .tile_assignment(), + TileAssignment({1, 6, 5, 7}, {2, 3, 5, 7}, {2, 3, 1, 0})); + EXPECT_EQ(ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{1, 2}) + .tile_assignment(), + TileAssignment({2, 15, 1, 7}, {2, 3, 5, 7}, {0, 3, 1, 2})); + EXPECT_EQ(ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{1, 3}) + .tile_assignment(), + TileAssignment({2, 21, 5, 1}, {2, 3, 5, 7}, {0, 2, 1, 3})); + + EXPECT_EQ(ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1, 2}) + .tile_assignment(), + TileAssignment({1, 30, 1, 7}, {2, 3, 5, 7}, {3, 1, 0, 2})); + EXPECT_EQ(ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1, 3}) + .tile_assignment(), + TileAssignment({1, 42, 5, 1}, {2, 3, 5, 7}, {2, 1, 0, 3})); + EXPECT_EQ(ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{1, 2, 3}) + .tile_assignment(), + TileAssignment({2, 105, 1, 1}, {2, 3, 5, 7}, {0, 1, 2, 3})); + + EXPECT_EQ(ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1, 2, 3}) + .tile_assignment(), + TileAssignment({1, 210, 1, 1}, {2, 3, 5, 7}, {1, 0, 2, 3})); + } } TEST(HloShardingUtilTest, PropagateReshapeShardingTiledSplitPartialMatch) { @@ -310,12 +435,11 @@ TEST(HloShardingUtilTest, GetManualSubgroupSharding_ReplicatedAndManual) { TEST(HloShardingUtilTest, UngroupSharding_ManualOnly) { HloSharding sharding = HloSharding::IotaTile({1, 2}); std::vector> device_groups = {{0, 2}, {1, 3}}; - std::vector group_dims = {2}; - std::vector group_dim_sizes = {2}; + DimensionVector group_dims = {2}; + DimensionVector group_dim_sizes = {2}; auto grouped = GroupedSharding( - std::move(device_groups), - std::vector(group_dims.begin(), group_dims.end()), + std::move(device_groups), std::move(group_dims), std::move(group_dim_sizes), sharding.tile_assignment().num_dimensions(), sharding, /*subgroup_manual=*/true); @@ -329,15 +453,14 @@ TEST(HloShardingUtilTest, UngroupSharding_ReplicatedAndManual) { HloSharding sharding = HloSharding::PartialTile(TileAssignment({1, 2, 2})); std::vector> device_groups = {{0, 2, 4, 6}, {1, 3, 5, 7}}; - std::vector group_dims = {3}; - std::vector group_dim_sizes = {2}; + DimensionVector group_dims = {3}; + DimensionVector group_dim_sizes = {2}; - auto grouped = GroupedSharding( - std::move(device_groups), - std::vector(group_dims.begin(), group_dims.end()), - std::move(group_dim_sizes), - sharding.tile_assignment().num_dimensions() - 1, sharding, - /*subgroup_manual=*/true); + auto grouped = + GroupedSharding(std::move(device_groups), std::move(group_dims), + std::move(group_dim_sizes), + sharding.tile_assignment().num_dimensions() - 1, sharding, + /*subgroup_manual=*/true); HloSharding ungroup_sharding = UngroupSharding(grouped); VLOG(1) << "ungroup_sharding: " << ungroup_sharding.ToString(); @@ -351,15 +474,14 @@ TEST(HloShardingUtilTest, UngroupSharding_ManualAndReplicated) { HloSharding sharding = HloSharding::PartialTile(TileAssignment({1, 2, 2})); std::vector> device_groups = {{0, 1, 4, 5}, {2, 3, 6, 7}}; - std::vector group_dims = {2}; - std::vector group_dim_sizes = {2}; + DimensionVector group_dims = {2}; + DimensionVector group_dim_sizes = {2}; - auto grouped = GroupedSharding( - std::move(device_groups), - std::vector(group_dims.begin(), group_dims.end()), - std::move(group_dim_sizes), - sharding.tile_assignment().num_dimensions() - 1, sharding, - /*subgroup_manual=*/true); + auto grouped = + GroupedSharding(std::move(device_groups), std::move(group_dims), + std::move(group_dim_sizes), + sharding.tile_assignment().num_dimensions() - 1, sharding, + /*subgroup_manual=*/true); HloSharding ungroup_sharding = UngroupSharding(grouped); VLOG(1) << "ungroup_sharding: " << ungroup_sharding.ToString(); @@ -372,16 +494,15 @@ TEST(HloShardingUtilTest, UngroupSharding_ManualAndReplicated) { TEST(HloShardingUtilTest, UngroupSharding_Replicated) { HloSharding sharding = HloSharding::Replicate(); - std::vector group_dims = {3}; - std::vector group_dim_sizes = {2}; + DimensionVector group_dims = {3}; + DimensionVector group_dim_sizes = {2}; std::vector> device_groups = {{0, 1}, {2, 3}}; - auto grouped = GroupedSharding( - std::move(device_groups), - std::vector(group_dims.begin(), group_dims.end()), - std::move(group_dim_sizes), 2, sharding, - /*subgroup_manual=*/true); + auto grouped = + GroupedSharding(std::move(device_groups), std::move(group_dims), + std::move(group_dim_sizes), 2, sharding, + /*subgroup_manual=*/true); HloSharding ungroup_sharding = UngroupSharding(grouped); VLOG(1) << "ungroup_sharding: " << ungroup_sharding.ToString(); @@ -392,16 +513,15 @@ TEST(HloShardingUtilTest, UngroupSharding_Replicated) { TEST(HloShardingUtilTest, UngroupSharding_Replicated2) { HloSharding sharding = HloSharding::Replicate(); - std::vector group_dims = {2}; - std::vector group_dim_sizes = {2}; + DimensionVector group_dims = {2}; + DimensionVector group_dim_sizes = {2}; std::vector> device_groups = {{0, 2}, {1, 3}}; - auto grouped = GroupedSharding( - std::move(device_groups), - std::vector(group_dims.begin(), group_dims.end()), - std::move(group_dim_sizes), 2, sharding, - /*subgroup_manual=*/true); + auto grouped = + GroupedSharding(std::move(device_groups), std::move(group_dims), + std::move(group_dim_sizes), 2, sharding, + /*subgroup_manual=*/true); HloSharding ungroup_sharding = UngroupSharding(grouped); VLOG(1) << "ungroup_sharding: " << ungroup_sharding.ToString(); @@ -410,50 +530,78 @@ TEST(HloShardingUtilTest, UngroupSharding_Replicated2) { "{devices=[1,1,2,2]0,2,1,3 last_tile_dims={manual, replicated}}"); } +TEST(HloShardingUtilTest, GroupedAndUngroupedReplicatedSharding) { + GroupedSharding group_sharding = GetGroupedReplicatedSharding( + /*num_groups=*/3, /*num_tiles=*/12, /*data_rank=*/2); + EXPECT_EQ(UngroupSharding(group_sharding), HloSharding::Replicate()); +} + +TEST(HloShardingUtilTest, GroupedAndUngroupedIotaSharding) { + std::vector> device_groups = {{0, 1, 2, 3, 4, 5}, + {6, 7, 8, 9, 10, 11}}; + GroupedSharding group_sharding = GroupedSharding( + device_groups, /*group_dims=*/{0}, /*group_dim_sizes=*/{2}, + /*data_rank=*/2, HloSharding::IotaTile({1, 2, 3}, {2, 3}, {1, 0})); + EXPECT_EQ(UngroupSharding(group_sharding), + HloSharding::IotaTile({2, 2, 3}, {2, 2, 3}, {0, 2, 1})); +} + +TEST(HloShardingUtilTest, GroupedAndUngroupedShardingWithUnsortedGroupDims) { + HloSharding sharding = HloSharding::IotaTile({4, 3, 5, 7}); + GroupedSharding group_sharding = + GroupShardingOnDims(sharding, {2, 0}, {1, 2}); + EXPECT_EQ(group_sharding.sharding, HloSharding::IotaTile({2, 3, 1, 7})); + EXPECT_EQ(UngroupSharding(group_sharding), sharding); +} + +TEST(HloShardingUtilTest, UngroupShardingWithUnsortedGroupDims) { + GroupedSharding group_sharding({{0}, {1}, {2}, {3}}, {1, 0}, {2, 2}, 4, + HloSharding::Replicate()); + EXPECT_EQ(UngroupSharding(group_sharding), + HloSharding::IotaTile({2, 2, 1, 1}, {2, 2}, {1, 0})); +} + TEST(HloShardingUtilTest, DeviceGroupsDoesNotMatch) { HloSharding sharding = HloSharding::PartialTile( TileAssignment((absl::Span){2, 2})); - std::vector group_dims = {2}; - std::vector group_dim_sizes = {2}; + DimensionVector group_dim_sizes = {2}; std::vector> lhs_device_groups = {{0, 2, 4, 6}, {1, 3, 5, 7}}; - std::vector lhs_group_dims = {3}; + DimensionVector lhs_group_dims = {3}; - auto lhs = GroupedSharding( - std::move(lhs_device_groups), - std::vector(lhs_group_dims.begin(), lhs_group_dims.end()), - group_dim_sizes, 2, sharding, - /*subgroup_manual=*/true); + auto lhs = + GroupedSharding(std::move(lhs_device_groups), std::move(lhs_group_dims), + group_dim_sizes, 2, sharding, + /*subgroup_manual=*/true); std::vector> rhs_device_groups = {{0, 1, 4, 5}, {2, 3, 6, 7}}; - std::vector rhs_group_dims = {2}; + DimensionVector rhs_group_dims = {2}; - auto rhs = GroupedSharding( - std::move(rhs_device_groups), - std::vector(rhs_group_dims.begin(), rhs_group_dims.end()), - group_dim_sizes, 2, sharding, - /*subgroup_manual=*/true); + auto rhs = + GroupedSharding(std::move(rhs_device_groups), std::move(rhs_group_dims), + group_dim_sizes, 2, sharding, + /*subgroup_manual=*/true); EXPECT_FALSE(DeviceGroupsAreMatch(lhs, rhs)); } TEST(HloShardingUtilTest, DeviceGroupsMatch) { HloSharding lhs_sharding = HloSharding::Replicate(); - std::vector group_dims = {2}; - std::vector group_dim_sizes = {2}; + DimensionVector group_dims = {2}; + DimensionVector group_dim_sizes = {2}; std::vector> device_groups = {{0, 2}, {1, 3}}; auto lhs = GroupedSharding( - device_groups, std::vector(group_dims.begin(), group_dims.end()), + device_groups, DimensionVector(group_dims.begin(), group_dims.end()), group_dim_sizes, 2, lhs_sharding, /*subgroup_manual=*/true); HloSharding rhs_sharding = HloSharding::PartialTile( TileAssignment((absl::Span){2, 2})); auto rhs = GroupedSharding( - device_groups, std::vector(group_dims.begin(), group_dims.end()), + device_groups, DimensionVector(group_dims.begin(), group_dims.end()), group_dim_sizes, 2, rhs_sharding, /*subgroup_manual=*/true); @@ -521,12 +669,131 @@ TEST(HloShardingUtilTest, IsSubShardingCompatibleShapeTiledPartialTiled) { EXPECT_TRUE(IsSubTilingOrEqualSharding(shape, lhs_sharding, rhs_sharding)); } +TEST(HloShardingUtilTest, IsSubTilingOrEqualShardingNoShortcut) { + HloSharding rhs_sharding = HloSharding::PartialTile( + TileAssignment((absl::Span){2, 2})); + HloSharding lhs_sharding = HloSharding::IotaTile({4}); + std::vector success = {1, 3, 4, 7, 8, 11, 12, 15, 16, 19, 20}; + std::vector fail = {2, 5, 6, 9, 10, 13, 14, 17, 18}; + for (int64_t i : success) { + Shape shape = ShapeUtil::MakeShape(F32, {i}); + EXPECT_TRUE(IsSubTilingOrEqualSharding(shape, lhs_sharding, rhs_sharding)); + } + for (int64_t i : fail) { + Shape shape = ShapeUtil::MakeShape(F32, {i}); + EXPECT_FALSE(IsSubTilingOrEqualSharding(shape, lhs_sharding, rhs_sharding)); + } +} + +TEST(HloShardingUtilTest, IsSubTilingOrEqualShardingShortcut1) { + HloSharding rhs_sharding = HloSharding::PartialTile( + TileAssignment((absl::Span){2, 2})); + HloSharding lhs_sharding = HloSharding::IotaTile({4}); + Shape shape = ShapeUtil::MakeShape(F32, {8}); + EXPECT_TRUE(IsSubTilingOrEqualSharding(shape, lhs_sharding, rhs_sharding)); +} + +TEST(HloShardingUtilTest, IsSubTilingOrEqualShardingShortcut2) { + HloSharding rhs_sharding = HloSharding::PartialTile( + TileAssignment((absl::Span){2, 2})); + Array lhs_array({4}); + lhs_array.SetValues({1, 0, 2, 3}); + HloSharding lhs_sharding = HloSharding::Tile(lhs_array); + Shape shape = ShapeUtil::MakeShape(F32, {8}); + EXPECT_TRUE(IsSubTilingOrEqualSharding(shape, lhs_sharding, rhs_sharding)); +} + +TEST(HloShardingUtilTest, IsSubTilingOrEqualShardingShortcut3) { + HloSharding rhs_sharding = HloSharding::PartialTile( + TileAssignment((absl::Span){2, 2})); + HloSharding lhs_sharding = HloSharding::IotaTile({4}, {2, 2}, {1, 0}); + Shape shape = ShapeUtil::MakeShape(F32, {8}); + EXPECT_FALSE(IsSubTilingOrEqualSharding(shape, lhs_sharding, rhs_sharding)); +} + +TEST(HloShardingUtilTest, IsSubTilingOrEqualShardingShortcut4) { + HloSharding rhs_sharding = + HloSharding::PartialTile(TileAssignment({2, 2}, {2, 2}, {1, 0})); + HloSharding lhs_sharding = HloSharding::IotaTile({4}, {2, 2}, {1, 0}); + Shape shape = ShapeUtil::MakeShape(F32, {8}); + EXPECT_TRUE(IsSubTilingOrEqualSharding(shape, lhs_sharding, rhs_sharding)); +} + +TEST(HloShardingUtilTest, IsSubTilingOrEqualShardingShortcut5) { + HloSharding rhs_sharding = + HloSharding::PartialTile(TileAssignment({2, 3, 5, 7})); + HloSharding lhs_sharding_1 = + HloSharding::IotaTile({2, 21, 5}, {2, 3, 5, 7}, {0, 1, 3, 2}); + HloSharding lhs_sharding_2 = + HloSharding::IotaTile({2, 21, 5}, {2, 3, 5, 7}, {0, 2, 3, 1}); + HloSharding lhs_sharding_3 = HloSharding::IotaTile({2, 21, 5}); + std::vector shapes = {ShapeUtil::MakeShape(F32, {10, 42, 10}), + ShapeUtil::MakeShape(F32, {11, 41, 11})}; + for (const auto& shape : shapes) { + EXPECT_TRUE( + IsSubTilingOrEqualSharding(shape, lhs_sharding_1, rhs_sharding)); + EXPECT_FALSE( + IsSubTilingOrEqualSharding(shape, lhs_sharding_2, rhs_sharding)); + EXPECT_FALSE( + IsSubTilingOrEqualSharding(shape, lhs_sharding_3, rhs_sharding)); + } +} + +TEST(HloShardingUtilTest, IsSubTilingOrEqualShardingShortcut6) { + HloSharding rhs_sharding = + HloSharding::PartialTile(TileAssignment({2, 3, 5, 7 * 11 * 13})); + HloSharding lhs_sharding_1 = HloSharding::PartialTile(TileAssignment( + {2 * 7, 3, 5 * 11, 13}, {2, 3, 5, 7, 11, 13}, {0, 3, 1, 2, 4, 5})); + HloSharding lhs_sharding_2 = HloSharding::PartialTile(TileAssignment( + {2 * 7, 3, 5 * 11, 13}, {2, 3, 5, 11, 7, 13}, {0, 4, 1, 2, 3, 5})); + HloSharding lhs_sharding_3 = HloSharding::PartialTile(TileAssignment( + {2 * 7, 3, 5 * 11, 13}, {2, 3, 5, 13, 7, 11}, {0, 4, 1, 2, 5, 3})); + HloSharding lhs_sharding_4 = HloSharding::PartialTile(TileAssignment( + {2 * 7, 3, 5 * 11, 13}, {2, 3, 5, 7, 13, 11}, {0, 3, 1, 2, 5, 4})); + HloSharding lhs_sharding_5 = + HloSharding::PartialTile(TileAssignment({2 * 7, 3, 5 * 11, 13})); + std::vector shapes = { + ShapeUtil::MakeShape(F32, {2 * 7, 9, 5 * 11}), + ShapeUtil::MakeShape(F32, {2 * 7 - 1, 4, 5 * 11 - 1})}; + for (const auto& shape : shapes) { + EXPECT_TRUE( + IsSubTilingOrEqualSharding(shape, lhs_sharding_1, rhs_sharding)); + EXPECT_TRUE( + IsSubTilingOrEqualSharding(shape, lhs_sharding_2, rhs_sharding)); + EXPECT_TRUE( + IsSubTilingOrEqualSharding(shape, lhs_sharding_3, rhs_sharding)); + EXPECT_TRUE( + IsSubTilingOrEqualSharding(shape, lhs_sharding_4, rhs_sharding)); + EXPECT_FALSE( + IsSubTilingOrEqualSharding(shape, lhs_sharding_5, rhs_sharding)); + } +} + +TEST(HloShardingUtilTest, IsSubTilingOrEqualShardingShortcut7) { + HloSharding rhs_sharding = + HloSharding::PartialTile(TileAssignment({1, 2, 1, 3, 5 * 7 * 11})); + HloSharding lhs_sharding = HloSharding::PartialTile( + TileAssignment({5, 2, 7, 3, 11}, {2, 3, 5, 7, 11}, {2, 0, 3, 1, 4})); + std::vector shapes = {ShapeUtil::MakeShape(F32, {5, 2, 7, 3}), + ShapeUtil::MakeShape(F32, {2, 2, 9, 3})}; + for (const auto& shape : shapes) { + EXPECT_TRUE(IsSubTilingOrEqualSharding(shape, lhs_sharding, rhs_sharding)); + } +} + TEST(HloShardingUtilTest, IsSortOperandShardingMovableRankTwoOneFreeDim) { HloIotaInstruction iota(ShapeUtil::MakeShape(F32, {8, 128}), 1); iota.set_sharding(HloSharding::IotaTile({1, 2})); EXPECT_TRUE(IsSortOperandShardingMovable(&iota, 1)); } +TEST(HloShardingUtilTest, + IsSortOperandShardingMovableRankTwoOneFreeDimOfSize1) { + HloIotaInstruction iota(ShapeUtil::MakeShape(F32, {1, 128}), 1); + iota.set_sharding(HloSharding::IotaTile({1, 2})); + EXPECT_FALSE(IsSortOperandShardingMovable(&iota, 1)); +} + TEST(HloShardingUtilTest, IsSortOperandShardingMovableRankTwoNoFreeDims) { HloIotaInstruction iota(ShapeUtil::MakeShape(F32, {8, 128}), 1); iota.set_sharding(HloSharding::IotaTile({2, 2})); @@ -556,6 +823,37 @@ TEST(HloShardingUtilTest, IsSortOperandShardingMovableSortDimUnsharded) { iota.set_sharding(HloSharding::IotaTile({1, 2})); EXPECT_FALSE(IsSortOperandShardingMovable(&iota, 0)); } + +TEST(HloShardingUtilTest, TileShape) { + HloSharding sharding = HloSharding::Tile(TileAssignment({4, 1})); + Shape shape_0 = ShapeUtil::MakeShape(F32, {80, 128}); + auto tile_shape_0 = hlo_sharding_util::TileShape(sharding, shape_0); + auto expected_shape_0 = ShapeUtil::MakeShape(F32, {20, 128}); + EXPECT_EQ(tile_shape_0, expected_shape_0); + Shape shape_1 = ShapeUtil::MakeShape(F32, {40, 128}); + auto tile_shape_1 = hlo_sharding_util::TileShape(sharding, shape_1); + auto expected_shape_1 = ShapeUtil::MakeShape(F32, {10, 128}); + EXPECT_EQ(tile_shape_1, expected_shape_1); + const Shape tuple = ShapeUtil::MakeTupleShape({tile_shape_0, tile_shape_1}); + EXPECT_EQ(hlo_sharding_util::TileShape(sharding, tuple), + ShapeUtil::MakeTupleShape({expected_shape_0, expected_shape_1})); +} + +TEST(HloShardingUtilTest, UntileShape) { + HloSharding sharding = HloSharding::Tile(TileAssignment({4, 1})); + Shape shape_0 = ShapeUtil::MakeShape(F32, {80, 128}); + auto tile_shape_0 = hlo_sharding_util::UntileShape(sharding, shape_0); + auto expected_shape_0 = ShapeUtil::MakeShape(F32, {320, 128}); + EXPECT_EQ(tile_shape_0, expected_shape_0); + Shape shape_1 = ShapeUtil::MakeShape(F32, {40, 128}); + auto tile_shape_1 = hlo_sharding_util::UntileShape(sharding, shape_1); + auto expected_shape_1 = ShapeUtil::MakeShape(F32, {160, 128}); + EXPECT_EQ(tile_shape_1, expected_shape_1); + const Shape tuple = ShapeUtil::MakeTupleShape({tile_shape_0, tile_shape_1}); + EXPECT_EQ(hlo_sharding_util::UntileShape(sharding, tuple), + ShapeUtil::MakeTupleShape({expected_shape_0, expected_shape_1})); +} + } // namespace } // namespace hlo_sharding_util } // namespace xla diff --git a/xla/index_util.cc b/xla/index_util.cc index 6797df88551d0..f209be28e7fff 100644 --- a/xla/index_util.cc +++ b/xla/index_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,16 +16,19 @@ limitations under the License. #include "xla/index_util.h" #include +#include #include #include #include "absl/strings/str_join.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/util.h" #include "tsl/platform/logging.h" namespace xla { -/* static */ std::vector IndexUtil::LinearIndexToMultidimensionalIndex( +/* static */ DimensionVector IndexUtil::LinearIndexToMultidimensionalIndex( const Shape& shape, int64_t linear_index) { DCHECK_GE(linear_index, 0); DCHECK_LT(linear_index, ShapeUtil::ElementsIn(shape)); @@ -37,7 +40,7 @@ namespace xla { // I{L(1)} = (linear_index / D{L(0)}) % D{L(1)} // I{L(2)} = (linear_index / (D{L(0)} * D{L(1)})) % D{L(2)} // ... - std::vector multi_index(shape.dimensions_size()); + DimensionVector multi_index(shape.dimensions_size()); // Accumulated product D{L(0)} * D{L(1)} * ... int64_t divisor = 1; diff --git a/xla/index_util.h b/xla/index_util.h index e096b93c00c48..bbafbbac83328 100644 --- a/xla/index_util.h +++ b/xla/index_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -110,7 +110,7 @@ class IndexUtil { // Converts a linear index into multidimensional index (eg {x, y, z}) based on // the shape and its layout. The first index in the returned multidimensional // index is dimension 0. - static std::vector LinearIndexToMultidimensionalIndex( + static DimensionVector LinearIndexToMultidimensionalIndex( const Shape& shape, int64_t linear_index); // Bumps a sequence of indices; e.g. {0,0,0,0} up by one index value; e.g. to diff --git a/xla/index_util_test.cc b/xla/index_util_test.cc index 19a8eb36e69e8..de393fadf7b20 100644 --- a/xla/index_util_test.cc +++ b/xla/index_util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ TEST(IndexUtilTest, VectorIndexing) { Shape vector_shape = ShapeUtil::MakeShape(F32, {100}); EXPECT_EQ(42, IndexUtil::MultidimensionalIndexToLinearIndex(vector_shape, {42})); - std::vector multi_index = + auto multi_index = IndexUtil::LinearIndexToMultidimensionalIndex(vector_shape, 42); EXPECT_EQ(1, multi_index.size()); EXPECT_EQ(42, multi_index[0]); @@ -56,8 +56,9 @@ TEST(IndexUtilTest, MatrixIndexingRowMajor) { {9, 19})); EXPECT_EQ(53, IndexUtil::MultidimensionalIndexToLinearIndex(matrix_shape_01, {3, 5})); - EXPECT_EQ(std::vector({3, 5}), - IndexUtil::LinearIndexToMultidimensionalIndex(matrix_shape_01, 53)); + EXPECT_THAT( + IndexUtil::LinearIndexToMultidimensionalIndex(matrix_shape_01, 53), + testing::ElementsAre(3, 5)); } TEST(IndexUtilTest, MatrixIndexingColumnMajor) { @@ -72,8 +73,9 @@ TEST(IndexUtilTest, MatrixIndexingColumnMajor) { {9, 19})); EXPECT_EQ(65, IndexUtil::MultidimensionalIndexToLinearIndex(matrix_shape_10, {3, 5})); - EXPECT_EQ(std::vector({3, 5}), - IndexUtil::LinearIndexToMultidimensionalIndex(matrix_shape_10, 65)); + EXPECT_THAT( + IndexUtil::LinearIndexToMultidimensionalIndex(matrix_shape_10, 65), + testing::ElementsAre(3, 5)); } TEST(IndexUtilTest, ThreeDArrayIndexing210) { @@ -131,7 +133,7 @@ TEST(IndexUtilTest, LinearToMultiToLinear) { Shape shape = ShapeUtil::MakeShape(F32, {10, 20, 30, 40, 30, 20, 10}); SetMinorToMajorLayout(&shape, minor_to_major_order); for (auto linear_index : linear_indexes) { - std::vector multi_index = + auto multi_index = IndexUtil::LinearIndexToMultidimensionalIndex(shape, linear_index); EXPECT_EQ(linear_index, IndexUtil::MultidimensionalIndexToLinearIndex( shape, multi_index)); diff --git a/xla/iterator_util.h b/xla/iterator_util.h index 23367890adcc7..80001e4a9b299 100644 --- a/xla/iterator_util.h +++ b/xla/iterator_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/iterator_util_test.cc b/xla/iterator_util_test.cc index 0bf6e7fe43f5a..ac093c3d1bd68 100644 --- a/xla/iterator_util_test.cc +++ b/xla/iterator_util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/layout.cc b/xla/layout.cc index 8e91029a5bc9e..8fe0743e8588d 100644 --- a/xla/layout.cc +++ b/xla/layout.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/layout_util.h" @@ -63,43 +64,82 @@ std::string Tile::ToString() const { return std::move(printer).ToString(); } -Layout::Layout() = default; +Layout::Layout() + : index_primitive_type_(PRIMITIVE_TYPE_INVALID), + pointer_primitive_type_(PRIMITIVE_TYPE_INVALID) {} + +SplitConfigProto SplitConfig::ToProto() const { + SplitConfigProto split_config_proto; + split_config_proto.set_dimension(dimension_); + for (int64_t i : split_indices_) { + split_config_proto.add_split_indices(i); + } + return split_config_proto; +} + +std::string SplitConfig::ToString() const { + return absl::StrCat("(", dimension_, ":", absl::StrJoin(split_indices_, ","), + ")"); +} Layout::Layout(absl::Span minor_to_major) - : minor_to_major_(minor_to_major.begin(), minor_to_major.end()) {} + : index_primitive_type_(PRIMITIVE_TYPE_INVALID), + pointer_primitive_type_(PRIMITIVE_TYPE_INVALID), + minor_to_major_(minor_to_major.begin(), minor_to_major.end()) {} Layout::Layout(absl::Span minor_to_major, absl::Span dim_level_types, absl::Span dim_unique, absl::Span dim_ordered, absl::Span tiles, + int64_t tail_padding_alignment_in_elements, PrimitiveType index_primitive_type, PrimitiveType element_primitive_type, int64_t element_size_in_bits, int64_t memory_space, + absl::Span split_configs, std::unique_ptr physical_shape, int64_t dynamic_shape_metadata_prefix_bytes) - : dim_level_types_(dim_level_types.begin(), dim_level_types.end()), - dim_unique_(dim_unique.begin(), dim_unique.end()), - dim_ordered_(dim_ordered.begin(), dim_ordered.end()), - minor_to_major_(minor_to_major.begin(), minor_to_major.end()), - tiles_(tiles.begin(), tiles.end()), - index_primitive_type_(index_primitive_type), + : index_primitive_type_(index_primitive_type), pointer_primitive_type_(element_primitive_type), - element_size_in_bits_(element_size_in_bits), memory_space_(memory_space), + element_size_in_bits_(element_size_in_bits), + minor_to_major_(minor_to_major.begin(), minor_to_major.end()), + tiles_(tiles.begin(), tiles.end()), + split_configs_(split_configs.begin(), split_configs.end()), + tail_padding_alignment_in_elements_(tail_padding_alignment_in_elements), physical_shape_(std::move(physical_shape)), dynamic_shape_metadata_prefix_bytes_( - dynamic_shape_metadata_prefix_bytes) {} + dynamic_shape_metadata_prefix_bytes) { + // Grow dim_attributes_ to the maximum length of "dim_level_types", + // "dim_unique", and "dim_ordered", and then initialize the attributes that + // should exist. + n_dim_level_types_ = dim_level_types.size(); + n_dim_unique_ = dim_unique.size(); + n_dim_ordered_ = dim_ordered.size(); + const int n_attributes = std::max( + n_dim_level_types_, std::max(n_dim_unique_, n_dim_ordered_)); + dim_attributes_.resize(n_attributes); + for (int i = 0; i < n_attributes; i++) { + if (i < n_dim_level_types_) + dim_attributes_[i].dim_level_type = dim_level_types[i]; + if (i < n_dim_unique_) dim_attributes_[i].dim_unique = dim_unique[i]; + if (i < n_dim_ordered_) dim_attributes_[i].dim_ordered = dim_ordered[i]; + } +} Layout::Layout(const Layout& other) - : dim_level_types_(other.dim_level_types_), - dim_unique_(other.dim_unique_), - dim_ordered_(other.dim_ordered_), - minor_to_major_(other.minor_to_major_), - tiles_(other.tiles_), + : dim_attributes_(other.dim_attributes_), + n_dim_level_types_(other.n_dim_level_types_), + n_dim_unique_(other.n_dim_unique_), + n_dim_ordered_(other.n_dim_ordered_), index_primitive_type_(other.index_primitive_type_), pointer_primitive_type_(other.pointer_primitive_type_), - element_size_in_bits_(other.element_size_in_bits_), memory_space_(other.memory_space_), + element_size_in_bits_(other.element_size_in_bits_), + minor_to_major_(other.minor_to_major_), + tiles_(other.tiles_), + split_configs_(other.split_configs_), + tail_padding_alignment_in_elements_( + other.tail_padding_alignment_in_elements_), physical_shape_(other.physical_shape_ != nullptr ? std::make_unique(*other.physical_shape_) : nullptr), @@ -112,15 +152,19 @@ Layout::~Layout() = default; Layout& Layout::operator=(const Layout& other) { if (this != &other) { - dim_level_types_ = other.dim_level_types_; - dim_unique_ = other.dim_unique_; - dim_ordered_ = other.dim_ordered_; + dim_attributes_ = other.dim_attributes_; + n_dim_level_types_ = other.n_dim_level_types_; + n_dim_unique_ = other.n_dim_unique_; + n_dim_ordered_ = other.n_dim_ordered_; minor_to_major_ = other.minor_to_major_; tiles_ = other.tiles_; + tail_padding_alignment_in_elements_ = + other.tail_padding_alignment_in_elements_; index_primitive_type_ = other.index_primitive_type_; pointer_primitive_type_ = other.pointer_primitive_type_; element_size_in_bits_ = other.element_size_in_bits_; memory_space_ = other.memory_space_; + split_configs_ = other.split_configs_; if (other.physical_shape_ != nullptr) { physical_shape_ = std::make_unique(*other.physical_shape_); } else { @@ -152,10 +196,19 @@ Layout& Layout::operator=(Layout&& other) = default; for (const TileProto& tile_proto : proto.tiles()) { *layout.add_tiles() = Tile::CreateFromProto(tile_proto); } + if (proto.tail_padding_alignment_in_elements() != 0) { + layout.set_tail_padding_alignment_in_elements( + proto.tail_padding_alignment_in_elements()); + } else { + layout.set_tail_padding_alignment_in_elements(1); + } layout.set_index_primitive_type(proto.index_primitive_type()); layout.set_pointer_primitive_type(proto.pointer_primitive_type()); layout.set_element_size_in_bits(proto.element_size_in_bits()); layout.set_memory_space(proto.memory_space()); + for (const SplitConfigProto& split_config_proto : proto.split_configs()) { + layout.add_split_configs(SplitConfig::CreateFromProto(split_config_proto)); + } if (proto.has_physical_shape()) { *layout.mutable_physical_shape() = Shape(proto.physical_shape()); } @@ -166,14 +219,14 @@ Layout& Layout::operator=(Layout&& other) = default; LayoutProto Layout::ToProto() const { LayoutProto proto; - for (DimLevelType dim_level_type : dim_level_types()) { - proto.add_dim_level_types(dim_level_type); + for (int i = 0; i < n_dim_level_types_; i++) { + proto.add_dim_level_types(dim_level_type(i)); } - for (bool dim_unique : dim_unique()) { - proto.add_dim_unique(dim_unique); + for (int i = 0; i < n_dim_unique_; i++) { + proto.add_dim_unique(dim_unique(i)); } - for (bool dim_ordered : dim_ordered()) { - proto.add_dim_ordered(dim_ordered); + for (int i = 0; i < n_dim_ordered_; i++) { + proto.add_dim_ordered(dim_ordered(i)); } proto.mutable_minor_to_major()->Reserve(minor_to_major_size()); for (const int64_t dimension : minor_to_major()) { @@ -182,10 +235,15 @@ LayoutProto Layout::ToProto() const { for (const Tile& tile : tiles()) { *proto.add_tiles() = tile.ToProto(); } + proto.set_tail_padding_alignment_in_elements( + tail_padding_alignment_in_elements()); proto.set_index_primitive_type(index_primitive_type()); proto.set_pointer_primitive_type(pointer_primitive_type()); proto.set_element_size_in_bits(element_size_in_bits_); proto.set_memory_space(memory_space_); + for (const SplitConfig& split_config : split_configs()) { + *proto.add_split_configs() = split_config.ToProto(); + } if (has_physical_shape()) { *proto.mutable_physical_shape() = physical_shape_->ToProto(); } @@ -222,20 +280,20 @@ void Layout::Print(Printer* printer) const { colon_printed = true; }; - if (!dim_level_types().empty()) { + if (n_dim_level_types_ > 0) { auto print_one = [&](int i) { printer->Append(DimLevelTypeAbbrev(dim_level_type(i))); - if (!dim_unique().empty() && !dim_unique(i)) { + if (n_dim_unique_ > 0 && !dim_unique(i)) { printer->Append("+"); } - if (!dim_ordered().empty() && !dim_ordered(i)) { + if (n_dim_ordered_ > 0 && !dim_ordered(i)) { printer->Append("~"); } }; print_colon(); printer->Append("D("); print_one(0); - for (int i = 1; i < dim_level_types().size(); ++i) { + for (int i = 1; i < n_dim_level_types_; ++i) { printer->Append(","); print_one(i); } @@ -250,6 +308,13 @@ void Layout::Print(Printer* printer) const { } } + if (tail_padding_alignment_in_elements() != 1) { + print_colon(); + printer->Append("L("); + printer->Append(tail_padding_alignment_in_elements()); + printer->Append(")"); + } + if (index_primitive_type() != PRIMITIVE_TYPE_INVALID) { print_colon(); if (primitive_util::IsIntegralType(index_primitive_type())) { @@ -287,6 +352,13 @@ void Layout::Print(Printer* printer) const { printer->Append(memory_space()); printer->Append(")"); } + if (!split_configs().empty()) { + print_colon(); + printer->Append("SC"); + for (const auto& split_config : split_configs()) { + printer->Append(split_config.ToString()); + } + } if (has_physical_shape()) { print_colon(); @@ -313,9 +385,33 @@ std::string Layout::ToString() const { bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) { if (!LayoutUtil::IsDense(lhs) || !LayoutUtil::IsDense(rhs)) { - if (lhs.dim_level_types() != rhs.dim_level_types()) { + // dim_level_types + if (lhs.dim_level_types_size() != rhs.dim_level_types_size()) { return false; } + for (int i = 0; i < lhs.dim_level_types_size(); i++) { + if (lhs.dim_level_type(i) != rhs.dim_level_type(i)) { + return false; + } + } + // dim_unique + if (lhs.dim_unique_size() != rhs.dim_unique_size()) { + return false; + } + for (int i = 0; i < lhs.dim_unique_size(); i++) { + if (lhs.dim_unique(i) != rhs.dim_unique(i)) { + return false; + } + } + // dim_ordered + if (lhs.dim_ordered_size() != rhs.dim_ordered_size()) { + return false; + } + for (int i = 0; i < lhs.dim_ordered_size(); i++) { + if (lhs.dim_ordered(i) != rhs.dim_ordered(i)) { + return false; + } + } } if (lhs.minor_to_major() != rhs.minor_to_major()) { return false; @@ -323,6 +419,11 @@ bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) { if (!ignore_tiles_ && lhs.tiles() != rhs.tiles()) { return false; } + if (!ignore_tail_padding_alignment_in_elements_ && + lhs.tail_padding_alignment_in_elements() != + rhs.tail_padding_alignment_in_elements()) { + return false; + } if (!ignore_index_primitive_type_ && lhs.index_primitive_type() != rhs.index_primitive_type()) { return false; @@ -338,6 +439,9 @@ bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) { if (!ignore_memory_space_ && lhs.memory_space() != rhs.memory_space()) { return false; } + if (!ignore_split_configs_ && lhs.split_configs() != rhs.split_configs()) { + return false; + } if (!ignore_physical_shape_) { if (lhs.has_physical_shape() || rhs.has_physical_shape()) { if (!lhs.has_physical_shape() || !rhs.has_physical_shape()) { @@ -387,9 +491,10 @@ Layout& Layout::DeleteDimension(int64_t dim_to_delete) { } // Delete the corresponding dim level types. if (LayoutUtil::IsSparse(*this)) { - dim_level_types_.erase(dim_level_types_.begin() + dim_to_delete); - dim_unique_.erase(dim_unique_.begin() + dim_to_delete); - dim_ordered_.erase(dim_ordered_.begin() + dim_to_delete); + if (dim_to_delete < n_dim_level_types_) n_dim_level_types_--; + if (dim_to_delete < n_dim_unique_) n_dim_unique_--; + if (dim_to_delete < n_dim_ordered_) n_dim_ordered_--; + dim_attributes_.erase(dim_attributes_.begin() + dim_to_delete); } return *this; } diff --git a/xla/layout.h b/xla/layout.h index c64806c0b2644..bbc73effab1e4 100644 --- a/xla/layout.h +++ b/xla/layout.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -58,7 +58,7 @@ class Tile { std::string ToString() const; // Returns the bound of the tile in the given dimension index. - int64_t dimension(int i) const { return dimensions_.at(i); } + int64_t dimension(int i) const { return dimensions_[i]; } // Returns the dimensions of the tile. absl::Span dimensions() const { return dimensions_; } @@ -90,6 +90,64 @@ class Tile { using TileVector = absl::InlinedVector; +// Describes how data is split between different memories. Each SplitConfig +// object represents a split in one dimension. Each SplitConfig is associated +// with a vector of split indices which point to the points in the iteration +// where the splits occur. For example, if the dimension contains 1024 elements, +// a split indices value of {512} indicates splitting this dimension into two +// right through the middle. The dimension here refers to the physical dimension +// such that 0 is the majormost dimension and rank-1 is the minormost dimension. +class SplitConfig { + public: + SplitConfig(int64_t dimension, absl::Span split_indices) + : dimension_(dimension), + split_indices_(split_indices.begin(), split_indices.end()) {} + + static SplitConfig CreateFromProto( + const SplitConfigProto& split_config_proto) { + return SplitConfig(split_config_proto.dimension(), + split_config_proto.split_indices()); + } + SplitConfigProto ToProto() const; + + bool operator==(const SplitConfig& other) const { + return dimension() == other.dimension() && + split_indices() == other.split_indices(); + } + bool operator!=(const SplitConfig& other) const { return !(*this == other); } + + std::string ToString() const; + + // Returns the dimension that is split. + int64_t dimension() const { return dimension_; } + SplitConfig& set_dimension(int64_t dimension) { + dimension_ = dimension; + return *this; + } + + // Returns the indices where splits occur. + absl::Span split_indices() const { return split_indices_; } + int64_t split_indices(int64_t idx) const { return split_indices_.at(idx); } + int64_t split_indices_size() const { return split_indices_.size(); } + SplitConfig& add_split_indices(int64_t split_index) { + split_indices_.push_back(split_index); + return *this; + } + SplitConfig& clear_split_indices() { + split_indices_.clear(); + return *this; + } + + template + friend H AbslHashValue(H h, const SplitConfig& t) { + return H::combine(std::move(h), t.dimension_, t.split_indices_); + } + + private: + int64_t dimension_; + absl::InlinedVector split_indices_; +}; + // TODO: Rename the `dim_level_types` field to `lvl_types`, so that it // matches `mlir::sparse_tensor::SparseTensorEncodingAttr`. class Layout { @@ -109,9 +167,11 @@ class Layout { absl::Span dim_unique, absl::Span dim_ordered, absl::Span tiles, + int64_t tail_padding_alignment_in_elements = 1, PrimitiveType index_primitive_type = PRIMITIVE_TYPE_INVALID, PrimitiveType element_primitive_type = PRIMITIVE_TYPE_INVALID, int64_t element_size_in_bits = 0, int64_t memory_space = 0, + absl::Span split_configs = {}, std::unique_ptr physical_shape = nullptr, int64_t dynamic_shape_metadata_prefix_bytes = 0); @@ -147,6 +207,11 @@ class Layout { return *this; } + Equal& IgnoreTailPaddingAlignmentInElements() { + ignore_tail_padding_alignment_in_elements_ = true; + return *this; + } + Equal& IgnoreIndexPrimitiveType() { ignore_index_primitive_type_ = true; return *this; @@ -162,6 +227,11 @@ class Layout { return *this; } + Equal& IgnoreSplitConfigs() { + ignore_split_configs_ = true; + return *this; + } + Equal& IgnorePhysicalShape() { ignore_physical_shape_ = true; return *this; @@ -178,15 +248,18 @@ class Layout { .IgnorePointerPrimitiveType() .IgnoreMemorySpace() .IgnorePhysicalShape() - .IgnoreElementSize(); + .IgnoreElementSize() + .IgnoreTailPaddingAlignmentInElements(); } private: bool ignore_tiles_ = false; + bool ignore_tail_padding_alignment_in_elements_ = false; bool ignore_element_size_ = false; bool ignore_index_primitive_type_ = false; bool ignore_pointer_primitive_type_ = false; bool ignore_memory_space_ = false; + bool ignore_split_configs_ = false; bool ignore_physical_shape_ = false; }; @@ -201,72 +274,66 @@ class Layout { // interface. // Methods for accessing the DimLevelType array. - int dim_level_types_size() const { return dim_level_types_.size(); } + int dim_level_types_size() const { return n_dim_level_types_; } DimLevelType dim_level_type(int index) const { - return dim_level_types_.at(index); + return dim_attributes_[index].dim_level_type; } Layout& set_dim_level_type(int index, DimLevelType dim_level_type) { - dim_level_types_.at(index) = dim_level_type; + dim_attributes_[index].dim_level_type = dim_level_type; return *this; } Layout& add_dim_level_type(DimLevelType dim_level_type) { - dim_level_types_.push_back(dim_level_type); + while (n_dim_level_types_ >= dim_attributes_.size()) { + dim_attributes_.push_back(DimInfo()); + } + dim_attributes_[n_dim_level_types_].dim_level_type = dim_level_type; + n_dim_level_types_++; return *this; } Layout& clear_dim_level_types() { - dim_level_types_.clear(); + n_dim_level_types_ = 0; return *this; } - absl::Span dim_level_types() const { - return dim_level_types_; - } - DimLevelTypeVector* mutable_dim_level_types() { return &dim_level_types_; } // Methods for accessing the dim_unique array. - int dim_unique_size() const { return dim_unique_.size(); } - bool dim_unique(int index) const { return dim_unique_.at(index); } + int dim_unique_size() const { return n_dim_unique_; } + bool dim_unique(int index) const { return dim_attributes_[index].dim_unique; } Layout& set_dim_unique(int index, bool unique) { - dim_unique_.at(index) = unique; + dim_attributes_[index].dim_unique = unique; return *this; } Layout& add_dim_unique(bool unique) { - dim_unique_.push_back(unique); - return *this; - } - Layout& clear_dim_unique() { - dim_unique_.clear(); + while (n_dim_unique_ >= dim_attributes_.size()) { + dim_attributes_.push_back(DimInfo()); + } + dim_attributes_[n_dim_unique_].dim_unique = unique; + n_dim_unique_++; return *this; } - absl::Span dim_unique() const { return dim_unique_; } - absl::InlinedVector* mutable_dim_unique() { - return &dim_unique_; - } // Methods for accessing the dim_ordered array. - int dim_ordered_size() const { return dim_ordered_.size(); } - bool dim_ordered(int index) const { return dim_ordered_.at(index); } + int dim_ordered_size() const { return n_dim_ordered_; } + bool dim_ordered(int index) const { + return dim_attributes_[index].dim_ordered; + } Layout& set_dim_ordered(int index, bool ordered) { - dim_ordered_.at(index) = ordered; + dim_attributes_[index].dim_ordered = ordered; return *this; } Layout& add_dim_ordered(bool ordered) { - dim_ordered_.push_back(ordered); - return *this; - } - Layout& clear_dim_ordered() { - dim_ordered_.clear(); + while (n_dim_ordered_ >= dim_attributes_.size()) { + dim_attributes_.push_back(DimInfo()); + } + dim_attributes_[n_dim_ordered_].dim_ordered = ordered; + n_dim_ordered_++; return *this; } - absl::Span dim_ordered() const { return dim_ordered_; } - absl::InlinedVector* mutable_dim_ordered() { - return &dim_ordered_; - } // Methods for accessing the minor-to-major array. int minor_to_major_size() const { return minor_to_major_.size(); } - int64_t minor_to_major(int index) const { return minor_to_major_.at(index); } + int64_t minor_to_major(int index) const { return minor_to_major_[index]; } Layout& set_minor_to_major(int index, int64_t value) { - minor_to_major_.at(index) = value; + minor_to_major_[index] = value; return *this; } Layout& add_minor_to_major(int64_t value) { @@ -286,8 +353,8 @@ class Layout { // Methods for accessing the tile field. int64_t tiles_size() const { return tiles_.size(); } - const Tile& tiles(int index) const { return tiles_.at(index); } - Tile* mutable_tiles(int index) { return &tiles_.at(index); } + const Tile& tiles(int index) const { return tiles_[index]; } + Tile* mutable_tiles(int index) { return &tiles_[index]; } Tile* add_tiles() { tiles_.push_back(Tile()); return &tiles_.back(); @@ -305,6 +372,14 @@ class Layout { return *this; } + int64_t tail_padding_alignment_in_elements() const { + return tail_padding_alignment_in_elements_; + } + Layout& set_tail_padding_alignment_in_elements(int64_t value) { + tail_padding_alignment_in_elements_ = value; + return *this; + } + PrimitiveType index_primitive_type() const { return index_primitive_type_; } Layout& set_index_primitive_type(PrimitiveType value) { index_primitive_type_ = value; @@ -321,12 +396,27 @@ class Layout { static constexpr int64_t kDefaultMemorySpace = 0; static constexpr int64_t kGenericFastMemorySpace = 1; + static constexpr int64_t kHostMemorySpace = 5; int64_t memory_space() const { return memory_space_; } Layout& set_memory_space(int64_t value) { memory_space_ = value; return *this; } + int split_configs_size() const { return split_configs_.size(); } + const SplitConfig& split_configs(int index) const { + return split_configs_.at(index); + } + SplitConfig* mutable_split_configs(int index) { + return &split_configs_.at(index); + } + Layout& add_split_configs(const SplitConfig& split_config) { + split_configs_.push_back(split_config); + return *this; + } + void clear_split_configs() { split_configs_.clear(); } + absl::Span split_configs() const { return split_configs_; } + // Methods for accessing the physical shape. bool has_physical_shape() const { return physical_shape_ != nullptr; } const Shape& physical_shape() const { @@ -354,17 +444,37 @@ class Layout { friend H AbslHashValue(H h, const Layout& l) { return H::combine(std::move(h), l.minor_to_major_, l.tiles_, l.element_size_in_bits_, l.index_primitive_type_, - l.pointer_primitive_type_, l.memory_space_); + l.pointer_primitive_type_, l.memory_space_, + l.split_configs_, l.tail_padding_alignment_in_elements_); } private: - // The list of dimension level types, indicating the method that will be used - // to represent each dimension of the array. - DimLevelTypeVector dim_level_types_; + // We store a single inlined vector to hold + struct DimInfo { + DimInfo() + : dim_level_type(DIM_DENSE), dim_unique(false), dim_ordered(false) {} + + DimLevelType dim_level_type : 6; + bool dim_unique : 1; + bool dim_ordered : 1; + }; + absl::InlinedVector dim_attributes_; - // Whether each DimLevelType is unique and ordered. - absl::InlinedVector dim_unique_; - absl::InlinedVector dim_ordered_; + uint8_t n_dim_level_types_ = 0; + uint8_t n_dim_unique_ = 0; + uint8_t n_dim_ordered_ = 0; + + // The primitive type to use for sparse array indices and pointers. Each of + // these must either be INVALID, or an unsigned integer type. + PrimitiveType index_primitive_type_ : 8; + PrimitiveType pointer_primitive_type_ : 8; + + // The assigned memory space. + int8_t memory_space_ = 0; + + // The number of bits used to store an individual array element. + // When the value is 0, default to ShapeUtil::ByteSizeOfPrimitiveType. + int64_t element_size_in_bits_ = 0; // A map from physical dimension numbers to logical dimension numbers. // The first element is the most minor physical dimension (fastest varying @@ -382,17 +492,18 @@ class Layout { // The tiles used in tiling-based layout. TileVector tiles_; - // The primitive type to use for sparse array indices and pointers. Each of - // these must either be INVALID, or an unsigned integer type. - PrimitiveType index_primitive_type_ = PRIMITIVE_TYPE_INVALID; - PrimitiveType pointer_primitive_type_ = PRIMITIVE_TYPE_INVALID; - - // The number of bits used to store an individual array element. - // When the value is 0, default to ShapeUtil::ByteSizeOfPrimitiveType. - int64_t element_size_in_bits_ = 0; - - // The assigned memory space. - int64_t memory_space_ = 0; + // The split configurations of the shape, which describes how the storage of + // the tensor is split between different physical memories. + absl::InlinedVector split_configs_; + + // The shape is padded at the end to multiple of, in terms of number of + // elements. This is useful when tiling does not bring the shape to certain + // desired granules. Tiling effectively pads/reshapes/transposes the shape + // to another shape. This field pads the total number of elements of that + // new shape to a multiple of certain number of elements. This is useful such + // as we want a layout which does not tile the data but still requires it to + // be padded to certain number of elements. + int64_t tail_padding_alignment_in_elements_ = 1; // The physical on-device shape used to represent a sparse array. std::unique_ptr physical_shape_; diff --git a/xla/layout_test.cc b/xla/layout_test.cc index 85675a39d7a42..46a13cf421b0e 100644 --- a/xla/layout_test.cc +++ b/xla/layout_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -37,13 +37,19 @@ TEST_F(LayoutTest, ToString) { .ToString(), "{3,2,1,0:T(42,123)(4,5)}"); EXPECT_EQ(Layout({3, 2, 1, 0}, {}, {}, {}, {Tile({42, 123}), Tile({4, 5})}) + .set_tail_padding_alignment_in_elements(100) .set_element_size_in_bits(42) .ToString(), - "{3,2,1,0:T(42,123)(4,5)E(42)}"); + "{3,2,1,0:T(42,123)(4,5)L(100)E(42)}"); EXPECT_EQ(Layout({3, 2, 1, 0}, {}, {}, {}, {Tile({42, 123}), Tile({4, 5})}) .set_memory_space(3) .ToString(), "{3,2,1,0:T(42,123)(4,5)S(3)}"); + EXPECT_EQ(Layout({0, 1}, {}, {}, {}, {Tile({123})}) + .add_split_configs(SplitConfig(0, {3})) + .add_split_configs(SplitConfig(1, {0, 4})) + .ToString(), + "{0,1:T(123)SC(0:3)(1:0,4)}"); } TEST_F(LayoutTest, StreamOut) { @@ -83,19 +89,26 @@ TEST_F(LayoutTest, Equality) { Layout({0, 1, 2}).set_memory_space(3)); EXPECT_FALSE(Layout::Equal()(Layout({0, 1, 2}, {}, {}, {}, {Tile({42, 44})}), Layout({0, 1, 2}))); + EXPECT_EQ(Layout({0, 1, 2}).add_split_configs(SplitConfig(0, {2})), + Layout({0, 1, 2}).add_split_configs(SplitConfig(0, {2}))); + EXPECT_NE(Layout({0, 1, 2}).add_split_configs(SplitConfig(0, {2})), + Layout({0, 1, 2}).add_split_configs(SplitConfig(0, {3}))); EXPECT_TRUE(Layout::Equal().IgnoreTiles()( Layout({0, 1, 2}, {}, {}, {}, {Tile({42, 44})}), Layout({0, 1, 2}))); - EXPECT_FALSE( - Layout::Equal()(Layout({0, 1, 2}, {}, {}, {}, {}, PRIMITIVE_TYPE_INVALID, - PRIMITIVE_TYPE_INVALID, 32), - Layout({0, 1, 2}, {}, {}, {}, {}, PRIMITIVE_TYPE_INVALID, - PRIMITIVE_TYPE_INVALID, 1))); + EXPECT_FALSE(Layout::Equal()( + Layout({0, 1, 2}, {}, {}, {}, {}, 1, PRIMITIVE_TYPE_INVALID, + PRIMITIVE_TYPE_INVALID, 32), + Layout({0, 1, 2}, {}, {}, {}, {}, 1, PRIMITIVE_TYPE_INVALID, + PRIMITIVE_TYPE_INVALID, 1))); EXPECT_TRUE(Layout::Equal().IgnoreElementSize()( Layout({0, 1, 2}).set_element_size_in_bits(32), Layout({0, 1, 2}).set_element_size_in_bits(1))); EXPECT_TRUE(Layout::Equal().IgnoreMemorySpace()( Layout({0, 1, 2}).set_memory_space(1), Layout({0, 1, 2}).set_memory_space(3))); + EXPECT_TRUE(Layout::Equal().IgnoreSplitConfigs()( + Layout({0, 1, 2}).add_split_configs(SplitConfig(0, {2})), + Layout({0, 1, 2}).add_split_configs(SplitConfig(0, {3})))); } TEST_F(LayoutTest, LayoutToFromProto) { @@ -111,9 +124,12 @@ TEST_F(LayoutTest, LayoutToFromProto) { Layout({3, 2, 1, 0}, {}, {}, {}, {Tile({42, 123}), Tile({4, 5})})); expect_unchanged(Layout({1, 0}, {DIM_DENSE, DIM_COMPRESSED}, {}, {}, {})); expect_unchanged( - Layout({1, 0}, {DIM_DENSE, DIM_COMPRESSED}, {}, {}, {}, - PRIMITIVE_TYPE_INVALID, PRIMITIVE_TYPE_INVALID, 0, 0, + Layout({1, 0}, {DIM_DENSE, DIM_COMPRESSED}, {}, {}, {}, 1, + PRIMITIVE_TYPE_INVALID, PRIMITIVE_TYPE_INVALID, 0, 0, {}, std::make_unique(ShapeUtil::MakeShape(S32, {10, 10})))); + expect_unchanged(Layout({0, 1}, {}, {}, {}, {Tile({123})}) + .add_split_configs(SplitConfig(0, {3})) + .add_split_configs(SplitConfig(1, {0, 4}))); } } // namespace diff --git a/xla/layout_util.cc b/xla/layout_util.cc index c4c2c5ece4eed..fbbf02ad9f0ed 100644 --- a/xla/layout_util.cc +++ b/xla/layout_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -66,9 +66,11 @@ absl::string_view BoolToString(bool b) { return b ? "true" : "false"; } absl::Span minor_to_major, absl::Span dim_level_types, absl::Span dim_unique, absl::Span dim_ordered, - absl::Span tiles, PrimitiveType index_primitive_type, - PrimitiveType pointer_primitive_type, int64_t element_size_in_bits, - int64_t memory_space, std::optional physical_shape, + absl::Span tiles, int64_t tail_padding_alignment_in_elements, + PrimitiveType index_primitive_type, PrimitiveType pointer_primitive_type, + int64_t element_size_in_bits, int64_t memory_space, + absl::Span split_configs, + std::optional physical_shape, int64_t dynamic_shape_metadata_prefix_bytes) { Layout layout; for (int64_t dimension_number : minor_to_major) { @@ -94,10 +96,15 @@ absl::string_view BoolToString(bool b) { return b ? "true" : "false"; } } *layout.add_tiles() = tile; } + layout.set_tail_padding_alignment_in_elements( + tail_padding_alignment_in_elements); layout.set_index_primitive_type(index_primitive_type); layout.set_pointer_primitive_type(pointer_primitive_type); layout.set_element_size_in_bits(element_size_in_bits); layout.set_memory_space(memory_space); + for (const SplitConfig& split_config : split_configs) { + layout.add_split_configs(split_config); + } if (physical_shape != std::nullopt) { *layout.mutable_physical_shape() = *std::move(physical_shape); } @@ -252,7 +259,8 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { absl::StrJoin(layout.minor_to_major(), ", "), shape.ShortDebugString()); } - std::vector dimensions_in_layout(shape.rank(), false); + absl::InlinedVector dimensions_in_layout(shape.rank(), + false); for (int64_t i = 0; i < shape.rank(); ++i) { int64_t dim = layout.minor_to_major(i); if (dim < 0 || dim >= shape.rank()) { @@ -271,13 +279,17 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { dimensions_in_layout[dim] = true; } - if (!layout.dim_level_types().empty()) { - if (layout.dim_level_types().size() != shape.rank()) { + if (layout.dim_level_types_size() > 0) { + if (layout.dim_level_types_size() != shape.rank()) { + std::vector dim_level_types(layout.dim_level_types_size()); + for (int i = 0; i < dim_level_types.size(); i++) { + dim_level_types[i] = layout.dim_level_type(i); + } return InvalidArgument( "layout dim_level_types field contains %d elements, but shape is " "rank %d: {%s}; shape: %s", layout.dim_level_types_size(), shape.rank(), - absl::StrJoin(layout.dim_level_types(), ", ", + absl::StrJoin(dim_level_types, ", ", [](std::string* out, DimLevelType dim_level_type) { absl::StrAppend(out, DimLevelType_Name(dim_level_type)); @@ -286,13 +298,17 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { } } - if (!layout.dim_unique().empty()) { - if (layout.dim_unique().size() != shape.rank()) { + if (layout.dim_unique_size() > 0) { + if (layout.dim_unique_size() != shape.rank()) { + std::vector dim_unique(layout.dim_unique_size()); + for (int i = 0; i < dim_unique.size(); i++) { + dim_unique[i] = layout.dim_unique(i); + } return InvalidArgument( "layout dim_unique field contains %d elements, but shape is " "rank %d: {%s}; shape: %s", layout.dim_unique_size(), shape.rank(), - absl::StrJoin(layout.dim_unique(), ", ", + absl::StrJoin(dim_unique, ", ", [](std::string* out, bool dim_unique) { absl::StrAppend(out, BoolToString(dim_unique)); }), @@ -300,20 +316,30 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { } } - if (!layout.dim_ordered().empty()) { - if (layout.dim_ordered().size() != shape.rank()) { + if (layout.dim_ordered_size() > 0) { + if (layout.dim_ordered_size() != shape.rank()) { + std::vector dim_ordered(layout.dim_ordered_size()); + for (int i = 0; i < dim_ordered.size(); i++) { + dim_ordered[i] = layout.dim_ordered(i); + } return InvalidArgument( - "layout dim_unique field contains %d elements, but shape is " + "layout dim_ordered field contains %d elements, but shape is " "rank %d: {%s}; shape: %s", layout.dim_ordered_size(), shape.rank(), - absl::StrJoin(layout.dim_unique(), ", ", - [](std::string* out, bool dim_unique) { - absl::StrAppend(out, BoolToString(dim_unique)); + absl::StrJoin(dim_ordered, ", ", + [](std::string* out, bool dim_ordered) { + absl::StrAppend(out, BoolToString(dim_ordered)); }), shape.ShortDebugString()); } } + if (layout.tail_padding_alignment_in_elements() <= 0) { + return InvalidArgument( + "layout tail_padding_alignment_in_elements field is <= 0: {%d}", + layout.tail_padding_alignment_in_elements()); + } + if (LayoutUtil::IsSparse(layout)) { if (layout.tiles_size() > 0) { return InvalidArgument( @@ -324,7 +350,7 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(layout.physical_shape())); TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( layout.physical_shape(), - [&](const Shape& subshape, const ShapeIndex& index) { + [&](const Shape& subshape, const ShapeIndex& index) -> absl::Status { if (subshape.has_layout() && subshape.layout().has_physical_shape()) { return InvalidArgument( @@ -388,6 +414,11 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { } } + if (layout.element_size_in_bits() < 0) { + return InvalidArgument("layout element_size_in_bits field is negative: %d", + layout.element_size_in_bits()); + } + return OkStatus(); } @@ -439,9 +470,10 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { } /* static */ bool LayoutUtil::IsDense(const Layout& layout) { - return absl::c_all_of( - layout.dim_level_types(), - [](DimLevelType dim_level_type) { return dim_level_type == DIM_DENSE; }); + for (int i = 0; i < layout.dim_level_types_size(); i++) { + if (layout.dim_level_type(i) != DIM_DENSE) return false; + } + return true; } /* static */ bool LayoutUtil::IsSparse(const Layout& layout) { @@ -449,24 +481,28 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { } /* static */ bool LayoutUtil::IsCOO(const Layout& layout) { - return !layout.dim_level_types().empty() && - layout.dim_level_type(0) == DIM_COMPRESSED && - absl::c_all_of(layout.dim_level_types().subspan(1), - [](DimLevelType dim_level_type) { - return dim_level_type == DIM_SINGLETON; - }); + if ((layout.dim_level_types_size() == 0) || + (layout.dim_level_type(0) != DIM_COMPRESSED)) { + return false; + } + for (int i = 1; i < layout.dim_level_types_size(); i++) { + if (layout.dim_level_type(i) != DIM_SINGLETON) return false; + } + return true; } /* static */ bool LayoutUtil::IsCSR(const Layout& layout) { return IsMonotonicWithDim0Major(layout) && - layout.dim_level_types() == - absl::Span{DIM_DENSE, DIM_COMPRESSED}; + (layout.dim_level_types_size() == 2) && + (layout.dim_level_type(0) == DIM_DENSE) && + (layout.dim_level_type(1) == DIM_COMPRESSED); } /* static */ bool LayoutUtil::IsCSC(const Layout& layout) { return IsMonotonicWithDim0Minor(layout) && - layout.dim_level_types() == - absl::Span{DIM_DENSE, DIM_COMPRESSED}; + (layout.dim_level_types_size() == 2) && + (layout.dim_level_type(0) == DIM_DENSE) && + (layout.dim_level_type(1) == DIM_COMPRESSED); } /* static */ bool LayoutUtil::IsMonotonicWithDim0Minor(const Layout& layout) { @@ -741,4 +777,40 @@ bool LayoutUtil::ValidateDimLevel(DimLevelType dim_level_type, bool dim_unique, return true; } +/*static*/ int64_t LayoutUtil::MaxSplitSize(const Shape& shape, int64_t dim) { + CHECK(shape.IsArray()) << ShapeUtil::HumanString(shape); + if (!shape.has_layout()) { + return shape.dimensions(dim); + } + const SplitConfig* split_config = nullptr; + for (const SplitConfig& config : shape.layout().split_configs()) { + if (Major(shape.layout(), config.dimension()) == dim) { + split_config = &config; + break; + } + } + if (split_config != nullptr) { + int64_t max_split_size = 0; + int64_t last_split_index = 0; + for (int split_index : split_config->split_indices()) { + int64_t split_size = split_index - last_split_index; + max_split_size = std::max(split_size, max_split_size); + last_split_index = split_index; + } + max_split_size = + std::max(max_split_size, shape.dimensions(dim) - last_split_index); + return max_split_size; + } + return shape.dimensions(dim); +} + +/*static*/ int64_t LayoutUtil::MaxElementsInPerSplit(const Shape& shape) { + CHECK(shape.IsArray()) << ShapeUtil::HumanString(shape); + int64_t max_elements_in = 1; + for (int dim = 0; dim < shape.rank(); ++dim) { + max_elements_in *= MaxSplitSize(shape, dim); + } + return max_elements_in; +} + } // namespace xla diff --git a/xla/layout_util.h b/xla/layout_util.h index 545d365975404..fca4ef169d2b7 100644 --- a/xla/layout_util.h +++ b/xla/layout_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -47,9 +47,11 @@ class LayoutUtil { absl::Span dim_unique = {}, absl::Span dim_ordered = {}, absl::Span tiles = {}, + int64_t tail_padding_alignment_in_elements = 1, PrimitiveType index_primitive_type = PRIMITIVE_TYPE_INVALID, PrimitiveType pointer_primitive_type = PRIMITIVE_TYPE_INVALID, int64_t element_size_in_bits = 0, int64_t memory_space = 0, + absl::Span split_configs = {}, std::optional physical_shape = std::nullopt, int64_t dynamic_shape_metadata_prefix_bytes = 0); @@ -285,6 +287,19 @@ class LayoutUtil { static bool ByteStridesIsMajorToMinor(absl::Span byte_strides, absl::Span dims, PrimitiveType element_type); + + // The max size of the split in the given dimension. If the layout doesn't + // have a split config in the given dimension, the value returned from this + // function is equal to the Shape::dimensions(). If there is a split config in + // the given dimension, we then find the size of the largest split in that + // dimension. + static int64_t MaxSplitSize(const Shape& shape, int64_t dim); + + // This function is analogous to ShapeUtil::ElementsIn, except we use the max + // split sizes for each dimension to calculate the max number of elements + // stored in a particular split. This can be useful for calculating how much + // memory to allocate in each of the memories. + static int64_t MaxElementsInPerSplit(const Shape& shape); }; } // namespace xla diff --git a/xla/layout_util_test.cc b/xla/layout_util_test.cc index 89b56cf5fc0b7..ed2f6ff479d7e 100644 --- a/xla/layout_util_test.cc +++ b/xla/layout_util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -435,8 +435,9 @@ TEST_F(LayoutUtilTest, ValidateLayout_InvalidArrayLayout) { TEST_F(LayoutUtilTest, ValidateLayout_InvalidDimLevelTypes) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); *shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); - *shape.mutable_layout()->mutable_dim_level_types() = {DIM_DENSE, DIM_DENSE, - DIM_DENSE}; + shape.mutable_layout()->add_dim_level_type(DIM_DENSE); + shape.mutable_layout()->add_dim_level_type(DIM_DENSE); + shape.mutable_layout()->add_dim_level_type(DIM_DENSE); auto status = LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false); EXPECT_FALSE(status.ok()); @@ -592,5 +593,27 @@ TEST_F(LayoutUtilTest, HasCustomElementSizeInBits) { EXPECT_TRUE(LayoutUtil::HasCustomElementSizeInBits(shape)); } +TEST_F(LayoutUtilTest, MaxSplitSize) { + Shape shape = ShapeUtil::MakeShape(F32, {150, 200, 100}); + *shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2}) + .add_split_configs(SplitConfig(0, {30})) + .add_split_configs(SplitConfig(1, {40, 130})); + + EXPECT_EQ(LayoutUtil::MaxSplitSize(shape, 0), 150); + EXPECT_EQ(LayoutUtil::MaxSplitSize(shape, 1), 90); + EXPECT_EQ(LayoutUtil::MaxSplitSize(shape, 2), 70); +} + +TEST_F(LayoutUtilTest, MaxElementsInPerSplit) { + Shape shape = ShapeUtil::MakeShape(F32, {150, 200, 100}); + *shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2}); + EXPECT_EQ(LayoutUtil::MaxElementsInPerSplit(shape), 150 * 200 * 100); + + *shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2}) + .add_split_configs(SplitConfig(0, {30})) + .add_split_configs(SplitConfig(1, {40, 130})); + EXPECT_EQ(LayoutUtil::MaxElementsInPerSplit(shape), 150 * 90 * 70); +} + } // namespace } // namespace xla diff --git a/xla/lazy.h b/xla/lazy.h index e296a41403a1c..8e960dd3d7d3a 100644 --- a/xla/lazy.h +++ b/xla/lazy.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,6 +28,8 @@ class Lazy { explicit Lazy(absl::AnyInvocable func) : maybe_value_(std::move(func)) {} + bool has_value() const { return std::holds_alternative(maybe_value_); } + const T& get() const { if (!std::holds_alternative(maybe_value_)) { maybe_value_ = diff --git a/xla/lit.bzl b/xla/lit.bzl new file mode 100644 index 0000000000000..57d25fff26124 --- /dev/null +++ b/xla/lit.bzl @@ -0,0 +1,301 @@ +"""Helper rules for writing LIT tests.""" + +load("@bazel_skylib//lib:paths.bzl", "paths") + +def enforce_glob(files, **kwargs): + """A utility to enforce that a list matches a glob expression. + + Note that the comparison is done in an order-independent fashion. + + Args: + files: a list that is expected to contain the same files as the + specified glob expression. + **kwargs: keyword arguments forwarded to the glob. + + Returns: + files. The input argument unchanged + """ + glob_result = native.glob(**kwargs) + + # glob returns a sorted list. + if sorted(files) != sorted(glob_result): + missing = [k for k in glob_result if k not in files] + extra = [k for k in files if k not in glob_result] + expected_formatted = "\n".join(['"{}",'.format(file) for file in glob_result]) + fail(("Error in enforce_glob." + + "\nExpected {}." + + "\nGot {}." + + "\nMissing {}." + + "\nExtra {}" + + "\nPaste this into the first enforce_glob argument:" + + "\n{}").format( + glob_result, + files, + missing, + extra, + expected_formatted, + )) + return files + +def lit_test_suite( + name, + srcs, + cfg, + tools = None, + args = None, + data = None, + visibility = None, + env = None, + timeout = None, + default_tags = None, + tags_override = None, + **kwargs): + """Creates one lit test per source file and a test suite that bundles them. + + Args: + name: string. the name of the generated test suite. + srcs: label_list. The files which contain the lit tests. + cfg: label. The lit config file. It must list the file extension of + the files in `srcs` in config.suffixes and must be in a parent directory + of `srcs`. + tools: label list. Tools invoked in the lit RUN lines. These binaries will + be symlinked into a directory which is on the path. They must therefore + have unique basenames. + args: string list. Additional arguments to pass to lit. Note that the test + file, `-v`, and a `--path` argument for the directory to which `tools` + are symlinked are added automatically. + data: label list. Additional data dependencies of the test. Note that + targets in `cfg` and `tools`, as well as their data dependencies, are + added automatically. + visibility: visibility of the generated test targets and test suite. + env: string_dict. Environment variables available during test execution. + See the common Bazel test attribute. + timeout: timeout argument passed to the individual tests. + default_tags: string list. Tags applied to all tests. + tags_override: string_dict. Tags applied in addition to only select tests. + **kwargs: additional keyword arguments to pass to all generated rules. + + See https://llvm.org/docs/CommandGuide/lit.html for details on lit + """ + # If there are kwargs that need to be passed to only some of the generated + # rules, they should be extracted into separate named arguments. + + args = args or [] + data = data or [] + tools = tools or [] + default_tags = default_tags or [] + tags_override = tags_override or {} + + tests = [] + for test_file in srcs: + # It's generally good practice to prefix any generated names with the + # macro name, but it's also nice to have the test name just match the + # file name. + test_name = "%s.test" % (test_file) + tests.append(test_name) + lit_test( + name = test_name, + test_file = test_file, + cfg = cfg, + tools = tools, + args = args, + data = data, + visibility = visibility, + env = env, + timeout = timeout, + tags = default_tags + tags_override.get(test_file, []), + **kwargs + ) + + native.test_suite( + name = name, + tests = tests, + **kwargs + ) + +def lit_test( + name, + test_file, + cfg, + tools = None, + args = None, + data = None, + visibility = None, + env = None, + timeout = None, + **kwargs): + """Runs a single test file with LLVM's lit tool. + + Args: + name: string. the name of the generated test target. + test_file: label. The file on which to run lit. + cfg: label. The lit config file. It must list the file extension of + `test_file` in config.suffixes and must be in a parent directory of + `test_file`. + tools: label list. Tools invoked in the lit RUN lines. These binaries will + be symlinked into a directory which is on the path. They must therefore + have unique basenames. + args: string list. Additional arguments to pass to lit. Note that the test + file, `-v`, and a `--path` argument for the directory to which `tools` + are symlinked are added automatically. + data: label list. Additional data dependencies of the test. Note that + targets in `cfg` and `tools`, as well as their data dependencies, are + added automatically. + visibility: visibility of the generated test target. + env: string_dict. Environment variables available during test execution. + See the common Bazel test attribute. + timeout: bazel test timeout string, as per common bazel definitions. + **kwargs: additional keyword arguments to pass to all generated rules. + + See https://llvm.org/docs/CommandGuide/lit.html for details on lit + """ + args = args or [] + data = data or [] + tools = tools or [] + env = env or {} + + tools_on_path_target_name = "_{}_tools_on_path".format(name) + + llvm_symbolizer = "@llvm-project//llvm:llvm-symbolizer" + if llvm_symbolizer not in tools: + tools.append(llvm_symbolizer) + + filecheck_env_var = "FILECHECK_OPTS" + if filecheck_env_var not in env: + env[filecheck_env_var] = "--enable-var-scope" + + bin_dir = paths.join( + native.package_name(), + tools_on_path_target_name, + "lit_bin", + ) + + _tools_on_path( + name = tools_on_path_target_name, + testonly = True, + srcs = tools, + bin_dir = bin_dir, + visibility = ["//visibility:private"], + **kwargs + ) + + native_test( + name = name, + src = "@llvm-project//llvm:lit", + args = [ + "-a", + "--path", + bin_dir, + "$(location {})".format(test_file), + ] + args, + data = [ + "@llvm-project//llvm:lit", + test_file, + + # TODO(cheshire): Config is not passed properly when it's not + # called lit.cfg.py + cfg, + tools_on_path_target_name, + ] + data, + visibility = visibility, + env = env, + timeout = timeout, + **kwargs + ) + +def _shared_impl(ctx): + out = ctx.attr.out + if not out: + out = ctx.attr.name + output = ctx.actions.declare_file(out) + ctx.actions.symlink( + target_file = ctx.executable.src, + output = output, + is_executable = True, + ) + + runfiles = ctx.runfiles(files = ctx.files.data) + + # For Bazel 4.x support. Drop when Bazel 4.x is no longer supported + to_merge = ([d[DefaultInfo].default_runfiles for d in ctx.attr.data] + + [ctx.attr.src[DefaultInfo].default_runfiles]) + if hasattr(runfiles, "merge_all"): + runfiles = runfiles.merge_all(to_merge) + else: + for m in to_merge: + runfiles = runfiles.merge(m) + return DefaultInfo( + executable = output, + files = depset([output]), + runfiles = runfiles, + ) + +def _native_test_impl(ctx): + default_info = _shared_impl(ctx) + return [default_info, testing.TestEnvironment(ctx.attr.env)] + +def _tools_on_path_impl(ctx): + runfiles = ctx.runfiles() + + # For Bazel 4.x support. Drop when Bazel 4.x is no longer supported + to_merge = [d[DefaultInfo].default_runfiles for d in ctx.attr.srcs] + if hasattr(runfiles, "merge_all"): + runfiles = runfiles.merge_all(to_merge) + else: + for m in to_merge: + runfiles = runfiles.merge(m) + + runfiles_symlinks = {} + + for src in ctx.attr.srcs: + exe = src[DefaultInfo].files_to_run.executable + if not exe: + fail("All targets used as tools by lit tests must have exactly one" + + " executable, but {} has none".format(src)) + bin_path = paths.join(ctx.attr.bin_dir, exe.basename) + if bin_path in runfiles_symlinks: + fail("All tools used by lit tests must have unique basenames, as" + + " they are added to the path." + + " {} and {} conflict".format(runfiles_symlinks[bin_path], exe)) + runfiles_symlinks[bin_path] = exe + + return [ + DefaultInfo(runfiles = ctx.runfiles( + symlinks = runfiles_symlinks, + ).merge(runfiles)), + ] + +_tools_on_path = rule( + _tools_on_path_impl, + attrs = { + "srcs": attr.label_list(allow_files = True, mandatory = True), + "bin_dir": attr.string(mandatory = True), + }, + doc = "Symlinks srcs into a single lit_bin directory. All basenames must be unique.", +) + +# We have to manually set "env" on the test rule because the builtin one is only +# available in native rules. See +# https://docs.bazel.build/versions/main/be/common-definitions.html#test.env +_TEST_ATTRS = { + "src": attr.label( + executable = True, + allow_files = True, + mandatory = True, + cfg = "target", + ), + "data": attr.label_list(allow_files = True), + # "out" is attr.string instead of attr.output, so that it is select()'able. + "out": attr.string(), + "env": attr.string_dict( + doc = "Mirrors the common env attribute that otherwise is" + + " only available on native rules. See" + + " https://docs.bazel.build/versions/main/be/common-definitions.html#test.env", + ), +} + +native_test = rule( + implementation = _native_test_impl, + attrs = _TEST_ATTRS, + test = True, +) diff --git a/xla/lit.cfg.py b/xla/lit.cfg.py new file mode 100644 index 0000000000000..ebb8bd434f382 --- /dev/null +++ b/xla/lit.cfg.py @@ -0,0 +1,50 @@ +# Copyright 2019 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Lit runner configuration.""" + +import os +import sys +import tempfile + +import lit.formats + + +# pylint: disable=undefined-variable + + +config.name = "XLA" +config.suffixes = [".cc", ".hlo", ".hlotxt", ".json", ".mlir", ".pbtxt", ".py"] + +config.test_format = lit.formats.ShTest(execute_external=True) + + +# Passthrough XLA_FLAGS. +config.environment["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + +# Use the most preferred temp directory. +config.test_exec_root = ( + os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") + or os.environ.get("TEST_TMPDIR") + or os.path.join(tempfile.gettempdir(), "lit") +) + +config.substitutions.extend([ + ("%PYTHON", os.getenv("PYTHON", sys.executable)), +]) + +# Include additional substitutions that may be defined via params +config.substitutions.extend( + ("%%{%s}" % key, val) + for key, val in lit_config.params.items() +) diff --git a/xla/literal.cc b/xla/literal.cc index 3749bcdbaed2f..62f42791577ea 100644 --- a/xla/literal.cc +++ b/xla/literal.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -49,17 +49,17 @@ limitations under the License. #include "xla/status.h" #include "xla/status_macros.h" #include "xla/statusor.h" +#include "xla/tsl/util/byte_swap_array.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/bitmap.h" #include "tsl/platform/errors.h" -#include "tsl/platform/float8.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/mem.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/util/byte_swap_array.h" namespace xla { namespace { @@ -120,12 +120,9 @@ const Shape& ScalarShapeImpl() { } const Shape& ScalarShape(PrimitiveType type) { - return primitive_util::PrimitiveTypeSwitch( + return primitive_util::ArrayTypeSwitch( [&](auto primitive_type_constant) -> const Shape& { - if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { - return ScalarShapeImpl(); - } - LOG(FATAL) << "Unhandled primitive type " << type; + return ScalarShapeImpl(); }, type); } @@ -252,6 +249,21 @@ Literal::Literal() : Literal(NilShape()) {} Literal::Literal(const Shape& shape) : Literal(shape, /*allocate_arrays=*/true) {} +void Literal::SetShape(const Shape& shape) { + Shape shape_storage; + const Shape* shape_ptr = &shape; + if (LayoutUtil::HasCustomElementSizeInBits(shape)) { + shape_storage = shape; + shape_storage.mutable_layout()->set_element_size_in_bits(0); + shape_ptr = &shape_storage; + } + if (const Shape* intered_shape_ptr = TryInternShape(*shape_ptr)) { + shape_ = intered_shape_ptr; + } else { + shape_ = std::make_unique(*shape_ptr); + } +} + void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays, ArrayValueState leaf_array_value_state) { if (shape.IsTuple()) { @@ -279,16 +291,9 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays, Literal::Literal(const Shape& shape, bool allocate_arrays, ArrayValueState leaf_array_value_state) : MutableLiteralBase() { - if (const Shape* intered_shape_ptr = TryInternShape(shape)) { - shape_ = intered_shape_ptr; - } else { - shape_ = std::make_unique(shape); - } + SetShape(shape); CHECK(leaf_array_value_state != ArrayValueState::kKnown || LayoutUtil::HasLayout(*shape_)); - // Currently we do nibble packing/unpacking in TPU host/device transfer. - CHECK(!LayoutUtil::HasCustomElementSizeInBits(*shape_)) - << "Literal does not support layouts with custom bit size: " << *shape_; root_piece_.set_subshape(shape_.get()); CHECK(&root_piece_.subshape() == shape_.get()); @@ -351,24 +356,38 @@ int32_t LiteralBase::GetDynamicSize(int64_t dim_index, } std::optional LiteralBase::GetFirstInteger() const { - return primitive_util::PrimitiveTypeSwitch>( + if (!primitive_util::IsIntegralType(shape().element_type())) { + return std::nullopt; + } + return primitive_util::IntegralTypeSwitch>( [&](auto primitive_type_constant) -> std::optional { - if constexpr (primitive_util::IsIntegralType(primitive_type_constant)) { - using NativeT = NativeTypeOf; - auto first_element = GetFirstElement(); - if constexpr (std::is_same_v) { - int64_t v = static_cast(first_element); - if (v < 0) { - return std::nullopt; - } + using NativeT = NativeTypeOf; + auto first_element = GetFirstElement(); + if constexpr (std::is_same_v) { + int64_t v = static_cast(first_element); + if (v < 0) { + return std::nullopt; } - return first_element; } - return std::nullopt; + return first_element; }, shape().element_type()); } +absl::Status LiteralBase::SerializeToString(std::string* output) const { + ShapeProto shape_proto = shape().ToProto(); + TF_ASSIGN_OR_RETURN(int64_t size, + ShapeUtil::SerializedSizeWithProto(shape(), shape_proto)); + output->resize(size); + return SerializeWithShapeProto(shape_proto, output->data()); +} + +absl::StatusOr LiteralBase::SerializeAsString() const { + std::string result; + TF_RETURN_IF_ERROR(SerializeToString(&result)); + return std::move(result); +} + template Status MutableLiteralBase::CopySliceFromInternal( const LiteralBase& src_literal, absl::Span src_base, @@ -449,7 +468,7 @@ void MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, } } -/* static */ StatusOr MutableLiteralBase::CreateFromProto( +/* static */ absl::StatusOr MutableLiteralBase::CreateFromProto( const LiteralProto& proto, bool prohibit_empty_literal) { if (!proto.has_shape()) { return InvalidArgument("LiteralProto has no shape"); @@ -471,7 +490,7 @@ void MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, Literal literal(shape); TF_RETURN_IF_ERROR(literal.root_piece_.ForEachMutableSubpieceWithStatus( - [&](const ShapeIndex& index, Piece* piece) { + [&](const ShapeIndex& index, Piece* piece) -> absl::Status { const LiteralProto* proto_element = &proto; for (int64_t i : index) { CHECK(i < proto_element->tuple_literals_size()); @@ -671,25 +690,18 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src, memcpy(buffer(), src.buffer(), src.size_bytes_dense()); } else { std::vector origin(subshape().rank(), 0); - TF_RETURN_IF_ERROR(primitive_util::PrimitiveTypeSwitch( - [&](auto primitive_type_constant) -> Status { - if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { - using NativeT = NativeTypeOf; - if (only_dynamic_bound) { - CopyElementsWithDynamicBound(src); - } else { - CopyElementsBetween(this->data(), - src.data(), subshape(), - src.subshape()); - } - return OkStatus(); + primitive_util::ArrayTypeSwitch( + [&](auto primitive_type_constant) { + using NativeT = NativeTypeOf; + if (only_dynamic_bound) { + CopyElementsWithDynamicBound(src); + } else { + CopyElementsBetween(this->data(), + src.data(), subshape(), + src.subshape()); } - return Unimplemented( - "Copying a Literal object with element type %s is not " - "implemented.", - PrimitiveType_Name(subshape().element_type())); }, - subshape().element_type())); + subshape().element_type()); } DCHECK_EQ(dynamic_size_buffer_bytes(), src.dynamic_size_buffer_bytes()); if (subshape().is_dynamic() && src.subshape().is_dynamic()) { @@ -815,17 +827,11 @@ Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal, TF_RET_CHECK(src_literal.shape().rank() == src_base.size()); TF_RET_CHECK(shape().rank() == dest_base.size()); - return primitive_util::PrimitiveTypeSwitch( + return primitive_util::ArrayTypeSwitch( [&](auto primitive_type_constant) -> Status { - if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { - using NativeT = NativeTypeOf; - return CopySliceFromInternal(src_literal, src_base, - dest_base, copy_size); - } - return Unimplemented( - "Copying a slice from a Literal object with element type %d is not " - "implemented.", - shape().element_type()); + using NativeT = NativeTypeOf; + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); }, shape().element_type()); } @@ -866,7 +872,7 @@ void MutableLiteralBase::PopulateInplaceInternal( } auto init_function = [&](absl::Span indexes, - int thread_id) -> StatusOr { + int thread_id) -> absl::StatusOr { const int64_t index = IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes); DimensionVector minor_scan_indexes(rank, 0); @@ -894,7 +900,7 @@ void MutableLiteralBase::PopulateInplaceInternal( this_shape, stride_config.base, stride_config.dimensions, stride_config.step, [&init_function]( - absl::Span indexes) -> StatusOr { + absl::Span indexes) -> absl::StatusOr { auto result_ignored = init_function(indexes, /*thread_id=*/-1); return true; }); @@ -1005,10 +1011,10 @@ Literal LiteralBase::ToStatic() const { namespace { template -StatusOr BroadcastHelper(const LiteralBase& src, - const Shape& src_shape, - const Shape& result_shape, - absl::Span dimensions) { +absl::StatusOr BroadcastHelper(const LiteralBase& src, + const Shape& src_shape, + const Shape& result_shape, + absl::Span dimensions) { for (int64_t i = 0, end = dimensions.size(); i < end; i++) { TF_RET_CHECK(src_shape.dimensions(i) == result_shape.dimensions(dimensions[i])); @@ -1073,7 +1079,7 @@ StatusOr BroadcastHelper(const LiteralBase& src, } } // anonymous namespace -StatusOr LiteralBase::Broadcast( +absl::StatusOr LiteralBase::Broadcast( const Shape& result_shape, absl::Span dimensions) const { const LiteralBase& src = *this; const Shape& src_shape = shape(); @@ -1103,7 +1109,7 @@ StatusOr LiteralBase::Broadcast( } } -StatusOr LiteralBase::Reshape( +absl::StatusOr LiteralBase::Reshape( absl::Span dimensions) const { if (!LayoutUtil::IsDenseArray(shape())) { return InvalidArgument("Reshape is only supported for dense arrays."); @@ -1225,14 +1231,10 @@ Literal LiteralBase::Slice(absl::Span start_indices, LayoutUtil::MinorToMajor(shape())); ShapeUtil::CopyDynamicDimensions(&result_shape, shape()); Literal result_literal(result_shape); - primitive_util::PrimitiveTypeSwitch( + primitive_util::ArrayTypeSwitch( [&](auto primitive_type_constant) -> void { - if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { - using NativeT = NativeTypeOf; - return SliceInternal(*this, start_indices, result_literal); - } - LOG(FATAL) << "not yet implemented: " - << PrimitiveType_Name(result_shape.element_type()); + using NativeT = NativeTypeOf; + return SliceInternal(*this, start_indices, result_literal); }, result_shape.element_type()); return result_literal; @@ -1262,27 +1264,23 @@ std::string LiteralBase::GetAsString(absl::Span multi_index, const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); CHECK(LayoutUtil::IsDenseArray(subshape)); - return primitive_util::PrimitiveTypeSwitch( + return primitive_util::ArrayTypeSwitch( [&](auto primitive_type_constant) -> std::string { - if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { - using NativeT = NativeTypeOf; - if constexpr (primitive_util::IsIntegralType( - primitive_type_constant)) { - return StrCat(Get(multi_index, shape_index)); - } - if constexpr (primitive_util::IsFloatingPointType( - primitive_type_constant)) { - return RoundTripFpToString(Get(multi_index, shape_index)); - } - if constexpr (primitive_util::IsComplexType( - primitive_type_constant)) { - NativeT c = Get(multi_index, shape_index); - return StrCat("(", RoundTripFpToString(c.real()), ", ", - RoundTripFpToString(c.imag()), ")"); - } - if constexpr (primitive_type_constant == PRED) { - return Get(multi_index, shape_index) ? "true" : "false"; - } + using NativeT = NativeTypeOf; + if constexpr (primitive_util::IsIntegralType(primitive_type_constant)) { + return StrCat(Get(multi_index, shape_index)); + } + if constexpr (primitive_util::IsFloatingPointType( + primitive_type_constant)) { + return RoundTripFpToString(Get(multi_index, shape_index)); + } + if constexpr (primitive_util::IsComplexType(primitive_type_constant)) { + NativeT c = Get(multi_index, shape_index); + return StrCat("(", RoundTripFpToString(c.real()), ", ", + RoundTripFpToString(c.imag()), ")"); + } + if constexpr (primitive_type_constant == PRED) { + return Get(multi_index, shape_index) ? "true" : "false"; } LOG(FATAL) << PrimitiveType_Name(subshape.element_type()); }, @@ -1325,19 +1323,19 @@ std::optional LiteralBase::GetSumAsDouble( const Shape& s = shape(); CHECK(LayoutUtil::IsDenseArray(s)); - return primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> std::optional { - if constexpr (primitive_util::IsFloatingPointType( - primitive_type_constant)) { - using NativeT = NativeTypeOf; - double sum = 0.0; - auto d = root_piece().data(); - for (const int64_t idx : linear_indices) { - sum += static_cast(d[idx]); - } - return sum; + if (!primitive_util::IsFloatingPointType(s.element_type())) { + return std::nullopt; + } + + return primitive_util::FloatingPointTypeSwitch( + [&](auto primitive_type_constant) -> double { + using NativeT = NativeTypeOf; + double sum = 0.0; + auto d = root_piece().data(); + for (const int64_t idx : linear_indices) { + sum += static_cast(d[idx]); } - return std::nullopt; + return sum; }, s.element_type()); } @@ -1388,18 +1386,17 @@ Status MutableLiteralBase::SetIntegralAsS64( Status MutableLiteralBase::SetFromDouble(absl::Span multi_index, double value) { CHECK(LayoutUtil::IsDenseArray(shape())); - return primitive_util::PrimitiveTypeSwitch( - [&](auto primitive_type_constant) -> Status { - if constexpr (primitive_util::IsFloatingPointType( - primitive_type_constant)) { - using NativeT = NativeTypeOf; - Set(multi_index, static_cast(value)); - return OkStatus(); - } - return FailedPrecondition("Array element type is not integral: %s", - PrimitiveType_Name(shape().element_type())); + if (!primitive_util::IsFloatingPointType(shape().element_type())) { + return FailedPrecondition("Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type())); + } + primitive_util::FloatingPointTypeSwitch( + [&](auto primitive_type_constant) -> void { + using NativeT = NativeTypeOf; + Set(multi_index, static_cast(value)); }, shape().element_type()); + return OkStatus(); } namespace { @@ -1631,7 +1628,7 @@ void LiteralBase::EachCellAsString( if (ShapeUtil::IsZeroElementArray(shape())) { return; } - std::vector indices = IndexUtil::LinearIndexToMultidimensionalIndex( + auto indices = IndexUtil::LinearIndexToMultidimensionalIndex( shape(), /*linear_index=*/0); do { per_cell(indices, GetAsString(indices)); @@ -1682,8 +1679,8 @@ void ConvertBetweenNativeTypes(absl::Span src_data, } template -void ConvertIfDestTypeMatches(const LiteralBase& src_literal, - MutableLiteralBase& dst_literal) { +Status ConvertIfDestTypeMatches(const LiteralBase& src_literal, + MutableLiteralBase& dst_literal) { DCHECK(dst_literal.shape().IsArray()); using NativeSrcT = NativeTypeOf; // Pass raw data Span/pointers to called template methods to avoid duplicating @@ -1691,31 +1688,32 @@ void ConvertIfDestTypeMatches(const LiteralBase& src_literal, auto src_data = src_literal.data(); void* dst_base = dst_literal.untyped_data(); DCHECK_EQ(src_data.size(), dst_literal.element_count()); - primitive_util::PrimitiveTypeSwitch( - [&](auto primitive_type_constant) -> void { - if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { - if constexpr (kSrcType != primitive_type_constant) { - using NativeDestT = NativeTypeOf; - ConvertBetweenNativeTypes(src_data, - dst_base); - } - return; + return primitive_util::ArrayTypeSwitch( + [&](auto primitive_type_constant) -> Status { + if constexpr (primitive_util::IsComplexType(kSrcType) && + !primitive_util::IsComplexType(primitive_type_constant)) { + return Unimplemented("%s from type %s to type %s is not implemented.", + "Converting", PrimitiveType_Name(kSrcType), + PrimitiveType_Name(primitive_type_constant())); + } else if constexpr (kSrcType != primitive_type_constant) { + using NativeDestT = NativeTypeOf; + ConvertBetweenNativeTypes(src_data, + dst_base); } - // This code path is impossible to hit. - LOG(FATAL) << "Unexpected type " << dst_literal.shape().element_type(); + return OkStatus(); }, dst_literal.shape().element_type()); } -StatusOr ConvertSwitch(const LiteralBase& literal, - PrimitiveType primitive_dest_type) { +absl::StatusOr ConvertSwitch(const LiteralBase& literal, + PrimitiveType primitive_dest_type) { TF_RET_CHECK(LayoutUtil::IsDenseArray(literal.shape())); if (literal.shape().element_type() == primitive_dest_type) { return literal.Clone(); } // Source Array type requirement is ensured by IsDenseArray before. if (!primitive_util::IsArrayType(primitive_dest_type) || - primitive_util::IsComplexType(literal.shape().element_type())) { + !primitive_util::IsArrayType(literal.shape().element_type())) { return Unimplemented("%s from type %s to type %s is not implemented.", "Converting", PrimitiveType_Name(literal.shape().element_type()), @@ -1726,29 +1724,24 @@ StatusOr ConvertSwitch(const LiteralBase& literal, // duplicating it N^2 times in the conversion implementation. Literal result( ShapeUtil::ChangeElementType(literal.shape(), primitive_dest_type)); - primitive_util::PrimitiveTypeSwitch( - [&](auto primitive_type_constant) -> void { - if constexpr (primitive_util::IsArrayType(primitive_type_constant) && - !primitive_util::IsComplexType(primitive_type_constant)) { - ConvertIfDestTypeMatches(literal, result); - return; - } - // Unsupported conversions are checked before this switch, this path is - // not possible to hit. - LOG(FATAL) << "Unexpected type " << literal.shape().element_type(); + TF_RETURN_IF_ERROR(primitive_util::ArrayTypeSwitch( + [&](auto primitive_type_constant) -> Status { + return ConvertIfDestTypeMatches(literal, + result); }, - literal.shape().element_type()); + literal.shape().element_type())); return result; } } // namespace -StatusOr LiteralBase::Convert( +absl::StatusOr LiteralBase::Convert( PrimitiveType primitive_dest_type) const { return ConvertSwitch(*this, primitive_dest_type); } -StatusOr LiteralBase::BitcastConvert(const Shape& dest_shape) const { +absl::StatusOr LiteralBase::BitcastConvert( + const Shape& dest_shape) const { if (ShapeUtil::ByteSizeOf(dest_shape) != ShapeUtil::ByteSizeOf(shape())) { return InvalidArgument( "Can not bitcast-convert from shape %s to a shape of different size %s", @@ -1788,7 +1781,8 @@ StatusOr LiteralBase::BitcastConvert(const Shape& dest_shape) const { return out; } -StatusOr LiteralBase::ConvertToShape(const Shape& dest_shape) const { +absl::StatusOr LiteralBase::ConvertToShape( + const Shape& dest_shape) const { if (!dest_shape.IsTuple()) { return Convert(dest_shape.element_type()); } @@ -1858,19 +1852,23 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { CHECK(LayoutUtil::IsDenseArray(subshape())) << __func__ << " is only supported for dense arrays: " << subshape(); CHECK_EQ(size_bytes_dense(), other.size_bytes_dense()); + if (primitive_util::Is4BitType(subshape().element_type())) { + auto one_array = buffer(); + auto two_array = other.buffer(); + for (int64_t i = 0; i < size_bytes_dense(); ++i) { + if ((one_array[i] & uint8_t{0xf}) != (two_array[i] & uint8_t{0xf})) + return false; + } + return true; + } return memcmp(buffer(), other.buffer(), size_bytes_dense()) == 0; } std::vector multi_index; - return primitive_util::PrimitiveTypeSwitch( + return primitive_util::ArrayTypeSwitch( [&](auto primitive_type_constant) -> bool { - if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { - using NativeSrcT = NativeTypeOf; - return EqualElementsInternal(other, &multi_index); - } - LOG(FATAL) - << "Unimplemented: LiteralBase::Piece::EqualElements for type " - << PrimitiveType_Name(subshape().element_type()); + using NativeSrcT = NativeTypeOf; + return EqualElementsInternal(other, &multi_index); }, subshape().element_type()); } @@ -1947,14 +1945,11 @@ bool Literal::Piece::IsAll(const Literal& scalar) const { CHECK(LayoutUtil::IsDenseArray(subshape())) << __func__ << " is only supported for dense arrays: " << subshape(); CHECK_EQ(subshape().element_type(), scalar.shape().element_type()); - return primitive_util::PrimitiveTypeSwitch( + return primitive_util::ArrayTypeSwitch( [&](auto primitive_type_constant) -> bool { - if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { - using NativeT = NativeTypeOf; - return AllElementsEqualValue(this->data(), - scalar.GetFirstElement()); - } - return false; + using NativeT = NativeTypeOf; + return AllElementsEqualValue(this->data(), + scalar.GetFirstElement()); }, subshape().element_type()); } @@ -1968,17 +1963,13 @@ int64_t Literal::Piece::CountAll(const Literal& scalar) const { CHECK(LayoutUtil::IsDenseArray(subshape())) << __func__ << " is only supported for dense arrays: " << subshape(); CHECK_EQ(subshape().element_type(), scalar.shape().element_type()); - return primitive_util::PrimitiveTypeSwitch( + return primitive_util::ArrayTypeSwitch( [&](auto primitive_type_constant) -> int64_t { - if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { - using NativeT = NativeTypeOf; - return absl::c_count_if( - this->data(), [&](NativeT elem) -> bool { - return EqualIncludingNan(elem, - scalar.GetFirstElement()); - }); - } - return 0; + using NativeT = NativeTypeOf; + return absl::c_count_if( + this->data(), [&](NativeT elem) -> bool { + return EqualIncludingNan(elem, scalar.GetFirstElement()); + }); }, subshape().element_type()); } @@ -1999,27 +1990,23 @@ bool LiteralBase::IsAll(int8_t value) const { return false; } Literal scalar(ShapeUtil::MakeScalarShape(ty)); - return primitive_util::PrimitiveTypeSwitch( + return primitive_util::ArrayTypeSwitch( [&](auto primitive_type_constant) -> bool { - if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { - using NativeT = NativeTypeOf; - NativeT converted(value); - if constexpr (primitive_util::IsFloatingPointType( - primitive_type_constant)) { - if (!Eigen::numext::isfinite(converted)) { - return false; - } + using NativeT = NativeTypeOf; + NativeT converted(value); + if constexpr (primitive_util::IsFloatingPointType( + primitive_type_constant)) { + if (!Eigen::numext::isfinite(converted)) { + return false; } - if constexpr (!primitive_util::IsComplexType( - primitive_type_constant)) { - if (static_cast(converted) != value) { - return false; - } + } + if constexpr (!primitive_util::IsComplexType(primitive_type_constant)) { + if (static_cast(converted) != value) { + return false; } - scalar.Set({}, converted); - return root_piece().IsAll(scalar); } - return false; + scalar.Set({}, converted); + return root_piece().IsAll(scalar); }, ty); } @@ -2029,41 +2016,34 @@ bool LiteralBase::IsAllFloat(float value) const { } bool LiteralBase::IsAllFloatImpl(float value, bool round_value) const { - if (!shape().IsArray()) { + PrimitiveType ty = shape().element_type(); + if (!primitive_util::IsFloatingPointType(ty)) { return false; } - PrimitiveType ty = shape().element_type(); Literal scalar(ShapeUtil::MakeScalarShape(ty)); - return primitive_util::PrimitiveTypeSwitch( + return primitive_util::FloatingPointTypeSwitch( [&](auto primitive_type_constant) -> bool { - if constexpr (primitive_util::IsFloatingPointType( - primitive_type_constant)) { - using NativeT = NativeTypeOf; - scalar.Set({}, static_cast(value)); - if (!round_value && scalar.GetAsDouble({}) != value) { - return false; - } - return root_piece().IsAll(scalar); + using NativeT = NativeTypeOf; + scalar.Set({}, static_cast(value)); + if (!round_value && scalar.GetAsDouble({}) != value) { + return false; } - return false; + return root_piece().IsAll(scalar); }, ty); } bool LiteralBase::IsAllComplex(complex64 value) const { - if (!shape().IsArray()) { + PrimitiveType ty = shape().element_type(); + if (!primitive_util::IsComplexType(ty)) { return false; } - PrimitiveType ty = shape().element_type(); Literal scalar(ShapeUtil::MakeScalarShape(ty)); - return primitive_util::PrimitiveTypeSwitch( + return primitive_util::ComplexTypeSwitch( [&](auto primitive_type_constant) -> bool { - if constexpr (primitive_util::IsComplexType(primitive_type_constant)) { - using NativeT = NativeTypeOf; - scalar.Set({}, static_cast(value)); - return root_piece().IsAll(scalar); - } - return false; + using NativeT = NativeTypeOf; + scalar.Set({}, static_cast(value)); + return root_piece().IsAll(scalar); }, ty); } @@ -2096,36 +2076,32 @@ bool LiteralBase::IsR1Iota() const { return false; } - return primitive_util::PrimitiveTypeSwitch( + return primitive_util::ArrayTypeSwitch( [&](auto primitive_type_constant) -> bool { - if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { - using NativeT = NativeTypeOf; - const int64_t elements = ShapeUtil::ElementsIn(shape()); - for (int64_t idx = 0; idx < elements; ++idx) { - if constexpr (primitive_util::IsIntegralType( - primitive_type_constant)) { - if (static_cast(Get({idx})) != idx) { - return false; - } - } else if constexpr (primitive_util::IsFloatingPointType( - primitive_type_constant)) { - if (Get({idx}) != static_cast(idx)) { - return false; - } - } else if constexpr (primitive_util::IsComplexType( - primitive_type_constant)) { - if (Get({idx}) != NativeT(idx, 0.0f)) { - return false; - } - } else { - // pred is not iota. + using NativeT = NativeTypeOf; + const int64_t elements = ShapeUtil::ElementsIn(shape()); + for (int64_t idx = 0; idx < elements; ++idx) { + if constexpr (primitive_util::IsIntegralType( + primitive_type_constant)) { + if (static_cast(Get({idx})) != idx) { + return false; + } + } else if constexpr (primitive_util::IsFloatingPointType( + primitive_type_constant)) { + if (Get({idx}) != static_cast(idx)) { return false; } + } else if constexpr (primitive_util::IsComplexType( + primitive_type_constant)) { + if (Get({idx}) != NativeT(idx, 0.0f)) { + return false; + } + } else { + // pred is not iota. + return false; } - return true; } - // token, opaque, tuple, etc. are all not iota. - return false; + return true; }, shape().element_type()); } @@ -2147,27 +2123,24 @@ std::optional LiteralBase::IsR1StridedIota() const { return std::nullopt; } - return primitive_util::PrimitiveTypeSwitch>( + return primitive_util::IntegralTypeSwitch>( [&](auto primitive_type_constant) -> std::optional { - if constexpr (primitive_util::IsIntegralType(primitive_type_constant)) { - using NativeT = NativeTypeOf; + using NativeT = NativeTypeOf; - // Infer the stride as the second element (since first element is - // supposed to be zero). - const int64_t stride = static_cast(Get({1})); - if (stride == 0) { - return std::nullopt; - } + // Infer the stride as the second element (since first element is + // supposed to be zero). + const int64_t stride = static_cast(Get({1})); + if (stride == 0) { + return std::nullopt; + } - for (int64_t idx = 0; idx < elements; ++idx) { - if (static_cast(Get({idx})) != idx * stride) { - return std::nullopt; - } + for (int64_t idx = 0; idx < elements; ++idx) { + if (static_cast(Get({idx})) != idx * stride) { + return std::nullopt; } - - return stride; } - return std::nullopt; + + return stride; }, shape().element_type()); } @@ -2175,13 +2148,10 @@ std::optional LiteralBase::IsR1StridedIota() const { bool LiteralBase::IsZero(absl::Span indices) const { CHECK(LayoutUtil::IsDenseArray(shape())) << __func__ << " is only supported for dense arrays: " << shape(); - return primitive_util::PrimitiveTypeSwitch( + return primitive_util::ArrayTypeSwitch( [&](auto primitive_type_constant) -> bool { - if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { - using NativeT = NativeTypeOf; - return Get(indices) == NativeT{0}; - } - LOG(FATAL) << "Input literal must be an array."; + using NativeT = NativeTypeOf; + return Get(indices) == NativeT{0}; }, shape().element_type()); } @@ -2213,9 +2183,6 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { case S4: *proto->mutable_s4s() = std::string( reinterpret_cast(data().data()), size_bytes_dense()); - if (!kLittleEndian) { - ConvertEndianShort(proto->mutable_s4s()); - } break; case S8: proto->set_s8s(static_cast(data().data()), @@ -2224,9 +2191,6 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { case U4: *proto->mutable_u4s() = std::string( reinterpret_cast(data().data()), size_bytes_dense()); - if (!kLittleEndian) { - ConvertEndianShort(proto->mutable_u4s()); - } break; case U8: proto->set_u8s(static_cast(data().data()), @@ -2374,9 +2338,6 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { const std::string& s(proto.s4s()); TF_RET_CHECK(data().size() * sizeof(s4) == s.size()); memcpy(untyped_data(), s.data(), s.size()); - if (!kLittleEndian) { - ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); - } } break; case S8: { auto s8_data = data(); @@ -2387,9 +2348,6 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { const std::string& s(proto.u4s()); TF_RET_CHECK(data().size() * sizeof(u4) == s.size()); memcpy(untyped_data(), s.data(), s.size()); - if (!kLittleEndian) { - ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); - } } break; case U8: { auto u8_data = data(); diff --git a/xla/literal.h b/xla/literal.h index 2c6b43cad4e1c..f1a8d65ae8c11 100644 --- a/xla/literal.h +++ b/xla/literal.h @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ limitations under the License. #define XLA_LITERAL_H_ #include +#include #include #include #include @@ -31,6 +32,9 @@ limitations under the License. #include #include +#include "absl/base/attributes.h" +#include "absl/base/casts.h" +#include "absl/base/config.h" #include "absl/functional/function_ref.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -55,6 +59,7 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/macros.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -66,6 +71,8 @@ class LiteralSlice; // Abstract base class for literals. class LiteralBase { public: + using DynamicSizeType = ShapeUtil::DynamicSizeType; + virtual ~LiteralBase() = 0; // Literals are equal if they have compatible shapes and the same data @@ -92,6 +99,34 @@ class LiteralBase { const void* untyped_data(const ShapeIndex& shape_index = {}) const; int64_t size_bytes(const ShapeIndex& shape_index = {}) const; + // Computes the size in bytes of the output of the Serialize method. + absl::StatusOr SerializedSize() const { + return ShapeUtil::SerializedSize(shape()); + } + + // Serialize the Literal into the given output iterator, whose value_type must + // be char. It's up to the caller to ensure that output can store + // SerializedSize() bytes of data. This can be ensured by using + // std::back_inserter, or by manually resizing the target container. + // This serializer is useful for bypassing the 2GB protobuf serialization + // limit with very large literals, and it should be faster than protobuf + // serialization when performance is a concern. + // The serialization format should not be relied on for forward/backward + // compatibility. If compatibility is required, you should use protobuf + // serialization instead. + template + Status Serialize(OutputIterator output) const { + return SerializeWithShapeProto(shape().ToProto(), output); + } + + // Serialize the Literal into the given string. This method has the same + // caveats as the Serialize() method above. + absl::Status SerializeToString(std::string* output) const; + + // Serialize the Literal into a string and return it. This method has the + // same caveats as the Serialize() method above. + absl::StatusOr SerializeAsString() const; + // Returns this literal's data as a string. This literal must be a rank-1 U8 // array. std::string GetR1U8AsString() const; @@ -157,9 +192,9 @@ class LiteralBase { NativeT Get(absl::Span multi_index) const; // Get the dynamic size on dim_index in the literal at the given shape_index. - int32_t GetDynamicSize(int64_t dim_index, - const ShapeIndex& shape_index) const; - int32_t GetDynamicSize(int64_t dim_index) const; + DynamicSizeType GetDynamicSize(int64_t dim_index, + const ShapeIndex& shape_index) const; + DynamicSizeType GetDynamicSize(int64_t dim_index) const; // Returns the element value at index (0, ..., 0), however many zeroes are // required for that index. @@ -342,16 +377,16 @@ class LiteralBase { // Converts this literal to the given shape. Returns an error is the // conversion is not possible. - StatusOr ConvertToShape(const Shape& dest_shape) const; + absl::StatusOr ConvertToShape(const Shape& dest_shape) const; // Converts this literal to another primitive type using a bitcast // conversion. Returns an error if the conversion is not possible. This // literal must be array-shaped. - StatusOr BitcastConvert(const Shape& dest_shape) const; + absl::StatusOr BitcastConvert(const Shape& dest_shape) const; // Converts this literal to another primitive type. Returns an error if the // conversion is not possible. This literal must be array-shaped. - StatusOr Convert(PrimitiveType primitive_dest_type) const; + absl::StatusOr Convert(PrimitiveType primitive_dest_type) const; // Clones the underlying buffers into a new Literal. Literal Clone() const; @@ -396,12 +431,12 @@ class LiteralBase { // dimensions. The total number of elements must not change; The // implementation currently only supports monotonic dim0-major layouts. // This literal must be an array. - StatusOr Reshape(absl::Span dimensions) const; + absl::StatusOr Reshape(absl::Span dimensions) const; // Creates a new literal by broadcasting this literal with `dimensions` to // yield a literal of shape `result_shape`. - StatusOr Broadcast(const Shape& result_shape, - absl::Span dimensions) const; + absl::StatusOr Broadcast(const Shape& result_shape, + absl::Span dimensions) const; // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers @@ -457,6 +492,248 @@ class LiteralBase { static Literal CreateFromShapeWithUndeterminedLeafArrays(const Shape& shape); protected: + template + Status SerializeWithShapeProto(const ShapeProto& proto, + OutputIterator output) const; + + template + class SerializeState { + public: + SerializeState(const ShapeProto& shape, OutputIterator output) + : output_(output) { + WriteShape(shape); + } + + int64_t num_written() const { return num_written_; } + + template + void WriteElement(NativeT element) { + constexpr PrimitiveType primitive_type = + primitive_util::NativeToPrimitiveType(); + static_assert(primitive_type != PRED); + static_assert(!primitive_util::Is4BitType(primitive_type)); + if constexpr (primitive_util::IsComplexType(primitive_type)) { + WriteElement(element.real()); + WriteElement(element.imag()); + } else { + constexpr PrimitiveType unsigned_type = + primitive_util::UnsignedIntegralTypeForBitWidth( + primitive_util::BitWidth(primitive_type)); + using UnsignedT = primitive_util::NativeTypeOf; + UnsignedT unsigned_element = absl::bit_cast(element); + if constexpr (sizeof(UnsignedT) == 1) { + *output_++ = absl::bit_cast(unsigned_element); + ++num_written_; + } else { + for (int i = 0; i < sizeof unsigned_element; ++i) { + *output_++ = static_cast(unsigned_element); + unsigned_element >>= CHAR_BIT; + ++num_written_; + } + } + } + } + + template + void WriteElements(absl::Span elements) { + if constexpr (std::is_same_v) { + int64_t bytes = elements.size() / 8; + for (int64_t i = 0; i < bytes; ++i) { + uint8_t byte = 0; + for (int b = 0; b < 8; ++b) { + if (elements[i * 8 + b]) { + byte |= uint8_t{1} << b; + } + } + WriteElement(byte); + } + int64_t rest = elements.size() % 8; + if (rest != 0) { + uint8_t byte = 0; + for (int64_t b = 0; b < rest; ++b) { + if (elements[bytes * 8 + b]) { + byte |= uint8_t{1} << b; + } + } + WriteElement(byte); + } + } else if constexpr (primitive_util::Is4BitType( + primitive_util::NativeToPrimitiveType< + NativeT>())) { + int64_t bytes = elements.size() / 2; + for (int64_t i = 0; i < bytes; ++i) { + uint8_t low = static_cast(elements[i * 2]); + uint8_t high = static_cast(elements[i * 2 + 1]); + uint8_t byte = (low & uint8_t{0xf}) | (high << 4); + WriteElement(byte); + } + if (elements.size() % 2 != 0) { + uint8_t last = static_cast(elements.back()) & uint8_t{0xf}; + WriteElement(last); + } + } else { + for (NativeT element : elements) { + WriteElement(element); + } + } + } + + void WriteDynamicSizes(absl::Span sizes) { + WriteElements(sizes); + } + + private: + void WriteShape(const ShapeProto& proto) { + std::string shape_bytes = proto.SerializeAsString(); + uint64_t shape_size = shape_bytes.size(); + WriteElement(shape_size); + output_ = std::copy(shape_bytes.begin(), shape_bytes.end(), output_); + num_written_ += shape_bytes.size(); + } + + OutputIterator output_; + int64_t num_written_ = 0; + }; + + template + class DeserializeState { + public: + DeserializeState(InputIterator input, InputIterator end) + : input_(input), end_(end) {} + + int64_t num_read() const { return num_read_; } + + template + ABSL_MUST_USE_RESULT bool ReadElement(NativeT& element) { + constexpr PrimitiveType primitive_type = + primitive_util::NativeToPrimitiveType(); + static_assert(!primitive_util::Is4BitType(primitive_type)); + static_assert(primitive_type != PRED); + if constexpr (primitive_util::IsComplexType(primitive_type)) { + using ComponentT = + primitive_util::NativeTypeOf; + ComponentT real; + if (!ReadElement(real)) { + return false; + } + ComponentT imag; + if (!ReadElement(imag)) { + return false; + } + element = NativeT(real, imag); + } else { + constexpr PrimitiveType unsigned_type = + primitive_util::UnsignedIntegralTypeForBitWidth( + primitive_util::BitWidth(primitive_type)); + using UnsignedT = primitive_util::NativeTypeOf; + if constexpr (sizeof(UnsignedT) == 1) { + if (at_end()) { + return false; + } + element = absl::bit_cast(*input_++); + ++num_read_; + } else { + UnsignedT unsigned_element = 0; + for (int i = 0, shift = 0; i < sizeof unsigned_element; + ++i, shift += CHAR_BIT) { + if (at_end()) { + return false; + } + unsigned_element |= + static_cast(static_cast(*input_++)) + << shift; + ++num_read_; + } + element = absl::bit_cast(unsigned_element); + } + } + return true; + } + + template + ABSL_MUST_USE_RESULT bool ReadElements(absl::Span elements) { + if constexpr (std::is_same_v) { + int64_t bytes = elements.size() / 8; + for (int64_t i = 0; i < bytes; ++i) { + uint8_t byte; + if (!ReadElement(byte)) { + return false; + } + for (int b = 0; b < 8; ++b) { + elements[i * 8 + b] = !!(byte & (uint8_t{1} << b)); + } + } + int64_t rest = elements.size() % 8; + if (rest != 0) { + uint8_t byte; + if (!ReadElement(byte)) { + return false; + } + for (int64_t b = 0; b < rest; ++b) { + elements[bytes * 8 + b] = !!(byte & (uint8_t{1} << b)); + } + } + } else if constexpr (primitive_util::Is4BitType( + primitive_util::NativeToPrimitiveType< + NativeT>())) { + int64_t bytes = elements.size() / 2; + for (int64_t i = 0; i < bytes; ++i) { + uint8_t byte; + if (!ReadElement(byte)) { + return false; + } + elements[i * 2] = static_cast(byte & uint8_t{0xf}); + elements[i * 2 + 1] = static_cast(byte >> 4); + } + if (elements.size() % 2 != 0) { + uint8_t last; + if (!ReadElement(last)) { + return false; + } + elements.back() = static_cast(last); + } + } else { + for (NativeT& element : elements) { + if (!ReadElement(element)) { + return false; + } + } + } + return true; + } + + bool ReadDynamicSizes(absl::Span sizes) { + return ReadElements(sizes); + } + + absl::StatusOr ReadShape(uint64_t size) { + std::string shape_bytes; + shape_bytes.reserve(size); + while (shape_bytes.size() < size) { + if (at_end()) { + return InvalidArgument("Failed to read shape data"); + } + shape_bytes.push_back(*input_++); + ++num_read_; + } + ShapeProto proto; + if (!proto.ParseFromString(shape_bytes)) { + return InvalidArgument("Failed to parse shape protobuf"); + } + Shape shape(proto); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); + return std::move(shape); + } + + bool at_end() const { return input_ == end_; } + + private: + InputIterator input_; + InputIterator end_; + int64_t num_read_ = 0; + }; + // Array literals could be in one of the following three states: // 1) Known: we have evaluated and known the value of the array literal. // 2) Unknown: we have tried to evaluate the array literal, but its value @@ -496,8 +773,8 @@ class LiteralBase { template void Set(absl::Span index, NativeT value); - int32_t GetDynamicSize(int64_t dim_index) const; - void SetDynamicSize(int64_t dim_index, int32_t size); + DynamicSizeType GetDynamicSize(int64_t dim_index) const; + void SetDynamicSize(int64_t dim_index, DynamicSizeType size); void AllocateBuffers(); void DeallocateBuffers(); // Gets/sets the buffer holding the array data. @@ -525,8 +802,6 @@ class LiteralBase { from.rep_.emplace(); } - using DynamicSizeType = int32_t; - // Gets/sets the buffer holding dynamic sizes. const DynamicSizeType* dynamic_size_buffer() const { DCHECK(LayoutUtil::IsDenseArray(*subshape_)); @@ -534,7 +809,7 @@ class LiteralBase { buffer() + dynamic_size_buffer_offset()); } DynamicSizeType* dynamic_size_buffer() { - return const_cast( + return const_cast( const_cast(this)->dynamic_size_buffer()); } @@ -702,6 +977,15 @@ class LiteralBase { bool IsKnown() const; + // Serialize the data contained by this Piece into the given serialization + // state. + template + void SerializeData(SerializeState& state) const; + + // Deserialize the data for this Piece from the given serialization state. + template + bool DeserializeData(DeserializeState& state); + private: // Uninitialized state representation. struct Uninitialized {}; @@ -842,8 +1126,8 @@ class MutableLiteralBase : public LiteralBase { // Set the dynamic size on dim_index in the literal at the given shape_index. void SetDynamicSize(int64_t dim_index, const ShapeIndex& shape_index, - int32_t size); - void SetDynamicSize(int64_t dim_index, int32_t size); + DynamicSizeType size); + void SetDynamicSize(int64_t dim_index, DynamicSizeType size); // Returns a pointer to the underlying buffer holding the array at the given // shape index. CHECKs if the subshape of the literal at the given ShapeIndex @@ -976,8 +1260,8 @@ class MutableLiteralBase : public LiteralBase { static Literal MoveIntoTuple(absl::Span elements); // Serialize from a proto. - static StatusOr CreateFromProto(const LiteralProto& proto, - bool prohibit_empty_literal = true); + static absl::StatusOr CreateFromProto( + const LiteralProto& proto, bool prohibit_empty_literal = true); protected: // Returns the piece at the given ShapeIndex. @@ -1156,6 +1440,16 @@ class Literal : public MutableLiteralBase { // ref-qualified with &&. Literal SubLiteral(ShapeIndexView shape_index); + // Deserialize a Literal from the given iterator range, whose value type must + // be char. See the comments on the Serialize() method for caveats. + template + static absl::StatusOr Deserialize(InputIterator begin, + InputIterator end); + + static absl::StatusOr DeserializeFromString(std::string_view data) { + return Deserialize(data.data(), data.data() + data.size()); + } + private: friend class LiteralBase; friend class MutableLiteralBase; @@ -1163,6 +1457,11 @@ class Literal : public MutableLiteralBase { // Deallocate the buffers held by this literal. void DeallocateBuffers(); + // Sets the shape_ field from a Shape. shape_'s element_size_in_bits field + // on the layout is always set to 0 since Literals do not support packed + // subbyte elements. + void SetShape(const Shape& shape); + // Recursively sets the subshapes and buffers of all subpieces rooted at // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in // the shape. @@ -1252,6 +1551,121 @@ class BorrowingLiteral : public LiteralBase { std::unique_ptr shape_; }; +template +void LiteralBase::Piece::SerializeData( + SerializeState& state) const { + CHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType()); + if (subshape().is_dynamic()) { + absl::Span sizes(dynamic_size_buffer(), + subshape().rank()); + state.WriteDynamicSizes(sizes); + } + state.WriteElements(data()); +} + +template +bool LiteralBase::Piece::DeserializeData( + DeserializeState& state) { + CHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType()); + if (subshape().is_dynamic()) { + absl::Span sizes(dynamic_size_buffer(), subshape().rank()); + if (!state.ReadDynamicSizes(sizes)) { + return false; + } + } + return state.ReadElements(data()); +} + +// Description of the native serialization format: +// +// - All data are stored in little-endian order. +// +// - The serialized format begins with a header. +// +// - The first 8 bytes (int64_t) of the header are the size of the serialized +// ShapeProto that provides the shape of the literal. +// +// - The remaining bytes of the header provide the serialized ShapeProto itself. +// +// - After the header, each piece of the literal is serialized, as produced +// through a depth-first traversal of the tuple tree. +// +// - If a piece is dynamic, we first write the sizes of the dynamic dimensions. +// +// - The elements of the piece are then written. Elements smaller than a single +// byte (PRED, S4, U4) are packed into bytes. Otherwise, they are written in +// little-endian byte order. +template +Status LiteralBase::SerializeWithShapeProto(const ShapeProto& shape_proto, + OutputIterator output) const { + SerializeState state(shape_proto, output); + TF_RETURN_IF_ERROR(root_piece().ForEachSubpieceWithStatus( + [&](const ShapeIndex& shape_index, const Piece& piece) -> absl::Status { + const Shape& subshape = piece.subshape(); + if (subshape.IsTuple()) { + return OkStatus(); + } + if (!subshape.IsArray()) { + return InvalidArgument("Shape cannot be serialized: %s", + shape().ToString()); + } + primitive_util::ArrayTypeSwitch( + [&](auto primitive_type) { + using NativeT = primitive_util::NativeTypeOf; + piece.SerializeData(state); + }, + subshape.element_type()); + return OkStatus(); + })); + DCHECK_EQ(state.num_written(), SerializedSize().value()) + << shape().ToString(); + return OkStatus(); +} + +template +absl::StatusOr Literal::Deserialize(InputIterator begin, + InputIterator end) { + DeserializeState state(begin, end); + uint64_t shape_size; + if (!state.ReadElement(shape_size)) { + return InvalidArgument("Failed to read shape size"); + } + TF_ASSIGN_OR_RETURN(Shape shape, state.ReadShape(shape_size)); + Literal literal(shape); + TF_RETURN_IF_ERROR( + literal.mutable_root_piece().ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& shape_index, Piece* piece) -> absl::Status { + const Shape& subshape = piece->subshape(); + if (subshape.IsTuple()) { + return OkStatus(); + } + if (!subshape.IsArray()) { + return InvalidArgument("Shape cannot be deserialized: %s", + shape.ToString()); + } + bool ok = primitive_util::ArrayTypeSwitch( + [&](auto primitive_type) { + using NativeT = primitive_util::NativeTypeOf; + return piece->DeserializeData(state); + }, + subshape.element_type()); + if (!ok) { + return InvalidArgument( + "Failed to deserialize all data for shape: %s", + shape.ToString()); + } + return OkStatus(); + })); + DCHECK_EQ(state.num_read(), ShapeUtil::SerializedSize(shape).value()) + << shape.ToString(); + if (!state.at_end()) { + return InvalidArgument("Did not consume all input data"); + } + return std::move(literal); +} + template absl::Span LiteralBase::Piece::data() const { DCHECK(LayoutUtil::IsDenseArray(subshape())) @@ -1349,38 +1763,32 @@ NativeT LiteralBase::GetFirstElement() const { template int64_t LiteralBase::CountEqual(T value) const { - if (!shape().IsArray()) { + PrimitiveType ty = shape().element_type(); + if (!primitive_util::IsArrayType(ty)) { return 0; } - PrimitiveType ty = shape().element_type(); Literal scalar(ShapeUtil::MakeScalarShape(ty)); - return primitive_util::PrimitiveTypeSwitch( + return primitive_util::ArrayTypeSwitch( [&](auto primitive_type_constant) -> int64_t { - if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { - using NativeT = primitive_util::NativeTypeOf; - scalar.Set({}, static_cast(value)); - return root_piece().CountAll(scalar); - } - return 0; + using NativeT = primitive_util::NativeTypeOf; + scalar.Set({}, static_cast(value)); + return root_piece().CountAll(scalar); }, ty); } template int64_t LiteralBase::CountEqual(std::complex value) const { - if (!shape().IsArray()) { + PrimitiveType ty = shape().element_type(); + if (!primitive_util::IsComplexType(ty)) { return 0; } - PrimitiveType ty = shape().element_type(); Literal scalar(ShapeUtil::MakeScalarShape(ty)); - return primitive_util::PrimitiveTypeSwitch( + return primitive_util::ComplexTypeSwitch( [&](auto primitive_type_constant) -> int64_t { - if constexpr (primitive_util::IsComplexType(primitive_type_constant)) { - using NativeT = primitive_util::NativeTypeOf; - scalar.Set({}, static_cast(value)); - return root_piece().CountAll(scalar); - } - return 0; + using NativeT = primitive_util::NativeTypeOf; + scalar.Set({}, static_cast(value)); + return root_piece().CountAll(scalar); }, ty); } diff --git a/xla/literal_comparison.cc b/xla/literal_comparison.cc index 3d32bb48fb8ce..4da0d71d1ff77 100644 --- a/xla/literal_comparison.cc +++ b/xla/literal_comparison.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ limitations under the License. #include "xla/literal_comparison.h" +#include + #ifndef _WIN32 #include #endif @@ -47,8 +49,8 @@ limitations under the License. #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" -#include "tsl/platform/float8.h" #include "tsl/platform/logging.h" // IWYU pragma: keep +#include "tsl/platform/ml_dtypes.h" using absl::StrAppend; using absl::StrAppendFormat; @@ -358,8 +360,8 @@ class NearComparator { // used for sorting a std::set of the top mismatches, and a nan value // here will result in undefined behavior because nan's do not satisfy // the strict weak ordering requirement of std containers. - abs_error = std::numeric_limits::infinity(); - rel_error = std::numeric_limits::infinity(); + abs_error = std::numeric_limits::infinity(); + rel_error = std::numeric_limits::infinity(); } else { abs_error = 0; rel_error = 0; @@ -376,14 +378,14 @@ class NearComparator { if (expected != T{0}) { rel_error = abs_error / FpAbsoluteValue(expected); } else { - rel_error = std::numeric_limits::infinity(); + rel_error = std::numeric_limits::infinity(); } } else if (IsInf(expected) || IsInf(actual)) { // If either the expected or actual value is infinity but not both, // then both absolute and relative error are regarded as infinity. CHECK(!CompareEqual(expected, actual, {linear_index})); - abs_error = std::numeric_limits::infinity(); - rel_error = std::numeric_limits::infinity(); + abs_error = std::numeric_limits::infinity(); + rel_error = std::numeric_limits::infinity(); } else { abs_error = FpAbsoluteValue(actual - expected); @@ -392,7 +394,7 @@ class NearComparator { if (expected != T{0}) { rel_error = abs_error / FpAbsoluteValue(expected); } else { - rel_error = std::numeric_limits::infinity(); + rel_error = std::numeric_limits::infinity(); } } const bool is_abs_mismatch = abs_error > error_.abs; @@ -433,25 +435,12 @@ class NearComparator { } // For complex types, we compare real and imaginary parts individually. - void CompareValues(complex64 expected, complex64 actual, - int64_t linear_index) { - const auto both_parts_mismatch = num_mismatches_ + 2; - CompareValues(expected.real(), actual.real(), linear_index); - CompareValues(expected.imag(), actual.imag(), linear_index); - if (num_mismatches_ == both_parts_mismatch) { - // The mismatch counter had been incremented by each CompareValues() call, - // which means that both real and imaginary parts of the passed-in complex - // values are different. However, the counter should reflect a single - // mismatch between these complex values. - num_mismatches_--; - } - } - - void CompareValues(complex128 expected, complex128 actual, + template + void CompareValues(std::complex expected, std::complex actual, int64_t linear_index) { const auto both_parts_mismatch = num_mismatches_ + 2; - CompareValues(expected.real(), actual.real(), linear_index); - CompareValues(expected.imag(), actual.imag(), linear_index); + CompareValues(expected.real(), actual.real(), linear_index); + CompareValues(expected.imag(), actual.imag(), linear_index); if (num_mismatches_ == both_parts_mismatch) { // The mismatch counter had been incremented by each CompareValues() call, // which means that both real and imaginary parts of the passed-in complex @@ -510,8 +499,8 @@ class NearComparator { std::string out; int64_t element_count = ShapeUtil::ElementsIn(actual_.shape()); - auto percent_string = [](float a, float b) { - float pct = b == 0.0 ? 0.0 : 100.0 * a / b; + auto percent_string = [](double a, double b) { + double pct = b == 0.0 ? 0.0 : 100.0 * a / b; return absl::StrFormat("%0.4f%%", pct); }; @@ -619,8 +608,9 @@ class NearComparator { // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the // bounds of these buckets. abs_value_buckets_ contains a pair for each // bucket: the element count and failure count. - static constexpr std::array kAbsValueBucketBounds = { - 0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits::infinity()}; + static inline constexpr std::array kAbsValueBucketBounds = { + 0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits::infinity(), + }; std::vector> abs_value_buckets_; // Buckets for relative and absolute errors. The relative error buckets only @@ -631,17 +621,12 @@ class NearComparator { // a cumulative distribution so an error value may appear in more than one // bucket. For example an error value of 0.003 may appear in the buckets // bounded by 0.01, 0.1, and 1.0. - static constexpr std::array kErrorBucketBounds = {0.0001, 0.001, - 0.01, 0.1, 1}; + static inline constexpr std::array kErrorBucketBounds = { + 0.0001, 0.001, 0.01, 0.1, 1}; std::vector abs_error_buckets_; std::vector rel_error_buckets_; }; -template -constexpr std::array NearComparator::kAbsValueBucketBounds; -template -constexpr std::array NearComparator::kErrorBucketBounds; - Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual, const ShapeIndex& shape_index, const MiscompareCallback& miscompare_callback) { @@ -829,8 +814,9 @@ Status EqualDynamicShapesAndDimensions(const LiteralSlice& expected, const LiteralSlice& actual) { TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); return ShapeUtil::ForEachSubshapeWithStatus( - expected.shape(), [&expected, &actual](const Shape& expected_shape, - const ShapeIndex& index) { + expected.shape(), + [&expected, &actual](const Shape& expected_shape, + const ShapeIndex& index) -> absl::Status { auto actual_shape = ShapeUtil::GetSubshape(actual.shape(), index); for (int i = 0; i < expected_shape.dimensions().size(); ++i) { if (!expected_shape.is_dynamic_dimension(i) && diff --git a/xla/literal_comparison.h b/xla/literal_comparison.h index fc30fea3b243c..06d416f89dc4c 100644 --- a/xla/literal_comparison.h +++ b/xla/literal_comparison.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/literal_test.cc b/xla/literal_test.cc index 6bce20ab9ca89..24c12bb92d6d8 100644 --- a/xla/literal_test.cc +++ b/xla/literal_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,12 +21,14 @@ limitations under the License. #include #include #include +#include #include #include #include #include #include "absl/base/casts.h" +#include "absl/random/random.h" #include "absl/strings/match.h" #include "absl/types/span.h" #include "xla/array.h" @@ -43,12 +45,13 @@ limitations under the License. #include "xla/status.h" #include "xla/test.h" #include "xla/types.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" -#include "tsl/platform/float8.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/macros.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test_benchmark.h" @@ -2791,6 +2794,113 @@ TEST_F(LiteralUtilTest, PopulateR3FromArray3DDynamicDim2) { EXPECT_EQ(expected, literal.ToString()); } +TEST_F(LiteralUtilTest, Compare4BitType) { + Literal literal1 = Literal(ShapeUtil::MakeShape(S4, {})); + Literal literal2 = Literal(ShapeUtil::MakeShape(S4, {})); + void* p = literal1.untyped_data(); + void* q = literal2.untyped_data(); + *((uint8_t*)p) = 0x44; + *((uint8_t*)q) = 0xc4; + std::string expected = R"(s4[] 4)"; + EXPECT_EQ(expected, literal1.ToString()); + EXPECT_EQ(literal1.ToString(), literal2.ToString()); + EXPECT_EQ(literal1, literal2); +} + +class LiteralSerializationTest : public ::testing::Test, + public ::testing::WithParamInterface { + public: + static std::vector GenerateSimpleParams() { + std::vector params; + for (PrimitiveType element_type : + {PRED, S4, U4, S8, U8, S16, + U16, S32, U32, S64, U64, F16, + F32, F64, BF16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, + F8E5M2FNUZ, F8E4M3FNUZ, C64, C128}) { + for (const DimensionVector& dimensions : { + DimensionVector{}, + DimensionVector{0}, + DimensionVector{1}, + DimensionVector{7}, + DimensionVector{8}, + DimensionVector{9}, + DimensionVector{0, 8}, + DimensionVector{8, 9}, + }) { + params.push_back(ShapeUtil::MakeShape(element_type, dimensions)); + } + } + return params; + } + + static std::vector GenerateTupleParams() { + std::vector params; + const Shape tuple_elements[] = { + ShapeUtil::MakeShape(PRED, {}), + ShapeUtil::MakeShape(U4, {3}), + ShapeUtil::MakeShape(U32, {0}), + ShapeUtil::MakeShape(F32, {7}), + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(BF16, {3}), + ShapeUtil::MakeShape(C64, {7}), + }), + }; + for (const Shape& lhs : tuple_elements) { + for (const Shape& rhs : tuple_elements) { + params.push_back(ShapeUtil::MakeTupleShape({lhs, rhs})); + } + } + return params; + } +}; + +TEST_P(LiteralSerializationTest, Test) { + const Shape& shape = GetParam(); + LOG(INFO) << "shape: " << shape.ToString(); + absl::InsecureBitGen bitgen(std::seed_seq({42})); + Literal literal(shape); + ASSERT_NO_FATAL_FAILURE(ShapeUtil::ForEachSubshape( + shape, [&](const Shape& subshape, const ShapeIndex& shape_index) { + if (subshape.IsTuple()) { + return; + } + ASSERT_TRUE(subshape.IsArray()); + primitive_util::ArrayTypeSwitch( + [&](auto primitive_type) { + using NativeT = primitive_util::NativeTypeOf; + for (auto& element : literal.data(shape_index)) { + if constexpr (std::is_same_v) { + element = absl::Uniform(bitgen, 0, 2); + } else if constexpr (primitive_util::IsComplexType( + primitive_type)) { + element = NativeT(absl::Uniform(bitgen, -1.0, 1.0), + absl::Uniform(bitgen, -1.0, 1.0)); + } else if constexpr (primitive_util::IsFloatingPointType( + primitive_type)) { + element = static_cast( + absl::Uniform(bitgen, -1.0, 1.0)); + } else { + element = + static_cast(absl::Uniform(bitgen)); + } + } + }, + subshape.element_type()); + })); + TF_ASSERT_OK_AND_ASSIGN(std::string serialized, literal.SerializeAsString()); + TF_ASSERT_OK_AND_ASSIGN(Literal deserialized, + Literal::DeserializeFromString(serialized)); + EXPECT_EQ(literal, deserialized); +} + +INSTANTIATE_TEST_SUITE_P( + Simple, LiteralSerializationTest, + ::testing::ValuesIn(LiteralSerializationTest::GenerateSimpleParams())); + +INSTANTIATE_TEST_SUITE_P( + Tuples, LiteralSerializationTest, + ::testing::ValuesIn(LiteralSerializationTest::GenerateTupleParams())); + void BM_BroadcastVectorToMatrix(::testing::benchmark::State& state) { const int d0 = state.range(0); const int d1 = state.range(1); diff --git a/xla/literal_util.cc b/xla/literal_util.cc index 572483672a7b6..6afa977248567 100644 --- a/xla/literal_util.cc +++ b/xla/literal_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -40,8 +40,8 @@ limitations under the License. #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/bitmap.h" -#include "tsl/platform/float8.h" #include "tsl/platform/logging.h" // IWYU pragma: keep +#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/status.h" namespace xla { @@ -299,10 +299,10 @@ void SetScalarAtIndexImpl(MutableLiteralBase& literal, return CreateScalar(primitive_type); } -/* static */ StatusOr LiteralUtil::NanValue( +/* static */ absl::StatusOr LiteralUtil::NanValue( PrimitiveType primitive_type) { return primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> StatusOr { + [&](auto primitive_type_constant) -> absl::StatusOr { if constexpr (primitive_util::IsFloatingPointType( primitive_type_constant)) { using NativeT = typename primitive_util::PrimitiveTypeToNative< @@ -367,9 +367,9 @@ void SetScalarAtIndexImpl(MutableLiteralBase& literal, // Copy data into new literal, element-by-element. for (int64_t i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { - std::vector from_multi_index = + auto from_multi_index = IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); - std::vector to_multi_index = + auto to_multi_index = IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); primitive_util::PrimitiveTypeSwitch( [&](auto primitive_type_constant) -> void { diff --git a/xla/literal_util.h b/xla/literal_util.h index e125ef7fe30b7..5a27f1b510088 100644 --- a/xla/literal_util.h +++ b/xla/literal_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -76,6 +76,8 @@ class LiteralUtil { // literal's linear representation in memory. template static Literal CreateR0(NativeT value); + template + static Literal CreateR0(PrimitiveType primitive_type, T value); template static Literal CreateR1(absl::Span values); static Literal CreateR1(const tsl::core::Bitmap& values); @@ -121,7 +123,7 @@ class LiteralUtil { // Creates a scalar literal value containing the NaN value of the given // primitive type. Fail for non-inexact types. For complex types, returns a // nan + nan * j value. - static StatusOr NanValue(PrimitiveType primitive_type); + static absl::StatusOr NanValue(PrimitiveType primitive_type); // Creates a literal of the given shape where each element is `value`. template static Literal CreateFullWithDescendingLayout( @@ -251,7 +253,7 @@ class LiteralUtil { // generator to populate the literal's values. // Returns the new literal object, or an error Status if failed. template > - static StatusOr CreateLiteralWithGenerator( + static absl::StatusOr CreateLiteralWithGenerator( const Shape& shape, absl::FunctionRef)> generator); @@ -261,16 +263,17 @@ class LiteralUtil { // Returns the new literal object, or an error Status if failed. template > - static StatusOr CreateRandomLiteral(const Shape& shape, E* engine, - T mean, T stddev); + static absl::StatusOr CreateRandomLiteral(const Shape& shape, + E* engine, T mean, + T stddev); // Creates a literal with the supplied shape, and initializes the literal // values using a normal distribution with given mean and stddev standard // deviation. // Returns the new literal object, or an error Status if failed. template > - static StatusOr CreateRandomLiteral(const Shape& shape, T mean, - T stddev); + static absl::StatusOr CreateRandomLiteral(const Shape& shape, T mean, + T stddev); // // End of factory methods. @@ -297,6 +300,17 @@ template return literal; } +template +/* static */ Literal LiteralUtil::CreateR0(PrimitiveType primitive_type, + T value) { + return primitive_util::ArrayTypeSwitch( + [&value](auto type) { + using NativeT = primitive_util::NativeTypeOf; + return CreateR0(static_cast(value)); + }, + primitive_type); +} + template /* static */ Literal LiteralUtil::CreateR1(absl::Span values) { Literal literal( @@ -522,7 +536,7 @@ template } template -/* static */ StatusOr LiteralUtil::CreateLiteralWithGenerator( +/* static */ absl::StatusOr LiteralUtil::CreateLiteralWithGenerator( const Shape& shape, absl::FunctionRef)> generator) { using NativeT = primitive_util::NativeTypeOf; @@ -534,7 +548,7 @@ template } template -/* static */ StatusOr LiteralUtil::CreateRandomLiteral( +/* static */ absl::StatusOr LiteralUtil::CreateRandomLiteral( const Shape& shape, E* engine, T mean, T stddev) { using NativeT = primitive_util::NativeTypeOf; std::normal_distribution generator(mean, stddev); @@ -545,7 +559,7 @@ template } template -/* static */ StatusOr LiteralUtil::CreateRandomLiteral( +/* static */ absl::StatusOr LiteralUtil::CreateRandomLiteral( const Shape& shape, T mean, T stddev) { std::minstd_rand0 engine; return CreateRandomLiteral(shape, &engine, mean, stddev); diff --git a/xla/map_util.h b/xla/map_util.h index b6321bcba40b2..a47c661f8f679 100644 --- a/xla/map_util.h +++ b/xla/map_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/metric_table_report.cc b/xla/metric_table_report.cc index e2dc63049a8d5..e84bf5f9cd637 100644 --- a/xla/metric_table_report.cc +++ b/xla/metric_table_report.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/metric_table_report.h b/xla/metric_table_report.h index 9c3659363e8b8..a84855491db1b 100644 --- a/xla/metric_table_report.h +++ b/xla/metric_table_report.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/backends/cpu/BUILD b/xla/mlir/backends/cpu/BUILD deleted file mode 100644 index f26817898f4f3..0000000000000 --- a/xla/mlir/backends/cpu/BUILD +++ /dev/null @@ -1,37 +0,0 @@ -load("@bazel_skylib//rules:build_test.bzl", "build_test") -load("//xla:xla.bzl", "xla_cc_binary") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//xla/mlir:__subpackages__"], - licenses = ["notice"], -) - -build_test( - name = "xla-cpu-opt_build_test", - targets = [ - ":xla-cpu-opt", - ], -) - -xla_cc_binary( - name = "xla-cpu-opt", - srcs = ["xla-cpu-opt.cc"], - deps = [ - "//xla/mlir/backends/cpu/transforms:passes", - "//xla/mlir/xla_cpu/ir:xla_cpu", - "//xla/mlir_hlo:all_passes", - "//xla/mlir_hlo:hlo_dialect_registration", - "//xla/mlir_hlo:lhlo", - "//xla/service/cpu:cpu_compiler", - "@llvm-project//mlir:BufferizationTransforms", - "@llvm-project//mlir:FuncExtensions", - "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:MlirOptLib", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:VectorDialect", - "@stablehlo//:register", - ], -) diff --git a/xla/mlir/backends/cpu/transforms/BUILD b/xla/mlir/backends/cpu/transforms/BUILD index 6f654b168f93f..83b4f8442bb16 100644 --- a/xla/mlir/backends/cpu/transforms/BUILD +++ b/xla/mlir/backends/cpu/transforms/BUILD @@ -1,10 +1,10 @@ -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") +load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//xla:internal"], + default_visibility = ["//xla:friends"], licenses = ["notice"], ) @@ -32,7 +32,6 @@ cc_library( "legalize_i1_vector_transfers.cc", "legalize_library_ops.cc", "remove_copies_to_out_params.cc", - "sparse_rewrite_passes.cc", "xla_abi_legalization.cc", "xla_cpu_memref_element_cast_to_llvm.cc", "xla_cpu_to_cpu_runtime.cc", @@ -45,7 +44,6 @@ cc_library( "//xla/mlir/runtime/utils:custom_calls", "//xla/mlir/xla_cpu/ir:xla_cpu", "//xla/mlir_hlo", - "//xla/mlir_hlo:lhlo", "//xla/service:hlo_parser", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", diff --git a/xla/mlir/backends/cpu/transforms/legalize_i1_vector_transfers.cc b/xla/mlir/backends/cpu/transforms/legalize_i1_vector_transfers.cc index f9f5d0d1bac7b..09d4293af32dc 100644 --- a/xla/mlir/backends/cpu/transforms/legalize_i1_vector_transfers.cc +++ b/xla/mlir/backends/cpu/transforms/legalize_i1_vector_transfers.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/backends/cpu/transforms/legalize_library_ops.cc b/xla/mlir/backends/cpu/transforms/legalize_library_ops.cc index 14948e0ee7279..c3231a9262538 100644 --- a/xla/mlir/backends/cpu/transforms/legalize_library_ops.cc +++ b/xla/mlir/backends/cpu/transforms/legalize_library_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/backends/cpu/transforms/passes.h b/xla/mlir/backends/cpu/transforms/passes.h index 0310dbf0a9689..7993c0f5d926f 100644 --- a/xla/mlir/backends/cpu/transforms/passes.h +++ b/xla/mlir/backends/cpu/transforms/passes.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -47,9 +47,6 @@ createConvertXlaCpuMemRefElementCastToLLVMPass(); std::unique_ptr> createRemoveCopiesToOutParamsPass(); -std::unique_ptr> -createSparseCustomCallRewritingPass(); - std::unique_ptr> createRewriteReallocToAllocPass(); diff --git a/xla/mlir/backends/cpu/transforms/passes.td b/xla/mlir/backends/cpu/transforms/passes.td index 1aeacd631770e..4d84b86bf970e 100644 --- a/xla/mlir/backends/cpu/transforms/passes.td +++ b/xla/mlir/backends/cpu/transforms/passes.td @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -109,24 +109,6 @@ def RemoveCopiesToOutParamsPass : let constructor = "createRemoveCopiesToOutParamsPass()"; } -def SparseCustomCallRewritingPass : - Pass<"xla-sparse-custom-call-to-pack", "mlir::func::FuncOp"> { - let summary = "Converts CustomCall operations to sparse operations"; - - let description = [{ - Converts CustomCallOp operations to sparse operations - to avoid dropping sparsity information during the - HLO roundtrip. - }]; - - let dependentDialects = [ - "mlir::sparse_tensor::SparseTensorDialect", - "mlir::chlo::ChloDialect", - ]; - - let constructor = "createSparseCustomCallRewritingPass()"; -} - def RewriteReallocToAllocPass : Pass<"xla-rewrite-realloc-to-alloc", "mlir::func::FuncOp"> { let summary = "Rewrites realloc to alloc + copy"; diff --git a/xla/mlir/backends/cpu/transforms/remove_copies_to_out_params.cc b/xla/mlir/backends/cpu/transforms/remove_copies_to_out_params.cc index b9974de5ee609..5171fd2dadd34 100644 --- a/xla/mlir/backends/cpu/transforms/remove_copies_to_out_params.cc +++ b/xla/mlir/backends/cpu/transforms/remove_copies_to_out_params.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc b/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc deleted file mode 100644 index 871292d2a6eed..0000000000000 --- a/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc +++ /dev/null @@ -1,682 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include - -#include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" // from @llvm-project -#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project -#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "stablehlo/dialect/ChloOps.h" // from @stablehlo -#include "xla/mlir/backends/cpu/transforms/passes.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" - -namespace xla { -namespace cpu { -namespace { - -#define GEN_PASS_DEF_SPARSECUSTOMCALLREWRITINGPASS -#include "xla/mlir/backends/cpu/transforms/passes.h.inc" - -using namespace mlir; // NOLINT - -DenseIntElementsAttr getDenseIntAttrFromConstant(Value v) { - if (auto const_op = v.getDefiningOp()) { - return const_op.getValue().cast(); - } else if (auto itoa_op = v.getDefiningOp()) { - // MHLO canonicalizer canonicalizes constants like [0, 1, 2, .., n-1] to - // mhlo.itoa {itoa_dimension=0}: tensor - RankedTensorType rtt = itoa_op.getOutput().getType(); - // We only use 1-D tensors to encode constant parameters in custom calls. - assert(itoa_op.getIotaDimension() == 0 && rtt.getRank() == 1); - SmallVector const_values; - const_values.reserve(rtt.getShape()[0]); - for (int i = 0; i < rtt.getShape()[0]; ++i) { - const_values.push_back(i); - } - return DenseIntElementsAttr::get(rtt, const_values); - } - llvm_unreachable("unrecognizable type of constant"); -} - -void getIntegersFromDenseElements(Value v, SmallVectorImpl& values) { - auto attr = getDenseIntAttrFromConstant(v); - values.reserve(values.size() + attr.size()); - auto range = llvm::map_range(attr, [](APInt i) { return i.getZExtValue(); }); - values.append(range.begin(), range.end()); -} - -Value getEmptyTensor(OpBuilder& b, Location loc, RankedTensorType type) { - auto t = b.create(loc, type.getShape(), - type.getElementType(), ValueRange{}); - auto zero = b.getZeroAttr(type.getElementType()); - auto c0 = b.create(loc, zero); - return b.create(loc, ValueRange{c0}, ValueRange{t}) - .getResult(0); -} - -struct SparseBatchedAssembleCallRewriter { - LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { - assert(op.getResults().size() == 1 && "Must be packing into one tensor"); - Value ret_sp_tensor = op.getResults()[0]; - rewriter.replaceOpWithNewOp( - op, ret_sp_tensor.getType(), op.getInputs()[0], // sparse tensor values - op.getInputs().drop_front()); // sparse tensor levels - return success(); - } -}; - -template -struct SparseBinaryCallRewriter { - LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { - assert(op.getInputs().size() == 2 && "Need two argument"); - assert(op.getResults().size() == 1 && "Need one output tensor"); - // Reconstruct the binary mhlo operation. - Value ret_sp_tensor = op.getResults()[0]; - rewriter.replaceOpWithNewOp( - op, ret_sp_tensor.getType(), op.getInputs()[0], op.getInputs()[1]); - return success(); - } -}; - -struct SparseBroadcastInDimCallRewriter { - LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { - assert(op.getInputs().size() == 2 && - "Need argument and broadcast dimensions"); - assert(op.getResults().size() == 1 && "Need one output tensor"); - // Broadcast dimensions are passed in as a constant of dense int elements. - auto dims_constant = op.getInputs()[1]; - auto broadcast_dimensions = getDenseIntAttrFromConstant(dims_constant); - // Reconstruct the broadcast_in_dim operation. - Value ret_sp_tensor = op.getResults()[0]; - rewriter.replaceOpWithNewOp( - op, ret_sp_tensor.getType(), op.getInputs()[0], broadcast_dimensions); - return success(); - } -}; - -template -struct SparseCmpNoEqualCallRewriter { - LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { - assert(op.getInputs().size() == 2 && "Need two arguments"); - assert(op.getResults().size() == 1 && "Need one output tensor"); - - Value lhs = op.getInputs().front(); - Value rhs = op.getInputs().back(); - // Uses the explicit type in case this is a sparse tensor. - Type ret_tp = op.getResultTypes().front(); - auto cmp_attr = mhlo::ComparisonTypeAttr::get(op.getContext(), CmpType); - // Replaces the call with the compare operation. - rewriter.replaceOpWithNewOp(op, ret_tp, lhs, rhs, CmpDir, - cmp_attr); - return success(); - } -}; - -struct SparseConcatenateCallRewriter { - LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { - assert(op.getResults().size() == 1 && "Need one output tensor"); - // The concatenation dimension. - auto concat_dim = op.getInputs().back().getDefiningOp(); - auto concat_dim_attr = concat_dim.getValue().cast(); - // Reconstruct the concatenate operation. - Value ret_sp_tensor = op.getResults()[0]; - // Depending on test setup, we can get either a 32-bit integer or a 64-bit - // integer. - if (concat_dim_attr.getElementType().isInteger(32)) { - rewriter.replaceOpWithNewOp( - op, ret_sp_tensor.getType(), op.getInputs().drop_back(), - rewriter.getIndexAttr(concat_dim_attr.getValues()[0])); - } else { - assert(concat_dim_attr.getElementType().isInteger(64)); - rewriter.replaceOpWithNewOp( - op, ret_sp_tensor.getType(), op.getInputs().drop_back(), - rewriter.getIndexAttr(concat_dim_attr.getValues()[0])); - } - return success(); - } -}; - -struct SparseConvCallRewriter { - LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { - assert(op.getInputs().size() == 2 && "Need two input tensors"); - assert(op.getResults().size() == 1 && "Need one output tensor"); - auto rtp = op.getResults()[0].getType().cast(); - rewriter.replaceOpWithNewOp( - op, op.getInputs(), getEmptyTensor(rewriter, op.getLoc(), rtp)); - return success(); - } -}; - -struct SparseConvertCallRewriter { - LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { - assert(op.getInputs().size() == 1 && "Need one input tensor"); - assert(op.getResults().size() == 1 && "Need one output tensor"); - Value ret_sp_tensor = op.getResults()[0]; - rewriter.replaceOpWithNewOp( - op, ret_sp_tensor.getType(), op.getInputs()[0]); - return success(); - } -}; - -struct SparseDotCallRewriter { - LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { - assert(op.getInputs().size() == 6 && "Need arguments and metadata"); - assert(op.getResults().size() == 1 && "Need one output tensor"); - SmallVector lhs_contr, rhs_contr, lhs_batch, rhs_batch; - getIntegersFromDenseElements(op.getInputs()[2], lhs_contr); - getIntegersFromDenseElements(op.getInputs()[3], rhs_contr); - getIntegersFromDenseElements(op.getInputs()[4], lhs_batch); - getIntegersFromDenseElements(op.getInputs()[5], rhs_batch); - auto dot_dims = mlir::mhlo::DotDimensionNumbersAttr::get( - op.getContext(), lhs_batch, rhs_batch, lhs_contr, rhs_contr); - Value ret_sp_tensor = op.getResults()[0]; - rewriter.replaceOpWithNewOp(op, ret_sp_tensor.getType(), - op.getInputs()[0], - op.getInputs()[1], dot_dims, - /*defaultPrecision*/ - ArrayAttr()); - return success(); - } -}; - -struct SparseDynSliceCallRewriter { - LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { - assert(op.getResults().size() == 1 && "Need one output tensor"); - auto ctx = op.getContext(); - auto loc = op.getLoc(); - auto retTp = op.getResults().getTypes()[0].cast(); - // Strips the tensor operand at the front and the static_size array at - // the end. Inputs in between specify the dynamic offsets. - auto dyn_off_tensors = op.getInputs().drop_front().drop_back(); - auto sizes = getDenseIntAttrFromConstant(op.getInputs().back()); - - assert(sizes.getNumElements() == retTp.getRank() && - dyn_off_tensors.size() == retTp.getRank()); - - SmallVector slice_attrs; - SmallVector static_offsets, static_sizes, static_strides; - SmallVector dyn_offsets; - constexpr auto dyn_v = sparse_tensor::SparseTensorDimSliceAttr::kDynamic; - for (auto em : llvm::enumerate(sizes)) { - // Populates sparse tensor slice attribute - uint64_t sz = em.value().getZExtValue(); - slice_attrs.push_back( - sparse_tensor::SparseTensorDimSliceAttr::get(ctx, dyn_v, sz, 1)); - // Populates arrays used for ExtractSliceOp. - static_offsets.push_back(ShapedType::kDynamic); - static_strides.push_back(1); // dynamic_slice always uses stride == 1 - static_sizes.push_back(sz); - // Populates dynamic offset value arrays for ExtractSliceOp. - Value dyn_off = rewriter.create( - loc, dyn_off_tensors[em.index()], ValueRange{}); - Value dyn_off_idx = rewriter.create( - loc, rewriter.getIndexType(), dyn_off); - dyn_offsets.push_back(dyn_off_idx); - } - - auto srcEnc = - retTp.getEncoding().cast(); - auto sliceEnc = sparse_tensor::SparseTensorEncodingAttr::get( - ctx, srcEnc.getLvlTypes(), srcEnc.getDimToLvl(), srcEnc.getLvlToDim(), - srcEnc.getPosWidth(), srcEnc.getCrdWidth(), slice_attrs); - auto sliceTp = RankedTensorType::get(retTp.getShape(), - retTp.getElementType(), sliceEnc); - - auto slice = rewriter.create( - loc, sliceTp, op.getInputs()[0], dyn_offsets, /*sizes=*/ValueRange{}, - /*strides=*/ValueRange{}, static_offsets, static_sizes, static_strides); - - // TODO(peiming): This weakens the performance benefit we get from the - // sparse compiler by forcing every slice to be materizalized while the - // sparse compiler supports view-based slice. - rewriter.replaceOpWithNewOp(op, retTp, slice); - return success(); - } -}; - -template -struct SparseReduceCallRewriter { - LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { - assert(op.getInputs().size() == 3 && - "Need one input tensor, identity, and axes"); - assert(op.getResults().size() == 1 && "Need one output tensor"); - SmallVector axes; - getIntegersFromDenseElements(op.getInputs()[2], axes); - Value result = op.getResults()[0]; - auto resultType = result.getType().dyn_cast(); - auto elementType = resultType.getElementType(); - - Location loc = op.getLoc(); - RankedTensorType blockArgumentType = RankedTensorType::get({}, elementType); - mhlo::ReduceOp reduce = rewriter.create( - loc, result.getType(), op.getInputs()[0], op.getInputs()[1], - rewriter.getI64TensorAttr(axes)); - - // Setup the body for mhlo.reduce. Note that sparse reductions like - // add/or/xor are good to go, but the more complicated prod/min/max/and - // need semi-ring lowering when converting to linalg. - Region& region = reduce.getBody(); - Block& block = region.emplaceBlock(); - block.addArgument(blockArgumentType, loc); - block.addArgument(blockArgumentType, loc); - auto* firstArgument = block.args_begin(); - auto secondArgument = block.args_rbegin(); - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - Value red = - rewriter.create(loc, *firstArgument, *secondArgument); - rewriter.create(loc, red); - } - rewriter.replaceOp(op, reduce.getResults()); - return success(); - } -}; - -struct SparseReshapeCallRewriter { - LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { - assert(op.getInputs().size() == 1 && "Need one input tensor"); - assert(op.getResults().size() == 1 && "Need one output tensor"); - // Reconstruct the reshape operation. - Value ret_sp_tensor = op.getResults()[0]; - // TODO(anlunx): Fix the issue that the reshape is rewritten to a collapse + - // expand pair where the sparsity encoding is dropped in between. - rewriter.replaceOpWithNewOp(op, ret_sp_tensor.getType(), - op.getInputs()[0]); - return success(); - } -}; - -struct SparseSelectRewriter { - LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { - assert(op.getInputs().size() == 3 && "Need three input tensors"); - assert(op.getResults().size() == 1 && "Need one output tensor"); - // Reconstruct the operation. - rewriter.replaceOpWithNewOp(op, op.getResults().getTypes(), - op.getInputs()); - return success(); - } -}; - -struct SparseSliceCallRewriter { - LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { - assert(op.getInputs().size() == 4 && - "Need one operand and three slicing parameters"); - assert(op.getResults().size() == 1 && "Need one output tensor"); - auto ctx = op.getContext(); - auto loc = op.getLoc(); - auto retTp = op.getResults().getTypes()[0].cast(); - auto offsets = getDenseIntAttrFromConstant(op.getInputs()[1]); - auto strides = getDenseIntAttrFromConstant(op.getInputs()[3]); - assert(offsets.getNumElements() == strides.getNumElements() && - offsets.getNumElements() == retTp.getRank()); - SmallVector slice_attrs; - SmallVector static_offsets, static_sizes, static_strides; - for (auto [offset, size, stride] : - llvm::zip(offsets, retTp.getShape(), strides)) { - int64_t o = offset.getZExtValue(), s = stride.getZExtValue(); - // Converts limits to sizes. - slice_attrs.push_back( - sparse_tensor::SparseTensorDimSliceAttr::get(ctx, o, size, s)); - static_offsets.push_back(o); - static_sizes.push_back(size); - static_strides.push_back(s); - } - auto srcEnc = - retTp.getEncoding().cast(); - // TODO(peiming): add a getSliceEncodingFrom into MLIR upstream. - auto sliceEnc = sparse_tensor::SparseTensorEncodingAttr::get( - ctx, srcEnc.getLvlTypes(), srcEnc.getDimToLvl(), srcEnc.getLvlToDim(), - srcEnc.getPosWidth(), srcEnc.getCrdWidth(), slice_attrs); - auto sliceTp = RankedTensorType::get(retTp.getShape(), - retTp.getElementType(), sliceEnc); - auto slice = rewriter.create( - loc, sliceTp, op.getInputs()[0], ValueRange(), ValueRange(), - ValueRange(), static_offsets, static_sizes, static_strides); - // TODO(peiming): This weakens the performance benefit we get from the - // sparse compiler by forcing every slice to be materialized while the - // sparse compiler supports view-based slice. - rewriter.replaceOpWithNewOp(op, retTp, slice); - return success(); - } -}; - -struct SparseTransposeCallRewriter { - LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { - assert(op.getInputs().size() == 2 && "Need argument and permutation"); - assert(op.getResults().size() == 1 && "Need one output tensor"); - // The permutation is passed in as a constant of dense int elements. - auto permutation_constant = - op.getInputs()[1].getDefiningOp(); - auto permutation = - permutation_constant.getValue().cast(); - // Reconstruct the transpose operation. - Value ret_sp_tensor = op.getResults()[0]; - rewriter.replaceOpWithNewOp( - op, ret_sp_tensor.getType(), op.getInputs()[0], permutation); - return success(); - } -}; - -template -struct SparseUnaryChloCallRewriter { - LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { - assert(op.getInputs().size() == 1 && "Need one argument"); - assert(op.getResults().size() == 1 && "Need one output tensor"); - // Reconstruct the unary chlo operation. - Value ret_sp_tensor = op.getResults()[0]; - rewriter.replaceOpWithNewOp(op, ret_sp_tensor.getType(), - op.getInputs()[0]); - return success(); - } -}; - -struct SparseDisassembleCallRewriter { - LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { - // TODO(peiming): Canonicalizes these two cases. The old bridge that uses - // jax.BCOO/BCSR does not require buffer lengths. - unsigned disassemble_bufs_num = op.getInputs().size() - 1; - assert(op.getResults().size() == disassemble_bufs_num || - op.getResults().size() == disassemble_bufs_num * 2); - SmallVector disassemble_ret_tp( - op.getResults().take_front(disassemble_bufs_num).getTypes()); - // Extra lengths for each buffer returned. - disassemble_ret_tp.append(disassemble_bufs_num, rewriter.getIndexType()); - Value tensor = op.getInputs()[0]; - Value out_vals = op.getInputs()[1]; - ValueRange out_lvls = op.getInputs().drop_front(2); - // Constructs the disassembleOp. - auto disassemble_op = rewriter.create( - op.getLoc(), disassemble_ret_tp, tensor, out_vals, out_lvls); - assert(disassemble_op.getResults().size() == disassemble_bufs_num * 2); - ValueRange bufs = - disassemble_op.getResults().take_front(disassemble_bufs_num); - ValueRange lens = - disassemble_op.getResults().take_back(disassemble_bufs_num); - - // Wraps the scalar value into a "scalar tensor", i.e., tensor - SmallVector rets(bufs.begin(), bufs.end()); - if (op.getResults().size() == disassemble_bufs_num * 2) { - ValueRange ret_lens = op.getResults().take_back(disassemble_bufs_num); - for (auto [len, tensor_len] : llvm::zip(lens, ret_lens)) { - auto ret_t_len = rewriter.create( - op.getLoc(), tensor_len.getType(), ValueRange{}); - auto int_len = rewriter.create( - op.getLoc(), ret_t_len.getType().getElementType(), len); - auto ret_len = rewriter.create( - op.getLoc(), ret_t_len.getType(), int_len, ret_t_len, ValueRange{}); - rets.push_back(ret_len); - } - } - rewriter.replaceOp(op, rets); - return success(); - } -}; - -struct SparseSDDMMCallRewriter { - LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { - assert(op.getInputs().size() == 3 && "Need S, A, B matrices"); - assert(op.getResults().size() == 1 && "Need one output tensor"); - Location loc = op.getLoc(); - Value matS = op.getInputs()[0]; - Value matA = op.getInputs()[1]; - Value matB = op.getInputs()[2]; - auto etp = matS.getType().dyn_cast().getElementType(); - // Build the enveloping generic op with the following trait: - // indexing_maps = [ - // affine_map<(i,j,k) -> (i,k)>, // A - // affine_map<(i,j,k) -> (k,j)>, // B - // affine_map<(i,j,k) -> (i,j)> // S - // ], - // iterator_types = ["parallel", "parallel", "reduction"], - // doc = "S(i,j) += spy[S(i,j)] x SUM_k A(i,k) B(k,j)" - SmallVector iteratorTypes; - iteratorTypes.push_back(utils::IteratorType::parallel); - iteratorTypes.push_back(utils::IteratorType::parallel); - iteratorTypes.push_back(utils::IteratorType::reduction); - using MapList = ArrayRef>; - auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; - AffineExpr i, j, k; - bindDims(op.getContext(), i, j, k); - auto indexingMaps = infer({{i, k}, {k, j}, {i, j}}); - auto genericOp = rewriter.create( - loc, TypeRange{matS.getType()}, ValueRange{matA, matB}, - ValueRange{matS}, indexingMaps, iteratorTypes); - // Construct semi-ring op. - Block* main = rewriter.createBlock(&genericOp.getRegion(), {}, - {etp, etp, etp}, {loc, loc, loc}); - Value argS = main->getArgument(2); - rewriter.setInsertionPointToStart(&genericOp.getRegion().front()); - auto semiring = rewriter.create(loc, etp, argS); - rewriter.createBlock(&semiring.getPresentRegion(), {}, etp, loc); - rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front()); - auto mul = rewriter.create(loc, main->getArgument(0), - main->getArgument(1)); - rewriter.create(loc, mul.getResult()); - rewriter.setInsertionPointAfter(semiring); - // Construct reduction op. - auto identity = - rewriter.create(loc, rewriter.getZeroAttr(etp)); - auto custom = rewriter.create( - loc, etp, argS, semiring.getResult(), identity); - Block* red = - rewriter.createBlock(&custom.getRegion(), {}, {etp, etp}, {loc, loc}); - rewriter.setInsertionPointToStart(&custom.getRegion().front()); - auto add = rewriter.create(loc, red->getArgument(0), - red->getArgument(1)); - rewriter.create(loc, add.getResult()); - rewriter.setInsertionPointAfter(custom); - rewriter.create(loc, custom.getResult()); - rewriter.replaceOp(op, genericOp.getResults()); - return success(); - } -}; - -// This rewriter rewrites 2:4 SpMM custom op to linalg.generic operator that -// carries the DENSE24 trait and does multiplication. -struct Sparse2To4SpMMCallRewriter { - LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) { - assert(op.getInputs().size() == 3 && "Need C, A, B matrices"); - assert(op.getResults().size() == 1 && "Need one output tensor"); - Location loc = op.getLoc(); - Value mat_c = op.getInputs()[0]; - Value mat_a = op.getInputs()[1]; - Value mat_b = op.getInputs()[2]; - - auto etp = mat_c.getType().dyn_cast().getElementType(); - // Build the enveloping generic op with the following trait: - // indexing_maps = [ - // affine_map<(i,j,k) -> (i,k)>, // A - // affine_map<(i,j,k) -> (k,j)>, // B - // affine_map<(i,j,k) -> (i,j)> // S - // ], - // iterator_types = ["parallel", "parallel", "reduction"], - // doc = "C(i,j) += SUM_k A(i,k) B(k,j)" - SmallVector iteratorTypes; - iteratorTypes.push_back(utils::IteratorType::parallel); - iteratorTypes.push_back(utils::IteratorType::parallel); - iteratorTypes.push_back(utils::IteratorType::reduction); - using MapList = ArrayRef>; - auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; - AffineExpr i, j, k; - bindDims(op.getContext(), i, j, k); - auto indexing_maps = infer({{i, k}, {k, j}, {i, j}}); - auto generic_op = rewriter.create( - loc, TypeRange{mat_c.getType()}, ValueRange{mat_a, mat_b}, - ValueRange{mat_c}, indexing_maps, iteratorTypes); - // Set DENSE24 attribute. - generic_op->setAttr("DENSE24", rewriter.getI32IntegerAttr(1)); - // Construct operations in the linalg.generic block. - Block* main = rewriter.createBlock(&generic_op.getRegion(), {}, - {etp, etp, etp}, {loc, loc, loc}); - Value arg_c = main->getArgument(2); - rewriter.setInsertionPointToStart(&generic_op.getRegion().front()); - auto mul = rewriter.create(loc, main->getArgument(0), - main->getArgument(1)); - auto add = rewriter.create(loc, mul.getResult(), arg_c); - rewriter.create(loc, add.getResult()); - rewriter.replaceOp(op, generic_op.getResults()); - return success(); - } -}; - -class SparseCustomCallRewriter : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - using SparseCustomTargetRewriter = std::function; - - const llvm::StringMap rewriter_map_{ - // Internal custom ops that need rewriting. - std::make_pair("sparse_tensor_add", - SparseBinaryCallRewriter()), - std::make_pair("sparse_tensor_asin", - SparseUnaryChloCallRewriter()), - std::make_pair("sparse_tensor_asinh", - SparseUnaryChloCallRewriter()), - std::make_pair("sparse_tensor_atan", - SparseUnaryChloCallRewriter()), - std::make_pair("sparse_tensor_atanh", - SparseUnaryChloCallRewriter()), - std::make_pair("sparse_tensor_bessel_i1e", - SparseUnaryChloCallRewriter()), - std::make_pair("sparse_tensor_broadcast_in_dim", - SparseBroadcastInDimCallRewriter()), - std::make_pair("sparse_tensor_concatenate", - SparseConcatenateCallRewriter()), - std::make_pair("sparse_tensor_conv_general_dilated", - SparseConvCallRewriter()), - std::make_pair("sparse_tensor_convert", SparseConvertCallRewriter()), - std::make_pair("sparse_tensor_dot_general", SparseDotCallRewriter()), - std::make_pair("sparse_tensor_dynamic_slice", - SparseDynSliceCallRewriter()), - std::make_pair( - "sparse_tensor_gt_SIGNED", - SparseCmpNoEqualCallRewriter()), - std::make_pair( - "sparse_tensor_gt_FLOAT", - SparseCmpNoEqualCallRewriter()), - std::make_pair( - "sparse_tensor_gt_UNSIGNED", - SparseCmpNoEqualCallRewriter()), - std::make_pair( - "sparse_tensor_lt_SIGNED", - SparseCmpNoEqualCallRewriter()), - std::make_pair( - "sparse_tensor_lt_FLOAT", - SparseCmpNoEqualCallRewriter()), - std::make_pair( - "sparse_tensor_lt_UNSIGNED", - SparseCmpNoEqualCallRewriter()), - std::make_pair("sparse_tensor_mul", - SparseBinaryCallRewriter()), - std::make_pair( - "sparse_tensor_ne_SIGNED", - SparseCmpNoEqualCallRewriter()), - std::make_pair( - "sparse_tensor_ne_FLOAT", - SparseCmpNoEqualCallRewriter()), - std::make_pair( - "sparse_tensor_ne_UNSIGNED", - SparseCmpNoEqualCallRewriter()), - std::make_pair("sparse_tensor_reduce_and", - SparseReduceCallRewriter()), - std::make_pair("sparse_tensor_reduce_max", - SparseReduceCallRewriter()), - std::make_pair("sparse_tensor_reduce_min", - SparseReduceCallRewriter()), - std::make_pair("sparse_tensor_reduce_or", - SparseReduceCallRewriter()), - std::make_pair("sparse_tensor_reduce_prod", - SparseReduceCallRewriter()), - std::make_pair("sparse_tensor_reduce_sum", - SparseReduceCallRewriter()), - std::make_pair("sparse_tensor_reduce_xor", - SparseReduceCallRewriter()), - std::make_pair("sparse_tensor_reshape", SparseReshapeCallRewriter()), - std::make_pair("sparse_tensor_select_n", SparseSelectRewriter()), - std::make_pair("sparse_tensor_sinh", - SparseUnaryChloCallRewriter()), - std::make_pair("sparse_tensor_slice", SparseSliceCallRewriter()), - std::make_pair("sparse_tensor_assemble", - SparseBatchedAssembleCallRewriter()), - std::make_pair("sparse_tensor_disassemble", - SparseDisassembleCallRewriter()), - std::make_pair("sparse_tensor_sub", - SparseBinaryCallRewriter()), - std::make_pair("sparse_tensor_tan", - SparseUnaryChloCallRewriter()), - std::make_pair("sparse_tensor_transpose", SparseTransposeCallRewriter()), - // User custom ops that need rewriting. - std::make_pair("sparse_jax_sddmm", SparseSDDMMCallRewriter()), - std::make_pair("sparse_jax_2to4_spmm", Sparse2To4SpMMCallRewriter()), - }; - - // Rewrites a CustomCallOp to corresponding sparse_tensor operation. - LogicalResult matchAndRewrite(mhlo::CustomCallOp op, - PatternRewriter& rewriter) const override { - if (auto it = rewriter_map_.find(op.getCallTargetName()); - it != rewriter_map_.end()) { - return it->second(op, rewriter); - } - // Returns failure on unmatched call target. - return failure(); - } -}; - -class SparseCustomCallRewritingPass - : public impl::SparseCustomCallRewritingPassBase< - SparseCustomCallRewritingPass> { - void runOnOperation() override { - func::FuncOp func = getOperation(); - MLIRContext* ctx = func.getContext(); - RewritePatternSet patterns(ctx); - patterns.insert(ctx); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr> -createSparseCustomCallRewritingPass() { - return std::make_unique(); -} - -} // namespace cpu -} // namespace xla diff --git a/xla/mlir/backends/cpu/transforms/tests/BUILD b/xla/mlir/backends/cpu/transforms/tests/BUILD deleted file mode 100644 index 52b72df95a3af..0000000000000 --- a/xla/mlir/backends/cpu/transforms/tests/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -load("@tsl//tsl:tsl.default.bzl", "filegroup") -load("//xla:glob_lit_test.bzl", "glob_lit_tests") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - licenses = ["notice"], -) - -glob_lit_tests( - name = "all_tests", - data = [":test_utilities"], - driver = "//xla:run_lit.sh", - test_file_exts = ["mlir"], -) - -# Bundle together all of the test utilities that are used by tests. -filegroup( - name = "test_utilities", - testonly = True, - data = [ - "//xla/mlir/backends/cpu:xla-cpu-opt", - "@llvm-project//llvm:FileCheck", - "@llvm-project//mlir:run_lit.sh", - ], -) diff --git a/xla/mlir/backends/cpu/transforms/tests/collective_ops.mlir b/xla/mlir/backends/cpu/transforms/tests/collective_ops.mlir deleted file mode 100644 index a219ff457b6bc..0000000000000 --- a/xla/mlir/backends/cpu/transforms/tests/collective_ops.mlir +++ /dev/null @@ -1,489 +0,0 @@ -// RUN: xla-cpu-opt %s -xla-legalize-library-ops | FileCheck %s - -func.func @max_reduce(%arg0: tensor<10xf32>) -> tensor<10xf32> { - %0 = "mhlo.all_reduce"(%arg0) ({ - ^bb0(%lhs: tensor, %rhs: tensor): - %max = mhlo.maximum %lhs, %rhs : tensor - "mhlo.return"(%max) : (tensor) -> () - }) - { - replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, - channel_handle = #mhlo.channel_handle< - handle = 5, - type = 2 - >, - use_global_device_ids - } : (tensor<10xf32>) -> tensor<10xf32> - func.return %0 : tensor<10xf32> -} - -// CHECK-LABEL: @max_reduce -// CHECK-SAME: %[[ARG0:.*]]: tensor<10xf32> -// CHECK: %[[DST:.*]] = tensor.empty() : tensor<10xf32> -// CHECK: %[[RET:.*]] = "xla_cpu.all_reduce"(%[[ARG0]], %[[DST]]) { -// CHECK-SAME: channel_handle = 5 : i64, -// CHECK-SAME: reduction_kind = 3 : i32, -// CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7]]> -// CHECK-SAME: use_global_device_ids = 1 -// CHECK: return %[[RET]] - -func.func @and_reduce(%arg0: tensor<1xi1>) -> tensor<1xi1> { - %0 = "mhlo.all_reduce"(%arg0) ({ - ^bb0(%lhs: tensor, %rhs: tensor): - %1 = mhlo.and %lhs, %rhs : tensor - mhlo.return %1 : tensor - }) { - replica_groups = dense<> : tensor<0x0xi64> - } : (tensor<1xi1>) -> tensor<1xi1> - func.return %0 : tensor<1xi1> -} - -// CHECK-LABEL: @and_reduce -// CHECK: reduction_kind = 2 : i32, - -func.func @or_reduce(%arg0: tensor<1xi1>) -> tensor<1xi1> { - %0 = "mhlo.all_reduce"(%arg0) ({ - ^bb0(%lhs: tensor, %rhs: tensor): - %1 = mhlo.or %lhs, %rhs : tensor - mhlo.return %1 : tensor - }) { - replica_groups = dense<> : tensor<0x0xi64> - } : (tensor<1xi1>) -> tensor<1xi1> - func.return %0 : tensor<1xi1> -} - -// CHECK-LABEL: @or_reduce -// CHECK: reduction_kind = 3 : i32, - -func.func @min_reduce_dynamic(%arg0: tensor) -> tensor { - %0 = "mhlo.all_reduce"(%arg0) ({ - ^bb0(%lhs: tensor, %rhs: tensor): - %max = mhlo.minimum %lhs, %rhs : tensor - "mhlo.return"(%max) : (tensor) -> () - }) - { - replica_groups = dense<> : tensor<0x0xi64>, - channel_handle = #mhlo.channel_handle< - handle = 5, - type = 2 - > - } : (tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: @min_reduce -// CHECK-SAME: %[[ARG0:.*]]: tensor -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK: %[[DST:.*]] = tensor.empty(%[[DIM]]) -// CHECK: "xla_cpu.all_reduce"(%[[ARG0]], %[[DST]]) -// CHECK-SAME: reduction_kind = 2 -// CHECK-SAME: use_global_device_ids = 0 - -func.func @partition_id() -> tensor { - %0 = "mhlo.partition_id"() : () -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: @partition_id -// CHECK: %[[ID:.*]] = "xla_cpu.partition_id"() : () -> i32 -// CHECK: %[[TENSOR:.*]] = tensor.from_elements %[[ID]] : tensor -// CHECK: %[[CAST:.*]] = mhlo.convert %[[TENSOR]] : (tensor) -> tensor -// CHECK: return %[[CAST]] - -func.func @replica_id() -> tensor { - %0 = "mhlo.replica_id"() : () -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: @replica_id -// CHECK: %[[ID:.*]] = "xla_cpu.replica_id"() : () -> i32 -// CHECK: %[[TENSOR:.*]] = tensor.from_elements %[[ID]] : tensor -// CHECK: %[[CAST:.*]] = mhlo.convert %[[TENSOR]] : (tensor) -> tensor -// CHECK: return %[[CAST]] - -func.func @collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { - %0 = "mhlo.collective_permute"(%arg0) { - source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, - channel_handle = #mhlo.channel_handle - } : (tensor<16x8xf32>) -> tensor<16x8xf32> - func.return %0 : tensor<16x8xf32> -} - -// CHECK-LABEL: @collective_permute -// CHECK-SAME: %[[ARG0:.*]]: tensor<16x8xf32> -// CHECK: %[[DST:.*]] = tensor.empty() : tensor<16x8xf32> -// CHECK: %[[RET:.*]] = "xla_cpu.collective_permute"(%[[ARG0]], %[[DST]]) { -// CHECK-SAME: channel_handle = 1 -// CHECK-SAME: source_target_pairs = dense< -// CHECK: return %[[RET]] - -func.func @collective_permute_dynamic(%arg0: tensor<16x?xf32>) - -> tensor<16x?xf32> { - %0 = "mhlo.collective_permute"(%arg0) { - source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, - channel_handle = #mhlo.channel_handle - } : (tensor<16x?xf32>) -> tensor<16x?xf32> - func.return %0 : tensor<16x?xf32> -} - -// CHECK-LABEL: @collective_permute_dynamic -// CHECK-SAME: %[[ARG0:.*]]: tensor<16x?xf32> -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK: %[[DST:.*]] = tensor.empty(%[[DIM]]) : tensor<16x?xf32> -// CHECK: "xla_cpu.collective_permute"(%[[ARG0]], %[[DST]]) { - -func.func @all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { - %0 = "mhlo.all_to_all"(%arg0) { - split_dimension = 1 : i64, - concat_dimension = 0 : i64, - split_count = 4 : i64, - channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> - } : (tensor<4x16xf32>) -> tensor<16x4xf32> - func.return %0 : tensor<16x4xf32> -} - -// CHECK-LABEL: @all_to_all -// CHECK-SAME: %[[ARG0:.*]]: tensor<4x16xf32> -// CHECK: %[[DST:.*]] = tensor.empty() : tensor<16x4xf32> -// CHECK: %[[RET:.*]] = "xla_cpu.all_to_all"(%[[ARG0]], %[[DST]]) { -// CHECK-SAME: channel_id_present = 1 -// CHECK-SAME: concat_dimension = 0 -// CHECK-SAME: op_id = 2 -// CHECK-SAME: replica_groups = dense< -// CHECK-SAME: split_count = 4 -// CHECK-SAME: split_dimension = 1 -// CHECK: return %[[RET]] - -func.func @all_to_all_dynamic_concat_dim(%arg0: tensor) - -> tensor { - %0 = "mhlo.all_to_all"(%arg0) { - split_dimension = 1 : i64, - concat_dimension = 0 : i64, - split_count = 4 : i64, - replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> - } : (tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: @all_to_all_dynamic_concat_dim -// CHECK-SAME: %[[ARG0:.*]]: tensor -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 -// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK: %[[CONCAT_DIM:.*]] = arith.muli %[[DIM]], %[[C4]] -// CHECK: %[[DST:.*]] = tensor.empty(%[[CONCAT_DIM]]) : tensor -// CHECK: "xla_cpu.all_to_all"(%[[ARG0]], %[[DST]]) { - -func.func @all_to_all_dynamic_split_dim(%arg0: tensor<4x?xf32>) - -> tensor<16x?xf32> { - %0 = "mhlo.all_to_all"(%arg0) { - split_dimension = 1 : i64, - concat_dimension = 0 : i64, - split_count = 4 : i64, - replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> - } : (tensor<4x?xf32>) -> tensor<16x?xf32> - func.return %0 : tensor<16x?xf32> -} - -// CHECK-LABEL: @all_to_all_dynamic_split_dim -// CHECK-SAME: %[[ARG0:.*]]: tensor<4x?xf32> -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 -// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK: %[[CONCAT_DIM:.*]] = arith.divui %[[DIM]], %[[C4]] -// CHECK: %[[DST:.*]] = tensor.empty(%[[CONCAT_DIM]]) : tensor<16x?xf32> -// CHECK: "xla_cpu.all_to_all"(%[[ARG0]], %[[DST]]) { - -func.func @all_to_all_tuple(%arg0: tensor<128x4xf32>, %arg1: tensor<128x4xf32>) - -> (tensor<128x4xf32>, tensor<128x4xf32>) { - %0:2 = "mhlo.all_to_all"(%arg0, %arg1) { - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> - } : (tensor<128x4xf32>, tensor<128x4xf32>) -> (tensor<128x4xf32>, tensor<128x4xf32>) - return %0#0, %0#1 : tensor<128x4xf32>, tensor<128x4xf32> -} - -// CHECK-LABEL: @all_to_all_tuple -// CHECK-SAME: %[[ARG0:.*]]: tensor<128x4xf32>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<128x4xf32> -// CHECK: %[[DST0:.*]] = tensor.empty() : tensor<128x4xf32> -// CHECK: %[[DST1:.*]] = tensor.empty() : tensor<128x4xf32> -// CHECK: "xla_cpu.all_to_all"(%[[ARG0]], %[[ARG1]], %[[DST0]], %[[DST1]]) - -func.func @outfeed_0_input(%token: !mhlo.token) -> !mhlo.token { - %res = "mhlo.outfeed"(%token) {outfeed_config = "foobar"} : (!mhlo.token) -> !mhlo.token - func.return %res : !mhlo.token -} - -// CHECK-LABEL: @outfeed_0_input -// CHECK: "xla_cpu.outfeed"() {config = "foobar", result_type = []} : () -> () - -func.func @outfeed_1_input(%data: tensor<2xui32>, %token: !mhlo.token) - -> !mhlo.token attributes {xlaframework.result_mapping = 1 : i32} { - %res = "mhlo.outfeed"(%data, %token) { - outfeed_config = "", xla_shape = "token[]" - } : (tensor<2xui32>, !mhlo.token) -> !mhlo.token - func.return %res : !mhlo.token -} - -// CHECK-LABEL: @outfeed_1_input -// CHECK-SAME: %[[DATA:.*]]: tensor<2xui32> -// CHECK-SAME: %[[TOKEN:.*]]: !mhlo.token -// CHECK: "xla_cpu.outfeed"(%[[DATA]]) {config = "", result_type = [ui32]} : (tensor<2xui32>) -> () -// CHECK: return %[[TOKEN]] : !mhlo.token - -func.func @outfeed_2_input(%data1: tensor<3xui32>, %data2: tensor<3xi32>, %token: !mhlo.token) -> !mhlo.token { - %res = "mhlo.outfeed"(%data1, %data2, %token) {outfeed_config = "foobar"} - : (tensor<3xui32>, tensor<3xi32>, !mhlo.token) -> !mhlo.token - func.return %res : !mhlo.token -} - -// CHECK-LABEL: @outfeed_2_input -// CHECK-SAME: %[[ARG0:.*]]: tensor<3xui32> -// CHECK-SAME: %[[ARG1:.*]]: tensor<3xi32> -// CHECK: "xla_cpu.outfeed"(%[[ARG0]], %[[ARG1]]) {config = "foobar", result_type = [ui32, i32]} -// CHECK-SAME: (tensor<3xui32>, tensor<3xi32>) - -func.func @add_dependency(%arg0: tensor<16xf32>, %arg1: !mhlo.token) -> tensor<16xf32> { - %0 = "mhlo.add_dependency"(%arg0, %arg1) : (tensor<16xf32>, !mhlo.token) -> tensor<16xf32> - func.return %0 : tensor<16xf32> -} - -// CHECK-LABEL: @add_dependency -// CHECK-SAME: %[[ARG0:.*]]: tensor<16xf32> -// CHECK-SAME: %[[ARG1:.*]]: !mhlo.token -// CHECK: %[[RES:.*]] = "xla_cpu.add_dependency" -// CHECK-SAME: %[[ARG0]], %[[ARG1]] -// CHECK: return %[[RES]] : tensor<16xf32> - -func.func @conv_i4(%arg0: tensor<64x8x8x8xi4>, %arg1: tensor<4x4x8x32xi4>) - -> tensor<64x3x3x32xi8> { - %0 = mhlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [2, 2]} - {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : - (tensor<64x8x8x8xi4>, tensor<4x4x8x32xi4>) -> tensor<64x3x3x32xi8> - func.return %0 : tensor<64x3x3x32xi8> -} - -// CHECK-LABEL: @conv_i4 -// CHECK-SAME: %[[ARG0:.*]]: tensor<64x8x8x8xi4> -// CHECK-SAME: %[[ARG1:.*]]: tensor<4x4x8x32xi4> -// CHECK: %[[RES:.*]] = mhlo.convolution -// CHECK-SAME: %[[ARG0]], %[[ARG1]] -// CHECK: return %[[RES]] : tensor<64x3x3x32xi8> - -func.func @conv_0d_nc(%arg0: tensor<3x2xf32>, %arg1: tensor<2x3xf32>) - -> tensor<3x3xf32> { - %0 = mhlo.convolution(%arg0, %arg1) - dim_numbers = [b, f]x[i, o]->[b, f], - window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], - reverse = []} - {batch_group_count = 1 : i64, feature_group_count = 1 : i64, - precision_config = [#mhlo, #mhlo]} - : (tensor<3x2xf32>, tensor<2x3xf32>) -> tensor<3x3xf32> - func.return %0 : tensor<3x3xf32> -} - -// CHECK-LABEL: @conv_0d_nc -// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2xf32> -// CHECK-SAME: %[[ARG1:.*]]: tensor<2x3xf32> -// CHECK: %[[RES:.*]] = mhlo.convolution -// CHECK-SAME: %[[ARG0]], %[[ARG1]] -// CHECK: return %[[RES]] : tensor<3x3xf32> - -func.func @conv_1d_nwc_dyn(%arg0: tensor, %arg1: tensor<2x?x?xf32>) - -> tensor { - %0 = "mhlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv, - feature_group_count = 1 : i64, - padding = dense<[[0, 0]]> : tensor<1x2xi64>, - rhs_dilation = dense<1> : tensor<1xi64>, - window_strides = dense<1> : tensor<1xi64>, - someattr - } : (tensor, tensor<2x?x?xf32>) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: @conv_1d_nwc_dyn -// CHECK-SAME: %[[ARG0:.*]]: tensor -// CHECK-SAME: %[[ARG1:.*]]: tensor<2x?x?xf32> -// CHECK: %[[RES:.*]] = mhlo.convolution -// CHECK-SAME: %[[ARG0]], %[[ARG1]] -// CHECK: return %[[RES]] : tensor - -func.func @depthwise_conv1d(%arg0: tensor<1x10x8xf32>, - %arg1: tensor<3x1x16xf32>) -> tensor<1x10x16xf32> { - %0 = mhlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, f]x[0, i, o]->[b, 0, f], - window = { - stride = [1], - pad = [[1, 1]], - lhs_dilate = [1], - rhs_dilate = [1], - reverse = [0]} { - batch_group_count = 1 : i64, - feature_group_count = 8 : i64, - someattr} : (tensor<1x10x8xf32>, tensor<3x1x16xf32>) -> tensor<1x10x16xf32> - func.return %0 : tensor<1x10x16xf32> -} - -// CHECK-LABEL: @depthwise_conv1d -// CHECK-SAME: %[[ARG0:.*]]: tensor<1x10x8xf32> -// CHECK-SAME: %[[ARG1:.*]]: tensor<3x1x16xf32> -// CHECK: %[[DST:.*]] = tensor.empty() : tensor<1x10x16xf32> -// CHECK: %[[RES:.*]] = "xla_cpu.convolution" -// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[DST]] -// CHECK: return %[[RES]] : tensor<1x10x16xf32> - -func.func @conv_2d_nhwc_hwcf(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x2x1x1xf32>) - -> tensor<1x2x4x1xf32> { - %0 = "mhlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv, - feature_group_count = 1 : i64, - padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>, - lhs_dilation = dense<1> : tensor<2xi64>, - rhs_dilation = dense<1> : tensor<2xi64>, - window_strides = dense<1> : tensor<2xi64> - } : (tensor<1x4x5x1xf32>, tensor<3x2x1x1xf32>) -> tensor<1x2x4x1xf32> - func.return %0 : tensor<1x2x4x1xf32> -} - -// CHECK-LABEL: @conv_2d_nhwc_hwcf -// CHECK-SAME: %[[ARG0:.*]]: tensor<1x4x5x1xf32> -// CHECK-SAME: %[[ARG1:.*]]: tensor<3x2x1x1xf32> -// CHECK: %[[DST:.*]] = tensor.empty() : tensor<1x2x4x1xf32> -// CHECK: %[[RES:.*]] = "xla_cpu.convolution" -// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[DST]] -// CHECK: return %[[RES]] : tensor<1x2x4x1xf32> - -func.func @conv_3d_ndhwc_dhwcf(%arg0: tensor<1x8x8x8x1xf32>, - %arg1: tensor<2x2x2x1x1xf32>) - -> tensor<1x7x7x7x1xf32> { - %0 = "mhlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv, - feature_group_count = 1 : i64, - padding = dense<[[0, 0], [0, 0], [0, 0]]> : tensor<3x2xi64>, - lhs_dilation = dense<1> : tensor<3xi64>, - rhs_dilation = dense<1> : tensor<3xi64>, - window_strides = dense<1> : tensor<3xi64> - } : (tensor<1x8x8x8x1xf32>, tensor<2x2x2x1x1xf32>) -> tensor<1x7x7x7x1xf32> - func.return %0 : tensor<1x7x7x7x1xf32> -} - -// CHECK-LABEL: @conv_3d_ndhwc_dhwcf -// CHECK-SAME: %[[ARG0:.*]]: tensor<1x8x8x8x1xf32> -// CHECK-SAME: %[[ARG1:.*]]: tensor<2x2x2x1x1xf32> -// CHECK: %[[DST:.*]] = tensor.empty() : tensor<1x7x7x7x1xf32> -// CHECK: %[[RES:.*]] = "xla_cpu.convolution" -// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[DST]] -// CHECK: return %[[RES]] : tensor<1x7x7x7x1xf32> - -func.func @normal_convolution_with_reversal(%arg0: tensor<1x3x3x3xf32>, - %arg1: tensor<3x3x3x1xf32>) -> tensor<1x1x1x1xf32> { - %0 = mhlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = { - stride = [1, 1], - pad = [[0, 0], [0, 0]], - lhs_dilate = [1, 1], - rhs_dilate = [1, 1], - reverse = [1, 1] - } { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, precision_config = [ - #mhlo, - #mhlo] - } : (tensor<1x3x3x3xf32>, tensor<3x3x3x1xf32>) -> tensor<1x1x1x1xf32> - return %0 : tensor<1x1x1x1xf32> -} - -// CHECK-LABEL: @normal_convolution_with_reversal -// CHECK-SAME: %[[ARG0:.*]]: tensor<1x3x3x3xf32> -// CHECK-SAME: %[[ARG1:.*]]: tensor<3x3x3x1xf32> -// CHECK: %[[RES:.*]] = mhlo.convolution -// CHECK-SAME: %[[ARG0]], %[[ARG1]] -// CHECK: return %[[RES]] : tensor<1x1x1x1xf32> - -func.func @general_convolution_with_zero_sized_dimension_in_output( - %arg0: tensor<2x4x9x0xi64> {bufferization.writable = false, - xla_framework.input_mapping = 2 : i32}, - %arg1: tensor<4x5x2x4xi64> {bufferization.writable = false, - xla_framework.input_mapping = 0 : i32}) - -> tensor<2x5x0x4xi64> attributes {xla_framework.result_mapping = 1 : i32} { - %0 = mhlo.convolution(%arg0, %arg1) - dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [2, 1], pad = [[1, 2], [2, 0]], lhs_dilate = [1, 4], - rhs_dilate = [1, 1], reverse = [0, 0]} - {batch_group_count = 1 : i64, feature_group_count = 2 : i64, - precision_config = [#mhlo, #mhlo]} - : (tensor<2x4x9x0xi64>, tensor<4x5x2x4xi64>) -> tensor<2x5x0x4xi64> - return %0 : tensor<2x5x0x4xi64> -} - -// CHECK-LABEL: @general_convolution_with_zero_sized_dimension_in_output -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x9x0xi64> -// CHECK-SAME: %[[ARG1:.*]]: tensor<4x5x2x4xi64> -// CHECK: %[[RES:.*]] = mhlo.convolution -// CHECK-SAME: %[[ARG0]], %[[ARG1]] -// CHECK: return %[[RES]] : tensor<2x5x0x4xi64> - -func.func @foo(%0: tensor<3x9x9x8xf32>, %1: tensor<1x7x8x8xf32>) - -> tensor<3x9x9x8xf32> { - %2 = mhlo.convolution(%0, %1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], - pad = [[0, 0], [3, 3]], - lhs_dilate = [1, 1], - rhs_dilate = [1, 1], - reverse = [0, 0]} - {batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#mhlo, - #mhlo]} - : (tensor<3x9x9x8xf32>, tensor<1x7x8x8xf32>) -> tensor<3x9x9x8xf32> - return %2 : tensor<3x9x9x8xf32> -} - -// CHECK-LABEL: @infeed -// CHECK: "xla_cpu.infeed" -func.func @infeed(%token: !mhlo.token) -> tensor<3x3xi32> { - %res:3 = "mhlo.infeed"(%token) {infeed_config = "foobar", layout=[[0,1], [0]]} - : (!mhlo.token) -> (tensor<3x3xi32>, tensor, !mhlo.token) - func.return %res#0 : tensor<3x3xi32> -} diff --git a/xla/mlir/backends/cpu/transforms/tests/fft.mlir b/xla/mlir/backends/cpu/transforms/tests/fft.mlir deleted file mode 100644 index 72973b0add585..0000000000000 --- a/xla/mlir/backends/cpu/transforms/tests/fft.mlir +++ /dev/null @@ -1,16 +0,0 @@ -// RUN: xla-cpu-opt %s -xla-legalize-library-ops | FileCheck %s - -func.func @fft(%arg0: tensor<3x5x4x8x256xf32>) -> tensor<3x5x4x8x129xcomplex> { - %0 = "mhlo.fft"(%arg0) { - fft_length = dense<[4, 8, 256]> : tensor<3xi64>, - fft_type = #mhlo - } : (tensor<3x5x4x8x256xf32>) -> tensor<3x5x4x8x129xcomplex> - func.return %0 : tensor<3x5x4x8x129xcomplex> -} - -// CHECK-LABEL: @fft -// CHECK-SAME: %[[ARG0:.*]]: tensor -// CHECK: %[[DST:.*]] = tensor.empty() : tensor<3x5x4x8x129xcomplex> -// CHECK: %[[FFT:.*]] = "xla_cpu.fft"(%[[ARG0]], %[[DST]]) -// CHECK-SAME: {fft_length = [4, 8, 256], fft_type = 2 : i32} -// CHECK: return %[[FFT]] diff --git a/xla/mlir/backends/cpu/transforms/tests/legalize_i1_vector_transfers.mlir b/xla/mlir/backends/cpu/transforms/tests/legalize_i1_vector_transfers.mlir deleted file mode 100644 index 39914b387704d..0000000000000 --- a/xla/mlir/backends/cpu/transforms/tests/legalize_i1_vector_transfers.mlir +++ /dev/null @@ -1,35 +0,0 @@ -// RUN: xla-cpu-opt %s -split-input-file -xla-legalize-i1-vector-transfers \ -// RUN: | FileCheck %s - -func.func @transfer_read(%in: memref<8xi1>) -> vector<8xi1> { - %pad = arith.constant true - %c1 = arith.constant 1 : index - %ret = vector.transfer_read %in[%c1], %pad : memref<8xi1>, vector<8xi1> - return %ret : vector<8xi1> -} - -// CHECK-LABEL: @transfer_read -// CHECK-SAME: %[[IN:.*]]: memref<8xi1> -// CHECK-DAG: %[[C1_I8:.*]] = arith.constant 1 : i8 -// CHECK-DAG: %[[C0_V:.*]] = arith.constant dense<0> : vector<8xi8> -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[CAST:.*]] = xla_cpu.memref_element_cast %[[IN]] -// CHECK: %[[READ:.*]] = vector.transfer_read %[[CAST]][%[[C1]]], -// CHECK-SAME: %[[C1_I8]] -// CHECK: %[[RET:.*]] = arith.cmpi ne, %[[READ]], %[[C0_V]] -// CHECK: return %[[RET]] - -func.func @transfer_write(%in: vector<8xi1>, %out: memref<8xi1>) { - %c0 = arith.constant 0 : index - vector.transfer_write %in, %out[%c0] : vector<8xi1>, memref<8xi1> - return -} - -// CHECK-LABEL: @transfer_write -// CHECK-SAME: %[[IN:.*]]: vector<8xi1> -// CHECK-SAME: %[[OUT:.*]]: memref<8xi1> -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[CAST_IN:.*]] = arith.extui %[[IN]] {{.*}} to vector<8xi8> -// CHECK-DAG: %[[CAST_OUT:.*]] = xla_cpu.memref_element_cast %[[OUT]] -// CHECK-NOT: vector.transfer_write {{.*}}%[[IN]] -// CHECK: vector.transfer_write %[[CAST_IN]], %[[CAST_OUT]][%[[C0]]] diff --git a/xla/mlir/backends/cpu/transforms/tests/library_ops_to_cpu_runtime.mlir b/xla/mlir/backends/cpu/transforms/tests/library_ops_to_cpu_runtime.mlir deleted file mode 100644 index 1aed4ad53e321..0000000000000 --- a/xla/mlir/backends/cpu/transforms/tests/library_ops_to_cpu_runtime.mlir +++ /dev/null @@ -1,131 +0,0 @@ -// RUN: xla-cpu-opt %s -split-input-file -xla-cpu-to-cpu-runtime | FileCheck %s - -func.func @partition_id() -> i32 { - %0 = "xla_cpu.partition_id"() : () -> i32 - func.return %0 : i32 -} - -// CHECK-LABEL: @partition_id -// CHECK: call @xla.cpu.partition_id() : () -> i32 - -// CHECK: func private @xla.cpu.partition_id() -> i32 attributes {rt.custom_call = "xla.cpu.partition_id"} - -// ----- - -func.func @replica_id() -> i32 { - %0 = "xla_cpu.replica_id"() : () -> i32 - func.return %0 : i32 -} - -// CHECK-LABEL: @replica_id -// CHECK: call @xla.cpu.replica_id() : () -> i32 - -// CHECK: func private @xla.cpu.replica_id() -> i32 attributes {rt.custom_call = "xla.cpu.replica_id"} - -// ----- - -#map = affine_map<(d0)[s0] -> (d0 + s0)> -func.func @all_reduce(%arg0: memref<32xf32, #map>, %arg1: memref<32xf32>) { - "xla_cpu.all_reduce"(%arg0, %arg1) { - replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, - channel_handle = 42 : i64, - reduction_kind = 3 : i32, - use_global_device_ids = 0 : i32 - } : (memref<32xf32, #map>, memref<32xf32>) -> () - func.return -} - -// CHECK-LABEL: @all_reduce -// CHECK-SAME: %[[ARG0:.*]]: memref<32xf32, -// CHECK-SAME: %[[ARG1:.*]]: memref<32xf32> -// CHECK: %[[ALLOC:.*]] = memref.alloc -// CHECK: memref.copy %[[ARG0]], %[[ALLOC]] -// CHECK: call @xla.cpu.all_reduce(%[[ALLOC]], %[[ARG1]]) -// CHECK-SAME: channel_handle = 42 -// CHECK-SAME: op_id = 0 -// CHECK-SAME: reduction_kind = 3 -// CHECK-SAME: replica_groups = dense< -// CHECK: func.func private @xla.cpu.all_reduce( -// CHECK-SAME: memref<32xf32>, memref<32xf32>) -// CHECK-SAME: attributes {rt.custom_call = "xla.cpu.all_reduce"} - - -// ----- - -func.func @collective_permute(%arg0: memref<16x8xf32>, %arg1: memref<16x8xf32>) { - "xla_cpu.collective_permute"(%arg0, %arg1) { - source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, - channel_handle = 42 : i64 - } : (memref<16x8xf32>, memref<16x8xf32>) -> () - func.return -} - -// CHECK-LABEL: @collective_permute -// CHECK-SAME: %[[ARG0:.*]]: memref<16x8xf32>, -// CHECK-SAME: %[[ARG1:.*]]: memref<16x8xf32> -// CHECK: call @xla.cpu.collective_permute(%[[ARG0]], %[[ARG1]]) -// CHECK-SAME: channel_handle = 42 -// CHECK-SAME: source_target_pairs = dense< -// CHECK: func.func private @xla.cpu.collective_permute( -// CHECK-SAME: attributes {rt.custom_call = "xla.cpu.collective_permute"} - -// ----- - -func.func @rng_bit_generator_default(%state: memref<3xui64>, - %state_out: memref<3xui64>, %values_out: memref<10xui32>) { - "xla_cpu.rng_bit_generator"(%state, %state_out, %values_out) - {rng_algorithm = #mhlo.rng_algorithm - } : (memref<3xui64>, memref<3xui64>, memref<10xui32>) -> () - return -} - -// CHECK-LABEL: @rng_bit_generator_default -// CHECK-SAME: %[[ARG0:.*]]: memref<3xui64>, %[[ARG1:.*]]: memref<3xui64>, -// CHECK-SAME: %[[ARG2:.*]]: memref<10xui32> -// CHECK: call @xla_cpu_rng_philox(%[[ARG0]], %[[ARG1]], %[[ARG2]]) -// CHECK: func.func private @xla_cpu_rng_philox( -// CHECK-SAME: attributes {rt.custom_call = "xla_cpu_rng_philox"} - -// ----- - -func.func @rng_bit_generator_three_fry(%state: memref<2xui64>, - %state_out: memref<2xui64>, %values_out: memref<10xui32>) { - "xla_cpu.rng_bit_generator"(%state, %state_out, %values_out) - {rng_algorithm = #mhlo.rng_algorithm - } : (memref<2xui64>, memref<2xui64>, memref<10xui32>) -> () - return -} - -// CHECK-LABEL: @rng_bit_generator_three_fry -// CHECK: call @xla_cpu_rng_three_fry( -// CHECK: func.func private @xla_cpu_rng_three_fry( -// CHECK-SAME: attributes {rt.custom_call = "xla_cpu_rng_three_fry"} - -// ----- - -func.func @conv_2d_nhwc_hwcf(%arg0: memref<1x4x5x1xf32>, %arg1: memref<3x2x1x1xf32>, %out: memref<1x2x4x1xf32>) { - "xla_cpu.convolution"(%arg0, %arg1, %out) {batch_group_count = 1 : i64, feature_group_count = 1 : i64, inputBatchDimension = 0 : i64, inputFeatureDimension = 3 : i64, inputSpatialDimensions = [1, 2], kernelInputFeatureDimension = 2 : i64, kernelOutputFeatureDimension = 3 : i64, kernelSpatialDimensions = [0, 1], lhs_dilation = dense<1> : tensor<2xi64>, outputBatchDimension = 0 : i64, outputFeatureDimension = 3 : i64, outputSpatialDimensions = [1, 2], padding = dense<0> : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (memref<1x4x5x1xf32>, memref<3x2x1x1xf32>, memref<1x2x4x1xf32>) -> () - return -} - -// ----- - -func.func @conv_3d_ndhwc_dhwcf(%arg0: memref<1x8x8x8x1xf32>, %arg1: memref<2x2x2x1x1xf32>, %out: memref<1x7x7x7x1xf32>) { - "xla_cpu.convolution"(%arg0, %arg1, %out) {batch_group_count = 1 : i64, feature_group_count = 1 : i64, inputBatchDimension = 0 : i64, inputFeatureDimension = 4 : i64, inputSpatialDimensions = [1, 2, 3], kernelInputFeatureDimension = 3 : i64, kernelOutputFeatureDimension = 4 : i64, kernelSpatialDimensions = [0, 1, 2], lhs_dilation = dense<1> : tensor<3xi64>, outputBatchDimension = 0 : i64, outputFeatureDimension = 4 : i64, outputSpatialDimensions = [1, 2, 3], padding = dense<0> : tensor<3x2xi64>, rhs_dilation = dense<1> : tensor<3xi64>, window_strides = dense<1> : tensor<3xi64>} : (memref<1x8x8x8x1xf32>, memref<2x2x2x1x1xf32>, memref<1x7x7x7x1xf32>) -> () - return -} - -// ----- - -func.func @depthwise_conv1d(%arg0: memref<1x10x8xf32>, %arg1: memref<3x1x16xf32>, %out: memref<1x10x16xf32>) { - "xla_cpu.convolution"(%arg0, %arg1, %out) {batch_group_count = 1 : i64, feature_group_count = 8 : i64, inputBatchDimension = 0 : i64, inputFeatureDimension = 2 : i64, inputSpatialDimensions = [1], kernelInputFeatureDimension = 1 : i64, kernelOutputFeatureDimension = 2 : i64, kernelSpatialDimensions = [0], lhs_dilation = dense<1> : tensor<1xi64>, outputBatchDimension = 0 : i64, outputFeatureDimension = 2 : i64, outputSpatialDimensions = [1], padding = dense<1> : tensor<1x2xi64>, rhs_dilation = dense<1> : tensor<1xi64>, window_reversal = dense : tensor<1xi1>, window_strides = dense<1> : tensor<1xi64>} : (memref<1x10x8xf32>, memref<3x1x16xf32>, memref<1x10x16xf32>) -> () - return -} - -// ----- - -func.func @foo(%arg0: memref<3x9x9x8xf32>, %arg1: memref<1x7x8x8xf32>, %out: memref<3x9x9x8xf32>) { - "xla_cpu.convolution"(%arg0, %arg1, %out) {batch_group_count = 1 : i64, feature_group_count = 1 : i64, inputBatchDimension = 0 : i64, inputFeatureDimension = 3 : i64, inputSpatialDimensions = [1, 2], kernelInputFeatureDimension = 2 : i64, kernelOutputFeatureDimension = 3 : i64, kernelSpatialDimensions = [0, 1], lhs_dilation = dense<1> : tensor<2xi64>, outputBatchDimension = 0 : i64, outputFeatureDimension = 3 : i64, outputSpatialDimensions = [1, 2], padding = dense<[[0, 0], [3, 3]]> : tensor<2x2xi64>, precision_config = [#mhlo, #mhlo], rhs_dilation = dense<1> : tensor<2xi64>, window_reversal = dense : tensor<2xi1>, window_strides = dense<1> : tensor<2xi64>} : (memref<3x9x9x8xf32>, memref<1x7x8x8xf32>, memref<3x9x9x8xf32>) -> () - return -} - diff --git a/xla/mlir/backends/cpu/transforms/tests/lmhlo_custom_call.mlir b/xla/mlir/backends/cpu/transforms/tests/lmhlo_custom_call.mlir deleted file mode 100644 index a06ad795e58c9..0000000000000 --- a/xla/mlir/backends/cpu/transforms/tests/lmhlo_custom_call.mlir +++ /dev/null @@ -1,90 +0,0 @@ -// RUN: xla-cpu-opt %s -split-input-file -xla-cpu-to-cpu-runtime \ -// RUN: | FileCheck %s - -// CHECK: func @test -// CHECK: %[[ARG0:.*]]: memref -// CHECK: ) -func.func @test(%arg0: memref) { - // CHECK: call @[[CUSTOM_CALL:.*]](%[[ARG0]]) - // CHECK-SAME: api_version = 2 : i32 - // CHECK-SAME: backend_config = "" - // CHECK-SAME: call_target_name = "target" - // CHECK-SAME: num_results = 1 : i32 - // CHECK-SAME: output_tuple = false - // CHECK-SAME: : (memref) -> () - "lmhlo.custom_call"(%arg0) ({}) { - api_version = 2 : i32, - backend_config = "", - call_target_name = "target", - operandSegmentSizes = array - } : (memref) -> () - return -} - -// CHECK: func.func private @[[CUSTOM_CALL]](memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.cpu.custom_call"} - -// ----- - -// CHECK: func @test_with_mapping -// CHECK: %[[ARG0:[0-9a-z]*]]: memref, -// CHECK: %[[ARG1:[0-9a-z]*]]: memref, -// CHECK: %[[ARG2:[0-9a-z]*]]: memref, -// CHECK: %[[ARG3:[0-9a-z]*]]: memref, -// CHECK: %[[ARG4:[0-9a-z]*]]: memref -// CHECK: ) -func.func @test_with_mapping( - %arg0: memref, - %arg1: memref, - %arg2: memref, - %arg3: memref, - %arg4: memref) { - // CHECK: %[[HOLE:.*]] = memref.alloca() : memref<0xi8> - - // CHECK: call @[[CUSTOM_CALL:.*]](%[[ARG0]], %[[HOLE]], %[[ARG1]], %[[HOLE]], - // CHECK-SAME: %[[ARG2]], %[[ARG3]], %[[HOLE]], %[[ARG4]]) - // CHECK-SAME: api_version = 1 : i32 - // CHECK-SAME: backend_config = "" - // CHECK-SAME: call_target_name = "target" - // CHECK-SAME: num_results = 4 : i32 - // CHECK-SAME: output_tuple = true - "lmhlo.custom_call"(%arg0, %arg1, %arg2, %arg3, %arg4) ({}) { - api_version = 1 : i32, - backend_config = "", - call_target_name = "target", - operandSegmentSizes = array, - target_arg_mapping = #lmhlo.custom_call_target_arg_mapping< - num_args = 4, - num_results = 4, - args_to_target_args = [0, 2], - results_to_target_results = [0, 1, 3]> - } : (memref, memref, memref, memref, memref) -> () - - return -} - -// CHECK: func.func private @[[CUSTOM_CALL]](memref, memref<0xi8>, -// CHECK-SAME: memref, memref<0xi8>, memref, memref, -// CHECK-SAME: memref<0xi8>, memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.cpu.custom_call"} - -// ----- - -// CHECK: func @one_element_output_tuple -// CHECK: %[[ARG0:.*]]: memref -// CHECK: ) -func.func @one_element_output_tuple(%arg0: memref) { - // CHECK: call @[[CUSTOM_CALL:.*]](%[[ARG0]]) - // CHECK-SAME: api_version = 2 : i32 - // CHECK-SAME: call_target_name = "target" - // CHECK-SAME: num_results = 1 : i32 - // CHECK-SAME: output_tuple = true - // CHECK-SAME: : (memref) -> () - "lmhlo.custom_call"(%arg0) ({}) { - api_version = 2 : i32, - call_target_name = "target", - operandSegmentSizes = array, - xla_shape = "(f32[])" - } : (memref) -> () - return -} diff --git a/xla/mlir/backends/cpu/transforms/tests/remove_copies_to_out_params.mlir b/xla/mlir/backends/cpu/transforms/tests/remove_copies_to_out_params.mlir deleted file mode 100644 index fbe3b502ca6ce..0000000000000 --- a/xla/mlir/backends/cpu/transforms/tests/remove_copies_to_out_params.mlir +++ /dev/null @@ -1,127 +0,0 @@ -// RUN: xla-cpu-opt %s -split-input-file -xla-remove-copies-to-out-params \ -// RUN: | FileCheck %s - -func.func @alloca(%arg0: memref, %arg1: memref) { - %0 = memref.load %arg0[] : memref - %1 = arith.addf %0, %0 : f64 - %alloca = memref.alloca() : memref - memref.store %1, %alloca[] : memref - memref.copy %alloca, %arg1 : memref to memref - return -} - -// CHECK-LABEL: func.func @alloca( -// CHECK-SAME: %[[ARG0:.*]]: memref, -// CHECK-SAME: %[[ARG1:.*]]: memref) { -// CHECK: %[[R0:.*]] = memref.load %[[ARG0]][] : memref -// CHECK: %[[R1:.*]] = arith.addf %[[R0]], %[[R0]] : f64 -// CHECK-NOT memref.alloca -// CHECK: memref.store %[[R1]], %[[ARG1]][] : memref -// CHECK-NOT: memref.copy -// CHECK-NEXT: return -// CHECK: } - -// ----- - -func.func @alloc_vectorized(%arg0: memref<1024xf64>, %arg1: memref<1024xf64>) { - %c1024 = arith.constant 1024 : index - %c0 = arith.constant 0 : index - %c8 = arith.constant 8 : index - %cst = arith.constant 0.000000e+00 : f64 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<1024xf64> - scf.parallel (%arg2) = (%c0) to (%c1024) step (%c8) { - %subview = memref.subview %alloc[%arg2] [8] [1] : - memref<1024xf64> to memref<8xf64, strided<[1], offset: ?>> - %0 = vector.transfer_read %arg0[%arg2], %cst {in_bounds = [true]} : - memref<1024xf64>, vector<8xf64> - %1 = arith.addf %0, %0 : vector<8xf64> - vector.transfer_write %1, %subview[%c0] {in_bounds = [true]} : - vector<8xf64>, memref<8xf64, strided<[1], offset: ?>> - scf.yield - } - memref.copy %alloc, %arg1 : memref<1024xf64> to memref<1024xf64> - memref.dealloc %alloc : memref<1024xf64> - return -} - -// CHECK-LABEL: func.func @alloc_vectorized( -// CHECK-SAME: %[[ARG0:.*]]: memref<1024xf64>, -// CHECK-SAME: %[[ARG1:.*]]: memref<1024xf64>) { -// CHECK-NOT: memref.alloc -// CHECK: scf.parallel -// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ARG1]] -// CHECK: %[[R0:.*]] = vector.transfer_read %[[ARG0]] -// CHECK: %[[R1:.*]] = arith.addf %[[R0]], %[[R0]] : vector<8xf64> -// CHECK: vector.transfer_write %[[R1]], %[[SUBVIEW]] -// CHECK: scf.yield -// CHECK: } -// CHECK-NOT: memref.copy -// CHECK-NOT: memref.dealloc -// CHECK-NEXT: return -// CHECK: } - -// ----- - -// Similar to alloc_vectorized, but with two output params (%arg1 and %arg2). -// Note: %arg1 = %arg0 + %arg0, and %arg2 = (%arg0 + %arg0) * %arg0 -func.func @alloc2_vectorized(%arg0: memref<256xf64>, - %arg1: memref<256xf64>, - %arg2: memref<256xf64>) { - %c256 = arith.constant 256 : index - %c0 = arith.constant 0 : index - %c8 = arith.constant 8 : index - %cst = arith.constant 0.000000e+00 : f64 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<256xf64> - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<256xf64> - scf.parallel (%arg3) = (%c0) to (%c256) step (%c8) { - %alloca = memref.alloca() : memref<8xf64> - %0 = vector.transfer_read %arg0[%arg3], %cst {in_bounds = [true]} : memref<256xf64>, vector<8xf64> - %1 = arith.addf %0, %0 : vector<8xf64> - vector.transfer_write %1, %alloca[%c0] {in_bounds = [true]} : vector<8xf64>, memref<8xf64> - %subview = memref.subview %alloc_0[%arg3] [8] [1] : memref<256xf64> to memref<8xf64, strided<[1], offset: ?>> - memref.copy %alloca, %subview : memref<8xf64> to memref<8xf64, strided<[1], offset: ?>> - scf.yield - } - scf.parallel (%arg3) = (%c0) to (%c256) step (%c8) { - %subview = memref.subview %alloc[%arg3] [8] [1] : memref<256xf64> to memref<8xf64, strided<[1], offset: ?>> - %0 = vector.transfer_read %alloc_0[%arg3], %cst {in_bounds = [true]} : memref<256xf64>, vector<8xf64> - %1 = vector.transfer_read %arg0[%arg3], %cst {in_bounds = [true]} : memref<256xf64>, vector<8xf64> - %2 = arith.mulf %0, %1 : vector<8xf64> - vector.transfer_write %2, %subview[%c0] {in_bounds = [true]} : vector<8xf64>, memref<8xf64, strided<[1], offset: ?>> - scf.yield - } - memref.copy %alloc_0, %arg1 : memref<256xf64> to memref<256xf64> - memref.dealloc %alloc_0 : memref<256xf64> - memref.copy %alloc, %arg2 : memref<256xf64> to memref<256xf64> - memref.dealloc %alloc : memref<256xf64> - return -} - -// CHECK-LABEL: func.func @alloc2_vectorized( -// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: memref<256xf64>, -// CHECK-SAME: %[[ARG1:.*]]: memref<256xf64>, -// CHECK-SAME: %[[ARG2:.*]]: memref<256xf64>) { -// CHECK-NOT: memref.alloc -// CHECK: scf.parallel -// CHECK: %[[ALLOCA:.*]] = memref.alloca() -// CHECK: %[[R0:.*]] = vector.transfer_read %[[ARG0]] -// CHECK: %[[R1:.*]] = arith.addf %[[R0]], %[[R0]] -// CHECK: vector.transfer_write %[[R1]], %[[ALLOCA]] -// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ARG1]] -// CHECK: memref.copy %[[ALLOCA]], %[[SUBVIEW]] -// CHECK: scf.yield -// CHECK: } -// CHECK-NOT: memref.copy -// CHECK-NOT: memref.dealloc -// CHECK-NEXT: scf.parallel -// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ARG2]] -// CHECK: %[[R0:.*]] = vector.transfer_read %[[ARG1]] -// CHECK: %[[R1:.*]] = vector.transfer_read %[[ARG0]] -// CHECK: %[[R2:.*]] = arith.mulf %[[R0]], %[[R1]] -// CHECK: vector.transfer_write %[[R2]], %[[SUBVIEW]] -// CHECK: scf.yield -// CHECK: } -// CHECK-NOT: memref.copy -// CHECK-NOT: memref.dealloc -// CHECK-NEXT: return -// CHECK: } diff --git a/xla/mlir/backends/cpu/transforms/tests/rng_bit_generator.mlir b/xla/mlir/backends/cpu/transforms/tests/rng_bit_generator.mlir deleted file mode 100644 index fc7b6ba99e32b..0000000000000 --- a/xla/mlir/backends/cpu/transforms/tests/rng_bit_generator.mlir +++ /dev/null @@ -1,16 +0,0 @@ -// RUN: xla-cpu-opt %s -xla-legalize-library-ops | FileCheck %s - -func.func @rng_bit_generator(%state: tensor<2xui64>) -> (tensor<2xui64>, tensor<10xui32>) { - %new_state, %output = "mhlo.rng_bit_generator"(%state) { - rng_algorithm = #mhlo.rng_algorithm - } : (tensor<2xui64>) -> (tensor<2xui64>, tensor<10xui32>) - func.return %new_state, %output : tensor<2xui64>, tensor<10xui32> -} - -// CHECK-LABEL: @rng_bit_generator -// CHECK-SAME: %[[ARG0:.*]]: tensor -// CHECK: %[[STATE_INIT:.*]] = tensor.empty() : tensor<2xui64> -// CHECK: %[[DST_INIT:.*]] = tensor.empty() : tensor<10xui32> -// CHECK: "xla_cpu.rng_bit_generator"(%[[ARG0]], %[[STATE_INIT]], %[[DST_INIT]]) -// CHECK-SAME: {rng_algorithm = #mhlo.rng_algorithm} : -// CHECK-SAME: (tensor<2xui64>, tensor<2xui64>, tensor<10xui32>) -> (tensor<2xui64>, tensor<10xui32>) diff --git a/xla/mlir/backends/cpu/transforms/tests/xla_abi_legalization.mlir b/xla/mlir/backends/cpu/transforms/tests/xla_abi_legalization.mlir deleted file mode 100644 index 9e6b3520a3f13..0000000000000 --- a/xla/mlir/backends/cpu/transforms/tests/xla_abi_legalization.mlir +++ /dev/null @@ -1,139 +0,0 @@ -// RUN: xla-cpu-opt %s -split-input-file -xla-legalize-abi \ -// RUN: -allow-unregistered-dialect \ -// RUN: | FileCheck %s - -func.func @all_custom(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) - -> tensor<2x3x4xf32> attributes { - xla_entry_computation_parameter_layouts = [ - dense<[0, 1, 2]> : tensor<3xindex>, - dense<[1, 2, 0]> : tensor<3xindex> - ], - xla_entry_computation_result_layout = dense<[2, 0, 1]> : tensor<3xindex> - } { - %add = mhlo.add %arg0, %arg1 : tensor<2x3x4xf32> - func.return %add : tensor<2x3x4xf32> -} - -// CHECK-LABEL: @all_custom -// CHECK-SAME: %[[ARG0:.*]]: tensor{{.*}}, %[[ARG1:.*]]: tensor{{.*}} -// CHECK-NOT: attributes -// CHECK: %[[R0:.*]] = mhlo.reshape %[[ARG0]] {{.*}} -> tensor<4x3x2xf32> -// CHECK: %[[T0:.*]] = "mhlo.transpose"(%[[R0]]) {{.*}} -> tensor<2x3x4xf32> -// CHECK: %[[R1:.*]] = mhlo.reshape %[[ARG1]] {{.*}} -> tensor<2x4x3xf32> -// CHECK: %[[T1:.*]] = "mhlo.transpose"(%[[R1]]) {{.*}} -> tensor<2x3x4xf32> -// CHECK: %[[ADD:.*]] = mhlo.add %[[T0]], %[[T1]] -// CHECK: %[[TR:.*]] = "mhlo.transpose"(%[[ADD]]) {{.*}} -> tensor<3x2x4xf32> -// CHECK: %[[RR:.*]] = mhlo.reshape %[[TR]] {{.*}} -> tensor<2x3x4xf32> -// CHECK: return %[[RR]] - -// ----- - -func.func @scalar_and_default_args(%arg0: tensor, %arg1: tensor<2x3xf32>, - %arg2: tensor<2x3xf32>) -> tensor attributes { - xla_entry_computation_parameter_layouts = [ - dense<> : tensor<0xindex>, - dense<[0, 1]> : tensor<2xindex>, - dense<[1, 0]> : tensor<2xindex> - ], - xla_entry_computation_result_layout = dense<> : tensor<0xindex> - } { - %result = "test.dummy"(%arg0, %arg1, %arg2) : - (tensor, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor - func.return %result : tensor -} - -// CHECK-LABEL: @scalar_and_default_args -// CHECK-SAME: %[[ARG0:.*]]: tensor{{.*}}, %[[ARG1:.*]]: tensor{{.*}}, %[[ARG2:.*]]: tensor{{.*}} -// CHECK: %[[R1:.*]] = mhlo.reshape %[[ARG1]] {{.*}} -> tensor<3x2xf32> -// CHECK: %[[T1:.*]] = "mhlo.transpose"(%[[R1]]) {{.*}} -> tensor<2x3xf32> -// CHECK: %[[R:.*]] = "test.dummy"(%[[ARG0]], %[[T1]], %[[ARG2]]) -// CHECK: return %[[R]] - -// ----- - -func.func @two_scalar_return_values() -> (tensor, tensor) attributes { - xla_entry_computation_result_layout = [ - dense<> : tensor<0xindex>, - dense<> : tensor<0xindex> - ] - } { - %result:2 = "test.dummy"() : () -> (tensor, tensor) - func.return %result#0, %result#1 : tensor, tensor -} - -// CHECK-LABEL: @two_scalar_return_values - -// ----- - -func.func @return_i1() -> (tensor) { - %result = "test.dummy"() : () -> tensor - func.return %result : tensor -} - -// CHECK-LABEL: @return_i1() -> tensor -// CHECK: %[[I1:.*]] = "test.dummy"() : () -> tensor -// CHECK: %[[U8:.*]] = mhlo.convert %[[I1]] {{.*}} -> tensor -// CHECK: return %[[U8]] - -// ----- - -func.func @custom_call(%arg0: tensor, %arg1: tensor<2x3xf32>) -> (tensor<6x3xf32>, tensor<3xf32>) { - %result:2 = "mhlo.custom_call"(%arg0, %arg1) { - call_target_name = "yolo", - operand_layouts = [dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], - result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>] - } : (tensor, tensor<2x3xf32>) -> (tensor<6x3xf32>, tensor<3xf32>) - return %result#0, %result#1 : tensor<6x3xf32>, tensor<3xf32> -} - -// CHECK-LABEL: @custom_call -// CHECK-SAME: %[[ARG0:.*]]: tensor{{.*}}, %[[ARG1:.*]]: tensor{{.*}} -// CHECK-NOT: operand_layouts -// CHECK-NOT: result_layouts -// CHECK: %[[T1:.*]] = "mhlo.transpose"(%[[ARG1]]) {{.*}} -> tensor<3x2xf32> -// CHECK: %[[R1:.*]] = mhlo.reshape %[[T1]] {{.*}} -> tensor<2x3xf32> -// CHECK: %[[CC:.*]]:2 = mhlo.custom_call @yolo(%[[ARG0]], %[[R1]]) -// CHECK: %[[RR:.*]] = mhlo.reshape %[[CC]]#0 {{.*}} -> tensor<3x6xf32> -// CHECK: %[[TR:.*]] = "mhlo.transpose"(%[[RR]]) {{.*}} -> tensor<6x3xf32> -// CHECK: return %[[TR]], %[[CC]]#1 - -// ----- - -func.func @custom_call_i1_input(%arg0: tensor<42xi1>) { - "mhlo.custom_call"(%arg0) { call_target_name = "yolo" } - : (tensor<42xi1>) -> () - return -} - -// CHECK-LABEL: @custom_call_i1_input -// CHECK: %[[CONVERTED:.*]] = mhlo.convert {{.*}} : (tensor<42xi1>) -> tensor<42xui8> -// CHECK: mhlo.custom_call @yolo(%[[CONVERTED]]) - -// ----- - -func.func @constant_with_layout() -> tensor<2x3xf32> { - %c = "mhlo.constant"() { - value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>, - result_layout = dense<[0, 1]> : tensor<2xindex> - } : () -> tensor<2x3xf32> - return %c : tensor<2x3xf32> -} - -// CHECK-LABEL: @constant_with_layout -// CHECK: %[[CST:.*]] = mhlo.constant {{.*}} : tensor<3x2xf32> -// CHECK: %[[TR:.*]] = "mhlo.transpose"(%[[CST]]) {{.*}} -> tensor<2x3xf32> -// CHECK: return %[[TR]] - -// ----- - -func.func @non_tensor_inouts() -> !mhlo.token { - %0 = mhlo.create_token : !mhlo.token - %1 = "mhlo.custom_call"(%0) { - call_target_name = "yolo", - operand_layouts = [dense<> : tensor<0xindex>], - result_layouts = [dense<> : tensor<0xindex>] - } : (!mhlo.token) -> (!mhlo.token) - return %1 : !mhlo.token -} - -// CHECK-LABEL: @non_tensor_inouts diff --git a/xla/mlir/backends/cpu/transforms/tests/xla_cpu_infeed.mlir b/xla/mlir/backends/cpu/transforms/tests/xla_cpu_infeed.mlir deleted file mode 100644 index 5ddb959a7713b..0000000000000 --- a/xla/mlir/backends/cpu/transforms/tests/xla_cpu_infeed.mlir +++ /dev/null @@ -1,16 +0,0 @@ -// RUN: xla-cpu-opt %s -xla-cpu-to-cpu-runtime | FileCheck %s - -func.func @infeed(%arg0 : memref<3x3xi32>, %arg1 : memref) -> () { - "xla_cpu.infeed"(%arg0, %arg1) {config = "foobar", layout = [[0, 1], [0]]} - : (memref<3x3xi32>, memref) -> () - return -} - -// CHECK: func @infeed( -// CHECK-SAME: %[[ARG0:[a-z0-9]+]]: memref<3x3xi32> -// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: memref -// CHECK-SAME: ) -// CHECK: call @[[INFEED:.*]](%[[ARG0]], %[[ARG1]]) -// CHECK SAME: : (memref<3x3xi32>, memref) -> () -// CHECK: func private @[[INFEED]](memref<3x3xi32>, memref) -// CHECK-SAME: attributes {rt.custom_call = "[[INFEED]]"} diff --git a/xla/mlir/backends/cpu/transforms/tests/xla_cpu_memref_element_cast_to_llvm.mlir b/xla/mlir/backends/cpu/transforms/tests/xla_cpu_memref_element_cast_to_llvm.mlir deleted file mode 100644 index f56b3b1566e8c..0000000000000 --- a/xla/mlir/backends/cpu/transforms/tests/xla_cpu_memref_element_cast_to_llvm.mlir +++ /dev/null @@ -1,42 +0,0 @@ -// RUN: xla-cpu-opt -xla-convert-memref-element-cast-to-llvm %s \ -// RUN: -split-input-file | FileCheck %s - -func.func @memref_cast(%arg0: memref<10xf32>) -> memref<10xi32> { - %ret = xla_cpu.memref_element_cast %arg0 : memref<10xf32> to memref<10xi32> - return %ret : memref<10xi32> -} -// CHECK-LABEL: func.func @memref_cast( -// CHECK-SAME: %[[SRC:.*]]: memref<10xf32>) -> memref<10xi32> -// CHECK: %[[SRC_DESC:.*]] = builtin.unrealized_conversion_cast %[[SRC]] -// CHECK-SAME: : memref<10xf32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK-NEXT: %[[ALLOC_PTR:.*]] = llvm.extractvalue %[[SRC_DESC]][0] -// CHECK-NEXT: %[[ALIGN_PTR:.*]] = llvm.extractvalue %[[SRC_DESC]][1] - -// CHECK: %[[DST_DESC:.*]] = llvm.mlir.undef -// CHECK-SAME: : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK-NEXT: %[[DST_DESC_:.*]] = llvm.insertvalue %[[ALLOC_PTR]], %[[DST_DESC]][0] -// CHECK-NEXT: llvm.insertvalue %[[ALIGN_PTR]], %[[DST_DESC_]][1] - -// CHECK: builtin.unrealized_conversion_cast -// CHECK-SAME: : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<10xi32> - -// ----- - -func.func @memref_cast_i1(%arg0: memref<10xi1>) -> memref<10xi8> { - %ret = xla_cpu.memref_element_cast %arg0 : memref<10xi1> to memref<10xi8> - return %ret : memref<10xi8> -} -// CHECK-LABEL: func.func @memref_cast_i1( -// CHECK-SAME: %[[SRC:.*]]: memref<10xi1>) -> memref<10xi8> -// CHECK: %[[SRC_DESC:.*]] = builtin.unrealized_conversion_cast %[[SRC]] -// CHECK-SAME: : memref<10xi1> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK-NEXT: %[[ALLOC_PTR:.*]] = llvm.extractvalue %[[SRC_DESC]][0] -// CHECK-NEXT: %[[ALIGN_PTR:.*]] = llvm.extractvalue %[[SRC_DESC]][1] - -// CHECK: %[[DST_DESC:.*]] = llvm.mlir.undef -// CHECK-SAME: : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK-NEXT: %[[DST_DESC_:.*]] = llvm.insertvalue %[[ALLOC_PTR]], %[[DST_DESC]][0] -// CHECK-NEXT: llvm.insertvalue %[[ALIGN_PTR]], %[[DST_DESC_]][1] - -// CHECK: builtin.unrealized_conversion_cast -// CHECK-SAME: : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<10xi8> diff --git a/xla/mlir/backends/cpu/transforms/tests/xla_cpu_outfeed.mlir b/xla/mlir/backends/cpu/transforms/tests/xla_cpu_outfeed.mlir deleted file mode 100644 index f4f17c9470587..0000000000000 --- a/xla/mlir/backends/cpu/transforms/tests/xla_cpu_outfeed.mlir +++ /dev/null @@ -1,37 +0,0 @@ -// RUN: xla-cpu-opt %s -split-input-file -xla-cpu-to-cpu-runtime \ -// RUN: | FileCheck %s - -func.func @cpu_outfeed(%arg0: memref<8xf32>, %arg1: memref<10xui32>) { - "xla_cpu.outfeed"(%arg0, %arg1) {config = "abc", result_type = [f32, ui32]} : (memref<8xf32>, memref<10xui32>) -> () - return -} - -// CHECK: func @cpu_outfeed( -// CHECK-SAME: %[[ARG0:[a-z0-9]+]]: memref<8xf32> -// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: memref<10xui32> -// CHECK-SAME: ) -// CHECK: call @[[OUTFEED:.*]](%[[ARG0]], %[[ARG1]]) -// CHECK-SAME: {result_type = [11 : i32, 8 : i32]} : (memref<8xf32>, memref<10xui32>) -> () -// CHECK: func private @[[OUTFEED]](memref<8xf32>, memref<10xui32>) -// CHECK-SAME: attributes {rt.custom_call = "xla.cpu.outfeed"} - -// ----- - -func.func @cpu_outfeed_strided( - %arg0: memref<8x8xf32, strided<[?, 1], offset: ?>>, - %arg1: memref<10xui32>) { - "xla_cpu.outfeed"(%arg0, %arg1) {config = "abc", result_type = [f32, ui32]} - : (memref<8x8xf32, strided<[?, 1], offset: ?>>, memref<10xui32>) -> () - return -} - -// CHECK: func @cpu_outfeed_strided( -// CHECK-SAME: %[[ARG0:[a-z0-9]+]]: memref<8x8xf32, strided<[?, 1], offset: ?>> -// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: memref<10xui32> -// CHECK-SAME: ) -// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() -// CHECK-NEXT: memref.copy %[[ARG0]], %[[ALLOC]] -// CHECK: call @[[OUTFEED:.*]](%[[ALLOC]], %[[ARG1]]) -// CHECK-SAME: {result_type = [11 : i32, 8 : i32]} : (memref<8x8xf32>, memref<10xui32>) -> () -// CHECK: func private @[[OUTFEED]](memref<8x8xf32>, memref<10xui32>) -// CHECK-SAME: attributes {rt.custom_call = "xla.cpu.outfeed"} diff --git a/xla/mlir/backends/cpu/transforms/xla_abi_legalization.cc b/xla/mlir/backends/cpu/transforms/xla_abi_legalization.cc index 8383b7c0b79c3..03298f338e798 100644 --- a/xla/mlir/backends/cpu/transforms/xla_abi_legalization.cc +++ b/xla/mlir/backends/cpu/transforms/xla_abi_legalization.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/backends/cpu/transforms/xla_cpu_memref_element_cast_to_llvm.cc b/xla/mlir/backends/cpu/transforms/xla_cpu_memref_element_cast_to_llvm.cc index 1bd61e1420c07..8bedb9be524ae 100644 --- a/xla/mlir/backends/cpu/transforms/xla_cpu_memref_element_cast_to_llvm.cc +++ b/xla/mlir/backends/cpu/transforms/xla_cpu_memref_element_cast_to_llvm.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/backends/cpu/transforms/xla_cpu_to_cpu_runtime.cc b/xla/mlir/backends/cpu/transforms/xla_cpu_to_cpu_runtime.cc index d297114f9ca95..fb3bb71548c6f 100644 --- a/xla/mlir/backends/cpu/transforms/xla_cpu_to_cpu_runtime.cc +++ b/xla/mlir/backends/cpu/transforms/xla_cpu_to_cpu_runtime.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,7 +35,6 @@ limitations under the License. #include "xla/mlir/runtime/transforms/type_converter.h" #include "xla/mlir/runtime/utils/custom_calls.h" #include "xla/mlir/xla_cpu/ir/xla_cpu.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo_parser.h" @@ -48,8 +47,6 @@ namespace { using namespace mlir; // NOLINT -using mlir::lmhlo::CustomCallOp; - using xla_cpu::PartitionIdOp; using xla_cpu::ReplicaIdOp; @@ -115,133 +112,6 @@ func::CallOp CreateCallForDpsCollectiveOp(Operation* op, //===----------------------------------------------------------------------===// -class CustomCallOpLowering : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = "xla.cpu.custom_call"; - - public: - CustomCallOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - // Rewrite custom call with `API_VERSION_TYPED_FFI` version into XLA runtime - // custom calls bypassing custom call adaptor. - LogicalResult rewriteTypedCustomCall(CustomCallOp op, - PatternRewriter& rewriter) const { - // TODO(ezhulenev): Support target arg mapping, or explain why we do not - // need them for typed custom calls. - if (op.getTargetArgMapping()) - return op.emitOpError( - "API_VERSION_TYPED_FFI custom calls do not " - "support target arg mapping"); - - // Create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = - custom_calls_.GetOrCreate(b, op.getCallTargetName(), op); - callee->setAttr("rt.dynamic", UnitAttr::get(b.getContext())); - - // Forward backend config to the custom call implementation. - auto config = op.getBackendConfig(); - if (!config) return op.emitOpError("Failed to get backend config"); - auto dict = config->cast(); - llvm::SmallVector backend_config(dict.begin(), dict.end()); - - // Call the custom call function forwarding user-defined attributes. - auto call = rewriter.replaceOpWithNewOp( - op, callee.getName(), TypeRange(), op.getOperands()); - AppendCustomCallAttrs(call, backend_config); - - return success(); - } - - LogicalResult matchAndRewrite(CustomCallOp op, - PatternRewriter& rewriter) const override { - // Typed custom calls lowered directly to XLA runtime custom calls. - if (op.getApiVersion() == mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI) - return rewriteTypedCustomCall(op, rewriter); - - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - // By default all operands passed to the custom call handler. - llvm::SmallVector operands = op.getOperands(); - - // Get the number of outputs from operand_segment_sizes. - int64_t num_results = op->getAttrOfType( - op.getOperandSegmentSizesAttrName())[1]; - - // If custom call has target arguments mapping, then we need to pass empty - // memrefs in place of holes. - if (op.getTargetArgMapping().has_value()) { - auto mapping = *op.getTargetArgMapping(); - int64_t num_args = mapping.getNumArgs(); - num_results = mapping.getNumResults(); - - // Always create an `alloca` in the parent function entry block. - // See: https://llvm.org/docs/Frontend/PerformanceTips.html#use-of-allocas - Value hole = [&]() -> Value { - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart( - &op->getParentOfType().front()); - return b.create(MemRefType::get({0}, b.getI8Type())); - }(); - - // We represent holes as empty i8 memrefs. - operands = llvm::SmallVector(num_args + num_results, hole); - - // Update operands to mapped custom call arguments. - auto args = mapping.getArgsToTargetArgs(); - for (const auto& indexed : llvm::enumerate(args)) - operands[indexed.value()] = op.getArgs()[indexed.index()]; - - // Update operands to mapped custom call results. - auto res = mapping.getResultsToTargetResults(); - for (const auto& indexed : llvm::enumerate(res)) - operands[num_args + indexed.value()] = op.getOutput()[indexed.index()]; - } - - // TODO(jreiffers): This will break if an output has a non-default layout. - operands = EnsureFlatMemrefs(operands, b); - // Create a custom call function declaration. - func::FuncOp callee = custom_calls_.GetOrCreate( - b, kCustomCallTarget, TypeRange(ValueRange(operands)), TypeRange()); - - // The ABI is different depending on whether the original op was outputting - // a tuple or not. For multiple outputs this is trivial but for a single - // output we rely on the xla_shape attribute to distinguish the ABIs. - bool output_tuple = num_results > 1; - if (auto xla_shape = op->getAttrOfType("xla_shape")) - output_tuple = ParseShape(xla_shape.strref())->IsTuple(); - - // This is not equivalent to op.getApiVersionAttr() - that call returns null - // if the attribute is absent. getApiVersion returns the default. - Attribute api_version = - mhlo::CustomCallApiVersionAttr::get(getContext(), op.getApiVersion()); - llvm::SmallVector custom_call_attrs = { - {b.getStringAttr("num_results"), - b.getI32IntegerAttr(static_cast(num_results))}, - {b.getStringAttr("output_tuple"), b.getBoolAttr(output_tuple)}, - {b.getStringAttr("api_version"), api_version}, - {b.getStringAttr("call_target_name"), op.getCallTargetNameAttr()}}; - - if (auto backend_config = op.getBackendConfigAttr()) { - custom_call_attrs.emplace_back(b.getStringAttr("backend_config"), - op.getBackendConfigAttr()); - } - - // Call the runtime intrinsic with the original operands. - auto call = rewriter.replaceOpWithNewOp( - op, callee.getName(), TypeRange(), operands); - AppendCustomCallAttrs(call, custom_call_attrs); - - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -//===----------------------------------------------------------------------===// - template class IdOpLowering : public OpRewritePattern { public: @@ -542,11 +412,10 @@ void ConvertXlaCpuToCpuRuntimePass::runOnOperation() { // Convert xla_cpu operations to XLA cpu runtime custom calls. RewritePatternSet patterns(ctx); - patterns - .insert( - ctx, custom_calls); + patterns.insert( + ctx, custom_calls); patterns.insert>(ctx, "xla.cpu.partition_id", custom_calls); patterns.insert>(ctx, "xla.cpu.replica_id", diff --git a/xla/mlir/backends/cpu/transforms/xla_rewrite_realloc_to_alloc.cc b/xla/mlir/backends/cpu/transforms/xla_rewrite_realloc_to_alloc.cc index 0f43b6799d15f..f358eb4f27c7c 100644 --- a/xla/mlir/backends/cpu/transforms/xla_rewrite_realloc_to_alloc.cc +++ b/xla/mlir/backends/cpu/transforms/xla_rewrite_realloc_to_alloc.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/backends/cpu/xla-cpu-opt.cc b/xla/mlir/backends/cpu/xla-cpu-opt.cc deleted file mode 100644 index 88a8e5c8663b9..0000000000000 --- a/xla/mlir/backends/cpu/xla-cpu-opt.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mlir/Dialect/Bufferization/Transforms/Passes.h" // from @llvm-project -#include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project -#include "mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project -#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project -#include "mlir/Dialect/Vector/IR/VectorOps.h" // from @llvm-project -#include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project -#include "stablehlo/dialect/Register.h" // from @stablehlo -#include "xla/mlir/backends/cpu/transforms/passes.h" -#include "xla/mlir/xla_cpu/ir/xla_cpu.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/mlir_hlo/lhlo/transforms/passes.h" -#include "xla/mlir_hlo/mhlo/IR/register.h" -#include "xla/mlir_hlo/mhlo/transforms/passes.h" - -int main(int argc, char **argv) { - mlir::mhlo::registerAllMhloPasses(); - mlir::lmhlo::registerAllLmhloPasses(); - mlir::bufferization::registerBufferizationPasses(); - - mlir::DialectRegistry registry; - mlir::mhlo::registerAllMhloDialects(registry); - mlir::stablehlo::registerAllDialects(registry); - registry.insert(); - mlir::func::registerAllExtensions(registry); - - xla::cpu::registerCpuTransformsPasses(); - - return failed(MlirOptMain(argc, argv, "Xla Cpu Pass Driver\n", registry)); -} diff --git a/xla/mlir/backends/gpu/BUILD b/xla/mlir/backends/gpu/BUILD deleted file mode 100644 index de4a77c1d0bd4..0000000000000 --- a/xla/mlir/backends/gpu/BUILD +++ /dev/null @@ -1,31 +0,0 @@ -load("//xla:xla.bzl", "xla_cc_binary") -load("@bazel_skylib//rules:build_test.bzl", "build_test") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//xla/mlir:__subpackages__"], - licenses = ["notice"], -) - -build_test( - name = "xla-gpu-opt_build_test", - targets = [ - ":xla-gpu-opt", - ], -) - -xla_cc_binary( - name = "xla-gpu-opt", - srcs = ["xla-gpu-opt.cc"], - deps = [ - "//xla/mlir/backends/gpu/transforms:passes", - "//xla/mlir_hlo:lhlo", - "//xla/mlir_hlo:lhlo_gpu", - "//xla/stream_executor", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncExtensions", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:MlirOptLib", - ], -) diff --git a/xla/mlir/backends/gpu/transforms/BUILD b/xla/mlir/backends/gpu/transforms/BUILD deleted file mode 100644 index 673c188a77d70..0000000000000 --- a/xla/mlir/backends/gpu/transforms/BUILD +++ /dev/null @@ -1,96 +0,0 @@ -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") -load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//xla:internal"], - licenses = ["notice"], -) - -gentbl_cc_library( - name = "passes_inc_gen", - compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=GpuTransforms", - ], - "passes.h.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "passes.td", - deps = ["@llvm-project//mlir:PassBaseTdFiles"], -) - -cc_library( - name = "dataflow_analysis", - srcs = ["dataflow_analysis.cc"], - hdrs = ["dataflow_analysis.h"], - compatible_with = [], - deps = [ - "//xla/mlir_hlo:lhlo", - "//xla/mlir_hlo:lhlo_gpu", - "@com_google_absl//absl/strings", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - ], -) - -cc_library( - name = "passes", - srcs = [ - "add_concurrent_regions.cc", - "add_hlo_trace_annotations.cc", - "gpu_to_gpu_runtime.cc", - "lmhlo_gpu_to_gpu_runtime.cc", - "lmhlo_to_gpu_launch.cc", - "lmhlo_to_gpu_runtime.cc", - "memref_get_global_to_arg.cc", - "outline_cuda_graphs.cc", - "passes.cc", - "stream_assignment.cc", - "uid_generator.h", - ], - hdrs = ["passes.h"], - # Override cc_library()'s internal default value of ["//buildenv/target:gce"].` - # TODO(ezhulenev): Do not depend on NCCL thunks in compiler passes. - compatible_with = [], - deps = [ - ":dataflow_analysis", - ":passes_inc_gen", - "//xla:debug_options_flags", - "//xla:xla_proto_cc", - "//xla/mlir/runtime/ir:rt", - "//xla/mlir/runtime/utils:custom_calls", - "//xla/mlir_hlo:lhlo", - "//xla/mlir_hlo:lhlo_gpu", - "//xla/service/gpu:backend_configs_cc", - "//xla/service/gpu:gpu_executable", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:nccl_collective_thunks", - "//xla/stream_executor:blas", - "//xla/stream_executor:device_description", - "//xla/translate/mhlo_to_hlo:location_exporter", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ControlFlowDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", - "@tsl//tsl/platform:env", - ], -) diff --git a/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc b/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc deleted file mode 100644 index f8c4d4f110371..0000000000000 --- a/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc +++ /dev/null @@ -1,236 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include -#include -#include -#include - -#include "absl/strings/match.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "xla/mlir/backends/gpu/transforms/dataflow_analysis.h" -#include "xla/mlir/runtime/utils/custom_calls.h" -#include "tsl/platform/env.h" - -namespace xla { -namespace gpu { - -namespace { - -#define GEN_PASS_DEF_ADDCONCURRENTREGIONSPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -using namespace mlir; // NOLINT -using mlir::func::FuncOp; -using xla::runtime::CustomCallDeclarations; - -class AddConcurrentRegionsPass - : public impl::AddConcurrentRegionsPassBase { - void runOnOperation() override; -}; - -//===----------------------------------------------------------------------===// - -struct RegionInfo { - Operation* start; - Operation* end; - int size; -}; - -bool IsNoOp(Operation* op) { - return isa(op); -} - -int GetKernelCount(llvm::ArrayRef region) { - int kernel_count = 0; - for (const DataflowAnalysis::Node& node : region) { - Operation* op = node.operation; - if (!IsNoOp(op)) { - kernel_count++; - } - } - return kernel_count; -} - -// We use the size of the inputs to the kernel as a heuristic to avoid -// adding memory bound kernels to the concurrent region. -// The memory bandwidth on A100 is 2MB/us, so a data movement less than 10MB -// is hidden by the kernel launch overhead, which is 5us. -static constexpr int64_t kInputSizeThreshold = 10'000'000; - -bool IsKernelMemoryBound(Operation* op) { - if (auto launch_func = dyn_cast(op)) { - size_t size = 0; - - for (Value operand : launch_func.getOperands()) { - if (auto memref_type = dyn_cast(operand.getType())) { - size += (memref_type.getNumElements() * - memref_type.getElementTypeBitWidth() + - 7) / - 8; - } - } - - if (size > kInputSizeThreshold) { - return true; - } - } - - return false; -} - -// -// Return a list of pairs of operations, in which the first element is the -// first operation in the region, and the second is the last operation in the -// region. -// -// We currently use a greedy algorithm to determine region starting point: -// regions = [] -// region = {first operation} -// for operation in the capture function -// if HasDependency(region, operation) -// regions.add(region) -// region = new region -// else -// region.add(operation) -// -llvm::SmallVector GetRegionInfos( - FuncOp capture_func, DataflowAnalysis& dataflow_analysis) { - llvm::SmallVector region_infos; - DataflowAnalysis::DataflowGraph dataflow_graph = - dataflow_analysis.GetDataflowGraph(capture_func); - - // If verbose logging is enabled print the dataflow graph as a DOT graph. - if (VLOG_IS_ON(100)) { - std::cout << "Dependency graph for graph capture function " - << capture_func.getName().str() << ":\n" - << dataflow_analysis.ToDot(dataflow_graph); - } - - llvm::SmallVector region; - - auto store_region_and_start_new_region = [&]() { - int kernel_count = GetKernelCount(region); - if (kernel_count >= 2) { - RegionInfo region_info = {region.front().operation, - region.back().operation, kernel_count}; - region_infos.push_back(region_info); - } - region.clear(); - }; - - auto append_node_to_region = [&](const DataflowAnalysis::Node& node) { - if (region.empty()) { - if (!IsNoOp(node.operation)) { - region.push_back(node); - } - } else { - region.push_back(node); - } - }; - - for (const DataflowAnalysis::Node& node : dataflow_graph) { - if (isa(node.operation)) { - break; - } - - bool has_dependency = false; - for (const DataflowAnalysis::Node& node_in_region : region) { - std::vector children = node_in_region.children; - if (std::find(children.begin(), children.end(), node.index) != - children.end()) { - has_dependency = true; - break; - } - } - - if (IsKernelMemoryBound(node.operation)) { - store_region_and_start_new_region(); - } else if (has_dependency) { - store_region_and_start_new_region(); - append_node_to_region(node); - } else { - append_node_to_region(node); - } - } - - store_region_and_start_new_region(); - return region_infos; -} - -void InsertConcurrentRegions(FuncOp capture_func, - CustomCallDeclarations& custom_calls, - DataflowAnalysis& dataflow_analysis) { - llvm::SmallVector region_infos = - GetRegionInfos(capture_func, dataflow_analysis); - auto sym_table = custom_calls.sym_table(); - - for (RegionInfo region_info : region_infos) { - Operation* start = region_info.start; - Operation* end = region_info.end; - - ImplicitLocOpBuilder b(start->getLoc(), sym_table.getOp()); - func::FuncOp begin_marker = custom_calls.GetOrCreate( - b, "xla.gpu.concurrent_region.begin", TypeRange(), TypeRange()); - b.setInsertionPoint(start); - auto call = b.create(begin_marker.getName(), TypeRange()); - call->setAttr(b.getStringAttr("size"), - IntegerAttr::get(b.getIntegerType(64), region_info.size)); - - func::FuncOp end_marker = custom_calls.GetOrCreate( - b, "xla.gpu.concurrent_region.end", TypeRange(), TypeRange()); - b.setInsertionPointAfter(end); - b.create(end_marker.getName(), TypeRange()); - } -} - -//===----------------------------------------------------------------------===// - -void AddConcurrentRegionsPass::runOnOperation() { - ModuleOp module = getOperation(); - SymbolTable sym_table(module); - CustomCallDeclarations custom_calls(std::move(sym_table)); - - auto func_ops = llvm::to_vector(module.getOps()); - - for (auto func_op : func_ops) { - // Find the gpu graph capture function. - if (absl::StrContains(func_op.getSymNameAttr().str(), - "xla.gpu.graph.capture")) { - InsertConcurrentRegions(func_op, custom_calls, - getAnalysis()); - } - } -} - -} // namespace - -std::unique_ptr> createAddConcurrentRegionsPass() { - return std::make_unique(); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/mlir/backends/gpu/transforms/add_hlo_trace_annotations.cc b/xla/mlir/backends/gpu/transforms/add_hlo_trace_annotations.cc deleted file mode 100644 index 64e2f0f38bec5..0000000000000 --- a/xla/mlir/backends/gpu/transforms/add_hlo_trace_annotations.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "absl/strings/match.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "xla/mlir/backends/gpu/transforms/passes.h" -#include "xla/mlir/runtime/ir/rt_dialect.h" -#include "xla/translate/mhlo_to_hlo/location_exporter.h" - -namespace xla { -namespace gpu { - -#define GEN_PASS_DEF_ADDHLOTRACEANNOTATIONSPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -using namespace mlir; // NOLINT - -using xla::runtime::HloTraceAttr; - -class AddHloTraceAnnotationsPass - : public impl::AddHloTraceAnnotationsPassBase { - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } -}; - -//===----------------------------------------------------------------------===// - -void AddHloTraceAnnotationsPass::runOnOperation() { - MLIRContext* ctx = &getContext(); - - ModuleOp module = getOperation(); - SymbolTable sym_table(module); - - getOperation().walk([&](func::CallOp call) { - // Check if the callee is a custom call. - auto callee = sym_table.lookup(call.getCallee()); - if (!callee->hasAttr("rt.custom_call")) return; - - // Drop multi-op trace for CUDA graphs since they are too large for xprof to - // display. - // TODO(b/275240695): Report the graph content once the Xprof team provides - // an API. - if (absl::StrContains(call.getCalleeAttr().getValue(), - "xla.gpu.graph.launch")) { - auto capture = call->getAttr("capture").cast(); - std::string op_name = "cuda_graph/" + capture.getValue().str(); - auto annotation = HloTraceAttr::get(ctx, std::move(op_name)); - call->setAttr("rt.trace", annotation); - return; - } - - // HLO operation name is encoded in the operation location. - std::string hlo_op = mlir::mhlo::GetDebugNameFromLocation(call->getLoc()); - auto annotation = HloTraceAttr::get(ctx, std::move(hlo_op)); - call->setAttr("rt.trace", annotation); - }); -} - -std::unique_ptr> -createAddHloTraceAnnotationsPass() { - return std::make_unique(); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc b/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc deleted file mode 100644 index 9772ca22e1da0..0000000000000 --- a/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc +++ /dev/null @@ -1,277 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/mlir/backends/gpu/transforms/dataflow_analysis.h" - -#include -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/IR/Visitors.h" // from @llvm-project -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" - -namespace xla { -namespace gpu { - -namespace { - -using namespace mlir; // NOLINT -using mlir::BlockArgument; -using mlir::Operation; -using mlir::func::FuncOp; - -// Represents a slice of the buffer argument to the graph capture function. -struct BufferUse { - BlockArgument arg; - size_t offset; - size_t byte_len; - - // The buffer is only read by the operation. - bool read_only; -}; - -BufferUse GetBufferUse(Value operand, bool read_only = false) { - Operation* defining_op = operand.getDefiningOp(); - if (!defining_op) { - auto block_argument = cast(operand); - auto memref_type = cast(block_argument.getType()); - size_t byte_len = - (memref_type.getNumElements() * memref_type.getElementTypeBitWidth() + - 7) / - 8; - return {block_argument, 0, byte_len, read_only}; - } - - if (isa(defining_op)) { - auto view_op = cast(defining_op); - auto buffer_use = GetBufferUse(view_op.getSource()); - - IntegerAttr offset_attr; - bool is_constant = - matchPattern(view_op.getByteShift(), m_Constant(&offset_attr)); - if (!is_constant) { - // Failed to refine the BufferUse. - return buffer_use; - } - size_t offset = offset_attr.getInt(); - - // Get len. - auto memref_type = cast(view_op.getType()); - // TODO(b/274157088): Handle the case where elements are complex numbers. - if (!memref_type.getElementType().isIntOrFloat()) { - return buffer_use; - } - - size_t byte_len = - (memref_type.getNumElements() * memref_type.getElementTypeBitWidth() + - 7) / - 8; - - return {buffer_use.arg, buffer_use.offset + offset, byte_len, read_only}; - } - - if (auto cast = dyn_cast(defining_op)) { - return GetBufferUse(cast.getSource(), read_only); - } - - return {}; -} - -llvm::SmallVector GetBufferUses(Operation& operation) { - llvm::SmallVector operand_buffer_uses; - if (auto launch_func = dyn_cast(operation)) { - auto kernel_func = - SymbolTable::lookupNearestSymbolFrom( - &operation, launch_func.getKernel()); - auto kernel_operands = launch_func.getKernelOperands(); - for (auto it : llvm::enumerate(kernel_operands)) { - BufferUse buffer_use = GetBufferUse( - it.value(), - /*read_only=*/!kernel_func.getArgAttrOfType( - it.index(), "lmhlo.written")); - operand_buffer_uses.push_back(buffer_use); - } - } else if (auto gemm = dyn_cast(operation)) { - BufferUse buffer_use_0 = GetBufferUse(gemm.getA(), /*read_only=*/true); - BufferUse buffer_use_1 = GetBufferUse(gemm.getB(), /*read_only=*/true); - BufferUse buffer_use_2 = GetBufferUse(gemm.getC(), /*read_only=*/false); - operand_buffer_uses.push_back(buffer_use_0); - operand_buffer_uses.push_back(buffer_use_1); - operand_buffer_uses.push_back(buffer_use_2); - } else if (auto memcpy = dyn_cast(operation)) { - BufferUse src_buffer = GetBufferUse(memcpy.getSrc(), /*read_only=*/true); - BufferUse dst_buffer = GetBufferUse(memcpy.getDst(), /*read_only=*/false); - operand_buffer_uses.push_back(src_buffer); - operand_buffer_uses.push_back(dst_buffer); - } - - return operand_buffer_uses; -} - -// Arguments to the graph capture function may have the "lmhlo.constant_name" -// attribute, which indicates that the passed-in buffer is constant. -bool IsConstant(BlockArgument block_argument) { - // Check if the input buffer is marked as constant. - Region* parent_region = block_argument.getParentRegion(); - auto parent_func = parent_region->getParentOfType(); - unsigned parent_func_arg_index = block_argument.getArgNumber(); - auto cst = parent_func.getArgAttrOfType(parent_func_arg_index, - "lmhlo.constant_name"); - return cst != nullptr; -} - -// Check if two buffer_uses overlap. -bool HasDependency(BufferUse buffer_use_a, BufferUse buffer_use_b) { - if (buffer_use_a.arg.getArgNumber() != buffer_use_b.arg.getArgNumber()) - return false; - if (IsConstant(buffer_use_a.arg) || IsConstant(buffer_use_b.arg)) - return false; - if (buffer_use_a.read_only && buffer_use_b.read_only) return false; - - // Check if two buffer slices overlap. - size_t start1 = buffer_use_a.offset; - size_t end1 = buffer_use_a.offset + buffer_use_a.byte_len; - size_t start2 = buffer_use_b.offset; - size_t end2 = buffer_use_b.offset + buffer_use_b.byte_len; - if (std::max(start1, start2) < std::min(end1, end2)) { - return true; - } - return false; -} - -bool HasDependency(llvm::ArrayRef buffer_uses_a, - llvm::ArrayRef buffer_uses_b) { - for (auto buffer_use_a : buffer_uses_a) { - for (auto buffer_use_b : buffer_uses_b) { - if (HasDependency(buffer_use_a, buffer_use_b)) return true; - } - } - return false; -} - -// Remove edges that are redundant for determining the execution order of -// kernels. We use the following algorithm to compute the transitive reduction: -// -// For source node in graph: -// For each edge (source -> target) -// longest_distance = the length of the longest path from source to target -// if (longest_distance > 1): -// remove (source -> target) -// -void TransitiveReduction(DataflowAnalysis::DataflowGraph& graph) { - std::vector> parents(graph.size(), std::vector()); - for (const DataflowAnalysis::Node& node : graph) { - for (size_t child_index : node.children) { - parents[child_index].push_back(node.index); - } - } - - std::vector longest_distance(graph.size()); - for (DataflowAnalysis::Node& source : graph) { - if (source.children.empty()) { - continue; - } - - std::fill(longest_distance.begin(), longest_distance.end(), 0); - size_t farthest_child = source.children.back(); - for (size_t target = source.index + 1; target <= farthest_child; target++) { - for (size_t mid : parents[target]) { - // If the mid node is before source in the topological order, no path - // source -> mid -> target can exits and we can skip it. - if (mid >= source.index) { - // If source -> mid -> target is longer than the longest path so far - // from source -> target, update the longest distance. - int candidate_longest_distance = longest_distance[mid] + 1; - if (candidate_longest_distance > longest_distance[target]) { - longest_distance[target] = candidate_longest_distance; - } - } - } - } - - source.children.erase( - std::remove_if( - source.children.begin(), source.children.end(), - [&](size_t target) { return longest_distance[target] > 1; }), - source.children.end()); - } -} - -} // namespace - -DataflowAnalysis::DataflowGraph DataflowAnalysis::GetDataflowGraph( - FuncOp graph_capture_function) { - std::vector graph; - for (auto [index, op] : llvm::enumerate(graph_capture_function.getOps())) { - graph.push_back(Node{&op, index, {}}); - } - - // A vector that stores the buffer used by each operation in the graph. The - // i-th operation's buffer uses are stored as the vector buffer_uses[i]; - std::vector> buffer_uses; - for (Operation& operation : graph_capture_function.getOps()) { - buffer_uses.push_back(GetBufferUses(operation)); - } - - for (int i = 0; i < graph.size(); ++i) { - Node& node_i = graph[i]; - llvm::ArrayRef buffer_uses_i = buffer_uses[i]; - for (int j = i + 1; j < graph.size(); ++j) { - llvm::ArrayRef buffer_uses_j = buffer_uses[j]; - if (HasDependency(buffer_uses_i, buffer_uses_j)) { - node_i.children.push_back(j); - } - } - } - - TransitiveReduction(graph); - return graph; -} - -std::string DataflowAnalysis::ToDot(const DataflowGraph& graph) { - std::string pad; - std::string res; - auto indent = [&] { pad.append(2, ' '); }; - auto outdent = [&] { pad.resize(pad.size() - 2); }; - auto addline = [&](auto&&... args) { - absl::StrAppend(&res, pad, args..., "\n"); - }; - auto get_name = [](const Node& node) -> std::string { - return absl::StrCat("\"", node.operation->getName().getStringRef().str(), - "_", node.index, "\""); - }; - - addline("digraph {"); - indent(); - for (const Node& node : graph) { - for (size_t child_index : node.children) { - Node child = graph[child_index]; - addline(get_name(node), " -> ", get_name(child)); - } - } - outdent(); - addline("}"); - return res; -} - -} // namespace gpu -} // namespace xla diff --git a/xla/mlir/backends/gpu/transforms/dataflow_analysis.h b/xla/mlir/backends/gpu/transforms/dataflow_analysis.h deleted file mode 100644 index 033037af11310..0000000000000 --- a/xla/mlir/backends/gpu/transforms/dataflow_analysis.h +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_MLIR_BACKENDS_GPU_TRANSFORMS_DATAFLOW_ANALYSIS_H_ -#define XLA_MLIR_BACKENDS_GPU_TRANSFORMS_DATAFLOW_ANALYSIS_H_ - -#include -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project - -namespace xla { -namespace gpu { - -class DataflowAnalysis { - public: - explicit DataflowAnalysis(mlir::Operation* op) {} - - struct Node { - mlir::Operation* operation; - size_t index; - std::vector children; - }; - - using DataflowGraph = std::vector; - - // This function creates a dataflow graph that represent data dependencies in - // the graph capture function. The analysis relies on some properties of the - // IR in XLA: - // (1) Buffer arguments do not alias. It is guaranteed that two buffer - // arguments to the graph capture function do not overlap. - // (2) XLA operations do not have any side effects beyond writing to its - // buffer arguments. So it is safe to reorder operations if they do not - // have write-conflicts. - // (3) We have information about read-only and read-write buffer arguments. - DataflowGraph GetDataflowGraph(mlir::func::FuncOp graph_capture_function); - - std::string ToDot(const DataflowGraph& graph); -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_MLIR_BACKENDS_GPU_TRANSFORMS_DATAFLOW_ANALYSIS_H_ diff --git a/xla/mlir/backends/gpu/transforms/gpu_to_gpu_runtime.cc b/xla/mlir/backends/gpu/transforms/gpu_to_gpu_runtime.cc deleted file mode 100644 index a016b4802100e..0000000000000 --- a/xla/mlir/backends/gpu/transforms/gpu_to_gpu_runtime.cc +++ /dev/null @@ -1,241 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "xla/mlir/backends/gpu/transforms/uid_generator.h" -#include "xla/mlir/runtime/utils/custom_calls.h" - -namespace xla { -namespace gpu { - -#define GEN_PASS_DEF_CONVERTGPUTOGPURUNTIMEPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -using namespace mlir; // NOLINT - -using mlir::gpu::GPUModuleOp; -using mlir::gpu::LaunchFuncOp; -using mlir::gpu::MemcpyOp; -using mlir::gpu::MemsetOp; - -using xla::runtime::CustomCallDeclarations; - -class ConvertGpuToGpuRuntimePass - : public impl::ConvertGpuToGpuRuntimePassBase { - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } -}; - -//===----------------------------------------------------------------------===// - -class GpuModuleOpLowering : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(GPUModuleOp op, - PatternRewriter& rewriter) const override { - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// - -class MemcpyOpLowering : public OpRewritePattern { - public: - MemcpyOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - // We use a heuristic to identify the direction of the memcpy operation, if - // the operand was allocated by alloca op or is a global memref, then it must - // be a memref on the host. - static bool IsHostMemRef(Value value) { - auto* op = value.getDefiningOp(); - return llvm::isa_and_nonnull(op); - } - - // Identify the direction of the memcpy operation. - static StringRef Target(MemcpyOp op) { - if (IsHostMemRef(op.getDst())) return "xla.gpu.memcpy.d2h"; - if (IsHostMemRef(op.getSrc())) return "xla.gpu.memcpy.h2d"; - return "xla.gpu.memcpy.d2d"; - } - - LogicalResult matchAndRewrite(MemcpyOp op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, Target(op), op); - - auto stream = op->getAttrOfType("stream"); - - // Create a function launch call operation. - auto call = rewriter.replaceOpWithNewOp( - op, callee.getName(), TypeRange(), op.getOperands()); - - if (stream) { - call->setAttr(b.getStringAttr("stream"), stream); - } else { - call->setAttr(b.getStringAttr("stream"), b.getI64IntegerAttr(0)); - } - - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -//===----------------------------------------------------------------------===// - -class MemsetOpLowering : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = "xla.gpu.memset"; - - public: - MemsetOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(MemsetOp op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, kCustomCallTarget, op); - - // Create a function launch call operation. - rewriter.replaceOpWithNewOp(op, callee.getName(), TypeRange(), - op.getOperands()); - - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -//===----------------------------------------------------------------------===// - -class LaunchFuncOpLowering : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = "xla.gpu.func.launch"; - - public: - LaunchFuncOpLowering(MLIRContext* ctx, UidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), uid_(uid), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(LaunchFuncOp op, - PatternRewriter& rewriter) const override { - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - // Cast grid and block dimensions to i32 before passing to the custom call. - auto cast = [&](mlir::Value value) { - return b.create(b.getI32Type(), value); - }; - - // Prepare arguments for the custom call. - llvm::SmallVector args = { - cast(op.getGridSizeX()), cast(op.getGridSizeY()), - cast(op.getGridSizeZ()), cast(op.getBlockSizeX()), - cast(op.getBlockSizeY()), cast(op.getBlockSizeZ())}; - - // Shared memory size is optional for the `gpu.launch` but mandatory for the - // Xla runtime kernel launch custom call. - if (op.getDynamicSharedMemorySize()) { - args.insert(args.begin(), op.getDynamicSharedMemorySize()); - } else { - auto zero = b.create(0, b.getI32Type()); - args.insert(args.begin(), zero); - } - - // Add kernel arguments. - llvm::copy(op.getKernelOperands(), std::back_inserter(args)); - - // Get or create a custom call function declaration. - func::FuncOp callee = custom_calls_.GetOrCreate( - b, "xla.gpu.func.launch", TypeRange(ValueRange(args)), TypeRange()); - - // Create a function launch call operation. - auto call = b.create(callee.getName(), TypeRange(), args); - call->setAttr(b.getStringAttr("kernel"), op.getKernelName()); - - // Assign a unique id to this instance of a kernel launch operation. - call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())); - - // Set assigned stream for the kernel launch. - auto stream = op->getAttrOfType("stream"); - if (stream) { - call->setAttr(b.getStringAttr("stream"), stream); - } else { - call->setAttr(b.getStringAttr("stream"), b.getI64IntegerAttr(0)); - } - - // Erase the original gpu launch operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - UidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -//===----------------------------------------------------------------------===// - -void ConvertGpuToGpuRuntimePass::runOnOperation() { - ModuleOp module = getOperation(); - MLIRContext* ctx = module.getContext(); - - // Keep track of the custom calls created from the lowered operations. - SymbolTable sym_table(module); - CustomCallDeclarations custom_calls(std::move(sym_table)); - - // Each kernel launch operation gets a unique id. - UidGenerator kernel_uid; - - // Convert gpu operations to XLA gpu runtime custom calls. - RewritePatternSet patterns(ctx); - patterns.insert(ctx); - patterns.insert(ctx, kernel_uid, custom_calls); - patterns.insert(ctx, custom_calls); - - if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) - return signalPassFailure(); -} - -std::unique_ptr> -createConvertGpuToGpuRuntimePass() { - return std::make_unique(); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc b/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc deleted file mode 100644 index 3a74ed46de424..0000000000000 --- a/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc +++ /dev/null @@ -1,1043 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include - -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "xla/mlir/backends/gpu/transforms/uid_generator.h" -#include "xla/mlir/runtime/ir/rt_dialect.h" -#include "xla/mlir/runtime/utils/custom_calls.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" -#include "xla/stream_executor/blas.h" - -namespace xla { -namespace gpu { - -#define GEN_PASS_DEF_CONVERTLMHLOGPUTOGPURUNTIMEPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -using namespace mlir; // NOLINT - -using mlir::lmhlo_gpu::CholeskyOp; -using mlir::lmhlo_gpu::ConvBackwardFilterOp; -using mlir::lmhlo_gpu::ConvBackwardInputOp; -using mlir::lmhlo_gpu::ConvForwardFusedOp; -using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp; -using mlir::lmhlo_gpu::ConvForwardGraphOp; -using mlir::lmhlo_gpu::ConvForwardOp; -using mlir::lmhlo_gpu::CublasLtMatmulF8Op; -using mlir::lmhlo_gpu::CublasLtMatmulOp; -using mlir::lmhlo_gpu::CudnnConvReorderFilterAndBiasOp; -using mlir::lmhlo_gpu::CudnnConvReorderFilterOp; -using mlir::lmhlo_gpu::CudnnNormOp; -using mlir::lmhlo_gpu::GEMMOp; -using mlir::lmhlo_gpu::RadixSortOp; - -using xla::runtime::CustomCallDeclarations; - -class ConvertLmhloGpuToGpuRuntimePass - : public impl::ConvertLmhloGpuToGpuRuntimePassBase< - ConvertLmhloGpuToGpuRuntimePass> { - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } -}; - -//===----------------------------------------------------------------------===// - -class GemmOpLowering : public OpRewritePattern { - static constexpr const char kCustomCallTarget[] = "xla.gpu.gemm"; - - public: - GemmOpLowering(MLIRContext* ctx, UidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), uid_(uid), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(GEMMOp op, - PatternRewriter& rewriter) const override { - { - // Set requires_blas attribute to true. The runtime pass will add cuBLAS - // initialization custom call to the entry function if the attribute is - // set to true. - auto module = op.getOperation()->getParentOfType(); - ImplicitLocOpBuilder b(module.getLoc(), rewriter); - module->setAttr(b.getStringAttr(runtime::kRequiresBlasAttrName), - BoolAttr::get(b.getContext(), true)); - } - - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, kCustomCallTarget, op); - - // Convert Gemm to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - - // Assign a unique id to this instance of a gemm operation. - call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())); - - // Copy backend specific attributes. - auto algorithm_attr = - op.getAlgorithm() - ? op.getAlgorithmAttr() - : b.getI64IntegerAttr(stream_executor::blas::kDefaultGemmAlgo); - call->setAttr(b.getStringAttr("algorithm"), algorithm_attr); - call->setAttr(b.getStringAttr("alpha_imag"), op.getAlphaImagAttr()); - call->setAttr(b.getStringAttr("alpha_real"), op.getAlphaRealAttr()); - call->setAttr(b.getStringAttr("beta"), op.getBetaAttr()); - call->setAttr(b.getStringAttr("dot_dims"), op.getDotDimensionNumbers()); - - if (auto precisions = op.getPrecisionConfig()) { - llvm::SmallVector values; - for (auto precision : *precisions) { - auto value = precision.cast().getValue(); - values.push_back(static_cast(value)); - } - call->setAttr(b.getStringAttr("precision"), b.getI32TensorAttr(values)); - } else { - call->setAttr(b.getStringAttr("precision"), b.getI32TensorAttr({0, 0})); - } - - // Erase the original gemm operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - UidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -//===----------------------------------------------------------------------===// - -class CublasLtMatmulOpLowering : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = "xla.gpu.cublas.lt.matmul"; - - public: - CublasLtMatmulOpLowering(MLIRContext* ctx, UidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), - uid_(uid), - custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(CublasLtMatmulOp op, - PatternRewriter& rewriter) const override { - // Get the custom call target. - std::string matmul = kCustomCallTarget; - - switch (op.getEpilogue()) { - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Default: - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Relu: - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Gelu: - if (op.getNumOperands() != 4) { - return op.emitOpError("unexpected number of operands for matmul"); - } - break; - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Bias: - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::BiasRelu: - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::BiasGelu: - if (op.getNumOperands() != 5) { - return op.emitOpError("unexpected number of operands for matmul"); - } - matmul += ".bias"; - break; - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::GeluAux: - if (op.getNumOperands() != 5) { - return op.emitOpError("unexpected number of operands for matmul"); - } - matmul += ".aux"; - break; - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::BiasGeluAux: - if (op.getNumOperands() != 6) { - return op.emitOpError("unexpected number of operands for matmul"); - } - matmul += ".bias.aux"; - break; - } - - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, matmul, op); - - // Convert matmul to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - - // Assign a unique id to this instance of a matmul operation. - call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())); - - // Copy backend specific attributes. - call->setAttr(b.getStringAttr("algorithm"), op.getAlgorithmAttr()); - call->setAttr(b.getStringAttr("alpha_imag"), op.getAlphaImagAttr()); - call->setAttr(b.getStringAttr("alpha_real"), op.getAlphaRealAttr()); - call->setAttr(b.getStringAttr("beta"), op.getBetaAttr()); - call->setAttr(b.getStringAttr("dot_dims"), op.getDotDimensionNumbers()); - call->setAttr(b.getStringAttr("epilogue"), op.getEpilogueAttr()); - - // TODO(ezhulenev): Today we can't pass an array of enum attributes to the - // custom call. Also we do not have a corresponding precision enum on the - // SE/XLA side, so we encode it as an i32 array (tensor). - if (auto precisions = op.getPrecisionConfig()) { - llvm::SmallVector values; - for (auto precision : *precisions) { - auto value = precision.cast().getValue(); - values.push_back(static_cast(value)); - } - call->setAttr(b.getStringAttr("precision"), b.getI32TensorAttr(values)); - } else { - call->setAttr(b.getStringAttr("precision"), b.getI32TensorAttr({0, 0})); - } - - // Erase the original matmul operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - UidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -// As above for FP8 Custom Calls. -class CublasLtMatmulF8OpLowering : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = - "xla.gpu.cublas.lt.matmul.f8"; - - public: - CublasLtMatmulF8OpLowering(MLIRContext* ctx, UidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), - uid_(uid), - custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(CublasLtMatmulF8Op op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, kCustomCallTarget, op); - - // Convert matmul to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - - // Assign a unique id to this instance of a matmul operation. - call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())); - - // Copy backend specific attributes. - call->setAttr(b.getStringAttr("algorithm"), op.getAlgorithmAttr()); - call->setAttr(b.getStringAttr("alpha_imag"), op.getAlphaImagAttr()); - call->setAttr(b.getStringAttr("alpha_real"), op.getAlphaRealAttr()); - call->setAttr(b.getStringAttr("beta"), op.getBetaAttr()); - call->setAttr(b.getStringAttr("dot_dims"), op.getDotDimensionNumbers()); - call->setAttr(b.getStringAttr("epilogue"), op.getEpilogueAttr()); - - // TODO(ezhulenev): Today we can't pass an array of enum attributes to the - // custom call. Also we do not have a corresponding precision enum on the - // SE/XLA side, so we encode it as an i32 array (tensor). - if (auto precisions = op.getPrecisionConfig()) { - llvm::SmallVector values; - for (auto precision : *precisions) { - auto value = precision.cast().getValue(); - values.push_back(static_cast(value)); - } - call->setAttr(b.getStringAttr("precision"), b.getI32TensorAttr(values)); - } else { - call->setAttr(b.getStringAttr("precision"), b.getI32TensorAttr({0, 0})); - } - - // Erase the original matmul operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - UidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -//===----------------------------------------------------------------------===// - -template -class ConvOpLowering : public OpRewritePattern { - private: - static StringRef CustomCallTarget(ConvForwardOp) { - return "xla.gpu.conv.forward"; - } - static StringRef CustomCallTarget(ConvForwardFusedOp) { - return "xla.gpu.conv.forward.fused"; - } - static StringRef CustomCallTarget(ConvForwardFusedSideInputOp) { - return "xla.gpu.conv.forward.fused.side_input"; - } - static StringRef CustomCallTarget(ConvBackwardFilterOp) { - return "xla.gpu.conv.backward.filter"; - } - static StringRef CustomCallTarget(ConvBackwardInputOp) { - return "xla.gpu.conv.backward.input"; - } - static StringRef CustomCallTarget(ConvForwardGraphOp) { - return "xla.gpu.conv.forward.graph"; - } - - public: - explicit ConvOpLowering(MLIRContext* ctx, UidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), uid_(uid), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(Conv op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = - custom_calls_.GetOrCreate(b, CustomCallTarget(op), op); - - // Convert Conv to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - - // Helper functins to copy attributes from the conv op to the custom call. - auto set_attr = [&](StringRef name, Attribute attr) { - call->setAttr(b.getStringAttr(name), attr); - }; - - auto set_xi64 = [&](StringRef name, - std::optional attr) { - SmallVector values; - if (attr.has_value()) - values = llvm::to_vector(attr->getValues()); - set_attr(name, b.getI64TensorAttr(values)); - }; - - // Convert `BoolElementsAttr` to i64 before passing to the runtime. - // TODO(ezhulenev): Allow passing boolean tensors to the XLA custom calls. - auto set_xi1 = [&](StringRef name, std::optional attr) { - SmallVector values; - if (attr.has_value()) - values.assign(attr->getValues().begin(), - attr->getValues().end()); - set_attr(name, b.getI64TensorAttr(values)); - }; - - // Assign a unique id to this instance of a conv operation. - call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())); - - // Copy dimension number attributes. - call->setAttr(b.getStringAttr("conv_dims"), op.getDimensionNumbers()); - - // Copy convolution window attributes. - set_xi1("window_reversal", op.getWindowReversal()); - set_xi64("window_strides", op.getWindowStrides()); - set_xi64("lhs_dilation", op.getLhsDilation()); - set_xi64("rhs_dilation", op.getRhsDilation()); - set_xi64("padding", op.getPadding()); - - // Copy backend config. - call->setAttr(b.getStringAttr("backend_config"), op.getBackendConfig()); - - // Copy remaining attributes. - set_attr("feature_group_count", op.getFeatureGroupCountAttr()); - set_attr("result_scale", op.getResultScaleAttr()); - - // Copy attributes specific for fused convolutions. - if (auto fused = dyn_cast(op.getOperation())) { - call->setAttr(b.getStringAttr("activation_mode"), - fused.getActivationModeAttr()); - set_attr("leakyrelu_alpha", fused.getLeakyreluAlphaAttr()); - } - - // Copy attributes specific for fused convolutions with side input. - if (auto fused = dyn_cast(op.getOperation())) { - call->setAttr(b.getStringAttr("activation_mode"), - fused.getActivationModeAttr()); - set_attr("side_input_scale", fused.getSideInputScaleAttr()); - } - - // Copy attributes specific for graph convolutions. - if (auto fused = dyn_cast(op.getOperation())) { - call->setAttr(b.getStringAttr("n_aux_outputs"), - fused.getNAuxOutputsAttr()); - call->setAttr(b.getStringAttr("serialized_graph"), - fused.getSerializedGraphAttr()); - } - - // Erase the original conv operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - UidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -class ConvForwardOpLowering : public ConvOpLowering { - public: - using ConvOpLowering::ConvOpLowering; -}; - -class ConvForwardFusedOpLowering : public ConvOpLowering { - public: - using ConvOpLowering::ConvOpLowering; -}; - -class ConvBackwardFilterOpLowering - : public ConvOpLowering { - public: - using ConvOpLowering::ConvOpLowering; -}; - -class ConvBackwardInputOpLowering : public ConvOpLowering { - public: - using ConvOpLowering::ConvOpLowering; -}; - -class ConvForwardFusedSideInputOpLowering - : public ConvOpLowering { - public: - using ConvOpLowering::ConvOpLowering; -}; - -class ConvForwardGraphOpLowering : public ConvOpLowering { - public: - using ConvOpLowering::ConvOpLowering; -}; - -//===----------------------------------------------------------------------===// - -template -class CudnnConvReorderOpLowering : public OpRewritePattern { - private: - static StringRef CustomCallTarget(CudnnConvReorderFilterOp) { - return "xla.gpu.conv.reorder.filter"; - } - static StringRef CustomCallTarget(CudnnConvReorderFilterAndBiasOp) { - return "xla.gpu.conv.reorder.filter_and_bias"; - } - - public: - explicit CudnnConvReorderOpLowering(MLIRContext* ctx, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(ConvReorder op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = - custom_calls_.GetOrCreate(b, CustomCallTarget(op), op); - - auto filterDims = rewriter.getDenseI64ArrayAttr( - llvm::to_vector(op.getFilterDims().template getValues())); - - // Replace ConvOp with an equivalent custom call. - auto call = rewriter.replaceOpWithNewOp( - op, callee.getName(), TypeRange(), op.getOperands()); - call->setAttr(b.getStringAttr("filter_dims"), filterDims); - - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -class CudnnConvReorderFilterOpLowering - : public CudnnConvReorderOpLowering { - public: - using CudnnConvReorderOpLowering::CudnnConvReorderOpLowering; -}; - -class CudnnConvReorderFilterAndBiasOpLowering - : public CudnnConvReorderOpLowering { - public: - using CudnnConvReorderOpLowering::CudnnConvReorderOpLowering; -}; - -//===----------------------------------------------------------------------===// - -class CholeskyOpLowering : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = "xla.gpu.cholesky"; - - public: - explicit CholeskyOpLowering(MLIRContext* ctx, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(CholeskyOp op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, kCustomCallTarget, op); - - // Convert Cholesky to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - - const auto& dims = - op.getInput().getType().cast().getShape(); - if (dims.size() < 2) - return op.emitOpError() << "Input's dimension count (" << dims.size() - << ") must be 2 or greater."; - int64_t n = dims[dims.size() - 1]; - int64_t batch_size = - std::accumulate(dims.begin(), dims.end() - 2, int64_t{1}, - [](int64_t a, int64_t b) { return a * b; }); - - // Copy backend specific attributes. - call->setAttr(b.getStringAttr("batch_size"), - b.getI64IntegerAttr(batch_size)); - call->setAttr(b.getStringAttr("n"), b.getI64IntegerAttr(n)); - call->setAttr(b.getStringAttr("is_lower"), op.getIsLowerAttr()); - - // Erase the original Cholesky operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -class NormOpLowering : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = "xla.gpu.norm"; - - public: - NormOpLowering(MLIRContext* ctx, UidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), - uid_(uid), - custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(CudnnNormOp op, - PatternRewriter& rewriter) const override { - // Get or create a Custom Call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, kCustomCallTarget, op); - - // Convert norm to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - - // Assign a unique id to this instance of a norm operation. - call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())); - - // Copy backend specific attributes. - call->setAttr(b.getStringAttr("norm_algorithm_config"), - op.getAlgorithmConfigAttr()); - call->setAttr(b.getStringAttr("epsilon"), op.getEpsilonAttr()); - - mlir::ArrayAttr array = op.getOperandLayouts(); - SmallVector values; - for (auto array_elem : array) { - mlir::IntegerAttr attr = array_elem.dyn_cast(); - values.push_back(attr.getInt()); - } - call->setAttr(b.getStringAttr("operand_layouts"), - b.getI64TensorAttr(values)); - - // Erase the original norm operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - UidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -using mlir::lmhlo_gpu::fusedMHAOp; - -template -class FusedAttentionForwardLowering - : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = "xla.gpu.fused.attention."; - - public: - explicit FusedAttentionForwardLowering(MLIRContext* ctx, UidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), - uid_(uid), - custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(FusedDotAttentionForward op, - PatternRewriter& rewriter) const override { - // Get the custom call target. - std::string fused_attention = kCustomCallTarget; - auto num_operands = op.getNumOperands(); - switch (op.getFusedMhaDag()) { - case mlir::lmhlo_gpu::FusedMhaDagSignature::Default: - if (num_operands == 5) { - fused_attention += "bmm.bmm.inference"; - } else if (num_operands == 6) { - fused_attention += "bmm.bmm.forward"; - } else { - return op.emitOpError( - "unexpected number of operands for fused dot attention - BMMBMM"); - } - break; - case mlir::lmhlo_gpu::FusedMhaDagSignature::Softmax: - if (num_operands == 5) { - fused_attention += "softmax.inference"; - } else if (num_operands == 6) { - fused_attention += "softmax.forward"; - } else { - return op.emitOpError( - "unexpected number of operands for fused dot attention - " - "BMM_Softmax_BMM"); - } - break; - case mlir::lmhlo_gpu::FusedMhaDagSignature::SoftmaxDropout: - if (num_operands == 5) { - fused_attention += "softmax.dropout.inference"; - } else if (num_operands == 6) { - fused_attention += "softmax.dropout.forward"; - } else { - return op.emitOpError( - "unexpected number of operands for fused dot attention - " - "BMM_Softmax_Dropout_BMM"); - } - break; - - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasMaskSoftmax: - if (num_operands == 7) { - fused_attention += "scale.bias.mask.softmax.inference"; - } else if (num_operands == 8) { - fused_attention += "scale.bias.mask.softmax.forward"; - } else { - return op.emitOpError( - "unexpected number of operands for fused dot attention - " - "BMM_Bias_Mask_Softmax_BMM"); - } - break; - - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasMaskSoftmaxDropout: - if (num_operands == 7) { - fused_attention += "scale.bias.mask.softmax.dropout.inference"; - } else if (num_operands == 8) { - fused_attention += "scale.bias.mask.softmax.dropout.forward"; - } else { - return op.emitOpError( - "unexpected number of operands for fused dot attention - " - "BMM_Bias_Mask_Softmax_Dropout_BMM"); - } - break; - - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleMaskSoftmax: - if (num_operands == 6) { - fused_attention += "scale.mask.softmax.inference"; - } else if (num_operands == 7) { - fused_attention += "scale.mask.softmax.forward"; - } else { - return op.emitOpError( - "unexpected number of operands for fused dot attention - " - "BMM_mask_Softmax_BMM"); - } - break; - - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleMaskSoftmaxDropout: - if (num_operands == 6) { - fused_attention += "scale.mask.softmax.dropout.inference"; - } else if (num_operands == 7) { - fused_attention += "scale.mask.softmax.dropout.forward"; - } else { - return op.emitOpError( - "unexpected number of operands for fused dot attention - " - "BMM_mask_Softmax_Dropout_BMM"); - } - break; - - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasSoftmax: - if (num_operands == 6) { - fused_attention += "scale.bias.softmax.inference"; - } else if (num_operands == 7) { - fused_attention += "scale.bias.softmax.forward"; - } else { - return op.emitOpError( - "unexpected number of operands for fused dot attention - " - "BMM_bias_Softmax_BMM"); - } - break; - - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasSoftmaxDropout: - if (num_operands == 6) { - fused_attention += "scale.bias.softmax.dropout.inference"; - } else if (num_operands == 7) { - fused_attention += "scale.bias.softmax.dropout.forward"; - } else { - return op.emitOpError( - "unexpected number of operands for fused dot attention - " - "BMM_bias_Softmax_Dropout_BMM"); - } - break; - - default: - return op.emitOpError("Undefined fused dot attention DAG signature"); - } - - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, fused_attention, op); - - // Convert fused_attention to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - - // Assign a unique id to this instance of a fused_attention operation. - call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())); - - // Helper functins to copy attributes from the conv op to the custom call. - auto set_attr = [&](StringRef name, Attribute attr) { - if (attr) { - call->setAttr(b.getStringAttr(name), attr); - } - }; - - set_attr("fmha_scale", op.getFmhaScaleAttr()); - set_attr("dropout_rate", op.getDropoutRateAttr()); - set_attr("seed", op.getSeedAttr()); - set_attr("is_flash_attention", op.getIsFlashAttentionAttr()); - set_attr("is_causal_mask", op.getIsCausalMaskAttr()); - set_attr("fused_mha_dag", op.getFusedMhaDagAttr()); - set_attr("algorithm_config", op.getAlgorithmConfigAttr()); - set_attr("bmm1_dot_dimension_numbers", op.getBmm1DotDimensionNumbers()); - set_attr("bmm2_dot_dimension_numbers", op.getBmm2DotDimensionNumbers()); - - auto set_xi64 = [&](StringRef name, mlir::ArrayAttr array) { - int rank = array.size(); - SmallVector values; - for (int i = 0; i < rank; i++) { - mlir::IntegerAttr attr = array[i].dyn_cast(); - values.push_back(attr.getInt()); - } - set_attr(name, b.getI64TensorAttr(values)); - }; - - set_xi64("intermediate_tensor_dimensions", - op.getIntermediateTensorDimensions()); - set_xi64("intermediate_tensor_layout", op.getIntermediateTensorLayout()); - - // Erase the original fused dot attention operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - UidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -class FusedAttentionForwardOpLowering - : public FusedAttentionForwardLowering { - public: - using FusedAttentionForwardLowering::FusedAttentionForwardLowering; -}; - -using mlir::lmhlo_gpu::fusedMHABackwardOp; - -template -class FusedAttentionBackwardLowering - : public OpRewritePattern { - private: - static constexpr const char kFusedAttentionCustomCallTarget[] = - "xla.gpu.fused.attention.backward."; - static constexpr const char kFlashAttentionCustomCallTarget[] = - "xla.gpu.flash.attention.backward."; - - public: - explicit FusedAttentionBackwardLowering(MLIRContext* ctx, UidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), - uid_(uid), - custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(FusedDotAttentionBackward op, - PatternRewriter& rewriter) const override { - // Get the custom call target. - bool is_flash_attention = op.getIsFlashAttention(); - std::string fused_attention = is_flash_attention - ? kFlashAttentionCustomCallTarget - : kFusedAttentionCustomCallTarget; - auto num_operands = op.getNumOperands(); - switch (op.getFusedMhaDag()) { - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmax: - if (is_flash_attention) { - if (num_operands == 12) { - fused_attention += "scale.softmax"; - } else { - return op.emitOpError( - "unexpected number of operands for flash attention backward - " - "BMM_Softmax_BMM"); - } - } - break; - - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasSoftmax: - if (is_flash_attention) { - if (num_operands == 13) { - fused_attention += "scale.bias.softmax"; - } else { - return op.emitOpError( - "unexpected number of operands for flash attention backward - " - "BMM_Bias_Softmax_BMM"); - } - break; - } - if (num_operands == 10) { - fused_attention += "scale.softmax"; - } else if (num_operands == 11) { - fused_attention += "scale.dbias.softmax"; - } else { - return op.emitOpError( - "unexpected number of operands for fused attention backward - " - "BMM_Bias_Softmax_BMM"); - } - break; - - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasSoftmaxDropout: - if (num_operands == 10) { - fused_attention += "scale.softmax.dropout"; - } else if (num_operands == 11) { - fused_attention += "scale.dbias.softmax.dropout"; - } else { - return op.emitOpError( - "unexpected number of operands for fused attention backward - " - "BMM_Bias_Softmax_Dropout_BMM"); - } - break; - - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasMaskSoftmax: - if (num_operands == 11) { - fused_attention += "scale.mask.softmax"; - } else if (num_operands == 12) { - fused_attention += "scale.dbias.mask.softmax"; - } else { - return op.emitOpError( - "unexpected number of operands for fused attention backward - " - "BMM_Bias_Mask_Softmax_BMM"); - } - break; - - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasMaskSoftmaxDropout: - if (num_operands == 11) { - fused_attention += "scale.mask.softmax.dropout"; - } else if (num_operands == 12) { - fused_attention += "scale.dbias.mask.softmax.dropout"; - } else { - return op.emitOpError( - "unexpected number of operands for fused attention backward - " - "BMM_Bias_Mask_Softmax_Dropout_BMM"); - } - break; - - default: - return op.emitOpError("Undefined fused attention DAG signature"); - } - - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, fused_attention, op); - - // Convert fused_attention to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - - // Assign a unique id to this instance of a fused_attention operation. - call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())); - - // Helper functins to copy attributes from the conv op to the custom call. - auto set_attr = [&](StringRef name, Attribute attr) { - if (attr) { - call->setAttr(b.getStringAttr(name), attr); - } - }; - - set_attr("fmha_scale", op.getFmhaScaleAttr()); - set_attr("dropout_rate", op.getDropoutRateAttr()); - set_attr("seed", op.getSeedAttr()); - set_attr("is_flash_attention", op.getIsFlashAttentionAttr()); - set_attr("is_causal_mask", op.getIsCausalMaskAttr()); - set_attr("fused_mha_dag", op.getFusedMhaDagAttr()); - set_attr("algorithm_config", op.getAlgorithmConfigAttr()); - set_attr("bmm1_grad_gemm1_dot_dimension_numbers", - op.getBmm1GradGemm1DotDimensionNumbers()); - set_attr("bmm1_grad_gemm2_dot_dimension_numbers", - op.getBmm1GradGemm2DotDimensionNumbers()); - set_attr("bmm2_grad_gemm1_dot_dimension_numbers", - op.getBmm2GradGemm1DotDimensionNumbers()); - set_attr("bmm2_grad_gemm2_dot_dimension_numbers", - op.getBmm2GradGemm2DotDimensionNumbers()); - - auto set_xi64 = [&](StringRef name, mlir::ArrayAttr array) { - int rank = array.size(); - SmallVector values; - for (int i = 0; i < rank; i++) { - mlir::IntegerAttr attr = array[i].dyn_cast(); - values.push_back(attr.getInt()); - } - set_attr(name, b.getI64TensorAttr(values)); - }; - - set_xi64("intermediate_tensor_dimensions", - op.getIntermediateTensorDimensions()); - set_xi64("intermediate_tensor_layout", op.getIntermediateTensorLayout()); - - // Erase the original fused dot attention operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - UidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -class FusedAttentionBackwardOpLowering - : public FusedAttentionBackwardLowering { - public: - using FusedAttentionBackwardLowering::FusedAttentionBackwardLowering; -}; - -class RadixSortOpLowering : public OpRewritePattern { - private: - static constexpr const char kSortKeysTarget[] = "xla.gpu.radix_sort_keys"; - static constexpr const char kSortPairsTarget[] = "xla.gpu.radix_sort_pairs"; - - public: - explicit RadixSortOpLowering(MLIRContext* ctx, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(RadixSortOp op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate( - b, op.getOperands().size() == 3 ? kSortKeysTarget : kSortPairsTarget, - op); - - // Convert radix sort to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - call->setAttr(b.getStringAttr("descending"), op.getDescendingAttr()); - - // Erase the original operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -//===----------------------------------------------------------------------===// - -void ConvertLmhloGpuToGpuRuntimePass::runOnOperation() { - ModuleOp module = getOperation(); - MLIRContext* ctx = module.getContext(); - - // Keep track of the custom calls created from the lowered operations. - SymbolTable sym_table(module); - CustomCallDeclarations custom_calls(std::move(sym_table)); - - // Convert lmhlo_gpu operations to XLA gpu runtime custom calls. - RewritePatternSet patterns(ctx); - - // Each unique Gemm/Matmul operation in the module will get assigned a uid. - UidGenerator matmul_uid; - patterns.insert(ctx, matmul_uid, custom_calls); - - // Each unique Conv operation in the module will get assigned a uid. - UidGenerator conv_uid; - patterns - .insert( - ctx, conv_uid, custom_calls); - - // Patterns for every other Gpu operation. - patterns.insert(ctx, custom_calls); - patterns.insert(ctx, custom_calls); - patterns.insert(ctx, custom_calls); - patterns.insert(ctx, custom_calls); - - // Each unique Norm operation in the module will get assigned a uid. - UidGenerator norm_uid; - patterns.insert(ctx, norm_uid, custom_calls); - - // Each unique fused_attention operation in the module will get assigned a - // uid. - UidGenerator fused_attention_uid; - patterns.insert(ctx, fused_attention_uid, - custom_calls); - - // Each unique fused_attention_backward operation in the module will get - // assigned a uid. - UidGenerator fused_attention_backward_uid; - patterns.insert( - ctx, fused_attention_backward_uid, custom_calls); - - if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) - return signalPassFailure(); -} - -std::unique_ptr> -createConvertLmhloGpuToGpuRuntimePass() { - return std::make_unique(); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_launch.cc b/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_launch.cc deleted file mode 100644 index d0df74a362ad0..0000000000000 --- a/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_launch.cc +++ /dev/null @@ -1,419 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/status/statusor.h" -#include "llvm/ADT/APFloat.h" -#include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "xla/mlir/runtime/ir/rt_ops.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/service/gpu/conditional_thunk.h" -#include "xla/service/gpu/copy_thunk.h" -#include "xla/service/gpu/kernel_thunk.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/memset_thunk.h" -#include "xla/service/gpu/sequential_thunk.h" -#include "xla/service/gpu/while_thunk.h" - -namespace xla { -namespace gpu { - -#define GEN_PASS_DEF_CONVERTLMHLOTOGPULAUNCHPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -using namespace mlir; // NOLINT - -using mlir::gpu::GPUDialect; -using mlir::gpu::GPUFuncOp; -using mlir::gpu::GPUModuleOp; -using mlir::gpu::KernelDim3; -using mlir::gpu::LaunchFuncOp; -using mlir::gpu::MemcpyOp; -using mlir::gpu::MemsetOp; -using mlir::gpu::ReturnOp; - -class ConvertLmhloToGpuLaunchPass - : public impl::ConvertLmhloToGpuLaunchPassBase< - ConvertLmhloToGpuLaunchPass> { - public: - explicit ConvertLmhloToGpuLaunchPass(ThunkSequence* thunk_sequence) - : thunk_sequence_(thunk_sequence) {} - - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } - - private: - ThunkSequence* thunk_sequence_; -}; - -// XLA some times (ab)uses custom calls to represent operations for which we do -// not want to define a separate `HloOpcode`. These operations emitted as device -// kernels (similar to fusions), and we detect such custom calls by name, and -// handle them similar to how we handle fusions. -static std::array kCustomCallIntrinsics = { - "SliceToDynamic", "PadToStatic"}; - -//===-----------------------------------------------------------------------===/ - -static Value MakeBitPatternConstant(OpBuilder& b, Location loc, Type type, - uint32_t bit_pattern) { - mlir::MLIRContext* ctx = type.getContext(); - - // For zero bit pattern always memset with a zero value of the same type. - if (bit_pattern == 0) { - // Because `arith` dialect doesn't support unsigned constants, we have to - // create signless constant first, and then use `rt.unsigned_cast` operation - // to make it unsigned. When lowering to LLVM and function calls, this - // casting operation will be erased. - if (type.isUnsignedInteger()) { - auto signless = IntegerType::get(ctx, type.getIntOrFloatBitWidth()); - auto zero = b.create(loc, b.getZeroAttr(signless)); - return b.create(loc, type, zero.getResult()); - } - - return b.create(loc, b.getZeroAttr(type)); - } - - // In XLA a 1-byte bit pattern copied to fill a 32-byte word when - // `Memset32BitValueThunk` is constructed, so to get back an `i1` constant we - // only need to check if any bit is set to `1`. - if (type.isInteger(1)) { - return b.create(loc, b.getBoolAttr(bit_pattern)); - } - - // Xla IR emitter copies integers of smaller width to fill 32 bits, so we can - // safely truncate the bit pattern. For integers larger than 32 bits we can - // construct a wider integer, as Xla guarantees that all 32-bit words are - // equal. - if (auto integer = type.dyn_cast()) { - llvm::APInt i32(32, bit_pattern); - - assert(integer.getWidth() <= 64 && "integer value must be <= 64 bits"); - llvm::APInt value = integer.getWidth() <= 32 ? i32.trunc(integer.getWidth()) - : i32.concat(i32); - - // See unsigned-to-signed cast documentation above. - if (integer.isUnsigned()) { - auto signless = IntegerType::get(ctx, integer.getWidth()); - auto cst = - b.create(loc, b.getIntegerAttr(signless, value)); - return b.create(loc, type, cst.getResult()); - } - - return b.create(loc, b.getIntegerAttr(integer, value)); - } - - // Similar to integer type we can safely truncate or concat bit pattern. - if (auto fp = type.dyn_cast()) { - llvm::APInt i32(32, bit_pattern); - - assert(fp.getWidth() <= 64 && "floating point value must be <= 64 bits"); - llvm::APInt ivalue = - fp.getWidth() <= 32 ? i32.trunc(fp.getWidth()) : i32.concat(i32); - - llvm::APFloat fvalue = [&]() -> llvm::APFloat { - if (fp.isBF16()) return {llvm::APFloat::BFloat(), ivalue}; - if (fp.isF16()) return {llvm::APFloat::IEEEhalf(), ivalue}; - if (fp.isF32()) return {llvm::APFloat::IEEEsingle(), ivalue}; - if (fp.isF64()) return {llvm::APFloat::IEEEdouble(), ivalue}; - - assert(false && "unsupported floating point type"); - return llvm::APFloat::getZero(llvm::APFloat::IEEEsingle()); - }(); - - return b.create(loc, fvalue, fp); - } - - // Return a constant index value, that will safely fail verification (there is - // no memset operation for `index` type), so that we do not accidentally crash - // the binary in optimized builds. - assert(false && "unsupported memset type"); - return b.create(loc, 0); -} - -static void ExtractThunksForOp(Operation* op, ThunkSequence& thunk_sequence, - ThunkSequence* thunks_for_op) { - for (std::unique_ptr& thunk : thunk_sequence) { - if (thunk == nullptr) { - // This thunk has already been std::move()'ed out of the ThunkSequence - // (see below). Do nothing. - } else if (thunk->kind() == Thunk::kWhile) { - // Search for thunks for the op in while loop. - auto* while_thunk = static_cast(thunk.get()); - ExtractThunksForOp(op, while_thunk->condition_thunk_sequence()->thunks(), - thunks_for_op); - ExtractThunksForOp(op, while_thunk->body_thunk_sequence()->thunks(), - thunks_for_op); - } else if (thunk->kind() == Thunk::kConditional) { - // Search for thunks for the op in conditional branches. - auto* cond_thunk = static_cast(thunk.get()); - for (const std::unique_ptr& branch_thunks : - cond_thunk->branch_thunks()) { - ExtractThunksForOp(op, branch_thunks->thunks(), thunks_for_op); - } - } else if (thunk->op() == op) { - // Found a thunk for the op. - thunks_for_op->push_back(std::move(thunk)); - } else { - // Thunk is not relevant to the op. Do nothing. - } - } -} - -// Returns the data to rewrite op without changing the IR. -static absl::StatusOr> Match( - Operation* op, ThunkSequence& thunk_sequence) { - auto thunks_for_op = std::make_unique(); - ExtractThunksForOp(op, thunk_sequence, thunks_for_op.get()); - - // Check if we know how to lower a Thunk to Gpu operation(s). - auto is_supported = [](const std::unique_ptr& thunk) -> bool { - Thunk::Kind kinds[] = {Thunk::kKernel, Thunk::kCopy, - Thunk::kMemset32BitValue, Thunk::kMemzero, - Thunk::kSequential}; - return llvm::any_of( - kinds, [&](Thunk::Kind kind) { return thunk->kind() == kind; }); - }; - - if (!llvm::all_of(*thunks_for_op, is_supported)) { - return absl::InternalError("Unsupported Thunk kind"); - } - - return std::move(thunks_for_op); -} - -static void LowerThunkToGpuOp(Operation* op, OpBuilder& b, - GPUModuleOp gpu_module, Thunk* thunk); - -// Replaces op with gpu.launch_func, gpu.memcpy, gpu.memset ops. -static void Rewrite(Operation* op, OpBuilder& b, SymbolTable& symbol_table, - ThunkSequence* thunks) { - OpBuilder::InsertionGuard guard(b); - auto loc = op->getLoc(); - - b.setInsertionPoint(op->getParentOfType()); - auto gpu_module = b.create(loc, "gpu_module"); - symbol_table.insert(gpu_module); - - for (const std::unique_ptr& thunk : *thunks) { - LowerThunkToGpuOp(op, b, gpu_module, thunk.get()); - } - - op->erase(); -} - -static void LowerKernelThunkToGpuOp( - Operation* op, OpBuilder& b, GPUModuleOp gpu_module, - const KernelThunk& thunk, const SmallVector& kernel_args, - const SmallVector& kernel_args_written) { - mlir::Location loc = op->getLoc(); - b.setInsertionPointToStart(gpu_module.getBody()); - - auto func_type = - b.getType(TypeRange(ValueRange(kernel_args)), TypeRange()); - - gpu::GPUFuncOp kernel_func = - b.create(loc, thunk.kernel_name(), func_type); - kernel_func->setAttr(GPUDialect::getKernelFuncAttrName(), b.getUnitAttr()); - - for (int i = 0; i < kernel_args.size(); ++i) { - if (kernel_args_written[i]) { - kernel_func.setArgAttr(i, "lmhlo.written", b.getUnitAttr()); - } - } - - b.setInsertionPointToEnd(&kernel_func.getBody().back()); - b.create(loc); - - auto make_const_idx = [&](int64_t value) { - auto attr = b.getIndexAttr(value); - return b.create(loc, attr).getResult(); - }; - - auto make_kernel_dim3 = [&](const auto& dim3) { - return KernelDim3{make_const_idx(dim3.x), make_const_idx(dim3.y), - make_const_idx(dim3.z)}; - }; - - b.setInsertionPoint(op); - const auto& launch_dims = thunk.launch_dimensions(); - auto grid_size = make_kernel_dim3(launch_dims.block_counts()); - auto block_size = make_kernel_dim3(launch_dims.thread_counts_per_block()); - auto shmem_size = b.create( - loc, b.getI32IntegerAttr(thunk.shmem_bytes())); - - b.create(loc, kernel_func, grid_size, block_size, shmem_size, - kernel_args); -} - -static void LowerThunkToGpuOp(Operation* op, OpBuilder& b, - GPUModuleOp gpu_module, Thunk* thunk) { - auto loc = op->getLoc(); - - if (thunk->kind() == Thunk::kSequential) { - const auto* seq_thunk = static_cast(thunk); - for (const std::unique_ptr& thunk : seq_thunk->thunks()) { - LowerThunkToGpuOp(op, b, gpu_module, thunk.get()); - } - return; - } - - if (thunk->kind() == Thunk::kCopy) { - const auto* copy_thunk = static_cast(thunk); - b.setInsertionPoint(op); - b.create(loc, TypeRange(), ValueRange(), - copy_thunk->destination_value(), - copy_thunk->source_value()); - return; - } - - auto rewrite_memset = [&](const xla::BufferAllocation::Slice& slice, - uint32_t memset_value, Value buffer_arg) { - auto element_type = - buffer_arg.getType().cast().getElementType(); - b.setInsertionPoint(op); - Value value = MakeBitPatternConstant(b, loc, element_type, memset_value); - b.create(loc, TypeRange(), ValueRange(), buffer_arg, value); - }; - - if (thunk->kind() == Thunk::kMemset32BitValue) { - const auto* memset_thunk = static_cast(thunk); - rewrite_memset(memset_thunk->destination(), memset_thunk->value(), - memset_thunk->dest_value()); - return; - } - if (thunk->kind() == Thunk::kMemzero) { - const auto* memzero_thunk = static_cast(thunk); - rewrite_memset(memzero_thunk->destination(), 0, - memzero_thunk->dest_value()); - return; - } - - if (thunk->kind() == Thunk::kKernel) { - const auto* kernel_thunk = static_cast(thunk); - - SmallVector kernel_args; - for (auto kernel_arg : kernel_thunk->values()) - kernel_args.push_back(kernel_arg); - - SmallVector kernel_args_written; - for (auto written : kernel_thunk->written()) { - kernel_args_written.push_back(written); - } - - LowerKernelThunkToGpuOp(op, b, gpu_module, *kernel_thunk, kernel_args, - kernel_args_written); - return; - } - - CHECK(false) << "Thunk kind not handled: " << thunk->kind(); -} - -// An overload set for defining predicates for operations that should -// conditionally go through the XLA GPU code emitters. -template -static bool HasGpuEmitter(OpTy) { - return true; -} - -// Select custom calls that have corresponding GPU emitters. -static bool HasGpuEmitter(lmhlo::CustomCallOp custom_call) { - return llvm::any_of(kCustomCallIntrinsics, [&](std::string_view name) { - return custom_call.getCallTargetName().equals(name); - }); -} - -//===-----------------------------------------------------------------------===/ - -void ConvertLmhloToGpuLaunchPass::runOnOperation() { - ModuleOp module = getOperation(); - - // No thunks to lower from. Skip pass. - if (thunk_sequence_ == nullptr) return signalPassFailure(); - - // Collect thunks for rewriting each compatible operation in the module into - // the sequence of device kernel launches. Some operation might have an empty - // thunk sequence (e.g. redundant copy operation that does not require running - // anything on device). - absl::flat_hash_map> rewrites; - - // Get data to rewrite kernel ops without changing the IR. - auto walk = [&](auto op_type_tag) { - return module.walk([&](decltype(op_type_tag) op) -> WalkResult { - if (!HasGpuEmitter(op)) return success(); - - auto data = Match(op, *thunk_sequence_); - if (!data.ok()) return op.emitOpError(data.status().message()); - - rewrites[op] = std::move(*data); - return success(); - }); - }; - - // Collect all operations that have GPU code emitters. - if (walk(lmhlo::FusionOp()).wasInterrupted() || - walk(lmhlo::RngGetAndUpdateStateOp()).wasInterrupted() || - walk(lmhlo::ScatterOp()).wasInterrupted() || - walk(lmhlo::SelectAndScatterOp()).wasInterrupted() || - walk(lmhlo::SortOp()).wasInterrupted() || - walk(lmhlo::CustomCallOp()).wasInterrupted() || - walk(LaunchFuncOp()).wasInterrupted()) - return signalPassFailure(); - - // No operations that should be lowered to sequence of device launches. - if (rewrites.empty()) return; - - OpBuilder b(module); - SymbolTable symbol_table(module); - - // Replace matched operations with gpu.launch_func's. - for (const auto& [op, thunks] : rewrites) { - Rewrite(op, b, symbol_table, thunks.get()); - } - - // Mark module as gpu.container_module. - module->setAttr(GPUDialect::getContainerModuleAttrName(), b.getUnitAttr()); -} - -std::unique_ptr> -createConvertLmhloToGpuLaunchPass(ThunkSequence* thunk_sequence) { - return std::make_unique(thunk_sequence); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_runtime.cc b/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_runtime.cc deleted file mode 100644 index 593b81ac2dcc6..0000000000000 --- a/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_runtime.cc +++ /dev/null @@ -1,1239 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include - -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" // from @llvm-project -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/IRMapping.h" // from @llvm-project -#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "xla/mlir/backends/gpu/transforms/uid_generator.h" -#include "xla/mlir/runtime/utils/custom_calls.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" -#include "xla/service/gpu/nccl_all_gather_thunk.h" -#include "xla/service/gpu/nccl_all_reduce_thunk.h" -#include "xla/service/gpu/nccl_all_to_all_thunk.h" -#include "xla/service/gpu/nccl_collective_permute_thunk.h" -#include "xla/service/gpu/nccl_collective_thunk.h" -#include "xla/service/gpu/nccl_recv_thunk.h" -#include "xla/service/gpu/nccl_send_thunk.h" - -namespace xla { -namespace gpu { - -#define GEN_PASS_DEF_CONVERTLMHLOTOGPURUNTIMEPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -using namespace mlir; // NOLINT - -using mlir::gpu::MemcpyOp; - -using mlir::lmhlo::CaseOp; -using mlir::lmhlo::CustomCallOp; -using mlir::lmhlo::FftOp; -using mlir::lmhlo::InfeedOp; -using mlir::lmhlo::OutfeedOp; -using mlir::lmhlo::TerminatorOp; -using mlir::lmhlo::WhileOp; - -using xla::runtime::AppendCustomCallAttrs; -using xla::runtime::CustomCallDeclarations; - -// helper template to check T is any of the types listed in Ts. -template -inline constexpr bool is_any = std::disjunction_v...>; - -class ConvertLmhloToGpuRuntimePass - : public impl::ConvertLmhloToGpuRuntimePassBase< - ConvertLmhloToGpuRuntimePass> { - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry - .insert(); - } -}; - -//===----------------------------------------------------------------------===// - -class TerminatorOpLowering : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TerminatorOp op, - PatternRewriter& rewriter) const override { - rewriter.replaceOpWithNewOp(op); - return mlir::success(); - } -}; - -//===----------------------------------------------------------------------===// - -template -class IoFeedOpLowering : public OpRewritePattern { - static StringRef Target(InfeedOp) { return "xla.gpu.infeed"; } - static StringRef Target(OutfeedOp) { return "xla.gpu.outfeed"; } - - public: - IoFeedOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(IoFeedOp op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, Target(op), op); - - llvm::SmallVector custom_call_attrs = { - {b.getStringAttr("config"), op.getConfigAttr()}}; - - // Call the runtime intrinsic with the original operands. - auto call = rewriter.replaceOpWithNewOp( - op, callee.getName(), TypeRange(), op.getOperands()); - AppendCustomCallAttrs(call, custom_call_attrs); - - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -class InfeedOpLowering : public IoFeedOpLowering { - public: - using IoFeedOpLowering::IoFeedOpLowering; -}; - -class OutfeedOpLowering : public IoFeedOpLowering { - public: - using IoFeedOpLowering::IoFeedOpLowering; -}; - -//===----------------------------------------------------------------------===// - -class CustomCallOpLowering : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = "xla.gpu.custom_call"; - - public: - CustomCallOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - // Rewrite custom call with `API_VERSION_TYPED_FFI` version into XLA runtime - // custom calls bypassing custom call adaptor. - LogicalResult rewriteTypedCustomCall(CustomCallOp op, - PatternRewriter& rewriter) const { - // TODO(ezhulenev): Support target arg mapping, or explain why we do not - // need them for typed custom calls. - if (op.getTargetArgMapping()) - return op.emitOpError( - "API_VERSION_TYPED_FFI custom calls do not " - "support target arg mapping"); - - // Create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = - custom_calls_.GetOrCreate(b, op.getCallTargetName(), op); - // Custom calls starting with the __gpu$ prefix are considered internal and - // statically linked (e.g. __gpu$TopK). - if (!op.getCallTargetName().starts_with("__gpu$")) { - callee->setAttr("rt.dynamic", UnitAttr::get(b.getContext())); - } - - // Forward backend config to the custom call implementation. - auto dict = op.getBackendConfig() - ? op.getBackendConfig()->cast() - : nullptr; - llvm::SmallVector backend_config(dict.begin(), dict.end()); - - // Call the custom call function forwarding user-defined attributes. - auto call = rewriter.replaceOpWithNewOp( - op, callee.getName(), TypeRange(), op.getOperands()); - AppendCustomCallAttrs(call, backend_config); - - return success(); - } - - LogicalResult matchAndRewrite(CustomCallOp op, - PatternRewriter& rewriter) const override { - // Typed custom calls lowered directly to XLA runtime custom calls. - if (op.getApiVersion() == mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI) - return rewriteTypedCustomCall(op, rewriter); - - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - // By default all operands passed to the custom call handler. - llvm::SmallVector operands = op.getOperands(); - - // If custom call has target arguments mapping, then we need to pass `i64` - // scalars in place of holes to detect them in custom call handler. - // - // TODO(ezhulenev): We need an `xla` dialect to model Xla framework - // semantics including holes for custom call. As a work around we pass `i64` - // values because xla custom call do not support scalar arguments, and we - // can disambiguate holes from buffers. - if (op.getTargetArgMapping().has_value()) { - auto mapping = *op.getTargetArgMapping(); - int64_t num_args = mapping.getNumArgs(); - int64_t num_results = mapping.getNumResults(); - - // We represent holes as an arbitrary `i64` constant. - Value hole = b.create(b.getI64IntegerAttr(-1)); - operands = llvm::SmallVector(num_args + num_results, hole); - - // Update operands to mapped custom call arguments. - auto args = mapping.getArgsToTargetArgs(); - for (const auto& indexed : llvm::enumerate(args)) - operands[indexed.value()] = op.getArgs()[indexed.index()]; - - // Update operands to mapped custom call results. - auto res = mapping.getResultsToTargetResults(); - for (const auto& indexed : llvm::enumerate(res)) - operands[num_args + indexed.value()] = op.getOutput()[indexed.index()]; - } - - // Create a custom call function declaration. - func::FuncOp callee = custom_calls_.GetOrCreate( - b, kCustomCallTarget, TypeRange(ValueRange(operands)), TypeRange()); - - llvm::SmallVector custom_call_attrs = { - {b.getStringAttr("api_version"), op.getApiVersionAttr()}, - {b.getStringAttr("backend_config"), op.getBackendConfigAttr()}, - {b.getStringAttr("call_target_name"), op.getCallTargetNameAttr()}}; - - // Call the runtime intrinsic with the original operands. - auto call = rewriter.replaceOpWithNewOp( - op, callee.getName(), TypeRange(), operands); - AppendCustomCallAttrs(call, custom_call_attrs); - - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -//===----------------------------------------------------------------------===// - -class FftOpLowering : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = "xla.gpu.fft"; - - public: - FftOpLowering(MLIRContext* ctx, UidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), uid_(uid), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(FftOp op, - PatternRewriter& rewriter) const override { - // Create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, kCustomCallTarget, op); - - llvm::SmallVector custom_call_attrs = { - {b.getStringAttr("fft_length"), op.getFftLengthAttr()}, - {b.getStringAttr("fft_type"), op.getFftTypeAttr()}, - {b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())}}; - - // Convert Fft to a function call. - auto call = rewriter.replaceOpWithNewOp( - op, callee.getName(), TypeRange(), op.getOperands()); - AppendCustomCallAttrs(call, custom_call_attrs); - return success(); - } - - private: - UidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -//===----------------------------------------------------------------------===// - -class CaseOpLowering : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(CaseOp op, - PatternRewriter& rewriter) const override { - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - // Copy index buffer to the host ... - auto index_type = op.getIndex().getType().dyn_cast(); - - // Always create an `alloca` in the parent function entry block. - // See: https://llvm.org/docs/Frontend/PerformanceTips.html#use-of-allocas - Value index_on_host = [&]() -> Value { - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart(&op->getParentOfType().front()); - return b.create(index_type); - }(); - - b.create(TypeRange(), ValueRange({index_on_host, op.getIndex()})); - - // Get the index value from the buffer. - Value index = b.create(index_type.getElementType(), - index_on_host, ValueRange()); - - bool is_predicate = index_type.getElementType().isInteger(1); - - // For binary index (predicate) convert i1 to i32 index. - if (is_predicate) { - Value c0 = b.create(b.getI32IntegerAttr(0)); - Value c1 = b.create(b.getI32IntegerAttr(1)); - index = b.create(index, c0, c1); - } - - // For integer index make sure that it is within range. - if (!is_predicate) { - unsigned n = op.getNumRegions() - 1; - Value c0 = b.create(b.getI32IntegerAttr(0)); - Value cN = b.create(b.getI32IntegerAttr(n)); - - Value too_small = b.create( - b.getI1Type(), arith::CmpIPredicate::slt, index, c0); - Value too_large = b.create( - b.getI1Type(), arith::CmpIPredicate::sgt, index, cN); - - Value out_of_range = b.create(too_small, too_large); - index = b.create(out_of_range, cN, index); - } - - // Wrap the CFG constructed from the `lmhlo.case` operation in an - // `scf.execute_region` operation, so that we do not introduce the CFG - // into regions that expect a single block (e.g. inside the loop body). - auto execute = b.create(TypeRange()); - - // Add an entry block to the execute region operation. - Block& entry = execute.getRegion().emplaceBlock(); - - // Create a block with `scf.yield` terminator. - Block& yield = execute.getRegion().emplaceBlock(); - b.setInsertionPointToStart(&yield); - b.create(); - - // Prepare case destinations for the `scf.switch` operation. - llvm::SmallVector case_values; - llvm::SmallVector case_blocks; - llvm::SmallVector case_operands; - - // Create blocks from each of the case regions. - for (Region& region : op->getRegions()) { - // Move `lmhlo.case` block into the execute region. - Block& block = region.front(); - block.moveBefore(&yield); - - // Erase original `lmhlo.terminator`. - rewriter.eraseOp(block.getTerminator()); - - // Branch into the yield block. - b.setInsertionPointToEnd(&block); - b.create(&yield); - - // Add a `cf.switch` case. - int32_t idx = case_blocks.size(); - case_values.push_back(b.getI32IntegerAttr(idx).getValue()); - case_blocks.push_back(&block); - case_operands.push_back({}); - } - - // Create a `cf.switch` operation in the execute region entry block. - b.setInsertionPointToEnd(&entry); - b.create(index, &yield, ValueRange(), case_values, - case_blocks, case_operands); - - // Erase the original case operation. - rewriter.eraseOp(op); - - return success(); - } -}; - -//===----------------------------------------------------------------------===// - -class WhileOpLowering : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - // Rewrite while loop with known trip count to `scf.for` operation. - LogicalResult rewriteForLoop(WhileOp op, PatternRewriter& rewriter) const { - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - Value lb = b.create(0); - Value ub = b.create(*op.getTripCount()); - Value c1 = b.create(1); - - // Create an `scf.for` loop in place of `lmhlo.while` loop. - auto loop = b.create(lb, ub, c1, ValueRange()); - - // Move body region into the new loop operation. - IRMapping mapping; - rewriter.eraseOp(op.getBody().front().getTerminator()); - rewriter.inlineBlockBefore(&op.getBody().front(), - loop.getBody()->getTerminator()); - - // Erase the original while loop. - rewriter.eraseOp(op); - - return success(); - } - - // Rewrite while loop with unknown trip count to `scf.while` operation. - LogicalResult rewriteWhileLoop(WhileOp op, PatternRewriter& rewriter) const { - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - // Create an `scf.while` loop in place of `lmhlo.while` loop. - auto loop = b.create(TypeRange(), ValueRange()); - - // Predicate buffer placed on the device. - Value pred = op.getOperand(0); - - // Inline condition and body regions into the new loop operation. - IRMapping mapping; - rewriter.inlineRegionBefore(op.getCond(), loop.getBefore(), - loop.getBefore().begin()); - rewriter.inlineRegionBefore(op.getBody(), loop.getAfter(), - loop.getAfter().begin()); - - { // Replace loop condition terminator. - auto* terminator = loop.getBefore().back().getTerminator(); - b.setInsertionPointAfter(terminator); - - auto i1 = b.getI1Type(); - - // Always create an `alloca` in the parent function entry block. - // See: https://llvm.org/docs/Frontend/PerformanceTips.html#use-of-allocas - Value pred_on_host = [&]() -> Value { - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart( - &op->getParentOfType().front()); - return b.create(MemRefType::get({}, i1)); - }(); - - // Copy predicate buffer to the host ... - b.create(TypeRange(), ValueRange({pred_on_host, pred})); - - // .. and check if we need to continue loop iteration. - Value cond = b.create(i1, pred_on_host, ValueRange()); - b.create(cond, ValueRange()); - rewriter.eraseOp(terminator); - } - - { // Replace loop body terminator. - auto* terminator = loop.getAfter().back().getTerminator(); - b.setInsertionPointAfter(terminator); - b.create(TypeRange(), ValueRange()); - rewriter.eraseOp(terminator); - } - - // Erase the original while loop. - rewriter.eraseOp(op); - - return success(); - } - - LogicalResult matchAndRewrite(WhileOp op, - PatternRewriter& rewriter) const override { - assert(op.getNumOperands() == 1 && "expected single lmhlo.while operand"); - return op.getTripCount().has_value() ? rewriteForLoop(op, rewriter) - : rewriteWhileLoop(op, rewriter); - } -}; - -//===----------------------------------------------------------------------===// -// Collective operations lowerings. -//===----------------------------------------------------------------------===// - -using mlir::lmhlo::PartitionIdOp; -using mlir::lmhlo::ReplicaIdOp; -using mlir::lmhlo_gpu::AllGatherDoneOp; -using mlir::lmhlo_gpu::AllGatherStartOp; -using mlir::lmhlo_gpu::AllReduceDoneOp; -using mlir::lmhlo_gpu::AllReduceStartOp; -using mlir::lmhlo_gpu::AllToAllDoneOp; -using mlir::lmhlo_gpu::AllToAllStartOp; -using mlir::lmhlo_gpu::CollectivePermuteDoneOp; -using mlir::lmhlo_gpu::CollectivePermuteStartOp; -using mlir::lmhlo_gpu::ReduceScatterDoneOp; -using mlir::lmhlo_gpu::ReduceScatterStartOp; - -using lmhlo::RecvDoneOp; -using lmhlo::RecvOp; -using lmhlo::SendDoneOp; -using lmhlo::SendOp; - -// We assign unique id to all collective operations in the module, so that we -// can efficiently access per-op state at run time. Exception to this rule are -// asynchronous collective operations, that share the same unique id by the pair -// of corresponding `start` and `done` operations. -// -// Asynchronous collective operations pass HLO Token to represent the dependency -// between the `Start` and `Done` operations. When we lower to XLA runtime -// custom calls we rely on assigning each unique pair of `Start` and `Done` -// operations a unique event id, and use shared "context" owned by the -// GpuExecutable to pass Gpu events from `Start` to `Done` custom call handlers. -// -// TODO(ezhulenev): Once XLA runtime custom calls support returning values, we -// should explicitly return event id from the `Start` custom call, and pass it -// to the `Done` custom call. Longer term this should become an `!async.token` -// and rely on XLA runtime asynchronous execution. -class CollectiveUidGenerator { - public: - CollectiveUidGenerator() : cnt_(0) {} - - // Assigns a unique event id to the pair of start and done operations. - int32_t AssignUid(Operation* start, Operation* done) { - int32_t id = next(); - uids_[start] = id; - uids_[done] = id; - return id; - } - - FailureOr AssignedUid(Operation* op) { - // Async operations must be assigned uid ahead of time. - if (isa(op)) { - auto it = uids_.find(op); - if (it == uids_.end()) return failure(); - return it->second; - } - // For every other operation we just assign a next id. - return next(); - } - - private: - int32_t next() { return cnt_++; } - - int32_t cnt_; - llvm::DenseMap uids_; -}; - -// Filters out host send/recv which do not participate in collective op -// lowerings. -struct CollectiveFilter { - template - static std::enable_if_t, bool> ShouldHandle( - OpT) { - return true; - } - - // We only handle send/recv that is not a host transfer. - template - static std::enable_if_t, bool> ShouldHandle( - OpT op) { - return !op.getIsHostTransfer(); - } -}; - -template -NcclCollectiveConfig GetNcclCollectiveConfigForP2POps(OpT op, int replica_count, - int num_partitions) { - return ThunkT::GetNcclP2PConfig(op, replica_count, num_partitions).config; -} - -template -class CollectiveOpLowering : public OpRewritePattern { - // Define target custom call for lowering of collective ops. - static StringRef Target(AllGatherStartOp) { return "xla.gpu.all_gather"; } - static StringRef Target(AllReduceStartOp) { return "xla.gpu.all_reduce"; } - static StringRef Target(AllToAllStartOp) { return "xla.gpu.all_to_all"; } - static StringRef Target(ReduceScatterStartOp) { - return "xla.gpu.reduce_scatter"; - } - static StringRef Target(CollectivePermuteStartOp) { - return "xla.gpu.collective_permute"; - } - static StringRef Target(SendOp) { return "xla.gpu.send"; } - static StringRef Target(RecvOp) { return "xla.gpu.recv"; } - - template - static std::enable_if_t< - is_any, - NcclCollectiveConfig> - GetNcclCollectiveConfig(OpT op, int /*replica_count*/, - int /*num_partitions*/) { - return GetNcclCollectiveConfigForMlir(op, op.getUseGlobalDeviceIds()); - } - - static NcclCollectiveConfig GetNcclCollectiveConfig(AllToAllStartOp op, - int /*replica_count*/, - int /*num_partitions*/) { - // TODO(b/180174349): LMHLO AllToAll incorrectly has use_global_device_ids - // attribute and it should be removed. - return GetNcclCollectiveConfigForMlir(op, std::nullopt); - } - - static NcclCollectiveConfig GetNcclCollectiveConfig( - CollectivePermuteStartOp op, int replica_count, int num_partitions) { - return GetNcclCollectiveConfigForP2POps( - op, replica_count, num_partitions); - } - - static NcclCollectiveConfig GetNcclCollectiveConfig(SendOp op, - int replica_count, - int num_partitions) { - return GetNcclCollectiveConfigForP2POps( - op, replica_count, num_partitions); - } - - static NcclCollectiveConfig GetNcclCollectiveConfig(RecvOp op, - int replica_count, - int num_partitions) { - return GetNcclCollectiveConfigForP2POps( - op, replica_count, num_partitions); - } - - template - static std::enable_if_t, - LogicalResult> - TryDegenerateToMemCopy(NonCollectivePermuteOp op, - const NcclCollectiveConfig& config, int replica_count, - int num_partitions, PatternRewriter& rewriter) { - if (!config.IsDegenerate(replica_count, num_partitions)) { - return failure(); - } - - for (int64_t i = 0; i < op.getInputs().size(); i++) { - rewriter.create( - op.getLoc(), TypeRange(), - ValueRange({op.getOutputs()[i], op.getOperands()[i]})); - } - - return success(); - } - - // Send/Recv is never degenerate by itself, so returns failure(). - template - static std::enable_if_t, LogicalResult> - TryDegenerateToMemCopy(OpT op, const NcclCollectiveConfig& config, - int replica_count, int num_partitions, - PatternRewriter& rewriter) { - return failure(); - } - - static LogicalResult TryDegenerateToMemCopy( - CollectivePermuteStartOp op, const NcclCollectiveConfig& config, - int replica_count, int num_partitions, PatternRewriter& rewriter) { - if (!NcclCollectivePermuteStartThunk::IsDegenerate(op, replica_count, - num_partitions)) { - return failure(); - } - - rewriter.create( - op.getLoc(), TypeRange(), - ValueRange({op.getOutput(), op.getOperand()})); - - return success(); - } - - static Status CheckImplementable(AllGatherStartOp op, int64_t replica_count, - int64_t num_partitions) { - return NcclAllGatherStartThunk::CheckImplementable(op, replica_count, - num_partitions); - } - - static Status CheckImplementable(AllReduceStartOp op, int64_t replica_count, - int64_t num_partitions) { - return NcclAllReduceStartThunk::CheckImplementable(op, replica_count, - num_partitions); - } - - static Status CheckImplementable(AllToAllStartOp op, int64_t replica_count, - int64_t num_partitions) { - return NcclAllToAllStartThunk::CheckImplementable(op, replica_count, - num_partitions); - } - - static Status CheckImplementable(CollectivePermuteStartOp op, - int64_t replica_count, - int64_t num_partitions) { - return NcclCollectivePermuteStartThunk::CheckImplementable( - op, replica_count, num_partitions); - } - - static Status CheckImplementable(SendOp op, int64_t replica_count, - int64_t num_partitions) { - return NcclSendThunk::CheckImplementable(op, replica_count, num_partitions); - } - - static Status CheckImplementable(RecvOp op, int64_t replica_count, - int64_t num_partitions) { - return NcclRecvThunk::CheckImplementable(op, replica_count, num_partitions); - } - - static Status CheckImplementable(ReduceScatterStartOp op, - int64_t replica_count, - int64_t num_partitions) { - return NcclReduceScatterStartThunk::CheckImplementable(op, replica_count, - num_partitions); - } - - template - static typename std::enable_if_t< - is_any, LogicalResult> - SetSpecificAttrs(ImplicitLocOpBuilder& b, OpT op, func::CallOp call) { - std::optional reduction_kind = - NcclAllReduceReduceScatterThunkBase::MatchAllReduceComputation( - op.getComputation()); - if (!reduction_kind.has_value()) - return op.emitOpError() - << "Failed to determine reduction computation for AllReduce"; - - call->setAttr( - b.getStringAttr("reduction_kind"), - b.getI64IntegerAttr(static_cast(reduction_kind.value()))); - - return success(); - } - - static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b, - AllGatherStartOp op, - func::CallOp call) { - return success(); - } - - static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b, - AllToAllStartOp op, func::CallOp call) { - call->setAttr(b.getStringAttr("has_split_dimension"), - b.getBoolAttr(op.getSplitDimension().has_value())); - return success(); - } - - static void SetSourceTargetPeersAttrs( - ImplicitLocOpBuilder& b, - const std::vector>& source_target_pairs, - func::CallOp call) { - std::vector source_peers; - std::vector target_peers; - source_peers.reserve(source_target_pairs.size()); - target_peers.reserve(source_target_pairs.size()); - for (const auto& source_target_pair : source_target_pairs) { - source_peers.push_back(source_target_pair.first); - target_peers.push_back(source_target_pair.second); - } - - auto source_peers_attr = b.getI64TensorAttr(source_peers); - auto target_peers_attr = b.getI64TensorAttr(target_peers); - call->setAttr(b.getStringAttr("source_peers"), source_peers_attr); - call->setAttr(b.getStringAttr("target_peers"), target_peers_attr); - } - - static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b, - CollectivePermuteStartOp op, - func::CallOp call) { - auto source_target_pairs_or = - ConvertNx2Attribute(op.getSourceTargetPairs()); - if (!source_target_pairs_or.ok()) { - return op.emitOpError() << source_target_pairs_or.status().message(); - } - SetSourceTargetPeersAttrs(b, source_target_pairs_or.value(), call); - return success(); - } - - template - static typename std::enable_if_t, LogicalResult> - SetSpecificAttrs(ImplicitLocOpBuilder& b, OpT op, func::CallOp call) { - auto source_target_pairs_or = - GetSourceTargetPairs(op.getFrontendAttributes()); - if (!source_target_pairs_or.ok()) { - return op.emitOpError() << source_target_pairs_or.status().message(); - } - SetSourceTargetPeersAttrs(b, source_target_pairs_or.value(), call); - return success(); - } - - template - static typename std::enable_if_t, bool> getIsSync( - OpT) { - return false; - } - - template - static typename std::enable_if_t, bool> - getIsSync(OpT op) { - return op.getIsSync(); - } - - template - static typename std::enable_if_t, bool> - noParallelCustomCall(OpT) { - return false; - } - - template - static typename std::enable_if_t, bool> - noParallelCustomCall(OpT op) { - return op.getNoParallelCustomCall(); - } - - // For async collective erase all corresponding done operations. - template - void eraseDoneOp(PatternRewriter& rewriter, CollectiveOp op) const { - if (auto start = dyn_cast(op.getOperation())) { - auto users = llvm::to_vector(start.getToken().getUsers()); - llvm::for_each(users, [&](Operation* user) { - if (isa(user)) rewriter.eraseOp(user); - }); - } - } - - public: - CollectiveOpLowering(MLIRContext* ctx, CollectiveUidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), - uid_(uid), - custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(CollectiveOp op, - PatternRewriter& rewriter) const override { - if (!CollectiveFilter::ShouldHandle(op)) { - return failure(); - } - - // Construct an NCCL collective config from the parent func attributes. - func::FuncOp fn = op->template getParentOfType(); - auto replica_count_attr = fn->getAttrOfType("replica_count"); - auto num_partitions_attr = fn->getAttrOfType("num_partitions"); - const int64_t replica_count = replica_count_attr.getInt(); - const int64_t num_partitions = num_partitions_attr.getInt(); - - NcclCollectiveConfig config = - GetNcclCollectiveConfig(op, replica_count, num_partitions); - - // For async collective erase all corresponding done operations. - auto erase_done_op = [&]() { - eraseDoneOp(rewriter, op); - eraseDoneOp(rewriter, op); - eraseDoneOp(rewriter, - op); - eraseDoneOp(rewriter, op); - eraseDoneOp(rewriter, op); - eraseDoneOp(rewriter, op); - eraseDoneOp(rewriter, op); - }; - - // A given collective op can be degenerate if across all groups formed - // by it are singleton. In such a case, we don't need to do any - // communication and we can just copy the input to the output. - if (succeeded(TryDegenerateToMemCopy(op, config, replica_count, - num_partitions, rewriter))) { - // For async collective erase all corresponding done operations. - erase_done_op(); - - // Erase the original collective operation. - rewriter.eraseOp(op); - - return success(); - } - - Status implementable_status = - CheckImplementable(op, replica_count, num_partitions); - if (!implementable_status.ok()) { - return op.emitOpError() << implementable_status.message(); - } - - // Check that we have and assigned unique collective operation id. - auto uid = uid_.AssignedUid(op); - if (failed(uid)) { - return op.emitOpError("failed to get a unique collective operation id"); - } - - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - // We always drop the return value from the signature, because for - // AllReduceStart operation we pass dependency through the collective - // operation id. - func::FuncOp callee = custom_calls_.GetOrCreate( - b, Target(op), TypeRange(op.getOperands()), TypeRange()); - - // Convert collective op to a function call. - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), op.getOperands()); - - // Copy backend specific attributes. - call->setAttr(b.getStringAttr("group_mode"), - b.getI64IntegerAttr(static_cast(config.group_mode))); - call->setAttr(b.getStringAttr("op_id"), b.getI64IntegerAttr(config.op_id)); - - // TODO(b/233930690): Pass the attribute below as a nested array. - // Pass an array of arrays using two vectors; one specifying all the values - // and another specifying the (ending) offsets of each array in the other - // vector. Example: [ [10, 20, 30, 40], [50, 60], [70, 80, 90] ] turns into - // offsets=[4, 6, 9] values=[10, 20, 30, 40, 50, 60, 70, 80, 90]. - std::vector replica_group_offsets; - std::vector replica_group_values; - replica_group_offsets.reserve(config.replica_groups.size()); - int replica_group_offset = 0; - for (const auto& replica_group : config.replica_groups) { - replica_group_offset += replica_group.replica_ids_size(); - replica_group_offsets.push_back(replica_group_offset); - replica_group_values.reserve(replica_group_offset); - for (auto replica_id : replica_group.replica_ids()) { - replica_group_values.push_back(replica_id); - } - } - call->setAttr(b.getStringAttr("replica_group_offsets"), - b.getI64TensorAttr(replica_group_offsets)); - call->setAttr(b.getStringAttr("replica_group_values"), - b.getI64TensorAttr(replica_group_values)); - - // Assign a unique collective operation id. - call->setAttr(b.getStringAttr("uid"), b.getI32IntegerAttr(*uid)); - - // Set attributes specific to the type of collective operation. - auto result = SetSpecificAttrs(b, op, call); - if (failed(result)) return result; - - bool is_async = !getIsSync(op); - call->setAttr(b.getStringAttr("is_async"), b.getBoolAttr(is_async)); - - call->setAttr(b.getStringAttr("no_parallel_custom_call"), - b.getBoolAttr(noParallelCustomCall(op))); - - // If the collective will not execute asynchronously, erase the associated - // done op. - if (!is_async) { - erase_done_op(); - } else { - // For asynchronous start operation we need to produce a fake token, that - // will be later removed, because corresponding `done` operation doesn't - // have a token argument. We rely on the `unrealized_conversion_cast` - // operation to create a fake token from the `i8` constant, and on the - // dead code elimination pass that will remove unused fake tokens. - Value token = op.getToken(); - Value c0 = b.create(b.getI8IntegerAttr(0)); - auto fake = b.create(token.getType(), c0); - token.replaceAllUsesWith(fake.getResult(0)); - } - - // Erase the original collective operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - CollectiveUidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -#define DEFINE_COLLECTIVE_OP_LOWERING(OP) \ - class OP##Lowering : public CollectiveOpLowering { \ - public: \ - using CollectiveOpLowering::CollectiveOpLowering; \ - } - -DEFINE_COLLECTIVE_OP_LOWERING(AllGatherStartOp); -DEFINE_COLLECTIVE_OP_LOWERING(AllReduceStartOp); -DEFINE_COLLECTIVE_OP_LOWERING(AllToAllStartOp); -DEFINE_COLLECTIVE_OP_LOWERING(CollectivePermuteStartOp); -DEFINE_COLLECTIVE_OP_LOWERING(ReduceScatterStartOp); -DEFINE_COLLECTIVE_OP_LOWERING(SendOp); -DEFINE_COLLECTIVE_OP_LOWERING(RecvOp); - -#undef DEFINE_COLLECTIVE_OP_LOWERING - -template -class AsyncDoneOpLowering : public OpRewritePattern { - public: - AsyncDoneOpLowering(MLIRContext* ctx, CollectiveUidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), uid_(uid), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(OpT op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate( - b, "xla.gpu.collective_done", TypeRange(), TypeRange()); - - // Get a unique collective operation id. - FailureOr uid = uid_.AssignedUid(op); - if (failed(uid)) - return op.emitOpError("failed to get a unique collective operation id"); - - llvm::SmallVector custom_call_attributes = { - {b.getStringAttr("uid"), b.getI32IntegerAttr(*uid)}, - {b.getStringAttr("done_type"), b.getStringAttr(Derived::kDoneType)}}; - - // Convert AllReduceDone to a function call. - auto call = rewriter.replaceOpWithNewOp(op, callee.getName(), - TypeRange()); - AppendCustomCallAttrs(call, custom_call_attributes); - - return success(); - } - - private: - CollectiveUidGenerator& uid_; - CustomCallDeclarations& custom_calls_; -}; - -#define DEFINE_COLLECTIVE_DONE_OP_LOWERING(OP, done_type) \ - struct OP##Lowering : public AsyncDoneOpLowering { \ - static constexpr const char kDoneType[] = done_type; \ - using AsyncDoneOpLowering::AsyncDoneOpLowering; \ - } - -DEFINE_COLLECTIVE_DONE_OP_LOWERING(AllGatherDoneOp, "all_gather_done"); -DEFINE_COLLECTIVE_DONE_OP_LOWERING(AllReduceDoneOp, "all_reduce_done"); -DEFINE_COLLECTIVE_DONE_OP_LOWERING(AllToAllDoneOp, "all_to_all_done"); -DEFINE_COLLECTIVE_DONE_OP_LOWERING(CollectivePermuteDoneOp, - "collective_permute_done"); -DEFINE_COLLECTIVE_DONE_OP_LOWERING(ReduceScatterDoneOp, "reduce_scatter_done"); -DEFINE_COLLECTIVE_DONE_OP_LOWERING(SendDoneOp, "send_done"); -DEFINE_COLLECTIVE_DONE_OP_LOWERING(RecvDoneOp, "recv_done"); - -#undef DEFINE_COLLECTIVE_DONE_OP_LOWERING - -template -class CollectiveIdOpLowering : public OpRewritePattern { - static StringRef Target(ReplicaIdOp) { return "xla.gpu.replica_id"; } - static StringRef Target(PartitionIdOp) { return "xla.gpu.partition_id"; } - - public: - CollectiveIdOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(CollectiveIdOp op, - PatternRewriter& rewriter) const override { - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate(b, Target(op), op); - - // Call the runtime intrinsic with the original operands. - rewriter.replaceOpWithNewOp(op, callee.getName(), TypeRange(), - op->getOperands()); - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -class ReplicaIdOpLowering : public CollectiveIdOpLowering { - public: - using CollectiveIdOpLowering::CollectiveIdOpLowering; -}; - -class PartitionIdOpLowering : public CollectiveIdOpLowering { - public: - using CollectiveIdOpLowering::CollectiveIdOpLowering; -}; - -//===----------------------------------------------------------------------===// -// Host<->Device communication ops lowering (Send/Recv). -//===----------------------------------------------------------------------===// - -template -class HostSendRecvOpLowering : public OpRewritePattern { - public: - HostSendRecvOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(OpT op, - PatternRewriter& rewriter) const override { - if (!op.getIsHostTransfer()) { - return failure(); - } - - constexpr bool is_done_op = - is_any; - - // Get or create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - // For done ops, drop the token input. - TypeRange input_types = - is_done_op ? TypeRange() : TypeRange(op->getOperands()); - func::FuncOp callee = custom_calls_.GetOrCreate( - b, Derived::kCustomCallTarget, input_types, TypeRange()); - - llvm::SmallVector custom_call_attributes = { - {b.getStringAttr("channel_handle"), op.getChannelHandleAttr()}}; - if constexpr (!is_done_op) { - custom_call_attributes.push_back(NamedAttribute( - b.getStringAttr("frontend_attributes"), op.getFrontendAttributes())); - } - - // Convert Send/Recv/SendDone/RecvDone to a function call. - ValueRange inputs = - is_done_op ? ValueRange() : ValueRange(op->getOperands()); - auto call = rewriter.create(op.getLoc(), callee.getName(), - TypeRange(), inputs); - AppendCustomCallAttrs(call, custom_call_attributes); - - if constexpr (!is_done_op) { - // For communication operation we need to produce a fake token, that will - // be later removed, because corresponding `done` operation doesn't have - // the token argument. We rely on the `unrealized_conversion_cast` - // operation to create a fake token from the `i8` constant. - Value token = op.getResult(); - Value c0 = b.create(b.getI8IntegerAttr(0)); - auto fake = b.create(token.getType(), c0); - token.replaceAllUsesWith(fake.getResult(0)); - } - - // Erase the original operation. - rewriter.eraseOp(op); - - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -#define DEFINE_HOST_SENDRECV_OP_LOWERING(OP, custom_call) \ - struct Host##OP##Lowering \ - : public HostSendRecvOpLowering { \ - static constexpr const char kCustomCallTarget[] = custom_call; \ - using HostSendRecvOpLowering::HostSendRecvOpLowering; \ - } - -DEFINE_HOST_SENDRECV_OP_LOWERING(SendOp, "xla.gpu.send_host"); -DEFINE_HOST_SENDRECV_OP_LOWERING(SendDoneOp, "xla.gpu.send_done_host"); -DEFINE_HOST_SENDRECV_OP_LOWERING(RecvOp, "xla.gpu.recv_host"); -DEFINE_HOST_SENDRECV_OP_LOWERING(RecvDoneOp, "xla.gpu.recv_done_host"); - -//===----------------------------------------------------------------------===// - -template -static WalkResult AssignAsyncUid(Operation* op, - CollectiveUidGenerator& collective_uid) { - auto start = dyn_cast(op); - if (!start) { - if constexpr (sizeof...(Remaining) != 0) { - return AssignAsyncUid(op, collective_uid); - } else { - return WalkResult::advance(); - } - } - - if (!CollectiveFilter::ShouldHandle(start)) { - return WalkResult::advance(); - } - - Value token = start.getToken(); - - // We expect the token to be consumed just once. - if (!token.hasOneUse()) return start.emitOpError("token has multiple uses"); - - // Token must be consumed by the corresponding done operation. - auto done = dyn_cast(*token.getUsers().begin()); - if (!done) return start.emitOpError("illegal token user"); - - collective_uid.AssignUid(start, done); - return WalkResult::advance(); -} - -void ConvertLmhloToGpuRuntimePass::runOnOperation() { - ModuleOp module = getOperation(); - MLIRContext* ctx = module.getContext(); - - // Keep track of the custom calls created from the lowered operations. - SymbolTable sym_table(module); - CustomCallDeclarations custom_calls(std::move(sym_table)); - - // Convert lmhlo operations to XLA gpu runtime custom calls. - RewritePatternSet patterns(ctx); - patterns.insert(ctx); - patterns.insert( - ctx, custom_calls); - - UidGenerator fft_uid; - patterns.insert(ctx, fft_uid, custom_calls); - - // Assign shared unique id to each unique pair of async start-done operations, - // all other collective operations will get assigned uid. - CollectiveUidGenerator collective_uid; - auto walked = module.walk([&collective_uid](Operation* op) { - return AssignAsyncUid< - std::pair, - std::pair, - std::pair, - std::pair, - std::pair, - std::pair, std::pair>( - op, collective_uid); - }); - if (walked.wasInterrupted()) return signalPassFailure(); - - // Convert lmhlo collective operations to XLA gpu runtime custom calls. - patterns.insert(ctx, - custom_calls); - patterns.insert( - ctx, collective_uid, custom_calls); - - // Convert lmhlo host<->device point-to-point communication operations to XLA - // gpu runtime. - patterns.insert(ctx, - custom_calls); - - if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) - return signalPassFailure(); - - // TODO(ezhulenev): We must run `done` op lowering after the `start` op - // lowering to ensure that all redundant collective operations will be - // safely replaced by a `memcpy` operations. - // - // This should be a part of lmhlo operation canonicalization. - { - RewritePatternSet patterns(ctx); - patterns.insert(ctx, collective_uid, custom_calls); - if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) - return signalPassFailure(); - } -} - -std::unique_ptr> -createConvertLmhloToGpuRuntimePass() { - return std::make_unique(); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/mlir/backends/gpu/transforms/memref_get_global_to_arg.cc b/xla/mlir/backends/gpu/transforms/memref_get_global_to_arg.cc deleted file mode 100644 index 4685b891c0340..0000000000000 --- a/xla/mlir/backends/gpu/transforms/memref_get_global_to_arg.cc +++ /dev/null @@ -1,168 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "xla/mlir/backends/gpu/transforms/passes.h" - -namespace xla { -namespace gpu { - -#define GEN_PASS_DEF_CONVERTMEMREFGETGLOBALTOARGPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -using namespace mlir; // NOLINT - -class ConvertMemrefGetGlobalToArgPass - : public impl::ConvertMemrefGetGlobalToArgPassBase< - ConvertMemrefGetGlobalToArgPass> { - public: - ConvertMemrefGetGlobalToArgPass() = default; - - explicit ConvertMemrefGetGlobalToArgPass(int64_t min_num_elements) { - this->min_num_elements_ = min_num_elements; - } - - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } -}; - -//===----------------------------------------------------------------------===// - -using GlobalConstantsArgs = - llvm::DenseMap>; - -// Returns a mapping from a global constant name to the function argument. -// -// Example: -// -// memref.global "private" constant @cst : memref<2x3xf32> -// func @get_global(%arg0: memref<24xi8> {lmhlo.constant_name = "cst"}) -// -// All memref.get_global operations will be replaced by constant arguments -// corresponding to the global constant. -static GlobalConstantsArgs GetConstantArgs(ModuleOp m) { - GlobalConstantsArgs mapping; - - m.walk([&](func::FuncOp func) { - for (unsigned i = 0; i < func.getNumArguments(); ++i) { - auto cst = func.getArgAttrOfType(i, "lmhlo.constant_name"); - if (cst) mapping[func][cst] = func.getArgument(i); - } - }); - - return mapping; -} - -class GetGlobalOpLowering : public OpRewritePattern { - public: - GetGlobalOpLowering(MLIRContext* ctx, const GlobalConstantsArgs& cst_args) - : OpRewritePattern(ctx), cst_args_(cst_args) {} - - LogicalResult matchAndRewrite(memref::GetGlobalOp op, - PatternRewriter& rewriter) const override { - // Find global constants mapping for the parent function. - auto func_mapping = cst_args_.find(op->getParentOfType()); - if (func_mapping == cst_args_.end()) return failure(); - - // Check if the global operation corresponds to the LMHLO constant arg. - auto arg = func_mapping->second.find(op.getName()); - if (arg == func_mapping->second.end()) return failure(); - - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - MemRefType memref = op->getResult(0).getType().cast(); - - // For identity layouts we can replace all loads from a global with the - // corresponding argument. - if (memref.getLayout().isIdentity()) { - Value c0 = b.create(rewriter.getIndexAttr(0)); - rewriter.replaceOpWithNewOp(op, memref, arg->second, c0, - ValueRange()); - return success(); - } - - // For non-identity type we first view constant argument as a flat memref - // with the correct element type, and then cast it to the strided memref - // corresponding to the original memref layout. - - // Get the strides and offset from the original memref type. - int64_t offset; - llvm::SmallVector strides; - if (failed(getStridesAndOffset(memref, strides, offset))) - return op.emitOpError("failed to compute strides and offset"); - - // Create a 1d view into the corresponding argument. - Value c0 = b.create(rewriter.getIndexAttr(0)); - Value flat_view = b.create( - MemRefType::get({memref.getNumElements()}, memref.getElementType()), - arg->second, c0, ValueRange()); - - // Cast flat memref view into the original memref type. - rewriter.replaceOpWithNewOp( - op, memref, flat_view, offset, memref.getShape(), strides); - - return success(); - } - - private: - const GlobalConstantsArgs& cst_args_; -}; - -void ConvertMemrefGetGlobalToArgPass::runOnOperation() { - ModuleOp module = getOperation(); - MLIRContext* ctx = module.getContext(); - - // Replace memref loads from globals corresponding to the constant arguments. - RewritePatternSet patterns(ctx); - GlobalConstantsArgs cst_args = GetConstantArgs(module); - patterns.insert(ctx, cst_args); - - // Set up conversion target to rewrite only GetGlobalOp larger than the - // threshold and avoid any other canonicalizations that can break later - // passes. - ConversionTarget target(*ctx); - target.addDynamicallyLegalOp( - [&](memref::GetGlobalOp op) { - auto memref = op.getType(); - return memref.getNumElements() < min_num_elements_; - }); - target.addLegalOp(); - - if (failed(applyPartialConversion(module, target, std::move(patterns)))) - signalPassFailure(); -} - -std::unique_ptr> -createConvertMemrefGetGlobalToArgPass() { - return std::make_unique(); -} - -std::unique_ptr> -createConvertMemrefGetGlobalToArgPass(int64_t min_num_elements) { - return std::make_unique(min_num_elements); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc b/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc deleted file mode 100644 index 63a2e723eb950..0000000000000 --- a/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc +++ /dev/null @@ -1,518 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Dominance.h" // from @llvm-project -#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/TypeRange.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/RegionUtils.h" // from @llvm-project -#include "xla/debug_options_flags.h" -#include "xla/mlir/backends/gpu/transforms/passes.h" -#include "xla/mlir/runtime/ir/rt_dialect.h" -#include "xla/mlir/runtime/ir/rt_ops.h" -#include "xla/mlir/runtime/utils/custom_calls.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/stream_executor/blas.h" -#include "xla/xla.pb.h" - -namespace xla { -namespace gpu { - -#define GEN_PASS_DEF_OUTLINEGPUGRAPHSPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -using namespace mlir; // NOLINT - -using mlir::gpu::LaunchFuncOp; - -class OutlineGpuGraphsPass - : public impl::OutlineGpuGraphsPassBase { - public: - OutlineGpuGraphsPass() = default; - explicit OutlineGpuGraphsPass( - absl::flat_hash_set command_types, - int min_graph_size) - : command_types_(std::move(command_types)) { - this->min_graph_size_ = min_graph_size; - } - - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } - - private: - absl::flat_hash_set command_types_ = { - DebugOptions::FUSION, DebugOptions::CUBLAS, DebugOptions::CUDNN}; - int gpu_graph_level_ = 3; -}; - -//===----------------------------------------------------------------------===// - -struct OpCapturePattern { - // CUDA-graph-compatible operations can be either moved or cloned into the - // graph capture function. Most of the operations should be moved, as they - // have side effects, however small constants and pure operations like - // `memref.view` can be safely cloned into the graph region. We rely on later - // dead code elimination to erase them from the "main" function if they are - // not used by any other operations. - enum class Capture { kMove, kClone }; - - virtual ~OpCapturePattern() = default; - virtual FailureOr match(Operation* op) = 0; -}; - -using OpCapturePatternSet = std::vector>; - -// A sequence of operations to be outlined into cuda graph capture function. -using CaptureSequence = - llvm::SmallVector>; - -//===----------------------------------------------------------------------===// - -template -struct OpCapture : public OpCapturePattern { - FailureOr match(Operation* op) final { - if (isa(op)) return capture; - return failure(); - } -}; - -static constexpr auto kMove = OpCapturePattern::Capture::kMove; -static constexpr auto kClone = OpCapturePattern::Capture::kClone; - -template -using MoveOp = OpCapture; -template -using CloneOp = OpCapture; - -// Capture gpu operations by moving them into graph capture function. -struct LaunchFuncOpCapture : public MoveOp {}; - -template -struct ConvOpCapture : public OpCapturePattern { - FailureOr match(Operation* op) final { - if (auto conv = llvm::dyn_cast(op)) { - // Convolution that does runtime autotuning should not be captured, since - // CUDA graphs do not support operations that allocate memory. - lmhlo_gpu::ConvolutionBackendConfigAttr backend_config = - conv.getBackendConfig(); - if (backend_config.getAlgorithm() != -1) { - return kMove; - } - } - return failure(); - } -}; - -// TODO(b/270426911): Right now GEMM/Convolution with runtime autotuning can't -// be captured by a cuda graph. However, longer term the proper fix is to make -// autotuning "cuda-graph-aware", and run autotuning on a separate stream that -// is not in capture mode. -struct ConvForwardOpCapture : public ConvOpCapture {}; -struct ConvBackwardInputOpCapture - : public ConvOpCapture {}; -struct ConvBackwardFilterOpCapture - : public ConvOpCapture {}; -struct ConvForwardFusedOpCapture - : public ConvOpCapture {}; -struct ConvForwardFusedSideInputOpCapture - : public ConvOpCapture {}; - -struct GemmOpCapture : public OpCapturePattern { - FailureOr match(Operation* op) final { - if (auto gemm = llvm::dyn_cast(op)) { - // GEMM that does runtime autotuning should not be captured, since CUDA - // graph does not support operations that allocate memory. - if (!gemm.getAlgorithm().has_value() || - gemm.getAlgorithm().value() != - stream_executor::blas::kRuntimeAutotuning) { - return kMove; - } - } - return failure(); - } -}; - -struct MemcpyOpCapture : public OpCapturePattern { - FailureOr match(Operation* op) final { - if (auto memcpy = llvm::dyn_cast(op)) { - // We use a heuristic to identify the direction of the memcpy operation, - // if the operand was allocated by alloca op or is a global memref, then - // it must be a memref on the host. - auto IsHostMemRef = [](Value value) { - auto* op = value.getDefiningOp(); - return llvm::isa_and_nonnull(op); - }; - - auto IsDeviceToDevice = [&](mlir::gpu::MemcpyOp op) { - return !IsHostMemRef(op.getDst()) && !IsHostMemRef(op.getSrc()); - }; - - // Device-to-host Memcpy cannot be captured by CUDA graphs. - if (IsDeviceToDevice(memcpy)) { - return kMove; - } - } - return failure(); - } -}; - -// Capture pure operations by cloning them into graph capture function. -struct ConstantOpCapture : public CloneOp {}; -struct ViewOpCapture : public CloneOp {}; -struct ReinterpretCastOpCapture : public CloneOp {}; - -//===----------------------------------------------------------------------===// - -// Collect sequences of operations that can be outlined into Cuda Graphs. -static std::vector CollectCaptureSequences( - DominanceInfo& dominance, ModuleOp module, OpCapturePatternSet& patterns) { - std::vector seqs; - - // Match given operation with all capture patterns. - auto match = [&](Operation* op) -> FailureOr { - for (auto& pattern : patterns) { - if (auto matched = pattern->match(op); succeeded(matched)) return matched; - } - return failure(); - }; - - // Find graph-compatible sequences of operations in every block. - module.walk([&](Block* block) { - CaptureSequence* seq = &seqs.emplace_back(); - - for (Operation& op : *block) { - FailureOr matched = match(&op); - // Append matched operation to the current sequence. We only append - // operations that must be moved into the graph capture function (ops with - // side effects), and add cloneable operations later. - if (succeeded(matched) && *matched == kMove) - seq->emplace_back(&op, *matched); - - // Skip unsupported operation and start a new sequence. - if (failed(matched) && !seq->empty()) seq = &seqs.emplace_back(); - } - - // Remove the last sequence if it's empty. - if (seq->empty()) seqs.pop_back(); - }); - - // Remove cloneable operations accidentally captured by the sequence of ops, - // e.g. we can have `memref.view` between two kernel launch operations that - // is not used by operations in the captured sequence. - for (CaptureSequence& seq : seqs) { - llvm::DenseSet moveable_ops; - for (auto& [op, capture] : seq) - if (capture == kMove) moveable_ops.insert(op); - - llvm::erase_if(seq, [&](auto& pair) { - return pair.second == kClone && - llvm::none_of(pair.first->getUsers(), [&](Operation* user) { - return moveable_ops.contains(user); - }); - }); - } - - // Try to extend discovered sequences of ops following operands use-def chains - // and pulling cloneable operations defining operands into the graph capture - // sequence. In practice we just clone `arith.constant` and `memref.view` - // operations into the graph capture function, to make it cheaper to compute - // the hash of the arguments at run time. - for (CaptureSequence& seq : seqs) { - llvm::DenseSet seq_ops; // operations already in `seq` - llvm::SmallVector worklist; - - // Add operations that define `op` arguments to the worklist. - auto populate_worklist = [&](Operation* op) { - for (Value arg : op->getOperands()) - if (Operation* op = arg.getDefiningOp()) worklist.push_back(op); - }; - - for (auto& [op, _] : seq) { - seq_ops.insert(op); - populate_worklist(op); - } - - // Find cloneable ops and group them by block where they are defined. - llvm::DenseMap> cloneable; - - // Traverse use-def chains to collect all cloneable operations. - while (!worklist.empty()) { - Operation* op = worklist.pop_back_val(); - if (seq_ops.contains(op)) continue; - - // Check if operation can be cloned into graph capture function. - if (auto matched = match(op); - succeeded(matched) && *matched == OpCapturePattern::Capture::kClone) { - cloneable[op->getBlock()].push_back(op); - seq_ops.insert(op); - populate_worklist(op); - } - } - - // Traverse blocks according to their dominance to avoid used-before-defined - // invalid SSA region construction in graph capture function. - llvm::SmallVector blocks; - for (auto& [block, _] : cloneable) blocks.push_back(block); - llvm::sort(blocks, [&](Block* a, Block* b) { - return dominance.properlyDominates(a, b); - }); - - for (Block* block : llvm::reverse(blocks)) { - // Sort operations according to their original position in the block. - llvm::sort(cloneable[block], [](Operation* a, Operation* b) { - return a->isBeforeInBlock(b); - }); - - // Prepend all cloneable operations to the discovered ops sequence. - auto cloned = llvm::map_range(cloneable[block], [](Operation* op) { - return std::make_pair(op, OpCapturePattern::Capture::kClone); - }); - seq.insert(seq.begin(), cloned.begin(), cloned.end()); - } - } - - return seqs; -} - -//===----------------------------------------------------------------------===// - -using xla::runtime::CustomCallDeclarations; - -static std::vector GetGraphCaptureFuncArgs(const CaptureSequence& seq) { - llvm::SetVector args; - - // Values defined by operations in the capture sequence. - llvm::DenseSet defined_by_seq; - for (auto& [op, _] : seq) - defined_by_seq.insert(op->result_begin(), op->result_end()); - - // Add arguments defined outside of the capture sequence. - for (auto& [op, _] : seq) { - auto external_args = llvm::make_filter_range( - op->getOperands(), - [&](Value arg) { return !defined_by_seq.contains(arg); }); - args.insert(external_args.begin(), external_args.end()); - } - llvm::SmallVector args_sv = args.takeVector(); - std::vector args_tv(args_sv.begin(), args_sv.end()); - return args_tv; -} - -// Given a sequence of operations, outline them into a graph capture function -// and replace them with an XLA Gpu runtime function call. -static LogicalResult Outline(unsigned ordinal, - CustomCallDeclarations& custom_calls, - CaptureSequence& seq, int min_graph_size) { - // Only operations that have to be moved into the graph capture function - // represent Gpu computations. - unsigned num_move_captures = llvm::count_if(seq, [](auto capture) { - return capture.second == OpCapturePattern::Capture::kMove; - }); - DebugOptions debug_options = GetDebugOptionsFromFlags(); - if (num_move_captures < min_graph_size) return failure(); - - SymbolTable& sym_table = custom_calls.sym_table(); - MLIRContext* ctx = sym_table.getOp()->getContext(); - - // Create a fused location out of LaunchFuncOp operations. - llvm::SmallVector locations; - for (auto& op : seq) locations.push_back(op.first->getLoc()); - ImplicitLocOpBuilder b(FusedLoc::get(ctx, locations), sym_table.getOp()); - - // Arguments of the graph capture function. - std::vector args = GetGraphCaptureFuncArgs(seq); - - // Create a function in the compiled module. - auto func = b.create( - "xla.gpu.graph.capture", - FunctionType::get(ctx, TypeRange(ValueRange(args)), TypeRange())); - - Operation* first_op = seq.front().first; - auto parent_func = first_op->getParentOfType(); - - // If an argument to parent_func has the "lmhlo.constant_name" attribute and - // is passed to the graph capture function, we propagate the attribute the - // graph capture function. - // - // We also annotate all arguments with "rt.allocation_index" attribute that - // allows us to forward correct arguments to graph capture function during - // Gpu executable initialization (see `InstantiateAllGraphs` implementation). - for (unsigned i = 0; i < args.size(); ++i) { - Value arg = args[i]; - - // Check if arg is a function argument of parent_func. - if (!isa(arg)) continue; - - // Function arguments are passed in as block arguments to the entry block. - auto block_arg = cast(arg); - Block* parent_block = block_arg.getParentBlock(); - if (!parent_block->isEntryBlock()) continue; - - // If this is an argument to the entry block of the parent function, it - // means that it's the XLA allocation, and we forward index to the capture - // function. - func.setArgAttr(i, "rt.allocation_index", - b.getIndexAttr(block_arg.getArgNumber())); - - // Check that the parent_block is in the SSACFG region of parent_func. - Region& parent_func_region = parent_func.getRegion(); - if (parent_block->getParent() != &parent_func_region) continue; - - unsigned parent_func_arg_index = block_arg.getArgNumber(); - auto cst = parent_func.getArgAttrOfType(parent_func_arg_index, - "lmhlo.constant_name"); - if (cst) { - func.setArgAttr(i, "lmhlo.constant_name", cst); - } - } - - for (auto op : seq) { - mlir::Operation* captured_op = op.first; - if (isa(captured_op)) { - func->setAttr(b.getStringAttr(runtime::kRequiresBlasAttrName), - BoolAttr::get(ctx, true)); - break; - } - } - - // Add graph capture function to the module. - sym_table.insert(func); - - // Export graph capture function to the runtime. - b.setInsertionPoint(func); - b.create(func, ordinal); - - // Create a custom call declaration corresponding to the outlined graph - // capture function. - func::FuncOp graph_launch = custom_calls.GetOrCreate( - b, "xla.gpu.graph.launch", TypeRange(ValueRange(args)), TypeRange()); - - // Call the cuda graph launch custom call right before the first moved op. - auto insertion_point = llvm::find_if(seq, [](auto capture) { - return capture.second == OpCapturePattern::Capture::kMove; - }); - b.setInsertionPoint(insertion_point->first); - - auto call = b.create(graph_launch.getName(), TypeRange(), args); - call->setAttr(b.getStringAttr("capture"), FlatSymbolRefAttr::get(func)); - - // At this point we successfully added new functions to the module, so we can - // move or clone captured operations from their original location to the graph - // capture function. - Block* body = func.addEntryBlock(); - - // We'll need to replace operands of cloned/moved operations inside the graph - // capture function. - llvm::SmallVector> mappings; // {from, to} mappings - for (auto mapping : llvm::zip(args, func.getArguments())) - mappings.emplace_back(std::get<0>(mapping), std::get<1>(mapping)); - - // Move or clone operations into the graph capture function. - for (auto& [op, capture] : seq) { - if (capture == OpCapturePattern::Capture::kMove) - op->moveBefore(body, body->end()); - - if (capture == OpCapturePattern::Capture::kClone) { - Operation* clone = op->clone(); - OpBuilder::atBlockEnd(body).insert(clone); - - for (auto mapping : llvm::zip(op->getResults(), clone->getResults())) - mappings.emplace_back(std::get<0>(mapping), std::get<1>(mapping)); - } - } - - // Update def-use chains inside the graph capture function. - for (auto mapping : mappings) { - replaceAllUsesInRegionWith(mapping.first, mapping.second, func.getBody()); - } - - // Add a return operation to the graph capture function. - b.setInsertionPointToEnd(body); - b.create(ValueRange()); - - return success(); -} - -//===----------------------------------------------------------------------===// - -void OutlineGpuGraphsPass::runOnOperation() { - SymbolTable sym_table(getOperation()); - CustomCallDeclarations custom_calls(std::move(sym_table)); - - OpCapturePatternSet patterns; - - if (command_types_.contains(DebugOptions::FUSION)) { - // Enable capturing fusions and memcpies. - patterns.emplace_back(new LaunchFuncOpCapture()); - patterns.emplace_back(new ConstantOpCapture()); - patterns.emplace_back(new ViewOpCapture()); - patterns.emplace_back(new MemcpyOpCapture()); - patterns.emplace_back(new ReinterpretCastOpCapture()); - } - - if (command_types_.contains(DebugOptions::CUBLAS)) { - // Enable capturing gemms. - patterns.emplace_back(new GemmOpCapture()); - } - - if (command_types_.contains(DebugOptions::CUDNN)) { - // Enable capturing convolutions. - patterns.emplace_back(new ConvForwardOpCapture()); - patterns.emplace_back(new ConvBackwardInputOpCapture()); - patterns.emplace_back(new ConvBackwardFilterOpCapture()); - patterns.emplace_back(new ConvForwardFusedOpCapture()); - patterns.emplace_back(new ConvForwardFusedSideInputOpCapture()); - } - - unsigned ordinal = 1; // entry point will be exported with ordinal 0 - for (auto& seq : CollectCaptureSequences(getAnalysis(), - getOperation(), patterns)) { - if (succeeded(Outline(ordinal, custom_calls, seq, min_graph_size_))) - ordinal++; - } -} - -std::unique_ptr> createOutlineGpuGraphsPass() { - return std::make_unique(); -} - -std::unique_ptr> createOutlineGpuGraphsPass( - absl::flat_hash_set command_types, - int min_graph_size) { - return std::make_unique(command_types, min_graph_size); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/mlir/backends/gpu/transforms/passes.cc b/xla/mlir/backends/gpu/transforms/passes.cc deleted file mode 100644 index 8fcb6d148ed88..0000000000000 --- a/xla/mlir/backends/gpu/transforms/passes.cc +++ /dev/null @@ -1,95 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/mlir/backends/gpu/transforms/passes.h" - -#include -#include - -#include "absl/log/log.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project -#include "xla/mlir/runtime/ir/rt_ops.h" - -namespace xla { -namespace gpu { - -using namespace mlir; // NOLINT - -std::vector> GetAllocationIndices(mlir::ModuleOp module) { - std::vector> res; - - SymbolTable sym_table(module); - for (auto op : module.getOps()) { - unsigned ordinal = *op.ordinal(); - if (ordinal >= res.size()) res.resize(ordinal + 1); - - auto func = sym_table.lookup(op.getFunctionRef()); - res[ordinal].resize(func.getNumArguments(), -1); - - for (unsigned i = 0; i < func.getNumArguments(); ++i) { - auto idx = func.getArgAttrOfType(i, "rt.allocation_index"); - if (idx) res[ordinal][i] = idx.getInt(); - } - } - - return res; -} - -void populateXlaGpuRuntimePasses(mlir::OpPassManager& pm, - ThunkSequence* thunk_sequence, - const GpuPipelineOpts& opts) { - // Lower operations with registered IR emitters to Gpu launches. - pm.addPass(createConvertLmhloToGpuLaunchPass(thunk_sequence)); - - // Clean up IR before converting it to the runtime operations. - pm.addPass(createCSEPass()); - - // Convert global memrefs corresponding to constant arguments. - pm.addPass(createConvertMemrefGetGlobalToArgPass()); - pm.addPass(createSymbolDCEPass()); // Clean up unused global constants. - - // Outline CUDA-Graph-compatible operations into graph capture functions. - pm.addPass( - createOutlineGpuGraphsPass(opts.command_types, opts.min_graph_size)); - if (opts.enable_concurrent_region) { - // Concurrent regions create repeated-fork-join topology inside CUDA graphs, - // which is not optimized by architectures prior to Ampere and may cause - // regression. So we enable concurrent regions only on Ampere GPUs. - if (auto cc = std::get_if( - &opts.compute_capability); - !cc || cc->IsAtLeast(8, 0)) { - pm.addPass(createAddConcurrentRegionsPass()); - } else { - LOG(WARNING) - << "Multi-stream execution disabled on non-ampere architectures"; - } - } - - // Lower all Gpu operations to the XLA Gpu runtime custom calls. - pm.addPass(createConvertLmhloGpuToGpuRuntimePass()); - pm.addPass(createConvertLmhloToGpuRuntimePass()); - pm.addPass(createConvertGpuToGpuRuntimePass()); - - // Add performance tracing annotations. - pm.addPass(createAddHloTraceAnnotationsPass()); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/mlir/backends/gpu/transforms/passes.h b/xla/mlir/backends/gpu/transforms/passes.h deleted file mode 100644 index 253eda6c3b3e1..0000000000000 --- a/xla/mlir/backends/gpu/transforms/passes.h +++ /dev/null @@ -1,150 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_MLIR_BACKENDS_GPU_TRANSFORMS_PASSES_H_ -#define XLA_MLIR_BACKENDS_GPU_TRANSFORMS_PASSES_H_ - -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "xla/stream_executor/device_description.h" -#include "xla/xla.pb.h" - -namespace xla { -namespace gpu { - -#define GEN_PASS_DECL_ADDHLOTRACEANNOTATIONSPASS -#define GEN_PASS_DECL_CONVERTGPUTOGPURUNTIMEPASS -#define GEN_PASS_DECL_CONVERTLMHLOGPUTOGPURUNTIMEPASS -#define GEN_PASS_DECL_CONVERTLMHLOTOGPULAUNCHPASS -#define GEN_PASS_DECL_CONVERTLMHLOTOGPURUNTIMEPASS -#define GEN_PASS_DECL_CONVERTMEMREFGETGLOBALTOARGPASS -#define GEN_PASS_DECL_OUTLINEGPUGRAPHSPASS -#define GEN_PASS_DECL_ADDCONCURRENTREGIONSPASS -#define GEN_PASS_DECL_STREAMASSIGNMENTPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -class ThunkSequence; // forward declare - -// Collects `rt.allocation_index` attributes from all exported functions. -// -// auto result = GetAllocationIndices(); -// result[ordinal][argument_index] == allocation_index; -// -// Returns `-1` for all arguments that do not have `rt.allocation_index` -// attribute. -// -// TODO(ezhulenev): This is a very ugly hack for graph capture integration, but -// given that we are moving towards a new runtime and command buffers, it's -// supposed to be a very short lived hack. -std::vector> GetAllocationIndices(mlir::ModuleOp module); - -struct GpuPipelineOpts { - // Enable experimental pass that outlines parts of the XLA computation into - // CUDA Graphs, which allows us to amortize the cost of launching multiple - // device kernels. - absl::flat_hash_set command_types; - int32_t min_graph_size = 0; - bool enable_concurrent_region = false; - stream_executor::GpuComputeCapability compute_capability; -}; - -// Populate passes that lower MLIR modules from a combination of LMHLO and -// LMHLO_GPU dialects to the XLA Gpu runtime. This pipeline is composed from -// the passes defined below, and few builtin MLIR passes. -void populateXlaGpuRuntimePasses(mlir::OpPassManager& pm, - ThunkSequence* thunk_sequence, - const GpuPipelineOpts& opts = {}); - -//===----------------------------------------------------------------------===// -// Auxiliary passes for lowering to XLA Gpu runtime. -//===----------------------------------------------------------------------===// - -std::unique_ptr> -createConvertMemrefGetGlobalToArgPass(); - -std::unique_ptr> -createConvertMemrefGetGlobalToArgPass(int64_t min_num_elements); - -//===-----------------------------------------------------------------------===/ -// Passes for lowering from the `gpu` dialect. -//===-----------------------------------------------------------------------===/ - -std::unique_ptr> -createConvertGpuToGpuRuntimePass(); - -//===----------------------------------------------------------------------===// -// Passes for lowering from the `lmhlo` dialect. -//===----------------------------------------------------------------------===// - -std::unique_ptr> -createConvertLmhloToGpuLaunchPass(ThunkSequence* thunk_sequence = nullptr); - -std::unique_ptr> -createConvertLmhloToGpuRuntimePass(); - -//===----------------------------------------------------------------------===// -// Passes for lowering from the `lmhlo_gpu` dialect. -//===----------------------------------------------------------------------===// - -std::unique_ptr> -createConvertLmhloGpuToGpuRuntimePass(); - -//===----------------------------------------------------------------------===// -// XLA runtime performance tracing passes. -//===----------------------------------------------------------------------===// - -std::unique_ptr> -createAddHloTraceAnnotationsPass(); - -//===----------------------------------------------------------------------===// -// XLA runtime <-> Cuda Graphs integration. -//===----------------------------------------------------------------------===// - -std::unique_ptr> -createOutlineGpuGraphsPass(); - -std::unique_ptr> createOutlineGpuGraphsPass( - absl::flat_hash_set command_types, - int32_t min_graph_size); - -//===----------------------------------------------------------------------===// -// Passes for marking concurrent region in CUDA graph capture function. -//===----------------------------------------------------------------------===// - -std::unique_ptr> -createAddConcurrentRegionsPass(); - -//===----------------------------------------------------------------------===// -// Passes for assigning kernels to streams in CUDA graph capture function. -//===----------------------------------------------------------------------===// - -std::unique_ptr> -createStreamAssignmentPass(); - -//===-----------------------------------------------------------------------===/ - -#define GEN_PASS_REGISTRATION -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -} // namespace gpu -} // namespace xla - -#endif // XLA_MLIR_BACKENDS_GPU_TRANSFORMS_PASSES_H_ diff --git a/xla/mlir/backends/gpu/transforms/passes.td b/xla/mlir/backends/gpu/transforms/passes.td deleted file mode 100644 index c03897e766c10..0000000000000 --- a/xla/mlir/backends/gpu/transforms/passes.td +++ /dev/null @@ -1,302 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_GPU_PASSES -#define XLA_GPU_PASSES - -include "mlir/Pass/PassBase.td" - -//===----------------------------------------------------------------------===// -// Auxiliary passes for lowering to XLA Gpu runtime. -//===----------------------------------------------------------------------===// - -def ConvertMemrefGetGlobalToArgPass : - Pass<"xla-memref-get-global-to-arg", "mlir::ModuleOp"> { - let summary = "Converts memref.get_global corresponding to lmhlo constants"; - - let description = [{ - Replaces `memref.get_global` operations corresponding to the lmhlo constant - arguments (arguments marked with `lmhlo.constant_name` attribute) to use - the constant arguments directly. - - Once we used global constants for constant folding, we no longer need to - keep them in the module, because they'll be in the binary constant section - on the host, and we need them on the device. - }]; - - let constructor = "createConvertMemrefGetGlobalToArgPass()"; - - let options = [ - Option<"min_num_elements_", "min-num-elements", "int64_t", /*default=*/"0", - "Do not convert `memref.get_global` operation if the number of " - "elements is smaller than the given value.">, - ]; -} - -//===----------------------------------------------------------------------===// -// Passes for lowering from the `gpu` dialect. -//===----------------------------------------------------------------------===// - -def ConvertGpuToGpuRuntimePass : - Pass<"xla-gpu-to-gpu-runtime", "mlir::ModuleOp"> { - let summary = "Converts gpu operations to XLA Gpu runtime custom calls"; - - let description = [{ - Converts gpu operations (function launch, memcpy, etc...) to the XLA Gpu - runtime custom calls. - }]; - - let constructor = "createConvertGpuToGpuRuntimePass()"; -} - -//===----------------------------------------------------------------------===// -// Passes for lowering from the `lmhlo` dialect. -//===----------------------------------------------------------------------===// - -def ConvertLmhloToGpuLaunchPass : - Pass<"xla-lmhlo-to-gpu-launch", "mlir::ModuleOp"> { - let summary = "Converts lmhlo fusions to Gpu dialect kernel launch"; - - let description = [{ - Converts lmhlo operations that have registered IR emitters (e.g. fusions) to - Gpu dialect kernel launch operations (and trivial memory operations like - memcpy or memset). This pass relies on a pre-compiled ThunkSequence with an - associated device module (PTX and cubin) to find device kernels - corresponding to lmhlo operation in the input module. - - Created Gpu kernel launch operations can be further lowered to the Gpu - runtime by the `xla-gpu-to-gpu-runtime` pass. - }]; - - let constructor = "createConvertLmhloToGpuLaunchPass()"; -} - -def ConvertLmhloToGpuRuntimePass : - Pass<"xla-lmhlo-to-gpu-runtime", "mlir::ModuleOp"> { - let summary = "Converts lmhlo operations to XLA Gpu runtime custom calls"; - - let description = [{ - Converts lmhlo dialect operations (infeed, outfeed, collectives, etc...) to - the XLA Gpu runtime custom calls. - }]; - - let constructor = "createConvertLmhloToGpuRuntimePass()"; -} - -//===----------------------------------------------------------------------===// -// Passes for lowering from the `lmhlo_gpu` dialect. -//===----------------------------------------------------------------------===// - -def ConvertLmhloGpuToGpuRuntimePass : - Pass<"xla-lmhlo-gpu-to-gpu-runtime", "mlir::ModuleOp"> { - let summary = "Converts lmhlo_gpu operations to XLA Gpu runtime custom calls"; - - let description = [{ - Converts lmhlo_gpu dialect operations (gemm, convolution, etc...) to - the XLA Gpu runtime custom calls. - }]; - - let constructor = "createConvertLmhloGpuToGpuRuntimePass()"; -} - -//===----------------------------------------------------------------------===// -// XLA runtime performance tracing passes. -//===----------------------------------------------------------------------===// - -// TODO(ezhulenev): This pass should be generic for all backends, consider -// moving it to the `transforms/runtime` folder once it will be used by CPU -// compiler. - -def AddHloTraceAnnotationsPass : - Pass<"xla-add-hlo-trace-annotations", "mlir::ModuleOp"> { - let summary = "Adds HLO trace annotations to the supported operations"; - - let description = [{ - Adds HLO trace annotations to the operations that result from compiling - an input HLO module, e.g. it adds HLO trace annotations to all runtime custom - calls that are constructed from the corresponding HLO operations. - - Example: - - ```mlir - call @xla.gpu.gemm(...) : (...) -> memref - ``` - - becomes: - - ```mlir - call @xla.gpu.gemm(...) { rt.trace = #rt.hlo<"gemm.1", "xla_module", 0> } - : (...) -> memref - ``` - - XLA compilation pipeline wraps traced operations into the `rt.trace` - operation, and eventually lowers them to the tracing API calls. - }]; - - let constructor = "createAddHloTraceAnnotationsPass()"; -} - -//===----------------------------------------------------------------------===// -// Xla Gpu <-> Cuda Graphs integration. -//===----------------------------------------------------------------------===// - -def OutlineGpuGraphsPass : - Pass<"xla-gpu-outline-gpu-graphs", "mlir::ModuleOp"> { - let summary = "Outline sequences of Xla Gpu operations into CUDA Graphs"; - - let description = [{ - Converts sequences of supported Xla Gpu operations to Cuda Graph capture - functions, and replaces the original sequences with calls to the Xla Cuda - Graph runtime API. - - Example: - - ```mlir - gpu.launch_func @compute::foo args(%arg0: memref) - gpu.launch_func @compute::bar args(%arg1: memref) - ``` - - becomes: - - ```mlir - // Export cuda graph capture function to Xla runtime. - rt.export @capture ordinal 1 - func.func @capture(@arg0: memref, %arg1: memref) { - ... capture a graph corresponding to a sequence of `gpu.launch_func` ops - } - - // Replace a sequence of graph launch operations with a call to runtime API. - call @xla.gpu.graph.launch(%arg0: memref, - %arg1: memref) - attributes { capture = @capture } - ``` - }]; - - let constructor = "createOutlineGpuGraphsPass()"; - - let options = [ - Option<"min_graph_size_", "min_graph_size", "int64_t", /*default=*/"2", - "The minimum size of the outlined CUDA graph function.">, - ]; -} - -//===----------------------------------------------------------------------===// -// Add concurrent regions to CUDA graph capture functions. -//===----------------------------------------------------------------------===// - -def AddConcurrentRegionsPass: - Pass<"xla-gpu-add-concurrent-regions", "mlir::ModuleOp"> { - let summary = "Identify and mark concurrent regions in CUDA graph capture " - "functions"; - - let description = [{ - Add concurent region markers to indicate a region of operations that can be - executed concurrently. - - Example: - - ```mlir - func.func @capture.cuda.graph() { - call @xla.gpu.launch.func - call @xla.gpu.launch.func - - // Everything here can run concurrently - call @xla.gpu.launch.func - call @xla.gpu.launch.func - call @xla.gpu.launch.func - call @xla.gpu.launch.func - // Back to sequential execution - - call @xla.gpu.launch.func - func.return - } - ``` - - becomes: - - ```mlir - func.func @capture.cuda.graph() { - call @xla.gpu.launch.func - call @xla.gpu.launch.func - - call @xla.gpu.concurrent_region.begin() - call @xla.gpu.launch.func - call @xla.gpu.launch.func - call @xla.gpu.launch.func - call @xla.gpu.launch.func - call @xla.gpu.concurrent_region.end() - - call @xla.gpu.launch.func - func.return - } - ``` - - }]; - - let constructor = "createAddConcurrentRegionsPass()"; -} - -//===----------------------------------------------------------------------===// -// Stream assignment. -//===----------------------------------------------------------------------===// - -def StreamAssignmentPass: - Pass<"xla-gpu-stream-assignment", "mlir::ModuleOp"> { - let summary = "Identify and mark concurrent regions in CUDA graph capture " - "functions"; - - let description = [{ - Assign a stream to each kernel launch in the capture function. Streams are - assigned to exploit parallelism, so that we can build parallel GPU graphs - duing graph capture. - - Example: - - ```mlir - func.func @capture.cuda.graph() { - // func1, func2, func3 can run in parallel - call @xla.gpu.launch.func1 - call @xla.gpu.launch.func2 - call @xla.gpu.launch.func3 - - // Depends on xla.gpu.launc.func1 and xla.gpu.launch.func2 to finish. - call @xla.gpu.launch.func - func.return - } - ``` - - becomes: - - ```mlir - func.func @capture.cuda.graph() { - // func1, func2, func3 can run in parallel - call @xla.gpu.launch.func1 {stream = 0 : i64} - call @xla.gpu.launch.func2 {stream = 1 : i64} - call @xla.gpu.launch.func3 {stream = 2 : i64} - - // Add explicit synchronization to wait for stream 1 to finish executing - // func2. - call @xla.stream.await {from = 0 : i64, to = [1]} - call @xla.gpu.launch.func {stream = 0: i64} - func.return - } - ``` - - }]; - - let constructor = "createStreamAssignmentPass()"; -} - -#endif // XLA_GPU_PASSES diff --git a/xla/mlir/backends/gpu/transforms/stream_assignment.cc b/xla/mlir/backends/gpu/transforms/stream_assignment.cc deleted file mode 100644 index 87c3ba7ac0b8a..0000000000000 --- a/xla/mlir/backends/gpu/transforms/stream_assignment.cc +++ /dev/null @@ -1,271 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/strings/match.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "xla/mlir/backends/gpu/transforms/dataflow_analysis.h" -#include "xla/mlir/runtime/utils/custom_calls.h" - -namespace xla { -namespace gpu { - -namespace { - -#define GEN_PASS_DEF_STREAMASSIGNMENTPASS -#include "xla/mlir/backends/gpu/transforms/passes.h.inc" - -using namespace mlir; // NOLINT -using mlir::func::FuncOp; -using DataflowGraph = DataflowAnalysis::DataflowGraph; -using Node = DataflowAnalysis::Node; - -class StreamAssignmentPass - : public impl::StreamAssignmentPassBase { - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } -}; - -static constexpr int kNumStreams = 10; - -//===----------------------------------------------------------------------===// - -bool IsParallelizableOp(Operation* op) { - return isa(op); -} - -// -// A simple algorithm to assign streams using the dependency information -// provided by the dataflow graph. -// Pseudocode: -// stream = 0 -// while there exists op such that it is unassigned: -// assign op to stream -// while op has a child: -// op = the last child in the order of execution in the capture function -// assign op to stream -// stream++ -// -// When assigning a stream to a dependency chain, we find the next op in the -// chain by finding the last child of the current op. For example, in the -// following dependency graph, A and C are assigned to stream 0, while B is -// assigned to 1. -// -// A-->B C -// | ^ -// +------| -// -std::vector AssignStreams(const DataflowGraph& graph, int num_streams) { - std::vector stream_assignment(graph.size(), -1); - size_t current_stream = 0; - - auto get_current_stream = [&]() -> size_t { - size_t assigned_stream = current_stream; - current_stream++; - if (current_stream == num_streams) { - current_stream = 0; - } - return assigned_stream; - }; - - auto get_first_unassigned_node = [&stream_assignment = - std::as_const(stream_assignment), - &graph]() -> std::optional { - for (auto [index, stream] : llvm::enumerate(stream_assignment)) { - if (stream == -1 && IsParallelizableOp(graph[index].operation)) { - return index; - } - } - return std::nullopt; - }; - - auto get_last_unassigned_child = [&stream_assignment = - std::as_const(stream_assignment), - &graph](Node node) -> std::optional { - for (int i = node.children.size() - 1; i >= 0; i--) { - Node child = graph[node.children[i]]; - if (!IsParallelizableOp(child.operation)) continue; - if (stream_assignment[child.index] == -1) { - return child; - } - } - return std::nullopt; - }; - - std::function assign_stream_to_dependency_chain = - [&](Node node, size_t stream) { - stream_assignment[node.index] = stream; - - if (auto child = get_last_unassigned_child(node)) { - assign_stream_to_dependency_chain(child.value(), stream); - } - }; - - while (std::optional unassigned_index = get_first_unassigned_node()) { - Node unassigned_node = graph[unassigned_index.value()]; - size_t assigned_stream = get_current_stream(); - assign_stream_to_dependency_chain(unassigned_node, assigned_stream); - } - - // next: Assign all non parallelizable ops to stream 0. - - return stream_assignment; -} - -std::optional GetAssignedStream(Operation* op) { - if (op->hasAttr("stream")) { - return op->getAttrOfType("stream").getInt(); - } - return std::nullopt; -} - -// -// Add synchronizations between assigned streams. The added custom call -// xla.streams.await() {from = A, to = [B, C, ...]} makes future work submitted -// to A wait for work that are already submitted to streams B, C, ... -// -// Pseudo code: -// For each node in the dependency graph -// If the node has a stream A assigned -// parents = A's parents -// to_streams = the assigned streams of its parents -// add xla.streams.await() {from = A, to = to_streams} before node -// -// TODO(anlunx): Handle the case where the cuda graph contains non -// parallelizable ops (cuBLAS, cuDNN). -// -void AddSynchronization(FuncOp await_op, - runtime::CustomCallDeclarations custom_calls, - const DataflowGraph& graph) { - for (const Node& node : graph) { - Operation* op = node.operation; - std::optional op_stream = GetAssignedStream(op); - if (!op_stream.has_value()) { - continue; - } - int from_stream = op_stream.value(); - - std::array dependent_streams; - dependent_streams.fill(false); - for (int i = 0; i < node.index; i++) { - if (std::find(graph[i].children.begin(), graph[i].children.end(), - node.index) != graph[i].children.end()) { - if (std::optional to_stream = - GetAssignedStream(graph[i].operation)) { - if (to_stream.value() != from_stream) { - dependent_streams[to_stream.value()] = true; - } - } - } - } - - ImplicitLocOpBuilder b(op->getLoc(), custom_calls.sym_table().getOp()); - llvm::SmallVector to_streams; - for (int i = 0; i < kNumStreams; i++) { - if (dependent_streams[i]) { - to_streams.push_back(b.getI64IntegerAttr(i)); - } - } - - if (to_streams.empty()) { - continue; - } - - b.setInsertionPoint(op); - auto call = b.create(await_op.getName(), TypeRange()); - call->setAttr(b.getStringAttr("from"), b.getI64IntegerAttr(from_stream)); - call->setAttr(b.getStringAttr("to"), b.getArrayAttr(to_streams)); - } -} - -//===----------------------------------------------------------------------===// - -void StreamAssignmentPass::runOnOperation() { - ModuleOp module = getOperation(); - SymbolTable sym_table(module); - runtime::CustomCallDeclarations custom_calls(std::move(sym_table)); - - auto func_ops = llvm::to_vector(module.getOps()); - ImplicitLocOpBuilder b(module->getLoc(), custom_calls.sym_table().getOp()); - func::FuncOp begin_marker = custom_calls.GetOrCreate( - b, "xla.gpu.concurrent_region.begin", TypeRange(), TypeRange()); - func::FuncOp end_marker = custom_calls.GetOrCreate( - b, "xla.gpu.concurrent_region.end", TypeRange(), TypeRange()); - func::FuncOp await_op = custom_calls.GetOrCreate(b, "xla.streams.await", - TypeRange(), TypeRange()); - - for (auto func_op : func_ops) { - if (!absl::StrContains(func_op.getSymNameAttr().str(), - "xla.gpu.graph.capture")) { - continue; - } - - DataflowAnalysis dataflow_analysis(func_op); - DataflowGraph graph = dataflow_analysis.GetDataflowGraph(func_op); - std::vector stream_assignment = AssignStreams(graph, kNumStreams); - - size_t stream_count = 0; - for (auto [index, stream] : llvm::enumerate(stream_assignment)) { - stream_count = std::max(stream_count, stream + 1); - Node node = graph[index]; - Operation* op = node.operation; - ImplicitLocOpBuilder b(op->getLoc(), custom_calls.sym_table().getOp()); - if (stream != -1) { - op->setAttr(b.getStringAttr("stream"), b.getI64IntegerAttr(stream)); - } - } - - AddSynchronization(await_op, custom_calls, graph); - - ImplicitLocOpBuilder b(func_op->getLoc(), custom_calls.sym_table().getOp()); - auto first_op = &(*func_op.getOps().begin()); - b.setInsertionPoint(first_op); - auto call = b.create(begin_marker.getName(), TypeRange()); - call->setAttr(b.getStringAttr("size"), b.getI64IntegerAttr(stream_count)); - - auto op_it = func_op.getOps().begin(); - while (!isa(*op_it)) { - op_it++; - } - Operation* return_op = &(*op_it); - b.setInsertionPoint(return_op); - b.create(end_marker.getName(), TypeRange()); - } -} - -} // namespace - -std::unique_ptr> createStreamAssignmentPass() { - return std::make_unique(); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/mlir/backends/gpu/transforms/tests/BUILD b/xla/mlir/backends/gpu/transforms/tests/BUILD deleted file mode 100644 index 68f8a0ba7fc10..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("@tsl//tsl:tsl.default.bzl", "filegroup") -load("//xla:glob_lit_test.bzl", "glob_lit_tests") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - licenses = ["notice"], -) - -glob_lit_tests( - name = "all_tests", - data = [":test_utilities"], - # TODO(b/286919981). Remove nomsan once we pass through MSAN_OPTIONS env var - # to the test. - default_tags = ["nomsan"], - driver = "//xla:run_lit.sh", - test_file_exts = ["mlir"], -) - -# Bundle together all of the test utilities that are used by tests. -filegroup( - name = "test_utilities", - testonly = True, - data = [ - "//xla/mlir/backends/gpu:xla-gpu-opt", - "@llvm-project//llvm:FileCheck", - "@llvm-project//mlir:run_lit.sh", - ], -) diff --git a/xla/mlir/backends/gpu/transforms/tests/add_concurrent_regions.mlir b/xla/mlir/backends/gpu/transforms/tests/add_concurrent_regions.mlir deleted file mode 100644 index e9f21a15b65e7..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/add_concurrent_regions.mlir +++ /dev/null @@ -1,348 +0,0 @@ -// RUN: xla-gpu-opt %s --split-input-file -xla-gpu-add-concurrent-regions \ -// RUN: | FileCheck %s - - -// ----- -// Check that two consecutive launch_funcs using different buffers is captured -// by a concurrent_region. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - } - - - // CHECK: func @xla.gpu.graph.capture - func.func @xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - %view_0 = memref.view %arg1[%c0][] : memref<72xi8> to memref<3x3xi64> - - // CHECK: call @xla.gpu.concurrent_region.begin() - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: call @xla.gpu.concurrent_region.end() - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - return - } -} - -// ----- -// Check that two consecutive launch_funcs using the same buffer is not -// captured. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - } - - - // CHECK: func @xla.gpu.graph.capture - func.func @xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - %view_0 = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - - // CHECK: gpu.launch_func - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - return - } -} - -// ----- -// Check that there is no dependency from launch_funcs that do not write to -// buffers. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<3x3xi64> ) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<3x3xi64> ) kernel { gpu.return } - } - - - // CHECK: func @xla.gpu.graph.capture - func.func @xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - %view_0 = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - - // CHECK: call @xla.gpu.concurrent_region.begin() - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: call @xla.gpu.concurrent_region.end() - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - return - } -} - -// ----- -// Check that the i1 data type is handled correctly. -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<3x3xi1> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<3x3xi1> {lmhlo.written} ) kernel { gpu.return } - } - - - // CHECK: func @xla.gpu.graph.capture - func.func @xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi1> - %view_0 = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi1> - - // CHECK-NOT: xla.gpu.concurrent_region.begin() - // CHECK: gpu.launch_func - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi1>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi1>) - return - } -} - -// ----- -// Check that disjoint buffer slices does not introduce dependency. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - } - - - // CHECK: func @xla.gpu.graph.capture - func.func @xla.gpu.graph.capture(%arg0: memref<144xi8>) { - %c0 = arith.constant 0 : index - %c72 = arith.constant 72 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<144xi8> to memref<3x3xi64> - %view_0 = memref.view %arg0[%c72][] : memref<144xi8> to memref<3x3xi64> - - // CHECK: call @xla.gpu.concurrent_region.begin() - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: call @xla.gpu.concurrent_region.end() - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - return - } -} - -// ----- -// Check that overlapping buffer slices creates dependency. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - } - - - // CHECK: func @xla.gpu.graph.capture - func.func @xla.gpu.graph.capture(%arg0: memref<144xi8>) { - %c0 = arith.constant 0 : index - %c36 = arith.constant 36 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<144xi8> to memref<3x3xi64> - %view_0 = memref.view %arg0[%c36][] : memref<144xi8> to memref<3x3xi64> - - // CHECK: gpu.launch_func - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - return - } -} - -// ----- -// Check that constant input buffer does not create dependency. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - } - - - // CHECK: func @xla.gpu.graph.capture - func.func @xla.gpu.graph.capture(%arg0: memref<144xi8> {lmhlo.constant_name = "cst0"}) { - %c0 = arith.constant 0 : index - %c36 = arith.constant 36 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<144xi8> to memref<3x3xi64> - %view_0 = memref.view %arg0[%c36][] : memref<144xi8> to memref<3x3xi64> - - // CHECK: call @xla.gpu.concurrent_region.begin() - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: call @xla.gpu.concurrent_region.end() - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - return - } -} - -// ----- -// Check that two gemms that read the same buffer are moved into a concurrent -// region. - -module attributes {gpu.container_module} { - - // CHECK: func @xla.gpu.graph.capture - func.func @xla.gpu.graph.capture(%arg0: memref<16xi8>, - %arg1: memref<16xi8>, - %arg2: memref<16xi8>, - %arg3: memref<16xi8>) { - %c0 = arith.constant 0 : index - %view_0 = memref.view %arg0[%c0][] : memref<16xi8> to memref<2x2xf32> - %c1 = arith.constant 0 : index - %view_1 = memref.view %arg1[%c1][] : memref<16xi8> to memref<2x2xf32> - %c2 = arith.constant 0 : index - %view_2 = memref.view %arg2[%c2][] : memref<16xi8> to memref<2x2xf32> - %view_3 = memref.view %arg3[%c2][] : memref<16xi8> to memref<2x2xf32> - - // CHECK: call @xla.gpu.concurrent_region.begin() - // CHECK-NEXT: lmhlo_gpu.gemm - // CHECK-NEXT: lmhlo_gpu.gemm - // CHECK-NEXT: call @xla.gpu.concurrent_region.end() - // CHECK-NEXT: return - "lmhlo_gpu.gemm"(%view_0, %view_1, %view_2) {alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, beta = 0.000000e+00 : f64, batch_size = 1 : i64, lhs_stride = 4 : i64, rhs_stride = 4 : i64, dot_dimension_numbers = #mhlo.dot} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () - "lmhlo_gpu.gemm"(%view_0, %view_1, %view_3) {alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, beta = 0.000000e+00 : f64, batch_size = 1 : i64, lhs_stride = 4 : i64, rhs_stride = 4 : i64, dot_dimension_numbers = #mhlo.dot} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () - return - } - - func.func private @external() -} - -// ----- -// Check that lmhlo_gpu.gemm is not moved into the concurrent region if it -// uses a buffer used by a kernel launch. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<16xi8> {lmhlo.written} ) kernel { gpu.return } - } - - // CHECK: func @xla.gpu.graph.capture - func.func @xla.gpu.graph.capture(%arg0: memref<16xi8>, - %arg1: memref<16xi8>, - %arg2: memref<16xi8>) { - %c0 = arith.constant 0 : index - %view_0 = memref.view %arg0[%c0][] : memref<16xi8> to memref<2x2xf32> - %c1 = arith.constant 0 : index - %view_1 = memref.view %arg1[%c1][] : memref<16xi8> to memref<2x2xf32> - %c2 = arith.constant 0 : index - %view_2 = memref.view %arg2[%c2][] : memref<16xi8> to memref<2x2xf32> - - // CHECK-NOT: @xla.gpu.concurrent_region.begin() - // CHECK: lmhlo_gpu.gemm - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: return - "lmhlo_gpu.gemm"(%view_0, %view_1, %view_2) {alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, beta = 0.000000e+00 : f64, batch_size = 1 : i64, lhs_stride = 4 : i64, rhs_stride = 4 : i64, dot_dimension_numbers = #mhlo.dot} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () - gpu.launch_func @gpu_module::@fn0 blocks in (%c0, %c0, %c0) - threads in (%c0, %c0, %c0) args(%arg0: memref<16xi8>) - return - } - - func.func private @external() -} - -// ----- -// Check that memcpies are added to concurrent regions. - -module attributes {gpu.container_module} { - - // CHECK: func @xla.gpu.graph.capture - func.func @xla.gpu.graph.capture(%arg0: memref<16xi8>, - %arg1: memref<16xi8>, - %arg2: memref<16xi8>) { - %c0 = arith.constant 0 : index - %view_0 = memref.view %arg0[%c0][] : memref<16xi8> to memref<2x2xf32> - %c1 = arith.constant 0 : index - %view_1 = memref.view %arg1[%c1][] : memref<16xi8> to memref<2x2xf32> - %c2 = arith.constant 0 : index - %view_2 = memref.view %arg2[%c2][] : memref<16xi8> to memref<2x2xf32> - - // CHECK: @xla.gpu.concurrent_region.begin() - // CHECK-NEXT: gpu.memcpy - // CHECK-NEXT: gpu.memcpy - // CHECK-NEXT: @xla.gpu.concurrent_region.end() - // CHECK-NEXT: return - gpu.memcpy %view_1, %view_0 : memref<2x2xf32>, memref<2x2xf32> - gpu.memcpy %view_2, %view_0 : memref<2x2xf32>, memref<2x2xf32> - return - } - - func.func private @external() -} - -// ----- -// Check that region size is set correctly. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - } - - - // CHECK: func @xla.gpu.graph.capture - func.func @xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - %view_0 = memref.view %arg1[%c0][] : memref<72xi8> to memref<3x3xi64> - - // CHECK: call @xla.gpu.concurrent_region.begin() {size = 2 : i64} - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: memref.view - // CHECK-NEXT: gpu.launch_func - // CHECK-NEXT: call @xla.gpu.concurrent_region.end() - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>) - %view_1 = memref.view %arg1[%c0][] : memref<72xi8> to memref<3x3xi64> - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - return - } -} diff --git a/xla/mlir/backends/gpu/transforms/tests/add_hlo_trace.mlir b/xla/mlir/backends/gpu/transforms/tests/add_hlo_trace.mlir deleted file mode 100644 index d2f46c8d6bd05..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/add_hlo_trace.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: xla-gpu-opt %s -xla-add-hlo-trace-annotations | FileCheck %s - -module attributes { mhlo.unique_id = 42 : i64 } { - -func.func private @xla.foo() attributes { rt.custom_call = "xla.foo" } - -// CHECK: func @func() { -func.func @func() { - // CHECK: call @xla.foo() - // CHECK-SAME: rt.trace = #rt.hlo_trace<"gemm.name.42"> - call @xla.foo() : () -> () loc("gemm.name.42") - return -} - -} loc("module-name") diff --git a/xla/mlir/backends/gpu/transforms/tests/gpu_launch.mlir b/xla/mlir/backends/gpu/transforms/tests/gpu_launch.mlir deleted file mode 100644 index e05ff982bb3c4..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/gpu_launch.mlir +++ /dev/null @@ -1,64 +0,0 @@ -// RUN: xla-gpu-opt %s -xla-gpu-to-gpu-runtime | FileCheck %s - -module attributes {gpu.container_module} { - -// CHECK-NOT: gpu.module -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<4x4xf32>, %arg1: memref<4x4xf32>) kernel { - gpu.return - } - gpu.func @fn1(%arg0: memref<4x4xf32>, %arg1: memref<4x4xf32>) kernel { - gpu.return - } -} - -// CHECK: @func( -// CHECK: %[[ARG0:.*]]: memref<4x4xf32>, -// CHECK: %[[ARG1:.*]]: memref<4x4xf32> -// CHECK: ) -func.func @func(%arg0: memref<4x4xf32>, %arg1: memref<4x4xf32>) { - // Launch dimensions converted to i32 as a part of the lowering. - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 - // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : i32 - // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : i32 - // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32 - // CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32 - // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : i32 - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32 - // CHECK-DAG: %[[C256:.*]] = arith.constant 256 : i32 - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %c5 = arith.constant 5 : index - %c6 = arith.constant 6 : index - %c256 = arith.constant 256 : i32 - - // CHECK: call @[[LAUNCH:[_a-z.]+]](%[[C0]], %[[C1]], %[[C2]], %[[C3]], - // CHECK-SAME: %[[C4]], %[[C5]], %[[C6]], %[[ARG0]], %[[ARG1]]) - // CHECK-SAME: kernel = "fn0" - gpu.launch_func @gpu_module::@fn0 - blocks in (%c1, %c2, %c3) - threads in (%c4, %c5, %c6) - args(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>) - - // CHECK: call @[[LAUNCH]](%[[C256]], %[[C3]], %[[C2]], %[[C1]], %[[C6]], - // CHECK-SAME: %[[C5]], %[[C4]], %[[ARG0]], %[[ARG1]]) - // CHECK-DAG: kernel = "fn1" - gpu.launch_func @gpu_module::@fn1 - blocks in (%c3, %c2, %c1) - threads in (%c6, %c5, %c4) - dynamic_shared_memory_size %c256 - args(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>) - - func.return -} - -// CHECK: func private @[[LAUNCH]](i32, i32, i32, i32, i32, i32, -// CHECK-SAME: memref<4x4xf32>, memref<4x4xf32>) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.func.launch"} - -// Check that we have a single custom call declaration in the module. -// CHECK-NOT: rt.custom_call - -} diff --git a/xla/mlir/backends/gpu/transforms/tests/gpu_memcpy.mlir b/xla/mlir/backends/gpu/transforms/tests/gpu_memcpy.mlir deleted file mode 100644 index 410a94c489ed1..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/gpu_memcpy.mlir +++ /dev/null @@ -1,46 +0,0 @@ -// RUN: xla-gpu-opt %s --split-input-file -xla-gpu-to-gpu-runtime | FileCheck %s - -// CHECK: func @gpu_memcpy_d2d( -// CHECK: %[[DST:[a-z0-9]+]]: memref, -// CHECK: %[[SRC:[a-z0-9]+]]: memref -// CHECK: ) -func.func @gpu_memcpy_d2d(%dst: memref, %src: memref) { - // CHECK: call @[[MEMCPY:.*]](%[[DST]], %[[SRC]]) - gpu.memcpy %dst, %src : memref, memref - return -} - -// CHECK: func private @[[MEMCPY]](memref, memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.memcpy.d2d"} - -// ----- - -// CHECK: func @gpu_memcpy_h2d( -// CHECK: %[[DST:[a-z0-9]+]]: memref -// CHECK: ) -func.func @gpu_memcpy_h2d(%dst: memref, %dim: index) { - // CHECK: %[[SRC:.*]] = memref.alloca - %src = memref.alloca(%dim) : memref - // CHECK: call @[[MEMCPY:.*]](%[[DST]], %[[SRC]]) - gpu.memcpy %dst, %src : memref, memref - return -} - -// CHECK: func private @[[MEMCPY]](memref, memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.memcpy.h2d"} - -// ----- - -// CHECK: func @gpu_memcpy_d2h( -// CHECK: %[[SRC:[a-z0-9]+]]: memref -// CHECK: ) -func.func @gpu_memcpy_d2h(%src: memref, %dim: index) { - // CHECK: %[[DST:.*]] = memref.alloca - %dst = memref.alloca(%dim) : memref - // CHECK: call @[[MEMCPY:.*]](%[[DST]], %[[SRC]]) - gpu.memcpy %dst, %src : memref, memref - return -} - -// CHECK: func private @[[MEMCPY]](memref, memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.memcpy.d2h"} diff --git a/xla/mlir/backends/gpu/transforms/tests/gpu_memset.mlir b/xla/mlir/backends/gpu/transforms/tests/gpu_memset.mlir deleted file mode 100644 index 33b2232fc8519..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/gpu_memset.mlir +++ /dev/null @@ -1,31 +0,0 @@ -// RUN: xla-gpu-opt %s --split-input-file -xla-gpu-to-gpu-runtime | FileCheck %s - -// CHECK: func @gpu_memset_i32( -// CHECK: %[[DST:[a-z0-9]+]]: memref -// CHECK: ) -func.func @gpu_memset_i32(%dst: memref) { - // CHECK: %[[CST:.*]] = arith.constant 0 : i32 - %cst = arith.constant 0 : i32 - // CHECK: call @[[MEMSET:.*]](%[[DST]], %[[CST]]) - gpu.memset %dst, %cst : memref, i32 - return -} - -// CHECK: func private @[[MEMSET]](memref, i32) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.memset"} - -// ----- - -// CHECK: func @gpu_memset_f32( -// CHECK: %[[DST:[a-z0-9]+]]: memref -// CHECK: ) -func.func @gpu_memset_f32(%dst: memref) { - // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 - %cst = arith.constant 0.000000e+00 : f32 - // CHECK: call @[[MEMSET:.*]](%[[DST]], %[[CST]]) - gpu.memset %dst, %cst : memref, f32 - return -} - -// CHECK: func private @[[MEMSET]](memref, f32) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.memset"} diff --git a/xla/mlir/backends/gpu/transforms/tests/lmhlo_case.mlir b/xla/mlir/backends/gpu/transforms/tests/lmhlo_case.mlir deleted file mode 100644 index dc30c5a42374a..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/lmhlo_case.mlir +++ /dev/null @@ -1,116 +0,0 @@ -// RUN: xla-gpu-opt %s -xla-lmhlo-to-gpu-runtime | FileCheck %s - -module attributes {gpu.container_module} { - memref.global "private" constant @constant : memref = dense<0> - - gpu.module @case0 attributes {binary = "ptx"} { - gpu.func @fn(%arg0: memref) kernel { - gpu.return - } - } - - gpu.module @case1 attributes {binary = "ptx"} { - gpu.func @fn(%arg0: memref) kernel { - gpu.return - } - } - - // CHECK: @case_true_false( - // CHECK-SAME: %[[ARG0:.*]]: memref, - // CHECK-SAME: %[[ARG1:.*]]: memref - // CHECK-SAME: ) - func.func @case_true_false(%arg0: memref, %arg1: memref) { - %c1 = arith.constant 1 : index - - // CHECK: %[[C0:.*]] = arith.constant 0 : i32 - // CHECK: %[[C1:.*]] = arith.constant 1 : i32 - - // CHECK: %[[HOST:.*]] = memref.alloca() : memref - // CHECK: gpu.memcpy %[[HOST]], %[[ARG1]] - - // CHECK: %[[PRED:.*]] = memref.load %[[HOST]][] : memref - // CHECK: %[[IDX:.*]] = arith.select %[[PRED]], %[[C0]], %[[C1]] - - // CHECK: scf.execute_region - // CHECK: cf.switch %[[IDX]] : i32 - // CHECK: default: ^[[YIELD:.*]], - // CHECK: 0: ^[[CASE0:.*]], - // CHECK: 1: ^[[CASE1:.*]] - "lmhlo.case"(%arg1) ({ - gpu.launch_func @case0::@fn blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - "lmhlo.terminator"() : () -> () - }, { - gpu.launch_func @case1::@fn blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - "lmhlo.terminator"() : () -> () - }) : (memref) -> () - - // CHECK: ^[[CASE0]]: - // CHECK: gpu.launch_func @case0::@fn - // CHECK: cf.br ^[[YIELD]] - - // CHECK: ^[[CASE1]]: - // CHECK: gpu.launch_func @case1::@fn - // CHECK: cf.br ^[[YIELD]] - - // CHECK: ^[[YIELD]]: - // CHECK-NEXT: scf.yield - - // CHECK: return - "lmhlo.terminator"() : () -> () - } - - // CHECK: @case_index( - // CHECK-SAME: %[[ARG0:.*]]: memref, - // CHECK-SAME: %[[ARG1:.*]]: memref - // CHECK-SAME: ) - func.func @case_index(%arg0: memref, %arg1: memref) { - %c1 = arith.constant 1 : index - - // CHECK: %[[C0:.*]] = arith.constant 0 : i32 - // CHECK: %[[C1:.*]] = arith.constant 1 : i32 - - // CHECK: %[[HOST:.*]] = memref.alloca() : memref - // CHECK: gpu.memcpy %[[HOST]], %[[ARG1]] - - // CHECK: %[[PRED:.*]] = memref.load %[[HOST]][] : memref - // CHECK: %[[SMALL:.*]] = arith.cmpi slt, %[[PRED]], %[[C0]] : i32 - // CHECK: %[[LARGE:.*]] = arith.cmpi sgt, %[[PRED]], %[[C1]] : i32 - // CHECK: %[[OOR:.*]] = arith.ori %[[SMALL]], %[[LARGE]] : i1 - // CHECK: %[[IDX:.*]] = arith.select %[[OOR]], %[[C1]], %[[PRED]] : i32 - - // CHECK: scf.execute_region - // CHECK: cf.switch %[[IDX]] : i32 - // CHECK: default: ^[[YIELD:.*]], - // CHECK: 0: ^[[CASE0:.*]], - // CHECK: 1: ^[[CASE1:.*]] - "lmhlo.case"(%arg1) ({ - gpu.launch_func @case0::@fn blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - "lmhlo.terminator"() : () -> () - }, { - gpu.launch_func @case1::@fn blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - "lmhlo.terminator"() : () -> () - }) : (memref) -> () - - // CHECK: ^[[CASE0]]: - // CHECK: gpu.launch_func @case0::@fn - // CHECK: cf.br ^[[YIELD]] - - // CHECK: ^[[CASE1]]: - // CHECK: gpu.launch_func @case1::@fn - // CHECK: cf.br ^[[YIELD]] - - // CHECK: ^[[YIELD]]: - // CHECK-NEXT: scf.yield - - // CHECK: return - "lmhlo.terminator"() : () -> () - } -} diff --git a/xla/mlir/backends/gpu/transforms/tests/lmhlo_custom_call.mlir b/xla/mlir/backends/gpu/transforms/tests/lmhlo_custom_call.mlir deleted file mode 100644 index 6b333f90f77dc..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/lmhlo_custom_call.mlir +++ /dev/null @@ -1,64 +0,0 @@ -// RUN: xla-gpu-opt %s -split-input-file -xla-lmhlo-to-gpu-runtime \ -// RUN: | FileCheck %s - -// CHECK: func @test -// CHECK: %[[ARG0:.*]]: memref -// CHECK: ) -func.func @test(%arg0: memref) { - // CHECK: call @[[CUSTOM_CALL:.*]](%[[ARG0]]) - // CHECK-SAME: api_version = 2 : i32 - // CHECK-SAME: backend_config = "" - // CHECK-SAME: call_target_name = "target" - // CHECK-SAME: : (memref) -> () - "lmhlo.custom_call"(%arg0) ({}) { - api_version = 2 : i32, - backend_config = "", - call_target_name = "target", - operandSegmentSizes = array - } : (memref) -> () - return -} - -// CHECK: func.func private @[[CUSTOM_CALL]](memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.custom_call"} - -// ----- - -// CHECK: func @test_with_mapping -// CHECK: %[[ARG0:[0-9a-z]*]]: memref, -// CHECK: %[[ARG1:[0-9a-z]*]]: memref, -// CHECK: %[[ARG2:[0-9a-z]*]]: memref, -// CHECK: %[[ARG3:[0-9a-z]*]]: memref, -// CHECK: %[[ARG4:[0-9a-z]*]]: memref -// CHECK: ) -func.func @test_with_mapping( - %arg0: memref, - %arg1: memref, - %arg2: memref, - %arg3: memref, - %arg4: memref) { - // CHECK: %[[HOLE:.*]] = arith.constant -1 : i64 - - // CHECK: call @[[CUSTOM_CALL:.*]](%[[ARG0]], %[[HOLE]], %[[ARG1]], %[[HOLE]], - // CHECK-SAME: %[[ARG2]], %[[ARG3]], %[[HOLE]], %[[ARG4]]) - // CHECK-SAME: api_version = 1 : i32 - // CHECK-SAME: backend_config = "" - // CHECK-SAME: call_target_name = "target" - "lmhlo.custom_call"(%arg0, %arg1, %arg2, %arg3, %arg4) ({}) { - api_version = 1 : i32, - backend_config = "", - call_target_name = "target", - operandSegmentSizes = array, - target_arg_mapping = #lmhlo.custom_call_target_arg_mapping< - num_args = 4, - num_results = 4, - args_to_target_args = [0, 2], - results_to_target_results = [0, 1, 3]> - } : (memref, memref, memref, memref, memref) -> () - - return -} - -// CHECK: func.func private @[[CUSTOM_CALL]](memref, i64, memref, i64, -// CHECK-SAME: memref, memref, i64, memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.custom_call"} diff --git a/xla/mlir/backends/gpu/transforms/tests/lmhlo_fft.mlir b/xla/mlir/backends/gpu/transforms/tests/lmhlo_fft.mlir deleted file mode 100644 index aeb19228d013e..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/lmhlo_fft.mlir +++ /dev/null @@ -1,25 +0,0 @@ -// RUN: xla-gpu-opt %s -xla-lmhlo-to-gpu-runtime | FileCheck %s - -// CHECK: @compute( -// CHECK: %[[ARG0:[a-z0-9]+]]: memref<3x5x16x5xcomplex -// CHECK: %[[ARG1:[a-z0-9]+]]: memref<3x5x16x8xf32> -// CHECK: ) -func.func @compute(%arg0: memref<3x5x16x5xcomplex>, - %arg1: memref<3x5x16x8xf32>) { - - // CHECK: call @[[FFT:.*]](%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: fft_length = dense<[16, 8]> : tensor<2xi64> - // CHECK-SAME: fft_type = #mhlo - // CHECK-SAME: uid = 0 : i64 - "lmhlo.fft"(%arg0, %arg1) { - fft_length = dense<[16, 8]> : tensor<2xi64>, - fft_type = #mhlo - } : (memref<3x5x16x5xcomplex>, memref<3x5x16x8xf32>) -> () - - // CHECK-NEXT: return - func.return -} - -// CHECK: func private @[[FFT]](memref<3x5x16x5xcomplex>, -// CHECK-SAME: memref<3x5x16x8xf32>) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.fft"} diff --git a/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_cholesky.mlir b/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_cholesky.mlir deleted file mode 100644 index 9dbc0ee2eb260..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_cholesky.mlir +++ /dev/null @@ -1,28 +0,0 @@ -// RUN: xla-gpu-opt %s -xla-lmhlo-gpu-to-gpu-runtime | FileCheck %s - -// CHECK: @compute( -// CHECK: %[[ARG0:[a-z0-9]+]]: memref<4x4xi32> -// CHECK: %[[ARG1:[a-z0-9]+]]: memref<4x4xi32> -// CHECK: %[[ARG2:[a-z0-9]+]]: memref<4x4xi32> -// CHECK: %[[ARG3:[a-z0-9]+]]: memref<4x4xi32> -// CHECK: ) -func.func @compute(%operand: memref<4x4xi32>, %a: memref<4x4xi32>, - %workspace: memref<4x4xi32>, %info: memref<4x4xi32>) { - - // CHECK: call @[[CHOLESKY:.*]](%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]]) - // CHECK-SAME: batch_size = 1 : i64 - // CHECK-SAME: is_lower = true - // CHECK-SAME: n = 4 : i64 - "lmhlo_gpu.cholesky"(%operand, %a, %workspace, %info) { - batch_size = 1 : i64, - is_lower = true, - n = 4 : i64 - } : (memref<4x4xi32>, memref<4x4xi32>, memref<4x4xi32>, memref<4x4xi32>) -> () - - // CHECK-NEXT: return - func.return -} - -// CHECK: func private @[[CHOLESKY]](memref<4x4xi32>, memref<4x4xi32>, -// CHECK-SAME: memref<4x4xi32>, memref<4x4xi32>) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.cholesky"} diff --git a/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_conv.mlir b/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_conv.mlir deleted file mode 100644 index 59d6ca10a9bdc..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_conv.mlir +++ /dev/null @@ -1,380 +0,0 @@ -// RUN: xla-gpu-opt %s -split-input-file -xla-lmhlo-gpu-to-gpu-runtime \ -// RUN: | FileCheck %s - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * 3 + d1 + d2 * 9 + d3 * 9)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 4 + d2 + d3 * 16)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 2 + d2 + d3 * 4)> - -// CHECK: @conv_forward( -// CHECK: %[[INPUT:[a-z0-9]+]]: memref -// CHECK: %[[FILTER:[a-z0-9]+]]: memref -// CHECK: %[[OUTPUT:[a-z0-9]+]]: memref -// CHECK: %[[SCRATCH:[a-z0-9]+]]: memref -// CHECK: ) -func.func @conv_forward(%input: memref<1x4x4x1024xf16, #map1>, - %filter: memref<3x3x1x1024xf16, #map0>, - %output: memref<1x2x2x1024xf16, #map2>, - %scratch: memref<0xui8>) { - - // CHECK: call @xla.gpu.conv.forward( - // CHECK-SAME: %[[INPUT]], %[[FILTER]], %[[OUTPUT]], %[[SCRATCH]]) - - // CHECK-DAG: uid = 0 : i64 - // CHECK-DAG: conv_dims = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]> - - // CHECK-DAG: window_strides = dense<1> : tensor<2xi64> - // CHECK-DAG: lhs_dilation = dense<1> : tensor<2xi64> - // CHECK-DAG: rhs_dilation = dense<1> : tensor<2xi64> - // CHECK-DAG: window_reversal = dense<0> : tensor<2xi64> - // CHECK-DAG: padding = dense<> : tensor<0xi64> - - // CHECK-DAG: backend_config = #lmhlo_gpu.convolution_backend_config< - // CHECK-DAG: algorithm = 0 - // CHECK-DAG: is_cudnn_frontend = true - // CHECK-DAG: knob_ids = [] - // CHECK-DAG: knob_values = [] - // CHECK-DAG: operand_0_layout = [2, 1, 3, 0] - // CHECK-DAG: operand_1_layout = [1, 0, 2, 3] - // CHECK-DAG: tensor_ops_enabled = false - // CHECK-DAG: workspace_size = 0 - - // CHECK-DAG: feature_group_count = 1024 : i64 - // CHECK-DAG: result_scale = 1.000000e+00 : f64 - lmhlo_gpu.conv_forward(%input, %filter, %output, %scratch) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = { stride = [1, 1], - lhs_dilate = [1, 1], - rhs_dilate = [1, 1], - reverse = [0, 0] - } - { backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = 0, - is_cudnn_frontend = true, - is_cudnn_reordered_int8 = false, - knob_ids = [], - knob_values = [], - operand_0_layout = [2, 1, 3, 0], - operand_1_layout = [1, 0, 2, 3], - result_layout = [2, 1, 3, 0], - tensor_ops_enabled = false, - workspace_size = 0 - >, - batch_group_count = 1 : i64, - feature_group_count = 1024 : i64, - precision_config = [], - result_scale = 1.000000e+00 : f64 - } : (memref<1x4x4x1024xf16, #map1>, - memref<3x3x1x1024xf16, #map0>, - memref<1x2x2x1024xf16, #map2>, - memref<0xui8>) -> () - - return -} - -// CHECK: func private @xla.gpu.conv.forward( -// CHECK-SAME: memref<1x4x4x1024xf16, #map{{[0-9]*}}>, memref<3x3x1x1024xf16, #map{{[0-9]*}}>, -// CHECK-SAME: memref<1x2x2x1024xf16, #map{{[0-9]*}}>, memref<0xui8>) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.conv.forward"} - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * 9 + d1 * 3 + d2 + d3 * 9)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 3 + d1 + d2 * 27 + d3 * 9)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d0 + d1 + d2 + d3)> - -// CHECK: @conv_backwardfilter( -// CHECK: %[[INPUT:[a-z0-9]+]]: memref -// CHECK: %[[D_OUTPUT:[a-z0-9]+]]: memref -// CHECK: %[[D_FILTER:[a-z0-9]+]]: memref -// CHECK: %[[SCRATCH:[a-z0-9]+]]: memref -// CHECK: ) -func.func @conv_backwardfilter(%input: memref<1x3x3x5xf16, #map0>, - %d_output: memref<3x3x5x3xf16, #map1>, - %d_filter: memref<1x1x1x3xf16, #map2>, - %scratch: memref<0xui8>) { - // CHECK: call @xla.gpu.conv.backward.filter( - // CHECK-SAME: %[[INPUT]], %[[D_OUTPUT]], %[[D_FILTER]], %[[SCRATCH]]) - lmhlo_gpu.conv_backwardfilter(%input, %d_output, %d_filter, %scratch) - dim_numbers = [f, 0, 1, b]x[i, 0, 1, o]->[0, 1, b, f], - window = { stride = [1, 1], - lhs_dilate = [1, 1], - rhs_dilate = [1, 1], - reverse = [0, 0] - } - { backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = 0, - is_cudnn_frontend = true, - is_cudnn_reordered_int8 = false, - knob_ids = [], - knob_values = [], - operand_0_layout = [2, 1, 0, 3], - operand_1_layout = [1, 0, 3, 2], - result_layout = [2, 1, 0, 3], - tensor_ops_enabled = false, - workspace_size = 0 - >, - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [], - result_scale = 1.000000e+00 : f64 - } : (memref<1x3x3x5xf16, #map0>, - memref<3x3x5x3xf16, #map1>, - memref<1x1x1x3xf16, #map2>, - memref<0xui8>) -> () - return -} - -// CHECK: func private @xla.gpu.conv.backward.filter( -// CHECK-SAME: memref<1x3x3x5xf16, #map{{[0-9]*}}>, memref<3x3x5x3xf16, #map{{[0-9]*}}>, -// CHECK-SAME: memref<1x1x1x3xf16, #map{{[0-9]*}}>, memref<0xui8> -// CHECK-SAME: ) attributes {rt.custom_call = -// CHECK-SAME: "xla.gpu.conv.backward.filter"} - -// ----- - -// CHECK: @conv_backwardinput( -// CHECK: %[[D_OUTPUT:[a-z0-9]+]]: memref -// CHECK: %[[FILTER:[a-z0-9]+]]: memref -// CHECK: %[[D_INPUT:[a-z0-9]+]]: memref -// CHECK: %[[SCRATCH:[a-z0-9]+]]: memref -// CHECK: ) -func.func @conv_backwardinput(%d_output: memref<4x5x16x16xf64>, - %filter: memref<5x3x7x7xf64>, - %d_input: memref<4x3x16x16xf64>, - %scratch: memref<0xui8>) { - // CHECK: call @xla.gpu.conv.backward.input( - // CHECK-SAME: %[[D_OUTPUT]], %[[FILTER]], %[[D_INPUT]], %[[SCRATCH]]) - lmhlo_gpu.conv_backwardinput(%d_output, %filter, %d_input, %scratch) - dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], - window = { stride = [1, 1], - lhs_dilate = [1, 1], - rhs_dilate = [1, 1], - reverse = [0, 0] - } - { backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = 2, - is_cudnn_frontend = true, - is_cudnn_reordered_int8 = false, - knob_ids = [3, 2], - knob_values = [0, 3], - operand_0_layout = [3, 2, 1, 0], - operand_1_layout = [3, 2, 1, 0], - result_layout = [3, 2, 1, 0], - tensor_ops_enabled = false, - workspace_size = 0 - >, - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [], - result_scale = 1.000000e+00 : f64 - } : (memref<4x5x16x16xf64>, - memref<5x3x7x7xf64>, - memref<4x3x16x16xf64>, - memref<0xui8>) -> () - return -} - -// CHECK: func private @xla.gpu.conv.backward.input( -// CHECK-SAME: memref<4x5x16x16xf64>, memref<5x3x7x7xf64>, -// CHECK-SAME: memref<4x3x16x16xf64>, memref<0xui8> -// CHECK-SAME: ) attributes {rt.custom_call = -// CHECK-SAME: "xla.gpu.conv.backward.input"} - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * 3 + d1 + d2 * 9 + d3 * 9)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 25 + d1 * 5 + d2 + d3 * 25)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d0 * 800 + d1 * 5 + d2 + d3 * 25)> - -// CHECK: @conv_forward_fused( -// CHECK: %[[INPUT:[a-z0-9]+]]: memref -// CHECK: %[[FILTER:[a-z0-9]+]]: memref -// CHECK: %[[BIAS:[a-z0-9]+]]: memref -// CHECK: %[[OUTPUT:[a-z0-9]+]]: memref -// CHECK: %[[SCRATCH:[a-z0-9]+]]: memref -// CHECK: ) -func.func @conv_forward_fused(%input: memref<8x5x5x1xf32, #map1>, - %filter: memref<3x3x1x32xf32, #map0>, - %bias: memref<32xf32>, - %output: memref<8x5x5x32xf32, #map2>, - %scratch: memref<0xui8>) { - // CHECK: call @xla.gpu.conv.forward.fused( - // CHECK-SAME: %[[INPUT]], %[[FILTER]], %[[BIAS]], %[[OUTPUT]], %[[SCRATCH]]) - - // CHECK-DAG: activation_mode = #lmhlo_gpu - // CHECK-DAG: knob_ids = [2, 3] - // CHECK-DAG: knob_values = [4, 0] - lmhlo_gpu.conv_forward_fused(%input, %filter, %bias, %output, %scratch) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = { stride = [1, 1], - lhs_dilate = [1, 1], - rhs_dilate = [1, 1], - reverse = [0, 0] - } - { activation_mode = #lmhlo_gpu, - leakyrelu_alpha = 0.0 : f64, - backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = 11, - is_cudnn_frontend = true, - is_cudnn_reordered_int8 = false, - knob_ids = [2, 3], - knob_values = [4, 0], - operand_0_layout = [2, 1, 3, 0], - operand_1_layout = [1, 0, 2, 3], - result_layout = [2, 1, 3, 0], - tensor_ops_enabled = false, - workspace_size = 0 - >, - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [], - result_scale = 1.000000e+00 : f64 - } : (memref<8x5x5x1xf32, #map1>, - memref<3x3x1x32xf32, #map0>, - memref<32xf32>, - memref<8x5x5x32xf32, #map2>, - memref<0xui8>) -> () - - return -} - -// CHECK: func private @xla.gpu.conv.forward.fused( -// CHECK-SAME: memref<8x5x5x1xf32, #map{{[0-9]*}}>, memref<3x3x1x32xf32, #map{{[0-9]*}}>, -// CHECK-SAME: memref<32xf32>, memref<8x5x5x32xf32, #map{{[0-9]*}}>, memref<0xui8> -// CHECK-SAME: ) attributes {rt.custom_call = -// CHECK-SAME: "xla.gpu.conv.forward.fused"} - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * 576 + d1 * 3 + d2 + d3 * 9)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 3 + d1 + d2 * 9 + d3 * 576)> - -// CHECK: @conv_forward_fused_with_side_input( -// CHECK: %[[INPUT:[a-z0-9]+]]: memref -// CHECK: %[[FILTER:[a-z0-9]+]]: memref -// CHECK: %[[BIAS:[a-z0-9]+]]: memref -// CHECK: %[[SIDE_INPUT:[a-z0-9]+]]: memref -// CHECK: %[[OUTPUT:[a-z0-9]+]]: memref -// CHECK: %[[SCRATCH:[a-z0-9]+]]: memref -// CHECK: ) -func.func @conv_forward_fused_with_side_input( - %input: memref<1x3x3x64xf64, #map0>, - %filter: memref<3x3x64x64xf64, #map1>, - %bias: memref<64xf64>, - %side_input: memref<1x3x3x64xf64, #map0>, - %output: memref<1x3x3x64xf64, #map0>, - %scratch: memref<0xui8>) { - - // CHECK: call @xla.gpu.conv.forward.fused.side_input( - // CHECK-SAME: %[[INPUT]], %[[FILTER]], %[[BIAS]], %[[SIDE_INPUT]], - // CHECK-SAME: %[[OUTPUT]], %[[SCRATCH]]) - - // CHECK-DAG: activation_mode = #lmhlo_gpu - // CHECK-DAG: side_input_scale = 1.000000e+00 : f64 - lmhlo_gpu.conv_forward_fused_with_side_input( - %input, %filter, %bias, %side_input, %output, %scratch) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = { stride = [1, 1], - lhs_dilate = [1, 1], - rhs_dilate = [1, 1], - reverse = [0, 0] - } - { activation_mode = #lmhlo_gpu, - backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = 0, - is_cudnn_frontend = true, - is_cudnn_reordered_int8 = false, - knob_ids = [], - knob_values = [], - operand_0_layout = [2, 1, 3, 0], - operand_1_layout = [1, 0, 2, 3], - result_layout = [2, 1, 3, 0], - tensor_ops_enabled = false, - workspace_size = 0 - >, - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [], - result_scale = 1.000000e+00 : f64, - side_input_scale = 1.000000e+00 : f64 - } : (memref<1x3x3x64xf64, #map0>, - memref<3x3x64x64xf64, #map1>, - memref<64xf64>, - memref<1x3x3x64xf64, #map0>, - memref<1x3x3x64xf64, #map0>, - memref<0xui8>) -> () - - return -} - -// CHECK: func private @xla.gpu.conv.forward.fused.side_input( -// CHECK-SAME: memref<1x3x3x64xf64, #map{{[0-9]*}}>, memref<3x3x64x64xf64, #map{{[0-9]*}}>, -// CHECK-SAME: memref<64xf64>, memref<1x3x3x64xf64, #map{{[0-9]*}}>, -// CHECK-SAME: memref<1x3x3x64xf64, #map{{[0-9]*}}>, memref<0xui8> -// CHECK-SAME: ) attributes {rt.custom_call = -// CHECK-SAME: "xla.gpu.conv.forward.fused.side_input"} - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3, d4) -> (d0 + d1 + d2 + d3 * 3 + d4 * 9)> - -// CHECK: @conv_reorder_filter( -// CHECK: %[[INPUT:[a-z0-9]+]]: memref -// CHECK: %[[OUTPUT:[a-z0-9]+]]: memref -// CHECK: ) -func.func @conv_reorder_filter( - %input: memref<1x1x3x3x32xi8, #map0>, - %output: memref<1x1x3x3x32xi8, #map0>) { - - // CHECK: call @xla.gpu.conv.reorder.filter( - // CHECK-SAME: %[[INPUT]], %[[OUTPUT]] - // CHECK-DAG: filter_dims = array - "lmhlo_gpu.cudnn_conv_reorder_filter"(%input, %output) { - filter_dims = dense<[1, 32, 3, 3]> : tensor<4xi64> - }: (memref<1x1x3x3x32xi8, #map0>, - memref<1x1x3x3x32xi8, #map0>) -> () - - return -} - -// CHECK: func private @xla.gpu.conv.reorder.filter( -// CHECK-SAME: memref<1x1x3x3x32xi8, #map{{[0-9]*}}>, -// CHECK-SAME: memref<1x1x3x3x32xi8, #map{{[0-9]*}}> -// CHECK-SAME: ) attributes {rt.custom_call = -// CHECK-SAME: "xla.gpu.conv.reorder.filter"} - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3, d4) -> (d0 + d1 + d2 + d3 * 3 + d4 * 9)> - -// CHECK: @conv_reorder_filter_and_bias( -// CHECK: %[[FILTER_INPUT:[a-z0-9]+]]: memref -// CHECK: %[[BIAS_INPUT:[a-z0-9]+]]: memref -// CHECK: %[[FILTER_OUTPUT:[a-z0-9]+]]: memref -// CHECK: %[[BIAS_OUTPUT:[a-z0-9]+]]: memref -// CHECK: ) -func.func @conv_reorder_filter_and_bias( - %filter_input: memref<1x1x3x3x32xi8, #map0>, - %bias_input: memref<32xf32>, - %filter_output: memref<1x1x3x3x32xi8, #map0>, - %bias_output: memref<32xf32>) { - - // CHECK: call @xla.gpu.conv.reorder.filter_and_bias( - // CHECK-SAME: %[[FILTER_INPUT]], %[[BIAS_INPUT]], %[[FILTER_OUTPUT]], %[[BIAS_OUTPUT]] - // CHECK-DAG: filter_dims = array - "lmhlo_gpu.cudnn_conv_reorder_filter_and_bias"( - %filter_input, %bias_input, %filter_output, %bias_output) { - filter_dims = dense<[1, 32, 3, 3]> : tensor<4xi64> - }: (memref<1x1x3x3x32xi8, #map0>, memref<32xf32>, - memref<1x1x3x3x32xi8, #map0>, memref<32xf32>) -> () - - return -} - -// CHECK: func private @xla.gpu.conv.reorder.filter_and_bias( -// CHECK-SAME: memref<1x1x3x3x32xi8, #map{{[0-9]*}}>, -// CHECK-SAME: memref<32xf32>, -// CHECK-SAME: memref<1x1x3x3x32xi8, #map{{[0-9]*}}>, -// CHECK-SAME: memref<32xf32> -// CHECK-SAME: ) attributes {rt.custom_call = -// CHECK-SAME: "xla.gpu.conv.reorder.filter_and_bias"} diff --git a/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_cublas_lt_matmul.mlir b/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_cublas_lt_matmul.mlir deleted file mode 100644 index 249461a5266a7..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_cublas_lt_matmul.mlir +++ /dev/null @@ -1,100 +0,0 @@ -// RUN: xla-gpu-opt %s -split-input-file -xla-lmhlo-gpu-to-gpu-runtime \ -// RUN: | FileCheck %s - -// CHECK: @compute( -// CHECK: %[[A:[a-z0-9]+]]: memref<2x6x2x2xf32>, -// CHECK: %[[B:[a-z0-9]+]]: memref<2x6x2x2xf32>, -// CHECK: %[[C:[a-z0-9]+]]: memref<2x6x2x2xf32>, -// CHECK: %[[D:[a-z0-9]+]]: memref<2x6x2x2xf32> -// CHECK: ) -func.func @compute(%a: memref<2x6x2x2xf32>, - %b: memref<2x6x2x2xf32>, - %c: memref<2x6x2x2xf32>, - %d: memref<2x6x2x2xf32>) { - - // CHECK: @xla.gpu.cublas.lt.matmul(%[[A]], %[[B]], %[[C]], %[[D]]) - // CHECK-SAME: alpha_imag = 0.000000e+00 : f64 - // CHECK-SAME: alpha_real = 1.000000e+00 : f64 - // CHECK-SAME: beta = 0.000000e+00 : f64 - // CHECK-SAME: dot_dims = #mhlo.dot - // CHECK-SAME: epilogue = #lmhlo_gpu - // CHECK-SAME: precision = dense<0> : tensor<2xi32> - // CHECK-SAME: uid = 0 : i64 - "lmhlo_gpu.cublas.lt.matmul"(%a, %b, %c, %d) { - algorithm = 0 : i64, - alpha_imag = 0.000000e+00 : f64, - alpha_real = 1.000000e+00 : f64, - beta = 0.000000e+00 : f64, - dot_dimension_numbers = #mhlo.dot< - lhs_batching_dimensions = [0, 1], - rhs_batching_dimensions = [0, 1], - lhs_contracting_dimensions = [3], - rhs_contracting_dimensions = [2]>, - epilogue = #lmhlo_gpu, - precision_config = [#mhlo, #mhlo], - operandSegmentSizes = array - } : (memref<2x6x2x2xf32>, memref<2x6x2x2xf32>, - memref<2x6x2x2xf32>, memref<2x6x2x2xf32>) -> () - - return -} - -// CHECK: func private @xla.gpu.cublas.lt.matmul( -// CHECK-SAME: memref<2x6x2x2xf32>, memref<2x6x2x2xf32>, -// CHECK-SAME: memref<2x6x2x2xf32>, memref<2x6x2x2xf32> -// CHECK-SAME: ) attributes {rt.custom_call = "xla.gpu.cublas.lt.matmul"} - -// ----- - -// CHECK: @compute( -// CHECK: %[[A:[a-z0-9]+]]: memref<2x6x2x2xf32>, -// CHECK: %[[B:[a-z0-9]+]]: memref<2x6x2x2xf32>, -// CHECK: %[[C:[a-z0-9]+]]: memref<2x6x2x2xf32>, -// CHECK: %[[D:[a-z0-9]+]]: memref<2x6x2x2xf32>, -// CHECK: %[[BIAS:[a-z0-9]+]]: memref<2x6x2x2xf32> -// CHECK: ) -func.func @compute(%a: memref<2x6x2x2xf32>, - %b: memref<2x6x2x2xf32>, - %c: memref<2x6x2x2xf32>, - %d: memref<2x6x2x2xf32>, - %bias: memref<2x6x2x2xf32>) { - - // CHECK: @xla.gpu.cublas.lt.matmul.bias(%[[A]], %[[B]], %[[C]], %[[D]], - // CHECK-SAME: %[[BIAS]]) - // CHECK-SAME: alpha_imag = 0.000000e+00 : f64 - // CHECK-SAME: alpha_real = 1.000000e+00 : f64 - // CHECK-SAME: beta = 0.000000e+00 : f64 - // CHECK-SAME: dot_dims = #mhlo.dot - // CHECK-SAME: epilogue = #lmhlo_gpu - // CHECK-SAME: precision = dense<0> : tensor<2xi32> - // CHECK-SAME: uid = 0 : i64 - "lmhlo_gpu.cublas.lt.matmul"(%a, %b, %c, %d, %bias) { - algorithm = 0 : i64, - alpha_imag = 0.000000e+00 : f64, - alpha_real = 1.000000e+00 : f64, - beta = 0.000000e+00 : f64, - dot_dimension_numbers = #mhlo.dot< - lhs_batching_dimensions = [0, 1], - rhs_batching_dimensions = [0, 1], - lhs_contracting_dimensions = [3], - rhs_contracting_dimensions = [2]>, - epilogue = #lmhlo_gpu, - precision_config = [#mhlo, #mhlo], - operandSegmentSizes = array - } : (memref<2x6x2x2xf32>, memref<2x6x2x2xf32>, memref<2x6x2x2xf32>, - memref<2x6x2x2xf32>, memref<2x6x2x2xf32>) -> () - - return -} - -// CHECK: func private @xla.gpu.cublas.lt.matmul.bias( -// CHECK-SAME: memref<2x6x2x2xf32>, memref<2x6x2x2xf32>, memref<2x6x2x2xf32>, -// CHECK-SAME: memref<2x6x2x2xf32>, memref<2x6x2x2xf32> -// CHECK-SAME: ) attributes {rt.custom_call = -// CHECK-SAME: "xla.gpu.cublas.lt.matmul.bias"} diff --git a/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_gemm.mlir b/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_gemm.mlir deleted file mode 100644 index 51d2247049d3c..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_gemm.mlir +++ /dev/null @@ -1,41 +0,0 @@ -// RUN: xla-gpu-opt %s -split-input-file -xla-lmhlo-gpu-to-gpu-runtime \ -// RUN: | FileCheck %s - -// CHECK: @compute( -// CHECK: %[[LHS:[a-z0-9]+]]: memref<4x4xf32>, -// CHECK: %[[RHS:[a-z0-9]+]]: memref<4x4xf32>, -// CHECK: %[[OUT:[a-z0-9]+]]: memref<4x4xf32> -// CHECK: ) -func.func @compute(%lhs: memref<4x4xf32>, %rhs: memref<4x4xf32>, - %out: memref<4x4xf32>) { - - // CHECK: call @[[GEMM:[_a-z.]+]](%[[LHS]], %[[RHS]], %[[OUT]]) - // CHECK-SAME: algorithm = 13 : i64 - // CHECK-SAME: alpha_imag = 0.000000e+00 : f64 - // CHECK-SAME: alpha_real = 1.000000e+00 : f64 - // CHECK-SAME: beta = 0.000000e+00 : f64 - // CHECK-SAME: dot_dims = #mhlo.dot - // CHECK-SAME: uid = 0 : i64 - // CHECK-SAME: (memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>) -> () - "lmhlo_gpu.gemm"(%lhs, %rhs, %out) - { - algorithm = 13 : i64, - alpha_imag = 0.000000e+00 : f64, - alpha_real = 1.000000e+00 : f64, - batch_size = 1 : i64, - beta = 0.000000e+00 : f64, - dot_dimension_numbers = #mhlo.dot, - lhs_stride = 16 : i64, - rhs_stride = 16 : i64 - } - : (memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>) -> () - - // CHECK-NEXT: return - func.return -} - -// CHECK: func private @[[GEMM:[_a-z.]+]](memref<4x4xf32>, memref<4x4xf32>, -// CHECK-SAME: memref<4x4xf32>) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.gemm"} diff --git a/xla/mlir/backends/gpu/transforms/tests/lmhlo_infeed.mlir b/xla/mlir/backends/gpu/transforms/tests/lmhlo_infeed.mlir deleted file mode 100644 index 089a7e6ac3351..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/lmhlo_infeed.mlir +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: xla-gpu-opt %s -xla-lmhlo-to-gpu-runtime | FileCheck %s - -// CHECK: func @gpu_infeed( -// CHECK: %[[ARG0:[a-z0-9]+]]: memref -// CHECK: ) -func.func @gpu_infeed(%arg0: memref) { - // CHECK: call @[[INFEED:.*]](%[[ARG0]]) - // CHECK-SAME: {config = "abc"} : (memref) -> () - "lmhlo.infeed"(%arg0) {config = "abc"} : (memref) -> () - return -} - -// CHECK: func private @[[INFEED]](memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.infeed"} diff --git a/xla/mlir/backends/gpu/transforms/tests/lmhlo_outfeed.mlir b/xla/mlir/backends/gpu/transforms/tests/lmhlo_outfeed.mlir deleted file mode 100644 index 32cf254a7ff99..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/lmhlo_outfeed.mlir +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: xla-gpu-opt %s -xla-lmhlo-to-gpu-runtime | FileCheck %s - -// CHECK: func @gpu_infeed( -// CHECK: %[[ARG0:[a-z0-9]+]]: memref -// CHECK: ) -func.func @gpu_infeed(%arg0: memref) { - // CHECK: call @[[OUTFEED:.*]](%[[ARG0]]) - // CHECK-SAME: {config = "abc"} : (memref) -> () - "lmhlo.outfeed"(%arg0) {config = "abc"} : (memref) -> () - return -} - -// CHECK: func private @[[OUTFEED]](memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.outfeed"} diff --git a/xla/mlir/backends/gpu/transforms/tests/lmhlo_send_recv.mlir b/xla/mlir/backends/gpu/transforms/tests/lmhlo_send_recv.mlir deleted file mode 100644 index 95c8515dc6aee..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/lmhlo_send_recv.mlir +++ /dev/null @@ -1,88 +0,0 @@ -// RUN: xla-gpu-opt %s -split-input-file -xla-lmhlo-to-gpu-runtime \ -// RUN: | FileCheck %s - -// CHECK: func @send( -// CHECK: %[[ARG0:[a-z0-9]+]]: memref<4xf32> -// CHECK: ) -func.func @send(%arg0: memref<4xf32>) { - // CHECK: call @xla.gpu.send_host(%[[ARG0]]) { - // CHECK-SAME: channel_handle = #mhlo.channel_handle, - // CHECK-SAME: frontend_attributes = { - // CHECK-SAME: _xla_dcn_recv_channel = "2", - // CHECK-SAME: _xla_host_transfer_handler_name = "undef", - // CHECK-SAME: _xla_host_transfer_rendezvous = "undef" - // CHECK-SAME: }} : (memref<4xf32>) -> () - "lmhlo.send"(%arg0) { - channel_handle = #mhlo.channel_handle, - frontend_attributes = {_xla_dcn_recv_channel = "2", - _xla_host_transfer_handler_name = "undef", - _xla_host_transfer_rendezvous = "undef"}, - is_host_transfer = true - } : (memref<4xf32>) -> !mhlo.token - return -} - -// CHECK: func private @xla.gpu.send_host(memref<4xf32>) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.send_host"} - -// ----- - -// CHECK: func @recv( -// CHECK: %[[ARG0:[a-z0-9]+]]: memref<4xf32> -// CHECK: ) -func.func @recv(%arg0: memref<4xf32>) { - // CHECK: call @xla.gpu.recv_host(%[[ARG0]]) { - // CHECK-SAME: channel_handle = #mhlo.channel_handle, - // CHECK-SAME: frontend_attributes = { - // CHECK-SAME: _xla_host_transfer_handler_name = "undef", - // CHECK-SAME: _xla_host_transfer_rendezvous = "undef" - // CHECK-SAME: }} : (memref<4xf32>) -> () - "lmhlo.recv"(%arg0) { - channel_handle = #mhlo.channel_handle, - frontend_attributes = {_xla_host_transfer_handler_name = "undef", - _xla_host_transfer_rendezvous = "undef"}, - is_host_transfer = true - } : (memref<4xf32>) -> !mhlo.token - return -} - -// CHECK: func private @xla.gpu.recv_host(memref<4xf32>) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.recv_host"} - -// ----- - -// CHECK: func @send_done( -// CHECK: %[[ARG0:[a-z0-9]+]]: !mhlo.token -// CHECK: ) -func.func @send_done(%arg0: !mhlo.token) { - // CHECK: call @xla.gpu.send_done_host() { - // CHECK-SAME: channel_handle = #mhlo.channel_handle - // CHECK-SAME: } : () -> () - "lmhlo.send_done"(%arg0) { - channel_handle = #mhlo.channel_handle, - is_host_transfer = true - } : (!mhlo.token) -> () - return -} - -// CHECK: func private @xla.gpu.send_done_host() -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.send_done_host"} - -// ----- - -// CHECK: func @recv_done( -// CHECK: %[[ARG0:[a-z0-9]+]]: !mhlo.token -// CHECK: ) -func.func @recv_done(%arg0: !mhlo.token) { - // CHECK: call @xla.gpu.recv_done_host() { - // CHECK-SAME: channel_handle = #mhlo.channel_handle - // CHECK-SAME: } : () -> () - "lmhlo.recv_done"(%arg0) { - channel_handle = #mhlo.channel_handle, - is_host_transfer = true - } : (!mhlo.token) -> () - return -} - -// CHECK: func private @xla.gpu.recv_done_host() -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.recv_done_host"} diff --git a/xla/mlir/backends/gpu/transforms/tests/lmhlo_while.mlir b/xla/mlir/backends/gpu/transforms/tests/lmhlo_while.mlir deleted file mode 100644 index 95459d3bcbff4..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/lmhlo_while.mlir +++ /dev/null @@ -1,97 +0,0 @@ -// RUN: xla-gpu-opt %s --split-input-file -xla-lmhlo-to-gpu-runtime \ -// RUN: | FileCheck %s - -module attributes {gpu.container_module} { - memref.global "private" constant @constant : memref = dense<0> - - gpu.module @cond attributes {binary = "ptx"} { - gpu.func @fn(%arg0: memref, %arg1: memref) kernel { - gpu.return - } - } - - gpu.module @body attributes {binary = "ptx"} { - gpu.func @fn(%arg0: memref) kernel { - gpu.return - } - } - - // CHECK: @while_loop( - // CHECK-SAME: %[[ARG0:.*]]: memref, - // CHECK-SAME: %[[ARG1:.*]]: memref - // CHECK-SAME: ) - func.func @while_loop(%arg0: memref, %arg1: memref) { - %c1 = arith.constant 1 : index - %0 = memref.get_global @constant : memref - gpu.memcpy %arg0, %0 : memref, memref - - // CHECK: %[[HOST_PRED:.*]] = memref.alloca() : memref - // CHECK: scf.while : () -> () - "lmhlo.while"(%arg1) ({ - // CHECK: gpu.launch_func @cond::@fn - // CHECK: gpu.memcpy %[[HOST_PRED]], %[[ARG1]] - // CHECK: %[[COND:.*]] = memref.load %[[HOST_PRED]][] : memref - // CHECK: scf.condition(%[[COND]]) - gpu.launch_func @cond::@fn blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref, %arg1 : memref) - "lmhlo.terminator"() : () -> () - }, { - // CHECK: gpu.launch_func @body::@fn - // CHECK: scf.yield - gpu.launch_func @body::@fn blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - "lmhlo.terminator"() : () -> () - }) : (memref) -> () - "lmhlo.terminator"() : () -> () - } -} - -// ----- -// Check that while loops with known trip counts lower to `scf.for` loops. - -module attributes {gpu.container_module} { - memref.global "private" constant @constant : memref = dense<0> - - gpu.module @cond attributes {binary = "ptx"} { - gpu.func @fn(%arg0: memref, %arg1: memref) kernel { - gpu.return - } - } - - gpu.module @body attributes {binary = "ptx"} { - gpu.func @fn(%arg0: memref) kernel { - gpu.return - } - } - - // CHECK: @for_loop( - // CHECK-SAME: %[[ARG0:.*]]: memref, - // CHECK-SAME: %[[ARG1:.*]]: memref - // CHECK-SAME: ) - func.func @for_loop(%arg0: memref, %arg1: memref) { - // CHECK: %[[LB:.*]] = arith.constant 0 - // CHECK: %[[UB:.*]] = arith.constant 3000 - // CHECK: %[[C1:.*]] = arith.constant 1 - %c1 = arith.constant 1 : index - - // CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[C1]] - // CHECK-NEXT: gpu.launch_func @body::@fn - // CHECK-NOT: gpu.launch.func - - "lmhlo.while"(%arg1) ({ - gpu.launch_func @cond::@fn blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref, %arg1 : memref) - "lmhlo.terminator"() : () -> () - }, { - gpu.launch_func @body::@fn blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - "lmhlo.terminator"() : () -> () - }) {trip_count = 3000 : i64} : (memref) -> () - - "lmhlo.terminator"() : () -> () - } -} diff --git a/xla/mlir/backends/gpu/transforms/tests/memref_get_global_to_arg.mlir b/xla/mlir/backends/gpu/transforms/tests/memref_get_global_to_arg.mlir deleted file mode 100644 index 6361c77f145f1..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/memref_get_global_to_arg.mlir +++ /dev/null @@ -1,43 +0,0 @@ -// RUN: xla-gpu-opt %s -xla-memref-get-global-to-arg=min-num-elements=2 \ -// RUN: | FileCheck %s - -#map = affine_map<(d0, d1) -> (d0 + 2 * d1)> - -memref.global "private" constant @cst0 : memref<2x3xf32> = - dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], - [4.000000e+00, 5.000000e+00, 6.000000e+00]]> - -memref.global "private" constant @cst1 : memref = - dense<1.000000e+00> - -memref.global "private" constant @cst2 : memref<2x3xf32, #map> = - dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], - [4.000000e+00, 5.000000e+00, 6.000000e+00]]> - -// CHECK: func.func @get_global( -// CHECK-SAME: %[[ARG0:.*]]: memref<24xi8> {lmhlo.constant_name = "cst0"}, -// CHECK-SAME: %[[ARG1:.*]]: memref<4xi8> {lmhlo.constant_name = "cst1"}, -// CHECK-SAME: %[[ARG2:.*]]: memref<24xi8> {lmhlo.constant_name = "cst2"} -// CHECK-SAME: ) -func.func @get_global(%arg0: memref<24xi8> {lmhlo.constant_name = "cst0"}, - %arg1: memref<4xi8> {lmhlo.constant_name = "cst1"}, - %arg2: memref<24xi8> {lmhlo.constant_name = "cst2"}) - -> (memref<2x3xf32>, memref, memref<2x3xf32, #map>) { - - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[V0:.*]] = memref.view %[[ARG0]][%[[C0]]][] {{.*}} memref<2x3xf32> - %0 = memref.get_global @cst0 : memref<2x3xf32> - - // CHECK: %[[V1:.*]] = memref.get_global {{.*}} : memref - %1 = memref.get_global @cst1 : memref - - // CHECK: %[[C0_1:.*]] = arith.constant 0 : index - // CHECK: %[[F:.*]] = memref.view %[[ARG2]][%[[C0_1]]][] {{.*}} memref<6xf32> - // CHECK: %[[V2:.*]] = memref.reinterpret_cast %[[F]] - // CHECK-SAME: to offset: [0], sizes: [2, 3], strides: [1, 2] - %2 = memref.get_global @cst2 : memref<2x3xf32, #map> - - // CHECK: return %[[V0]], %[[V1]], %[[V2]] - // CHECK-SAME: : memref<2x3xf32>, memref, memref<2x3xf32, #map{{[0-9]*}}> - return %0, %1, %2 : memref<2x3xf32>, memref, memref<2x3xf32, #map> -} diff --git a/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir b/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir deleted file mode 100644 index ce446c8df5a80..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir +++ /dev/null @@ -1,686 +0,0 @@ -// RUN: xla-gpu-opt %s --split-input-file -xla-gpu-outline-gpu-graphs \ -// RUN: | FileCheck %s - -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref) kernel { - gpu.return - } - gpu.func @fn1(%arg0: memref) kernel { - gpu.return - } -} - -// CHECK: @func( -// CHECK: %[[ARG0:.*]]: memref, -// CHECK: %[[ARG1:.*]]: memref -// CHECK: ) -func.func @func(%arg0: memref, %arg1: memref) { - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %c5 = arith.constant 5 : index - %c6 = arith.constant 6 : index - - // CHECK: call @xla.gpu.graph.launch(%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: {capture = @xla.gpu.graph.capture} - // CHECK-NEXT: return - - gpu.launch_func @gpu_module::@fn0 - blocks in (%c1, %c2, %c3) - threads in (%c4, %c5, %c6) - args(%arg0 : memref) - - gpu.launch_func @gpu_module::@fn1 - blocks in (%c3, %c2, %c1) - threads in (%c6, %c5, %c4) - args(%arg1 : memref) - - func.return -} - -// CHECK: func @xla.gpu.graph.capture -// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 -// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 -// CHECK-NEXT: %[[C3:.*]] = arith.constant 3 -// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 -// CHECK-NEXT: %[[C5:.*]] = arith.constant 5 -// CHECK-NEXT: %[[C6:.*]] = arith.constant 6 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-SAME: blocks in (%[[C1]], %[[C2]], %[[C3]]) -// CHECK-SAME: threads in (%[[C4]], %[[C5]], %[[C6]]) -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 -// CHECK-SAME: blocks in (%[[C3]], %[[C2]], %[[C1]]) -// CHECK-SAME: threads in (%[[C6]], %[[C5]], %[[C4]]) -// CHECK-NEXT: return - -// CHECK: func private @xla.gpu.graph.launch(memref, memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.graph.launch"} -} - -// ----- -// Check that single function launch was not outlined into graph capture. - -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref) kernel { - gpu.return - } -} - -// CHECK: @func(%[[ARG0:.*]]: memref) -func.func @func(%arg0: memref) { - %c1 = arith.constant 1 : index - - // CHECK: gpu.launch_func {{.*}} args(%[[ARG0]] : memref) - // CHECK-NOT: call @xla.gpu.graph.launch - gpu.launch_func @gpu_module::@fn0 - blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - - func.return -} - -} - -// ----- -// Check that two different sequences are outlined in different capture -// functions. - -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref) kernel { - gpu.return - } - gpu.func @fn1(%arg0: memref) kernel { - gpu.return - } -} - -// CHECK: @func(%[[ARG0:.*]]: memref) -func.func @func(%arg0: memref) { - // CHECK: %[[C1:.*]] = arith.constant 1 - %c1 = arith.constant 1 : index - - // CHECK: call @xla.gpu.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @[[CAPTURE:.*]]} - - gpu.launch_func @gpu_module::@fn0 - blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - - gpu.launch_func @gpu_module::@fn1 - blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - - // CHECK: %[[C2:.*]] = arith.constant 2 - %c2 = arith.constant 2 : index - - // Use function call to break the captured ops sequence. - // CHECK: call @external - call @external(): () -> () - - // CHECK: call @xla.gpu.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @[[CAPTURE_0:.*]]} - - gpu.launch_func @gpu_module::@fn1 - blocks in (%c2, %c2, %c2) - threads in (%c2, %c2, %c2) - args(%arg0 : memref) - - gpu.launch_func @gpu_module::@fn0 - blocks in (%c2, %c2, %c2) - threads in (%c2, %c2, %c2) - args(%arg0 : memref) - - func.return -} - -func.func private @external() - -// CHECK: rt.export @[[CAPTURE]] -// CHECK: func.func @[[CAPTURE]]( -// CHECK: %arg0: memref -// CHECK: ) -// CHECK-NEXT: arith.constant 1 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 - -// CHECK: rt.export @[[CAPTURE_0]] -// CHECK: func.func @[[CAPTURE_0]]( -// CHECK: %arg0: memref -// CHECK: ) -// CHECK-NEXT: arith.constant 2 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 - -} - -// ----- -// Check that constants from the different basic blocks are cloned into the -// graph capture function. - -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref) kernel { - gpu.return - } - gpu.func @fn1(%arg0: memref) kernel { - gpu.return - } -} - -// CHECK: @func( -// CHECK: %[[ARG0:.*]]: memref, -// CHECK: %[[ARG1:.*]]: memref -// CHECK: ) -func.func @func(%arg0: memref, %arg1: memref) { - cf.br ^bb2 -^bb1: - // CHECK: call @xla.gpu.graph.launch(%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: {capture = @xla.gpu.graph.capture} - // CHECK-NEXT: return - - gpu.launch_func @gpu_module::@fn0 - blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - - gpu.launch_func @gpu_module::@fn1 - blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg1 : memref) - - func.return - -^bb2: - %c1 = arith.constant 1 : index - cf.br ^bb1 -} -} - -// CHECK: func @xla.gpu.graph.capture -// CHECK-NEXT: arith.constant 1 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 -// CHECK-NEXT: return - -// ----- -// Check that memref.view operations are cloned into the graph capture function. - -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<4xf32>) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<4xf32>) kernel { gpu.return } -} - -// CHECK: @func(%[[ARG0:.*]]: memref<16xi8>) -func.func @func(%arg0: memref<16xi8>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<16xi8> to memref<4xf32> - - call @external() : () -> () - - // CHECK: call @xla.gpu.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @xla.gpu.graph.capture} - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<4xf32>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<4xf32>) - - func.return -} - -func.func private @external() -} - -// CHECK: func @xla.gpu.graph.capture -// CHECK-NEXT: arith.constant 0 -// CHECK-NEXT: arith.constant 1 -// CHECK-NEXT: memref.view -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 -// CHECK-NEXT: return - -// ----- -// Check that memref.view not used by operations in the captured graph will not -// be moved into the graph capture function. - -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<16xi8>) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<16xi8>) kernel { gpu.return } -} - -// CHECK: @func(%[[ARG0:.*]]: memref<16xi8>) -func.func @func(%arg0: memref<16xi8>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - call @external() : () -> () - - // CHECK: call @xla.gpu.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @xla.gpu.graph.capture} - // CHECK-NEXT: memref.view - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%arg0 : memref<16xi8>) - %view = memref.view %arg0[%c0][] : memref<16xi8> to memref<4xf32> - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%arg0 : memref<16xi8>) - - func.return -} - -func.func private @external() -} - -// CHECK: func @xla.gpu.graph.capture -// CHECK-NEXT: arith.constant 1 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 -// CHECK-NEXT: return - -// ----- -// Check that lmhlo_gpu.gemm is moved into the graph capture function. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<16xi8>) kernel { gpu.return } - } - - // CHECK: @func(%[[ARG0:.*]]: memref<16xi8> {lmhlo.params = 0 : index} - // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> {lmhlo.params = 1 : index} - // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> - func.func @func(%raw_arg0: memref<16xi8> {lmhlo.params = 0 : index}, - %raw_arg1: memref<16xi8> {lmhlo.params = 1 : index}, - %raw_arg2: memref<16xi8> {lmhlo.output_index = dense<[0]> : tensor<1xindex>}) attributes { - result_xla_shape = "(f32[4]) " - } { - %c0 = arith.constant 0 : index - %arg0 = memref.view %raw_arg0[%c0][] : memref<16xi8> to memref<2x2xf32> - %c1 = arith.constant 0 : index - %arg1 = memref.view %raw_arg1[%c1][] : memref<16xi8> to memref<2x2xf32> - %c2 = arith.constant 0 : index - %arg2 = memref.view %raw_arg2[%c2][] : memref<16xi8> to memref<2x2xf32> - - // CHECK: call @xla.gpu.graph.launch(%[[ARG0]], %[[ARG1]], %[[ARG2]]) - // CHECK-SAME: {capture = @xla.gpu.graph.capture} - "lmhlo_gpu.gemm"(%arg0, %arg1, %arg2) {alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, beta = 0.000000e+00 : f64, batch_size = 1 : i64, lhs_stride = 4 : i64, rhs_stride = 4 : i64, dot_dimension_numbers = #mhlo.dot} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%raw_arg0 : memref<16xi8>) - "lmhlo.terminator"() : () -> () - } - - func.func private @external() -} - -// CHECK: func @xla.gpu.graph.capture -// CHECK-NEXT: arith.constant 0 -// CHECK-NEXT: memref.view -// CHECK-NEXT: arith.constant 0 -// CHECK-NEXT: memref.view -// CHECK-NEXT: arith.constant 0 -// CHECK-NEXT: memref.view -// CHECK-NEXT: "lmhlo_gpu.gemm" -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-NEXT: return - -// ----- -// Check that lmhlo_gpu.gemm with runtime autotuning is not captured by a CUDA -// graph. - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<16xi8>) kernel { gpu.return } - } - - // CHECK: @func(%[[ARG0:.*]]: memref<16xi8> {lmhlo.params = 0 : index} - // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> {lmhlo.params = 1 : index} - // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> - func.func @func(%raw_arg0: memref<16xi8> {lmhlo.params = 0 : index}, - %raw_arg1: memref<16xi8> {lmhlo.params = 1 : index}, - %raw_arg2: memref<16xi8> {lmhlo.output_index = dense<[0]> : tensor<1xindex>}) attributes { - result_xla_shape = "(f32[4]) " - } { - %c0 = arith.constant 0 : index - %arg0 = memref.view %raw_arg0[%c0][] : memref<16xi8> to memref<2x2xf32> - %c1 = arith.constant 0 : index - %arg1 = memref.view %raw_arg1[%c1][] : memref<16xi8> to memref<2x2xf32> - %c2 = arith.constant 0 : index - %arg2 = memref.view %raw_arg2[%c2][] : memref<16xi8> to memref<2x2xf32> - - - // CHECK-NOT: call @xla.gpu.graph.launch - // CHECK: "lmhlo_gpu.gemm" - "lmhlo_gpu.gemm"(%arg0, %arg1, %arg2) {algorithm = -5, alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, beta = 0.000000e+00 : f64, batch_size = 1 : i64, lhs_stride = 4 : i64, rhs_stride = 4 : i64, dot_dimension_numbers = #mhlo.dot} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%raw_arg0 : memref<16xi8>) - "lmhlo.terminator"() : () -> () - } - - func.func private @external() -} - -// ----- -// Check that convolution with runtime autotuning is not captured by a CUDA -// graph. - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * 3 + d1 + d2 * 9 + d3 * 9)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 4 + d2 + d3 * 16)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 2 + d2 + d3 * 4)> - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<16xi8>) kernel { gpu.return } - } - - - // CHECK: @func(%[[ARG0:.*]]: memref<8x5x5x1xf32, #map> - // CHECK-SAME: %[[ARG1:.*]]: memref<3x3x1x32xf32, #map1> - // CHECK-SAME: %[[ARG2:.*]]: memref<32xf32> - // CHECK-SAME: %[[ARG3:.*]]: memref<8x5x5x32xf32, #map2> - // CHECK-SAME: %[[ARG4:.*]]: memref<0xui8> - // CHECK-SAME: %[[ARG5:.*]]: memref<16xi8> - func.func @func(%input: memref<8x5x5x1xf32, #map1>, - %filter: memref<3x3x1x32xf32, #map0>, - %bias: memref<32xf32>, - %output: memref<8x5x5x32xf32, #map2>, - %scratch: memref<0xui8>, - %raw_arg0: memref<16xi8> {lmhlo.params = 0 : index} - ) { - %c0 = arith.constant 0 : index - - // CHECK-NOT: call @xla.g.cuda.graph.launch - // CHECK: lmhlo_gpu.conv_forward_fused - lmhlo_gpu.conv_forward_fused(%input, %filter, %bias, %output, %scratch) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = { stride = [1, 1], - lhs_dilate = [1, 1], - rhs_dilate = [1, 1], - reverse = [0, 0] - } - { activation_mode = #lmhlo_gpu, - leakyrelu_alpha = 0.0 : f64, - backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = -1, - is_cudnn_frontend = true, - is_cudnn_reordered_int8 = false, - knob_ids = [2, 3], - knob_values = [4, 0], - operand_0_layout = [2, 1, 3, 0], - operand_1_layout = [1, 0, 2, 3], - result_layout = [2, 1, 3, 0], - tensor_ops_enabled = false, - workspace_size = 0 - >, - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [], - result_scale = 1.000000e+00 : f64 - } : (memref<8x5x5x1xf32, #map1>, - memref<3x3x1x32xf32, #map0>, - memref<32xf32>, - memref<8x5x5x32xf32, #map2>, - memref<0xui8>) -> () - gpu.launch_func @gpu_module::@fn0 blocks in (%c0, %c0, %c0) - threads in (%c0, %c0, %c0) args(%raw_arg0 : memref<16xi8>) - return - } - func.func private @external() -} - -// ----- -// Check that convolutions are captured by cuda graphs. - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * 3 + d1 + d2 * 9 + d3 * 9)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 4 + d2 + d3 * 16)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 2 + d2 + d3 * 4)> - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<16xi8>) kernel { gpu.return } - } - - - // CHECK: @func(%[[ARG0:.*]]: memref<1x4x4x1024xf16, #map> - // CHECK-SAME: %[[ARG1:.*]]: memref<3x3x1x1024xf16, #map1> - // CHECK-SAME: %[[ARG2:.*]]: memref<1x2x2x1024xf16, #map2> - // CHECK-SAME: %[[ARG3:.*]]: memref<0xui8> - // CHECK-SAME: %[[ARG4:.*]]: memref<16xi8> - func.func @func(%input: memref<1x4x4x1024xf16, #map1>, - %filter: memref<3x3x1x1024xf16, #map0>, - %output: memref<1x2x2x1024xf16, #map2>, - %scratch: memref<0xui8>, - %raw_arg0: memref<16xi8> {lmhlo.params = 0 : index} - ) { - %c0 = arith.constant 0 : index - - // CHECK: call @xla.gpu.graph.launch(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) - // CHECK-SAME: {capture = @xla.gpu.graph.capture} - lmhlo_gpu.conv_forward(%input, %filter, %output, %scratch) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = { stride = [1, 1], - lhs_dilate = [1, 1], - rhs_dilate = [1, 1], - reverse = [0, 0] - } - { backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = 0, - is_cudnn_frontend = true, - is_cudnn_reordered_int8 = false, - knob_ids = [], - knob_values = [], - operand_0_layout = [2, 1, 3, 0], - operand_1_layout = [1, 0, 2, 3], - result_layout = [2, 1, 3, 0], - tensor_ops_enabled = false, - workspace_size = 0 - >, - batch_group_count = 1 : i64, - feature_group_count = 1024 : i64, - precision_config = [], - result_scale = 1.000000e+00 : f64 - } : (memref<1x4x4x1024xf16, #map1>, - memref<3x3x1x1024xf16, #map0>, - memref<1x2x2x1024xf16, #map2>, - memref<0xui8>) -> () - gpu.launch_func @gpu_module::@fn0 blocks in (%c0, %c0, %c0) - threads in (%c0, %c0, %c0) args(%raw_arg0 : memref<16xi8>) - return - } - func.func private @external() -} - -// CHECK: func @xla.gpu.graph.capture -// CHECK-NEXT: arith.constant 0 -// CHECK-NEXT: lmhlo_gpu.conv_forward -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-NEXT: return - -// ----- -// Check that d2d memcpy are captured. - -module attributes {gpu.container_module} { - - // CHECK: @func(%[[ARG0:.*]]: memref<100xi8>) - func.func @func(%arg0: memref<100xi8>) { - %c0 = arith.constant 0 : index - %dst = memref.view %arg0[%c0][] : memref<100xi8> to memref<10xf32> - %src = memref.view %arg0[%c0][] : memref<100xi8> to memref<10xf32> - - // CHECK: call @xla.gpu.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @xla.gpu.graph.capture} - gpu.memcpy %dst, %src : memref<10xf32>, memref<10xf32> - gpu.memcpy %dst, %src : memref<10xf32>, memref<10xf32> - - // CHECK: return - return - } - func.func private @external() -} - -// CHECK: func @xla.gpu.graph.capture -// CHECK: gpu.memcpy -// CHECK: gpu.memcpy -// CHECK-NEXT: return - -// ----- -// Check that memref.reinterpret_cast operations are cloned into the graph -// capture function. - -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<16xi8, strided<[1], offset: 0>>) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<16xi8, strided<[1], offset: 0>>) kernel { gpu.return } -} - -// CHECK: @func(%[[ARG0:.*]]: memref<16xi8>) -func.func @func(%arg0: memref<16xi8>) { - %c1 = arith.constant 1 : index - %view = memref.reinterpret_cast %arg0 to offset: [0], sizes: [16], strides: [1]: memref<16xi8> to memref<16xi8, strided<[1], offset: 0>> - - call @external() : () -> () - - // CHECK: call @xla.gpu.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @xla.gpu.graph.capture} - // CHECK-NEXT: return - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<16xi8, strided<[1], offset: 0>>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<16xi8, strided<[1], offset: 0>>) - - func.return -} - -func.func private @external() -} - -// CHECK: func @xla.gpu.graph.capture -// CHECK-NEXT: arith.constant 1 -// CHECK-NEXT: memref.reinterpret_cast -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 -// CHECK-NEXT: return - -// ----- -// Check that the loop body of lmhlo.while is cloned into the graph. - -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref<16xi8>) kernel { gpu.return } - gpu.func @fn1(%arg0: memref<16xi8>) kernel { gpu.return } -} - -// CHECK: @func(%[[ARG0:.*]]: memref<16xi8> -func.func @func(%arg0: memref<16xi8>, %cond: memref) { - %c1 = arith.constant 1 : index - - call @external() : () -> () - - "lmhlo.while"(%cond) ({ - // CHECK: func.call @xla.gpu.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @xla.gpu.graph.capture} - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%arg0: memref<16xi8>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%arg0: memref<16xi8>) - "lmhlo.terminator"() : () -> () }, { - // CHECK: func.call @xla.gpu.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @xla.gpu.graph.capture_0} - gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%arg0: memref<16xi8>) - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%arg0: memref<16xi8>) - "lmhlo.terminator"() : () -> () - }) : (memref) -> () - func.return -} - -func.func private @external() -} - -// CHECK: func @xla.gpu.graph.capture -// CHECK-NEXT: arith.constant 1 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 -// CHECK-NEXT: return - -// CHECK: func @xla.gpu.graph.capture_0 -// CHECK-NEXT: arith.constant 1 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 -// CHECK-NEXT: return - -// ----- -// Check that lmhlo.constant_name is propogated to the graph capture function -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref) kernel { - gpu.return - } - gpu.func @fn1(%arg0: memref) kernel { - gpu.return - } -} - -// CHECK: @func( -// CHECK: %[[ARG0:.*]]: memref {lmhlo.constant_name = "cst0"}, -// CHECK: %[[ARG1:.*]]: memref {lmhlo.constant_name = "cst1"} -// CHECK: ) -func.func @func(%arg0: memref {lmhlo.constant_name = "cst0"}, - %arg1: memref {lmhlo.constant_name = "cst1"}) { - %c1 = arith.constant 1 : index - - // CHECK: call @xla.gpu.graph.launch(%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: {capture = @xla.gpu.graph.capture} - // CHECK-NEXT: return - - gpu.launch_func @gpu_module::@fn0 - blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - - gpu.launch_func @gpu_module::@fn1 - blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg1 : memref) - - func.return -} - -// CHECK: func @xla.gpu.graph.capture( -// CHECK-SAME: %[[ARG0]]: memref {lmhlo.constant_name = "cst0", -// CHECK-SAME: %[[ARG1]]: memref {lmhlo.constant_name = "cst1", -// CHECK-SAME: ) -// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-SAME: blocks in (%[[C1]], %[[C1]], %[[C1]]) -// CHECK-SAME: threads in (%[[C1]], %[[C1]], %[[C1]]) -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 -// CHECK-SAME: blocks in (%[[C1]], %[[C1]], %[[C1]]) -// CHECK-SAME: threads in (%[[C1]], %[[C1]], %[[C1]]) -// CHECK-NEXT: return - -// CHECK: func private @xla.gpu.graph.launch(memref, memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.graph.launch"} -} diff --git a/xla/mlir/backends/gpu/transforms/tests/stream_assignment.mlir b/xla/mlir/backends/gpu/transforms/tests/stream_assignment.mlir deleted file mode 100644 index 9bac5e0cf8472..0000000000000 --- a/xla/mlir/backends/gpu/transforms/tests/stream_assignment.mlir +++ /dev/null @@ -1,190 +0,0 @@ -// RUN: xla-gpu-opt %s --split-input-file -xla-gpu-stream-assignment \ -// RUN: | FileCheck %s - -// ----- -// Check that independent kernels are assigned to different streams. -// A B--->C -// | ^ -// | | -// +--------+ -// -// Stream assignment: A->0 B->1 C->0 - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn2(%arg0: memref<3x3xi64>, %arg1: memref<3x3xi64>) kernel { gpu.return } - } - - // CHECK: func @xla.gpu.graph.capture - func.func @xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>) { - // CHECK: call @xla.gpu.concurrent_region.begin() - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - %view_0 = memref.view %arg1[%c0][] : memref<72xi8> to memref<3x3xi64> - - // CHECK: gpu.launch_func @gpu_module::@fn1 - // CHECK-SAME: {stream = 0 : i64} - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>) - // CHECK: gpu.launch_func @gpu_module::@fn1 - // CHECK-SAME: {stream = 1 : i64} - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - // CHECK: call @xla.streams.await() {from = 0 : i64, to = [1]} - // CHECK: gpu.launch_func @gpu_module::@fn2 - // CHECK-SAME: {stream = 0 : i64} - gpu.launch_func @gpu_module::@fn2 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>, %view_0 : memref<3x3xi64>) - // CHECK: call @xla.gpu.concurrent_region.end() - // CHECK: return - return - } -} - -// ----- -// Check that the assignment for the following pattern correctly exploits -// parallelism. -// A--->B C -// | ^ -// | | -// +--------+ -// -// Stream assignment: A->0 B->1 C->0 -// - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn2(%arg0: memref<3x3xi64> {lmhlo.written}, %arg1: memref<3x3xi64> {lmhlo.written}) kernel { gpu.return } - } - - - // CHECK: func @xla.gpu.graph.capture - func.func @xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>) { - // CHECK: call @xla.gpu.concurrent_region.begin() - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - %view_0 = memref.view %arg1[%c0][] : memref<72xi8> to memref<3x3xi64> - - // CHECK: gpu.launch_func @gpu_module::@fn2 - // CHECK-SAME: {stream = 0 : i64} - gpu.launch_func @gpu_module::@fn2 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>, %view_0 : memref<3x3xi64>) - // CHECK: call @xla.streams.await() {from = 1 : i64, to = [0]} - // CHECK: gpu.launch_func @gpu_module::@fn1 - // CHECK-SAME: {stream = 1 : i64} - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>) - // CHECK: gpu.launch_func @gpu_module::@fn1 - // CHECK-SAME: {stream = 0 : i64} - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - // CHECK: call @xla.gpu.concurrent_region.end() - // CHECK: return - return - } -} - -// ----- -// Check that stream with multiple dependencies is handled correctly. -// A B C-->D -// | | ^ -// | |--------| -// +-------------+ -// -// Stream assignment: A->0 B->1 C->2 D->0 -// - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn2(%arg0: memref<3x3xi64> {lmhlo.written}, %arg1: memref<3x3xi64> {lmhlo.written}, %arg3: memref<3x3xi64>) kernel { gpu.return } - } - - - // CHECK: func @xla.gpu.graph.capture - func.func @xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<72xi8>) { - // CHECK: call @xla.gpu.concurrent_region.begin() - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view_0 = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - %view_1 = memref.view %arg1[%c0][] : memref<72xi8> to memref<3x3xi64> - %view_2 = memref.view %arg2[%c0][] : memref<72xi8> to memref<3x3xi64> - - // CHECK: gpu.launch_func @gpu_module::@fn1 - // CHECK-SAME: {stream = 0 : i64} - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - // CHECK: gpu.launch_func @gpu_module::@fn1 - // CHECK-SAME: {stream = 1 : i64} - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_1 : memref<3x3xi64>) - // CHECK: gpu.launch_func @gpu_module::@fn1 - // CHECK-SAME: {stream = 2 : i64} - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_2 : memref<3x3xi64>) - // CHECK: call @xla.streams.await() {from = 0 : i64, to = [1, 2]} - // CHECK: gpu.launch_func @gpu_module::@fn2 - // CHECK-SAME: {stream = 0 : i64} - gpu.launch_func @gpu_module::@fn2 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>, %view_1 : memref<3x3xi64>, %view_2 : memref<3x3xi64>) - // CHECK: call @xla.gpu.concurrent_region.end() - // CHECK: return - return - } -} - -// ----- -// Check that stream synchronization only happens when two streams joins. -// A B--->C-->D -// | ^ -// | | -// +---------+ -// -// Stream assignment: A->0 B->1 C->0 D->0 -// - -module attributes {gpu.container_module} { - - gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } - gpu.func @fn2(%arg0: memref<3x3xi64> {lmhlo.written}, %arg1: memref<3x3xi64>) kernel { gpu.return } - } - - - // CHECK: func @xla.gpu.graph.capture - func.func @xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>) { - // CHECK: call @xla.gpu.concurrent_region.begin() - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %view_0 = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> - %view_1 = memref.view %arg1[%c0][] : memref<72xi8> to memref<3x3xi64> - - // CHECK: gpu.launch_func @gpu_module::@fn1 - // CHECK-SAME: {stream = 0 : i64} - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) - // CHECK: gpu.launch_func @gpu_module::@fn1 - // CHECK-SAME: {stream = 1 : i64} - gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_1 : memref<3x3xi64>) - // CHECK: call @xla.streams.await() {from = 0 : i64, to = [1]} - // CHECK: gpu.launch_func @gpu_module::@fn2 - // CHECK-SAME: {stream = 0 : i64} - gpu.launch_func @gpu_module::@fn2 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>, %view_1 : memref<3x3xi64>) - // CHECK-NEXT: gpu.launch_func @gpu_module::@fn2 - // CHECK-SAME: {stream = 0 : i64} - gpu.launch_func @gpu_module::@fn2 blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>, %view_1 : memref<3x3xi64>) - // CHECK: call @xla.gpu.concurrent_region.end() - // CHECK: return - return - } -} diff --git a/xla/mlir/backends/gpu/transforms/uid_generator.h b/xla/mlir/backends/gpu/transforms/uid_generator.h deleted file mode 100644 index 1a89d184ac04a..0000000000000 --- a/xla/mlir/backends/gpu/transforms/uid_generator.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_MLIR_BACKENDS_GPU_TRANSFORMS_UID_GENERATOR_H_ -#define XLA_MLIR_BACKENDS_GPU_TRANSFORMS_UID_GENERATOR_H_ - -#include - -namespace xla { -namespace gpu { - -// Every stateful operation in the module gets assigned a unique id, that is -// passed to the custom call handler. This id is used for caching resources -// between the different invocations of the same custom call (e.g. cache -// convolution descriptors). -// -// TODO(b/255600288): Improve stateful custom calls in Xla runtime. -class UidGenerator { - public: - UidGenerator() : uid_(0) {} - int64_t uid() { return uid_.fetch_add(1); } - - private: - std::atomic uid_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_MLIR_BACKENDS_GPU_TRANSFORMS_UID_GENERATOR_H_ diff --git a/xla/mlir/backends/gpu/xla-gpu-opt.cc b/xla/mlir/backends/gpu/xla-gpu-opt.cc deleted file mode 100644 index e974a69b725e2..0000000000000 --- a/xla/mlir/backends/gpu/xla-gpu-opt.cc +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project -#include "xla/mlir/backends/gpu/transforms/passes.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" - -int main(int argc, char **argv) { - mlir::DialectRegistry registry; - registry - .insert(); - mlir::func::registerAllExtensions(registry); - - xla::gpu::registerGpuTransformsPasses(); - - return failed(MlirOptMain(argc, argv, "Xla Gpu Pass Driver\n", registry)); -} diff --git a/xla/mlir/framework/ir/BUILD b/xla/mlir/framework/ir/BUILD index e9390a6a403fd..18c3101220b01 100644 --- a/xla/mlir/framework/ir/BUILD +++ b/xla/mlir/framework/ir/BUILD @@ -1,10 +1,11 @@ -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@tsl//tsl:tsl.bzl", "internal_visibility") load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") +load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], + default_visibility = internal_visibility(["//learning/brain/mlir:xla_friends"]), licenses = ["notice"], ) diff --git a/xla/mlir/framework/ir/xla_framework.cc b/xla/mlir/framework/ir/xla_framework.cc index 6cab85cd88515..7a5cb85706ac2 100644 --- a/xla/mlir/framework/ir/xla_framework.cc +++ b/xla/mlir/framework/ir/xla_framework.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/framework/ir/xla_framework.h b/xla/mlir/framework/ir/xla_framework.h index 690fb63e84062..c1d9be0f97bbc 100644 --- a/xla/mlir/framework/ir/xla_framework.h +++ b/xla/mlir/framework/ir/xla_framework.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/framework/ir/xla_framework_ops.td b/xla/mlir/framework/ir/xla_framework_ops.td index 9a894e1b5f820..f8015781bed1a 100644 --- a/xla/mlir/framework/ir/xla_framework_ops.td +++ b/xla/mlir/framework/ir/xla_framework_ops.td @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/framework/tests/BUILD b/xla/mlir/framework/tests/BUILD index 165b571df37e6..e0311ea4ac362 100644 --- a/xla/mlir/framework/tests/BUILD +++ b/xla/mlir/framework/tests/BUILD @@ -1,29 +1,23 @@ -load("@tsl//tsl:tsl.default.bzl", "filegroup") -load("//xla:glob_lit_test.bzl", "glob_lit_tests") +load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) -glob_lit_tests( +lit_test_suite( name = "all_tests", - data = [":test_utilities"], - driver = "@llvm-project//mlir:run_lit.sh", - test_file_exts = [ - "mlir", - "hlotxt", - ], -) - -# Bundle together all of the test utilities that are used by tests. -# This intentionally does not pull-in the top-level tf-opt to reduce the -# dependencies. -filegroup( - name = "test_utilities", - testonly = True, - data = [ - "//xla/translate/mhlo_to_lhlo_with_xla:xla-translate-opt", + srcs = enforce_glob( + [ + "legalize-xla-framework.mlir", + "outline-with-xla-framework.mlir", + "xla-framework.mlir", + ], + include = ["*.mlir"], + ), + cfg = "//xla:lit.cfg.py", + tools = [ + "//xla/translate:xla-translate-opt", "@llvm-project//llvm:FileCheck", ], ) diff --git a/xla/mlir/framework/transforms/BUILD b/xla/mlir/framework/transforms/BUILD index 3893d22f8fb29..f178411d0872e 100644 --- a/xla/mlir/framework/transforms/BUILD +++ b/xla/mlir/framework/transforms/BUILD @@ -1,10 +1,11 @@ -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("@tsl//tsl:tsl.bzl", "internal_visibility") load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") +load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], + default_visibility = internal_visibility(["//learning/brain/mlir:xla_friends"]), licenses = ["notice"], ) diff --git a/xla/mlir/framework/transforms/outline_with_xla_framework.cc b/xla/mlir/framework/transforms/outline_with_xla_framework.cc index 0b4857502e4f6..8efd1726ed05c 100644 --- a/xla/mlir/framework/transforms/outline_with_xla_framework.cc +++ b/xla/mlir/framework/transforms/outline_with_xla_framework.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/framework/transforms/passes.h b/xla/mlir/framework/transforms/passes.h index cb11e9e96deb9..213ced404a71a 100644 --- a/xla/mlir/framework/transforms/passes.h +++ b/xla/mlir/framework/transforms/passes.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/framework/transforms/passes.td b/xla/mlir/framework/transforms/passes.td index b062f63837999..f851fbce6858f 100644 --- a/xla/mlir/framework/transforms/passes.td +++ b/xla/mlir/framework/transforms/passes.td @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/framework/transforms/xla_framework_to_llvm_pass.cc b/xla/mlir/framework/transforms/xla_framework_to_llvm_pass.cc index 3d127ee7a084f..dc609381152f6 100644 --- a/xla/mlir/framework/transforms/xla_framework_to_llvm_pass.cc +++ b/xla/mlir/framework/transforms/xla_framework_to_llvm_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/math/BUILD b/xla/mlir/math/BUILD deleted file mode 100644 index 0fbd61df0678f..0000000000000 --- a/xla/mlir/math/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -package_group( - name = "friends", - packages = [ - "//xla/mlir/...", - # copybara:uncomment_begin(google-only) - # # TODO(ezhulenev): Clean up dependencies that are leforvers from Autofusion project. - # "//third_party/tf_runtime/...", - # copybara:uncomment_end(google-only) - ], -) - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [":friends"], - licenses = ["notice"], -) diff --git a/xla/mlir/math/transforms/BUILD b/xla/mlir/math/transforms/BUILD deleted file mode 100644 index b23eed4d7ee63..0000000000000 --- a/xla/mlir/math/transforms/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//xla/mlir/math:friends"], - licenses = ["notice"], -) - -gentbl_cc_library( - name = "passes_inc_gen", - compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=MathTransforms", - ], - "passes.h.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "passes.td", - deps = ["@llvm-project//mlir:PassBaseTdFiles"], -) - -cc_library( - name = "passes", - srcs = [ - "math_approximation.cc", - "math_optimization.cc", - ], - hdrs = ["passes.h"], - compatible_with = get_compatible_with_portable(), - deps = [ - ":passes_inc_gen", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMCommonConversion", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:MathToLLVM", - "@llvm-project//mlir:MathToLibm", - "@llvm-project//mlir:MathTransforms", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:VectorDialect", - "@llvm-project//mlir:X86VectorDialect", - ], -) diff --git a/xla/mlir/math/transforms/math_approximation.cc b/xla/mlir/math/transforms/math_approximation.cc deleted file mode 100644 index a96d361f01a6f..0000000000000 --- a/xla/mlir/math/transforms/math_approximation.cc +++ /dev/null @@ -1,774 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/Math/Transforms/Passes.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/mlir/math/transforms/passes.h" - -namespace xla { -namespace { - -#define GEN_PASS_DEF_MATHAPPROXIMATIONPASS -#include "xla/mlir/math/transforms/passes.h.inc" - -using ::llvm::ArrayRef; -using ::llvm::SmallVector; - -using ::mlir::ImplicitLocOpBuilder; -using ::mlir::LogicalResult; -using ::mlir::OperationPass; -using ::mlir::OpRewritePattern; -using ::mlir::PatternRewriter; -using ::mlir::RewritePatternSet; -using ::mlir::Type; -using ::mlir::Value; -using ::mlir::VectorType; - -namespace arith = ::mlir::arith; -namespace func = ::mlir::func; -namespace math = ::mlir::math; -namespace vector = ::mlir::vector; - -using TypePredicate = ::llvm::function_ref; - -#define LN2_VALUE \ - 0.693147180559945309417232121458176568075500134360255254120680009493393621L -#define LOG2E_VALUE \ - 1.442695040888963407359924681001892137426645954152985934135449406931109219L - -// Returns vector shape if the element type is matching the predicate (scalars -// that do match the predicate have shape equal to `{1}`). -std::optional> vectorShape(Type type, - TypePredicate pred) { - // If the type matches the predicate then its shape is `{1}`. - if (pred(type)) return SmallVector{1}; - - // Otherwise check if the type is a vector type. - auto vectorType = type.dyn_cast(); - if (vectorType && pred(vectorType.getElementType())) { - return llvm::to_vector<2>(vectorType.getShape()); - } - - return std::nullopt; -} - -bool isF32(Type type) { return type.isF32(); } -bool isI32(Type type) { return type.isInteger(32); } - -//----------------------------------------------------------------------------// -// Broadcast scalar types and values into vector types and values. -//----------------------------------------------------------------------------// - -// Returns true if shape != {1}. -bool isNonScalarShape(ArrayRef shape) { - return shape.size() > 1 || shape[0] > 1; -} - -// Broadcasts scalar type into vector type (iff shape is non-scalar). -Type broadcast(Type type, ArrayRef shape) { - assert(!type.isa() && "must be scalar type"); - return isNonScalarShape(shape) ? VectorType::get(shape, type) : type; -} - -// Broadcasts scalar value into vector (iff shape is non-scalar). -Value broadcast(ImplicitLocOpBuilder &builder, Value value, - ArrayRef shape) { - assert(!value.getType().isa() && "must be scalar value"); - auto type = broadcast(value.getType(), shape); - return isNonScalarShape(shape) - ? builder.create(type, value) - : value; -} - -//----------------------------------------------------------------------------// -// Helper functions to create constants. -//----------------------------------------------------------------------------// - -Value f32Cst(ImplicitLocOpBuilder &builder, float value) { - return builder.create(builder.getF32FloatAttr(value)); -} - -Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) { - return builder.create(builder.getI32IntegerAttr(value)); -} - -Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { - Value i32v = i32Cst(builder, static_cast(bits)); - return builder.create(builder.getF32Type(), i32v); -} - -//----------------------------------------------------------------------------// -// Helper functions to build math functions approximations. -//----------------------------------------------------------------------------// - -// Return the clamped value or NaN if value is NaN. -// Note: the bounds must be normal, not NaN's. -Value ClampWithNormals(ImplicitLocOpBuilder &builder, - const llvm::SmallVector &shape, Value value, - float lower_bound, float upper_bound) { - assert(!std::isnan(lower_bound)); - assert(!std::isnan(upper_bound)); - - auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, shape); - }; - - auto select_cmp = [&builder](auto pred, Value value, Value bound) { - return builder.create( - builder.create(pred, value, bound), value, bound); - }; - - // Note: prefer UGE/ULE vs. UGT/ULT, since they generate vmaxps/vminps vs. - // vcmpleps+vmovaps on x86_64. The latter outcome is also obtained with - // arith::{Max,Min}FOp. - value = select_cmp(arith::CmpFPredicate::UGE, value, - bcast(f32Cst(builder, lower_bound))); - value = select_cmp(arith::CmpFPredicate::ULE, value, - bcast(f32Cst(builder, upper_bound))); - return value; -} - -// Return the maximum of the two values or NaN if value is NaN -Value Max(ImplicitLocOpBuilder &builder, Value value, Value bound) { - return builder.create( - builder.create(arith::CmpFPredicate::UGE, value, bound), - value, bound); -} - -// Computes exp2 for an i32 argument. -Value Exp2I32(ImplicitLocOpBuilder &builder, Value arg) { - auto shape = vectorShape(arg.getType(), isI32); - assert(shape.has_value() && "arg must be of i32 type"); - - auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *shape); - }; - - auto f32_vec = broadcast(builder.getF32Type(), *shape); - // The exponent of f32 located at 23-bit. - Value cst_exponent_bit = bcast(i32Cst(builder, 23)); - // Set the exponent bias to zero. - Value cst_bias = bcast(i32Cst(builder, 127)); - - Value biased_arg = builder.create(arg, cst_bias); - Value exp2_i32 = builder.create(biased_arg, cst_exponent_bit); - Value exp2_f32 = builder.create(f32_vec, exp2_i32); - - return exp2_f32; -} - -// Decomposes given floating point value `arg` into a normalized fraction and -// an integral power of two (see std::frexp). Returned values have float type. -std::pair Frexp(ImplicitLocOpBuilder &builder, Value arg, - bool isPositive = false) { - auto shape = vectorShape(arg.getType(), isF32); - assert(shape.has_value() && "arg must be of f32 type"); - - auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *shape); - }; - - auto i32 = builder.getIntegerType(32); - auto i32_vec = broadcast(builder.getI32Type(), *shape); - auto f32_vec = broadcast(builder.getF32Type(), *shape); - - Value cst126f = f32Cst(builder, 126.0f); - Value cst_half = f32Cst(builder, 0.5f); - Value cst_inv_mant_mask = f32FromBits(builder, ~0x7f800000u); - - // Bitcast to i32 for bitwise operations. - Value i32_half = builder.create(i32, cst_half); - Value i32_inv_mant_mask = - builder.create(i32, cst_inv_mant_mask); - Value i32_arg = builder.create(i32_vec, arg); - - // Compute normalized fraction. - Value tmp0 = builder.create(i32_arg, bcast(i32_inv_mant_mask)); - Value tmp1 = builder.create(tmp0, bcast(i32_half)); - Value normalized_fraction = builder.create(f32_vec, tmp1); - - // Compute exponent. - Value arg0 = isPositive ? arg : builder.create(arg); - Value biased_exponent_bits = builder.create( - builder.create(i32_vec, arg0), - bcast(i32Cst(builder, 23))); - Value biased_exponent = - builder.create(f32_vec, biased_exponent_bits); - Value exponent = - builder.create(biased_exponent, bcast(cst126f)); - - return {normalized_fraction, exponent}; -} - -struct ExpM1Approximation : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(math::ExpM1Op op, - PatternRewriter &rewriter) const final; -}; - -// This approximation comes from XLA Classic. -LogicalResult ExpM1Approximation::matchAndRewrite( - math::ExpM1Op op, PatternRewriter &rewriter) const { - auto shape = vectorShape(op.getOperand().getType(), isF32); - if (!shape.has_value()) - return rewriter.notifyMatchFailure(op, "unsupported operand type"); - - ImplicitLocOpBuilder builder(op->getLoc(), rewriter); - auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *shape); - }; - - Value cst_zero = bcast(f32Cst(builder, 0.0f)); - Value cst_half = bcast(f32Cst(builder, 0.5f)); - Value cst_one = bcast(f32Cst(builder, 1.0f)); - - // expm1(x) == tanh(x/2)*(exp(x)+1) - // x/2 can underflow, if it does we approximate expm1 with x. - Value x = op.getOperand(); - Value x_over_two = builder.create(x, cst_half); - Value x_over_two_is_zero = builder.create( - arith::CmpFPredicate::OEQ, x_over_two, cst_zero); - Value abs_x = builder.create(x); - - Value abs_x_is_large = - builder.create(arith::CmpFPredicate::OGT, abs_x, cst_half); - Value tanh_of_x_over_two = builder.create(x_over_two); - Value exp_of_x = builder.create(x); - Value exp_of_x_plus_one = builder.create(exp_of_x, cst_one); - Value exp_of_x_minus_one = builder.create(exp_of_x, cst_one); - - Value expm1_of_x = - builder.create(tanh_of_x_over_two, exp_of_x_plus_one); - expm1_of_x = builder.create(abs_x_is_large, - exp_of_x_minus_one, expm1_of_x); - expm1_of_x = - builder.create(x_over_two_is_zero, x, expm1_of_x); - - rewriter.replaceOp(op, expm1_of_x); - return mlir::success(); -} - -struct ExpApproximation : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(math::ExpOp op, - PatternRewriter &rewriter) const final; -}; - -LogicalResult ExpApproximation::matchAndRewrite( - math::ExpOp op, PatternRewriter &rewriter) const { - auto shape = vectorShape(op.getOperand().getType(), isF32); - if (!shape.has_value()) { - return rewriter.notifyMatchFailure(op, "unsupported operand type"); - } - - ImplicitLocOpBuilder builder(op->getLoc(), rewriter); - - auto add = [&](Value a, Value b) -> Value { - return builder.create(a, b); - }; - auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *shape); - }; - auto floor = [&](Value a) { return builder.create(a); }; - auto fmla = [&](Value a, Value b, Value c) { - return builder.create(a, b, c); - }; - auto mul = [&](Value a, Value b) -> Value { - return builder.create(a, b); - }; - - // Polynomial approximation. Originally from Cephes, but then modified for - // XLA Classic. - // - // To compute e^x, we re-express it as - // - // e^x = e^(a + b) - // = e^(a + n log(2)) - // = e^a * 2^n. - // - // We choose n = round(x / log(2)), restricting the value of `a` to - // (-log(2)/2, log(2)/2). We then use a polynomial to compute e^a. The - // relative error between our approximation and the true value of e^a is less - // than 2^-22.5 for all values of `a` within this range. - - // Restrict input to a small range, including some values that evaluate to - // +/- inf. Note that for our lower bound, we choose log(2^-126) instead of - // log(F32_EPSILON). We do so because this routine always flushes denormal - // floating points to 0. Therefore, we only need to worry about exponentiating - // up to the smallest representable non-denormal floating point, which is - // 2^-126. - - // Constants. - Value cst_half = bcast(f32Cst(builder, 0.5f)); - Value cst_one = bcast(f32Cst(builder, 1.0f)); - - // 1/log(2) - Value cst_log2ef = bcast(f32Cst(builder, 1.44269504088896341f)); - - Value cst_exp_c1 = bcast(f32Cst(builder, -0.693359375f)); - Value cst_exp_c2 = bcast(f32Cst(builder, 2.12194440e-4f)); - Value cst_exp_p0 = bcast(f32Cst(builder, 1.9875691500E-4f)); - Value cst_exp_p1 = bcast(f32Cst(builder, 1.3981999507E-3f)); - Value cst_exp_p2 = bcast(f32Cst(builder, 8.3334519073E-3f)); - Value cst_exp_p3 = bcast(f32Cst(builder, 4.1665795894E-2f)); - Value cst_exp_p4 = bcast(f32Cst(builder, 1.6666665459E-1f)); - Value cst_exp_p5 = bcast(f32Cst(builder, 5.0000001201E-1f)); - - // Our computations below aren't particularly sensitive to the exact choices - // here, so we choose values a bit larger/smaller than - // - // log(F32_MAX) = 88.723... - // log(2^-126) = -87.337... - Value x = op.getOperand(); - x = ClampWithNormals(builder, *shape, x, -87.8f, 88.8f); - Value n = floor(fmla(x, cst_log2ef, cst_half)); - - // When we eventually do the multiplication in e^a * 2^n, we need to handle - // the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1 - // (so e^a * 2^n != inf). There's a similar problem for n < -126, the - // smallest fp32 exponent. - // - // A straightforward solution would be to detect n out of range and split it - // up, doing - // - // e^a * 2^n = e^a * 2^(n1 + n2) - // = (2^n1 * e^a) * 2^n2. - // - // But it turns out this approach is quite slow, probably because it - // manipulates subnormal values. - // - // The approach we use instead is to clamp n to [-127, 127]. Let n' be the - // value of n clamped to [-127, 127]. In the case where n' = 127, `a` can grow - // up to as large as 88.8 - 127 * log(2) which is about 0.7703. Even though - // this value of `a` is outside our previously specified range, e^a will still - // only have a relative error of approximately 2^-16 at worse. In practice - // this seems to work well enough; it passes our exhaustive tests, breaking - // only one result, and by one ulp (we return exp(88.7228394) = max-float but - // we should return inf). - // - // In the case where n' = -127, the original input value of x is so small that - // e^x, our final answer, is less than 2^-126. Since 2^-126 is the smallest - // normal floating point, and since we flush denormals, we simply return 0. We - // do this in a branchless way by observing that our code for constructing 2^n - // produces 0 if n = -127. - // - // The proof that n' = -127 implies e^x < 2^-126 is as follows: - // - // n' = -127 implies n <= -127 - // implies round(x / log(2)) <= -127 - // implies x/log(2) < -126.5 - // implies x < -126.5 * log(2) - // implies e^x < e^(-126.5 * log(2)) - // implies e^x < 2^-126.5 < 2^-126 - // - // This proves that n' = -127 implies e^x < 2^-126. - n = ClampWithNormals(builder, *shape, n, -127.0f, 127.0f); - - // Computes x = x - n' * log(2), the value for `a` - x = fmla(cst_exp_c1, n, x); - x = fmla(cst_exp_c2, n, x); - - // Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5). - Value z = fmla(x, cst_exp_p0, cst_exp_p1); - z = fmla(z, x, cst_exp_p2); - z = fmla(z, x, cst_exp_p3); - z = fmla(z, x, cst_exp_p4); - z = fmla(z, x, cst_exp_p5); - z = fmla(z, mul(x, x), x); - z = add(cst_one, z); - - // Convert n' to an i32. This is safe because we clamped it above. - auto i32_vec = broadcast(builder.getI32Type(), *shape); - Value n_i32 = builder.create(i32_vec, n); - - // Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127. - Value pow2 = Exp2I32(builder, n_i32); - - // Return z * 2^n' if -126 <= n' <= 127 and 0 if n = -127. - Value ret = mul(z, pow2); - - rewriter.replaceOp(op, ret); - return mlir::success(); -} - -template -struct LogApproximationBase : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise. - LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter, - bool base2) const; -}; - -// This approximation comes from Julien Pommier's SSE math library. -// Link: http://gruntthepeon.free.fr/ssemath -template -LogicalResult LogApproximationBase::logMatchAndRewrite( - Op op, PatternRewriter &rewriter, bool base2) const { - auto shape = vectorShape(op.getOperand().getType(), isF32); - if (!shape.has_value()) { - return rewriter.notifyMatchFailure(op, "unsupported operand type"); - } - - ImplicitLocOpBuilder builder(op->getLoc(), rewriter); - auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *shape); - }; - - Value cst_zero = bcast(f32Cst(builder, 0.0f)); - Value cst_one = bcast(f32Cst(builder, 1.0f)); - Value cst_neg_half = bcast(f32Cst(builder, -0.5f)); - - // The smallest non denormalized float number. - Value cst_min_norm_pos = bcast(f32FromBits(builder, 0x00800000u)); - Value cst_minus_inf = bcast(f32FromBits(builder, 0xff800000u)); - Value cst_pos_inf = bcast(f32FromBits(builder, 0x7f800000u)); - Value cst_nan = bcast(f32FromBits(builder, 0x7fc00000)); - - // Polynomial coefficients. - Value cst_cephes_sqrthf = bcast(f32Cst(builder, 0.707106781186547524f)); - - // Truncate input values to the minimum positive normal. - // Extract significant in the range [0.5,1) and exponent. - auto [x, e] = Frexp(builder, Max(builder, op.getOperand(), cst_min_norm_pos), - /*isPositive=*/true); - - // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift - // by -1.0. The values are then centered around 0, which improves the - // stability of the polynomial evaluation: - // - // if( x < SQRTHF ) { - // e -= 1; - // x = x + x - 1.0; - // } else { x = x - 1.0; } - Value mask = builder.create(arith::CmpFPredicate::OLT, x, - cst_cephes_sqrthf); - Value tmp = builder.create(mask, x, cst_zero); - - x = builder.create(x, cst_one); - e = builder.create( - e, builder.create(mask, cst_one, cst_zero)); - x = builder.create(x, tmp); - - Value x2 = builder.create(x, x); - Value x3 = builder.create(x2, x); - - Value cephes_log_p0 = bcast(f32Cst(builder, 7.0376836292E-2)); - Value cephes_log_p1 = bcast(f32Cst(builder, -1.1514610310E-1)); - Value cephes_log_p2 = bcast(f32Cst(builder, 1.1676998740E-1)); - Value cephes_log_p3 = bcast(f32Cst(builder, -1.2420140846E-1)); - Value cephes_log_p4 = bcast(f32Cst(builder, +1.4249322787E-1)); - Value cephes_log_p5 = bcast(f32Cst(builder, -1.6668057665E-1)); - Value cephes_log_p6 = bcast(f32Cst(builder, +2.0000714765E-1)); - Value cephes_log_p7 = bcast(f32Cst(builder, -2.4999993993E-1)); - Value cephes_log_p8 = bcast(f32Cst(builder, +3.3333331174E-1)); - Value cephes_log_q1 = bcast(f32Cst(builder, -2.12194440e-4)); - Value cephes_log_q2 = bcast(f32Cst(builder, 0.693359375)); - Value half = bcast(f32Cst(builder, 0.5f)); - - // Evaluate the polynomial approximant of degree 8 in three parts. - Value y = builder.create(x, cephes_log_p0, cephes_log_p1); - Value y1 = builder.create(x, cephes_log_p3, cephes_log_p4); - Value y2 = builder.create(x, cephes_log_p6, cephes_log_p7); - y = builder.create(y, x, cephes_log_p2); - y1 = builder.create(y1, x, cephes_log_p5); - y2 = builder.create(y2, x, cephes_log_p8); - // y = y * x3 + y1 - y = builder.create(y, x3, y1); - // y = y * x3 + y2 - y = builder.create(y, x3, y2); - // y *= x3 - y = builder.create(y, x3); - - Value tmp1 = builder.create(cephes_log_q1, e); - Value tmp2 = builder.create(half, x2); - if (base2) { - x = builder.create(x, tmp2); - Value cst_log2e = bcast(f32Cst(builder, static_cast(LOG2E_VALUE))); - x = builder.create(x, cst_log2e, e); - } else { - // y += log_q1 * e - y = builder.create(y, tmp1); - // x -= 0.5 * x2 - x = builder.create(x, tmp2); - Value tmp3 = builder.create(cephes_log_q2, e); - // x += y - x = builder.create(x, y); - // x += log_q2 * e - x = builder.create(x, tmp3); - } - - Value invalid_mask = builder.create(arith::CmpFPredicate::ULT, - op.getOperand(), cst_zero); - Value zero_mask = builder.create(arith::CmpFPredicate::OEQ, - op.getOperand(), cst_zero); - Value pos_inf_mask = builder.create( - arith::CmpFPredicate::OEQ, op.getOperand(), cst_pos_inf); - - // Filter out invalid values: - // • x == 0 -> -INF - // • x < 0 -> NAN - // • x == +INF -> +INF - Value aproximation = builder.create( - zero_mask, cst_minus_inf, - builder.create( - invalid_mask, cst_nan, - builder.create(pos_inf_mask, cst_pos_inf, x))); - - rewriter.replaceOp(op, aproximation); - - return mlir::success(); -} - -struct Log2Approximation : public LogApproximationBase { - using LogApproximationBase::LogApproximationBase; - - LogicalResult matchAndRewrite(math::Log2Op op, - PatternRewriter &rewriter) const final { - return logMatchAndRewrite(op, rewriter, /*base2=*/true); - } -}; - -struct LogApproximation : public LogApproximationBase { - using LogApproximationBase::LogApproximationBase; - - LogicalResult matchAndRewrite(math::LogOp op, - PatternRewriter &rewriter) const final { - return logMatchAndRewrite(op, rewriter, /*base2=*/false); - } -}; - -struct Log1pApproximation : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(math::Log1pOp op, - PatternRewriter &rewriter) const final; -}; - -// Approximate log(1+x). -LogicalResult Log1pApproximation::matchAndRewrite( - math::Log1pOp op, PatternRewriter &rewriter) const { - auto shape = vectorShape(op.getOperand().getType(), isF32); - if (!shape.has_value()) { - return rewriter.notifyMatchFailure(op, "unsupported operand type"); - } - - ImplicitLocOpBuilder builder(op->getLoc(), rewriter); - auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *shape); - }; - - // Approximate log(1+x) using the following, due to W. Kahan: - // u = x + 1.0; - // if (u == 1.0 || u == inf) return x; - // return x * log(u) / (u - 1.0); - // ^^^^^^^^^^^^^^^^^^^^^^ - // "log_large" below. - Value cst_one = bcast(f32Cst(builder, 1.0f)); - Value cst_negative_half = bcast(f32Cst(builder, -0.5f)); - - Value x = op.getOperand(); - Value for_large_x = - builder.create(builder.create(cst_one, x)); - - // When x is small, (defined to be less than sqrt(2) / 2), use a rational - // approximation. The approximation below is based on one from the Cephes - // Mathematical Library. - // - // sqrt(2) - 1. - const auto kAntilogarithmIsSmallThreshold = 0.41421356237309504880; - - static const std::array kDenominatorCoeffs{ - 1., - 1.5062909083469192043167E1, - 8.3047565967967209469434E1, - 2.2176239823732856465394E2, - 3.0909872225312059774938E2, - 2.1642788614495947685003E2, - 6.0118660497603843919306E1, - }; - - static const std::array kNumeratorCoeffs{ - 4.5270000862445199635215E-5, 4.9854102823193375972212E-1, - 6.5787325942061044846969E0, 2.9911919328553073277375E1, - 6.0949667980987787057556E1, 5.7112963590585538103336E1, - 2.0039553499201281259648E1, - }; - - auto eval_polynomial = [&](const std::array &coefficients) { - auto poly = bcast(f32Cst(builder, 0.0)); - for (double c : coefficients) { - poly = builder.create(poly, x, bcast(f32Cst(builder, c))); - } - return poly; - }; - - auto x_squared = builder.create(x, x); - Value denominator = eval_polynomial(kDenominatorCoeffs); - Value numerator = eval_polynomial(kNumeratorCoeffs); - Value for_small_x = builder.create(numerator, denominator); - for_small_x = builder.create( - builder.create(x, x_squared), for_small_x); - for_small_x = - builder.create(cst_negative_half, x_squared, for_small_x); - for_small_x = builder.create(x, for_small_x); - - auto abs_x = builder.create(x); - auto x_is_small = builder.create( - arith::CmpFPredicate::OLT, abs_x, - bcast(f32Cst(builder, kAntilogarithmIsSmallThreshold))); - Value approximation = - builder.create(x_is_small, for_small_x, for_large_x); - - rewriter.replaceOp(op, approximation); - return mlir::success(); -} - -struct TanhApproximation : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(math::TanhOp op, - PatternRewriter &rewriter) const final; -}; - -// This approximation comes from Eigen::generic_fast_tanh function. -LogicalResult TanhApproximation::matchAndRewrite( - math::TanhOp op, PatternRewriter &rewriter) const { - auto shape = vectorShape(op.getOperand().getType(), isF32); - if (!shape.has_value()) { - return rewriter.notifyMatchFailure(op, "unsupported operand type"); - } - - ImplicitLocOpBuilder builder(op->getLoc(), rewriter); - auto bcast = [&](Value value) -> Value { - return broadcast(builder, value, *shape); - }; - - Value x = ClampWithNormals(builder, *shape, op.getOperand(), - -7.99881172180175781f, 7.99881172180175781f); - - // Mask for tiny values that are approximated with `operand`. - Value tiny = bcast(f32Cst(builder, 0.0004f)); - Value tiny_mask = builder.create( - arith::CmpFPredicate::OLT, builder.create(op.getOperand()), - tiny); - - // The monomial coefficients of the numerator polynomial (odd). - Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f)); - Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f)); - Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f)); - Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f)); - Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f)); - Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f)); - Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f)); - - // The monomial coefficients of the denominator polynomial (even). - Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f)); - Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f)); - Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f)); - Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f)); - - // Since the polynomials are odd/even, we need x^2. - Value x2 = builder.create(x, x); - - // Evaluate the numerator polynomial p. - Value p = builder.create(x2, alpha13, alpha11); - p = builder.create(x2, p, alpha9); - p = builder.create(x2, p, alpha7); - p = builder.create(x2, p, alpha5); - p = builder.create(x2, p, alpha3); - p = builder.create(x2, p, alpha1); - p = builder.create(x, p); - - // Evaluate the denominator polynomial q. - Value q = builder.create(x2, beta6, beta4); - q = builder.create(x2, q, beta2); - q = builder.create(x2, q, beta0); - - // Divide the numerator by the denominator. - Value res = builder.create( - tiny_mask, x, builder.create(p, q)); - - rewriter.replaceOp(op, res); - - return mlir::success(); -} - -void populateMathApproximationPatterns(RewritePatternSet &patterns, - ArrayRef oplist) { - for (const std::string &op : oplist) { - if (op == "all") { - patterns.add( - patterns.getContext()); - } else if (op == "exp") { - patterns.add(patterns.getContext()); - } else if (op == "expm1") { - patterns.add(patterns.getContext()); - } else if (op == "log") { - patterns.add(patterns.getContext()); - } else if (op == "log1p") { - patterns.add(patterns.getContext()); - } else if (op == "log2") { - patterns.add(patterns.getContext()); - } else if (op == "tanh") { - patterns.add(patterns.getContext()); - } - } -} - -struct MathApproximationPass - : public impl::MathApproximationPassBase { - explicit MathApproximationPass(ArrayRef approx_oplist) { - this->oplist = approx_oplist; - } - - void runOnOperation() override; -}; - -void MathApproximationPass::runOnOperation() { - RewritePatternSet patterns(&getContext()); - populateMathApproximationPatterns(patterns, oplist); - if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) - signalPassFailure(); -} - -} // namespace - -std::unique_ptr> CreateMathApproximationPass( - ArrayRef oplist) { - return std::make_unique(oplist); -} - -} // namespace xla diff --git a/xla/mlir/math/transforms/math_optimization.cc b/xla/mlir/math/transforms/math_optimization.cc deleted file mode 100644 index a00e5325cdfec..0000000000000 --- a/xla/mlir/math/transforms/math_optimization.cc +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "mlir/Dialect/Math/Transforms/Passes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" // from @llvm-project -#include "mlir/Dialect/X86Vector/X86VectorDialect.h" // from @llvm-project -#include "xla/mlir/math/transforms/passes.h" - -namespace xla { - -using namespace mlir; // NOLINT - -#define GEN_PASS_DEF_MATHOPTIMIZATIONPASS -#include "xla/mlir/math/transforms/passes.h.inc" - -struct MathOptimizationPass - : public impl::MathOptimizationPassBase { - explicit MathOptimizationPass(bool enable_avx2) { - enable_avx2_ = enable_avx2; - } - void runOnOperation() override; -}; - -void MathOptimizationPass::runOnOperation() { - RewritePatternSet patterns(&getContext()); - populateMathAlgebraicSimplificationPatterns(patterns); - - MathPolynomialApproximationOptions approx_options; - approx_options.enableAvx2 = enable_avx2_; - populateMathPolynomialApproximationPatterns(patterns, approx_options); - - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) - signalPassFailure(); -} - -std::unique_ptr> CreateMathOptimizationPass( - bool enable_avx2) { - return std::make_unique(enable_avx2); -} - -} // namespace xla diff --git a/xla/mlir/math/transforms/passes.h b/xla/mlir/math/transforms/passes.h deleted file mode 100644 index 512f3eec1577c..0000000000000 --- a/xla/mlir/math/transforms/passes.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_MLIR_MATH_TRANSFORMS_PASSES_H_ -#define XLA_MLIR_MATH_TRANSFORMS_PASSES_H_ - -#include -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project - -namespace xla { - -#define GEN_PASS_DECL_MATHAPPROXIMATIONPASS -#define GEN_PASS_DECL_MATHOPTIMIZATIONPASS -#include "xla/mlir/math/transforms/passes.h.inc" - -std::unique_ptr> -CreateMathOptimizationPass(bool enable_avx2 = false); - -std::unique_ptr> -CreateMathApproximationPass(llvm::ArrayRef oplist = {}); - -#define GEN_PASS_REGISTRATION -#include "xla/mlir/math/transforms/passes.h.inc" - -} // namespace xla - -#endif // XLA_MLIR_MATH_TRANSFORMS_PASSES_H_ diff --git a/xla/mlir/math/transforms/passes.td b/xla/mlir/math/transforms/passes.td deleted file mode 100644 index 4b324b66c58db..0000000000000 --- a/xla/mlir/math/transforms/passes.td +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_MATH_PASSES -#define XLA_MATH_PASSES - -include "mlir/Pass/PassBase.td" - -def MathOptimizationPass - : Pass<"xla-math-optimization", "mlir::func::FuncOp"> { - let summary = "Optimize operations from the `math` dialect."; - - let description = [{ - This pass performs algebraic simplification and polynomial approximation for - ops from the Math dialect. - }]; - - let dependentDialects = [ - "mlir::vector::VectorDialect", - "mlir::x86vector::X86VectorDialect" - ]; - - let constructor = "::xla::CreateMathOptimizationPass()"; - - let options = [ - Option<"enable_avx2_", "enable-avx2", "bool", "false", - "Enable math approximations that emit AVX2 intrinsics."> - ]; -} - -def MathApproximationPass - : Pass<"xla-math-approximation", "mlir::func::FuncOp"> { - let summary = "Approximate math operations for accuracy and speed."; - let constructor = "::xla::CreateMathApproximationPass()"; - let options = [ - ListOption<"oplist", "oplist", "std::string", - "List of math operations to be approximated. Use 'all' to select " - "all supported math operations.">, - ]; -} - -#endif // XLA_MATH_PASSES diff --git a/xla/mlir/math/transforms/tests/BUILD b/xla/mlir/math/transforms/tests/BUILD deleted file mode 100644 index 8f64aacfb4252..0000000000000 --- a/xla/mlir/math/transforms/tests/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -load("@tsl//tsl:tsl.default.bzl", "filegroup") -load("//xla:glob_lit_test.bzl", "glob_lit_tests") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - licenses = ["notice"], -) - -glob_lit_tests( - name = "all_tests", - data = [":test_utilities"], - driver = "//xla:run_lit.sh", - test_file_exts = ["mlir"], -) - -# Bundle together all of the test utilities that are used by tests. -filegroup( - name = "test_utilities", - testonly = True, - data = [ - "//xla/mlir/runtime:xla-runtime-opt", - "@llvm-project//llvm:FileCheck", - "@llvm-project//mlir:run_lit.sh", - ], -) diff --git a/xla/mlir/math/transforms/tests/math_optimization.mlir b/xla/mlir/math/transforms/tests/math_optimization.mlir deleted file mode 100644 index b4d215c82ae2a..0000000000000 --- a/xla/mlir/math/transforms/tests/math_optimization.mlir +++ /dev/null @@ -1,636 +0,0 @@ -// RUN: xla-runtime-opt %s --xla-math-optimization \ -// RUN: | FileCheck %s - -// RUN: xla-runtime-opt %s --xla-math-optimization=enable-avx2 \ -// RUN: | FileCheck --check-prefix=AVX2 %s - -// CHECK-LABEL: @pow_noop -func.func @pow_noop(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { - // CHECK: return %arg0, %arg1 - %c = arith.constant 1.0 : f32 - %v = arith.constant dense <1.0> : vector<4xf32> - %0 = math.powf %arg0, %c : f32 - %1 = math.powf %arg1, %v : vector<4xf32> - func.return %0, %1 : f32, vector<4xf32> -} -// CHECK-LABEL: @pow_square -func.func @pow_square(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { - // CHECK: %[[SCALAR:.*]] = arith.mulf %arg0, %arg0 - // CHECK: %[[VECTOR:.*]] = arith.mulf %arg1, %arg1 - // CHECK: return %[[SCALAR]], %[[VECTOR]] - %c = arith.constant 2.0 : f32 - %v = arith.constant dense <2.0> : vector<4xf32> - %0 = math.powf %arg0, %c : f32 - %1 = math.powf %arg1, %v : vector<4xf32> - func.return %0, %1 : f32, vector<4xf32> -} -// CHECK-LABEL: @pow_cube -func.func @pow_cube(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { - // CHECK: %[[TMP_S:.*]] = arith.mulf %arg0, %arg0 - // CHECK: %[[SCALAR:.*]] = arith.mulf %arg0, %[[TMP_S]] - // CHECK: %[[TMP_V:.*]] = arith.mulf %arg1, %arg1 - // CHECK: %[[VECTOR:.*]] = arith.mulf %arg1, %[[TMP_V]] - // CHECK: return %[[SCALAR]], %[[VECTOR]] - %c = arith.constant 3.0 : f32 - %v = arith.constant dense <3.0> : vector<4xf32> - %0 = math.powf %arg0, %c : f32 - %1 = math.powf %arg1, %v : vector<4xf32> - func.return %0, %1 : f32, vector<4xf32> -} -// CHECK-LABEL: @pow_recip -func.func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { - // CHECK: %[[CST_S:.*]] = arith.constant 1.0{{.*}} : f32 - // CHECK: %[[CST_V:.*]] = arith.constant dense<1.0{{.*}}> : vector<4xf32> - // CHECK: %[[SCALAR:.*]] = arith.divf %[[CST_S]], %arg0 - // CHECK: %[[VECTOR:.*]] = arith.divf %[[CST_V]], %arg1 - // CHECK: return %[[SCALAR]], %[[VECTOR]] - %c = arith.constant -1.0 : f32 - %v = arith.constant dense <-1.0> : vector<4xf32> - %0 = math.powf %arg0, %c : f32 - %1 = math.powf %arg1, %v : vector<4xf32> - func.return %0, %1 : f32, vector<4xf32> -} -// CHECK-LABEL: @pow_sqrt -func.func @pow_sqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { - // CHECK: %[[SCALAR:.*]] = math.sqrt %arg0 - // CHECK: %[[VECTOR:.*]] = math.sqrt %arg1 - // CHECK: return %[[SCALAR]], %[[VECTOR]] - %c = arith.constant 0.5 : f32 - %v = arith.constant dense <0.5> : vector<4xf32> - %0 = math.powf %arg0, %c : f32 - %1 = math.powf %arg1, %v : vector<4xf32> - func.return %0, %1 : f32, vector<4xf32> -} -// CHECK-LABEL: @pow_rsqrt -func.func @pow_rsqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { - // CHECK: %[[SCALAR:.*]] = math.rsqrt %arg0 - // CHECK: %[[VECTOR:.*]] = math.rsqrt %arg1 - // CHECK: return %[[SCALAR]], %[[VECTOR]] - %c = arith.constant -0.5 : f32 - %v = arith.constant dense <-0.5> : vector<4xf32> - %0 = math.powf %arg0, %c : f32 - %1 = math.powf %arg1, %v : vector<4xf32> - func.return %0, %1 : f32, vector<4xf32> -} -// Check that all math functions lowered to approximations built from -// standard operations (add, mul, fma, shift, etc...). -// CHECK-LABEL: func @erf_scalar( -// CHECK-SAME: %[[val_arg0:.*]]: f32) -> f32 { -// CHECK-DAG: %[[val_cst:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[val_cst_0:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK-DAG: %[[val_cst_1:.*]] = arith.constant 1.12837911 : f32 -// CHECK-DAG: %[[val_cst_2:.*]] = arith.constant -0.523018539 : f32 -// CHECK-DAG: %[[val_cst_3:.*]] = arith.constant 0.209741712 : f32 -// CHECK-DAG: %[[val_cst_4:.*]] = arith.constant 0.0258146804 : f32 -// CHECK-DAG: %[[val_cst_5:.*]] = arith.constant 1.12750685 : f32 -// CHECK-DAG: %[[val_cst_6:.*]] = arith.constant -0.364721417 : f32 -// CHECK-DAG: %[[val_cst_7:.*]] = arith.constant 0.118407398 : f32 -// CHECK-DAG: %[[val_cst_8:.*]] = arith.constant 0.0370645523 : f32 -// CHECK-DAG: %[[val_cst_9:.*]] = arith.constant -0.00330093061 : f32 -// CHECK-DAG: %[[val_cst_10:.*]] = arith.constant 0.00351961935 : f32 -// CHECK-DAG: %[[val_cst_11:.*]] = arith.constant -0.00141373626 : f32 -// CHECK-DAG: %[[val_cst_12:.*]] = arith.constant 2.53447099E-4 : f32 -// CHECK-DAG: %[[val_cst_13:.*]] = arith.constant -1.71048032E-5 : f32 -// CHECK-DAG: %[[val_cst_14:.*]] = arith.constant -0.463513821 : f32 -// CHECK-DAG: %[[val_cst_15:.*]] = arith.constant 0.519230127 : f32 -// CHECK-DAG: %[[val_cst_16:.*]] = arith.constant -0.131808966 : f32 -// CHECK-DAG: %[[val_cst_17:.*]] = arith.constant 0.0739796459 : f32 -// CHECK-DAG: %[[val_cst_18:.*]] = arith.constant -3.276070e-01 : f32 -// CHECK-DAG: %[[val_cst_19:.*]] = arith.constant 0.448369086 : f32 -// CHECK-DAG: %[[val_cst_20:.*]] = arith.constant -0.0883462652 : f32 -// CHECK-DAG: %[[val_cst_21:.*]] = arith.constant 0.0572442785 : f32 -// CHECK-DAG: %[[val_cst_22:.*]] = arith.constant -2.0606916 : f32 -// CHECK-DAG: %[[val_cst_23:.*]] = arith.constant 1.62705934 : f32 -// CHECK-DAG: %[[val_cst_24:.*]] = arith.constant -0.583389878 : f32 -// CHECK-DAG: %[[val_cst_25:.*]] = arith.constant 0.0821908935 : f32 -// CHECK-DAG: %[[val_cst_26:.*]] = arith.constant 8.000000e-01 : f32 -// CHECK-DAG: %[[val_cst_27:.*]] = arith.constant 2.000000e+00 : f32 -// CHECK-DAG: %[[val_cst_28:.*]] = arith.constant 3.750000e+00 : f32 -// CHECK: %[[val_0:.*]] = arith.cmpf olt, %[[val_arg0]], %[[val_cst]] : f32 -// CHECK: %[[val_1:.*]] = arith.negf %[[val_arg0]] : f32 -// CHECK: %[[val_2:.*]] = arith.select %[[val_0]], %[[val_1]], %[[val_arg0]] : f32 -// CHECK: %[[val_3:.*]] = arith.cmpf olt, %[[val_2]], %[[val_cst_26]] : f32 -// CHECK: %[[val_4:.*]] = arith.select %[[val_3]], %[[val_cst_1]], %[[val_cst_5]] : f32 -// CHECK: %[[val_5:.*]] = arith.select %[[val_3]], %[[val_cst_14]], %[[val_cst_18]] : f32 -// CHECK: %[[val_6:.*]] = arith.select %[[val_3]], %[[val_cst_2]], %[[val_cst_6]] : f32 -// CHECK: %[[val_7:.*]] = arith.select %[[val_3]], %[[val_cst_15]], %[[val_cst_19]] : f32 -// CHECK: %[[val_8:.*]] = arith.select %[[val_3]], %[[val_cst_3]], %[[val_cst_7]] : f32 -// CHECK: %[[val_9:.*]] = arith.select %[[val_3]], %[[val_cst_16]], %[[val_cst_20]] : f32 -// CHECK: %[[val_10:.*]] = arith.select %[[val_3]], %[[val_cst_4]], %[[val_cst_8]] : f32 -// CHECK: %[[val_11:.*]] = arith.select %[[val_3]], %[[val_cst_17]], %[[val_cst_21]] : f32 -// CHECK: %[[val_12:.*]] = arith.cmpf olt, %[[val_2]], %[[val_cst_27]] : f32 -// CHECK: %[[val_13:.*]] = arith.select %[[val_12]], %[[val_cst]], %[[val_cst_9]] : f32 -// CHECK: %[[val_14:.*]] = arith.select %[[val_12]], %[[val_4]], %[[val_cst_10]] : f32 -// CHECK: %[[val_15:.*]] = arith.select %[[val_12]], %[[val_5]], %[[val_cst_22]] : f32 -// CHECK: %[[val_16:.*]] = arith.select %[[val_12]], %[[val_6]], %[[val_cst_11]] : f32 -// CHECK: %[[val_17:.*]] = arith.select %[[val_12]], %[[val_7]], %[[val_cst_23]] : f32 -// CHECK: %[[val_18:.*]] = arith.select %[[val_12]], %[[val_8]], %[[val_cst_12]] : f32 -// CHECK: %[[val_19:.*]] = arith.select %[[val_12]], %[[val_9]], %[[val_cst_24]] : f32 -// CHECK: %[[val_20:.*]] = arith.select %[[val_12]], %[[val_10]], %[[val_cst_13]] : f32 -// CHECK: %[[val_21:.*]] = arith.select %[[val_12]], %[[val_11]], %[[val_cst_25]] : f32 -// CHECK: %[[val_22:.*]] = arith.select %[[val_12]], %[[val_cst]], %[[val_cst_0]] : f32 -// CHECK: %[[val_23:.*]] = arith.cmpf ult, %[[val_2]], %[[val_cst_28]] : f32 -// CHECK: %[[val_24:.*]] = math.fma %[[val_2]], %[[val_20]], %[[val_18]] : f32 -// CHECK: %[[val_25:.*]] = math.fma %[[val_2]], %[[val_24]], %[[val_16]] : f32 -// CHECK: %[[val_26:.*]] = math.fma %[[val_2]], %[[val_25]], %[[val_14]] : f32 -// CHECK: %[[val_27:.*]] = math.fma %[[val_2]], %[[val_26]], %[[val_13]] : f32 -// CHECK: %[[val_28:.*]] = math.fma %[[val_2]], %[[val_21]], %[[val_19]] : f32 -// CHECK: %[[val_29:.*]] = math.fma %[[val_2]], %[[val_28]], %[[val_17]] : f32 -// CHECK: %[[val_30:.*]] = math.fma %[[val_2]], %[[val_29]], %[[val_15]] : f32 -// CHECK: %[[val_31:.*]] = math.fma %[[val_2]], %[[val_30]], %[[val_cst_0]] : f32 -// CHECK: %[[val_32:.*]] = arith.divf %[[val_27]], %[[val_31]] : f32 -// CHECK: %[[val_33:.*]] = arith.addf %[[val_22]], %[[val_32]] : f32 -// CHECK: %[[val_34:.*]] = arith.select %[[val_23]], %[[val_33]], %[[val_cst_0]] : f32 -// CHECK: %[[val_35:.*]] = arith.negf %[[val_34]] : f32 -// CHECK: %[[val_36:.*]] = arith.select %[[val_0]], %[[val_35]], %[[val_34]] : f32 -// CHECK: return %[[val_36]] : f32 -// CHECK: } -func.func @erf_scalar(%arg0: f32) -> f32 { - %0 = math.erf %arg0 : f32 - func.return %0 : f32 -} -// CHECK-LABEL: func @erf_vector( -// CHECK-SAME: %[[arg0:.*]]: vector<8xf32>) -> vector<8xf32> { -// CHECK: %[[zero:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32> -// CHECK-NOT: erf -// CHECK-COUNT-20: select -// CHECK: %[[res:.*]] = arith.select -// CHECK: return %[[res]] : vector<8xf32> -// CHECK: } -func.func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> { - %0 = math.erf %arg0 : vector<8xf32> - func.return %0 : vector<8xf32> -} -// CHECK-LABEL: func @exp_scalar( -// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 { -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 5.000000e-01 : f32 -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1.44269502 : f32 -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant -0.693359375 : f32 -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 2.12194442E-4 : f32 -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1.98756912E-4 : f32 -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0.00139819994 : f32 -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0.00833345205 : f32 -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.0416657962 : f32 -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0.166666657 : f32 -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant -8.780000e+01 : f32 -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 8.880000e+01 : f32 -// CHECK-DAG: %[[VAL_13:.*]] = arith.constant -1.270000e+02 : f32 -// CHECK-DAG: %[[VAL_14:.*]] = arith.constant 1.270000e+02 : f32 -// CHECK-DAG: %[[VAL_15:.*]] = arith.constant 23 : i32 -// CHECK-DAG: %[[VAL_16:.*]] = arith.constant 127 : i32 -// CHECK-DAG: %[[VAL_17:.*]] = arith.cmpf uge, %[[VAL_0]], %[[VAL_11]] : f32 -// CHECK-DAG: %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_0]], %[[VAL_11]] : f32 -// CHECK-DAG: %[[VAL_19:.*]] = arith.cmpf ule, %[[VAL_18]], %[[VAL_12]] : f32 -// CHECK-DAG: %[[VAL_20:.*]] = arith.select %[[VAL_19]], %[[VAL_18]], %[[VAL_12]] : f32 -// CHECK-DAG: %[[VAL_21:.*]] = math.fma %[[VAL_20]], %[[VAL_3]], %[[VAL_1]] : f32 -// CHECK-DAG: %[[VAL_22:.*]] = math.floor %[[VAL_21]] : f32 -// CHECK-DAG: %[[VAL_23:.*]] = arith.cmpf uge, %[[VAL_22]], %[[VAL_13]] : f32 -// CHECK-DAG: %[[VAL_24:.*]] = arith.select %[[VAL_23]], %[[VAL_22]], %[[VAL_13]] : f32 -// CHECK-DAG: %[[VAL_25:.*]] = arith.cmpf ule, %[[VAL_24]], %[[VAL_14]] : f32 -// CHECK-DAG: %[[VAL_26:.*]] = arith.select %[[VAL_25]], %[[VAL_24]], %[[VAL_14]] : f32 -// CHECK-DAG: %[[VAL_27:.*]] = math.fma %[[VAL_4]], %[[VAL_26]], %[[VAL_20]] : f32 -// CHECK-DAG: %[[VAL_28:.*]] = math.fma %[[VAL_5]], %[[VAL_26]], %[[VAL_27]] : f32 -// CHECK-DAG: %[[VAL_29:.*]] = math.fma %[[VAL_28]], %[[VAL_6]], %[[VAL_7]] : f32 -// CHECK-DAG: %[[VAL_30:.*]] = math.fma %[[VAL_29]], %[[VAL_28]], %[[VAL_8]] : f32 -// CHECK-DAG: %[[VAL_31:.*]] = math.fma %[[VAL_30]], %[[VAL_28]], %[[VAL_9]] : f32 -// CHECK-DAG: %[[VAL_32:.*]] = math.fma %[[VAL_31]], %[[VAL_28]], %[[VAL_10]] : f32 -// CHECK-DAG: %[[VAL_33:.*]] = math.fma %[[VAL_32]], %[[VAL_28]], %[[VAL_1]] : f32 -// CHECK-DAG: %[[VAL_34:.*]] = arith.mulf %[[VAL_28]], %[[VAL_28]] : f32 -// CHECK-DAG: %[[VAL_35:.*]] = math.fma %[[VAL_33]], %[[VAL_34]], %[[VAL_28]] : f32 -// CHECK-DAG: %[[VAL_36:.*]] = arith.addf %[[VAL_35]], %[[VAL_2]] : f32 -// CHECK-DAG: %[[VAL_37:.*]] = arith.fptosi %[[VAL_26]] : f32 to i32 -// CHECK-DAG: %[[VAL_38:.*]] = arith.addi %[[VAL_37]], %[[VAL_16]] : i32 -// CHECK-DAG: %[[VAL_39:.*]] = arith.shli %[[VAL_38]], %[[VAL_15]] : i32 -// CHECK-DAG: %[[VAL_40:.*]] = arith.bitcast %[[VAL_39]] : i32 to f32 -// CHECK-DAG: %[[VAL_41:.*]] = arith.mulf %[[VAL_36]], %[[VAL_40]] : f32 -// CHECK: return %[[VAL_41]] : f32 -func.func @exp_scalar(%arg0: f32) -> f32 { - %0 = math.exp %arg0 : f32 - func.return %0 : f32 -} -// CHECK-LABEL: func @exp_vector( -// CHECK-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> { -// CHECK-NOT: math.exp -func.func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> { - %0 = math.exp %arg0 : vector<8xf32> - func.return %0 : vector<8xf32> -} -// CHECK-LABEL: func @expm1_scalar( -// CHECK-SAME: %[[X:.*]]: f32) -> f32 { -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant -1.000000e+00 : f32 -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 5.000000e-01 : f32 -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1.44269502 : f32 -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant -0.693359375 : f32 -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 2.12194442E-4 : f32 -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1.98756912E-4 : f32 -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0.00139819994 : f32 -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.00833345205 : f32 -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0.0416657962 : f32 -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0.166666657 : f32 -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant -8.780000e+01 : f32 -// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 8.880000e+01 : f32 -// CHECK-DAG: %[[VAL_14:.*]] = arith.constant -1.270000e+02 : f32 -// CHECK-DAG: %[[VAL_15:.*]] = arith.constant 1.270000e+02 : f32 -// CHECK-DAG: %[[VAL_16:.*]] = arith.constant 23 : i32 -// CHECK-DAG: %[[VAL_17:.*]] = arith.constant 127 : i32 -// CHECK-DAG: %[[VAL_18:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[VAL_19:.*]] = arith.constant -5.000000e-01 : f32 -// CHECK-DAG: %[[VAL_20:.*]] = arith.constant 1.17549435E-38 : f32 -// CHECK-DAG: %[[VAL_21:.*]] = arith.constant 0xFF800000 : f32 -// CHECK-DAG: %[[VAL_22:.*]] = arith.constant 0x7F800000 : f32 -// CHECK-DAG: %[[VAL_23:.*]] = arith.constant 0x7FC00000 : f32 -// CHECK-DAG: %[[VAL_24:.*]] = arith.constant 0.707106769 : f32 -// CHECK-DAG: %[[VAL_25:.*]] = arith.constant 0.0703768358 : f32 -// CHECK-DAG: %[[VAL_26:.*]] = arith.constant -0.115146101 : f32 -// CHECK-DAG: %[[VAL_27:.*]] = arith.constant 0.116769984 : f32 -// CHECK-DAG: %[[VAL_28:.*]] = arith.constant -0.12420141 : f32 -// CHECK-DAG: %[[VAL_29:.*]] = arith.constant 0.142493233 : f32 -// CHECK-DAG: %[[VAL_30:.*]] = arith.constant -0.166680574 : f32 -// CHECK-DAG: %[[VAL_31:.*]] = arith.constant 0.200007141 : f32 -// CHECK-DAG: %[[VAL_32:.*]] = arith.constant -0.24999994 : f32 -// CHECK-DAG: %[[VAL_33:.*]] = arith.constant 0.333333313 : f32 -// CHECK-DAG: %[[VAL_34:.*]] = arith.constant 1.260000e+02 : f32 -// CHECK-DAG: %[[VAL_35:.*]] = arith.constant -2139095041 : i32 -// CHECK-DAG: %[[VAL_36:.*]] = arith.constant 1056964608 : i32 -// CHECK-DAG: %[[VAL_37:.*]] = arith.constant 0.693147182 : f32 -// CHECK-DAG: %[[VAL_38:.*]] = arith.cmpf uge, %[[X]], %[[VAL_12]] : f32 -// CHECK-DAG: %[[VAL_39:.*]] = arith.select %[[VAL_38]], %[[X]], %[[VAL_12]] : f32 -// CHECK-DAG: %[[VAL_40:.*]] = arith.cmpf ule, %[[VAL_39]], %[[VAL_13]] : f32 -// CHECK-DAG: %[[VAL_41:.*]] = arith.select %[[VAL_40]], %[[VAL_39]], %[[VAL_13]] : f32 -// CHECK-DAG: %[[VAL_42:.*]] = math.fma %[[VAL_41]], %[[VAL_4]], %[[VAL_3]] : f32 -// CHECK-DAG: %[[VAL_43:.*]] = math.floor %[[VAL_42]] : f32 -// CHECK-DAG: %[[VAL_44:.*]] = arith.cmpf uge, %[[VAL_43]], %[[VAL_14]] : f32 -// CHECK-DAG: %[[VAL_45:.*]] = arith.select %[[VAL_44]], %[[VAL_43]], %[[VAL_14]] : f32 -// CHECK-DAG: %[[VAL_46:.*]] = arith.cmpf ule, %[[VAL_45]], %[[VAL_15]] : f32 -// CHECK-DAG: %[[VAL_47:.*]] = arith.select %[[VAL_46]], %[[VAL_45]], %[[VAL_15]] : f32 -// CHECK-DAG: %[[VAL_48:.*]] = math.fma %[[VAL_5]], %[[VAL_47]], %[[VAL_41]] : f32 -// CHECK-DAG: %[[VAL_49:.*]] = math.fma %[[VAL_6]], %[[VAL_47]], %[[VAL_48]] : f32 -// CHECK-DAG: %[[VAL_50:.*]] = math.fma %[[VAL_49]], %[[VAL_7]], %[[VAL_8]] : f32 -// CHECK-DAG: %[[VAL_51:.*]] = math.fma %[[VAL_50]], %[[VAL_49]], %[[VAL_9]] : f32 -// CHECK-DAG: %[[VAL_52:.*]] = math.fma %[[VAL_51]], %[[VAL_49]], %[[VAL_10]] : f32 -// CHECK-DAG: %[[VAL_53:.*]] = math.fma %[[VAL_52]], %[[VAL_49]], %[[VAL_11]] : f32 -// CHECK-DAG: %[[VAL_54:.*]] = math.fma %[[VAL_53]], %[[VAL_49]], %[[VAL_3]] : f32 -// CHECK-DAG: %[[VAL_55:.*]] = arith.mulf %[[VAL_49]], %[[VAL_49]] : f32 -// CHECK-DAG: %[[VAL_56:.*]] = math.fma %[[VAL_54]], %[[VAL_55]], %[[VAL_49]] : f32 -// CHECK-DAG: %[[VAL_57:.*]] = arith.addf %[[VAL_56]], %[[VAL_1]] : f32 -// CHECK-DAG: %[[VAL_58:.*]] = arith.fptosi %[[VAL_47]] : f32 to i32 -// CHECK-DAG: %[[VAL_59:.*]] = arith.addi %[[VAL_58]], %[[VAL_17]] : i32 -// CHECK-DAG: %[[VAL_60:.*]] = arith.shli %[[VAL_59]], %[[VAL_16]] : i32 -// CHECK-DAG: %[[VAL_61:.*]] = arith.bitcast %[[VAL_60]] : i32 to f32 -// CHECK-DAG: %[[VAL_62:.*]] = arith.mulf %[[VAL_57]], %[[VAL_61]] : f32 -// CHECK-DAG: %[[VAL_63:.*]] = arith.cmpf ueq, %[[VAL_62]], %[[VAL_1]] : f32 -// CHECK-DAG: %[[VAL_64:.*]] = arith.subf %[[VAL_62]], %[[VAL_1]] : f32 -// CHECK-DAG: %[[VAL_65:.*]] = arith.cmpf oeq, %[[VAL_64]], %[[VAL_2]] : f32 -// CHECK-DAG: %[[VAL_66:.*]] = arith.cmpf ugt, %[[VAL_62]], %[[VAL_20]] : f32 -// CHECK-DAG: %[[VAL_67:.*]] = arith.select %[[VAL_66]], %[[VAL_62]], %[[VAL_20]] : f32 -// CHECK-DAG: %[[VAL_68:.*]] = arith.bitcast %[[VAL_67]] : f32 to i32 -// CHECK-DAG: %[[VAL_69:.*]] = arith.andi %[[VAL_68]], %[[VAL_35]] : i32 -// CHECK-DAG: %[[VAL_70:.*]] = arith.ori %[[VAL_69]], %[[VAL_36]] : i32 -// CHECK-DAG: %[[VAL_71:.*]] = arith.bitcast %[[VAL_70]] : i32 to f32 -// CHECK-DAG: %[[VAL_72:.*]] = arith.bitcast %[[VAL_67]] : f32 to i32 -// CHECK-DAG: %[[VAL_73:.*]] = arith.shrui %[[VAL_72]], %[[VAL_16]] : i32 -// CHECK-DAG: %[[VAL_74:.*]] = arith.sitofp %[[VAL_73]] : i32 to f32 -// CHECK-DAG: %[[VAL_75:.*]] = arith.subf %[[VAL_74]], %[[VAL_34]] : f32 -// CHECK-DAG: %[[VAL_76:.*]] = arith.cmpf olt, %[[VAL_71]], %[[VAL_24]] : f32 -// CHECK-DAG: %[[VAL_77:.*]] = arith.select %[[VAL_76]], %[[VAL_71]], %[[VAL_18]] : f32 -// CHECK-DAG: %[[VAL_78:.*]] = arith.subf %[[VAL_71]], %[[VAL_1]] : f32 -// CHECK-DAG: %[[VAL_79:.*]] = arith.select %[[VAL_76]], %[[VAL_1]], %[[VAL_18]] : f32 -// CHECK-DAG: %[[VAL_80:.*]] = arith.subf %[[VAL_75]], %[[VAL_79]] : f32 -// CHECK-DAG: %[[VAL_81:.*]] = arith.addf %[[VAL_78]], %[[VAL_77]] : f32 -// CHECK-DAG: %[[VAL_82:.*]] = arith.mulf %[[VAL_81]], %[[VAL_81]] : f32 -// CHECK-DAG: %[[VAL_83:.*]] = arith.mulf %[[VAL_82]], %[[VAL_81]] : f32 -// CHECK-DAG: %[[VAL_84:.*]] = math.fma %[[VAL_25]], %[[VAL_81]], %[[VAL_26]] : f32 -// CHECK-DAG: %[[VAL_85:.*]] = math.fma %[[VAL_28]], %[[VAL_81]], %[[VAL_29]] : f32 -// CHECK-DAG: %[[VAL_86:.*]] = math.fma %[[VAL_31]], %[[VAL_81]], %[[VAL_32]] : f32 -// CHECK-DAG: %[[VAL_87:.*]] = math.fma %[[VAL_84]], %[[VAL_81]], %[[VAL_27]] : f32 -// CHECK-DAG: %[[VAL_88:.*]] = math.fma %[[VAL_85]], %[[VAL_81]], %[[VAL_30]] : f32 -// CHECK-DAG: %[[VAL_89:.*]] = math.fma %[[VAL_86]], %[[VAL_81]], %[[VAL_33]] : f32 -// CHECK-DAG: %[[VAL_90:.*]] = math.fma %[[VAL_87]], %[[VAL_83]], %[[VAL_88]] : f32 -// CHECK-DAG: %[[VAL_91:.*]] = math.fma %[[VAL_90]], %[[VAL_83]], %[[VAL_89]] : f32 -// CHECK-DAG: %[[VAL_92:.*]] = arith.mulf %[[VAL_91]], %[[VAL_83]] : f32 -// CHECK-DAG: %[[VAL_93:.*]] = math.fma %[[VAL_19]], %[[VAL_82]], %[[VAL_92]] : f32 -// CHECK-DAG: %[[VAL_94:.*]] = arith.addf %[[VAL_81]], %[[VAL_93]] : f32 -// CHECK-DAG: %[[VAL_95:.*]] = math.fma %[[VAL_80]], %[[VAL_37]], %[[VAL_94]] : f32 -// CHECK-DAG: %[[VAL_96:.*]] = arith.cmpf ult, %[[VAL_62]], %[[VAL_18]] : f32 -// CHECK-DAG: %[[VAL_97:.*]] = arith.cmpf oeq, %[[VAL_62]], %[[VAL_18]] : f32 -// CHECK-DAG: %[[VAL_98:.*]] = arith.cmpf oeq, %[[VAL_62]], %[[VAL_22]] : f32 -// CHECK-DAG: %[[VAL_99:.*]] = arith.select %[[VAL_98]], %[[VAL_22]], %[[VAL_95]] : f32 -// CHECK-DAG: %[[VAL_100:.*]] = arith.select %[[VAL_96]], %[[VAL_23]], %[[VAL_99]] : f32 -// CHECK-DAG: %[[VAL_101:.*]] = arith.select %[[VAL_97]], %[[VAL_21]], %[[VAL_100]] : f32 -// CHECK-DAG: %[[VAL_102:.*]] = arith.cmpf oeq, %[[VAL_101]], %[[VAL_62]] : f32 -// CHECK-DAG: %[[VAL_103:.*]] = arith.divf %[[X]], %[[VAL_101]] : f32 -// CHECK-DAG: %[[VAL_104:.*]] = arith.mulf %[[VAL_64]], %[[VAL_103]] : f32 -// CHECK-DAG: %[[VAL_105:.*]] = arith.select %[[VAL_102]], %[[VAL_62]], %[[VAL_104]] : f32 -// CHECK-DAG: %[[VAL_106:.*]] = arith.select %[[VAL_65]], %[[VAL_2]], %[[VAL_105]] : f32 -// CHECK-DAG: %[[VAL_107:.*]] = arith.select %[[VAL_63]], %[[X]], %[[VAL_106]] : f32 -// CHECK-DAG: return %[[VAL_107]] : f32 -// CHECK: } -func.func @expm1_scalar(%arg0: f32) -> f32 { - %0 = math.expm1 %arg0 : f32 - func.return %0 : f32 -} -// CHECK-LABEL: func @expm1_vector( -// CHECK-SAME: %[[VAL_0:.*]]: vector<8x8xf32>) -> vector<8x8xf32> { -// CHECK-NOT: exp -// CHECK-NOT: log -// CHECK-NOT: expm1 -func.func @expm1_vector(%arg0: vector<8x8xf32>) -> vector<8x8xf32> { - %0 = math.expm1 %arg0 : vector<8x8xf32> - func.return %0 : vector<8x8xf32> -} -// CHECK-LABEL: func @log_scalar( -// CHECK-SAME: %[[X:.*]]: f32) -> f32 { -// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK: %[[VAL_3:.*]] = arith.constant -5.000000e-01 : f32 -// CHECK: %[[VAL_4:.*]] = arith.constant 1.17549435E-38 : f32 -// CHECK: %[[VAL_5:.*]] = arith.constant 0xFF800000 : f32 -// CHECK: %[[VAL_6:.*]] = arith.constant 0x7F800000 : f32 -// CHECK: %[[VAL_7:.*]] = arith.constant 0x7FC00000 : f32 -// CHECK: %[[VAL_8:.*]] = arith.constant 0.707106769 : f32 -// CHECK: %[[VAL_9:.*]] = arith.constant 0.0703768358 : f32 -// CHECK: %[[VAL_10:.*]] = arith.constant -0.115146101 : f32 -// CHECK: %[[VAL_11:.*]] = arith.constant 0.116769984 : f32 -// CHECK: %[[VAL_12:.*]] = arith.constant -0.12420141 : f32 -// CHECK: %[[VAL_13:.*]] = arith.constant 0.142493233 : f32 -// CHECK: %[[VAL_14:.*]] = arith.constant -0.166680574 : f32 -// CHECK: %[[VAL_15:.*]] = arith.constant 0.200007141 : f32 -// CHECK: %[[VAL_16:.*]] = arith.constant -0.24999994 : f32 -// CHECK: %[[VAL_17:.*]] = arith.constant 0.333333313 : f32 -// CHECK: %[[VAL_18:.*]] = arith.constant 1.260000e+02 : f32 -// CHECK: %[[VAL_19:.*]] = arith.constant -2139095041 : i32 -// CHECK: %[[VAL_20:.*]] = arith.constant 1056964608 : i32 -// CHECK: %[[VAL_21:.*]] = arith.constant 23 : i32 -// CHECK: %[[VAL_22:.*]] = arith.constant 0.693147182 : f32 -// CHECK: %[[VAL_23:.*]] = arith.cmpf ugt, %[[X]], %[[VAL_4]] : f32 -// CHECK: %[[VAL_24:.*]] = arith.select %[[VAL_23]], %[[X]], %[[VAL_4]] : f32 -// CHECK-NOT: frexp -// CHECK: %[[VAL_25:.*]] = arith.bitcast %[[VAL_24]] : f32 to i32 -// CHECK: %[[VAL_26:.*]] = arith.andi %[[VAL_25]], %[[VAL_19]] : i32 -// CHECK: %[[VAL_27:.*]] = arith.ori %[[VAL_26]], %[[VAL_20]] : i32 -// CHECK: %[[VAL_28:.*]] = arith.bitcast %[[VAL_27]] : i32 to f32 -// CHECK: %[[VAL_29:.*]] = arith.bitcast %[[VAL_24]] : f32 to i32 -// CHECK: %[[VAL_30:.*]] = arith.shrui %[[VAL_29]], %[[VAL_21]] : i32 -// CHECK: %[[VAL_31:.*]] = arith.sitofp %[[VAL_30]] : i32 to f32 -// CHECK: %[[VAL_32:.*]] = arith.subf %[[VAL_31]], %[[VAL_18]] : f32 -// CHECK: %[[VAL_33:.*]] = arith.cmpf olt, %[[VAL_28]], %[[VAL_8]] : f32 -// CHECK: %[[VAL_34:.*]] = arith.select %[[VAL_33]], %[[VAL_28]], %[[VAL_1]] : f32 -// CHECK: %[[VAL_35:.*]] = arith.subf %[[VAL_28]], %[[VAL_2]] : f32 -// CHECK: %[[VAL_36:.*]] = arith.select %[[VAL_33]], %[[VAL_2]], %[[VAL_1]] : f32 -// CHECK: %[[VAL_37:.*]] = arith.subf %[[VAL_32]], %[[VAL_36]] : f32 -// CHECK: %[[VAL_38:.*]] = arith.addf %[[VAL_35]], %[[VAL_34]] : f32 -// CHECK: %[[VAL_39:.*]] = arith.mulf %[[VAL_38]], %[[VAL_38]] : f32 -// CHECK: %[[VAL_40:.*]] = arith.mulf %[[VAL_39]], %[[VAL_38]] : f32 -// CHECK: %[[VAL_41:.*]] = math.fma %[[VAL_9]], %[[VAL_38]], %[[VAL_10]] : f32 -// CHECK: %[[VAL_42:.*]] = math.fma %[[VAL_12]], %[[VAL_38]], %[[VAL_13]] : f32 -// CHECK: %[[VAL_43:.*]] = math.fma %[[VAL_15]], %[[VAL_38]], %[[VAL_16]] : f32 -// CHECK: %[[VAL_44:.*]] = math.fma %[[VAL_41]], %[[VAL_38]], %[[VAL_11]] : f32 -// CHECK: %[[VAL_45:.*]] = math.fma %[[VAL_42]], %[[VAL_38]], %[[VAL_14]] : f32 -// CHECK: %[[VAL_46:.*]] = math.fma %[[VAL_43]], %[[VAL_38]], %[[VAL_17]] : f32 -// CHECK: %[[VAL_47:.*]] = math.fma %[[VAL_44]], %[[VAL_40]], %[[VAL_45]] : f32 -// CHECK: %[[VAL_48:.*]] = math.fma %[[VAL_47]], %[[VAL_40]], %[[VAL_46]] : f32 -// CHECK: %[[VAL_49:.*]] = arith.mulf %[[VAL_48]], %[[VAL_40]] : f32 -// CHECK: %[[VAL_50:.*]] = math.fma %[[VAL_3]], %[[VAL_39]], %[[VAL_49]] : f32 -// CHECK: %[[VAL_51:.*]] = arith.addf %[[VAL_38]], %[[VAL_50]] : f32 -// CHECK: %[[VAL_52:.*]] = math.fma %[[VAL_37]], %[[VAL_22]], %[[VAL_51]] : f32 -// CHECK: %[[VAL_53:.*]] = arith.cmpf ult, %[[X]], %[[VAL_1]] : f32 -// CHECK: %[[VAL_54:.*]] = arith.cmpf oeq, %[[X]], %[[VAL_1]] : f32 -// CHECK: %[[VAL_55:.*]] = arith.cmpf oeq, %[[X]], %[[VAL_6]] : f32 -// CHECK: %[[VAL_56:.*]] = arith.select %[[VAL_55]], %[[VAL_6]], %[[VAL_52]] : f32 -// CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_53]], %[[VAL_7]], %[[VAL_56]] : f32 -// CHECK: %[[VAL_58:.*]] = arith.select %[[VAL_54]], %[[VAL_5]], %[[VAL_57]] : f32 -// CHECK: return %[[VAL_58]] : f32 -// CHECK: } -func.func @log_scalar(%arg0: f32) -> f32 { - %0 = math.log %arg0 : f32 - func.return %0 : f32 -} -// CHECK-LABEL: func @log_vector( -// CHECK-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> { -// CHECK: %[[CST_LN2:.*]] = arith.constant dense<0.693147182> : vector<8xf32> -// CHECK-COUNT-5: select -// CHECK: %[[VAL_71:.*]] = arith.select -// CHECK: return %[[VAL_71]] : vector<8xf32> -// CHECK: } -func.func @log_vector(%arg0: vector<8xf32>) -> vector<8xf32> { - %0 = math.log %arg0 : vector<8xf32> - func.return %0 : vector<8xf32> -} -// CHECK-LABEL: func @log2_scalar( -// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 { -// CHECK: %[[CST_LOG2E:.*]] = arith.constant 1.44269502 : f32 -// CHECK-COUNT-5: select -// CHECK: %[[VAL_65:.*]] = arith.select -// CHECK: return %[[VAL_65]] : f32 -// CHECK: } -func.func @log2_scalar(%arg0: f32) -> f32 { - %0 = math.log2 %arg0 : f32 - func.return %0 : f32 -} -// CHECK-LABEL: func @log2_vector( -// CHECK-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> { -// CHECK: %[[CST_LOG2E:.*]] = arith.constant dense<1.44269502> : vector<8xf32> -// CHECK-COUNT-5: select -// CHECK: %[[VAL_71:.*]] = arith.select -// CHECK: return %[[VAL_71]] : vector<8xf32> -// CHECK: } -func.func @log2_vector(%arg0: vector<8xf32>) -> vector<8xf32> { - %0 = math.log2 %arg0 : vector<8xf32> - func.return %0 : vector<8xf32> -} -// CHECK-LABEL: func @log1p_scalar( -// CHECK-SAME: %[[X:.*]]: f32) -> f32 { -// CHECK: %[[CST_ONE:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK: %[[U:.*]] = arith.addf %[[X]], %[[CST_ONE]] : f32 -// CHECK: %[[U_SMALL:.*]] = arith.cmpf oeq, %[[U]], %[[CST_ONE]] : f32 -// CHECK-NOT: log -// CHECK-COUNT-5: select -// CHECK: %[[LOG_U:.*]] = arith.select -// CHECK: %[[U_INF:.*]] = arith.cmpf oeq, %[[U]], %[[LOG_U]] : f32 -// CHECK: %[[VAL_69:.*]] = arith.subf %[[U]], %[[CST_ONE]] : f32 -// CHECK: %[[VAL_70:.*]] = arith.divf %[[LOG_U]], %[[VAL_69]] : f32 -// CHECK: %[[LOG_LARGE:.*]] = arith.mulf %[[X]], %[[VAL_70]] : f32 -// CHECK: %[[VAL_72:.*]] = arith.ori %[[U_SMALL]], %[[U_INF]] : i1 -// CHECK: %[[APPROX:.*]] = arith.select %[[VAL_72]], %[[X]], %[[LOG_LARGE]] : f32 -// CHECK: return %[[APPROX]] : f32 -// CHECK: } -func.func @log1p_scalar(%arg0: f32) -> f32 { - %0 = math.log1p %arg0 : f32 - func.return %0 : f32 -} -// CHECK-LABEL: func @log1p_vector( -// CHECK-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> { -// CHECK: %[[CST_ONE:.*]] = arith.constant dense<1.000000e+00> : vector<8xf32> -// CHECK-COUNT-6: select -// CHECK: %[[VAL_79:.*]] = arith.select -// CHECK: return %[[VAL_79]] : vector<8xf32> -// CHECK: } -func.func @log1p_vector(%arg0: vector<8xf32>) -> vector<8xf32> { - %0 = math.log1p %arg0 : vector<8xf32> - func.return %0 : vector<8xf32> -} -// CHECK-LABEL: func @tanh_scalar( -// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 { -// CHECK: %[[VAL_1:.*]] = arith.constant -7.99881172 : f32 -// CHECK: %[[VAL_2:.*]] = arith.constant 7.99881172 : f32 -// CHECK: %[[VAL_3:.*]] = arith.constant 4.000000e-04 : f32 -// CHECK: %[[VAL_4:.*]] = arith.constant 0.00489352457 : f32 -// CHECK: %[[VAL_5:.*]] = arith.constant 6.37261954E-4 : f32 -// CHECK: %[[VAL_6:.*]] = arith.constant 1.48572235E-5 : f32 -// CHECK: %[[VAL_7:.*]] = arith.constant 5.12229725E-8 : f32 -// CHECK: %[[VAL_8:.*]] = arith.constant -8.60467184E-11 : f32 -// CHECK: %[[VAL_9:.*]] = arith.constant 2.00018794E-13 : f32 -// CHECK: %[[VAL_10:.*]] = arith.constant -2.76076837E-16 : f32 -// CHECK: %[[VAL_11:.*]] = arith.constant 0.00489352504 : f32 -// CHECK: %[[VAL_12:.*]] = arith.constant 0.00226843474 : f32 -// CHECK: %[[VAL_13:.*]] = arith.constant 1.18534706E-4 : f32 -// CHECK: %[[VAL_14:.*]] = arith.constant 1.19825836E-6 : f32 -// CHECK: %[[VAL_15:.*]] = arith.cmpf ult, %[[VAL_0]], %[[VAL_2]] : f32 -// CHECK: %[[VAL_16:.*]] = arith.select %[[VAL_15]], %[[VAL_0]], %[[VAL_2]] : f32 -// CHECK: %[[VAL_17:.*]] = arith.cmpf ugt, %[[VAL_16]], %[[VAL_1]] : f32 -// CHECK: %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_16]], %[[VAL_1]] : f32 -// CHECK: %[[VAL_19:.*]] = math.absf %[[VAL_0]] : f32 -// CHECK: %[[VAL_20:.*]] = arith.cmpf olt, %[[VAL_19]], %[[VAL_3]] : f32 -// CHECK: %[[VAL_21:.*]] = arith.mulf %[[VAL_18]], %[[VAL_18]] : f32 -// CHECK: %[[VAL_22:.*]] = math.fma %[[VAL_21]], %[[VAL_10]], %[[VAL_9]] : f32 -// CHECK: %[[VAL_23:.*]] = math.fma %[[VAL_21]], %[[VAL_22]], %[[VAL_8]] : f32 -// CHECK: %[[VAL_24:.*]] = math.fma %[[VAL_21]], %[[VAL_23]], %[[VAL_7]] : f32 -// CHECK: %[[VAL_25:.*]] = math.fma %[[VAL_21]], %[[VAL_24]], %[[VAL_6]] : f32 -// CHECK: %[[VAL_26:.*]] = math.fma %[[VAL_21]], %[[VAL_25]], %[[VAL_5]] : f32 -// CHECK: %[[VAL_27:.*]] = math.fma %[[VAL_21]], %[[VAL_26]], %[[VAL_4]] : f32 -// CHECK: %[[VAL_28:.*]] = arith.mulf %[[VAL_18]], %[[VAL_27]] : f32 -// CHECK: %[[VAL_29:.*]] = math.fma %[[VAL_21]], %[[VAL_14]], %[[VAL_13]] : f32 -// CHECK: %[[VAL_30:.*]] = math.fma %[[VAL_21]], %[[VAL_29]], %[[VAL_12]] : f32 -// CHECK: %[[VAL_31:.*]] = math.fma %[[VAL_21]], %[[VAL_30]], %[[VAL_11]] : f32 -// CHECK: %[[VAL_32:.*]] = arith.divf %[[VAL_28]], %[[VAL_31]] : f32 -// CHECK: %[[VAL_33:.*]] = arith.select %[[VAL_20]], %[[VAL_18]], %[[VAL_32]] : f32 -// CHECK: return %[[VAL_33]] : f32 -// CHECK: } -func.func @tanh_scalar(%arg0: f32) -> f32 { - %0 = math.tanh %arg0 : f32 - func.return %0 : f32 -} -// CHECK-LABEL: func @tanh_vector( -// CHECK-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> { -// CHECK: %[[VAL_1:.*]] = arith.constant dense<-7.99881172> : vector<8xf32> -// CHECK-NOT: tanh -// CHECK-COUNT-2: select -// CHECK: %[[VAL_33:.*]] = arith.select -// CHECK: return %[[VAL_33]] : vector<8xf32> -// CHECK: } -func.func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> { - %0 = math.tanh %arg0 : vector<8xf32> - func.return %0 : vector<8xf32> -} -// We only approximate rsqrt for vectors and when the AVX2 option is enabled. -// CHECK-LABEL: func @rsqrt_scalar -// AVX2-LABEL: func @rsqrt_scalar -// CHECK: math.rsqrt -// AVX2: math.rsqrt -func.func @rsqrt_scalar(%arg0: f32) -> f32 { - %0 = math.rsqrt %arg0 : f32 - func.return %0 : f32 -} -// CHECK-LABEL: func @rsqrt_vector_8xf32 -// CHECK: math.rsqrt -// AVX2-LABEL: func @rsqrt_vector_8xf32( -// AVX2-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> { -// AVX2: %[[VAL_1:.*]] = arith.constant dense<0x7F800000> : vector<8xf32> -// AVX2: %[[VAL_2:.*]] = arith.constant dense<1.500000e+00> : vector<8xf32> -// AVX2: %[[VAL_3:.*]] = arith.constant dense<-5.000000e-01> : vector<8xf32> -// AVX2: %[[VAL_4:.*]] = arith.constant dense<1.17549435E-38> : vector<8xf32> -// AVX2: %[[VAL_5:.*]] = arith.mulf %[[VAL_0]], %[[VAL_3]] : vector<8xf32> -// AVX2: %[[VAL_6:.*]] = arith.cmpf olt, %[[VAL_0]], %[[VAL_4]] : vector<8xf32> -// AVX2: %[[VAL_7:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_1]] : vector<8xf32> -// AVX2: %[[VAL_8:.*]] = arith.ori %[[VAL_6]], %[[VAL_7]] : vector<8xi1> -// AVX2: %[[VAL_9:.*]] = x86vector.avx.rsqrt %[[VAL_0]] : vector<8xf32> -// AVX2: %[[VAL_10:.*]] = arith.mulf %[[VAL_5]], %[[VAL_9]] : vector<8xf32> -// AVX2: %[[VAL_11:.*]] = math.fma %[[VAL_9]], %[[VAL_10]], %[[VAL_2]] : vector<8xf32> -// AVX2: %[[VAL_12:.*]] = arith.mulf %[[VAL_9]], %[[VAL_11]] : vector<8xf32> -// AVX2: %[[VAL_13:.*]] = arith.select %[[VAL_8]], %[[VAL_9]], %[[VAL_12]] : vector<8xi1>, vector<8xf32> -// AVX2: return %[[VAL_13]] : vector<8xf32> -// AVX2: } -func.func @rsqrt_vector_8xf32(%arg0: vector<8xf32>) -> vector<8xf32> { - %0 = math.rsqrt %arg0 : vector<8xf32> - func.return %0 : vector<8xf32> -} -// Virtual vector width is not a multiple of an AVX2 vector width. -// -// CHECK-LABEL: func @rsqrt_vector_5xf32 -// CHECK: math.rsqrt -// AVX2-LABEL: func @rsqrt_vector_5xf32 -// AVX2: math.rsqrt -func.func @rsqrt_vector_5xf32(%arg0: vector<5xf32>) -> vector<5xf32> { - %0 = math.rsqrt %arg0 : vector<5xf32> - func.return %0 : vector<5xf32> -} -// One dimensional virtual vector expanded and unrolled into multiple AVX2-sized -// vectors. -// -// CHECK-LABEL: func @rsqrt_vector_16xf32 -// CHECK: math.rsqrt -// AVX2-LABEL: func @rsqrt_vector_16xf32( -// AVX2-SAME: %[[ARG:.*]]: vector<16xf32> -// AVX2-SAME: ) -> vector<16xf32> -// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x8xf32> -// AVX2: %[[EXPAND:.*]] = vector.shape_cast %[[ARG]] : vector<16xf32> to vector<2x8xf32> -// AVX2: %[[VEC0:.*]] = vector.extract %[[EXPAND]][0] -// AVX2: %[[RSQRT0:.*]] = x86vector.avx.rsqrt %[[VEC0]] -// AVX2: %[[VEC1:.*]] = vector.extract %[[EXPAND]][1] -// AVX2: %[[RSQRT1:.*]] = x86vector.avx.rsqrt %[[VEC1]] -// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT0]], %[[INIT]] [0] -// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT1]], %[[RESULT0]] [1] -// AVX2: %[[RSQRT:.*]] = vector.shape_cast %[[RESULT1]] : vector<2x8xf32> to vector<16xf32> -func.func @rsqrt_vector_16xf32(%arg0: vector<16xf32>) -> vector<16xf32> { - %0 = math.rsqrt %arg0 : vector<16xf32> - func.return %0 : vector<16xf32> -} -// Two dimensional virtual vector unrolled into multiple AVX2-sized vectors. -// -// CHECK-LABEL: func @rsqrt_vector_2x8xf32 -// CHECK: math.rsqrt -// AVX2-LABEL: func @rsqrt_vector_2x8xf32( -// AVX2-SAME: %[[ARG:.*]]: vector<2x8xf32> -// AVX2-SAME: ) -> vector<2x8xf32> -// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x8xf32> -// AVX2-NOT: vector.shape_cast -// AVX2: %[[VEC0:.*]] = vector.extract %[[ARG]][0] -// AVX2: %[[RSQRT0:.*]] = x86vector.avx.rsqrt %[[VEC0]] -// AVX2: %[[VEC1:.*]] = vector.extract %[[ARG]][1] -// AVX2: %[[RSQRT1:.*]] = x86vector.avx.rsqrt %[[VEC1]] -// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT0]], %[[INIT]] [0] -// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT1]], %[[RESULT0]] [1] -// AVX2-NOT: vector.shape_cast -func.func @rsqrt_vector_2x8xf32(%arg0: vector<2x8xf32>) -> vector<2x8xf32> { - %0 = math.rsqrt %arg0 : vector<2x8xf32> - func.return %0 : vector<2x8xf32> -} -// Two dimensional virtual vector expanded and unrolled into multiple AVX2-sized -// vectors. -// -// CHECK-LABEL: func @rsqrt_vector_2x16xf32 -// CHECK: math.rsqrt -// AVX2-LABEL: func @rsqrt_vector_2x16xf32( -// AVX2-SAME: %[[ARG:.*]]: vector<2x16xf32> -// AVX2-SAME: ) -> vector<2x16xf32> -// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x8xf32> -// AVX2: %[[EXPAND:.*]] = vector.shape_cast %[[ARG]] : vector<2x16xf32> to vector<2x2x8xf32> -// AVX2: %[[VEC00:.*]] = vector.extract %[[EXPAND]][0, 0] -// AVX2: %[[RSQRT00:.*]] = x86vector.avx.rsqrt %[[VEC00]] -// AVX2: %[[VEC01:.*]] = vector.extract %[[EXPAND]][0, 1] -// AVX2: %[[RSQRT01:.*]] = x86vector.avx.rsqrt %[[VEC01]] -// AVX2: %[[VEC10:.*]] = vector.extract %[[EXPAND]][1, 0] -// AVX2: %[[RSQRT10:.*]] = x86vector.avx.rsqrt %[[VEC10]] -// AVX2: %[[VEC11:.*]] = vector.extract %[[EXPAND]][1, 1] -// AVX2: %[[RSQRT11:.*]] = x86vector.avx.rsqrt %[[VEC11]] -// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT00]], %[[INIT]] [0, 0] -// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT01]], %[[RESULT0]] [0, 1] -// AVX2: %[[RESULT2:.*]] = vector.insert %[[RSQRT10]], %[[RESULT1]] [1, 0] -// AVX2: %[[RESULT3:.*]] = vector.insert %[[RSQRT11]], %[[RESULT2]] [1, 1] -// AVX2: %[[RSQRT:.*]] = vector.shape_cast %[[RESULT3]] : vector<2x2x8xf32> to vector<2x16xf32> -func.func @rsqrt_vector_2x16xf32(%arg0: vector<2x16xf32>) -> vector<2x16xf32> { - %0 = math.rsqrt %arg0 : vector<2x16xf32> - func.return %0 : vector<2x16xf32> -} diff --git a/xla/mlir/memref/BUILD b/xla/mlir/memref/BUILD index 0fbd61df0678f..1bb06251af0dc 100644 --- a/xla/mlir/memref/BUILD +++ b/xla/mlir/memref/BUILD @@ -5,6 +5,7 @@ package_group( # copybara:uncomment_begin(google-only) # # TODO(ezhulenev): Clean up dependencies that are leforvers from Autofusion project. # "//third_party/tf_runtime/...", + # "//third_party/py/enzyme_ad/...", # copybara:uncomment_end(google-only) ], ) diff --git a/xla/mlir/memref/transforms/BUILD b/xla/mlir/memref/transforms/BUILD index 7ada7e5a3645b..eea3a3123dfd9 100644 --- a/xla/mlir/memref/transforms/BUILD +++ b/xla/mlir/memref/transforms/BUILD @@ -1,6 +1,6 @@ -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") +load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -32,10 +32,9 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ ":passes_inc_gen", - "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Transforms", ], ) diff --git a/xla/mlir/memref/transforms/aligned_allocations.cc b/xla/mlir/memref/transforms/aligned_allocations.cc index 875e8f606514b..8cee23aec11e3 100644 --- a/xla/mlir/memref/transforms/aligned_allocations.cc +++ b/xla/mlir/memref/transforms/aligned_allocations.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,8 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "xla/mlir/memref/transforms/passes.h" diff --git a/xla/mlir/memref/transforms/passes.h b/xla/mlir/memref/transforms/passes.h index da79f1c16400f..05b6f5fec2aeb 100644 --- a/xla/mlir/memref/transforms/passes.h +++ b/xla/mlir/memref/transforms/passes.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/memref/transforms/passes.td b/xla/mlir/memref/transforms/passes.td index cf657aafa4959..8310e55556294 100644 --- a/xla/mlir/memref/transforms/passes.td +++ b/xla/mlir/memref/transforms/passes.td @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/memref/transforms/tests/BUILD b/xla/mlir/memref/transforms/tests/BUILD deleted file mode 100644 index 8f64aacfb4252..0000000000000 --- a/xla/mlir/memref/transforms/tests/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -load("@tsl//tsl:tsl.default.bzl", "filegroup") -load("//xla:glob_lit_test.bzl", "glob_lit_tests") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - licenses = ["notice"], -) - -glob_lit_tests( - name = "all_tests", - data = [":test_utilities"], - driver = "//xla:run_lit.sh", - test_file_exts = ["mlir"], -) - -# Bundle together all of the test utilities that are used by tests. -filegroup( - name = "test_utilities", - testonly = True, - data = [ - "//xla/mlir/runtime:xla-runtime-opt", - "@llvm-project//llvm:FileCheck", - "@llvm-project//mlir:run_lit.sh", - ], -) diff --git a/xla/mlir/memref/transforms/tests/aligned_allocations.mlir b/xla/mlir/memref/transforms/tests/aligned_allocations.mlir deleted file mode 100644 index ee024f42d2de6..0000000000000 --- a/xla/mlir/memref/transforms/tests/aligned_allocations.mlir +++ /dev/null @@ -1,28 +0,0 @@ -// RUN: xla-runtime-opt %s --xla-memref-aligned-allocations \ -// RUN: | FileCheck %s - -// RUN: xla-runtime-opt %s --xla-memref-aligned-allocations=alignment=16 \ -// RUN: | FileCheck --check-prefix=ALIGN16 %s - -// CHECK-LABEL: @aligned_alloc -// ALIGN16-LABEL: @aligned_alloc -func.func @aligned_alloc(%arg0: index) -> memref { - // CHECK: %[[ALLOC:.*]] = memref.alloc(%arg0) {alignment = 64 : i64} - // ALIGN16: %[[ALLOC:.*]] = memref.alloc(%arg0) {alignment = 32 : i64} - %0 = memref.alloc(%arg0) { alignment = 32 : i64 } : memref - // CHECK: return %[[ALLOC]] - // ALIGN16: return %[[ALLOC]] - return %0 : memref -} - -// CHECK-LABEL: @unaligned_alloc -// ALIGN16-LABEL: @unaligned_alloc -func.func @unaligned_alloc(%arg0: index) -> memref { - // CHECK: %[[ALLOC:.*]] = memref.alloc(%arg0) {alignment = 64 : i64} - // ALIGN16: %[[ALLOC:.*]] = memref.alloc(%arg0) {alignment = 16 : i64} - %0 = memref.alloc(%arg0) : memref - // CHECK: return %[[ALLOC]] - // ALIGN16: return %[[ALLOC]] - return %0 : memref -} - diff --git a/xla/mlir/runtime/BUILD b/xla/mlir/runtime/BUILD index 44eb131a5cccc..a4687f8a4044f 100644 --- a/xla/mlir/runtime/BUILD +++ b/xla/mlir/runtime/BUILD @@ -1,7 +1,3 @@ -load("//xla:xla.bzl", "xla_cc_binary") -load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") -load("@bazel_skylib//rules:build_test.bzl", "build_test") - package_group( name = "friends", packages = [ @@ -18,6 +14,7 @@ package_group( # TODO(ezhulenev): All targets depending on mlir must be under xla/mlir folder "//xla/service/cpu/...", "//xla/service/gpu/...", + "//third_party/py/enzyme_ad/...", ], ) @@ -26,30 +23,3 @@ package( default_visibility = [":friends"], licenses = ["notice"], ) - -build_test( - name = "xla-runtime-opt_build_test", - targets = [ - ":xla-runtime-opt", - ], -) - -xla_cc_binary( - name = "xla-runtime-opt", - srcs = ["xla-runtime-opt.cc"], - compatible_with = get_compatible_with_portable(), - deps = [ - "//xla/mlir/math/transforms:passes", - "//xla/mlir/memref/transforms:passes", - "//xla/mlir/runtime/ir/tests:testlib", - "//xla/mlir/runtime/transforms:compilation_pipeline_cpu", - "//xla/mlir/runtime/transforms:compilation_pipeline_gpu", - "//xla/mlir/runtime/transforms:passes", - "@llvm-project//mlir:AsyncDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncExtensions", - "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:MlirOptLib", - ], -) diff --git a/xla/mlir/runtime/ir/BUILD b/xla/mlir/runtime/ir/BUILD index 8365d7727bfb7..9c0d4e8b4f9c2 100644 --- a/xla/mlir/runtime/ir/BUILD +++ b/xla/mlir/runtime/ir/BUILD @@ -1,6 +1,6 @@ -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") +load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -107,6 +107,5 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:SideEffectInterfaces", - "@llvm-project//mlir:Support", ], ) diff --git a/xla/mlir/runtime/ir/rt_dialect.cc b/xla/mlir/runtime/ir/rt_dialect.cc index 6f1463bc0de3a..45b52b3f7f0b1 100644 --- a/xla/mlir/runtime/ir/rt_dialect.cc +++ b/xla/mlir/runtime/ir/rt_dialect.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/runtime/ir/rt_dialect.h b/xla/mlir/runtime/ir/rt_dialect.h index 5a690aacba7ac..d13a536e57599 100644 --- a/xla/mlir/runtime/ir/rt_dialect.h +++ b/xla/mlir/runtime/ir/rt_dialect.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/runtime/ir/rt_dialect.td b/xla/mlir/runtime/ir/rt_dialect.td index 84cef0274448c..e2a5baa842023 100644 --- a/xla/mlir/runtime/ir/rt_dialect.td +++ b/xla/mlir/runtime/ir/rt_dialect.td @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/runtime/ir/rt_interfaces.cc b/xla/mlir/runtime/ir/rt_interfaces.cc index 7fb51f29dd7fa..611b341b2bad1 100644 --- a/xla/mlir/runtime/ir/rt_interfaces.cc +++ b/xla/mlir/runtime/ir/rt_interfaces.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/runtime/ir/rt_interfaces.h b/xla/mlir/runtime/ir/rt_interfaces.h index 8b310a6f6bc97..6a58a9578f9db 100644 --- a/xla/mlir/runtime/ir/rt_interfaces.h +++ b/xla/mlir/runtime/ir/rt_interfaces.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/runtime/ir/rt_interfaces.td b/xla/mlir/runtime/ir/rt_interfaces.td index ac053f8460429..790a67eae249d 100644 --- a/xla/mlir/runtime/ir/rt_interfaces.td +++ b/xla/mlir/runtime/ir/rt_interfaces.td @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/runtime/ir/rt_ops.cc b/xla/mlir/runtime/ir/rt_ops.cc index a8dfe04f2c17d..3440ddb9dd14a 100644 --- a/xla/mlir/runtime/ir/rt_ops.cc +++ b/xla/mlir/runtime/ir/rt_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/runtime/ir/rt_ops.h b/xla/mlir/runtime/ir/rt_ops.h index d15d5c5e15e10..3bc455a15f9f1 100644 --- a/xla/mlir/runtime/ir/rt_ops.h +++ b/xla/mlir/runtime/ir/rt_ops.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/runtime/ir/rt_ops.td b/xla/mlir/runtime/ir/rt_ops.td index 2cc3987b3021f..eee5e09250fb8 100644 --- a/xla/mlir/runtime/ir/rt_ops.td +++ b/xla/mlir/runtime/ir/rt_ops.td @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/runtime/ir/tests/BUILD b/xla/mlir/runtime/ir/tests/BUILD deleted file mode 100644 index c95a1b26d209c..0000000000000 --- a/xla/mlir/runtime/ir/tests/BUILD +++ /dev/null @@ -1,94 +0,0 @@ -load("//xla:glob_lit_test.bzl", "glob_lit_tests") -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("@tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") - -# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) - -glob_lit_tests( - name = "all_tests", - data = [":test_utilities"], - driver = "//xla:run_lit.sh", - test_file_exts = ["mlir"], -) - -# Bundle together all of the test utilities that are used by tests. -filegroup( - name = "test_utilities", - testonly = True, - data = [ - "//xla/mlir/runtime:xla-runtime-opt", - "@llvm-project//llvm:FileCheck", - "@llvm-project//mlir:run_lit.sh", - ], -) - -td_library( - name = "testlib_td_files", - srcs = [ - "testlib.td", - "testlib_attrs.td", - "testlib_enums.td", - "testlib_types.td", - ], - compatible_with = get_compatible_with_portable(), - includes = ["include"], - deps = ["@llvm-project//mlir:OpBaseTdFiles"], -) - -gentbl_cc_library( - name = "testlib_inc_gen", - compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-attrdef-decls"], - "testlib_attrs.h.inc", - ), - ( - ["-gen-attrdef-defs"], - "testlib_attrs.cc.inc", - ), - ( - ["-gen-dialect-decls"], - "testlib_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "testlib_dialect.cc.inc", - ), - ( - ["-gen-enum-decls"], - "testlib_enums.h.inc", - ), - ( - ["-gen-enum-defs"], - "testlib_enums.cc.inc", - ), - ( - ["-gen-typedef-decls"], - "testlib_types.h.inc", - ), - ( - ["-gen-typedef-defs"], - "testlib_types.cc.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "testlib.td", - deps = [":testlib_td_files"], -) - -cc_library( - name = "testlib", - srcs = ["testlib.cc"], - hdrs = ["testlib.h"], - compatible_with = get_compatible_with_portable(), - visibility = ["//xla/mlir/runtime:friends"], - deps = [ - ":testlib_inc_gen", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", - ], -) diff --git a/xla/mlir/runtime/ir/tests/ops.mlir b/xla/mlir/runtime/ir/tests/ops.mlir deleted file mode 100644 index 36387c24a5e31..0000000000000 --- a/xla/mlir/runtime/ir/tests/ops.mlir +++ /dev/null @@ -1,83 +0,0 @@ -// RUN: xla-runtime-opt %s | FileCheck %s - -// CHECK: rt.export @pass_context -rt.export @pass_context - -// CHECK-LABEL: func @pass_context( -// CHECK: %[[CTX:.*]]: !rt.execution_context -func.func @pass_context(%arg0: !rt.execution_context) { - return -} - -// CHECK: rt.export @set_output ordinal 42 -rt.export @set_output ordinal 42 - -// CHECK-LABEL: func @set_output( -// CHECK: %[[CTX:.*]]: !rt.execution_context -func.func @set_output(%arg0: !rt.execution_context) { - // CHECK: %[[MEMREF:.*]] = memref.alloc - %0 = memref.alloc() : memref - // CHECK: rt.set_output %[[CTX]], 0, %[[MEMREF]] - rt.set_output %arg0, 0, %0 : memref - return -} - -// CHECK-LABEL: func @set_error( -// CHECK: %[[CTX:.*]]: !rt.execution_context -func.func @set_error(%arg0: !rt.execution_context) { - // CHECK: rt.set_error %[[CTX]], "Failed precondition" - rt.set_error %arg0, "Failed precondition" - return -} - -// CHECK-LABEL: func @custom_call( -// CHECK: %[[CTX:.*]]: !rt.execution_context -// CHECK: %[[MEMREF:.*]]: memref -func.func @custom_call(%ctx: !rt.execution_context, - %input: memref) -> f32 { - // CHECK: rt.call %[[CTX]]["f32_reduce"] (%[[MEMREF]]) - // CHECK-SAME: : (memref) -> f32 - %status, %0 = rt.call %ctx["f32_reduce"] (%input) : (memref) -> f32 - %ok = rt.is_ok %status - cf.assert %ok, "failed to call custom call" - return %0 : f32 -} - -// CHECK-LABEL: func @dynamic_custom_call( -// CHECK: %[[CTX:.*]]: !rt.execution_context -func.func @dynamic_custom_call(%ctx: !rt.execution_context) { - // CHECK: rt.call dynamic %[[CTX]]["f32_reduce"] () : () -> () - %status = rt.call dynamic %ctx["f32_reduce"] () : () -> () - return -} - -// CHECK-LABEL: func @opaque_arg( -// CHECK: %[[CTX:.*]]: !rt.execution_context, -// CHECK: %[[ARG:.*]]: !rt.opaque -// CHECK: ) -> !rt.opaque -func.func @opaque_arg(%ctx: !rt.execution_context, - %arg0: !rt.opaque) -> !rt.opaque { - // CHECK: rt.call %[[CTX]]["test"] - // CHECK-SAME: (%[[ARG]]) : (!rt.opaque) -> !rt.opaque - %status, %result = rt.call %ctx["test"] (%arg0) : (!rt.opaque) -> (!rt.opaque) - return %result : !rt.opaque -} - -// CHECK-LABEL: func @trace( -// CHECK: %[[CTX:.*]]: !rt.execution_context, -// CHECK: %[[ARG:.*]]: memref -// CHECK: ) -> memref -func.func @trace(%ctx: !rt.execution_context, - %arg: memref) -> memref { - // CHECK: rt.trace #rt.hlo_trace<"fusion">, %[[CTX]] - rt.trace #rt.hlo_trace<"fusion">, %ctx {} - - // CHECK: rt.trace #rt.hlo_trace<"fusion"> - // CHECK-SAME: %[[CTX]] -> memref - // CHECK-NEXT: yield %[[ARG]] : memref - %0 = rt.trace #rt.hlo_trace<"fusion">, %ctx -> memref { - yield %arg : memref - } - - return %0 : memref -} diff --git a/xla/mlir/runtime/ir/tests/ops_verify.mlir b/xla/mlir/runtime/ir/tests/ops_verify.mlir deleted file mode 100644 index 4fb5231582aec..0000000000000 --- a/xla/mlir/runtime/ir/tests/ops_verify.mlir +++ /dev/null @@ -1,83 +0,0 @@ -// RUN: xla-runtime-opt -verify-diagnostics -split-input-file %s - -// ----- -// expected-error @+1 {{func op named 'foo' not found for export}} -rt.export @foo - -// ----- -// expected-error @+1 {{'func.func' op requires "rt.exported" to be an integer attribute}} -func.func private @verify_rt_exported(%arg0: memref) - attributes { rt.exported } { - call @custom_call(%arg0) : (memref) -> () - return -} - -// ----- -func.func private @verify_exported_non_func(%arg0: memref) { - // expected-error @+1 {{"rt.exported" can only be applied to a function}} - call @custom_call(%arg0) { rt.exported = 0 : i32}: (memref) -> () - return -} - -// ----- -func.func private @verify_exported_non_func(%arg0: memref) { - // expected-error @+1 {{"rt.dynamic" can only be applied to a custom call declaration}} - call @custom_call(%arg0) {rt.dynamic}: (memref) -> () - return -} - -// ----- -// expected-error @+1 {{'func.func' op requires non-empty body for function with attribute "rt.exported"}} -func.func private @verify_rt_exported(%arg0: memref) - attributes { rt.exported = 0 : i32 } - - -// ----- -// expected-error @+1 {{'func.func' op requires "rt.custom_call" to only accept string value}} -func.func private @custom_call(%arg0: memref) -> memref - attributes { rt.custom_call = 1, attr0 = 1 : i32, attr1 = 1.0 : f32 } - - -// ----- -// expected-error @+1 {{'func.func' op requires "rt.custom_call" to only accept string value}} -func.func private @custom_call(%arg0: memref) -> memref - attributes { rt.custom_call = 1, attr0 = 1 : i32, attr1 = 1.0 : f32 } - - -// ----- -// expected-error @+1 {{'func.func' op requires "rt.custom_call" to only apply to a function declaration}} -func.func private @custom_call(%arg0: memref) -> memref - attributes { rt.custom_call = "target", attr0 = 1 : i32, attr1 = 1.0 : f32 }{ - call @custom_call(%arg0) : (memref) -> () - return -} - -// ----- -func.func private @verify_custom_call_non_func(%arg0: memref) - attributes { rt.exported = 0 : i32 } { - // expected-error @+1 {{"rt.custom_call" can only be applied to a function}} - call @custom_call(%arg0) {rt.custom_call = "target"}: (memref) -> () - return -} - -// ----- -// expected-error @+1 {{'func.func' op has illegal attribute value of rt.constraint for argument 0}} -func.func private @constraint( - %input0: memref<*xf32> { rt.constraint = "test" }, - %input1: memref { rt.constraint = "shape" }, - %perm: memref<4xi32> { rt.constraint = "value" } -) attributes {rt.custom_call = "target"} - -// ----- -func.func @trace(%ctx: !rt.execution_context) { - // expected-error @+1 {{'rt.trace' invalid kind of attribute specified}} - rt.trace "string attribute", %ctx {} - return -} - -// ----- -func.func @trace_attribute(%ctx: !rt.execution_context) { - // expected-error @+1 {{"rt.trace" to be a trace annotation attribute}} - call @custom_call() { rt.trace = "foo" } : () -> () - return -} diff --git a/xla/mlir/runtime/ir/tests/testlib.cc b/xla/mlir/runtime/ir/tests/testlib.cc deleted file mode 100644 index e3bee7378d3c2..0000000000000 --- a/xla/mlir/runtime/ir/tests/testlib.cc +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/mlir/runtime/ir/tests/testlib.h" - -#include "llvm/ADT/TypeSwitch.h" -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/DialectImplementation.h" // from @llvm-project - -// clang-format off -#include "xla/mlir/runtime/ir/tests/testlib_dialect.cc.inc" -#include "xla/mlir/runtime/ir/tests/testlib_enums.cc.inc" -// clang-format on - -#define GET_ATTRDEF_CLASSES -#include "xla/mlir/runtime/ir/tests/testlib_attrs.cc.inc" - -#define GET_TYPEDEF_CLASSES -#include "xla/mlir/runtime/ir/tests/testlib_types.cc.inc" - -namespace xla { -namespace runtime { - -void TestlibDialect::initialize() { - addAttributes< -#define GET_ATTRDEF_LIST -#include "xla/mlir/runtime/ir/tests/testlib_attrs.cc.inc" - >(); - addTypes< -#define GET_TYPEDEF_LIST -#include "xla/mlir/runtime/ir/tests/testlib_types.cc.inc" - >(); -} - -} // namespace runtime -} // namespace xla diff --git a/xla/mlir/runtime/ir/tests/testlib.h b/xla/mlir/runtime/ir/tests/testlib.h deleted file mode 100644 index c7c60418680c7..0000000000000 --- a/xla/mlir/runtime/ir/tests/testlib.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_MLIR_RUNTIME_IR_TESTS_TESTLIB_H_ -#define XLA_MLIR_RUNTIME_IR_TESTS_TESTLIB_H_ - -#include - -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project -#include "mlir/IR/Dialect.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project - -// clang-format off -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "xla/mlir/runtime/ir/tests/testlib_dialect.h.inc" -#include "xla/mlir/runtime/ir/tests/testlib_enums.h.inc" -// clang-format on - -#define GET_ATTRDEF_CLASSES -#include "xla/mlir/runtime/ir/tests/testlib_attrs.h.inc" - -#define GET_TYPEDEF_CLASSES -#include "xla/mlir/runtime/ir/tests/testlib_types.h.inc" - -namespace xla { -namespace runtime { - -inline mlir::Type ConvertValueType(ValueType type) { - return mlir::LLVM::LLVMPointerType::get(type.getContext()); -} - -inline void AddTestlibTypeConversions(mlir::TypeConverter& converter) { - converter.addConversion(ConvertValueType); -} - -} // namespace runtime -} // namespace xla - -#endif // XLA_MLIR_RUNTIME_IR_TESTS_TESTLIB_H_ diff --git a/xla/mlir/runtime/ir/tests/testlib.mlir b/xla/mlir/runtime/ir/tests/testlib.mlir deleted file mode 100644 index a09c0660b3729..0000000000000 --- a/xla/mlir/runtime/ir/tests/testlib.mlir +++ /dev/null @@ -1,27 +0,0 @@ -// RUN: xla-runtime-opt %s | FileCheck %s - -// CHECK-LABEL: func @custom_arg( -// CHECK: %[[ARG:.*]]: !testlib.custom_arg -func.func @custom_arg(%arg0: !testlib.custom_arg) { - return -} - -// CHECK-LABEL: func @enum( -// CHECK: enum = #testlib.enum_type -func.func @enum() attributes { enum = #testlib.enum_type } { - return -} - -// CHECK-LABEL: func @another_enum( -// CHECK: enum = #testlib.another_enum_type -func.func @another_enum() attributes { enum = #testlib.another_enum_type } -{ - return -} - -// CHECK-LABEL: func @dims( -// CHECK: dims = #testlib.pair_of_dims<2, [1, 1], [2, 2]> -func.func @dims() attributes { dims = #testlib.pair_of_dims<2, [1, 1], [2, 2]> } -{ - return -} diff --git a/xla/mlir/runtime/ir/tests/testlib.td b/xla/mlir/runtime/ir/tests/testlib.td deleted file mode 100644 index 24b0c4b2d0196..0000000000000 --- a/xla/mlir/runtime/ir/tests/testlib.td +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_TESTLIB -#define XLA_TESTLIB - -include "mlir/IR/DialectBase.td" - -//===----------------------------------------------------------------------===// -// XLA Runtime Testlib dialect definitions. -//===----------------------------------------------------------------------===// - -def TestlibDialect : Dialect { - let name = "testlib"; - - let description = [{ - XLA Runtime Testlib dialect for writing tests for the runtime features. - }]; - - let cppNamespace = "::xla::runtime"; - - let useDefaultAttributePrinterParser = 1; - - let useDefaultTypePrinterParser = 1; -} - -include "testlib_attrs.td" -include "testlib_enums.td" -include "testlib_types.td" - -#endif // XLA_TESTLIB diff --git a/xla/mlir/runtime/ir/tests/testlib_attrs.td b/xla/mlir/runtime/ir/tests/testlib_attrs.td deleted file mode 100644 index 8fa33335c379d..0000000000000 --- a/xla/mlir/runtime/ir/tests/testlib_attrs.td +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_TESTLIB_ATTRS -#define XLA_TESTLIB_ATTRS - -include "mlir/IR/OpBase.td" -include "mlir/IR/AttrTypeBase.td" - -def Dim : ArrayRefParameter<"int64_t", "Dimension">; - -def PairOfDims : AttrDef { - let mnemonic = "pair_of_dims"; - let summary = "Pair of dimensions"; - let parameters = (ins - "int64_t":$rank, - Dim:$a, - Dim:$b - ); - let assemblyFormat = "`<` $rank `,` `[` $a `]` `,` `[` $b `]` `>`"; -} - -#endif // XLA_TESTLIB_ATTRS diff --git a/xla/mlir/runtime/ir/tests/testlib_enums.td b/xla/mlir/runtime/ir/tests/testlib_enums.td deleted file mode 100644 index 5af9e312fb7d1..0000000000000 --- a/xla/mlir/runtime/ir/tests/testlib_enums.td +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_TESTLIB_ENUMS -#define XLA_TESTLIB_ENUMS - -include "mlir/IR/EnumAttr.td" - -def TESTLIB_FOO : I32EnumAttrCase<"Foo", 0>; -def TESTLIB_BAR : I32EnumAttrCase<"Bar", 1>; -def TESTLIB_BAZ : I32EnumAttrCase<"Baz", 2>; - -def TESTLIB_EnumType : I32EnumAttr<"EnumType", - "Custom Call Testlib Enum Type.", - [TESTLIB_FOO, TESTLIB_BAR, TESTLIB_BAZ]> { - let genSpecializedAttr = 0; - let cppNamespace = "::xla::runtime"; -} - -def TESTLIB_EnumTypeAttr : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -// Define another enum type to test enum conversion at the custom call boundary. -def TESTLIB_EnumType2 : I32EnumAttr<"EnumType2", - "Another Custom Call Testlib Enum Type.", - [TESTLIB_FOO, TESTLIB_BAR, TESTLIB_BAZ]> { - let genSpecializedAttr = 0; - let cppNamespace = "::xla::runtime"; -} - -def TESTLIB_EnumType2Attr : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -#endif // XLA_TESTLIB_ENUMS diff --git a/xla/mlir/runtime/ir/tests/testlib_types.td b/xla/mlir/runtime/ir/tests/testlib_types.td deleted file mode 100644 index dbd43037337a4..0000000000000 --- a/xla/mlir/runtime/ir/tests/testlib_types.td +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2022 The TensorFlow Runtime Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef XLA_TESTLIB_TYPES -#define XLA_TESTLIB_TYPES - -class TESTLIB_Type - : TypeDef { - let mnemonic = typeMnemonic; -} - -// TODO(ezhulenev): Remove this type once all tests are migrated to the type -// defined below. This requires moving tests from `jitrt/cpp_tests` to XLA. -def TESTLIB_CustomArgType : TESTLIB_Type<"CustomArg", "custom_arg"> { - let summary = "custom argument type"; - let description = [{ - Type for testing passing custom user-defined arguments to XLA executables. - }]; -} - -def TESTLIB_ValueType : TESTLIB_Type<"Value", "value"> { - let summary = "custom value type"; - let description = [{ - Custom type for testing passing custom user-defined types as XLA executable - and XLA runtime custom calls arguments and results. - }]; -} - -#endif // XLA_TESTLIB_TYPES diff --git a/xla/mlir/runtime/transforms/BUILD b/xla/mlir/runtime/transforms/BUILD index 03743ce98dfc6..0563e379303f1 100644 --- a/xla/mlir/runtime/transforms/BUILD +++ b/xla/mlir/runtime/transforms/BUILD @@ -1,8 +1,15 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") -load("//xla:xla.bzl", "xla_cc_test") load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") -load("@tsl//tsl/platform:build_config.bzl", "if_llvm_system_z_available") +load( + "@tsl//tsl/platform:build_config_root.bzl", + "if_llvm_aarch64_available", + "if_llvm_arm_available", + "if_llvm_powerpc_available", + "if_llvm_system_z_available", + "if_llvm_x86_available", +) load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -46,6 +53,7 @@ cc_library( "//xla/mlir/runtime/ir:rt", "//xla/mlir/runtime/utils:custom_calls", "//xla/runtime:custom_call", + "//xla/runtime:logical_result", "//xla/runtime:tracing", "//xla/runtime:type_id", "@com_google_absl//absl/log:check", @@ -72,9 +80,10 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "//xla/mlir/runtime/ir:rt", - "@llvm-project//mlir:FuncDialect", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", - "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", ], ) @@ -85,10 +94,9 @@ xla_cc_test( deps = [ ":calling_convention", "//xla/mlir/runtime/ir:rt", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:TransformUtils", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -99,24 +107,26 @@ cc_library( srcs = ["compilation_pipeline_cpu.cc"], hdrs = ["compilation_pipeline_cpu.h"], compatible_with = get_compatible_with_portable(), + local_defines = select({ + "//xla/service/cpu:experimental_mlir_gpu_enabled": [ + "EXPERIMENTAL_MLIR_GPU=1", + ], + "//conditions:default": [], + }), visibility = ["//visibility:public"], deps = [ ":compilation_pipeline_options", ":compiler", ":passes", "//xla/mlir/backends/cpu/transforms:passes", - "//xla/mlir/math/transforms:passes", "//xla/mlir/memref/transforms:passes", "//xla/mlir/runtime/ir:rt", "//xla/mlir_hlo:transforms_passes", "//xla/runtime:compiler", - "@llvm-project//mlir:AMXToLLVMIRTranslation", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithTransforms", - "@llvm-project//mlir:ArmNeonToLLVMIRTranslation", - "@llvm-project//mlir:ArmSVEToLLVMIRTranslation", "@llvm-project//mlir:AsyncDialect", "@llvm-project//mlir:AsyncToLLVM", "@llvm-project//mlir:AsyncTransforms", @@ -125,8 +135,6 @@ cc_library( "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", - "@llvm-project//mlir:GPUToGPURuntimeTransforms", - "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:LLVMToLLVMIRTranslation", "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:LinalgTransforms", @@ -140,46 +148,21 @@ cc_library( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:Transforms", + "@tsl//tsl/platform:logging", + ] + select({ + "//xla/service/cpu:experimental_mlir_gpu_enabled": [ + "@llvm-project//mlir:GPUToGPURuntimeTransforms", + "@llvm-project//mlir:GPUTransforms", + ], + "//conditions:default": [], + }) + if_llvm_aarch64_available([ + "@llvm-project//mlir:ArmSVEToLLVMIRTranslation", + ]) + if_llvm_arm_available([ + "@llvm-project//mlir:ArmNeonToLLVMIRTranslation", + ]) + if_llvm_x86_available([ + "@llvm-project//mlir:AMXToLLVMIRTranslation", "@llvm-project//mlir:X86VectorToLLVMIRTranslation", - ], - alwayslink = 1, # has pipeline registration -) - -cc_library( - name = "compilation_pipeline_gpu", - srcs = ["compilation_pipeline_gpu.cc"], - hdrs = ["compilation_pipeline_gpu.h"], - compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], - deps = [ - ":compilation_pipeline_options", - ":compiler", - ":passes", - "//xla/mlir/runtime/ir:rt", - "//xla/mlir/runtime/ir/tests:testlib", - "//xla/mlir_hlo", - "//xla/mlir_hlo:lhlo", - "//xla/mlir_hlo:lhlo_gpu", - "//xla/runtime:compiler", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:AsyncDialect", - "@llvm-project//mlir:AsyncToLLVM", - "@llvm-project//mlir:AsyncTransforms", - "@llvm-project//mlir:BuiltinToLLVMIRTranslation", - "@llvm-project//mlir:ControlFlowDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncToLLVM", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMToLLVMIRTranslation", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:MemRefToLLVM", - "@llvm-project//mlir:MemRefTransforms", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ReconcileUnrealizedCasts", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:SCFToControlFlow", - "@llvm-project//mlir:Transforms", - ], + ]), alwayslink = 1, # has pipeline registration ) @@ -190,6 +173,7 @@ cc_library( deps = [ ":custom_call_encoding", "//xla/runtime:type_id", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -201,10 +185,15 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/mlir/runtime/ir:rt", "//xla/runtime:custom_call", + "//xla/runtime:logical_result", + "//xla/runtime:memref_view", "//xla/runtime:tracing", "//xla/runtime:type_id", + "@com_google_absl//absl/types:span", + "@eigen_archive//:eigen3", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:AsyncDialect", @@ -232,36 +221,40 @@ cc_library( "//xla/runtime:arguments", "//xla/runtime:compiler", "//xla/runtime:constraints", + "//xla/runtime:errors", "//xla/runtime:executable", + "//xla/runtime:execution_engine", + "//xla/runtime:logical_result", + "//xla/runtime:memory_mapper", "//xla/runtime:symbolic_shape", + "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Analysis", "@llvm-project//llvm:Core", + "@llvm-project//llvm:JITLink", "@llvm-project//llvm:Passes", "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", "@llvm-project//mlir:ExecutionEngineUtils", "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@llvm-project//mlir:ToLLVMIRTranslation", - ] + select({ - "@tsl//tsl:arm_any": [ - "@llvm-project//llvm:AArch64AsmParser", - ], - "@tsl//tsl:linux_ppc64le": [ - "@llvm-project//llvm:PowerPCAsmParser", - ], - "@tsl//tsl:macos_arm64": [ - "@llvm-project//llvm:AArch64AsmParser", - ], - "//conditions:default": [ - "@llvm-project//llvm:X86AsmParser", - ], - }) + if_llvm_system_z_available([ + ] + if_llvm_aarch64_available([ + "@llvm-project//llvm:AArch64AsmParser", + ]) + if_llvm_powerpc_available([ + "@llvm-project//llvm:PowerPCAsmParser", + ]) + if_llvm_system_z_available([ "@llvm-project//llvm:SystemZAsmParser", + ]) + if_llvm_x86_available([ + "@llvm-project//llvm:X86AsmParser", ]), ) @@ -272,10 +265,10 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ ":type_converter", + "//xla:xla_data_proto_cc", "//xla/mlir/runtime/utils:constraints", "//xla/runtime:arguments", "//xla/runtime:constraints", - "//xla/runtime:errors", "//xla/runtime:symbolic_shape", "//xla/runtime:types", "@com_google_absl//absl/status", @@ -283,7 +276,6 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -298,14 +290,17 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/mlir/runtime/ir:rt", "//xla/runtime:types", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@llvm-project//llvm:Support", "@llvm-project//mlir:AsyncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@tsl//tsl/platform:statusor", ], ) @@ -315,7 +310,9 @@ xla_cc_test( compatible_with = get_compatible_with_portable(), deps = [ ":type_converter", + "//xla:xla_data_proto_cc", "//xla/runtime:types", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", diff --git a/xla/mlir/runtime/transforms/add_initializations.cc b/xla/mlir/runtime/transforms/add_initializations.cc index d4a65c5b85d6f..09e4d4d564769 100644 --- a/xla/mlir/runtime/transforms/add_initializations.cc +++ b/xla/mlir/runtime/transforms/add_initializations.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/runtime/transforms/calling_convention.cc b/xla/mlir/runtime/transforms/calling_convention.cc index a98276975446a..d7861d265614f 100644 --- a/xla/mlir/runtime/transforms/calling_convention.cc +++ b/xla/mlir/runtime/transforms/calling_convention.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,7 +18,15 @@ limitations under the License. #include #include -#include "xla/mlir/runtime/ir/rt_ops.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "xla/mlir/runtime/ir/rt_dialect.h" namespace xla { namespace runtime { diff --git a/xla/mlir/runtime/transforms/calling_convention.h b/xla/mlir/runtime/transforms/calling_convention.h index 6504af7e8b36e..da21f46780762 100644 --- a/xla/mlir/runtime/transforms/calling_convention.h +++ b/xla/mlir/runtime/transforms/calling_convention.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/runtime/transforms/calling_convention_test.cc b/xla/mlir/runtime/transforms/calling_convention_test.cc index 01572346568f9..6d326264d3be1 100644 --- a/xla/mlir/runtime/transforms/calling_convention_test.cc +++ b/xla/mlir/runtime/transforms/calling_convention_test.cc @@ -21,7 +21,7 @@ #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/TypeRange.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "xla/mlir/runtime/ir/rt_ops.h" +#include "xla/mlir/runtime/ir/rt_dialect.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc b/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc index 50488cad7f13f..c5e845c22516d 100644 --- a/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc +++ b/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ limitations under the License. #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" // from @llvm-project #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" // from @llvm-project -#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" // from @llvm-project #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" // from @llvm-project #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project @@ -32,7 +31,6 @@ limitations under the License. #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project #include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project @@ -42,21 +40,34 @@ limitations under the License. #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#ifdef TF_LLVM_X86_AVAILABLE #include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h" // from @llvm-project +#endif +#if defined(TF_LLVM_AARCH64_AVAILABLE) || defined(TF_LLVM_AARCH32_AVAILABLE) #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h" // from @llvm-project +#ifdef TF_LLVM_AARCH64_AVAILABLE #include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h" // from @llvm-project +#endif +#endif #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project +#ifdef TF_LLVM_X86_AVAILABLE #include "mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h" // from @llvm-project +#endif #include "mlir/Transforms/Passes.h" // from @llvm-project #include "xla/mlir/backends/cpu/transforms/passes.h" -#include "xla/mlir/math/transforms/passes.h" #include "xla/mlir/memref/transforms/passes.h" #include "xla/mlir/runtime/ir/rt_dialect.h" #include "xla/mlir/runtime/transforms/compilation_pipeline_options.h" #include "xla/mlir/runtime/transforms/compiler.h" #include "xla/mlir/runtime/transforms/passes.h" #include "xla/mlir_hlo/transforms/passes.h" +#include "tsl/platform/logging.h" + +#ifdef EXPERIMENTAL_MLIR_GPU +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project +#include "mlir/Dialect/GPU/Transforms/Passes.h" // from @llvm-project +#endif // EXPERIMENTAL_MLIR_GPU namespace xla { namespace runtime { @@ -75,12 +86,18 @@ void RegisterDefaultXlaCpuRuntimeDialects(DialectRegistry& dialects) { mlir::func::registerAllExtensions(*dialects); // Register MLIR dialects that can be translated to LLVM IR. +#ifdef TF_LLVM_AARCH64_AVAILABLE + mlir::registerArmSVEDialectTranslation(*dialects); +#endif +#if defined(TF_LLVM_AARCH64_AVAILABLE) || defined(TF_LLVM_AARCH32_AVAILABLE) mlir::registerArmNeonDialectTranslation(*dialects); +#endif +#ifdef TF_LLVM_X86_AVAILABLE mlir::registerAMXDialectTranslation(*dialects); - mlir::registerArmSVEDialectTranslation(*dialects); + mlir::registerX86VectorDialectTranslation(*dialects); +#endif mlir::registerBuiltinDialectTranslation(*dialects); mlir::registerLLVMDialectTranslation(*dialects); - mlir::registerX86VectorDialectTranslation(*dialects); } static void CreateXlaCpuCompilationPipeline(mlir::OpPassManager& pm, @@ -97,10 +114,6 @@ static void CreateXlaCpuCompilationPipeline(mlir::OpPassManager& pm, pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); - // Enable math approximations to match XLA's FP accuracy spec. - pm.addNestedPass( - xla::CreateMathApproximationPass({"all"})); - // Convert all linalg operations to parallel loops. pm.addNestedPass( mlir::createConvertLinalgToParallelLoopsPass()); @@ -146,6 +159,7 @@ static void CreateXlaCpuCompilationPipeline(mlir::OpPassManager& pm, llvm_options.enableAvx2 = opts.math_avx2; pm.addPass(mlir::hlo::createGenericHostToLLVMPass(llvm_options)); const bool gpuCodegen = opts.xla_cpu_sparse_cuda_threads > 0; +#ifdef EXPERIMENTAL_MLIR_GPU if (gpuCodegen) { #ifdef MLIR_GPU_TO_CUBIN_PASS_ENABLE pm.addNestedPass( @@ -154,6 +168,10 @@ static void CreateXlaCpuCompilationPipeline(mlir::OpPassManager& pm, #endif pm.addPass(mlir::createGpuToLLVMConversionPass()); } +#else // EXPERIMENTAL_MLIR_GPU + CHECK(!gpuCodegen) + << "Experimental MLIR GPU code generation was not enabled at build time"; +#endif // EXPERIMENTAL_MLIR_GPU pm.addPass(mlir::createReconcileUnrealizedCastsPass()); // Prepare module for translation to LLVM. diff --git a/xla/mlir/runtime/transforms/compilation_pipeline_cpu.h b/xla/mlir/runtime/transforms/compilation_pipeline_cpu.h index 91c5a9054f633..df2eb7ba863f4 100644 --- a/xla/mlir/runtime/transforms/compilation_pipeline_cpu.h +++ b/xla/mlir/runtime/transforms/compilation_pipeline_cpu.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/runtime/transforms/compilation_pipeline_gpu.cc b/xla/mlir/runtime/transforms/compilation_pipeline_gpu.cc deleted file mode 100644 index 5fb9099fa3b10..0000000000000 --- a/xla/mlir/runtime/transforms/compilation_pipeline_gpu.cc +++ /dev/null @@ -1,135 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/mlir/runtime/transforms/compilation_pipeline_gpu.h" - -#include - -#include "mhlo/IR/hlo_ops.h" -#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" // from @llvm-project -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" // from @llvm-project -#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" // from @llvm-project -#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Async/IR/Async.h" // from @llvm-project -#include "mlir/Dialect/Async/Passes.h" // from @llvm-project -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project -#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project -#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project -#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project -#include "xla/mlir/runtime/ir/rt_dialect.h" -#include "xla/mlir/runtime/ir/tests/testlib.h" -#include "xla/mlir/runtime/transforms/compiler.h" -#include "xla/mlir/runtime/transforms/passes.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" - -namespace xla { -namespace runtime { - -void RegisterDefaultXlaGpuRuntimeDialects(DialectRegistry& dialects) { - // Register MLIR dialects supported by the compiled executables. - dialects->insert(); - - // Register MLIR dialects that can be translated to LLVM IR. - mlir::registerBuiltinDialectTranslation(*dialects); - mlir::registerLLVMDialectTranslation(*dialects); -} - -void RegisterLmhloGpuDialect(DialectRegistry& dialects) { - dialects->insert(); -} - -void RegisterTestlibDialect(DialectRegistry& dialects) { - dialects->insert(); -} - -static void CreateDefaultXlaGpuRuntimeCompilationPipeline( - mlir::OpPassManager& pm, const CompilationPipelineOptions& opts, - bool add_async_passes) { - pm.addPass(mlir::createConvertSCFToCFPass()); - - if (add_async_passes) pm.addPass(mlir::createAsyncFuncToAsyncRuntimePass()); - - // Export functions to the XLA runtime. - pm.addPass(CreateExportRuntimeFunctionsPass()); - pm.addPass(CreateAddInitializationsPass()); - pm.addPass(CreateConvertCustomCallsPass()); - pm.addPass(CreateConvertAssertsPass()); - - if (add_async_passes) { - // Lower from high level async operations to async runtime. - pm.addPass(mlir::createAsyncToAsyncRuntimePass()); - - // Add async.runtime reference counting operations. - pm.addPass(mlir::createAsyncRuntimePolicyBasedRefCountingPass()); - } - - // Prepare memrefs for lowering to LLVM. - pm.addNestedPass(mlir::memref::createExpandOpsPass()); - pm.addNestedPass( - mlir::memref::createExpandStridedMetadataPass()); - - // Convert runtime operations and custom calls to LLVM dialect. - ConvertRuntimeToLLvmOpts rt_to_llvm_opts = { - opts.populate_type_id_names, opts.populate_type_conversions, - opts.populate_arg_encodings, opts.populate_ret_encodings, - opts.populate_attr_encodings}; - pm.addPass(CreateConvertRuntimeToLLVMPass(std::move(rt_to_llvm_opts))); - - // Convert async dialect to LLVM once everything else is in the LLVM dialect. - if (add_async_passes) pm.addPass(mlir::createConvertAsyncToLLVMPass()); - - // Convert everything else to LLVM dialect. - pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); - pm.addPass(mlir::createConvertFuncToLLVMPass()); - pm.addPass(mlir::createReconcileUnrealizedCastsPass()); - - // Clean up IR before passing it to LLVM. - pm.addPass(mlir::createCSEPass()); -} - -void CreateDefaultXlaGpuRuntimeCompilationPipeline( - PassManager& passes, const CompilationPipelineOptions& opts, - bool add_async_passes) { - CreateDefaultXlaGpuRuntimeCompilationPipeline(*passes, opts, - add_async_passes); -} - -void AppendXlaGpuDialectRegistry(mlir::MLIRContext& context) { - DialectRegistry dialects; - RegisterDefaultXlaGpuRuntimeDialects(dialects); - context.appendDialectRegistry(*dialects); -} - -static void CreateDefaultGpuPipeline(mlir::OpPassManager& pm) { - CompilationPipelineOptions copts; - CreateDefaultXlaGpuRuntimeCompilationPipeline(pm, copts, false); -} - -static mlir::PassPipelineRegistration<> kXlaRuntimePipeline( - "xla-runtime-default-gpu-pipeline", - "Default XLA-GPU runtime compilation pipeline", CreateDefaultGpuPipeline); -} // namespace runtime -} // namespace xla diff --git a/xla/mlir/runtime/transforms/compilation_pipeline_gpu.h b/xla/mlir/runtime/transforms/compilation_pipeline_gpu.h deleted file mode 100644 index a36235baff6bc..0000000000000 --- a/xla/mlir/runtime/transforms/compilation_pipeline_gpu.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_MLIR_RUNTIME_TRANSFORMS_COMPILATION_PIPELINE_GPU_H_ -#define XLA_MLIR_RUNTIME_TRANSFORMS_COMPILATION_PIPELINE_GPU_H_ - -#include "xla/mlir/runtime/transforms/compilation_pipeline_options.h" -#include "xla/runtime/compiler.h" - -namespace mlir { -class MLIRContext; -} // namespace mlir - -namespace xla { -namespace runtime { - -// Registers dialects, interfaces and dialects translations with the registry -// required by the default XLA-GPU runtime compilation pipeline. -void RegisterDefaultXlaGpuRuntimeDialects(DialectRegistry& dialects); - -void RegisterLmhloGpuDialect(DialectRegistry& dialects); - -void RegisterTestlibDialect(DialectRegistry& dialects); - -// Creates default XLA-GPU runtime compilation pipeline that lowers from the -// `rt` and `memref` dialects to the LLVMIR dialect. This is a very simple -// pipeline that is mostly intended for writing tests for the XLA runtime, and -// it is expected that all end users will construct their own compilation -// pipelines from the available XLA and MLIR passes. -void CreateDefaultXlaGpuRuntimeCompilationPipeline( - PassManager& passes, const CompilationPipelineOptions& opts, - bool add_async_passes = false); - -void AppendXlaGpuDialectRegistry(mlir::MLIRContext& context); - -} // namespace runtime -} // namespace xla - -#endif // XLA_MLIR_RUNTIME_TRANSFORMS_COMPILATION_PIPELINE_GPU_H_ diff --git a/xla/mlir/runtime/transforms/compilation_pipeline_options.h b/xla/mlir/runtime/transforms/compilation_pipeline_options.h index a80876fb8bcb6..9840984d347f0 100644 --- a/xla/mlir/runtime/transforms/compilation_pipeline_options.h +++ b/xla/mlir/runtime/transforms/compilation_pipeline_options.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/runtime/transforms/compiler.h b/xla/mlir/runtime/transforms/compiler.h index 4db705a32bef5..f9a9c44ddd1cf 100644 --- a/xla/mlir/runtime/transforms/compiler.h +++ b/xla/mlir/runtime/transforms/compiler.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/runtime/transforms/convert_asserts.cc b/xla/mlir/runtime/transforms/convert_asserts.cc index 8f28463394ff6..94dfee42bf010 100644 --- a/xla/mlir/runtime/transforms/convert_asserts.cc +++ b/xla/mlir/runtime/transforms/convert_asserts.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,9 +18,16 @@ limitations under the License. #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "xla/mlir/runtime/ir/rt_dialect.h" #include "xla/mlir/runtime/ir/rt_ops.h" #include "xla/mlir/runtime/transforms/passes.h" diff --git a/xla/mlir/runtime/transforms/convert_custom_calls.cc b/xla/mlir/runtime/transforms/convert_custom_calls.cc index 3b136dff62bff..2c7886cce9b27 100644 --- a/xla/mlir/runtime/transforms/convert_custom_calls.cc +++ b/xla/mlir/runtime/transforms/convert_custom_calls.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,14 +17,29 @@ limitations under the License. #include #include +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "xla/mlir/runtime/ir/rt_dialect.h" +#include "xla/mlir/runtime/ir/rt_interfaces.h" #include "xla/mlir/runtime/ir/rt_ops.h" #include "xla/mlir/runtime/transforms/passes.h" diff --git a/xla/mlir/runtime/transforms/custom_call_encoding.cc b/xla/mlir/runtime/transforms/custom_call_encoding.cc index b8d2613104fce..b2c75bd12443c 100644 --- a/xla/mlir/runtime/transforms/custom_call_encoding.cc +++ b/xla/mlir/runtime/transforms/custom_call_encoding.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,29 +22,45 @@ limitations under the License. #include #include +#include "absl/types/span.h" +#include "Eigen/Core" // from @eigen_archive +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Async/IR/AsyncTypes.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "xla/mlir/runtime/ir/rt_dialect.h" #include "xla/primitive_util.h" #include "xla/runtime/custom_call.h" +#include "xla/runtime/logical_result.h" +#include "xla/runtime/memref_view.h" #include "xla/runtime/tracing.h" #include "xla/runtime/type_id.h" +#include "xla/xla_data.pb.h" #include "tsl/concurrency/async_value_ref.h" #include "tsl/concurrency/chain.h" diff --git a/xla/mlir/runtime/transforms/custom_call_encoding.h b/xla/mlir/runtime/transforms/custom_call_encoding.h index 3582987c1fb65..ea9fac26865ed 100644 --- a/xla/mlir/runtime/transforms/custom_call_encoding.h +++ b/xla/mlir/runtime/transforms/custom_call_encoding.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,8 +26,15 @@ limitations under the License. #include #include +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project diff --git a/xla/mlir/runtime/transforms/export_functions.cc b/xla/mlir/runtime/transforms/export_functions.cc index 54f6c590654c6..612666397173d 100644 --- a/xla/mlir/runtime/transforms/export_functions.cc +++ b/xla/mlir/runtime/transforms/export_functions.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,9 +22,14 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "xla/mlir/runtime/ir/rt_dialect.h" #include "xla/mlir/runtime/ir/rt_ops.h" #include "xla/mlir/runtime/transforms/passes.h" diff --git a/xla/mlir/runtime/transforms/jit_compiler.cc b/xla/mlir/runtime/transforms/jit_compiler.cc index c4d66d9073805..bfb072221687e 100644 --- a/xla/mlir/runtime/transforms/jit_compiler.cc +++ b/xla/mlir/runtime/transforms/jit_compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,31 +23,58 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/CGSCCPassManager.h" #include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/PassTimingInfo.h" #include "llvm/Pass.h" #include "llvm/Passes/OptimizationLevel.h" #include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Error.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/SMLoc.h" #include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/ExecutionEngine/OptUtils.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/Timing.h" // from @llvm-project #include "mlir/Target/LLVMIR/Export.h" // from @llvm-project #include "xla/mlir/runtime/ir/rt_dialect.h" #include "xla/mlir/runtime/ir/rt_ops.h" #include "xla/mlir/runtime/transforms/compiler.h" #include "xla/mlir/runtime/transforms/passes.h" +#include "xla/mlir/runtime/transforms/specialization.h" +#include "xla/runtime/arguments.h" +#include "xla/runtime/constraints.h" +#include "xla/runtime/errors.h" +#include "xla/runtime/executable.h" +#include "xla/runtime/execution_engine.h" +#include "xla/runtime/logical_result.h" +#include "xla/runtime/memory_mapper.h" +#include "xla/service/llvm_ir/llvm_util.h" namespace xla { namespace runtime { @@ -112,7 +139,8 @@ static void PrintPassPipeline(const mlir::PassManager& pm) { } static LogicalResult RunPipeline( - ModuleOp module, const std::function& create_pipeline, + ModuleOp module, + const std::function& create_pipeline, int verification_level) { if (!create_pipeline) return success(); @@ -136,7 +164,11 @@ static LogicalResult RunPipeline( pm.enableTiming(timing); } PassManager passes(&pm); - create_pipeline(passes); + absl::Status pipeline_created = create_pipeline(passes); + if (!pipeline_created.ok()) { + llvm::errs() << pipeline_created.message() << "\n"; + return mlir::failure(); + } if (DebugJitCompiler()) { PrintPassPipeline(pm); @@ -422,6 +454,11 @@ MakeOptimizingTransformerForJit(llvm::TargetMachine* targetMachine) { if (!llvm_module) return compiler->Error("failed to translate module to LLVM IR"); + std::string llvm_module_string; + if (compiler->options().embed_ir_in_executable) { + llvm_module_string = llvm_ir::DumpToString(llvm_module.get()); + } + // Compile input module to the native function. auto engine = ExecutionEngine::CreateFromModule( std::move(llvm_ctx), std::move(llvm_module), engine_options, exported); @@ -441,7 +478,7 @@ MakeOptimizingTransformerForJit(llvm::TargetMachine* targetMachine) { return Executable(compiler->name(), std::move(memory_mapper), std::move(*engine), std::move(functions), specialization, - time_to_compile); + time_to_compile, std::move(llvm_module_string)); } // TODO(ezhulenev): Currently it's possible to specialize only one function. It diff --git a/xla/mlir/runtime/transforms/jit_compiler.h b/xla/mlir/runtime/transforms/jit_compiler.h index 2bd569bc1f23d..d36bdbcf9e90e 100644 --- a/xla/mlir/runtime/transforms/jit_compiler.h +++ b/xla/mlir/runtime/transforms/jit_compiler.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,9 +25,17 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/CodeGen.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project #include "xla/mlir/runtime/transforms/calling_convention.h" #include "xla/mlir/runtime/transforms/specialization.h" #include "xla/mlir/runtime/transforms/type_converter.h" @@ -35,6 +43,7 @@ limitations under the License. #include "xla/runtime/compiler.h" #include "xla/runtime/constraints.h" #include "xla/runtime/executable.h" +#include "xla/runtime/execution_engine.h" #include "xla/runtime/symbolic_shape.h" namespace xla { @@ -62,7 +71,7 @@ class JitCompiler { // Original input module might have an undefined calling convention (e.g. // XLA runtime does not support unranked tensors), and specialization can be // required as a precondition for compilation. - std::function create_specialization_pipeline; + std::function create_specialization_pipeline; // Create a pass pipeline that lowers compiled module from high level // dialects to the LLVM dialect. XLA runtime will use the LLVM ORC compiler @@ -73,7 +82,7 @@ class JitCompiler { // (convert them to an ABI compatible with the calling convention advertised // to XLA through the `calling_convention` type conversion), and for // that it usually must include `xla-rt-export-functions` pass. - std::function create_compilation_pipeline; + std::function create_compilation_pipeline; // LLVM optimization level when JIT compiling a module. llvm::CodeGenOptLevel jit_code_opt_level = llvm::CodeGenOptLevel::Default; @@ -107,6 +116,9 @@ class JitCompiler { // How much verification would you like to do? int verification_level = 0; + + // Whether to embed the LLVM IR generated in the executable + bool embed_ir_in_executable = false; }; // Instantiates compiler from the serialized mlir source. diff --git a/xla/mlir/runtime/transforms/move_allocas_to_entry_block.cc b/xla/mlir/runtime/transforms/move_allocas_to_entry_block.cc index 7a9a892770b8e..8b9a4c6cf62ee 100644 --- a/xla/mlir/runtime/transforms/move_allocas_to_entry_block.cc +++ b/xla/mlir/runtime/transforms/move_allocas_to_entry_block.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/runtime/transforms/ordinal_assignment.cc b/xla/mlir/runtime/transforms/ordinal_assignment.cc index 5cca97fbca622..5902c50d8fd99 100644 --- a/xla/mlir/runtime/transforms/ordinal_assignment.cc +++ b/xla/mlir/runtime/transforms/ordinal_assignment.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,9 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "xla/mlir/runtime/ir/rt_ops.h" #include "xla/mlir/runtime/transforms/passes.h" diff --git a/xla/mlir/runtime/transforms/passes.h b/xla/mlir/runtime/transforms/passes.h index 9f186607783f5..a0849eb6a9ac8 100644 --- a/xla/mlir/runtime/transforms/passes.h +++ b/xla/mlir/runtime/transforms/passes.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/runtime/transforms/passes.td b/xla/mlir/runtime/transforms/passes.td index aadf975e6d507..0d485208170d9 100644 --- a/xla/mlir/runtime/transforms/passes.td +++ b/xla/mlir/runtime/transforms/passes.td @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/runtime/transforms/rt_to_llvm.cc b/xla/mlir/runtime/transforms/rt_to_llvm.cc index f08d70ef31912..8af1cfda92803 100644 --- a/xla/mlir/runtime/transforms/rt_to_llvm.cc +++ b/xla/mlir/runtime/transforms/rt_to_llvm.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -43,13 +44,15 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "xla/mlir/runtime/ir/rt_dialect.h" #include "xla/mlir/runtime/ir/rt_ops.h" #include "xla/mlir/runtime/transforms/custom_call_encoding.h" #include "xla/mlir/runtime/transforms/passes.h" #include "xla/runtime/custom_call.h" +#include "xla/runtime/logical_result.h" #include "xla/runtime/tracing.h" #include "xla/runtime/type_id.h" diff --git a/xla/mlir/runtime/transforms/specialization.cc b/xla/mlir/runtime/transforms/specialization.cc index 8aeffa272d5d1..3109917e6fb2e 100644 --- a/xla/mlir/runtime/transforms/specialization.cc +++ b/xla/mlir/runtime/transforms/specialization.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,18 +23,30 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project #include "xla/mlir/runtime/transforms/type_converter.h" #include "xla/mlir/runtime/utils/constraints.h" #include "xla/runtime/arguments.h" +#include "xla/runtime/constraints.h" #include "xla/runtime/symbolic_shape.h" +#include "xla/runtime/types.h" +#include "xla/xla_data.pb.h" namespace xla { namespace runtime { diff --git a/xla/mlir/runtime/transforms/specialization.h b/xla/mlir/runtime/transforms/specialization.h index 592a3dbd1020c..4a43daa6ab49e 100644 --- a/xla/mlir/runtime/transforms/specialization.h +++ b/xla/mlir/runtime/transforms/specialization.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_MLIR_RUNTIME_TRANSFORMS_SPECIALIZATION_H_ #define XLA_MLIR_RUNTIME_TRANSFORMS_SPECIALIZATION_H_ +#include "absl/status/status.h" +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project diff --git a/xla/mlir/runtime/transforms/tests/BUILD b/xla/mlir/runtime/transforms/tests/BUILD deleted file mode 100644 index 7373506d182c0..0000000000000 --- a/xla/mlir/runtime/transforms/tests/BUILD +++ /dev/null @@ -1,54 +0,0 @@ -load("@tsl//tsl:tsl.default.bzl", "filegroup") -load("//xla:glob_lit_test.bzl", "glob_lit_tests") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - licenses = ["notice"], -) - -glob_lit_tests( - name = "all_tests", - data = [":test_utilities"], - driver = "//xla:run_lit.sh", - test_file_exts = ["mlir"], -) - -# Bundle together all of the test utilities that are used by tests. -filegroup( - name = "test_utilities", - testonly = True, - data = [ - "//xla/mlir/runtime:xla-runtime-opt", - "@llvm-project//llvm:FileCheck", - "@llvm-project//mlir:run_lit.sh", - ], -) - -cc_library( - name = "testlib_pipeline", - testonly = 1, - srcs = ["testlib_pipeline.cc"], - hdrs = ["testlib_pipeline.h"], - visibility = ["//xla:runtime"], - deps = [ - "//xla/mlir/runtime/transforms:compiler", - "//xla/mlir/runtime/transforms:passes", - "//xla/runtime:compiler", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:AsyncDialect", - "@llvm-project//mlir:AsyncToLLVM", - "@llvm-project//mlir:AsyncTransforms", - "@llvm-project//mlir:BuiltinToLLVMIRTranslation", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncToLLVM", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMToLLVMIRTranslation", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:MemRefToLLVM", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ReconcileUnrealizedCasts", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:SCFToControlFlow", - "@llvm-project//mlir:Transforms", - ], -) diff --git a/xla/mlir/runtime/transforms/tests/compilation_pipeline.mlir b/xla/mlir/runtime/transforms/tests/compilation_pipeline.mlir deleted file mode 100644 index 18fe89180a2b9..0000000000000 --- a/xla/mlir/runtime/transforms/tests/compilation_pipeline.mlir +++ /dev/null @@ -1,26 +0,0 @@ -// RUN: xla-runtime-opt %s --xla-runtime-default-cpu-pipeline | FileCheck %s -// RUN: xla-runtime-opt %s --xla-runtime-default-gpu-pipeline | FileCheck %s - -// Check that entrypoint function was lowered to LLVM function with expected -// ABI. - -// CHECK-LABEL: llvm.func @main( -// CHECK-SAME: %[[ARG0:arg[0-9]+]]: !llvm.ptr, -// CHECK-SAME: %[[ARG1:arg[0-9]+]]: !llvm.ptr, -// CHECK-SAME: %[[ARG2:arg[0-9]+]]: !llvm.ptr, -// CHECK-SAME: %[[ARG3:arg[0-9]+]]: i64, -// CHECK-SAME: %[[ARG4:arg[0-9]+]]: i64, -// CHECK-SAME: %[[ARG5:arg[0-9]+]]: i64 -// CHECK-SAME: ) -rt.export @main ordinal 0 -func.func @main(%arg0: memref) { - call @custom_call(%arg0) : (memref) -> () - return -} - -// Check that XLA runtime custom call was lowered to a LLVM function call. - -// CHECK: llvm.func @target -// CHECK-SAME: passthrough = ["nounwind"] -func.func private @custom_call(%arg0: memref) - attributes { rt.custom_call = "target" } diff --git a/xla/mlir/runtime/transforms/tests/convert_asserts.mlir b/xla/mlir/runtime/transforms/tests/convert_asserts.mlir deleted file mode 100644 index 196e80cc465ee..0000000000000 --- a/xla/mlir/runtime/transforms/tests/convert_asserts.mlir +++ /dev/null @@ -1,17 +0,0 @@ -// RUN: xla-runtime-opt %s --xla-rt-convert-asserts | FileCheck %s - -// CHECK: func @exported( -// CHECK: %[[CTX:.*]]: !rt.execution_context, -// CHECK: %[[PRED:.*]]: i1 -// CHECK: ) -func.func @exported(%arg0: !rt.execution_context, %arg1: i1) - attributes {rt.exported = 0 : i32} { - // CHECK: cf.cond_br %[[PRED]], ^[[OK:.*]], ^[[ERR:.*]] - // CHECK: ^[[OK]]: - // CHECK: return - // CHECK: ^[[ERR]]: - // CHECK: rt.set_error %[[CTX]], "Oops" - // CHECK: return - cf.assert %arg1, "Oops" - return -} diff --git a/xla/mlir/runtime/transforms/tests/convert_custom_calls.mlir b/xla/mlir/runtime/transforms/tests/convert_custom_calls.mlir deleted file mode 100644 index 4b4e8495454e8..0000000000000 --- a/xla/mlir/runtime/transforms/tests/convert_custom_calls.mlir +++ /dev/null @@ -1,59 +0,0 @@ -// RUN: xla-runtime-opt %s --xla-rt-convert-custom-calls | FileCheck %s - -// CHECK-NOT: func private @custom_call(memref) -func.func private @custom_call(%arg0: memref) -> memref - attributes { rt.custom_call = "target", attr0 = 1 : i32, attr1 = 1.0 : f32 } - -// CHECK-NOT: func private @dynamic_custom_call(memref) -func.func private @dynamic_custom_call(%arg0: memref) - attributes { rt.dynamic, rt.custom_call = "target" } - -// CHECK: func @function_call_to_custom_call( -// CHECK: %[[CTX:.*]]: !rt.execution_context, -// CHECK: %[[ARG:.*]]: memref -// CHECK: ) -> memref attributes {rt.exported = 0 : i32} { -func.func @function_call_to_custom_call( - %arg0: !rt.execution_context, - %arg1: memref -) -> memref attributes {rt.exported = 0 : i32} { - // CHECK: %[[STATUS:.*]], %[[RES:.*]] = rt.call %[[CTX]]["target"] - // CHECK-SAME: (%[[ARG]]) {attr0 = 2 : i32, attr1 = 1.000000e+00 : f32} - // CHECK: %[[IS_OK:.*]] = rt.is_ok %[[STATUS]] - // CHECK: assert %[[IS_OK]], "custom call 'target' failed" - %0 = call @custom_call(%arg1) { attr0 = 2 : i32 } - : (memref) -> memref - return %0 : memref -} - -// CHECK: func @function_call_to_dynamic_custom_call( -// CHECK: %[[CTX:.*]]: !rt.execution_context, -// CHECK: %[[ARG:.*]]: memref -// CHECK: ) attributes {rt.exported = 0 : i32} { -func.func @function_call_to_dynamic_custom_call( - %arg0: !rt.execution_context, - %arg1: memref -) attributes {rt.exported = 0 : i32} { - // CHECK: rt.call dynamic %[[CTX]]["target"] - call @dynamic_custom_call(%arg1) : (memref) -> () - return -} - -// CHECK: func @function_call_to_traced_custom_call( -// CHECK: %[[CTX:.*]]: !rt.execution_context, -// CHECK: %[[ARG:.*]]: memref -// CHECK: ) -> memref attributes {rt.exported = 0 : i32} { -func.func @function_call_to_traced_custom_call( - %arg0: !rt.execution_context, - %arg1: memref -) -> memref attributes {rt.exported = 0 : i32} { - // CHECK: %[[RES:.*]]:2 = rt.trace #rt.hlo_trace<"fusion">, %[[CTX]] - // CHECK-SAME: -> !rt.status, memref { - // CHECK-NEXT: %[[STATUS:.*]], %[[RET:.*]] = call %[[CTX]]["target"] - // CHECK-NOT: #rt.hlo_trace - // CHECK-NEXT: yield %[[STATUS]], %[[RET]] : !rt.status, memref - // CHECK-NEXT: } - // CHECK: rt.is_ok %[[RES]]#0 - %0 = call @custom_call(%arg1) { rt.trace = #rt.hlo_trace<"fusion"> } - : (memref) -> memref - return %0 : memref -} \ No newline at end of file diff --git a/xla/mlir/runtime/transforms/tests/export_functions.mlir b/xla/mlir/runtime/transforms/tests/export_functions.mlir deleted file mode 100644 index 7d167df73daa9..0000000000000 --- a/xla/mlir/runtime/transforms/tests/export_functions.mlir +++ /dev/null @@ -1,33 +0,0 @@ -// RUN: xla-runtime-opt %s --xla-rt-export-functions | FileCheck %s - -// CHECK: func @single_result( -// CHECK: %[[CTX:.*]]: !rt.execution_context, -// CHECK: %[[ARG:.*]]: memref -// CHECK: ) attributes {rt.exported = 0 : i32} { -rt.export @single_result ordinal 0 -func.func @single_result(%arg0: memref) -> memref { - // CHECK: rt.set_output %[[CTX]], 0, %[[ARG]] : memref - // CHECK: return - return %arg0 : memref -} - -// CHECK: func @two_results( -// CHECK: %[[CTX:.*]]: !rt.execution_context, -// CHECK: %[[ARG:.*]]: memref -// CHECK: ) attributes {rt.exported = 1 : i32} { -rt.export @two_results ordinal 1 -func.func @two_results(%arg0: memref) -> (memref, memref) { - // CHECK: rt.set_output %[[CTX]], 0, %[[ARG]] : memref - // CHECK: rt.set_output %[[CTX]], 1, %[[ARG]] : memref - // CHECK: return - return %arg0, %arg0 : memref, memref -} - -// CHECK: func @not_exported( -// CHECK: %[[ARG:.*]]: memref -// CHECK: ) -> memref { -func.func @not_exported(%arg0: memref) -> memref { - // CHECK-NOT: rt.set_output - // CHECK: return %[[ARG]] - return %arg0 : memref -} diff --git a/xla/mlir/runtime/transforms/tests/move_allocas_to_entry_block.mlir b/xla/mlir/runtime/transforms/tests/move_allocas_to_entry_block.mlir deleted file mode 100644 index dff76dfc312c2..0000000000000 --- a/xla/mlir/runtime/transforms/tests/move_allocas_to_entry_block.mlir +++ /dev/null @@ -1,55 +0,0 @@ -// RUN: xla-runtime-opt %s --xla-rt-move-allocas-to-entry-block | FileCheck %s - -func.func @compute( - %arg0: !rt.execution_context, - %arg1: !async.value> -) -> !async.token attributes {passthrough = ["presplitcoroutine"]} { - // CHECK: %alloca = memref.alloca() {alignment = 64 : i64} : memref - // CHECK: %0 = async.runtime.create : !async.token - // CHECK: %1 = async.coro.id - // CHECK: %2 = async.coro.begin %1 - // CHECK: %3 = async.coro.save %2 - // CHECK: async.runtime.resume %2 - // CHECK: async.coro.suspend %3, ^bb9, ^bb1, ^bb8 - // CHECK: ^bb1: // pred: ^bb0 - // CHECK: %status = rt.call %arg0["test.producer"] (%alloca) - // CHECK: : (memref) -> () - %0 = async.runtime.create : !async.token - %1 = async.coro.id - %2 = async.coro.begin %1 - %3 = async.coro.save %2 - async.runtime.resume %2 - async.coro.suspend %3, ^bb9, ^bb1, ^bb8 -^bb1: // pred: ^bb0 - %alloca = memref.alloca() {alignment = 64 : i64} : memref - %status = rt.call %arg0["test.producer"] (%alloca) : (memref) -> () - %4 = rt.is_ok %status - cf.cond_br %4, ^bb2, ^bb6 -^bb2: // pred: ^bb1 - %5 = async.coro.save %2 - async.runtime.await_and_resume %arg1, %2 : !async.value> - async.coro.suspend %5, ^bb9, ^bb3, ^bb8 -^bb3: // pred: ^bb2 - %6 = async.runtime.is_error %arg1 : !async.value> - cf.cond_br %6, ^bb6, ^bb4 -^bb4: // pred: ^bb3 - %7 = async.runtime.load %arg1 : > - %status_0 = rt.call %arg0["test.consumer"] (%alloca) : (memref) -> () - %8 = rt.is_ok %status_0 - cf.cond_br %8, ^bb5, ^bb6 -^bb5: // pred: ^bb4 - async.runtime.set_available %0 : !async.token - cf.br ^bb7 -^bb6: // 3 preds: ^bb1, ^bb3, ^bb4 - async.runtime.set_error %0 : !async.token - cf.br ^bb7 -^bb7: // 2 preds: ^bb5, ^bb6 - async.coro.free %1, %2 - cf.br ^bb9 -^bb8: // 2 preds: ^bb0, ^bb2 - async.coro.free %1, %2 - cf.br ^bb9 -^bb9: // 4 preds: ^bb0, ^bb2, ^bb7, ^bb8 - async.coro.end %2 - return %0 : !async.token -} diff --git a/xla/mlir/runtime/transforms/tests/ordinal_assignment.mlir b/xla/mlir/runtime/transforms/tests/ordinal_assignment.mlir deleted file mode 100644 index e6b2ffd9754a6..0000000000000 --- a/xla/mlir/runtime/transforms/tests/ordinal_assignment.mlir +++ /dev/null @@ -1,56 +0,0 @@ -// RUN: xla-runtime-opt %s --split-input-file --xla-rt-ordinal-assignment \ -// RUN: | FileCheck %s - -// CHECK: rt.export @exported.0 ordinal 0 -// CHECK: rt.export @exported.1 ordinal 1 -rt.export @exported.0 -rt.export @exported.1 - -func.func @exported.0() { return } -func.func @exported.1() { return } - -// ----- - -// CHECK: rt.export @exported.0 ordinal 0 -// CHECK: rt.export @exported.1 ordinal 1 -rt.export @exported.0 ordinal 0 -rt.export @exported.1 - -func.func @exported.0() { return } -func.func @exported.1() { return } - -// ----- - -// CHECK: rt.export @exported.0 ordinal 1 -// CHECK: rt.export @exported.1 ordinal 0 -rt.export @exported.0 ordinal 1 -rt.export @exported.1 - -func.func @exported.0() { return } -func.func @exported.1() { return } - -// ----- - -// CHECK: rt.export @exported.0 ordinal 0 -// CHECK: rt.export @exported.1 ordinal 1 -// CHECK: rt.export @exported.2 ordinal 2 -rt.export @exported.0 -rt.export @exported.1 -rt.export @exported.2 ordinal 2 - -func.func @exported.0() { return } -func.func @exported.1() { return } -func.func @exported.2() { return } - -// ----- - -// CHECK: rt.export @exported.0 ordinal 0 -// CHECK: rt.export @exported.1 ordinal 2 -// CHECK: rt.export @exported.2 ordinal 1 -rt.export @exported.0 -rt.export @exported.1 -rt.export @exported.2 ordinal 1 - -func.func @exported.0() { return } -func.func @exported.1() { return } -func.func @exported.2() { return } diff --git a/xla/mlir/runtime/transforms/tests/rt_to_llvm.mlir b/xla/mlir/runtime/transforms/tests/rt_to_llvm.mlir deleted file mode 100644 index 08fc6dddd0b2d..0000000000000 --- a/xla/mlir/runtime/transforms/tests/rt_to_llvm.mlir +++ /dev/null @@ -1,648 +0,0 @@ -// RUN: xla-runtime-opt %s --split-input-file --xla-rt-to-llvm | FileCheck %s - -// CHECK: func @pass_context( -// CHECK: %[[CTX:.*]]: !llvm.ptr -// CHECK: ) -func.func @pass_context(%arg0: !rt.execution_context) { - func.return -} - -// ----- - -// CHECK: func @set_output( -// CHECK: %[[CTX:.*]]: !llvm.ptr -// CHECK: ) -func.func @set_output(%arg0: !rt.execution_context) { - // CHECK: %[[MEMREF:.*]] = memref.alloc - // CHECK: %[[LLVM_MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEMREF]] - %0 = memref.alloc() : memref - // CHECK: %[[C0:.*]] = arith.constant 0 : i64 - // CHECK: %[[RES_PTR:.*]] = call @runtimeGetResultStorage(%[[CTX]], %[[C0]]) - // CHECK: llvm.store %[[LLVM_MEMREF]], %[[RES_PTR]] - rt.set_output %arg0, 0, %0 : memref - func.return -} - -// ----- - -// CHECK-DAG: llvm.mlir.global {{.*}} @[[ERR0:.*]]("Failed precondition #0\00") -// CHECK-DAG: llvm.mlir.global {{.*}} @[[ERR1:.*]]("Failed precondition #1\00") - -// CHECK: func @set_error( -// CHECK: %[[CTX:.*]]: !llvm.ptr -// CHECK: ) -func.func @set_error(%arg0: !rt.execution_context) { - // CHECK: %[[ADDR0:.*]] = llvm.mlir.addressof @[[ERR0]] - // CHECK: call @runtimeSetError(%[[CTX]], %[[ADDR0]]) - rt.set_error %arg0, "Failed precondition #0" - // CHECK: %[[ADDR1:.*]] = llvm.mlir.addressof @[[ERR1]] - // CHECK: call @runtimeSetError(%[[CTX]], %[[ADDR1]]) - rt.set_error %arg0, "Failed precondition #1" - func.return -} - -// ----- - -// CHECK: llvm.mlir.global {{.*}} @[[ERR:.*]]("Failed precondition\00") -// CHECK-NOT: Failed precondition - -// CHECK: func @dedup_error_message( -// CHECK: %[[CTX:.*]]: !llvm.ptr -// CHECK: ) -func.func @dedup_error_message(%arg0: !rt.execution_context) { - // CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @[[ERR]] - rt.set_error %arg0, "Failed precondition" - // CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @[[ERR]] - rt.set_error %arg0, "Failed precondition" - func.return -} - -// ----- - -// CHECK: global internal constant @__rt_num_attrs(1 : i64) {{.*}}: i64 - -// CHECK: global internal constant @__rt_attr_value() -// CHECK-SAME: !llvm.array<3 x i64> { -// CHECK: llvm.mlir.undef : !llvm.array<3 x i64> -// CHECK: arith.constant 1 : i64 -// CHECK: llvm.insertvalue -// CHECK: arith.constant 2 : i64 -// CHECK: llvm.insertvalue -// CHECK: arith.constant 3 : i64 -// CHECK: llvm.insertvalue -// CHECK: llvm.return -// CHECK: } - -// CHECK: global internal constant @__rt_attr_value_0() -// CHECK-SAME: !llvm.struct<(i64, ptr)> { -// CHECK: arith.constant 3 : i64 -// CHECK: llvm.mlir.addressof @__rt_attr_value : !llvm.ptr -// CHECK: llvm.mlir.undef : !llvm.struct<(i64, ptr)> -// CHECK: llvm.insertvalue -// CHECK: llvm.insertvalue -// CHECK: llvm.return -// CHECK: } - -// CHECK: func @custom_call( -// CHECK: %[[CTX:.*]]: !llvm.ptr -// CHECK: ) -func.func @custom_call(%arg0: !rt.execution_context) { - // CHECK: call @target - rt.call %arg0["target"] () { arr = [1, 2, 3] } : () -> () - func.return -} - -// ----- - -// CHECK: global internal constant @__rt_num_attrs(1 : i64) {{.*}}: i64 - -// CHECK: global internal constant @__rt_attr_value() -// CHECK-SAME: : !llvm.array<3 x i64> { -// CHECK: llvm.mlir.undef : !llvm.array<3 x i64> -// CHECK: arith.constant 1 : i64 -// CHECK: llvm.insertvalue -// CHECK: arith.constant 2 : i64 -// CHECK: llvm.insertvalue -// CHECK: arith.constant 3 : i64 -// CHECK: llvm.insertvalue -// CHECK: } - -// CHECK: global internal constant @__rt_attr_value_0() -// CHECK-SAME: !llvm.struct<(i64, ptr)> { -// CHECK arith.constant 3 : i64 -// CHECK llvm.mlir.addressof @__rt_attr_value -// CHECK llvm.mlir.undef : !llvm.struct<(i64, ptr)> -// CHECK llvm.mlir.insertvalue -// CHECK llvm.mlir.insertvalue -// CHECK: } - -// CHECK: func @custom_call( -// CHECK: %[[CTX:.*]]: !llvm.ptr -// CHECK: ) -func.func @custom_call(%arg0: !rt.execution_context) { - // CHECK: call @target - rt.call %arg0["target"] () - { attr_name = array } : () -> () - func.return -} - -// ----- - -// CHECK: global internal constant @__rt_num_attrs(1 : i64) - -// CHECK: global internal constant @__rt_attr_value() -// CHECK-SAME: !llvm.struct<(i64, ptr)> { -// CHECK: arith.constant 0 : i64 -// CHECK: llvm.mlir.zero : !llvm.ptr -// CHECK: llvm.mlir.undef : !llvm.struct<(i64, ptr)> -// CHECK: llvm.insertvalue -// CHECK: llvm.insertvalue -// CHECK: } - -// CHECK: func @custom_call( -// CHECK: %[[CTX:.*]]: !llvm.ptr -// CHECK: ) -func.func @custom_call(%arg0: !rt.execution_context) { - // CHECK: call @target - rt.call %arg0["target"] () { arr = [] } : () -> () - func.return -} - -// ----- - -// CHECK: global internal constant @__rt_custom_call_name("target\00") - -// CHECK: global internal constant @__rt_empty_rets() -// CHECK: { -// CHECK: llvm.mlir.undef : !llvm.array<1 x ptr> -// CHECK: llvm.mlir.addressof @__rt_zero : !llvm.ptr -// CHECK: } - -// CHECK: global internal constant @__rt_num_attrs(0 : i64) - -// CHECK: global internal constant @__rt_custom_call_attrs() -// CHECK: { -// CHECK: llvm.mlir.undef : !llvm.array<1 x ptr> -// CHECK: llvm.mlir.addressof @__rt_num_attrs : !llvm.ptr -// CHECK: } - -// CHECK: global internal constant @__rt_empty_args() -// CHECK: { -// CHECK: llvm.mlir.undef : !llvm.array<1 x ptr> -// CHECK: llvm.mlir.addressof @__rt_zero : !llvm.ptr -// CHECK: } - -// CHECK: func @dynamic_custom_call( -// CHECK: %[[CTX:.*]]: !llvm.ptr -// CHECK: ) -func.func @dynamic_custom_call(%arg0: !rt.execution_context) { - - // CHECK: %[[CALLEE_ADDR:.*]] = llvm.mlir.addressof @__rt_custom_call_name - // CHECK: %[[ARGS:.*]] = llvm.mlir.addressof @__rt_empty_args - // CHECK: %[[ATTRS:.*]] = llvm.mlir.addressof @__rt_custom_call_attrs - // CHECK: %[[RETS:.*]] = llvm.mlir.addressof @__rt_empty_rets - - // CHECK: %[[STATUS:.*]] = call @runtimeCustomCall(%[[CTX]], %[[CALLEE_ADDR]], - // CHECK-SAME: %[[ARGS]], %[[ATTRS]], - // CHECK-SAME: %[[RETS]]) - // CHECK: cf.assert %[[STATUS]], "oops" - %status = rt.call dynamic %arg0["target"] () : () -> () - %ok = rt.is_ok %status - cf.assert %ok, "oops" - func.return -} - -// ----- - -// CHECK: global internal constant @__rt_num_attrs(1 : i64) -// CHECK: global internal constant @__rt_attr_value(1.230000e+02 : f32) -// CHECK: global internal constant @__rt_str("attr_name\00") - -// CHECK: global internal constant @__rt_attr_name() -// CHECK-SAME: : !llvm.struct<(i64, ptr)> { -// CHECK: arith.constant 9 : i64 -// CHECK: llvm.mlir.addressof @__rt_str : !llvm.ptr -// CHECK: } - -// CHECK: global internal constant @__rt_custom_call_attrs() -// CHECK-SAME: : !llvm.array<4 x ptr> { -// CHECK: llvm.mlir.addressof @__rt_num_attrs -// CHECK: llvm.mlir.addressof @__rt_attr_name -// CHECK: llvm.mlir.addressof @__type_id_float -// CHECK: llvm.mlir.addressof @__rt_attr_value -// CHECK: } - -// CHECK: func @custom_call( -// CHECK: %[[CTX:.*]]: !llvm.ptr -// CHECK: ) -func.func @custom_call(%arg0: !rt.execution_context) { - // CHECK: call @target - rt.call %arg0["target"] () { attr_name = 123.0 : f32 } : () -> () - func.return -} - -// ----- - -// CHECK: global internal constant @__rt_num_attrs(1 : i64) - -// CHECK: llvm.mlir.global internal constant @__rt_attr_value -// CHECK-SAME: (dense<[1, 2, 3]> : tensor<3xi32>) - -// CHECK: llvm.mlir.global internal constant @__rt_attr_value_0() -// CHECK-SAME: : !llvm.struct -// CHECK-SAME: <(struct<(i64, ptr)>, i64, array<1 x i64>)> { -// CHECK: arith.constant 3 : i64 -// CHECK: llvm.mlir.addressof -// CHECK: llvm.mlir.undef : !llvm.struct<(i64, ptr)> -// CHECK: llvm.insertvalue -// CHECK: llvm.insertvalue -// CHECK: arith.constant 1 : i64 -// CHECK: llvm.mlir.undef : !llvm.array<1 x i64> -// CHECK: arith.constant 3 : i64 -// CHECK: llvm.insertvalue -// CHECK: llvm.mlir.undef : !llvm.struct -// CHECK-SAME: <(struct<(i64, ptr)>, i64, array<1 x i64>)> -// CHECK: llvm.insertvalue -// CHECK: llvm.insertvalue -// CHECK: llvm.insertvalue -// CHECK: } - -// CHECK: func @custom_call( -// CHECK: %[[CTX:.*]]: !llvm.ptr -// CHECK: ) -func.func @custom_call(%arg0: !rt.execution_context) { - // CHECK: call @target - rt.call %arg0["target"] () - { attr_name = dense<[1, 2, 3]> : tensor<3xi32> } : () -> () - func.return -} - -// ----- - -// CHECK: global internal constant @__rt_num_attrs(1 : i64) - -// CHECK: llvm.mlir.global internal constant @__rt_attr_value -// CHECK-SAME: (dense<[1, 2]> : tensor<2xi32>) - -// CHECK: llvm.mlir.global internal constant @__rt_attr_value_0() -// CHECK-SAME: : !llvm.struct -// CHECK-SAME: <(struct<(i64, ptr)>, i64, array<2 x i64>)> { -// CHECK: arith.constant 2 : i64 -// CHECK: llvm.mlir.addressof -// CHECK: llvm.mlir.undef : !llvm.struct<(i64, ptr)> -// CHECK: llvm.insertvalue -// CHECK: llvm.insertvalue -// CHECK: arith.constant 2 : i64 -// CHECK: llvm.mlir.undef : !llvm.array<2 x i64> -// CHECK: arith.constant 2 : i64 -// CHECK: llvm.insertvalue -// CHECK: arith.constant 1 : i64 -// CHECK: llvm.insertvalue -// CHECK: llvm.mlir.undef : !llvm.struct -// CHECK-SAME: <(struct<(i64, ptr)>, i64, array<2 x i64>)> -// CHECK: llvm.insertvalue -// CHECK: llvm.insertvalue -// CHECK: llvm.insertvalue -// CHECK: } - -// CHECK: func @custom_call( -// CHECK: %[[CTX:.*]]: !llvm.ptr -// CHECK: ) -func.func @custom_call(%arg0: !rt.execution_context) { - // CHECK: call @target - rt.call %arg0["target"] () - { attr_name = dense<[[1], [2]]> : tensor<2x1xi32> } : () -> () - func.return -} - -// ----- - -// CHECK: global internal constant @__rt_num_attrs(1 : i64) -// CHECK: global internal constant @[[STR:.*]]("attr_value\00") - -// CHECK: global internal constant @__rt_attr_value() -// CHECK-SAME: : !llvm.struct<(i64, ptr)> { -// CHECK: arith.constant 10 : i64 -// CHECK: llvm.mlir.addressof @[[STR]] : !llvm.ptr -// CHECK: } - -// CHECK: func @custom_call( -// CHECK: %[[CTX:.*]]: !llvm.ptr -// CHECK: ) -func.func @custom_call(%arg0: !rt.execution_context) { - // CHECK: call @target - rt.call %arg0["target"] () { attr_name = "attr_value" } : () -> () - func.return -} - -// ----- - -// CHECK: llvm.mlir.global internal constant @__rt_empty_rets() - -// CHECK: llvm.mlir.global internal constant @__rt_num_attrs(0 : i64) -// CHECK: llvm.mlir.global internal constant @__rt_custom_call_attrs - -// CHECK: llvm.mlir.global internal constant @__rt_args_type_table -// CHECK: llvm.mlir.undef : !llvm.array<1 x ptr> -// CHECK: llvm.mlir.addressof @__type_id_float - -// CHECK: func @custom_call( -// CHECK: %[[CTX:.*]]: !llvm.ptr -// CHECK: %[[ARG:.*]]: f32 -// CHECK: ) -func.func @custom_call(%arg0: !rt.execution_context, %arg1 : f32) { - // CHECK-DAG: %[[MEM:.*]] = llvm.alloca {{.*}} x f32 - // CHECK-DAG: %[[ARGS:.*]] = llvm.alloca {{.*}} x !llvm.array<3 x ptr> - - // CHECK-DAG: %[[N_ARGS:.*]] = llvm.mlir.addressof @__rt_num_args - // CHECK-DAG: llvm.store volatile %[[ARG]], %[[MEM]] - - // CHECK: %[[ARGS_TYPES:.*]] = llvm.mlir.addressof @__rt_args_type_table - // CHECK: llvm.insertvalue %[[ARGS_TYPES]], {{.*}}[1] : !llvm.array<3 x ptr> - // CHECK: llvm.intr.lifetime.start -1, %[[ARGS]] - // CHECK: llvm.store {{.*}}, %[[ARGS]] : !llvm.array<3 x ptr>, !llvm.ptr - - // CHECK: %[[RETS:.*]] = llvm.mlir.addressof @__rt_empty_rets - - // CHECK: call @target - // CHECK: llvm.intr.lifetime.end -1, %[[ARGS]] - rt.call %arg0["target"] (%arg1) : (f32) -> () - func.return -} - -// ----- - -// CHECK: llvm.mlir.global internal constant @__rt_args_type_table -// CHECK: llvm.mlir.addressof @__type_id_memref_view - -// CHECK: func @custom_call( -// CHECK: %[[CTX:.*]]: !llvm.ptr -// CHECK: %[[ARG:.*]]: memref -// CHECK: ) -func.func @custom_call(%arg0: !rt.execution_context, %arg1 : memref) { - - // CHECK: %[[DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG]] - // CHECK-SAME: to !llvm.struct - - // CHECK: llvm.mlir.undef : !llvm.array<4 x i64> - // CHECK-NEXT: llvm.extractvalue %[[DESC]][3, 0] - // CHECK-NEXT: arith.constant 256 : i64 - // CHECK-NEXT: llvm.insertvalue - // CHECK-NEXT: llvm.insertvalue - // CHECK-NEXT: arith.constant 256 : i64 - // CHECK-NEXT: arith.constant 1 : i64 - // CHECK-NEXT: llvm.insertvalue - // CHECK-NEXT: %[[SIZES:.*]] = llvm.insertvalue - - // llvm.mlir.undef : !llvm.struct<(i8, i8, ptr, array<2 x i64>)> - // CHECK: llvm.insertvalue - // CHECK: llvm.insertvalue - // CHECK: llvm.insertvalue %[[SIZES]] - // CHECK: llvm.insertvalue - - // CHECK: %[[N_ARGS:.*]] = llvm.mlir.addressof @__rt_num_args - // CHECK: %[[TYPES:.*]] = llvm.mlir.addressof @__rt_args_type_table - - // CHECK: call @target - rt.call %arg0["target"] (%arg1) : (memref) -> () - func.return -} - -// ----- - -// CHECK: internal constant @__rt_custom_call_attrs() {{.*}}: !llvm.array<4 x ptr> -// CHECK-NOT: internal constant @__rt_custom_call_attrs - -// CHECK: func @dedup_custom_call_attrs( -// CHECK: %[[CTX:.*]]: !llvm.ptr -// CHECK: ) -func.func @dedup_custom_call_attrs(%arg0: !rt.execution_context) { - // CHECK: call @target - rt.call %arg0["target"] () { arr = [1, 2, 3] } : () -> () - // CHECK: call @target - rt.call %arg0["target"] () { arr = [1, 2, 3] } : () -> () - func.return -} - -// CHECK: func private @target(!llvm.ptr, !llvm.ptr, -// CHECK-SAME: !llvm.ptr) -> i1 - -// ----- - -// CHECK: func @dynamic_custom_call( -// CHECK: %[[CTX:.*]]: !llvm.ptr -// CHECK: ) -func.func @dynamic_custom_call(%arg0: !rt.execution_context) { - // CHECK: call @runtimeCustomCall - // CHECK: call @runtimeCustomCall - rt.call dynamic %arg0["target"] () : () -> () - rt.call dynamic %arg0["target"] () : () -> () - func.return -} - -// ----- - -func.func @custom_call(%ctx: !rt.execution_context) -> (f32) { - // CHECK: %[[C1:.*]] = arith.constant 1 : i32 - // CHECK: %[[RETS:.*]] = llvm.alloca %[[C1]] x !llvm.array<3 x ptr> - - // CHECK: %[[C1_0:.*]] = arith.constant 1 : i32 - // CHECK: %[[F32_ALLOCA:.*]] = llvm.alloca %[[C1_0]] x f32 - - // CHECK: %[[N_RETS:.*]] = llvm.mlir.addressof @__rt_num_rets - - // CHECK: call @f32_reduce - // CHECK: %[[LOAD2:.*]] = llvm.load %[[F32_ALLOCA]] - // CHECK: llvm.intr.lifetime.end -1, %[[F32_ALLOCA]] - %status, %0 = rt.call %ctx["f32_reduce"] () : () -> (f32) - return %0 : f32 -} - -// ----- - -// CHECK: func @opaque_arg( -// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, -// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr -// CHECK-SAME: ) -func.func @opaque_arg(%ctx: !rt.execution_context, %arg: !rt.opaque) { - return -} - -// ----- - -// CHECK: llvm.mlir.global internal constant @__rt_args_type_table -// CHECK: llvm.mlir.addressof @__type_id_opaque : !llvm.ptr - -// CHECK: func @opaque_custom_call_arg( -// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, -// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr -// CHECK-SAME: ) -func.func @opaque_custom_call_arg(%ctx: !rt.execution_context, - %arg: !rt.opaque) { - // CHECK: %[[ALLOCA:.*]] = llvm.alloca {{.*}} x !llvm.ptr - // CHECK: llvm.store volatile %[[ARG1]], %[[ALLOCA]] : !llvm.ptr - // CHECK: call @target - %status = rt.call %ctx["target"] (%arg) : (!rt.opaque) -> () - return -} - -// ----- - -// CHECK: func @opaque_custom_call_res( -// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr -// CHECK-SAME: ) -func.func @opaque_custom_call_res(%ctx: !rt.execution_context) { - // CHECK: %[[ALLOCA:.*]] = llvm.alloca {{.*}} x !llvm.ptr - // CHECK: call @target - %status, %res = rt.call %ctx["target"] () : () -> (!rt.opaque) - // CHECK: llvm.load %[[ALLOCA]] : !llvm.ptr -> !llvm.ptr - return -} - -// ----- - -// CHECK: llvm.mlir.global internal constant @__rt_custom_call_attrs -// CHECK: llvm.mlir.addressof @__type_id_nullopt - -// CHECK: func @custom_call_unit_attr( -// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr -// CHECK-SAME: ) -func.func @custom_call_unit_attr(%ctx: !rt.execution_context) { - // CHECK: llvm.mlir.addressof @__rt_custom_call_attrs - %status = rt.call %ctx["target"] () { attr } : () -> () - return -} - -// ----- - -// CHECK: %[[C1:.*]] = arith.constant 1 : i32 -// CHECK: %[[RETS_ALLOCA:.*]] = llvm.alloca %[[C1]] x !llvm.array<3 x ptr> - -// CHECK: %[[C1_0:.*]] = arith.constant 1 : i32 -// CHECK: %[[MEMREF_ALLOCA:.*]] = llvm.alloca %[[C1_0]] x !llvm.struct<(i8, i8, ptr, array<4 x i64>)> - -// CHECK: call @f32_reduce - -// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: %[[DATA_GEP:.*]] = llvm.getelementptr %[[MEMREF_ALLOCA]] -// CHECK: %[[DATA:.*]] = llvm.load %[[DATA_GEP]] - -// CHECK: llvm.insertvalue %[[DATA]], {{.*}}[0] -// CHECK: llvm.insertvalue %[[DATA]], {{.*}}[1] - -// CHECK: %[[OFFSET:.*]] = llvm.mlir.constant(0 : index) -// CHECK: llvm.insertvalue %[[OFFSET]], {{.*}}[2] - -// CHECK: %[[DIM0:.*]] = llvm.mlir.constant(2 : index) -// CHECK: llvm.insertvalue %[[DIM0]], {{.*}}[3, 0] -// CHECK: %[[STRIDE0:.*]] = llvm.mlir.constant(2 : index) -// CHECK: llvm.insertvalue %[[STRIDE0]], {{.*}}[4, 0] - -// CHECK: %[[DIM1:.*]] = llvm.mlir.constant(2 : index) -// CHECK: llvm.insertvalue %[[DIM1]], {{.*}}[3, 1] -// CHECK: %[[STRIDE1:.*]] = llvm.mlir.constant(1 : index) -// CHECK: llvm.insertvalue %[[STRIDE1]], {{.*}}[4, 1] -func.func @custom_call(%ctx: !rt.execution_context) -> (memref<2x2xf32>) { - %status, %0 = rt.call %ctx["f32_reduce"] () : () -> (memref<2x2xf32>) - return %0 : memref<2x2xf32> -} - -// ----- - -// CHECK: %[[C1:.*]] = arith.constant 1 : i32 -// CHECK: %[[RETS_ALLOCA:.*]] = llvm.alloca %[[C1]] x !llvm.array<3 x ptr> - -// CHECK: %[[C1_0:.*]] = arith.constant 1 : i32 -// CHECK: %[[MEMREF_ALLOCA:.*]] = llvm.alloca %[[C1_0]] x !llvm.struct<(i8, i8, ptr, array<4 x i64>)> - -// CHECK: call @f32_reduce -func.func @custom_call(%ctx: !rt.execution_context) - -> (!async.value>) { - %status, %0 = rt.call %ctx["f32_reduce"] () - : () -> (!async.value>) - return %0 : !async.value> -} - -// ----- - -// Test that custom call encoding can pass a reference to exported function as a -// custom call attribute. -func.func @init(%ctx: !rt.execution_context) - attributes {rt.exported = 0: i32} { return } - -// CHECK-DAG: mlir.global internal constant @__rt_num_attrs(1 : i64) -// CHECK-DAG: mlir.global external constant @__type_id_function_ordinal() -// CHECK-DAG: mlir.global internal constant @__rt_attr_value(0 : i32) - -// CHECK: mlir.global internal constant @__rt_custom_call_attrs -// CHECK: mlir.addressof @__type_id_function_ordinal -// CHECK: mlir.addressof @__rt_attr_value -// CHECK: llvm.return {{.*}} : !llvm.array<4 x ptr> - -// CHECK: @custom_call_exported_function_ref -func.func @custom_call_exported_function_ref(%ctx: !rt.execution_context) { - %status = rt.call %ctx["call_init"] () { init = @init } : () -> () - return -} - -// ----- - -func.func private @compute() -> tensor - -// CHECK: mlir.global internal constant @__rt_aggregate_hlo_trace -// CHECK: llvm.mlir.addressof @__rt_aggregate_hlo_trace - -// CHECK: func @trace -func.func @trace(%ctx: !rt.execution_context) -> tensor { - // CHECK: call @xla.trace.activity_start - // CHECK: call @compute - // CHECK: call @xla.trace.activity_end - %0 = rt.trace #rt.hlo_trace<"foo">, %ctx -> tensor { - %1 = func.call @compute(): () -> tensor - yield %1 : tensor - } - return %0 : tensor -} - -// ----- - -// CHECK: llvm.mlir.global internal constant @__rt_c123(123 : i32) - -// CHECK: func @custom_call( -// CHECK: %[[CTX:.*]]: !llvm.ptr -// CHECK: ) -func.func @custom_call(%arg0: !rt.execution_context) { - // CHECK: llvm.mlir.addressof @__rt_c123 : !llvm.ptr - // CHECK: call @target - %c123 = arith.constant 123 : i32 - rt.call %arg0["target"] (%c123) : (i32) -> () - func.return -} - -// ----- - -// CHECK: llvm.mlir.global internal constant @__rt_cst(1.234560e+02 : f32) - -// CHECK: func @custom_call( -// CHECK: %[[CTX:.*]]: !llvm.ptr -// CHECK: ) -func.func @custom_call(%arg0: !rt.execution_context) { - // CHECK: llvm.mlir.addressof @__rt_cst : !llvm.ptr - // CHECK: call @target - %cst = arith.constant 123.456 : f32 - rt.call %arg0["target"] (%cst) : (f32) -> () - func.return -} - -// ----- -// Check that we reuse allocas for encoding arguments on the stack. - -// CHECK: func @custom_call( -// CHECK: %[[CTX:.*]]: !llvm.ptr, -// CHECK: %[[ARG:.*]]: f32 -// CHECK: ) -func.func @custom_call(%arg0: !rt.execution_context, %arg1: f32) { - // CHECK: %[[ARGS:.*]] = llvm.alloca {{.*}} x !llvm.array<3 x ptr> - // CHECK: %[[ARG_ALLOCA:.*]] = llvm.alloca %{{.*}} x f32 - // CHECK-NOT: llvm.alloca - - // llvm.intr.lifetime.start -1, %[[ARG_ALLOCA]] : !llvm.ptr - // CHECK: llvm.store volatile %[[ARG]], %[[ARG_ALLOCA]] : f32, !llvm.ptr - // llvm.intr.lifetime.start -1, %[[ARGS]] : !llvm.ptr - // CHECK: llvm.store {{.*}}, %[[ARGS]] - // CHECK: call @target - rt.call %arg0["target"] (%arg1) : (f32) -> () - // llvm.intr.lifetime.end -1, %[[ARGS]] : !llvm.ptr - // llvm.intr.lifetime.end -1, %[[ARG_ALLOCA]] : !llvm.ptr - - // llvm.intr.lifetime.start -1, %[[ARG_ALLOCA]] : !llvm.ptr - // CHECK: llvm.store volatile %[[ARG]], %[[ARG_ALLOCA]] : f32, !llvm.ptr - // llvm.intr.lifetime.start -1, %[[ARGS]] : !llvm.ptr - // CHECK: llvm.store {{.*}}, %[[ARGS]] - // CHECK: call @target - rt.call %arg0["target"] (%arg1) : (f32) -> () - // llvm.intr.lifetime.end -1, %[[ARGS]] : !llvm.ptr - // llvm.intr.lifetime.end -1, %[[ARG_ALLOCA]] : !llvm.ptr - - func.return -} diff --git a/xla/mlir/runtime/transforms/tests/testlib_pipeline.cc b/xla/mlir/runtime/transforms/tests/testlib_pipeline.cc deleted file mode 100644 index 4afd9c54dc94f..0000000000000 --- a/xla/mlir/runtime/transforms/tests/testlib_pipeline.cc +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/mlir/runtime/transforms/tests/testlib_pipeline.h" - -#include - -#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" // from @llvm-project -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" // from @llvm-project -#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" // from @llvm-project -#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Async/IR/Async.h" // from @llvm-project -#include "mlir/Dialect/Async/Passes.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project -#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project -#include "xla/mlir/runtime/transforms/compiler.h" -#include "xla/mlir/runtime/transforms/passes.h" - -namespace xla { -namespace runtime { - -void RegisterXlaRuntimeTestlibDialects(DialectRegistry& dialects) { - // Register MLIR dialects supported by the Xla runtime tests. - dialects->insert(); - - // Register MLIR dialects that can be translated to LLVM IR. - registerBuiltinDialectTranslation(*dialects); - registerLLVMDialectTranslation(*dialects); -} - -void CreateXlaRuntimeTestlibPipeline(PassManager& passes) { - passes->addPass(mlir::createConvertSCFToCFPass()); - passes->addPass(mlir::createAsyncFuncToAsyncRuntimePass()); - - // Export functions to the XLA runtime. - passes->addPass(CreateExportRuntimeFunctionsPass()); - passes->addPass(CreateConvertCustomCallsPass()); - passes->addPass(CreateConvertAssertsPass()); - - // Lower from high level async operations to async runtime. - passes->addPass(mlir::createAsyncToAsyncRuntimePass()); - - // Add async.runtime reference counting operations. - passes->addPass(mlir::createAsyncRuntimePolicyBasedRefCountingPass()); - - // Convert runtime operations and custom calls to LLVM dialect. - ConvertRuntimeToLLvmOpts rt_to_llvm_opts; - passes->addPass(CreateConvertRuntimeToLLVMPass(std::move(rt_to_llvm_opts))); - - // Convert async runtime operations to LLVM dialect. - passes->addPass(mlir::createConvertAsyncToLLVMPass()); - - // Convert everything else to LLVM dialect. - passes->addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); - passes->addPass(mlir::createConvertFuncToLLVMPass()); - passes->addPass(mlir::createReconcileUnrealizedCastsPass()); - - // Clean up IR before translating it to LLVM. - passes->addPass(mlir::createCSEPass()); - passes->addPass(mlir::createCanonicalizerPass()); -} - -} // namespace runtime -} // namespace xla diff --git a/xla/mlir/runtime/transforms/tests/testlib_pipeline.h b/xla/mlir/runtime/transforms/tests/testlib_pipeline.h deleted file mode 100644 index 31edab0068e4b..0000000000000 --- a/xla/mlir/runtime/transforms/tests/testlib_pipeline.h +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_MLIR_RUNTIME_TRANSFORMS_TESTS_TESTLIB_PIPELINE_H_ -#define XLA_MLIR_RUNTIME_TRANSFORMS_TESTS_TESTLIB_PIPELINE_H_ - -#include "xla/runtime/compiler.h" - -namespace xla { -namespace runtime { - -// Registers dialects supported by the Xla runtime tests. -void RegisterXlaRuntimeTestlibDialects(DialectRegistry& dialects); - -// Populates passes for compiling Xla runtime tests. -void CreateXlaRuntimeTestlibPipeline(PassManager& passes); - -} // namespace runtime -} // namespace xla - -#endif // XLA_MLIR_RUNTIME_TRANSFORMS_TESTS_TESTLIB_PIPELINE_H_ diff --git a/xla/mlir/runtime/transforms/type_converter.cc b/xla/mlir/runtime/transforms/type_converter.cc index 4967bcf2640f5..0c0303bdf864c 100644 --- a/xla/mlir/runtime/transforms/type_converter.cc +++ b/xla/mlir/runtime/transforms/type_converter.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,11 +24,17 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Async/IR/AsyncTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project #include "xla/mlir/runtime/ir/rt_dialect.h" #include "xla/primitive_util.h" +#include "xla/runtime/types.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace runtime { diff --git a/xla/mlir/runtime/transforms/type_converter.h b/xla/mlir/runtime/transforms/type_converter.h index 2c99b81fa8b1d..edf275bf5c255 100644 --- a/xla/mlir/runtime/transforms/type_converter.h +++ b/xla/mlir/runtime/transforms/type_converter.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,8 +20,12 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "xla/runtime/types.h" +#include "xla/xla_data.pb.h" namespace xla { namespace runtime { diff --git a/xla/mlir/runtime/transforms/type_converter_test.cc b/xla/mlir/runtime/transforms/type_converter_test.cc index c85b68132032e..94425abfbfa25 100644 --- a/xla/mlir/runtime/transforms/type_converter_test.cc +++ b/xla/mlir/runtime/transforms/type_converter_test.cc @@ -16,8 +16,11 @@ #include "xla/mlir/runtime/transforms/type_converter.h" +#include "llvm/Support/Casting.h" #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/runtime/types.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/mlir/runtime/utils/BUILD b/xla/mlir/runtime/utils/BUILD index 8d7b9d2fc8880..69fb07cee751d 100644 --- a/xla/mlir/runtime/utils/BUILD +++ b/xla/mlir/runtime/utils/BUILD @@ -17,8 +17,11 @@ cc_library( "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/functional:any_invocable", "@llvm-project//llvm:OrcJIT", + "@llvm-project//llvm:OrcShared", + "@llvm-project//llvm:Support", "@llvm-project//mlir:mlir_async_runtime_api", "@tsl//tsl/concurrency:async_value", + "@tsl//tsl/concurrency:ref_count", "@tsl//tsl/platform:platform_port", ], ) @@ -40,12 +43,10 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "//xla/runtime:constraints", - "//xla/runtime:errors", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -68,8 +69,5 @@ cc_library( name = "float_16bits", hdrs = ["float_16bits.h"], compatible_with = get_compatible_with_portable(), - deps = [ - "@llvm-project//llvm:OrcJIT", - "@llvm-project//mlir:mlir_float16_utils", - ], + deps = ["@llvm-project//llvm:OrcJIT"], ) diff --git a/xla/mlir/runtime/utils/async_runtime_api.cc b/xla/mlir/runtime/utils/async_runtime_api.cc index 104af42e4192e..2e7967a9e2eba 100644 --- a/xla/mlir/runtime/utils/async_runtime_api.cc +++ b/xla/mlir/runtime/utils/async_runtime_api.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,11 +25,18 @@ limitations under the License. #include #include "absl/base/dynamic_annotations.h" +#include "absl/functional/any_invocable.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/Mangling.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" #include "mlir/ExecutionEngine/AsyncRuntime.h" // from @llvm-project #include "xla/runtime/async_runtime.h" #include "tsl/concurrency/async_value.h" #include "tsl/concurrency/async_value_ref.h" #include "tsl/concurrency/chain.h" +#include "tsl/concurrency/ref_count.h" #include "tsl/platform/mem.h" namespace xla { diff --git a/xla/mlir/runtime/utils/async_runtime_api.h b/xla/mlir/runtime/utils/async_runtime_api.h index 545cd2a1a1c9e..94f3dcd2f4649 100644 --- a/xla/mlir/runtime/utils/async_runtime_api.h +++ b/xla/mlir/runtime/utils/async_runtime_api.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,9 +17,11 @@ limitations under the License. #define XLA_MLIR_RUNTIME_UTILS_ASYNC_RUNTIME_API_H_ #include "absl/functional/any_invocable.h" +#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ExecutionEngine/Orc/Core.h" #include "llvm/ExecutionEngine/Orc/Mangling.h" #include "xla/runtime/async_runtime.h" +#include "tsl/concurrency/async_value.h" #include "tsl/concurrency/async_value_ref.h" #include "tsl/concurrency/chain.h" diff --git a/xla/mlir/runtime/utils/c_runner_utils.h b/xla/mlir/runtime/utils/c_runner_utils.h index 7a7910c141ab2..152f7b2bf5c3b 100644 --- a/xla/mlir/runtime/utils/c_runner_utils.h +++ b/xla/mlir/runtime/utils/c_runner_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/runtime/utils/constraints.cc b/xla/mlir/runtime/utils/constraints.cc index a48f3958bd55e..da0023ea62b4d 100644 --- a/xla/mlir/runtime/utils/constraints.cc +++ b/xla/mlir/runtime/utils/constraints.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,8 +21,15 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "xla/runtime/constraints.h" namespace xla { namespace runtime { diff --git a/xla/mlir/runtime/utils/constraints.h b/xla/mlir/runtime/utils/constraints.h index 99cd314cfac38..f88c3f267e0d3 100644 --- a/xla/mlir/runtime/utils/constraints.h +++ b/xla/mlir/runtime/utils/constraints.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,7 +17,9 @@ limitations under the License. #define XLA_MLIR_RUNTIME_UTILS_CONSTRAINTS_H_ #include "absl/status/statusor.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project #include "xla/runtime/constraints.h" diff --git a/xla/mlir/runtime/utils/custom_calls.cc b/xla/mlir/runtime/utils/custom_calls.cc index fb7661bbc6c61..41991bc8f50b1 100644 --- a/xla/mlir/runtime/utils/custom_calls.cc +++ b/xla/mlir/runtime/utils/custom_calls.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,12 +15,15 @@ limitations under the License. #include "xla/mlir/runtime/utils/custom_calls.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project namespace xla { namespace runtime { diff --git a/xla/mlir/runtime/utils/custom_calls.h b/xla/mlir/runtime/utils/custom_calls.h index 1b50fda35390b..02a85d2c38ff6 100644 --- a/xla/mlir/runtime/utils/custom_calls.h +++ b/xla/mlir/runtime/utils/custom_calls.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project diff --git a/xla/mlir/runtime/utils/float_16bits.h b/xla/mlir/runtime/utils/float_16bits.h index 51b5ac419cb5b..0ea054626b4d3 100644 --- a/xla/mlir/runtime/utils/float_16bits.h +++ b/xla/mlir/runtime/utils/float_16bits.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/runtime/xla-runtime-opt.cc b/xla/mlir/runtime/xla-runtime-opt.cc deleted file mode 100644 index 22adfe8131054..0000000000000 --- a/xla/mlir/runtime/xla-runtime-opt.cc +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mlir/Dialect/Async/IR/Async.h" // from @llvm-project -#include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project -#include "xla/mlir/math/transforms/passes.h" -#include "xla/mlir/memref/transforms/passes.h" -#include "xla/mlir/runtime/ir/tests/testlib.h" -#include "xla/mlir/runtime/transforms/passes.h" - -int main(int argc, char **argv) { - mlir::DialectRegistry registry; - - registry.insert(); - mlir::func::registerAllExtensions(registry); - xla::registerMathTransformsPasses(); - xla::registerMemrefTransformsPasses(); - xla::runtime::registerRuntimeTransformsPasses(); - - return failed(MlirOptMain(argc, argv, "Xla Runtime Pass Driver\n", registry)); -} diff --git a/xla/mlir/utils/BUILD b/xla/mlir/utils/BUILD index cc92bcb3e6f60..f3a48908faac0 100644 --- a/xla/mlir/utils/BUILD +++ b/xla/mlir/utils/BUILD @@ -1,12 +1,14 @@ +load("@tsl//tsl:tsl.bzl", "internal_visibility") load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ + default_visibility = internal_visibility([ "//third_party/golang/github_com/gomlx/gomlx:__subpackages__", "//xla:internal", - ], + ]), licenses = ["notice"], ) @@ -37,3 +39,34 @@ cc_test( "@tsl//tsl/platform:test_main", ], ) + +cc_library( + name = "type_util", + srcs = ["type_util.cc"], + hdrs = ["type_util.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "@com_google_absl//absl/status:statusor", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "type_util_test", + srcs = ["type_util_test.cc"], + deps = [ + ":type_util", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "@com_google_absl//absl/functional:function_ref", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test_main", + ], +) diff --git a/xla/mlir/utils/error_util.cc b/xla/mlir/utils/error_util.cc index 335087e159d3b..9df4c468c4257 100644 --- a/xla/mlir/utils/error_util.cc +++ b/xla/mlir/utils/error_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/utils/error_util.h b/xla/mlir/utils/error_util.h index ddf85ec63fdb6..b37f478e173c4 100644 --- a/xla/mlir/utils/error_util.h +++ b/xla/mlir/utils/error_util.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/utils/error_util_test.cc b/xla/mlir/utils/error_util_test.cc index 1a6b8335e31c9..9771cd7a7a5b7 100644 --- a/xla/mlir/utils/error_util_test.cc +++ b/xla/mlir/utils/error_util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/utils/type_util.cc b/xla/mlir/utils/type_util.cc new file mode 100644 index 0000000000000..873d960908327 --- /dev/null +++ b/xla/mlir/utils/type_util.cc @@ -0,0 +1,110 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/mlir/utils/type_util.h" + +#include "absl/status/statusor.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "xla/primitive_util.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +absl::StatusOr ConvertPrimitiveTypeToMlirType( + xla::PrimitiveType type, mlir::Builder b) { + switch (type) { + case xla::PrimitiveType::PRED: + return b.getI1Type(); + case xla::PrimitiveType::F8E5M2: + return b.getFloat8E5M2Type(); + case xla::PrimitiveType::F8E4M3FN: + return b.getFloat8E4M3FNType(); + case xla::PrimitiveType::F8E4M3B11FNUZ: + return b.getFloat8E4M3B11FNUZType(); + case xla::PrimitiveType::F8E5M2FNUZ: + return b.getFloat8E5M2FNUZType(); + case xla::PrimitiveType::F8E4M3FNUZ: + return b.getFloat8E4M3FNUZType(); + case xla::PrimitiveType::F16: + return b.getF16Type(); + case xla::PrimitiveType::BF16: + return b.getBF16Type(); + case xla::PrimitiveType::F32: + return b.getF32Type(); + case xla::PrimitiveType::F64: + return b.getF64Type(); + // TODO(b/130356985): Support unsigned primitive types. + default: + if (xla::primitive_util::IsIntegralType(type)) { + return mlir::IntegerType::get( + b.getContext(), + /*width=*/xla::primitive_util::BitWidth(type), + /*signed=*/ + xla::primitive_util::IsUnsignedIntegralType(type) + ? mlir::IntegerType::Unsigned + : mlir::IntegerType::Signless); + } + if (xla::primitive_util::IsComplexType(type)) { + TF_ASSIGN_OR_RETURN( + mlir::Type component_type, + xla::ConvertPrimitiveTypeToMlirType( + xla::primitive_util::ComplexComponentType(type), b)); + return mlir::ComplexType::get(component_type); + } + return xla::Internal("Unsupported type: %s", + xla::PrimitiveType_Name(type)); + } +} + +xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) { + if (type.isFloat8E5M2()) { + return xla::PrimitiveType::F8E5M2; + } else if (type.isFloat8E4M3FN()) { + return xla::PrimitiveType::F8E4M3FN; + } else if (type.isFloat8E4M3B11FNUZ()) { + return xla::PrimitiveType::F8E4M3B11FNUZ; + } else if (type.isFloat8E4M3FNUZ()) { + return xla::PrimitiveType::F8E4M3FNUZ; + } else if (type.isFloat8E5M2FNUZ()) { + return xla::PrimitiveType::F8E5M2FNUZ; + } else if (type.isBF16()) { + return xla::PrimitiveType::BF16; + } else if (type.isF16()) { + return xla::PrimitiveType::F16; + } else if (type.isF32()) { + return xla::PrimitiveType::F32; + } else if (type.isF64()) { + return xla::PrimitiveType::F64; + } else if (auto complex_type = type.dyn_cast()) { + mlir::Type element_ty = complex_type.getElementType(); + return xla::primitive_util::ComplexType( + ConvertMlirTypeToPrimitiveType(element_ty)); + } else if (auto integer_type = type.dyn_cast()) { + bool is_unsigned = integer_type.isUnsigned(); + if (integer_type.getWidth() == 1) { + return xla::PrimitiveType::PRED; + } + return is_unsigned ? xla::primitive_util::UnsignedIntegralTypeForBitWidth( + integer_type.getWidth()) + : xla::primitive_util::SignedIntegralTypeForBitWidth( + integer_type.getWidth()); + } + return xla::PrimitiveType::PRIMITIVE_TYPE_INVALID; +} +} // namespace xla diff --git a/xla/mlir/utils/type_util.h b/xla/mlir/utils/type_util.h new file mode 100644 index 0000000000000..1505a55cfed2f --- /dev/null +++ b/xla/mlir/utils/type_util.h @@ -0,0 +1,38 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_MLIR_UTILS_TYPE_UTIL_H_ +#define XLA_MLIR_UTILS_TYPE_UTIL_H_ + +#include "absl/status/statusor.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "xla/xla_data.pb.h" + +// Type utilities to match MLIR types to XLA primitive types and vice versa. +namespace xla { +// Converts an XLA primitive type to the corresponding MLIR type. +// Signed XLA primitive types are converted to signless MLIR types; +// unsigned XLA primitive types are converted to unsigned MLIR types. +absl::StatusOr ConvertPrimitiveTypeToMlirType( + xla::PrimitiveType type, mlir::Builder b); + +// Returns an XLA xla::PrimitiveType equivalent of an MLIR Type that represents +// a primitive type (e.g., i8, f32), else returns PRIMITIVE_TYPE_INVALID. +// Signless MLIR types are converted to signed XLA primitive types. +xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type); +} // namespace xla + +#endif // XLA_MLIR_UTILS_TYPE_UTIL_H_ diff --git a/xla/mlir/utils/type_util_test.cc b/xla/mlir/utils/type_util_test.cc new file mode 100644 index 0000000000000..67110b32e1fc5 --- /dev/null +++ b/xla/mlir/utils/type_util_test.cc @@ -0,0 +1,157 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/mlir/utils/type_util.h" + +#include +#include + +#include +#include "absl/functional/function_ref.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/primitive_util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +// A pair of corresponding types +struct TypeUtilTestParam { + xla::PrimitiveType xla_t; + absl::FunctionRef mlir_t; +}; + +inline std::string mlirTypeToString(mlir::Type type) { + std::string result{}; + llvm::raw_string_ostream sstream(result); + sstream << type; + return result; +} + +class TypeUtilTest : public ::testing::TestWithParam {}; + +TEST_P(TypeUtilTest, ConvertInvalidTypeTest) { + mlir::MLIRContext context; + mlir::Builder b(&context); + + EXPECT_EQ(ConvertMlirTypeToPrimitiveType(b.getIntegerType(17)), + xla::PrimitiveType::PRIMITIVE_TYPE_INVALID); +} + +TEST_P(TypeUtilTest, MLIRToPrimitiveTypeConversionTest) { + mlir::MLIRContext context = mlir::MLIRContext(); + mlir::Builder b = mlir::Builder(&context); + xla::PrimitiveType xla_type_expected = GetParam().xla_t; + mlir::Type mlir_type = GetParam().mlir_t(b); + xla::PrimitiveType xla_type_actual = + ConvertMlirTypeToPrimitiveType(mlir_type); + EXPECT_EQ(xla_type_actual, xla_type_expected) + << "Expected: " + << primitive_util::LowercasePrimitiveTypeName(xla_type_expected) + << ". Actual: " + << primitive_util::LowercasePrimitiveTypeName(xla_type_actual) << "."; +} + +TEST_P(TypeUtilTest, PrimitiveTypeToMLIRTypeConversionTest) { + mlir::MLIRContext context = mlir::MLIRContext(); + mlir::Builder b = mlir::Builder(&context); + xla::PrimitiveType xla_type = GetParam().xla_t; + mlir::Type mlir_type_expected = GetParam().mlir_t(b); + TF_ASSERT_OK_AND_ASSIGN(mlir::Type mlir_type_actual, + ConvertPrimitiveTypeToMlirType(xla_type, b)); + EXPECT_EQ(mlir_type_actual, mlir_type_expected) + << "Expected: " << mlirTypeToString(mlir_type_expected) + << ". Actual: " << mlirTypeToString(mlir_type_actual) << "."; +} + +TEST_P(TypeUtilTest, BidirectionalConversionTest) { + mlir::MLIRContext context = mlir::MLIRContext(); + mlir::Builder b = mlir::Builder(&context); + xla::PrimitiveType xla_type_expected = GetParam().xla_t; + TF_ASSERT_OK_AND_ASSIGN(mlir::Type mlir_type_actual, + ConvertPrimitiveTypeToMlirType(xla_type_expected, b)); + xla::PrimitiveType xla_type_actual = + ConvertMlirTypeToPrimitiveType(mlir_type_actual); + EXPECT_EQ(xla_type_actual, xla_type_expected) + << "Expected: " + << primitive_util::LowercasePrimitiveTypeName(xla_type_expected) + << ". Actual: " + << primitive_util::LowercasePrimitiveTypeName(xla_type_actual) + << ". Intermediate MLIR type: " << mlirTypeToString(mlir_type_actual) + << "."; +} + +INSTANTIATE_TEST_SUITE_P( + Execute, TypeUtilTest, + ::testing::ValuesIn(std::vector( + {{PRED, [](mlir::Builder b) { return b.getI1Type(); }}, + {F8E5M2, [](mlir::Builder b) { return b.getFloat8E5M2Type(); }}, + {F8E4M3FN, [](mlir::Builder b) { return b.getFloat8E4M3FNType(); }}, + {F8E4M3B11FNUZ, + [](mlir::Builder b) { return b.getFloat8E4M3B11FNUZType(); }}, + {F8E5M2FNUZ, + [](mlir::Builder b) { return b.getFloat8E5M2FNUZType(); }}, + {F8E4M3FNUZ, + [](mlir::Builder b) { return b.getFloat8E4M3FNUZType(); }}, + {F16, [](mlir::Builder b) { return b.getF16Type(); }}, + {BF16, [](mlir::Builder b) { return b.getBF16Type(); }}, + {F32, [](mlir::Builder b) { return b.getF32Type(); }}, + {F64, [](mlir::Builder b) { return b.getF64Type(); }}, + {U4, [](mlir::Builder b) { return b.getIntegerType(4, false); }}, + {U8, [](mlir::Builder b) { return b.getIntegerType(8, false); }}, + {U16, [](mlir::Builder b) { return b.getIntegerType(16, false); }}, + {U32, [](mlir::Builder b) { return b.getIntegerType(32, false); }}, + {U64, [](mlir::Builder b) { return b.getIntegerType(64, false); }}, + {S4, + [](mlir::Builder b) { + return mlir::IntegerType::get(b.getContext(), 4, + mlir::IntegerType::Signless); + }}, + {S8, + [](mlir::Builder b) { + return mlir::IntegerType::get(b.getContext(), 8, + mlir::IntegerType::Signless); + }}, + {S16, + [](mlir::Builder b) { + return mlir::IntegerType::get(b.getContext(), 16, + mlir::IntegerType::Signless); + }}, + {S32, + [](mlir::Builder b) { + return mlir::IntegerType::get(b.getContext(), 32, + mlir::IntegerType::Signless); + }}, + {S64, + [](mlir::Builder b) { + return mlir::IntegerType::get(b.getContext(), 64, + mlir::IntegerType::Signless); + }}})), + [](const auto& info) { + mlir::MLIRContext context; + mlir::Builder b(&context); + + return absl::StrFormat( + "xla_%s_mlir_%s", + primitive_util::LowercasePrimitiveTypeName(info.param.xla_t), + mlirTypeToString(info.param.mlir_t(b))); + }); + +} // namespace +} // namespace xla diff --git a/xla/mlir/xla_cpu/ir/BUILD b/xla/mlir/xla_cpu/ir/BUILD index 821e79e1ef216..7d3f2007cf1c5 100644 --- a/xla/mlir/xla_cpu/ir/BUILD +++ b/xla/mlir/xla_cpu/ir/BUILD @@ -1,10 +1,14 @@ -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") +load("@tsl//tsl:tsl.bzl", "internal_visibility") +load( + "@tsl//tsl:tsl.default.bzl", + "get_compatible_with_portable", +) +load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], + default_visibility = internal_visibility(["//learning/brain/mlir:xla_friends"]), ) td_library( @@ -100,6 +104,7 @@ cc_library( "//xla/mlir_hlo", "@llvm-project//llvm:Support", "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:BufferizationInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", ], diff --git a/xla/mlir/xla_cpu/ir/xla_cpu.cc b/xla/mlir/xla_cpu/ir/xla_cpu.cc index c8378c4768560..3e977dc7f2e7c 100644 --- a/xla/mlir/xla_cpu/ir/xla_cpu.cc +++ b/xla/mlir/xla_cpu/ir/xla_cpu.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/xla_cpu/ir/xla_cpu.h b/xla/mlir/xla_cpu/ir/xla_cpu.h index 17197dfe1c5c2..6391e985d1622 100644 --- a/xla/mlir/xla_cpu/ir/xla_cpu.h +++ b/xla/mlir/xla_cpu/ir/xla_cpu.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/xla_cpu/ir/xla_cpu_dialect.td b/xla/mlir/xla_cpu/ir/xla_cpu_dialect.td index 3008651981409..026168df1f189 100644 --- a/xla/mlir/xla_cpu/ir/xla_cpu_dialect.td +++ b/xla/mlir/xla_cpu/ir/xla_cpu_dialect.td @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,8 +27,6 @@ def XlaCpuDialect : Dialect { CPU runtime. }]; let cppNamespace = "::mlir::xla_cpu"; - - let usePropertiesForAttributes = 0; } #endif // XLA_MLIR_XLA_CPU_DIALECT_TD_ diff --git a/xla/mlir/xla_cpu/ir/xla_cpu_enums.td b/xla/mlir/xla_cpu/ir/xla_cpu_enums.td index 8b6ba925f23a9..ff09a17c97520 100644 --- a/xla/mlir/xla_cpu/ir/xla_cpu_enums.td +++ b/xla/mlir/xla_cpu/ir/xla_cpu_enums.td @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/xla_cpu/ir/xla_cpu_ops.td b/xla/mlir/xla_cpu/ir/xla_cpu_ops.td index fa554a2c622b1..ce107cf5e2531 100644 --- a/xla/mlir/xla_cpu/ir/xla_cpu_ops.td +++ b/xla/mlir/xla_cpu/ir/xla_cpu_ops.td @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir/xla_cpu/tests/BUILD b/xla/mlir/xla_cpu/tests/BUILD deleted file mode 100644 index c1d06c5f97dde..0000000000000 --- a/xla/mlir/xla_cpu/tests/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -load("@tsl//tsl:tsl.default.bzl", "filegroup") -load("//xla:glob_lit_test.bzl", "glob_lit_tests") - -# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) - -glob_lit_tests( - name = "all_tests", - data = [":test_utilities"], - driver = "//xla:run_lit.sh", - test_file_exts = [ - "mlir", - ], -) - -# Bundle together all of the test utilities that are used by tests. -# This intentionally does not pull-in the top-level tf-opt to reduce the -# dependencies. -filegroup( - name = "test_utilities", - testonly = True, - data = [ - "//xla/mlir/backends/cpu:xla-cpu-opt", - "@llvm-project//llvm:FileCheck", - "@llvm-project//mlir:run_lit.sh", - ], -) diff --git a/xla/mlir/xla_cpu/tests/bufferize.mlir b/xla/mlir/xla_cpu/tests/bufferize.mlir deleted file mode 100644 index f6d727d415a08..0000000000000 --- a/xla/mlir/xla_cpu/tests/bufferize.mlir +++ /dev/null @@ -1,133 +0,0 @@ -// RUN: xla-cpu-opt %s -split-input-file -empty-tensor-to-alloc-tensor \ -// RUN: -one-shot-bufferize | FileCheck %s - -func.func @max_reduce(%arg0: tensor<10xf32>) -> tensor<10xf32> { - %0 = tensor.empty() : tensor<10xf32> - %1 = "xla_cpu.all_reduce"(%arg0, %0) { - channel_handle = 5 : i64, - reduction_kind = 3 : i32, - replica_groups = dense<[]> : tensor<0xi64>, - use_global_device_ids = 0 : i32 - } : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> - return %1 : tensor<10xf32> -} - -// CHECK-LABEL: @max_reduce -// CHECK-SAME: %[[ARG0:.*]]: tensor<10xf32> -// CHECK: %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0]] -// CHECK: %[[OUT:.*]] = memref.alloc() {{.*}} memref<10xf32> -// CHECK: "xla_cpu.all_reduce"(%[[ARG0_MEMREF]], %[[OUT]]) { -// CHECK-SAME: channel_handle = 5 -// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[OUT]] -// CHECK: return %[[RESULT]] - -// ----- - -func.func @collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { - %0 = tensor.empty() : tensor<16x8xf32> - %1 = "xla_cpu.collective_permute"(%arg0, %0) { - channel_handle = 1 : i64, - source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> - } : (tensor<16x8xf32>, tensor<16x8xf32>) -> tensor<16x8xf32> - return %1 : tensor<16x8xf32> -} - -// CHECK-LABEL: @collective_permute -// CHECK-SAME: %[[ARG0:.*]]: tensor<16x8xf32> -// CHECK: %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0]] -// CHECK: %[[OUT:.*]] = memref.alloc() {{.*}} memref<16x8xf32> -// CHECK: "xla_cpu.collective_permute"(%[[ARG0_MEMREF]], %[[OUT]]) { -// CHECK-SAME: channel_handle = 1 -// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[OUT]] -// CHECK: return %[[RESULT]] - -// ----- - -func.func @all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { - %0 = tensor.empty() : tensor<16x4xf32> - %1 = "xla_cpu.all_to_all"(%arg0, %0) { - concat_dimension = 0 : i64, - replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, - channel_id_present = 0 : i32, - op_id = 0 : i64, - split_count = 4 : i64, - split_dimension = 1 : i64 - } : (tensor<4x16xf32>, tensor<16x4xf32>) -> tensor<16x4xf32> - return %1 : tensor<16x4xf32> -} - -// CHECK-LABEL: @all_to_all -// CHECK-SAME: %[[ARG0:.*]]: tensor<4x16xf32> -// CHECK: %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0]] -// CHECK: %[[OUT:.*]] = memref.alloc() {{.*}} memref<16x4xf32> -// CHECK: "xla_cpu.all_to_all"(%[[ARG0_MEMREF]], %[[OUT]]) { -// CHECK-SAME: split_count = 4 -// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[OUT]] -// CHECK: return %[[RESULT]] - - -// ----- - -func.func @all_to_all_tuple(%arg0: tensor<128x4xf32>, %arg1: tensor<128x4xf32>) - -> (tensor<128x4xf32>, tensor<128x4xf32>) { - %0 = tensor.empty() : tensor<128x4xf32> - %1 = tensor.empty() : tensor<128x4xf32> - %2:2 = "xla_cpu.all_to_all"(%arg0, %arg1, %0, %1) { - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, - channel_id_present = 0 : i32, - op_id = 0 : i64 - } : (tensor<128x4xf32>, tensor<128x4xf32>, - tensor<128x4xf32>, tensor<128x4xf32>) -> - (tensor<128x4xf32>, tensor<128x4xf32>) - return %2#0, %2#1 : tensor<128x4xf32>, tensor<128x4xf32> -} - -// CHECK-LABEL: @all_to_all_tuple -// CHECK-SAME: %[[ARG0:.*]]: tensor<128x4xf32>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<128x4xf32> -// CHECK-DAG: %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0]] -// CHECK-DAG: %[[ARG1_MEMREF:.*]] = bufferization.to_memref %[[ARG1]] -// CHECK-DAG: "xla_cpu.all_to_all"(%[[ARG0_MEMREF]], %[[ARG1_MEMREF]], %[[OUT0:.*]], %[[OUT1:.*]]) { -// CHECK-DAG: %[[OUT0]] = memref.alloc() {{.*}} memref<128x4xf32> -// CHECK-DAG: %[[OUT1]] = memref.alloc() {{.*}} memref<128x4xf32> -// CHECK-DAG: %[[RESULT0:.*]] = bufferization.to_tensor %[[OUT0]] : -// CHECK-DAG: %[[RESULT1:.*]] = bufferization.to_tensor %[[OUT1]] : -// CHECK: return %[[RESULT0]], %[[RESULT1]] - -// ----- - -func.func @fft(%arg0: tensor<3x5x4x8x256xf32>) -> tensor<3x5x4x8x129xcomplex> { - %0 = tensor.empty() : tensor<3x5x4x8x129xcomplex> - %1 = "xla_cpu.fft"(%arg0, %0) { - fft_length = [4, 8, 256], - fft_type = 2 : i32 - } : (tensor<3x5x4x8x256xf32>,tensor<3x5x4x8x129xcomplex>) -> tensor<3x5x4x8x129xcomplex> - return %1 : tensor<3x5x4x8x129xcomplex> -} - -// CHECK-LABEL: @fft -// CHECK-SAME: %[[ARG0:.*]]: tensor<3x5x4x8x256xf32> -// CHECK: %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0]] -// CHECK: %[[OUT:.*]] = memref.alloc() {{.*}} -// CHECK: "xla_cpu.fft"(%[[ARG0_MEMREF]], %[[OUT]]) - - -// ----- - -func.func @rng_bit_generator(%state: tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) { - %new_state_init = tensor.empty() : tensor<2xui64> - %output_init = tensor.empty() : tensor<10x12xui32> - %new_state, %output = "xla_cpu.rng_bit_generator"(%state, %new_state_init, - %output_init) { - rng_algorithm = #mhlo.rng_algorithm - } : (tensor<2xui64>, tensor<2xui64>, tensor<10x12xui32>) - -> (tensor<2xui64>, tensor<10x12xui32>) - func.return %new_state, %output : tensor<2xui64>, tensor<10x12xui32> -} - -// CHECK-LABEL: @rng_bit_generator -// CHECK-SAME: %[[STATE:.*]]: tensor -// CHECK: %[[STATE_MEMREF:.*]] = bufferization.to_memref %[[STATE]] -// CHECK: %[[STATE_OUT:.*]] = memref.alloc() {{.*}}<2xui64> -// CHECK: %[[OUTPUT:.*]] = memref.alloc() {{.*}}<10x12xui32> -// CHECK: "xla_cpu.rng_bit_generator"(%[[STATE_MEMREF]], %[[STATE_OUT]], %[[OUTPUT]]) \ No newline at end of file diff --git a/xla/mlir/xla_cpu/tests/invalid.mlir b/xla/mlir/xla_cpu/tests/invalid.mlir deleted file mode 100644 index 8f9584417e6dd..0000000000000 --- a/xla/mlir/xla_cpu/tests/invalid.mlir +++ /dev/null @@ -1,7 +0,0 @@ -// RUN: xla-cpu-opt %s -split-input-file -verify-diagnostics - -func.func @memref_cast_out_of_place(%arg0: memref<10xi1>) -> memref<10xi16> { - // expected-error @+1 {{cannot cast from 'i1' to 'i16'}} - %ret = xla_cpu.memref_element_cast %arg0 : memref<10xi1> to memref<10xi16> - return %ret : memref<10xi16> -} diff --git a/xla/mlir/xla_cpu/tests/ops.mlir b/xla/mlir/xla_cpu/tests/ops.mlir deleted file mode 100644 index 7f06ab3fd3d17..0000000000000 --- a/xla/mlir/xla_cpu/tests/ops.mlir +++ /dev/null @@ -1,16 +0,0 @@ -// RUN: xla-cpu-opt %s -split-input-file -empty-tensor-to-alloc-tensor \ -// RUN: -one-shot-bufferize | FileCheck %s - -func.func @memref_cast(%arg0: memref<10xf32>) -> memref<10xi32> { - %ret = xla_cpu.memref_element_cast %arg0 : memref<10xf32> to memref<10xi32> - return %ret : memref<10xi32> -} - -// CHECK: xla_cpu.memref_element_cast {{.*}} : memref<10xf32> to memref<10xi32> - -func.func @memref_cast_i1(%arg0: memref<10xi1>) -> memref<10xi8> { - %ret = xla_cpu.memref_element_cast %arg0 : memref<10xi1> to memref<10xi8> - return %ret : memref<10xi8> -} - -// CHECK: xla_cpu.memref_element_cast {{.*}} : memref<10xi1> to memref<10xi8> \ No newline at end of file diff --git a/xla/mlir_hlo/BUILD b/xla/mlir_hlo/BUILD index a0e7075d59e32..7ee7c8f493249 100644 --- a/xla/mlir_hlo/BUILD +++ b/xla/mlir_hlo/BUILD @@ -1,11 +1,12 @@ load("@bazel_skylib//rules:build_test.bzl", "build_test") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "gentbl_filegroup", "td_library") +load("@tsl//tsl:tsl.bzl", "internal_visibility") load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], + default_visibility = internal_visibility(["//learning/brain/mlir:mhlo_friends"]), licenses = ["notice"], ) @@ -219,63 +220,6 @@ gentbl_cc_library( deps = [":lhlo_ops_td_files"], ) -gentbl_cc_library( - name = "lhlo_gpu_ops_enums_inc_gen", - compatible_with = get_compatible_with_portable(), - strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-enum-decls"], - "lhlo_gpu/IR/lhlo_gpu_ops_enums.h.inc", - ), - ( - ["-gen-enum-defs"], - "lhlo_gpu/IR/lhlo_gpu_ops_enums.cc.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "lhlo_gpu/IR/lhlo_gpu_ops_enums.td", - deps = [":lhlo_gpu_ops_td_files"], -) - -gentbl_cc_library( - name = "lhlo_gpu_ops_dialect_inc_gen", - compatible_with = get_compatible_with_portable(), - strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-dialect-decls"], - "lhlo_gpu/IR/lhlo_gpu_ops_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "lhlo_gpu/IR/lhlo_gpu_ops_dialect.cc.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "lhlo_gpu/IR/lhlo_gpu_ops_enums.td", - deps = [":lhlo_gpu_ops_td_files"], -) - -gentbl_cc_library( - name = "lhlo_gpu_ops_attrdefs_inc_gen", - compatible_with = get_compatible_with_portable(), - strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-attrdef-decls"], - "lhlo_gpu/IR/lhlo_gpu_ops_attrdefs.h.inc", - ), - ( - ["-gen-attrdef-defs"], - "lhlo_gpu/IR/lhlo_gpu_ops_attrdefs.cc.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "lhlo_gpu/IR/lhlo_gpu_ops_enums.td", - deps = [":lhlo_gpu_ops_td_files"], -) - gentbl_filegroup( name = "hlo_ops_doc_gen", compatible_with = get_compatible_with_portable(), @@ -321,37 +265,6 @@ cc_library( ], ) -td_library( - name = "lhlo_gpu_ops_td_files", - srcs = glob(["lhlo_gpu/IR/*.td"]), - compatible_with = get_compatible_with_portable(), - includes = ["."], - deps = [ - ":hlo_ops_td_files", - ":lhlo_ops_td_files", - "@llvm-project//mlir:SideEffectInterfacesTdFiles", - ], -) - -gentbl_cc_library( - name = "lhlo_gpu_ops_inc_gen", - compatible_with = get_compatible_with_portable(), - strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-op-decls"], - "lhlo_gpu/IR/lhlo_gpu_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "lhlo_gpu/IR/lhlo_gpu_ops.cc.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "lhlo_gpu/IR/lhlo_gpu_ops.td", - deps = [":lhlo_gpu_ops_td_files"], -) - #TODO(aminim): revisit the naming and grouping of these rules post-move. gentbl_cc_library( name = "canonicalize_inc_gen", @@ -393,6 +306,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -526,6 +440,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:Dialect", @@ -533,6 +448,7 @@ cc_library( "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", @@ -575,51 +491,9 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:ViewLikeInterface", - "@stablehlo//:stablehlo_type_inference", - ], -) - -cc_library( - name = "lhlo_gpu", - srcs = ["lhlo_gpu/IR/lhlo_gpu_ops.cc"], - hdrs = ["lhlo_gpu/IR/lhlo_gpu_ops.h"], - strip_include_prefix = ".", - deps = [ - ":hlo_ops_common", - ":lhlo", - ":lhlo_gpu_ops_attrdefs_inc_gen", - ":lhlo_gpu_ops_dialect_inc_gen", - ":lhlo_gpu_ops_enums_inc_gen", - ":lhlo_gpu_ops_inc_gen", - ":lhlo_gpu_ops_ops", - ":mlir_hlo", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - ], -) - -cc_library( - name = "lhlo_gpu_ops_ops", - srcs = ["lhlo_gpu/IR/lhlo_gpu_ops.cc.inc"], - hdrs = ["lhlo_gpu/IR/lhlo_gpu_ops.h.inc"], - strip_include_prefix = ".", - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", - "@llvm-project//mlir:ControlFlowInterfaces", - "@llvm-project//mlir:CopyOpInterface", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:InferTypeOpInterface", - "@llvm-project//mlir:LoopLikeInterface", - "@llvm-project//mlir:Pass", "@llvm-project//mlir:SideEffectInterfaces", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", "@llvm-project//mlir:ViewLikeInterface", + "@stablehlo//:stablehlo_type_inference", ], ) @@ -662,8 +536,6 @@ cc_library( "mhlo/transforms/legalize_gather_to_torch_index_select/legalize_gather_to_torch_index_select.cc", "mhlo/transforms/legalize_shape_computations/legalize_shape_computations.cc", "mhlo/transforms/legalize_sort/legalize_sort.cc", - "mhlo/transforms/legalize_sparse_ops/legalize_sparse_ops.cc", - "mhlo/transforms/legalize_sparse_ops/sparse_ops_to_custom_calls.cc", "mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc", "mhlo/transforms/legalize_to_standard/generated_legalize_to_standard.inc", "mhlo/transforms/legalize_to_standard/legalize_to_standard.cc", @@ -681,16 +553,15 @@ cc_library( "mhlo/transforms/mhlo_canonicalize_scatter/mhlo_canonicalize_scatter.cc", "mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc", "mhlo/transforms/mhlo_passes.h.inc", + "mhlo/transforms/mhlo_quant_legalize_to_int/mhlo_quant_legalize_to_int.cc", "mhlo/transforms/optimize_mhlo/optimize_mhlo.cc", "mhlo/transforms/optimize_mhlo/optimize_mhlo_pass.cc", "mhlo/transforms/prepare_for_export/prepare_for_export.cc", - "mhlo/transforms/rank_specialization/rank_specialization.cc", "mhlo/transforms/restrict_max_rank/restrict_max_rank.cc", "mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc", "mhlo/transforms/shape_reification/shape_reification_pass.cc", "mhlo/transforms/shape_simplification/shape_simplification.cc", "mhlo/transforms/sink_constants_to_control_flow/sink_constants_to_control_flow.cc", - "mhlo/transforms/sparse_rewriting/sparse_rewriting.cc", "mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc", "mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc", "mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc", @@ -705,7 +576,7 @@ cc_library( ], strip_include_prefix = ".", deps = [ - ":chlo_legalize_to_hlo", + ":chlo_legalize_to_hlo_inc_gen", ":hlo_legalize_to_stablehlo", ":legalize_to_linalg_utils", ":legalize_to_standard_inc_gen", @@ -727,6 +598,7 @@ cc_library( "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithUtils", "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:BufferizationInterfaces", "@llvm-project//mlir:BufferizationTransforms", "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:ControlFlowDialect", @@ -741,6 +613,7 @@ cc_library( "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:ShapeTransforms", @@ -749,10 +622,12 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TensorUtils", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@stablehlo//:base", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_passes", ], ) @@ -767,6 +642,7 @@ cc_library( "@llvm-project//mlir:FuncTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@stablehlo//:stablehlo_ops", ], @@ -777,17 +653,8 @@ cc_library( hdrs = ["lhlo/transforms/map_lmhlo_to_scalar_op.h"], strip_include_prefix = ".", deps = [ - ":lhlo", ":map_lhlo_to_hlo_op", ":map_mhlo_to_scalar_op", - ":mlir_hlo", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ComplexDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:SCFDialect", ], ) @@ -803,7 +670,6 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:VectorDialect", ], ) @@ -844,7 +710,6 @@ cc_library( strip_include_prefix = ".", deps = [ ":mlir_hlo", - "@llvm-project//mlir:IR", "@stablehlo//:stablehlo_ops", ], ) @@ -909,27 +774,6 @@ cc_library( deps = ["@llvm-project//llvm:Support"], ) -cc_library( - name = "lhlo_elemental_utils", - srcs = ["lhlo/transforms/lhlo_elemental_utils.cc"], - hdrs = ["lhlo/transforms/lhlo_elemental_utils.h"], - strip_include_prefix = ".", - deps = [ - ":codegen_utils", - ":lhlo", - ":map_lmhlo_to_scalar_op", - ":mlir_hlo", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:Transforms", - ], -) - cc_library( name = "legalize_to_linalg_utils", srcs = ["mhlo/utils/legalize_to_linalg_utils.cc"], @@ -952,6 +796,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TensorUtils", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@stablehlo//:chlo_ops", ], @@ -1039,32 +884,11 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) -cc_library( - name = "chlo_legalize_to_hlo", - srcs = ["mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc"], - hdrs = ["mhlo/transforms/rewriters.h"], - strip_include_prefix = ".", - deps = [ - ":chlo_legalize_to_hlo_inc_gen", - ":map_chlo_to_hlo_op", - ":mlir_hlo", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ComplexDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:Transforms", - "@stablehlo//:broadcast_utils", - "@stablehlo//:chlo_ops", - ], -) - gentbl_cc_library( name = "chlo_legalize_to_hlo_inc_gen", compatible_with = get_compatible_with_portable(), @@ -1092,6 +916,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_ops_inc_gen", @@ -1111,6 +936,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_ops_inc_gen", @@ -1134,7 +960,6 @@ cc_library( ], strip_include_prefix = ".", deps = [ - ":chlo_legalize_to_hlo", ":deallocation_passes", ":deallocation_passes_inc_gen", ":lhlo", @@ -1145,7 +970,6 @@ cc_library( ":stablehlo_legalize_to_hlo", ":transforms_passes", ":transforms_passes_inc_gen", - ":userange_analysis", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -1158,10 +982,7 @@ cc_library( cc_library( name = "transforms_passes", srcs = [ - "analysis/test_userange_analysis.cc", - "mhlo/analysis/test_shape_component_analysis.cc", "transforms/alloc_to_arg_pass.cc", - "transforms/buffer_packing.cc", "transforms/bufferize.cc", "transforms/bufferize_pass.cc", "transforms/collapse_parallel_loops_to_1d_pass.cc", @@ -1190,7 +1011,6 @@ cc_library( ":shape_component_analysis", ":transforms_passes_inc_gen", ":type_conversion", - ":userange_analysis", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", @@ -1201,6 +1021,7 @@ cc_library( "@llvm-project//mlir:ArithUtils", "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:BufferizationInterfaces", "@llvm-project//mlir:BufferizationTransforms", "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:ComplexToLLVM", @@ -1242,6 +1063,7 @@ cc_library( "@llvm-project//mlir:TensorTransforms", "@llvm-project//mlir:TensorUtils", "@llvm-project//mlir:TransformDialect", + "@llvm-project//mlir:TransformDialectInterfaces", "@llvm-project//mlir:TransformDialectTransforms", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", @@ -1281,7 +1103,9 @@ cc_library( "@llvm-project//mlir:ControlFlowToLLVM", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncToLLVM", + "@llvm-project//mlir:GPUCommonTransforms", "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToGPURuntimeTransforms", "@llvm-project//mlir:GPUToNVVMTransforms", "@llvm-project//mlir:GPUToROCDLTransforms", "@llvm-project//mlir:GPUTransforms", @@ -1301,6 +1125,7 @@ cc_library( "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:ShapeToStandard", "@llvm-project//mlir:TensorInferTypeOpInterfaceImpl", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorToLLVM", @@ -1344,20 +1169,6 @@ gentbl_cc_library( deps = ["@llvm-project//mlir:PassBaseTdFiles"], ) -cc_library( - name = "userange_analysis", - srcs = ["analysis/userange_analysis.cc"], - hdrs = ["analysis/userange_analysis.h"], - strip_include_prefix = ".", - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", - "@llvm-project//mlir:BufferizationTransforms", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LoopLikeInterface", - ], -) - cc_library( name = "shape_component_analysis", srcs = ["mhlo/analysis/shape_component_analysis.cc"], @@ -1437,7 +1248,6 @@ cc_binary( ":all_passes", ":hlo_dialect_registration", ":lhlo", - ":lhlo_gpu", ":transforms_gpu_passes", "@llvm-project//llvm:Support", "@llvm-project//mlir:AllExtensions", diff --git a/xla/mlir_hlo/CMakeLists.txt b/xla/mlir_hlo/CMakeLists.txt index 9bfdc58b3a3eb..bb81e4a7e80ea 100644 --- a/xla/mlir_hlo/CMakeLists.txt +++ b/xla/mlir_hlo/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -158,11 +158,9 @@ set(MLIR_HLO_TOOLS_DIR ${MLIR_HLO_BINARY_DIR}/bin) set(MLIR_HLO_LIB_DIR ${MLIR_HLO_BINARY_DIR}/lib) add_custom_target(check-mlir-hlo) -add_subdirectory(analysis) add_subdirectory(bindings) add_subdirectory(deallocation) add_subdirectory(lhlo) -add_subdirectory(lhlo_gpu) add_subdirectory(mhlo) add_subdirectory(stablehlo) add_subdirectory(tests) diff --git a/xla/mlir_hlo/WORKSPACE b/xla/mlir_hlo/WORKSPACE index c3115e33da9ab..ae7ca4dd0254e 100644 --- a/xla/mlir_hlo/WORKSPACE +++ b/xla/mlir_hlo/WORKSPACE @@ -1,4 +1,4 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Workspace for MLIR HLO.""" +# buildifier: disable=load-on-top + +# buildifier: disable=load-on-top load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") diff --git a/xla/mlir_hlo/analysis/CMakeLists.txt b/xla/mlir_hlo/analysis/CMakeLists.txt deleted file mode 100644 index 88a0fec149403..0000000000000 --- a/xla/mlir_hlo/analysis/CMakeLists.txt +++ /dev/null @@ -1,28 +0,0 @@ -add_mlir_library(MLIRHLOAnalysis - userange_analysis.cc - - DEPENDS - mlir-headers - - LINK_LIBS PUBLIC - MLIRAnalysis - MLIRIR -) - -add_mlir_library(MLIRHLOTestAnalysis - test_userange_analysis.cc - - DEPENDS - LMHLOTransformsPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - LmhloDialect - LmhloGPUDialect - MLIRHLOAnalysis - MLIRAnalysis - MLIRPass - MLIRTransforms -) diff --git a/xla/mlir_hlo/analysis/test_userange_analysis.cc b/xla/mlir_hlo/analysis/test_userange_analysis.cc deleted file mode 100644 index 7fde2920a174d..0000000000000 --- a/xla/mlir_hlo/analysis/test_userange_analysis.cc +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "analysis/userange_analysis.h" -#include "lhlo/IR/lhlo_ops.h" -#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" -#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { - -#define GEN_PASS_DEF_TESTUSERANGE -#include "transforms/passes.h.inc" - -namespace { - -struct TestUserangePass : public impl::TestUserangeBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - llvm::outs() << "Testing : " << getOperation().getName() << "\n"; - UserangeAnalysis(getOperation(), - bufferization::BufferPlacementAllocs(getOperation()), - BufferViewFlowAnalysis(getOperation())) - .dump(llvm::outs()); - } -}; - -} // end anonymous namespace - -std::unique_ptr> createTestUserangePass() { - return std::make_unique(); -} - -} // namespace mlir diff --git a/xla/mlir_hlo/analysis/userange_analysis.cc b/xla/mlir_hlo/analysis/userange_analysis.cc deleted file mode 100644 index cc25afdc6959d..0000000000000 --- a/xla/mlir_hlo/analysis/userange_analysis.cc +++ /dev/null @@ -1,625 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "analysis/userange_analysis.h" - -#include -#include -#include -#include - -#include "llvm/ADT/SetOperations.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/Region.h" -#include "mlir/Interfaces/LoopLikeInterface.h" - -using namespace mlir; - -namespace { -/// Builds a userange information from the given value and its liveness. The -/// information includes all operations that are within the userange. -struct UserangeInfoBuilder { - using OperationListT = Liveness::OperationListT; - using ValueSetT = BufferViewFlowAnalysis::ValueSetT; - - public: - /// Constructs an Userange builder. - UserangeInfoBuilder(Liveness liveness, ValueSetT values, - OperationListT opList) - : values(std::move(values)), - opList(std::move(opList)), - liveness(std::move(liveness)) {} - - /// Computes the userange of the current value by iterating over all of its - /// uses. - Liveness::OperationListT computeUserange() { - Region *topRegion = findTopRegion(); - // Iterate over all associated uses. - for (Operation *use : opList) { - // If one of the parents implements a LoopLikeOpInterface we need to add - // all operations inside of its regions to the userange. - Operation *loopParent = use->getParentOfType(); - if (loopParent && topRegion->isProperAncestor(use->getParentRegion())) - addAllOperationsInRegion(loopParent); - - // Check if the parent block has already been processed. - Block *useBlock = findTopLiveBlock(use); - if (!startBlocks.insert(useBlock).second || visited.contains(useBlock)) - continue; - - // Add all operations inside the block that are within the userange. - findOperationsInUse(useBlock); - } - return currentUserange; - } - - private: - /// Find the top most Region of all values stored in the values set. - Region *findTopRegion() const { - Region *topRegion = nullptr; - llvm::for_each(values, [&](Value v) { - Region *other = v.getParentRegion(); - if (!topRegion || topRegion->isAncestor(other)) topRegion = other; - }); - return topRegion; - } - - /// Finds the highest level block that has the current value in its liveOut - /// set. - Block *findTopLiveBlock(Operation *op) const { - Operation *topOp = op; - while (const LivenessBlockInfo *blockInfo = - liveness.getLiveness(op->getBlock())) { - if (llvm::any_of(values, - [&](Value v) { return blockInfo->isLiveOut(v); })) - topOp = op; - op = op->getParentOp(); - } - return topOp->getBlock(); - } - - /// Adds all operations from start to end to the userange of the current - /// value. If an operation implements a nested region all operations inside of - /// it are included as well. If includeEnd is false the end operation is not - /// added. - void addAllOperationsBetween(Operation *start, Operation *end) { - currentUserange.push_back(start); - addAllOperationsInRegion(start); - - while (start != end) { - start = start->getNextNode(); - addAllOperationsInRegion(start); - currentUserange.push_back(start); - } - } - - /// Adds all operations that are uses of the value in the given block to the - /// userange of the current value. Additionally iterate over all successors - /// where the value is live. - void findOperationsInUse(Block *block) { - SmallVector blocksToProcess; - addOperationsInBlockAndFindSuccessors( - block, block, getStartOperation(block), blocksToProcess); - while (!blocksToProcess.empty()) { - Block *toProcess = blocksToProcess.pop_back_val(); - addOperationsInBlockAndFindSuccessors( - block, toProcess, &toProcess->front(), blocksToProcess); - } - } - - /// Adds the operations between the given start operation and the computed end - /// operation to the userange. If the current value is live out, add all - /// successor blocks that have the value live in to the process queue. If we - /// find a loop, add the operations before the first use in block to the - /// userange (if any). The startBlock is the block where the iteration over - /// all successors started and is propagated further to find potential loops. - void addOperationsInBlockAndFindSuccessors( - const Block *startBlock, Block *toProcess, Operation *start, - SmallVector &blocksToProcess) { - const LivenessBlockInfo *blockInfo = liveness.getLiveness(toProcess); - Operation *end = getEndOperation(toProcess); - - addAllOperationsBetween(start, end); - - // If the value is live out we need to process all successors at which the - // value is live in. - if (!llvm::any_of(values, [&](Value v) { return blockInfo->isLiveOut(v); })) - return; - for (Block *successor : toProcess->getSuccessors()) { - // If the successor is the startBlock, we found a loop and only have to - // add the operations from the block front to the first use of the - // value. - if (!llvm::any_of(values, [&](Value v) { - return liveness.getLiveness(successor)->isLiveIn(v); - })) - continue; - if (successor == startBlock) { - start = &successor->front(); - end = getStartOperation(successor); - if (start != end) addAllOperationsBetween(start, end->getPrevNode()); - // Else we need to check if the value is live in and the successor - // has not been visited before. If so we also need to process it. - } else if (visited.insert(successor).second) { - blocksToProcess.push_back(successor); - } - } - } - - /// Iterates over all regions of a given operation and adds all operations - /// inside those regions to the userange of the current value. - void addAllOperationsInRegion(Operation *parentOp) { - // Iterate over all regions of the parentOp. - for (Region ®ion : parentOp->getRegions()) { - // Iterate over blocks inside the region. - for (Block &block : region) { - // If the blocks have been used as a startBlock before, we need to add - // all operations between the block front and the startOp of the value. - if (startBlocks.contains(&block)) { - Operation *start = &block.front(); - Operation *end = getStartOperation(&block); - if (start != end) addAllOperationsBetween(start, end->getPrevNode()); - - // If the block has never been seen before, we need to add all - // operations inside. - } else if (visited.insert(&block).second) { - for (Operation &op : block) { - addAllOperationsInRegion(&op); - currentUserange.push_back(&op); - } - continue; - } - // If the block has either been visited before or was used as a - // startBlock, we need to add all operations between the endOp of the - // value and the end of the block. - Operation *end = getEndOperation(&block); - if (end == &block.back()) continue; - addAllOperationsBetween(end->getNextNode(), &block.back()); - } - } - } - - /// Find the start operation of the current value inside the given block. - Operation *getStartOperation(Block *block) { - Operation *startOperation = &block->back(); - for (Operation *useOp : opList) { - // Find the associated operation in the current block (if any). - useOp = block->findAncestorOpInBlock(*useOp); - // Check whether the use is in our block and after the current end - // operation. - if (useOp && useOp->isBeforeInBlock(startOperation)) - startOperation = useOp; - } - return startOperation; - } - - /// Find the end operation of the current value inside the given block. - Operation *getEndOperation(Block *block) { - const LivenessBlockInfo *blockInfo = liveness.getLiveness(block); - if (llvm::any_of(values, [&](Value v) { return blockInfo->isLiveOut(v); })) - return &block->back(); - - Operation *endOperation = &block->front(); - for (Operation *useOp : opList) { - // Find the associated operation in the current block (if any). - useOp = block->findAncestorOpInBlock(*useOp); - // Check whether the use is in our block and after the current end - // operation. - if (useOp && endOperation->isBeforeInBlock(useOp)) endOperation = useOp; - } - return endOperation; - } - - /// The current Value. - ValueSetT values; - - /// The list of all operations used by the values. - OperationListT opList; - - /// The result list of the userange computation. - OperationListT currentUserange; - - /// The set of visited blocks during the userange computation. - SmallPtrSet visited; - - /// The set of blocks that the userange computation started from. - SmallPtrSet startBlocks; - - /// The current liveness info. - Liveness liveness; -}; -} // namespace - -/// Empty UseInterval Constructor. -UseInterval::UseInterval() - : start(std::numeric_limits::max()), - end(std::numeric_limits::min()) {} - -/// Performs an interval subtraction => A = A - B. -void UseInterval::intervalSubtract(UseInterval::Vector &a, - const UseInterval::Vector &b) { - const auto *iterB = b.begin(); - const auto *endB = b.end(); - for (auto *iterA = a.begin(); iterA != a.end() && iterB != endB;) { - // iterA is strictly before iterB => increment iterA. - if (*iterA < *iterB) { - ++iterA; - // iterB is strictly before iterA => increment iterB. - } else if (*iterA > *iterB) { - ++iterB; - // iterB overlaps with the start of iterA, but iterA has some values that - // go beyond those of iterB. We have to set the start of iterA to the end - // of iterB + 1 and increment iterB. A(3, 100) - B(3, 5) => A(6,100) - } else if (iterA->start >= iterB->start && iterA->end > iterB->end) { - iterA->start = iterB->end + 1; - ++iterB; - // iterB overlaps with the end of iterA, but iterA has some values that - // come before iterB. We have to set the end of iterA to the start of - // iterB - 1 and increment iterA. A(4, 50) - B(40, 50) => A(4, 39) - } else if (iterA->end <= iterB->end && iterA->start < iterB->start) { - iterA->end = iterB->start - 1; - ++iterA; - // iterB is in the middle of iterA. We have to split iterA and increment - // iterB. - // A(2, 10) - B(5, 7) => (2, 4), (8, 10) - } else if (iterA->start < iterB->start && iterA->end > iterB->end) { - size_t endA = iterA->end; - iterA->end = iterB->start - 1; - iterA = a.insert(iterA, UseInterval(iterB->end + 1, endA)); - ++iterB; - // Both intervals are equal. We have to erase the whole interval. - // A(5, 5) - B(5, 5) => {} - } else { - iterA = a.erase(iterA); - ++iterB; - } - } -} - -/// Performs an interval intersection => A = A ^ B. -void UseInterval::intervalIntersect(UseInterval::Vector &a, - const UseInterval::Vector &b) { - const auto *iterB = b.begin(); - const auto *endB = b.end(); - for (auto *iterA = a.begin(); iterA != a.end();) { - // iterB points to the end, therefore the remaining UseIntervals from A must - // be erased or iterA is strictly before iterB => erase iterA. - if (iterB == endB || *iterA < *iterB) { - iterA = a.erase(iterA); - // iterB is strictly before iterA => increment iterB. - } else if (*iterA > *iterB) { - ++iterB; - // iterB overlaps with iterA => reduce the interval to the overlap and - // insert the ending split-off to vector A again. - } else { - size_t currentEndA = iterA->end; - iterA->start = std::max(iterA->start, iterB->start); - iterA->end = std::min(currentEndA, iterB->end); - if (currentEndA > iterB->end) { - iterA = a.insert(std::next(iterA), - UseInterval(iterB->end + 1, currentEndA)); - ++iterB; - } else { - ++iterA; - } - } - } -} - -/// Performs an interval merge => A = A u B. -/// Note: All overlapping and contiguous UseIntervals are merged. -void UseInterval::intervalMerge(UseInterval::Vector &a, - const UseInterval::Vector &b) { - const auto *iterB = b.begin(); - const auto *endB = b.end(); - // Iterate over UseInterval::Vector a and b. - for (auto *iterA = a.begin(); iterA != a.end() && iterB != endB;) { - // Let A be the UseInterval of iterA and B the UseInterval of iterB. - // Check if A is before B. - if (*iterA < *iterB) { - // Check if A and B can be merged if they are contiguous. If the merge - // result contains the next elements of A, we can erase them. - if (iterA->isContiguous(*iterB)) { - mergeAndEraseContiguousIntervals(a, iterA, *iterB); - ++iterB; - } - ++iterA; - // Check if B is before A. - } else if (*iterA > *iterB) { - // Check if A and B can be merged if they are contiguous, else add B - // to the Vector of A. - if (iterB->isContiguous(*iterA)) - iterA->mergeWith(*iterB); - else - iterA = a.insert(iterA, *iterB); - ++iterB; - // The UseIntervals interfere and must be merged. - } else { - mergeAndEraseContiguousIntervals(a, iterA, *iterB); - ++iterB; - } - } - // If there are remaining UseIntervals in b, add them to a. - if (iterB != endB) a.insert(a.end(), iterB, endB); -} - -/// Merge the UseIntervals and erase overlapping and contiguouse UseIntervals -/// of the UseInterval::Vector. -void UseInterval::mergeAndEraseContiguousIntervals( - UseInterval::Vector &interval, UseInterval *iter, - const UseInterval &toMerge) { - // Return if the iter points to the end. - if (iter == interval.end()) return; - - // Merge the UseIntervals. - iter->mergeWith(toMerge); - - // Find the next UseInterval from iter that is not contiguous with the merged - // iter. - UseInterval *next = std::next(iter); - while (next != interval.end() && iter->isContiguous(*next)) { - if (iter->end < next->end) iter->end = next->end; - ++next; - } - // Remove contiguous UseIntervals. - if (std::next(iter) != next) iter = interval.erase(std::next(iter), next); -} - -UserangeAnalysis::UserangeAnalysis( - Operation *op, const bufferization::BufferPlacementAllocs &allocs, - const BufferViewFlowAnalysis &aliases) - : liveness(op) { - // Walk over all operations and map them to an ID. - op->walk([&](Operation *operation) { - gatherMemoryEffects(operation); - operationIds.insert({operation, operationIds.size()}); - operations.push_back(operation); - }); - - // Compute the use range for every allocValue and its aliases. Merge them - // and compute an interval. Add all computed intervals to the useIntervalMap. - for (const bufferization::BufferPlacementAllocs::AllocEntry &entry : allocs) { - Value allocValue = std::get<0>(entry); - const Value::use_range &allocUses = allocValue.getUses(); - size_t dist = std::distance(allocUses.begin(), allocUses.end()); - OperationListT useList; - useList.reserve(dist); - for (auto &use : allocUses) useList.push_back(use.getOwner()); - computeUsePositions(allocValue); - - UserangeInfoBuilder builder(liveness, {allocValue}, useList); - OperationListT liveOperations = builder.computeUserange(); - - // Sort the operation list by ids. - std::sort(liveOperations.begin(), liveOperations.end(), - [&](Operation *left, Operation *right) { - return operationIds[left] < operationIds[right]; - }); - - UseInterval::Vector allocInterval = - computeInterval(allocValue, liveOperations); - // Iterate over all aliases and add their useranges to the userange of the - // current value. Also add the useInterval of each alias to the - // useIntervalMap. - ValueSetT aliasSet = aliases.resolve(allocValue); - for (Value alias : aliasSet) { - if (alias == allocValue) continue; - if (!aliasUseranges.count(alias)) { - OperationListT aliasOperations; - // If the alias is a BlockArgument then the value is live with the first - // operation inside that block. Otherwise the liveness analysis is - // sufficient for the use range. - if (alias.isa()) { - aliasOperations.push_back(&alias.getParentBlock()->front()); - for (auto &use : alias.getUses()) - aliasOperations.push_back(use.getOwner()); - // Compute the use range for the alias and sort the operations - // afterwards. - UserangeInfoBuilder aliasBuilder(liveness, {alias}, aliasOperations); - aliasOperations = aliasBuilder.computeUserange(); - std::sort(aliasOperations.begin(), aliasOperations.end(), - [&](Operation *left, Operation *right) { - return operationIds[left] < operationIds[right]; - }); - } else { - aliasOperations = liveness.resolveLiveness(alias); - } - - aliasUseranges.insert({alias, aliasOperations}); - useIntervalMap.insert( - {alias, computeInterval(alias, aliasUseranges[alias])}); - computeUsePositions(alias); - } - UseInterval::intervalMerge(allocInterval, useIntervalMap[alias]); - mergeUsePositions(usePositionMap[allocValue], usePositionMap[alias]); - } - aliasCache.insert(std::make_pair(allocValue, aliasSet)); - - // Map the current allocValue to the computed useInterval. - useIntervalMap.insert(std::make_pair(allocValue, allocInterval)); - } -} - -/// Computes the doubled Id for the given value inside the operation based on -/// the program sequence. If the value has only read effects, the returning ID -/// will be even, otherwise odd. -size_t UserangeAnalysis::computeId(Value v, Operation *op) const { - size_t doubledID = (operationIds.find(op)->second + 1) * 2 - 1; - auto mapIter = opReadWriteMap.find(op); - if (mapIter == opReadWriteMap.end()) return doubledID; - auto reads = mapIter->second.first; - auto writes = mapIter->second.second; - if (reads.contains(v) && !writes.contains(v)) return doubledID - 1; - return doubledID; -} - -/// Computes the UsePositions of the given Value, sorts and inserts them into -/// the usePositionMap. -void UserangeAnalysis::computeUsePositions(Value v) { - // Get the uses of v. - const Value::use_range &uses = v.getUses(); - - // Create a UsePositionList. - UsePositionList usePosList; - size_t dist = std::distance(uses.begin(), uses.end()); - usePosList.reserve(dist); - - // Add all ids and Operations to the UsePositionList. - for (auto &use : uses) { - Operation *useOwner = use.getOwner(); - usePosList.emplace_back(computeId(v, useOwner), useOwner); - } - - // Sort the UsePositions by ascending Ids. - std::sort(usePosList.begin(), usePosList.end(), - [](const UsePosition &a, const UsePosition &b) { - return a.first < b.first; - }); - - // Insert the UsePositionList into the usePositionMap. - usePositionMap.insert(std::make_pair(v, usePosList)); -} - -/// Merges listB into listA, sorts the result and removes all duplicates. -void UserangeAnalysis::mergeUsePositions(UsePositionList &listA, - const UsePositionList &listB) { - // Insert listB into listA. - listA.insert(listA.end(), listB.begin(), listB.end()); - - // Sort the resulting listA. - std::sort(listA.begin(), listA.end(), - [](const UsePosition &a, const UsePosition &b) { - return a.first < b.first; - }); - - // Remove duplicates. - listA.erase(std::unique(listA.begin(), listA.end()), listA.end()); -} - -/// Checks if the use intervals of the given values interfere. -bool UserangeAnalysis::rangesInterfere(Value itemA, Value itemB) const { - ValueSetT intersect = aliasCache.find(itemA)->second; - llvm::set_intersect(intersect, aliasCache.find(itemB)->second); - UseInterval::Vector tmpIntervalA = useIntervalMap.find(itemA)->second; - const UseInterval::Vector &intervalsB = useIntervalMap.find(itemB)->second; - - // If the two values share a common alias, then the alias does not count as an - // interference and should be removed. - if (!intersect.empty()) { - for (Value alias : intersect) { - const UseInterval::Vector &aliasInterval = - useIntervalMap.find(alias)->second; - UseInterval::intervalSubtract(tmpIntervalA, aliasInterval); - } - } - - // Iterate over both UseInterval::Vector and check if they interfere. - const auto *iterB = intervalsB.begin(); - const auto *endB = intervalsB.end(); - for (auto iterA = tmpIntervalA.begin(), endA = tmpIntervalA.end(); - iterA != endA && iterB != endB;) { - if (*iterA < *iterB) - ++iterA; - else if (*iterA > *iterB) - ++iterB; - else - return true; - } - return false; -} - -/// Merges the userange of itemB into the userange of itemA. -void UserangeAnalysis::unionRanges(Value itemA, Value itemB) { - UseInterval::intervalMerge(useIntervalMap[itemA], useIntervalMap[itemB]); -} - -/// Builds an UseInterval::Vector corresponding to the given OperationList. -UseInterval::Vector UserangeAnalysis::computeInterval( - Value value, const Liveness::OperationListT &operationList) { - assert(!operationList.empty() && "Operation list must not be empty"); - size_t start = computeId(value, *operationList.begin()); - size_t last = start; - UseInterval::Vector intervals; - // Iterate over all operations in the operationList. If the gap between the - // respective operationIds is greater 1 create a new interval. - for (auto opIter = ++operationList.begin(), e = operationList.end(); - opIter != e; ++opIter) { - size_t current = computeId(value, *opIter); - if (current - last > 2) { - intervals.emplace_back(start, last); - start = current; - } - last = current; - } - intervals.emplace_back(start, last); - return intervals; -} - -/// Checks each operand within the operation for its memory effects and -/// separates them into read and write. -void UserangeAnalysis::gatherMemoryEffects(Operation *op) { - if (OpTrait::hasElementwiseMappableTraits(op)) { - if (auto effectInterface = dyn_cast(op)) { - SmallPtrSet readEffectSet; - SmallPtrSet writeEffectSet; - SmallVector effects; - for (auto operand : op->getOperands()) { - effects.clear(); - effectInterface.getEffectsOnValue(operand, effects); - for (auto effect : effects) { - if (isa(effect.getEffect())) - writeEffectSet.insert(operand); - else if (isa(effect.getEffect())) - readEffectSet.insert(operand); - } - } - opReadWriteMap.insert( - {op, std::make_pair(readEffectSet, writeEffectSet)}); - } - } -} - -/// Computes the doubled Id back to the OperationId. -size_t UserangeAnalysis::unwrapId(size_t id) const { return id / 2; } - -void UserangeAnalysis::dump(raw_ostream &os) { - os << "// ---- UserangeAnalysis -----\n"; - llvm::SmallVector values; - values.reserve(useIntervalMap.size()); - for (auto const &item : useIntervalMap) { - values.push_back(item.first); - } - std::sort(values.begin(), values.end(), [&](Value left, Value right) { - if (left.getDefiningOp()) { - if (right.getDefiningOp()) - return operationIds[left.getDefiningOp()] < - operationIds[right.getDefiningOp()]; - return true; - } - if (right.getDefiningOp()) return false; - return operationIds[&left.getParentBlock()->front()] < - operationIds[&right.getParentBlock()->front()]; - }); - for (auto value : values) { - os << "Value: " << value << (value.getDefiningOp() ? "\n" : ""); - auto *rangeIt = useIntervalMap[value].begin(); - os << "Userange: {(" << rangeIt->start << ", " << rangeIt->end << ")"; - rangeIt++; - for (auto *e = useIntervalMap[value].end(); rangeIt != e; ++rangeIt) { - os << ", (" << rangeIt->start << ", " << rangeIt->end << ")"; - } - os << "}\n"; - } - os << "// ---------------------------\n"; -} diff --git a/xla/mlir_hlo/analysis/userange_analysis.h b/xla/mlir_hlo/analysis/userange_analysis.h deleted file mode 100644 index 48216cfd395a4..0000000000000 --- a/xla/mlir_hlo/analysis/userange_analysis.h +++ /dev/null @@ -1,206 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_ANALYSIS_USERANGE_ANALYSIS_H -#define MLIR_HLO_ANALYSIS_USERANGE_ANALYSIS_H - -#include -#include - -#include "mlir/Analysis/Liveness.h" -#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Value.h" - -namespace mlir { - -/// Representation of an inclusive Interval for the Userange. -struct UseInterval { - using Vector = SmallVector; - - public: - /// UseInterval Constructor. - UseInterval(); - /// Empty UseInterval Constructor. - UseInterval(size_t start, size_t end) : start(start), end(end) {} - - /// Checks if the given UseInterval overlaps with this UseInterval. - bool isOverlapping(const UseInterval &other) const { - return start <= other.end && end >= other.start; - } - - /// Checks if the given UseInterval is contiguous with this UseInterval in - /// terms of doubled Ids. - /// For example: (0, 2) and (4, 6) are contiguous where (0, 2) and (5, 6) are - /// not. - bool isContiguous(const UseInterval &other) const { - return start <= other.end + 2 && end + 2 >= other.start; - } - - /// Checks if the given position is inside this UseInterval. - bool contains(size_t position) const { - return start <= position && end >= position; - } - - /// Merges this UseInterval with the given UseInterval by updating start and - /// end. - bool mergeWith(const UseInterval &other) { - if (!isContiguous(other)) return false; - start = std::min(start, other.start); - end = std::max(end, other.end); - return true; - } - - /// Performs an interval subtraction => A = A - B. - static void intervalSubtract(Vector &a, const Vector &b); - - /// Performs an interval intersection => A = A ^ B. - static void intervalIntersect(Vector &a, const Vector &b); - - /// Performs an interval merge => A = A u B. - /// Note: All overlapping and contiguous UseIntervals are merged. - static void intervalMerge(Vector &a, const Vector &b); - - /// Merge the UseIntervals and erase overlapping and contiguouse UseIntervals - /// of the UseInterval::Vector. - static void mergeAndEraseContiguousIntervals(Vector &interval, - UseInterval *iter, - const UseInterval &toMerge); - - bool operator<(const UseInterval &other) const { return end < other.start; } - - bool operator>(const UseInterval &other) const { return start > other.end; } - - bool operator==(const UseInterval &other) const { - return start == other.start && end == other.end; - } - - /// The start of this UseInterval. - size_t start; - - /// The end of this UseInterval. - size_t end; -}; - -/// Represents an analysis for computing the useranges of all alloc values -/// inside a given function operation. The analysis uses liveness information to -/// compute intervals starting at the first and ending with the last use of -/// every alloc value. -class UserangeAnalysis { - public: - using UsePosition = std::pair; - using UsePositionList = std::vector; - - UserangeAnalysis(Operation *op, - const bufferization::BufferPlacementAllocs &allocs, - const BufferViewFlowAnalysis &aliases); - - /// Returns the index of the first operation that uses the given value or an - /// empty Optional if the value has no uses. - std::optional getFirstUseIndex(Value value) const { - auto &intervals = useIntervalMap.find(value)->second; - if (intervals.empty()) return std::nullopt; - return intervals.begin()->start; - } - - /// Returns the UseInterval::Vector of the given value. - std::optional getUserangeInterval( - Value value) const { - auto intervals = useIntervalMap.find(value); - if (intervals == useIntervalMap.end()) return std::nullopt; - return &intervals->second; - } - - /// Returns an UsePositionList* of the given value or an empty Optional - /// if the value has no uses. - std::optional getUserangePositions( - Value value) const { - auto usePosition = usePositionMap.find(value); - if (usePosition == usePositionMap.end() || usePosition->second.empty()) - return std::nullopt; - return &usePosition->second; - } - - /// Returns the operation associated with a given Id. - Operation *getOperation(size_t id) const { return operations[unwrapId(id)]; }; - - /// Computes the doubled Id for the given value inside the operation based on - /// the program sequence. If the value has only read effects, the returning ID - /// will be even, otherwise odd. - size_t computeId(Value v, Operation *op) const; - - /// Checks if the use intervals of the given values interfere. - bool rangesInterfere(Value itemA, Value itemB) const; - - /// Merges the userange of itemB into the userange of itemA. - void unionRanges(Value itemA, Value itemB); - - /// Merges listB into listA, sorts the result and removes all duplicates. - static void mergeUsePositions(UsePositionList &listA, - const UsePositionList &listB); - - /// Dumps the liveness information to the given stream. - void dump(raw_ostream &os); - - private: - using ValueSetT = BufferViewFlowAnalysis::ValueSetT; - using OperationListT = Liveness::OperationListT; - - /// Builds an UseInterval::Vector corresponding to the given OperationList. - UseInterval::Vector computeInterval( - Value value, const Liveness::OperationListT &operationList); - - /// Computes the UsePositions of the given Value, sorts and inserts them into - /// the usePositionMap. - void computeUsePositions(Value v); - - /// Checks each operand within the operation for its memory effects and - /// separates them into read and write. - void gatherMemoryEffects(Operation *op); - - /// Computes the doubled Id back to the OperationId. - size_t unwrapId(size_t id) const; - - /// Maps each Operation to a unique ID according to the program sequence. - DenseMap operationIds; - - /// Stores all operations according to the program sequence. - std::vector operations; - - /// Maps a value to its UseInterval::Vector. - DenseMap useIntervalMap; - - /// Maps an Operation to a pair of read and write Operands. - DenseMap, SmallPtrSet>> - opReadWriteMap; - - /// Maps aliasValues to their use ranges. This is necessary to prevent - /// recomputations of the use range intervals of the aliases. - DenseMap aliasUseranges; - - /// Maps a Value to a UsePostionList which contains all uses of the Value and - /// their userange position. - DenseMap usePositionMap; - - /// Cache the alias lists for all values to avoid recomputation. - BufferViewFlowAnalysis::ValueMapT aliasCache; - - /// The current liveness info. - Liveness liveness; -}; - -} // namespace mlir - -#endif // MLIR_HLO_ANALYSIS_USERANGE_ANALYSIS_H diff --git a/xla/mlir_hlo/bindings/c/Attributes.cc b/xla/mlir_hlo/bindings/c/Attributes.cc index 7e5b16faea39f..d5183742094a3 100644 --- a/xla/mlir_hlo/bindings/c/Attributes.cc +++ b/xla/mlir_hlo/bindings/c/Attributes.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -654,3 +654,29 @@ intptr_t mlirMhloTypeExtensionsGetBoundsSize(MlirAttribute attr) { int64_t mlirMhloTypeExtensionsGetBoundsElem(MlirAttribute attr, intptr_t pos) { return unwrap(attr).cast().getBounds()[pos]; } + +// +// SparsityDescriptor +// + +MlirAttribute mlirMhloSparsityDescriptorGet(MlirContext ctx, int64_t dimension, + int64_t n, int64_t m) { + return wrap( + mlir::mhlo::SparsityDescriptorAttr::get(unwrap(ctx), dimension, n, m)); +} + +bool mlirMhloAttributeIsASparsityDescriptor(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +int64_t mlirMhloSparsityDescriptorGetDimension(MlirAttribute attr) { + return unwrap(attr).cast().getDimension(); +} + +int64_t mlirMhloSparsityDescriptorGetN(MlirAttribute attr) { + return unwrap(attr).cast().getN(); +} + +int64_t mlirMhloSparsityDescriptorGetM(MlirAttribute attr) { + return unwrap(attr).cast().getM(); +} diff --git a/xla/mlir_hlo/bindings/c/Attributes.h b/xla/mlir_hlo/bindings/c/Attributes.h index 624440c2ce8f5..1eabbcd6ecde8 100644 --- a/xla/mlir_hlo/bindings/c/Attributes.h +++ b/xla/mlir_hlo/bindings/c/Attributes.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -381,6 +381,23 @@ mlirMhloTypeExtensionsGetBoundsSize(MlirAttribute attr); MLIR_CAPI_EXPORTED int64_t mlirMhloTypeExtensionsGetBoundsElem(MlirAttribute attr, intptr_t pos); +// +// SparsityDescriptor +// +// Creates a SparseDescriptor attribute with the given sparsity configurations. +MLIR_CAPI_EXPORTED MlirAttribute mlirMhloSparsityDescriptorGet( + MlirContext ctx, int64_t dimension, int64_t n, int64_t m); + +// Returns true if the given attribute is a SparsityDescriptor attribute. +MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsASparsityDescriptor( + MlirAttribute attr); + +// Returns the dimension and N:M sparsity configurations. +MLIR_CAPI_EXPORTED int64_t +mlirMhloSparsityDescriptorGetDimension(MlirAttribute attr); +MLIR_CAPI_EXPORTED int64_t mlirMhloSparsityDescriptorGetN(MlirAttribute attr); +MLIR_CAPI_EXPORTED int64_t mlirMhloSparsityDescriptorGetM(MlirAttribute attr); + #ifdef __cplusplus } #endif diff --git a/xla/mlir_hlo/bindings/c/Dialects.cc b/xla/mlir_hlo/bindings/c/Dialects.cc index edb68eb11a327..667cb065f005a 100644 --- a/xla/mlir_hlo/bindings/c/Dialects.cc +++ b/xla/mlir_hlo/bindings/c/Dialects.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/xla/mlir_hlo/bindings/c/Dialects.h b/xla/mlir_hlo/bindings/c/Dialects.h index ca18c4980e412..a21cc2059144c 100644 --- a/xla/mlir_hlo/bindings/c/Dialects.h +++ b/xla/mlir_hlo/bindings/c/Dialects.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/xla/mlir_hlo/bindings/c/Passes.cc b/xla/mlir_hlo/bindings/c/Passes.cc index 0a47ced1836ce..f5508035fc700 100644 --- a/xla/mlir_hlo/bindings/c/Passes.cc +++ b/xla/mlir_hlo/bindings/c/Passes.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/xla/mlir_hlo/bindings/c/Passes.h b/xla/mlir_hlo/bindings/c/Passes.h index a2cfb784575a9..85cf8a9f33fd4 100644 --- a/xla/mlir_hlo/bindings/c/Passes.h +++ b/xla/mlir_hlo/bindings/c/Passes.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/xla/mlir_hlo/bindings/c/Types.cc b/xla/mlir_hlo/bindings/c/Types.cc index b4669eccb8ee1..0be0e34c02069 100644 --- a/xla/mlir_hlo/bindings/c/Types.cc +++ b/xla/mlir_hlo/bindings/c/Types.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/xla/mlir_hlo/bindings/c/Types.h b/xla/mlir_hlo/bindings/c/Types.h index 6869997aa0379..bd0d825ee82ce 100644 --- a/xla/mlir_hlo/bindings/c/Types.h +++ b/xla/mlir_hlo/bindings/c/Types.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/xla/mlir_hlo/bindings/python/MlirHloModule.cc b/xla/mlir_hlo/bindings/python/MlirHloModule.cc index 1f96eb75a7ce2..18f87bbcf662d 100644 --- a/xla/mlir_hlo/bindings/python/MlirHloModule.cc +++ b/xla/mlir_hlo/bindings/python/MlirHloModule.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -508,4 +508,29 @@ PYBIND11_MODULE(_mlirHlo, m) { mlirMhloTypeExtensionsGetBoundsSize, mlirMhloTypeExtensionsGetBoundsElem); }); + + mlir::python::adaptors::mlir_attribute_subclass( + m, "SparsityDescriptor", mlirMhloAttributeIsASparsityDescriptor) + .def_classmethod( + "get", + [](py::object cls, const int64_t dimension, const int64_t n, + const int64_t m, MlirContext ctx) { + return cls(mlirMhloSparsityDescriptorGet(ctx, dimension, n, m)); + }, + py::arg("cls"), py::arg("dimension"), py::arg("n"), py::arg("m"), + py::arg("context") = py::none(), + "Creates a SparseDescriptor attribute with the given sparsity " + "configurations.") + .def_property_readonly( + "dimension", + [](MlirAttribute self) { + return mlirMhloSparsityDescriptorGetDimension(self); + }) + .def_property_readonly("n", + [](MlirAttribute self) { + return mlirMhloSparsityDescriptorGetN(self); + }) + .def_property_readonly("m", [](MlirAttribute self) { + return mlirMhloSparsityDescriptorGetM(self); + }); } diff --git a/xla/mlir_hlo/bindings/python/mlir/dialects/MhloOps.td b/xla/mlir_hlo/bindings/python/mlir/dialects/MhloOps.td index d87efee87b2ab..33791f3e3eb6e 100644 --- a/xla/mlir_hlo/bindings/python/mlir/dialects/MhloOps.td +++ b/xla/mlir_hlo/bindings/python/mlir/dialects/MhloOps.td @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/bindings/python/mlir/dialects/mhlo.py b/xla/mlir_hlo/bindings/python/mlir/dialects/mhlo.py index b25ec4d5d9401..aa45ea20ea66e 100644 --- a/xla/mlir_hlo/bindings/python/mlir/dialects/mhlo.py +++ b/xla/mlir_hlo/bindings/python/mlir/dialects/mhlo.py @@ -1,4 +1,4 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/build_tools/build_mlir.sh b/xla/mlir_hlo/build_tools/build_mlir.sh index 7af3164775f7d..ff95012fff1ab 100755 --- a/xla/mlir_hlo/build_tools/build_mlir.sh +++ b/xla/mlir_hlo/build_tools/build_mlir.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The OpenXLA Authors. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/xla/mlir_hlo/deallocation/CMakeLists.txt b/xla/mlir_hlo/deallocation/CMakeLists.txt index d758e74bb38e8..c4ec5db137cd1 100644 --- a/xla/mlir_hlo/deallocation/CMakeLists.txt +++ b/xla/mlir_hlo/deallocation/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,4 +13,4 @@ # limitations under the License. add_subdirectory(transforms) -add_subdirectory(utils) \ No newline at end of file +add_subdirectory(utils) diff --git a/xla/mlir_hlo/deallocation/transforms/CMakeLists.txt b/xla/mlir_hlo/deallocation/transforms/CMakeLists.txt index efe5e92b44ac8..3b014d858b270 100644 --- a/xla/mlir_hlo/deallocation/transforms/CMakeLists.txt +++ b/xla/mlir_hlo/deallocation/transforms/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc b/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc index 7f95cdbbabc51..951dc240e9b53 100644 --- a/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc +++ b/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/deallocation/transforms/passes.h b/xla/mlir_hlo/deallocation/transforms/passes.h index 4ba90fed2eae9..b0ead40c4f72f 100644 --- a/xla/mlir_hlo/deallocation/transforms/passes.h +++ b/xla/mlir_hlo/deallocation/transforms/passes.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/deallocation/transforms/passes.td b/xla/mlir_hlo/deallocation/transforms/passes.td index 34dc1f56be6c4..65bcbc5d03cd2 100644 --- a/xla/mlir_hlo/deallocation/transforms/passes.td +++ b/xla/mlir_hlo/deallocation/transforms/passes.td @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/deallocation/utils/CMakeLists.txt b/xla/mlir_hlo/deallocation/utils/CMakeLists.txt index 9603fef1a7023..9e10fcf5ba960 100644 --- a/xla/mlir_hlo/deallocation/utils/CMakeLists.txt +++ b/xla/mlir_hlo/deallocation/utils/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/deallocation/utils/util.cc b/xla/mlir_hlo/deallocation/utils/util.cc index 4226425e752ea..5c383b357e7f6 100644 --- a/xla/mlir_hlo/deallocation/utils/util.cc +++ b/xla/mlir_hlo/deallocation/utils/util.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/deallocation/utils/util.h b/xla/mlir_hlo/deallocation/utils/util.h index 14ff8ec0798a8..ce6be44f99d96 100644 --- a/xla/mlir_hlo/deallocation/utils/util.h +++ b/xla/mlir_hlo/deallocation/utils/util.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/CMakeLists.txt b/xla/mlir_hlo/lhlo/CMakeLists.txt index e138afa587f38..649a7d47b96d5 100644 --- a/xla/mlir_hlo/lhlo/CMakeLists.txt +++ b/xla/mlir_hlo/lhlo/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/IR/CMakeLists.txt b/xla/mlir_hlo/lhlo/IR/CMakeLists.txt index 1cffce7c86fe6..51516427777c8 100644 --- a/xla/mlir_hlo/lhlo/IR/CMakeLists.txt +++ b/xla/mlir_hlo/lhlo/IR/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/IR/lhlo_dialect.td b/xla/mlir_hlo/lhlo/IR/lhlo_dialect.td index 7cddf4e08b1d1..4e95122db70c2 100644 --- a/xla/mlir_hlo/lhlo/IR/lhlo_dialect.td +++ b/xla/mlir_hlo/lhlo/IR/lhlo_dialect.td @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/IR/lhlo_ops.cc b/xla/mlir_hlo/lhlo/IR/lhlo_ops.cc index 30f6a3f1fe1ad..952e3c94751d3 100644 --- a/xla/mlir_hlo/lhlo/IR/lhlo_ops.cc +++ b/xla/mlir_hlo/lhlo/IR/lhlo_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/IR/lhlo_ops.h b/xla/mlir_hlo/lhlo/IR/lhlo_ops.h index f5a1f7c013cfa..7adcacee3060e 100644 --- a/xla/mlir_hlo/lhlo/IR/lhlo_ops.h +++ b/xla/mlir_hlo/lhlo/IR/lhlo_ops.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/IR/lhlo_ops.td b/xla/mlir_hlo/lhlo/IR/lhlo_ops.td index e406b7006573e..c793ed6b1002d 100644 --- a/xla/mlir_hlo/lhlo/IR/lhlo_ops.td +++ b/xla/mlir_hlo/lhlo/IR/lhlo_ops.td @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -86,6 +86,22 @@ def LHLO_CommandBufferOp: LHLO_Op<"command_buffer", []> { let arguments = (ins); } +def LHLO_AsyncStartOp: LHLO_Op<"async_start", []> { + let summary = "Async start operator"; + let description = [{ + A dummy operation that represents the submission of a generic async start. + }]; + let arguments = (ins); +} + +def LHLO_AsyncDoneOp: LHLO_Op<"async_done", []> { + let summary = "Async done operator"; + let description = [{ + A dummy operation that represents the submission of a generic async done. + }]; + let arguments = (ins); +} + //===----------------------------------------------------------------------===// // LMHLO unary elementwise op definitions. //===----------------------------------------------------------------------===// @@ -380,7 +396,7 @@ class LHLO_BinaryElementwiseOp:$lhs, Arg:$rhs, Arg:$out, - OptionalAttr:$broadcast_dimensions + OptionalAttr:$broadcast_dimensions ); } @@ -421,7 +437,7 @@ def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]> { Arg:$lhs, Arg:$rhs, Arg:$output, - OptionalAttr:$broadcast_dimensions + OptionalAttr:$broadcast_dimensions ); } @@ -681,7 +697,7 @@ def LHLO_CompareOp: LHLO_Op<"compare", []> { Arg:$lhs, Arg:$rhs, Arg:$out, - OptionalAttr:$broadcast_dimensions, + OptionalAttr:$broadcast_dimensions, MHLO_ComparisonDirectionAttr:$comparison_direction, OptionalAttr:$compare_type ); @@ -845,7 +861,7 @@ def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim", let arguments = (ins Arg:$operand, Arg:$output, - BroadcastDimAttr:$broadcast_dimensions + I64ElementsAttr:$broadcast_dimensions ); } @@ -1485,7 +1501,7 @@ def LHLO_DynamicBroadcastInDimOp : LHLO_Op<"dynamic_broadcast_in_dim", Arg:$operand, Arg:$output_dimensions, Arg:$output, - BroadcastDimAttr:$broadcast_dimensions + I64ElementsAttr:$broadcast_dimensions ); } diff --git a/xla/mlir_hlo/lhlo/IR/lhlo_ops_base.td b/xla/mlir_hlo/lhlo/IR/lhlo_ops_base.td index 0ba78dfd69853..4760c4c09eec3 100644 --- a/xla/mlir_hlo/lhlo/IR/lhlo_ops_base.td +++ b/xla/mlir_hlo/lhlo/IR/lhlo_ops_base.td @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/IR/lhlo_ops_structs.h b/xla/mlir_hlo/lhlo/IR/lhlo_ops_structs.h index 593249a71eb06..38371e3595fc9 100644 --- a/xla/mlir_hlo/lhlo/IR/lhlo_ops_structs.h +++ b/xla/mlir_hlo/lhlo/IR/lhlo_ops_structs.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/IR/lhlo_ops_structs.td b/xla/mlir_hlo/lhlo/IR/lhlo_ops_structs.td index 44a3650fa6987..ca8a42ca33ba4 100644 --- a/xla/mlir_hlo/lhlo/IR/lhlo_ops_structs.td +++ b/xla/mlir_hlo/lhlo/IR/lhlo_ops_structs.td @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.cc b/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.cc index 73bd3450c1607..bb75b36e64834 100644 --- a/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.cc +++ b/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.h b/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.h index 0a584db58c402..05f6c0c3ed298 100644 --- a/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.h +++ b/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.td b/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.td index 38156ef129083..efc0ee0c81c7d 100644 --- a/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.td +++ b/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.td @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/transforms/CMakeLists.txt b/xla/mlir_hlo/lhlo/transforms/CMakeLists.txt index 6a85513005eeb..30aae289f923d 100644 --- a/xla/mlir_hlo/lhlo/transforms/CMakeLists.txt +++ b/xla/mlir_hlo/lhlo/transforms/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +24,6 @@ include_directories(BEFORE add_mlir_library(LmhloPasses legalize_to_tensor_op/legalize_to_tensor_op.cc - lhlo_elemental_utils.cc lhlo_legalize_to_affine/lhlo_legalize_to_affine.cc lhlo_legalize_to_gpu/lhlo_legalize_to_gpu.cc lhlo_legalize_to_parallel_loops/lhlo_legalize_to_parallel_loops.cc diff --git a/xla/mlir_hlo/lhlo/transforms/legalize_to_tensor_op/legalize_to_tensor_op.cc b/xla/mlir_hlo/lhlo/transforms/legalize_to_tensor_op/legalize_to_tensor_op.cc index 1f843908f2b8f..6513fb7370295 100644 --- a/xla/mlir_hlo/lhlo/transforms/legalize_to_tensor_op/legalize_to_tensor_op.cc +++ b/xla/mlir_hlo/lhlo/transforms/legalize_to_tensor_op/legalize_to_tensor_op.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/transforms/lhlo_elemental_utils.cc b/xla/mlir_hlo/lhlo/transforms/lhlo_elemental_utils.cc deleted file mode 100644 index 85183e5b020d6..0000000000000 --- a/xla/mlir_hlo/lhlo/transforms/lhlo_elemental_utils.cc +++ /dev/null @@ -1,275 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file provides basic utilities for the elemental lowering of -// each node - -#include "lhlo/transforms/lhlo_elemental_utils.h" - -#include "lhlo/IR/lhlo_ops.h" -#include "lhlo/transforms/map_lmhlo_to_scalar_op.h" -#include "llvm/Support/Debug.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "utils/codegen_utils.h" - -using mlir::memref::DimOp; -using mlir::memref::LoadOp; -using mlir::memref::StoreOp; - -namespace mlir { -namespace lmhlo { - -Value createLoadOrUseCachedValue(Location loc, OpBuilder* b, Value memref, - ValueRange indices, - OpBuilder::InsertPoint insertPoint) { - // Check if there are any cached value that can be reused, - // within the current Block. Alternatively we can do this for - // all the Blocks that dominant this Block, but that will be - // complicated anyway. - std::vector storeOps; - insertPoint.getBlock()->walk( - insertPoint.getBlock()->begin(), insertPoint.getPoint(), - [&](StoreOp storeOp) { - if (storeOp.getOperation()->getBlock() != insertPoint.getBlock()) - return; - if ((storeOp.getMemRef() == memref) && - (storeOp.getIndices() == indices)) - storeOps.emplace_back(storeOp); - }); - if (!storeOps.empty()) return storeOps[0].getOperand(0); - int rank = memref.getType().dyn_cast().getRank(); - return rank > 0 ? b->create(loc, memref, indices) - : b->create(loc, memref); -} - -DenseSet noLoaderUser(SmallVectorImpl& ops) { - SmallVector worklist; - DenseSet hasLoaderOps; - for (Operation* op : ops) { - Value memref = cast(op).getResultBuffer(); - if (memref == nullptr) continue; - for (auto* user : memref.getUsers()) { - if (isa(user)) { - worklist.push_back(op); - hasLoaderOps.insert(op); - } - } - } - - while (!worklist.empty()) { - Operation* op = worklist.pop_back_val(); - int numOperands = op->getNumOperands(); - for (int i = 0; i < numOperands - 1; ++i) { - Value memref = op->getOperand(i); - for (Operation* user : memref.getUsers()) { - if ((!isa(user)) || hasLoaderOps.count(user)) continue; - if (cast(user).getResultBuffer() == memref) { - worklist.push_back(user); - hasLoaderOps.insert(user); - } - } - } - } - - DenseSet noLoaderOps; - for (Operation* op : ops) - if (!hasLoaderOps.count(op)) noLoaderOps.insert(op); - return noLoaderOps; -} - -void cleanUnusedLhloOps(Block* parent) { - SmallVector lhloOps; - for (Operation& op : parent->getOperations()) { - if (op.getDialect() == op.getContext()->getLoadedDialect("lmhlo") && - (!isa(op))) - lhloOps.push_back(&op); - } - for (auto* lhloOp : noLoaderUser(lhloOps)) lhloOp->erase(); -} - -template -Value elementalLower(OpBuilder* b, Location loc, LHLO_OpTy op, - ValueRange outputIndex, bool checkCache); - -template <> -Value elementalLower(OpBuilder* b, Location loc, - lmhlo::RealDynamicSliceOp op, - ValueRange outputIndex, - bool checkCache) { - Value startIndicesMemref = op->getOperand(1); - Value stridesMemref = op->getOperand(3); - int rank = outputIndex.size(); - SmallVector inputIndex; - for (int dim = 0; dim < rank; ++dim) { - SmallVector dimIndex; - dimIndex.push_back(b->create( - loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), dim))); - auto startIndexLoad = - b->create(loc, startIndicesMemref, ValueRange{dimIndex}); - auto startIndex = - b->create(loc, b->getIndexType(), startIndexLoad); - auto strideLoad = - b->create(loc, stridesMemref, ValueRange{dimIndex}); - auto stride = - b->create(loc, b->getIndexType(), strideLoad); - // input_dim = out_dim * stride + start_index - auto inputDim = b->create( - loc, b->create(loc, outputIndex[dim], stride), - startIndex); - inputIndex.push_back(inputDim); - } - - Value operandMemref = *(op->getOperands().begin()); - - if (!checkCache) return b->create(loc, operandMemref, inputIndex); - return createLoadOrUseCachedValue(loc, b, operandMemref, inputIndex, - b->saveInsertionPoint()); -} - -namespace { - -template -Value elementalLowerImplForBroadcastInDimOps(OpBuilder* b, Location loc, - T broadcastInDim, - ValueRange outputIndex, - bool checkCache) { - auto broadcastDimensions = - broadcastInDim.getBroadcastDimensions().template getValues(); - int outRank = outputIndex.size(); - Value operandMemref = broadcastInDim->getOperand(0); - SmallVector inputIndex; - for (int64_t dim = 0; dim < outRank; ++dim) { - auto it = - std::find(broadcastDimensions.begin(), broadcastDimensions.end(), dim); - - bool isBroadcastDim = (it != broadcastDimensions.end()); - if (isBroadcastDim) { - int inputDim = std::distance(broadcastDimensions.begin(), it); - int64_t staticDimSize = - operandMemref.getType().cast().getShape()[inputDim]; - if (staticDimSize == 1) { - // we know this dim is to be broadcasted at compile time - auto zero = b->create( - loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), 0)); - inputIndex.push_back(zero); - } else if (staticDimSize == ShapedType::kDynamic) { - // we are not sure if this dim is to be broadcasted at compile time - auto dimSize = b->create(loc, operandMemref, inputDim); - auto one = b->create( - loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), 1)); - auto zero = b->create( - loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), 0)); - auto dimSizeIs1 = b->create( - loc, arith::CmpIPredicate::eq, dimSize, one); - inputIndex.push_back(b->create( - loc, dimSizeIs1, zero, outputIndex[dim])); - } else { - // we know this dim is not to be broadcasted at compile time - inputIndex.push_back(outputIndex[dim]); - } - } - } - - if (!checkCache) { - int rank = operandMemref.getType().dyn_cast().getRank(); - return (rank > 0) ? b->create(loc, operandMemref, inputIndex) - : b->create(loc, operandMemref, ValueRange()); - } - return createLoadOrUseCachedValue(loc, b, operandMemref, inputIndex, - b->saveInsertionPoint()); -} - -} // namespace - -template <> -Value elementalLower( - OpBuilder* b, Location loc, lmhlo::DynamicBroadcastInDimOp op, - ValueRange outputIndex, bool checkCache) { - return elementalLowerImplForBroadcastInDimOps(b, loc, op, outputIndex, - checkCache); -} - -template <> -Value elementalLower(OpBuilder* b, Location loc, - lmhlo::BroadcastInDimOp op, - ValueRange outputIndex, - bool checkCache) { - return elementalLowerImplForBroadcastInDimOps(b, loc, op, outputIndex, - checkCache); -} - -scf::ForOp createLoopAndSetInsPt(OpBuilder& b, Location loc, Value& var, - Value lb, Value ub, Value step, - ArrayRef initValues) { - auto forOp = b.create(loc, lb, ub, step, initValues); - b.setInsertionPointToStart(forOp.getBody()); - var = forOp.getInductionVar(); - return forOp; -} - -scf::ParallelOp createParallelAndSetInsPt(OpBuilder& b, Location loc, - SmallVectorImpl& vars, - ArrayRef lbs, - ArrayRef ubs, - ArrayRef steps, - ArrayRef initValues) { - auto parOp = b.create(loc, lbs, ubs, steps, initValues, - /*bodyBuilderFn=*/nullptr); - b.setInsertionPointToStart(parOp.getBody()); - vars.append(parOp.getInductionVars().begin(), parOp.getInductionVars().end()); - return parOp; -} - -// reinterpret_cast the input memref into 1D -memref::ReinterpretCastOp createMemRef1DReinterpretCast(OpBuilder& b, - Location loc, - Value memref) { - auto memrefTy = memref.getType().cast(); - assert(memrefTy.getLayout().isIdentity()); - Value size = codegen_utils::emitNumElementsComputation(b, loc, memref); - Value stride = b.create( - loc, b.getIndexType(), b.getIntegerAttr(b.getIndexType(), 1)); - Value zero = b.create( - loc, b.getIndexType(), b.getIntegerAttr(b.getIndexType(), 0)); - auto memref1dType = - MemRefType::get({ShapedType::kDynamic}, memrefTy.getElementType(), - b.getMultiDimIdentityMap(1), memrefTy.getMemorySpace()); - return b.create( - loc, memref1dType, memref, zero, ValueRange{size}, ValueRange{stride}); -} - -void createOffsetStore(OpBuilder& b, Location loc, Value res, Value memref, - Value offset) { - Value memref1d = createMemRef1DReinterpretCast(b, loc, memref); - b.create(loc, res, memref1d, ValueRange{offset}); -} - -memref::LoadOp createOffsetLoad(OpBuilder& b, Location loc, Value memref, - Value offset) { - Value memref1d = createMemRef1DReinterpretCast(b, loc, memref); - return b.create(loc, memref1d, ValueRange{offset}); -} - -} // namespace lmhlo -} // namespace mlir diff --git a/xla/mlir_hlo/lhlo/transforms/lhlo_elemental_utils.h b/xla/mlir_hlo/lhlo/transforms/lhlo_elemental_utils.h deleted file mode 100644 index ac906415a4a71..0000000000000 --- a/xla/mlir_hlo/lhlo/transforms/lhlo_elemental_utils.h +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_LHLO_TRANSFORMS_LHLO_ELEMENTAL_UTILS_H -#define MLIR_HLO_LHLO_TRANSFORMS_LHLO_ELEMENTAL_UTILS_H - -#include "mlir/IR/Builders.h" - -namespace mlir { -namespace func { -class FuncOp; -} // namespace func -class Value; -class Location; -class Operation; -class ValueRange; -class Region; -enum class AtomicRMWKind : uint64_t; - -namespace scf { -class ForOp; -class ParallelOp; -} // namespace scf - -namespace memref { -class LoadOp; -} // namespace memref - -namespace lmhlo { - -Value createLoadOrUseCachedValue(Location loc, OpBuilder* b, Value memref, - ValueRange indices, - OpBuilder::InsertPoint insertPoint); - -DenseSet noLoaderUser(SmallVectorImpl& ops); -void cleanUnusedLhloOps(Block* parent); - -template -Value elementalLower(OpBuilder* b, Location loc, LHLO_OpTy op, - ValueRange outputIndex, bool checkCache = false); - -scf::ForOp createLoopAndSetInsPt(OpBuilder& b, Location loc, Value& var, - Value lb, Value ub, Value step, - ArrayRef initValues = {}); - -scf::ParallelOp createParallelAndSetInsPt(OpBuilder& b, Location loc, - SmallVectorImpl& vars, - ArrayRef lbs, - ArrayRef ubs, - ArrayRef steps, - ArrayRef initValues); - -void createOffsetStore(OpBuilder& b, Location loc, Value res, Value memref, - Value offset); - -memref::LoadOp createOffsetLoad(OpBuilder& b, Location loc, Value memref, - Value offset); - -} // namespace lmhlo -} // namespace mlir - -#endif // MLIR_HLO_LHLO_TRANSFORMS_LHLO_ELEMENTAL_UTILS_H diff --git a/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_affine/lhlo_legalize_to_affine.cc b/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_affine/lhlo_legalize_to_affine.cc index dee3c6afe603e..45e44cc116182 100644 --- a/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_affine/lhlo_legalize_to_affine.cc +++ b/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_affine/lhlo_legalize_to_affine.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_gpu/lhlo_legalize_to_gpu.cc b/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_gpu/lhlo_legalize_to_gpu.cc index 49c7f0fbb528d..7572655977e21 100644 --- a/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_gpu/lhlo_legalize_to_gpu.cc +++ b/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_gpu/lhlo_legalize_to_gpu.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_parallel_loops/lhlo_legalize_to_parallel_loops.cc b/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_parallel_loops/lhlo_legalize_to_parallel_loops.cc index 5433d93fe8127..f3c6b3940aced 100644 --- a/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_parallel_loops/lhlo_legalize_to_parallel_loops.cc +++ b/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_parallel_loops/lhlo_legalize_to_parallel_loops.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -75,9 +75,8 @@ Value applySingleResultLhloCode(Location loc, ValueRange operands, // into a reduction operator of scf.reduce by doing buffer allocation for // scalar arguments and the result of `scf.reduce` to make it compatible with // LHLO ops. -void convertToReductionOperator(Location loc, scf::ReduceOp reduceOp, +void convertToReductionOperator(Location loc, Block& loopReduceOpBody, Block* lhloBlock, OpBuilder* b) { - Block& loopReduceOpBody = reduceOp.getReductionOperator().front(); OpBuilder::InsertionGuard guard(*b); b->setInsertionPointToStart(&loopReduceOpBody); b->create( @@ -211,7 +210,8 @@ class ReduceOpConverter : public OpConversionPattern { scf::ReduceOp scfReduceOp = createReduceOpInNestedParallelLoops(reduceOp, &rewriter); - convertToReductionOperator(reduceOp.getLoc(), scfReduceOp, + convertToReductionOperator(reduceOp.getLoc(), + scfReduceOp.getReductions().front().front(), &reduceOp.getBody().front(), &rewriter); rewriter.replaceOp(reduceOp, std::nullopt); return success(); @@ -387,7 +387,8 @@ class ReduceWindowOpConverter scf::ReduceOp reduceOp = createReduceOpInNestedParallelLoops( reduceWindowOp, outputLoop, windowLoop, &rewriter); - convertToReductionOperator(reduceWindowOp.getLoc(), reduceOp, + convertToReductionOperator(reduceWindowOp.getLoc(), + reduceOp.getReductions().front().front(), &reduceWindowOp.getBody().front(), &rewriter); rewriter.replaceOp(reduceWindowOp, std::nullopt); return success(); @@ -452,12 +453,14 @@ class ReduceWindowOpConverter loc, inputType.getElementType(), mappedIvs.inBounds, /*withElseRegion=*/true); - OpBuilder thenBuilder = elemOrInit.getThenBodyBuilder(rewriter); + OpBuilder thenBuilder = + elemOrInit.getThenBodyBuilder(rewriter->getListener()); Value elem = thenBuilder.create(loc, input, mappedIvs.ivs); thenBuilder.create(loc, elem); - OpBuilder elseBuilder = elemOrInit.getElseBodyBuilder(rewriter); + OpBuilder elseBuilder = + elemOrInit.getElseBodyBuilder(rewriter->getListener()); elseBuilder.create(loc, *windowLoop.getInitVals().begin()); return rewriter->create(loc, diff --git a/xla/mlir_hlo/lhlo/transforms/lmhlo_passes.td b/xla/mlir_hlo/lhlo/transforms/lmhlo_passes.td index 8cddbbf7dc944..189e9b7761ea1 100644 --- a/xla/mlir_hlo/lhlo/transforms/lmhlo_passes.td +++ b/xla/mlir_hlo/lhlo/transforms/lmhlo_passes.td @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/transforms/map_hlo_to_lhlo_op.h b/xla/mlir_hlo/lhlo/transforms/map_hlo_to_lhlo_op.h index 3a286d9878230..b12cfd580f6cf 100644 --- a/xla/mlir_hlo/lhlo/transforms/map_hlo_to_lhlo_op.h +++ b/xla/mlir_hlo/lhlo/transforms/map_hlo_to_lhlo_op.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/transforms/map_lhlo_to_hlo_op.h b/xla/mlir_hlo/lhlo/transforms/map_lhlo_to_hlo_op.h index 68c92f386960a..269cdccf3c70b 100644 --- a/xla/mlir_hlo/lhlo/transforms/map_lhlo_to_hlo_op.h +++ b/xla/mlir_hlo/lhlo/transforms/map_lhlo_to_hlo_op.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/transforms/map_lmhlo_to_scalar_op.h b/xla/mlir_hlo/lhlo/transforms/map_lmhlo_to_scalar_op.h index fb4a2e86672e8..0a9dec9f7179b 100644 --- a/xla/mlir_hlo/lhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/xla/mlir_hlo/lhlo/transforms/map_lmhlo_to_scalar_op.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/transforms/passes.h b/xla/mlir_hlo/lhlo/transforms/passes.h index 8225dfa238ccf..41a8247ea3393 100644 --- a/xla/mlir_hlo/lhlo/transforms/passes.h +++ b/xla/mlir_hlo/lhlo/transforms/passes.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo/utils/lhlo_utils.h b/xla/mlir_hlo/lhlo/utils/lhlo_utils.h index 007e9ddc7ea19..d85288237d78f 100644 --- a/xla/mlir_hlo/lhlo/utils/lhlo_utils.h +++ b/xla/mlir_hlo/lhlo/utils/lhlo_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/lhlo_gpu/CMakeLists.txt b/xla/mlir_hlo/lhlo_gpu/CMakeLists.txt deleted file mode 100644 index b16dd4a6fd48e..0000000000000 --- a/xla/mlir_hlo/lhlo_gpu/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -# -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -add_subdirectory(IR) diff --git a/xla/mlir_hlo/lhlo_gpu/IR/CMakeLists.txt b/xla/mlir_hlo/lhlo_gpu/IR/CMakeLists.txt deleted file mode 100644 index 81912905556cb..0000000000000 --- a/xla/mlir_hlo/lhlo_gpu/IR/CMakeLists.txt +++ /dev/null @@ -1,46 +0,0 @@ -# -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops.td) -mlir_tablegen(lhlo_gpu_ops.h.inc -gen-op-decls) -mlir_tablegen(lhlo_gpu_ops.cc.inc -gen-op-defs) - -set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops_enums.td) -mlir_tablegen(lhlo_gpu_ops_enums.h.inc -gen-enum-decls) -mlir_tablegen(lhlo_gpu_ops_enums.cc.inc -gen-enum-defs) -mlir_tablegen(lhlo_gpu_ops_attrdefs.h.inc -gen-attrdef-decls) -mlir_tablegen(lhlo_gpu_ops_attrdefs.cc.inc -gen-attrdef-defs) -mlir_tablegen(lhlo_gpu_ops_dialect.h.inc -gen-dialect-decls) -mlir_tablegen(lhlo_gpu_ops_dialect.cc.inc -gen-dialect-defs) - -add_public_tablegen_target(MLIRlhlo_gpu_opsIncGen) -add_dependencies(mlir-headers MLIRlhlo_gpu_opsIncGen) - - -include_directories(BEFORE - ${CMAKE_CURRENT_BINARY_DIR} - ${CMAKE_CURRENT_SOURCE_DIR}) - -add_mlir_dialect_library(LmhloGPUDialect - lhlo_gpu_ops.cc - - DEPENDS - MLIRlhlo_gpu_opsIncGen - - LINK_LIBS PUBLIC - MhloDialect - MLIRIR - HloOpsCommon -) diff --git a/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.cc b/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.cc deleted file mode 100644 index 798b27cabd819..0000000000000 --- a/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.cc +++ /dev/null @@ -1,171 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file defines the operations used in the LMHLO GPU dialect. - -#include "lhlo_gpu/IR/lhlo_gpu_ops.h" - -#include -#include -#include - -#include - -#include "lhlo/utils/lhlo_utils.h" -#include "llvm/ADT/APFloat.h" -#include "llvm/ADT/APInt.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/FormatVariadic.h" -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/IR/hlo_ops_common.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" - -namespace mlir { -namespace lmhlo_gpu { -static FailureOr parseBool(AsmParser &parser) { - if (succeeded(parser.parseOptionalKeyword("true"))) return true; - if (succeeded(parser.parseOptionalKeyword("false"))) return false; - return failure(); -} - -static FailureOr> parseI64Array(AsmParser &parser) { - SmallVector elements; - auto elementParser = [&]() { - int64_t element = 0; - if (failed(parser.parseInteger(element))) return failure(); - elements.push_back(element); - return success(); - }; - if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, - elementParser)) - return failure(); - return elements; -} -} // namespace lmhlo_gpu -} // namespace mlir - -// Include order below matters. -#include "lhlo_gpu/IR/lhlo_gpu_ops_dialect.cc.inc" -#include "lhlo_gpu/IR/lhlo_gpu_ops_enums.cc.inc" -#define GET_ATTRDEF_CLASSES -#include "lhlo_gpu/IR/lhlo_gpu_ops_attrdefs.cc.inc" - -namespace mlir { -namespace lmhlo_gpu { - -using mhlo::TokenType; - -void LmhloGpuDialect::initialize() { - getContext()->loadDialect(); - addOperations< -#define GET_OP_LIST -#include "lhlo_gpu/IR/lhlo_gpu_ops.cc.inc" - >(); - addAttributes< -#define GET_ATTRDEF_LIST -#include "lhlo_gpu/IR/lhlo_gpu_ops_attrdefs.cc.inc" - >(); -} - -// TODO(jurahul): Add verification for operand shapes and ranks. - -using mlir::hlo::parseWindowAttributes; -using mlir::hlo::printWindowAttributes; - -//===----------------------------------------------------------------------===// -// AllReduceStartOp -//===----------------------------------------------------------------------===// - -mlir::LogicalResult AllReduceStartOp::verify() { - AllReduceStartOp op = *this; - return lmhlo::verifyAllReduce(op); -} - -//===----------------------------------------------------------------------===// -// AllToAllStartOp -//===----------------------------------------------------------------------===// - -mlir::LogicalResult AllToAllStartOp::verify() { - AllToAllStartOp op = *this; - return mlir::hlo::verifyReplicaGroups(op.getLoc(), op.getReplicaGroups(), - /*allGroupsMustHaveSameSize=*/true, - /*useGlobalDeviceIds=*/false, - /*expectedGroupSize=*/std::nullopt); -} - -//===----------------------------------------------------------------------===// -// CollectivePermuteStartOp -//===----------------------------------------------------------------------===// - -mlir::LogicalResult CollectivePermuteStartOp::verify() { - CollectivePermuteStartOp op = *this; - return mlir::hlo::verifyCollectivePermuteSourceTargetPairs( - op, op.getSourceTargetPairs()); -} - -//===----------------------------------------------------------------------===// -// AllGatherStartOp -//===----------------------------------------------------------------------===// - -mlir::LogicalResult AllGatherStartOp::verify() { - AllGatherStartOp op = *this; - return mlir::hlo::verifyReplicaGroups(op.getLoc(), op.getReplicaGroups(), - /*allGroupsMustHaveSameSize=*/true, - op.getUseGlobalDeviceIds(), - /*expectedGroupSize=*/std::nullopt); -} - -//===----------------------------------------------------------------------===// -// ReduceScatterStartOp -//===----------------------------------------------------------------------===// - -LogicalResult ReduceScatterStartOp::verify() { - ReduceScatterStartOp op = *this; - if (failed(hlo::verifyReplicaGroups(op.getLoc(), op.getReplicaGroups(), - /*allGroupsMustHaveSameSize=*/true, - op.getUseGlobalDeviceIds(), - /*expectedGroupSize=*/std::nullopt))) - return failure(); - if (failed(mlir::hlo::verifyReduceScatter( - op, /*operandTypes=*/op.getInputs().getTypes(), - /*resultTypes=*/op.getOutputs().getTypes(), - /*scatterDimension=*/op.getScatterDimension()))) - return failure(); - return success(); -} - -} // namespace lmhlo_gpu -} // namespace mlir - -#define GET_OP_CLASSES -#include "lhlo_gpu/IR/lhlo_gpu_ops.cc.inc" diff --git a/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h b/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h deleted file mode 100644 index 75d5f338693d4..0000000000000 --- a/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file defines the operations used in the LHLO dialect. - -#ifndef MLIR_HLO_LHLO_GPU_IR_LHLO_GPU_OPS_H -#define MLIR_HLO_LHLO_GPU_IR_LHLO_GPU_OPS_H - -#include "llvm/ADT/StringRef.h" -#include "mhlo/IR/hlo_ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Types.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" - -namespace mlir { -class OpBuilder; -} // namespace mlir - -// Include order below matters. -#include "lhlo_gpu/IR/lhlo_gpu_ops_dialect.h.inc" -#include "lhlo_gpu/IR/lhlo_gpu_ops_enums.h.inc" -#define GET_ATTRDEF_CLASSES -#include "lhlo_gpu/IR/lhlo_gpu_ops_attrdefs.h.inc" -#define GET_OP_CLASSES -#include "lhlo_gpu/IR/lhlo_gpu_ops.h.inc" - -#endif // MLIR_HLO_LHLO_GPU_IR_LHLO_GPU_OPS_H diff --git a/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td b/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td deleted file mode 100644 index b2e9b4dff8f45..0000000000000 --- a/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td +++ /dev/null @@ -1,416 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This is the operation definition file for LHMLO level GPU operations. -// Because these are LMHLO level operations, they operate on memrefs. - -#ifndef LHLO_GPU_OPS -#define LHLO_GPU_OPS - -include "mlir/IR/OpBase.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "lhlo/IR/lhlo_ops_base.td" -include "lhlo_gpu/IR/lhlo_gpu_ops_base.td" -include "lhlo_gpu/IR/lhlo_gpu_ops_enums.td" -include "stablehlo/dialect/Base.td" - -class LHLOGPU_Op traits = []> : - Op], traits)>; - -// Type for scratch buffers used by GPU library calls (memref) -def UntypedBuffer : MemRefRankOf<[I8], [1]>; - -// Cholesky info output buffer type. -def I32Buffer : MemRefOf<[I32]>; - -//===----------------------------------------------------------------------===// -// LMHLO ops representing convolution library functions. -//===----------------------------------------------------------------------===// - -class GpuConvolutionAttributes { - dag attributes = !con( - MHLO_ConvolutionAttributes.attributes, - (ins F64Attr:$result_scale), - extraAttribs, - (ins ConvolutionBackendConfigAttr:$backend_config)); -} - -// Provide a custom assembly format for all LHLO_GPU convolution operations. -class LHLOGPU_ConvBaseOp traits = []> : LHLOGPU_Op { - let assemblyFormat = [{ - `(`operands`)` - `dim_numbers` `=` custom($dimension_numbers) `,` - `window` `=` `{` custom($window_strides, $padding, - $lhs_dilation, $rhs_dilation, - $window_reversal) `}` - attr-dict `:` functional-type(operands, results) - }]; -} - -def LHLOGPU_ConvForwardOp : LHLOGPU_ConvBaseOp<"conv_forward"> { - let arguments = !con( - (ins - Arg:$input, - Arg:$filter, - Arg:$output, - Arg:$scratch), - GpuConvolutionAttributes<(ins)>.attributes); -} - -def LHLOGPU_ConvBackwardInputOp : LHLOGPU_ConvBaseOp<"conv_backwardinput"> { - let arguments = !con( - (ins - Arg:$d_output, - Arg:$filter, - Arg:$d_input, - Arg:$scratch), - GpuConvolutionAttributes<(ins)>.attributes); -} - -def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_ConvBaseOp<"conv_backwardfilter"> { - let arguments = !con( - (ins - Arg:$input, - Arg:$d_output, - Arg:$d_filter, - Arg:$scratch), - GpuConvolutionAttributes<(ins)>.attributes); -} - -// output = activation(result_scale * conv(input, filter) + bias) -def LHLOGPU_ConvForwardFusedOp : LHLOGPU_ConvBaseOp<"conv_forward_fused"> { - let arguments = !con( - (ins - Arg:$input, - Arg:$filter, - Arg:$bias, - Arg:$output, - Arg:$scratch), - GpuConvolutionAttributes<(ins - ActivationAttr:$activation_mode, - F64Attr:$leakyrelu_alpha)>.attributes); -} - -// output = activation(result_scale * conv(input, filter) + -// side_input * side_input_scale + -// bias) -def LHLOGPU_ConvForwardFusedSideInputOp : - LHLOGPU_ConvBaseOp<"conv_forward_fused_with_side_input"> { - let arguments = !con( - (ins - Arg:$input, - Arg:$filter, - Arg:$bias, - Arg:$side_input, - Arg:$output, - Arg:$scratch), - GpuConvolutionAttributes<(ins - ActivationAttr:$activation_mode, - F64Attr:$side_input_scale)>.attributes); -} - -// Reordering helpers for int8x32 cuDNN convolutions. -def LHLOGPU_CudnnConvReorderFilterOp : LHLOGPU_Op<"cudnn_conv_reorder_filter"> { - let arguments = (ins - Arg:$filter_input, - Arg:$filter_output, - I64ElementsAttr:$filter_dims); -} - -def LHLOGPU_CudnnConvReorderFilterAndBiasOp : - LHLOGPU_Op<"cudnn_conv_reorder_filter_and_bias"> { - let arguments = (ins - Arg:$filter_input, - Arg:$bias_input, - Arg:$filter_output, - Arg:$bias_output, - I64ElementsAttr:$filter_dims); -} - -def LHLOGPU_ConvForwardGraphOp : - LHLOGPU_ConvBaseOp<"conv_forward_graph", [AttrSizedOperandSegments]> { - let arguments = !con( - (ins - Arg:$input, - Arg:$filter, - Arg, "", [MemRead]>:$binary_operands, - Arg:$output, - Arg, "", [MemWrite]>:$aux_outputs, - Arg:$scratch), - GpuConvolutionAttributes<(ins - I32Attr:$n_aux_outputs, - StrAttr:$serialized_graph)>.attributes); -} - -//===----------------------------------------------------------------------===// -// LMHLO ops representing other library functions. -//===----------------------------------------------------------------------===// - -// c = alpha * (a @ b) + beta * c -def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> { - let arguments = (ins - Arg:$a, - Arg:$b, - Arg:$c, - Arg, "", [MemRead, MemWrite]>:$workspace, - MHLO_DotDimensionNumbers:$dot_dimension_numbers, - MHLO_PrecisionConfigAttr:$precision_config, - F64Attr:$alpha_real, - F64Attr:$alpha_imag, - F64Attr:$beta, - OptionalAttr:$algorithm, - OptionalAttr:$grad_x, - OptionalAttr:$grad_y); -} - -def LHLOGPU_CublasLtMatmulOp : LHLOGPU_Op<"cublas.lt.matmul", [AttrSizedOperandSegments]> { - let arguments = (ins - Arg:$a, - Arg:$b, - Arg:$c, - Arg:$d, - Arg, "", [MemRead]>:$bias, - Arg, "", [MemRead, MemWrite]>:$aux, - MHLO_DotDimensionNumbers:$dot_dimension_numbers, - MHLO_PrecisionConfigAttr:$precision_config, - F64Attr:$alpha_real, - F64Attr:$alpha_imag, - F64Attr:$beta, - CublasLtMatmulEpilogueAttr:$epilogue, - I64Attr:$algorithm, - OptionalAttr:$grad_x, - OptionalAttr:$grad_y); -} - -def LHLOGPU_CublasLtMatmulF8Op : LHLOGPU_Op<"cublas.lt.matmul.f8", [AttrSizedOperandSegments]> { - let arguments = (ins - Arg:$a, - Arg:$b, - Arg:$c, - Arg:$a_scale, - Arg:$b_scale, - Arg:$c_scale, - Arg:$d_scale, - Arg:$d, - Arg, "", [MemRead]>:$bias, - Arg, "", [MemWrite]>:$d_amax, - MHLO_DotDimensionNumbers:$dot_dimension_numbers, - MHLO_PrecisionConfigAttr:$precision_config, - F64Attr:$alpha_real, - F64Attr:$alpha_imag, - F64Attr:$beta, - CublasLtMatmulEpilogueAttr:$epilogue, - I64Attr:$algorithm, - OptionalAttr:$grad_x, - OptionalAttr:$grad_y); -} - -def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> { - let arguments = (ins - Arg:$input, - Arg:$output, - Arg:$scratch, - Arg:$info, - BoolAttr:$is_lower); -} - -// Base class for all async collective communication operations. -class LHLOGPU_AsyncCollectiveCommunicationOpBase traits = []> : - LHLOGPU_Op { - let results = (outs MHLO_Token:$token); - let hasVerifier = 1; -} - -// Base class for async all-reduce & all-gather. -class LHLOGPU_AsyncCollectiveCommunicationOp traits = []> : - LHLOGPU_AsyncCollectiveCommunicationOpBase { - dag arguments_base = (ins - Arg, "", [MemRead]>:$inputs, - Arg, "", [MemWrite]>:$outputs, - I64ElementsAttr:$replica_groups, - DefaultValuedOptionalAttr:$constrain_layout, - OptionalAttr:$channel_id, - DefaultValuedOptionalAttr:$use_global_device_ids, - BoolAttr:$is_sync, - BoolAttr:$no_parallel_custom_call - ); -} - -def LHLOGPU_AllReduceStartOp : - LHLOGPU_AsyncCollectiveCommunicationOp<"all_reduce_start", [SameOperandsElementType]> { - let summary = "AllReduceStart operator"; - let description = [{ - Performs an asynchronous custom reduction across replicas. - }]; - let arguments = arguments_base; - let regions = (region SizedRegion<1>:$computation); -} - -def LHLOGPU_AllReduceDoneOp: LHLOGPU_Op<"all_reduce_done"> { - let summary = "AllReduceDone operator"; - let arguments = (ins MHLO_Token:$token); -} - -def LHLOGPU_CollectivePermuteStartOp : - LHLOGPU_AsyncCollectiveCommunicationOpBase<"collective_permute_start"> { - let summary = "CollectivePermuteStart operator"; - let arguments = (ins - Arg:$operand, - Arg:$output, - I64ElementsAttr:$source_target_pairs, - OptionalAttr:$channel_id, - BoolAttr:$is_sync, - BoolAttr:$no_parallel_custom_call - ); -} - -def LHLOGPU_CollectivePermuteDoneOp: LHLOGPU_Op<"collective_permute_done"> { - let summary = "CollectivePermuteDone operator"; - let arguments = (ins MHLO_Token:$token); -} - -def LHLOGPU_AllGatherStartOp : - LHLOGPU_AsyncCollectiveCommunicationOp<"all_gather_start"> { - let summary = "AllGatherStart operator"; - let description = [{ - Performs asynchronous concatenation across replicas. - }]; - let arguments = !con( - arguments_base, - (ins I64Attr:$all_gather_dimension)); -} - -def LHLOGPU_AllGatherDoneOp: LHLOGPU_Op<"all_gather_done"> { - let summary = "AllGatherDone operator"; - let arguments = (ins MHLO_Token:$token); -} - -def LHLOGPU_ReduceScatterStartOp : - LHLOGPU_AsyncCollectiveCommunicationOp<"reduce_scatter_start", [SameOperandsElementType]> { - let summary = "ReduceScatter start operator"; - let description = [{ - Performs all_reduce followed by a scatter. - }]; - let arguments = !con( - arguments_base, - (ins I64Attr:$scatter_dimension)); - let regions = (region SizedRegion<1>:$computation); -} - -def LHLOGPU_ReduceScatterDoneOp: LHLOGPU_Op<"reduce_scatter_done"> { - let summary = "ReduceScatterDone operator"; - let arguments = (ins MHLO_Token:$token); -} - -def LHLOGPU_AllToAllStartOp : - LHLOGPU_AsyncCollectiveCommunicationOp<"all_to_all_start", [SameOperandsElementType]> { - let summary = "All2AllStart operator"; - let description = [{ - Send data from all cores to all cores. - }]; - let arguments = !con( - arguments_base, - (ins OptionalAttr:$split_dimension)); -} - -def LHLOGPU_AllToAllDoneOp: LHLOGPU_Op<"all_to_all_done"> { - let summary = "All2AllDone operator"; - let arguments = (ins MHLO_Token:$token); -} - -def LHLOGPU_CudnnNormOp : LHLOGPU_Op<"Norm", [AttrSizedOperandSegments]> { - let arguments = (ins - Arg:$input, - Arg:$scale, - Arg:$bias, - Arg:$output, - Arg, "", [MemWrite]>:$expectation, - Arg, "", [MemWrite]>:$norm_factor, - Arg:$scratch, - NormAlgorithmConfigAttr:$algorithm_config, - F64Attr:$epsilon, - I64ArrayAttr:$operand_layouts - ); -} - -def LHLOGPU_fusedMHAOp : LHLOGPU_Op<"fMHA", [AttrSizedOperandSegments]> { - let arguments = (ins - Arg:$lhs_bmm1, - Arg:$rhs_bmm1, - Arg:$rhs_bmm2, - Arg, "", [MemRead]>:$mask, - Arg, "", [MemRead]>:$bias, - Arg:$output, - Arg:$scratch, - Arg, "", [MemWrite]>:$activation, - MHLO_DotDimensionNumbers:$bmm1_dot_dimension_numbers, - MHLO_DotDimensionNumbers:$bmm2_dot_dimension_numbers, - I64ArrayAttr:$intermediate_tensor_dimensions, - I64ArrayAttr:$intermediate_tensor_layout, - F64Attr:$fmha_scale, - FusedMhaDagSignatureAttr:$fused_mha_dag, - FusedMHAAlgorithmConfigAttr:$algorithm_config, - OptionalAttr:$dropout_rate, - OptionalAttr:$seed, - BoolAttr:$is_flash_attention, - BoolAttr:$is_causal_mask - ); -} - -def LHLOGPU_fusedMHABackwardOp : LHLOGPU_Op<"fMHABackward", [AttrSizedOperandSegments]> { - let arguments = (ins - Arg:$bmm1_grad_gemm1_rhs, - Arg:$bmm1_grad_gemm2_rhs, - Arg:$bmm2_grad_gemm2_rhs, - Arg:$bmm2_grad_gemm1_lhs, - Arg:$d_output, - Arg, "", [MemRead]>:$mask, - Arg, "", [MemRead]>:$bias, - Arg, "", [MemRead]>:$fwd_output, - Arg:$d_bmm1_lhs, - Arg:$d_bmm1_rhs, - Arg:$d_bmm2_rhs, - Arg, "", [MemWrite]>:$d_S, - Arg, "", [MemWrite]>:$softmax_sum, - Arg, "", [MemWrite]>:$d_Q_accum, - Arg:$scratch, - Arg, "", [MemWrite]>:$d_bias, - MHLO_DotDimensionNumbers:$bmm1_grad_gemm1_dot_dimension_numbers, - MHLO_DotDimensionNumbers:$bmm1_grad_gemm2_dot_dimension_numbers, - MHLO_DotDimensionNumbers:$bmm2_grad_gemm1_dot_dimension_numbers, - MHLO_DotDimensionNumbers:$bmm2_grad_gemm2_dot_dimension_numbers, - I64ArrayAttr:$intermediate_tensor_dimensions, - I64ArrayAttr:$intermediate_tensor_layout, - F64Attr:$fmha_scale, - FusedMhaBackwardDagSignatureAttr:$fused_mha_dag, - FusedMHAAlgorithmConfigAttr:$algorithm_config, - OptionalAttr:$dropout_rate, - OptionalAttr:$seed, - BoolAttr:$is_flash_attention, - BoolAttr:$is_causal_mask - ); -} - -def LHLOGPU_RadixSortOp: LHLOGPU_Op<"radix_sort", [SameVariadicOperandSize]> { - let arguments = (ins - Arg, "", [MemRead]>:$inputs, - Arg, "", [MemWrite]>:$output, - Arg:$scratch, - DefaultValuedOptionalAttr:$descending - ); -} - -#endif // LHLO_GPU_OPS diff --git a/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_base.td b/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_base.td deleted file mode 100644 index 878dc2c17f2b7..0000000000000 --- a/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_base.td +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// We define the dialect here so that both structs and ops can refer to it. - -#ifndef LHLO_GPU_OPS_BASE -#define LHLO_GPU_OPS_BASE - -include "mlir/IR/OpBase.td" - -def LmhloGpuDialect : Dialect { - let name = "lmhlo_gpu"; - let cppNamespace = "::mlir::lmhlo_gpu"; - - let useDefaultAttributePrinterParser = 1; -} - -#endif // LHLO_GPU_OPS_BASE diff --git a/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td b/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td deleted file mode 100644 index 7ce614e43b859..0000000000000 --- a/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td +++ /dev/null @@ -1,189 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef LHLO_GPU_OPS_ENUMS -#define LHLO_GPU_OPS_ENUMS - -include "mlir/IR/OpBase.td" -include "mlir/IR/EnumAttr.td" -include "mlir/IR/AttrTypeBase.td" - -include "lhlo_gpu/IR/lhlo_gpu_ops_base.td" - -def ActivationModeNone : I32EnumAttrCase<"None", 0>; -def ActivationModeSigmoid : I32EnumAttrCase<"Sigmoid", 1>; -def ActivationModeTanh : I32EnumAttrCase<"Tanh", 2>; -def ActivationModeRelu : I32EnumAttrCase<"Relu", 3>; -def ActivationModeRelu6 : I32EnumAttrCase<"Relu6", 4>; -def ActivationModeReluX : I32EnumAttrCase<"ReluX", 5>; -def ActivationModeBandPass : I32EnumAttrCase<"BandPass", 6>; -def ActivationModeElu: I32EnumAttrCase<"Elu", 7>; -def ActivationModeLeakyRelu: I32EnumAttrCase<"LeakyRelu", 8>; - -def Activation: I32EnumAttr<"Activation", - "Activation applied with fused convolution", - [ActivationModeNone, ActivationModeSigmoid, ActivationModeTanh, - ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX, - ActivationModeBandPass, ActivationModeElu, ActivationModeLeakyRelu]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::lmhlo_gpu"; -} - -def ActivationAttr : EnumAttr; - -def BoolParameter : AttrOrTypeParameter<"bool", ""> { - let parser = "::mlir::lmhlo_gpu::parseBool($_parser)"; -} - -def I64ArrayParameter : - AttrOrTypeParameter<"::llvm::ArrayRef", ""> { - let allocator = [{$_dst = $_allocator.copyInto($_self);}]; - let cppStorageType = "::llvm::SmallVector"; - let parser = "::mlir::lmhlo_gpu::parseI64Array($_parser)"; - let printer = "$_printer << '[' << $_self << ']'"; -} - -def ConvolutionBackendConfigAttr : AttrDef< - LmhloGpuDialect, "ConvolutionBackendConfig"> { - let mnemonic = "convolution_backend_config"; - let parameters = (ins - // These six fields are a TableGen transliteration of AlgorithmProto. - "int64_t":$algorithm, - BoolParameter:$tensor_ops_enabled, - // The next two fields are aligned arrays of knob IDs and values to - // represent the knob_id -> knob_value map. - I64ArrayParameter:$knob_ids, - I64ArrayParameter:$knob_values, - BoolParameter:$is_cudnn_frontend, - // If the convolution has CUDNN_TENSOR_NCHW_VECT_C layout (applicable to - // int8 data type only), this flag denotes that the filter (and bias, if - // present) are reordered using `cudnnReorderFilterAndBias`. - BoolParameter:$is_cudnn_reordered_int8, - "int64_t":$workspace_size, - - // The following 3 attributes describe the layout as an array of integers - // that list the dimensions in minor-to-major order similar to XLA's layout - // representation. operand_0_layout and operand_0_layout described the layout - // of the first 2 operands of the convolution, and result_layout describes - // the layout of the primary output operand of the convolution. - // Note: Not using names like input_layout or filter_layout as `input` may be - // an input operand (for ConvForward) but output for ConvBackward. - I64ArrayParameter:$operand_0_layout, - I64ArrayParameter:$operand_1_layout, - I64ArrayParameter:$result_layout - ); - let assemblyFormat = "`<` struct(params) `>`"; - let summary = "GPU Convolution backend configuration"; -} - -def CublasLtMatmulEpilogueDefault : I32EnumAttrCase<"Default", 0>; -def CublasLtMatmulEpilogueBias : I32EnumAttrCase<"Bias", 1>; -def CublasLtMatmulEpilogueRelu : I32EnumAttrCase<"Relu", 2>; -def CublasLtMatmulEpilogueBiasRelu : I32EnumAttrCase<"BiasRelu", 3>; -def CublasLtMatmulEpilogueGelu : I32EnumAttrCase<"Gelu", 4>; -def CublasLtMatmulEpilogueBiasGelu : I32EnumAttrCase<"BiasGelu", 5>; -def CublasLtMatmulEpilogueGeluAux : I32EnumAttrCase<"GeluAux", 6>; -def CublasLtMatmulEpilogueBiasGeluAux : I32EnumAttrCase<"BiasGeluAux", 7>; - - -def CublasLtMatmulEpilogue: I32EnumAttr<"CublasLtMatmulEpilogue", - "Epilogue for cublasLt matmul", - [CublasLtMatmulEpilogueDefault, CublasLtMatmulEpilogueBias, - CublasLtMatmulEpilogueRelu, CublasLtMatmulEpilogueBiasRelu, - CublasLtMatmulEpilogueGelu, CublasLtMatmulEpilogueBiasGelu, - CublasLtMatmulEpilogueGeluAux, CublasLtMatmulEpilogueBiasGeluAux]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::lmhlo_gpu"; -} - -def CublasLtMatmulEpilogueAttr : EnumAttr; - -def NormAlgorithmConfigAttr : AttrDef< - LmhloGpuDialect, "NormAlgorithmConfig"> { - let mnemonic = "norm_algorithm_config"; - let parameters = (ins - "int64_t":$algorithm, - "int64_t":$workspace_size - ); - let assemblyFormat = "`<` struct(params) `>`"; - let summary = "GPU Norm Algorithm configuration"; -} - -def FusedMHAAlgorithmConfigAttr : AttrDef< - LmhloGpuDialect, "FusedMHAAlgorithmConfig"> { - let mnemonic = "fHMA_algorithm_config"; - let parameters = (ins - "int64_t":$algorithm, - // These 2 fields are a TableGen transliteration of AlgorithmProto. - // Currently only knobs ids and values are relevant for fMHA but this - // Attr can be used to add algorithm related fields. - // The next two fields are aligned arrays of knob IDs and values to - // represent the knob_id -> knob_value map. - I64ArrayParameter:$knob_ids, - I64ArrayParameter:$knob_values, - "int64_t":$workspace_size - ); - let assemblyFormat = "`<` struct(params) `>`"; - let summary = "GPU Fused Multi Headed Attention Algorithm configuration"; -} - -def FusedMhaDagDefault : I32EnumAttrCase<"Default", 0>; -def FusedMhaDagScaleBiasMaskSoftmax : I32EnumAttrCase<"ScaleBiasMaskSoftmax", 1>; -def FusedMhaDagScaleBiasMaskSoftmaxDropout : I32EnumAttrCase<"ScaleBiasMaskSoftmaxDropout", 2>; -def FusedMhaDagScaleMaskSoftmax : I32EnumAttrCase<"ScaleMaskSoftmax", 3>; -def FusedMhaDagScaleMaskSoftmaxDropout : I32EnumAttrCase<"ScaleMaskSoftmaxDropout", 4>; -def FusedMhaDagSoftmaxDropout : I32EnumAttrCase<"SoftmaxDropout", 5>; -def FusedMhaDagSoftmax : I32EnumAttrCase<"Softmax", 6>; -def FusedMhaDagScaleBiasSoftmaxDropout : I32EnumAttrCase<"ScaleBiasSoftmaxDropout", 7>; -def FusedMhaDagScaleBiasSoftmax : I32EnumAttrCase<"ScaleBiasSoftmax", 8>; - -def FusedMhaBackwardDagScaleBiasSoftmaxDropout : I32EnumAttrCase<"BackwardScaleBiasSoftmaxDropout", 0>; -def FusedMhaBackwardDagScaleBiasSoftmax : I32EnumAttrCase<"BackwardScaleBiasSoftmax", 1>; -def FusedMhaBackwardDagScaleBiasMaskSoftmax : I32EnumAttrCase<"BackwardScaleBiasMaskSoftmax", 2>; -def FusedMhaBackwardDagScaleBiasMaskSoftmaxDropout : I32EnumAttrCase<"BackwardScaleBiasMaskSoftmaxDropout", 3>; -def FusedMhaBackwardDagSoftmax : I32EnumAttrCase<"BackwardSoftmax", 4>; -def FusedMhaBackwardDagSoftmaxDropout : I32EnumAttrCase<"BackwardSoftmaxDropout", 5>; - -def FusedMhaDagSignature: I32EnumAttr<"FusedMhaDagSignature", - "DAG configuration for Fused Multi-Headed Attention", - [FusedMhaDagDefault, - FusedMhaDagScaleBiasMaskSoftmax, - FusedMhaDagScaleBiasMaskSoftmaxDropout, - FusedMhaDagScaleMaskSoftmax, - FusedMhaDagScaleMaskSoftmaxDropout, - FusedMhaDagSoftmaxDropout, - FusedMhaDagSoftmax, - FusedMhaDagScaleBiasSoftmaxDropout, - FusedMhaDagScaleBiasSoftmax]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::lmhlo_gpu"; -} - -def FusedMhaBackwardDagSignature: I32EnumAttr<"FusedMhaBackwardDagSignature", - "DAG configuration for Fused Multi-Headed Attention Backward", - [ - FusedMhaBackwardDagScaleBiasSoftmaxDropout, - FusedMhaBackwardDagScaleBiasSoftmax, - FusedMhaBackwardDagScaleBiasMaskSoftmax, - FusedMhaBackwardDagScaleBiasMaskSoftmaxDropout, - FusedMhaBackwardDagSoftmax, - FusedMhaBackwardDagSoftmaxDropout]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::lmhlo_gpu"; -} - -def FusedMhaDagSignatureAttr : EnumAttr; -def FusedMhaBackwardDagSignatureAttr : EnumAttr; -#endif // LHLO_GPU_OPS_ENUMS diff --git a/xla/mlir_hlo/mhlo/CMakeLists.txt b/xla/mlir_hlo/mhlo/CMakeLists.txt index a4a8881e2043a..347117c8bcb1c 100644 --- a/xla/mlir_hlo/mhlo/CMakeLists.txt +++ b/xla/mlir_hlo/mhlo/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/IR/CMakeLists.txt b/xla/mlir_hlo/mhlo/IR/CMakeLists.txt index b982ed7343548..099d0bb6f6049 100644 --- a/xla/mlir_hlo/mhlo/IR/CMakeLists.txt +++ b/xla/mlir_hlo/mhlo/IR/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/IR/chlo_canonicalize.td b/xla/mlir_hlo/mhlo/IR/chlo_canonicalize.td index dcc0ab4636614..2db08da7e5f2e 100644 --- a/xla/mlir_hlo/mhlo/IR/chlo_canonicalize.td +++ b/xla/mlir_hlo/mhlo/IR/chlo_canonicalize.td @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/IR/hlo_base.td b/xla/mlir_hlo/mhlo/IR/hlo_base.td index 429edb829fe63..15d8dcc8cf5f5 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_base.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_base.td @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -39,9 +39,9 @@ defvar MHLO_Complex = HLO_Complex; // Integer-based uniform quantized types. The definitions can be used to specify // operand's tensor types. -defvar MHLO_QuantizedSignedInt = HLO_QuantizedSignedInt; -defvar MHLO_QuantizedUnsignedInt = HLO_QuantizedUnsignedInt; -defvar MHLO_QuantizedInt = HLO_QuantizedInt; +defvar MHLO_QuantizedSignedInt = AnyTypeOf<[HLO_QuantizedSignedInt, HLO_PerAxisQuantizedSignedInt]>; +defvar MHLO_QuantizedUnsignedInt = AnyTypeOf<[HLO_QuantizedUnsignedInt, HLO_PerAxisQuantizedUnsignedInt]>; +defvar MHLO_QuantizedInt = AnyTypeOf<[HLO_QuantizedSignedInt, HLO_PerAxisQuantizedSignedInt, HLO_QuantizedUnsignedInt, HLO_PerAxisQuantizedUnsignedInt]>; // The broadcasting dimensions correspond to a tuple that describes how a // smaller rank shape is broadcast into a larger rank shape. For example, @@ -65,17 +65,21 @@ defvar MHLO_FpTensor = HLO_FpTensor; defvar MHLO_Fp32Or64Tensor = HLO_Fp32Or64Tensor; // Any quantized integer tensor types -defvar MHLO_QuantizedIntTensor = HLO_QuantizedIntTensor; +defvar MHLO_QuantizedIntTensor = HLO_QuantizedIntOrPerAxisQuantizedIntTensor; defvar MHLO_PredTensor = HLO_PredTensor; -defvar MHLO_Tensor = HLO_Tensor; +// TODO: b/327490705 - Change this alias back to HLO_Tensor once MHLO_Tensor no +// longer needs per axis quantization. +defvar MHLO_Tensor = HLO_TensorOrPerAxisQuantizedTensor; + +defvar MHLO_AnyTensor = HLO_AnyTensor; defvar MHLO_ComplexTensor = HLO_ComplexTensor; defvar MHLO_Tuple = HLO_Tuple; -defvar MHLO_TensorOrToken = HLO_TensorOrToken; +defvar MHLO_TensorOrToken = HLO_TensorOrPerAxisQuantizedTensorOrToken; defvar MHLO_TensorOrTokenOrTuple = AnyTypeOf<[MHLO_Tensor, MHLO_Token, MHLO_Tuple]>; @@ -84,6 +88,8 @@ defvar MHLO_DimensionValue = HLO_DimensionValue; // Dynamic representation of a shape vector as a tensor. defvar MHLO_DimensionTensor = HLO_DimensionTensor; +defvar MHLO_CustomCallValue = HLO_CustomCallValue; + //===----------------------------------------------------------------------===// // MHLO combined type definitions. //===----------------------------------------------------------------------===// @@ -103,7 +109,7 @@ defvar MHLO_FpOrComplexTensor = HLO_FpOrComplexTensor; defvar MHLO_FpComplexOrQuantizedIntTensor = HLO_FpComplexOrQuantizedIntTensor; // Any int, floating-point, complex or quantized tensor types -defvar MHLO_IntFpOrComplexOrQuantizedIntTensor = HLO_IntFpOrComplexOrQuantizedIntTensor; +defvar MHLO_IntFpOrComplexOrQuantizedIntTensor = RankedTensorOf<[HLO_Int, HLO_Float, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt]>; // Any pred, int or floating-point tensor types defvar MHLO_PredIntOrFpTensor = HLO_PredIntOrFpTensor; @@ -117,9 +123,9 @@ defvar MHLO_PredIntFpOrQuantizedTensor = HLO_PredIntFpOrQuantizedTensor; // In general, static shaped tensor constraints should be avoided unless // it is for a legacy op which is only correct with static shapes. -defvar MHLO_StaticShapeTensor = HLO_StaticShapeTensor; +defvar MHLO_StaticShapeTensor = HLO_StaticShapeTensorOrPerAxisQuantizedTensor; -defvar MHLO_StaticShapeTensorOrToken = HLO_StaticShapeTensorOrToken; +defvar MHLO_StaticShapeTensorOrToken = HLO_StaticShapeTensorOrPerAxisQuantizedTensorOrToken; defvar MHLO_StaticShapeIntOrFpTensor = HLO_StaticShapeIntOrFpTensor; diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index 50aa7b33341bf..c5ef02a62e14d 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -84,6 +84,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/InliningUtils.h" #include "stablehlo/dialect/AssemblyFormat.h" +#include "stablehlo/dialect/Base.h" #include "stablehlo/dialect/TypeInference.h" #include "utils/convert_op_folder.h" #include "utils/hlo_utils.h" @@ -168,16 +169,6 @@ hlo::HloDialectInterface* getMhloDialect(MLIRContext* context) { return dialect->getRegisteredInterface(); } -void createArgs(ArrayRef operands, - ArrayRef types, - SmallVector& args) { - for (auto argAndType : llvm::zip(operands, types)) { - auto& arg = args.emplace_back(); - arg.ssaName = std::get<0>(argAndType); - arg.type = std::get<1>(argAndType); - } -} - //===----------------------------------------------------------------------===// // Utilities for the canonicalize patterns //===----------------------------------------------------------------------===// @@ -334,17 +325,6 @@ LogicalResult TypeExtensionsAttr::verifyEncoding( getBounds(), RankedTensorType::get(shape, elementType), emitError); } -//===----------------------------------------------------------------------===// -// CollectivePermuteOp -//===----------------------------------------------------------------------===// - -void CollectivePermuteOp::build(OpBuilder& odsBuilder, OperationState& odsState, - Type resultType, Value operand, - DenseIntElementsAttr sourceTargetPairs) { - CollectivePermuteOp::build(odsBuilder, odsState, resultType, operand, - sourceTargetPairs, /*channel_handle=*/nullptr); -} - //===----------------------------------------------------------------------===// // ReduceScatterOp //===----------------------------------------------------------------------===// @@ -392,12 +372,14 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Atan2Op) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CbrtOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CeilOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ClzOp) +INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CollectiveBroadcastOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CollectivePermuteOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CopyOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CosineOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CrossReplicaSumOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DivOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DomainOp) +INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ExpOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Expm1Op) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(FloorOp) @@ -433,11 +415,43 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(XorOp) // Async ops //===----------------------------------------------------------------------===// -Type maybeTupleFromTypes(MLIRContext* ctx, ArrayRef types) { - if (types.size() == 1 && !types[0].isa()) return types[0]; +Type maybeTupleFromTypes(MLIRContext* ctx, ArrayRef types, + bool expectsTuple = false) { + if (!expectsTuple && types.size() == 1 && !types[0].isa()) + return types[0]; return TupleType::get(ctx, TypeRange(types)); } +template +LogicalResult verifyAsyncBundleType(AsyncOp* op, AsyncBundleType bundleType, + FunctionType calleeType) { + auto bundleTypes = bundleType.getTypes(); + if (bundleTypes.size() < 2) + return op->emitOpError() << "bundle is expected to have at least 2 " + << "components, but got " << bundleTypes.size(); + + auto calleeInputTypes = calleeType.getInputs(); + auto calleeResultTypes = calleeType.getResults(); + MLIRContext* ctx = op->getContext(); + // TODO(vsytch): Cleanup callee operand verification when old-style HLO async + // types are removed. + // + // async-* expects the computation operand's types to be wrapped in a tuple. + // Old style async ops did not do this, so we need to check both cases. + if (bundleTypes[0] != maybeTupleFromTypes(ctx, calleeInputTypes) && + bundleTypes[0] != maybeTupleFromTypes(ctx, calleeInputTypes, + /*expectsTuple=*/true)) { + return op->emitOpError() + << "component #0 of async bundle doesn't match callee input types"; + } + if (bundleTypes[1] != maybeTupleFromTypes(ctx, calleeResultTypes)) { + return op->emitOpError() + << "component #1 of async bundle doesn't match callee result types"; + } + + return success(); +} + LogicalResult AsyncStartOp::verify() { ModuleOp module = getOperation()->getParentOfType(); func::FuncOp callee = @@ -446,8 +460,6 @@ LogicalResult AsyncStartOp::verify() { return emitOpError() << "can't find function: " << getCalledComputation(); } FunctionType calleeType = callee.getFunctionType(); - auto calleeInputTypes = calleeType.getInputs(); - auto calleeResultTypes = calleeType.getResults(); auto calleeThreadName = callee->getAttrOfType("execution_thread"); if (!calleeThreadName) @@ -476,21 +488,8 @@ LogicalResult AsyncStartOp::verify() { } } - auto resultTypes = getResult().getType().cast().getTypes(); - if (resultTypes.size() < 2) - return emitOpError() << "result is expected to be a bundle of at least 2 " - "components, but got " - << resultTypes.size(); - if (resultTypes[0] != maybeTupleFromTypes(getContext(), calleeInputTypes)) { - return emitOpError() - << "component #0 of return type doesn't match callee input types"; - } - if (resultTypes[1] != maybeTupleFromTypes(getContext(), calleeResultTypes)) { - return emitOpError() - << "component #1 of return type doesn't match callee result types"; - } - - return success(); + auto bundleType = getResult().getType().cast(); + return verifyAsyncBundleType(this, bundleType, calleeType); } LogicalResult AsyncUpdateOp::verify() { @@ -501,9 +500,6 @@ LogicalResult AsyncUpdateOp::verify() { return emitOpError() << "can't find function: " << getCalledComputation(); } FunctionType calleeType = callee.getFunctionType(); - auto calleeInputTypes = calleeType.getInputs(); - auto calleeResultTypes = calleeType.getResults(); - auto bundleTypes = getBundle().getType().cast().getTypes(); auto calleeThreadName = callee->getAttrOfType("execution_thread"); if (!calleeThreadName) @@ -515,27 +511,15 @@ LogicalResult AsyncUpdateOp::verify() { << calleeThreadName << "."; } - if (bundleTypes.size() < 2) - return emitOpError() << "operand is expected to be a bundle of at least 2 " - "components, but got " - << bundleTypes.size(); - if (bundleTypes[0] != maybeTupleFromTypes(getContext(), calleeInputTypes)) { - return emitOpError() << "component #0 of operand bundle type doesn't match " - "callee input types"; - } - if (bundleTypes[1] != maybeTupleFromTypes(getContext(), calleeResultTypes)) { - return emitOpError() << "component #1 of operand bundle type doesn't match " - "callee result types"; - } - - return success(); + auto bundleType = getResult().getType().cast(); + return verifyAsyncBundleType(this, bundleType, calleeType); } LogicalResult AsyncUpdateOp::inferReturnTypes( MLIRContext*, std::optional, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - AsyncUpdateOp::Adaptor adaptor(operands, attributes, {}, regions); + AsyncUpdateOp::Adaptor adaptor(operands, attributes, properties, regions); auto stateType = adaptor.getBundle().getType().cast(); inferredReturnTypes.push_back(stateType); return success(); @@ -549,9 +533,6 @@ LogicalResult AsyncDoneOp::verify() { return emitOpError() << "can't find function: " << getCalledComputation(); } FunctionType calleeType = callee.getFunctionType(); - auto calleeInputTypes = calleeType.getInputs(); - auto calleeResultTypes = calleeType.getResults(); - auto bundleTypes = getBundle().getType().cast().getTypes(); auto calleeThreadName = callee->getAttrOfType("execution_thread"); if (!calleeThreadName) @@ -563,27 +544,15 @@ LogicalResult AsyncDoneOp::verify() { << calleeThreadName << "."; } - if (bundleTypes.size() < 2) - return emitOpError() << "operand is expected to be a bundle of at least 2 " - "components, but got " - << bundleTypes.size(); - if (bundleTypes[0] != maybeTupleFromTypes(getContext(), calleeInputTypes)) { - return emitOpError() - << "operand type component #0 doesn't match callee input types"; - } - if (bundleTypes[1] != maybeTupleFromTypes(getContext(), calleeResultTypes)) { - return emitOpError() - << "operand type component #1 doesn't match callee result types"; - } - - return success(); + auto bundleType = getBundle().getType().cast(); + return verifyAsyncBundleType(this, bundleType, calleeType); } LogicalResult AsyncDoneOp::inferReturnTypes( MLIRContext*, std::optional, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - AsyncDoneOp::Adaptor adaptor(operands, attributes, {}, regions); + AsyncDoneOp::Adaptor adaptor(operands, attributes, properties, regions); ModuleOp module = adaptor.getBundle().getDefiningOp()->getParentOfType(); auto calledComputation = adaptor.getCalledComputationAttr(); @@ -609,6 +578,16 @@ LogicalResult AfterAllOp::inferReturnTypes( inferredReturnTypes); } +//===----------------------------------------------------------------------===// +// CompositeOp +//===----------------------------------------------------------------------===// + +LogicalResult CompositeOp::verifySymbolUses( + SymbolTableCollection& symbolTable) { + return hlo::verifyCompositeOp(getLoc(), getOperation(), getName(), + getDecomposition(), symbolTable); +} + //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// @@ -623,34 +602,35 @@ OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { // Builds a constant op with the specified attribute `value`. void ConstantOp::build(OpBuilder& /*builder*/, OperationState& result, Attribute value) { + Properties& properties = result.getOrAddProperties(); Type type; if (auto elemAttr = value.dyn_cast()) { type = elemAttr.getType(); + properties.value = elemAttr; } else if (value.isa()) { // All XLA types must be tensor types. In the build() method, we want to // provide more flexibility by allowing attributes of scalar types. But we // need to wrap it up with ElementsAttr to construct valid XLA constants. type = RankedTensorType::get(/*shape=*/{}, value.cast().getType()); - value = DenseElementsAttr::get(type.cast(), value); + properties.value = DenseElementsAttr::get(type.cast(), value); } else if (auto complexAttr = value.dyn_cast()) { type = RankedTensorType::get(/*shape=*/{}, complexAttr.cast().getType()); - value = + properties.value = DenseElementsAttr::get(type.cast(), complexAttr.getValue()); } // TODO: support other XLA specific types. assert(type && "unsupported attribute type for building mhlo.constant"); result.types.push_back(type); - result.addAttribute("value", value); } LogicalResult ConstantOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - ConstantOpAdaptor adaptor(operands, attributes); + ConstantOpAdaptor adaptor(operands, attributes, properties, regions); return hlo::inferConstantOp(location, adaptor.getValue(), inferredReturnTypes); } @@ -934,9 +914,9 @@ void CustomCallOp::getEffects( LogicalResult CholeskyOp::inferReturnTypeComponents( MLIRContext*, std::optional location, ValueShapeRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - CholeskyOp::Adaptor adaptor(operands, attributes, {}, regions); + CholeskyOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferCholeskyOp(location, adaptor.getA(), inferredReturnShapes); } @@ -945,8 +925,8 @@ LogicalResult CholeskyOp::inferReturnTypeComponents( //===----------------------------------------------------------------------===// LogicalResult DotOp::verify() { - return hlo::verifyDotOp(getLoc(), getLhs(), getRhs(), getPrecisionConfig(), - getResult()); + return hlo::verifyDotOp(getLoc(), getLhs().getType(), getRhs().getType(), + getPrecisionConfig(), getResult()); } //===----------------------------------------------------------------------===// @@ -1000,19 +980,81 @@ LogicalResult DotGeneralOp::reifyReturnTypeShapes( return success(); } +//===----------------------------------------------------------------------===// +// SparseDotOp +//===----------------------------------------------------------------------===// + +LogicalResult SparseDotOp::verify() { + RankedTensorType lhsType = getLhs().getType().dyn_cast(); + RankedTensorType rhsType = getRhs().getType().dyn_cast(); + // If either operand is unranked, static verification is not possible. + if (!lhsType || !rhsType) return success(); + + auto applySparsityDescriptor = [&](std::optional attr, + RankedTensorType* type) { + if (!attr.has_value()) return success(); + SmallVector sparseShape(type->getShape()); + if (static_cast(attr->getDimension()) >= sparseShape.size()) { + return emitOptionalError(getLoc(), "sparsity dimension is incorrect"); + } + if (attr->getN() != 2 || attr->getM() != 4) { + return emitOptionalError(getLoc(), "only 2:4 sparsity is supported"); + } + sparseShape[attr->getDimension()] *= attr->getM() / attr->getN(); + *type = type->clone(sparseShape); + return success(); + }; + if (failed(applySparsityDescriptor(getLhsSparsity(), &lhsType)) || + failed(applySparsityDescriptor(getRhsSparsity(), &rhsType))) + return failure(); + + SmallVector inferredReturnShapes; + if (failed(hlo::inferDotGeneralOp( + getLoc(), lhsType, rhsType, + getDotDimensionNumbersAttr().getLhsBatchingDimensions(), + getDotDimensionNumbersAttr().getRhsBatchingDimensions(), + getDotDimensionNumbersAttr().getLhsContractingDimensions(), + getDotDimensionNumbersAttr().getRhsContractingDimensions(), + getPrecisionConfig(), inferredReturnShapes))) + return failure(); + + auto inferredShape = inferredReturnShapes[0]; + auto resultType = getResult().getType().cast(); + if (inferredShape.hasRank() && resultType.hasRank() && + failed(verifyCompatibleShape(inferredShape.getDims(), + resultType.getShape()))) + return emitOptionalError(getLoc(), "inferred shape '", + hlo::dimSizesToString(inferredShape.getDims()), + "' is incompatible with return type of operation ", + resultType); + return success(); +} + //===----------------------------------------------------------------------===// // FftOp //===----------------------------------------------------------------------===// +LogicalResult verify1dTensor(std::optional loc, + DenseIntElementsAttr attr, std::string attrName) { + auto rank = attr.getType().getRank(); + if (rank != 1) { + return emitOptionalError(loc, attrName, " has rank ", rank, + " instead of required rank 1."); + } + return success(); +} LogicalResult FftOp::inferReturnTypeComponents( MLIRContext*, std::optional location, ValueShapeRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - FftOp::Adaptor adaptor(operands, attributes, {}, regions); - return hlo::inferFftOp(location, adaptor.getOperand(), - adaptor.getFftType() == FftType::RFFT, - adaptor.getFftType() == FftType::IRFFT, - adaptor.getFftLength(), inferredReturnShapes); + FftOp::Adaptor adaptor(operands, attributes, properties, regions); + if (failed(verify1dTensor(location, adaptor.getFftLength(), "fft_length"))) + return failure(); + return hlo::inferFftOp( + location, adaptor.getOperand(), adaptor.getFftType() == FftType::RFFT, + adaptor.getFftType() == FftType::IRFFT, + llvm::to_vector(adaptor.getFftLength().getValues()), + inferredReturnShapes); } //===----------------------------------------------------------------------===// @@ -1187,17 +1229,20 @@ LogicalResult GatherOp::reifyReturnTypeShapes( LogicalResult GatherOp::inferReturnTypeComponents( MLIRContext* context, std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties, - RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - GatherOp::Adaptor adaptor(operands, attributes, {}, regions); + GatherOp::Adaptor adaptor(operands, attributes, properties, regions); + if (failed(verify1dTensor(location, adaptor.getSliceSizes(), "slice_sizes"))) + return failure(); return hlo::inferGatherOp( location, adaptor.getOperand(), adaptor.getStartIndices(), adaptor.getDimensionNumbers().getOffsetDims(), adaptor.getDimensionNumbers().getCollapsedSliceDims(), adaptor.getDimensionNumbers().getStartIndexMap(), adaptor.getDimensionNumbers().getIndexVectorDim(), - adaptor.getSliceSizes(), inferredReturnShapes); + llvm::to_vector(adaptor.getSliceSizes().getValues()), + inferredReturnShapes); } //===----------------------------------------------------------------------===// @@ -1245,10 +1290,10 @@ LogicalResult DynamicGatherOp::reifyReturnTypeShapes( LogicalResult DynamicGatherOp::inferReturnTypeComponents( MLIRContext* context, std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties, - RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - DynamicGatherOp::Adaptor adaptor(operands, attributes, {}, regions); + DynamicGatherOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferDynamicGatherOp( location, adaptor.getOperand(), adaptor.getStartIndices(), adaptor.getSliceSizes(), adaptor.getDimensionNumbers().getOffsetDims(), @@ -1263,9 +1308,10 @@ LogicalResult DynamicGatherOp::inferReturnTypeComponents( LogicalResult GetDimensionSizeOp::inferReturnTypeComponents( MLIRContext*, std::optional location, ValueShapeRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - GetDimensionSizeOp::Adaptor adaptor(operands, attributes, {}, regions); + GetDimensionSizeOp::Adaptor adaptor(operands, attributes, properties, + regions); return hlo::inferGetDimensionSizeOp(location, adaptor.getOperand().getType(), adaptor.getDimension(), inferredReturnShapes); @@ -1442,9 +1488,10 @@ LogicalResult DynamicIotaOp::reifyReturnTypeShapes( LogicalResult DynamicUpdateSliceOp::inferReturnTypeComponents( MLIRContext*, std::optional location, ValueShapeRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - DynamicUpdateSliceOp::Adaptor adaptor(operands, attributes, {}, regions); + DynamicUpdateSliceOp::Adaptor adaptor(operands, attributes, properties, + regions); return hlo::inferDynamicUpdateSliceOp( location, adaptor.getOperand(), adaptor.getUpdate(), adaptor.getStartIndices(), inferredReturnShapes); @@ -1482,16 +1529,39 @@ OpFoldResult DynamicUpdateSliceOp::fold(FoldAdaptor /*adaptor*/) { LogicalResult AbsOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - AbsOp::Adaptor adaptor(operands, attributes, {}, regions); + AbsOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferAbsOp(location, adaptor.getOperand(), inferredReturnTypes); } +//===----------------------------------------------------------------------===// +// CollectiveBroadcastOp +//===----------------------------------------------------------------------===// + +void CollectiveBroadcastOp::build(OpBuilder& odsBuilder, + OperationState& odsState, Type resultType, + Value operand, + DenseIntElementsAttr replicaGroups) { + CollectiveBroadcastOp::build(odsBuilder, odsState, resultType, operand, + replicaGroups, /*channel_handle=*/nullptr); +} + +LogicalResult CollectiveBroadcastOp::verify() { + return hlo::verifyCollectiveBroadcastOp(getLoc(), getReplicaGroups()); +} + //===----------------------------------------------------------------------===// // CollectivePermuteOp //===----------------------------------------------------------------------===// +void CollectivePermuteOp::build(OpBuilder& odsBuilder, OperationState& odsState, + Type resultType, Value operand, + DenseIntElementsAttr sourceTargetPairs) { + CollectivePermuteOp::build(odsBuilder, odsState, resultType, operand, + sourceTargetPairs, /*channel_handle=*/nullptr); +} + LogicalResult CollectivePermuteOp::verify() { return hlo::verifyCollectivePermuteOp(getLoc(), getSourceTargetPairs()); } @@ -1785,13 +1855,8 @@ void DynamicConvOp::getCanonicalizationPatterns(RewritePatternSet& results, void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand, Type resultElementTy) { - Type resultTy; - Type operandTy = operand.getType(); - if (auto rankedTy = operandTy.dyn_cast()) { - resultTy = RankedTensorType::get(rankedTy.getShape(), resultElementTy); - } else { - resultTy = UnrankedTensorType::get(resultElementTy); - } + auto rankedTy = operand.getType().cast(); + auto resultTy = RankedTensorType::get(rankedTy.getShape(), resultElementTy); build(builder, result, resultTy, operand); } @@ -1931,9 +1996,9 @@ void TupleOp::getCanonicalizationPatterns(RewritePatternSet& results, LogicalResult AllToAllOp::inferReturnTypeComponents( MLIRContext*, std::optional location, ValueShapeRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - AllToAllOp::Adaptor adaptor(operands, attributes, {}, regions); + AllToAllOp::Adaptor adaptor(operands, attributes, properties, regions); bool isArrayAllToAll = adaptor.getSplitDimension() && adaptor.getConcatDimension() && @@ -2001,9 +2066,21 @@ LogicalResult AllGatherOp::verify() { if (auto channelHandleAttr = getChannelHandleAttr()) channelId = channelHandleAttr.getHandle(); - return hlo::verifyAllGatherOp(getLoc(), getOperand(), getAllGatherDim(), - getReplicaGroups(), channelId, - getUseGlobalDeviceIds(), getResult()); + if (getOperands().empty()) + return emitOptionalError(getLoc(), + "AllGather must have have at least one operand"); + if (getNumOperands() != getNumResults()) + return emitOptionalError( + getLoc(), "AllGather requires the same number of operands and results"); + + for (unsigned i = 0; i < getNumOperands(); ++i) { + if (failed(hlo::verifyAllGatherOp( + getLoc(), getOperand(i), getAllGatherDim(), getReplicaGroups(), + channelId, getUseGlobalDeviceIds(), getResult(i)))) { + return failure(); + } + } + return success(); } void AllGatherOp::build(OpBuilder& odsBuilder, OperationState& odsState, @@ -2011,8 +2088,8 @@ void AllGatherOp::build(OpBuilder& odsBuilder, OperationState& odsState, IntegerAttr allGatherDim, DenseIntElementsAttr replicaGroups, ChannelHandleAttr channelHandle) { - AllGatherOp::build(odsBuilder, odsState, resultType, operand, allGatherDim, - replicaGroups, channelHandle, + AllGatherOp::build(odsBuilder, odsState, resultType, ValueRange(operand), + allGatherDim, replicaGroups, channelHandle, /*use_global_device_ids=*/nullptr); } @@ -2040,9 +2117,9 @@ void AllReduceOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult AllReduceOp::inferReturnTypeComponents( MLIRContext*, std::optional location, ValueShapeRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - AllReduceOp::Adaptor adaptor(operands, attributes, {}, regions); + AllReduceOp::Adaptor adaptor(operands, attributes, properties, regions); // Verify constraints if (adaptor.getOperands().empty()) @@ -2061,15 +2138,8 @@ LogicalResult AllReduceOp::inferReturnTypeComponents( } // Populate inferred return shapes - for (auto resultType : adaptor.getOperands().getTypes()) { - auto rankedResult = resultType.dyn_cast(); - if (rankedResult) - inferredReturnShapes.emplace_back(rankedResult.getShape(), - rankedResult.getElementType(), - rankedResult.getEncoding()); - else - inferredReturnShapes.emplace_back(resultType.cast()); - } + return hlo::inferAllReduceOp(location, adaptor.getOperands(), + adaptor.getComputation(), inferredReturnShapes); return success(); } @@ -2079,10 +2149,10 @@ LogicalResult AllReduceOp::inferReturnTypeComponents( LogicalResult BatchNormGradOp::inferReturnTypeComponents( MLIRContext* context, std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties, - RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - BatchNormGradOp::Adaptor adaptor(operands, attributes, {}, regions); + BatchNormGradOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferBatchNormGradOp( location, adaptor.getOperand(), adaptor.getScale(), adaptor.getMean(), adaptor.getVariance(), adaptor.getGradOutput(), adaptor.getFeatureIndex(), @@ -2095,10 +2165,11 @@ LogicalResult BatchNormGradOp::inferReturnTypeComponents( LogicalResult BatchNormTrainingOp::inferReturnTypeComponents( MLIRContext* context, std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties, - RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - BatchNormTrainingOp::Adaptor adaptor(operands, attributes, {}, regions); + BatchNormTrainingOp::Adaptor adaptor(operands, attributes, properties, + regions); return hlo::inferBatchNormTrainingOp( location, adaptor.getOperand(), adaptor.getScale(), adaptor.getOffset(), adaptor.getFeatureIndex(), inferredReturnShapes); @@ -2110,10 +2181,11 @@ LogicalResult BatchNormTrainingOp::inferReturnTypeComponents( LogicalResult BatchNormInferenceOp::inferReturnTypeComponents( MLIRContext* context, std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties, - RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - BatchNormInferenceOp::Adaptor adaptor(operands, attributes, {}, regions); + BatchNormInferenceOp::Adaptor adaptor(operands, attributes, properties, + regions); return hlo::inferBatchNormInferenceOp( location, adaptor.getOperand(), adaptor.getScale(), adaptor.getOffset(), adaptor.getMean(), adaptor.getVariance(), adaptor.getFeatureIndex(), @@ -2214,12 +2286,16 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { LogicalResult BroadcastOp::inferReturnTypeComponents( MLIRContext*, std::optional location, ValueShapeRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - BroadcastOp::Adaptor adaptor(operands, attributes, {}, regions); - return hlo::inferBroadcastOp(location, adaptor.getOperand(), - adaptor.getBroadcastSizes(), - inferredReturnShapes); + BroadcastOp::Adaptor adaptor(operands, attributes, properties, regions); + if (failed(verify1dTensor(location, adaptor.getBroadcastSizes(), + "broadcast_sizes"))) + return failure(); + return hlo::inferBroadcastOp( + location, adaptor.getOperand(), + llvm::to_vector(adaptor.getBroadcastSizes().getValues()), + inferredReturnShapes); } LogicalResult BroadcastOp::reifyReturnTypeShapes( @@ -2261,8 +2337,10 @@ LogicalResult BroadcastOp::reifyReturnTypeShapes( //===----------------------------------------------------------------------===// LogicalResult BroadcastInDimOp::verify() { - return hlo::verifyBroadcastInDimOp(getLoc(), getOperand(), - getBroadcastDimensions(), getResult()); + return hlo::verifyBroadcastInDimOp( + getLoc(), getOperand(), + llvm::to_vector(getBroadcastDimensions().getValues()), + getResult()); } OpFoldResult BroadcastInDimOp::fold(FoldAdaptor adaptor) { @@ -2365,9 +2443,27 @@ void BroadcastInDimOp::getCanonicalizationPatterns(RewritePatternSet& results, //===----------------------------------------------------------------------===// LogicalResult DynamicBroadcastInDimOp::verify() { + // Check for unranked dynamism. Unranked dynamism is not supported by + // StableHLO (hlo::verifyReshapeOp will fail) and we can't verify + // anything statically in that case anyway. + auto outputdimensionsType = + getOutputDimensions().getType().cast(); + auto resultType = getResult().getType().cast(); + if (!outputdimensionsType.hasRank() || !resultType.hasRank()) { + return success(); + } + return hlo::verifyDynamicBroadcastInDimOp( - getLoc(), getOperand(), getOutputDimensions(), getBroadcastDimensions(), - getKnownExpandingDimensions(), getKnownNonexpandingDimensions(), + getLoc(), getOperand(), getOutputDimensions(), + llvm::to_vector(getBroadcastDimensions().getValues()), + getKnownExpandingDimensionsAttr() + ? std::optional>(llvm::to_vector( + getKnownExpandingDimensions()->getValues())) + : std::nullopt, + getKnownNonexpandingDimensions() + ? std::optional>(llvm::to_vector( + getKnownNonexpandingDimensions()->getValues())) + : std::nullopt, getResult()); } @@ -2524,9 +2620,9 @@ LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes( LogicalResult ComplexOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - ComplexOp::Adaptor adaptor(operands, attributes, {}, regions); + ComplexOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferComplexOp(location, adaptor.getLhs(), inferredReturnTypes); } @@ -2546,9 +2642,9 @@ OpFoldResult ComplexOp::fold(FoldAdaptor) { LogicalResult ImagOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - ImagOp::Adaptor adaptor(operands, attributes, {}, regions); + ImagOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferImagOp(location, adaptor.getOperand(), inferredReturnTypes); } @@ -2566,9 +2662,9 @@ OpFoldResult ImagOp::fold(FoldAdaptor) { LogicalResult IsFiniteOp::inferReturnTypes( MLIRContext* ctx, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - IsFiniteOp::Adaptor adaptor(operands, attributes, {}, regions); + IsFiniteOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferIsFiniteOp(ctx, location, adaptor.getX(), inferredReturnTypes); } @@ -2579,9 +2675,9 @@ LogicalResult IsFiniteOp::inferReturnTypes( LogicalResult RealOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - RealOp::Adaptor adaptor(operands, attributes, {}, regions); + RealOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferRealOp(location, adaptor.getOperand(), inferredReturnTypes); } @@ -2677,9 +2773,9 @@ class ConcatenateForwarding : public OpRewritePattern { LogicalResult ConcatenateOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - ConcatenateOp::Adaptor adaptor(operands, attributes, {}, regions); + ConcatenateOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferConcatenateOp(location, adaptor.getVal().getTypes(), adaptor.getDimension(), inferredReturnTypes); } @@ -2739,7 +2835,8 @@ static Attribute foldConcatenate(ConcatenateOp* op, OpFoldResult ConcatenateOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); - if (getNumOperands() == 1) return getOperand(0); + if (getNumOperands() == 1 && getOperand(0).getType() == getType()) + return getOperand(0); ShapedType type = getResult().getType().cast(); if (!type.hasStaticShape()) return {}; @@ -2818,6 +2915,13 @@ LogicalResult ConcatenateOp::reifyReturnTypeShapes( //===----------------------------------------------------------------------===// LogicalResult DynamicReshapeOp::verify() { + // Check for unranked dynamism. Unranked dynamism is not supported by + // StableHLO (hlo::verifyDynamicReshapeOp will fail) and we can't verify + // anything statically in that case anyway. + auto resultType = getResult().getType().cast(); + auto outputShapeType = getOutputShape().getType().cast(); + if (!resultType.hasRank() || !outputShapeType.hasStaticShape()) + return success(); return hlo::verifyDynamicReshapeOp(getLoc(), getOutputShape(), getResult()); } @@ -2991,13 +3095,16 @@ void DynamicSliceOp::getCanonicalizationPatterns(RewritePatternSet& results, LogicalResult DynamicSliceOp::inferReturnTypeComponents( MLIRContext*, std::optional location, ValueShapeRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - DynamicSliceOp::Adaptor adaptor(operands, attributes, {}, regions); - return hlo::inferDynamicSliceOp(location, adaptor.getOperand().getType(), - adaptor.getStartIndices().getTypes(), - adaptor.getSliceSizes(), - inferredReturnShapes); + DynamicSliceOp::Adaptor adaptor(operands, attributes, properties, regions); + if (failed(verify1dTensor(location, adaptor.getSliceSizes(), "slice_sizes"))) + return failure(); + return hlo::inferDynamicSliceOp( + location, adaptor.getOperand().getType(), + adaptor.getStartIndices().getTypes(), + llvm::to_vector(adaptor.getSliceSizes().getValues()), + inferredReturnShapes); } //===----------------------------------------------------------------------===// @@ -3137,11 +3244,15 @@ LogicalResult InfeedOp::verify() { LogicalResult MapOp::inferReturnTypeComponents( MLIRContext*, std::optional location, ValueShapeRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - MapOp::Adaptor adaptor(operands, attributes, {}, regions); - return hlo::inferMapOp(location, adaptor.getInputs(), adaptor.getDimensions(), - adaptor.getComputation(), inferredReturnShapes); + MapOp::Adaptor adaptor(operands, attributes, properties, regions); + if (failed(verify1dTensor(location, adaptor.getDimensions(), "dimensions"))) + return failure(); + return hlo::inferMapOp( + location, adaptor.getInputs(), + llvm::to_vector(adaptor.getDimensions().getValues()), + adaptor.getComputation(), inferredReturnShapes); } OpFoldResult MapOp::fold(FoldAdaptor) { @@ -3218,21 +3329,52 @@ OpFoldResult CopyOp::fold(FoldAdaptor) { return getOperand(); } LogicalResult ReduceWindowOp::inferReturnTypeComponents( MLIRContext*, std::optional location, ValueShapeRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - ReduceWindowOp::Adaptor adaptor(operands, attributes, {}, regions); + ReduceWindowOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferReduceWindowOp( location, adaptor.getInputs(), adaptor.getInitValues(), - adaptor.getWindowDimensions(), adaptor.getWindowStrides(), - adaptor.getBaseDilations(), adaptor.getWindowDilations(), - adaptor.getPadding(), inferredReturnShapes); + llvm::to_vector(adaptor.getWindowDimensions().getValues()), + adaptor.getWindowStrides() + ? llvm::to_vector(adaptor.getWindowStrides()->getValues()) + : ArrayRef{}, + adaptor.getBaseDilations() + ? llvm::to_vector(adaptor.getBaseDilations()->getValues()) + : ArrayRef{}, + adaptor.getWindowDilations() + ? llvm::to_vector(adaptor.getWindowDilations()->getValues()) + : ArrayRef{}, + adaptor.getPadding(), adaptor.getBody(), inferredReturnShapes); } LogicalResult ReduceWindowOp::verify() { - return hlo::verifyReduceWindowOp(getLoc(), getInputs(), getInitValues(), - getWindowDimensions(), getWindowStrides(), - getBaseDilations(), getWindowDilations(), - getPadding(), getBody()); + if (failed( + verify1dTensor(getLoc(), getWindowDimensions(), "window_dimensions"))) + return failure(); + // TODO: simplify this code and others in this file + if (getWindowStrides() && + failed(verify1dTensor(getLoc(), *getWindowStrides(), "window_strides"))) + return failure(); + if (getBaseDilations() && + failed(verify1dTensor(getLoc(), *getBaseDilations(), "base_dilations"))) + return failure(); + if (getWindowDilations() && + failed( + verify1dTensor(getLoc(), *getWindowDilations(), "window_dilations"))) + return failure(); + return hlo::verifyReduceWindowOp( + getLoc(), getInputs(), getInitValues(), + llvm::to_vector(getWindowDimensions().getValues()), + getWindowStrides() + ? llvm::to_vector(getWindowStrides()->getValues()) + : ArrayRef{}, + getBaseDilations() + ? llvm::to_vector(getBaseDilations()->getValues()) + : ArrayRef{}, + getWindowDilations() + ? llvm::to_vector(getWindowDilations()->getValues()) + : ArrayRef{}, + getPadding(), getBody()); } // Get the operation used for reduction applied to `result_index`th result. Its @@ -3312,23 +3454,12 @@ void ReduceWindowOp::build( function_ref bodyBuilder) { odsState.addOperands(inputs); odsState.addOperands(init_values); - odsState.addAttribute(getWindowDimensionsAttrName(odsState.name), - window_dimensions); - if (window_strides) { - odsState.addAttribute(getWindowStridesAttrName(odsState.name), - window_strides); - } - if (base_dilations) { - odsState.addAttribute(getBaseDilationsAttrName(odsState.name), - base_dilations); - } - if (window_dilations) { - odsState.addAttribute(getWindowDilationsAttrName(odsState.name), - window_dilations); - } - if (padding) { - odsState.addAttribute(getPaddingAttrName(odsState.name), padding); - } + Properties& properties = odsState.getOrAddProperties(); + properties.window_dimensions = window_dimensions; + properties.window_strides = window_strides; + properties.base_dilations = base_dilations; + properties.window_dilations = window_dilations; + properties.padding = padding; Region* region = odsState.addRegion(); llvm::SmallVector blockArgTypes; @@ -3476,32 +3607,45 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) { // ReduceOp //===----------------------------------------------------------------------===// -LogicalResult ReduceOp::fold(FoldAdaptor /*adaptor*/, - SmallVectorImpl& results) { +static LogicalResult tryFoldZeroDimReduction( + ReduceOp reduceOp, SmallVectorImpl& results) { + if (reduceOp.getDimensions().getNumElements() != 0) return failure(); // No dimensions to reduce. - if (getDimensions().getNumElements() == 0) { - for (Value operand : this->getInputs()) { - results.push_back(operand); + for (auto [operand, opResult] : + llvm::zip_equal(reduceOp.getInputs(), reduceOp.getResults())) { + if (operand.getType() != opResult.getType()) { + results.clear(); + return failure(); } - return success(); + results.push_back(operand); } + return success(); +} +static LogicalResult tryFoldOutsideValuesReduction( + ReduceOp reduceOp, SmallVectorImpl& results) { // If all returned values in the ReduceOp region exists outside // the region replace the ReduceOp with those values. - mlir::Block& bb = this->getBody().front(); - SmallVector replacedResults; - if (auto retOp = mlir::dyn_cast(bb.back())) { - for (Value result : retOp.getResults()) { - if (result.getParentRegion() == retOp->getParentRegion()) - return failure(); - replacedResults.push_back(result); + mlir::Block& bb = reduceOp.getBody().front(); + auto retOp = mlir::dyn_cast(bb.back()); + if (!retOp) return failure(); + for (auto [result, opResult] : + llvm::zip_equal(retOp.getResults(), reduceOp.getResults())) { + if (result.getParentRegion() == retOp->getParentRegion() || + result.getType() != opResult.getType()) { + results.clear(); + return failure(); } - - results.insert(results.end(), replacedResults.begin(), - replacedResults.end()); - return success(); + results.push_back(result); } + return success(); +} +LogicalResult ReduceOp::fold(FoldAdaptor /*adaptor*/, + SmallVectorImpl& results) { + if (succeeded(tryFoldZeroDimReduction(*this, results))) return success(); + if (succeeded(tryFoldOutsideValuesReduction(*this, results))) + return success(); return failure(); } @@ -3516,319 +3660,82 @@ bool hasSameOperandAndResultTypes(Operation& op) { llvm::all_of(op.getResultTypes(), typeMatch); } -// Checks the following eligibility criteria for compact printing of -// mhlo.reduce: -// E1. The reduce-op wraps a single inner-op in the associated region. -// E2. The single operation is a commutative binary-op from mhlo dialect, zero -// region, producing single result such that the operands and result all -// have the same type. -// E3. The reduce-op consist of at least one input-operand; The operand-types of -// inner-op should be derived trivially from the element-type of reduce-op's -// first input-operand. -// E4. The arguments of the region's only basic block are forwarded perfectly -// to inner-op's operands. -// E5. The reduce-op, inner-op, blocks arguments, and the return-op all have the -// same location. -// E6. The single operation result is perfectly forwarded to the reduce op -// return. -static bool isEligibleForCompactPrint(ReduceOp op) { - // Check E1. - auto& block = op.getBody().front(); - if (!hasSingleElement(block.without_terminator())) return false; - - Operation& innerOp = *block.begin(); - - // Check E2. - if (innerOp.getDialect() != op->getDialect()) return false; - - if (innerOp.getNumOperands() != 2 || - !innerOp.hasTrait() || - !hasSameOperandAndResultTypes(innerOp) || - !innerOp.hasTrait() || - !innerOp.hasTrait()) - return false; - - // Check E3. - if (op.getInputs().empty()) return false; - - auto elemType = - op.getInputs()[0].getType().cast().getElementType(); - auto expectedInnerOpType = RankedTensorType::get(/*shape=*/{}, elemType); - if (innerOp.getOperands()[0].getType() != expectedInnerOpType) return false; - - // Check E4. - if (!llvm::equal(block.getArguments(), innerOp.getOperands())) return false; - - // Check E5. - auto retOp = dyn_cast(block.getTerminator()); - if (!retOp) return false; - - auto blockArgLoc = block.getArgument(0).getLoc(); - if (blockArgLoc != block.getArgument(1).getLoc()) return false; - - if (innerOp.getLoc() != op.getLoc() || retOp.getLoc() != op.getLoc() || - blockArgLoc != op.getLoc()) - return false; - - // Check E6. - return llvm::equal(innerOp.getResults(), retOp.getOperands()); -} - void ReduceOp::print(OpAsmPrinter& p) { - { - // Print the pairs of operands under the form: - // (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5) - StringRef comma = ""; - int numOperandPairs = getNumOperands() / 2; - for (int opId : llvm::seq(0, numOperandPairs)) { - p << comma << "(" << getOperand(opId) - << " init: " << getOperand(opId + numOperandPairs) << ")"; - comma = ", "; - } - } - - // If the reduce-op is eligible for compact printing, we emit the one-liner: - // mhlo.reduce applies across dimensions = [...] : - // Note: We are not printing the function type of reduction operation. We - // have some simplifying assumptions (refer to IsEligibleForCompactPrint::E3) - // to derive the type from that of reduce-op. - if (isEligibleForCompactPrint(*this)) { - Operation& innerOp = getBody().front().front(); - p << " applies "; - printEscapedString(innerOp.getName().getStringRef(), p.getStream()); - - p << " across dimensions = ["; - llvm::interleaveComma(getDimensions().getValues(), p); - p << "]"; - p << " : "; - p.printFunctionalType(*this); - } else { - p << " across dimensions = ["; - llvm::interleaveComma(getDimensions().getValues(), p); - p << "]"; - p.printOptionalAttrDict(getOperation()->getAttrs(), {"dimensions"}); - p << " : "; - p.printFunctionalType(*this); - p.printNewline(); - p << " reducer"; - { - // Print the pairs of block operands under the form: - // (%arg0_elt, %arg0_acc) (%arg1_elt, %arg1_acc): - Block& reducer = getBody().front(); - int numOperandPairs = getNumOperands() / 2; - for (int opId : llvm::seq(0, numOperandPairs)) { - p << "("; - p.printRegionArgument(reducer.getArgument(opId)); - p << ", "; - p.printRegionArgument(reducer.getArgument(opId + numOperandPairs)); - p << ") "; - } - } - p << ' '; - p.printRegion(getBody(), /*printEntryBlockArgs=*/false); - } + auto dimensions = llvm::to_vector(getDimensions().getValues()); + hlo::printReduceOp(p, getOperation(), getInputs(), dimensions, getBody()); } ParseResult ReduceOp::parse(OpAsmParser& parser, OperationState& result) { - llvm::SMLoc loc = parser.getCurrentLocation(); - Location currLocation = parser.getEncodedSourceLoc(loc); - - // Parse the operands of reduce-op, this is a list of pair under the form: - // (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5) - // Each input to reduce is paired with its init value, even though in memory - // they are stored with the input first and the init values after. - SmallVector operands; - SmallVector initOperands; - do { - (void)parser.parseOptionalComma(); - if (parser.parseOptionalLParen()) break; - OpAsmParser::UnresolvedOperand operand, initOperand; - if (parser.parseOperand(operand) || parser.parseKeyword("init") || - parser.parseColon() || parser.parseOperand(initOperand) || - parser.parseRParen()) - return failure(); - operands.push_back(operand); - initOperands.push_back(initOperand); - } while (true); - operands.append(initOperands); - - // Check if we are parsing the compact version of reduce-op: - // mhlo.reduce applies across dimensions = [...] : - // else parse the "region-based" variant. - if (failed(parser.parseOptionalKeyword("applies"))) { - // Parse the inner-op dimensions, reduce-op's function-type and - // optional location. - SmallVector dimensions; - auto parseDim = [&]() -> ParseResult { - if (parser.parseInteger(dimensions.emplace_back())) return failure(); - return success(); - }; - - FunctionType reduceOpFntype; - if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") || - parser.parseEqual() || - parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, - parseDim) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColon() || parser.parseType(reduceOpFntype) || - parser.parseKeyword("reducer")) - return failure(); - OpBuilder builder(parser.getBuilder().getContext()); - result.addAttribute("dimensions", builder.getI64TensorAttr(dimensions)); - - // Parse the "reducer" region now. - SmallVector reducerOperands; - SmallVector reducerInitOperands; - SmallVector reducerTypes; - SmallVector reducerInitTypes; - SmallVector, 2> reducerLocs; - SmallVector, 2> reducerInitLocs; - auto parseBlockOperand = - [&](SmallVectorImpl& operands, - SmallVectorImpl& types, - SmallVectorImpl>& locs) -> ParseResult { - OpAsmParser::UnresolvedOperand operand; - Type type; - std::optional loc; - if (parser.parseOperand(operand, /*allowResultNumber=*/false) || - parser.parseColon() || parser.parseType(type) || - parser.parseOptionalLocationSpecifier(loc)) - return failure(); - operands.push_back(operand); - types.push_back(type); - locs.push_back(loc); - return success(); - }; - do { - if (failed(parser.parseOptionalLParen())) break; - if (parseBlockOperand(reducerOperands, reducerTypes, reducerLocs) || - parser.parseComma() || - parseBlockOperand(reducerInitOperands, reducerInitTypes, - reducerInitLocs) || - parser.parseRParen()) - return failure(); - } while (true); - reducerOperands.append(reducerInitOperands); - reducerTypes.append(reducerInitTypes); - reducerLocs.append(reducerInitLocs); - result.addTypes(reduceOpFntype.getResults()); - SmallVector reducerArgs; - createArgs(reducerOperands, reducerTypes, reducerArgs); - - // Derive the SSA-values for reduce-op's operands and parse the region, and - // the optional trailing location. - std::optional trailingLoc; - if (parser.resolveOperands(operands, reduceOpFntype.getInputs(), loc, - result.operands) || - parser.parseRegion(*result.addRegion(), reducerArgs)) - return failure(); - // Set the individual block arguments. - for (auto argAndLoc : - llvm::zip(result.regions.front()->front().getArguments(), reducerLocs)) - if (std::get<1>(argAndLoc)) - std::get<0>(argAndLoc).setLoc(std::get<1>(argAndLoc).value()); - result.location = trailingLoc.value_or(currLocation); - return success(); - } - - // Parse the inner-op name and check if the contract on inner-op - // mentioned in "isEligibleForCompactPrint::E2" for pretty-priting is met. - FailureOr innerOpNameInfo = parser.parseCustomOperationName(); - if (failed(innerOpNameInfo)) return failure(); - - StringRef innerOpName = innerOpNameInfo->getStringRef(); - Dialect* innerOpDialect = innerOpNameInfo->getDialect(); - if (!innerOpDialect || !innerOpDialect->getNamespace().equals("mhlo") || - !innerOpNameInfo->hasTrait::Impl>() || - !innerOpNameInfo->hasTrait() || - !innerOpNameInfo->hasTrait() || - !innerOpNameInfo->hasTrait()) { - parser.emitError(loc, - "expected the inner-op to be a commutative binary-op from " - "mhlo dialect, zero region, producing single result"); - return failure(); - } - - // Parse the inner-op dimensions, reduce-op's function-type and - // optional location. - SmallVector dimensions; - auto parseDim = [&]() -> ParseResult { - if (parser.parseInteger(dimensions.emplace_back())) return failure(); - return success(); + auto parseDenseElements = [](OpBuilder& b, + ArrayRef dims) -> Attribute { + return b.getI64TensorAttr(dims); }; - - std::optional explicitLoc; - FunctionType reduceOpFntype; - if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") || - parser.parseEqual() || - parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, parseDim) || - parser.parseColon() || parser.parseType(reduceOpFntype) || - parser.parseOptionalLocationSpecifier(explicitLoc)) - return failure(); - - if (!reduceOpFntype || reduceOpFntype.getInputs().empty()) { - if (!reduceOpFntype) return parser.emitError(loc, "expected function type"); - return parser.emitError(loc, - "input types missing in reduce-op function type"); - } - - // If location of reduce-op is explicitly provided, then use it; Else use - // the parser's current location. - Location reduceOpLoc = explicitLoc.value_or(currLocation); - - // Derive the SSA-values for reduce-op's operands. - if (parser.resolveOperands(operands, reduceOpFntype.getInputs(), loc, - result.operands)) - return failure(); - - // Derive the type of inner-op from that of reduce-op's input operand. - auto innerOpType = RankedTensorType::get( - /*shape=*/{}, getElementTypeOrSelf(reduceOpFntype.getInput(0))); - - // Add a region for reduce-op. - Region& region = *result.addRegion(); - - // Create a basic-block inside reduce-op's region. - Block& block = region.emplaceBlock(); - auto lhs = block.addArgument(innerOpType, reduceOpLoc); - auto rhs = block.addArgument(innerOpType, reduceOpLoc); - - // Create and insert an "inner-op" operation in the block. - OpBuilder builder(parser.getBuilder().getContext()); - builder.setInsertionPointToStart(&block); - - OperationState innerOpState(reduceOpLoc, innerOpName); - innerOpState.operands.push_back(lhs); - innerOpState.operands.push_back(rhs); - innerOpState.addTypes(innerOpType); - - Operation* innerOp = builder.create(innerOpState); - - // Insert a return statement in the block returning the inner-op's result. - builder.create(innerOp->getLoc(), innerOp->getResults()); - - // Populate the reduce-op operation-state with result-type, location, and - // dimension attribute. - result.addTypes(reduceOpFntype.getResults()); - result.location = innerOp->getLoc(); - result.addAttribute("dimensions", builder.getI64TensorAttr(dimensions)); - - return success(); + return hlo::parseReduceOp(parser, result, parseDenseElements); } LogicalResult ReduceOp::inferReturnTypeComponents( MLIRContext*, std::optional location, ValueShapeRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - ReduceOp::Adaptor adaptor(operands, attributes, {}, regions); - return hlo::inferReduceOp(location, adaptor.getInputs().getTypes(), - adaptor.getInitValues().getTypes(), - adaptor.getDimensions(), inferredReturnShapes); + ReduceOp::Adaptor adaptor(operands, attributes, properties, regions); + return hlo::inferReduceOp( + location, adaptor.getInputs().getTypes(), + llvm::to_vector(adaptor.getDimensions().getValues()), + adaptor.getBody(), inferredReturnShapes); +} + +void ReduceOp::build(OpBuilder&, OperationState& odsState, ValueRange inputs, + ValueRange initValues, DenseIntElementsAttr dimensions, + TypeRange elementTypes) { + odsState.addOperands(inputs); + odsState.addOperands(initValues); + Properties& properties = odsState.getOrAddProperties(); + properties.dimensions = dimensions; + (void)odsState.addRegion(); + + SmallVector newDimensions; + Attribute encoding; + ReduceOp::Adaptor adaptor( + odsState.operands, + odsState.attributes.getDictionary(odsState.getContext()), {}, + odsState.regions); + + SmallVector inputArgTensorTypes{ + llvm::map_range(adaptor.getInputs().getTypes(), + [](Type t) { return t.cast(); })}; + SmallVector initValueTensorTypes{ + llvm::map_range(adaptor.getInitValues().getTypes(), + [](Type t) { return t.cast(); })}; + + if (succeeded(hlo::verifyReduceOpInputsAndInferShape( + odsState.location, inputArgTensorTypes, + llvm::to_vector(dimensions.getValues()), newDimensions, + encoding))) { + SmallVector inferredReturnTypes; + for (uint64_t inputIdx = 0; inputIdx < inputArgTensorTypes.size(); + ++inputIdx) { + Type elementTy = elementTypes[inputIdx]; + ShapedType inputType = inputArgTensorTypes[inputIdx]; + if (inputType.hasRank()) { + inferredReturnTypes.push_back( + RankedTensorType::get(newDimensions, elementTy, encoding)); + } else { + assert(encoding == nullptr && "attribute not supported"); + inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy)); + } + } + odsState.addTypes(inferredReturnTypes); + } else { + llvm::report_fatal_error("Failed to infer result type(s)."); + } } LogicalResult ReduceOp::verify() { - return hlo::verifyReduceOp(getLoc(), getInputs(), getInitValues(), - getDimensions(), getBody()); + if (failed(verify1dTensor(getLoc(), getDimensions(), "dimensions"))) + return failure(); + return hlo::verifyReduceOp( + getLoc(), getInputs(), getInitValues(), + llvm::to_vector(getDimensions().getValues()), getBody()); } // Enable constant folding to occur within the region of the ReduceOp @@ -3970,9 +3877,10 @@ LogicalResult ReduceOp::reifyReturnTypeShapes( //===----------------------------------------------------------------------===// LogicalResult OptimizationBarrierOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - OptimizationBarrierOp::Adaptor adaptor(operands, attributes); + OptimizationBarrierOp::Adaptor adaptor(operands, attributes, properties, + regions); return hlo::inferOptimizationBarrierOp(location, adaptor.getOperand(), inferredReturnTypes); } @@ -3981,7 +3889,11 @@ LogicalResult OptimizationBarrierOp::inferReturnTypes( // ReverseOp //===----------------------------------------------------------------------===// LogicalResult ReverseOp::verify() { - return hlo::verifyReverseOp(getLoc(), getOperand(), getDimensions()); + if (failed(verify1dTensor(getLoc(), getDimensions(), "dimensions"))) + return failure(); + return hlo::verifyReverseOp( + getLoc(), getOperand(), + llvm::to_vector(getDimensions().getValues())); } //===----------------------------------------------------------------------===// @@ -4000,10 +3912,10 @@ LogicalResult RngBitGeneratorOp::verify() { LogicalResult RngOp::inferReturnTypeComponents( MLIRContext* context, std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties, - RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - RngOp::Adaptor adaptor(operands, attributes, {}, regions); + RngOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferRngOp( location, adaptor.getA(), adaptor.getB(), adaptor.getShape(), adaptor.getRngDistribution() == RngDistribution::UNIFORM, @@ -4083,9 +3995,9 @@ void SelectOp::getCanonicalizationPatterns(RewritePatternSet& results, // the return type based on operand type. LogicalResult SelectOp::inferReturnTypeComponents( MLIRContext*, std::optional location, ValueShapeRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - SelectOp::Adaptor op(operands, attributes); + SelectOp::Adaptor op(operands, attributes, properties, regions); return hlo::inferSelectOp(location, op.getPred(), op.getOnTrue(), op.getOnFalse(), inferredReturnShapes); } @@ -4121,10 +4033,11 @@ OpFoldResult SetDimensionSizeOp::fold(FoldAdaptor adaptor) { LogicalResult SetDimensionSizeOp::inferReturnTypeComponents( MLIRContext* context, std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties, - RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - SetDimensionSizeOp::Adaptor adaptor(operands, attributes, {}, regions); + SetDimensionSizeOp::Adaptor adaptor(operands, attributes, properties, + regions); return hlo::inferSetDimensionSizeOp( getMhloDialect(context), location, adaptor.getOperand().getType(), adaptor.getSize(), adaptor.getDimension(), inferredReturnShapes); @@ -4136,14 +4049,23 @@ LogicalResult SetDimensionSizeOp::inferReturnTypeComponents( LogicalResult PadOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - PadOp::Adaptor adaptor(operands, attributes, {}, regions); - return hlo::inferPadOp(location, adaptor.getOperand().getType(), - adaptor.getPaddingValue().getType(), - adaptor.getEdgePaddingLow(), - adaptor.getEdgePaddingHigh(), - adaptor.getInteriorPadding(), inferredReturnTypes); + PadOp::Adaptor adaptor(operands, attributes, properties, regions); + if (failed(verify1dTensor(location, adaptor.getEdgePaddingLow(), + "edge_padding_low")) || + failed(verify1dTensor(location, adaptor.getEdgePaddingHigh(), + "edge_padding_high")) || + failed(verify1dTensor(location, adaptor.getInteriorPadding(), + "interior_padding"))) + return failure(); + return hlo::inferPadOp( + location, adaptor.getOperand().getType(), + adaptor.getPaddingValue().getType(), + llvm::to_vector(adaptor.getEdgePaddingLow().getValues()), + llvm::to_vector(adaptor.getEdgePaddingHigh().getValues()), + llvm::to_vector(adaptor.getInteriorPadding().getValues()), + inferredReturnTypes); } template @@ -4232,7 +4154,8 @@ OpFoldResult PadOp::fold(FoldAdaptor adaptor) { LogicalResult PadOp::reifyReturnTypeShapes( OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { - PadOp::Adaptor adaptor(operands, this->getOperation()->getAttrDictionary()); + PadOp::Adaptor adaptor(operands, this->getOperation()->getAttrDictionary(), + this->getOperation()->getPropertiesStorage()); auto loc = this->getLoc(); Value operand = adaptor.getOperand(); auto operandTy = operand.getType().cast(); @@ -4455,6 +4378,14 @@ LogicalResult DynamicPadOp::reifyReturnTypeShapes( //===----------------------------------------------------------------------===// LogicalResult ReshapeOp::verify() { + // Check for unranked dynamism. Unranked dynamism is not supported by + // StableHLO (hlo::verifyReshapeOp will fail) and we can't verify + // anything statically in that case anyway. + auto operandType = getOperand().getType().cast(); + auto resultType = getResult().getType().cast(); + if (!operandType.hasRank() || !resultType.hasRank()) { + return success(); + } return hlo::verifyReshapeOp(getLoc(), getOperand(), getResult()); } @@ -4522,9 +4453,9 @@ LogicalResult AddDependencyOp::inferReturnTypes( LogicalResult IfOp::inferReturnTypes( MLIRContext* context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - IfOp::Adaptor adaptor(operands, attributes, {}, regions); + IfOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferIfOp(location, adaptor.getPred(), adaptor.getRegions(), inferredReturnTypes); } @@ -4553,9 +4484,9 @@ void IfOp::getCanonicalizationPatterns(RewritePatternSet& results, LogicalResult CaseOp::inferReturnTypes( MLIRContext* context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - CaseOp::Adaptor adaptor(operands, attributes, {}, regions); + CaseOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferCaseOp(location, adaptor.getIndex(), adaptor.getRegions(), inferredReturnTypes); } @@ -4766,6 +4697,7 @@ UNARY_FOLDER_FLOAT(RoundNearestEvenOp, RoundNearestEven) UNARY_FOLDER_FLOAT(RoundOp, Round) UNARY_FOLDER_UPCAST_TO_F64(CosineOp, std::cos, AnyValue) +UNARY_FOLDER_UPCAST_TO_F64(ErfOp, std::erf, AnyValue) UNARY_FOLDER_UPCAST_TO_F64(ExpOp, std::exp, AnyValue) UNARY_FOLDER_UPCAST_TO_F64(LogisticOp, logistic, AnyValue) UNARY_FOLDER_UPCAST_TO_F64(LogOp, std::log, PositiveValue) @@ -5101,9 +5033,9 @@ OpFoldResult ClampOp::fold(FoldAdaptor adaptor) { LogicalResult ClampOp::inferReturnTypeComponents( MLIRContext*, std::optional location, ValueShapeRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - ClampOp::Adaptor adaptor(operands, attributes, {}, regions); + ClampOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferClampOp(location, adaptor.getMin(), adaptor.getOperand(), adaptor.getMax(), inferredReturnShapes); } @@ -5122,12 +5054,21 @@ LogicalResult ClampOp::reifyReturnTypeShapes( LogicalResult SliceOp::inferReturnTypes( MLIRContext* /*context*/, std::optional location, - ValueRange operands, DictionaryAttr attributes, OpaqueProperties, - RegionRange /*regions*/, SmallVectorImpl& inferredReturnTypes) { - SliceOpAdaptor adaptor(operands, attributes); - return hlo::inferSliceOp(location, adaptor.getOperand().getType(), - adaptor.getStartIndices(), adaptor.getLimitIndices(), - adaptor.getStrides(), inferredReturnTypes); + ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, + RegionRange regions, SmallVectorImpl& inferredReturnTypes) { + SliceOpAdaptor adaptor(operands, attributes, properties, regions); + if (failed(verify1dTensor(location, adaptor.getStartIndices(), + "start_indices")) || + failed(verify1dTensor(location, adaptor.getLimitIndices(), + "limit_indices")) || + failed(verify1dTensor(location, adaptor.getStrides(), "strides"))) + return failure(); + return hlo::inferSliceOp( + location, adaptor.getOperand().getType(), + llvm::to_vector(adaptor.getStartIndices().getValues()), + llvm::to_vector(adaptor.getLimitIndices().getValues()), + llvm::to_vector(adaptor.getStrides().getValues()), + inferredReturnTypes); } template @@ -5334,8 +5275,9 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet& results, void SortOp::build(OpBuilder& builder, OperationState& state, ValueRange operands, int64_t dimension, bool isStable) { state.addOperands(operands); - state.addAttribute("dimension", builder.getI64IntegerAttr(dimension)); - state.addAttribute("is_stable", builder.getBoolAttr(isStable)); + Properties& properties = state.getOrAddProperties(); + properties.dimension = builder.getI64IntegerAttr(dimension); + properties.is_stable = builder.getBoolAttr(isStable); for (Value operand : operands) state.addTypes(operand.getType()); @@ -5344,9 +5286,9 @@ void SortOp::build(OpBuilder& builder, OperationState& state, LogicalResult SortOp::inferReturnTypeComponents( MLIRContext*, std::optional location, ValueShapeRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - SortOp::Adaptor adaptor(operands, attributes, {}, regions); + SortOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferSortOp(location, adaptor.getInputs(), inferredReturnShapes); } @@ -5456,7 +5398,8 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { return {}; } } - return getOperand(); + if (getOperand().getType() == getType()) return getOperand(); + return {}; } // transpose(transpose(X)) => transpose(X) @@ -5581,11 +5524,15 @@ LogicalResult TransposeOp::reifyReturnTypeShapes( LogicalResult TransposeOp::inferReturnTypes( MLIRContext*, std::optional loc, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - TransposeOp::Adaptor adaptor(operands, attributes, {}, regions); - return hlo::inferTransposeOp(loc, adaptor.getOperand(), - adaptor.getPermutation(), inferredReturnTypes); + TransposeOp::Adaptor adaptor(operands, attributes, properties, regions); + if (failed(verify1dTensor(loc, adaptor.getPermutation(), "permutation"))) + return failure(); + return hlo::inferTransposeOp( + loc, adaptor.getOperand(), + llvm::to_vector(adaptor.getPermutation().getValues()), + inferredReturnTypes); } //===----------------------------------------------------------------------===// @@ -5594,9 +5541,9 @@ LogicalResult TransposeOp::inferReturnTypes( LogicalResult TriangularSolveOp::inferReturnTypeComponents( MLIRContext*, std::optional location, ValueShapeRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - TriangularSolveOp::Adaptor adaptor(operands, attributes, {}, regions); + TriangularSolveOp::Adaptor adaptor(operands, attributes, properties, regions); bool isTransposeAInvalid = (adaptor.getTransposeA() == Transpose::TRANSPOSE_INVALID); return hlo::inferTriangularSolveOp(location, adaptor.getA(), adaptor.getB(), @@ -5618,9 +5565,9 @@ OpFoldResult GetTupleElementOp::fold(FoldAdaptor /*adaptor*/) { LogicalResult GetTupleElementOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - GetTupleElementOp::Adaptor adaptor(operands, attributes, {}, regions); + GetTupleElementOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferGetTupleElementOp(location, adaptor.getOperand(), adaptor.getIndex(), inferredReturnTypes); } @@ -5631,9 +5578,9 @@ LogicalResult GetTupleElementOp::inferReturnTypes( LogicalResult TupleOp::inferReturnTypes( MLIRContext* context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - TupleOp::Adaptor adaptor(operands, attributes, {}, regions); + TupleOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferTupleOp(context, location, adaptor.getVal(), inferredReturnTypes); } @@ -5661,10 +5608,10 @@ void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs, LogicalResult CompareOp::inferReturnTypeComponents( MLIRContext* context, std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties, - RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - CompareOp::Adaptor adaptor(operands, attributes, {}, regions); + CompareOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferCompareOp(context, location, adaptor.getLhs(), inferredReturnShapes); } @@ -5792,19 +5739,34 @@ OpFoldResult CompareOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// LogicalResult SelectAndScatterOp::inferReturnTypes( - MLIRContext*, std::optional, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + MLIRContext*, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - SelectAndScatterOp::Adaptor adaptor(operands, attributes, {}, regions); - return hlo::inferSelectAndScatterOp(adaptor.getOperand(), + SelectAndScatterOp::Adaptor adaptor(operands, attributes, properties, + regions); + return hlo::inferSelectAndScatterOp(location, adaptor.getOperand(), + adaptor.getScatter(), inferredReturnTypes); } LogicalResult SelectAndScatterOp::verify() { - return hlo::verifySelectAndScatterOp(getLoc(), getOperand(), getSource(), - getInitValue(), getWindowDimensions(), - getWindowStrides(), getPadding(), - getSelect(), getScatter()); + if (getWindowDimensions() && + failed(verify1dTensor(getLoc(), *getWindowDimensions(), + "window_dimensions"))) + return failure(); + if (getWindowStrides() && + failed(verify1dTensor(getLoc(), *getWindowStrides(), "window_strides"))) + return failure(); + + return hlo::verifySelectAndScatterOp( + getLoc(), getOperand(), getSource(), getInitValue(), + getWindowDimensions() + ? llvm::to_vector(getWindowDimensions()->getValues()) + : ArrayRef{}, + getWindowStrides() + ? llvm::to_vector(getWindowStrides()->getValues()) + : ArrayRef{}, + getPadding(), getSelect(), getScatter()); } //===----------------------------------------------------------------------===// @@ -5813,10 +5775,11 @@ LogicalResult SelectAndScatterOp::verify() { LogicalResult ScatterOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - ScatterOp::Adaptor adaptor(operands, attributes, {}, regions); + ScatterOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferScatterOp(location, adaptor.getInputs(), + adaptor.getUpdateComputation(), inferredReturnTypes); } @@ -6038,9 +6001,9 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet& results, LogicalResult WhileOp::inferReturnTypes( MLIRContext* context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - WhileOp::Adaptor adaptor(operands, attributes, {}, regions); + WhileOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferWhileOp(location, adaptor.getOperand(), inferredReturnTypes); } @@ -6048,70 +6011,12 @@ LogicalResult WhileOp::verify() { return hlo::verifyWhileOp(getLoc(), getOperand(), getCond(), getBody()); } -/// Print a `while` op. -/// -/// op ::= `mhlo.while` `(` assignment-list `)` `:` types attribute-dict -/// `cond` region -/// `do` region -/// assignment-list ::= assignment | assignment `,` assignment-list -/// assignment ::= ssa-value `=` ssa-value void WhileOp::print(OpAsmPrinter& p) { - p << '('; - llvm::interleaveComma( - llvm::zip(SingleBlock::getBody()->getArguments(), getOperands()), p, - [&](auto zip) { - p.printOperand(std::get<0>(zip)); - p << " = "; - p.printOperand(std::get<1>(zip)); - }); - p << ")"; - if (getNumOperands()) { - p << " : "; - llvm::interleaveComma(getOperandTypes(), p); - } - p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs()); - p.printNewline(); - p << " cond "; - p.printRegion(getRegion(0), /*printEntryBlockArgs=*/false); - p << " do "; - p.printRegion(getRegion(1), /*printEntryBlockArgs=*/false); + hlo::printWhileOp(p, getOperation(), getCond(), getBody()); } ParseResult WhileOp::parse(OpAsmParser& parser, OperationState& result) { - llvm::SMLoc loc = parser.getCurrentLocation(); - // Parse the operands of the while: these are of the form: - // %iter_arg = %init_val - // where %iter_arg is the name of the block argument in the cond/body blocks - // and %init_val is the actual operand. - SmallVector operands; - SmallVector iterArgs; - if (parser.parseLParen()) return failure(); - do { - if (succeeded(parser.parseOptionalRParen())) break; - OpAsmParser::UnresolvedOperand operand, iterArg; - if (parser.parseOperand(iterArg) || parser.parseEqual() || - parser.parseOperand(operand)) - return failure(); - iterArgs.push_back(iterArg); - operands.push_back(operand); - if (succeeded(parser.parseOptionalRParen())) break; - if (failed(parser.parseComma())) return failure(); - } while (true); - if (!operands.empty()) { - if (parser.parseColon() || parser.parseTypeList(result.types)) - return failure(); - } - - SmallVector args; - createArgs(iterArgs, result.types, args); - if (parser.resolveOperands(operands, result.types, loc, result.operands) || - parser.parseOptionalAttrDictWithKeyword(result.attributes) || - parser.parseKeyword("cond") || - parser.parseRegion(*result.addRegion(), args) || - parser.parseKeyword("do") || - parser.parseRegion(*result.addRegion(), args)) - return failure(); - return success(); + return hlo::parseWhileOp(parser, result); } LogicalResult WhileOp::fold(FoldAdaptor /*adaptor*/, @@ -6195,9 +6100,10 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet& results, LogicalResult UniformDequantizeOp::inferReturnTypeComponents( MLIRContext*, std::optional location, ValueShapeRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - UniformDequantizeOp::Adaptor adaptor(operands, attributes, {}, regions); + UniformDequantizeOp::Adaptor adaptor(operands, attributes, properties, + regions); return hlo::inferUniformDequantizeOp(location, adaptor.getOperand(), inferredReturnShapes); } diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.h b/xla/mlir_hlo/mhlo/IR/hlo_ops.h index 4a5483a91c6d8..543c47c27a992 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.h +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -89,6 +89,7 @@ class MhloDialect : public Dialect { class TokenType : public Type::TypeBase { public: using Base::Base; + static constexpr StringLiteral name = "mhlo.token"; }; void printConvolutionDimensions(AsmPrinter &p, ConvDimensionNumbersAttr dnums); diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/xla/mlir_hlo/mhlo/IR/hlo_ops.td index e566a7463a64b..63d89a47ab906 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ limitations under the License. #define MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS include "mlir/Dialect/Shape/IR/ShapeBase.td" +include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpAsmInterface.td" @@ -194,8 +195,8 @@ class MHLO_UnaryElementwiseOp traits, // Abs supports complex to real, so element type is not guaranteed to match. def MHLO_AbsOp: MHLO_UnaryElementwiseOp<"abs", [Pure, DeclareOpInterfaceMethods], - TensorOf<[HLO_SInt, HLO_Float, HLO_Complex, HLO_QuantizedInt]>, - TensorOf<[HLO_SInt, HLO_Float, HLO_QuantizedInt]>> { + RankedTensorOf<[MHLO_SInt, MHLO_Float, MHLO_Complex, MHLO_QuantizedInt]>, + RankedTensorOf<[MHLO_SInt, MHLO_Float, MHLO_QuantizedInt]>> { let summary = "Abs operation"; let description = [{ Performs element-wise abs operation on `operand` tensor and produces a @@ -305,6 +306,23 @@ def MHLO_CosineOp: MHLO_UnaryElementwiseOp<"cosine", let hasCustomHLOConverter = 1; } +def MHLO_ErfOp: MHLO_UnaryElementwiseOp<"erf", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpTensor> { + let summary = "Erf operation"; + let description = [{ + Performs element-wise erf operation on `operand` tensor and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#erf + + Example: + ```mlir + %result = mhlo.erf %operand : tensor<2x2xf32> + ``` + }]; + let hasFolder = 1; +} def MHLO_ExpOp: MHLO_UnaryElementwiseOp<"exponential", [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpComplexOrQuantizedIntTensor> { let summary = "Exp operation"; @@ -577,7 +595,7 @@ def MHLO_RsqrtOp: MHLO_UnaryElementwiseOp<"rsqrt", def MHLO_SignOp: MHLO_UnaryElementwiseOp<"sign", [Pure, HLO_CompatibleOperandsAndResultType], - TensorOf<[MHLO_SInt, MHLO_Float, MHLO_Complex, HLO_QuantizedInt]>> { + RankedTensorOf<[MHLO_SInt, MHLO_Float, MHLO_Complex, HLO_QuantizedInt]>> { let summary = "Sign operation"; let description = [{ Returns the sign of the `operand` element-wise and produces a `result` @@ -942,7 +960,7 @@ def MHLO_SubtractOp : MHLO_BinaryElementwiseOp<"subtract", } def MHLO_StochasticConvertOp : MHLO_Op<"stochastic_convert", - [Pure, AllShapesMatch<["operand", "random", "result"]>]> { + [Pure, Elementwise, AllShapesMatch<["operand", "random", "result"]>]> { let summary = "StochasticConvert operation"; let description = [{ This operation is a work in progress, so it is not yet included in @@ -953,7 +971,7 @@ def MHLO_StochasticConvertOp : MHLO_Op<"stochastic_convert", number passed in. }]; - let arguments = (ins MHLO_FpTensor:$operand, TensorOf<[MHLO_UInt]>:$random); + let arguments = (ins MHLO_FpTensor:$operand, RankedTensorOf<[MHLO_UInt]>:$random); let results = (outs MHLO_Tensor:$result); let hasCustomHLOConverter = 1; let hasVerifier = 1; @@ -1162,7 +1180,7 @@ def MHLO_ReplicaIdOp : MHLO_Op<"replica_id", [Pure, %result = mhlo.replica_id : tensor ``` }]; - let results = (outs TensorOf<[UI32]>); + let results = (outs UI32RankedTensor); let assemblyFormat = "attr-dict `:` type(results)"; } @@ -1246,9 +1264,6 @@ def MHLO_AsyncStartOp : MHLO_Op<"async_start", []> { `called_computation` is the function that will be run asynchronously `execution_thread` is the name of the thread in which it will be run. The main thread is called "main". All threads have names. - `group_id` labels a set of async-start, async-done, and zero or more - async-update ops corresponding to the same computation. We - represent a missing group_id with either an negative value or None. This returns all the state needed between async ops. After buffer assignment, the return values represents the space needed to hold the input, @@ -1258,8 +1273,7 @@ def MHLO_AsyncStartOp : MHLO_Op<"async_start", []> { let arguments = (ins Variadic:$inputs, FlatSymbolRefAttr:$called_computation, - StrAttr:$execution_thread, - OptionalAttr:$group_id + StrAttr:$execution_thread ); let results = (outs MHLO_AsyncBundle); @@ -1282,8 +1296,7 @@ def MHLO_AsyncUpdateOp : MHLO_Op<"async_update", [DeclareOpInterfaceMethods:$group_id + StrAttr:$execution_thread ); let results = (outs MHLO_AsyncBundle); @@ -1307,8 +1320,7 @@ def MHLO_AsyncDoneOp : MHLO_Op<"async_done", [DeclareOpInterfaceMethods:$group_id + StrAttr:$execution_thread ); let results = (outs Variadic); @@ -1455,8 +1467,9 @@ def MHLO_AllGatherOp : MHLO_Op<"all_gather", [SameOperandsAndResultElementType]> string summary = "AllGather operation"; string description = [{ Within each process group in the process grid, concatenates the values of the - `operand` tensor from each process along `all_gather_dim` and produces a - `result` tensor. + operand tensor from each process along `all_gather_dim` and produces a + result tensor. The `computation` is applied separately for each operand in + `operands`, producing one result per operand. See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_gather @@ -1474,13 +1487,13 @@ def MHLO_AllGatherOp : MHLO_Op<"all_gather", [SameOperandsAndResultElementType]> }]; let arguments = (ins - MHLO_Tensor:$operand, + Variadic:$operands, I64Attr:$all_gather_dim, I64ElementsAttr:$replica_groups, OptionalAttr:$channel_handle, UnitAttr:$use_global_device_ids ); - let results = (outs MHLO_Tensor); + let results = (outs Variadic); // use_global_device_ids is rarely used, so we add simplified builder methods // for convenience. let builders = [ @@ -1495,7 +1508,6 @@ def MHLO_AllGatherOp : MHLO_Op<"all_gather", [SameOperandsAndResultElementType]> } def MHLO_AllReduceOp : MHLO_Op<"all_reduce", [ - SameOperandsAndResultElementType, SingleBlockImplicitTerminator<"ReturnOp">, InferTensorType ]> { @@ -1549,8 +1561,7 @@ def MHLO_AllReduceOp : MHLO_Op<"all_reduce", [ let hasCustomHLOConverter = 1; } -def MHLO_ReduceScatterOp : MHLO_Op<"reduce_scatter", - [SameOperandsAndResultElementType]> { +def MHLO_ReduceScatterOp : MHLO_Op<"reduce_scatter", []> { let summary = "ReduceScatter operation"; let description = [{ Within each process group in the process grid, performs reduction, using @@ -1696,6 +1707,12 @@ def MHLO_ReduceOp: MHLO_ShapedInterfaceOp<"reduce", [ // compatible with reduce op's operands. let regions = (region SizedRegion<1>:$body); + // Builder + let builders = [ + OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$init_values, + "DenseIntElementsAttr":$dimensions, "TypeRange":$element_types)>, + ]; + // TODO(b/129422361): ReduceOp has special conversion logic to HLO. let hasCustomHLOConverter = 1; } @@ -1948,17 +1965,17 @@ def MHLO_BatchNormGradOp : MHLO_Op<"batch_norm_grad", [Pure, }]; let arguments = (ins - RankedTensorOf<[MHLO_Float]>:$operand, + MHLO_FpTensor:$operand, 1DTensorOf<[MHLO_Float]>:$scale, 1DTensorOf<[MHLO_Float]>:$mean, 1DTensorOf<[MHLO_Float]>:$variance, - RankedTensorOf<[MHLO_Float]>:$grad_output, + MHLO_FpTensor:$grad_output, F32Attr:$epsilon, I64Attr:$feature_index ); let results = (outs - RankedTensorOf<[MHLO_Float]>:$grad_operand, + MHLO_FpTensor:$grad_operand, 1DTensorOf<[MHLO_Float]>:$grad_scale, 1DTensorOf<[MHLO_Float]>:$grad_offset); @@ -1985,7 +2002,7 @@ def MHLO_BatchNormInferenceOp : MHLO_Op<"batch_norm_inference", }]; let arguments = (ins - RankedTensorOf<[MHLO_Float]>:$operand, + MHLO_FpTensor:$operand, 1DTensorOf<[MHLO_Float]>:$scale, 1DTensorOf<[MHLO_Float]>:$offset, 1DTensorOf<[MHLO_Float]>:$mean, @@ -1994,7 +2011,7 @@ def MHLO_BatchNormInferenceOp : MHLO_Op<"batch_norm_inference", I64Attr:$feature_index ); - let results = (outs RankedTensorOf<[MHLO_Float]>:$result); + let results = (outs MHLO_FpTensor:$result); } def MHLO_BatchNormTrainingOp : MHLO_Op<"batch_norm_training", @@ -2019,7 +2036,7 @@ def MHLO_BatchNormTrainingOp : MHLO_Op<"batch_norm_training", }]; let arguments = (ins - RankedTensorOf<[MHLO_Float]>:$operand, + MHLO_FpTensor:$operand, 1DTensorOf<[MHLO_Float]>:$scale, 1DTensorOf<[MHLO_Float]>:$offset, F32Attr:$epsilon, @@ -2027,7 +2044,7 @@ def MHLO_BatchNormTrainingOp : MHLO_Op<"batch_norm_training", ); let results = (outs - RankedTensorOf<[MHLO_Float]>:$output, + MHLO_FpTensor:$output, 1DTensorOf<[MHLO_Float]>:$batch_mean, 1DTensorOf<[MHLO_Float]>:$batch_var); @@ -2243,6 +2260,42 @@ def MHLO_ConcatenateOp : MHLO_ShapedInterfaceOp<"concatenate", let hasFolder = 1; } +def MHLO_CollectiveBroadcastOp: MHLO_Op<"collective_broadcast", + [HLO_CompatibleOperandsAndResultType]> { + let summary = "CollectiveBroadcast operation"; + let description = [{ + Within each process group in the process grid, send the value of the + `operand` tensor from the source process to the target processes and produce a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#collective_broadcast + + Example: + ```mlir + %result = "mhlo.collective_broadcast"(%operand) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #mhlo.channel_handle + } : (tensor<1x2xi64>) -> tensor<1x2xi64> + ``` + }]; + + let arguments = (ins + MHLO_Tensor:$operand, + I64ElementsAttr:$replica_groups, + OptionalAttr:$channel_handle + ); + let results = (outs MHLO_Tensor); + let hasCustomHLOConverter = 1; + let hasVerifier = 1; + // channel_handle is only used for the SPMD partitioner, so we add a + // simplified builder method for convenience. + let builders = [ + OpBuilder<(ins + "::mlir::Type":$result_type, "::mlir::Value":$operand, + "::mlir::DenseIntElementsAttr":$replica_groups)>]; +} + def MHLO_CollectivePermuteOp: MHLO_Op<"collective_permute", [Pure, HLO_CompatibleOperandsAndResultType]> { let summary = "CollectivePermute operation"; @@ -2279,6 +2332,44 @@ def MHLO_CollectivePermuteOp: MHLO_Op<"collective_permute", "::mlir::DenseIntElementsAttr":$source_target_pairs)>]; } +def MHLO_CompositeOp : MHLO_Op<"composite", [DeclareOpInterfaceMethods]> { + let summary = "Composite operation"; + let description = [{ + Encapsulates an operation made up (composed) of other StableHLO operations, + taking `inputs` and `composite_attributes` and producing `results`. The + semantics of the op are implemented by the `decomposition` attribute. The + `composite` op can be replaced with its decomposition without changing program + semantics. In cases where inlining the decomposition does not provide the same + op semantics, prefer using `custom_call`. + + The `version` field (defaults to `0`) is used to denote when a composite's + semantics change. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#composite + + Example: + ```mlir + %results = mhlo.composite "my.op" %arg0, %arg1 { + decomposition = @my_op, + composite_attributes = { my_attribute = "my_value" }, + version = 1 : i32 + } : (tensor, tensor) -> tensor + ``` + }]; + + let arguments = (ins + Variadic:$inputs, + StrAttr:$name, + DefaultValuedOptionalAttr:$composite_attributes, + FlatSymbolRefAttr:$decomposition, + DefaultValuedOptionalAttr:$version + ); + let results = (outs Variadic); + let hasCustomHLOConverter = 1; + let assemblyFormat = "$name $inputs attr-dict `:` functional-type(operands, results)"; +} + def MHLO_ConvolutionOp : MHLO_Op<"convolution", [Pure]> { let summary = "Convolution operation"; let description = [{ @@ -2332,7 +2423,8 @@ def MHLO_ConvolutionOp : MHLO_Op<"convolution", [Pure]> { }]; } -def MHLO_CopyOp: MHLO_Op<"copy", [Pure, HLO_CompatibleOperandsAndResultType]> { +def MHLO_CopyOp: MHLO_Op<"copy", + [Pure, Elementwise, HLO_CompatibleOperandsAndResultType]> { let summary = "Copy operation"; let description = [{ This operation is private to the XLA compiler, so it is does not yet have @@ -2435,7 +2527,7 @@ def MHLO_CustomCallOp: MHLO_Op<"custom_call", ``` }]; let arguments = (ins - Variadic:$inputs, + Variadic:$inputs, StrAttr:$call_target_name, DefaultValuedOptionalAttr:$has_side_effect, OptionalAttr>:$backend_config, @@ -2455,7 +2547,7 @@ def MHLO_CustomCallOp: MHLO_Op<"custom_call", "Aliasing attribute for outputs and operands of CustomCall">, "{}">:$output_operand_aliases ); - let results = (outs Variadic); + let results = (outs Variadic); let hasCustomHLOConverter = 1; let hasVerifier = 1; @@ -2540,6 +2632,29 @@ def MHLO_DotGeneralOp: MHLO_ShapedInterfaceOp<"dot_general", [Pure]> { let hasVerifier = 1; } +def MHLO_SparseDotOp: MHLO_Op<"sparse_dot", [Pure]> { + let summary = "Sparse dot operation"; + let description = [{ + Similar to `dot_general` operation, with one or both of the operands being + sparse. An additional argument provides sparsity meta information. + Disclaimer: this op is experimental / a work in progress. + }]; + let arguments = (ins + MHLO_Tensor:$lhs, + MHLO_Tensor:$rhs, + Variadic:$meta, + OptionalAttr:$lhs_sparsity, + OptionalAttr:$rhs_sparsity, + MHLO_DotDimensionNumbers:$dot_dimension_numbers, + MHLO_PrecisionConfigAttr:$precision_config + ); + let results = (outs MHLO_Tensor); + // SparseDot op required custom exporter to pass the preferred element type + // to Xla builder. + let hasCustomHLOConverter = 1; + let hasVerifier = 1; +} + def MHLO_EinsumOp: MHLO_Op<"einsum", [Pure]> { let summary = "Einsum operation"; let description = [{ @@ -2732,7 +2847,7 @@ def MHLO_ReshapeOp: MHLO_Op<"reshape", ``` }]; - let arguments = (ins MHLO_Tensor:$operand); + let arguments = (ins MHLO_AnyTensor:$operand); let results = (outs MHLO_StaticShapeTensor); let hasFolder = 1; @@ -2760,8 +2875,8 @@ def MHLO_DynamicReshapeOp: MHLO_ShapedInterfaceOp<"dynamic_reshape", [Pure]> { ``` }]; - let arguments = (ins MHLO_Tensor:$operand, MHLO_DimensionTensor:$output_shape); - let results = (outs MHLO_Tensor:$result); + let arguments = (ins MHLO_AnyTensor:$operand, MHLO_DimensionTensor:$output_shape); + let results = (outs MHLO_AnyTensor:$result); let hasCanonicalizer = 1; // Cannot be exported to legacy formats. @@ -2802,7 +2917,7 @@ def MHLO_ScatterOp: MHLO_Op<"scatter", }]; let arguments = (ins Variadic:$inputs, - TensorOf<[AnyInteger, Index]>:$scatter_indices, + RankedTensorOf<[AnyInteger, Index]>:$scatter_indices, Variadic:$updates, MHLO_ScatterDimensionNumbers:$scatter_dimension_numbers, DefaultValuedOptionalAttr:$indices_are_sorted, @@ -3045,8 +3160,8 @@ def MHLO_PartitionIdOp : MHLO_Op<"partition_id", [ %result = mhlo.partition_id : tensor ``` }]; - let results = (outs TensorOf<[UI32]>); - let results = (outs TensorOf<[UI32]>); + let results = (outs UI32RankedTensor); + let results = (outs UI32RankedTensor); let hasCustomHLOConverter = 1; let assemblyFormat = "attr-dict `:` type(results)"; @@ -3454,7 +3569,7 @@ def MHLO_XlaRngGetAndUpdateStateOp: MHLO_Op<"xla.rng_get_and_update_state", [Dec // TODO(b/230662142): Implement unknown scales/zero_point cases. def MHLO_UniformQuantizeOp : MHLO_UnaryElementwiseOp<"uniform_quantize", - [Pure], TensorOf<[HLO_Float, MHLO_QuantizedInt]>, + [Pure], RankedTensorOf<[MHLO_Float, MHLO_QuantizedInt]>, MHLO_QuantizedIntTensor> { let summary = "UniformQuantize operation"; let description = [{ @@ -3558,8 +3673,8 @@ def MHLO_BitcastOp : MHLO_Op<"bitcast", [Pure]> { let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } -def MHLO_ReducePrecisionOp : - MHLO_Op<"reduce_precision", [HLO_CompatibleOperandsAndResultType, Pure]> { +def MHLO_ReducePrecisionOp : MHLO_Op<"reduce_precision", + [HLO_CompatibleOperandsAndResultType, Pure, Elementwise]> { let summary = "ReducePrecision operation"; let description = [{ Performs element-wise conversion of `operand` to another floating-point type diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td b/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td index 08c4619591525..2e52b72c04ea1 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -157,8 +157,9 @@ def MHLO_ArgResultAlias : AttrDef { } // Represents a unique identifier for each Send/Recv instruction pair or -// optionally for collective instructions (AllReduce, CollectivePermute, -// AllToAll). Non-positive channel_id handle is equivalent to no channel id. +// optionally for collective instructions (AllToAll, AllReduce, +// CollectiveBroadcast, and CollectivePermute). Non-positive channel_id +// handle is equivalent to no channel id. def MHLO_ChannelHandle : AttrDef { let mnemonic = "channel_handle"; let parameters = (ins "int64_t":$handle, "int64_t":$type); @@ -255,6 +256,26 @@ def MHLO_FlatSymbolRefArrayAttr : let constBuilderCall = "::mlir::ArrayAttr::get($_builder.getContext(), $0)"; } +// Sparsity descriptor attribute +def MHLO_SparsityDescriptor : AttrDef { + let mnemonic = "sparsity"; + let summary = "Describes structured (N:M) sparsity configuration"; + let description = [{ + This attribute is defined for a sparse dot operation with a structured + sparse input tensor. With (N=2,M=4), every 4 consecutive logical elements + have exactly 2 non-zero physical elements in the input tensor. + + $dimension defines the index of the contracting dimension that is sparse + (it has to be the most minor dimension). The additional metadata operand + in the sparse dot operation defines which logical elements are zeroed out. + }]; + let parameters = (ins + "int64_t":$dimension, + "int64_t":$n, + "int64_t":$m); + let assemblyFormat = "`<` struct(params) `>`"; +} + //===----------------------------------------------------------------------===// // Common convolution attributes //===----------------------------------------------------------------------===// diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops_common.cc b/xla/mlir_hlo/mhlo/IR/hlo_ops_common.cc index 3a70f02633bf1..ff69d84d08879 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops_common.cc +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops_common.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops_common.h b/xla/mlir_hlo/mhlo/IR/hlo_ops_common.h index db6898ae9850f..854e986764dca 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops_common.h +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops_common.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops_common.td b/xla/mlir_hlo/mhlo/IR/hlo_ops_common.td index b04fdb14a6557..d7ce2fdd56487 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops_common.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops_common.td @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,7 +22,6 @@ def MHLO_Dialect : Dialect { let useDefaultAttributePrinterParser = 0; let useDefaultTypePrinterParser = 0; - let usePropertiesForAttributes = 0; } include "mhlo/IR/hlo_base.td" diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td b/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td index 729dbf48cb2c3..25375ac741da1 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops_typedefs.td b/xla/mlir_hlo/mhlo/IR/hlo_ops_typedefs.td index 532915ab98022..e9baa78d2956e 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops_typedefs.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops_typedefs.td @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/IR/hlo_patterns.td b/xla/mlir_hlo/mhlo/IR/hlo_patterns.td index 2c1534e6568b0..cba8995e22d5e 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_patterns.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_patterns.td @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/IR/hlo_utils.td b/xla/mlir_hlo/mhlo/IR/hlo_utils.td index b5eb6de56fef3..079e363b826f5 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_utils.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_utils.td @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -44,6 +44,8 @@ def MHLO_ConstantLikeNegInfValue : NativeCodeCall< def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; +def NullDenseI64ArrayAttr : NativeCodeCall<"DenseI64ArrayAttr()">; + def BinBroadcastDimensions : NativeCodeCall< "hlo::getBroadcastDimensionsAttr(&$_builder, $0, $1)">; diff --git a/xla/mlir_hlo/mhlo/IR/init.cc b/xla/mlir_hlo/mhlo/IR/init.cc index 9cd5e2a926578..11f2228e5df05 100644 --- a/xla/mlir_hlo/mhlo/IR/init.cc +++ b/xla/mlir_hlo/mhlo/IR/init.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc b/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc index 9efce77c1a328..260f736abfd4b 100644 --- a/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc +++ b/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.h b/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.h index a2c78c82c4d6b..53cc3854ea7ad 100644 --- a/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.h +++ b/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/IR/mhlo_canonicalize.td b/xla/mlir_hlo/mhlo/IR/mhlo_canonicalize.td index 28ea9a469ad84..e737d8cbf169f 100644 --- a/xla/mlir_hlo/mhlo/IR/mhlo_canonicalize.td +++ b/xla/mlir_hlo/mhlo/IR/mhlo_canonicalize.td @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/IR/register.h b/xla/mlir_hlo/mhlo/IR/register.h index 47a89a707115a..f104221d0b889 100644 --- a/xla/mlir_hlo/mhlo/IR/register.h +++ b/xla/mlir_hlo/mhlo/IR/register.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/analysis/CMakeLists.txt b/xla/mlir_hlo/mhlo/analysis/CMakeLists.txt index 4fb958c6d687b..c5a6a0a7f1cde 100644 --- a/xla/mlir_hlo/mhlo/analysis/CMakeLists.txt +++ b/xla/mlir_hlo/mhlo/analysis/CMakeLists.txt @@ -9,19 +9,3 @@ add_mlir_library(MhloAnalysis MLIRIR LmhloDialect ) - -add_mlir_library(MhloTestAnalysis - test_shape_component_analysis.cc - - LINK_COMPONENTS - Core - - DEPENDS - LMHLOTransformsPassIncGen - - LINK_LIBS PUBLIC - MLIRHLOAnalysis - MLIRAnalysis - MLIRPass - MLIRTransforms -) diff --git a/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc b/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc index 044f011c81e7a..4d21d8aecf6b0 100644 --- a/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc +++ b/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.h b/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.h index 3eb9fe6dd69e4..27d3a643de417 100644 --- a/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.h +++ b/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/analysis/test_shape_component_analysis.cc b/xla/mlir_hlo/mhlo/analysis/test_shape_component_analysis.cc deleted file mode 100644 index 0aedbeeef388e..0000000000000 --- a/xla/mlir_hlo/mhlo/analysis/test_shape_component_analysis.cc +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/analysis/shape_component_analysis.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { - -#define GEN_PASS_DEF_TESTSHAPECOMPONENTANALYSIS -#include "transforms/passes.h.inc" - -using SymbolicExpr = ShapeComponentAnalysis::SymbolicExpr; - -namespace { - -struct TestShapeComponentAnalysisPass - : public impl::TestShapeComponentAnalysisBase< - TestShapeComponentAnalysisPass> { - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } - - void runOnOperation() override { - ShapeComponentAnalysis shapeComponent; - llvm::outs() << "Testing : " << getOperation().getName() << '\n'; - // Analyze anything that looks like a shape tensor. - getOperation().walk([&](Operation* op) { - // Skip ops with more than one result. - if (op->getNumResults() != 1) return; - Value result = op->getResults().front(); - - // Dump shape info if any. - if (auto shapeInfo = shapeComponent.GetShapeInfo(result)) { - llvm::outs() << "Shape info for " << result << ":\n"; - for (const SymbolicExpr& d : *shapeInfo) { - llvm::outs().indent(2); - d.dump(llvm::outs()); - } - } - - // Dump value info if any. - if (auto valueInfo = shapeComponent.GetValueInfo(result)) { - llvm::outs() << "Value info for " << result << ":\n"; - for (const SymbolicExpr& d : *valueInfo) { - llvm::outs().indent(2); - d.dump(llvm::outs()); - } - } - }); - } -}; - -} // end anonymous namespace - -std::unique_ptr> -createTestShapeComponentAnalysisPass() { - return std::make_unique(); -} - -} // namespace mlir diff --git a/xla/mlir_hlo/mhlo/interfaces/bufferizable_op_interface_impl.h b/xla/mlir_hlo/mhlo/interfaces/bufferizable_op_interface_impl.h index a0ccb1d7a2118..ba886da0d726e 100644 --- a/xla/mlir_hlo/mhlo/interfaces/bufferizable_op_interface_impl.h +++ b/xla/mlir_hlo/mhlo/interfaces/bufferizable_op_interface_impl.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt b/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt index 1bd60985e1a0d..b7d447620a22b 100644 --- a/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt +++ b/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -52,8 +52,6 @@ add_mlir_library(MhloPasses legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc legalize_gather_to_torch_index_select/legalize_gather_to_torch_index_select.cc legalize_shape_computations/legalize_shape_computations.cc - legalize_sparse_ops/legalize_sparse_ops.cc - legalize_sparse_ops/sparse_ops_to_custom_calls.cc legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc lower_complex/lower_complex.cc @@ -70,13 +68,11 @@ add_mlir_library(MhloPasses prepare_for_export/prepare_for_export.cc optimize_mhlo/optimize_mhlo.cc optimize_mhlo/optimize_mhlo_pass.cc - rank_specialization/rank_specialization.cc restrict_max_rank/restrict_max_rank.cc shape_legalize_to_hlo/shape_legalize_to_hlo.cc shape_reification/shape_reification_pass.cc shape_simplification/shape_simplification.cc sink_constants_to_control_flow/sink_constants_to_control_flow.cc - sparse_rewriting/sparse_rewriting.cc symbolic_shape_optimization/symbolic_shape_optimization.cc test_infer_shaped_type/test_infer_shaped_type_pass.cc unfuse_batch_norm/unfuse_batch_norm.cc @@ -107,6 +103,26 @@ add_mlir_library(MhloPasses StablehloBroadcastUtils ) +add_mlir_library(MhloQuantToIntConversion + mhlo_quant_legalize_to_int/mhlo_quant_legalize_to_int.cc + + DEPENDS + MLIRhlo_opsIncGen + MLIRMhloPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + LmhloDialect + MhloDialect + MhloTypeConversion + MLIRIR + MLIRPass + MLIRMathDialect + MLIRTransforms + MLIRTransformUtils +) add_mlir_library(MhloToMemrefConversion hlo_legalize_to_memref/hlo_legalize_to_memref.cc @@ -176,7 +192,6 @@ add_mlir_library(MhloToStandard ) add_mlir_library(ChloPasses - chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc DEPENDS @@ -291,6 +306,7 @@ add_library(AllMhloPasses INTERFACE) target_link_libraries(AllMhloPasses INTERFACE ChloPasses MhloPasses + MhloQuantToIntConversion MhloToArithmeticConversion MhloToMemrefConversion MhloToStandard diff --git a/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc b/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc index 593a6852d330c..56ea0a42e79c1 100644 --- a/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc +++ b/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc b/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc deleted file mode 100644 index 96c0b282d4f7c..0000000000000 --- a/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc +++ /dev/null @@ -1,1966 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Enable the use of M_* math constants. -// NOTE: this must be first in the file to ensure that if cmath is transitively -// included by any other header it has the define set on first processing. -// https://docs.microsoft.com/en-us/cpp/c-runtime-library/math-constants -#define _USE_MATH_DEFINES -#include -#include -#include -#include -#include -#include - -#include "llvm/ADT/SmallVector.h" -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/map_chlo_to_hlo_op.h" -#include "mhlo/transforms/rewriters.h" -#include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/BroadcastUtils.h" -#include "stablehlo/dialect/ChloOps.h" -#include "utils/hlo_utils.h" - -namespace mlir { -namespace chlo { -namespace { - -struct ConvertConstantLikeOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ConstantLikeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto resultTy = op.getType().cast(); - - // Unranked uses are not supported. - if (!resultTy.hasRank()) return failure(); - - // Lower to MHLO constant if statically shaped. - if (resultTy.hasStaticShape()) { - auto complexAttr = op.getValue().dyn_cast(); - auto attr = complexAttr - ? DenseElementsAttr::get(resultTy, complexAttr.getValue()) - : DenseElementsAttr::get(resultTy, op.getValue()); - rewriter.replaceOpWithNewOp(op, attr); - return success(); - } - - // Lower to broadcasted constant. - auto loc = op.getLoc(); - Value constant = rewriter.create(loc, op.getValue()); - Value shape = rewriter.create(loc, adaptor.getOperand()); - rewriter.replaceOpWithNewOp( - op, resultTy, constant, shape, rewriter.getI64TensorAttr({})); - return success(); - } -}; - -template -Value materializeChebyshevPolynomialApproximation( - ConversionPatternRewriter &rewriter, Location loc, Value x, - ArrayRef coefficients) { - Value b0 = chlo::getConstantLike(rewriter, loc, 0.0, x); - Value b1 = chlo::getConstantLike(rewriter, loc, 0.0, x); - Value b2 = chlo::getConstantLike(rewriter, loc, 0.0, x); - for (FTy c : coefficients) { - b2 = b1; - b1 = b0; - b0 = rewriter.create(loc, x.getType(), x, b1); - b0 = rewriter.create(loc, x.getType(), b0, b2); - b0 = rewriter.create( - loc, x.getType(), b0, chlo::getConstantLike(rewriter, loc, c, x)); - } - Value result = rewriter.create(loc, x.getType(), b0, b2); - result = rewriter.create( - loc, x.getType(), result, chlo::getConstantLike(rewriter, loc, 0.5, x)); - return result; -} - -template -Value materializeBesselI1eApproximation(ConversionPatternRewriter &rewriter, - Location loc, Value x, - ArrayRef kI1eCoeffsA, - ArrayRef kI1eCoeffsB) { - Value z = rewriter.create(loc, x); - Value half = chlo::getConstantLike(rewriter, loc, 0.5, x); - Value two = chlo::getConstantLike(rewriter, loc, 2.0, x); - Value thirtyTwo = chlo::getConstantLike(rewriter, loc, 32.0, x); - Value eight = chlo::getConstantLike(rewriter, loc, 8.0, x); - - Value tmp = rewriter.create(loc, half, z); - tmp = rewriter.create(loc, tmp, two); - - Value xLe8 = materializeChebyshevPolynomialApproximation(rewriter, loc, tmp, - kI1eCoeffsA); - xLe8 = rewriter.create(loc, z, xLe8); - - tmp = rewriter.create(loc, thirtyTwo, z); - tmp = rewriter.create(loc, tmp, two); - Value xGt8 = materializeChebyshevPolynomialApproximation(rewriter, loc, tmp, - kI1eCoeffsB); - xGt8 = rewriter.create(loc, xGt8, - rewriter.create(loc, z)); - - Value isLe8 = rewriter.create(loc, z, eight, - mhlo::ComparisonDirection::LE); - - Value select = rewriter.create(loc, isLe8, xLe8, xGt8); - return rewriter.create( - loc, rewriter.create(loc, x), select); -} - -Value materializeBesselI1eApproximationF32(ConversionPatternRewriter &rewriter, - Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF32() && - "expect f32 element type"); - const float kI1eCoeffsA[] = { - 9.38153738649577178388E-9f, -4.44505912879632808065E-8f, - 2.00329475355213526229E-7f, -8.56872026469545474066E-7f, - 3.47025130813767847674E-6f, -1.32731636560394358279E-5f, - 4.78156510755005422638E-5f, -1.61760815825896745588E-4f, - 5.12285956168575772895E-4f, -1.51357245063125314899E-3f, - 4.15642294431288815669E-3f, -1.05640848946261981558E-2f, - 2.47264490306265168283E-2f, -5.29459812080949914269E-2f, - 1.02643658689847095384E-1f, -1.76416518357834055153E-1f, - 2.52587186443633654823E-1f}; - - const float kI1eCoeffsB[] = { - -3.83538038596423702205E-9f, -2.63146884688951950684E-8f, - -2.51223623787020892529E-7f, -3.88256480887769039346E-6f, - -1.10588938762623716291E-4f, -9.76109749136146840777E-3f, - 7.78576235018280120474E-1f}; - - return materializeBesselI1eApproximation(rewriter, loc, x, kI1eCoeffsA, - kI1eCoeffsB); -} - -Value materializeBesselI1eApproximationF64(ConversionPatternRewriter &rewriter, - Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF64() && - "expect f64 element type"); - - const double kI1eCoeffsA[] = { - 2.77791411276104639959E-18, -2.11142121435816608115E-17, - 1.55363195773620046921E-16, -1.10559694773538630805E-15, - 7.60068429473540693410E-15, -5.04218550472791168711E-14, - 3.22379336594557470981E-13, -1.98397439776494371520E-12, - 1.17361862988909016308E-11, -6.66348972350202774223E-11, - 3.62559028155211703701E-10, -1.88724975172282928790E-9, - 9.38153738649577178388E-9, -4.44505912879632808065E-8, - 2.00329475355213526229E-7, -8.56872026469545474066E-7, - 3.47025130813767847674E-6, -1.32731636560394358279E-5, - 4.78156510755005422638E-5, -1.61760815825896745588E-4, - 5.12285956168575772895E-4, -1.51357245063125314899E-3, - 4.15642294431288815669E-3, -1.05640848946261981558E-2, - 2.47264490306265168283E-2, -5.29459812080949914269E-2, - 1.02643658689847095384E-1, -1.76416518357834055153E-1, - 2.52587186443633654823E-1}; - - const double kI1eCoeffsB[] = { - 7.51729631084210481353E-18, 4.41434832307170791151E-18, - -4.65030536848935832153E-17, -3.20952592199342395980E-17, - 2.96262899764595013876E-16, 3.30820231092092828324E-16, - -1.88035477551078244854E-15, -3.81440307243700780478E-15, - 1.04202769841288027642E-14, 4.27244001671195135429E-14, - -2.10154184277266431302E-14, -4.08355111109219731823E-13, - -7.19855177624590851209E-13, 2.03562854414708950722E-12, - 1.41258074366137813316E-11, 3.25260358301548823856E-11, - -1.89749581235054123450E-11, -5.58974346219658380687E-10, - -3.83538038596423702205E-9, -2.63146884688951950684E-8, - -2.51223623787020892529E-7, -3.88256480887769039346E-6, - -1.10588938762623716291E-4, -9.76109749136146840777E-3, - 7.78576235018280120474E-1}; - - return materializeBesselI1eApproximation(rewriter, loc, x, - kI1eCoeffsA, kI1eCoeffsB); -} - -Value materializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc, - ValueRange args, FloatType minPrecisionTy, - Value callback(ConversionPatternRewriter &, - Location, ValueRange)) { - auto originalTy = getElementTypeOrSelf(args.front().getType()); - auto floatOriginalTy = originalTy.dyn_cast(); - bool needsUpcast = - floatOriginalTy && floatOriginalTy.getWidth() < minPrecisionTy.getWidth(); - - // Upcast arguments if necessary. - llvm::SmallVector castedArgs; - if (needsUpcast) { - for (Value a : args) { - castedArgs.push_back( - rewriter.create(loc, a, minPrecisionTy)); - } - args = castedArgs; - } - - Value result = callback(rewriter, loc, args); - - // Cast back if necessary. - if (needsUpcast) { - result = rewriter.create(loc, result, originalTy); - } - - return result; -} - -struct ConvertBesselI1eOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - BesselI1eOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value x = adaptor.getOperand(); - Type ty = x.getType().cast().getElementType(); - - // For now, we support only f64, f32, f16 and bf16. - // See https://www.tensorflow.org/api_docs/python/tf/math/bessel_i1e - if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16()) - return failure(); - - if (ty.isF64()) { - rewriter.replaceOp( - op, materializeBesselI1eApproximationF64(rewriter, loc, x)); - return success(); - } - - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(), - rewriter.getF32Type(), - &materializeBesselI1eApproximationF32)); - return success(); - } -}; - -template -Value materializePolynomialApproximation(ConversionPatternRewriter &rewriter, - Location loc, Value x, - ArrayRef coefficients) { - if (coefficients.empty()) return chlo::getConstantLike(rewriter, loc, 0.0, x); - - Value poly = chlo::getConstantLike(rewriter, loc, coefficients[0], x); - for (size_t i = 1; i < coefficients.size(); ++i) { - poly = rewriter.create(loc, x.getType(), poly, x); - poly = rewriter.create( - loc, x.getType(), poly, - chlo::getConstantLike(rewriter, loc, coefficients[i], x)); - } - return poly; -} - -// Precondition is |x| >= 1. Use erf approximation, otherwise. -// -// We rely on multiple polynomial approximations for x >= 1. We pass |x| as an -// argument and derive the final approximation for all |x| >= 1. -// This implementation is based on Cephes. -Value materializeErfcApproximationF64ForMagnituteGeOne( - ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF64() && - "expect f64 element type"); - const double kMaxlog = 7.09782712893383996843E2; - const double kErfcPCoefficients[] = { - 2.46196981473530512524E-10, 5.64189564831068821977E-1, - 7.46321056442269912687E0, 4.86371970985681366614E1, - 1.96520832956077098242E2, 5.26445194995477358631E2, - 9.34528527171957607540E2, 1.02755188689515710272E3, - 5.57535335369399327526E2}; - const double kErfcQCoefficients[] = { - 1.00000000000000000000E0, 1.32281951154744992508E1, - 8.67072140885989742329E1, 3.54937778887819891062E2, - 9.75708501743205489753E2, 1.82390916687909736289E3, - 2.24633760818710981792E3, 1.65666309194161350182E3, - 5.57535340817727675546E2}; - const double kErfcRCoefficients[] = { - 5.64189583547755073984E-1, 1.27536670759978104416E0, - 5.01905042251180477414E0, 6.16021097993053585195E0, - 7.40974269950448939160E0, 2.97886665372100240670E0}; - const double kErfcSCoefficients[] = { - 1.00000000000000000000E0, 2.26052863220117276590E0, - 9.39603524938001434673E0, 1.20489539808096656605E1, - 1.70814450747565897222E1, 9.60896809063285878198E0, - 3.36907645100081516050E0}; - - // Let z = -x^2. - Value xSq = rewriter.create(loc, x, x); - Value z = rewriter.create(loc, xSq); - - // Materialize polynomial approximation for x in [1, 8) as - // erfc(x) = exp(z) P(|x|) / Q(|x|). - Value expZ = rewriter.create(loc, z); - Value absX = rewriter.create(loc, x); - Value polP = materializePolynomialApproximation( - rewriter, loc, absX, llvm::ArrayRef(kErfcPCoefficients)); - Value expZMulPolyP = rewriter.create(loc, expZ, polP); - Value polQ = materializePolynomialApproximation( - rewriter, loc, absX, llvm::ArrayRef(kErfcQCoefficients)); - Value erfcApprox18 = rewriter.create(loc, expZMulPolyP, polQ); - - // Materialize polynomial approximation for x in >= 8 as - // erfc(x) exp(z) R(|x|) / S(|x|). - Value polR = materializePolynomialApproximation( - rewriter, loc, absX, llvm::ArrayRef(kErfcRCoefficients)); - Value expZMulPolyR = rewriter.create(loc, expZ, polR); - Value polS = materializePolynomialApproximation( - rewriter, loc, absX, llvm::ArrayRef(kErfcSCoefficients)); - Value erfcApprox8Inf = rewriter.create(loc, expZMulPolyR, polS); - - // Combine polynomial approximations for x >= 1. - Value eight = chlo::getConstantLike(rewriter, loc, 8.0, x); - Value absXLt8 = rewriter.create( - loc, absX, eight, mhlo::ComparisonDirection::LT); - Value erfcApprox = rewriter.create(loc, absXLt8, erfcApprox18, - erfcApprox8Inf); - - // Clamp to prevent overflow and materialize approximation for large x as - // erfc(x) = 0. - Value zLtNegMaxlog = rewriter.create( - loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x), - mhlo::ComparisonDirection::LT); - Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x); - Value erfcApproxClamped = - rewriter.create(loc, zLtNegMaxlog, zero, erfcApprox); - - // Derive approximation for x <= -1 as - // erfc(x) = 2 - erfc(-x). - // Reuse previously materialized approximations all of which take |x| as their - // argument. - Value xLtZero = rewriter.create( - loc, x, zero, mhlo::ComparisonDirection::LT); - Value two = chlo::getConstantLike(rewriter, loc, 2.0, x); - Value twoSubErfcApproxClamped = - rewriter.create(loc, two, erfcApproxClamped); - return rewriter.create(loc, xLtZero, twoSubErfcApproxClamped, - erfcApproxClamped); -} - -// Precondition is |x| <= 1. Use erfc approximation, otherwise. -// This implementation is based on Cephes. -Value materializeErfApproximationF64ForMagnituteLeOne( - ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF64() && - "expect f64 element type"); - const double kErfTCoefficients[] = { - 9.60497373987051638749E0, 9.00260197203842689217E1, - 2.23200534594684319226E3, 7.00332514112805075473E3, - 5.55923013010394962768E4}; - const double kErfUCoefficients[] = { - 1.00000000000000000000E0, 3.35617141647503099647E1, - 5.21357949780152679795E2, 4.59432382970980127987E3, - 2.26290000613890934246E4, 4.92673942608635921086E4}; - - // Materialize polynomial approximation for |x| <= 1 as - // erf(x) = x T(x^2) / U(x^2). - Value xSq = rewriter.create(loc, x, x); - Value polyT = materializePolynomialApproximation( - rewriter, loc, xSq, llvm::ArrayRef(kErfTCoefficients)); - Value xMulPolyT = rewriter.create(loc, x, polyT); - Value polyU = materializePolynomialApproximation( - rewriter, loc, xSq, llvm::ArrayRef(kErfUCoefficients)); - return rewriter.create(loc, xMulPolyT, polyU); -} - -// This implementation is based on Cephes. -Value materializeErfApproximationF64(ConversionPatternRewriter &rewriter, - Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF64() && - "expect f64 element type"); - - // Rely on erf approximation for |x| < 1 - // erf(x) = erf_approx(x) - Value erfApprox = - materializeErfApproximationF64ForMagnituteLeOne(rewriter, loc, x); - - // Rely on erfc approximation for |x| >= 1 and materialize erf as - // erf(x) = 1 - erfc_approx(x) - Value one = chlo::getConstantLike(rewriter, loc, 1.0, x); - Value erfcApprox = - materializeErfcApproximationF64ForMagnituteGeOne(rewriter, loc, x); - Value erfcBasedApprox = - rewriter.create(loc, one, erfcApprox); - - // Materialize approximation selection based on argument. - Value absX = rewriter.create(loc, x); - Value absXLtOne = rewriter.create( - loc, absX, one, mhlo::ComparisonDirection::LT); - return rewriter.create(loc, absXLtOne, erfApprox, - erfcBasedApprox); -} - -Value materializeErfcApproximationF64(ConversionPatternRewriter &rewriter, - Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF64() && - "expect f64 element type"); - - // Rely on erfc approximation for |x| >= 1 - // erfc(x) = erfc_approx(x) - Value erfcApprox = - materializeErfcApproximationF64ForMagnituteGeOne(rewriter, loc, x); - - // Rely on erf approximation for |x| < 1 and materialize erfc as - // erfc(x) = 1 - erf_approx(x) - Value one = chlo::getConstantLike(rewriter, loc, 1.0, x); - Value erfApprox = - materializeErfApproximationF64ForMagnituteLeOne(rewriter, loc, x); - Value erfBasedApprox = rewriter.create(loc, one, erfApprox); - - // Materialize approximation selection based on argument. - Value absX = rewriter.create(loc, x); - Value absXLtOne = rewriter.create( - loc, absX, one, mhlo::ComparisonDirection::LT); - return rewriter.create(loc, absXLtOne, erfBasedApprox, - erfcApprox); -} - -// Precondition is |x| >= 1. Use erf approximation, otherwise. -// -// We rely on multiple polynomial approximations for x >= 1. We pass |x| as an -// argument and derive the final approximation for all |x| >= 1. -// This implementation is based on Cephes. -Value materializeErfcApproximationF32ForMagnitudeGeOne( - ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF32() && - "expect f32 element type"); - const double kMaxlog = 88.72283905206835; - const float kErfcPCoefficients[] = { - +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1, - -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1, - +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1, - }; - const float kErfcRCoefficients[] = { - -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0, - +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1, - -2.820767439740514E-1, +5.641895067754075E-1, - }; - - // Let z = -x^2. - Value xSq = rewriter.create(loc, x, x); - Value z = rewriter.create(loc, xSq); - - // Materialize polynomial approximation for x >= 1 as - // erfc(x) = exp(z) 1/x P(1/x^2) if x in [1, 2) - // erfc(x) = exp(z) 1/x R(1/x^2) if x >= 2 - Value absX = rewriter.create(loc, x); - Value one = chlo::getConstantLike(rewriter, loc, 1.0, x); - Value reciprocalXSq = rewriter.create(loc, one, xSq); - Value expZ = rewriter.create(loc, z); - Value oneDivAbsX = rewriter.create(loc, one, absX); - Value expZMulOneDivAbsX = rewriter.create(loc, expZ, oneDivAbsX); - Value two = chlo::getConstantLike(rewriter, loc, 2.0, x); - Value absXLtTwo = rewriter.create( - loc, absX, two, mhlo::ComparisonDirection::LT); - Value polP = materializePolynomialApproximation( - rewriter, loc, reciprocalXSq, llvm::ArrayRef(kErfcPCoefficients)); - Value polR = materializePolynomialApproximation( - rewriter, loc, reciprocalXSq, llvm::ArrayRef(kErfcRCoefficients)); - Value poly = rewriter.create(loc, absXLtTwo, polP, polR); - Value erfcApprox = rewriter.create(loc, expZMulOneDivAbsX, poly); - - // Clamp to prevent overflow and materialize approximation for large x as - // erfc(x) = 0. - Value zLtNeqMaxlog = rewriter.create( - loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x), - mhlo::ComparisonDirection::LT); - Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x); - Value erfcApproxClamped = - rewriter.create(loc, zLtNeqMaxlog, zero, erfcApprox); - - // Derive approximation for x <= -1 as - // erfc(x) = 2 - erfc(-x). - // Reuse previously materialized approximations all of which take |x| as their - // argument. - Value xLtZero = rewriter.create( - loc, x, zero, mhlo::ComparisonDirection::LT); - Value twoSubErfcApprox = - rewriter.create(loc, two, erfcApproxClamped); - return rewriter.create(loc, xLtZero, twoSubErfcApprox, - erfcApproxClamped); -} - -// Precondition is |x| <= 1. Use erfc approximation, otherwise. -// This implementation is based on Cephes. -Value materializeErfApproximationF32ForMagnitudeLeOne( - ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF32() && - "expect f32 element type"); - const float kErfTCoefficients[] = { - +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3, - -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1, - +1.128379165726710E+0, - }; - - // Materialize polynomial approximation for |x| <= 1 as - // erf(x) = x T(x^2). - Value xSq = rewriter.create(loc, x, x); - Value polyT = materializePolynomialApproximation( - rewriter, loc, xSq, llvm::ArrayRef(kErfTCoefficients)); - return rewriter.create(loc, x, polyT); -} - -// This is the same approximation as used in Eigen. -Value materializeErfApproximationF32(ConversionPatternRewriter &rewriter, - Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF32() && - "expect f32 element type"); - const float kAlpha[] = { - -2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f, - -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f, - -1.60960333262415e-02f, - }; - const float kBeta[] = { - -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f, - -7.37332916720468e-03f, -1.42647390514189e-02f, - }; - - // Clamp argument between -4 and 4. - Value lb = chlo::getConstantLike(rewriter, loc, -4.0, x); - Value ub = chlo::getConstantLike(rewriter, loc, 4.0, x); - x = rewriter.create(loc, x.getType(), lb, x, ub); - Value xSq = rewriter.create(loc, x, x); - - // Materialize polynomial approximation for x in [-4, 4] as - // erf(x) = x * Alpha(x^2) / Beta(x^2). - Value alphaPoly = materializePolynomialApproximation(rewriter, loc, xSq, - llvm::ArrayRef(kAlpha)); - Value betaPoly = materializePolynomialApproximation(rewriter, loc, xSq, - llvm::ArrayRef(kBeta)); - Value xMulAlphaPoly = rewriter.create(loc, x, alphaPoly); - Value erf = rewriter.create(loc, xMulAlphaPoly, betaPoly); - Value lbErf = chlo::getConstantLike(rewriter, loc, -1.0, x); - Value ubErf = chlo::getConstantLike(rewriter, loc, 1.0, x); - return rewriter.create(loc, erf.getType(), lbErf, erf, ubErf); -} - -Value materializeErfcApproximationF32(ConversionPatternRewriter &rewriter, - Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF32() && - "expect f32 element type"); - - // Rely on erfc approximation for |x| >= 1 - // erfc(x) = erfc_approx(x) - Value erfcApprox = - materializeErfcApproximationF32ForMagnitudeGeOne(rewriter, loc, x); - - // Rely on erf approximation for |x| < 1 and materialize erfc as - // erfc(x) = 1 - erf_approx(x) - Value one = chlo::getConstantLike(rewriter, loc, 1.0, x); - Value erfApprox = - materializeErfApproximationF32ForMagnitudeLeOne(rewriter, loc, x); - Value erfBasedApprox = rewriter.create(loc, one, erfApprox); - - // Materialize approximation selection based on argument. - Value absX = rewriter.create(loc, x); - Value absXLtOne = rewriter.create( - loc, absX, one, mhlo::ComparisonDirection::LT); - return rewriter.create(loc, absXLtOne, erfBasedApprox, - erfcApprox); -} - -struct ConvertErfOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ErfOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value x = adaptor.getOperand(); - Type ty = x.getType().cast().getElementType(); - - // For now, we support only f64, f32, f16 and bf16. - if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16()) - return failure(); - - if (ty.isF64()) { - rewriter.replaceOp(op, materializeErfApproximationF64(rewriter, loc, x)); - return success(); - } - - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(), - rewriter.getF32Type(), - &materializeErfApproximationF32)); - return success(); - } -}; - -struct ConvertErfcOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ErfcOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value x = adaptor.getOperand(); - Type ty = x.getType().cast().getElementType(); - - // For now, we support only f64, f32, f16 and bf16. - if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16()) - return failure(); - - if (ty.isF64()) { - rewriter.replaceOp(op, materializeErfcApproximationF64(rewriter, loc, x)); - return success(); - } - - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(), - rewriter.getF32Type(), - &materializeErfcApproximationF32)); - return success(); - } -}; - -Value erfInv32(ConversionPatternRewriter &b, Location loc, ValueRange args) { - constexpr int kDegree = 9; - constexpr std::array wLessThan5Constants = { - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - constexpr std::array wGreaterThan5Constants = { - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; - - Value x = args[0]; - // Compute logarithm of (1+arg) using log1p(arg) which is more precise than - // log(1+arg) when arg is close to zero. For more details, see - // https://en.cppreference.com/w/cpp/numeric/math/log1p - Value minusXSquared = - b.create(loc, x, b.create(loc, x)); - Value w = - b.create(loc, b.create(loc, minusXSquared)); - - Value lt = b.create(loc, w, getConstantLike(b, loc, 5.0, x), - mhlo::ComparisonDirection::LT); - auto coefficient = [&](int i) { - return b.create( - loc, lt, getConstantLike(b, loc, wLessThan5Constants[i], x), - getConstantLike(b, loc, wGreaterThan5Constants[i], x)); - }; - w = b.create( - loc, lt, - b.create(loc, w, getConstantLike(b, loc, 2.5, x)), - b.create(loc, b.create(loc, w), - getConstantLike(b, loc, 3.0, x))); - Value p = coefficient(0); - for (int i = 1; i < kDegree; ++i) { - p = b.create(loc, coefficient(i), - b.create(loc, p, w)); - } - - // Result modulo edge cases. - Value result = b.create(loc, p, x); - - // Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is - // indeterminate, and can give nan or -/+inf.) - return b.create( - loc, - b.create(loc, b.create(loc, x), - getConstantLike(b, loc, 1, x), - mhlo::ComparisonDirection::EQ), - b.create(loc, x, getConstantLikeInfValue(b, loc, x, false)), - result); -} - -Value erfInv64(ConversionPatternRewriter &b, Location loc, ValueRange args) { - constexpr std::array wLessThan625Constants = { - -3.6444120640178196996e-21, -1.685059138182016589e-19, - 1.2858480715256400167e-18, 1.115787767802518096e-17, - -1.333171662854620906e-16, 2.0972767875968561637e-17, - 6.6376381343583238325e-15, -4.0545662729752068639e-14, - -8.1519341976054721522e-14, 2.6335093153082322977e-12, - -1.2975133253453532498e-11, -5.4154120542946279317e-11, - 1.051212273321532285e-09, -4.1126339803469836976e-09, - -2.9070369957882005086e-08, 4.2347877827932403518e-07, - -1.3654692000834678645e-06, -1.3882523362786468719e-05, - 0.0001867342080340571352, -0.00074070253416626697512, - -0.0060336708714301490533, 0.24015818242558961693, - 1.6536545626831027356}; - constexpr std::array wLessThan16Constants = { - 2.2137376921775787049e-09, 9.0756561938885390979e-08, - -2.7517406297064545428e-07, 1.8239629214389227755e-08, - 1.5027403968909827627e-06, -4.013867526981545969e-06, - 2.9234449089955446044e-06, 1.2475304481671778723e-05, - -4.7318229009055733981e-05, 6.8284851459573175448e-05, - 2.4031110387097893999e-05, -0.0003550375203628474796, - 0.00095328937973738049703, -0.0016882755560235047313, - 0.0024914420961078508066, -0.0037512085075692412107, - 0.005370914553590063617, 1.0052589676941592334, - 3.0838856104922207635, - }; - constexpr std::array wGreaterThan16Constants = { - -2.7109920616438573243e-11, -2.5556418169965252055e-10, - 1.5076572693500548083e-09, -3.7894654401267369937e-09, - 7.6157012080783393804e-09, -1.4960026627149240478e-08, - 2.9147953450901080826e-08, -6.7711997758452339498e-08, - 2.2900482228026654717e-07, -9.9298272942317002539e-07, - 4.5260625972231537039e-06, -1.9681778105531670567e-05, - 7.5995277030017761139e-05, -0.00021503011930044477347, - -0.00013871931833623122026, 1.0103004648645343977, - 4.8499064014085844221, - }; - - Value x = args[0]; - // Compute logarithm of (1+arg) using log1p(arg) which is more precise than - // log(1+arg) when arg is close to zero. For more details, see - // https://en.cppreference.com/w/cpp/numeric/math/log1p - Value minusXSquared = - b.create(loc, x, b.create(loc, x)); - Value w = - b.create(loc, b.create(loc, minusXSquared)); - - Value lt625 = b.create( - loc, w, getConstantLike(b, loc, 6.25, x), mhlo::ComparisonDirection::LT); - Value lt16 = b.create(loc, w, getConstantLike(b, loc, 16, x), - mhlo::ComparisonDirection::LT); - - auto coefficient = [&](int i) { - Value c = getConstantLike(b, loc, wLessThan625Constants[i], x); - if (i < 19) { - c = b.create( - loc, lt625, c, getConstantLike(b, loc, wLessThan16Constants[i], x)); - } - if (i < 17) { - c = b.create( - loc, lt16, c, getConstantLike(b, loc, wGreaterThan16Constants[i], x)); - } - return c; - }; - - Value sqrtW = b.create(loc, w); - Value wMinus3125 = - b.create(loc, w, getConstantLike(b, loc, 3.125, x)); - Value select2 = - b.create(loc, lt16, getConstantLike(b, loc, 3.25, w), - getConstantLike(b, loc, 5.0, w)); - Value select2Result = b.create(loc, sqrtW, select2); - w = b.create(loc, lt625, wMinus3125, select2Result); - - Value p = coefficient(0); - for (int i = 1; i < 17; ++i) { - p = b.create(loc, coefficient(i), - b.create(loc, p, w)); - } - for (int i = 17; i < 19; ++i) { - p = b.create( - loc, lt16, - b.create(loc, coefficient(i), - b.create(loc, p, w)), - p); - } - for (int i = 19; i < 23; ++i) { - p = b.create( - loc, lt625, - b.create(loc, coefficient(i), - b.create(loc, p, w)), - p); - } - - // Result modulo edge cases. - Value result = b.create(loc, p, x); - - // Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is - // indeterminate, and can give nan or -/+inf.) - return b.create( - loc, - b.create(loc, b.create(loc, x), - getConstantLike(b, loc, 1, x), - mhlo::ComparisonDirection::EQ), - b.create(loc, x, getConstantLikeInfValue(b, loc, x, false)), - result); -} - -struct ConvertErfInvOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ErfInvOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - if (op.getResult().getType().getElementType().isF64()) { - rewriter.replaceOp(op, erfInv64(rewriter, loc, adaptor.getOperands())); - return success(); - } - FloatType minPrecisionTy = rewriter.getF32Type(); - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(), - minPrecisionTy, &erfInv32)); - return success(); - } -}; - -// Coefficients for the Lanczos approximation of the gamma function. The -// coefficients are uniquely determined by the choice of g and n (kLanczosGamma -// and kLanczosCoefficients.size() + 1). The coefficients below correspond to -// [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and -// [7, 9] seemed to be the least sensitive to the quality of the log function. -// In particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5 -// for a particularly inaccurate log function. -constexpr double kLanczosGamma = 7; // aka g -constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478; -constexpr std::array kLanczosCoefficients = { - 676.520368121885098567009190444019, -1259.13921672240287047156078755283, - 771.3234287776530788486528258894, -176.61502916214059906584551354, - 12.507343278686904814458936853, -0.13857109526572011689554707, - 9.984369578019570859563e-6, 1.50563273514931155834e-7}; - -// Compute the Lgamma function using Lanczos' approximation from "A Precision -// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis -// series B. Vol. 1: -// lgamma(z + 1) = (log(2) + log(pi)) / 2 -// + (z + 1/2) * log(t(z)) -// - t(z) + log(a(z)) -// with t(z) = z + kLanczosGamma + 1/2 -// a(z) = kBaseLanczosCoeff -// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) -Value materializeLgamma(ConversionPatternRewriter &rewriter, Location loc, - ValueRange args) { - // If the input is less than 0.5 use Euler's reflection formula. - // gamma(x) = pi / (sin(pi * x) * gamma(1 - x)) - // Let z be - // z = -x if x < 1/2 - // z = x - 1 otheriwse - Value x = args.front(); - Value half = getConstantLike(rewriter, loc, 0.5, x); - Value needToReflect = rewriter.create( - loc, x, half, mhlo::ComparisonDirection::LT); - Value negX = rewriter.create(loc, x); - Value one = getConstantLike(rewriter, loc, 1, x); - Value xSubOne = rewriter.create(loc, x, one); - Value z = rewriter.create(loc, needToReflect, negX, xSubOne); - - // Materialize - // a(z) = kBaseLanczosCoeff - // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) - Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x); - for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) { - Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x); - Value oneBasedIndex = getConstantLike(rewriter, loc, i + 1, x); - Value quotient = rewriter.create( - loc, coeff, rewriter.create(loc, z, oneBasedIndex)); - a = rewriter.create(loc, a, quotient); - } - - // To improve accuracy on platforms with less-precise log implementations, - // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the - // device. - // Materialize as - // log(t) = log(kLanczosGamma + 1/2 + z) - // = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)). - Value lanczosPlusHalf = - getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x); - Value t = rewriter.create(loc, lanczosPlusHalf, z); - Value logTerm = - getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x); - Value log1pTerm = rewriter.create( - loc, rewriter.create(loc, z, lanczosPlusHalf)); - Value logT = rewriter.create(loc, logTerm, log1pTerm); - - // Note that t(z) may be large and we need to be careful not to overflow to - // infinity in the relevant term - // r = (z + 1/2) * log(t(z)) - t(z). - // Therefore, we compute this as - // r = (z + 1/2 - t(z) / log(t(z))) * log(t(z)). - Value tDivLogT = rewriter.create(loc, t, logT); - Value sum = rewriter.create( - loc, rewriter.create(loc, z, half), tDivLogT); - Value r = rewriter.create(loc, sum, logT); - - // Compute the final result (modulo reflection) as - // lgamma(z + 1) = (log(2) + log(pi)) / 2 + r + log(a(z)). - Value logA = rewriter.create(loc, a); - Value lgamma = rewriter.create( - loc, - rewriter.create( - loc, - getConstantLike(rewriter, loc, (std::log(2) + std::log(M_PI)) / 2, x), - r), - logA); - - // Compute the reflected value for x < 0.5 as - // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))). - // - // The abs is needed because lgamma is the log of the absolute value of the - // gamma function. - // - // We have to be careful when computing the final term above. gamma(x) goes - // to +/-inf at every integer x < 0, and this is controlled by the sin(pi * x) - // term. The slope is large, so precision is particularly important. - // - // Because abs(sin(pi * x)) has period of 1 we can equivalently use - // abs(sin(pi * frac(x))) where frac(x) is the fractional part of x. This is - // more numerically accurate: It doesn't overflow to inf like pi * x would and - // if x is an integer it evaluates to exactly 0 which is important because we - // then take the log of this value, and log(0) is inf. - // - // We don't have a frac(x) primitive in HLO and computing it is tricky, but - // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for our - // purposes to use abs(frac(x)) = abs(x) - floor(abs(x)). - // - // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close - // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain - // [0, 1] is symmetric across the line Y=0.5. - // - - // Convert values of abs_frac > 0.5 to (1 - abs_frac) to improve precision of - // pi * abs_frac for values of abs_frac close to 1. - Value abs = rewriter.create(loc, x); - Value absFrac = rewriter.create( - loc, abs, rewriter.create(loc, abs)); - Value reduceAbsFrac = rewriter.create( - loc, half, absFrac, mhlo::ComparisonDirection::LT); - absFrac = rewriter.create( - loc, reduceAbsFrac, rewriter.create(loc, one, absFrac), - absFrac); - - // Materialize reflection. - Value reflectionDenom = rewriter.create( - loc, - rewriter.create( - loc, rewriter.create( - loc, getConstantLike(rewriter, loc, M_PI, x), absFrac))); - Value lgammaReflection = rewriter.create( - loc, - rewriter.create( - loc, getConstantLike(rewriter, loc, std::log(M_PI), x), - reflectionDenom), - lgamma); - - // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf, - // then it "wins" and the result is +/-inf. - Value finiteReflectionDenom = - rewriter.create(loc, reflectionDenom); - Value negReflectionDenom = rewriter.create(loc, reflectionDenom); - lgammaReflection = rewriter.create( - loc, finiteReflectionDenom, lgammaReflection, negReflectionDenom); - - // Select whether or not to rely on the reflection. - lgamma = rewriter.create(loc, needToReflect, lgammaReflection, - lgamma); - - // Materialize +/-inf behavior as - // lgamma(+/-inf) = +inf. - Value xIsInf = rewriter.create(loc, x); - return rewriter.create( - loc, xIsInf, - chlo::getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false), - lgamma); -} - -// Express `cosh` as -// cosh(x) = (e^x + e^-x) / 2 -// = e^(x + log(1/2)) + e^(-x + log(1/2)) -// -// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not. -// -// This incorrectly overflows to inf for two f32 input values, namely -// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The -// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so -// we deem this acceptable. -Value materializeCoshApproximation(ConversionPatternRewriter &rewriter, - Location loc, ValueRange operands) { - CoshOp::Adaptor transformed(operands); - Value x = transformed.getOperand(); - - Value logOneHalf = - rewriter.create(loc, getConstantLike(rewriter, loc, 0.5, x)); - Value expAdd = rewriter.create( - loc, rewriter.create(loc, x, logOneHalf)); - Value expSub = rewriter.create( - loc, rewriter.create(loc, logOneHalf, x)); - return rewriter.create(loc, expAdd, expSub); -} - -struct ConvertCoshOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - CoshOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(), - rewriter.getF32Type(), - &materializeCoshApproximation)); - return success(); - } -}; - -// Compute the Digamma function using Lanczos' approximation from "A Precision -// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis -// series B. Vol. 1: -// digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z) -// with t(z) = z + kLanczosGamma + 1/2 -// a(z) = kBaseLanczosCoeff -// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) -// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) -Value materializeDigamma(ConversionPatternRewriter &rewriter, Location loc, - ValueRange args) { - // If the input is less than 0.5 use Euler's reflection formula. - // digamma(x) = digamma(1 - x) - pi * cot(pi * x) - // Let z be - // z = -x if x < 1/2 - // z = x - 1 otheriwse - Value x = args.front(); - Value half = getConstantLike(rewriter, loc, 0.5, x); - Value needToReflect = rewriter.create( - loc, x, half, mhlo::ComparisonDirection::LT); - Value negX = rewriter.create(loc, x); - Value one = getConstantLike(rewriter, loc, 1, x); - Value xSubOne = rewriter.create(loc, x, one); - Value z = rewriter.create(loc, needToReflect, negX, xSubOne); - - // Materialize - // a(z) = kBaseLanczosCoeff - // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) - // a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) - Value zero = getConstantLike(rewriter, loc, 0.0, x); - Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x); - Value aPrime = zero; - for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) { - Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x); - Value oneBasedIndex = getConstantLike(rewriter, loc, i + 1, x); - Value zTerm = rewriter.create(loc, z, oneBasedIndex); - aPrime = rewriter.create( - loc, aPrime, - rewriter.create( - loc, coeff, rewriter.create(loc, zTerm, zTerm))); - a = rewriter.create( - loc, a, rewriter.create(loc, coeff, zTerm)); - } - - // To improve accuracy on platforms with less-precise log implementations, - // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the - // device. - // Materialize as - // log(t) = log(kLanczosGamma + 1/2 + z) - // = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)). - Value lanczosPlusHalf = - getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x); - Value t = rewriter.create(loc, lanczosPlusHalf, z); - Value logTerm = - getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x); - Value log1pTerm = rewriter.create( - loc, rewriter.create(loc, z, lanczosPlusHalf)); - Value logT = rewriter.create(loc, logTerm, log1pTerm); - - // Materialize the final result (modulo reflection) as - // digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z). - Value aPrimeDivA = rewriter.create(loc, aPrime, a); - Value lanczosGammaDivT = rewriter.create( - loc, getConstantLike(rewriter, loc, kLanczosGamma, x), t); - Value digamma = rewriter.create( - loc, rewriter.create(loc, logT, aPrimeDivA), - lanczosGammaDivT); - - // We need to be careful how we compute cot(pi * input) below: For - // near-integral arguments, pi * input can lose precision. - // - // Input is already known to be less than 0.5 (otherwise we don't have to - // reflect). We shift values smaller than -0.5 into the range [-0.5, 0.5] to - // increase precision of pi * x and the resulting cotangent. - Value reducedX = rewriter.create( - loc, x, - rewriter.create( - loc, rewriter.create( - loc, rewriter.create( - loc, x, getConstantLike(rewriter, loc, 0.5, x))))); - - // Materialize reflection for inputs less than 0.5 as - // digamma(x) = digamma(1 - x) - pi * cot(pi * x) - // = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x) - Value pi = getConstantLike(rewriter, loc, M_PI, x); - Value piMulReducedX = rewriter.create(loc, pi, reducedX); - Value cos = rewriter.create(loc, piMulReducedX); - Value sin = rewriter.create(loc, piMulReducedX); - Value reflection = rewriter.create( - loc, digamma, - rewriter.create( - loc, rewriter.create(loc, pi, cos), sin)); - - // Select whether or not to rely on the reflection. - digamma = - rewriter.create(loc, needToReflect, reflection, digamma); - - // Digamma has poles at negative integers and zero; return nan for those. - Value isLeZero = rewriter.create( - loc, x, zero, mhlo::ComparisonDirection::LE); - Value isInt = rewriter.create( - loc, x, rewriter.create(loc, x), - mhlo::ComparisonDirection::EQ); - Value isPole = rewriter.create(loc, isLeZero, isInt); - return rewriter.create( - loc, isPole, - getConstantLike(rewriter, loc, std::numeric_limits::quiet_NaN(), - x), - digamma); -} - -Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc, - ValueRange args) { - assert(args.size() == 2); - Value x = args[0]; - Value q = args[1]; - static const std::array kZetaCoeffs{ - -7.1661652561756670113e18, - 1.8152105401943546773e17, - -4.5979787224074726105e15, - 1.1646782814350067249e14, - -2.950130727918164224e12, - 7.47242496e10, - -1.8924375803183791606e9, - 47900160.0, - -1209600.0, - 30240.0, - -720.0, - 12.0, - }; - - // For speed we'll always use 9 iterations for the initial series estimate, - // and a 12 term expansion for the Euler-Maclaurin formula. - Value a = q; - Value zero = chlo::getConstantLike(rewriter, loc, 0.0, a); - Value negPower = zero; - Value negX = rewriter.create(loc, x); - Value initialSum = rewriter.create(loc, q, negX); - Value one = chlo::getConstantLike(rewriter, loc, 1.0, a); - for (int i = 0; i < 9; ++i) { - a = rewriter.create(loc, a, one); - negPower = rewriter.create(loc, a, negX); - initialSum = rewriter.create(loc, initialSum, negPower); - } - a = rewriter.create(loc, a, one); - negPower = rewriter.create(loc, a, negX); - Value oneLikeX = chlo::getConstantLike(rewriter, loc, 1.0, x); - Value xMinusOne = rewriter.create(loc, x, oneLikeX); - Value negPowerMulA = rewriter.create(loc, negPower, a); - Value negPowerMulADivXMinusOne = - rewriter.create(loc, negPowerMulA, xMinusOne); - Value s = - rewriter.create(loc, initialSum, negPowerMulADivXMinusOne); - Value aInverseSquare = rewriter.create( - loc, one, rewriter.create(loc, a, a)); - - Value hornerSum = zero; - Value factor = one; - // Use Horner's rule for this. - // Note this differs from Cephes which does a 'naive' polynomial evaluation. - // Using Horner's rule allows to avoid some NaN's and Infs from happening, - // resulting in more numerically stable code. - for (int i = 0; i < 11; ++i) { - Value factorLhs = rewriter.create( - loc, x, chlo::getConstantLike(rewriter, loc, 22 - 2 * i, x)); - Value factorRhs = rewriter.create( - loc, x, chlo::getConstantLike(rewriter, loc, 21 - 2 * i, x)); - factor = rewriter.create(loc, factorLhs, factorRhs); - hornerSum = rewriter.create( - loc, factor, - rewriter.create( - loc, aInverseSquare, - rewriter.create( - loc, hornerSum, - chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a)))); - } - Value zeroPointFiveLikeNegPower = - chlo::getConstantLike(rewriter, loc, .5, negPower); - Value xDivA = rewriter.create(loc, x, a); - s = rewriter.create( - loc, s, - rewriter.create( - loc, negPower, - rewriter.create( - loc, zeroPointFiveLikeNegPower, - rewriter.create( - loc, xDivA, - rewriter.create( - loc, - chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11], - a), - hornerSum))))); - - // Use the initial zeta sum without the correction term coming - // from Euler-Maclaurin if it is accurate enough. - Value absNegPower = rewriter.create(loc, negPower); - Value absInitialSum = rewriter.create(loc, initialSum); - Value output = rewriter.create( - loc, - rewriter.create( - loc, absNegPower, - rewriter.create( - loc, absInitialSum, - chlo::getConstantLikeSmallestFiniteValue(rewriter, loc, a)), - mhlo::ComparisonDirection::LT), - initialSum, s); - - // Function is not defined for x < 1. - Value nan = chlo::getConstantLike( - rewriter, loc, std::numeric_limits::quiet_NaN(), x); - output = rewriter.create( - loc, - rewriter.create(loc, x, oneLikeX, - mhlo::ComparisonDirection::LT), - nan, output); - - // For q <= 0, x must be an integer. - Value qLeZero = rewriter.create( - loc, q, zero, mhlo::ComparisonDirection::LE); - Value xNotInt = rewriter.create( - loc, x, rewriter.create(loc, x), - mhlo::ComparisonDirection::NE); - Value xDomainError = rewriter.create(loc, qLeZero, xNotInt); - output = rewriter.create(loc, xDomainError, nan, output); - - // For all integer q <= 0, zeta has a pole. The limit is only defined as - // +inf if x is and even integer. - Value inf = chlo::getConstantLike(rewriter, loc, - std::numeric_limits::infinity(), x); - Value qIsInt = rewriter.create( - loc, q, rewriter.create(loc, q), - mhlo::ComparisonDirection::EQ); - Value atPole = rewriter.create(loc, qLeZero, qIsInt); - Value two = chlo::getConstantLike(rewriter, loc, 2.0, x); - Value xIsInt = rewriter.create( - loc, x, rewriter.create(loc, x), - mhlo::ComparisonDirection::EQ); - Value xIsEven = rewriter.create( - loc, rewriter.create(loc, x, two), zero, - mhlo::ComparisonDirection::EQ); - Value xIsEvenInt = rewriter.create(loc, xIsInt, xIsEven); - output = rewriter.create( - loc, atPole, rewriter.create(loc, xIsEvenInt, inf, nan), - output); - - // For x = 1, this is the harmonic series and diverges. - output = rewriter.create( - loc, - rewriter.create(loc, x, one, - mhlo::ComparisonDirection::EQ), - inf, output); - - return output; -} - -Value materializePolygamma(ConversionPatternRewriter &rewriter, Location loc, - ValueRange args) { - PolygammaOp::Adaptor transformed(args); - Value n = transformed.getN(); - Value x = transformed.getX(); - - // Handle integer n > 0. - Value one = getConstantLike(rewriter, loc, 1.0, x); - Value two = getConstantLike(rewriter, loc, 2.0, x); - Value sign = rewriter.create( - loc, - rewriter.create(loc, two, - rewriter.create(loc, n, two)), - one); - Value nPlusOne = rewriter.create(loc, n, one); - Value expLgammaNp1 = rewriter.create( - loc, rewriter.create(loc, nPlusOne)); - Value zeta = rewriter.create(loc, nPlusOne, x); - Value result = rewriter.create( - loc, rewriter.create(loc, sign, expLgammaNp1), zeta); - - // Handle n = 0. - Value zero = getConstantLike(rewriter, loc, 0.0, x); - Value nEqZero = rewriter.create( - loc, n, zero, mhlo::ComparisonDirection::EQ); - result = rewriter.create( - loc, nEqZero, rewriter.create(loc, x), result); - - // Check that n is a natural number. Return nan, otherwise. - Value nonInt = rewriter.create( - loc, n, rewriter.create(loc, n), - mhlo::ComparisonDirection::NE); - Value negative = rewriter.create( - loc, n, zero, mhlo::ComparisonDirection::LT); - Value nonNatural = rewriter.create(loc, nonInt, negative); - return rewriter.create( - loc, nonNatural, - getConstantLike(rewriter, loc, std::numeric_limits::quiet_NaN(), - x), - result); -} - -struct ConvertLgammaOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - LgammaOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FloatType minPrecisionTy = rewriter.getF32Type(); - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(), - minPrecisionTy, &materializeLgamma)); - return success(); - } -}; - -struct ConvertDigammaOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - DigammaOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FloatType minPrecisionTy = rewriter.getF32Type(); - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(), - minPrecisionTy, &materializeDigamma)); - return success(); - } -}; - -Value materializeNextAfter(ConversionPatternRewriter &rewriter, Location loc, - ValueRange operands) { - NextAfterOp::Adaptor transformed(operands); - Value x = transformed.getX(); - Value y = transformed.getY(); - auto resultTy = x.getType().cast(); - auto bitwidth = resultTy.getElementType().getIntOrFloatBitWidth(); - ImplicitLocOpBuilder b(loc, rewriter); - auto intTy = resultTy.clone(b.getIntegerType(bitwidth)); - auto xAsInt = b.create(intTy, x); - auto yAsInt = b.create(intTy, y); - - // The result is NaN if either "x" or "y" are NaN. - auto xIsNan = b.create(x, x, mhlo::ComparisonDirection::NE); - auto yIsNan = b.create(y, y, mhlo::ComparisonDirection::NE); - auto nanInput = b.create(xIsNan, yIsNan); - auto resultForNan = getConstantLike( - rewriter, loc, std::numeric_limits::quiet_NaN(), x); - auto resultForNanAsInt = - b.create(intTy, resultForNan); - - // The sign bit is the MSB. - const int64_t signBit = int64_t{1} << (bitwidth - 1); - // Discard the sign bit to make the result non-negative. - auto signMask = getConstantLike(rewriter, loc, signBit, xAsInt); - auto negatedSignMask = getConstantLike(rewriter, loc, ~signBit, xAsInt); - auto xAbs = b.create(xAsInt, negatedSignMask); - auto yAbs = b.create(yAsInt, negatedSignMask); - - // When both "x" and "y" are equal, the result is "y". - auto xAndYAreEqual = - b.create(x, y, mhlo::ComparisonDirection::EQ); - auto resultForEqual = yAsInt; - - // When both "x" and "y" are 0, the result is "y". This is a separate case - // from above because "x" and "y" might have a different sign. - auto zero = getConstantLike(rewriter, loc, 0, xAsInt); - auto xIsZero = - b.create(xAbs, zero, mhlo::ComparisonDirection::EQ); - auto yIsZero = - b.create(yAbs, zero, mhlo::ComparisonDirection::EQ); - auto resultForBothZero = yAsInt; - - auto xSign = b.create(xAsInt, signMask); - auto ySign = b.create(yAsInt, signMask); - - // If from == 0 && to != 0, we need to return the smallest subnormal number - // signed like "to". - auto one = getConstantLike(rewriter, loc, 1, xAsInt); - auto resultForXZeroYNonZero = b.create(ySign, one); - - // If the sign of "x" and "y" disagree: - // - we need to make the magnitude of "from" smaller so that it is closer to - // zero. - // - // Otherwise the signs agree: - // - "x" with a magnitude larger than "y" means we need to make the magnitude - // smaller. - // - "x" with a magnitude smaller than "y" means we need to make the magnitude - // larger. - auto signsDisagree = - b.create(xSign, ySign, mhlo::ComparisonDirection::NE); - auto xMagnitudeLargerThanY = - b.create(xAbs, yAbs, mhlo::ComparisonDirection::GT); - auto resultHasSmallerMagnitude = - b.create(xMagnitudeLargerThanY, signsDisagree); - auto minusOne = getConstantLike(rewriter, loc, -1, xAsInt); - auto magnitudeAdjustment = - b.create(resultHasSmallerMagnitude, minusOne, one); - Value result = b.create(xAsInt, magnitudeAdjustment); - // Handle from == +-0. - result = b.create( - xIsZero, - b.create(yIsZero, resultForBothZero, - resultForXZeroYNonZero), - result); - // Handle from == to. - result = b.create(xAndYAreEqual, resultForEqual, result); - // Handle isnan(x) || isnan(y). - result = b.create(nanInput, resultForNanAsInt, result); - - // Cast back to the original type. - return b.create(resultTy, result); -} - -struct ConvertNextAfterOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - NextAfterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOp( - op, materializeNextAfter(rewriter, op.getLoc(), adaptor.getOperands())); - return success(); - } -}; - -struct ConvertPolygammaOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - PolygammaOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - FloatType minPrecisionTy = rewriter.getF32Type(); - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(), - minPrecisionTy, &materializePolygamma)); - return success(); - } -}; - -// Sinh(x) = (e^x - e^-x) / 2 -// = e^(x + log(1/2)) - e^(-x + log(1/2)). -// -// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not -// inf. -// -// This incorrectly overflows to +/-inf for two f32 input values, namely -// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The -// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so -// we deem this acceptable. -Value materializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter, - Location loc, ValueRange operands) { - SinhOp::Adaptor transformed(operands); - Value x = transformed.getOperand(); - - Value logOneHalf = - rewriter.create(loc, getConstantLike(rewriter, loc, 0.5, x)); - Value expAdd = rewriter.create( - loc, rewriter.create(loc, x, logOneHalf)); - Value expSub = rewriter.create( - loc, rewriter.create(loc, logOneHalf, x)); - return rewriter.create(loc, expAdd, expSub); -} - -// Express `sinh` as -// sinh(x) = (e^x - e^-x) / 2 if |x| < 1 -// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise. -Value materializeSinhApproximation(ConversionPatternRewriter &rewriter, - Location loc, ValueRange operands) { - Value largeSinhResult = - materializeSinhApproximationForLargeX(rewriter, loc, operands); - - SinhOp::Adaptor transformed(operands); - Value x = transformed.getOperand(); - - // For smaller x, we get unwanted cancellations of e^x - e^-x, resulting in - // 0. - // Rewrite this to avoid that. We use expm1(x) because that preserves the - // first order term of the taylor series of e^x. - // (e^(x) - e^(-x)) / 2. = - // (e^(x) - 1 + 1 - e^(-x)) / 2. - // (expm1(x) + (e^(x) - 1) / e^x) / 2. - // (expm1(x) + expm1(x) / (expm1(x) + 1)) / 2. - Value expm1 = rewriter.create(loc, x); - Value one = getConstantLike(rewriter, loc, 1.0, x); - Value oneHalf = getConstantLike(rewriter, loc, 0.5, x); - Value expm1PlusOne = rewriter.create(loc, expm1, one); - Value ratio = rewriter.create(loc, expm1, expm1PlusOne); - Value sum = rewriter.create(loc, expm1, ratio); - Value smallSinhResult = rewriter.create(loc, oneHalf, sum); - - Value absX = rewriter.create(loc, x); - Value absXLtOne = rewriter.create( - loc, absX, one, mhlo::ComparisonDirection::LT); - return rewriter.create(loc, absXLtOne, smallSinhResult, - largeSinhResult); -} - -struct ConvertSinhOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - SinhOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value x = adaptor.getOperand(); - if (x.getType().cast().getElementType().isa()) { - rewriter.replaceOp(op, materializeSinhApproximationForLargeX( - rewriter, op.getLoc(), adaptor.getOperands())); - return success(); - } - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(), - rewriter.getF32Type(), - &materializeSinhApproximation)); - return success(); - } -}; - -// Converts chlo.top_k to MHLO iota, sort, and slice ops. -// -// chlo.top_k sorts along last dimension of the input tensor and then returns -// the top K components' values and indices. This is translated into a few -// ops in MHLO: first generating an integer sequence for the indices, -// then sort both the original input tensor and the indices togheter, and -// at last slice out the top K components. -// -// For example, for the following IR: -// -// %0:2 = "chlo.top_k"(%input, k=8): tensor<16x16xf32> -> -// (tensor<16x8xf32>, tensor<16x8xi32>) -// -// We will get: -// -// %1 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32> -// %2 = "mhlo.sort"(%input, %1) ({ -// ^bb0(%arg1: tensor, %arg2: tensor, -// %arg3: tensor, %arg4: tensor): -// %7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ... -// "mhlo.return"(%7) : (tensor) -> () -// }) {dimension = 1 : i64, is_stable = true} : ... -// %3 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : ... -// %4 = "mhlo.get_tuple_element"(%2) {index = 1 : i32} : ... -// %5 = "mhlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>, -// start_indices dense<0> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : -// (tensor<16x16xf32>) -> tensor<16x8xf32> -// %6 = "mhlo.slice"(%4) ... -// -// TODO(b/284078162): Decide what to do with this pattern given that we now -// have mhlo::TopKOp. No action needed for now given that mhlo::TopKOp is -// currently categorized as `hasPrivateFeaturesNotInStablehlo`. -struct ConvertTopKOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - TopKOp op, OpAdaptor /*adaptor*/, - ConversionPatternRewriter &rewriter) const override { - auto operandType = op.getOperand().getType().dyn_cast(); - if (!operandType) return failure(); - int64_t operandRank = operandType.getRank(); - int64_t lastDimIndex = operandRank - 1; - int64_t lastDimSize = operandType.getDimSize(lastDimIndex); - int64_t lastDimResultSize = - hlo::isDynamicDimSize(lastDimSize) - ? static_cast(op.getK()) - : std::min(static_cast(op.getK()), lastDimSize); - int64_t isDynamic = !operandType.hasStaticShape(); - auto i32Type = rewriter.getIntegerType(32); - Value opShapeValue, resultShapeValue; - if (isDynamic) { - SmallVector sizesI32x1; - for (auto i = 0; i < operandType.getRank(); ++i) { - auto sizeI32 = rewriter.create( - op.getLoc(), op.getOperand(), i); - auto sizeI32x1 = rewriter.create( - op.getLoc(), RankedTensorType::get({1}, i32Type), sizeI32); - sizesI32x1.push_back(sizeI32x1); - } - opShapeValue = - rewriter.create(op.getLoc(), sizesI32x1, - /*dimension=*/0); - auto lastDimI32 = rewriter.create( - op.getLoc(), - rewriter.getI32IntegerAttr(static_cast(lastDimResultSize))); - auto lastDimI32x1 = rewriter.create( - op.getLoc(), RankedTensorType::get({1}, i32Type), lastDimI32); - sizesI32x1.back() = lastDimI32x1; - resultShapeValue = - rewriter.create(op.getLoc(), sizesI32x1, - /*dimension=*/0); - } - - // Create an Iota op for indices. - Type iotaType = RankedTensorType::get(operandType.getShape(), i32Type); - Value iotaOp; - if (isDynamic) { - iotaOp = rewriter.create( - op.getLoc(), iotaType, opShapeValue, - rewriter.getI64IntegerAttr(lastDimIndex)); - } else { - iotaOp = rewriter.create( - op.getLoc(), iotaType, rewriter.getI64IntegerAttr(lastDimIndex)); - } - - // Create the sort op. It takes two inputs, one for the original input, the - // other for the indices. Use TOTALORDER comparison type instead of the - // default comparison if the element type is of type float. - Type elementType = operandType.getElementType(); - auto sortOp = - createSortOp(&rewriter, op.getLoc(), {op.getOperand(), iotaOp}, - {elementType, i32Type}, lastDimIndex, - /*isStable=*/true, - /*direction=*/mhlo::ComparisonDirection::GT); - - // Get the sorted input and index tuple element. - auto tupleFirstElement = sortOp.getResult(0); - auto tupleSecondElement = sortOp.getResult(1); - - SmallVector beginIndices(operandRank, 0); - auto endIndices = llvm::to_vector<4>(operandType.getShape()); - endIndices.back() = lastDimResultSize; - SmallVector strides(operandRank, 1); - - // Get the slice for the top K elements. - auto indicesTy = RankedTensorType::get(operandRank, rewriter.getI64Type()); - Value values, indices; - if (isDynamic) { - Value startIndices = rewriter.create( - op.getLoc(), DenseIntElementsAttr::get(indicesTy, beginIndices)); - Value lastIndices = rewriter.create( - op.getLoc(), resultShapeValue, rewriter.getI64Type()); - Value stridesOp = rewriter.create( - op.getLoc(), DenseIntElementsAttr::get(indicesTy, strides)); - - SmallVector resultShape = - llvm::to_vector<4>(operandType.getShape()); - resultShape.back() = lastDimResultSize; - RankedTensorType resultType = RankedTensorType::get( - resultShape, elementType, operandType.getEncoding()); - RankedTensorType indexResultType = - RankedTensorType::get(resultShape, i32Type); - - values = rewriter.create( - op.getLoc(), resultType, tupleFirstElement, startIndices, lastIndices, - stridesOp); - indices = rewriter.create( - op.getLoc(), indexResultType, tupleSecondElement, startIndices, - lastIndices, stridesOp); - } else { - values = rewriter.create( - op.getLoc(), tupleFirstElement, - DenseIntElementsAttr::get(indicesTy, beginIndices), - DenseIntElementsAttr::get(indicesTy, endIndices), - DenseIntElementsAttr::get(indicesTy, strides)); - indices = rewriter.create( - op.getLoc(), tupleSecondElement, - DenseIntElementsAttr::get(indicesTy, beginIndices), - DenseIntElementsAttr::get(indicesTy, endIndices), - DenseIntElementsAttr::get(indicesTy, strides)); - } - - rewriter.replaceOp(op, {values, indices}); - return success(); - } -}; - -struct ConvertZetaOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ZetaOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - FloatType minPrecisionTy = rewriter.getF32Type(); - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(), - minPrecisionTy, &materializeZeta)); - return success(); - } -}; - -struct ConvertSelectOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - BroadcastSelectOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Only support ranked operands. - Value pred = adaptor.getPred(); - Value onTrue = adaptor.getOnTrue(); - Value onFalse = adaptor.getOnFalse(); - auto predType = pred.getType().dyn_cast(); - auto onTrueType = onTrue.getType().dyn_cast(); - auto onFalseType = onFalse.getType().dyn_cast(); - auto resultType = op.getResult().getType().dyn_cast(); - if (!predType || !onTrueType || !onFalseType || !resultType) { - return failure(); - } - - auto loc = op.getLoc(); - - Value predShape = rewriter.createOrFold(loc, pred); - Value onTrueShape = rewriter.createOrFold(loc, onTrue); - Value onFalseShape = rewriter.createOrFold(loc, onFalse); - int64_t resultRank = std::max( - {predType.getRank(), onTrueType.getRank(), onFalseType.getRank()}); - - Value broadcastableCstr = rewriter.createOrFold( - loc, ValueRange{predShape, onTrueShape, onFalseShape}); - auto assumingOp = rewriter.create( - loc, ArrayRef{resultType}, broadcastableCstr); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.createBlock(&assumingOp.getDoRegion()); - - Value resultExtents = rewriter.createOrFold( - loc, shape::getExtentTensorType(op.getContext()), - ValueRange{predShape, onTrueShape, onFalseShape}, - /*error=*/nullptr); - auto shapeType = - RankedTensorType::get({resultRank}, rewriter.getIndexType()); - resultExtents = - rewriter.createOrFold(loc, shapeType, resultExtents); - - Value broadcastedPred = pred; - // Pred has an implicit broadcast for scalars, so use that when convenient. - if (predType.getRank() > 0) { - auto predBroadcastDimensions = llvm::to_vector<4>( - llvm::seq(resultRank - predType.getRank(), resultRank)); - broadcastedPred = rewriter.create( - loc, - RankedTensorType::get(resultType.getShape(), - predType.getElementType()), - pred, resultExtents, - rewriter.getI64TensorAttr(predBroadcastDimensions)); - } - auto onTrueBroadcastDimensions = llvm::to_vector<4>( - llvm::seq(resultRank - onTrueType.getRank(), resultRank)); - Value broadcastedOnTrue = rewriter.create( - loc, - RankedTensorType::get(resultType.getShape(), - onTrueType.getElementType()), - onTrue, resultExtents, - rewriter.getI64TensorAttr(onTrueBroadcastDimensions)); - auto onFalseBroadcastDimensions = llvm::to_vector<4>( - llvm::seq(resultRank - onFalseType.getRank(), resultRank)); - Value broadcastedOnFalse = rewriter.create( - loc, - RankedTensorType::get(resultType.getShape(), - onFalseType.getElementType()), - onFalse, resultExtents, - rewriter.getI64TensorAttr(onFalseBroadcastDimensions)); - - // And generate the final non-broadcasted ternary op. - Value finalResult = - rewriter.create(loc, resultType, broadcastedPred, - broadcastedOnTrue, broadcastedOnFalse); - rewriter.create(loc, finalResult); - rewriter.replaceOp(op, {assumingOp.getResult(0)}); - return success(); - } -}; - -// Converts binary ops that statically are determined to not broadcast directly -// to the corresponding mhlo non-broadcasting op. -template -struct ConvertTrivialNonBroadcastBinaryOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ChloOpTy op, typename ChloOpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Only rewrite for statically determinable non-broadcasting cases. - auto lhsType = - adaptor.getLhs().getType().template dyn_cast(); - auto rhsType = - adaptor.getRhs().getType().template dyn_cast(); - if (!lhsType || !rhsType) return failure(); - - // Requires rank broadcast. - if (lhsType.getRank() != rhsType.getRank()) return failure(); - // Any dynamic dimension may require broadcasting and requires more - // analysis. - if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) - return failure(); - - for (auto extents : llvm::zip(lhsType.getShape(), rhsType.getShape())) { - auto lhsExtent = std::get<0>(extents); - auto rhsExtent = std::get<1>(extents); - if (lhsExtent != rhsExtent) { - return failure(); - } - } - - rewriter.replaceOp(op, Adaptor::createOp(op, op.getResult().getType(), - adaptor.getOperands(), rewriter)); - return success(); - } -}; - -// Converts a binary op with ranked broadcasting operands to explicitly -// broadcast and invoke the corresponding mhlo non-broadcasting op. -// Note that dynamic broadcasting supported by this pattern is only valid for -// "numpy" broadcasting semantics as defined here: -// https://docs.scipy.org/doc/numpy/reference/ufuncs.html -// Specifically, this includes the following cases: -// - Same rank broadcast (operands have the same static rank). -// - Different-rank broadcast, either without a broadcast_dims attribte or -// with the broadcast_dims attribute set to map to a prefix padding. -// - Legal combinations of degenerate (1-dim) implicit broadcasting. -// The restriction on broadcast_dims derives from the definition of the -// `shape.broadcast` op, which only supports prefix-padding. -template -struct ConvertRankedDynamicBroadcastBinaryOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ChloOpTy op, typename ChloOpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Only support ranked operands. - Value lhs = adaptor.getLhs(); - Value rhs = adaptor.getRhs(); - auto lhsType = lhs.getType().dyn_cast(); - auto rhsType = rhs.getType().dyn_cast(); - auto resultType = - op.getResult().getType().template dyn_cast(); - if (!lhsType || !rhsType || !resultType) return failure(); - - // Check for "numpy"-style rank broadcast. - auto broadcastDimensions = op.getBroadcastDimensions(); - if (broadcastDimensions && - !hlo::isLegalNumpyRankedBroadcast(lhs, rhs, *broadcastDimensions)) { - // Note: It is unclear whether the general specification of explicit - // broadcast_dimensions on binary ops is a feature we want to carry - // forward. While it can technically be implemented for ranked-dynamic, - // it is incompatible with unranked inputs. If this warning is emitted - // in real programs, it is an indication that the feature should be - // implemented versus just falling back on the more standard definition - // of numpy-like prefix-padding. - op.emitWarning() << "unsupported non prefix-padded dynamic rank " - << "broadcast_dimensions = " << *broadcastDimensions; - return failure(); - } - - // Compute result shape. - auto loc = op.getLoc(); - - // Insert a constraint on the shapes being broadcastable and insert all - // future code into an assuming block reliant on the constraint. - Value lhsShape = rewriter.create(loc, lhs); - Value rhsShape = rewriter.create(loc, rhs); - auto broadcastableCstr = - rewriter.create(loc, lhsShape, rhsShape); - auto assumingOp = rewriter.create( - loc, ArrayRef{resultType}, broadcastableCstr.getResult()); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.createBlock(&assumingOp.getDoRegion()); - - int64_t resultRank = std::max(lhsType.getRank(), rhsType.getRank()); - Value resultExtents = - hlo::computeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs, - rewriter); - - // Note that we unconditionally emit DynamicBroadcastInDim ops and let - // downstream canonicalizations fold them away if possible. This is - // because, in the dynamic case, there are many corner cases regarding - // when it is safe to omit, and some of them require analysis to prove - // properly. - auto lhsBroadcastDimensions = llvm::to_vector<4>( - llvm::seq(resultRank - lhsType.getRank(), resultRank)); - Value broadcastedLhs = rewriter.create( - loc, - RankedTensorType::get(resultType.getShape(), lhsType.getElementType()), - lhs, resultExtents, rewriter.getI64TensorAttr(lhsBroadcastDimensions)); - auto rhsBroadcastDimensions = llvm::to_vector<4>( - llvm::seq(resultRank - rhsType.getRank(), resultRank)); - Value broadcastedRhs = rewriter.create( - loc, - RankedTensorType::get(resultType.getShape(), rhsType.getElementType()), - rhs, resultExtents, rewriter.getI64TensorAttr(rhsBroadcastDimensions)); - - // And generate the final non-broadcasted binary op. - Value finalResult = Adaptor::createOp( - op, resultType, {broadcastedLhs, broadcastedRhs}, rewriter); - rewriter.create(loc, finalResult); - rewriter.replaceOp(op, {assumingOp.getResult(0)}); - return success(); - } -}; - -class ConvertDynamicReshapeOp - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(chlo::DynamicReshapeOp op, - PatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto tensor = op.getOperand(); - auto shape = op.getOutputShape(); - - auto shapeTy = shape.getType().cast(); - auto resultTy = op.getType().cast(); - - Value inputShape = rewriter.create(loc, tensor); - Value numEls = rewriter.create(loc, inputShape); - Value cstr = rewriter.create(loc, numEls, shape); - rewriter.replaceOpWithNewOp( - op, cstr, [&](OpBuilder &b, Location l) { - Value computedShape = - b.create(l, shapeTy, numEls, shape); - SmallVector result; - result.push_back(b.create(l, resultTy, tensor, - computedShape)); - return result; - }); - - return success(); - } -}; - -#include "chlo_legalize_to_hlo/generated_chlo_legalize_to_hlo.inc" -} // namespace - -void populateChloBroadcastingPatterns(MLIRContext *context, - RewritePatternSet *patterns) { - // Instantiate conversion templates for conforming binary elementwise ops - // that do not have different dtypes between operands and results and do - // not have special attributes that need to be preserved. - populateForBroadcastingBinaryOp( - context, patterns, 10); - populateForBroadcastingBinaryOp( - context, patterns, 5); - patterns - ->add( - context); -} - -void populateDecomposeChloPatterns(MLIRContext *context, - RewritePatternSet *patterns) { - populateWithGenerated(*patterns); - - // Other patterns. - // clang-format off - patterns->add(context); - // clang-format on -} - -} // namespace chlo -} // namespace mlir diff --git a/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc b/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc index 8c26d40aff52f..d03c05880865e 100644 --- a/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc +++ b/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,61 +19,76 @@ limitations under the License. #include "mhlo/IR/hlo_ops.h" #include "mhlo/transforms/passes.h" #include "mhlo/transforms/rewriters.h" +#include "mhlo/utils/type_conversion.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Pass/Pass.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" namespace mlir { namespace mhlo { #define GEN_PASS_DEF_CHLOLEGALIZETOHLOPASS +#define GEN_PASS_DEF_CHLOLEGALIZETOHIGHLEVELMHLOPASS #include "mhlo/transforms/mhlo_passes.h.inc" namespace { -struct ChloLegalizeToHloPass - : public impl::ChloLegalizeToHloPassBase { - explicit ChloLegalizeToHloPass(bool legalizeBroadcasts, - bool expandCompositions) - : ChloLegalizeToHloPassBase< - ChloLegalizeToHloPass>::ChloLegalizeToHloPassBase() { - this->legalize_broadcasts_ = legalizeBroadcasts; - this->expand_compositions_ = expandCompositions; - } +struct ChloLegalizeToHighLevelMhloPass + : public impl::ChloLegalizeToHighLevelMhloPassBase< + ChloLegalizeToHighLevelMhloPass> { + using ChloLegalizeToHighLevelMhloPassBase:: + ChloLegalizeToHighLevelMhloPassBase; + + void runOnOperation() override { + MLIRContext &context = getContext(); + ConversionTarget conversionTarget(context); + RewritePatternSet conversionPatterns(&context); + + chlo::populateChloToHighLevelMhloOpPatterns(&context, &conversionPatterns); + + // Consider the mhlo dialect legal for tests. Also add helper dialects + // that are needed by the patterns. + conversionTarget.addLegalDialect(); + conversionTarget.addIllegalOp(); - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + if (failed(applyPartialConversion(getOperation(), conversionTarget, + std::move(conversionPatterns)))) { + return signalPassFailure(); + } } +}; + +struct ChloLegalizeToHloPass + : public impl::ChloLegalizeToHloPassBase { + using ChloLegalizeToHloPassBase::ChloLegalizeToHloPassBase; void runOnOperation() override { - ConversionTarget conversionTarget(getContext()); - RewritePatternSet conversionPatterns(&getContext()); - conversionTarget.addIllegalDialect(); + MLIRContext &context = getContext(); + ConversionTarget conversionTarget(context); + RewritePatternSet conversionPatterns(&context); + + stablehlo::StablehloToHloTypeConverter typeConverter; + chlo::populateChloToHloPatterns(&context, &typeConverter, + &conversionPatterns); // Consider the mhlo dialect legal for tests. Also add helper dialects // that are needed by the patterns. conversionTarget - .addLegalDialect(); + .addIllegalDialect(); + conversionTarget.addLegalDialect< + MhloDialect, mlir::arith::ArithDialect, mlir::func::FuncDialect, + mlir::tensor::TensorDialect, mlir::shape::ShapeDialect>(); conversionTarget.addLegalOp(); - if (legalize_broadcasts_) { - chlo::populateChloBroadcastingPatterns(&getContext(), - &conversionPatterns); - } - - if (expand_compositions_) { - chlo::populateDecomposeChloPatterns(&getContext(), &conversionPatterns); - } else { - conversionTarget - .addLegalOp(); - } - if (failed(applyPartialConversion(getOperation(), conversionTarget, std::move(conversionPatterns)))) { return signalPassFailure(); @@ -83,11 +98,26 @@ struct ChloLegalizeToHloPass } // namespace -std::unique_ptr> createChloLegalizeToHloPass( - bool legalizeBroadcasts, bool expandCompositions) { - return std::make_unique(legalizeBroadcasts, - expandCompositions); +} // namespace mhlo + +namespace chlo { +namespace { +#include "chlo_legalize_to_hlo/generated_chlo_legalize_to_hlo.inc" + +} // namespace + +void populateChloToHighLevelMhloOpPatterns(MLIRContext *, + RewritePatternSet *patterns) { + populateWithGenerated(*patterns); } -} // namespace mhlo +void populateChloToHloPatterns(MLIRContext *context, + TypeConverter *typeConverter, + RewritePatternSet *patterns) { + chlo::populateChloToHighLevelMhloOpPatterns(context, patterns); + stablehlo::populateChloToStablehloPatterns(context, patterns); + stablehlo::populateStablehloToHloPatterns(patterns, typeConverter, context); +} + +} // namespace chlo } // namespace mlir diff --git a/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.td b/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.td index 3090ef17a4882..497686bf2e2ab 100644 --- a/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.td +++ b/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.td @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,339 +19,22 @@ limitations under the License. // ambiguous/different for various backends. Avoid patterns that are actually // lowering to non-canonical forms. -include "mlir/Dialect/Shape/IR/ShapeOps.td" include "mlir/IR/OpBase.td" include "mhlo/IR/hlo_ops.td" include "stablehlo/dialect/ChloOps.td" -class MHLO_ComparisonDirectionValue : - ConstantAttr; - //===----------------------------------------------------------------------===// -// Unary op patterns. +// Direct CHLO->MHLO conversions //===----------------------------------------------------------------------===// -// Expand acos for non-complex arguments to MHLO dialect as follows: -// acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1 -// = pi if x == -1 -// -// TODO(b/237376133): Support operands with complex element types separately -// using the following formula. -// acos(x) = -(i * log(x + i * sqrt((1 + x) * (1 - x)))) -def : Pat<(CHLO_AcosOp NonComplexElementType:$input), - (MHLO_SelectOp - (MHLO_CompareOp - $input, - (MHLO_ConstantLike<"-1"> $input), - MHLO_ComparisonDirectionValue<"NE">, - (MHLO_DEFAULT_COMPARISON_TYPE) - ), - (MHLO_MulOp - (MHLO_ConstantLike<"2"> $input), - (MHLO_Atan2Op - (MHLO_SqrtOp - (MHLO_SubtractOp - (MHLO_ConstantLike<"1"> $input), - (MHLO_MulOp $input, $input) - ) - ), - (MHLO_AddOp - (MHLO_ConstantLike<"1"> $input), - $input - ) - ) - ), - (MHLO_ConstantLike<"M_PI"> $input) - )>; - -// Expand acosh to MHLO dialect as follows: -// acosh(x) = log(x + sqrt(x^2 - 1)) if x >= -1 -// = log(x + sqrt((x+1)*(x-1))) -// acosh(x) = nan if x < -1 -// -// If x^2 will overflow, we approximate sqrt(x^2 - 1) == x and compute as -// log(2*x) = log(2) + log(x). (Note this works because negative x never -// overflows; x < -1 simply yields nan. -def : Pat<(CHLO_AcoshOp NonComplexElementType:$input), - (MHLO_SelectOp - (MHLO_CompareOp - $input, - (MHLO_ConstantLike<"-1"> $input), - MHLO_ComparisonDirectionValue<"LT">, - (MHLO_DEFAULT_COMPARISON_TYPE) - ), - (MHLO_ConstantLike<"NAN"> $input), - (MHLO_SelectOp - (MHLO_CompareOp - $input, - (MHLO_SqrtOp - (MHLO_ConstantLikeMaxFiniteValue $input) - ), - MHLO_ComparisonDirectionValue<"GE">, - (MHLO_DEFAULT_COMPARISON_TYPE) - ), - (MHLO_AddOp - (MHLO_LogOp $input), - (MHLO_LogOp - (MHLO_ConstantLike<"2"> $input) - ) - ), - (MHLO_LogOp - (MHLO_AddOp - $input, - (MHLO_SqrtOp - (MHLO_MulOp - (MHLO_AddOp - (MHLO_ConstantLike<"1"> $input), - $input - ), - (MHLO_AddOp - (MHLO_ConstantLike<"-1"> $input), - $input - ) - ) - ) - ) - ) - ) - )>; - -// Expand acosh for complex arguments to MHLO dialect as -// acosh(x) = log(x + sqrt((x+1)*(x-1))) -// -// Per tensorflow/compiler/xla/client/lib/math.cc at the time of writing: -// "For now, we ignore the question of overflow if x is a -// complex type, because we don't yet have exhaustive tests for complex trig -// functions". -def : Pat<(CHLO_AcoshOp ComplexElementType:$input), - (MHLO_LogOp - (MHLO_AddOp - $input, - (MHLO_SqrtOp - (MHLO_MulOp - (MHLO_AddOp - $input, - (MHLO_ConstantLike<"1"> $input) - ), - (MHLO_SubtractOp - $input, - (MHLO_ConstantLike<"1"> $input) - ) - ) - ) - ) - )>; - - -// Expand asin to MHLO dialect as follows: -// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) -def : Pat<(CHLO_AsinOp $input), - (MHLO_MulOp - (MHLO_ConstantLike<"2"> $input), - (MHLO_Atan2Op - $input, - (MHLO_AddOp - (MHLO_ConstantLike<"1"> $input), - (MHLO_SqrtOp - (MHLO_SubtractOp - (MHLO_ConstantLike<"1"> $input), - (MHLO_MulOp $input, $input) - ) - ) - ) - ) - )>; - -// Expand asinh for non-complex arguments to MHLO dialect as -// asinh(x) = log(x + sqrt(x^2 + 1)) -// -// If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1) -// as 2*x and return log(2) + log(x). -// -// For small x, sqrt(x^2 + 1) will evaluate to 1 due to floating point -// arithmetic. However, we would like to retain the low order term of this, -// which is around 0.5 * x^2 using a binomial expansion. -// Let z = sqrt(a^2 + 1) -// The following rewrite retains the lower order term. -// log(a + sqrt(a^2 + 1)) -// = log((a + sqrt(a^2 + 1)) * (1 + sqrt(a^2 + 1)) / (1 + sqrt(a^2 + 1))) -// = log((a + a^2 + 1 + a * z + z) / (1 + z)) -// = log(1 + a + a^2 / (1 + z)) -// = log(1 + a + a^2 / (1 + sqrt(a^2 + 1))) -// -// If x is negative, the above would give us some trouble; we can't approximate -// the result as x + abs(x) = 0 but we are saved by the fact that asinh(-x) = -// -asinh(x). -def : Pat<(CHLO_AsinhOp NonComplexElementType:$input), - (MHLO_MulOp - (MHLO_SignOp $input), - (MHLO_SelectOp - (MHLO_CompareOp - (MHLO_AbsOp $input), - (MHLO_SqrtOp - (MHLO_ConstantLikeMaxFiniteValue $input) - ), - MHLO_ComparisonDirectionValue<"GE">, - (MHLO_DEFAULT_COMPARISON_TYPE) - ), - (MHLO_AddOp - (MHLO_LogOp - (MHLO_AbsOp $input) - ), - (MHLO_LogOp - (MHLO_ConstantLike<"2"> $input) - ) - ), - (MHLO_SelectOp - (MHLO_CompareOp - (MHLO_AbsOp $input), - (MHLO_ConstantLike<"1"> $input), - MHLO_ComparisonDirectionValue<"LE">, - (MHLO_DEFAULT_COMPARISON_TYPE) - ), - (MHLO_Log1pOp - (MHLO_AddOp - (MHLO_AbsOp $input), - (MHLO_MulOp - (MHLO_AbsOp $input), - (MHLO_DivOp - (MHLO_AbsOp $input), - (MHLO_AddOp - (MHLO_ConstantLike<"1"> $input), - (MHLO_SqrtOp - (MHLO_AddOp - (MHLO_MulOp - (MHLO_AbsOp $input), - (MHLO_AbsOp $input) - ), - (MHLO_ConstantLike<"1"> $input) - ) - ) - ) - ) - ) - ) - ), - (MHLO_LogOp - (MHLO_AddOp - (MHLO_AbsOp $input), - (MHLO_SqrtOp - (MHLO_AddOp - (MHLO_MulOp - (MHLO_AbsOp $input), - (MHLO_AbsOp $input) - ), - (MHLO_ConstantLike<"1"> $input) - ) - ) - ) - ) - ) - ) - )>; - -// Expand asinh for complex arguments to MHLO dialect as -// asinh(x) = log(x + sqrt(x^2 + 1)) -// -// Per tensorflow/compiler/xla/client/lib/math.cc at the time of writing: -// "For now, we ignore the question of overflow if x is a -// complex type, because we don't yet have exhaustive tests for complex trig -// functions". -def : Pat<(CHLO_AsinhOp ComplexElementType:$input), - (MHLO_LogOp - (MHLO_AddOp - $input, - (MHLO_SqrtOp - (MHLO_AddOp - (MHLO_MulOp $input, $input), - (MHLO_ConstantLike<"1"> $input) - ) - ) - ) - )>; - -// Express `atan` as -// atan(x) = atan2(x, 1) -def : Pat<(CHLO_AtanOp $input), - (MHLO_Atan2Op - $input, - (MHLO_ConstantLike<"1"> $input) - )>; - -// Express `atanh` for non-complex arguments as follows: -// atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1 -// atanh(x) = nan otherwise -def : Pat<(CHLO_AtanhOp NonComplexElementType:$input), - (MHLO_SelectOp - (MHLO_CompareOp - (MHLO_AbsOp $input), - (MHLO_ConstantLike<"1"> $input), - MHLO_ComparisonDirectionValue<"GT">, - (MHLO_DEFAULT_COMPARISON_TYPE) - ), - (MHLO_ConstantLike<"NAN"> $input), - (MHLO_MulOp - (MHLO_SubtractOp - (MHLO_Log1pOp $input), - (MHLO_Log1pOp - (MHLO_NegOp $input) - ) - ), - (MHLO_ConstantLike<"0.5"> $input) - ) - )>; - -// Express `atanh` for complex arguments as follows: -// atanh(x) = (log(1 + x) - log(1 + (-x))) * 0.5 -// -// Per tensorflow/compiler/xla/client/lib/math.cc at the time of writing: -// "For now, we ignore the nan edge case for complex inputs, -// because we don't yet have exhaustive tests for complex trig functions". -def : Pat<(CHLO_AtanhOp ComplexElementType:$input), - (MHLO_MulOp - (MHLO_SubtractOp - (MHLO_Log1pOp $input), - (MHLO_Log1pOp - (MHLO_NegOp $input) - ) - ), - (MHLO_ConstantLike<"0.5"> $input) - )>; - -// Express `conj` as -// conj(x) = (re(x), -im(x)). -def : Pat<(CHLO_ConjOp $v), - (MHLO_ComplexOp (MHLO_RealOp $v), (MHLO_NegOp (MHLO_ImagOp $v)))>; - -// Express `is_inf` as -// is_inf(x) = is_pos_inf(|x|) -def : Pat<(CHLO_IsInfOp NonComplexElementType:$input), - (CHLO_IsPosInfOp - (MHLO_AbsOp $input) - )>; - -// Express `is_pos_inf` as -// is_pos_inf(x) = (x == +inf) -def : Pat<(CHLO_IsPosInfOp NonComplexElementType:$input), - (MHLO_CompareOp - $input, - (MHLO_ConstantLikePosInfValue $input), - MHLO_ComparisonDirectionValue<"EQ">, - (MHLO_DEFAULT_COMPARISON_TYPE) - )>; - -// Express `is_neg_inf` as -// is_neg_inf(x) = (x == -inf) -def : Pat<(CHLO_IsNegInfOp NonComplexElementType:$input), - (MHLO_CompareOp - $input, - (MHLO_ConstantLikeNegInfValue $input), - MHLO_ComparisonDirectionValue<"EQ">, - (MHLO_DEFAULT_COMPARISON_TYPE) - )>; +def : Pat<(CHLO_TanOp $v), + (MHLO_TanOp $v), + [], [], (addBenefit 10)>; -def : Pat<(CHLO_ConstantOp $v), - (MHLO_ConstantOp $v)>; +def : Pat<(CHLO_ErfOp $v), + (MHLO_ErfOp $v), + [], [], (addBenefit 10)>; -def : Pat<(CHLO_TanOp $v), - (MHLO_TanOp $v)>; +def : Pat<(CHLO_TopKOp AnyRankedTensor:$v, $k), + (MHLO_TopKOp $v, $k, ConstBoolAttrTrue), + [], [], (addBenefit 10)>; diff --git a/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc b/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc index 4d1d671ce77c5..7f1e7de597b53 100644 --- a/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc +++ b/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/constraint_fusion/constraint_fusion_pass.cc b/xla/mlir_hlo/mhlo/transforms/constraint_fusion/constraint_fusion_pass.cc index 4dae80389f567..7dc6080c07ccb 100644 --- a/xla/mlir_hlo/mhlo/transforms/constraint_fusion/constraint_fusion_pass.cc +++ b/xla/mlir_hlo/mhlo/transforms/constraint_fusion/constraint_fusion_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/convert_to_signless/convert_to_signless_pass.cc b/xla/mlir_hlo/mhlo/transforms/convert_to_signless/convert_to_signless_pass.cc index ea4bd7dd66080..c95ea747ea136 100644 --- a/xla/mlir_hlo/mhlo/transforms/convert_to_signless/convert_to_signless_pass.cc +++ b/xla/mlir_hlo/mhlo/transforms/convert_to_signless/convert_to_signless_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/expand_hlo_tuples/expand_hlo_tuples.cc b/xla/mlir_hlo/mhlo/transforms/expand_hlo_tuples/expand_hlo_tuples.cc index 7f70f4339c7ee..6b514e720fc53 100644 --- a/xla/mlir_hlo/mhlo/transforms/expand_hlo_tuples/expand_hlo_tuples.cc +++ b/xla/mlir_hlo/mhlo/transforms/expand_hlo_tuples/expand_hlo_tuples.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/expand_ops_simplifier/expand_ops_simplifier.cc b/xla/mlir_hlo/mhlo/transforms/expand_ops_simplifier/expand_ops_simplifier.cc index 1327bdd2f5b1c..5113736c0b3bb 100644 --- a/xla/mlir_hlo/mhlo/transforms/expand_ops_simplifier/expand_ops_simplifier.cc +++ b/xla/mlir_hlo/mhlo/transforms/expand_ops_simplifier/expand_ops_simplifier.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -52,9 +53,9 @@ struct SelectAndScatterExpanderPattern PatternRewriter& rewriter) const override { // Capture original values with variables ImplicitLocOpBuilder builder(sas.getLoc(), rewriter); - TypedValue operand = sas.getOperand(); + TypedValue operand = sas.getOperand(); llvm::ArrayRef operandShape = operand.getType().getShape(); - TypedValue source = sas.getSource(); + TypedValue source = sas.getSource(); Value initValue = sas.getInitValue(); Region& select = sas.getSelect(); Region& scatter = sas.getScatter(); diff --git a/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc b/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc index 2f266d30a5527..d2058c0e23254 100644 --- a/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc +++ b/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -224,8 +225,11 @@ LogicalResult tryLowerTo1DOr2DReduction( int64_t reductionDim = leadingReduction ? 0 : 1; auto reductionDimAttr = rewriter.getI64VectorAttr({reductionDim}); Value initVal = op.getInitValues().front(); - auto reductionOp = - rewriter.create(loc, intermResult, initVal, reductionDimAttr); + SmallVector elementTypes{llvm::map_range( + op.getBody().front().getTerminator()->getOperands(), + [](Value v) { return v.getType().cast().getElementType(); })}; + auto reductionOp = rewriter.create(loc, intermResult, initVal, + reductionDimAttr, elementTypes); rewriter.inlineRegionBefore(op.getBody(), reductionOp.getBody(), reductionOp.getBody().begin()); intermResult = reductionOp->getResults().front(); diff --git a/xla/mlir_hlo/mhlo/transforms/hlo_legalize_shape_ops_to_standard/hlo_legalize_shape_ops_to_standard.cc b/xla/mlir_hlo/mhlo/transforms/hlo_legalize_shape_ops_to_standard/hlo_legalize_shape_ops_to_standard.cc index eb9b5db97695b..702deddaa0a51 100644 --- a/xla/mlir_hlo/mhlo/transforms/hlo_legalize_shape_ops_to_standard/hlo_legalize_shape_ops_to_standard.cc +++ b/xla/mlir_hlo/mhlo/transforms/hlo_legalize_shape_ops_to_standard/hlo_legalize_shape_ops_to_standard.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc b/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc index f62debe75e5d1..fead7be62bf1c 100644 --- a/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc +++ b/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -207,6 +207,7 @@ void populateScalarHloToArithmeticConversionPatterns( ScalarHloToArithmeticPattern, ScalarHloToArithmeticPattern, ScalarHloToArithmeticPattern, + ScalarHloToArithmeticPattern, ScalarHloToArithmeticPattern, ScalarHloToArithmeticPattern, ScalarHloToArithmeticPattern, diff --git a/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_memref/hlo_legalize_to_memref.cc b/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_memref/hlo_legalize_to_memref.cc index bb97f78dac233..1df0035b53e6f 100644 --- a/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_memref/hlo_legalize_to_memref.cc +++ b/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_memref/hlo_legalize_to_memref.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc b/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc index 4774aac93d4b5..a7ef5e93d1585 100644 --- a/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc +++ b/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -52,7 +53,7 @@ bool hasPrivateFeaturesNotInStablehlo(HloOpTy hloOp) { // Please let us know if we missed something, and we'll recategorize them. if (isa(hloOp.getOperation())) { return true; } @@ -140,6 +141,18 @@ std::optional getPublicFeaturesNotInStablehlo(HloOpTy hloOp) { mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI) return 1; } + // StableHLO doesn't support TopK yet. + // Proposal: https://github.com/openxla/stablehlo/pull/1593 + if constexpr (std::is_same::value) { + // Version 1: Initial version for TopK. + return 1; + } + // StableHLO doesn't support TopK yet. + // Proposal: https://github.com/openxla/stablehlo/pull/1593 + if constexpr (std::is_same::value) { + // Version 1: Initial version for ErfOp. + return 1; + } return std::nullopt; } @@ -148,6 +161,78 @@ bool hasPublicFeaturesNotInStablehlo(HloOpTy op) { return getPublicFeaturesNotInStablehlo(op).has_value(); } +template +bool isDenseI64Array(mlir::StringAttr hloName) { + if (std::is_same::value && + hloName == "broadcast_sizes") + return true; + if (std::is_same::value && + hloName == "broadcast_dimensions") + return true; + if ((std::is_same::value || + std::is_same::value) && + (hloName == "window_strides" || hloName == "lhs_dilation" || + hloName == "rhs_dilation")) + return true; + if (std::is_same::value && + (hloName == "broadcast_dimensions" || + hloName == "known_expanding_dimensions" || + hloName == "known_nonexpanding_dimensions")) + return true; + if ((std::is_same::value || + std::is_same::value) && + hloName == "slice_sizes") + return true; + if (std::is_same::value && + hloName == "fft_length") + return true; + if ((std::is_same::value || + std::is_same::value || + std::is_same::value) && + hloName == "dimensions") + return true; + if (std::is_same::value && + (hloName == "edge_padding_low" || hloName == "edge_padding_high" || + hloName == "interior_padding")) + return true; + if (std::is_same::value && + (hloName == "window_dimensions" || hloName == "window_strides" || + hloName == "base_dilations" || hloName == "window_dilations")) + return true; + if (std::is_same::value && + (hloName == "window_dimensions" || hloName == "window_strides")) + return true; + if (std::is_same::value && + (hloName == "start_indices" || hloName == "limit_indices" || + hloName == "strides")) + return true; + if (std::is_same::value && + hloName == "permutation") + return true; + return false; +} + +template +Attribute convertDenseArray(mlir::StringAttr hloName, Attribute hloAttr) { + auto denseInts = hloAttr.dyn_cast(); + if (!denseInts) return {}; + + if ((std::is_same::value || + std::is_same::value) && + hloName == "window_reversal") { + return DenseBoolArrayAttr::get( + hloAttr.getContext(), llvm::to_vector(denseInts.getValues())); + } + + // Handle DenseIntElementsAttr --> DenseI64ArrayAttr for StableHLO ops that + // use dense arrays. This is temporary while MHLO integrates this change. + if (isDenseI64Array(hloName)) + return DenseI64ArrayAttr::get( + hloAttr.getContext(), llvm::to_vector(denseInts.getValues())); + + return {}; +} + #define RETURN_CONVERTED_ENUM_ATTR(Name) \ auto hloValue = mhlo::stringify##Name(attr.getValue()); \ auto stablehloValue = stablehlo::symbolize##Name(hloValue); \ @@ -494,7 +579,11 @@ class HloToStablehloOpConverter : public OpConversionPattern { hloOp.getCustomCallSchedule() == mhlo::CustomCallSchedule::NONE) continue; } - auto stablehloAttr = convertAttr(hloAttr.getValue()); + auto stablehloAttr = convertDenseArray>( + hloAttr.getName(), hloAttr.getValue()); + if (!stablehloAttr) { + stablehloAttr = convertAttr(hloAttr.getValue()); + } if (!stablehloAttr) return failure(); stablehloAttrs.push_back({hloAttr.getName(), stablehloAttr}); } @@ -505,12 +594,12 @@ class HloToStablehloOpConverter : public OpConversionPattern { // for the generic builder. HloToStablehloOp stablehloOp; if constexpr (std::is_same::value) { - stablehloOp = rewriter.replaceOpWithNewOp( - hloOp, stablehloTypes, stablehloOperands, stablehloAttrs, + stablehloOp = rewriter.create( + hloOp.getLoc(), stablehloTypes, stablehloOperands, stablehloAttrs, hloOp.getBranches().size()); } else { - stablehloOp = rewriter.replaceOpWithNewOp>( - hloOp, stablehloTypes, stablehloOperands, stablehloAttrs); + stablehloOp = rewriter.create>( + hloOp.getLoc(), stablehloTypes, stablehloOperands, stablehloAttrs); } // Finally, populate the regions while converting argument types @@ -524,6 +613,8 @@ class HloToStablehloOpConverter : public OpConversionPattern { /*entryConversion=*/nullptr))) return failure(); } + + rewriter.replaceOp(hloOp, stablehloOp); return success(); } @@ -566,7 +657,8 @@ void populateHloToStablehloPatterns(RewritePatternSet* patterns, #include "stablehlo/dialect/StablehloOps.cpp.inc" >(patterns, converter, context, allowExperimentalFeatures); - populateHloToStablehloCustomCallPatterns( + populateHloToStablehloCustomCallPatterns( patterns, converter, context, allowExperimentalFeatures); } diff --git a/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc b/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc index 05bbd455f861f..22b439460821c 100644 --- a/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc +++ b/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_broadcast_to_broadcast_in_dim/legalize_broadcast_to_broadcast_in_dim.cc b/xla/mlir_hlo/mhlo/transforms/legalize_broadcast_to_broadcast_in_dim/legalize_broadcast_to_broadcast_in_dim.cc index 1747b8eaa0cef..94acce2ec180e 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_broadcast_to_broadcast_in_dim/legalize_broadcast_to_broadcast_in_dim.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_broadcast_to_broadcast_in_dim/legalize_broadcast_to_broadcast_in_dim.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_control_flow/legalize_control_flow.cc b/xla/mlir_hlo/mhlo/transforms/legalize_control_flow/legalize_control_flow.cc index dec0b535e77a3..d56a1da759109 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_control_flow/legalize_control_flow.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_control_flow/legalize_control_flow.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_create_token_to_after_all/legalize_create_token_to_after_all.cc b/xla/mlir_hlo/mhlo/transforms/legalize_create_token_to_after_all/legalize_create_token_to_after_all.cc index 67b36657112f4..1deadbdae3dfe 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_create_token_to_after_all/legalize_create_token_to_after_all.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_create_token_to_after_all/legalize_create_token_to_after_all.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_cross_replica_sum_to_all_reduce/legalize_cross_replica_sum_to_all_reduce.cc b/xla/mlir_hlo/mhlo/transforms/legalize_cross_replica_sum_to_all_reduce/legalize_cross_replica_sum_to_all_reduce.cc index a0862ed390918..dfc541c591a4d 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_cross_replica_sum_to_all_reduce/legalize_cross_replica_sum_to_all_reduce.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_cross_replica_sum_to_all_reduce/legalize_cross_replica_sum_to_all_reduce.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_dot_general_to_dot/legalize_dot_general_to_dot.cc b/xla/mlir_hlo/mhlo/transforms/legalize_dot_general_to_dot/legalize_dot_general_to_dot.cc index badbb8dec33fc..ea37c6104e62c 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_dot_general_to_dot/legalize_dot_general_to_dot.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_dot_general_to_dot/legalize_dot_general_to_dot.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc b/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc index e312638579d3c..bfeeaed83f89d 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc b/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc index 92fe475423427..0dc4495cdbbc6 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_gather_to_torch_index_select/legalize_gather_to_torch_index_select.cc b/xla/mlir_hlo/mhlo/transforms/legalize_gather_to_torch_index_select/legalize_gather_to_torch_index_select.cc index 8dcb022850320..aba642b45cdf0 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_gather_to_torch_index_select/legalize_gather_to_torch_index_select.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_gather_to_torch_index_select/legalize_gather_to_torch_index_select.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_shape_computations/legalize_shape_computations.cc b/xla/mlir_hlo/mhlo/transforms/legalize_shape_computations/legalize_shape_computations.cc index f0c6bd12f9fe8..7a8feb69caf9c 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_shape_computations/legalize_shape_computations.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_shape_computations/legalize_shape_computations.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_sort/legalize_sort.cc b/xla/mlir_hlo/mhlo/transforms/legalize_sort/legalize_sort.cc index 9e5c9a09b56f6..6446995a7e91f 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_sort/legalize_sort.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_sort/legalize_sort.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_sparse_ops/legalize_sparse_ops.cc b/xla/mlir_hlo/mhlo/transforms/legalize_sparse_ops/legalize_sparse_ops.cc deleted file mode 100644 index bba14c025a0d9..0000000000000 --- a/xla/mlir_hlo/mhlo/transforms/legalize_sparse_ops/legalize_sparse_ops.cc +++ /dev/null @@ -1,171 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements logic for converting CHLO dialect to Linalg dialect. - -#include -#include - -#include "llvm/ADT/STLExtras.h" -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" -#include "mhlo/transforms/rewriters.h" -#include "mhlo/utils/legalize_to_linalg_utils.h" -#include "mhlo/utils/type_conversion.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/ChloOps.h" - -namespace mlir { -namespace mhlo { - -#define GEN_PASS_DEF_LEGALIZESPARSEOPSPASS -#include "mhlo/transforms/mhlo_passes.h.inc" - -namespace { - -struct LegalizeSparseOpsPass - : public impl::LegalizeSparseOpsPassBase { - explicit LegalizeSparseOpsPass(bool legalizeToCustomCalls) - : impl::LegalizeSparseOpsPassBase< - LegalizeSparseOpsPass>::LegalizeSparseOpsPassBase() { - this->legalize_to_custom_calls_ = legalizeToCustomCalls; - } - - void getDependentDialects(DialectRegistry& registry) const override { - registry - .insert(); - } - - void runOnOperation() override { - MLIRContext* ctx = &getContext(); - RewritePatternSet patterns(ctx); - ConversionTarget target(*ctx); - mhlo::RemoveSignTypeConverter typeConverter; - if (legalize_to_custom_calls_) { - setupLegalizeToCustomCallPatterns(ctx, &patterns, typeConverter, target); - } else { - setupLegalizeSparseCHLOPatterns(ctx, &patterns, typeConverter, target); - } - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) { - return signalPassFailure(); - } - } - - private: - static bool isNotSparseOp(Operation* op) { - return !sparse_tensor::hasAnySparseOperandOrResult(op); - } - - static void setupLegalizeToCustomCallPatterns(MLIRContext* ctx, - RewritePatternSet* patterns, - TypeConverter& typeConverter, - ConversionTarget& target) { - mhlo::populateLegalizeSparseOpsToCustomCallPatterns(ctx, typeConverter, - patterns); - target.addIllegalDialect(); - target.addLegalOp(); - } - - static void setupLegalizeSparseCHLOPatterns(MLIRContext* ctx, - RewritePatternSet* patterns, - TypeConverter& typeConverter, - ConversionTarget& target) { - mhlo::populateLegalizeSparseCHLOPatterns(ctx, typeConverter, patterns); - target.addLegalDialect(); - /// TODO(bixia): Remove the convert of such sparse CHLO ops from - /// chlo_legalize_to_hlo. - target.addDynamicallyLegalOp(isNotSparseOp); - } -}; - -} // namespace - -namespace impl { -/// Converts unary chlo op to a scalar op. -/// -/// Since the CHLO ops require tensor operands, we first create a single element -/// from the tensor, then perform the CHLO ops, and extract the scalar result -/// from the tensor. This may introduce memory accesses overhead. -/// TODO(bixia): Remove the extra memory accesses for performance. -#define ADD_OP(OpTy) \ - template <> \ - Value mapMhloOpToStdScalarOp(Location loc, ArrayRef resultTypes, \ - ArrayRef /*arg_types*/, \ - OpTy::Adaptor adaptor, OpBuilder * b) { \ - Type innerResultTy = resultTypes[0]; \ - RankedTensorType tensorResultTy = \ - RankedTensorType::get({}, innerResultTy); \ - Value tensorArg = b->create( \ - loc, tensorResultTy, adaptor.getOperands()[0]); \ - Value tensorResult = \ - b->create(loc, tensorResultTy, ValueRange({tensorArg})); \ - Value innerResult = \ - b->create(loc, tensorResult, ValueRange({})); \ - return innerResult; \ - } - -ADD_OP(chlo::AsinOp) -ADD_OP(chlo::AsinhOp) -ADD_OP(chlo::AtanOp) -ADD_OP(chlo::AtanhOp) -ADD_OP(chlo::BesselI1eOp) -ADD_OP(chlo::SinhOp) -ADD_OP(chlo::TanOp) - -#undef ADD_OP - -} // namespace impl - -void populateLegalizeSparseCHLOPatterns(MLIRContext* context, - TypeConverter& typeConverter, - RewritePatternSet* patterns) { - patterns->add, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter>(typeConverter, - context); -} - -std::unique_ptr> createLegalizeSparseOperationsPass( - bool legalizeToCustomCalls) { - return std::make_unique(legalizeToCustomCalls); -} - -} // namespace mhlo - -} // namespace mlir diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_sparse_ops/sparse_ops_to_custom_calls.cc b/xla/mlir_hlo/mhlo/transforms/legalize_sparse_ops/sparse_ops_to_custom_calls.cc deleted file mode 100644 index 6641c0451e852..0000000000000 --- a/xla/mlir_hlo/mhlo/transforms/legalize_sparse_ops/sparse_ops_to_custom_calls.cc +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/rewriters.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir { -namespace { - -StringAttr getOperationTargetName(Operation* op) { - // Strips off `dialect` from `dialect.opName`. - StringRef opName = op->getName().getIdentifier().strref().split(".").second; - return StringAttr::get(op->getContext(), "sparse_tensor_" + opName); -} - -} // namespace -namespace mhlo { - -template -class SparseOpToCustomCallConverter : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter& rewriter) const final { - NamedAttribute callTargetName = - rewriter.getNamedAttr("call_target_name", getOperationTargetName(op)); - rewriter.replaceOpWithNewOp(op, op->getResultTypes(), - adaptor.getOperands(), - ArrayRef{callTargetName}); - return success(); - } -}; - -void populateLegalizeSparseOpsToCustomCallPatterns( - MLIRContext* context, TypeConverter& typeConverter, - RewritePatternSet* patterns) { - patterns->add, - SparseOpToCustomCallConverter, - SparseOpToCustomCallConverter>( - typeConverter, context); -} - -} // namespace mhlo -} // namespace mlir diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc b/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc index 904791c882429..2c470ccf66d48 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -1471,6 +1471,7 @@ class IotaConverter : public OpConversionPattern { ConversionPatternRewriter& rewriter) const final { ShapedType resultShapedType = getHloOpResultType(iotaOp); if (!resultShapedType) return failure(); + Type targetElementType = resultShapedType.getElementType(); resultShapedType = this->typeConverter->convertType(resultShapedType) .template dyn_cast(); @@ -1503,9 +1504,9 @@ class IotaConverter : public OpConversionPattern { nestedBuilder.getIntegerType( unwrappedResultElementType.getIntOrFloatBitWidth()), indexOp); - castOp = mhlo::MhloOpToStdScalarOp::mapOpOfType( - nestedLoc, resultElementType, castOp.getType(), {castOp}, - &nestedBuilder); + castOp = mhlo::MhloOpToStdScalarOp::mapConvertOpToStdScalarOp( + nestedLoc, targetElementType, resultElementType, castOp.getType(), + {castOp}, &nestedBuilder); nestedBuilder.create(nestedLoc, castOp); }, linalg::getPrunedAttributeList(iotaOp)); @@ -1524,6 +1525,7 @@ class IotaToMapConverter : public OpConversionPattern { ConversionPatternRewriter& rewriter) const final { ShapedType resultTy = getHloOpResultType(iotaOp); if (!resultTy) return failure(); + Type targetElementType = resultTy.getElementType(); resultTy = this->typeConverter->convertType(resultTy) .template dyn_cast(); @@ -1538,10 +1540,9 @@ class IotaToMapConverter : public OpConversionPattern { nestedLoc, iotaOp.getIotaDimension()); index = nestedBuilder.create( nestedLoc, nestedBuilder.getI64Type(), index); - Value result = - mhlo::MhloOpToStdScalarOp::mapOpOfType( - nestedLoc, resultTy.getElementType(), index.getType(), - {ValueRange{index}}, &nestedBuilder); + Value result = mhlo::MhloOpToStdScalarOp::mapConvertOpToStdScalarOp( + nestedLoc, targetElementType, resultTy.getElementType(), + index.getType(), {ValueRange{index}}, &nestedBuilder); nestedBuilder.create(nestedLoc, ValueRange{result}); }, linalg::getPrunedAttributeList(iotaOp)); @@ -3263,7 +3264,7 @@ struct ConvolutionOpGeneralConversion // Finally, create the computation auto inferredMaps = - AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs}); + AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs}, ctx); Value emptyTensor = rewriter.create( loc, reshapedResultShape, resultType.getElementType()); @@ -3578,7 +3579,7 @@ struct ReduceWindowOpOnTensorsGenericConversion SmallVector inferredMaps(3, AffineMap::get(ctx)); if (rank > 0) inferredMaps = - AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs}); + AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs}, ctx); SmallVector indexingMaps; @@ -4504,6 +4505,7 @@ void populateHloToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, @@ -4563,6 +4565,7 @@ void populateHloToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_to_standard/legalize_to_standard.cc b/xla/mlir_hlo/mhlo/transforms/legalize_to_standard/legalize_to_standard.cc index f1b2d403efe8a..2b8b4a051ffa5 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_to_standard/legalize_to_standard.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_to_standard/legalize_to_standard.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_to_standard/legalize_to_standard_patterns.td b/xla/mlir_hlo/mhlo/transforms/legalize_to_standard/legalize_to_standard_patterns.td index 2436dadb87017..d71af24027513 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_to_standard/legalize_to_standard_patterns.td +++ b/xla/mlir_hlo/mhlo/transforms/legalize_to_standard/legalize_to_standard_patterns.td @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -40,6 +40,10 @@ def createFastMathNone : NativeCodeCall< "::mlir::arith::FastMathFlagsAttr::get(" "$_builder.getContext(), ::mlir::arith::FastMathFlags::none" ")">; +def createOverflowNone : NativeCodeCall< + "::mlir::arith::IntegerOverflowFlagsAttr::get(" + "$_builder.getContext(), ::mlir::arith::IntegerOverflowFlags::none" + ")">; // Unary Lowering Patterns. @@ -68,13 +72,13 @@ def : Pat<(MHLO_RemOp MHLO_FpTensor:$l, MHLO_FpTensor:$r), (Arith_RemFOp $l, $r, (createFastMathNone )), [(IsSameSizeConstraint $l, $r)]>; def : Pat<(MHLO_AddOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), - (Arith_AddIOp $l, $r), + (Arith_AddIOp $l, $r, (createOverflowNone )), [(IsSameSizeConstraint $l, $r)]>; def : Pat<(MHLO_SubtractOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), - (Arith_SubIOp $l, $r), + (Arith_SubIOp $l, $r, (createOverflowNone )), [(IsSameSizeConstraint $l, $r)]>; def : Pat<(MHLO_MulOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), - (Arith_MulIOp $l, $r), + (Arith_MulIOp $l, $r, (createOverflowNone )), [(IsSameSizeConstraint $l, $r)]>; def : Pat<(MHLO_DivOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), (Arith_DivSIOp $l, $r), diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc b/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc index daaafd0157239..432c773354477 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc b/xla/mlir_hlo/mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc index 0a5a738c2c085..2e7018e2fd17c 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/lower_complex/lower_complex.cc b/xla/mlir_hlo/mhlo/transforms/lower_complex/lower_complex.cc index 6e8f0edbcb6a7..ff0e67ab60b18 100644 --- a/xla/mlir_hlo/mhlo/transforms/lower_complex/lower_complex.cc +++ b/xla/mlir_hlo/mhlo/transforms/lower_complex/lower_complex.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/lower_complex/lower_complex_patterns.td b/xla/mlir_hlo/mhlo/transforms/lower_complex/lower_complex_patterns.td index 1fc4128cba05e..693cec55e479e 100644 --- a/xla/mlir_hlo/mhlo/transforms/lower_complex/lower_complex_patterns.td +++ b/xla/mlir_hlo/mhlo/transforms/lower_complex/lower_complex_patterns.td @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/lower_general_dot/lower_general_dot.cc b/xla/mlir_hlo/mhlo/transforms/lower_general_dot/lower_general_dot.cc index 66fcdc9dff2ea..ee672671a6859 100644 --- a/xla/mlir_hlo/mhlo/transforms/lower_general_dot/lower_general_dot.cc +++ b/xla/mlir_hlo/mhlo/transforms/lower_general_dot/lower_general_dot.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ limitations under the License. #include #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" #include "mhlo/IR/hlo_ops.h" #include "mhlo/transforms/passes.h" #include "mhlo/transforms/rewriters.h" @@ -32,6 +33,7 @@ limitations under the License. #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -256,9 +258,9 @@ struct GeneralDotConvert : public OpRewritePattern { if (sparse_tensor::hasAnySparseOperandOrResult(op)) return failure(); // Compute the, possibly, transposed-reshaped operands. - lhs = llvm::cast>(processDotArg( + lhs = llvm::cast>(processDotArg( lhs, loc, lhsContractingDims, /*outerDimsFirst=*/true, rewriter)); - rhs = llvm::cast>(processDotArg( + rhs = llvm::cast>(processDotArg( rhs, loc, rhsContractingDims, /*outerDimsFirst=*/false, rewriter)); // Accept only static shaped types. diff --git a/xla/mlir_hlo/mhlo/transforms/map_chlo_to_hlo_op.h b/xla/mlir_hlo/mhlo/transforms/map_chlo_to_hlo_op.h index 5ba277bc4dbcc..fc043d2484c0b 100644 --- a/xla/mlir_hlo/mhlo/transforms/map_chlo_to_hlo_op.h +++ b/xla/mlir_hlo/mhlo/transforms/map_chlo_to_hlo_op.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h b/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h index 5f76f058e0f13..9761312cad5d8 100644 --- a/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h +++ b/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "mhlo/IR/hlo_ops.h" @@ -27,7 +26,6 @@ limitations under the License. #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" @@ -83,6 +81,10 @@ struct MhloToScalarOp { using COp = ::mlir::complex::CosOp; }; template <> +struct MhloToScalarOp { + using FOp = ::mlir::math::ErfOp; +}; +template <> struct MhloToScalarOp { using FOp = ::mlir::math::ExpOp; using COp = ::mlir::complex::ExpOp; @@ -472,10 +474,9 @@ inline Value mapMhloOpToStdScalarOp( mlir::ImplicitLocOpBuilder b(loc, *builder); // Integer and float types for casting and constant generation. - auto floatType = - argTypes.front().cast().getElementType().cast(); + auto floatType = getElementTypeOrSelf(argTypes.front()).cast(); int64_t nbits = floatType.getWidth(); - auto intType = mlir::IntegerType::get(loc.getContext(), floatType.getWidth()); + auto intType = mlir::IntegerType::get(loc.getContext(), nbits); Value xAsInt = b.create(intType, adaptor.getOperand()); @@ -1074,13 +1075,18 @@ template <> inline Value mapMhloOpToStdScalarOp( Location loc, ArrayRef resultTypes, ArrayRef /*argTypes*/, mhlo::LogisticOp::Adaptor adaptor, OpBuilder* b) { - // 1.0 / (1.0 - exp(-x)) + // 1.0 / (1.0 + exp(-x)) Value negX = mapMhloOpToStdScalarOp( loc, resultTypes, resultTypes, {adaptor.getOperand()}, b); Value expNegX = mapMhloOpToStdScalarOp(loc, resultTypes, resultTypes, {{negX}}, b); - Value oneFloat = b->create(loc, b->getF32FloatAttr(1.0)); + Type type = getElementTypeOrSelf(resultTypes[0]); + Value oneFloat = + type.isa() + ? b->create(loc, b->getF32FloatAttr(1.0)) + : getConstantOrSplat(b, loc, resultTypes[0], + FloatAttr::get(type, 1.0f)); Value one = mapConvertOpToStdScalarOp(loc, resultTypes, resultTypes, {oneFloat.getType()}, {{oneFloat}}, b); Value oneAddExprNegX = mapMhloOpToStdScalarOp( @@ -1096,8 +1102,9 @@ inline Value mapMhloOpToStdScalarOp(Location loc, mhlo::PowOp::Adaptor adaptor, OpBuilder* b) { auto lb = ImplicitLocOpBuilder(loc, *b); - // Floating point can use std::powf - auto resultType = resultTypes.front(); + // TODO: b/315868720 Consider alternate lowerings of mhlo::PowOp with integer + // operands. Floating point can use std::powf + auto resultType = getElementTypeOrSelf(resultTypes.front()); if (resultType.isa()) { return MapMhloOpToScalarOpImpl{}(loc, resultTypes, argTypes, @@ -1156,7 +1163,7 @@ inline Value mapMhloOpToStdScalarOp(Location loc, // The accum is correct when the rhs is non-negative. When rhs is // negative, we return 0 for integer, with the exception of lhs values of 1 // and -1 which have integer results for negative exponents. Specifically, the - // calulation is the following: + // calculation is the following: // // - Return accum if the rhs is not negative. // - Return 1 or -1 depending on the parity of rhs when the lhs is -1. @@ -1306,9 +1313,11 @@ struct MhloOpToStdScalarOp { ArrayRef argTypes, ValueRange args, OpBuilder* b) { static_assert(!std::is_same::value); - return mapOpOfType( - op.getLoc(), resultTypes, argTypes, - typename MhloOpTy::Adaptor(args, op->getAttrDictionary()), b); + typename MhloOpTy::Adaptor adaptor(args, op->getAttrDictionary(), + op->getPropertiesStorage(), + op->getRegions()); + return mapOpOfType(op.getLoc(), resultTypes, argTypes, adaptor, + b); } // Overload for mhlo::ConvertOp. static Value mapOpWithArgTypes(mhlo::ConvertOp op, ArrayRef resultTypes, @@ -1323,15 +1332,18 @@ struct MhloOpToStdScalarOp { static Value mapOpOfType(Location loc, ArrayRef resultTypes, ArrayRef argTypes, typename MhloOpTy::Adaptor adaptor, OpBuilder* b) { - if (std::is_same::value) { - // Note: this assumes that the caller is passing result/arg types with - // appropriate signedness. - return impl::mapConvertOpToStdScalarOp( - loc, resultTypes, resultTypes, argTypes, adaptor.getOperands(), b); - } return impl::mapMhloOpToStdScalarOp(loc, resultTypes, argTypes, adaptor, b); } + + static Value mapConvertOpToStdScalarOp(Location loc, + ArrayRef targetTypes, + ArrayRef resultTypes, + ArrayRef argTypes, + ValueRange args, OpBuilder* b) { + return impl::mapConvertOpToStdScalarOp(loc, targetTypes, resultTypes, + argTypes, args, b); + } }; } // namespace mhlo diff --git a/xla/mlir_hlo/mhlo/transforms/map_stablehlo_to_hlo_op.h b/xla/mlir_hlo/mhlo/transforms/map_stablehlo_to_hlo_op.h index 3667563ac078e..3754b928e4b51 100644 --- a/xla/mlir_hlo/mhlo/transforms/map_stablehlo_to_hlo_op.h +++ b/xla/mlir_hlo/mhlo/transforms/map_stablehlo_to_hlo_op.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -64,9 +64,11 @@ MAP_STABLEHLO_TO_HLO(CeilOp) MAP_STABLEHLO_TO_HLO(CholeskyOp) MAP_STABLEHLO_TO_HLO(ClampOp) MAP_STABLEHLO_TO_HLO(ClzOp) +MAP_STABLEHLO_TO_HLO(CollectiveBroadcastOp) MAP_STABLEHLO_TO_HLO(CollectivePermuteOp) MAP_STABLEHLO_TO_HLO(CompareOp) MAP_STABLEHLO_TO_HLO(ComplexOp) +MAP_STABLEHLO_TO_HLO(CompositeOp) MAP_STABLEHLO_TO_HLO(ComputeReshapeShapeOp) MAP_STABLEHLO_TO_HLO(ConcatenateOp) MAP_STABLEHLO_TO_HLO(ConstantOp) diff --git a/xla/mlir_hlo/mhlo/transforms/materialize_broadcasts/materialize_broadcasts.cc b/xla/mlir_hlo/mhlo/transforms/materialize_broadcasts/materialize_broadcasts.cc index 4b2cb559b5ea9..32538f4a8c37f 100644 --- a/xla/mlir_hlo/mhlo/transforms/materialize_broadcasts/materialize_broadcasts.cc +++ b/xla/mlir_hlo/mhlo/transforms/materialize_broadcasts/materialize_broadcasts.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/materialize_broadcasts/materialize_broadcasts_pass.cc b/xla/mlir_hlo/mhlo/transforms/materialize_broadcasts/materialize_broadcasts_pass.cc index 188b72f5e1d77..0b5f8171be9c7 100644 --- a/xla/mlir_hlo/mhlo/transforms/materialize_broadcasts/materialize_broadcasts_pass.cc +++ b/xla/mlir_hlo/mhlo/transforms/materialize_broadcasts/materialize_broadcasts_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc b/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc index 8af920e29be84..185b2c9d7caa1 100644 --- a/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc +++ b/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_dot/mhlo_canonicalize_dot.cc b/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_dot/mhlo_canonicalize_dot.cc index 14c01c1d69359..959c82837be8b 100644 --- a/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_dot/mhlo_canonicalize_dot.cc +++ b/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_dot/mhlo_canonicalize_dot.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_gather/mhlo_canonicalize_gather.cc b/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_gather/mhlo_canonicalize_gather.cc index 62cf99f5e1695..3c897e38e4b5a 100644 --- a/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_gather/mhlo_canonicalize_gather.cc +++ b/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_gather/mhlo_canonicalize_gather.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,8 +28,11 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -40,7 +43,8 @@ namespace { #include "mhlo/transforms/mhlo_passes.h.inc" // Given an input tensor, collapse dimensions 1+collapsedSliceDims[...]. -Value collapseSliceDims(ImplicitLocOpBuilder& b, TypedValue input, +Value collapseSliceDims(ImplicitLocOpBuilder& b, + TypedValue input, ArrayRef collapsedSliceDims) { if (collapsedSliceDims.empty()) return input; @@ -59,7 +63,7 @@ Value collapseSliceDims(ImplicitLocOpBuilder& b, TypedValue input, // Expands the first dimension of `input` into the shape of `startIndices`, // removing the index vector dimension. Value expandBatchDimension(ImplicitLocOpBuilder& b, - TypedValue input, + TypedValue input, GatherOp originalGatherOp) { llvm::SmallVector newShape{ originalGatherOp.getStartIndices().getType().getShape()}; @@ -87,7 +91,7 @@ Value expandBatchDimension(ImplicitLocOpBuilder& b, } Value moveOffsetDimensions(ImplicitLocOpBuilder& b, - TypedValue input, + TypedValue input, GatherOp originalGatherOp) { const auto& dims = originalGatherOp.getDimensionNumbers(); int64_t outputRank = input.getType().getRank(); @@ -152,7 +156,7 @@ struct CanonicalizeGatherPattern : public OpRewritePattern { rewriter.getContext(), offsetDims, /*collapsedSliceDims=*/{}, startIndexMap, /*indexVectorDim=*/1); - TypedValue result = + TypedValue result = b.create(operand, startIndices, newDims, b.getI64TensorAttr(permute( gatherOp.getSliceSizes().getValues(), @@ -168,16 +172,16 @@ struct CanonicalizeGatherPattern : public OpRewritePattern { result, b.getI64TensorAttr(operandPermutationInverse)); // Collapse the requested dimensions. - result = cast>( + result = cast>( collapseSliceDims(b, result, dims.getCollapsedSliceDims())); // Expand the start index dimensions. - result = - cast>(expandBatchDimension(b, result, gatherOp)); + result = cast>( + expandBatchDimension(b, result, gatherOp)); // Move the offset dims to the final locations. - result = - cast>(moveOffsetDimensions(b, result, gatherOp)); + result = cast>( + moveOffsetDimensions(b, result, gatherOp)); rewriter.replaceOp(gatherOp.getOperation(), {result}); return success(); diff --git a/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_reduction/mhlo_canonicalize_reduction.cc b/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_reduction/mhlo_canonicalize_reduction.cc index 49397f378c3b7..bb68ec22fbac9 100644 --- a/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_reduction/mhlo_canonicalize_reduction.cc +++ b/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_reduction/mhlo_canonicalize_reduction.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ limitations under the License. #include +#include "llvm/ADT/STLExtras.h" #include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -219,8 +220,12 @@ struct HloCanonicalizeReductionPass elemTy), operand, newOperandShape)); } - auto newOp = - b.create(loc, newOperands, op.getInitValues(), attr); + SmallVector elementTypes{llvm::map_range( + op.getBody().front().getTerminator()->getOperands(), [](Value v) { + return v.getType().cast().getElementType(); + })}; + auto newOp = b.create(loc, newOperands, op.getInitValues(), + attr, elementTypes); newOp.getBody().takeBody(op.getBody()); SmallVector newResults; diff --git a/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_scatter/mhlo_canonicalize_scatter.cc b/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_scatter/mhlo_canonicalize_scatter.cc index 4c9d853c3fb0d..30b0c6f1848bd 100644 --- a/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_scatter/mhlo_canonicalize_scatter.cc +++ b/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_scatter/mhlo_canonicalize_scatter.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc b/xla/mlir_hlo/mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc index f30028421679c..f55bb8e85973d 100644 --- a/xla/mlir_hlo/mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc +++ b/xla/mlir_hlo/mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td b/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td index e29f17382437d..62868358ffc0e 100644 --- a/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td +++ b/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,23 +15,27 @@ limitations under the License. include "mlir/Pass/PassBase.td" -def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "func::FuncOp"> { - let summary = "Legalize CHLO to HLO."; - let constructor = "createChloLegalizeToHloPass()"; - let options = [ - Option<"legalize_broadcasts_", "legalize-broadcasts", "bool", - /*default=*/"true", "Legalize implicit broadcasts to explicit HLO broadcasting forms">, - Option<"expand_compositions_", "expand-compositions", "bool", - /*default=*/"true", "Expands client-centric compositions to HLO primitives">, - ]; +def ChloLegalizeToHighLevelMhloPass : Pass<"chlo-legalize-to-high-level-mhlo", "func::FuncOp"> { + let summary = "Legalize CHLO's with XLA counterparts, like TopK and Erf."; + let description = [{ + Performs direct legalization of CHLO->MHLO only for high-level (non-basis) + ops with XLA support. These are MHLO ops that directly model the CHLO op, + such as TopK and Erf. + }]; + let dependentDialects = ["mhlo::MhloDialect"]; } -def LegalizeSparseOpsPass : Pass<"legalize-sparse-ops", "func::FuncOp"> { - let summary = "Legalize from sparse ops before convert MLIR to XLA computation."; - let constructor = "createLegalizeSparseOperationsPass()"; - let options = [ - Option<"legalize_to_custom_calls_", "legalize-to-custom-calls", "bool", - /*default=*/"true", "Whether legalize the sparse operations to custom_calls to be able to translate sparse operations to XLA computations">, +def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "func::FuncOp"> { + let summary = "Legalize CHLO to MHLO with XLA-supported ops."; + let description = [{ + Performs legalization of CHLO->StableHLO->MHLO, while also preserving MHLO + high level operations when possible (see ChloLegalizeToHighLevelMhloPass). + }]; + let dependentDialects = [ + "mhlo::MhloDialect", + "mlir::shape::ShapeDialect", + "mlir::stablehlo::StablehloDialect", + "mlir::tensor::TensorDialect" ]; } @@ -315,28 +319,6 @@ def ConvertToSignlessPass : Pass<"convert-to-signless", "ModuleOp"> { let constructor = "createConvertToSignlessPass()"; } -def SparseRewritingPass : Pass<"mhlo-sparse-rewriting", "func::FuncOp"> { - let summary = "Pass to rewrite mhlo sparse tensor types."; - let constructor = "createSparseRewritingPass()"; -} - -/// Rank specialization passes. - -def RankSpecializationClusterPass - : Pass<"mhlo-rank-specialization-cluster", "func::FuncOp"> { - let constructor = "createRankSpecializationClusterPass()"; -} - -def RankSpecializationToSCFPass - : Pass<"mhlo-rank-specialization-to-scf", "func::FuncOp"> { - let constructor = "createRankSpecializationToSCFPass()"; - let options = [ - Option<"max_target_rank_", "max-target-rank", "int", /*default=*/"8", - "The maximum supported rank after rank specialization. Any argument " - "of greater rank may result in a runtime failure.">, - ]; -} - def MhloExpandOpsSimplifierPass : Pass<"mhlo-expand-ops-simplifier", "func::FuncOp"> { let summary = "Expand feature rich mhlo ops into a set of simpler mhlo ops."; @@ -392,4 +374,21 @@ def ShapeLegalizeToHloPass : Pass<"shape-legalize-to-hlo", "func::FuncOp"> { compilation pipelines that use HLO operations to model dynamism. }]; let dependentDialects = ["mhlo::MhloDialect"]; + let options = [ + Option<"legalize_constraints_", "legalize-constraints", "bool", + /*default=*/"false", "Whether to legalize Cstr Ops to shape_assertion custom_call"> + ]; +} + +def MhloQuantLegalizeToInt : Pass<"mhlo-quant-legalize-to-int", "mlir::func::FuncOp"> { + let summary = "Convert from MHLO quantized ops to MHLO primitive ops."; + + let description = [{ + Convert from MHLO quantized ops with MHLO quant types to MHLO primitive ops + like int ops. + }]; + let constructor = "createMhloQuantLegalizeToIntPass()"; + let dependentDialects = ["chlo::ChloDialect", "mhlo::MhloDialect", + "quant::QuantizationDialect", + "func::FuncDialect"]; } diff --git a/xla/mlir_hlo/mhlo/transforms/mhlo_quant_legalize_to_int/mhlo_quant_legalize_to_int.cc b/xla/mlir_hlo/mhlo/transforms/mhlo_quant_legalize_to_int/mhlo_quant_legalize_to_int.cc new file mode 100644 index 0000000000000..4d44b97db4fde --- /dev/null +++ b/xla/mlir_hlo/mhlo/transforms/mhlo_quant_legalize_to_int/mhlo_quant_legalize_to_int.cc @@ -0,0 +1,1399 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/dialect/ChloOps.h" + +namespace mlir::mhlo { +namespace { + +// TODO: b/311218165 - consider extract this to common utils and better ways to +// handle polymorphism. +using QuantType = std::variant; +FailureOr getQuantType(Type type) { + if (auto quantType = + getElementTypeOrSelf(type).dyn_cast()) { + return QuantType(quantType); + } + if (auto quantType = getElementTypeOrSelf(type) + .dyn_cast()) { + return QuantType(quantType); + } + return failure(); +} + +bool isPerTensorType(QuantType quantType) { + return std::holds_alternative(quantType); +} + +bool isPerChannelType(QuantType quantType) { + return std::holds_alternative(quantType); +} + +quant::UniformQuantizedType getPerTensorType(QuantType quantType) { + return std::get(quantType); +} + +quant::UniformQuantizedPerAxisType getPerChannelType(QuantType quantType) { + return std::get(quantType); +} + +// Extracts scale and zero point info from input quant type info. +void getQuantizationParams(OpBuilder &builder, Location loc, + QuantType quantType, Value &scales, + Value &zeroPoints, bool outputZeroPointInFp, + DenseI64ArrayAttr &broadcastDims) { + // Get scales/zero points for per-tensor and per-axis quantization cases. + if (auto *quantPerTensorType = + std::get_if(&quantType)) { + scales = builder.create( + loc, builder.getF32FloatAttr(quantPerTensorType->getScale())); + if (outputZeroPointInFp) { + zeroPoints = builder.create( + loc, builder.getF32FloatAttr( + static_cast(quantPerTensorType->getZeroPoint()))); + } else { + zeroPoints = builder.create( + loc, builder.getI32IntegerAttr( + static_cast(quantPerTensorType->getZeroPoint()))); + } + } else { + auto &quantPerChannelType = + std::get(quantType); + SmallVector scalesVec; + for (auto scale : quantPerChannelType.getScales()) + scalesVec.push_back(scale); + scales = builder.create( + loc, + DenseFPElementsAttr::get( + RankedTensorType::get( + {static_cast(quantPerChannelType.getScales().size())}, + builder.getF32Type()), + scalesVec)); + if (outputZeroPointInFp) { + SmallVector zeroPointsVec; + for (auto zeroPoint : quantPerChannelType.getZeroPoints()) + zeroPointsVec.push_back(zeroPoint); + zeroPoints = builder.create( + loc, DenseFPElementsAttr::get( + RankedTensorType::get( + {static_cast( + quantPerChannelType.getZeroPoints().size())}, + builder.getF32Type()), + zeroPointsVec)); + } else { + SmallVector zeroPointsVec; + for (auto zeroPoint : quantPerChannelType.getZeroPoints()) + zeroPointsVec.push_back(zeroPoint); + zeroPoints = builder.create( + loc, DenseIntElementsAttr::get( + RankedTensorType::get( + {static_cast( + quantPerChannelType.getZeroPoints().size())}, + builder.getI32Type()), + zeroPointsVec)); + } + broadcastDims = DenseI64ArrayAttr::get( + builder.getContext(), + {static_cast(quantPerChannelType.getQuantizedDimension())}); + } +} + +// Extracts storage min/max from input quant type info. +void getQuantizationStorageInfo(OpBuilder &builder, Location loc, + QuantType quantType, Value &storageMin, + Value &storageMax) { + if (auto *quantPerTensorType = + std::get_if(&quantType)) { + storageMin = builder.create( + loc, builder.getF32FloatAttr( + static_cast(quantPerTensorType->getStorageTypeMin()))); + storageMax = builder.create( + loc, builder.getF32FloatAttr( + static_cast(quantPerTensorType->getStorageTypeMax()))); + } else { + auto &quantPerChannelType = + std::get(quantType); + storageMin = builder.create( + loc, builder.getF32FloatAttr( + static_cast(quantPerChannelType.getStorageTypeMin()))); + storageMax = builder.create( + loc, builder.getF32FloatAttr( + static_cast(quantPerChannelType.getStorageTypeMax()))); + } +} + +// Extracts storage type of a UQ type. Return original type if it is no UQ type. +Type getQuantStorageType(Type type) { + if (auto shaped = type.dyn_cast()) { + return shaped.clone(getQuantStorageType(shaped.getElementType())); + } + + if (auto elementType = + getElementTypeOrSelf(type).dyn_cast()) { + return elementType.getStorageType(); + } + if (auto elementType = getElementTypeOrSelf(type) + .dyn_cast()) { + return elementType.getStorageType(); + } + return type; +} + +Type getQuantStorageType(QuantType type) { + if (isPerTensorType(type)) { + return getPerTensorType(type).getStorageType(); + } + return getPerChannelType(type).getStorageType(); +} + +Value applyMergedScalesAndZps(OpBuilder &builder, Location loc, + QuantType inputQuantType, + QuantType outputQuantType, + Value inputFloatTensor) { + // Use single merged scale and merged zp if both input and output are + // per-tensor quantized. Otherwise use a vector. + if (isPerTensorType(inputQuantType) && isPerTensorType(outputQuantType)) { + quant::UniformQuantizedType inputPerTensorType = + getPerTensorType(inputQuantType); + quant::UniformQuantizedType outputPerTensorType = + getPerTensorType(outputQuantType); + double mergedScaleFp = + inputPerTensorType.getScale() / outputPerTensorType.getScale(); + auto mergedScale = builder.create( + loc, builder.getF32FloatAttr(static_cast(mergedScaleFp))); + inputFloatTensor = + builder.create(loc, inputFloatTensor, mergedScale, + /*broadcast_dimensions=*/nullptr); + // Add merged_zp only when it is non-zero. + double mergedZpFp = outputPerTensorType.getZeroPoint() - + inputPerTensorType.getZeroPoint() * mergedScaleFp; + if (mergedZpFp != 0) { + Value mergedZp = builder.create( + loc, builder.getF32FloatAttr(static_cast(mergedZpFp))); + inputFloatTensor = builder.create( + loc, inputFloatTensor, mergedZp, /*broadcast_dimensions=*/nullptr); + } + } else { + int64_t channelSize = + isPerChannelType(outputQuantType) + ? getPerChannelType(outputQuantType).getScales().size() + : getPerChannelType(inputQuantType).getScales().size(); + int64_t quantizedDimension = + isPerChannelType(outputQuantType) + ? getPerChannelType(outputQuantType).getQuantizedDimension() + : getPerChannelType(inputQuantType).getQuantizedDimension(); + SmallVector mergedScaleDouble, mergedZpDouble; + mergedScaleDouble.resize(channelSize); + mergedZpDouble.resize(channelSize); + for (int i = 0; i < channelSize; ++i) { + mergedScaleDouble[i] = + (isPerChannelType(inputQuantType) + ? getPerChannelType(inputQuantType).getScales()[i] + : getPerTensorType(inputQuantType).getScale()) / + (isPerChannelType(outputQuantType) + ? getPerChannelType(outputQuantType).getScales()[i] + : getPerTensorType(outputQuantType).getScale()); + mergedZpDouble[i] = + (isPerChannelType(outputQuantType) + ? getPerChannelType(outputQuantType).getZeroPoints()[i] + : getPerTensorType(outputQuantType).getZeroPoint()) - + (isPerChannelType(inputQuantType) + ? getPerChannelType(inputQuantType).getZeroPoints()[i] + : getPerTensorType(inputQuantType).getZeroPoint()) * + mergedScaleDouble[i]; + } + SmallVector mergedScaleFloat(mergedScaleDouble.begin(), + mergedScaleDouble.end()), + mergedZpFloat(mergedZpDouble.begin(), mergedZpDouble.end()); + + auto broadcastDims = + DenseI64ArrayAttr::get(builder.getContext(), {quantizedDimension}); + Value mergedScale = builder.create( + loc, DenseFPElementsAttr::get( + RankedTensorType::get({channelSize}, builder.getF32Type()), + mergedScaleFloat)); + inputFloatTensor = builder.create( + loc, inputFloatTensor, mergedScale, broadcastDims); + if (llvm::any_of(mergedZpFloat, [](double zp) { return zp != 0; })) { + Value mergedZp = builder.create( + loc, DenseFPElementsAttr::get( + RankedTensorType::get({channelSize}, builder.getF32Type()), + mergedZpFloat)); + inputFloatTensor = builder.create( + loc, inputFloatTensor, mergedZp, broadcastDims); + } + } + return inputFloatTensor; +} + +// This helper function create ops to requantize `input` tensor and returns the +// output tensor. Clamping is done if output integer bit-width < i32. It assumes +// that if both input and output tensor are per-channel quantized, they have the +// same quantization axis. +// +// Requantization is essentially dequantize --> quantize. +// +// Dequantize: (input - zp) * scale +// Quantize: input / scale + zp +// +// Hence, +// output = (input - input_zp) * input_scale / output_scale + output_zp +// +// This is simplified as: +// output = input * merged_scale + merged_zp +// where: +// merged_zp = output_zp - input_zp * merged_scale. +// merged_scale = input_scale / output_scale. +Value requantize(mlir::OpState op, Value input, QuantType inputQuantType, + QuantType outputQuantType, TensorType outputTensorType, + ConversionPatternRewriter &rewriter) { + // Skip requantization when input and result have the same type. + if (inputQuantType == outputQuantType) { + return rewriter.create(op->getLoc(), outputTensorType, + input); + } + + auto floatTensorType = outputTensorType.clone(rewriter.getF32Type()); + Value outputFloat = + rewriter.create(op->getLoc(), floatTensorType, input); + + outputFloat = applyMergedScalesAndZps(rewriter, op->getLoc(), inputQuantType, + outputQuantType, outputFloat); + + // Clamp output if the output integer bit-width <32. + if (outputTensorType.getElementType().cast().getWidth() < 32) { + Value quantizationMin, quantizationMax; + getQuantizationStorageInfo(rewriter, op->getLoc(), outputQuantType, + quantizationMin, quantizationMax); + // Clamp results by [quantizationMin, quantizationMax]. + outputFloat = rewriter.create(op->getLoc(), quantizationMin, + outputFloat, quantizationMax); + } + + outputFloat = rewriter.create( + op->getLoc(), floatTensorType, outputFloat); + return rewriter.create(op->getLoc(), outputTensorType, + outputFloat); +} + +class ConvertUniformQuantizeOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::UniformQuantizeOp op, mhlo::UniformQuantizeOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto inputElementType = getElementTypeOrSelf(op.getOperand().getType()); + if (inputElementType.isF32()) { + auto quantType = getQuantType(op.getResult().getType()); + if (succeeded(quantType)) { + return matchAndRewriteQuantize(op, adaptor, rewriter, *quantType); + } + } else if (inputElementType.isa()) { + auto inputQuantType = getQuantType(inputElementType); + auto outputQuantType = getQuantType(op.getResult().getType()); + if (succeeded(inputQuantType) && succeeded(outputQuantType)) { + if (isPerChannelType(*inputQuantType) && + isPerChannelType(*outputQuantType) && + getPerChannelType(*inputQuantType).getQuantizedDimension() != + getPerChannelType(*outputQuantType).getQuantizedDimension()) { + op->emitError("Cannot requantize while changing quantization_axis"); + return failure(); + } + return matchAndRewriteRequantize(op, adaptor, rewriter, *inputQuantType, + *outputQuantType); + } + } + op->emitError("Unsupported input element type."); + return failure(); + } + + LogicalResult matchAndRewriteQuantize(mhlo::UniformQuantizeOp op, + mhlo::UniformQuantizeOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + QuantType quantType) const { + Value scales, zeroPoints; + DenseI64ArrayAttr broadcastDims; + getQuantizationParams(rewriter, op->getLoc(), quantType, scales, zeroPoints, + /*outputZeroPointInFp=*/true, broadcastDims); + + Value quantizationMin, quantizationMax; + getQuantizationStorageInfo(rewriter, op->getLoc(), quantType, + quantizationMin, quantizationMax); + + auto resFloatTensorType = + op.getOperand().getType().clone(rewriter.getF32Type()); + Value resFloat = rewriter.create( + op->getLoc(), resFloatTensorType, adaptor.getOperand(), scales, + broadcastDims); + resFloat = rewriter.create( + op->getLoc(), resFloatTensorType, resFloat, zeroPoints, broadcastDims); + + resFloat = rewriter.create(op->getLoc(), resFloatTensorType, + quantizationMin, resFloat, + quantizationMax); + resFloat = rewriter.create( + op->getLoc(), resFloatTensorType, resFloat); + auto resFinalTensorType = resFloatTensorType.clone( + getQuantStorageType(op.getResult().getType().getElementType())); + rewriter.replaceOpWithNewOp(op, resFinalTensorType, + resFloat); + return success(); + } + + LogicalResult matchAndRewriteRequantize( + mhlo::UniformQuantizeOp op, mhlo::UniformQuantizeOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, QuantType inputQuantType, + QuantType outputQuantType) const { + rewriter.replaceOp( + op, + requantize(op, adaptor.getOperand(), inputQuantType, outputQuantType, + /*outputTensorType=*/ + op.getResult().getType().cast().clone( + getQuantStorageType(outputQuantType)), + rewriter)); + return success(); + } +}; + +class ConvertUniformDequantizeOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::UniformDequantizeOp op, mhlo::UniformDequantizeOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto quantType = getQuantType(op.getOperand().getType()); + if (failed(quantType)) { + return failure(); + } + Value scales, zeroPoints; + DenseI64ArrayAttr broadcastDims; + getQuantizationParams(rewriter, op->getLoc(), *quantType, scales, + zeroPoints, + /*outputZeroPointInFp=*/false, broadcastDims); + + Value input = adaptor.getOperand(); + // TODO: b/260280919 - Consider avoiding conversion to int32. + auto resInt32TensorType = + input.getType().cast().clone(rewriter.getI32Type()); + Value resInt32 = rewriter.create( + op->getLoc(), resInt32TensorType, input); + resInt32 = rewriter.create( + op->getLoc(), resInt32TensorType, resInt32, zeroPoints, broadcastDims); + auto resFloatTensorType = + resInt32.getType().cast().clone(rewriter.getF32Type()); + Value resFloat = rewriter.create( + op->getLoc(), resFloatTensorType, resInt32); + resFloat = rewriter.replaceOpWithNewOp( + op, resFloatTensorType, resFloat, scales, broadcastDims); + return success(); + } +}; + +class ConvertUniformQuantizedAddOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::AddOp op, mhlo::AddOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto lhsQuantType = + getQuantType(getElementTypeOrSelf(op.getLhs().getType())); + auto rhsQuantType = + getQuantType(getElementTypeOrSelf(op.getRhs().getType())); + auto resQuantType = + getQuantType(getElementTypeOrSelf(op.getResult().getType())); + + // We only handle cases where lhs, rhs and results all have quantized + // element type. + if (failed(lhsQuantType) || failed(rhsQuantType) || failed(resQuantType)) { + op->emitError( + "AddOp requires the quantized element type for all operands and " + "results"); + return failure(); + } + + if (isPerChannelType(*lhsQuantType) || isPerChannelType(*rhsQuantType) || + isPerChannelType(*resQuantType)) { + // Handle Per-Channel Quantized Types. We only support lhs/rhs/result with + // exact same per-channel quantized types with I32 storage type. + if (!isPerChannelType(*lhsQuantType) || + !isPerChannelType(*rhsQuantType) || + !isPerChannelType(*resQuantType) || + getPerChannelType(*lhsQuantType) != + getPerChannelType(*rhsQuantType) || + getPerChannelType(*lhsQuantType) != + getPerChannelType(*resQuantType)) { + op->emitError( + "Per-channel quantized AddOp requires the same quantized element " + "type for all operands and results"); + return failure(); + } + if (!getPerChannelType(*lhsQuantType).getStorageType().isInteger(32)) { + // For server-side StableHLO Quantization, add is quantized only when + // fused with conv/dot ops, whose output must be i32. + op->emitError("Per-channel quantized AddOp requires i32 storage type"); + return failure(); + } + return matchAndRewritePerChannel(op, adaptor, rewriter, + getPerChannelType(*lhsQuantType)); + } + + // TODO: b/260280919 - Consider avoiding conversion to int32. + auto resInt32TensorType = + op.getResult().getType().clone(rewriter.getI32Type()); + + // When lhs, rhs and result have different scale and zps, requantize them to + // be the same as the result. + // TODO: b/260280919 - Consider avoiding conversion to int32. + Value lhs = adaptor.getLhs(); + Value lhsInt32Tensor = requantize(op, lhs, *lhsQuantType, *resQuantType, + resInt32TensorType, rewriter); + + Value rhs = adaptor.getRhs(); + Value rhsInt32Tensor = requantize(op, rhs, *rhsQuantType, *resQuantType, + resInt32TensorType, rewriter); + + Value zeroPoint = rewriter.create( + op->getLoc(), rewriter.getI32IntegerAttr(static_cast( + getPerTensorType(*resQuantType).getZeroPoint()))); + + // Now the lhs and rhs have been coverted to the same scale and zps. + // Given: + // lhs_fp = (lhs_quant - zp) * scale + // rhs_fp = (rhs_quant - zp) * scale + // res_fp = lhs_fp + rhs_fp + // = ((lhs_quant + rhs_quant - zp) - zp) * scale + // res_quant = res_fp / scale + zp + // = lhs_quant + rhs_quant - zp + // The following add the inputs and then substract by zero point. + Value addResult = rewriter.create( + op->getLoc(), resInt32TensorType, lhsInt32Tensor, rhsInt32Tensor, + nullptr); + Value resInt32 = rewriter.create( + op->getLoc(), resInt32TensorType, addResult, zeroPoint, nullptr); + + if (getQuantStorageType(*resQuantType).isInteger(32)) { + // For i32, clamping is not needed. + rewriter.replaceOp(op, resInt32); + } else { + // Clamp results by [quantizationMin, quantizationMax] when storage type + // is not i32. + Value resultQuantizationMin = rewriter.create( + op->getLoc(), + rewriter.getI32IntegerAttr(static_cast( + getPerTensorType(*resQuantType).getStorageTypeMin()))); + Value resultQuantizationMax = rewriter.create( + op->getLoc(), + rewriter.getI32IntegerAttr(static_cast( + getPerTensorType(*resQuantType).getStorageTypeMax()))); + resInt32 = rewriter.create( + op->getLoc(), resInt32TensorType, resultQuantizationMin, resInt32, + resultQuantizationMax); + // Convert results back to result storage type. + auto resFinalTensorType = + resInt32TensorType.clone(getQuantStorageType(*resQuantType)); + rewriter.replaceOpWithNewOp(op, resFinalTensorType, + resInt32); + } + + return success(); + } + + LogicalResult matchAndRewritePerChannel( + mhlo::AddOp op, mhlo::AddOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + quant::UniformQuantizedPerAxisType quantType) const { + // We assume lhs/rhs/result have the same quantized type with i32 storage. + Value addResult = rewriter.create( + op->getLoc(), adaptor.getLhs(), adaptor.getRhs()); + // Add zp contribution if it is non-zero for any channel. + if (llvm::any_of(quantType.getZeroPoints(), + [](int64_t zp) { return zp != 0; })) { + SmallVector zpsVec(quantType.getZeroPoints().begin(), + quantType.getZeroPoints().end()); + Value zps = rewriter.create( + op->getLoc(), + DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(zpsVec.size())}, + rewriter.getI32Type()), + zpsVec)); + addResult = rewriter.create( + op->getLoc(), addResult, zps, + rewriter.getDenseI64ArrayAttr( + {static_cast(quantType.getQuantizedDimension())})); + } + rewriter.replaceOp(op, addResult); + return success(); + } +}; + +// This is a convenient struct for holding dimension numbers for dot-like ops +// including DotGeneral and Convolution. So that we can share code for all +// dot-like ops. +// For Convolution, only NHWC format is supported. +// For DotGeneral, there is no contracting dims. The batching and contracting +// dimensions are defined in +// https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dot_general. +struct DotLikeDimensionNumbers { + SmallVector lhsBatchingDims; + SmallVector lhsSpatialDims; + SmallVector lhsContractingDims; + SmallVector rhsBatchingDims; + SmallVector rhsSpatialDims; + SmallVector rhsContractingDims; +}; + +// Checks if zero points of the given quantized type are zero. +bool isZeroPointZero(QuantType type) { + if (isPerTensorType(type)) { + return getPerTensorType(type).getZeroPoint() == 0; + } + if (isPerChannelType(type)) { + ArrayRef zeroPoints = getPerChannelType(type).getZeroPoints(); + return llvm::all_of(zeroPoints, [](int64_t zp) { return zp == 0; }); + } + return false; +} + +// A shared matchAndRewrite implementation for dot-like hybrid quantized +// operators. Hybrid ops are currently only interpreted as weight-only +// quantization ops, this might change in the future. +// +// All attrs of the original op are preserved after the conversion. +template +LogicalResult matchAndRewriteDotLikeHybridOp( + OpType &op, OpAdaptorType &adaptor, ConversionPatternRewriter &rewriter) { + // For dot like hybrid ops, lhs is float type, rhs is uniform + // quantized type and result is float type. + // For weight-only quantization: + // result = hybridOp(lhs, dequant(rhs)) + Value lhsFloat32Tensor = adaptor.getLhs(); + // Insert optimization_barrier to prevent constant folding of dequantize + + // quantized weights. + auto barrier = rewriter.create(op->getLoc(), + adaptor.getRhs()); + Operation::result_range resultRange = barrier.getResults(); + Value rhs = resultRange.front(); + FailureOr rhsElementQuantType = + getQuantType(op.getRhs().getType()); + if (failed(rhsElementQuantType)) { + return failure(); + } + auto resFloat32TensorType = + op.getResult().getType().template cast(); + auto rhsFloat32TensorType = + op.getRhs().getType().template cast().clone( + rewriter.getF32Type()); + + // Get scales and zero points for rhs. + Value rhsScale, rhsZeroPoint; + DenseI64ArrayAttr broadcastDims; + getQuantizationParams(rewriter, op->getLoc(), *rhsElementQuantType, rhsScale, + rhsZeroPoint, + /*outputZeroPointInFp=*/true, broadcastDims); + + // Dequantize rhs_float32_tensor. + Value rhsFloat32Tensor = + rewriter.create(op->getLoc(), rhsFloat32TensorType, rhs); + + // Subtract zero points only when it is not zero. + if (!isZeroPointZero(*rhsElementQuantType)) { + rhsFloat32Tensor = rewriter.create( + op->getLoc(), rhsFloat32TensorType, rhsFloat32Tensor, rhsZeroPoint, + broadcastDims); + } + rhsFloat32Tensor = rewriter.create( + op->getLoc(), rhsFloat32TensorType, rhsFloat32Tensor, rhsScale, + broadcastDims); + + // Execute conversion target op. + SmallVector operands{lhsFloat32Tensor, rhsFloat32Tensor}; + rewriter.replaceOpWithNewOp(op, resFloat32TensorType, operands, + op->getAttrs()); + return success(); +} + +Value createZeroPointPartialOffset(OpBuilder &builder, Location loc, + Value tensor, const int64_t otherTensorZp, + SmallVector reductionDims) { + // This function calculates part of the zero-point-offset by using + // mhlo::Reduce to sum over the contracting dims of the tensor, and then + // multiply by zp of the other tensor. + auto outputElementType = builder.getI32Type(); + + // Calculate the output tensor shape. This is input tensor dims minus + // contracting dims. + auto rankedTensor = tensor.getType().cast(); + SmallVector outputDims; + for (int64_t i = 0; i < rankedTensor.getRank(); ++i) { + if (llvm::count(reductionDims, i) == 0) { + outputDims.push_back(rankedTensor.getDimSize(i)); + } + } + + // Convert input tensor to output type since mhlo::Reduce only supports same + // element type for input/output. + tensor = builder.create( + loc, tensor.getType().cast().clone(outputElementType), + tensor); + auto reducerTensorType = RankedTensorType::get({}, outputElementType); + + // Initial value for reduced tensor. This is set 0. + Value initValues = builder.create( + loc, DenseIntElementsAttr::get(reducerTensorType, {0})); + mhlo::ReduceOp reduce = builder.create( + loc, RankedTensorType::get(outputDims, outputElementType), tensor, + initValues, builder.getI64TensorAttr(reductionDims)); + // Define reducer function to compute sum. + Region ®ion = reduce.getBody(); + Block &block = region.emplaceBlock(); + block.addArgument(reducerTensorType, loc); + block.addArgument(reducerTensorType, loc); + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&block); + Value sum = + builder.create(loc, *firstArgument, *secondArgument); + builder.create(loc, sum); + } + Value zp = builder.create( + loc, builder.getI32IntegerAttr(otherTensorZp)); + Value mulOp = builder.create(loc, reduce.getResult(0), + zp, nullptr); + return mulOp; +} + +Value getDimValue(OpBuilder &builder, Location loc, Value tensor, + mlir::ShapedType tensorShape, int64_t idx) { + if (tensorShape.isDynamicDim(idx)) { + // Get dynamic dim using GetDimensionSizeOp and convert result from to + // <1xi64>. + Value dynamicDim = builder.create( + loc, tensor, builder.getI64IntegerAttr(idx)); + dynamicDim = builder.create( + loc, RankedTensorType::get(ArrayRef{}, builder.getI64Type()), + dynamicDim); + return builder.create( + loc, RankedTensorType::get({1}, builder.getI64Type()), dynamicDim); + } + return builder.create( + loc, DenseIntElementsAttr::get( + RankedTensorType::get({1}, builder.getI64Type()), + {tensorShape.getDimSize(idx)})); +} + +Value calculateDynamicOutputDims(OpBuilder &builder, Location loc, Value output, + ShapedType outputTensorType) { + // Calculate each output tensor dim and concatenate into a 1D tensor. + SmallVector outputDims; + for (int64_t i = 0; i < outputTensorType.getRank(); ++i) { + outputDims.push_back( + getDimValue(builder, loc, output, outputTensorType, i)); + } + return builder.create(loc, outputDims, + builder.getI64IntegerAttr(0)); +} + +Value broadcastZpContribution(OpBuilder &builder, Location loc, + Value zpContribution, + ArrayRef reductionDims, + ArrayRef batchingDims, + int64_t nonBatchingStartingIdx, Value output, + TensorType outputTensorType, + Value &outputDimsValue) { + // This function calculates the dims for broadcasting from the + // zero-point-offset tensor to the final output tensor, and then do the + // broadcast. + auto zpContributionRank = + zpContribution.getType().cast().getRank(); + SmallVector broadcastDims; + broadcastDims.resize(zpContributionRank, 0); + // Result tensor will have batching dims first, then LHS result dims, then + // RHS result dims. So non-batching result dims index doesn't start from 0. + // The arg non_batching_starting_idx is used to distinguish LHS and RHS. + int64_t resultBatchingIdx = 0; + int64_t resultNonBatchingIdx = nonBatchingStartingIdx; + for (int64_t idx = 0, originalIdx = 0; idx < zpContributionRank; + ++idx, ++originalIdx) { + // zp_contribution has removed contracting/spatial dims from the tensor + // after reduction. The following recovers the index in the original tensor. + while (llvm::count(reductionDims, originalIdx) != 0) { + originalIdx++; + } + if (llvm::count(batchingDims, originalIdx) == 0) { + broadcastDims[idx] = resultNonBatchingIdx++; + } else { + broadcastDims[idx] = resultBatchingIdx++; + } + } + // Use broadcast_in_dim or dyanmic_broadcast_in_dim based on output shape + // dynamism. + if (outputTensorType.cast().hasStaticShape()) { + zpContribution = builder.create( + loc, outputTensorType, zpContribution, + DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(broadcastDims.size())}, + builder.getI64Type()), + broadcastDims)); + } else { + if (!outputDimsValue) { + outputDimsValue = + calculateDynamicOutputDims(builder, loc, output, outputTensorType); + } + zpContribution = builder.create( + loc, outputTensorType, zpContribution, outputDimsValue, + DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(broadcastDims.size())}, + builder.getI64Type()), + broadcastDims)); + } + return zpContribution; +} + +Value calculateZeroPointOffset(OpBuilder &builder, Location loc, Value lhs, + Value rhs, Value output, int64_t lhsZp, + int64_t rhsZp, TensorType outputTensorType, + const DotLikeDimensionNumbers &dims) { + mlir::ShapedType lhsShape = lhs.getType().cast(); + mlir::ShapedType rhsShape = rhs.getType().cast(); + Value result = nullptr; + Value outputDimsValue = nullptr; + // Calculate LHS contribution when RHS zp is non-zero. + if (rhsZp != 0) { + SmallVector reductionDims = to_vector(llvm::concat( + dims.lhsSpatialDims, dims.lhsContractingDims)); + Value lhsZpContribution = + createZeroPointPartialOffset(builder, loc, lhs, rhsZp, reductionDims); + // Broadcast lhs ZP contribution to result tensor shape. + lhsZpContribution = broadcastZpContribution( + builder, loc, lhsZpContribution, reductionDims, dims.lhsBatchingDims, + dims.lhsBatchingDims.size(), output, outputTensorType, outputDimsValue); + result = lhsZpContribution; + } + // Calculate RHS contribution when LHS zp is non-zero. + if (lhsZp != 0) { + SmallVector reductionDims = to_vector(llvm::concat( + dims.rhsSpatialDims, dims.rhsContractingDims)); + Value rhsZpContribution = + createZeroPointPartialOffset(builder, loc, rhs, lhsZp, reductionDims); + // Broadcast rhs ZP contribution to result tensor shape. + rhsZpContribution = broadcastZpContribution( + builder, loc, rhsZpContribution, reductionDims, dims.rhsBatchingDims, + lhsShape.getRank() - dims.lhsContractingDims.size(), output, + outputTensorType, outputDimsValue); + if (result) { + result = builder.create(loc, result, rhsZpContribution); + } else { + result = rhsZpContribution; + } + } + + if (lhsZp != 0 && rhsZp != 0) { + // Contributions from LHS_ZP * RHS_ZP. + // This is multiplied by the product of all contracting dimensions. + int32_t contractingDimTotalInt = 1; + bool hasDynamicContractingDim = false; + Value dynamicContractingDimTotal = builder.create( + loc, builder.getI32IntegerAttr(static_cast(1))); + // Calculate the product for static/dynamic dims separately. + for (int64_t rhsIdx : llvm::concat( + dims.rhsSpatialDims, dims.rhsContractingDims)) { + if (rhsShape.isDynamicDim(rhsIdx)) { + hasDynamicContractingDim = true; + auto dim = builder.create( + loc, rhs, builder.getI64IntegerAttr(rhsIdx)); + dynamicContractingDimTotal = + builder.create(loc, dynamicContractingDimTotal, dim); + } else { + contractingDimTotalInt *= rhsShape.getDimSize(rhsIdx); + } + } + Value zpOffsetValue = builder.create( + loc, builder.getI32IntegerAttr(static_cast(lhsZp) * + static_cast(rhsZp) * + contractingDimTotalInt)); + // Multiply the static dims contribution by the dynamic one if needed. + if (hasDynamicContractingDim) { + zpOffsetValue = builder.create(loc, zpOffsetValue, + dynamicContractingDimTotal); + } + result = builder.create(loc, result, zpOffsetValue, + nullptr); + } + return result; +} + +// Generic function to create DotGeneral kernel for Dot/DotGeneral ops. +template +Value createDotLikeKernel(OpBuilder &builder, Location loc, DotLikeOp, + Type resultType, Value &lhs, Value &rhs, + ArrayRef attrs) { + return builder.create(loc, resultType, + ArrayRef{lhs, rhs}, attrs); +} + +// Template specialization for Convolution op. +// This function may pad LHS if needed. If so, lhs is updated in place. +template <> +Value createDotLikeKernel(OpBuilder &builder, Location loc, + mhlo::ConvolutionOp op, + Type resultType, Value &lhs, + Value &rhs, + ArrayRef attrs) { + // We only handle the case where RHS zp is zero. + // Explicitly pad LHS with zp and update LHS value. + SmallVector newAttrs(attrs); + if (op.getPadding().has_value() && + llvm::any_of(op.getPaddingAttr().getValues(), + [](int64_t x) { return x != 0; })) { + auto originalPadding = op.getPaddingAttr().getValues(); + + Value zp = builder.create( + loc, + DenseIntElementsAttr::get( + RankedTensorType::get({}, builder.getI8Type()), + {static_cast(getElementTypeOrSelf(op.getLhs().getType()) + .cast() + .getZeroPoint())})); + // Convert Padding attributes from mhlo::Convolution to mhlo::Pad. Note that + // Padding is applied for spatial dimensions [1...rank-1) only for + // mhlo::Convolution. But mhlo::Pad require those for all dimensions. Hence + // we add 0 to the beginning and end of the padding vectors. + int64_t rank = lhs.getType().cast().getRank(); + SmallVector paddingLow(rank, 0), paddingHigh(rank, 0), + paddingInterior(rank, 0); + for (int64_t i = 1; i < rank - 1; ++i) { + paddingLow[i] = originalPadding[i * 2 - 2]; + paddingHigh[i] = originalPadding[i * 2 - 1]; + } + lhs = builder.create( + loc, lhs, zp, + DenseIntElementsAttr::get( + RankedTensorType::get({rank}, builder.getI64Type()), paddingLow), + DenseIntElementsAttr::get( + RankedTensorType::get({rank}, builder.getI64Type()), paddingHigh), + DenseIntElementsAttr::get( + RankedTensorType::get({rank}, builder.getI64Type()), + paddingInterior)); + + // After explicitly padding/dilating LHS, update attributes so that LHS is + // not padded/dilated again during Convolution. + for (auto &attr : newAttrs) { + if (attr.getName().getValue() == "padding") { + attr.setValue(SplatElementsAttr::get( + RankedTensorType::get({rank - 2, 2}, builder.getI64Type()), + builder.getI64IntegerAttr(0))); + } + } + } + return builder.create( + loc, resultType, ArrayRef{lhs, rhs}, newAttrs); +} + +template +LogicalResult matchAndRewriteDotLikeOp(DotLikeOp op, DotLikeOpAdaptor adaptor, + ArrayRef attrs, + const DotLikeDimensionNumbers &dims, + ConversionPatternRewriter &rewriter) { + // Lower Dot/DotGeneral UQ ops to DotGeneral int. + // Assumes that operands and results are uq types. + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + auto resInt32TensorType = + op.getResult().getType().clone(rewriter.getI32Type()); + + // Dot result + // = dot((lhs - zp_l) * scale_l, (rhs - zp_r) * scale_r) / scale_res + // + zp_res + // = dot(lhs - zp_l, rhs - zp_r) * scale_l * scale_r / scale_res + zp_res + // = dot(lhs, rhs) * combined_scale + combined_zp + // where: + // combined_scale = scale_l * scale_r / scale_res + // combined_zp = res_zp - zp_offset * combined_scale + // zp_offset = zp_l*rhs + zp_r*lhs - zp_l*zp_r + Value resI32 = createDotLikeKernel(rewriter, op->getLoc(), op, + resInt32TensorType, lhs, rhs, attrs); + + auto lhsElementQuantType = getElementTypeOrSelf(op.getLhs().getType()) + .template cast(); + auto rhsElementQuantType = + getElementTypeOrSelf(op.getRhs().getType()) + .template dyn_cast(); + auto rhsElementQuantPerChannelType = + getElementTypeOrSelf(op.getRhs().getType()) + .template dyn_cast(); + auto resElementQuantType = + getElementTypeOrSelf(op.getResult()) + .template dyn_cast(); + auto resElementQuantPerChannelType = + getElementTypeOrSelf(op.getResult()) + .template dyn_cast(); + + // Here we assume LHS must be per-tensor quantized. + // If RHS is per-channel quantized, it must has 0 zp. + Value zpOffset = calculateZeroPointOffset( + rewriter, op->getLoc(), lhs, rhs, resI32, + lhsElementQuantType.getZeroPoint(), + (rhsElementQuantType ? rhsElementQuantType.getZeroPoint() : 0), + resInt32TensorType, dims); + + // For per-channel quantization, we assume that result scales are proportional + // to rhs scales for each channels. + double combinedScaleFp = + rhsElementQuantType + ? lhsElementQuantType.getScale() * rhsElementQuantType.getScale() / + resElementQuantType.getScale() + : lhsElementQuantType.getScale() * + rhsElementQuantPerChannelType.getScales()[0] / + resElementQuantPerChannelType.getScales()[0]; + + // Multiply dot result and zp_offset by combined_scale only if it is not 1.0. + if (std::abs(combinedScaleFp - 1.0) > 0.001) { + Value combinedScale = rewriter.create( + op->getLoc(), rewriter.getF32FloatAttr(combinedScaleFp)); + + auto resFloat32TensorType = + op.getResult().getType().clone(rewriter.getF32Type()); + Value resF32 = rewriter.create( + op->getLoc(), resFloat32TensorType, resI32); + resF32 = rewriter.create( + op->getLoc(), resFloat32TensorType, resF32, combinedScale, nullptr); + resI32 = rewriter.create(op->getLoc(), resInt32TensorType, + resF32); + + // Skip zp_offset if it is 0. + if (zpOffset) { + auto zpOffsetFloat32TensorType = + zpOffset.getType().cast().clone(rewriter.getF32Type()); + zpOffset = rewriter.create( + op->getLoc(), zpOffsetFloat32TensorType, zpOffset); + zpOffset = rewriter.create( + op->getLoc(), zpOffsetFloat32TensorType, zpOffset, combinedScale, + nullptr); + zpOffset = rewriter.create( + op->getLoc(), zpOffsetFloat32TensorType.clone(rewriter.getI32Type()), + zpOffset); + } + } + + // If result is per-channel quantized, it must has 0 zp. + Value combinedZp = rewriter.create( + op->getLoc(), + rewriter.getI32IntegerAttr( + resElementQuantType ? resElementQuantType.getZeroPoint() : 0)); + if (zpOffset) { + combinedZp = rewriter.create( + op->getLoc(), resInt32TensorType, combinedZp, zpOffset, nullptr); + } + rewriter.replaceOpWithNewOp( + op, resInt32TensorType, resI32, combinedZp, nullptr); + return success(); +} + +template +FailureOr isDotLikeOpHybrid(DotLikeOp op) { + // Checks whether a dot-like op is hybrid by looking at input/output types. + // Returns failure() when the type is not supported. + bool isLhsQuant = isa( + getElementTypeOrSelf(op.getLhs().getType())); + bool isLhsQuantPerChannel = isa( + getElementTypeOrSelf(op.getLhs().getType())); + bool isRhsQuant = isa( + getElementTypeOrSelf(op.getRhs().getType())); + bool isRhsQuantPerChannel = isa( + getElementTypeOrSelf(op.getRhs().getType())); + bool isResQuant = + isa(getElementTypeOrSelf(op.getResult())); + bool isResQuantPerChannel = isa( + getElementTypeOrSelf(op.getResult())); + + if (isLhsQuant && ((isRhsQuant && isResQuant) || + (isRhsQuantPerChannel && isResQuantPerChannel))) { + // For quantized ops, RHS and result must be both per-channel quantized or + // both per-tensor quantized. + return false; + } + if (!isLhsQuant && !isLhsQuantPerChannel && + (isRhsQuant || isRhsQuantPerChannel) && !isResQuant && + !isResQuantPerChannel) { + return true; + } + op->emitError("Invalid input/output type for Dot/Convolution op"); + return failure(); +} + +class ConvertUniformQuantizedDotOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::DotOp op, mhlo::DotOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto isHybrid = isDotLikeOpHybrid(op); + if (failed(isHybrid)) { + return failure(); + } + if (*isHybrid) { + return matchAndRewriteDotLikeHybridOp(op, adaptor, rewriter); + } // DotOp is a special case of DotGeneralOp, where LHS and RHS are both + // rank-2 tensors and have contracting dims of 1 and 0 respectively. + auto dims = mhlo::DotDimensionNumbersAttr::get( + rewriter.getContext(), /*lhsBatchingDimensions=*/{}, + /*rhsBatchingDimensions=*/{}, /*lhsContractingDimensions=*/{1}, + /*rhsContractingDimensions=*/{0}); + SmallVector attrs(op->getAttrs()); + attrs.push_back( + {StringAttr::get(rewriter.getContext(), "dot_dimension_numbers"), + dims}); + return matchAndRewriteDotLikeOp( + op, adaptor, attrs, + DotLikeDimensionNumbers{/*lhs_batching_dims=*/{}, + /*lhs_spatial_dims=*/{}, + /*lhs_contracting_dims=*/{1}, + /*rhs_batching_dims=*/{}, + /*rhs_spatial_dims=*/{}, + /*rhs_contracting_dims=*/{0}}, + rewriter); + } +}; + +class ConvertUniformQuantizedDotGeneralOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::DotGeneralOp op, mhlo::DotGeneralOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto isHybrid = isDotLikeOpHybrid(op); + if (failed(isHybrid)) { + return failure(); + } + if (*isHybrid) { + return matchAndRewriteDotLikeHybridOp(op, adaptor, rewriter); + } + return matchAndRewriteDotLikeOp( + op, adaptor, op->getAttrs(), + DotLikeDimensionNumbers{ + to_vector(op.getDotDimensionNumbers().getLhsBatchingDimensions()), + /*lhs_spatial_dims=*/{}, + to_vector( + op.getDotDimensionNumbers().getLhsContractingDimensions()), + to_vector(op.getDotDimensionNumbers().getRhsBatchingDimensions()), + /*rhs_spatial_dims=*/{}, + to_vector( + op.getDotDimensionNumbers().getRhsContractingDimensions())}, + rewriter); + } +}; + +bool isConvNhwc(const mhlo::ConvDimensionNumbersAttr &dims) { + return dims.getInputBatchDimension() == 0 && + dims.getInputFeatureDimension() == 3 && + dims.getInputSpatialDimensions().size() == 2 && + dims.getInputSpatialDimensions()[0] == 1 && + dims.getInputSpatialDimensions()[1] == 2 && + dims.getKernelInputFeatureDimension() == 2 && + dims.getKernelOutputFeatureDimension() == 3 && + dims.getKernelSpatialDimensions().size() == 2 && + dims.getKernelSpatialDimensions()[0] == 0 && + dims.getKernelSpatialDimensions()[1] == 1 && + dims.getOutputBatchDimension() == 0 && + dims.getOutputFeatureDimension() == 3 && + dims.getOutputSpatialDimensions().size() == 2 && + dims.getOutputSpatialDimensions()[0] == 1 && + dims.getOutputSpatialDimensions()[1] == 2; +} + +bool isConvNDHWC(const mhlo::ConvDimensionNumbersAttr &dims) { + return dims.getInputBatchDimension() == 0 && + dims.getInputFeatureDimension() == 4 && + dims.getInputSpatialDimensions().size() == 3 && + dims.getInputSpatialDimensions()[0] == 1 && + dims.getInputSpatialDimensions()[1] == 2 && + dims.getInputSpatialDimensions()[2] == 3 && + dims.getKernelInputFeatureDimension() == 3 && + dims.getKernelOutputFeatureDimension() == 4 && + dims.getKernelSpatialDimensions().size() == 3 && + dims.getKernelSpatialDimensions()[0] == 0 && + dims.getKernelSpatialDimensions()[1] == 1 && + dims.getKernelSpatialDimensions()[2] == 2 && + dims.getOutputBatchDimension() == 0 && + dims.getOutputFeatureDimension() == 4 && + dims.getOutputSpatialDimensions().size() == 3 && + dims.getOutputSpatialDimensions()[0] == 1 && + dims.getOutputSpatialDimensions()[1] == 2 && + dims.getOutputSpatialDimensions()[2] == 3; +} + +FailureOr verifyAndConstructDims( + mhlo::ConvolutionOp op) { + // RHS (weight) must have zero zp. + // Here assumes RHS/result must be both per-tensor or both per-channel + // quantized. + auto failedOr = getQuantType(op.getRhs().getType()); + if (failed(failedOr)) { + return failure(); + } + QuantType rhsElementQuantType = *failedOr; + bool isRhsQuantPerTensor = + std::get_if(&rhsElementQuantType); + + if (isRhsQuantPerTensor + ? (std::get(rhsElementQuantType) + .getZeroPoint() != 0) + : llvm::any_of(llvm::concat( + std::get( + rhsElementQuantType) + .getZeroPoints(), + getElementTypeOrSelf(op.getResult()) + .cast() + .getZeroPoints()), + [](int64_t zp) { return zp != 0; })) { + op->emitError("RHS/result UQ type must have zero zp."); + return failure(); + } + // For per-channel quantization, RHS quantized axis must be out channel axis. + if (!isRhsQuantPerTensor && + (std::get(rhsElementQuantType) + .getQuantizedDimension() != + op.getRhs().getType().cast().getRank() - 1)) { + op->emitError("Conv quantized axis must be out channel axis"); + return failure(); + } + // For per-channel quantization, ratio between RHS and Result scales must be + // the same for each channel. + if (!isRhsQuantPerTensor) { + auto resElementQuantPerChannelType = + getElementTypeOrSelf(op.getResult()) + .cast(); + SmallVector scaleRatios( + resElementQuantPerChannelType.getScales().size()); + for (size_t i = 0; i < scaleRatios.size(); ++i) { + scaleRatios[i] = + resElementQuantPerChannelType.getScales()[i] / + std::get(rhsElementQuantType) + .getScales()[i]; + auto diff = (scaleRatios[i] - scaleRatios[0]) / scaleRatios[0]; + // Check all ratios within a threshold. + if (std::abs(diff) > 0.001) { + op->emitError( + "Per-channel quantizated Conv must have same RHS/Result scale " + "ratio for each channel"); + return failure(); + } + } + } + // lhs_dilation must not exist. + if (op.getLhsDilation().has_value() && + llvm::any_of(op.getLhsDilationAttr().getValues(), + [](int64_t dilate) { return dilate != 1; })) { + op->emitError("lhs_dilation must be 1."); + return failure(); + } + + // We only support NHWC Conv2D and NDHWC Conv3D. + auto dims = op.getDimensionNumbers(); + if (isConvNhwc(dims)) { + // 2D Convolution. + return DotLikeDimensionNumbers{/*lhs_batching_dims=*/{0}, + /*lhs_spatial_dims=*/{1, 2}, + /*lhs_contracting_dims=*/{3}, + /*rhs_batching_dims=*/{}, + /*rhs_spatial_dims=*/{0, 1}, + /*rhs_contracting_dims=*/{2}}; + } + if (isConvNDHWC(dims)) { + // 3D Convolution. + return DotLikeDimensionNumbers{/*lhs_batching_dims=*/{0}, + /*lhs_spatial_dims=*/{1, 2, 3}, + /*lhs_contracting_dims=*/{4}, + /*rhs_batching_dims=*/{}, + /*rhs_spatial_dims=*/{0, 1, 2}, + /*rhs_contracting_dims=*/{3}}; + } + op->emitError("Convolution data format must be NHWC."); + return failure(); +} + +class ConvertUniformQuantizedConvolutionOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::ConvolutionOp op, mhlo::ConvolutionOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto isHybrid = isDotLikeOpHybrid(op); + if (failed(isHybrid)) { + return failure(); + } + if (*isHybrid) { + return matchAndRewriteDotLikeHybridOp(op, adaptor, rewriter); + } + auto dims = verifyAndConstructDims(op); + if (failed(dims)) return failure(); + return matchAndRewriteDotLikeOp(op, adaptor, op->getAttrs(), *dims, + rewriter); + } +}; + +// This pattern lowers a generic MHLO op for uq->int. +// This pattern essentially just performs type change, with no algorithm change. +// TODO: b/310685906 - Add operand/result type validations. +class ConvertGenericOp : public ConversionPattern { + public: + explicit ConvertGenericOp(MLIRContext *ctx, TypeConverter &converter) + : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, ctx) {} + + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // This pattern only handle selected ops. + if (!isa(op)) { + return failure(); + } + + // Determine new result type: use storage type for uq types; use original + // type otherwise. + SmallVector newResultTypes; + for (auto resultType : op->getResultTypes()) { + newResultTypes.push_back(getQuantStorageType(resultType)); + } + + OperationState state(op->getLoc(), op->getName().getStringRef(), operands, + newResultTypes, op->getAttrs(), op->getSuccessors()); + for (Region ®ion : op->getRegions()) { + Region &newRegion = *state.addRegion(); + rewriter.inlineRegionBefore(region, newRegion, newRegion.begin()); + if (failed( + rewriter.convertRegionTypes(&newRegion, *getTypeConverter()))) { + return failure(); + } + } + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +// TypeConverter for converting UQ type to int type. +class UniformQuantizedToIntTypeConverter : public TypeConverter { + public: + UniformQuantizedToIntTypeConverter() { + addConversion([](Type type) -> Type { return getQuantStorageType(type); }); + } +}; + +#define GEN_PASS_DEF_MHLOQUANTLEGALIZETOINT +#include "mhlo/transforms/mhlo_passes.h.inc" + +class MhloQuantLegalizeToInt + : public impl::MhloQuantLegalizeToIntBase { + public: + // Performs conversion of MHLO quant ops to primitive ops. + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *context = op->getContext(); + RewritePatternSet patterns(context); + + // Populate MHLO quant ops conversion patterns. + patterns.add(context); + + // uq->int convert patterns for func.func, func.return and generic ops. + UniformQuantizedToIntTypeConverter converter; + patterns.add(context, converter); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + converter); + populateReturnOpTypeConversionPattern(patterns, converter); + + ConversionTarget target(*op->getContext()); + target.addIllegalDialect(); + auto isLegal = [&converter](Operation *op) { + return converter.isLegal(op); + }; + target.addDynamicallyLegalDialect(isLegal); + target.addDynamicallyLegalDialect(isLegal); + target.addDynamicallyLegalDialect( + [&converter](Operation *op) { + if (auto func = dyn_cast(op)) { + return converter.isSignatureLegal(func.getFunctionType()); + } + return converter.isLegal(op); + }); + + LogicalResult result = + applyPartialConversion(op, target, std::move(patterns)); + if (failed(result)) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +createMhloQuantLegalizeToIntPass() { + return std::make_unique(); +} + +} // namespace mlir::mhlo diff --git a/xla/mlir_hlo/mhlo/transforms/optimize_mhlo/optimize_mhlo.cc b/xla/mlir_hlo/mhlo/transforms/optimize_mhlo/optimize_mhlo.cc index e8f40574b4385..48d30f88e1a23 100644 --- a/xla/mlir_hlo/mhlo/transforms/optimize_mhlo/optimize_mhlo.cc +++ b/xla/mlir_hlo/mhlo/transforms/optimize_mhlo/optimize_mhlo.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/optimize_mhlo/optimize_mhlo_pass.cc b/xla/mlir_hlo/mhlo/transforms/optimize_mhlo/optimize_mhlo_pass.cc index 8dead22b8547e..fbe739495a6f9 100644 --- a/xla/mlir_hlo/mhlo/transforms/optimize_mhlo/optimize_mhlo_pass.cc +++ b/xla/mlir_hlo/mhlo/transforms/optimize_mhlo/optimize_mhlo_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/passes.h b/xla/mlir_hlo/mhlo/transforms/passes.h index 946c56ae18ab1..b9a025cdb5586 100644 --- a/xla/mlir_hlo/mhlo/transforms/passes.h +++ b/xla/mlir_hlo/mhlo/transforms/passes.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -49,10 +49,6 @@ std::unique_ptr> createLegalizeSortPass(); /// Lowers from HLO dialect to Standard dialect. std::unique_ptr> createLegalizeToStdPass(); -/// Lowers from the CHLO dialect to the HLO dialect. -std::unique_ptr> createChloLegalizeToHloPass( - bool legalizeBroadcasts = true, bool expandCompositions = true); - // Lowers from sparse ops in CHLO dialect to Linalg dialect. std::unique_ptr> createLegalizeSparseOperationsPass( bool legalizeToCustomCalls = true); @@ -144,15 +140,6 @@ std::unique_ptr> createConstraintFusionPass(); std::unique_ptr> createGroupReductionDimensionsPass( bool preferColumnsReductions = true); -/// Rank specialization passes: -/// - Find compatible operations and group them together in one rank -/// specialization cluster. -/// - Lower rank specialization clusters to SCF and ranked operations. -std::unique_ptr> -createRankSpecializationClusterPass(); -std::unique_ptr> createRankSpecializationToSCFPass( - int64_t maxTargetRank = 5); - std::unique_ptr> createOptimizeMhloPass(); std::unique_ptr> createLowerComplexPass(); @@ -196,7 +183,12 @@ std::unique_ptr> createHloLegalizeToStablehloPass(); std::unique_ptr> createStablehloLegalizeToHloPass(); // Legalizes from the Shape dialect to the MHLO dialect. -std::unique_ptr> createShapeLegalizeToHloPass(); +std::unique_ptr> createShapeLegalizeToHloPass( + bool legalizeConstraints = false); + +// Legalizes from MHLO quantized ops with MHLO quant types to MHLO primitive ops +// like int ops. +std::unique_ptr> createMhloQuantLegalizeToIntPass(); // Test passes. std::unique_ptr createTestInferShapedTypeMethodsPass(); diff --git a/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc b/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc index 01acfbf635eaf..83a3fc7f71b65 100644 --- a/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc +++ b/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ limitations under the License. // This file implements logic for some optimizations to reduce size on export. +#include #include #include #include @@ -23,6 +24,7 @@ limitations under the License. #include "mhlo/IR/hlo_ops.h" #include "mhlo/transforms/passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" @@ -30,7 +32,9 @@ limitations under the License. #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" #include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/RegionUtils.h" @@ -40,6 +44,8 @@ limitations under the License. namespace mlir { namespace mhlo { +constexpr char kShardingAttr[] = "mhlo.sharding"; + #define GEN_PASS_DEF_PREPAREFOREXPORTPASS #include "mhlo/transforms/mhlo_passes.h.inc" @@ -55,8 +61,8 @@ struct PrepareForExportPass // Materializes some splat before export because it may be more efficient in // HLOInstruction. void prepareConstantOp(Operation *op, SplatElementsAttr attr) { - // Arbitrarialy chosen "small" number. This could be chosen based on the - // proto size too. + // Arbitrarily chosen "small" number. This could be chosen based on the proto + // size too. if (attr.getNumElements() < 32) return; ShapedType returnType = op->getResultTypes().front().cast(); ImplicitLocOpBuilder b(op->getLoc(), op); @@ -72,6 +78,10 @@ void prepareConstantOp(Operation *op, SplatElementsAttr attr) { } auto broadcast = b.create(returnType, cst, b.getI64TensorAttr({})); + if (auto sharding = op->getAttrOfType(kShardingAttr)) { + // The added broadcast inherits the kShardingAttr from op. + broadcast->setAttr(kShardingAttr, sharding); + } op->replaceAllUsesWith(broadcast); op->erase(); } @@ -153,6 +163,36 @@ void prepareBroadcastInDim(BroadcastInDimOp bcast) { DenseIntElementsAttr::get(dims.getType(), transposedDim)); } +// Make implicitly captured constant explicit before exporting +void prepareExplicitCapturedConstants(Operation *op) { + for (Region ®ion : op->getRegions()) { + assert(region.getBlocks().size() == 1 && + "Only OPs with single block regions are allowed"); + llvm::SetVector implicitInputs; + // Get implicit inputs, i.e. those are used in the region + // but defined outside + getUsedValuesDefinedAbove(region, implicitInputs); + Block &block = region.getBlocks().front(); + OpBuilder builder(&block.front()); + for (Value input : implicitInputs) { + // If the captured value is defined by a constant OP, + // Create a clone constant OP within a block to make + // it explicit and replace uses within the block + Operation *definingOp = input.getDefiningOp(); + mlir::DenseElementsAttr attr; + if (matchPattern(input, m_Constant(&attr))) { + Operation *clonedOp = builder.clone(*definingOp); + // Find which uses belong to the block and replace + // with the cloned/explicit one + input.replaceUsesWithIf( + clonedOp->getResult(0), [&block](OpOperand &use) { + return block.getParentOp()->isProperAncestor(use.getOwner()); + }); + } + } + } +} + void PrepareForExportPass::runOnOperation() { getOperation().walk([&](Operation *op) { mlir::SplatElementsAttr attr; @@ -161,6 +201,11 @@ void PrepareForExportPass::runOnOperation() { if (auto whileOp = dyn_cast(op)) return prepareWhileOp(whileOp); if (auto bcastOp = dyn_cast(op)) return prepareBroadcastInDim(bcastOp); + // IfOp, CaseOp, WhileOp are already being handled during + // mhlo --> hlo translation. MapOp soon be deprecated. + if (mlir::isa(op)) + return prepareExplicitCapturedConstants(op); }); } diff --git a/xla/mlir_hlo/mhlo/transforms/rank_specialization/rank_specialization.cc b/xla/mlir_hlo/mhlo/transforms/rank_specialization/rank_specialization.cc deleted file mode 100644 index 8735eb22707a7..0000000000000 --- a/xla/mlir_hlo/mhlo/transforms/rank_specialization/rank_specialization.cc +++ /dev/null @@ -1,976 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -==============================================================================*/ - -#include -#include -#include -#include -#include - -#include "llvm/ADT/EquivalenceClasses.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/SmallVector.h" -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" -#include "mhlo/transforms/rewriters.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "stablehlo/dialect/ChloOps.h" - -namespace mlir { - -/// Needed to build `llvm::SmallSet`s and `llvm::EquivalenceClasses` of -/// `mlir::Value`s. -static bool operator<(const Value &lhs, const Value &rhs) { - return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer(); -} - -namespace mhlo { - -#define GEN_PASS_DEF_RANKSPECIALIZATIONCLUSTERPASS -#define GEN_PASS_DEF_RANKSPECIALIZATIONTOSCFPASS -#include "mhlo/transforms/mhlo_passes.h.inc" - -namespace { - -/// Identify clusters of operations that can be rank-specialized together. The -/// required traits for clustered operations are: -/// - Element-wise: All operations in the group must be element-wise. This -/// allows to reshape operands before applying the operations as well as -/// reshaping the result to the desired shape afterwards. This way, we can, -/// e.g., apply unary ops to a completely flattened operand and restore the -/// original shape afterwards. -/// - Broadcasting semantics: All operations must implement broadcasting -/// semantics. Most importantly, this allows extending operand shapes such -/// that they match in rank. Operations that require all their operands to -/// be of the same shape also fulfill this requirement. -/// - Shape reification: All operations must implement -/// `InferShapedTypeOpInterface`. This is later needed to compute and to -/// restore the desired result shape. - -bool isClusterable(Operation *op) { - if (!llvm::isa(op)) return false; - if (op->getNumOperands() == 0) return false; - return (op->hasTrait() && - op->hasTrait()) || - op->hasTrait(); -} - -struct RankSpecializationClusterPattern : public RewritePattern { - explicit RankSpecializationClusterPattern(MLIRContext *ctx) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - // Only apply to operations that have not been clustered yet. - if (op->getParentOfType()) { - return failure(); - } - - // Only cluster when rank specialization is needed. - if (!isClusterable(op) || !llvm::any_of(op->getOperandTypes(), [](Type ty) { - return ty.isa(); - })) { - return failure(); - } - - // Collect all collectively rank specializable ops. - SmallVector cluster; - llvm::SmallSet operandSet; - llvm::SmallSet resultSet; - - Operation *rootOp = op; - while (rootOp->getNextNode() != nullptr && - isClusterable(rootOp->getNextNode())) - rootOp = rootOp->getNextNode(); - - Operation *it = rootOp; - while (it != nullptr && isClusterable(it)) { - // Find results that escape the cluster. - for (OpOperand &use : it->getUses()) { - if (!llvm::is_contained(cluster, use.getOwner())) - resultSet.insert(use.get()); - } - - // Update cluster operands. - for (OpResult v : it->getResults()) operandSet.erase(Value(v)); - for (OpOperand &v : it->getOpOperands()) operandSet.insert(v.get()); - - cluster.push_back(it); - it = it->getPrevNode(); - } - - // Create `RankSpecializationClusterOp`. - auto operands = llvm::to_vector<16>(operandSet); - auto results = llvm::to_vector<16>(resultSet); - auto resultTypes = llvm::to_vector<16>( - llvm::map_range(resultSet, [](Value v) { return v.getType(); })); - Location loc = op->getLoc(); - auto clusterOp = rewriter.create( - loc, resultTypes, operands); - - // Create body block. - auto operandTypes = llvm::to_vector<16>( - llvm::map_range(operandSet, [](Value v) { return v.getType(); })); - Block *block = - rewriter.createBlock(&clusterOp.getBody(), {}, operandTypes, - SmallVector(operandTypes.size(), loc)); - - // Copy operations into the body. - IRMapping bvm; - for (auto it : llvm::zip(operands, block->getArguments())) - bvm.map(std::get<0>(it), std::get<1>(it)); - rewriter.setInsertionPointToStart(block); - for (Operation *it : llvm::reverse(cluster)) rewriter.clone(*it, bvm); - - // Create `RankSpecializationClusterYieldOp`. - auto mappedResults = llvm::to_vector<16>( - llvm::map_range(results, [&](Value v) { return bvm.lookup(v); })); - rewriter.create(loc, mappedResults); - - // Replace original ops with the new results. - for (auto it : llvm::zip(results, clusterOp.getResults())) - bvm.map(std::get<0>(it), std::get<1>(it)); - for (Operation *it : cluster) { - if (it->getUses().empty()) { - rewriter.eraseOp(it); - continue; - } - auto replacements = llvm::to_vector<16>(llvm::map_range( - it->getResults(), [&](Value v) { return bvm.lookup(v); })); - rewriter.replaceOp(it, replacements); - } - - return success(); - } -}; - -struct MergeRankSpecializationClusterOpsPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op, - PatternRewriter &rewriter) const override { - auto precedingOp = - llvm::dyn_cast_or_null( - op->getPrevNode()); - if (!precedingOp) return failure(); - Block *body = op.SingleBlock::getBody(); - Block *precedingBody = precedingOp.SingleBlock::getBody(); - auto yieldOp = llvm::dyn_cast( - op.SingleBlock::getBody()->getTerminator()); - auto precedingYieldOp = - llvm::dyn_cast( - precedingOp.SingleBlock::getBody()->getTerminator()); - - // Merge cluster operands. Consider only those operands of the second - // cluster that do not originate in the preceding cluster. - SmallVector newOperands; - for (Value v : precedingOp.getOperands()) newOperands.push_back(v); - for (Value v : op.getOperands()) { - if (v.getDefiningOp() != precedingOp && - !llvm::is_contained(precedingOp.getOperands(), v)) { - newOperands.push_back(v); - } - } - - // Merge cluster results. Consider only those results of the preceding - // cluster that are not exclusively used as operands to the second cluster. - SmallVector newUnmappedResults; - for (auto it : - llvm::zip(precedingOp.getResults(), precedingYieldOp.getResults())) { - Value result, innerResult; - std::tie(result, innerResult) = it; - if (!llvm::all_of(result.getUsers(), - [&](Operation *user) { return user == op; })) { - newUnmappedResults.push_back(innerResult); - } - } - for (Value v : yieldOp.getResults()) newUnmappedResults.push_back(v); - - // Create merged cluster op. - rewriter.setInsertionPoint(precedingOp); - auto loc = op.getLoc(); - auto resultTypes = llvm::to_vector<16>(llvm::map_range( - newUnmappedResults, [](Value v) { return v.getType(); })); - auto newOp = rewriter.create( - loc, resultTypes, newOperands); - auto operandTypes = llvm::to_vector<16>( - llvm::map_range(newOperands, [](Value v) { return v.getType(); })); - Block *newBody = - rewriter.createBlock(&newOp.getBody(), {}, operandTypes, - SmallVector(operandTypes.size(), loc)); - rewriter.setInsertionPointToStart(newBody); - - // Map operands and copy operations of the preceding cluster into the new - // body. - IRMapping bvm; - for (const auto &it : llvm::enumerate(precedingBody->getArguments())) - bvm.map(it.value(), newBody->getArgument(it.index())); - for (Operation &nestedOp : precedingBody->without_terminator()) - rewriter.clone(nestedOp, bvm); - - // Map operands and copy operations of the second cluster. If they result - // from the preceeding cluster, we can simply map the corresponding value - // internally. - for (auto it : llvm::zip(body->getArguments(), op.getOperands())) { - Value blockArg, operand; - std::tie(blockArg, operand) = it; - if (operand.getDefiningOp() == precedingOp) { - auto where = llvm::find(precedingOp.getResults(), operand); - assert(where.getBase() != nullptr && "expected to find "); - bvm.map(blockArg, - bvm.lookup(precedingYieldOp.getOperand(where.getIndex()))); - } else { - auto where = llvm::find(newOp.getOperands(), operand); - bvm.map(blockArg, newBody->getArgument(where.getIndex())); - } - } - for (Operation &nestedOp : body->without_terminator()) { - rewriter.clone(nestedOp, bvm); - } - - // Yield inner results. - rewriter.create( - loc, - llvm::to_vector<16>(llvm::map_range(newUnmappedResults, [&](Value v) { - return bvm.lookupOrDefault(v); - }))); - - // Replace the two cluster ops with the new corresponding results. - SmallVector precedingOpReplacements; - int64_t i = 0; - for (Value result : precedingOp.getResults()) { - Value replacement = nullptr; - if (!llvm::all_of(result.getUsers(), - [&](Operation *user) { return user == op; })) { - replacement = newOp->getResult(i++); - } - precedingOpReplacements.push_back(replacement); - } - ValueRange opReplacements = - newOp.getResults().take_back(op.getNumResults()); - rewriter.replaceOp(op, opReplacements); - rewriter.replaceOp(precedingOp, precedingOpReplacements); - - return success(); - } -}; - -struct RankSpecializationClusterPass - : public impl::RankSpecializationClusterPassBase< - RankSpecializationClusterPass> { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - MLIRContext *ctx = &getContext(); - RewritePatternSet patterns(ctx); - mhlo::populateRankSpecializationClusterPatterns(ctx, &patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -/// Lower rank specialization cluster to SCF. - -bool isScalarTensorType(Type ty) { - auto rankedTy = ty.dyn_cast(); - return rankedTy && rankedTy.getRank() == 0; -} - -bool isScalarShapeType(Type ty) { - return ty.cast().getDimSize(0) == 0; -} - -Type deriveRankedTensorTypes(Type ty, int64_t rank) { - auto tensorTy = ty.dyn_cast(); - if (!tensorTy) return ty; - SmallVector shape(rank, ShapedType::kDynamic); - return RankedTensorType::get(shape, tensorTy.getElementType()); -} - -Type deriveUnrankedTensorTypes(Type ty) { - if (auto rankedTy = ty.dyn_cast()) - return UnrankedTensorType::get(rankedTy.getElementType()); - return ty; -} - -SmallVector materializeRankedOperations( - OpBuilder &b, Location loc, IRMapping &bvm, - chlo::RankSpecializationClusterOp op) { - // Create ranked operations. - for (Operation &nestedOp : op.SingleBlock::getBody()->without_terminator()) { - auto mappedOperands = llvm::to_vector<4>(llvm::map_range( - nestedOp.getOperands(), [&](Value v) { return bvm.lookup(v); })); - int64_t targetRank = 0; - for (Value v : mappedOperands) { - targetRank = - std::max(targetRank, v.getType().cast().getRank()); - } - auto rankedResultTypes = llvm::to_vector<2>( - llvm::map_range(nestedOp.getResultTypes(), [targetRank](Type ty) { - return deriveRankedTensorTypes(ty, targetRank); - })); - OperationState rankedOpState(loc, nestedOp.getName().getStringRef(), - mappedOperands, rankedResultTypes, - nestedOp.getAttrs()); - Operation *rankedOp = b.create(rankedOpState); - for (auto it : llvm::zip(nestedOp.getResults(), rankedOp->getResults())) - bvm.map(std::get<0>(it), std::get<1>(it)); - } - - // Collect ranked results. - auto yieldOp = llvm::cast( - op.SingleBlock::getBody()->getTerminator()); - return llvm::to_vector<8>(llvm::map_range( - yieldOp.getResults(), [&](Value v) { return bvm.lookup(v); })); -} - -SmallVector materializeFinalReshape( - PatternRewriter &rewriter, Location loc, - chlo::RankSpecializationClusterOp op, ValueRange unshapedResults) { - auto yieldOp = llvm::cast( - op.SingleBlock::getBody()->getTerminator()); - assert(unshapedResults.size() == 1 && yieldOp.getResults().size() == 1 && - "Currently, rank specialization supports only one result."); - - // Reify result shape. - Operation *lastOpBeforeShapeReification = op->getPrevNode(); - SmallVector resultShape; - Value originalResult = yieldOp.getResults().front(); - auto originalResultIface = - llvm::cast(originalResult.getDefiningOp()); - if (failed(originalResultIface.reifyReturnTypeShapes( - rewriter, originalResultIface->getOperands(), resultShape))) { - return {}; - } - - // Materialize final reshape. - Value unshapedResult = unshapedResults.front(); - Value result = rewriter.create( - loc, deriveUnrankedTensorTypes(unshapedResult.getType()), unshapedResult, - resultShape.front()); - - // Reify shapes until they are independent of operations in the original - // cluster. - { - Operation *it = resultShape.front().getDefiningOp(); - while (it != nullptr && it != lastOpBeforeShapeReification) { - bool advanced = false; - if (auto shapeOfOp = llvm::dyn_cast(it)) { - Operation *def = shapeOfOp.getArg().getDefiningOp(); - if (def && def->getBlock() == op.SingleBlock::getBody()) { - // Resolve `shape_of` op because it still depends on operation in the - // original cluster. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(shapeOfOp); - SmallVector tmpShape; - auto iface = llvm::cast(def); - if (failed(iface.reifyReturnTypeShapes(rewriter, iface->getOperands(), - tmpShape))) - return {}; - rewriter.replaceOp(shapeOfOp, tmpShape.front()); - - // Continue, including the newly created operations. - it = tmpShape.front().getDefiningOp(); - advanced = true; - } - } - - // Skip op, otherwise. - if (!advanced) it = it->getPrevNode(); - } - } - - // Replace all remaining uses of the original cluster's block args. - for (auto it : - llvm::zip(op.getOperands(), op.SingleBlock::getBody()->getArguments())) { - Value operand, barg; - std::tie(operand, barg) = it; - barg.replaceUsesWithIf(operand, [&](OpOperand &operand) { - return operand.getOwner()->getBlock() != op.SingleBlock::getBody(); - }); - } - - return {result}; -} - -Value materializeFlatShape(OpBuilder &b, Location loc, ValueRange sameShapes) { - assert(!sameShapes.empty() && "Expected at least one shape."); - Value shape = sameShapes.size() == 1 - ? sameShapes.front() - : b.create(loc, sameShapes.front().getType(), - sameShapes); - return b.create( - loc, - b.create(loc, b.getIndexType(), shape).getResult()); -} - -Value materializeScalarRankSpecializationCase( - OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - const SmallVector &shapes, ValueRange nonScalarsOfSameShape, - function_ref elseBuilderFn) { - // Materialize predicate: All operands are scalars, except the expected - // non-scalars. - Value one = b.create(loc, 1); - Value allOthersAreScalar; - for (auto it : llvm::zip(op.getOperands(), shapes)) { - Value operand, shape; - std::tie(operand, shape) = it; - if (llvm::is_contained(nonScalarsOfSameShape, operand) || - isScalarTensorType(operand.getType())) { - continue; - } - auto literal = b.create( - loc, arith::CmpIPredicate::eq, - b.create(loc, shape), one); - allOthersAreScalar = - allOthersAreScalar - ? b.create(loc, allOthersAreScalar, literal) - .getResult() - : literal.getResult(); - } - - auto ifOp = b.create( - loc, allOthersAreScalar, - [&](OpBuilder &b, Location loc) { - // Compute flat non-scalar shape. - SmallVector nonScalarShapes; - for (auto it : llvm::zip(op.getOperands(), shapes)) { - Value operand, shape; - std::tie(operand, shape) = it; - if (llvm::is_contained(nonScalarsOfSameShape, operand)) - nonScalarShapes.push_back(shape); - } - Value flatShape = materializeFlatShape(b, loc, nonScalarShapes); - - // Derive ranked operands. - auto rankedOperands = llvm::to_vector<8>( - llvm::map_range(op.getOperands(), [&](Value v) -> Value { - if (isScalarTensorType(v.getType())) return v; - if (!llvm::is_contained(nonScalarsOfSameShape, v)) { - return b - .create( - loc, deriveRankedTensorTypes(v.getType(), /*rank=*/0), - v) - .getResult(); - } - return b - .create( - loc, deriveRankedTensorTypes(v.getType(), /*rank=*/1), v, - flatShape) - .getResult(); - })); - - // Materialize ranked variants for the element-wise operations. - IRMapping bvm; - for (auto it : llvm::zip(op.SingleBlock::getBody()->getArguments(), - rankedOperands)) - bvm.map(std::get<0>(it), std::get<1>(it)); - Value unshapedResult = - materializeRankedOperations(b, loc, bvm, op).front(); - - // Return as unranked tensor for compatibility with the other cases. - b.create( - loc, b.create( - loc, deriveUnrankedTensorTypes(unshapedResult.getType()), - unshapedResult) - .getDest()); - }, - elseBuilderFn); - - return ifOp.getResults().front(); -} - -Value materializeEqualShapesRankSpecializationCase( - OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - const SmallVector &shapes, - function_ref elseBuilderFn) { - // Materialize all shapes equal predicate. - Value allShapesEqOrScalar; - auto nonScalarShapes = llvm::to_vector<8>(llvm::make_filter_range( - shapes, [](Value v) { return !isScalarShapeType(v.getType()); })); - assert( - nonScalarShapes.size() >= 2 && - "Equal shapes strategy requires at least two non-scalar operand shapes."); - for (Value s : llvm::drop_begin(nonScalarShapes)) { - auto literal = b.create(loc, nonScalarShapes.front(), s); - allShapesEqOrScalar = - allShapesEqOrScalar - ? b.create(loc, allShapesEqOrScalar, literal) - .getResult() - : literal; - } - - auto ifOp = b.create( - loc, allShapesEqOrScalar, - [&](OpBuilder &b, Location loc) { - // Flatten non-scalar operands. - Value flatShape = materializeFlatShape(b, loc, nonScalarShapes); - auto flatOperands = llvm::to_vector<8>( - llvm::map_range(op.getOperands(), [&](Value v) -> Value { - if (isScalarTensorType(v.getType())) return v; - return b.create( - loc, deriveRankedTensorTypes(v.getType(), /*rank=*/1), v, - flatShape); - })); - - // Materialize ranked variants for the element-wise operations. - IRMapping bvm; - for (auto it : - llvm::zip(op.SingleBlock::getBody()->getArguments(), flatOperands)) - bvm.map(std::get<0>(it), std::get<1>(it)); - Value unshapedResult = - materializeRankedOperations(b, loc, bvm, op).front(); - - // Return as unranked tensor for compatibility with the other cases. - b.create( - loc, b.create( - loc, deriveUnrankedTensorTypes(unshapedResult.getType()), - unshapedResult) - .getDest()); - }, - elseBuilderFn); - - return ifOp.getResults().front(); -} - -Value materializeTargetRankSpecializationCase( - OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - const SmallVector &shapes, int64_t targetRank) { - // Reshape unranked operands to match the target rank. - RankedTensorType extentTensorTy = - shape::getExtentTensorType(b.getContext(), targetRank); - Value allOnesShape = b.create( - loc, extentTensorTy, - mlir::DenseIntElementsAttr::get(extentTensorTy, - SmallVector(targetRank, 1))); - SmallVector rankedOperands; - for (auto it : llvm::zip(op.getOperands(), shapes)) { - Value operand, shape; - std::tie(operand, shape) = it; - if (operand.getType().isa()) { - rankedOperands.push_back(operand); - continue; - } - Value rankedShape = b.create( - loc, extentTensorTy, - b.create(loc, - shape::getExtentTensorType(b.getContext()), - shape, allOnesShape, - /*error=*/nullptr)); - rankedOperands.push_back(b.create( - loc, deriveRankedTensorTypes(operand.getType(), targetRank), operand, - rankedShape)); - } - - // Materialize ranked versions of the element-wise operations. - IRMapping bvm; - for (auto it : llvm::zip(op.getBody().front().getArguments(), rankedOperands)) - bvm.map(std::get<0>(it), std::get<1>(it)); - - // Return as unranked for compatibility with other target ranks. - auto unshapedResult = materializeRankedOperations(b, loc, bvm, op).front(); - return b.create( - loc, deriveUnrankedTensorTypes(unshapedResult.getType()), unshapedResult); -} - -Value recusivelyMaterializeTargetRankSpecializationCases( - OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - const SmallVector &shapes, Value maxRank, int64_t minTargetRank, - int64_t maxTargetRank) { - Value condition = b.create( - loc, arith::CmpIPredicate::ule, maxRank, - b.create(loc, minTargetRank)); - - // If only a unique target rank is left, we can lower to an assert instead - // of the usual if operation. - if (minTargetRank == maxTargetRank) { - b.create( - loc, condition, - "Input for dynamic binary or n-ary op lowering was of " - "a rank greater than " + - std::to_string(maxTargetRank)); - return materializeTargetRankSpecializationCase(b, loc, op, shapes, - minTargetRank); - } - - // Materialize IR for the smallest considered target rank. - auto ifOp = b.create(loc, op->getResultTypes(), condition, - /*withElseRegion=*/true); - auto thenBuilder = ifOp.getThenBodyBuilder(); - thenBuilder.create( - loc, materializeTargetRankSpecializationCase(thenBuilder, loc, op, shapes, - minTargetRank)); - - // Recurse for all remaining target ranks. - auto elseBuilder = ifOp.getElseBodyBuilder(); - elseBuilder.create( - loc, recusivelyMaterializeTargetRankSpecializationCases( - elseBuilder, loc, op, shapes, maxRank, minTargetRank + 1, - maxTargetRank)); - - return ifOp.getResults().front(); -} - -Value materializeGenericRankSpecializationCases( - OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - const SmallVector &shapes, int64_t maxTargetRank) { - // Get the minimum broadcast shapes of the operands. - auto nonScalarShapes = llvm::to_vector<8>(llvm::make_filter_range( - shapes, [](Value v) { return !isScalarShapeType(v.getType()); })); - auto minBcastShapesOp = b.create( - loc, - SmallVector(nonScalarShapes.size(), - shape::getExtentTensorType(b.getContext())), - nonScalarShapes); - - // Find the maximum rank among the reduced operand shapes. - Value maxRank; - for (Value shape : minBcastShapesOp.getResults()) { - Value rank = b.create(loc, b.getIndexType(), shape); - if (!maxRank) { - maxRank = rank; - } else { - maxRank = b.create( - loc, - b.create(loc, arith::CmpIPredicate::sgt, maxRank, - rank), - maxRank, rank); - } - } - - // Collect reduced shapes. - SmallVector reducedShapes; - auto it = minBcastShapesOp.result_begin(); - for (Value s : shapes) { - if (isScalarShapeType(s.getType())) { - reducedShapes.push_back(s); - } else { - reducedShapes.push_back(*it++); - } - } - - // Materialize rank specialization for ranks 1, ... - return recusivelyMaterializeTargetRankSpecializationCases( - b, loc, op, reducedShapes, maxRank, /*minTargetRank=*/1, maxTargetRank); -} - -Value materializeDefaultRankSpecializationCases( - OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - const SmallVector &shapes, int64_t maxTargetRank) { - return materializeEqualShapesRankSpecializationCase( - b, loc, op, shapes, [&](OpBuilder &b, Location loc) { - b.create(loc, materializeGenericRankSpecializationCases( - b, loc, op, shapes, maxTargetRank)); - }); -} - -SmallVector -materializeRankSpecializationForSingleNonScalarShapeEquivalenceClass( - PatternRewriter &rewriter, Location loc, - chlo::RankSpecializationClusterOp op, ValueRange nonScalarsOfSameShape) { - // Compute flat operand shape. - auto nonScalarShapes = - llvm::to_vector<4>(llvm::map_range(nonScalarsOfSameShape, [&](Value v) { - return rewriter.create(loc, v).getResult(); - })); - Value flatShape = materializeFlatShape(rewriter, loc, nonScalarShapes); - - // Materialize ranked variants for the element-wise operations. - IRMapping bvm; - for (auto it : - llvm::zip(op.SingleBlock::getBody()->getArguments(), op.getOperands())) { - Value operand; - Value bbArg; - std::tie(bbArg, operand) = it; - if (!isScalarTensorType(operand.getType())) { - assert(llvm::is_contained(nonScalarsOfSameShape, operand) && - "Expected all non-scalars in the same shape equivalence class."); - operand = rewriter.create( - loc, deriveRankedTensorTypes(operand.getType(), /*rank=*/1), operand, - flatShape); - } - bvm.map(bbArg, operand); - } - SmallVector unshapedResults = - materializeRankedOperations(rewriter, loc, bvm, op); - - // Restore the results' expected shape. - Value shape = nonScalarShapes.front(); - return llvm::to_vector<8>( - llvm::map_range(unshapedResults, [&](Value v) -> Value { - return rewriter.create( - loc, deriveUnrankedTensorTypes(v.getType()), v, shape); - })); -} - -Value materializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses( - PatternRewriter &rewriter, Location loc, - chlo::RankSpecializationClusterOp op, - SmallVector, 4> nonScalarEqs, int64_t maxTargetRank) { - assert(nonScalarEqs.size() == 2 && - "Expect two non-scalar equivalence classes."); - auto shapes = - llvm::to_vector<8>(llvm::map_range(op.getOperands(), [&](Value v) { - return rewriter.create(loc, v).getResult(); - })); - ValueRange lhsNonScalarEqs = nonScalarEqs[0]; - ValueRange rhsNonScalarEqs = nonScalarEqs[1]; - - // Materialize all the different cases. - Value unshapedResult = materializeScalarRankSpecializationCase( - rewriter, loc, op, shapes, rhsNonScalarEqs, - [&](OpBuilder &b, Location loc) { - b.create( - loc, materializeScalarRankSpecializationCase( - b, loc, op, shapes, lhsNonScalarEqs, - [&](OpBuilder &b, Location loc) { - b.create( - loc, materializeDefaultRankSpecializationCases( - b, loc, op, shapes, maxTargetRank)); - })); - }); - - // Materialize final reshape once and for all rank specialization cases. - return materializeFinalReshape(rewriter, loc, op, unshapedResult).front(); -} - -// Materialize rank generic rank specialization. -Value materializeDefaultRankSpecialization(PatternRewriter &rewriter, - Location loc, - chlo::RankSpecializationClusterOp op, - int64_t maxTargetRank) { - auto shapes = - llvm::to_vector<8>(llvm::map_range(op.getOperands(), [&](Value v) { - return rewriter.create(loc, v).getResult(); - })); - - // Materialize all the different cases. - Value unshapedResult = materializeDefaultRankSpecializationCases( - rewriter, loc, op, shapes, maxTargetRank); - - // Materialize final reshape once and for all rank specialization cases. - return materializeFinalReshape(rewriter, loc, op, unshapedResult).front(); -} - -// This is a very limited form of shape inference. It is correct but incomplete. -SmallVector, 4> findNonScalarShapeEquivalences( - chlo::RankSpecializationClusterOp op) { - llvm::EquivalenceClasses eqs; - - // Bridge the equivalences between operands and block arguments. - for (auto it : - llvm::zip(op.getOperands(), op.SingleBlock::getBody()->getArguments())) - eqs.unionSets(std::get<0>(it), std::get<1>(it)); - - // Find equalities through `SameOperandsAndResultShape` trait. - auto unionSets = [&](ValueRange vs) { - if (vs.empty()) return; - Value repr = vs.front(); - for (Value v : vs.drop_front()) eqs.unionSets(repr, v); - }; - for (Operation &nestedOp : op.SingleBlock::getBody()->without_terminator()) { - if (nestedOp.hasTrait()) { - unionSets(nestedOp.getOperands()); - unionSets(nestedOp.getResults()); - if (!nestedOp.getOperands().empty() && !nestedOp.getResults().empty()) - eqs.unionSets(nestedOp.getResult(0), nestedOp.getOperand(0)); - } - } - - // Find shape equalities through surrounding constraints. - if (auto assumingOp = op->getParentOfType()) { - SmallVector queue; - auto appendIfNotNull = [&](Operation *op) { - if (op != nullptr) queue.push_back(op); - }; - appendIfNotNull(assumingOp.getWitness().getDefiningOp()); - while (!queue.empty()) { - Operation *it = queue.pop_back_val(); - if (auto assumingAllOp = llvm::dyn_cast(it)) { - for (Value v : assumingAllOp.getInputs()) - appendIfNotNull(v.getDefiningOp()); - } else if (auto cstrEqOp = llvm::dyn_cast(it)) { - Value refArg; - for (Value v : cstrEqOp.getShapes()) { - if (auto shapeOfOp = - dyn_cast_or_null(v.getDefiningOp())) { - if (!refArg) { - refArg = shapeOfOp.getArg(); - } else { - eqs.unionSets(refArg, shapeOfOp.getArg()); - } - } - } - } - } - } - - // Find equalities through special knowledge of ops. - // TODO(frgossen): Remove this when these shape equalities can be inferred - // from surrounding shape constraints. - for (Operation &nestedOp : op.SingleBlock::getBody()->without_terminator()) { - if (auto selectOp = llvm::dyn_cast(nestedOp)) { - unionSets( - {selectOp.getOnTrue(), selectOp.getOnFalse(), selectOp.getResult()}); - } else if (auto clampOp = llvm::dyn_cast(nestedOp)) { - unionSets({clampOp.getOperand(), clampOp.getResult()}); - } - } - - // Convert to a list-like equivalence class representation. - SmallVector, 4> nonScalarEqs; - for (Value v : op.getOperands()) { - if (isScalarTensorType(v.getType())) continue; - bool inserted = false; - for (auto &eqClass : nonScalarEqs) { - if (eqs.isEquivalent(eqClass.front(), v)) { - eqClass.push_back(v); - inserted = true; - break; - } - } - if (!inserted) nonScalarEqs.push_back(SmallVector({v})); - } - - return nonScalarEqs; -} - -struct LowerRankSpecializationClusterPattern - : public OpRewritePattern { - LowerRankSpecializationClusterPattern(MLIRContext *ctx, int64_t maxTargetRank) - : OpRewritePattern(ctx, /*benefit=*/1), - maxTargetRank(maxTargetRank) {} - - LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op, - PatternRewriter &rewriter) const override { - // Restoring the result shape currently relies on all operands being used - // for a single result. The result shape is then the broadcasted shape of - // all operands. - if (op.getNumResults() != 1) return failure(); - - // If there is only a single non-scalar shape equivalence class, we can - // flatten that operands completely. - SmallVector, 4> nonScalarEqs = - findNonScalarShapeEquivalences(op); - Location loc = op.getLoc(); - if (nonScalarEqs.size() == 1) { - rewriter.replaceOp( - op, - materializeRankSpecializationForSingleNonScalarShapeEquivalenceClass( - rewriter, loc, op, nonScalarEqs.front())); - return success(); - } - - // If there are exactly two non-scalar shape equivalence classes, we can - // consider two extra cases: If either of the operand classes turns out to - // be all-scalars at runtime, we can, again, flatten all operands. - if (nonScalarEqs.size() == 2) { - rewriter.replaceOp( - op, - materializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses( - rewriter, loc, op, nonScalarEqs, maxTargetRank)); - return success(); - } - - // For all other cases, reshape the operands to match in rank, apply the - // operation, and restore the expected shape. - rewriter.replaceOp(op, materializeDefaultRankSpecialization( - rewriter, loc, op, maxTargetRank)); - return success(); - } - - private: - int64_t maxTargetRank; -}; - -struct RankSpecializationToSCFPass - : public impl::RankSpecializationToSCFPassBase< - RankSpecializationToSCFPass> { - explicit RankSpecializationToSCFPass(int64_t maxTargetRank) - : RankSpecializationToSCFPassBase< - RankSpecializationToSCFPass>::RankSpecializationToSCFPassBase() { - this->max_target_rank_ = maxTargetRank; - } - - void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); - } - - void runOnOperation() override { - MLIRContext *ctx = &getContext(); - RewritePatternSet patterns(ctx); - populateRankSpecializationToSCFPatterns(ctx, &patterns, - this->max_target_rank_); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -void populateRankSpecializationClusterPatterns(MLIRContext *context, - RewritePatternSet *patterns) { - patterns->add(context); -} - -void populateRankSpecializationToSCFPatterns(MLIRContext *context, - RewritePatternSet *patterns, - int64_t maxTargetRank) { - patterns->add(context, maxTargetRank); - shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context); - shape::ShapeOfOp::getCanonicalizationPatterns(*patterns, context); - shape::AnyOp::getCanonicalizationPatterns(*patterns, context); -} - -std::unique_ptr> -createRankSpecializationClusterPass() { - return std::make_unique(); -} - -std::unique_ptr> createRankSpecializationToSCFPass( - int64_t maxTargetRank) { - return std::make_unique(maxTargetRank); -} - -} // namespace mhlo -} // namespace mlir diff --git a/xla/mlir_hlo/mhlo/transforms/restrict_max_rank/restrict_max_rank.cc b/xla/mlir_hlo/mhlo/transforms/restrict_max_rank/restrict_max_rank.cc index 4433af20fafc5..fcaf6acf2e88e 100644 --- a/xla/mlir_hlo/mhlo/transforms/restrict_max_rank/restrict_max_rank.cc +++ b/xla/mlir_hlo/mhlo/transforms/restrict_max_rank/restrict_max_rank.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/rewriters.h b/xla/mlir_hlo/mhlo/transforms/rewriters.h index 2d5bbfa888765..c40a087ab52cb 100644 --- a/xla/mlir_hlo/mhlo/transforms/rewriters.h +++ b/xla/mlir_hlo/mhlo/transforms/rewriters.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -166,13 +166,6 @@ void populateGroupReductionDimensionsPatterns(MLIRContext *context, RewritePatternSet *patterns, bool preferColumnsReductions); -/// Populate rank specialization clustering and lowering patterns. -void populateRankSpecializationClusterPatterns(MLIRContext *context, - RewritePatternSet *patterns); -void populateRankSpecializationToSCFPatterns(MLIRContext *context, - RewritePatternSet *patterns, - int64_t maxTargetRank); - /// Populate sparse tensor specific rewriting patterns. void populateSparseRewritingPatterns(RewritePatternSet *patterns, MLIRContext *ctx); @@ -191,17 +184,16 @@ void populateLegalizeSparseOpsToCustomCallPatterns(MLIRContext *context, namespace chlo { -// Populates a collection of conversion patterns for legalizing broadcasting -// client-HLO to their non-broadcasting counterparts. -void populateChloBroadcastingPatterns(MLIRContext *context, - RewritePatternSet *patterns); +// Populates direct translations between CHLO and MHLO ops for higher level +// MHLO ops like TopK and Erf. +void populateChloToHighLevelMhloOpPatterns(MLIRContext *context, + RewritePatternSet *patterns); -// Populates a collection of conversion patterns for legalizing client-HLO to -// HLO by decomposing client-operations to corresponding sequences of more -// primitive operations. This does not include the -// PopulateChloBroadcastingPatterns above. -void populateDecomposeChloPatterns(MLIRContext *context, - RewritePatternSet *patterns); +// Populates direct translations between CHLO->MHLO high level ops +// and CHLO->StableHLO->MHLO patterns. +void populateChloToHloPatterns(MLIRContext *context, + TypeConverter *typeConverter, + RewritePatternSet *patterns); } // namespace chlo diff --git a/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc b/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc index 48a1fc1f72281..9a45db31ef3ae 100644 --- a/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc +++ b/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -25,13 +26,17 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" @@ -111,6 +116,16 @@ Value castToIndex(PatternRewriter& rewriter, Location loc, Value value) { return cast.getResult(0); } +void insertShapeAssertionCustomCall(OpBuilder builder, Location loc, + Value assert) { + auto customCall = + builder.create(loc, TypeRange{}, ValueRange{assert}); + customCall.setCallTargetName("shape_assertion"); + customCall.setHasSideEffect(true); + customCall->setAttr("error_message", + builder.getStringAttr("Shape assertion failed")); +} + struct ConvertComputeReshapeShapeOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -256,22 +271,156 @@ struct ConvertShapeOfOpPattern : public OpRewritePattern { } }; +struct ConvertConstShapeOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(shape::ConstShapeOp op, + PatternRewriter& rewriter) const override { + auto operandType = op.getResult().getType().dyn_cast(); + if (!operandType) + return rewriter.notifyMatchFailure(op, "expected ranked operand"); + + llvm::SmallVector shape; + for (int i : op.getShape().getValues()) { + shape.push_back(i); + } + auto newConst = rewriter.create( + op.getLoc(), DenseElementsAttr::get( + RankedTensorType::get({operandType.getDimSize(0)}, + rewriter.getI32Type()), + ArrayRef(shape))); + auto newConstIndex = castToIndex(rewriter, op.getLoc(), newConst); + rewriter.replaceOp(op, newConstIndex); + return success(); + } +}; + +struct ConvertIndexCastOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::IndexCastOp op, + PatternRewriter& rewriter) const override { + Value result = op.getIn(); + if (hasIndexStyle(op.getIn()) && !op.getIn().getType().isa()) { + // Handle a special case of index -> i64. + // This is converted to the following sequence: + // unrealized_conversion_cast index -> tensor + // mhlo.convert tensor -> tensor + // unrealized_conversion_cast tensor -> i64 + result = castToI32(rewriter, op.getLoc(), result); + if (!op.getOut().getType().isInteger(32)) { + result = rewriter.create(op.getLoc(), result, + op.getOut().getType()); + } + rewriter.replaceOp(op, rewriter.create( + op.getLoc(), op.getOut().getType(), result)); + return success(); + } + if (!op.getIn().getType().isa() && hasIndexStyle(op.getOut())) { + // Handle a special case of i32 -> index. + // This is converted to the following sequence: + // unrealized_conversion_cast i32 -> tensor + // unrealized_conversion_cast tensor -> index + result = rewriter + .create( + op.getLoc(), RankedTensorType::get({}, result.getType()), + result) + .getResult(0); + rewriter.replaceOp(op, rewriter.create( + op.getLoc(), op.getOut().getType(), result)); + return success(); + } + + if (hasIndexStyle(result)) { + result = castToI32(rewriter, op.getLoc(), result); + } else if (!hasI32Style(result)) { + return rewriter.notifyMatchFailure(op, + "expected input with index/i32 style"); + } + + if (hasIndexStyle(op.getOut())) { + result = castToIndex(rewriter, op.getLoc(), result); + } else if (!hasI32Style(op.getOut())) { + return rewriter.notifyMatchFailure( + op, "expected output with index/i32 style"); + } + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ConvertMulIOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::MulIOp op, + PatternRewriter& rewriter) const override { + // We only handle index types. + if (!hasIndexStyle(op.getLhs()) || !hasIndexStyle(op.getRhs()) || + !hasIndexStyle(op.getResult())) { + return rewriter.notifyMatchFailure(op, "expected index type"); + } + Value lhs = op.getLhs(); + if (auto constIndex = + dyn_cast_or_null(lhs.getDefiningOp())) { + lhs = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get( + RankedTensorType::get({}, rewriter.getI32Type()), + static_cast(constIndex.value()))); + } else { + lhs = castToI32(rewriter, op.getLoc(), op.getLhs()); + } + Value rhs = op.getRhs(); + if (auto constIndex = + dyn_cast_or_null(rhs.getDefiningOp())) { + rhs = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get( + RankedTensorType::get({}, rewriter.getI32Type()), + static_cast(constIndex.value()))); + } else { + rhs = castToI32(rewriter, op.getLoc(), op.getRhs()); + } + Value result = rewriter.create(op.getLoc(), lhs, rhs); + rewriter.replaceOp(op, castToIndex(rewriter, op.getLoc(), result)); + return success(); + } +}; + +// Pads input tensor by X ones from the left. The number X is +// determined by input pad. Result is tensor<(X+N) x i32>, where the first X +// elements are ones. +Value padFromLeft(PatternRewriter& rewriter, Location loc, Value input, + int64_t pad) { + Value padI32 = rewriter.create( + loc, DenseIntElementsAttr::get( + RankedTensorType::get({pad}, rewriter.getI32Type()), 1)); + return rewriter.create(loc, ValueRange{padI32, input}, + /*dimension=*/0); +} + struct ConvertShapeBroadcastOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(shape::BroadcastOp op, PatternRewriter& rewriter) const override { - // Only support broadcasting for two 1D tensors with same size. + // As defined, op inputs must be 1D tensor or !shape.shape. + // We only support inputs of two input 1D tensors. if (op.getShapes().size() != 2) return failure(); auto shape1 = castToI32(rewriter, op.getLoc(), op.getShapes().front()); auto shape2 = castToI32(rewriter, op.getLoc(), op.getShapes().back()); if (!shape1 || !shape2) return failure(); auto tensorType1 = shape1.getType().dyn_cast(); auto tensorType2 = shape2.getType().dyn_cast(); - if (!tensorType1 || !tensorType2 || tensorType1.getRank() != 1 || - tensorType2.getRank() != 1 || - tensorType1.getDimSize(0) != tensorType2.getDimSize(0)) - return failure(); + if (!tensorType1 || !tensorType2) return failure(); + + // If the two operand shapes are of different sizes, the smaller one is + // padded with 1's from the left. + if (tensorType1.getDimSize(0) < tensorType2.getDimSize(0)) { + shape1 = + padFromLeft(rewriter, op.getLoc(), shape1, + tensorType2.getDimSize(0) - tensorType1.getDimSize(0)); + } else if (tensorType1.getDimSize(0) > tensorType2.getDimSize(0)) { + shape2 = + padFromLeft(rewriter, op.getLoc(), shape2, + tensorType1.getDimSize(0) - tensorType2.getDimSize(0)); + } // By definition, broadcasted dims are: // result[i] = lhs[i] if lhs[i] == rhs[i] @@ -314,18 +463,90 @@ struct ConvertTensorDimPattern : public OpRewritePattern { } }; +struct ConvertTensorExtractPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::ExtractOp op, + PatternRewriter& rewriter) const override { + SmallVector indices; + auto tensorType = op.getTensor().getType(); + // We only support getting static indices. + for (auto index : op.getIndices()) { + auto constIndex = + dyn_cast_or_null(index.getDefiningOp()); + if (!constIndex) + return rewriter.notifyMatchFailure(op, "expected constant index op"); + + // Check if the index is out of range. + int idx = indices.size(); + if (tensorType.isDynamicDim(idx) || + constIndex.value() >= tensorType.getDimSize(idx)) + return rewriter.notifyMatchFailure(op, "index out of range"); + + indices.push_back(constIndex.value()); + } + auto input = castToI32(rewriter, op.getLoc(), op.getTensor()); + auto startIndices = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(indices.size())}, + rewriter.getI64Type()), + indices); + for (auto& index : indices) { + index += 1; + } + auto limitIndices = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(indices.size())}, + rewriter.getI64Type()), + indices); + + Value extractedTensor = rewriter.create( + op.getLoc(), input, startIndices, limitIndices, + /*strides=*/ + DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(indices.size())}, + rewriter.getI64Type()), + 1)); + Value extractedScalarTensor = rewriter.create( + op.getLoc(), RankedTensorType::get({}, rewriter.getI32Type()), + extractedTensor); + if (getElementTypeOrSelf(op.getResult().getType()).isIndex()) { + auto extractedIndex = + castToIndex(rewriter, op.getLoc(), extractedScalarTensor); + rewriter.replaceOp(op, extractedIndex); + } else { + // For the special case when the input is a i32 tensor and output is i32, + // convert the result back to i32 to be consistent: + // unrealized_conversion_cast tensor -> i32 + rewriter.replaceOp(op, rewriter.create( + op.getLoc(), op.getResult().getType(), + extractedScalarTensor)); + } + return success(); + } +}; + struct ConvertTensorFromElementsPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::FromElementsOp op, PatternRewriter& rewriter) const override { - // We only handle 1D tensor with index types. tensor.from_elements spec - // allows the same element type only for all input/output. auto tensorType = op.getResult().getType().dyn_cast_or_null(); - if (!tensorType || tensorType.getRank() != 1) { + if (!tensorType) { return failure(); } + if (tensorType.getRank() == 0) { + // Handle the special cast of tensor.from_elements i64 -> tensor + // This is converted to unrealized_conversin_cast i64 -> tensor, + // which is later cancelled with previous unrealized_conversin_cast op. + rewriter.replaceOp( + op, rewriter.create( + op.getLoc(), op.getResult().getType(), op.getElements()[0])); + return success(); + } + + // We only handle 1D tensor with index types. tensor.from_elements spec + // allows the same element type only for all input/output. + if (tensorType.getRank() != 1) return failure(); if (!hasIndexStyle(op.getResult())) return failure(); SmallVector elementI32x1; @@ -356,6 +577,177 @@ struct ConvertTensorFromElementsPattern } }; +struct ConvertCstrBroadcastableOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, + PatternRewriter& rewriter) const override { + // As defined, op inputs must be 1D tensor or !shape.shape. + // We only support inputs of two 1D tensors. + if (op.getShapes().size() != 2) return failure(); + auto shape1 = castToI32(rewriter, op.getLoc(), op.getShapes().front()); + auto shape2 = castToI32(rewriter, op.getLoc(), op.getShapes().back()); + if (!shape1 || !shape2) return failure(); + auto tensorType1 = shape1.getType().dyn_cast(); + auto tensorType2 = shape2.getType().dyn_cast(); + if (!tensorType1 || !tensorType2) return failure(); + + // If the two operand shapes are of different sizes, the smaller one is + // padded with 1's from the left. + int32_t rank = + std::max(tensorType1.getDimSize(0), tensorType2.getDimSize(0)); + if (tensorType1.getDimSize(0) < tensorType2.getDimSize(0)) { + shape1 = + padFromLeft(rewriter, op.getLoc(), shape1, + tensorType2.getDimSize(0) - tensorType1.getDimSize(0)); + } else if (tensorType1.getDimSize(0) > tensorType2.getDimSize(0)) { + shape2 = + padFromLeft(rewriter, op.getLoc(), shape2, + tensorType1.getDimSize(0) - tensorType2.getDimSize(0)); + } + + // Compute if each dim is broadcastable. A dim is broadcastable iff + // dimSize1 == dimSize2 or dimSize1 == 1 or dimSize2 == 1 + auto allOne = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get( + RankedTensorType::get({rank}, rewriter.getI32Type()), + static_cast(1))); + Value dimSize1Is1 = rewriter.create( + op.getLoc(), shape1, allOne, ComparisonDirection::EQ); + Value dimSize2Is1 = rewriter.create( + op.getLoc(), shape2, allOne, ComparisonDirection::EQ); + Value eitherDimSizeIs1 = + rewriter.create(op.getLoc(), dimSize1Is1, dimSize2Is1); + Value dimSizeEq = rewriter.create( + op.getLoc(), shape1, shape2, ComparisonDirection::EQ); + Value dimBroadcastable = + rewriter.create(op.getLoc(), eitherDimSizeIs1, dimSizeEq); + + // Iterate over each dim to check that all dims are broadcastable. + auto boolType = RankedTensorType::get({1}, rewriter.getI1Type()); + Value allBroadcastable = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get(boolType, true)); + for (auto i = 0; i < rank; ++i) { + Value broadcastable = rewriter.create( + op.getLoc(), dimBroadcastable, rewriter.getI64TensorAttr(i), + rewriter.getI64TensorAttr(i + 1), rewriter.getI64TensorAttr(1)); + allBroadcastable = + rewriter.create(op.getLoc(), allBroadcastable, broadcastable); + } + Value allBroadcastableScalar = rewriter.create( + op.getLoc(), RankedTensorType::get({}, rewriter.getI1Type()), + allBroadcastable); + + // Add CustomCallOp and replace Cstr op with const witness, which is useful + // for canonicalizer to remove the shape.assuming region. + insertShapeAssertionCustomCall(rewriter, op->getLoc(), + allBroadcastableScalar); + rewriter.replaceOpWithNewOp(op.getOperation(), true); + return success(); + } +}; + +// As defined in tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.td, the +// dynamic shape is reshapable if it has only 0 or 1 dynamic dimensions and the +// number of element can divide the product of the static dimension sizes. +struct ConvertCstrReshapableOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mhlo::CstrReshapableOp op, + PatternRewriter& rewriter) const override { + Value numElements; + if (auto constIndex = dyn_cast_or_null( + op.getNumElements().getDefiningOp())) { + numElements = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get( + RankedTensorType::get({}, rewriter.getI32Type()), + static_cast(constIndex.value()))); + } else { + numElements = castToI32(rewriter, op->getLoc(), op.getNumElements()); + } + Value dyanmicShape = + castToI32(rewriter, op->getLoc(), op.getDynamicShape()); + if (!dyanmicShape || !numElements) return failure(); + auto dyanmicShapeType = + dyanmicShape.getType().dyn_cast_or_null(); + if (!dyanmicShapeType || dyanmicShapeType.getRank() != 1) return failure(); + + auto i32Type = RankedTensorType::get({}, rewriter.getI32Type()); + Value minusOne = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get(i32Type, -1)); + Value one = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get(i32Type, 1)); + Value zero = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get(i32Type, 0)); + Value productAllDimSizes = one; + Value numDyanmicDim = zero; + for (auto i = 0; i < dyanmicShapeType.getDimSize(0); ++i) { + // Calculate the product of static dimension sizes. + Value dimSize = rewriter.create( + op.getLoc(), dyanmicShape, rewriter.getI64TensorAttr(i), + rewriter.getI64TensorAttr(i + 1), rewriter.getI64TensorAttr(1)); + dimSize = rewriter.create(op.getLoc(), i32Type, dimSize); + productAllDimSizes = + rewriter.create(op.getLoc(), productAllDimSizes, dimSize); + // Count number of -1 dims, aka dynamic dimensions. + Value eqMinusOne = rewriter.create( + op.getLoc(), dimSize, minusOne, ComparisonDirection::EQ); + eqMinusOne = + rewriter.create(op.getLoc(), eqMinusOne, one, zero); + numDyanmicDim = + rewriter.create(op.getLoc(), numDyanmicDim, eqMinusOne); + } + + // Here we handle two situations below. Either one is a valid reshape. + // A: There is 1 dynamic dimension and the number of elements can be divided + // by the product of static dim sizes. + // B: There is no dynamic dimension and the number of elements equals the + // product of all dim sizes. + + // A.1: Check there is 1 dynamic dim. + Value exactlyOneDynamicDim = rewriter.create( + op.getLoc(), numDyanmicDim, one, ComparisonDirection::EQ); + + // A.2: Calculate product of all static dim sizes. Multiple by -1 to cancel + // with the dynamic dim size -1. + Value productStaticDimSizes = + rewriter.create(op.getLoc(), productAllDimSizes, minusOne); + + // A.3: Check number of elements can be divided by product of static dim + // sizes. + Value rem = + rewriter.create(op.getLoc(), numElements, productStaticDimSizes); + Value dynamicReshapable = rewriter.create( + op.getLoc(), rem, zero, ComparisonDirection::EQ); + + // A.4: Check both conditions for scenario A are true. + dynamicReshapable = rewriter.create(op.getLoc(), dynamicReshapable, + exactlyOneDynamicDim); + + // B.1: Check there is no dynamic dim. + Value noDynamicDim = rewriter.create( + op.getLoc(), numDyanmicDim, zero, ComparisonDirection::EQ); + + // B.2: Check product of all dim sizes equals number of elements. + Value staticReshapable = rewriter.create( + op.getLoc(), productAllDimSizes, numElements, ComparisonDirection::EQ); + + // B.3: Check both conditions for scenario B are true. + staticReshapable = + rewriter.create(op.getLoc(), noDynamicDim, staticReshapable); + + // Check if either scenario is true. + Value reshapable = + rewriter.create(op.getLoc(), dynamicReshapable, staticReshapable); + + // Add CustomCallOp and replace Cstr op with const witness, which is + // useful for canonicalizer to remove the shape.assuming region. + insertShapeAssertionCustomCall(rewriter, op->getLoc(), reshapable); + rewriter.replaceOpWithNewOp(op.getOperation(), true); + return success(); + } +}; + template struct CastOperandsPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -387,6 +779,12 @@ struct CastOperandsPattern : public OpRewritePattern { // needed to support bounded dynamism in MHLO export. struct ShapeLegalizeToHloPass : public impl::ShapeLegalizeToHloPassBase { + explicit ShapeLegalizeToHloPass(bool legalizeConstraints) + : impl::ShapeLegalizeToHloPassBase< + ShapeLegalizeToHloPass>::ShapeLegalizeToHloPassBase() { + this->legalize_constraints_ = legalizeConstraints; + } + void runOnOperation() override { // In order to make dynamic MHLO programs compatible with HLO, // we need to get rid of all non-MHLO ops as well as the two shape-related @@ -413,22 +811,26 @@ struct ShapeLegalizeToHloPass // is able to remove unnecessary cruft. At the moment, this pass is a // work in progress, so not all of these ops are supported. // - // The only problem (and a big problem at that) are the ops involved in - // shape constraints: cstr* ops as well as shape.assuming*. Since HLO does - // not support shape constraints, it is currently unclear what to do with - // them, unless they can be removed by --symbolic-shape-optimization. - // At the moment, this pass is a work in progress, so it does not provide - // an answer to this problem yet. + // When legalize_constraints_ is set true, cstr* ops are also legalized. + // A shape_assertion custom_call is used to check the constraint. And the + // shape.assuming region will consume a shape.const_witness that evaluate to + // true, so that it can be removed later in a canonicalizer pass. ConversionTarget target(getContext()); target.addIllegalDialect(); target.addIllegalDialect(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addDynamicallyLegalDialect([](Operation* op) { return !llvm::any_of(op->getOperands(), hasIndexStyle); }); target.addLegalOp(); target.addLegalOp(); + if (this->legalize_constraints_) { + target.addLegalOp(); + } // The patterns do what one might expect, converting between MLIR-style // and HLO-style shape computations. @@ -440,12 +842,21 @@ struct ShapeLegalizeToHloPass // everything went right. RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); patterns.add>(&getContext()); + patterns.add>(&getContext()); patterns.add(&getContext()); + patterns.add(&getContext()); patterns.add(&getContext()); + if (this->legalize_constraints_) { + patterns.add(&getContext()); + patterns.add(&getContext()); + } if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); @@ -454,9 +865,9 @@ struct ShapeLegalizeToHloPass } // namespace -std::unique_ptr> -createShapeLegalizeToHloPass() { - return std::make_unique(); +std::unique_ptr> createShapeLegalizeToHloPass( + bool legalizeConstraints) { + return std::make_unique(legalizeConstraints); } } // namespace mhlo diff --git a/xla/mlir_hlo/mhlo/transforms/shape_reification/shape_reification_pass.cc b/xla/mlir_hlo/mhlo/transforms/shape_reification/shape_reification_pass.cc index 08f7551d12e67..af1738fe3c532 100644 --- a/xla/mlir_hlo/mhlo/transforms/shape_reification/shape_reification_pass.cc +++ b/xla/mlir_hlo/mhlo/transforms/shape_reification/shape_reification_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc b/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc index 0ec9b1a5a8d7a..22e980e96e074 100644 --- a/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc +++ b/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/sink_constants_to_control_flow/sink_constants_to_control_flow.cc b/xla/mlir_hlo/mhlo/transforms/sink_constants_to_control_flow/sink_constants_to_control_flow.cc index d150e5b416c5c..bd2db9037f036 100644 --- a/xla/mlir_hlo/mhlo/transforms/sink_constants_to_control_flow/sink_constants_to_control_flow.cc +++ b/xla/mlir_hlo/mhlo/transforms/sink_constants_to_control_flow/sink_constants_to_control_flow.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/sparse_rewriting/sparse_rewriting.cc b/xla/mlir_hlo/mhlo/transforms/sparse_rewriting/sparse_rewriting.cc deleted file mode 100644 index 07f5454dd6dc9..0000000000000 --- a/xla/mlir_hlo/mhlo/transforms/sparse_rewriting/sparse_rewriting.cc +++ /dev/null @@ -1,150 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements a set of sparse MHLO rewriting rules. - -#include -#include - -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" -#include "mhlo/transforms/rewriters.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/IR/Operation.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir { -namespace mhlo { - -#define GEN_PASS_DEF_SPARSEREWRITINGPASS -#include "mhlo/transforms/mhlo_passes.h.inc" - -namespace { - -/// Approves subsuming sparse types into operation. -// TODO(b/231360416): replace this list with "supports sparsity" trait? -bool canFuseWithSparseConvert(Operation *op) { - return isa(op) || isa(op) || - isa(op) || isa(op) || isa(op) || - isa(op) || isa(op) || isa(op) || - isa(op) || isa(op) || isa(op) || - isa(op) || isa(op) || isa(op) || - isa(op) || isa(op) || isa(op) || - isa(op) || isa(op) || isa(op) || - isa(op) || isa(op) || isa(op) || - isa(op) || isa(op); -} - -/// Fuses a sparse tensor type from a conversion into a mhlo operation -/// where possible, essentially rewriting something like: -/// %0 = mhlo.sign %arg : tensor<100xf64> -/// %1 = sparse_tensor.convert %0 : tensor<100xf64> to tensor<100xf64, #SV> -/// ... = ... %1 ... -/// into: -/// %0 = mhlo.sign %arg : (tensor<100xf64>) -> tensor<100xf64, #SV> -/// ... = ... %0 ... -/// This eventually yields better sparse code, since the intermediate -/// results do not need to be explicitly generated. -struct SparseConvertConverter - : public OpRewritePattern { - explicit SparseConvertConverter(MLIRContext *context) - : OpRewritePattern(context) {} - LogicalResult matchAndRewrite(sparse_tensor::ConvertOp op, - PatternRewriter &rewriter) const override { - // Cannot fuse element-wise type conversion. - if (op.getSource().getType().getElementType() != - op.getDest().getType().getElementType()) { - return failure(); - } - if (Operation *def = op.getSource().getDefiningOp()) { - if (def->hasOneUse() && canFuseWithSparseConvert(def)) { - def->getResult(0).setType(op->getResultTypes()[0]); - rewriter.replaceOp(op, def->getResult(0)); - return success(); - } - } - return failure(); - } -}; - -struct SparseElementWiseConvertConverter - : public OpRewritePattern { - explicit SparseElementWiseConvertConverter(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(mhlo::ConvertOp op, - PatternRewriter &rewriter) const override { - if (sparse_tensor::hasAnySparseOperandOrResult(op)) { - // Uses sparse_tensor::ConvertOp to do element-wise value conversion. - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), op.getOperand()); - return success(); - } - return failure(); - } -}; - -/// Converts a mhlo::concatenate operation into a sparse_tensor::concatenate -/// directly when there is any sparse input/ouput. -struct SparseConcatenateConverter - : public OpRewritePattern { - explicit SparseConcatenateConverter(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(mhlo::ConcatenateOp op, - PatternRewriter &rewriter) const override { - auto resultType = op.getResult().getType(); - if (sparse_tensor::hasAnySparseOperandOrResult(op)) { - // If there is any sparse input, lower to sparse_tensor.concatenate - // directly. - rewriter.replaceOpWithNewOp( - op, resultType, op.getOperands(), - rewriter.getIndexAttr(op.getDimension())); - return success(); - } - // Pass to mhlo lowering pipeline if all input and output tensors - // are dense. - return failure(); - } -}; - -struct SparseRewritingPass - : public impl::SparseRewritingPassBase { - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - populateSparseRewritingPatterns(&patterns, &getContext()); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -void populateSparseRewritingPatterns(RewritePatternSet *patterns, - MLIRContext *ctx) { - patterns->add(ctx); -} - -std::unique_ptr> createSparseRewritingPass() { - return std::make_unique(); -} - -} // namespace mhlo -} // namespace mlir diff --git a/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc b/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc index a3c401a7b1abe..81153b92ecd06 100644 --- a/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc +++ b/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -44,6 +44,20 @@ namespace { return mhlo::Name##Attr::get(attr.getContext(), hloValue.value()) Attribute convertAttr(Attribute stablehloAttr) { + // StableHLO uses DenseArray for some attributes, MHLO is in the process + // of integrating this change. In the meantime, convert DenseArray to + // DenseElementsAttr. + if (auto attr = stablehloAttr.dyn_cast()) { + return DenseIntElementsAttr::get( + RankedTensorType::get(attr.getSize(), attr.getElementType()), + attr.asArrayRef()); + } + if (auto attr = stablehloAttr.dyn_cast()) { + return DenseIntElementsAttr::get( + RankedTensorType::get(attr.getSize(), attr.getElementType()), + attr.asArrayRef()); + } + // Handle StableHLO attributes. // The logic that handles attributes from other dialects (e.g. builtin // attributes) lives below. @@ -343,12 +357,12 @@ class StablehloToHloOpConverter : public OpConversionPattern { // for the generic builder. StablehloToHloOp hloOp; if constexpr (std::is_same::value) { - hloOp = rewriter.replaceOpWithNewOp( - stablehloOp, hloTypes, hloOperands, hloAttrs, - stablehloOp.getBranches().size()); + hloOp = rewriter.create(stablehloOp.getLoc(), hloTypes, + hloOperands, hloAttrs, + stablehloOp.getBranches().size()); } else { - hloOp = rewriter.replaceOpWithNewOp>( - stablehloOp, hloTypes, hloOperands, hloAttrs); + hloOp = rewriter.create>( + stablehloOp.getLoc(), hloTypes, hloOperands, hloAttrs); } // For backward compatibility, fix custom call with mhlo.backend_config @@ -365,6 +379,8 @@ class StablehloToHloOpConverter : public OpConversionPattern { /*entryConversion=*/nullptr))) return failure(); } + + rewriter.replaceOp(stablehloOp, hloOp); return success(); } }; diff --git a/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc b/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc index 848cf5597e819..9b4c4d4eb64f0 100644 --- a/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc +++ b/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc b/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc index bf36a140ce7e8..402bf827c6393 100644 --- a/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc +++ b/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -209,12 +209,12 @@ struct AnnotateExpandingDimensionsInDynamicBroadcastInDim } // Annotate op in place. - rewriter.startRootUpdate(op); + rewriter.startOpModification(op); op.setKnownExpandingDimensionsAttr( rewriter.getI64TensorAttr(knownExpandingDims.takeVector())); op.setKnownNonexpandingDimensionsAttr( rewriter.getI64TensorAttr(knownNonexpandingDims.takeVector())); - rewriter.finalizeRootUpdate(op); + rewriter.finalizeOpModification(op); return success(); } }; diff --git a/xla/mlir_hlo/mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc b/xla/mlir_hlo/mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc index 5d371ed3e4e4c..8bd3bbc140961 100644 --- a/xla/mlir_hlo/mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc +++ b/xla/mlir_hlo/mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -46,8 +46,9 @@ struct InferReturnTypesPattern : public RewritePattern { SmallVector types; if (failed(definingOpInt.inferReturnTypes( op->getContext(), op->getLoc(), definingOp->getOperands(), - definingOp->getAttrDictionary(), op->getPropertiesStorage(), - definingOp->getRegions(), types))) { + definingOpInt->getAttrDictionary(), + definingOpInt->getPropertiesStorage(), definingOpInt->getRegions(), + types))) { return failure(); } diff --git a/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm.cc b/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm.cc index 8e300a15b4aa4..c59722ded20ad 100644 --- a/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm.cc +++ b/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc b/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc index 133fe78e3ff2e..7409def78d770 100644 --- a/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc +++ b/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/utils/CMakeLists.txt b/xla/mlir_hlo/mhlo/utils/CMakeLists.txt index 4bb7ede7af873..59889fec26735 100644 --- a/xla/mlir_hlo/mhlo/utils/CMakeLists.txt +++ b/xla/mlir_hlo/mhlo/utils/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.cc b/xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.cc index e6161a90154bc..2d1a5009f8354 100644 --- a/xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.cc +++ b/xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.h b/xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.h index 218d3c084869d..1b476d64da64d 100644 --- a/xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.h +++ b/xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/utils/mhlo_rng_utils.cc b/xla/mlir_hlo/mhlo/utils/mhlo_rng_utils.cc index d2e9b48da34d9..ac34bea9cc222 100644 --- a/xla/mlir_hlo/mhlo/utils/mhlo_rng_utils.cc +++ b/xla/mlir_hlo/mhlo/utils/mhlo_rng_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/utils/mhlo_rng_utils.h b/xla/mlir_hlo/mhlo/utils/mhlo_rng_utils.h index ab2c5b2e5d6aa..03cd2bb9302a2 100644 --- a/xla/mlir_hlo/mhlo/utils/mhlo_rng_utils.h +++ b/xla/mlir_hlo/mhlo/utils/mhlo_rng_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/utils/mhlo_scatter_gather_utils.cc b/xla/mlir_hlo/mhlo/utils/mhlo_scatter_gather_utils.cc index 5564648713a84..91db4969a88ad 100644 --- a/xla/mlir_hlo/mhlo/utils/mhlo_scatter_gather_utils.cc +++ b/xla/mlir_hlo/mhlo/utils/mhlo_scatter_gather_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/utils/mhlo_scatter_gather_utils.h b/xla/mlir_hlo/mhlo/utils/mhlo_scatter_gather_utils.h index a0a52105e18fc..2a4c5d7cb3183 100644 --- a/xla/mlir_hlo/mhlo/utils/mhlo_scatter_gather_utils.h +++ b/xla/mlir_hlo/mhlo/utils/mhlo_scatter_gather_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/utils/type_conversion.cc b/xla/mlir_hlo/mhlo/utils/type_conversion.cc index 2522785f68d73..42a5a54228369 100644 --- a/xla/mlir_hlo/mhlo/utils/type_conversion.cc +++ b/xla/mlir_hlo/mhlo/utils/type_conversion.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/mhlo/utils/type_conversion.h b/xla/mlir_hlo/mhlo/utils/type_conversion.h index ae3f0b963e8af..0bcba717981b2 100644 --- a/xla/mlir_hlo/mhlo/utils/type_conversion.h +++ b/xla/mlir_hlo/mhlo/utils/type_conversion.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/tests/BUILD b/xla/mlir_hlo/tests/BUILD index 9ddfca0aa4289..4d595d5b13534 100644 --- a/xla/mlir_hlo/tests/BUILD +++ b/xla/mlir_hlo/tests/BUILD @@ -1,6 +1,6 @@ -load("@llvm-project//llvm:lit_test.bzl", "lit_test", "package_path") load("@bazel_skylib//rules:build_test.bzl", "build_test") load("@bazel_skylib//rules:expand_template.bzl", "expand_template") +load("@llvm-project//llvm:lit_test.bzl", "lit_test", "package_path") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/xla/mlir_hlo/tests/CMakeLists.txt b/xla/mlir_hlo/tests/CMakeLists.txt index 1e4ee187e1aec..f874b69ace02b 100644 --- a/xla/mlir_hlo/tests/CMakeLists.txt +++ b/xla/mlir_hlo/tests/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_hlo_broadcasts.mlir b/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_hlo_broadcasts.mlir deleted file mode 100644 index ce5a493fd6e89..0000000000000 --- a/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_hlo_broadcasts.mlir +++ /dev/null @@ -1,345 +0,0 @@ -// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="legalize-broadcasts=true expand-compositions=false" -cse -canonicalize -split-input-file -verify-diagnostics %s -o - | FileCheck %s - -// Check the non-broadcast case for each registered op, then just check a -// representative op for detailed broadcast semantics. -// CHECK-LABEL: @addWithoutBroadcast -func.func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: mhlo.add %arg0, %arg1 - %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @dynamicBroadcast -// CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK-SAME: %[[ARG1:.+]]: tensor -func.func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] - // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] - // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] - // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] - // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] - // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[ARG0_B]], %[[ARG1_B]] - // CHECK-NEXT: shape.assuming_yield %[[RESULT]] - // CHECK-NEXT: } - // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor - %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- -// CHECK-LABEL: @dynamicBroadcastComplex -// CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK-SAME: %[[ARG1:.+]]: tensor -func.func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> tensor> { - // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] - // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] - // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] - // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] - // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] - // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-NEXT: %[[RESULT:.+]] = mhlo.complex %[[ARG0_B]], %[[ARG1_B]] : tensor> - // CHECK-NEXT: shape.assuming_yield %[[RESULT]] - // CHECK-NEXT: } - // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor> - %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor, tensor) -> tensor> - func.return %0 : tensor> -} - -// ----- -// CHECK-LABEL: @dynamicBroadcastCompare -// CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK-SAME: %[[ARG1:.+]]: tensor -func.func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] - // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] - // CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] - // CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] - // CHECK: %[[RESULT_EXTENTS:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] - // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK: %[[RESULT:.+]] = mhlo.compare EQ, %[[ARG0_B]], %[[ARG1_B]] : (tensor, tensor) -> tensor - // CHECK: shape.assuming_yield %[[RESULT]] - // CHECK-NEXT: } - // CHECK: return %[[FINAL_RESULT]] : tensor - %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = #chlo} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @selectv2 -func.func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: mhlo.select %arg0, %arg1, %arg2 - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - func.return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @selectv2_pred_scalar -func.func @selectv2_pred_scalar(%arg0: tensor, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: mhlo.select %arg0, %arg1, %arg2 - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - func.return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @selectv2_broadcast_then -func.func @selectv2_broadcast_then(%arg0: tensor, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { - // CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32> - // CHECK-NEXT: mhlo.select %arg0, %[[BROADCAST]], %arg2 - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> - func.return %0: tensor<2x8x8xi32> -} - -// CHECK-LABEL: func @selectv2_broadcast_else -func.func @selectv2_broadcast_else(%arg0: tensor, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> { - // CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32> - // CHECK-NEXT: mhlo.select %arg0, %arg1, %[[BROADCAST]] - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32> - func.return %0: tensor<2x8x8xi32> -} - -// CHECK-LABEL: func @selectv2_broadcast_pred -func.func @selectv2_broadcast_pred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { - // CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1xi1>) -> tensor<2x8x8xi1> - // CHECK-NEXT: mhlo.select %[[BROADCAST]], %arg1, %arg2 - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> - func.return %0: tensor<2x8x8xi32> -} - -// CHECK-LABEL: func @selectv2_broadcast_tensor_pred -func.func @selectv2_broadcast_tensor_pred(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { - // CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi1>) -> tensor<2x3xi1> - // CHECK-NEXT: mhlo.select %[[BROADCAST]], %arg1, %arg2 - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> - func.return %0: tensor<2x3xf16> -} - -// CHECK-LABEL: func @selectv2_broadcast_all -func.func @selectv2_broadcast_all(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> { - // CHECK-DAG: %[[BROADCAST_0:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x1xi1>) -> tensor<8x8x8xi1> - // CHECK-DAG: %[[BROADCAST_1:.*]] = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x8x1xi32>) -> tensor<8x8x8xi32> - // CHECK-DAG: %[[BROADCAST_2:.*]] = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x1x8xi32>) -> tensor<8x8x8xi32> - // CHECK: mhlo.select %[[BROADCAST_0]], %[[BROADCAST_1]], %[[BROADCAST_2]] - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32> - func.return %0: tensor<8x8x8xi32> -} - -// CHECK-LABEL: func @selectv2_dynamic_ranked -func.func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> { - // CHECK-DAG: %[[SHAPE0:.*]] = shape.const_shape [1] : tensor<1xindex> - // CHECK-DAG: %[[SHAPE2:.*]] = shape.const_shape [2, 8, 8] : tensor<3xindex> - // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor<2x?x8xi32> -> tensor<3xindex> - // CHECK-NEXT: %[[CSTR:.*]] = shape.cstr_broadcastable %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE2]] : tensor<3xindex>, tensor<1xindex>, tensor<3xindex> - // CHECK-NEXT: %[[ASSUME:.*]] = shape.assuming %[[CSTR]] -> (tensor<2x?x8xi32>) { - // CHECK-NEXT: %[[BCST:.*]] = shape.broadcast %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> - // CHECK-NEXT: %[[BCST0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[BCST]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1xi1>, tensor<3xindex>) -> tensor<2x?x8xi1> - // CHECK-NEXT: %[[BCST1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[BCST]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x?x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32> - // CHECK-NEXT: %[[BCST2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, %[[BCST]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x8x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32> - // CHECK-NEXT: %[[SELECT:.*]] = mhlo.select %[[BCST0]], %[[BCST1]], %[[BCST2]] : tensor<2x?x8xi1>, tensor<2x?x8xi32> - // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor<2x?x8xi32> - // CHECK-NEXT: } - // CHECK-NEXT: return %[[ASSUME]] : tensor<2x?x8xi32> - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32> - func.return %0: tensor<2x?x8xi32> -} - -// ----- -// Verifies that broadcast_dimensions validity checks are valid. -// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions -func.func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK: mhlo.add - %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - func.return %0 : tensor<1x4xf32> -} - -// ----- -// Verifies that broadcast_dimensions validity checks are valid. -// CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions -func.func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor) -> tensor<1x4xf32> { - // CHECK: mhlo.add - %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor) -> tensor<1x4xf32> - func.return %0 : tensor<1x4xf32> -} - -// ----- -// Verifies that invalid broadcast dimensions are rejected. -func.func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}} - // expected-error @+1 {{failed to legalize operation}} - %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - func.return %0 : tensor<1x4xf32> -} - -// ----- -// Verifies that invalid broadcast dimensions are rejected. -func.func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}} - // expected-error @+1 {{failed to legalize operation}} - %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - func.return %0 : tensor<1x4xf32> -} - -// ----- -// Note that broadcast_add is used as a proxy for all of the template -// expansions. Tests below merely verify that the op has an expansion. -// CHECK-LABEL: @andWithoutBroadcast -func.func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { - // CHECK: mhlo.and %arg0, %arg1 - %0 = chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> - func.return %0 : tensor<4xi1> -} - -// ----- -// CHECK-LABEL: @atan2WithoutBroadcast -func.func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: mhlo.atan2 %arg0, %arg1 - %0 = chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @compareWithoutBroadcast -func.func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> { - // CHECK: mhlo.compare EQ, %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> - %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = #chlo} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> - func.return %0 : tensor<4xi1> -} - -// ----- -// CHECK-LABEL: @complexWithoutBroadcast -func.func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex> { - // CHECK: mhlo.complex %arg0, %arg1 : tensor<4xcomplex> - %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> - func.return %0 : tensor<4xcomplex> -} - -// ----- -// CHECK-LABEL: @divideWithoutBroadcast -func.func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: mhlo.divide %arg0, %arg1 - %0 = chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @maximumWithoutBroadcast -func.func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: mhlo.maximum %arg0, %arg1 - %0 = chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @minimumWithoutBroadcast -func.func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: mhlo.minimum %arg0, %arg1 - %0 = chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @multiplyWithoutBroadcast -func.func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: mhlo.multiply %arg0, %arg1 - %0 = chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @orWithoutBroadcast -func.func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { - // CHECK: mhlo.or %arg0, %arg1 - %0 = chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> - func.return %0 : tensor<4xi1> -} - -// ----- -// CHECK-LABEL: @powerWithoutBroadcast -func.func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: mhlo.power %arg0, %arg1 - %0 = chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @remainderWithoutBroadcast -func.func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: mhlo.remainder %arg0, %arg1 - %0 = chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @shift_leftWithoutBroadcast -func.func @shift_leftWithoutBroadcast(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: mhlo.shift_left %arg0, %arg1 - %0 = chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - func.return %0 : tensor<4xi32> -} - -// ----- -// CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast -func.func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: mhlo.shift_right_arithmetic %arg0, %arg1 - %0 = chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - func.return %0 : tensor<4xi32> -} - -// ----- -// CHECK-LABEL: @shift_right_logicalWithoutBroadcast -func.func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: mhlo.shift_right_logical %arg0, %arg1 - %0 = chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - func.return %0 : tensor<4xi32> -} - -// ----- -// CHECK-LABEL: @subWithoutBroadcast -func.func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: mhlo.subtract %arg0, %arg1 - %0 = chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @xorWithoutBroadcast -func.func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { - // CHECK: mhlo.xor %arg0, %arg1 - %0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> - func.return %0 : tensor<4xi1> -} - -// ----- -// CHECK-LABEL: @NextAfterWithoutBroadcast -// CHECK-SAME: (%[[LHS:.*]]: tensor<4xf32>, %[[RHS:.*]]: tensor<4xf32>) -func.func @NextAfterWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) - -> tensor<4xf32> { - // CHECK: chlo.next_after %[[LHS]], %[[RHS]] - %0 = chlo.broadcast_next_after %arg0, %arg1 - : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @PolygammaWithoutBroadcast -// CHECK-SAME: (%[[LHS:.*]]: tensor<4xf32>, %[[RHS:.*]]: tensor<4xf32>) -func.func @PolygammaWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) - -> tensor<4xf32> { - // CHECK: chlo.polygamma %[[LHS]], %[[RHS]] - %0 = chlo.broadcast_polygamma %arg0, %arg1 - : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @ZetaWithoutBroadcast -func.func @ZetaWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) - -> tensor<4xf32> { - // CHECK: chlo.zeta %arg0, %arg1 - %0 = chlo.broadcast_zeta %arg0, %arg1 - : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} diff --git a/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_hlo_no_broadcasts.mlir b/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_hlo_no_broadcasts.mlir deleted file mode 100644 index 2a22006f2c983..0000000000000 --- a/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_hlo_no_broadcasts.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="legalize-broadcasts=false" %s | FileCheck %s - -// CHECK-LABEL: atan_static -// CHECK-SAME: %[[ARG:.*]]: tensor<2x3x4xf32> -func.func @atan_static(%arg0: tensor<2x3x4xf32>) -> tuple> { - // CHECK: %[[CST:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x3x4xf32> - // CHECK: mhlo.atan2 %[[ARG]], %[[CST]] : tensor<2x3x4xf32> - %0 = chlo.atan %arg0 : tensor<2x3x4xf32> -> tensor<2x3x4xf32> - %1 = "mhlo.tuple"(%0) : (tensor<2x3x4xf32>) -> tuple> - func.return %1 : tuple> -} diff --git a/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir b/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir index 3e72efd620c74..fdc11fedbd392 100644 --- a/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir @@ -1,4 +1,5 @@ // RUN: mlir-hlo-opt --chlo-legalize-to-hlo --split-input-file -verify-diagnostics %s | FileCheck %s --dump-input-context=20 +// RUN: mlir-hlo-opt --chlo-legalize-to-high-level-mhlo --split-input-file -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-HIGH-LEVEL // CHECK-LABEL: func.func @asin_bf16( // CHECK-SAME: %[[TMP_arg0:.*]]: tensor @@ -237,7 +238,7 @@ func.func @constant_like_static_shape(%arg : tensor<1x2xi64>) -> tensor<1x2xf32> func.func @constant_like_dynamic_shape(%arg : tensor) -> tensor { // CHECK: %[[CONSTANT:.*]] = mhlo.constant dense<3.200000e+00> : tensor // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor -> tensor<2xindex> - // CHECK: %[[BROADCASTED_CONSTANT:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[CONSTANT]], %[[SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK: %[[BROADCASTED_CONSTANT:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[CONSTANT]], %[[SHAPE]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<2xindex>) -> tensor // CHECK: return %[[BROADCASTED_CONSTANT]] : tensor %result = "chlo.constant_like"(%arg) { value = 3.2 : f32 } : (tensor) -> tensor @@ -262,147 +263,8 @@ func.func @conj(%arg0: tensor<3xcomplex>) -> tensor<3xcomplex> { // CHECK-LABEL: @erf_f64 // CHECK-SAME: %[[ARG:.*]]: tensor func.func @erf_f64(%arg : tensor) -> tensor { - // CHECK: %[[TMP_0:.*]] = mhlo.multiply %[[ARG]], %[[ARG]] - // CHECK: %[[TMP_3:.*]] = mhlo.constant dense<9.6049737398705161> - // CHECK: %[[TMP_5:.*]] = mhlo.multiply %[[TMP_3]], %[[TMP_0]] - // CHECK: %[[TMP_6:.*]] = mhlo.constant dense<90.026019720384269> - // CHECK: %[[TMP_7:.*]] = mhlo.add %[[TMP_5]], %[[TMP_6]] - // CHECK: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_7]], %[[TMP_0]] - // CHECK: %[[TMP_9:.*]] = mhlo.constant dense<2232.0053459468431> - // CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_8]], %[[TMP_9]] - // CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_0]] - // CHECK: %[[TMP_12:.*]] = mhlo.constant dense<7003.3251411280507> - // CHECK: %[[TMP_13:.*]] = mhlo.add %[[TMP_11]], %[[TMP_12]] - // CHECK: %[[TMP_14:.*]] = mhlo.multiply %[[TMP_13]], %[[TMP_0]] - // CHECK: %[[TMP_15:.*]] = mhlo.constant dense<55592.301301039493> - // CHECK: %[[TMP_16:.*]] = mhlo.add %[[TMP_14]], %[[TMP_15]] - // CHECK: %[[TMP_17:.*]] = mhlo.multiply %[[ARG]], %[[TMP_16]] - // CHECK: %[[TMP_20:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_22:.*]] = mhlo.multiply %[[TMP_20]], %[[TMP_0]] - // CHECK: %[[TMP_23:.*]] = mhlo.constant dense<33.561714164750313> - // CHECK: %[[TMP_24:.*]] = mhlo.add %[[TMP_22]], %[[TMP_23]] - // CHECK: %[[TMP_25:.*]] = mhlo.multiply %[[TMP_24]], %[[TMP_0]] - // CHECK: %[[TMP_26:.*]] = mhlo.constant dense<521.35794978015269> - // CHECK: %[[TMP_27:.*]] = mhlo.add %[[TMP_25]], %[[TMP_26]] - // CHECK: %[[TMP_28:.*]] = mhlo.multiply %[[TMP_27]], %[[TMP_0]] - // CHECK: %[[TMP_29:.*]] = mhlo.constant dense<4594.3238297098014> - // CHECK: %[[TMP_30:.*]] = mhlo.add %[[TMP_28]], %[[TMP_29]] - // CHECK: %[[TMP_31:.*]] = mhlo.multiply %[[TMP_30]], %[[TMP_0]] - // CHECK: %[[TMP_32:.*]] = mhlo.constant dense<22629.000061389095> - // CHECK: %[[TMP_33:.*]] = mhlo.add %[[TMP_31]], %[[TMP_32]] - // CHECK: %[[TMP_34:.*]] = mhlo.multiply %[[TMP_33]], %[[TMP_0]] - // CHECK: %[[TMP_35:.*]] = mhlo.constant dense<49267.394260863592> - // CHECK: %[[TMP_36:.*]] = mhlo.add %[[TMP_34]], %[[TMP_35]] - // CHECK: %[[TMP_37:.*]] = mhlo.divide %[[TMP_17]], %[[TMP_36]] - // CHECK: %[[TMP_38:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_39:.*]] = mhlo.multiply %[[ARG]], %[[ARG]] - // CHECK: %[[TMP_40:.*]] = mhlo.negate %[[TMP_39]] - // CHECK: %[[TMP_41:.*]] = mhlo.exponential %[[TMP_40]] - // CHECK: %[[TMP_42:.*]] = mhlo.abs %[[ARG]] - // CHECK: %[[TMP_45:.*]] = mhlo.constant dense<2.4619698147353052E-10> - // CHECK: %[[TMP_47:.*]] = mhlo.multiply %[[TMP_45]], %[[TMP_42]] - // CHECK: %[[TMP_48:.*]] = mhlo.constant dense<0.56418956483106886> - // CHECK: %[[TMP_49:.*]] = mhlo.add %[[TMP_47]], %[[TMP_48]] - // CHECK: %[[TMP_50:.*]] = mhlo.multiply %[[TMP_49]], %[[TMP_42]] - // CHECK: %[[TMP_51:.*]] = mhlo.constant dense<7.4632105644226989> - // CHECK: %[[TMP_52:.*]] = mhlo.add %[[TMP_50]], %[[TMP_51]] - // CHECK: %[[TMP_53:.*]] = mhlo.multiply %[[TMP_52]], %[[TMP_42]] - // CHECK: %[[TMP_54:.*]] = mhlo.constant dense<48.637197098568137> - // CHECK: %[[TMP_55:.*]] = mhlo.add %[[TMP_53]], %[[TMP_54]] - // CHECK: %[[TMP_56:.*]] = mhlo.multiply %[[TMP_55]], %[[TMP_42]] - // CHECK: %[[TMP_57:.*]] = mhlo.constant dense<196.5208329560771> - // CHECK: %[[TMP_58:.*]] = mhlo.add %[[TMP_56]], %[[TMP_57]] - // CHECK: %[[TMP_59:.*]] = mhlo.multiply %[[TMP_58]], %[[TMP_42]] - // CHECK: %[[TMP_60:.*]] = mhlo.constant dense<526.44519499547732> - // CHECK: %[[TMP_61:.*]] = mhlo.add %[[TMP_59]], %[[TMP_60]] - // CHECK: %[[TMP_62:.*]] = mhlo.multiply %[[TMP_61]], %[[TMP_42]] - // CHECK: %[[TMP_63:.*]] = mhlo.constant dense<934.52852717195765> - // CHECK: %[[TMP_64:.*]] = mhlo.add %[[TMP_62]], %[[TMP_63]] - // CHECK: %[[TMP_65:.*]] = mhlo.multiply %[[TMP_64]], %[[TMP_42]] - // CHECK: %[[TMP_66:.*]] = mhlo.constant dense<1027.5518868951572> - // CHECK: %[[TMP_67:.*]] = mhlo.add %[[TMP_65]], %[[TMP_66]] - // CHECK: %[[TMP_68:.*]] = mhlo.multiply %[[TMP_67]], %[[TMP_42]] - // CHECK: %[[TMP_69:.*]] = mhlo.constant dense<557.53533536939938> - // CHECK: %[[TMP_70:.*]] = mhlo.add %[[TMP_68]], %[[TMP_69]] - // CHECK: %[[TMP_71:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_70]] - // CHECK: %[[TMP_74:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_76:.*]] = mhlo.multiply %[[TMP_74]], %[[TMP_42]] - // CHECK: %[[TMP_77:.*]] = mhlo.constant dense<13.228195115474499> - // CHECK: %[[TMP_78:.*]] = mhlo.add %[[TMP_76]], %[[TMP_77]] - // CHECK: %[[TMP_79:.*]] = mhlo.multiply %[[TMP_78]], %[[TMP_42]] - // CHECK: %[[TMP_80:.*]] = mhlo.constant dense<86.707214088598973> - // CHECK: %[[TMP_81:.*]] = mhlo.add %[[TMP_79]], %[[TMP_80]] - // CHECK: %[[TMP_82:.*]] = mhlo.multiply %[[TMP_81]], %[[TMP_42]] - // CHECK: %[[TMP_83:.*]] = mhlo.constant dense<354.93777888781989> - // CHECK: %[[TMP_84:.*]] = mhlo.add %[[TMP_82]], %[[TMP_83]] - // CHECK: %[[TMP_85:.*]] = mhlo.multiply %[[TMP_84]], %[[TMP_42]] - // CHECK: %[[TMP_86:.*]] = mhlo.constant dense<975.70850174320549> - // CHECK: %[[TMP_87:.*]] = mhlo.add %[[TMP_85]], %[[TMP_86]] - // CHECK: %[[TMP_88:.*]] = mhlo.multiply %[[TMP_87]], %[[TMP_42]] - // CHECK: %[[TMP_89:.*]] = mhlo.constant dense<1823.9091668790973> - // CHECK: %[[TMP_90:.*]] = mhlo.add %[[TMP_88]], %[[TMP_89]] - // CHECK: %[[TMP_91:.*]] = mhlo.multiply %[[TMP_90]], %[[TMP_42]] - // CHECK: %[[TMP_92:.*]] = mhlo.constant dense<2246.3376081871097> - // CHECK: %[[TMP_93:.*]] = mhlo.add %[[TMP_91]], %[[TMP_92]] - // CHECK: %[[TMP_94:.*]] = mhlo.multiply %[[TMP_93]], %[[TMP_42]] - // CHECK: %[[TMP_95:.*]] = mhlo.constant dense<1656.6630919416134> - // CHECK: %[[TMP_96:.*]] = mhlo.add %[[TMP_94]], %[[TMP_95]] - // CHECK: %[[TMP_97:.*]] = mhlo.multiply %[[TMP_96]], %[[TMP_42]] - // CHECK: %[[TMP_98:.*]] = mhlo.constant dense<557.53534081772773> - // CHECK: %[[TMP_99:.*]] = mhlo.add %[[TMP_97]], %[[TMP_98]] - // CHECK: %[[TMP_100:.*]] = mhlo.divide %[[TMP_71]], %[[TMP_99]] - // CHECK: %[[TMP_103:.*]] = mhlo.constant dense<0.56418958354775506> - // CHECK: %[[TMP_105:.*]] = mhlo.multiply %[[TMP_103]], %[[TMP_42]] - // CHECK: %[[TMP_106:.*]] = mhlo.constant dense<1.275366707599781> - // CHECK: %[[TMP_107:.*]] = mhlo.add %[[TMP_105]], %[[TMP_106]] - // CHECK: %[[TMP_108:.*]] = mhlo.multiply %[[TMP_107]], %[[TMP_42]] - // CHECK: %[[TMP_109:.*]] = mhlo.constant dense<5.0190504225118051> - // CHECK: %[[TMP_110:.*]] = mhlo.add %[[TMP_108]], %[[TMP_109]] - // CHECK: %[[TMP_111:.*]] = mhlo.multiply %[[TMP_110]], %[[TMP_42]] - // CHECK: %[[TMP_112:.*]] = mhlo.constant dense<6.160210979930536> - // CHECK: %[[TMP_113:.*]] = mhlo.add %[[TMP_111]], %[[TMP_112]] - // CHECK: %[[TMP_114:.*]] = mhlo.multiply %[[TMP_113]], %[[TMP_42]] - // CHECK: %[[TMP_115:.*]] = mhlo.constant dense<7.4097426995044895> - // CHECK: %[[TMP_116:.*]] = mhlo.add %[[TMP_114]], %[[TMP_115]] - // CHECK: %[[TMP_117:.*]] = mhlo.multiply %[[TMP_116]], %[[TMP_42]] - // CHECK: %[[TMP_118:.*]] = mhlo.constant dense<2.9788666537210022> - // CHECK: %[[TMP_119:.*]] = mhlo.add %[[TMP_117]], %[[TMP_118]] - // CHECK: %[[TMP_120:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_119]] - // CHECK: %[[TMP_123:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_125:.*]] = mhlo.multiply %[[TMP_123]], %[[TMP_42]] - // CHECK: %[[TMP_126:.*]] = mhlo.constant dense<2.2605286322011726> - // CHECK: %[[TMP_127:.*]] = mhlo.add %[[TMP_125]], %[[TMP_126]] - // CHECK: %[[TMP_128:.*]] = mhlo.multiply %[[TMP_127]], %[[TMP_42]] - // CHECK: %[[TMP_129:.*]] = mhlo.constant dense<9.3960352493800147> - // CHECK: %[[TMP_130:.*]] = mhlo.add %[[TMP_128]], %[[TMP_129]] - // CHECK: %[[TMP_131:.*]] = mhlo.multiply %[[TMP_130]], %[[TMP_42]] - // CHECK: %[[TMP_132:.*]] = mhlo.constant dense<12.048953980809666> - // CHECK: %[[TMP_133:.*]] = mhlo.add %[[TMP_131]], %[[TMP_132]] - // CHECK: %[[TMP_134:.*]] = mhlo.multiply %[[TMP_133]], %[[TMP_42]] - // CHECK: %[[TMP_135:.*]] = mhlo.constant dense<17.081445074756591> - // CHECK: %[[TMP_136:.*]] = mhlo.add %[[TMP_134]], %[[TMP_135]] - // CHECK: %[[TMP_137:.*]] = mhlo.multiply %[[TMP_136]], %[[TMP_42]] - // CHECK: %[[TMP_138:.*]] = mhlo.constant dense<9.6089680906328585> - // CHECK: %[[TMP_139:.*]] = mhlo.add %[[TMP_137]], %[[TMP_138]] - // CHECK: %[[TMP_140:.*]] = mhlo.multiply %[[TMP_139]], %[[TMP_42]] - // CHECK: %[[TMP_141:.*]] = mhlo.constant dense<3.3690764510008151> - // CHECK: %[[TMP_142:.*]] = mhlo.add %[[TMP_140]], %[[TMP_141]] - // CHECK: %[[TMP_143:.*]] = mhlo.divide %[[TMP_120]], %[[TMP_142]] - // CHECK: %[[TMP_144:.*]] = mhlo.constant dense<8.000000e+00> - // CHECK: %[[TMP_145:.*]] = mhlo.compare LT, %[[TMP_42]], %[[TMP_144]], NOTYPE - // CHECK: %[[TMP_146:.*]] = mhlo.select %[[TMP_145]], %[[TMP_100]], %[[TMP_143]] - // CHECK: %[[TMP_147:.*]] = mhlo.constant dense<-709.78271289338397> - // CHECK: %[[TMP_148:.*]] = mhlo.compare LT, %[[TMP_40]], %[[TMP_147]], NOTYPE - // CHECK: %[[TMP_149:.*]] = mhlo.constant dense<0.000000e+00> - // CHECK: %[[TMP_150:.*]] = mhlo.select %[[TMP_148]], %[[TMP_149]], %[[TMP_146]] - // CHECK: %[[TMP_152:.*]] = mhlo.compare LT, %[[ARG]], %[[TMP_149]], NOTYPE - // CHECK: %[[TMP_153:.*]] = mhlo.constant dense<2.000000e+00> - // CHECK: %[[TMP_154:.*]] = mhlo.subtract %[[TMP_153]], %[[TMP_150]] - // CHECK: %[[TMP_155:.*]] = mhlo.select %[[TMP_152]], %[[TMP_154]], %[[TMP_150]] - // CHECK: %[[TMP_156:.*]] = mhlo.subtract %[[TMP_38]], %[[TMP_155]] - // CHECK: %[[TMP_157:.*]] = mhlo.abs %[[ARG]] - // CHECK: %[[TMP_159:.*]] = mhlo.compare LT, %[[TMP_157]], %[[TMP_38]], NOTYPE - // CHECK: %[[RESULT:.*]] = mhlo.select %[[TMP_159]], %[[TMP_37]], %[[TMP_156]] + // CHECK-HIGH-LEVEL: mhlo.erf + // CHECK: %[[RESULT:.*]] = mhlo.erf %[[ARG]] // CHECK: return %[[RESULT]] %1 = "chlo.erf"(%arg) : (tensor) -> tensor func.return %1 : tensor @@ -413,47 +275,8 @@ func.func @erf_f64(%arg : tensor) -> tensor { // CHECK-LABEL: @erf_f32 // CHECK-SAME: %[[ARG:.*]]: tensor func.func @erf_f32(%arg : tensor) -> tensor { - // CHECK-DAG: %[[TMP_0:.*]] = mhlo.constant dense<-4.000000e+00> - // CHECK-DAG: %[[TMP_1:.*]] = mhlo.constant dense<4.000000e+00> - // CHECK: %[[TMP_2:.*]] = mhlo.clamp %[[TMP_0]], %[[ARG]], %[[TMP_1]] - // CHECK: %[[TMP_3:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_2]] - // CHECK: %[[TMP_6:.*]] = mhlo.constant dense<-2.72614237E-10> - // CHECK: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_6]], %[[TMP_3]] - // CHECK: %[[TMP_9:.*]] = mhlo.constant dense<2.77068146E-8> - // CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_8]], %[[TMP_9]] - // CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_3]] - // CHECK: %[[TMP_12:.*]] = mhlo.constant dense<-2.10102394E-6> - // CHECK: %[[TMP_13:.*]] = mhlo.add %[[TMP_11]], %[[TMP_12]] - // CHECK: %[[TMP_14:.*]] = mhlo.multiply %[[TMP_13]], %[[TMP_3]] - // CHECK: %[[TMP_15:.*]] = mhlo.constant dense<-5.69250624E-5> - // CHECK: %[[TMP_16:.*]] = mhlo.add %[[TMP_14]], %[[TMP_15]] - // CHECK: %[[TMP_17:.*]] = mhlo.multiply %[[TMP_16]], %[[TMP_3]] - // CHECK: %[[TMP_18:.*]] = mhlo.constant dense<-7.34990637E-4> - // CHECK: %[[TMP_19:.*]] = mhlo.add %[[TMP_17]], %[[TMP_18]] - // CHECK: %[[TMP_20:.*]] = mhlo.multiply %[[TMP_19]], %[[TMP_3]] - // CHECK: %[[TMP_21:.*]] = mhlo.constant dense<-2.954600e-03> - // CHECK: %[[TMP_22:.*]] = mhlo.add %[[TMP_20]], %[[TMP_21]] - // CHECK: %[[TMP_23:.*]] = mhlo.multiply %[[TMP_22]], %[[TMP_3]] - // CHECK: %[[TMP_24:.*]] = mhlo.constant dense<-0.0160960332> - // CHECK: %[[TMP_25:.*]] = mhlo.add %[[TMP_23]], %[[TMP_24]] - // CHECK: %[[TMP_28:.*]] = mhlo.constant dense<-1.45660715E-5> - // CHECK: %[[TMP_30:.*]] = mhlo.multiply %[[TMP_28]], %[[TMP_3]] - // CHECK: %[[TMP_31:.*]] = mhlo.constant dense<-2.13374049E-4> - // CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_30]], %[[TMP_31]] - // CHECK: %[[TMP_33:.*]] = mhlo.multiply %[[TMP_32]], %[[TMP_3]] - // CHECK: %[[TMP_34:.*]] = mhlo.constant dense<-0.00168282702> - // CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_33]], %[[TMP_34]] - // CHECK: %[[TMP_36:.*]] = mhlo.multiply %[[TMP_35]], %[[TMP_3]] - // CHECK: %[[TMP_37:.*]] = mhlo.constant dense<-0.00737332925> - // CHECK: %[[TMP_38:.*]] = mhlo.add %[[TMP_36]], %[[TMP_37]] - // CHECK: %[[TMP_39:.*]] = mhlo.multiply %[[TMP_38]], %[[TMP_3]] - // CHECK: %[[TMP_40:.*]] = mhlo.constant dense<-0.0142647391> - // CHECK: %[[TMP_41:.*]] = mhlo.add %[[TMP_39]], %[[TMP_40]] - // CHECK: %[[TMP_42:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_25]] - // CHECK: %[[TMP_43:.*]] = mhlo.divide %[[TMP_42]], %[[TMP_41]] - // CHECK-DAG: %[[TMP_44:.*]] = mhlo.constant dense<-1.000000e+00> - // CHECK-DAG: %[[TMP_45:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[RESULT:.*]] = mhlo.clamp %[[TMP_44]], %[[TMP_43]], %[[TMP_45]] + // CHECK-HIGH-LEVEL: mhlo.erf + // CHECK: %[[RESULT:.*]] = mhlo.erf %[[ARG]] // CHECK: return %[[RESULT]] %1 = "chlo.erf"(%arg) : (tensor) -> tensor func.return %1 : tensor @@ -464,8 +287,8 @@ func.func @erf_f32(%arg : tensor) -> tensor { // CHECK-LABEL: @erf_f16 // CHECK-SAME: %[[ARG:.*]]: tensor func.func @erf_f16(%arg : tensor) -> tensor { - // CHECK: mhlo.convert %[[ARG]] : (tensor) -> tensor - // CHECK: %[[RESULT:.*]] = mhlo.convert %{{.*}} : (tensor) -> tensor + // CHECK-HIGH-LEVEL: mhlo.erf + // CHECK: %[[RESULT:.*]] = mhlo.erf %[[ARG]] // CHECK: return %[[RESULT]] %1 = "chlo.erf"(%arg) : (tensor) -> tensor func.return %1 : tensor @@ -476,8 +299,8 @@ func.func @erf_f16(%arg : tensor) -> tensor { // CHECK-LABEL: @erf_bf16 // CHECK-SAME: %[[ARG:.*]]: tensor func.func @erf_bf16(%arg : tensor) -> tensor { - // CHECK: mhlo.convert %[[ARG]] : (tensor) -> tensor - // CHECK: %[[RESULT:.*]] = mhlo.convert %{{.*}} : (tensor) -> tensor + // CHECK-HIGH-LEVEL: mhlo.erf + // CHECK: %[[RESULT:.*]] = mhlo.erf %[[ARG]] // CHECK: return %[[RESULT]] %1 = "chlo.erf"(%arg) : (tensor) -> tensor func.return %1 : tensor @@ -1286,153 +1109,153 @@ func.func @digamma_f16(%arg : tensor) -> tensor { func.func @zeta_f16(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: %[[TMP_0:.*]] = mhlo.convert %[[X]] : (tensor) -> tensor // CHECK: %[[TMP_1:.*]] = mhlo.convert %[[Q]] : (tensor) -> tensor - // CHECK: %[[TMP_2:.*]] = mhlo.constant dense<0.000000e+00> - // CHECK: %[[TMP_3:.*]] = mhlo.negate %[[TMP_0]] - // CHECK: %[[TMP_4:.*]] = mhlo.power %[[TMP_1]], %[[TMP_3]] - // CHECK: %[[TMP_5:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_6:.*]] = mhlo.add %[[TMP_1]], %[[TMP_5]] - // CHECK: %[[TMP_7:.*]] = mhlo.power %[[TMP_6]], %[[TMP_3]] - // CHECK: %[[TMP_8:.*]] = mhlo.add %[[TMP_4]], %[[TMP_7]] - // CHECK: %[[TMP_9:.*]] = mhlo.add %[[TMP_6]], %[[TMP_5]] - // CHECK: %[[TMP_10:.*]] = mhlo.power %[[TMP_9]], %[[TMP_3]] + // CHECK-DAG: %[[TMP_2:.*]] = mhlo.constant dense<0.000000e+00> + // CHECK-DAG: %[[TMP_3:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_4:.*]] = mhlo.negate %[[TMP_0]] + // CHECK: %[[TMP_5:.*]] = mhlo.power %[[TMP_1]], %[[TMP_4]] + // CHECK: %[[TMP_6:.*]] = mhlo.add %[[TMP_1]], %[[TMP_3]] + // CHECK: %[[TMP_7:.*]] = mhlo.power %[[TMP_6]], %[[TMP_4]] + // CHECK: %[[TMP_8:.*]] = mhlo.add %[[TMP_5]], %[[TMP_7]] + // CHECK: %[[TMP_9:.*]] = mhlo.add %[[TMP_6]], %[[TMP_3]] + // CHECK: %[[TMP_10:.*]] = mhlo.power %[[TMP_9]], %[[TMP_4]] // CHECK: %[[TMP_11:.*]] = mhlo.add %[[TMP_8]], %[[TMP_10]] - // CHECK: %[[TMP_12:.*]] = mhlo.add %[[TMP_9]], %[[TMP_5]] - // CHECK: %[[TMP_13:.*]] = mhlo.power %[[TMP_12]], %[[TMP_3]] + // CHECK: %[[TMP_12:.*]] = mhlo.add %[[TMP_9]], %[[TMP_3]] + // CHECK: %[[TMP_13:.*]] = mhlo.power %[[TMP_12]], %[[TMP_4]] // CHECK: %[[TMP_14:.*]] = mhlo.add %[[TMP_11]], %[[TMP_13]] - // CHECK: %[[TMP_15:.*]] = mhlo.add %[[TMP_12]], %[[TMP_5]] - // CHECK: %[[TMP_16:.*]] = mhlo.power %[[TMP_15]], %[[TMP_3]] + // CHECK: %[[TMP_15:.*]] = mhlo.add %[[TMP_12]], %[[TMP_3]] + // CHECK: %[[TMP_16:.*]] = mhlo.power %[[TMP_15]], %[[TMP_4]] // CHECK: %[[TMP_17:.*]] = mhlo.add %[[TMP_14]], %[[TMP_16]] - // CHECK: %[[TMP_18:.*]] = mhlo.add %[[TMP_15]], %[[TMP_5]] - // CHECK: %[[TMP_19:.*]] = mhlo.power %[[TMP_18]], %[[TMP_3]] + // CHECK: %[[TMP_18:.*]] = mhlo.add %[[TMP_15]], %[[TMP_3]] + // CHECK: %[[TMP_19:.*]] = mhlo.power %[[TMP_18]], %[[TMP_4]] // CHECK: %[[TMP_20:.*]] = mhlo.add %[[TMP_17]], %[[TMP_19]] - // CHECK: %[[TMP_21:.*]] = mhlo.add %[[TMP_18]], %[[TMP_5]] - // CHECK: %[[TMP_22:.*]] = mhlo.power %[[TMP_21]], %[[TMP_3]] + // CHECK: %[[TMP_21:.*]] = mhlo.add %[[TMP_18]], %[[TMP_3]] + // CHECK: %[[TMP_22:.*]] = mhlo.power %[[TMP_21]], %[[TMP_4]] // CHECK: %[[TMP_23:.*]] = mhlo.add %[[TMP_20]], %[[TMP_22]] - // CHECK: %[[TMP_24:.*]] = mhlo.add %[[TMP_21]], %[[TMP_5]] - // CHECK: %[[TMP_25:.*]] = mhlo.power %[[TMP_24]], %[[TMP_3]] + // CHECK: %[[TMP_24:.*]] = mhlo.add %[[TMP_21]], %[[TMP_3]] + // CHECK: %[[TMP_25:.*]] = mhlo.power %[[TMP_24]], %[[TMP_4]] // CHECK: %[[TMP_26:.*]] = mhlo.add %[[TMP_23]], %[[TMP_25]] - // CHECK: %[[TMP_27:.*]] = mhlo.add %[[TMP_24]], %[[TMP_5]] - // CHECK: %[[TMP_28:.*]] = mhlo.power %[[TMP_27]], %[[TMP_3]] + // CHECK: %[[TMP_27:.*]] = mhlo.add %[[TMP_24]], %[[TMP_3]] + // CHECK: %[[TMP_28:.*]] = mhlo.power %[[TMP_27]], %[[TMP_4]] // CHECK: %[[TMP_29:.*]] = mhlo.add %[[TMP_26]], %[[TMP_28]] - // CHECK: %[[TMP_30:.*]] = mhlo.add %[[TMP_27]], %[[TMP_5]] - // CHECK: %[[TMP_31:.*]] = mhlo.power %[[TMP_30]], %[[TMP_3]] + // CHECK: %[[TMP_30:.*]] = mhlo.add %[[TMP_27]], %[[TMP_3]] + // CHECK: %[[TMP_31:.*]] = mhlo.power %[[TMP_30]], %[[TMP_4]] // CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_29]], %[[TMP_31]] - // CHECK: %[[TMP_33:.*]] = mhlo.add %[[TMP_30]], %[[TMP_5]] - // CHECK: %[[TMP_34:.*]] = mhlo.power %[[TMP_33]], %[[TMP_3]] + // CHECK: %[[TMP_33:.*]] = mhlo.add %[[TMP_30]], %[[TMP_3]] + // CHECK: %[[TMP_34:.*]] = mhlo.power %[[TMP_33]], %[[TMP_4]] // CHECK: %[[TMP_35:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_36:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_35]] - // CHECK: %[[TMP_37:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_33]] - // CHECK: %[[TMP_38:.*]] = mhlo.divide %[[TMP_37]], %[[TMP_36]] - // CHECK: %[[TMP_39:.*]] = mhlo.add %[[TMP_32]], %[[TMP_38]] - // CHECK: %[[TMP_40:.*]] = mhlo.multiply %[[TMP_33]], %[[TMP_33]] - // CHECK: %[[TMP_41:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_40]] - // CHECK: %[[TMP_42:.*]] = mhlo.constant dense<2.200000e+01> - // CHECK: %[[TMP_43:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_42]] - // CHECK: %[[TMP_44:.*]] = mhlo.constant dense<2.100000e+01> - // CHECK: %[[TMP_45:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_44]] - // CHECK: %[[TMP_46:.*]] = mhlo.multiply %[[TMP_43]], %[[TMP_45]] - // CHECK: %[[TMP_47:.*]] = mhlo.constant dense<-1.39544646E-19> - // CHECK: %[[TMP_48:.*]] = mhlo.add %[[TMP_2]], %[[TMP_47]] - // CHECK: %[[TMP_49:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_48]] - // CHECK: %[[TMP_50:.*]] = mhlo.multiply %[[TMP_46]], %[[TMP_49]] - // CHECK: %[[TMP_51:.*]] = mhlo.constant dense<2.000000e+01> - // CHECK: %[[TMP_52:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_51]] - // CHECK: %[[TMP_53:.*]] = mhlo.constant dense<1.900000e+01> - // CHECK: %[[TMP_54:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_53]] - // CHECK: %[[TMP_55:.*]] = mhlo.multiply %[[TMP_52]], %[[TMP_54]] - // CHECK: %[[TMP_56:.*]] = mhlo.constant dense<5.50900303E-18> - // CHECK: %[[TMP_57:.*]] = mhlo.add %[[TMP_50]], %[[TMP_56]] - // CHECK: %[[TMP_58:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_57]] - // CHECK: %[[TMP_59:.*]] = mhlo.multiply %[[TMP_55]], %[[TMP_58]] - // CHECK: %[[TMP_60:.*]] = mhlo.constant dense<1.800000e+01> - // CHECK: %[[TMP_61:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_60]] - // CHECK: %[[TMP_62:.*]] = mhlo.constant dense<1.700000e+01> - // CHECK: %[[TMP_63:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_62]] - // CHECK: %[[TMP_64:.*]] = mhlo.multiply %[[TMP_61]], %[[TMP_63]] - // CHECK: %[[TMP_65:.*]] = mhlo.constant dense<-2.17486866E-16> - // CHECK: %[[TMP_66:.*]] = mhlo.add %[[TMP_59]], %[[TMP_65]] - // CHECK: %[[TMP_67:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_66]] - // CHECK: %[[TMP_68:.*]] = mhlo.multiply %[[TMP_64]], %[[TMP_67]] - // CHECK: %[[TMP_69:.*]] = mhlo.constant dense<1.600000e+01> - // CHECK: %[[TMP_70:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_69]] - // CHECK: %[[TMP_71:.*]] = mhlo.constant dense<1.500000e+01> - // CHECK: %[[TMP_72:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_71]] - // CHECK: %[[TMP_73:.*]] = mhlo.multiply %[[TMP_70]], %[[TMP_72]] - // CHECK: %[[TMP_74:.*]] = mhlo.constant dense<8.58606213E-15> - // CHECK: %[[TMP_75:.*]] = mhlo.add %[[TMP_68]], %[[TMP_74]] - // CHECK: %[[TMP_76:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_75]] - // CHECK: %[[TMP_77:.*]] = mhlo.multiply %[[TMP_73]], %[[TMP_76]] - // CHECK: %[[TMP_78:.*]] = mhlo.constant dense<1.400000e+01> - // CHECK: %[[TMP_79:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_78]] - // CHECK: %[[TMP_80:.*]] = mhlo.constant dense<1.300000e+01> - // CHECK: %[[TMP_81:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_80]] - // CHECK: %[[TMP_82:.*]] = mhlo.multiply %[[TMP_79]], %[[TMP_81]] - // CHECK: %[[TMP_83:.*]] = mhlo.constant dense<-3.3896803E-13> - // CHECK: %[[TMP_84:.*]] = mhlo.add %[[TMP_77]], %[[TMP_83]] - // CHECK: %[[TMP_85:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_84]] - // CHECK: %[[TMP_86:.*]] = mhlo.multiply %[[TMP_82]], %[[TMP_85]] - // CHECK: %[[TMP_87:.*]] = mhlo.constant dense<1.200000e+01> - // CHECK: %[[TMP_88:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_87]] - // CHECK: %[[TMP_89:.*]] = mhlo.constant dense<1.100000e+01> - // CHECK: %[[TMP_90:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_89]] - // CHECK: %[[TMP_91:.*]] = mhlo.multiply %[[TMP_88]], %[[TMP_90]] - // CHECK: %[[TMP_92:.*]] = mhlo.constant dense<1.33825364E-11> - // CHECK: %[[TMP_93:.*]] = mhlo.add %[[TMP_86]], %[[TMP_92]] - // CHECK: %[[TMP_94:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_93]] - // CHECK: %[[TMP_95:.*]] = mhlo.multiply %[[TMP_91]], %[[TMP_94]] - // CHECK: %[[TMP_96:.*]] = mhlo.constant dense<1.000000e+01> - // CHECK: %[[TMP_97:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_96]] - // CHECK: %[[TMP_98:.*]] = mhlo.constant dense<9.000000e+00> - // CHECK: %[[TMP_99:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_98]] - // CHECK: %[[TMP_100:.*]] = mhlo.multiply %[[TMP_97]], %[[TMP_99]] - // CHECK: %[[TMP_101:.*]] = mhlo.constant dense<-5.28419031E-10> - // CHECK: %[[TMP_102:.*]] = mhlo.add %[[TMP_95]], %[[TMP_101]] - // CHECK: %[[TMP_103:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_102]] - // CHECK: %[[TMP_104:.*]] = mhlo.multiply %[[TMP_100]], %[[TMP_103]] - // CHECK: %[[TMP_105:.*]] = mhlo.constant dense<8.000000e+00> - // CHECK: %[[TMP_106:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_105]] - // CHECK: %[[TMP_107:.*]] = mhlo.constant dense<7.000000e+00> - // CHECK: %[[TMP_108:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_107]] - // CHECK: %[[TMP_109:.*]] = mhlo.multiply %[[TMP_106]], %[[TMP_108]] - // CHECK: %[[TMP_110:.*]] = mhlo.constant dense<2.08767563E-8> - // CHECK: %[[TMP_111:.*]] = mhlo.add %[[TMP_104]], %[[TMP_110]] - // CHECK: %[[TMP_112:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_111]] - // CHECK: %[[TMP_113:.*]] = mhlo.multiply %[[TMP_109]], %[[TMP_112]] - // CHECK: %[[TMP_114:.*]] = mhlo.constant dense<6.000000e+00> - // CHECK: %[[TMP_115:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_114]] - // CHECK: %[[TMP_116:.*]] = mhlo.constant dense<5.000000e+00> - // CHECK: %[[TMP_117:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_116]] - // CHECK: %[[TMP_118:.*]] = mhlo.multiply %[[TMP_115]], %[[TMP_117]] - // CHECK: %[[TMP_119:.*]] = mhlo.constant dense<-8.26719599E-7> - // CHECK: %[[TMP_120:.*]] = mhlo.add %[[TMP_113]], %[[TMP_119]] - // CHECK: %[[TMP_121:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_120]] - // CHECK: %[[TMP_122:.*]] = mhlo.multiply %[[TMP_118]], %[[TMP_121]] - // CHECK: %[[TMP_123:.*]] = mhlo.constant dense<4.000000e+00> - // CHECK: %[[TMP_124:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_123]] - // CHECK: %[[TMP_125:.*]] = mhlo.constant dense<3.000000e+00> - // CHECK: %[[TMP_126:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_125]] - // CHECK: %[[TMP_127:.*]] = mhlo.multiply %[[TMP_124]], %[[TMP_126]] - // CHECK: %[[TMP_128:.*]] = mhlo.constant dense<3.30687835E-5> - // CHECK: %[[TMP_129:.*]] = mhlo.add %[[TMP_122]], %[[TMP_128]] - // CHECK: %[[TMP_130:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_129]] - // CHECK: %[[TMP_131:.*]] = mhlo.multiply %[[TMP_127]], %[[TMP_130]] - // CHECK: %[[TMP_132:.*]] = mhlo.constant dense<2.000000e+00> - // CHECK: %[[TMP_133:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_132]] - // CHECK: %[[TMP_134:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_135:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_134]] - // CHECK: %[[TMP_136:.*]] = mhlo.multiply %[[TMP_133]], %[[TMP_135]] - // CHECK: %[[TMP_137:.*]] = mhlo.constant dense<-0.00138888892> - // CHECK: %[[TMP_138:.*]] = mhlo.add %[[TMP_131]], %[[TMP_137]] - // CHECK: %[[TMP_139:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_138]] - // CHECK: %[[TMP_140:.*]] = mhlo.multiply %[[TMP_136]], %[[TMP_139]] - // CHECK: %[[TMP_141:.*]] = mhlo.constant dense<5.000000e-01> - // CHECK: %[[TMP_142:.*]] = mhlo.divide %[[TMP_0]], %[[TMP_33]] - // CHECK: %[[TMP_143:.*]] = mhlo.constant dense<0.0833333358> - // CHECK: %[[TMP_144:.*]] = mhlo.add %[[TMP_143]], %[[TMP_140]] - // CHECK: %[[TMP_145:.*]] = mhlo.multiply %[[TMP_142]], %[[TMP_144]] - // CHECK: %[[TMP_146:.*]] = mhlo.add %[[TMP_141]], %[[TMP_145]] - // CHECK: %[[TMP_147:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_146]] - // CHECK: %[[TMP_148:.*]] = mhlo.add %[[TMP_39]], %[[TMP_147]] + // CHECK: %[[TMP_36:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_33]] + // CHECK: %[[TMP_37:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_35]] + // CHECK: %[[TMP_38:.*]] = mhlo.divide %[[TMP_36]], %[[TMP_37]] + // CHECK: %[[TMP_39:.*]] = mhlo.multiply %[[TMP_33]], %[[TMP_33]] + // CHECK: %[[TMP_40:.*]] = mhlo.divide %[[TMP_3]], %[[TMP_39]] + // CHECK: %[[TMP_41:.*]] = mhlo.constant dense<2.200000e+01> + // CHECK: %[[TMP_42:.*]] = mhlo.add %[[TMP_0]], %[[TMP_41]] + // CHECK: %[[TMP_43:.*]] = mhlo.constant dense<2.100000e+01> + // CHECK: %[[TMP_44:.*]] = mhlo.add %[[TMP_0]], %[[TMP_43]] + // CHECK: %[[TMP_45:.*]] = mhlo.multiply %[[TMP_42]], %[[TMP_44]] + // CHECK: %[[TMP_46:.*]] = mhlo.constant dense<-1.39544646E-19> + // CHECK: %[[TMP_47:.*]] = mhlo.add %[[TMP_2]], %[[TMP_46]] + // CHECK: %[[TMP_48:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_47]] + // CHECK: %[[TMP_49:.*]] = mhlo.multiply %[[TMP_45]], %[[TMP_48]] + // CHECK: %[[TMP_50:.*]] = mhlo.constant dense<2.000000e+01> + // CHECK: %[[TMP_51:.*]] = mhlo.add %[[TMP_0]], %[[TMP_50]] + // CHECK: %[[TMP_52:.*]] = mhlo.constant dense<1.900000e+01> + // CHECK: %[[TMP_53:.*]] = mhlo.add %[[TMP_0]], %[[TMP_52]] + // CHECK: %[[TMP_54:.*]] = mhlo.multiply %[[TMP_51]], %[[TMP_53]] + // CHECK: %[[TMP_55:.*]] = mhlo.constant dense<5.50900303E-18> + // CHECK: %[[TMP_56:.*]] = mhlo.add %[[TMP_49]], %[[TMP_55]] + // CHECK: %[[TMP_57:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_56]] + // CHECK: %[[TMP_58:.*]] = mhlo.multiply %[[TMP_54]], %[[TMP_57]] + // CHECK: %[[TMP_59:.*]] = mhlo.constant dense<1.800000e+01> + // CHECK: %[[TMP_60:.*]] = mhlo.add %[[TMP_0]], %[[TMP_59]] + // CHECK: %[[TMP_61:.*]] = mhlo.constant dense<1.700000e+01> + // CHECK: %[[TMP_62:.*]] = mhlo.add %[[TMP_0]], %[[TMP_61]] + // CHECK: %[[TMP_63:.*]] = mhlo.multiply %[[TMP_60]], %[[TMP_62]] + // CHECK: %[[TMP_64:.*]] = mhlo.constant dense<-2.17486866E-16> + // CHECK: %[[TMP_65:.*]] = mhlo.add %[[TMP_58]], %[[TMP_64]] + // CHECK: %[[TMP_66:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_65]] + // CHECK: %[[TMP_67:.*]] = mhlo.multiply %[[TMP_63]], %[[TMP_66]] + // CHECK: %[[TMP_68:.*]] = mhlo.constant dense<1.600000e+01> + // CHECK: %[[TMP_69:.*]] = mhlo.add %[[TMP_0]], %[[TMP_68]] + // CHECK: %[[TMP_70:.*]] = mhlo.constant dense<1.500000e+01> + // CHECK: %[[TMP_71:.*]] = mhlo.add %[[TMP_0]], %[[TMP_70]] + // CHECK: %[[TMP_72:.*]] = mhlo.multiply %[[TMP_69]], %[[TMP_71]] + // CHECK: %[[TMP_73:.*]] = mhlo.constant dense<8.58606213E-15> + // CHECK: %[[TMP_74:.*]] = mhlo.add %[[TMP_67]], %[[TMP_73]] + // CHECK: %[[TMP_75:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_74]] + // CHECK: %[[TMP_76:.*]] = mhlo.multiply %[[TMP_72]], %[[TMP_75]] + // CHECK: %[[TMP_77:.*]] = mhlo.constant dense<1.400000e+01> + // CHECK: %[[TMP_78:.*]] = mhlo.add %[[TMP_0]], %[[TMP_77]] + // CHECK: %[[TMP_79:.*]] = mhlo.constant dense<1.300000e+01> + // CHECK: %[[TMP_80:.*]] = mhlo.add %[[TMP_0]], %[[TMP_79]] + // CHECK: %[[TMP_81:.*]] = mhlo.multiply %[[TMP_78]], %[[TMP_80]] + // CHECK: %[[TMP_82:.*]] = mhlo.constant dense<-3.3896803E-13> + // CHECK: %[[TMP_83:.*]] = mhlo.add %[[TMP_76]], %[[TMP_82]] + // CHECK: %[[TMP_84:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_83]] + // CHECK: %[[TMP_85:.*]] = mhlo.multiply %[[TMP_81]], %[[TMP_84]] + // CHECK: %[[TMP_86:.*]] = mhlo.constant dense<1.200000e+01> + // CHECK: %[[TMP_87:.*]] = mhlo.add %[[TMP_0]], %[[TMP_86]] + // CHECK: %[[TMP_88:.*]] = mhlo.constant dense<1.100000e+01> + // CHECK: %[[TMP_89:.*]] = mhlo.add %[[TMP_0]], %[[TMP_88]] + // CHECK: %[[TMP_90:.*]] = mhlo.multiply %[[TMP_87]], %[[TMP_89]] + // CHECK: %[[TMP_91:.*]] = mhlo.constant dense<1.33825364E-11> + // CHECK: %[[TMP_92:.*]] = mhlo.add %[[TMP_85]], %[[TMP_91]] + // CHECK: %[[TMP_93:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_92]] + // CHECK: %[[TMP_94:.*]] = mhlo.multiply %[[TMP_90]], %[[TMP_93]] + // CHECK: %[[TMP_95:.*]] = mhlo.constant dense<1.000000e+01> + // CHECK: %[[TMP_96:.*]] = mhlo.add %[[TMP_0]], %[[TMP_95]] + // CHECK: %[[TMP_97:.*]] = mhlo.constant dense<9.000000e+00> + // CHECK: %[[TMP_98:.*]] = mhlo.add %[[TMP_0]], %[[TMP_97]] + // CHECK: %[[TMP_99:.*]] = mhlo.multiply %[[TMP_96]], %[[TMP_98]] + // CHECK: %[[TMP_100:.*]] = mhlo.constant dense<-5.28419031E-10> + // CHECK: %[[TMP_101:.*]] = mhlo.add %[[TMP_94]], %[[TMP_100]] + // CHECK: %[[TMP_102:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_101]] + // CHECK: %[[TMP_103:.*]] = mhlo.multiply %[[TMP_99]], %[[TMP_102]] + // CHECK: %[[TMP_104:.*]] = mhlo.constant dense<8.000000e+00> + // CHECK: %[[TMP_105:.*]] = mhlo.add %[[TMP_0]], %[[TMP_104]] + // CHECK: %[[TMP_106:.*]] = mhlo.constant dense<7.000000e+00> + // CHECK: %[[TMP_107:.*]] = mhlo.add %[[TMP_0]], %[[TMP_106]] + // CHECK: %[[TMP_108:.*]] = mhlo.multiply %[[TMP_105]], %[[TMP_107]] + // CHECK: %[[TMP_109:.*]] = mhlo.constant dense<2.08767563E-8> + // CHECK: %[[TMP_110:.*]] = mhlo.add %[[TMP_103]], %[[TMP_109]] + // CHECK: %[[TMP_111:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_110]] + // CHECK: %[[TMP_112:.*]] = mhlo.multiply %[[TMP_108]], %[[TMP_111]] + // CHECK: %[[TMP_113:.*]] = mhlo.constant dense<6.000000e+00> + // CHECK: %[[TMP_114:.*]] = mhlo.add %[[TMP_0]], %[[TMP_113]] + // CHECK: %[[TMP_115:.*]] = mhlo.constant dense<5.000000e+00> + // CHECK: %[[TMP_116:.*]] = mhlo.add %[[TMP_0]], %[[TMP_115]] + // CHECK: %[[TMP_117:.*]] = mhlo.multiply %[[TMP_114]], %[[TMP_116]] + // CHECK: %[[TMP_118:.*]] = mhlo.constant dense<-8.26719599E-7> + // CHECK: %[[TMP_119:.*]] = mhlo.add %[[TMP_112]], %[[TMP_118]] + // CHECK: %[[TMP_120:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_119]] + // CHECK: %[[TMP_121:.*]] = mhlo.multiply %[[TMP_117]], %[[TMP_120]] + // CHECK: %[[TMP_122:.*]] = mhlo.constant dense<4.000000e+00> + // CHECK: %[[TMP_123:.*]] = mhlo.add %[[TMP_0]], %[[TMP_122]] + // CHECK: %[[TMP_124:.*]] = mhlo.constant dense<3.000000e+00> + // CHECK: %[[TMP_125:.*]] = mhlo.add %[[TMP_0]], %[[TMP_124]] + // CHECK: %[[TMP_126:.*]] = mhlo.multiply %[[TMP_123]], %[[TMP_125]] + // CHECK: %[[TMP_127:.*]] = mhlo.constant dense<3.30687835E-5> + // CHECK: %[[TMP_128:.*]] = mhlo.add %[[TMP_121]], %[[TMP_127]] + // CHECK: %[[TMP_129:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_128]] + // CHECK: %[[TMP_130:.*]] = mhlo.multiply %[[TMP_126]], %[[TMP_129]] + // CHECK: %[[TMP_131:.*]] = mhlo.constant dense<2.000000e+00> + // CHECK: %[[TMP_132:.*]] = mhlo.add %[[TMP_0]], %[[TMP_131]] + // CHECK: %[[TMP_133:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_134:.*]] = mhlo.add %[[TMP_0]], %[[TMP_133]] + // CHECK: %[[TMP_135:.*]] = mhlo.multiply %[[TMP_132]], %[[TMP_134]] + // CHECK: %[[TMP_136:.*]] = mhlo.constant dense<-0.00138888892> + // CHECK: %[[TMP_137:.*]] = mhlo.add %[[TMP_130]], %[[TMP_136]] + // CHECK: %[[TMP_138:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_137]] + // CHECK: %[[TMP_139:.*]] = mhlo.multiply %[[TMP_135]], %[[TMP_138]] + // CHECK: %[[TMP_140:.*]] = mhlo.constant dense<5.000000e-01> + // CHECK: %[[TMP_141:.*]] = mhlo.divide %[[TMP_0]], %[[TMP_33]] + // CHECK: %[[TMP_142:.*]] = mhlo.constant dense<0.0833333358> + // CHECK: %[[TMP_143:.*]] = mhlo.add %[[TMP_142]], %[[TMP_139]] + // CHECK: %[[TMP_144:.*]] = mhlo.multiply %[[TMP_141]], %[[TMP_143]] + // CHECK: %[[TMP_145:.*]] = mhlo.add %[[TMP_140]], %[[TMP_144]] + // CHECK: %[[TMP_146:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_145]] + // CHECK: %[[TMP_147:.*]] = mhlo.add %[[TMP_32]], %[[TMP_38]] + // CHECK: %[[TMP_148:.*]] = mhlo.add %[[TMP_147]], %[[TMP_146]] // CHECK: %[[TMP_149:.*]] = mhlo.abs %[[TMP_34]] // CHECK: %[[TMP_150:.*]] = mhlo.abs %[[TMP_32]] // CHECK: %[[TMP_151:.*]] = mhlo.constant dense<1.401300e-45> @@ -1459,7 +1282,7 @@ func.func @zeta_f16(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: %[[TMP_172:.*]] = mhlo.and %[[TMP_169]], %[[TMP_171]] : tensor // CHECK: %[[TMP_173:.*]] = mhlo.select %[[TMP_172]], %[[TMP_163]], %[[TMP_155]] // CHECK: %[[TMP_174:.*]] = mhlo.select %[[TMP_166]], %[[TMP_173]], %[[TMP_162]] - // CHECK: %[[TMP_175:.*]] = mhlo.compare EQ, %[[TMP_0]], %[[TMP_5]], NOTYPE + // CHECK: %[[TMP_175:.*]] = mhlo.compare EQ, %[[TMP_0]], %[[TMP_3]], NOTYPE // CHECK: %[[TMP_176:.*]] = mhlo.select %[[TMP_175]], %[[TMP_163]], %[[TMP_174]] // CHECK: %[[TMP_177:.*]] = mhlo.convert %[[TMP_176]] : (tensor) -> tensor %0 = chlo.zeta %arg0, %arg1 : tensor, tensor -> tensor @@ -1561,153 +1384,153 @@ func.func @polygamma_f32(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_87:.*]] = mhlo.constant dense<0x7F800000> // CHECK: %[[TMP_88:.*]] = mhlo.select %[[TMP_86]], %[[TMP_87]], %[[TMP_83]] // CHECK: %[[TMP_89:.*]] = mhlo.exponential %[[TMP_88]] - // CHECK: %[[TMP_90:.*]] = mhlo.constant dense<0.000000e+00> - // CHECK: %[[TMP_91:.*]] = mhlo.negate %[[TMP_5]] - // CHECK: %[[TMP_92:.*]] = mhlo.power %[[ARG1]], %[[TMP_91]] - // CHECK: %[[TMP_93:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_94:.*]] = mhlo.add %[[ARG1]], %[[TMP_93]] - // CHECK: %[[TMP_95:.*]] = mhlo.power %[[TMP_94]], %[[TMP_91]] - // CHECK: %[[TMP_96:.*]] = mhlo.add %[[TMP_92]], %[[TMP_95]] - // CHECK: %[[TMP_97:.*]] = mhlo.add %[[TMP_94]], %[[TMP_93]] - // CHECK: %[[TMP_98:.*]] = mhlo.power %[[TMP_97]], %[[TMP_91]] + // CHECK-DAG: %[[TMP_90:.*]] = mhlo.constant dense<0.000000e+00> + // CHECK-DAG: %[[TMP_91:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_92:.*]] = mhlo.negate %[[TMP_5]] + // CHECK: %[[TMP_93:.*]] = mhlo.power %[[ARG1]], %[[TMP_92]] + // CHECK: %[[TMP_94:.*]] = mhlo.add %[[ARG1]], %[[TMP_91]] + // CHECK: %[[TMP_95:.*]] = mhlo.power %[[TMP_94]], %[[TMP_92]] + // CHECK: %[[TMP_96:.*]] = mhlo.add %[[TMP_93]], %[[TMP_95]] + // CHECK: %[[TMP_97:.*]] = mhlo.add %[[TMP_94]], %[[TMP_91]] + // CHECK: %[[TMP_98:.*]] = mhlo.power %[[TMP_97]], %[[TMP_92]] // CHECK: %[[TMP_99:.*]] = mhlo.add %[[TMP_96]], %[[TMP_98]] - // CHECK: %[[TMP_100:.*]] = mhlo.add %[[TMP_97]], %[[TMP_93]] - // CHECK: %[[TMP_101:.*]] = mhlo.power %[[TMP_100]], %[[TMP_91]] + // CHECK: %[[TMP_100:.*]] = mhlo.add %[[TMP_97]], %[[TMP_91]] + // CHECK: %[[TMP_101:.*]] = mhlo.power %[[TMP_100]], %[[TMP_92]] // CHECK: %[[TMP_102:.*]] = mhlo.add %[[TMP_99]], %[[TMP_101]] - // CHECK: %[[TMP_103:.*]] = mhlo.add %[[TMP_100]], %[[TMP_93]] - // CHECK: %[[TMP_104:.*]] = mhlo.power %[[TMP_103]], %[[TMP_91]] + // CHECK: %[[TMP_103:.*]] = mhlo.add %[[TMP_100]], %[[TMP_91]] + // CHECK: %[[TMP_104:.*]] = mhlo.power %[[TMP_103]], %[[TMP_92]] // CHECK: %[[TMP_105:.*]] = mhlo.add %[[TMP_102]], %[[TMP_104]] - // CHECK: %[[TMP_106:.*]] = mhlo.add %[[TMP_103]], %[[TMP_93]] - // CHECK: %[[TMP_107:.*]] = mhlo.power %[[TMP_106]], %[[TMP_91]] + // CHECK: %[[TMP_106:.*]] = mhlo.add %[[TMP_103]], %[[TMP_91]] + // CHECK: %[[TMP_107:.*]] = mhlo.power %[[TMP_106]], %[[TMP_92]] // CHECK: %[[TMP_108:.*]] = mhlo.add %[[TMP_105]], %[[TMP_107]] - // CHECK: %[[TMP_109:.*]] = mhlo.add %[[TMP_106]], %[[TMP_93]] - // CHECK: %[[TMP_110:.*]] = mhlo.power %[[TMP_109]], %[[TMP_91]] + // CHECK: %[[TMP_109:.*]] = mhlo.add %[[TMP_106]], %[[TMP_91]] + // CHECK: %[[TMP_110:.*]] = mhlo.power %[[TMP_109]], %[[TMP_92]] // CHECK: %[[TMP_111:.*]] = mhlo.add %[[TMP_108]], %[[TMP_110]] - // CHECK: %[[TMP_112:.*]] = mhlo.add %[[TMP_109]], %[[TMP_93]] - // CHECK: %[[TMP_113:.*]] = mhlo.power %[[TMP_112]], %[[TMP_91]] + // CHECK: %[[TMP_112:.*]] = mhlo.add %[[TMP_109]], %[[TMP_91]] + // CHECK: %[[TMP_113:.*]] = mhlo.power %[[TMP_112]], %[[TMP_92]] // CHECK: %[[TMP_114:.*]] = mhlo.add %[[TMP_111]], %[[TMP_113]] - // CHECK: %[[TMP_115:.*]] = mhlo.add %[[TMP_112]], %[[TMP_93]] - // CHECK: %[[TMP_116:.*]] = mhlo.power %[[TMP_115]], %[[TMP_91]] + // CHECK: %[[TMP_115:.*]] = mhlo.add %[[TMP_112]], %[[TMP_91]] + // CHECK: %[[TMP_116:.*]] = mhlo.power %[[TMP_115]], %[[TMP_92]] // CHECK: %[[TMP_117:.*]] = mhlo.add %[[TMP_114]], %[[TMP_116]] - // CHECK: %[[TMP_118:.*]] = mhlo.add %[[TMP_115]], %[[TMP_93]] - // CHECK: %[[TMP_119:.*]] = mhlo.power %[[TMP_118]], %[[TMP_91]] + // CHECK: %[[TMP_118:.*]] = mhlo.add %[[TMP_115]], %[[TMP_91]] + // CHECK: %[[TMP_119:.*]] = mhlo.power %[[TMP_118]], %[[TMP_92]] // CHECK: %[[TMP_120:.*]] = mhlo.add %[[TMP_117]], %[[TMP_119]] - // CHECK: %[[TMP_121:.*]] = mhlo.add %[[TMP_118]], %[[TMP_93]] - // CHECK: %[[TMP_122:.*]] = mhlo.power %[[TMP_121]], %[[TMP_91]] + // CHECK: %[[TMP_121:.*]] = mhlo.add %[[TMP_118]], %[[TMP_91]] + // CHECK: %[[TMP_122:.*]] = mhlo.power %[[TMP_121]], %[[TMP_92]] // CHECK: %[[TMP_123:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_124:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_123]] - // CHECK: %[[TMP_125:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_121]] - // CHECK: %[[TMP_126:.*]] = mhlo.divide %[[TMP_125]], %[[TMP_124]] - // CHECK: %[[TMP_127:.*]] = mhlo.add %[[TMP_120]], %[[TMP_126]] - // CHECK: %[[TMP_128:.*]] = mhlo.multiply %[[TMP_121]], %[[TMP_121]] - // CHECK: %[[TMP_129:.*]] = mhlo.divide %[[TMP_93]], %[[TMP_128]] - // CHECK: %[[TMP_130:.*]] = mhlo.constant dense<2.200000e+01> - // CHECK: %[[TMP_131:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_130]] - // CHECK: %[[TMP_132:.*]] = mhlo.constant dense<2.100000e+01> - // CHECK: %[[TMP_133:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_132]] - // CHECK: %[[TMP_134:.*]] = mhlo.multiply %[[TMP_131]], %[[TMP_133]] - // CHECK: %[[TMP_135:.*]] = mhlo.constant dense<-1.39544646E-19> - // CHECK: %[[TMP_136:.*]] = mhlo.add %[[TMP_90]], %[[TMP_135]] - // CHECK: %[[TMP_137:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_136]] - // CHECK: %[[TMP_138:.*]] = mhlo.multiply %[[TMP_134]], %[[TMP_137]] - // CHECK: %[[TMP_139:.*]] = mhlo.constant dense<2.000000e+01> - // CHECK: %[[TMP_140:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_139]] - // CHECK: %[[TMP_141:.*]] = mhlo.constant dense<1.900000e+01> - // CHECK: %[[TMP_142:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_141]] - // CHECK: %[[TMP_143:.*]] = mhlo.multiply %[[TMP_140]], %[[TMP_142]] - // CHECK: %[[TMP_144:.*]] = mhlo.constant dense<5.50900303E-18> - // CHECK: %[[TMP_145:.*]] = mhlo.add %[[TMP_138]], %[[TMP_144]] - // CHECK: %[[TMP_146:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_145]] - // CHECK: %[[TMP_147:.*]] = mhlo.multiply %[[TMP_143]], %[[TMP_146]] - // CHECK: %[[TMP_148:.*]] = mhlo.constant dense<1.800000e+01> - // CHECK: %[[TMP_149:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_148]] - // CHECK: %[[TMP_150:.*]] = mhlo.constant dense<1.700000e+01> - // CHECK: %[[TMP_151:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_150]] - // CHECK: %[[TMP_152:.*]] = mhlo.multiply %[[TMP_149]], %[[TMP_151]] - // CHECK: %[[TMP_153:.*]] = mhlo.constant dense<-2.17486866E-16> - // CHECK: %[[TMP_154:.*]] = mhlo.add %[[TMP_147]], %[[TMP_153]] - // CHECK: %[[TMP_155:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_154]] - // CHECK: %[[TMP_156:.*]] = mhlo.multiply %[[TMP_152]], %[[TMP_155]] - // CHECK: %[[TMP_157:.*]] = mhlo.constant dense<1.600000e+01> - // CHECK: %[[TMP_158:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_157]] - // CHECK: %[[TMP_159:.*]] = mhlo.constant dense<1.500000e+01> - // CHECK: %[[TMP_160:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_159]] - // CHECK: %[[TMP_161:.*]] = mhlo.multiply %[[TMP_158]], %[[TMP_160]] - // CHECK: %[[TMP_162:.*]] = mhlo.constant dense<8.58606213E-15> - // CHECK: %[[TMP_163:.*]] = mhlo.add %[[TMP_156]], %[[TMP_162]] - // CHECK: %[[TMP_164:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_163]] - // CHECK: %[[TMP_165:.*]] = mhlo.multiply %[[TMP_161]], %[[TMP_164]] - // CHECK: %[[TMP_166:.*]] = mhlo.constant dense<1.400000e+01> - // CHECK: %[[TMP_167:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_166]] - // CHECK: %[[TMP_168:.*]] = mhlo.constant dense<1.300000e+01> - // CHECK: %[[TMP_169:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_168]] - // CHECK: %[[TMP_170:.*]] = mhlo.multiply %[[TMP_167]], %[[TMP_169]] - // CHECK: %[[TMP_171:.*]] = mhlo.constant dense<-3.3896803E-13> - // CHECK: %[[TMP_172:.*]] = mhlo.add %[[TMP_165]], %[[TMP_171]] - // CHECK: %[[TMP_173:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_172]] - // CHECK: %[[TMP_174:.*]] = mhlo.multiply %[[TMP_170]], %[[TMP_173]] - // CHECK: %[[TMP_175:.*]] = mhlo.constant dense<1.200000e+01> - // CHECK: %[[TMP_176:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_175]] - // CHECK: %[[TMP_177:.*]] = mhlo.constant dense<1.100000e+01> - // CHECK: %[[TMP_178:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_177]] - // CHECK: %[[TMP_179:.*]] = mhlo.multiply %[[TMP_176]], %[[TMP_178]] - // CHECK: %[[TMP_180:.*]] = mhlo.constant dense<1.33825364E-11> - // CHECK: %[[TMP_181:.*]] = mhlo.add %[[TMP_174]], %[[TMP_180]] - // CHECK: %[[TMP_182:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_181]] - // CHECK: %[[TMP_183:.*]] = mhlo.multiply %[[TMP_179]], %[[TMP_182]] - // CHECK: %[[TMP_184:.*]] = mhlo.constant dense<1.000000e+01> - // CHECK: %[[TMP_185:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_184]] - // CHECK: %[[TMP_186:.*]] = mhlo.constant dense<9.000000e+00> - // CHECK: %[[TMP_187:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_186]] - // CHECK: %[[TMP_188:.*]] = mhlo.multiply %[[TMP_185]], %[[TMP_187]] - // CHECK: %[[TMP_189:.*]] = mhlo.constant dense<-5.28419031E-10> - // CHECK: %[[TMP_190:.*]] = mhlo.add %[[TMP_183]], %[[TMP_189]] - // CHECK: %[[TMP_191:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_190]] - // CHECK: %[[TMP_192:.*]] = mhlo.multiply %[[TMP_188]], %[[TMP_191]] - // CHECK: %[[TMP_193:.*]] = mhlo.constant dense<8.000000e+00> - // CHECK: %[[TMP_194:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_193]] - // CHECK: %[[TMP_195:.*]] = mhlo.constant dense<7.000000e+00> - // CHECK: %[[TMP_196:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_195]] - // CHECK: %[[TMP_197:.*]] = mhlo.multiply %[[TMP_194]], %[[TMP_196]] - // CHECK: %[[TMP_198:.*]] = mhlo.constant dense<2.08767563E-8> - // CHECK: %[[TMP_199:.*]] = mhlo.add %[[TMP_192]], %[[TMP_198]] - // CHECK: %[[TMP_200:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_199]] - // CHECK: %[[TMP_201:.*]] = mhlo.multiply %[[TMP_197]], %[[TMP_200]] - // CHECK: %[[TMP_202:.*]] = mhlo.constant dense<6.000000e+00> - // CHECK: %[[TMP_203:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_202]] - // CHECK: %[[TMP_204:.*]] = mhlo.constant dense<5.000000e+00> - // CHECK: %[[TMP_205:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_204]] - // CHECK: %[[TMP_206:.*]] = mhlo.multiply %[[TMP_203]], %[[TMP_205]] - // CHECK: %[[TMP_207:.*]] = mhlo.constant dense<-8.26719599E-7> - // CHECK: %[[TMP_208:.*]] = mhlo.add %[[TMP_201]], %[[TMP_207]] - // CHECK: %[[TMP_209:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_208]] - // CHECK: %[[TMP_210:.*]] = mhlo.multiply %[[TMP_206]], %[[TMP_209]] - // CHECK: %[[TMP_211:.*]] = mhlo.constant dense<4.000000e+00> - // CHECK: %[[TMP_212:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_211]] - // CHECK: %[[TMP_213:.*]] = mhlo.constant dense<3.000000e+00> - // CHECK: %[[TMP_214:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_213]] - // CHECK: %[[TMP_215:.*]] = mhlo.multiply %[[TMP_212]], %[[TMP_214]] - // CHECK: %[[TMP_216:.*]] = mhlo.constant dense<3.30687835E-5> - // CHECK: %[[TMP_217:.*]] = mhlo.add %[[TMP_210]], %[[TMP_216]] - // CHECK: %[[TMP_218:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_217]] - // CHECK: %[[TMP_219:.*]] = mhlo.multiply %[[TMP_215]], %[[TMP_218]] - // CHECK: %[[TMP_220:.*]] = mhlo.constant dense<2.000000e+00> - // CHECK: %[[TMP_221:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_220]] - // CHECK: %[[TMP_222:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_223:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_222]] - // CHECK: %[[TMP_224:.*]] = mhlo.multiply %[[TMP_221]], %[[TMP_223]] - // CHECK: %[[TMP_225:.*]] = mhlo.constant dense<-0.00138888892> - // CHECK: %[[TMP_226:.*]] = mhlo.add %[[TMP_219]], %[[TMP_225]] - // CHECK: %[[TMP_227:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_226]] - // CHECK: %[[TMP_228:.*]] = mhlo.multiply %[[TMP_224]], %[[TMP_227]] - // CHECK: %[[TMP_229:.*]] = mhlo.constant dense<5.000000e-01> - // CHECK: %[[TMP_230:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_121]] - // CHECK: %[[TMP_231:.*]] = mhlo.constant dense<0.0833333358> - // CHECK: %[[TMP_232:.*]] = mhlo.add %[[TMP_231]], %[[TMP_228]] - // CHECK: %[[TMP_233:.*]] = mhlo.multiply %[[TMP_230]], %[[TMP_232]] - // CHECK: %[[TMP_234:.*]] = mhlo.add %[[TMP_229]], %[[TMP_233]] - // CHECK: %[[TMP_235:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_234]] - // CHECK: %[[TMP_236:.*]] = mhlo.add %[[TMP_127]], %[[TMP_235]] + // CHECK: %[[TMP_124:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_121]] + // CHECK: %[[TMP_125:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_123]] + // CHECK: %[[TMP_126:.*]] = mhlo.divide %[[TMP_124]], %[[TMP_125]] + // CHECK: %[[TMP_127:.*]] = mhlo.multiply %[[TMP_121]], %[[TMP_121]] + // CHECK: %[[TMP_128:.*]] = mhlo.divide %[[TMP_91]], %[[TMP_127]] + // CHECK: %[[TMP_129:.*]] = mhlo.constant dense<2.200000e+01> + // CHECK: %[[TMP_130:.*]] = mhlo.add %[[TMP_5]], %[[TMP_129]] + // CHECK: %[[TMP_131:.*]] = mhlo.constant dense<2.100000e+01> + // CHECK: %[[TMP_132:.*]] = mhlo.add %[[TMP_5]], %[[TMP_131]] + // CHECK: %[[TMP_133:.*]] = mhlo.multiply %[[TMP_130]], %[[TMP_132]] + // CHECK: %[[TMP_134:.*]] = mhlo.constant dense<-1.39544646E-19> + // CHECK: %[[TMP_135:.*]] = mhlo.add %[[TMP_90]], %[[TMP_134]] + // CHECK: %[[TMP_136:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_135]] + // CHECK: %[[TMP_137:.*]] = mhlo.multiply %[[TMP_133]], %[[TMP_136]] + // CHECK: %[[TMP_138:.*]] = mhlo.constant dense<2.000000e+01> + // CHECK: %[[TMP_139:.*]] = mhlo.add %[[TMP_5]], %[[TMP_138]] + // CHECK: %[[TMP_140:.*]] = mhlo.constant dense<1.900000e+01> + // CHECK: %[[TMP_141:.*]] = mhlo.add %[[TMP_5]], %[[TMP_140]] + // CHECK: %[[TMP_142:.*]] = mhlo.multiply %[[TMP_139]], %[[TMP_141]] + // CHECK: %[[TMP_143:.*]] = mhlo.constant dense<5.50900303E-18> + // CHECK: %[[TMP_144:.*]] = mhlo.add %[[TMP_137]], %[[TMP_143]] + // CHECK: %[[TMP_145:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_144]] + // CHECK: %[[TMP_146:.*]] = mhlo.multiply %[[TMP_142]], %[[TMP_145]] + // CHECK: %[[TMP_147:.*]] = mhlo.constant dense<1.800000e+01> + // CHECK: %[[TMP_148:.*]] = mhlo.add %[[TMP_5]], %[[TMP_147]] + // CHECK: %[[TMP_149:.*]] = mhlo.constant dense<1.700000e+01> + // CHECK: %[[TMP_150:.*]] = mhlo.add %[[TMP_5]], %[[TMP_149]] + // CHECK: %[[TMP_151:.*]] = mhlo.multiply %[[TMP_148]], %[[TMP_150]] + // CHECK: %[[TMP_152:.*]] = mhlo.constant dense<-2.17486866E-16> + // CHECK: %[[TMP_153:.*]] = mhlo.add %[[TMP_146]], %[[TMP_152]] + // CHECK: %[[TMP_154:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_153]] + // CHECK: %[[TMP_155:.*]] = mhlo.multiply %[[TMP_151]], %[[TMP_154]] + // CHECK: %[[TMP_156:.*]] = mhlo.constant dense<1.600000e+01> + // CHECK: %[[TMP_157:.*]] = mhlo.add %[[TMP_5]], %[[TMP_156]] + // CHECK: %[[TMP_158:.*]] = mhlo.constant dense<1.500000e+01> + // CHECK: %[[TMP_159:.*]] = mhlo.add %[[TMP_5]], %[[TMP_158]] + // CHECK: %[[TMP_160:.*]] = mhlo.multiply %[[TMP_157]], %[[TMP_159]] + // CHECK: %[[TMP_161:.*]] = mhlo.constant dense<8.58606213E-15> + // CHECK: %[[TMP_162:.*]] = mhlo.add %[[TMP_155]], %[[TMP_161]] + // CHECK: %[[TMP_163:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_162]] + // CHECK: %[[TMP_164:.*]] = mhlo.multiply %[[TMP_160]], %[[TMP_163]] + // CHECK: %[[TMP_165:.*]] = mhlo.constant dense<1.400000e+01> + // CHECK: %[[TMP_166:.*]] = mhlo.add %[[TMP_5]], %[[TMP_165]] + // CHECK: %[[TMP_167:.*]] = mhlo.constant dense<1.300000e+01> + // CHECK: %[[TMP_168:.*]] = mhlo.add %[[TMP_5]], %[[TMP_167]] + // CHECK: %[[TMP_169:.*]] = mhlo.multiply %[[TMP_166]], %[[TMP_168]] + // CHECK: %[[TMP_170:.*]] = mhlo.constant dense<-3.3896803E-13> + // CHECK: %[[TMP_171:.*]] = mhlo.add %[[TMP_164]], %[[TMP_170]] + // CHECK: %[[TMP_172:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_171]] + // CHECK: %[[TMP_173:.*]] = mhlo.multiply %[[TMP_169]], %[[TMP_172]] + // CHECK: %[[TMP_174:.*]] = mhlo.constant dense<1.200000e+01> + // CHECK: %[[TMP_175:.*]] = mhlo.add %[[TMP_5]], %[[TMP_174]] + // CHECK: %[[TMP_176:.*]] = mhlo.constant dense<1.100000e+01> + // CHECK: %[[TMP_177:.*]] = mhlo.add %[[TMP_5]], %[[TMP_176]] + // CHECK: %[[TMP_178:.*]] = mhlo.multiply %[[TMP_175]], %[[TMP_177]] + // CHECK: %[[TMP_179:.*]] = mhlo.constant dense<1.33825364E-11> + // CHECK: %[[TMP_180:.*]] = mhlo.add %[[TMP_173]], %[[TMP_179]] + // CHECK: %[[TMP_181:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_180]] + // CHECK: %[[TMP_182:.*]] = mhlo.multiply %[[TMP_178]], %[[TMP_181]] + // CHECK: %[[TMP_183:.*]] = mhlo.constant dense<1.000000e+01> + // CHECK: %[[TMP_184:.*]] = mhlo.add %[[TMP_5]], %[[TMP_183]] + // CHECK: %[[TMP_185:.*]] = mhlo.constant dense<9.000000e+00> + // CHECK: %[[TMP_186:.*]] = mhlo.add %[[TMP_5]], %[[TMP_185]] + // CHECK: %[[TMP_187:.*]] = mhlo.multiply %[[TMP_184]], %[[TMP_186]] + // CHECK: %[[TMP_188:.*]] = mhlo.constant dense<-5.28419031E-10> + // CHECK: %[[TMP_189:.*]] = mhlo.add %[[TMP_182]], %[[TMP_188]] + // CHECK: %[[TMP_190:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_189]] + // CHECK: %[[TMP_191:.*]] = mhlo.multiply %[[TMP_187]], %[[TMP_190]] + // CHECK: %[[TMP_192:.*]] = mhlo.constant dense<8.000000e+00> + // CHECK: %[[TMP_193:.*]] = mhlo.add %[[TMP_5]], %[[TMP_192]] + // CHECK: %[[TMP_194:.*]] = mhlo.constant dense<7.000000e+00> + // CHECK: %[[TMP_195:.*]] = mhlo.add %[[TMP_5]], %[[TMP_194]] + // CHECK: %[[TMP_196:.*]] = mhlo.multiply %[[TMP_193]], %[[TMP_195]] + // CHECK: %[[TMP_197:.*]] = mhlo.constant dense<2.08767563E-8> + // CHECK: %[[TMP_198:.*]] = mhlo.add %[[TMP_191]], %[[TMP_197]] + // CHECK: %[[TMP_199:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_198]] + // CHECK: %[[TMP_200:.*]] = mhlo.multiply %[[TMP_196]], %[[TMP_199]] + // CHECK: %[[TMP_201:.*]] = mhlo.constant dense<6.000000e+00> + // CHECK: %[[TMP_202:.*]] = mhlo.add %[[TMP_5]], %[[TMP_201]] + // CHECK: %[[TMP_203:.*]] = mhlo.constant dense<5.000000e+00> + // CHECK: %[[TMP_204:.*]] = mhlo.add %[[TMP_5]], %[[TMP_203]] + // CHECK: %[[TMP_205:.*]] = mhlo.multiply %[[TMP_202]], %[[TMP_204]] + // CHECK: %[[TMP_206:.*]] = mhlo.constant dense<-8.26719599E-7> + // CHECK: %[[TMP_207:.*]] = mhlo.add %[[TMP_200]], %[[TMP_206]] + // CHECK: %[[TMP_208:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_207]] + // CHECK: %[[TMP_209:.*]] = mhlo.multiply %[[TMP_205]], %[[TMP_208]] + // CHECK: %[[TMP_210:.*]] = mhlo.constant dense<4.000000e+00> + // CHECK: %[[TMP_211:.*]] = mhlo.add %[[TMP_5]], %[[TMP_210]] + // CHECK: %[[TMP_212:.*]] = mhlo.constant dense<3.000000e+00> + // CHECK: %[[TMP_213:.*]] = mhlo.add %[[TMP_5]], %[[TMP_212]] + // CHECK: %[[TMP_214:.*]] = mhlo.multiply %[[TMP_211]], %[[TMP_213]] + // CHECK: %[[TMP_215:.*]] = mhlo.constant dense<3.30687835E-5> + // CHECK: %[[TMP_216:.*]] = mhlo.add %[[TMP_209]], %[[TMP_215]] + // CHECK: %[[TMP_217:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_216]] + // CHECK: %[[TMP_218:.*]] = mhlo.multiply %[[TMP_214]], %[[TMP_217]] + // CHECK: %[[TMP_219:.*]] = mhlo.constant dense<2.000000e+00> + // CHECK: %[[TMP_220:.*]] = mhlo.add %[[TMP_5]], %[[TMP_219]] + // CHECK: %[[TMP_221:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_222:.*]] = mhlo.add %[[TMP_5]], %[[TMP_221]] + // CHECK: %[[TMP_223:.*]] = mhlo.multiply %[[TMP_220]], %[[TMP_222]] + // CHECK: %[[TMP_224:.*]] = mhlo.constant dense<-0.00138888892> + // CHECK: %[[TMP_225:.*]] = mhlo.add %[[TMP_218]], %[[TMP_224]] + // CHECK: %[[TMP_226:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_225]] + // CHECK: %[[TMP_227:.*]] = mhlo.multiply %[[TMP_223]], %[[TMP_226]] + // CHECK: %[[TMP_228:.*]] = mhlo.constant dense<5.000000e-01> + // CHECK: %[[TMP_229:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_121]] + // CHECK: %[[TMP_230:.*]] = mhlo.constant dense<0.0833333358> + // CHECK: %[[TMP_231:.*]] = mhlo.add %[[TMP_230]], %[[TMP_227]] + // CHECK: %[[TMP_232:.*]] = mhlo.multiply %[[TMP_229]], %[[TMP_231]] + // CHECK: %[[TMP_233:.*]] = mhlo.add %[[TMP_228]], %[[TMP_232]] + // CHECK: %[[TMP_234:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_233]] + // CHECK: %[[TMP_235:.*]] = mhlo.add %[[TMP_120]], %[[TMP_126]] + // CHECK: %[[TMP_236:.*]] = mhlo.add %[[TMP_235]], %[[TMP_234]] // CHECK: %[[TMP_237:.*]] = mhlo.abs %[[TMP_122]] // CHECK: %[[TMP_238:.*]] = mhlo.abs %[[TMP_120]] // CHECK: %[[TMP_239:.*]] = mhlo.constant dense<1.401300e-45> @@ -1734,7 +1557,7 @@ func.func @polygamma_f32(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_260:.*]] = mhlo.and %[[TMP_257]], %[[TMP_259]] // CHECK: %[[TMP_261:.*]] = mhlo.select %[[TMP_260]], %[[TMP_251]], %[[TMP_243]] // CHECK: %[[TMP_262:.*]] = mhlo.select %[[TMP_254]], %[[TMP_261]], %[[TMP_250]] - // CHECK: %[[TMP_263:.*]] = mhlo.compare EQ, %[[TMP_5]], %[[TMP_93]], NOTYPE + // CHECK: %[[TMP_263:.*]] = mhlo.compare EQ, %[[TMP_5]], %[[TMP_91]], NOTYPE // CHECK: %[[TMP_264:.*]] = mhlo.select %[[TMP_263]], %[[TMP_251]], %[[TMP_262]] // CHECK: %[[TMP_265:.*]] = mhlo.multiply %[[TMP_4]], %[[TMP_89]] // CHECK: %[[TMP_266:.*]] = mhlo.multiply %[[TMP_265]], %[[TMP_264]] @@ -1948,153 +1771,153 @@ func.func @polygamma_f64(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_87:.*]] = mhlo.constant dense<0x7FF0000000000000> // CHECK: %[[TMP_88:.*]] = mhlo.select %[[TMP_86]], %[[TMP_87]], %[[TMP_83]] // CHECK: %[[TMP_89:.*]] = mhlo.exponential %[[TMP_88]] - // CHECK: %[[TMP_90:.*]] = mhlo.constant dense<0.000000e+00> - // CHECK: %[[TMP_91:.*]] = mhlo.negate %[[TMP_5]] - // CHECK: %[[TMP_92:.*]] = mhlo.power %[[ARG1]], %[[TMP_91]] - // CHECK: %[[TMP_93:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_94:.*]] = mhlo.add %[[ARG1]], %[[TMP_93]] - // CHECK: %[[TMP_95:.*]] = mhlo.power %[[TMP_94]], %[[TMP_91]] - // CHECK: %[[TMP_96:.*]] = mhlo.add %[[TMP_92]], %[[TMP_95]] - // CHECK: %[[TMP_97:.*]] = mhlo.add %[[TMP_94]], %[[TMP_93]] - // CHECK: %[[TMP_98:.*]] = mhlo.power %[[TMP_97]], %[[TMP_91]] + // CHECK-DAG: %[[TMP_90:.*]] = mhlo.constant dense<0.000000e+00> + // CHECK-DAG: %[[TMP_91:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_92:.*]] = mhlo.negate %[[TMP_5]] + // CHECK: %[[TMP_93:.*]] = mhlo.power %[[ARG1]], %[[TMP_92]] + // CHECK: %[[TMP_94:.*]] = mhlo.add %[[ARG1]], %[[TMP_91]] + // CHECK: %[[TMP_95:.*]] = mhlo.power %[[TMP_94]], %[[TMP_92]] + // CHECK: %[[TMP_96:.*]] = mhlo.add %[[TMP_93]], %[[TMP_95]] + // CHECK: %[[TMP_97:.*]] = mhlo.add %[[TMP_94]], %[[TMP_91]] + // CHECK: %[[TMP_98:.*]] = mhlo.power %[[TMP_97]], %[[TMP_92]] // CHECK: %[[TMP_99:.*]] = mhlo.add %[[TMP_96]], %[[TMP_98]] - // CHECK: %[[TMP_100:.*]] = mhlo.add %[[TMP_97]], %[[TMP_93]] - // CHECK: %[[TMP_101:.*]] = mhlo.power %[[TMP_100]], %[[TMP_91]] + // CHECK: %[[TMP_100:.*]] = mhlo.add %[[TMP_97]], %[[TMP_91]] + // CHECK: %[[TMP_101:.*]] = mhlo.power %[[TMP_100]], %[[TMP_92]] // CHECK: %[[TMP_102:.*]] = mhlo.add %[[TMP_99]], %[[TMP_101]] - // CHECK: %[[TMP_103:.*]] = mhlo.add %[[TMP_100]], %[[TMP_93]] - // CHECK: %[[TMP_104:.*]] = mhlo.power %[[TMP_103]], %[[TMP_91]] + // CHECK: %[[TMP_103:.*]] = mhlo.add %[[TMP_100]], %[[TMP_91]] + // CHECK: %[[TMP_104:.*]] = mhlo.power %[[TMP_103]], %[[TMP_92]] // CHECK: %[[TMP_105:.*]] = mhlo.add %[[TMP_102]], %[[TMP_104]] - // CHECK: %[[TMP_106:.*]] = mhlo.add %[[TMP_103]], %[[TMP_93]] - // CHECK: %[[TMP_107:.*]] = mhlo.power %[[TMP_106]], %[[TMP_91]] + // CHECK: %[[TMP_106:.*]] = mhlo.add %[[TMP_103]], %[[TMP_91]] + // CHECK: %[[TMP_107:.*]] = mhlo.power %[[TMP_106]], %[[TMP_92]] // CHECK: %[[TMP_108:.*]] = mhlo.add %[[TMP_105]], %[[TMP_107]] - // CHECK: %[[TMP_109:.*]] = mhlo.add %[[TMP_106]], %[[TMP_93]] - // CHECK: %[[TMP_110:.*]] = mhlo.power %[[TMP_109]], %[[TMP_91]] + // CHECK: %[[TMP_109:.*]] = mhlo.add %[[TMP_106]], %[[TMP_91]] + // CHECK: %[[TMP_110:.*]] = mhlo.power %[[TMP_109]], %[[TMP_92]] // CHECK: %[[TMP_111:.*]] = mhlo.add %[[TMP_108]], %[[TMP_110]] - // CHECK: %[[TMP_112:.*]] = mhlo.add %[[TMP_109]], %[[TMP_93]] - // CHECK: %[[TMP_113:.*]] = mhlo.power %[[TMP_112]], %[[TMP_91]] + // CHECK: %[[TMP_112:.*]] = mhlo.add %[[TMP_109]], %[[TMP_91]] + // CHECK: %[[TMP_113:.*]] = mhlo.power %[[TMP_112]], %[[TMP_92]] // CHECK: %[[TMP_114:.*]] = mhlo.add %[[TMP_111]], %[[TMP_113]] - // CHECK: %[[TMP_115:.*]] = mhlo.add %[[TMP_112]], %[[TMP_93]] - // CHECK: %[[TMP_116:.*]] = mhlo.power %[[TMP_115]], %[[TMP_91]] + // CHECK: %[[TMP_115:.*]] = mhlo.add %[[TMP_112]], %[[TMP_91]] + // CHECK: %[[TMP_116:.*]] = mhlo.power %[[TMP_115]], %[[TMP_92]] // CHECK: %[[TMP_117:.*]] = mhlo.add %[[TMP_114]], %[[TMP_116]] - // CHECK: %[[TMP_118:.*]] = mhlo.add %[[TMP_115]], %[[TMP_93]] - // CHECK: %[[TMP_119:.*]] = mhlo.power %[[TMP_118]], %[[TMP_91]] + // CHECK: %[[TMP_118:.*]] = mhlo.add %[[TMP_115]], %[[TMP_91]] + // CHECK: %[[TMP_119:.*]] = mhlo.power %[[TMP_118]], %[[TMP_92]] // CHECK: %[[TMP_120:.*]] = mhlo.add %[[TMP_117]], %[[TMP_119]] - // CHECK: %[[TMP_121:.*]] = mhlo.add %[[TMP_118]], %[[TMP_93]] - // CHECK: %[[TMP_122:.*]] = mhlo.power %[[TMP_121]], %[[TMP_91]] + // CHECK: %[[TMP_121:.*]] = mhlo.add %[[TMP_118]], %[[TMP_91]] + // CHECK: %[[TMP_122:.*]] = mhlo.power %[[TMP_121]], %[[TMP_92]] // CHECK: %[[TMP_123:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_124:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_123]] - // CHECK: %[[TMP_125:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_121]] - // CHECK: %[[TMP_126:.*]] = mhlo.divide %[[TMP_125]], %[[TMP_124]] - // CHECK: %[[TMP_127:.*]] = mhlo.add %[[TMP_120]], %[[TMP_126]] - // CHECK: %[[TMP_128:.*]] = mhlo.multiply %[[TMP_121]], %[[TMP_121]] - // CHECK: %[[TMP_129:.*]] = mhlo.divide %[[TMP_93]], %[[TMP_128]] - // CHECK: %[[TMP_130:.*]] = mhlo.constant dense<2.200000e+01> - // CHECK: %[[TMP_131:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_130]] - // CHECK: %[[TMP_132:.*]] = mhlo.constant dense<2.100000e+01> - // CHECK: %[[TMP_133:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_132]] - // CHECK: %[[TMP_134:.*]] = mhlo.multiply %[[TMP_131]], %[[TMP_133]] - // CHECK: %[[TMP_135:.*]] = mhlo.constant dense<-1.3954464685812522E-19> - // CHECK: %[[TMP_136:.*]] = mhlo.add %[[TMP_90]], %[[TMP_135]] - // CHECK: %[[TMP_137:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_136]] - // CHECK: %[[TMP_138:.*]] = mhlo.multiply %[[TMP_134]], %[[TMP_137]] - // CHECK: %[[TMP_139:.*]] = mhlo.constant dense<2.000000e+01> - // CHECK: %[[TMP_140:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_139]] - // CHECK: %[[TMP_141:.*]] = mhlo.constant dense<1.900000e+01> - // CHECK: %[[TMP_142:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_141]] - // CHECK: %[[TMP_143:.*]] = mhlo.multiply %[[TMP_140]], %[[TMP_142]] - // CHECK: %[[TMP_144:.*]] = mhlo.constant dense<5.5090028283602295E-18> - // CHECK: %[[TMP_145:.*]] = mhlo.add %[[TMP_138]], %[[TMP_144]] - // CHECK: %[[TMP_146:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_145]] - // CHECK: %[[TMP_147:.*]] = mhlo.multiply %[[TMP_143]], %[[TMP_146]] - // CHECK: %[[TMP_148:.*]] = mhlo.constant dense<1.800000e+01> - // CHECK: %[[TMP_149:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_148]] - // CHECK: %[[TMP_150:.*]] = mhlo.constant dense<1.700000e+01> - // CHECK: %[[TMP_151:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_150]] - // CHECK: %[[TMP_152:.*]] = mhlo.multiply %[[TMP_149]], %[[TMP_151]] - // CHECK: %[[TMP_153:.*]] = mhlo.constant dense<-2.1748686985580617E-16> - // CHECK: %[[TMP_154:.*]] = mhlo.add %[[TMP_147]], %[[TMP_153]] - // CHECK: %[[TMP_155:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_154]] - // CHECK: %[[TMP_156:.*]] = mhlo.multiply %[[TMP_152]], %[[TMP_155]] - // CHECK: %[[TMP_157:.*]] = mhlo.constant dense<1.600000e+01> - // CHECK: %[[TMP_158:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_157]] - // CHECK: %[[TMP_159:.*]] = mhlo.constant dense<1.500000e+01> - // CHECK: %[[TMP_160:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_159]] - // CHECK: %[[TMP_161:.*]] = mhlo.multiply %[[TMP_158]], %[[TMP_160]] - // CHECK: %[[TMP_162:.*]] = mhlo.constant dense<8.5860620562778452E-15> - // CHECK: %[[TMP_163:.*]] = mhlo.add %[[TMP_156]], %[[TMP_162]] - // CHECK: %[[TMP_164:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_163]] - // CHECK: %[[TMP_165:.*]] = mhlo.multiply %[[TMP_161]], %[[TMP_164]] - // CHECK: %[[TMP_166:.*]] = mhlo.constant dense<1.400000e+01> - // CHECK: %[[TMP_167:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_166]] - // CHECK: %[[TMP_168:.*]] = mhlo.constant dense<1.300000e+01> - // CHECK: %[[TMP_169:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_168]] - // CHECK: %[[TMP_170:.*]] = mhlo.multiply %[[TMP_167]], %[[TMP_169]] - // CHECK: %[[TMP_171:.*]] = mhlo.constant dense<-3.3896802963225832E-13> - // CHECK: %[[TMP_172:.*]] = mhlo.add %[[TMP_165]], %[[TMP_171]] - // CHECK: %[[TMP_173:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_172]] - // CHECK: %[[TMP_174:.*]] = mhlo.multiply %[[TMP_170]], %[[TMP_173]] - // CHECK: %[[TMP_175:.*]] = mhlo.constant dense<1.200000e+01> - // CHECK: %[[TMP_176:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_175]] - // CHECK: %[[TMP_177:.*]] = mhlo.constant dense<1.100000e+01> - // CHECK: %[[TMP_178:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_177]] - // CHECK: %[[TMP_179:.*]] = mhlo.multiply %[[TMP_176]], %[[TMP_178]] - // CHECK: %[[TMP_180:.*]] = mhlo.constant dense<1.3382536530684679E-11> - // CHECK: %[[TMP_181:.*]] = mhlo.add %[[TMP_174]], %[[TMP_180]] - // CHECK: %[[TMP_182:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_181]] - // CHECK: %[[TMP_183:.*]] = mhlo.multiply %[[TMP_179]], %[[TMP_182]] - // CHECK: %[[TMP_184:.*]] = mhlo.constant dense<1.000000e+01> - // CHECK: %[[TMP_185:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_184]] - // CHECK: %[[TMP_186:.*]] = mhlo.constant dense<9.000000e+00> - // CHECK: %[[TMP_187:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_186]] - // CHECK: %[[TMP_188:.*]] = mhlo.multiply %[[TMP_185]], %[[TMP_187]] - // CHECK: %[[TMP_189:.*]] = mhlo.constant dense<-5.2841901386874932E-10> - // CHECK: %[[TMP_190:.*]] = mhlo.add %[[TMP_183]], %[[TMP_189]] - // CHECK: %[[TMP_191:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_190]] - // CHECK: %[[TMP_192:.*]] = mhlo.multiply %[[TMP_188]], %[[TMP_191]] - // CHECK: %[[TMP_193:.*]] = mhlo.constant dense<8.000000e+00> - // CHECK: %[[TMP_194:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_193]] - // CHECK: %[[TMP_195:.*]] = mhlo.constant dense<7.000000e+00> - // CHECK: %[[TMP_196:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_195]] - // CHECK: %[[TMP_197:.*]] = mhlo.multiply %[[TMP_194]], %[[TMP_196]] - // CHECK: %[[TMP_198:.*]] = mhlo.constant dense<2.08767569878681E-8> - // CHECK: %[[TMP_199:.*]] = mhlo.add %[[TMP_192]], %[[TMP_198]] - // CHECK: %[[TMP_200:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_199]] - // CHECK: %[[TMP_201:.*]] = mhlo.multiply %[[TMP_197]], %[[TMP_200]] - // CHECK: %[[TMP_202:.*]] = mhlo.constant dense<6.000000e+00> - // CHECK: %[[TMP_203:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_202]] - // CHECK: %[[TMP_204:.*]] = mhlo.constant dense<5.000000e+00> - // CHECK: %[[TMP_205:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_204]] - // CHECK: %[[TMP_206:.*]] = mhlo.multiply %[[TMP_203]], %[[TMP_205]] - // CHECK: %[[TMP_207:.*]] = mhlo.constant dense<-8.2671957671957675E-7> - // CHECK: %[[TMP_208:.*]] = mhlo.add %[[TMP_201]], %[[TMP_207]] - // CHECK: %[[TMP_209:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_208]] - // CHECK: %[[TMP_210:.*]] = mhlo.multiply %[[TMP_206]], %[[TMP_209]] - // CHECK: %[[TMP_211:.*]] = mhlo.constant dense<4.000000e+00> - // CHECK: %[[TMP_212:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_211]] - // CHECK: %[[TMP_213:.*]] = mhlo.constant dense<3.000000e+00> - // CHECK: %[[TMP_214:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_213]] - // CHECK: %[[TMP_215:.*]] = mhlo.multiply %[[TMP_212]], %[[TMP_214]] - // CHECK: %[[TMP_216:.*]] = mhlo.constant dense<3.3068783068783071E-5> - // CHECK: %[[TMP_217:.*]] = mhlo.add %[[TMP_210]], %[[TMP_216]] - // CHECK: %[[TMP_218:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_217]] - // CHECK: %[[TMP_219:.*]] = mhlo.multiply %[[TMP_215]], %[[TMP_218]] - // CHECK: %[[TMP_220:.*]] = mhlo.constant dense<2.000000e+00> - // CHECK: %[[TMP_221:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_220]] - // CHECK: %[[TMP_222:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_223:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_222]] - // CHECK: %[[TMP_224:.*]] = mhlo.multiply %[[TMP_221]], %[[TMP_223]] - // CHECK: %[[TMP_225:.*]] = mhlo.constant dense<-0.0013888888888888889> - // CHECK: %[[TMP_226:.*]] = mhlo.add %[[TMP_219]], %[[TMP_225]] - // CHECK: %[[TMP_227:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_226]] - // CHECK: %[[TMP_228:.*]] = mhlo.multiply %[[TMP_224]], %[[TMP_227]] - // CHECK: %[[TMP_229:.*]] = mhlo.constant dense<5.000000e-01> - // CHECK: %[[TMP_230:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_121]] - // CHECK: %[[TMP_231:.*]] = mhlo.constant dense<0.083333333333333329> - // CHECK: %[[TMP_232:.*]] = mhlo.add %[[TMP_231]], %[[TMP_228]] - // CHECK: %[[TMP_233:.*]] = mhlo.multiply %[[TMP_230]], %[[TMP_232]] - // CHECK: %[[TMP_234:.*]] = mhlo.add %[[TMP_229]], %[[TMP_233]] - // CHECK: %[[TMP_235:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_234]] - // CHECK: %[[TMP_236:.*]] = mhlo.add %[[TMP_127]], %[[TMP_235]] + // CHECK: %[[TMP_124:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_121]] + // CHECK: %[[TMP_125:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_123]] + // CHECK: %[[TMP_126:.*]] = mhlo.divide %[[TMP_124]], %[[TMP_125]] + // CHECK: %[[TMP_127:.*]] = mhlo.multiply %[[TMP_121]], %[[TMP_121]] + // CHECK: %[[TMP_128:.*]] = mhlo.divide %[[TMP_91]], %[[TMP_127]] + // CHECK: %[[TMP_129:.*]] = mhlo.constant dense<2.200000e+01> + // CHECK: %[[TMP_130:.*]] = mhlo.add %[[TMP_5]], %[[TMP_129]] + // CHECK: %[[TMP_131:.*]] = mhlo.constant dense<2.100000e+01> + // CHECK: %[[TMP_132:.*]] = mhlo.add %[[TMP_5]], %[[TMP_131]] + // CHECK: %[[TMP_133:.*]] = mhlo.multiply %[[TMP_130]], %[[TMP_132]] + // CHECK: %[[TMP_134:.*]] = mhlo.constant dense<-1.3954464685812522E-19> + // CHECK: %[[TMP_135:.*]] = mhlo.add %[[TMP_90]], %[[TMP_134]] + // CHECK: %[[TMP_136:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_135]] + // CHECK: %[[TMP_137:.*]] = mhlo.multiply %[[TMP_133]], %[[TMP_136]] + // CHECK: %[[TMP_138:.*]] = mhlo.constant dense<2.000000e+01> + // CHECK: %[[TMP_139:.*]] = mhlo.add %[[TMP_5]], %[[TMP_138]] + // CHECK: %[[TMP_140:.*]] = mhlo.constant dense<1.900000e+01> + // CHECK: %[[TMP_141:.*]] = mhlo.add %[[TMP_5]], %[[TMP_140]] + // CHECK: %[[TMP_142:.*]] = mhlo.multiply %[[TMP_139]], %[[TMP_141]] + // CHECK: %[[TMP_143:.*]] = mhlo.constant dense<5.5090028283602295E-18> + // CHECK: %[[TMP_144:.*]] = mhlo.add %[[TMP_137]], %[[TMP_143]] + // CHECK: %[[TMP_145:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_144]] + // CHECK: %[[TMP_146:.*]] = mhlo.multiply %[[TMP_142]], %[[TMP_145]] + // CHECK: %[[TMP_147:.*]] = mhlo.constant dense<1.800000e+01> + // CHECK: %[[TMP_148:.*]] = mhlo.add %[[TMP_5]], %[[TMP_147]] + // CHECK: %[[TMP_149:.*]] = mhlo.constant dense<1.700000e+01> + // CHECK: %[[TMP_150:.*]] = mhlo.add %[[TMP_5]], %[[TMP_149]] + // CHECK: %[[TMP_151:.*]] = mhlo.multiply %[[TMP_148]], %[[TMP_150]] + // CHECK: %[[TMP_152:.*]] = mhlo.constant dense<-2.1748686985580617E-16> + // CHECK: %[[TMP_153:.*]] = mhlo.add %[[TMP_146]], %[[TMP_152]] + // CHECK: %[[TMP_154:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_153]] + // CHECK: %[[TMP_155:.*]] = mhlo.multiply %[[TMP_151]], %[[TMP_154]] + // CHECK: %[[TMP_156:.*]] = mhlo.constant dense<1.600000e+01> + // CHECK: %[[TMP_157:.*]] = mhlo.add %[[TMP_5]], %[[TMP_156]] + // CHECK: %[[TMP_158:.*]] = mhlo.constant dense<1.500000e+01> + // CHECK: %[[TMP_159:.*]] = mhlo.add %[[TMP_5]], %[[TMP_158]] + // CHECK: %[[TMP_160:.*]] = mhlo.multiply %[[TMP_157]], %[[TMP_159]] + // CHECK: %[[TMP_161:.*]] = mhlo.constant dense<8.5860620562778452E-15> + // CHECK: %[[TMP_162:.*]] = mhlo.add %[[TMP_155]], %[[TMP_161]] + // CHECK: %[[TMP_163:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_162]] + // CHECK: %[[TMP_164:.*]] = mhlo.multiply %[[TMP_160]], %[[TMP_163]] + // CHECK: %[[TMP_165:.*]] = mhlo.constant dense<1.400000e+01> + // CHECK: %[[TMP_166:.*]] = mhlo.add %[[TMP_5]], %[[TMP_165]] + // CHECK: %[[TMP_167:.*]] = mhlo.constant dense<1.300000e+01> + // CHECK: %[[TMP_168:.*]] = mhlo.add %[[TMP_5]], %[[TMP_167]] + // CHECK: %[[TMP_169:.*]] = mhlo.multiply %[[TMP_166]], %[[TMP_168]] + // CHECK: %[[TMP_170:.*]] = mhlo.constant dense<-3.3896802963225832E-13> + // CHECK: %[[TMP_171:.*]] = mhlo.add %[[TMP_164]], %[[TMP_170]] + // CHECK: %[[TMP_172:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_171]] + // CHECK: %[[TMP_173:.*]] = mhlo.multiply %[[TMP_169]], %[[TMP_172]] + // CHECK: %[[TMP_174:.*]] = mhlo.constant dense<1.200000e+01> + // CHECK: %[[TMP_175:.*]] = mhlo.add %[[TMP_5]], %[[TMP_174]] + // CHECK: %[[TMP_176:.*]] = mhlo.constant dense<1.100000e+01> + // CHECK: %[[TMP_177:.*]] = mhlo.add %[[TMP_5]], %[[TMP_176]] + // CHECK: %[[TMP_178:.*]] = mhlo.multiply %[[TMP_175]], %[[TMP_177]] + // CHECK: %[[TMP_179:.*]] = mhlo.constant dense<1.3382536530684679E-11> + // CHECK: %[[TMP_180:.*]] = mhlo.add %[[TMP_173]], %[[TMP_179]] + // CHECK: %[[TMP_181:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_180]] + // CHECK: %[[TMP_182:.*]] = mhlo.multiply %[[TMP_178]], %[[TMP_181]] + // CHECK: %[[TMP_183:.*]] = mhlo.constant dense<1.000000e+01> + // CHECK: %[[TMP_184:.*]] = mhlo.add %[[TMP_5]], %[[TMP_183]] + // CHECK: %[[TMP_185:.*]] = mhlo.constant dense<9.000000e+00> + // CHECK: %[[TMP_186:.*]] = mhlo.add %[[TMP_5]], %[[TMP_185]] + // CHECK: %[[TMP_187:.*]] = mhlo.multiply %[[TMP_184]], %[[TMP_186]] + // CHECK: %[[TMP_188:.*]] = mhlo.constant dense<-5.2841901386874932E-10> + // CHECK: %[[TMP_189:.*]] = mhlo.add %[[TMP_182]], %[[TMP_188]] + // CHECK: %[[TMP_190:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_189]] + // CHECK: %[[TMP_191:.*]] = mhlo.multiply %[[TMP_187]], %[[TMP_190]] + // CHECK: %[[TMP_192:.*]] = mhlo.constant dense<8.000000e+00> + // CHECK: %[[TMP_193:.*]] = mhlo.add %[[TMP_5]], %[[TMP_192]] + // CHECK: %[[TMP_194:.*]] = mhlo.constant dense<7.000000e+00> + // CHECK: %[[TMP_195:.*]] = mhlo.add %[[TMP_5]], %[[TMP_194]] + // CHECK: %[[TMP_196:.*]] = mhlo.multiply %[[TMP_193]], %[[TMP_195]] + // CHECK: %[[TMP_197:.*]] = mhlo.constant dense<2.08767569878681E-8> + // CHECK: %[[TMP_198:.*]] = mhlo.add %[[TMP_191]], %[[TMP_197]] + // CHECK: %[[TMP_199:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_198]] + // CHECK: %[[TMP_200:.*]] = mhlo.multiply %[[TMP_196]], %[[TMP_199]] + // CHECK: %[[TMP_201:.*]] = mhlo.constant dense<6.000000e+00> + // CHECK: %[[TMP_202:.*]] = mhlo.add %[[TMP_5]], %[[TMP_201]] + // CHECK: %[[TMP_203:.*]] = mhlo.constant dense<5.000000e+00> + // CHECK: %[[TMP_204:.*]] = mhlo.add %[[TMP_5]], %[[TMP_203]] + // CHECK: %[[TMP_205:.*]] = mhlo.multiply %[[TMP_202]], %[[TMP_204]] + // CHECK: %[[TMP_206:.*]] = mhlo.constant dense<-8.2671957671957675E-7> + // CHECK: %[[TMP_207:.*]] = mhlo.add %[[TMP_200]], %[[TMP_206]] + // CHECK: %[[TMP_208:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_207]] + // CHECK: %[[TMP_209:.*]] = mhlo.multiply %[[TMP_205]], %[[TMP_208]] + // CHECK: %[[TMP_210:.*]] = mhlo.constant dense<4.000000e+00> + // CHECK: %[[TMP_211:.*]] = mhlo.add %[[TMP_5]], %[[TMP_210]] + // CHECK: %[[TMP_212:.*]] = mhlo.constant dense<3.000000e+00> + // CHECK: %[[TMP_213:.*]] = mhlo.add %[[TMP_5]], %[[TMP_212]] + // CHECK: %[[TMP_214:.*]] = mhlo.multiply %[[TMP_211]], %[[TMP_213]] + // CHECK: %[[TMP_215:.*]] = mhlo.constant dense<3.3068783068783071E-5> + // CHECK: %[[TMP_216:.*]] = mhlo.add %[[TMP_209]], %[[TMP_215]] + // CHECK: %[[TMP_217:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_216]] + // CHECK: %[[TMP_218:.*]] = mhlo.multiply %[[TMP_214]], %[[TMP_217]] + // CHECK: %[[TMP_219:.*]] = mhlo.constant dense<2.000000e+00> + // CHECK: %[[TMP_220:.*]] = mhlo.add %[[TMP_5]], %[[TMP_219]] + // CHECK: %[[TMP_221:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_222:.*]] = mhlo.add %[[TMP_5]], %[[TMP_221]] + // CHECK: %[[TMP_223:.*]] = mhlo.multiply %[[TMP_220]], %[[TMP_222]] + // CHECK: %[[TMP_224:.*]] = mhlo.constant dense<-0.0013888888888888889> + // CHECK: %[[TMP_225:.*]] = mhlo.add %[[TMP_218]], %[[TMP_224]] + // CHECK: %[[TMP_226:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_225]] + // CHECK: %[[TMP_227:.*]] = mhlo.multiply %[[TMP_223]], %[[TMP_226]] + // CHECK: %[[TMP_228:.*]] = mhlo.constant dense<5.000000e-01> + // CHECK: %[[TMP_229:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_121]] + // CHECK: %[[TMP_230:.*]] = mhlo.constant dense<0.083333333333333329> + // CHECK: %[[TMP_231:.*]] = mhlo.add %[[TMP_230]], %[[TMP_227]] + // CHECK: %[[TMP_232:.*]] = mhlo.multiply %[[TMP_229]], %[[TMP_231]] + // CHECK: %[[TMP_233:.*]] = mhlo.add %[[TMP_228]], %[[TMP_232]] + // CHECK: %[[TMP_234:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_233]] + // CHECK: %[[TMP_235:.*]] = mhlo.add %[[TMP_120]], %[[TMP_126]] + // CHECK: %[[TMP_236:.*]] = mhlo.add %[[TMP_235]], %[[TMP_234]] // CHECK: %[[TMP_237:.*]] = mhlo.abs %[[TMP_122]] // CHECK: %[[TMP_238:.*]] = mhlo.abs %[[TMP_120]] // CHECK: %[[TMP_239:.*]] = mhlo.constant dense<4.940660e-324> @@ -2121,7 +1944,7 @@ func.func @polygamma_f64(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_260:.*]] = mhlo.and %[[TMP_257]], %[[TMP_259]] // CHECK: %[[TMP_261:.*]] = mhlo.select %[[TMP_260]], %[[TMP_251]], %[[TMP_243]] // CHECK: %[[TMP_262:.*]] = mhlo.select %[[TMP_254]], %[[TMP_261]], %[[TMP_250]] - // CHECK: %[[TMP_263:.*]] = mhlo.compare EQ, %[[TMP_5]], %[[TMP_93]], NOTYPE + // CHECK: %[[TMP_263:.*]] = mhlo.compare EQ, %[[TMP_5]], %[[TMP_91]], NOTYPE // CHECK: %[[TMP_264:.*]] = mhlo.select %[[TMP_263]], %[[TMP_251]], %[[TMP_262]] // CHECK: %[[TMP_265:.*]] = mhlo.multiply %[[TMP_4]], %[[TMP_89]] // CHECK: %[[TMP_266:.*]] = mhlo.multiply %[[TMP_265]], %[[TMP_264]] @@ -2438,12 +2261,9 @@ func.func @next_after_f32(%x: tensor<2xf32>, %y: tensor<2xf32>) -> tensor<2xf32> // CHECK-LABEL: @tan_f16 // CHECK-SAME: (%[[ARG:.*]]: tensor) func.func @tan_f16(%arg : tensor) -> tensor { - // %[[TMP_0:.*]] = mhlo.convert [[ARG]] : (tensor) -> tensor - // %[[TMP_1:.*]] = mhlo.sine %[[TMP_0]] - // %[[TMP_2:.*]] = mhlo.cosine %[[TMP_0]] - // %[[TMP_3:.*]] = mhlo.divide %[[TMP_1]], %[[TMP_2]] - // %[[TMP_4:.*]] = mhlo.convert %[[TMP_3]] : (tensor) -> tensor - // return %[[TMP_4]] : tensor + // CHECK-HIGH-LEVEL: mhlo.tan + // CHECK: %[[RESULT:.*]] = mhlo.tan %[[ARG]] : tensor + // CHECK: return %[[RESULT]] %1 = chlo.tan %arg : tensor -> tensor func.return %1 : tensor } @@ -2453,10 +2273,9 @@ func.func @tan_f16(%arg : tensor) -> tensor { // CHECK-LABEL: @tan_f32 // CHECK-SAME: (%[[ARG:.*]]: tensor) func.func @tan_f32(%arg : tensor) -> tensor { - // %[[TMP_0:.*]] = mhlo.sine %[[ARG]] - // %[[TMP_1:.*]] = mhlo.cosine %[[ARG]] - // %[[TMP_2:.*]] = mhlo.divide %[[TMP_0]], %[[TMP_1]] - // return %[[TMP_2]] : tensor + // CHECK-HIGH-LEVEL: mhlo.tan + // CHECK: %[[RESULT:.*]] = mhlo.tan %[[ARG]] : tensor + // CHECK: return %[[RESULT]] %1 = chlo.tan %arg : tensor -> tensor func.return %1 : tensor } @@ -2466,15 +2285,8 @@ func.func @tan_f32(%arg : tensor) -> tensor { // CHECK-LABEL: @top_k // CHECK-SAME: (%[[ARG:.*]]: tensor<16x16xf32>) func.func @top_k(%arg : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { - // CHECK: %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} - // CHECK-NEXT: %[[SORT:.*]]:2 = "mhlo.sort"(%[[ARG]], %[[IOTA]]) ({ - // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor, %{{.*}}: tensor, %{{.*}}: tensor): - // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[LHS]], %[[RHS]], TOTALORDER - // CHECK-NEXT: mhlo.return %[[CMP]] - // CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) - // CHECK-NEXT: %[[VAL:.*]] = "mhlo.slice"(%[[SORT]]#0) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - // CHECK-NEXT: %[[IDX:.*]] = "mhlo.slice"(%[[SORT]]#1) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - // CHECK-NEXT: return %[[VAL]], %[[IDX]] + // CHECK-HIGH-LEVEL: mhlo.topk + // CHECK: %values, %indices = mhlo.topk(%arg0, k = 8, largest = true) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>) %1:2 = chlo.top_k(%arg, k=8) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>) func.return %1#0, %1#1 : tensor<16x8xf32>, tensor<16x8xi32> } @@ -2485,28 +2297,8 @@ func.func @top_k(%arg : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32 // CHECK-SAME: ([[ARG:%.*]]: tensor // CHECK-SAME: -> (tensor, tensor) func.func @dyn_top_k(%arg0: tensor) -> (tensor, tensor) { - // CHECK-NEXT: [[DIM_0_I32:%.*]] = "mhlo.get_dimension_size"([[ARG]]) {dimension = 0 : i64} : (tensor) -> tensor - // CHECK-NEXT: [[DIM_0_I32x1:%.*]] = mhlo.reshape [[DIM_0_I32]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: [[DIM_1_I32:%.*]] = "mhlo.get_dimension_size"([[ARG]]) {dimension = 1 : i64} : (tensor) -> tensor - // CHECK-NEXT: [[DIM_1_I32x1:%.*]] = mhlo.reshape [[DIM_1_I32]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: [[DIM_2_I32:%.*]] = "mhlo.get_dimension_size"([[ARG]]) {dimension = 2 : i64} : (tensor) -> tensor - // CHECK-NEXT: [[DIM_2_I32x1:%.*]] = mhlo.reshape [[DIM_2_I32]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: [[IOTA_SHAPE:%.*]] = "mhlo.concatenate"([[DIM_0_I32x1]], [[DIM_1_I32x1]], [[DIM_2_I32x1]]) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> - // CHECK-NEXT: [[K_I32:%.*]] = mhlo.constant dense<2> : tensor - // CHECK-NEXT: [[K_I32x1:%.*]] = mhlo.reshape [[K_I32]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: [[RESULT_SHAPE:%.*]] = "mhlo.concatenate"([[DIM_0_I32x1]], [[DIM_1_I32x1]], [[K_I32x1]]) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> - // CHECK-NEXT: [[IOTA:%.*]] = "mhlo.dynamic_iota"([[IOTA_SHAPE]]) {iota_dimension = 2 : i64} : (tensor<3xi32>) -> tensor - // CHECK-NEXT: [[SORT:%.*]]:2 = "mhlo.sort"([[ARG]], [[IOTA]]) ({ - // CHECK-NEXT: ^bb0([[ARG_1:%.*]]: tensor, [[ARG_2:%.*]]: tensor, [[ARG_3:%.*]]: tensor, [[ARG_4:%.*]]: tensor): - // CHECK-NEXT: [[CMP:%.*]] = mhlo.compare GT, [[ARG_1]], [[ARG_2]], NOTYPE : (tensor, tensor) -> tensor - // CHECK-NEXT: mhlo.return [[CMP]] : tensor - // CHECK-NEXT: }) {dimension = 2 : i64, is_stable = true} : (tensor, tensor) -> (tensor, tensor) - // CHECK-NEXT: [[STARTS:%.*]] = mhlo.constant dense<0> : tensor<3xi64> - // CHECK-NEXT: [[LIMITS:%.*]] = mhlo.convert [[RESULT_SHAPE]] : (tensor<3xi32>) -> tensor<3xi64> - // CHECK-NEXT: [[STRIDES:%.*]] = mhlo.constant dense<1> : tensor<3xi64> - // CHECK-NEXT: [[VAL:%.*]] = mhlo.real_dynamic_slice [[SORT]]#0, [[STARTS]], [[LIMITS]], [[STRIDES]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor - // CHECK-NEXT: [[IDX:%.*]] = mhlo.real_dynamic_slice [[SORT]]#1, [[STARTS]], [[LIMITS]], [[STRIDES]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor - // CHECK-NEXT: return [[VAL]], [[IDX]] : tensor, tensor + // CHECK-HIGH-LEVEL: mhlo.topk + // CHECK: %values, %indices = mhlo.topk(%arg0, k = 2, largest = true) : tensor -> (tensor, tensor) %values, %indices = chlo.top_k(%arg0, k = 2) : tensor -> (tensor, tensor) return %values, %indices : tensor, tensor } diff --git a/xla/mlir_hlo/tests/Dialect/chlo/sparse_chlo_legalize_to_linalg.mlir b/xla/mlir_hlo/tests/Dialect/chlo/sparse_chlo_legalize_to_linalg.mlir deleted file mode 100644 index 1b7602b71dd3a..0000000000000 --- a/xla/mlir_hlo/tests/Dialect/chlo/sparse_chlo_legalize_to_linalg.mlir +++ /dev/null @@ -1,95 +0,0 @@ -// RUN: mlir-hlo-opt --legalize-sparse-ops="legalize-to-custom-calls=false" %s | FileCheck %s - -#CSR = #sparse_tensor.encoding<{ - map = (d0, d1) -> (d0 : dense, d1 : compressed) -}> - -// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> - -// CHECK-LABEL: @asinh_scalar( -// CHECK-SAME: %[[ARG:.*]]: tensor) -> tensor { -// CHECK: %[[RESULT:.*]] = chlo.asinh %[[ARG]] : tensor -> tensor -// CHECK: return %[[RESULT]] : tensor -func.func @asinh_scalar(%arg : tensor) -> tensor { - %result = "chlo.asinh"(%arg) : (tensor) -> tensor - func.return %result : tensor -} - -// CHECK-LABEL: @asinh_tensor( -// CHECK-SAME: %[[ARG:.*]]: tensor<10x20xf32, #[[$CSR]]>) -> -// CHECK-SAME: tensor<10x20xf32, #[[$CSR]]> { -// CHECK: %[[OUT:.*]] = bufferization.alloc_tensor() : -// CHECK-SAME: tensor<10x20xf32, #[[$CSR]]> -// CHECK: %[[VAL:.*]] = linalg.generic -// CHECK-SAME: ins(%[[ARG]] : tensor<10x20xf32, -// CHECK-SAME: #sparse>) outs(%[[OUT]] -// CHECK: sparse_tensor.unary %{{.*}} : f32 to f32 -// CHECK: present = { -// CHECK: tensor.from_elements -// CHECK: chlo.asinh -// CHECK: tensor.extract -// CHECK: sparse_tensor.yield %{{.*}} : f32 -// CHECK: } -// CHECK: absent = { -// CHECK: } -// CHECK: } -func.func @asinh_tensor(%arg : tensor<10x20xf32, #CSR>) - -> tensor<10x20xf32, #CSR> { - %result = "chlo.asinh"(%arg) : (tensor<10x20xf32, #CSR>) - -> tensor<10x20xf32, #CSR> - func.return %result : tensor<10x20xf32, #CSR> -} - -// CHECK-LABEL: func.func @tan_tensor( -// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<10x20xf32, #[[$CSR]] -// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor() : tensor<10x20xf32, #[[$CSR]] -// CHECK: %[[TMP_1:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} -// CHECK-SAME: ins(%[[TMP_arg0]] : tensor<10x20xf32, #[[$CSR]] -// CHECK-SAME: outs(%[[TMP_0]] : tensor<10x20xf32, #[[$CSR]] -// CHECK: ^bb0(%[[TMP_arg1:.*]]: f32, %[[TMP_arg2:.*]]: f32): -// CHECK: %[[TMP_2:.*]] = sparse_tensor.unary %[[TMP_arg1]] : f32 to f32 -// CHECK: present = { -// CHECK: ^bb0(%[[TMP_arg3:.*]]: f32): -// CHECK: %[[TMP_3:.*]] = tensor.from_elements %[[TMP_arg3]] : tensor -// CHECK: %[[TMP_4:.*]] = chlo.tan %[[TMP_3]] : tensor -> tensor -// CHECK: %[[TMP_5:.*]] = tensor.extract %[[TMP_4]][] : tensor -// CHECK: sparse_tensor.yield %[[TMP_5]] : f32 -// CHECK: } -// CHECK: absent = { -// CHECK: } -// CHECK: linalg.yield %[[TMP_2]] : f32 -// CHECK: } -> tensor<10x20xf32, -// CHECK: return %[[TMP_1]] : tensor<10x20xf32, -func.func @tan_tensor(%arg : tensor<10x20xf32, #CSR>) - -> tensor<10x20xf32, #CSR> { - %result = "chlo.tan"(%arg) : (tensor<10x20xf32, #CSR>) - -> tensor<10x20xf32, #CSR> - func.return %result : tensor<10x20xf32, #CSR> -} - -// CHECK-LABEL: func.func @sinh_tensor( -// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<10x20xf32, -// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor() : tensor<10x20xf32, #[[$CSR]] -// CHECK: %[[TMP_1:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} -// CHECK-SAME: ins(%[[TMP_arg0]] : tensor<10x20xf32, #[[$CSR]] -// CHECK-SAME: outs(%[[TMP_0]] : tensor<10x20xf32, #[[$CSR]] -// CHECK: ^bb0(%[[TMP_arg1:.*]]: f32, %[[TMP_arg2:.*]]: f32): -// CHECK: %[[TMP_2:.*]] = sparse_tensor.unary %[[TMP_arg1]] : f32 to f32 -// CHECK: present = { -// CHECK: ^bb0(%[[TMP_arg3:.*]]: f32): -// CHECK: %[[TMP_3:.*]] = tensor.from_elements %[[TMP_arg3]] : tensor -// CHECK: %[[TMP_4:.*]] = chlo.sinh %[[TMP_3]] : tensor -> tensor -// CHECK: %[[TMP_5:.*]] = tensor.extract %[[TMP_4]][] : tensor -// CHECK: sparse_tensor.yield %[[TMP_5]] : f32 -// CHECK: } -// CHECK: absent = { -// CHECK: } -// CHECK: linalg.yield %[[TMP_2]] : f32 -// CHECK: } -> tensor<10x20xf32, -// CHECK: return %[[TMP_1]] : tensor<10x20xf32, -func.func @sinh_tensor(%arg : tensor<10x20xf32, #CSR>) - -> tensor<10x20xf32, #CSR> { - %result = "chlo.sinh"(%arg) : (tensor<10x20xf32, #CSR>) - -> tensor<10x20xf32, #CSR> - func.return %result : tensor<10x20xf32, #CSR> -} diff --git a/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-legalize-select-and-scatter.mlir b/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-legalize-select-and-scatter.mlir index 13077e650541f..388ce00b077a8 100644 --- a/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-legalize-select-and-scatter.mlir +++ b/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-legalize-select-and-scatter.mlir @@ -50,7 +50,7 @@ func.func @select_and_scatter(%arg: memref<112x112xf32>, // CHECK: "scf.parallel"(%[[C0]], %[[C0]], %[[C112]], %[[C112]], %[[C1]], %[[C1]]) <{{.*}}> ({ // CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index): // CHECK: "memref.store"(%[[INIT]], %[[RESULT_BUF]], %[[I]], %[[J]]) -// CHECK: "scf.yield"() : () -> () +// CHECK: "scf.reduce"() : () -> () // CHECK: }) // Parallel loop over source buffer to compute scattered values. @@ -155,4 +155,4 @@ func.func @select_and_scatter(%arg: memref<112x112xf32>, // CHECK: "memref.atomic_yield"(%[[RES]]) : (f32) -> () // Parallel loop over source buffer yield -// CHECK: "scf.yield"() : () -> () +// CHECK: "scf.reduce"() : () -> () diff --git a/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-legalize-to-affine.mlir b/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-legalize-to-affine.mlir index a48dd2f98ae43..fc5afe7cc33c4 100644 --- a/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-legalize-to-affine.mlir +++ b/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-legalize-to-affine.mlir @@ -282,9 +282,9 @@ func.func @gather_2(%arg0: memref<16x11xf32>, %arg1: memref<5x2xi32>, %arg2: mem "lmhlo.copy"(%0, %arg2) : (memref<5x8x6xf32>, memref<5x8x6xf32>) -> () "lmhlo.terminator"() : () -> () } -// CHECK-NEXT: %[[zero:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-NEXT: %c0 = arith.constant 0 : index -// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-DAG: %[[zero:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %c0 = arith.constant 0 : index +// CHECK-DAG: %c1 = arith.constant 1 : index // CHECK-NEXT: %[[temp_output:.*]] = memref.alloc() : memref<5x8x6xf32> // CHECK-NEXT: affine.for %{{.*}} = 0 to 5 { // CHECK-NEXT: affine.for %{{.*}} = 0 to 8 { @@ -336,9 +336,9 @@ func.func @gather_3(%arg0: memref<16x11xf16>, %arg1: memref<4x2x5xi32>, %arg2: m "lmhlo.copy"(%0, %arg2) : (memref<4x5x8x6xf16>, memref<4x5x8x6xf16>) -> () "lmhlo.terminator"() : () -> () } -// CHECK-NEXT: %[[zero:.*]] = arith.constant 0.000000e+00 : f16 -// CHECK-NEXT: %c0 = arith.constant 0 : index -// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-DAG: %[[zero:.*]] = arith.constant 0.000000e+00 : f16 +// CHECK-DAG: %c0 = arith.constant 0 : index +// CHECK-DAG: %c1 = arith.constant 1 : index // CHECK-NEXT: %[[temp_output:.*]] = memref.alloc() : memref<4x5x8x6xf16> // CHECK-NEXT: affine.for %{{.*}} = 0 to 4 { // CHECK-NEXT: affine.for %{{.*}} = 0 to 5 { @@ -453,11 +453,11 @@ func.func @gather_6(%arg0: memref<16x11x10x9xf32>, %arg1: memref<5x4xi32>, %arg2 "lmhlo.copy"(%0, %arg2) : (memref<5x8x6x5x4xf32>, memref<5x8x6x5x4xf32>) -> () "lmhlo.terminator"() : () -> () } -// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-NEXT: %[[ZERO_IDX:.*]] = arith.constant 0 : index -// CHECK-NEXT: %[[ONE_IDX:.*]] = arith.constant 1 : index -// CHECK-NEXT: %[[TWO_IDX:.*]] = arith.constant 2 : index -// CHECK-NEXT: %[[THREE_IDX:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_IDX:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[ONE_IDX:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[TWO_IDX:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[THREE_IDX:.*]] = arith.constant 3 : index // CHECK-NEXT: %[[RESULT:.*]] = memref.alloc() : memref<5x8x6x5x4xf32> // CHECK-NEXT: affine.for %{{.*}} = 0 to 5 { // CHECK-NEXT: affine.for %{{.*}} = 0 to 8 { diff --git a/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-legalize-to-parallel-loops.mlir b/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-legalize-to-parallel-loops.mlir index 5a9afa044f0f3..15434a94a0774 100644 --- a/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-legalize-to-parallel-loops.mlir +++ b/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-legalize-to-parallel-loops.mlir @@ -28,7 +28,7 @@ func.func @reduce(%arg: memref<100x10x5xf32>, // CHECK-SAME: ([[C0]]) to ([[C10]]) step ([[C1]]) init ([[INIT]]) -> f32 { // CHECK: [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]] // CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<100x10x5xf32> -// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: scf.reduce([[ELEM_TO_REDUCE]] : f32) { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref // CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref @@ -39,10 +39,9 @@ func.func @reduce(%arg: memref<100x10x5xf32>, // CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } -// CHECK: scf.yield // CHECK: } // CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] -// CHECK: scf.yield +// CHECK: scf.reduce // ----- @@ -69,7 +68,7 @@ func.func @reduce_no_outer_loop(%arg: memref<100xf32>, // CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[I:%.*]]) = ([[C0]]) // CHECK-SAME: to ([[C100]]) step ([[C1]]) init ([[INIT]]) -> f32 { // CHECK: [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]]{{\[}}[[I]]{{\]}} -// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: scf.reduce([[ELEM_TO_REDUCE]] : f32) { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref // CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref @@ -80,7 +79,6 @@ func.func @reduce_no_outer_loop(%arg: memref<100xf32>, // CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] // CHECK: } -// CHECK: scf.yield // CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[C0]]] // ----- @@ -114,7 +112,7 @@ func.func @dynamic_reduce(%arg: memref, // CHECK-SAME: ([[C0]]) to ([[DIM1]]) step ([[C1]]) init ([[INIT]]) -> f32 { // CHECK: [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]] // CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref -// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: scf.reduce([[ELEM_TO_REDUCE]] : f32) { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref // CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref @@ -125,10 +123,9 @@ func.func @dynamic_reduce(%arg: memref, // CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } -// CHECK: scf.yield // CHECK: } // CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] -// CHECK: scf.yield +// CHECK: scf.reduce // ----- @@ -182,7 +179,7 @@ func.func @reduce_window(%arg: memref<112x112xf32>, // CHECK: scf.yield [[INIT]] : f32 // CHECK: } -// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: scf.reduce([[ELEM_TO_REDUCE]] : f32) { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref // CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref @@ -193,10 +190,9 @@ func.func @reduce_window(%arg: memref<112x112xf32>, // CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } -// CHECK: scf.yield // CHECK: } // CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]] -// CHECK: scf.yield +// CHECK: scf.reduce // CHECK: } // CHECK: return // CHECK: } diff --git a/xla/mlir_hlo/tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir b/xla/mlir_hlo/tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir deleted file mode 100644 index 2370f762db439..0000000000000 --- a/xla/mlir_hlo/tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir +++ /dev/null @@ -1,250 +0,0 @@ -// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file \ -// RUN: | mlir-hlo-opt \ -// RUN: | FileCheck %s - -// CHECK-LABEL: func @conv_forward_generic -// CHECK: lmhlo_gpu.conv_forward -// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1] -// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [1, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} -func.func @conv_forward_generic(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) { - %scratch = memref.alloc() : memref<32xi8> - // This defined a 2D convolution over a 8x8 single channel input using a 2x2 - // filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W) - "lmhlo_gpu.conv_forward"(%input, %filter, %output, %scratch) { - dimension_numbers = #mhlo.conv, - window_strides = dense<[1, 1]> : tensor<2xi64>, - padding = dense<[[0, 0], [1, 0]]> : tensor<2x2xi64>, - lhs_dilation = dense<[1,1]> : tensor<2xi64>, - rhs_dilation = dense<[1,1]> : tensor<2xi64>, - feature_group_count = 1, - batch_group_count = 1, - result_scale = 1.0, - backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = 0, - tensor_ops_enabled = true, - is_cudnn_reordered_int8 = false, - knob_ids = [], - knob_values = [], - is_cudnn_frontend = false, - workspace_size = -1, - operand_0_layout = [3,2,1,0], - operand_1_layout = [3,2,1,0], - result_layout = [3,2,1,0] - > - } : (memref<1x1x8x8xf16>, memref<1x1x2x2xf16>, memref<1x1x7x7xf16>, memref<32xi8>) -> () - func.return -} - -// CHECK-LABEL: func @conv_forward -// CHECK: lmhlo_gpu.conv_forward -// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1] -// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [1, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} -func.func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) { - %scratch = memref.alloc() : memref<32xi8> - // This defined a 2D convolution over a 8x8 single channel input using a 2x2 - // filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W) - lmhlo_gpu.conv_forward(%input, %filter, %output, %scratch) - dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], - window = {stride = [1, 1], pad = [[0, 0], [1, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { feature_group_count = 1, batch_group_count = 1, result_scale = 1.0, - backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = 0, - tensor_ops_enabled = true, - knob_ids = [], - knob_values = [], - is_cudnn_frontend = false, - is_cudnn_reordered_int8 = false, - workspace_size = -1, - operand_0_layout = [3,2,1,0], - operand_1_layout = [3,2,1,0], - result_layout = [3,2,1,0] - > - } : (memref<1x1x8x8xf16>, memref<1x1x2x2xf16>, memref<1x1x7x7xf16>, memref<32xi8>) -> () - func.return -} - -// CHECK-LABEL: func @conv_backfilter -// CHECK: lmhlo_gpu.conv_backwardfilter -// CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] -// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} -func.func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64xf64>, %output: memref<54x54x16x64xf64>) { - %scratch = memref.alloc() : memref<23328xui8> - lmhlo_gpu.conv_backwardfilter(%input, %filter, %output, %scratch) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = 0, - tensor_ops_enabled = true, - knob_ids = [], - knob_values = [], - is_cudnn_frontend = false, - is_cudnn_reordered_int8 = false, - workspace_size = -1, - operand_0_layout = [3,2,1,0], - operand_1_layout = [3,2,1,0], - result_layout = [3,2,1,0] - >, - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [], - result_scale = 1.000000e+00 : f64 - } : (memref<3x56x56x16xf64>, memref<3x3x3x64xf64>, memref<54x54x16x64xf64>, memref<23328xui8>) -> () - func.return -} - -// CHECK-LABEL: func @conv_backinput -// CHECK: lmhlo_gpu.conv_backwardinput -// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1] -// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[3, 0], [1, 5]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [1, 1]} -func.func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf64>, %output : memref<4x3x16x16xf64>) { - %scratch = memref.alloc() : memref<32xui8> - lmhlo_gpu.conv_backwardinput(%input, %filter, %output, %scratch) - dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], - window = {stride = [1, 1], pad = [[3, 0], [1, 5]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [1, 1]} - { backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = 0, - tensor_ops_enabled = true, - knob_ids = [], - knob_values = [], - is_cudnn_frontend = false, - is_cudnn_reordered_int8 = false, - workspace_size = -1, - operand_0_layout = [3,2,1,0], - operand_1_layout = [3,2,1,0], - result_layout = [3,2,1,0] - >, - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [], - result_scale = 1.000000e+00 : f64 - } : (memref<4x5x16x16xf64>, memref<5x3x7x7xf64>, memref<4x3x16x16xf64>, memref<32xui8>) -> () - func.return -} - -// CHECK-LABEL: func @conv_fused -// CHECK: lmhlo_gpu.conv_forward_fused -// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1] -// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} -func.func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %output : memref<1x32x9x9xf16>) { - %scratch = memref.alloc() : memref<32xui8> - lmhlo_gpu.conv_forward_fused(%input, %filter, %bias, %output, %scratch) - dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1], - window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { activation_mode = #lmhlo_gpu, - leakyrelu_alpha = 0.0 : f64, - backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = 0, - tensor_ops_enabled = true, - knob_ids = [], - knob_values = [], - is_cudnn_frontend = false, - is_cudnn_reordered_int8 = false, - workspace_size = -1, - operand_0_layout = [3,2,1,0], - operand_1_layout = [3,2,1,0], - result_layout = [3,2,1,0] - >, - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#mhlo, #mhlo, #mhlo], - result_scale = 1.000000e+00 : f64 - } : (memref<1x17x9x9xf16>, memref<3x3x17x32xf16>, memref<32xf16>, memref<1x32x9x9xf16>, memref<32xui8>) -> () - func.return -} - -// CHECK-LABEL: func @conv_fused_side_input -// CHECK: lmhlo_gpu.conv_forward_fused_with_side_input -// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1] -// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} -func.func @conv_fused_side_input(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %side_input: memref<32xf16>, %output : memref<1x32x9x9xf16>) { - %scratch = memref.alloc() : memref<0xui8> - lmhlo_gpu.conv_forward_fused_with_side_input(%input, %filter, %bias, %side_input, %output, %scratch) - dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1], - window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - { activation_mode = #lmhlo_gpu, - backend_config = #lmhlo_gpu.convolution_backend_config< - algorithm = 0, - tensor_ops_enabled = true, - knob_ids = [], - knob_values = [], - is_cudnn_frontend = false, - is_cudnn_reordered_int8 = false, - workspace_size = -1, - operand_0_layout = [3,2,1,0], - operand_1_layout = [3,2,1,0], - result_layout = [3,2,1,0] - >, - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#mhlo, #mhlo, #mhlo], - result_scale = 1.000000e+00 : f64, - side_input_scale = 1.000000e+00 : f64 - } : (memref<1x17x9x9xf16>, memref<3x3x17x32xf16>, memref<32xf16>, memref<32xf16>, memref<1x32x9x9xf16>, memref<0xui8>) -> () - func.return -} - -// CHECK-LABEL: func @gemm -func.func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>) { - "lmhlo_gpu.gemm"(%lhs, %rhs, %output) { - dot_dimension_numbers = #mhlo.dot< - lhs_batching_dimensions = [1,1], - rhs_batching_dimensions = [1,1], - lhs_contracting_dimensions = [1,1], - rhs_contracting_dimensions = [1,1] - >, - alpha_real = 0.5, - alpha_imag = 0.0, - beta = 0.0, - batch_size = 1, - lhs_stride = 20, - rhs_stride = 20, - algorithm = 0 - } : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> () - func.return -} - -// CHECK-LABEL: func @cholesky -func.func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) { - %scratch = memref.alloc() : memref<32xi8> - %info = memref.alloc() : memref<32xi32> - "lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_lower = true } - : (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> () - func.return -} - -// CHECK-LABEL: func @ag_start -func.func @ag_start(%arg : memref<10x10xf32>, %out: memref<20x10xf32>) { - %0 = "lmhlo_gpu.all_gather_start"(%arg, %out) - { - replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, - all_gather_dimension = 0, - is_sync = false, - no_parallel_custom_call = false - } - : (memref<10x10xf32>, memref<20x10xf32>) -> (!mhlo.token) - func.return -} - -// CHECK-LABEL: func @ag_start_mixed -func.func @ag_start_mixed(%arg0 : memref<10x10xf32>, %arg1 : memref<10x10xf16>, - %out0: memref<20x10xf32>, %out1: memref<20x10xf16>) { - %0 = "lmhlo_gpu.all_gather_start"(%arg0, %arg1, %out0, %out1) - { - replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, - all_gather_dimension = 0, - is_sync = true, - no_parallel_custom_call = true - } - : (memref<10x10xf32>, memref<10x10xf16>, memref<20x10xf32>, memref<20x10xf16>) -> (!mhlo.token) - func.return -} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/broadcast_propagation.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/broadcast_propagation.mlir index 9af9f9f6a71f3..4bf50644127e7 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/broadcast_propagation.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/broadcast_propagation.mlir @@ -52,8 +52,8 @@ func.func @single_bcast_ensure_order(%arg0 : tensor<16x?xf32>, %arg1 : tensor<16 func.func @double_bcasts(%arg0 : tensor<16x?xf32>, %arg1 : tensor<16x?xf32>, %shape0 : tensor<3xindex>, %shape1 : tensor<3xindex>) -> (tensor, tensor) { - // CHECK-DAG: %[[BCASTED_ARG00:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[SHAPE0]]) [[BCAST_DIMS0:{broadcast_dimensions = dense<\[1, 2\]> : tensor<2xi64>}]] - // CHECK-DAG: %[[BCASTED_ARG01:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[SHAPE1]]) [[BCAST_DIMS1:{broadcast_dimensions = dense<\[0, 2\]> : tensor<2xi64>}]] + // CHECK-DAG: %[[BCASTED_ARG00:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[SHAPE0]]) [[BCAST_DIMS0:<{broadcast_dimensions = dense<\[1, 2\]> : tensor<2xi64>}>]] + // CHECK-DAG: %[[BCASTED_ARG01:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[SHAPE1]]) [[BCAST_DIMS1:<{broadcast_dimensions = dense<\[0, 2\]> : tensor<2xi64>}>]] // CHECK-DAG: %[[BCASTED_ARG10:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[SHAPE0]]) [[BCAST_DIMS0]] // CHECK-DAG: %[[BCASTED_ARG11:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[SHAPE1]]) [[BCAST_DIMS1]] // CHECK-DAG: %[[ADD0:.*]] = mhlo.add %[[BCASTED_ARG00]], %[[BCASTED_ARG10]] : [[BCAST_TY:tensor<\?x16x\?xf32>]] @@ -85,8 +85,8 @@ func.func @double_bcasts(%arg0 : tensor<16x?xf32>, %arg1 : tensor<16x?xf32>, func.func @late_output_dimensions(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { // CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG2]] - // CHECK-DAG: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK-DAG: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK-DAG: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[SHAPE]]) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> + // CHECK-DAG: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[SHAPE]]) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> // CHECK-DAG: %[[SUB:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]] : [[BCAST_TY:tensor<\?x\?x32xf32>]] // CHECK-DAG: %[[ADD:.*]] = mhlo.add %[[SUB]], %[[SUB]] : [[BCAST_TY]] // CHECK: return %[[ADD]] : [[BCAST_TY]] @@ -118,7 +118,7 @@ func.func @very_late_output_dimensions(%arg0 : tensor, %acc2 = mhlo.subtract %acc1, %arg1 : tensor %acc3 = mhlo.divide %acc2, %arg1 : tensor %1 = shape.shape_of %arg2 : tensor -> tensor<3xindex> - %3 = "mhlo.dynamic_broadcast_in_dim"(%acc3, %1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<3xindex>) -> tensor + %3 = "mhlo.dynamic_broadcast_in_dim"(%acc3, %1) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor, tensor<3xindex>) -> tensor func.return %3 : tensor } @@ -176,7 +176,7 @@ func.func @propagate_within_block_2(%arg : tensor, // CHECK-SAME: %[[ARG:.*]]: tensor<1xindex> func.func @propagate_across_bcasts_cst_src(%s : tensor<1xindex>) -> tensor { // CHECK-DAG: %[[C1:.*]] = mhlo.constant dense : tensor - // CHECK-DAG: %[[RES:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[C1]], %[[ARG]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor + // CHECK-DAG: %[[RES:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[C1]], %[[ARG]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<1xindex>) -> tensor // CHECK: return %[[RES]] %0 = mhlo.constant dense : tensor %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %s) @@ -193,7 +193,7 @@ func.func @propagate_across_bcasts_cst_src(%s : tensor<1xindex>) -> tensor // CHECK-LABEL: @compose_bcast_dims // CHECK-SAME: %[[ARG:.*]]: tensor, %[[S0:.*]]: tensor<3xindex>, %[[S1:.*]]: tensor<4xindex> func.func @compose_bcast_dims(%arg : tensor, %s0 : tensor<3xindex>, %s1 : tensor<4xindex>) -> tensor<1x?x1x?xi1> { - // CHECK-DAG: %[[RES:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[S1]]) {broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>} : (tensor, tensor<4xindex>) -> tensor<1x?x1x?xi1> + // CHECK-DAG: %[[RES:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[S1]]) <{broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>}> : (tensor, tensor<4xindex>) -> tensor<1x?x1x?xi1> // CHECK: return %[[RES]] %1 = "mhlo.dynamic_broadcast_in_dim"(%arg, %s0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} @@ -209,7 +209,7 @@ func.func @compose_bcast_dims(%arg : tensor, %s0 : tensor<3xindex>, %s1 // CHECK-LABEL: @propagate_across_bcasts // CHECK-SAME: %[[ARG:.*]]: tensor, %[[S:.*]]: tensor<3xindex> func.func @propagate_across_bcasts(%arg : tensor, %shape : tensor<3xindex>) -> tensor { - // CHECK-DAG: %[[RES:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[S]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor, tensor<3xindex>) -> tensor + // CHECK-DAG: %[[RES:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[S]]) <{broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}> : (tensor, tensor<3xindex>) -> tensor // CHECK: return %[[RES]] %0 = "mhlo.dynamic_broadcast_in_dim"(%arg, %shape) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir index da2d9bc1f28ec..91ddb13539f2a 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir @@ -288,7 +288,7 @@ func.func @clamp_fold_float() -> tensor<6xf32> { // CHECK-LABEL: concatenate_noop func.func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { // CHECK-SAME: [[ARG:%.+]]: tensor<4xi32> - %0 = "mhlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32> + %0 = "mhlo.concatenate"(%arg0) <{ dimension = 0 : i64 }> : (tensor<4xi32>) -> tensor<4xi32> // CHECK: return [[ARG]] func.return %0 : tensor<4xi32> @@ -298,7 +298,7 @@ func.func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { func.func @concatenate_noop_typecast(%arg0: tensor) -> tensor<4xi32> { // CHECK-SAME: [[ARG:%.+]]: tensor // CHECK-NEXT: [[RES:%.+]] = tensor.cast [[ARG]] : tensor to tensor<4xi32> - %0 = "mhlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor) -> tensor<4xi32> + %0 = "mhlo.concatenate"(%arg0) <{ dimension = 0 : i64 }> : (tensor) -> tensor<4xi32> // CHECK: return [[RES]] func.return %0 : tensor<4xi32> @@ -308,7 +308,7 @@ func.func @concatenate_noop_typecast(%arg0: tensor) -> tensor<4xi32> { func.func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> tensor<4xi32> { // CHECK-SAME: [[ARG0:%.+]]: tensor<4xi32> // CHECK-SAME: [[ARG1:%.+]]: tensor<0xi32> - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32> + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32> // CHECK: return [[ARG0]] func.return %0 : tensor<4xi32> @@ -316,10 +316,10 @@ func.func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32> // CHECK-LABEL: concatenate_forward func.func @concatenate_forward(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<12xi32> { - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<4xi32>) -> tensor<8xi32> + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<4xi32>, tensor<4xi32>) -> tensor<8xi32> %1 = mhlo.constant dense<[0, 1, 2, 3]> : tensor<4xi32> - // CHECK: "mhlo.concatenate"(%arg0, %arg1, %0) {dimension = 0 : i64} : (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<12xi32> - %2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<8xi32>, tensor<4xi32>) -> tensor<12xi32> + // CHECK: "mhlo.concatenate"(%arg0, %arg1, %0) <{dimension = 0 : i64}> : (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<12xi32> + %2 = "mhlo.concatenate"(%0, %1) <{ dimension = 0 : i64 }> : (tensor<8xi32>, tensor<4xi32>) -> tensor<12xi32> func.return %2 : tensor<12xi32> } @@ -327,7 +327,7 @@ func.func @concatenate_forward(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> te // CHECK-LABEL: concatenate_empty_bool func.func @concatenate_empty_bool(%arg0: tensor<0xi1>, %arg1: tensor<0xi1>) -> tensor<0xi1> { // CHECK: mhlo.constant - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1> + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1> func.return %0 : tensor<0xi1> } @@ -335,7 +335,7 @@ func.func @concatenate_empty_bool(%arg0: tensor<0xi1>, %arg1: tensor<0xi1>) -> t // CHECK-LABEL: concatenate_empty_int func.func @concatenate_empty_int(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<0xi32> { // CHECK: mhlo.constant - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32> + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32> func.return %0 : tensor<0xi32> } @@ -343,7 +343,7 @@ func.func @concatenate_empty_int(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> // CHECK-LABEL: concatenate_empty_float func.func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> { // CHECK: mhlo.constant - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32> + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32> func.return %0 : tensor<0xf32> } @@ -353,7 +353,7 @@ func.func @concatenate_const_1D() -> tensor<4xi32> { // CHECK: [[VAL:%.+]]= mhlo.constant dense<[0, 1, 2, 3]> %0 = mhlo.constant dense<[0, 1]> : tensor<2xi32> %1 = mhlo.constant dense<[2, 3]> : tensor<2xi32> - %2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32> + %2 = "mhlo.concatenate"(%0, %1) <{ dimension = 0 : i64 }> : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32> // CHECK: return [[VAL]] func.return %2 : tensor<4xi32> @@ -365,7 +365,7 @@ func.func @concatenate_const_1D_float() -> tensor<4xf32> { %0 = mhlo.constant dense<[0.0, 1.0]> : tensor<2xf32> %1 = mhlo.constant dense<[2.0, 3.0]> : tensor<2xf32> - %2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32> + %2 = "mhlo.concatenate"(%0, %1) <{ dimension = 0 : i64 }> : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32> // CHECK: return [[VAL]] func.return %2 : tensor<4xf32> @@ -378,7 +378,7 @@ func.func @concatenate_const_2D_vertical() -> tensor<2x2xi32> { // CHECK-SAME: ]> %0 = mhlo.constant dense<[[0, 1]]> : tensor<1x2xi32> %1 = mhlo.constant dense<[[2, 3]]> : tensor<1x2xi32> - %2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> + %2 = "mhlo.concatenate"(%0, %1) <{ dimension = 0 : i64 }> : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> // CHECK: return [[VAL]] func.return %2 : tensor<2x2xi32> @@ -391,7 +391,7 @@ func.func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> { // CHECK-SAME: ]> %0 = mhlo.constant dense<[[0], [1]]> : tensor<2x1xi32> %1 = mhlo.constant dense<[[2], [3]]> : tensor<2x1xi32> - %2 = "mhlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> + %2 = "mhlo.concatenate"(%0, %1) <{ dimension = 1 : i64 }> : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> // CHECK: return [[VAL]] func.return %2 : tensor<2x2xi32> @@ -405,10 +405,10 @@ func.func @constant_like_constant(%arg0: tensor<3x4xi32>) -> tensor<3x4xf32> { } // CHECK-LABEL: constant_like_constant_dynamic -func.func @constant_like_constant_dynamic(%arg0: tensor<*xi32>) -> tensor<*xf32> { +func.func @constant_like_constant_dynamic(%arg0: tensor) -> tensor { // CHECK: chlo.constant_like - %0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor<*xi32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> + %0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor) -> tensor + func.return %0 : tensor } // CHECK-LABEL: dynamic_update_slice_fold_length_0 @@ -440,7 +440,7 @@ func.func @dynamic_update_slice_fold_fail_dynamic_shapes(%arg0: tensor, // CHECK-LABEL: dynamic_slice_variable_start func.func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { // CHECK: "mhlo.dynamic_slice" - %1 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + %1 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> func.return %1 : tensor<1x4xi32> } @@ -452,7 +452,7 @@ func.func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} // CHECK: return %[[RESULT]] : tensor<2xi32> %0 = mhlo.constant dense<1> : tensor - %1 = "mhlo.dynamic_slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<2xi32> + %1 = "mhlo.dynamic_slice"(%arg0, %0) <{slice_sizes = dense<2> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<2xi32> func.return %1 : tensor<2xi32> } @@ -462,7 +462,7 @@ func.func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor, %a // CHECK-NOT: mhlo.slice %0 = mhlo.constant dense<1> : tensor %1 = mhlo.constant dense<0> : tensor - %2 = "mhlo.dynamic_slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor, tensor, tensor) -> tensor<1x4xi32> + %2 = "mhlo.dynamic_slice"(%arg0, %0, %1) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor, tensor, tensor) -> tensor<1x4xi32> func.return %2 : tensor<1x4xi32> } @@ -475,7 +475,7 @@ func.func @dynamic_slice_constant_start_upper_bound(%arg0: tensor<8x4xi32>, %arg // CHECK: return %[[RESULT]] : tensor<1x4xi32> %0 = mhlo.constant dense<10> : tensor %1 = mhlo.constant dense<0> : tensor - %2 = "mhlo.dynamic_slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<8x4xi32>, tensor, tensor) -> tensor<1x4xi32> + %2 = "mhlo.dynamic_slice"(%arg0, %0, %1) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<8x4xi32>, tensor, tensor) -> tensor<1x4xi32> func.return %2 : tensor<1x4xi32> } @@ -488,14 +488,14 @@ func.func @dynamic_slice_constant_start_lower_bound(%arg0: tensor<8x4xi32>, %arg // CHECK: return %[[RESULT]] : tensor<1x4xi32> %0 = mhlo.constant dense<-1> : tensor %1 = mhlo.constant dense<0> : tensor - %2 = "mhlo.dynamic_slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<8x4xi32>, tensor, tensor) -> tensor<1x4xi32> + %2 = "mhlo.dynamic_slice"(%arg0, %0, %1) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<8x4xi32>, tensor, tensor) -> tensor<1x4xi32> func.return %2 : tensor<1x4xi32> } // CHECK-LABEL: slice_2D_noop // CHECK-SAME: [[ARG:%.+]]: tensor<2x2xi64> func.func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> { - %0 = "mhlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>) + %0 = "mhlo.slice"(%arg0) <{ limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<2x2xi64>) -> (tensor<2x2xi64>) // CHECK-NEXT: return [[ARG]] func.return %0 : tensor<2x2xi64> @@ -505,7 +505,7 @@ func.func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> { func.func @slice_1D_fold() -> tensor<2xi64> { %0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> // CHECK: mhlo.constant dense<[7, 9]> - %1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>) + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<4xi64>) -> (tensor<2xi64>) func.return %1 : tensor<2xi64> } @@ -513,7 +513,7 @@ func.func @slice_1D_fold() -> tensor<2xi64> { func.func @slice_1D_fp() -> tensor<2xf32> { %0 = mhlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32> // CHECK: mhlo.constant dense<[7.000000e+00, 9.000000e+00]> - %1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>) + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<4xf32>) -> (tensor<2xf32>) func.return %1 : tensor<2xf32> } @@ -521,7 +521,7 @@ func.func @slice_1D_fp() -> tensor<2xf32> { func.func @slice_1D_strided_fold() -> tensor<2xi64> { %0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> // CHECK: mhlo.constant dense<[7, 10]> - %1 = "mhlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>) + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}> : (tensor<4xi64>) -> (tensor<2xi64>) func.return %1 : tensor<2xi64> } @@ -532,7 +532,7 @@ func.func @slice_2D_fold() -> tensor<2x2xi64> { // CHECK-SAME: [6, 7], // CHECK-SAME: [10, 11] // CHECK-SAME: ]> - %1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>) + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x4xi64>) -> (tensor<2x2xi64>) func.return %1 : tensor<2x2xi64> } @@ -542,7 +542,7 @@ func.func @slice_2D_fold_horizontal() -> tensor<1x4xi64> { // CHECK-NEXT: mhlo.constant dense<[ // CHECK-SAME: [0, 1, 2, 3] // CHECK-SAME: ]> - %1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>) + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x4xi64>) -> (tensor<1x4xi64>) func.return %1 : tensor<1x4xi64> } @@ -552,7 +552,7 @@ func.func @slice_2D_fold_vertical() -> tensor<4x1xi64> { // CHECK-NEXT: mhlo.constant dense<[ // CHECK-SAME: [2], [6], [10], [14] // CHECK-SAME: ]> - %1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>) + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x4xi64>) -> (tensor<4x1xi64>) func.return %1 : tensor<4x1xi64> } @@ -560,39 +560,39 @@ func.func @slice_2D_fold_vertical() -> tensor<4x1xi64> { func.func @slice_zero_elements() -> tensor<0xi64> { %0 = mhlo.constant dense<> : tensor<0xi64> // CHECK: %[[CONST:.*]] = mhlo.constant dense<> : tensor<0xi64> - %1 = "mhlo.slice"(%0) { limit_indices = dense<[0]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<0xi64>) -> (tensor<0xi64>) + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[0]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<0xi64>) -> (tensor<0xi64>) // CHECK: return %[[CONST]] : tensor<0xi64> func.return %1 : tensor<0xi64> } // CHECK-LABEL: slice_unknown_shape -func.func @slice_unknown_shape(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32> - %0 = "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> +func.func @slice_unknown_shape(%arg0: tensor) -> tensor { + // CHECK: "mhlo.slice"(%arg0) <{limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor) -> tensor + %0 = "mhlo.slice"(%arg0) <{limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor) -> tensor + func.return %0 : tensor } // CHECK-LABEL: slice_concat_fold_first func.func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> - %1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>) + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<2x5xf32>) -> (tensor<1x5xf32>) // CHECK: return %arg0 func.return %1 : tensor<1x5xf32> } // CHECK-LABEL: slice_concat_fold_second func.func @slice_concat_fold_second(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> - %1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>) + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<2x5xf32>) -> (tensor<1x5xf32>) // CHECK: return %arg1 func.return %1 : tensor<1x5xf32> } // CHECK-LABEL: slice_concat_fold_second_with_slice func.func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x4xf32> { - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> - // CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x5xf32>) -> tensor<1x4xf32> - %1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x4xf32>) + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> + // CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<1x5xf32>) -> tensor<1x4xf32> + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<2x5xf32>) -> (tensor<1x4xf32>) // CHECK: return [[SLICE]] func.return %1 : tensor<1x4xf32> @@ -600,9 +600,9 @@ func.func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: te // CHECK-LABEL: slice_concat_fold_middle func.func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> { - %0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> - // CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - %1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<1x5xf32>) + %0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) <{ dimension = 0 : i64 }> : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> + // CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x5xf32>) -> (tensor<1x5xf32>) // CHECK: return [[SLICE]] func.return %1 : tensor<1x5xf32> @@ -610,11 +610,11 @@ func.func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf3 // CHECK-LABEL: slice_concat_fold_two func.func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<2x5xf32> { - // CHECK: [[CONCAT:%.+]] = "mhlo.concatenate"(%arg1, %arg2) {dimension = 0 : i64} - %0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> + // CHECK: [[CONCAT:%.+]] = "mhlo.concatenate"(%arg1, %arg2) <{dimension = 0 : i64}> + %0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) <{ dimension = 0 : i64 }> : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> - // CHECK: [[SLICE:%.+]] = "mhlo.slice"([[CONCAT]]) {limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - %1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<2x5xf32>) + // CHECK: [[SLICE:%.+]] = "mhlo.slice"([[CONCAT]]) <{limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x5xf32>) -> (tensor<2x5xf32>) // CHECK: return [[SLICE]] func.return %1 : tensor<2x5xf32> @@ -622,9 +622,9 @@ func.func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, // CHECK-LABEL: slice_concat_empty func.func @slice_concat_empty(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> { - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> - %1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<0x5xf32>) - %2 = "mhlo.concatenate"(%1, %arg2) { dimension = 0 : i64 } : (tensor<0x5xf32>, tensor<1x5xf32>) -> tensor<1x5xf32> + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<2x5xf32>) -> (tensor<0x5xf32>) + %2 = "mhlo.concatenate"(%1, %arg2) <{ dimension = 0 : i64 }> : (tensor<0x5xf32>, tensor<1x5xf32>) -> tensor<1x5xf32> // CHECK: return %arg2 func.return %2 : tensor<1x5xf32> @@ -633,28 +633,28 @@ func.func @slice_concat_empty(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>, %a // CHECK-LABEL: func @broadcast_identity func.func @broadcast_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { // CHECK: return %arg0 - %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[]> : tensor<0xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + %0 = "mhlo.broadcast"(%arg0) <{broadcast_sizes = dense<[]> : tensor<0xi64>}> : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32> func.return %0 : tensor<2x3x4xf32> } // CHECK-LABEL: func @broadcast_dynamic_shape_identity func.func @broadcast_dynamic_shape_identity(%arg0: tensor) -> tensor { // CHECK: return %arg0 - %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[]> : tensor<0xi64>} : (tensor) -> tensor + %0 = "mhlo.broadcast"(%arg0) <{broadcast_sizes = dense<[]> : tensor<0xi64>}> : (tensor) -> tensor func.return %0 : tensor } // CHECK-LABEL: func @broadcast_dynamic_shape_not_identity func.func @broadcast_dynamic_shape_not_identity(%arg0: tensor) -> tensor<20x?x?x?xf32> { // CHECK: mhlo.broadcast - %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[20]> : tensor<1xi64>} : (tensor) -> tensor<20x?x?x?xf32> + %0 = "mhlo.broadcast"(%arg0) <{broadcast_sizes = dense<[20]> : tensor<1xi64>}> : (tensor) -> tensor<20x?x?x?xf32> func.return %0 : tensor<20x?x?x?xf32> } // CHECK-LABEL: func @broadcast_constant_fold_0d func.func @broadcast_constant_fold_0d() -> tensor<1x64x224x224xf32> { %cst = mhlo.constant dense<0.000000e+00> : tensor - %b = "mhlo.broadcast"(%cst) {broadcast_sizes = dense<[1, 64, 224, 224]> : tensor<4xi64>} : (tensor) -> tensor<1x64x224x224xf32> + %b = "mhlo.broadcast"(%cst) <{broadcast_sizes = dense<[1, 64, 224, 224]> : tensor<4xi64>}> : (tensor) -> tensor<1x64x224x224xf32> func.return %b : tensor<1x64x224x224xf32> } // CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x224x224xf32> @@ -663,7 +663,7 @@ func.func @broadcast_constant_fold_0d() -> tensor<1x64x224x224xf32> { // CHECK-LABEL: func @broadcast_constant_fold func.func @broadcast_constant_fold() -> tensor<1x64x4x4xf32> { %cst = mhlo.constant dense<0.000000e+00> : tensor<4x4xf32> - %b = "mhlo.broadcast"(%cst) {broadcast_sizes = dense<[1, 64]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<1x64x4x4xf32> + %b = "mhlo.broadcast"(%cst) <{broadcast_sizes = dense<[1, 64]> : tensor<2xi64>}> : (tensor<4x4xf32>) -> tensor<1x64x4x4xf32> func.return %b : tensor<1x64x4x4xf32> } // CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x4x4xf32> @@ -674,14 +674,14 @@ func.func @broadcast_constant_fold_not_splat() -> tensor<1x64x2xf32> { // CHECK: mhlo.constant %cst = mhlo.constant dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32> // CHECK: mhlo.broadcast - %b = "mhlo.broadcast"(%cst) {broadcast_sizes = dense<[1, 64]> : tensor<2xi64>} : (tensor<2xf32>) -> tensor<1x64x2xf32> + %b = "mhlo.broadcast"(%cst) <{broadcast_sizes = dense<[1, 64]> : tensor<2xi64>}> : (tensor<2xf32>) -> tensor<1x64x2xf32> func.return %b : tensor<1x64x2xf32> } // CHECK-LABEL: func @broadcast_constant_fold_complex func.func @broadcast_constant_fold_complex() -> tensor<1x64x224x224xcomplex> { %cst = mhlo.constant dense<(0.000000e+00,1.000000e+00)> : tensor> - %b = "mhlo.broadcast"(%cst) {broadcast_sizes = dense<[1, 64, 224, 224]> : tensor<4xi64>} : (tensor>) -> tensor<1x64x224x224xcomplex> + %b = "mhlo.broadcast"(%cst) <{broadcast_sizes = dense<[1, 64, 224, 224]> : tensor<4xi64>}> : (tensor>) -> tensor<1x64x224x224xcomplex> func.return %b : tensor<1x64x224x224xcomplex> } // CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<(0.000000e+00,1.000000e+00)> : tensor<1x64x224x224xcomplex> @@ -690,31 +690,31 @@ func.func @broadcast_constant_fold_complex() -> tensor<1x64x224x224xcomplex // CHECK-LABEL: func @broadcast_constant_fold_quantized_skipped func.func @broadcast_constant_fold_quantized_skipped() -> tensor<1x64x224x224x!quant.uniform> { %cst = stablehlo.constant() {value = dense<2> : tensor} : () -> tensor> - %b = "mhlo.broadcast"(%cst) {broadcast_sizes = dense<[1, 64, 224, 224]> : tensor<4xi64>} : (tensor>) -> tensor<1x64x224x224x!quant.uniform> + %b = "mhlo.broadcast"(%cst) <{broadcast_sizes = dense<[1, 64, 224, 224]> : tensor<4xi64>}> : (tensor>) -> tensor<1x64x224x224x!quant.uniform> func.return %b : tensor<1x64x224x224x!quant.uniform> } // CHECK-NEXT: %[[CST:.*]] = stablehlo.constant() {value = dense<2> : tensor} : () -> tensor> -// CHECK-NEXT: %[[RES:.*]] = "mhlo.broadcast"(%[[CST:.*]]) {broadcast_sizes = dense<[1, 64, 224, 224]> : tensor<4xi64>} : (tensor>) -> tensor<1x64x224x224x!quant.uniform> +// CHECK-NEXT: %[[RES:.*]] = "mhlo.broadcast"(%[[CST:.*]]) <{broadcast_sizes = dense<[1, 64, 224, 224]> : tensor<4xi64>}> : (tensor>) -> tensor<1x64x224x224x!quant.uniform> // CHECK-NEXT: return %[[RES:.*]] : tensor<1x64x224x224x!quant.uniform> // CHECK-LABEL: func @broadcast_in_dim_identity func.func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { // CHECK: return %arg0 - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}> : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32> func.return %0 : tensor<2x3x4xf32> } // CHECK-LABEL: func @broadcast_in_dim_equivalent_reshape func.func @broadcast_in_dim_equivalent_reshape(%arg0: tensor<2x3x4xf32>) -> tensor<1x2x3x4xf32> { // CHECK: mhlo.reshape - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<1x2x3x4xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}> : (tensor<2x3x4xf32>) -> tensor<1x2x3x4xf32> func.return %0 : tensor<1x2x3x4xf32> } // CHECK-LABEL: func @broadcast_in_dim_not_identity_because_it_actually_broadcasts func.func @broadcast_in_dim_not_identity_because_it_actually_broadcasts(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> { // CHECK: mhlo.broadcast_in_dim - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<2x2xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x2xf32>) -> tensor<2x2xf32> func.return %0 : tensor<2x2xf32> } @@ -722,16 +722,16 @@ func.func @broadcast_in_dim_not_identity_because_it_actually_broadcasts(%arg0: t func.func @broadcast_in_dim_equivalent_transpose(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: mhlo.transpose // CHECK-SAME: permutation = dense<[1, 0]> - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>}> : (tensor<2x2xf32>) -> tensor<2x2xf32> func.return %0 : tensor<2x2xf32> } // CHECK-LABEL: func @broadcast_in_dim_constant_fold_quantized_skipped func.func @broadcast_in_dim_constant_fold_quantized_skipped(%arg0: tensor<1x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { - %b = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> + %b = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> func.return %b : tensor<2x2x!quant.uniform> } -// CHECK-NEXT: %[[RES:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> +// CHECK-NEXT: %[[RES:.*]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> // CHECK-NEXT: return %[[RES:.*]] : tensor<2x2x!quant.uniform> // CHECK-LABEL: func @broadcast_consecutive @@ -739,15 +739,15 @@ func.func @broadcast_consecutive(%arg0: tensor<2x3xf32>) -> tensor<2x3x4x5xf32> // CHECK: mhlo.broadcast_in_dim // CHECK-SAME: broadcast_dimensions = dense<[0, 1]> // CHECK-NEXT: return - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<2x3x4xf32> - %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4x5xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<2x3xf32>) -> tensor<2x3x4xf32> + %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}> : (tensor<2x3x4xf32>) -> tensor<2x3x4x5xf32> func.return %1 : tensor<2x3x4x5xf32> } // CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> { - // CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32> - %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32> + // CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<5x4xf32> + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) <{ broadcast_dimensions = dense<1> : tensor<1xi64> }> : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32> // CHECK: return %[[RESULT]] : tensor<5x4xf32> func.return %0 : tensor<5x4xf32> } @@ -755,8 +755,8 @@ func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32> // CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_shape func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_shape(%arg0: tensor) -> tensor<4x32xi32> { %0 = mhlo.constant dense<[4, 32]> : tensor<2xi32> - // CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<4x32xi32> - %1 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xi32>) -> tensor + // CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4x32xi32> + %1 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<2xi32>) -> tensor %2 = "mhlo.dynamic_reshape"(%1, %0) : (tensor, tensor<2xi32>) -> tensor<4x32xi32> // CHECK: return %[[RESULT]] : tensor<4x32xi32> func.return %2 : tensor<4x32xi32> @@ -765,8 +765,8 @@ func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_shape(%arg0 // CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_index_shape func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_index_shape(%arg0: tensor) -> tensor<4x32xf32> { %0 = shape.const_shape [4, 32] : tensor<2xindex> - // CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<4x32xf32> - %1 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4x32xf32> + %1 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<2xindex>) -> tensor %2 = "mhlo.dynamic_reshape"(%1, %0) : (tensor, tensor<2xindex>) -> tensor<4x32xf32> // CHECK: return %[[RESULT]] : tensor<4x32xf32> func.return %2 : tensor<4x32xf32> @@ -775,8 +775,8 @@ func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_index_shape // CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_requires_cast func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_requires_cast(%arg0: tensor) -> tensor { %0 = shape.const_shape [4, 32] : tensor<2xindex> - // CHECK: %[[BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<4x32xf32> - %1 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK: %[[BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4x32xf32> + %1 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<2xindex>) -> tensor // CHECK: %[[RESULT:.*]] = tensor.cast %[[BCAST]] : tensor<4x32xf32> to tensor // CHECK: return %[[RESULT]] : tensor func.return %1 : tensor @@ -784,8 +784,8 @@ func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_requires_ca // CHECK-LABEL: func @dynamic_broadcast_in_dim_op_almost_not_actually_dynamic func.func @dynamic_broadcast_in_dim_op_almost_not_actually_dynamic(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<5x4xf32> { - // CHECK: %[[RESULT:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xi64>) -> tensor<5x4xf32> - %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor, tensor<2xi64>) -> tensor<5x4xf32> + // CHECK: %[[RESULT:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor, tensor<2xi64>) -> tensor<5x4xf32> + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) <{ broadcast_dimensions = dense<1> : tensor<1xi64> }> : (tensor, tensor<2xi64>) -> tensor<5x4xf32> // CHECK: return %[[RESULT]] : tensor<5x4xf32> func.return %0 : tensor<5x4xf32> } @@ -794,7 +794,7 @@ func.func @dynamic_broadcast_in_dim_op_almost_not_actually_dynamic(%arg0: tensor func.func @dynamic_broadcast_in_dim_to_same_shape_1(%arg0: tensor) -> tensor { // CHECK-SAME: %[[ARG:.*]]: tensor %0 = shape.shape_of %arg0 : tensor -> tensor<1xindex> - %2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor, tensor<1xindex>) -> tensor + %2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) <{ broadcast_dimensions = dense<0> : tensor<1xi64> }> : (tensor, tensor<1xindex>) -> tensor // CHECK: return %[[ARG]] : tensor func.return %2 : tensor } @@ -804,51 +804,48 @@ func.func @dynamic_broadcast_in_dim_to_same_shape_2(%arg0: tensor) -> ten // CHECK-SAME: %[[ARG:.*]]: tensor %0 = shape.shape_of %arg0 : tensor -> !shape.shape %1 = shape.to_extent_tensor %0 : !shape.shape -> tensor<1xindex> - %2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %1) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor, tensor<1xindex>) -> tensor + %2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %1) <{ broadcast_dimensions = dense<0> : tensor<1xi64> }> : (tensor, tensor<1xindex>) -> tensor // CHECK: return %[[ARG]] : tensor func.return %2 : tensor } // CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_3 -func.func @dynamic_broadcast_in_dim_to_same_shape_3(%arg0: tensor<*xf32>) -> tensor { - // CHECK-SAME: %[[ARG:.*]]: tensor<*xf32> - %0 = shape.shape_of %arg0 : tensor<*xf32> -> tensor +func.func @dynamic_broadcast_in_dim_to_same_shape_3(%arg0: tensor) -> tensor { + // CHECK-SAME: %[[ARG:.*]]: tensor + %0 = shape.shape_of %arg0 : tensor -> tensor %1 = tensor.cast %0 : tensor to tensor<1xindex> - %2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %1) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<*xf32>, tensor<1xindex>) -> tensor - // CHECK: %[[RES:.*]] = tensor.cast %[[ARG]] : tensor<*xf32> to tensor - // CHECK: return %[[RES]] : tensor + %2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %1) <{ broadcast_dimensions = dense<0> : tensor<1xi64> }> : (tensor, tensor<1xindex>) -> tensor + // CHECK: return %[[ARG]] : tensor func.return %2 : tensor } // CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_4 -func.func @dynamic_broadcast_in_dim_to_same_shape_4(%arg0: tensor<*xf32>) -> tensor { - // CHECK-SAME: %[[ARG:.*]]: tensor<*xf32> - %0 = shape.shape_of %arg0 : tensor<*xf32> -> !shape.shape +func.func @dynamic_broadcast_in_dim_to_same_shape_4(%arg0: tensor) -> tensor { + // CHECK-SAME: %[[ARG:.*]]: tensor + %0 = shape.shape_of %arg0 : tensor -> !shape.shape %1 = shape.to_extent_tensor %0 : !shape.shape -> tensor %2 = tensor.cast %1 : tensor to tensor<1xindex> - %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<*xf32>, tensor<1xindex>) -> tensor - // CHECK: %[[RES:.*]] = tensor.cast %[[ARG]] : tensor<*xf32> to tensor - // CHECK: return %[[RES]] : tensor + %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) <{ broadcast_dimensions = dense<0> : tensor<1xi64> }> : (tensor, tensor<1xindex>) -> tensor + // CHECK: return %[[ARG]] : tensor func.return %3 : tensor } // CHECK-LABEL: func @dynamic_broadcast_in_dim_all_dims_non_expanding -func.func @dynamic_broadcast_in_dim_all_dims_non_expanding(%arg0: tensor<*xf32>, %arg1: tensor<1xindex>) -> tensor { - // CHECK-SAME: %[[ARG:.*]]: tensor<*xf32> +func.func @dynamic_broadcast_in_dim_all_dims_non_expanding(%arg0: tensor, %arg1: tensor<1xindex>) -> tensor { + // CHECK-SAME: %[[ARG:.*]]: tensor %1 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<0> : tensor<1xi64>, known_expanding_dimensions = dense<> : tensor<0xi64>, known_nonexpanding_dimensions = dense<0> : tensor<1xi64> - } : (tensor<*xf32>, tensor<1xindex>) -> tensor - // CHECK: %[[RES:.*]] = tensor.cast %[[ARG]] : tensor<*xf32> to tensor - // CHECK: return %[[RES]] : tensor + } : (tensor, tensor<1xindex>) -> tensor + // CHECK: return %[[ARG]] : tensor func.return %1 : tensor } // CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d func.func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> { %cst = mhlo.constant dense<0.000000e+00> : tensor - %b = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<1x64x224x224xf32> + %b = "mhlo.broadcast_in_dim"(%cst) <{broadcast_dimensions = dense<[]> : tensor<0xi64>}> : (tensor) -> tensor<1x64x224x224xf32> func.return %b : tensor<1x64x224x224xf32> } // CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x224x224xf32> @@ -857,7 +854,7 @@ func.func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> { // CHECK-LABEL: func @broadcast_in_dim_constant_fold func.func @broadcast_in_dim_constant_fold() -> tensor<1x64x4x4xf32> { %cst = mhlo.constant dense<0.000000e+00> : tensor<4x4xf32> - %b = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<1x64x4x4xf32> + %b = "mhlo.broadcast_in_dim"(%cst) <{broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>}> : (tensor<4x4xf32>) -> tensor<1x64x4x4xf32> func.return %b : tensor<1x64x4x4xf32> } // CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x4x4xf32> @@ -866,7 +863,7 @@ func.func @broadcast_in_dim_constant_fold() -> tensor<1x64x4x4xf32> { // CHECK-LABEL: func @broadcast_in_dim_constant_fold_complex func.func @broadcast_in_dim_constant_fold_complex() -> tensor<1x64x224x224xcomplex> { %cst = mhlo.constant dense<(0.000000e+00,1.000000e+00)> : tensor> - %b = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor>) -> tensor<1x64x224x224xcomplex> + %b = "mhlo.broadcast_in_dim"(%cst) <{broadcast_dimensions = dense<[]> : tensor<0xi64>}> : (tensor>) -> tensor<1x64x224x224xcomplex> func.return %b : tensor<1x64x224x224xcomplex> } // CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<(0.000000e+00,1.000000e+00)> : tensor<1x64x224x224xcomplex> @@ -895,15 +892,15 @@ func.func @complex_collapse_fold(%arg0: tensor<4xcomplex>) -> tensor<4xcomp func.func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> { // CHECK: [[RESULT:%.*]] = "mhlo.iota" // CHECK: return [[RESULT]] - %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<4xi32> + %0 = "mhlo.dynamic_iota"(%arg0) <{iota_dimension = 0 : i64}> : (tensor<1xindex>) -> tensor<4xi32> func.return %0 : tensor<4xi32> } // CHECK-LABEL: @dynamic_iota_broadcast func.func @dynamic_iota_broadcast(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> { - // CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xi32> - // CHECK: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xi32>, tensor<2xindex>) -> tensor<5x?xi32> - %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<5x?xi32> + // CHECK: [[IOTA:%.+]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<5xi32> + // CHECK: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<5xi32>, tensor<2xindex>) -> tensor<5x?xi32> + %0 = "mhlo.dynamic_iota"(%arg0) <{iota_dimension = 0 : i64}> : (tensor<2xindex>) -> tensor<5x?xi32> // CHECK: return [[BROADCAST]] func.return %0 : tensor<5x?xi32> @@ -912,11 +909,11 @@ func.func @dynamic_iota_broadcast(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> { // CHECK-LABEL: @dynamic_iota_broadcast_second func.func @dynamic_iota_broadcast_second(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> { // CHECK-NEXT: [[CAST1:%.+]] = arith.index_cast %arg0 : tensor<2xindex> to tensor<2xi64> - // CHECK-NEXT: [[SLICE:%.+]] = "mhlo.slice"([[CAST1]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> + // CHECK-NEXT: [[SLICE:%.+]] = "mhlo.slice"([[CAST1]]) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi64>) -> tensor<1xi64> // CHECK-NEXT: [[CAST2:%.+]] = arith.index_cast [[SLICE]] : tensor<1xi64> to tensor<1xindex> - // CHECK-NEXT: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[CAST2]]) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor - // CHECK-NEXT: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor<5x?xi32> - %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<5x?xi32> + // CHECK-NEXT: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[CAST2]]) <{iota_dimension = 0 : i64}> : (tensor<1xindex>) -> tensor + // CHECK-NEXT: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor, tensor<2xindex>) -> tensor<5x?xi32> + %0 = "mhlo.dynamic_iota"(%arg0) <{iota_dimension = 1 : i64}> : (tensor<2xindex>) -> tensor<5x?xi32> // CHECK: return [[BROADCAST]] func.return %0 : tensor<5x?xi32> @@ -925,8 +922,8 @@ func.func @dynamic_iota_broadcast_second(%arg0 : tensor<2xindex>) -> tensor<5x?x // CHECK-LABEL: @dynamic_iota_constant func.func @dynamic_iota_constant(%arg0 : tensor<2xindex>) -> tensor<1x?xi32> { // CHECK: [[IOTA:%.+]] = mhlo.constant dense<0> : tensor<1xi32> - // CHECK: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xi32>, tensor<2xindex>) -> tensor<1x?xi32> - %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<1x?xi32> + // CHECK: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xi32>, tensor<2xindex>) -> tensor<1x?xi32> + %0 = "mhlo.dynamic_iota"(%arg0) <{iota_dimension = 0 : i64}> : (tensor<2xindex>) -> tensor<1x?xi32> // CHECK: return [[BROADCAST]] func.return %0 : tensor<1x?xi32> @@ -935,7 +932,7 @@ func.func @dynamic_iota_constant(%arg0 : tensor<2xindex>) -> tensor<1x?xi32> { // CHECK-LABEL: @iota_constant func.func @iota_constant() -> tensor<1xi32> { // CHECK: [[CONST:%.+]] = mhlo.constant dense<0> : tensor<1xi32> - %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<1xi32> + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<1xi32> // CHECK: return [[CONST]] : tensor<1xi32> func.return %0 : tensor<1xi32> @@ -944,7 +941,7 @@ func.func @iota_constant() -> tensor<1xi32> { // CHECK-LABEL: @iota_constant_multi func.func @iota_constant_multi() -> tensor<1x4xi32> { // CHECK: [[CONST:%.+]] = mhlo.constant dense<0> : tensor<1x4xi32> - %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<1x4xi32> + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<1x4xi32> // CHECK: return [[CONST]] : tensor<1x4xi32> func.return %0 : tensor<1x4xi32> @@ -954,24 +951,24 @@ func.func @iota_constant_multi() -> tensor<1x4xi32> { func.func @iota_not_lowered_to_constant() -> tensor<4xi32> { // CHECK: [[RESULT:%.*]] = "mhlo.iota" // CHECK: return [[RESULT]] - %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<4xi32> func.return %0 : tensor<4xi32> } // CHECK-LABEL: @iota_broadcast func.func @iota_broadcast() -> tensor<5x4xi32> { - // CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xi32> - // CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xi32>) -> tensor<5x4xi32> - %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5x4xi32> + // CHECK: [[IOTA:%.+]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<5xi32> + // CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<5xi32>) -> tensor<5x4xi32> + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<5x4xi32> func.return %0 : tensor<5x4xi32> } // CHECK-LABEL: @iota_broadcast func.func @iota_broadcast_second() -> tensor<5x4xi32> { - // CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> - // CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<5x4xi32> - %0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<5x4xi32> + // CHECK: [[IOTA:%.+]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<4xi32> + // CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<4xi32>) -> tensor<5x4xi32> + %0 = "mhlo.iota"() <{iota_dimension = 1 : i64}> : () -> tensor<5x4xi32> func.return %0 : tensor<5x4xi32> } @@ -979,7 +976,7 @@ func.func @iota_broadcast_second() -> tensor<5x4xi32> { // CHECK-LABEL: @unary_einsum func.func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK: "mhlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"} + // CHECK: "mhlo.einsum"(%[[ONE]], %arg0) <{einsum_config = ",ab->aa"}> %0 = "mhlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> func.return %0 : tensor<2x2xf32> } @@ -1002,9 +999,9 @@ func.func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: te // CHECK-LABEL: func @shape_of_dynamic_reshape // CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]] // CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]] -func.func @shape_of_dynamic_reshape(%arg0: tensor<*xf32>, %shape: tensor<2xindex>) -> tensor<2xindex> { +func.func @shape_of_dynamic_reshape(%arg0: tensor, %shape: tensor<2xindex>) -> tensor<2xindex> { // CHECK: return [[ARG1]] - %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<*xf32>, tensor<2xindex>) -> tensor + %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor %1 = shape.shape_of %0 : tensor -> tensor<2xindex> func.return %1 : tensor<2xindex> } @@ -1029,11 +1026,11 @@ func.func @dynamic_reshape_rank_1_to_rank_1(%arg0: tensor>, // CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]] func.func @dynamic_reshape_of_dynamic_reshape(%arg0: tensor, %shape: tensor) -> tensor { // CHECK: return [[ARG0]] - %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor) -> tensor<*xf16> - %1 = shape.shape_of %0 : tensor<*xf16> -> tensor + %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor) -> tensor + %1 = shape.shape_of %0 : tensor -> tensor %2 = shape.num_elements %1 : tensor -> index %3 = tensor.from_elements %2 : tensor<1xindex> - %4 = "mhlo.dynamic_reshape"(%0, %3) : (tensor<*xf16>, tensor<1xindex>) -> tensor + %4 = "mhlo.dynamic_reshape"(%0, %3) : (tensor, tensor<1xindex>) -> tensor func.return %4 : tensor } @@ -1455,7 +1452,7 @@ func.func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1 // CHECK-LABEL: func @fold_get_dimension_size func.func @fold_get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor { - %size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i64} : (tensor<1x128x512xf32>) -> tensor + %size = "mhlo.get_dimension_size"(%I) <{dimension = 2 : i64}> : (tensor<1x128x512xf32>) -> tensor func.return %size : tensor // CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor // CHECK-NEXT: return %[[C]] @@ -1464,7 +1461,7 @@ func.func @fold_get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor { // CHECK-LABEL: func @fold_get_dimension_size_fail func.func @fold_get_dimension_size_fail(%I: tensor<1x128x?xf32>) -> tensor { // CHECK: "mhlo.get_dimension_size" - %size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i64} : (tensor<1x128x?xf32>) -> tensor + %size = "mhlo.get_dimension_size"(%I) <{dimension = 2 : i64}> : (tensor<1x128x?xf32>) -> tensor func.return %size : tensor } @@ -1522,11 +1519,11 @@ func.func @simplify_not_as_select_pred(%arg0 : tensor<4xi1>, %arg1 : tensor<4xf3 // CHECK-LABEL: func @simplify_broadcasted_not_as_select_pred( func.func @simplify_broadcasted_not_as_select_pred(%arg0 : tensor<1xi1>, %arg1 : tensor<4xf32>, %arg2 : tensor<4xf32>) -> tensor<4xf32> { %0 = "mhlo.not"(%arg0) : (tensor<1xi1>) -> tensor<1xi1> - %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<[0]> : tensor<1xi64> } : (tensor<1xi1>) -> tensor<4xi1> + %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[0]> : tensor<1xi64> }> : (tensor<1xi1>) -> tensor<4xi1> %2 = "mhlo.select"(%1, %arg1, %arg2) : (tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> func.return %2 : tensor<4xf32> - // CHECK: %[[B:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xi1>) -> tensor<4xi1> + // CHECK: %[[B:.*]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xi1>) -> tensor<4xi1> // CHECK: %[[R:.*]] = mhlo.select %[[B]], %arg2, %arg1 // CHECK: return %[[R]] } @@ -1543,7 +1540,7 @@ func.func @gather_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<3x6x5xf32> { indices_are_sorted = false, slice_sizes = dense<[3, 6, 5]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor<2xi32>) -> tensor<3x6x5xf32> func.return %1 : tensor<3x6x5xf32> - // CHECK: %[[RET:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[4, 6, 7]> : tensor<3xi64>, start_indices = dense<[1, 0, 2]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<3x6x5xf32> + // CHECK: %[[RET:.*]] = "mhlo.slice"(%arg0) <{limit_indices = dense<[4, 6, 7]> : tensor<3xi64>, start_indices = dense<[1, 0, 2]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> : (tensor<5x6x7xf32>) -> tensor<3x6x5xf32> // CHECK: return %[[RET]] : tensor<3x6x5xf32> } @@ -1559,7 +1556,7 @@ func.func @gather_scalar_index_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<5x6x indices_are_sorted = false, slice_sizes = dense<[5, 6, 4]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor) -> tensor<5x6x4xf32> func.return %1 : tensor<5x6x4xf32> - // CHECK: %[[RET:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[5, 6, 5]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<5x6x4xf32> + // CHECK: %[[RET:.*]] = "mhlo.slice"(%arg0) <{limit_indices = dense<[5, 6, 5]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> : (tensor<5x6x7xf32>) -> tensor<5x6x4xf32> // CHECK: return %[[RET]] : tensor<5x6x4xf32> } @@ -1576,7 +1573,7 @@ func.func @gather_to_slice_reshape(%arg0: tensor<5x6x7xf32>) -> tensor<3x6xf32> indices_are_sorted = false, slice_sizes = dense<[3, 6, 1]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor<2xi32>) -> tensor<3x6xf32> func.return %1 : tensor<3x6xf32> - // CHECK: %[[V0:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[4, 6, 3]> : tensor<3xi64>, start_indices = dense<[1, 0, 2]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<3x6x1xf32> + // CHECK: %[[V0:.*]] = "mhlo.slice"(%arg0) <{limit_indices = dense<[4, 6, 3]> : tensor<3xi64>, start_indices = dense<[1, 0, 2]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> : (tensor<5x6x7xf32>) -> tensor<3x6x1xf32> // CHECK: %[[V1:.*]] = mhlo.reshape %[[V0]] : (tensor<3x6x1xf32>) -> tensor<3x6xf32> // CHECK: return %[[V1]] : tensor<3x6xf32> } @@ -1593,7 +1590,7 @@ func.func @gather_to_slice_indices_clamp_upperbound(%arg0 : tensor<4x2xui32>) -> >, indices_are_sorted = true, slice_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x2xui32>, tensor<1xi32>) -> tensor<2xui32> func.return %1 : tensor<2xui32> - // CHECK: %[[V0:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<[3, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x2xui32>) -> tensor<1x2xui32> + // CHECK: %[[V0:.*]] = "mhlo.slice"(%arg0) <{limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<[3, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x2xui32>) -> tensor<1x2xui32> // CHECK: %[[V1:.*]] = mhlo.reshape %[[V0]] : (tensor<1x2xui32>) -> tensor<2xui32> // CHECK: return %[[V1]] : tensor<2xui32> } @@ -1610,7 +1607,7 @@ func.func @gather_to_slice_indices_clamp_lowerbound(%arg0 : tensor<4x2xui32>) -> >, indices_are_sorted = true, slice_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x2xui32>, tensor<1xi32>) -> tensor<2xui32> func.return %1 : tensor<2xui32> - // CHECK: %[[V0:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[1, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x2xui32>) -> tensor<1x2xui32> + // CHECK: %[[V0:.*]] = "mhlo.slice"(%arg0) <{limit_indices = dense<[1, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x2xui32>) -> tensor<1x2xui32> // CHECK: %[[V1:.*]] = mhlo.reshape %[[V0]] : (tensor<1x2xui32>) -> tensor<2xui32> // CHECK: return %[[V1]] : tensor<2xui32> } @@ -2376,7 +2373,7 @@ func.func @pad_negative_fold() -> tensor<4x4xi32> { func.func @pad_fold_zero_elements() -> tensor<3xi32> { %0 = mhlo.constant dense<> : tensor<0xi32> %1 = mhlo.constant dense<7> : tensor - %2 = "mhlo.pad"(%0, %1) {edge_padding_high = dense<3> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<0xi32>, tensor) -> tensor<3xi32> + %2 = "mhlo.pad"(%0, %1) <{edge_padding_high = dense<3> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<0xi32>, tensor) -> tensor<3xi32> func.return %2 : tensor<3xi32> // CHECK: mhlo.constant dense<7> : tensor<3xi32> } @@ -2385,14 +2382,14 @@ func.func @pad_fold_zero_elements() -> tensor<3xi32> { func.func @pad_float_fold() -> tensor<2xf32> { %0 = mhlo.constant dense<2.000000e+00> : tensor<1xf32> %1 = mhlo.constant dense<1.000000e+00> : tensor - %2 = "mhlo.pad"(%0, %1) {edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<1xf32>, tensor) -> tensor<2xf32> + %2 = "mhlo.pad"(%0, %1) <{edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<1xf32>, tensor) -> tensor<2xf32> return %2 : tensor<2xf32> // CHECK: mhlo.constant dense<[2.000000e+00, 1.000000e+00]> : tensor<2xf32> } // CHECK-LABEL: @pad_zero_length func.func @pad_zero_length(%arg0: tensor<5x0xf32>, %arg1: tensor) -> tensor<7x2xf32> { - // CHECK: %[[RES:.+]] = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<7x2xf32> + // CHECK: %[[RES:.+]] = "mhlo.broadcast_in_dim"(%arg1) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<7x2xf32> %0 = "mhlo.pad"(%arg0, %arg1) { edge_padding_low = dense<1> : tensor<2xi64>, edge_padding_high = dense<1> : tensor<2xi64>, @@ -2414,7 +2411,7 @@ func.func @pad_zero_length_dyn(%arg0: tensor, %arg1: tensor) -> te // CHECK-DAG: %[[ADD1:.+]] = arith.addi %[[DIM]], %[[MUL]] // CHECK-DAG: %[[ADD2:.+]] = arith.addi %[[ADD1]], %[[C2]] // CHECK-DAG: %[[SHAPE:.+]] = tensor.from_elements %[[ADD2]], %[[C2]] : tensor<2xindex> - // CHECK-DAG: %[[BROAD:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK-DAG: %[[BROAD:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<2xindex>) -> tensor %0 = "mhlo.pad"(%arg0, %arg1) { edge_padding_low = dense<1> : tensor<2xi64>, edge_padding_high = dense<1> : tensor<2xi64>, @@ -2446,11 +2443,11 @@ func.func @dynamic_pad_identity_fold(%arg0: tensor<5x7xf32>) -> tensor<11x15xf32 func.func @dynamic_pad_length_dyn( %arg0: tensor, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>) -> tensor { - // CHECK: %[[C0:.+]] = arith.constant 0 : i32 - // CHECK: %[[C1:.+]] = arith.constant 1 : i32 - // CHECK: %[[CI0:.+]] = arith.constant 0 : index - // CHECK: %[[CI1:.+]] = arith.constant 1 : index - // CHECK: %[[CST:.+]] = arith.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32 + // CHECK-DAG: %[[CI0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[CI1:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : tensor // CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %[[CI0]] // CHECK: %[[CAST:.+]] = arith.index_cast %[[DIM0]] : index to i32 // CHECK: %[[EX0:.+]] = tensor.extract %arg1[%[[CI0]]] @@ -2467,7 +2464,7 @@ func.func @dynamic_pad_length_dyn( // CHECK: %[[EX4:.+]] = tensor.extract %arg2[%[[CI1]]] // CHECK: %[[ADD3:.+]] = arith.addi %[[EX3]], %[[EX4]] : i32 // CHECK: %[[SHAPE:.+]] = tensor.from_elements %[[ADD2]], %[[ADD3]] : tensor<2xi32> - // CHECK: %[[BROAD:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[CST]], %[[SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK: %[[BROAD:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[CST]], %[[SHAPE]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> %0 = arith.constant dense<0.0> : tensor %1 = "mhlo.dynamic_pad"(%arg0, %0, %arg1, %arg2, %arg3) { } : (tensor, tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor @@ -2479,7 +2476,7 @@ func.func @dynamic_pad_length_dyn( func.func @pad_complex_fold() -> tensor<2xcomplex> { %0 = mhlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor<1xcomplex> %1 = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor> - %2 = "mhlo.pad"(%0, %1) {edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<1xcomplex>, tensor>) -> tensor<2xcomplex> + %2 = "mhlo.pad"(%0, %1) <{edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<1xcomplex>, tensor>) -> tensor<2xcomplex> return %2 : tensor<2xcomplex> // CHECK: mhlo.constant dense<[(2.000000e+00,0.000000e+00), (1.000000e+00,0.000000e+00)]> : tensor<2xcomplex> } @@ -2668,11 +2665,11 @@ func.func @sort_drop_second_arg(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> t // CHECK-LABEL: @sort_drop_second_arg // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] -// CHECK: %[[RES:.+]] = "mhlo.sort"(%[[ARG0]]) +// CHECK: %[[RES:.+]] = "mhlo.sort"(%[[ARG0]]) <{dimension = 0 : i64, is_stable = false}> ({ // CHECK: ^bb0(%[[ARG2:.+]]: tensor, %[[ARG3:.+]]: tensor) // CHECK: %[[CMP:.+]] = mhlo.compare GT, %[[ARG2]], %[[ARG3]] : (tensor, tensor) -> tensor // CHECK: mhlo.return %[[CMP]] : tensor -// CHECK: {dimension = 0 : i64, is_stable = false} : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: }) : (tensor<3xi32>) -> tensor<3xi32> // CHECK: return %[[RES]] : tensor<3xi32> func.func @sort_no_dim_provided(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> { @@ -2705,7 +2702,7 @@ func.func public @reshape_splat_of_bools() -> tensor<2x1xi1> { func.func @simplify_dynamic_gather_i64(%arg0: tensor<375682x256xf16>, %arg1: tensor<16x64xi64>) -> tensor<16x64x256xf16> { %0 = "arith.constant"() {value = dense<[1, 256]> : tensor<2xi64>} : () -> tensor<2xi64> %1 = "mhlo.dynamic_gather"(%arg0, %arg1, %0) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor<375682x256xf16>, tensor<16x64xi64>, tensor<2xi64>) -> tensor<16x64x256xf16> - // CHECK: %[[RET:.+]] = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 256]> : tensor<2xi64>} : (tensor<375682x256xf16>, tensor<16x64xi64>) -> tensor<16x64x256xf16> + // CHECK: %[[RET:.+]] = "mhlo.gather"(%arg0, %arg1) <{dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 256]> : tensor<2xi64>}> : (tensor<375682x256xf16>, tensor<16x64xi64>) -> tensor<16x64x256xf16> // CHECK: return %[[RET]] return %1 : tensor<16x64x256xf16> } @@ -2714,7 +2711,7 @@ func.func @simplify_dynamic_gather_i64(%arg0: tensor<375682x256xf16>, %arg1: ten func.func @simplify_dynamic_gather_i32(%arg0: tensor<375682x256xf16>, %arg1: tensor<16x64xi64>) -> tensor<16x64x256xf16> { %0 = "arith.constant"() {value = dense<[1, 256]> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "mhlo.dynamic_gather"(%arg0, %arg1, %0) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor<375682x256xf16>, tensor<16x64xi64>, tensor<2xi32>) -> tensor<16x64x256xf16> - // CHECK: %[[RET:.+]] = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 256]> : tensor<2xi64>} : (tensor<375682x256xf16>, tensor<16x64xi64>) -> tensor<16x64x256xf16> + // CHECK: %[[RET:.+]] = "mhlo.gather"(%arg0, %arg1) <{dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 256]> : tensor<2xi64>}> : (tensor<375682x256xf16>, tensor<16x64xi64>) -> tensor<16x64x256xf16> // CHECK: return %[[RET]] return %1 : tensor<16x64x256xf16> } @@ -2757,12 +2754,12 @@ func.func @simplify_real_dynamic_slice_to_dynamic_slice(%arg0: tensor, %2 = mhlo.constant dense<[1, 1]> : tensor<2xi32> %3 = mhlo.real_dynamic_slice %arg0, %arg1, %1, %2 : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x4xf32> return %3 : tensor<1x4xf32> - // CHECK: [[START_INDEX_0_1D:%.*]] = "mhlo.slice"(%arg1) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + // CHECK: [[START_INDEX_0_1D:%.*]] = "mhlo.slice"(%arg1) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> // CHECK-NEXT: [[START_INDEX_0_0D:%.*]] = mhlo.reshape [[START_INDEX_0_1D]] : (tensor<1xi32>) -> tensor - // CHECK-NEXT: [[START_INDEX_1_1D:%.*]] = "mhlo.slice"(%arg1) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + // CHECK-NEXT: [[START_INDEX_1_1D:%.*]] = "mhlo.slice"(%arg1) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> // CHECK-NEXT: [[START_INDEX_1_0D:%.*]] = mhlo.reshape [[START_INDEX_1_1D]] : (tensor<1xi32>) -> tensor - // CHECK-NEXT: [[RESULT:%.*]] = "mhlo.dynamic_slice"(%arg0, [[START_INDEX_0_0D]], [[START_INDEX_1_0D]]) { + // CHECK-NEXT: [[RESULT:%.*]] = "mhlo.dynamic_slice"(%arg0, [[START_INDEX_0_0D]], [[START_INDEX_1_0D]]) <{ // CHECK-SAME: slice_sizes = dense<[1, 4]> : tensor<2xi64> - // CHECK-SAME: } : (tensor, tensor, tensor) -> tensor<1x4xf32> + // CHECK-SAME: }> : (tensor, tensor, tensor) -> tensor<1x4xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x4xf32> } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/concatenate.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/concatenate.mlir index fc937c1657e37..a770286fbc949 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/concatenate.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/concatenate.mlir @@ -3,17 +3,17 @@ // CHECK-LABEL: func @single_operand // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func.func @single_operand(%arg: tensor<1x2xf32>) -> tensor<1x2xf32> { - %0 = "mhlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %0 = "mhlo.concatenate"(%arg) <{dimension = 0 : i64}> : (tensor<1x2xf32>) -> tensor<1x2xf32> // CHECK-NEXT: return [[ARG]] func.return %0 : tensor<1x2xf32> } // ----- -// CHECK-LABEL: func @operand_with_unknown_rank -func.func @operand_with_unknown_rank(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { +// CHECK-LABEL: func @operand_with_dynamic_shape +func.func @operand_with_dynamic_shape(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-NEXT: mhlo.concatenate - %0 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %0 = "mhlo.concatenate"(%arg0, %arg1) <{dimension = 0 : i64}> : (tensor, tensor) -> tensor // CHECK-NEXT: return - func.return %0 : tensor<*xf32> + func.return %0 : tensor } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/convolution.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/convolution.mlir index 4316d22d08aaf..526ecb3968895 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/convolution.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/convolution.mlir @@ -28,8 +28,8 @@ func.func @convolution_is_dot_general_swap(%arg0: tensor<5x6xf32>, %arg1: tensor func.func @conv_grouped_is_dot(%arg0: tensor<5x12xf32>, %arg1: tensor<2x6xf32>) -> tensor<5x6xf32> { // CHECK: %[[RES0:.+]] = mhlo.reshape %arg0 : (tensor<5x12xf32>) -> tensor<5x6x2xf32> // CHECK: %[[RES1:.+]] = mhlo.reshape %arg1 : (tensor<2x6xf32>) -> tensor<6x1x2xf32> - // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[RES0]], %[[RES1]]) {dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]} - // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%2) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} + // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[RES0]], %[[RES1]]) <{dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]}> + // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%2) <{permutation = dense<[1, 0, 2]> : tensor<3xi64>}> // CHECK: %[[OUT:.+]] = mhlo.reshape %3 : (tensor<5x6x1xf32>) -> tensor<5x6xf32> %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, f]x[i, o]->[b, f], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 6 : i64, precision_config = [#mhlo, #mhlo]} : (tensor<5x12xf32>, tensor<2x6xf32>) -> tensor<5x6xf32> // CHECK: return %[[OUT]] @@ -42,8 +42,8 @@ func.func @conv_grouped_is_dot(%arg0: tensor<5x12xf32>, %arg1: tensor<2x6xf32>) func.func @conv_grouped_is_dot_multi(%arg0: tensor<5x4xf32>, %arg1: tensor<2x6xf32>) -> tensor<5x6xf32> { // CHECK: %[[LHS:.+]] = mhlo.reshape %arg0 : (tensor<5x4xf32>) -> tensor<5x2x2xf32> // CHECK: %[[RHS:.+]] = mhlo.reshape %arg1 : (tensor<2x6xf32>) -> tensor<2x3x2xf32> - // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[LHS]], %[[RHS]]) {dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]} - // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%[[DOT]]) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} + // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[LHS]], %[[RHS]]) <{dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]}> + // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%[[DOT]]) <{permutation = dense<[1, 0, 2]> : tensor<3xi64>}> // CHECK: %[[OUT:.+]] = mhlo.reshape %[[TRANSPOSE]] : (tensor<5x2x3xf32>) -> tensor<5x6xf32> %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, f]x[i, o]->[b, f], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 2 : i64, precision_config = [#mhlo, #mhlo]} : (tensor<5x4xf32>, tensor<2x6xf32>) -> tensor<5x6xf32> // CHECK: return %[[OUT]] @@ -56,8 +56,8 @@ func.func @conv_grouped_is_dot_multi(%arg0: tensor<5x4xf32>, %arg1: tensor<2x6xf func.func @conv_grouped_is_dot_transpose_rhs(%arg0: tensor<5x4xf32>, %arg1: tensor<6x2xf32>) -> tensor<5x6xf32> { // CHECK: %[[LHS:.+]] = mhlo.reshape %arg0 : (tensor<5x4xf32>) -> tensor<5x2x2xf32> // CHECK: %[[RHS:.+]] = mhlo.reshape %arg1 : (tensor<6x2xf32>) -> tensor<2x2x3xf32> - // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[LHS]], %[[RHS]]) {dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]} - // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%[[DOT]]) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} + // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[LHS]], %[[RHS]]) <{dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]}> + // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%[[DOT]]) <{permutation = dense<[1, 0, 2]> : tensor<3xi64>}> // CHECK: %[[OUT:.+]] = mhlo.reshape %[[TRANSPOSE]] : (tensor<5x2x3xf32>) -> tensor<5x6xf32> %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, f]x[o, i]->[b, f], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 2 : i64, precision_config = [#mhlo, #mhlo]} : (tensor<5x4xf32>, tensor<6x2xf32>) -> tensor<5x6xf32> // CHECK: return %[[OUT]] @@ -70,8 +70,8 @@ func.func @conv_grouped_is_dot_transpose_rhs(%arg0: tensor<5x4xf32>, %arg1: tens func.func @conv_grouped_is_dot_transpose_ins(%arg0: tensor<4x5xf32>, %arg1: tensor<6x2xf32>) -> tensor<5x6xf32> { // CHECK: %[[LHS:.+]] = mhlo.reshape %arg0 : (tensor<4x5xf32>) -> tensor<2x2x5xf32> // CHECK: %[[RHS:.+]] = mhlo.reshape %arg1 : (tensor<6x2xf32>) -> tensor<2x2x3xf32> - // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[LHS]], %[[RHS]]) {dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]} - // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%[[DOT]]) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} + // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[LHS]], %[[RHS]]) <{dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]}> + // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%[[DOT]]) <{permutation = dense<[1, 0, 2]> : tensor<3xi64>}> // CHECK: %[[OUT:.+]] = mhlo.reshape %[[TRANSPOSE]] : (tensor<5x2x3xf32>) -> tensor<5x6xf32> %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [f, b]x[o, i]->[b, f], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 2 : i64, precision_config = [#mhlo, #mhlo]} : (tensor<4x5xf32>, tensor<6x2xf32>) -> tensor<5x6xf32> // CHECK: return %[[OUT]] @@ -84,8 +84,8 @@ func.func @conv_grouped_is_dot_transpose_ins(%arg0: tensor<4x5xf32>, %arg1: tens func.func @conv_grouped_is_dot_transpose_out(%arg0: tensor<5x4xf32>, %arg1: tensor<2x6xf32>) -> tensor<6x5xf32> { // CHECK: %[[LHS:.+]] = mhlo.reshape %arg0 : (tensor<5x4xf32>) -> tensor<5x2x2xf32> // CHECK: %[[RHS:.+]] = mhlo.reshape %arg1 : (tensor<2x6xf32>) -> tensor<2x3x2xf32> - // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[LHS]], %[[RHS]]) {dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]} - // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%[[DOT]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} + // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[LHS]], %[[RHS]]) <{dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]}> + // CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%[[DOT]]) <{permutation = dense<[0, 2, 1]> : tensor<3xi64>}> // CHECK: %[[OUT:.+]] = mhlo.reshape %[[TRANSPOSE]] %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, f]x[i, o]->[f, b], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 2 : i64, precision_config = [#mhlo, #mhlo]} : (tensor<5x4xf32>, tensor<2x6xf32>) -> tensor<6x5xf32> // CHECK: return %[[OUT]] diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/folder_limit.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/folder_limit.mlir index 1d62442436c7e..fcb4ae16f5fba 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/folder_limit.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/folder_limit.mlir @@ -152,7 +152,7 @@ func.func @compare_large_constants() -> tensor<65537xi1> { func.func @concatenate_small_constants() -> tensor<65536xi32> { // CHECK-NOT: mhlo.concatenate %0 = mhlo.constant dense<0> : tensor<32768xi32> - %1 = "mhlo.concatenate"(%0, %0) {dimension = 0 : i64} : (tensor<32768xi32>, tensor<32768xi32>) -> tensor<65536xi32> + %1 = "mhlo.concatenate"(%0, %0) <{dimension = 0 : i64}> : (tensor<32768xi32>, tensor<32768xi32>) -> tensor<65536xi32> func.return %1 : tensor<65536xi32> } @@ -160,7 +160,7 @@ func.func @concatenate_small_constants() -> tensor<65536xi32> { func.func @concatenate_large_constants() -> tensor<65538xi32> { // CHECK: mhlo.concatenate %0 = mhlo.constant dense<0> : tensor<32769xi32> - %1 = "mhlo.concatenate"(%0, %0) {dimension = 0 : i64} : (tensor<32769xi32>, tensor<32769xi32>) -> tensor<65538xi32> + %1 = "mhlo.concatenate"(%0, %0) <{dimension = 0 : i64}> : (tensor<32769xi32>, tensor<32769xi32>) -> tensor<65538xi32> func.return %1 : tensor<65538xi32> } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reduce.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reduce.mlir index 70e5f3e4b4fc9..994bc1041a687 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reduce.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reduce.mlir @@ -70,7 +70,7 @@ func.func @or_fold() -> (tensor, tensor) { // CHECK-LABEL: func @zero_ext func.func @zero_ext(%arg0: tensor<0xi1>) -> tensor { %0 = mhlo.constant dense : tensor - %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<0xi1> + %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<0xi1> %2 = mhlo.compare NE, %arg0, %1, UNSIGNED : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1> %3 = mhlo.convert %2 : (tensor<0xi1>) -> tensor<0xi32> %4 = mhlo.constant dense<0> : tensor diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reverse.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reverse.mlir index 33880f6adb139..0f98bbd754b48 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reverse.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reverse.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: func @noop // CHECK-SAME: (%[[ARG0:.*]]: tensor<1x2xf32>) func.func @noop(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> { - %0 = "mhlo.reverse"(%arg0) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %0 = "mhlo.reverse"(%arg0) <{dimensions = dense<[]> : tensor<0xi64>}> : (tensor<1x2xf32>) -> tensor<1x2xf32> // CHECK: return %[[ARG0]] func.return %0 : tensor<1x2xf32> } @@ -11,7 +11,7 @@ func.func @noop(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> { // CHECK-LABEL: func @dim1 // CHECK-SAME: (%[[ARG0:.*]]: tensor func.func @dim1(%arg0: tensor<9x1x2x1x42xf32>) -> tensor<9x1x2x1x42xf32> { - %0 = "mhlo.reverse"(%arg0) {dimensions = dense<[1,3]> : tensor<2xi64>} : (tensor<9x1x2x1x42xf32>) -> tensor<9x1x2x1x42xf32> + %0 = "mhlo.reverse"(%arg0) <{dimensions = dense<[1,3]> : tensor<2xi64>}> : (tensor<9x1x2x1x42xf32>) -> tensor<9x1x2x1x42xf32> // CHECK: return %[[ARG0]] func.return %0 : tensor<9x1x2x1x42xf32> } @@ -19,7 +19,7 @@ func.func @dim1(%arg0: tensor<9x1x2x1x42xf32>) -> tensor<9x1x2x1x42xf32> { // CHECK-LABEL: @noop_reverse_dynamic_shape // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func.func @noop_reverse_dynamic_shape(%arg0 : tensor<10x?x512xf32>) -> tensor<10x?x512xf32> { - %0 = "mhlo.reverse"(%arg0) {dimensions = dense<[0,1]> : tensor<2xi64>}: (tensor<10x?x512xf32>) -> tensor<10x?x512xf32> + %0 = "mhlo.reverse"(%arg0) <{dimensions = dense<[0,1]> : tensor<2xi64>}>: (tensor<10x?x512xf32>) -> tensor<10x?x512xf32> // CHECK-NEXT: "mhlo.reverse"([[ARG]]) func.return %0 : tensor<10x?x512xf32> } @@ -28,7 +28,7 @@ func.func @noop_reverse_dynamic_shape(%arg0 : tensor<10x?x512xf32>) -> tensor<10 func.func @reverse_fold_constant_int() -> tensor<0x2x0xi64> { %cst = mhlo.constant dense<> : tensor<0x2x0xi64> // CHECK: mhlo.constant dense<> - %1 = "mhlo.reverse"(%cst) {dimensions = dense<[0,1]> : tensor<2xi64>} : (tensor<0x2x0xi64>) -> tensor<0x2x0xi64> + %1 = "mhlo.reverse"(%cst) <{dimensions = dense<[0,1]> : tensor<2xi64>}> : (tensor<0x2x0xi64>) -> tensor<0x2x0xi64> func.return %1 : tensor<0x2x0xi64> } @@ -36,7 +36,7 @@ func.func @reverse_fold_constant_int() -> tensor<0x2x0xi64> { func.func @reverse_fold_constant_int_0() -> tensor<0xi64> { %cst = mhlo.constant dense<> : tensor<0xi64> // CHECK: mhlo.constant dense<> - %1 = "mhlo.reverse"(%cst) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<0xi64>) -> tensor<0xi64> + %1 = "mhlo.reverse"(%cst) <{dimensions = dense<[0]> : tensor<1xi64>}> : (tensor<0xi64>) -> tensor<0xi64> func.return %1 : tensor<0xi64> } @@ -44,7 +44,7 @@ func.func @reverse_fold_constant_int_0() -> tensor<0xi64> { func.func @reverse_fold_constant_int_1() -> tensor<3x2xi32> { %cst = mhlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32> // CHECK: mhlo.constant dense<{{\[\[}}6, 5], [4, 3], [2, 1]]> - %1 = "mhlo.reverse"(%cst) {dimensions = dense<[0,1]> : tensor<2xi64>} : (tensor<3x2xi32>) -> tensor<3x2xi32> + %1 = "mhlo.reverse"(%cst) <{dimensions = dense<[0,1]> : tensor<2xi64>}> : (tensor<3x2xi32>) -> tensor<3x2xi32> func.return %1 : tensor<3x2xi32> } @@ -52,7 +52,7 @@ func.func @reverse_fold_constant_int_1() -> tensor<3x2xi32> { func.func @reverse_fold_constant_int_2() -> tensor<3x2xi32> { %cst = mhlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32> // CHECK: mhlo.constant dense<{{\[\[}}5, 6], [3, 4], [1, 2]]> - %1 = "mhlo.reverse"(%cst) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<3x2xi32>) -> tensor<3x2xi32> + %1 = "mhlo.reverse"(%cst) <{dimensions = dense<[0]> : tensor<1xi64>}> : (tensor<3x2xi32>) -> tensor<3x2xi32> func.return %1 : tensor<3x2xi32> } @@ -60,7 +60,7 @@ func.func @reverse_fold_constant_int_2() -> tensor<3x2xi32> { func.func @reverse_fold_constant_int_3() -> tensor<3x2xi32> { %cst = mhlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32> // CHECK: mhlo.constant dense<{{\[\[}}2, 1], [4, 3], [6, 5]]> - %1 = "mhlo.reverse"(%cst) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3x2xi32>) -> tensor<3x2xi32> + %1 = "mhlo.reverse"(%cst) <{dimensions = dense<[1]> : tensor<1xi64>}> : (tensor<3x2xi32>) -> tensor<3x2xi32> func.return %1 : tensor<3x2xi32> } @@ -68,7 +68,7 @@ func.func @reverse_fold_constant_int_3() -> tensor<3x2xi32> { func.func @reverse_fold_constant_int_4() -> tensor<2x3x2xi32> { %cst = mhlo.constant dense<[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]> : tensor<2x3x2xi32> // CHECK: mhlo.constant dense<{{\[\[\[}}12, 11], [10, 9], [8, 7]], {{\[\[}}6, 5], [4, 3], [2, 1]]]> - %1 = "mhlo.reverse"(%cst) {dimensions = dense<[0,1,2]> : tensor<3xi64>} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32> + %1 = "mhlo.reverse"(%cst) <{dimensions = dense<[0,1,2]> : tensor<3xi64>}> : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32> func.return %1 : tensor<2x3x2xi32> } @@ -76,7 +76,7 @@ func.func @reverse_fold_constant_int_4() -> tensor<2x3x2xi32> { func.func @reverse_fold_constant_float() -> tensor<3x2xf32> { %cst = mhlo.constant dense<[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]> : tensor<3x2xf32> // CHECK: mhlo.constant dense<{{\[\[}}6.000000e+00, 5.000000e+00], [4.000000e+00, 3.000000e+00], [2.000000e+00, 1.000000e+00]]> - %1 = "mhlo.reverse"(%cst) {dimensions = dense<[0,1]> : tensor<2xi64>} : (tensor<3x2xf32>) -> tensor<3x2xf32> + %1 = "mhlo.reverse"(%cst) <{dimensions = dense<[0,1]> : tensor<2xi64>}> : (tensor<3x2xf32>) -> tensor<3x2xf32> func.return %1 : tensor<3x2xf32> } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/scatter.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/scatter.mlir index 1a7a6f343e933..901428df8a10c 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/scatter.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/scatter.mlir @@ -84,11 +84,11 @@ func.func @scatter_full_overwrite_add( mhlo.return %2 : tensor }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true} : (tensor<1xbf16>, tensor<0xi32>, tensor<1xbf16>) -> tensor<1xbf16> - // CHECK: "mhlo.map"(%[[ARG0]], %[[ARG2]]) ({ + // CHECK: "mhlo.map"(%[[ARG0]], %[[ARG2]]) <{dimensions = dense<0> : tensor<1xi64>}> ({ // CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): // CHECK: %[[ADD:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor // CHECK: mhlo.return %[[ADD]] : tensor - // CHECK: }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<1xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + // CHECK: }) : (tensor<1xbf16>, tensor<1xbf16>) -> tensor<1xbf16> func.return %scatter : tensor<1xbf16> } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/transpose.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/transpose.mlir index 2b4b8cd9982a2..63d9c6d537755 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/transpose.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/transpose.mlir @@ -4,7 +4,7 @@ func.func @transpose_splat_constant() -> tensor<5x10xf32> { // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<5x10xf32> %cst = mhlo.constant dense<1.000000e+00> : tensor<10x5xf32> - %0 = "mhlo.transpose"(%cst) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<10x5xf32>) -> tensor<5x10xf32> + %0 = "mhlo.transpose"(%cst) <{permutation = dense<[1, 0]> : tensor<2xi64>}> : (tensor<10x5xf32>) -> tensor<5x10xf32> // CHECK-NEXT: return [[CST]] func.return %0 : tensor<5x10xf32> } @@ -14,7 +14,7 @@ func.func @transpose_splat_constant() -> tensor<5x10xf32> { // CHECK-LABEL: func @remove_noop // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func.func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> { - %0 = "mhlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> + %0 = "mhlo.transpose"(%arg) <{permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}>: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> // CHECK-NEXT: return [[ARG]] func.return %0 : tensor<2x3x9x5xi32> } @@ -25,7 +25,7 @@ func.func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> { // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func.func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { // CHECK-NEXT: "mhlo.transpose"([[ARG]]) - %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> + %0 = "mhlo.transpose"(%arg) <{permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}>: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> func.return %0 : tensor<3x2x5x9xi32> } @@ -35,7 +35,7 @@ func.func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32 // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func.func @keep_same_shape_real_transpose(%arg : tensor<4x4xi32>) -> tensor<4x4xi32> { // CHECK-NEXT: "mhlo.transpose"([[ARG]]) - %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32> + %0 = "mhlo.transpose"(%arg) <{permutation = dense<[1, 0]> : tensor<2xi64>}>: (tensor<4x4xi32>) -> tensor<4x4xi32> func.return %0 : tensor<4x4xi32> } @@ -44,8 +44,8 @@ func.func @keep_same_shape_real_transpose(%arg : tensor<4x4xi32>) -> tensor<4x4x // CHECK-LABEL: @eliminate_redundant_transpose // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func.func @eliminate_redundant_transpose(%arg : tensor<3x4x16x2xf32>) -> tensor<3x2x16x4xf32> { - %0 = "mhlo.transpose"(%arg) {permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}: (tensor<3x4x16x2xf32>) -> tensor<3x2x4x16xf32> - %1 = "mhlo.transpose"(%0) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>}: (tensor<3x2x4x16xf32>) -> tensor<3x2x16x4xf32> + %0 = "mhlo.transpose"(%arg) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}>: (tensor<3x4x16x2xf32>) -> tensor<3x2x4x16xf32> + %1 = "mhlo.transpose"(%0) <{permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>}>: (tensor<3x2x4x16xf32>) -> tensor<3x2x16x4xf32> // CHECK: [[RET:%[a-zA-Z0-9]+]] = "mhlo.transpose"([[ARG]]) // CHECK-SAME: dense<[0, 3, 2, 1] // CHECK-NEXT: return [[RET]] @@ -57,7 +57,7 @@ func.func @eliminate_redundant_transpose(%arg : tensor<3x4x16x2xf32>) -> tensor< // CHECK-LABEL: @simplify_transpose_case1 // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func.func @simplify_transpose_case1(%arg : tensor<10x1x512xf32>) -> tensor<1x10x512xf32> { - %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 2]> : tensor<3xi64>}: (tensor<10x1x512xf32>) -> tensor<1x10x512xf32> + %0 = "mhlo.transpose"(%arg) <{permutation = dense<[1, 0, 2]> : tensor<3xi64>}>: (tensor<10x1x512xf32>) -> tensor<1x10x512xf32> // CHECK-NEXT: mhlo.reshape [[ARG]] func.return %0 : tensor<1x10x512xf32> } @@ -67,7 +67,7 @@ func.func @simplify_transpose_case1(%arg : tensor<10x1x512xf32>) -> tensor<1x10x // CHECK-LABEL: @simplify_transpose_case2 // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func.func @simplify_transpose_case2(%arg : tensor<10x1x512x1xf32>) -> tensor<1x1x10x512xf32> { - %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 3, 0, 2]> : tensor<4xi64>}: (tensor<10x1x512x1xf32>) -> tensor<1x1x10x512xf32> + %0 = "mhlo.transpose"(%arg) <{permutation = dense<[1, 3, 0, 2]> : tensor<4xi64>}>: (tensor<10x1x512x1xf32>) -> tensor<1x1x10x512xf32> // CHECK-NEXT: mhlo.reshape [[ARG]] func.return %0 : tensor<1x1x10x512xf32> } @@ -77,7 +77,7 @@ func.func @simplify_transpose_case2(%arg : tensor<10x1x512x1xf32>) -> tensor<1x1 // CHECK-LABEL: @not_simplify_transpose_dynamic_shape // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func.func @not_simplify_transpose_dynamic_shape(%arg : tensor<10x?x512xf32>) -> tensor { - %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 2]> : tensor<3xi64>}: (tensor<10x?x512xf32>) -> tensor + %0 = "mhlo.transpose"(%arg) <{permutation = dense<[1, 0, 2]> : tensor<3xi64>}>: (tensor<10x?x512xf32>) -> tensor // CHECK-NEXT: "mhlo.transpose"([[ARG]]) func.return %0 : tensor } @@ -87,8 +87,8 @@ func.func @not_simplify_transpose_dynamic_shape(%arg : tensor<10x?x512xf32>) -> // CHECK-LABEL: func @broadcast_transpose // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func.func @broadcast_transpose(%arg0 : tensor<64xf32>) -> tensor<5x64x31x95xf32> { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<5x31x95x64xf32> - %1 = "mhlo.transpose"(%0) {permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>} : (tensor<5x31x95x64xf32>) -> tensor<5x64x31x95xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<64xf32>) -> tensor<5x31x95x64xf32> + %1 = "mhlo.transpose"(%0) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<5x31x95x64xf32>) -> tensor<5x64x31x95xf32> // CHECK: [[RET:%[a-zA-Z0-9]+]] = "mhlo.broadcast_in_dim"([[ARG]]) // CHECK-SAME: dense<1> // CHECK-NEXT: return [[RET]] @@ -100,8 +100,8 @@ func.func @broadcast_transpose(%arg0 : tensor<64xf32>) -> tensor<5x64x31x95xf32> // CHECK-LABEL: func @broadcast_transpose_non_dim // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func.func @broadcast_transpose_non_dim(%arg0 : tensor) -> tensor<5x64x31x95xf32> { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<5x31x95x64xf32> - %1 = "mhlo.transpose"(%0) {permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>} : (tensor<5x31x95x64xf32>) -> tensor<5x64x31x95xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<5x31x95x64xf32> + %1 = "mhlo.transpose"(%0) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<5x31x95x64xf32>) -> tensor<5x64x31x95xf32> // CHECK: [[RET:%[a-zA-Z0-9]+]] = "mhlo.broadcast_in_dim"([[ARG]]) // CHECK-SAME: dense<> // CHECK-NEXT: return [[RET]] @@ -113,8 +113,8 @@ func.func @broadcast_transpose_non_dim(%arg0 : tensor) -> tensor<5x64x31x95 // CHECK-LABEL: func @broadcast_transpose_multi_dim // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func.func @broadcast_transpose_multi_dim(%arg0 : tensor<95x64xf32>) -> tensor<5x64x31x95xf32> { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<95x64xf32>) -> tensor<5x31x95x64xf32> - %1 = "mhlo.transpose"(%0) {permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>} : (tensor<5x31x95x64xf32>) -> tensor<5x64x31x95xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>}> : (tensor<95x64xf32>) -> tensor<5x31x95x64xf32> + %1 = "mhlo.transpose"(%0) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<5x31x95x64xf32>) -> tensor<5x64x31x95xf32> // CHECK: [[RET:%[a-zA-Z0-9]+]] = "mhlo.broadcast_in_dim"([[ARG]]) // CHECK-SAME: dense<[3, 1]> // CHECK-NEXT: return [[RET]] @@ -127,7 +127,7 @@ func.func @broadcast_transpose_multi_dim(%arg0 : tensor<95x64xf32>) -> tensor<5x // CHECK-NOT: mhlo.transpose func.func @transpose_splat_constant_quantized_per_tensor() -> tensor<5x10x!quant.uniform> { %cst = mhlo.constant() {value = dense<42> : tensor<10x5xi8>} : () -> tensor<10x5x!quant.uniform> - %0 = "mhlo.transpose"(%cst) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<10x5x!quant.uniform>) -> tensor<5x10x!quant.uniform> + %0 = "mhlo.transpose"(%cst) <{permutation = dense<[1, 0]> : tensor<2xi64>}> : (tensor<10x5x!quant.uniform>) -> tensor<5x10x!quant.uniform> // CHECK-NEXT: [[CST:%.+]] = mhlo.constant // CHECK-SAME: tensor<5x10x!quant.uniform> // CHECK-NEXT: return [[CST]] @@ -140,7 +140,7 @@ func.func @transpose_splat_constant_quantized_per_tensor() -> tensor<5x10x!quant // CHECK-NOT: mhlo.transpose func.func @transpose_splat_constant_quantized_per_axis() -> tensor<2x10x!quant.uniform> { %cst = mhlo.constant() {value = dense<42> : tensor<10x2xi8>} : () -> tensor<10x2x!quant.uniform> - %0 = "mhlo.transpose"(%cst) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<10x2x!quant.uniform>) -> tensor<2x10x!quant.uniform> + %0 = "mhlo.transpose"(%cst) <{permutation = dense<[1, 0]> : tensor<2xi64>}> : (tensor<10x2x!quant.uniform>) -> tensor<2x10x!quant.uniform> // CHECK-NEXT: [[CST:%.+]] = mhlo.constant // CHECK-SAME: tensor<2x10x!quant.uniform> // CHECK-NEXT: return [[CST]] @@ -153,9 +153,9 @@ func.func @transpose_splat_constant_quantized_per_axis() -> tensor<2x10x!quant.u // CHECK-LABEL: func @nofold_nonsplat_quant_constant func.func @nofold_nonsplat_quant_constant() -> tensor<4x2x!quant.uniform> { %cst = mhlo.constant() {value = dense<[[1, 2, 3, 4],[5, 6, 7, 8]]> : tensor<2x4xi8>} : () -> tensor<2x4x!quant.uniform> - %0 = "mhlo.transpose"(%cst) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x4x!quant.uniform>) -> tensor<4x2x!quant.uniform> + %0 = "mhlo.transpose"(%cst) <{permutation = dense<[1, 0]> : tensor<2xi64>}> : (tensor<2x4x!quant.uniform>) -> tensor<4x2x!quant.uniform> // CHECK: [[TRANSPOSED:%.+]] = "mhlo.transpose" // CHECK-SAME: -> tensor<4x2x!quant.uniform> // CHECK-NEXT: return [[TRANSPOSED]] func.return %0 : tensor<4x2x!quant.uniform> -} \ No newline at end of file +} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/expand_ops_simplifier.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/expand_ops_simplifier.mlir index 58f60bb795cec..ee6601b0ae08a 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/expand_ops_simplifier.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/expand_ops_simplifier.mlir @@ -2,7 +2,10 @@ func.func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({ + %1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) <{ + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> + }> ({ ^bb0(%arg3: tensor, %arg4: tensor): %2 = "mhlo.compare"(%arg3, %arg4) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor "mhlo.return"(%2) : (tensor) -> () @@ -10,10 +13,7 @@ func.func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) ^bb0(%arg3: tensor, %arg4: tensor): %2 = mhlo.add %arg3, %arg4 : tensor "mhlo.return"(%2) : (tensor) -> () - }) { - window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, - window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> - } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> + }) : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> func.return %1 : tensor<10x24x24x64xf32> } @@ -24,11 +24,11 @@ func.func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) // CHECK-DAG: %[[NEG_1:.*]] = mhlo.constant dense<-1> : tensor // CHECK-DAG: %[[INIT:.*]] = mhlo.constant dense<0.000000e+00> : tensor<10x24x24x64xf32> // CHECK-DAG: %[[C0:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[IOTA_0:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<10x24x24x64xi64> -// CHECK: %[[IOTA_1:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<10x24x24x64xi64> -// CHECK: %[[IOTA_2:.*]] = "mhlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<10x24x24x64xi64> -// CHECK: %[[IOTA_3:.*]] = "mhlo.iota"() {iota_dimension = 3 : i64} : () -> tensor<10x24x24x64xi64> -// CHECK: %[[REDUCE_WINDOW:.*]]:5 = "mhlo.reduce_window"(%[[OPERAND]], %[[IOTA_0]], %[[IOTA_1]], %[[IOTA_2]], %[[IOTA_3]], %[[C0]], %[[NEG_1]], %[[NEG_1]], %[[NEG_1]], %[[NEG_1]]) ({ +// CHECK: %[[IOTA_0:.*]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<10x24x24x64xi64> +// CHECK: %[[IOTA_1:.*]] = "mhlo.iota"() <{iota_dimension = 1 : i64}> : () -> tensor<10x24x24x64xi64> +// CHECK: %[[IOTA_2:.*]] = "mhlo.iota"() <{iota_dimension = 2 : i64}> : () -> tensor<10x24x24x64xi64> +// CHECK: %[[IOTA_3:.*]] = "mhlo.iota"() <{iota_dimension = 3 : i64}> : () -> tensor<10x24x24x64xi64> +// CHECK: %[[REDUCE_WINDOW:.*]]:5 = "mhlo.reduce_window"(%[[OPERAND]], %[[IOTA_0]], %[[IOTA_1]], %[[IOTA_2]], %[[IOTA_3]], %[[C0]], %[[NEG_1]], %[[NEG_1]], %[[NEG_1]], %[[NEG_1]]) <{window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({ // CHECK: ^bb0(%[[VAL_10:.*]]: tensor, %[[VAL_11:.*]]: tensor, %[[VAL_12:.*]]: tensor, %[[VAL_13:.*]]: tensor, %[[VAL_14:.*]]: tensor, %[[VAL_15:.*]]: tensor, %[[VAL_16:.*]]: tensor, %[[VAL_17:.*]]: tensor, %[[VAL_18:.*]]: tensor, %[[VAL_19:.*]]: tensor): // CHECK: %[[VAL_20:.*]] = mhlo.compare NE, %[[VAL_11]], %[[NEG_1]] // CHECK: %[[VAL_21:.*]] = mhlo.compare NE, %[[VAL_16]], %[[NEG_1]] @@ -42,16 +42,16 @@ func.func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) // CHECK: %[[SELECTED_3:.*]] = mhlo.select %[[VAL_25]], %[[VAL_13]], %[[VAL_18]] // CHECK: %[[SELECTED_4:.*]] = mhlo.select %[[VAL_25]], %[[VAL_14]], %[[VAL_19]] // CHECK: mhlo.return %[[SELECTED_0]], %[[SELECTED_1]], %[[SELECTED_2]], %[[SELECTED_3]], %[[SELECTED_4]] -// CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<10x24x24x64xf32>, tensor<10x24x24x64xi64>, tensor<10x24x24x64xi64>, tensor<10x24x24x64xi64>, tensor<10x24x24x64xi64>, tensor, tensor, tensor, tensor, tensor) -> (tensor<10x12x12x64xf32>, tensor<10x12x12x64xi64>, tensor<10x12x12x64xi64>, tensor<10x12x12x64xi64>, tensor<10x12x12x64xi64>) +// CHECK: }) : (tensor<10x24x24x64xf32>, tensor<10x24x24x64xi64>, tensor<10x24x24x64xi64>, tensor<10x24x24x64xi64>, tensor<10x24x24x64xi64>, tensor, tensor, tensor, tensor, tensor) -> (tensor<10x12x12x64xf32>, tensor<10x12x12x64xi64>, tensor<10x12x12x64xi64>, tensor<10x12x12x64xi64>, tensor<10x12x12x64xi64>) // CHECK: %[[RESHAPE_0:.*]] = mhlo.reshape %[[REDUCE_WINDOW]]#1 : (tensor<10x12x12x64xi64>) -> tensor<10x12x12x64x1xi64> // CHECK: %[[RESHAPE_1:.*]] = mhlo.reshape %[[REDUCE_WINDOW]]#2 : (tensor<10x12x12x64xi64>) -> tensor<10x12x12x64x1xi64> // CHECK: %[[RESHAPE_2:.*]] = mhlo.reshape %[[REDUCE_WINDOW]]#3 : (tensor<10x12x12x64xi64>) -> tensor<10x12x12x64x1xi64> // CHECK: %[[RESHAPE_3:.*]] = mhlo.reshape %[[REDUCE_WINDOW]]#4 : (tensor<10x12x12x64xi64>) -> tensor<10x12x12x64x1xi64> -// CHECK: %[[CONCAT:.*]] = "mhlo.concatenate"(%[[RESHAPE_0]], %[[RESHAPE_1]], %[[RESHAPE_2]], %[[RESHAPE_3]]) {dimension = 4 : i64} -// CHECK: %[[SCATTER:.*]] = "mhlo.scatter"(%[[INIT]], %[[CONCAT]], %[[SOURCE]]) ({ +// CHECK: %[[CONCAT:.*]] = "mhlo.concatenate"(%[[RESHAPE_0]], %[[RESHAPE_1]], %[[RESHAPE_2]], %[[RESHAPE_3]]) <{dimension = 4 : i64}> +// CHECK: %[[SCATTER:.*]] = "mhlo.scatter"(%[[INIT]], %[[CONCAT]], %[[SOURCE]]) <{indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false}> ({ // CHECK: ^bb0(%[[VAL_38:.*]]: tensor, %[[VAL_39:.*]]: tensor): // CHECK: %[[UPDATE:.*]] = mhlo.add %[[VAL_38]], %[[VAL_39]] : tensor // CHECK: mhlo.return %[[UPDATE]] : tensor -// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64x4xi64>, tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> +// CHECK: }) : (tensor<10x24x24x64xf32>, tensor<10x12x12x64x4xi64>, tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> // CHECK: return %[[SCATTER]] : tensor<10x24x24x64xf32> -// CHECK: } \ No newline at end of file +// CHECK: } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/group_reduction_dimensions.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/group_reduction_dimensions.mlir index e49f3c12429da..ae9349a7b4367 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/group_reduction_dimensions.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/group_reduction_dimensions.mlir @@ -407,7 +407,7 @@ func.func @needs_transpose(%arg : tensor<10x11x12x13x14x15x16x17x18x19xf32>) // CHECK-ROW-RED-SAME: into tensor<110x156x210x272x342xf32> // CHECK-ROW-RED: %[[CTED:.*]] = "mhlo.transpose"(%[[CED]]) // CHECK-ROW-RED-SAME: {permutation = dense<[0, 2, 4, 1, 3]> -// CHECK-ROW-RED-SAME: : tensor<5xi64>} : (tensor<110x156x210x272x342xf32>) +// CHECK-ROW-RED-SAME: : tensor<5xi64>}> : (tensor<110x156x210x272x342xf32>) // CHECK-ROW-RED-SAME: -> tensor<110x210x342x156x272xf32> // CHECK-ROW-RED: %[[CTCED:.*]] = tensor.collapse_shape %[[CTED]] // CHECK-ROW-RED-SAME: {{\[}}[0, 1, 2], [3, 4]{{\]}} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-broadcast-to-broadcast-in-dim.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-broadcast-to-broadcast-in-dim.mlir index 343ce3467f148..f91e65b2cc46f 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-broadcast-to-broadcast-in-dim.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-broadcast-to-broadcast-in-dim.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: @broadcast_to_broadcast_in_dim func.func @broadcast_to_broadcast_in_dim(%arg0: tensor<4xi64>) -> tensor<1x2x3x4xi64> { - // CHECK: [[RES:%.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<4xi64>) -> tensor<1x2x3x4xi64> + // CHECK: [[RES:%.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> : (tensor<4xi64>) -> tensor<1x2x3x4xi64> %0 = "mhlo.broadcast"(%arg0) { broadcast_sizes = dense<[1, 2, 3]> : tensor<3xi64> } : (tensor<4xi64>) -> tensor<1x2x3x4xi64> @@ -14,7 +14,7 @@ func.func @broadcast_to_broadcast_in_dim(%arg0: tensor<4xi64>) -> tensor<1x2x3x4 // CHECK-LABEL: @broadcast_to_broadcast_in_dim_dynamic_operand func.func @broadcast_to_broadcast_in_dim_dynamic_operand(%arg0: tensor) -> tensor<1x2x3x4xi64> { - // CHECK: [[RES:%.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor) -> tensor<1x2x3x4xi64> + // CHECK: [[RES:%.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>}> : (tensor) -> tensor<1x2x3x4xi64> %0 = "mhlo.broadcast"(%arg0) { broadcast_sizes = dense<[1, 2]> : tensor<2xi64> } : (tensor) -> tensor<1x2x3x4xi64> @@ -31,14 +31,3 @@ func.func @broadcast_to_broadcast_in_dim_dynamic_result(%arg0: tensor<3x4xi64>) } : (tensor<3x4xi64>) -> tensor<1x2x?x4xi64> func.return %0 : tensor<1x2x?x4xi64> } - -// ----- - -// CHECK-LABEL: @broadcast_to_broadcast_in_dim_unranked_result -func.func @broadcast_to_broadcast_in_dim_unranked_result(%arg0: tensor<3x4xi64>) -> tensor<*xi64> { - // CHECK: "mhlo.broadcast" - %0 = "mhlo.broadcast"(%arg0) { - broadcast_sizes = dense<[1, 2]> : tensor<2xi64> - } : (tensor<3x4xi64>) -> tensor<*xi64> - func.return %0 : tensor<*xi64> -} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-cross-replica-sum-to-all-reduce.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-cross-replica-sum-to-all-reduce.mlir index 70496c07d338e..69932525ec0f7 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-cross-replica-sum-to-all-reduce.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-cross-replica-sum-to-all-reduce.mlir @@ -2,11 +2,12 @@ // CHECK-LABEL: @cross_replica_sum_to_all_reduce func.func @cross_replica_sum_to_all_reduce(%arg0 : tensor<4xi64>) -> tensor<4xi64> { - // CHECK: [[RES:%.+]] = "mhlo.all_reduce"(%arg0) ({ + // CHECK: [[RES:%.+]] = "mhlo.all_reduce"(%arg0) + // CHECK-SAME{LITERAL}: <{replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>}> ({ // CHECK: ^bb0(%arg1: tensor, %arg2: tensor): // CHECK: [[ADD:%.+]] = mhlo.add %arg1, %arg2 : tensor // CHECK: mhlo.return [[ADD]] : tensor - // CHECK-SAME{LITERAL} }) {replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>} : (tensor<4xi64>) -> tensor<4xi64> + // CHECK-NEXT: }) : (tensor<4xi64>) -> tensor<4xi64> %0 = "mhlo.cross-replica-sum"(%arg0) { replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> } : (tensor<4xi64>) -> tensor<4xi64> diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-dot-general-to-dot.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-dot-general-to-dot.mlir index b2f5eb6b25b86..9fc51f4272425 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-dot-general-to-dot.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-dot-general-to-dot.mlir @@ -4,7 +4,7 @@ func.func @dot_general_is_dot(%arg0: tensor<5x6xf32>, %arg1: tensor<6x?xf32>) -> tensor<5x?xf32> { // CHECK: %[[DOT:.+]] = "mhlo.dot"(%arg0, %arg1) // CHECK-SAME: precision_config = [#mhlo, #mhlo] - %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]} : (tensor<5x6xf32>, tensor<6x?xf32>) -> tensor<5x?xf32> + %0 = "mhlo.dot_general"(%arg0, %arg1) <{dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]}> : (tensor<5x6xf32>, tensor<6x?xf32>) -> tensor<5x?xf32> // CHECK: %[[DOT]] return %0 : tensor<5x?xf32> } @@ -14,9 +14,9 @@ func.func @dot_general_is_dot(%arg0: tensor<5x6xf32>, %arg1: tensor<6x?xf32>) -> // CHECK-LABEL: @dot_general_is_dot_keep_attrs func.func @dot_general_is_dot_keep_attrs(%arg0: tensor<5x6xf32>, %arg1: tensor<6x?xf32>) -> tensor<5x?xf32> { // CHECK: %[[DOT:.+]] = "mhlo.dot"(%arg0, %arg1) - // CHECK-SAME: mhlo.frontend_attributes = {test_name = "test_value"} // CHECK-SAME: precision_config = [#mhlo, #mhlo] - %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot, mhlo.frontend_attributes = {test_name = "test_value"}, precision_config = [#mhlo, #mhlo]} : (tensor<5x6xf32>, tensor<6x?xf32>) -> tensor<5x?xf32> + // CHECK-SAME: mhlo.frontend_attributes = {test_name = "test_value"} + %0 = "mhlo.dot_general"(%arg0, %arg1) <{dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]}> {mhlo.frontend_attributes = {test_name = "test_value"}} : (tensor<5x6xf32>, tensor<6x?xf32>) -> tensor<5x?xf32> // CHECK: %[[DOT]] return %0 : tensor<5x?xf32> } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-dot-to-dot-general.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-dot-to-dot-general.mlir index c338e5f52f71a..f45ab298ebb47 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-dot-to-dot-general.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-dot-to-dot-general.mlir @@ -2,12 +2,12 @@ // CHECK-LABEL: @dot_to_dot_general_vector_dot_vector func.func @dot_to_dot_general_vector_dot_vector(%arg0 : tensor<4xi64>, %arg1 : tensor<4xi64>) -> tensor { - // CHECK: [[RES:%.+]] = "mhlo.dot_general"(%arg0, %arg1) { + // CHECK: [[RES:%.+]] = "mhlo.dot_general"(%arg0, %arg1) <{ // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< // CHECK-SAME: lhs_contracting_dimensions = [0], // CHECK-SAME: rhs_contracting_dimensions = [0] // CHECK-SAME: > - // CHECK-SAME: } : (tensor<4xi64>, tensor<4xi64>) -> tensor + // CHECK-SAME: }> : (tensor<4xi64>, tensor<4xi64>) -> tensor %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<4xi64>, tensor<4xi64>) -> tensor func.return %0 : tensor } @@ -16,12 +16,12 @@ func.func @dot_to_dot_general_vector_dot_vector(%arg0 : tensor<4xi64>, %arg1 : t // CHECK-LABEL: @dot_to_dot_general_matrix_dot_vector func.func @dot_to_dot_general_matrix_dot_vector(%arg0 : tensor<4x5xi64>, %arg1 : tensor<5xi64>) -> tensor<4xi64> { - // CHECK: [[RES:%.+]] = "mhlo.dot_general"(%arg0, %arg1) { + // CHECK: [[RES:%.+]] = "mhlo.dot_general"(%arg0, %arg1) <{ // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< // CHECK-SAME: lhs_contracting_dimensions = [1], // CHECK-SAME: rhs_contracting_dimensions = [0] // CHECK-SAME: > - // CHECK-SAME: } : (tensor<4x5xi64>, tensor<5xi64>) -> tensor<4xi64> + // CHECK-SAME: }> : (tensor<4x5xi64>, tensor<5xi64>) -> tensor<4xi64> %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<4x5xi64>, tensor<5xi64>) -> tensor<4xi64> func.return %0 : tensor<4xi64> } @@ -30,12 +30,12 @@ func.func @dot_to_dot_general_matrix_dot_vector(%arg0 : tensor<4x5xi64>, %arg1 : // CHECK-LABEL: @dot_to_dot_general_vector_dot_matrix func.func @dot_to_dot_general_vector_dot_matrix(%arg0 : tensor<5xi64>, %arg1 : tensor<5x4xi64>) -> tensor<4xi64> { - // CHECK: [[RES:%.+]] = "mhlo.dot_general"(%arg0, %arg1) { + // CHECK: [[RES:%.+]] = "mhlo.dot_general"(%arg0, %arg1) <{ // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< // CHECK-SAME: lhs_contracting_dimensions = [0], // CHECK-SAME: rhs_contracting_dimensions = [0] // CHECK-SAME: > - // CHECK-SAME: } : (tensor<5xi64>, tensor<5x4xi64>) -> tensor<4xi64> + // CHECK-SAME: }> : (tensor<5xi64>, tensor<5x4xi64>) -> tensor<4xi64> %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<5xi64>, tensor<5x4xi64>) -> tensor<4xi64> func.return %0 : tensor<4xi64> } @@ -44,20 +44,12 @@ func.func @dot_to_dot_general_vector_dot_matrix(%arg0 : tensor<5xi64>, %arg1 : t // CHECK-LABEL: @dot_to_dot_general_matrix_dot_matrix func.func @dot_to_dot_general_matrix_dot_matrix(%arg0 : tensor<4x5xi64>, %arg1 : tensor<5x4xi64>) -> tensor<4x4xi64> { - // CHECK: [[RES:%.+]] = "mhlo.dot_general"(%arg0, %arg1) { + // CHECK: [[RES:%.+]] = "mhlo.dot_general"(%arg0, %arg1) <{ // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< // CHECK-SAME: lhs_contracting_dimensions = [1], // CHECK-SAME: rhs_contracting_dimensions = [0] // CHECK-SAME: > - // CHECK-SAME: } : (tensor<4x5xi64>, tensor<5x4xi64>) -> tensor<4x4xi64> + // CHECK-SAME: }> : (tensor<4x5xi64>, tensor<5x4xi64>) -> tensor<4x4xi64> %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<4x5xi64>, tensor<5x4xi64>) -> tensor<4x4xi64> func.return %0 : tensor<4x4xi64> } - -// ----- - -func.func @dot_to_dot_general_unranked(%arg0 : tensor<*xi64>, %arg1 : tensor<*xi64>) -> tensor<*xi64> { - // expected-error@+1 {{unranked operands}} - %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<*xi64>, tensor<*xi64>) -> tensor<*xi64> - func.return %0 : tensor<*xi64> -} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-einsum-to-dot-general.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-einsum-to-dot-general.mlir index a224c31027b98..51363c9c9c304 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-einsum-to-dot-general.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-einsum-to-dot-general.mlir @@ -2,7 +2,7 @@ func.func @einsum_diag(%arg0: tensor<6x6xf32>) -> tensor<6xf32> { %0 = mhlo.constant dense<1.000000e+00> : tensor - %1 = "mhlo.einsum"(%0, %arg0) {einsum_config = ",ii->i"} : (tensor, tensor<6x6xf32>) -> tensor<6xf32> + %1 = "mhlo.einsum"(%0, %arg0) <{einsum_config = ",ii->i"}> : (tensor, tensor<6x6xf32>) -> tensor<6xf32> func.return %1 : tensor<6xf32> } // CHECK-LABEL: func @einsum_diag @@ -14,7 +14,7 @@ func.func @einsum_diag(%arg0: tensor<6x6xf32>) -> tensor<6xf32> { // CHECK: "mhlo.einsum" func.func @einsum_batched_matrix_high_rank_vector_mul(%arg0: tensor<8x2x6xf32>, %arg1: tensor<8x5x3x6xf32>) -> tensor<8x5x3x2xf32> { - %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "bxy,bijy->bijx"} : (tensor<8x2x6xf32>, tensor<8x5x3x6xf32>) -> tensor<8x5x3x2xf32> + %0 = "mhlo.einsum"(%arg0, %arg1) <{einsum_config = "bxy,bijy->bijx"}> : (tensor<8x2x6xf32>, tensor<8x5x3x6xf32>) -> tensor<8x5x3x2xf32> func.return %0 : tensor<8x5x3x2xf32> } // CHECK-LABEL: func @einsum_batched_matrix_high_rank_vector_mul @@ -32,7 +32,7 @@ func.func @einsum_batched_matrix_high_rank_vector_mul(%arg0: tensor<8x2x6xf32>, // CHECK-SAME: : (tensor<8x2x5x3xf32>) -> tensor<8x5x3x2xf32> func.func @matmul(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "ij,jk->ik"} : (tensor, tensor) -> tensor + %0 = "mhlo.einsum"(%arg0, %arg1) <{einsum_config = "ij,jk->ik"}> : (tensor, tensor) -> tensor func.return %0 : tensor } // CHECK-LABEL: func @matmul @@ -45,7 +45,7 @@ func.func @matmul(%arg0: tensor, %arg1: tensor) -> tensor, tensor) -> tensor func.func @matvec(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "ij,j->i"} : (tensor, tensor) -> tensor + %0 = "mhlo.einsum"(%arg0, %arg1) <{einsum_config = "ij,j->i"}> : (tensor, tensor) -> tensor func.return %0 : tensor } // CHECK-LABEL: func @matvec @@ -58,7 +58,7 @@ func.func @matvec(%arg0: tensor, %arg1: tensor) -> tensor // CHECK-SAME: : (tensor, tensor) -> tensor func.func @dot(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "i,i->"} : (tensor, tensor) -> tensor + %0 = "mhlo.einsum"(%arg0, %arg1) <{einsum_config = "i,i->"}> : (tensor, tensor) -> tensor func.return %0 : tensor } // CHECK-LABEL: func @dot diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-gather-to-torch-index-select.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-gather-to-torch-index-select.mlir index 0fd367809d202..2649723ca4fe6 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-gather-to-torch-index-select.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-gather-to-torch-index-select.mlir @@ -2,10 +2,10 @@ // CHECK-LABEL: @gather_to_index_select func.func @gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x4xf32> { - // CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) { + // CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) <{ // CHECK-SAME: batch_dims = 0 : i64, // CHECK-SAME: dim = 0 : i64 - // CHECK-SAME: } : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32> + // CHECK-SAME: }> : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32> // CHECK: [[RES:%.+]] = mhlo.reshape [[TIS]] %0 = "mhlo.gather"(%arg0, %arg1) { dimension_numbers = #mhlo.gather< diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-rng-to-linalg.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-rng-to-linalg.mlir index 056fbfb2eb64a..bd7b54c5c3bc7 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-rng-to-linalg.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-rng-to-linalg.mlir @@ -2,7 +2,7 @@ // RUN: FILECHECK_OPTS="" FileCheck %s func.func @three_fry_i64(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi64>) { - %output_state, %output = "mhlo.rng_bit_generator"(%arg0) {rng_algorithm = #mhlo.rng_algorithm} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi64>) + %output_state, %output = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi64>) return %output_state, %output : tensor<2xi64>, tensor<8xi64> } @@ -98,7 +98,7 @@ func.func @three_fry_i64(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi64>) // ----- func.func @three_fry_i32(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi32>) { - %output_state, %output = "mhlo.rng_bit_generator"(%arg0) {rng_algorithm = #mhlo.rng_algorithm} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi32>) + %output_state, %output = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi32>) return %output_state, %output : tensor<2xi64>, tensor<8xi32> } @@ -140,7 +140,7 @@ func.func @three_fry_i32(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi32>) // ----- func.func @three_fry_odd_i32(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<7x11xi32>) { - %output_state, %output = "mhlo.rng_bit_generator"(%arg0) {rng_algorithm = #mhlo.rng_algorithm} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<7x11xi32>) + %output_state, %output = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> : (tensor<2xi64>) -> (tensor<2xi64>, tensor<7x11xi32>) return %output_state, %output : tensor<2xi64>, tensor<7x11xi32> } @@ -183,7 +183,7 @@ func.func @three_fry_odd_i32(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<7x1 // ----- func.func @three_fry_i16(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi16>) { - %output_state, %output = "mhlo.rng_bit_generator"(%arg0) {rng_algorithm = #mhlo.rng_algorithm} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi16>) + %output_state, %output = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi16>) return %output_state, %output : tensor<2xi64>, tensor<8xi16> } @@ -224,7 +224,7 @@ func.func @three_fry_i16(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi16>) // ----- func.func @philox_i64(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi64>) { - %output_state, %output = "mhlo.rng_bit_generator"(%arg0) {rng_algorithm = #mhlo.rng_algorithm} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi64>) + %output_state, %output = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi64>) return %output_state, %output : tensor<2xi64>, tensor<8xi64> } @@ -349,7 +349,7 @@ func.func @philox_i64(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi64>) { // ----- func.func @philox_i32(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi32>) { - %output_state, %output = "mhlo.rng_bit_generator"(%arg0) {rng_algorithm = #mhlo.rng_algorithm} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi32>) + %output_state, %output = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi32>) return %output_state, %output : tensor<2xi64>, tensor<8xi32> } @@ -385,7 +385,7 @@ func.func @philox_i32(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi32>) { // ----- func.func @philox_i32_odd(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<7x11xi32>) { - %output_state, %output = "mhlo.rng_bit_generator"(%arg0) {rng_algorithm = #mhlo.rng_algorithm} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<7x11xi32>) + %output_state, %output = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> : (tensor<2xi64>) -> (tensor<2xi64>, tensor<7x11xi32>) return %output_state, %output : tensor<2xi64>, tensor<7x11xi32> } @@ -436,7 +436,7 @@ func.func @philox_i32_odd(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<7x11xi func.func @philox_i64_odd(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<3x5xi64>) { - %output_state, %output = "mhlo.rng_bit_generator"(%arg0) {rng_algorithm = #mhlo.rng_algorithm} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<3x5xi64>) + %output_state, %output = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> : (tensor<2xi64>) -> (tensor<2xi64>, tensor<3x5xi64>) return %output_state, %output : tensor<2xi64>, tensor<3x5xi64> } @@ -475,7 +475,7 @@ func.func @philox_i64_odd(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<3x5xi6 // ----- func.func @philox_i16(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi16>) { - %output_state, %output = "mhlo.rng_bit_generator"(%arg0) {rng_algorithm = #mhlo.rng_algorithm} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi16>) + %output_state, %output = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi16>) return %output_state, %output : tensor<2xi64>, tensor<8xi16> } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir index 7726ec47587c3..3af8f3e9b6a20 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir @@ -928,7 +928,7 @@ func.func @select_mixed(%pred: tensor<2x?xi1>, %lhs: tensor, // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @broadcast_scalar func.func @broadcast_scalar(%arg: tensor) -> tensor<4x2x1xf32> { - %0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor) -> tensor<4x2x1xf32> + %0 = "mhlo.broadcast"(%arg) <{broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>}> : (tensor) -> tensor<4x2x1xf32> func.return %0: tensor<4x2x1xf32> } // CHECK: tensor.empty() : tensor<4x2x1xf32> @@ -949,7 +949,7 @@ func.func @broadcast_scalar(%arg: tensor) -> tensor<4x2x1xf32> { // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> // CHECK-LABEL: func @broadcast func.func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> { - %0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> + %0 = "mhlo.broadcast"(%arg) <{broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>}> : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> func.return %0: tensor<4x2x1x4x?x16xf32> } // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index @@ -1092,7 +1092,7 @@ func.func @broadcast_in_dim_scalar(%operand: tensor) -> tensor<7x10x6xf32> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-LABEL: func @transpose func.func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { - %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} + %0 = "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}> : (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> func.return %0 : tensor<3x2x5x9xi32> } @@ -1107,7 +1107,7 @@ func.func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-LABEL: func @transpose_dynamic func.func @transpose_dynamic(%arg0: tensor) -> tensor { - %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>, someattr} + %0 = "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}> {someattr} : (tensor) -> tensor func.return %0 : tensor } @@ -1944,7 +1944,7 @@ func.func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @iota_f32 func.func @iota_f32() -> tensor<7x10xf32> { - %result = "mhlo.iota"() {iota_dimension = 1 : i64, someattr} : () -> (tensor<7x10xf32>) + %result = "mhlo.iota"() <{iota_dimension = 1 : i64}> {someattr} : () -> (tensor<7x10xf32>) func.return %result : tensor<7x10xf32> } // CHECK: tensor.empty @@ -1971,7 +1971,7 @@ func.func @iota_f32() -> tensor<7x10xf32> { // CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @iota_i32 func.func @iota_i32() -> tensor<7x10xi32> { - %result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xi32>) + %result = "mhlo.iota"() <{iota_dimension = 1 : i64}> : () -> (tensor<7x10xi32>) func.return %result : tensor<7x10xi32> } // CHECK: tensor.empty @@ -1987,7 +1987,7 @@ func.func @iota_i32() -> tensor<7x10xi32> { // CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @iota_ui32 func.func @iota_ui32() -> tensor<7x10xui32> { - %result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xui32>) + %result = "mhlo.iota"() <{iota_dimension = 1 : i64}> : () -> (tensor<7x10xui32>) func.return %result : tensor<7x10xui32> } // CHECK: tensor.empty @@ -2004,7 +2004,7 @@ func.func @iota_ui32() -> tensor<7x10xui32> { // CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @iota_complexf32 func.func @iota_complexf32() -> tensor<7x10xcomplex> { - %result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xcomplex>) + %result = "mhlo.iota"() <{iota_dimension = 1 : i64}> : () -> (tensor<7x10xcomplex>) func.return %result : tensor<7x10xcomplex> } // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 @@ -2024,7 +2024,7 @@ func.func @iota_complexf32() -> tensor<7x10xcomplex> { // CHECK-LABEL: func @dynamic_iota_f32 // CHECK-SAME: %[[SHAPE:.*]]: tensor func.func @dynamic_iota_f32(%shape: tensor) -> tensor { - %result = "mhlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor) -> (tensor) + %result = "mhlo.dynamic_iota"(%shape) <{iota_dimension = 1 : i64}> : (tensor) -> (tensor) func.return %result : tensor } // CHECK: %[[V1:.*]] = tensor.extract %[[SHAPE]][%c0] @@ -2046,7 +2046,7 @@ func.func @dynamic_iota_f32(%shape: tensor) -> tensor { // CHECK-LABEL: func @dyanmic_iota_ui32 // CHECK-SAME: %[[SHAPE:.*]]: tensor func.func @dyanmic_iota_ui32(%shape: tensor) -> tensor { - %result = "mhlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor) -> (tensor) + %result = "mhlo.dynamic_iota"(%shape) <{iota_dimension = 1 : i64}> : (tensor) -> (tensor) func.return %result : tensor } // CHECK: %[[V1:.*]] = tensor.extract %[[SHAPE]][%c0] @@ -2211,9 +2211,9 @@ func.func @integer_pow(%lhs: tensor<2x2xi32>, // CHECK-SAME: [[SHAPE:%.*]]: tensor<1xindex> func.func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>) -> tensor { %cst = mhlo.constant dense<0x7F800000> : tensor - %result = "mhlo.dynamic_broadcast_in_dim"(%cst, %shape) { - broadcast_dimensions = dense<> : tensor<0xi64>, someattr - } : (tensor, tensor<1xindex>) -> tensor + %result = "mhlo.dynamic_broadcast_in_dim"(%cst, %shape) <{ + broadcast_dimensions = dense<> : tensor<0xi64> + }> {someattr} : (tensor, tensor<1xindex>) -> tensor func.return %result : tensor } // CHECK: [[CST:%.*]] = arith.constant dense @@ -2242,9 +2242,9 @@ func.func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>) -> tensor { // CHECK-SAME: [[SHAPE:%.*]]: tensor<2xindex> func.func @dynamic_broadcast_in_dim(%scalar: tensor, %shape: tensor<2xindex>) -> tensor { - %result = "mhlo.dynamic_broadcast_in_dim"(%scalar, %shape) { + %result = "mhlo.dynamic_broadcast_in_dim"(%scalar, %shape) <{ broadcast_dimensions = dense<> : tensor<0xi64> - } : (tensor, tensor<2xindex>) -> tensor + }> : (tensor, tensor<2xindex>) -> tensor func.return %result : tensor } // CHECK: [[INIT:%.*]] = tensor.empty @@ -2293,9 +2293,9 @@ func.func @dynamic_broadcast_in_dim(%vector: tensor<42xf32>, %shape: tensor<3xin // fail if the %shape i32 -> index cast is not performed properly. func.func @dynamic_broadcast_in_dim(%scalar: tensor, %shape: tensor<2xi32>) -> tensor { - %result = "mhlo.dynamic_broadcast_in_dim"(%scalar, %shape) { + %result = "mhlo.dynamic_broadcast_in_dim"(%scalar, %shape) <{ broadcast_dimensions = dense<> : tensor<0xi64> - } : (tensor, tensor<2xi32>) -> tensor + }> : (tensor, tensor<2xi32>) -> tensor func.return %result : tensor } @@ -2307,9 +2307,9 @@ func.func @dynamic_broadcast_in_dim(%scalar: tensor, %shape: tensor<2xi32>) // CHECK-LABEL: func @dynamic_broadcast_in_dim( // CHECK-SAME: [[SHAPE:%.*]]: tensor<1xindex>, [[CSTARG:%.*]]: tensor func.func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>, %cst: tensor) -> tensor { - %result = "mhlo.dynamic_broadcast_in_dim"(%cst, %shape) { + %result = "mhlo.dynamic_broadcast_in_dim"(%cst, %shape) <{ broadcast_dimensions = dense<> : tensor<0xi64> - } : (tensor, tensor<1xindex>) -> tensor + }> : (tensor, tensor<1xindex>) -> tensor func.return %result : tensor } // CHECK: [[CST:%.*]] = builtin.unrealized_conversion_cast [[CSTARG]] : tensor to tensor @@ -2340,10 +2340,10 @@ func.func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>, %cst: tensor) // CHECK-PRIMITIVE-SAME: %[[ARG:.*]]: tensor, %[[SHAPE:.*]]: tensor<7xindex> func.func @dynamic_broadcast_in_dim(%arg: tensor, %shape: tensor<7xindex>) -> tensor { - %result = "mhlo.dynamic_broadcast_in_dim"(%arg, %shape) { + %result = "mhlo.dynamic_broadcast_in_dim"(%arg, %shape) <{ broadcast_dimensions = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>, known_expanding_dimensions = dense<[0, 1]> : tensor<2xi64>, - known_nonexpanding_dimensions = dense<[2, 3]> : tensor<2xi64> } + known_nonexpanding_dimensions = dense<[2, 3]> : tensor<2xi64> }> : (tensor, tensor<7xindex>) -> tensor func.return %result : tensor } @@ -2696,7 +2696,7 @@ func.func @dot_general_batch_matmul_large // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> // CHECK-LABEL: func @einsum_basic func.func @einsum_basic(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { - %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "ijk,ikm->ijm", someattr}: (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> + %0 = "mhlo.einsum"(%arg0, %arg1) <{einsum_config = "ijk,ikm->ijm"}> {someattr} : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> func.return %0 : tensor<3x4x6xf32> } // CHECK-SAME: (%[[LHS:.*]]: tensor<3x4x5xf32>, %[[RHS:.*]]: tensor<3x5x6xf32>) @@ -2738,7 +2738,7 @@ func.func @dot_general_batch_matvec(%arg0: tensor, // CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: func @einsum_pointwisemul func.func @einsum_pointwisemul(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> { - %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "abc,abc->abc"} : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> + %0 = "mhlo.einsum"(%arg0, %arg1) <{einsum_config = "abc,abc->abc"}> : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> func.return %0 : tensor<3x4x5xf32> } // CHECK-SAME: (%[[LHS:.*]]: tensor<3x4x5xf32>, %[[RHS:.*]]: tensor<3x4x5xf32>) @@ -2759,7 +2759,7 @@ func.func @einsum_pointwisemul(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x4x5xf32 // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK: func @einsum_matmul func.func @einsum_matmul(%arg0: tensor<7x9xf32>, %arg1: tensor<9x5xf32>) -> tensor<7x5xf32> { - %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "ae,ed->ad"}: (tensor<7x9xf32>, tensor<9x5xf32>) -> tensor<7x5xf32> + %0 = "mhlo.einsum"(%arg0, %arg1) <{einsum_config = "ae,ed->ad"}>: (tensor<7x9xf32>, tensor<9x5xf32>) -> tensor<7x5xf32> func.return %0 : tensor<7x5xf32> } // CHECK-SAME: (%[[LHS:.*]]: tensor<7x9xf32>, %[[RHS:.*]]: tensor<9x5xf32>) @@ -2783,7 +2783,7 @@ func.func @einsum_matmul(%arg0: tensor<7x9xf32>, %arg1: tensor<9x5xf32>) -> tens // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d5)> // CHECK: func @einsum_broadcast4 func.func @einsum_broadcast4(%arg0: tensor<3x4x5x6x7xf32>, %arg1: tensor<7x8xf32>) -> tensor<3x4x5x6x8xf32> { - %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "abcdh,hg->abcdg"}: (tensor<3x4x5x6x7xf32>, tensor<7x8xf32>) -> tensor<3x4x5x6x8xf32> + %0 = "mhlo.einsum"(%arg0, %arg1) <{einsum_config = "abcdh,hg->abcdg"}>: (tensor<3x4x5x6x7xf32>, tensor<7x8xf32>) -> tensor<3x4x5x6x8xf32> func.return %0 : tensor<3x4x5x6x8xf32> } // CHECK-SAME: (%[[LHS:.*]]: tensor<3x4x5x6x7xf32>, %[[RHS:.*]]: tensor<7x8xf32>) @@ -2807,7 +2807,7 @@ func.func @einsum_broadcast4(%arg0: tensor<3x4x5x6x7xf32>, %arg1: tensor<7x8xf32 // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> // CHECK: func @einsum_ellipsis func.func @einsum_ellipsis(%arg0: tensor<1x512x128xf32>, %arg1: tensor<128x256xf32>) -> tensor<1x512x256xf32> { - %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "...x,xy->...y"} : (tensor<1x512x128xf32>, tensor<128x256xf32>) -> tensor<1x512x256xf32> + %0 = "mhlo.einsum"(%arg0, %arg1) <{einsum_config = "...x,xy->...y"}> : (tensor<1x512x128xf32>, tensor<128x256xf32>) -> tensor<1x512x256xf32> func.return %0 : tensor<1x512x256xf32> } // CHECK-SAME: (%[[LHS:.*]]: tensor<1x512x128xf32>, %[[RHS:.*]]: tensor<128x256xf32>) @@ -2831,7 +2831,7 @@ func.func @einsum_ellipsis(%arg0: tensor<1x512x128xf32>, %arg1: tensor<128x256xf // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> // CHECK: func @einsum_dynamic_size_broadcast_dot func.func @einsum_dynamic_size_broadcast_dot(%arg0: tensor, %arg1: tensor<4x?xf32>) -> tensor { - %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "abc,cd->abd"} : (tensor, tensor<4x?xf32>) -> tensor + %0 = "mhlo.einsum"(%arg0, %arg1) <{einsum_config = "abc,cd->abd"}> : (tensor, tensor<4x?xf32>) -> tensor func.return %0 : tensor } // CHECK-SAME: (%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor<4x?xf32>) @@ -3078,28 +3078,13 @@ func.func @reduce_add(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor<5xi3 // CHECK-PRIMITIVE-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor // CHECK-PRIMITIVE-DAG: %[[INIT_TENSOR:.*]] = tensor.empty() // CHECK-PRIMITIVE-DAG: %[[FILL_TENSOR:.*]] = linalg.fill ins(%[[INIT]]{{.*}}outs(%[[INIT_TENSOR]] -// CHECK-PRIMITIVE: linalg.reduce { arith.addi } +// CHECK-PRIMITIVE: linalg.reduce { arith.addi {overflowFlags = #arith.overflow} } // CHECK-PRIMITIVE-SAME: ins(%{{.*}}tensor<5x4xi32>) // CHECK-PRIMITIVE-SAME: outs(%[[FILL_TENSOR]] : tensor<5xi32>) // CHECK-PRIMITIVE-SAME: dimensions = [1] {someattr} // ----- -// CHECK-LABEL: @reduce_add_unranked -// CHECK-PRIMITIVE-LABEL: @reduce_add_unranked -func.func @reduce_add_unranked(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor<*xi32> { - %0 = "mhlo.reduce"(%arg0, %arg1) ({ - ^bb0(%arg3: tensor, %arg4 : tensor): - %1 = mhlo.add %arg3, %arg4 : tensor - "mhlo.return"(%1) : (tensor) -> () - }) {dimensions = dense<1> : tensor<1xi64>, someattr} : (tensor<*xi32>, tensor) -> tensor<*xi32> - func.return %0 : tensor<*xi32> -} -// CHECK: mhlo.reduce -// CHECK-PRIMITIVE: mhlo.reduce - -// ----- - func.func @reduce_dim0(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor<4xi32> { %0 = "mhlo.reduce"(%arg0, %arg1) ({ ^bb0(%arg3: tensor, %arg4 : tensor): @@ -5011,49 +4996,6 @@ func.func @gather_non_static(%operand : tensor, %start_indices : tensor // ----- -func.func @gather_unranked(%operand : tensor<*xi32>, %start_indices : tensor) -> tensor { - %res = "mhlo.gather"(%operand, %start_indices) { - dimension_numbers = #mhlo.gather< - collapsed_slice_dims = [], - index_vector_dim = 1, - offset_dims = [0, 1], - start_index_map = [0] - >, - indices_are_sorted = false, - slice_sizes = dense<[3, 4]> : tensor<2xi64> - } : (tensor<*xi32>, tensor) -> tensor - func.return %res : tensor -} - -// CHECK-LABEL: func @gather_unranked( -// CHECK-SAME: %[[OPERAND:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[START_INDICES:[a-zA-Z0-9_]+]] -// CHECK-SAME: ) -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 -// CHECK-DAG: %[[RES_DIM2:.+]] = tensor.dim %[[START_INDICES]], %[[C0]] -// CHECK-DAG: %[[INIT:.+]] = tensor.empty(%[[RES_DIM2]]) -// CHECK: %[[RES:.+]] = linalg.generic -// CHECK-SAME: outs(%[[INIT]] : tensor -// CHECK: ^bb0 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 -// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 -// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 -// CHECK-DAG: %[[S0_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX2]], %[[C0]]] : tensor -// CHECK-DAG: %[[S0:.+]] = arith.index_cast %[[S0_INT]] : i32 to index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND]], %[[C0]] -// CHECK-DAG: %[[L0:.+]] = arith.subi %[[D0]], %[[C3]] -// CHECK-DAG: %[[CLAMP0:.+]] = arith.maxsi %[[S0]], %[[C0]] : index -// CHECK-DAG: %[[CLAMP0_1:.+]] = arith.minsi %[[CLAMP0]], %[[L0]] : index -// CHECK-DAG: %[[IN0:.+]] = arith.addi %[[CLAMP0_1]], %[[IDX0]] : index -// CHECK-DAG: %[[OPERAND_CASTED:.+]] = tensor.cast %[[OPERAND]] : tensor<*xi32> to tensor -// CHECK: %[[Y:.+]] = tensor.extract %[[OPERAND_CASTED]][%[[IN0]], %[[IDX1]]] : tensor -// CHECK: linalg.yield %[[Y]] : i32 -// CHECK: %[[CAST:.+]] = tensor.cast %[[RES]] -// CHECK: return %[[CAST]] - -// ----- - func.func @torch_index_select(%arg0: tensor<5x1x5xi32>, %arg1: tensor<2xi32>) -> tensor<2x1x5xi32> { %0 = "mhlo.torch_index_select"(%arg0, %arg1) { @@ -5090,7 +5032,7 @@ func.func @torch_index_select(%arg0: tensor<5x1x5xi32>, func.func @rng_uniform_1d(%min: tensor, %max: tensor) -> tensor<10xf32> { %shape = arith.constant dense<[10]> : tensor<1xi32> - %0 = "mhlo.rng"(%min, %max, %shape) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<1xi32>) -> tensor<10xf32> + %0 = "mhlo.rng"(%min, %max, %shape) <{rng_distribution = #mhlo.rng_distribution}> : (tensor, tensor, tensor<1xi32>) -> tensor<10xf32> func.return %0 : tensor<10xf32> } // CHECK-LABEL: func @rng_uniform_1d @@ -5115,7 +5057,7 @@ func.func @rng_uniform_1d(%min: tensor, %max: tensor) -> tensor<10xf32 func.func @rng_uniform_2d(%min: tensor, %max: tensor) -> tensor<3x3xf32> { %shape = arith.constant dense<[3, 3]> : tensor<2xi32> - %0 = "mhlo.rng"(%min, %max, %shape) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<2xi32>) -> tensor<3x3xf32> + %0 = "mhlo.rng"(%min, %max, %shape) <{rng_distribution = #mhlo.rng_distribution}> : (tensor, tensor, tensor<2xi32>) -> tensor<3x3xf32> func.return %0 : tensor<3x3xf32> } // CHECK-LABEL: func @rng_uniform_2d @@ -5145,7 +5087,7 @@ func.func @rng_uniform_2d(%min: tensor, %max: tensor) -> tensor<3x3xf3 func.func @rng_uniform_3d(%min: tensor, %max: tensor) -> tensor<2x2x2xf32> { %shape = arith.constant dense<[2, 2, 2]> : tensor<3xi32> - %0 = "mhlo.rng"(%min, %max, %shape) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<3xi32>) -> tensor<2x2x2xf32> + %0 = "mhlo.rng"(%min, %max, %shape) <{rng_distribution = #mhlo.rng_distribution}> : (tensor, tensor, tensor<3xi32>) -> tensor<2x2x2xf32> func.return %0 : tensor<2x2x2xf32> } // CHECK-LABEL: func @rng_uniform_3d @@ -5179,7 +5121,7 @@ func.func @rng_uniform_3d(%min: tensor, %max: tensor) -> tensor<2x2x2x func.func @rng_uniform_dynamic_1d(%min: tensor, %max: tensor, %shape: tensor<1xi32>) -> tensor { - %0 = "mhlo.rng"(%min, %max, %shape) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<1xi32>) -> tensor + %0 = "mhlo.rng"(%min, %max, %shape) <{rng_distribution = #mhlo.rng_distribution}> : (tensor, tensor, tensor<1xi32>) -> tensor func.return %0 : tensor } // CHECK-LABEL: func @rng_uniform_dynamic_1d @@ -6128,23 +6070,3 @@ func.func @clamp_complex(%min: tensor<8xcomplex>, %result = mhlo.clamp %min, %operand, %max : tensor<8xcomplex> func.return %result : tensor<8xcomplex> } - -// ----- -// CHECK: #[[$ST_3D:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) }> -// CHECK: #[[$ST_4D:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : compressed) }> -// CHECK-LABEL: func @reshape_sparse_encoding -// CHECK-PRIMITIVE-LABEL: func @reshape_sparse_encoding - -#ST_3D = #sparse_tensor.encoding<{ - map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) -}> - -#ST_4D = #sparse_tensor.encoding<{ - map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : compressed) -}> - -func.func @reshape_sparse_encoding(%arg0: tensor<1x49x16xf32, #ST_3D>) -> tensor<1x784x1x1xf32, #ST_4D> { - %0 = "mhlo.reshape"(%arg0) : (tensor<1x49x16xf32, #ST_3D>) -> tensor<1x784x1x1xf32, #ST_4D> - func.return %0 : tensor<1x784x1x1xf32, #ST_4D> -} -// CHECK: tensor.reshape %{{.*}} : (tensor<1x49x16xf32, #[[$ST_3D]]>, tensor<4xi64>) -> tensor<1x784x1x1xf32, #[[$ST_4D]]> diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 0d06994182e37..31cf7728f48d2 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -272,20 +272,20 @@ func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "attr_rng_distribution_uniform" -func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "mhlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #stablehlo rng_distribution = #mhlo.rng_distribution - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } // CHECK-LABEL: "attr_rng_distribution_normal" -func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "mhlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #stablehlo rng_distribution = #mhlo.rng_distribution - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } @@ -390,15 +390,15 @@ func.func @op_all_reduce(%arg0: tensor) -> tensor { // CHECK-SAME{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, // CHECK-SAME: use_global_device_ids // CHECK-SAME: } : (tensor) -> tensor - %0 = "mhlo.all_reduce"(%arg0) ({ - ^bb0(%arg1: tensor, %arg2: tensor): - %1 = "mhlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor - "mhlo.return"(%1) : (tensor) -> () - }) { + %0 = "mhlo.all_reduce"(%arg0) <{ replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, channel_handle = #mhlo.channel_handle, use_global_device_ids - } : (tensor) -> tensor + }> ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "mhlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () + }) : (tensor) -> tensor func.return %0 : tensor } @@ -490,18 +490,18 @@ func.func @op_bitcast_convert(%arg0: tensor) -> tensor { // CHECK-LABEL: "op_broadcast_in_dim" func.func @op_broadcast_in_dim(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { // CHECK: "stablehlo.broadcast_in_dim"(%arg0) { - // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: broadcast_dimensions = array // CHECK-SAME: } : (tensor<16xf32>) -> tensor<16x16xf32> - %0 = "mhlo.broadcast_in_dim"(%arg0) { + %0 = "mhlo.broadcast_in_dim"(%arg0) <{ broadcast_dimensions = dense<1> : tensor<1xi64> - } : (tensor<16xf32>) -> tensor<16x16xf32> + }> : (tensor<16xf32>) -> tensor<16x16xf32> func.return %0 : tensor<16x16xf32> } // CHECK-LABEL: "op_broadcast" func.func @op_broadcast(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { // CHECK: "stablehlo.broadcast"(%arg0) { - // CHECK-SAME: broadcast_sizes = dense<16> : tensor<1xi64> + // CHECK-SAME: broadcast_sizes = array // CHECK-SAME: } : (tensor<16xf32>) -> tensor<16x16xf32> %0 = "mhlo.broadcast"(%arg0) { broadcast_sizes = dense<16> : tensor<1xi64> @@ -559,6 +559,19 @@ func.func @op_count_leading_zeros(%arg0: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "op_collective_broadcast" +func.func @op_collective_broadcast(%arg0: tensor<1x2xi64>) -> tensor<1x2xi64> { + // CHECK: "stablehlo.collective_broadcast"(%arg0) { + // CHECK-SAME: channel_handle = #stablehlo.channel_handle, + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + // CHECK-SAME: } : (tensor<1x2xi64>) -> tensor<1x2xi64> + %0 = "mhlo.collective_broadcast"(%arg0) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #mhlo.channel_handle + } : (tensor<1x2xi64>) -> tensor<1x2xi64> + func.return %0 : tensor<1x2xi64> +} + // CHECK-LABEL: "op_collective_permute" func.func @op_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { // CHECK: "stablehlo.collective_permute"(%arg0) { @@ -592,6 +605,22 @@ func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> } +// CHECK-LABEL: "op_composite" +func.func @op_composite(%arg0 : tensor) -> tensor { + // CHECK: "stablehlo.composite"(%arg0) {composite_attributes = {n = 2 : i64}, decomposition = @add_n.impl, name = "mhlo.add_n"} : (tensor) -> tensor + %0 = mhlo.composite "mhlo.add_n" %arg0 { + composite_attributes = { n = 2 : i64 }, + decomposition = @add_n.impl + } : (tensor) -> tensor + func.return %0 : tensor +} + +func.func @add_n.impl(%arg0: tensor) -> tensor { + %0 = mhlo.constant dense<2> : tensor + %1 = mhlo.add %arg0, %0 : tensor + func.return %1 : tensor +} + // CHECK-LABEL: "op_compute_reshape_shape" func.func @op_compute_reshape_shape(%arg0: index, %arg1: tensor<1xindex>) -> tensor<1xindex> { // CHECK: "stablehlo.compute_reshape_shape"(%arg0, %arg1) : (index, tensor<1xindex>) -> tensor<1xindex> @@ -634,12 +663,12 @@ func.func @op_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16 // CHECK-SAME: batch_group_count = 1 : i64, // CHECK-SAME: dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, // CHECK-SAME: feature_group_count = 1 : i64, - // CHECK-SAME: lhs_dilation = dense<1> : tensor<2xi64>, + // CHECK-SAME: lhs_dilation = array, // CHECK-SAME: padding = dense<1> : tensor<2x2xi64>, // CHECK-SAME: precision_config = [#stablehlo, #stablehlo], - // CHECK-SAME: rhs_dilation = dense<1> : tensor<2xi64>, - // CHECK-SAME: window_reversal = dense : tensor<2xi1>, - // CHECK-SAME: window_strides = dense<1> : tensor<2xi64> + // CHECK-SAME: rhs_dilation = array, + // CHECK-SAME: window_reversal = array, + // CHECK-SAME: window_strides = array // CHECK-SAME: } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> %0 = "mhlo.convolution"(%arg0, %arg1) { window_strides = dense<1> : tensor<2xi64>, @@ -769,15 +798,15 @@ func.func @op_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x // CHECK-LABEL: "op_dynamic_broadcast_in_dim" func.func @op_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xindex>) -> tensor { // CHECK: "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { - // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64>, - // CHECK-SAME: known_expanding_dimensions = dense<> : tensor<0xi64>, - // CHECK-SAME: known_nonexpanding_dimensions = dense<0> : tensor<1xi64> + // CHECK-SAME: broadcast_dimensions = array, + // CHECK-SAME: known_expanding_dimensions = array, + // CHECK-SAME: known_nonexpanding_dimensions = array // CHECK-SAME: } : (tensor, tensor<2xindex>) -> tensor - %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) <{ broadcast_dimensions = dense<1> : tensor<1xi64>, known_expanding_dimensions = dense<[]> : tensor<0xi64>, known_nonexpanding_dimensions = dense<0> : tensor<1xi64> - } : (tensor, tensor<2xindex>) -> tensor + }> : (tensor, tensor<2xindex>) -> tensor func.return %0 : tensor } @@ -787,12 +816,12 @@ func.func @op_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x1 // CHECK-SAME: batch_group_count = 1 : i64, // CHECK-SAME: dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, // CHECK-SAME: feature_group_count = 1 : i64, - // CHECK-SAME: lhs_dilation = dense<1> : tensor<2xi64>, + // CHECK-SAME: lhs_dilation = array, // CHECK-SAME: padding = dense<1> : tensor<2x2xi64>, // CHECK-SAME: precision_config = [#stablehlo, #stablehlo], - // CHECK-SAME: rhs_dilation = dense<1> : tensor<2xi64>, - // CHECK-SAME: window_reversal = dense : tensor<2xi1>, - // CHECK-SAME: window_strides = dense<1> : tensor<2xi64> + // CHECK-SAME: rhs_dilation = array, + // CHECK-SAME: window_reversal = array, + // CHECK-SAME: window_strides = array // CHECK-SAME: } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<4xi32>) -> tensor<1x?x?x16xf32> %0 = "mhlo.dynamic_conv"(%arg0, %arg1, %arg2) { window_strides = dense<1> : tensor<2xi64>, @@ -850,16 +879,16 @@ func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tenso } // CHECK-LABEL: "op_dynamic_reshape" -func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { - // CHECK: "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor - %0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor + %0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor func.return %0 : tensor } // CHECK-LABEL: "op_dynamic_slice" func.func @op_dynamic_slice(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor<4xf32> { // CHECK: "stablehlo.dynamic_slice"(%arg0, %arg1) { - // CHECK-SAME: slice_sizes = dense<4> : tensor<1xi64> + // CHECK-SAME: slice_sizes = array // CHECK-SAME: } : (tensor<16xf32>, tensor) -> tensor<4xf32> %0 = "mhlo.dynamic_slice"(%arg0, %arg1) { slice_sizes = dense<4> : tensor<1xi64> @@ -902,7 +931,7 @@ func.func @op_exponential(%arg0: tensor) -> tensor { // CHECK-LABEL: "op_fft" func.func @op_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { // CHECK: "stablehlo.fft"(%arg0) { - // CHECK-SAME: fft_length = dense<16> : tensor<1xi64>, + // CHECK-SAME: fft_length = array, // CHECK-SAME: fft_type = #stablehlo // CHECK-SAME: } : (tensor<16xcomplex>) -> tensor<16xcomplex> %0 = "mhlo.fft"(%arg0) { @@ -931,7 +960,7 @@ func.func @op_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> te // CHECK-SAME: index_vector_dim = 2 // CHECK-SAME: >, // CHECK-SAME: indices_are_sorted = false, - // CHECK-SAME: slice_sizes = dense<1> : tensor<3xi64> + // CHECK-SAME: slice_sizes = array // CHECK-SAME: } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> %0 = "mhlo.gather"(%arg0, %arg1) { dimension_numbers = #mhlo.gather< @@ -1049,7 +1078,7 @@ func.func @op_map(%arg0: tensor<16xf32>) -> tensor<16xf32> { // CHECK-NEXT: %[[VAL1:.*]] = "stablehlo.abs"(%[[ARG1]]) : (tensor) -> tensor // CHECK-NEXT: "stablehlo.return"(%[[VAL1]]) : (tensor) -> () // CHECK-NEXT: }) { - // CHECK-SAME: dimensions = dense<0> : tensor<1xi64> + // CHECK-SAME: dimensions = array // CHECK-SAME: } : (tensor<16xf32>) -> tensor<16xf32> %0 = "mhlo.map"(%arg0) ({ ^bb0(%arg1: tensor): @@ -1124,9 +1153,9 @@ func.func @op_outfeed(%arg0: tensor, %arg1: !mhlo.token) -> !mhlo.token { // CHECK-LABEL: "op_pad" func.func @op_pad(%arg0: tensor<8xf32>, %arg1: tensor) -> tensor<16xf32> { // CHECK: "stablehlo.pad"(%arg0, %arg1) { - // CHECK-SAME: edge_padding_high = dense<4> : tensor<1xi64>, - // CHECK-SAME: edge_padding_low = dense<4> : tensor<1xi64>, - // CHECK-SAME: interior_padding = dense<0> : tensor<1xi64> + // CHECK-SAME: edge_padding_high = array, + // CHECK-SAME: edge_padding_low = array, + // CHECK-SAME: interior_padding = array // CHECK-SAME: } : (tensor<8xf32>, tensor) -> tensor<16xf32> %0 = "mhlo.pad"(%arg0, %arg1) { edge_padding_high = dense<4> : tensor<1xi64>, @@ -1239,11 +1268,11 @@ func.func @op_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> // CHECK-NEXT: %[[VAL1:.*]] = "stablehlo.maximum"(%[[ARG2]], %[[ARG3]]) : (tensor, tensor) -> tensor // CHECK-NEXT: "stablehlo.return"(%[[VAL1]]) : (tensor) -> () // CHECK-NEXT: }) { - // CHECK-SAME: base_dilations = dense<1> : tensor<4xi64>, + // CHECK-SAME: base_dilations = array, // CHECK-SAME{LITERAL}: padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64>, - // CHECK-SAME: window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>, - // CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, - // CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64> + // CHECK-SAME: window_dilations = array, + // CHECK-SAME: window_dimensions = array, + // CHECK-SAME: window_strides = array // CHECK-SAME: } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x5x8x7xf32> %0 = "mhlo.reduce_window"(%arg0, %arg1) ({ ^bb0(%arg2: tensor, %arg3: tensor): @@ -1294,7 +1323,7 @@ func.func @op_return(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-LABEL: "op_reverse" func.func @op_reverse(%arg0: tensor<16xf32>) -> tensor<16xf32> { // CHECK: "stablehlo.reverse"(%arg0) { - // CHECK-SAME: dimensions = dense<0> : tensor<1xi64> + // CHECK-SAME: dimensions = array // CHECK-SAME: } : (tensor<16xf32>) -> tensor<16xf32> %0 = "mhlo.reverse"(%arg0) { dimensions = dense<0> : tensor<1xi64> @@ -1314,13 +1343,13 @@ func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "op_rng" -func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { // CHECK: "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK-SAME: rng_distribution = #stablehlo - // CHECK-SAME: } : (tensor, tensor, tensor) -> tensor + // CHECK-SAME: } : (tensor, tensor, tensor<0xindex>) -> tensor %0 = "mhlo.rng"(%arg0, %arg1, %arg2) { rng_distribution = #mhlo.rng_distribution - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } @@ -1390,8 +1419,8 @@ func.func @op_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<1 // CHECK-NEXT: "stablehlo.return"(%[[VAL12]]) : (tensor) -> () // CHECK-NEXT: }) { // CHECK-SAME: padding = dense<0> : tensor<4x2xi64>, - // CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, - // CHECK-SAME: window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> + // CHECK-SAME: window_dimensions = array, + // CHECK-SAME: window_strides = array // CHECK-SAME: } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> %0 = "mhlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ ^bb0(%arg3: tensor, %arg4: tensor): @@ -1478,9 +1507,9 @@ func.func @op_sine(%arg0: tensor) -> tensor { // CHECK-LABEL: "op_slice" func.func @op_slice(%arg0: tensor<16xf32>) -> tensor<4xf32> { // CHECK: "stablehlo.slice"(%arg0) { - // CHECK-SAME: limit_indices = dense<4> : tensor<1xi64>, - // CHECK-SAME: start_indices = dense<0> : tensor<1xi64>, - // CHECK-SAME: strides = dense<1> : tensor<1xi64> + // CHECK-SAME: limit_indices = array, + // CHECK-SAME: start_indices = array, + // CHECK-SAME: strides = array // CHECK-SAME: } : (tensor<16xf32>) -> tensor<4xf32> %0 = "mhlo.slice"(%arg0) { start_indices = dense<0> : tensor<1xi64>, @@ -1543,7 +1572,16 @@ func.func @op_tanh(%arg0: tensor) -> tensor { func.return %0 : tensor } -// TopKOp aka mhlo.topk is unsupported at the moment (see negative test below). +// CHECK-LABEL: "op_topk" +func.func @op_topk(%arg0: tensor<5x10xf32>) -> (tensor<5x8xf32>, tensor<5x8xi32>) { + // CHECK: "stablehlo.custom_call"(%arg0) { + // CHECK-SAME: call_target_name = "mhlo.topk" + // CHECK-SAME{LITERAL}: mhlo.attributes = {k = 8 : i64, largest = true} + // CHECK-SAME{LITERAL}: mhlo.version = 1 : i64 + // CHECK-SAME: } : (tensor<5x10xf32>) -> (tensor<5x8xf32>, tensor<5x8xi32>) + %0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<5x10xf32> -> (tensor<5x8xf32>, tensor<5x8xi32>) + func.return %0#0, %0#1 : tensor<5x8xf32>, tensor<5x8xi32> +} // CHECK-LABEL: "op_torch_index_select" func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>) -> tensor<2x1x5xf32> { @@ -1572,7 +1610,7 @@ func.func @op_trace(%arg0: tensor) { // CHECK-LABEL: "op_transpose" func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> { // CHECK: "stablehlo.transpose"(%arg0) { - // CHECK-SAME: permutation = dense<[1, 0]> : tensor<2xi64> + // CHECK-SAME: permutation = array // CHECK-SAME: } : (tensor<16x8xf32>) -> tensor<8x16xf32> %0 = "mhlo.transpose"(%arg0) { permutation = dense<[1, 0]> : tensor<2xi64> @@ -1820,18 +1858,11 @@ func.func @type_dynamism_ranked(%arg0: tensor) -> tensor { func.return %0 : tensor } -// CHECK-LABEL: "type_dynamism_unranked" -func.func @type_dynamism_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "stablehlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> - %0 = "mhlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - // CHECK-LABEL: "type_quantization" -func.func @type_quantization(%arg0: tensor>, %arg1: tensor) -> tensor { - // CHECK: "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor) -> tensor - %0 = "mhlo.add"(%arg0, %arg1) : (tensor>, tensor) -> tensor - func.return %0 : tensor +func.func @type_quantization(%arg0: tensor>) -> tensor> { + // CHECK: "stablehlo.add"(%arg0, %arg0) : (tensor>, tensor>) -> tensor> + %0 = "mhlo.add"(%arg0, %arg0) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> } // ----- @@ -2057,14 +2088,6 @@ func.func @op_stochastic_convert(%arg0: tensor, %arg1: tensor) -> ten // ----- -func.func @op_topk(%arg0 : tensor<16xf32>) { - // expected-error@+1 {{failed to legalize operation 'mhlo.topk' that was explicitly marked illegal}} - %0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) - return -} - -// ----- - func.func @op_xla_rng_get_and_update_state() -> tensor<2xui64> { // expected-error@+1 {{failed to legalize operation 'mhlo.xla.rng_get_and_update_state' that was explicitly marked illegal}} %0 = "mhlo.xla.rng_get_and_update_state"() { diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-torch-index-select-to-gather.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-torch-index-select-to-gather.mlir index 19230564c8010..058bbe8b068fa 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-torch-index-select-to-gather.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-torch-index-select-to-gather.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: @index_select_to_gather_convert_index_type func.func @index_select_to_gather_convert_index_type(%arg0 : tensor<5x1x5xi64>, %arg1 : tensor<2xi64>) -> tensor<2x1x5xi64> { // CHECK: [[ARG1:%.+]] = mhlo.convert %arg1 : (tensor<2xi64>) -> tensor<2xui32> - // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, [[ARG1]]) { + // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, [[ARG1]]) <{ // CHECK-SAME: dimension_numbers = #mhlo.gather< // CHECK-SAME: offset_dims = [1, 2], // CHECK-SAME: collapsed_slice_dims = [0], @@ -12,11 +12,11 @@ func.func @index_select_to_gather_convert_index_type(%arg0 : tensor<5x1x5xi64>, // CHECK-SAME: >, // CHECK-SAME: indices_are_sorted = false, // CHECK-SAME: slice_sizes = dense<[1, 1, 5]> : tensor<3xi64> - // CHECK-SAME: } : (tensor<5x1x5xi64>, tensor<2xui32>) -> tensor<2x1x5xi64> - %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + // CHECK-SAME: }> : (tensor<5x1x5xi64>, tensor<2xui32>) -> tensor<2x1x5xi64> + %0 = "mhlo.torch_index_select"(%arg0, %arg1) <{ dim = 0 : i64, batch_dims = 0 : i64 - } : (tensor<5x1x5xi64>, tensor<2xi64>) -> tensor<2x1x5xi64> + }> : (tensor<5x1x5xi64>, tensor<2xi64>) -> tensor<2x1x5xi64> // CHECK: return [[RES]] : tensor<2x1x5xi64> func.return %0 : tensor<2x1x5xi64> } @@ -25,7 +25,7 @@ func.func @index_select_to_gather_convert_index_type(%arg0 : tensor<5x1x5xi64>, // CHECK-LABEL: @index_select_to_gather_multi_offset_dims func.func @index_select_to_gather_multi_offset_dims(%arg0 : tensor<5x1x5xi32>, %arg1 : tensor<2xi32>) -> tensor<2x1x5xi32> { - // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, %arg1) { + // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, %arg1) <{ // CHECK-SAME: dimension_numbers = #mhlo.gather< // CHECK-SAME: offset_dims = [1, 2], // CHECK-SAME: collapsed_slice_dims = [0], @@ -34,11 +34,11 @@ func.func @index_select_to_gather_multi_offset_dims(%arg0 : tensor<5x1x5xi32>, % // CHECK-SAME: >, // CHECK-SAME: indices_are_sorted = false, // CHECK-SAME: slice_sizes = dense<[1, 1, 5]> : tensor<3xi64> - // CHECK-SAME: } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> - %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + // CHECK-SAME: }> : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> + %0 = "mhlo.torch_index_select"(%arg0, %arg1) <{ dim = 0 : i64, batch_dims = 0 : i64 - } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> + }> : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> // CHECK: return [[RES]] : tensor<2x1x5xi32> func.return %0 : tensor<2x1x5xi32> } @@ -47,7 +47,7 @@ func.func @index_select_to_gather_multi_offset_dims(%arg0 : tensor<5x1x5xi32>, % // CHECK-LABEL: @index_select_to_gather_larger_output func.func @index_select_to_gather_larger_output(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32> { - // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, %arg1) { + // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, %arg1) <{ // CHECK-SAME: dimension_numbers = #mhlo.gather< // CHECK-SAME: offset_dims = [3], // CHECK-SAME: collapsed_slice_dims = [0], @@ -56,11 +56,11 @@ func.func @index_select_to_gather_larger_output(%arg0 : tensor<5x4xf32>, %arg1 : // CHECK-SAME: >, // CHECK-SAME: indices_are_sorted = false, // CHECK-SAME: slice_sizes = dense<[1, 4]> : tensor<2xi64> - // CHECK-SAME: } : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32> - %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + // CHECK-SAME: }> : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32> + %0 = "mhlo.torch_index_select"(%arg0, %arg1) <{ dim = 0 : i64, batch_dims = 0 : i64 - } : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32> + }> : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32> // CHECK: return [[RES]] : tensor<1x3x1x4xf32> func.return %0 : tensor<1x3x1x4xf32> } @@ -69,7 +69,7 @@ func.func @index_select_to_gather_larger_output(%arg0 : tensor<5x4xf32>, %arg1 : // CHECK-LABEL: @index_select_to_gather_regular_map func.func @index_select_to_gather_regular_map(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi32>) -> tensor<2x4xi32> { - // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, %arg1) { + // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, %arg1) <{ // CHECK-SAME: dimension_numbers = #mhlo.gather< // CHECK-SAME: offset_dims = [1], // CHECK-SAME: collapsed_slice_dims = [0], @@ -78,11 +78,11 @@ func.func @index_select_to_gather_regular_map(%arg0: tensor<3x4xi32>, %arg1: ten // CHECK-SAME: >, // CHECK-SAME: indices_are_sorted = false, // CHECK-SAME: slice_sizes = dense<[1, 4]> : tensor<2xi64> - // CHECK-SAME: } : (tensor<3x4xi32>, tensor<2xi32>) -> tensor<2x4xi32> - %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + // CHECK-SAME: }> : (tensor<3x4xi32>, tensor<2xi32>) -> tensor<2x4xi32> + %0 = "mhlo.torch_index_select"(%arg0, %arg1) <{ dim = 0 : i64, batch_dims = 0 : i64 - } : (tensor<3x4xi32>, tensor<2xi32>) -> tensor<2x4xi32> + }> : (tensor<3x4xi32>, tensor<2xi32>) -> tensor<2x4xi32> // CHECK: return [[RES]] : tensor<2x4xi32> func.return %0 : tensor<2x4xi32> } @@ -91,7 +91,7 @@ func.func @index_select_to_gather_regular_map(%arg0: tensor<3x4xi32>, %arg1: ten // CHECK-LABEL: @index_select_to_gather_reverse_map func.func @index_select_to_gather_reverse_map(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi32>) -> tensor<3x2xi32> { - // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, %arg1) { + // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, %arg1) <{ // CHECK-SAME: dimension_numbers = #mhlo.gather< // CHECK-SAME: offset_dims = [0], // CHECK-SAME: collapsed_slice_dims = [1], @@ -100,11 +100,11 @@ func.func @index_select_to_gather_reverse_map(%arg0: tensor<3x4xi32>, %arg1: ten // CHECK-SAME: >, // CHECK-SAME: indices_are_sorted = false, // CHECK-SAME: slice_sizes = dense<[3, 1]> : tensor<2xi64> - // CHECK-SAME: } : (tensor<3x4xi32>, tensor<2xi32>) -> tensor<3x2xi32> - %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + // CHECK-SAME: }> : (tensor<3x4xi32>, tensor<2xi32>) -> tensor<3x2xi32> + %0 = "mhlo.torch_index_select"(%arg0, %arg1) <{ dim = 1 : i64, batch_dims = 0 : i64 - } : (tensor<3x4xi32>, tensor<2xi32>) -> tensor<3x2xi32> + }> : (tensor<3x4xi32>, tensor<2xi32>) -> tensor<3x2xi32> // CHECK: return [[RES]] : tensor<3x2xi32> func.return %0 : tensor<3x2xi32> } @@ -113,10 +113,10 @@ func.func @index_select_to_gather_reverse_map(%arg0: tensor<3x4xi32>, %arg1: ten // CHECK-LABEL: @index_select_to_gather_batch_dim_greater_than_1 func.func @index_select_to_gather_batch_dim_greater_than_1(%arg0 : tensor<5x1x5xi32>, %arg1 : tensor<2xi32>) -> tensor<2x5xi32> { - // CHECK: [[ARG0:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x1xi32> + // CHECK: [[ARG0:%.+]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2x1xi32> // CHECK: [[ARG1:%.+]] = mhlo.reshape %arg1 : (tensor<2xi32>) -> tensor<2x1xi32> - // CHECK: [[ARG2:%.+]] = "mhlo.concatenate"([[ARG0]], [[ARG1]]) {dimension = 1 : i64} : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> - // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, [[ARG2]]) { + // CHECK: [[ARG2:%.+]] = "mhlo.concatenate"([[ARG0]], [[ARG1]]) <{dimension = 1 : i64}> : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> + // CHECK: [[RES:%.+]] = "mhlo.gather"(%arg0, [[ARG2]]) <{ // CHECK-SAME: dimension_numbers = #mhlo.gather< // CHECK-SAME: offset_dims = [1], // CHECK-SAME: collapsed_slice_dims = [0, 1], @@ -125,33 +125,22 @@ func.func @index_select_to_gather_batch_dim_greater_than_1(%arg0 : tensor<5x1x5x // CHECK-SAME: >, // CHECK-SAME: indices_are_sorted = false, // CHECK-SAME: slice_sizes = dense<[1, 1, 5]> : tensor<3xi64> - // CHECK-SAME: } : (tensor<5x1x5xi32>, tensor<2x2xi32>) -> tensor<2x5xi32> - %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + // CHECK-SAME: }> : (tensor<5x1x5xi32>, tensor<2x2xi32>) -> tensor<2x5xi32> + %0 = "mhlo.torch_index_select"(%arg0, %arg1) <{ dim = 1 : i64, batch_dims = 1 : i64 - } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x5xi32> + }> : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x5xi32> func.return %0 : tensor<2x5xi32> } // ----- -func.func @index_select_to_gather_unranked(%arg0 : tensor<*xi32>, %arg1 : tensor<*xi32>) -> tensor<*xi32> { - // CHECK: mhlo.torch_index_select - %0 = "mhlo.torch_index_select"(%arg0, %arg1) { - dim = 0 : i64, - batch_dims = 0 : i64 - } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> - func.return %0 : tensor<*xi32> -} - -// ----- - func.func @index_select_to_gather_non_static_operand(%arg0 : tensor<5x1x?xi32>, %arg1 : tensor<2xi32>) -> tensor<2x1x5xi32> { // CHECK: mhlo.torch_index_select - %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + %0 = "mhlo.torch_index_select"(%arg0, %arg1) <{ dim = 0 : i64, batch_dims = 0 : i64 - } : (tensor<5x1x?xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> + }> : (tensor<5x1x?xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> func.return %0 : tensor<2x1x5xi32> } @@ -159,10 +148,10 @@ func.func @index_select_to_gather_non_static_operand(%arg0 : tensor<5x1x?xi32>, func.func @index_select_to_gather_non_static_index(%arg0 : tensor<5x1x5xi32>, %arg1 : tensor) -> tensor<2x1x5xi32> { // CHECK: mhlo.torch_index_select - %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + %0 = "mhlo.torch_index_select"(%arg0, %arg1) <{ dim = 0 : i64, batch_dims = 0 : i64 - } : (tensor<5x1x5xi32>, tensor) -> tensor<2x1x5xi32> + }> : (tensor<5x1x5xi32>, tensor) -> tensor<2x1x5xi32> func.return %0 : tensor<2x1x5xi32> } @@ -170,10 +159,10 @@ func.func @index_select_to_gather_non_static_index(%arg0 : tensor<5x1x5xi32>, %a func.func @index_select_to_gather_dim_less_than_batch_dims(%arg0 : tensor<5x1x5xi32>, %arg1 : tensor<2xi32>) -> tensor<2x1x5xi32> { // CHECK: mhlo.torch_index_select - %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + %0 = "mhlo.torch_index_select"(%arg0, %arg1) <{ dim = 0 : i64, batch_dims = 1 : i64 - } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> + }> : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> func.return %0 : tensor<2x1x5xi32> } @@ -181,9 +170,9 @@ func.func @index_select_to_gather_dim_less_than_batch_dims(%arg0 : tensor<5x1x5x func.func @index_select_to_gather_non_integer_index(%arg0 : tensor<5x1x5xi32>, %arg1 : tensor<2xf32>) -> tensor<2x1x5xi32> { // CHECK: mhlo.torch_index_select - %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + %0 = "mhlo.torch_index_select"(%arg0, %arg1) <{ dim = 0 : i64, batch_dims = 0 : i64 - } : (tensor<5x1x5xi32>, tensor<2xf32>) -> tensor<2x1x5xi32> + }> : (tensor<5x1x5xi32>, tensor<2xf32>) -> tensor<2x1x5xi32> func.return %0 : tensor<2x1x5xi32> } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/legalize-control-flow.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/legalize-control-flow.mlir index 2f91ca955248c..49cb1c4f82575 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/legalize-control-flow.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/legalize-control-flow.mlir @@ -64,14 +64,14 @@ func.func @while_multi_operands(%arg0: tensor<3xi32>) -> tuple, tens // CHECK-NEXT: %[[VAL_13:.*]] = mhlo.constant dense<1> : tensor // CHECK: %[[VAL_14:.*]] = mhlo.add %[[VAL_10]], %[[VAL_13]] : tensor // CHECK: %[[VAL_15:.*]] = mhlo.convert %[[VAL_10]] : tensor - // CHECK: %[[VAL_16:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_15]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<3xi32> + // CHECK: %[[VAL_16:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_15]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<3xi32> // CHECK: %[[VAL_17:.*]] = mhlo.add %[[VAL_11]], %[[VAL_16]] : tensor<3xi32> // CHECK: scf.yield %[[VAL_14]], %[[VAL_17]] : tensor, tensor<3xi32> %4 = mhlo.constant dense : tensor %5 = mhlo.constant dense<1> : tensor %6 = mhlo.add %arg1, %5 : tensor %7 = mhlo.convert %arg1 : (tensor) -> tensor - %8 = "mhlo.broadcast_in_dim"(%7) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<3xi32> + %8 = "mhlo.broadcast_in_dim"(%7) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<3xi32> %9 = mhlo.add %arg2, %8 : tensor<3xi32> "mhlo.return"(%6, %9) : (tensor, tensor<3xi32>) -> () }) : (tensor, tensor<3xi32>) -> (tensor, tensor<3xi32>) diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/legalize-hlo-shape-computations.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/legalize-hlo-shape-computations.mlir index fe6ab0e1860da..c98d4fe442d05 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/legalize-hlo-shape-computations.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/legalize-hlo-shape-computations.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func @get_dimension_size func.func @get_dimension_size(%arg0: tensor) -> (tensor) { - %1 = "mhlo.get_dimension_size"(%arg0) {dimension = 1 : i64} : (tensor) -> tensor + %1 = "mhlo.get_dimension_size"(%arg0) <{dimension = 1 : i64}> : (tensor) -> tensor func.return %1 : tensor } @@ -16,7 +16,7 @@ func.func @get_dimension_size(%arg0: tensor) -> (tensor) { // CHECK-LABEL: func @reshape_dimension_size func.func @reshape_dimension_size(%arg0: tensor) -> (tensor<1xi32>) { - %0 = "mhlo.get_dimension_size"(%arg0) {dimension = 1 : i64} : (tensor) -> tensor + %0 = "mhlo.get_dimension_size"(%arg0) <{dimension = 1 : i64}> : (tensor) -> tensor %1 = "mhlo.reshape"(%0) : (tensor) -> tensor<1xi32> func.return %1 : tensor<1xi32> } @@ -32,7 +32,7 @@ func.func @reshape_dimension_size(%arg0: tensor) -> (tensor<1xi32>) { // CHECK-LABEL: func @multiply_dimension_size func.func @multiply_dimension_size(%arg0: tensor) -> (tensor) { %0 = mhlo.constant dense<2> : tensor - %1 = "mhlo.get_dimension_size"(%arg0) {dimension = 1 : i64} : (tensor) -> tensor + %1 = "mhlo.get_dimension_size"(%arg0) <{dimension = 1 : i64}> : (tensor) -> tensor %2 = "mhlo.multiply"(%0, %1) : (tensor, tensor) -> tensor func.return %2 : tensor } @@ -50,10 +50,10 @@ func.func @multiply_dimension_size(%arg0: tensor) -> (tensor) { // CHECK-LABEL: func @concat_dimension_size func.func @concat_dimension_size(%arg0: tensor) -> (tensor<2xi32>) { - %0 = "mhlo.get_dimension_size"(%arg0) {dimension = 1 : i64} : (tensor) -> tensor + %0 = "mhlo.get_dimension_size"(%arg0) <{dimension = 1 : i64}> : (tensor) -> tensor %1 = "mhlo.reshape"(%0) : (tensor) -> tensor<1xi32> %2 = mhlo.constant dense<2> : tensor<1xi32> - %3 = "mhlo.concatenate"(%1, %2) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %3 = "mhlo.concatenate"(%1, %2) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> func.return %3 : tensor<2xi32> } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/legalize-to-std.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/legalize-to-std.mlir index 59900aeae16ac..cd1ea4fdb9a55 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/legalize-to-std.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/legalize-to-std.mlir @@ -120,7 +120,7 @@ func.func @float_constant() -> (tensor, tensor<2x3xf32>, tensor<2x3xf32>) { // CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> { func.func @iota.const.1() -> tensor<4xi32> { // CHECK-NEXT: %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi32> - %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<4xi32> // CHECK-NEXT: return %[[CST]] : tensor<4xi32> func.return %0 : tensor<4xi32> } @@ -128,7 +128,7 @@ func.func @iota.const.1() -> tensor<4xi32> { // CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> { func.func @iota.const.2() -> tensor<2x4xi32> { // CHECK-NEXT: %[[CST:.*]] = arith.constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32> - %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32> + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32> func.return %0 : tensor<2x4xi32> } @@ -136,7 +136,7 @@ func.func @iota.const.2() -> tensor<2x4xi32> { // CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> { func.func @iota.const.3() -> tensor<2x4xi32> { // CHECK-NEXT: %[[CST:.*]] = arith.constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32> - %0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32> + %0 = "mhlo.iota"() <{iota_dimension = 1 : i64}> : () -> tensor<2x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32> func.return %0 : tensor<2x4xi32> } @@ -144,7 +144,7 @@ func.func @iota.const.3() -> tensor<2x4xi32> { // CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> { func.func @iota.const.4() -> tensor<2x3x4xi32> { // CHECK-NEXT: %[[CST:.*]] = arith.constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32> - %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32> + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2x3x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> func.return %0 : tensor<2x3x4xi32> } @@ -152,7 +152,7 @@ func.func @iota.const.4() -> tensor<2x3x4xi32> { // CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> { func.func @iota.const.5() -> tensor<2x3x4xi32> { // CHECK-NEXT: %[[CST:.*]] = arith.constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32> - %0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32> + %0 = "mhlo.iota"() <{iota_dimension = 1 : i64}> : () -> tensor<2x3x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> func.return %0 : tensor<2x3x4xi32> } @@ -160,7 +160,7 @@ func.func @iota.const.5() -> tensor<2x3x4xi32> { // CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> { func.func @iota.const.6() -> tensor<2x3x4xi32> { // CHECK-NEXT: %[[CST:.*]] = arith.constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32> - %0 = "mhlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32> + %0 = "mhlo.iota"() <{iota_dimension = 2 : i64}> : () -> tensor<2x3x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> func.return %0 : tensor<2x3x4xi32> } @@ -168,7 +168,7 @@ func.func @iota.const.6() -> tensor<2x3x4xi32> { // CHECK-LABEL: func @iota.const.f32 func.func @iota.const.f32() -> tensor<4xf32> { // CHECK-NEXT: %[[CST:.*]] = arith.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32> - %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32> + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<4xf32> // CHECK-NEXT: return %[[CST]] : tensor<4xf32> func.return %0 : tensor<4xf32> } @@ -176,7 +176,7 @@ func.func @iota.const.f32() -> tensor<4xf32> { // CHECK-LABEL: func @iota.const.f64 func.func @iota.const.f64() -> tensor<4xf64> { // CHECK-NEXT: %[[CST:.*]] = arith.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64> - %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64> + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<4xf64> // CHECK-NEXT: return %[[CST]] : tensor<4xf64> func.return %0 : tensor<4xf64> } @@ -184,7 +184,7 @@ func.func @iota.const.f64() -> tensor<4xf64> { // CHECK-LABEL: func @iota.const.bf16 func.func @iota.const.bf16() -> tensor<4xbf16> { // CHECK-NEXT: %[[CST:.*]] = arith.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xbf16> - %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16> + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<4xbf16> // CHECK-NEXT: return %[[CST]] : tensor<4xbf16> func.return %0 : tensor<4xbf16> } @@ -194,7 +194,7 @@ func.func @iota.const.complex.f32() -> tensor<4xcomplex> { // CHECK-NEXT: [[REAL:%.*]] = arith.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32> // CHECK-NEXT: [[IMAG:%.*]] = arith.constant dense<0.000000e+00> : tensor<4xf32> // CHECK-NEXT: [[COMPLEX:%.*]] = mhlo.complex [[REAL]], [[IMAG]] - %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex> + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<4xcomplex> // CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex> func.return %0 : tensor<4xcomplex> } @@ -204,7 +204,7 @@ func.func @iota.const.complex.f64() -> tensor<4xcomplex> { // CHECK-NEXT: [[REAL:%.*]] = arith.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64> // CHECK-NEXT: [[IMAG:%.*]] = arith.constant dense<0.000000e+00> : tensor<4xf64> // CHECK-NEXT: [[COMPLEX:%.*]] = mhlo.complex [[REAL]], [[IMAG]] - %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex> + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<4xcomplex> // CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex> func.return %0 : tensor<4xcomplex> } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/lower-complex.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/lower-complex.mlir index 13963174283cd..8c2e615baed7f 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/lower-complex.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/lower-complex.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s +// RUN: mlir-hlo-opt %s --mhlo-test-lower-complex | FileCheck %s // CHECK-LABEL: @add func.func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { @@ -15,21 +15,6 @@ func.func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf3 func.return %5, %6 : tensor<2xf32>, tensor<2xf32> } -// CHECK-LABEL: @add_unranked -func.func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { - %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = mhlo.add %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = mhlo.add %arg1, %arg3 - %4 = "mhlo.add"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) - %5 = mhlo.real %4 : (tensor<*xcomplex>) -> (tensor<*xf32>) - %6 = mhlo.imag %4 : (tensor<*xcomplex>) -> (tensor<*xf32>) - - // CHECK: return [[VAL0]], [[VAL1]] - func.return %5, %6 : tensor<*xf32>, tensor<*xf32> -} - // CHECK-LABEL: @sub func.func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) @@ -45,21 +30,6 @@ func.func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf3 func.return %5, %6 : tensor<2xf32>, tensor<2xf32> } -// CHECK-LABEL: @sub_unranked -func.func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { - %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = mhlo.subtract %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = mhlo.subtract %arg1, %arg3 - %4 = "mhlo.subtract"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) - %5 = mhlo.real %4 : (tensor<*xcomplex>) -> (tensor<*xf32>) - %6 = mhlo.imag %4 : (tensor<*xcomplex>) -> (tensor<*xf32>) - - // CHECK: return [[VAL0]], [[VAL1]] - func.return %5, %6 : tensor<*xf32>, tensor<*xf32> -} - // CHECK-LABEL: @mul func.func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) @@ -79,25 +49,6 @@ func.func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf3 func.return %5, %6 : tensor<2xf32>, tensor<2xf32> } -// CHECK-LABEL: @mul_unranked -func.func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { - %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg3 - // CHECK-DAG: [[VAL2:%.+]] = mhlo.subtract [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply %arg0, %arg3 - // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg1, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = mhlo.add [[VAL3]], [[VAL4]] - %4 = "mhlo.multiply"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) - %5 = mhlo.real %4 : (tensor<*xcomplex>) -> (tensor<*xf32>) - %6 = mhlo.imag %4 : (tensor<*xcomplex>) -> (tensor<*xf32>) - - // CHECK: return %2, %5 : tensor<*xf32>, tensor<*xf32> - func.return %5, %6 : tensor<*xf32>, tensor<*xf32> -} - // CHECK-LABEL: @div func.func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) @@ -137,44 +88,6 @@ func.func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf3 // ----- -// CHECK-LABEL: @div_unranked -func.func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { - %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = mhlo.negate %arg3 - - // Compute the numerator's real component: - // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag - // CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg0, %arg2 - // CHECK-DAG: [[VAL2:%.+]] = mhlo.multiply %arg1, [[VAL0]] - // CHECK-DAG: [[VAL3:%.+]] = mhlo.subtract [[VAL1]], [[VAL2]] - - // Compute the real valued denominator as rhs * con(rhs): - // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag - // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, %arg3 - // CHECK-DAG: [[VAL6:%.+]] = mhlo.add [[VAL4]], [[VAL5]] - - // Compute the numerator's imaginary component: - // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag - // CHECK-DAG: [[VAL7:%.+]] = mhlo.multiply %arg1, %arg2 - // CHECK-DAG: [[VAL8:%.+]] = mhlo.multiply %arg0, [[VAL0]] - // CHECK-DAG: [[VAL9:%.+]] = mhlo.add [[VAL8]], [[VAL7]] - - // Divide the numerator by the real valued denominator. - // CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]] - // CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]] - - %4 = "mhlo.divide"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) - - %5 = mhlo.real %4 : (tensor<*xcomplex>) -> (tensor<*xf32>) - %6 = mhlo.imag %4 : (tensor<*xcomplex>) -> (tensor<*xf32>) - - // CHECK: return [[VAL10]], [[VAL11]] - func.return %5, %6 : tensor<*xf32>, tensor<*xf32> -} - // CHECK-LABEL: @abs func.func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) { %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) @@ -223,22 +136,6 @@ func.func @exp_complex(%arg0 : tensor<2xcomplex>) -> (tensor<2xcomplex func.return %0 : tensor<2xcomplex> } -// CHECK-LABEL: @exp_unranked -func.func @exp_unranked(%arg0 : tensor<*xcomplex>) -> (tensor<*xcomplex>) { - // CHECK-DAG: [[REAL:%.+]] = mhlo.real %arg0 - // CHECK-DAG: [[IMAG:%.+]] = mhlo.imag %arg0 - // CHECK-DAG: [[EXP:%.+]] = mhlo.exponential [[REAL]] - // CHECK-DAG: [[COS:%.+]] = mhlo.cosine [[IMAG]] - // CHECK-DAG: [[SIN:%.+]] = mhlo.sine [[IMAG]] - // CHECK-DAG: [[OUTR:%.+]] = mhlo.multiply [[COS]], [[EXP]] - // CHECK-DAG: [[OUTI:%.+]] = mhlo.multiply [[SIN]], [[EXP]] - // CHECK-DAG: [[OUT:%.+]] = mhlo.complex [[OUTR]], [[OUTI]] - %0 = mhlo.exponential %arg0 : (tensor<*xcomplex>) -> (tensor<*xcomplex>) - - // CHECK: [[OUT]] - func.return %0 : tensor<*xcomplex> -} - // CHECK-LABEL: @compare_eq // CHECK: ([[LHS:%.+]]: tensor<2xcomplex>, [[RHS:%.+]]: tensor<2xcomplex>) func.func @compare_eq(%lhs : tensor<2xcomplex>, %rhs: tensor<2xcomplex>) -> (tensor<2xi1>) { diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/lower-general-dot.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/lower-general-dot.mlir index 60b810ff24432..f0d93cabe5b20 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/lower-general-dot.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/lower-general-dot.mlir @@ -3,15 +3,15 @@ // CHECK-LABEL: @testDebatch1 func.func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> { // CHECK-DAG: [[R0:%.+]] = mhlo.reshape %arg0 : (tensor<1x1x2xf32>) -> tensor<1x2xf32> - // CHECK-DAG: [[R1:%.+]] = "mhlo.dot"([[R0]], %arg1) {precision_config = [#mhlo, #mhlo]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + // CHECK-DAG: [[R1:%.+]] = "mhlo.dot"([[R0]], %arg1) <{precision_config = [#mhlo, #mhlo]}> : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> // CHECK: [[R2:%.+]] = mhlo.reshape [[R1]] : (tensor<1x3xf32>) -> tensor<1x1x3xf32> - %0 = "mhlo.dot_general"(%arg0, %arg1) { + %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot< lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0] >, precision_config = [#mhlo, #mhlo] - } : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32> + }> : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32> func.return %0 : tensor<1x1x3xf32> } @@ -20,19 +20,19 @@ func.func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> ten // CHECK-LABEL: @testDebatch2 func.func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<3x1x1xf32> { - // CHECK-DAG: [[R0:%.+]] = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32> - // CHECK-DAG: [[R1:%.+]] = "mhlo.transpose"(%arg1) {permutation = dense<[2, 0, 1]> : tensor<3xi64>} : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32> + // CHECK-DAG: [[R0:%.+]] = "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0]> : tensor<2xi64>}> : (tensor<2x3xf32>) -> tensor<3x2xf32> + // CHECK-DAG: [[R1:%.+]] = "mhlo.transpose"(%arg1) <{permutation = dense<[2, 0, 1]> : tensor<3xi64>}> : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32> // CHECK-DAG: [[R2:%.+]] = mhlo.reshape [[R1]] : (tensor<2x1x1xf32>) -> tensor<2x1xf32> - // CHECK-DAG: [[R3:%.+]] = "mhlo.dot"([[R0]], [[R2]]) {precision_config = [#mhlo, #mhlo]} : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32> + // CHECK-DAG: [[R3:%.+]] = "mhlo.dot"([[R0]], [[R2]]) <{precision_config = [#mhlo, #mhlo]}> : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32> // CHECK: [[R4:%.+]] = mhlo.reshape [[R3]] : (tensor<3x1xf32>) -> tensor<3x1x1xf32> - %0 = "mhlo.dot_general"(%arg0, %arg1) { + %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot< lhs_contracting_dimensions = [0], rhs_contracting_dimensions = [2] >, precision_config = [#mhlo, #mhlo] - } : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32> + }> : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32> func.return %0 : tensor<3x1x1xf32> } @@ -41,7 +41,7 @@ func.func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> ten // CHECK-LABEL: @testBatchPassthrough func.func @testBatchPassthrough(%arg0: tensor<2x2x3xf32>, %arg1: tensor<2x1x2xf32>) -> tensor<2x3x1xf32> { // CHECK-NEXT: "mhlo.dot_general"(%arg0, %arg1) - %0 = "mhlo.dot_general"(%arg0, %arg1) { + %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], lhs_contracting_dimensions = [1], @@ -49,7 +49,7 @@ func.func @testBatchPassthrough(%arg0: tensor<2x2x3xf32>, %arg1: tensor<2x1x2xf3 rhs_contracting_dimensions = [2] >, precision_config = [#mhlo, #mhlo] - } : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<2x3x1xf32> + }> : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<2x3x1xf32> func.return %0 : tensor<2x3x1xf32> } @@ -59,13 +59,13 @@ func.func @testBatchPassthrough(%arg0: tensor<2x2x3xf32>, %arg1: tensor<2x1x2xf3 func.func @testVec(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>) -> tensor { // CHECK-NEXT: [[R:%.+]] = "mhlo.dot"(%arg0, %arg1) // CHECK-NEXT: return [[R]] - %0 = "mhlo.dot_general"(%arg0, %arg1) { + %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot< lhs_contracting_dimensions = [0], rhs_contracting_dimensions = [0] >, precision_config = [#mhlo, #mhlo] - } : (tensor<32xf32>, tensor<32xf32>) -> tensor + }> : (tensor<32xf32>, tensor<32xf32>) -> tensor func.return %0 : tensor } @@ -75,13 +75,13 @@ func.func @testVec(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>) -> tensor func.func @testMatVec(%arg0: tensor<20x32xf32>, %arg1: tensor<32xf32>) -> tensor<20xf32> { // CHECK-NEXT: [[R:%.+]] = "mhlo.dot"(%arg0, %arg1) // CHECK-NEXT: return [[R]] - %0 = "mhlo.dot_general"(%arg0, %arg1) { + %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot< lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0] >, precision_config = [#mhlo, #mhlo] - } : (tensor<20x32xf32>, tensor<32xf32>) -> tensor<20xf32> + }> : (tensor<20x32xf32>, tensor<32xf32>) -> tensor<20xf32> func.return %0 : tensor<20xf32> } @@ -89,25 +89,25 @@ func.func @testMatVec(%arg0: tensor<20x32xf32>, %arg1: tensor<32xf32>) -> tensor // CHECK-LABEL: @testMatVec func.func @testMatVec(%arg0: tensor<32x20xf32>, %arg1: tensor<32xf32>) -> tensor<20xf32> { - // CHECK-NEXT: [[T:%.+]] = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> + // CHECK-NEXT: [[T:%.+]] = "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0]> // CHECK-NEXT: [[R1:%.+]] = mhlo.reshape %arg1 : (tensor<32xf32>) -> tensor<32x1xf32> // CHECK-NEXT: [[M:%.+]] = "mhlo.dot"([[T]], [[R1]]) // CHECK-NEXT: [[R:%.+]] = mhlo.reshape [[M]] // CHECK-NEXT: return [[R]] - %0 = "mhlo.dot_general"(%arg0, %arg1) { + %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot< lhs_contracting_dimensions = [0], rhs_contracting_dimensions = [0] >, precision_config = [#mhlo, #mhlo] - } : (tensor<32x20xf32>, tensor<32xf32>) -> tensor<20xf32> + }> : (tensor<32x20xf32>, tensor<32xf32>) -> tensor<20xf32> func.return %0 : tensor<20xf32> } // ----- func.func @dot_general_to_dot_dynamic(%arg0: tensor<128x4x?x32xf32>, %arg1: tensor<8x?x128x4xf32>) -> tensor { - %0 = "mhlo.dot_general"(%arg0, %arg1) { + %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [], lhs_contracting_dimensions = [0, 1], @@ -115,31 +115,31 @@ func.func @dot_general_to_dot_dynamic(%arg0: tensor<128x4x?x32xf32>, %arg1: tens rhs_contracting_dimensions = [2, 3], >, precision_config = [#mhlo, #mhlo] - } : (tensor<128x4x?x32xf32>, tensor<8x?x128x4xf32>) -> tensor + }> : (tensor<128x4x?x32xf32>, tensor<8x?x128x4xf32>) -> tensor func.return %0 : tensor } // CHECK-LABEL: func @dot_general_to_dot_dynamic // CHECK-DAG: %[[C32:.+]] = mhlo.constant dense<32> : tensor<1xi32> // CHECK-DAG: %[[C512:.+]] = mhlo.constant dense<512> : tensor<1xi32> // CHECK-DAG: %[[C8:.+]] = mhlo.constant dense<8> : tensor<1xi32> -// CHECK-DAG: %[[TRANS0:.+]] = "mhlo.transpose"(%arg0) {permutation = dense<[2, 3, 0, 1]> : tensor<4xi64>} -// CHECK-DAG: %[[DIM0:.+]] = "mhlo.get_dimension_size"(%arg0) {dimension = 2 : i64} +// CHECK-DAG: %[[TRANS0:.+]] = "mhlo.transpose"(%arg0) <{permutation = dense<[2, 3, 0, 1]> : tensor<4xi64>}> +// CHECK-DAG: %[[DIM0:.+]] = "mhlo.get_dimension_size"(%arg0) <{dimension = 2 : i64}> // CHECK-DAG: %[[RESHAPE0:.+]] = mhlo.reshape %[[DIM0]] : (tensor) -> tensor<1xi32> // CHECK-DAG: %[[MUL0:.+]] = mhlo.multiply %[[RESHAPE0]], %[[C32]] -// CHECK-DAG: %[[CONCAT1:.+]] = "mhlo.concatenate"(%[[MUL0]], %[[C512]]) {dimension = 0 : i64} +// CHECK-DAG: %[[CONCAT1:.+]] = "mhlo.concatenate"(%[[MUL0]], %[[C512]]) <{dimension = 0 : i64}> // CHECK-DAG: %[[DR1:.+]] = mhlo.dynamic_reshape %[[TRANS0]], %[[CONCAT1]] -// CHECK-DAG: %[[TRANS1:.+]] = "mhlo.transpose"(%arg1) {permutation = dense<[2, 3, 0, 1]> : tensor<4xi64>} -// CHECK-DAG: %[[DIM1:.+]] = "mhlo.get_dimension_size"(%arg1) {dimension = 1 : i64} +// CHECK-DAG: %[[TRANS1:.+]] = "mhlo.transpose"(%arg1) <{permutation = dense<[2, 3, 0, 1]> : tensor<4xi64>}> +// CHECK-DAG: %[[DIM1:.+]] = "mhlo.get_dimension_size"(%arg1) <{dimension = 1 : i64}> // CHECK-DAG: %[[RESHAPE1:.+]] = mhlo.reshape %[[DIM1]] : (tensor) -> tensor<1xi32> // CHECK-DAG: %[[MUL1:.+]] = mhlo.multiply %[[RESHAPE1]], %[[C8]] -// CHECK-DAG: %[[CONCAT2:.+]] = "mhlo.concatenate"(%[[C512]], %[[MUL1]]) {dimension = 0 : i64} +// CHECK-DAG: %[[CONCAT2:.+]] = "mhlo.concatenate"(%[[C512]], %[[MUL1]]) <{dimension = 0 : i64}> // CHECK-DAG: %[[DR2:.+]] = mhlo.dynamic_reshape %[[TRANS1]], %[[CONCAT2]] // CHECK-DAG: %[[DOT:.+]] = "mhlo.dot"(%[[DR1:.+]], %[[DR2:.+]]) -// CHECK-DAG: %[[DIM2:.+]] = "mhlo.get_dimension_size"(%arg0) {dimension = 2 : i64} +// CHECK-DAG: %[[DIM2:.+]] = "mhlo.get_dimension_size"(%arg0) <{dimension = 2 : i64}> // CHECK-DAG: %[[RESHAPE2:.+]] = mhlo.reshape %[[DIM2]] : (tensor) -> tensor<1xi32> -// CHECK-DAG: %[[DIM3:.+]] = "mhlo.get_dimension_size"(%arg1) {dimension = 1 : i64} +// CHECK-DAG: %[[DIM3:.+]] = "mhlo.get_dimension_size"(%arg1) <{dimension = 1 : i64}> // CHECK-DAG: %[[RESHAPE3:.+]] = mhlo.reshape %[[DIM3]] : (tensor) -> tensor<1xi32> -// CHECK-DAG: %[[CONCAT3:.+]] = "mhlo.concatenate"(%[[RESHAPE2]], %[[C32]], %[[C8]], %[[RESHAPE3]]) {dimension = 0 : i64} +// CHECK-DAG: %[[CONCAT3:.+]] = "mhlo.concatenate"(%[[RESHAPE2]], %[[C32]], %[[C8]], %[[RESHAPE3]]) <{dimension = 0 : i64}> // CHECK-DAG: %[[DR3:.+]] = mhlo.dynamic_reshape %[[DOT]], %[[CONCAT3]] // CHECK: return %[[DR3]] @@ -147,11 +147,11 @@ func.func @dot_general_to_dot_dynamic(%arg0: tensor<128x4x?x32xf32>, %arg1: tens // ----- func.func @dot_no_rhs_batch(%arg0: tensor<1x512x768xf32>, %arg1: tensor<768x12x64xf32>) -> tensor<1x512x12x64xf32> { - %0 = "mhlo.dot_general"(%arg0, %arg1) { + %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot< lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0]> - } : (tensor<1x512x768xf32>, tensor<768x12x64xf32>) -> tensor<1x512x12x64xf32> + }> : (tensor<1x512x768xf32>, tensor<768x12x64xf32>) -> tensor<1x512x12x64xf32> func.return %0 : tensor<1x512x12x64xf32> } @@ -165,14 +165,14 @@ func.func @dot_no_rhs_batch(%arg0: tensor<1x512x768xf32>, %arg1: tensor<768x12x6 // CHECK-LABEL: @testPrefElem func.func @testPrefElem(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf64> { - // CHECK: "mhlo.dot"({{%.*}}, {{%.*}}) {precision_config = [#mhlo, #mhlo]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf64> - %0 = "mhlo.dot_general"(%arg0, %arg1) { + // CHECK: "mhlo.dot"({{%.*}}, {{%.*}}) <{precision_config = [#mhlo, #mhlo]}> : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf64> + %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot< lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [0] >, precision_config = [#mhlo, #mhlo] - } : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf64> + }> : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf64> func.return %0 : tensor<1x1x3xf64> } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/materialize-broadcasts.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/materialize-broadcasts.mlir index 18829da584b4f..5c09f9193d50e 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/materialize-broadcasts.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/materialize-broadcasts.mlir @@ -3,8 +3,8 @@ // CHECK-LABEL: @clampBroadcast // CHECK-SAME: (%[[MIN:.+]]: tensor, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor) func.func @clampBroadcast(%min: tensor, %value: tensor<4xf32>, %max: tensor) -> tensor<4xf32> { - // CHECK-DAG: %[[MIN_BC:.+]] = "mhlo.broadcast"(%[[MIN]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor) -> tensor<4xf32> - // CHECK-DAG: %[[MAX_BC:.+]] = "mhlo.broadcast"(%[[MAX]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor) -> tensor<4xf32> + // CHECK-DAG: %[[MIN_BC:.+]] = "mhlo.broadcast"(%[[MIN]]) <{broadcast_sizes = dense<4> : tensor<1xi64>}> : (tensor) -> tensor<4xf32> + // CHECK-DAG: %[[MAX_BC:.+]] = "mhlo.broadcast"(%[[MAX]]) <{broadcast_sizes = dense<4> : tensor<1xi64>}> : (tensor) -> tensor<4xf32> // CHECK: mhlo.clamp %[[MIN_BC]], %[[VAL]], %[[MAX_BC]] : tensor<4xf32> %0 = "mhlo.clamp"(%min, %value, %max) : (tensor, tensor<4xf32>, tensor) -> tensor<4xf32> func.return %0 : tensor<4xf32> diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/mhlo-quant-legalize-to-int.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/mhlo-quant-legalize-to-int.mlir new file mode 100644 index 0000000000000..5988d8efb68f4 --- /dev/null +++ b/xla/mlir_hlo/tests/Dialect/mhlo/mhlo-quant-legalize-to-int.mlir @@ -0,0 +1,2346 @@ +// RUN: mlir-hlo-opt --mhlo-quant-legalize-to-int -split-input-file %s -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @uniform_quantize_and_dequantize +func.func @uniform_quantize_and_dequantize(%arg0: tensor) -> tensor { + // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<3.000000e+00> : tensor + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL0:.*]] = chlo.broadcast_divide %arg0, %[[SCALES]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL1:.*]] = chlo.broadcast_add %[[VAL0]], %[[ZPS]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL2:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL1]], %[[QUANT_MAX]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL3:.*]] = mhlo.round_nearest_even %[[VAL2]] : tensor + // CHECK: %[[VAL4:.*]] = mhlo.convert %[[VAL3]] : (tensor) -> tensor + %0 = mhlo.uniform_quantize %arg0 : (tensor) -> tensor> + + // CHECK-DAG: %[[SCALES_DQ:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[ZPS_DQ:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[VAL5:.*]] = mhlo.convert %[[VAL4]] : (tensor) -> tensor + // CHECK: %[[VAL6:.*]] = chlo.broadcast_subtract %[[VAL5]], %[[ZPS_DQ]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL7:.*]] = mhlo.convert %[[VAL6]] : (tensor) -> tensor + // CHECK: %[[VAL8:.*]] = chlo.broadcast_multiply %[[VAL7]], %[[SCALES_DQ]] : (tensor, tensor) -> tensor + // CHECK: return %[[VAL8]] : tensor + %1 = mhlo.uniform_dequantize %0 : (tensor>) -> tensor + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: func @uniform_quantize_convert_dequantize +func.func @uniform_quantize_convert_dequantize(%arg0: tensor) -> tensor { + // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<3.000000e+00> : tensor + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL0:.*]] = chlo.broadcast_divide %arg0, %[[SCALES]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL1:.*]] = chlo.broadcast_add %[[VAL0]], %[[ZPS]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL2:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL1]], %[[QUANT_MAX]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL3:.*]] = mhlo.round_nearest_even %[[VAL2]] : tensor + // CHECK: %[[VAL4:.*]] = mhlo.convert %[[VAL3]] : (tensor) -> tensor + %0 = mhlo.uniform_quantize %arg0 : (tensor) -> tensor> + + // CHECK: %[[VAL5:.*]] = mhlo.bitcast_convert %[[VAL4]] : (tensor) -> tensor + %1 = mhlo.bitcast_convert %0 : (tensor>) -> tensor + + // CHECK: %[[VAL6:.*]] = mhlo.bitcast_convert %[[VAL5]] : (tensor) -> tensor + %2 = mhlo.bitcast_convert %1 : (tensor) -> tensor> + + // CHECK-DAG: %[[SCALES_DQ:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[ZPS_DQ:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[VAL7:.*]] = mhlo.convert %[[VAL6]] : (tensor) -> tensor + // CHECK: %[[VAL8:.*]] = chlo.broadcast_subtract %[[VAL7]], %[[ZPS_DQ]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL9:.*]] = mhlo.convert %[[VAL8]] : (tensor) -> tensor + // CHECK: %[[VAL10:.*]] = chlo.broadcast_multiply %[[VAL9]], %[[SCALES_DQ]] : (tensor, tensor) -> tensor + // CHECK: return %[[VAL10]] : tensor + %3 = mhlo.uniform_dequantize %2 : (tensor>) -> tensor + return %3 : tensor +} + +// ----- + +// CHECK-LABEL: func @uniform_quantize_and_dequantize_int4 +func.func @uniform_quantize_and_dequantize_int4(%arg0: tensor) -> tensor { + // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<3.000000e+00> : tensor + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-8.000000e+00> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<7.000000e+00> : tensor + // CHECK: %[[VAL0:.*]] = chlo.broadcast_divide %arg0, %[[SCALES]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL1:.*]] = chlo.broadcast_add %[[VAL0]], %[[ZPS]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL2:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL1]], %[[QUANT_MAX]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL3:.*]] = mhlo.round_nearest_even %[[VAL2]] : tensor + // CHECK: %[[VAL4:.*]] = mhlo.convert %[[VAL3]] : (tensor) -> tensor + %0 = mhlo.uniform_quantize %arg0 : (tensor) -> tensor> + + // CHECK-DAG: %[[SCALES_DQ:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[ZPS_DQ:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[VAL5:.*]] = mhlo.convert %[[VAL4]] : (tensor) -> tensor + // CHECK: %[[VAL6:.*]] = chlo.broadcast_subtract %[[VAL5]], %[[ZPS_DQ]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL7:.*]] = mhlo.convert %[[VAL6]] : (tensor) -> tensor + // CHECK: %[[VAL8:.*]] = chlo.broadcast_multiply %[[VAL7]], %[[SCALES_DQ]] : (tensor, tensor) -> tensor + // CHECK: return %[[VAL8]] : tensor + %1 = mhlo.uniform_dequantize %0 : (tensor>) -> tensor + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: func @uniform_quantize_and_dequantize_type_exensions +func.func @uniform_quantize_and_dequantize_type_exensions(%arg0: tensor>) -> () { + // CHECK: %[[QUANTIZED:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor>) -> tensor> + %0 = mhlo.uniform_quantize %arg0 : (tensor>) -> tensor, #mhlo.type_extensions> + // CHECK: %[[DEQUANTIZED:.*]] = chlo.broadcast_multiply %[[VAL1:.*]], %[[CONST_SCALE:.*]] : (tensor>, tensor) -> tensor> + %1 = mhlo.uniform_dequantize %0 : (tensor, #mhlo.type_extensions>) -> tensor> + return +} + +// ----- + +#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> + +// CHECK: #[[$SV:.*]] = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> +// CHECK-LABEL: func @uniform_quantize_and_dequantize_sparse_tensor_encoding +func.func @uniform_quantize_and_dequantize_sparse_tensor_encoding(%arg0: tensor) -> () { + // CHECK: %[[QUANTIZED:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor + %0 = mhlo.uniform_quantize %arg0 : (tensor) -> tensor, #SV> + // CHECK: %[[DEQUANTIZED:.*]] = chlo.broadcast_multiply %[[VAL1:.*]], %[[CONST_SCALE:.*]] : (tensor, tensor) -> tensor + %1 = mhlo.uniform_dequantize %0 : (tensor, #SV>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: func @quantize_per_channel +func.func @quantize_per_channel(%arg0: tensor<26x26x3x2xf32> + ) -> tensor<26x26x3x2x!quant.uniform> { + // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<[1.100000e+00, 1.100000e-01]> + // CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<[-1.000000e+01, 2.000000e+00]> + // CHECK-DAG: %[[QMIN:.*]] = mhlo.constant dense<-2.14748365E+9> : tensor + // CHECK-DAG: %[[QMAX:.*]] = mhlo.constant dense<2.14748365E+9> : tensor + // CHECK: %[[DIVIDE:.*]] = chlo.broadcast_divide %arg0, %[[SCALES]] + // CHECK-SAME: {broadcast_dimensions = array} + // CHECK-SAME: (tensor<26x26x3x2xf32>, tensor<2xf32>) -> tensor<26x26x3x2xf32> + // CHECK: %[[ADD:.*]] = chlo.broadcast_add %[[DIVIDE]], %[[ZPS]] + // CHECK-SAME: {broadcast_dimensions = array} + // CHECK-SAME: (tensor<26x26x3x2xf32>, tensor<2xf32>) -> tensor<26x26x3x2xf32> + // CHECK: %[[CLAMP:.*]] = mhlo.clamp %[[QMIN]], %[[ADD]], %[[QMAX]] + // CHECK: %[[ROUND:.*]] = mhlo.round_nearest_even %[[CLAMP]] + // CHECK: %[[RESULT:.*]] = mhlo.convert %[[ROUND]] + // CHECK-SAME: (tensor<26x26x3x2xf32>) -> tensor<26x26x3x2xi32> + %0 = mhlo.uniform_quantize %arg0 : (tensor<26x26x3x2xf32> + ) -> tensor<26x26x3x2x!quant.uniform> + return %0 : tensor<26x26x3x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dequantize_per_channel +func.func @dequantize_per_channel( + %arg0: tensor<26x26x3x2x!quant.uniform> + ) -> tensor<26x26x3x2xf32> { + // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<[1.100000e+00, 1.100000e-01]> + // CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<[-10, 2]> : tensor<2xi32> + // CHECK: %[[SUBTRACT:.*]] = chlo.broadcast_subtract + // CHECK-SAME: %[[INPUT:.*]], %[[ZPS]] + // CHECK-SAME: {broadcast_dimensions = array} + // CHECK-SAME: (tensor<26x26x3x2xi32>, tensor<2xi32>) -> tensor<26x26x3x2xi32> + // CHECK: %[[FLOAT:.*]] = mhlo.convert %[[SUBTRACT]] + // CHECK: %[[RESULT:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[FLOAT]], %[[SCALES]] + // CHECK-SAME: {broadcast_dimensions = array} + // CHECK-SAME: (tensor<26x26x3x2xf32>, tensor<2xf32>) -> tensor<26x26x3x2xf32> + %0 = mhlo.uniform_dequantize %arg0 : ( + tensor<26x26x3x2x!quant.uniform> + ) -> tensor<26x26x3x2xf32> + return %0 : tensor<26x26x3x2xf32> +} + +// ----- + +// CHECK-LABEL: func @add +func.func @add( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor + // CHECK: %[[VAL3:.*]] = mhlo.convert %[[VAL2:.*]] : (tensor) -> tensor + // CHECK-DAG: %[[VAL5:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL1]], %[[VAL3]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL6:.*]] = chlo.broadcast_subtract %[[VAL4]], %[[VAL5]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL9:.*]] = mhlo.clamp %[[VAL7:.*]], %[[VAL6]], %[[VAL8:.*]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL10:.*]] = mhlo.convert %[[VAL9]] : (tensor) -> tensor + %0 = mhlo.add %arg0, %arg1: ( + tensor>, + tensor> + ) -> tensor> + return %0: tensor> +} + +// ----- + +// CHECK-LABEL: func @add_i32 +func.func @add_i32( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : tensor + // CHECK: %[[VAL3:.*]] = mhlo.convert %[[VAL2:.*]] : tensor + // CHECK-DAG: %[[VAL5:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL1:.*]], %[[VAL3:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL6:.*]] = chlo.broadcast_subtract %[[VAL4]], %[[VAL5]] : (tensor, tensor) -> tensor + // CHECK-NEXT: return + %2 = mhlo.add %arg0, %arg1: ( + tensor>, + tensor> + ) -> tensor> + return %2 : tensor> +} + +// ----- + +// CHECK-LABEL: func @add_int4 +func.func @add_int4( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor + // CHECK: %[[VAL3:.*]] = mhlo.convert %[[VAL2:.*]] : (tensor) -> tensor + // CHECK-DAG: %[[VAL5:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL1]], %[[VAL3]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL6:.*]] = chlo.broadcast_subtract %[[VAL4]], %[[VAL5]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL9:.*]] = mhlo.clamp %[[VAL7:.*]], %[[VAL6]], %[[VAL8:.*]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL10:.*]] = mhlo.convert %[[VAL9]] : (tensor) -> tensor + %0 = mhlo.add %arg0, %arg1: ( + tensor>, + tensor> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: @add_different_lhs_type +func.func @add_different_lhs_type( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK-DAG: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK-DAG: %[[LHS:.*]] = mhlo.convert %arg0 : (tensor) -> tensor + // CHECK-DAG: %[[MUL:.*]] = chlo.broadcast_multiply %[[LHS]], %[[COMBINED_SCALE]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[COMBINED_ZP:.*]] = mhlo.constant dense<-5.000000e+00> + // CHECK: %[[LHS_32:.*]] = chlo.broadcast_add %[[MUL]], %[[COMBINED_ZP]] : (tensor, tensor) -> tensor + + // CHECK-DAG: %[[RHS_32:.*]] = mhlo.convert %[[RHS:.*]] : (tensor) -> tensor + // CHECK-DAG: %[[RES_ZPS:.*]] = mhlo.constant dense<1> : tensor + // CHECK-DAG: %[[VAL7:.*]] = chlo.broadcast_add %[[LHS_32_REQ:.*]], %[[RHS_32:.*]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[VAL9:.*]] = chlo.broadcast_subtract %[[VAL7:.*]], %[[RES_ZPS:.*]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-128> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<127> : tensor + // CHECK: %[[VAL10:.*]] = mhlo.clamp %[[QUANT_MIN:.*]], %[[VAL9:.*]], %[[QUANT_MAX:.*]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL11:.*]] = mhlo.convert %[[VAL10:.*]] : (tensor) -> tensor + %2 = mhlo.add %arg0, %arg1: ( + tensor>, + tensor> + ) -> tensor> + return %2 : tensor> +} + +// ----- + +// CHECK-LABEL: @add_different_rhs_type +func.func @add_different_rhs_type( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK-DAG: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK-DAG: %[[RHS:.*]] = mhlo.convert %arg1 : (tensor) -> tensor + // CHECK-DAG: %[[MUL:.*]] = chlo.broadcast_multiply %[[RHS]], %[[COMBINED_SCALE]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[COMBINED_ZP:.*]] = mhlo.constant dense<-5.000000e+00> + // CHECK: %[[RHS_32:.*]] = chlo.broadcast_add %[[MUL]], %[[COMBINED_ZP]] : (tensor, tensor) -> tensor + + // CHECK-DAG: %[[RES_ZPS:.*]] = mhlo.constant dense<1> : tensor + // CHECK-DAG: %[[VAL7:.*]] = chlo.broadcast_add %[[LHS_32:.*]], %[[RHS_32_REQ:.*]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[VAL9:.*]] = chlo.broadcast_subtract %[[VAL7:.*]], %[[RES_ZPS:.*]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-128> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<127> : tensor + // CHECK: %[[VAL10:.*]] = mhlo.clamp %[[QUANT_MIN:.*]], %[[VAL9:.*]], %[[QUANT_MAX:.*]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL11:.*]] = mhlo.convert %[[VAL10:.*]] : (tensor) -> tensor + %0 = mhlo.add %arg0, %arg1: ( + tensor>, + tensor> + ) -> tensor> + return %0 : tensor> +} + +// CHECK-LABEL: @add_different_res_type +func.func @add_different_res_type( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK-DAG: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK-DAG: %[[LHS:.*]] = mhlo.convert %arg0 : (tensor) -> tensor + // CHECK-DAG: %[[MUL:.*]] = chlo.broadcast_multiply %[[LHS]], %[[COMBINED_SCALE]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[COMBINED_ZP:.*]] = mhlo.constant dense<-5.000000e+00> + // CHECK: %[[LHS_32_REQ:.*]] = chlo.broadcast_add %[[MUL]], %[[COMBINED_ZP]] : (tensor, tensor) -> tensor + + // CHECK-DAG: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK-DAG: %[[RHS:.*]] = mhlo.convert %arg1 : (tensor) -> tensor + // CHECK-DAG: %[[MUL:.*]] = chlo.broadcast_multiply %[[RHS]], %[[COMBINED_SCALE]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[COMBINED_ZP:.*]] = mhlo.constant dense<-5.000000e+00> + // CHECK: %[[RHS_32_REQ:.*]] = chlo.broadcast_add %[[MUL]], %[[COMBINED_ZP]] : (tensor, tensor) -> tensor + + // CHECK-DAG: %[[RES_ZPS:.*]] = mhlo.constant dense<1> : tensor + // CHECK-DAG: %[[VAL11:.*]] = chlo.broadcast_add %[[LHS_32_REQ:.*]], %[[RHS_32_REQ:.*]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[VAL12:.*]] = chlo.broadcast_subtract %[[VAL11:.*]], %[[RES_ZPS:.*]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-128> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<127> : tensor + // CHECK: %[[VAL13:.*]] = mhlo.clamp %[[QUANT_MIN:.*]], %[[VAL12:.*]], %[[QUANT_MAX:.*]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL14:.*]] = mhlo.convert %[[VAL13:.*]] : (tensor) -> tensor + %0 = mhlo.add %arg0, %arg1: ( + tensor>, + tensor> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @add_per_channel +func.func @add_per_channel( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK: %[[ADD:.*]] = mhlo.add {{.*}} : tensor + // CHECK: %[[ZPS:.*]] = mhlo.constant dense<[3, 2]> : tensor<2xi32> + // CHECK: %[[BCAST_SUB:.*]] = chlo.broadcast_subtract %[[ADD]], %[[ZPS]] + // CHECK-SAME: {broadcast_dimensions = array} + // CHECK-SAME: (tensor, tensor<2xi32>) -> tensor + // CHECK: return %[[BCAST_SUB]] : tensor + %11 = mhlo.add %arg0, %arg1 : tensor> + return %11 : tensor> +} + +// ----- + +// CHECK-LABEL: func @add_per_channel_no_zp +func.func @add_per_channel_no_zp( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK: %[[ADD:.*]] = mhlo.add {{.*}} : tensor + // CHECK: return %[[ADD]] : tensor + %11 = mhlo.add %arg0, %arg1 : tensor> + return %11 : tensor> +} + +// ----- + +func.func @add_per_channel_i8( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // expected-error@+2 {{Per-channel quantized AddOp requires i32 storage type}} + // expected-error@+1 {{failed to legalize operation 'mhlo.add' that was explicitly marked illegal}} + %11 = mhlo.add %arg0, %arg1 : tensor> + return %11 : tensor> +} + +// ----- + +func.func @add_per_channel_different_quant_types( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // expected-error@+2 {{Per-channel quantized AddOp requires the same quantized element type for all operands and results}} + // expected-error@+1 {{failed to legalize operation 'mhlo.add' that was explicitly marked illegal}} + %11 = mhlo.add %arg0, %arg1 : ( + tensor>, + tensor> + ) -> tensor> + return %11 : tensor> +} + +// ----- + +func.func @add_per_channel_per_tensor_mix( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // expected-error@+1 {{'mhlo.add' op requires compatible types for all operands and results}} + %11 = mhlo.add %arg0, %arg1 : ( + tensor>, + tensor> + ) -> tensor> + return %11 : tensor> +} + +// ----- + +// CHECK-LABEL: func @requantize +func.func @requantize( + %arg0: tensor> + ) -> tensor> { + // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<-5.000000e+00> : tensor + // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor) -> tensor + // CHECK-DAG: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL5:.*]] = mhlo.round_nearest_even %[[VAL4]] : tensor + // CHECK: %[[VAL6:.*]] = mhlo.convert %[[VAL5]] : (tensor) -> tensor + %0 = mhlo.uniform_quantize %arg0 : ( + tensor> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @requantize_merged_zp_zero +func.func @requantize_merged_zp_zero( + %arg0: tensor> + ) -> tensor> { + // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor) -> tensor + // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL3:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL2]], %[[QUANT_MAX]] : (tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL4:.*]] = mhlo.round_nearest_even %[[VAL3]] : tensor + // CHECK: %[[VAL5:.*]] = mhlo.convert %[[VAL4]] : (tensor) -> tensor + %0 = mhlo.uniform_quantize %arg0 : (tensor>) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @requantize_per_channel +func.func @requantize_per_channel( + %arg0: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor<2x2xi8>) -> tensor<2x2xf32> + // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<[2.000000e+00, 5.000000e-01]> : tensor<2xf32> + // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] + // CHECK-SAME: broadcast_dimensions = array + // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<[-5.000000e+00, -2.000000e+00]> : tensor<2xf32> + // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] + // CHECK-SAME: broadcast_dimensions = array + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] + // CHECK: %[[VAL5:.*]] = mhlo.round_nearest_even %[[VAL4]] : tensor<2x2xf32> + // CHECK: %[[VAL6:.*]] = mhlo.convert %[[VAL5]] : (tensor<2x2xf32>) -> tensor<2x2xi8> + %0 = mhlo.uniform_quantize %arg0 : ( + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @requantize_per_channel_to_per_tensor +func.func @requantize_per_channel_to_per_tensor( + %arg0: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor<2x2xi8>) -> tensor<2x2xf32> + // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<[2.000000e+00, 1.000000e+00]> : tensor<2xf32> + // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] + // CHECK-SAME: broadcast_dimensions = array + // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<[-5.000000e+00, -1.000000e+00]> : tensor<2xf32> + // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] + // CHECK-SAME: broadcast_dimensions = array + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] + // CHECK: %[[VAL5:.*]] = mhlo.round_nearest_even %[[VAL4]] : tensor<2x2xf32> + // CHECK: %[[VAL6:.*]] = mhlo.convert %[[VAL5]] : (tensor<2x2xf32>) -> tensor<2x2xi8> + %0 = mhlo.uniform_quantize %arg0 : ( + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @requantize_per_tensor_to_per_channel +func.func @requantize_per_tensor_to_per_channel( + %arg0: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor<2x2xi8>) -> tensor<2x2xf32> + // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<[1.000000e+00, 5.000000e-01]> : tensor<2xf32> + // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] + // CHECK-SAME: broadcast_dimensions = array + // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<[-1.000000e+00, -2.000000e+00]> : tensor<2xf32> + // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] + // CHECK-SAME: broadcast_dimensions = array + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] + // CHECK: %[[VAL5:.*]] = mhlo.round_nearest_even %[[VAL4]] : tensor<2x2xf32> + // CHECK: %[[VAL6:.*]] = mhlo.convert %[[VAL5]] : (tensor<2x2xf32>) -> tensor<2x2xi8> + %0 = mhlo.uniform_quantize %arg0 : ( + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +func.func @requantize_per_channel_change_axis( + %arg0: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // expected-error@+2 {{Cannot requantize while changing quantization_axis}} + // expected-error@+1 {{failed to legalize operation 'mhlo.uniform_quantize' that was explicitly marked illegal}} + %0 = mhlo.uniform_quantize %arg0 : ( + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot +func.func @dot(%arg0: tensor<2x2x!quant.uniform>, + %arg1: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK: "mhlo.dot_general" + // CHECK-SAME: lhs_contracting_dimensions = [1] + // CHECK-SAME: rhs_contracting_dimensions = [0] + // CHECK-SAME: (tensor<2x2xi8>, tensor<2x2xi8>) -> tensor<2x2xi32> + %0 = "mhlo.dot" (%arg0, %arg1) : ( + tensor<2x2x!quant.uniform>, + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot_int4 +func.func @dot_int4( + %arg0: tensor<2x2x!quant.uniform>, + %arg1: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK: "mhlo.dot_general" + // CHECK-SAME: lhs_contracting_dimensions = [1] + // CHECK-SAME: rhs_contracting_dimensions = [0] + // CHECK-SAME: (tensor<2x2xi4>, tensor<2x2xi4>) -> tensor<2x2xi32> + %0 = "mhlo.dot" (%arg0, %arg1): ( + tensor<2x2x!quant.uniform>, + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot_dynamic +func.func @dot_dynamic( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK: %[[DOT:.*]] = "mhlo.dot_general" + // CHECK-SAME: lhs_contracting_dimensions = [1] + // CHECK-SAME: rhs_contracting_dimensions = [0] + // CHECK-SAME: (tensor, tensor) -> tensor + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [1] + // CHECK-SAME: (tensor, tensor) -> tensor + // CHECK: "mhlo.get_dimension_size"(%[[DOT]]) + // CHECK-SAME: <{dimension = 0 : i64}> : (tensor) -> tensor + // CHECK: "mhlo.get_dimension_size"(%[[DOT]]) + // CHECK-SAME: <{dimension = 1 : i64}> : (tensor) -> tensor + // CHECK: %[[DYN_DIMS:.*]] = "mhlo.concatenate" + // CHECK-SAME: <{dimension = 0 : i64}> + // CHECK: mhlo.dynamic_broadcast_in_dim + // CHECK-SAME: %[[DYN_DIMS]]) + // CHECK-SAME: broadcast_dimensions = dense<0> + // CHECK-SAME: (tensor, tensor<2xi64>) -> tensor + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [0] + // CHECK-SAME: (tensor, tensor) -> tensor + // CHECK: mhlo.dynamic_broadcast_in_dim + // CHECK-SAME: %[[DYN_DIMS]]) + // CHECK-SAME: broadcast_dimensions = dense<1> + // CHECK-SAME: (tensor, tensor<2xi64>) -> tensor + %0 = "mhlo.dot" (%arg0, %arg1) : ( + tensor>, + tensor> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @dot_dynamic_int4 +func.func @dot_dynamic_int4( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK: mhlo.dot_general + // CHECK-SAME: lhs_contracting_dimensions = [1] + // CHECK-SAME: rhs_contracting_dimensions = [0] + // CHECK-SAME: (tensor, tensor) -> tensor + %0 = "mhlo.dot" (%arg0, %arg1) : ( + tensor>, + tensor> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @dot_dynamic_contracting_dim +func.func @dot_dynamic_contracting_dim( + %arg0: tensor<2x?x!quant.uniform>, + %arg1: tensor> + ) -> tensor<2x2x!quant.uniform> { + // CHECK: "mhlo.dot_general" + // CHECK-SAME: lhs_contracting_dimensions = [1] + // CHECK-SAME: rhs_contracting_dimensions = [0] + // CHECK-SAME: (tensor<2x?xi8>, tensor) -> tensor<2x2xi32> + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [1] + // CHECK-SAME: (tensor<2x?xi32>, tensor) -> tensor<2xi32> + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [0] + // CHECK-SAME: (tensor, tensor) -> tensor<2xi32> + + // CHECK: %[[DYNAMIC_DIM_INIT:.*]] = mhlo.constant dense<1> : tensor + // CHECK: %[[DYNAMIC_DIM:.*]] = "mhlo.get_dimension_size" + // CHECK-SAME: <{dimension = 0 : i64}> : (tensor) -> tensor + // CHECK: %[[DYNAMIC_DIM_TOTAL:.*]] = mhlo.multiply + // CHECK-SAME: %[[DYNAMIC_DIM_INIT]], %[[DYNAMIC_DIM]] + // CHECK: %[[DIMS:.*]] = mhlo.constant dense<9> : tensor + // CHECK: %[[DIMS_1:.*]] = mhlo.multiply %[[DIMS]], %[[DYNAMIC_DIM_TOTAL]] + // CHECK: chlo.broadcast_subtract %[[ZP_OFFSET:.*]], %[[DIMS:.*]] + %0 = "mhlo.dot" (%arg0, %arg1) : ( + tensor<2x?x!quant.uniform>, + tensor> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot_dynamic_result_dim +func.func @dot_dynamic_result_dim( + %arg0: tensor>, + %arg1: tensor<2x?x!quant.uniform> + ) -> tensor> { + // CHECK: "mhlo.dot_general" + // CHECK-SAME: lhs_contracting_dimensions = [1] + // CHECK-SAME: rhs_contracting_dimensions = [0] + // CHECK-SAME: (tensor, tensor<2x?xi8>) -> tensor + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [1] + // CHECK-SAME: (tensor, tensor) -> tensor + // CHECK: mhlo.dynamic_broadcast_in_dim + // CHECK-SAME: broadcast_dimensions = dense<0> + // CHECK-SAME: (tensor, tensor<2xi64>) -> tensor + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [0] + // CHECK-SAME: (tensor<2x?xi32>, tensor) -> tensor + // CHECK: mhlo.dynamic_broadcast_in_dim + // CHECK-SAME: broadcast_dimensions = dense<1> + // CHECK-SAME: (tensor, tensor<2xi64>) -> tensor + + %0 = "mhlo.dot" (%arg0, %arg1) : ( + tensor>, + tensor<2x?x!quant.uniform> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @dot_dynamic_batch_dim +func.func @dot_dynamic_batch_dim( + %arg0: tensor>, + %arg1: tensor<2x2x!quant.uniform> + ) -> tensor> { + // CHECK: "mhlo.dot_general" + // CHECK-SAME: lhs_contracting_dimensions = [1] + // CHECK-SAME: rhs_contracting_dimensions = [0] + // CHECK-SAME: (tensor, tensor<2x2xi8>) -> tensor + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [1] + // CHECK-SAME: (tensor, tensor) -> tensor + // CHECK: mhlo.dynamic_broadcast_in_dim + // CHECK-SAME: broadcast_dimensions = dense<0> + // CHECK-SAME: (tensor, tensor<2xi64>) -> tensor + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [0] + // CHECK-SAME: (tensor<2x2xi32>, tensor) -> tensor<2xi32> + // CHECK: mhlo.dynamic_broadcast_in_dim + // CHECK-SAME: broadcast_dimensions = dense<1> + // CHECK-SAME: (tensor<2xi32>, tensor<2xi64>) -> tensor + + %0 = "mhlo.dot" (%arg0, %arg1) : ( + tensor>, + tensor<2x2x!quant.uniform> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @dot_general +func.func @dot_general( + %arg0: tensor<2x5x6x!quant.uniform>, + %arg1: tensor<6x8x2x!quant.uniform> + ) -> tensor<2x5x8x!quant.uniform> { + // CHECK: %[[DOT_RES:.*]] = "mhlo.dot_general" + // CHECK-SAME: lhs_batching_dimensions = [0] + // CHECK-SAME: rhs_batching_dimensions = [2] + // CHECK-SAME: lhs_contracting_dimensions = [2] + // CHECK-SAME: rhs_contracting_dimensions = [0] + + // Zero point offset contribution from LHS tensor * RHS ZP. + + // CHECK: %[[LHS_I32:.*]] = mhlo.convert %[[LHS:.*]] : (tensor<2x5x6xi8>) + // CHECK-SAME: -> tensor<2x5x6xi32> + // CHECK: %[[LHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[LHS_REDUCE:.*]] = mhlo.reduce(%[[LHS_I32]] init: %[[LHS_REDUCE_INIT]]) + // CHECK-SAME: applies mhlo.add across dimensions = [2] + // CHECK-SAME: (tensor<2x5x6xi32>, tensor) + // CHECK-SAME: -> tensor<2x5xi32> + // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<5> : tensor + // CHECK: %[[LHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[LHS_REDUCE]], %[[RHS_ZP]] : + // CHECK-SAME: (tensor<2x5xi32>, tensor) -> tensor<2x5xi32> + // CHECK: %[[LHS_ZP_BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[LHS_ZP_CONTRIB]]) + // CHECK-SAME: broadcast_dimensions = dense<[0, 1]> + // CHECK-SAME: (tensor<2x5xi32>) -> tensor<2x5x8xi32> + + // Zero point offset contribution from RHS tensor * LHS ZP. + + // CHECK: %[[RHS_I32:.*]] = mhlo.convert %[[RHS:.*]] : (tensor<6x8x2xi8>) + // CHECK-SAME: -> tensor<6x8x2xi32> + // CHECK: %[[RHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[RHS_REDUCE:.*]] = mhlo.reduce(%[[RHS_I32]] init: %[[RHS_REDUCE_INIT]]) + // CHECK-SAME: applies mhlo.add across dimensions = [0] + // CHECK-SAME: (tensor<6x8x2xi32>, tensor) + // CHECK-SAME: -> tensor<8x2xi32> + // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[RHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RHS_REDUCE]], %[[RHS_ZP]] : + // CHECK-SAME: (tensor<8x2xi32>, tensor) -> tensor<8x2xi32> + // CHECK: %[[RHS_ZP_BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[RHS_ZP_CONTRIB]]) + // CHECK-SAME: broadcast_dimensions = dense<[2, 0]> + // CHECK-SAME: (tensor<8x2xi32>) -> tensor<2x5x8xi32> + // CHECK: %[[ZP_TOTAL_1:.*]] = mhlo.add %[[LHS_ZP_BCAST]], %[[RHS_ZP_BCAST]] + + // Zero point offset contribution from LHS ZP * RHS ZP. + + // CHECK: %[[ZPS:.*]] = mhlo.constant dense<90> : tensor + // CHECK: %[[ZP_TOTAL_2:.*]] = chlo.broadcast_subtract %[[ZP_TOTAL_1]], %[[ZPS]] + // CHECK-SAME: (tensor<2x5x8xi32>, tensor) -> tensor<2x5x8xi32> + + // Combine dot result with zero point offset and output final result. + + // CHECK: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<5.000000e-01> : tensor + // CHECK: %[[RES_FP:.*]] = mhlo.convert %[[DOT_RES]] + // CHECK-SAME: (tensor<2x5x8xi32>) -> tensor<2x5x8xf32> + // CHECK: %[[RES_FP_1:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RES_FP:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[RES_INT:.*]] = mhlo.convert %[[RES_FP_1]] + // CHECK-SAME: (tensor<2x5x8xf32>) -> tensor<2x5x8xi32> + + // CHECK: %[[ZP_TOTAL_3:.*]] = mhlo.convert %[[ZP_TOTAL_2]] + // CHECK-SAME: (tensor<2x5x8xi32>) -> tensor<2x5x8xf32> + // CHECK: %[[ZP_TOTAL_4:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[ZP_TOTAL_3:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[ZP_TOTAL_5:.*]] = mhlo.convert %[[ZP_TOTAL_4]] + // CHECK-SAME: (tensor<2x5x8xf32>) -> tensor<2x5x8xi32> + + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<7> : tensor + // CHECK: %[[ZP_TOTAL_6:.*]] = chlo.broadcast_subtract %[[RES_ZP]], %[[ZP_TOTAL_5]] + // CHECK-SAME: (tensor, tensor<2x5x8xi32>) -> tensor<2x5x8xi32> + // CHECK: chlo.broadcast_add %[[RES_INT]], %[[ZP_TOTAL_6]] + + %0 = "mhlo.dot_general" (%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [2], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [0] + >} : ( + tensor<2x5x6x!quant.uniform>, + tensor<6x8x2x!quant.uniform> + ) -> tensor<2x5x8x!quant.uniform> + return %0 : tensor<2x5x8x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot_general_combined_scale_1 +func.func @dot_general_combined_scale_1( + %arg0: tensor<2x5x6x!quant.uniform>, + %arg1: tensor<6x8x2x!quant.uniform> + ) -> tensor<2x5x8x!quant.uniform> { + // CHECK: %[[DOT_RES:.*]] = "mhlo.dot_general" + // CHECK-SAME: lhs_batching_dimensions = [0] + // CHECK-SAME: rhs_batching_dimensions = [2] + // CHECK-SAME: lhs_contracting_dimensions = [2] + // CHECK-SAME: rhs_contracting_dimensions = [0] + + // Zero point offset contribution from LHS tensor * RHS ZP. + + // CHECK: %[[LHS_I32:.*]] = mhlo.convert %[[LHS:.*]] : (tensor<2x5x6xi8>) + // CHECK-SAME: -> tensor<2x5x6xi32> + // CHECK: %[[LHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[LHS_REDUCE:.*]] = mhlo.reduce(%[[LHS_I32]] init: %[[LHS_REDUCE_INIT]]) + // CHECK-SAME: applies mhlo.add across dimensions = [2] + // CHECK-SAME: (tensor<2x5x6xi32>, tensor) + // CHECK-SAME: -> tensor<2x5xi32> + // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<5> : tensor + // CHECK: %[[LHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[LHS_REDUCE]], %[[RHS_ZP]] : + // CHECK-SAME: (tensor<2x5xi32>, tensor) -> tensor<2x5xi32> + // CHECK: %[[LHS_ZP_BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[LHS_ZP_CONTRIB]]) + // CHECK-SAME: broadcast_dimensions = dense<[0, 1]> + // CHECK-SAME: (tensor<2x5xi32>) -> tensor<2x5x8xi32> + + // Zero point offset contribution from RHS tensor * LHS ZP. + + // CHECK: %[[RHS_I32:.*]] = mhlo.convert %[[RHS:.*]] : (tensor<6x8x2xi8>) + // CHECK-SAME: -> tensor<6x8x2xi32> + // CHECK: %[[RHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[RHS_REDUCE:.*]] = mhlo.reduce(%[[RHS_I32]] init: %[[RHS_REDUCE_INIT]]) + // CHECK-SAME: applies mhlo.add across dimensions = [0] + // CHECK-SAME: (tensor<6x8x2xi32>, tensor) + // CHECK-SAME: -> tensor<8x2xi32> + // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[RHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RHS_REDUCE]], %[[RHS_ZP]] : + // CHECK-SAME: (tensor<8x2xi32>, tensor) -> tensor<8x2xi32> + // CHECK: %[[RHS_ZP_BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[RHS_ZP_CONTRIB]]) + // CHECK-SAME: broadcast_dimensions = dense<[2, 0]> + // CHECK-SAME: (tensor<8x2xi32>) -> tensor<2x5x8xi32> + // CHECK: %[[ZP_TOTAL_1:.*]] = mhlo.add %[[LHS_ZP_BCAST]], %[[RHS_ZP_BCAST]] + + // Zero point offset contribution from LHS ZP * RHS ZP. + + // CHECK: %[[ZPS:.*]] = mhlo.constant dense<90> : tensor + // CHECK: %[[ZP_TOTAL_2:.*]] = chlo.broadcast_subtract %[[ZP_TOTAL_1]], %[[ZPS]] + // CHECK-SAME: (tensor<2x5x8xi32>, tensor) -> tensor<2x5x8xi32> + + // Combine dot result with zero point offset and output final result. + // Do not multiply by combined scale since it is 1.0 and thus no-op. + + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<7> : tensor + // CHECK: %[[ZP_TOTAL_3:.*]] = chlo.broadcast_subtract %[[RES_ZP]], %[[ZP_TOTAL_2]] + // CHECK-SAME: (tensor, tensor<2x5x8xi32>) -> tensor<2x5x8xi32> + // CHECK: chlo.broadcast_add %[[DOT_RES]], %[[ZP_TOTAL_3]] + + %0 = "mhlo.dot_general" (%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [2], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [0] + >} : ( + tensor<2x5x6x!quant.uniform>, + tensor<6x8x2x!quant.uniform> + ) -> tensor<2x5x8x!quant.uniform> + return %0 : tensor<2x5x8x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot_general_multiple_batching_dims +func.func @dot_general_multiple_batching_dims( + %arg0: tensor<2x5x3x7x6x!quant.uniform>, + %arg1: tensor<6x2x7x8x3x!quant.uniform> + ) -> tensor<2x3x5x8x!quant.uniform> { + // CHECK: %[[DOT_RES:.*]] = "mhlo.dot_general" + // CHECK-SAME: lhs_batching_dimensions = [0, 2] + // CHECK-SAME: rhs_batching_dimensions = [1, 4] + // CHECK-SAME: lhs_contracting_dimensions = [4, 3] + // CHECK-SAME: rhs_contracting_dimensions = [0, 2]>} + + // Zero point offset contribution from LHS tensor * RHS ZP. + + // CHECK: %[[LHS_I32:.*]] = mhlo.convert %[[LHS:.*]] : (tensor<2x5x3x7x6xi8>) + // CHECK-SAME: -> tensor<2x5x3x7x6xi32> + // CHECK: %[[LHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[LHS_REDUCE:.*]] = mhlo.reduce(%[[LHS_I32]] init: %[[LHS_REDUCE_INIT]]) + // CHECK-SAME: applies mhlo.add across dimensions = [4, 3] + // CHECK-SAME: (tensor<2x5x3x7x6xi32>, tensor) + // CHECK-SAME: -> tensor<2x5x3xi32> + // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<5> : tensor + // CHECK: %[[LHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[LHS_REDUCE]], %[[RHS_ZP]] : + // CHECK-SAME: (tensor<2x5x3xi32>, tensor) -> tensor<2x5x3xi32> + // CHECK: %[[LHS_ZP_BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[LHS_ZP_CONTRIB]]) + // CHECK-SAME: broadcast_dimensions = dense<[0, 2, 1]> + // CHECK-SAME: (tensor<2x5x3xi32>) -> tensor<2x3x5x8xi32> + + // Zero point offset contribution from RHS tensor * LHS ZP. + + // CHECK: %[[RHS_I32:.*]] = mhlo.convert %[[RHS:.*]] : (tensor<6x2x7x8x3xi8>) + // CHECK-SAME: -> tensor<6x2x7x8x3xi32> + // CHECK: %[[RHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[RHS_REDUCE:.*]] = mhlo.reduce(%[[RHS_I32]] init: %[[RHS_REDUCE_INIT]]) + // CHECK-SAME: applies mhlo.add across dimensions = [0, 2] + // CHECK-SAME: (tensor<6x2x7x8x3xi32>, tensor) + // CHECK-SAME: -> tensor<2x8x3xi32> + // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[RHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RHS_REDUCE]], %[[RHS_ZP]] : + // CHECK-SAME: (tensor<2x8x3xi32>, tensor) -> tensor<2x8x3xi32> + // CHECK: %[[RHS_ZP_BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[RHS_ZP_CONTRIB]]) + // CHECK-SAME: broadcast_dimensions = dense<[0, 3, 1]> + // CHECK-SAME: (tensor<2x8x3xi32>) -> tensor<2x3x5x8xi32> + // CHECK: %[[ZP_TOTAL_1:.*]] = mhlo.add %[[LHS_ZP_BCAST]], %[[RHS_ZP_BCAST]] + + // Zero point offset contribution from LHS ZP * RHS ZP. + + // CHECK: %[[ZPS:.*]] = mhlo.constant dense<630> : tensor + // CHECK: %[[ZP_TOTAL_2:.*]] = chlo.broadcast_subtract %[[ZP_TOTAL_1]], %[[ZPS]] + // CHECK-SAME: (tensor<2x3x5x8xi32>, tensor) -> tensor<2x3x5x8xi32> + + // Combine dot result with zero point offset and output final result. + + // CHECK: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<5.000000e-01> : tensor + // CHECK: %[[RES_FP:.*]] = mhlo.convert %[[DOT_RES]] + // CHECK-SAME: (tensor<2x3x5x8xi32>) -> tensor<2x3x5x8xf32> + // CHECK: %[[RES_FP_1:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RES_FP:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[RES_INT:.*]] = mhlo.convert %[[RES_FP_1]] + // CHECK-SAME: (tensor<2x3x5x8xf32>) -> tensor<2x3x5x8xi32> + + // CHECK: %[[ZP_TOTAL_3:.*]] = mhlo.convert %[[ZP_TOTAL_2]] + // CHECK-SAME: (tensor<2x3x5x8xi32>) -> tensor<2x3x5x8xf32> + // CHECK: %[[ZP_TOTAL_4:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[ZP_TOTAL_3:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[ZP_TOTAL_5:.*]] = mhlo.convert %[[ZP_TOTAL_4]] + // CHECK-SAME: (tensor<2x3x5x8xf32>) -> tensor<2x3x5x8xi32> + + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<7> : tensor + // CHECK: %[[ZP_TOTAL_6:.*]] = chlo.broadcast_subtract %[[RES_ZP]], %[[ZP_TOTAL_5]] + // CHECK-SAME: (tensor, tensor<2x3x5x8xi32>) -> tensor<2x3x5x8xi32> + // CHECK: chlo.broadcast_add %[[RES_INT]], %[[ZP_TOTAL_6]] + + %0 = "mhlo.dot_general" (%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0, 2], + rhs_batching_dimensions = [1, 4], + lhs_contracting_dimensions = [4, 3], + rhs_contracting_dimensions = [0, 2] + >} : ( + tensor<2x5x3x7x6x!quant.uniform>, + tensor<6x2x7x8x3x!quant.uniform> + ) -> tensor<2x3x5x8x!quant.uniform> + return %0 : tensor<2x3x5x8x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot_general_rhs_zero_zp +func.func @dot_general_rhs_zero_zp( + %arg0: tensor<2x5x6x!quant.uniform>, + %arg1: tensor<6x8x2x!quant.uniform> + ) -> tensor<2x5x8x!quant.uniform> { + // CHECK: %[[DOT_RES:.*]] = "mhlo.dot_general" + // CHECK-SAME: lhs_batching_dimensions = [0] + // CHECK-SAME: rhs_batching_dimensions = [2] + // CHECK-SAME: lhs_contracting_dimensions = [2] + // CHECK-SAME: rhs_contracting_dimensions = [0] + + // Zero point offset contribution from LHS tensor * RHS ZP is 0 and skipped. + + // Zero point offset contribution from RHS tensor * LHS ZP. + + // CHECK: %[[RHS_I32:.*]] = mhlo.convert %[[RHS:.*]] : (tensor<6x8x2xi8>) + // CHECK-SAME: -> tensor<6x8x2xi32> + // CHECK: %[[RHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[RHS_REDUCE:.*]] = mhlo.reduce(%[[RHS_I32]] init: %[[RHS_REDUCE_INIT]]) + // CHECK-SAME: applies mhlo.add across dimensions = [0] + // CHECK-SAME: (tensor<6x8x2xi32>, tensor) + // CHECK-SAME: -> tensor<8x2xi32> + // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[RHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RHS_REDUCE]], %[[RHS_ZP]] : + // CHECK-SAME: (tensor<8x2xi32>, tensor) -> tensor<8x2xi32> + // CHECK: %[[RHS_ZP_BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[RHS_ZP_CONTRIB]]) + // CHECK-SAME: broadcast_dimensions = dense<[2, 0]> + // CHECK-SAME: (tensor<8x2xi32>) -> tensor<2x5x8xi32> + + // Zero point offset contribution from LHS ZP * RHS ZP is 0 and skipped. + + // Combine dot result with zero point offset and output final result. + + // CHECK: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<5.000000e-01> : tensor + // CHECK: %[[RES_FP:.*]] = mhlo.convert %[[DOT_RES]] + // CHECK-SAME: (tensor<2x5x8xi32>) -> tensor<2x5x8xf32> + // CHECK: %[[RES_FP_1:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RES_FP:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[RES_INT:.*]] = mhlo.convert %[[RES_FP_1]] + // CHECK-SAME: (tensor<2x5x8xf32>) -> tensor<2x5x8xi32> + + // CHECK: %[[ZP_TOTAL_1:.*]] = mhlo.convert %[[RHS_ZP_BCAST]] + // CHECK-SAME: (tensor<2x5x8xi32>) -> tensor<2x5x8xf32> + // CHECK: %[[ZP_TOTAL_2:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[ZP_TOTAL_1:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[ZP_TOTAL_3:.*]] = mhlo.convert %[[ZP_TOTAL_2]] + // CHECK-SAME: (tensor<2x5x8xf32>) -> tensor<2x5x8xi32> + + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<7> : tensor + // CHECK: %[[ZP_TOTAL_4:.*]] = chlo.broadcast_subtract %[[RES_ZP]], %[[ZP_TOTAL_3]] + // CHECK-SAME: (tensor, tensor<2x5x8xi32>) -> tensor<2x5x8xi32> + // CHECK: chlo.broadcast_add %[[RES_INT]], %[[ZP_TOTAL_4]] + + %0 = "mhlo.dot_general" (%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [2], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [0] + >} : ( + tensor<2x5x6x!quant.uniform>, + tensor<6x8x2x!quant.uniform> + ) -> tensor<2x5x8x!quant.uniform> + return %0 : tensor<2x5x8x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot_general_zero_zp +func.func @dot_general_zero_zp( + %arg0: tensor<2x5x6x!quant.uniform>, + %arg1: tensor<6x8x2x!quant.uniform> + ) -> tensor<2x5x8x!quant.uniform> { + // CHECK: %[[DOT_RES:.*]] = "mhlo.dot_general" + // CHECK-SAME: lhs_batching_dimensions = [0] + // CHECK-SAME: rhs_batching_dimensions = [2] + // CHECK-SAME: lhs_contracting_dimensions = [2] + // CHECK-SAME: rhs_contracting_dimensions = [0] + + // Both LHS/RHS have zero zp. No zp contribution. + + // CHECK-DAG: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<1.500000e+00> : tensor + // CHECK: %[[RES_FP:.*]] = mhlo.convert %[[DOT_RES]] : + // CHECK-SAME: (tensor<2x5x8xi32>) -> tensor<2x5x8xf32> + // CHECK: %[[RES_FP_1:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RES_FP:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[RES_INT:.*]] = mhlo.convert %[[RES_FP_1]] + // CHECK-SAME: (tensor<2x5x8xf32>) -> tensor<2x5x8xi32> + + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<7> : tensor + // CHECK: chlo.broadcast_add %[[RES_INT]], %[[RES_ZP]] + + %0 = "mhlo.dot_general" (%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [2], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [0] + >} : ( + tensor<2x5x6x!quant.uniform>, + tensor<6x8x2x!quant.uniform> + ) -> tensor<2x5x8x!quant.uniform> + return %0 : tensor<2x5x8x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot_general_multiple_dynamic_dims +func.func @dot_general_multiple_dynamic_dims( + %arg0: tensor>, + %arg1: tensor<6x?x?x8x3x!quant.uniform> + ) -> tensor> { + // CHECK: %[[DOT_RES:.*]] = "mhlo.dot_general" + // CHECK-SAME: lhs_batching_dimensions = [0, 2] + // CHECK-SAME: rhs_batching_dimensions = [1, 4] + // CHECK-SAME: lhs_contracting_dimensions = [4, 3] + // CHECK-SAME: rhs_contracting_dimensions = [0, 2]>} + + // Zero point offset contribution from LHS tensor * RHS ZP. + + // CHECK: %[[LHS_I32:.*]] = mhlo.convert %[[LHS:.*]] : (tensor) + // CHECK-SAME: -> tensor + // CHECK: %[[LHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[LHS_REDUCE:.*]] = mhlo.reduce(%[[LHS_I32]] init: %[[LHS_REDUCE_INIT]]) + // CHECK-SAME: applies mhlo.add across dimensions = [4, 3] + // CHECK-SAME: (tensor, tensor) + // CHECK-SAME: -> tensor + // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<5> : tensor + // CHECK: %[[LHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[LHS_REDUCE]], %[[RHS_ZP]] : + // CHECK-SAME: (tensor, tensor) -> tensor + + // Calculate output dynamic dims. + // CHECK: %[[DIM_1_1:.*]] = "mhlo.get_dimension_size"(%[[DOT_RES]]) + // CHECK-SAME: {dimension = 0 : i64} + // CHECK: %[[DIM_1_2:.*]] = mhlo.convert %[[DIM_1_1]] : (tensor) -> tensor + // CHECK: %[[DIM_1:.*]] = mhlo.reshape %[[DIM_1_2]] : (tensor) -> tensor<1xi64> + // CHECK: %[[DIM_2:.*]] = mhlo.constant dense<3> : tensor<1xi64> + // CHECK: %[[DIM_3_1:.*]] = "mhlo.get_dimension_size"(%[[DOT_RES]]) + // CHECK-SAME: {dimension = 2 : i64} + // CHECK: %[[DIM_3_2:.*]] = mhlo.convert %[[DIM_3_1]] : (tensor) -> tensor + // CHECK: %[[DIM_3:.*]] = mhlo.reshape %[[DIM_3_2]] : (tensor) -> tensor<1xi64> + // CHECK: %[[DIM_4:.*]] = mhlo.constant dense<8> : tensor<1xi64> + // CHECK: %[[OUTPUT_DIMS:.*]] = "mhlo.concatenate" + // CHECK-SAME: %[[DIM_1]], %[[DIM_2]], %[[DIM_3]], %[[DIM_4]] + + // CHECK: %[[LHS_ZP_BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim" + // CHECK-SAME: (%[[LHS_ZP_CONTRIB]], %[[OUTPUT_DIMS]]) + // CHECK-SAME: broadcast_dimensions = dense<[0, 2, 1]> + // CHECK-SAME: (tensor, tensor<4xi64>) -> tensor + + // Zero point offset contribution from RHS tensor * LHS ZP. + + // CHECK: %[[RHS_I32:.*]] = mhlo.convert %[[RHS:.*]] : (tensor<6x?x?x8x3xi8>) + // CHECK-SAME: -> tensor<6x?x?x8x3xi32> + // CHECK: %[[RHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[RHS_REDUCE:.*]] = mhlo.reduce(%[[RHS_I32]] init: %[[RHS_REDUCE_INIT]]) + // CHECK-SAME: applies mhlo.add across dimensions = [0, 2] + // CHECK-SAME: (tensor<6x?x?x8x3xi32>, tensor) + // CHECK-SAME: -> tensor + // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[RHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RHS_REDUCE]], %[[RHS_ZP]] : + // CHECK-SAME: (tensor, tensor) -> tensor + + // CHECK: %[[RHS_ZP_BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim" + // CHECK-SAME: (%[[RHS_ZP_CONTRIB]], %[[OUTPUT_DIMS]]) + // CHECK-SAME: broadcast_dimensions = dense<[0, 3, 1]> + // CHECK-SAME: (tensor, tensor<4xi64>) -> tensor + // CHECK: %[[ZP_TOTAL_1:.*]] = mhlo.add %[[LHS_ZP_BCAST]], %[[RHS_ZP_BCAST]] + + // Zero point offset contribution from LHS ZP * RHS ZP. + + // CHECK: %[[ZPS_INIT:.*]] = mhlo.constant dense<1> : tensor + // CHECK: %[[DYN_DIM:.*]] = "mhlo.get_dimension_size"(%[[RHS]]) + // CHECK: %[[ZPS_1:.*]] = mhlo.multiply %[[ZPS_INIT]], %[[DYN_DIM]] + // CHECK: %[[STATIC_DIM:.*]] = mhlo.constant dense<90> : tensor + // CHECK: %[[ZPS:.*]] = mhlo.multiply %[[STATIC_DIM]], %[[ZPS_1]] + // CHECK: %[[ZP_TOTAL_2:.*]] = chlo.broadcast_subtract %[[ZP_TOTAL_1]], %[[ZPS]] + // CHECK-SAME: (tensor, tensor) -> tensor + + // Combine dot result with zero point offset and output final result. + + // CHECK: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<5.000000e-01> : tensor + // CHECK: %[[RES_FP:.*]] = mhlo.convert %[[DOT_RES]] + // CHECK-SAME: (tensor) -> tensor + // CHECK: %[[RES_FP_1:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RES_FP:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[RES_INT:.*]] = mhlo.convert %[[RES_FP_1]] + // CHECK-SAME: (tensor) -> tensor + + // CHECK: %[[ZP_TOTAL_3:.*]] = mhlo.convert %[[ZP_TOTAL_2]] + // CHECK-SAME: (tensor) -> tensor + // CHECK: %[[ZP_TOTAL_4:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[ZP_TOTAL_3:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[ZP_TOTAL_5:.*]] = mhlo.convert %[[ZP_TOTAL_4]] + // CHECK-SAME: (tensor) -> tensor + + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<7> : tensor + // CHECK: %[[ZP_TOTAL_6:.*]] = chlo.broadcast_subtract %[[RES_ZP]], %[[ZP_TOTAL_5]] + // CHECK-SAME: (tensor, tensor) -> tensor + // CHECK: chlo.broadcast_add %[[RES_INT]], %[[ZP_TOTAL_6]] + + %0 = "mhlo.dot_general" (%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0, 2], + rhs_batching_dimensions = [1, 4], + lhs_contracting_dimensions = [4, 3], + rhs_contracting_dimensions = [0, 2] + >} : ( + tensor>, + tensor<6x?x?x8x3x!quant.uniform> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @dot_general_per_channel +func.func @dot_general_per_channel( + %arg0: tensor>, + %arg1: tensor<2x2x!quant.uniform:f32:1, {3.0,4.0}>> + ) -> tensor> { + // CHECK: %[[DOT_RES:.*]] = "mhlo.dot_general" + // CHECK-SAME: lhs_contracting_dimensions = [1] + // CHECK-SAME: rhs_contracting_dimensions = [0]>} + + // Zero point offset contribution from RHS tensor * LHS ZP. + + // CHECK: %[[RHS_I32:.*]] = mhlo.convert %arg1 : (tensor<2x2xi8>) + // CHECK-SAME: -> tensor<2x2xi32> + // CHECK: %[[RHS_REDUCE_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[RHS_REDUCE:.*]] = mhlo.reduce(%[[RHS_I32]] init: %[[RHS_REDUCE_INIT]]) + // CHECK-SAME: applies mhlo.add across dimensions = [0] + // CHECK-SAME: (tensor<2x2xi32>, tensor) + // CHECK-SAME: -> tensor<2xi32> + // CHECK: %[[RHS_ZP:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[RHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RHS_REDUCE]], %[[RHS_ZP]] : + // CHECK-SAME: (tensor<2xi32>, tensor) -> tensor<2xi32> + + // Calculate output dynamic dims. + // CHECK: %[[DIM_1_1:.*]] = "mhlo.get_dimension_size"(%[[DOT_RES]]) + // CHECK-SAME: {dimension = 0 : i64} + // CHECK: %[[DIM_1_2:.*]] = mhlo.convert %[[DIM_1_1]] : (tensor) -> tensor + // CHECK: %[[DIM_1:.*]] = mhlo.reshape %[[DIM_1_2]] : (tensor) -> tensor<1xi64> + // CHECK: %[[DIM_2:.*]] = mhlo.constant dense<2> : tensor<1xi64> + // CHECK: %[[OUTPUT_DIMS:.*]] = "mhlo.concatenate" + // CHECK-SAME: %[[DIM_1]], %[[DIM_2]] + + // CHECK: %[[RHS_ZP_BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim" + // CHECK-SAME: (%[[RHS_ZP_CONTRIB]], %[[OUTPUT_DIMS]]) + // CHECK-SAME: broadcast_dimensions = dense<1> + // CHECK-SAME: (tensor<2xi32>, tensor<2xi64>) -> tensor + // CHECK: %[[ZPS_INIT:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[ZP_TOTAL_2:.*]] = chlo.broadcast_subtract %[[ZPS_INIT]], %[[RHS_ZP_BCAST]] + // CHECK-SAME: (tensor, tensor) -> tensor + // CHECK: chlo.broadcast_add %[[DOT_RES]], %[[ZP_TOTAL_2]] + %0 = "mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot} : ( + tensor>, + tensor<2x2x!quant.uniform:f32:1, {3.0,4.0}>> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @conv2d_dynamic +func.func @conv2d_dynamic( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK-NOT: mhlo.pad + + // CHECK: %[[CONV:.*]] = mhlo.convolution + // CHECK-SAME: (%[[LHS:.*]], %[[RHS:.{1,4}]]) + // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + // CHECK-SAME: window = {stride = [1, 2], pad = {{\[}}[0, 0], [0, 0]], + // CHECK-SAME: lhs_dilate = [1, 1], rhs_dilate = [2, 2]} + // CHECK-SAME: {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + // CHECK-SAME: (tensor, tensor) -> tensor + + // Zero point offset contribution from LHS ZP * RHS. + + // CHECK: %[[RHS_I32:.*]] = mhlo.convert %[[RHS]] + // CHECK-SAME: (tensor) -> tensor + // CHECK: %[[RHS_REDUCE:.*]] = mhlo.reduce(%[[RHS_I32]] + // CHECK-SAME: applies mhlo.add across dimensions = [0, 1, 2] + // CHECK-SAME: (tensor, tensor) + // CHECK-SAME: -> tensor + // CHECK: %[[LHS_ZP:.*]] = mhlo.constant dense<4> : tensor + // CHECK: %[[RHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply %[[RHS_REDUCE]], %[[LHS_ZP]] + // CHECK-SAME: (tensor, tensor) -> tensor + // CHECK: %[[RHS_ZP_BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim" + // CHECK-SAME: %[[RHS_ZP_CONTRIB]] + // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK-SAME: (tensor, tensor<4xi64>) -> tensor + + // Combine conv result with zero point offset and output final result. + + // CHECK: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<6.000000e+00> : tensor + // CHECK: %[[RES_FP:.*]] = mhlo.convert %[[CONV]] + // CHECK-SAME: (tensor) -> tensor + // CHECK: %[[RES_FP_1:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RES_FP:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[RES_INT:.*]] = mhlo.convert %[[RES_FP_1]] + // CHECK-SAME: (tensor) -> tensor + + // CHECK: %[[ZP_TOTAL_1:.*]] = mhlo.convert %[[RHS_ZP_BCAST]] + // CHECK-SAME: (tensor) -> tensor + // CHECK: %[[ZP_TOTAL_2:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[ZP_TOTAL_1:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[ZP_TOTAL_3:.*]] = mhlo.convert %[[ZP_TOTAL_2]] + // CHECK-SAME: (tensor) -> tensor + + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<5> : tensor + // CHECK: %[[ZP_TOTAL_4:.*]] = chlo.broadcast_subtract %[[RES_ZP]], %[[ZP_TOTAL_3]] + // CHECK-SAME: (tensor, tensor) -> tensor + // CHECK: chlo.broadcast_add %[[RES_INT]], %[[ZP_TOTAL_4]] + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 2], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [2, 2] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor>, tensor>) + -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @conv2d_static +func.func @conv2d_static( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x128x!quant.uniform> + ) -> tensor<128x26x26x128x!quant.uniform> { + // CHECK-NOT: mhlo.pad + + // CHECK: %[[CONV:.*]] = mhlo.convolution + // CHECK-SAME: (%[[LHS:.*]], %[[RHS:.{1,4}]]) + // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + // CHECK-SAME: window = {stride = [1, 1], pad = {{\[}}[0, 0], [0, 0]], + // CHECK-SAME: lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + // CHECK-SAME: {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + // CHECK-SAME: (tensor<128x28x28x1xi8>, tensor<3x3x1x128xi8>) -> tensor<128x26x26x128xi32> + + // Zero point offset contribution from LHS ZP * RHS. + + // CHECK: %[[RHS_I32:.*]] = mhlo.convert %[[RHS]] + // CHECK-SAME: (tensor<3x3x1x128xi8>) -> tensor<3x3x1x128xi32> + // CHECK: %[[RHS_REDUCE:.*]] = mhlo.reduce(%[[RHS_I32]] + // CHECK-SAME: applies mhlo.add across dimensions = [0, 1, 2] + // CHECK-SAME: (tensor<3x3x1x128xi32>, tensor) + // CHECK-SAME: -> tensor<128xi32> + // CHECK: %[[LHS_ZP:.*]] = mhlo.constant dense<4> : tensor + // CHECK: %[[RHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply %[[RHS_REDUCE]], %[[LHS_ZP]] + // CHECK-SAME: (tensor<128xi32>, tensor) -> tensor<128xi32> + // CHECK: %[[RHS_ZP_BCAST:.*]] = "mhlo.broadcast_in_dim" + // CHECK-SAME: %[[RHS_ZP_CONTRIB]] + // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK-SAME: (tensor<128xi32>) -> tensor<128x26x26x128xi32> + + // Combine conv result with zero point offset and output final result. + + // CHECK: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<6.000000e+00> : tensor + // CHECK: %[[RES_FP:.*]] = mhlo.convert %[[CONV]] + // CHECK-SAME: (tensor<128x26x26x128xi32>) -> tensor<128x26x26x128xf32> + // CHECK: %[[RES_FP_1:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RES_FP:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[RES_INT:.*]] = mhlo.convert %[[RES_FP_1]] + // CHECK-SAME: (tensor<128x26x26x128xf32>) -> tensor<128x26x26x128xi32> + + // CHECK: %[[ZP_TOTAL_1:.*]] = mhlo.convert %[[RHS_ZP_BCAST]] + // CHECK-SAME: (tensor<128x26x26x128xi32>) -> tensor<128x26x26x128xf32> + // CHECK: %[[ZP_TOTAL_2:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[ZP_TOTAL_1:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[ZP_TOTAL_3:.*]] = mhlo.convert %[[ZP_TOTAL_2]] + // CHECK-SAME: (tensor<128x26x26x128xf32>) -> tensor<128x26x26x128xi32> + + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<5> : tensor + // CHECK: %[[ZP_TOTAL_4:.*]] = chlo.broadcast_subtract %[[RES_ZP]], %[[ZP_TOTAL_3]] + // CHECK-SAME: (tensor, tensor<128x26x26x128xi32>) -> tensor<128x26x26x128xi32> + // CHECK: chlo.broadcast_add %[[RES_INT]], %[[ZP_TOTAL_4]] + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x1x!quant.uniform>, tensor<3x3x1x128x!quant.uniform>) + -> tensor<128x26x26x128x!quant.uniform> + return %0 : tensor<128x26x26x128x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @conv2d_default_attr +func.func @conv2d_default_attr( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x128x!quant.uniform> + ) -> tensor<128x26x26x128x!quant.uniform> { + // CHECK: mhlo.convolution + // CHECK-NOT: quant.uniform + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x1x!quant.uniform>, tensor<3x3x1x128x!quant.uniform>) + -> tensor<128x26x26x128x!quant.uniform> + return %0 : tensor<128x26x26x128x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @conv2d_static_padding +func.func @conv2d_static_padding( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x128x!quant.uniform> + ) -> tensor<128x29x33x128x!quant.uniform> { + // Explicitly pad LHS with ZP. + + // CHECK: %[[LHS_ZP_i8:.*]] = mhlo.constant dense<4> : tensor + // CHECK: %[[LHS_PAD:.*]] = "mhlo.pad"(%[[LHS:.*]], %[[LHS_ZP_i8]]) + // CHECK-SAME: edge_padding_high = dense<[0, 2, 4, 0]> + // CHECK-SAME: edge_padding_low = dense<[0, 1, 3, 0]> + // CHECK-SAME: interior_padding = dense<0> + // CHECK-SAME: (tensor<128x28x28x1xi8>, tensor) -> tensor<128x31x35x1xi8> + + // Convolution with padding removed. + + // CHECK: %[[CONV:.*]] = mhlo.convolution + // CHECK-SAME: (%[[LHS_PAD]], %[[RHS:.{1,4}]]) + // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + // CHECK-SAME: window = {stride = [1, 1], pad = {{\[}}[0, 0], [0, 0]], + // CHECK-SAME: lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + // CHECK-SAME: {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + // CHECK-SAME: (tensor<128x31x35x1xi8>, tensor<3x3x1x128xi8>) -> tensor<128x29x33x128xi32> + + // Zero point offset contribution from LHS ZP * RHS. + + // CHECK: %[[RHS_I32:.*]] = mhlo.convert %[[RHS]] + // CHECK-SAME: (tensor<3x3x1x128xi8>) -> tensor<3x3x1x128xi32> + // CHECK: %[[RHS_REDUCE:.*]] = mhlo.reduce(%[[RHS_I32]] + // CHECK-SAME: applies mhlo.add across dimensions = [0, 1, 2] + // CHECK-SAME: (tensor<3x3x1x128xi32>, tensor) + // CHECK-SAME: -> tensor<128xi32> + // CHECK: %[[LHS_ZP:.*]] = mhlo.constant dense<4> : tensor + // CHECK: %[[RHS_ZP_CONTRIB:.*]] = chlo.broadcast_multiply %[[RHS_REDUCE]], %[[LHS_ZP]] + // CHECK-SAME: (tensor<128xi32>, tensor) -> tensor<128xi32> + // CHECK: %[[RHS_ZP_BCAST:.*]] = "mhlo.broadcast_in_dim" + // CHECK-SAME: %[[RHS_ZP_CONTRIB]] + // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK-SAME: (tensor<128xi32>) -> tensor<128x29x33x128xi32> + + // Combine conv result with zero point offset and output final result. + + // CHECK: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<6.000000e+00> : tensor + // CHECK: %[[RES_FP:.*]] = mhlo.convert %[[CONV]] + // CHECK-SAME: (tensor<128x29x33x128xi32>) -> tensor<128x29x33x128xf32> + // CHECK: %[[RES_FP_1:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[RES_FP:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[RES_INT:.*]] = mhlo.convert %[[RES_FP_1]] + // CHECK-SAME: (tensor<128x29x33x128xf32>) -> tensor<128x29x33x128xi32> + + // CHECK: %[[ZP_TOTAL_1:.*]] = mhlo.convert %[[RHS_ZP_BCAST]] + // CHECK-SAME: (tensor<128x29x33x128xi32>) -> tensor<128x29x33x128xf32> + // CHECK: %[[ZP_TOTAL_2:.*]] = chlo.broadcast_multiply + // CHECK-SAME: %[[ZP_TOTAL_1:.*]], %[[COMBINED_SCALE]] + // CHECK: %[[ZP_TOTAL_3:.*]] = mhlo.convert %[[ZP_TOTAL_2]] + // CHECK-SAME: (tensor<128x29x33x128xf32>) -> tensor<128x29x33x128xi32> + + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<5> : tensor + // CHECK: %[[ZP_TOTAL_4:.*]] = chlo.broadcast_subtract %[[RES_ZP]], %[[ZP_TOTAL_3]] + // CHECK-SAME: (tensor, tensor<128x29x33x128xi32>) -> tensor<128x29x33x128xi32> + // CHECK: chlo.broadcast_add %[[RES_INT]], %[[ZP_TOTAL_4]] + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[1, 2], [3, 4]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x1x!quant.uniform>, tensor<3x3x1x128x!quant.uniform>) + -> tensor<128x29x33x128x!quant.uniform> + return %0 : tensor<128x29x33x128x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @conv2d_per_channel +func.func @conv2d_per_channel( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> { + // CHECK: %[[CONV:.*]] = mhlo.convolution(%arg0, %arg1) + // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + // CHECK-SAME: window = {stride = [1, 1], pad = {{\[}}[0, 0], [0, 0]], + // CHECK-SAME: lhs_dilate = [1, 1], rhs_dilate = [1, 1] + // CHECK-SAME: {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + // CHECK-SAME: (tensor<128x28x28x1xi8>, tensor<3x3x1x2xi8>) -> tensor<128x26x26x2xi32> + + // CHECK: %[[RHS:.*]] = mhlo.convert %arg1 : (tensor<3x3x1x2xi8>) -> tensor<3x3x1x2xi32> + // CHECK: %[[REDUCE:.*]] = mhlo.reduce(%[[RHS]] + // CHECK-SAME: applies mhlo.add across dimensions = [0, 1, 2] + // CHECK: %[[LHS_ZP:.*]] = mhlo.constant dense<4> : tensor + // CHECK: %[[ZP_OFFSET:.*]] = chlo.broadcast_multiply %[[REDUCE]], %[[LHS_ZP]] + // CHECK: %[[ZP_OFFSET_BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[ZP_OFFSET]]) + // CHECK: %[[RES_ZP:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[ZP_OFFSET_TOTAL:.*]] = chlo.broadcast_subtract %[[RES_ZP:.*]], %[[ZP_OFFSET_BCAST]] + // CHECK: chlo.broadcast_add %[[CONV]], %[[ZP_OFFSET_TOTAL]] + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : ( + tensor<128x28x28x1x!quant.uniform>, + tensor<3x3x1x2x!quant.uniform>) + -> tensor<128x26x26x2x!quant.uniform> + return %0 : tensor<128x26x26x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @conv3d_static +func.func @conv3d_static( + %arg0: tensor<128x28x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x3x1x128x!quant.uniform> + ) -> tensor<128x26x26x26x128x!quant.uniform>{ + // CHECK-NOT: mhlo.pad + + // CHECK: mhlo.convolution + // CHECK-SAME: dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f] + // CHECK-SAME: window = {stride = [1, 1, 1], pad = {{\[}}[0, 0], [0, 0], [0, 0]], + // CHECK-SAME: lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]} + // CHECK-SAME: {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + // CHECK-SAME: (tensor<128x28x28x28x1xi8>, tensor<3x3x3x1x128xi8>) -> tensor<128x26x26x26x128xi32> + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [0, 1, 2, 3] + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f], + window = { + stride = [1, 1, 1], pad = [[0, 0], [0, 0], [0, 0]], + lhs_dilate = [1, 1, 1], + rhs_dilate = [1, 1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x28x1x!quant.uniform>, tensor<3x3x3x1x128x!quant.uniform>) + -> tensor<128x26x26x26x128x!quant.uniform> + return %0 : tensor<128x26x26x26x128x!quant.uniform> +} + +// ----- + +func.func @conv3d_rhs_zp_not_zero( + %arg0: tensor<128x28x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x3x1x128x!quant.uniform>) { + // expected-error@+2 {{RHS/result UQ type must have zero zp}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f], + window = { + stride = [1, 1, 1], pad = [[0, 0], [0, 0], [0, 0]], + lhs_dilate = [1, 1, 1], + rhs_dilate = [1, 1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x28x1x!quant.uniform>, tensor<3x3x3x1x128x!quant.uniform>) + -> tensor<128x26x26x26x128x!quant.uniform> + return +} + +// ----- + +func.func @conv3d_rhs_invalid_dilate( + %arg0: tensor<128x28x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x3x1x128x!quant.uniform>) { + // expected-error@+2 {{lhs_dilation must be 1}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f], + window = { + stride = [1, 1, 1], pad = [[0, 0], [0, 0], [0, 0]], + lhs_dilate = [2, 2, 2], + rhs_dilate = [1, 1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x28x1x!quant.uniform>, tensor<3x3x3x1x128x!quant.uniform>) + -> tensor<128x53x53x53x128x!quant.uniform> + return +} + +// ----- + +func.func @conv3d_non_nhwc( + %arg0: tensor<128x1x28x28x28x!quant.uniform>, + %arg1: tensor<3x3x3x1x128x!quant.uniform>) { + // expected-error@+2 {{Convolution data format must be NHWC}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, f, 0, 1, 2]x[0, 1, 2, i, o]->[b, f, 0, 1, 2], + window = { + stride = [1, 1, 1], pad = [[0, 0], [0, 0], [0, 0]], + lhs_dilate = [1, 1, 1], + rhs_dilate = [1, 1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x1x28x28x28x!quant.uniform>, tensor<3x3x3x1x128x!quant.uniform>) + -> tensor<128x128x26x26x26x!quant.uniform> + return +} + +// ----- + +func.func @conv2d_non_nhwc( + %arg0: tensor<128x1x28x28x!quant.uniform>, + %arg1: tensor<3x3x1x128x!quant.uniform>) { + // expected-error@+2 {{Convolution data format must be NHWC}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x1x28x28x!quant.uniform>, tensor<3x3x1x128x!quant.uniform>) + -> tensor<128x128x26x26x!quant.uniform> + return +} + +// ----- + +func.func @conv2d_per_channel_rhs_zp_not_zero( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> { + // expected-error@+2 {{RHS/result UQ type must have zero zp.}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : ( + tensor<128x28x28x1x!quant.uniform>, + tensor<3x3x1x2x!quant.uniform>) + -> tensor<128x26x26x2x!quant.uniform> + return %0 : tensor<128x26x26x2x!quant.uniform> +} + +// ----- + +func.func @conv2d_per_channel_res_zp_not_zero( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> { + // expected-error@+2 {{RHS/result UQ type must have zero zp.}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : ( + tensor<128x28x28x1x!quant.uniform>, + tensor<3x3x1x2x!quant.uniform>) + -> tensor<128x26x26x2x!quant.uniform> + return %0 : tensor<128x26x26x2x!quant.uniform> +} + +// ----- + +func.func @conv2d_per_channel_rhs_only( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> { + // expected-error@+2 {{Invalid input/output type for Dot/Convolution op}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : ( + tensor<128x28x28x1x!quant.uniform>, + tensor<3x3x1x2x!quant.uniform>) + -> tensor<128x26x26x2x!quant.uniform> + return %0 : tensor<128x26x26x2x!quant.uniform> +} + +// ----- + +func.func @conv2d_per_channel_res_only( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> { + // expected-error@+2 {{Invalid input/output type for Dot/Convolution op}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : ( + tensor<128x28x28x1x!quant.uniform>, + tensor<3x3x1x2x!quant.uniform>) + -> tensor<128x26x26x2x!quant.uniform> + return %0 : tensor<128x26x26x2x!quant.uniform> +} + +// ----- + +func.func @conv2d_per_channel_unsupported_channel( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> { + // expected-error@+2 {{Conv quantized axis must be out channel axis}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x1x!quant.uniform>, tensor<3x3x1x2x!quant.uniform>) + -> tensor<128x26x26x2x!quant.uniform> + return %0 : tensor<128x26x26x2x!quant.uniform> +} + +// ----- + +func.func @conv2d_per_channel_rhs_result_scale_ratio_different( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> { + // expected-error@+2 {{Per-channel quantizated Conv must have same RHS/Result scale ratio for each channel}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : ( + tensor<128x28x28x1x!quant.uniform>, + tensor<3x3x1x2x!quant.uniform>) + -> tensor<128x26x26x2x!quant.uniform> + return %0 : tensor<128x26x26x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dot_hybrid +func.func @dot_hybrid( + %arg0: tensor, + %arg1: tensor>) -> tensor { + // CHECK: %[[VAL1:.*]] = mhlo.optimization_barrier %[[VAL0:.*]] : tensor + // CHECK: %[[VAL2:.*]] = mhlo.convert %[[VAL1:.*]] : (tensor) -> tensor + // CHECK: %[[VAL4:.*]] = chlo.broadcast_subtract %[[VAL2]], %[[VAL3:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL6:.*]] = chlo.broadcast_multiply %[[VAL4]], %[[VAL5:.*]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL7:.*]] = "mhlo.dot"(%arg0, %[[VAL6]]) : (tensor, tensor) -> tensor + %1 = "mhlo.dot" (%arg0, %arg1): ( + tensor, tensor>) -> tensor + return %1: tensor +} + +// ----- + +// CHECK-LABEL: func @dot_general_hybrid_per_channel +// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor<2x2xi8> +func.func @dot_general_hybrid_per_channel( + %arg0: tensor<3x2xf32>, + %arg1: tensor<2x2x!quant.uniform:f32:1, {3.000000e+00, 4.000000e+00}>> + ) -> tensor<3x2xf32> { + // CHECK-DAG: %[[BARRIER:.*]] = mhlo.optimization_barrier %[[ARG1]] : tensor<2x2xi8> + // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<[3.000000e+00, 4.000000e+00]> : tensor<2xf32> + // CHECK-DAG: %[[CONVERT:.*]] = mhlo.convert %[[BARRIER]] : (tensor<2x2xi8>) -> tensor<2x2xf32> + // CHECK-NOT: chlo.broadcast_subtract + // CHECK: %[[MUL:.*]] = chlo.broadcast_multiply %[[CONVERT]], %[[SCALES]] {broadcast_dimensions = array} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32> + // CHECK: %[[DOT:.*]] = "mhlo.dot_general"(%[[ARG0]], %[[MUL]]) + // CHECK-SAME: (tensor<3x2xf32>, tensor<2x2xf32>) -> tensor<3x2xf32> + // CHECK: return %[[DOT]] + + %0 = "mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot} : ( + tensor<3x2xf32>, + tensor<2x2x!quant.uniform:f32:1, {3.000000e+00, 4.000000e+00}>> + ) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: func @dot_general_hybrid_per_channel_asymmetric +// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor<2x2xi8> +func.func @dot_general_hybrid_per_channel_asymmetric( + %arg0: tensor<3x2xf32>, + %arg1: tensor<2x2x!quant.uniform:f32:1, {3.000000e+00:10, 4.000000e+00:20}>> + ) -> tensor<3x2xf32> { + // CHECK-DAG: %[[BARRIER:.*]] = mhlo.optimization_barrier %[[ARG1]] : tensor<2x2xi8> + // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<[3.000000e+00, 4.000000e+00]> : tensor<2xf32> + // CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<[1.000000e+01, 2.000000e+01]> : tensor<2xf32> + // CHECK-DAG: %[[CONVERT:.*]] = mhlo.convert %[[BARRIER]] : (tensor<2x2xi8>) -> tensor<2x2xf32> + // CHECK: %[[SUB:.*]] = chlo.broadcast_subtract %[[CONVERT]], %[[ZPS]] {broadcast_dimensions = array} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32> + // CHECK: %[[MUL:.*]] = chlo.broadcast_multiply %[[SUB]], %[[SCALES]] {broadcast_dimensions = array} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32> + // CHECK: %[[DOT:.*]] = "mhlo.dot_general"(%[[ARG0]], %[[MUL]]) + // CHECK-SAME: (tensor<3x2xf32>, tensor<2x2xf32>) -> tensor<3x2xf32> + // CHECK: return %[[DOT]] + + %0 = "mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot} : ( + tensor<3x2xf32>, + tensor<2x2x!quant.uniform:f32:1, {3.000000e+00:10, 4.000000e+00:20}>> + ) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> +} + +// ----- + +func.func @dot_hybrid_result_type_not_float( + %arg0: tensor, + %arg1: tensor>) { + // expected-error@+2 {{Invalid input/output type for Dot/Convolution op}} + // expected-error@+1 {{failed to legalize operation 'mhlo.dot' that was explicitly marked illegal}} + %1 = "mhlo.dot" (%arg0, %arg1): ( + tensor, tensor> + ) -> tensor> + return +} + +// ----- + +func.func @dot_hybrid_lhs_type_not_float( + %arg0: tensor>, + %arg1: tensor) { + // expected-error@+2 {{Invalid input/output type for Dot/Convolution op}} + // expected-error@+1 {{failed to legalize operation 'mhlo.dot' that was explicitly marked illegal}} + %1 = "mhlo.dot" (%arg0, %arg1): ( + tensor>, tensor + ) -> tensor> + return +} + +// ----- + +// CHECK-LABEL: func @conv2d_static_hybrid +func.func @conv2d_static_hybrid( + %arg0: tensor<128x28x28x1xf32>, + %arg1: tensor<3x3x1x128x!quant.uniform> + ) -> tensor<128x26x26x128xf32> { + // CHECK-DAG: %[[BARRIER:.*]] = mhlo.optimization_barrier %arg1 : tensor<3x3x1x128xi8> + // CHECK-DAG: %[[ZP:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[SCALE:.*]] = mhlo.constant dense<3.000000e+00> : tensor + // CHECK: %[[RHS:.*]] = mhlo.convert %[[BARRIER]] : (tensor<3x3x1x128xi8>) -> tensor<3x3x1x128xf32> + // CHECK: %[[SUB:.*]] = chlo.broadcast_subtract %[[RHS]], %[[ZP]] + // CHECK: %[[MUL:.*]] = chlo.broadcast_multiply %[[SUB]], %[[SCALE]] + // CHECK: mhlo.convolution(%arg0, %[[MUL]]) + // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + // CHECK-SAME: stride = [1, 1], pad = {{\[}}[0, 0], [0, 0]] + // CHECK-SAME: lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + // CHECK-SAME: {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + // CHECK-SAME: : (tensor<128x28x28x1xf32>, tensor<3x3x1x128xf32>) -> tensor<128x26x26x128xf32> + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x1xf32>, tensor<3x3x1x128x!quant.uniform>) + -> tensor<128x26x26x128xf32> + return %0 : tensor<128x26x26x128xf32> +} + +// ----- + +// CHECK-LABEL: func @conv2d_hybrid_per_channel +// CHECK-SAME: %[[ARG0:.*]]: tensor<128x28x28x1xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor<3x3x1x2xi8> +func.func @conv2d_hybrid_per_channel( + %arg0: tensor<128x28x28x1xf32>, + %arg1: tensor<3x3x1x2x!quant.uniform> + ) -> tensor<128x26x26x2xf32> { + // CHECK-DAG: %[[BARRIER:.*]] = mhlo.optimization_barrier %[[ARG1]] : tensor<3x3x1x2xi8> + // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<[2.000000e+00, 1.000000e+00]> : tensor<2xf32> + // CHECK-DAG: %[[CONVERT:.*]] = mhlo.convert %[[BARRIER]] : (tensor<3x3x1x2xi8>) -> tensor<3x3x1x2xf32> + // CHECK-NOT: chlo.broadcast_subtract + // CHECK: %[[MUL:.*]] = chlo.broadcast_multiply %[[CONVERT]], %[[SCALES]] {broadcast_dimensions = array} : (tensor<3x3x1x2xf32>, tensor<2xf32>) -> tensor<3x3x1x2xf32> + // CHECK: %[[CONV:.*]] = mhlo.convolution(%[[ARG0]], %[[MUL]]) + // CHECK-SAME{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + // CHECK-SAME: {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<128x28x28x1xf32>, tensor<3x3x1x2xf32>) -> tensor<128x26x26x2xf32> + // CHECK: return %[[CONV]] + + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : ( + tensor<128x28x28x1xf32>, + tensor<3x3x1x2x!quant.uniform>) + -> tensor<128x26x26x2xf32> + return %0 : tensor<128x26x26x2xf32> +} + +// ----- + +// CHECK-LABEL: func @conv2d_hybrid_per_channel_asymmetric +// CHECK-SAME: %[[ARG0:.*]]: tensor<128x28x28x1xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor<3x3x1x2xi8> +func.func @conv2d_hybrid_per_channel_asymmetric( + %arg0: tensor<128x28x28x1xf32>, + %arg1: tensor<3x3x1x2x!quant.uniform> + ) -> tensor<128x26x26x2xf32> { + // CHECK-DAG: %[[BARRIER:.*]] = mhlo.optimization_barrier %[[ARG1]] : tensor<3x3x1x2xi8> + // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<[2.000000e+00, 1.000000e+00]> : tensor<2xf32> + // CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<[1.000000e+01, 2.000000e+01]> : tensor<2xf32> + // CHECK-DAG: %[[CONVERT:.*]] = mhlo.convert %[[BARRIER]] : (tensor<3x3x1x2xi8>) -> tensor<3x3x1x2xf32> + // CHECK: %[[SUB:.*]] = chlo.broadcast_subtract %[[CONVERT]], %[[ZPS]] {broadcast_dimensions = array} : (tensor<3x3x1x2xf32>, tensor<2xf32>) -> tensor<3x3x1x2xf32> + // CHECK: %[[MUL:.*]] = chlo.broadcast_multiply %[[SUB]], %[[SCALES]] {broadcast_dimensions = array} : (tensor<3x3x1x2xf32>, tensor<2xf32>) -> tensor<3x3x1x2xf32> + // CHECK: %[[CONV:.*]] = mhlo.convolution(%[[ARG0]], %[[MUL]]) + // CHECK-SAME{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + // CHECK-SAME: {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<128x28x28x1xf32>, tensor<3x3x1x2xf32>) -> tensor<128x26x26x2xf32> + // CHECK: return %[[CONV]] + + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : ( + tensor<128x28x28x1xf32>, + tensor<3x3x1x2x!quant.uniform>) + -> tensor<128x26x26x2xf32> + return %0 : tensor<128x26x26x2xf32> +} + +// ----- + +func.func @conv2d_hybrid_result_not_float( + %arg0: tensor<128x28x28x1xf32>, + %arg1: tensor<3x3x1x128x!quant.uniform>) { + // expected-error@+2 {{Invalid input/output type for Dot/Convolution op}} + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x1xf32>, tensor<3x3x1x128x!quant.uniform>) + -> tensor<128x26x26x128x!quant.uniform> + return +} + +// ----- + +func.func @dot_general_hybrid_result_not_float( + %arg0: tensor<2x5x6xf32>, + %arg1: tensor<6x8x2x!quant.uniform>) { + // expected-error@+2 {{Invalid input/output type for Dot/Convolution op}} + // expected-error@+1 {{failed to legalize operation 'mhlo.dot_general' that was explicitly marked illegal}} + %0 = "mhlo.dot_general" (%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [2], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [0] + >} : ( + tensor<2x5x6xf32>, + tensor<6x8x2x!quant.uniform> + ) -> tensor<2x5x8x!quant.uniform> + return +} + +// ----- + +// CHECK-LABEL: func @mhlo_constant_uniform_quantized +func.func @mhlo_constant_uniform_quantized() -> tensor<1x!quant.uniform> { + // CHECK: mhlo.constant dense<9> : tensor<1xi8> + %0 = mhlo.constant() {value = dense<9> : tensor<1xi8>} : () -> tensor<1x!quant.uniform> + return %0 : tensor<1x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @mhlo_constant_uniform_quantized_per_channel +func.func @mhlo_constant_uniform_quantized_per_channel() -> () { + // CHECK: mhlo.constant dense<[9, 4]> : tensor<2xi8> + %0 = mhlo.constant() {value = dense<[9, 4]> : tensor<2xi8>} : () + -> tensor<2x!quant.uniform> + return +} + + +// ----- + +// CHECK-LABEL: func @mhlo_constant_int +func.func @mhlo_constant_int() -> tensor { + // CHECK: mhlo.constant dense<-128> : tensor + %0 = mhlo.constant() {value = dense<-128> : tensor} : () -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @broadcast +func.func @broadcast( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<2x3x1x!quant.uniform> { + // CHECK: "mhlo.broadcast_in_dim" + // CHECK-SAME: broadcast_dimensions = dense<[2, 0]> : tensor<2xi64> + // CHECK-SAME: (tensor<1x2xi8>) -> tensor<2x3x1xi8> + %0 = "mhlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = dense<[2, 0]> : tensor<2xi64> + } : (tensor<1x2x!quant.uniform>) -> tensor<2x3x1x!quant.uniform> + return %0 : tensor<2x3x1x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @broadcast_per_channel +func.func @broadcast_per_channel( + %arg0: tensor<2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> { + // CHECK: "mhlo.broadcast_in_dim" + // CHECK-SAME: broadcast_dimensions = dense<3> : tensor<1xi64> + // CHECK-SAME: (tensor<2xi32>) -> tensor<128x26x26x2xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<3> : tensor<1xi64>}>: ( + tensor<2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> + return %0 : tensor<128x26x26x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dynamic_broadcast +func.func @dynamic_broadcast( + %arg0: tensor<1x2x!quant.uniform>, + %arg1: tensor<3xi32> + ) -> tensor> { + // CHECK: "mhlo.dynamic_broadcast_in_dim" + // CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> + // CHECK-SAME: (tensor<1x2xi8>, tensor<3xi32>) -> tensor + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> + } : ( + tensor<1x2x!quant.uniform>, tensor<3xi32> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @max +func.func @max( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> { + // CHECK: mhlo.maximum + // CHECK-SAME: tensor<1x2xi8> + %0 = "mhlo.maximum"(%arg0, %arg0) : ( + tensor<1x2x!quant.uniform>, + tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> + return %0 : tensor<1x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @max_per_channel +func.func @max_per_channel( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> { + // CHECK: mhlo.maximum + // CHECK-SAME: tensor<1x2xi8> + %0 = "mhlo.maximum"(%arg0, %arg0) : ( + tensor<1x2x!quant.uniform>, + tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> + return %0 : tensor<1x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @min +func.func @min( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> { + // CHECK: mhlo.minimum + // CHECK-SAME: tensor<1x2xi8> + %0 = "mhlo.minimum"(%arg0, %arg0) : ( + tensor<1x2x!quant.uniform>, + tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> + return %0 : tensor<1x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @min_per_channel +func.func @min_per_channel( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> { + // CHECK: mhlo.minimum + // CHECK-SAME: tensor<1x2xi8> + %0 = "mhlo.minimum"(%arg0, %arg0) : ( + tensor<1x2x!quant.uniform>, + tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> + return %0 : tensor<1x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @function(%arg0: tensor<1x2xi8>) -> tensor<1x2xi8> +func.func @function( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> { + // CHECK: return %arg0 : tensor<1x2xi8> + return %arg0 : tensor<1x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @concatenate +func.func @concatenate( + %arg0: tensor<3x2x!quant.uniform:f32, 5.000000e-03>>, + %arg1: tensor<1x2x!quant.uniform:f32, 5.000000e-03>> + ) -> tensor<4x2x!quant.uniform:f32, 5.000000e-03>> { + // CHECK: mhlo.concatenate + // CHECK-SAME: (tensor<3x2xi8>, tensor<1x2xi8>) -> tensor<4x2xi8> + %0 = "mhlo.concatenate"(%arg0, %arg1) <{dimension = 0 : i64}> : ( + tensor<3x2x!quant.uniform:f32, 5.000000e-03>>, + tensor<1x2x!quant.uniform:f32, 5.000000e-03>> + ) -> tensor<4x2x!quant.uniform:f32, 5.000000e-03>> + return %0 : tensor<4x2x!quant.uniform:f32, 5.000000e-03>> +} + +// ----- + +// CHECK-LABEL: func @pad +func.func @pad( + %arg0: tensor<2x3x!quant.uniform:f32, 5.000000e-03>>, + %arg1: tensor:f32, 5.000000e-03>> + ) -> tensor<5x9x!quant.uniform:f32, 5.000000e-03>> { + // CHECK: mhlo.pad + // CHECK-SAME: (tensor<2x3xi8>, tensor) -> tensor<5x9xi8> + %0 = "mhlo.pad"(%arg0, %arg1) { + edge_padding_low = dense<[0, 1]> : tensor<2xi64>, + edge_padding_high = dense<[2, 1]> : tensor<2xi64>, + interior_padding = dense<[1, 2]> : tensor<2xi64> + }: ( + tensor<2x3x!quant.uniform:f32, 5.000000e-03>>, + tensor:f32, 5.000000e-03>> + ) -> tensor<5x9x!quant.uniform:f32, 5.000000e-03>> + return %0 : tensor<5x9x!quant.uniform:f32, 5.000000e-03>> +} + +// ----- + +// CHECK-LABEL: func @reshape +func.func @reshape( + %arg0: tensor<1x3x!quant.uniform> + ) -> tensor<3x1x!quant.uniform> { + // CHECK: mhlo.reshape + // CHECK-SAME: (tensor<1x3xi8>) -> tensor<3x1xi8> + %0 = "mhlo.reshape"(%arg0) : ( + tensor<1x3x!quant.uniform> + ) -> tensor<3x1x!quant.uniform> + return %0 : tensor<3x1x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dynamic_reshape +func.func @dynamic_reshape( + %arg0: tensor>, + %arg1: tensor<2xi32> + ) -> tensor> { + // CHECK: mhlo.dynamic_reshape + // CHECK-SAME: (tensor, tensor<2xi32>) -> tensor + %0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : ( + tensor>, tensor<2xi32> + ) -> tensor> + return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: func @select +func.func @select( + %arg0: tensor<1x3xi1>, + %arg1: tensor<1x3x!quant.uniform>, + %arg2: tensor<1x3x!quant.uniform> + ) -> tensor<1x3x!quant.uniform> { + // CHECK: mhlo.select + // CHECK-SAME: tensor<1x3xi8> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : ( + tensor<1x3xi1>, + tensor<1x3x!quant.uniform>, + tensor<1x3x!quant.uniform> + ) -> tensor<1x3x!quant.uniform> + return %0 : tensor<1x3x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @transpose +func.func @transpose( + %arg0: tensor<3x1x!quant.uniform> + ) -> tensor<1x3x!quant.uniform> { + // CHECK: mhlo.transpose + // CHECK-SAME: (tensor<3x1xi8>) -> tensor<1x3xi8> + %0 = "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0]> : tensor<2xi64>}> : ( + tensor<3x1x!quant.uniform> + ) -> tensor<1x3x!quant.uniform> + return %0 : tensor<1x3x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @gather +func.func @gather( + %arg0: tensor<3x4x2x!quant.uniform>, + %arg1: tensor<2x3x2xi64> + ) -> tensor<2x3x2x2x!quant.uniform> { + // CHECK: mhlo.gather + // CHECK-SAME: (tensor<3x4x2xi8>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi8> + %0 = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = #mhlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>, + indices_are_sorted = false + } : ( + tensor<3x4x2x!quant.uniform>, + tensor<2x3x2xi64> + ) -> tensor<2x3x2x2x!quant.uniform> + return %0 : tensor<2x3x2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @slice +func.func @slice( + %arg0: tensor<3x4x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK: mhlo.slice + // CHECK-SAME: (tensor<3x4xi8>) -> tensor<2x2xi8> + %0 = "mhlo.slice"(%arg0) { + start_indices = dense<[1, 2]> : tensor<2xi64>, + limit_indices = dense<[3, 4]> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } : ( + tensor<3x4x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @dynamic_slice +func.func @dynamic_slice( + %arg0: tensor>, + %arg1: tensor, + %arg2: tensor + ) -> tensor<1x1x!quant.uniform> { + // CHECK: mhlo.dynamic_slice + // CHECK-SAME: (tensor, tensor, tensor) -> tensor<1x1xi8> + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) { + slice_sizes = dense<1> : tensor<2xi64> + } : ( + tensor>, tensor, + tensor + ) -> tensor<1x1x!quant.uniform> + return %0 : tensor<1x1x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @get_dimension_size +func.func @get_dimension_size( + %arg0: tensor> + ) -> tensor { + // CHECK: mhlo.get_dimension_size + // CHECK-SAME: (tensor) -> tensor + %0 = "mhlo.get_dimension_size"(%arg0) <{dimension = 0 : i64}> : ( + tensor>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: reduce_window +func.func @reduce_window( + %arg0: tensor<2x3x10x3x!quant.uniform>, + %arg1: tensor> + ) -> tensor<2x3x10x3x!quant.uniform> { + // CHECK: mhlo.reduce_window + // CHECK: %[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor + // CHECK: %[[MAX:.*]] = mhlo.maximum %[[ARG2]], %[[ARG3]] : tensor + // CHECK: mhlo.return %[[MAX]] : tensor + // CHECK: (tensor<2x3x10x3xi8>, tensor) -> tensor<2x3x10x3xi8> + %0 = "mhlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor>, %arg3: tensor>): + %1 = mhlo.maximum %arg2, %arg3 : tensor> + mhlo.return %1 : tensor> + }) {padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>} : (tensor<2x3x10x3x!quant.uniform>, tensor>) -> tensor<2x3x10x3x!quant.uniform> + return %0 : tensor<2x3x10x3x!quant.uniform> +} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_canonicalize_scatter.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_canonicalize_scatter.mlir index 653bd3cf913be..2d61f1c0bf16e 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_canonicalize_scatter.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_canonicalize_scatter.mlir @@ -62,8 +62,8 @@ func.func @collapse_scatter_dims(%dst: tensor<3x3xf32>, // CHECK: %[[IND_:.*]] = tensor.collapse_shape %[[IND]] {{\[\[}}0, 1], [2]] : tensor<2x1x2xi32> into tensor<2x2xi32> // CHECK: %[[UPD_:.*]] = tensor.collapse_shape %[[UPD]] {{\[\[}}0, 1], [2], [3]] : tensor<2x1x1x3xf32> into tensor<2x1x3xf32> -// CHECK: "mhlo.scatter"(%[[DST]], %[[IND_]], %[[UPD_]]) ({ -// CHECK: update_window_dims = [1, 2], +// CHECK: "mhlo.scatter"(%[[DST]], %[[IND_]], %[[UPD_]]) +// CHECK-SAME: update_window_dims = [1, 2], // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1], // CHECK-SAME: index_vector_dim = 1 @@ -91,7 +91,7 @@ func.func @move_index_vector_dim(%dst: tensor<3x3xf32>, // CHECK-SAME: %[[IND:.*]]: tensor<2x1xi32>, // CHECK-SAME: %[[UPD:.*]]: tensor<1x3x3xf32> -// CHECK: %[[IND_:.*]] = "mhlo.transpose"(%[[IND]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x1xi32>) -> tensor<1x2xi32> +// CHECK: %[[IND_:.*]] = "mhlo.transpose"(%[[IND]]) <{permutation = dense<[1, 0]> : tensor<2xi64>}> : (tensor<2x1xi32>) -> tensor<1x2xi32> // CHECK: "mhlo.scatter"(%[[DST]], %[[IND_]], %[[UPD]]) // CHECK: update_window_dims = [1, 2], // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1], @@ -121,21 +121,21 @@ func.func @transform_updates_and_operands_using_scatter_dims(%dst: tensor<3x4x5x // CHECK-SAME: %[[IND:.*]]: tensor<2x2xi32>, // CHECK-SAME: %[[UPD:.*]]: tensor<2x1x1x3xf32>) -> tensor<3x4x5xf32> { -// CHECK: %[[DST_:.*]] = "mhlo.transpose"(%[[DST]]) { +// CHECK: %[[DST_:.*]] = "mhlo.transpose"(%[[DST]]) <{ // CHECK-SAME: permutation = dense<[2, 0, 1]> : tensor<3xi64> -// CHECK-SAME: } : (tensor<3x4x5xf32>) -> tensor<5x3x4xf32> -// CHECK: %[[UPD_:.*]] = "mhlo.transpose"(%[[UPD]]) { +// CHECK-SAME: }> : (tensor<3x4x5xf32>) -> tensor<5x3x4xf32> +// CHECK: %[[UPD_:.*]] = "mhlo.transpose"(%[[UPD]]) <{ // CHECK-SAME: permutation = dense<[0, 3, 1, 2]> : tensor<4xi64> -// CHECK-SAME: } : (tensor<2x1x1x3xf32>) -> tensor<2x3x1x1xf32> +// CHECK-SAME: }> : (tensor<2x1x1x3xf32>) -> tensor<2x3x1x1xf32> // CHECK: %[[NEW_OP:.*]] = "mhlo.scatter"(%[[DST_]], %[[IND]], %[[UPD_]]) -// CHECK: update_window_dims = [1, 2, 3], +// CHECK-SAME: update_window_dims = [1, 2, 3], // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1], // CHECK-SAME: index_vector_dim = 1 -// CHECK-NEXT: "mhlo.transpose"(%[[NEW_OP:.*]]) { +// CHECK: "mhlo.transpose"(%[[NEW_OP:.*]]) <{ // CHECK-SAME: permutation = dense<[1, 2, 0]> : tensor<3xi64> -// CHECK-SAME: } : (tensor<5x3x4xf32>) -> tensor<3x4x5xf32> +// CHECK-SAME: }> : (tensor<5x3x4xf32>) -> tensor<3x4x5xf32> // ----- @@ -161,12 +161,12 @@ func.func @make_scatter_dims_leading_in_updates(%dst: tensor<3xf32>, // CHECK-SAME: %[[IND:.*]]: tensor<1x1xi32>, // CHECK-SAME: %[[UPD:.*]]: tensor<2x1xf32> -// CHECK: %[[UPD_:.*]] = "mhlo.transpose"(%[[UPD]]) { +// CHECK: %[[UPD_:.*]] = "mhlo.transpose"(%[[UPD]]) <{ // CHECK-SAME: permutation = dense<[1, 0]> : tensor<2xi64> -// CHECK-SAME: } : (tensor<2x1xf32>) -> tensor<1x2xf32> +// CHECK-SAME: }> : (tensor<2x1xf32>) -> tensor<1x2xf32> // CHECK: "mhlo.scatter"(%[[DST]], %[[IND]], %[[UPD_]] -// CHECK: update_window_dims = [1], +// CHECK-SAME: update_window_dims = [1], // CHECK-SAME: scatter_dims_to_operand_dims = [0], // CHECK-SAME: index_vector_dim = 1 @@ -197,8 +197,8 @@ func.func @zero_dim_scatter_indices(%dst: tensor<4x4xf32>, // CHECK-SAME: [0, 1]] : tensor<2xi32> into tensor<1x2xi32> // CHECK: %[[UPD_:.*]] = tensor.expand_shape %[[UPD]] [ // CHECK-SAME: [0, 1], [2]] : tensor<3x3xf32> into tensor<1x3x3xf32> -// CHECK: "mhlo.scatter"(%[[DST]], %[[IND_]], %[[UPD_]]) ({ -// CHECK: update_window_dims = [1, 2], +// CHECK: "mhlo.scatter"(%[[DST]], %[[IND_]], %[[UPD_]]) +// CHECK-SAME: update_window_dims = [1, 2], // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] // CHECK-SAME: index_vector_dim = 1 @@ -231,4 +231,4 @@ func.func @multiple_window_and_scatter_dims( // CHECK: %[[UPD0:.*]] = "mhlo.transpose"(%[[UPD]]) {{.*}} -> tensor<6x7x2x4xf32> // CHECK: %[[UPD1:.*]] = tensor.collapse_shape %[[UPD0]] {{.*}} into tensor<42x2x4xf32> // CHECK: %[[UPD2:.*]] = tensor.expand_shape %[[UPD1]] {{.*}} into tensor<42x1x2x1x4x1xf32> -// CHECK: "mhlo.scatter"(%[[DST]], %[[IND0]], %[[UPD2]]) \ No newline at end of file +// CHECK: "mhlo.scatter"(%[[DST]], %[[IND0]], %[[UPD2]]) diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_infer_shape_type_methods.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_infer_shape_type_methods.mlir index 358e33df945d0..a30c82e793778 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_infer_shape_type_methods.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_infer_shape_type_methods.mlir @@ -18,9 +18,9 @@ func.func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xindex func.func @select(%pred : tensor, %a : tensor, %b : tensor<1x?x3xf32>) -> tensor<1x2x3xindex> { %0 = "mhlo.select"(%pred, %a, %b) - : (tensor, tensor, tensor<1x?x3xf32>) -> tensor<*xf32> + : (tensor, tensor, tensor<1x?x3xf32>) -> tensor // CHECK: types0 = tensor<1x2x3xf32> - %1 = "mhlo_test.get_return_types"(%0) : (tensor<*xf32>) -> tensor<1x2x3xindex> + %1 = "mhlo_test.get_return_types"(%0) : (tensor) -> tensor<1x2x3xindex> func.return %1 : tensor<1x2x3xindex> } @@ -39,7 +39,7 @@ func.func @compare(%a : tensor<2x2xf32>, %b : tensor<2x2xf32>) -> tensor<2x2xind // CHECK-LABEL: @broadcast func.func @broadcast(%a : tensor<3xi32>) -> tensor<1x2x3xindex> { - %0 = "mhlo.broadcast"(%a) {broadcast_sizes = dense<[1, 2]> : tensor<2xi64>} + %0 = "mhlo.broadcast"(%a) <{broadcast_sizes = dense<[1, 2]> : tensor<2xi64>}> : (tensor<3xi32>) -> tensor<1x2x3xi32> // CHECK: types0 = tensor<1x2x3xi32> %1 = "mhlo_test.get_return_types"(%0) : (tensor<1x2x3xi32>) -> tensor<1x2x3xindex> @@ -51,7 +51,7 @@ func.func @broadcast(%a : tensor<3xi32>) -> tensor<1x2x3xindex> { func.func @broadcast(%a : tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+2 {{'mhlo.broadcast' op failed to infer returned types}} // expected-error@+1 {{Broadcast with negative dimension size -2}} - %0 = "mhlo.broadcast"(%a) {broadcast_sizes = dense<[1, -2]> : tensor<2xi64>} + %0 = "mhlo.broadcast"(%a) <{broadcast_sizes = dense<[1, -2]> : tensor<2xi64>}> : (tensor<3xi32>) -> tensor<1x2x3xi32> func.return %0 : tensor<1x2x3xi32> } @@ -60,7 +60,7 @@ func.func @broadcast(%a : tensor<3xi32>) -> tensor<1x2x3xi32> { // CHECK-LABEL: @dynamic_slice func.func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xindex> { - %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> // CHECK: types0 = tensor<1x4xi32> %1 = "mhlo_test.get_return_types"(%0) : (tensor<1x4xi32>) -> tensor<1x4xindex> func.return %1 : tensor<1x4xindex> @@ -70,11 +70,11 @@ func.func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tens // CHECK-LABEL: @pad func.func @pad(%arg0: tensor<1x2x3xf16>, %arg1: tensor) -> tensor<2x4x7xindex> { - %0 = "mhlo.pad"(%arg0, %arg1) { + %0 = "mhlo.pad"(%arg0, %arg1) <{ edge_padding_high = dense<[1, 1, 0]> : tensor<3xi64>, edge_padding_low = dense<[0, 1, 2]> : tensor<3xi64>, interior_padding = dense<[0, 0, 1]> : tensor<3xi64> - } : (tensor<1x2x3xf16>, tensor) -> tensor<2x4x7xf16> + }> : (tensor<1x2x3xf16>, tensor) -> tensor<2x4x7xf16> // CHECK: types0 = tensor<2x4x7xf16> %1 = "mhlo_test.get_return_types"(%0) : (tensor<2x4x7xf16>) -> tensor<2x4x7xindex> func.return %1 : tensor<2x4x7xindex> @@ -83,29 +83,29 @@ func.func @pad(%arg0: tensor<1x2x3xf16>, %arg1: tensor) -> tensor<2x4x7xind // ----- // CHECK-LABEL: @pad_with_bounds -func.func @pad_with_bounds(%arg0: tensor<3x?x?xf16, #mhlo.type_extensions>, %arg1: tensor) -> tensor<*xindex> { +func.func @pad_with_bounds(%arg0: tensor<3x?x?xf16, #mhlo.type_extensions>, %arg1: tensor) -> tensor { %0 = "mhlo.pad"(%arg0, %arg1) { edge_padding_low = dense<[2, 2, 0]> : tensor<3xi64>, edge_padding_high = dense<[0, 0, 0]> : tensor<3xi64>, interior_padding = dense<[1, 1, 1]> : tensor<3xi64> - } : (tensor<3x?x?xf16, #mhlo.type_extensions>, tensor) -> tensor<*xf16> + } : (tensor<3x?x?xf16, #mhlo.type_extensions>, tensor) -> tensor // CHECK: types0 = tensor<7x?x?xf16, #mhlo.type_extensions> - %1 = "mhlo_test.get_return_types"(%0) : (tensor<*xf16>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0) : (tensor) -> tensor + func.return %1 : tensor } // ----- -func.func @pad_with_negative_inferred_bounds(%arg0: tensor<3x?x?xf16, #mhlo.type_extensions>, %arg1: tensor) -> tensor<*xindex> { +func.func @pad_with_negative_inferred_bounds(%arg0: tensor<3x?x?xf16, #mhlo.type_extensions>, %arg1: tensor) -> tensor { // expected-error@+2 {{'mhlo.pad' op failed to infer returned types}} // expected-error@+1 {{Padding result in negative bound for dimension 1}} %0 = "mhlo.pad"(%arg0, %arg1) { edge_padding_low = dense<[2, -10, 0]> : tensor<3xi64>, edge_padding_high = dense<[0, 0, 0]> : tensor<3xi64>, interior_padding = dense<[1, 1, 1]> : tensor<3xi64> - } : (tensor<3x?x?xf16, #mhlo.type_extensions>, tensor) -> tensor<*xf16> - %1 = "mhlo_test.get_return_types"(%0) : (tensor<*xf16>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + } : (tensor<3x?x?xf16, #mhlo.type_extensions>, tensor) -> tensor + %1 = "mhlo_test.get_return_types"(%0) : (tensor) -> tensor + func.return %1 : tensor } // ----- @@ -136,16 +136,16 @@ func.func @alltoall(%data: tensor<4x16xf32>) -> tensor<16x4xindex> { // ----- // CHECK-LABEL: func @alltoall_bounds -func.func @alltoall_bounds(%data: tensor<16x?xf32, #mhlo.type_extensions>) -> tensor<*xindex> { +func.func @alltoall_bounds(%data: tensor<16x?xf32, #mhlo.type_extensions>) -> tensor { %0 = "mhlo.all_to_all"(%data) { split_dimension = 0 : i64, concat_dimension = 1 : i64, split_count = 4 : i64, replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> - } : (tensor<16x?xf32, #mhlo.type_extensions>) -> tensor<*xf32> + } : (tensor<16x?xf32, #mhlo.type_extensions>) -> tensor // CHECK: types0 = tensor<4x?xf32, #mhlo.type_extensions> - %1 = "mhlo_test.get_return_types"(%0) : (tensor<*xf32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0) : (tensor) -> tensor + func.return %1 : tensor } // ----- @@ -162,7 +162,7 @@ func.func @abs(%arg0: tensor<1x2xf32>) -> tensor<1x2xindex> { // CHECK-LABEL: @concat func.func @concat(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xindex> { - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> // CHECK: types0 = tensor<3xi32> %1 = "mhlo_test.get_return_types"(%0) : (tensor<3xi32>) -> tensor<3xindex> func.return %1 : tensor<3xindex> @@ -184,13 +184,13 @@ func.func @concat(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xindex> // CHECK-LABEL: @concat_bounds_c0 func.func @concat_bounds_c0( %arg0: tensor<5x1xi32, #mhlo.type_extensions>, - %arg1: tensor<5x2xi32, #mhlo.type_extensions>) -> tensor<*xindex> { - %result = "mhlo.concatenate"(%arg0, %arg1) { dimension = 1 : i64 } : ( + %arg1: tensor<5x2xi32, #mhlo.type_extensions>) -> tensor { + %result = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 1 : i64 }> : ( tensor<5x1xi32, #mhlo.type_extensions>, tensor<5x2xi32, #mhlo.type_extensions>) -> tensor // CHECK: types0 = tensor<5x3xi32> - %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor + func.return %1 : tensor } // ----- @@ -198,20 +198,20 @@ func.func @concat_bounds_c0( // CHECK-LABEL: @concat_bounds_c1 func.func @concat_bounds_c1( %arg0: tensor<5x2xi32, #mhlo.type_extensions>, - %arg1: tensor<5x?xi32, #mhlo.type_extensions>) -> tensor<*xindex> { - %result = "mhlo.concatenate"(%arg0, %arg1) { dimension = 1 : i64 } : ( + %arg1: tensor<5x?xi32, #mhlo.type_extensions>) -> tensor { + %result = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 1 : i64 }> : ( tensor<5x2xi32, #mhlo.type_extensions>, tensor<5x?xi32, #mhlo.type_extensions>) -> tensor // CHECK: types0 = tensor<5x?xi32> - %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor - %result_swap = "mhlo.concatenate"(%arg1, %arg0) { dimension = 1 : i64 } : ( + %result_swap = "mhlo.concatenate"(%arg1, %arg0) <{ dimension = 1 : i64 }> : ( tensor<5x?xi32, #mhlo.type_extensions>, tensor<5x2xi32, #mhlo.type_extensions>) -> tensor // CHECK: types0 = tensor<5x?xi32> - %2 = "mhlo_test.get_return_types"(%result_swap) : (tensor) -> tensor<*xindex> + %2 = "mhlo_test.get_return_types"(%result_swap) : (tensor) -> tensor - func.return %1 : tensor<*xindex> + func.return %1 : tensor } // ----- @@ -219,20 +219,20 @@ func.func @concat_bounds_c1( // CHECK-LABEL: @concat_bounds_c2 func.func @concat_bounds_c2( %arg0: tensor<5x2xi32, #mhlo.type_extensions>, - %arg1: tensor<5x?xi32, #mhlo.type_extensions>) -> tensor<*xindex> { - %result = "mhlo.concatenate"(%arg0, %arg1) { dimension = 1 : i64 } : ( + %arg1: tensor<5x?xi32, #mhlo.type_extensions>) -> tensor { + %result = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 1 : i64 }> : ( tensor<5x2xi32, #mhlo.type_extensions>, tensor<5x?xi32, #mhlo.type_extensions>) -> tensor // CHECK: types0 = tensor<5x?xi32, #mhlo.type_extensions> - %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor - %result_swap = "mhlo.concatenate"(%arg1, %arg0) { dimension = 1 : i64 } : ( + %result_swap = "mhlo.concatenate"(%arg1, %arg0) <{ dimension = 1 : i64 }> : ( tensor<5x?xi32, #mhlo.type_extensions>, tensor<5x2xi32, #mhlo.type_extensions>) -> tensor // CHECK: types0 = tensor<5x?xi32, #mhlo.type_extensions> - %2 = "mhlo_test.get_return_types"(%result_swap) : (tensor) -> tensor<*xindex> + %2 = "mhlo_test.get_return_types"(%result_swap) : (tensor) -> tensor - func.return %1 : tensor<*xindex> + func.return %1 : tensor } // ----- @@ -240,13 +240,13 @@ func.func @concat_bounds_c2( // CHECK-LABEL: @concat_bounds_c3 func.func @concat_bounds_c3( %arg0: tensor<5x?xi32, #mhlo.type_extensions>, - %arg1: tensor<5x?xi32, #mhlo.type_extensions>) -> tensor<*xindex> { - %result = "mhlo.concatenate"(%arg0, %arg1) { dimension = 1 : i64 } : ( + %arg1: tensor<5x?xi32, #mhlo.type_extensions>) -> tensor { + %result = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 1 : i64 }> : ( tensor<5x?xi32, #mhlo.type_extensions>, tensor<5x?xi32, #mhlo.type_extensions>) -> tensor // CHECK: types0 = tensor<5x?xi32> - %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor + func.return %1 : tensor } // ----- @@ -254,20 +254,20 @@ func.func @concat_bounds_c3( // CHECK-LABEL: @concat_bounds_c4 func.func @concat_bounds_c4( %arg0: tensor<5x?xi32, #mhlo.type_extensions>, - %arg1: tensor<5x?xi32, #mhlo.type_extensions>) -> tensor<*xindex> { - %result = "mhlo.concatenate"(%arg0, %arg1) { dimension = 1 : i64 } : ( + %arg1: tensor<5x?xi32, #mhlo.type_extensions>) -> tensor { + %result = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 1 : i64 }> : ( tensor<5x?xi32, #mhlo.type_extensions>, tensor<5x?xi32, #mhlo.type_extensions>) -> tensor // CHECK: types0 = tensor<5x?xi32> - %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor - %result_swap = "mhlo.concatenate"(%arg1, %arg0) { dimension = 1 : i64 } : ( + %result_swap = "mhlo.concatenate"(%arg1, %arg0) <{ dimension = 1 : i64 }> : ( tensor<5x?xi32, #mhlo.type_extensions>, tensor<5x?xi32, #mhlo.type_extensions>) -> tensor // CHECK: types0 = tensor<5x?xi32> - %2 = "mhlo_test.get_return_types"(%result_swap) : (tensor) -> tensor<*xindex> + %2 = "mhlo_test.get_return_types"(%result_swap) : (tensor) -> tensor - func.return %1 : tensor<*xindex> + func.return %1 : tensor } // ----- @@ -275,46 +275,13 @@ func.func @concat_bounds_c4( // CHECK-LABEL: @concat_bounds_c5 func.func @concat_bounds_c5( %arg0: tensor<5x?xi32, #mhlo.type_extensions>, - %arg1: tensor<5x?xi32, #mhlo.type_extensions>) -> tensor<*xindex> { - %result = "mhlo.concatenate"(%arg0, %arg1) { dimension = 1 : i64 } : ( + %arg1: tensor<5x?xi32, #mhlo.type_extensions>) -> tensor { + %result = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 1 : i64 }> : ( tensor<5x?xi32, #mhlo.type_extensions>, tensor<5x?xi32, #mhlo.type_extensions>) -> tensor // CHECK: types0 = tensor<5x?xi32, #mhlo.type_extensions> - %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor<*xindex> - func.return %1 : tensor<*xindex> -} - -// ----- - -// Note: unranked input types can't be ignored, consider these input types: -// c0: (<5x?xf32>, <*xf32>) with concat dim 0 should infer -// c1: (<5x?xf32>, <*xf32>) with concat dim 1 should infer <5x?xf32> -// Instead, they should be replaced with dynamic tensors: tensor - -// CHECK-LABEL: @concat_bounds_unranked_c0 -func.func @concat_bounds_unranked_c0( - %arg0: tensor<*xi32>, - %arg1: tensor<5x?xi32, #mhlo.type_extensions>) -> tensor<*xindex> { - %result = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : ( - tensor<*xi32>, - tensor<5x?xi32, #mhlo.type_extensions>) -> tensor<5x?xi32> - // CHECK: types0 = tensor> - %1 = "mhlo_test.get_return_types"(%result) : (tensor<5x?xi32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> -} - -// ----- - -// CHECK-LABEL: @concat_bounds_unranked_c1 -func.func @concat_bounds_unranked_c1( - %arg0: tensor<*xi32>, - %arg1: tensor<5x?xi32, #mhlo.type_extensions>) -> tensor<*xindex> { - %result = "mhlo.concatenate"(%arg0, %arg1) { dimension = 1 : i64 } : ( - tensor<*xi32>, - tensor<5x?xi32, #mhlo.type_extensions>) -> tensor<5x?xi32> - // CHECK: types0 = tensor<5x?xi32> - %1 = "mhlo_test.get_return_types"(%result) : (tensor<5x?xi32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor + func.return %1 : tensor } // ----- @@ -323,34 +290,34 @@ func.func @concat_bounds_unranked_c1( // CHECK-LABEL: func @if_bounds func.func @if_bounds(%pred : tensor, %true_branch_operand : tensor<2x3x4x?x?x?xf32, #mhlo.type_extensions>, - %false_branch_operand : tensor<2x?x?x?x?x?xf32, #mhlo.type_extensions>) -> tensor<*xindex> { + %false_branch_operand : tensor<2x?x?x?x?x?xf32, #mhlo.type_extensions>) -> tensor { %0 = "mhlo.if"(%pred) ({ "mhlo.return"(%true_branch_operand) : ( tensor<2x3x4x?x?x?xf32, #mhlo.type_extensions>) -> () }, { "mhlo.return"(%false_branch_operand) : ( tensor<2x?x?x?x?x?xf32, #mhlo.type_extensions>) -> () - }) : (tensor) -> tensor<*xf32> + }) : (tensor) -> tensor // CHECK: types0 = tensor<2x?x?x?x?x?xf32, #mhlo.type_extensions> - %1 = "mhlo_test.get_return_types"(%0) : (tensor<*xf32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0) : (tensor) -> tensor + func.return %1 : tensor } // ----- func.func @if_bounds_unranked(%pred : tensor, %true_branch_operand : tensor<2x3x4x?x?x?xf32, #mhlo.type_extensions>, - %false_branch_operand : tensor<*xf32>) -> tensor<*xindex> { + %false_branch_operand : tensor) -> tensor { %0 = "mhlo.if"(%pred) ({ "mhlo.return"(%true_branch_operand) : ( tensor<2x3x4x?x?x?xf32, #mhlo.type_extensions>) -> () }, { "mhlo.return"(%false_branch_operand) : ( - tensor<*xf32>) -> () - }) : (tensor) -> tensor<*xf32> - // CHECK: types0 = tensor<*xf32> - %1 = "mhlo_test.get_return_types"(%0) : (tensor<*xf32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + tensor) -> () + }) : (tensor) -> tensor + // CHECK: types0 = tensor + %1 = "mhlo_test.get_return_types"(%0) : (tensor) -> tensor + func.return %1 : tensor } // ----- @@ -360,17 +327,17 @@ func.func @if_bounds_unranked(%pred : tensor, // CHECK-LABEL: func @case_bounds func.func @case_bounds(%index : tensor, %branch_0_operand : tensor<2xf32, #mhlo.type_extensions>, - %branch_2_operand : tensor>) -> tensor<*xindex> { + %branch_2_operand : tensor>) -> tensor { %0 = "mhlo.case"(%index) ({ "mhlo.return"(%branch_0_operand) : (tensor<2xf32, #mhlo.type_extensions>) -> () }, { "mhlo.return"(%branch_0_operand) : (tensor<2xf32, #mhlo.type_extensions>) -> () }, { "mhlo.return"(%branch_2_operand) : (tensor>) -> () - }) : (tensor) -> tensor<*xf32> + }) : (tensor) -> tensor // CHECK: types0 = tensor> - %1 = "mhlo_test.get_return_types"(%0) : (tensor<*xf32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0) : (tensor) -> tensor + func.return %1 : tensor } // ----- @@ -378,7 +345,7 @@ func.func @case_bounds(%index : tensor, // CHECK-LABEL: while_bounds func.func @while_bounds( %while_arg_1: tensor<2x?xi32, #mhlo.type_extensions>, - %while_arg_2: tensor<3xf32>) -> tensor<*xindex> { + %while_arg_2: tensor<3xf32>) -> tensor { %1:2 = "mhlo.while"(%while_arg_1, %while_arg_2) ({ ^bb0(%arg1: tensor<2x?xi32, #mhlo.type_extensions>, %arg2: tensor<3xf32>): %2 = mhlo.constant dense<1> : tensor @@ -386,11 +353,11 @@ func.func @while_bounds( }, { ^bb0(%arg1: tensor<2x?xi32, #mhlo.type_extensions>, %arg2: tensor<3xf32>): "mhlo.return"(%arg1, %arg2) : (tensor<2x?xi32, #mhlo.type_extensions>, tensor<3xf32>) -> () - }) : (tensor<2x?xi32, #mhlo.type_extensions>, tensor<3xf32>) -> (tensor<*xi32>, tensor<*xf32>) + }) : (tensor<2x?xi32, #mhlo.type_extensions>, tensor<3xf32>) -> (tensor, tensor) // CHECK: types0 = tensor<2x?xi32, #mhlo.type_extensions>, // CHECK-SAME: types1 = tensor<3xf32> - %3 = "mhlo_test.get_return_types"(%1) : (tensor<*xi32>) -> tensor<*xindex> - func.return %3 : tensor<*xindex> + %3 = "mhlo_test.get_return_types"(%1) : (tensor) -> tensor + func.return %3 : tensor } // ----- @@ -417,7 +384,7 @@ func.func @gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor<1x5x2xi3 // CHECK-LABEL: @gather_bounds func.func @gather_bounds(%operand : tensor>, %start_indices : tensor>) - -> tensor<*xindex> { + -> tensor { %res = "mhlo.gather"(%operand, %start_indices) { dimension_numbers = #mhlo.gather< collapsed_slice_dims = [0, 1], @@ -431,8 +398,8 @@ func.func @gather_bounds(%operand : tensor tensor // CHECK: types0 = tensor> - %1 = "mhlo_test.get_return_types"(%res) : (tensor) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%res) : (tensor) -> tensor + func.return %1 : tensor } // ----- @@ -440,7 +407,7 @@ func.func @gather_bounds(%operand : tensor, %arg1: tensor) -> tensor<7xindex> { %0 = "mhlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64> - %1 = "mhlo.rng"(%arg0, %arg1, %0) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<1xi64>) -> tensor<7xf32> + %1 = "mhlo.rng"(%arg0, %arg1, %0) <{rng_distribution = #mhlo.rng_distribution}> : (tensor, tensor, tensor<1xi64>) -> tensor<7xf32> // CHECK: types0 = tensor<7xf32> %2 = "mhlo_test.get_return_types"(%1) : (tensor<7xf32>) -> tensor<7xindex> func.return %2 : tensor<7xindex> @@ -451,7 +418,7 @@ func.func @rng_normal(%arg0: tensor, %arg1: tensor) -> tensor<7xindex> // CHECK-LABEL: func @rng_uniform func.func @rng_uniform(%a: tensor, %b: tensor) -> tensor<2x3x5xindex> { %0 = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> - %1 = "mhlo.rng"(%a, %b, %0) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + %1 = "mhlo.rng"(%a, %b, %0) <{rng_distribution = #mhlo.rng_distribution}> : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> // CHECK: types0 = tensor<2x3x5xf32> %2 = "mhlo_test.get_return_types"(%1) : (tensor<2x3x5xf32>) -> tensor<2x3x5xindex> func.return %2 : tensor<2x3x5xindex> @@ -461,7 +428,7 @@ func.func @rng_uniform(%a: tensor, %b: tensor) -> tensor<2x3x5xindex> // CHECK-LABEL: func @slice func.func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x2xindex> { - %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32> + %0 = "mhlo.slice"(%arg0) <{start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>}> : (tensor<3x4xi32>) -> tensor<1x2xi32> // CHECK: types0 = tensor<1x2xi32> %1 = "mhlo_test.get_return_types"(%0) : (tensor<1x2xi32>) -> tensor<1x2xindex> func.return %1 : tensor<1x2xindex> @@ -470,21 +437,21 @@ func.func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x2xindex> { // ----- // CHECK-LABEL: func @slice_with_bounds -func.func @slice_with_bounds(%arg0: tensor<3x?x?xi32, #mhlo.type_extensions>) -> tensor<*xindex> { - %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 0, 0]> : tensor<3xi64>, limit_indices = dense<[2, 4, 4]> : tensor<3xi64>, strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor<3x?x?xi32, #mhlo.type_extensions>) -> tensor<*xi32> +func.func @slice_with_bounds(%arg0: tensor<3x?x?xi32, #mhlo.type_extensions>) -> tensor { + %0 = "mhlo.slice"(%arg0) <{start_indices = dense<[1, 0, 0]> : tensor<3xi64>, limit_indices = dense<[2, 4, 4]> : tensor<3xi64>, strides = dense<[1, 2, 2]> : tensor<3xi64>}> : (tensor<3x?x?xi32, #mhlo.type_extensions>) -> tensor // CHECK: types0 = tensor<1x2x2xi32> - %1 = "mhlo_test.get_return_types"(%0) : (tensor<*xi32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0) : (tensor) -> tensor + func.return %1 : tensor } // ----- -func.func @slice_with_index_larger_than_bound_dim(%arg0: tensor<3x?x?xi32, #mhlo.type_extensions>) -> tensor<*xindex> { +func.func @slice_with_index_larger_than_bound_dim(%arg0: tensor<3x?x?xi32, #mhlo.type_extensions>) -> tensor { // expected-error@+2 {{'mhlo.slice' op failed to infer returned types}} // expected-error@+1 {{limit index 5 is larger than dimension bound 4 in dimension 1}} - %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 0, 0]> : tensor<3xi64>, limit_indices = dense<[2, 5, 4]> : tensor<3xi64>, strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor<3x?x?xi32, #mhlo.type_extensions>) -> tensor<*xi32> - %1 = "mhlo_test.get_return_types"(%0) : (tensor<*xi32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %0 = "mhlo.slice"(%arg0) <{start_indices = dense<[1, 0, 0]> : tensor<3xi64>, limit_indices = dense<[2, 5, 4]> : tensor<3xi64>, strides = dense<[1, 2, 2]> : tensor<3xi64>}> : (tensor<3x?x?xi32, #mhlo.type_extensions>) -> tensor + %1 = "mhlo_test.get_return_types"(%0) : (tensor) -> tensor + func.return %1 : tensor } // ----- @@ -511,7 +478,7 @@ func.func @uniform_dequantize(%arg: tensor<16x16x!quant.uniform // CHECK-LABEL: func @fft func.func @fft(%arg0: tensor<3x9xcomplex>) -> tensor<3x9xindex> { - %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo } : (tensor<3x9xcomplex>) -> tensor<3x9xcomplex> + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo }> : (tensor<3x9xcomplex>) -> tensor<3x9xcomplex> // CHECK: types0 = tensor<3x9xcomplex> %1 = "mhlo_test.get_return_types"(%0) : (tensor<3x9xcomplex>) -> tensor<3x9xindex> func.return %1 : tensor<3x9xindex> @@ -520,37 +487,37 @@ func.func @fft(%arg0: tensor<3x9xcomplex>) -> tensor<3x9xindex> { // ----- // CHECK-LABEL: func @batch_norm_grad -func.func @batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<*xindex> { +func.func @batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor { %0:3 = "mhlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) // CHECK: types0 = tensor<2x2x2x2xf32> // CHECK-SAME: types1 = tensor<2xf32> // CHECK-SAME: types2 = tensor<2xf32> - %1 = "mhlo_test.get_return_types"(%0#0) : (tensor<2x2x2x2xf32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0#0) : (tensor<2x2x2x2xf32>) -> tensor + func.return %1 : tensor } // ----- // CHECK-LABEL: func @batch_norm_train -func.func @batch_norm_train(%input: tensor<2x?x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tensor<*xindex> { +func.func @batch_norm_train(%input: tensor<2x?x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tensor { %0:3 = "mhlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 1 : i64} : (tensor<2x?x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<2x?x2x2xf32>, tensor, tensor) // CHECK: types0 = tensor<2x?x2x2xf32> // CHECK-SAME: types1 = tensor // CHECK-SAME: types2 = tensor - %1 = "mhlo_test.get_return_types"(%0#0) : (tensor<2x?x2x2xf32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0#0) : (tensor<2x?x2x2xf32>) -> tensor + func.return %1 : tensor } // ----- // CHECK-LABEL: @batch_norm_inference -func.func @batch_norm_inference(%input: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, %mean: tensor<256xf32>, %variance: tensor<256xf32>) -> (tensor<*xindex>) { +func.func @batch_norm_inference(%input: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, %mean: tensor<256xf32>, %variance: tensor<256xf32>) -> (tensor) { %0 = "mhlo.batch_norm_inference" (%input, %scale, %offset, %mean, %variance) {epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} : (tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>) -> tensor<4x256xf32> // CHECK: types0 = tensor<4x256xf32> - %1 = "mhlo_test.get_return_types"(%0) : (tensor<4x256xf32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0) : (tensor<4x256xf32>) -> tensor + func.return %1 : tensor } // ----- @@ -559,13 +526,13 @@ func.func @batch_norm_inference(%input: tensor<4x256xf32>, %scale: tensor<256xf3 func.func @batch_norm_inference_bounds( %input: tensor<4x?xf32, #mhlo.type_extensions>, %scale: tensor, %offset: tensor, %mean: tensor, %variance: tensor -) -> (tensor<*xindex>) { +) -> (tensor) { %0 = "mhlo.batch_norm_inference" (%input, %scale, %offset, %mean, %variance) { epsilon = 1.001000e-05 : f32, feature_index = 1 : i64 } : (tensor<4x?xf32, #mhlo.type_extensions>, tensor, tensor, tensor, tensor) -> tensor<4x?xf32, #mhlo.type_extensions> // CHECK: types0 = tensor<4x?xf32, #mhlo.type_extensions> - %1 = "mhlo_test.get_return_types"(%0) : (tensor<4x?xf32, #mhlo.type_extensions>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0) : (tensor<4x?xf32, #mhlo.type_extensions>) -> tensor + func.return %1 : tensor } // ----- @@ -577,7 +544,7 @@ func.func @batch_norm_grad_bounds( %mean: tensor>, %variance: tensor>, %grad_output: tensor<2x?xf32, #mhlo.type_extensions> -) -> tensor<*xindex> { +) -> tensor { %0:3 = "mhlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) { epsilon = 0.001 : f32, feature_index = 1 : i64 } : ( @@ -595,8 +562,8 @@ func.func @batch_norm_grad_bounds( // CHECK: types0 = tensor<2x?xf32, #mhlo.type_extensions> // CHECK-SAME: types1 = tensor> // CHECK-SAME: types2 = tensor> - %1 = "mhlo_test.get_return_types"(%0#0) : (tensor<2x?xf32, #mhlo.type_extensions>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0#0) : (tensor<2x?xf32, #mhlo.type_extensions>) -> tensor + func.return %1 : tensor } // ----- @@ -606,7 +573,7 @@ func.func @batch_norm_train_bounds( %input: tensor<2x?xf32, #mhlo.type_extensions>, %scale: tensor>, %offset: tensor> -) -> tensor<*xindex> { +) -> tensor { %0:3 = "mhlo.batch_norm_training" (%input, %scale, %offset) { epsilon = 0.001 : f32, feature_index = 1 : i64 } : ( @@ -622,8 +589,8 @@ func.func @batch_norm_train_bounds( // CHECK: types0 = tensor<2x?xf32, #mhlo.type_extensions> // CHECK-SAME: types1 = tensor> // CHECK-SAME: types2 = tensor> - %1 = "mhlo_test.get_return_types"(%0#0) : (tensor<2x?xf32, #mhlo.type_extensions>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0#0) : (tensor<2x?xf32, #mhlo.type_extensions>) -> tensor + func.return %1 : tensor } // ----- @@ -695,7 +662,7 @@ func.func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) -> tenso // ----- // CHECK-LABEL: @sort_bounds_and_unknown_rank -func.func @sort_bounds_and_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<5x?x?xi32, #mhlo.type_extensions>) { +func.func @sort_bounds_and_unknown_rank(%input0: tensor, %input1: tensor<5x?x?xi32, #mhlo.type_extensions>) { %0, %1 = "mhlo.sort"(%input0, %input1) ({ ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %pred = "mhlo.compare"(%arg0, %arg1) { @@ -703,12 +670,12 @@ func.func @sort_bounds_and_unknown_rank(%input0: tensor<*xf32>, %input1: tensor< } : (tensor, tensor) -> tensor "mhlo.return"(%pred) : (tensor) -> () }) { dimension = 1 : i64, is_stable = true } : ( - tensor<*xf32>, + tensor, tensor<5x?x?xi32, #mhlo.type_extensions> - ) -> (tensor<*xf32>, tensor<*xi32>) - // CHECK: types0 = tensor<*xf32> + ) -> (tensor, tensor) + // CHECK: types0 = tensor // CHECK-SAME: types1 = tensor<5x?x?xi32, #mhlo.type_extensions> - %2 = "mhlo_test.get_return_types"(%0) : (tensor<*xf32>) -> tensor<*xindex> + %2 = "mhlo_test.get_return_types"(%0) : (tensor) -> tensor func.return } @@ -750,7 +717,7 @@ func.func @while(%arg0: tensor<4xf32>, %arg1: tensor, %arg2: tensor, % // CHECK-LABEL: func @get_dimension_size func.func @get_dimension_size(%arg0: tensor<4x2xf32>) -> tensor { - %0 = "mhlo.get_dimension_size"(%arg0) {dimension = 1 : i64} : (tensor<4x2xf32>) -> tensor + %0 = "mhlo.get_dimension_size"(%arg0) <{dimension = 1 : i64}> : (tensor<4x2xf32>) -> tensor // CHECK: types0 = tensor %1 = "mhlo_test.get_return_types"(%0) : (tensor) -> tensor func.return %1 : tensor @@ -769,11 +736,11 @@ func.func @dynamic_update_slice(%arg0: tensor<4x4xi32>, %arg1: tensor<2x2xi32>, // ----- // CHECK-LABEL: @dynamic_update_slice_with_bounds -func.func @dynamic_update_slice_with_bounds(%input: tensor<3x?x?xi64, #mhlo.type_extensions>, %update: tensor<1x4x3xi64>, %start1: tensor, %start2: tensor, %start3 : tensor) -> tensor<*xindex> { +func.func @dynamic_update_slice_with_bounds(%input: tensor<3x?x?xi64, #mhlo.type_extensions>, %update: tensor<1x4x3xi64>, %start1: tensor, %start2: tensor, %start3 : tensor) -> tensor { %0 = "mhlo.dynamic_update_slice"(%input, %update, %start1, %start2, %start3) : (tensor<3x?x?xi64, #mhlo.type_extensions>, tensor<1x4x3xi64>, tensor, tensor, tensor) -> tensor<3x?x?xi64> // CHECK: types0 = tensor<3x?x?xi64, #mhlo.type_extensions> - %1 = "mhlo_test.get_return_types"(%0) : (tensor<3x?x?xi64>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0) : (tensor<3x?x?xi64>) -> tensor + func.return %1 : tensor } // ----- @@ -866,7 +833,7 @@ func.func @scatter(%input_tensor: tensor<200x100x300xf32>, // CHECK-LABEL: func @scatter_bounds func.func @scatter_bounds(%input_tensor: tensor<200x?x?xf32, #mhlo.type_extensions>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> - tensor<*xindex> { + tensor { %0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ ^bb0(%lhs: tensor, %rhs: tensor): %add = mhlo.add %lhs, %rhs : tensor @@ -884,8 +851,8 @@ func.func @scatter_bounds(%input_tensor: tensor<200x?x?xf32, #mhlo.type_extensio tensor<200x?x?xf32> // CHECK: types0 = tensor<200x?x?xf32, #mhlo.type_extensions> - %1 = "mhlo_test.get_return_types"(%0) : (tensor<200x?x?xf32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0) : (tensor<200x?x?xf32>) -> tensor + func.return %1 : tensor } // ----- @@ -982,7 +949,7 @@ func.func @reduce(%arg0: tensor<7x5xf32>, %arg1 : tensor<5xf32>) // CHECK-LABEL: func @reduce_with_bounds func.func @reduce_with_bounds(%arg0: tensor>, %arg1 : tensor<5xf32>) - -> (tensor<*xindex>) { + -> (tensor) { %0 = "mhlo.reduce"(%arg0, %arg1) ({ ^bb0(%arg2: tensor<5xf32>, %arg3: tensor<5xf32> ): @@ -995,25 +962,9 @@ func.func @reduce_with_bounds(%arg0: tensor> %2 = "mhlo_test.get_return_types"(%0) - : (tensor>) -> tensor<*xindex> + : (tensor>) -> tensor - func.return %2: tensor<*xindex> -} - -// ----- - -// CHECK-LABEL: func @unranked_reduce -func.func @unranked_reduce(%arg0: tensor<*xf32>, %arg1 : tensor) - -> (tensor<*xindex>) { - %0 = "mhlo.reduce"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor ): - %1 = "mhlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor - "mhlo.return"(%1) : (tensor) -> () - }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<*xf32>, tensor) -> tensor<*xf32> - - // CHECK: types0 = tensor<*xf32> - %2 = "mhlo_test.get_return_types"(%0) : (tensor<*xf32>) -> tensor<*xindex> - func.return %2: tensor<*xindex> + func.return %2: tensor } // ----- @@ -1045,68 +996,57 @@ func.func @reduce_window(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, //===----------------------------------------------------------------------===// // CHECK-LABEL: @tensor_bounds -func.func @tensor_bounds(%arg0: tensor<3x5xf32>, %arg1: tensor) -> tensor<*xindex> { - %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x5xf32>, tensor) -> tensor<*xf32> +func.func @tensor_bounds(%arg0: tensor<3x5xf32>, %arg1: tensor) -> tensor { + %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x5xf32>, tensor) -> tensor // CHECK: types0 = tensor> - %1 = "mhlo_test.get_return_types"(%result) : (tensor<*xf32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor + func.return %1 : tensor } // ----- // CHECK-LABEL: @static_tensor_bounds -func.func @static_tensor_bounds(%arg0: tensor>) -> tensor<*xindex> { +func.func @static_tensor_bounds(%arg0: tensor>) -> tensor { %bounds = mhlo.constant dense<8> : tensor - %result = "mhlo.set_dimension_size"(%arg0, %bounds) {dimension = 0 : i64} : (tensor>, tensor) -> tensor<*xf32> + %result = "mhlo.set_dimension_size"(%arg0, %bounds) {dimension = 0 : i64} : (tensor>, tensor) -> tensor // CHECK: types0 = tensor<8x5xf32> - %1 = "mhlo_test.get_return_types"(%result) : (tensor<*xf32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor + func.return %1 : tensor } // ----- // CHECK-LABEL: @edit_tensor_bounds -func.func @edit_tensor_bounds(%arg0: tensor>, %arg1: tensor) -> tensor<*xindex> { - %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 1 : i64} : (tensor>, tensor) -> tensor<*xf32> +func.func @edit_tensor_bounds(%arg0: tensor>, %arg1: tensor) -> tensor { + %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 1 : i64} : (tensor>, tensor) -> tensor // CHECK: types0 = tensor> - %1 = "mhlo_test.get_return_types"(%result) : (tensor<*xf32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor + func.return %1 : tensor } // ----- // CHECK-LABEL: @retain_tensor_bounds -func.func @retain_tensor_bounds(%arg0: tensor>, %arg1: tensor) -> tensor<*xindex> { - %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 0 : i64} : (tensor>, tensor) -> tensor<*xf32> +func.func @retain_tensor_bounds(%arg0: tensor>, %arg1: tensor) -> tensor { + %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 0 : i64} : (tensor>, tensor) -> tensor // CHECK: types0 = tensor> - %1 = "mhlo_test.get_return_types"(%result) : (tensor<*xf32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor + func.return %1 : tensor } // ----- // CHECK-LABEL: @unknown_bounds -func.func @unknown_bounds(%arg0: tensor>, %arg1: tensor) -> tensor<*xindex> { - %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 1 : i64} : (tensor>, tensor) -> tensor<*xf32> +func.func @unknown_bounds(%arg0: tensor>, %arg1: tensor) -> tensor { + %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 1 : i64} : (tensor>, tensor) -> tensor // CHECK: types0 = tensor> - %1 = "mhlo_test.get_return_types"(%result) : (tensor<*xf32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> -} - -// ----- - -// CHECK-LABEL: @unranked_input -func.func @unranked_input(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xindex> { - %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<*xf32>, tensor) -> tensor<*xf32> - - // CHECK: types0 = tensor<*xf32> - %1 = "mhlo_test.get_return_types"(%result) : (tensor<*xf32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor + func.return %1 : tensor } // ----- @@ -1117,7 +1057,7 @@ func.func @unranked_input(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*x // CHECK-LABEL: @add_bounds func.func @add_bounds( %arg0: tensor<3x3x3x?x?x?x?xf32, #mhlo.type_extensions>, - %arg1: tensor<3x?x?x?x?x?x?xf32, #mhlo.type_extensions>) -> tensor<*xindex> { + %arg1: tensor<3x?x?x?x?x?x?xf32, #mhlo.type_extensions>) -> tensor { %result1 = "mhlo.add"(%arg0, %arg1) : ( tensor<3x3x3x?x?x?x?xf32, #mhlo.type_extensions>, tensor<3x?x?x?x?x?x?xf32, #mhlo.type_extensions>) @@ -1128,11 +1068,11 @@ func.func @add_bounds( -> tensor // CHECK: types0 = tensor<3x3x3x?x?x?x?xf32, #mhlo.type_extensions> - %1 = "mhlo_test.get_return_types"(%result1) : (tensor) -> tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%result1) : (tensor) -> tensor // CHECK: types0 = tensor<3x3x3x?x?x?x?xf32, #mhlo.type_extensions> - %2 = "mhlo_test.get_return_types"(%result2) : (tensor) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %2 = "mhlo_test.get_return_types"(%result2) : (tensor) -> tensor + func.return %1 : tensor } // ----- @@ -1141,35 +1081,23 @@ func.func @add_bounds( // See PairwiseSameOperandAndResultType::inferDimWithBound() func.func @add_bounds_mismatch( %arg0: tensor<3xf32, #mhlo.type_extensions>, - %arg1: tensor>) -> tensor<*xindex> { + %arg1: tensor>) -> tensor { // expected-error@+1 {{requires compatible types for all operands and results}} %result = "mhlo.add"(%arg0, %arg1) : ( tensor<3xf32, #mhlo.type_extensions>, tensor>) -> tensor - %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor<*xindex> - func.return %1 : tensor<*xindex> -} - -// ----- - -// CHECK-LABEL: @add_bounds_unranked -func.func @add_bounds_unranked( - %arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xindex> { - %result = "mhlo.add"(%arg0, %arg1) : ( - tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - // CHECK: types0 = tensor<*xf32> - %1 = "mhlo_test.get_return_types"(%result) : (tensor<*xf32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor + func.return %1 : tensor } // ----- // CHECK-LABEL: @partition_id -func.func @partition_id() -> tensor<*xindex> { +func.func @partition_id() -> tensor { %result = "mhlo.partition_id"() : () -> tensor // CHECK: types0 = tensor - %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor + func.return %1 : tensor } // ----- @@ -1189,9 +1117,9 @@ func.func @send(%arg0: !mhlo.token) -> !mhlo.token { // CHECK-LABEL: func @gather // CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4x2xi32>, %[[ARG1:.*]]: tensor func.func @gather(%operand : tensor<3x4x2xi32>, %start_indices : tensor) -> tensor<4xindex> { - // CHECK: %[[C2:.*]] = arith.constant 2 : index - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[C3:.*]] = arith.constant 3 : index + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor // CHECK: %[[RES:.*]] = tensor.from_elements %[[DIM]], %[[C3]], %[[C2]], %[[C2]] : tensor<4xindex> // CHECK: return %[[RES]] : tensor<4xindex> @@ -1213,8 +1141,8 @@ func.func @gather(%operand : tensor<3x4x2xi32>, %start_indices : tensor func.func @pad(%arg0: tensor) -> tensor<4xindex> { - // CHECK: %[[CST0:.*]] = arith.constant 0 : index - // CHECK: %[[CST1:.*]] = arith.constant 48 : index + // CHECK-DAG: %[[CST0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[CST1:.*]] = arith.constant 48 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[CST0]] : tensor // CHECK: %[[RES:.*]] = tensor.from_elements %[[DIM]], %[[CST1]], %[[CST1]], %[[CST1]] : tensor<4xindex> // CHECK: return %[[RES]] : tensor<4xindex> @@ -1231,18 +1159,18 @@ func.func @pad(%arg0: tensor) -> tensor<4xindex> { // ----- // CHECK-LABEL: func @cholesky_bounds -func.func @cholesky_bounds(%input: tensor<2x?x?xf32, #mhlo.type_extensions>) -> tensor<*xindex> { - %0 = "mhlo.cholesky"(%input) { lower = true } : (tensor<2x?x?xf32, #mhlo.type_extensions>) -> tensor<*xf32> +func.func @cholesky_bounds(%input: tensor<2x?x?xf32, #mhlo.type_extensions>) -> tensor { + %0 = "mhlo.cholesky"(%input) { lower = true } : (tensor<2x?x?xf32, #mhlo.type_extensions>) -> tensor // CHECK: types0 = tensor<2x?x?xf32, #mhlo.type_extensions> - %1 = "mhlo_test.get_return_types"(%0) : (tensor<*xf32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0) : (tensor) -> tensor + func.return %1 : tensor } // CHECK-LABEL: func @concatenate // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor func.func @concatenate(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<2xindex> { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor // CHECK: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor // CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor @@ -1302,8 +1230,8 @@ func.func @real_dynamic_slice(%arg0: tensor, %arg1: tensor<1xindex>, %arg // CHECK-LABEL: func @dot_general // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor func.func @dot_general(%arg0: tensor, %arg1: tensor) -> tensor<3xindex> { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor // CHECK: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C2]] : tensor // CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C2]] : tensor @@ -1357,7 +1285,7 @@ func.func @broadcast(%arg0: tensor) -> tensor<3xindex> { // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor // CHECK: %[[RES:.*]] = tensor.from_elements %[[C1]], %[[C2]], %[[DIM]] : tensor<3xindex> // CHECK: return %[[RES]] : tensor<3xindex> - %result = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor) -> tensor<1x2x?xi32> + %result = "mhlo.broadcast"(%arg0) <{broadcast_sizes = dense<[1, 2]> : tensor<2xi64>}> : (tensor) -> tensor<1x2x?xi32> %1 = "mhlo_test.reify_return_type_shapes"(%result): (tensor<1x2x?xi32>) -> tensor<3xindex> func.return %1: tensor<3xindex> } @@ -1367,17 +1295,17 @@ func.func @broadcast(%arg0: tensor) -> tensor<3xindex> { // CHECK-LABEL: func @transpose // CHECK-SAME: (%[[ARG0:.*]]: tensor func.func @transpose(%arg0: tensor) -> tensor<4xindex> { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[C1:.*]] = arith.constant 1 : index - // CHECK: %[[C2:.*]] = arith.constant 2 : index - // CHECK: %[[C3:.*]] = arith.constant 3 : index + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor // CHECK: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor // CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C2]] : tensor // CHECK: %[[DIM2:.*]] = tensor.dim %[[ARG0]], %[[C3]] : tensor // CHECK: %[[RES:.*]] = tensor.from_elements %[[DIM0]], %[[DIM]], %[[DIM2]], %[[DIM1]] : tensor<4xindex> // CHECK: return %[[RES]] : tensor<4xindex> - %result = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor) -> tensor + %result = "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}> : (tensor) -> tensor %1 = "mhlo_test.reify_return_type_shapes"(%result): (tensor) -> tensor<4xindex> func.return %1: tensor<4xindex> } @@ -1400,7 +1328,7 @@ func.func @dynamic_iota(%arg0: tensor<1xindex>) -> tensor<1xindex> { // CHECK: func @select_and_scatter_bound func.func @select_and_scatter_bound( %arg0: tensor>, - %arg1: tensor>) -> tensor<*xindex> { + %arg1: tensor>) -> tensor { %0 = mhlo.constant dense<0.000000e+00> : tensor %1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({ ^bb0(%arg3: tensor, %arg4: tensor): @@ -1418,17 +1346,17 @@ func.func @select_and_scatter_bound( window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> } : (tensor>, tensor>, - tensor) -> tensor<*xf32> + tensor) -> tensor // CHECK: types0 = tensor> - %3 = "mhlo_test.get_return_types"(%1) : (tensor<*xf32>) -> tensor<*xindex> - func.return %3 : tensor<*xindex> + %3 = "mhlo_test.get_return_types"(%1) : (tensor) -> tensor + func.return %3 : tensor } // ----- // CHECK-LABEL: func @reduce_window_bound func.func @reduce_window_bound(%arg0: tensor<4x?x?x?xf32, #mhlo.type_extensions>, - %init0: tensor) -> (tensor<*xindex>) { + %init0: tensor) -> (tensor) { %0:1 = "mhlo.reduce_window"(%arg0, %init0) ({ ^bb0(%a0: tensor, %b0: tensor): %2 = mhlo.add %a0, %b0 : tensor @@ -1438,10 +1366,10 @@ func.func @reduce_window_bound(%arg0: tensor<4x?x?x?xf32, #mhlo.type_extensions< window_dimensions = dense<[1, 1, 5, 1]> : tensor<4xi64>, window_strides = dense<[1, 1, 3, 1]> : tensor<4xi64> } : (tensor<4x?x?x?xf32, #mhlo.type_extensions>, - tensor) -> (tensor<*xf32>) + tensor) -> (tensor) // CHECK: types0 = tensor<4x?x?x?xf32, #mhlo.type_extensions> - %1 = "mhlo_test.get_return_types"(%0#0) : (tensor<*xf32>) -> tensor<*xindex> - func.return %1: tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0#0) : (tensor) -> tensor + func.return %1: tensor } // ----- @@ -1449,53 +1377,53 @@ func.func @reduce_window_bound(%arg0: tensor<4x?x?x?xf32, #mhlo.type_extensions< // CHECK-LABEL: func @triangular_solve_bounds func.func @triangular_solve_bounds( %arg0: tensor<10x5x?x4xf32, #mhlo_test.type_extensions>, - %arg1: tensor<10x5x?x?xf32, #mhlo_test.type_extensions>) -> tensor<*xindex> { + %arg1: tensor<10x5x?x?xf32, #mhlo_test.type_extensions>) -> tensor { %0 = "mhlo.triangular_solve"(%arg0, %arg1) { left_side = false, lower = true, transpose_a = #mhlo, unit_diagonal = true } : (tensor<10x5x?x4xf32, #mhlo_test.type_extensions>, - tensor<10x5x?x?xf32, #mhlo_test.type_extensions>) -> tensor<*xf32> + tensor<10x5x?x?xf32, #mhlo_test.type_extensions>) -> tensor // CHECK: types0 = tensor<10x5x?x?xf32, #mhlo_test.type_extensions> - %1 = "mhlo_test.get_return_types"(%0) : (tensor<*xf32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0) : (tensor) -> tensor + func.return %1 : tensor } //----- // CHECK-LABEL: func @fft_bound -func.func @fft_bound(%arg0: tensor, #mhlo.type_extensions>) -> tensor<*xindex> { +func.func @fft_bound(%arg0: tensor, #mhlo.type_extensions>) -> tensor { %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo - } : (tensor, #mhlo.type_extensions>) -> tensor<*xcomplex> + } : (tensor, #mhlo.type_extensions>) -> tensor> // CHECK: types0 = tensor, #mhlo.type_extensions> - %1 = "mhlo_test.get_return_types"(%0) : (tensor<*xcomplex>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0) : (tensor>) -> tensor + func.return %1 : tensor } // ----- // CHECK-LABEL: func @rfft_with_bound -func.func @rfft_with_bound(%arg0: tensor<3x?x?xf32, #mhlo.type_extensions>) -> tensor<*xindex> { +func.func @rfft_with_bound(%arg0: tensor<3x?x?xf32, #mhlo.type_extensions>) -> tensor { %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo - } : (tensor<3x?x?xf32, #mhlo.type_extensions>) -> tensor<*xcomplex> + } : (tensor<3x?x?xf32, #mhlo.type_extensions>) -> tensor> // CHECK: types0 = tensor<3x?x5xcomplex, #mhlo.type_extensions> - %1 = "mhlo_test.get_return_types"(%0) : (tensor<*xcomplex>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0) : (tensor>) -> tensor + func.return %1 : tensor } // ----- // CHECK-LABEL: func @irfft_with_bound -func.func @irfft_with_bound(%arg0: tensor<3x?x?xcomplex, #mhlo.type_extensions>) -> tensor<*xindex> { +func.func @irfft_with_bound(%arg0: tensor<3x?x?xcomplex, #mhlo.type_extensions>) -> tensor { %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo - } : (tensor<3x?x?xcomplex, #mhlo.type_extensions>) -> tensor<*xf32> + } : (tensor<3x?x?xcomplex, #mhlo.type_extensions>) -> tensor // CHECK: types0 = tensor<3x?x9xf32, #mhlo.type_extensions> - %1 = "mhlo_test.get_return_types"(%0) : (tensor<*xf32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0) : (tensor) -> tensor + func.return %1 : tensor } // ----- @@ -1503,19 +1431,19 @@ func.func @irfft_with_bound(%arg0: tensor<3x?x?xcomplex, #mhlo.type_extensi // CHECK-LABEL: @select func.func @select(%pred : tensor, %a : tensor>, - %b : tensor<1x?x3x?xf32, #mhlo.type_extensions>) -> tensor<*xindex> { + %b : tensor<1x?x3x?xf32, #mhlo.type_extensions>) -> tensor { %0 = "mhlo.select"(%pred, %a, %b) : (tensor, tensor>, - tensor<1x?x3x?xf32, #mhlo.type_extensions>) -> tensor<*xf32> + tensor<1x?x3x?xf32, #mhlo.type_extensions>) -> tensor // CHECK: types0 = tensor<1x2x3x?xf32, #mhlo.type_extensions> - %1 = "mhlo_test.get_return_types"(%0) : (tensor<*xf32>) -> tensor<*xindex> - func.return %1 : tensor<*xindex> + %1 = "mhlo_test.get_return_types"(%0) : (tensor) -> tensor + func.return %1 : tensor } // ----- // CHECK-LABEL: func @dynamic_gather -func.func @dynamic_gather(%arg0: tensor, %arg1: tensor<1xi64>) -> tensor<*xindex> { +func.func @dynamic_gather(%arg0: tensor, %arg1: tensor<1xi64>) -> tensor { %0 = mhlo.constant dense<[1, 2]> : tensor<2xi32> %1 = "mhlo.dynamic_gather"(%arg0, %arg1, %0) { dimension_numbers = #mhlo.gather< @@ -1523,8 +1451,8 @@ func.func @dynamic_gather(%arg0: tensor, %arg1: tensor<1xi64>) -> tenso start_index_map = [1] >, indices_are_sorted = true - } : (tensor, tensor<1xi64>, tensor<2xi32>) -> tensor<*xf32> + } : (tensor, tensor<1xi64>, tensor<2xi32>) -> tensor // CHECK: types0 = tensor<1x2xf32> - %2 = "mhlo_test.get_return_types"(%1) : (tensor<*xf32>) -> tensor<*xindex> - func.return %2 : tensor<*xindex> + %2 = "mhlo_test.get_return_types"(%1) : (tensor) -> tensor + func.return %2 : tensor } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_ops_prettyprint.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_ops_prettyprint.mlir index 7aad7a84cc515..d45da5fa1f7b7 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_ops_prettyprint.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_ops_prettyprint.mlir @@ -223,8 +223,8 @@ func.func @compare_op(%arg0 : tensor<3xi32>) -> () { // CHECK-LABEL: func @extensions func.func @extensions(%arg0 : tensor>, %arg1 : tensor) -> () { - // CHECK: %0 = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 1 : i64} : (tensor>, tensor) -> tensor<*xf32> - %0 = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 1 : i64} : (tensor>, tensor) -> tensor<*xf32> + // CHECK: %0 = "mhlo.set_dimension_size"(%arg0, %arg1) <{dimension = 1 : i64}> : (tensor>, tensor) -> tensor + %0 = "mhlo.set_dimension_size"(%arg0, %arg1) <{dimension = 1 : i64}> : (tensor>, tensor) -> tensor "mhlo.return"() : () -> () } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_reduce_pretty_print.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_reduce_pretty_print.mlir index 8435e14fee55a..4471026910f8e 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_reduce_pretty_print.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_reduce_pretty_print.mlir @@ -25,45 +25,6 @@ func.func @reduce_one_op_all_locs_same(%arg0: tensor, %arg1 : tensor } -// The test case is not eligible for pretty-printing reduce-op. The location of -// reduce-op is different. - -// CHECK-LABEL: func @reduce_one_op_all_locs_not_same_1 -// CHECK-NEXT: mhlo.reduce(%arg0 init: %arg1) -// CHECK-SAME: across dimensions = [1] {foo = "bar"} -// CHECK-SAME: : (tensor, tensor) -> tensor -// CHECK-NEXT: reducer(%arg[[x:.+]]: tensor loc("foo"), %arg[[y:.+]]: tensor loc("foo")) -// CHECK-NEXT: mhlo.add %arg[[x]], %arg[[y]] : tensor loc("foo") -// CHECK-NEXT: mhlo.return %{{[0-9]+}} : tensor loc("foo") -// CHECK-NEXT: loc("not_foo") - -func.func @reduce_one_op_all_locs_not_same_1(%arg0: tensor, %arg1 : tensor) -> (tensor) { - %0 = "mhlo.reduce"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor loc("foo"), %arg3: tensor loc("foo")): - %1 = "mhlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor loc("foo") - "mhlo.return"(%1) : (tensor) -> () loc("foo") - }) {dimensions = dense<[1]> : tensor<1xi64>, foo = "bar"} : (tensor, tensor) -> tensor loc("not_foo") - - func.return %0: tensor -} - -// The test case is not eligible for pretty-printing reduce-op. The location of -// block-arguments are different. - -// CHECK-LABEL: func @reduce_one_op_all_locs_not_same_2 -// CHECK-NOT: applies - -func.func @reduce_one_op_all_locs_not_same_2(%arg0: tensor, %arg1 : tensor) -> (tensor) { - %0 = "mhlo.reduce"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor loc("foo"), %arg3: tensor loc("not_foo")): - %1 = "mhlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor loc("foo") - "mhlo.return"(%1) : (tensor) -> () loc("foo") - }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor, tensor) -> tensor loc("foo") - - func.return %0: tensor -} - - // The test case is not eligible for pretty-printing reduce-op. More than two // block-arguments which are not perfectly forwarded to inner-op. @@ -168,3 +129,16 @@ func.func @reduce_innerop_type_not_trivially_derived(%arg0: tensor<4x4xf32>, %ar func.return %0: tensor<4xf32> } + + +// The test case makes sure any custom attrs set on the reduce-op are +// printed/parsed when pretty-printed. + +// CHECK-LABEL: func @pretty_print_with_custom_attr +// CHECK: applies mhlo.add across dimensions = [1] {custom_user_attr = 1 : i64} + +func.func @pretty_print_with_custom_attr(%arg0: tensor<2x64x13xf32>) -> tensor<2x13xf32> { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = mhlo.reduce(%arg0 init: %0) applies mhlo.add across dimensions = [1] {custom_user_attr = 1 : i64} : (tensor<2x64x13xf32>, tensor) -> tensor<2x13xf32> + return %1 : tensor<2x13xf32> +} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index 206fa35192889..d8f929bc8e064 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -14,56 +14,88 @@ func.func private @invalid_type() -> !mhlo.foobar // ----- func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { - %0 = "mhlo.all_reduce"(%arg0) ({ - // Perform max reduction inside the region - ^bb0(%lhs: tensor, %rhs: tensor): - %max = mhlo.maximum %lhs, %rhs : tensor - "mhlo.return"(%max) : (tensor) -> () - }) - { + %0 = "mhlo.all_reduce"(%arg0) <{ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, channel_handle = #mhlo.channel_handle< handle = 5, type = 2 >, use_global_device_ids - } : (tensor<10xf32>) -> tensor<10xf32> + }> ({ + // Perform max reduction inside the region + ^bb0(%lhs: tensor, %rhs: tensor): + %max = mhlo.maximum %lhs, %rhs : tensor + "mhlo.return"(%max) : (tensor) -> () + }) : (tensor<10xf32>) -> tensor<10xf32> func.return %0 : tensor<10xf32> } // ----- func.func @all_reduce_tuple(%arg0: tensor<10xf32>, %arg1: tensor) -> tensor<10xf32> { - %0:2 = "mhlo.all_reduce"(%arg0, %arg1) ({ - // Perform max reduction inside the region - ^bb0(%lhs: tensor, %rhs: tensor): - %max = mhlo.maximum %lhs, %rhs : tensor - "mhlo.return"(%max) : (tensor) -> () - }) - { + %0:2 = "mhlo.all_reduce"(%arg0, %arg1) <{ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, channel_handle = #mhlo.channel_handle< handle = 5, type = 2 >, use_global_device_ids - } : (tensor<10xf32>, tensor) -> (tensor<10xf32>, tensor) + }> ({ + // Perform max reduction inside the region + ^bb0(%lhs: tensor, %rhs: tensor): + %max = mhlo.maximum %lhs, %rhs : tensor + "mhlo.return"(%max) : (tensor) -> () + }) : (tensor<10xf32>, tensor) -> (tensor<10xf32>, tensor) func.return %0 : tensor<10xf32> } // ----- +// CHECK-LABEL: func @all_reduce_with_promotable_types +func.func @all_reduce_with_promotable_types(%operand: tensor) -> tensor { + + %result = "mhlo.all_reduce"(%operand) <{ + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #mhlo.channel_handle + }> ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "mhlo.return"(%0) : (tensor) -> () + }) : (tensor) -> tensor + + func.return %result : tensor +} + +// ----- + +// CHECK-LABEL: func @all_reduce_with_promotable_quantized_types +func.func @all_reduce_with_promotable_quantized_types(%operand: tensor>) + -> tensor> { + + %result = "mhlo.all_reduce"(%operand) <{ + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #mhlo.channel_handle + }> ({ + ^bb0(%arg0: tensor>, %arg1: tensor>): + %0 = mhlo.add %arg0, %arg1 : tensor> + "mhlo.return"(%0) : (tensor>) -> () + }) : (tensor>) -> tensor> + + func.return %result : tensor> +} + +// ----- + func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32> { // expected-error@+2 {{'mhlo.all_reduce' op failed to infer returned types}} // expected-error@+1 {{Reduction-region must take 2 parameters, but takes 3 parameter(s)}} - %0 = "mhlo.all_reduce"(%operand) ({ + %0 = "mhlo.all_reduce"(%operand) <{ + replica_groups = dense<[[0, 2, 4, -1], [1, 3, -1, -1]]> : tensor<2x4xi64> + }> ({ ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor): %max = mhlo.maximum %arg0, %arg1 : tensor "mhlo.return"(%max) : (tensor) -> () - }) - { - replica_groups = dense<[[0, 2, 4, -1], [1, 3, -1, -1]]> : tensor<2x4xi64> - } : (tensor<10xf32>) -> tensor<10xf32> + }) : (tensor<10xf32>) -> tensor<10xf32> func.return %0 : tensor<10xf32> } @@ -72,14 +104,13 @@ func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32 func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32> { // expected-error@+2 {{'mhlo.all_reduce' op failed to infer returned types}} // expected-error@+1 {{The reduction-region expected to return some value(s)}} - %0 = "mhlo.all_reduce"(%operand) ({ + %0 = "mhlo.all_reduce"(%operand) <{ + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + }> ({ ^bb0(%arg0: tensor, %arg1: tensor): %max = mhlo.maximum %arg0, %arg1 : tensor "mhlo.return"() : () -> () - }) - { - replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - } : (tensor<10xf32>) -> tensor<10xf32> + }) : (tensor<10xf32>) -> tensor<10xf32> func.return %0 : tensor<10xf32> } @@ -88,14 +119,13 @@ func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32 func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32> { // expected-error@+2 {{'mhlo.all_reduce' op failed to infer returned types}} // expected-error@+1 {{Reduction-region here must produce 1 tensors, but produces 2 instead}} - %0 = "mhlo.all_reduce"(%operand) ({ + %0 = "mhlo.all_reduce"(%operand) <{ + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + }> ({ ^bb0(%arg0: tensor, %arg1: tensor): %max = mhlo.maximum %arg0, %arg1 : tensor "mhlo.return"(%max, %max) : (tensor, tensor) -> () - }) - { - replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - } : (tensor<10xf32>) -> tensor<10xf32> + }) : (tensor<10xf32>) -> tensor<10xf32> func.return %0 : tensor<10xf32> } @@ -104,15 +134,14 @@ func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32 func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32> { // expected-error@+2 {{'mhlo.all_reduce' op failed to infer returned types}} // expected-error@+1 {{Reduction-region here must produce tensor-typed result(s), but produces 'tuple, tensor>' instead}} - %0 = "mhlo.all_reduce"(%operand) ({ + %0 = "mhlo.all_reduce"(%operand) <{ + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + }> ({ ^bb0(%arg0: tensor, %arg1: tensor): %max = mhlo.maximum %arg0, %arg1 : tensor %tup = "mhlo.tuple"(%max, %max) : (tensor, tensor) -> tuple, tensor> "mhlo.return"(%tup) : (tuple, tensor>) -> () - }) - { - replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - } : (tensor<10xf32>) -> tensor<10xf32> + }) : (tensor<10xf32>) -> tensor<10xf32> func.return %0 : tensor<10xf32> } @@ -121,14 +150,13 @@ func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32 func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32> { // expected-error@+2 {{'mhlo.all_reduce' op failed to infer returned types}} // expected-error@+1 {{The type of reduction-region's parameter at index 1 is different than the corresponding result type: 'tensor' vs 'tensor'}} - %0 = "mhlo.all_reduce"(%operand) ({ + %0 = "mhlo.all_reduce"(%operand) <{ + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + }> ({ ^bb0(%arg0: tensor, %arg1: tensor): %max = mhlo.maximum %arg0, %arg0 : tensor "mhlo.return"(%max) : (tensor) -> () - }) - { - replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - } : (tensor<10xf32>) -> tensor<10xf32> + }) : (tensor<10xf32>) -> tensor<10xf32> func.return %0 : tensor<10xf32> } @@ -137,15 +165,14 @@ func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32 func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32> { // expected-error@+2 {{'mhlo.all_reduce' op failed to infer returned types}} // expected-error@+1 {{The type of reduction-region's parameter at index 0 is different than the corresponding result type: 'tensor' vs 'tensor'}} - %0 = "mhlo.all_reduce"(%operand) ({ + %0 = "mhlo.all_reduce"(%operand) <{ + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + }> ({ ^bb0(%arg0: tensor, %arg1: tensor): %max = mhlo.maximum %arg0, %arg1 : tensor %maxint = "mhlo.convert"(%max) : (tensor) -> tensor "mhlo.return"(%maxint) : (tensor) -> () - }) - { - replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - } : (tensor<10xf32>) -> tensor<10xf32> + }) : (tensor<10xf32>) -> tensor<10xf32> func.return %0 : tensor<10xf32> } @@ -153,15 +180,14 @@ func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32 func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32> { // expected-error@+2 {{'mhlo.all_reduce' op failed to infer returned types}} - // expected-error@+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} - %0 = "mhlo.all_reduce"(%operand) ({ + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} + %0 = "mhlo.all_reduce"(%operand) <{ + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + }> ({ ^bb0(%arg0: tensor, %arg1: tensor): %max = mhlo.maximum %arg0, %arg1 : tensor "mhlo.return"(%max) : (tensor) -> () - }) - { - replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - } : (tensor<10xf32>) -> tensor<10xf32> + }) : (tensor<10xf32>) -> tensor<10xf32> func.return %0 : tensor<10xf32> } @@ -169,15 +195,14 @@ func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32 func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32> { // expected-error@+2 {{'mhlo.all_reduce' op failed to infer returned types}} - // expected-error@+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor<4xf32>' vs 'tensor'}} - %0 = "mhlo.all_reduce"(%operand) ({ + // expected-error@+1 {{The shape of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor<4xf32>' vs 'tensor'}} + %0 = "mhlo.all_reduce"(%operand) <{ + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + }> ({ ^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>): %max = mhlo.maximum %arg0, %arg1 : tensor<4xf32> "mhlo.return"(%max) : (tensor<4xf32>) -> () - }) - { - replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - } : (tensor<10xf32>) -> tensor<10xf32> + }) : (tensor<10xf32>) -> tensor<10xf32> func.return %0 : tensor<10xf32> } @@ -186,29 +211,28 @@ func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32 func.func @all_reduce_invalid_return_type(%operand: tensor<10xf32>) -> tensor<10x4xf32> { // expected-error@+2 {{'mhlo.all_reduce' op failed to infer returned types}} // expected-error@+1 {{'mhlo.all_reduce' op inferred type(s) 'tensor<10xf32>' are incompatible with return type(s) of operation 'tensor<10x4xf32>'}} - %0 = "mhlo.all_reduce"(%operand) ({ + %0 = "mhlo.all_reduce"(%operand) <{ + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + }> ({ ^bb0(%arg0: tensor, %arg1: tensor): %max = mhlo.maximum %arg0, %arg1 : tensor "mhlo.return"(%max) : (tensor) -> () - }) - { - replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - } : (tensor<10xf32>) -> tensor<10x4xf32> + }) : (tensor<10xf32>) -> tensor<10x4xf32> func.return %0 : tensor<10x4xf32> } // ----- func.func @all_reduce_invalid_return_type(%operand: tensor<10xf32>) -> tensor<10xi32> { - // expected-error@+1 {{'mhlo.all_reduce' op requires the same element type for all operands and results}} - %0 = "mhlo.all_reduce"(%operand) ({ + // expected-error@+2 {{'mhlo.all_reduce' op inferred type(s) 'tensor<10xf32>' are incompatible with return type(s) of operation 'tensor<10xi32>'}} + // expected-error@+1 {{'mhlo.all_reduce' op failed to infer returned types}} + %0 = "mhlo.all_reduce"(%operand) <{ + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + }> ({ ^bb0(%arg0: tensor, %arg1: tensor): %max = mhlo.maximum %arg0, %arg1 : tensor "mhlo.return"(%max) : (tensor) -> () - }) - { - replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - } : (tensor<10xf32>) -> tensor<10xi32> + }) : (tensor<10xf32>) -> tensor<10xi32> func.return %0 : tensor<10xi32> } @@ -217,14 +241,13 @@ func.func @all_reduce_invalid_return_type(%operand: tensor<10xf32>) -> tensor<10 func.func @all_reduce_invalid_replica_group(%operand: tensor<10xf32>) -> tensor<10xf32> { // expected-error@+2 {{'mhlo.all_reduce' op failed to infer returned types}} // expected-error@+1 {{replica groups should be a rank 2 tensor}} - %0 = "mhlo.all_reduce"(%operand) ({ + %0 = "mhlo.all_reduce"(%operand) <{ + replica_groups = dense<0> : tensor<1xi64> + }> ({ ^bb0(%arg0: tensor, %arg1: tensor): %max = mhlo.maximum %arg0, %arg1 : tensor "mhlo.return"(%max) : (tensor) -> () - }) - { - replica_groups = dense<0> : tensor<1xi64> - } : (tensor<10xf32>) -> tensor<10xf32> + }) : (tensor<10xf32>) -> tensor<10xf32> func.return %0 : tensor<10xf32> } @@ -233,14 +256,13 @@ func.func @all_reduce_invalid_replica_group(%operand: tensor<10xf32>) -> tensor< func.func @all_reduce_invalid_replica_group(%operand: tensor<10xf32>) -> tensor<10xf32> { // expected-error@+2 {{'mhlo.all_reduce' op failed to infer returned types}} // expected-error@+1 {{replica id #1 seen more than once}} - %0 = "mhlo.all_reduce"(%operand) ({ + %0 = "mhlo.all_reduce"(%operand) <{ + replica_groups = dense<[[0, 1, 1, 3]]> : tensor<1x4xi64> + }> ({ ^bb0(%arg0: tensor, %arg1: tensor): %max = mhlo.maximum %arg0, %arg1 : tensor "mhlo.return"(%max) : (tensor) -> () - }) - { - replica_groups = dense<[[0, 1, 1, 3]]> : tensor<1x4xi64> - } : (tensor<10xf32>) -> tensor<10xf32> + }) : (tensor<10xf32>) -> tensor<10xf32> func.return %0 : tensor<10xf32> } @@ -249,14 +271,13 @@ func.func @all_reduce_invalid_replica_group(%operand: tensor<10xf32>) -> tensor< func.func @all_reduce_invalid_replica_group(%operand: tensor<10xf32>) -> tensor<10xf32> { // expected-error@+2 {{'mhlo.all_reduce' op failed to infer returned types}} // expected-error@+1 {{replica id #2 not seen in replica groups}} - %0 = "mhlo.all_reduce"(%operand) ({ + %0 = "mhlo.all_reduce"(%operand) <{ + replica_groups = dense<[[0, 1, 3]]> : tensor<1x3xi64> + }> ({ ^bb0(%arg0: tensor, %arg1: tensor): %max = mhlo.maximum %arg0, %arg1 : tensor "mhlo.return"(%max) : (tensor) -> () - }) - { - replica_groups = dense<[[0, 1, 3]]> : tensor<1x3xi64> - } : (tensor<10xf32>) -> tensor<10xf32> + }) : (tensor<10xf32>) -> tensor<10xf32> func.return %0 : tensor<10xf32> } @@ -265,15 +286,14 @@ func.func @all_reduce_invalid_replica_group(%operand: tensor<10xf32>) -> tensor< func.func @all_reduce_invalid_replica_group(%operand: tensor<10xf32>) -> tensor<10xf32> { // expected-error@+2 {{'mhlo.all_reduce' op failed to infer returned types}} // expected-error@+1 {{replica groups cannot be empty}} - %0 = "mhlo.all_reduce"(%operand) ({ + %0 = "mhlo.all_reduce"(%operand) <{ + replica_groups = dense<0> : tensor<0x2xi64>, + use_global_device_ids + }> ({ ^bb0(%arg0: tensor, %arg1: tensor): %max = mhlo.maximum %arg0, %arg1 : tensor "mhlo.return"(%max) : (tensor) -> () - }) - { - replica_groups = dense<0> : tensor<0x2xi64>, - use_global_device_ids - } : (tensor<10xf32>) -> tensor<10xf32> + }) : (tensor<10xf32>) -> tensor<10xf32> func.return %0 : tensor<10xf32> } @@ -281,14 +301,16 @@ func.func @all_reduce_invalid_replica_group(%operand: tensor<10xf32>) -> tensor< // CHECK-LABEL: func @reduce_scatter func.func @reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{ + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #mhlo.channel_handle, + use_global_device_ids + }> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () - }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, - scatter_dimension = 1 : i64, - channel_handle = #mhlo.channel_handle, - use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf32> + }) : (tensor<4x16xf32>) -> tensor<4x4xf32> func.return %0 : tensor<4x4xf32> } @@ -296,27 +318,61 @@ func.func @reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // CHECK-LABEL: func @reduce_scatter_dynamic func.func @reduce_scatter_dynamic(%data: tensor) -> tensor { - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{ + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #mhlo.channel_handle, + use_global_device_ids + }> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () - }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + }) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @reduce_scatter_with_promotable_types +func.func @reduce_scatter_with_promotable_types(%data: tensor<4x16xf32>) -> tensor<4x4xf64> { + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, scatter_dimension = 1 : i64, channel_handle = #mhlo.channel_handle, - use_global_device_ids} : (tensor) -> tensor - func.return %0 : tensor + use_global_device_ids}> ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) : (tensor<4x16xf32>) -> tensor<4x4xf64> + func.return %0 : tensor<4x4xf64> +} + +// ----- + +// CHECK-LABEL: func @reduce_scatter_with_promotable_quantized_types +func.func @reduce_scatter_with_promotable_quantized_types( + %data: tensor<4x16x!quant.uniform>) -> + tensor<4x4x!quant.uniform> { + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #mhlo.channel_handle, + use_global_device_ids}> ({ + ^bb0(%arg2: tensor>, %arg3: tensor>): + %1 = mhlo.add %arg2, %arg3 : tensor> + "mhlo.return"(%1) : (tensor>) -> () + }) : (tensor<4x16x!quant.uniform>) -> tensor<4x4x!quant.uniform> + func.return %0 : tensor<4x4x!quant.uniform> } // ----- func.func @reduce_scatter_c2(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{expects scatter_dimension >= 0}} - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = -1 : i64}> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () - }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, - scatter_dimension = -1 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32> + }) : (tensor<4x16xf32>) -> tensor<4x4xf32> func.return %0 : tensor<4x4xf32> } @@ -324,12 +380,12 @@ func.func @reduce_scatter_c2(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { func.func @reduce_scatter_c2(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{scatter dim should be less than operand/result rank}} - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 4 : i64}> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () - }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, - scatter_dimension = 4 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32> + }) : (tensor<4x16xf32>) -> tensor<4x4xf32> func.return %0 : tensor<4x4xf32> } @@ -337,12 +393,12 @@ func.func @reduce_scatter_c2(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { func.func @reduce_scatter_c3(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{replica id #1 seen more than once}} - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, 1, 1, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64}> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () - }) {replica_groups = dense<[[0, 1, 1, 3]]> : tensor<1x4xi64>, - scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32> + }) : (tensor<4x16xf32>) -> tensor<4x4xf32> func.return %0 : tensor<4x4xf32> } @@ -350,12 +406,12 @@ func.func @reduce_scatter_c3(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { func.func @reduce_scatter_c5(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{Invalid replica id -1}} - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, -1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64}> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () - }) {replica_groups = dense<[[0, -1, 2, 3]]> : tensor<1x4xi64>, - scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32> + }) : (tensor<4x16xf32>) -> tensor<4x4xf32> func.return %0 : tensor<4x4xf32> } @@ -363,12 +419,12 @@ func.func @reduce_scatter_c5(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { func.func @reduce_scatter_c5(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{replica id #2 not seen in replica groups}} - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, 1, 3]]> : tensor<1x3xi64>, + scatter_dimension = 1 : i64}> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () - }) {replica_groups = dense<[[0, 1, 3]]> : tensor<1x3xi64>, - scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32> + }) : (tensor<4x16xf32>) -> tensor<4x4xf32> func.return %0 : tensor<4x4xf32> } @@ -376,13 +432,13 @@ func.func @reduce_scatter_c5(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { func.func @reduce_scatter_c6(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{channel_id must be positive when useGlobalDeviceIds is set but got: 0}} - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + use_global_device_ids}> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () - }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, - scatter_dimension = 1 : i64, - use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf32> + }) : (tensor<4x16xf32>) -> tensor<4x4xf32> func.return %0 : tensor<4x4xf32> } @@ -390,12 +446,12 @@ func.func @reduce_scatter_c6(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { func.func @reduce_scatter_c7(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{Reduction-region must take 2 parameters, but takes 3 parameter(s)}} - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64}> ({ ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"() : () -> () - }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, - scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32> + }) : (tensor<4x16xf32>) -> tensor<4x4xf32> func.return %0 : tensor<4x4xf32> } @@ -403,12 +459,12 @@ func.func @reduce_scatter_c7(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { func.func @reduce_scatter_c7(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{The reduction-region expected to return some value(s)}} - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64}> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"() : () -> () - }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, - scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32> + }) : (tensor<4x16xf32>) -> tensor<4x4xf32> func.return %0 : tensor<4x4xf32> } @@ -416,12 +472,12 @@ func.func @reduce_scatter_c7(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { func.func @reduce_scatter_c7(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{Reduction-region here must produce 1 tensors, but produces 2 instead}} - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64}> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1, %1) : (tensor, tensor) -> () - }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, - scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32> + }) : (tensor<4x16xf32>) -> tensor<4x4xf32> func.return %0 : tensor<4x4xf32> } @@ -429,12 +485,12 @@ func.func @reduce_scatter_c7(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { func.func @reduce_scatter_c7(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{Reduction-region here must produce tensor-typed result(s), but produces 'tuple, tensor>' instead}} - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64}> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = "mhlo.tuple"(%arg2, %arg2) : (tensor, tensor) -> tuple, tensor> "mhlo.return"(%1) : (tuple, tensor>) -> () - }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, - scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32> + }) : (tensor<4x16xf32>) -> tensor<4x4xf32> func.return %0 : tensor<4x4xf32> } @@ -442,12 +498,12 @@ func.func @reduce_scatter_c7(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { func.func @reduce_scatter_c7(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{The type of reduction-region's parameter at index 1 is different than the corresponding result type: 'tensor' vs 'tensor'}} - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64}> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg2 : tensor "mhlo.return"(%1) : (tensor) -> () - }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, - scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32> + }) : (tensor<4x16xf32>) -> tensor<4x4xf32> func.return %0 : tensor<4x4xf32> } @@ -455,26 +511,26 @@ func.func @reduce_scatter_c7(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { func.func @reduce_scatter_c7(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{The type of reduction-region's parameter at index 0 is different than the corresponding result type: 'tensor' vs 'tensor'}} - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64}> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor %2 = "mhlo.convert"(%1) : (tensor) -> tensor "mhlo.return"(%2) : (tensor) -> () - }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, - scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32> + }) : (tensor<4x16xf32>) -> tensor<4x4xf32> func.return %0 : tensor<4x4xf32> } // ----- func.func @reduce_scatter_c7(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { - // expected-error@+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} - %0 = "mhlo.reduce_scatter"(%data) ({ + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64}> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () - }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, - scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32> + }) : (tensor<4x16xf32>) -> tensor<4x4xf32> func.return %0 : tensor<4x4xf32> } @@ -482,12 +538,12 @@ func.func @reduce_scatter_c7(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { func.func @reduce_scatter_c8(%data: tensor<4x16xf32>) -> tensor<4xf32> { // expected-error@+1 {{operand and result should have same rank}} - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64}> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () - }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, - scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4xf32> + }) : (tensor<4x16xf32>) -> tensor<4xf32> func.return %0 : tensor<4xf32> } @@ -495,12 +551,12 @@ func.func @reduce_scatter_c8(%data: tensor<4x16xf32>) -> tensor<4xf32> { func.func @reduce_scatter_c8(%data: tensor<4x16xf32>) -> tensor<4x5xf32> { // expected-error@+1 {{operand scatter dimension has size 16, expected to be a multiple of result scatter dimension size 5}} - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64}> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () - }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, - scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x5xf32> + }) : (tensor<4x16xf32>) -> tensor<4x5xf32> func.return %0 : tensor<4x5xf32> } @@ -508,12 +564,12 @@ func.func @reduce_scatter_c8(%data: tensor<4x16xf32>) -> tensor<4x5xf32> { func.func @reduce_scatter_c8(%data: tensor<4x16xf32>) -> tensor<3x4xf32> { // expected-error@+1 {{non scatter dimensions should be same for operand (4) and result (3)}} - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64}> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () - }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, - scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<3x4xf32> + }) : (tensor<4x16xf32>) -> tensor<3x4xf32> func.return %0 : tensor<3x4xf32> } @@ -521,12 +577,12 @@ func.func @reduce_scatter_c8(%data: tensor<4x16xf32>) -> tensor<3x4xf32> { func.func @reduce_scatter_i3(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{replica groups should be a rank 2 tensor}} - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<0> : tensor<1xi64>, + scatter_dimension = 1 : i64}> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () - }) {replica_groups = dense<0> : tensor<1xi64>, - scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32> + }) : (tensor<4x16xf32>) -> tensor<4x4xf32> func.return %0 : tensor<4x4xf32> } @@ -535,12 +591,12 @@ func.func @reduce_scatter_i3(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // TODO(#1746): Sync verification of ReduceScatter with HLO. func.func @reduce_scatter_invalid(%data: tensor<4x16xf32>) -> tensor<4x0xf32> { // expected-error@+1 {{result dimension size at scatter_dimension cannot be zero}} - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64}> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () - }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, - scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x0xf32> + }) : (tensor<4x16xf32>) -> tensor<4x0xf32> func.return %0 : tensor<4x0xf32> } @@ -549,12 +605,12 @@ func.func @reduce_scatter_invalid(%data: tensor<4x16xf32>) -> tensor<4x0xf32> { // TODO(#1746): Sync verification of ReduceScatter with HLO. func.func @reduce_scatter_invalid(%data: tensor<4x0xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{operand dimension size at scatter_dimension cannot be zero}} - %0 = "mhlo.reduce_scatter"(%data) ({ + %0 = "mhlo.reduce_scatter"(%data) <{replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64}> ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () - }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, - scatter_dimension = 1 : i64} : (tensor<4x0xf32>) -> tensor<4x4xf32> + }) : (tensor<4x0xf32>) -> tensor<4x4xf32> func.return %0 : tensor<4x4xf32> } @@ -574,19 +630,6 @@ func.func @all_to_all(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { // ----- -// CHECK-LABEL: func @all_to_all_unranked_input -func.func @all_to_all_unranked_input(%data: tensor<*xf32>) -> tensor<*xf32> { - %0 = "mhlo.all_to_all"(%data) { - split_dimension = 1 : i64, - concat_dimension = 0 : i64, - split_count = 5 : i64, - replica_groups = dense<[[0, 1, 2, 3, 4]]> : tensor<1x5xi64> - } : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - // CHECK-LABEL: func @all_to_all_dynamic_split_dim func.func @all_to_all_dynamic_split_dim(%data: tensor<4x?xf32>) -> tensor<20x?xf32> { %0 = "mhlo.all_to_all"(%data) { @@ -753,13 +796,24 @@ func.func @all_to_all_i5(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { // ----- +func.func @all_gather_variadic(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>) { + %0:2 = "mhlo.all_gather"(%arg0, %arg1) <{ + all_gather_dim = 1 : i64, + channel_handle = #mhlo.channel_handle, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + }> : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>) + func.return %0#0, %0#1 : tensor<8x8xf32>, tensor<8x16xf32> +} + +// ----- + func.func @allgather_gather_along_zero_dimension(%arg0: tensor<128x0xf32>) -> tensor<128x100xf32> { // expected-error@+1 {{dimension size of operand at 'all_gather_dim' cannot be zero}} - %0 = "mhlo.all_gather"(%arg0) { + %0 = "mhlo.all_gather"(%arg0) <{ all_gather_dim = 1 : i64, channel_handle = #mhlo.channel_handle, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - } : (tensor<128x0xf32>) -> tensor<128x100xf32> + }> : (tensor<128x0xf32>) -> tensor<128x100xf32> func.return %0 : tensor<128x100xf32> } @@ -767,11 +821,11 @@ func.func @allgather_gather_along_zero_dimension(%arg0: tensor<128x0xf32>) -> te func.func @all_gather_c1(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { // expected-error@+1 {{all_gather_dim cannot be negative}} - %0 = "mhlo.all_gather"(%arg0) { + %0 = "mhlo.all_gather"(%arg0) <{ all_gather_dim = -1 : i64, channel_handle = #mhlo.channel_handle, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - } : (tensor<8x2xf32>) -> tensor<8x8xf32> + }> : (tensor<8x2xf32>) -> tensor<8x8xf32> func.return %0 : tensor<8x8xf32> } @@ -779,11 +833,11 @@ func.func @all_gather_c1(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { func.func @all_gather_c1(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { // expected-error@+1 {{all_gather_dim must be a valid index of operand}} - %0 = "mhlo.all_gather"(%arg0) { + %0 = "mhlo.all_gather"(%arg0) <{ all_gather_dim = 2 : i64, channel_handle = #mhlo.channel_handle, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - } : (tensor<8x2xf32>) -> tensor<8x8xf32> + }> : (tensor<8x2xf32>) -> tensor<8x8xf32> func.return %0 : tensor<8x8xf32> } @@ -791,11 +845,11 @@ func.func @all_gather_c1(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { func.func @all_gather_c2(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { // expected-error@+1 {{replica id #2 seen more than once}} - %0 = "mhlo.all_gather"(%arg0) { + %0 = "mhlo.all_gather"(%arg0) <{ all_gather_dim = 1 : i64, channel_handle = #mhlo.channel_handle, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 2]]> : tensor<2x4xi64> - } : (tensor<8x2xf32>) -> tensor<8x8xf32> + }> : (tensor<8x2xf32>) -> tensor<8x8xf32> func.return %0 : tensor<8x8xf32> } @@ -803,11 +857,11 @@ func.func @all_gather_c2(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { func.func @all_gather_c4(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { // expected-error@+1 {{Invalid replica id -1}} - %0 = "mhlo.all_gather"(%arg0) { + %0 = "mhlo.all_gather"(%arg0) <{ all_gather_dim = 1 : i64, channel_handle = #mhlo.channel_handle, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, -1]]> : tensor<2x4xi64> - } : (tensor<8x2xf32>) -> tensor<8x8xf32> + }> : (tensor<8x2xf32>) -> tensor<8x8xf32> func.return %0 : tensor<8x8xf32> } @@ -815,11 +869,11 @@ func.func @all_gather_c4(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { func.func @all_gather_c4(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { // expected-error@+1 {{replica id #4 not seen in replica groups}} - %0 = "mhlo.all_gather"(%arg0) { + %0 = "mhlo.all_gather"(%arg0) <{ all_gather_dim = 1 : i64, channel_handle = #mhlo.channel_handle, replica_groups = dense<[[0, 2, 6, 8], [1, 3, 5, 7]]> : tensor<2x4xi64> - } : (tensor<8x2xf32>) -> tensor<8x8xf32> + }> : (tensor<8x2xf32>) -> tensor<8x8xf32> func.return %0 : tensor<8x8xf32> } @@ -827,12 +881,12 @@ func.func @all_gather_c4(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { func.func @all_gather_c5(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { // expected-error@+1 {{channel_id cannot be negative when useGlobalDeviceIds is set}} - %0 = "mhlo.all_gather"(%arg0) { + %0 = "mhlo.all_gather"(%arg0) <{ all_gather_dim = 1 : i64, replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, channel_handle = #mhlo.channel_handle, use_global_device_ids - } : (tensor<8x2xf32>) -> tensor<8x8xf32> + }> : (tensor<8x2xf32>) -> tensor<8x8xf32> func.return %0 : tensor<8x8xf32> } @@ -840,11 +894,11 @@ func.func @all_gather_c5(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { func.func @all_gather_c6(%arg0: tensor<8x2x32xf32>) -> tensor<8x8xf32> { // expected-error@+1 {{operand and result must have the same rank}} - %0 = "mhlo.all_gather"(%arg0) { + %0 = "mhlo.all_gather"(%arg0) <{ all_gather_dim = 1 : i64, channel_handle = #mhlo.channel_handle, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - } : (tensor<8x2x32xf32>) -> tensor<8x8xf32> + }> : (tensor<8x2x32xf32>) -> tensor<8x8xf32> func.return %0 : tensor<8x8xf32> } @@ -852,11 +906,11 @@ func.func @all_gather_c6(%arg0: tensor<8x2x32xf32>) -> tensor<8x8xf32> { func.func @all_gather_c6(%arg0: tensor<8x2xf32>) -> tensor<4x8xf32> { // expected-error@+1 {{operand and result should have the same shape except for the dimension size at 'all_gather_dim'}} - %0 = "mhlo.all_gather"(%arg0) { + %0 = "mhlo.all_gather"(%arg0) <{ all_gather_dim = 1 : i64, channel_handle = #mhlo.channel_handle, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - } : (tensor<8x2xf32>) -> tensor<4x8xf32> + }> : (tensor<8x2xf32>) -> tensor<4x8xf32> func.return %0 : tensor<4x8xf32> } @@ -864,11 +918,11 @@ func.func @all_gather_c6(%arg0: tensor<8x2xf32>) -> tensor<4x8xf32> { func.func @all_gather_c6(%arg0: tensor<128x32xf32>) -> tensor<128x100xf32> { // expected-error@+1 {{result gather dimension has size 100, expected to be a multiple of operand gather dimension size 32}} - %0 = "mhlo.all_gather"(%arg0) { + %0 = "mhlo.all_gather"(%arg0) <{ all_gather_dim = 1 : i64, channel_handle = #mhlo.channel_handle, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - } : (tensor<128x32xf32>) -> tensor<128x100xf32> + }> : (tensor<128x32xf32>) -> tensor<128x100xf32> func.return %0 : tensor<128x100xf32> } @@ -876,11 +930,11 @@ func.func @all_gather_c6(%arg0: tensor<128x32xf32>) -> tensor<128x100xf32> { func.func @all_gather_i3(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { // expected-error@+1 {{replica groups should be a rank 2 tensor}} - %0 = "mhlo.all_gather"(%arg0) { + %0 = "mhlo.all_gather"(%arg0) <{ all_gather_dim = 1 : i64, channel_handle = #mhlo.channel_handle, replica_groups = dense<[[[0], [1], [2], [3]]]> : tensor<1x4x1xi64> - } : (tensor<8x2xf32>) -> tensor<8x8xf32> + }> : (tensor<8x2xf32>) -> tensor<8x8xf32> func.return %0 : tensor<8x8xf32> } @@ -888,7 +942,7 @@ func.func @all_gather_i3(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { // CHECK-LABEL: func @broadcast func.func @broadcast(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { - %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast"(%arg0) <{broadcast_sizes = dense<[1, 2]> : tensor<2xi64>}> : (tensor<3xi32>) -> tensor<1x2x3xi32> func.return %0 : tensor<1x2x3xi32> } @@ -896,8 +950,8 @@ func.func @broadcast(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { func.func @broadcast_bad_sizes_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+2 {{'mhlo.broadcast' op failed to infer returned types}} - // expected-error@+1 {{broadcast_sizes has rank 2 instead of rank 1}} - %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[[1, 2]]> : tensor<1x2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + // expected-error@+1 {{broadcast_sizes has rank 2 instead of required rank 1.}} + %0 = "mhlo.broadcast"(%arg0) <{broadcast_sizes = dense<[[1, 2]]> : tensor<1x2xi64>}> : (tensor<3xi32>) -> tensor<1x2x3xi32> func.return %0 : tensor<1x2x3xi32> } @@ -906,7 +960,7 @@ func.func @broadcast_bad_sizes_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { func.func @broadcast_bad_result_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+2 {{'mhlo.broadcast' op failed to infer returned types}} // expected-error@+1 {{'mhlo.broadcast' op inferred type(s) 'tensor<2x3xi32>' are incompatible with return type(s) of operation 'tensor<1x2x3xi32>'}} - %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast"(%arg0) <{broadcast_sizes = dense<[2]> : tensor<1xi64>}> : (tensor<3xi32>) -> tensor<1x2x3xi32> func.return %0 : tensor<1x2x3xi32> } @@ -915,7 +969,7 @@ func.func @broadcast_bad_result_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> func.func @broadcast_bad_first_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+2 {{'mhlo.broadcast' op failed to infer returned types}} // expected-error@+1 {{'mhlo.broadcast' op inferred type(s) 'tensor<2x3xi32>' are incompatible with return type(s) of operation 'tensor<1x3xi32>'}} - %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x3xi32> + %0 = "mhlo.broadcast"(%arg0) <{broadcast_sizes = dense<[2]> : tensor<1xi64>}> : (tensor<3xi32>) -> tensor<1x3xi32> func.return %0 : tensor<1x3xi32> } @@ -924,7 +978,7 @@ func.func @broadcast_bad_first_part_result_shape(%arg0: tensor<3xi32>) -> tensor func.func @broadcast_bad_second_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+2 {{'mhlo.broadcast' op failed to infer returned types}} // expected-error@+1 {{'mhlo.broadcast' op inferred type(s) 'tensor<2x3xi32>' are incompatible with return type(s) of operation 'tensor<2x1xi32>'}} - %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<2x1xi32> + %0 = "mhlo.broadcast"(%arg0) <{broadcast_sizes = dense<[2]> : tensor<1xi64>}> : (tensor<3xi32>) -> tensor<2x1xi32> func.return %0 : tensor<2x1xi32> } @@ -932,7 +986,7 @@ func.func @broadcast_bad_second_part_result_shape(%arg0: tensor<3xi32>) -> tenso // CHECK-LABEL: func @dynamic_broadcast_in_dim func.func @dynamic_broadcast_in_dim(%arg0: tensor, %shape: tensor<3xi64>) -> tensor { - %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor<3xi64>) -> tensor + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) <{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}> : (tensor, tensor<3xi64>) -> tensor func.return %0 : tensor } @@ -940,7 +994,7 @@ func.func @dynamic_broadcast_in_dim(%arg0: tensor, %shape: tensor<3xi64 // CHECK-LABEL: func @dynamic_broadcast_in_dim_unknown_dim func.func @dynamic_broadcast_in_dim_unknown_dim(%arg0: tensor<32xf32>, %shape: tensor<3xi64>) -> tensor { - %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<32xf32>, tensor<3xi64>) -> tensor + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) <{broadcast_dimensions = dense<[2]> : tensor<1xi64>}> : (tensor<32xf32>, tensor<3xi64>) -> tensor func.return %0 : tensor } @@ -948,7 +1002,7 @@ func.func @dynamic_broadcast_in_dim_unknown_dim(%arg0: tensor<32xf32>, %shape: t // CHECK-LABEL: func @dynamic_broadcast_in_dim_ok_dim func.func @dynamic_broadcast_in_dim_ok_dim(%arg0: tensor<1xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> { - %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32> + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) <{broadcast_dimensions = dense<[2]> : tensor<1xi64>}> : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32> func.return %0 : tensor<7x8x9xf32> } @@ -956,7 +1010,7 @@ func.func @dynamic_broadcast_in_dim_ok_dim(%arg0: tensor<1xf32>, %shape: tensor< func.func @dynamic_broadcast_in_dim_shape_mismatch(%arg0: tensor<32xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> { // expected-error@+1 {{size of operand dimension 0 (32) is not compatible with size of result dimension 2 (9)}} - %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<32xf32>, tensor<3xi64>) -> tensor<7x8x9xf32> + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) <{broadcast_dimensions = dense<[2]> : tensor<1xi64>}> : (tensor<32xf32>, tensor<3xi64>) -> tensor<7x8x9xf32> func.return %0 : tensor<7x8x9xf32> } @@ -964,7 +1018,7 @@ func.func @dynamic_broadcast_in_dim_shape_mismatch(%arg0: tensor<32xf32>, %shape func.func @dynamic_broadcast_in_dim_negative_size(%arg0: tensor<1xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> { // expected-error@+1 {{broadcast_dimensions contains invalid value -1 for result with rank 3}} - %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[-1]> : tensor<1xi64>} : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32> + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) <{broadcast_dimensions = dense<[-1]> : tensor<1xi64>}> : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32> func.return %0 : tensor<7x8x9xf32> } @@ -972,7 +1026,7 @@ func.func @dynamic_broadcast_in_dim_negative_size(%arg0: tensor<1xf32>, %shape: func.func @dynamic_broadcast_in_dim_too_large(%arg0: tensor<1xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> { // expected-error@+1 {{broadcast_dimensions contains invalid value 3 for result with rank 3}} - %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[3]> : tensor<1xi64>} : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32> + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) <{broadcast_dimensions = dense<[3]> : tensor<1xi64>}> : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32> func.return %0 : tensor<7x8x9xf32> } @@ -981,7 +1035,7 @@ func.func @dynamic_broadcast_in_dim_too_large(%arg0: tensor<1xf32>, %shape: tens // CHECK-LABEL: func @broadcast_in_dim func.func @broadcast_in_dim(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}> : (tensor<1x2xi32>) -> tensor<1x2x2xi32> func.return %0 : tensor<1x2x2xi32> } @@ -989,7 +1043,7 @@ func.func @broadcast_in_dim(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { func.func @broadcast_in_dim_c2(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_dimensions size (1) does not match operand rank (2)}} - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1]> : tensor<1xi64>}> : (tensor<1x2xi32>) -> tensor<1x2x3xi32> func.return %0 : tensor<1x2x3xi32> } @@ -997,7 +1051,7 @@ func.func @broadcast_in_dim_c2(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { func.func @broadcast_in_dim_c3(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { // expected-error@+1 {{broadcast_dimensions contains invalid value -1 for result with rank 3}} - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[-1, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[-1, 2]> : tensor<2xi64>}> : (tensor<1x2xi32>) -> tensor<1x2x2xi32> func.return %0 : tensor<1x2x2xi32> } @@ -1005,7 +1059,7 @@ func.func @broadcast_in_dim_c3(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { func.func @broadcast_in_dim_c3(%arg0: tensor<1x2x3xi32>) -> tensor<3xi32> { // expected-error@+1 {{broadcast_dimensions contains invalid value 1 for result with rank 1}} - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0,1,2]> : tensor<3xi64>} : (tensor<1x2x3xi32>) -> tensor<3xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0,1,2]> : tensor<3xi64>}> : (tensor<1x2x3xi32>) -> tensor<3xi32> func.return %0 : tensor<3xi32> } @@ -1013,7 +1067,7 @@ func.func @broadcast_in_dim_c3(%arg0: tensor<1x2x3xi32>) -> tensor<3xi32> { func.func @broadcast_in_dim_c4(%arg0: tensor<1x1x3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_dimensions should not have duplicates}} - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0,0,2]> : tensor<3xi64>} : (tensor<1x1x3xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0,0,2]> : tensor<3xi64>}> : (tensor<1x1x3xi32>) -> tensor<1x2x3xi32> func.return %0 : tensor<1x2x3xi32> } @@ -1021,15 +1075,15 @@ func.func @broadcast_in_dim_c4(%arg0: tensor<1x1x3xi32>) -> tensor<1x2x3xi32> { func.func @broadcast_in_dim_c5(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{size of operand dimension 0 (3) is not equal to 1 or size of result dimension 1 (2)}} - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1]> : tensor<1xi64>}> : (tensor<3xi32>) -> tensor<1x2x3xi32> func.return %0 : tensor<1x2x3xi32> } // ----- func.func @broadcast_in_dim_i2(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { - // expected-error@+1 {{broadcast_dimensions has rank 2 instead of rank 1}} - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> + // expected-error@+1 {{broadcast_dimensions size (4) does not match operand rank (2)}} + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>}> : (tensor<1x2xi32>) -> tensor<1x2x3xi32> func.return %0 : tensor<1x2x3xi32> } @@ -1047,18 +1101,6 @@ func.func @broadcast_in_dim_dynamic_shaped_operand(%arg0 : tensor) -> ten // ----- -// Regression test for b/180052624, where this crashed verification given the -// unranked operand. -// CHECK-LABEL: func @broadcast_in_dim_unranked_operand -func.func @broadcast_in_dim_unranked_operand(%arg0 : tensor<*xf32>) -> tensor<2xf32> { - %0 = "mhlo.broadcast_in_dim"(%arg0) { - broadcast_dimensions = dense<0> : tensor<1xi64> - } : (tensor<*xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - // CHECK-LABEL: func @if func.func @if(%pred : tensor, %branch_operand : tensor<2xf32>) -> tensor<2xf32> { %0 = "mhlo.if"(%pred) ({ @@ -1162,18 +1204,6 @@ func.func @if_i1(%pred : tensor<1xi1>, %branch_operand : tensor) -> tensor< // ----- -// CHECK-LABEL: if_unranked -func.func @if_unranked(%pred : tensor, %true_branch_operand: tensor<2xf32>, %false_branch_operand : tensor<*xf32>) -> tensor<*xf32> { - %0 = "mhlo.if"(%pred) ({ - "mhlo.return"(%true_branch_operand) : (tensor<2xf32>) -> () - }, { - "mhlo.return"(%false_branch_operand) : (tensor<*xf32>) -> () - }) : (tensor) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - // CHECK-LABEL: @case func.func @case(%index : tensor, %branch_operand : tensor) -> (tensor, tensor) { %0, %1 = "mhlo.case"(%index) ({ @@ -1274,18 +1304,6 @@ func.func @case_i1(%index : tensor<1xi32>, %branch_operand : tensor<2xf32>) -> t // ----- -// CHECK-LABEL: @case_unranked -func.func @case_unranked(%index : tensor, %branch_operand : tensor<*xf32>) -> tensor<*xf32> { - %0 = "mhlo.case"(%index) ({ - "mhlo.return"(%branch_operand) : (tensor<*xf32>) -> () - }, { - "mhlo.return"(%branch_operand) : (tensor<*xf32>) -> () - }) : (tensor) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - // CHECK-LABEL: func @comp_eq func.func @comp_eq(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> { %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> @@ -1319,7 +1337,7 @@ func.func @comp_compatible_operand_types(%arg0: tensor<3xi32>, %arg1: tensor, %arg1: tensor<3xi32>) -> tensor<3xf16> { - // expected-error@+1 {{result #0 must be tensor of pred (AKA boolean or 1-bit integer) values, but got 'tensor<3xf16>'}} + // expected-error@+1 {{result #0 must be ranked tensor of pred (AKA boolean or 1-bit integer) values, but got 'tensor<3xf16>'}} %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xf16> func.return %0 : tensor<3xf16> } @@ -1386,7 +1404,7 @@ func.func @collective_permute_invalid_source_target_pairs(%arg0: tensor<128x32xf // CHECK-LABEL: @concatenate_1D func.func @concatenate_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> { - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> func.return %0 : tensor<3xi32> } @@ -1395,25 +1413,17 @@ func.func @concatenate_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor // CHECK-LABEL: @concatenate_1D // Verifies that an error is not thrown if the inferred type is compatible with // the result type. -func.func @concatenate_1D(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> { - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32> +func.func @concatenate_1D(%arg0: tensor<1xi32>, %arg1: tensor) -> tensor<3xi32> { + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<1xi32>, tensor) -> tensor<3xi32> func.return %0 : tensor<3xi32> } // ----- -// CHECK-LABEL: @concatenate_1D_unranked -func.func @concatenate_1D_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32> - func.return %0 : tensor<*xi32> -} - -// ----- - func.func @concatenate_c1_c5(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<4xi32> { // @expected-error@+2 {{'mhlo.concatenate' op failed to infer returned types}} // expected-error@+1 {{op inferred type(s) 'tensor<3xi32>' are incompatible with return type(s) of operation 'tensor<4xi32>'}} - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32> + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32> func.return %0 : tensor<4xi32> } @@ -1422,7 +1432,7 @@ func.func @concatenate_c1_c5(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> ten func.func @concatenate_c2(%arg0: tensor<1xi32>, %arg1: tensor<2x2xi32>) -> tensor<3xi32> { // @expected-error@+2 {{'mhlo.concatenate' op failed to infer returned types}} // expected-error@+1 {{operands (0) and (1) do not match rank}} - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2x2xi32>) -> tensor<3xi32> + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<1xi32>, tensor<2x2xi32>) -> tensor<3xi32> func.return %0 : tensor<3xi32> } @@ -1430,7 +1440,7 @@ func.func @concatenate_c2(%arg0: tensor<1xi32>, %arg1: tensor<2x2xi32>) -> tens func.func @concatenate_c3() -> tensor<2xi32> { // expected-error@+1 {{expected 1 or more operands, but found 0}} - %0 = "mhlo.concatenate"() { dimension = 0 : i64 } : () -> tensor<2xi32> + %0 = "mhlo.concatenate"() <{ dimension = 0 : i64 }> : () -> tensor<2xi32> func.return %0 : tensor<2xi32> } @@ -1439,17 +1449,17 @@ func.func @concatenate_c3() -> tensor<2xi32> { func.func @concatenate_c4(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> { // @expected-error@+2 {{'mhlo.concatenate' op failed to infer returned types}} // expected-error@+1 {{dimension -1 is negative}} - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = -1 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = -1 : i64 }> : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> func.return %0 : tensor<3xi32> } // ----- -func.func @concatenate_c4(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { +func.func @concatenate_c4(%arg0: tensor, %arg1: tensor) -> tensor { // @expected-error@+2 {{'mhlo.concatenate' op failed to infer returned types}} // expected-error@+1 {{dimension -1 is negative}} - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = -1 : i64 } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> - func.return %0 : tensor<*xi32> + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = -1 : i64 }> : (tensor, tensor) -> tensor + func.return %0 : tensor } // ----- @@ -1457,7 +1467,7 @@ func.func @concatenate_c4(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor func.func @concatenate_c4(%arg0: tensor, %arg1: tensor) -> tensor<2xi32> { // @expected-error@+2 {{'mhlo.concatenate' op failed to infer returned types}} // expected-error@+1 {{rank-0 values cannot be concatenated}} - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor, tensor) -> tensor<2xi32> + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor, tensor) -> tensor<2xi32> func.return %0 : tensor<2xi32> } @@ -1466,7 +1476,7 @@ func.func @concatenate_c4(%arg0: tensor, %arg1: tensor) -> tensor<2xi func.func @concatenate_c4(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> { // @expected-error@+2 {{'mhlo.concatenate' op failed to infer returned types}} // expected-error@+1 {{dimension 10 is out-of-bounds for input rank 1}} - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 10 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 10 : i64 }> : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> func.return %0 : tensor<3xi32> } @@ -1474,8 +1484,8 @@ func.func @concatenate_c4(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor func.func @concatenate_c6(%arg0: tensor<1x3xi32>, %arg1: tensor<2x2xi32>) -> tensor<3x3xi32> { // @expected-error@+2 {{'mhlo.concatenate' op failed to infer returned types}} - // expected-error@+1 {{shapes of operand (0) and (1) do not match at non-concat index: (1, 3) != (2, 2) at non-concat index 1}} - %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x3xi32>, tensor<2x2xi32>) -> tensor<3x3xi32> + // expected-error@+1 {{shapes of operand (0) and (1) are not compatible at non-concat index 1: (1, 3) != (2, 2)}} + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<1x3xi32>, tensor<2x2xi32>) -> tensor<3x3xi32> func.return %0 : tensor<3x3xi32> } @@ -1565,7 +1575,7 @@ func.func @cholesky_invalid_rank(%arg0: tensor<1xf32>) -> tensor<1xf32> { // ----- func.func @cholesky_invalid_elt(%arg0: tensor<1x2x2xi32>) -> tensor<1x2x2xi32> { - // expected-error@+1 {{op operand #0 must be tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements values, but got 'tensor<1x2x2xi32>'}} + // expected-error@+1 {{op operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements values, but got 'tensor<1x2x2xi32>'}} %0 = "mhlo.cholesky"(%arg0) { lower = true } : (tensor<1x2x2xi32>) -> tensor<1x2x2xi32> func.return %0: tensor<1x2x2xi32> } @@ -1635,10 +1645,10 @@ func.func @dot_more_dynamic_output_type(%arg0: tensor<3xf32>, %arg1: tensor, %arg1: tensor) -> tensor<*xf32> { +func.func @dot_cannot_infer_type(%arg0: tensor, %arg1: tensor) -> tensor { // expected-error@+1 {{expected both lhs/rhs ranks to be either 1 or 2}} - %0 = "mhlo.dot"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> - func.return %0 : tensor<*xf32> + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor } // ----- @@ -1651,31 +1661,9 @@ func.func @dot_result_type_mismatch_with_inferred_type(%arg0: tensor, % // ----- -func.func @dot_result_type_match_with_inferred_type(%arg0: tensor, %arg1: tensor<3xf32>) -> tensor<*xf32> { - %0 = "mhlo.dot"(%arg0, %arg1) : (tensor, tensor<3xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - -// CHECK-LABEL: func @dot_legal_unranked_rank_type -func.func @dot_legal_unranked_rank_type(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<2x2xf32> { - // unrank legal test - %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - // vector dot vector - %1 = tensor.cast %arg0 : tensor<*xf32> to tensor<3xf32> - %2 = tensor.cast %arg0 : tensor<*xf32> to tensor<3xf32> - %3 = "mhlo.dot"(%1, %2) : (tensor<3xf32>, tensor<3xf32>) -> tensor - // matrix dot vector - %4 = tensor.cast %arg0 : tensor<*xf32> to tensor<2x3xf32> - %5 = tensor.cast %arg1 : tensor<*xf32> to tensor<3xf32> - %6 = "mhlo.dot"(%4, %5) : (tensor<2x3xf32>, tensor<3xf32>) -> tensor<2xf32> - // matrix dot matrix - %7 = tensor.cast %arg0 : tensor<*xf32> to tensor<2x3xf32> - %8 = tensor.cast %arg1 : tensor<*xf32> to tensor<3x2xf32> - %9 = "mhlo.dot"(%7, %8) : (tensor<2x3xf32>, tensor<3x2xf32>) -> tensor<2x2xf32> - - func.return %9 : tensor<2x2xf32> +func.func @dot_result_type_match_with_inferred_type(%arg0: tensor, %arg1: tensor<3xf32>) -> tensor { + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor, tensor<3xf32>) -> tensor + func.return %0 : tensor } // ----- @@ -1696,14 +1684,6 @@ func.func @imag_complex_input(%arg0: tensor<2x3xcomplex>) -> tensor<2x3xf32 // ----- -// CHECK-LABEL: func @imag_unranked -func.func @imag_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "mhlo.imag"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - func.func @infeed_non_token_second_result(%token: !mhlo.token) -> tuple, tensor> { // expected-error@+1 {{last element of result types is expected to be of token type, but got 'tensor'}} %0:2 = "mhlo.infeed"(%token) {infeed_config = "foobar", layout = [[[0]], [0]]} : (!mhlo.token) -> (tensor, tensor) @@ -1747,7 +1727,7 @@ func.func @main(%arg0: !mhlo.token) -> tensor<3x3xi32> { func.func @iota_scalar() -> tensor { // expected-error@+1 {{does not support scalars}} - %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor func.return %0 : tensor } @@ -1755,7 +1735,7 @@ func.func @iota_scalar() -> tensor { func.func @iota_invalid_iota_dimension() -> tensor<4xi32> { // expected-error@+1 {{iota dimension cannot go beyond the output rank or be negative}} - %0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4xi32> + %0 = "mhlo.iota"() <{iota_dimension = 1 : i64}> : () -> tensor<4xi32> func.return %0 : tensor<4xi32> } @@ -1796,18 +1776,6 @@ func.func @map_scalar_operands(%arg0: tensor, %arg1: tensor) -> tensor // ----- -// CHECK-LABEL: func @map_unranked -func.func @map_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %0 = "mhlo.map"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = mhlo.add %arg2, %arg3 {name = "add"} : tensor - "mhlo.return"(%1) : (tensor) -> () - }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - func.func @map_mismatched_args(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // @expected-error@+2 {{'mhlo.map' op failed to infer returned types}} // expected-error@+1 {{expects number of operands to match the arity of map computation, but got: 2 and 1}} @@ -1888,7 +1856,7 @@ func.func @mismatch_computation_output_type(%arg0: tensor<4x5xf32>, %arg1: tenso func.func @map_invalid_dimension_numbers(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { // @expected-error@+2 {{'mhlo.map' op failed to infer returned types}} - // expected-error@+1 {{requires monotonically increasing dimension numbers, but got: dense<[1, 0]> : tensor<2xi64>}} + // expected-error@+1 {{requires monotonically increasing dimension numbers, but got: 1, 0}} %0 = "mhlo.map"(%arg0, %arg1) ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 {name = "add"} : tensor @@ -1939,14 +1907,6 @@ func.func @real_complex_input(%arg0: tensor<2x3xcomplex>) -> tensor<2x3xf32 // ----- -// CHECK-LABEL: func @real_unranked -func.func @real_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "mhlo.real"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - func.func @recv_non_token_second_result(%token: !mhlo.token) -> tuple, tensor> { // expected-error@+1 {{last element of result types is expected to be of token type, but got 'tensor'}} %0:2 = "mhlo.recv"(%token) { @@ -1972,7 +1932,7 @@ func.func @replica_id() -> tensor { // CHECK-LABEL: func @rng_bit_generator func.func @rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) { - %0, %1 = "mhlo.rng_bit_generator"(%arg0) {rng_algorithm = #mhlo.rng_algorithm} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) + %0, %1 = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> : (tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) func.return %0, %1 : tensor<2xui64>, tensor<10x12xui32> } @@ -1980,7 +1940,7 @@ func.func @rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<1 func.func @rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) { // expected-error@+1 {{output state shape must be compatible with initial state shape. Got: 'tensor<2xui64>' and 'tensor<3xui64>'}} - %0, %1 = "mhlo.rng_bit_generator"(%arg0) {rng_algorithm = #mhlo.rng_algorithm} : (tensor<2xui64>) -> (tensor<3xui64>, tensor<10x12xui32>) + %0, %1 = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> : (tensor<2xui64>) -> (tensor<3xui64>, tensor<10x12xui32>) func.return %0, %1 : tensor<3xui64>, tensor<10x12xui32> } @@ -1988,7 +1948,7 @@ func.func @rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<1 // CHECK-LABEL: func @rng_bit_generator_dynamic func.func @rng_bit_generator_dynamic(%arg0: tensor) -> (tensor, tensor<10x12xui32>) { - %0, %1 = "mhlo.rng_bit_generator"(%arg0) {rng_algorithm = #mhlo.rng_algorithm} : (tensor) -> (tensor, tensor<10x12xui32>) + %0, %1 = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> : (tensor) -> (tensor, tensor<10x12xui32>) func.return %0, %1 : tensor, tensor<10x12xui32> } @@ -1997,7 +1957,7 @@ func.func @rng_bit_generator_dynamic(%arg0: tensor) -> (tensor, // CHECK-LABEL: func @rng_normal func.func @rng_normal(%arg0: tensor, %arg1: tensor) -> tensor<2x3x5xf32> { %cst = "mhlo.constant"() {value = dense<[2, 3, 5]> : tensor<3xi64>} : () -> tensor<3xi64> - %0 = "mhlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng"(%arg0, %arg1, %cst) <{rng_distribution = #mhlo.rng_distribution}>: (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2005,25 +1965,17 @@ func.func @rng_normal(%arg0: tensor, %arg1: tensor) -> tensor<2x3x5xf3 // CHECK-LABEL: func @rng_normal_no_constant func.func @rng_normal_no_constant(%a: tensor, %b: tensor, %shape: tensor<3xi64>) -> tensor { - %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor<3xi64>) -> tensor + %0 = "mhlo.rng"(%a, %b, %shape) <{rng_distribution = #mhlo.rng_distribution}>: (tensor, tensor, tensor<3xi64>) -> tensor func.return %0 : tensor } // ----- -// CHECK-LABEL: func @rng_normal_dynamic_dim -func.func @rng_normal_dynamic_dim(%a: tensor, %b: tensor, %shape: tensor) -> tensor<*xf32> { - %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - func.func @rng_normal_invalid_shape(%arg0: tensor, %arg1: tensor) { %cst = "mhlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64> // @expected-error@+2 {{'mhlo.rng' op failed to infer returned types}} // expected-error @+1 {{inferred type(s) 'tensor<7xf32>' are incompatible with return type(s) of operation 'tensor<12xf32>'}} - %0 = "mhlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor<1xi64>) -> tensor<12xf32> + %0 = "mhlo.rng"(%arg0, %arg1, %cst) <{rng_distribution = #mhlo.rng_distribution}>: (tensor, tensor, tensor<1xi64>) -> tensor<12xf32> func.return } @@ -2032,7 +1984,7 @@ func.func @rng_normal_invalid_shape(%arg0: tensor, %arg1: tensor) { func.func @rng_normal_invalid_mu_rank(%mu: tensor<1xf32>, %sigma: tensor) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> // expected-error@+1 {{op operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} - %0 = "mhlo.rng"(%mu, %sigma, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor<1xf32>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng"(%mu, %sigma, %shape) <{rng_distribution = #mhlo.rng_distribution}>: (tensor<1xf32>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2041,7 +1993,7 @@ func.func @rng_normal_invalid_mu_rank(%mu: tensor<1xf32>, %sigma: tensor) - func.func @rng_normal_invalid_sigma_rank(%mu: tensor, %sigma: tensor<1xf32>) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> // expected-error@+1 {{op operand #1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} - %0 = "mhlo.rng"(%mu, %sigma, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng"(%mu, %sigma, %shape) <{rng_distribution = #mhlo.rng_distribution}>: (tensor, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2050,7 +2002,7 @@ func.func @rng_normal_invalid_sigma_rank(%mu: tensor, %sigma: tensor<1xf32> func.func @rng_normal_invalid_shape_rank(%mu: tensor, %sigma: tensor) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[[2, 3, 5]]> : tensor<1x3xi64> // expected-error@+1 {{operand #2 must be 1D tensor of index or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer values, but got 'tensor<1x3xi64>'}} - %0 = "mhlo.rng"(%mu, %sigma, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor<1x3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng"(%mu, %sigma, %shape) <{rng_distribution = #mhlo.rng_distribution}>: (tensor, tensor, tensor<1x3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2059,7 +2011,7 @@ func.func @rng_normal_invalid_shape_rank(%mu: tensor, %sigma: tensor) func.func @rng_normal_invalid_type(%arg0: tensor>, %arg1: tensor) { %cst = "mhlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64> // expected-error @+1 {{op operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} - %0 = "mhlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #mhlo.rng_distribution}: (tensor>, tensor, tensor<1xi64>) -> tensor<7xf32> + %0 = "mhlo.rng"(%arg0, %arg1, %cst) <{rng_distribution = #mhlo.rng_distribution}>: (tensor>, tensor, tensor<1xi64>) -> tensor<7xf32> func.return } @@ -2068,7 +2020,7 @@ func.func @rng_normal_invalid_type(%arg0: tensor>, %arg1: tensor, %b: tensor) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> - %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng"(%a, %b, %shape) <{rng_distribution = #mhlo.rng_distribution}>: (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2076,24 +2028,16 @@ func.func @rng_uniform(%a: tensor, %b: tensor) -> tensor<2x3x5xf32> { // CHECK-LABEL: func @rng_uniform_no_constant func.func @rng_uniform_no_constant(%a: tensor, %b: tensor, %shape: tensor<3xi64>) -> tensor { - %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor<3xi64>) -> tensor + %0 = "mhlo.rng"(%a, %b, %shape) <{rng_distribution = #mhlo.rng_distribution}>: (tensor, tensor, tensor<3xi64>) -> tensor func.return %0 : tensor } // ----- -// CHECK-LABEL: func @rng_uniform_dynamic_dim -func.func @rng_uniform_dynamic_dim(%a: tensor, %b: tensor, %shape: tensor) -> tensor<*xf32> { - %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - func.func @rng_uniform_invalid_shape(%arg0: tensor, %arg1: tensor, %arg2: tensor<7xi64>) { // @expected-error@+2 {{'mhlo.rng' op failed to infer returned types}} // expected-error @+1 {{inferred type(s) 'tensor' are incompatible with return type(s) of operation 'tensor'}} - %0 = "mhlo.rng"(%arg0, %arg1, %arg2) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor<7xi64>) -> tensor + %0 = "mhlo.rng"(%arg0, %arg1, %arg2) <{rng_distribution = #mhlo.rng_distribution}>: (tensor, tensor, tensor<7xi64>) -> tensor func.return } @@ -2102,7 +2046,7 @@ func.func @rng_uniform_invalid_shape(%arg0: tensor, %arg1: tensor, %ar func.func @rng_uniform_invalid_a_rank(%a: tensor<1xf32>, %b: tensor) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> // expected-error@+1 {{op operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} - %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor<1xf32>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng"(%a, %b, %shape) <{rng_distribution = #mhlo.rng_distribution}>: (tensor<1xf32>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2112,7 +2056,7 @@ func.func @rng_uniform_invalid_a_rank(%a: tensor<1xf32>, %b: tensor) -> ten func.func @rng_uniform_invalid_b_rank(%a: tensor, %b: tensor<1xf32>) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> // expected-error@+1 {{op operand #1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} - %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng"(%a, %b, %shape) <{rng_distribution = #mhlo.rng_distribution}>: (tensor, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2121,7 +2065,7 @@ func.func @rng_uniform_invalid_b_rank(%a: tensor, %b: tensor<1xf32>) -> ten func.func @rng_uniform_invalid_shape_rank(%a: tensor, %b: tensor) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[[2, 3, 5]]> : tensor<1x3xi64> // expected-error@+1 {{operand #2 must be 1D tensor of index or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer values, but got 'tensor<1x3xi64>'}} - %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor<1x3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng"(%a, %b, %shape) <{rng_distribution = #mhlo.rng_distribution}>: (tensor, tensor, tensor<1x3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2130,7 +2074,7 @@ func.func @rng_uniform_invalid_shape_rank(%a: tensor, %b: tensor) -> t func.func @rng_uniform_invalid_type(%a: tensor>, %b: tensor) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> // expected-error@+1 {{op operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} - %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng"(%a, %b, %shape) <{rng_distribution = #mhlo.rng_distribution}>: (tensor>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2153,17 +2097,17 @@ func.func @select_scalar_pred(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: // ----- // CHECK-LABEL: func @select_cast_compatible_types -func.func @select_cast_compatible_types(%arg0: tensor, %arg1: tensor<*xi32>, %arg2: tensor<2x3xi32>) -> tensor<*xi32> { - %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<*xi32>, tensor<2x3xi32>) -> tensor<*xi32> - func.return %0 : tensor<*xi32> +func.func @select_cast_compatible_types(%arg0: tensor, %arg1: tensor, %arg2: tensor<2x3xi32>) -> tensor { + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor<2x3xi32>) -> tensor + func.return %0 : tensor } // ----- // CHECK-LABEL: func @select_cast_compatible_types -func.func @select_cast_compatible_types(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<*xi32>) -> tensor<*xi32> { - %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x3xi32>, tensor<*xi32>) -> tensor<*xi32> - func.return %0 : tensor<*xi32> +func.func @select_cast_compatible_types(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor) -> tensor { + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x3xi32>, tensor) -> tensor + func.return %0 : tensor } // ----- @@ -2193,7 +2137,7 @@ func.func @select_scalar_x_y(%arg0: tensor, %arg1: tensor, %arg2: tenso // ----- func.func @select_bad_pred_type(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { - // expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) values}} + // expected-error@+1 {{must be ranked tensor of pred (AKA boolean or 1-bit integer) values}} %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> func.return %0 : tensor<2x3xi32> } @@ -2245,7 +2189,7 @@ func.func @select_element_type_mismatch(%arg0: tensor, %arg1: tensor<2x3xf32 // CHECK-LABEL: func @slice func.func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> { - %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32> + %0 = "mhlo.slice"(%arg0) <{start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>}> : (tensor<3x4xi32>) -> tensor<1x2xi32> func.return %0 : tensor<1x2xi32> } @@ -2341,17 +2285,9 @@ func.func @slice_i2(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> { // ----- -// CHECK-LABEL: func @slice_unranked -func.func @slice_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> { - %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<*xi32>) -> tensor<*xi32> - func.return %0 : tensor<*xi32> -} - -// ----- - // CHECK-LABEL: func @dynamic_slice func.func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { - %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> func.return %0 : tensor<1x4xi32> } @@ -2360,7 +2296,7 @@ func.func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tens func.func @dynamic_slice_c2(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { // @expected-error@+2 {{'mhlo.dynamic_slice' op failed to infer returned types}} // expected-error@+1 {{has mismatched number of slice sizes (1) and number of start indices (2)}} - %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4]> : tensor<1xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> func.return %0 : tensor<1x4xi32> } @@ -2369,7 +2305,7 @@ func.func @dynamic_slice_c2(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: t func.func @dynamic_slice_c2(%arg0: tensor<3x4xi32>, %arg1: tensor) -> tensor<1x4xi32> { // @expected-error@+2 {{'mhlo.dynamic_slice' op failed to infer returned types}} // expected-error@+1 {{has mismatched number of start indices (1) and the rank of operand (2)}} - %0 = "mhlo.dynamic_slice"(%arg0, %arg1) {slice_sizes = dense<[1]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor) -> tensor<1x4xi32> + %0 = "mhlo.dynamic_slice"(%arg0, %arg1) <{slice_sizes = dense<[1]> : tensor<1xi64>}> : (tensor<3x4xi32>, tensor) -> tensor<1x4xi32> func.return %0 : tensor<1x4xi32> } @@ -2378,7 +2314,7 @@ func.func @dynamic_slice_c2(%arg0: tensor<3x4xi32>, %arg1: tensor) -> tenso func.func @dynamic_slice_c3(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { // @expected-error@+2 {{'mhlo.dynamic_slice' op failed to infer returned types}} // expected-error@+1 {{start indices must have same element type}} - %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> func.return %0 : tensor<1x4xi32> } @@ -2387,7 +2323,7 @@ func.func @dynamic_slice_c3(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: t func.func @dynamic_slice_c4(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { // @expected-error@+2 {{'mhlo.dynamic_slice' op failed to infer returned types}} // expected-error@+1 {{has negative size index to dynamic slice: -1}} - %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[-1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[-1, 4]> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> func.return %0 : tensor<1x4xi32> } @@ -2396,7 +2332,7 @@ func.func @dynamic_slice_c4(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: t func.func @dynamic_slice_c4(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { // @expected-error@+2 {{'mhlo.dynamic_slice' op failed to infer returned types}} // expected-error@+1 {{has slice size 10 greater than dimension size 4 in dimension 1 of operand}} - %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 10]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[1, 10]> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> func.return %0 : tensor<1x4xi32> } @@ -2405,7 +2341,7 @@ func.func @dynamic_slice_c4(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: t func.func @dynamic_slice_c5(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<2x4xi32> { // @expected-error@+2 {{'mhlo.dynamic_slice' op failed to infer returned types}} // expected-error@+1 {{inferred type(s) 'tensor<1x4xi32>' are incompatible with return type(s) of operation 'tensor<2x4xi32>'}} - %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<2x4xi32> + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<2x4xi32> func.return %0 : tensor<2x4xi32> } @@ -2413,7 +2349,7 @@ func.func @dynamic_slice_c5(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: t // CHECK-LABEL: func @dynamic_slice_dynamic_dim func.func @dynamic_slice_dynamic_dim(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { - %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor, tensor, tensor) -> tensor<1x4xi32> + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor, tensor, tensor) -> tensor<1x4xi32> func.return %0 : tensor<1x4xi32> } @@ -2421,8 +2357,8 @@ func.func @dynamic_slice_dynamic_dim(%arg0: tensor, %arg1: tensor, func.func @dynamic_slice_i3(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { // @expected-error@+2 {{'mhlo.dynamic_slice' op failed to infer returned types}} - // expected-error@+1 {{slice_sizes should be rank 1, but got rank 0.}} - %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<1> : tensor} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + // expected-error@+1 {{slice_sizes has rank 0 instead of required rank 1.}} + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<1> : tensor}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> func.return %0 : tensor<1x4xi32> } @@ -2490,16 +2426,16 @@ func.func @dynamic_update_slice_dynamic_dim(%operand: tensor, %update: // ----- // CHECK-LABEL: func @dynamic_update_slice_dynamic_rank_operand -func.func @dynamic_update_slice_dynamic_rank_operand(%operand: tensor<*xi64>, %update: tensor<1x4xi64>, %start_indices0: tensor, %start_indices1: tensor) -> tensor<*xi64> { - %0 = "mhlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1) : (tensor<*xi64>, tensor<1x4xi64>, tensor, tensor) -> tensor<*xi64> - func.return %0 : tensor<*xi64> +func.func @dynamic_update_slice_dynamic_rank_operand(%operand: tensor, %update: tensor<1x4xi64>, %start_indices0: tensor, %start_indices1: tensor) -> tensor { + %0 = "mhlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1) : (tensor, tensor<1x4xi64>, tensor, tensor) -> tensor + func.return %0 : tensor } // ----- // CHECK-LABEL: func @dynamic_update_slice_dynamic_rank_update -func.func @dynamic_update_slice_dynamic_rank_update(%operand: tensor<3x4xi64>, %update: tensor<*xi64>, %start_indices0: tensor, %start_indices1: tensor) -> tensor<3x4xi64> { - %0 = "mhlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1) : (tensor<3x4xi64>, tensor<*xi64>, tensor, tensor) -> tensor<3x4xi64> +func.func @dynamic_update_slice_dynamic_rank_update(%operand: tensor<3x4xi64>, %update: tensor, %start_indices0: tensor, %start_indices1: tensor) -> tensor<3x4xi64> { + %0 = "mhlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1) : (tensor<3x4xi64>, tensor, tensor, tensor) -> tensor<3x4xi64> func.return %0 : tensor<3x4xi64> } @@ -2515,29 +2451,22 @@ func.func @dynamic_update_slice_dynamic_sizes(%operand: tensor, %update // CHECK-LABEL: func @transpose func.func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { - %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + %0 = "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}> : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> func.return %0: tensor<2x1x4x3xi32> } // ----- func.func @transpose_ranked(%arg0: tensor) -> tensor { - %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor) -> tensor + %0 = "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}> : (tensor) -> tensor func.return %0: tensor } // ----- -func.func @transpose_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> { - %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<*xi32>) -> tensor<*xi32> - func.return %0: tensor<*xi32> -} - -// ----- - func.func @transpose_missing_permutation(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { // expected-error@+1 {{requires attribute 'permutation'}} - %0 = "mhlo.transpose"(%arg0) {} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + %0 = "mhlo.transpose"(%arg0) : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> func.return %0: tensor<2x1x4x3xi32> } @@ -2545,8 +2474,8 @@ func.func @transpose_missing_permutation(%arg0: tensor<1x2x3x4xi32>) -> tensor<2 func.func @transpose_bad_permutations_rank(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { // @expected-error@+2 {{'mhlo.transpose' op failed to infer returned types}} - // expected-error@+1 {{permutation has rank 2 instead of rank 1}} - %0 = "mhlo.transpose"(%arg0) {permutation = dense<[[1]]> : tensor<1x1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + // expected-error@+1 {{permutation has rank 2 instead of required rank 1.}} + %0 = "mhlo.transpose"(%arg0) <{permutation = dense<[[1]]> : tensor<1x1xi64>}> : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> func.return %0: tensor<2x1x4x3xi32> } @@ -2555,7 +2484,7 @@ func.func @transpose_bad_permutations_rank(%arg0: tensor<1x2x3x4xi32>) -> tenso func.func @transpose_bad_permutations_size(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { // @expected-error@+2 {{'mhlo.transpose' op failed to infer returned types}} // expected-error@+1 {{TransposeOp operand rank 4 does not match permutation size 1}} - %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1]> : tensor<1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + %0 = "mhlo.transpose"(%arg0) <{permutation = dense<[1]> : tensor<1xi64>}> : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> func.return %0: tensor<2x1x4x3xi32> } @@ -2563,8 +2492,8 @@ func.func @transpose_bad_permutations_size(%arg0: tensor<1x2x3x4xi32>) -> tenso func.func @transpose_bad_permutation(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { // @expected-error@+2 {{'mhlo.transpose' op failed to infer returned types}} - // expected-error@+1 {{attribute permutation must be a permutation of [0, 1, 2, 3] but got dense<[1, 0, 3, 9]> : tensor<4xi64>}} - %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 9]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + // expected-error@+1 {{attribute permutation must be a permutation of [0, 1, 2, 3] but got 1, 0, 3, 9}} + %0 = "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0, 3, 9]> : tensor<4xi64>}> : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> func.return %0: tensor<2x1x4x3xi32> } @@ -2573,7 +2502,7 @@ func.func @transpose_bad_permutation(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x func.func @transpose_operand_result_rank_mismatch(%arg0: tensor<1x2x3x4xi32>) -> tensor<2xi32> { // @expected-error@+2 {{'mhlo.transpose' op failed to infer returned types}} // expected-error@+1 {{op inferred type(s) 'tensor<2x1x4x3xi32>' are incompatible with return type(s) of operation 'tensor<2xi32>'}} - %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2xi32> + %0 = "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}> : (tensor<1x2x3x4xi32>) -> tensor<2xi32> func.return %0: tensor<2xi32> } @@ -2582,7 +2511,7 @@ func.func @transpose_operand_result_rank_mismatch(%arg0: tensor<1x2x3x4xi32>) -> func.func @transpose_operand_result_permutation_mismatch(%arg0: tensor<1x?x3x?xi32>) -> tensor { // @expected-error@+2 {{'mhlo.transpose' op failed to infer returned types}} // expected-error@+1 {{op inferred type(s) 'tensor' are incompatible with return type(s) of operation 'tensor}} - %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x?x3x?xi32>) -> tensor + %0 = "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}> : (tensor<1x?x3x?xi32>) -> tensor func.return %0: tensor } @@ -2620,30 +2549,6 @@ func.func @triangular_solve_dynamic_dims_batch(%arg0: tensor, %arg1 // ----- -// CHECK-LABEL: func @triangular_solve_unranked -func.func @triangular_solve_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = #mhlo, unit_diagonal = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - -// CHECK-LABEL: func @triangular_solve_a_is_unranked -func.func @triangular_solve_a_is_unranked(%arg0: tensor<*xf32>, %arg1: tensor<4x4xf32>) -> tensor<*xf32> { - %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = #mhlo, unit_diagonal = true} : (tensor<*xf32>, tensor<4x4xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - -// CHECK-LABEL: func @triangular_solve_b_is_unranked -func.func @triangular_solve_b_is_unranked(%arg0: tensor<4x4xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = #mhlo, unit_diagonal = true} : (tensor<4x4xf32>, tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - func.func @triangular_solve_rank_less_than_2(%arg0: tensor<4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> { // @expected-error@+2 {{'mhlo.triangular_solve' op failed to infer returned types}} // expected-error@+1 {{operand 'a' must have rank >= 2, but got 'tensor<4xf32>'}} @@ -2796,7 +2701,7 @@ func.func @or_invalid_f32_type(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te // ----- func.func @floor_invalid_i32_type(%arg0: tensor<4xi32>) -> tensor<4xi32> { - // expected-error@+1 {{op operand #0 must be tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<4xi32>'}} + // expected-error@+1 {{op operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<4xi32>'}} %0 = "mhlo.floor"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> func.return %0 : tensor<4xi32> } @@ -2865,19 +2770,6 @@ func.func @sort_no_operands() { // ----- -// CHECK-LABEL: func @sort_unknown_rank -func.func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { - %0:2 = "mhlo.sort"(%input0, %input1) ({ - ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): - %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - %8 = "mhlo.select"(%7, %7, %7) : (tensor, tensor, tensor) -> tensor<*xi1> - "mhlo.return"(%8) : (tensor<*xi1>) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) - func.return -} - -// ----- - func.func @sort_dynamism(%input0: tensor, %input1: tensor<16x16xi32>) { %0:2 = "mhlo.sort"(%input0, %input1) ({ ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): @@ -2889,18 +2781,6 @@ func.func @sort_dynamism(%input0: tensor, %input1: tensor<16x16xi32>) // ----- -func.func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { - // expected-error @+1 {{comparator block argument #0 should be of type 'tensor' but got 'tensor'}} - %0:2 = "mhlo.sort"(%input0, %input1) ({ - ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): - %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) - func.return -} - -// ----- - func.func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{op requires the same shape for all operands and results}} %0:2 = "mhlo.sort"(%input0, %input1) ({ @@ -3050,12 +2930,12 @@ func.func @reverse_c2(%operand: tensor<3x2xi32>) -> tensor<3x2xi32> { // ----- -func.func @reverse_c3(%operand: tensor<*xi32>) -> tensor<*xi32> { +func.func @reverse_c3(%operand: tensor) -> tensor { // expected-error @+1 {{all dimensions should be non-negative. Got dimension: -1.}} %0 = "mhlo.reverse"(%operand) { dimensions = dense<-1> : tensor<1xi64> - } : (tensor<*xi32>) -> tensor<*xi32> - func.return %0 : tensor<*xi32> + } : (tensor) -> tensor + func.return %0 : tensor } // ----- @@ -3119,49 +2999,6 @@ func.func @dot_general(%arg0: tensor<1x?x1x?xf32>, %arg1: tensor) // ----- -func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { - %0 = "mhlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #mhlo.dot< - lhs_batching_dimensions = [0], - rhs_batching_dimensions = [0], - lhs_contracting_dimensions = [1], - rhs_contracting_dimensions = [1] - > - } : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - func.return -} - -// ----- - -func.func @dot_general(%arg0: tensor, %arg1: tensor<*xf32>) { - %0 = "mhlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #mhlo.dot< - lhs_batching_dimensions = [0], - rhs_batching_dimensions = [0], - lhs_contracting_dimensions = [1], - rhs_contracting_dimensions = [1] - > - } : (tensor, tensor<*xf32>) -> tensor - func.return -} - - -// ----- - -func.func @dot_general(%arg0: tensor, %arg1: tensor<*xf32>) { - %0 = "mhlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #mhlo.dot< - lhs_batching_dimensions = [0], - rhs_batching_dimensions = [0], - lhs_contracting_dimensions = [1], - rhs_contracting_dimensions = [1] - > - } : (tensor, tensor<*xf32>) -> tensor - func.return -} - -// ----- - func.func @dot_general(%arg0: tensor, %arg1: tensor) { // expected-error @+1 {{lhs and rhs should have the same number of batching dimensions}} %0 = "mhlo.dot_general"(%arg0, %arg1) { @@ -3177,7 +3014,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- -func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { +func.func @dot_general(%arg0: tensor, %arg1: tensor) { // expected-error @+1 {{lhs and rhs should have the same number of batching dimensions}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< @@ -3186,7 +3023,7 @@ func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1] > - } : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + } : (tensor, tensor) -> tensor func.return } @@ -3207,7 +3044,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- -func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { +func.func @dot_general(%arg0: tensor, %arg1: tensor) { // expected-error @+1 {{lhs and rhs should have the same number of contracting dimensions}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< @@ -3216,7 +3053,7 @@ func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { lhs_contracting_dimensions = [], rhs_contracting_dimensions = [1] > - } : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + } : (tensor, tensor) -> tensor func.return } @@ -3237,7 +3074,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- -func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { +func.func @dot_general(%arg0: tensor, %arg1: tensor) { // expected-error @+1 {{has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 0}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< @@ -3246,7 +3083,7 @@ func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1] > - } : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + } : (tensor, tensor) -> tensor func.return } @@ -3267,7 +3104,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- -func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { +func.func @dot_general(%arg0: tensor, %arg1: tensor) { // expected-error @+1 {{has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 1}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< @@ -3276,7 +3113,7 @@ func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { lhs_contracting_dimensions = [1, 1], rhs_contracting_dimensions = [1, 1] > - } : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + } : (tensor, tensor) -> tensor func.return } @@ -3297,7 +3134,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- -func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { +func.func @dot_general(%arg0: tensor, %arg1: tensor) { // expected-error @+1 {{has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 0}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< @@ -3306,7 +3143,7 @@ func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { lhs_contracting_dimensions = [0], rhs_contracting_dimensions = [1] > - } : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + } : (tensor, tensor) -> tensor func.return } @@ -3327,7 +3164,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- -func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { +func.func @dot_general(%arg0: tensor, %arg1: tensor) { // expected-error @+1 {{has duplicated dimension from rhs_batching_dimensions and rhs_contracting_dimensions: 0}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< @@ -3336,7 +3173,7 @@ func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0] > - } : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + } : (tensor, tensor) -> tensor func.return } @@ -3540,6 +3377,70 @@ func.func @dot_general_three_element_precision_config(%arg0: tensor<2x3x4xf32>, // ----- +// CHECK-LABEL: func @sparse_dot +func.func @sparse_dot(%arg0: tensor<2x16xf32>, %arg1: tensor<32x2xf32>, %meta: tensor<2x2xi16>) -> tensor<2x2xf32> { + %0 = "mhlo.sparse_dot"(%arg0, %arg1, %meta) { + lhs_sparsity = #mhlo.sparsity, + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [0] + > + } : (tensor<2x16xf32>, tensor<32x2xf32>, tensor<2x2xi16>) -> tensor<2x2xf32> + func.return %0 : tensor<2x2xf32> +} + +// ----- + +func.func @sparse_dot_incorrect_dimension(%arg0: tensor<2x16xf32>, %arg1: tensor<32x2xf32>, %meta: tensor<2x2xi16>) -> tensor<2x2xf32> { + // expected-error@+1 {{sparsity dimension is incorrect}} + %0 = "mhlo.sparse_dot"(%arg0, %arg1, %meta) { + lhs_sparsity = #mhlo.sparsity, + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [0] + > + } : (tensor<2x16xf32>, tensor<32x2xf32>, tensor<2x2xi16>) -> tensor<2x2xf32> + func.return %0 : tensor<2x2xf32> +} + +// ----- + +func.func @sparse_dot_incorrect_dimension(%arg0: tensor<2x16xf32>, %arg1: tensor<32x2xf32>, %meta: tensor<2x2xi16>) -> tensor<2x2xf32> { + // expected-error@+1 {{only 2:4 sparsity is supported}} + %0 = "mhlo.sparse_dot"(%arg0, %arg1, %meta) { + lhs_sparsity = #mhlo.sparsity, + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [0] + > + } : (tensor<2x16xf32>, tensor<32x2xf32>, tensor<2x2xi16>) -> tensor<2x2xf32> + func.return %0 : tensor<2x2xf32> +} + +// ----- + +func.func @sparse_dot(%arg0: tensor<2x32xf32>, %arg1: tensor<32x2xf32>, %meta: tensor<2x2xi16>) -> tensor<2x2xf32> { + // expected-error@+1 {{contracting dimension sizes must match for lhs/rhs}} + %0 = "mhlo.sparse_dot"(%arg0, %arg1, %meta) { + lhs_sparsity = #mhlo.sparsity, + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [0] + > + } : (tensor<2x32xf32>, tensor<32x2xf32>, tensor<2x2xi16>) -> tensor<2x2xf32> + func.return %0 : tensor<2x2xf32> +} + +// ----- + func.func @compatible_shapes(%arg0: tensor, %shape: tensor<2xindex>) -> tensor { %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor func.return %0 : tensor @@ -3606,9 +3507,9 @@ func.func @bitcast_convert_scalar(%arg: tensor) -> tensor { // ----- -func.func @bitcast_convert(%arg: tensor<*xf32>) -> tensor<*xf32> { - %0 = "mhlo.bitcast_convert"(%arg) : (tensor<*xf32>) -> tensor<*xf32> - return %0 : tensor<*xf32> +func.func @bitcast_convert(%arg: tensor) -> tensor { + %0 = "mhlo.bitcast_convert"(%arg) : (tensor) -> tensor + return %0 : tensor } // ----- @@ -3645,7 +3546,7 @@ func.func @stochastic_convert(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xui32>) - // ----- func.func @invalid_stochastic_convert_disallowed_random_type(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi8> { - // expected-error@+1 {{must be tensor of 4/8/16/32/64-bit unsigned integer values, but got 'tensor<2x4xi32>'}} + // expected-error@+1 {{must be ranked tensor of 4/8/16/32/64-bit unsigned integer values, but got 'tensor<2x4xi32>'}} %0 = "mhlo.stochastic_convert"(%arg0, %arg1) : (tensor<2x4xf32>, tensor<2x4xi32>) -> tensor<2x4xi8> return %0 : tensor<2x4xi8> } @@ -3700,7 +3601,7 @@ func.func @gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor<1x5x2xi3 // ----- // CHECK: gather -func.func @gather(%operand : tensor<*xi32>, %start_indices : tensor<1x5x2xi32>) -> tensor<1x5x8xi32> { +func.func @gather(%operand : tensor, %start_indices : tensor<1x5x2xi32>) -> tensor<1x5x8xi32> { %res = "mhlo.gather"(%operand, %start_indices) { dimension_numbers = #mhlo.gather< collapsed_slice_dims = [0, 1], @@ -3710,14 +3611,14 @@ func.func @gather(%operand : tensor<*xi32>, %start_indices : tensor<1x5x2xi32>) >, indices_are_sorted = false, slice_sizes = dense<[1, 1, 8]> : tensor<3xi64> - } : (tensor<*xi32>, tensor<1x5x2xi32>) -> tensor<1x5x8xi32> + } : (tensor, tensor<1x5x2xi32>) -> tensor<1x5x8xi32> func.return %res : tensor<1x5x8xi32> } // ----- // CHECK: gather -func.func @gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor<*xi32>) -> tensor<1x5x8xi32> { +func.func @gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor) -> tensor<1x5x8xi32> { %res = "mhlo.gather"(%operand, %start_indices) { dimension_numbers = #mhlo.gather< collapsed_slice_dims = [0, 1], @@ -3727,13 +3628,13 @@ func.func @gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor<*xi32>) >, indices_are_sorted = false, slice_sizes = dense<[1, 1, 8]> : tensor<3xi64> - } : (tensor<2x4x9xi32>, tensor<*xi32>) -> tensor<1x5x8xi32> + } : (tensor<2x4x9xi32>, tensor) -> tensor<1x5x8xi32> func.return %res : tensor<1x5x8xi32> } // ----- -func.func @gather(%operand : tensor<*xi32>, %start_indices : tensor<*xi32>) -> tensor<*xi32> { +func.func @gather(%operand : tensor, %start_indices : tensor) -> tensor { %res = "mhlo.gather"(%operand, %start_indices) { dimension_numbers = #mhlo.gather< collapsed_slice_dims = [0, 1], @@ -3743,8 +3644,8 @@ func.func @gather(%operand : tensor<*xi32>, %start_indices : tensor<*xi32>) -> t >, indices_are_sorted = false, slice_sizes = dense<[1, 1, 8]> : tensor<3xi64> - } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> - func.return %res : tensor<*xi32> + } : (tensor, tensor) -> tensor + func.return %res : tensor } // ----- @@ -3805,7 +3706,7 @@ func.func @gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor<1x5x2xi3 func.func @gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor<1x5x2xi32>) -> tensor<1x5x8xi32> { // @expected-error@+2 {{'mhlo.gather' op failed to infer returned types}} - // expected-error@+1 {{slice_sizes.rank != 1}} + // expected-error@+1 {{slice_sizes has rank 2 instead of required rank 1}} %res = "mhlo.gather"(%operand, %start_indices) { dimension_numbers = #mhlo.gather< collapsed_slice_dims = [0, 1], @@ -3839,7 +3740,7 @@ func.func @gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor<1x5x2xi3 // ----- -func.func @gather(%operand : tensor<*xi32>, %start_indices : tensor<*xi32>) -> tensor<*xi32> { +func.func @gather(%operand : tensor, %start_indices : tensor) -> tensor { // @expected-error@+2 {{'mhlo.gather' op failed to infer returned types}} // expected-error@+1 {{slice_sizes size (6) not equal to (implied) operand rank (3)}} %res = "mhlo.gather"(%operand, %start_indices) { @@ -3851,8 +3752,8 @@ func.func @gather(%operand : tensor<*xi32>, %start_indices : tensor<*xi32>) -> t >, indices_are_sorted = false, slice_sizes = dense<[1, 1, 8, 1, 2, 3]> : tensor<6xi64> - } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> - func.return %res : tensor<*xi32> + } : (tensor, tensor) -> tensor + func.return %res : tensor } // ----- @@ -3875,7 +3776,7 @@ func.func @gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor<1x5x2xi3 // ----- -func.func @gather(%operand : tensor<*xi32>, %start_indices : tensor) -> tensor<3xi32> { +func.func @gather(%operand : tensor, %start_indices : tensor) -> tensor<3xi32> { // @expected-error@+2 {{'mhlo.gather' op failed to infer returned types}} // expected-error@+1 {{inferred type(s) 'tensor<8x?x7x1x6x1x?xi32>' are incompatible with return type(s) of operation 'tensor<3xi32>'}} %res = "mhlo.gather"(%operand, %start_indices) { @@ -3887,13 +3788,13 @@ func.func @gather(%operand : tensor<*xi32>, %start_indices : tensor) >, indices_are_sorted = false, slice_sizes = dense<[1, 1, 8, 1, 7, 1, 6, 1]> : tensor<8xi64> - } : (tensor<*xi32>, tensor) -> tensor<3xi32> + } : (tensor, tensor) -> tensor<3xi32> func.return %res : tensor<3xi32> } // ----- -func.func @gather(%operand : tensor<*xi32>, %start_indices : tensor<*xi32>) -> tensor<*xi32> { +func.func @gather(%operand : tensor, %start_indices : tensor) -> tensor { // @expected-error@+2 {{'mhlo.gather' op failed to infer returned types}} // expected-error@+1 {{slice_sizes collapsed dimension 2 should <= 1 but got 8}} %res = "mhlo.gather"(%operand, %start_indices) { @@ -3905,8 +3806,8 @@ func.func @gather(%operand : tensor<*xi32>, %start_indices : tensor<*xi32>) -> t >, indices_are_sorted = false, slice_sizes = dense<[1, 1, 8]> : tensor<3xi64> - } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> - func.return %res : tensor<*xi32> + } : (tensor, tensor) -> tensor + func.return %res : tensor } // ----- @@ -3947,7 +3848,7 @@ func.func @gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor<1x5x2xi3 // ----- -func.func @gather(%operand : tensor, %start_indices : tensor<*xi32>) -> tensor<*xi32> { +func.func @gather(%operand : tensor, %start_indices : tensor) -> tensor { // @expected-error@+2 {{'mhlo.gather' op failed to infer returned types}} // expected-error@+1 {{slice size (-1) is out of bounds for operand dimension (2) at index 2}} %res = "mhlo.gather"(%operand, %start_indices) { @@ -3959,13 +3860,13 @@ func.func @gather(%operand : tensor, %start_indices : tensor<*xi32>) >, indices_are_sorted = false, slice_sizes = dense<[1, 1, -1]> : tensor<3xi64> - } : (tensor, tensor<*xi32>) -> tensor<*xi32> - func.return %res : tensor<*xi32> + } : (tensor, tensor) -> tensor + func.return %res : tensor } // ----- -func.func @gather(%operand : tensor, %start_indices : tensor<*xi32>) -> tensor<*xi32> { +func.func @gather(%operand : tensor, %start_indices : tensor) -> tensor { // @expected-error@+2 {{'mhlo.gather' op failed to infer returned types}} // expected-error@+1 {{slice size (8) is out of bounds for operand dimension (2) at index 2}} %res = "mhlo.gather"(%operand, %start_indices) { @@ -3977,8 +3878,8 @@ func.func @gather(%operand : tensor, %start_indices : tensor<*xi32>) >, indices_are_sorted = false, slice_sizes = dense<[1, 1, 8]> : tensor<3xi64> - } : (tensor, tensor<*xi32>) -> tensor<*xi32> - func.return %res : tensor<*xi32> + } : (tensor, tensor) -> tensor + func.return %res : tensor } // ----- @@ -4160,7 +4061,7 @@ func.func @dynamic_gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor< // ----- -func.func @dynamic_gather(%operand : tensor<*xi32>, %start_indices : tensor<*xi32>, %slice_sizes : tensor<*xi32>) -> tensor<*xi32> { +func.func @dynamic_gather(%operand : tensor, %start_indices : tensor, %slice_sizes : tensor) -> tensor { %res = "mhlo.dynamic_gather"(%operand, %start_indices, %slice_sizes) { dimension_numbers = #mhlo.gather< collapsed_slice_dims = [0, 1], @@ -4169,13 +4070,13 @@ func.func @dynamic_gather(%operand : tensor<*xi32>, %start_indices : tensor<*xi3 start_index_map = [0, 1] >, indices_are_sorted = false - } : (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> - func.return %res : tensor<*xi32> + } : (tensor, tensor, tensor) -> tensor + func.return %res : tensor } // ----- -func.func @dynamic_gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor<*xi32>, %slice_sizes : tensor<*xi32>) -> tensor<*xi32> { +func.func @dynamic_gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor, %slice_sizes : tensor) -> tensor { %res = "mhlo.dynamic_gather"(%operand, %start_indices, %slice_sizes) { dimension_numbers = #mhlo.gather< collapsed_slice_dims = [0, 1], @@ -4184,14 +4085,13 @@ func.func @dynamic_gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor< start_index_map = [0, 1] >, indices_are_sorted = false - } : (tensor<2x4x9xi32>, tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> - func.return %res : tensor<*xi32> + } : (tensor<2x4x9xi32>, tensor, tensor) -> tensor + func.return %res : tensor } - // ----- -func.func @dynamic_gather(%operand : tensor<*xi32>, %start_indices : tensor, %slice_sizes : tensor<*xi32>) -> tensor<*xi32> { +func.func @dynamic_gather(%operand : tensor, %start_indices : tensor, %slice_sizes : tensor) -> tensor { // @expected-error@+2 {{'mhlo.dynamic_gather' op failed to infer returned types}} // expected-error@+1 {{index_vector_dim 4 is out of bounds for start indices with rank 3}} %res = "mhlo.dynamic_gather"(%operand, %start_indices, %slice_sizes) { @@ -4202,13 +4102,13 @@ func.func @dynamic_gather(%operand : tensor<*xi32>, %start_indices : tensor, indices_are_sorted = false - } : (tensor<*xi32>, tensor, tensor<*xi32>) -> tensor<*xi32> - func.return %res : tensor<*xi32> + } : (tensor, tensor, tensor) -> tensor + func.return %res : tensor } // ----- -func.func @dynamic_gather(%operand : tensor, %start_indices : tensor<*xi32>, %slice_sizes : tensor<*xi32>) -> tensor<*xi32> { +func.func @dynamic_gather(%operand : tensor, %start_indices : tensor, %slice_sizes : tensor) -> tensor { // @expected-error@+2 {{'mhlo.dynamic_gather' op failed to infer returned types}} // expected-error@+1 {{offset_dims size (2) plus collapse_slice_dims size (2) is not equal to operand rank (3)}} %res = "mhlo.dynamic_gather"(%operand, %start_indices, %slice_sizes) { @@ -4219,13 +4119,13 @@ func.func @dynamic_gather(%operand : tensor, %start_indices : tensor< start_index_map = [0, 1] >, indices_are_sorted = false - } : (tensor, tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> - func.return %res : tensor<*xi32> + } : (tensor, tensor, tensor) -> tensor + func.return %res : tensor } // ----- -func.func @dynamic_gather(%operand : tensor<*xi32>, %start_indices : tensor, %slice_sizes : tensor<*xi32>) -> tensor<*xi32> { +func.func @dynamic_gather(%operand : tensor, %start_indices : tensor, %slice_sizes : tensor) -> tensor { // @expected-error@+2 {{'mhlo.dynamic_gather' op failed to infer returned types}} // expected-error@+1 {{start_index_map size (1) is not equal to size of index dimension (2) of start_indices (2)}} %res = "mhlo.dynamic_gather"(%operand, %start_indices, %slice_sizes) { @@ -4236,13 +4136,13 @@ func.func @dynamic_gather(%operand : tensor<*xi32>, %start_indices : tensor, indices_are_sorted = false - } : (tensor<*xi32>, tensor, tensor<*xi32>) -> tensor<*xi32> - func.return %res : tensor<*xi32> + } : (tensor, tensor, tensor) -> tensor + func.return %res : tensor } // ----- -func.func @dynamic_gather(%operand : tensor<*xi32>, %start_indices : tensor<*xi32>, %slice_sizes : tensor) -> tensor<*xi32> { +func.func @dynamic_gather(%operand : tensor, %start_indices : tensor, %slice_sizes : tensor) -> tensor { // @expected-error@+2 {{'mhlo.dynamic_gather' op failed to infer returned types}} // expected-error@+1 {{slice_sizes.rank != 1}} %res = "mhlo.dynamic_gather"(%operand, %start_indices, %slice_sizes) { @@ -4253,13 +4153,13 @@ func.func @dynamic_gather(%operand : tensor<*xi32>, %start_indices : tensor<*xi3 start_index_map = [0, 1] >, indices_are_sorted = false - } : (tensor<*xi32>, tensor<*xi32>, tensor) -> tensor<*xi32> - func.return %res : tensor<*xi32> + } : (tensor, tensor, tensor) -> tensor + func.return %res : tensor } // ----- -func.func @dynamic_gather(%operand : tensor, %start_indices : tensor<*xi32>, %slice_sizes : tensor<2xi32>) -> tensor<*xi32> { +func.func @dynamic_gather(%operand : tensor, %start_indices : tensor, %slice_sizes : tensor<2xi32>) -> tensor { // @expected-error@+2 {{'mhlo.dynamic_gather' op failed to infer returned types}} // expected-error@+1 {{slice_sizes size (2) not equal to (implied) operand rank (3)}} %res = "mhlo.dynamic_gather"(%operand, %start_indices, %slice_sizes) { @@ -4270,8 +4170,8 @@ func.func @dynamic_gather(%operand : tensor, %start_indices : tensor< start_index_map = [0, 1] >, indices_are_sorted = false - } : (tensor, tensor<*xi32>, tensor<2xi32>) -> tensor<*xi32> - func.return %res : tensor<*xi32> + } : (tensor, tensor, tensor<2xi32>) -> tensor + func.return %res : tensor } // ----- @@ -4293,7 +4193,7 @@ func.func @dynamic_gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor< // ----- -func.func @dynamic_gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor<1x5x2xi32>, %slice_sizes : tensor<*xi32>) -> tensor<3xi32> { +func.func @dynamic_gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor<1x5x2xi32>, %slice_sizes : tensor) -> tensor<3xi32> { // @expected-error@+2 {{'mhlo.dynamic_gather' op failed to infer returned types}} // expected-error@+1 {{inferred type(s) 'tensor<1x5x?xi32>' are incompatible with return type(s) of operation 'tensor<3xi32>'}} %res = "mhlo.dynamic_gather"(%operand, %start_indices, %slice_sizes) { @@ -4304,7 +4204,7 @@ func.func @dynamic_gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor< start_index_map = [0, 1] >, indices_are_sorted = false - } : (tensor<2x4x9xi32>, tensor<1x5x2xi32>, tensor<*xi32>) -> tensor<3xi32> + } : (tensor<2x4x9xi32>, tensor<1x5x2xi32>, tensor) -> tensor<3xi32> func.return %res : tensor<3xi32> } @@ -4327,7 +4227,7 @@ func.func @dynamic_gather(%operand : tensor, %start_indices : tensor< // ----- -func.func @dynamic_gather(%operand : tensor<*xi32>, %start_indices : tensor, %slice_sizes : tensor) -> tensor { +func.func @dynamic_gather(%operand : tensor, %start_indices : tensor, %slice_sizes : tensor) -> tensor { // @expected-error@+2 {{'mhlo.dynamic_gather' op failed to infer returned types}} // expected-error@+1 {{inferred type(s) 'tensor' are incompatible with return type(s) of operation 'tensor'}} %res = "mhlo.dynamic_gather"(%operand, %start_indices, %slice_sizes) { @@ -4338,7 +4238,7 @@ func.func @dynamic_gather(%operand : tensor<*xi32>, %start_indices : tensor, indices_are_sorted = false - } : (tensor<*xi32>, tensor, tensor) -> tensor + } : (tensor, tensor, tensor) -> tensor func.return %res : tensor } @@ -4347,14 +4247,14 @@ func.func @dynamic_gather(%operand : tensor<*xi32>, %start_indices : tensor) -> tensor { // @expected-error@+2 {{'mhlo.get_dimension_size' op failed to infer returned types}} // expected-error@+1 {{requires dimension attribute in range [0, 3); found (3)}} - %size = "mhlo.get_dimension_size"(%I) {dimension = 3 : i64} : (tensor<1x128x512xf32>) -> tensor + %size = "mhlo.get_dimension_size"(%I) <{dimension = 3 : i64}> : (tensor<1x128x512xf32>) -> tensor func.return %size : tensor } // ----- func.func @get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor { - %size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i64} : (tensor<1x128x512xf32>) -> tensor + %size = "mhlo.get_dimension_size"(%I) <{dimension = 2 : i64}> : (tensor<1x128x512xf32>) -> tensor func.return %size : tensor } @@ -4363,7 +4263,7 @@ func.func @get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor { func.func @get_dimension_size_negative_dimension(%I: tensor<1x128x512xf32>) -> tensor { // @expected-error@+2 {{'mhlo.get_dimension_size' op failed to infer returned types}} // expected-error@+1 {{requires non-negative dimension attribute; found (-1)}} - %size = "mhlo.get_dimension_size"(%I) {dimension = -1 : i64} : (tensor<1x128x512xf32>) -> tensor + %size = "mhlo.get_dimension_size"(%I) <{dimension = -1 : i64}> : (tensor<1x128x512xf32>) -> tensor func.return %size : tensor } @@ -4372,7 +4272,7 @@ func.func @get_dimension_size_negative_dimension(%I: tensor<1x128x512xf32>) -> t func.func @get_dimension_size_invalid_dimension(%I: tensor<1x128x512xf32>) -> tensor { // @expected-error@+2 {{'mhlo.get_dimension_size' op failed to infer returned types}} // expected-error@+1 {{requires dimension attribute in range [0, 3); found (3)}} - %size = "mhlo.get_dimension_size"(%I) {dimension = 3 : i64} : (tensor<1x128x512xf32>) -> tensor + %size = "mhlo.get_dimension_size"(%I) <{dimension = 3 : i64}> : (tensor<1x128x512xf32>) -> tensor func.return %size : tensor } @@ -5411,7 +5311,7 @@ func.func @xla.rng_get_and_update_state() -> tensor<2xui64> { // CHECK-LABEL: @fft func.func @fft(%arg0: tensor<3x9xcomplex>) -> tensor<3x9xcomplex> { - %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo } : (tensor<3x9xcomplex>) -> tensor<3x9xcomplex> + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo }> : (tensor<3x9xcomplex>) -> tensor<3x9xcomplex> func.return %0 : tensor<3x9xcomplex> } @@ -5419,7 +5319,7 @@ func.func @fft(%arg0: tensor<3x9xcomplex>) -> tensor<3x9xcomplex> { // CHECK-LABEL: @ifft func.func @ifft(%arg0: tensor<3x9xcomplex>) -> tensor<3x9xcomplex> { - %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo } : (tensor<3x9xcomplex>) -> tensor<3x9xcomplex> + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo }> : (tensor<3x9xcomplex>) -> tensor<3x9xcomplex> func.return %0 : tensor<3x9xcomplex> } @@ -5427,7 +5327,7 @@ func.func @ifft(%arg0: tensor<3x9xcomplex>) -> tensor<3x9xcomplex> { // CHECK-LABEL: @rfft func.func @rfft(%arg0: tensor<3x9xf32>) -> tensor<3x5xcomplex> { - %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo } : (tensor<3x9xf32>) -> tensor<3x5xcomplex> + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo }> : (tensor<3x9xf32>) -> tensor<3x5xcomplex> func.return %0 : tensor<3x5xcomplex> } @@ -5435,24 +5335,16 @@ func.func @rfft(%arg0: tensor<3x9xf32>) -> tensor<3x5xcomplex> { // CHECK-LABEL: @irfft func.func @irfft(%arg0: tensor<3x9xcomplex>) -> tensor<3x16xf32> { - %0 = "mhlo.fft"(%arg0) { fft_length = dense<16> : tensor<1xi64>, fft_type = #mhlo } : (tensor<3x9xcomplex>) -> tensor<3x16xf32> + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<16> : tensor<1xi64>, fft_type = #mhlo }> : (tensor<3x9xcomplex>) -> tensor<3x16xf32> func.return %0 : tensor<3x16xf32> } // ----- -// CHECK-LABEL: @rfft_unranked -func.func @rfft_unranked(%arg0: tensor<*xf32>) -> tensor<*xcomplex> { - %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo } : (tensor<*xf32>) -> tensor<*xcomplex> - func.return %0 : tensor<*xcomplex> -} - -// ----- - func.func @rfft_not_float32or64(%arg0: tensor<3x9xf16>) -> tensor<3x5xcomplex> { // @expected-error@+2 {{'mhlo.fft' op failed to infer returned types}} // expected-error@+1 {{RFFT requires f32 or f64 input type, but is given 'f16'.}} - %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo } : (tensor<3x9xf16>) -> tensor<3x5xcomplex> + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo }> : (tensor<3x9xf16>) -> tensor<3x5xcomplex> func.return %0 : tensor<3x5xcomplex> } @@ -5461,7 +5353,7 @@ func.func @rfft_not_float32or64(%arg0: tensor<3x9xf16>) -> tensor<3x5xcomplex) -> tensor<3x9xcomplex> { // @expected-error@+2 {{'mhlo.fft' op failed to infer returned types}} // expected-error@+1 {{rank must be between 1 and 3, but got 4.}} - %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<4xi64>, fft_type = #mhlo } : (tensor<3x9xf32>) -> tensor<3x9xcomplex> + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<9> : tensor<4xi64>, fft_type = #mhlo }> : (tensor<3x9xf32>) -> tensor<3x9xcomplex> func.return %0 : tensor<3x9xcomplex> } @@ -5470,7 +5362,7 @@ func.func @fft_invalid_rank(%arg0: tensor<3x9xf32>) -> tensor<3x9xcomplex> func.func @fft_rank_mismatch(%arg0: tensor<3x9xf32>) -> tensor<3x9xcomplex> { // @expected-error@+2 {{'mhlo.fft' op failed to infer returned types}} // expected-error@+1 {{operand rank must not be less than fft rank of 3 for operand of type 'tensor<3x9xf32>'}} - %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<3xi64>, fft_type = #mhlo } : (tensor<3x9xf32>) -> tensor<3x9xcomplex> + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<9> : tensor<3xi64>, fft_type = #mhlo }> : (tensor<3x9xf32>) -> tensor<3x9xcomplex> func.return %0 : tensor<3x9xcomplex> } @@ -5479,7 +5371,7 @@ func.func @fft_rank_mismatch(%arg0: tensor<3x9xf32>) -> tensor<3x9xcomplex> func.func @rfft_invalid_dim(%arg0: tensor<3x9xf32>) -> tensor<3x9xcomplex> { // @expected-error@+2 {{'mhlo.fft' op failed to infer returned types}} // expected-error@+1 {{RFFT requires innermost dimensions to be compatible with fft_length. Got: 3, 9 but wanted 9, 9.}} - %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<2xi64>, fft_type = #mhlo } : (tensor<3x9xf32>) -> tensor<3x9xcomplex> + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<9> : tensor<2xi64>, fft_type = #mhlo }> : (tensor<3x9xf32>) -> tensor<3x9xcomplex> func.return %0 : tensor<3x9xcomplex> } @@ -5488,7 +5380,7 @@ func.func @rfft_invalid_dim(%arg0: tensor<3x9xf32>) -> tensor<3x9xcomplex> func.func @irfft_invalid_dim(%arg0: tensor<3x9xcomplex>) -> tensor<3x9xf32> { // @expected-error@+2 {{'mhlo.fft' op failed to infer returned types}} // expected-error@+1 {{IRFFT requires non-final dimensions to be compatible with fft_length. Got: 3, 9 but wanted 9, 9, and 3 != 9.}} - %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<2xi64>, fft_type = #mhlo } : (tensor<3x9xcomplex>) -> tensor<3x9xf32> + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<9> : tensor<2xi64>, fft_type = #mhlo }> : (tensor<3x9xcomplex>) -> tensor<3x9xf32> func.return %0 : tensor<3x9xf32> } @@ -5497,7 +5389,7 @@ func.func @irfft_invalid_dim(%arg0: tensor<3x9xcomplex>) -> tensor<3x9xf32> func.func @irfft_invalid_dim(%arg0: tensor<3x9xcomplex>) -> tensor<3x9xf32> { // @expected-error@+2 {{'mhlo.fft' op failed to infer returned types}} // expected-error@+1 {{IRFFT requires innermost dimension to be compatible with fft_length[-1]/2+1. Got: 9 but fft_length is 9.}} - %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo } : (tensor<3x9xcomplex>) -> tensor<3x9xf32> + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo }> : (tensor<3x9xcomplex>) -> tensor<3x9xf32> func.return %0 : tensor<3x9xf32> } @@ -5506,7 +5398,7 @@ func.func @irfft_invalid_dim(%arg0: tensor<3x9xcomplex>) -> tensor<3x9xf32> func.func @irfft_invalid_elt(%arg0: tensor<3x9xf32>) -> tensor<3x9xcomplex> { // @expected-error@+2 {{'mhlo.fft' op failed to infer returned types}} // expected-error@+1 {{FFT/IFFT/IRFFT take a complex tensor as input, but is given 'tensor<3x9xf32>'}} - %0 = "mhlo.fft"(%arg0) { fft_length = dense<16> : tensor<1xi64>, fft_type = #mhlo } : (tensor<3x9xf32>) -> tensor<3x9xcomplex> + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<16> : tensor<1xi64>, fft_type = #mhlo }> : (tensor<3x9xf32>) -> tensor<3x9xcomplex> func.return %0 : tensor<3x9xcomplex> } @@ -5515,7 +5407,7 @@ func.func @irfft_invalid_elt(%arg0: tensor<3x9xf32>) -> tensor<3x9xcomplex> func.func @irfft_invalid_ret_elt(%arg0: tensor<3x9xcomplex>) -> tensor<3x16xcomplex> { // @expected-error@+2 {{'mhlo.fft' op failed to infer returned types}} // expected-error@+1 {{inferred type(s) 'tensor<3x16xf32>' are incompatible with return type(s) of operation 'tensor<3x16xcomplex>'}} - %0 = "mhlo.fft"(%arg0) { fft_length = dense<16> : tensor<1xi64>, fft_type = #mhlo } : (tensor<3x9xcomplex>) -> tensor<3x16xcomplex> + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<16> : tensor<1xi64>, fft_type = #mhlo }> : (tensor<3x9xcomplex>) -> tensor<3x16xcomplex> func.return %0 : tensor<3x16xcomplex> } @@ -5524,7 +5416,7 @@ func.func @irfft_invalid_ret_elt(%arg0: tensor<3x9xcomplex>) -> tensor<3x16 func.func @rfft_invalid_ret_elt(%arg0: tensor<3x9xf32>) -> tensor<3x9xf32> { // @expected-error@+2 {{'mhlo.fft' op failed to infer returned types}} // expected-error@+1 {{inferred type(s) 'tensor<3x5xcomplex>' are incompatible with return type(s) of operation 'tensor<3x9xf32>'}} - %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo } : (tensor<3x9xf32>) -> tensor<3x9xf32> + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo }> : (tensor<3x9xf32>) -> tensor<3x9xf32> func.return %0 : tensor<3x9xf32> } @@ -5532,7 +5424,7 @@ func.func @rfft_invalid_ret_elt(%arg0: tensor<3x9xf32>) -> tensor<3x9xf32> { // CHECK-LABEL: @rfft_dynamic func.func @rfft_dynamic(%arg0: tensor) -> tensor> { - %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo } : (tensor) -> tensor> + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo }> : (tensor) -> tensor> func.return %0 : tensor> } @@ -5541,7 +5433,7 @@ func.func @rfft_dynamic(%arg0: tensor) -> tensor> { func.func @rfft_dynamic_incompatible_dims(%arg0: tensor<3x10xf32>) -> tensor> { // @expected-error@+2 {{'mhlo.fft' op failed to infer returned types}} // expected-error@+1{{RFFT requires innermost dimensions to be compatible with fft_length. Got: 3, 10 but wanted 9.}} - %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo } : (tensor<3x10xf32>) -> tensor> + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo }> : (tensor<3x10xf32>) -> tensor> func.return %0 : tensor> } @@ -5549,7 +5441,7 @@ func.func @rfft_dynamic_incompatible_dims(%arg0: tensor<3x10xf32>) -> tensor>) -> tensor { - %0 = "mhlo.fft"(%arg0) { fft_length = dense<16> : tensor<1xi64>, fft_type = #mhlo } : (tensor>) -> tensor + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<16> : tensor<1xi64>, fft_type = #mhlo }> : (tensor>) -> tensor func.return %0 : tensor } @@ -5558,7 +5450,7 @@ func.func @irfft_dynamic(%arg0: tensor>) -> tensor { func.func @irfft_dynamic_incompatible_non_final_dims(%arg0: tensor>) -> tensor { // @expected-error@+2 {{'mhlo.fft' op failed to infer returned types}} // expected-error@+1{{IRFFT requires non-final dimensions to be compatible with fft_length. Got: -9223372036854775808, 3, 15 but wanted 4, 16, and 3 != 4}} - %0 = "mhlo.fft"(%arg0) { fft_length = dense<[4, 16]> : tensor<2xi64>, fft_type = #mhlo } : (tensor>) -> tensor + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<[4, 16]> : tensor<2xi64>, fft_type = #mhlo }> : (tensor>) -> tensor func.return %0 : tensor } @@ -5567,7 +5459,7 @@ func.func @irfft_dynamic_incompatible_non_final_dims(%arg0: tensor>) -> tensor { // @expected-error@+2 {{'mhlo.fft' op failed to infer returned types}} // expected-error@+1{{IRFFT requires innermost dimension to be compatible with fft_length[-1]/2+1. Got: 8 but fft_length is 16.}} - %0 = "mhlo.fft"(%arg0) { fft_length = dense<16> : tensor<1xi64>, fft_type = #mhlo } : (tensor>) -> tensor + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<16> : tensor<1xi64>, fft_type = #mhlo }> : (tensor>) -> tensor func.return %0 : tensor } @@ -5575,7 +5467,7 @@ func.func @irfft_dynamic_incompatible_final_dim(%arg0: tensor>) // CHECK-LABEL: @irfft_dynamic func.func @irfft_dynamic(%arg0: tensor>) -> tensor { - %0 = "mhlo.fft"(%arg0) { fft_length = dense<16> : tensor<1xi64>, fft_type = #mhlo } : (tensor>) -> tensor + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<16> : tensor<1xi64>, fft_type = #mhlo }> : (tensor>) -> tensor func.return %0 : tensor } @@ -5718,14 +5610,6 @@ func.func @uniform_dequantize(%arg: tensor<16x16x!quant.uniform // ----- -// CHECK: func @uniform_dequantize_unranked -func.func @uniform_dequantize_unranked(%arg: tensor<*x!quant.uniform>) -> tensor<*xf32> { - %0 = mhlo.uniform_dequantize %arg : (tensor<*x!quant.uniform>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - // CHECK-LABEL: func @quantized_constants func.func @quantized_constants() -> (tensor<2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x!quant.uniform>) { %0 = mhlo.constant() {value = dense<[1, 2]> : tensor<2xi8>} : () -> tensor<2x!quant.uniform> @@ -5734,7 +5618,7 @@ func.func @quantized_constants() -> (tensor<2x!quant.uniform>, t %3 = mhlo.uniform_quantize %2 : (tensor<2xf32>) -> tensor<2x!quant.uniform> %4 = mhlo.uniform_quantize %1 : (tensor<2xf32>) -> tensor<2x!quant.uniform> func.return %0, %4, %3 : tensor<2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x!quant.uniform> - // CHECK: mhlo.constant() {value = dense<[1, 2]> : tensor<2xi8>} : () -> tensor<2x!quant.uniform> + // CHECK: mhlo.constant() <{value = dense<[1, 2]> : tensor<2xi8>}> : () -> tensor<2x!quant.uniform> // CHECK-NEXT: mhlo.constant dense<[1.000000e+01, 1.200000e+01]> : tensor<2xf32> // CHECK-NEXT: mhlo.constant dense<[3.000000e+00, 1.000000e+02]> : tensor<2xf32> } @@ -5767,7 +5651,7 @@ func.func @dot_i8xi8_i16(%arg0: tensor<1x2xi8>, %arg1: tensor<2x1xi8>) -> tensor // CHECK-LABEL: func @einsum_i4xi4_i8 func.func @einsum_i4xi4_i8(%arg0: tensor<1x2xi4>, %arg1: tensor<2x1xi4>) -> tensor<1x1xi8> { - %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<1x2xi4>, tensor<2x1xi4>) -> tensor<1x1xi8> + %0 = "mhlo.einsum"(%arg0, %arg1) <{einsum_config = "ab,bc->ac"}> : (tensor<1x2xi4>, tensor<2x1xi4>) -> tensor<1x1xi8> func.return %0: tensor<1x1xi8> } @@ -5775,7 +5659,7 @@ func.func @einsum_i4xi4_i8(%arg0: tensor<1x2xi4>, %arg1: tensor<2x1xi4>) -> tens // CHECK-LABEL: func @einsum_i8xi8_i16 func.func @einsum_i8xi8_i16(%arg0: tensor<1x2xi8>, %arg1: tensor<2x1xi8>) -> tensor<1x1xi16> { - %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<1x2xi8>, tensor<2x1xi8>) -> tensor<1x1xi16> + %0 = "mhlo.einsum"(%arg0, %arg1) <{einsum_config = "ab,bc->ac"}> : (tensor<1x2xi8>, tensor<2x1xi8>) -> tensor<1x1xi16> func.return %0: tensor<1x1xi16> } @@ -5933,47 +5817,6 @@ func.func @is_compatible_dynamism_mix(%arg0: tensor, %arg1: tensor<1xf32> func.return } -// TODO(b/231448733): verifyCompatibleShape allows rankedness mismatches but Elemementwise doesn't. -// Sort this out while refactoring uses of SameOperandsAndResultType and friends. -// func.func @is_compatible_dynamism_mix(%arg0: tensor<*xf32>, %arg1: tensor, %arg2: tensor<1xf32>) { -// %0 = "mhlo.add"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> -// %1 = "mhlo.add"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor -// %2 = "mhlo.add"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<1xf32> -// %3 = "mhlo.add"(%arg0, %arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> -// %4 = "mhlo.add"(%arg0, %arg1) : (tensor<*xf32>, tensor) -> tensor -// %5 = "mhlo.add"(%arg0, %arg1) : (tensor<*xf32>, tensor) -> tensor<1xf32> -// %6 = "mhlo.add"(%arg0, %arg2) : (tensor<*xf32>, tensor<1xf32>) -> tensor<*xf32> -// %7 = "mhlo.add"(%arg0, %arg2) : (tensor<*xf32>, tensor<1xf32>) -> tensor -// %8 = "mhlo.add"(%arg0, %arg2) : (tensor<*xf32>, tensor<1xf32>) -> tensor<1xf32> -// %9 = "mhlo.add"(%arg1, %arg0) : (tensor, tensor<*xf32>) -> tensor<*xf32> -// %10 = "mhlo.add"(%arg1, %arg0) : (tensor, tensor<*xf32>) -> tensor -// %11 = "mhlo.add"(%arg1, %arg0) : (tensor, tensor<*xf32>) -> tensor<1xf32> -// %12 = "mhlo.add"(%arg1, %arg1) : (tensor, tensor) -> tensor<*xf32> -// %13 = "mhlo.add"(%arg1, %arg1) : (tensor, tensor) -> tensor -// %14 = "mhlo.add"(%arg1, %arg1) : (tensor, tensor) -> tensor<1xf32> -// %15 = "mhlo.add"(%arg1, %arg2) : (tensor, tensor<1xf32>) -> tensor<*xf32> -// %16 = "mhlo.add"(%arg1, %arg2) : (tensor, tensor<1xf32>) -> tensor -// %17 = "mhlo.add"(%arg1, %arg2) : (tensor, tensor<1xf32>) -> tensor<1xf32> -// %18 = "mhlo.add"(%arg2, %arg0) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32> -// %19 = "mhlo.add"(%arg2, %arg0) : (tensor<1xf32>, tensor<*xf32>) -> tensor -// %20 = "mhlo.add"(%arg2, %arg0) : (tensor<1xf32>, tensor<*xf32>) -> tensor<1xf32> -// %21 = "mhlo.add"(%arg2, %arg1) : (tensor<1xf32>, tensor) -> tensor<*xf32> -// %22 = "mhlo.add"(%arg2, %arg1) : (tensor<1xf32>, tensor) -> tensor -// %23 = "mhlo.add"(%arg2, %arg1) : (tensor<1xf32>, tensor) -> tensor<1xf32> -// %24 = "mhlo.add"(%arg2, %arg2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32> -// %25 = "mhlo.add"(%arg2, %arg2) : (tensor<1xf32>, tensor<1xf32>) -> tensor -// %26 = "mhlo.add"(%arg2, %arg2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> -// func.return -// } - -// ----- - -func.func @is_compatible_dynamism_rankedness_mismatch(%arg0: tensor<*xf32>) { - // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}} - %0 = "mhlo.add"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<1xf32> - func.return -} - // ----- func.func @is_compatible_dynamism_ranked_mismatch(%arg0: tensor) { @@ -5992,16 +5835,10 @@ func.func @is_compatible_dynamism_dim_mismatch(%arg0: tensor<1x?xf32>) { // ----- -// TODO(b/230263270): For mhlo.add, the plan is to only allow fp+fp=fp, q+q=q and q+q=fp. func.func @is_compatible_quant_mix_non_quant(%arg0: tensor<1xf32>, %arg1: tensor<1x!quant.uniform>) { %0 = "mhlo.add"(%arg0, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> - %1 = "mhlo.add"(%arg0, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1x!quant.uniform> - %2 = "mhlo.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1x!quant.uniform>) -> tensor<1x!quant.uniform> - %3 = "mhlo.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1x!quant.uniform>) -> tensor<1x!quant.uniform> - %4 = "mhlo.add"(%arg1, %arg0) : (tensor<1x!quant.uniform>, tensor<1xf32>) -> tensor<1xf32> - %5 = "mhlo.add"(%arg1, %arg0) : (tensor<1x!quant.uniform>, tensor<1xf32>) -> tensor<1xf32> - %6 = "mhlo.add"(%arg1, %arg1) : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<1x!quant.uniform> - %7 = "mhlo.add"(%arg1, %arg1) : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<1x!quant.uniform> + %1 = "mhlo.add"(%arg1, %arg1) : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<1x!quant.uniform> + %2 = "mhlo.add"(%arg1, %arg1) : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<1x!quant.uniform> func.return } @@ -6097,10 +5934,10 @@ func.func @quantization_supported_ops(%arg0: tensor<1x2x2x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>, %arg1: tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30}>>) { - %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0,1,3]> : tensor<3xi64>} : (tensor<1x2x2x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x3x2x!quant.uniform:f32:3, {0.1:-30, 0.5:-20}>> - %1 = "stablehlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0,1,2]> : tensor<3xi64>} : (tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30}>>) -> tensor<2x2x2x!quant.uniform:f32:0, {0.1:-30, 0.1:-30}>> - %2 = stablehlo.reshape %arg0 : (tensor<1x2x2x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<2x2x!quant.uniform:f32:1, {0.1:-30, 0.5:-20}>> - %3 = "stablehlo.transpose"(%arg0) {permutation = dense<[0,2,1]> : tensor<3xi64>}: (tensor<1x2x2x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x2x!quant.uniform:f32:1, {0.1:-30, 0.5:-20}>> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0,1,3]> : tensor<3xi64>}> : (tensor<1x2x2x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x3x2x!quant.uniform:f32:3, {0.1:-30, 0.5:-20}>> + %1 = "mhlo.broadcast_in_dim"(%arg1) <{broadcast_dimensions = dense<[0,1,2]> : tensor<3xi64>}> : (tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30}>>) -> tensor<2x2x2x!quant.uniform:f32:0, {0.1:-30, 0.1:-30}>> + %2 = mhlo.reshape %arg0 : (tensor<1x2x2x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<2x2x!quant.uniform:f32:1, {0.1:-30, 0.5:-20}>> + %3 = "mhlo.transpose"(%arg0) <{permutation = dense<[0,2,1]> : tensor<3xi64>}> : (tensor<1x2x2x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x2x!quant.uniform:f32:1, {0.1:-30, 0.5:-20}>> func.return } @@ -6216,7 +6053,7 @@ func.func @complex(%arg0: tensor<10x10xf32>, %arg1: tensor<10x10xf32>) -> tensor // ----- func.func @complex_int_input(%arg0: tensor<10x10xi32>, %arg1: tensor<10x10xi32>) -> tensor<10x10xcomplex> { - // expected-error@+1 {{operand #0 must be tensor of 32-bit float or 64-bit float values, but got 'tensor<10x10xi32>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of 32-bit float or 64-bit float values, but got 'tensor<10x10xi32>'}} %0 = "mhlo.complex"(%arg0, %arg1) {} : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<10x10xcomplex> func.return %0 : tensor<10x10xcomplex> } @@ -6232,7 +6069,7 @@ func.func @complex_f32_f64_mix_input(%arg0: tensor<10x10xf32>, %arg1: tensor<10x // ----- func.func @complex_f16_input(%arg0: tensor<10x10xf16>, %arg1: tensor<10x10xf16>) -> tensor<10x10xcomplex> { - // expected-error@+1 {{operand #0 must be tensor of 32-bit float or 64-bit float values, but got 'tensor<10x10xf16>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of 32-bit float or 64-bit float values, but got 'tensor<10x10xf16>'}} %0 = "mhlo.complex"(%arg0, %arg1) {} : (tensor<10x10xf16>, tensor<10x10xf16>) -> tensor<10x10xcomplex> func.return %0 : tensor<10x10xcomplex> } @@ -6282,7 +6119,7 @@ func.func @async_op(%arg0: tensor<10x10xf32>) -> tensor<32xf32> } func.func @async(%arg0: tensor<10x10xf32>) -> tensor<32xf32> { - // expected-error@+1 {{component #0 of return type doesn't match callee input types}} + // expected-error@+1 {{component #0 of async bundle doesn't match callee input types}} %0 = "mhlo.async_start"(%arg0) {called_computation=@async_op, execution_thread="thread"} : (tensor<10x10xf32>) -> !mhlo.async_bundle, tensor<32xf32>, tensor> %1 = "mhlo.async_update"(%0) {called_computation=@async_op, execution_thread="thread"} : (!mhlo.async_bundle, tensor<32xf32>, tensor>) -> !mhlo.async_bundle, tensor<32xf32>, tensor> %2 = "mhlo.async_done"(%1) {called_computation=@async_op, execution_thread="thread"} : (!mhlo.async_bundle, tensor<32xf32>, tensor>) -> tensor<32xf32> @@ -6298,7 +6135,7 @@ func.func @async_op(%arg0: tensor<10x10xf32>) -> tensor<32xf32> } func.func @async(%arg0: tensor<10x10xf32>) -> tensor<32xf32> { - // expected-error@+1 {{component #1 of return type doesn't match callee result types}} + // expected-error@+1 {{component #1 of async bundle doesn't match callee result types}} %0 = "mhlo.async_start"(%arg0) {called_computation=@async_op, execution_thread="thread"} : (tensor<10x10xf32>) -> !mhlo.async_bundle, tensor, tensor> %1 = "mhlo.async_update"(%0) {called_computation=@async_op, execution_thread="thread"} : (!mhlo.async_bundle, tensor, tensor>) -> !mhlo.async_bundle, tensor, tensor> %2 = "mhlo.async_done"(%1) {called_computation=@async_op, execution_thread="thread"} : (!mhlo.async_bundle, tensor, tensor>) -> tensor<32xf32> @@ -6343,7 +6180,7 @@ func.func @async_op(%arg0: tensor<10x10xf32>) -> tensor<32xf32> } func.func @async(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { - // expected-error@+1 {{result is expected to be a bundle of at least 2 components, but got 1}} + // expected-error@+1 {{bundle is expected to have at least 2 components, but got 1}} %0 = "mhlo.async_start"(%arg0) {called_computation=@async_op, execution_thread="thread"} : (tensor<10x10xf32>) -> !mhlo.async_bundle> func.return %arg0 : tensor<10x10xf32> } @@ -6462,7 +6299,7 @@ func.func @is_finite(%arg0: tensor<3xf32>) -> tensor<3xi1> { // ----- func.func @is_finite_int_input(%arg0: tensor<3xi32>) -> tensor<3xi1> { - // expected-error@+1 {{op operand #0 must be tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<3xi32>'}} + // expected-error@+1 {{op operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<3xi32>'}} %0 = "mhlo.is_finite"(%arg0) {} : (tensor<3xi32>) -> tensor<3xi1> func.return %0 : tensor<3xi1> } @@ -6470,7 +6307,7 @@ func.func @is_finite_int_input(%arg0: tensor<3xi32>) -> tensor<3xi1> { // ----- func.func @is_finite_mismatch_return_element_type(%arg0: tensor<3xf32>) -> tensor<3xi10> { - // expected-error@+1 {{result #0 must be tensor of pred (AKA boolean or 1-bit integer) values, but got 'tensor<3xi10>'}} + // expected-error@+1 {{result #0 must be ranked tensor of pred (AKA boolean or 1-bit integer) values, but got 'tensor<3xi10>'}} %0 = "mhlo.is_finite"(%arg0) {} : (tensor<3xf32>) -> tensor<3xi10> func.return %0 : tensor<3xi10> } @@ -6485,20 +6322,20 @@ func.func @is_finite_mismatch_return_shape(%arg0: tensor<3xf32>) -> tensor<4xi1> // ----- -func.func @negative_dimension_attr(%arg0: tensor>, %arg1: tensor) -> tensor<*xf32> { +func.func @negative_dimension_attr(%arg0: tensor>, %arg1: tensor) -> tensor { // expected-error@+2 {{'mhlo.set_dimension_size' op failed to infer returned types}} // expected-error@+1 {{requires non-negative dimension attribute; found (-1)}} - %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = -1 : i64} : (tensor>, tensor) -> tensor<*xf32> - func.return %result : tensor<*xf32> + %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = -1 : i64} : (tensor>, tensor) -> tensor + func.return %result : tensor } // ----- -func.func @invalid_dimension_attr(%arg0: tensor>, %arg1: tensor) -> tensor<*xf32> { +func.func @invalid_dimension_attr(%arg0: tensor>, %arg1: tensor) -> tensor { // expected-error@+2 {{'mhlo.set_dimension_size' op failed to infer returned types}} // expected-error@+1 {{requires dimension attribute in range [0, 2); found (2)}} - %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 2 : i64} : (tensor>, tensor) -> tensor<*xf32> - func.return %result : tensor<*xf32> + %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 2 : i64} : (tensor>, tensor) -> tensor + func.return %result : tensor } // ----- @@ -6545,13 +6382,6 @@ func.func @top_k_bounded(%arg0 : tensor) { - %0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<*xf32> -> (tensor<*xf32>, tensor<*xi32>) - return -} - -// ----- - func.func @top_k_1d_false(%arg0 : tensor<16xf32>) { %0:2 = mhlo.topk(%arg0, k=8, largest=false) : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) return @@ -6581,3 +6411,162 @@ func.func @topk_last_dimension_at_least_k(%arg0 : tensor<4xf32>) { %0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<4xf32> -> (tensor<8xf32>, tensor<8xi32>) return } + +// ----- + +func.func @first(%arg0: tensor, %arg1: tensor) -> tensor { + func.return %arg0 : tensor +} + +func.func @composite_generic(%arg0: tensor, %arg1: tensor) { + %0 = "mhlo.composite"(%arg0, %arg1) { + name = "mhlo.first", + decomposition = @first, + version = 1 : i32, + composite_attributes = { + an_attribute = "foo" + } + } : (tensor, tensor) -> tensor + func.return +} + +// ----- + +func.func @foo() { func.return } +func.func @composite_c1() { + // @expected-error@+1 {{name must be a valid namespaced op name}} + mhlo.composite "foo" { decomposition = @foo } : () -> () + func.return +} + +// ----- + +func.func @foo() { func.return } +func.func @composite_c1() { + // @expected-error@+1 {{name must be a valid namespaced op name}} + mhlo.composite "." { decomposition = @foo } : () -> () + func.return +} + +// ----- + +func.func @foo() { func.return } +func.func @composite_c1() { + // @expected-error@+1 {{name must be a valid namespaced op name}} + mhlo.composite "foo." { decomposition = @foo } : () -> () + func.return +} + +// ----- + +func.func @foo() { func.return } +func.func @composite_c1() { + // @expected-error@+1 {{name must be a valid namespaced op name}} + mhlo.composite ".foo" { decomposition = @foo } : () -> () + func.return +} + +// ----- + +func.func @foo() { func.return } +func.func @composite_c1() { + // @expected-error@+1 {{name must be a valid namespaced op name}} + mhlo.composite "0.foo" { decomposition = @foo } : () -> () + func.return +} + +// ----- + +func.func @foo() { func.return } +func.func @composite_c1() { + // @expected-error@+1 {{name must be a valid namespaced op name}} + mhlo.composite "foo.%" { decomposition = @foo } : () -> () + func.return +} + +// ----- + +func.func @foo() { func.return } +func.func @composite_c1() { + // @expected-error@+1 {{name must be a valid namespaced op name}} + mhlo.composite "foo.foo.%" { decomposition = @foo } : () -> () + func.return +} + +// ----- + +func.func @foo() { func.return } +func.func @composite_c1() { + // valid name + mhlo.composite "f00._.$" { decomposition = @foo } : () -> () + func.return +} + +// ----- + +func.func @composite_c2(%arg0: tensor) { + // @expected-error@+1 {{'nonexistent' does not reference a valid function}} + %0 = mhlo.composite "mhlo.nonexistent" %arg0 { + decomposition = @nonexistent + } : (tensor) -> tensor + func.return +} + +// ----- + +func.func @foo() -> !mhlo.token { + %0 = mhlo.create_token : !mhlo.token + func.return %0 : !mhlo.token +} + +func.func @composite_c3(%arg0: tensor) { + // @expected-error@+1 {{has 1 operand(s), but decomposition has 0}} + %0 = mhlo.composite "mhlo.identity" %arg0 { + decomposition = @foo + } : (tensor) -> !mhlo.token + func.return +} + +// ----- + +func.func @foo(%arg0: tensor) -> !mhlo.token { + %0 = mhlo.create_token : !mhlo.token + func.return %0 : !mhlo.token +} + +func.func @composite_c3(%arg0: tensor) { + // @expected-error@+1 {{operand at index 0 has type 'tensor', but decomposition has type 'tensor'}} + %0 = mhlo.composite "mhlo.identity" %arg0 { + decomposition = @foo + } : (tensor) -> !mhlo.token + func.return +} + +// ----- + +func.func @foo(%arg0: !mhlo.token) { + func.return +} + +func.func @composite_c4(%arg0: !mhlo.token) { + // @expected-error@+1 {{has 1 result(s), but decomposition has 0}} + %0 = mhlo.composite "mhlo.identity" %arg0 { + decomposition = @foo + } : (!mhlo.token) -> tensor + func.return +} + +// ----- + +func.func @foo(%arg0: !mhlo.token) -> tensor { + %0 = mhlo.constant dense<0.> : tensor + func.return %0 : tensor +} + +func.func @composite_c4(%arg0: !mhlo.token) { + // @expected-error@+1 {{result at index 0 has type 'tensor', but decomposition has type 'tensor'}} + %0 = mhlo.composite "mhlo.identity" %arg0 { + decomposition = @foo + } : (!mhlo.token) -> tensor + func.return +} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/optimize-hlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/optimize-hlo.mlir index f002d87fde84c..2bd6f76e7ad1e 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/optimize-hlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/optimize-hlo.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: @gather_is_slice_no_rank func.func @gather_is_slice_no_rank(%arg0: tensor<2x1x2xi32>, %arg1: tensor) -> tensor<1x2xi32> { // CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor - // CHECK: [[SLICE:%.+]] = "mhlo.dynamic_slice"(%arg0, %arg1, [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>} + // CHECK: [[SLICE:%.+]] = "mhlo.dynamic_slice"(%arg0, %arg1, [[CST]], [[CST]]) <{slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}> // CHECK: [[RESHAPE:%.+]] = mhlo.reshape [[SLICE]] %res = "mhlo.gather"(%arg0, %arg1) { dimension_numbers = #mhlo.gather< @@ -23,7 +23,7 @@ func.func @gather_is_slice_no_rank(%arg0: tensor<2x1x2xi32>, %arg1: tensor) func.func @gather_is_slice(%arg0: tensor<2x1x2xi32>, %arg1: tensor<1xi64>) -> tensor<1x2xi32> { // CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor // CHECK: [[RESHAPE:%.+]] = mhlo.reshape %arg1 - // CHECK: [[SLICE:%.+]] = "mhlo.dynamic_slice"(%arg0, [[RESHAPE]], [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>} + // CHECK: [[SLICE:%.+]] = "mhlo.dynamic_slice"(%arg0, [[RESHAPE]], [[CST]], [[CST]]) <{slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}> // CHECK: [[RES:%.+]] = mhlo.reshape [[SLICE]] %res = "mhlo.gather"(%arg0, %arg1) { diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir index 7b300ecad1466..05612efb2ee19 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir @@ -52,7 +52,7 @@ func.func @while_with_implicit_capture(%arg0 : tensor, %arg1 : tensor<5xi32 %1 = mhlo.constant dense : tensor // Check that the iota implicit capture is made explicit // CHECK: %[[IOTA:.*]] = "mhlo.iota - %2 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xi32> + %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<5xi32> // CHECK: mhlo.while{{.*}} %[[IOTA]]) %3:2 = "mhlo.while"(%arg0, %arg1) ({ ^bb0(%arg2: tensor, %arg3 : tensor<5xi32>): @@ -96,6 +96,159 @@ func.func @broadcast_in_dim_dimension_unsorted(%arg0: tensor<1x2xi32>) -> tensor // Unfuse the transpose from the broadcastInDim before export. // CHECK: %[[TRANSPOSE:.*]] = "mhlo.transpose"(%arg0){{.*}}permutation = dense<[1, 0]>{{.*}} -> tensor<2x1xi32> // CHECK: mhlo.broadcast_in_dim"(%[[TRANSPOSE]]){{.*}}broadcast_dimensions = dense<[1, 2]> - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[2, 1]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[2, 1]> : tensor<2xi64>}> : (tensor<1x2xi32>) -> tensor<1x2x3xi32> func.return %0 : tensor<1x2x3xi32> } + +// ----- + +// CHECK-LABEL: @reduce_with_multiple_implicit_captures +func.func @reduce_with_multiple_implicit_captures(%arg0: tensor<2x2xf32>) -> tuple> { + %0 = mhlo.constant dense<1.000000e+00> : tensor + %1 = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: mhlo.reduce + %2 = mhlo.reduce(%arg0 init: %1) across dimensions = [0, 1] : (tensor<2x2xf32>, tensor) -> tensor + reducer(%arg1: tensor, %arg2: tensor) { + // CHECK-DAG: mhlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: mhlo.constant dense<1.000000e+00> : tensor + // CHECK: mhlo.compare + %5 = mhlo.compare NE, %arg1, %1 : (tensor, tensor) -> tensor + %6 = mhlo.compare NE, %arg2, %1 : (tensor, tensor) -> tensor + %7 = mhlo.or %5, %6 : tensor + %8 = mhlo.select %7, %0, %1 : tensor, tensor + mhlo.return %8 : tensor + } + %3 = mhlo.compare NE, %2, %1 : (tensor, tensor) -> tensor + %4 = mhlo.tuple %3 {xla_shape = "(pred[])"} : tuple> + return %4 : tuple> +} + +// ----- + +// CHECK-LABEL: @all_reduce_with_implicit_capture +func.func @all_reduce_with_implicit_capture(%arg0: tensor) -> tensor { + %c = mhlo.constant dense<0.0> : tensor + // CHECK: mhlo.all_reduce + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: tensor, %[[ARG2:arg.*]]: tensor): + %0 = "mhlo.all_reduce"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + // CHECK: %[[VAL1:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: mhlo.add + // CHECK-SAME: %[[ARG1]], %[[VAL1]] + %1 = mhlo.add %arg1, %c : tensor + mhlo.return %1 : tensor + }) {replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>} : (tensor) -> tensor + return %0 : tensor + } + +// ----- + +// CHECK-LABEL: @reduce_scatter_with_implicit_capture +func.func @reduce_scatter_with_implicit_capture(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { + %c = mhlo.constant dense<0.0> : tensor + // CHECK: mhlo.reduce_scatter + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: tensor, %[[ARG2:arg.*]]: tensor): + %0 = "mhlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + // CHECK: %[[VAL1:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: mhlo.add + // CHECK-SAME: %[[ARG1]], %[[VAL1]] + %1 = mhlo.add %arg2, %c : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #mhlo.channel_handle, + use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf32> + func.return %0 : tensor<4x4xf32> +} + +// ----- + +// CHECK-LABEL: @reduce_window_with_implicit_capture +func.func @reduce_window_with_implicit_capture(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x16x30x7xf32> { + %c = mhlo.constant dense<0.0> : tensor + // CHECK: mhlo.reduce_window + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: tensor, %[[ARG3:arg.*]]: tensor): + %0 = "mhlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + // CHECK: %[[VAL1:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: mhlo.maximum + // CHECK-SAME: %[[ARG2]], %[[VAL1]] + %1 = mhlo.maximum %arg2, %c : tensor + mhlo.return %1 : tensor + }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x16x30x7xf32> + return %0 : tensor<2x16x30x7xf32> + } + +// ----- + +// CHECK-LABEL: @scatter_with_implicit_capture +func.func @scatter_with_implicit_capture(%arg0: tensor<3xi32>, %arg1: tensor<1x1xi32>, + %arg2: tensor<1xi32>) -> tensor<3xi32> { + %c = mhlo.constant dense<0> : tensor + // CHECK: mhlo.scatter + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: tensor, %[[ARG4:arg.*]]: tensor): + %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + // CHECK: %[[VAL1:.*]] = mhlo.constant dense<0> : tensor + // CHECK: mhlo.add + // CHECK-SAME: %[[ARG4]], %[[VAL1]] + %x = mhlo.add %arg4, %c : tensor + "mhlo.return"(%x) : (tensor) -> () + }) { + indices_are_sorted = false, + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [], + inserted_window_dims = [0], + scatter_dims_to_operand_dims = [0], + index_vector_dim = 1, + >, + unique_indices = false + } : (tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>) -> tensor<3xi32> + func.return %0 : tensor<3xi32> +} + +// ----- + +// CHECK-LABEL: @select_and_scatter_with_implicit_capture +func.func @select_and_scatter_with_implicit_capture(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x23x23x64xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { + %c1 = mhlo.constant dense<0.0> : tensor + %c2 = mhlo.constant dense<0.0> : tensor + // CHECK: mhlo.select_and_scatter + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: tensor, %[[ARG4:arg.*]]: tensor): + %0 = "mhlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + // CHECK: %[[VAL1:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: mhlo.compare + // CHECK-SAME: %[[ARG3]], %[[VAL1]] + %1 = mhlo.compare GE, %arg3, %c1, TOTALORDER : (tensor, tensor) -> tensor + mhlo.return %1 : tensor + }, { + // CHECK: ^[[BB:bb.*]](%[[ARG3:arg.*]]: tensor, %[[ARG4:arg.*]]: tensor): + ^bb0(%arg3: tensor, %arg4: tensor): + // CHECK: %[[VAL2:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: mhlo.add + // CHECK-SAME: %[[ARG4]], %[[VAL2]] + %1 = mhlo.add %arg4, %c2 : tensor + mhlo.return %1 : tensor + }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<10x24x24x64xf32>, tensor<10x23x23x64xf32>, tensor) -> tensor<10x24x24x64xf32> + return %0 : tensor<10x24x24x64xf32> + } + +// ----- + +// CHECK-LABEL: @sort_with_implicit_capture +func.func @sort_with_implicit_capture(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { + %c = mhlo.constant dense<0.0> : tensor + // CHECK: mhlo.sort + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG0:arg.*]]: tensor, %[[ARG1:arg.*]]: tensor, %[[ARG2:arg.*]]: tensor, %[[ARG3:arg.*]]: tensor): + %0:2 = "mhlo.sort"(%input0, %input1) ({ + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + // CHECK: %[[VAL1:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: mhlo.compare + // CHECK-SAME: %[[ARG0]], %[[VAL1]] + %7 = "mhlo.compare"(%arg0, %c) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) + func.return +} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/reify-result-types.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/reify-result-types.mlir index c7bdf3565962b..d6601a029e72d 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/reify-result-types.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/reify-result-types.mlir @@ -2,49 +2,49 @@ // RUN: -split-input-file %s -o - | FileCheck %s // CHECK-LABEL: @dynamic_broadcast_i32_shape -func.func @dynamic_broadcast_i32_shape(%arg0 : tensor, %arg1 : tensor<*xf32>) +func.func @dynamic_broadcast_i32_shape(%arg0 : tensor<3xi32>, %arg1 : tensor) -> index { // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[EXTRACT:.*]] = tensor.extract %arg0[%[[C0]]] : tensor + // CHECK: %[[EXTRACT:.*]] = tensor.extract %arg0[%[[C0]]] : tensor<3xi32> // CHECK: %[[CAST:.*]] = arith.index_cast %[[EXTRACT]] : i32 to index // CHECK: return %[[CAST]] %c0 = arith.constant 0 : index %0 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %arg0) { broadcast_dimensions = dense<0> : tensor<1xi64> } - : (tensor<*xf32>, tensor) -> tensor<*xf32> - %1 = tensor.dim %0, %c0 : tensor<*xf32> + : (tensor, tensor<3xi32>) -> tensor + %1 = tensor.dim %0, %c0 : tensor func.return %1 : index } // ----- // CHECK-LABEL: @dynamic_iota_i32_shape -func.func @dynamic_iota_i32_shape(%arg0 : tensor) -> index { +func.func @dynamic_iota_i32_shape(%arg0 : tensor<3xi32>) -> index { // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[EXTRACT:.*]] = tensor.extract %arg0[%[[C0]]] : tensor + // CHECK: %[[EXTRACT:.*]] = tensor.extract %arg0[%[[C0]]] : tensor<3xi32> // CHECK: %[[CAST:.*]] = arith.index_cast %[[EXTRACT]] : i32 to index // CHECK: return %[[CAST]] %c0 = arith.constant 0 : index %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} - : (tensor) -> tensor - %1 = tensor.dim %0, %c0 : tensor + : (tensor<3xi32>) -> tensor + %1 = tensor.dim %0, %c0 : tensor func.return %1 : index } // ----- // CHECK-LABEL: @dynamic_reshape_i32_shape -func.func @dynamic_reshape_i32_shape(%arg0 : tensor, %arg1 : tensor<*xf32>) +func.func @dynamic_reshape_i32_shape(%arg0 : tensor<3xi32>, %arg1 : tensor) -> index { // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[EXTRACT:.*]] = tensor.extract %arg0[%[[C0]]] : tensor + // CHECK: %[[EXTRACT:.*]] = tensor.extract %arg0[%[[C0]]] : tensor<3xi32> // CHECK: %[[CAST:.*]] = arith.index_cast %[[EXTRACT]] : i32 to index // CHECK: return %[[CAST]] %c0 = arith.constant 0 : index %0 = "mhlo.dynamic_reshape"(%arg1, %arg0) { broadcast_dimensions = dense<0> : tensor<1xi64> } - : (tensor<*xf32>, tensor) -> tensor<*xf32> - %1 = tensor.dim %0, %c0 : tensor<*xf32> + : (tensor, tensor<3xi32>) -> tensor + %1 = tensor.dim %0, %c0 : tensor func.return %1 : index } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/restrict_max_rank.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/restrict_max_rank.mlir index 21a5b7557df02..45405222f4c9a 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/restrict_max_rank.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/restrict_max_rank.mlir @@ -5,14 +5,14 @@ func.func @ReduceTransposeReduce4D(%arg0 : tensor<17x6x35x13xf32>) -> tensor<357x2x5x13xf32> { // CHECK: %[[OUT0:.*]] = mhlo.reshape %arg0 : (tensor<17x6x35x13xf32>) -> tensor<17x6x5x7x13xf32> - // CHECK: %[[OUT1:.*]] = "mhlo.transpose"(%[[OUT0]]) {permutation = dense<[3, 0, 1, 2, 4]> : tensor<5xi64>} : (tensor<17x6x5x7x13xf32>) -> tensor<7x17x6x5x13xf32> + // CHECK: %[[OUT1:.*]] = "mhlo.transpose"(%[[OUT0]]) <{permutation = dense<[3, 0, 1, 2, 4]> : tensor<5xi64>}> : (tensor<17x6x5x7x13xf32>) -> tensor<7x17x6x5x13xf32> // CHECK: %[[OUT2:.*]] = mhlo.reshape %[[OUT1]] : (tensor<7x17x6x5x13xf32>) -> tensor<119x2x3x5x13xf32> - // CHECK: %[[OUT3:.*]] = "mhlo.transpose"(%[[OUT2]]) {permutation = dense<[2, 0, 1, 3, 4]> : tensor<5xi64>} : (tensor<119x2x3x5x13xf32>) -> tensor<3x119x2x5x13xf32> + // CHECK: %[[OUT3:.*]] = "mhlo.transpose"(%[[OUT2]]) <{permutation = dense<[2, 0, 1, 3, 4]> : tensor<5xi64>}> : (tensor<119x2x3x5x13xf32>) -> tensor<3x119x2x5x13xf32> // CHECK: %[[OUT4:.*]] = mhlo.reshape %[[OUT3]] : (tensor<3x119x2x5x13xf32>) -> tensor<357x2x5x13xf32> // CHECK: return %[[OUT4]] %0 = "mhlo.reshape"(%arg0) : (tensor<17x6x35x13xf32>) -> tensor<17x2x3x5x7x13xf32> - %1 = "mhlo.transpose"(%0) {permutation = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi64>} : (tensor<17x2x3x5x7x13xf32>) -> tensor<3x7x17x2x5x13xf32> + %1 = "mhlo.transpose"(%0) <{permutation = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi64>}> : (tensor<17x2x3x5x7x13xf32>) -> tensor<3x7x17x2x5x13xf32> %2 = "mhlo.reshape"(%1) : (tensor<3x7x17x2x5x13xf32>) -> tensor<357x2x5x13xf32> return %2 : tensor<357x2x5x13xf32> } @@ -23,16 +23,16 @@ func.func @ReduceTransposeReduce4D(%arg0 : tensor<17x6x35x13xf32>) -> tensor<357 func.func @ReduceTransposeReduce5D(%arg0 : tensor<17x6x35x15x13xf32>) -> tensor<1785x2x5x3x13xf32> { // CHECK: %[[OUT0:.*]] = mhlo.reshape %arg0 : (tensor<17x6x35x15x13xf32>) -> tensor<17x6x35x3x5x13xf32> - // CHECK: %[[OUT1:.*]] = "mhlo.transpose"(%[[OUT0]]) {permutation = dense<[4, 0, 1, 2, 3, 5]> : tensor<6xi64>} : (tensor<17x6x35x3x5x13xf32>) -> tensor<5x17x6x35x3x13xf32> + // CHECK: %[[OUT1:.*]] = "mhlo.transpose"(%[[OUT0]]) <{permutation = dense<[4, 0, 1, 2, 3, 5]> : tensor<6xi64>}> : (tensor<17x6x35x3x5x13xf32>) -> tensor<5x17x6x35x3x13xf32> // CHECK: %[[OUT2:.*]] = mhlo.reshape %[[OUT1]] : (tensor<5x17x6x35x3x13xf32>) -> tensor<85x6x5x7x3x13xf32> - // CHECK: %[[OUT3:.*]] = "mhlo.transpose"(%[[OUT2]]) {permutation = dense<[3, 0, 1, 2, 4, 5]> : tensor<6xi64>} : (tensor<85x6x5x7x3x13xf32>) -> tensor<7x85x6x5x3x13xf32> + // CHECK: %[[OUT3:.*]] = "mhlo.transpose"(%[[OUT2]]) <{permutation = dense<[3, 0, 1, 2, 4, 5]> : tensor<6xi64>}> : (tensor<85x6x5x7x3x13xf32>) -> tensor<7x85x6x5x3x13xf32> // CHECK: %[[OUT4:.*]] = mhlo.reshape %[[OUT3]] : (tensor<7x85x6x5x3x13xf32>) -> tensor<595x2x3x5x3x13xf32> - // CHECK: %[[OUT5:.*]] = "mhlo.transpose"(%[[OUT4]]) {permutation = dense<[2, 0, 1, 3, 4, 5]> : tensor<6xi64>} : (tensor<595x2x3x5x3x13xf32>) -> tensor<3x595x2x5x3x13xf32> + // CHECK: %[[OUT5:.*]] = "mhlo.transpose"(%[[OUT4]]) <{permutation = dense<[2, 0, 1, 3, 4, 5]> : tensor<6xi64>}> : (tensor<595x2x3x5x3x13xf32>) -> tensor<3x595x2x5x3x13xf32> // CHECK: %[[OUT6:.*]] = mhlo.reshape %[[OUT5]] : (tensor<3x595x2x5x3x13xf32>) -> tensor<1785x2x5x3x13xf32> // CHECK: return %[[OUT6]] %0 = "mhlo.reshape"(%arg0) : (tensor<17x6x35x15x13xf32>) -> tensor<17x2x3x5x7x3x5x13xf32> - %1 = "mhlo.transpose"(%0) {permutation = dense<[2, 4, 6, 0, 1, 3, 5, 7]> : tensor<8xi64>} : (tensor<17x2x3x5x7x3x5x13xf32>) -> tensor<3x7x5x17x2x5x3x13xf32> + %1 = "mhlo.transpose"(%0) <{permutation = dense<[2, 4, 6, 0, 1, 3, 5, 7]> : tensor<8xi64>}> : (tensor<17x2x3x5x7x3x5x13xf32>) -> tensor<3x7x5x17x2x5x3x13xf32> %2 = "mhlo.reshape"(%1) : (tensor<3x7x5x17x2x5x3x13xf32>) -> tensor<1785x2x5x3x13xf32> return %2 : tensor<1785x2x5x3x13xf32> } @@ -44,9 +44,9 @@ func.func @ReduceTransposeReduce4D(%arg0 : tensor<17x6x35x13xf32>) -> tensor<357 %0 = "mhlo.reshape"(%arg0) : (tensor<17x6x35x13xf32>) -> tensor<17x2x3x5x7x13xf32> // Shouldn't modify this transpose op as it doesn't meet the criteria. - // CHECK: "mhlo.transpose"(%{{.*}}) {permutation = dense<[4, 2, 0, 1, 3, 5]> : tensor<6xi64>} : (tensor<17x2x3x5x7x13xf32>) -> tensor<7x3x17x2x5x13xf32> + // CHECK: "mhlo.transpose"(%{{.*}}) <{permutation = dense<[4, 2, 0, 1, 3, 5]> : tensor<6xi64>}> : (tensor<17x2x3x5x7x13xf32>) -> tensor<7x3x17x2x5x13xf32> - %1 = "mhlo.transpose"(%0) {permutation = dense<[4, 2, 0, 1, 3, 5]> : tensor<6xi64>} : (tensor<17x2x3x5x7x13xf32>) -> tensor<7x3x17x2x5x13xf32> + %1 = "mhlo.transpose"(%0) <{permutation = dense<[4, 2, 0, 1, 3, 5]> : tensor<6xi64>}> : (tensor<17x2x3x5x7x13xf32>) -> tensor<7x3x17x2x5x13xf32> %2 = "mhlo.reshape"(%1) : (tensor<7x3x17x2x5x13xf32>) -> tensor<357x2x5x13xf32> return %2 : tensor<357x2x5x13xf32> } @@ -58,8 +58,8 @@ func.func @ReduceTransposeReduce4D(%arg0 : tensor<17x6x35x13xf32>) -> tensor<3x2 %0 = "mhlo.reshape"(%arg0) : (tensor<17x6x35x13xf32>) -> tensor<17x2x3x5x7x13xf32> // Shouldn't modify this transpose op as it doesn't meet the criteria. - // CHECK: "mhlo.transpose"(%{{.*}}) {permutation = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi64>} : (tensor<17x2x3x5x7x13xf32>) -> tensor<3x7x17x2x5x13xf32> - %1 = "mhlo.transpose"(%0) {permutation = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi64>} : (tensor<17x2x3x5x7x13xf32>) -> tensor<3x7x17x2x5x13xf32> + // CHECK: "mhlo.transpose"(%{{.*}}) <{permutation = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi64>}> : (tensor<17x2x3x5x7x13xf32>) -> tensor<3x7x17x2x5x13xf32> + %1 = "mhlo.transpose"(%0) <{permutation = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi64>}> : (tensor<17x2x3x5x7x13xf32>) -> tensor<3x7x17x2x5x13xf32> %2 = "mhlo.reshape"(%1) : (tensor<3x7x17x2x5x13xf32>) -> tensor<3x238x5x13xf32> return %2 : tensor<3x238x5x13xf32> diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/shape_cstr_legalize_to_hlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/shape_cstr_legalize_to_hlo.mlir new file mode 100644 index 0000000000000..e5b645c816a53 --- /dev/null +++ b/xla/mlir_hlo/tests/Dialect/mhlo/shape_cstr_legalize_to_hlo.mlir @@ -0,0 +1,189 @@ +// RUN: mlir-hlo-opt --shape-legalize-to-hlo=legalize-constraints=true --split-input-file --verify-diagnostics %s | FileCheck %s + +// ----- + +// CHECK-LABEL: func.func @shape_cstr_broadcastable +func.func @shape_cstr_broadcastable(%arg0: tensor<2xindex>, %arg1: tensor<2xindex>) { + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<2xindex> + shape.assuming %0 { + } + func.return + // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[ONES:.*]] = mhlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = mhlo.compare EQ, %[[DIMS1]], %[[ONES:.*]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = mhlo.compare EQ, %[[DIMS2]], %[[ONES:.*]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = mhlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> + // CHECK-NEXT: %[[DIMS_EQ:.*]] = mhlo.compare EQ, %[[DIMS1]], %[[DIMS2]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = mhlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> + // CHECK-NEXT: %[[TRUE:.*]] = mhlo.constant dense : tensor<1xi1> + // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = "mhlo.slice"(%[[DIMS_BROADCASTABLE]]) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = mhlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = "mhlo.slice"(%[[DIMS_BROADCASTABLE]]) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = mhlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = mhlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor + // CHECK-NEXT: mhlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true + // CHECK-NEXT: shape.assuming %[[WITNESS]] { + // CHECK-NEXT: } + // CHECK-NEXT: return +} + +// ----- + +func.func @shape_cstr_broadcastable_input_shape(%arg0: !shape.shape, %arg1: !shape.shape) { + // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} + %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape + shape.assuming %0 { + } + func.return +} + +// ----- + +func.func @shape_cstr_broadcastable_different_dims_1(%arg0: tensor<2xindex>, %arg1: tensor<1xindex>) { + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<1xindex> + shape.assuming %0 { + } + func.return + // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<1xindex> to tensor<1xi32> + // CHECK-NEXT: %[[PAD:.*]] = mhlo.constant dense<1> : tensor<1xi32> + // CHECK-NEXT: %[[DIMS2_PAD:.*]] = "mhlo.concatenate"(%[[PAD]], %[[DIMS2]]) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + // CHECK-NEXT: %[[ONES:.*]] = mhlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = mhlo.compare EQ, %[[DIMS1]], %[[ONES:.*]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = mhlo.compare EQ, %[[DIMS2_PAD]], %[[ONES:.*]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = mhlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> + // CHECK-NEXT: %[[DIMS_EQ:.*]] = mhlo.compare EQ, %[[DIMS1]], %[[DIMS2_PAD]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = mhlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> + // CHECK-NEXT: %[[TRUE:.*]] = mhlo.constant dense : tensor<1xi1> + // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = "mhlo.slice"(%[[DIMS_BROADCASTABLE]]) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = mhlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = "mhlo.slice"(%[[DIMS_BROADCASTABLE]]) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = mhlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = mhlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor + // CHECK-NEXT: mhlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true + // CHECK-NEXT: shape.assuming %[[WITNESS]] { + // CHECK-NEXT: } + // CHECK-NEXT: return +} + +// ----- + +func.func @shape_cstr_broadcastable_different_dims_2(%arg0: tensor<1xindex>, %arg1: tensor<2xindex>) { + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<1xindex>, tensor<2xindex> + shape.assuming %0 { + } + func.return + // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<1xindex> to tensor<1xi32> + // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[PAD:.*]] = mhlo.constant dense<1> : tensor<1xi32> + // CHECK-NEXT: %[[DIMS1_PAD:.*]] = "mhlo.concatenate"(%[[PAD]], %[[DIMS1]]) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + // CHECK-NEXT: %[[ONES:.*]] = mhlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = mhlo.compare EQ, %[[DIMS1_PAD]], %[[ONES:.*]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = mhlo.compare EQ, %[[DIMS2]], %[[ONES:.*]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = mhlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> + // CHECK-NEXT: %[[DIMS_EQ:.*]] = mhlo.compare EQ, %[[DIMS1_PAD]], %[[DIMS2]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = mhlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> + // CHECK-NEXT: %[[TRUE:.*]] = mhlo.constant dense : tensor<1xi1> + // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = "mhlo.slice"(%[[DIMS_BROADCASTABLE]]) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = mhlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = "mhlo.slice"(%[[DIMS_BROADCASTABLE]]) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = mhlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = mhlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor + // CHECK-NEXT: mhlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true + // CHECK-NEXT: shape.assuming %[[WITNESS]] { + // CHECK-NEXT: } + // CHECK-NEXT: return +} + +// ----- + +func.func @shape_cstr_broadcast_too_many_operands(%arg0: tensor<4xindex>, %arg1: tensor<4xindex>, %arg2: tensor<4xindex>) { + // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} + %0 = shape.cstr_broadcastable %arg0, %arg1, %arg2 : tensor<4xindex>, tensor<4xindex>, tensor<4xindex> + shape.assuming %0 { + } + func.return +} + +// ----- + +func.func @mhlo_cstr_reshapable(%arg0: index, %arg1: tensor<2xindex>) { + %0 = mhlo.cstr_reshapable %arg0, %arg1 : (index, tensor<2xindex>) -> !shape.witness + func.return + // CHECK-DAG: %[[NUM_ELEMENTS:.*]] = builtin.unrealized_conversion_cast %arg0 : index to tensor + // CHECK-DAG: %[[DYNAMIC_SHAPE:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<2xindex> to tensor<2xi32> + // CHECK-DAG: %[[MINUS_ONE:.*]] = mhlo.constant dense<-1> : tensor + // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1> : tensor + // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor + // CHECK-NEXT: %[[DIM_SIZE_1:.*]] = "mhlo.slice"(%[[DYNAMIC_SHAPE]]) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> + // CHECK-NEXT: %[[DIM_SIZE_SCALAR_1:.*]] = mhlo.reshape %[[DIM_SIZE_1]] : (tensor<1xi32>) -> tensor + // CHECK-NEXT: %[[ALL_DIMS_PRODUCT_1:.*]] = mhlo.multiply %[[ONE]], %[[DIM_SIZE_SCALAR_1]] : tensor + // CHECK-NEXT: %[[EQ_MINUS_ONE_1:.*]] = mhlo.compare EQ, %[[DIM_SIZE_SCALAR_1]], %[[MINUS_ONE]], NOTYPE : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[DYNAMIC_DIM_1:.*]] = mhlo.select %[[EQ_MINUS_ONE_1]], %[[ONE]], %[[ZERO]] : tensor, tensor + // CHECK-NEXT: %[[NUM_DYNAMIC_DIM_1:.*]] = mhlo.add %[[ZERO]], %[[DYNAMIC_DIM_1]] : tensor + // CHECK-NEXT: %[[DIM_SIZE_2:.*]] = "mhlo.slice"(%[[DYNAMIC_SHAPE]]) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> + // CHECK-NEXT: %[[DIM_SIZE_SCALAR_2:.*]] = mhlo.reshape %[[DIM_SIZE_2]] : (tensor<1xi32>) -> tensor + // CHECK-NEXT: %[[ALL_DIMS_PRODUCT:.*]] = mhlo.multiply %[[ALL_DIMS_PRODUCT_1]], %[[DIM_SIZE_SCALAR_2]] : tensor + // CHECK-NEXT: %[[EQ_MINUS_ONE_2:.*]] = mhlo.compare EQ, %[[DIM_SIZE_SCALAR_2]], %[[MINUS_ONE]], NOTYPE : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[DYNAMIC_DIM_2:.*]] = mhlo.select %[[EQ_MINUS_ONE_2]], %[[ONE]], %[[ZERO]] : tensor, tensor + // CHECK-NEXT: %[[NUM_DYNAMIC_DIM:.*]] = mhlo.add %[[NUM_DYNAMIC_DIM_1]], %[[DYNAMIC_DIM_2]] : tensor + // CHECK-NEXT: %[[ONLY_ONE_DYNAMIC_DIM:.*]] = mhlo.compare EQ, %[[NUM_DYNAMIC_DIM]], %[[ONE]], NOTYPE : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[STATIC_DIMS_PRODUCT:.*]] = mhlo.multiply %[[ALL_DIMS_PRODUCT]], %[[MINUS_ONE]] : tensor + // CHECK-NEXT: %[[REM:.*]] = mhlo.remainder %[[NUM_ELEMENTS]], %[[STATIC_DIMS_PRODUCT]] : tensor + // CHECK-NEXT: %[[NO_RESIDUAL:.*]] = mhlo.compare EQ, %[[REM]], %[[ZERO]], NOTYPE : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[DYNAMIC_RESHAPABLE:.*]] = mhlo.and %[[NO_RESIDUAL]], %[[ONLY_ONE_DYNAMIC_DIM]] : tensor + // CHECK-NEXT: %[[NO_DYNAMIC_DIM:.*]] = mhlo.compare EQ, %16, %[[ZERO]], NOTYPE : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[NUM_ELEMENTS_EQUALS:.*]] = mhlo.compare EQ, %[[ALL_DIMS_PRODUCT]], %[[NUM_ELEMENTS]], NOTYPE : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[STATIC_RESHAPABLE:.*]] = mhlo.and %[[NO_DYNAMIC_DIM]], %[[NUM_ELEMENTS_EQUALS]] : tensor + // CHECK-NEXT: %[[RESHAPABLE:.*]] = mhlo.or %[[DYNAMIC_RESHAPABLE]], %[[STATIC_RESHAPABLE]] : tensor + // CHECK-NEXT: mhlo.custom_call @shape_assertion(%[[RESHAPABLE]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () +} + +// ----- + +// CHECK-LABEL: func.func @mhlo_cstr_reshapable_const +func.func @mhlo_cstr_reshapable_const(%arg0: tensor) { + %0 = arith.constant 20 : index + %1 = mhlo.constant dense<[-1, 4]> : tensor<2xi32> + %2 = mhlo.cstr_reshapable %0, %1 : (index, tensor<2xi32>) -> !shape.witness + func.return + // CHECK-DAG: %[[DYNAMIC_SHAPE:.*]] = mhlo.constant dense<[-1, 4]> : tensor<2xi32> + // CHECK-DAG: %[[NUM_ELEMENTS:.*]] = mhlo.constant dense<20> : tensor + // CHECK-DAG: %[[MINUS_ONE:.*]] = mhlo.constant dense<-1> : tensor + // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1> : tensor + // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor + // CHECK-NEXT: %[[DIM_SIZE_1:.*]] = "mhlo.slice"(%[[DYNAMIC_SHAPE]]) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> + // CHECK-NEXT: %[[DIM_SIZE_SCALAR_1:.*]] = mhlo.reshape %[[DIM_SIZE_1]] : (tensor<1xi32>) -> tensor + // CHECK-NEXT: %[[ALL_DIMS_PRODUCT_1:.*]] = mhlo.multiply %[[ONE]], %[[DIM_SIZE_SCALAR_1]] : tensor + // CHECK-NEXT: %[[EQ_MINUS_ONE_1:.*]] = mhlo.compare EQ, %[[DIM_SIZE_SCALAR_1]], %[[MINUS_ONE]], NOTYPE : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[DYNAMIC_DIM_1:.*]] = mhlo.select %[[EQ_MINUS_ONE_1]], %[[ONE]], %[[ZERO]] : tensor, tensor + // CHECK-NEXT: %[[NUM_DYNAMIC_DIM_1:.*]] = mhlo.add %[[ZERO]], %[[DYNAMIC_DIM_1]] : tensor + // CHECK-NEXT: %[[DIM_SIZE_2:.*]] = "mhlo.slice"(%[[DYNAMIC_SHAPE]]) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> + // CHECK-NEXT: %[[DIM_SIZE_SCALAR_2:.*]] = mhlo.reshape %[[DIM_SIZE_2]] : (tensor<1xi32>) -> tensor + // CHECK-NEXT: %[[ALL_DIMS_PRODUCT:.*]] = mhlo.multiply %[[ALL_DIMS_PRODUCT_1]], %[[DIM_SIZE_SCALAR_2]] : tensor + // CHECK-NEXT: %[[EQ_MINUS_ONE_2:.*]] = mhlo.compare EQ, %[[DIM_SIZE_SCALAR_2]], %[[MINUS_ONE]], NOTYPE : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[DYNAMIC_DIM_2:.*]] = mhlo.select %[[EQ_MINUS_ONE_2]], %[[ONE]], %[[ZERO]] : tensor, tensor + // CHECK-NEXT: %[[NUM_DYNAMIC_DIM:.*]] = mhlo.add %[[NUM_DYNAMIC_DIM_1]], %[[DYNAMIC_DIM_2]] : tensor + // CHECK-NEXT: %[[ONLY_ONE_DYNAMIC_DIM:.*]] = mhlo.compare EQ, %[[NUM_DYNAMIC_DIM]], %[[ONE]], NOTYPE : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[STATIC_DIMS_PRODUCT:.*]] = mhlo.multiply %[[ALL_DIMS_PRODUCT]], %[[MINUS_ONE]] : tensor + // CHECK-NEXT: %[[REM:.*]] = mhlo.remainder %[[NUM_ELEMENTS]], %[[STATIC_DIMS_PRODUCT]] : tensor + // CHECK-NEXT: %[[NO_RESIDUAL:.*]] = mhlo.compare EQ, %[[REM]], %[[ZERO]], NOTYPE : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[DYNAMIC_RESHAPABLE:.*]] = mhlo.and %[[NO_RESIDUAL]], %[[ONLY_ONE_DYNAMIC_DIM]] : tensor + // CHECK-NEXT: %[[NO_DYNAMIC_DIM:.*]] = mhlo.compare EQ, %16, %[[ZERO]], NOTYPE : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[NUM_ELEMENTS_EQUALS:.*]] = mhlo.compare EQ, %[[ALL_DIMS_PRODUCT]], %[[NUM_ELEMENTS]], NOTYPE : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[STATIC_RESHAPABLE:.*]] = mhlo.and %[[NO_DYNAMIC_DIM]], %[[NUM_ELEMENTS_EQUALS]] : tensor + // CHECK-NEXT: %[[RESHAPABLE:.*]] = mhlo.or %[[DYNAMIC_RESHAPABLE]], %[[STATIC_RESHAPABLE]] : tensor + // CHECK-NEXT: mhlo.custom_call @shape_assertion(%[[RESHAPABLE]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () +} + +// ----- + +func.func @mhlo_cstr_reshapable_i8(%arg0: index, %arg1: tensor<2xi8>) { + // expected-error@+1 {{failed to legalize operation 'mhlo.cstr_reshapable' that was explicitly marked illegal}} + %0 = mhlo.cstr_reshapable %arg0, %arg1 : (index, tensor<2xi8>) -> !shape.witness + func.return +} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/shape_cstr_legalize_to_hlo_e2e.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/shape_cstr_legalize_to_hlo_e2e.mlir new file mode 100644 index 0000000000000..4d99d6a03294e --- /dev/null +++ b/xla/mlir_hlo/tests/Dialect/mhlo/shape_cstr_legalize_to_hlo_e2e.mlir @@ -0,0 +1,82 @@ +// RUN: mlir-hlo-opt --shape-legalize-to-hlo=legalize-constraints=true -reconcile-unrealized-casts -canonicalize --split-input-file --verify-diagnostics %s | FileCheck %s +// This test verifies e2e lowering of cstr ops result is correct for constant inputs. + +// ----- + +// CHECK-LABEL: func.func @mhlo_cstr_reshapable_true +func.func @mhlo_cstr_reshapable_true(%arg0: tensor) -> tensor { + %0 = arith.constant 16 : index + %1 = mhlo.constant dense<[-1, 4, 2]> : tensor<3xi32> + %2 = mhlo.cstr_reshapable %0, %1 : (index, tensor<3xi32>) -> !shape.witness + %3 = shape.assuming %2 -> tensor { + shape.assuming_yield %arg0 : tensor + } + func.return %3 : tensor + // CHECK: %[[TRUE:.*]] = mhlo.constant dense : tensor + // CHECK-NEXT: mhlo.custom_call @shape_assertion(%[[TRUE]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: return %arg0 : tensor +} + +// ----- + +// CHECK-LABEL: func.func @mhlo_cstr_reshapable_has_residual +func.func @mhlo_cstr_reshapable_has_residual(%arg0: tensor) -> tensor { + %0 = arith.constant 19 : index + %1 = mhlo.constant dense<[-1, 4]> : tensor<2xi32> + %2 = mhlo.cstr_reshapable %0, %1 : (index, tensor<2xi32>) -> !shape.witness + %3 = shape.assuming %2 -> tensor { + shape.assuming_yield %arg0 : tensor + } + func.return %3 : tensor + // CHECK: %[[FALSE:.*]] = mhlo.constant dense : tensor + // CHECK-NEXT: mhlo.custom_call @shape_assertion(%[[FALSE]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: return %arg0 : tensor +} + +// ----- + +// CHECK-LABEL: func.func @mhlo_cstr_reshapable_2_dynamic_dims +func.func @mhlo_cstr_reshapable_2_dynamic_dims(%arg0: tensor) -> tensor { + %0 = arith.constant 20 : index + %1 = mhlo.constant dense<[-1, 4, -1]> : tensor<3xi32> + %2 = mhlo.cstr_reshapable %0, %1 : (index, tensor<3xi32>) -> !shape.witness + %3 = shape.assuming %2 -> tensor { + shape.assuming_yield %arg0 : tensor + } + func.return %3 : tensor + // CHECK: %[[FALSE:.*]] = mhlo.constant dense : tensor + // CHECK-NEXT: mhlo.custom_call @shape_assertion(%[[FALSE]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: return %arg0 : tensor +} + +// ----- + +// CHECK-LABEL: func.func @mhlo_cstr_reshapable_static_true +func.func @mhlo_cstr_reshapable_static_true(%arg0: tensor) -> tensor { + %0 = arith.constant 20 : index + %1 = mhlo.constant dense<[1, 4, 5]> : tensor<3xi32> + %2 = mhlo.cstr_reshapable %0, %1 : (index, tensor<3xi32>) -> !shape.witness + %3 = shape.assuming %2 -> tensor { + shape.assuming_yield %arg0 : tensor + } + func.return %3 : tensor + // CHECK: %[[TRUE:.*]] = mhlo.constant dense : tensor + // CHECK-NEXT: mhlo.custom_call @shape_assertion(%[[TRUE]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: return %arg0 : tensor +} + +// ----- + +// CHECK-LABEL: func.func @mhlo_cstr_reshapable_static_false +func.func @mhlo_cstr_reshapable_static_false(%arg0: tensor) -> tensor { + %0 = arith.constant 21 : index + %1 = mhlo.constant dense<[1, 4, 5]> : tensor<3xi32> + %2 = mhlo.cstr_reshapable %0, %1 : (index, tensor<3xi32>) -> !shape.witness + %3 = shape.assuming %2 -> tensor { + shape.assuming_yield %arg0 : tensor + } + func.return %3 : tensor + // CHECK: %[[FALSE:.*]] = mhlo.constant dense : tensor + // CHECK-NEXT: mhlo.custom_call @shape_assertion(%[[FALSE]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: return %arg0 : tensor +} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir index f81af2fa05271..dc338bca94773 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir @@ -6,10 +6,10 @@ func.func @compute_reshape_shape(%arg0: index, %arg1: tensor<2xi32>) -> tensor<2 func.return %0 : tensor<2xi32> // CHECK: %[[ARG0_I32:.*]] = builtin.unrealized_conversion_cast %arg0 : index to tensor // CHECK-NEXT: %[[TMP0:.*]] = mhlo.constant dense<-1> : tensor - // CHECK-NEXT: %[[INPUT_SIZE0x1:.*]] = "mhlo.slice"(%arg1) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + // CHECK-NEXT: %[[INPUT_SIZE0x1:.*]] = "mhlo.slice"(%arg1) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> // CHECK-NEXT: %[[INPUT_SIZE0:.*]] = mhlo.reshape %[[INPUT_SIZE0x1]] : (tensor<1xi32>) -> tensor // CHECK-NEXT: %[[TMP1:.*]] = mhlo.multiply %[[TMP0]], %[[INPUT_SIZE0]] : tensor - // CHECK-NEXT: %[[INPUT_SIZE1x1:.*]] = "mhlo.slice"(%arg1) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + // CHECK-NEXT: %[[INPUT_SIZE1x1:.*]] = "mhlo.slice"(%arg1) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> // CHECK-NEXT: %[[INPUT_SIZE1:.*]] = mhlo.reshape %[[INPUT_SIZE1x1]] : (tensor<1xi32>) -> tensor // CHECK-NEXT: %[[INPUT_SIZE_PRODUCT:.*]] = mhlo.multiply %[[TMP1]], %[[INPUT_SIZE1]] : tensor // CHECK-NEXT: %[[COMPUTED_SIZE:.*]] = mhlo.divide %[[ARG0_I32]], %[[INPUT_SIZE_PRODUCT]] : tensor @@ -20,7 +20,7 @@ func.func @compute_reshape_shape(%arg0: index, %arg1: tensor<2xi32>) -> tensor<2 // CHECK-NEXT: %[[INPUT_SIZE1_EQ_M1:.*]] = mhlo.compare EQ, %6, %[[M1]], NOTYPE : (tensor, tensor) -> tensor // CHECK-NEXT: %[[RESULT_SIZE1:.*]] = mhlo.select %[[INPUT_SIZE1_EQ_M1]], %[[COMPUTED_SIZE]], %6 : tensor, tensor // CHECK-NEXT: %[[RESULT_SIZE1x1:.*]] = mhlo.reshape %[[RESULT_SIZE1]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: %[[RESULT:.*]] = "mhlo.concatenate"(%[[RESULT_SIZE0x1]], %[[RESULT_SIZE1x1]]) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + // CHECK-NEXT: %[[RESULT:.*]] = "mhlo.concatenate"(%[[RESULT_SIZE0x1]], %[[RESULT_SIZE1x1]]) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> // CHECK-NEXT: return %[[RESULT]] : tensor<2xi32> } @@ -32,10 +32,10 @@ func.func @num_elements_tensor_to_index(%arg0: tensor<2xindex>) -> index { func.return %0 : index // CHECK: %[[ARG0_I32:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32> // CHECK-NEXT: %[[TMP0:.*]] = mhlo.constant dense<1> : tensor - // CHECK-NEXT: %[[SIZE0x1:.*]] = "mhlo.slice"(%[[ARG0_I32]]) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + // CHECK-NEXT: %[[SIZE0x1:.*]] = "mhlo.slice"(%[[ARG0_I32]]) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> // CHECK-NEXT: %[[SIZE0:.*]] = mhlo.reshape %[[SIZE0x1]] : (tensor<1xi32>) -> tensor // CHECK-NEXT: %[[TMP1:.*]] = mhlo.multiply %[[TMP0]], %[[SIZE0]] : tensor - // CHECK-NEXT: %[[SIZE1x1:.*]] = "mhlo.slice"(%[[ARG0_I32]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + // CHECK-NEXT: %[[SIZE1x1:.*]] = "mhlo.slice"(%[[ARG0_I32]]) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> // CHECK-NEXT: %[[SIZE1:.*]] = mhlo.reshape %[[SIZE1x1]] : (tensor<1xi32>) -> tensor // CHECK-NEXT: %[[RESULT_I32:.*]] = mhlo.multiply %[[TMP1]], %[[SIZE1]] : tensor // CHECK-NEXT: %[[RESULT_INDEX:.*]] = builtin.unrealized_conversion_cast %[[RESULT_I32]] : tensor to index @@ -64,11 +64,11 @@ func.func @num_elements_xxx_to_size(%arg0: tensor<2xindex>) -> !shape.size { func.func @shape_of_ranked_to_index(%arg0: tensor) -> tensor<2xindex> { %0 = shape.shape_of %arg0 : tensor -> tensor<2xindex> func.return %0 : tensor<2xindex> - // CHECK: %[[SIZE0x1:.*]] = "mhlo.get_dimension_size"(%arg0) {dimension = 0 : i64} : (tensor) -> tensor + // CHECK: %[[SIZE0x1:.*]] = "mhlo.get_dimension_size"(%arg0) <{dimension = 0 : i64}> : (tensor) -> tensor // CHECK-NEXT: %[[SIZE0:.*]] = mhlo.reshape %[[SIZE0x1]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: %[[SIZE1x1:.*]] = "mhlo.get_dimension_size"(%arg0) {dimension = 1 : i64} : (tensor) -> tensor + // CHECK-NEXT: %[[SIZE1x1:.*]] = "mhlo.get_dimension_size"(%arg0) <{dimension = 1 : i64}> : (tensor) -> tensor // CHECK-NEXT: %[[SIZE1:.*]] = mhlo.reshape %[[SIZE1x1]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: %[[RESULT_I32:.*]] = "mhlo.concatenate"(%[[SIZE0]], %[[SIZE1]]) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + // CHECK-NEXT: %[[RESULT_I32:.*]] = "mhlo.concatenate"(%[[SIZE0]], %[[SIZE1]]) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> // CHECK-NEXT: %[[RESULT_INDEX:.*]] = builtin.unrealized_conversion_cast %[[RESULT_I32]] : tensor<2xi32> to tensor<2xindex> // CHECK-NEXT: return %[[RESULT_INDEX]] : tensor<2xindex> } @@ -96,7 +96,7 @@ func.func @tensor_dim(%arg0: tensor) -> index { %c0 = arith.constant 0 : index %dim = tensor.dim %arg0, %c0 : tensor func.return %dim : index - // CHECK: %[[DIM_SIZE:.*]] = "mhlo.get_dimension_size"(%arg0) {dimension = 0 : i64} : (tensor) -> tensor + // CHECK: %[[DIM_SIZE:.*]] = "mhlo.get_dimension_size"(%arg0) <{dimension = 0 : i64}> : (tensor) -> tensor // CHECK-NEXT: %[[DIM_SIZE_INDEX:.*]] = builtin.unrealized_conversion_cast %[[DIM_SIZE]] : tensor to index // CHECK-NEXT: return %[[DIM_SIZE_INDEX]] : index } @@ -119,7 +119,7 @@ func.func @tensor_from_elements(%arg0: index) -> tensor<2xindex> { // CHECK: %[[ELEMENT1_SCALAR:.*]] = builtin.unrealized_conversion_cast %arg0 : index to tensor // CHECK-NEXT: %[[ELEMENT1:.*]] = mhlo.reshape %[[ELEMENT1_SCALAR]] : (tensor) -> tensor<1xi32> // CHECK-NEXT: %[[ELEMENT2:.*]] = mhlo.constant dense<0> : tensor<1xi32> - // CHECK-NEXT: %[[CONCAT:.*]] = "mhlo.concatenate"(%[[ELEMENT1]], %[[ELEMENT2]]) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + // CHECK-NEXT: %[[CONCAT:.*]] = "mhlo.concatenate"(%[[ELEMENT1]], %[[ELEMENT2]]) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> // CHECK-NEXT: %[[CONCAT_INDEX:.*]] = builtin.unrealized_conversion_cast %[[CONCAT]] : tensor<2xi32> to tensor<2xindex> // CHECK-NEXT: return %[[CONCAT_INDEX]] : tensor<2xindex> } @@ -135,6 +135,16 @@ func.func @tensor_from_elements_i8(%arg0: i8) -> tensor<2xi8> { // ----- +// CHECK-LABEL: func.func @tensor_from_elements_scalar +func.func @tensor_from_elements_scalar(%arg0: i64) -> tensor { + %0 = tensor.from_elements %arg0 : tensor + func.return %0 : tensor + // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %arg0 : i64 to tensor + // CHECK-NEXT: return %[[RESULT]] : tensor +} + +// ----- + func.func @tensor_from_elements_rank2(%arg0: index) -> tensor<2x1xindex> { %c0 = arith.constant 0 : index // expected-error@+1 {{failed to legalize operation 'tensor.from_elements' that was explicitly marked illegal}} @@ -157,6 +167,20 @@ func.func @shape_broadcast(%arg0: tensor<4xindex>, %arg1: tensor<4xindex>) -> te // ----- +func.func @shape_broadcast_different_dims(%arg0: tensor<4xindex>, %arg1: tensor<6xindex>) -> tensor<6xindex> { + %0 = shape.broadcast %arg0, %arg1 : tensor<4xindex>, tensor<6xindex> -> tensor<6xindex> + func.return %0 : tensor<6xindex> + // CHECK: %[[LHS:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<4xindex> to tensor<4xi32> + // CHECK-NEXT: %[[RHS:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<6xindex> to tensor<6xi32> + // CHECK-NEXT: %[[PAD:.*]] = mhlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[LHS_PAD:.*]] = "mhlo.concatenate"(%[[PAD]], %[[LHS]]) <{dimension = 0 : i64}> : (tensor<2xi32>, tensor<4xi32>) -> tensor<6xi32> + // CHECK-NEXT: %[[BROADCAST:.*]] = mhlo.maximum %[[LHS_PAD]], %[[RHS]] : tensor<6xi32> + // CHECK-NEXT: %[[BROADCAST_INDEX:.*]] = builtin.unrealized_conversion_cast %[[BROADCAST]] : tensor<6xi32> to tensor<6xindex> + // CHECK-NEXT: return %[[BROADCAST_INDEX]] : tensor<6xindex> +} + +// ----- + func.func @shape_broadcast_result_shape(%arg0: tensor<4xindex>, %arg1: tensor<4xindex>) -> !shape.shape { // expected-error@+1 {{failed to legalize operation 'shape.broadcast' that was explicitly marked illegal}} %0 = shape.broadcast %arg0, %arg1 : tensor<4xindex>, tensor<4xindex> -> !shape.shape @@ -173,16 +197,202 @@ func.func @shape_broadcast_input_shape(%arg0: !shape.shape, %arg1: !shape.shape) // ----- -func.func @shape_broadcast_different_dims(%arg0: tensor<4xindex>, %arg1: tensor<6xindex>) -> tensor<6xindex> { +func.func @shape_broadcast_too_many_operands(%arg0: tensor<4xindex>, %arg1: tensor<4xindex>, %arg2: tensor<4xindex>) -> tensor<4xindex> { // expected-error@+1 {{failed to legalize operation 'shape.broadcast' that was explicitly marked illegal}} - %0 = shape.broadcast %arg0, %arg1 : tensor<4xindex>, tensor<6xindex> -> tensor<6xindex> - func.return %0 : tensor<6xindex> + %0 = shape.broadcast %arg0, %arg1, %arg2 : tensor<4xindex>, tensor<4xindex>, tensor<4xindex> -> tensor<4xindex> + func.return %0 : tensor<4xindex> } // ----- -func.func @shape_broadcast_too_many_operands(%arg0: tensor<4xindex>, %arg1: tensor<4xindex>, %arg2: tensor<4xindex>) -> tensor<4xindex> { - // expected-error@+1 {{failed to legalize operation 'shape.broadcast' that was explicitly marked illegal}} - %0 = shape.broadcast %arg0, %arg1, %arg2 : tensor<4xindex>, tensor<4xindex>, tensor<4xindex> -> tensor<4xindex> - func.return %0 : tensor<4xindex> +func.func @shape_cstr_broadcastable(%arg0: tensor<2xindex>, %arg1: tensor<2xindex>) -> !shape.witness { + // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<2xindex> + func.return %0 : !shape.witness +} + +// ----- + +func.func @mhlo_cstr_reshapable(%arg0: index, %arg1: tensor<2xindex>, %arg2: tensor) -> tensor { + // expected-error@+1 {{failed to legalize operation 'mhlo.cstr_reshapable' that was explicitly marked illegal}} + %0 = mhlo.cstr_reshapable %arg0, %arg1 : (index, tensor<2xindex>) -> !shape.witness + %1 = shape.assuming %0 -> (tensor) { + %2 = mhlo.dynamic_reshape %arg2, %arg1 : (tensor, tensor<2xindex>) -> tensor + shape.assuming_yield %2 : tensor + } + func.return %1 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_shape +func.func @const_shape() -> tensor<2xindex> { + %0 = shape.const_shape [6, 4] : tensor<2xindex> + return %0 : tensor<2xindex> + // CHECK: %[[CST:.*]] = mhlo.constant dense<[6, 4]> : tensor<2xi32> + // CHECK-NEXT: %[[CST_INDEX:.*]] = builtin.unrealized_conversion_cast %[[CST]] : tensor<2xi32> to tensor<2xindex> + // CHECK-NEXT: return %[[CST_INDEX]] : tensor<2xindex> +} + +// ----- + +// CHECK-LABEL: func @index_cast_index_to_i32 +func.func @index_cast_index_to_i32(%arg0: tensor<2xindex>) -> tensor<2xi32> { + %0 = arith.index_cast %arg0 : tensor<2xindex> to tensor<2xi32> + return %0 : tensor<2xi32> + // CHECK-NEXT: %[[CST_I32:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: return %[[CST_I32]] : tensor<2xi32> +} + +// ----- + +// CHECK-LABEL: func @index_cast_i32_to_index +func.func @index_cast_i32_to_index(%arg0: tensor<2xi32>) -> tensor<2xindex> { + %0 = arith.index_cast %arg0 : tensor<2xi32> to tensor<2xindex> + return %0 : tensor<2xindex> + // CHECK-NEXT: %[[CST_INDEX:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xi32> to tensor<2xindex> + // CHECK-NEXT: return %[[CST_INDEX]] : tensor<2xindex> +} + +// ----- + +// CHECK-LABEL: func @index_cast_scalar_index_to_i32 +func.func @index_cast_scalar_index_to_i32(%arg0: index) -> i32 { + // CHECK: %[[CAST_I32:.*]] = builtin.unrealized_conversion_cast %arg0 : index to tensor + // CHECK-NEXT: %[[CAST_INDEX:.*]] = builtin.unrealized_conversion_cast %[[CAST_I32]] : tensor to i32 + // CHECK-NEXT: return %[[CAST_INDEX]] : i32 + %0 = arith.index_cast %arg0 : index to i32 + return %0 : i32 +} + +// ----- + +// CHECK-LABEL: func @index_cast_scalar_index_to_i64 +func.func @index_cast_scalar_index_to_i64(%arg0: index) -> i64 { + // CHECK: %[[CAST_I32:.*]] = builtin.unrealized_conversion_cast %arg0 : index to tensor + // CHECK-NEXT: %[[CONVERT:.*]] = mhlo.convert %[[CAST_I32]] : (tensor) -> tensor + // CHECK-NEXT: %[[CAST_INDEX:.*]] = builtin.unrealized_conversion_cast %[[CONVERT]] : tensor to i64 + // CHECK-NEXT: return %[[CAST_INDEX]] : i64 + %0 = arith.index_cast %arg0 : index to i64 + return %0 : i64 +} + +// ----- + +func.func @index_cast_scalar_i32_to_index(%arg0: i32) -> index { + // CHECK: %[[CAST_I32:.*]] = builtin.unrealized_conversion_cast %arg0 : i32 to tensor + // CHECK-NEXT: %[[CAST_INDEX:.*]] = builtin.unrealized_conversion_cast %[[CAST_I32]] : tensor to index + // CHECK-NEXT: return %[[CAST_INDEX]] : index + %0 = arith.index_cast %arg0 : i32 to index + return %0 : index +} + +// ----- + +func.func @index_cast_index_to_i8(%arg0: tensor<2xindex>) -> tensor<2xi8> { + // expected-error@+1 {{failed to legalize operation 'arith.index_cast' that was explicitly marked illegal}} + %0 = arith.index_cast %arg0 : tensor<2xindex> to tensor<2xi8> + return %0 : tensor<2xi8> +} + +// ----- + +func.func @index_cast_i8_to_index(%arg0: tensor<2xi8>) -> tensor<2xindex> { + // expected-error@+1 {{failed to legalize operation 'arith.index_cast' that was explicitly marked illegal}} + %0 = arith.index_cast %arg0 : tensor<2xi8> to tensor<2xindex> + return %0 : tensor<2xindex> +} + + +// ----- + +// CHECK-LABEL: func @muli +func.func @muli(%arg0: index, %arg1: index) -> index { + %0 = arith.muli %arg0, %arg1 : index + return %0 : index + // CHECK: %[[LHS:.*]] = builtin.unrealized_conversion_cast %arg0 : index to tensor + // CHECK-NEXT: %[[RHS:.*]] = builtin.unrealized_conversion_cast %arg1 : index to tensor + // CHECK-NEXT: %[[RES:.*]] = mhlo.multiply %[[LHS]], %[[RHS]] : tensor + // CHECK-NEXT: %[[RES_INDEX:.*]] = builtin.unrealized_conversion_cast %[[RES]] : tensor to index + // CHECK-NEXT: return %[[RES_INDEX]] : index +} + +// ----- + +// CHECK-LABEL: func @muli_const +func.func @muli_const() -> index { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = arith.muli %c1, %c2 : index + return %0 : index + // CHECK: %[[LHS:.*]] = mhlo.constant dense<1> : tensor + // CHECK-NEXT: %[[RHS:.*]] = mhlo.constant dense<2> : tensor + // CHECK-NEXT: %[[RES:.*]] = mhlo.multiply %[[LHS]], %[[RHS]] : tensor + // CHECK-NEXT: %[[RES_INDEX:.*]] = builtin.unrealized_conversion_cast %[[RES]] : tensor to index + // CHECK-NEXT: return %[[RES_INDEX]] : index +} + +// ----- + +func.func @muli_i32(%arg0: i32, %arg1: i32) -> i32 { + // expected-error@+1 {{failed to legalize operation 'arith.muli' that was explicitly marked illegal}} + %0 = arith.muli %arg0, %arg1 : i32 + return %0 : i32 +} + +// ----- + +// CHECK-LABEL: func @tensor_extract +func.func @tensor_extract(%arg0: tensor<3x3xindex>) -> index { + %c1 = arith.constant 0 : index + %c2 = arith.constant 1 : index + %0 = tensor.extract %arg0[%c1, %c2] : tensor<3x3xindex> + return %0 : index + // CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<3x3xindex> to tensor<3x3xi32> + // CHECK-NEXT: %[[SLICE:.*]] = "mhlo.slice"(%[[CAST]]) + // CHECK-SAME: limit_indices = dense<[1, 2]> : tensor<2xi64> + // CHECK-SAME: start_indices = dense<[0, 1]> : tensor<2xi64> + // CHECK-SAME: strides = dense<1> : tensor<2xi64> + // CHECK-SAME: (tensor<3x3xi32>) -> tensor<1x1xi32> + // CHECK-NEXT: %[[RESHAPE:.*]] = mhlo.reshape %[[SLICE]] : (tensor<1x1xi32>) -> tensor + // CHECK-NEXT: %[[RES_INDEX:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE]] : tensor to index + // CHECK-NEXT: return %[[RES_INDEX]] : index +} + +// ----- + +// CHECK-LABEL: func @tensor_extract_i32 +func.func @tensor_extract_i32(%arg0: tensor<3x3xi32>) -> i32 { + %c1 = arith.constant 0 : index + %c2 = arith.constant 1 : index + %0 = tensor.extract %arg0[%c1, %c2] : tensor<3x3xi32> + return %0 : i32 + // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%arg0) + // CHECK-SAME: limit_indices = dense<[1, 2]> : tensor<2xi64> + // CHECK-SAME: start_indices = dense<[0, 1]> : tensor<2xi64> + // CHECK-SAME: strides = dense<1> : tensor<2xi64> + // CHECK-SAME: (tensor<3x3xi32>) -> tensor<1x1xi32> + // CHECK-NEXT: %[[RESHAPE:.*]] = mhlo.reshape %[[SLICE]] : (tensor<1x1xi32>) -> tensor + // CHECK-NEXT: %[[RES_I32:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE]] : tensor to i32 + // CHECK-NEXT: return %[[RES_I32]] : i32 +} + +// ----- + +func.func @tensor_extract_out_of_range(%arg0: tensor<3x3xindex>) -> index { + %c1 = arith.constant 4 : index + %c2 = arith.constant 4 : index + // expected-error@+1 {{failed to legalize operation 'tensor.extract' that was explicitly marked illegal}} + %0 = tensor.extract %arg0[%c1, %c2] : tensor<3x3xindex> + return %0 : index +} + +// ----- + +func.func @tensor_extract_dynamic(%arg0: tensor) -> index { + %c1 = arith.constant 0 : index + %c2 = arith.constant 2 : index + // expected-error@+1 {{failed to legalize operation 'tensor.extract' that was explicitly marked illegal}} + %0 = tensor.extract %arg0[%c1, %c2] : tensor + return %0 : index } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/sparse_gendot_lower.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/sparse_gendot_lower.mlir deleted file mode 100644 index ec98a7cc1a4c6..0000000000000 --- a/xla/mlir_hlo/tests/Dialect/mhlo/sparse_gendot_lower.mlir +++ /dev/null @@ -1,121 +0,0 @@ -// RUN: mlir-hlo-opt %s \ -// RUN: --verify-diagnostics \ -// RUN: --mhlo-test-lower-general-dot --canonicalize | FileCheck %s - -#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> -#DCSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }> -#COO = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique), d2 : singleton) }> - -// CHECK: #[[$SV:.*]] = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> -// CHECK: #[[$DCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }> -// CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique), d2 : singleton) }> - -// -// Vector-vector gendot. -// -// CHECK-LABEL: func.func @sparse_vecvec( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10xf64, #[[$SV]]>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<10xf64, #[[$SV]]>) -> tensor { -// CHECK: %[[DOT:.*]] = "mhlo.dot"(%[[ARG0]], %[[ARG1]]) {precision_config = [#mhlo, #mhlo]} : (tensor<10xf64, #[[$SV]]>, tensor<10xf64, #[[$SV]]>) -> tensor -// CHECK: return %[[DOT]] : tensor -// CHECK: } -// -func.func @sparse_vecvec(%arg0: tensor<10xf64, #SV>, - %arg1: tensor<10xf64, #SV>) -> tensor { - %0 = "mhlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #mhlo.dot, - precision_config = [#mhlo, - #mhlo]} - : (tensor<10xf64, #SV>, - tensor<10xf64, #SV>) -> tensor - return %0 : tensor -} - -// -// Matrix-vector gendot. -// -// CHECK-LABEL: func.func @sparse_matvec( -// CHECK-SAME: %[[ARG0:.*]]: tensor<3x5xf64, #[[$DCSR]]>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<5xf64, #[[$SV]]>) -> tensor<3xf64> { -// CHECK: %[[DOT:.*]] = "mhlo.dot"(%[[ARG0]], %[[ARG1]]) {precision_config = [#mhlo, #mhlo]} : (tensor<3x5xf64, #[[$DCSR]]>, tensor<5xf64, #[[$SV]]>) -> tensor<3xf64> -// CHECK: return %[[DOT]] : tensor<3xf64> -// CHECK: } -// -func.func @sparse_matvec(%arg0: tensor<3x5xf64, #DCSR>, - %arg1: tensor<5xf64, #SV>) -> tensor<3xf64> { - %0 = "mhlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #mhlo.dot, - precision_config = [#mhlo, - #mhlo]} - : (tensor<3x5xf64, #DCSR>, - tensor<5xf64, #SV>) -> tensor<3xf64> - return %0 : tensor<3xf64> -} - -// -// Matrix-matrix gendot, one sparse operand. -// -// CHECK-LABEL: func.func @sparse_matmat_1s( -// CHECK-SAME: %[[ARG0:.*]]: tensor<16x32xf64, #[[$DCSR]]>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<32x64xf64>) -> tensor<16x64xf64> { -// CHECK: %[[DOT:.*]] = "mhlo.dot"(%[[ARG0]], %[[ARG1]]) {precision_config = [#mhlo, #mhlo]} : (tensor<16x32xf64, #[[$DCSR]]>, tensor<32x64xf64>) -> tensor<16x64xf64> -// CHECK: return %[[DOT]] : tensor<16x64xf64> -// CHECK: } -// -func.func @sparse_matmat_1s(%arg0: tensor<16x32xf64, #DCSR>, - %arg1: tensor<32x64xf64>) -> tensor<16x64xf64> { - %0 = "mhlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #mhlo.dot, - precision_config = [#mhlo, - #mhlo]} - : (tensor<16x32xf64, #DCSR>, - tensor<32x64xf64>) -> tensor<16x64xf64> - return %0 : tensor<16x64xf64> -} - -// -// Matrix-matrix gendot, everything sparse. -// -// CHECK-LABEL: func.func @sparse_matmat_as( -// CHECK-SAME: %[[ARG0:.*]]: tensor<16x32xf64, #[[$DCSR]]>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<32x64xf64, #[[$DCSR]]>) -> tensor<16x64xf64, #[[$DCSR]]> { -// CHECK: %[[DOT:.*]] = "mhlo.dot"(%[[ARG0]], %[[ARG1]]) {precision_config = [#mhlo, #mhlo]} : (tensor<16x32xf64, #[[$DCSR]]>, tensor<32x64xf64, #[[$DCSR]]>) -> tensor<16x64xf64, #[[$DCSR]]> -// CHECK: return %[[DOT]] : tensor<16x64xf64, #[[$DCSR]]> -// CHECK: } -// -func.func @sparse_matmat_as(%arg0: tensor<16x32xf64, #DCSR>, - %arg1: tensor<32x64xf64, #DCSR>) -> tensor<16x64xf64, #DCSR> { - %0 = "mhlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #mhlo.dot, - precision_config = [#mhlo, - #mhlo]} - : (tensor<16x32xf64, #DCSR>, - tensor<32x64xf64, #DCSR>) -> tensor<16x64xf64, #DCSR> - return %0 : tensor<16x64xf64, #DCSR> -} - -// -// Higher-order gendot. -// -// A situation that would introduce sparse reshape operations is not rewritten. -// -// CHECK-LABEL: func.func @sparse_tensor( -// CHECK-SAME: %[[ARG0:.*]]: tensor<197x12x64xf32>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<12x64x768xf32, #[[$COO]]>) -> tensor<197x768xf32> { -// CHECK: %[[R:.*]] = "mhlo.dot_general"(%[[ARG0]], %[[ARG1]]) -// CHECK: return %[[R]] : tensor<197x768xf32> -func.func @sparse_tensor(%arg0: tensor<197x12x64xf32>, - %arg1: tensor<12x64x768xf32, #COO>) -> tensor<197x768xf32> { - %0 = "mhlo.dot_general"(%arg0, %arg1) - {dot_dimension_numbers = #mhlo.dot, - precision_config = [#mhlo, - #mhlo]} - : (tensor<197x12x64xf32>, - tensor<12x64x768xf32, #COO>) -> tensor<197x768xf32> - return %0 : tensor<197x768xf32> -} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/sparse_lower.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/sparse_lower.mlir deleted file mode 100644 index eb5e7d98cdf6e..0000000000000 --- a/xla/mlir_hlo/tests/Dialect/mhlo/sparse_lower.mlir +++ /dev/null @@ -1,273 +0,0 @@ -// RUN: mlir-hlo-opt %s \ -// RUN: --verify-diagnostics \ -// RUN: --hlo-legalize-to-linalg \ -// RUN: --canonicalize | FileCheck %s - -// Verifies that different sparse input and output types are -// properly dealt with while lowering mhlo ops to linalg ops. - -#SV = #sparse_tensor.encoding<{ - map = (d0) -> (d0 : compressed) -}> - -#CSR = #sparse_tensor.encoding<{ - map = (d0, d1) -> (d0 : dense, d1 : compressed) -}> - -#DCSR = #sparse_tensor.encoding<{ - map = (d0, d1) -> (d0 : compressed, d1 : compressed) -}> - -#ST = #sparse_tensor.encoding<{ - map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) -}> - -// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> -// CHECK: #[[$DCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }> -// CHECK: #[[$ST:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) }> -// CHECK: #[[$SV:.*]] = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> - - -// CHECK-LABEL: func @sparse_abs_eltwise( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10x20xf32, #[[$CSR]]>) -> tensor<10x20xf32, #[[$DCSR]]> { -// CHECK: %[[OUT:.*]] = bufferization.alloc_tensor() : tensor<10x20xf32, #[[$DCSR]]> -// CHECK: %[[VAL:.*]] = linalg.generic {{{.*}} ins(%[[ARG0]] : tensor<10x20xf32, #[[$CSR]]>) outs(%[[OUT]] : tensor<10x20xf32, #[[$DCSR]]>) -// CHECK: ^bb0(%[[A:.*]]: f32, %[[B:.*]]: f32): -// CHECK: %[[ABS:.*]] = math.absf %[[A]] : f32 -// CHECK: linalg.yield %[[ABS]] : f32 -// CHECK: } -> tensor<10x20xf32, #[[$DCSR]]> -// CHECK: return %[[VAL:.*]] : tensor<10x20xf32, #[[$DCSR]]> -// CHECK: } -func.func @sparse_abs_eltwise(%arg0: tensor<10x20xf32, #CSR>) - -> tensor<10x20xf32, #DCSR> { - %0 = "mhlo.abs"(%arg0) : (tensor<10x20xf32, #CSR>) - -> tensor<10x20xf32, #DCSR> - func.return %0 : tensor<10x20xf32, #DCSR> -} - -// CHECK-LABEL: func @sparse_add_eltwise( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10x20xf32, #[[$CSR]]>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<10x20xf32, #[[$DCSR]]>) -> tensor<10x20xf32, #[[$CSR]]> { -// CHECK: %[[OUT:.*]] = bufferization.alloc_tensor() : tensor<10x20xf32, #[[$CSR]]> -// CHECK: %[[VAL:.*]] = linalg.generic {{{.*}}} ins(%[[ARG0]], %[[ARG1]] : tensor<10x20xf32, #[[$CSR]]>, tensor<10x20xf32, #[[$DCSR]]>) outs(%[[OUT]] : tensor<10x20xf32, #[[$CSR]]>) { -// CHECK: ^bb0(%[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32): -// CHECK: %[[ADD:.*]] = arith.addf %[[A]], %[[B]] : f32 -// CHECK: linalg.yield %[[ADD]] : f32 -// CHECK: } -> tensor<10x20xf32, #[[$CSR]]> -// CHECK: return %[[VAL:.*]] : tensor<10x20xf32, #[[$CSR]]> -// CHECK: } -func.func @sparse_add_eltwise(%arg0: tensor<10x20xf32, #CSR>, - %arg1: tensor<10x20xf32, #DCSR>) - -> tensor<10x20xf32, #CSR> { - %0 = mhlo.add %arg0, %arg1 : (tensor<10x20xf32, #CSR>, - tensor<10x20xf32, #DCSR>) - -> tensor<10x20xf32, #CSR> - func.return %0 : tensor<10x20xf32, #CSR> -} - -// CHECK-LABEL: func @sparse_mul_eltwise( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10x20xf32, #[[$CSR]]>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<10x20xf32, #[[$DCSR]]>) -> tensor<10x20xf32, #[[$CSR]]> { -// CHECK: %[[OUT:.*]] = bufferization.alloc_tensor() : tensor<10x20xf32, #[[$CSR]]> -// CHECK: %[[VAL:.*]] = linalg.generic {{{.*}}} ins(%[[ARG0]], %[[ARG1]] : tensor<10x20xf32, #[[$CSR]]>, tensor<10x20xf32, #[[$DCSR]]>) outs(%[[OUT]] : tensor<10x20xf32, #[[$CSR]]>) { -// CHECK: ^bb0(%[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32): -// CHECK: %[[ADD:.*]] = arith.mulf %[[A]], %[[B]] : f32 -// CHECK: linalg.yield %[[ADD]] : f32 -// CHECK: } -> tensor<10x20xf32, #[[$CSR]]> -// CHECK: return %[[VAL:.*]] : tensor<10x20xf32, #[[$CSR]]> -// CHECK: } -func.func @sparse_mul_eltwise(%arg0: tensor<10x20xf32, #CSR>, - %arg1: tensor<10x20xf32, #DCSR>) - -> tensor<10x20xf32, #CSR> { - %0 = mhlo.multiply %arg0, %arg1 : (tensor<10x20xf32, #CSR>, - tensor<10x20xf32, #DCSR>) - -> tensor<10x20xf32, #CSR> - func.return %0 : tensor<10x20xf32, #CSR> -} - -// CHECK-LABEL: func @sparse_math( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10x20x30xf64, #[[$ST]]>) -> tensor<10x20x30xf64, #[[$ST]]> { -// CHECK: %[[T0:.*]] = linalg.generic {{{.*}}} ins(%[[ARG0]] : tensor<10x20x30xf64, #[[$ST]]>) outs -// CHECK: math.absf -// CHECK: } -// CHECK: %[[T1:.*]] = linalg.generic {{{.*}}} ins(%[[T0]] : tensor<10x20x30xf64, #[[$ST]]>) outs -// CHECK: math.expm1 -// CHECK: } -// CHECK: %[[T2:.*]] = linalg.generic {{{.*}}} ins(%[[T1]] : tensor<10x20x30xf64, #[[$ST]]>) outs -// CHECK: math.log1p -// CHECK: } -// CHECK: %[[T3:.*]] = linalg.generic {{{.*}}} ins(%[[T2]] : tensor<10x20x30xf64, #[[$ST]]>) outs -// CHECK: arith.negf -// CHECK: } -// CHECK: %[[T4:.*]] = linalg.generic {{{.*}}} ins(%[[T3]] : tensor<10x20x30xf64, #[[$ST]]>) outs -// CHECK: sparse_tensor.unary %{{.*}} : f64 to f64 -// CHECK: present = { -// CHECK: math.copysign -// CHECK: sparse_tensor.yield %{{.*}} : f64 -// CHECK: } -// CHECK: absent = { -// CHECK: } -// CHECK: } -// CHECK: %[[T5:.*]] = linalg.generic {{{.*}}} ins(%[[T4]] : tensor<10x20x30xf64, #[[$ST]]>) outs -// CHECK: math.sin -// CHECK: } -// CHECK: %[[T6:.*]] = linalg.generic {{{.*}}} ins(%[[T5]] : tensor<10x20x30xf64, #[[$ST]]>) outs -// CHECK: math.sqrt -// CHECK: } -// CHECK: %[[T7:.*]] = linalg.generic {{{.*}}} ins(%[[T6]] : tensor<10x20x30xf64, #[[$ST]]>) outs -// CHECK: math.tanh -// CHECK: } -// CHECK: %[[T8:.*]] = linalg.generic {{{.*}}} ins(%[[T7]] : tensor<10x20x30xf64, #[[$ST]]>) outs -// CHECK: math.ceil -// CHECK: } -// CHECK: %[[T9:.*]] = linalg.generic {{{.*}}} ins(%[[T8]] : tensor<10x20x30xf64, #[[$ST]]>) outs -// CHECK: math.floor -// CHECK: } -// CHECK: return %[[T9]] : tensor<10x20x30xf64, #[[$ST]]> -// CHECK: } -func.func @sparse_math(%arg0: tensor<10x20x30xf64, #ST>) -> tensor<10x20x30xf64, #ST> { - %0 = mhlo.abs %arg0 : (tensor<10x20x30xf64, #ST>) -> tensor<10x20x30xf64, #ST> - %1 = mhlo.exponential_minus_one %0 : (tensor<10x20x30xf64, #ST>) -> tensor<10x20x30xf64, #ST> - %2 = mhlo.log_plus_one %1 : (tensor<10x20x30xf64, #ST>) -> tensor<10x20x30xf64, #ST> - %3 = mhlo.negate %2 : (tensor<10x20x30xf64, #ST>) -> tensor<10x20x30xf64, #ST> - %4 = mhlo.sign %3 : (tensor<10x20x30xf64, #ST>) -> tensor<10x20x30xf64, #ST> - %5 = mhlo.sine %4 : (tensor<10x20x30xf64, #ST>) -> tensor<10x20x30xf64, #ST> - %6 = mhlo.sqrt %5 : (tensor<10x20x30xf64, #ST>) -> tensor<10x20x30xf64, #ST> - %7 = mhlo.tanh %6 : (tensor<10x20x30xf64, #ST>) -> tensor<10x20x30xf64, #ST> - %8 = mhlo.ceil %7 : (tensor<10x20x30xf64, #ST>) -> tensor<10x20x30xf64, #ST> - %9 = mhlo.floor %8 : (tensor<10x20x30xf64, #ST>) -> tensor<10x20x30xf64, #ST> - func.return %9 : tensor<10x20x30xf64, #ST> -} - -// CHECK-LABEL: func @sparse_sign( -// CHECK-SAME: %[[A:.*]]: tensor<100xi32, #[[$SV]]>) -> tensor<100xi32> { -// CHECK: %[[T:.*]] = linalg.generic {{{.*}}} ins(%[[A]] : tensor<100xi32, #[[$SV]]>) -// CHECK: %[[U:.*]] = sparse_tensor.unary %{{.*}} : i32 to i32 -// CHECK: present = { -// CHECK: arith.cmpi eq -// CHECK: sparse_tensor.yield %{{.*}} : i32 -// CHECK: } -// CHECK: absent = { -// CHECK: } -// CHECK: linalg.yield %[[U]] : i32 -// CHECK: } -> tensor<100xi32> -// CHECK: return %[[T]] : tensor<100xi32> -// CHECK: } -func.func @sparse_sign(%arg0: tensor<100xi32, #SV>) -> tensor<100xi32> { - %0 = mhlo.sign %arg0 : (tensor<100xi32, #SV>) -> tensor<100xi32> - func.return %0 : tensor<100xi32> -} - -// CHECK-LABEL: func @sparse_int_abs( -// CHECK-SAME: %[[A:.*]]: tensor<100xi64, #[[$SV]]>) -> tensor<100xi64> { -// CHECK: %[[T:.*]] = linalg.generic {{{.*}}} ins(%[[A]] : tensor<100xi64, #[[$SV]]>) -// CHECK: %[[U:.*]] = sparse_tensor.unary %{{.*}} : i64 to i64 -// CHECK: present = { -// CHECK: arith.cmpi sge -// CHECK: arith.subi -// CHECK: arith.select -// CHECK: sparse_tensor.yield %{{.*}} : i64 -// CHECK: } -// CHECK: absent = { -// CHECK: } -// CHECK: linalg.yield %[[U]] : i64 -// CHECK: } -> tensor<100xi64> -// CHECK: return %[[T]] : tensor<100xi64> -// CHECK: } -func.func @sparse_int_abs(%arg0: tensor<100xi64, #SV>) -> tensor<100xi64> { - %0 = mhlo.abs %arg0 : (tensor<100xi64, #SV>) -> tensor<100xi64> - func.return %0 : tensor<100xi64> -} - -// CHECK-LABEL: func @sparse_reduce( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10xi64, #[[$SV]]>) -> tensor { -// CHECK: %[[T0:.*]] = linalg.generic {{{.*}}} ins(%[[ARG0]] : tensor<10xi64, #[[$SV]]>) -// CHECK: arith.addi -// CHECK: } -// CHECK: return %[[T0]] : tensor -// CHECK: } -func.func @sparse_reduce(%arg0: tensor<10xi64, #SV>) -> tensor { - %0 = mhlo.constant dense<0> : tensor - %1 = mhlo.reduce(%arg0 init: %0) across dimensions = [0] : (tensor<10xi64, #SV>, tensor) -> tensor - reducer(%arg1: tensor, %arg2: tensor) { - %2 = mhlo.add %arg1, %arg2 : tensor - "mhlo.return"(%2) : (tensor) -> () - } - func.return %1 : tensor -} - -// CHECK-LABEL: func @sparse_dot( -// CHECK-SAME: %[[ARG0:.*]]: tensor, -// CHECK-SAME: %[[ARG1:.*]]: tensor) -> tensor { -// CHECK: %[[T0:.*]] = linalg.generic {{{.*}}} ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) -// CHECK: arith.mulf -// CHECK: arith.addf -// CHECK: } -// CHECK: return %[[T0]] : tensor -// CHECK: } -func.func @sparse_dot(%arg0: tensor, - %arg1: tensor) -> tensor { - %0 = "mhlo.dot_general"(%arg0, %arg1) - {dot_dimension_numbers = #mhlo.dot, - precision_config = [#mhlo, - #mhlo]} - : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: func @sparse_transpose( -// CHECK-SAME: %[[ARG0:.*]]: tensor<100x200xf64, #[[$CSR]]>) -> tensor<200x100xf64, #[[$DCSR]]> { -// CHECK: %[[T0:.*]] = bufferization.alloc_tensor() : tensor<200x100xf64, #[[$DCSR]]> -// CHECK: %[[T1:.*]] = linalg.generic {{.*}} ins(%[[ARG0]] : tensor<100x200xf64, #[[$CSR]]>) outs(%[[T0]] : tensor<200x100xf64, #[[$DCSR]]>) { -// CHECK: linalg.yield -// CHECK: } -// CHECK: return %[[T1]] : tensor<200x100xf64, #[[$DCSR]]> -// CHECK: } -func.func @sparse_transpose(%arg0: tensor<100x200xf64, #CSR>) - -> tensor<200x100xf64, #DCSR> { - %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} - : (tensor<100x200xf64, #CSR>) -> tensor<200x100xf64, #DCSR> - func.return %0 : tensor<200x100xf64, #DCSR> -} - -// CHECK-LABEL: func @sparse_expand( -// CHECK-SAME: %[[ARG0:.*]]: tensor<100xf64, #[[$SV]]>) -> tensor<10x10xf64, #[[$CSR]]> { -// CHECK: %[[CST:.*]] = arith.constant dense<10> : tensor<2xi64> -// CHECK: %[[OUT:.*]] = tensor.reshape %[[ARG0]](%[[CST]]) : (tensor<100xf64, #[[$SV]]>, tensor<2xi64>) -> tensor<10x10xf64, #[[$CSR]]> -// CHECK: return %[[OUT]] : tensor<10x10xf64, #[[$CSR]]> -func.func @sparse_expand(%arg0: tensor<100xf64, #SV>) -> tensor<10x10xf64, #CSR> { - %0 = "mhlo.reshape"(%arg0) : (tensor<100xf64, #SV>) -> tensor<10x10xf64, #CSR> - return %0 : tensor<10x10xf64, #CSR> -} - -// CHECK-LABEL: func @sparse_collapse( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10x10xf64, #[[$CSR]]>) -> tensor<100xf64, #[[$SV]]> { -// CHECK: %[[CST:.*]] = arith.constant dense<100> : tensor<1xi64> -// CHECK: %[[OUT:.*]] = tensor.reshape %[[ARG0]](%[[CST]]) : (tensor<10x10xf64, #[[$CSR]]>, tensor<1xi64>) -> tensor<100xf64, #[[$SV]]> -// CHECK: return %[[OUT]] : tensor<100xf64, #[[$SV]]> -func.func @sparse_collapse(%arg0: tensor<10x10xf64, #CSR>) -> tensor<100xf64, #SV> { - %0 = "mhlo.reshape"(%arg0) : (tensor<10x10xf64, #CSR>) -> tensor<100xf64, #SV> - return %0 : tensor<100xf64, #SV> -} - -// CHECK-LABEL: func @sparse_tensor_dot( -// CHECK-SAME: %[[ARG0:.*]]: tensor<197x12x64xf32>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<12x64x768xf32, #[[$ST]]>) -> tensor<197x768xf32, #[[$CSR]]> { -// CHECK: %[[T0:.*]] = linalg.generic {{{.*}}} ins(%[[ARG0]], %[[ARG1]] : -// CHECK: arith.mulf -// CHECK: arith.addf -// CHECK: } -// CHECK: return %[[T0]] : tensor<197x768xf32, #[[$CSR]]> -// CHECK: } -func.func @sparse_tensor_dot(%arg0: tensor<197x12x64xf32>, - %arg1: tensor<12x64x768xf32, #ST>) -> tensor<197x768xf32, #CSR> { - %0 = "mhlo.dot_general"(%arg0, %arg1) - {dot_dimension_numbers = #mhlo.dot, - precision_config = [#mhlo, - #mhlo]} - : (tensor<197x12x64xf32>, - tensor<12x64x768xf32, #ST>) -> tensor<197x768xf32, #CSR> - return %0 : tensor<197x768xf32, #CSR> -} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/sparse_ops.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/sparse_ops.mlir deleted file mode 100644 index e5ea921f6df2a..0000000000000 --- a/xla/mlir_hlo/tests/Dialect/mhlo/sparse_ops.mlir +++ /dev/null @@ -1,330 +0,0 @@ -// RUN: mlir-hlo-opt %s -verify-diagnostics -allow-unregistered-dialect | FileCheck %s - -// Tests for sparse types. Note that most dense MHLO ops can be made sparse -// by simply annotating one or more of the tensor types as sparse. Other than -// subtle printing and parsing difference (due to having different input and -// output types), dense or sparse ops are semantically equivalent. - -#SV = #sparse_tensor.encoding<{ - map = (d0) -> (d0 : compressed) -}> - -#CSR = #sparse_tensor.encoding<{ - map = (d0, d1) -> (d0 : dense, d1 : compressed) -}> - -#DCSR = #sparse_tensor.encoding<{ - map = (d0, d1) -> (d0 : compressed, d1 : compressed) -}> - -// -// Dense unary and binary eltwise. -// - -// CHECK-LABEL: func @dense_abs_eltwise( -// CHECK-SAME: %[[A:.*]]: tensor<10x20xf32>) -// CHECK: %[[T:.*]] = mhlo.abs %[[A]] : tensor<10x20xf32> -// CHECK: return %[[T]] : tensor<10x20xf32> -func.func @dense_abs_eltwise(%arg0: tensor<10x20xf32>) -> tensor<10x20xf32> { - %0 = mhlo.abs %arg0 : tensor<10x20xf32> - func.return %0 : tensor<10x20xf32> -} - -// CHECK-LABEL: func @dense_add_eltwise( -// CHECK-SAME: %[[A:.*]]: tensor<10x20xf32>, -// CHECK-SAME: %[[B:.*]]: tensor<10x20xf32>) -// CHECK: %[[T:.*]] = mhlo.add %[[A]], %[[B]] : tensor<10x20xf32> -// CHECK: return %[[T]] : tensor<10x20xf32> -func.func @dense_add_eltwise(%arg0: tensor<10x20xf32>, - %arg1: tensor<10x20xf32>) -> tensor<10x20xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<10x20xf32> - func.return %0 : tensor<10x20xf32> -} - -// -// Sparse unary eltwise. -// - -// CHECK-LABEL: func @sparse_abs_eltwise1( -// CHECK-SAME: %[[A:.*]]: tensor<10x20xf32, #{{.*}}>) -// CHECK: %[[T:.*]] = mhlo.abs %[[A]] : (tensor<10x20xf32, #{{.*}}>) -> tensor<10x20xf32> -// CHECK: return %[[T]] : tensor<10x20xf32> -func.func @sparse_abs_eltwise1(%arg0: tensor<10x20xf32, #CSR>) -> tensor<10x20xf32> { - %0 = mhlo.abs %arg0 : (tensor<10x20xf32, #CSR>) -> tensor<10x20xf32> - func.return %0 : tensor<10x20xf32> -} - -// CHECK-LABEL: func @sparse_abs_eltwise2( -// CHECK-SAME: %[[A:.*]]: tensor<10x20xf32, #{{.*}}>) -// CHECK: %[[T:.*]] = mhlo.abs %[[A]] : tensor<10x20xf32, #{{.*}}> -// CHECK: return %[[T]] : tensor<10x20xf32, #{{.*}}> -func.func @sparse_abs_eltwise2(%arg0: tensor<10x20xf32, #CSR>) -> tensor<10x20xf32, #CSR> { - %0 = mhlo.abs %arg0 : tensor<10x20xf32, #CSR> - func.return %0 : tensor<10x20xf32, #CSR> -} - -// CHECK-LABEL: func @sparse_abs_eltwise3( -// CHECK-SAME: %[[A:.*]]: tensor<10x20xf32, #{{.*}}>) -// CHECK: %[[T:.*]] = mhlo.abs %[[A]] : (tensor<10x20xf32, #{{.*}}>) -> tensor<10x20xf32, #{{.*}}> -// CHECK: return %[[T]] : tensor<10x20xf32, #{{.*}}> -func.func @sparse_abs_eltwise3(%arg0: tensor<10x20xf32, #CSR>) -> tensor<10x20xf32, #DCSR> { - %0 = mhlo.abs %arg0 : (tensor<10x20xf32, #CSR>) -> tensor<10x20xf32, #DCSR> - func.return %0 : tensor<10x20xf32, #DCSR> -} - -// CHECK-LABEL: func @sparse_abs_eltwise4( -// CHECK-SAME: %[[A:.*]]: tensor<10x20xf32>) -// CHECK: %[[T:.*]] = mhlo.abs %[[A]] : (tensor<10x20xf32>) -> tensor<10x20xf32, #{{.*}}> -// CHECK: return %[[T]] : tensor<10x20xf32, #{{.*}}> -func.func @sparse_abs_eltwise4(%arg0: tensor<10x20xf32>) -> tensor<10x20xf32, #CSR> { - %0 = mhlo.abs %arg0 : (tensor<10x20xf32>) -> tensor<10x20xf32, #CSR> - func.return %0 : tensor<10x20xf32, #CSR> -} - -// CHECK-LABEL: func @sparse_conv_eltwise1( -// CHECK-SAME: %[[A:.*]]: tensor<2x3xf32, #{{.*}}>) -// CHECK: %[[T:.*]] = mhlo.convert %[[A]] : (tensor<2x3xf32, #{{.*}}>) -> tensor<2x3xi32> -// CHECK: return %[[T]] : tensor<2x3xi32> -func.func @sparse_conv_eltwise1(%arg0: tensor<2x3xf32, #CSR>) -> tensor<2x3xi32> { - %0 = mhlo.convert %arg0 : (tensor<2x3xf32, #CSR>) -> tensor<2x3xi32> - return %0 : tensor<2x3xi32> -} - -// CHECK-LABEL: func @sparse_conv_eltwise2( -// CHECK-SAME: %[[A:.*]]: tensor<2x3xf32>) -// CHECK: %[[T:.*]] = mhlo.convert %[[A]] : (tensor<2x3xf32>) -> tensor<2x3xi32, #{{.*}}> -// CHECK: return %[[T]] : tensor<2x3xi32, #{{.*}}> -func.func @sparse_conv_eltwise2(%arg0: tensor<2x3xf32>) -> tensor<2x3xi32, #CSR> { - %0 = mhlo.convert %arg0 : (tensor<2x3xf32>) -> tensor<2x3xi32, #CSR> - return %0 : tensor<2x3xi32, #CSR> -} - -// CHECK-LABEL: func @sparse_conv_eltwise3( -// CHECK-SAME: %[[A:.*]]: tensor<2x3xf32, #{{.*}}>) -// CHECK: %[[T:.*]] = mhlo.convert %[[A]] : (tensor<2x3xf32, #{{.*}}>) -> tensor<2x3xi32, #{{.*}}> -// CHECK: return %[[T]] : tensor<2x3xi32, #{{.*}}> -func.func @sparse_conv_eltwise3(%arg0: tensor<2x3xf32, #CSR>) -> tensor<2x3xi32, #CSR> { - %0 = mhlo.convert %arg0 : (tensor<2x3xf32, #CSR>) -> tensor<2x3xi32, #CSR> - return %0 : tensor<2x3xi32, #CSR> -} - -// -// Sparse binary eltwise. -// - -// CHECK-LABEL: func @sparse_add_eltwise1( -// CHECK-SAME: %[[A:.*]]: tensor<10x20xf32, #{{.*}}>, -// CHECK-SAME: %[[B:.*]]: tensor<10x20xf32>) -// CHECK: %[[T:.*]] = mhlo.add %[[A]], %[[B]] : (tensor<10x20xf32, #{{.*}}>, tensor<10x20xf32>) -> tensor<10x20xf32> -// CHECK: return %[[T]] : tensor<10x20xf32> -func.func @sparse_add_eltwise1(%arg0: tensor<10x20xf32, #CSR>, - %arg1: tensor<10x20xf32>) -> tensor<10x20xf32> { - %0 = mhlo.add %arg0, %arg1 : (tensor<10x20xf32, #CSR>, - tensor<10x20xf32>) -> tensor<10x20xf32> - func.return %0 : tensor<10x20xf32> -} - -// CHECK-LABEL: func @sparse_add_eltwise2( -// CHECK-SAME: %[[A:.*]]: tensor<10x20xf32, #{{.*}}>, -// CHECK-SAME: %[[B:.*]]: tensor<10x20xf32, #{{.*}}>) -// CHECK: %[[T:.*]] = mhlo.add %[[A]], %[[B]] : (tensor<10x20xf32, #{{.*}}>, tensor<10x20xf32, #{{.*}}>) -> tensor<10x20xf32> -// CHECK: return %[[T]] : tensor<10x20xf32> -func.func @sparse_add_eltwise2(%arg0: tensor<10x20xf32, #CSR>, - %arg1: tensor<10x20xf32, #DCSR>) - -> tensor<10x20xf32> { - %0 = mhlo.add %arg0, %arg1 : (tensor<10x20xf32, #CSR>, - tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32> - func.return %0 : tensor<10x20xf32> -} - -// CHECK-LABEL: func @sparse_add_eltwise3( -// CHECK-SAME: %[[A:.*]]: tensor<10x20xf32, #{{.*}}>, -// CHECK-SAME: %[[B:.*]]: tensor<10x20xf32, #{{.*}}>) -// CHECK: %[[T:.*]] = mhlo.add %[[A]], %[[B]] : (tensor<10x20xf32, #{{.*}}>, tensor<10x20xf32, #{{.*}}>) -> tensor<10x20xf32, #{{.*}}> -// CHECK: return %[[T]] : tensor<10x20xf32, #{{.*}}> -func.func @sparse_add_eltwise3(%arg0: tensor<10x20xf32, #CSR>, - %arg1: tensor<10x20xf32, #DCSR>) - -> tensor<10x20xf32, #CSR> { - %0 = mhlo.add %arg0, %arg1 : (tensor<10x20xf32, #CSR>, - tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32, #CSR> - func.return %0 : tensor<10x20xf32, #CSR> -} - -// CHECK-LABEL: func @sparse_add_eltwise4( -// CHECK-SAME: %[[A:.*]]: tensor<10x20xf32>, -// CHECK-SAME: %[[B:.*]]: tensor<10x20xf32>) -// CHECK: %[[T:.*]] = mhlo.add %[[A]], %[[B]] : (tensor<10x20xf32>, tensor<10x20xf32>) -> tensor<10x20xf32, #{{.*}}> -// CHECK: return %[[T]] : tensor<10x20xf32, #{{.*}}> -func.func @sparse_add_eltwise4(%arg0: tensor<10x20xf32>, - %arg1: tensor<10x20xf32>) - -> tensor<10x20xf32, #CSR> { - %0 = mhlo.add %arg0, %arg1 : (tensor<10x20xf32>, - tensor<10x20xf32>) -> tensor<10x20xf32, #CSR> - func.return %0 : tensor<10x20xf32, #CSR> -} - -// CHECK-LABEL: func @sparse_add_eltwise5( -// CHECK-SAME: %[[A:.*]]: tensor<10x20xf32, #{{.*}}>, -// CHECK-SAME: %[[B:.*]]: tensor<10x20xf32, #{{.*}}>) -// CHECK: %[[T:.*]] = mhlo.add %[[A]], %[[B]] : tensor<10x20xf32, #{{.*}}> -// CHECK: return %[[T]] : tensor<10x20xf32, #{{.*}}> -func.func @sparse_add_eltwise5(%arg0: tensor<10x20xf32, #CSR>, - %arg1: tensor<10x20xf32, #CSR>) - -> tensor<10x20xf32, #CSR> { - %0 = mhlo.add %arg0, %arg1 : tensor<10x20xf32, #CSR> - func.return %0 : tensor<10x20xf32, #CSR> -} - -// CHECK-LABEL: func @sparse_mul_eltwise1( -// CHECK-SAME: %[[A:.*]]: tensor<10x20xf32, #{{.*}}>, -// CHECK-SAME: %[[B:.*]]: tensor<10x20xf32, #{{.*}}>) -// CHECK: %[[T:.*]] = mhlo.multiply %[[A]], %[[B]] : tensor<10x20xf32, #{{.*}}> -// CHECK: return %[[T]] : tensor<10x20xf32, #{{.*}}> -func.func @sparse_mul_eltwise1(%arg0: tensor<10x20xf32, #CSR>, - %arg1: tensor<10x20xf32, #CSR>) - -> tensor<10x20xf32, #CSR> { - %0 = mhlo.multiply %arg0, %arg1 : tensor<10x20xf32, #CSR> - func.return %0 : tensor<10x20xf32, #CSR> -} - -// CHECK-LABEL: func @sparse_mul_eltwise2( -// CHECK-SAME: %[[A:.*]]: tensor<10x20xf32>, -// CHECK-SAME: %[[B:.*]]: tensor<10x20xf32, #{{.*}}>) -// CHECK: %[[T:.*]] = mhlo.multiply %[[A]], %[[B]] : (tensor<10x20xf32>, tensor<10x20xf32, #{{.*}}>) -> tensor<10x20xf32, #{{.*}}> -// CHECK: return %[[T]] : tensor<10x20xf32, #{{.*}}> -func.func @sparse_mul_eltwise2(%arg0: tensor<10x20xf32>, - %arg1: tensor<10x20xf32, #CSR>) - -> tensor<10x20xf32, #CSR> { - %0 = mhlo.multiply %arg0, %arg1 : (tensor<10x20xf32>, - tensor<10x20xf32, #CSR>) -> tensor<10x20xf32, #CSR> - func.return %0 : tensor<10x20xf32, #CSR> -} - -// -// Sparse dot operation. -// - -// CHECK-LABEL: func @dot1( -// CHECK-SAME: %[[A:.*]]: tensor<4xf64, #{{.*}}>, -// CHECK-SAME: %[[B:.*]]: tensor<4xf64>) -> tensor { -// CHECK: %[[T:.*]] = "mhlo.dot_general"(%[[A]], %[[B]]) {{{.*}}} : (tensor<4xf64, #{{.*}}>, tensor<4xf64>) -> tensor -// CHECK: return %[[T]] : tensor -func.func @dot1(%arg0: tensor<4xf64, #SV>, - %arg1: tensor<4xf64>) -> tensor { - %0 = "mhlo.dot_general"(%arg0, %arg1) - {dot_dimension_numbers = #mhlo.dot, - precision_config = [#mhlo, - #mhlo]} - : (tensor<4xf64, #SV>, tensor<4xf64>) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: func @dot2( -// CHECK-SAME: %[[A:.*]]: tensor<4xf64>, -// CHECK-SAME: %[[B:.*]]: tensor<4xf64, #{{.*}}>) -> tensor { -// CHECK: %[[T:.*]] = "mhlo.dot_general"(%[[A]], %[[B]]) {{{.*}}} : (tensor<4xf64>, tensor<4xf64, #{{.*}}>) -> tensor -// CHECK: return %[[T]] : tensor -func.func @dot2(%arg0: tensor<4xf64>, - %arg1: tensor<4xf64, #SV>) -> tensor { - %0 = "mhlo.dot_general"(%arg0, %arg1) - {dot_dimension_numbers = #mhlo.dot, - precision_config = [#mhlo, - #mhlo]} - : (tensor<4xf64>, tensor<4xf64, #SV>) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: func @dot3( -// CHECK-SAME: %[[A:.*]]: tensor<4xf64, #{{.*}}>, -// CHECK-SAME: %[[B:.*]]: tensor<4xf64, #{{.*}}>) -> tensor { -// CHECK: %[[T:.*]] = "mhlo.dot_general"(%[[A]], %[[B]]) {{{.*}}} : (tensor<4xf64, #{{.*}}>, tensor<4xf64, #{{.*}}>) -> tensor -// CHECK: return %[[T]] : tensor -func.func @dot3(%arg0: tensor<4xf64, #SV>, - %arg1: tensor<4xf64, #SV>) -> tensor { - %0 = "mhlo.dot_general"(%arg0, %arg1) - {dot_dimension_numbers = #mhlo.dot, - precision_config = [#mhlo, - #mhlo]} - : (tensor<4xf64, #SV>, tensor<4xf64, #SV>) -> tensor - func.return %0 : tensor -} - -// -// Reduce. -// - -// CHECK-LABEL: func @sparse_reduce( -// CHECK-SAME: %[[A:.*]]: tensor<10xi64, #{{.*}}>) -> tensor { -// CHECK: %[[C:.*]] = mhlo.constant dense<0> : tensor -// CHECK: %[[T:.*]] = mhlo.reduce(%[[A]] init: %[[C]]) across dimensions = [0] : (tensor<10xi64, #{{.*}}>) -> tensor -// CHECK: return %[[T]] : tensor -func.func @sparse_reduce(%arg0: tensor<10xi64, #SV>) -> tensor { - %0 = mhlo.constant dense<0> : tensor - %1 = mhlo.reduce(%arg0 init: %0) across dimensions = [0] : (tensor<10xi64, #SV>, tensor) -> tensor - reducer(%arg1: tensor, %arg2: tensor) { - %2 = mhlo.add %arg1, %arg2 : tensor - "mhlo.return"(%2) : (tensor) -> () - } - func.return %1 : tensor -} - -// -// Transpose. -// - -// CHECK-LABEL: func @sparse_transpose( -// CHECK-SAME: %[[A:.*]]: tensor<100x100xf64, #{{.*}}>) -> tensor<100x100xf64, #{{.*}}> { -// CHECK: %[[T:.*]] = "mhlo.transpose"(%[[A]]) {{{.*}}} : (tensor<100x100xf64, #{{.*}}>) -> tensor<100x100xf64, #{{.*}}> -// CHECK: return %[[T]] : tensor<100x100xf64, #{{.*}}> -func.func @sparse_transpose(%arg0: tensor<100x100xf64, #CSR>) - -> tensor<100x100xf64, #DCSR> { - %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} - : (tensor<100x100xf64, #CSR>) -> tensor<100x100xf64, #DCSR> - func.return %0 : tensor<100x100xf64, #DCSR> -} - -// -// Math. -// - -// CHECK-LABEL: func @sparse_zero_preserving_math( -// CHECK-SAME: %[[A:.*]]: tensor<64xf64, #{{.*}}>) -// CHECK: %[[T0:.*]] = mhlo.abs %[[A]] : tensor<64xf64, #{{.*}}> -// CHECK: %[[T1:.*]] = mhlo.exponential_minus_one %[[T0]] : tensor<64xf64, #{{.*}}> -// CHECK: %[[T2:.*]] = mhlo.log_plus_one %[[T1]] : tensor<64xf64, #{{.*}}> -// CHECK: %[[T3:.*]] = mhlo.negate %[[T2]] : tensor<64xf64, #{{.*}}> -// CHECK: %[[T4:.*]] = mhlo.sign %[[T3]] : tensor<64xf64, #{{.*}}> -// CHECK: %[[T5:.*]] = mhlo.sine %[[T4]] : tensor<64xf64, #{{.*}}> -// CHECK: %[[T6:.*]] = mhlo.sqrt %[[T5]] : tensor<64xf64, #{{.*}}> -// CHECK: %[[T7:.*]] = mhlo.tanh %[[T6]] : tensor<64xf64, #{{.*}}> -// CHECK: %[[T8:.*]] = mhlo.ceil %[[T7]] : tensor<64xf64, #{{.*}}> -// CHECK: %[[T9:.*]] = mhlo.floor %[[T8]] : tensor<64xf64, #{{.*}}> -// CHECK: return %[[T9]] : tensor<64xf64, #{{.*}}> -func.func @sparse_zero_preserving_math(%arg0: tensor<64xf64, #SV>) -> tensor<64xf64, #SV> { - %0 = mhlo.abs %arg0 : (tensor<64xf64, #SV>) -> tensor<64xf64, #SV> - %1 = mhlo.exponential_minus_one %0 : (tensor<64xf64, #SV>) -> tensor<64xf64, #SV> - %2 = mhlo.log_plus_one %1 : (tensor<64xf64, #SV>) -> tensor<64xf64, #SV> - %3 = mhlo.negate %2 : (tensor<64xf64, #SV>) -> tensor<64xf64, #SV> - %4 = mhlo.sign %3 : (tensor<64xf64, #SV>) -> tensor<64xf64, #SV> - %5 = mhlo.sine %4 : (tensor<64xf64, #SV>) -> tensor<64xf64, #SV> - %6 = mhlo.sqrt %5 : (tensor<64xf64, #SV>) -> tensor<64xf64, #SV> - %7 = mhlo.tanh %6 : (tensor<64xf64, #SV>) -> tensor<64xf64, #SV> - %8 = mhlo.ceil %7 : (tensor<64xf64, #SV>) -> tensor<64xf64, #SV> - %9 = mhlo.floor %8 : (tensor<64xf64, #SV>) -> tensor<64xf64, #SV> - func.return %9 : tensor<64xf64, #SV> -} - -// -// Combination of quantization and sparse. -// - -// CHECK-LABEL: func @quantization_and_sparse( -// CHECK-SAME: %[[A:.*]]: tensor<1x!quant.uniform, #{{.*}}>) -// CHECK: return %[[A]] : tensor<1x!quant.uniform, #{{.*}}> -func.func @quantization_and_sparse(%arg0: tensor<1x!quant.uniform, #SV>) - -> tensor<1x!quant.uniform, #SV> { - func.return %arg0 : tensor<1x!quant.uniform, #SV> -} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/sparse_rewriting.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/sparse_rewriting.mlir deleted file mode 100644 index c0386c1de2ac9..0000000000000 --- a/xla/mlir_hlo/tests/Dialect/mhlo/sparse_rewriting.mlir +++ /dev/null @@ -1,138 +0,0 @@ -// RUN: mlir-hlo-opt %s \ -// RUN: --verify-diagnostics \ -// RUN: --mhlo-sparse-rewriting | FileCheck %s - -// Verifies that mhlo sparse tensor type rewriting occurs. - -#SV= #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> - -#CSR = #sparse_tensor.encoding<{ - map = (d0, d1) -> (d0 : dense, d1 : compressed) -}> - -#DCSR = #sparse_tensor.encoding<{ - map = (d0, d1) -> (d0 : compressed, d1 : compressed) -}> - -// CHECK: #[[$SV:.*]] = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> -// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> -// CHECK: #[[$DCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }> - - -// CHECK-LABEL: func @rewrite_unary( -// CHECK-SAME: %[[ARG0:.*]]: tensor<100xf64>) -> tensor<100xf64, #[[$SV]]> { -// CHECK: %[[VAL:.*]] = mhlo.abs %[[ARG0]] : (tensor<100xf64>) -> tensor<100xf64, #[[$SV]]> -// CHECK-NEXT: return %[[VAL:.*]] : tensor<100xf64, #[[$SV]]> -func.func @rewrite_unary(%arg0: tensor<100xf64>) -> tensor<100xf64, #SV> { - %0 = mhlo.abs %arg0 : tensor<100xf64> - %1 = sparse_tensor.convert %0 : tensor<100xf64> to tensor<100xf64, #SV> - return %1 : tensor<100xf64, #SV> -} - -// CHECK-LABEL: func @rewrite_binary( -// CHECK-SAME: %[[ARG0:.*]]: tensor<100xf64>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<100xf64, #[[$SV]]>) -> tensor<100xf64, #[[$SV]]> { -// CHECK: %[[VAL:.*]] = mhlo.multiply %[[ARG0]], %[[ARG1]] : (tensor<100xf64>, tensor<100xf64, #[[$SV]]> -// CHECK-NEXT: return %[[VAL:.*]] : tensor<100xf64, #[[$SV]]> -func.func @rewrite_binary(%arg0: tensor<100xf64>, - %arg1: tensor<100xf64, #SV>) -> tensor<100xf64, #SV> { - %0 = mhlo.multiply %arg0, %arg1 : (tensor<100xf64>, tensor<100xf64, #SV>) -> tensor<100xf64> - %1 = sparse_tensor.convert %0 : tensor<100xf64> to tensor<100xf64, #SV> - return %1 : tensor<100xf64, #SV> -} - -// CHECK-LABEL: func @rewrite_binary_override( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10x10xf64, #[[$CSR]]>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<10x10xf64, #[[$CSR]]>) -> tensor<10x10xf64, #[[$DCSR]]> { -// CHECK: %[[VAL:.*]] = mhlo.multiply %[[ARG0]], %[[ARG1]] : (tensor<10x10xf64, #[[$CSR]]>, tensor<10x10xf64, #[[$CSR]]>) -> tensor<10x10xf64, #[[$DCSR]]> -// CHECK-NEXT: return %[[VAL:.*]] : tensor<10x10xf64, #[[$DCSR]]> -func.func @rewrite_binary_override(%arg0: tensor<10x10xf64, #CSR>, - %arg1: tensor<10x10xf64, #CSR>) -> tensor<10x10xf64, #DCSR> { - %0 = mhlo.multiply %arg0, %arg1 : (tensor<10x10xf64, #CSR>, tensor<10x10xf64, #CSR>) -> tensor<10x10xf64, #CSR> - %1 = sparse_tensor.convert %0 : tensor<10x10xf64, #CSR> to tensor<10x10xf64, #DCSR> - return %1 : tensor<10x10xf64, #DCSR> -} - -// CHECK-LABEL: func @rewrite_convert( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10x10xf64>) -> tensor<10x10xf64, #[[$CSR]]> { -// CHECK: %[[VAL:.*]] = sparse_tensor.convert %[[ARG0]] : tensor<10x10xf64> to tensor<10x10xf64, #[[$CSR]]> -// CHECK-NEXT: return %[[VAL:.*]] : tensor<10x10xf64, #[[$CSR]]> -func.func @rewrite_convert(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64, #CSR> { - %0 = sparse_tensor.convert %arg0 : tensor<10x10xf64> to tensor<10x10xf64, #DCSR> - %1 = sparse_tensor.convert %0 : tensor<10x10xf64, #DCSR> to tensor<10x10xf64, #CSR> - %2 = sparse_tensor.convert %1 : tensor<10x10xf64, #CSR> to tensor<10x10xf64, #CSR> - return %2 : tensor<10x10xf64, #CSR> -} - -// CHECK-LABEL: func @rewrite_convert_nop( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10x10xf64, #[[$CSR]]>) -> tensor<10x10xf64, #[[$CSR]]> -// CHECK-NEXT: return %[[ARG0]] : tensor<10x10xf64, #[[$CSR]]> -func.func @rewrite_convert_nop(%arg0: tensor<10x10xf64, #CSR>) -> tensor<10x10xf64, #CSR> { - %0 = sparse_tensor.convert %arg0 : tensor<10x10xf64, #CSR> to tensor<10x10xf64, #DCSR> - %1 = sparse_tensor.convert %0 : tensor<10x10xf64, #DCSR> to tensor<10x10xf64, #CSR> - %2 = sparse_tensor.convert %1 : tensor<10x10xf64, #CSR> to tensor<10x10xf64, #CSR> - return %2 : tensor<10x10xf64, #CSR> -} - -// CHECK-LABEL: func @rewrite_transpose( -// CHECK-SAME: %[[ARG0:.*]]: tensor<100x200xf64, #[[$CSR]]>) -> tensor<200x100xf64, #[[$CSR]]> { -// CHECK: %[[VAL:.*]] = "mhlo.transpose"(%[[ARG0]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<100x200xf64, #[[$CSR]]> -// CHECK-NEXT: return %[[VAL:.*]] : tensor<200x100xf64, #[[$CSR]]> -func.func @rewrite_transpose(%arg0: tensor<100x200xf64, #CSR>) -> tensor<200x100xf64, #CSR> { - %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<100x200xf64, #CSR>) -> tensor<200x100xf64> - %1 = sparse_tensor.convert %0 : tensor<200x100xf64> to tensor<200x100xf64, #CSR> - return %1 : tensor<200x100xf64, #CSR> -} - -// CHECK-LABEL: func.func @rewrite_dot( -// CHECK-SAME: %[[ARG0:.*0]]: tensor<5x5xf64, #[[$CSR]]>, -// CHECK-SAME: %[[ARG1:.*1]]: tensor<5x5xf64, #[[$CSR]]>) -> tensor<5x5xf64, #[[$CSR]]> { -// CHECK: %[[VAL:.*]] = "mhlo.dot"(%[[ARG0]], %[[ARG1]]) -// CHECK: return %[[VAL]] : tensor<5x5xf64, #[[$CSR]]> -func.func @rewrite_dot(%arg0: tensor<5x5xf64, #CSR>, - %arg1: tensor<5x5xf64, #CSR>) -> tensor<5x5xf64, #CSR> { - %0 = "mhlo.dot"(%arg0, %arg1) - {precision_config = [#mhlo, - #mhlo]} - : (tensor<5x5xf64, #CSR>, - tensor<5x5xf64, #CSR>) -> tensor<5x5xf64> - %1 = sparse_tensor.convert %0 : tensor<5x5xf64> to tensor<5x5xf64, #CSR> - return %1 : tensor<5x5xf64, #CSR> -} - -// CHECK-LABEL: func.func @rewrite_general_dot( -// CHECK-SAME: %[[ARG0:.*0]]: tensor<5x5xf64, #[[$CSR]]>, -// CHECK-SAME: %[[ARG1:.*1]]: tensor<5x5xf64, #[[$CSR]]>) -> tensor<5x5xf64, #[[$CSR]]> { -// CHECK: %[[VAL:.*]] = "mhlo.dot_general"(%[[ARG0]], %[[ARG1]]) -// CHECK: return %[[VAL]] : tensor<5x5xf64, #[[$CSR]]> -func.func @rewrite_general_dot(%arg0: tensor<5x5xf64, #CSR>, - %arg1: tensor<5x5xf64, #CSR>) -> tensor<5x5xf64, #CSR> { - %0 = "mhlo.dot_general"(%arg0, %arg1) - {dot_dimension_numbers = #mhlo.dot, - precision_config = [#mhlo, - #mhlo]} - : (tensor<5x5xf64, #CSR>, - tensor<5x5xf64, #CSR>) -> tensor<5x5xf64> - %1 = sparse_tensor.convert %0 : tensor<5x5xf64> to tensor<5x5xf64, #CSR> - return %1 : tensor<5x5xf64, #CSR> -} - -// CHECK-LABEL: func.func @rewrite_elt_convert( -// CHECK-SAME: %[[ARG0:.*0]]: tensor<5x5xf64, #[[$CSR]]>) -> tensor<5x5xf32, #[[$CSR]]> { -// CHECK: %[[VAL:.*]] = sparse_tensor.convert %[[ARG0]] -// CHECK: return %[[VAL]] : tensor<5x5xf32, #[[$CSR]]> -func.func @rewrite_elt_convert(%arg0: tensor<5x5xf64, #CSR>) -> tensor<5x5xf32, #CSR> { - %0 = "mhlo.convert"(%arg0) : (tensor<5x5xf64, #CSR>) -> tensor<5x5xf32, #CSR> - return %0 : tensor<5x5xf32, #CSR> -} - -// CHECK-LABEL: func.func @concatenate_sparse( -// CHECK-SAME: %[[ARG0:.*0]]: tensor<100x100xf64, #[[$CSR]]>, -// CHECK-SAME: %[[ARG1:.*1]]: tensor<100x100xf64, #[[$CSR]]>) -> tensor<200x100xf64, #[[$CSR]]> { -// CHECK: %[[VAL:.*]] = sparse_tensor.concatenate %[[ARG0]], %[[ARG1]] {dimension = 0 -// CHECK: return %[[VAL]] : tensor<200x100xf64, #[[$CSR]]> -func.func @concatenate_sparse(%arg0: tensor<100x100xf64, #CSR>, %arg1: tensor<100x100xf64, #CSR>) -> tensor<200x100xf64, #CSR> { - %0 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<100x100xf64, #CSR>, tensor<100x100xf64, #CSR>) -> tensor<200x100xf64, #CSR> - return %0 : tensor<200x100xf64, #CSR> -} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/sparse_transpose.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/sparse_transpose.mlir deleted file mode 100755 index ae8dc3213c7f3..0000000000000 --- a/xla/mlir_hlo/tests/Dialect/mhlo/sparse_transpose.mlir +++ /dev/null @@ -1,58 +0,0 @@ -// RUN: mlir-hlo-opt %s \ -// RUN: --verify-diagnostics \ -// RUN: --canonicalize | FileCheck %s - -#DCSR = #sparse_tensor.encoding<{ - map = (d0, d1) -> (d0 : compressed, d1 : compressed) -}> - -// -// Tests that ensure trivial transposes are folded, -// but the simplified code still accounts for sparsity. -// - -// CHECK-LABEL: func @transpose1( -// CHECK-SAME: %[[A:.*]]: tensor<100x100xf64>) -// CHECK: return %[[A]] : tensor<100x100xf64> -func.func @transpose1(%arg0: tensor<100x100xf64>) - -> tensor<100x100xf64> { - %0 = "mhlo.transpose"(%arg0) - {permutation = dense<[0, 1]> : tensor<2xi64>} - : (tensor<100x100xf64>) -> tensor<100x100xf64> - return %0 : tensor<100x100xf64> -} - -// CHECK-LABEL: func @transpose2( -// CHECK-SAME: %[[A:.*]]: tensor<100x100xf64, #sparse{{[0-9]*}}>) -// CHECK: return %[[A]] : tensor<100x100xf64, #sparse{{[0-9]*}}> -func.func @transpose2(%arg0: tensor<100x100xf64, #DCSR>) - -> tensor<100x100xf64, #DCSR> { - %0 = "mhlo.transpose"(%arg0) - {permutation = dense<[0, 1]> : tensor<2xi64>} - : (tensor<100x100xf64, #DCSR>) -> tensor<100x100xf64, #DCSR> - return %0 : tensor<100x100xf64, #DCSR> -} - -// CHECK-LABEL: func @transpose3( -// CHECK-SAME: %[[A:.*]]: tensor<100x100xf64, #sparse{{[0-9]*}}>) -// CHECK: %[[R:.*]] = mhlo.reshape %[[A]] : (tensor<100x100xf64, #sparse{{[0-9]*}}>) -> tensor<100x100xf64> -// CHECK: return %[[R]] : tensor<100x100xf64> -func.func @transpose3(%arg0: tensor<100x100xf64, #DCSR>) - -> tensor<100x100xf64> { - %0 = "mhlo.transpose"(%arg0) - {permutation = dense<[0, 1]> : tensor<2xi64>} - : (tensor<100x100xf64, #DCSR>) -> tensor<100x100xf64> - return %0 : tensor<100x100xf64> -} - -// CHECK-LABEL: func @transpose4( -// CHECK-SAME: %[[A:.*]]: tensor<100x100xf64>) -// CHECK: %[[R:.*]] = mhlo.reshape %[[A]] : (tensor<100x100xf64>) -> tensor<100x100xf64, #sparse{{[0-9]*}}> -// CHECK: return %[[R]] : tensor<100x100xf64, #sparse{{[0-9]*}}> -func.func @transpose4(%arg0: tensor<100x100xf64>) - -> tensor<100x100xf64, #DCSR> { - %0 = "mhlo.transpose"(%arg0) - {permutation = dense<[0, 1]> : tensor<2xi64>} - : (tensor<100x100xf64>) -> tensor<100x100xf64, #DCSR> - return %0 : tensor<100x100xf64, #DCSR> -} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index e9ce4e3ce68b2..8c8c890dbf87a 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -158,7 +158,7 @@ func.func @attr_fft_type_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomple %0 = "stablehlo.fft"(%arg0) { // CHECK: fft_type = #mhlo fft_type = #stablehlo, - fft_length = dense<16> : tensor<1xi64> + fft_length = array } : (tensor<16xcomplex>) -> tensor<16xcomplex> func.return %0 : tensor<16xcomplex> } @@ -168,27 +168,27 @@ func.func @attr_fft_type_ifft(%arg0: tensor<16xcomplex>) -> tensor<16xcompl %0 = "stablehlo.fft"(%arg0) { // CHECK: fft_type = #mhlo fft_type = #stablehlo, - fft_length = dense<16> : tensor<1xi64> + fft_length = array } : (tensor<16xcomplex>) -> tensor<16xcomplex> func.return %0 : tensor<16xcomplex> } // CHECK-LABEL: "attr_fft_type_rfft" func.func @attr_fft_type_rfft(%arg0: tensor<16xf32>) -> tensor<9xcomplex> { - %0 = "mhlo.fft"(%arg0) { + %0 = "stablehlo.fft"(%arg0) { // CHECK: fft_type = #mhlo - fft_type = #mhlo, - fft_length = dense<16> : tensor<1xi64> + fft_type = #stablehlo, + fft_length = array } : (tensor<16xf32>) -> tensor<9xcomplex> func.return %0 : tensor<9xcomplex> } // CHECK-LABEL: "attr_fft_type_irfft" func.func @attr_fft_type_irfft(%arg0: tensor<9xcomplex>) -> tensor<16xf32> { - %0 = "mhlo.fft"(%arg0) { + %0 = "stablehlo.fft"(%arg0) { // CHECK: fft_type = #mhlo - fft_type = #mhlo, - fft_length = dense<16> : tensor<1xi64> + fft_type = #stablehlo, + fft_length = array } : (tensor<9xcomplex>) -> tensor<16xf32> func.return %0 : tensor<16xf32> } @@ -250,20 +250,20 @@ func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "attr_rng_distribution_uniform" -func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #mhlo.rng_distribution rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } // CHECK-LABEL: "attr_rng_distribution_normal" -func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #mhlo.rng_distribution rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } @@ -341,12 +341,12 @@ func.func @op_after_all(%arg0: !stablehlo.token) -> !stablehlo.token { // CHECK-LABEL: "op_all_gather" func.func @op_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { - // CHECK: "mhlo.all_gather"(%arg0) { + // CHECK: "mhlo.all_gather"(%arg0) <{ // CHECK-SAME: all_gather_dim = 1 : i64, // CHECK-SAME: channel_handle = #mhlo.channel_handle, // CHECK-SAME{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, // CHECK-SAME: use_global_device_ids - // CHECK-SAME: } : (tensor<16x8xf32>) -> tensor<16x16xf32> + // CHECK-SAME: }> : (tensor<16x8xf32>) -> tensor<16x16xf32> %0 = "stablehlo.all_gather"(%arg0) { all_gather_dim = 1 : i64, replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, @@ -358,15 +358,15 @@ func.func @op_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { // CHECK-LABEL: "op_all_reduce" func.func @op_all_reduce(%arg0: tensor) -> tensor { - // CHECK: "mhlo.all_reduce"(%arg0) ({ - // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: tensor, %[[ARG2:arg.*]]: tensor): - // CHECK-NEXT: %[[VAL1:.*]] = "mhlo.add"(%[[ARG1]], %[[ARG2]]) : (tensor, tensor) -> tensor - // CHECK-NEXT: "mhlo.return"(%[[VAL1]]) : (tensor) -> () - // CHECK-NEXT: }) { + // CHECK: "mhlo.all_reduce"(%arg0) <{ // CHECK-SAME: channel_handle = #mhlo.channel_handle, // CHECK-SAME{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, // CHECK-SAME: use_global_device_ids - // CHECK-SAME: } : (tensor) -> tensor + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: tensor, %[[ARG2:arg.*]]: tensor): + // CHECK-NEXT: %[[VAL1:.*]] = "mhlo.add"(%[[ARG1]], %[[ARG2]]) : (tensor, tensor) -> tensor + // CHECK-NEXT: "mhlo.return"(%[[VAL1]]) : (tensor) -> () + // CHECK-NEXT: }) : (tensor) -> tensor %0 = "stablehlo.all_reduce"(%arg0) ({ ^bb0(%arg1: tensor, %arg2: tensor): %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor @@ -381,11 +381,11 @@ func.func @op_all_reduce(%arg0: tensor) -> tensor { // CHECK-LABEL: "op_all_reduce_tuple" func.func @op_all_reduce_tuple(%arg0: tensor<8xf32>, %arg1: tensor) -> (tensor<8xf32>, tensor) { - // CHECK: "mhlo.all_reduce"(%[[ARG0:.*]], %[[ARG1:.*]]) ({ + // CHECK: "mhlo.all_reduce"(%[[ARG0:.*]], %[[ARG1:.*]]) <{replica_groups = dense<> : tensor<0x0xi64>}> ({ // CHECK-NEXT: ^bb0(%[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor): // CHECK-NEXT: %[[ADD:.*]] = "mhlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor // CHECK-NEXT: "mhlo.return"(%[[ADD]]) : (tensor) -> () - // CHECK-NEXT: }) {replica_groups = dense<> : tensor<0x0xi64>} : (tensor<8xf32>, tensor) -> (tensor<8xf32>, tensor) + // CHECK-NEXT: }) : (tensor<8xf32>, tensor) -> (tensor<8xf32>, tensor) %0:2 = stablehlo.custom_call @mhlo.all_reduce(%arg0, %arg1) {called_computations = [@all_reduce0], mhlo.attributes = {replica_groups = dense<> : tensor<0x0xi64>}} : (tensor<8xf32>, tensor) -> (tensor<8xf32>, tensor) return %0#0, %0#1 : tensor<8xf32>, tensor } @@ -396,13 +396,13 @@ func.func @all_reduce0(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-LABEL: "op_all_to_all" func.func @op_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { - // CHECK: "mhlo.all_to_all"(%arg0) { + // CHECK: "mhlo.all_to_all"(%arg0) <{ // CHECK-SAME: channel_handle = #mhlo.channel_handle, // CHECK-SAME: concat_dimension = 0 : i64, // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, // CHECK-SAME: split_count = 4 : i64, // CHECK-SAME: split_dimension = 1 : i64 - // CHECK-SAME: } : (tensor<4x16xf32>) -> tensor<16x4xf32> + // CHECK-SAME: }> : (tensor<4x16xf32>) -> tensor<16x4xf32> %0 = "stablehlo.all_to_all"(%arg0) { split_dimension = 1 : i64, concat_dimension = 0 : i64, @@ -429,10 +429,10 @@ func.func @op_atan2(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-LABEL: "op_batch_norm_grad" func.func @op_batch_norm_grad(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { - // CHECK: "mhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4) { + // CHECK: "mhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4) <{ // CHECK-SAME: epsilon = 1.000000e-03 : f32, // CHECK-SAME: feature_index = 0 : i64 - // CHECK-SAME: } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) + // CHECK-SAME: }> : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) %0:3 = "stablehlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4) { epsilon = 0.001 : f32, feature_index = 0 : i64 @@ -442,10 +442,10 @@ func.func @op_batch_norm_grad(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf // CHECK-LABEL: "op_batch_norm_inference" func.func @op_batch_norm_inference(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16xf32>) -> tensor<16x16x16x16xf32> { - // CHECK: "mhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) { + // CHECK: "mhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) <{ // CHECK-SAME: epsilon = 1.000000e-03 : f32, // CHECK-SAME: feature_index = 0 : i64 - // CHECK-SAME: } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<16x16x16x16xf32> + // CHECK-SAME: }> : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<16x16x16x16xf32> %0 = "stablehlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) { epsilon = 0.001 : f32, feature_index = 0 : i64 @@ -455,10 +455,10 @@ func.func @op_batch_norm_inference(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor // CHECK-LABEL: "op_batch_norm_training" func.func @op_batch_norm_training(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { - // CHECK: "mhlo.batch_norm_training"(%arg0, %arg1, %arg2) { + // CHECK: "mhlo.batch_norm_training"(%arg0, %arg1, %arg2) <{ // CHECK-SAME: epsilon = 1.000000e-03 : f32, // CHECK-SAME: feature_index = 0 : i64 - // CHECK-SAME: } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) + // CHECK-SAME: }> : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) %0:3 = "stablehlo.batch_norm_training"(%arg0, %arg1, %arg2) { epsilon = 0.001 : f32, feature_index = 0 : i64 @@ -475,22 +475,22 @@ func.func @op_bitcast_convert(%arg0: tensor) -> tensor { // CHECK-LABEL: "op_broadcast_in_dim" func.func @op_broadcast_in_dim(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { - // CHECK: "mhlo.broadcast_in_dim"(%arg0) { + // CHECK: "mhlo.broadcast_in_dim"(%arg0) <{ // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> - // CHECK-SAME: } : (tensor<16xf32>) -> tensor<16x16xf32> + // CHECK-SAME: }> : (tensor<16xf32>) -> tensor<16x16xf32> %0 = "stablehlo.broadcast_in_dim"(%arg0) { - broadcast_dimensions = dense<1> : tensor<1xi64> + broadcast_dimensions = array } : (tensor<16xf32>) -> tensor<16x16xf32> func.return %0 : tensor<16x16xf32> } // CHECK-LABEL: "op_broadcast" func.func @op_broadcast(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { - // CHECK: "mhlo.broadcast"(%arg0) { + // CHECK: "mhlo.broadcast"(%arg0) <{ // CHECK-SAME: broadcast_sizes = dense<16> : tensor<1xi64> - // CHECK-SAME: } : (tensor<16xf32>) -> tensor<16x16xf32> + // CHECK-SAME: }> : (tensor<16xf32>) -> tensor<16x16xf32> %0 = "stablehlo.broadcast"(%arg0) { - broadcast_sizes = dense<16> : tensor<1xi64> + broadcast_sizes = array } : (tensor<16xf32>) -> tensor<16x16xf32> func.return %0 : tensor<16x16xf32> } @@ -522,9 +522,9 @@ func.func @op_ceil(%arg0: tensor) -> tensor { // CHECK-LABEL: "op_cholesky" func.func @op_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { - // CHECK: "mhlo.cholesky"(%arg0) { + // CHECK: "mhlo.cholesky"(%arg0) <{ // CHECK-SAME: lower = true - // CHECK-SAME: } : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> + // CHECK-SAME: }> : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> %0 = "stablehlo.cholesky"(%arg0) { lower = true } : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> @@ -545,12 +545,25 @@ func.func @op_count_leading_zeros(%arg0: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "op_collective_broadcast" +func.func @op_collective_broadcast(%arg0: tensor<1x2xi64>) -> tensor<1x2xi64> { + // CHECK: "mhlo.collective_broadcast"(%arg0) <{ + // CHECK-SAME: channel_handle = #mhlo.channel_handle, + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + // CHECK-SAME: }> : (tensor<1x2xi64>) -> tensor<1x2xi64> + %0 = "stablehlo.collective_broadcast"(%arg0) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<1x2xi64>) -> tensor<1x2xi64> + func.return %0 : tensor<1x2xi64> +} + // CHECK-LABEL: "op_collective_permute" func.func @op_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { - // CHECK: "mhlo.collective_permute"(%arg0) { + // CHECK: "mhlo.collective_permute"(%arg0) <{ // CHECK-SAME: channel_handle = #mhlo.channel_handle, // CHECK-SAME{LITERAL}: source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> - // CHECK-SAME: } : (tensor<16x8xf32>) -> tensor<16x8xf32> + // CHECK-SAME: }> : (tensor<16x8xf32>) -> tensor<16x8xf32> %0 = "stablehlo.collective_permute"(%arg0) { source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, channel_handle = #stablehlo.channel_handle @@ -560,10 +573,10 @@ func.func @op_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { // CHECK-LABEL: "op_compare" func.func @op_compare(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: "mhlo.compare"(%arg0, %arg1) { + // CHECK: "mhlo.compare"(%arg0, %arg1) <{ // CHECK-SAME: compare_type = #mhlo, // CHECK-SAME: comparison_direction = #mhlo - // CHECK-SAME: } : (tensor, tensor) -> tensor + // CHECK-SAME: }> : (tensor, tensor) -> tensor %0 = "stablehlo.compare"(%arg0, %arg1) { comparison_direction = #stablehlo, compare_type = #stablehlo @@ -578,6 +591,22 @@ func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> } +// CHECK-LABEL: "op_composite" +func.func @op_composite(%arg0 : tensor) -> tensor { + // CHECK: "mhlo.composite"(%arg0) <{composite_attributes = {n = 2 : i64}, decomposition = @add_n.impl, name = "stablehlo.add_n"}> : (tensor) -> tensor + %0 = stablehlo.composite "stablehlo.add_n" %arg0 { + composite_attributes = { n = 2 : i64 }, + decomposition = @add_n.impl + } : (tensor) -> tensor + func.return %0 : tensor +} + +func.func @add_n.impl(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2> : tensor + %1 = stablehlo.add %arg0, %0 : tensor + func.return %1 : tensor +} + // CHECK-LABEL: "op_compute_reshape_shape" func.func @op_compute_reshape_shape(%arg0: index, %arg1: tensor<1xindex>) -> tensor<1xindex> { // CHECK: "mhlo.compute_reshape_shape"(%arg0, %arg1) : (index, tensor<1xindex>) -> tensor<1xindex> @@ -587,9 +616,9 @@ func.func @op_compute_reshape_shape(%arg0: index, %arg1: tensor<1xindex>) -> ten // CHECK-LABEL: "op_concatenate" func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { - // CHECK: "mhlo.concatenate"(%arg0, %arg1) { + // CHECK: "mhlo.concatenate"(%arg0, %arg1) <{ // CHECK-SAME: dimension = 0 : i64 - // CHECK-SAME: } : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32> + // CHECK-SAME: }> : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32> %0 = "stablehlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32> @@ -598,9 +627,9 @@ func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor< // CHECK-LABEL: "op_constant" func.func @op_constant(%arg0: tensor) -> tensor { - // CHECK: "mhlo.constant"() { + // CHECK: "mhlo.constant"() <{ // CHECK-SAME: value = dense<0.000000e+00> : tensor - // CHECK-SAME: } : () -> tensor + // CHECK-SAME: }> : () -> tensor %0 = "stablehlo.constant"() { value = dense<0.0> : tensor } : () -> tensor @@ -616,7 +645,7 @@ func.func @op_convert(%arg0: tensor) -> tensor { // CHECK-LABEL: "op_convolution" func.func @op_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // CHECK: "mhlo.convolution"(%arg0, %arg1) { + // CHECK: "mhlo.convolution"(%arg0, %arg1) <{ // CHECK-SAME: batch_group_count = 1 : i64, // CHECK-SAME: dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, // CHECK-SAME: feature_group_count = 1 : i64, @@ -626,13 +655,13 @@ func.func @op_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16 // CHECK-SAME: rhs_dilation = dense<1> : tensor<2xi64>, // CHECK-SAME: window_reversal = dense : tensor<2xi1>, // CHECK-SAME: window_strides = dense<1> : tensor<2xi64> - // CHECK-SAME: } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + // CHECK-SAME: }> : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> %0 = "stablehlo.convolution"(%arg0, %arg1) { - window_strides = dense<1> : tensor<2xi64>, + window_strides = array, padding = dense<1> : tensor<2x2xi64>, - lhs_dilation = dense<1> : tensor<2xi64>, - rhs_dilation = dense<1> : tensor<2xi64>, - window_reversal = dense : tensor<2xi1>, + lhs_dilation = array, + rhs_dilation = array, + window_reversal = array, dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 1 : i64, batch_group_count = 1 : i64, @@ -657,9 +686,9 @@ func.func @op_create_token() -> !stablehlo.token { // CHECK-LABEL: "op_cross_replica_sum" func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { - // CHECK: "mhlo.cross-replica-sum"(%arg0) { + // CHECK: "mhlo.cross-replica-sum"(%arg0) <{ // CHECK-SAME{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> - // CHECK-SAME: } : (tensor) -> tensor + // CHECK-SAME: }> : (tensor) -> tensor %0 = "stablehlo.cross-replica-sum"(%arg0) { replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> } : (tensor) -> tensor @@ -676,7 +705,7 @@ func.func @op_cstr_reshapable(%arg0: index, %arg1: tensor<1xindex>) -> !shape.wi // CHECK-LABEL: "op_custom_call_api_version_original" func.func @called_computation() { func.return } func.func @op_custom_call_api_version_original(%arg0: tensor) -> tensor { - // CHECK: "mhlo.custom_call"(%arg0) { + // CHECK: "mhlo.custom_call"(%arg0) <{ // CHECK-SAME: api_version = 1 : i32, // CHECK-SAME: backend_config = "", // CHECK-SAME: call_target_name = "foo", @@ -689,7 +718,7 @@ func.func @op_custom_call_api_version_original(%arg0: tensor) -> tensor] // CHECK-SAME: result_layouts = [dense<> : tensor<0xindex>] - // CHECK-SAME: } : (tensor) -> tensor + // CHECK-SAME: }> : (tensor) -> tensor %0 = "stablehlo.custom_call"(%arg0) { call_target_name = "foo", has_side_effect = false, @@ -708,11 +737,11 @@ func.func @op_custom_call_api_version_original(%arg0: tensor) -> tensor) -> tensor { - // CHECK: "mhlo.custom_call"(%arg0) { + // CHECK: "mhlo.custom_call"(%arg0) <{ // CHECK-SAME: api_version = 4 : i32, // CHECK-SAME: backend_config = {foo = "bar"}, // CHECK-SAME: call_target_name = "foo" - // CHECK-SAME: } : (tensor) -> tensor + // CHECK-SAME: }> : (tensor) -> tensor %0 = "stablehlo.custom_call"(%arg0) { call_target_name = "mhlo.custom_call", mhlo.attributes = {api_version = 4 : i32, backend_config = {foo = "bar"}, call_target_name = "foo"}, @@ -723,11 +752,11 @@ func.func @op_custom_call_api_version_typed_ffi(%arg0: tensor) -> tensor) -> tensor<16x4xbf16> { - // CHECK: "mhlo.custom_call"(%arg0) { + // CHECK: "mhlo.custom_call"(%arg0) <{ // CHECK-SAME: api_version = 4 : i32, // CHECK-SAME: backend_config = {aggregate_to_topk = true}, // CHECK-SAME: call_target_name = "foo" - // CHECK-SAME: } : (tensor<16x256xbf16>) -> tensor<16x4xbf16> + // CHECK-SAME: }> : (tensor<16x256xbf16>) -> tensor<16x4xbf16> %4 = stablehlo.custom_call @foo(%arg0) { "mhlo.backend_config" = {aggregate_to_topk = true} } : (tensor<16x256xbf16>) -> tensor<16x4xbf16> @@ -743,7 +772,7 @@ func.func @op_divide(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-LABEL: "op_dot_general" func.func @op_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { - // CHECK: "mhlo.dot_general"(%arg0, %arg1) { + // CHECK: "mhlo.dot_general"(%arg0, %arg1) <{ // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< // CHECK-SAME: lhs_batching_dimensions = [0], // CHECK-SAME: rhs_batching_dimensions = [0], @@ -751,7 +780,7 @@ func.func @op_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) // CHECK-SAME: rhs_contracting_dimensions = [1] // CHECK-SAME: >, // CHECK-SAME: precision_config = [] - // CHECK-SAME: } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + // CHECK-SAME: }> : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_batching_dimensions = [0], @@ -766,9 +795,9 @@ func.func @op_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) // CHECK-LABEL: "op_dot" func.func @op_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { - // CHECK: "mhlo.dot"(%arg0, %arg1) { + // CHECK: "mhlo.dot"(%arg0, %arg1) <{ // CHECK-SAME: precision_config = [] - // CHECK-SAME: } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + // CHECK-SAME: }> : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> %0 = "stablehlo.dot"(%arg0, %arg1) { precision_config = [] } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> @@ -777,22 +806,22 @@ func.func @op_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x // CHECK-LABEL: "op_dynamic_broadcast_in_dim" func.func @op_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xindex>) -> tensor { - // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) <{ // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64>, // CHECK-SAME: known_expanding_dimensions = dense<> : tensor<0xi64>, // CHECK-SAME: known_nonexpanding_dimensions = dense<0> : tensor<1xi64> - // CHECK-SAME: } : (tensor, tensor<2xindex>) -> tensor + // CHECK-SAME: }> : (tensor, tensor<2xindex>) -> tensor %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { - broadcast_dimensions = dense<1> : tensor<1xi64>, - known_expanding_dimensions = dense<[]> : tensor<0xi64>, - known_nonexpanding_dimensions = dense<0> : tensor<1xi64> + broadcast_dimensions = array, + known_expanding_dimensions = array, + known_nonexpanding_dimensions = array } : (tensor, tensor<2xindex>) -> tensor func.return %0 : tensor } // CHECK-LABEL: "op_dynamic_conv" func.func @op_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<4xi32>) -> tensor<1x?x?x16xf32> { - // CHECK: "mhlo.dynamic_conv"(%arg0, %arg1, %arg2) { + // CHECK: "mhlo.dynamic_conv"(%arg0, %arg1, %arg2) <{ // CHECK-SAME: batch_group_count = 1 : i64, // CHECK-SAME: dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, // CHECK-SAME: feature_group_count = 1 : i64, @@ -802,13 +831,13 @@ func.func @op_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x1 // CHECK-SAME: rhs_dilation = dense<1> : tensor<2xi64>, // CHECK-SAME: window_reversal = dense : tensor<2xi1>, // CHECK-SAME: window_strides = dense<1> : tensor<2xi64> - // CHECK-SAME: } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<4xi32>) -> tensor<1x?x?x16xf32> + // CHECK-SAME: }> : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<4xi32>) -> tensor<1x?x?x16xf32> %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { - window_strides = dense<1> : tensor<2xi64>, + window_strides = array, padding = dense<1> : tensor<2x2xi64>, - lhs_dilation = dense<1> : tensor<2xi64>, - rhs_dilation = dense<1> : tensor<2xi64>, - window_reversal = dense : tensor<2xi1>, + lhs_dilation = array, + rhs_dilation = array, + window_reversal = array, dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 1 : i64, batch_group_count = 1 : i64, @@ -819,7 +848,7 @@ func.func @op_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x1 // CHECK-LABEL: "op_dynamic_gather" func.func @op_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xf32> { - // CHECK: "mhlo.dynamic_gather"(%arg0, %arg1, %arg2) { + // CHECK: "mhlo.dynamic_gather"(%arg0, %arg1, %arg2) <{ // CHECK-SAME: dimension_numbers = #mhlo.gather< // CHECK-SAME: offset_dims = [2], // CHECK-SAME: collapsed_slice_dims = [0, 1], @@ -827,7 +856,7 @@ func.func @op_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32 // CHECK-SAME: index_vector_dim = 2 // CHECK-SAME: >, // CHECK-SAME: indices_are_sorted = false - // CHECK-SAME: } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> + // CHECK-SAME: }> : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { dimension_numbers = #stablehlo.gather< offset_dims = [2], @@ -842,9 +871,9 @@ func.func @op_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32 // CHECK-LABEL: "op_dynamic_iota" func.func @op_dynamic_iota(%arg0: tensor<1xindex>) -> tensor { - // CHECK: "mhlo.dynamic_iota"(%arg0) { + // CHECK: "mhlo.dynamic_iota"(%arg0) <{ // CHECK-SAME: iota_dimension = 0 : i64 - // CHECK-SAME: } : (tensor<1xindex>) -> tensor + // CHECK-SAME: }> : (tensor<1xindex>) -> tensor %0 = "stablehlo.dynamic_iota"(%arg0) { iota_dimension = 0 : i64 } : (tensor<1xindex>) -> tensor @@ -859,19 +888,19 @@ func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tenso } // CHECK-LABEL: "op_dynamic_reshape" -func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { - // CHECK: "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor - %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor func.return %0 : tensor } // CHECK-LABEL: "op_dynamic_slice" func.func @op_dynamic_slice(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor<4xf32> { - // CHECK: "mhlo.dynamic_slice"(%arg0, %arg1) { + // CHECK: "mhlo.dynamic_slice"(%arg0, %arg1) <{ // CHECK-SAME: slice_sizes = dense<4> : tensor<1xi64> - // CHECK-SAME: } : (tensor<16xf32>, tensor) -> tensor<4xf32> + // CHECK-SAME: }> : (tensor<16xf32>, tensor) -> tensor<4xf32> %0 = "stablehlo.dynamic_slice"(%arg0, %arg1) { - slice_sizes = dense<4> : tensor<1xi64> + slice_sizes = array } : (tensor<16xf32>, tensor) -> tensor<4xf32> func.return %0 : tensor<4xf32> } @@ -885,9 +914,9 @@ func.func @op_dynamic_update_slice(%arg0: tensor<16xf32>, %arg1: tensor<4xf32>, // CHECK-LABEL: "op_einsum" func.func @op_einsum(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { - // CHECK: "mhlo.einsum"(%arg0, %arg1) { + // CHECK: "mhlo.einsum"(%arg0, %arg1) <{ // CHECK-SAME: einsum_config = "ab,bc->ac" - // CHECK-SAME: } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + // CHECK-SAME: }> : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> %0 = "stablehlo.einsum"(%arg0, %arg1) { einsum_config = "ab,bc->ac" } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> @@ -910,13 +939,13 @@ func.func @op_exponential(%arg0: tensor) -> tensor { // CHECK-LABEL: "op_fft" func.func @op_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { - // CHECK: "mhlo.fft"(%arg0) { + // CHECK: "mhlo.fft"(%arg0) <{ // CHECK-SAME: fft_length = dense<16> : tensor<1xi64>, // CHECK-SAME: fft_type = #mhlo - // CHECK-SAME: } : (tensor<16xcomplex>) -> tensor<16xcomplex> + // CHECK-SAME: }> : (tensor<16xcomplex>) -> tensor<16xcomplex> %0 = "stablehlo.fft"(%arg0) { fft_type = #stablehlo, - fft_length = dense<16> : tensor<1xi64> + fft_length = array } : (tensor<16xcomplex>) -> tensor<16xcomplex> func.return %0 : tensor<16xcomplex> } @@ -930,7 +959,7 @@ func.func @op_floor(%arg0: tensor) -> tensor { // CHECK-LABEL: "op_gather" func.func @op_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { - // CHECK: "mhlo.gather"(%arg0, %arg1) { + // CHECK: "mhlo.gather"(%arg0, %arg1) <{ // CHECK-SAME: dimension_numbers = #mhlo.gather< // CHECK-SAME: offset_dims = [2], // CHECK-SAME: collapsed_slice_dims = [0, 1], @@ -939,7 +968,7 @@ func.func @op_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> te // CHECK-SAME: >, // CHECK-SAME: indices_are_sorted = false, // CHECK-SAME: slice_sizes = dense<1> : tensor<3xi64> - // CHECK-SAME: } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> + // CHECK-SAME: }> : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> %0 = "stablehlo.gather"(%arg0, %arg1) { dimension_numbers = #stablehlo.gather< offset_dims = [2], @@ -947,7 +976,7 @@ func.func @op_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> te start_index_map = [0, 1], index_vector_dim = 2 >, - slice_sizes = dense<1> : tensor<3xi64>, + slice_sizes = array, indices_are_sorted = false } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> func.return %0 : tensor<1x5x1xf32> @@ -955,9 +984,9 @@ func.func @op_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> te // CHECK-LABEL: "op_get_dimension_size" func.func @op_get_dimension_size(%arg0: tensor) -> tensor { - // CHECK: "mhlo.get_dimension_size"(%arg0) { + // CHECK: "mhlo.get_dimension_size"(%arg0) <{ // CHECK-SAME: dimension = 0 : i64 - // CHECK-SAME: } : (tensor) -> tensor + // CHECK-SAME: }> : (tensor) -> tensor %0 = "stablehlo.get_dimension_size"(%arg0) { dimension = 0 : i64 } : (tensor) -> tensor @@ -966,9 +995,9 @@ func.func @op_get_dimension_size(%arg0: tensor) -> tensor { // CHECK-LABEL: "op_get_tuple_element" func.func @op_get_tuple_element(%arg0: tuple, tensor, tensor, tensor, tensor>) -> tensor { - // CHECK: "mhlo.get_tuple_element"(%arg0) { + // CHECK: "mhlo.get_tuple_element"(%arg0) <{ // CHECK-SAME: index = 4 : i32 - // CHECK-SAME: } : (tuple, tensor, tensor, tensor, tensor>) -> tensor + // CHECK-SAME: }> : (tuple, tensor, tensor, tensor, tensor>) -> tensor %0 = "stablehlo.get_tuple_element"(%arg0) { index = 4 : i32 } : (tuple, tensor, tensor, tensor, tensor>) -> tensor @@ -999,10 +1028,10 @@ func.func @op_imag(%arg0: tensor>) -> tensor { // CHECK-LABEL: "op_infeed" func.func @op_infeed(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { - // CHECK: "mhlo.infeed"(%arg0) { + // CHECK: "mhlo.infeed"(%arg0) <{ // CHECK-SAME: infeed_config = "", // CHECK-SAME{LITERAL}: layout = [[]] - // CHECK-SAME: } : (!mhlo.token) -> (tensor, !mhlo.token) + // CHECK-SAME: }> : (!mhlo.token) -> (tensor, !mhlo.token) %0:2 = "stablehlo.infeed"(%arg0) { infeed_config = "", layout = [[]] @@ -1012,9 +1041,9 @@ func.func @op_infeed(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) // CHECK-LABEL: "op_iota" func.func @op_iota() -> tensor<16xf32> { - // CHECK: "mhlo.iota"() { + // CHECK: "mhlo.iota"() <{ // CHECK-SAME: iota_dimension = 0 : i64 - // CHECK-SAME: } : () -> tensor<16xf32> + // CHECK-SAME: }> : () -> tensor<16xf32> %0 = "stablehlo.iota"() { iota_dimension = 0 : i64 } : () -> tensor<16xf32> @@ -1051,19 +1080,19 @@ func.func @op_logistic(%arg0: tensor) -> tensor { // CHECK-LABEL: "op_map" func.func @op_map(%arg0: tensor<16xf32>) -> tensor<16xf32> { - // CHECK: "mhlo.map"(%arg0) ({ + // CHECK: "mhlo.map"(%arg0) <{ + // CHECK-SAME: dimensions = dense<0> : tensor<1xi64> + // CHECK-SAME: }> ({ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: tensor): // CHECK-NEXT: %[[VAL1:.*]] = "mhlo.abs"(%[[ARG1]]) : (tensor) -> tensor // CHECK-NEXT: "mhlo.return"(%[[VAL1]]) : (tensor) -> () - // CHECK-NEXT: }) { - // CHECK-SAME: dimensions = dense<0> : tensor<1xi64> - // CHECK-SAME: } : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: }) : (tensor<16xf32>) -> tensor<16xf32> %0 = "stablehlo.map"(%arg0) ({ ^bb0(%arg1: tensor): %1 = "stablehlo.abs"(%arg1) : (tensor) -> tensor "stablehlo.return"(%1) : (tensor) -> () }) { - dimensions = dense<0> : tensor<1xi64> + dimensions = array } : (tensor<16xf32>) -> tensor<16xf32> func.return %0 : tensor<16xf32> } @@ -1119,9 +1148,9 @@ func.func @op_or(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-LABEL: "op_outfeed" func.func @op_outfeed(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { - // CHECK: "mhlo.outfeed"(%arg0, %arg1) { + // CHECK: "mhlo.outfeed"(%arg0, %arg1) <{ // CHECK-SAME: outfeed_config = "" - // CHECK-SAME: } : (tensor, !mhlo.token) -> !mhlo.token + // CHECK-SAME: }> : (tensor, !mhlo.token) -> !mhlo.token %0 = "stablehlo.outfeed"(%arg0, %arg1) { outfeed_config = "" } : (tensor, !stablehlo.token) -> !stablehlo.token @@ -1130,15 +1159,15 @@ func.func @op_outfeed(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo // CHECK-LABEL: "op_pad" func.func @op_pad(%arg0: tensor<8xf32>, %arg1: tensor) -> tensor<16xf32> { - // CHECK: "mhlo.pad"(%arg0, %arg1) { + // CHECK: "mhlo.pad"(%arg0, %arg1) <{ // CHECK-SAME: edge_padding_high = dense<4> : tensor<1xi64>, // CHECK-SAME: edge_padding_low = dense<4> : tensor<1xi64>, // CHECK-SAME: interior_padding = dense<0> : tensor<1xi64> - // CHECK-SAME: } : (tensor<8xf32>, tensor) -> tensor<16xf32> + // CHECK-SAME: }> : (tensor<8xf32>, tensor) -> tensor<16xf32> %0 = "stablehlo.pad"(%arg0, %arg1) { - edge_padding_high = dense<4> : tensor<1xi64>, - edge_padding_low = dense<4> : tensor<1xi64>, - interior_padding = dense<0> : tensor<1xi64> + edge_padding_high = array, + edge_padding_low = array, + interior_padding = array } : (tensor<8xf32>, tensor) -> tensor<16xf32> func.return %0 : tensor<16xf32> } @@ -1180,10 +1209,10 @@ func.func @op_real(%arg0: tensor>) -> tensor { // CHECK-LABEL: "op_recv" func.func @op_recv(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { - // CHECK: "mhlo.recv"(%arg0) { + // CHECK: "mhlo.recv"(%arg0) <{ // CHECK-SAME: channel_handle = #mhlo.channel_handle, // CHECK-SAME: is_host_transfer = true - // CHECK-SAME: } : (!mhlo.token) -> (tensor, !mhlo.token) + // CHECK-SAME: }> : (!mhlo.token) -> (tensor, !mhlo.token) %0:2 = "stablehlo.recv"(%arg0) { channel_handle = #stablehlo.channel_handle, is_host_transfer = true @@ -1198,17 +1227,17 @@ func.func @op_reduce(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor "stablehlo.return"(%1) : (tensor) -> () }) { - dimensions = dense<0> : tensor<1xi64> + dimensions = array } : (tensor<16xf32>, tensor) -> tensor func.return %0 : tensor } // CHECK-LABEL: "op_reduce_precision" func.func @op_reduce_precision(%arg0: tensor) -> tensor { - // CHECK: "mhlo.reduce_precision"(%arg0) { + // CHECK: "mhlo.reduce_precision"(%arg0) <{ // CHECK-SAME: exponent_bits = 8 : i32, // CHECK-SAME: mantissa_bits = 10 : i32 - // CHECK-SAME: } : (tensor) -> tensor + // CHECK-SAME: }> : (tensor) -> tensor %0 = "stablehlo.reduce_precision"(%arg0) { exponent_bits = 8 : i32, mantissa_bits = 10 : i32 @@ -1218,15 +1247,15 @@ func.func @op_reduce_precision(%arg0: tensor) -> tensor { // CHECK-LABEL: "op_reduce_scatter" func.func @op_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { - // CHECK: "mhlo.reduce_scatter"(%arg0) ({ - // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: tensor, %[[ARG2:arg.*]]: tensor): - // CHECK-NEXT: %[[VAL1:.*]] = "mhlo.add"(%[[ARG1]], %[[ARG2]]) : (tensor, tensor) -> tensor - // CHECK-NEXT: "mhlo.return"(%[[VAL1]]) : (tensor) -> () - // CHECK-NEXT: }) { + // CHECK: "mhlo.reduce_scatter"(%arg0) <{ // CHECK-SAME: channel_handle = #mhlo.channel_handle, // CHECK-SAME{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, // CHECK-SAME: scatter_dimension = 0 : i64 - // CHECK-SAME: } : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: tensor, %[[ARG2:arg.*]]: tensor): + // CHECK-NEXT: %[[VAL1:.*]] = "mhlo.add"(%[[ARG1]], %[[ARG2]]) : (tensor, tensor) -> tensor + // CHECK-NEXT: "mhlo.return"(%[[VAL1]]) : (tensor) -> () + // CHECK-NEXT: }) : (tensor<16xf32>) -> tensor<16xf32> %0 = "stablehlo.reduce_scatter"(%arg0) ({ ^bb0(%arg1: tensor, %arg2: tensor): %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor @@ -1241,26 +1270,26 @@ func.func @op_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { // CHECK-LABEL: "op_reduce_window" func.func @op_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x5x8x7xf32> { - // CHECK: "mhlo.reduce_window"(%arg0, %arg1) ({ - // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: tensor, %[[ARG3:arg.*]]: tensor): - // CHECK-NEXT: %[[VAL1:.*]] = "mhlo.maximum"(%[[ARG2]], %[[ARG3]]) : (tensor, tensor) -> tensor - // CHECK-NEXT: "mhlo.return"(%[[VAL1]]) : (tensor) -> () - // CHECK-NEXT: }) { + // CHECK: "mhlo.reduce_window"(%arg0, %arg1) <{ // CHECK-SAME: base_dilations = dense<1> : tensor<4xi64>, // CHECK-SAME{LITERAL}: padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64>, // CHECK-SAME: window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>, // CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, // CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64> - // CHECK-SAME: } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x5x8x7xf32> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: tensor, %[[ARG3:arg.*]]: tensor): + // CHECK-NEXT: %[[VAL1:.*]] = "mhlo.maximum"(%[[ARG2]], %[[ARG3]]) : (tensor, tensor) -> tensor + // CHECK-NEXT: "mhlo.return"(%[[VAL1]]) : (tensor) -> () + // CHECK-NEXT: }) : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x5x8x7xf32> %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor, tensor) -> tensor "stablehlo.return"(%1) : (tensor) -> () }) { - window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, - window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>, - base_dilations = dense<[1, 1, 1, 1]> : tensor<4xi64>, - window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_dimensions = array, + window_strides = array, + base_dilations = array, + window_dilations = array, padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64> } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x5x8x7xf32> func.return %0 : tensor<2x5x8x7xf32> @@ -1300,20 +1329,20 @@ func.func @op_return(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-LABEL: "op_reverse" func.func @op_reverse(%arg0: tensor<16xf32>) -> tensor<16xf32> { - // CHECK: "mhlo.reverse"(%arg0) { + // CHECK: "mhlo.reverse"(%arg0) <{ // CHECK-SAME: dimensions = dense<0> : tensor<1xi64> - // CHECK-SAME: } : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-SAME: }> : (tensor<16xf32>) -> tensor<16xf32> %0 = "stablehlo.reverse"(%arg0) { - dimensions = dense<0> : tensor<1xi64> + dimensions = array } : (tensor<16xf32>) -> tensor<16xf32> func.return %0 : tensor<16xf32> } // CHECK-LABEL: "op_rng_bit_generator" func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor) { - // CHECK: "mhlo.rng_bit_generator"(%arg0) { + // CHECK: "mhlo.rng_bit_generator"(%arg0) <{ // CHECK-SAME: rng_algorithm = #mhlo.rng_algorithm - // CHECK-SAME: } : (tensor) -> (tensor, tensor) + // CHECK-SAME: }> : (tensor) -> (tensor, tensor) %0:2 = "stablehlo.rng_bit_generator"(%arg0) { rng_algorithm = #stablehlo } : (tensor) -> (tensor, tensor) @@ -1321,13 +1350,13 @@ func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "op_rng" -func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK: "mhlo.rng"(%arg0, %arg1, %arg2) { +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { + // CHECK: "mhlo.rng"(%arg0, %arg1, %arg2) <{ // CHECK-SAME: rng_distribution = #mhlo.rng_distribution - // CHECK-SAME: } : (tensor, tensor, tensor) -> tensor + // CHECK-SAME: }> : (tensor, tensor, tensor<0xindex>) -> tensor %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } @@ -1354,11 +1383,7 @@ func.func @op_rsqrt(%arg0: tensor) -> tensor { // CHECK-LABEL: "op_scatter" func.func @op_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<200x100x300xf32> { - // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ({ - // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: tensor, %[[ARG4:arg.*]]: tensor): - // CHECK-NEXT: %[[VAL1:.*]] = "mhlo.add"(%[[ARG3]], %[[ARG4]]) : (tensor, tensor) -> tensor - // CHECK-NEXT: "mhlo.return"(%[[VAL1]]) : (tensor) -> () - // CHECK-NEXT: }) { + // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) <{ // CHECK-SAME: indices_are_sorted = true, // CHECK-SAME: scatter_dimension_numbers = #mhlo.scatter< // CHECK-SAME: update_window_dims = [1], @@ -1367,7 +1392,11 @@ func.func @op_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, % // CHECK-SAME: index_vector_dim = 1 // CHECK-SAME: >, // CHECK-SAME: unique_indices = true - // CHECK-SAME: } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: tensor, %[[ARG4:arg.*]]: tensor): + // CHECK-NEXT: %[[VAL1:.*]] = "mhlo.add"(%[[ARG3]], %[[ARG4]]) : (tensor, tensor) -> tensor + // CHECK-NEXT: "mhlo.return"(%[[VAL1]]) : (tensor) -> () + // CHECK-NEXT: }) : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ ^bb0(%arg3: tensor, %arg4: tensor): %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor @@ -1387,19 +1416,19 @@ func.func @op_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, % // CHECK-LABEL: "op_select_and_scatter" func.func @op_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { - // CHECK: "mhlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + // CHECK: "mhlo.select_and_scatter"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: padding = dense<0> : tensor<4x2xi64>, + // CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + // CHECK-SAME: window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> + // CHECK-SAME: }> ({ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG31:arg.*]]: tensor, %[[ARG41:arg.*]]: tensor): - // CHECK-NEXT: %[[VAL11:.*]] = "mhlo.compare"(%[[ARG31]], %[[ARG41]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[VAL11:.*]] = "mhlo.compare"(%[[ARG31]], %[[ARG41]]) <{compare_type = #mhlo, comparison_direction = #mhlo}> : (tensor, tensor) -> tensor // CHECK-NEXT: "mhlo.return"(%[[VAL11]]) : (tensor) -> () // CHECK-NEXT: }, { // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG32:arg.*]]: tensor, %[[ARG42:arg.*]]: tensor): // CHECK-NEXT: %[[VAL12:.*]] = "mhlo.add"(%[[ARG32]], %[[ARG42]]) : (tensor, tensor) -> tensor // CHECK-NEXT: "mhlo.return"(%[[VAL12]]) : (tensor) -> () - // CHECK-NEXT: }) { - // CHECK-SAME: padding = dense<0> : tensor<4x2xi64>, - // CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, - // CHECK-SAME: window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> - // CHECK-SAME: } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> + // CHECK-NEXT: }) : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ ^bb0(%arg3: tensor, %arg4: tensor): %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor @@ -1409,8 +1438,8 @@ func.func @op_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<1 %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor "stablehlo.return"(%1) : (tensor) -> () }) { - window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, - window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_dimensions = array, + window_strides = array, padding = dense<0> : tensor<4x2xi64> } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> func.return %0 : tensor<10x24x24x64xf32> @@ -1425,10 +1454,10 @@ func.func @op_select(%arg0: tensor, %arg1: tensor, %arg2: tensor) // CHECK-LABEL: "op_send" func.func @op_send(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { - // CHECK: "mhlo.send"(%arg0, %arg1) { + // CHECK: "mhlo.send"(%arg0, %arg1) <{ // CHECK-SAME: channel_handle = #mhlo.channel_handle, // CHECK-SAME: is_host_transfer = true - // CHECK-SAME: } : (tensor, !mhlo.token) -> !mhlo.token + // CHECK-SAME: }> : (tensor, !mhlo.token) -> !mhlo.token %0 = "stablehlo.send"(%arg0, %arg1) { channel_handle = #stablehlo.channel_handle, is_host_transfer = true @@ -1438,9 +1467,9 @@ func.func @op_send(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.to // CHECK-LABEL: "op_set_dimension_size" func.func @op_set_dimension_size(%arg0: tensor, %arg1: tensor) -> tensor<16xf32> { - // CHECK: "mhlo.set_dimension_size"(%arg0, %arg1) { + // CHECK: "mhlo.set_dimension_size"(%arg0, %arg1) <{ // CHECK-SAME: dimension = 0 : i64 - // CHECK-SAME: } : (tensor, tensor) -> tensor<16xf32> + // CHECK-SAME: }> : (tensor, tensor) -> tensor<16xf32> %0 = "stablehlo.set_dimension_size"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor, tensor) -> tensor<16xf32> @@ -1484,29 +1513,29 @@ func.func @op_sine(%arg0: tensor) -> tensor { // CHECK-LABEL: "op_slice" func.func @op_slice(%arg0: tensor<16xf32>) -> tensor<4xf32> { - // CHECK: "mhlo.slice"(%arg0) { + // CHECK: "mhlo.slice"(%arg0) <{ // CHECK-SAME: limit_indices = dense<4> : tensor<1xi64>, // CHECK-SAME: start_indices = dense<0> : tensor<1xi64>, // CHECK-SAME: strides = dense<1> : tensor<1xi64> - // CHECK-SAME: } : (tensor<16xf32>) -> tensor<4xf32> + // CHECK-SAME: }> : (tensor<16xf32>) -> tensor<4xf32> %0 = "stablehlo.slice"(%arg0) { - start_indices = dense<0> : tensor<1xi64>, - limit_indices = dense<4> : tensor<1xi64>, - strides = dense<1> : tensor<1xi64> + start_indices = array, + limit_indices = array, + strides = array } : (tensor<16xf32>) -> tensor<4xf32> func.return %0 : tensor<4xf32> } // CHECK-LABEL: "op_sort" func.func @op_sort(%arg0: tensor<16xf32>) -> tensor<16xf32> { - // CHECK: "mhlo.sort"(%arg0) ({ - // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: tensor, %[[ARG2:arg.*]]: tensor): - // CHECK-NEXT: %[[VAL1:.*]] = "mhlo.compare"(%[[ARG1]], %[[ARG2]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor - // CHECK-NEXT: "mhlo.return"(%[[VAL1]]) : (tensor) -> () - // CHECK-NEXT: }) { + // CHECK: "mhlo.sort"(%arg0) <{ // CHECK-SAME: dimension = 0 : i64, // CHECK-SAME: is_stable = true - // CHECK-SAME: } : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: tensor, %[[ARG2:arg.*]]: tensor): + // CHECK-NEXT: %[[VAL1:.*]] = "mhlo.compare"(%[[ARG1]], %[[ARG2]]) <{compare_type = #mhlo, comparison_direction = #mhlo}> : (tensor, tensor) -> tensor + // CHECK-NEXT: "mhlo.return"(%[[VAL1]]) : (tensor) -> () + // CHECK-NEXT: }) : (tensor<16xf32>) -> tensor<16xf32> %0 = "stablehlo.sort"(%arg0) ({ ^bb0(%arg1: tensor, %arg2: tensor): %1 = "stablehlo.compare"(%arg1, %arg2) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor @@ -1553,10 +1582,10 @@ func.func @op_tanh(%arg0: tensor) -> tensor { // CHECK-LABEL: "op_torch_index_select" func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>) -> tensor<2x1x5xf32> { - // CHECK: "mhlo.torch_index_select"(%arg0, %arg1) { + // CHECK: "mhlo.torch_index_select"(%arg0, %arg1) <{ // CHECK-SAME: batch_dims = 0 : i64, // CHECK-SAME: dim = 0 : i64 - // CHECK-SAME: } : (tensor<5x1x5xf32>, tensor<2xi32>) -> tensor<2x1x5xf32> + // CHECK-SAME: }> : (tensor<5x1x5xf32>, tensor<2xi32>) -> tensor<2x1x5xf32> %0 = "stablehlo.torch_index_select"(%arg0, %arg1) { dim = 0 : i64, batch_dims = 0 : i64 @@ -1566,9 +1595,9 @@ func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>) // CHECK-LABEL: "op_trace" func.func @op_trace(%arg0: tensor) { - // CHECK: "mhlo.trace"(%arg0) { + // CHECK: "mhlo.trace"(%arg0) <{ // CHECK-SAME: tag = "foo" - // CHECK-SAME: } : (tensor) -> () + // CHECK-SAME: }> : (tensor) -> () "stablehlo.trace"(%arg0) { tag = "foo" } : (tensor) -> () @@ -1577,23 +1606,23 @@ func.func @op_trace(%arg0: tensor) { // CHECK-LABEL: "op_transpose" func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> { - // CHECK: "mhlo.transpose"(%arg0) { + // CHECK: "mhlo.transpose"(%arg0) <{ // CHECK-SAME: permutation = dense<[1, 0]> : tensor<2xi64> - // CHECK-SAME: } : (tensor<16x8xf32>) -> tensor<8x16xf32> + // CHECK-SAME: }> : (tensor<16x8xf32>) -> tensor<8x16xf32> %0 = "stablehlo.transpose"(%arg0) { - permutation = dense<[1, 0]> : tensor<2xi64> + permutation = array } : (tensor<16x8xf32>) -> tensor<8x16xf32> func.return %0 : tensor<8x16xf32> } // CHECK-LABEL: "op_triangular_solve" func.func @op_triangular_solve(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { - // CHECK: "mhlo.triangular_solve"(%arg0, %arg1) { + // CHECK: "mhlo.triangular_solve"(%arg0, %arg1) <{ // CHECK-SAME: left_side = true, // CHECK-SAME: lower = true, // CHECK-SAME: transpose_a = #mhlo, // CHECK-SAME: unit_diagonal = true - // CHECK-SAME: } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + // CHECK-SAME: }> : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { left_side = true, lower = true, @@ -1612,9 +1641,9 @@ func.func @op_tuple(%arg0: tensor) -> tuple> { // CHECK-LABEL: "op_unary_einsum" func.func @op_unary_einsum(%arg0: tensor<8x16xf32>) -> tensor<8xf32> { - // CHECK: "mhlo.unary_einsum"(%arg0) { + // CHECK: "mhlo.unary_einsum"(%arg0) <{ // CHECK-SAME: einsum_config = "ab->a" - // CHECK-SAME: } : (tensor<8x16xf32>) -> tensor<8xf32> + // CHECK-SAME: }> : (tensor<8x16xf32>) -> tensor<8xf32> %0 = "stablehlo.unary_einsum"(%arg0) { einsum_config = "ab->a" } : (tensor<8x16xf32>) -> tensor<8xf32> @@ -1824,18 +1853,11 @@ func.func @type_dynamism_ranked(%arg0: tensor) -> tensor { func.return %0 : tensor } -// CHECK-LABEL: "type_dynamism_unranked" -func.func @type_dynamism_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "mhlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> - %0 = "stablehlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - // CHECK-LABEL: "type_quantization" -func.func @type_quantization(%arg0: tensor>, %arg1: tensor) -> tensor { - // CHECK: "mhlo.add"(%arg0, %arg1) : (tensor>, tensor) -> tensor - %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor) -> tensor - func.return %0 : tensor +func.func @type_quantization(%arg0: tensor>) -> tensor> { + // CHECK: "mhlo.add"(%arg0, %arg0) : (tensor>, tensor>) -> tensor> + %0 = "stablehlo.add"(%arg0, %arg0) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> } // ----- @@ -1942,3 +1964,17 @@ func.func @op_custom_call_botched_mhlo_backend_config_version(%arg0: tensor } : (tensor) -> tensor return %0 : tensor } + +// ----- + +// CHECK-LABEL: "op_topk_mhlo_v1" +func.func @op_topk_mhlo_v1(%arg0: tensor<5x10xf32>) -> (tensor<5x8xf32>, tensor<5x8xi32>) { + // CHECK: "mhlo.topk"(%arg0) <{k = 8 : i64, largest = true}> : (tensor<5x10xf32>) -> (tensor<5x8xf32>, tensor<5x8xi32>) + %0:2 = "stablehlo.custom_call"(%arg0) { + backend_config = "", + call_target_name = "mhlo.topk", + mhlo.attributes = {k = 8 : i64, largest = true}, + mhlo.version = 1 : i64 + } : (tensor<5x10xf32>) -> (tensor<5x8xf32>, tensor<5x8xi32>) + func.return %0#0, %0#1 : tensor<5x8xf32>, tensor<5x8xi32> +} diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/symbolic-shape-optimization.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/symbolic-shape-optimization.mlir index 4eb4d418396de..258d78b613e71 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/symbolic-shape-optimization.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/symbolic-shape-optimization.mlir @@ -396,7 +396,7 @@ func.func @reshape_integration(%arg0: tensor<512x512xf32>, // CHECK: shape.assuming_yield shape.assuming_yield %21 : tensor } - %5 = "mhlo.transpose"(%4) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} + %5 = "mhlo.transpose"(%4) <{permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>}> : (tensor) -> tensor %6 = "mhlo.transpose"(%5) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor) -> tensor @@ -412,7 +412,7 @@ func.func @reshape_integration(%arg0: tensor<512x512xf32>, %12 = "mhlo.reshape"(%11) : (tensor<1xi32>) -> tensor %13 = mhlo.multiply %10, %12 : tensor %14 = "mhlo.reshape"(%13) : (tensor) -> tensor<1xi32> - %15 = "mhlo.concatenate"(%14, %0) {dimension = 0 : i64} + %15 = "mhlo.concatenate"(%14, %0) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> %16 = shape.shape_of %6 : tensor -> tensor<4xindex> %17 = shape.num_elements %16 : tensor<4xindex> -> index diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/unfuse_batch_norm.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/unfuse_batch_norm.mlir index b947a12c3ec86..9b0d39c50ac51 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/unfuse_batch_norm.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/unfuse_batch_norm.mlir @@ -13,10 +13,10 @@ func.func @batchNormInference_2D_inner_features( // CHECK-DAG: %[[EPS_BCAST:.+]] = mhlo.constant dense<1.001000e-05> : tensor<256xf32> // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32> // CHECK-DAG: %[[STDDEV:.+]] = mhlo.sqrt %[[VARIANCE_EPS]] : tensor<256xf32> - // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MEAN]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> // CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32> // CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32> // CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32> @@ -48,13 +48,13 @@ func.func @batchNormTraining_2D_inner_features( // CHECK-DAG: %[[VARIANCE:.+]] = mhlo.subtract %[[EX2]], %[[E2X]] : tensor<256xf32> // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32> // CHECK-DAG: %[[STDDEV:.+]] = mhlo.sqrt %[[VARIANCE_EPS]] : tensor<256xf32> - // CHECK-DAG: %[[EX_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[EX]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[EX_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[EX]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> // CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[EX_BCAST]] : tensor<4x256xf32> - // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> // CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_CENTER]], %[[STDDEV_BCAST]] : tensor<4x256xf32> - // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> // CHECK-DAG: %[[X_NORMED_SCALED:.+]] = mhlo.multiply %[[X_NORMED]], %[[SCALE_BCAST]] : tensor<4x256xf32> - // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> // CHECK-DAG: %[[X_NORMED_SCALED_OFFSET:.+]] = mhlo.add %[[X_NORMED_SCALED]], %[[OFFSET_BCAST]] : tensor<4x256xf32> // CHECK-DAG: return %[[X_NORMED_SCALED_OFFSET]], %[[EX]], %[[VARIANCE]] : tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32> %0:3 = "mhlo.batch_norm_training"(%x, %scale, %offset) @@ -69,7 +69,7 @@ func.func @batchNormTraining_2D_inner_features( // the verifier to enforce the rest. // CHECK-SAME: %[[X:[^:]+]] // CHECK-SAME: %[[SCALE:[^:]+]] -// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32> +// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<3x4x256x6xf32> func.func @batchNormInference_4D_middle_features( %x: tensor<3x4x256x6xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, %mean: tensor<256xf32>, %variance: tensor<256xf32>) @@ -101,13 +101,13 @@ func.func @batchNormTraining_4D_middle_features( // CHECK-DAG: %[[VARIANCE:.+]] = mhlo.subtract %[[EX2]], %[[E2X]] : tensor<256xf32> // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32> // CHECK-DAG: %[[STDDEV:.+]] = mhlo.sqrt %[[VARIANCE_EPS]] : tensor<256xf32> - // CHECK-DAG: %[[EX_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[EX]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32> + // CHECK-DAG: %[[EX_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[EX]]) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<3x4x256x6xf32> // CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[EX_BCAST]] : tensor<3x4x256x6xf32> - // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32> + // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<3x4x256x6xf32> // CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_CENTER]], %[[STDDEV_BCAST]] : tensor<3x4x256x6xf32> - // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32> + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<3x4x256x6xf32> // CHECK-DAG: %[[X_NORMED_SCALED:.+]] = mhlo.multiply %[[X_NORMED]], %[[SCALE_BCAST]] : tensor<3x4x256x6xf32> - // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32> + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<3x4x256x6xf32> // CHECK-DAG: %[[X_NORMED_SCALED_OFFSET:.+]] = mhlo.add %[[X_NORMED_SCALED]], %[[OFFSET_BCAST]] : tensor<3x4x256x6xf32> // CHECK-DAG: return %[[X_NORMED_SCALED_OFFSET]], %[[EX]], %[[VARIANCE]] : tensor<3x4x256x6xf32>, tensor<256xf32>, tensor<256xf32> %0:3 = "mhlo.batch_norm_training"(%x, %scale, %offset) @@ -210,14 +210,14 @@ func.func @batchNormInference_dynamic_shape( -> tensor { // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor // CHECK-DAG: %[[VAR_SHAPE:.+]] = shape.shape_of %[[VARIANCE]] : tensor -> tensor<1xindex> - // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[VAR_SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor + // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[VAR_SHAPE]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<1xindex>) -> tensor // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor // CHECK-DAG: %[[STDDEV:.+]] = mhlo.sqrt %[[VARIANCE_EPS]] : tensor // CHECK-DAG: %[[X_SHAPE:.+]] = shape.shape_of %[[X]] : tensor -> tensor<4xindex> - // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[X_SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor - // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[X_SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor - // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[X_SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor - // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[X_SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[X_SHAPE]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[X_SHAPE]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[X_SHAPE]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[X_SHAPE]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor, tensor<4xindex>) -> tensor // CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor // CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor // CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor @@ -241,7 +241,7 @@ func.func @batchNormTraining_dynamic_shape( // CHECK-DAG: %[[ZERO:.+]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.001000e-05> : tensor // CHECK-DAG: %[[SCALE_SHAPE:.+]] = shape.shape_of %[[SCALE]] : tensor -> tensor<1xindex> - // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[SCALE_SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor + // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[SCALE_SHAPE]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<1xindex>) -> tensor // CHECK-DAG: %[[X_SHAPE:.+]] = shape.shape_of %[[X]] : tensor -> tensor<4xindex> // CHECK-DAG: %[[X_SIZE:.+]] = shape.num_elements %[[X_SHAPE]] : tensor<4xindex> -> index // CHECK-DAG: %[[SCALE_SIZE:.+]] = shape.num_elements %[[SCALE_SHAPE]] : tensor<1xindex> -> index @@ -250,7 +250,7 @@ func.func @batchNormTraining_dynamic_shape( // CHECK-DAG: %[[REDUCE_SIZE_TENSOR:.+]] = tensor.from_elements %[[INDEX_CAST]] : tensor<1xi64> // CHECK-DAG: %[[REDUCE_SIZE_TENSOR_FP:.+]] = mhlo.convert %[[REDUCE_SIZE_TENSOR]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK-DAG: %[[REDUCE_SIZE_RESHAPE:.+]] = mhlo.reshape %[[REDUCE_SIZE_TENSOR_FP]] : (tensor<1xf32>) -> tensor - // CHECK-DAG: %[[REDUCE_SIZE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[REDUCE_SIZE_RESHAPE]], %[[SCALE_SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor + // CHECK-DAG: %[[REDUCE_SIZE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[REDUCE_SIZE_RESHAPE]], %[[SCALE_SHAPE]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<1xindex>) -> tensor // CHECK-DAG: %[[X_SUM:.+]] = mhlo.reduce(%[[X]] init: %[[ZERO]]) applies mhlo.add across dimensions = [0, 1, 3] : (tensor, tensor) -> tensor // CHECK-DAG: %[[X2:.+]] = mhlo.multiply %[[X]], %[[X]] : tensor // CHECK-DAG: %[[X2_SUM:.+]] = mhlo.reduce(%[[X2]] init: %[[ZERO]]) applies mhlo.add across dimensions = [0, 1, 3] : (tensor, tensor) -> tensor @@ -260,13 +260,13 @@ func.func @batchNormTraining_dynamic_shape( // CHECK-DAG: %[[VARX:.+]] = mhlo.subtract %[[EX2]], %[[EX_2]] : tensor // CHECK-DAG: %[[VARX_EPS:.+]] = mhlo.add %[[VARX]], %[[EPS_BCAST]] : tensor // CHECK-DAG: %[[STDX:.+]] = mhlo.sqrt %[[VARX_EPS]] : tensor - // CHECK-DAG: %[[EX_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EX]], %[[X_SHAPE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[EX_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EX]], %[[X_SHAPE]]) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor, tensor<4xindex>) -> tensor // CHECK-DAG: %[[X_SUB_EX:.+]] = mhlo.subtract %[[X]], %[[EX_BCAST]] : tensor - // CHECK-DAG: %[[STDX_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDX]], %[[X_SHAPE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[STDX_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDX]], %[[X_SHAPE]]) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor, tensor<4xindex>) -> tensor // CHECK-DAG: %[[X_CENTOR:.+]] = mhlo.divide %[[X_SUB_EX]], %[[STDX_BCAST]] : tensor - // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[X_SHAPE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[X_SHAPE]]) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor, tensor<4xindex>) -> tensor // CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTOR]], %[[SCALE_BCAST]] : tensor - // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[X_SHAPE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[X_SHAPE]]) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor, tensor<4xindex>) -> tensor // CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_SCALED]], %[[OFFSET_BCAST]] : tensor // CHECK-DAG: return %[[RESULT]], %[[EX]], %[[VARX]] : tensor, tensor, tensor %0:3 = "mhlo.batch_norm_training"(%x, %scale, %offset) diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir index b9e831a891dda..764a93e5fb77e 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir @@ -35,22 +35,6 @@ func.func @reduce_complex_type(%arg0: tensor<1x2xcomplex>, %arg1 : tensor, %arg1 : tensor<*xf32>) - -> (tensor<*xf32>) { - %0 = "mhlo.reduce"(%arg0, %arg1) ({ - - ^bb0(%arg2: tensor<*xf32>, %arg3: tensor<*xf32> ): - %1 = "mhlo.add"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - "mhlo.return"(%1) : (tensor<*xf32>) -> () - - }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - - func.return %0: tensor<*xf32> -} - -// ----- - // CHECK-LABEL: func @reduce_mixed_dynamism func.func @reduce_mixed_dynamism(%arg0: tensor<4x4xf32>, %arg1 : tensor) -> (tensor) { @@ -62,36 +46,31 @@ func.func @reduce_mixed_dynamism(%arg0: tensor<4x4xf32>, %arg1 : tensor) // ----- -// CHECK-LABEL: func @reduce_unranked -func.func @reduce_unranked(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, - %arg2: tensor<*xf32>, %arg3: tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { - %0:2 = "mhlo.reduce"(%arg0, %arg1, %arg2, %arg3) ({ +// CHECK-LABEL: func @reduce_with_promotable_types +func.func @reduce_with_promotable_types(%arg0: tensor<4x4xf32>, %arg1 : tensor) + -> (tensor<4xf64>) { + %0 = "mhlo.reduce"(%arg0, %arg1) ({ - ^bb0(%arg4: tensor<*xf32>, %arg5: tensor<*xf32>, %arg6: tensor<*xf32>, %arg7: tensor<*xf32>): - %1 = "mhlo.add"(%arg4, %arg6) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %2 = "mhlo.add"(%arg5, %arg7) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - "mhlo.return"(%1, %2) : (tensor<*xf32>, tensor<*xf32>) -> () + ^bb0(%arg2: tensor, %arg3: tensor ): + %1 = "mhlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () - }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor) -> tensor<4xf64> - func.return %0#0, %0#1 : tensor<*xf32>, tensor<*xf32> + func.return %0: tensor<4xf64> } // ----- -// CHECK-LABEL: func @reduce_mix_rank_and_unranked -func.func @reduce_mix_rank_and_unranked(%arg0: tensor<4x4xf32>, %arg1: tensor<*xf32>, - %arg2: tensor<4xf32>, %arg3: tensor<*xf32>) -> (tensor<4xf32>, tensor<*xf32>) { - %0:2 = "mhlo.reduce"(%arg0, %arg1, %arg2, %arg3) ({ - - ^bb0(%arg4: tensor<4xf32>, %arg5: tensor<*xf32>, %arg6: tensor<4xf32>, %arg7: tensor<*xf32>): - %1 = "mhlo.add"(%arg4, %arg6) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %2 = "mhlo.add"(%arg5, %arg7) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - "mhlo.return"(%1, %2) : (tensor<4xf32>, tensor<*xf32>) -> () - - }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<*xf32>, tensor<4xf32>, tensor<*xf32>) -> (tensor<4xf32>, tensor<*xf32>) - - func.return %0#0, %0#1 : tensor<4xf32>, tensor<*xf32> +// CHECK-LABEL: func @reduce_with_promotable_quantized_types +func.func @reduce_with_promotable_quantized_types(%arg0: tensor<4x4x!quant.uniform>, + %arg1: tensor>) -> tensor<4x!quant.uniform> { + %0 = mhlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<4x4x!quant.uniform>, tensor>) -> tensor<4x!quant.uniform> + reducer(%arg2: tensor>, %arg3: tensor>) { + %1 = mhlo.add %arg2, %arg3 : tensor> + mhlo.return %1 : tensor> + } + return %0 : tensor<4x!quant.uniform> } // Next, we have the invalid testcases. @@ -157,7 +136,7 @@ func.func @reduce_diferent_input_shapes(%arg0: tensor<2x3xf32>, %arg1: tensor<3x func.func @reduce_oob_dims(%arg0: tensor, %arg1 : tensor) -> (tensor) { - // expected-error@+1 {{Out-of-bounds dimension 2, expected to be less than the input-tensor rank 2}} + // expected-error@+1 {{Out-of-bounds dimension 2, expected to be in range [0, 2)}} %0 = "mhlo.reduce"(%arg0, %arg1) ({ ^bb0(%arg2: tensor, %arg3: tensor ): @@ -314,7 +293,7 @@ func.func @verify_reducer_function(%arg0: tensor, %arg1: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { - // expected-error@+1 {{The type of reduction-region's result type at index 1 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} + // expected-error@+1 {{The element-type of reduction-region's result type at index 1 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} %0:2 = "mhlo.reduce"(%arg0, %arg1, %arg2, %arg3) ({ ^bb0(%arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor): @@ -332,7 +311,7 @@ func.func @verify_reducer_function(%arg0: tensor, %arg1: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { - // expected-error@+1 {{The element-type of reduction-region's argument at index 3 is expected to be 'i32', but got 'tensor'}} + // expected-error@+1 {{The element-type of reduction-region's argument at index 3 is expected to be promotable from 'i32', but got 'f32'}} %0:2 = "mhlo.reduce"(%arg0, %arg1, %arg2, %arg3) ({ ^bb0(%arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor): @@ -492,7 +471,7 @@ func.func @reduce_verify_rettype(%arg0: tensor, %arg1 : tensor) // ----- func.func @reduce_parsing_pretty_reduce_non_commutative(%arg0: tensor , %arg1: tensor ) -> tensor { - // expected-error@+1 {{expected the inner-op to be a commutative binary-op from mhlo dialect, zero region, producing single result}} + // expected-error@+1 {{expected the inner-op to be a commutative binary-op from the mhlo dialect, with zero region, producing single result}} %0 = mhlo.reduce(%arg0 init: %arg1) applies mhlo.divide across dimensions = [1] : (tensor, tensor) -> tensor loc("foo") func.return %0 : tensor } @@ -500,7 +479,7 @@ func.func @reduce_parsing_pretty_reduce_non_commutative(%arg0: tensor , // ----- func.func @reduce_parsing_pretty_reduce_wrong_dialect(%arg0: tensor , %arg1: tensor ) -> tensor { - // expected-error@+1 {{expected the inner-op to be a commutative binary-op from mhlo dialect, zero region, producing single result}} + // expected-error@+1 {{expected the inner-op to be a commutative binary-op from the mhlo dialect, with zero region, producing single result}} %0 = mhlo.reduce(%arg0 init: %arg1) applies std.add across dimensions = [1] : (tensor, tensor) -> tensor loc("foo") func.return %0 : tensor } @@ -508,7 +487,7 @@ func.func @reduce_parsing_pretty_reduce_wrong_dialect(%arg0: tensor , % // ----- func.func @reduce_parsing_pretty_reduce_non_binary(%arg0: tensor , %arg1: tensor ) -> tensor { - // expected-error@+1 {{expected the inner-op to be a commutative binary-op from mhlo dialect, zero region, producing single result}} + // expected-error@+1 {{expected the inner-op to be a commutative binary-op from the mhlo dialect, with zero region, producing single result}} %0 = mhlo.reduce(%arg0 init: %arg1) applies mhlo.reshape across dimensions = [1] : (tensor, tensor) -> tensor loc("foo") func.return %0 : tensor } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_window_op.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_window_op.mlir index 59c2159e64eea..b6b3a1ad0cd97 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_window_op.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_window_op.mlir @@ -21,28 +21,6 @@ func.func @reduce_window(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, // ----- -func.func @reduce_window_with_unranked_dynamic_dims(%arg0: tensor<*xf32>, - %arg1: tensor<4x?xi32>, %init0: tensor, %init1: tensor) -> - (tensor, tensor<*xi32>) { - %0:2 = "mhlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ - ^bb0(%a0: tensor, %a1: tensor, - %b0: tensor, %b1: tensor): - %2 = mhlo.add %a0, %b0 : tensor - %3 = mhlo.add %a1, %b1 : tensor - "mhlo.return"(%2, %3) : (tensor, tensor) -> () - }) - { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, - window_dimensions = dense<[5, 1]> : tensor<2xi64>, - window_strides = dense<[3, 1]> : tensor<2xi64>, - base_dilations = dense<[1,1]> : tensor<2xi64>, - window_dilations = dense<[1,1]> : tensor<2xi64> } - : (tensor<*xf32>, tensor<4x?xi32>, tensor, tensor) -> - (tensor, tensor<*xi32>) - func.return %0#0, %0#1 : tensor, tensor<*xi32> -} - -// ----- - func.func @reduce_window_with_non_scalar_block_arg1(%arg0: tensor<4x2xf32>, %init0: tensor<4xf32>) -> tensor<2x1xf32> { %0 = "mhlo.reduce_window"(%arg0, %init0) ({ @@ -79,6 +57,46 @@ func.func @reduce_window_with_non_scalar_block_arg2(%arg0: tensor<4x2xf32>, // ----- +// CHECK-LABEL: func @reduce_window_with_promotable_types +func.func @reduce_window_with_promotable_types(%arg0: tensor<4x2xf32>, + %arg1: tensor<4x2xf32>, %init0: tensor, %init1: tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) { + %0:2 = "mhlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ + ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, + %b1: tensor): + %2 = mhlo.add %a0, %b0 : tensor + %3 = mhlo.add %a1, %b1 : tensor + "mhlo.return"(%2,%3) : (tensor, tensor) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> } + : (tensor<4x2xf32>, tensor<4x2xf32>, tensor, tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) + func.return %0#0, %0#1 : tensor<2x2xf64>, tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @reduce_window_with_promotable_quantized_types +func.func @reduce_window_with_promotable_quantized_types(%arg0: tensor<4x2x!quant.uniform>, + %init0: tensor>) -> (tensor<2x2x!quant.uniform>) { + + %0 = "mhlo.reduce_window"(%arg0, %init0) ({ + ^bb0(%a0: tensor>, %b0: tensor>): + %1 = mhlo.add %a0, %b0 : tensor> + "mhlo.return"(%1) : (tensor>) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> + } + : (tensor<4x2x!quant.uniform>, tensor>) -> (tensor<2x2x!quant.uniform>) + func.return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + func.func @reduce_window_invalid_inputs(%arg0: tensor<4x2xf32>, %arg1: tensor<4x3xi32>, %init0: tensor, %init1: tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) { @@ -253,7 +271,7 @@ func.func @reduce_window_invalid_padding_attributes(%arg0: tensor<4x2xf32>, func.func @reduce_window(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor, %init1: tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) { - // expected-error @+1 {{expects the shape of window_dimensions attribute to be 1-D, but got {1, 2}}} + // expected-error @+1 {{window_dimensions has rank 2 instead of required rank 1}} %0:2 = "mhlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, %b1: tensor): @@ -274,7 +292,7 @@ func.func @reduce_window(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, func.func @reduce_window(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor, %init1: tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) { - // expected-error @+1 {{expects the shape of window_strides attribute to be 1-D, but got {1, 2}}} + // expected-error @+1 {{window_strides has rank 2 instead of required rank 1}} %0:2 = "mhlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, %b1: tensor): @@ -292,56 +310,10 @@ func.func @reduce_window(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, // ----- -func.func @reduce_window_with_unranked_dynamic_dims(%arg0: tensor<*xf32>, - %arg1: tensor<4x?xi32>, %init0: tensor, %init1: tensor) -> - (tensor, tensor<*xi32>) { - // expected-error @+1 {{expects the shape of base_dilations attribute to be 1-D, but got {1, 2}}} - %0:2 = "mhlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ - ^bb0(%a0: tensor, %a1: tensor, - %b0: tensor, %b1: tensor): - %2 = mhlo.add %a0, %b0 : tensor - %3 = mhlo.add %a1, %b1 : tensor - "mhlo.return"(%2, %3) : (tensor, tensor) -> () - }) - { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, - window_dimensions = dense<[5, 1]> : tensor<2xi64>, - window_strides = dense<[3, 1]> : tensor<2xi64>, - base_dilations = dense<[[1, 1]]> : tensor<1x2xi64>, - window_dilations = dense<[1, 1]> : tensor<2xi64> } - : (tensor<*xf32>, tensor<4x?xi32>, tensor, tensor) -> - (tensor, tensor<*xi32>) - func.return %0#0, %0#1 : tensor, tensor<*xi32> -} - -// ----- - -func.func @reduce_window_with_unranked_dynamic_dims(%arg0: tensor<*xf32>, - %arg1: tensor<4x?xi32>, %init0: tensor, %init1: tensor) -> - (tensor, tensor<*xi32>) { - // expected-error @+1 {{expects the shape of window_dilations attribute to be 1-D, but got {1, 2}}} - %0:2 = "mhlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ - ^bb0(%a0: tensor, %a1: tensor, - %b0: tensor, %b1: tensor): - %2 = mhlo.add %a0, %b0 : tensor - %3 = mhlo.add %a1, %b1 : tensor - "mhlo.return"(%2, %3) : (tensor, tensor) -> () - }) - { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, - window_dimensions = dense<[5, 1]> : tensor<2xi64>, - window_strides = dense<[3, 1]> : tensor<2xi64>, - base_dilations = dense<[1, 1]> : tensor<2xi64>, - window_dilations = dense<[[1, 1]]> : tensor<1x2xi64> } - : (tensor<*xf32>, tensor<4x?xi32>, tensor, tensor) -> - (tensor, tensor<*xi32>) - func.return %0#0, %0#1 : tensor, tensor<*xi32> -} - -// ----- - func.func @reduce_window_invalid_attributes(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor, %init1: tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) { - // expected-error @+1 {{expects the shape of window_dimensions attribute to be 1-D}} + // expected-error @+1 {{window_dimensions has rank 2 instead of required rank 1}} %0:2 = "mhlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, %b1: tensor): @@ -668,7 +640,7 @@ func.func @reduce_window_invalid_reducer(%arg0: tensor<4x2xf32>, func.func @reduce_window_invalid_reducer(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor, %init1: tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) { - // expected-error@+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} %0:2 = "mhlo.reduce_window"(%arg0, %arg1, %init1, %init0) ({ ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, %b1: tensor): @@ -689,7 +661,7 @@ func.func @reduce_window_invalid_reducer(%arg0: tensor<4x2xf32>, func.func @reduce_window_invalid_reducer(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor, %init1: tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) { - // expected-error@+1 {{The element-type of reduction-region's argument at index 2 is expected to be 'i32', but got 'tensor' as its type.}} + // expected-error@+1 {{The element-type of reduction-region's argument at index 2 is expected to be promotable from 'i32', but got 'f32'}} %0:2 = "mhlo.reduce_window"(%arg1, %arg0, %init0, %init1) ({ ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, %b1: tensor): diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/verifier_scatter_op.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/verifier_scatter_op.mlir index 58789086a22c6..ea6f77a27e7ec 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/verifier_scatter_op.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/verifier_scatter_op.mlir @@ -22,14 +22,16 @@ func.func @scatter(%input_tensor: tensor<200x100x300xf32>, func.return %0 : tensor<200x100x300xf32> } -// CHECK: func @scatter_with_unranked_inputs -func.func @scatter_with_unranked_inputs(%input_tensor: tensor<*xf32>, - %scatter_indices: tensor<*xi32>, %updates: tensor<*xf32>) -> - tensor<*xf32> { +// ----- + +// CHECK: func @scatter_with_promotable_types +func.func @scatter_with_promotable_types(%input_tensor: tensor<200x100x300xf32>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> + tensor<200x100x300xf64> { %0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ - ^bb0(%lhs: tensor, %rhs: tensor): - %add = mhlo.add %lhs, %rhs : tensor - "mhlo.return"(%add) : (tensor) -> () + ^bb0(%lhs: tensor, %rhs: tensor): + %add = mhlo.add %lhs, %rhs : tensor + "mhlo.return"(%add) : (tensor) -> () }) { scatter_dimension_numbers = #mhlo.scatter< update_window_dims = [1], @@ -39,17 +41,41 @@ func.func @scatter_with_unranked_inputs(%input_tensor: tensor<*xf32>, >, indices_are_sorted = true, unique_indices = true - } : (tensor<*xf32>, tensor<*xi32>, tensor<*xf32>) -> - tensor<*xf32> - func.return %0 : tensor<*xf32> + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> + tensor<200x100x300xf64> + func.return %0 : tensor<200x100x300xf64> } // ----- +// CHECK: func @scatter_with_promotable_quantized_types +func.func @scatter_with_promotable_quantized_types(%input_tensor: tensor<200x100x300x!quant.uniform>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300x!quant.uniform>) -> + tensor<200x100x300x!quant.uniform> { + %0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor>, %rhs: tensor>): + %add = mhlo.add %lhs, %rhs : tensor> + "mhlo.return"(%add) : (tensor>) -> () + }) { + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300x!quant.uniform>, tensor<10x2xi32>, + tensor<10x300x!quant.uniform>) -> + tensor<200x100x300x!quant.uniform> + func.return %0 : tensor<200x100x300x!quant.uniform> +} +// ----- + func.func @invalid_scatter(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xf32>, %updates: tensor<10x300xf32>) -> tensor<200x100x300xf32> { - // expected-error @+1 {{operand #1 must be tensor of integer or index values, but got 'tensor<10x2xf32>'}} + // expected-error @+1 {{operand #1 must be ranked tensor of integer or index values, but got 'tensor<10x2xf32>'}} %0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ ^bb0(%lhs: tensor, %rhs: tensor): %add = mhlo.add %lhs, %rhs : tensor @@ -70,9 +96,9 @@ func.func @invalid_scatter(%input_tensor: tensor<200x100x300xf32>, // ----- -func.func @invalid_scatter(%input_tensor: tensor<*xf32>, - %scatter_indices: tensor<10x2xi32>, %updates: tensor<*xf32>) -> - tensor<*xf32> { +func.func @invalid_scatter(%input_tensor: tensor, + %scatter_indices: tensor<10x2xi32>, %updates: tensor) -> + tensor { // expected-error @+1 {{expects scatter index leaf dimension to be within [0, rank(scatter_indices) + 1. rank(scatter_indices) is 2 and scatter index leaf dimension is 3.}} %0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ ^bb0(%lhs: tensor, %rhs: tensor): @@ -87,9 +113,9 @@ func.func @invalid_scatter(%input_tensor: tensor<*xf32>, >, indices_are_sorted = true, unique_indices = true - } : (tensor<*xf32>, tensor<10x2xi32>, tensor<*xf32>) -> - tensor<*xf32> - func.return %0 : tensor<*xf32> + } : (tensor, tensor<10x2xi32>, tensor) -> + tensor + func.return %0 : tensor } // ----- @@ -118,9 +144,9 @@ func.func @invalid_scatter(%input_tensor: tensor<200x100x300xf32>, // ----- -func.func @invalid_scatter(%input_tensor: tensor<*xf32>, +func.func @invalid_scatter(%input_tensor: tensor, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> - tensor<*xf32> { + tensor { // expected-error @+1 {{expects updates tensor must be of rank 3 ( == rank-of('scatter_indices') - 1 + size-of('update_window_dims'), where 'scatter_indices' is expanded by a trailing 1 dimension if 'index_vector_dim' == rank-of('scatter_indices')), but got 2.}} %0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ ^bb0(%lhs: tensor, %rhs: tensor): @@ -135,8 +161,8 @@ func.func @invalid_scatter(%input_tensor: tensor<*xf32>, >, indices_are_sorted = true, unique_indices = true - } : (tensor<*xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> + } : (tensor, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor + func.return %0 : tensor } // ----- @@ -237,9 +263,9 @@ func.func @invalid_scatter_dimensions() -> tensor<512x1x6400x6400xf32> { // ----- -func.func @invalid_scatter_dimensions(%input_tensor: tensor<*xf32>, - %scatter_indices: tensor<*xi32>, %updates: tensor<*xf32>) -> - tensor<*xf32> { +func.func @invalid_scatter_dimensions(%input_tensor: tensor, + %scatter_indices: tensor, %updates: tensor) -> + tensor { // expected-error @+1 {{Expects inserted_window_dims to be sorted; got: [1, 0].}} %0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ @@ -255,8 +281,8 @@ func.func @invalid_scatter_dimensions(%input_tensor: tensor<*xf32>, >, indices_are_sorted = true, unique_indices = true - } : (tensor<*xf32>, tensor<*xi32>, tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> + } : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor } // ----- @@ -311,7 +337,7 @@ func.func @invalid_scatter_dimensions(%input_tensor: tensor<200x100x300xf32>, // ----- func.func @invalid_scatter_dimensions(%input_tensor: tensor<200x100x300xf32>, - %scatter_indices: tensor<*xi32>, %updates: tensor<*xf32>) -> tensor<*xf32> { + %scatter_indices: tensor, %updates: tensor) -> tensor { // expected-error @+1 {{Expects rank-of operand to match size-of('update_window_dims') + size-of('inserted_window_dims') i.e. 4 but got 3.}} %0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ @@ -327,15 +353,15 @@ func.func @invalid_scatter_dimensions(%input_tensor: tensor<200x100x300xf32>, >, indices_are_sorted = true, unique_indices = true - } : (tensor<200x100x300xf32>, tensor<*xi32>, tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> + } : (tensor<200x100x300xf32>, tensor, tensor) -> tensor + func.return %0 : tensor } // ----- -func.func @invalid_scatter_dimensions(%input_tensor: tensor<*xf32>, - %scatter_indices: tensor<10x2xi32>, %updates: tensor<*xf32>) -> - tensor<*xf32> { +func.func @invalid_scatter_dimensions(%input_tensor: tensor, + %scatter_indices: tensor<10x2xi32>, %updates: tensor) -> + tensor { // expected-error @+1 {{Scatter op has 3 elements in scatter_dims_to_operand_dims and the bound of dimension index_vector_dim=1 of scatter_indices is 2. These two numbers must be equal.}} %0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ @@ -345,20 +371,20 @@ func.func @invalid_scatter_dimensions(%input_tensor: tensor<*xf32>, }) { scatter_dimension_numbers = #mhlo.scatter< update_window_dims = [1], - inserted_window_dims = [0, 1], + inserted_window_dims = [0], scatter_dims_to_operand_dims = [0, 1, 2], index_vector_dim = 1 >, indices_are_sorted = true, unique_indices = true - } : (tensor<*xf32>, tensor<10x2xi32>, tensor<*xf32>) -> - tensor<*xf32> - func.return %0 : tensor<*xf32> + } : (tensor, tensor<10x2xi32>, tensor) -> + tensor + func.return %0 : tensor } func.func @valid_scatter_dimensions_with_dynamic_index_vector_dim( - %input_tensor: tensor<*xf32>, %scatter_indices: tensor<10x?xi32>, - %updates: tensor<*xf32>) -> tensor<*xf32> { + %input_tensor: tensor, %scatter_indices: tensor<10x?xi32>, + %updates: tensor) -> tensor { %0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ ^bb0(%lhs: tensor, %rhs: tensor): @@ -373,14 +399,14 @@ func.func @valid_scatter_dimensions_with_dynamic_index_vector_dim( >, indices_are_sorted = true, unique_indices = true - } : (tensor<*xf32>, tensor<10x?xi32>, tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> + } : (tensor, tensor<10x?xi32>, tensor) -> tensor + func.return %0 : tensor } // ----- func.func @invalid_scatter_dimensions(%input_tensor: tensor<200x100x300xf32>, - %scatter_indices: tensor<*xi32>, %updates: tensor<*xf32>) -> tensor<*xf32> { + %scatter_indices: tensor, %updates: tensor) -> tensor { // expected-error @+1 {{Invalid scatter_dims_to_operand_dims mapping; domain is [0, 3), got: 1->3.}} %0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ @@ -396,8 +422,8 @@ func.func @invalid_scatter_dimensions(%input_tensor: tensor<200x100x300xf32>, >, indices_are_sorted = true, unique_indices = true - } : (tensor<200x100x300xf32>, tensor<*xi32>, tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> + } : (tensor<200x100x300xf32>, tensor, tensor) -> tensor + func.return %0 : tensor } // ----- @@ -685,7 +711,7 @@ func.func @invalid_scatter_reducer(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xi32>) -> tensor<200x100x300xf32> { - // expected-error@+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} %0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ ^bb0(%lhs: tensor, %rhs: tensor): %add = mhlo.add %lhs, %rhs : tensor @@ -710,7 +736,7 @@ func.func @invalid_scatter_reducer(%input_tensor: tensor<200x100x300xi32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> tensor<200x100x300xf32> { - // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be 'i32', but got 'tensor' as its type.}} + // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from 'i32', but got 'f32'}} %0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ ^bb0(%lhs: tensor, %rhs: tensor): %add = mhlo.add %lhs, %rhs : tensor diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/verifier_select_and_scatter_op.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/verifier_select_and_scatter_op.mlir index 6f30d4af6aaa3..7a889b2d255d1 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/verifier_select_and_scatter_op.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/verifier_select_and_scatter_op.mlir @@ -26,34 +26,58 @@ func.func @select_and_scatter( func.return %1 : tensor<10x24x24x64xf32> } -func.func @select_and_scatter_with_unranked_dims( - %arg0: tensor<4x5x1x1xbf16>, - %arg1: tensor<2x2x1x1xbf16>, - %arg2: tensor) -> tensor { - %0 = mhlo.constant dense<0> : tensor<4x2xi32> - %1 = mhlo.constant dense<[2, 2, 1, 1]> : tensor<4xi32> - %2 = mhlo.constant dense<[2, 3, 1, 1]> : tensor<4xi32> - - %3 = "mhlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ - ^bb0(%arg3: tensor<*xbf16>, %arg4: tensor<*xbf16>): - %4 = "mhlo.compare"(%arg3, %arg4) { + +// CHECK: func @select_and_scatter_with_promotable_types +func.func @select_and_scatter_with_promotable_types( + %arg0: tensor<10x24x24x64xf32>, + %arg1: tensor<10x12x12x64xf32>) -> () { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = "mhlo.compare"(%arg3, %arg4) { + comparison_direction = #mhlo + } : (tensor, tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = mhlo.add %arg3, %arg4 : tensor + "mhlo.return"(%2) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>, + padding = dense<0> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> + tensor<10x24x24x64xf64> + func.return +} + + +// CHECK: func @select_and_scatter_with_promotable_quantized_types +func.func @select_and_scatter_with_promotable_quantized_types( + %arg0: tensor<10x24x24x64x!quant.uniform>, + %arg1: tensor<10x12x12x64x!quant.uniform>, + %arg2 : tensor>) -> + tensor<10x24x24x64x!quant.uniform> { + + %1 = "mhlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor>, %arg4: tensor>): + %2 = "mhlo.compare"(%arg3, %arg4) { compare_type = #mhlo, - comparison_direction = #mhlo} - : (tensor<*xbf16>, tensor<*xbf16>) -> tensor<*xi1> - "mhlo.return"(%4) : (tensor<*xi1>) -> () - }, { - ^bb0(%arg3: tensor<*xbf16>, %arg4: tensor<*xbf16>): - %4 = "mhlo.add"(%arg3, %arg4) : (tensor<*xbf16>, tensor<*xbf16>) -> - tensor<*xbf16> - "mhlo.return"(%4) : (tensor<*xbf16>) -> () + comparison_direction = #mhlo + } : (tensor>, tensor>) -> tensor + "mhlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor>, %arg4: tensor>): + %2 = mhlo.add %arg3, %arg4 : tensor> + "mhlo.return"(%2) : (tensor>) -> () }) { - padding = dense<0> : tensor<4x2xi64>, - window_dimensions = dense<[2, 3, 1, 1]> : tensor<4xi64>, - window_strides = dense<[2, 2, 1, 1]> : tensor<4xi64>} - : (tensor<4x5x1x1xbf16>, tensor<2x2x1x1xbf16>, tensor) -> - tensor - - func.return %3 : tensor + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<10x24x24x64x!quant.uniform>, + tensor<10x12x12x64x!quant.uniform>, + tensor>) -> + tensor<10x24x24x64x!quant.uniform> + func.return %1 : tensor<10x24x24x64x!quant.uniform> } // ----- @@ -649,7 +673,7 @@ func.func @select_and_scatter_invalid_scatter_reducer( %arg1: tensor<10x12x12x64xf32>) -> () { %0 = mhlo.constant dense<0> : tensor - // expected-error @+1 {{The element-type of reduction-region's argument at index 1 is expected to be 'f32', but got 'tensor' as its type.}} + // expected-error @+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from 'f32', but got 'i32'}} %1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({ ^bb0(%arg3: tensor, %arg4: tensor): %2 = "mhlo.compare"(%arg3, %arg4) { @@ -677,7 +701,7 @@ func.func @select_and_scatter_invalid_scatter_reducer( %arg1: tensor<10x12x12x64xf32>) -> () { %0 = mhlo.constant dense<0.000000e+00> : tensor - // expected-error @+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} + // expected-error @+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} %1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({ ^bb0(%arg3: tensor, %arg4: tensor): %2 = "mhlo.compare"(%arg3, %arg4) { diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/verifier_while_op.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/verifier_while_op.mlir index eae57c2baa5bf..0508f501fbbee 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/verifier_while_op.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/verifier_while_op.mlir @@ -28,13 +28,13 @@ func.func @while_with_different_types(%arg0: tensor<3xf32>) -> tensor<3xf32> { %1:4 = "mhlo.while"(%cst_0, %cst_1, %cst_2, %arg0) ({ ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): %2 = arith.constant dense<0> : tensor - %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + %3 = "mhlo.slice"(%arg2) <{limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = #mhlo} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor "mhlo.return"(%5) : (tensor) -> () }, { ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): - %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32> + %3 = "mhlo.broadcast_in_dim"(%arg3) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xf32>) -> tensor<3xf32> %4 = mhlo.add %3, %arg4 : tensor<3xf32> "mhlo.return"(%arg1, %arg2, %arg3, %4) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> () }) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) @@ -44,48 +44,24 @@ func.func @while_with_different_types(%arg0: tensor<3xf32>) -> tensor<3xf32> { // ----- // CHECK-LABEL: while_dynamic -func.func @while_dynamic(%arg0: tensor<3xf32>) -> tensor<*xf32> { +func.func @while_dynamic(%arg0: tensor<3xf32>) -> tensor { %cst_0 = arith.constant dense<0> : tensor<1xi32> %cst_1 = arith.constant dense<[100, 100]> : tensor<2xi32> %cst_2 = arith.constant dense<1.00> : tensor<1xf32> %1:4 = "mhlo.while"(%cst_0, %cst_1, %cst_2, %arg0) ({ - ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<*xf32>): + ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor): %2 = arith.constant dense<0> : tensor - %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + %3 = "mhlo.slice"(%arg2) <{limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = #mhlo} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor "mhlo.return"(%5) : (tensor) -> () }, { - ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<*xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): - %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32> + ^bb0(%arg1: tensor<1xi32>, %arg2: tensor, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): + %3 = "mhlo.broadcast_in_dim"(%arg3) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xf32>) -> tensor<3xf32> %4 = mhlo.add %3, %arg4 : tensor<3xf32> - "mhlo.return"(%arg1, %arg2, %arg3, %4) : (tensor<1xi32>, tensor<*xi32>, tensor<1xf32>, tensor<3xf32>) -> () - }) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<*xf32>) - func.return %1#3: tensor<*xf32> -} - - -// ----- - -// CHECK-LABEL: while_unranked -func.func @while_unranked(%arg0: tensor<3xf32>) -> tensor<*xf32> { - %cst_0 = arith.constant dense<0> : tensor<1xi32> - %cst_1 = arith.constant dense<[100, 100]> : tensor<2xi32> - %cst_2 = arith.constant dense<1.00> : tensor<1xf32> - %1:4 = "mhlo.while"(%cst_0, %cst_1, %cst_2, %arg0) ({ - ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<*xf32>): - %2 = arith.constant dense<0> : tensor - %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> - %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = #mhlo} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %5 = "mhlo.select"(%4, %4, %4) : (tensor<1xi1>, tensor<1xi1>, tensor<1xi1>) -> tensor<*xi1> - "mhlo.return"(%5) : (tensor<*xi1>) -> () - }, { - ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<*xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): - %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32> - %4 = mhlo.add %3, %arg4 : tensor<3xf32> - "mhlo.return"(%arg1, %arg2, %arg3, %4) : (tensor<1xi32>, tensor<*xi32>, tensor<1xf32>, tensor<3xf32>) -> () - }) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<*xf32>) - func.return %1#3: tensor<*xf32> + "mhlo.return"(%arg1, %arg2, %arg3, %4) : (tensor<1xi32>, tensor, tensor<1xf32>, tensor<3xf32>) -> () + }) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor) + func.return %1#3: tensor } // Negative tests below @@ -101,13 +77,13 @@ func.func @while_with_invalid_types(%arg0: tensor<3xf32>) -> tensor<3xf32> { %1:4 = "mhlo.while"(%cst_0, %cst_1, %cst_2, %arg0) ({ ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): %2 = arith.constant dense<0> : tensor - %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + %3 = "mhlo.slice"(%arg2) <{limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = #mhlo} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor "mhlo.return"(%5) : (tensor) -> () }, { ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): - %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32> + %3 = "mhlo.broadcast_in_dim"(%arg3) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xf32>) -> tensor<3xf32> %4 = mhlo.add %3, %arg4 : tensor<3xf32> "mhlo.return"(%arg1, %arg2, %arg3, %4) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> () }) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> (tensor<1xi32>, tensor<2xi32>, tensor<3xf32>, tensor<1xf32>) @@ -122,12 +98,12 @@ func.func @while_with_invalid_tuples(%arg0: tensor<3xf32>) -> tensor<3xf32> { %cst_2 = arith.constant dense<1.00> : tensor<1xf32> %0 = "mhlo.tuple"(%arg0, %cst_2) : (tensor<3xf32>, tensor<1xf32>) -> tuple, tensor<1xf32>> %1 = "mhlo.tuple"(%cst_1, %0) : (tensor<2xi32>, tuple, tensor<1xf32>>) -> tuple, tuple, tensor<1xf32>>> - // expected-error @+1 {{op operand #1 must be variadic of tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer values or token, but got 'tuple, tuple, tensor<1xf32>>>'}} + // expected-error @+1 {{operand #1 must be variadic of ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer values or ranked tensor of 4/8/16/32-bit uniform quantized per axis signed integer or 4/8/16/32-bit uniform quantized per axis unsigned integer values or token, but got 'tuple, tuple, tensor<1xf32>>>'}} %2:2 = "mhlo.while"(%cst_0, %1) ({ ^bb0(%arg1: tensor<1xi32>, %arg2: tuple, tuple, tensor<3xf32>>>): %t0 = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple, tuple, tensor<3xf32>>>) -> tensor<2xi32> %3 = arith.constant dense<0> : tensor - %4 = "mhlo.slice"(%t0) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + %4 = "mhlo.slice"(%t0) <{limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> %5 = "mhlo.compare"(%arg1, %4) {comparison_direction = #mhlo} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> "mhlo.return"(%5) : (tensor<1xi1>) -> () }, { @@ -136,7 +112,7 @@ func.func @while_with_invalid_tuples(%arg0: tensor<3xf32>) -> tensor<3xf32> { %t1_2 = "mhlo.get_tuple_element"(%arg2) {index = 1 : i32} : (tuple, tuple, tensor<3xf32>>>) -> tuple, tensor<3xf32>> %t1 = "mhlo.get_tuple_element"(%t1_2) {index = 0 : i32} : (tuple, tensor<3xf32>>) -> tensor<1xf32> %t2 = "mhlo.get_tuple_element"(%t1_2) {index = 1 : i32} : (tuple, tensor<3xf32>>) -> tensor<3xf32> - %3 = "mhlo.broadcast_in_dim"(%t1) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32> + %3 = "mhlo.broadcast_in_dim"(%t1) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xf32>) -> tensor<3xf32> %4 = mhlo.add %3, %t2 : tensor<3xf32> %5 = "mhlo.tuple"(%t1, %4) : (tensor<1xf32>, tensor<3xf32>) -> tuple, tensor<3xf32>> %6 = "mhlo.tuple"(%t0, %5) : (tensor<2xi32>, tuple, tensor<3xf32>>) -> tuple, tuple, tensor<3xf32>>> @@ -156,13 +132,13 @@ func.func @while_with_different_types(%arg0: tensor<3xf32>) -> tensor<3xf32> { %1:4 = "mhlo.while"(%cst_0, %cst_1, %cst_2, %arg0) ({ ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<3xf32>, %arg4: tensor<3xf32>): %2 = arith.constant dense<0> : tensor - %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + %3 = "mhlo.slice"(%arg2) <{limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = #mhlo} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor "mhlo.return"(%5) : (tensor) -> () }, { ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): - %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32> + %3 = "mhlo.broadcast_in_dim"(%arg3) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xf32>) -> tensor<3xf32> %4 = mhlo.add %3, %arg4 : tensor<3xf32> "mhlo.return"(%arg1, %arg2, %arg3, %4) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> () }) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) @@ -179,13 +155,13 @@ func.func @while_with_different_types(%arg0: tensor<3xf32>) -> tensor<3xf32> { %1:4 = "mhlo.while"(%cst_0, %cst_1, %cst_2, %arg0) ({ ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): %2 = arith.constant dense<0> : tensor - %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + %3 = "mhlo.slice"(%arg2) <{limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = #mhlo} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor "mhlo.return"(%5) : (tensor) -> () }, { ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<3xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): - %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32> + %3 = "mhlo.broadcast_in_dim"(%arg3) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xf32>) -> tensor<3xf32> %4 = mhlo.add %3, %arg4 : tensor<3xf32> "mhlo.return"(%arg1, %arg2, %arg3, %4) : (tensor<1xi32>, tensor<3xi32>, tensor<1xf32>, tensor<3xf32>) -> () }) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) @@ -202,13 +178,13 @@ func.func @while_with_block_count_mismatch(%arg0: tensor<3xf32>) -> tensor<3xf32 %1:4 = "mhlo.while"(%cst_0, %cst_1, %cst_2, %arg0) ({ ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>): %2 = arith.constant dense<0> : tensor - %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + %3 = "mhlo.slice"(%arg2) <{limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = #mhlo} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor "mhlo.return"(%5) : (tensor) -> () }, { ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<3xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): - %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32> + %3 = "mhlo.broadcast_in_dim"(%arg3) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xf32>) -> tensor<3xf32> %4 = mhlo.add %3, %arg4 : tensor<3xf32> "mhlo.return"(%arg1, %arg2, %arg3, %4) : (tensor<1xi32>, tensor<3xi32>, tensor<1xf32>, tensor<3xf32>) -> () }) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) @@ -225,13 +201,13 @@ func.func @while_with_block_count_mismatch(%arg0: tensor<3xf32>) -> tensor<3xf32 %1:4 = "mhlo.while"(%cst_0, %cst_1, %cst_2, %arg0) ({ ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): %2 = arith.constant dense<0> : tensor - %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + %3 = "mhlo.slice"(%arg2) <{limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = #mhlo} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor "mhlo.return"(%5) : (tensor) -> () }, { ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<3xi32>, %arg3: tensor<1xf32>): - %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32> + %3 = "mhlo.broadcast_in_dim"(%arg3) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xf32>) -> tensor<3xf32> "mhlo.return"(%arg1, %arg2, %arg3, %3) : (tensor<1xi32>, tensor<3xi32>, tensor<1xf32>, tensor<3xf32>) -> () }) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) func.return %1#3: tensor<3xf32> @@ -250,7 +226,7 @@ func.func @while_with_cond_return_width_mismatch(%arg0: tensor<3xf32>) -> tensor "mhlo.return"(%2) : (tensor) -> () }, { ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): - %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32> + %3 = "mhlo.broadcast_in_dim"(%arg3) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xf32>) -> tensor<3xf32> %4 = mhlo.add %3, %arg4 : tensor<3xf32> "mhlo.return"(%arg1, %arg2, %arg3, %4) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> () }) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) @@ -266,12 +242,12 @@ func.func @while_with_cond_return_rank_mismatch(%arg0: tensor<3xf32>) -> tensor< // expected-error @+1 {{expect condition block return a zero-ranked tensor of i1 but got 'tensor<1xi1>'}} %1:4 = "mhlo.while"(%cst_0, %cst_1, %cst_2, %arg0) ({ ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): - %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + %3 = "mhlo.slice"(%arg2) <{limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = #mhlo} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> "mhlo.return"(%4) : (tensor<1xi1>) -> () }, { ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): - %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32> + %3 = "mhlo.broadcast_in_dim"(%arg3) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xf32>) -> tensor<3xf32> %4 = mhlo.add %3, %arg4 : tensor<3xf32> "mhlo.return"(%arg1, %arg2, %arg3, %4) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> () }) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) @@ -291,7 +267,7 @@ func.func @while_with_cond_return_type_mismatch(%arg0: tensor<3xf32>) -> tensor< "mhlo.return"(%2) : (tensor) -> () }, { ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): - %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32> + %3 = "mhlo.broadcast_in_dim"(%arg3) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xf32>) -> tensor<3xf32> %4 = mhlo.add %3, %arg4 : tensor<3xf32> "mhlo.return"(%arg1, %arg2, %arg3, %4) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> () }) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) @@ -308,7 +284,7 @@ func.func @while_with_body_return_mismatch(%arg0: tensor<3xf32>) -> tensor<3xf32 %1:4 = "mhlo.while"(%cst_0, %cst_1, %cst_2, %arg0) ({ ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): %2 = arith.constant dense<0> : tensor - %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + %3 = "mhlo.slice"(%arg2) <{limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = #mhlo} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor "mhlo.return"(%5) : (tensor) -> () @@ -329,13 +305,13 @@ func.func @while_with_multiple_operand_in_cond_return(%arg0: tensor<3xf32>) -> t %1:4 = "mhlo.while"(%cst_0, %cst_1, %cst_2, %arg0) ({ ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): %2 = arith.constant dense<0> : tensor - %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + %3 = "mhlo.slice"(%arg2) <{limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = #mhlo} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor "mhlo.return"(%5, %5) : (tensor, tensor) -> () }, { ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): - %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32> + %3 = "mhlo.broadcast_in_dim"(%arg3) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xf32>) -> tensor<3xf32> %4 = mhlo.add %3, %arg4 : tensor<3xf32> "mhlo.return"(%arg1, %arg2, %arg3, %4) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> () }) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) @@ -352,13 +328,13 @@ func.func @while_mismatch_operand_count_with_body_return(%arg0: tensor<3xf32>) - %1:4 = "mhlo.while"(%cst_0, %cst_1, %cst_2, %arg0) ({ ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): %2 = arith.constant dense<0> : tensor - %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + %3 = "mhlo.slice"(%arg2) <{limit_indices = dense<[1]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> %4 = "mhlo.compare"(%arg1, %3) {comparison_direction = #mhlo} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> %5 = "mhlo.reshape"(%4) : (tensor<1xi1>) -> tensor "mhlo.return"(%5) : (tensor) -> () }, { ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): - %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32> + %3 = "mhlo.broadcast_in_dim"(%arg3) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xf32>) -> tensor<3xf32> %4 = mhlo.add %3, %arg4 : tensor<3xf32> "mhlo.return"(%arg1, %arg2, %arg3) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>) -> () }) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) diff --git a/xla/mlir_hlo/tests/buffer_packing.mlir b/xla/mlir_hlo/tests/buffer_packing.mlir deleted file mode 100644 index 737dad4148de1..0000000000000 --- a/xla/mlir_hlo/tests/buffer_packing.mlir +++ /dev/null @@ -1,164 +0,0 @@ -// RUN: mlir-hlo-opt -buffer-packing -split-input-file %s | FileCheck %s - -// CHECK-LABEL: @noPackingSameLiveRange -func.func @noPackingSameLiveRange() -> (f32, f32) { - // CHECK: memref.alloc - // CHECK: memref.alloc - %c1 = arith.constant 1 : index - %c2 = arith.constant 2.0 : f32 - %0 = memref.alloc() : memref<42xf32> - %1 = memref.alloc() : memref<42xf32> - memref.store %c2, %0[%c1] : memref<42xf32> - memref.store %c2, %1[%c1] : memref<42xf32> - %2 = memref.load %0[%c1] : memref<42xf32> - %3 = memref.load %1[%c1] : memref<42xf32> - return %2, %3 : f32, f32 -} - -// ----- - -// CHECK-LABEL: @packingScfIfSameSize -func.func @packingScfIfSameSize(%pred : i1) -> (f32) { - // CHECK: %[[MEM:.*]] = memref.alloc() : memref<192xi8> - // CHECK: %[[VIEW1:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<42xf32> - // CHECK: %[[VIEW2:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<42xf32> - // CHECK: scf.if - // CHECK: memref.load %[[VIEW1]] - // CHECK: else - // CHECK: memref.load %[[VIEW2]] - %c1 = arith.constant 1 : index - %c2 = arith.constant 2.0 : f32 - %0 = memref.alloc() : memref<42xf32> - %1 = memref.alloc() : memref<42xf32> - %2 = scf.if %pred -> f32 { - memref.store %c2, %0[%c1] : memref<42xf32> - %2 = memref.load %0[%c1] : memref<42xf32> - scf.yield %2 : f32 - } else { - memref.store %c2, %1[%c1] : memref<42xf32> - %2 = memref.load %1[%c1] : memref<42xf32> - scf.yield %2 : f32 - } - return %2 : f32 -} - -// ----- - -// CHECK-LABEL: @packingScfIfDifferentSize -func.func @packingScfIfDifferentSize(%pred : i1) -> (f32) { - // CHECK: %[[MEM:.*]] = memref.alloc() : memref<192xi8> - // CHECK: scf.if - // CHECK: %[[VIEW1:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<42xf32> - // CHECK: memref.load %[[VIEW1]] - // CHECK: else - // CHECK: %[[VIEW2:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<16xf32> - // CHECK: memref.load %[[VIEW2]] - %c1 = arith.constant 1 : index - %c2 = arith.constant 2.0 : f32 - %0 = scf.if %pred -> f32 { - %0 = memref.alloc() : memref<42xf32> - memref.store %c2, %0[%c1] : memref<42xf32> - %1 = memref.load %0[%c1] : memref<42xf32> - scf.yield %1 : f32 - } else { - %0 = memref.alloc() : memref<16xf32> - memref.store %c2, %0[%c1] : memref<16xf32> - %1 = memref.load %0[%c1] : memref<16xf32> - scf.yield %1 : f32 - } - return %0 : f32 -} - -// ----- - -// CHECK-LABEL: @packingScfIfDifferentElementType -func.func @packingScfIfDifferentElementType(%pred : i1) -> (f32) { - // CHECK: %[[MEM:.*]] = memref.alloc() : memref<128xi8> - // CHECK: scf.if - // CHECK: %[[VIEW1:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<128xi8> to memref<42xf16> - // CHECK: memref.load %[[VIEW1]] - // CHECK: else - // CHECK: %[[VIEW2:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<128xi8> to memref<16xf32> - // CHECK: memref.load %[[VIEW2]] - %c1 = arith.constant 1 : index - %0 = scf.if %pred -> f32 { - %c2 = arith.constant 2.0 : f16 - %0 = memref.alloc() : memref<42xf16> - memref.store %c2, %0[%c1] : memref<42xf16> - %1 = memref.load %0[%c1] : memref<42xf16> - %2 = arith.extf %1 : f16 to f32 - scf.yield %2 : f32 - } else { - %c2 = arith.constant 2.0 : f32 - %0 = memref.alloc() : memref<16xf32> - memref.store %c2, %0[%c1] : memref<16xf32> - %1 = memref.load %0[%c1] : memref<16xf32> - scf.yield %1 : f32 - } - return %0 : f32 -} - -// ----- - -// CHECK-LABEL: @packWithOutsideControlFlow -func.func @packWithOutsideControlFlow(%pred : i1) -> (f32, f32) { - // CHECK: %[[MEM:.*]] = memref.alloc() : memref<192xi8> - // CHECK: %[[VIEW0:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<42xf32> - // CHECK: memref.load %[[VIEW0]] - // CHECK: scf.if - // CHECK: %[[VIEW1:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<42xf32> - // CHECK: memref.load %[[VIEW1]] - // CHECK: else - // CHECK: %[[VIEW2:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<42xf32> - // CHECK: memref.load %[[VIEW2]] - %c1 = arith.constant 1 : index - %c2 = arith.constant 2.0 : f32 - %0 = memref.alloc() : memref<42xf32> - memref.store %c2, %0[%c1] : memref<42xf32> - %1 = memref.load %0[%c1] : memref<42xf32> - %2 = scf.if %pred -> f32 { - %3 = memref.alloc() : memref<42xf32> - memref.store %c2, %3[%c1] : memref<42xf32> - %4 = memref.load %3[%c1] : memref<42xf32> - scf.yield %4 : f32 - } else { - %3 = memref.alloc() : memref<42xf32> - memref.store %c2, %3[%c1] : memref<42xf32> - %4 = memref.load %3[%c1] : memref<42xf32> - scf.yield %4 : f32 - } - return %1, %2 : f32, f32 -} - -// ----- - -// CHECK-LABEL: @packTwoInOne -func.func @packTwoInOne(%pred : i1) -> (f32) { - // CHECK: %[[MEM:.*]] = memref.alloc() : memref<192xi8> - // CHECK: scf.if - // CHECK: %[[VIEW1:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<42xf32> - // CHECK: memref.load %[[VIEW1]] - // CHECK: else - // CHECK: %[[VIEW2:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<16xf32> - // CHECK: %[[VIEW3:.*]] = memref.view %[[MEM]][%{{.*}}][] : memref<192xi8> to memref<8xf32> - // CHECK: memref.load %[[VIEW2]] - // CHECK: memref.load %[[VIEW3]] - %c1 = arith.constant 1 : index - %c2 = arith.constant 2.0 : f32 - %0 = scf.if %pred -> f32 { - %0 = memref.alloc() : memref<42xf32> - memref.store %c2, %0[%c1] : memref<42xf32> - %1 = memref.load %0[%c1] : memref<42xf32> - scf.yield %1 : f32 - } else { - %0 = memref.alloc() : memref<16xf32> - %1 = memref.alloc() : memref<8xf32> - memref.store %c2, %0[%c1] : memref<16xf32> - %2 = memref.load %0[%c1] : memref<16xf32> - memref.store %c2, %1[%c1] : memref<8xf32> - %3 = memref.load %1[%c1] : memref<8xf32> - %4 = arith.addf %2, %3 : f32 - scf.yield %4 : f32 - } - return %0 : f32 -} diff --git a/xla/mlir_hlo/tests/bufferize.mlir b/xla/mlir_hlo/tests/bufferize.mlir index d91f923467f97..1341eeb1942bb 100644 --- a/xla/mlir_hlo/tests/bufferize.mlir +++ b/xla/mlir_hlo/tests/bufferize.mlir @@ -257,7 +257,7 @@ func.func @slice(%t : tensor<3xi32>) -> tensor<1xi32> { func.func @dynamic_broadcast_return(%t : tensor, %shape : tensor<2xi32>) -> tensor { // CHECK: memref.copy - %bcast = "mhlo.dynamic_broadcast_in_dim"(%t, %shape) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xi32>) -> tensor + %bcast = "mhlo.dynamic_broadcast_in_dim"(%t, %shape) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor, tensor<2xi32>) -> tensor func.return %bcast : tensor } diff --git a/xla/mlir_hlo/tests/bufferize_one_shot.mlir b/xla/mlir_hlo/tests/bufferize_one_shot.mlir index 91d04a84872f2..278ac3d360ba0 100644 --- a/xla/mlir_hlo/tests/bufferize_one_shot.mlir +++ b/xla/mlir_hlo/tests/bufferize_one_shot.mlir @@ -109,7 +109,7 @@ func.func @slice(%t : tensor<3xi32>) -> tensor<1xi32> { func.func @dynamic_broadcast_return(%t : tensor, %shape : tensor<2xi32>) -> tensor { // CHECK: memref.copy - %bcast = "mhlo.dynamic_broadcast_in_dim"(%t, %shape) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xi32>) -> tensor + %bcast = "mhlo.dynamic_broadcast_in_dim"(%t, %shape) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor, tensor<2xi32>) -> tensor func.return %bcast : tensor } diff --git a/xla/mlir_hlo/tests/capi_test.c b/xla/mlir_hlo/tests/capi_test.c index a37d4a8aab9ea..92b75ce9a3dc6 100644 --- a/xla/mlir_hlo/tests/capi_test.c +++ b/xla/mlir_hlo/tests/capi_test.c @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/xla/mlir_hlo/tests/collapse_parallel_loops_to_1d_pass.mlir b/xla/mlir_hlo/tests/collapse_parallel_loops_to_1d_pass.mlir index a49b1675cb928..584c5c8520644 100644 --- a/xla/mlir_hlo/tests/collapse_parallel_loops_to_1d_pass.mlir +++ b/xla/mlir_hlo/tests/collapse_parallel_loops_to_1d_pass.mlir @@ -12,7 +12,7 @@ func.func @parallel_2d(%arg0: memref<4x4xf32>, %arg1: memref<4x4xf32>) { %2 = memref.load %arg0[%arg2,%arg3] : memref<4x4xf32> %3 = math.log %2 : f32 memref.store %3, %0[%arg2,%arg3] : memref<4x4xf32> - scf.yield + scf.reduce } %1 = bufferization.to_tensor %0 : memref<4x4xf32> bufferization.materialize_in_destination %1 in writable %arg1 diff --git a/xla/mlir_hlo/tests/detensorize_scf_ops.mlir b/xla/mlir_hlo/tests/detensorize_scf_ops.mlir index 0c744d378d103..6c2b7470613f5 100644 --- a/xla/mlir_hlo/tests/detensorize_scf_ops.mlir +++ b/xla/mlir_hlo/tests/detensorize_scf_ops.mlir @@ -75,8 +75,8 @@ func.func @if_return(%cond: i1) -> tensor { // CHECK-LABEL: @if_return // CHECK-SAME: (%[[COND:.*]]: -// CHECK: %[[C0:.*]] = arith.constant 0 -// CHECK: %[[C1:.*]] = arith.constant 1 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 // CHECK: %[[RESULT_SCALAR:.*]] = scf.if %[[COND]] -> (f32) { // CHECK: scf.yield %[[C0]] // CHECK: } else { @@ -99,9 +99,9 @@ func.func @for(%arg: tensor) -> tensor { // CHECK-LABEL: @for // CHECK-SAME: (%[[ARG:.*]]: -// CHECK: %[[CST:.*]] = arith.constant 0.0 -// CHECK: %[[C0:.*]] = arith.constant 0 -// CHECK: %[[C1:.*]] = arith.constant 1 +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.0 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 // CHECK: %[[ARG_SCALAR:.*]] = tensor.extract %[[ARG]] // CHECK: %[[RESULT_SCALAR:.*]] = scf.for {{.*}} iter_args(%{{.*}} = %[[ARG_SCALAR]]) // CHECK: scf.yield %[[CST]] diff --git a/xla/mlir_hlo/tests/lit.cfg.py b/xla/mlir_hlo/tests/lit.cfg.py index 5ec92d7fb24a5..06fa09371e8d9 100644 --- a/xla/mlir_hlo/tests/lit.cfg.py +++ b/xla/mlir_hlo/tests/lit.cfg.py @@ -1,5 +1,5 @@ """Lit configuration to drive test in this repo.""" -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The OpenXLA Authors. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/xla/mlir_hlo/tests/lit.site.cfg.py.in b/xla/mlir_hlo/tests/lit.site.cfg.py.in index 0104048d180be..4dc5165acfc99 100644 --- a/xla/mlir_hlo/tests/lit.site.cfg.py.in +++ b/xla/mlir_hlo/tests/lit.site.cfg.py.in @@ -1,4 +1,4 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The OpenXLA Authors. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/xla/mlir_hlo/tests/naive_copy_removal.mlir b/xla/mlir_hlo/tests/naive_copy_removal.mlir index bec5a0822ac7b..2b6a5b191c463 100644 --- a/xla/mlir_hlo/tests/naive_copy_removal.mlir +++ b/xla/mlir_hlo/tests/naive_copy_removal.mlir @@ -65,9 +65,9 @@ func.func @target_is_subview_of_subview(%arg0: memref<8x8xf32>) %subview_5 = memref.subview %alloc_4[0, 0] [%c4, %c4] [1, 1] : memref<8x8xf32> to memref> %subview_6 = memref.subview %subview_5[0, 0] [%c4, %c4] [1, 1] : - memref> to memref> + memref> to memref> memref.copy %arg0, %subview_6 : - memref<8x8xf32> to memref> + memref<8x8xf32> to memref> return %arg0 : memref<8x8xf32> } @@ -79,32 +79,6 @@ func.func @target_is_subview_of_subview(%arg0: memref<8x8xf32>) // ----- -func.func @do_not_simplify_subview_of_subview(%arg0: memref<8x8xf32>) - -> vector<8x8xf32> { - %c4 = arith.constant 4 : index - %c0 = arith.constant 0 : index - %cst_0 = arith.constant 0.000000e+00 : f32 - %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x8xf32> - %subview_5 = memref.subview %alloc_4[0, 0] [%c4, %c4] [1, 1] : - memref<8x8xf32> to memref> - %subview_6 = memref.subview %subview_5[0, 0] [%c4, %c4] [1, 1] : - memref> to memref> - memref.copy %arg0, %subview_6 : - memref<8x8xf32> to memref> - %27 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 : - memref>, vector<8x8xf32> - return %27 : vector<8x8xf32> -} - -// CHECK-LABEL: func @do_not_simplify_subview_of_subview( - -// CHECK: memref.alloc -// CHECK: memref.subview -// CHECK: memref.subview -// CHECK: memref.copy - -// ----- - func.func @do_not_simplify_subview(%arg0: memref<8x8xf32>) -> vector<8x8xf32> { %c4 = arith.constant 4 : index %c0 = arith.constant 0 : index diff --git a/xla/mlir_hlo/tests/python/CMakeLists.txt b/xla/mlir_hlo/tests/python/CMakeLists.txt index b9e177528db60..f869e89a2b036 100644 --- a/xla/mlir_hlo/tests/python/CMakeLists.txt +++ b/xla/mlir_hlo/tests/python/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/tests/python/attributes.py b/xla/mlir_hlo/tests/python/attributes.py index 706dafc74b1ff..d92283fb67ffd 100644 --- a/xla/mlir_hlo/tests/python/attributes.py +++ b/xla/mlir_hlo/tests/python/attributes.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -216,3 +216,13 @@ def test_type_extensions(): attr = mhlo.TypeExtensions.get(bounds=[128, dyn_size]) assert attr is not None assert attr.bounds == [128, dyn_size] + + +@run +def test_sparsity_descriptor(): + attr = mhlo.SparsityDescriptor.get(dimension=1, n=2, m=4) + assert attr is not None + assert str(attr) == "#mhlo.sparsity" + assert attr.dimension == 1 + assert attr.n == 2 + assert attr.m == 4 diff --git a/xla/mlir_hlo/tests/python/smoketest.py b/xla/mlir_hlo/tests/python/smoketest.py index cd3b1c1968c48..6270430fded9c 100755 --- a/xla/mlir_hlo/tests/python/smoketest.py +++ b/xla/mlir_hlo/tests/python/smoketest.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/tests/python/types.py b/xla/mlir_hlo/tests/python/types.py index d66dd70cf5486..44e3453b2cf8b 100644 --- a/xla/mlir_hlo/tests/python/types.py +++ b/xla/mlir_hlo/tests/python/types.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/tests/rank-specialization.mlir b/xla/mlir_hlo/tests/rank-specialization.mlir deleted file mode 100644 index 3627fbebedc89..0000000000000 --- a/xla/mlir_hlo/tests/rank-specialization.mlir +++ /dev/null @@ -1,702 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s -// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster --mhlo-rank-specialization-to-scf=max-target-rank=3 | FileCheck %s --check-prefix CHECK-SCF - -// CHECK-LABEL: @add_mul -// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>) -func.func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, - %arg2 : tensor<*xf32>) -> tensor<*xf32> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG2]], %[[ARG0]], %[[ARG1]]) ({ - // CHECK: ^bb0(%[[ARG2_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>): - // CHECK: %[[TMP:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[ARG1_]] - // CHECK: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[ARG2_]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[INNER_RES]]) - // CHECK: }) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - // CHECK: return %[[RES]] - %0 = chlo.broadcast_multiply %arg0, %arg1 - : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %1 = chlo.broadcast_add %0, %arg2 - : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - func.return %1 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @add_mul -// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>) -// CHECK-SCF-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-SCF-DAG: %[[C2:.*]] = arith.constant 2 -// CHECK-SCF-DAG: %[[C3:.*]] = arith.constant 3 -// CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1] -// CHECK-SCF-DAG: %[[ONE_SHAPE_2:.*]] = shape.const_shape [1, 1] -// CHECK-SCF-DAG: %[[ONE_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] -// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF-DAG: %[[SHAPE_ARG2:.*]] = shape.shape_of %[[ARG2]] -// Equal shapes case: -// CHECK-SCF-DAG: %[[EQ20:.*]] = shape.shape_eq %[[SHAPE_ARG2]], %[[SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[EQ21:.*]] = shape.shape_eq %[[SHAPE_ARG2]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = arith.andi %[[EQ20]], %[[EQ21]] -// CHECK-SCF: %[[UNSHAPED_RES_EQ_SHAPES:.*]] = scf.if %[[SHAPES_EQ]] -// CHECK-SCF-DAG: %[[ANY_SHAPE:.*]] = shape.any %[[SHAPE_ARG2]], %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[ANY_SHAPE]] -// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[FLAT_ARG2:.*]] = mhlo.dynamic_reshape %[[ARG2]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[FLAT_ARG0]], %[[FLAT_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[FLAT_ARG2]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Find maximum reduced rank. -// CHECK-SCF-DAG: %[[REDUCED_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[SHAPE_ARG2]], %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[REDUCED_RANK0:.*]] = shape.rank %[[REDUCED_SHAPES]]#1 -// CHECK-SCF-DAG: %[[REDUCED_RANK1:.*]] = shape.rank %[[REDUCED_SHAPES]]#2 -// CHECK-SCF-DAG: %[[REDUCED_RANK2:.*]] = shape.rank %[[REDUCED_SHAPES]]#0 -// CHECK-SCF-DAG: %[[R2_GT_R0:.*]] = arith.cmpi sgt, %[[REDUCED_RANK2]], %[[REDUCED_RANK0]] -// CHECK-SCF-DAG: %[[R20:.*]] = arith.select %[[R2_GT_R0]], %[[REDUCED_RANK2]], %[[REDUCED_RANK0]] -// CHECK-SCF-DAG: %[[R20_GT_R1:.*]] = arith.cmpi sgt, %[[R20]], %[[REDUCED_RANK1]] -// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = arith.select %[[R20_GT_R1]], %[[R20]], %[[REDUCED_RANK1]] -// Generic case 1: -// CHECK-SCF: %[[MAX_RED_RANK_LE_1:.*]] = arith.cmpi ule, %[[MAX_RED_RANK]], %[[C1]] -// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_LE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] -// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[EXT_SHAPE_ARG0_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[EXT_SHAPE_ARG1_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = mhlo.dynamic_reshape %[[ARG2]], %[[EXT_SHAPE_ARG2_]] -// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Generic case 2: -// CHECK-SCF: %[[MAX_RED_RANK_LE_2:.*]] = arith.cmpi ule, %[[MAX_RED_RANK]], %[[C2]] -// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_LE_2]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_2]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] -// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[EXT_SHAPE_ARG0_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[EXT_SHAPE_ARG1_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = mhlo.dynamic_reshape %[[ARG2]], %[[EXT_SHAPE_ARG2_]] -// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Generic case 3: -// CHECK-SCF: %[[MAX_RED_RANK_LE_3:.*]] = arith.cmpi ule, %[[MAX_RED_RANK]], %[[C3]] -// CHECK-SCF: assert %[[MAX_RED_RANK_LE_3]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 3" -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_3]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] -// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[EXT_SHAPE_ARG0_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[EXT_SHAPE_ARG1_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = mhlo.dynamic_reshape %[[ARG2]], %[[EXT_SHAPE_ARG2_]] -// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]] -// CHECK-SCF: scf.yield %[[UNSHAPED_RES_1]] -// Reshape the result. -// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF-DAG: %[[TMP:.*]] = shape.broadcast %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[SHAPE_ARG2:.*]] = shape.shape_of %[[ARG2]] -// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[TMP]], %[[SHAPE_ARG2]] -// CHECK-SCF-DAG: %[[RES:.*]] = mhlo.dynamic_reshape %[[UNSHAPED_RES_EQ_SHAPES]], %[[RES_SHAPE]] -// CHECK-SCF: return %[[RES]] - -// ----- - -// CHECK-LABEL: @compare_const_like -// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>) -func.func @compare_const_like(%arg0 : tensor<*xf32>) -> tensor<*xi1> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]]) ({ - // CHECK: ^bb0(%[[ARG1:.*]]: tensor<*xf32>): - // CHECK: %[[ZERO:.*]] = "chlo.constant_like"(%[[ARG1]]) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> - // CHECK: %[[CMP_GT:.*]] = chlo.broadcast_compare %[[ARG1]], %[[ZERO]] {comparison_direction = #chlo} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1> - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[CMP_GT]]) : (tensor<*xi1>) -> () - // CHECK: }) : (tensor<*xf32>) -> tensor<*xi1> - // CHECK: return %[[RES]] : tensor<*xi1> - %0 = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> - %1 = chlo.broadcast_compare %arg0, %0 {comparison_direction = #chlo} - : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1> - func.return %1 : tensor<*xi1> -} - -// ----- - -// Unary MHLO operation. -// CHECK-LABEL: @sqrt -// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -func.func @sqrt(%arg : tensor<*xf32>) -> tensor<*xf32> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]]) - // CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>): - // CHECK: %[[TMP0:.*]] = mhlo.sqrt %[[ARG_]] - // CHECK: %[[TMP1:.*]] = mhlo.sqrt %[[TMP0]] - // CHECK: %[[TMP2:.*]] = mhlo.sqrt %[[TMP1]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]]) - // CHECK: return %[[RES]] - %0 = mhlo.sqrt %arg : (tensor<*xf32>) -> tensor<*xf32> - %1 = mhlo.sqrt %0 : (tensor<*xf32>) -> tensor<*xf32> - %2 = mhlo.sqrt %1 : (tensor<*xf32>) -> tensor<*xf32> - func.return %2 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @sqrt -// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xf32>) -// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] -// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]] -// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF: %[[FLAT_ARG:.*]] = mhlo.dynamic_reshape %[[ARG]], %[[FLAT_SHAPE]] : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-SCF: %[[TMP0:.*]] = mhlo.sqrt %[[FLAT_ARG]] : tensor -// CHECK-SCF: %[[TMP1:.*]] = mhlo.sqrt %[[TMP0]] : tensor -// CHECK-SCF: %[[UNSHAPED_RES:.*]] = mhlo.sqrt %[[TMP1]] : tensor -// CHECK-SCF: %[[RES:.*]] = mhlo.dynamic_reshape %[[UNSHAPED_RES]], %[[SHAPE]] : (tensor, tensor) -> tensor<*xf32> -// CHECK-SCF: return %[[RES]] - -// ----- - -// Don't cluster ranked operations. -// CHECK-LABEL: @sqrt_ranked -// CHECK-SAME: (%[[ARG:.*]]: tensor<3x?xf32>) -func.func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> { - // CHECK-NOT: rank_specialization_cluster - %0 = mhlo.sqrt %arg : (tensor<3x?xf32>) -> tensor<3x?xf32> - %1 = mhlo.sqrt %0 : (tensor<3x?xf32>) -> tensor<3x?xf32> - %2 = mhlo.sqrt %1 : (tensor<3x?xf32>) -> tensor<3x?xf32> - func.return %2 : tensor<3x?xf32> -} - -// CHECK-SCF-LABEL: @sqrt_ranked -// CHECK-SCF-NOT: dynamic_reshape -// CHECK-SCF: return - -// ----- - -// Operation with mixed ranked and unranked operands. -// CHECK-LABEL: @select_mixed -// CHECK-SAME: (%[[PRED:.*]]: tensor<*xi1>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<2xf32>) -func.func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>, - %arg2: tensor<2xf32>) -> tensor<*xf32> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[PRED]], %[[ARG1]], %[[ARG2]]) - // CHECK: ^bb0(%[[PRED_:.*]]: tensor<*xi1>, %[[ARG1_:.*]]: tensor<*xf32>, %[[ARG2_:.*]]: tensor<2xf32>) - // CHECK: %[[TMP:.*]] = chlo.broadcast_select %[[PRED_]], %[[ARG1_]], %[[ARG2_]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]]) - // CHECK: return %[[RES]] - %0 = "chlo.broadcast_select"(%pred, %arg1, %arg2) - : (tensor<*xi1>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @select_mixed -// CHECK-SCF: chlo.broadcast_select %{{.*}}, %{{.*}}, %{{.*}} : (tensor, tensor, tensor) -// CHECK-SCF: return - -// ----- - -// Unary CHLO operation. -// CHECK-LABEL: @tan -// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -func.func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]]) ({ - // CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>) - // CHECK: %[[TMP0:.*]] = chlo.tan %[[ARG_]] - // CHECK: %[[TMP1:.*]] = chlo.tan %[[TMP0]] - // CHECK: %[[TMP2:.*]] = chlo.tan %[[TMP1]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]]) - // CHECK: return %[[RES]] - %0 = chlo.tan %arg : tensor<*xf32> -> tensor<*xf32> - %1 = chlo.tan %0 : tensor<*xf32> -> tensor<*xf32> - %2 = chlo.tan %1 : tensor<*xf32> -> tensor<*xf32> - func.return %2 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @tan -// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xf32>) -// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] -// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]] -// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF: %[[FLAT_ARG:.*]] = mhlo.dynamic_reshape %[[ARG]], %[[FLAT_SHAPE]] : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-SCF: %[[TMP0:.*]] = chlo.tan %[[FLAT_ARG]] : tensor -// CHECK-SCF: %[[TMP1:.*]] = chlo.tan %[[TMP0]] : tensor -// CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.tan %[[TMP1]] : tensor -// CHECK-SCF: %[[RES:.*]] = mhlo.dynamic_reshape %[[UNSHAPED_RES]], %[[SHAPE]] : (tensor, tensor) -> tensor<*xf32> -// CHECK-SCF: return %[[RES]] - -// ----- - -// Composition of unary/binary CHLO and unary MHLO ops. -// CHECK-LABEL: @mixed -// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>) -func.func @mixed(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>) - -> tensor<*xf32> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG2]], %[[ARG1]], %[[ARG0]]) - // CHECK: ^bb0(%[[ARG2_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>) - // CHECK: %[[TMP0:.*]] = chlo.tan %[[ARG0_]] - // CHECK: %[[TMP1:.*]] = mhlo.sqrt %[[ARG1_]] - // CHECK: %[[TMP2:.*]] = chlo.broadcast_multiply %[[TMP0]], %[[TMP1]] - // CHECK: %[[TMP3:.*]] = chlo.broadcast_add %[[TMP2]], %[[ARG2_]] - // CHECK: %[[TMP4:.*]] = mhlo.sqrt %[[TMP3]] - // CHECK: %[[TMP5:.*]] = chlo.tan %[[TMP4]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP5]]) - // CHECK: return %[[RES]] - %0 = chlo.tan %arg0 : tensor<*xf32> -> tensor<*xf32> - %1 = mhlo.sqrt %arg1 : (tensor<*xf32>) -> tensor<*xf32> - %2 = chlo.broadcast_multiply %0, %1 - : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %3 = chlo.broadcast_add %2, %arg2 - : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %4 = mhlo.sqrt %3 : (tensor<*xf32>) -> tensor<*xf32> - %5 = chlo.tan %4 : tensor<*xf32> -> tensor<*xf32> - func.return %5 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @mixed -// CHECK-SCF-DAG: %[[TMP0:.*]] = chlo.tan %{{.*}} : tensor -// CHECK-SCF-DAG: %[[TMP1:.*]] = mhlo.sqrt %{{.*}} : tensor -// CHECK-SCF-DAG: %[[TMP2:.*]] = chlo.broadcast_multiply %[[TMP0]], %[[TMP1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[TMP3:.*]] = chlo.broadcast_add %[[TMP2]], %{{.*}} : (tensor, tensor) -// CHECK-SCF-DAG: %[[TMP4:.*]] = mhlo.sqrt %[[TMP3]] : tensor -// CHECK-SCF: chlo.tan %[[TMP4]] : tensor - -// ----- - -// Constant cluster operand. -// CHECK-LABEL: @relu -// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -func.func @relu(%arg : tensor<*xf32>) -> tensor<*xf32> { - // CHECK: %[[C0:.*]] = mhlo.constant dense<0.000000e+00> - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]], %[[C0]]) - // CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>, %[[C0_:.*]]: tensor): - // CHECK: %[[TMP:.*]] = chlo.broadcast_maximum %[[ARG_]], %[[C0_]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]]) - // CHECK: return %[[RES]] - %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = chlo.broadcast_maximum %0, %arg - : (tensor, tensor<*xf32>) -> tensor<*xf32> - func.return %1 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @relu -// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xf32>) -// CHECK-SCF: %[[C0:.*]] = mhlo.constant dense<0.000000e+00> -// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] -// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]] -// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF: %[[FLAT_ARG:.*]] = mhlo.dynamic_reshape %[[ARG]], %[[FLAT_SHAPE]] : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.broadcast_maximum %[[FLAT_ARG]], %[[C0]] : (tensor, tensor) -// CHECK-SCF: %[[RES:.*]] = mhlo.dynamic_reshape %[[UNSHAPED_RES]], %[[SHAPE]] : (tensor, tensor) -> tensor<*xf32> -// CHECK-SCF: return %[[RES]] - -// ----- - -// Cluster with binary non-broadcasting operation. -// CHECK-LABEL: @angle -// CHECK-SAME: (%[[ARG:.*]]: tensor<*xcomplex>) -func.func @angle(%arg : tensor<*xcomplex>) -> tensor<*xf32> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]]) - // CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xcomplex>): - // CHECK: %[[IMAG:.*]] = mhlo.imag %[[ARG_]] - // CHECK: %[[REAL:.*]] = mhlo.real %[[ARG_]] - // CHECK: %[[TMP:.*]] = mhlo.atan2 %[[IMAG]], %[[REAL]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]]) - // CHECK: return %[[RES]] - %0 = mhlo.imag %arg : (tensor<*xcomplex>) -> tensor<*xf32> - %1 = mhlo.real %arg : (tensor<*xcomplex>) -> tensor<*xf32> - %2 = mhlo.atan2 %0, %1 : tensor<*xf32> - func.return %2 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @angle -// CHECK-SCF-SAME: (%[[ARG:.*]]: tensor<*xcomplex>) -// CHECK-SCF: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] -// CHECK-SCF: %[[N:.*]] = shape.num_elements %[[SHAPE]] -// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF: %[[FLAT_ARG:.*]] = mhlo.dynamic_reshape %[[ARG]], %[[FLAT_SHAPE]] : (tensor<*xcomplex>, tensor<1xindex>) -> tensor> -// CHECK-SCF: %[[IMAG:.*]] = mhlo.imag %[[FLAT_ARG]] : (tensor>) -// CHECK-SCF: %[[REAL:.*]] = mhlo.real %[[FLAT_ARG]] : (tensor>) -// CHECK-SCF: %[[UNSHAPED_RES:.*]] = mhlo.atan2 %[[IMAG]], %[[REAL]] : tensor - // CHECK-SCF: %[[RES:.*]] = mhlo.dynamic_reshape %[[UNSHAPED_RES]], %[[SHAPE]] : (tensor, tensor) -> tensor<*xf32> -// CHECK-SCF: return %[[RES]] - -// ----- - -// Scalar cluster operand. -// CHECK-LABEL: @xlogy -// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) -func.func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { - // CHECK: %[[C0:.*]] = mhlo.constant dense<0.000000e+00> - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[C0]], %[[ARG0]], %[[ARG1]]) - // CHECK: ^bb0(%[[C0_:.*]]: tensor, %[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>): - // CHECK: %[[TMP0:.*]] = chlo.broadcast_compare %[[ARG0_]], %[[C0_]] {comparison_direction = #chlo} - // CHECK: %[[TMP1:.*]] = mhlo.log %[[ARG1_]] - // CHECK: %[[TMP2:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[TMP1]] - // CHECK: %[[TMP3:.*]] = chlo.broadcast_select %[[TMP0]], %[[C0_]], %[[TMP2]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP3]]) - // CHECK: return %[[RES]] - %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = tensor.cast %0 : tensor to tensor - %2 = chlo.broadcast_compare %arg0, %1 {comparison_direction = #chlo} - : (tensor<*xf32>, tensor) -> tensor<*xi1> - %3 = mhlo.log %arg1 : (tensor<*xf32>) -> tensor<*xf32> - %4 = chlo.broadcast_multiply %arg0, %3 - : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %5 = chlo.broadcast_select %2, %1, %4 - : (tensor<*xi1>, tensor, tensor<*xf32>) -> tensor<*xf32> - func.return %5 : tensor<*xf32> -} - -// CHECK-SCF: @xlogy -// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) -// CHECK-SCF-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1] -// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.00{{.*}}> -// Lhs scalar case: -// CHECK-SCF-DAG: %[[LHS_N:.*]] = shape.num_elements %[[SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[LHS_SCALAR:.*]] = arith.cmpi eq, %[[LHS_N]], %[[C1]] -// CHECK-SCF: %[[UNSHAPED_RES:.*]] = scf.if %[[LHS_SCALAR]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[SCALAR:.*]] = mhlo.reshape %[[ARG0]] -// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[SCALAR]], %[[ZERO]] {comparison_direction = #chlo} : (tensor, tensor) -// CHECK-SCF-DAG: %[[TMP0:.*]] = mhlo.log %[[FLAT_NON_SCALAR]] : tensor -// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[SCALAR]], %[[TMP0]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor, tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Rhs scalar case: -// CHECK-SCF-DAG: %[[RHS_N:.*]] = shape.num_elements %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[RHS_SCALAR:.*]] = arith.cmpi eq, %[[RHS_N]], %[[C1]] -// CHECK-SCF: %{{.*}} = scf.if %[[RHS_SCALAR]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[SCALAR:.*]] = mhlo.reshape %[[ARG1]] -// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FLAT_NON_SCALAR]], %[[ZERO]] {comparison_direction = #chlo} : (tensor, tensor) -// CHECK-SCF-DAG: %[[TMP0:.*]] = mhlo.log %[[SCALAR]] : tensor -// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[FLAT_NON_SCALAR]], %[[TMP0]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor, tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Equal shapes case: -// CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF: %{{.*}} = scf.if %[[SHAPES_EQ]] -// CHECK-SCF-DAG: %[[SHAPE:.*]] = shape.any %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE]] -// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FLAT_ARG0]], %[[ZERO]] {comparison_direction = #chlo} : (tensor, tensor) -// CHECK-SCF-DAG: %[[TMP0:.*]] = mhlo.log %[[FLAT_ARG1]] : tensor -// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[FLAT_ARG0]], %[[TMP0]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor, tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Find maximum reduced rank. -// CHECK-SCF-DAG: %[[REDUCED_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[REDUCED_RANK0:.*]] = shape.rank %[[REDUCED_SHAPES]]#0 -// CHECK-SCF-DAG: %[[REDUCED_RANK1:.*]] = shape.rank %[[REDUCED_SHAPES]]#1 -// CHECK-SCF-DAG: %[[R0_GT_R1:.*]] = arith.cmpi sgt, %[[REDUCED_RANK0]], %[[REDUCED_RANK1]] -// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = arith.select %[[R0_GT_R1]], %[[REDUCED_RANK0]], %[[REDUCED_RANK1]] -// Generic case 1: -// CHECK-SCF: %[[MAX_RED_RANK_LE_1:.*]] = arith.cmpi ule, %[[MAX_RED_RANK]], %[[C1]] -// CHECK-SCF: %{{.*}} = scf.if %[[MAX_RED_RANK_LE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[EXT_SHAPE_ARG0_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[EXT_SHAPE_ARG1_]] -// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[REDUCED_ARG0]], %[[ZERO]] {comparison_direction = #chlo} : (tensor, tensor) -// CHECK-SCF-DAG: %[[TMP0:.*]] = mhlo.log %[[REDUCED_ARG1]] : tensor -// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[TMP0]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor, tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// ... -// Reshape the result. -// CHECK-SCF: %[[S0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF: %[[S0_:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF: %[[S1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF: %[[TMP:.*]] = shape.broadcast %[[S0_]], %[[S1]] -// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.broadcast %[[S0]], %[[TMP]] -// CHECK-SCF: %[[RES:.*]] = mhlo.dynamic_reshape %[[UNSHAPED_RES]], %[[RES_SHAPE]] : (tensor<*xf32>, tensor) -> tensor<*xf32> -// CHECK-SCF: return %[[RES]] - -// ----- - -// CHECK-LABEL: @mul -// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) -func.func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]]) - // CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>): - // CHECK: %[[TMP:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[ARG1_]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]]) - // CHECK: return %[[RES]] - %0 = chlo.broadcast_multiply %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @mul -// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) -// CHECK-SCF-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-SCF-DAG: %[[C2:.*]] = arith.constant 2 -// CHECK-SCF-DAG: %[[C3:.*]] = arith.constant 3 -// CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1] -// CHECK-SCF-DAG: %[[ONE_SHAPE_2:.*]] = shape.const_shape [1, 1] -// CHECK-SCF-DAG: %[[ONE_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] -// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] -// Lhs scalar case: -// CHECK-SCF-DAG: %[[LHS_N:.*]] = shape.num_elements %[[SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[LHS_SCALAR:.*]] = arith.cmpi eq, %[[LHS_N]], %[[C1]] -// CHECK-SCF: %[[UNSHAPED_RES_LHS_SCALAR:.*]] = scf.if %[[LHS_SCALAR]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[SCALAR:.*]] = mhlo.reshape %[[ARG0]] -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[SCALAR]], %[[FLAT_NON_SCALAR]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Rhs scalar case: -// CHECK-SCF-DAG: %[[RHS_N:.*]] = shape.num_elements %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[RHS_SCALAR:.*]] = arith.cmpi eq, %[[RHS_N]], %[[C1]] -// CHECK-SCF: %[[UNSHAPED_RES_RHS_SCALAR:.*]] = scf.if %[[RHS_SCALAR]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[SCALAR:.*]] = mhlo.reshape %[[ARG1]] -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[FLAT_NON_SCALAR]], %[[SCALAR]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Equal shapes case: -// CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF: %[[UNSHAPED_RES_EQ_SHAPES:.*]] = scf.if %[[SHAPES_EQ]] -// CHECK-SCF-DAG: %[[SHAPE:.*]] = shape.any %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE]] -// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[FLAT_ARG0]], %[[FLAT_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Find maximum reduced rank. -// CHECK-SCF-DAG: %[[REDUCED_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[REDUCED_RANK0:.*]] = shape.rank %[[REDUCED_SHAPES]]#0 -// CHECK-SCF-DAG: %[[REDUCED_RANK1:.*]] = shape.rank %[[REDUCED_SHAPES]]#1 -// CHECK-SCF-DAG: %[[R0_GT_R1:.*]] = arith.cmpi sgt, %[[REDUCED_RANK0]], %[[REDUCED_RANK1]] -// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = arith.select %[[R0_GT_R1]], %[[REDUCED_RANK0]], %[[REDUCED_RANK1]] -// Generic case 1: -// CHECK-SCF: %[[MAX_RED_RANK_LE_1:.*]] = arith.cmpi ule, %[[MAX_RED_RANK]], %[[C1]] -// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_LE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[EXT_SHAPE_ARG0_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[EXT_SHAPE_ARG1_]] -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Generic case 2: -// CHECK-SCF: %[[MAX_RED_RANK_LE_2:.*]] = arith.cmpi ule, %[[MAX_RED_RANK]], %[[C2]] -// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_LE_2]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[EXT_SHAPE_ARG0_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[EXT_SHAPE_ARG1_]] -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: else -// Generic case 3: -// CHECK-SCF: %[[MAX_RED_RANK_LE_3:.*]] = arith.cmpi ule, %[[MAX_RED_RANK]], %[[C3]] -// CHECK-SCF: assert %[[MAX_RED_RANK_LE_3]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 3" -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[EXT_SHAPE_ARG0_]] -// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[EXT_SHAPE_ARG1_]] -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] -// CHECK-SCF: scf.yield %[[INNER_RES_]] -// CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]] -// CHECK-SCF: scf.yield %[[UNSHAPED_RES_1]] -// CHECK-SCF: scf.yield %[[UNSHAPED_RES_EQ_SHAPES]] -// CHECK-SCF: scf.yield %[[UNSHAPED_RES_RHS_SCALAR]] -// Reshape the result. -// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[RES:.*]] = mhlo.dynamic_reshape %[[UNSHAPED_RES_LHS_SCALAR]], %[[RES_SHAPE]] -// CHECK-SCF: return %[[RES]] - -// ----- - -// CHECK-LABEL: @merge_clusters -// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf64>, %[[ARG1:.*]]: tensor<*xf64>) -func.func @merge_clusters(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>) - -> tensor<*xf64> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]]) - // CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf64>, %[[ARG1_:.*]]: tensor<*xf64>): - // CHECK: %[[TMP0:.*]] = mhlo.tanh %[[ARG0_]] - // CHECK: %[[TMP1:.*]] = chlo.broadcast_add %[[TMP0]], %[[ARG0_]] - // CHECK: %[[TMP2:.*]] = chlo.broadcast_add %[[TMP1]], %[[ARG1_]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]]) - // CHECK: return %[[RES]] - %0 = "chlo.rank_specialization_cluster"(%arg0) ({ - ^bb0(%arg0_: tensor<*xf64>): - %1 = mhlo.tanh %arg0_ : (tensor<*xf64>) -> tensor<*xf64> - "chlo.rank_specialization_cluster_yield"(%1) : (tensor<*xf64>) -> () - }) : (tensor<*xf64>) -> (tensor<*xf64>) - %2 = "chlo.rank_specialization_cluster"(%0, %arg0, %arg1) ({ - ^bb0(%3: tensor<*xf64>, %4: tensor<*xf64>, %5: tensor<*xf64>): - %6 = "chlo.broadcast_add"(%3, %4) - : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> - %7 = "chlo.broadcast_add"(%6, %5) - : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> - "chlo.rank_specialization_cluster_yield"(%7) : (tensor<*xf64>) -> () - }) : (tensor<*xf64>, tensor<*xf64>, tensor<*xf64>) -> (tensor<*xf64>) - func.return %2 : tensor<*xf64> -} - -// ----- - -// CHECK-LABEL: @all_equal_shapes_inferrable -// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf64>, %[[ARG1:.*]]: tensor<*xf64>) -func.func @all_equal_shapes_inferrable(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>) - -> tensor<*xf64> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]]) - // CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf64>, %[[ARG1_:.*]]: tensor<*xf64>) - // CHECK: %[[INNER_RES:.*]] = mhlo.add %[[ARG0_]], %[[ARG1_]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[INNER_RES]]) - // CHECK: return %[[RES]] - %0 = "mhlo.add"(%arg0, %arg1) - : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> - func.return %0 : tensor<*xf64> -} - -// CHECK-SCF-LABEL: @all_equal_shapes_inferrable -// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf64>, %[[ARG1:.*]]: tensor<*xf64>) -// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF-DAG: %[[S:.*]] = shape.any %[[S0]], %[[S1]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[S]] -// CHECK-SCF-DAG: %[[FLAT_S:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[FLAT_S]] -// CHECK-SCF-DAG: %[[FLAT1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[FLAT_S]] -// CHECK-SCF: %[[FLAT_RES:.*]] = mhlo.add %[[FLAT0]], %[[FLAT1]] -// CHECK-SCF-DAG: %[[RES:.*]] = mhlo.dynamic_reshape %[[FLAT_RES]], %[[S0]] -// CHECK-SCF: return %[[RES]] - -// ----- - -// All shapes are equal, which is inferrable through the select op. -// CHECK-LABEL: @relu_grad -// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) -func.func @relu_grad(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG1]], %[[ARG0]]) - // CHECK: ^bb0(%[[ARG1_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>) - // CHECK: %[[TMP0:.*]] = "chlo.constant_like"(%[[ARG0_]]) {value = 0.0{{.*}}e+00 : f32} - // CHECK: %[[TMP1:.*]] = mhlo.compare GT, %[[ARG0_]], %[[TMP0]] - // CHECK: %[[TMP2:.*]] = mhlo.select %[[TMP1]], %[[ARG1_]], %[[TMP0]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]]) - // CHECK: return %[[RES]] - %0 = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> - %1 = "mhlo.compare"(%arg0, %0) {comparison_direction = #mhlo} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1> - %2 = "mhlo.select"(%1, %arg1, %0) : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - func.return %2 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @relu_grad -// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) -// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF-DAG: %[[S:.*]] = shape.any %[[S1]], %[[S0]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[S]] -// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[FLAT1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%[[FLAT0]]) {value = 0.0{{.*}}+00 : f32} -// CHECK-SCF-DAG: %[[PRED:.*]] = mhlo.compare GT, %[[FLAT0]], %[[ZERO]] -// CHECK-SCF: %[[UNSHAPED_RES:.*]] = mhlo.select %[[PRED]], %[[FLAT1]], %[[ZERO]] -// CHECK-SCF-DAG: %[[RES:.*]] = mhlo.dynamic_reshape %[[UNSHAPED_RES]], %[[S1]] -// CHECK-SCF: return %[[RES]] - -// ----- - -// Find shape equivalences through surrounding constraints. -// CHECK-LABEL: @relu_grad -// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) -func.func @relu_grad(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - // CHECK-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] - // CHECK-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]] - // CHECK-DAG: %[[CSTR_EQ:.*]] = shape.cstr_eq %[[S0]], %[[S1]] - // CHECK: %[[RES:.*]] = shape.assuming %[[CSTR_EQ]] - // CHECK: %[[INNER_RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG1]], %[[ARG0]]) - // CHECK: ^bb0(%[[ARG1_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>): - // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%[[ARG0_]]) {value = 0.0{{.*}}+00 : f32} - // CHECK-DAG: %[[PRED:.*]] = mhlo.compare GT, %[[ARG0_]], %[[ZERO]] - // CHECK-DAG: %[[INNER_INNER_RES:.*]] = mhlo.select %[[PRED]], %[[ARG1_]], %[[ZERO]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[INNER_INNER_RES]]) - // CHECK: shape.assuming_yield %[[INNER_RES]] - // CHECK: return %[[RES]] - %0 = shape.shape_of %arg0 : tensor<*xf32> -> tensor - %1 = shape.shape_of %arg1 : tensor<*xf32> -> tensor - %2 = shape.cstr_eq %0, %1 : tensor, tensor - %3 = shape.assuming %2 -> tensor<*xf32> { - %4 = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} - : (tensor<*xf32>) -> tensor<*xf32> - %5 = "mhlo.compare"(%arg0, %4) {comparison_direction = #mhlo} - : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1> - %6 = "mhlo.select"(%5, %arg1, %4) - : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - shape.assuming_yield %6 : tensor<*xf32> - } - func.return %3 : tensor<*xf32> -} - -// CHECK-SCF-LABEL: @relu_grad -// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) -// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF-DAG: %[[CSTR_EQ:.*]] = shape.cstr_eq %0, %1 -// CHECK-SCF: %[[RES:.*]] = shape.assuming %[[CSTR_EQ]] -// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF-DAG: %[[S:.*]] = shape.any %[[S1]], %[[S0]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[S]] -// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] -// CHECK-SCF-DAG: %[[FLAT0:.*]] = mhlo.dynamic_reshape %[[ARG0]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[FLAT1:.*]] = mhlo.dynamic_reshape %[[ARG1]], %[[FLAT_SHAPE]] -// CHECK-SCF-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%[[FLAT0]]) {value = 0.0{{.*}}+00 : f32} -// CHECK-SCF-DAG: %[[PRED:.*]] = mhlo.compare GT, %[[FLAT0]], %[[ZERO]] -// CHECK-SCF: %[[UNSHAPED_RES:.*]] = mhlo.select %[[PRED]], %[[FLAT1]], %[[ZERO]] -// CHECK-SCF-DAG: %[[INNER_RES:.*]] = mhlo.dynamic_reshape %[[UNSHAPED_RES]], %[[S1]] -// CHECK-SCF: shape.assuming_yield %[[INNER_RES]] -// CHECK-SCF: return %[[RES]] diff --git a/xla/mlir_hlo/tests/shape-component-analysis.mlir b/xla/mlir_hlo/tests/shape-component-analysis.mlir deleted file mode 100644 index e088b434518bf..0000000000000 --- a/xla/mlir_hlo/tests/shape-component-analysis.mlir +++ /dev/null @@ -1,365 +0,0 @@ -// RUN: mlir-hlo-opt --test-print-shape-components --split-input-file %s | FileCheck %s - -// CHECK-LABEL: Testing : assuming -func.func @assuming(%arg0: tensor, %arg1: tensor, %arg2 : !shape.witness) -> tensor<2xi32> { - %0:2 = shape.assuming %arg2 -> (tensor, tensor) { - shape.assuming_yield %arg0, %arg1 : tensor, tensor - } - %1 = shape.shape_of %0#0 : tensor -> tensor<2xindex> - %2 = shape.shape_of %0#1 : tensor -> tensor<2xindex> - %3 = arith.index_cast %1 : tensor<2xindex> to tensor<2xi32> - %4 = arith.index_cast %2 : tensor<2xindex> to tensor<2xi32> - // CHECK: Value info for %5 = mhlo.add %3, %4 : tensor<2xi32> - // CHECK-NEXT: s0 + s1 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[0] - // CHECK-NEXT: s1 = shapeof( of type 'tensor' at index: 1)[0] - // CHECK-NEXT: s0 + s1 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[1] - // CHECK-NEXT: s1 = shapeof( of type 'tensor' at index: 1)[1] - %5 = mhlo.add %3, %4 : tensor<2xi32> - // CHECK: Value info for %6 = mhlo.multiply %5, %4 : tensor<2xi32> - // CHECK-NEXT: (s0 + s1) * s2 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[0] - // CHECK-NEXT: s1 = shapeof( of type 'tensor' at index: 1)[0] - // CHECK-NEXT: s2 = shapeof( of type 'tensor' at index: 1)[0] - // CHECK-NEXT: (s0 + s1) * s2 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[1] - // CHECK-NEXT: s1 = shapeof( of type 'tensor' at index: 1)[1] - // CHECK-NEXT: s2 = shapeof( of type 'tensor' at index: 1)[1] - %6 = mhlo.multiply %5, %4 : tensor<2xi32> - func.return %6 : tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: Testing : num_elements -func.func @num_elements(%arg0: tensor) -> index { - // CHECK: Value info for %0 = shape.shape_of %arg0 : tensor -> tensor<4xindex> - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[0] - // CHECK-NEXT: 8 - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[2] - // CHECK-NEXT: 64 - %0 = shape.shape_of %arg0 : tensor -> tensor<4xindex> - // CHECK: Value info for %1 = shape.num_elements %0 : tensor<4xindex> -> index: - // CHECK-NEXT: (s0 * s1) * 512 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[0] - // CHECK-NEXT: s1 = shapeof( of type 'tensor' at index: 0)[2] - %1 = shape.num_elements %0 : tensor<4xindex> -> index - func.return %1 : index -} - -// ----- - -// CHECK-LABEL: Testing : dynamic_broadcast_in_dim -func.func @dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor) -> tensor<2xindex> { - %0 = shape.shape_of %arg0 : tensor -> tensor<2xindex> - %1 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK: Value info for %2 = shape.shape_of %1 : tensor -> tensor<2xindex> - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[0] - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[1] - %2 = shape.shape_of %1 : tensor -> tensor<2xindex> - func.return %2 : tensor<2xindex> -} - -// ----- - -// CHECK-LABEL: Testing : dynamic_reshape -func.func @dynamic_reshape(%arg0: tensor, %arg1: tensor) -> tensor<2xindex> { - %0 = shape.shape_of %arg0 : tensor -> tensor<2xindex> - %1 = "mhlo.dynamic_reshape"(%arg0, %0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK: Value info for %2 = shape.shape_of %1 : tensor -> tensor<2xindex> - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[0] - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[1] - %2 = shape.shape_of %1 : tensor -> tensor<2xindex> - func.return %2 : tensor<2xindex> -} - -// ----- - -// CHECK-LABEL: Testing : reduce -func.func @reduce(%arg0: tensor, %arg1: tensor) -> tensor<2xindex> { - %0 = "mhlo.reduce"(%arg0, %arg1) ({ - ^bb0(%a: tensor, %b: tensor): - %26 = mhlo.add %a, %b : tensor - "mhlo.return"(%26) : (tensor) -> () - }) {dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor - // CHECK: Value info for %1 = shape.shape_of %0 : tensor -> tensor<2xindex> - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[0] - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[2] - %1 = shape.shape_of %0 : tensor -> tensor<2xindex> - func.return %1 : tensor<2xindex> -} - -// ----- - -// CHECK-LABEL: Testing : transpose -func.func @transpose(%arg0: tensor) -> tensor<2xindex> { - %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor) -> tensor - // CHECK: Value info for %1 = shape.shape_of %0 : tensor -> tensor<2xindex> - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[1] - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[0] - %1 = shape.shape_of %0 : tensor -> tensor<2xindex> - func.return %1 : tensor<2xindex> -} - -// ----- - -// CHECK-LABEL: Testing : select -func.func @select(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<2xindex> { - %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor - // CHECK: Value info for %1 = shape.shape_of %0 : tensor -> tensor<2xindex> - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 1)[0] - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 1)[1] - %1 = shape.shape_of %0 : tensor -> tensor<2xindex> - func.return %1 : tensor<2xindex> -} - -// ----- - -// CHECK-LABEL: Testing : dim -func.func @dim(%arg0: tensor) -> tensor<2xindex> { - %c0 = arith.constant 0 : index - %d0 = tensor.dim %arg0, %c0 : tensor - %t = tensor.from_elements %d0, %d0 : tensor<2xindex> - // CHECK: Value info for %{{.*}} = tensor.from_elements %{{.*}}, %{{.*}} : tensor<2xindex> - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[0] - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[0] - func.return %t : tensor<2xindex> -} - -// ----- - -// CHECK-LABEL: Testing : extract -func.func @extract(%arg0: tensor) -> tensor<2xindex> { - %shape = shape.shape_of %arg0 : tensor -> tensor<2xindex> - %c1 = arith.constant 1 : index - %d0 = tensor.extract %shape[%c1] : tensor<2xindex> - // CHECK: Value info for %{{.*}} = tensor.from_elements %{{.*}}, %{{.*}} : tensor<2xindex> - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[1] - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[1] - %t = tensor.from_elements %d0, %d0 : tensor<2xindex> - func.return %t : tensor<2xindex> -} - -// ----- - -// CHECK-LABEL: Testing : symbolic_constraint -func.func @symbolic_constraint( - %arg0: tensor - {rt.symbolic_shape = dense<[-3, -2]> : tensor<2xi64>}, - %arg1: tensor - {rt.symbolic_shape = dense<[-4, -2]> : tensor<2xi64>} -) -> tensor<2xi32> { - %0 = shape.shape_of %arg0 : tensor -> tensor<2xindex> - %1 = shape.shape_of %arg1 : tensor -> tensor<2xindex> - %2 = arith.index_cast %0 : tensor<2xindex> to tensor<2xi32> - %3 = arith.index_cast %1 : tensor<2xindex> to tensor<2xi32> - // CHECK: Value info for %4 = mhlo.add %2, %3 : tensor<2xi32>: - // CHECK-NEXT: s0 + s1 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[0] - // CHECK-NEXT: s1 = shapeof( of type 'tensor' at index: 1)[0] - // CHECK-NEXT: s0 + s1 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[1] - // CHECK-NEXT: s1 = shapeof( of type 'tensor' at index: 0)[1] - %4 = mhlo.add %2, %3 : tensor<2xi32> - func.return %4 : tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: Testing : dynamic_reshape -func.func @dynamic_reshape(%arg0: tensor, %arg1: tensor<4xi32>) - -> tensor { - %0 = shape.shape_of %arg0 : tensor -> tensor<4xindex> - %1 = shape.num_elements %0 : tensor<4xindex> -> index - %2 = mhlo.compute_reshape_shape %1, %arg1 : (index, tensor<4xi32>) - -> tensor<4xi32> - // CHECK: Shape info for %3 = mhlo.dynamic_reshape %arg0, %2 : (tensor, tensor<4xi32>) -> tensor - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = %2 = mhlo.compute_reshape_shape %1, %arg1 : (index, tensor<4xi32>) -> tensor<4xi32>[0] - // CHECK-NEXT: 8 - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = %2 = mhlo.compute_reshape_shape %1, %arg1 : (index, tensor<4xi32>) -> tensor<4xi32>[2] - // CHECK-NEXT: 64 - %3 = "mhlo.dynamic_reshape"(%arg0, %2) - : (tensor, tensor<4xi32>) -> tensor - func.return %3 : tensor -} - -// ----- - -// Larger examples. - -// CHECK-LABEL: Testing : softmax -func.func @softmax(%arg0: tensor) -> tensor { - %0 = mhlo.constant dense<-1> : tensor<1xi64> - %1 = "mhlo.convert"(%arg0) : (tensor) -> tensor - %2 = mhlo.constant dense<0xFF800000> : tensor - %3 = "mhlo.reduce"(%1, %2) ({ - ^bb0(%arg1: tensor, %arg2: tensor): - %26 = mhlo.maximum %arg1, %arg2 : tensor - "mhlo.return"(%26) : (tensor) -> () - }) {dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor - %4 = "mhlo.convert"(%3) : (tensor) -> tensor - %cst = arith.constant dense<1> : tensor<1xi32> - // CHECK: Value info for %5 = shape.shape_of - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[0] - %5 = shape.shape_of %4 : tensor -> tensor<1xindex> - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - // CHECK: Value info for %{{.*}} = tensor.extract - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[0] - %6 = tensor.extract %5[%c0] : tensor<1xindex> - // CHECK: Value info for %{{.*}} = tensor.from_elements - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[0] - // CHECK-NEXT: 1 - %7 = tensor.from_elements %6, %c1 : tensor<2xindex> - %8 = "mhlo.dynamic_reshape"(%4, %7) : (tensor, tensor<2xindex>) -> tensor - %9 = shape.shape_of %arg0 : tensor -> tensor<2xindex> - %10 = shape.shape_of %8 : tensor -> tensor<2xindex> - %11 = shape.cstr_broadcastable %9, %10 : tensor<2xindex>, tensor<2xindex> - %12 = shape.assuming %11 -> (tensor) { - // CHECK: Value info for %{{.*}} = shape.shape_of %arg0 : tensor -> tensor<2xindex>: - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[0] - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[1] - %26 = shape.shape_of %arg0 : tensor -> tensor<2xindex> - // CHECK: Value info for %{{.*}} = shape.shape_of - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[0] - // CHECK-NEXT: 1 - %27 = shape.shape_of %8 : tensor -> tensor<2xindex> - %28 = shape.broadcast %26, %27 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> - %29 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %28) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - %30 = "mhlo.dynamic_broadcast_in_dim"(%8, %28) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - %31 = mhlo.subtract %29, %30 : tensor - shape.assuming_yield %31 : tensor - } - %13 = "mhlo.exponential"(%12) : (tensor) -> tensor - %14 = "mhlo.convert"(%13) : (tensor) -> tensor - %15 = mhlo.constant dense<0.000000e+00> : tensor - %16 = "mhlo.reduce"(%14, %15) ({ - ^bb0(%arg1: tensor, %arg2: tensor): - %26 = mhlo.add %arg1, %arg2 : tensor - "mhlo.return"(%26) : (tensor) -> () - }) {dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor - %17 = "mhlo.convert"(%16) : (tensor) -> tensor - %cst_0 = arith.constant dense<1> : tensor<1xi32> - %18 = shape.shape_of %17 : tensor -> tensor<1xindex> - %c1_1 = arith.constant 1 : index - %c0_2 = arith.constant 0 : index - %19 = tensor.extract %18[%c0_2] : tensor<1xindex> - %20 = tensor.from_elements %19, %c1_1 : tensor<2xindex> - %21 = "mhlo.dynamic_reshape"(%17, %20) : (tensor, tensor<2xindex>) -> tensor - %22 = shape.shape_of %13 : tensor -> tensor<2xindex> - %23 = shape.shape_of %21 : tensor -> tensor<2xindex> - %24 = shape.cstr_broadcastable %22, %23 : tensor<2xindex>, tensor<2xindex> - %25 = shape.assuming %24 -> (tensor) { - %26 = shape.shape_of %13 : tensor -> tensor<2xindex> - %27 = shape.shape_of %21 : tensor -> tensor<2xindex> - %28 = shape.broadcast %26, %27 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> - %29 = "mhlo.dynamic_broadcast_in_dim"(%13, %28) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - %30 = "mhlo.dynamic_broadcast_in_dim"(%21, %28) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - %31 = mhlo.divide %29, %30 : tensor - shape.assuming_yield %31 : tensor - } - func.return %25 : tensor -} - -// ----- - -// CHECK-LABEL: Testing : reshape_integration -func.func @reshape_integration(%arg0: tensor<512x512xf32>, %arg1: tensor, %arg2: tensor<4xi32>, %arg3: tensor<512xf32>, %arg4: tensor, %arg5: tensor<512xf32>, %arg6: tensor<512xf32>, %arg7: tensor<512x2048xf32>, %arg8: tensor<2048xf32>, %arg9: tensor<2048x512xf32>, %arg10: tensor<512xf32>, %arg11: tensor<512xf32>, %arg12: tensor<512xf32>) -> tensor { - %0 = mhlo.constant dense<512> : tensor<1xi32> - %1 = shape.shape_of %arg1 : tensor -> tensor<4xindex> - %2 = shape.num_elements %1 : tensor<4xindex> -> index - %3 = mhlo.cstr_reshapable %2, %arg2 : (index, tensor<4xi32>) -> !shape.witness - %4 = "mhlo.dynamic_reshape"(%arg1, %arg2) : (tensor, tensor<4xi32>) -> tensor - %5 = "mhlo.transpose"(%4) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor) -> tensor - %6 = "mhlo.transpose"(%5) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor) -> tensor - %7 = shape.shape_of %6 : tensor -> tensor<4xindex> - %8 = arith.index_cast %7 : tensor<4xindex> to tensor<4xi32> - %9 = "mhlo.slice"(%8) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1xi32> - %10 = "mhlo.reshape"(%9) : (tensor<1xi32>) -> tensor - %11 = "mhlo.slice"(%8) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1xi32> - %12 = "mhlo.reshape"(%11) : (tensor<1xi32>) -> tensor - %13 = mhlo.multiply %10, %12 : tensor - %14 = "mhlo.reshape"(%13) : (tensor) -> tensor<1xi32> - // CHECK: Value info for %15 = "mhlo.concatenate"(%14, %0) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - // CHECK-NEXT: s0 * s1 with - // CHECK-NEXT: s0 = of type 'tensor<4xi32>' at index: 2[0] - // CHECK-NEXT: s1 = of type 'tensor<4xi32>' at index: 2[2] - // CHECK-NEXT: 512 - %15 = "mhlo.concatenate"(%14, %0) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - // CHECK: Value info for %16 = shape.shape_of %6 : tensor -> tensor<4xindex>: - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = of type 'tensor<4xi32>' at index: 2[0] - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = of type 'tensor<4xi32>' at index: 2[2] - // CHECK-NEXT: 64 - // CHECK-NEXT: 8 - %16 = shape.shape_of %6 : tensor -> tensor<4xindex> - %17 = shape.num_elements %16 : tensor<4xindex> -> index - %18 = mhlo.cstr_reshapable %17, %15 : (index, tensor<2xi32>) -> !shape.witness - %19 = shape.assuming %18 -> (tensor) { - %21 = "mhlo.dynamic_reshape"(%6, %15) : (tensor, tensor<2xi32>) -> tensor - shape.assuming_yield %21 : tensor - } - func.return %19 : tensor -} - -// ----- - -// CHECK-LABEL: Testing : broadcast -func.func @broadcast(%arg0 : tensor, %arg1 : tensor<1x?x7xf32>) - -> tensor<3xindex> { - %s0 = shape.shape_of %arg0 : tensor -> tensor<3xindex> - %s1 = shape.shape_of %arg1 : tensor<1x?x7xf32> -> tensor<3xindex> - // CHECK: Value info for %2 = shape.broadcast %0, %1 : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>: - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor' at index: 0)[0] - // CHECK-NEXT: 5 - // CHECK-NEXT: 7 - %0 = shape.broadcast %s0, %s1 : tensor<3xindex>, tensor<3xindex> - -> tensor<3xindex> - func.return %0 : tensor<3xindex> -} - -// ----- - -// CHECK-LABEL: Testing : broadcast -func.func @broadcast(%arg0 : tensor, %arg1 : tensor<1x5x?x?xf32>) - -> tensor<4xindex> { - %s0 = shape.shape_of %arg0 : tensor -> tensor<1xindex> - %s1 = shape.shape_of %arg1 : tensor<1x5x?x?xf32> -> tensor<4xindex> - // CHECK: Value info for %2 = shape.broadcast %0, %1 : tensor<1xindex>, tensor<4xindex> -> tensor<4xindex>: - // CHECK-NEXT: 1 - // CHECK-NEXT: 5 - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = shapeof( of type 'tensor<1x5x?x?xf32>' at index: 1)[2] - // CHECK-NEXT: s0 with - // CHECK-NEXT: s0 = %2 = shape.broadcast %{{.*}}, %{{.*}} : tensor<1xindex>, tensor<4xindex> -> tensor<4xindex>[3] - %0 = shape.broadcast %s0, %s1 : tensor<1xindex>, tensor<4xindex> - -> tensor<4xindex> - func.return %0 : tensor<4xindex> -} diff --git a/xla/mlir_hlo/tests/test_userange.mlir b/xla/mlir_hlo/tests/test_userange.mlir deleted file mode 100644 index 88b01dcf8c9ac..0000000000000 --- a/xla/mlir_hlo/tests/test_userange.mlir +++ /dev/null @@ -1,118 +0,0 @@ -// RUN: mlir-hlo-opt -test-print-userange -split-input-file %s | FileCheck %s - -// CHECK-LABEL: Testing : func_empty -func.func @func_empty() { - func.return -} -// CHECK: ---- UserangeAnalysis ----- -// CHECK-NEXT: --------------------------- - -// ----- - -// CHECK-LABEL: Testing : useRangeGap -func.func @useRangeGap(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) -{ - %0 = memref.alloc() : memref<2xf32> - %1 = memref.alloc() : memref<2xf32> - cf.cond_br %arg0, ^bb1, ^bb2 -^bb1: - "lmhlo.negate"(%arg1, %0) : (memref<2xf32>, memref<2xf32>) -> () - "lmhlo.negate"(%arg1, %1) : (memref<2xf32>, memref<2xf32>) -> () - cf.br ^bb3 -^bb2: - "lmhlo.negate"(%arg2, %0) : (memref<2xf32>, memref<2xf32>) -> () - "lmhlo.negate"(%arg2, %1) : (memref<2xf32>, memref<2xf32>) -> () - cf.br ^bb3 -^bb3: - func.return -} -// CHECK: Value: %[[A0:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(7, 7), (13, 13)} -// CHECK: Value: %[[A1:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(9, 9), (15, 15)} -// CHECK: %[[A0]] = memref.alloc -// CHECK: %[[A1]] = memref.alloc - -// ----- - -// CHECK-LABEL: Testing : loopWithNestedRegion -func.func @loopWithNestedRegion(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) -{ - %0 = memref.alloc() : memref<2xf32> - %1 = memref.alloc() : memref<2xf32> - %2 = memref.alloc() : memref<2xf32> - %3 = memref.alloc() : memref<2xf32> - cf.br ^bb1 -^bb1: - %4 = scf.if %arg0 -> (memref<2xf32>) { - "lmhlo.negate"(%arg1, %0) : (memref<2xf32>, memref<2xf32>) -> () - scf.yield %2 : memref<2xf32> - } else { - "lmhlo.negate"(%arg1, %1) : (memref<2xf32>, memref<2xf32>) -> () - scf.yield %2 : memref<2xf32> - } - cf.br ^bb2 -^bb2: - cf.cond_br %arg0, ^bb1, ^bb3 -^bb3: - "lmhlo.negate"(%arg1, %2) : (memref<2xf32>, memref<2xf32>) -> () - "lmhlo.negate"(%arg1, %3) : (memref<2xf32>, memref<2xf32>) -> () - func.return -} -// CHECK: Value: %[[A0:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(11, 23)} -// CHECK: Value: %[[A1:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(11, 23)} -// CHECK: Value: %[[A2:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(11, 25)} -// CHECK: Value: %[[A3:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(27, 27)} -// CHECK: Value: %[[A4:.*]] = scf.if -// CHECK: Userange: {(19, 19)} -// CHECK: %[[A0]] = memref.alloc -// CHECK: %[[A1]] = memref.alloc -// CHECK: %[[A2]] = memref.alloc -// CHECK: %[[A3]] = memref.alloc -// CHECK: %[[A4]] = scf.if - -// ----- - -// CHECK-LABEL: Testing : condBranchWithAlias -func.func @condBranchWithAlias(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) -{ - %0 = memref.alloc() : memref<2xf32> - cf.cond_br %arg0, ^bb1, ^bb2 -^bb1: - "lmhlo.negate"(%arg1, %0) : (memref<2xf32>, memref<2xf32>) -> () - cf.br ^bb3(%0 : memref<2xf32>) -^bb2: - %1 = memref.alloc() : memref<2xf32> - "lmhlo.negate"(%arg1, %1) : (memref<2xf32>, memref<2xf32>) -> () - cf.br ^bb3(%1 : memref<2xf32>) -^bb3(%2 : memref<2xf32>): - %3 = memref.alloc() : memref<2xf32> - "lmhlo.copy"(%2, %arg2) : (memref<2xf32>, memref<2xf32>) -> () - "lmhlo.copy"(%3, %arg2) : (memref<2xf32>, memref<2xf32>) -> () - %4 = memref.alloc() : memref<2xf32> - "lmhlo.copy"(%4, %arg2) : (memref<2xf32>, memref<2xf32>) -> () - cf.br ^bb4(%0 : memref<2xf32>) -^bb4(%5 : memref<2xf32>): - "lmhlo.copy"(%5, %arg2) : (memref<2xf32>, memref<2xf32>) -> () - func.return -} -// CHECK: Value: %[[A0:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(5, 7), (15, 27)} -// CHECK: Value: %[[A1:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(11, 17)} -// CHECK: Value: %[[A2:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(19, 19)} -// CHECK: Value: %[[A3:.*]] = memref.alloc -// CHECK-NEXT: Userange: {(23, 23)} -// CHECK: Value: of type 'memref<2xf32>' at index: 0 -// CHECK-SAME: Userange: {(15, 17)} -// CHECK: Value: of type 'memref<2xf32>' at index: 0 -// CHECK-SAME: Userange: {(27, 27)} -// CHECK: %[[A0]] = memref.alloc -// CHECK: %[[A1]] = memref.alloc -// CHECK: %[[A2]] = memref.alloc -// CHECK: %[[A3]] = memref.alloc diff --git a/xla/mlir_hlo/tests/tile_loops.mlir b/xla/mlir_hlo/tests/tile_loops.mlir index 8cfce78ba48c1..1357736a52da6 100644 --- a/xla/mlir_hlo/tests/tile_loops.mlir +++ b/xla/mlir_hlo/tests/tile_loops.mlir @@ -16,7 +16,7 @@ func.func @parallel_loop(%arg0: memref<16xf32>, %arg1: memref<16xf32>) { %2 = memref.load %arg0[%arg2] : memref<16xf32> %3 = math.log %2 : f32 memref.store %3, %0[%arg2] : memref<16xf32> - scf.yield + scf.reduce } %1 = bufferization.to_tensor %0 : memref<16xf32> bufferization.materialize_in_destination %1 in writable %arg1 @@ -38,14 +38,14 @@ func.func @statically_unrolled(%arg0: memref) { // CHECK: scf.parallel // CHECK: scf.parallel {{.*}} to (%[[C4]]) memref.store %arg1, %arg0[%arg1] : memref - scf.yield + scf.reduce } scf.parallel (%arg1) = (%c0) to (%c36) step (%c3) { // CHECK: scf.parallel // CHECK: scf.parallel // CHECK: scf.parallel {{.*}} to (%[[C4]]) memref.store %arg1, %arg0[%arg1] : memref - scf.yield + scf.reduce } "lmhlo.terminator"() : () -> () @@ -62,22 +62,22 @@ func.func @dynamically_unrolled(%arg0: memref, %arg1 : index) { scf.parallel (%arg2) = (%c0) to (%arg1) step (%c1) { // CHECK-NOT: scf.parallel {{.*}} to (%[[C4]]) memref.store %arg2, %arg0[%arg2] : memref - scf.yield + scf.reduce } scf.parallel (%arg2) = (%c0) to (%c10) step (%c1) { // CHECK-NOT: scf.parallel {{.*}} to (%[[C4]]) memref.store %arg2, %arg0[%arg2] : memref - scf.yield + scf.reduce } scf.parallel (%arg2) = (%c10) to (%c32) step (%c1) { // CHECK-NOT: scf.parallel {{.*}} to (%[[C4]]) memref.store %arg2, %arg0[%arg2] : memref - scf.yield + scf.reduce } scf.parallel (%arg2) = (%c0) to (%c32) step (%c10) { // CHECK-NOT: scf.parallel {{.*}} to (%[[C4]]) memref.store %arg2, %arg0[%arg2] : memref - scf.yield + scf.reduce } "lmhlo.terminator"() : () -> () @@ -99,7 +99,7 @@ func.func @complex_access(%arg0: memref<16xf32>, %arg1: memref<4xf32>) { %2 = memref.load %arg0[%idx] : memref<16xf32> %3 = math.log %2 : f32 memref.store %3, %0[%arg2] : memref<4xf32> - scf.yield + scf.reduce } %1 = bufferization.to_tensor %0 : memref<4xf32> bufferization.materialize_in_destination %1 in writable %arg1 diff --git a/xla/mlir_hlo/tests/vectorize_copy.mlir b/xla/mlir_hlo/tests/vectorize_copy.mlir index 8c57281a7041c..fa852549c9057 100644 --- a/xla/mlir_hlo/tests/vectorize_copy.mlir +++ b/xla/mlir_hlo/tests/vectorize_copy.mlir @@ -1,7 +1,6 @@ // RUN: mlir-hlo-opt %s --vectorize-copy --split-input-file | FileCheck %s -func.func @vectorize_copy(%arg: memref<2x2xf32>) -> memref<2x2xf32> { - %subview = memref.subview %arg[0, 0] [2, 2] [1, 1] : memref<2x2xf32> to memref<2x2xf32, strided<[16, 1]>> +func.func @vectorize_copy(%subview: memref<2x2xf32, strided<[16, 1]>>) -> memref<2x2xf32> { %alloc = memref.alloc() : memref<2x2xf32> memref.copy %subview, %alloc : memref<2x2xf32, strided<[16, 1]>> to memref<2x2xf32> return %alloc : memref<2x2xf32> diff --git a/xla/mlir_hlo/tools/CMakeLists.txt b/xla/mlir_hlo/tools/CMakeLists.txt index 0f3d1c8579574..0305ff606734c 100644 --- a/xla/mlir_hlo/tools/CMakeLists.txt +++ b/xla/mlir_hlo/tools/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt b/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt index a7a65587e1190..65988d923416b 100644 --- a/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt +++ b/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,14 +25,10 @@ set(LIBS AllMhloPasses DeallocationPasses LmhloDialect - LmhloGPUDialect LmhloPasses MLIRBufferTransforms - MLIRHLOAnalysis MLIRHLOGPUTransforms - MLIRHLOTestAnalysis MhloRegisterDialects - MhloTestAnalysis ) add_llvm_executable(mlir-hlo-opt mlir-hlo-opt.cc DEPENDS diff --git a/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc b/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc index 742bf5707b8ea..cde8923187ed5 100644 --- a/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc +++ b/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ limitations under the License. #include "deallocation/transforms/passes.h" #include "lhlo/IR/lhlo_ops.h" #include "lhlo/transforms/passes.h" -#include "lhlo_gpu/IR/lhlo_gpu_ops.h" #include "mhlo/IR/register.h" #include "mhlo/transforms/passes.h" #include "mlir/InitAllDialects.h" @@ -42,6 +41,6 @@ int main(int argc, char** argv) { registerAllExtensions(registry); mhlo::registerAllMhloDialects(registry); stablehlo::registerAllDialects(registry); - registry.insert(); + registry.insert(); return failed(MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry)); } diff --git a/xla/mlir_hlo/transforms/CMakeLists.txt b/xla/mlir_hlo/transforms/CMakeLists.txt index c011ce3a72422..45c8f24a796f5 100644 --- a/xla/mlir_hlo/transforms/CMakeLists.txt +++ b/xla/mlir_hlo/transforms/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +24,6 @@ add_public_tablegen_target(LMHLOGPUTransformsPassIncGen) add_mlir_library(MLIRBufferTransforms alloc_to_arg_pass.cc - buffer_packing.cc bufferize.cc bufferize_pass.cc collapse_parallel_loops_to_1d_pass.cc @@ -49,7 +48,6 @@ add_mlir_library(MLIRBufferTransforms LINK_LIBS PUBLIC ChloOps MLIRGPUDialect - MLIRHLOAnalysis MLIRIR MLIRMathTransforms MLIRPass @@ -75,7 +73,6 @@ add_mlir_library(MLIRHLOGPUTransforms LINK_LIBS PUBLIC MLIRArithTransforms MLIRGPUDialect - MLIRHLOAnalysis MLIRIR MLIRMemRefTransforms MLIRPass diff --git a/xla/mlir_hlo/transforms/alloc_to_arg_pass.cc b/xla/mlir_hlo/transforms/alloc_to_arg_pass.cc index 54694553fc833..8768825ec1332 100644 --- a/xla/mlir_hlo/transforms/alloc_to_arg_pass.cc +++ b/xla/mlir_hlo/transforms/alloc_to_arg_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/transforms/buffer_packing.cc b/xla/mlir_hlo/transforms/buffer_packing.cc deleted file mode 100644 index e7f957634ee5c..0000000000000 --- a/xla/mlir_hlo/transforms/buffer_packing.cc +++ /dev/null @@ -1,494 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include - -#include "analysis/userange_analysis.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" -#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/Operation.h" -#include "mlir/Pass/Pass.h" -#include "transforms/passes.h" -#include "utils/hlo_utils.h" - -namespace mlir { - -#define GEN_PASS_DEF_BUFFERPACKING -#define GEN_PASS_DEF_MEMORYCOUNT -#include "transforms/passes.h.inc" - -namespace { - -/// Returns the length of an userange interval. -size_t computeUserangeSize(const UseInterval &interval) { - return interval.end - interval.start + 1; -} - -/// Compute the byte size of a given Value. -size_t computeByteSize(const Value &v) { - auto type = v.getType().cast(); - return type.getNumElements() * type.getElementTypeBitWidth() / 8; -} - -/// Compute the 64 byte alinged segments of a given Value. -size_t computeAlignedSegments(const Value &v) { - size_t padding = 64; - size_t bytes = computeByteSize(v); - return std::ceil(bytes / (double)padding); -} - -/// The buffer offset information. -struct AllocBufferOffset { - public: - AllocBufferOffset(Value source, size_t offset) - : source(source), offset(offset) {} - - Value source; - size_t offset; -}; - -/// Contains the information to create a new buffer, that is used to pack -/// other buffers. -struct PackedBuffer { - public: - PackedBuffer(size_t numSegments, - std::vector &packedBuffers) - : numSegments(numSegments), allocBufferOffsets(packedBuffers) {} - - size_t numSegments; - std::vector allocBufferOffsets; -}; - -/// Contains the information about a buffers allocation for sorting and checking -/// if it fits into other buffers and vise versa. -/// This structure contains the allocation value, the first and last userangeid -/// of a buffer, the window id, the number of alligned 64 byte segments and all -/// userange intervals. -struct AllocationInfo { - public: - AllocationInfo(Value alloc, size_t allocUserangeId, size_t firstUse, - size_t lastUse, size_t numSegments, size_t windowId, - const UseInterval::Vector *userangeIntervals) - : alloc(alloc), - allocUserangeId(allocUserangeId), - firstUse(firstUse), - lastUse(lastUse), - numSegments(numSegments), - windowId(windowId), - userangeIntervals(userangeIntervals) {} - - /// The allocation value. - Value alloc; - - /// The id of allocation based on the Userange Analysis. - size_t allocUserangeId; - - /// The first use of the buffer. - size_t firstUse; - - /// The last use of the buffer based on the Userange Analysis. - size_t lastUse; - - /// The number of 64 byte aligned segments of contigous memory. - size_t numSegments; - - /// The window id of the allocation position. - size_t windowId; - - /// The userange intervals of the buffer. - const UseInterval::Vector *userangeIntervals; - - /// Compute the gaps of the alloc userange with the number of segments. The - /// maxUserangeId is used to add a dummy gap from the last used id to the - /// maxUserangeId. By default the maxUserangeId is zero and no gap is added. - std::list> computeGaps( - size_t maxUserangeId = 0) { - std::list> gaps; - - // The previous gap ending, initially set to 0. - size_t gapEnd = 0; - - for (const auto *useRangeIter = userangeIntervals->begin(); - useRangeIter < userangeIntervals->end(); ++useRangeIter) { - // Add a gap if the end is not equal to the start. - if (gapEnd < useRangeIter->start) - gaps.emplace_back(UseInterval(gapEnd, useRangeIter->start - 1), - numSegments); - gapEnd = useRangeIter->end + 1; - } - - // Add a dummy gap behind the last use of the buffer. - if (gapEnd < maxUserangeId) { - gaps.emplace_back(UseInterval(gapEnd, maxUserangeId), numSegments); - } - - return gaps; - } - - /// Compute the userange size. - size_t getUserangeSize() const { return lastUse - firstUse + 1; } -}; - -// Comparator to sort allocation informations by window id, userange and by -// number of memory segments. -class AllocInfoWinIdComparator { - public: - bool operator()(const AllocationInfo &a, const AllocationInfo &b) { - if (a.windowId == b.windowId) { - if (a.allocUserangeId == b.allocUserangeId) - return a.numSegments > b.numSegments; - return a.allocUserangeId > b.allocUserangeId; - } - return a.windowId < b.windowId; - } -}; - -// Comparator to sort the allocation informations by number of segments. -class AllocInfoMemSizeCompare { - public: - bool operator()(const AllocationInfo &a, const AllocationInfo &b) { - return a.numSegments > b.numSegments; - } -}; - -/// This approach computes an allocation information list and sorts it by -/// a given comparator. From top to bottom the algortihm tries to fill userange -/// gaps with appropriate buffers behind it, to optimze the memory. It is a bin -/// packing approach. -template -class SortedPackingStrategy { - public: - using AllocInfoList = std::vector; - - public: - /// Constructs the Sorted Packing Strategy. The window size is used as sliding - /// window size. Allocation userangepositions that are in the same range are - /// mapped to the same window id. So the information of the allocation - /// starting position is blured. - SortedPackingStrategy(size_t windowSize, CompareT compare) - : windowSize(windowSize), compare(compare) {} - - /// Optimize the buffer allocations. - void optimze(const mlir::bufferization::BufferPlacementAllocs &allocs, - const UserangeAnalysis &userangeAnalysis, - std::vector &packedBuffers) { - AllocInfoList allocInfos; - allocInfos.reserve(std::distance(allocs.begin(), allocs.end())); - - // Create allocInformations and store them in allocInfos. - size_t maxUserangeId = - computeAllocationInfos(allocInfos, userangeAnalysis, allocs); - - // Sort the allocation infos. - std::sort(allocInfos.begin(), allocInfos.end(), compare); - - for (auto currentIter = allocInfos.begin(); currentIter != allocInfos.end(); - ++currentIter) { - std::vector allocBufferOffsets{ - AllocBufferOffset(currentIter->alloc, 0)}; - - // Compute userange gaps. - std::list> gaps = - currentIter->computeGaps(maxUserangeId); - - if (gaps.empty()) continue; - - for (auto checkedAllocInfoIter = std::next(currentIter); - checkedAllocInfoIter != allocInfos.end();) { - // Check if a gap exists to pack the memory into. - // If not continue. - if (!findGapAndUpdate(gaps, allocBufferOffsets, *checkedAllocInfoIter, - *currentIter)) { - ++checkedAllocInfoIter; - continue; - } - checkedAllocInfoIter = allocInfos.erase(checkedAllocInfoIter); - } - // Add the current buffer offets to the packed infos. - packedBuffers.emplace_back(currentIter->numSegments * 64, - allocBufferOffsets); - } - } - - private: - const size_t windowSize; - const CompareT compare; - - /// We try to find an appropriate userange gap to pack the buffer into it. - /// If we find one we update only the gaps and the buffer offset map. - bool findGapAndUpdate(std::list> &gaps, - std::vector &allocBufferOffsets, - const AllocationInfo &allocToPack, - const AllocationInfo &allocToPackInto) { - // Check if the buffer to pack into has enough memory. - if (allocToPackInto.numSegments < allocToPack.numSegments) return false; - for (auto gapIter = gaps.begin(); gapIter != gaps.end();) { - // The list is sorted, so we can break here. - if (gapIter->first.start > allocToPack.firstUse) break; - - // Checks if enough contiguous memory segments are free or if the current - // gap is out of bounds. - if (gapIter->second < allocToPack.numSegments || - allocToPack.firstUse < gapIter->first.start || - allocToPack.lastUse > gapIter->first.end) { - ++gapIter; - continue; - } - - // Stores the packed buffer with the offset. - allocBufferOffsets.emplace_back( - allocToPack.alloc, - (allocToPackInto.numSegments - gapIter->second) * 64); - - // Update gap segments, will removed later if no free contigous memory - // exists. It is needed to split the interval, if not the full gap is - // used. - size_t freeContiguousMemory = gapIter->second; - gapIter->second = freeContiguousMemory - allocToPack.numSegments; - - // Check if the gap must be splitted. If so, then the current gap must be - // trimmed accordingly. Therefore, new gaps are created in front and after - // the current gap. - if (computeUserangeSize(gapIter->first) > allocToPack.getUserangeSize()) { - size_t oldStart = gapIter->first.start; - size_t oldEnd = gapIter->first.end; - gapIter->first.end = allocToPack.lastUse; - gapIter->first.start = allocToPack.firstUse; - - // Insert a new gap behind. - if (allocToPack.lastUse < oldEnd) - gaps.insert( - std::next(gapIter), - std::make_pair(UseInterval(allocToPack.lastUse + 1, oldEnd), - freeContiguousMemory)); - // Insert a new gap before. - if (allocToPack.firstUse > oldStart) - gaps.insert( - gapIter, - std::make_pair(UseInterval(oldStart, allocToPack.firstUse - 1), - freeContiguousMemory)); - } - - // If a gap interval has no free contiguous memory anymore, erease it from - // list. - if (gapIter->second <= 0) gapIter = gaps.erase(gapIter); - - return true; - } - return false; - } - - /// Aggreagtes the allocation informations of the allocs and returns the - /// maximal userange. - size_t computeAllocationInfos( - AllocInfoList &allocInfos, const UserangeAnalysis &userangeAnalysis, - const mlir::bufferization::BufferPlacementAllocs &allocs) { - // Create allocInformations and store them in allocInfos. - size_t maxUserangeId = 0; - - for (auto &allocEntry : allocs) { - Value v = std::get<0>(allocEntry); - auto userangeIntervals = userangeAnalysis.getUserangeInterval(v); - - if (!userangeIntervals) continue; - - // Computes the userange id of the allocation. - size_t allocUserangeId = userangeAnalysis.computeId(v, v.getDefiningOp()); - - // Computes the last use of the allocated buffer. - size_t lastUse = std::prev((*userangeIntervals.value()).end())->end; - - // Computes the first use of the allocated buffer. - size_t firstUse = (*userangeIntervals.value()).begin()->start; - - // Computes the number of aligend segments of the buffer. - size_t numSegments = computeAlignedSegments(v); - maxUserangeId = std::max(maxUserangeId, lastUse); - allocInfos.emplace_back(v, allocUserangeId, firstUse, lastUse, - numSegments, 0, userangeIntervals.value()); - } - - // If the window size is zero we need no sorting anymore. - if (windowSize == 0) return maxUserangeId; - // Sorts the allocation informations to compute the window id. The window id - // is used to blur the userange starting position of an allocation. - std::sort(allocInfos.begin(), allocInfos.end(), - [](const AllocationInfo &a, const AllocationInfo &b) { - return a.allocUserangeId < b.allocUserangeId; - }); - - // resize window id - size_t windowId = 0; - size_t lastAllocUserangeId = 0; - for (auto &allocationInfo : allocInfos) { - if (allocationInfo.allocUserangeId > lastAllocUserangeId + windowSize) - ++windowId; - - lastAllocUserangeId = allocationInfo.allocUserangeId; - allocationInfo.windowId = windowId; - } - return maxUserangeId; - } -}; - -/// Pass to pack buffer together to optimize the memeory consumption and to -/// save allocation operations. A strategy must be passed as a template -/// argument. -class BufferPacking : bufferization::BufferPlacementTransformationBase { - public: - template - BufferPacking(Operation *op, StrategyT strategy) - : BufferPlacementTransformationBase(op), - userangeAnalysis(op, allocs, aliases), - dominators(op) { - std::vector packedBuffers; - strategy.optimze(allocs, userangeAnalysis, packedBuffers); - - for (auto &packedBuffer : packedBuffers) { - // Find common dominators. - Block *block = findAllocationsDominator(packedBuffer.allocBufferOffsets); - // Find alloc position operation. - mlir::OpBuilder packBuilder(&(block->front())); - auto location = block->front().getLoc(); - auto memrefType = - MemRefType::get({static_cast(packedBuffer.numSegments)}, - packBuilder.getIntegerType(8)); - Value targetBuffer = - packBuilder.create(location, memrefType); - - for (auto &packInfo : packedBuffer.allocBufferOffsets) { - Value currentAlloc = packInfo.source; - size_t offset = packInfo.offset; - Operation *viewDefOp = currentAlloc.getDefiningOp(); - Location loc = viewDefOp->getLoc(); - mlir::OpBuilder viewBuilder(viewDefOp); - - // Create a arithmetic ConstantOp with the aligned offset. - Value constantOp = viewBuilder.create( - loc, viewBuilder.getIndexType(), - viewBuilder.getIntegerAttr(viewBuilder.getIndexType(), offset)); - - // Store the operands for the ViewOp. - SmallVector newOperands{targetBuffer}; - newOperands.push_back(constantOp); - - auto shape = currentAlloc.getType().cast(); - - // Create a ViewOp with the shape of the old alloc and use the created - // packed alloc and the constant for the operands. - Value viewOp = - viewBuilder.create(loc, shape, newOperands); - - // Replace all old allocs references with the created ViewOp and - // afterwards remove the old allocs. - currentAlloc.replaceAllUsesWith(viewOp); - viewDefOp->erase(); - } - } - } - - private: - UserangeAnalysis userangeAnalysis; - /// The current dominance info. - DominanceInfo dominators; - - /// Find the block that dominates all buffer allocations. - Block *findAllocationsDominator( - const std::vector &packingInfos) { - SmallPtrSet allocValues; - for (auto &packInfo : packingInfos) { - allocValues.insert(packInfo.source); - } - - // Find common dominators. - return findCommonDominator(packingInfos.begin()->source, allocValues, - dominators); - } -}; - -/// Tries to pack allocated buffer together to save allocation operations and -/// memory. The window size is used as sliding window size. Allocation -/// userangepoitions that are in the same range are mapped to the same window -/// id. The information of the allocation starting position is blured. -struct BufferPackingPass : public impl::BufferPackingBase { - explicit BufferPackingPass(unsigned windowSize) { - this->window_size_ = windowSize; - } - - void runOnOperation() override { - if (window_size_ == 0) { - SortedPackingStrategy strategy( - window_size_, AllocInfoMemSizeCompare()); - BufferPacking packing(getOperation(), strategy); - } else { - SortedPackingStrategy strategy( - window_size_, AllocInfoWinIdComparator()); - BufferPacking packing(getOperation(), strategy); - } - } -}; - -/// Pass to find all allocations and to compute memory usage. -struct MemoryCountPass : impl::MemoryCountBase { - void runOnOperation() override { - Operation *op = getOperation(); - std::vector allocs; - op->walk([&](MemoryEffectOpInterface opInterface) { - // Try to find a single allocation result. - SmallVector effects; - opInterface.getEffects(effects); - - SmallVector allocateResultEffects; - llvm::copy_if( - effects, std::back_inserter(allocateResultEffects), - [=](MemoryEffects::EffectInstance &it) { - Value value = it.getValue(); - return isa(it.getEffect()) && value && - value.isa() && - it.getResource() != - SideEffects::AutomaticAllocationScopeResource::get(); - }); - - if (allocateResultEffects.size() != 1) return; - // Insert allocation. - allocs.push_back(allocateResultEffects[0].getValue()); - }); - auto output = mlir::hlo::computeMemory(allocs); - llvm::outs() << "Memory Count Pass:\n" - << output.first << ";" << output.second << "\n"; - } -}; - -} // namespace - -std::unique_ptr> createBufferPackingPass( - unsigned windowSize) { - return std::make_unique(windowSize); -} - -std::unique_ptr> createMemoryCountPass() { - return std::make_unique(); -} - -} // namespace mlir diff --git a/xla/mlir_hlo/transforms/bufferize.cc b/xla/mlir_hlo/transforms/bufferize.cc index 6e7a85f91c980..93063d24190f6 100644 --- a/xla/mlir_hlo/transforms/bufferize.cc +++ b/xla/mlir_hlo/transforms/bufferize.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/transforms/bufferize_pass.cc b/xla/mlir_hlo/transforms/bufferize_pass.cc index 057e7289ca306..93756798df52b 100644 --- a/xla/mlir_hlo/transforms/bufferize_pass.cc +++ b/xla/mlir_hlo/transforms/bufferize_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -138,6 +138,9 @@ struct ComputeOpAndFuncBufferizePass .insert(); + arith::registerBufferizableOpInterfaceExternalModels(registry); + mlir::bufferization::func_ext:: + registerBufferizableOpInterfaceExternalModels(registry); linalg::registerBufferizableOpInterfaceExternalModels(registry); mhlo::registerBufferizableOpInterfaceExternalModels(registry); shape::registerBufferizableOpInterfaceExternalModels(registry); @@ -153,8 +156,7 @@ struct ComputeOpAndFuncBufferizePass // will be migrated to BufferizableOpInterface-based bufferization. options.opFilter.allowDialect(); + shape::ShapeDialect, vector::VectorDialect>(); if (failed(bufferization::bufferizeOp(getOperation(), options))) { signalPassFailure(); diff --git a/xla/mlir_hlo/transforms/collapse_parallel_loops_to_1d_pass.cc b/xla/mlir_hlo/transforms/collapse_parallel_loops_to_1d_pass.cc index 79bb80ed0ac04..a73a9c3bc7772 100644 --- a/xla/mlir_hlo/transforms/collapse_parallel_loops_to_1d_pass.cc +++ b/xla/mlir_hlo/transforms/collapse_parallel_loops_to_1d_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" #include "transforms/passes.h" using ::mlir::scf::ParallelOp; @@ -48,12 +49,13 @@ struct CollapseParallelLoopsTo1D using namespace mlir; void mlir::CollapseParallelLoopsTo1D::runOnOperation() { + IRRewriter rewriter(&getContext()); getOperation()->walk([&](ParallelOp op) { unsigned numLoops = op.getNumLoops(); if (numLoops == 1) return; std::vector combinedLoops(numLoops); std::iota(combinedLoops.begin(), combinedLoops.end(), 0u); - mlir::collapseParallelLoops(op, {combinedLoops}); + mlir::collapseParallelLoops(rewriter, op, {combinedLoops}); }); } diff --git a/xla/mlir_hlo/transforms/detensorize_scf_ops.cc b/xla/mlir_hlo/transforms/detensorize_scf_ops.cc index 6a2a898791049..196a394ece577 100644 --- a/xla/mlir_hlo/transforms/detensorize_scf_ops.cc +++ b/xla/mlir_hlo/transforms/detensorize_scf_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/transforms/generic_host_to_llvm.cc b/xla/mlir_hlo/transforms/generic_host_to_llvm.cc index bef78953286b9..9df69afbaf55a 100644 --- a/xla/mlir_hlo/transforms/generic_host_to_llvm.cc +++ b/xla/mlir_hlo/transforms/generic_host_to_llvm.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/transforms/gpu_kernel_lowering_passes.cc b/xla/mlir_hlo/transforms/gpu_kernel_lowering_passes.cc index 292d70ae51431..3e22aa5588832 100644 --- a/xla/mlir_hlo/transforms/gpu_kernel_lowering_passes.cc +++ b/xla/mlir_hlo/transforms/gpu_kernel_lowering_passes.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include @@ -20,6 +21,7 @@ limitations under the License. #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" @@ -48,6 +50,10 @@ namespace { /// that are currently required, currently mixing std, linalg and gpu. class GpuKernelToNVVMPass : public impl::GpuKernelToNVVMPassBase { + public: + explicit GpuKernelToNVVMPass(bool useBarePtrCallConv) { + this->useBarePtrCallConv = useBarePtrCallConv; + } void runOnOperation() override; }; @@ -95,9 +101,26 @@ void GpuKernelToNVVMPass::runOnOperation() { RewritePatternSet patterns(&getContext()); LowerToLLVMOptions llvmOpts(&getContext(), DataLayout(getOperation())); + llvmOpts.useBarePtrCallConv = useBarePtrCallConv; LLVMTypeConverter converter(&getContext(), llvmOpts); + populateCommonPatterns(converter, patterns); populateGpuToNVVMConversionPatterns(converter, patterns); + + populateGpuMemorySpaceAttributeConversions( + converter, [](gpu::AddressSpace space) { + switch (space) { + case gpu::AddressSpace::Global: + return 1; + case gpu::AddressSpace::Workgroup: + return 3; + case gpu::AddressSpace::Private: + return 5; + } + assert(false && "unknown address space enum value"); + return 0; + }); + ConversionTarget target(getContext()); configureGpuToNVVMConversionLegality(target); if (failed( @@ -120,8 +143,9 @@ void GpuKernelToROCDLPass::runOnOperation() { } } -std::unique_ptr> createGpuKernelToNvvmPass() { - return std::make_unique(); +std::unique_ptr> createGpuKernelToNvvmPass( + bool useBarePtrCallConv) { + return std::make_unique(useBarePtrCallConv); } std::unique_ptr> createGpuKernelToRocdlPass() { diff --git a/xla/mlir_hlo/transforms/gpu_passes.cc b/xla/mlir_hlo/transforms/gpu_passes.cc index bd04deef84ce2..57d2a55c0151e 100644 --- a/xla/mlir_hlo/transforms/gpu_passes.cc +++ b/xla/mlir_hlo/transforms/gpu_passes.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/transforms/gpu_passes.h b/xla/mlir_hlo/transforms/gpu_passes.h index d691555e02c53..a36f45e021faa 100644 --- a/xla/mlir_hlo/transforms/gpu_passes.h +++ b/xla/mlir_hlo/transforms/gpu_passes.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ ArrayAttr getWrittenOperandsAttribute(Operation* op); /// Pass that transforms gpu modules in standard dialect to NNVM. std::unique_ptr> -createGpuKernelToNvvmPass(); +createGpuKernelToNvvmPass(bool useBarePtrCallConv = false); /// Pass that transforms gpu modules in standard dialect to ROCDL. std::unique_ptr> diff --git a/xla/mlir_hlo/transforms/gpu_passes.td b/xla/mlir_hlo/transforms/gpu_passes.td index e277fb6b2d30a..5ee6f75951f1f 100644 --- a/xla/mlir_hlo/transforms/gpu_passes.td +++ b/xla/mlir_hlo/transforms/gpu_passes.td @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,10 @@ include "mlir/Pass/PassBase.td" def GpuKernelToNVVMPass : Pass<"gpu-kernel-to-nvvm", "gpu::GPUModuleOp"> { let summary = "Pass to transform a gpu module to nvvm."; let dependentDialects = ["LLVM::LLVMDialect", "NVVM::NVVMDialect"]; + let options = [ + Option<"useBarePtrCallConv", "use-bare-ptr-call-conv", "bool", + /*default=*/"false", "Use bare pointer memref to llvm lowering">, + ]; let constructor = "createGpuKernelToNvvmPass()"; } diff --git a/xla/mlir_hlo/transforms/lower_index_cast_pass.cc b/xla/mlir_hlo/transforms/lower_index_cast_pass.cc index 30df1fede701b..6d67f53ca1df3 100644 --- a/xla/mlir_hlo/transforms/lower_index_cast_pass.cc +++ b/xla/mlir_hlo/transforms/lower_index_cast_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/transforms/naive_copy_removal.cc b/xla/mlir_hlo/transforms/naive_copy_removal.cc index ddd6e6916971f..55ab2fbb2e0ee 100644 --- a/xla/mlir_hlo/transforms/naive_copy_removal.cc +++ b/xla/mlir_hlo/transforms/naive_copy_removal.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/transforms/passes.h b/xla/mlir_hlo/transforms/passes.h index ac322b01aac4a..3ecee8797e076 100644 --- a/xla/mlir_hlo/transforms/passes.h +++ b/xla/mlir_hlo/transforms/passes.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -44,7 +44,6 @@ using BufferizePatternsCallback = std::function> createBufferPackingPass( - unsigned windowSize = 5); - -/// Creates a pass that tests the useranges of the UserangeAnalysis. -std::unique_ptr> createTestUserangePass(); - -/// Creates a pass that prints the analysis results of ShapeComponentsAnalysis. -std::unique_ptr> -createTestShapeComponentAnalysisPass(); - -/// Creates a pass that computes the allocated memory. -std::unique_ptr> createMemoryCountPass(); - // Pass to lower index cast on tensors to tensor dialect. std::unique_ptr> createLowerIndexCastPass(); diff --git a/xla/mlir_hlo/transforms/passes.td b/xla/mlir_hlo/transforms/passes.td index f97752add6253..88b0649dab305 100644 --- a/xla/mlir_hlo/transforms/passes.td +++ b/xla/mlir_hlo/transforms/passes.td @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,26 +18,6 @@ limitations under the License. include "mlir/Pass/PassBase.td" -def BufferPacking : Pass<"buffer-packing", "func::FuncOp"> { - let summary = "Pass to pack allocated buffer to reduce memory consumption."; - let description = [{The pass tries to pack smaller buffers into larger buffers. - To do this, it sorts all allocated buffers by multiple criteria depends on the - selected window-size. - After this sorting, the buffers are checked whether subsequent buffers can be - packed into them.}]; - let dependentDialects = ["func::FuncDialect","memref::MemRefDialect", - "arith::ArithDialect"]; - let constructor = "createBufferPackingPass()"; - let options = [ - Option<"window_size_", "window-size", "unsigned", - /*default=*/"5", "The window size blurs the start position of an" - "allocated buffer. Buffers allocated in the same sliding window area" - "are treated equally in terms of starting position, withing the" - "sliding window area they are sorted by memory size." - "A window size of zero sorts the buffers only by memory size.">, - ]; -} - def CollapseParallelLoopsTo1DPass : Pass<"collapse-parallel-loops-to-1d"> { let summary = "Collapses multidimensional loops."; let description = [{ The pass converts a multidimensional `scf.parallel` loop @@ -71,24 +51,6 @@ def TileLoopsPass : Pass<"tile-loops", "func::FuncOp"> { let dependentDialects = ["affine::AffineDialect"]; } -def MemoryCount : Pass<"memory-count", "func::FuncOp"> { - let summary = "Test pass to count the allocated memory of a module."; - let description = [{A test pass that prints the size of allocated memory of a - module.}]; - let constructor = "createMemoryCountPass()"; -} - -def TestUserange : Pass<"test-print-userange", "func::FuncOp"> { - let summary = "Test pass for checking userange intervals."; - let constructor = "createTestUserangePass()"; -} - -def TestShapeComponentAnalysis : Pass<"test-print-shape-components", - "func::FuncOp"> { - let summary = "Test pass for analyzing shape components."; - let constructor = "createTestShapeComponentAnalysisPass()"; -} - def LowerIndexCastPass : Pass<"lower-index-cast", "mlir::func::FuncOp"> { let summary = "Lower index cast on tensors to tensor dialect"; diff --git a/xla/mlir_hlo/transforms/propagate_static_shapes_to_kernel.cc b/xla/mlir_hlo/transforms/propagate_static_shapes_to_kernel.cc index 27a53e5a65a27..aa1873478a0e7 100644 --- a/xla/mlir_hlo/transforms/propagate_static_shapes_to_kernel.cc +++ b/xla/mlir_hlo/transforms/propagate_static_shapes_to_kernel.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -198,7 +198,7 @@ LogicalResult PropagateStaticShapesPattern::matchAndRewrite( if (argsToDrop.none()) { return rewriter.notifyMatchFailure(funcOp, "no static shapes"); } - rewriter.updateRootInPlace(funcOp, [&] { + rewriter.modifyOpInPlace(funcOp, [&] { SmallVector argTypes; for (unsigned idx = 0; idx < argsToDrop.size(); ++idx) if (!argsToDrop[idx]) diff --git a/xla/mlir_hlo/transforms/rewriters.h b/xla/mlir_hlo/transforms/rewriters.h index b7196a358b97c..6e484b2237050 100644 --- a/xla/mlir_hlo/transforms/rewriters.h +++ b/xla/mlir_hlo/transforms/rewriters.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,11 +31,6 @@ void populateExtraBufferizePatterns( MLIRContext *context, bufferization::BufferizeTypeConverter *converter, RewritePatternSet *patterns); -/// Populate pattern to bufferize `linalg.tiled_loop`. -void populateTiledLoopBufferizePattern( - MLIRContext *context, bufferization::BufferizeTypeConverter *converter, - RewritePatternSet *patterns); - } // namespace mlir #endif // MLIR_HLO_TRANSFORMS_REWRITERS_H diff --git a/xla/mlir_hlo/transforms/test_hlo_transform_dialect_interpreter.cc b/xla/mlir_hlo/transforms/test_hlo_transform_dialect_interpreter.cc index 69a7529850649..99bfb28760403 100644 --- a/xla/mlir_hlo/transforms/test_hlo_transform_dialect_interpreter.cc +++ b/xla/mlir_hlo/transforms/test_hlo_transform_dialect_interpreter.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" diff --git a/xla/mlir_hlo/transforms/tile_loops_pass.cc b/xla/mlir_hlo/transforms/tile_loops_pass.cc index dbfc9c54d5e1e..ee3b935cff277 100644 --- a/xla/mlir_hlo/transforms/tile_loops_pass.cc +++ b/xla/mlir_hlo/transforms/tile_loops_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/transforms/unbufferize_pass.cc b/xla/mlir_hlo/transforms/unbufferize_pass.cc index e9570004f8628..9b936bd759b58 100644 --- a/xla/mlir_hlo/transforms/unbufferize_pass.cc +++ b/xla/mlir_hlo/transforms/unbufferize_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/transforms/unroll_loops.cc b/xla/mlir_hlo/transforms/unroll_loops.cc index b638fd1a40039..94c9e4fe64c38 100644 --- a/xla/mlir_hlo/transforms/unroll_loops.cc +++ b/xla/mlir_hlo/transforms/unroll_loops.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/transforms/vectorize_copy.cc b/xla/mlir_hlo/transforms/vectorize_copy.cc index 4d1a9fa213e0b..1b68cd8b28b74 100644 --- a/xla/mlir_hlo/transforms/vectorize_copy.cc +++ b/xla/mlir_hlo/transforms/vectorize_copy.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/utils/CMakeLists.txt b/xla/mlir_hlo/utils/CMakeLists.txt index 20cb83cb10c19..55227fdfb8ddf 100644 --- a/xla/mlir_hlo/utils/CMakeLists.txt +++ b/xla/mlir_hlo/utils/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/utils/codegen_utils.cc b/xla/mlir_hlo/utils/codegen_utils.cc index e107cd761b3d5..fc0622ae04627 100644 --- a/xla/mlir_hlo/utils/codegen_utils.cc +++ b/xla/mlir_hlo/utils/codegen_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/utils/codegen_utils.h b/xla/mlir_hlo/utils/codegen_utils.h index cbf9508e5a030..307d1198cb7e4 100644 --- a/xla/mlir_hlo/utils/codegen_utils.h +++ b/xla/mlir_hlo/utils/codegen_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/utils/convert_op_folder.cc b/xla/mlir_hlo/utils/convert_op_folder.cc index 9421c6eb93aad..2468da81bfab9 100644 --- a/xla/mlir_hlo/utils/convert_op_folder.cc +++ b/xla/mlir_hlo/utils/convert_op_folder.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/utils/convert_op_folder.h b/xla/mlir_hlo/utils/convert_op_folder.h index 1186a3014c125..39b3d56d19a0f 100644 --- a/xla/mlir_hlo/utils/convert_op_folder.h +++ b/xla/mlir_hlo/utils/convert_op_folder.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/utils/cycle_detector.cc b/xla/mlir_hlo/utils/cycle_detector.cc index 3e36445ece0e5..e3901ae88cc74 100644 --- a/xla/mlir_hlo/utils/cycle_detector.cc +++ b/xla/mlir_hlo/utils/cycle_detector.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/utils/cycle_detector.h b/xla/mlir_hlo/utils/cycle_detector.h index b8456c55872aa..9f08b754bff84 100644 --- a/xla/mlir_hlo/utils/cycle_detector.h +++ b/xla/mlir_hlo/utils/cycle_detector.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/utils/cycle_detector_test.cc b/xla/mlir_hlo/utils/cycle_detector_test.cc index 4c4903e500694..dd0fdacfb3f9d 100644 --- a/xla/mlir_hlo/utils/cycle_detector_test.cc +++ b/xla/mlir_hlo/utils/cycle_detector_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/mlir_hlo/utils/hlo_utils.cc b/xla/mlir_hlo/utils/hlo_utils.cc index d14103f0ed048..fc02ada8840de 100644 --- a/xla/mlir_hlo/utils/hlo_utils.cc +++ b/xla/mlir_hlo/utils/hlo_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include "utils/hlo_utils.h" #include +#include #include #include #include @@ -23,14 +24,19 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" namespace mlir { namespace hlo { static constexpr size_t kPaddingSize = 64; -DenseIntElementsAttr getBroadcastDimensionsAttr(Builder* b, Value x, Value y, - bool allowEmpty) { +DenseI64ArrayAttr getBroadcastDimensionsAttr(Builder* b, Value x, Value y, + bool allowEmpty) { TensorType xType = x.getType().dyn_cast(); TensorType yType = y.getType().dyn_cast(); if (!xType || !yType) return {}; @@ -56,9 +62,7 @@ DenseIntElementsAttr getBroadcastDimensionsAttr(Builder* b, Value x, Value y, std::iota(broadcastDimensions.begin(), broadcastDimensions.end(), maxRank - minRank); - RankedTensorType type = - RankedTensorType::get({minRank}, b->getIntegerType(64)); - return DenseIntElementsAttr::get(type, broadcastDimensions); + return b->getDenseI64ArrayAttr(broadcastDimensions); } DenseElementsAttr getScalarOfType(Type ty, int64_t rawValue) { @@ -165,7 +169,7 @@ DenseElementsAttr getScalarLimitOfType(Type ty, ScalarLimit limit) { std::string lmhloToMhloOpName(llvm::StringRef opName, mlir::MLIRContext* context) { - assert(opName.startswith("lmhlo.") && "Expected an LMHLO op"); + assert(opName.starts_with("lmhlo.") && "Expected an LMHLO op"); if (opName == "lmhlo.dot") { return "mhlo.dot_general"; diff --git a/xla/mlir_hlo/utils/hlo_utils.h b/xla/mlir_hlo/utils/hlo_utils.h index 84322e23154ba..72f22992b4a94 100644 --- a/xla/mlir_hlo/utils/hlo_utils.h +++ b/xla/mlir_hlo/utils/hlo_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,10 +36,9 @@ namespace hlo { // between two ranked tensors. // If `allow_empty` is true, then null can be returned to mean that the // broadcast is an "identity". -mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b, - mlir::Value x, - mlir::Value y, - bool allowEmpty = true); +mlir::DenseI64ArrayAttr getBroadcastDimensionsAttr(mlir::Builder* b, + mlir::Value x, mlir::Value y, + bool allowEmpty = true); // Get a constant splat for the given value of type. Requires value to be of // type static shaped RankedTensorType. diff --git a/xla/mlir_hlo/utils/placement_utils.h b/xla/mlir_hlo/utils/placement_utils.h index f4b8fb60a584a..8f544c95048c9 100644 --- a/xla/mlir_hlo/utils/placement_utils.h +++ b/xla/mlir_hlo/utils/placement_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/overflow_util.h b/xla/overflow_util.h index 6420af97ea8c1..00fa1c45518db 100644 --- a/xla/overflow_util.h +++ b/xla/overflow_util.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,18 +16,28 @@ limitations under the License. #ifndef XLA_OVERFLOW_UTIL_H_ #define XLA_OVERFLOW_UTIL_H_ -#include +#include #include #include +#include -#include "xla/types.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" +#include "absl/base/attributes.h" +#include "absl/base/optimization.h" namespace xla { -// Multiply two nonnegative int64_t's, returning negative for overflow -inline int64_t MultiplyWithoutOverflow(const int64_t x, const int64_t y) { +// Multiply two non-negative int64_t's, returning the two's complement result +// and a bool which is true when overflow or negative inputs occurs and false +// otherwise. +ABSL_ATTRIBUTE_ALWAYS_INLINE inline std::pair +OverflowSafeMultiply(const int64_t x, const int64_t y) { +#if ABSL_HAVE_BUILTIN(__builtin_mul_overflow) + int64_t result; + bool bad = __builtin_mul_overflow(x, y, &result); + bad |= x < 0; + bad |= y < 0; + return std::make_pair(result, bad); +#else // Multiply in uint64_t rather than int64_t since signed overflow is // undefined. Negative values will wrap around to large unsigned values in the // casts (see section 4.7 [conv.integral] of the C++14 standard). @@ -35,28 +45,40 @@ inline int64_t MultiplyWithoutOverflow(const int64_t x, const int64_t y) { const uint64_t uy = y; const uint64_t uxy = ux * uy; + // Cast back to signed. + int64_t result = static_cast(uxy); + bool bad = result < 0; + // Check if we overflow uint64_t, using a cheap check if both inputs are small if (ABSL_PREDICT_FALSE((ux | uy) >> 32 != 0)) { - // Ensure nonnegativity. Note that negative numbers will appear "large" - // to the unsigned comparisons above. - CHECK(x >= 0 && y >= 0); - - // Otherwise, detect overflow using a division - if (ux != 0 && uxy / ux != uy) return -1; + if (x < 0 || y < 0) { + // Ensure nonnegativity. Note that negative numbers will appear "large" + // to the unsigned comparisons above. + bad = true; + } else if (ux != 0 && uxy / ux != uy) { + // Otherwise, detect overflow using a division + bad = true; + } } - - // Cast back to signed. Any negative value will signal an error. - return static_cast(uxy); + return std::make_pair(result, bad); +#endif } // Computes x + y and returns nullopt if it overflows. // // x and y must be signed integers. template -inline std::optional OverflowSafeAdd(T x, T y) { +ABSL_ATTRIBUTE_ALWAYS_INLINE inline std::optional OverflowSafeAdd(T x, T y) { static_assert(std::is_signed::value, "Only implemented for signed numbers T."); static_assert(std::is_integral::value, "Only implemented for integers T."); +#if ABSL_HAVE_BUILTIN(__builtin_add_overflow) + T result; + if (ABSL_PREDICT_FALSE(__builtin_add_overflow(x, y, &result))) { + return std::nullopt; + } + return result; +#else // "Signed integer overflow occurs on integer addition iff the operands have // the same sign and the sum has a sign opposite to that of the operands." // Hacker's Delight 2nd ed, p 28. @@ -69,6 +91,7 @@ inline std::optional OverflowSafeAdd(T x, T y) { return std::nullopt; } return sum; +#endif } } // namespace xla diff --git a/xla/packed_literal_reader.cc b/xla/packed_literal_reader.cc index ef8add8d91619..03cf165176e51 100644 --- a/xla/packed_literal_reader.cc +++ b/xla/packed_literal_reader.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,8 +38,8 @@ PackedLiteralReader::PackedLiteralReader(tsl::RandomAccessFile* file) PackedLiteralReader::~PackedLiteralReader() { delete file_; } -StatusOr PackedLiteralReader::Read(const Shape& shape, - const Layout* layout) { +absl::StatusOr PackedLiteralReader::Read(const Shape& shape, + const Layout* layout) { VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape) << " layout: " << (layout == nullptr ? "" : layout->ToString()); Shape literal_shape = shape; diff --git a/xla/packed_literal_reader.h b/xla/packed_literal_reader.h index 101b6236f0640..9103e7544ca57 100644 --- a/xla/packed_literal_reader.h +++ b/xla/packed_literal_reader.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -40,7 +40,8 @@ class PackedLiteralReader { // // Layout is optional. If it is not provided, no layout is set on the literal // that is produced. - StatusOr Read(const Shape& shape, const Layout* layout = nullptr); + absl::StatusOr Read(const Shape& shape, + const Layout* layout = nullptr); // Returns whether the input file has been fully exhausted; i.e. all available // packed literals have been read and we're at the end of the file. diff --git a/xla/parse_flags_from_env.cc b/xla/parse_flags_from_env.cc index 5cba95a543b58..0f58671ebff7d 100644 --- a/xla/parse_flags_from_env.cc +++ b/xla/parse_flags_from_env.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,8 +32,8 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/logging.h" -#include "tsl/util/command_line_flags.h" namespace xla { @@ -210,7 +210,7 @@ bool ParseFlagsFromEnvAndIgnoreUnknown( } } - return tsl::Flags::Parse(&env_argv->argc, &env_argv->argv[0], flag_list); + return tsl::Flags::Parse(&env_argv->argc, env_argv->argv.data(), flag_list); } bool DieIfEnvHasUnknownFlagsLeft(absl::string_view envvar) { diff --git a/xla/parse_flags_from_env.h b/xla/parse_flags_from_env.h index 070176754e53a..01d476f22fa3d 100644 --- a/xla/parse_flags_from_env.h +++ b/xla/parse_flags_from_env.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -51,8 +51,8 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/types.h" -#include "tsl/util/command_line_flags.h" namespace xla { diff --git a/xla/parse_flags_from_env_test.cc b/xla/parse_flags_from_env_test.cc index f8554f6931664..f00cb309c12a9 100644 --- a/xla/parse_flags_from_env_test.cc +++ b/xla/parse_flags_from_env_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,11 +24,11 @@ limitations under the License. #include #include "absl/strings/str_format.h" +#include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" #include "tsl/platform/subprocess.h" #include "tsl/platform/test.h" -#include "tsl/util/command_line_flags.h" namespace xla { diff --git a/xla/permutation_util.cc b/xla/permutation_util.cc index e28c4bf89fbdd..040f210c0cc92 100644 --- a/xla/permutation_util.cc +++ b/xla/permutation_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ std::vector InversePermutation( DCHECK(IsPermutation(input_permutation)); std::vector output_permutation(input_permutation.size(), -1); for (size_t i = 0; i < input_permutation.size(); ++i) { - output_permutation.at(input_permutation.at(i)) = i; + output_permutation[input_permutation[i]] = i; } return output_permutation; } diff --git a/xla/permutation_util.h b/xla/permutation_util.h index b1e68a044e4fb..61dde9ca117c4 100644 --- a/xla/permutation_util.h +++ b/xla/permutation_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/permutation_util_test.cc b/xla/permutation_util_test.cc index 51cd0b9253efb..9597da742f09d 100644 --- a/xla/permutation_util_test.cc +++ b/xla/permutation_util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/BUILD b/xla/pjrt/BUILD index f31226380a6ee..061780fed377d 100644 --- a/xla/pjrt/BUILD +++ b/xla/pjrt/BUILD @@ -1,14 +1,16 @@ -# Placeholder: load py_proto_library -load("//xla:xla.bzl", "xla_cc_test") +load("@tsl//tsl:tsl.bzl", "internal_visibility") load( "@tsl//tsl/platform:build_config.bzl", "tf_proto_library", ) load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +# Placeholder: load py_proto_library +load("//xla:xla.bzl", "xla_cc_test") + package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//xla:internal"], + default_visibility = internal_visibility(["//xla:internal"]), licenses = ["notice"], ) @@ -19,6 +21,7 @@ package_group( "//xla:internal", ], packages = [ + "//tensorflow/core/tfrt/ifrt/...", "//third_party/australis/...", "//third_party/openxla_pjrt_plugin/...", "//third_party/py/jax/...", @@ -88,7 +91,10 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/synchronization", + "@tsl//tsl/profiler/lib:connected_traceme", + "@tsl//tsl/profiler/lib:context_types_hdrs", ], ) @@ -114,13 +120,17 @@ cc_library( hdrs = ["local_device_state.h"], deps = [ ":event_pool", + ":pjrt_common", ":semaphore", ":worker_thread", "//xla:status", "//xla:util", "//xla/client:local_client", "//xla/stream_executor", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/synchronization", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", "@tsl//tsl/protobuf:error_codes_proto_impl_cc", ], @@ -161,13 +171,14 @@ cc_library( name = "pjrt_client", srcs = ["pjrt_client.cc"], hdrs = ["pjrt_client.h"], - visibility = ["//xla:friends"], + visibility = internal_visibility(["//xla:friends"]), deps = [ ":pjrt_common", ":pjrt_compiler", ":pjrt_device_description", ":pjrt_executable", ":pjrt_future", + ":pjrt_layout", ":utils", "//xla:literal", "//xla:shape_util", @@ -221,11 +232,13 @@ cc_library( name = "pjrt_executable", srcs = ["pjrt_executable.cc"], hdrs = ["pjrt_executable.h"], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ ":compile_options_proto_cc", + ":executable_metadata_proto_cc", ":execute_options_proto_cc", ":pjrt_common", + ":pjrt_layout", "//xla:shape_layout", "//xla:shape_util", "//xla:status", @@ -247,7 +260,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:protobuf", "@tsl//tsl/platform:statusor", ], ) @@ -279,13 +291,12 @@ cc_library( name = "pjrt_compiler", srcs = ["pjrt_compiler.cc"], hdrs = ["pjrt_compiler.h"], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ ":metrics", ":pjrt_device_description", ":pjrt_executable", "//xla/client:xla_computation", - "//xla/service:hlo_parser", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -320,13 +331,16 @@ cc_library( name = "pjrt_common", hdrs = ["pjrt_common.h"], visibility = [":friends"], + deps = [ + "@tsl//tsl/lib/gtl:int_type", + ], ) cc_library( name = "utils", srcs = ["utils.cc"], hdrs = ["utils.h"], - visibility = ["//xla:friends"], + visibility = internal_visibility(["//xla:friends"]), deps = [ ":layout_mode", "//xla:shape_util", @@ -353,6 +367,22 @@ cc_library( ], ) +cc_library( + name = "pjrt_layout", + hdrs = ["pjrt_layout.h"], + visibility = ["//xla:friends"], + deps = [ + "//xla:shape_util", + "//xla:statusor", + "//xla/service:hlo_parser", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@tsl//tsl/platform:casts", + "@tsl//tsl/platform:statusor", + ], +) + cc_library( name = "layout_mode", srcs = ["layout_mode.cc"], @@ -409,21 +439,26 @@ cc_library( name = "pjrt_stream_executor_client", srcs = ["pjrt_stream_executor_client.cc"], hdrs = ["pjrt_stream_executor_client.h"], - visibility = ["//xla:friends"], + visibility = internal_visibility(["//xla:friends"]), deps = [ ":event_pool", + ":host_callback", ":local_device_state", ":metrics", ":mlir_to_hlo", ":pjrt_client", + ":pjrt_common", + ":pjrt_compiler", ":pjrt_executable", ":pjrt_future", + ":semaphore", ":tracked_device_buffer", ":transpose", ":utils", "//xla:cpu_function_runtime", "//xla:executable_run_options", "//xla:literal", + "//xla:shape_tree", "//xla:shape_util", "//xla:status", "//xla:statusor", @@ -434,6 +469,7 @@ cc_library( "//xla/client:xla_computation", "//xla/hlo/ir:hlo", "//xla/pjrt/distributed:protocol_proto_cc", + "//xla/service:compiler", "//xla/service:computation_layout", "//xla/service:computation_placer", "//xla/service:executable", @@ -447,9 +483,11 @@ cc_library( "//xla/stream_executor/host:host_platform_id", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -457,6 +495,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", "@tsl//tsl/framework:allocator", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:env", @@ -467,6 +506,7 @@ cc_library( "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:connected_traceme", + "@tsl//tsl/profiler/lib:context_types_hdrs", "@tsl//tsl/profiler/lib:traceme", ], ) @@ -519,28 +559,43 @@ cc_library( visibility = [":friends"], deps = [ "//xla:status", + "//xla:statusor", + "//xla:util", "//xla/client:xla_computation", "//xla/mlir/utils:error_util", "//xla/mlir_hlo", + "//xla/mlir_hlo:hlo_dialect_registration", "//xla/mlir_hlo:mhlo_passes", "//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BytecodeWriter", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", + "@llvm-project//mlir:MLProgramDialect", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ReconcileUnrealizedCasts", + "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", "@stablehlo//:chlo_ops", + "@stablehlo//:register", "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_passes", + "@stablehlo//:stablehlo_portable_api", + "@stablehlo//:stablehlo_serialization", ], ) cc_library( name = "pjrt_future", hdrs = ["pjrt_future.h"], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ "@com_google_absl//absl/functional:any_invocable", "@tsl//tsl/concurrency:async_value", @@ -548,13 +603,25 @@ cc_library( ], ) +cc_library( + name = "host_memory_spaces", + srcs = ["host_memory_spaces.cc"], + hdrs = ["host_memory_spaces.h"], + deps = [ + ":pjrt_client", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + # Transitional forwarding target. Use cpu:cpu_client instead. cc_library( name = "tfrt_cpu_pjrt_client", hdrs = ["tfrt_cpu_pjrt_client.h"], - visibility = [ + visibility = internal_visibility([ "//xla:friends", - ], + ]), deps = [ "//xla/pjrt/cpu:cpu_client", ], @@ -589,11 +656,14 @@ cc_library( visibility = [":friends"], deps = [ ":lru_cache", + "//xla:compiler_macros", + "//xla:ef57", "//xla:permutation_util", "//xla:status", "//xla:statusor", "//xla:util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -631,6 +701,7 @@ cc_library( hdrs = ["pjrt_c_api_client.h"], deps = [ ":compile_options_proto_cc", + ":mlir_to_hlo", ":pjrt_api", ":pjrt_client", ":pjrt_common", @@ -638,6 +709,7 @@ cc_library( ":pjrt_device_description", ":pjrt_executable", ":pjrt_future", + ":pjrt_layout", "//xla:literal", "//xla:shape_util", "//xla:status", @@ -651,6 +723,8 @@ cc_library( "//xla/mlir_hlo:mhlo_passes", "//xla/pjrt/c:pjrt_c_api_hdrs", "//xla/pjrt/c:pjrt_c_api_helpers", + "//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs", + "//xla/pjrt/distributed:key_value_store_interface", "//xla/service:computation_placer_hdr", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_module_config", @@ -659,6 +733,7 @@ cc_library( "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -745,7 +820,7 @@ cc_library( name = "host_callback", srcs = ["host_callback.cc"], hdrs = ["host_callback.h"], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ ":pjrt_client", ":pjrt_future", @@ -790,3 +865,44 @@ tf_proto_library( srcs = ["execute_options.proto"], visibility = ["//visibility:public"], ) + +tf_proto_library( + name = "executable_metadata_proto", + srcs = ["executable_metadata.proto"], + protodeps = [ + "//xla/service:hlo_proto", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "exceptions", + hdrs = ["exceptions.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "//xla:status", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "status_casters", + hdrs = ["status_casters.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = [":friends"], + deps = [ + ":exceptions", + "//xla:status", + "//xla:statusor", + "@tsl//tsl/platform:macros", + ], +) diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD index 580ac2258b519..357a778764f1b 100644 --- a/xla/pjrt/c/BUILD +++ b/xla/pjrt/c/BUILD @@ -1,4 +1,3 @@ -load("//xla:xla.bzl", "xla_cc_binary", "xla_cc_test") load( "@tsl//tsl/platform:build_config_root.bzl", "tf_cuda_tests_tags", @@ -8,6 +7,7 @@ load( "@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) +load("//xla:xla.bzl", "xla_cc_binary", "xla_cc_test") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) @@ -53,6 +53,15 @@ cc_library( ], ) +cc_library( + name = "pjrt_c_api_custom_partitioner_extension_hdrs", + hdrs = ["pjrt_c_api_custom_partitioner_extension.h"], + visibility = ["//visibility:public"], + deps = [ + ":pjrt_c_api_hdrs", + ], +) + cc_library( name = "pjrt_c_api_wrapper_impl", srcs = ["pjrt_c_api_wrapper_impl.cc"], @@ -64,7 +73,6 @@ cc_library( "//xla:literal", "//xla:shape_util", "//xla:status", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", @@ -78,11 +86,14 @@ cc_library( "//xla/pjrt:pjrt_device_description", "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", + "//xla/pjrt:pjrt_layout", + "//xla/pjrt/distributed:key_value_store_interface", "//xla/service:computation_placer_hdr", "//xla/service:hlo_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -92,8 +103,11 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//mlir:IR", "@tsl//tsl/framework:allocator", + "@tsl//tsl/platform:casts", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", + "@tsl//tsl/profiler/lib:connected_traceme", + "@tsl//tsl/profiler/lib:context_types_hdrs", ], ) @@ -104,18 +118,20 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":pjrt_c_api_hdrs", + ":pjrt_c_api_profiler_extension_hdrs", "//xla:shape_util", "//xla:status", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", + "//xla/pjrt/distributed:key_value_store_interface", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", @@ -123,6 +139,8 @@ cc_library( "@tsl//tsl/platform:logging", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", + "@tsl//tsl/profiler/lib:connected_traceme", + "@tsl//tsl/profiler/lib:context_types_hdrs", ], ) @@ -151,12 +169,32 @@ cc_library( ], ) +# PJRT CPU plugin. +xla_cc_binary( + name = "pjrt_c_api_cpu_plugin.so", + linkopts = [ + "-Wl,--version-script,$(location :pjrt_c_api_cpu_version_script.lds)", + "-Wl,--no-undefined", + ], + linkshared = True, + tags = [ + "noasan", + "nomsan", + "notsan", + ], + deps = [ + ":pjrt_c_api_cpu", + ":pjrt_c_api_cpu_version_script.lds", + ], +) + cc_library( name = "pjrt_c_api_gpu_internal", srcs = ["pjrt_c_api_gpu_internal.cc"], hdrs = ["pjrt_c_api_gpu_internal.h"], visibility = ["//visibility:public"], deps = [ + ":pjrt_c_api_custom_partitioner_extension_hdrs", ":pjrt_c_api_gpu_extension_hdrs", ":pjrt_c_api_hdrs", ":pjrt_c_api_helpers", @@ -166,12 +204,23 @@ cc_library( "//xla/backends/profiler/plugin:plugin_tracer_impl", "//xla/backends/profiler/plugin:profiler_c_api_hdrs", "//xla/backends/profiler/plugin:profiler_error", + "//xla/client:local_client", + "//xla/ffi", + "//xla/ffi:ffi_api", + "//xla/ffi/api:c_api", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", + "//xla/pjrt:pjrt_compiler", + "//xla/pjrt:pjrt_device_description", "//xla/pjrt/gpu:gpu_helpers", "//xla/pjrt/gpu:se_gpu_pjrt_client", + "//xla/pjrt/gpu:se_gpu_pjrt_compiler", # To register GPU AOT compiler + "//xla/python:custom_partition_callback", "//xla/python:inspect_sharding", # To register "InspectSharding" custom partitioning handler. + "//xla/service:compiler", "//xla/service:custom_call_target_registry", + "//xla/stream_executor:device_description", + "//xla/stream_executor:stream_executor_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", @@ -187,6 +236,9 @@ cc_library( ":pjrt_c_api_gpu_internal", ":pjrt_c_api_hdrs", ":pjrt_c_api_macros", + "@com_google_absl//absl/base", + "@com_google_absl//absl/log:initialize", + "@tsl//tsl/platform", ], alwayslink = 1, ) @@ -254,9 +306,12 @@ xla_cc_test( "//xla:shape_util", "//xla:status", "//xla:statusor", + "//xla/ffi:ffi_api", + "//xla/ffi/api:ffi", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", "//xla/pjrt:pjrt_future", + "//xla/pjrt/distributed:in_memory_key_value_store", "//xla/service:custom_call_target_registry", "//xla/service:gpu_plugin", "//xla/tests:literal_test_util", @@ -285,11 +340,11 @@ xla_cc_test( "//xla:statusor", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", + "//xla/pjrt/distributed:in_memory_key_value_store", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", "@tsl//tsl/lib/core:status_test_util", @@ -329,7 +384,6 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/pjrt:compile_options_proto_cc", "//xla/pjrt:pjrt_client", - "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", "//xla/service:computation_placer_hdr", "//xla/service:hlo_parser", @@ -343,8 +397,6 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/pjrt/c/CHANGELOG.md b/xla/pjrt/c/CHANGELOG.md index 5bd553df96126..ec1f828b3b98d 100644 --- a/xla/pjrt/c/CHANGELOG.md +++ b/xla/pjrt/c/CHANGELOG.md @@ -1,5 +1,33 @@ # PJRT C API changelog +## 0.47 +* Added ``PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner``. +* Renamed host buffer semantics enum from ``PJRT_HostBufferSemantics_kZeroCopy`` + to ``PJRT_HostBufferSemantics_kImmutableZeroCopy``. + +## 0.46 (Feb 29, 2024) +* Update outdated struct sizes from previous changes to + ``PJRT_Device_AddressableMemories_Args`` and ``PJRT_ExecuteOptions``. + +## 0.45 (Feb 27, 2024) +* Breaking changes + * Added struct_size field to beginning of PJRT_Extension_Base. This is so + forwards and backwards compatibility logic can be implemented with extension + structs. + +## 0.44 (Feb 26, 2024) +* Changed all ``void*`` extension fields to have type ``PJRT_Extension_Base*`` + +## 0.43 (Feb 24, 2024) +* Added some new fields to PJRT_Executable_GetCompiledMemoryStats + +## 0.42 (Feb 13, 2024) +* Renamed all ``priv`` fields to ``extension_start`` + +## 0.41 (Feb 13, 2024) +* Renamed PJRT_Structure_Base to PJRT_Extension_Base +* Renamed PJRT_Structure_Type to PJRT_Extension_Type (and similarly for enum fields) + ## 0.40 (Nov 27, 2023) * Added PJRT_Executable_GetCompiledMemoryStats. @@ -21,7 +49,7 @@ PJRT_ExecuteOptions. * Deprecated PJRT_LoadedExecutable_Fingerprint ## 0.34 (Oct 9, 2023) -* Added PJRT_Structure_Type::PJRT_Structure_Type_Profiler. +* Added PJRT_Extension_Type::PJRT_Extension_Type_Profiler. ## 0.33 (Oct 3, 2023) * Added PJRT_Client_CreateViewOfDeviceBuffer. @@ -30,9 +58,9 @@ PJRT_ExecuteOptions. * Added PJRT_Buffer_CopyToMemory. ## 0.31 (Sep 22, 2023) -* Added PJRT_Structure_Base. -* Added PJRT_Structure_Type. -* Renamed PJRT_Api.priv to PJRT_Api.extension_start. +* Added PJRT_Extension_Base. +* Added PJRT_Extension_Type. +* Renamed PJRT_Api.extension_start to PJRT_Api.extension_start. ## 0.30 (Sep 14, 2023) * Added PJRT_NamedValue_Type::PJRT_NamedValue_kBool. diff --git a/xla/pjrt/c/README.md b/xla/pjrt/c/README.md index cdead8e4d01b6..62855a48bde30 100644 --- a/xla/pjrt/c/README.md +++ b/xla/pjrt/c/README.md @@ -15,5 +15,7 @@ opaque to the frameworks. * [PJRT C API changelog](https://github.com/openxla/xla/blob/main/xla/pjrt/c/CHANGELOG.md) * [PJRT integration guide](https://github.com/openxla/xla/blob/main/xla/pjrt/c/docs/pjrt_integration_guide.md) +* [PJRT design docs](https://drive.google.com/corp/drive/folders/18M944-QQPk1E34qRyIjkqDRDnpMa3miN) +* [PJRT API ABI versioning and compatibility](https://docs.google.com/document/d/1TKB5NyGtdzrpgw5mpyFjVAhJjpSNdF31T6pjPl_UT2o/edit) * [PJRT Plugin Mechanism design doc](https://docs.google.com/document/d/1Qdptisz1tUPGn1qFAVgCV2omnfjN01zoQPwKLdlizas/edit) * [OpenXLA/IREE PJRT plugin implementation](https://github.com/openxla/openxla-pjrt-plugin) diff --git a/xla/pjrt/c/docs/pjrt_integration_guide.md b/xla/pjrt/c/docs/pjrt_integration_guide.md index 0dfd763dd428b..0cde6b00f632c 100644 --- a/xla/pjrt/c/docs/pjrt_integration_guide.md +++ b/xla/pjrt/c/docs/pjrt_integration_guide.md @@ -41,10 +41,13 @@ With the [wrapper](https://github.com/openxla/xla/blob/c23fbd601a017be25726fd6d6 ### Step 2: Implement GetPjRtApi -You need to implement a method `GetPjRtApi` which returns a `PJRT_Api*` containing function pointers to PJRT C API implementations. Below is an example assuming implementing through wrapper (similar to [pjrt\_c\_api\_cpu.cc](https://github.com/openxla/xla/blob/c23fbd601a017be25726fd6d624b22daa6a8a4e5/xla/pjrt/c/pjrt_c_api_cpu.cc)): +You need to implement a method `GetPjRtApi` which returns a `PJRT_Api*` containing function pointers to PJRT C API implementations. Below is an example assuming implementing through wrapper (similar to [pjrt\_c\_api\_cpu.cc](https://github.com/openxla/xla/blob/main/xla/pjrt/c/pjrt_c_api_cpu.cc)): ``` -constexpr PJRT_Api pjrt_api = pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create); -const PJRT_Api* GetPjrtApi() { return &pjrt_api; } +const PJRT_Api* GetPjrtApi() { + static const PJRT_Api pjrt_api = + pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create); + return &pjrt_api; +} ``` ### Step 3: Test C API implementations @@ -132,17 +135,21 @@ print(jax.jit(lambda x: x * 2)(1.)) # => 2.0 # pmap -arr = jax.numpy.arange(jax.device_count()) -print(jax.pmap(lambda x: x + jax.lax.psum(x, 'i'), - axis_name='i')(arr)) + +arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x + +jax.lax.psum(x, 'i'), axis_name='i')(arr)) + # single device: [0] + # 4 devices: [6 7 8 9] + ``` (We'll add instructions for running the jax unit tests against your plugin soon!) ## Example: JAX CUDA plugin 1. PJRT C API implementation through wrapper ([pjrt\_c\_api\_gpu.h](https://github.com/openxla/xla/blob/c23fbd601a017be25726fd6d624b22daa6a8a4e5/xla/pjrt/c/pjrt_c_api_gpu.h)). -1. Set up the entry point for the package ([setup.py](https://github.com/google/jax/blob/main/plugins/cuda/setup.py)). +1. Set up the entry point for the package ([setup.py](https://github.com/google/jax/blob/main/jax_plugins/cuda/setup.py)). 1. Implement an initialize() method ([\_\_init\_\_.py](https://github.com/google/jax/blob/a10854786b6d1bc92a65dd314916b151640789af/plugins/cuda/__init__.py#L31-L51)). 1. Can be tested with any jax tests for CUDA. +``` diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index e249630374a13..670501205c35c 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,9 +20,14 @@ limitations under the License. #include #include +// Read more on C API ABI versioning and compatibility here: +// https://docs.google.com/document/d/1TKB5NyGtdzrpgw5mpyFjVAhJjpSNdF31T6pjPl_UT2o/edit?usp=sharing + #define PJRT_STRUCT_SIZE(struct_type, last_field) \ offsetof(struct_type, last_field) + sizeof(((struct_type*)0)->last_field) +// Must update PJRT_DEFINE_STRUCT_TRAITS with the new `last_field` after +// adding a new member to a struct. #define PJRT_DEFINE_STRUCT_TRAITS(sname, last_field) \ typedef struct sname sname; \ enum { sname##_STRUCT_SIZE = PJRT_STRUCT_SIZE(sname, last_field) } @@ -31,6 +36,24 @@ limitations under the License. extern "C" { #endif +// ------------------------------- Extensions ---------------------------------- + +typedef enum { + PJRT_Extension_Type_Gpu_Custom_Call = 0, + PJRT_Extension_Type_Profiler, + PJRT_Extension_Type_Custom_Partitioner, +} PJRT_Extension_Type; + +// PJRT_Extension_Base contains a type and a pointer to next +// PJRT_Extension_Base. The framework can go through this chain to find an +// extension and identify it with the type. +typedef struct PJRT_Extension_Base { + size_t struct_size; + PJRT_Extension_Type type; + PJRT_Extension_Base* next; +} PJRT_Extension_Base; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); + // --------------------------------- Version ----------------------------------- // Incremented when an ABI-incompatible change is made to the interface. @@ -53,14 +76,14 @@ extern "C" { // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 40 +#define PJRT_API_MINOR 47 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in // this header that the implementation was compiled with. struct PJRT_Api_Version { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; int major_version; // out int minor_version; // out }; @@ -77,7 +100,7 @@ typedef struct PJRT_Error PJRT_Error; struct PJRT_Error_Destroy_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Error* error; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Error_Destroy_Args, error); @@ -87,7 +110,7 @@ typedef void PJRT_Error_Destroy(PJRT_Error_Destroy_Args* args); struct PJRT_Error_Message_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; const PJRT_Error* error; // Has the lifetime of `error`. const char* message; // out @@ -121,7 +144,7 @@ typedef enum { struct PJRT_Error_GetCode_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; const PJRT_Error* error; PJRT_Error_Code code; // out }; @@ -151,7 +174,7 @@ typedef enum { // Named value for key-value pairs. struct PJRT_NamedValue { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; const char* name; size_t name_size; PJRT_NamedValue_Type type; @@ -172,16 +195,16 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_NamedValue, value_size); struct PJRT_Plugin_Initialize_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; }; -PJRT_DEFINE_STRUCT_TRAITS(PJRT_Plugin_Initialize_Args, priv); +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Plugin_Initialize_Args, extension_start); // One-time plugin setup. Must be called before any other functions are called. typedef PJRT_Error* PJRT_Plugin_Initialize(PJRT_Plugin_Initialize_Args* args); struct PJRT_Plugin_Attributes_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; // Returned attributes have the lifetime of the process. const PJRT_NamedValue* attributes; // out size_t num_attributes; // out @@ -205,7 +228,7 @@ typedef struct PJRT_Event PJRT_Event; struct PJRT_Event_Destroy_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Event* event; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Event_Destroy_Args, event); @@ -215,7 +238,7 @@ typedef PJRT_Error* PJRT_Event_Destroy(PJRT_Event_Destroy_Args* args); struct PJRT_Event_IsReady_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Event* event; bool is_ready; // out }; @@ -227,7 +250,7 @@ typedef PJRT_Error* PJRT_Event_IsReady(PJRT_Event_IsReady_Args* args); struct PJRT_Event_Error_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Event* event; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Event_Error_Args, event); @@ -245,7 +268,7 @@ typedef PJRT_Error* PJRT_Event_Error(PJRT_Event_Error_Args* args); struct PJRT_Event_Await_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Event* event; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Event_Await_Args, event); @@ -263,7 +286,7 @@ typedef void (*PJRT_Event_OnReadyCallback)(PJRT_Error* error, void* user_arg); struct PJRT_Event_OnReady_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Event* event; PJRT_Event_OnReadyCallback callback; // `user_arg` allows `callback` to be called with arbitrary arguments (e.g. @@ -297,7 +320,7 @@ typedef void (*PJRT_KeyValueGetCallback_ValueDeleter)(char* value); struct PJRT_KeyValueGetCallback_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; const char* key; size_t key_size; int timeout_in_ms; @@ -323,7 +346,7 @@ typedef PJRT_Error* (*PJRT_KeyValueGetCallback)( struct PJRT_KeyValuePutCallback_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; const char* key; size_t key_size; // Only needs to stay alive for the duration of the PJRT_KeyValuePutCallback @@ -344,7 +367,7 @@ typedef PJRT_Error* (*PJRT_KeyValuePutCallback)( struct PJRT_Client_Create_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; // Extra platform-specific options to create a client. const PJRT_NamedValue* create_options; size_t num_options; @@ -367,7 +390,7 @@ typedef PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args); struct PJRT_Client_Destroy_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Client* client; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Destroy_Args, client); @@ -377,7 +400,7 @@ typedef PJRT_Error* PJRT_Client_Destroy(PJRT_Client_Destroy_Args* args); struct PJRT_Client_PlatformName_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Client* client; // `platform_name` has the same lifetime as `client`. It is owned by `client`. const char* platform_name; // out @@ -391,7 +414,7 @@ typedef PJRT_Error* PJRT_Client_PlatformName( struct PJRT_Client_ProcessIndex_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Client* client; int process_index; // out }; @@ -404,7 +427,7 @@ typedef PJRT_Error* PJRT_Client_ProcessIndex( struct PJRT_Client_PlatformVersion_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Client* client; // `platform_version` has the same lifetime as `client`. It's owned by // `client`. @@ -421,7 +444,7 @@ typedef PJRT_Error* PJRT_Client_PlatformVersion( struct PJRT_Client_TopologyDescription_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Client* client; // Is owned by and has the same lifetime as `client`. PJRT_TopologyDescription* topology; // out @@ -435,7 +458,7 @@ typedef PJRT_Error* PJRT_Client_TopologyDescription( struct PJRT_Client_Devices_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Client* client; PJRT_Device* const* devices; // out size_t num_devices; // out @@ -448,7 +471,7 @@ typedef PJRT_Error* PJRT_Client_Devices(PJRT_Client_Devices_Args* args); struct PJRT_Client_AddressableDevices_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Client* client; PJRT_Device* const* addressable_devices; // out size_t num_addressable_devices; // out @@ -464,7 +487,7 @@ typedef PJRT_Error* PJRT_Client_AddressableDevices( struct PJRT_Client_LookupDevice_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Client* client; int id; // `device` has the same lifetime as `client`. It is owned by `client`. @@ -479,7 +502,7 @@ typedef PJRT_Error* PJRT_Client_LookupDevice( struct PJRT_Client_LookupAddressableDevice_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Client* client; int local_hardware_id; // `addressable_device` has the same lifetime as `client`. It is owned by @@ -496,7 +519,7 @@ typedef PJRT_Error* PJRT_Client_LookupAddressableDevice( struct PJRT_Client_AddressableMemories_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Client* client; PJRT_Memory* const* addressable_memories; // out size_t num_addressable_memories; // out @@ -512,7 +535,7 @@ typedef PJRT_Error* PJRT_Client_AddressableMemories( struct PJRT_Program { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; // Serialized code in the specified format below. // String is owned by the caller. char* code; // in/out depending on usage @@ -529,7 +552,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Program, format_size); struct PJRT_Client_Compile_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Client* client; // Only needs to stay alive for the duration of the Compile call. // `program->format` and `program->format_size` are owned by the caller. @@ -549,7 +572,7 @@ typedef PJRT_Error* PJRT_Client_Compile(PJRT_Client_Compile_Args* args); struct PJRT_Client_DefaultDeviceAssignment_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Client* client; int num_replicas; int num_partitions; @@ -629,11 +652,13 @@ typedef enum { PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes, // The PjRtBuffer may alias `data` internally and the runtime may use the - // `data` contents as long as the buffer is alive. The caller promises to - // keep `data` alive and not to mutate its contents as long as the buffer is - // alive; to notify the caller that the buffer may be freed, the runtime - // will call `done_with_host_buffer` when the PjRtBuffer is freed. - PJRT_HostBufferSemantics_kZeroCopy, + // `data` contents as long as the buffer is alive. The runtime promises not + // to mutate contents of the buffer (i.e. it will not use it for aliased + // output buffers). The caller promises to keep `data` alive and not to mutate + // its contents as long as the buffer is alive; to notify the caller that the + // buffer may be freed, the runtime will call `done_with_host_buffer` when the + // PjRtBuffer is freed. + PJRT_HostBufferSemantics_kImmutableZeroCopy, } PJRT_HostBufferSemantics; typedef enum { @@ -643,7 +668,7 @@ typedef enum { struct PJRT_Buffer_MemoryLayout_Tiled { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; // A map from physical dimension numbers to logical dimension numbers. // The first element is the most minor physical dimension (fastest varying // index) and the last the most major (slowest varying index). The contents of @@ -661,7 +686,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_MemoryLayout_Tiled, num_tiles); struct PJRT_Buffer_MemoryLayout_Strides { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; // Number of bytes to traverse per dimension. Must be the same size as // the number of dimensions of the data. Caution: `byte_strides` are allowed // to be negative, in which case data may need to point to the interior of @@ -676,7 +701,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_MemoryLayout_Strides, num_byte_strides); // strides. struct PJRT_Buffer_MemoryLayout { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; union { PJRT_Buffer_MemoryLayout_Tiled tiled; PJRT_Buffer_MemoryLayout_Strides strides; @@ -687,7 +712,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_MemoryLayout, type); struct PJRT_Client_BufferFromHostBuffer_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Client* client; // Pointer to the host buffer const void* data; @@ -735,7 +760,7 @@ typedef PJRT_Error* PJRT_Client_BufferFromHostBuffer( struct PJRT_Client_CreateViewOfDeviceBuffer_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Client* client; // A pointer to a non-owned device buffer. A PJRT_Buffer that is a non-owned // view of this device buffer will be created. @@ -781,7 +806,7 @@ typedef PJRT_Error* PJRT_Client_CreateViewOfDeviceBuffer( struct PJRT_DeviceDescription_Id_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_DeviceDescription* device_description; int id; // out }; @@ -795,7 +820,7 @@ typedef PJRT_Error* PJRT_DeviceDescription_Id( struct PJRT_DeviceDescription_ProcessIndex_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_DeviceDescription* device_description; int process_index; // out }; @@ -812,7 +837,7 @@ typedef PJRT_Error* PJRT_DeviceDescription_ProcessIndex( struct PJRT_DeviceDescription_Attributes_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_DeviceDescription* device_description; size_t num_attributes; // out const PJRT_NamedValue* attributes; // out @@ -826,7 +851,7 @@ typedef PJRT_Error* PJRT_DeviceDescription_Attributes( struct PJRT_DeviceDescription_Kind_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_DeviceDescription* device_description; // `device_kind` string is owned by `device` and has same lifetime as // `device`. @@ -842,7 +867,7 @@ typedef PJRT_Error* PJRT_DeviceDescription_Kind( struct PJRT_DeviceDescription_DebugString_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_DeviceDescription* device_description; const char* debug_string; // out size_t debug_string_size; // out @@ -857,7 +882,7 @@ typedef PJRT_Error* PJRT_DeviceDescription_DebugString( struct PJRT_DeviceDescription_ToString_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_DeviceDescription* device_description; const char* to_string; // out size_t to_string_size; // out @@ -873,7 +898,7 @@ typedef PJRT_Error* PJRT_DeviceDescription_ToString( struct PJRT_Device_GetDescription_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Device* device; PJRT_DeviceDescription* device_description; // out }; @@ -885,7 +910,7 @@ typedef PJRT_Error* PJRT_Device_GetDescription( struct PJRT_Device_IsAddressable_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Device* device; bool is_addressable; // out }; @@ -897,7 +922,7 @@ typedef PJRT_Error* PJRT_Device_IsAddressable( struct PJRT_Device_LocalHardwareId_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Device* device; int local_hardware_id; // out }; @@ -910,13 +935,13 @@ typedef PJRT_Error* PJRT_Device_LocalHardwareId( struct PJRT_Device_AddressableMemories_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Device* device; // Has the lifetime of `device`. PJRT_Memory* const* memories; // out size_t num_memories; // out }; -PJRT_DEFINE_STRUCT_TRAITS(PJRT_Device_AddressableMemories_Args, memories); +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Device_AddressableMemories_Args, num_memories); // Returns the memories that a device can address. typedef PJRT_Error* PJRT_Device_AddressableMemories( @@ -924,7 +949,7 @@ typedef PJRT_Error* PJRT_Device_AddressableMemories( struct PJRT_Device_DefaultMemory_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Device* device; // `memory` has the same lifetime as `device`. PJRT_Memory* memory; // out @@ -938,7 +963,7 @@ typedef PJRT_Error* PJRT_Device_DefaultMemory( struct PJRT_Device_MemoryStats_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Device* device; // Number of bytes in use. @@ -989,7 +1014,7 @@ typedef PJRT_Error* PJRT_Device_MemoryStats(PJRT_Device_MemoryStats_Args* args); struct PJRT_Memory_Id_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Memory* memory; int id; // out }; @@ -1000,7 +1025,7 @@ typedef PJRT_Error* PJRT_Memory_Id(PJRT_Memory_Id_Args* args); struct PJRT_Memory_Kind_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Memory* memory; // `memory_kind` has same lifetime as `memory`. const char* memory_kind; // out @@ -1013,7 +1038,7 @@ typedef PJRT_Error* PJRT_Memory_Kind(PJRT_Memory_Kind_Args* args); struct PJRT_Memory_DebugString_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Memory* memory; const char* debug_string; // out size_t debug_string_size; // out @@ -1026,7 +1051,7 @@ typedef PJRT_Error* PJRT_Memory_DebugString(PJRT_Memory_DebugString_Args* args); struct PJRT_Memory_ToString_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Memory* memory; const char* to_string; // out size_t to_string_size; // out @@ -1038,7 +1063,7 @@ typedef PJRT_Error* PJRT_Memory_ToString(PJRT_Memory_ToString_Args* args); struct PJRT_Memory_AddressableByDevices_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Memory* memory; PJRT_Device* const* devices; // out size_t num_devices; // out @@ -1053,7 +1078,7 @@ typedef PJRT_Error* PJRT_Memory_AddressableByDevices( struct PJRT_Executable_Destroy_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Executable* executable; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_Destroy_Args, executable); @@ -1063,7 +1088,7 @@ typedef PJRT_Error* PJRT_Executable_Destroy(PJRT_Executable_Destroy_Args* args); struct PJRT_LoadedExecutable_Destroy_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_LoadedExecutable* executable; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_Destroy_Args, executable); @@ -1075,7 +1100,7 @@ typedef PJRT_Error* PJRT_LoadedExecutable_Destroy( struct PJRT_LoadedExecutable_GetExecutable_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_LoadedExecutable* loaded_executable; PJRT_Executable* executable; // out }; @@ -1088,7 +1113,7 @@ typedef PJRT_Error* PJRT_LoadedExecutable_GetExecutable( struct PJRT_Executable_Name_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Executable* executable; // `executable_name` has the same lifetime as `executable`. It is owned by // `executable`. @@ -1103,7 +1128,7 @@ typedef PJRT_Error* PJRT_Executable_Name(PJRT_Executable_Name_Args* args); // TODO(b/269178731): Revisit whether num_replicas is needed. struct PJRT_Executable_NumReplicas_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Executable* executable; size_t num_replicas; // out }; @@ -1115,7 +1140,7 @@ typedef PJRT_Error* PJRT_Executable_NumReplicas( struct PJRT_Executable_NumPartitions_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Executable* executable; size_t num_partitions; // out }; @@ -1127,7 +1152,7 @@ typedef PJRT_Error* PJRT_Executable_NumPartitions( struct PJRT_LoadedExecutable_AddressableDevices_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_LoadedExecutable* executable; PJRT_Device* const* addressable_devices; // out size_t num_addressable_devices; // out @@ -1141,7 +1166,7 @@ typedef PJRT_Error* PJRT_LoadedExecutable_AddressableDevices( struct PJRT_Executable_OptimizedProgram_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Executable* executable; PJRT_Program* program; // out, but read below }; @@ -1175,7 +1200,7 @@ typedef PJRT_Error* PJRT_Executable_OptimizedProgram( struct PJRT_LoadedExecutable_Delete_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_LoadedExecutable* executable; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_Delete_Args, executable); @@ -1190,7 +1215,7 @@ typedef PJRT_Error* PJRT_LoadedExecutable_Delete( struct PJRT_LoadedExecutable_IsDeleted_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_LoadedExecutable* executable; bool is_deleted; // out }; @@ -1247,7 +1272,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_RecvCallbackInfo, recv_callback); struct PJRT_ExecuteOptions { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; // Callbacks for when send/recv ops are executed. The outer lists correspond // to each device returned by `PJRT_Executable_AddressableDevices` for // `executable` (i.e. they will have length `num_devices`). Each inner list @@ -1275,11 +1300,11 @@ struct PJRT_ExecuteOptions { const int64_t* non_donatable_input_indices; size_t num_non_donatable_input_indices; }; -PJRT_DEFINE_STRUCT_TRAITS(PJRT_ExecuteOptions, launch_id); +PJRT_DEFINE_STRUCT_TRAITS(PJRT_ExecuteOptions, num_non_donatable_input_indices); struct PJRT_LoadedExecutable_Execute_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_LoadedExecutable* executable; // Only needs to stay alive for the duration of the Execute call. PJRT_ExecuteOptions* options; @@ -1318,7 +1343,7 @@ typedef PJRT_Error* PJRT_LoadedExecutable_Execute( struct PJRT_Executable_NumOutputs_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Executable* executable; size_t num_outputs; // out }; @@ -1330,7 +1355,7 @@ typedef PJRT_Error* PJRT_Executable_NumOutputs( struct PJRT_Executable_SizeOfGeneratedCodeInBytes_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Executable* executable; int64_t size_in_bytes; // out }; @@ -1342,7 +1367,7 @@ typedef PJRT_Error* PJRT_Executable_SizeOfGeneratedCodeInBytes( struct PJRT_Executable_Fingerprint_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Executable* executable; // Has the lifetime of `executable` const char* executable_fingerprint; // out @@ -1360,7 +1385,7 @@ typedef PJRT_Error* PJRT_Executable_Fingerprint( struct PJRT_Executable_GetCostAnalysis_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Executable* executable; size_t num_properties; // out // `properties` and any embedded data are owned by and have the same lifetime @@ -1378,28 +1403,37 @@ typedef PJRT_Error* PJRT_Executable_GetCostAnalysis( struct PJRT_Executable_GetCompiledMemoryStats_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Executable* executable; // Mirrors xla::CompiledMemoryStats. + // Device default memory (e.g., HBM for GPU/TPU) usage stats. int64_t generated_code_size_in_bytes; // out int64_t argument_size_in_bytes; // out int64_t output_size_in_bytes; // out // How much argument is reused for output. int64_t alias_size_in_bytes; // out int64_t temp_size_in_bytes; // out + + // Host memory usage stats. + int64_t host_generated_code_size_in_bytes; // out + int64_t host_argument_size_in_bytes; // out + int64_t host_output_size_in_bytes; // out + int64_t host_alias_size_in_bytes; // out + int64_t host_temp_size_in_bytes; // out }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_GetCompiledMemoryStats_Args, - temp_size_in_bytes); + host_temp_size_in_bytes); -// Return memory stats that allow callers to estimate device memory usage -// when running this executable. +// Return memory stats that allow callers to estimate memory usage when running +// this executable. The memory stats could contain usage info from different +// memory spaces, like default memory (e.g., HBM for GPU/TPU) and host memory. typedef PJRT_Error* PJRT_Executable_GetCompiledMemoryStats( PJRT_Executable_GetCompiledMemoryStats_Args* args); struct PJRT_Executable_OutputElementTypes_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Executable* executable; PJRT_Buffer_Type* output_types; // out size_t num_output_types; // out @@ -1413,7 +1447,7 @@ typedef PJRT_Error* PJRT_Executable_OutputElementTypes( struct PJRT_Executable_OutputDimensions_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Executable* executable; size_t num_outputs; // Has length: sum of all elements in the list `dim_sizes`. @@ -1431,7 +1465,7 @@ typedef PJRT_Error* PJRT_Executable_OutputDimensions( struct PJRT_Executable_OutputMemoryKinds_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Executable* executable; size_t num_outputs; // Has length `num_outputs`. @@ -1450,7 +1484,7 @@ typedef struct PJRT_SerializedExecutable PJRT_SerializedExecutable; struct PJRT_Executable_Serialize_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; const PJRT_Executable* executable; // Lives only as long as serialized_executable @@ -1473,7 +1507,7 @@ typedef PJRT_Error* PJRT_Executable_Serialize( struct PJRT_Executable_DeserializeAndLoad_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Client* client; const char* serialized_executable; size_t serialized_executable_size; @@ -1490,7 +1524,7 @@ typedef PJRT_Error* PJRT_Executable_DeserializeAndLoad( struct PJRT_LoadedExecutable_Fingerprint_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_LoadedExecutable* executable; // Has the lifetime of `executable` const char* executable_fingerprint; // out @@ -1510,7 +1544,7 @@ typedef PJRT_Error* PJRT_LoadedExecutable_Fingerprint( struct PJRT_Buffer_Destroy_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* buffer; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_Destroy_Args, buffer); @@ -1521,7 +1555,7 @@ typedef PJRT_Error* PJRT_Buffer_Destroy(PJRT_Buffer_Destroy_Args* args); struct PJRT_Buffer_ElementType_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* buffer; PJRT_Buffer_Type type; // out }; @@ -1532,7 +1566,7 @@ typedef PJRT_Error* PJRT_Buffer_ElementType(PJRT_Buffer_ElementType_Args* args); struct PJRT_Buffer_Dimensions_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* buffer; // Has the lifetime of `buffer` and length `num_dims`. const int64_t* dims; // out @@ -1545,7 +1579,7 @@ typedef PJRT_Error* PJRT_Buffer_Dimensions(PJRT_Buffer_Dimensions_Args* args); struct PJRT_Buffer_UnpaddedDimensions_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* buffer; // Has the lifetime of `buffer` and length `num_dims`. const int64_t* unpadded_dims; // out @@ -1565,7 +1599,7 @@ typedef PJRT_Error* PJRT_Buffer_UnpaddedDimensions( struct PJRT_Buffer_DynamicDimensionIndices_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* buffer; // Has the lifetime of `buffer` and length `num_dynamic_dims`. const size_t* dynamic_dim_indices; // out @@ -1583,7 +1617,7 @@ typedef PJRT_Error* PJRT_Buffer_DynamicDimensionIndices( struct PJRT_Buffer_GetMemoryLayout_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* buffer; // Layout data is owned by and has the lifetime of `buffer`. PJRT_Buffer_MemoryLayout layout; // out @@ -1596,7 +1630,7 @@ typedef PJRT_Error* PJRT_Buffer_GetMemoryLayout( struct PJRT_Buffer_ToHostBuffer_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* src; // The caller can specify an optional host layout. If nullptr, the layout of @@ -1622,7 +1656,7 @@ typedef PJRT_Error* PJRT_Buffer_ToHostBuffer( struct PJRT_Buffer_OnDeviceSizeInBytes_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* buffer; size_t on_device_size_in_bytes; // out }; @@ -1635,7 +1669,7 @@ typedef PJRT_Error* PJRT_Buffer_OnDeviceSizeInBytes( struct PJRT_Buffer_Delete_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* buffer; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_Delete_Args, buffer); @@ -1649,7 +1683,7 @@ typedef PJRT_Error* PJRT_Buffer_Delete(PJRT_Buffer_Delete_Args* args); struct PJRT_Buffer_IsDeleted_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* buffer; bool is_deleted; // out }; @@ -1660,7 +1694,7 @@ typedef PJRT_Error* PJRT_Buffer_IsDeleted(PJRT_Buffer_IsDeleted_Args* args); struct PJRT_Buffer_CopyToDevice_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* buffer; PJRT_Device* dst_device; PJRT_Buffer* dst_buffer; // out @@ -1675,7 +1709,7 @@ typedef PJRT_Error* PJRT_Buffer_CopyToDevice( struct PJRT_Buffer_CopyToMemory_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* buffer; PJRT_Memory* dst_memory; PJRT_Buffer* dst_buffer; // out @@ -1690,7 +1724,7 @@ typedef PJRT_Error* PJRT_Buffer_CopyToMemory( struct PJRT_Buffer_IsOnCpu_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* buffer; bool is_on_cpu; // out }; @@ -1701,7 +1735,7 @@ typedef PJRT_Error* PJRT_Buffer_IsOnCpu(PJRT_Buffer_IsOnCpu_Args* args); struct PJRT_Buffer_Device_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* buffer; PJRT_Device* device; // out }; @@ -1712,7 +1746,7 @@ typedef PJRT_Error* PJRT_Buffer_Device(PJRT_Buffer_Device_Args* args); struct PJRT_Buffer_Memory_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* buffer; PJRT_Memory* memory; // out }; @@ -1723,7 +1757,7 @@ typedef PJRT_Error* PJRT_Buffer_Memory(PJRT_Buffer_Memory_Args* args); struct PJRT_Buffer_ReadyEvent_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* buffer; // The caller is responsible for calling PJRT_Event_Destroy on `event`. PJRT_Event* event; // out @@ -1743,7 +1777,7 @@ typedef PJRT_Error* PJRT_Buffer_ReadyEvent(PJRT_Buffer_ReadyEvent_Args* args); struct PJRT_Buffer_UnsafePointer_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* buffer; uintptr_t buffer_pointer; // out }; @@ -1756,7 +1790,7 @@ typedef PJRT_Error* PJRT_Buffer_UnsafePointer( struct PJRT_Buffer_IncreaseExternalReferenceCount_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* buffer; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_IncreaseExternalReferenceCount_Args, @@ -1772,7 +1806,7 @@ typedef PJRT_Error* PJRT_Buffer_IncreaseExternalReferenceCount( struct PJRT_Buffer_DecreaseExternalReferenceCount_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* buffer; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_DecreaseExternalReferenceCount_Args, @@ -1786,7 +1820,7 @@ typedef PJRT_Error* PJRT_Buffer_DecreaseExternalReferenceCount( struct PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_Buffer* buffer; void* device_memory_ptr; // out }; @@ -1803,7 +1837,7 @@ typedef PJRT_Error* PJRT_Buffer_OpaqueDeviceMemoryDataPointer( struct PJRT_CopyToDeviceStream_Destroy_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_CopyToDeviceStream* stream; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_CopyToDeviceStream_Destroy_Args, stream); @@ -1814,7 +1848,7 @@ typedef PJRT_Error* PJRT_CopyToDeviceStream_Destroy( struct PJRT_CopyToDeviceStream_AddChunk_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_CopyToDeviceStream* stream; // Takes ownership of `chunk` (i.e. implementation will call chunk.deleter). PJRT_Chunk* chunk; @@ -1835,7 +1869,7 @@ typedef PJRT_Error* PJRT_CopyToDeviceStream_AddChunk( struct PJRT_CopyToDeviceStream_TotalBytes_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_CopyToDeviceStream* stream; int64_t total_bytes; // out }; @@ -1847,7 +1881,7 @@ typedef PJRT_Error* PJRT_CopyToDeviceStream_TotalBytes( struct PJRT_CopyToDeviceStream_GranuleSize_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_CopyToDeviceStream* stream; int64_t granule_size_in_bytes; // out }; @@ -1861,7 +1895,7 @@ typedef PJRT_Error* PJRT_CopyToDeviceStream_GranuleSize( struct PJRT_CopyToDeviceStream_CurrentBytes_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_CopyToDeviceStream* stream; int64_t current_bytes; // out }; @@ -1877,7 +1911,7 @@ typedef PJRT_Error* PJRT_CopyToDeviceStream_CurrentBytes( struct PJRT_TopologyDescription_Create_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; const char* topology_name; size_t topology_name_size; // Extra platform-specific options to create a client. @@ -1894,7 +1928,7 @@ typedef PJRT_Error* PJRT_TopologyDescription_Create( struct PJRT_TopologyDescription_Destroy_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_TopologyDescription* topology; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_Destroy_Args, topology); @@ -1905,7 +1939,7 @@ typedef PJRT_Error* PJRT_TopologyDescription_Destroy( struct PJRT_TopologyDescription_PlatformVersion_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_TopologyDescription* topology; // `platform_version` has the same lifetime as `topology`. It's owned by // `topology`. @@ -1922,7 +1956,7 @@ typedef PJRT_Error* PJRT_TopologyDescription_PlatformVersion( struct PJRT_TopologyDescription_PlatformName_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_TopologyDescription* topology; // `platform_name` has the same lifetime as `topology`. It is owned by // `topology`. @@ -1938,7 +1972,7 @@ typedef PJRT_Error* PJRT_TopologyDescription_PlatformName( struct PJRT_TopologyDescription_GetDeviceDescriptions_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_TopologyDescription* topology; // Has the same lifetime as topology. PJRT_DeviceDescription* const* descriptions; // out @@ -1957,7 +1991,7 @@ typedef struct PJRT_SerializedTopology PJRT_SerializedTopology; struct PJRT_TopologyDescription_Serialize_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_TopologyDescription* topology; // Lives only as long as serialized_topology. @@ -1979,7 +2013,7 @@ typedef PJRT_Error* PJRT_TopologyDescription_Serialize( struct PJRT_TopologyDescription_Attributes_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; PJRT_TopologyDescription* topology; // Only lives as long as topology. @@ -1995,7 +2029,7 @@ typedef PJRT_Error* PJRT_TopologyDescription_Attributes( struct PJRT_Compile_Args { size_t struct_size; - void* priv; + PJRT_Extension_Base* extension_start; const PJRT_TopologyDescription* topology; // Only needs to stay alive for the duration of the Compile call. // `program->format` and `program->format_size` are owned by the caller. @@ -2016,21 +2050,6 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Compile_Args, executable); // PJRT_Client before execution. typedef PJRT_Error* PJRT_Compile(PJRT_Compile_Args* args); -// -------------------------------- Extension ---------------------------------- - -typedef enum { - PJRT_Structure_Type_Gpu_Custom_Call = 0, - PJRT_Structure_Type_Profiler, -} PJRT_Structure_Type; - -// PJRT_Structure_Base contains a type and a pointer to next -// PJRT_Structure_Base. The framework can go through this chain to find -// structure and identify it with the type. -typedef struct PJRT_Structure_Base { - PJRT_Structure_Type type; - const struct PJRT_Structure_Base* next; -} PJRT_Structure_Base; - // -------------------------------- API access --------------------------------- #define _PJRT_API_STRUCT_FIELD(fn_type) fn_type* fn_type @@ -2038,7 +2057,7 @@ typedef struct PJRT_Structure_Base { // Please modify PJRT_Api_STRUCT_SIZE if the last field of PJRT_Api is changed. typedef struct { size_t struct_size; - void* extension_start; + PJRT_Extension_Base* extension_start; PJRT_Api_Version pjrt_api_version; diff --git a/xla/pjrt/c/pjrt_c_api_cpu.cc b/xla/pjrt/c/pjrt_c_api_cpu.cc index 2b60ad794a599..90e391f8fb122 100644 --- a/xla/pjrt/c/pjrt_c_api_cpu.cc +++ b/xla/pjrt/c/pjrt_c_api_cpu.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/c/pjrt_c_api_cpu.h b/xla/pjrt/c/pjrt_c_api_cpu.h index d2756ef396a86..f2599d0516c6d 100644 --- a/xla/pjrt/c/pjrt_c_api_cpu.h +++ b/xla/pjrt/c/pjrt_c_api_cpu.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/c/pjrt_c_api_cpu_internal.cc b/xla/pjrt/c/pjrt_c_api_cpu_internal.cc index b60296cee5555..cd5d22dae6257 100644 --- a/xla/pjrt/c/pjrt_c_api_cpu_internal.cc +++ b/xla/pjrt/c/pjrt_c_api_cpu_internal.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -48,12 +48,14 @@ PJRT_Error* PJRT_CpuDeviceTopology_Create( "Topology not supported for CPU compilation.")}; } -constexpr PJRT_Api pjrt_api = - pjrt::CreatePjrtApi(pjrt::cpu_plugin::PJRT_Client_Create, - pjrt::cpu_plugin::PJRT_CpuDeviceTopology_Create, - pjrt::PJRT_Plugin_Initialize_NoOp); +const PJRT_Api* GetCpuPjrtApi() { + static const PJRT_Api pjrt_api = + pjrt::CreatePjrtApi(pjrt::cpu_plugin::PJRT_Client_Create, + pjrt::cpu_plugin::PJRT_CpuDeviceTopology_Create, + pjrt::PJRT_Plugin_Initialize_NoOp); -const PJRT_Api* GetCpuPjrtApi() { return &pjrt_api; } + return &pjrt_api; +} } // namespace cpu_plugin } // namespace pjrt diff --git a/xla/pjrt/c/pjrt_c_api_cpu_internal.h b/xla/pjrt/c/pjrt_c_api_cpu_internal.h index 794124c5ec8d8..5db2b3785c034 100644 --- a/xla/pjrt/c/pjrt_c_api_cpu_internal.h +++ b/xla/pjrt/c/pjrt_c_api_cpu_internal.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/c/pjrt_c_api_cpu_test.cc b/xla/pjrt/c/pjrt_c_api_cpu_test.cc index 30d94a9479107..ff32e0cb0e212 100644 --- a/xla/pjrt/c/pjrt_c_api_cpu_test.cc +++ b/xla/pjrt/c/pjrt_c_api_cpu_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/c/pjrt_c_api_cpu_version_script.lds b/xla/pjrt/c/pjrt_c_api_cpu_version_script.lds new file mode 100644 index 0000000000000..46cc4278883d1 --- /dev/null +++ b/xla/pjrt/c/pjrt_c_api_cpu_version_script.lds @@ -0,0 +1,12 @@ +# The symbols listed in the "global" section of this file--and only those-- +# are discoverable by programs or frameworks that `dlopen()` libtpu. +# The linker will expose those symbols and no others to frameworks. +VERS_1.0 { + global: + extern "C" { + GetPjrtApi; + }; + + local: + *; +}; diff --git a/xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h b/xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h new file mode 100644 index 0000000000000..825734610b863 --- /dev/null +++ b/xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h @@ -0,0 +1,134 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_C_PJRT_C_API_CUSTOM_PARTITIONER_EXTENSION_H_ +#define XLA_PJRT_C_PJRT_C_API_CUSTOM_PARTITIONER_EXTENSION_H_ + +#include +#include + +#include "xla/pjrt/c/pjrt_c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define PJRT_API_CUSTOM_PARTITIONER_EXTENSION_VERSION 0 + +struct JAX_CustomCallPartitioner_string { + const char* data; + size_t size; +}; + +struct JAX_CustomCallPartitioner_aval { + JAX_CustomCallPartitioner_string shape; + bool has_sharding; + JAX_CustomCallPartitioner_string sharding; +}; + +// General callback information containing api versions, the result error +// message and the cleanup function to free any temporary memory that is backing +// the results. Arguments are always owned by the caller, and results are owned +// by the cleanup_fn. These should never be used directly. Args and results +// should be serialized via the PopulateArgs, ReadArgs, PopulateResults, +// ConsumeResults functions defined below. +struct JAX_CustomCallPartitioner_version_and_error { + int64_t api_version; + void* data; // out + // cleanup_fn cleans up any returned results. The caller must finish with all + // uses by the point the cleanup is called. + void (*cleanup_fn)(void* data); // out + bool has_error; + PJRT_Error_Code code; // out + JAX_CustomCallPartitioner_string error_msg; // out +}; + +struct JAX_CustomCallPartitioner_Partition_Args { + JAX_CustomCallPartitioner_version_and_error header; + + size_t num_args; + JAX_CustomCallPartitioner_aval* op_args; + JAX_CustomCallPartitioner_aval op_result; + JAX_CustomCallPartitioner_string backend_config; + + // out + JAX_CustomCallPartitioner_string mlir_module; + JAX_CustomCallPartitioner_string* args_sharding; + JAX_CustomCallPartitioner_string result_sharding; +}; + +struct JAX_CustomCallPartitioner_InferShardingFromOperands_Args { + JAX_CustomCallPartitioner_version_and_error header; + + size_t num_args; + JAX_CustomCallPartitioner_aval* op_args; + JAX_CustomCallPartitioner_string result_shape; + JAX_CustomCallPartitioner_string backend_config; + + bool has_result_sharding; + JAX_CustomCallPartitioner_string result_sharding; +}; + +struct JAX_CustomCallPartitioner_PropagateUserSharding_Args { + JAX_CustomCallPartitioner_version_and_error header; + + JAX_CustomCallPartitioner_string backend_config; + + JAX_CustomCallPartitioner_string result_shape; + + JAX_CustomCallPartitioner_string result_sharding; // inout +}; + +struct JAX_CustomCallPartitioner_Callbacks { + int64_t version; + void* private_data; + void (*dtor)(JAX_CustomCallPartitioner_Callbacks* data); + void (*partition)(JAX_CustomCallPartitioner_Callbacks* data, + JAX_CustomCallPartitioner_Partition_Args* args); + void (*infer_sharding)( + JAX_CustomCallPartitioner_Callbacks* data, + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args); + void (*propagate_user_sharding)( + JAX_CustomCallPartitioner_Callbacks* data, + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args); + bool can_side_effecting_have_replicated_sharding; +}; + +struct PJRT_Register_Custom_Partitioner_Args { + size_t struct_size; + const char* name; // lifetime of the call. + size_t name_size; + JAX_CustomCallPartitioner_Callbacks* callbacks; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Register_Custom_Partitioner_Args, callbacks); + +// Registers a custom partitioner. +typedef PJRT_Error* PJRT_Register_Custom_Partitioner( + PJRT_Register_Custom_Partitioner_Args* args); + +typedef struct PJRT_Custom_Partitioner_Extension { + size_t struct_size; + PJRT_Extension_Type type; + PJRT_Extension_Base* next; + PJRT_Register_Custom_Partitioner* register_custom_partitioner; +} PJRT_Custom_Partitioner_Extension; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Custom_Partitioner_Extension, + register_custom_partitioner); + +#ifdef __cplusplus +} +#endif + +#endif // XLA_PJRT_C_PJRT_C_API_CUSTOM_PARTITIONER_EXTENSION_H_ diff --git a/xla/pjrt/c/pjrt_c_api_gpu.cc b/xla/pjrt/c/pjrt_c_api_gpu.cc index 2af7ffe2304ac..fe836332be666 100644 --- a/xla/pjrt/c/pjrt_c_api_gpu.cc +++ b/xla/pjrt/c/pjrt_c_api_gpu.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,17 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api_gpu.h" +#include "absl/base/call_once.h" +#include "absl/log/initialize.h" +#include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_gpu_internal.h" +#include "tsl/platform/platform.h" -const PJRT_Api* GetPjrtApi() { return pjrt::gpu_plugin::GetGpuPjrtApi(); } +const PJRT_Api* GetPjrtApi() { + // Initialize ABSL logging because code within XLA uses it. +#ifndef PLATFORM_GOOGLE + static absl::once_flag once; + absl::call_once(once, []() { absl::InitializeLog(); }); +#endif // PLATFORM_GOOGLE + return pjrt::gpu_plugin::GetGpuPjrtApi(); +} diff --git a/xla/pjrt/c/pjrt_c_api_gpu.h b/xla/pjrt/c/pjrt_c_api_gpu.h index a06f59cbcd90d..0ef092ffd9bd3 100644 --- a/xla/pjrt/c/pjrt_c_api_gpu.h +++ b/xla/pjrt/c/pjrt_c_api_gpu.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/c/pjrt_c_api_gpu_extension.h b/xla/pjrt/c/pjrt_c_api_gpu_extension.h index a94d0808b9da8..3ecdaeafb3274 100644 --- a/xla/pjrt/c/pjrt_c_api_gpu_extension.h +++ b/xla/pjrt/c/pjrt_c_api_gpu_extension.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,12 +24,13 @@ limitations under the License. extern "C" { #endif -#define PJRT_API_GPU_EXTENSION_VERSION 0 +#define PJRT_API_GPU_EXTENSION_VERSION 1 struct PJRT_Gpu_Register_Custom_Call_Args { size_t struct_size; const char* function_name; size_t function_name_size; + int api_version; // 0 for an untyped call, 1 -- for typed void* custom_call_function; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Gpu_Register_Custom_Call_Args, @@ -40,10 +41,12 @@ typedef PJRT_Error* PJRT_Gpu_Register_Custom_Call( PJRT_Gpu_Register_Custom_Call_Args* args); typedef struct PJRT_Gpu_Custom_Call { - PJRT_Structure_Type type; - const void* next; + size_t struct_size; + PJRT_Extension_Type type; + PJRT_Extension_Base* next; PJRT_Gpu_Register_Custom_Call* custom_call; } PJRT_Gpu_Custom_Call; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Gpu_Custom_Call, custom_call); #ifdef __cplusplus } diff --git a/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index 5edae823c341e..297c678655298 100644 --- a/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,7 +27,12 @@ limitations under the License. #include "xla/backends/profiler/plugin/plugin_tracer_impl.h" #include "xla/backends/profiler/plugin/profiler_c_api.h" #include "xla/backends/profiler/plugin/profiler_error.h" +#include "xla/client/local_client.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" #include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h" #include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" #include "xla/pjrt/c/pjrt_c_api_profiler_extension.h" @@ -36,7 +41,13 @@ limitations under the License. #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_device_description.h" +#include "xla/python/custom_partition_callback.h" +#include "xla/service/compiler.h" #include "xla/service/custom_call_target_registry.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/stream_executor_pimpl.h" #include "tsl/platform/errors.h" namespace pjrt { @@ -53,15 +64,18 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { pjrt::ConvertFromPjRtNamedValueList(args->create_options, args->num_options); const auto kExpectedOptionNameAndTypes = - absl::flat_hash_map( - {{"platform_name", PJRT_NamedValue_Type::PJRT_NamedValue_kString}, - {"allocator", PJRT_NamedValue_Type::PJRT_NamedValue_kString}, - {"memory_fraction", PJRT_NamedValue_Type::PJRT_NamedValue_kFloat}, - {"preallocate", PJRT_NamedValue_Type::PJRT_NamedValue_kBool}, - {"visible_devices", - PJRT_NamedValue_Type::PJRT_NamedValue_kInt64List}, - {"node_id", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64}, - {"num_nodes", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64}}); + absl::flat_hash_map({ + {"platform_name", PJRT_NamedValue_Type::PJRT_NamedValue_kString}, + {"allocator", PJRT_NamedValue_Type::PJRT_NamedValue_kString}, + {"memory_fraction", PJRT_NamedValue_Type::PJRT_NamedValue_kFloat}, + {"preallocate", PJRT_NamedValue_Type::PJRT_NamedValue_kBool}, + {"collective_memory_size", + PJRT_NamedValue_Type::PJRT_NamedValue_kInt64}, + {"visible_devices", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64List}, + {"node_id", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64}, + {"num_nodes", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64}, + {"enable_mock_nccl", PJRT_NamedValue_Type::PJRT_NamedValue_kBool}, + }); PJRT_RETURN_IF_ERROR( ValidateCreateOptions(create_options, kExpectedOptionNameAndTypes)); @@ -96,6 +110,10 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { it != create_options.end()) { allocator_config.preallocate = std::get(it->second); } + if (auto it = create_options.find("collective_memory_size"); + it != create_options.end()) { + allocator_config.collective_memory_size = std::get(it->second); + } std::optional> visible_devices; if (auto it = create_options.find("visible_devices"); it != create_options.end()) { @@ -110,6 +128,11 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { if (auto it = create_options.find("num_nodes"); it != create_options.end()) { num_nodes = std::get(it->second); } + bool enable_mock_nccl = false; + if (auto it = create_options.find("enable_mock_nccl"); + it != create_options.end()) { + enable_mock_nccl = std::get(it->second); + } xla::GpuClientOptions options; options.allocator_config = allocator_config; @@ -117,10 +140,10 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { options.num_nodes = num_nodes; options.allowed_devices = visible_devices; options.platform_name = platform_name; - options.kv_get = pjrt::ToCppKeyValueGetCallback(args->kv_get_callback, - args->kv_get_user_arg); - options.kv_put = pjrt::ToCppKeyValuePutCallback(args->kv_put_callback, - args->kv_put_user_arg); + options.kv_store = + pjrt::ToCppKeyValueStore(args->kv_get_callback, args->kv_get_user_arg, + args->kv_put_callback, args->kv_put_user_arg); + options.enable_mock_nccl = enable_mock_nccl; PJRT_ASSIGN_OR_RETURN(std::unique_ptr client, xla::GetStreamExecutorGpuClient(options)); args->client = pjrt::CreateWrapperClient(std::move(client)); @@ -129,8 +152,32 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { PJRT_Error* PJRT_GpuDeviceTopology_Create( PJRT_TopologyDescription_Create_Args* args) { - return new PJRT_Error{tsl::errors::Unimplemented( - "Topology not supported for GPU compilation.")}; + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_TopologyDescription_Create_Args", + PJRT_TopologyDescription_Create_Args_STRUCT_SIZE, args->struct_size)); + + PJRT_ASSIGN_OR_RETURN(xla::LocalClient * xla_client, + xla::GetGpuXlaClient(/*platform_name=*/std::nullopt, + /*allowed_devices=*/std::nullopt)); + stream_executor::StreamExecutor* executor = + xla_client->backend().default_stream_executor(); + const stream_executor::DeviceDescription& description = + executor->GetDeviceDescription(); + std::vector device_ids; + device_ids.reserve(xla_client->backend().stream_executors().size()); + for (stream_executor::StreamExecutor* executor : + xla_client->backend().stream_executors()) { + device_ids.push_back(executor->device_ordinal()); + } + auto gpu_target_config = xla::Compiler::TargetConfig(executor); + auto pjrt_topology = + std::make_unique( + xla::CudaId(), xla::CudaName(), description.name(), device_ids, + absl::flat_hash_map{ + {"target_config", + gpu_target_config.ToProto().SerializeAsString()}}); + args->topology = CreateWrapperDeviceTopology(std::move(pjrt_topology)); + return nullptr; } PLUGIN_Profiler_Api profiler_api{ @@ -147,34 +194,72 @@ PLUGIN_Profiler_Api profiler_api{ }; PJRT_Profiler_Extension profiler_extension{ - /*type=*/PJRT_Structure_Type::PJRT_Structure_Type_Profiler, + /*struct_size=*/PJRT_Profiler_Extension_STRUCT_SIZE, + /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Profiler, /*next=*/nullptr, /*profiler_api=*/&profiler_api, }; +PJRT_Error* PJRT_Register_Custom_Partitioner( + PJRT_Register_Custom_Partitioner_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_Register_Custom_Partitioner_Args", + PJRT_Register_Custom_Partitioner_Args_STRUCT_SIZE, args->struct_size)); + std::string name(args->name, args->name_size); + RegisterCustomCallPartitioner( + name, jax::CreateCApiCustomCallPartitioner(args->callbacks)); + return nullptr; +} + +PJRT_Custom_Partitioner_Extension custom_partitioner{ + /*struct_size=*/PJRT_Gpu_Custom_Call_STRUCT_SIZE, + /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner, + /*next=*/reinterpret_cast(&profiler_extension), + /*register_custom_partitioner=*/PJRT_Register_Custom_Partitioner, +}; + PJRT_Error* PJRT_Gpu_Register_Custom_Call( PJRT_Gpu_Register_Custom_Call_Args* args) { PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( "PJRT_Gpu_Register_Custom_Call_Args", PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE, args->struct_size)); std::string function_name(args->function_name, args->function_name_size); - xla::CustomCallTargetRegistry::Global()->Register( - function_name, args->custom_call_function, PJRT_GPU_PLUGIN_PLATFORM_NAME); - return nullptr; + switch (args->api_version) { + case 0: + xla::CustomCallTargetRegistry::Global()->Register( + function_name, args->custom_call_function, + PJRT_GPU_PLUGIN_PLATFORM_NAME); + return nullptr; + case 1: + xla::ffi::Ffi::RegisterStaticHandler( + xla::ffi::GetXlaFfiApi(), function_name, + PJRT_GPU_PLUGIN_PLATFORM_NAME, + reinterpret_cast(args->custom_call_function)); + return nullptr; + default: + return new PJRT_Error{absl::UnimplementedError( + absl::StrFormat("API version %d not supported for PJRT GPU plugin. " + "Supported versions are 0 and 1.", + args->api_version))}; + } } -PJRT_Gpu_Custom_Call custom_call{ - /*type=*/PJRT_Structure_Type::PJRT_Structure_Type_Gpu_Custom_Call, - /*next=*/&profiler_extension, - /*custom_call=*/PJRT_Gpu_Register_Custom_Call, -}; - -constexpr PJRT_Api pjrt_api = pjrt::CreatePjrtApi( - pjrt::gpu_plugin::PJRT_Client_Create, - pjrt::gpu_plugin::PJRT_GpuDeviceTopology_Create, - pjrt::PJRT_Plugin_Initialize_NoOp, static_cast(&custom_call)); +const PJRT_Api* GetGpuPjrtApi() { + static PJRT_Gpu_Custom_Call custom_call{ + /*struct_size=*/PJRT_Gpu_Custom_Call_STRUCT_SIZE, + /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call, + /*next=*/reinterpret_cast(&custom_partitioner), + /*custom_call=*/PJRT_Gpu_Register_Custom_Call, + }; + static const PJRT_Api pjrt_api = + pjrt::CreatePjrtApi(pjrt::gpu_plugin::PJRT_Client_Create, + pjrt::gpu_plugin::PJRT_GpuDeviceTopology_Create, + pjrt::PJRT_Plugin_Initialize_NoOp, + reinterpret_cast(&custom_call), + pjrt::PJRT_Plugin_Attributes_Xla); -const PJRT_Api* GetGpuPjrtApi() { return &pjrt_api; } + return &pjrt_api; +} } // namespace gpu_plugin } // namespace pjrt diff --git a/xla/pjrt/c/pjrt_c_api_gpu_internal.h b/xla/pjrt/c/pjrt_c_api_gpu_internal.h index 4cc9a4b1d8c5e..04c26e5b21220 100644 --- a/xla/pjrt/c/pjrt_c_api_gpu_internal.h +++ b/xla/pjrt/c/pjrt_c_api_gpu_internal.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/xla/pjrt/c/pjrt_c_api_gpu_test.cc index ba1a144b0d946..f5583b3878dd1 100644 --- a/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,6 +35,8 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/pjrt/c/pjrt_c_api.h" @@ -43,6 +45,7 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api_test.h" #include "xla/pjrt/c/pjrt_c_api_test_base.h" #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" +#include "xla/pjrt/distributed/in_memory_key_value_store.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_future.h" @@ -81,7 +84,7 @@ TEST_F(PjrtCApiGpuTest, CreateViewOfDeviceBuffer) { PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args device_buffer_ptr_args; device_buffer_ptr_args.struct_size = PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args_STRUCT_SIZE; - device_buffer_ptr_args.priv = nullptr; + device_buffer_ptr_args.extension_start = nullptr; device_buffer_ptr_args.buffer = buffer.get(); PJRT_Error* device_buffer_ptr_error = api_->PJRT_Buffer_OpaqueDeviceMemoryDataPointer(&device_buffer_ptr_args); @@ -89,7 +92,7 @@ TEST_F(PjrtCApiGpuTest, CreateViewOfDeviceBuffer) { // Looks up a device. PJRT_Buffer_Device_Args device_args = PJRT_Buffer_Device_Args{ /*struct_size=*/PJRT_Buffer_Device_Args_STRUCT_SIZE, - /*priv=*/nullptr, + /*extension_start=*/nullptr, /*buffer=*/buffer.get(), }; PJRT_Error* device_error = api_->PJRT_Buffer_Device(&device_args); @@ -99,7 +102,7 @@ TEST_F(PjrtCApiGpuTest, CreateViewOfDeviceBuffer) { PJRT_Client_CreateViewOfDeviceBuffer_Args create_view_args; create_view_args.struct_size = PJRT_Client_CreateViewOfDeviceBuffer_Args_STRUCT_SIZE; - create_view_args.priv = nullptr; + create_view_args.extension_start = nullptr; create_view_args.client = client_; create_view_args.device_buffer_ptr = device_buffer_ptr_args.device_memory_ptr; xla::Shape shape = xla::ShapeUtil::MakeShape(xla::S32, {4}); @@ -133,7 +136,7 @@ TEST_F(PjrtCApiGpuTest, CreateViewOfDeviceBuffer) { // Transfers view_buffer to host to verify. PJRT_Buffer_ToHostBuffer_Args to_host_args; to_host_args.struct_size = PJRT_Buffer_ToHostBuffer_Args_STRUCT_SIZE; - to_host_args.priv = nullptr; + to_host_args.extension_start = nullptr; to_host_args.src = view_buffer.get(); xla::Shape host_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); auto literal = std::make_shared(host_shape); @@ -155,43 +158,12 @@ TEST_F(PjrtCApiGpuTest, CreateViewOfDeviceBuffer) { xla::LiteralUtil::CreateR1(float_data), *literal)); } -std::unique_ptr<::pjrt::PJRT_KeyValueCallbackData> CreateTestCKVCallback( - absl::flat_hash_map* kv_store, absl::Mutex& mu) { - xla::PjRtClient::KeyValueGetCallback kv_get = - [kv_store, &mu](std::string_view k, - absl::Duration timeout) -> xla::StatusOr { - absl::Duration wait_interval = absl::Milliseconds(10); - int num_retry = timeout / wait_interval; - for (int i = 0; i < num_retry; i++) { - { - absl::MutexLock lock(&mu); - auto iter = kv_store->find(k); - if (iter != kv_store->end()) { - return iter->second; - } - } - absl::SleepFor(wait_interval); - } - return absl::NotFoundError( - absl::StrCat(k, " is not found in the kv store.")); - }; - xla::PjRtClient::KeyValuePutCallback kv_put = - [kv_store, &mu](std::string_view k, std::string_view v) -> xla::Status { - { - absl::MutexLock lock(&mu); - kv_store->insert(std::pair(k, v)); - } - return tsl::OkStatus(); - }; - return ::pjrt::ConvertToCKeyValueCallbacks(kv_get, kv_put); -} - absl::StatusOr BuildCreateArg( ::pjrt::PJRT_KeyValueCallbackData* kv_callback_data, std::vector& c_options) { PJRT_Client_Create_Args args; args.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.create_options = c_options.data(); args.num_options = c_options.size(); args.kv_get_callback = kv_callback_data->c_kv_get; @@ -204,11 +176,9 @@ absl::StatusOr BuildCreateArg( TEST(PjrtCApiGpuKVStoreTest, CreateClientWithKVCallback) { auto api = GetPjrtApi(); - auto kv_store_ptr = - std::make_shared>(); - absl::Mutex mu; + auto kv_store = std::make_shared(); std::shared_ptr<::pjrt::PJRT_KeyValueCallbackData> kv_callback_data = - CreateTestCKVCallback(kv_store_ptr.get(), mu); + ::pjrt::ConvertToCKeyValueCallbacks(kv_store); int num_nodes = 2; std::vector threads; @@ -216,13 +186,12 @@ TEST(PjrtCApiGpuKVStoreTest, CreateClientWithKVCallback) { for (int i = 0; i < num_nodes; i++) { threads.emplace_back([api, i, num_nodes, kv_callback_data = kv_callback_data, - kv_store_ptr = kv_store_ptr] { + kv_store = kv_store] { absl::flat_hash_map options = { {"num_nodes", static_cast(num_nodes)}, {"node_id", static_cast(i)}}; TF_ASSERT_OK_AND_ASSIGN(std::vector c_options, - ::pjrt::ConvertToPjRtNamedValueList( - options, /*api_minor_version=*/30)); + ::pjrt::ConvertToPjRtNamedValueList(options)); TF_ASSERT_OK_AND_ASSIGN( PJRT_Client_Create_Args create_arg, BuildCreateArg(kv_callback_data.get(), c_options)); @@ -231,7 +200,7 @@ TEST(PjrtCApiGpuKVStoreTest, CreateClientWithKVCallback) { PJRT_Client_Devices_Args device_args; device_args.struct_size = PJRT_Client_Devices_Args_STRUCT_SIZE; - device_args.priv = nullptr; + device_args.extension_start = nullptr; device_args.client = create_arg.client; PJRT_Error* device_error = api->PJRT_Client_Devices(&device_args); @@ -241,7 +210,7 @@ TEST(PjrtCApiGpuKVStoreTest, CreateClientWithKVCallback) { PJRT_Client_AddressableDevices_Args addressable_device_args; addressable_device_args.struct_size = PJRT_Client_AddressableDevices_Args_STRUCT_SIZE; - addressable_device_args.priv = nullptr; + addressable_device_args.extension_start = nullptr; addressable_device_args.client = create_arg.client; PJRT_Error* addressable_device_error = @@ -251,7 +220,7 @@ TEST(PjrtCApiGpuKVStoreTest, CreateClientWithKVCallback) { PJRT_Client_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_Client_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.client = create_arg.client; PJRT_Error* destroy_error = api->PJRT_Client_Destroy(&destroy_args); @@ -278,12 +247,11 @@ TEST(PjrtCApiGpuAllocatorTest, ValidOptionsParsing) { if (allocator_option == "cuda_async") { options["preallocate"] = true; } - TF_ASSERT_OK_AND_ASSIGN( - std::vector c_options, - ::pjrt::ConvertToPjRtNamedValueList(options, /*api_minor_version=*/30)); + TF_ASSERT_OK_AND_ASSIGN(std::vector c_options, + ::pjrt::ConvertToPjRtNamedValueList(options)); PJRT_Client_Create_Args create_arg; create_arg.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE; - create_arg.priv = nullptr; + create_arg.extension_start = nullptr; create_arg.client = nullptr; create_arg.create_options = c_options.data(); create_arg.num_options = c_options.size(); @@ -292,7 +260,7 @@ TEST(PjrtCApiGpuAllocatorTest, ValidOptionsParsing) { PJRT_Client_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_Client_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.client = create_arg.client; PJRT_Error* destroy_error = api->PJRT_Client_Destroy(&destroy_args); @@ -307,12 +275,11 @@ TEST(PjrtCApiGpuAllocatorTest, InvalidAllocatorOptionsParsing) { {"memory_fraction", 0.5f}, {"preallocate", true}, }; - TF_ASSERT_OK_AND_ASSIGN( - std::vector c_options, - ::pjrt::ConvertToPjRtNamedValueList(options, /*api_minor_version=*/30)); + TF_ASSERT_OK_AND_ASSIGN(std::vector c_options, + ::pjrt::ConvertToPjRtNamedValueList(options)); PJRT_Client_Create_Args create_arg; create_arg.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE; - create_arg.priv = nullptr; + create_arg.extension_start = nullptr; create_arg.client = nullptr; create_arg.create_options = c_options.data(); create_arg.num_options = c_options.size(); @@ -327,7 +294,7 @@ TEST(PjrtCApiGpuAllocatorTest, InvalidAllocatorOptionsParsing) { PJRT_Error_Destroy_Args error_destroy_args; error_destroy_args.struct_size = PJRT_Error_Destroy_Args_STRUCT_SIZE; - error_destroy_args.priv = nullptr; + error_destroy_args.extension_start = nullptr; error_destroy_args.error = error; api->PJRT_Error_Destroy(&error_destroy_args); @@ -342,12 +309,11 @@ TEST(PjrtCApiPlatformNameTest, AvailablePlatformName) { {"allocator", static_cast("default")}, {"visible_devices", xla::PjRtValueType(std::vector{0, 1})}, }; - TF_ASSERT_OK_AND_ASSIGN( - std::vector c_options, - ::pjrt::ConvertToPjRtNamedValueList(options, /*api_minor_version=*/30)); + TF_ASSERT_OK_AND_ASSIGN(std::vector c_options, + ::pjrt::ConvertToPjRtNamedValueList(options)); PJRT_Client_Create_Args create_arg; create_arg.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE; - create_arg.priv = nullptr; + create_arg.extension_start = nullptr; create_arg.client = nullptr; create_arg.create_options = c_options.data(); create_arg.num_options = c_options.size(); @@ -356,7 +322,7 @@ TEST(PjrtCApiPlatformNameTest, AvailablePlatformName) { PJRT_Client_PlatformName_Args platform_name_args; platform_name_args.struct_size = PJRT_Client_PlatformName_Args_STRUCT_SIZE; - platform_name_args.priv = nullptr; + platform_name_args.extension_start = nullptr; platform_name_args.client = create_arg.client; PJRT_Error* platform_name_error = @@ -370,7 +336,7 @@ TEST(PjrtCApiPlatformNameTest, AvailablePlatformName) { PJRT_Client_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_Client_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.client = create_arg.client; PJRT_Error* destroy_error = api->PJRT_Client_Destroy(&destroy_args); @@ -384,12 +350,11 @@ TEST(PjrtCApiPlatformNameTest, UnavailablePlatformName) { {"allocator", static_cast("default")}, {"visible_devices", xla::PjRtValueType(std::vector{0, 1})}, }; - TF_ASSERT_OK_AND_ASSIGN( - std::vector c_options, - ::pjrt::ConvertToPjRtNamedValueList(options, /*api_minor_version=*/30)); + TF_ASSERT_OK_AND_ASSIGN(std::vector c_options, + ::pjrt::ConvertToPjRtNamedValueList(options)); PJRT_Client_Create_Args create_arg; create_arg.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE; - create_arg.priv = nullptr; + create_arg.extension_start = nullptr; create_arg.client = nullptr; create_arg.create_options = c_options.data(); create_arg.num_options = c_options.size(); @@ -404,27 +369,28 @@ TEST(PjrtCApiPlatformNameTest, UnavailablePlatformName) { PJRT_Error_Destroy_Args error_destroy_args; error_destroy_args.struct_size = PJRT_Error_Destroy_Args_STRUCT_SIZE; - error_destroy_args.priv = nullptr; + error_destroy_args.extension_start = nullptr; error_destroy_args.error = error; api->PJRT_Error_Destroy(&error_destroy_args); } -void TestCustomCall() {} +void TestCustomCallV2() {} -TEST(PjrtCApiGpuPrivTest, CustomCall) { +TEST(PjrtCApiGpuExtensionTest, CustomCallUntyped) { PJRT_Gpu_Register_Custom_Call_Args args; args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; - std::string function_name = "function_name"; + std::string function_name = "untyped_function_name"; args.function_name = function_name.c_str(); args.function_name_size = function_name.size(); - args.custom_call_function = reinterpret_cast(&TestCustomCall); + args.api_version = 0; + args.custom_call_function = reinterpret_cast(&TestCustomCallV2); auto api = GetPjrtApi(); - const PJRT_Structure_Base* next = - reinterpret_cast(api->extension_start); + const PJRT_Extension_Base* next = + reinterpret_cast(api->extension_start); while (next != nullptr && next->type != - PJRT_Structure_Type::PJRT_Structure_Type_Gpu_Custom_Call) { + PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { next = next->next; } ASSERT_NE(next, nullptr); @@ -435,7 +401,37 @@ TEST(PjrtCApiGpuPrivTest, CustomCall) { CHECK_EQ(error, nullptr); void* custom_call = xla::CustomCallTargetRegistry::Global()->Lookup(function_name, "CUDA"); - EXPECT_EQ(custom_call, reinterpret_cast(&TestCustomCall)); + EXPECT_EQ(custom_call, reinterpret_cast(&TestCustomCallV2)); +} + +static void* kNoop = xla::ffi::Ffi::Bind() + .To([]() { return xla::ffi::Error::Success(); }) + .release(); + +TEST(PjrtCApiGpuExtensionTest, CustomCallTyped) { + PJRT_Gpu_Register_Custom_Call_Args args; + args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; + std::string function_name = "typed_function_name"; + args.function_name = function_name.c_str(); + args.function_name_size = function_name.size(); + args.api_version = 1; + args.custom_call_function = kNoop; + auto api = GetPjrtApi(); + const PJRT_Extension_Base* next = + reinterpret_cast(api->extension_start); + while (next != nullptr && + next->type != + PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { + next = next->next; + } + ASSERT_NE(next, nullptr); + + PJRT_Error* error = + reinterpret_cast(next)->custom_call(&args); + + CHECK_EQ(error, nullptr); + auto registration = xla::ffi::FindHandler(function_name, "CUDA").value(); + EXPECT_EQ(reinterpret_cast(registration.handler), kNoop); } } // namespace diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index 5e960c3cd60d3..aade90c1b534e 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -28,12 +29,15 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "xla/layout.h" #include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_profiler_extension.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_executable.h" @@ -41,13 +45,14 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/shape_util.h" #include "xla/status.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/connected_traceme.h" +#include "tsl/profiler/lib/context_types.h" namespace pjrt { @@ -59,7 +64,7 @@ PJRT_ClientDeleter MakeClientDeleter(const PJRT_Api* api) { return [api](PJRT_Client* client) -> void { PJRT_Client_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_Client_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.client = client; PJRT_Error* error = api->PJRT_Client_Destroy(&destroy_args); @@ -72,7 +77,7 @@ PJRT_ErrorDeleter MakeErrorDeleter(const PJRT_Api* api) { return [api](PJRT_Error* error) -> void { PJRT_Error_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_Error_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.error = error; api->PJRT_Error_Destroy(&destroy_args); @@ -83,7 +88,7 @@ PJRT_BufferDeleter MakeBufferDeleter(const PJRT_Api* api) { return [api](PJRT_Buffer* buffer) -> void { PJRT_Buffer_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_Buffer_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.buffer = buffer; pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_Destroy(&destroy_args), api); @@ -94,7 +99,7 @@ PJRT_ExecutableDeleter MakeExecutableDeleter(const PJRT_Api* api) { return [api](PJRT_Executable* executable) -> void { PJRT_Executable_Destroy_Args args; args.struct_size = PJRT_Executable_Destroy_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = executable; pjrt::LogFatalIfPjrtError(api->PJRT_Executable_Destroy(&args), api); }; @@ -104,17 +109,17 @@ PJRT_LoadedExecutableDeleter MakeLoadedExecutableDeleter(const PJRT_Api* api) { return [api](PJRT_LoadedExecutable* executable) -> void { PJRT_LoadedExecutable_Destroy_Args args; args.struct_size = PJRT_LoadedExecutable_Destroy_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = executable; pjrt::LogFatalIfPjrtError(api->PJRT_LoadedExecutable_Destroy(&args), api); }; } -xla::Status PjrtErrorToStatus(const PJRT_Error* error, const PJRT_Api* api) { - xla::Status status; +absl::Status PjrtErrorToStatus(const PJRT_Error* error, const PJRT_Api* api) { + absl::Status status; if (error != nullptr) { - status = xla::Status(PjrtErrorToStatusCode(error, api), - GetPjrtErrorMessage(error, api)); + status = absl::Status(PjrtErrorToStatusCode(error, api), + GetPjrtErrorMessage(error, api)); } return status; } @@ -125,7 +130,7 @@ PJRT_TopologyDescriptionDeleter MakeTopologyDescriptionDeleter( PJRT_TopologyDescription_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_TopologyDescription_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.topology = topology; pjrt::LogFatalIfPjrtError( @@ -136,7 +141,7 @@ PJRT_TopologyDescriptionDeleter MakeTopologyDescriptionDeleter( PJRT_Error_Code GetErrorCode(const PJRT_Error* error, const PJRT_Api* api) { PJRT_Error_GetCode_Args args; args.struct_size = PJRT_Error_GetCode_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.error = error; pjrt::LogFatalIfPjrtError(api->PJRT_Error_GetCode(&args), api); return args.code; @@ -144,7 +149,10 @@ PJRT_Error_Code GetErrorCode(const PJRT_Error* error, const PJRT_Api* api) { absl::StatusCode PjrtErrorToStatusCode(const PJRT_Error* error, const PJRT_Api* api) { - PJRT_Error_Code code = GetErrorCode(error, api); + return PjrtErrorCodeToStatusCode(GetErrorCode(error, api)); +} + +absl::StatusCode PjrtErrorCodeToStatusCode(PJRT_Error_Code code) { switch (code) { case PJRT_Error_Code_CANCELLED: case PJRT_Error_Code_UNKNOWN: @@ -203,7 +211,7 @@ absl::string_view GetPjrtErrorMessage(const PJRT_Error* error, const PJRT_Api* api) { PJRT_Error_Message_Args message_args; message_args.struct_size = PJRT_Error_Message_Args_STRUCT_SIZE; - message_args.priv = nullptr; + message_args.extension_start = nullptr; message_args.error = error; api->PJRT_Error_Message(&message_args); return absl::string_view(message_args.message, message_args.message_size); @@ -212,7 +220,7 @@ absl::string_view GetPjrtErrorMessage(const PJRT_Error* error, void LogFatalIfPjrtError(PJRT_Error* error, const PJRT_Api* api) { std::unique_ptr _error( error, MakeErrorDeleter(api)); - xla::Status _status = PjrtErrorToStatus(_error.get(), api); + absl::Status _status = PjrtErrorToStatus(_error.get(), api); if (!_status.ok()) { LOG(FATAL) << "Unexpected error status " << _status.message(); } @@ -223,7 +231,7 @@ PJRT_EventDeleter MakeEventDeleter(const PJRT_Api* api) { return [api](PJRT_Event* managed) { PJRT_Event_Destroy_Args args; args.struct_size = PJRT_Event_Destroy_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.event = managed; LogFatalIfPjrtError(api->PJRT_Event_Destroy(&args), api); @@ -341,8 +349,8 @@ const char* HostBufferSemanticsToString( switch (h) { case xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall: return "xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall"; - case xla::PjRtClient::HostBufferSemantics::kZeroCopy: - return "xla::PjRtClient::HostBufferSemantics::kZeroCopy"; + case xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy: + return "xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy"; case xla::PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes: return "xla::PjRtClient::HostBufferSemantics::" "kImmutableUntilTransferCompletes"; @@ -358,8 +366,9 @@ PJRT_HostBufferSemantics ConvertToPjRtHostBufferSemantics( case xla::PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes: return PJRT_HostBufferSemantics:: PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes; - case xla::PjRtClient::HostBufferSemantics::kZeroCopy: - return PJRT_HostBufferSemantics::PJRT_HostBufferSemantics_kZeroCopy; + case xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy: + return PJRT_HostBufferSemantics:: + PJRT_HostBufferSemantics_kImmutableZeroCopy; default: CHECK(false) << "Input host buffer semantics is not supported in C API layer: " @@ -377,28 +386,28 @@ xla::PjRtClient::HostBufferSemantics ConvertFromPjRtHostBufferSemantics( PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes: return xla::PjRtClient::HostBufferSemantics:: kImmutableUntilTransferCompletes; - case PJRT_HostBufferSemantics::PJRT_HostBufferSemantics_kZeroCopy: - return xla::PjRtClient::HostBufferSemantics::kZeroCopy; + case PJRT_HostBufferSemantics::PJRT_HostBufferSemantics_kImmutableZeroCopy: + return xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy; } } -xla::PjRtFuture ConvertCEventToCppFuture(PJRT_Event* c_event, - const PJRT_Api* c_api) { - using xla::Status, xla::PjRtFuture; +xla::PjRtFuture ConvertCEventToCppFuture(PJRT_Event* c_event, + const PJRT_Api* c_api) { + using absl::Status, xla::PjRtFuture; PJRT_Event_OnReady_Args event_onready_args; event_onready_args.struct_size = PJRT_Event_OnReady_Args_STRUCT_SIZE; - event_onready_args.priv = nullptr; + event_onready_args.extension_start = nullptr; event_onready_args.event = c_event; PjRtFuture::Promise promise = PjRtFuture::CreatePromise(); event_onready_args.user_arg = new std::function( [promise, c_event, c_api](PJRT_Error* error) mutable { if (error != nullptr) { - xla::Status s = ::pjrt::PjrtErrorToStatus(error, c_api); + absl::Status s = ::pjrt::PjrtErrorToStatus(error, c_api); promise.Set(s); ::pjrt::MakeErrorDeleter(c_api)(error); } else { - promise.Set(tsl::OkStatus()); + promise.Set(absl::OkStatus()); } ::pjrt::MakeEventDeleter(c_api)(c_event); }); @@ -411,18 +420,17 @@ xla::PjRtFuture ConvertCEventToCppFuture(PJRT_Event* c_event, PJRT_Error* error = c_api->PJRT_Event_OnReady(&event_onready_args); if (error != nullptr) { - xla::Status s = ::pjrt::PjrtErrorToStatus(error, c_api); + absl::Status s = ::pjrt::PjrtErrorToStatus(error, c_api); return PjRtFuture(s); } return PjRtFuture(std::move(promise)); } -static xla::StatusOr ConvertToPjRtNamedValue( - const std::string& name, const xla::PjRtValueType& value, - int api_minor_version) { +static absl::StatusOr ConvertToPjRtNamedValue( + const std::string& name, const xla::PjRtValueType& value) { PJRT_NamedValue c_value; c_value.struct_size = PJRT_NamedValue_STRUCT_SIZE; - c_value.priv = nullptr; + c_value.extension_start = nullptr; c_value.name = name.c_str(); c_value.name_size = name.size(); @@ -446,15 +454,6 @@ static xla::StatusOr ConvertToPjRtNamedValue( c_value.float_value = std::get(value); c_value.value_size = 1; } else if (std::holds_alternative(value)) { - // TODO: b/300294893 - Remove this after 12 weeks (12/06/2023) as that is - // how long we support old behavior for - if (api_minor_version < 30) { - return absl::InvalidArgumentError(absl::StrCat( - "Client cannot provide this option for API versions " - "less than 0.30. The framework PJRT API version is ", - PJRT_API_MAJOR, ".", PJRT_API_MINOR, - "and the plugin minor version is ", api_minor_version, ".")); - } c_value.type = PJRT_NamedValue_Type::PJRT_NamedValue_kBool; c_value.bool_value = std::get(value); c_value.value_size = 1; @@ -466,15 +465,13 @@ static xla::StatusOr ConvertToPjRtNamedValue( return c_value; } -xla::StatusOr> ConvertToPjRtNamedValueList( - const absl::flat_hash_map& cpp_value_map, - int api_minor_version) { +absl::StatusOr> ConvertToPjRtNamedValueList( + const absl::flat_hash_map& cpp_value_map) { std::vector c_value_list; c_value_list.reserve(cpp_value_map.size()); for (const auto& [name, value] : cpp_value_map) { - TF_ASSIGN_OR_RETURN( - PJRT_NamedValue c_value, - ConvertToPjRtNamedValue(name, value, api_minor_version)); + TF_ASSIGN_OR_RETURN(PJRT_NamedValue c_value, + ConvertToPjRtNamedValue(name, value)); c_value_list.push_back(c_value); } return c_value_list; @@ -522,7 +519,7 @@ ConvertFromPjRtNamedValueList(const PJRT_NamedValue* c_value_list, return cpp_value_map; } -static xla::StatusOr GetPjrtNamedValueType( +static absl::StatusOr GetPjrtNamedValueType( xla::PjRtValueType cpp_value) { if (std::holds_alternative(cpp_value)) { return PJRT_NamedValue_Type::PJRT_NamedValue_kString; @@ -543,7 +540,7 @@ static xla::StatusOr GetPjrtNamedValueType( cpp_value.index()); } -xla::Status ValidateCreateOptions( +absl::Status ValidateCreateOptions( const absl::flat_hash_map& value_map, const absl::flat_hash_map& expected_name_and_types) { @@ -562,7 +559,23 @@ xla::Status ValidateCreateOptions( it->second); } } - return tsl::OkStatus(); + return absl::OkStatus(); +} + +const std::vector& GetXlaPluginCAttributes() { + constexpr absl::string_view kXlaVersion = "xla_version"; + PJRT_NamedValue c_value; + c_value.struct_size = PJRT_NamedValue_STRUCT_SIZE; + c_value.extension_start = nullptr; + c_value.name = kXlaVersion.data(); + c_value.name_size = kXlaVersion.size(); + c_value.type = PJRT_NamedValue_Type::PJRT_NamedValue_kInt64; + // TODO(b/327203806): figure out where to keep the xla_version. + c_value.int64_value = 1; + c_value.value_size = 1; + static const std::vector* c_values = + new std::vector({c_value}); + return *c_values; } static std::string StructSizeErrorMsg(absl::string_view struct_name, @@ -578,9 +591,9 @@ static std::string StructSizeErrorMsg(absl::string_view struct_name, return error_msg; } -xla::Status ActualStructSizeIsGreaterOrEqual(absl::string_view struct_name, - size_t expected_size, - size_t actual_size) { +absl::Status ActualStructSizeIsGreaterOrEqual(absl::string_view struct_name, + size_t expected_size, + size_t actual_size) { if (actual_size < expected_size) { return tsl::errors::InvalidArgument( StructSizeErrorMsg(struct_name, expected_size, actual_size)); @@ -588,13 +601,13 @@ xla::Status ActualStructSizeIsGreaterOrEqual(absl::string_view struct_name, if (actual_size > expected_size) { VLOG(2) << StructSizeErrorMsg(struct_name, expected_size, actual_size); } - return tsl::OkStatus(); + return absl::OkStatus(); } absl::string_view GetPlatformVersion(PJRT_Client* client, const PJRT_Api* api) { PJRT_Client_PlatformVersion_Args args; args.struct_size = PJRT_Client_PlatformVersion_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.client = client; LogFatalIfPjrtError(api->PJRT_Client_PlatformVersion(&args), api); @@ -607,18 +620,18 @@ absl::string_view GetPlatformName(PJRT_Client* client, const PJRT_Api* api) { PJRT_Client_PlatformName_Args args; args.client = client; args.struct_size = PJRT_Client_PlatformName_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; pjrt::LogFatalIfPjrtError(api->PJRT_Client_PlatformName(&args), api); absl::string_view platform_name(args.platform_name, args.platform_name_size); return platform_name; } -xla::StatusOr GetTopologyDescription( +absl::StatusOr GetTopologyDescription( PJRT_Client* client, const PJRT_Api* api) { PJRT_Client_TopologyDescription_Args args; args.struct_size = PJRT_Client_TopologyDescription_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.client = client; RETURN_STATUS_IF_PJRT_ERROR(api->PJRT_Client_TopologyDescription(&args), api); return args.topology; @@ -657,7 +670,7 @@ PJRT_DeviceDescription* GetDeviceDescription(const PJRT_Api* api, PJRT_Device* device) { PJRT_Device_GetDescription_Args args; args.struct_size = PJRT_Device_GetDescription_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device = device; pjrt::LogFatalIfPjrtError(api->PJRT_Device_GetDescription(&args), api); return args.device_description; @@ -667,7 +680,7 @@ absl::Span GetAddressableMemories(const PJRT_Api* api, PJRT_Device* device) { PJRT_Device_AddressableMemories_Args args; args.struct_size = PJRT_Device_AddressableMemories_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device = device; pjrt::LogFatalIfPjrtError(api->PJRT_Device_AddressableMemories(&args), api); return absl::MakeSpan(args.memories, args.num_memories); @@ -683,11 +696,11 @@ int GetId(const PJRT_Api* api, PJRT_DeviceDescription* device_desc) { static void PjRtValueDeleterCallback(char* value) { delete[] value; } static PJRT_KeyValueGetCFunc ToKVGetCFunc( - const xla::PjRtClient::KeyValueGetCallback& cpp_kv_get) { - return [&cpp_kv_get](PJRT_KeyValueGetCallback_Args* args) -> PJRT_Error* { - xla::StatusOr output = - cpp_kv_get(std::string(args->key, args->key_size), - absl::Milliseconds(args->timeout_in_ms)); + xla::KeyValueStoreInterface* kv_store) { + return [kv_store](PJRT_KeyValueGetCallback_Args* args) -> PJRT_Error* { + absl::StatusOr output = + kv_store->Get(std::string_view(args->key, args->key_size), + absl::Milliseconds(args->timeout_in_ms)); if (!output.ok()) { absl::string_view message = output.status().message(); return (*args->callback_error)( @@ -703,10 +716,11 @@ static PJRT_KeyValueGetCFunc ToKVGetCFunc( } static PJRT_KeyValuePutCFunc ToKVPutCFunc( - const xla::PjRtClient::KeyValuePutCallback& cpp_kv_put) { - return [&cpp_kv_put](PJRT_KeyValuePutCallback_Args* args) -> PJRT_Error* { - xla::Status status = cpp_kv_put(std::string(args->key, args->key_size), - std::string(args->value, args->value_size)); + xla::KeyValueStoreInterface* kv_store) { + return [kv_store](PJRT_KeyValuePutCallback_Args* args) -> PJRT_Error* { + absl::Status status = + kv_store->Set(std::string_view(args->key, args->key_size), + std::string_view(args->value, args->value_size)); if (!status.ok()) { absl::string_view message = status.message(); return (*args->callback_error)(StatusCodeToPjrtErrorCode(status.code()), @@ -722,7 +736,7 @@ static PJRT_KeyValueGetCallback ToCKVGetCallback( PJRT_KeyValueGetCFunc* kv_get_c_func = reinterpret_cast(args->user_arg); if (kv_get_c_func == nullptr) { - xla::Status status = xla::InvalidArgument( + absl::Status status = xla::InvalidArgument( "got nullptr for PJRT_KeyValueGet_Args.user_arg"); return (*args->callback_error)(StatusCodeToPjrtErrorCode(status.code()), status.message().data(), @@ -738,7 +752,7 @@ static PJRT_KeyValuePutCallback ToCKVPutCallback( PJRT_KeyValuePutCFunc* kv_put_c_func = reinterpret_cast(args->user_arg); if (kv_put_c_func == nullptr) { - xla::Status status = xla::InvalidArgument( + absl::Status status = xla::InvalidArgument( "got nullptr for PJRT_KeyValuePut_Args.user_arg"); return (*args->callback_error)(StatusCodeToPjrtErrorCode(status.code()), status.message().data(), @@ -749,17 +763,15 @@ static PJRT_KeyValuePutCallback ToCKVPutCallback( } std::unique_ptr ConvertToCKeyValueCallbacks( - xla::PjRtClient::KeyValueGetCallback kv_get, - xla::PjRtClient::KeyValuePutCallback kv_put) { + std::shared_ptr kv_store) { auto kv_callback_data = std::make_unique(); - kv_callback_data->kv_get = std::move(kv_get); - kv_callback_data->kv_put = std::move(kv_put); - kv_callback_data->kv_get_c_func = ToKVGetCFunc(kv_callback_data->kv_get); - kv_callback_data->kv_put_c_func = ToKVPutCFunc(kv_callback_data->kv_put); + kv_callback_data->kv_get_c_func = ToKVGetCFunc(kv_store.get()); + kv_callback_data->kv_put_c_func = ToKVPutCFunc(kv_store.get()); kv_callback_data->c_kv_get = ToCKVGetCallback(&kv_callback_data->kv_get_c_func); kv_callback_data->c_kv_put = ToCKVPutCallback(&kv_callback_data->kv_put_c_func); + kv_callback_data->kv_store = std::move(kv_store); return kv_callback_data; } @@ -806,7 +818,7 @@ PJRT_RecvCallbackInfo CppRecvCallbackToCRecvCallback( }}; } -xla::StatusOr ConvertToBufferMemoryLayoutData( +absl::StatusOr ConvertToBufferMemoryLayoutData( const xla::Layout& cpp_layout) { BufferMemoryLayoutData layout_data; layout_data.c_layout.type = @@ -833,7 +845,7 @@ xla::StatusOr ConvertToBufferMemoryLayoutData( return layout_data; } -xla::StatusOr ConvertToBufferMemoryLayoutData( +absl::StatusOr ConvertToBufferMemoryLayoutData( absl::Span byte_strides) { BufferMemoryLayoutData layout_data; layout_data.c_layout.type = @@ -843,7 +855,7 @@ xla::StatusOr ConvertToBufferMemoryLayoutData( return layout_data; } -xla::StatusOr ConvertToLayout( +absl::StatusOr ConvertToLayout( const PJRT_Buffer_MemoryLayout_Tiled& c_tiled) { absl::Span minor_to_major(c_tiled.minor_to_major, c_tiled.minor_to_major_size); @@ -863,7 +875,7 @@ xla::StatusOr ConvertToLayout( PJRT_Buffer_Type GetElementType(const PJRT_Api* api, PJRT_Buffer* buffer) { PJRT_Buffer_ElementType_Args args; args.struct_size = PJRT_Buffer_ElementType_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer; LogFatalIfPjrtError(api->PJRT_Buffer_ElementType(&args), api); return args.type; @@ -873,7 +885,7 @@ absl::Span GetDimensions(const PJRT_Api* api, PJRT_Buffer* buffer) { PJRT_Buffer_Dimensions_Args args; args.struct_size = PJRT_Buffer_Dimensions_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer; LogFatalIfPjrtError(api->PJRT_Buffer_Dimensions(&args), api); return {args.dims, args.num_dims}; @@ -883,16 +895,15 @@ PJRT_Buffer_MemoryLayout GetMemoryLayout(const PJRT_Api* api, PJRT_Buffer* buffer) { PJRT_Buffer_GetMemoryLayout_Args args; args.struct_size = PJRT_Buffer_GetMemoryLayout_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer; LogFatalIfPjrtError(api->PJRT_Buffer_GetMemoryLayout(&args), api); return args.layout; } -xla::StatusOr BuildXlaShapeFromC(PJRT_Buffer_Type element_type, - const int64_t* dims, - size_t num_dims, - PJRT_Buffer_MemoryLayout* layout) { +absl::StatusOr BuildXlaShapeFromC( + PJRT_Buffer_Type element_type, const int64_t* dims, size_t num_dims, + PJRT_Buffer_MemoryLayout* layout) { xla::Shape shape = xla::ShapeUtil::MakeShape(ConvertFromPjRtBufferType(element_type), absl::Span(dims, num_dims)); @@ -924,7 +935,7 @@ absl::string_view PlatformName(const PJRT_Api* api, const PJRT_TopologyDescription* topo_desc) { PJRT_TopologyDescription_PlatformName_Args args; args.struct_size = PJRT_TopologyDescription_PlatformName_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.topology = const_cast(topo_desc); LogFatalIfPjrtError(api->PJRT_TopologyDescription_PlatformName(&args), api); return {args.platform_name, args.platform_name_size}; @@ -935,7 +946,7 @@ absl::Span DeviceDescriptions( PJRT_TopologyDescription_GetDeviceDescriptions_Args args; args.struct_size = PJRT_TopologyDescription_GetDeviceDescriptions_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.topology = const_cast(topo_desc); LogFatalIfPjrtError( api->PJRT_TopologyDescription_GetDeviceDescriptions(&args), api); @@ -944,9 +955,16 @@ absl::Span DeviceDescriptions( absl::StatusOr GetCompiledMemoryStats( const PJRT_Api* api, PJRT_Executable* executable) { + // TODO(jieying): To be removed after 03/2024. + if (api->pjrt_api_version.major_version == 0 && + api->pjrt_api_version.minor_version < 40) { + return absl::UnimplementedError( + "GetCompiledMemoryStats requires a plugin with PJRT C API version >= " + "0.40"); + } PJRT_Executable_GetCompiledMemoryStats_Args args; args.struct_size = PJRT_Executable_GetCompiledMemoryStats_Args_STRUCT_SIZE; - args.priv = 0; + args.extension_start = nullptr; args.executable = executable; RETURN_STATUS_IF_PJRT_ERROR( api->PJRT_Executable_GetCompiledMemoryStats(&args), api); @@ -956,7 +974,28 @@ absl::StatusOr GetCompiledMemoryStats( results.output_size_in_bytes = args.output_size_in_bytes; results.alias_size_in_bytes = args.alias_size_in_bytes; results.temp_size_in_bytes = args.temp_size_in_bytes; + results.host_generated_code_size_in_bytes = + args.host_generated_code_size_in_bytes; + results.host_argument_size_in_bytes = args.host_argument_size_in_bytes; + results.host_output_size_in_bytes = args.host_output_size_in_bytes; + results.host_alias_size_in_bytes = args.host_alias_size_in_bytes; + results.host_temp_size_in_bytes = args.host_temp_size_in_bytes; return results; } +PJRT_Profiler_Extension CreatePjrtProfilerExtension( + absl::string_view traceme_name) { + tsl::profiler::TraceMeProducer producer( + traceme_name, tsl::profiler::ContextType::kPjrtLibraryCall); + int64_t traceme_context_id = producer.GetContextId(); + PJRT_Profiler_Extension profiler_extension{ + /*struct_size=*/PJRT_Profiler_Extension_STRUCT_SIZE, + /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Profiler, + /*next=*/nullptr, + /*profiler_api=*/nullptr, + /*traceme_context_id=*/traceme_context_id, + }; + return profiler_extension; +} + } // namespace pjrt diff --git a/xla/pjrt/c/pjrt_c_api_helpers.h b/xla/pjrt/c/pjrt_c_api_helpers.h index f727a00709713..bfd05baef4757 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.h +++ b/xla/pjrt/c/pjrt_c_api_helpers.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,7 +22,10 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_profiler_extension.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_future.h" #include "xla/status.h" @@ -111,6 +114,7 @@ xla::Status PjrtErrorToStatus(const PJRT_Error* error, const PJRT_Api* api); absl::StatusCode PjrtErrorToStatusCode(const PJRT_Error* error, const PJRT_Api* api); +absl::StatusCode PjrtErrorCodeToStatusCode(PJRT_Error_Code code); PJRT_Error_Code StatusCodeToPjrtErrorCode(absl::StatusCode code); // Conversion helper from xla::PrimitiveType to PJRT_Buffer_Type. @@ -137,9 +141,8 @@ xla::PjRtFuture ConvertCEventToCppFuture(PJRT_Event* c_event, // The data of returned variable-length PJRT_NamedValue list is backed by // `cpp_value_map`, so `cpp_value_map` must outlive the returned list. It will // raise errors for unsupported PjRtValueType. -xla::StatusOr> ConvertToPjRtNamedValueList( - const absl::flat_hash_map& cpp_value_map, - int api_minor_version); +absl::StatusOr> ConvertToPjRtNamedValueList( + const absl::flat_hash_map& cpp_value_map); absl::flat_hash_map ConvertFromPjRtNamedValueList(const PJRT_NamedValue* c_value_list, @@ -153,6 +156,10 @@ xla::Status ValidateCreateOptions( const absl::flat_hash_map& expected_name_and_types); +// Returns attributes for plugin that uses XLA compiler. The attributes have the +// lifetime of the process. +const std::vector& GetXlaPluginCAttributes(); + // Helper function for checking the actual C API argument struct size is greater // than or equal to the expected size. The actual struct size can be larger if // it comes from a forwards-compatible caller built at a later version than this @@ -164,7 +171,7 @@ xla::Status ActualStructSizeIsGreaterOrEqual(absl::string_view struct_name, absl::string_view GetPlatformVersion(PJRT_Client* client, const PJRT_Api* api); absl::string_view GetPlatformName(PJRT_Client* client, const PJRT_Api* api); -xla::StatusOr GetTopologyDescription( +absl::StatusOr GetTopologyDescription( PJRT_Client* client, const PJRT_Api* api); // Releases `chunk`. @@ -193,9 +200,9 @@ struct PJRT_KeyValueCallbackData { PJRT_KeyValueCallbackData() = default; PJRT_KeyValueCallbackData(const PJRT_KeyValueCallbackData&) = delete; - xla::PjRtClient::KeyValueGetCallback kv_get; - xla::PjRtClient::KeyValuePutCallback kv_put; - // kv_get_c_func and kv_put_c_func are holding pointers to kv_get and kv_put. + std::shared_ptr kv_store; + + // kv_get_c_func and kv_put_c_func are holding pointers to kv_store. pjrt::PJRT_KeyValueGetCFunc kv_get_c_func; pjrt::PJRT_KeyValuePutCFunc kv_put_c_func; // c_kv_get and c_kv_put are holding pointers to kv_get_c_func and @@ -210,8 +217,7 @@ struct PJRT_KeyValueCallbackData { // PJRT_KeyValueCallbackData must be kept alive as long as c_kv_get and c_kv_put // may be called. std::unique_ptr ConvertToCKeyValueCallbacks( - xla::PjRtClient::KeyValueGetCallback kv_get, - xla::PjRtClient::KeyValuePutCallback kv_put); + std::shared_ptr kv_store); // std::function version of PJRT_SendCallback using PJRT_SendCallbackFunction = @@ -245,12 +251,12 @@ struct BufferMemoryLayoutData { std::vector tile_dims; std::vector tile_dim_sizes; }; -xla::StatusOr ConvertToBufferMemoryLayoutData( +absl::StatusOr ConvertToBufferMemoryLayoutData( const xla::Layout& cpp_layout); -xla::StatusOr ConvertToBufferMemoryLayoutData( +absl::StatusOr ConvertToBufferMemoryLayoutData( absl::Span byte_strides); -xla::StatusOr ConvertToLayout( +absl::StatusOr ConvertToLayout( const PJRT_Buffer_MemoryLayout_Tiled& c_tiled); PJRT_Buffer_Type GetElementType(const PJRT_Api* api, PJRT_Buffer* buffer); @@ -259,10 +265,10 @@ absl::Span GetDimensions(const PJRT_Api* api, PJRT_Buffer_MemoryLayout GetMemoryLayout(const PJRT_Api* api, PJRT_Buffer* buffer); -xla::StatusOr BuildXlaShapeFromC(PJRT_Buffer_Type element_type, - const int64_t* dims, - size_t num_dims, - PJRT_Buffer_MemoryLayout* layout); +absl::StatusOr BuildXlaShapeFromC(PJRT_Buffer_Type element_type, + const int64_t* dims, + size_t num_dims, + PJRT_Buffer_MemoryLayout* layout); absl::string_view PlatformName(const PJRT_Api* api, const PJRT_TopologyDescription* topo_desc); @@ -272,6 +278,46 @@ absl::Span DeviceDescriptions( absl::StatusOr GetCompiledMemoryStats( const PJRT_Api* api, PJRT_Executable* executable); +// Creates a PJRT_Profiler_Extension and adds a producer trace with +// the given name. The created PJRT_Profiler_Extension will be used in argument +// structs to pass the producer traceme context id to add a corresponding +// consumer trace in the API implementation. +PJRT_Profiler_Extension CreatePjrtProfilerExtension( + absl::string_view traceme_name); + +// Traverses an extension chain to find an extension struct with type +// `type`. `in` can either be a PJRT_Api* or a pointer to an Args struct -- +// anything with an `extension_start` field. The ExtType template parameter +// specifies the C extension type of the returned struct, if found (i.e. a +// specific extension struct that is layout-compatible with +// PJRT_Extension_Base). +template +ExtType* FindExtension(InputType* in, PJRT_Extension_Type type) { + PJRT_Extension_Base* ext = in->extension_start; + while (ext != nullptr) { + if (ext->type == type) { + return reinterpret_cast(ext); + } + ext = ext->next; + } + // 'type' wasn't found in extension chain + return nullptr; +} + +// Gets a traceme context id attached to PJRT_Profiler_Extension. +// Returns -1 if there is no PJRT_Profiler_Extension in args. +template +int64_t GetTracemeContextId(InputType* args) { + PJRT_Profiler_Extension* profiler_extension = + FindExtension( + args, PJRT_Extension_Type::PJRT_Extension_Type_Profiler); + int64_t traceme_context_id = -1; + if (profiler_extension != nullptr) { + traceme_context_id = profiler_extension->traceme_context_id; + } + return traceme_context_id; +} + } // namespace pjrt #endif // XLA_PJRT_C_PJRT_C_API_HELPERS_H_ diff --git a/xla/pjrt/c/pjrt_c_api_helpers_test.cc b/xla/pjrt/c/pjrt_c_api_helpers_test.cc index d36f3cf9cd3d1..bb4e386e5a4b5 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers_test.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,8 +16,8 @@ limitations under the License. #include #include +#include #include -#include #include #include @@ -26,12 +26,11 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/synchronization/mutex.h" -#include "absl/time/clock.h" #include "absl/time/time.h" #include "xla/layout.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" +#include "xla/pjrt/distributed/in_memory_key_value_store.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" #include "xla/status.h" @@ -55,9 +54,8 @@ TEST(PjRtCApiHelperTest, ConvertValidPjRtValueType) { {"int64_list", int64_list}, {"float", static_cast(1.0)}}; - TF_ASSERT_OK_AND_ASSIGN( - std::vector c_map, - ConvertToPjRtNamedValueList(original_cpp_map, /*api_minor_version=*/30)); + TF_ASSERT_OK_AND_ASSIGN(std::vector c_map, + ConvertToPjRtNamedValueList(original_cpp_map)); auto converted_back_cpp_map = ConvertFromPjRtNamedValueList(c_map.data(), c_map.size()); @@ -65,30 +63,6 @@ TEST(PjRtCApiHelperTest, ConvertValidPjRtValueType) { testing::UnorderedElementsAreArray(original_cpp_map)); } -TEST(PjRtCApiHelperTest, ConvertValidPjRtValueTypeWithUnsupportedApiVersion) { - std::vector int64_list = {static_cast(1), - static_cast(2)}; - absl::flat_hash_map original_cpp_map = { - {"string", static_cast("v1")}, - {"int64", static_cast(1)}, - {"int64_list", int64_list}, - {"float", static_cast(1.0)}, - {"float", static_cast(1.0)}, - {"bool", static_cast(true)}}; - int api_minor_version = 29; - absl::StatusOr> c_map = - ConvertToPjRtNamedValueList(original_cpp_map, api_minor_version); - EXPECT_THAT( - c_map.status(), - ::tsl::testing::StatusIs( - absl::StatusCode::kInvalidArgument, - absl::StrCat( - "Client cannot provide this option for API versions less " - "than 0.30. The framework PJRT API version is ", - PJRT_API_MAJOR, ".", PJRT_API_MINOR, - "and the plugin minor version is ", api_minor_version, "."))); -} - TEST(PjRtCApiHelperTest, ValidOptionNameAndPjRtValueTypeIndex) { const auto expected = absl::flat_hash_map({ {"string", PJRT_NamedValue_Type::PJRT_NamedValue_kString}, @@ -111,7 +85,7 @@ TEST(PjRtCApiHelperTest, InvalidOptionName) { auto status = ValidateCreateOptions(invalid_map, expected); - EXPECT_NE(status, tsl::OkStatus()); + EXPECT_NE(status, absl::OkStatus()); EXPECT_THAT(status.message(), HasSubstr("Unexpected option name passed to PJRT_Client_Create")); } @@ -126,58 +100,31 @@ TEST(PjRtCApiHelperTest, InvalidOptionTypeIndex) { auto status = ValidateCreateOptions(invalid_map, expected); - EXPECT_NE(status, tsl::OkStatus()); + EXPECT_NE(status, absl::OkStatus()); EXPECT_THAT(status.message(), HasSubstr("Option passed to PJRT_Client_Create with name string " "has type index 2 but expected type index is 0")); } TEST(PjRtCApiHelperTest, Callback) { - absl::flat_hash_map kv_store; - absl::Mutex mu; - xla::PjRtClient::KeyValueGetCallback kv_get = - [&kv_store, &mu](std::string_view k, - absl::Duration timeout) -> xla::StatusOr { - absl::Duration wait_interval = absl::Milliseconds(10); - int num_retry = timeout / wait_interval; - for (int i = 0; i < num_retry; i++) { - { - absl::MutexLock lock(&mu); - auto iter = kv_store.find(k); - if (iter != kv_store.end()) { - return iter->second; - } - } - absl::SleepFor(wait_interval); - } - return absl::NotFoundError( - absl::StrCat(k, " is not found in the kv store.")); - }; - xla::PjRtClient::KeyValuePutCallback kv_put = - [&kv_store, &mu](std::string_view k, std::string_view v) -> xla::Status { - { - absl::MutexLock lock(&mu); - kv_store[k] = v; - } - return tsl::OkStatus(); - }; - auto kv_callback_data = ConvertToCKeyValueCallbacks(kv_get, kv_put); - auto converted_back_kv_get = ToCppKeyValueGetCallback( - kv_callback_data->c_kv_get, &kv_callback_data->kv_get_c_func); - auto converted_back_kv_put = ToCppKeyValuePutCallback( + auto kv_store = std::make_shared(); + + auto kv_callback_data = ConvertToCKeyValueCallbacks(kv_store); + auto converted_kv_store = ToCppKeyValueStore( + kv_callback_data->c_kv_get, &kv_callback_data->kv_get_c_func, kv_callback_data->c_kv_put, &kv_callback_data->kv_put_c_func); - auto s = converted_back_kv_put("key", "value"); + auto s = converted_kv_store->Set("key", "value"); TF_EXPECT_OK(s); - auto v = converted_back_kv_get("key", absl::Seconds(1)); + auto v = converted_kv_store->Get("key", absl::Seconds(1)); TF_EXPECT_OK(v.status()); EXPECT_EQ(*v, "value"); } TEST(PjRtCApiHelperTest, ConvertToCLayoutFromStrides) { std::vector strides = {4, 8}; - xla::StatusOr layout_data = + absl::StatusOr layout_data = ConvertToBufferMemoryLayoutData(strides); EXPECT_TRUE(layout_data.ok()); diff --git a/xla/pjrt/c/pjrt_c_api_macros.h b/xla/pjrt/c/pjrt_c_api_macros.h index a6627276c2c15..1ee7d93214499 100644 --- a/xla/pjrt/c/pjrt_c_api_macros.h +++ b/xla/pjrt/c/pjrt_c_api_macros.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/c/pjrt_c_api_profiler_extension.h b/xla/pjrt/c/pjrt_c_api_profiler_extension.h index 6f620b2e9faad..c821916add71a 100644 --- a/xla/pjrt/c/pjrt_c_api_profiler_extension.h +++ b/xla/pjrt/c/pjrt_c_api_profiler_extension.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_ #define XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_ +#include +#include + #include "xla/backends/profiler/plugin/profiler_c_api.h" #include "xla/pjrt/c/pjrt_c_api.h" @@ -26,10 +29,15 @@ extern "C" { #define PJRT_API_PROFILER_EXTENSION_VERSION 0 typedef struct PJRT_Profiler_Extension { - PJRT_Structure_Type type; - const void* next; + size_t struct_size; + PJRT_Extension_Type type; + PJRT_Extension_Base* next; + // can be nullptr if PJRT_Profiler_Extension is used as an args extension PLUGIN_Profiler_Api* profiler_api; + // valid only when used as an args extension + int64_t traceme_context_id; } PJRT_Profiler_Extension; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Profiler_Extension, profiler_api); #ifdef __cplusplus } diff --git a/xla/pjrt/c/pjrt_c_api_test.cc b/xla/pjrt/c/pjrt_c_api_test.cc index 68cbc5d65d96a..573a43c7aa686 100644 --- a/xla/pjrt/c/pjrt_c_api_test.cc +++ b/xla/pjrt/c/pjrt_c_api_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -149,7 +149,7 @@ TEST_F(PjrtCApiTest, PlatformName) { PJRT_Client_PlatformName_Args args; args.client = client_; args.struct_size = PJRT_Client_PlatformName_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; PJRT_Error* error = api_->PJRT_Client_PlatformName(&args); ASSERT_EQ(error, nullptr); absl::string_view platform_name(args.platform_name, args.platform_name_size); @@ -160,7 +160,7 @@ TEST_F(PjrtCApiTest, ClientProcessIndex) { PJRT_Client_ProcessIndex_Args process_index_args = PJRT_Client_ProcessIndex_Args{ .struct_size = PJRT_Client_ProcessIndex_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, .process_index = -1, }; @@ -199,7 +199,7 @@ TEST_F(PjrtCApiTest, LookupDevice) { PJRT_Client_LookupDevice_Args lookup_device_args = PJRT_Client_LookupDevice_Args{ .struct_size = PJRT_Client_LookupDevice_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, .id = 0, .device = nullptr, @@ -217,7 +217,7 @@ TEST_F(PjrtCApiTest, LookupAddressableDevice) { PJRT_Client_LookupAddressableDevice_Args lookup_addressable_device_args = PJRT_Client_LookupAddressableDevice_Args{ .struct_size = PJRT_Client_LookupAddressableDevice_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, .local_hardware_id = 0, .addressable_device = nullptr, @@ -239,7 +239,7 @@ TEST_F(PjrtCApiTest, GetDefaultDeviceAssignmentNominal) { std::vector assignment_buffer(kNumReplicas * kNumPartitions); PJRT_Client_DefaultDeviceAssignment_Args args{ .struct_size = PJRT_Client_DefaultDeviceAssignment_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, .num_replicas = kNumReplicas, .num_partitions = kNumPartitions, @@ -257,7 +257,7 @@ TEST_F(PjrtCApiTest, GetDefaultDeviceAssignmentBufferTooSmall) { std::vector assignment_buffer(kBufferSize); PJRT_Client_DefaultDeviceAssignment_Args args{ .struct_size = PJRT_Client_DefaultDeviceAssignment_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, .num_replicas = kNumReplicas, .num_partitions = kNumPartitions, @@ -276,7 +276,7 @@ TEST_F(PjrtCApiTest, GetDefaultDeviceAssignmentBufferTooSmall) { TEST_F(PjrtCApiTest, LookupDeviceNegativeId) { PJRT_Client_LookupDevice_Args args = PJRT_Client_LookupDevice_Args{ .struct_size = PJRT_Client_LookupDevice_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, .id = -1, .device = nullptr, @@ -296,7 +296,7 @@ TEST_F(PjrtCApiTest, LookupDeviceOutOfRangeId) { int out_of_range_id = GetNumDevices(); PJRT_Client_LookupDevice_Args args = PJRT_Client_LookupDevice_Args{ .struct_size = PJRT_Client_LookupDevice_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, .id = out_of_range_id, .device = nullptr, @@ -318,7 +318,7 @@ void destroy_executable(PJRT_LoadedExecutable* executable, const PJRT_Api* api) { PJRT_LoadedExecutable_Destroy_Args args{ .struct_size = PJRT_LoadedExecutable_Destroy_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .executable = executable, }; PJRT_Error* error = api->PJRT_LoadedExecutable_Destroy(&args); @@ -345,7 +345,7 @@ TEST_F(PjrtCApiTest, BufferTransferImmutableUntilTransferCompletes) { PJRT_Event_Await_Args await_args; await_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE; - await_args.priv = nullptr; + await_args.extension_start = nullptr; await_args.event = event.get(); PJRT_Error* event_error = api_->PJRT_Event_Await(&await_args); ASSERT_EQ(event_error, nullptr); @@ -354,7 +354,7 @@ TEST_F(PjrtCApiTest, BufferTransferImmutableUntilTransferCompletes) { TEST_F(PjrtCApiTest, Compile) { PJRT_Client_Compile_Args args = PJRT_Client_Compile_Args{ .struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, }; std::string options_str = BuildSingleDeviceCompileOptionStr(); @@ -365,7 +365,7 @@ TEST_F(PjrtCApiTest, Compile) { std::string program_code{module_add_one}; PJRT_Program program = PJRT_Program{ .struct_size = PJRT_Program_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .code = program_code.data(), .code_size = program_code.length(), .format = format.c_str(), @@ -383,7 +383,7 @@ TEST_F(PjrtCApiTest, Compile) { TEST_F(PjrtCApiTest, CompileXlaComputation) { PJRT_Client_Compile_Args args = PJRT_Client_Compile_Args{ .struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, }; xla::DeviceAssignment device_assignment(1, 1); @@ -403,7 +403,7 @@ TEST_F(PjrtCApiTest, CompileXlaComputation) { std::string format(::pjrt::kHloFormat); PJRT_Program program = PJRT_Program{ .struct_size = PJRT_Program_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .code = module_str.data(), .code_size = module_str.size(), .format = format.c_str(), @@ -421,7 +421,7 @@ TEST_F(PjrtCApiTest, CompileXlaComputation) { TEST_F(PjrtCApiTest, CompileInvalidOption) { PJRT_Client_Compile_Args args = PJRT_Client_Compile_Args{ .struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, }; std::string options_str = "invalid compile options"; @@ -432,7 +432,7 @@ TEST_F(PjrtCApiTest, CompileInvalidOption) { std::string program_code{module_add_one}; PJRT_Program program = PJRT_Program{ .struct_size = PJRT_Program_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .code = program_code.data(), .code_size = program_code.length(), .format = format.c_str(), @@ -453,7 +453,7 @@ TEST_F(PjrtCApiTest, CompileInvalidOption) { TEST_F(PjrtCApiTest, CompileInvalidProgramFormat) { PJRT_Client_Compile_Args args = PJRT_Client_Compile_Args{ .struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .client = client_, }; xla::DeviceAssignment device_assignment(1, 1); @@ -468,7 +468,7 @@ TEST_F(PjrtCApiTest, CompileInvalidProgramFormat) { std::string format("invalid"); PJRT_Program program = PJRT_Program{ .struct_size = PJRT_Program_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .code = nullptr, .code_size = 0, .format = format.c_str(), @@ -498,7 +498,7 @@ TEST_F(PjrtCApiTest, DeviceProcessIndex) { PJRT_DeviceDescription_ProcessIndex_Args args = PJRT_DeviceDescription_ProcessIndex_Args{ .struct_size = PJRT_DeviceDescription_ProcessIndex_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .device_description = ::pjrt::GetDeviceDescription(api_, GetClientDevices()[0]), .process_index = -1, @@ -512,7 +512,7 @@ TEST_F(PjrtCApiTest, DeviceProcessIndex) { TEST_F(PjrtCApiTest, DeviceIsAddressable) { PJRT_Device_IsAddressable_Args args = PJRT_Device_IsAddressable_Args{ .struct_size = PJRT_Device_IsAddressable_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .device = GetClientDevices()[0], .is_addressable = false, }; @@ -525,7 +525,7 @@ TEST_F(PjrtCApiTest, DeviceIsAddressable) { TEST_F(PjrtCApiTest, DeviceLocalHardwareId) { PJRT_Device_LocalHardwareId_Args args = PJRT_Device_LocalHardwareId_Args{ .struct_size = PJRT_Device_LocalHardwareId_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .device = GetClientDevices()[0], .local_hardware_id = -1, }; @@ -565,7 +565,7 @@ class PjrtCApiBufferTest : public PjrtCApiTest { TEST_F(PjrtCApiBufferTest, IsDeleted) { PJRT_Buffer_IsDeleted_Args is_deleted_args; is_deleted_args.struct_size = PJRT_Buffer_IsDeleted_Args_STRUCT_SIZE; - is_deleted_args.priv = nullptr; + is_deleted_args.extension_start = nullptr; is_deleted_args.buffer = buffer_.get(); PJRT_Error* is_deleted_error = api_->PJRT_Buffer_IsDeleted(&is_deleted_args); ASSERT_EQ(is_deleted_error, nullptr); @@ -573,7 +573,7 @@ TEST_F(PjrtCApiBufferTest, IsDeleted) { PJRT_Buffer_Delete_Args delete_args; delete_args.struct_size = PJRT_Buffer_Delete_Args_STRUCT_SIZE; - delete_args.priv = nullptr; + delete_args.extension_start = nullptr; delete_args.buffer = buffer_.get(); PJRT_Error* delete_error = api_->PJRT_Buffer_Delete(&delete_args); ASSERT_EQ(delete_error, nullptr); @@ -586,7 +586,7 @@ TEST_F(PjrtCApiBufferTest, IsDeleted) { TEST_F(PjrtCApiBufferTest, GetOnDeviceSizeInBytes) { PJRT_Buffer_OnDeviceSizeInBytes_Args args; args.struct_size = PJRT_Buffer_OnDeviceSizeInBytes_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); PJRT_Error* on_device_size_bytes_error = api_->PJRT_Buffer_OnDeviceSizeInBytes(&args); @@ -598,7 +598,7 @@ TEST_F(PjrtCApiBufferTest, GetOnDeviceSizeInBytes) { TEST_F(PjrtCApiBufferTest, ReadyEvent) { PJRT_Buffer_ReadyEvent_Args get_event_args; get_event_args.struct_size = PJRT_Buffer_ReadyEvent_Args_STRUCT_SIZE; - get_event_args.priv = nullptr; + get_event_args.extension_start = nullptr; get_event_args.buffer = buffer_.get(); auto error = ToUniquePtr(api_->PJRT_Buffer_ReadyEvent(&get_event_args)); ASSERT_EQ(error, nullptr); @@ -609,7 +609,7 @@ TEST_F(PjrtCApiBufferTest, ReadyEvent) { // Wait for `buffer_`'s data transfer to complete (if it hasn't already) PJRT_Event_Await_Args await_args; await_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE; - await_args.priv = nullptr; + await_args.extension_start = nullptr; await_args.event = event; error.reset(api_->PJRT_Event_Await(&await_args)); ASSERT_EQ(error, nullptr); @@ -617,7 +617,7 @@ TEST_F(PjrtCApiBufferTest, ReadyEvent) { // Must be ready when `PJRT_Event_Await` completes PJRT_Event_IsReady_Args ready_args; ready_args.struct_size = PJRT_Event_IsReady_Args_STRUCT_SIZE; - ready_args.priv = nullptr; + ready_args.extension_start = nullptr; ready_args.event = event; error.reset(api_->PJRT_Event_IsReady(&ready_args)); ASSERT_EQ(error, nullptr); @@ -626,7 +626,7 @@ TEST_F(PjrtCApiBufferTest, ReadyEvent) { // Clean up PJRT_Event_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_Event_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.event = event; error.reset(api_->PJRT_Event_Destroy(&destroy_args)); EXPECT_EQ(error, nullptr); @@ -635,7 +635,7 @@ TEST_F(PjrtCApiBufferTest, ReadyEvent) { TEST_F(PjrtCApiBufferTest, ToHostBufferNoHostLayout) { PJRT_Buffer_ToHostBuffer_Args args; args.struct_size = PJRT_Buffer_ToHostBuffer_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.src = buffer_.get(); xla::Shape host_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); auto literal = std::make_shared(host_shape); @@ -661,7 +661,7 @@ TEST_F(PjrtCApiBufferTest, IncreaseAndDecreaseReferenceCount) { PJRT_Buffer_IncreaseExternalReferenceCount_Args increase_reference_count_args; increase_reference_count_args.struct_size = PJRT_Buffer_IncreaseExternalReferenceCount_Args_STRUCT_SIZE; - increase_reference_count_args.priv = nullptr; + increase_reference_count_args.extension_start = nullptr; increase_reference_count_args.buffer = buffer_.get(); PJRT_Error* increase_reference_count_error = api_->PJRT_Buffer_IncreaseExternalReferenceCount( @@ -671,7 +671,7 @@ TEST_F(PjrtCApiBufferTest, IncreaseAndDecreaseReferenceCount) { PJRT_Buffer_DecreaseExternalReferenceCount_Args decrease_reference_count_args; decrease_reference_count_args.struct_size = PJRT_Buffer_DecreaseExternalReferenceCount_Args_STRUCT_SIZE; - decrease_reference_count_args.priv = nullptr; + decrease_reference_count_args.extension_start = nullptr; decrease_reference_count_args.buffer = buffer_.get(); PJRT_Error* decrease_reference_error = api_->PJRT_Buffer_DecreaseExternalReferenceCount( @@ -683,7 +683,7 @@ TEST_F(PjrtCApiBufferTest, DecreaseReferenceCountReturnsError) { PJRT_Buffer_DecreaseExternalReferenceCount_Args args; args.struct_size = PJRT_Buffer_DecreaseExternalReferenceCount_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); auto error = ToUniquePtr(api_->PJRT_Buffer_DecreaseExternalReferenceCount(&args)); @@ -698,7 +698,7 @@ TEST_F(PjrtCApiBufferTest, DecreaseReferenceCountReturnsError) { TEST_F(PjrtCApiBufferTest, OpaqueDeviceMemoryDataPointer) { PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args args; args.struct_size = PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); PJRT_Error* error = api_->PJRT_Buffer_OpaqueDeviceMemoryDataPointer(&args); EXPECT_EQ(error, nullptr); diff --git a/xla/pjrt/c/pjrt_c_api_test.h b/xla/pjrt/c/pjrt_c_api_test.h index 143f417c26b2c..768a54eaa8faa 100644 --- a/xla/pjrt/c/pjrt_c_api_test.h +++ b/xla/pjrt/c/pjrt_c_api_test.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/c/pjrt_c_api_test_base.cc b/xla/pjrt/c/pjrt_c_api_test_base.cc index 3cfb43f1631b6..aa383745be279 100644 --- a/xla/pjrt/c/pjrt_c_api_test_base.cc +++ b/xla/pjrt/c/pjrt_c_api_test_base.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ namespace { PJRT_Client* CreateClient(const PJRT_Api* api) { PJRT_Client_Create_Args create_args; create_args.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE; - create_args.priv = nullptr; + create_args.extension_start = nullptr; create_args.create_options = nullptr; create_args.num_options = 0; create_args.kv_get_callback = nullptr; @@ -67,7 +67,7 @@ PjrtCApiTestBase::~PjrtCApiTestBase() { destroy_client(client_); } void PjrtCApiTestBase::destroy_client(PJRT_Client* client) { PJRT_Client_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_Client_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.client = client; PJRT_Error* error = api_->PJRT_Client_Destroy(&destroy_args); CHECK_EQ(error, nullptr); @@ -76,7 +76,7 @@ void PjrtCApiTestBase::destroy_client(PJRT_Client* client) { int PjrtCApiTestBase::GetDeviceId(PJRT_DeviceDescription* device_desc) const { PJRT_DeviceDescription_Id_Args args = PJRT_DeviceDescription_Id_Args{ .struct_size = PJRT_DeviceDescription_Id_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .device_description = device_desc, .id = -1, }; @@ -96,7 +96,7 @@ bool PjrtCApiTestBase::IsValidDeviceId(PJRT_Device* device) const { int PjrtCApiTestBase::GetLocalHardwareId(PJRT_Device* device) const { PJRT_Device_LocalHardwareId_Args args = PJRT_Device_LocalHardwareId_Args{ .struct_size = PJRT_Device_LocalHardwareId_Args_STRUCT_SIZE, - .priv = nullptr, + .extension_start = nullptr, .device = device, .local_hardware_id = -1, }; @@ -108,7 +108,7 @@ int PjrtCApiTestBase::GetLocalHardwareId(PJRT_Device* device) const { absl::Span PjrtCApiTestBase::GetClientDevices() const { PJRT_Client_Devices_Args dev_args; dev_args.struct_size = PJRT_Client_Devices_Args_STRUCT_SIZE; - dev_args.priv = nullptr; + dev_args.extension_start = nullptr; dev_args.client = client_; PJRT_Error* error = api_->PJRT_Client_Devices(&dev_args); CHECK(error == nullptr); @@ -136,7 +136,7 @@ absl::Span PjrtCApiTestBase::GetClientAddressableDevices() const { PJRT_Client_AddressableDevices_Args addr_args; addr_args.struct_size = PJRT_Client_AddressableDevices_Args_STRUCT_SIZE; - addr_args.priv = nullptr; + addr_args.extension_start = nullptr; addr_args.client = client_; PJRT_Error* error = api_->PJRT_Client_AddressableDevices(&addr_args); CHECK(error == nullptr); @@ -151,7 +151,7 @@ PjrtCApiTestBase::CreateBufferFromHostBufferArgs( PJRT_Device* device) { PJRT_Client_BufferFromHostBuffer_Args args; args.struct_size = PJRT_Client_BufferFromHostBuffer_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.data = data.data(); args.type = ::pjrt::ConvertToPjRtBufferType(shape.element_type()); @@ -195,7 +195,7 @@ PjrtCApiTestBase::create_buffer(PJRT_Device* device) { PJRT_Buffer_ReadyEvent_Args get_event_args; get_event_args.struct_size = PJRT_Buffer_ReadyEvent_Args_STRUCT_SIZE; - get_event_args.priv = nullptr; + get_event_args.extension_start = nullptr; get_event_args.buffer = buffer.get(); auto ready_event_error = ToUniquePtr(api_->PJRT_Buffer_ReadyEvent(&get_event_args)); diff --git a/xla/pjrt/c/pjrt_c_api_test_base.h b/xla/pjrt/c/pjrt_c_api_test_base.h index 5cd0013333929..28fe1f660d1b4 100644 --- a/xla/pjrt/c/pjrt_c_api_test_base.h +++ b/xla/pjrt/c/pjrt_c_api_test_base.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/c/pjrt_c_api_tpu.h b/xla/pjrt/c/pjrt_c_api_tpu.h index 898dd37f8c245..469e5a3dba4f2 100644 --- a/xla/pjrt/c/pjrt_c_api_tpu.h +++ b/xla/pjrt/c/pjrt_c_api_tpu.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 3deaf7e0841e4..74b473f84650d 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,6 +29,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -47,6 +48,7 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" #include "xla/pjrt/compile_options.pb.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" @@ -54,18 +56,21 @@ limitations under the License. #include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/service/computation_placer.h" #include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/framework/allocator.h" +#include "tsl/platform/casts.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/connected_traceme.h" +#include "tsl/profiler/lib/context_types.h" namespace pjrt { @@ -118,7 +123,7 @@ static xla::Status PopulateExecutableCostAnalysis(PJRT_Executable* executable) { std::string& property_name = cost_analysis_names[i]; cost_analysis_property.struct_size = PJRT_NamedValue_STRUCT_SIZE; - cost_analysis_property.priv = nullptr; + cost_analysis_property.extension_start = nullptr; property_name = property.first; cost_analysis_property.name = property_name.c_str(); @@ -226,14 +231,17 @@ static xla::Status PopulateExecutableOutputMemoryKinds( return xla::OkStatus(); } -xla::PjRtClient::KeyValueGetCallback ToCppKeyValueGetCallback( - PJRT_KeyValueGetCallback c_callback, void* user_arg) { - if (c_callback == nullptr) { - return nullptr; - } - return [c_callback, user_arg]( - std::string_view key, - absl::Duration timeout) -> xla::StatusOr { +class CApiKeyValueStore : public xla::KeyValueStoreInterface { + public: + CApiKeyValueStore(PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, + PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg) + : c_get_callback_(c_get_callback), + get_user_arg_(get_user_arg), + c_put_callback_(c_put_callback), + put_user_arg_(put_user_arg) {} + + absl::StatusOr Get(std::string_view key, + absl::Duration timeout) override { PJRT_CallbackError callback_error = [](PJRT_Error_Code code, const char* message, size_t message_size) { @@ -245,24 +253,17 @@ xla::PjRtClient::KeyValueGetCallback ToCppKeyValueGetCallback( args.key_size = key.size(); args.timeout_in_ms = timeout / absl::Milliseconds(1); args.callback_error = &callback_error; - args.user_arg = user_arg; - std::unique_ptr error(c_callback(&args)); + args.user_arg = get_user_arg_; + std::unique_ptr error(c_get_callback_(&args)); if (error != nullptr) { return error->status; } auto result = std::string(args.value, args.value_size); args.value_deleter_callback(args.value); return result; - }; -} - -xla::PjRtClient::KeyValuePutCallback ToCppKeyValuePutCallback( - PJRT_KeyValuePutCallback c_callback, void* user_arg) { - if (c_callback == nullptr) { - return nullptr; } - return [c_callback, user_arg](std::string_view key, - std::string_view value) -> xla::Status { + + absl::Status Set(std::string_view key, std::string_view value) override { PJRT_CallbackError callback_error = [](PJRT_Error_Code code, const char* message, size_t message_size) { @@ -275,13 +276,29 @@ xla::PjRtClient::KeyValuePutCallback ToCppKeyValuePutCallback( args.value = value.data(); args.value_size = value.size(); args.callback_error = &callback_error; - args.user_arg = user_arg; - std::unique_ptr error(c_callback(&args)); + args.user_arg = put_user_arg_; + std::unique_ptr error(c_put_callback_(&args)); if (error != nullptr) { return error->status; } - return xla::OkStatus(); - }; + return absl::OkStatus(); + } + + private: + PJRT_KeyValueGetCallback c_get_callback_; + void* get_user_arg_; + PJRT_KeyValuePutCallback c_put_callback_; + void* put_user_arg_; +}; + +std::shared_ptr ToCppKeyValueStore( + PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, + PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg) { + if (c_get_callback == nullptr || c_put_callback == nullptr) { + return nullptr; + } + return std::make_shared(c_get_callback, get_user_arg, + c_put_callback, put_user_arg); } // ---------------------------------- Errors ----------------------------------- @@ -323,11 +340,23 @@ PJRT_Error* PJRT_Error_GetCode(PJRT_Error_GetCode_Args* args) { // ---------------------------------- Plugin ----------------------------------- -PJRT_Error* PJRT_Plugin_Attributes(PJRT_Plugin_Attributes_Args* args) { +PJRT_Error* PJRT_Plugin_Attributes_Empty(PJRT_Plugin_Attributes_Args* args) { PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( "PJRT_Plugin_Attributes_Args", PJRT_Plugin_Attributes_Args_STRUCT_SIZE, args->struct_size)); args->num_attributes = 0; + args->attributes = nullptr; + return nullptr; +} + +PJRT_Error* PJRT_Plugin_Attributes_Xla(PJRT_Plugin_Attributes_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_Plugin_Attributes_Args", PJRT_Plugin_Attributes_Args_STRUCT_SIZE, + args->struct_size)); + const std::vector& attributes = + pjrt::GetXlaPluginCAttributes(); + args->num_attributes = attributes.size(); + args->attributes = attributes.data(); return nullptr; } @@ -449,8 +478,8 @@ PJRT_Error* PJRT_Client_AddressableMemories( } // Searches `device_list` for a PJRT_Device* that wraps a provided -// `xla::PjRtDevice *` (`cpp_device`). If a match is found, that PJRT_Device* is -// returned. Otherwise, returns nullptr. +// `xla::PjRtDevice *` (`cpp_device`). If a match is found, that PJRT_Device* +// is returned. Otherwise, returns nullptr. static PJRT_Device* FindDeviceWrapper( xla::PjRtDevice* cpp_device, absl::Span device_list) { for (PJRT_Device* device : device_list) { @@ -502,7 +531,7 @@ static void PopulatePjrtExecutableAddressableDevices( namespace { -xla::StatusOr ParseCompileOptions( +absl::StatusOr ParseCompileOptions( absl::string_view options_str) { xla::CompileOptionsProto options_proto; // Open source ParseFromString doesn't support string_view. @@ -515,7 +544,7 @@ xla::StatusOr ParseCompileOptions( using ProgramVariant = std::variant, xla::XlaComputation>; -xla::StatusOr< +absl::StatusOr< std::variant, xla::XlaComputation>> ParsePjrtProgram(std::optional& context, const PJRT_Program* program) { @@ -560,6 +589,11 @@ PJRT_Error* PJRT_Client_Compile(PJRT_Client_Compile_Args* args) { PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( "PJRT_Program", PJRT_Program_STRUCT_SIZE, args->program->struct_size)); + int64_t traceme_context_id = pjrt::GetTracemeContextId(args); + tsl::profiler::TraceMeConsumer consumer( + "PJRT_Client_Compile", tsl::profiler::ContextType::kPjrtLibraryCall, + traceme_context_id); + PJRT_ASSIGN_OR_RETURN( xla::CompileOptions options, ParseCompileOptions(absl::string_view(args->compile_options, @@ -658,7 +692,7 @@ PJRT_Error* PJRT_Client_BufferFromHostBuffer( xla::PjRtFuture::Promise promise = xla::PjRtFuture::CreatePromise(); - std::function on_done_with_host_buffer = [promise]() mutable { + absl::AnyInvocable on_done_with_host_buffer = [promise]() mutable { promise.Set(xla::OkStatus()); }; @@ -674,9 +708,18 @@ PJRT_Error* PJRT_Client_BufferFromHostBuffer( dims, byte_strides, ::pjrt::ConvertFromPjRtHostBufferSemantics( args->host_buffer_semantics), - on_done_with_host_buffer, args->memory->memory_space, - &layout.value())); + std::move(on_done_with_host_buffer), + args->memory->memory_space, &layout.value())); } else if (has_layout_and_no_memory) { + PJRT_ASSIGN_OR_RETURN( + buffer, args->client->client->BufferFromHostBuffer( + args->data, ::pjrt::ConvertFromPjRtBufferType(args->type), + dims, byte_strides, + ::pjrt::ConvertFromPjRtHostBufferSemantics( + args->host_buffer_semantics), + std::move(on_done_with_host_buffer), args->device->device, + &layout.value())); + } else if (has_memory_and_no_layout) { PJRT_ASSIGN_OR_RETURN( buffer, args->client->client->BufferFromHostBuffer( @@ -684,16 +727,8 @@ PJRT_Error* PJRT_Client_BufferFromHostBuffer( byte_strides, ::pjrt::ConvertFromPjRtHostBufferSemantics( args->host_buffer_semantics), - on_done_with_host_buffer, args->device->device, &layout.value())); - } else if (has_memory_and_no_layout) { - PJRT_ASSIGN_OR_RETURN( - buffer, args->client->client->BufferFromHostBuffer( - args->data, ::pjrt::ConvertFromPjRtBufferType(args->type), - dims, byte_strides, - ::pjrt::ConvertFromPjRtHostBufferSemantics( - args->host_buffer_semantics), - on_done_with_host_buffer, args->memory->memory_space, - /*device_layout=*/nullptr)); + std::move(on_done_with_host_buffer), args->memory->memory_space, + /*device_layout=*/nullptr)); } else { PJRT_ASSIGN_OR_RETURN( buffer, args->client->client->BufferFromHostBuffer( @@ -701,7 +736,7 @@ PJRT_Error* PJRT_Client_BufferFromHostBuffer( dims, byte_strides, ::pjrt::ConvertFromPjRtHostBufferSemantics( args->host_buffer_semantics), - on_done_with_host_buffer, args->device->device)); + std::move(on_done_with_host_buffer), args->device->device)); } args->buffer = new PJRT_Buffer{std::move(buffer), args->client}; @@ -1060,8 +1095,8 @@ static xla::Status VerifyOptimizedProgramArgs( return xla::OkStatus(); } -static xla::StatusOr> GetOptimizedProgramModule( - const PJRT_Executable_OptimizedProgram_Args* args) { +static absl::StatusOr> +GetOptimizedProgramModule(const PJRT_Executable_OptimizedProgram_Args* args) { TF_ASSIGN_OR_RETURN(std::vector> hlo_modules, args->executable->get()->GetHloModules()); if (hlo_modules.empty()) { @@ -1245,8 +1280,8 @@ static xla::SendCallback CSendCallbackToCpp( PJRT_Chunk c_chunk = ConvertFromCppChunk(std::move(input)); // PJRT_CallbackError creates PJRT_Error in the implementation, but // using the caller's callback status code & message. This way, the - // caller avoids creating PJRT_Error itself, and the PJRT_Error is fully - // managed in the implementation layer. + // caller avoids creating PJRT_Error itself, and the PJRT_Error is + // fully managed in the implementation layer. PJRT_CallbackError c_callback_error = [](PJRT_Error_Code code, const char* message, size_t message_size) { return new PJRT_Error{ @@ -1257,7 +1292,7 @@ static xla::SendCallback CSendCallbackToCpp( std::unique_ptr error(callback( &c_chunk, &c_callback_error, total_size_in_bytes, done, user_arg)); if (error == nullptr) { - return tsl::OkStatus(); + return absl::OkStatus(); } return error->status; }}; @@ -1330,6 +1365,12 @@ PJRT_Error* PJRT_LoadedExecutable_Execute( PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( "PJRT_ExecuteOptions", PJRT_ExecuteOptions_STRUCT_SIZE, args->options->struct_size)); + + int64_t traceme_context_id = pjrt::GetTracemeContextId(args); + tsl::profiler::TraceMeConsumer consumer( + "PJRT_LoadedExecutable_Execute", + tsl::profiler::ContextType::kPjrtLibraryCall, traceme_context_id); + xla::ExecuteOptions options; options.launch_id = args->options->launch_id; options.strict_shape_checking = true; @@ -1414,7 +1455,8 @@ PJRT_Error* PJRT_LoadedExecutable_Execute( if (args->num_devices != 1) { return new PJRT_Error{xla::InvalidArgument( "num_devices and corresponding output list sizes must be 1 when " - "calling PJRT_LoadedExecutable_Execute with non-null execute_device. " + "calling PJRT_LoadedExecutable_Execute with non-null " + "execute_device. " "Got " "num_devices=%i", args->num_devices)}; @@ -1493,6 +1535,12 @@ PJRT_Error* PJRT_Executable_GetCompiledMemoryStats( args->output_size_in_bytes = memory_stats.output_size_in_bytes; args->alias_size_in_bytes = memory_stats.alias_size_in_bytes; args->temp_size_in_bytes = memory_stats.temp_size_in_bytes; + args->host_generated_code_size_in_bytes = + memory_stats.host_generated_code_size_in_bytes; + args->host_argument_size_in_bytes = memory_stats.host_argument_size_in_bytes; + args->host_output_size_in_bytes = memory_stats.host_output_size_in_bytes; + args->host_alias_size_in_bytes = memory_stats.host_alias_size_in_bytes; + args->host_temp_size_in_bytes = memory_stats.host_temp_size_in_bytes; return nullptr; } @@ -1606,9 +1654,16 @@ PJRT_Error* PJRT_Buffer_GetMemoryLayout( { absl::MutexLock lock(&args->buffer->mu); if (!layout_data.has_value()) { - PJRT_ASSIGN_OR_RETURN( - BufferMemoryLayoutData data, - ConvertToBufferMemoryLayoutData(args->buffer->buffer->layout())); + // TODO(skyewm): change PJRT C API to also use opaque layout type + std::unique_ptr pjrt_layout = + args->buffer->buffer->layout(); + xla::PjRtXlaLayout* pjrt_xla_layout = + tensorflow::down_cast(pjrt_layout.get()); + CHECK(pjrt_xla_layout != nullptr) << "Got unexpected layout type"; + const xla::Layout& xla_layout = pjrt_xla_layout->xla_layout(); + + PJRT_ASSIGN_OR_RETURN(BufferMemoryLayoutData data, + ConvertToBufferMemoryLayoutData(xla_layout)); layout_data.emplace(std::move(data)); } } @@ -1920,7 +1975,7 @@ PJRT_Error* PJRT_Event_Error(PJRT_Event_Error_Args* args) { if (!event->status.has_value()) { PJRT_Event_Await_Args await_args; await_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE; - await_args.priv = nullptr; + await_args.extension_start = nullptr; await_args.event = event; return PJRT_Event_Await(&await_args); } @@ -2061,7 +2116,7 @@ static std::vector PopulatePjrtAttributes( for (auto const& [name, value] : attributes) { PJRT_NamedValue& cur_attribute = c_attributes[ind]; cur_attribute.struct_size = PJRT_NamedValue_STRUCT_SIZE; - cur_attribute.priv = nullptr; + cur_attribute.extension_start = nullptr; cur_attribute.name = name.c_str(); cur_attribute.name_size = name.size(); if (const std::string* string_val = std::get_if(&value)) { @@ -2157,9 +2212,9 @@ static void AttachDevicesAndMemories(PJRT_Client* c_client) { } } -static xla::StatusOr> +static absl::StatusOr> GetStatusOrTopologyDescription(const xla::PjRtClient& cpp_client) { - xla::StatusOr status_or_cpp_topo = + absl::StatusOr status_or_cpp_topo = cpp_client.GetTopologyDescription(); if (!status_or_cpp_topo.ok()) { return status_or_cpp_topo.status(); @@ -2220,3 +2275,179 @@ PJRT_LoadedExecutable::PJRT_LoadedExecutable( : executable(std::move(executable)), client(client) { pjrt::PopulatePjrtExecutableAddressableDevices(this); } + +namespace pjrt { + +PJRT_Api CreatePjrtApi(PJRT_Client_Create* create_fn, + PJRT_TopologyDescription_Create* topology_create_fn, + PJRT_Plugin_Initialize* plugin_initialize_fn, + PJRT_Extension_Base* extension_start, + PJRT_Plugin_Attributes* plugin_attributes_fn) { + return PJRT_Api{ + /*struct_size=*/PJRT_Api_STRUCT_SIZE, + /*extension_start=*/extension_start, + + /*pjrt_api_version=*/ + PJRT_Api_Version{/*struct_size=*/PJRT_Api_Version_STRUCT_SIZE, + /*priv=*/nullptr, + /*major_version=*/PJRT_API_MAJOR, + /*minor_version=*/PJRT_API_MINOR}, + + /*PJRT_Error_Destroy=*/pjrt::PJRT_Error_Destroy, + /*PJRT_Error_Message=*/pjrt::PJRT_Error_Message, + /*PJRT_Error_GetCode=*/pjrt::PJRT_Error_GetCode, + + /*PJRT_Plugin_Initialize=*/plugin_initialize_fn, + /*PJRT_Plugin_Attributes=*/plugin_attributes_fn, + + /*PJRT_Event_Destroy=*/pjrt::PJRT_Event_Destroy, + /*PJRT_Event_IsReady=*/pjrt::PJRT_Event_IsReady, + /*PJRT_Event_Error=*/pjrt::PJRT_Event_Error, + /*PJRT_Event_Await=*/pjrt::PJRT_Event_Await, + /*PJRT_Event_OnReady=*/pjrt::PJRT_Event_OnReady, + + /*PJRT_Client_Create=*/create_fn, + /*PJRT_Client_Destroy=*/pjrt::PJRT_Client_Destroy, + /*PJRT_Client_PlatformName=*/pjrt::PJRT_Client_PlatformName, + /*PJRT_Client_ProcessIndex=*/pjrt::PJRT_Client_ProcessIndex, + /*PJRT_Client_PlatformVersion= */ pjrt::PJRT_Client_PlatformVersion, + /*PJRT_Client_Devices= */ pjrt::PJRT_Client_Devices, + /*PJRT_Client_AddressableDevices=*/ + pjrt::PJRT_Client_AddressableDevices, + /*PJRT_Client_LookupDevice=*/pjrt::PJRT_Client_LookupDevice, + /*PJRT_Client_LookupAddressableDevice=*/ + pjrt::PJRT_Client_LookupAddressableDevice, + /*PJRT_Client_AddressableMemories=*/pjrt::PJRT_Client_AddressableMemories, + /*PJRT_Client_Compile=*/pjrt::PJRT_Client_Compile, + /*PJRT_Client_DefaultDeviceAssignment=*/ + pjrt::PJRT_Client_DefaultDeviceAssignment, + /*PJRT_Client_BufferFromHostBuffer=*/ + pjrt::PJRT_Client_BufferFromHostBuffer, + + /*PJRT_DeviceDescription_Id=*/pjrt::PJRT_DeviceDescription_Id, + /*PJRT_DeviceDescription_ProcessIndex=*/ + pjrt::PJRT_DeviceDescription_ProcessIndex, + /*PJRT_DeviceDescription_Attributes=*/ + pjrt::PJRT_DeviceDescription_Attributes, + /*PJRT_DeviceDescription_Kind=*/pjrt::PJRT_DeviceDescription_Kind, + /*PJRT_DeviceDescription_DebugString=*/ + pjrt::PJRT_DeviceDescription_DebugString, + /*PJRT_DeviceDescription_ToString=*/ + pjrt::PJRT_DeviceDescription_ToString, + + /*PJRT_Device_GetDescription=*/pjrt::PJRT_Device_GetDescription, + /*PJRT_Device_IsAddressable=*/pjrt::PJRT_Device_IsAddressable, + /*PJRT_Device_LocalHardwareId=*/pjrt::PJRT_Device_LocalHardwareId, + /*PJRT_Device_AddressableMemories=*/pjrt::PJRT_Device_AddressableMemories, + /*PJRT_Device_DefaultMemory=*/pjrt::PJRT_Device_DefaultMemory, + /*PJRT_Device_MemoryStats=*/pjrt::PJRT_Device_MemoryStats, + + /*PJRT_Memory_Id=*/pjrt::PJRT_Memory_Id, + /*PJRT_Memory_Kind=*/pjrt::PJRT_Memory_Kind, + /*PJRT_Memory_DebugString=*/pjrt::PJRT_Memory_DebugString, + /*PJRT_Memory_ToString=*/pjrt::PJRT_Memory_ToString, + /*PJRT_Memory_AddressableByDevices=*/ + pjrt::PJRT_Memory_AddressableByDevices, + + /*PJRT_Executable_Destroy=*/pjrt::PJRT_Executable_Destroy, + /*PJRT_Executable_Name=*/pjrt::PJRT_Executable_Name, + /*PJRT_Executable_NumReplicas=*/pjrt::PJRT_Executable_NumReplicas, + /*PJRT_Executable_NumPartitions=*/ + pjrt::PJRT_Executable_NumPartitions, + /*PJRT_Executable_NumOutputs=*/pjrt::PJRT_Executable_NumOutputs, + /*PJRT_Executable_SizeOfGeneratedCodeInBytes=*/ + pjrt::PJRT_Executable_SizeOfGeneratedCodeInBytes, + /*PJRT_Executable_GetCostAnalysis=*/pjrt::PJRT_Executable_GetCostAnalysis, + /*PJRT_Executable_OutputMemoryKinds=*/ + pjrt::PJRT_Executable_OutputMemoryKinds, + /*PJRT_Executable_OptimizedProgram=*/ + pjrt::PJRT_Executable_OptimizedProgram, + /*PJRT_Executable_Serialize=*/pjrt::PJRT_Executable_Serialize, + + /*PJRT_LoadedExecutable_Destroy=*/pjrt::PJRT_LoadedExecutable_Destroy, + /*PJRT_LoadedExecutable_GetExecutable=*/ + pjrt::PJRT_LoadedExecutable_GetExecutable, + /*PJRT_LoadedExecutable_AddressableDevices=*/ + pjrt::PJRT_LoadedExecutable_AddressableDevices, + /*PJRT_LoadedExecutable_Delete=*/pjrt::PJRT_LoadedExecutable_Delete, + /*PJRT_LoadedExecutable_IsDeleted=*/ + pjrt::PJRT_LoadedExecutable_IsDeleted, + /*PJRT_LoadedExecutable_Execute=*/pjrt::PJRT_LoadedExecutable_Execute, + /*PJRT_Executable_DeserializeAndLoad=*/ + pjrt::PJRT_Executable_DeserializeAndLoad, + /*PJRT_LoadedExecutable_Fingerprint=*/ + pjrt::PJRT_LoadedExecutable_Fingerprint, + + /*PJRT_Buffer_Destroy=*/pjrt::PJRT_Buffer_Destroy, + /*PJRT_Buffer_ElementType=*/pjrt::PJRT_Buffer_ElementType, + /*PJRT_Buffer_Dimensions=*/pjrt::PJRT_Buffer_Dimensions, + /*PJRT_Buffer_UnpaddedDimensions=*/ + pjrt::PJRT_Buffer_UnpaddedDimensions, + /*PJRT_Buffer_DynamicDimensionIndices=*/ + pjrt::PJRT_Buffer_DynamicDimensionIndices, + /*PJRT_Buffer_GetMemoryLayout=*/ + pjrt::PJRT_Buffer_GetMemoryLayout, + /*PJRT_Buffer_OnDeviceSizeInBytes=*/ + pjrt::PJRT_Buffer_OnDeviceSizeInBytes, + /*PJRT_Buffer_Device=*/pjrt::PJRT_Buffer_Device, + /*PJRT_Buffer_Memory=*/pjrt::PJRT_Buffer_Memory, + /*PJRT_Buffer_Delete=*/pjrt::PJRT_Buffer_Delete, + /*PJRT_Buffer_IsDeleted=*/pjrt::PJRT_Buffer_IsDeleted, + /*PJRT_Buffer_CopyToDevice=*/pjrt::PJRT_Buffer_CopyToDevice, + /*PJRT_Buffer_ToHostBuffer=*/pjrt::PJRT_Buffer_ToHostBuffer, + /*PJRT_Buffer_IsOnCpu=*/pjrt::PJRT_Buffer_IsOnCpu, + /*PJRT_Buffer_ReadyEvent=*/pjrt::PJRT_Buffer_ReadyEvent, + /*PJRT_Buffer_UnsafePointer=*/pjrt::PJRT_Buffer_UnsafePointer, + /*PJRT_Buffer_IncreaseExternalReferenceCount=*/ + pjrt::PJRT_Buffer_IncreaseExternalReferenceCount, + /*PJRT_Buffer_DecreaseExternalReferenceCount=*/ + pjrt::PJRT_Buffer_DecreaseExternalReferenceCount, + /*PJRT_Buffer_OpaqueDeviceMemoryDataPointer=*/ + pjrt::PJRT_Buffer_OpaqueDeviceMemoryDataPointer, + + /*PJRT_CopyToDeviceStream_Destroy=*/ + pjrt::PJRT_CopyToDeviceStream_Destroy, + /*PJRT_CopyToDeviceStream_AddChunk=*/ + pjrt::PJRT_CopyToDeviceStream_AddChunk, + /*PJRT_CopyToDeviceStream_TotalBytes=*/ + pjrt::PJRT_CopyToDeviceStream_TotalBytes, + /*PJRT_CopyToDeviceStream_GranuleSize=*/ + pjrt::PJRT_CopyToDeviceStream_GranuleSize, + /*PJRT_CopyToDeviceStream_CurrentBytes=*/ + pjrt::PJRT_CopyToDeviceStream_CurrentBytes, + + /*PJRT_TopologyDescription_Create=*/topology_create_fn, + /*PJRT_TopologyDescription_Destroy=*/ + pjrt::PJRT_TopologyDescription_Destroy, + /*PJRT_TopologyDescription_PlatformName=*/ + pjrt::PJRT_TopologyDescription_PlatformName, + /*PJRT_TopologyDescription_PlatformVersion=*/ + pjrt::PJRT_TopologyDescription_PlatformVersion, + /*PJRT_TopologyDescription_GetDeviceDescriptions=*/ + pjrt::PJRT_TopologyDescription_GetDeviceDescriptions, + /*PJRT_TopologyDescription_Serialize=*/ + pjrt::PJRT_TopologyDescription_Serialize, + /*PJRT_TopologyDescription_Attributes=*/ + pjrt::PJRT_TopologyDescription_Attributes, + + /*PJRT_Compile=*/pjrt::PJRT_Compile, + + // Always add new fields to the end of the struct. Move fields below to + // their corresponding places after each major version bump. + /*PJRT_Executable_OutputElementTypes=*/ + pjrt::PJRT_Executable_OutputElementTypes, + /*PJRT_Executable_OutputDimensions=*/ + pjrt::PJRT_Executable_OutputDimensions, + /*PJRT_Buffer_CopyToMemory=*/ + pjrt::PJRT_Buffer_CopyToMemory, + /*PJRT_Client_CreateViewOfDeviceBuffer=*/ + pjrt::PJRT_Client_CreateViewOfDeviceBuffer, + /*PJRT_Executable_Fingerprint=*/pjrt::PJRT_Executable_Fingerprint, + /*PJRT_Client_TopologyDescription= */ + pjrt::PJRT_Client_TopologyDescription, + /*PJRT_Executable_GetCompiledMemoryStats= */ + pjrt::PJRT_Executable_GetCompiledMemoryStats, + }; +} + +} // namespace pjrt diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/xla/pjrt/c/pjrt_c_api_wrapper_impl.h index 2ce5220719452..a36299b6cfae4 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.h +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,13 +29,16 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/shape.h" #include "xla/status.h" +#include "tsl/platform/casts.h" struct PJRT_Error { xla::Status status; @@ -77,7 +80,7 @@ struct PJRT_Client { // `owned_memories`. absl::flat_hash_map c_memory_from_cpp_memory; - xla::StatusOr> topology; + absl::StatusOr> topology; explicit PJRT_Client(std::unique_ptr cpp_client); }; @@ -111,7 +114,7 @@ struct PJRT_Executable { // Must be shared_ptr so that we can share with PJRT_LoadedExecutable. std::shared_ptr executable; - xla::StatusOr fingerprint; + absl::StatusOr fingerprint; // Used to synchronize concurrent setting of cached values. mutable absl::Mutex mutex; @@ -208,7 +211,8 @@ void PJRT_Error_Destroy(PJRT_Error_Destroy_Args* args); void PJRT_Error_Message(PJRT_Error_Message_Args* args); PJRT_Error* PJRT_Error_GetCode(PJRT_Error_GetCode_Args* args); -PJRT_Error* PJRT_Plugin_Attributes(PJRT_Plugin_Attributes_Args* args); +PJRT_Error* PJRT_Plugin_Attributes_Empty(PJRT_Plugin_Attributes_Args* args); +PJRT_Error* PJRT_Plugin_Attributes_Xla(PJRT_Plugin_Attributes_Args* args); PJRT_Error* PJRT_Event_Destroy(PJRT_Event_Destroy_Args* args); PJRT_Error* PJRT_Event_IsReady(PJRT_Event_IsReady_Args* args); @@ -414,11 +418,9 @@ PJRT_TopologyDescription* CreateWrapperDeviceTopology( PJRT_Client* CreateWrapperClient(std::unique_ptr cpp_client); // Helper functions for converting C key-value store callbacks to C++ callbacks. -xla::PjRtClient::KeyValueGetCallback ToCppKeyValueGetCallback( - PJRT_KeyValueGetCallback c_callback, void* user_arg); - -xla::PjRtClient::KeyValuePutCallback ToCppKeyValuePutCallback( - PJRT_KeyValuePutCallback c_callback, void* user_arg); +std::shared_ptr ToCppKeyValueStore( + PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, + PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg); // A method that does not nothing other than returning a nullptr. Can be used as // the implementation of PJRT_Plugin_Initialize for plugins that do not require @@ -427,177 +429,12 @@ PJRT_Error* PJRT_Plugin_Initialize_NoOp(PJRT_Plugin_Initialize_Args* args); // Creates a PJRT_Api with create_fn from the input and other functions in // pjrt_c_api_wrapper_impl. -constexpr PJRT_Api CreatePjrtApi( - PJRT_Client_Create* create_fn, - PJRT_TopologyDescription_Create* topology_create_fn, - PJRT_Plugin_Initialize* plugin_initialize_fn, - void* extension_start = nullptr) { - return PJRT_Api{ - /*struct_size=*/PJRT_Api_STRUCT_SIZE, - /*extension_start=*/extension_start, - - /*pjrt_api_version=*/ - PJRT_Api_Version{/*struct_size=*/PJRT_Api_Version_STRUCT_SIZE, - /*priv=*/nullptr, - /*major_version=*/PJRT_API_MAJOR, - /*minor_version=*/PJRT_API_MINOR}, - - /*PJRT_Error_Destroy=*/pjrt::PJRT_Error_Destroy, - /*PJRT_Error_Message=*/pjrt::PJRT_Error_Message, - /*PJRT_Error_GetCode=*/pjrt::PJRT_Error_GetCode, - - /*PJRT_Plugin_Initialize=*/plugin_initialize_fn, - /*PJRT_Plugin_Attributes=*/pjrt::PJRT_Plugin_Attributes, - - /*PJRT_Event_Destroy=*/pjrt::PJRT_Event_Destroy, - /*PJRT_Event_IsReady=*/pjrt::PJRT_Event_IsReady, - /*PJRT_Event_Error=*/pjrt::PJRT_Event_Error, - /*PJRT_Event_Await=*/pjrt::PJRT_Event_Await, - /*PJRT_Event_OnReady=*/pjrt::PJRT_Event_OnReady, - - /*PJRT_Client_Create=*/create_fn, - /*PJRT_Client_Destroy=*/pjrt::PJRT_Client_Destroy, - /*PJRT_Client_PlatformName=*/pjrt::PJRT_Client_PlatformName, - /*PJRT_Client_ProcessIndex=*/pjrt::PJRT_Client_ProcessIndex, - /*PJRT_Client_PlatformVersion= */ pjrt::PJRT_Client_PlatformVersion, - /*PJRT_Client_Devices= */ pjrt::PJRT_Client_Devices, - /*PJRT_Client_AddressableDevices=*/ - pjrt::PJRT_Client_AddressableDevices, - /*PJRT_Client_LookupDevice=*/pjrt::PJRT_Client_LookupDevice, - /*PJRT_Client_LookupAddressableDevice=*/ - pjrt::PJRT_Client_LookupAddressableDevice, - /*PJRT_Client_AddressableMemories=*/pjrt::PJRT_Client_AddressableMemories, - /*PJRT_Client_Compile=*/pjrt::PJRT_Client_Compile, - /*PJRT_Client_DefaultDeviceAssignment=*/ - pjrt::PJRT_Client_DefaultDeviceAssignment, - /*PJRT_Client_BufferFromHostBuffer=*/ - pjrt::PJRT_Client_BufferFromHostBuffer, - - /*PJRT_DeviceDescription_Id=*/pjrt::PJRT_DeviceDescription_Id, - /*PJRT_DeviceDescription_ProcessIndex=*/ - pjrt::PJRT_DeviceDescription_ProcessIndex, - /*PJRT_DeviceDescription_Attributes=*/ - pjrt::PJRT_DeviceDescription_Attributes, - /*PJRT_DeviceDescription_Kind=*/pjrt::PJRT_DeviceDescription_Kind, - /*PJRT_DeviceDescription_DebugString=*/ - pjrt::PJRT_DeviceDescription_DebugString, - /*PJRT_DeviceDescription_ToString=*/ - pjrt::PJRT_DeviceDescription_ToString, - - /*PJRT_Device_GetDescription=*/pjrt::PJRT_Device_GetDescription, - /*PJRT_Device_IsAddressable=*/pjrt::PJRT_Device_IsAddressable, - /*PJRT_Device_LocalHardwareId=*/pjrt::PJRT_Device_LocalHardwareId, - /*PJRT_Device_AddressableMemories=*/pjrt::PJRT_Device_AddressableMemories, - /*PJRT_Device_DefaultMemory=*/pjrt::PJRT_Device_DefaultMemory, - /*PJRT_Device_MemoryStats=*/pjrt::PJRT_Device_MemoryStats, - - /*PJRT_Memory_Id=*/pjrt::PJRT_Memory_Id, - /*PJRT_Memory_Kind=*/pjrt::PJRT_Memory_Kind, - /*PJRT_Memory_DebugString=*/pjrt::PJRT_Memory_DebugString, - /*PJRT_Memory_ToString=*/pjrt::PJRT_Memory_ToString, - /*PJRT_Memory_AddressableByDevices=*/ - pjrt::PJRT_Memory_AddressableByDevices, - - /*PJRT_Executable_Destroy=*/pjrt::PJRT_Executable_Destroy, - /*PJRT_Executable_Name=*/pjrt::PJRT_Executable_Name, - /*PJRT_Executable_NumReplicas=*/pjrt::PJRT_Executable_NumReplicas, - /*PJRT_Executable_NumPartitions=*/ - pjrt::PJRT_Executable_NumPartitions, - /*PJRT_Executable_NumOutputs=*/pjrt::PJRT_Executable_NumOutputs, - /*PJRT_Executable_SizeOfGeneratedCodeInBytes=*/ - pjrt::PJRT_Executable_SizeOfGeneratedCodeInBytes, - /*PJRT_Executable_GetCostAnalysis=*/pjrt::PJRT_Executable_GetCostAnalysis, - /*PJRT_Executable_OutputMemoryKinds=*/ - pjrt::PJRT_Executable_OutputMemoryKinds, - /*PJRT_Executable_OptimizedProgram=*/ - pjrt::PJRT_Executable_OptimizedProgram, - /*PJRT_Executable_Serialize=*/pjrt::PJRT_Executable_Serialize, - - /*PJRT_LoadedExecutable_Destroy=*/pjrt::PJRT_LoadedExecutable_Destroy, - /*PJRT_LoadedExecutable_GetExecutable=*/ - pjrt::PJRT_LoadedExecutable_GetExecutable, - /*PJRT_LoadedExecutable_AddressableDevices=*/ - pjrt::PJRT_LoadedExecutable_AddressableDevices, - /*PJRT_LoadedExecutable_Delete=*/pjrt::PJRT_LoadedExecutable_Delete, - /*PJRT_LoadedExecutable_IsDeleted=*/ - pjrt::PJRT_LoadedExecutable_IsDeleted, - /*PJRT_LoadedExecutable_Execute=*/pjrt::PJRT_LoadedExecutable_Execute, - /*PJRT_Executable_DeserializeAndLoad=*/ - pjrt::PJRT_Executable_DeserializeAndLoad, - /*PJRT_LoadedExecutable_Fingerprint=*/ - pjrt::PJRT_LoadedExecutable_Fingerprint, - - /*PJRT_Buffer_Destroy=*/pjrt::PJRT_Buffer_Destroy, - /*PJRT_Buffer_ElementType=*/pjrt::PJRT_Buffer_ElementType, - /*PJRT_Buffer_Dimensions=*/pjrt::PJRT_Buffer_Dimensions, - /*PJRT_Buffer_UnpaddedDimensions=*/ - pjrt::PJRT_Buffer_UnpaddedDimensions, - /*PJRT_Buffer_DynamicDimensionIndices=*/ - pjrt::PJRT_Buffer_DynamicDimensionIndices, - /*PJRT_Buffer_GetMemoryLayout=*/ - pjrt::PJRT_Buffer_GetMemoryLayout, - /*PJRT_Buffer_OnDeviceSizeInBytes=*/ - pjrt::PJRT_Buffer_OnDeviceSizeInBytes, - /*PJRT_Buffer_Device=*/pjrt::PJRT_Buffer_Device, - /*PJRT_Buffer_Memory=*/pjrt::PJRT_Buffer_Memory, - /*PJRT_Buffer_Delete=*/pjrt::PJRT_Buffer_Delete, - /*PJRT_Buffer_IsDeleted=*/pjrt::PJRT_Buffer_IsDeleted, - /*PJRT_Buffer_CopyToDevice=*/pjrt::PJRT_Buffer_CopyToDevice, - /*PJRT_Buffer_ToHostBuffer=*/pjrt::PJRT_Buffer_ToHostBuffer, - /*PJRT_Buffer_IsOnCpu=*/pjrt::PJRT_Buffer_IsOnCpu, - /*PJRT_Buffer_ReadyEvent=*/pjrt::PJRT_Buffer_ReadyEvent, - /*PJRT_Buffer_UnsafePointer=*/pjrt::PJRT_Buffer_UnsafePointer, - /*PJRT_Buffer_IncreaseExternalReferenceCount=*/ - pjrt::PJRT_Buffer_IncreaseExternalReferenceCount, - /*PJRT_Buffer_DecreaseExternalReferenceCount=*/ - pjrt::PJRT_Buffer_DecreaseExternalReferenceCount, - /*PJRT_Buffer_OpaqueDeviceMemoryDataPointer=*/ - pjrt::PJRT_Buffer_OpaqueDeviceMemoryDataPointer, - - /*PJRT_CopyToDeviceStream_Destroy=*/ - pjrt::PJRT_CopyToDeviceStream_Destroy, - /*PJRT_CopyToDeviceStream_AddChunk=*/ - pjrt::PJRT_CopyToDeviceStream_AddChunk, - /*PJRT_CopyToDeviceStream_TotalBytes=*/ - pjrt::PJRT_CopyToDeviceStream_TotalBytes, - /*PJRT_CopyToDeviceStream_GranuleSize=*/ - pjrt::PJRT_CopyToDeviceStream_GranuleSize, - /*PJRT_CopyToDeviceStream_CurrentBytes=*/ - pjrt::PJRT_CopyToDeviceStream_CurrentBytes, - - /*PJRT_TopologyDescription_Create=*/topology_create_fn, - /*PJRT_TopologyDescription_Destroy=*/ - pjrt::PJRT_TopologyDescription_Destroy, - /*PJRT_TopologyDescription_PlatformName=*/ - pjrt::PJRT_TopologyDescription_PlatformName, - /*PJRT_TopologyDescription_PlatformVersion=*/ - pjrt::PJRT_TopologyDescription_PlatformVersion, - /*PJRT_TopologyDescription_GetDeviceDescriptions=*/ - pjrt::PJRT_TopologyDescription_GetDeviceDescriptions, - /*PJRT_TopologyDescription_Serialize=*/ - pjrt::PJRT_TopologyDescription_Serialize, - /*PJRT_TopologyDescription_Attributes=*/ - pjrt::PJRT_TopologyDescription_Attributes, - - /*PJRT_Compile=*/pjrt::PJRT_Compile, - - // Always add new fields to the end of the struct. Move fields below to - // their corresponding places after each major version bump. - /*PJRT_Executable_OutputElementTypes=*/ - pjrt::PJRT_Executable_OutputElementTypes, - /*PJRT_Executable_OutputDimensions=*/ - pjrt::PJRT_Executable_OutputDimensions, - /*PJRT_Buffer_CopyToMemory=*/ - pjrt::PJRT_Buffer_CopyToMemory, - /*PJRT_Client_CreateViewOfDeviceBuffer=*/ - pjrt::PJRT_Client_CreateViewOfDeviceBuffer, - /*PJRT_Executable_Fingerprint=*/pjrt::PJRT_Executable_Fingerprint, - /*PJRT_Client_TopologyDescription= */ - pjrt::PJRT_Client_TopologyDescription, - /*PJRT_Executable_GetCompiledMemoryStats= */ - pjrt::PJRT_Executable_GetCompiledMemoryStats, - }; -} +PJRT_Api CreatePjrtApi(PJRT_Client_Create* create_fn, + PJRT_TopologyDescription_Create* topology_create_fn, + PJRT_Plugin_Initialize* plugin_initialize_fn, + PJRT_Extension_Base* extension_start = nullptr, + PJRT_Plugin_Attributes* plugin_attributes_fn = + pjrt::PJRT_Plugin_Attributes_Empty); } // namespace pjrt diff --git a/xla/pjrt/compile_options.proto b/xla/pjrt/compile_options.proto index 5e9c94fadfc56..4ea4af933e936 100644 --- a/xla/pjrt/compile_options.proto +++ b/xla/pjrt/compile_options.proto @@ -7,7 +7,7 @@ import "xla/xla.proto"; import "xla/xla_data.proto"; // A serialization of xla::ExecutableBuildOptions. -// Next id: 16. +// Next id: 19. message ExecutableBuildOptionsProto { // If set, this is the device to build the computation for. Valid // device_ordinal values are: 0 to # of devices - 1. These values are @@ -65,6 +65,18 @@ message ExecutableBuildOptionsProto { // which can be used to compile post-optimizations HLO modules. bool run_backend_only = 11; + // Allows sharding propagation to propagate to the parameters. This changes + // the input shape of the computation (which is undesirable), but it can be + // used to allow to run partial compilation to determine what would be the + // input sharding of a computation if XLA would be allowed to propagate the + // sharding which can be used by higher level framework as a way to query + // intermediate sharding of operations when multiple computation would be + // chained and merged together. + // This is a vector of bool, because the user can control which parameters can + // have the sharding substituted. If only one boolean value is passed in the + // vector that is interpreted as the value to be applied for every parameter. + repeated bool allow_spmd_sharding_propagation_to_parameters = 18; + // Allows sharding propagation to propagate to the outputs. This changes the // output shape of the computation (which is undesirable), but it can be used // to allow to run partial compilation to determine what would be the output @@ -84,6 +96,12 @@ message ExecutableBuildOptionsProto { bytes fdo_profile = 14; int64 device_memory_size = 15; + + // Mesh shape in auto sharding options. + repeated int64 auto_spmd_partitioning_mesh_shape = 16; + + // Mesh ids in auto sharding options. + repeated int64 auto_spmd_partitioning_mesh_ids = 17; } message OptionOverrideProto { diff --git a/xla/pjrt/cpu/BUILD b/xla/pjrt/cpu/BUILD index 4e00e2f472553..d8ae46a93bb2c 100644 --- a/xla/pjrt/cpu/BUILD +++ b/xla/pjrt/cpu/BUILD @@ -1,5 +1,7 @@ -load("//xla:xla.bzl", "xla_cc_test") +load("@tsl//tsl:tsl.bzl", "if_oss", "internal_visibility") +load("@tsl//tsl/platform:build_config.bzl", "tf_proto_library") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -28,6 +30,7 @@ cc_library( "//xla/runtime:cpu_event", "@com_google_absl//absl/base", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", @@ -95,15 +98,42 @@ cc_library( ], ) +tf_proto_library( + name = "cpu_topology_proto", + srcs = ["cpu_topology.proto"], + cc_api_version = 2, + visibility = ["//visibility:public"], +) + +cc_library( + name = "cpu_topology", + srcs = ["cpu_topology.cc"], + hdrs = ["cpu_topology.h"], + deps = [ + ":cpu_topology_proto_cc", + "@com_google_absl//absl/types:span", + ], +) + +xla_cc_test( + name = "cpu_topology_test", + srcs = ["cpu_topology_test.cc"], + deps = [ + ":cpu_topology", + "@tsl//tsl/platform:protobuf", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + cc_library( name = "cpu_client", srcs = ["cpu_client.cc"], hdrs = ["cpu_client.h"], - visibility = [ - "//xla:friends", - ], + visibility = internal_visibility(["//xla:friends"]), deps = [ ":abstract_tfrt_cpu_buffer", + ":cpu_topology", ":tracked_tfrt_cpu_device_buffer", "//xla:array", "//xla:debug_options_flags", @@ -122,11 +152,15 @@ cc_library( "//xla/pjrt:compile_options_proto_cc", "//xla/pjrt:mlir_to_hlo", "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_common", + "//xla/pjrt:pjrt_compiler", + "//xla/pjrt:pjrt_device_description", "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", "//xla/pjrt:semaphore", "//xla/pjrt:transpose", "//xla/pjrt:utils", + "//xla/pjrt/distributed:key_value_store_interface", "//xla/pjrt/distributed:topology_util", "//xla/runtime:cpu_event", "//xla/service:buffer_assignment", @@ -141,9 +175,12 @@ cc_library( "//xla/service:hlo_proto_cc", "//xla/service:hlo_value", "//xla/service/cpu:buffer_desc", + "//xla/service/cpu:collectives_interface", "//xla/service/cpu:cpu_compiler", "//xla/service/cpu:cpu_executable", + "//xla/service/cpu:cpu_executable_run_options", "//xla/service/cpu:cpu_xfeed", + "//xla/service/cpu:simple_orc_jit", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/container:flat_hash_map", @@ -153,15 +190,18 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", # TODO(zhangqiaorjc): Remove if use TFRT threadpool. + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@tsl//tsl/concurrency:async_value", "@tsl//tsl/concurrency:ref_count", + "@tsl//tsl/lib/strings:proto_serialization", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:denormal", "@tsl//tsl/platform:env", @@ -199,3 +239,82 @@ xla_cc_test( "@tsl//tsl/platform:test", ], ) + +cc_library( + name = "gloo_kv_store", + srcs = ["gloo_kv_store.cc"], + hdrs = ["gloo_kv_store.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "//xla/pjrt:status_casters", + "//xla/pjrt/distributed:key_value_store_interface", + "@com_google_absl//absl/time", + "@gloo", + ], +) + +cc_library( + name = "gloo_collectives", + srcs = ["gloo_collectives.cc"], + hdrs = ["gloo_collectives.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "//xla/service/cpu:collectives_interface", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@gloo", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + ], +) + +cc_library( + name = "mpi_collectives", + srcs = if_oss(["mpi_collectives.cc"]), + hdrs = if_oss(["mpi_collectives.h"]), + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = if_oss([ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "//xla/service/cpu:collectives_interface", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@mpitrampoline", + ]), +) diff --git a/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc b/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc index 3bf5805e75888..9b484ebc4a971 100644 --- a/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc +++ b/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -133,14 +133,6 @@ ShapedBuffer AsShapedBuffer( } // namespace -UnpinnedHostMemorySpace::UnpinnedHostMemorySpace(int id, PjRtClient* client) - : id_(id), client_(client) { - debug_string_ = absl::StrFormat( - "UnpinnedHostMemorySpace(id=%i, process_index=%i, client=%s)", id_, - client_->process_index(), client_->platform_name()); - to_string_ = absl::StrFormat("UNPINNED_HOST_%i", id_); -} - AbstractTfrtCpuBuffer::AbstractTfrtCpuBuffer( Shape on_device_shape, std::unique_ptr tracked_device_buffer) @@ -151,7 +143,7 @@ AbstractTfrtCpuBuffer::~AbstractTfrtCpuBuffer() { AbstractTfrtCpuBuffer::Delete(); } -StatusOr AbstractTfrtCpuBuffer::logical_on_device_shape() { +absl::StatusOr AbstractTfrtCpuBuffer::logical_on_device_shape() { if (on_device_shape_.is_static()) { return on_device_shape_; } @@ -168,7 +160,7 @@ StatusOr AbstractTfrtCpuBuffer::logical_on_device_shape() { const auto& av = device_buffer->definition_event(); BlockUntilReady(av.GetAsyncValue()); if (auto* error = av.GetErrorIfPresent()) { - return InternalError("Error Execute: %s", error->message()); + return Internal("Error Execute: %s", error->message()); } ShapedBuffer shaped_buffer = @@ -180,11 +172,11 @@ StatusOr AbstractTfrtCpuBuffer::logical_on_device_shape() { return ret_shape; } -StatusOr AbstractTfrtCpuBuffer::GetOnDeviceSizeInBytes() const { +absl::StatusOr AbstractTfrtCpuBuffer::GetOnDeviceSizeInBytes() const { return ShapeUtil::ByteSizeOf(on_device_shape_); } -StatusOr> +absl::StatusOr> AbstractTfrtCpuBuffer::AcquireExternalReference() { class ScopedExternalReference : public PjRtBuffer::ExternalReference { public: @@ -239,7 +231,7 @@ class TrackedCpuDeviceBufferExternalReference std::unique_ptr tracked_device_buffer_; }; -StatusOr> +absl::StatusOr> AbstractTfrtCpuBuffer::ReleaseDeviceMemoryOwnership( bool wait_for_operations_to_complete) { if (on_device_shape_.IsTuple()) { @@ -324,7 +316,7 @@ AbstractTfrtCpuBuffer::ReleaseBufferLocked() { return std::move(tracked_device_buffer_); } -StatusOr> +absl::StatusOr> AbstractTfrtCpuBuffer::Release(bool wait_for_operations_to_complete) { std::unique_ptr device_buffer; { @@ -347,7 +339,7 @@ AbstractTfrtCpuBuffer::Release(bool wait_for_operations_to_complete) { BlockUntilReady(av.GetAsyncValue()); if (auto* error = av.GetErrorIfPresent()) { first_error.Update( - InternalError("Error Execute: %s", error->message())); + Internal("Error Execute: %s", error->message())); } } if (!first_error.ok()) return std::move(first_error); @@ -367,7 +359,7 @@ TrackedTfrtCpuDeviceBuffer* AbstractTfrtCpuBuffer::AcquireUsage( return tracked_device_buffer_.get(); } -StatusOr +absl::StatusOr AbstractTfrtCpuBuffer::AcquireDonation() { absl::MutexLock lock(&mu_); @@ -412,7 +404,7 @@ PjRtFuture AbstractTfrtCpuBuffer::ToLiteralHelper( bool should_sync_copy = device_buffer_wait_avs.empty() && literal->size_bytes() < kSmallDataTransferByteSize; - StatusOr device_shape = logical_on_device_shape(); + absl::StatusOr device_shape = logical_on_device_shape(); if (!device_shape.ok()) { return PjRtFuture(device_shape.status()); } @@ -462,7 +454,7 @@ PjRtFuture AbstractTfrtCpuBuffer::ToLiteralHelper( } } -StatusOr> +absl::StatusOr> AbstractTfrtCpuBuffer::CopyToDeviceAcrossClients(PjRtDevice* dst_device) { TF_ASSIGN_OR_RETURN(std::shared_ptr literal, ToLiteralSync()); // Avoid use-after-free on `literal` due to unsequenced move and use. @@ -474,11 +466,11 @@ AbstractTfrtCpuBuffer::CopyToDeviceAcrossClients(PjRtDevice* dst_device) { return dst_device->client()->BufferFromHostBuffer( literal_pointer->untyped_data(), literal_pointer->shape().element_type(), literal_pointer->shape().dimensions(), byte_strides, - PjRtClient::HostBufferSemantics::kZeroCopy, + PjRtClient::HostBufferSemantics::kImmutableZeroCopy, [literal{std::move(literal)}]() { /* frees literal */ }, dst_device); } -StatusOr> +absl::StatusOr> AbstractTfrtCpuBuffer::CopyToDeviceHelper(AsyncWorkRunner* async_work_runner) { // Copy each leaf buffer to a destination buffer. auto usage_event = tsl::MakeConstructedAsyncValueRef(); @@ -637,7 +629,7 @@ void AbstractTfrtCpuBuffer::CopyFromLiteral( } } -/*static*/ StatusOr> +/*static*/ absl::StatusOr> AbstractTfrtCpuBuffer::AllocateTrackedDeviceBuffer( const Shape& on_device_shape, absl::InlinedVector, 4> definition_events) { @@ -677,12 +669,12 @@ AbstractTfrtCpuBuffer::AllocateTrackedDeviceBuffer( } } -/*static*/ StatusOr> +/*static*/ absl::StatusOr> AbstractTfrtCpuBuffer::BufferFromHostBufferHelper( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, PjRtClient::HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, const Shape& shape, + absl::AnyInvocable on_done_with_host_buffer, const Shape& shape, AsyncWorkRunner* async_work_runner, absl::Mutex* transpose_mu, TransposePlanCache* transpose_cache) { bool has_default_layout = @@ -695,12 +687,13 @@ AbstractTfrtCpuBuffer::BufferFromHostBufferHelper( // code which requires it. bool can_use_zero_copy = has_default_layout && !is_int4 && - host_buffer_semantics == PjRtClient::HostBufferSemantics::kZeroCopy && + host_buffer_semantics == + PjRtClient::HostBufferSemantics::kImmutableZeroCopy && ((absl::bit_cast(data) & (cpu_function_runtime::MinAlign() - 1)) == 0); absl::InlinedVector, 4> buffers; absl::InlinedVector, 4> definition_events; - std::function on_delete_callback; + absl::AnyInvocable on_delete_callback; size_t byte_size = ShapeUtil::ByteSizeOf(shape); if (can_use_zero_copy) { auto device_buffer = std::make_shared( @@ -724,11 +717,13 @@ AbstractTfrtCpuBuffer::BufferFromHostBufferHelper( { absl::InlinedVector permutation(dims.size()); absl::c_iota(permutation, 0); + TransposePlan::Options options; + options.elem_size_in_bytes = primitive_util::ByteWidth(type); + options.dims = dims; + options.permutation = permutation; + options.input_layout = TransposePlan::Striding{*byte_strides}; absl::MutexLock lock(transpose_mu); - TF_ASSIGN_OR_RETURN( - transpose, transpose_cache->GetOrCreate( - primitive_util::ByteWidth(type), dims, permutation, - TransposePlan::Striding{*byte_strides})); + TF_ASSIGN_OR_RETURN(transpose, transpose_cache->GetOrCreate(options)); } if (!is_int4) { transpose->Execute(data, dst_data_ptr); @@ -745,7 +740,7 @@ AbstractTfrtCpuBuffer::BufferFromHostBufferHelper( PackInt4(src_data_span, dst_data_span); } if (on_done_with_host_buffer) { - on_done_with_host_buffer(); + std::move(on_done_with_host_buffer)(); on_done_with_host_buffer = nullptr; } } else { @@ -756,7 +751,7 @@ AbstractTfrtCpuBuffer::BufferFromHostBufferHelper( if (should_sync_copy) { std::memcpy(dst_data_ptr, data, byte_size); if (on_done_with_host_buffer) { - on_done_with_host_buffer(); + std::move(on_done_with_host_buffer)(); on_done_with_host_buffer = nullptr; } } else { @@ -771,7 +766,7 @@ AbstractTfrtCpuBuffer::BufferFromHostBufferHelper( tsl::profiler::TraceMe traceme("H2D Dispatch"); std::memcpy(dst_data_ptr, data, byte_size); if (on_done_with_host_buffer) { - on_done_with_host_buffer(); + std::move(on_done_with_host_buffer)(); on_done_with_host_buffer = nullptr; } // Signal copy is complete. diff --git a/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h b/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h index f1cec9f5acb17..d9ed3d356507a 100644 --- a/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h +++ b/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -90,37 +90,6 @@ class AsyncWorkRunner { absl::AnyInvocable work) = 0; }; -// Represents the unpinned host memory accessible to a PjRtDevice. -class UnpinnedHostMemorySpace : public PjRtMemorySpace { - public: - static constexpr absl::string_view kMemorySpaceKind = "unpinned_host"; - - UnpinnedHostMemorySpace(int id, PjRtClient* client); - - PjRtClient* client() const override { return client_; } - - absl::Span devices() const override { return devices_; } - - int id() const override { return id_; } - - absl::string_view memory_space_kind() const override { - return kMemorySpaceKind; - } - - absl::string_view DebugString() const override { return debug_string_; } - - absl::string_view ToString() const override { return to_string_; } - - void AttachDevice(PjRtDevice* device) { devices_.push_back(device); } - - private: - int id_; - PjRtClient* client_; - std::vector devices_; - std::string debug_string_; - std::string to_string_; -}; - class AbstractTfrtCpuBuffer : public PjRtBuffer { public: AbstractTfrtCpuBuffer( @@ -130,22 +99,22 @@ class AbstractTfrtCpuBuffer : public PjRtBuffer { const Shape& on_device_shape() const override { return on_device_shape_; } - StatusOr logical_on_device_shape() override; + absl::StatusOr logical_on_device_shape() override; - StatusOr> AcquireExternalReference() + absl::StatusOr> AcquireExternalReference() override; - StatusOr> ReleaseDeviceMemoryOwnership( - bool wait_for_operations_to_complete) override; + absl::StatusOr> + ReleaseDeviceMemoryOwnership(bool wait_for_operations_to_complete) override; - StatusOr GetOnDeviceSizeInBytes() const override; + absl::StatusOr GetOnDeviceSizeInBytes() const override; PjRtFuture CopyRawToHost(void* dst, int64_t offset, int64_t transfer_size) override { return PjRtFuture(Unimplemented("CopyRawToHost not implemented")); } - StatusOr> CopyToMemorySpace( + absl::StatusOr> CopyToMemorySpace( PjRtMemorySpace* dst_memory_space) override { return Unimplemented("CopyToMemorySpace not implemented"); } @@ -232,7 +201,7 @@ class AbstractTfrtCpuBuffer : public PjRtBuffer { // serialize this donation with previous usages. After this method is called, // calls to AcquireUsage() will fail. Returns error status if the buffer is // already donated or there is outstanding external references. - StatusOr AcquireDonation(); + absl::StatusOr AcquireDonation(); // A helper function for PjRtClient::BufferFromHostLiteral. Copy the literal // to the current buffer asynchronously. `avs` is used to signal when the copy @@ -245,7 +214,7 @@ class AbstractTfrtCpuBuffer : public PjRtBuffer { // Allocates a new `TrackedTfrtCpuDeviceBuffer` with the given shape and // definition events. - static StatusOr> + static absl::StatusOr> AllocateTrackedDeviceBuffer( const Shape& on_device_shape, absl::InlinedVector, 4> @@ -264,14 +233,14 @@ class AbstractTfrtCpuBuffer : public PjRtBuffer { // device buffer from the host buffer (maybe zero-copy or async). // `transpose_mu` and `transpose_cache` are used to transpose the input // layout. - static StatusOr> + static absl::StatusOr> BufferFromHostBufferHelper( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, PjRtClient::HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, const Shape& shape, - AsyncWorkRunner* async_work_runner, absl::Mutex* transpose_mu, - TransposePlanCache* transpose_cache); + absl::AnyInvocable on_done_with_host_buffer, + const Shape& shape, AsyncWorkRunner* async_work_runner, + absl::Mutex* transpose_mu, TransposePlanCache* transpose_cache); protected: virtual absl::string_view buffer_name() const = 0; @@ -279,11 +248,11 @@ class AbstractTfrtCpuBuffer : public PjRtBuffer { PjRtFuture ToLiteralHelper(MutableLiteralBase* literal, AsyncWorkRunner* async_work_runner); - StatusOr> CopyToDeviceAcrossClients( + absl::StatusOr> CopyToDeviceAcrossClients( PjRtDevice* dst_device); - StatusOr> CopyToDeviceHelper( - AsyncWorkRunner* async_work_runner); + absl::StatusOr> + CopyToDeviceHelper(AsyncWorkRunner* async_work_runner); bool IsEmptyTuple() const { return on_device_shape_.IsTuple() && @@ -315,7 +284,7 @@ class AbstractTfrtCpuBuffer : public PjRtBuffer { // If the buffer was shared via an external reference it is the client's // responsibility that accesses via that reference do not interfere with // accesses via the buffer returned from Release. - StatusOr> Release( + absl::StatusOr> Release( bool wait_for_operations_to_complete); // Releases the device buffer by returning a unique_ptr of it. If there is diff --git a/xla/pjrt/cpu/cpu_client.cc b/xla/pjrt/cpu/cpu_client.cc index a869f9fb8bbd3..559a3fa77f927 100644 --- a/xla/pjrt/cpu/cpu_client.cc +++ b/xla/pjrt/cpu/cpu_client.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,6 +27,9 @@ limitations under the License. #include #include +#include "xla/pjrt/cpu/cpu_topology.h" +#include "xla/pjrt/pjrt_compiler.h" + #define EIGEN_USE_THREADS #include "absl/base/dynamic_annotations.h" @@ -53,6 +56,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/literal_util.h" @@ -62,6 +66,7 @@ limitations under the License. #include "xla/pjrt/distributed/topology_util.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/semaphore.h" @@ -72,9 +77,12 @@ limitations under the License. #include "xla/service/compiler.h" #include "xla/service/computation_placer.h" #include "xla/service/cpu/buffer_desc.h" +#include "xla/service/cpu/collectives_interface.h" #include "xla/service/cpu/cpu_compiler.h" #include "xla/service/cpu/cpu_executable.h" +#include "xla/service/cpu/cpu_executable_run_options.h" #include "xla/service/cpu/cpu_xfeed.h" +#include "xla/service/cpu/simple_orc_jit.h" #include "xla/service/custom_call_status.h" #include "xla/service/dump.h" #include "xla/service/executable.h" @@ -93,6 +101,7 @@ limitations under the License. #include "tsl/concurrency/async_value.h" #include "tsl/concurrency/async_value_ref.h" #include "tsl/concurrency/ref_count.h" +#include "tsl/lib/strings/proto_serialization.h" #include "tsl/platform/casts.h" #include "tsl/platform/denormal.h" #include "tsl/platform/env.h" @@ -109,7 +118,7 @@ namespace { using ::xla::runtime::CpuEvent; -StatusOr> AllocateDestinationBuffer( +absl::StatusOr> AllocateDestinationBuffer( const Shape& on_device_shape, absl::InlinedVector, 4> definition_events, TfrtCpuDevice* device, TfrtCpuClient* client) { @@ -121,7 +130,7 @@ StatusOr> AllocateDestinationBuffer( on_device_shape, std::move(tracked_device_buffer), client, device); } -StatusOr> AllocateDestinationBufferAndAvs( +absl::StatusOr> AllocateDestinationBufferAndAvs( const Shape& shape, absl::InlinedVector, 4>* avs, TfrtCpuDevice* device, TfrtCpuClient* client) { @@ -178,7 +187,8 @@ class ThreadPoolAsyncWorkRunner : public AsyncWorkRunner { class TfrtCpuAsyncHostToDeviceTransferManager : public AbstractAsyncHostToHostMemoryTransferManager { public: - static StatusOr> + static absl::StatusOr< + std::unique_ptr> Create(absl::Span shapes, TfrtCpuDevice* device, TfrtCpuClient* client) { absl::InlinedVector, 4> buffers; @@ -259,6 +269,36 @@ absl::string_view TfrtCpuDeviceDescription::ToString() const { return to_string_; } +/*static*/ TfrtCpuTopologyDescription TfrtCpuTopologyDescription::Create( + PjRtPlatformId platform_id, absl::string_view platform_name, + absl::string_view platform_version, + absl::Span> devices, + absl::Span machine_attributes) { + std::vector cpu_devices; + cpu_devices.reserve(devices.size()); + for (auto& device : devices) { + cpu_devices.push_back( + {device->id(), device->process_index(), device->local_hardware_id()}); + } + return TfrtCpuTopologyDescription(platform_id, platform_name, + platform_version, cpu_devices, + machine_attributes); +} + +absl::StatusOr TfrtCpuTopologyDescription::GetDefaultLayout( + PrimitiveType element_type, absl::Span dims) const { + Shape shape = ShapeUtil::MakeShape(element_type, dims); + return LayoutUtil::GetWithDefaultLayout(shape).layout(); +} + +absl::StatusOr TfrtCpuTopologyDescription::Serialize() const { + std::string result; + if (!tsl::SerializeToStringDeterministic(cpu_topology_.ToProto(), &result)) { + return absl::InternalError("Failed to serialize cpu_topology"); + } + return result; +} + TfrtCpuDevice::TfrtCpuDevice(int id, int process_index, int local_hardware_id, int max_inflight_computations) : description_(id, process_index, local_hardware_id), @@ -277,7 +317,7 @@ absl::Span TfrtCpuDevice::memory_spaces() const { return {}; } -StatusOr TfrtCpuDevice::default_memory_space() const { +absl::StatusOr TfrtCpuDevice::default_memory_space() const { return Unimplemented("default_memory_space is not supported"); } @@ -288,7 +328,7 @@ static int CpuDeviceCount() { return GetDebugOptionsFromFlags().xla_force_host_platform_device_count(); } -StatusOr> GetTfrtCpuClient( +absl::StatusOr> GetTfrtCpuClient( const CpuClientOptions& options) { // Need at least CpuDeviceCount threads to launch one collective. int cpu_device_count = options.cpu_device_count.value_or(CpuDeviceCount()); @@ -311,10 +351,10 @@ StatusOr> GetTfrtCpuClient( } GlobalTopologyProto global_topology; - TF_RETURN_IF_ERROR( - ExchangeTopologies("cpu", options.node_id, options.num_nodes, - absl::Minutes(2), absl::Minutes(5), options.kv_get, - options.kv_put, local_topology, &global_topology)); + TF_RETURN_IF_ERROR(ExchangeTopologies( + "cpu", options.node_id, options.num_nodes, absl::Minutes(2), + absl::Minutes(5), options.kv_store.get(), local_topology, + &global_topology)); std::vector> devices; for (const LocalTopologyProto& node : global_topology.nodes()) { @@ -328,17 +368,27 @@ StatusOr> GetTfrtCpuClient( } return std::unique_ptr(std::make_unique( - /*process_index=*/options.node_id, std::move(devices), num_threads)); + /*process_index=*/options.node_id, std::move(devices), + std::move(options.collectives), num_threads)); +} + +static tsl::ThreadOptions GetThreadOptions() { + tsl::ThreadOptions thread_options; + // On Mac OS the default stack size is 512KiB, which is too small for some + // BLAS and LAPACK functions (https://github.com/google/jax/issues/20428). + thread_options.stack_size = 2 * 1024 * 1024; + return thread_options; } TfrtCpuClient::TfrtCpuClient( int process_index, std::vector> devices, - size_t num_threads) + std::shared_ptr collectives, size_t num_threads) : process_index_(process_index), owned_devices_(std::move(devices)), computation_placer_(std::make_unique()), - pjrt_client_thread_pool_(new tsl::thread::ThreadPool( - tsl::Env::Default(), "XLATfrtCpuClient", num_threads)), + pjrt_client_thread_pool_( + new tsl::thread::ThreadPool(tsl::Env::Default(), GetThreadOptions(), + "XLATfrtCpuClient", num_threads)), async_work_runner_(std::make_unique( pjrt_client_thread_pool_.get())), eigen_intraop_pool_(new tsl::thread::ThreadPool( @@ -348,7 +398,11 @@ TfrtCpuClient::TfrtCpuClient( eigen_intraop_pool_->NumThreads())), last_collective_launch_event_( tsl::MakeAvailableAsyncValueRef()), - transpose_cache_(1024) { + transpose_cache_(1024), + collectives_(std::move(collectives)), + topology_(TfrtCpuTopologyDescription::Create( + platform_id(), platform_name(), platform_version(), owned_devices_, + cpu::DetectMachineAttributes())) { for (const std::unique_ptr& device : owned_devices_) { devices_.push_back(device.get()); CHECK(id_to_device_.insert({device->id(), device.get()}).second) @@ -372,37 +426,63 @@ TfrtCpuClient::TfrtCpuClient( TfrtCpuClient::~TfrtCpuClient() { LOG(INFO) << "TfrtCpuClient destroyed."; } -StatusOr TfrtCpuClient::LookupDevice(int device_id) const { - auto it = id_to_device_.find(device_id); +absl::StatusOr TfrtCpuClient::LookupDevice(int device_id) const { + return LookupDevice(PjRtGlobalDeviceId(device_id)); +} + +absl::StatusOr TfrtCpuClient::LookupDevice( + xla::PjRtGlobalDeviceId global_device_id) const { + auto it = id_to_device_.find(global_device_id.value()); if (it != id_to_device_.end()) { return it->second; } return InvalidArgument("No matching device found for device_id %d", - device_id); + global_device_id.value()); } -StatusOr TfrtCpuClient::LookupAddressableDevice( +absl::StatusOr TfrtCpuClient::LookupAddressableDevice( int local_hardware_id) const { + return LookupAddressableDevice(PjRtLocalDeviceId(local_hardware_id)); +} + +absl::StatusOr TfrtCpuClient::LookupAddressableDevice( + PjRtLocalDeviceId local_device_id) const { for (auto* device : addressable_devices_) { - if (local_hardware_id == device->local_hardware_id()) { + if (local_device_id == device->local_device_id()) { return device; } } - return InvalidArgument("No matching device found for local_hardware_id %d", - local_hardware_id); + return InvalidArgument("No matching device found for local_device_id %d", + local_device_id.value()); } absl::Span TfrtCpuClient::memory_spaces() const { return {}; } -StatusOr TfrtCpuClient::GetDefaultDeviceAssignment( +absl::StatusOr TfrtCpuClient::GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const { + if (num_partitions * num_replicas <= addressable_devices().size()) { + xla::DeviceAssignment assignment(num_replicas, num_partitions); + for (int i = 0; i < num_replicas; ++i) { + for (int j = 0; j < num_partitions; ++j) { + assignment(i, j) = + addressable_devices().at(i * num_partitions + j)->id(); + } + } + return assignment; + } return computation_placer_->AssignDevices(num_replicas, num_partitions); } -StatusOr> TfrtCpuClient::GetHloCostAnalysis() - const { +absl::StatusOr TfrtCpuClient::GetDefaultLayout( + PrimitiveType element_type, absl::Span dims) { + Shape shape = ShapeUtil::MakeShape(element_type, dims); + return LayoutUtil::GetWithDefaultLayout(shape).layout(); +} + +absl::StatusOr> +TfrtCpuClient::GetHloCostAnalysis() const { return std::make_unique(cpu::CpuExecutable::ShapeSizeBytes); } @@ -419,7 +499,7 @@ static const InstructionValueSet& GetRootValueSet( // assignment that make up for the output buffer. It is used by // CreateResultShapedBuffer to reconstruct the output buffer from the buffer // table allocated by MemoryForAllocation. -static StatusOr> +static absl::StatusOr> FindResultBufferAllocationIndex(const BufferAssignment& assignment, const HloModule& module) { absl::InlinedVector buffer_indices; @@ -456,7 +536,7 @@ FindResultBufferAllocationIndex(const BufferAssignment& assignment, return {std::move(buffer_indices)}; } -StatusOr TfrtCpuExecutable::SerializeExecutable() const { +absl::StatusOr TfrtCpuExecutable::SerializeExecutable() const { cpu::CpuCompiler compiler; TF_ASSIGN_OR_RETURN(std::unique_ptr aot_result, compiler.Export(cpu_executable_.get())); @@ -473,7 +553,7 @@ StatusOr TfrtCpuExecutable::SerializeExecutable() const { return proto.SerializeAsString(); } -StatusOr> +absl::StatusOr> TfrtCpuClient::DeserializeExecutable(absl::string_view serialized, std::optional options) { ExecutableAndOptionsProto proto; @@ -581,19 +661,19 @@ TfrtCpuClient::DeserializeExecutable(absl::string_view serialized, return std::unique_ptr(std::move(tfrt_cpu_executable)); } -static StatusOr> JitCompile( +static absl::StatusOr> JitCompile( const XlaComputation& computation, const absl::Span argument_layouts, const ExecutableBuildOptions& build_options, - const ExecutionOptions& execution_options) { + const ExecutionOptions& execution_options, + const xla::Compiler::CompileOptions& compile_options, int num_threads) { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, computation.GetProgramShape()); // Unoptimized HloModuleConfig. TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_module_config, CreateModuleConfig(program_shape, argument_layouts, &execution_options, - execution_options.num_replicas(), - /*num_threads=*/std::nullopt, + execution_options.num_replicas(), num_threads, /*aot_options=*/nullptr)); // Unoptimized HloModule. @@ -606,20 +686,17 @@ static StatusOr> JitCompile( DumpHloModuleIfEnabled(*hlo_module, kBeforeOptimizationsDumpName); // Run Hlo Passes - bool allow_sparse_shapes = - hlo_module->config().debug_options().xla_cpu_use_xla_runtime(); - cpu::CpuCompiler compiler(allow_sparse_shapes); - xla::Compiler::CompileOptions dummy; - TF_ASSIGN_OR_RETURN(hlo_module, - compiler.RunHloPasses(std::move(hlo_module), - /*stream_exec=*/nullptr, dummy)); + cpu::CpuCompiler compiler; + TF_ASSIGN_OR_RETURN(hlo_module, compiler.RunHloPasses(std::move(hlo_module), + /*stream_exec=*/nullptr, + compile_options)); // Run backend. return compiler.RunBackend(std::move(hlo_module), /*stream_exec=*/nullptr, - dummy); + compile_options); } -StatusOr> TfrtCpuClient::Compile( +absl::StatusOr> TfrtCpuClient::Compile( const XlaComputation& computation, CompileOptions options) { tsl::profiler::TraceMe traceme("TfrtCpuClient::Compile"); auto input_options = options; @@ -637,9 +714,7 @@ StatusOr> TfrtCpuClient::Compile( }, &num_replicas, &num_partitions, &device_assignment)); - // TODO(phawkins): cross-process computations aren't implemented yet. Check - // for these and error. - if (device_assignment) { + if (collectives_ == nullptr && device_assignment) { for (int replica = 0; replica < device_assignment->replica_count(); ++replica) { for (int computation = 0; @@ -648,6 +723,8 @@ StatusOr> TfrtCpuClient::Compile( int id = (*device_assignment)(replica, computation); TF_ASSIGN_OR_RETURN(auto* device, LookupDevice(id)); if (device->process_index() != process_index()) { + // TODO(phawkins): improve this error message when we're ready to + // publicize that multiprocess collectives exist. return InvalidArgument( "Multiprocess computations aren't implemented on the CPU " "backend."); @@ -698,9 +775,14 @@ StatusOr> TfrtCpuClient::Compile( computation.GetProgramShape()); ExecutionOptions execution_options = CreateExecutionOptions(build_options, &program_shape); - TF_ASSIGN_OR_RETURN(std::unique_ptr cpu_executable, - JitCompile(computation, argument_layout_pointers, - build_options, execution_options)); + xla::Compiler::CompileOptions compile_options{ + build_options.device_allocator(), build_options.compile_thread_pool(), + build_options.layout_canonicalization_callback()}; + TF_ASSIGN_OR_RETURN( + std::unique_ptr cpu_executable, + JitCompile(computation, argument_layout_pointers, build_options, + execution_options, compile_options, + eigen_intraop_device()->getPool()->NumThreads())); auto cpu_executable_ptr = tensorflow::down_cast(cpu_executable.get()); @@ -731,17 +813,18 @@ StatusOr> TfrtCpuClient::Compile( return std::unique_ptr(std::move(executable)); } -StatusOr> TfrtCpuClient::Compile( +absl::StatusOr> TfrtCpuClient::Compile( mlir::ModuleOp module, CompileOptions options) { XlaComputation xla_computation; TF_RETURN_IF_ERROR(MlirToXlaComputation( module, xla_computation, /*use_tuple_args=*/options.parameter_is_tupled_arguments, - /*return_tuple=*/false, /*legalize_sparse_ops=*/true)); + /*return_tuple=*/false)); return Compile(xla_computation, options); } -StatusOr> TfrtCpuClient::CreateViewOfDeviceBuffer( +absl::StatusOr> +TfrtCpuClient::CreateViewOfDeviceBuffer( void* device_ptr, const Shape& shape, PjRtDevice* device, std::function on_delete_callback, std::optional stream) { @@ -764,7 +847,7 @@ StatusOr> TfrtCpuClient::CreateViewOfDeviceBuffer( tensorflow::down_cast(device))); } -StatusOr> TfrtCpuClient::CreateErrorBuffer( +absl::StatusOr> TfrtCpuClient::CreateErrorBuffer( Status error, const Shape& shape, PjRtDevice* device) { return std::make_unique( shape, @@ -777,8 +860,9 @@ StatusOr> TfrtCpuClient::CreateErrorBuffer( this, tensorflow::down_cast(device)); } -StatusOr> TfrtCpuClient::CreateUninitializedBuffer( - const Shape& shape, PjRtDevice* device) { +absl::StatusOr> +TfrtCpuClient::CreateUninitializedBuffer(const Shape& shape, + PjRtDevice* device) { tsl::profiler::TraceMe traceme("TfrtCpuClient::CreateUninitializedBuffer"); VLOG(1) << "TfrtCpuClient::CreateUninitializedBuffer: shape: " << shape.DebugString() << " device: " << device->DebugString(); @@ -787,7 +871,7 @@ StatusOr> TfrtCpuClient::CreateUninitializedBuffer( tensorflow::down_cast(device), this); } -StatusOr> +absl::StatusOr> TfrtCpuClient::CreateBuffersForAsyncHostToDevice(absl::Span shapes, PjRtDevice* device) { auto* tfrt_device = tensorflow::down_cast(device); @@ -795,11 +879,12 @@ TfrtCpuClient::CreateBuffersForAsyncHostToDevice(absl::Span shapes, this); } -StatusOr> TfrtCpuClient::BufferFromHostBuffer( +absl::StatusOr> TfrtCpuClient::BufferFromHostBuffer( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, PjRtDevice* device) { + absl::AnyInvocable on_done_with_host_buffer, + PjRtDevice* device) { tsl::profiler::TraceMe traceme("TfrtCpuClient::BufferFromHostBuffer"); Shape shape = ShapeUtil::MakeShape(type, dims); VLOG(2) << "TfrtCpuClient::BufferFromHostBuffer: shape: " << shape.ToString() @@ -821,8 +906,9 @@ StatusOr> TfrtCpuClient::BufferFromHostBuffer( tensorflow::down_cast(device))); } -StatusOr> TfrtCpuClient::BufferFromHostLiteral( - const LiteralSlice& literal, PjRtDevice* device) { +absl::StatusOr> +TfrtCpuClient::BufferFromHostLiteral(const LiteralSlice& literal, + PjRtDevice* device) { tsl::profiler::TraceMe traceme("TfrtCpuClient::BufferFromHostLiteral"); VLOG(1) << "TfrtCpuClient::BufferFromHostLiteral: shape: " << literal.shape().DebugString() @@ -859,13 +945,22 @@ static std::vector> CopyAsyncValues( return avs; } -PjRtFuture TfrtCpuBuffer::ToLiteral(MutableLiteralBase* literal) { +PjRtFuture TfrtCpuBuffer::ToLiteral(MutableLiteralBase* literal) { return ToLiteralHelper(literal, client()->async_work_runner()); } +PjRtFuture TfrtCpuBuffer::LazyToLiteral( + absl::AnyInvocable() &&> generator) { + auto buffer = std::move(generator)(); + if (!buffer.ok()) { + return PjRtFuture(buffer.status()); + } + return ToLiteralHelper(buffer.value(), client()->async_work_runner()); +} + // TODO(zhangqiaorjc): Consider disallowing multiple CPU devices and assign // multiple pmap replicas to the same CPU device for multi-CPU pmap testing. -StatusOr> TfrtCpuBuffer::CopyToDevice( +absl::StatusOr> TfrtCpuBuffer::CopyToDevice( PjRtDevice* dst_device) { tsl::profiler::TraceMe traceme("TfrtCpuBuffer::CopyToDevice"); // TODO(zhangqiaorjc): Remove this restriction after removing the test that @@ -952,7 +1047,8 @@ void TfrtCpuExecutable::Delete() {} bool TfrtCpuExecutable::IsDeleted() { return false; } -StatusOr> TfrtCpuExecutable::Fingerprint() const { +absl::StatusOr> TfrtCpuExecutable::Fingerprint() + const { return std::optional(); } @@ -965,7 +1061,8 @@ Status TfrtCpuExecutable::SetUpDonation(bool tuple_inputs) { // The following few helpers are adapted from XLA:CPU to create a buffer table // and assemble the buffer pointers in order to call into CpuExecutable. -static StatusOr> MemoryForAllocation( +static absl::StatusOr> +MemoryForAllocation( const BufferAllocation& allocation, absl::Span const> arguments) { if (allocation.is_entry_computation_parameter()) { @@ -1004,7 +1101,7 @@ static StatusOr> MemoryForAllocation( return out; } -static StatusOr>> +static absl::StatusOr>> CreateBufferTable( const BufferAssignment& assignment, absl::Span const> arguments) { @@ -1062,7 +1159,7 @@ static std::vector MakeXLARuntimeDescriptorTable( return descriptor_table; } -StatusOr TfrtCpuExecutable::ExecuteHelper( +absl::StatusOr TfrtCpuExecutable::ExecuteHelper( absl::Span argument_handles, int replica, int partition, const RunId& run_id, const ExecuteOptions& options, tsl::AsyncValueRef last_collective_launch_event, bool fill_future, @@ -1145,8 +1242,8 @@ StatusOr TfrtCpuExecutable::ExecuteHelper( tfrt_buffer, donation_clashes, must_donate, i, replica, partition)); if (must_donate) { ++donate_it; - StatusOr donation_transaction = - tfrt_buffer->AcquireDonation(); + absl::StatusOr + donation_transaction = tfrt_buffer->AcquireDonation(); // On CPU, we allow donation to succeed by introducing a copy. This was // added when enabling buffer donation on CPU since it turned out that a // number of users were holding external references to buffers that were @@ -1236,6 +1333,10 @@ StatusOr TfrtCpuExecutable::ExecuteHelper( run_options.set_device_assignment(device_assignment.get()); run_options.set_intra_op_thread_pool(client_->eigen_intraop_device()); + auto cpu_run_options = std::make_shared(); + cpu_run_options->set_collectives(client_->collectives_.get()); + run_options.set_cpu_executable_run_options(cpu_run_options.get()); + // Schedule only one collective at a time. bool is_a_collective_launch = !!last_collective_launch_event; if (is_a_collective_launch) { @@ -1280,7 +1381,7 @@ StatusOr TfrtCpuExecutable::ExecuteHelper( std::optional error_message = xla::CustomCallStatusGetMessage(&status); if (error_message) { - return InternalError("Generated function failed: %s", *error_message); + return Internal("Generated function failed: %s", *error_message); } } else { @@ -1304,6 +1405,7 @@ StatusOr TfrtCpuExecutable::ExecuteHelper( run_options = std::move(run_options), cpu_executable_copy = cpu_executable_, device_assignment = std::move(device_assignment), + cpu_run_options = std::move(cpu_run_options), compute_reservation = std::move(compute_reservation), tuplized_arg = std::move(tuplized_arg), donation_transactions = std::move(donation_transactions), @@ -1391,7 +1493,7 @@ StatusOr TfrtCpuExecutable::ExecuteHelper( [done_event = done_event.CopyRef(), event = execute_event.CopyRef()]() { Status s; if (auto* error = event.GetErrorIfPresent()) { - s = InternalError("Compute error: %s", error->message()); + s = Internal("Compute error: %s", error->message()); } done_event.emplace(std::move(s)); }); @@ -1434,7 +1536,7 @@ static void MaybeDumpHloSnapshot( hlo_snapshot.SerializeAsString()); } -StatusOr>>> +absl::StatusOr>>> TfrtCpuExecutable::Execute( absl::Span> argument_handles, const ExecuteOptions& options, @@ -1557,7 +1659,7 @@ TfrtCpuExecutable::Execute( return wrapped_results; } -StatusOr>> +absl::StatusOr>> TfrtCpuExecutable::ExecuteSharded( absl::Span argument_handles, PjRtDevice* device, const ExecuteOptions& options, @@ -1588,7 +1690,7 @@ TfrtCpuExecutable::ExecuteSharded( device->id()); } -StatusOr>> +absl::StatusOr>> TfrtCpuExecutable::ExecutePortable( absl::Span argument_handles, PjRtDevice* device, const ExecuteOptions& options, diff --git a/xla/pjrt/cpu/cpu_client.h b/xla/pjrt/cpu/cpu_client.h index a350ce7d843b9..b302243ecfecd 100644 --- a/xla/pjrt/cpu/cpu_client.h +++ b/xla/pjrt/cpu/cpu_client.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,6 +28,8 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" @@ -38,8 +40,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" #include "xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h" +#include "xla/pjrt/cpu/cpu_topology.h" #include "xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/semaphore.h" @@ -47,11 +53,13 @@ limitations under the License. #include "xla/runtime/cpu_event.h" #include "xla/service/buffer_assignment.h" #include "xla/service/computation_placer.h" +#include "xla/service/cpu/collectives_interface.h" #include "xla/service/executable.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/shape.h" #include "xla/status.h" +#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/concurrency/async_value_ref.h" @@ -61,6 +69,8 @@ limitations under the License. namespace xla { +class TfrtCpuDevice; // forward declare + class TfrtCpuDeviceDescription final : public PjRtDeviceDescription { public: TfrtCpuDeviceDescription(int id, int process_index, int local_hardware_id); @@ -91,6 +101,100 @@ class TfrtCpuDeviceDescription final : public PjRtDeviceDescription { absl::flat_hash_map attributes_ = {}; }; +class TfrtCpuTopologyDescription : public PjRtTopologyDescription { + public: + static TfrtCpuTopologyDescription Create( + PjRtPlatformId platform_id, absl::string_view platform_name, + absl::string_view platform_version, + absl::Span> devices, + absl::Span machine_attributes); + + // `cpu_device_ids` is the list of logical device ids for the CPU devices and + // will be used to initialize the CPU topology. + TfrtCpuTopologyDescription( + const PjRtPlatformId platform_id, const absl::string_view platform_name, + const absl::string_view platform_version, + const std::vector cpu_devices, + absl::Span machine_attributes) + : platform_id_(platform_id), + platform_name_(platform_name), + platform_version_(platform_version), + cpu_topology_(std::move(cpu_devices), + std::vector(machine_attributes.begin(), + machine_attributes.end())) {} + + bool operator==(const TfrtCpuTopologyDescription& other) const { + return this->platform_id() == other.platform_id() && + this->platform_name() == other.platform_name() && + this->platform_version() == other.platform_version() && + this->cpu_topology().devices() == other.cpu_topology().devices(); + } + + PjRtPlatformId platform_id() const override { return platform_id_; } + + absl::string_view platform_name() const override { return platform_name_; } + + absl::string_view platform_version() const override { + return platform_version_; + } + + std::vector> DeviceDescriptions() + const override { + std::vector> devices; + devices.reserve(cpu_topology_.number_of_devices()); + for (const CpuTopology::CpuDevice& device : cpu_topology_.devices()) { + devices.push_back(std::make_unique( + device.id, device.process_index, device.local_hardware_id)); + } + return devices; + } + + const CpuTopology& cpu_topology() const { return cpu_topology_; } + const CpuTopology* cpu_topology_ptr() const { return &cpu_topology_; } + + // No subslice is supported. + bool is_subslice_topology() const override { return false; } + + // TODO(b/319478189): We support multi-host CPU computations and should + // correctly report process count. + absl::StatusOr ProcessCount() const override { return 1; } + + absl::StatusOr CoreCountOfDefaultType() const override { + return cpu_topology_.number_of_devices(); + } + + absl::StatusOr LogicalDeviceCountOfDefaultType() const override { + return cpu_topology_.number_of_devices(); + } + + absl::StatusOr CoreCountOfDefaultTypePerProcess() const override { + return cpu_topology_.number_of_devices(); + } + + absl::StatusOr CoreCountOfDefaultTypePerChip() const override { + return 1; + } + + absl::StatusOr Serialize() const override; + + // Returns vendor specific attributes about the topology. + const absl::flat_hash_map& Attributes() + const override { + return attributes_; + } + + StatusOr GetDefaultLayout( + PrimitiveType element_type, + absl::Span dims) const override; + + private: + const PjRtPlatformId platform_id_; + const std::string platform_name_; + const std::string platform_version_; + const CpuTopology cpu_topology_; + absl::flat_hash_map attributes_; +}; + class TfrtCpuDevice final : public PjRtDevice { public: explicit TfrtCpuDevice(int id, int process_index, int local_hardware_id, @@ -112,7 +216,15 @@ class TfrtCpuDevice final : public PjRtDevice { } int local_hardware_id() const override { - return description_.local_hardware_id(); + return local_hardware_id_typed().value(); + } + + PjRtLocalDeviceId local_device_id() const override { + return PjRtLocalDeviceId(local_hardware_id_typed().value()); + } + + PjRtLocalHardwareId local_hardware_id_typed() const override { + return PjRtLocalHardwareId(description_.local_hardware_id()); } Status TransferToInfeed(const LiteralSlice& literal) override; @@ -121,7 +233,7 @@ class TfrtCpuDevice final : public PjRtDevice { absl::Span memory_spaces() const override; - StatusOr default_memory_space() const override; + absl::StatusOr default_memory_space() const override; // Returns a semaphore for admission control on inflight computations. Semaphore& max_inflight_computations_semaphore() { @@ -147,6 +259,7 @@ class TfrtCpuClient final : public PjRtClient { public: TfrtCpuClient(int process_index, std::vector> devices, + std::shared_ptr collectives, size_t num_threads); ~TfrtCpuClient() override; @@ -164,10 +277,14 @@ class TfrtCpuClient final : public PjRtClient { return addressable_devices_; } - StatusOr LookupDevice(int device_id) const override; + absl::StatusOr LookupDevice(int device_id) const override; + absl::StatusOr LookupDevice( + PjRtGlobalDeviceId global_device_id) const override; - StatusOr LookupAddressableDevice( + absl::StatusOr LookupAddressableDevice( int local_hardware_id) const override; + absl::StatusOr LookupAddressableDevice( + PjRtLocalDeviceId local_device_id) const override; absl::Span memory_spaces() const override; @@ -181,31 +298,34 @@ class TfrtCpuClient final : public PjRtClient { PjRtRuntimeType runtime_type() const override { return kTfrt; } - StatusOr GetDefaultDeviceAssignment( + absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override; - StatusOr> GetHloCostAnalysis() + absl::StatusOr GetDefaultLayout( + PrimitiveType element_type, absl::Span dims) override; + + absl::StatusOr> GetHloCostAnalysis() const override; - StatusOr> Compile( + absl::StatusOr> Compile( const XlaComputation& computation, CompileOptions options) override; - StatusOr> Compile( + absl::StatusOr> Compile( mlir::ModuleOp module, CompileOptions options) override; // For TfrtCpuClient, `options` is mandatory. // This function returns an InvalidArgument error if `std::nullopt` is passed. // TODO(b/237720161): make it actually optional - StatusOr> DeserializeExecutable( + absl::StatusOr> DeserializeExecutable( absl::string_view serialized, std::optional options) override; - StatusOr> CreateErrorBuffer( + absl::StatusOr> CreateErrorBuffer( Status error, const Shape& shape, PjRtDevice* device) override; - StatusOr> CreateUninitializedBuffer( + absl::StatusOr> CreateUninitializedBuffer( const Shape& shape, PjRtDevice* device) override; - StatusOr> + absl::StatusOr> CreateBuffersForAsyncHostToDevice(absl::Span shapes, PjRtDevice* device) override; @@ -216,24 +336,24 @@ class TfrtCpuClient final : public PjRtClient { "CreateBuffersForAsyncHostToDevice with memory_space not implemented."); } - StatusOr> BufferFromHostBuffer( + absl::StatusOr> BufferFromHostBuffer( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, + absl::AnyInvocable on_done_with_host_buffer, PjRtDevice* device) override; - StatusOr> BufferFromHostLiteral( + absl::StatusOr> BufferFromHostLiteral( const LiteralSlice& literal, PjRtDevice* device) override; - StatusOr>> + absl::StatusOr>> MakeCrossHostReceiveBuffers(absl::Span shapes, PjRtDevice* device, PjRtCrossHostRecvNotifier notifier) override { return Unimplemented("MakeCrossHostReceiveBuffers not implemented."); } - StatusOr>> + absl::StatusOr>> MakeCrossHostReceiveBuffersForGather( absl::Span shapes, std::vector gather_details, PjRtDevice* device, PjRtCrossHostRecvNotifier notifier) override { @@ -241,18 +361,18 @@ class TfrtCpuClient final : public PjRtClient { "MakeCrossHostReceiveBuffersForGather not implemented."); } - StatusOr> CreateViewOfDeviceBuffer( + absl::StatusOr> CreateViewOfDeviceBuffer( void* device_ptr, const Shape& shape, PjRtDevice* device, std::function on_delete_callback, std::optional stream) override; - StatusOr CreateChannelHandle() override { + absl::StatusOr CreateChannelHandle() override { return Unimplemented("CreateChannelHandle not implemented."); } - StatusOr CreateDeviceToHostChannelHandle() override { + absl::StatusOr CreateDeviceToHostChannelHandle() override { return Unimplemented("CreateDeviceToHostChannelHandle not implemented."); } - StatusOr CreateHostToDeviceChannelHandle() override { + absl::StatusOr CreateHostToDeviceChannelHandle() override { return Unimplemented("CreateHostToDeviceChannelHandle not implemented."); } @@ -283,7 +403,14 @@ class TfrtCpuClient final : public PjRtClient { last_collective_launch_event_ = std::move(event); } + absl::StatusOr GetTopologyDescription() + const override { + return &topology_; + } + private: + friend class TfrtCpuExecutable; + int process_index_; // Includes all devices, including non-addressable devices. std::vector> owned_devices_; @@ -322,6 +449,10 @@ class TfrtCpuClient final : public PjRtClient { // major-to-minor layout. absl::Mutex transpose_mu_; TransposePlanCache transpose_cache_ ABSL_GUARDED_BY(transpose_mu_); + + std::shared_ptr collectives_; + + xla::TfrtCpuTopologyDescription topology_; }; class TfrtCpuBuffer final : public AbstractTfrtCpuBuffer { @@ -342,8 +473,11 @@ class TfrtCpuBuffer final : public AbstractTfrtCpuBuffer { using PjRtBuffer::ToLiteralSync; PjRtFuture ToLiteral(MutableLiteralBase* literal) override; + PjRtFuture LazyToLiteral( + absl::AnyInvocable() &&> generator) + override; - StatusOr> CopyToDevice( + absl::StatusOr> CopyToDevice( PjRtDevice* dst_device) override; private: @@ -394,18 +528,18 @@ class TfrtCpuExecutable final : public PjRtLoadedExecutable { return addressable_devices_; } - StatusOr>> GetHloModules() + absl::StatusOr>> GetHloModules() const override { return std::vector>{ cpu_executable_->shared_module()}; } - StatusOr>> GetOutputMemoryKinds() - const override { + absl::StatusOr>> + GetOutputMemoryKinds() const override { return Unimplemented("GetOutputMemoryKinds is not supported."); } - StatusOr GetCompiledMemoryStats() const override { + absl::StatusOr GetCompiledMemoryStats() const override { CompiledMemoryStats memory_stats = CompiledMemoryStats(); memory_stats.generated_code_size_in_bytes = SizeOfGeneratedCodeInBytes(); const HloProto* proto = cpu_executable_->hlo_proto(); @@ -418,21 +552,21 @@ class TfrtCpuExecutable final : public PjRtLoadedExecutable { } using PjRtLoadedExecutable::Execute; - StatusOr>>> Execute( + absl::StatusOr>>> Execute( absl::Span> argument_handles, const ExecuteOptions& options, std::optional>>& returned_futures) override; using PjRtLoadedExecutable::ExecuteSharded; - StatusOr>> ExecuteSharded( + absl::StatusOr>> ExecuteSharded( absl::Span argument_handles, PjRtDevice* device, const ExecuteOptions& options, std::optional>& returned_future, bool fill_future) override; using PjRtLoadedExecutable::ExecutePortable; - StatusOr>> ExecutePortable( + absl::StatusOr>> ExecutePortable( absl::Span argument_handles, PjRtDevice* device, const ExecuteOptions& options, std::optional>& returned_future, @@ -442,18 +576,22 @@ class TfrtCpuExecutable final : public PjRtLoadedExecutable { bool IsDeleted() override; - StatusOr SerializeExecutable() const override; + absl::StatusOr SerializeExecutable() const override; bool IsReturnedFutureSupported() const override { return true; } - StatusOr> Fingerprint() const; + absl::StatusOr> Fingerprint() const; std::shared_ptr cpu_executable() const { return cpu_executable_; } - StatusOr FingerprintExecutable() const override { + absl::StatusOr FingerprintExecutable() const override { return Unimplemented("Fingerprinting executable is not supported."); } + absl::StatusOr GetCompileOptions() const override { + return compile_options_; + } + private: friend class TfrtCpuClient; @@ -465,7 +603,7 @@ class TfrtCpuExecutable final : public PjRtLoadedExecutable { absl::Span const> input_buffers) const; - StatusOr ExecuteHelper( + absl::StatusOr ExecuteHelper( absl::Span argument_handles, int replica, int partition, const RunId& run_id, const ExecuteOptions& options, tsl::AsyncValueRef last_collective_launch_event, @@ -532,15 +670,18 @@ struct CpuClientOptions { // My node ID. int node_id = 0; - // KV store primitives for sharing topology information. - PjRtClient::KeyValueGetCallback kv_get = nullptr; - PjRtClient::KeyValuePutCallback kv_put = nullptr; + // KV store for sharing topology information. + std::shared_ptr kv_store = nullptr; + + // Distributed collectives implementation. Optional. If not provided, an + // in-process collectives implementation will be used. + std::shared_ptr collectives; }; -StatusOr> GetTfrtCpuClient( +absl::StatusOr> GetTfrtCpuClient( const CpuClientOptions& options); // Deprecated. Use the overload that takes 'options' instead. -inline StatusOr> GetTfrtCpuClient( +inline absl::StatusOr> GetTfrtCpuClient( bool asynchronous) { CpuClientOptions options; options.asynchronous = asynchronous; @@ -548,7 +689,7 @@ inline StatusOr> GetTfrtCpuClient( } // Deprecated. Use the overload that takes 'options' instead. -inline StatusOr> GetTfrtCpuClient( +inline absl::StatusOr> GetTfrtCpuClient( bool asynchronous, int cpu_device_count, int max_inflight_computations_per_device = 32) { CpuClientOptions options; diff --git a/xla/pjrt/cpu/cpu_client_test.cc b/xla/pjrt/cpu/cpu_client_test.cc index f23aab9f808ee..181db327a3653 100644 --- a/xla/pjrt/cpu/cpu_client_test.cc +++ b/xla/pjrt/cpu/cpu_client_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,7 +15,9 @@ limitations under the License. #include "xla/pjrt/cpu/cpu_client.h" +#ifndef _WIN32 #include +#endif #include #include @@ -272,7 +274,7 @@ TEST(TfrtCpuClientTest, AsyncTransferSetBufferError) { client->CreateBuffersForAsyncHostToDevice( {shape}, client->addressable_devices()[0])); auto buffer = transfer_manager->RetrieveBuffer(0); - transfer_manager->SetBufferError(0, InternalError("foobar")); + transfer_manager->SetBufferError(0, Internal("foobar")); EXPECT_THAT( buffer->ToLiteralSync(), tsl::testing::StatusIs(tsl::error::INTERNAL, HasSubstr("foobar"))); @@ -282,7 +284,7 @@ TEST(TfrtCpuClientTest, CreateErrorBuffer) { TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(CpuClientOptions())); xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2}); TF_ASSERT_OK_AND_ASSIGN( - auto buffer, client->CreateErrorBuffer(InternalError("foobar"), shape, + auto buffer, client->CreateErrorBuffer(Internal("foobar"), shape, client->addressable_devices()[0])); EXPECT_THAT( buffer->ToLiteralSync(), diff --git a/xla/pjrt/cpu/cpu_topology.cc b/xla/pjrt/cpu/cpu_topology.cc new file mode 100644 index 0000000000000..72e5ece634f45 --- /dev/null +++ b/xla/pjrt/cpu/cpu_topology.cc @@ -0,0 +1,61 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/pjrt/cpu/cpu_topology.h" + +#include +#include +#include +#include +#include + +namespace xla { + +std::unique_ptr CpuTopology::FromProto( + const CpuTopologyProto& cpu_topology_proto) { + std::vector devices; + devices.reserve(cpu_topology_proto.cpu_devices_size()); + + for (size_t i = 0; i < cpu_topology_proto.cpu_devices_size(); ++i) { + auto& cpu_device_proto = cpu_topology_proto.cpu_devices(i); + devices.push_back({cpu_device_proto.id(), cpu_device_proto.process_index(), + cpu_device_proto.local_hardware_id()}); + } + + std::vector machine_attributes; + machine_attributes.reserve(cpu_topology_proto.machine_attributes_size()); + for (size_t i = 0; i < cpu_topology_proto.machine_attributes_size(); ++i) { + machine_attributes.push_back(cpu_topology_proto.machine_attributes(i)); + } + + return std::make_unique(std::move(devices), + std::move(machine_attributes)); +} + +CpuTopologyProto CpuTopology::ToProto() const { + CpuTopologyProto proto; + for (auto& cpu_device : cpu_devices_) { + auto* cpu_device_proto = proto.add_cpu_devices(); + cpu_device_proto->set_id(cpu_device.id); + cpu_device_proto->set_process_index(cpu_device.process_index); + cpu_device_proto->set_local_hardware_id(cpu_device.local_hardware_id); + } + for (const std::string& machine_attribute : machine_attributes_) { + proto.add_machine_attributes(machine_attribute); + } + return proto; +} + +} // namespace xla diff --git a/xla/pjrt/cpu/cpu_topology.h b/xla/pjrt/cpu/cpu_topology.h new file mode 100644 index 0000000000000..d698f2cb14f5e --- /dev/null +++ b/xla/pjrt/cpu/cpu_topology.h @@ -0,0 +1,63 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_CPU_CPU_TOPOLOGY_H_ +#define XLA_PJRT_CPU_CPU_TOPOLOGY_H_ + +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "xla/pjrt/cpu/cpu_topology.pb.h" + +namespace xla { +class CpuTopology { + public: + struct CpuDevice { + int id; + int process_index; + int local_hardware_id; + + bool operator==(const CpuDevice& other) const { + return id == other.id && process_index == other.process_index && + local_hardware_id == other.local_hardware_id; + } + }; + + explicit CpuTopology(std::vector cpu_devices, + std::vector machine_attributes) + : cpu_devices_(std::move(cpu_devices)), + machine_attributes_(std::move(machine_attributes)) {} + + int number_of_devices() const { return cpu_devices_.size(); } + absl::Span devices() const { return cpu_devices_; } + absl::Span machine_attributes() const { + return machine_attributes_; + } + + static std::unique_ptr FromProto( + const CpuTopologyProto& proto); + CpuTopologyProto ToProto() const; + + private: + const std::vector cpu_devices_; + const std::vector machine_attributes_; +}; + +} // namespace xla + +#endif // XLA_PJRT_CPU_CPU_TOPOLOGY_H_ diff --git a/xla/pjrt/cpu/cpu_topology.proto b/xla/pjrt/cpu/cpu_topology.proto new file mode 100644 index 0000000000000..85167fc5ffcff --- /dev/null +++ b/xla/pjrt/cpu/cpu_topology.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package xla; + +// A proto used to serialize CpuTopology instances. +message CpuTopologyProto { + message CpuDevice { + int32 id = 1; + int32 process_index = 2; + int32 local_hardware_id = 3; + } + repeated CpuDevice cpu_devices = 1; + repeated string machine_attributes = 4; +} diff --git a/xla/pjrt/cpu/cpu_topology_test.cc b/xla/pjrt/cpu/cpu_topology_test.cc new file mode 100644 index 0000000000000..148f1edc353d8 --- /dev/null +++ b/xla/pjrt/cpu/cpu_topology_test.cc @@ -0,0 +1,59 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/pjrt/cpu/cpu_topology.h" + +#include + +#include "tsl/platform/protobuf.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace { + +TEST(CpuTopology, FromProto) { + CpuTopologyProto msg; + ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString( + R"pb( + cpu_devices: + [ { id: 1, process_index: 2, local_hardware_id: 3 }] + machine_attributes: [ "x86_64", "Intel" ] + )pb", + &msg)); + + std::unique_ptr cpu_topology = CpuTopology::FromProto(msg); + EXPECT_EQ(cpu_topology->devices().size(), 1); + EXPECT_EQ(cpu_topology->devices()[0].id, 1); + EXPECT_EQ(cpu_topology->devices()[0].process_index, 2); + EXPECT_EQ(cpu_topology->devices()[0].local_hardware_id, 3); + EXPECT_EQ(cpu_topology->machine_attributes().size(), 2); + EXPECT_EQ(cpu_topology->machine_attributes()[0], "x86_64"); + EXPECT_EQ(cpu_topology->machine_attributes()[1], "Intel"); +} + +TEST(CpuTopology, ToProto) { + CpuTopology cpu_topology({{1, 2, 3}}, {"ab", "cd"}); + CpuTopologyProto msg = cpu_topology.ToProto(); + EXPECT_EQ(msg.cpu_devices_size(), 1); + EXPECT_EQ(msg.cpu_devices(0).id(), 1); + EXPECT_EQ(msg.cpu_devices(0).process_index(), 2); + EXPECT_EQ(msg.cpu_devices(0).local_hardware_id(), 3); + EXPECT_EQ(msg.machine_attributes_size(), 2); + EXPECT_EQ(msg.machine_attributes(0), "ab"); + EXPECT_EQ(msg.machine_attributes(1), "cd"); +} + +} // namespace +} // namespace xla diff --git a/xla/pjrt/cpu/gloo_collectives.cc b/xla/pjrt/cpu/gloo_collectives.cc new file mode 100644 index 0000000000000..2e37676ece999 --- /dev/null +++ b/xla/pjrt/cpu/gloo_collectives.cc @@ -0,0 +1,469 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/pjrt/cpu/gloo_collectives.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "gloo/algorithm.h" // from @gloo +#include "gloo/allgather.h" // from @gloo +#include "gloo/allreduce.h" // from @gloo +#include "gloo/context.h" // from @gloo +#include "gloo/math.h" // from @gloo +#include "gloo/reduce_scatter.h" // from @gloo +#include "gloo/rendezvous/context.h" // from @gloo +#include "gloo/rendezvous/prefix_store.h" // from @gloo +#include "gloo/rendezvous/store.h" // from @gloo +#include "gloo/transport/device.h" // from @gloo +#include "gloo/transport/unbound_buffer.h" // from @gloo +#include "gloo/types.h" // from @gloo +#include "xla/primitive_util.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/cpu/collectives_interface.h" +#include "xla/service/global_device_id.h" +#include "xla/status_macros.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" + +namespace xla::cpu { + +GlooCollectivesCommunicator::GlooCollectivesCommunicator( + std::shared_ptr context) + : context_(std::move(context)) {} +GlooCollectivesCommunicator::~GlooCollectivesCommunicator() = default; + +template +static absl::Status SetAllReduceOptions(ReductionKind reduction_kind, + const void* input_buffer, + void* output_buffer, + size_t num_elements, + gloo::AllreduceOptions& options) { + options.setInput(reinterpret_cast(const_cast(input_buffer)), + num_elements); + options.setOutput(reinterpret_cast(const_cast(output_buffer)), + num_elements); + + using ReductionFn = void (*)(void*, const void*, const void*, size_t); + + switch (reduction_kind) { + case ReductionKind::SUM: + options.setReduceFunction(static_cast(&gloo::sum)); + break; + case ReductionKind::PRODUCT: + options.setReduceFunction(static_cast(&gloo::product)); + break; + case ReductionKind::MIN: + if constexpr (!is_complex_v) { + options.setReduceFunction(static_cast(&gloo::min)); + } else { + return absl::InvalidArgumentError( + "MIN reduction not supported for complex types"); + } + break; + case ReductionKind::MAX: + if constexpr (!is_complex_v) { + options.setReduceFunction(static_cast(&gloo::max)); + } else { + return absl::InvalidArgumentError( + "MAX reduction not supported for complex types"); + } + break; + } + return absl::OkStatus(); +} + +absl::Status GlooCollectivesCommunicator::AllReduce( + const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t num_elements, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + gloo::AllreduceOptions options(context_); + // TODO(phawkins): how to do tags? + // options.setTag(tag); + switch (element_type) { + case S8: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, input_buffer, output_buffer, num_elements, options)); + break; + case PRED: + case U8: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, input_buffer, output_buffer, num_elements, options)); + break; + case S16: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, input_buffer, output_buffer, num_elements, options)); + break; + case U16: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, input_buffer, output_buffer, num_elements, options)); + break; + case S32: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, input_buffer, output_buffer, num_elements, options)); + break; + case U32: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, input_buffer, output_buffer, num_elements, options)); + break; + case S64: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, input_buffer, output_buffer, num_elements, options)); + break; + case U64: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, input_buffer, output_buffer, num_elements, options)); + break; + case F16: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, input_buffer, output_buffer, num_elements, options)); + break; + case BF16: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, input_buffer, output_buffer, num_elements, options)); + break; + case F32: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, input_buffer, output_buffer, num_elements, options)); + break; + case F64: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, input_buffer, output_buffer, num_elements, options)); + break; + case C64: + TF_RETURN_IF_ERROR(SetAllReduceOptions>( + reduction_kind, input_buffer, output_buffer, num_elements, options)); + break; + case C128: + TF_RETURN_IF_ERROR(SetAllReduceOptions>( + reduction_kind, input_buffer, output_buffer, num_elements, options)); + break; + default: + return absl::InvalidArgumentError("Unknown datatype in allreduce"); + } + options.setAlgorithm(gloo::AllreduceOptions::Algorithm::RING); + options.setTimeout(absl::ToChronoMilliseconds(timeout)); + + try { + gloo::allreduce(options); + } catch (std::exception& e) { + return absl::UnknownError( + absl::StrCat("Gloo all-reduce failed: ", e.what())); + } + return absl::OkStatus(); +} + +static constexpr uint8_t kCollectivePermuteSlotPrefix = 0x40; + +absl::Status GlooCollectivesCommunicator::CollectivePermute( + const RendezvousKey& key, size_t num_bytes, std::optional source_rank, + absl::Span target_ranks, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + uint32_t tag = 0; // TODO(phawkins): come up with better tags. + const auto slot = gloo::Slot::build(kCollectivePermuteSlotPrefix, tag); + try { + std::unique_ptr in; + std::unique_ptr out; + for (int target : target_ranks) { + if (target != context_->rank) { + VLOG(1) << "send from " << context_->rank << " to " << target; + if (!in) { + in = context_->createUnboundBuffer(const_cast(input_buffer), + num_bytes); + } + in->send(target, slot); + } + } + if (source_rank) { + if (*source_rank == context_->rank) { + std::memcpy(output_buffer, input_buffer, num_bytes); + } else { + VLOG(1) << "recv at " << context_->rank << " from " << *source_rank; + out = context_->createUnboundBuffer(output_buffer, num_bytes); + out->recv(*source_rank, slot); + } + } else { + std::memset(output_buffer, 0, num_bytes); + } + VLOG(1) << "wait for send at " << context_->rank; + auto deadline = absl::ToChronoTime(absl::Now() + timeout); + if (in) { + in->waitSend(deadline); + } + VLOG(1) << "wait for recv at " << context_->rank; + if (out) { + out->waitRecv(deadline); + } + VLOG(1) << "done waiting at " << context_->rank; + } catch (std::exception& e) { + return absl::UnknownError( + absl::StrCat("Gloo collective permute failed: ", e.what())); + } + return absl::OkStatus(); +} + +absl::Status GlooCollectivesCommunicator::AllToAll( + const RendezvousKey& key, size_t chunk_bytes, + absl::Span input_buffers, + absl::Span output_buffers, absl::Duration timeout) { + // We can't use Gloo's all-to-all implementation directly because it assumes + // that the inputs and outputs are contiguous. No big deal; it's just built + // on top of send/recv and we can do the same as it. + uint32_t tag = 0; // TODO(phawkins): use better tags. + int my_rank = context_->rank; + int world_size = context_->size; + + TF_RET_CHECK(world_size == input_buffers.size()); + TF_RET_CHECK(world_size == output_buffers.size()); + + try { + const auto slot = gloo::Slot::build(gloo::kAlltoallSlotPrefix, tag); + std::vector> ins( + context_->size); + std::vector> outs( + context_->size); + for (size_t i = 0; i < world_size; ++i) { + if (i != my_rank) { + ins[i] = context_->createUnboundBuffer( + const_cast(input_buffers[i]), chunk_bytes); + outs[i] = context_->createUnboundBuffer(output_buffers[i], chunk_bytes); + } + } + + for (int i = 1; i < world_size; i++) { + int send_rank = (my_rank + i) % world_size; + int recv_rank = (my_rank + world_size - i) % world_size; + ins[send_rank]->send(send_rank, slot); + outs[recv_rank]->recv(recv_rank, slot); + } + + std::memcpy(output_buffers[my_rank], input_buffers[my_rank], chunk_bytes); + + auto deadline = absl::ToChronoTime(absl::Now() + timeout); + for (int i = 0; i < world_size; i++) { + if (i != my_rank) { + ins[i]->waitSend(deadline); + outs[i]->waitRecv(deadline); + } + } + } catch (std::exception& e) { + return absl::UnknownError( + absl::StrCat("Gloo all-to-all failed: ", e.what())); + } + return absl::OkStatus(); +} + +absl::Status GlooCollectivesCommunicator::AllGather(const RendezvousKey& key, + size_t chunk_bytes, + const void* input_buffer, + void* output_buffer, + absl::Duration timeout) { + uint32_t tag = 0; // TODO(phawkins): use better tags. + + gloo::AllgatherOptions options(context_); + options.setTag(tag); + options.setTimeout(absl::ToChronoMilliseconds(timeout)); + options.setInput(reinterpret_cast(const_cast(input_buffer)), + chunk_bytes); + options.setOutput(reinterpret_cast(output_buffer), + chunk_bytes * context_->size); + + try { + gloo::allgather(options); + } catch (std::exception& e) { + return absl::UnknownError( + absl::StrCat("Gloo AllGather failed: ", e.what())); + } + return absl::OkStatus(); +} + +template +absl::Status ReduceScatterHelper(std::shared_ptr context, + ReductionKind reduction_kind, void* buffer, + size_t chunk_elems) { + const gloo::ReductionFunction* reduction_function = nullptr; + if constexpr (is_complex_v) { + switch (reduction_kind) { + case ReductionKind::SUM: + reduction_function = gloo::ReductionFunction::sum; + break; + case ReductionKind::PRODUCT: + reduction_function = gloo::ReductionFunction::product; + break; + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported reduction kind: ", static_cast(reduction_kind))); + } + } else { + switch (reduction_kind) { + case ReductionKind::SUM: + reduction_function = gloo::ReductionFunction::sum; + break; + case ReductionKind::PRODUCT: + reduction_function = gloo::ReductionFunction::product; + break; + case ReductionKind::MAX: + reduction_function = gloo::ReductionFunction::max; + break; + case ReductionKind::MIN: + reduction_function = gloo::ReductionFunction::min; + break; + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported reduction kind: ", static_cast(reduction_kind))); + } + } + try { + std::vector recv_elems(context->size, chunk_elems); + gloo::ReduceScatterHalvingDoubling algorithm( + context, std::vector{reinterpret_cast(buffer)}, + chunk_elems * context->size, recv_elems, reduction_function); + algorithm.run(); + } catch (std::exception& e) { + return absl::UnknownError( + absl::StrCat("Gloo ReduceScatter failed: ", e.what())); + } + return absl::OkStatus(); +} + +absl::Status GlooCollectivesCommunicator::ReduceScatter( + const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + size_t chunk_bytes = chunk_elems * primitive_util::ByteWidth(element_type); + std::unique_ptr temp(new char[chunk_bytes * context_->size]); + std::memcpy(temp.get(), input_buffer, chunk_bytes * context_->size); + switch (element_type) { + case S8: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), chunk_elems)); + break; + case PRED: + case U8: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), chunk_elems)); + break; + case S16: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), chunk_elems)); + break; + case U16: + TF_RETURN_IF_ERROR(ReduceScatterHelper( + context_, reduction_kind, temp.get(), chunk_elems)); + break; + case S32: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), chunk_elems)); + break; + case U32: + TF_RETURN_IF_ERROR(ReduceScatterHelper( + context_, reduction_kind, temp.get(), chunk_elems)); + break; + case S64: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), chunk_elems)); + break; + case U64: + TF_RETURN_IF_ERROR(ReduceScatterHelper( + context_, reduction_kind, temp.get(), chunk_elems)); + break; + case BF16: + TF_RETURN_IF_ERROR(ReduceScatterHelper( + context_, reduction_kind, temp.get(), chunk_elems)); + break; + case F16: + TF_RETURN_IF_ERROR(ReduceScatterHelper( + context_, reduction_kind, temp.get(), chunk_elems)); + break; + case F32: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), chunk_elems)); + break; + case F64: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), chunk_elems)); + break; + case C64: + TF_RETURN_IF_ERROR(ReduceScatterHelper>( + context_, reduction_kind, temp.get(), chunk_elems)); + break; + case C128: + TF_RETURN_IF_ERROR(ReduceScatterHelper>( + context_, reduction_kind, temp.get(), chunk_elems)); + break; + default: + return absl::InvalidArgumentError("Unknown datatype in reducescatter"); + } + std::memcpy(output_buffer, temp.get(), chunk_bytes); + return absl::OkStatus(); +} + +GlooCollectives::GlooCollectives( + std::unique_ptr store, + std::shared_ptr device) + : store_(std::move(store)), device_(std::move(device)) {} + +GlooCollectives::~GlooCollectives() = default; + +absl::StatusOr> +GlooCollectives::GetCommunicator( + absl::Span global_devices, int rank) { + absl::MutexLock lock(&mu_); + auto& context = contexts_[std::make_tuple( + std::vector(global_devices.begin(), global_devices.end()), + rank)]; + if (context) { + return context; + } + auto gloo_context = + std::make_shared(rank, global_devices.size()); + auto prefix_store = gloo::rendezvous::PrefixStore( + absl::StrCat("gloo/", + absl::StrJoin(global_devices, ",", + [](std::string* out, GlobalDeviceId id) { + absl::StrAppend(out, id.value()); + })), + *store_); + try { + gloo_context->connectFullMesh(prefix_store, device_); + } catch (std::exception& e) { + return absl::UnknownError( + absl::StrCat("Gloo context initialization failed: ", e.what())); + } + context = + std::make_shared(std::move(gloo_context)); + return context; +} + +} // namespace xla::cpu diff --git a/xla/pjrt/cpu/gloo_collectives.h b/xla/pjrt/cpu/gloo_collectives.h new file mode 100644 index 0000000000000..e1d2852ddaf3b --- /dev/null +++ b/xla/pjrt/cpu/gloo_collectives.h @@ -0,0 +1,94 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_CPU_GLOO_COLLECTIVES_H_ +#define XLA_PJRT_CPU_GLOO_COLLECTIVES_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "gloo/context.h" // from @gloo +#include "gloo/rendezvous/store.h" // from @gloo +#include "gloo/transport/device.h" // from @gloo +#include "xla/service/collective_ops_utils.h" +#include "xla/service/cpu/collectives_interface.h" +#include "xla/service/global_device_id.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +class GlooCollectivesCommunicator : public CollectivesCommunicator { + public: + explicit GlooCollectivesCommunicator(std::shared_ptr context); + ~GlooCollectivesCommunicator() override; + + absl::Status AllReduce(const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t num_elements, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + absl::Status CollectivePermute(const RendezvousKey& key, size_t num_bytes, + std::optional source_rank, + absl::Span target_ranks, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + absl::Status AllToAll(const RendezvousKey& key, size_t chunk_bytes, + absl::Span input_buffers, + absl::Span output_buffers, + absl::Duration timeout) override; + absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + absl::Status ReduceScatter(const RendezvousKey& key, + ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + + private: + std::shared_ptr context_; +}; + +class GlooCollectives : public CollectivesInterface { + public: + GlooCollectives(std::unique_ptr store, + std::shared_ptr device); + ~GlooCollectives() override; + + // Thread-safe. + absl::StatusOr> GetCommunicator( + absl::Span devices, int rank) override; + + private: + std::unique_ptr store_; + std::shared_ptr device_; + absl::Mutex mu_; + absl::flat_hash_map, int>, + std::shared_ptr> + contexts_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace xla::cpu + +#endif // XLA_PJRT_CPU_GLOO_COLLECTIVES_H_ diff --git a/xla/pjrt/cpu/gloo_kv_store.cc b/xla/pjrt/cpu/gloo_kv_store.cc new file mode 100644 index 0000000000000..5747b698ee916 --- /dev/null +++ b/xla/pjrt/cpu/gloo_kv_store.cc @@ -0,0 +1,68 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/pjrt/cpu/gloo_kv_store.h" + +#include // NOLINT +#include +#include +#include +#include +#include +#include + +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "gloo/rendezvous/store.h" // from @gloo +#include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/pjrt/status_casters.h" + +namespace xla::cpu { + +GlooKeyValueStore::GlooKeyValueStore( + std::shared_ptr kv_store) + : kv_store_(std::move(kv_store)) {} + +GlooKeyValueStore::~GlooKeyValueStore() = default; + +void GlooKeyValueStore::set(const std::string& key, + const std::vector& data) { + ThrowIfError(kv_store_->Set(key, std::string_view(data.data(), data.size()))); +} + +std::vector GlooKeyValueStore::get(const std::string& key) { + std::string result = ValueOrThrow(kv_store_->Get(key, kv_get_timeout_)); + std::vector data(result.begin(), result.end()); + return data; +} + +void GlooKeyValueStore::wait(const std::vector& keys) { + wait(keys, Store::kDefaultTimeout); +} + +void GlooKeyValueStore::wait(const std::vector& keys, + const std::chrono::milliseconds& timeout) { + // TODO(phawkins): add a wait-many feature to the distributed service. + absl::Time deadline = absl::Now() + absl::FromChrono(timeout); + for (const std::string& key : keys) { + absl::Time now = absl::Now(); + if (now >= deadline) { + throw std::runtime_error("Deadline exceeded in wait()"); + } + ThrowIfError(kv_store_->Get(key, deadline - now).status()); + } +} + +} // namespace xla::cpu diff --git a/xla/pjrt/cpu/gloo_kv_store.h b/xla/pjrt/cpu/gloo_kv_store.h new file mode 100644 index 0000000000000..0f9f8e7e18c50 --- /dev/null +++ b/xla/pjrt/cpu/gloo_kv_store.h @@ -0,0 +1,52 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_CPU_GLOO_KV_STORE_H_ +#define XLA_PJRT_CPU_GLOO_KV_STORE_H_ + +#include // NOLINT +#include +#include +#include + +#include "absl/time/time.h" +#include "gloo/rendezvous/store.h" // from @gloo +#include "xla/pjrt/distributed/key_value_store_interface.h" + +namespace xla::cpu { + +class GlooKeyValueStore : public ::gloo::rendezvous::Store { + public: + explicit GlooKeyValueStore(std::shared_ptr kv_store); + ~GlooKeyValueStore() override; + + void set(const std::string& key, const std::vector& data) override; + + std::vector get(const std::string& key) override; + + void wait(const std::vector& keys) override; + + void wait(const std::vector& keys, + const std::chrono::milliseconds& timeout) override; + + private: + std::shared_ptr kv_store_; + + absl::Duration kv_get_timeout_ = absl::Minutes(1); +}; + +} // namespace xla::cpu + +#endif // XLA_PJRT_CPU_GLOO_KV_STORE_H_ diff --git a/xla/pjrt/cpu/mpi_collectives.cc b/xla/pjrt/cpu/mpi_collectives.cc new file mode 100644 index 0000000000000..d2c93fd75450f --- /dev/null +++ b/xla/pjrt/cpu/mpi_collectives.cc @@ -0,0 +1,283 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/pjrt/cpu/mpi_collectives.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mpi.h" // NOLINT +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/primitive_util.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/cpu/collectives_interface.h" +#include "xla/service/global_device_id.h" +#include "xla/status_macros.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" + +namespace xla::cpu { + +absl::StatusOr PrimitiveTypeToMpiType( + PrimitiveType element_type) { + switch (element_type) { + case S8: + return MPI_INT8_T; + case U8: + case PRED: + return MPI_UINT8_T; + case S16: + return MPI_INT16_T; + case U16: + return MPI_UINT16_T; + case S32: + return MPI_INT32_T; + case U32: + return MPI_UINT32_T; + case S64: + return MPI_INT64_T; + case U64: + return MPI_UINT64_T; + case F32: + return MPI_FLOAT; + case F64: + return MPI_DOUBLE; + case C64: + return MPI_C_COMPLEX; + case C128: + return MPI_C_DOUBLE_COMPLEX; + default: + // For implementing the reduction of unsupported types + // see e.g. https://stackoverflow.com/a/29643391 + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported primitive type for reduction: ", + primitive_util::LowercasePrimitiveTypeName(element_type))); + } +} + +bool MpiTypeIsComplex(MPI_Datatype type) { + return type == MPI_C_COMPLEX || type == MPI_C_DOUBLE_COMPLEX; +} + +absl::StatusOr ReductionKindToMpiOp(ReductionKind reduction_kind, + MPI_Datatype type) { + switch (reduction_kind) { + case ReductionKind::SUM: + return MPI_SUM; + case ReductionKind::PRODUCT: + return MPI_PROD; + case ReductionKind::MIN: + if (!MpiTypeIsComplex(type)) { + return MPI_MIN; + } else { + return absl::InvalidArgumentError( + "MIN reduction not supported for complex types"); + } + case ReductionKind::MAX: + if (!MpiTypeIsComplex(type)) { + return MPI_MAX; + } else { + return absl::InvalidArgumentError( + "MAX reduction not supported for complex types"); + } + default: + return absl::InvalidArgumentError( + absl::StrCat("Unknown reduction kind: ", reduction_kind)); + } +} + +static absl::Status MpiErrorToAbslStatus(int error) { + if (error != MPI_SUCCESS) { + char error_str[MPI_MAX_ERROR_STRING]; + int len; + MPI_Error_string(error, error_str, &len); + return absl::UnknownError(absl::StrCat("MPI error: ", error_str)); + } + return absl::OkStatus(); +} + +MpiCollectivesCommunicator::MpiCollectivesCommunicator(int color, int key) { + MPI_Comm_split(MPI_COMM_WORLD, color, key, &comm_); + MPI_Comm_rank(comm_, &mpi_rank_); + MPI_Comm_size(comm_, &mpi_size_); +} + +MpiCollectivesCommunicator::~MpiCollectivesCommunicator() { + MPI_Comm_free(&comm_); +}; + +absl::Status MpiCollectivesCommunicator::AllReduce( + const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t num_elements, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(element_type)); + TF_ASSIGN_OR_RETURN(MPI_Op op, ReductionKindToMpiOp(reduction_kind, type)); + return MpiErrorToAbslStatus(MPI_Allreduce(input_buffer, output_buffer, + num_elements, type, op, comm_)); +} + +absl::Status MpiCollectivesCommunicator::CollectivePermute( + const RendezvousKey& key, size_t num_bytes, std::optional source_rank, + absl::Span target_ranks, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + int tag = 0; // TODO come up with better tags. + + const int rank = mpi_rank_; + + std::vector requests; + + if (source_rank) { + if (*source_rank == rank) { + std::memcpy(output_buffer, input_buffer, num_bytes); + } else { + VLOG(1) << "recv at " << rank << " from " << *source_rank; + requests.emplace_back(); + TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( + MPI_Irecv(output_buffer, num_bytes, MPI_BYTE, *source_rank, tag, + comm_, &requests.back()))); + } + } else { + std::memset(output_buffer, 0, num_bytes); + } + + for (int target : target_ranks) { + if (target != rank) { + VLOG(1) << "send from " << rank << " to " << target; + requests.emplace_back(); + TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( + MPI_Isend(input_buffer, num_bytes, MPI_BYTE, target, tag, comm_, + &requests.back()))); + } + } + + for (auto& request : requests) { + TF_RETURN_IF_ERROR( + MpiErrorToAbslStatus(MPI_Wait(&request, MPI_STATUS_IGNORE))); + } + + return absl::OkStatus(); +} + +absl::Status MpiCollectivesCommunicator::AllToAll( + const RendezvousKey& key, size_t chunk_bytes, + absl::Span input_buffers, + absl::Span output_buffers, absl::Duration timeout) { + // We can't use MPI_Alltoall directly because it assumes that the inputs and + // outputs are contiguous. Therefore here we implement it using MPI_Sendrecv. + + int tag = 0; // TODO use better tags. + const int rank = mpi_rank_; + const int size = mpi_size_; + TF_RET_CHECK(size == input_buffers.size()); + TF_RET_CHECK(size == output_buffers.size()); + + std::memcpy(output_buffers[rank], input_buffers[rank], chunk_bytes); + + for (int i = 1; i < size; i++) { + int send_rank = (rank + i) % size; + int recv_rank = (rank + size - i) % size; + TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( + MPI_Sendrecv(input_buffers[send_rank], chunk_bytes, MPI_BYTE, send_rank, + tag, output_buffers[recv_rank], chunk_bytes, MPI_BYTE, + recv_rank, tag, comm_, MPI_STATUS_IGNORE))); + } + + return absl::OkStatus(); +} + +absl::Status MpiCollectivesCommunicator::AllGather(const RendezvousKey& key, + size_t chunk_bytes, + const void* input_buffer, + void* output_buffer, + absl::Duration timeout) { + return MpiErrorToAbslStatus(MPI_Allgather(input_buffer, chunk_bytes, MPI_BYTE, + output_buffer, chunk_bytes, + MPI_BYTE, comm_)); +} + +absl::Status MpiCollectivesCommunicator::ReduceScatter( + const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + const int size = mpi_size_; + std::vector recvcounts(size, chunk_elems); + TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(element_type)); + TF_ASSIGN_OR_RETURN(MPI_Op op, ReductionKindToMpiOp(reduction_kind, type)); + return MpiErrorToAbslStatus(MPI_Reduce_scatter( + input_buffer, output_buffer, recvcounts.data(), type, op, comm_)); +} + +void MpiCollectives::Init() { + int provided; + MPI_Init_thread(NULL, NULL, MPI_THREAD_FUNNELED, &provided); + MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank_); + MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size_); + VLOG(1) << "MPI rank=" << mpi_world_rank_ << " size=" << mpi_world_size_; +} + +void MpiCollectives::Finalize() { + contexts_.clear(); + MPI_Finalize(); +} + +absl::StatusOr> +MpiCollectives::GetCommunicator(absl::Span global_devices, + int rank) { + int flag; + MPI_Is_thread_main(&flag); + if (!flag) { + return absl::UnknownError( + absl::StrCat("MPI: Communicator requested from a thread that is not " + "the one MPI was initialized from. Multiple " + "threads/devices per process are not yet supported.")); + } + + auto& context = contexts_[std::make_tuple( + std::vector(global_devices.begin(), global_devices.end()), + rank)]; + if (context) { + return context; + } + + int color; + int key = 0; + if (global_devices.size() > 0) { + color = static_cast(global_devices.at(0).value()); + key = rank; + } else { + color = MPI_UNDEFINED; + } + context = std::make_shared(color, key); + return context; +} + +} // namespace xla::cpu diff --git a/xla/pjrt/cpu/mpi_collectives.h b/xla/pjrt/cpu/mpi_collectives.h new file mode 100644 index 0000000000000..fdf6ec81b6dc6 --- /dev/null +++ b/xla/pjrt/cpu/mpi_collectives.h @@ -0,0 +1,102 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_CPU_MPI_COLLECTIVES_H_ +#define XLA_PJRT_CPU_MPI_COLLECTIVES_H_ + +#include +#include +#include +#include +#include + +#include "mpi.h" // NOLINT +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/cpu/collectives_interface.h" +#include "xla/service/global_device_id.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +class MpiCollectivesCommunicator : public CollectivesCommunicator { + public: + explicit MpiCollectivesCommunicator(int color, int key); + ~MpiCollectivesCommunicator() override; + + absl::Status AllReduce(const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t num_elements, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + absl::Status CollectivePermute(const RendezvousKey& key, size_t num_bytes, + std::optional source_rank, + absl::Span target_ranks, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + absl::Status AllToAll(const RendezvousKey& key, size_t chunk_bytes, + absl::Span input_buffers, + absl::Span output_buffers, + absl::Duration timeout) override; + absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + absl::Status ReduceScatter(const RendezvousKey& key, + ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + + private: + MPI_Comm comm_; + int mpi_rank_; + int mpi_size_; +}; + +class MpiCollectives : public CollectivesInterface { + public: + /* + The user has to explicitly call Init() and Finalize() before and + after use. + For example, using the Python client, this can be achieved with: + + collectives = xla_client._xla.make_mpi_collectives() + collectives.Init() + atexit.register(collectives.Finalize) + */ + void Init(); + void Finalize(); + + absl::StatusOr> GetCommunicator( + absl::Span global_devices, int rank) override; + + private: + absl::Status ExchangeGlobalDeviceIds( + absl::Span global_devices, int rank); + + int mpi_world_rank_; + int mpi_world_size_; + absl::flat_hash_map, int>, + std::shared_ptr> + contexts_; +}; + +} // namespace xla::cpu + +#endif // XLA_PJRT_CPU_MPI_COLLECTIVES_H_ diff --git a/xla/pjrt/cpu/pjrt_client_test_cpu.cc b/xla/pjrt/cpu/pjrt_client_test_cpu.cc index ccc2ac8cc2575..f0aea4ec2f326 100644 --- a/xla/pjrt/cpu/pjrt_client_test_cpu.cc +++ b/xla/pjrt/cpu/pjrt_client_test_cpu.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.cc b/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.cc index 5d327e57cb401..e3406e555868f 100644 --- a/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.cc +++ b/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/base/casts.h" +#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "xla/runtime/cpu_event.h" @@ -79,7 +80,7 @@ TrackedTfrtCpuDeviceBuffer::TrackedTfrtCpuDeviceBuffer( bool is_tuple, absl::InlinedVector, 4> buffers, absl::InlinedVector, 4> definition_events, - std::function on_delete_callback) + absl::AnyInvocable on_delete_callback) : TrackedTfrtCpuDeviceBuffer(is_tuple, std::move(buffers), AfterAll(definition_events), std::move(on_delete_callback)) {} @@ -88,7 +89,7 @@ TrackedTfrtCpuDeviceBuffer::TrackedTfrtCpuDeviceBuffer( bool is_tuple, absl::InlinedVector, 4> buffers, tsl::AsyncValueRef definition_event, - std::function on_delete_callback) + absl::AnyInvocable on_delete_callback) : is_tuple_(is_tuple), buffers_(std::move(buffers)), definition_event_(std::move(definition_event)), @@ -110,7 +111,7 @@ TrackedTfrtCpuDeviceBuffer::TrackedTfrtCpuDeviceBuffer( TrackedTfrtCpuDeviceBuffer::~TrackedTfrtCpuDeviceBuffer() { ReleaseDeviceMemory(); if (on_delete_callback_) { - on_delete_callback_(); + std::move(on_delete_callback_)(); } } @@ -133,7 +134,7 @@ void TrackedTfrtCpuDeviceBuffer::AddUsageEvents( if (usage_events_.size() >= 1024) { int i = 0; while (i < usage_events_.size()) { - auto& event = usage_events_.at(i); + auto& event = usage_events_[i]; if (event.IsAvailable()) { using std::swap; swap(event, usage_events_.back()); diff --git a/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h b/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h index 2d4b7589cbb4b..3a624b22934e1 100644 --- a/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h +++ b/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -54,7 +54,7 @@ class MaybeOwningCpuMemory { MaybeOwningCpuMemory& operator=(const MaybeOwningCpuMemory&) = delete; // Owning. - static StatusOr> AllocateShared( + static absl::StatusOr> AllocateShared( size_t size) { uint8_t* data = static_cast( tsl::port::AlignedMalloc(size, cpu_function_runtime::MinAlign())); @@ -88,13 +88,13 @@ class TrackedTfrtCpuDeviceBuffer { absl::InlinedVector, 4> buffers, absl::InlinedVector, 4> definition_events, - std::function on_delete_callback = nullptr); + absl::AnyInvocable on_delete_callback = nullptr); TrackedTfrtCpuDeviceBuffer( bool is_tuple, absl::InlinedVector, 4> buffers, tsl::AsyncValueRef definition_event, - std::function on_delete_callback = nullptr); + absl::AnyInvocable on_delete_callback = nullptr); // Move-only. TrackedTfrtCpuDeviceBuffer(TrackedTfrtCpuDeviceBuffer&&) = default; @@ -144,7 +144,7 @@ class TrackedTfrtCpuDeviceBuffer { absl::InlinedVector, 4> usage_events_; // A callback to call when the TrackedTfrtCpuDeviceBuffer is about to be // destroyed. - std::function on_delete_callback_; + absl::AnyInvocable on_delete_callback_; }; } // namespace xla diff --git a/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer_test.cc b/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer_test.cc index 4ca8b79a2fd88..b457cddf71910 100644 --- a/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer_test.cc +++ b/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/distributed/BUILD b/xla/pjrt/distributed/BUILD index 39a2f9205b7e7..1f8ea2162fb50 100644 --- a/xla/pjrt/distributed/BUILD +++ b/xla/pjrt/distributed/BUILD @@ -1,7 +1,7 @@ -load("//xla:xla.bzl", "xla_cc_test") load("@tsl//tsl:tsl.default.bzl", "tsl_grpc_cc_dependencies") load("@tsl//tsl/platform:build_config.bzl", "tf_proto_library") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") licenses(["notice"]) @@ -47,13 +47,10 @@ xla_cc_test( name = "topology_util_test", srcs = ["topology_util_test.cc"], deps = [ + ":in_memory_key_value_store", ":protocol_proto_cc", ":topology_util", - "//xla:status", - "//xla:statusor", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", - "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", @@ -71,10 +68,14 @@ cc_library( "client.h", ], deps = [ + ":key_value_store_interface", ":util", "//xla:statusor", "//xla:types", "//xla:util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@tsl//tsl/distributed_runtime/coordination:coordination_client", @@ -106,6 +107,7 @@ cc_library( ":client", ":service", "//xla:statusor", + "@tsl//tsl/platform:grpc_credentials", ] + tsl_grpc_cc_dependencies(), ) @@ -114,6 +116,7 @@ cc_library( srcs = ["topology_util.cc"], hdrs = ["topology_util.h"], deps = [ + ":key_value_store_interface", ":protocol_proto_cc", "//xla:status", "//xla:statusor", @@ -155,3 +158,28 @@ xla_cc_test( "@tsl//tsl/platform:test_main", ] + tsl_grpc_cc_dependencies(), ) + +cc_library( + name = "key_value_store_interface", + hdrs = ["key_value_store_interface.h"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "in_memory_key_value_store", + srcs = ["in_memory_key_value_store.cc"], + hdrs = ["in_memory_key_value_store.h"], + deps = [ + ":key_value_store_interface", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + ], +) diff --git a/xla/pjrt/distributed/client.cc b/xla/pjrt/distributed/client.cc index 60dc6878acd84..59610a246fd7a 100644 --- a/xla/pjrt/distributed/client.cc +++ b/xla/pjrt/distributed/client.cc @@ -22,8 +22,12 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/time/time.h" #include "grpcpp/channel.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "tsl/distributed_runtime/coordination/coordination_client.h" #include "tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "tsl/distributed_runtime/coordination/coordination_service_error_util.h" @@ -44,18 +48,18 @@ class DistributedRuntimeCoordinationServiceClient : DistributedRuntimeCoordinationServiceClient(channel, Options()) {} ~DistributedRuntimeCoordinationServiceClient() override; - xla::Status Connect() override; - xla::Status Shutdown() override; - xla::StatusOr BlockingKeyValueGet( + absl::Status Connect() override; + absl::Status Shutdown() override; + absl::StatusOr BlockingKeyValueGet( std::string_view key, absl::Duration timeout) override; - xla::StatusOr>> + absl::StatusOr>> KeyValueDirGet(std::string_view key) override; - xla::Status KeyValueSet(std::string_view key, - std::string_view value) override; - xla::Status KeyValueDelete(std::string_view key) override; - xla::Status WaitAtBarrier(std::string barrier_id, - absl::Duration timeout) override; - xla::StatusOr GetCoordinationServiceAgent() + absl::Status KeyValueSet(std::string_view key, + std::string_view value) override; + absl::Status KeyValueDelete(std::string_view key) override; + absl::Status WaitAtBarrier(std::string barrier_id, + absl::Duration timeout) override; + absl::StatusOr GetCoordinationServiceAgent() override; private: @@ -103,7 +107,7 @@ DistributedRuntimeCoordinationServiceClient:: DistributedRuntimeCoordinationServiceClient:: ~DistributedRuntimeCoordinationServiceClient() = default; -xla::Status DistributedRuntimeCoordinationServiceClient::Connect() { +absl::Status DistributedRuntimeCoordinationServiceClient::Connect() { const absl::Time deadline = absl::Now() + absl::Milliseconds(config_.cluster_register_timeout_in_ms()); @@ -126,20 +130,20 @@ xla::Status DistributedRuntimeCoordinationServiceClient::Connect() { return s; } -xla::Status DistributedRuntimeCoordinationServiceClient::Shutdown() { +absl::Status DistributedRuntimeCoordinationServiceClient::Shutdown() { LOG(INFO) << "Distributed task shutdown initiated."; Status s = coord_agent_->Shutdown(); LOG(INFO) << "Distributed task shutdown result: " << s; return s; } -xla::StatusOr +absl::StatusOr DistributedRuntimeCoordinationServiceClient::BlockingKeyValueGet( std::string_view key, absl::Duration timeout) { return coord_agent_->GetKeyValue(key, timeout); } -xla::StatusOr>> +absl::StatusOr>> DistributedRuntimeCoordinationServiceClient::KeyValueDirGet( std::string_view key) { // TODO(hanyangtay): Migrate to string_view for both client and coordination @@ -158,22 +162,22 @@ DistributedRuntimeCoordinationServiceClient::KeyValueDirGet( return kvs; } -xla::Status DistributedRuntimeCoordinationServiceClient::KeyValueDelete( +absl::Status DistributedRuntimeCoordinationServiceClient::KeyValueDelete( std::string_view key) { return coord_agent_->DeleteKeyValue(key); } -xla::Status DistributedRuntimeCoordinationServiceClient::KeyValueSet( +absl::Status DistributedRuntimeCoordinationServiceClient::KeyValueSet( std::string_view key, std::string_view value) { return coord_agent_->InsertKeyValue(key, value); } -xla::Status DistributedRuntimeCoordinationServiceClient::WaitAtBarrier( +absl::Status DistributedRuntimeCoordinationServiceClient::WaitAtBarrier( std::string barrier_id, absl::Duration timeout) { return coord_agent_->WaitAtBarrier(barrier_id, timeout, /*tasks=*/{}); } -xla::StatusOr +absl::StatusOr DistributedRuntimeCoordinationServiceClient::GetCoordinationServiceAgent() { return coord_agent_.get(); } @@ -184,4 +188,35 @@ std::unique_ptr GetDistributedRuntimeClient( return std::make_unique( channel, options); } + +namespace { + +class DistributedKeyValueStore : public KeyValueStoreInterface { + public: + DistributedKeyValueStore(std::shared_ptr client, + std::string prefix) + : client_(std::move(client)), prefix_(std::move(prefix)) {} + + absl::StatusOr Get(std::string_view key, + absl::Duration timeout) override { + return client_->BlockingKeyValueGet(absl::StrCat(prefix_, key), timeout); + } + + absl::Status Set(std::string_view key, std::string_view value) override { + return client_->KeyValueSet(absl::StrCat(prefix_, key), value); + } + + private: + std::shared_ptr client_; + std::string prefix_; +}; + +} // namespace + +std::shared_ptr GetDistributedKeyValueStore( + std::shared_ptr client, std::string prefix) { + return std::make_shared(std::move(client), + std::move(prefix)); +} + } // namespace xla diff --git a/xla/pjrt/distributed/client.h b/xla/pjrt/distributed/client.h index c780389c63ef2..37e675cae871a 100644 --- a/xla/pjrt/distributed/client.h +++ b/xla/pjrt/distributed/client.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/time/time.h" #include "grpcpp/channel.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/statusor.h" #include "xla/types.h" #include "tsl/platform/env.h" @@ -70,9 +71,9 @@ class DistributedRuntimeClient { // is reported by the coordinator, or we have not heard from the coordinator // recently. `coordinator_reported_failure` is true in the former case. // Exposed so tests can override this behavior to something non-fatal. - std::function + std::function missed_heartbeat_callback = - [](xla::Status status, bool coordinator_reported_failure) { + [](absl::Status status, bool coordinator_reported_failure) { if (coordinator_reported_failure) { LOG(QFATAL) << "Terminating process because the coordinator detected " @@ -103,19 +104,19 @@ class DistributedRuntimeClient { // connected. // Not thread-safe, i.e., calls to Connect()/Shutdown() must be serialized by // some other means. - virtual xla::Status Connect() = 0; + virtual absl::Status Connect() = 0; // Reports to the master that the client is ready to shutdown, and blocks // until all clients are ready to shutdown or the shutdown timeout expires. // Not thread-safe. - virtual xla::Status Shutdown() = 0; + virtual absl::Status Shutdown() = 0; // The following APIs are thread-safe. // Key-value store API. // There are no concurrency guarantees. To avoid a race / impose an ordering // on potentially concurrent ops (e.g. set, delete), use WaitAtBarrier(). - virtual xla::StatusOr BlockingKeyValueGet( + virtual absl::StatusOr BlockingKeyValueGet( std::string_view key, absl::Duration timeout) = 0; // Get all key-value pairs under a directory (key). @@ -123,24 +124,24 @@ class DistributedRuntimeClient { // the directory. // This is not a blocking call. If no keys are found, an empty vector is // returned immediately. - virtual xla::StatusOr>> + virtual absl::StatusOr>> KeyValueDirGet(std::string_view key) = 0; - virtual xla::Status KeyValueSet(std::string_view key, - std::string_view value) = 0; + virtual absl::Status KeyValueSet(std::string_view key, + std::string_view value) = 0; // Delete the key-value. If the key is a directory, recursively clean // up all key-values under the directory. - virtual xla::Status KeyValueDelete(std::string_view key) = 0; + virtual absl::Status KeyValueDelete(std::string_view key) = 0; // Blocks until all nodes are at the barrier or the barrier times out. // `barrier_id` should be unique across barriers. - virtual xla::Status WaitAtBarrier(std::string barrier_id, - absl::Duration timeout) = 0; + virtual absl::Status WaitAtBarrier(std::string barrier_id, + absl::Duration timeout) = 0; // Returns pointer to coordination service agent, or InternalError if the // client does not use coordination service. - virtual StatusOr + virtual absl::StatusOr GetCoordinationServiceAgent() = 0; }; @@ -149,6 +150,9 @@ std::unique_ptr GetDistributedRuntimeClient( std::shared_ptr<::grpc::Channel> channel, const DistributedRuntimeClient::Options& options); +std::shared_ptr GetDistributedKeyValueStore( + std::shared_ptr client, std::string key_prefix); + } // namespace xla #endif // XLA_PJRT_DISTRIBUTED_CLIENT_H_ diff --git a/xla/pjrt/distributed/client_server_test.cc b/xla/pjrt/distributed/client_server_test.cc index 7e3c05d3b154a..f9a9a5c0a5cda 100644 --- a/xla/pjrt/distributed/client_server_test.cc +++ b/xla/pjrt/distributed/client_server_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -50,7 +50,7 @@ constexpr absl::Duration kBarrierTimeout = absl::Milliseconds(200); class ClientServerTest : public testing::Test { public: - std::unique_ptr GetClient( + std::shared_ptr GetClient( int node_id, DistributedRuntimeClient::Options client_options = {}, std::shared_ptr<::grpc::Channel> channel = nullptr) { client_options.node_id = node_id; @@ -115,7 +115,7 @@ TEST_F(ClientServerTest, ConnectAndShutdownAreBarriers) { absl::Barrier barrier(num_nodes); - auto thread_fn = [&](int node_id) -> xla::Status { + auto thread_fn = [&](int node_id) -> absl::Status { auto client = GetClient(node_id); // Allow the threads to call Connect one-by-one in order. @@ -155,7 +155,7 @@ TEST_F(ClientServerTest, ConnectAndShutdownAreBarriers) { return OkStatus(); }; - std::vector statuses(num_nodes); + std::vector statuses(num_nodes); { tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", num_nodes); @@ -207,7 +207,7 @@ TEST_F(ClientServerTest, ConnectAndEnumerateDevices) { // client. This ensures that devices are sent out of turn (compared to their // node ids). absl::Notification n; - auto thread0_fn = [&]() -> xla::Status { + auto thread0_fn = [&]() -> absl::Status { auto client = GetClient(/*node_id=*/0); GlobalTopologyProto topology; TF_RETURN_IF_ERROR(client->Connect()); @@ -218,18 +218,12 @@ TEST_F(ClientServerTest, ConnectAndEnumerateDevices) { // Sleep a short while for the other thread to send their device info first. absl::SleepFor(absl::Seconds(1)); - auto kv_get = [&](std::string_view k, - absl::Duration timeout) -> xla::StatusOr { - return client->BlockingKeyValueGet(k, timeout); - }; - auto kv_put = [&](std::string_view k, std::string_view v) -> xla::Status { - return client->KeyValueSet(k, v); - }; + auto kv_store = GetDistributedKeyValueStore(client, /*key_prefix=*/""); TF_RETURN_IF_ERROR( ExchangeTopologies("cuda", /*node_id=*/0, /*num_nodes=*/2, /*get_local_topology_timeout=*/absl::Minutes(1), /*get_global_topology_timeout=*/absl::Minutes(1), - kv_get, kv_put, locals[0], &topology)); + kv_store.get(), locals[0], &topology)); TF_RET_CHECK( xla::protobuf_util::ProtobufEquals(topology, expected_topology)) << topology.DebugString(); @@ -240,7 +234,7 @@ TEST_F(ClientServerTest, ConnectAndEnumerateDevices) { TF_RET_CHECK(value == "value2"); return OkStatus(); }; - auto thread1_fn = [&]() -> xla::Status { + auto thread1_fn = [&]() -> absl::Status { auto client = GetClient(/*node_id=*/1); GlobalTopologyProto topology; TF_RETURN_IF_ERROR(client->Connect()); @@ -250,18 +244,12 @@ TEST_F(ClientServerTest, ConnectAndEnumerateDevices) { // We cannot send the notification after the call since there is a barrier // within the call that would cause a deadlock. n.Notify(); - auto kv_get = [&](std::string_view k, - absl::Duration timeout) -> xla::StatusOr { - return client->BlockingKeyValueGet(k, timeout); - }; - auto kv_put = [&](std::string_view k, std::string_view v) -> xla::Status { - return client->KeyValueSet(k, v); - }; + auto kv_store = GetDistributedKeyValueStore(client, /*key_prefix=*/""); TF_RETURN_IF_ERROR( ExchangeTopologies("cuda", /*node_id=*/1, /*num_nodes=*/2, /*get_local_topology_timeout=*/absl::Minutes(1), /*get_global_topology_timeout=*/absl::Minutes(1), - kv_get, kv_put, locals[1], &topology)); + kv_store.get(), locals[1], &topology)); TF_RET_CHECK( xla::protobuf_util::ProtobufEquals(topology, expected_topology)) << topology.DebugString(); @@ -273,9 +261,9 @@ TEST_F(ClientServerTest, ConnectAndEnumerateDevices) { return OkStatus(); }; - std::vector> functions = {thread0_fn, - thread1_fn}; - std::vector statuses(functions.size()); + std::vector> functions = {thread0_fn, + thread1_fn}; + std::vector statuses(functions.size()); { tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", functions.size()); @@ -311,29 +299,23 @@ TEST_F(ClientServerTest, EnumerateElevenDevices) { node->mutable_devices(0)->set_slice_index(i % 2); } - auto thread_fn = [&](int node_id) -> xla::Status { + auto thread_fn = [&](int node_id) -> absl::Status { auto client = GetClient(node_id); GlobalTopologyProto topology; TF_RETURN_IF_ERROR(client->Connect()); - auto kv_get = [&](std::string_view k, - absl::Duration timeout) -> xla::StatusOr { - return client->BlockingKeyValueGet(k, timeout); - }; - auto kv_put = [&](std::string_view k, std::string_view v) -> xla::Status { - return client->KeyValueSet(k, v); - }; + auto kv_store = GetDistributedKeyValueStore(client, /*key_prefix=*/""); TF_RETURN_IF_ERROR( ExchangeTopologies("cuda", /*node_id=*/node_id, num_nodes, /*get_local_topology_timeout=*/absl::Minutes(1), /*get_global_topology_timeout=*/absl::Minutes(1), - kv_get, kv_put, locals[node_id], &topology)); + kv_store.get(), locals[node_id], &topology)); TF_RET_CHECK( xla::protobuf_util::ProtobufEquals(topology, expected_topology)) << topology.DebugString(); return OkStatus(); }; - std::vector statuses(num_nodes); + std::vector statuses(num_nodes); { tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", num_nodes); @@ -354,7 +336,7 @@ TEST_F(ClientServerTest, ZeroInitTimeoutShouldStillWaitForOtherTasks) { absl::Barrier barrier(num_nodes); - auto thread_fn = [&](int node_id) -> xla::Status { + auto thread_fn = [&](int node_id) -> absl::Status { DistributedRuntimeClient::Options client_options; client_options.init_timeout = absl::ZeroDuration(); auto client = GetClient(node_id, client_options); @@ -369,7 +351,7 @@ TEST_F(ClientServerTest, ZeroInitTimeoutShouldStillWaitForOtherTasks) { return OkStatus(); }; - std::vector statuses(num_nodes); + std::vector statuses(num_nodes); { tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", num_nodes); @@ -386,11 +368,11 @@ TEST_F(ClientServerTest, ClientsTerminateShutdownIfAnyClientGoesAway) { int num_nodes = 3; StartService(num_nodes); - auto thread_fn = [&](int node_id) -> xla::Status { + auto thread_fn = [&](int node_id) -> absl::Status { DistributedRuntimeClient::Options client_options; client_options.shutdown_on_destruction = node_id != 0; client_options.missed_heartbeat_callback = - [&](xla::Status status, bool coordinator_initiated) {}; + [&](absl::Status status, bool coordinator_initiated) {}; auto client = GetClient(node_id, client_options); TF_RETURN_IF_ERROR(client->Connect()); @@ -405,7 +387,7 @@ TEST_F(ClientServerTest, ClientsTerminateShutdownIfAnyClientGoesAway) { return OkStatus(); }; - std::vector statuses(num_nodes); + std::vector statuses(num_nodes); { tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", num_nodes); @@ -431,11 +413,11 @@ TEST_F(ClientServerTest, ClientsReceiveMissedHeartbeatIfAnyClientGoesAway) { int num_nodes = 3; StartService(num_nodes); - auto thread_fn = [&](int node_id) -> xla::Status { + auto thread_fn = [&](int node_id) -> absl::Status { DistributedRuntimeClient::Options client_options; client_options.shutdown_on_destruction = (node_id != 0); absl::Notification shutdown; - client_options.missed_heartbeat_callback = [&](xla::Status status, + client_options.missed_heartbeat_callback = [&](absl::Status status, bool coordinator_initiated) { shutdown.Notify(); }; @@ -450,7 +432,7 @@ TEST_F(ClientServerTest, ClientsReceiveMissedHeartbeatIfAnyClientGoesAway) { return OkStatus(); }; - std::vector statuses(num_nodes); + std::vector statuses(num_nodes); { tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", num_nodes); @@ -474,12 +456,12 @@ TEST_F(ClientServerTest, ClientsTerminateIfServiceGoesAway) { absl::Barrier barrier(num_nodes + 1); - auto thread_fn = [&](int node_id) -> xla::Status { + auto thread_fn = [&](int node_id) -> absl::Status { DistributedRuntimeClient::Options client_options; client_options.rpc_timeout = absl::Seconds(1); client_options.shutdown_timeout = absl::Seconds(10); absl::Notification shutdown; - client_options.missed_heartbeat_callback = [&](xla::Status status, + client_options.missed_heartbeat_callback = [&](absl::Status status, bool coordinator_initiated) { shutdown.Notify(); }; @@ -498,7 +480,7 @@ TEST_F(ClientServerTest, ClientsTerminateIfServiceGoesAway) { return OkStatus(); }; - std::vector statuses(num_nodes); + std::vector statuses(num_nodes); { tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", num_nodes); @@ -520,7 +502,7 @@ TEST_F(ClientServerTest, LateClientsAreOk) { absl::Barrier barrier(num_nodes); - auto thread_fn = [&](int node_id) -> xla::Status { + auto thread_fn = [&](int node_id) -> absl::Status { DistributedRuntimeClient::Options client_options; client_options.init_timeout = absl::Seconds(20); client_options.rpc_timeout = absl::Milliseconds(200); @@ -533,7 +515,7 @@ TEST_F(ClientServerTest, LateClientsAreOk) { return OkStatus(); }; - std::vector statuses(num_nodes); + std::vector statuses(num_nodes); { tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", num_nodes); @@ -555,13 +537,13 @@ TEST_F(ClientServerTest, ConnectEventuallyTimesOutIfAClientDoesNotShowUp) { service_options.shutdown_timeout = timeout; StartService(num_nodes, service_options); - auto thread_fn = [&](int node_id) -> xla::Status { + auto thread_fn = [&](int node_id) -> absl::Status { DistributedRuntimeClient::Options client_options; client_options.init_timeout = timeout; client_options.rpc_timeout = timeout; // Overwrite the default error callback which invokes LOG(QFATAL). client_options.missed_heartbeat_callback = - [](xla::Status status, bool coordinator_reported_failure) { + [](absl::Status status, bool coordinator_reported_failure) { LOG(ERROR) << "Distributed client has missing heartbeats: " << status; }; auto client = GetClient(node_id, client_options); @@ -572,7 +554,7 @@ TEST_F(ClientServerTest, ConnectEventuallyTimesOutIfAClientDoesNotShowUp) { }; // Note: one fewer thread than 'num_nodes'. - std::vector statuses(num_nodes - 1); + std::vector statuses(num_nodes - 1); { tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", num_nodes); @@ -589,7 +571,7 @@ TEST_F(ClientServerTest, WaitAtBarrier_Succeed) { int num_nodes = 2; StartService(num_nodes); - auto thread_fn = [&](int node_id) -> xla::Status { + auto thread_fn = [&](int node_id) -> absl::Status { auto client = GetClient(node_id); TF_RETURN_IF_ERROR(client->Connect()); @@ -600,7 +582,7 @@ TEST_F(ClientServerTest, WaitAtBarrier_Succeed) { return xla::OkStatus(); }; - std::vector statuses(num_nodes); + std::vector statuses(num_nodes); { tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", num_nodes); @@ -618,7 +600,7 @@ TEST_F(ClientServerTest, WaitAtBarrier_Timeout) { StartService(num_nodes); absl::Notification n; - auto thread_fn = [&](int node_id) -> xla::Status { + auto thread_fn = [&](int node_id) -> absl::Status { auto client = GetClient(node_id); TF_RETURN_IF_ERROR(client->Connect()); @@ -637,7 +619,7 @@ TEST_F(ClientServerTest, WaitAtBarrier_Timeout) { return xla::OkStatus(); }; - std::vector statuses(num_nodes); + std::vector statuses(num_nodes); { tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", num_nodes); @@ -657,7 +639,7 @@ TEST_F(ClientServerTest, WaitAtBarrier_TimeoutWithDifferentBarrierId) { int num_nodes = 2; StartService(num_nodes); - auto thread_fn = [&](int node_id) -> xla::Status { + auto thread_fn = [&](int node_id) -> absl::Status { auto client = GetClient(node_id); TF_RETURN_IF_ERROR(client->Connect()); @@ -673,7 +655,7 @@ TEST_F(ClientServerTest, WaitAtBarrier_TimeoutWithDifferentBarrierId) { return xla::OkStatus(); }; - std::vector statuses(num_nodes); + std::vector statuses(num_nodes); { tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", num_nodes); @@ -691,7 +673,7 @@ TEST_F(ClientServerTest, WaitAtBarrier_FailWithSameBarrierId) { int num_nodes = 2; StartService(num_nodes); - auto thread_fn = [&](int node_id) -> xla::Status { + auto thread_fn = [&](int node_id) -> absl::Status { auto client = GetClient(node_id); TF_RETURN_IF_ERROR(client->Connect()); @@ -702,7 +684,7 @@ TEST_F(ClientServerTest, WaitAtBarrier_FailWithSameBarrierId) { return xla::OkStatus(); }; - std::vector statuses(num_nodes); + std::vector statuses(num_nodes); { tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", num_nodes); diff --git a/xla/pjrt/distributed/distributed.cc b/xla/pjrt/distributed/distributed.cc index fa7c6278c896e..4eb5c8ac65ad6 100644 --- a/xla/pjrt/distributed/distributed.cc +++ b/xla/pjrt/distributed/distributed.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,27 +15,32 @@ limitations under the License. #include "xla/pjrt/distributed/distributed.h" +#include #include -#include "grpcpp/grpcpp.h" +#include "grpcpp/channel.h" +#include "grpcpp/create_channel.h" #include "xla/pjrt/distributed/client.h" #include "xla/pjrt/distributed/service.h" +#include "xla/statusor.h" +#include "tsl/platform/grpc_credentials.h" namespace xla { -StatusOr> +// In OSS, insecure credentials are used as default. +constexpr bool kVerifySecureCredentials = false; + +absl::StatusOr> GetDistributedRuntimeService(std::string address, const CoordinationServiceImpl::Options& options) { - auto credentials = ::grpc::InsecureServerCredentials(); - return DistributedRuntimeService::Get(address, credentials, options); + return DistributedRuntimeService::Get( + address, tsl::GetServerCredentials(kVerifySecureCredentials), options); } std::shared_ptr GetDistributedRuntimeClient( std::string address, const DistributedRuntimeClient::Options& options) { - std::shared_ptr<::grpc::ChannelCredentials> creds = - ::grpc::InsecureChannelCredentials(); - std::shared_ptr<::grpc::Channel> channel = - ::grpc::CreateChannel(address, creds); + std::shared_ptr channel = grpc::CreateChannel( + address, tsl::GetClientCredentials(kVerifySecureCredentials)); return GetDistributedRuntimeClient(channel, options); } diff --git a/xla/pjrt/distributed/distributed.h b/xla/pjrt/distributed/distributed.h index 7140819d2d227..393bd234530bb 100644 --- a/xla/pjrt/distributed/distributed.h +++ b/xla/pjrt/distributed/distributed.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ namespace xla { // Builds a distributed runtime service. `address` is the address on which // the service should listen, e.g., [::]:1234 . `num_nodes` is the number // of nodes in the cluster. -StatusOr> +absl::StatusOr> GetDistributedRuntimeService(std::string address, const CoordinationServiceImpl::Options& options); diff --git a/xla/pjrt/distributed/in_memory_key_value_store.cc b/xla/pjrt/distributed/in_memory_key_value_store.cc new file mode 100644 index 0000000000000..8140bb9bd80ea --- /dev/null +++ b/xla/pjrt/distributed/in_memory_key_value_store.cc @@ -0,0 +1,51 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/pjrt/distributed/in_memory_key_value_store.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" + +namespace xla { + +absl::StatusOr InMemoryKeyValueStore::Get(std::string_view key, + absl::Duration timeout) { + absl::MutexLock lock(&mu_); + auto cond = [&]() { + mu_.AssertHeld(); + return kv_store_.find(key) != kv_store_.end(); + }; + bool exists = mu_.AwaitWithTimeout(absl::Condition(&cond), timeout); + if (!exists) { + return absl::NotFoundError( + absl::StrCat(key, " is not found in the kv store.")); + } + return kv_store_.find(key)->second; +} + +absl::Status InMemoryKeyValueStore::Set(std::string_view key, + std::string_view value) { + absl::MutexLock lock(&mu_); + kv_store_[key] = value; + return absl::OkStatus(); +} + +} // namespace xla diff --git a/xla/pjrt/distributed/in_memory_key_value_store.h b/xla/pjrt/distributed/in_memory_key_value_store.h new file mode 100644 index 0000000000000..680abc5b4c9c0 --- /dev/null +++ b/xla/pjrt/distributed/in_memory_key_value_store.h @@ -0,0 +1,44 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_DISTRIBUTED_IN_MEMORY_KEY_VALUE_STORE_H_ +#define XLA_PJRT_DISTRIBUTED_IN_MEMORY_KEY_VALUE_STORE_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" + +namespace xla { + +class InMemoryKeyValueStore : public KeyValueStoreInterface { + public: + absl::StatusOr Get(std::string_view key, + absl::Duration timeout) override; + + absl::Status Set(std::string_view key, std::string_view value) override; + + private: + absl::Mutex mu_; + absl::flat_hash_map kv_store_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace xla + +#endif // XLA_PJRT_DISTRIBUTED_IN_MEMORY_KEY_VALUE_STORE_H_ diff --git a/xla/pjrt/distributed/key_value_store_interface.h b/xla/pjrt/distributed/key_value_store_interface.h new file mode 100644 index 0000000000000..38b48e6063b02 --- /dev/null +++ b/xla/pjrt/distributed/key_value_store_interface.h @@ -0,0 +1,50 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_INTERFACE_H_ +#define XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_INTERFACE_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" + +namespace xla { + +// In the multi-node case, the caller of PjRtClient can provide a key-value +// store accessible across nodes. The caller can provide the two callbacks +// below to access the key-value store. There are a few requirements: +// (1) Get and Set must be thread-safe. +// (2) The caller that provides the two callbacks is responsible for avoiding +// key collisions between different users of key-value store (i.e. between +// different plugins, but not between different GPU plugin nodes). +class KeyValueStoreInterface { + public: + virtual ~KeyValueStoreInterface() = default; + + // Blocking Get(). + // There are no concurrency guarantees. To avoid a race / impose an ordering + // on potentially concurrent ops (e.g. set, delete), use WaitAtBarrier(). + virtual absl::StatusOr Get(std::string_view key, + absl::Duration timeout) = 0; + + virtual absl::Status Set(std::string_view key, std::string_view value) = 0; +}; + +} // namespace xla + +#endif // XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_INTERFACE_H_ diff --git a/xla/pjrt/distributed/protocol.proto b/xla/pjrt/distributed/protocol.proto index 010856095fd1d..955a537d1c45d 100644 --- a/xla/pjrt/distributed/protocol.proto +++ b/xla/pjrt/distributed/protocol.proto @@ -1,4 +1,4 @@ -// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// Copyright 2020 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -45,7 +45,13 @@ message DeviceProto { // Devices with the same slice_index are connected by fast network, e.g. // NVLink on GPUs. int32 slice_index = 5; -} + + // Store vendor-specific compute capability. + string compute_capability = 6; + + // The number of cores (e.g. SMs on GPUs) on the device. + int32 core_count = 7; +}; message LocalTopologyProto { int32 node_id = 1; diff --git a/xla/pjrt/distributed/service.cc b/xla/pjrt/distributed/service.cc index 37d4c651544b6..6128d4dad6516 100644 --- a/xla/pjrt/distributed/service.cc +++ b/xla/pjrt/distributed/service.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -87,7 +87,7 @@ void CoordinationServiceImpl::StartRpcThread() { [service = coord_rpc_service_.get()] { service->HandleRPCsLoop(); })); } -xla::StatusOr> +absl::StatusOr> DistributedRuntimeService::Get( const std::string& address, std::shared_ptr<::grpc::ServerCredentials> credentials, diff --git a/xla/pjrt/distributed/service.h b/xla/pjrt/distributed/service.h index d2e308d0f8a71..3e7a6f6867206 100644 --- a/xla/pjrt/distributed/service.h +++ b/xla/pjrt/distributed/service.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -84,7 +84,7 @@ class CoordinationServiceImpl { class DistributedRuntimeService { public: - static xla::StatusOr> Get( + static absl::StatusOr> Get( const std::string& address, std::shared_ptr<::grpc::ServerCredentials> credentials, const CoordinationServiceImpl::Options& options); diff --git a/xla/pjrt/distributed/topology_util.cc b/xla/pjrt/distributed/topology_util.cc index fa3e26b6595e3..bec8c617010ce 100644 --- a/xla/pjrt/distributed/topology_util.cc +++ b/xla/pjrt/distributed/topology_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,6 +29,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/distributed/protocol.pb.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/utils.h" @@ -49,7 +50,7 @@ static constexpr char kBootIdPath[] = "/proc/sys/kernel/random/boot_id"; // Retrieve content of /proc/sys/kernel/random/boot_id as a string. // Note that procfs file may have file size 0 which throws off generic file // readers such as tsl::ReadFileToString. -StatusOr GetBootIdString() { +absl::StatusOr GetBootIdString() { std::string boot_id_str; #ifdef __linux__ std::ifstream file(kBootIdPath); @@ -73,9 +74,9 @@ static std::string GetGlobalTopologyKey(std::string_view platform) { return absl::StrCat("global_topology/", platform); } -static StatusOr> GetAllLocalTopologies( - std::string_view platform, int num_nodes, - const PjRtClient::KeyValueGetCallback& kv_get, absl::Duration timeout) { +static absl::StatusOr> GetAllLocalTopologies( + std::string_view platform, int num_nodes, KeyValueStoreInterface* kv_store, + absl::Duration timeout) { std::vector> local_topology_strs(num_nodes); // TODO(ezhulenev): Should a thread pool become a function argument? @@ -86,8 +87,8 @@ static StatusOr> GetAllLocalTopologies( absl::Mutex mu; for (int i = 0; i < num_nodes; i++) { thread_pool.Schedule([&, i] { - StatusOr local_topology_str = - kv_get(GetLocalTopologyKey(platform, i), timeout); + absl::StatusOr local_topology_str = + kv_store->Get(GetLocalTopologyKey(platform, i), timeout); { absl::MutexLock lock(&mu); local_topology_strs[i] = local_topology_str; @@ -101,7 +102,7 @@ static StatusOr> GetAllLocalTopologies( std::vector local_topologies; int max_num_failed_message = 10; int failed_count = 0; - for (const StatusOr& str : local_topology_strs) { + for (const absl::StatusOr& str : local_topology_strs) { if (str.ok()) { LocalTopologyProto local; local.ParseFromString(*str); @@ -157,8 +158,7 @@ GlobalTopologyProto BuildGlobalTopology( Status ExchangeTopologies(std::string_view platform, int node_id, int num_nodes, absl::Duration get_local_topology_timeout, absl::Duration get_global_topology_timeout, - const PjRtClient::KeyValueGetCallback& kv_get, - const PjRtClient::KeyValuePutCallback& kv_put, + KeyValueStoreInterface* kv_store, const LocalTopologyProto& local_topology, GlobalTopologyProto* global_topology) { VLOG(3) << "Local Topology for platform" << platform << ":\n" @@ -171,25 +171,25 @@ Status ExchangeTopologies(std::string_view platform, int node_id, int num_nodes, } return absl::OkStatus(); } - - TF_RETURN_IF_ERROR(kv_put(GetLocalTopologyKey(platform, node_id), - local_topology.SerializeAsString())); + CHECK(kv_store != nullptr); + TF_RETURN_IF_ERROR(kv_store->Set(GetLocalTopologyKey(platform, node_id), + local_topology.SerializeAsString())); // The lead node gets all local topologies, builds the global topology and // puts it to the key-value store. std::string global_topology_key = GetGlobalTopologyKey(platform); if (node_id == 0) { TF_ASSIGN_OR_RETURN(std::vector local_topologies, - GetAllLocalTopologies(platform, num_nodes, kv_get, + GetAllLocalTopologies(platform, num_nodes, kv_store, get_local_topology_timeout)); *global_topology = BuildGlobalTopology(absl::Span(local_topologies)); - TF_RETURN_IF_ERROR( - kv_put(global_topology_key, global_topology->SerializeAsString())); + TF_RETURN_IF_ERROR(kv_store->Set(global_topology_key, + global_topology->SerializeAsString())); } else { TF_ASSIGN_OR_RETURN( std::string global_topology_str, - kv_get(global_topology_key, get_global_topology_timeout)); + kv_store->Get(global_topology_key, get_global_topology_timeout)); global_topology->ParseFromString(global_topology_str); } VLOG(3) << "Global topology for platform " << platform << ":\n" diff --git a/xla/pjrt/distributed/topology_util.h b/xla/pjrt/distributed/topology_util.h index 10e3732c71f67..a74478cdb0b1c 100644 --- a/xla/pjrt/distributed/topology_util.h +++ b/xla/pjrt/distributed/topology_util.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,8 +21,8 @@ limitations under the License. #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/distributed/protocol.pb.h" -#include "xla/pjrt/pjrt_client.h" #include "xla/status.h" #include "xla/statusor.h" @@ -30,7 +30,7 @@ namespace xla { // Retrieve content of /proc/sys/kernel/random/boot_id as a string. // Empty on non-Linux platforms. -StatusOr GetBootIdString(); +absl::StatusOr GetBootIdString(); // Performs a distributed exchange of topologies using a KV store. Each process // provides its local topology, and the local topologies are exchanged to @@ -38,8 +38,7 @@ StatusOr GetBootIdString(); Status ExchangeTopologies(std::string_view platform, int node_id, int num_nodes, absl::Duration get_local_topology_timeout, absl::Duration get_global_topology_timeout, - const PjRtClient::KeyValueGetCallback& kv_get, - const PjRtClient::KeyValuePutCallback& kv_put, + KeyValueStoreInterface* kv_store, const LocalTopologyProto& local_topology, GlobalTopologyProto* global_topology); diff --git a/xla/pjrt/distributed/topology_util_test.cc b/xla/pjrt/distributed/topology_util_test.cc index c8baa3e4ac6f5..f311b398d5309 100644 --- a/xla/pjrt/distributed/topology_util_test.cc +++ b/xla/pjrt/distributed/topology_util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,13 +19,9 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/synchronization/mutex.h" #include "absl/time/time.h" +#include "xla/pjrt/distributed/in_memory_key_value_store.h" #include "xla/pjrt/distributed/protocol.pb.h" -#include "xla/status.h" -#include "xla/statusor.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" @@ -64,26 +60,7 @@ TEST(TopologyTest, ExchangeTopology) { DeviceProto* d3 = locals[1].add_devices(); d3->set_local_device_ordinal(1); - absl::Mutex mu; - absl::flat_hash_map kv; - - auto kv_get = [&](std::string_view key, - absl::Duration timeout) -> xla::StatusOr { - absl::MutexLock lock(&mu); - auto ready = [&]() { return kv.contains(key); }; - if (mu.AwaitWithTimeout(absl::Condition(&ready), timeout)) { - return kv[key]; - } - return absl::NotFoundError("key not found"); - }; - - auto kv_put = [&](std::string_view key, - std::string_view value) -> xla::Status { - absl::MutexLock lock(&mu); - kv[key] = value; - return absl::OkStatus(); - }; - + InMemoryKeyValueStore kv_store; std::vector globals(num_nodes); { tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "TestPool", @@ -94,7 +71,7 @@ TEST(TopologyTest, ExchangeTopology) { /*platform=*/"cuda", /*node_id=*/i, num_nodes, /*get_local_topology_timeout=*/ absl::Seconds(10), /*get_global_topology_timeout=*/ - absl::Seconds(10), kv_get, kv_put, locals[i], &globals[i])); + absl::Seconds(10), &kv_store, locals[i], &globals[i])); }); } } diff --git a/xla/pjrt/distributed/util.h b/xla/pjrt/distributed/util.h index 04d25d12eb4bc..8e0f290795207 100644 --- a/xla/pjrt/distributed/util.h +++ b/xla/pjrt/distributed/util.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/event_pool.cc b/xla/pjrt/event_pool.cc index 4a07a6ee75b10..34c37a1b00394 100644 --- a/xla/pjrt/event_pool.cc +++ b/xla/pjrt/event_pool.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ EventPool::Handle::~Handle() { EventPool::EventPool(bool allow_reuse) : allow_reuse_(allow_reuse), next_sequence_number_(1) {} -StatusOr EventPool::AllocateEvent( +absl::StatusOr EventPool::AllocateEvent( se::StreamExecutor* executor) { Handle event; @@ -54,11 +54,11 @@ StatusOr EventPool::AllocateEvent( void EventPool::ThenRecordEvent(se::Stream* stream, EventPool::Handle& handle) { absl::MutexLock lock(&mu_); - stream->ThenRecordEvent(handle.event_.get()); + stream->RecordEvent(handle.event_.get()).IgnoreError(); handle.sequence_number_ = next_sequence_number_++; } -StatusOr EventPool::ThenAllocateAndRecordEvent( +absl::StatusOr EventPool::ThenAllocateAndRecordEvent( se::Stream* stream) { TF_ASSIGN_OR_RETURN(EventPool::Handle handle, AllocateEvent(stream->parent())); diff --git a/xla/pjrt/event_pool.h b/xla/pjrt/event_pool.h index 704ef41a8c7a4..89b8f6d8161a7 100644 --- a/xla/pjrt/event_pool.h +++ b/xla/pjrt/event_pool.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -76,11 +76,11 @@ class EventPool { // such as cudaStreamWaitEvent capture the state of the event at the time of // the host-side call and are not affected by a later host-side // cudaEventRecord. - StatusOr ThenAllocateAndRecordEvent(se::Stream* stream); + absl::StatusOr ThenAllocateAndRecordEvent(se::Stream* stream); // Version of ThenAllocateAndRecordEvent split into two phases; this is // sometimes helpful if we want to avoid failures by preallocating events. - StatusOr AllocateEvent(se::StreamExecutor* executor); + absl::StatusOr AllocateEvent(se::StreamExecutor* executor); void ThenRecordEvent(se::Stream* stream, EventPool::Handle& handle); private: diff --git a/xla/python/exceptions.h b/xla/pjrt/exceptions.h similarity index 91% rename from xla/python/exceptions.h rename to xla/pjrt/exceptions.h index c5b7e72e61663..bf72c4f4fe5ec 100644 --- a/xla/python/exceptions.h +++ b/xla/pjrt/exceptions.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_PYTHON_EXCEPTIONS_H_ -#define XLA_PYTHON_EXCEPTIONS_H_ +#ifndef XLA_PJRT_EXCEPTIONS_H_ +#define XLA_PJRT_EXCEPTIONS_H_ #include #include @@ -64,4 +64,4 @@ class XlaRuntimeError : public std::runtime_error { } // namespace xla -#endif // XLA_PYTHON_EXCEPTIONS_H_ +#endif // XLA_PJRT_EXCEPTIONS_H_ diff --git a/xla/pjrt/executable_metadata.proto b/xla/pjrt/executable_metadata.proto new file mode 100644 index 0000000000000..db308d57af477 --- /dev/null +++ b/xla/pjrt/executable_metadata.proto @@ -0,0 +1,23 @@ +syntax = "proto3"; + +package xla; + +import "xla/service/hlo.proto"; + +// Mirror of xla::CompiledMemoryStats. +message CompiledMemoryStatsProto { + // Device default memory (e.g., HBM for GPU/TPU) usage stats. + int64 generated_code_size_in_bytes = 1; + int64 argument_size_in_bytes = 2; + int64 output_size_in_bytes = 3; + int64 alias_size_in_bytes = 4; + int64 temp_size_in_bytes = 5; + xla.HloProto hlo_proto = 6; + + // Host memory usage stats. + int64 host_generated_code_size_in_bytes = 7; + int64 host_argument_size_in_bytes = 8; + int64 host_output_size_in_bytes = 9; + int64 host_alias_size_in_bytes = 10; + int64 host_temp_size_in_bytes = 11; +} diff --git a/xla/pjrt/gpu/BUILD b/xla/pjrt/gpu/BUILD index 5332b597275f5..226361a748f91 100644 --- a/xla/pjrt/gpu/BUILD +++ b/xla/pjrt/gpu/BUILD @@ -1,15 +1,14 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("//xla:xla.bzl", "xla_cc_test") -load("//xla/stream_executor:build_defs.bzl", "if_cuda_or_rocm", "if_gpu_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") -load("@tsl//tsl:tsl.bzl", "if_nccl") +load("@tsl//tsl:tsl.bzl", "if_google", "internal_visibility") load("@tsl//tsl/platform:build_config.bzl", "tf_proto_library") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") +load("//xla:xla.bzl", "xla_cc_test") +load("//xla/stream_executor:build_defs.bzl", "if_cuda_or_rocm", "if_gpu_is_configured") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//xla:internal"], + default_visibility = internal_visibility(["//xla:internal"]), licenses = ["notice"], ) @@ -17,7 +16,7 @@ cc_library( name = "gpu_helpers", srcs = ["gpu_helpers.cc"], hdrs = ["gpu_helpers.h"], - visibility = ["//xla/pjrt:friends"], + visibility = internal_visibility(["//xla/pjrt:friends"]), deps = [ "//xla:statusor", "//xla:types", @@ -27,10 +26,10 @@ cc_library( "//xla/service:platform_util", "//xla/stream_executor", "//xla/stream_executor/integrations:device_mem_allocator", + "//xla/tsl/util:env_var", "@com_google_absl//absl/types:span", "@tsl//tsl/framework:bfc_allocator", "@tsl//tsl/framework:device_id_impl", - "@tsl//tsl/util:env_var", ], ) @@ -39,10 +38,15 @@ cc_library( srcs = ["se_gpu_pjrt_client.cc"], hdrs = ["se_gpu_pjrt_client.h"], defines = if_cuda(["GOOGLE_CUDA=1"]) + if_rocm(["TENSORFLOW_USE_ROCM=1"]), - visibility = ["//xla/pjrt:friends"], + visibility = internal_visibility(["//xla/pjrt:friends"]), deps = [ ":gpu_helpers", + ":gpu_metrics", ":gpu_topology", + "//xla:literal", + "//xla:shape_util", + "//xla:status", + "//xla:status_macros", "//xla:statusor", "//xla:util", "//xla:xla_proto_cc", @@ -50,36 +54,53 @@ cc_library( "//xla/client:local_client", "//xla/client:xla_computation", "//xla/pjrt:compile_options_proto_cc", - "//xla/pjrt:metrics", + "//xla/pjrt:event_pool", + "//xla/pjrt:local_device_state", "//xla/pjrt:mlir_to_hlo", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_compiler", + "//xla/pjrt:pjrt_device_description", "//xla/pjrt:pjrt_executable", + "//xla/pjrt:pjrt_future", "//xla/pjrt:pjrt_stream_executor_client", "//xla/pjrt:stream_executor_executable", "//xla/pjrt:stream_executor_executable_proto_cc", "//xla/pjrt:tracked_device_buffer", "//xla/pjrt:utils", "//xla/pjrt/distributed:client", + "//xla/pjrt/distributed:in_memory_key_value_store", + "//xla/pjrt/distributed:key_value_store_interface", "//xla/pjrt/distributed:topology_util", "//xla/service:compiler", + "//xla/service:computation_placer_hdr", "//xla/service:executable", + "//xla/service:global_device_id", "//xla/service:platform_util", + "//xla/service:shaped_buffer", + "//xla/service:transfer_manager", "//xla/service/gpu:gpu_executable_run_options", + "//xla/stream_executor", + "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory", - "//xla/stream_executor:stream_executor_internal", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:platform", "//xla/stream_executor/integrations:device_mem_allocator", "//xla/stream_executor/integrations:tf_allocator_adapter", + "//xla/tsl/util:env_var", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", "@tsl//tsl/framework:allocator", "@tsl//tsl/framework:bfc_allocator", "@tsl//tsl/framework:device_id", @@ -88,16 +109,16 @@ cc_library( "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:fingerprint", + "@tsl//tsl/platform:status", "@tsl//tsl/profiler/lib:connected_traceme", - "@tsl//tsl/util:env_var", + "@tsl//tsl/profiler/lib:traceme", ] + if_cuda_or_rocm([ + ":nccl_id_store", "//xla/service/gpu:gpu_compiler", ]) + if_cuda([ - ":nccl_id_store_cuda", "@local_config_cuda//cuda:cuda_headers", "//xla/stream_executor/gpu:gpu_cudamallocasync_allocator", ]) + if_rocm([ - ":nccl_id_store_rocm", "@local_config_rocm//rocm:rocm_headers", ]), ) @@ -105,7 +126,6 @@ cc_library( xla_cc_test( name = "se_gpu_pjrt_client_test", srcs = if_gpu_is_configured(["se_gpu_pjrt_client_test.cc"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), tags = [ "gpu", "no_oss", @@ -119,12 +139,15 @@ xla_cc_test( "//xla:literal_util", "//xla:statusor", "//xla:test", - "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_future", "//xla/pjrt:utils", + "//xla/pjrt/distributed:in_memory_key_value_store", "//xla/service:gpu_plugin", "//xla/service:hlo_parser", "//xla/tests:literal_test_util", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@tsl//tsl/lib/core:status_test_util", @@ -136,44 +159,30 @@ xla_cc_test( ], ) -# We actually wish we could write if_cuda(if_nccl(...)) in :gpu_device, -# but Bazel does not allow nested selects. We can work around the problem using -# an intermediate library that has the conditional NCCL pieces that is only -# itself included as a dependency if CUDA is enabled. cc_library( - name = "nccl_id_store_cuda", + name = "nccl_id_store", srcs = ["nccl_id_store.cc"], hdrs = ["nccl_id_store.h"], - defines = if_nccl(["NCCL_ENABLED=1"]), deps = [ + "//xla:status_macros", "//xla:statusor", "//xla:util", "//xla/pjrt:pjrt_client", "//xla/pjrt/distributed:client", + "//xla/pjrt/distributed:key_value_store_interface", "//xla/service:global_device_id", - "//xla/service/gpu:gpu_executable_run_options", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/synchronization", - ] + if_nccl(["@local_config_nccl//:nccl"]), -) - -cc_library( - name = "nccl_id_store_rocm", - srcs = ["nccl_id_store.cc"], - hdrs = ["nccl_id_store.h"], - defines = if_nccl(["NCCL_ENABLED=1"]), - deps = [ - "//xla:statusor", - "//xla:util", - "//xla/pjrt:pjrt_client", - "//xla/pjrt/distributed:client", - "//xla/service:global_device_id", - "//xla/service/gpu:gpu_executable_run_options", + "//xla/service/gpu:nccl_clique_key", + "//xla/service/gpu/runtime:nccl_api", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", - ] + if_nccl(["@local_config_nccl//:nccl"]), + "@com_google_absl//absl/time", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], ) xla_cc_test( @@ -233,35 +242,46 @@ cc_library( "//xla/service:local_service", "//xla/service:local_service_utils", "//xla/service/gpu:executable_proto_cc", - "//xla/stream_executor/cuda:cuda_platform_id", + "//xla/stream_executor/platform", "@com_google_absl//absl/status", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:errors", ] + if_cuda_or_rocm([ + ":nccl_id_store", "//xla/service/gpu:gpu_compiler", ]) + if_cuda([ - ":nccl_id_store_cuda", "@local_config_cuda//cuda:cuda_headers", + "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/cuda:cuda_activation_header", "//xla/stream_executor/gpu:gpu_cudamallocasync_allocator", "//xla/service/gpu:nvptx_compiler_impl", ]) + if_rocm([ - ":nccl_id_store_rocm", "@local_config_rocm//rocm:rocm_headers", + "//xla/stream_executor/rocm:rocm_platform_id", "//xla/service/gpu:amdgpu_compiler_impl", ]), alwayslink = True, ) +cc_library( + name = "gpu_metrics", + srcs = ["gpu_metrics.cc"], + hdrs = ["gpu_metrics.h"], + deps = [ + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/lib/monitoring:gauge", + ], +) + xla_cc_test( name = "se_gpu_pjrt_compiler_test", srcs = if_gpu_is_configured(["se_gpu_pjrt_compiler_test.cc"]), tags = [ - "config-cuda-only", "gpu", "no_oss", "requires-gpu-nvidia", - ], + ] + if_google(["config-cuda-only"]), deps = [ ":se_gpu_pjrt_client", ":se_gpu_pjrt_compiler", @@ -288,11 +308,10 @@ xla_cc_test( name = "se_gpu_pjrt_compiler_aot_test", srcs = if_gpu_is_configured(["se_gpu_pjrt_compiler_aot_test.cc"]), tags = [ - "config-cuda-only", "gpu", "no_oss", "requires-gpu-nvidia", - ], + ] + if_google(["config-cuda-only"]), deps = [ ":se_gpu_pjrt_client", ":se_gpu_pjrt_compiler", @@ -306,8 +325,6 @@ xla_cc_test( "//xla/service:compiler", "//xla/service:gpu_plugin", "//xla/service:hlo_parser", - "//xla/service/gpu:nvptx_compiler_impl", - "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:literal_test_util", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -317,11 +334,15 @@ xla_cc_test( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:protobuf", - "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", - ], + ] + if_cuda([ + "//xla/service/gpu:nvptx_compiler_impl", + "//xla/stream_executor/cuda:cublas_plugin", + ]) + if_rocm([ + "//xla/service/gpu:amdgpu_compiler_impl", + "//xla/stream_executor/rocm:rocblas_plugin", + ]), ) diff --git a/xla/pjrt/gpu/gpu_helpers.cc b/xla/pjrt/gpu/gpu_helpers.cc index 52ebc0bceaa2f..b9477360de119 100644 --- a/xla/pjrt/gpu/gpu_helpers.cc +++ b/xla/pjrt/gpu/gpu_helpers.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ limitations under the License. #include "xla/pjrt/gpu/gpu_helpers.h" +#include #include #include #include @@ -27,14 +28,14 @@ limitations under the License. #include "xla/statusor.h" #include "xla/stream_executor/integrations/device_host_allocator.h" #include "xla/stream_executor/integrations/device_mem_allocator.h" +#include "xla/tsl/util/env_var.h" #include "xla/util.h" #include "tsl/framework/device_id.h" -#include "tsl/util/env_var.h" namespace xla { // Builds an xla::LocalClient for the GPU platform. -StatusOr GetGpuXlaClient( +absl::StatusOr GetGpuXlaClient( const std::optional& platform_name, const std::optional>& allowed_devices) { TF_ASSIGN_OR_RETURN( @@ -71,9 +72,9 @@ void EnablePeerAccess(absl::Span executors) { } // Builds a BFCAllocator for all local GPUs. -StatusOr> CreateBFCAllocator( +absl::StatusOr> CreateBFCAllocator( se::StreamExecutor* executor, double memory_fraction, bool preallocate, - bool garbage_collection) { + std::optional gpu_system_memory_size) { bool enable_unified_memory; Status status = tsl::ReadBoolFromEnvVar("TF_FORCE_UNIFIED_MEMORY", false, &enable_unified_memory); @@ -85,7 +86,9 @@ StatusOr> CreateBFCAllocator( int device_ordinal = executor->device_ordinal(); auto sub_allocator = std::make_unique( executor, tsl::PlatformDeviceId(device_ordinal), - /*use_unified_memory=*/enable_unified_memory, + /*memory_type=*/ + enable_unified_memory ? stream_executor::MemoryType::kUnified + : stream_executor::MemoryType::kDevice, /*alloc_visitors=*/std::vector(), /*free_visitors=*/std::vector()); @@ -102,6 +105,11 @@ StatusOr> CreateBFCAllocator( size_t allocator_memory = enable_unified_memory ? total_memory * fmax(1.0, memory_fraction) : total_memory * memory_fraction; + // If gpu_system_memory_size is set, use it instead of default value. + if (gpu_system_memory_size.has_value()) { + allocator_memory = gpu_system_memory_size.value(); + } + if (preallocate) { LOG(INFO) << "XLA backend allocating " << allocator_memory << " bytes on device " << device_ordinal << " for BFCAllocator."; @@ -112,12 +120,49 @@ StatusOr> CreateBFCAllocator( tsl::BFCAllocator::Options opts; opts.allow_growth = !preallocate; - opts.garbage_collection = garbage_collection; return std::make_unique( std::move(sub_allocator), allocator_memory, absl::StrCat("GPU_", device_ordinal, "_bfc"), opts); } +// Builds a BFCAllocator for all local GPUs that uses collective memory. +absl::StatusOr> CreateCollectiveBFCAllocator( + se::StreamExecutor* executor, double memory_fraction, + size_t collective_memory_size) { + int device_ordinal = executor->device_ordinal(); + auto sub_allocator = std::make_unique( + executor, tsl::PlatformDeviceId(device_ordinal), + /*memory_type=*/stream_executor::MemoryType::kCollective, + /*alloc_visitors=*/std::vector(), + /*free_visitors=*/std::vector()); + + int64_t free_memory; + int64_t total_memory; + if (!executor->DeviceMemoryUsage(&free_memory, &total_memory)) { + return Unavailable("Failed to query available memory from device %i", + device_ordinal); + } + bool preallocate = collective_memory_size != 0; + size_t allocator_memory = + preallocate ? collective_memory_size : total_memory * memory_fraction; + + if (preallocate) { + LOG(INFO) << "XLA backend allocating " << allocator_memory + << " bytes on device " << device_ordinal + << " for CollectiveBFCAllocator."; + } else { + LOG(INFO) << "XLA backend will use up to " << allocator_memory + << " bytes on device " << device_ordinal + << " for CollectiveBFCAllocator."; + } + + tsl::BFCAllocator::Options opts; + opts.allow_growth = !preallocate; + return std::make_unique( + std::move(sub_allocator), allocator_memory, + absl::StrCat("GPU_collectivememory_", device_ordinal, "_bfc"), opts); +} + // Returns a GPU pinned host memory allocator to use when staging host->GPU // transfers. We use a fixed 64GB pool of pinned memory. std::unique_ptr GetGpuHostAllocator( @@ -126,8 +171,18 @@ std::unique_ptr GetGpuHostAllocator( new se::DeviceHostAllocator(executor, /*numa_node=*/0, /*alloc_visitors=*/{}, /*free_visitors=*/{})); - // TODO(phawkins): allow the user to tune this. - const int64_t kGpuHostMemoryLimitBytes = 64 * (1LL << 30); + + int64_t xla_pjrt_gpu_host_memory_limit_gb; + Status status = + tsl::ReadInt64FromEnvVar("XLA_PJRT_GPU_HOST_MEMORY_LIMIT_GB", 64, + &xla_pjrt_gpu_host_memory_limit_gb); + if (!status.ok()) { + LOG(ERROR) << "Unable to read XLA_PJRT_GPU_HOST_MEMORY_LIMIT_GB: " + << status.message(); + } + + const int64_t kGpuHostMemoryLimitBytes = + xla_pjrt_gpu_host_memory_limit_gb * (1LL << 30); tsl::BFCAllocator::Options opts; opts.allow_growth = true; diff --git a/xla/pjrt/gpu/gpu_helpers.h b/xla/pjrt/gpu/gpu_helpers.h index 3b3d38e5ff22b..cdda4ff7e075c 100644 --- a/xla/pjrt/gpu/gpu_helpers.h +++ b/xla/pjrt/gpu/gpu_helpers.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_PJRT_GPU_GPU_HELPERS_H_ #define XLA_PJRT_GPU_GPU_HELPERS_H_ +#include #include #include #include @@ -31,7 +32,7 @@ limitations under the License. namespace xla { // Builds an xla::LocalClient for the GPU platform. -StatusOr GetGpuXlaClient( +absl::StatusOr GetGpuXlaClient( const std::optional& platform_name, const std::optional>& allowed_devices); @@ -50,25 +51,43 @@ struct GpuAllocatorConfig { // Only used if kind == kBFC. The maximum fraction of available memory to // allocate. This is the default value of XLA_PYTHON_CLIENT_MEM_FRACTION. + // + // If `gpu_system_memory_size` is set, it determines memory allocation. + // `memory_fraction` won't be used in this case. double memory_fraction = 0.75; + // Only used if kind == kBFC. The absolute size of reserved memory space for + // GPU system in bytes. + // + // If null, the default value `memory_fraction` will be used. + std::optional gpu_system_memory_size = std::nullopt; + // Only used if kind == kBFC. If true, the allocator will immediately allocate // the maximum amount allowed by `memory_fraction`. This reduces // fragmentation, allowing more of the total memory to be used. If false, the // allocator will allocate more memory as allocations are requested. bool preallocate = true; - // activate garbage collection or not - bool garbage_collection = false; + // Amount of collective memory (ncclMemAlloc) to preallocate. If this value is + // 0, collective memory space will be grown as needed to fit the application's + // usage, with the drawback of potentially higher fragmentation. If set, + // should be set to a multiple of 512MB to avoid wasting memory due to + // granularity requirements. + size_t collective_memory_size = 0; }; std::unique_ptr GetGpuHostAllocator( se::StreamExecutor* executor); // Builds a BFCAllocator for all local GPUs. -StatusOr> CreateBFCAllocator( +absl::StatusOr> CreateBFCAllocator( se::StreamExecutor* executor, double memory_fraction, bool preallocate, - bool garbage_collection); + std::optional gpu_system_memory_size); + +// Builds a BFCAllocator for all local GPUs that uses collective memory. +absl::StatusOr> CreateCollectiveBFCAllocator( + se::StreamExecutor* executor, double memory_fraction, + size_t collective_memory_size); } // namespace xla diff --git a/xla/pjrt/gpu/gpu_metrics.cc b/xla/pjrt/gpu/gpu_metrics.cc new file mode 100644 index 0000000000000..56cc0dc6861b8 --- /dev/null +++ b/xla/pjrt/gpu/gpu_metrics.cc @@ -0,0 +1,43 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/pjrt/gpu/gpu_metrics.h" + +#include + +#include "absl/strings/str_cat.h" +#include "tsl/lib/monitoring/gauge.h" + +namespace xla { +namespace { +auto* free_gpu_system_memory = tsl::monitoring::Gauge::New( + gpu_metrics::freeGpuSystemMemoryMetricName, + "Record the free GPU system memory.", "gpu_id"); +} // namespace + +namespace gpu_metrics { + +void RecordFreeGpuSystemMemory(const int device_ordinal, + const int64_t free_memory) { + free_gpu_system_memory->GetCell(absl::StrCat(device_ordinal)) + ->Set(free_memory); +} + +int64_t GetFreeGpuSystemMemory(int gpu_id) { + return free_gpu_system_memory->GetCell(absl::StrCat(gpu_id))->value(); +} + +} // namespace gpu_metrics +} // namespace xla diff --git a/xla/pjrt/gpu/gpu_metrics.h b/xla/pjrt/gpu/gpu_metrics.h new file mode 100644 index 0000000000000..a1dd5a08e7a2d --- /dev/null +++ b/xla/pjrt/gpu/gpu_metrics.h @@ -0,0 +1,36 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_GPU_GPU_METRICS_H_ +#define XLA_PJRT_GPU_GPU_METRICS_H_ + +#include + +#include "absl/strings/string_view.h" + +namespace xla { +namespace gpu_metrics { + +inline constexpr absl::string_view freeGpuSystemMemoryMetricName = + "/pjrt/gpu/free_gpu_system_memory"; + +void RecordFreeGpuSystemMemory(int device_ordinal, int64_t free_memory); + +int64_t GetFreeGpuSystemMemory(int gpu_id); + +} // namespace gpu_metrics +} // namespace xla + +#endif // XLA_PJRT_GPU_GPU_METRICS_H_ diff --git a/xla/pjrt/gpu/gpu_topology.cc b/xla/pjrt/gpu/gpu_topology.cc index 00717a9111e4e..278cd72824107 100644 --- a/xla/pjrt/gpu/gpu_topology.cc +++ b/xla/pjrt/gpu/gpu_topology.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/gpu/gpu_topology.h b/xla/pjrt/gpu/gpu_topology.h index 9da556f48c461..25d1834c26e67 100644 --- a/xla/pjrt/gpu/gpu_topology.h +++ b/xla/pjrt/gpu/gpu_topology.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/gpu/nccl_id_store.cc b/xla/pjrt/gpu/nccl_id_store.cc index 98a5a705edc9b..55726a0542788 100644 --- a/xla/pjrt/gpu/nccl_id_store.cc +++ b/xla/pjrt/gpu/nccl_id_store.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,24 +18,18 @@ limitations under the License. #include #include -#ifdef NCCL_ENABLED -#if TENSORFLOW_USE_ROCM -#include "rocm/rocm_config.h" -#if (TF_ROCM_VERSION >= 50200) -#include "rocm/include/rccl/rccl.h" -#else -#include "rocm/include/rccl.h" -#endif -#else -#include "third_party/nccl/nccl.h" -#endif -#endif // NCCL_ENABLED - -#include "xla/util.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/status_macros.h" +#include "xla/statusor.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { -StatusOr NcclIdStore::GetNcclUniqueId( +absl::StatusOr NcclIdStore::GetNcclUniqueId( const gpu::NcclCliqueKey& key) { // The caller must ensure that threads calling this method concurrently have // unique keys, otherwise the global key-value store may hold the wrong value. @@ -46,23 +40,18 @@ StatusOr NcclIdStore::GetNcclUniqueId( return it->second; } } - std::string id_string; + gpu::NcclCliqueId clique_id; int primary_node_id = device_to_node_.at(key.devices()[0]); if (node_id_ == primary_node_id) { -#ifdef NCCL_ENABLED - ncclUniqueId id; - ncclResult_t r = ncclGetUniqueId(&id); - TF_RET_CHECK(r == ncclSuccess); - id_string = std::string(id.internal, NCCL_UNIQUE_ID_BYTES); - TF_RETURN_IF_ERROR(kv_put_(key.ToString(), id_string)); -#else - return FailedPrecondition("NCCL support was not built into XLA binary."); -#endif + TF_ASSIGN_OR_RETURN(clique_id, gpu::NcclApi::Default()->GetUniqueId()); + TF_RETURN_IF_ERROR(kv_store_->Set(key.ToString(), clique_id.ToString())); } else { - TF_ASSIGN_OR_RETURN(id_string, kv_get_(key.ToString(), absl::Minutes(10))); + TF_ASSIGN_OR_RETURN(std::string id_str, + kv_store_->Get(key.ToString(), absl::Minutes(10))); + TF_ASSIGN_OR_RETURN(clique_id, gpu::NcclCliqueId::FromString(id_str)); } absl::MutexLock lock(&mu_); - auto result = cache_.emplace(key, std::move(id_string)); + auto result = cache_.emplace(key, std::move(clique_id)); TF_RET_CHECK(result.second) << "Unique ID already in cache."; return result.first->second; } diff --git a/xla/pjrt/gpu/nccl_id_store.h b/xla/pjrt/gpu/nccl_id_store.h index b97a886eb5edf..70060e242b150 100644 --- a/xla/pjrt/gpu/nccl_id_store.h +++ b/xla/pjrt/gpu/nccl_id_store.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,44 +17,40 @@ limitations under the License. #define XLA_PJRT_GPU_NCCL_ID_STORE_H_ #include -#include #include #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/synchronization/mutex.h" -#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/service/global_device_id.h" -#include "xla/service/gpu/gpu_executable_run_options.h" +#include "xla/service/gpu/nccl_clique_key.h" #include "xla/statusor.h" namespace xla { -// A table mapping NcclCliqueKeys to ncclUniqueId values encoded as strings. -// In a distributed setup the table of NCCL IDs is kept on the master node -// (node 0). The node of the first participating device will create the unique -// id. +// A table mapping NcclCliqueKeys to NcclCliqueIds. In a distributed setup the +// table of NCCL IDs is kept on the master node (node 0). The node of the first +// participating device will create the unique id. class NcclIdStore { public: NcclIdStore(int node_id, absl::flat_hash_map device_to_node, - PjRtClient::KeyValueGetCallback kv_get, - PjRtClient::KeyValuePutCallback kv_put) + std::shared_ptr kv_store) : node_id_(node_id), device_to_node_(std::move(device_to_node)), - kv_get_(kv_get), - kv_put_(kv_put) {} + kv_store_(std::move(kv_store)) {} - StatusOr GetNcclUniqueId(const gpu::NcclCliqueKey& key); + absl::StatusOr GetNcclUniqueId( + const gpu::NcclCliqueKey& key); private: const int node_id_; const absl::flat_hash_map device_to_node_; - const PjRtClient::KeyValueGetCallback kv_get_; - const PjRtClient::KeyValuePutCallback kv_put_; + const std::shared_ptr kv_store_; absl::Mutex mu_; - absl::flat_hash_map cache_ + absl::flat_hash_map cache_ ABSL_GUARDED_BY(mu_); }; diff --git a/xla/pjrt/gpu/pjrt_client_test_se_gpu.cc b/xla/pjrt/gpu/pjrt_client_test_se_gpu.cc index bde468f2ff4c5..08cf47ba6359a 100644 --- a/xla/pjrt/gpu/pjrt_client_test_se_gpu.cc +++ b/xla/pjrt/gpu/pjrt_client_test_se_gpu.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 87faabdb599f7..0b61bec9748eb 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,56 +15,80 @@ limitations under the License. #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" +#include +#include +#include +#include +#include #include #include #include -#include #include #include #include +#include #include -#include "absl/base/attributes.h" +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" +#include "absl/functional/any_invocable.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/ascii.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" -#include "absl/synchronization/blocking_counter.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/time/time.h" +#include "absl/types/span.h" #include "xla/client/local_client.h" #include "xla/client/xla_computation.h" +#include "xla/literal.h" +#include "xla/pjrt/distributed/in_memory_key_value_store.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/distributed/topology_util.h" +#include "xla/pjrt/event_pool.h" #include "xla/pjrt/gpu/gpu_helpers.h" +#include "xla/pjrt/local_device_state.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/pjrt_stream_executor_client.h" #include "xla/pjrt/stream_executor_executable.h" #include "xla/pjrt/tracked_device_buffer.h" -#include "xla/pjrt/utils.h" #include "xla/service/compiler.h" -#include "xla/service/executable.h" +#include "xla/service/computation_placer.h" +#include "xla/service/global_device_id.h" +#include "xla/service/shaped_buffer.h" +#include "xla/service/transfer_manager.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/stream_executor_internal.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" #include "tsl/framework/allocator.h" -#include "tsl/framework/bfc_allocator.h" #include "tsl/lib/strings/proto_serialization.h" -#include "tsl/platform/env.h" #include "tsl/platform/errors.h" +#include "tsl/platform/status.h" #include "tsl/platform/threadpool.h" #include "tsl/profiler/lib/connected_traceme.h" +#include "tsl/profiler/lib/traceme.h" #if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) #include "xla/pjrt/compile_options.pb.h" +#include "xla/pjrt/gpu/gpu_metrics.h" #include "xla/pjrt/gpu/nccl_id_store.h" -#include "xla/pjrt/metrics.h" #include "xla/pjrt/stream_executor_executable.pb.h" #include "xla/service/gpu/gpu_compiler.h" #include "xla/xla.pb.h" @@ -78,24 +102,19 @@ limitations under the License. #include "rocm/rocm_config.h" #endif -#include "xla/client/client_library.h" #include "xla/service/gpu/gpu_executable_run_options.h" -#include "xla/service/platform_util.h" #include "xla/statusor.h" -#include "xla/stream_executor/integrations/device_host_allocator.h" #include "xla/stream_executor/integrations/device_mem_allocator.h" #include "xla/stream_executor/integrations/tf_allocator_adapter.h" #include "xla/util.h" -#include "tsl/framework/device_id.h" -#include "tsl/util/env_var.h" namespace xla { class AsyncHostToDeviceTransferManager : public xla::PjRtClient::AsyncHostToDeviceTransferManager { public: - static StatusOr> Create( - absl::Span shapes, PjRtStreamExecutorDevice* device, - PjRtStreamExecutorClient* client) { + static absl::StatusOr> + Create(absl::Span shapes, PjRtStreamExecutorDevice* device, + PjRtStreamExecutorClient* client) { absl::InlinedVector, 4> buffers; absl::InlinedVector, 4> buffer_ptrs; absl::InlinedVector, 4> @@ -259,7 +278,10 @@ class AsyncHostToDeviceTransferManager CleanUp(buffer_index, std::move(event), stream, /*is_last_transfer=*/true, std::move(on_done)); }; - stream->ThenDoHostCallback(std::move(cleanup)); + auto status = stream->DoHostCallback(std::move(cleanup)); + if (!status.ok()) { + LOG(ERROR) << "DoHostCallback failed: " << status; + } }; se_client->thread_pool()->Schedule( ([ptr = new absl::AnyInvocable(std::move(transfer_h2d))]() { @@ -283,6 +305,27 @@ class AsyncHostToDeviceTransferManager bool is_last_transfer, absl::AnyInvocable on_done) override { auto* stream = device_->local_device_state()->host_to_device_stream(); + auto* client = + tensorflow::down_cast(device_->client()); + bool should_stage_host_to_device_transfers = + client->should_stage_host_to_device_transfers(); + std::shared_ptr staging_buffer; + if (should_stage_host_to_device_transfers) { + auto* host_memory_allocator = client->host_memory_allocator(); + if (host_memory_allocator == nullptr) { + return InvalidArgument( + "host_memory_allocator should be initialized for staging buffer " + "transfer."); + } + + void* ptr = host_memory_allocator->AllocateRaw( + tsl::Allocator::kAllocatorAlignment, transfer_size); + staging_buffer = std::shared_ptr( + ptr, [host_memory_allocator = host_memory_allocator](void* ptr) { + host_memory_allocator->DeallocateRaw(ptr); + }); + } + absl::ReleasableMutexLock l(&mu_); DCHECK_LT(buffer_index, buffer_ptrs_.size()); if (last_transfer_started_[buffer_index]) { @@ -308,24 +351,42 @@ class AsyncHostToDeviceTransferManager CHECK_LE(offset, buffer_memory.size()); CHECK_LE(transfer_size, buffer_memory.size() - offset); if (transfer_size < buffer_memory.size()) { - sub_buffer = se::DeviceMemoryBase( - reinterpret_cast(buffer_memory.opaque()) + offset, - transfer_size); + sub_buffer = buffer_memory.GetByteSlice(offset, transfer_size); } else { sub_buffer = buffer_memory; } ++transfers_in_flight_; + // Release the lock before transfer in case transfer or cleanup could be + // called on this thread, to avoid deadlock. + l.Release(); + auto event = device_->local_device_state()->event_pool().AllocateEvent( stream->parent()); + if (transfer_size != 0) { - stream->ThenMemcpy(&sub_buffer, data, transfer_size); + if (staging_buffer != nullptr) { + auto copy_to_staging_buffer = [data, transfer_size, + staging_buffer]() mutable { + std::memcpy(staging_buffer.get(), data, transfer_size); + }; + if (auto status = + stream->DoHostCallback(std::move(copy_to_staging_buffer)); + !status.ok()) { + return status; + } + if (auto status = stream->Memcpy(&sub_buffer, staging_buffer.get(), + transfer_size); + !status.ok()) { + return status; + } + } else if (auto status = stream->Memcpy(&sub_buffer, data, transfer_size); + !status.ok()) { + return status; + } } device_->local_device_state()->event_pool().ThenRecordEvent(stream, event.value()); - // Release the lock before calling ThenDoHostCallback in case cleanup - // could be called on this thread, to avoid deadlock. - l.Release(); auto cleanup = [this, buffer_index, event = std::move(event).value(), stream, is_last_transfer, @@ -333,8 +394,7 @@ class AsyncHostToDeviceTransferManager CleanUp(buffer_index, std::move(event), stream, is_last_transfer, std::move(on_done)); }; - stream->ThenDoHostCallback(std::move(cleanup)); - return OkStatus(); + return stream->DoHostCallback(std::move(cleanup)); } void SetBufferError(int buffer_index, Status error) override { @@ -419,7 +479,7 @@ absl::string_view StreamExecutorGpuClient::platform_version() const { #endif // TENSORFLOW_USE_ROCM && defined(TF_ROCM_VERSION) } -StatusOr> +absl::StatusOr> StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( absl::Span shapes, PjRtDevice* device) { auto* stream_executor_device = @@ -428,7 +488,7 @@ StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( shapes, stream_executor_device, this); } -xla::StatusOr +absl::StatusOr StreamExecutorGpuClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) const { if (num_partitions == 1 && num_replicas <= addressable_devices().size()) { @@ -482,16 +542,47 @@ PjRtFuture StreamExecutorGpuClient::CopyRawSubBufferToHost( std::unique_ptr sub_buffer; if (transfer_size < device_memory.size()) { sub_buffer = std::make_unique( - reinterpret_cast(device_memory.opaque()) + offset, - transfer_size); + device_memory.GetByteSlice(offset, transfer_size)); } else { sub_buffer = std::make_unique(device_memory); } if (transfer_size != 0) { - // D2H request holds a non-owned pointer into sub_buffer base address - // that needs to outlive the transfer until the stream callback is invoked. - stream->ThenMemcpy(dst, *sub_buffer, transfer_size); + if (should_stage_host_to_device_transfers()) { + if (host_memory_allocator() == nullptr) { + return PjRtFuture(InvalidArgument( + "host_memory_allocator should be initialized for staging buffer " + "transfer.")); + } + void* ptr = host_memory_allocator()->AllocateRaw( + tsl::Allocator::kAllocatorAlignment, transfer_size); + + std::shared_ptr staging_buffer = std::shared_ptr( + ptr, [host_memory_allocator = host_memory_allocator()](void* ptr) { + host_memory_allocator->DeallocateRaw(ptr); + }); + if (auto status = + stream->Memcpy(staging_buffer.get(), *sub_buffer, transfer_size); + !status.ok()) { + return PjRtFuture(status); + } + auto copy_to_staging_buffer = [dst, transfer_size, + staging_buffer]() mutable { + std::memcpy(dst, staging_buffer.get(), transfer_size); + }; + if (auto status = stream->DoHostCallback(copy_to_staging_buffer); + !status.ok()) { + return PjRtFuture(status); + } + } else { + // D2H request holds a non-owned pointer into sub_buffer base address + // that needs to outlive the transfer until the stream callback is + // invoked. + auto status = stream->Memcpy(dst, *sub_buffer, transfer_size); + if (!status.ok()) { + return PjRtFuture(status); + } + } } auto usage_event = @@ -504,14 +595,17 @@ PjRtFuture StreamExecutorGpuClient::CopyRawSubBufferToHost( /*reference_held=*/false); auto promise = PjRtFuture::CreatePromise(); - local_device->ThenExecuteCallback( - stream.get(), [promise, free_sub_range = sub_buffer.release(), - free_stream = stream.release(), local_device]() mutable { + auto stream_ptr = stream.get(); + auto callback_status = local_device->ThenExecuteCallback( + stream_ptr, + [promise, free_stream = stream.release(), local_device]() mutable { auto stream = std::unique_ptr(free_stream); - auto sub_range = std::unique_ptr(free_sub_range); local_device->ReturnStreamToPool(std::move(stream)); promise.Set(OkStatus()); }); + if (!callback_status.ok()) { + return PjRtFuture(callback_status); + } return PjRtFuture( std::move(promise), @@ -531,7 +625,7 @@ PjRtFuture StreamExecutorGpuClient::CopyRawSubBufferToHost( }); } -StatusOr> +absl::StatusOr> StreamExecutorGpuClient::Compile(const XlaComputation& computation, CompileOptions options) { auto executable = PjRtStreamExecutorClient::Compile(computation, options); @@ -546,7 +640,7 @@ StreamExecutorGpuClient::Compile(const XlaComputation& computation, se::StreamExecutor* executor = local_device_state->executor(); int device_ordinal = executor->device_ordinal(); if (executor->DeviceMemoryUsage(&free_memory, &total_memory)) { - metrics::RecordFreeGpuSystemMemory(device_ordinal, free_memory); + gpu_metrics::RecordFreeGpuSystemMemory(device_ordinal, free_memory); } else { LOG(ERROR) << "Failed to query available memory for GPU " << device_ordinal; @@ -559,7 +653,7 @@ StreamExecutorGpuClient::Compile(const XlaComputation& computation, namespace { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -StatusOr> FromProto( +absl::StatusOr> FromProto( const StreamExecutorExecutableProto& proto) { TF_ASSIGN_OR_RETURN(CompileOptions compile_options, CompileOptions::FromProto(proto.compile_options())); @@ -573,12 +667,13 @@ StatusOr> FromProto( } return std::make_unique( compile_options, std::move(deserialized_aot_executables), - proto.num_replicas(), proto.num_partitions(), proto.name()); + proto.num_replicas(), proto.num_partitions(), proto.name(), + proto.fingerprint()); } #endif } // namespace -StatusOr> +absl::StatusOr> StreamExecutorGpuClient::LoadSerialized(absl::string_view serialized, std::optional options, const LoadOptions& load_options) { @@ -601,8 +696,26 @@ StreamExecutorGpuClient::LoadSerialized(absl::string_view serialized, return absl::InternalError("LoadSerialized only works with cuda or rocm."); } -StatusOr> StreamExecutorGpuClient::Load( - std::unique_ptr executable) { +absl::StatusOr> +StreamExecutorGpuClient::DeserializeExecutable( + absl::string_view serialized, std::optional options) { + if (serialized.size() > std::numeric_limits::max()) { + return Internal( + "StreamExecutorGpuClient::DeserializeExecutable proto too large " + "(>2GB)"); + } +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + StreamExecutorExecutableProto proto; + if (proto.ParseFromArray(serialized.data(), serialized.size())) { + TF_ASSIGN_OR_RETURN(auto se_executable, FromProto(proto)); + return Load(std::move(se_executable)); + } +#endif + return PjRtStreamExecutorClient::DeserializeExecutable(serialized, options); +} + +absl::StatusOr> +StreamExecutorGpuClient::Load(std::unique_ptr executable) { auto se_executable = absl::WrapUnique( tensorflow::down_cast(executable.release())); @@ -639,12 +752,13 @@ namespace { #if defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020 -StatusOr> CreateCudaAsyncAllocator( +absl::StatusOr> +CreateCudaAsyncAllocator( se::Platform* platform, const std::map>& addressable_devices, double memory_fraction, bool preallocate) { CHECK_GT(addressable_devices.size(), 0); - std::vector allocators; + std::vector allocators; for (auto& ordinal_and_device : addressable_devices) { se::StreamExecutor* executor = ordinal_and_device.second->executor(); @@ -678,15 +792,16 @@ StatusOr> CreateCudaAsyncAllocator( ->platform_specific_handle() .stream); allocators.emplace_back(std::move(allocator), - ordinal_and_device.second->compute_stream()); + ordinal_and_device.second->compute_stream(), + /*memory_space=*/0); } - return std::make_unique(platform, - std::move(allocators)); + return allocators; } #else // defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020 -StatusOr> CreateCudaAsyncAllocator( +absl::StatusOr> +CreateCudaAsyncAllocator( se::Platform* platform, const std::map>& addressable_devices, double memory_fraction, bool preallocate) { @@ -696,7 +811,7 @@ StatusOr> CreateCudaAsyncAllocator( #endif // defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020 // Builds a LocalDeviceState for each GPU present. -StatusOr>> +absl::StatusOr>> BuildLocalDeviceStates(LocalClient* xla_client) { std::map> addressable_devices; for (se::StreamExecutor* executor : @@ -713,65 +828,93 @@ BuildLocalDeviceStates(LocalClient* xla_client) { // Constructs a GPU device memory allocator to use, according to the allocator // configuration the client requested. -StatusOr> +absl::StatusOr> GetStreamExecutorGpuDeviceAllocator( se::Platform* platform, const GpuAllocatorConfig& allocator_config, const std::map>& addressable_devices) { - std::unique_ptr allocator; + std::vector allocators; switch (allocator_config.kind) { case GpuAllocatorConfig::Kind::kCudaAsync: { - auto allocator_or = CreateCudaAsyncAllocator( + auto allocators_or = CreateCudaAsyncAllocator( platform, addressable_devices, allocator_config.memory_fraction, allocator_config.preallocate); - if (allocator_or.ok()) { + if (allocators_or.ok()) { LOG(INFO) << "Using CUDA async allocator."; - allocator = std::move(allocator_or.value()); + allocators = std::move(allocators_or.value()); break; } LOG(ERROR) << "Failed to initialize CUDA async allocator: " - << allocator_or.status() << "; falling back to BFC."; + << allocators_or.status() << "; falling back to BFC."; [[fallthrough]]; } case GpuAllocatorConfig::Kind::kDefault: case GpuAllocatorConfig::Kind::kBFC: { LOG(INFO) << "Using BFC allocator."; - std::vector - allocators_and_streams; for (const auto& ordinal_and_device : addressable_devices) { TF_ASSIGN_OR_RETURN( auto bfc_allocator, CreateBFCAllocator(ordinal_and_device.second->executor(), allocator_config.memory_fraction, allocator_config.preallocate, - allocator_config.garbage_collection)); - allocators_and_streams.emplace_back( - std::move(bfc_allocator), - ordinal_and_device.second->compute_stream()); + allocator_config.gpu_system_memory_size)); + allocators.emplace_back(std::move(bfc_allocator), + ordinal_and_device.second->compute_stream(), + /*memory_space=*/0); } - allocator = std::make_unique( - platform, std::move(allocators_and_streams)); break; } case GpuAllocatorConfig::Kind::kPlatform: LOG(INFO) << "Using platform allocator."; - break; + if (allocator_config.collective_memory_size != 0) { + LOG(WARNING) + << "collective_memory_size is non-zero, but allocator kind is set " + "to \"platform\". Collective memory will not be allocated."; + } + // Returning null will cause the client to use the default backend + // allocator. + return nullptr; } - return std::move(allocator); + + // Add any additional allocators for alternate memory spaces. + for (const auto& ordinal_and_device : addressable_devices) { + TF_ASSIGN_OR_RETURN( + auto collective_bfc_allocator, + CreateCollectiveBFCAllocator( + ordinal_and_device.second->executor(), + /*memory_fraction=*/1.0 - allocator_config.memory_fraction, + allocator_config.collective_memory_size)); + allocators.emplace_back(std::move(collective_bfc_allocator), + ordinal_and_device.second->compute_stream(), + /*memory_space=*/1); + } + + for (const auto& ordinal_and_device : addressable_devices) { + auto host_allocator = + GetGpuHostAllocator(ordinal_and_device.second->executor()); + allocators.emplace_back(std::move(host_allocator), + ordinal_and_device.second->compute_stream(), + /*memory_space=*/ + static_cast(se::MemoryType::kHost)); + } + + return std::make_unique(platform, + std::move(allocators)); } +} // namespace + Status BuildDistributedDevices( std::string_view platform_name, std::map> local_device_states, int node_id, int num_nodes, std::vector>* devices, gpu::GpuExecutableRunOptions* gpu_executable_run_options, - const PjRtClient::KeyValueGetCallback& kv_get, - const PjRtClient::KeyValuePutCallback& kv_put, - absl::Duration get_local_topology_timeout = absl::Minutes(2), - absl::Duration get_global_topology_timeout = absl::Minutes(5)) { + std::shared_ptr kv_store, bool enable_mock_nccl, + absl::Duration get_local_topology_timeout, + absl::Duration get_global_topology_timeout) { LocalTopologyProto local_topology; local_topology.set_node_id(node_id); std::string boot_id_str; @@ -792,13 +935,24 @@ Status BuildDistributedDevices( device_proto->set_local_device_ordinal(ordinal_and_device.first); device_proto->set_name(desc->name()); device_proto->set_vendor(desc->device_vendor()); + device_proto->set_compute_capability( + MakeComputeCapabilityString(desc.get())); + device_proto->set_core_count(desc->core_count()); } GlobalTopologyProto global_topology; - TF_RETURN_IF_ERROR(ExchangeTopologies( - platform_name, node_id, num_nodes, get_local_topology_timeout, - get_global_topology_timeout, kv_get, kv_put, local_topology, - &global_topology)); + if (enable_mock_nccl) { + std::vector local_topologies(num_nodes, local_topology); + for (int i = 0; i < num_nodes; ++i) { + local_topologies[i].set_node_id(i); + } + global_topology = BuildGlobalTopology(absl::MakeSpan(local_topologies)); + } else { + TF_RETURN_IF_ERROR(ExchangeTopologies( + platform_name, node_id, num_nodes, get_local_topology_timeout, + get_global_topology_timeout, kv_store.get(), local_topology, + &global_topology)); + } std::map gpu_device_ids; absl::flat_hash_map device_to_node; @@ -817,8 +971,9 @@ Status BuildDistributedDevices( } auto device = std::make_unique( device_proto.global_device_id(), std::move(local_device), - device_proto.name(), device_proto.vendor(), node.node_id(), - device_proto.slice_index()); + device_proto.name(), device_proto.vendor(), + device_proto.compute_capability(), device_proto.core_count(), + node.node_id(), device_proto.slice_index()); devices->push_back(std::move(device)); } } @@ -829,9 +984,9 @@ Status BuildDistributedDevices( std::move(gpu_device_ids)); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (num_nodes > 1) { - auto nccl_id_store = - std::make_shared(node_id, device_to_node, kv_get, kv_put); - gpu_executable_run_options->set_nccl_unique_id_callback( + auto nccl_id_store = std::make_shared(node_id, device_to_node, + std::move(kv_store)); + gpu_executable_run_options->set_nccl_clique_id_callback( [nccl_id_store](const gpu::NcclCliqueKey& key) { return nccl_id_store->GetNcclUniqueId(key); }); @@ -840,23 +995,49 @@ Status BuildDistributedDevices( return OkStatus(); } -} // namespace +std::string MakeComputeCapabilityString(const se::DeviceDescription* desc) { + se::GpuComputeCapability cc = desc->gpu_compute_capability(); + if (std::holds_alternative(cc)) { + auto nvcc = std::get(cc); + return absl::StrCat(nvcc.major, ".", nvcc.minor); + } else if (std::holds_alternative(cc)) { + auto rocmcc = std::get(cc); + return rocmcc.gfx_version(); + } else { + return "unknown"; + } +} StreamExecutorGpuDevice::StreamExecutorGpuDevice( int id, std::unique_ptr local_device_state, - std::string device_kind, std::string device_vendor, int node_id, + std::string device_kind, std::string device_vendor, + std::string compute_capability, int core_count, int node_id, int slice_index) : PjRtStreamExecutorDevice(id, std::move(local_device_state), std::move(device_kind), node_id), device_vendor_(std::move(device_vendor)), slice_index_(slice_index) { - description().SetAttributes({ - {"device_vendor", device_vendor_}, - {"slice_index", static_cast(slice_index)}, - }); + int64_t core_index = 0; + description().SetCoreOnChip(core_index); + std::array coords = {local_device_id().value()}; + description().SetCoords(coords); + std::vector v_coords(description().coords().begin(), + description().coords().end()); + + description().SetAttributes( + {{"coords", xla::PjRtDeviceAttribute(v_coords)}, + {"core_on_chip", xla::PjRtDeviceAttribute(core_index)}, + {"device_vendor", device_vendor_}, + {"slice_index", static_cast(slice_index)}, + {"compute_capability", xla::PjRtDeviceAttribute(compute_capability)}, + {"core_count", static_cast(core_count)}}); description().SetToString(absl::StrFormat( - "StreamExecutorGpuDevice(id=%i, process_index=%i, slice_index=%i)", id, - process_index(), slice_index)); + "StreamExecutorGpuDevice(device_kind=%s, id=%i, process_index=%i, " + "slice_index=%i))", + description().device_kind(), id, process_index(), slice_index)); + description().SetDebugString(absl::StrFormat("%s_%i(process=%i,(%i))", + description().device_kind(), id, + process_index(), v_coords[0])); } int StreamExecutorGpuDevice::slice_index() const { return slice_index_; } @@ -872,22 +1053,36 @@ absl::StatusOr StreamExecutorGpuDevice::GetAllocatorStats() "GetAllocatorStats() is allowed only for addressable devices"); } - TF_ASSIGN_OR_RETURN( - auto allocator, - tensorflow::down_cast( - tensorflow::down_cast(client()) - ->allocator()) - ->GetAllocator(local_hardware_id())); + auto* allocator_adapter = dynamic_cast( + tensorflow::down_cast(client())->allocator()); + if (!allocator_adapter) { + return Unimplemented( + "GetAllocatorStats() is only implemented with MultiDeviceAdapter " + "allocator"); + } + + TF_ASSIGN_OR_RETURN(auto allocator, allocator_adapter->GetAllocator( + local_device_id().value())); auto stats = allocator->GetStats(); TF_RET_CHECK(stats.has_value()); return stats.value(); } -StatusOr> GetStreamExecutorGpuClient( +absl::Span StreamExecutorGpuDevice::coords() const { + return description().coords(); +} + +int StreamExecutorGpuDevice::core_on_chip() const { + return description().core_on_chip(); +} + +absl::StatusOr> GetStreamExecutorGpuClient( const GpuClientOptions& options) { #if TENSORFLOW_USE_ROCM auto pjrt_platform_name = xla::RocmName(); +#elif TENSORFLOW_USE_SYCL + auto pjrt_platform_name = xla::SyclName(); #else // TENSORFLOW_USE_ROCM auto pjrt_platform_name = xla::CudaName(); #endif // TENSORFLOW_USE_ROCM @@ -910,52 +1105,15 @@ StatusOr> GetStreamExecutorGpuClient( if (options.enable_mock_nccl) { gpu_run_options->set_enable_mock_nccl_collectives(); } - absl::flat_hash_map device_maps; - absl::Mutex mu; - PjRtClient::KeyValueGetCallback kv_get = options.kv_get; - PjRtClient::KeyValuePutCallback kv_put = options.kv_put; + std::shared_ptr kv_store = options.kv_store; if (options.enable_mock_nccl) { - kv_get = [&device_maps, &mu, &options]( - std::string_view k, - absl::Duration timeout) -> xla::StatusOr { - std::string result; - { - absl::MutexLock lock(&mu); - if (device_maps.contains(k)) { - result = device_maps[k]; - } else { - int device_id; - std::vector tokens = absl::StrSplit(k, ':'); - if (tokens.size() != 2 || !absl::SimpleAtoi(tokens[1], &device_id)) { - device_id = options.num_nodes - 1; - } - // Return fake local topology with device_id info back. - xla::LocalTopologyProto local; - local.set_boot_id("fake_boot_id"); - local.set_node_id(device_id); - xla::DeviceProto* device = local.add_devices(); - device->set_global_device_id(device_id); - device->set_name("fake_device"); - device->set_vendor("fake_vendor"); - result = local.SerializeAsString(); - } - } - return result; - }; - kv_put = [&device_maps, &mu](std::string_view k, - std::string_view v) -> xla::Status { - { - absl::MutexLock lock(&mu); - device_maps[k] = v; - } - return xla::OkStatus(); - }; + kv_store = std::make_shared(); } - TF_RET_CHECK(options.num_nodes == 1 || kv_get != nullptr); - TF_RET_CHECK(options.num_nodes == 1 || kv_put != nullptr); + TF_RET_CHECK(options.num_nodes == 1 || kv_store != nullptr); TF_RETURN_IF_ERROR(BuildDistributedDevices( pjrt_platform_name, std::move(local_device_states), options.node_id, - options.num_nodes, &devices, gpu_run_options.get(), kv_get, kv_put)); + options.num_nodes, &devices, gpu_run_options.get(), kv_store, + options.enable_mock_nccl)); return std::unique_ptr(std::make_unique( pjrt_platform_name, xla_client, std::move(devices), options.node_id, @@ -973,16 +1131,23 @@ absl::StatusOr StreamExecutorGpuTopologyDescription::Serialize() return result; } +absl::StatusOr StreamExecutorGpuTopologyDescription::GetDefaultLayout( + PrimitiveType element_type, absl::Span dims) const { + Shape shape = ShapeUtil::MakeShape(element_type, dims); + return LayoutUtil::GetWithDefaultLayout(shape).layout(); +} + std::vector> BuildLocalDevices( std::map> local_device_states, int node_id) { std::vector> devices; for (auto& ordinal_and_device : local_device_states) { - const se::DeviceDescription& description = + const se::DeviceDescription& desc = ordinal_and_device.second->executor()->GetDeviceDescription(); auto device = std::make_unique( ordinal_and_device.first, std::move(ordinal_and_device.second), - description.name(), description.device_vendor(), node_id); + desc.name(), desc.device_vendor(), MakeComputeCapabilityString(&desc), + desc.core_count(), node_id); devices.push_back(std::move(device)); } return devices; diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client.h b/xla/pjrt/gpu/se_gpu_pjrt_client.h index b4fcb20f5278b..39c9327683d2e 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,7 +25,9 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/pjrt/distributed/client.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/gpu/gpu_helpers.h" #include "xla/pjrt/gpu/gpu_topology.h" #include "xla/pjrt/pjrt_client.h" @@ -58,14 +60,17 @@ class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription { } // `gpu_device_ids` is the list of logical device ids for the GPU devices and // will be used to initialize the GPU topology. - StreamExecutorGpuTopologyDescription(const PjRtPlatformId platform_id, - const absl::string_view platform_name, - const absl::string_view platform_version, - const std::vector& gpu_device_ids) + StreamExecutorGpuTopologyDescription( + const PjRtPlatformId platform_id, const absl::string_view platform_name, + const absl::string_view platform_version, + const std::vector& gpu_device_ids, + const absl::flat_hash_map& attributes = + {}) : platform_id_(platform_id), platform_name_(platform_name), platform_version_(platform_version), - gpu_topology_(gpu_device_ids) {} + gpu_topology_(gpu_device_ids), + attributes_(attributes) {} bool operator==(const StreamExecutorGpuTopologyDescription& other) const { return this->platform_id() == other.platform_id() && @@ -127,6 +132,10 @@ class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription { return attributes_; } + StatusOr GetDefaultLayout( + PrimitiveType element_type, + absl::Span dims) const override; + private: const PjRtPlatformId platform_id_; const std::string platform_name_; @@ -140,6 +149,7 @@ class StreamExecutorGpuDevice : public PjRtStreamExecutorDevice { StreamExecutorGpuDevice(int id, std::unique_ptr local_device_state, std::string device_kind, std::string device_vendor, + std::string compute_capability, int core_count, int node_id, int slice_index = 0); int slice_index() const; @@ -148,6 +158,10 @@ class StreamExecutorGpuDevice : public PjRtStreamExecutorDevice { absl::StatusOr GetAllocatorStats() const override; + absl::Span coords() const; + + int core_on_chip() const; + private: std::string device_vendor_; int slice_index_; @@ -173,12 +187,12 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { tsl::Fingerprint64(platform_name), platform_name, devices_.back()->device_kind(), devices_)) {} - xla::StatusOr GetDefaultDeviceAssignment( + absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override; absl::string_view platform_version() const override; - StatusOr> + absl::StatusOr> CreateBuffersForAsyncHostToDevice(absl::Span shapes, PjRtDevice* device) override; @@ -186,13 +200,13 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { int64_t offset, int64_t transfer_size) override; - StatusOr GetTopologyDescription() + absl::StatusOr GetTopologyDescription() const override { return &topology_; } // TODO(b/285385306): Enable loading a non-loaded PjRtExecutable. - StatusOr> Load( + absl::StatusOr> Load( std::unique_ptr executable, const LoadOptions& load_options) override { return absl::WrapUnique( @@ -201,16 +215,20 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { // TODO(b/296466237): Unify `Load` method after (de)serialization and tests on // existing use cases are done. - StatusOr> Load( + absl::StatusOr> Load( std::unique_ptr executable); // TODO(b/296466237): Unify `LoadSerializedExecutable` after fixing existing // tests. - StatusOr> LoadSerialized( + absl::StatusOr> LoadSerialized( absl::string_view serialized, std::optional options, const LoadOptions& load_options); - StatusOr> Compile( + absl::StatusOr> DeserializeExecutable( + absl::string_view serialized, + std::optional options) override; + + absl::StatusOr> Compile( const XlaComputation& computation, CompileOptions options) override; private: @@ -221,6 +239,18 @@ std::vector> BuildLocalDevices( std::map> local_device_states, int node_id); +std::string MakeComputeCapabilityString(const se::DeviceDescription* desc); + +Status BuildDistributedDevices( + std::string_view platform_name, + std::map> local_device_states, + int node_id, int num_nodes, + std::vector>* devices, + gpu::GpuExecutableRunOptions* gpu_executable_run_options, + std::shared_ptr kv_store, bool enable_mock_nccl, + absl::Duration get_local_topology_timeout = absl::Minutes(2), + absl::Duration get_global_topology_timeout = absl::Minutes(5)); + struct GpuClientOptions { GpuAllocatorConfig allocator_config; @@ -234,16 +264,13 @@ struct GpuClientOptions { bool should_stage_host_to_device_transfers = true; - // `kv_get` and `kv_put` are callbacks provided by the caller to access a - // key-value store shared between nodes. `kv_get` and `kv_put` must be - // non-null if `num_nodes` > 1. - PjRtClient::KeyValueGetCallback kv_get = nullptr; - PjRtClient::KeyValuePutCallback kv_put = nullptr; + // kv_store must be non-null if num_nodes > 1. + std::shared_ptr kv_store = nullptr; bool enable_mock_nccl = false; }; -StatusOr> GetStreamExecutorGpuClient( +absl::StatusOr> GetStreamExecutorGpuClient( const GpuClientOptions& options); } // namespace xla diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index 4361139b06c93..e7aac5c63b354 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,13 +28,16 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/distributed/in_memory_key_value_store.h" +#include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/utils.h" #include "xla/service/hlo_parser.h" #include "xla/statusor.h" @@ -53,7 +56,7 @@ using ::testing::ElementsAre; using ::testing::HasSubstr; using ::tsl::testing::StatusIs; -StatusOr> CompileExecutable( +absl::StatusOr> CompileExecutable( absl::string_view program, xla::PjRtClient& client, xla::CompileOptions compile_options = xla::CompileOptions()) { TF_ASSIGN_OR_RETURN(auto hlo_module, @@ -65,8 +68,8 @@ StatusOr> CompileExecutable( // Given the result of a PjrtExecutable::Execute call (TF-status of vectors of // vectors), extract the zeroth result from the zeroth device. -StatusOr> ExtractSingleResult( - xla::StatusOr>>>& +absl::StatusOr> ExtractSingleResult( + absl::StatusOr>>>& result) { TF_RETURN_IF_ERROR(result.status()); TF_RET_CHECK(result->size() == 1); @@ -168,7 +171,7 @@ TEST(StreamExecutorGpuClientTest, SendErrorNoDeadLock) { SendCallback send_callback = { /*channel_id=*/1, [&](const PjRtTransferMetadata&, PjRtChunk, int64_t, bool) { - return InternalError("Uh-oh, can send chunk to host"); + return Internal("Uh-oh, can send chunk to host"); }}; // No-op Recv handler. @@ -248,7 +251,7 @@ TEST(StreamExecutorGpuClientTest, ToLiteralAsync) { TF_ASSERT_OK( transfer_manager->TransferLiteralToBuffer(0, src_literal, [&]() {})); - buffer->ToLiteral(literal.get(), [&](Status s) { + buffer->ToLiteral(literal.get()).OnReady([&](Status s) { absl::MutexLock l(&mu); TF_ASSERT_OK(s); got_literal = true; @@ -282,7 +285,7 @@ TEST(StreamExecutorGpuClientTest, ToLiteralAsyncBeforeBufferReady) { ShapeUtil::DeviceShapeToHostShape(buffer->on_device_shape())); bool got_literal = false; - buffer->ToLiteral(literal.get(), [&](Status s) { + buffer->ToLiteral(literal.get()).OnReady([&](Status s) { absl::MutexLock l(&mu); TF_ASSERT_OK(s); got_literal = true; @@ -326,11 +329,6 @@ TEST(StreamExecutorGpuClientTest, FromHostAsync) { buffers.emplace_back(transfer_manager->RetrieveBuffer(i)); } - absl::Mutex mu; - std::vector> literals; - int got_literal_count = 0; - int got_callback_count = 0; - for (int i = 0; i < src_shapes.size(); ++i) { TF_ASSERT_OK(transfer_manager->TransferRawDataToBuffer( i, @@ -339,15 +337,20 @@ TEST(StreamExecutorGpuClientTest, FromHostAsync) { [&]() {})); } + absl::Mutex mu; + std::vector> literals; + int got_literal_count = 0; + int got_callback_count = 0; + for (auto& buffer : buffers) { literals.push_back(std::make_shared( ShapeUtil::DeviceShapeToHostShape(buffer->on_device_shape()))); - buffer->ToLiteral(literals.back().get(), [&](Status s) { + buffer->ToLiteral(literals.back().get()).OnReady([&](Status s) { absl::MutexLock l(&mu); TF_ASSERT_OK(s); ++got_literal_count; }); - buffer->OnReady([&](Status s) { + buffer->GetReadyFuture().OnReady([&](Status s) { absl::MutexLock l(&mu); TF_ASSERT_OK(s); ++got_callback_count; @@ -425,6 +428,32 @@ TEST(StreamExecutorGpuClientTest, CopyRawToHostOutOfRange) { free(dst); } +TEST(StreamExecutorGpuClientTest, CopyRawToHostFuture) { + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + auto literal = xla::LiteralUtil::CreateR1({41.0f, 42.0f}); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr buffer, + client->BufferFromHostLiteral(literal, client->addressable_devices()[0])); + + auto dst_promise = xla::PjRtFuture>::CreatePromise(); + xla::PjRtFuture> dst_future(dst_promise); + + TF_ASSERT_OK_AND_ASSIGN(int64_t size, buffer->GetOnDeviceSizeInBytes()); + buffer->GetReadyFuture().OnReady([dst_promise = std::move(dst_promise), + size](absl::Status status) mutable { + dst_promise.Set(aligned_alloc(size, 0)); + }); + + auto result = buffer->CopyRawToHostFuture(dst_future, 0, size); + TF_EXPECT_OK(result.Await()); + TF_ASSERT_OK_AND_ASSIGN(auto* dst, dst_future.Await()); + EXPECT_EQ(*(static_cast(dst)), 41.0f); + EXPECT_EQ(*(static_cast(dst) + 1), 42.0f); + + free(dst); +} + TEST(StreamExecutorGpuClientTest, AsyncCopyToDevice) { TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(GpuClientOptions())); @@ -470,9 +499,9 @@ TEST(StreamExecutorGpuClientTest, CreateMixOfErrorBuffers) { src_literals.emplace_back(LiteralUtil::CreateR1(data)); src_shapes.push_back(src_literals.back().shape()); } - ASSERT_OK_AND_ASSIGN(auto transfer_manager, - client->CreateBuffersForAsyncHostToDevice( - src_shapes, client->addressable_devices()[0])); + TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, + client->CreateBuffersForAsyncHostToDevice( + src_shapes, client->addressable_devices()[0])); std::vector> buffers; for (int i = 0; i < src_shapes.size(); ++i) { buffers.emplace_back(transfer_manager->RetrieveBuffer(i)); @@ -483,21 +512,22 @@ TEST(StreamExecutorGpuClientTest, CreateMixOfErrorBuffers) { for (int i = 0; i < 4; ++i) { auto& buffer = buffers[i]; if (i == 0 || i == 3) { - ASSERT_OK(transfer_manager->TransferLiteralToBuffer(i, src_literals[i], - [&]() {})); - buffer->OnReady([&](absl::Status s) { + TF_ASSERT_OK(transfer_manager->TransferLiteralToBuffer(i, src_literals[i], + [&]() {})); + buffer->GetReadyFuture().OnReady([&](absl::Status s) { absl::MutexLock l(&mu); - ASSERT_OK(s); + TF_ASSERT_OK(s); ++got_callback_count; }); } else { - absl::Status error = InternalError("error %d", i); + absl::Status error = Internal("error %d", i); transfer_manager->SetBufferError(i, error); - buffer->OnReady([error, &mu, &got_callback_count](absl::Status s) { - absl::MutexLock l(&mu); - ASSERT_EQ(s, error); - ++got_callback_count; - }); + buffer->GetReadyFuture().OnReady( + [error, &mu, &got_callback_count](absl::Status s) { + absl::MutexLock l(&mu); + ASSERT_EQ(s, error); + ++got_callback_count; + }); } buffer.reset(); } @@ -527,46 +557,17 @@ TEST(GpuTopology, ToProto) { EXPECT_THAT(msg.device_ids(), ElementsAre(3, 2, 1)); } -TEST(StreamExecutorGpuClientTest, DistributeInit) { - absl::flat_hash_map kv_store; - absl::Mutex mu; - PjRtClient::KeyValueGetCallback kv_get = - [&kv_store, &mu](std::string_view k, - absl::Duration timeout) -> xla::StatusOr { - absl::Duration wait_interval = absl::Milliseconds(10); - int num_retry = timeout / wait_interval; - for (int i = 0; i < num_retry; i++) { - { - absl::MutexLock lock(&mu); - auto iter = kv_store.find(k); - if (iter != kv_store.end()) { - return iter->second; - } - } - absl::SleepFor(wait_interval); - } - return absl::NotFoundError( - absl::StrCat(k, " is not found in the kv store.")); - }; - PjRtClient::KeyValuePutCallback kv_put = - [&kv_store, &mu](std::string_view k, std::string_view v) -> xla::Status { - { - absl::MutexLock lock(&mu); - kv_store[k] = v; - } - return tsl::OkStatus(); - }; - +TEST(StreamExecutorGpuClientTest, DistributedInit) { + auto kv_store = std::make_shared(); tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "DistributeInit", 4); int num_nodes = 2; for (int i = 0; i < num_nodes; i++) { - thread_pool.Schedule([&, i] { + thread_pool.Schedule([kv_store, i, num_nodes] { GpuClientOptions options; options.node_id = i; options.num_nodes = num_nodes; - options.kv_get = kv_get; - options.kv_put = kv_put; + options.kv_store = kv_store; TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(options)); EXPECT_TRUE(client->platform_name() == "cuda" || client->platform_name() == "rocm"); @@ -592,5 +593,22 @@ TEST(StreamExecutorGpuClientTest, GetAllocatorStatsTest) { } } +TEST(StreamExecutorGpuClientTest, GpuDeviceDescriptionTest) { + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + for (int device_index = 0; device_index < client->device_count(); + device_index++) { + auto coords = + static_cast(client->devices()[device_index]) + ->description() + .coords(); + EXPECT_EQ(coords[0], device_index); + } + EXPECT_EQ(static_cast(client->devices()[0]) + ->description() + .core_on_chip(), + 0); +} + } // namespace } // namespace xla diff --git a/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc b/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc index 36b5e43e542d6..ebb68c2a483f9 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ limitations under the License. #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/status_macros.h" +#include "xla/stream_executor/platform/initialize.h" #include "tsl/platform/casts.h" #include "tsl/platform/errors.h" @@ -41,20 +42,22 @@ limitations under the License. #include "xla/service/hlo_proto_util.h" #include "xla/service/local_service.h" #include "xla/service/local_service_utils.h" -#include "xla/stream_executor/cuda/cuda_platform_id.h" #endif #if GOOGLE_CUDA #include "xla/service/gpu/nvptx_compiler.h" +#include "xla/stream_executor/cuda/cuda_platform_id.h" #elif TENSORFLOW_USE_ROCM #include "xla/service/gpu/amdgpu_compiler.h" +#include "xla/stream_executor/rocm/rocm_platform_id.h" #endif namespace xla { namespace { bool IsGpuClient(const PjRtClient& client) { - return client.platform_id() == CudaId() || client.platform_id() == RocmId(); + return client.platform_id() == CudaId() || client.platform_id() == RocmId() || + client.platform_id() == SyclId(); } bool IsSameTopology(const PjRtTopologyDescription& topology1, @@ -104,13 +107,24 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options, CompileOptions input_options = options; if (!options.target_config) { - if (!client) { + if (client != nullptr) { + TF_RETURN_IF_ERROR(IsValidTopologyAndClientForCompile(topology, client)); + return client->Compile(computation, options); + } + auto attr = topology.Attributes(); + if (auto it = attr.find("target_config"); it != attr.end()) { + auto target_config_str = std::get(it->second); + stream_executor::GpuTargetConfigProto gpu_target_config_proto; + if (!gpu_target_config_proto.ParseFromString(target_config_str)) { + return FailedPrecondition("Failed to parse GpuTargetConfigProto"); + } + options.target_config.emplace( + Compiler::TargetConfig(gpu_target_config_proto)); + } else { return absl::UnimplementedError( "Compilation without client and without target_config specified is " "not implemented"); } - TF_RETURN_IF_ERROR(IsValidTopologyAndClientForCompile(topology, client)); - return client->Compile(computation, options); } TF_RETURN_IF_ERROR(options.ApplyAllOptionOverrides()); std::vector argument_layout_pointers; @@ -135,18 +149,15 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options, Compiler::CompileOptions opts; opts.target_config = options.target_config; - if (!options.executable_build_options.run_backend_only()) { - TF_ASSIGN_OR_RETURN( - hlo_module, gpu_compiler.RunHloPasses(std::move(hlo_module), - /*stream_exec=*/nullptr, opts)); - } - AotCompilationOptions aot_options(gpu_compiler.PlatformId()); aot_options.set_target_config(*options.target_config); + aot_options.set_run_backend_only( + options.executable_build_options.run_backend_only()); const int num_replicas = hlo_module->config().replica_count(); const int num_partitions = hlo_module->config().num_partitions(); const std::string name = hlo_module->name(); + const std::string fingerprint = hlo_module->GetFingerprint128(); auto unique_module_group = std::make_unique(std::move(hlo_module)); TF_ASSIGN_OR_RETURN( @@ -155,7 +166,7 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options, aot_options)); return std::make_unique( std::move(input_options), std::move(aot_results), num_replicas, - num_partitions, name); + num_partitions, name, fingerprint); #else return absl::InternalError( "GPU Compilation requires the target to be built with CUDA or " @@ -183,8 +194,13 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options, #endif } -REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, { - PjRtRegisterCompiler(CudaName(), - std::make_unique()); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, { + PjRtRegisterCompiler( +#if TENSORFLOW_USE_ROCM + RocmName(), +#else + CudaName(), +#endif + std::make_unique()); }); } // namespace xla diff --git a/xla/pjrt/gpu/se_gpu_pjrt_compiler.h b/xla/pjrt/gpu/se_gpu_pjrt_compiler.h index 17850fb1cb741..facd2dd5dd0bf 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_compiler.h +++ b/xla/pjrt/gpu/se_gpu_pjrt_compiler.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc b/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc index 7daa467f8e6d4..3d0c2d0319ed6 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -39,7 +39,11 @@ limitations under the License. #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/service/compiler.h" +#if GOOGLE_CUDA #include "xla/service/gpu/nvptx_compiler.h" +#elif TENSORFLOW_USE_ROCM +#include "xla/service/gpu/amdgpu_compiler.h" +#endif #include "xla/service/hlo_parser.h" #include "xla/tests/literal_test_util.h" #include "tsl/platform/casts.h" @@ -116,7 +120,11 @@ TEST(StreamExecutorGpuCompilerTest, SuccessAotCompileXlaAndLoad) { GetStreamExecutorGpuClient(GpuClientOptions())); auto se_client = absl::WrapUnique( tensorflow::down_cast(client.release())); +#if GOOGLE_CUDA auto gpu_compiler = gpu::NVPTXCompiler(); +#elif TENSORFLOW_USE_ROCM + auto gpu_compiler = gpu::AMDGPUCompiler(); +#endif Compiler::TargetConfig gpu_target_config{ se_client->client()->backend().default_stream_executor()}; StreamExecutorGpuCompiler compiler; diff --git a/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc b/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc index 69013b441b9f7..70a47c2400c43 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/host_callback.cc b/xla/pjrt/host_callback.cc index 938220a880f08..e59eeb6f1f815 100644 --- a/xla/pjrt/host_callback.cc +++ b/xla/pjrt/host_callback.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,13 @@ limitations under the License. namespace xla { +static thread_local int on_send_guard = 0; + +void EnterHostCallback() { ++on_send_guard; } +void LeaveHostCallback() { --on_send_guard; } + +bool ThisThreadIsInsideHostCallback() { return on_send_guard > 0; } + Status HostCallbackContext::OnSend(int arg_num, const PjRtTransferMetadata& metadata, PjRtChunk data) { @@ -72,7 +79,10 @@ Status HostCallbackContext::OnSend(int arg_num, result_ptrs.push_back(results.back().data()); } + EnterHostCallback(); auto status = host_callback_.callback(result_ptrs.data(), arg_ptrs.data()); + LeaveHostCallback(); + // TODO(chky): Consider populating garbage data in results upon errors. // Clear the arguments for this invocation. This won't race with next diff --git a/xla/pjrt/host_callback.h b/xla/pjrt/host_callback.h index 4a9b44874d895..58788c0ee8795 100644 --- a/xla/pjrt/host_callback.h +++ b/xla/pjrt/host_callback.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,6 +34,12 @@ limitations under the License. namespace xla { +bool ThisThreadIsInsideHostCallback(); + +void EnterHostCallback(); + +void LeaveHostCallback(); + // A thread-safe queue for passing PjRtChunk objects for e.g. from Send ops to // Recv ops. class ThreadSafePjRtChunkQueue { diff --git a/xla/pjrt/host_callback_test.cc b/xla/pjrt/host_callback_test.cc index 38ddbf043bb98..a1125a601736a 100644 --- a/xla/pjrt/host_callback_test.cc +++ b/xla/pjrt/host_callback_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/host_memory_spaces.cc b/xla/pjrt/host_memory_spaces.cc new file mode 100644 index 0000000000000..0273ec6da85b4 --- /dev/null +++ b/xla/pjrt/host_memory_spaces.cc @@ -0,0 +1,39 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/pjrt/host_memory_spaces.h" + +#include "absl/strings/str_format.h" +#include "xla/pjrt/pjrt_client.h" + +namespace xla { + +UnpinnedHostMemorySpace::UnpinnedHostMemorySpace(int id, PjRtClient* client) + : id_(id), client_(client) { + debug_string_ = absl::StrFormat( + "UnpinnedHostMemorySpace(id=%i, process_index=%i, client=%s)", id_, + client_->process_index(), client_->platform_name()); + to_string_ = absl::StrFormat("UNPINNED_HOST_%i", id_); +} + +PinnedHostMemorySpace::PinnedHostMemorySpace(int id, PjRtClient* client) + : id_(id), client_(client) { + debug_string_ = + absl::StrFormat("PinnedHostMemory(id=%i, process_index=%i, client=%s)", + id_, client_->process_index(), client_->platform_name()); + to_string_ = absl::StrFormat("PINNED_HOST_%i", id_); +} + +} // namespace xla diff --git a/xla/pjrt/host_memory_spaces.h b/xla/pjrt/host_memory_spaces.h new file mode 100644 index 0000000000000..cc51eea4391bc --- /dev/null +++ b/xla/pjrt/host_memory_spaces.h @@ -0,0 +1,99 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_HOST_MEMORY_SPACES_H_ +#define XLA_PJRT_HOST_MEMORY_SPACES_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/pjrt/pjrt_client.h" + +namespace xla { + +// Represents the unpinned host memory accessible to a `PjRtDevice`. +// An "unpinned" host memory space accommodates ordinary host buffers that are +// not mapped to any virtual memory of the attached `PjRtDevice`. +class UnpinnedHostMemorySpace : public PjRtMemorySpace { + public: + static constexpr absl::string_view kMemorySpaceKind = "unpinned_host"; + + UnpinnedHostMemorySpace(int id, PjRtClient* client); + + PjRtClient* client() const override { return client_; } + + absl::Span devices() const override { return devices_; } + + int id() const override { return id_; } + + absl::string_view memory_space_kind() const override { + return kMemorySpaceKind; + } + + absl::string_view DebugString() const override { return debug_string_; } + + absl::string_view ToString() const override { return to_string_; } + + void AttachDevice(PjRtDevice* device) { devices_.push_back(device); } + + private: + int id_; + PjRtClient* client_; + std::vector devices_; + std::string debug_string_; + std::string to_string_; +}; + +// Represents the pinned host memory accessible to a `PjRtDevice`. +// A "pinned" host memory space accommodates host buffers that are mapped to a +// virtual memory of the attached `PjRtDevice`. The `PjRtDevice` may have the +// capability to direct-memory-access (DMA) the buffers in this memory space. +class PinnedHostMemorySpace : public PjRtMemorySpace { + public: + static constexpr absl::string_view kMemorySpaceKind = "pinned_host"; + + PinnedHostMemorySpace(int id, PjRtClient* client); + + PjRtClient* client() const override { return client_; } + + absl::Span devices() const override { + return absl::Span(&device_, device_ != nullptr ? 1 : 0); + } + + int id() const override { return id_; } + + absl::string_view memory_space_kind() const override { + return kMemorySpaceKind; + } + + absl::string_view DebugString() const override { return debug_string_; } + + absl::string_view ToString() const override { return to_string_; } + + void AttachDevice(PjRtDevice* device) { device_ = device; } + + private: + int id_; + PjRtClient* client_ = nullptr; + PjRtDevice* device_ = nullptr; + std::string debug_string_; + std::string to_string_; +}; + +} // namespace xla + +#endif // XLA_PJRT_HOST_MEMORY_SPACES_H_ diff --git a/xla/pjrt/interpreter_device.cc b/xla/pjrt/interpreter_device.cc index aa9c9eba92a6a..2fd5a71c957d0 100644 --- a/xla/pjrt/interpreter_device.cc +++ b/xla/pjrt/interpreter_device.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/interpreter_device.h b/xla/pjrt/interpreter_device.h index 0d820cb0912ce..1842a816e9eb2 100644 --- a/xla/pjrt/interpreter_device.h +++ b/xla/pjrt/interpreter_device.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/layout_mode.cc b/xla/pjrt/layout_mode.cc index 458144d50865c..864a03799364d 100644 --- a/xla/pjrt/layout_mode.cc +++ b/xla/pjrt/layout_mode.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/layout_mode.h b/xla/pjrt/layout_mode.h index 792bc3de9f9a1..156932a2f71c4 100644 --- a/xla/pjrt/layout_mode.h +++ b/xla/pjrt/layout_mode.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/local_device_state.cc b/xla/pjrt/local_device_state.cc index e2d5ebe394910..5c0d2b30e7032 100644 --- a/xla/pjrt/local_device_state.cc +++ b/xla/pjrt/local_device_state.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,9 +22,12 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "absl/synchronization/mutex.h" #include "xla/stream_executor/stream.h" #include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/protobuf/error_codes.pb.h" @@ -46,8 +49,9 @@ LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor, prng_seed_generator_(prng_seed_device_()), prng_seed_distribution_(std::numeric_limits::min(), std::numeric_limits::max()) { - device_ordinal_ = - device_ordinal != -1 ? device_ordinal : executor->device_ordinal(); + local_hardware_id_ = executor_->device_ordinal(); + local_device_id_ = + device_ordinal != -1 ? device_ordinal : executor_->device_ordinal(); int num_device_to_host_streams = stream_options.has_value() ? stream_options->num_device_to_host_streams @@ -55,46 +59,30 @@ LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor, int num_device_to_device_streams = stream_options.has_value() ? stream_options->num_device_to_device_streams : kNumDeviceToDeviceStreams; - compute_stream_ = std::make_unique(executor); - if (stream_options.has_value()) { - compute_stream_->SetPriority(stream_options->priority); - } - host_to_device_stream_ = std::make_unique(executor); - if (stream_options.has_value()) { - host_to_device_stream_->SetPriority(stream_options->priority); - } - compute_stream_->Init(); - host_to_device_stream_->Init(); + auto create_stream = [executor, &stream_options]() { + if (stream_options.has_value()) { + return executor->CreateStream(stream_options->priority).value(); + } else { + return executor->CreateStream().value(); + } + }; + compute_stream_ = create_stream(); + host_to_device_stream_ = create_stream(); if (use_callback_stream) { callback_stream_map_ = absl::flat_hash_map>(); } device_to_host_streams_.reserve(num_device_to_host_streams); for (int i = 0; i < num_device_to_host_streams; ++i) { - auto stream = std::make_unique(executor); - if (stream_options.has_value()) { - stream->SetPriority(stream_options->priority); - } - stream->Init(); - device_to_host_streams_.push_back(std::move(stream)); + device_to_host_streams_.emplace_back(create_stream()); } device_to_device_streams_.reserve(num_device_to_device_streams); for (int i = 0; i < num_device_to_device_streams; ++i) { - auto stream = std::make_unique(executor); - if (stream_options.has_value()) { - stream->SetPriority(stream_options->priority); - } - stream->Init(); - device_to_device_streams_.push_back(std::move(stream)); + device_to_device_streams_.emplace_back(create_stream()); } external_ready_event_streams_.reserve(kNumExternalReadyEventStreams); for (int i = 0; i < kNumExternalReadyEventStreams; ++i) { - auto stream = std::make_unique(executor); - if (stream_options.has_value()) { - stream->SetPriority(stream_options->priority); - } - stream->Init(); - external_ready_event_streams_.push_back(std::move(stream)); + external_ready_event_streams_.emplace_back(create_stream()); } execute_thread_ = std::make_unique(tsl::Env::Default(), "py_xla_execute"); @@ -136,32 +124,31 @@ Status LocalDeviceState::SynchronizeAllActivity() { Status LocalDeviceState::ThenMemcpyDeviceToDevice( se::Stream* transfer_stream, se::Stream* dst_stream, se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) { - // The default implementation simply calls ThenMemcpyD2D, and assumes that + // The default implementation simply calls MemcpyD2D, and assumes that // the buffer addresses identify the devices. This does not work // on all platforms; this method is virtual so it can be overridden. - transfer_stream->ThenMemcpyD2D(&dst_buffer, src_buffer, dst_buffer.size()); - return OkStatus(); + return transfer_stream->MemcpyD2D(&dst_buffer, src_buffer, dst_buffer.size()); } -void LocalDeviceState::ThenExecuteCallback(se::Stream* stream, - std::function callback) { +absl::Status LocalDeviceState::ThenExecuteCallback( + se::Stream* stream, std::function callback) { tsl::profiler::TraceMe traceme("ThenExecuteCallback"); if (callback_stream_map_.has_value()) { // Prevent concurrent updates to the callback stream map. absl::MutexLock lock(&callback_stream_map_mu_); auto callback_stream = callback_stream_map_->find(stream); if (callback_stream == callback_stream_map_->end()) { - auto new_stream = std::make_unique(executor_); - new_stream->Init(); + TF_ASSIGN_OR_RETURN(auto new_stream, executor_->CreateStream()); callback_stream = callback_stream_map_->insert({stream, std::move(new_stream)}).first; } - callback_stream->second->ThenWaitFor(stream); + TF_RETURN_IF_ERROR(callback_stream->second->WaitFor(stream)); stream = callback_stream->second.get(); } - stream->ThenDoHostCallback([this, callback{std::move(callback)}]() mutable { - callback_thread_->Schedule(std::move(callback)); - }); + return stream->DoHostCallback( + [this, callback{std::move(callback)}]() mutable { + callback_thread_->Schedule(std::move(callback)); + }); } se::Stream* LocalDeviceState::GetDeviceToHostStream() { @@ -215,21 +202,23 @@ std::vector LocalDeviceState::GetDeviceToDeviceStreams() { } std::unique_ptr LocalDeviceState::BorrowStreamFromPool() { - absl::MutexLock lock(&mu_); - if (usage_stream_pool_.empty()) { - auto stream = std::make_unique(compute_stream_->parent()); - stream->Init(); - return stream; - } else { - std::unique_ptr stream = std::move(usage_stream_pool_.top()); - usage_stream_pool_.pop(); - auto status = stream->RefreshStatus(); // Can return error::Unimplemented - // Stream may fail with "ABORTED: Bad connection". - if (status.code() != tsl::error::ABORTED) { - CHECK(stream->ok()) << status; + { + absl::MutexLock lock(&stream_pool_mu_); + if (!usage_stream_pool_.empty()) { + std::unique_ptr stream = std::move(usage_stream_pool_.top()); + usage_stream_pool_.pop(); + auto status = stream->RefreshStatus(); // Can return error::Unimplemented + // Stream may fail with "ABORTED: Bad connection". + if (status.code() != tsl::error::ABORTED) { + CHECK(stream->ok()) << status; + } + return stream; } - return stream; } + + // The stream pool is empty, create a new stream. + auto stream = compute_stream_->parent()->CreateStream().value(); + return stream; } void LocalDeviceState::ReturnStreamToPool(std::unique_ptr stream) { @@ -238,7 +227,7 @@ void LocalDeviceState::ReturnStreamToPool(std::unique_ptr stream) { if (status.code() != tsl::error::ABORTED) { CHECK(stream->ok()) << status; } - absl::MutexLock lock(&mu_); + absl::MutexLock lock(&stream_pool_mu_); usage_stream_pool_.push(std::move(stream)); } diff --git a/xla/pjrt/local_device_state.h b/xla/pjrt/local_device_state.h index d299e85aa67bd..73206d4fbd3ab 100644 --- a/xla/pjrt/local_device_state.h +++ b/xla/pjrt/local_device_state.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/client/local_client.h" #include "xla/pjrt/event_pool.h" +#include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/semaphore.h" #include "xla/pjrt/worker_thread.h" #include "xla/status.h" @@ -97,8 +98,13 @@ class LocalDeviceState { int num_device_to_device_streams = 1; }; - // If asynchronous is false, the host will synchronize to the device after - // each execution or transfer. This is intended for debugging only. + // `device_ordinal` is the logical local device ordinal (returned by + // `local_device_id()`), and it's used to look up an addressable device local + // to a given client. If it is not set (-1 by default), the device's logical + // device ordinal will be the same as its physical device ordinal (returned by + // `local_hardware_id()`). In general, different PJRT devices have different + // logical device ordinals, and several PJRT devices can have the same + // physical device ordinal if they share the same physical device. LocalDeviceState(se::StreamExecutor* executor, LocalClient* client, AllocationModel allocation_model, int max_inflight_computations, bool allow_event_reuse, @@ -108,7 +114,8 @@ class LocalDeviceState { se::StreamExecutor* executor() const { return executor_; } - int device_ordinal() const { return device_ordinal_; } + PjRtLocalDeviceId local_device_id() { return local_device_id_; } + PjRtLocalHardwareId local_hardware_id() { return local_hardware_id_; } LocalClient* client() const { return client_; } @@ -168,7 +175,8 @@ class LocalDeviceState { // runtime and cannot perform GPU operations itself. On GPU, callbacks // execute in a separate thread. // b) ThenDoHostCallback waits for the callback to complete. - void ThenExecuteCallback(se::Stream* stream, std::function callback); + absl::Status ThenExecuteCallback(se::Stream* stream, + std::function callback); // Helpers for releasing values on a worker thread at the tail of a stream on // a worker thread. Copies `object`, and destroys the copy when the tail of @@ -177,8 +185,8 @@ class LocalDeviceState { // device callback, so it is safe if the destructor frees device resource // (e.g., GPU objects). template - void ThenRelease(se::Stream* stream, T&& object) { - ThenExecuteCallback( + absl::Status ThenRelease(se::Stream* stream, T&& object) { + return ThenExecuteCallback( stream, [object = std::forward(object)]() { /* releases object */ }); } @@ -198,7 +206,8 @@ class LocalDeviceState { // stream by the host ahead of the device. Semaphore compute_semaphore_; - int device_ordinal_; + PjRtLocalDeviceId local_device_id_; + PjRtLocalHardwareId local_hardware_id_; se::StreamExecutor* const executor_; LocalClient* const client_; std::unique_ptr compute_stream_; @@ -215,13 +224,15 @@ class LocalDeviceState { int next_device_to_host_stream_ ABSL_GUARDED_BY(mu_) = 0; int next_device_to_device_stream_ ABSL_GUARDED_BY(mu_) = 0; int next_external_ready_event_stream_ ABSL_GUARDED_BY(mu_) = 0; - std::stack> usage_stream_pool_ - ABSL_GUARDED_BY(mu_); std::random_device prng_seed_device_ ABSL_GUARDED_BY(mu_); std::mt19937 prng_seed_generator_ ABSL_GUARDED_BY(mu_); std::uniform_int_distribution<> prng_seed_distribution_ ABSL_GUARDED_BY(mu_); + absl::Mutex stream_pool_mu_; + std::stack> usage_stream_pool_ + ABSL_GUARDED_BY(stream_pool_mu_); + // Callback map pairs callback stream with a device stream and is used for // running short host-side callbacks after device side events, without // preventing the device-side stream from doing useful work. diff --git a/xla/pjrt/lru_cache.h b/xla/pjrt/lru_cache.h index ce3d1933f8c5f..2d8bca88c25ad 100644 --- a/xla/pjrt/lru_cache.h +++ b/xla/pjrt/lru_cache.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -92,6 +92,9 @@ class LRUCache { int Size() const { return entries_.size(); } int Capacity() const { return lru_list_->Capacity(); } + auto begin() const { return entries_.begin(); } + auto end() const { return entries_.end(); } + private: LRUList* lru_list_; diff --git a/xla/pjrt/lru_cache_test.cc b/xla/pjrt/lru_cache_test.cc index 097393a1f7967..1c091bb1188a3 100644 --- a/xla/pjrt/lru_cache_test.cc +++ b/xla/pjrt/lru_cache_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/metrics.cc b/xla/pjrt/metrics.cc index 1a17cc6590e08..add1aac679c98 100644 --- a/xla/pjrt/metrics.cc +++ b/xla/pjrt/metrics.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ limitations under the License. #include -#include "absl/strings/str_cat.h" #include "tsl/lib/monitoring/counter.h" #include "tsl/lib/monitoring/gauge.h" @@ -42,10 +41,6 @@ auto* pjrt_compiler_is_compiling_module = tsl::monitoring::Gauge::New( metrics::kPjrtCompilerCompileModuleMetricName, "Whether the PjRT compiler is compiling modules."); -auto* free_gpu_system_memory = tsl::monitoring::Gauge::New( - metrics::kPjrtCompilerFreeGpuSystemMemoryMetricName, - "Record the free GPU system memory.", "gpu_id"); - } // namespace namespace metrics { @@ -69,15 +64,5 @@ void RecordPjrtCompilerCompileModuleStatus(bool is_compiling) { pjrt_compiler_is_compiling_module->GetCell()->Set(is_compiling); } -void RecordFreeGpuSystemMemory(const int device_ordinal, - const int64_t free_memory) { - free_gpu_system_memory->GetCell(absl::StrCat(device_ordinal)) - ->Set(free_memory); -} - -int64_t GetFreeGpuSystemMemory(int gpu_id) { - return free_gpu_system_memory->GetCell(absl::StrCat(gpu_id))->value(); -} - } // namespace metrics } // namespace xla diff --git a/xla/pjrt/metrics.h b/xla/pjrt/metrics.h index 870e473aac6a8..602886b740ef6 100644 --- a/xla/pjrt/metrics.h +++ b/xla/pjrt/metrics.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,8 +29,6 @@ inline constexpr absl::string_view kPjrtCompilerCompileComputationMetricName = "/pjrt/compiler/is_compiling_computation"; inline constexpr absl::string_view kPjrtCompilerCompileModuleMetricName = "/pjrt/compiler/is_compiling_module"; -inline constexpr absl::string_view kPjrtCompilerFreeGpuSystemMemoryMetricName = - "/pjrt/compiler/free_gpu_system_memory"; void ReportExecutableEnqueueTime(uint64_t running_time_usecs); @@ -38,12 +36,6 @@ void RecordPjrtCompilerCompileComputationStatus(bool is_compiling); void RecordPjrtCompilerCompileModuleStatus(bool is_compiling); -// TODO(xiangll): Refactor to a more appropriate location. -void RecordFreeGpuSystemMemory(int device_ordinal, int64_t free_memory); - -// TODO(xiangll): Refactor to a more appropriate location. -int64_t GetFreeGpuSystemMemory(int gpu_id); - } // namespace metrics } // namespace xla diff --git a/xla/pjrt/mlir_to_hlo.cc b/xla/pjrt/mlir_to_hlo.cc index 7db5d56d77d6b..664012f46b978 100644 --- a/xla/pjrt/mlir_to_hlo.cc +++ b/xla/pjrt/mlir_to_hlo.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,43 +15,192 @@ limitations under the License. #include "xla/pjrt/mlir_to_hlo.h" +#include +#include +#include #include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Bytecode/BytecodeWriter.h" // from @llvm-project +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" // from @llvm-project +#include "mlir/Dialect/MLProgram/IR/MLProgram.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project -#include "stablehlo/dialect/ChloOps.h" // from @stablehlo +#include "stablehlo/dialect/Register.h" // from @stablehlo +#include "stablehlo/dialect/Serialization.h" // from @stablehlo #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/transforms/Passes.h" // from @stablehlo #include "xla/mlir/utils/error_util.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/status.h" +#include "xla/statusor.h" #include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" +#include "xla/util.h" namespace xla { +namespace { + +static mlir::Attribute ArrayToElements(mlir::Attribute attr) { + if (auto array = attr.dyn_cast()) { + return mlir::DenseIntElementsAttr::get( + mlir::RankedTensorType::get(array.size(), array.getElementType()), + array.asArrayRef()); + } + if (auto array = attr.dyn_cast()) { + return mlir::DenseIntElementsAttr::get( + mlir::RankedTensorType::get(array.size(), array.getElementType()), + array.asArrayRef()); + } + return attr; +} + +static mlir::Attribute ElementsToArray(mlir::Attribute attr) { + if (auto elements = llvm::dyn_cast(attr)) { + if (elements.getElementType().isInteger(64)) { + return mlir::DenseI64ArrayAttr::get( + attr.getContext(), llvm::to_vector(elements.getValues())); + } + return mlir::DenseBoolArrayAttr::get( + attr.getContext(), llvm::to_vector(elements.getValues())); + } + return attr; +} + +static void ConvertAttr( + mlir::Operation* op, llvm::StringRef attr_name, + llvm::function_ref convert) { + if (auto attr = op->getAttr(attr_name)) { + op->setAttr(attr_name, convert(attr)); + } +} + +// Convert attrs that use DenseI64ArrayAttr (or DenseBoolArrayAttr) to use a +// different type of Attribute. For backwards compatibility purposes, arrays +// should be converted to DenseIntElementsAttr right before serialization, and +// converted back right after serialization. Deserialization checks the IR is +// valid by default, so you will need to disable that and do the verification +// explicitly after parsing. +void ConvertStablehloDenseAttributes( + mlir::Operation* root_op, + llvm::function_ref convert, + std::optional plugin_version) { + llvm::TypeSwitch(root_op) + .Case([&](mlir::stablehlo::BroadcastInDimOp op) { + ConvertAttr(op, "broadcast_dimensions", convert); + }) + .Case([&](mlir::stablehlo::ConvolutionOp op) { + ConvertAttr(op, "window_strides", convert); + ConvertAttr(op, "lhs_dilation", convert); + ConvertAttr(op, "rhs_dilation", convert); + ConvertAttr(op, "window_reversal", convert); + }) + .Case([&](mlir::stablehlo::DynamicBroadcastInDimOp op) { + ConvertAttr(op, "broadcast_dimensions", convert); + ConvertAttr(op, "known_expanding_dimensions", convert); + ConvertAttr(op, "known_nonexpanding_dimensions", convert); + }) + .Case([&](mlir::stablehlo::DynamicConvOp op) { + ConvertAttr(op, "window_strides", convert); + ConvertAttr(op, "lhs_dilation", convert); + ConvertAttr(op, "rhs_dilation", convert); + ConvertAttr(op, "window_reversal", convert); + }) + .Case([&](mlir::stablehlo::GatherOp op) { + ConvertAttr(op, "slice_sizes", convert); + }) + .Case([&](mlir::stablehlo::MapOp op) { + ConvertAttr(op, "dimensions", convert); + }) + .Case([&](mlir::stablehlo::ReduceOp op) { + ConvertAttr(op, "dimensions", convert); + }) + .Case([&](mlir::stablehlo::ReduceWindowOp op) { + ConvertAttr(op, "window_dimensions", convert); + ConvertAttr(op, "window_strides", convert); + ConvertAttr(op, "base_dilations", convert); + ConvertAttr(op, "window_dilations", convert); + }) + + .Case([&](mlir::stablehlo::SelectAndScatterOp op) { + ConvertAttr(op, "window_dimensions", convert); + ConvertAttr(op, "window_strides", convert); + }); + + // Use PJRT_API_MINOR 40 from Nov 27, 2023 for Dec 9, 2023 StableHLO changes. + // Always run when plugin_value is unset (used for deserialization upgrades) + // and only run when plugin version is less than 40 otherwise. + if (!plugin_version.has_value() || plugin_version.value() < 40) { + // Downgrade slice, dynamic_slice, pad, broadcast, transpose, fft, reverse + llvm::TypeSwitch(root_op) + .Case([&](mlir::stablehlo::BroadcastOp op) { + ConvertAttr(op, "broadcast_sizes", convert); + }) + .Case([&](mlir::stablehlo::DynamicSliceOp op) { + ConvertAttr(op, "slice_sizes", convert); + }) + .Case([&](mlir::stablehlo::FftOp op) { + ConvertAttr(op, "fft_length", convert); + }) + .Case([&](mlir::stablehlo::PadOp op) { + ConvertAttr(op, "edge_padding_low", convert); + ConvertAttr(op, "edge_padding_high", convert); + ConvertAttr(op, "interior_padding", convert); + }) + .Case([&](mlir::stablehlo::ReverseOp op) { + ConvertAttr(op, "dimensions", convert); + }) + .Case([&](mlir::stablehlo::SliceOp op) { + ConvertAttr(op, "start_indices", convert); + ConvertAttr(op, "limit_indices", convert); + ConvertAttr(op, "strides", convert); + }) + .Case([&](mlir::stablehlo::TransposeOp op) { + ConvertAttr(op, "permutation", convert); + }); + } +} + +void DowngradeStablehlo(mlir::ModuleOp module, + std::optional plugin_version) { + module->walk([&](mlir::Operation* op) { + ConvertStablehloDenseAttributes(op, ArrayToElements, plugin_version); + }); +} +void UpgradeStablehlo(mlir::ModuleOp module) { + module->walk([](mlir::Operation* op) { + ConvertStablehloDenseAttributes(op, ElementsToArray, + /*plugin_version=*/std::nullopt); + }); +} + +} // namespace + Status MlirToXlaComputation(mlir::ModuleOp module, XlaComputation& xla_computation, - bool use_tuple_args, bool return_tuple, - bool legalize_sparse_ops) { + bool use_tuple_args, bool return_tuple) { mlir::BaseScopedDiagnosticHandler diagnostic_handler(module->getContext()); { mlir::PassManager pm(module->getContext()); - if (legalize_sparse_ops) { - // Convert sparse operations to custom_calls in order to translate sparse - // operations into XLA HLO. - pm.addNestedPass( - mlir::mhlo::createLegalizeSparseOperationsPass( - /*legalizeToCustomCalls=*/true)); - } pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); pm.addNestedPass( - mlir::mhlo::createChloLegalizeToHloPass( - /*legalizeBroadcasts=*/true, /*expandCompositions=*/true)); + mlir::mhlo::createChloLegalizeToHloPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); // In order to export to XLA, we must sink constants to control flow // regions, since XLA uses functional control flow. @@ -70,9 +219,8 @@ Status MlirToXlaComputation(mlir::ModuleOp module, } HloProto proto; - mlir::MlirToHloConversionOptions options; - TF_RETURN_IF_ERROR(ConvertMlirHloToHlo(module, &proto, use_tuple_args, - return_tuple, options)); + TF_RETURN_IF_ERROR( + ConvertMlirHloToHlo(module, &proto, use_tuple_args, return_tuple)); xla_computation = XlaComputation(std::move(*proto.mutable_hlo_module())); return OkStatus(); @@ -80,22 +228,36 @@ Status MlirToXlaComputation(mlir::ModuleOp module, StatusOr> ParseMlirModuleString( absl::string_view mlir_module_str, mlir::MLIRContext& context) { - mlir::OwningOpRef module; mlir::DialectRegistry registry; + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); mlir::func::registerAllExtensions(registry); + mlir::mhlo::registerAllMhloDialects(registry); + mlir::stablehlo::registerAllDialects(registry); context.appendDialectRegistry(registry); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); + mlir::BaseScopedDiagnosticHandler diagnostic_handler(&context); - module = mlir::parseSourceString( - llvm::StringRef(mlir_module_str.data(), mlir_module_str.size()), - &context); + mlir::OwningOpRef module = + mlir::parseSourceString( + llvm::StringRef(mlir_module_str.data(), mlir_module_str.size()), + // IR may be invalid because some fields may be using DenseElements + // instead of DenseArray. We rectify that below and verify after. + mlir::ParserConfig{&context, /*verifyAfterParse=*/false}); if (!module) { return diagnostic_handler.ConsumeStatus(); } + + // In + // https://github.com/google/jax/commit/184e3a88004680dbf34328b05c5fc0d869cc4a93, + // fields on some ops were changed to use Dense{Bool,I64}ArrayAttr instead of + // I64DenseElementsAttr (DenseIntElementsAttr). Some clients still expect + // dense elements, not dense arrays, so when serializing we always convert the + // arrays to elements. The elements need to be converted back to arrays when + // deserializing. + // TODO: b/320507168 - Remove the conversion code, and verifyAfterParse. + TF_RETURN_IF_ERROR(UpgradeVersionedStablehlo(*module)); if (failed(module->verifyInvariants())) { VLOG(1) << "MLIR verification failed."; module->dump(); @@ -114,4 +276,73 @@ Status ParseMlirModuleStringAndConvertToXlaComputation( return_tuple); } +absl::StatusOr SerializeUsingNativeBytecode( + mlir::ModuleOp module, std::optional plugin_version) { + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + mlir::BytecodeWriterConfig config; + // Pin bytecode version to 1 until transition to stable. + // TODO: b/285913864 - Remove post enabling frameworks to set it. + config.setDesiredBytecodeVersion(1); + // In + // https://github.com/google/jax/commit/184e3a88004680dbf34328b05c5fc0d869cc4a93, + // fields on some ops were changed to use Dense{Bool,I64}ArrayAttr instead of + // I64DenseElementsAttr (DenseIntElementsAttr). Some clients still expect + // dense elements, not dense arrays, so convert the arrays to elements before + // serializing. The elements need to be converted back to arrays when + // deserializing. + // TODO: b/320507168 - Remove this conversion code. + mlir::OwningOpRef cloned = module.clone(); + DowngradeStablehlo(*cloned, plugin_version); + if (mlir::failed(mlir::writeBytecodeToFile(*cloned, os, config))) { + return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed"); + } + return bytecode; +} + +absl::StatusOr SerializeUsingVersionedStablehlo( + mlir::ModuleOp mlir_module, absl::string_view target, bool inplace) { + // Legalize CHLO -> [StableHLO+Shape] -> StableHLO + // Preserve higher-level ops with XLA support. To be replaced by composites. + mlir::PassManager pm(mlir_module->getContext()); + pm.addNestedPass( + mlir::mhlo::createChloLegalizeToHighLevelMhloPass()); + pm.addNestedPass( + mlir::stablehlo::createChloLegalizeToStablehloPass()); + pm.addNestedPass( + mlir::stablehlo::createShapeLegalizeToStablehloPass()); + pm.addPass(mlir::createReconcileUnrealizedCastsPass()); + pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + if (!mlir::succeeded(pm.run(mlir_module))) { + return xla::InvalidArgument("CHLO => [MHLO+Shape] => StableHLO failed"); + } + + // Avoid mutating the original module if it will be reused elsewhere + mlir::OwningOpRef cloned; + if (!inplace) { + cloned = mlir_module.clone(); + mlir_module = *cloned; + } + + // Serialize portable artifact + std::string buffer; + llvm::raw_string_ostream os(buffer); + if (failed( + mlir::stablehlo::serializePortableArtifact(mlir_module, target, os))) + return xla::InvalidArgument("Failed to serialize StableHLO"); + return buffer; +} + +Status UpgradeVersionedStablehlo(mlir::ModuleOp mlir_module) { + // Apply StableHLO bytecode patch + UpgradeStablehlo(mlir_module); + + // Upgrade if VHLO + mlir::PassManager pm(mlir_module->getContext()); + mlir::stablehlo::createStablehloDeserializePipeline(pm); + if (!mlir::succeeded(pm.run(mlir_module))) + return xla::InvalidArgument("Failed to upgrade versioned StableHLO."); + return OkStatus(); +} + } // namespace xla diff --git a/xla/pjrt/mlir_to_hlo.h b/xla/pjrt/mlir_to_hlo.h index 761a0d073aba4..9bad539971742 100644 --- a/xla/pjrt/mlir_to_hlo.h +++ b/xla/pjrt/mlir_to_hlo.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,14 +30,40 @@ StatusOr> ParseMlirModuleString( // Converts an CHLO/MHLO module to XLA HLO. Status MlirToXlaComputation(mlir::ModuleOp module, XlaComputation& xla_computation, - bool use_tuple_args, bool return_tuple, - bool legalize_sparse_ops = false); + bool use_tuple_args, bool return_tuple); // Converts an MHLO/CHLO module string to an XLA computation. Status ParseMlirModuleStringAndConvertToXlaComputation( absl::string_view mlir_module_str, XlaComputation& xla_computation, bool use_tuple_args, bool return_tuple); +// Serialize using MLIR Bytecode Format which does not guarantee forward or +// backward compatiblity of the dialects used. If passing StableHLO with forward +// or backward compatibility requirements, use SerializeUsingVersionedStablehlo. +absl::StatusOr SerializeUsingNativeBytecode( + mlir::ModuleOp mlir_module, std::optional plugin_version); + +// Serializes an MLIR module to a portable artifact with forward and backward +// compatibility. Supports modules using StableHLO/MHLO/CHLO/Func dialects. +// Target parameter is a StableHLO version string ("0.9.0") which can be used +// for forward compatibility to specify the target downgrade version. +// Most commonly should use: +// `mlir::stablehlo::getCurrentVersion()` for backward compat but not forward. +// `mlir::stablehlo::getMinimumVersion()` for maximum forward compatibility. +// Ideally should be the `mlir::stablehlo::getCurrentVersion()` of the plugin. +// If program contains dialects that aren't supposed in StableHLO portable +// artifacts, use SerializeUsingNativeBytecode. +absl::StatusOr SerializeUsingVersionedStablehlo( + mlir::ModuleOp mlir_module, absl::string_view target, bool inplace = false); + +// Given a module that might be a portable artifact, deserialize and upgrade it +// back to StableHLO. +// If module is not a portable artifact, this method is identity. Only fails +// on portable artifacts that are outside of the compatibility window. +// `ParseMlirModuleString` uses this method, and should be preferred to directly +// calling `UpgradeVersionedStablehlo` where possible. +Status UpgradeVersionedStablehlo(mlir::ModuleOp mlir_module); + } // namespace xla #endif // XLA_PJRT_MLIR_TO_HLO_H_ diff --git a/xla/pjrt/pjrt_api.cc b/xla/pjrt/pjrt_api.cc index 60c37381e2692..6cfb7d2b07058 100644 --- a/xla/pjrt/pjrt_api.cc +++ b/xla/pjrt/pjrt_api.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -57,7 +57,7 @@ static std::string CanonicalizeDeviceType(absl::string_view device_type) { return absl::AsciiStrToLower(device_type); } -xla::StatusOr PjrtApi(absl::string_view device_type) { +absl::StatusOr PjrtApi(absl::string_view device_type) { std::string canonicalize_device_type = CanonicalizeDeviceType(device_type); auto iter = pjrt_apis->find(canonicalize_device_type); if (iter == pjrt_apis->end()) { @@ -67,7 +67,7 @@ xla::StatusOr PjrtApi(absl::string_view device_type) { return iter->second.first; } -xla::Status SetPjrtApi(absl::string_view device_type, const PJRT_Api* api) { +absl::Status SetPjrtApi(absl::string_view device_type, const PJRT_Api* api) { std::string canonicalize_device_type = CanonicalizeDeviceType(device_type); if (auto iter = pjrt_apis->find(canonicalize_device_type); iter != pjrt_apis->end()) { @@ -77,12 +77,12 @@ xla::Status SetPjrtApi(absl::string_view device_type, const PJRT_Api* api) { (*pjrt_apis)[canonicalize_device_type] = std::make_pair(api, /*is_initialized=*/false); LOG(INFO) << "PJRT_Api is set for device type " << canonicalize_device_type; - return tsl::OkStatus(); + return absl::OkStatus(); } typedef const PJRT_Api* (*PjrtApiInitFn)(); -xla::StatusOr LoadPjrtPlugin(absl::string_view device_type, - absl::string_view library_path) { +absl::StatusOr LoadPjrtPlugin(absl::string_view device_type, + absl::string_view library_path) { #ifdef PLATFORM_WINDOWS return tsl::errors::Unimplemented( "LoadPjrtPlugin is not implemented on windows yet."); @@ -105,7 +105,7 @@ xla::StatusOr LoadPjrtPlugin(absl::string_view device_type, #endif } -xla::StatusOr IsPjrtPluginInitialized(absl::string_view device_type) { +absl::StatusOr IsPjrtPluginInitialized(absl::string_view device_type) { std::string canonicalize_device_type = CanonicalizeDeviceType(device_type); auto iter = pjrt_apis->find(canonicalize_device_type); if (iter == pjrt_apis->end()) { @@ -128,7 +128,7 @@ static bool IsPjRtCompatibilityEnabled() { return enabled; } -xla::Status InitializePjrtPlugin(absl::string_view device_type) { +absl::Status InitializePjrtPlugin(absl::string_view device_type) { std::string canonicalize_device_type = CanonicalizeDeviceType(device_type); auto iter = pjrt_apis->find(canonicalize_device_type); if (iter == pjrt_apis->end()) { @@ -177,7 +177,7 @@ xla::Status InitializePjrtPlugin(absl::string_view device_type) { } PJRT_Plugin_Initialize_Args args; args.struct_size = PJRT_Plugin_Initialize_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; RETURN_STATUS_IF_PJRT_ERROR(pjrt_api->PJRT_Plugin_Initialize(&args), pjrt_api); iter->second.second = true; diff --git a/xla/pjrt/pjrt_api.h b/xla/pjrt/pjrt_api.h index 8a1d3cfe7468c..eff361c3cbe7b 100644 --- a/xla/pjrt/pjrt_api.h +++ b/xla/pjrt/pjrt_api.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,22 +25,22 @@ namespace pjrt { // Gets and sets the global map for PJRT_Api*. Not thread safe. `device_type` is // case insensitive. -xla::StatusOr PjrtApi(absl::string_view device_type); -xla::Status SetPjrtApi(absl::string_view device_type, const PJRT_Api* api); +absl::StatusOr PjrtApi(absl::string_view device_type); +absl::Status SetPjrtApi(absl::string_view device_type, const PJRT_Api* api); // Loads a PJRT plugin. The library provided by library_path must export a // symbol called `GetPjrtApi` with function signature `const PJRT_Api* // GetPjrtApi()`. This method dlopen the plugin library, dlsym `GetPjrtApi`, // calls `GetPjrtApi` and `SetPjrtApi`. Returns the loaded PJRT_Api* if // successful. -xla::StatusOr LoadPjrtPlugin(absl::string_view device_type, - absl::string_view library_path); +absl::StatusOr LoadPjrtPlugin(absl::string_view device_type, + absl::string_view library_path); // Requires that SetPjrtApi has been successfully called on `device_type` before // calling this method. -xla::StatusOr IsPjrtPluginInitialized(absl::string_view device_type); +absl::StatusOr IsPjrtPluginInitialized(absl::string_view device_type); // Initializes a PJRT plugin with `PJRT_Plugin_Initialize`. -xla::Status InitializePjrtPlugin(absl::string_view device_type); +absl::Status InitializePjrtPlugin(absl::string_view device_type); } // namespace pjrt diff --git a/xla/pjrt/pjrt_api_test.cc b/xla/pjrt/pjrt_api_test.cc index e8abdcc8ac732..b6e13ca5d14e2 100644 --- a/xla/pjrt/pjrt_api_test.cc +++ b/xla/pjrt/pjrt_api_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc index c7504baf2e333..84211a7c1e4d7 100644 --- a/xla/pjrt/pjrt_c_api_client.cc +++ b/xla/pjrt/pjrt_c_api_client.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ limitations under the License. #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" +#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -36,13 +37,14 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Bytecode/BytecodeWriter.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/MLProgram/IR/MLProgram.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -55,7 +57,10 @@ limitations under the License. #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/c/pjrt_c_api_profiler_extension.h" #include "xla/pjrt/compile_options.pb.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" @@ -63,6 +68,7 @@ limitations under the License. #include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/service/computation_placer.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" @@ -87,27 +93,21 @@ namespace xla { // Return error future if not success and frees the PJRT_Error returned by // `expr`. -#define RETURN_FUTURE_IF_ERROR(expr, c_api) \ - do { \ - PJRT_Error* error = (expr); \ - std::unique_ptr _error( \ - error, pjrt::MakeErrorDeleter(c_api)); \ - xla::Status _status = pjrt::PjrtErrorToStatus(_error.get(), c_api); \ - if (!_status.ok()) { \ - return PjRtFuture(_status); \ - } \ +#define RETURN_FUTURE_IF_ERROR(expr, c_api) \ + do { \ + PJRT_Error* error = (expr); \ + std::unique_ptr _error( \ + error, pjrt::MakeErrorDeleter(c_api)); \ + absl::Status _status = pjrt::PjrtErrorToStatus(_error.get(), c_api); \ + if (!_status.ok()) { \ + return PjRtFuture(_status); \ + } \ } while (false) // ---------------------------------- Client ----------------------------------- static StatusOr InitClientTopoDesc( const PJRT_Api* c_api, PJRT_Client* c_client) { - if (c_api->pjrt_api_version.major_version == 0 && - c_api->pjrt_api_version.minor_version < 36) { - return Unimplemented( - "Getting TopologyDescription for PJRT client requires plugin with PJRT " - "C API version >= 0.36"); - } StatusOr c_topo = pjrt::GetTopologyDescription(c_client, c_api); TF_RETURN_IF_ERROR(c_topo.status()); @@ -131,6 +131,7 @@ PjRtCApiClient::PjRtCApiClient( platform_name_(::pjrt::GetPlatformName(c_client, c_api)), platform_id_(tsl::Fingerprint64(platform_name_)) { InitDevicesAndMemorySpaces(); + InitAttributes(); LOG(INFO) << "PjRtCApiClient created."; } @@ -138,7 +139,7 @@ void PjRtCApiClient::InitDevicesAndMemorySpaces() { // Initialize devices. PJRT_Client_Devices_Args devices_args; devices_args.struct_size = PJRT_Client_Devices_Args_STRUCT_SIZE; - devices_args.priv = nullptr; + devices_args.extension_start = nullptr; devices_args.client = c_client_.get(); pjrt::LogFatalIfPjrtError(c_api_->PJRT_Client_Devices(&devices_args), c_api_); @@ -159,7 +160,7 @@ void PjRtCApiClient::InitDevicesAndMemorySpaces() { // Initialize addressable devices. PJRT_Client_AddressableDevices_Args address_args; address_args.struct_size = PJRT_Client_AddressableDevices_Args_STRUCT_SIZE; - address_args.priv = nullptr; + address_args.extension_start = nullptr; address_args.client = c_client_.get(); pjrt::LogFatalIfPjrtError( @@ -177,7 +178,7 @@ void PjRtCApiClient::InitDevicesAndMemorySpaces() { // TODO(yueshengys): Initialize global memory spaces when supported. PJRT_Client_AddressableMemories_Args memory_args; memory_args.struct_size = PJRT_Client_AddressableMemories_Args_STRUCT_SIZE; - memory_args.priv = nullptr; + memory_args.extension_start = nullptr; memory_args.client = c_client_.get(); std::unique_ptr client_error( @@ -209,7 +210,7 @@ void PjRtCApiClient::InitDevicesAndMemorySpaces() { PJRT_Device* c_device = cpp_device->c_device(); PJRT_Device_AddressableMemories_Args args; args.struct_size = PJRT_Device_AddressableMemories_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device = c_device; std::unique_ptr device_error( @@ -238,7 +239,7 @@ void PjRtCApiClient::InitDevicesAndMemorySpaces() { PJRT_Memory* c_memory = cpp_memory->c_memory(); PJRT_Memory_AddressableByDevices_Args args; args.struct_size = PJRT_Memory_AddressableByDevices_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.memory = c_memory; pjrt::LogFatalIfPjrtError(c_api_->PJRT_Memory_AddressableByDevices(&args), c_api_); @@ -252,6 +253,15 @@ void PjRtCApiClient::InitDevicesAndMemorySpaces() { } } +void PjRtCApiClient::InitAttributes() { + PJRT_Plugin_Attributes_Args args; + args.struct_size = PJRT_Plugin_Attributes_Args_STRUCT_SIZE; + args.extension_start = nullptr; + pjrt::LogFatalIfPjrtError(c_api_->PJRT_Plugin_Attributes(&args), c_api_); + attributes_ = + pjrt::ConvertFromPjRtNamedValueList(args.attributes, args.num_attributes); +} + int PjRtCApiClient::device_count() const { return devices_.size(); } int PjRtCApiClient::addressable_device_count() const { @@ -269,7 +279,7 @@ absl::Span PjRtCApiClient::addressable_devices() const { int PjRtCApiClient::process_index() const { PJRT_Client_ProcessIndex_Args process_index_args; process_index_args.struct_size = PJRT_Client_ProcessIndex_Args_STRUCT_SIZE; - process_index_args.priv = nullptr; + process_index_args.extension_start = nullptr; process_index_args.client = c_client_.get(); pjrt::LogFatalIfPjrtError( c_api_->PJRT_Client_ProcessIndex(&process_index_args), c_api_); @@ -281,6 +291,12 @@ absl::string_view PjRtCApiClient::platform_version() const { return platform_version_; } +std::optional PjRtCApiClient::plugin_attributes() const { + return PjRtPluginAttributes{c_api_->pjrt_api_version.major_version, + c_api_->pjrt_api_version.minor_version, + attributes_}; +} + static DeviceAssignment CalculateDefaultAssignment( int num_replicas, int num_partitions, absl::Span device_assignment) { @@ -298,7 +314,7 @@ StatusOr PjRtCApiClient::GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const { PJRT_Client_DefaultDeviceAssignment_Args args; args.struct_size = PJRT_Client_DefaultDeviceAssignment_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.client = c_client_.get(); args.num_replicas = num_replicas; args.num_partitions = num_partitions; @@ -314,22 +330,32 @@ StatusOr PjRtCApiClient::GetDefaultDeviceAssignment( } StatusOr PjRtCApiClient::LookupDevice(int device_id) const { + return LookupDevice(PjRtGlobalDeviceId(device_id)); +} + +StatusOr PjRtCApiClient::LookupDevice( + PjRtGlobalDeviceId global_device_id) const { PJRT_Client_LookupDevice_Args args; args.struct_size = PJRT_Client_LookupDevice_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.client = c_client_.get(); - args.id = device_id; + args.id = global_device_id.value(); RETURN_STATUS_IF_PJRT_ERROR(c_api_->PJRT_Client_LookupDevice(&args), c_api_); return GetCppDevice(args.device); } StatusOr PjRtCApiClient::LookupAddressableDevice( int local_hardware_id) const { + return LookupAddressableDevice(PjRtLocalDeviceId(local_hardware_id)); +} + +StatusOr PjRtCApiClient::LookupAddressableDevice( + PjRtLocalDeviceId local_device_id) const { PJRT_Client_LookupAddressableDevice_Args args; args.struct_size = PJRT_Client_LookupAddressableDevice_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.client = c_client_.get(); - args.local_hardware_id = local_hardware_id; + args.local_hardware_id = local_device_id.value(); RETURN_STATUS_IF_PJRT_ERROR( c_api_->PJRT_Client_LookupAddressableDevice(&args), c_api_); return GetCppDevice(args.addressable_device); @@ -347,7 +373,10 @@ static StatusOr> InitializeArgsAndCompile( const std::string& format) { PJRT_Client_Compile_Args args; args.struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE; - args.priv = nullptr; + PJRT_Profiler_Extension profiler_extension = + pjrt::CreatePjrtProfilerExtension("PJRT_Client_Compile linkage"); + args.extension_start = + reinterpret_cast(&profiler_extension); args.client = client; TF_ASSIGN_OR_RETURN(const CompileOptionsProto options_proto, options.ToProto()); @@ -357,7 +386,7 @@ static StatusOr> InitializeArgsAndCompile( PJRT_Program program; program.struct_size = PJRT_Program_STRUCT_SIZE; - program.priv = nullptr; + program.extension_start = nullptr; program.code = const_cast(code.c_str()); program.code_size = code.size(); program.format = format.c_str(); @@ -380,19 +409,15 @@ StatusOr> PjRtCApiClient::Compile( StatusOr> PjRtCApiClient::Compile( mlir::ModuleOp module, CompileOptions options) { - std::string module_bytecode; - { - llvm::raw_string_ostream os(module_bytecode); - mlir::BytecodeWriterConfig config; - // Pin bytecode version to 1 until transition to stable. - // TODO(285913864): Remove post enabling frameworks to set it. - config.setDesiredBytecodeVersion(1); - if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) - return absl::UnknownError("writeBytecodeToFile() failed."); - } + // TODO: Once plugins are ready, use SerializeUsingVersionedStablehlo. + if (!pjrt_c_api()) llvm::report_fatal_error("pjrt_c_api is null"); + TF_ASSIGN_OR_RETURN( + std::string serialized, + xla::SerializeUsingNativeBytecode( + module, plugin_attributes()->pjrt_c_api_minor_version)); std::string format(pjrt::kMlirFormat); return InitializeArgsAndCompile(this, c_api_, c_client_.get(), options, - module_bytecode, format); + serialized, format); } StatusOr> @@ -401,7 +426,7 @@ PjRtCApiClient::DeserializeExecutable(absl::string_view serialized, PJRT_Executable_DeserializeAndLoad_Args des_args; des_args.struct_size = PJRT_Executable_DeserializeAndLoad_Args_STRUCT_SIZE; - des_args.priv = nullptr; + des_args.extension_start = nullptr; des_args.client = c_client_.get(); des_args.serialized_executable = serialized.data(); des_args.serialized_executable_size = serialized.length(); @@ -441,7 +466,7 @@ StatusOr PjRtCApiClient::UnsafeBufferPointer( PJRT_Buffer_UnsafePointer_Args args; args.struct_size = PJRT_Buffer_UnsafePointer_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = tensorflow::down_cast(buffer)->c_buffer(); @@ -455,23 +480,23 @@ PjRtCApiClient::BufferFromHostBufferInternalImpl( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, + absl::AnyInvocable on_done_with_host_buffer, std::variant device_or_memory, const Layout* device_layout) { if (host_buffer_semantics != HostBufferSemantics::kImmutableOnlyDuringCall && - host_buffer_semantics != HostBufferSemantics::kZeroCopy && + host_buffer_semantics != HostBufferSemantics::kImmutableZeroCopy && host_buffer_semantics != HostBufferSemantics::kImmutableUntilTransferCompletes) { return Unimplemented( "PJRT C API does not support HostBufferSemantics other than " "HostBufferSemantics::kImmutableOnlyDuringCall, " - "HostBufferSemantics::kZeroCopy and " + "HostBufferSemantics::kImmutableZeroCopy and " "HostBufferSemantics::kImmutableUntilTransferCompletes."); } PJRT_Client_BufferFromHostBuffer_Args args; args.struct_size = PJRT_Client_BufferFromHostBuffer_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.client = c_client_.get(); args.data = data; args.type = ::pjrt::ConvertToPjRtBufferType(type); @@ -521,19 +546,19 @@ PjRtCApiClient::BufferFromHostBufferInternalImpl( if (on_done_with_host_buffer) { PJRT_Event_OnReady_Args event_args; event_args.struct_size = PJRT_Event_OnReady_Args_STRUCT_SIZE; - event_args.priv = nullptr; + event_args.extension_start = nullptr; event_args.event = event.get(); - event_args.user_arg = new std::function( + event_args.user_arg = new absl::AnyInvocable( [on_done_with_host_buffer = std::move(on_done_with_host_buffer), - c_api = c_api_](PJRT_Error* error) { + c_api = c_api_](PJRT_Error* error) mutable { if (error) { ::pjrt::MakeErrorDeleter(c_api)(error); } - on_done_with_host_buffer(); + std::move(on_done_with_host_buffer)(); }); event_args.callback = [](PJRT_Error* error, void* args) { - std::function* on_done_with_host_buffer = - reinterpret_cast*>(args); + auto* on_done_with_host_buffer = + reinterpret_cast*>(args); (*on_done_with_host_buffer)(error); delete on_done_with_host_buffer; }; @@ -549,32 +574,33 @@ StatusOr> PjRtCApiClient::BufferFromHostBuffer( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, + absl::AnyInvocable on_done_with_host_buffer, PjRtMemorySpace* memory_space, const Layout* device_layout) { return BufferFromHostBufferInternalImpl( data, type, dims, byte_strides, host_buffer_semantics, - on_done_with_host_buffer, memory_space, device_layout); + std::move(on_done_with_host_buffer), memory_space, device_layout); } StatusOr> PjRtCApiClient::BufferFromHostBuffer( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, PjRtDevice* device, + absl::AnyInvocable on_done_with_host_buffer, PjRtDevice* device, const Layout* device_layout) { return BufferFromHostBufferInternalImpl( data, type, dims, byte_strides, host_buffer_semantics, - on_done_with_host_buffer, device, device_layout); + std::move(on_done_with_host_buffer), device, device_layout); } StatusOr> PjRtCApiClient::BufferFromHostBuffer( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, PjRtDevice* device) { + absl::AnyInvocable on_done_with_host_buffer, + PjRtDevice* device) { return BufferFromHostBufferInternalImpl( data, type, dims, byte_strides, host_buffer_semantics, - on_done_with_host_buffer, device, /*device_layout=*/nullptr); + std::move(on_done_with_host_buffer), device, /*device_layout=*/nullptr); } StatusOr> PjRtCApiClient::CreateViewOfDeviceBuffer( @@ -583,7 +609,7 @@ StatusOr> PjRtCApiClient::CreateViewOfDeviceBuffer( std::optional stream) { PJRT_Client_CreateViewOfDeviceBuffer_Args args; args.struct_size = PJRT_Client_CreateViewOfDeviceBuffer_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.client = c_client_.get(); args.device_buffer_ptr = device_ptr; args.dims = shape.dimensions().data(); @@ -616,11 +642,7 @@ StatusOr> PjRtCApiClient::CreateViewOfDeviceBuffer( args.stream = reinterpret_cast(nullptr); } const PJRT_Api* c_api = pjrt_c_api(); - // TODO(jieying): To be removed after 12/29/2023. - if (c_api->pjrt_api_version.minor_version < 33) { - return Unimplemented( - "The plugin does not support CreateViewOfDeviceBuffer"); - } + RETURN_STATUS_IF_PJRT_ERROR( c_api->PJRT_Client_CreateViewOfDeviceBuffer(&args), c_api); @@ -641,7 +663,7 @@ PjRtCApiDeviceDescription::PjRtCApiDeviceDescription( int PjRtCApiDeviceDescription::id() const { PJRT_DeviceDescription_Id_Args args; args.struct_size = PJRT_DeviceDescription_Id_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device_description = device_description_; pjrt::LogFatalIfPjrtError(c_api_->PJRT_DeviceDescription_Id(&args), c_api_); return args.id; @@ -650,7 +672,7 @@ int PjRtCApiDeviceDescription::id() const { int PjRtCApiDeviceDescription::process_index() const { PJRT_DeviceDescription_ProcessIndex_Args args; args.struct_size = PJRT_DeviceDescription_ProcessIndex_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device_description = device_description_; pjrt::LogFatalIfPjrtError(c_api_->PJRT_DeviceDescription_ProcessIndex(&args), c_api_); @@ -661,7 +683,7 @@ void PjRtCApiDeviceDescription::InitAttributes() { attributes_ = {}; PJRT_DeviceDescription_Attributes_Args args; args.struct_size = PJRT_DeviceDescription_Attributes_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device_description = device_description_; pjrt::LogFatalIfPjrtError(c_api_->PJRT_DeviceDescription_Attributes(&args), c_api_); @@ -712,7 +734,7 @@ PjRtCApiDeviceDescription::Attributes() const { absl::string_view PjRtCApiDeviceDescription::device_kind() const { PJRT_DeviceDescription_Kind_Args args; args.struct_size = PJRT_DeviceDescription_Kind_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device_description = device_description_; pjrt::LogFatalIfPjrtError(c_api_->PJRT_DeviceDescription_Kind(&args), c_api_); @@ -724,7 +746,7 @@ absl::string_view PjRtCApiDeviceDescription::device_kind() const { absl::string_view PjRtCApiDeviceDescription::DebugString() const { PJRT_DeviceDescription_DebugString_Args args; args.struct_size = PJRT_DeviceDescription_DebugString_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device_description = device_description_; pjrt::LogFatalIfPjrtError(c_api_->PJRT_DeviceDescription_DebugString(&args), c_api_); @@ -735,7 +757,7 @@ absl::string_view PjRtCApiDeviceDescription::DebugString() const { absl::string_view PjRtCApiDeviceDescription::ToString() const { PJRT_DeviceDescription_ToString_Args args; args.struct_size = PJRT_DeviceDescription_ToString_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device_description = device_description_; pjrt::LogFatalIfPjrtError(c_api_->PJRT_DeviceDescription_ToString(&args), c_api_); @@ -754,7 +776,7 @@ PjRtClient* PjRtCApiDevice::client() const { return client_; } bool PjRtCApiDevice::IsAddressable() const { PJRT_Device_IsAddressable_Args args; args.struct_size = PJRT_Device_IsAddressable_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device = device_; const PJRT_Api* api = client_->pjrt_c_api(); pjrt::LogFatalIfPjrtError(api->PJRT_Device_IsAddressable(&args), api); @@ -762,19 +784,23 @@ bool PjRtCApiDevice::IsAddressable() const { } int PjRtCApiDevice::local_hardware_id() const { + return local_hardware_id_typed().value(); +} + +PjRtLocalHardwareId PjRtCApiDevice::local_hardware_id_typed() const { PJRT_Device_LocalHardwareId_Args args; args.struct_size = PJRT_Device_LocalHardwareId_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device = device_; const PJRT_Api* api = client_->pjrt_c_api(); pjrt::LogFatalIfPjrtError(api->PJRT_Device_LocalHardwareId(&args), api); - return args.local_hardware_id; + return PjRtLocalHardwareId(args.local_hardware_id); } StatusOr PjRtCApiDevice::default_memory_space() const { PJRT_Device_DefaultMemory_Args args; args.struct_size = PJRT_Device_DefaultMemory_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device = device_; const PJRT_Api* api = client_->pjrt_c_api(); RETURN_STATUS_IF_PJRT_ERROR(api->PJRT_Device_DefaultMemory(&args), api); @@ -784,7 +810,7 @@ StatusOr PjRtCApiDevice::default_memory_space() const { StatusOr PjRtCApiDevice::GetAllocatorStats() const { PJRT_Device_MemoryStats_Args args; args.struct_size = PJRT_Device_MemoryStats_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.device = device_; const PJRT_Api* api = client_->pjrt_c_api(); RETURN_STATUS_IF_PJRT_ERROR(api->PJRT_Device_MemoryStats(&args), api); @@ -852,7 +878,7 @@ PjRtClient* PjRtCApiMemorySpace::client() const { return client_; } int PjRtCApiMemorySpace::id() const { PJRT_Memory_Id_Args args; args.struct_size = PJRT_Memory_Id_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.memory = c_memory_; pjrt::LogFatalIfPjrtError(pjrt_c_api()->PJRT_Memory_Id(&args), pjrt_c_api()); return args.id; @@ -861,7 +887,7 @@ int PjRtCApiMemorySpace::id() const { absl::string_view PjRtCApiMemorySpace::memory_space_kind() const { PJRT_Memory_Kind_Args args; args.struct_size = PJRT_Memory_Kind_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.memory = c_memory_; pjrt::LogFatalIfPjrtError(pjrt_c_api()->PJRT_Memory_Kind(&args), @@ -873,7 +899,7 @@ absl::string_view PjRtCApiMemorySpace::memory_space_kind() const { absl::string_view PjRtCApiMemorySpace::DebugString() const { PJRT_Memory_DebugString_Args args; args.struct_size = PJRT_Memory_DebugString_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.memory = c_memory_; pjrt::LogFatalIfPjrtError(pjrt_c_api()->PJRT_Memory_DebugString(&args), pjrt_c_api()); @@ -883,7 +909,7 @@ absl::string_view PjRtCApiMemorySpace::DebugString() const { absl::string_view PjRtCApiMemorySpace::ToString() const { PJRT_Memory_ToString_Args args; args.struct_size = PJRT_Memory_ToString_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.memory = c_memory_; pjrt::LogFatalIfPjrtError(pjrt_c_api()->PJRT_Memory_ToString(&args), pjrt_c_api()); @@ -903,7 +929,7 @@ absl::string_view PjRtCApiExecutable::name() const { PJRT_Executable_Name_Args args; args.executable = executable; args.struct_size = PJRT_Executable_Name_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; pjrt::LogFatalIfPjrtError(c_api->PJRT_Executable_Name(&args), c_api); return absl::string_view(args.executable_name, args.executable_name_size); @@ -915,7 +941,7 @@ int PjRtCApiExecutable::num_replicas() const { PJRT_Executable_NumReplicas_Args args; args.executable = executable; args.struct_size = PJRT_Executable_NumReplicas_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; pjrt::LogFatalIfPjrtError(c_api->PJRT_Executable_NumReplicas(&args), c_api); return args.num_replicas; @@ -927,7 +953,7 @@ int PjRtCApiExecutable::num_partitions() const { PJRT_Executable_NumPartitions_Args args; args.executable = executable; args.struct_size = PJRT_Executable_NumPartitions_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; pjrt::LogFatalIfPjrtError(c_api->PJRT_Executable_NumPartitions(&args), c_api); return args.num_partitions; @@ -939,7 +965,7 @@ int64_t PjRtCApiExecutable::SizeOfGeneratedCodeInBytes() const { PJRT_Executable_SizeOfGeneratedCodeInBytes_Args args; args.struct_size = PJRT_Executable_SizeOfGeneratedCodeInBytes_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = executable; pjrt::LogFatalIfPjrtError( @@ -952,7 +978,7 @@ PjRtCApiExecutable::GetCostAnalysis() const { // Initialize function call args PJRT_Executable_GetCostAnalysis_Args args; args.struct_size = PJRT_Executable_GetCostAnalysis_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = c_executable(); // Make PJRT C API call @@ -969,16 +995,11 @@ StatusOr>> PjRtCApiExecutable::GetOutputElementTypes() const { PJRT_Executable_OutputElementTypes_Args args; args.struct_size = PJRT_Executable_OutputElementTypes_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = c_executable(); const PJRT_Api* c_api = pjrt_c_api(); - // TODO(yueshengys): To be removed after 11/29/2023. - if (c_api->PJRT_Executable_OutputElementTypes == nullptr) { - return Unimplemented("PJRT C API does not support GetOutputElementTypes"); - } - RETURN_STATUS_IF_PJRT_ERROR(c_api->PJRT_Executable_OutputElementTypes(&args), c_api); @@ -994,16 +1015,11 @@ StatusOr>> PjRtCApiExecutable::GetOutputDimensions() const { PJRT_Executable_OutputDimensions_Args args; args.struct_size = PJRT_Executable_OutputDimensions_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = c_executable(); const PJRT_Api* c_api = pjrt_c_api(); - // TODO(yueshengys): To be removed after 11/29/2023. - if (c_api->PJRT_Executable_OutputDimensions == nullptr) { - return Unimplemented("PJRT C API does not support GetOutputDimensions"); - } - RETURN_STATUS_IF_PJRT_ERROR(c_api->PJRT_Executable_OutputDimensions(&args), c_api); @@ -1025,7 +1041,7 @@ StatusOr>> PjRtCApiExecutable::GetOutputMemoryKinds() const { PJRT_Executable_OutputMemoryKinds_Args args; args.struct_size = PJRT_Executable_OutputMemoryKinds_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = c_executable(); const PJRT_Api* c_api = pjrt_c_api(); @@ -1047,11 +1063,11 @@ PjRtCApiExecutable::GetHloModules() const { auto* executable = c_executable(); PJRT_Executable_OptimizedProgram_Args args; args.struct_size = PJRT_Executable_OptimizedProgram_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = executable; PJRT_Program program; program.struct_size = PJRT_Program_STRUCT_SIZE; - program.priv = nullptr; + program.extension_start = nullptr; program.code = nullptr; args.program = &program; @@ -1069,7 +1085,7 @@ PjRtCApiExecutable::GetHloModules() const { absl::string_view program_format(program.format, program.format_size); if (program_format != ::pjrt::kHloWithConfigFormat && program_format != ::pjrt::kMlirFormat) { - return xla::InternalError( + return xla::Internal( "expected program format `hlo_with_config` or `mlir` but got %s", program_format); } @@ -1077,25 +1093,18 @@ PjRtCApiExecutable::GetHloModules() const { if (program_format == ::pjrt::kMlirFormat) { xla::HloProto hlo_proto; mlir::MLIRContext ctx; - mlir::DialectRegistry registry; - registry.insert(); - mlir::stablehlo::registerAllDialects(registry); - mlir::mhlo::registerAllMhloDialects(registry); - ctx.appendDialectRegistry(registry); - auto module = mlir::parseSourceString(code, &ctx); - if (!module) return xla::InternalError("failed to parse source module"); + TF_ASSIGN_OR_RETURN( // NOLINT(clang-diagnostic-pre-c++20-compat) + mlir::OwningOpRef module, + ParseMlirModuleString(code, ctx)); mlir::PassManager pm(&ctx); pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); if (mlir::failed(pm.run(module.get()))) - return xla::InternalError("failed to convert to MHLO"); - mlir::MlirToHloConversionOptions options; + return xla::Internal("failed to convert to MHLO"); // TODO(jieying): Tuple args should really come from GetCompileOptions (or // equivalent) once implemented. - TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo( - module.get(), &hlo_proto, /*use_tuple_args=*/false, - /*return_tuple=*/false, options)); + TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module.get(), &hlo_proto, + /*use_tuple_args=*/false, + /*return_tuple=*/false)); xla::DebugOptions debug_options; TF_ASSIGN_OR_RETURN(xla::HloModuleConfig module_config, xla::HloModule::CreateModuleConfigFromProto( @@ -1122,7 +1131,7 @@ StatusOr PjRtCApiExecutable::SerializeExecutable() const { auto* executable = c_executable(); PJRT_Executable_Serialize_Args ser_args; ser_args.struct_size = PJRT_Executable_Serialize_Args_STRUCT_SIZE; - ser_args.priv = nullptr; + ser_args.extension_start = nullptr; ser_args.executable = executable; ser_args.serialized_executable = nullptr; @@ -1136,17 +1145,9 @@ StatusOr PjRtCApiExecutable::SerializeExecutable() const { StatusOr PjRtCApiExecutable::FingerprintExecutable() const { const PJRT_Api* c_api_ = pjrt_c_api(); - if (c_api_->pjrt_api_version.major_version == 0 && - c_api_->pjrt_api_version.minor_version < 35) { - // TODO(yeounoh): To be removed after 01/20/2024. - return xla::Unimplemented( - "Getting fingerprint from unloaded PJRT executable requires plugin " - "with PJRT C API version >= 0.35"); - } - PJRT_Executable_Fingerprint_Args args; args.struct_size = PJRT_Executable_Fingerprint_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = c_executable(); RETURN_STATUS_IF_PJRT_ERROR(c_api_->PJRT_Executable_Fingerprint(&args), c_api_); @@ -1163,7 +1164,7 @@ PjRtCApiLoadedExecutable::PjRtCApiLoadedExecutable( client->pjrt_c_api())) { PJRT_LoadedExecutable_GetExecutable_Args args; args.struct_size = PJRT_LoadedExecutable_GetExecutable_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.loaded_executable = c_loaded_executable(); args.executable = nullptr; pjrt::LogFatalIfPjrtError( @@ -1176,7 +1177,7 @@ PjRtCApiLoadedExecutable::PjRtCApiLoadedExecutable( void PjRtCApiLoadedExecutable::InitDevices() { PJRT_LoadedExecutable_AddressableDevices_Args args; args.struct_size = PJRT_LoadedExecutable_AddressableDevices_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = c_loaded_executable(); args.addressable_devices = nullptr; args.num_addressable_devices = 0; @@ -1244,9 +1245,9 @@ PJRT_SendCallbackInfo CppSendCallbackToC( // PJRT C API doesn't support // use_major_to_minor_data_layout_for_callbacks = false xla::Shape dummy_shape; - xla::Status status = send_callback(xla::PjRtTransferMetadata{dummy_shape}, - ::pjrt::ConvertToCppChunk(*chunk), - total_size_in_bytes, done); + absl::Status status = send_callback(xla::PjRtTransferMetadata{dummy_shape}, + ::pjrt::ConvertToCppChunk(*chunk), + total_size_in_bytes, done); if (!status.ok()) { absl::string_view message = status.message(); return (*callback_error)(pjrt::StatusCodeToPjrtErrorCode(status.code()), @@ -1281,7 +1282,7 @@ CApiCopyToDeviceStream::CApiCopyToDeviceStream( PJRT_CopyToDeviceStream_TotalBytes_Args total_bytes_args; total_bytes_args.struct_size = PJRT_CopyToDeviceStream_TotalBytes_Args_STRUCT_SIZE; - total_bytes_args.priv = nullptr; + total_bytes_args.extension_start = nullptr; total_bytes_args.stream = c_stream_; pjrt::LogFatalIfPjrtError( c_api_->PJRT_CopyToDeviceStream_TotalBytes(&total_bytes_args), c_api_); @@ -1290,7 +1291,7 @@ CApiCopyToDeviceStream::CApiCopyToDeviceStream( PJRT_CopyToDeviceStream_GranuleSize_Args granule_size_args; granule_size_args.struct_size = PJRT_CopyToDeviceStream_GranuleSize_Args_STRUCT_SIZE; - granule_size_args.priv = nullptr; + granule_size_args.extension_start = nullptr; granule_size_args.stream = c_stream_; pjrt::LogFatalIfPjrtError( c_api_->PJRT_CopyToDeviceStream_GranuleSize(&granule_size_args), c_api_); @@ -1300,7 +1301,7 @@ CApiCopyToDeviceStream::CApiCopyToDeviceStream( CApiCopyToDeviceStream::~CApiCopyToDeviceStream() { PJRT_CopyToDeviceStream_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_CopyToDeviceStream_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; + destroy_args.extension_start = nullptr; destroy_args.stream = c_stream_; pjrt::LogFatalIfPjrtError( c_api_->PJRT_CopyToDeviceStream_Destroy(&destroy_args), c_api_); @@ -1312,14 +1313,14 @@ PjRtFuture CApiCopyToDeviceStream::AddChunk(PjRtChunk chunk) { PJRT_CopyToDeviceStream_AddChunk_Args add_chunk_args; add_chunk_args.struct_size = PJRT_CopyToDeviceStream_AddChunk_Args_STRUCT_SIZE; - add_chunk_args.priv = nullptr; + add_chunk_args.extension_start = nullptr; add_chunk_args.stream = c_stream_; add_chunk_args.chunk = &c_chunk; PJRT_CopyToDeviceStream_CurrentBytes_Args current_bytes_args; current_bytes_args.struct_size = PJRT_CopyToDeviceStream_CurrentBytes_Args_STRUCT_SIZE; - current_bytes_args.priv = nullptr; + current_bytes_args.extension_start = nullptr; current_bytes_args.stream = c_stream_; { @@ -1417,7 +1418,7 @@ static void CppRecvCallbackListsToC( } } -xla::StatusOr +absl::StatusOr PjRtCApiLoadedExecutable::GetCommonExecuteArgs( absl::Span> argument_handles, const ExecuteOptions& options, PJRT_ExecuteOptions& c_options, @@ -1439,7 +1440,6 @@ PjRtCApiLoadedExecutable::GetCommonExecuteArgs( PJRT_LoadedExecutable_Execute_Args args; args.struct_size = PJRT_LoadedExecutable_Execute_Args_STRUCT_SIZE; - args.priv = nullptr; args.executable = c_loaded_executable(); args.options = &c_options; args.options->struct_size = PJRT_ExecuteOptions_STRUCT_SIZE; @@ -1474,7 +1474,7 @@ PjRtCApiLoadedExecutable::GetCommonExecuteArgs( PJRT_Executable_NumOutputs_Args numoutputs_args; numoutputs_args.struct_size = PJRT_Executable_NumOutputs_Args_STRUCT_SIZE; - numoutputs_args.priv = nullptr; + numoutputs_args.extension_start = nullptr; numoutputs_args.executable = c_executable(); RETURN_STATUS_IF_PJRT_ERROR( pjrt_c_api()->PJRT_Executable_NumOutputs(&numoutputs_args), pjrt_c_api()); @@ -1544,6 +1544,11 @@ PjRtCApiLoadedExecutable::Execute( non_donatable_input_indices_storage)); args.execute_device = nullptr; + PJRT_Profiler_Extension profiler_extension = + pjrt::CreatePjrtProfilerExtension( + "PJRT_LoadedExecutable_Execute linkage"); + args.extension_start = + reinterpret_cast(&profiler_extension); RETURN_STATUS_IF_PJRT_ERROR( pjrt_c_api()->PJRT_LoadedExecutable_Execute(&args), pjrt_c_api()); @@ -1556,10 +1561,11 @@ PjRtCApiLoadedExecutable::Execute( args.device_complete_events[i], pjrt_c_api()); if (!callback_data->c_send_callbacks.empty() || !callback_data->c_recv_callbacks.empty()) { - device_complete_futures[i].OnReady([callback_data](xla::Status status) { - // Keeps C callbacks alive until execution completes on all - // devices. - }); + device_complete_futures[i].OnReady( + [callback_data](absl::Status status) { + // Keeps C callbacks alive until execution completes on all + // devices. + }); } } @@ -1612,6 +1618,11 @@ PjRtCApiLoadedExecutable::ExecuteWithSingleDevice( args.execute_device = tensorflow::down_cast(device)->c_device(); + PJRT_Profiler_Extension profiler_extension = + pjrt::CreatePjrtProfilerExtension( + "PJRT_LoadedExecutable_Execute linkage"); + args.extension_start = + reinterpret_cast(&profiler_extension); RETURN_STATUS_IF_PJRT_ERROR( pjrt_c_api()->PJRT_LoadedExecutable_Execute(&args), pjrt_c_api()); @@ -1646,7 +1657,7 @@ PjRtCApiLoadedExecutable::ExecutePortable( void PjRtCApiLoadedExecutable::Delete() { PJRT_LoadedExecutable_Delete_Args args; args.struct_size = PJRT_LoadedExecutable_Delete_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = c_loaded_executable(); const PJRT_Api* c_api = pjrt_c_api(); pjrt::LogFatalIfPjrtError(c_api->PJRT_LoadedExecutable_Delete(&args), c_api); @@ -1655,7 +1666,7 @@ void PjRtCApiLoadedExecutable::Delete() { bool PjRtCApiLoadedExecutable::IsDeleted() { PJRT_LoadedExecutable_IsDeleted_Args args; args.struct_size = PJRT_LoadedExecutable_IsDeleted_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = c_loaded_executable(); const PJRT_Api* c_api = pjrt_c_api(); @@ -1679,7 +1690,7 @@ StatusOr PjRtCApiLoadedExecutable::FingerprintExecutable() const { // TODO(yeounoh): To be removed after 01/20/2024. PJRT_LoadedExecutable_Fingerprint_Args args; args.struct_size = PJRT_LoadedExecutable_Fingerprint_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.executable = c_loaded_executable(); const PJRT_Api* c_api = pjrt_c_api(); std::unique_ptr error( @@ -1703,7 +1714,7 @@ PjRtCApiBuffer::PjRtCApiBuffer(PjRtCApiClient* client, PJRT_Buffer* buffer) PrimitiveType PjRtCApiBuffer::element_type() const { PJRT_Buffer_ElementType_Args args; args.struct_size = PJRT_Buffer_ElementType_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); pjrt::LogFatalIfPjrtError(pjrt_c_api()->PJRT_Buffer_ElementType(&args), pjrt_c_api()); @@ -1713,20 +1724,20 @@ PrimitiveType PjRtCApiBuffer::element_type() const { absl::Span PjRtCApiBuffer::dimensions() const { PJRT_Buffer_Dimensions_Args args; args.struct_size = PJRT_Buffer_Dimensions_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); pjrt::LogFatalIfPjrtError(pjrt_c_api()->PJRT_Buffer_Dimensions(&args), pjrt_c_api()); return absl::Span(args.dims, args.num_dims); } -const Layout& PjRtCApiBuffer::layout() const { +std::unique_ptr PjRtCApiBuffer::layout() const { { absl::MutexLock lock(&mu_); if (!layout_.has_value()) { PJRT_Buffer_GetMemoryLayout_Args args; args.struct_size = PJRT_Buffer_GetMemoryLayout_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); pjrt::LogFatalIfPjrtError( pjrt_c_api()->PJRT_Buffer_GetMemoryLayout(&args), pjrt_c_api()); @@ -1738,13 +1749,13 @@ const Layout& PjRtCApiBuffer::layout() const { layout_.emplace(*cpp_layout); } } - return *layout_; + return std::make_unique(*layout_); } bool PjRtCApiBuffer::has_dynamic_dimensions() const { PJRT_Buffer_DynamicDimensionIndices_Args args; args.struct_size = PJRT_Buffer_DynamicDimensionIndices_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); const PJRT_Api* api = pjrt_c_api(); @@ -1769,7 +1780,7 @@ absl::Span PjRtCApiBuffer::is_dynamic_dimension() const { PJRT_Buffer_DynamicDimensionIndices_Args args; args.struct_size = PJRT_Buffer_DynamicDimensionIndices_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); const PJRT_Api* api = pjrt_c_api(); std::unique_ptr error( @@ -1790,7 +1801,7 @@ absl::Span PjRtCApiBuffer::is_dynamic_dimension() const { StatusOr> PjRtCApiBuffer::logical_dimensions() { PJRT_Buffer_UnpaddedDimensions_Args args; args.struct_size = PJRT_Buffer_UnpaddedDimensions_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); RETURN_STATUS_IF_PJRT_ERROR( pjrt_c_api()->PJRT_Buffer_UnpaddedDimensions(&args), pjrt_c_api()); @@ -1798,10 +1809,19 @@ StatusOr> PjRtCApiBuffer::logical_dimensions() { args.unpadded_dims + args.num_dims); } +PjRtFuture PjRtCApiBuffer::LazyToLiteral( + absl::AnyInvocable() &&> generator) { + auto buffer = std::move(generator)(); + if (!buffer.ok()) { + return PjRtFuture(buffer.status()); + } + return ToLiteral(buffer.value()); +} + PjRtFuture PjRtCApiBuffer::ToLiteral(MutableLiteralBase* literal) { PJRT_Buffer_ToHostBuffer_Args args; args.struct_size = PJRT_Buffer_ToHostBuffer_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.src = buffer_.get(); const xla::Shape& shape = literal->shape(); @@ -1814,7 +1834,7 @@ PjRtFuture PjRtCApiBuffer::ToLiteral(MutableLiteralBase* literal) { args.dst_size = ShapeUtil::ByteSizeOfElements(shape); args.dst = literal->untyped_data(); - xla::StatusOr c_layout_data; + absl::StatusOr c_layout_data; if (literal->shape().has_layout()) { c_layout_data = pjrt::ConvertToBufferMemoryLayoutData(literal->shape().layout()); @@ -1833,7 +1853,7 @@ PjRtFuture PjRtCApiBuffer::ToLiteral(MutableLiteralBase* literal) { ::pjrt::MakeErrorDeleter(api)}; if (error != nullptr) { - xla::Status s = ::pjrt::PjrtErrorToStatus(error.get(), api); + absl::Status s = ::pjrt::PjrtErrorToStatus(error.get(), api); return PjRtFuture(s); } @@ -1843,7 +1863,7 @@ PjRtFuture PjRtCApiBuffer::ToLiteral(MutableLiteralBase* literal) { StatusOr PjRtCApiBuffer::GetOnDeviceSizeInBytes() const { PJRT_Buffer_OnDeviceSizeInBytes_Args args; args.struct_size = PJRT_Buffer_OnDeviceSizeInBytes_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); RETURN_STATUS_IF_PJRT_ERROR( client_->pjrt_c_api()->PJRT_Buffer_OnDeviceSizeInBytes(&args), @@ -1855,7 +1875,7 @@ StatusOr PjRtCApiBuffer::GetOnDeviceSizeInBytes() const { PjRtMemorySpace* PjRtCApiBuffer::memory_space() const { PJRT_Buffer_Memory_Args args; args.struct_size = PJRT_Buffer_Memory_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); const PJRT_Api* api = pjrt_c_api(); std::unique_ptr error( @@ -1872,7 +1892,7 @@ PjRtMemorySpace* PjRtCApiBuffer::memory_space() const { PjRtDevice* PjRtCApiBuffer::device() const { PJRT_Buffer_Device_Args args; args.struct_size = PJRT_Buffer_Device_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); const PJRT_Api* api = pjrt_c_api(); pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_Device(&args), api); @@ -1882,7 +1902,7 @@ PjRtDevice* PjRtCApiBuffer::device() const { void PjRtCApiBuffer::Delete() { PJRT_Buffer_Delete_Args args; args.struct_size = PJRT_Buffer_Delete_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); const PJRT_Api* api = pjrt_c_api(); pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_Delete(&args), api); @@ -1891,7 +1911,7 @@ void PjRtCApiBuffer::Delete() { bool PjRtCApiBuffer::IsDeleted() { PJRT_Buffer_IsDeleted_Args args; args.struct_size = PJRT_Buffer_IsDeleted_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); const PJRT_Api* api = pjrt_c_api(); pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_IsDeleted(&args), api); @@ -1903,7 +1923,7 @@ StatusOr> PjRtCApiBuffer::CopyToDevice( if (dst_device->client() == client_) { PJRT_Buffer_CopyToDevice_Args args; args.struct_size = PJRT_Buffer_CopyToDevice_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); args.dst_device = tensorflow::down_cast(dst_device)->c_device(); @@ -1924,7 +1944,7 @@ StatusOr> PjRtCApiBuffer::CopyToDevice( literal_pointer->untyped_data(), literal_pointer->shape().element_type(), literal_pointer->shape().dimensions(), byte_strides, - PjRtClient::HostBufferSemantics::kZeroCopy, + PjRtClient::HostBufferSemantics::kImmutableZeroCopy, [literal{std::move(literal)}]() { /* frees literal */ }, dst_device); } } @@ -1932,16 +1952,11 @@ StatusOr> PjRtCApiBuffer::CopyToDevice( StatusOr> PjRtCApiBuffer::CopyToMemorySpace( PjRtMemorySpace* dst_memory) { const PJRT_Api* api = pjrt_c_api(); - // TODO(yueshengys): Remove this after 12/20/2023. - if (api->pjrt_api_version.minor_version < 32) { - return Unimplemented( - "The plugin has PJRT API version 0.32 which does not support " - "CopyToMemorySpace"); - } + if (dst_memory->client() == client_) { PJRT_Buffer_CopyToMemory_Args args; args.struct_size = PJRT_Buffer_CopyToMemory_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); args.dst_memory = tensorflow::down_cast(dst_memory)->c_memory(); @@ -1961,7 +1976,7 @@ StatusOr> PjRtCApiBuffer::CopyToMemorySpace( literal_pointer->untyped_data(), literal_pointer->shape().element_type(), literal_pointer->shape().dimensions(), byte_strides, - PjRtClient::HostBufferSemantics::kZeroCopy, + PjRtClient::HostBufferSemantics::kImmutableZeroCopy, [literal{std::move(literal)}]() { /* frees literal */ }, dst_memory, /*device_layout=*/nullptr); } @@ -1970,7 +1985,7 @@ StatusOr> PjRtCApiBuffer::CopyToMemorySpace( bool PjRtCApiBuffer::IsOnCpu() const { PJRT_Buffer_IsOnCpu_Args args; args.struct_size = PJRT_Buffer_IsOnCpu_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); const PJRT_Api* api = pjrt_c_api(); pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_IsOnCpu(&args), api); @@ -1982,7 +1997,7 @@ PJRT_Event* PjRtCApiBuffer::GetReadyEvent() { const PJRT_Api* api = pjrt_c_api(); PJRT_Buffer_ReadyEvent_Args args; args.struct_size = PJRT_Buffer_ReadyEvent_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_.get(); pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_ReadyEvent(&args), api); readiness_event_.reset(args.event); @@ -1995,7 +2010,7 @@ void PjRtCApiBuffer::MakePromiseTrackEvent() { const PJRT_Api* api = pjrt_c_api(); PJRT_Event_OnReady_Args args; args.struct_size = PJRT_Event_OnReady_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.event = GetReadyEvent(); args.user_arg = new std::function( [promise = readiness_promise_, api](PJRT_Error* error) -> void { @@ -2033,7 +2048,7 @@ PjRtCApiBuffer::AcquireExternalReference() { increase_reference_count_args.buffer = c_buffer(); increase_reference_count_args.struct_size = PJRT_Buffer_IncreaseExternalReferenceCount_Args_STRUCT_SIZE; - increase_reference_count_args.priv = nullptr; + increase_reference_count_args.extension_start = nullptr; RETURN_STATUS_IF_PJRT_ERROR( pjrt_c_api()->PJRT_Buffer_IncreaseExternalReferenceCount( &increase_reference_count_args), @@ -2043,7 +2058,7 @@ PjRtCApiBuffer::AcquireExternalReference() { opaque_device_memory_data_pointer_args; opaque_device_memory_data_pointer_args.struct_size = PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args_STRUCT_SIZE; - opaque_device_memory_data_pointer_args.priv = nullptr; + opaque_device_memory_data_pointer_args.extension_start = nullptr; opaque_device_memory_data_pointer_args.buffer = c_buffer(); RETURN_STATUS_IF_PJRT_ERROR( pjrt_c_api()->PJRT_Buffer_OpaqueDeviceMemoryDataPointer( @@ -2060,7 +2075,7 @@ PjRtCApiExternalReference::~PjRtCApiExternalReference() { PJRT_Buffer_DecreaseExternalReferenceCount_Args args; args.struct_size = PJRT_Buffer_DecreaseExternalReferenceCount_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.buffer = buffer_->c_buffer(); pjrt::LogFatalIfPjrtError( client_->pjrt_c_api()->PJRT_Buffer_DecreaseExternalReferenceCount(&args), @@ -2086,7 +2101,7 @@ absl::string_view PjRtCApiTopologyDescription::platform_name() const { PJRT_TopologyDescription_PlatformName_Args args; args.topology = c_topology_; args.struct_size = PJRT_TopologyDescription_PlatformName_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; pjrt::LogFatalIfPjrtError( c_api_->PJRT_TopologyDescription_PlatformName(&args), c_api_); return absl::string_view(args.platform_name, args.platform_name_size); @@ -2095,7 +2110,7 @@ absl::string_view PjRtCApiTopologyDescription::platform_name() const { absl::string_view PjRtCApiTopologyDescription::platform_version() const { PJRT_TopologyDescription_PlatformVersion_Args args; args.struct_size = PJRT_TopologyDescription_PlatformVersion_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.topology = c_topology_; pjrt::LogFatalIfPjrtError( c_api_->PJRT_TopologyDescription_PlatformVersion(&args), c_api_); @@ -2107,7 +2122,7 @@ PjRtCApiTopologyDescription::DeviceDescriptions() const { PJRT_TopologyDescription_GetDeviceDescriptions_Args args; args.struct_size = PJRT_TopologyDescription_GetDeviceDescriptions_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.topology = c_topology_; pjrt::LogFatalIfPjrtError( c_api_->PJRT_TopologyDescription_GetDeviceDescriptions(&args), c_api_); @@ -2125,7 +2140,7 @@ PjRtCApiTopologyDescription::DeviceDescriptions() const { StatusOr PjRtCApiTopologyDescription::Serialize() const { PJRT_TopologyDescription_Serialize_Args args; args.struct_size = PJRT_TopologyDescription_Serialize_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.topology = c_topology_; RETURN_STATUS_IF_PJRT_ERROR(c_api_->PJRT_TopologyDescription_Serialize(&args), c_api_); @@ -2137,7 +2152,7 @@ StatusOr PjRtCApiTopologyDescription::Serialize() const { void PjRtCApiTopologyDescription::InitAttributes() { PJRT_TopologyDescription_Attributes_Args args; args.struct_size = PJRT_TopologyDescription_Attributes_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; args.topology = c_topology_; pjrt::LogFatalIfPjrtError(c_api_->PJRT_TopologyDescription_Attributes(&args), c_api_); @@ -2153,7 +2168,7 @@ static StatusOr> InitializeArgsAndCompileAot( const std::string& format) { PJRT_Compile_Args args; args.struct_size = PJRT_Compile_Args_STRUCT_SIZE; - args.priv = nullptr; + args.extension_start = nullptr; if (client == nullptr) { args.client = nullptr; } else { @@ -2171,7 +2186,7 @@ static StatusOr> InitializeArgsAndCompileAot( PJRT_Program program; program.struct_size = PJRT_Program_STRUCT_SIZE; - program.priv = nullptr; + program.extension_start = nullptr; program.code = const_cast(code.c_str()); program.code_size = code.size(); program.format = format.c_str(); @@ -2196,19 +2211,16 @@ StatusOr> PjRtCApiCompiler::Compile( StatusOr> PjRtCApiCompiler::Compile( CompileOptions options, mlir::ModuleOp module, const PjRtTopologyDescription& topology, PjRtClient* client) { - std::string module_bytecode; - { - llvm::raw_string_ostream os(module_bytecode); - mlir::BytecodeWriterConfig config; - // Pin bytecode version to 1 until transition to stable. - // TODO(285913864): Remove post enabling frameworks to set it. - config.setDesiredBytecodeVersion(1); - if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) - return absl::UnknownError("writeBytecodeToFile() failed."); + // TODO: Once plugins are ready, use SerializeUsingVersionedStablehlo. + std::optional plugin_version; + if (client) { + plugin_version = client->plugin_attributes()->pjrt_c_api_minor_version; } + TF_ASSIGN_OR_RETURN(std::string serialized, xla::SerializeUsingNativeBytecode( + module, plugin_version)); std::string format(pjrt::kMlirFormat); return InitializeArgsAndCompileAot(c_api_, client, options, topology, - module_bytecode, format); + serialized, format); } // -------------------------------- API access --------------------------------- @@ -2216,37 +2228,27 @@ StatusOr> PjRtCApiCompiler::Compile( StatusOr> GetCApiClient( absl::string_view device_type, const absl::flat_hash_map& create_options, - PjRtClient::KeyValueGetCallback kv_get, - PjRtClient::KeyValuePutCallback kv_put) { + std::shared_ptr kv_store) { TF_ASSIGN_OR_RETURN(const PJRT_Api* c_api, pjrt::PjrtApi(device_type)); if (c_api == nullptr) { - return InternalError("PJRT C API is nullptr for %s", device_type); + return Internal("PJRT C API is nullptr for %s", device_type); } PJRT_Client_Create_Args init_args; init_args.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE; - init_args.priv = nullptr; - TF_ASSIGN_OR_RETURN( - std::vector c_options, - pjrt::ConvertToPjRtNamedValueList(create_options, - c_api->pjrt_api_version.minor_version)); + init_args.extension_start = nullptr; + TF_ASSIGN_OR_RETURN(std::vector c_options, + pjrt::ConvertToPjRtNamedValueList(create_options)); init_args.create_options = c_options.data(); init_args.num_options = c_options.size(); std::unique_ptr kv_callback_data; - if (kv_get == nullptr && kv_put == nullptr) { - kv_callback_data = nullptr; - } else if (kv_get != nullptr && kv_put != nullptr) { - kv_callback_data = pjrt::ConvertToCKeyValueCallbacks(kv_get, kv_put); + if (kv_store) { + kv_callback_data = pjrt::ConvertToCKeyValueCallbacks(kv_store); init_args.kv_get_callback = kv_callback_data->c_kv_get; init_args.kv_get_user_arg = &kv_callback_data->kv_get_c_func; init_args.kv_put_callback = kv_callback_data->c_kv_put; init_args.kv_put_user_arg = &kv_callback_data->kv_put_c_func; - } else { - return InvalidArgument( - "Only one of KeyValueGetCallback and KeyValuePutCallback is set in " - "GetCApiClient for %s", - device_type); } RETURN_STATUS_IF_PJRT_ERROR(c_api->PJRT_Client_Create(&init_args), c_api); @@ -2261,16 +2263,19 @@ StatusOr> GetCApiTopology( const absl::flat_hash_map& create_options) { TF_ASSIGN_OR_RETURN(const PJRT_Api* c_api, pjrt::PjrtApi(device_type)); if (c_api == nullptr) { - return InternalError("PJRT C API is nullptr for %s", device_type); + return Internal("PJRT C API is nullptr for %s", device_type); } + return GetCApiTopology(c_api, topology_name, create_options); +} +absl::StatusOr> GetCApiTopology( + const PJRT_Api* c_api, absl::string_view topology_name, + const absl::flat_hash_map& create_options) { PJRT_TopologyDescription_Create_Args init_args; init_args.struct_size = PJRT_TopologyDescription_Create_Args_STRUCT_SIZE; - init_args.priv = nullptr; - TF_ASSIGN_OR_RETURN( - std::vector c_options, - pjrt::ConvertToPjRtNamedValueList(create_options, - c_api->pjrt_api_version.minor_version)); + init_args.extension_start = nullptr; + TF_ASSIGN_OR_RETURN(std::vector c_options, + pjrt::ConvertToPjRtNamedValueList(create_options)); init_args.create_options = c_options.data(); init_args.num_options = c_options.size(); init_args.topology_name = topology_name.data(); diff --git a/xla/pjrt/pjrt_c_api_client.h b/xla/pjrt/pjrt_c_api_client.h index 24c31bd9e089b..3fabee09f6d83 100644 --- a/xla/pjrt/pjrt_c_api_client.h +++ b/xla/pjrt/pjrt_c_api_client.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -40,12 +40,14 @@ limitations under the License. #include "xla/literal.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/service/computation_placer.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/shape.h" @@ -126,6 +128,7 @@ class PjRtCApiDevice : public PjRtDevice { bool IsAddressable() const override; int local_hardware_id() const override; + PjRtLocalHardwareId local_hardware_id_typed() const override; Status TransferToInfeed(const LiteralSlice& literal) override { return Unimplemented("PJRT C API does not support TransferToInfeed"); @@ -215,6 +218,12 @@ class PjRtCApiTopologyDescription : public PjRtTopologyDescription { return attributes_; } + StatusOr GetDefaultLayout( + PrimitiveType element_type, + absl::Span dims) const override { + return Unimplemented("PJRT C API does not support GetDefaultLayout"); + } + private: std::unique_ptr compiler_; const PJRT_Api* c_api_; @@ -246,9 +255,13 @@ class PjRtCApiClient : public PjRtClient { absl::Span addressable_devices() const override; StatusOr LookupDevice(int device_id) const override; + StatusOr LookupDevice( + PjRtGlobalDeviceId global_device_id) const override; StatusOr LookupAddressableDevice( int local_hardware_id) const override; + StatusOr LookupAddressableDevice( + PjRtLocalDeviceId local_device_id) const override; absl::Span memory_spaces() const override; @@ -258,10 +271,7 @@ class PjRtCApiClient : public PjRtClient { absl::string_view platform_version() const override; - std::optional plugin_attributes() const override { - return PjRtPluginAttributes{c_api_->pjrt_api_version.major_version, - c_api_->pjrt_api_version.minor_version}; - } + std::optional plugin_attributes() const override; // TODO(b/244756954): Rethink this function altogether PjRtRuntimeType runtime_type() const override { @@ -276,6 +286,12 @@ class PjRtCApiClient : public PjRtClient { return Unimplemented("PJRT C API does not support GetHloCostAnalysis"); } + StatusOr GetDefaultLayout(PrimitiveType element_type, + absl::Span dims) override { + // TODO(skyewm): implement + return Unimplemented("PJRT C API does not support GetDefaultLayout"); + } + StatusOr> Compile( const XlaComputation& computation, CompileOptions options) override; @@ -314,21 +330,21 @@ class PjRtCApiClient : public PjRtClient { const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, + absl::AnyInvocable on_done_with_host_buffer, PjRtDevice* device) override; StatusOr> BufferFromHostBuffer( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, PjRtDevice* device, - const Layout* device_layout) override; + absl::AnyInvocable on_done_with_host_buffer, + PjRtDevice* device, const Layout* device_layout) override; StatusOr> BufferFromHostBuffer( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, + absl::AnyInvocable on_done_with_host_buffer, PjRtMemorySpace* memory_space, const Layout* device_layout) override; StatusOr> BufferFromHostLiteral( @@ -402,12 +418,13 @@ class PjRtCApiClient : public PjRtClient { private: void InitDevicesAndMemorySpaces(); + void InitAttributes(); StatusOr> BufferFromHostBufferInternalImpl( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, + absl::AnyInvocable on_done_with_host_buffer, std::variant device_or_memory, const Layout* device_layout); @@ -431,6 +448,7 @@ class PjRtCApiClient : public PjRtClient { const std::string platform_version_; const std::string platform_name_; const PjRtPlatformId platform_id_; + absl::flat_hash_map attributes_; }; class PjRtCApiBuffer : public PjRtBuffer { @@ -441,7 +459,7 @@ class PjRtCApiBuffer : public PjRtBuffer { absl::Span dimensions() const override; - const Layout& layout() const override; + std::unique_ptr layout() const override; // PJRT C API doesn't support tuple buffers. bool IsTuple() const override { return false; } @@ -470,7 +488,10 @@ class PjRtCApiBuffer : public PjRtBuffer { StatusOr> AcquireExternalReference() override; - PjRtFuture ToLiteral(MutableLiteralBase* literal) override; + PjRtFuture ToLiteral(MutableLiteralBase* literal) override; + PjRtFuture LazyToLiteral( + absl::AnyInvocable() &&> generator) + override; StatusOr GetOnDeviceSizeInBytes() const override; @@ -535,7 +556,7 @@ class PjRtCApiBuffer : public PjRtBuffer { // we set on `readiness_event` modifies `readiness_promise_`. std::shared_ptr::Promise> readiness_promise_; // Set and cached the first time layout() is called. - mutable std::optional layout_; + mutable std::optional layout_; // Set and cached the first time is_dynamic_dimension() is called. mutable std::optional> is_dynamic_dimension_; @@ -728,7 +749,7 @@ class PjRtCApiLoadedExecutable : public PjRtLoadedExecutable { // Gets common Execute_Args between Execute, ExecuteSharded and // ExecutePortable. device_complete_events in the return is set if the input // device_complete_events has value. - xla::StatusOr GetCommonExecuteArgs( + absl::StatusOr GetCommonExecuteArgs( absl::Span> argument_handles, const ExecuteOptions& options, PJRT_ExecuteOptions& c_options, std::vector>& c_argument_lists_storage, @@ -769,9 +790,16 @@ class CApiCopyToDeviceStream : public CopyToDeviceStream { StatusOr> GetCApiClient( absl::string_view device_type, const absl::flat_hash_map& create_options = {}, - PjRtClient::KeyValueGetCallback kv_get = nullptr, - PjRtClient::KeyValuePutCallback kv_put = nullptr); + std::shared_ptr kv_store = nullptr); + +absl::StatusOr> GetCApiTopology( + const PJRT_Api* c_api, absl::string_view topology_name, + const absl::flat_hash_map& create_options); +// A variant that takes `device_type` as an input, used for plugins that are not +// registered with standard way (xla_bridge.register_plugin). +// TODO(b/322357665): Delete this method after TPU plugin changes to use the +// standard registration. StatusOr> GetCApiTopology( absl::string_view device_type, absl::string_view topology_name, const absl::flat_hash_map& create_options = {}); diff --git a/xla/pjrt/pjrt_c_api_client_test.cc b/xla/pjrt/pjrt_c_api_client_test.cc index 9e9cb71d226fb..188e159419a2e 100644 --- a/xla/pjrt/pjrt_c_api_client_test.cc +++ b/xla/pjrt/pjrt_c_api_client_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/pjrt_client.cc b/xla/pjrt/pjrt_client.cc index 293c3d4cb02dc..1b2bf4644643f 100644 --- a/xla/pjrt/pjrt_client.cc +++ b/xla/pjrt/pjrt_client.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -68,9 +68,17 @@ std::string CompiledMemoryStats::DebugString() const { "argument_size_in_bytes=$1, " "output_size_in_bytes=$2, " "alias_size_in_bytes=$3, " - "temp_size_in_bytes=$4)", + "temp_size_in_bytes=$4, " + "host_generated_code_size_in_bytes=$5, " + "host_argument_size_in_bytes=$6, " + "host_output_size_in_bytes=$7, " + "host_alias_size_in_bytes=$8, " + "host_temp_size_in_bytes=$9)", generated_code_size_in_bytes, argument_size_in_bytes, - output_size_in_bytes, alias_size_in_bytes, temp_size_in_bytes); + output_size_in_bytes, alias_size_in_bytes, temp_size_in_bytes, + host_generated_code_size_in_bytes, host_argument_size_in_bytes, + host_output_size_in_bytes, host_alias_size_in_bytes, + host_temp_size_in_bytes); } // Defining the first virtual non-pure method, which is usually the virtual diff --git a/xla/pjrt/pjrt_client.h b/xla/pjrt/pjrt_client.h index 176bb4a026b90..4d00ad6a7a9af 100644 --- a/xla/pjrt/pjrt_client.h +++ b/xla/pjrt/pjrt_client.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -47,6 +47,7 @@ limitations under the License. #include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/service/computation_placer.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/shape.h" @@ -119,7 +120,39 @@ class PjRtDevice { // The ID of this device. IDs are unique among devices of this type // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all // hosts' devices. This is the ID that should be used in a DeviceAssignment. - virtual int id() const { return description().id(); } + ABSL_DEPRECATED("Use global_device_id() instead") + virtual int id() const { return global_device_id().value(); } + + // There are several different IDs for a PJRT device. + // + // - global_device_id: The logical global device ID. This is unique among + // devices of this type (e.g. CPUs, GPUs). On multi-host platforms, this will + // be unique across all hosts' devices. This is the ID that should be used in + // a DeviceAssignment. + // + // - local_device_id: The logical local device ID. This will be used to look + // up an addressable device local to a given client. It is -1 if undefined. + // + // - local_hardware_id: The physical local device ID, e.g., the CUDA device + // number. Multiple PJRT devices can have the same local_hardware_id if + // these PJRT devices share the same physical device. This is useful for + // identifying which physical device when interacting with non-JAX code. In + // general, not guaranteed to be dense, and -1 if undefined. + + // TODO(b/314368788): Remove `id()` and replace it with this function. + virtual PjRtGlobalDeviceId global_device_id() const { + return PjRtGlobalDeviceId(description().id()); + } + + virtual PjRtLocalDeviceId local_device_id() const { + // By default, local_device_id is the same as local_hardware_id when there + // is only one PJRT device on a physical device. + return PjRtLocalDeviceId(local_hardware_id_typed().value()); + } + + // TODO(b/314368788): Remove `int local_hardware_id()` and rename this + // function to `local_hardware_id()`. + virtual PjRtLocalHardwareId local_hardware_id_typed() const = 0; // The index of the process that this device belongs to, i.e. is addressable // from. This is not always identical to PjRtClient::process_index() in a @@ -131,7 +164,10 @@ class PjRtDevice { // Opaque hardware ID, e.g., the CUDA device number, useful for identifying // which GPU when interacting with non-JAX code. In general, not guaranteed to // be dense, and -1 if undefined. - virtual int local_hardware_id() const = 0; + ABSL_DEPRECATED("Use local_hardware_id_typed() instead") + virtual int local_hardware_id() const { + return local_hardware_id_typed().value(); + } // A vendor-dependent string that uniquely identifies the kind of device, // e.g., "Tesla V100-SXM2-16GB". May be used to determine whether two GPUs are @@ -405,6 +441,7 @@ class PjRtLoadedExecutable; struct PjRtPluginAttributes { int64_t pjrt_c_api_major_version; int64_t pjrt_c_api_minor_version; + absl::flat_hash_map attributes; }; // Encapsulates the state of Python session with XLA. @@ -456,21 +493,6 @@ struct PjRtPluginAttributes { // will eventually be able to make progress. class PjRtClient { public: - // In the multi-node case, the caller of PjRtClient can provide a key-value - // store accessible across nodes. The caller can provide the two callbacks - // below to access the key-value store. There are a few requirements: - // (1) KeyValueGetCallback and KeyValuePutCallback must be thread-safe. - // (2) The caller that provides the two callbacks is responsible for avoiding - // key collisions between different users of key-value store (i.e. between - // different plugins, but not between different GPU plugin nodes). - // (3) KeyValueGetCallback is blocking. - // Subclasses of PjRtClient can optionally take these callbacks in their - // constructors. - using KeyValueGetCallback = std::function( - std::string_view key, absl::Duration timeout)>; - using KeyValuePutCallback = - std::function; - PjRtClient() = default; explicit PjRtClient(std::unique_ptr host_memory_for_device_manager) @@ -500,12 +522,22 @@ class PjRtClient { virtual absl::Span addressable_devices() const = 0; // Lookup any PjRtDevice for a given PjRtDevice::id(). - virtual StatusOr LookupDevice(int device_id) const = 0; + ABSL_DEPRECATED("Use LookupDevice(PjRtGlobalDeviceId) instead") + virtual StatusOr LookupDevice(int device_id) const { + return LookupDevice(PjRtGlobalDeviceId(device_id)); + } + virtual StatusOr LookupDevice( + PjRtGlobalDeviceId global_device_id) const = 0; // Return an addressable PjRtDevice for a given // PjRtDevice::local_hardware_id(). + ABSL_DEPRECATED("Use LookupAddressableDevice(PjRtLocalDeviceId) instead") + virtual StatusOr LookupAddressableDevice( + int local_hardware_id) const { + return LookupAddressableDevice(PjRtLocalDeviceId(local_hardware_id)); + } virtual StatusOr LookupAddressableDevice( - int local_hardware_id) const = 0; + PjRtLocalDeviceId local_device_id) const = 0; // Return all memory spaces owned by the client. // The memory spaces are in no particular order. @@ -549,6 +581,15 @@ class PjRtClient { return Unimplemented("Multi slice device assignment is not supported."); } + // Returns the default device layout for a buffer with `element_type` and + // `dims`. The default layout is a platform-specific layout used when no other + // layout is specified, e.g. for host-to-device transfers. When compiling, the + // default layout is used for program arguments and outputs unless + // user-specified or compiler-chosen layouts are requested via the + // "mhlo.layout_mode" attribute. + virtual StatusOr GetDefaultLayout(PrimitiveType element_type, + absl::Span dims) = 0; + // Returns a backend-specific HLO cost analysis visitor. virtual StatusOr> GetHloCostAnalysis() const = 0; @@ -716,6 +757,43 @@ class PjRtClient { CreateBuffersForAsyncHostToDevice(absl::Span shapes, PjRtMemorySpace* memory_space) = 0; + // Creates a shapeless buffer on the device that can be partitioned into + // multiple PjRtBuffer. This class is an Arena version of + // `AsyncHostToDeviceTransferManager`. + // As a low-level interface, the user must make sure that invocations of + // `Slice` match properly with the writes from `TransferRawDataToSubBuffer`. + // + // For the intended application to Arena allocation / transfer, the user can + // use `GetOnDeviceSizeInBytes` to calculate the offsets for the host buffers + // that need to be transferred. + class PjRtRawDeviceBuffer { + public: + virtual ~PjRtRawDeviceBuffer() = default; + + // Transfers data to the device buffer. Data should already be in the + // device layout. + virtual Status TransferRawDataToSubBuffer( + const void* data, int64_t offset, int64_t transfer_size, + bool is_last_transfer, absl::AnyInvocable on_done) = 0; + + // The resulting buffer becomes ready when all transfers complete. + virtual StatusOr> Slice( + int64_t offset, PrimitiveType type, absl::Span dims, + const Layout& layout) = 0; + }; + // Creates a raw device buffer of a given size in bytes. + virtual StatusOr> CreateRawDeviceBuffer( + int64_t size, PjRtDevice* device) { + return Unimplemented("CreateRawDeviceBuffer is not implemented."); + } + + // On-device bytes required for a PjRt buffer with these `Shape` attributes. + virtual StatusOr GetOnDeviceSizeInBytes( + PrimitiveType type, absl::Span dims, + const Layout& layout) { + return Unimplemented("GetOnDeviceSizeInBytes is not implemented."); + }; + // Describes the semantics the caller to BufferFromHostBuffer expects from the // runtime, in a total order from most restrictive to least restrictive. enum class HostBufferSemantics { @@ -735,13 +813,14 @@ class PjRtClient { kImmutableUntilTransferCompletes, // The PjRtBuffer may alias `data` internally and the runtime may use the - // `data` contents as long as the buffer is alive. The caller promises to - // keep `data` alive and not to mutate its contents as long as the buffer is - // alive; to notify the caller that the buffer may be freed, the runtime - // will call `on_done_with_host_buffer` when the PjRtBuffer is freed. On - // non-CPU platforms this acts identically to - // kImmutableUntilTransferCompletes. - kZeroCopy, + // `data` contents as long as the buffer is alive. The runtime promises not + // to mutate contents of the buffer (i.e. it will not use it for aliased + // output buffers). The caller promises to keep `data` alive and also not to + // mutate its contents as long as the buffer is alive; to notify the caller + // that the buffer may be freed, the runtime will call + // `on_done_with_host_buffer` when the PjRtBuffer is freed. On non-CPU + // platforms this acts identically to kImmutableUntilTransferCompletes. + kImmutableZeroCopy, }; // on_done_with_host_buffer is optional and may be null. @@ -757,7 +836,8 @@ class PjRtClient { const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, PjRtDevice* device) = 0; + absl::AnyInvocable on_done_with_host_buffer, + PjRtDevice* device) = 0; // Variant of BufferFromHostBuffer that takes an optional device layout. It is // used when non-compact layout is preferred. @@ -767,8 +847,8 @@ class PjRtClient { const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, PjRtDevice* device, - const Layout* device_layout) { + absl::AnyInvocable on_done_with_host_buffer, + PjRtDevice* device, const Layout* device_layout) { return tsl::errors::Unimplemented( "BufferFromHostBuffer with an optional device layout is not " "implemented on platform: ", @@ -781,7 +861,7 @@ class PjRtClient { const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, + absl::AnyInvocable on_done_with_host_buffer, PjRtMemorySpace* memory_space, const Layout* device_layout) { return tsl::errors::Unimplemented( "BufferFromHostBuffer with PjRtMemorySpace is not implemented on " @@ -929,9 +1009,12 @@ class PjRtBuffer { return on_device_shape().dimensions(); } - virtual const Layout& layout() const { + // The on-device memory layout of this buffer. Returned via unique_ptr to make + // memory management easier -- PjRtLayout is an abstract base class, so cannot + // be easily copied. + virtual std::unique_ptr layout() const { CHECK(on_device_shape().has_layout()); - return on_device_shape().layout(); + return std::make_unique(on_device_shape().layout()); } // PjRtBuffers can either represent a single array buffer or a tuple of array @@ -1015,21 +1098,20 @@ class PjRtBuffer { // completed. The transfer respects the layout of `literal`; to specify a // particular layout, set the layout before calling `ToLiteral`. virtual PjRtFuture ToLiteral(MutableLiteralBase* literal) = 0; - - // Copies the buffer's value into `literal`. Calls `on_ready` when the value - // (or an error) is ready. The transfer respects the layout of `literal`; to - // specify a particular layout, set the layout before calling `ToLiteral`. - ABSL_DEPRECATED("Use ToLiteral(...).OnReady() instead") - void ToLiteral(MutableLiteralBase* literal, - std::function on_ready) { - ToLiteral(literal).OnReady(std::move(on_ready)); - } + // This version of ToLiteral allows the implementation to defer the + // construction of the literal (e.g. until the underlying buffer is ready). + // The specific timing of calling `generator` is implementation defined, and + // might be done eagerly, but it is guaranteed to be earlier than when the + // returned future becomes ready. + virtual PjRtFuture LazyToLiteral( + absl::AnyInvocable() &&> + generator) = 0; // Synchronous overload of ToLiteral, as a convenience. Status ToLiteralSync(MutableLiteralBase* literal) { absl::Notification done; Status status; - ToLiteral(literal, [&](Status s) { + ToLiteral(literal).OnReady([&](Status s) { status = std::move(s); done.Notify(); }); @@ -1037,9 +1119,7 @@ class PjRtBuffer { return status; } - // Convenience synchronous overload that allocates a literal with a default - // layout. - StatusOr> ToLiteralSync() { + absl::StatusOr HostShape() { Shape device_shape; if (!IsTuple()) { absl::Span literal_dims; @@ -1053,7 +1133,8 @@ class PjRtBuffer { literal_dims = dimensions(); } device_shape = ShapeUtil::MakeShape(element_type(), literal_dims); - *device_shape.mutable_layout() = layout(); + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + *device_shape.mutable_layout() = GetXlaLayoutUnsafe(layout()); } else { // TODO(skyewm): does anything need to create tuple literals? The PJRT C // API doesn't support tuples or {logical_}on_device_shape(), so we prefer @@ -1063,8 +1144,14 @@ class PjRtBuffer { TF_ASSIGN_OR_RETURN(device_shape, logical_on_device_shape()); } } - auto literal = std::make_shared( - ShapeUtil::DeviceShapeToHostShape(device_shape)); + return ShapeUtil::DeviceShapeToHostShape(device_shape); + } + + // Convenience synchronous overload that allocates a literal with a default + // layout. + absl::StatusOr> ToLiteralSync() { + TF_ASSIGN_OR_RETURN(Shape host_shape, HostShape()); + auto literal = std::make_shared(host_shape); TF_RETURN_IF_ERROR(ToLiteralSync(literal.get())); return literal; } @@ -1293,25 +1380,6 @@ class PjRtBuffer { return s; } - // Calls callback when the buffer is ready. - // - // buf->OnReady(callback); - // - // is semantically almost identical to: - // - // ForkThread([]() { callback(buf->Await()); }); - // - // the only difference being that the callback may happen immediately on the - // calling thread. (The implementation may also be more efficient.) - // - // The interface makes no assumptions about what thread calls callback, so the - // caller must ensure that callback returns quickly and hands off long-running - // work or any blocking operation to a caller-managed threadpool. - ABSL_DEPRECATED("Use GetReadyFuture().OnReady() instead") - void OnReady(std::function callback) { - return GetReadyFuture().OnReady(std::move(callback)); - } - // Whether this buffer is on CPU and thus allows for certain optimizations. virtual bool IsOnCpu() const = 0; }; diff --git a/xla/pjrt/pjrt_client_test.cc b/xla/pjrt/pjrt_client_test.cc index f11f2379f9765..1bc5d6704abbf 100644 --- a/xla/pjrt/pjrt_client_test.cc +++ b/xla/pjrt/pjrt_client_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -176,7 +176,7 @@ TEST_P(PjRtClientTest, ExecuteWithTupleZeroCopy) { /*byte_strides=*/std::nullopt, // Use kZeroCopy to test the correctness of // `on_done_with_host_buffer`. - PjRtClient::HostBufferSemantics::kZeroCopy, + PjRtClient::HostBufferSemantics::kImmutableZeroCopy, /*on_done_with_host_buffer=*/ [&data]() { // Deliberately modifying the content of `data`. A @@ -216,8 +216,8 @@ TEST_P(PjRtClientTest, ExecuteWithDonation) { auto buffer, client->BufferFromHostBuffer( data.data(), shape.element_type(), shape.dimensions(), /*byte_strides=*/std::nullopt, - PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, - client->addressable_devices()[0])); + PjRtClient::HostBufferSemantics::kImmutableZeroCopy, + nullptr, client->addressable_devices()[0])); ExecuteOptions options; options.execution_mode = GetParam(); @@ -249,8 +249,8 @@ TEST_P(PjRtClientTest, ExecuteWithDonationAbort) { auto buffer, client->BufferFromHostBuffer( data.data(), shape.element_type(), shape.dimensions(), /*byte_strides=*/std::nullopt, - PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, - client->addressable_devices()[0])); + PjRtClient::HostBufferSemantics::kImmutableZeroCopy, + nullptr, client->addressable_devices()[0])); auto external_reference = buffer->AcquireExternalReference(); @@ -323,8 +323,8 @@ TEST_P(PjRtClientTest, ExecuteWithConcurrentUsageAndDonation) { auto buffer, client->BufferFromHostBuffer( data.data(), shape.element_type(), shape.dimensions(), /*byte_strides=*/std::nullopt, - PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, - client->addressable_devices()[0])); + PjRtClient::HostBufferSemantics::kImmutableZeroCopy, + nullptr, client->addressable_devices()[0])); ExecuteOptions options; options.execution_mode = GetParam(); @@ -563,5 +563,7 @@ ENTRY DuplicateDonationError() -> (f32[2, 2], f32[2, 2]) { } } +TEST(PjRtClientTest, GetDefaultLayout) {} + } // namespace } // namespace xla diff --git a/xla/pjrt/pjrt_client_test.h b/xla/pjrt/pjrt_client_test.h index acd0561d943df..95251296570a8 100644 --- a/xla/pjrt/pjrt_client_test.h +++ b/xla/pjrt/pjrt_client_test.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/pjrt_common.h b/xla/pjrt/pjrt_common.h index 0de187db139b2..042d28acd12a0 100644 --- a/xla/pjrt/pjrt_common.h +++ b/xla/pjrt/pjrt_common.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "tsl/lib/gtl/int_type.h" + namespace xla { // bool comes before int64_t because when pybind11 tries to convert a Python @@ -29,6 +31,12 @@ namespace xla { using PjRtValueType = std::variant, float>; +// The strong-typed integer classes to better disambiguate different IDs for +// PJRT devices. +TSL_LIB_GTL_DEFINE_INT_TYPE(PjRtGlobalDeviceId, int32_t); +TSL_LIB_GTL_DEFINE_INT_TYPE(PjRtLocalDeviceId, int32_t); +TSL_LIB_GTL_DEFINE_INT_TYPE(PjRtLocalHardwareId, int32_t); + } // namespace xla #endif // XLA_PJRT_PJRT_COMMON_H_ diff --git a/xla/pjrt/pjrt_compiler.cc b/xla/pjrt/pjrt_compiler.cc index ec918f2db2d11..9501e6f5908c0 100644 --- a/xla/pjrt/pjrt_compiler.cc +++ b/xla/pjrt/pjrt_compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/pjrt_compiler.h b/xla/pjrt/pjrt_compiler.h index f310bf96d2ae3..46c363ace1361 100644 --- a/xla/pjrt/pjrt_compiler.h +++ b/xla/pjrt/pjrt_compiler.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -44,6 +44,10 @@ inline const char* RocmName() { static constexpr char kRocmName[] = "rocm"; return kRocmName; } +inline const char* SyclName() { + static constexpr char kSyclName[] = "sycl"; + return kSyclName; +} inline const char* TpuName() { static constexpr char kTpuName[] = "tpu"; return kTpuName; @@ -60,6 +64,10 @@ inline PjRtPlatformId RocmId() { static const PjRtPlatformId kRocmId = tsl::Fingerprint64(RocmName()); return kRocmId; } +inline PjRtPlatformId SyclId() { + static const PjRtPlatformId kSyclId = tsl::Fingerprint64(SyclName()); + return kSyclId; +} inline PjRtPlatformId TpuId() { static const PjRtPlatformId kTpuId = tsl::Fingerprint64(TpuName()); return kTpuId; @@ -131,6 +139,15 @@ class PjRtTopologyDescription { // Returns vendor specific attributes about the topology. virtual const absl::flat_hash_map& Attributes() const = 0; + + // Returns the default device layout for a buffer with `element_type` and + // `dims`. The default layout is a platform-specific layout used when no other + // layout is specified, e.g. for host-to-device transfers. When compiling, the + // default layout is used for program arguments and outputs unless + // user-specified or compiler-chosen layouts are requested via the + // "mhlo.layout_mode" attribute. + virtual StatusOr GetDefaultLayout( + PrimitiveType element_type, absl::Span dims) const = 0; }; // Abstract interface that all registered compilers must implement. diff --git a/xla/pjrt/pjrt_compiler_test.cc b/xla/pjrt/pjrt_compiler_test.cc index 87d31153b3840..98a2b8e8d5e16 100644 --- a/xla/pjrt/pjrt_compiler_test.cc +++ b/xla/pjrt/pjrt_compiler_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -56,6 +56,11 @@ class PjRtTestTopology : public PjRtTopologyDescription { const override { LOG(FATAL) << "Unused"; } + StatusOr GetDefaultLayout( + PrimitiveType element_type, + absl::Span dims) const override { + return Unimplemented("TestTopology does not support GetDefaultLayout"); + } }; TEST(PjRtCompilerTest, CompilerNotRegistered) { @@ -85,6 +90,11 @@ TEST(PjRtCompilerTest, CompilerRegistered) { const override { LOG(FATAL) << "Unused"; } + StatusOr GetDefaultLayout( + PrimitiveType element_type, + absl::Span dims) const override { + return Unimplemented("TestTopology does not support GetDefaultLayout"); + } }; PjRtTestTopology topology; diff --git a/xla/pjrt/pjrt_device_description.h b/xla/pjrt/pjrt_device_description.h index a2df027f334f6..2021f8e1cf196 100644 --- a/xla/pjrt/pjrt_device_description.h +++ b/xla/pjrt/pjrt_device_description.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/pjrt_executable.cc b/xla/pjrt/pjrt_executable.cc index 141987063ae1d..4c4eae38f0cbc 100644 --- a/xla/pjrt/pjrt_executable.cc +++ b/xla/pjrt/pjrt_executable.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ limitations under the License. #include "xla/pjrt/compile_options.pb.h" #include "xla/pjrt/execute_options.pb.h" #include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/service/computation_layout.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/shape.h" @@ -321,7 +322,8 @@ PjRtExecutable::GetOutputDimensions() const { return output_dimensions; } -StatusOr> PjRtExecutable::GetParameterLayouts() const { +absl::StatusOr>> +PjRtExecutable::GetParameterLayouts() const { TF_ASSIGN_OR_RETURN(std::vector> hlo_modules, GetHloModules()); if (hlo_modules.size() > 1) { @@ -335,10 +337,18 @@ StatusOr> PjRtExecutable::GetParameterLayouts() const { "from executable."); } ComputationLayout comp_layout = hlo_modules[0]->entry_computation_layout(); - return comp_layout.FlattenedParameterLayouts(); + TF_ASSIGN_OR_RETURN(std::vector layouts, + comp_layout.FlattenedParameterLayouts()); + std::vector> result; + result.reserve(layouts.size()); + for (const Layout& layout : layouts) { + result.push_back(std::make_unique(layout)); + } + return result; } -StatusOr> PjRtExecutable::GetOutputLayouts() const { +absl::StatusOr>> +PjRtExecutable::GetOutputLayouts() const { TF_ASSIGN_OR_RETURN(std::vector> hlo_modules, GetHloModules()); if (hlo_modules.size() > 1) { @@ -352,7 +362,14 @@ StatusOr> PjRtExecutable::GetOutputLayouts() const { "from executable."); } ComputationLayout comp_layout = hlo_modules[0]->entry_computation_layout(); - return comp_layout.FlattenedResultLayouts(); + TF_ASSIGN_OR_RETURN(std::vector layouts, + comp_layout.FlattenedResultLayouts()); + std::vector> result; + result.reserve(layouts.size()); + for (const Layout& layout : layouts) { + result.push_back(std::make_unique(layout)); + } + return result; } StatusOr> @@ -424,7 +441,7 @@ CompileOptions::LoadEnvOptionOverrides( env_option_override.second.double_field())}); break; case OptionOverrideProto::VALUE_NOT_SET: - return InternalError("OptionOverrideProto value not set."); + return Internal("OptionOverrideProto value not set."); } } return result; diff --git a/xla/pjrt/pjrt_executable.h b/xla/pjrt/pjrt_executable.h index 0c603c20f1baf..e1a6c6f22af96 100644 --- a/xla/pjrt/pjrt_executable.h +++ b/xla/pjrt/pjrt_executable.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,9 +33,12 @@ limitations under the License. #include "absl/types/span.h" #include "xla/client/executable_build_options.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/layout.h" #include "xla/pjrt/compile_options.pb.h" +#include "xla/pjrt/executable_metadata.pb.h" #include "xla/pjrt/execute_options.pb.h" #include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/service/compiler.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_cost_analysis.h" @@ -44,7 +47,6 @@ limitations under the License. #include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/protobuf.h" namespace xla { @@ -260,12 +262,13 @@ struct ExecuteOptions { const ExecuteOptionsProto& proto); }; -// Static device memory usage for a compiled program. +// Static memory usage for a compiled program. // The on-device memory needed to run an executable is at least // generated_code_size_in_bytes // + argument_size_in_bytes + output_size_in_bytes - alias_size_in_bytes // + temp_size_in_bytes. struct CompiledMemoryStats { + // Device default memory (e.g., HBM for GPU/TPU) usage stats. int64_t generated_code_size_in_bytes = 0; int64_t argument_size_in_bytes = 0; int64_t output_size_in_bytes = 0; @@ -273,8 +276,49 @@ struct CompiledMemoryStats { int64_t alias_size_in_bytes = 0; int64_t temp_size_in_bytes = 0; + // Host memory usage stats. + int64_t host_generated_code_size_in_bytes = 0; + int64_t host_argument_size_in_bytes = 0; + int64_t host_output_size_in_bytes = 0; + int64_t host_alias_size_in_bytes = 0; + int64_t host_temp_size_in_bytes = 0; + std::string serialized_hlo_proto = ""; std::string DebugString() const; + + CompiledMemoryStatsProto ToProto() { + CompiledMemoryStatsProto proto; + proto.set_generated_code_size_in_bytes(generated_code_size_in_bytes); + proto.set_argument_size_in_bytes(argument_size_in_bytes); + proto.set_output_size_in_bytes(output_size_in_bytes); + proto.set_alias_size_in_bytes(alias_size_in_bytes); + proto.set_temp_size_in_bytes(temp_size_in_bytes); + proto.mutable_hlo_proto()->ParseFromString(serialized_hlo_proto); + proto.set_host_generated_code_size_in_bytes( + host_generated_code_size_in_bytes); + proto.set_host_argument_size_in_bytes(host_argument_size_in_bytes); + proto.set_host_output_size_in_bytes(host_output_size_in_bytes); + proto.set_host_alias_size_in_bytes(host_alias_size_in_bytes); + proto.set_host_temp_size_in_bytes(host_temp_size_in_bytes); + return proto; + } + + static CompiledMemoryStats FromProto(const CompiledMemoryStatsProto& proto) { + CompiledMemoryStats stats; + stats.generated_code_size_in_bytes = proto.generated_code_size_in_bytes(); + stats.argument_size_in_bytes = proto.argument_size_in_bytes(); + stats.output_size_in_bytes = proto.alias_size_in_bytes(); + stats.alias_size_in_bytes = proto.alias_size_in_bytes(); + stats.temp_size_in_bytes = proto.temp_size_in_bytes(); + stats.serialized_hlo_proto = proto.hlo_proto().SerializeAsString(); + stats.host_generated_code_size_in_bytes = + proto.host_generated_code_size_in_bytes(); + stats.host_argument_size_in_bytes = proto.host_argument_size_in_bytes(); + stats.host_output_size_in_bytes = proto.host_output_size_in_bytes(); + stats.host_alias_size_in_bytes = proto.host_alias_size_in_bytes(); + stats.host_temp_size_in_bytes = proto.host_temp_size_in_bytes(); + return stats; + } }; class PjRtExecutable { @@ -309,10 +353,12 @@ class PjRtExecutable { GetOutputDimensions() const; // Returns the layout of each input parameter. - virtual StatusOr> GetParameterLayouts() const; + virtual absl::StatusOr>> + GetParameterLayouts() const; // Returns the layout of each output. - virtual StatusOr> GetOutputLayouts() const; + virtual absl::StatusOr>> + GetOutputLayouts() const; // Returns a list of lists of memory kind strings for output. The returned // value is `[num_programs, num_output]`. The size of the outer list should be diff --git a/xla/pjrt/pjrt_executable_test.cc b/xla/pjrt/pjrt_executable_test.cc index a66c0935e65c9..8fa614c1f050e 100644 --- a/xla/pjrt/pjrt_executable_test.cc +++ b/xla/pjrt/pjrt_executable_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/pjrt_future.h b/xla/pjrt/pjrt_future.h index dd748daa912de..9944b1f046b75 100644 --- a/xla/pjrt/pjrt_future.h +++ b/xla/pjrt/pjrt_future.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ limitations under the License. #define XLA_PJRT_PJRT_FUTURE_H_ #include +#include #include #include #include diff --git a/xla/pjrt/pjrt_layout.h b/xla/pjrt/pjrt_layout.h new file mode 100644 index 0000000000000..0fbf205de8b6e --- /dev/null +++ b/xla/pjrt/pjrt_layout.h @@ -0,0 +1,111 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_PJRT_LAYOUT_H_ +#define XLA_PJRT_PJRT_LAYOUT_H_ + +#include +#include +#include + +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/strings/string_view.h" +#include "xla/layout.h" +#include "xla/service/hlo_parser.h" +#include "xla/statusor.h" +#include "tsl/platform/casts.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +// Abstract class representing the memory layout of a PjRtBuffer. +class PjRtLayout { + public: + virtual ~PjRtLayout() = default; + + // Returns the serialized layout as a string. + // TODO(b/328671718): add generic deserialize method to PjRtClient and/or + // PjRtCompiler. + virtual std::string Serialize() const = 0; + + // Human-readable string for error messages, user introspection, etc. + virtual std::string ToString() const = 0; + + virtual bool operator==(const PjRtLayout& other) const = 0; + + template + friend H AbslHashValue(H state, const PjRtLayout& layout) { + layout.Hash(absl::HashState::Create(&state)); + return std::move(state); + } + + protected: + virtual void Hash(absl::HashState state) const = 0; +}; + +// PjRtLayout backed by an xla::Layout. This is a convenience class for PJRT +// implementations that use XLA. PJRT users should use the PjRtLayout interface +// to be compatible with all implementations, e.g. PjRtCApiClient which doesn't +// have access to full xla::Layouts. +class PjRtXlaLayout : public PjRtLayout { + public: + explicit PjRtXlaLayout(Layout layout) : xla_layout_(std::move(layout)) { + // Strip memory space and set it to the default. PJRT tracks memory space + // separately from layout. + xla_layout_.set_memory_space(xla::Layout::kDefaultMemorySpace); + } + + std::string Serialize() const override { return xla_layout_.ToString(); } + + static StatusOr Deserialize(absl::string_view serialized) { + TF_ASSIGN_OR_RETURN(Layout xla_layout, ParseLayout(serialized)); + return PjRtXlaLayout(std::move(xla_layout)); + } + + std::string ToString() const override { return xla_layout_.ToString(); } + + bool operator==(const PjRtLayout& other) const override { + auto xla_other = dynamic_cast(&other); + if (xla_other == nullptr) { + return false; + } + return xla_layout_ == xla_other->xla_layout_; + }; + + const Layout& xla_layout() const { return xla_layout_; } + + protected: + void Hash(absl::HashState state) const override { + absl::HashState::combine(std::move(state), xla_layout_); + } + + private: + Layout xla_layout_; +}; + +// TODO(b/327524065): make callers use PjRtLayout directly instead of assuming +// an xla::Layout and get rid of this function. +inline Layout GetXlaLayoutUnsafe( + const std::unique_ptr& pjrt_layout) { + PjRtXlaLayout* xla_layout = + tensorflow::down_cast(pjrt_layout.get()); + CHECK(xla_layout != nullptr) << "Got unexpected layout type"; + return xla_layout->xla_layout(); +} + +} // namespace xla + +#endif // XLA_PJRT_PJRT_LAYOUT_H_ diff --git a/xla/pjrt/pjrt_stream_executor_client.cc b/xla/pjrt/pjrt_stream_executor_client.cc index 32dcffb93ac2f..beefac21941a5 100644 --- a/xla/pjrt/pjrt_stream_executor_client.cc +++ b/xla/pjrt/pjrt_stream_executor_client.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -70,6 +70,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -80,17 +81,20 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/base/casts.h" +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/match.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" -#include "absl/synchronization/notification.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "xla/client/executable_build_options.h" #include "xla/client/local_client.h" #include "xla/client/xla_computation.h" #include "xla/cpu_function_runtime.h" @@ -99,14 +103,21 @@ limitations under the License. #include "xla/literal.h" #include "xla/pjrt/distributed/protocol.pb.h" #include "xla/pjrt/event_pool.h" +#include "xla/pjrt/host_callback.h" #include "xla/pjrt/local_device_state.h" #include "xla/pjrt/metrics.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/semaphore.h" #include "xla/pjrt/tracked_device_buffer.h" +#include "xla/pjrt/transpose.h" #include "xla/pjrt/utils.h" +#include "xla/primitive_util.h" +#include "xla/service/compiler.h" #include "xla/service/computation_layout.h" #include "xla/service/executable.h" #include "xla/service/generic_transfer_manager.h" @@ -115,13 +126,17 @@ limitations under the License. #include "xla/service/shaped_buffer.h" #include "xla/service/transfer_manager.h" #include "xla/shape.h" +#include "xla/shape_tree.h" #include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/host/host_platform_id.h" #include "xla/stream_executor/stream.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/framework/allocator.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/fingerprint.h" @@ -131,6 +146,7 @@ limitations under the License. #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" #include "tsl/profiler/lib/connected_traceme.h" +#include "tsl/profiler/lib/context_types.h" #include "tsl/profiler/lib/traceme.h" namespace xla { @@ -247,7 +263,7 @@ PjRtStreamExecutorClient::PjRtStreamExecutorClient( // appeared. absl::c_sort(addressable_devices_, [](const PjRtDevice* a, const PjRtDevice* b) { - return a->local_hardware_id() < b->local_hardware_id(); + return a->local_device_id() < b->local_device_id(); }); } @@ -257,6 +273,16 @@ StatusOr PjRtStreamExecutorClient::GetDefaultDeviceAssignment( num_partitions); } +StatusOr PjRtStreamExecutorClient::GetDefaultLayout( + PrimitiveType element_type, absl::Span dims) { + Shape shape = ShapeUtil::MakeShape(element_type, dims); + TF_ASSIGN_OR_RETURN( + shape, + client()->backend().transfer_manager()->ChooseCompactLayoutForShape( + shape)); + return shape.layout(); +} + StatusOr> PjRtStreamExecutorClient::GetHloCostAnalysis() const { return std::make_unique( @@ -280,7 +306,10 @@ void StallStreamOnError(LocalDeviceState* local_device, se::Stream* stream) { // This will stall computation but that's ok in this very rare error // case. if (stream != local_device->compute_stream()) { - local_device->compute_stream()->ThenWaitFor(stream); + auto status = local_device->compute_stream()->WaitFor(stream); + if (!status.ok()) { + LOG(ERROR) << "Stalling compute stream failed: " << status; + } } break; @@ -348,7 +377,8 @@ void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer, if (buffers_to_release) { buffers_to_release->push_back(device_buffer.buffer()); } else { - buffer_local_device->ThenRelease(usage_stream, device_buffer.buffer()); + buffer_local_device->ThenRelease(usage_stream, device_buffer.buffer()) + .IgnoreError(); } } device_buffer.ConvertUsageHold(usage_stream, event, @@ -382,13 +412,13 @@ StatusOr> AllocateDestinationBuffer( TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer, transfer_manager->AllocateScopedShapedBuffer( on_host_shape, se_client->allocator(), - local_device->device_ordinal())); + local_device->local_device_id().value())); if (local_device->allocation_model() == LocalDeviceState::kComputeSynchronized) { if (copy_stream == nullptr) { CHECK(is_uninitialized_create); } else { - copy_stream->ThenWaitFor(local_device->compute_stream()); + CHECK(copy_stream->WaitFor(local_device->compute_stream()).ok()); } } else { DCHECK(transfer_manager->CanShapedBufferBeAccessedNow( @@ -436,7 +466,8 @@ StatusOr> AllocateDestinationBuffer( if (tuple_table_stream != copy_stream) { if (local_device->allocation_model() == LocalDeviceState::kComputeSynchronized) { - tuple_table_stream->ThenWaitFor(local_device->compute_stream()); + DCHECK( + tuple_table_stream->WaitFor(local_device->compute_stream()).ok()); } else { DCHECK(transfer_manager->CanShapedBufferBeAccessedNow( local_device->compute_stream()->parent(), dst_buffer)); @@ -586,9 +617,7 @@ void PjRtStreamExecutorBuffer::ScopedHold::AddToInput( } } -bool PjRtStreamExecutorBuffer::IsOnCpu() const { - return client()->platform_id() == CpuId(); -} +bool PjRtStreamExecutorBuffer::IsOnCpu() const { return false; } StatusOr PjRtStreamExecutorBuffer::logical_on_device_shape() { if (on_device_shape_.is_static()) { @@ -729,9 +758,9 @@ PjRtStreamExecutorBuffer::DonateWithControlDependency( original_definition_events.end()); auto new_device_buffer = std::make_shared( - tracked_buffer->allocator(), device()->local_hardware_id(), + tracked_buffer->allocator(), device()->local_device_id().value(), std::move(buffers), std::move(definition_events), - /*on_delete_callback=*/std::function()); + /*on_delete_callback=*/nullptr); // Make the new buffer which is identical to the old, except for the new // definition event. @@ -764,7 +793,7 @@ PjRtStreamExecutorClient::BufferFromHostBuffer( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, PjRtDevice* device, + absl::AnyInvocable on_done_with_host_buffer, PjRtDevice* device, const Layout* device_layout) { tsl::profiler::TraceMe traceme( "PjRtStreamExecutorClient::BufferFromHostBuffer"); @@ -797,59 +826,6 @@ PjRtStreamExecutorClient::BufferFromHostBuffer( ShapeUtil::ByteStrides(device_shape, absl::MakeSpan(shape_strides))); bool host_and_device_strides_equal = (size == 0 || *byte_strides == shape_strides); - // The CPU platform is special because the "host" and the "device" are in the - // same memory space. If the input shape is in the correct layout and we don't - // want to defer the copy onto a thread, we can use the following fast - // path. - bool is_cpu_platform = - local_device->executor()->platform()->id() == se::host::kHostPlatformId; - if (is_cpu_platform) { - // If we are on the host platform and the input buffer is sufficiently - // aligned, we can simply point to the input array's data without any - // further copies. At the time of writing we require a 16-byte alignment - // because XLA may generate code which requires it. - bool can_use_zero_copy = - host_buffer_semantics == HostBufferSemantics::kZeroCopy && - ((absl::bit_cast(data) & - (cpu_function_runtime::MinAlign() - 1)) == 0); - if (host_and_device_strides_equal && - (host_buffer_semantics == - HostBufferSemantics::kImmutableOnlyDuringCall || - can_use_zero_copy)) { - std::function on_delete_callback; - se::DeviceMemoryBase buffer; - // If we are on the host platform and the input buffer is sufficiently - // aligned, we can simply point to the input array's data without any - // further copies. At the time of writing we require a 16-byte alignment - // because XLA may generate code which requires it. - if (can_use_zero_copy) { - on_delete_callback = std::move(on_done_with_host_buffer); - buffer = se::DeviceMemoryBase( - const_cast(static_cast(data)), size); - } else { - void* staging_buffer = host_memory_allocator()->AllocateRaw( - cpu_function_runtime::MinAlign(), size); - buffer = se::DeviceMemoryBase(staging_buffer, size); - std::memcpy(staging_buffer, data, size); - if (on_done_with_host_buffer) { - on_done_with_host_buffer(); - } - on_delete_callback = [staging_buffer, host_memory_allocator = - host_memory_allocator()]() { - host_memory_allocator->DeallocateRaw(staging_buffer); - }; - } - absl::Span> - definition_events; - auto device_buffer = std::make_shared( - /*allocator=*/nullptr, local_device->device_ordinal(), - std::initializer_list{buffer}, - definition_events, std::move(on_delete_callback)); - return std::unique_ptr( - std::make_unique( - device_shape, std::move(device_buffer), this, device)); - } - } TF_ASSIGN_OR_RETURN( std::unique_ptr py_buffer, @@ -861,32 +837,48 @@ PjRtStreamExecutorClient::BufferFromHostBuffer( py_buffer->GetBufferWithUsageHold()); CHECK(device_buffer.ok()); + std::shared_ptr transpose; + if (!host_and_device_strides_equal) { + absl::InlinedVector permutation(dims.size()); + absl::c_reverse_copy(device_shape.layout().minor_to_major(), + permutation.begin()); + TransposePlan::Options options; + options.elem_size_in_bytes = primitive_util::ByteWidth(type); + options.dims = dims; + options.permutation = permutation; + options.input_layout = TransposePlan::Striding{*byte_strides}; + absl::MutexLock lock(&transpose_mu_); + TF_ASSIGN_OR_RETURN(transpose, transpose_cache_.GetOrCreate(options)); + } + + bool should_pack = + primitive_util::Is4BitType(type) && transfer_manager->PackSubbyteTypes(); + int64_t packed_size; + if (should_pack) { + packed_size = CeilOfRatio(size, 2); + } else { + packed_size = size; + } + // If necessary, allocate a host-side buffer for staging host-to-device // transfers. On GPU this is a buffer in pinned memory. std::shared_ptr staging_buffer; - if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall || - should_stage_host_to_device_transfers() || - !host_and_device_strides_equal) { + bool must_use_staging_buffer = + host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall || + !host_and_device_strides_equal || packed_size != size; + // Allocating multigigabyte pinned buffers can be very slow. In that case, + // using a staging buffer is probably worse than not using one. + // TODO(phawkins): add chunking for transfers. + if (must_use_staging_buffer || (should_stage_host_to_device_transfers() && + packed_size < (int64_t{1} << 30))) { void* ptr = host_memory_allocator()->AllocateRaw( - tsl::Allocator::kAllocatorAlignment, size); + tsl::Allocator::kAllocatorAlignment, transpose ? size : packed_size); staging_buffer = std::shared_ptr( ptr, [host_memory_allocator = host_memory_allocator()](void* ptr) { host_memory_allocator->DeallocateRaw(ptr); }); } - std::shared_ptr transpose; - if (!host_and_device_strides_equal) { - absl::InlinedVector permutation(dims.size()); - absl::c_reverse_copy(device_shape.layout().minor_to_major(), - permutation.begin()); - absl::MutexLock lock(&transpose_mu_); - TF_ASSIGN_OR_RETURN(transpose, - transpose_cache_.GetOrCreate( - primitive_util::ByteWidth(type), dims, permutation, - TransposePlan::Striding{*byte_strides})); - } - // Copy the buffer into a staging buffer before returning control to the // caller if the caller only guaranteed that the buffer is valid for the // duration of the call. Otherwise, we stage (if necessary) on a separate @@ -894,11 +886,23 @@ PjRtStreamExecutorClient::BufferFromHostBuffer( if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall) { if (transpose) { transpose->Execute(data, staging_buffer.get()); + if (should_pack) { + PackInt4(absl::MakeConstSpan( + static_cast(staging_buffer.get()), size), + absl::MakeSpan(static_cast(staging_buffer.get()), + packed_size)); + } } else { - std::memcpy(staging_buffer.get(), data, size); + if (should_pack) { + PackInt4(absl::MakeConstSpan(static_cast(data), size), + absl::MakeSpan(static_cast(staging_buffer.get()), + packed_size)); + } else { + std::memcpy(staging_buffer.get(), data, size); + } } if (on_done_with_host_buffer) { - on_done_with_host_buffer(); + std::move(on_done_with_host_buffer)(); on_done_with_host_buffer = nullptr; } } @@ -911,11 +915,15 @@ PjRtStreamExecutorClient::BufferFromHostBuffer( // put the transfer into the calling thread for small literals. auto transfer_h2d = [local_client = client(), transfer_manager, local_device, data, size, - movable_device_buffer{device_buffer.ToClosure()}, device_shape, - py_buffer{py_buffer.get()}, + type, packed_size, movable_device_buffer{device_buffer.ToClosure()}, + device_shape, should_pack, py_buffer{py_buffer.get()}, on_device_shape{py_buffer->on_device_shape()}, staging_buffer{std::move(staging_buffer)}, - on_done_with_host_buffer{std::move(on_done_with_host_buffer)}, + on_done_with_host_buffer = + on_done_with_host_buffer + ? std::make_shared>( + std::move(on_done_with_host_buffer)) + : nullptr, host_buffer_semantics, transpose{std::move(transpose)}]() { PjRtStreamExecutorBuffer::ScopedHold device_buffer( movable_device_buffer); @@ -925,7 +933,8 @@ PjRtStreamExecutorClient::BufferFromHostBuffer( // memory that has already been allocated, and a possible Event // allocation. - ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape); + se::DeviceMemoryBase device_memory = device_buffer->device_memory()[0]; + // If applicable on the backend, stage the transfer via host memory // allocated via the host_memory_allocator. On GPU, this is pinned // memory. @@ -936,24 +945,29 @@ PjRtStreamExecutorClient::BufferFromHostBuffer( HostBufferSemantics::kImmutableOnlyDuringCall) { if (transpose) { transpose->Execute(data, staging_buffer.get()); + if (should_pack) { + PackInt4( + absl::MakeConstSpan( + static_cast(staging_buffer.get()), size), + absl::MakeSpan(static_cast(staging_buffer.get()), + packed_size)); + } } else { - std::memcpy(staging_buffer.get(), data, size); + if (should_pack) { + PackInt4( + absl::MakeConstSpan(static_cast(data), size), + absl::MakeSpan(static_cast(staging_buffer.get()), + packed_size)); + } else { + std::memcpy(staging_buffer.get(), data, size); + } } } - // The buffer has the same dimension order as the on-device shape, but - // is not tiled, etc. - BorrowingLiteral literal( - static_cast(staging_buffer.get()), - ShapeUtil::DeviceShapeToHostShape(on_device_shape)); - TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( - local_device->host_to_device_stream(), literal, buffer)); + TF_CHECK_OK(local_device->host_to_device_stream()->Memcpy( + &device_memory, staging_buffer.get(), packed_size)); } else { - BorrowingLiteral literal( - reinterpret_cast(data), - ShapeUtil::DeviceShapeToHostShape(on_device_shape)); - // Otherwise, just transfer the literal. - TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( - local_device->host_to_device_stream(), literal, buffer)); + TF_CHECK_OK(local_device->host_to_device_stream()->Memcpy( + &device_memory, data, packed_size)); } std::shared_ptr event = @@ -962,22 +976,17 @@ PjRtStreamExecutorClient::BufferFromHostBuffer( local_device, std::move(device_buffer), event, local_device->host_to_device_stream())); - local_device->ThenExecuteCallback( + TF_CHECK_OK(local_device->ThenExecuteCallback( local_device->host_to_device_stream(), [staging_buffer{std::move(staging_buffer)}, - on_done_with_host_buffer{std::move(on_done_with_host_buffer)}]() { + on_done_with_host_buffer{ + std::move(on_done_with_host_buffer)}]() mutable { if (on_done_with_host_buffer) { - on_done_with_host_buffer(); + std::move (*on_done_with_host_buffer)(); } - }); + })); }; - if (is_cpu_platform) { - // Using the thread_pool would be a double thread hop; the code - // already defers its work onto a stream (= thread on CPU). - transfer_h2d(); - } else { - thread_pool()->Schedule(transfer_h2d); - } + thread_pool()->Schedule(transfer_h2d); return std::unique_ptr(std::move(py_buffer)); } @@ -986,10 +995,11 @@ PjRtStreamExecutorClient::BufferFromHostBuffer( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, PjRtDevice* device) { - return BufferFromHostBuffer(data, type, dims, byte_strides, - host_buffer_semantics, on_done_with_host_buffer, - device, /*device_layout=*/nullptr); + absl::AnyInvocable on_done_with_host_buffer, + PjRtDevice* device) { + return BufferFromHostBuffer( + data, type, dims, byte_strides, host_buffer_semantics, + std::move(on_done_with_host_buffer), device, /*device_layout=*/nullptr); } StatusOr> @@ -1033,7 +1043,21 @@ PjRtStreamExecutorClient::CreateErrorBuffer(Status error, const Shape& shape, auto definition_event = std::make_shared(this->thread_pool()); definition_event->SetDefinedStatus(error); - return CreateUninitializedBuffer(shape, device, definition_event); + + // Create an empty buffer. + auto* se_client = tensorflow::down_cast(this); + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, + tensorflow::down_cast(device) + ->GetLocalDeviceState()); + absl::Span buffers; + auto dummy_device_buffer = std::make_shared( + se_client->allocator(), local_device->local_device_id().value(), buffers, + absl::MakeSpan(&definition_event, 1), + /*on_delete_callback=*/nullptr); + + auto py_buffer = std::make_unique( + shape, std::move(dummy_device_buffer), this, device); + return py_buffer; } StatusOr> @@ -1079,6 +1103,7 @@ PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal, // allocation. se::Stream* h2d_stream = local_device->host_to_device_stream(); + ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape); TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( h2d_stream, literal, buffer)); @@ -1202,7 +1227,7 @@ PjRtStreamExecutorClient::CreateViewOfDeviceBuffer( definition_stream); auto device_buffer = std::make_shared( - /*allocator=*/nullptr, device->local_hardware_id(), + /*allocator=*/nullptr, device->local_device_id().value(), std::initializer_list{buffer}, definition_events, std::move(on_delete_callback)); return std::unique_ptr(std::make_unique( @@ -1214,7 +1239,7 @@ Status PjRtStreamExecutorDevice::TransferToInfeed(const LiteralSlice& literal) { // Only support infeed to local device. TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState()); return local_device->client()->TransferToInfeedLocal( - literal, local_device->device_ordinal()); + literal, local_device->local_hardware_id().value()); } Status PjRtStreamExecutorDevice::TransferFromOutfeed( @@ -1222,7 +1247,7 @@ Status PjRtStreamExecutorDevice::TransferFromOutfeed( VLOG(1) << "PjRtStreamExecutorDevice::TransferFromOutfeed"; TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState()); return local_device->client()->TransferFromOutfeedLocal( - local_device->device_ordinal(), literal); + local_device->local_hardware_id().value(), literal); } absl::Span PjRtStreamExecutorDevice::memory_spaces() @@ -1250,13 +1275,18 @@ PjRtStreamExecutorDevice::GetStreamForExternalReadyEvents() const { StatusOr PjRtStreamExecutorClient::LookupAddressableDevice( int local_hardware_id) const { + return LookupAddressableDevice(PjRtLocalDeviceId(local_hardware_id)); +} + +StatusOr PjRtStreamExecutorClient::LookupAddressableDevice( + xla::PjRtLocalDeviceId local_device_id) const { for (auto* device : addressable_devices_) { - if (local_hardware_id == device->local_hardware_id()) { + if (local_device_id == device->local_device_id()) { return device; } } - return InvalidArgument("No matching device found for local_hardware_id %d", - local_hardware_id); + return InvalidArgument("No matching device found for local_device_id %d", + local_device_id.value()); } absl::Span PjRtStreamExecutorClient::memory_spaces() @@ -1361,12 +1391,12 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) { } if (block_stream != nullptr) { se::Stream* block_stream_ptr = block_stream.release(); - local_device_state->ThenExecuteCallback( + TF_RETURN_IF_ERROR(local_device_state->ThenExecuteCallback( block_stream_ptr, [device_buffer, block_stream_ptr, local_device_state]() { local_device_state->ReturnStreamToPool( std::unique_ptr(block_stream_ptr)); - }); + })); } } } @@ -1462,6 +1492,15 @@ void PjRtStreamExecutorBuffer::DropHold(ScopedHold::Type type, } } +PjRtFuture PjRtStreamExecutorBuffer::LazyToLiteral( + absl::AnyInvocable() &&> generator) { + auto buffer = std::move(generator)(); + if (!buffer.ok()) { + return PjRtFuture(buffer.status()); + } + return ToLiteral(buffer.value()); +} + PjRtFuture PjRtStreamExecutorBuffer::ToLiteral( MutableLiteralBase* literal) { VLOG(1) << "PjRtStreamExecutorBuffer::ToLiteral"; @@ -1544,7 +1583,10 @@ PjRtFuture PjRtStreamExecutorBuffer::ToLiteral( local_device->event_pool().ThenRecordEvent(stream, event_or.value()); usage_event->SetSequencingEvent(std::move(event_or).value(), stream); - local_device->ThenRelease(stream, tracked_device_buffer); + defined_status = local_device->ThenRelease(stream, tracked_device_buffer); + if (!defined_status.ok()) { + promise.Set(defined_status); + } }; tracked_device_buffer->definition_events()[0]->ExecuteOrAddToFutureTasks( @@ -1586,6 +1628,29 @@ PjRtFuture PjRtStreamExecutorBuffer::CopyRawToHost( return client_->CopyRawSubBufferToHost(this, dst, offset, transfer_size); } +PjRtFuture PjRtStreamExecutorBuffer::CopyRawToHostFuture( + PjRtFuture> dst, int64_t offset, int64_t transfer_size) { + auto promise = PjRtFuture::CreatePromise(); + dst.OnReady([this, promise, offset, + transfer_size](absl::StatusOr dst) mutable { + if (dst.ok()) { + // Trampoline through a thread pool since some device types (e.g., GPUs) + // do not allow calling D2H inside the callback's context. + client_->thread_pool()->Schedule( + [this, dst = *dst, offset, transfer_size, + promise = std::move(promise)]() mutable { + CopyRawToHost(dst, offset, transfer_size) + .OnReady([promise = std::move(promise)](Status status) mutable { + promise.Set(status); + }); + }); + } else { + promise.Set(dst.status()); + } + }); + return PjRtFuture(std::move(promise)); +} + StatusOr PjRtStreamExecutorBuffer::AsShapedBuffer() const { absl::MutexLock lock(&mu_); if (device_buffer_ == nullptr) { @@ -1671,8 +1736,11 @@ PjRtStreamExecutorBuffer::CopyToDeviceHelper( // returned, and StallStreamOnError only makes sure the // destination device is ok, so make sure that the src buffer // remains valid until after any transfers have completed. - src_local_device->ThenRelease(transfer_stream, - std::move(src_device_buffer)); + auto status = src_local_device->ThenRelease( + transfer_stream, std::move(src_device_buffer)); + if (!status.ok()) { + LOG(ERROR) << "ThenRelease failed due to: " << status; + } } return; } @@ -1693,8 +1761,11 @@ PjRtStreamExecutorBuffer::CopyToDeviceHelper( copy_event->SetDefinedStatus(defined_status); } - src_local_device->ThenRelease(transfer_stream, - std::move(src_device_buffer)); + auto status = src_local_device->ThenRelease(transfer_stream, + std::move(src_device_buffer)); + if (!status.ok()) { + LOG(ERROR) << "ThenRelease failed due to: " << status; + } }; src_device_buffer->definition_events()[0]->ExecuteOrAddToFutureTasks( @@ -1734,7 +1805,7 @@ StatusOr> PjRtStreamExecutorBuffer::CopyToDevice( literal_pointer->untyped_data(), literal_pointer->shape().element_type(), literal_pointer->shape().dimensions(), byte_strides, - PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy, + PjRtStreamExecutorClient::HostBufferSemantics::kImmutableZeroCopy, [literal{std::move(literal)}]() { /* frees literal */ }, dst_device); } @@ -1860,7 +1931,7 @@ PjRtFuture PjRtStreamExecutorBuffer::GetReadyFuture() { // the callback directly on that stream instead of bouncing through // local_device_state->ThenExecuteCallback. The direct callback // saves significant time. - stream_ptr->ThenDoHostCallback( + auto status = stream_ptr->DoHostCallback( [definition_promise, stream_ptr, local_device_state, event_with_status = device_buffer->definition_events()[0]]() mutable { @@ -1868,6 +1939,10 @@ PjRtFuture PjRtStreamExecutorBuffer::GetReadyFuture() { std::unique_ptr(stream_ptr)); definition_promise.Set(event_with_status->GetDefinedStatus()); }); + if (!status.ok()) { + definition_promise.Set(status); + return; + } } else { // All events are already complete; set the `definition_promise` // with the status of the buffer's first definition event which may @@ -1981,7 +2056,7 @@ StatusOr MakeTupleHelper( if (local_device->allocation_model() == LocalDeviceState::kComputeSynchronized) { - stream->ThenWaitFor(local_device->compute_stream()); + TF_RETURN_IF_ERROR(stream->WaitFor(local_device->compute_stream())); } else { DCHECK(transfer_manager->CanBufferBeAccessedNow( local_device->compute_stream()->parent(), root_table_memory.cref())); @@ -2214,24 +2289,22 @@ using tsl::MakeConstructedAsyncValueRef; // Converts PjRt SendCallbacks to an XLA StreamExecutor send function. static SendDeviceMemoryFunction ConvertSendCallbacksToSendFunction( - int device_ordinal, const ExecuteOptions& options, + int replica, const ExecuteOptions& options, tsl::thread::ThreadPool* thread_pool) { - // Check if we have callbacks registered for the given device ordinal. - if (device_ordinal >= options.send_callbacks.size()) { - return - [device_ordinal](int64_t channel_id, se::Stream*, const Shape&, - const se::DeviceMemoryBase&, - const absl::flat_hash_map&) { - return InvalidArgument( - "Failed to send a buffer to the channel_id=%d, there was no send " - "callbacks registered for the device_ordinal=%d", - channel_id, device_ordinal); - }; + // Check if we have callbacks registered for the given replica. + if (replica >= options.send_callbacks.size()) { + return [replica](int64_t channel_id, se::Stream*, const Shape&, + const se::DeviceMemoryBase&, + const absl::flat_hash_map&) { + return Internal( + "Don't send a buffer to the channel_id=%d, there was no send " + "callbacks registered for the replica=%d", + channel_id, replica); + }; } // SendCallbacks registered for a device ordinal. Can be empty. - absl::Span callbacks = - options.send_callbacks[device_ordinal]; + absl::Span callbacks = options.send_callbacks[replica]; return [callbacks, thread_pool]( int64_t channel_id, se::Stream* stream, const Shape& shape, @@ -2253,8 +2326,8 @@ static SendDeviceMemoryFunction ConvertSendCallbacksToSendFunction( // the device memory long enough to complete the memcpy command. auto done_event = MakeConstructedAsyncValueRef(stream->parent()); if (!done_event->Init()) - return InternalError("Failed to initialize done event (channel_id=%d)", - channel_id); + return Internal("Failed to initialize done event (channel_id=%d)", + channel_id); thread_pool->Schedule([done_event, stream, src, channel_id, shape, send] { tsl::profiler::TraceMe trace([&] { @@ -2266,8 +2339,16 @@ static SendDeviceMemoryFunction ConvertSendCallbacksToSendFunction( // Allocate chunk on the host for copying data from device. PjRtChunk chunk = PjRtChunk::AllocateDefault(src.size()); - stream->ThenMemcpy(chunk.data(), src, src.size()); - stream->ThenRecordEvent(&done_event.get()); + auto status = stream->Memcpy(chunk.data(), src, src.size()); + if (!status.ok()) { + done_event.SetError(status); + return; + } + status = stream->RecordEvent(&done_event.get()); + if (!status.ok()) { + done_event.SetError(status); + return; + } // Wait for the data to be available on the host. if (auto st = stream->BlockHostUntilDone(); !st.ok()) { @@ -2342,17 +2423,29 @@ class StreamExecutorCopyToDeviceStream : public CopyToDeviceStream { bool complete = IsCompleteLocked(); lock.Release(); - stream_->ThenMemcpy(&dst, chunk.data(), chunk.size()); + auto copied = stream_->Memcpy(&dst, chunk.data(), chunk.size()); + if (!copied.ok()) { + done_.SetError(copied); + return PjRtFuture(done_.GetError()); + } // Delete chunk once the memcpy operation completes. auto* chunk_ptr = std::make_unique(std::move(chunk)).release(); - stream_->ThenDoHostCallback([chunk_ptr]() { delete chunk_ptr; }); + auto deleted = stream_->DoHostCallback([chunk_ptr]() { delete chunk_ptr; }); + if (!deleted.ok()) { + done_.SetError(deleted); + return PjRtFuture(done_.GetError()); + } // Record done event once processed the last chunk. It is the caller // responsibility to synchronize with this event before submitting any new // computations to the stream. if (complete) { - stream_->ThenRecordEvent(&done_.get()); + auto recorded = stream_->RecordEvent(&done_.get()); + if (!recorded.ok()) { + done_.SetError(recorded); + return PjRtFuture(done_.GetError()); + } done_.SetStateConcrete(); } @@ -2371,23 +2464,21 @@ class StreamExecutorCopyToDeviceStream : public CopyToDeviceStream { } // namespace static RecvDeviceMemoryFunction ConvertRecvCallbacksToRecvFunction( - int device_ordinal, const ExecuteOptions& options) { - // Check if we have callbacks registered for the given device ordinal. - if (device_ordinal >= options.send_callbacks.size()) { - return - [device_ordinal](int64_t channel_id, se::Stream*, const Shape&, - se::DeviceMemoryBase*, - const absl::flat_hash_map&) { - return InvalidArgument( - "Failed to receive a buffer from the channel_id=%d, there was no " - "recv callbacks registered for the device_ordinal=%d", - channel_id, device_ordinal); - }; + int replica, const ExecuteOptions& options) { + // Check if we have callbacks registered for the given replica. + if (replica >= options.send_callbacks.size()) { + return [replica](int64_t channel_id, se::Stream*, const Shape&, + se::DeviceMemoryBase*, + const absl::flat_hash_map&) { + return InvalidArgument( + "Failed to receive a buffer from the channel_id=%d, there was no " + "recv callbacks registered for the replica=%d", + channel_id, replica); + }; } // RecvCallbacks registered for a device ordinal. Can be empty. - absl::Span callbacks = - options.recv_callbacks[device_ordinal]; + absl::Span callbacks = options.recv_callbacks[replica]; return [callbacks](int64_t channel_id, se::Stream* stream, const Shape& shape, se::DeviceMemoryBase* dst, @@ -2414,8 +2505,8 @@ static RecvDeviceMemoryFunction ConvertRecvCallbacksToRecvFunction( // `StreamExecutorCopyToDeviceStream` implementation above). auto done_event = MakeConstructedAsyncValueRef(stream->parent()); if (!done_event->Init()) - return InternalError("Failed to initialize done event (channel_id=%d)", - channel_id); + return Internal("Failed to initialize done event (channel_id=%d)", + channel_id); recv->callback({shape}, std::make_unique( channel_id, stream, *dst, done_event)); @@ -2439,7 +2530,8 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution( std::vector>& compute_callbacks) const { int device_ordinal = tensorflow::down_cast(device) ->local_device_state() - ->device_ordinal(); + ->local_device_id() + .value(); LocalDeviceState* device_state = &(client_->device_state(device_ordinal)); tsl::profiler::TraceMeConsumer activity( "PjRtStreamExecutorLoadedExecutable::EnqueueExecution", @@ -2546,9 +2638,9 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution( // Create a PjRt<->StreamExecutor adaptors to send/recv device memory as // PjRt chunks via the user-provided callbacks. SendDeviceMemoryFunction send_device_memory = - ConvertSendCallbacksToSendFunction(device_ordinal, options, thread_pool); + ConvertSendCallbacksToSendFunction(replica, options, thread_pool); RecvDeviceMemoryFunction recv_device_memory = - ConvertRecvCallbacksToRecvFunction(device_ordinal, options); + ConvertRecvCallbacksToRecvFunction(replica, options); ExecutableRunOptions run_options; run_options.set_stream(device_state->compute_stream()); @@ -2696,7 +2788,8 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper( CHECK_EQ(device->process_index(), client_->process_index()); int device_ordinal = tensorflow::down_cast(device) ->local_device_state() - ->device_ordinal(); + ->local_device_id() + .value(); tsl::profiler::TraceMe traceme( "PjRtStreamExecutorLoadedExecutable::ExecuteHelper"); VLOG(1) << "Replica " << replica << ", partition " << partition @@ -2765,13 +2858,13 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper( compute_callbacks.push_back( [promise = std::move(promise)]() mutable { promise.Set(OkStatus()); }); } - device_state->ThenExecuteCallback( + TF_RETURN_IF_ERROR(device_state->ThenExecuteCallback( stream, [callbacks{std::move(compute_callbacks)}, buffers_to_release{std::move(buffers_to_release)}]() { for (auto& fn : callbacks) { fn(); } - }); + })); metrics::ReportExecutableEnqueueTime(tsl::Env::Default()->NowMicros() - start_time_usecs); return Result({/*future=*/std::move(future), /*buffers=*/std::move(outputs)}); @@ -2806,7 +2899,7 @@ PjRtStreamExecutorLoadedExecutable::Execute( << " num_partitions=" << num_partitions() << " num_addressable_devices=" << num_addressable_devices; std::vector> results(num_addressable_devices); - if (num_addressable_devices == 1) { + if (num_addressable_devices == 1 && !ThisThreadIsInsideHostCallback()) { // Fast-path if there is only one device — run the computation on the // current thread. const int replica = addressable_device_logical_ids_[0].replica; @@ -3014,6 +3107,22 @@ PjRtStreamExecutorLoadedExecutable::GetOutputMemoryKinds() const { return Unimplemented("GetOutputMemoryKinds is not supported."); } +StatusOr +PjRtStreamExecutorLoadedExecutable::FingerprintExecutable() const { + if (executables_.size() != 1) { + return absl::InternalError( + "Fingerprinting multiple executables within one " + "PjRtStreamExecutorLoadedExecutable is not supported."); + } + + Executable* executable = executables_[0]->executable(); + if (executable->has_module()) { + return executable->module().GetFingerprint128(); + } else { + return absl::InternalError("Executable does not have HLO modules."); + } +} + StatusOr PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) { ExecutableExtras extras; @@ -3156,7 +3265,7 @@ StatusOr PjRtStreamExecutorClient::SerializeExecutable( absl::Span> local_executables = se_executable->executables(); if (local_executables.empty()) { - return InternalError("No local executable"); + return Internal("No local executable"); } if (local_executables.size() != 1) { return Unimplemented( diff --git a/xla/pjrt/pjrt_stream_executor_client.h b/xla/pjrt/pjrt_stream_executor_client.h index bcffdb5604162..26cd4d1bec861 100644 --- a/xla/pjrt/pjrt_stream_executor_client.h +++ b/xla/pjrt/pjrt_stream_executor_client.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_ #define XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_ +#include #include #include #include @@ -78,6 +79,10 @@ class PjRtStreamExecutorDeviceDescription : public PjRtDeviceDescription { absl::string_view DebugString() const override { return debug_string_; } + int core_on_chip() const { return core_index_; } + + absl::Span coords() const { return absl::MakeSpan(coords_); } + const absl::flat_hash_map& Attributes() const override { return attributes_; @@ -94,13 +99,19 @@ class PjRtStreamExecutorDeviceDescription : public PjRtDeviceDescription { void SetToString(std::string to_string) { to_string_ = std::move(to_string); } + void SetCoords(std::array coords) { coords_ = coords; } + + void SetCoreOnChip(int core_index) { core_index_ = core_index; } + private: const int id_; const int process_index_; const std::string device_kind_; + int core_index_ = -1; std::string debug_string_ = ""; std::string to_string_ = ""; absl::flat_hash_map attributes_; + std::array coords_; }; class PjRtStreamExecutorDevice : public PjRtDevice { @@ -109,8 +120,12 @@ class PjRtStreamExecutorDevice : public PjRtDevice { int id, std::unique_ptr local_device_state, std::string device_kind, int process_index = 0) : description_(id, std::move(device_kind), process_index), - device_ordinal_( - local_device_state ? local_device_state->device_ordinal() : -1), + local_device_id_(local_device_state + ? local_device_state->local_device_id() + : PjRtLocalDeviceId(-1)), + local_hardware_id_(local_device_state + ? local_device_state->local_hardware_id() + : PjRtLocalHardwareId(-1)), local_device_state_(std::move(local_device_state)) {} ~PjRtStreamExecutorDevice() override = default; @@ -137,9 +152,19 @@ class PjRtStreamExecutorDevice : public PjRtDevice { PjRtClient* client() const override { return client_; } - bool IsAddressable() const override { return device_ordinal_ != -1; } + bool IsAddressable() const override { return local_device_id_ != -1; } - int local_hardware_id() const override { return device_ordinal_; } + int local_hardware_id() const override { + return local_hardware_id_typed().value(); + } + + PjRtLocalDeviceId local_device_id() const override { + return local_device_id_; + } + + PjRtLocalHardwareId local_hardware_id_typed() const override { + return local_hardware_id_; + } // If this is a device local to this host, returns a LocalDeviceState object // that can be used to manipulate the device. Returns nullptr if the device is @@ -170,7 +195,8 @@ class PjRtStreamExecutorDevice : public PjRtDevice { private: PjRtStreamExecutorDeviceDescription description_; - const int device_ordinal_; // -1 means not local. + const PjRtLocalDeviceId local_device_id_; + const PjRtLocalHardwareId local_hardware_id_; const std::unique_ptr local_device_state_; PjRtClient* client_ = nullptr; }; @@ -199,16 +225,23 @@ class PjRtStreamExecutorClient : public PjRtClient { } StatusOr LookupDevice(int device_id) const override { - auto it = id_to_device_.find(device_id); + return LookupDevice(PjRtGlobalDeviceId(device_id)); + } + + StatusOr LookupDevice( + PjRtGlobalDeviceId global_device_id) const override { + auto it = id_to_device_.find(global_device_id.value()); if (it != id_to_device_.end()) { return it->second; } return InvalidArgument("No matching device found for device_id %d", - device_id); + global_device_id.value()); } StatusOr LookupAddressableDevice( int local_hardware_id) const override; + StatusOr LookupAddressableDevice( + PjRtLocalDeviceId local_device_id) const override; absl::Span memory_spaces() const override; @@ -225,6 +258,9 @@ class PjRtStreamExecutorClient : public PjRtClient { StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override; + StatusOr GetDefaultLayout(PrimitiveType element_type, + absl::Span dims) override; + StatusOr> Compile( const XlaComputation& computation, CompileOptions options) override; StatusOr> Compile( @@ -276,14 +312,14 @@ class PjRtStreamExecutorClient : public PjRtClient { const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, PjRtDevice* device, - const Layout* device_layout) override; + absl::AnyInvocable on_done_with_host_buffer, + PjRtDevice* device, const Layout* device_layout) override; StatusOr> BufferFromHostBuffer( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, + absl::AnyInvocable on_done_with_host_buffer, PjRtDevice* device) override; StatusOr> BufferFromHostLiteral( @@ -630,13 +666,20 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { bool wait_for_operations_to_complete) override; using PjRtBuffer::ToLiteralSync; - PjRtFuture ToLiteral(MutableLiteralBase* literal) override; + PjRtFuture ToLiteral(MutableLiteralBase* literal) override; + PjRtFuture LazyToLiteral( + absl::AnyInvocable() &&> generator) + override; StatusOr GetOnDeviceSizeInBytes() const override; PjRtFuture CopyRawToHost(void* dst, int64_t offset, int64_t transfer_size) override; + PjRtFuture CopyRawToHostFuture(PjRtFuture> dst, + int64_t offset, + int64_t transfer_size) override; + // Drops the buffer's reference to its associated device memory, leaving the // buffer in an invalid state. The memory will be freed lazily when all async // operations using the buffer have completed, according to the allocation @@ -875,6 +918,12 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable { return executables_; } + absl::StatusOr GetCompileOptions() const override { + return compile_options_; + } + + absl::StatusOr FingerprintExecutable() const override; + protected: bool parameter_is_tupled_arguments() const { return parameter_is_tupled_arguments_; diff --git a/xla/pjrt/pjrt_stream_executor_client_test.cc b/xla/pjrt/pjrt_stream_executor_client_test.cc index e835ad746dbf4..26bbbaa300a22 100644 --- a/xla/pjrt/pjrt_stream_executor_client_test.cc +++ b/xla/pjrt/pjrt_stream_executor_client_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ limitations under the License. namespace xla { namespace { -xla::StatusOr> GetClient() { +absl::StatusOr> GetClient() { LocalClient* local_client = xla::ClientLibrary::LocalClientOrDie(); TF_ASSIGN_OR_RETURN(se::Platform * platform, PlatformUtil::GetPlatform("Host")); @@ -167,7 +167,7 @@ TEST(PjRtStreamExecutorClientTest, DonateWithControlDependency) { auto result_literal = std::make_shared( ShapeUtil::DeviceShapeToHostShape(blocked_buffer->on_device_shape())); bool got_literal = false; - blocked_buffer->ToLiteral(result_literal.get(), [&](absl::Status s) { + blocked_buffer->ToLiteral(result_literal.get()).OnReady([&](absl::Status s) { absl::MutexLock l(&mu); TF_ASSERT_OK(s); got_literal = true; @@ -176,7 +176,7 @@ TEST(PjRtStreamExecutorClientTest, DonateWithControlDependency) { EXPECT_FALSE(got_literal); - avr.emplace(tsl::OkStatus()); + avr.emplace(absl::OkStatus()); EXPECT_TRUE(future.IsReady()); { diff --git a/xla/pjrt/plugin/BUILD b/xla/pjrt/plugin/BUILD index 3a9984373625a..73319c1cb29d2 100644 --- a/xla/pjrt/plugin/BUILD +++ b/xla/pjrt/plugin/BUILD @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/pjrt/semaphore.cc b/xla/pjrt/semaphore.cc index 1a729455caf24..54fc9674b005f 100644 --- a/xla/pjrt/semaphore.cc +++ b/xla/pjrt/semaphore.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/semaphore.h b/xla/pjrt/semaphore.h index b1c42519cbdb6..d4871911ad22f 100644 --- a/xla/pjrt/semaphore.h +++ b/xla/pjrt/semaphore.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/semaphore_test.cc b/xla/pjrt/semaphore_test.cc index 348bed71321bd..6a971ee4a9158 100644 --- a/xla/pjrt/semaphore_test.cc +++ b/xla/pjrt/semaphore_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/status_casters.h b/xla/pjrt/status_casters.h new file mode 100644 index 0000000000000..bd4a044f6c304 --- /dev/null +++ b/xla/pjrt/status_casters.h @@ -0,0 +1,218 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_STATUS_CASTERS_H_ +#define XLA_PJRT_STATUS_CASTERS_H_ + +#include "xla/pjrt/exceptions.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "tsl/platform/macros.h" + +namespace xla { + +// C++ -> Python caster helpers. +// +// Failing statuses become Python exceptions; OK Status() becomes None. +// +// Given there can be only a single global pybind11 type_caster for the +// `absl::Status` type, and given XLA wants a custom exception being raised, +// we use a dedicated helper to implement this feature without relying on a +// global `type_caster`. +// +// For example: +// +// - Functions without arguments: +// m.def("my_func", []() { xla::ThrowIfError(MyFunc()); } +// - Classes with a single argument: +// py_class.def("delete", [](Buffer& self) { +// xla::ThrowIfError(self.Delete()); +// } +// +// For functions with more arguments, you can either inline the arguments, +// or use the `ThrowIfErrorWrapper` wrapper defined below: +// +// m.def("my_func", xla::ThrowIfErrorWrapper(MyFunc)); +// +// Nonstatic member functions can be wrapped by passing a +// pointer-to-member-function: +// xla::ThrowIfErrorWrapper(&MyClass::MyMethod) + +inline void ThrowIfError(absl::Status src) { + if (!src.ok()) { + throw xla::XlaRuntimeError(src); + } +} + +// If one does not want to have to define a lambda specifying the inputs +// arguments, on can use the `ThrowIfErrorWrapper` wrapper. +// +// There are three specializations: +// - For free functions, `Sig` is the function type and `F` is `Sig&`. +// - For callable types, `Sig` is the pointer to member function type +// and `F` is the type of the callable. +// - For a nonstatic member function of a class `C`, `Sig` is the function type +// and `F` is Sig C::*. +// +// In the first two cases, the wrapper returns a callable with signature `Sig`; +// in the third case, the wrapper returns callable with a modified signature +// that takes a C instance as the first argument. +template +struct ThrowIfErrorWrapper; + +// C++17 "deduction guide" that guides class template argument deduction (CTAD) +// For free functions. +template +ThrowIfErrorWrapper(F) -> ThrowIfErrorWrapper; + +// For callable types (with operator()). +template +ThrowIfErrorWrapper(absl::Status (&)(Args...)) + -> ThrowIfErrorWrapper; + +// For unbound nonstatic member functions. +template +ThrowIfErrorWrapper(absl::Status (C::*)(Args...)) + -> ThrowIfErrorWrapper; + +// Template specializations. + +// For free functions. +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(absl::Status (&f)(Args...)) : func(f) {} + void operator()(Args... args) const { + xla::ThrowIfError(func(std::forward(args)...)); + } + absl::Status (&func)(Args...); +}; + +// For callable types (with operator()), non-const and const versions. +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(F&& f) : func(std::move(f)) {} + void operator()(Args... args) const { + xla::ThrowIfError(func(std::forward(args)...)); + } + F func; +}; +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(F&& f) : func(std::move(f)) {} + void operator()(Args... args) const { + xla::ThrowIfError(func(std::forward(args)...)); + } + F func; +}; + +// For unbound nonstatic member functions, non-const and const versions. +// `ptmf` stands for "pointer to member function". +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(absl::Status (C::*ptmf)(Args...)) : ptmf(ptmf) {} + void operator()(C& instance, Args... args) const { + xla::ThrowIfError((instance.*ptmf)(std::forward(args)...)); + } + absl::Status (C::*ptmf)(Args...); +}; +template +struct ThrowIfErrorWrapper { + explicit ThrowIfErrorWrapper(absl::Status (C::*ptmf)(Args...) const) + : ptmf(ptmf) {} + void operator()(const C& instance, Args... args) const { + xla::ThrowIfError((instance.*ptmf)(std::forward(args)...)); + } + absl::Status (C::*ptmf)(Args...) const; +}; + +// Utilities for `StatusOr`. +template +T ValueOrThrow(StatusOr v) { + if (!v.ok()) { + throw xla::XlaRuntimeError(v.status()); + } + return std::move(v).value(); +} + +template +struct ValueOrThrowWrapper; + +template +ValueOrThrowWrapper(F) -> ValueOrThrowWrapper; + +template +ValueOrThrowWrapper(absl::StatusOr (&)(Args...)) + -> ValueOrThrowWrapper(Args...), + absl::StatusOr (&)(Args...)>; + +template +ValueOrThrowWrapper(absl::StatusOr (C::*)(Args...)) + -> ValueOrThrowWrapper(Args...), C>; + +// Deduction guide for const methods. +template +ValueOrThrowWrapper(absl::StatusOr (C::*)(Args...) const) + -> ValueOrThrowWrapper(Args...) const, C>; + +template +struct ValueOrThrowWrapper(Args...), + absl::StatusOr (&)(Args...)> { + explicit ValueOrThrowWrapper(absl::StatusOr (&f)(Args...)) : func(f) {} + R operator()(Args... args) const { + return xla::ValueOrThrow(func(std::forward(args)...)); + } + absl::StatusOr (&func)(Args...); +}; +template +struct ValueOrThrowWrapper (C::*)(Args...), F> { + explicit ValueOrThrowWrapper(F&& f) : func(std::move(f)) {} + R operator()(Args... args) const { + return xla::ValueOrThrow(func(std::forward(args)...)); + } + F func; +}; +template +struct ValueOrThrowWrapper (C::*)(Args...) const, F> { + explicit ValueOrThrowWrapper(F&& f) : func(std::move(f)) {} + R operator()(Args... args) const { + return xla::ValueOrThrow(func(std::forward(args)...)); + } + F func; +}; + +// For unbound nonstatic member functions, non-const and const versions. +// `ptmf` stands for "pointer to member function". +template +struct ValueOrThrowWrapper(Args...), C> { + explicit ValueOrThrowWrapper(absl::StatusOr (C::*ptmf)(Args...)) + : ptmf(ptmf) {} + R operator()(C& instance, Args... args) const { + return xla::ValueOrThrow((instance.*ptmf)(std::forward(args)...)); + } + absl::StatusOr (C::*ptmf)(Args...); +}; +template +struct ValueOrThrowWrapper(Args...) const, C> { + explicit ValueOrThrowWrapper(absl::StatusOr (C::*ptmf)(Args...) const) + : ptmf(ptmf) {} + R operator()(const C& instance, Args... args) const { + return xla::ValueOrThrow((instance.*ptmf)(std::forward(args)...)); + } + absl::StatusOr (C::*ptmf)(Args...) const; +}; + +} // namespace xla + +#endif // XLA_PJRT_STATUS_CASTERS_H_ diff --git a/xla/pjrt/stream_executor_executable.cc b/xla/pjrt/stream_executor_executable.cc index f80a1f26c5b06..41e0098eaa3eb 100644 --- a/xla/pjrt/stream_executor_executable.cc +++ b/xla/pjrt/stream_executor_executable.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,6 +36,7 @@ StatusOr StreamExecutorExecutable::SerializeExecutable() const { proto.set_num_replicas(num_replicas_); proto.set_num_partitions(num_partitions_); proto.set_name(name_); + proto.set_fingerprint(fingerprint_); return proto.SerializeAsString(); } } // namespace xla diff --git a/xla/pjrt/stream_executor_executable.h b/xla/pjrt/stream_executor_executable.h index e9fe0b319558e..b12f21a14dd0a 100644 --- a/xla/pjrt/stream_executor_executable.h +++ b/xla/pjrt/stream_executor_executable.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,12 +27,14 @@ class StreamExecutorExecutable : public PjRtExecutable { StreamExecutorExecutable( const CompileOptions& compile_options, std::vector> executables, - int num_replicas, int num_partitions, absl::string_view name) + int num_replicas, int num_partitions, absl::string_view name, + absl::string_view fingerprint) : compile_options_(compile_options), aot_executables_(std::move(executables)), num_replicas_(num_replicas), num_partitions_(num_partitions), - name_(name) {} + name_(name), + fingerprint_(fingerprint) {} StatusOr SerializeExecutable() const override; @@ -63,12 +65,17 @@ class StreamExecutorExecutable : public PjRtExecutable { return aot_executables_; } + StatusOr FingerprintExecutable() const override { + return fingerprint_; + } + private: CompileOptions compile_options_; std::vector> aot_executables_; int num_replicas_; int num_partitions_; std::string name_; + std::string fingerprint_; }; } // namespace xla diff --git a/xla/pjrt/stream_executor_executable.proto b/xla/pjrt/stream_executor_executable.proto index 20cc49de21252..c9572d9aaf874 100644 --- a/xla/pjrt/stream_executor_executable.proto +++ b/xla/pjrt/stream_executor_executable.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,4 +25,5 @@ message StreamExecutorExecutableProto { int32 num_replicas = 3; int32 num_partitions = 4; string name = 5; + string fingerprint = 6; } diff --git a/xla/pjrt/tf_pjrt_client.cc b/xla/pjrt/tf_pjrt_client.cc index 53b90631f4f44..d9d5102faf53c 100644 --- a/xla/pjrt/tf_pjrt_client.cc +++ b/xla/pjrt/tf_pjrt_client.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/tf_pjrt_client.h b/xla/pjrt/tf_pjrt_client.h index d92c32e402386..b440c15a82249 100644 --- a/xla/pjrt/tf_pjrt_client.h +++ b/xla/pjrt/tf_pjrt_client.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -57,9 +57,14 @@ class TfPjRtBuffer : public PjRtBuffer { override { return wrapped_->AcquireExternalReference(); } - PjRtFuture ToLiteral(MutableLiteralBase* literal) override { + PjRtFuture ToLiteral(MutableLiteralBase* literal) override { return wrapped_->ToLiteral(literal); } + PjRtFuture LazyToLiteral( + absl::AnyInvocable() &&> generator) + override { + return wrapped_->LazyToLiteral(std::move(generator)); + } StatusOr GetOnDeviceSizeInBytes() const override { return wrapped_->GetOnDeviceSizeInBytes(); } @@ -206,15 +211,23 @@ class TfPjRtClient : public PjRtClient { return wrapped_->addressable_devices(); } StatusOr LookupDevice(int device_id) const override { - return wrapped_->LookupDevice(device_id); + return LookupDevice(PjRtGlobalDeviceId(device_id)); + } + StatusOr LookupDevice( + PjRtGlobalDeviceId global_device_id) const override { + return wrapped_->LookupDevice(global_device_id.value()); } StatusOr LookupAddressableDevice( int local_hardware_id) const override { + return LookupAddressableDevice(PjRtLocalDeviceId(local_hardware_id)); + } + StatusOr LookupAddressableDevice( + PjRtLocalDeviceId local_device_id) const override { if (wrapped_ == nullptr) { return tsl::errors::Internal( "Wrapped PJRT client in TfPjRtClient is already destoryed."); } - return wrapped_->LookupAddressableDevice(local_hardware_id); + return wrapped_->LookupAddressableDevice(local_device_id); } absl::Span memory_spaces() const override { return wrapped_->memory_spaces(); @@ -235,6 +248,10 @@ class TfPjRtClient : public PjRtClient { int num_replicas, int num_partitions) const override { return wrapped_->GetDefaultDeviceAssignment(num_replicas, num_partitions); } + StatusOr GetDefaultLayout(PrimitiveType element_type, + absl::Span dims) override { + return wrapped_->GetDefaultLayout(element_type, dims); + } StatusOr> GetHloCostAnalysis() const override { return wrapped_->GetHloCostAnalysis(); @@ -275,21 +292,21 @@ class TfPjRtClient : public PjRtClient { const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, + absl::AnyInvocable on_done_with_host_buffer, PjRtDevice* device) override { return WrapBuffer(wrapped_->BufferFromHostBuffer( data, type, dims, byte_strides, host_buffer_semantics, - on_done_with_host_buffer, device)); + std::move(on_done_with_host_buffer), device)); } StatusOr> BufferFromHostBuffer( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, PjRtDevice* device, - const Layout* device_layout) override { + absl::AnyInvocable on_done_with_host_buffer, + PjRtDevice* device, const Layout* device_layout) override { return WrapBuffer(wrapped_->BufferFromHostBuffer( data, type, dims, byte_strides, host_buffer_semantics, - on_done_with_host_buffer, device, device_layout)); + std::move(on_done_with_host_buffer), device, device_layout)); } StatusOr> BufferFromHostLiteral( const LiteralSlice& literal, PjRtDevice* device) override { @@ -328,6 +345,10 @@ class TfPjRtClient : public PjRtClient { StatusOr CreateHostToDeviceChannelHandle() override { return wrapped_->CreateHostToDeviceChannelHandle(); } + StatusOr GetTopologyDescription() + const override { + return wrapped_->GetTopologyDescription(); + } Status Defragment() override { return wrapped_->Defragment(); } PjRtClient* wrapped() const { return wrapped_.get(); } diff --git a/xla/pjrt/tf_pjrt_client_test.cc b/xla/pjrt/tf_pjrt_client_test.cc index 9e3b785fc0185..2e946459fbebd 100644 --- a/xla/pjrt/tf_pjrt_client_test.cc +++ b/xla/pjrt/tf_pjrt_client_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/tfrt_cpu_pjrt_client.h b/xla/pjrt/tfrt_cpu_pjrt_client.h index 7fa97e13118f0..b76d14f48ef9c 100644 --- a/xla/pjrt/tfrt_cpu_pjrt_client.h +++ b/xla/pjrt/tfrt_cpu_pjrt_client.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/tracked_device_buffer.cc b/xla/pjrt/tracked_device_buffer.cc index c071a7d8a827d..c1a699e2de4b8 100644 --- a/xla/pjrt/tracked_device_buffer.cc +++ b/xla/pjrt/tracked_device_buffer.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -25,6 +26,7 @@ limitations under the License. #include #include +#include "absl/functional/any_invocable.h" #include "absl/synchronization/mutex.h" #include "xla/pjrt/local_device_state.h" #include "xla/pjrt/utils.h" @@ -33,6 +35,8 @@ limitations under the License. #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/event.h" #include "xla/types.h" +#include "tsl/profiler/lib/connected_traceme.h" +#include "tsl/profiler/lib/context_types.h" namespace xla { @@ -75,7 +79,7 @@ void BufferSequencingEvent::WaitForEventOnStream(se::Stream* stream) { return; } - stream->ThenWaitFor(event_.event()); + stream->WaitFor(event_.event()).IgnoreError(); streams_defined_on_.push_back(stream); } @@ -121,11 +125,21 @@ bool BufferSequencingEvent::IsComplete() { void BufferSequencingEvent::ExecuteOrAddToFutureTasks( const std::string& task_name, std::function task) { absl::MutexLock lock(&mu_); + tsl::profiler::TraceMeProducer producer( + "BufferSequencingEvent::ExecuteOrAddToFutureTasks", + tsl::profiler::ContextType::kPjRt); + uint64_t context_id = producer.GetContextId(); + auto wrapped_task = [task = std::move(task), context_id]() { + tsl::profiler::TraceMeConsumer consumer("BufferSequencingEvent::Execute", + tsl::profiler::ContextType::kPjRt, + context_id); + task(); + }; if (defined_status_.IsConcrete()) { - thread_pool_->Schedule(std::move(task)); + thread_pool_->Schedule(std::move(wrapped_task)); return; } - on_ready_tasks_callback_[task_name] = std::move(task); + on_ready_tasks_callback_[task_name] = std::move(wrapped_task); } void BufferSequencingEvent::ExecuteFutureTasks() { @@ -207,7 +221,7 @@ TrackedDeviceBuffer::TrackedDeviceBuffer( se::DeviceMemoryAllocator* allocator, int device_ordinal, absl::Span device_memory, absl::Span> definition_events, - std::function on_delete_callback) + absl::AnyInvocable on_delete_callback) : allocator_(allocator), device_ordinal_(device_ordinal), device_memory_(device_memory.begin(), device_memory.end()), @@ -226,7 +240,7 @@ TrackedDeviceBuffer::~TrackedDeviceBuffer() { } } if (on_delete_callback_) { - on_delete_callback_(); + std::move(on_delete_callback_)(); } } diff --git a/xla/pjrt/tracked_device_buffer.h b/xla/pjrt/tracked_device_buffer.h index d3a7761f9345b..218d0bad85cd1 100644 --- a/xla/pjrt/tracked_device_buffer.h +++ b/xla/pjrt/tracked_device_buffer.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" #include "xla/pjrt/event_pool.h" #include "xla/pjrt/local_device_state.h" #include "xla/pjrt/utils.h" @@ -271,7 +272,7 @@ class TrackedDeviceBuffer { absl::Span device_memory, absl::Span> definition_events, - std::function on_delete_callback); + absl::AnyInvocable on_delete_callback); ~TrackedDeviceBuffer(); private: @@ -300,7 +301,7 @@ class TrackedDeviceBuffer { StreamAndEventContainer usage_events_; // A callback to call when the TrackedDeviceBuffer is about to be destroyed. - std::function on_delete_callback_; + absl::AnyInvocable on_delete_callback_; }; // Populates 'events' with the set of buffer events for buffer. If diff --git a/xla/pjrt/tracked_device_buffer_test.cc b/xla/pjrt/tracked_device_buffer_test.cc index 2c83d927a5147..9b4b238775443 100644 --- a/xla/pjrt/tracked_device_buffer_test.cc +++ b/xla/pjrt/tracked_device_buffer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/transpose.cc b/xla/pjrt/transpose.cc index acab3a8991e97..2f16afbc2f6d6 100644 --- a/xla/pjrt/transpose.cc +++ b/xla/pjrt/transpose.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -85,14 +85,18 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/base/optimization.h" +#include "absl/container/inlined_vector.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/synchronization/blocking_counter.h" #include "absl/types/span.h" #include "absl/types/variant.h" +#include "xla/ef57.h" #include "xla/permutation_util.h" #include "xla/pjrt/transpose_kernels.h" #include "xla/status.h" +#include "xla/statusor.h" #include "xla/util.h" #include "tsl/platform/logging.h" #include "tsl/profiler/lib/traceme.h" @@ -133,15 +137,6 @@ struct TransposePlan::Node { bool is_inner_dim_in_b = false; }; -void ConvertF64ToEf57(const double* input, float* output, int n) { - // TODO(phawkins): vectorize this transformation. - for (int i = 0; i < n; ++i) { - std::tie(output[0], output[1]) = SplitF64ToF32(*input); - ++input; - output += 2; - } -} - template void MacroKernel(const char* __restrict a, int64_t lda, int outer_bs_a, @@ -156,10 +151,23 @@ void MacroKernel(const char* __restrict a, int64_t lda, int outer_bs_a, if constexpr (transformation == TransposePlan::Transformation::kF64ToEf57) { DCHECK_EQ(outer_bs_a * inner_bs % 2, 0); float* p = reinterpret_cast(scratch); - for (int i = 0; i < outer_bs_b * inner_bs; ++i) { - ConvertF64ToEf57(reinterpret_cast(a + lda * i), - p + outer_bs_a * inner_bs * i, - outer_bs_a * inner_bs / 2); + if (ABSL_PREDICT_TRUE(lda == sizeof(double) && + outer_bs_a * inner_bs == 2)) { + absl::Span input = absl::MakeConstSpan( + reinterpret_cast(a), outer_bs_b * inner_bs); + absl::Span output = + absl::MakeSpan(reinterpret_cast(p), input.size() * 2); + ConvertF64ToEf57(input, output); + } else { + for (int i = 0; i < outer_bs_b * inner_bs; ++i) { + absl::Span input = + absl::MakeConstSpan(reinterpret_cast(a + lda * i), + outer_bs_a * inner_bs / 2); + absl::Span output = absl::MakeSpan( + reinterpret_cast(p + outer_bs_a * inner_bs * i), + input.size() * 2); + ConvertF64ToEf57(input, output); + } } a = reinterpret_cast(scratch); lda = outer_bs_a * inner_bs * sizeof(float); @@ -482,15 +490,17 @@ void TransposePlan::Execute( execute_by_type(nodes); } } else { - absl::BlockingCounter counter(nodes_.size()); - for (absl::Span nodes : nodes_) { + absl::BlockingCounter counter(nodes_.size() - 1); + for (size_t i = 1; i < nodes_.size(); ++i) { + absl::Span nodes = nodes_[i]; schedule_work([&, nodes]() { - tsl::profiler::TraceMe traceme("Transpose::Execute", - /*level=*/2); + tsl::profiler::TraceMe traceme("Transpose::Execute", /*level=*/2); execute_by_type(nodes); counter.DecrementCount(); }); } + // Run the first chunk inline in this thread. + execute_by_type(nodes_[0]); counter.Wait(); } } @@ -660,7 +670,7 @@ static Status ParseTilingSpecification( tiling.resize(ndim, 1); if (tiling_spec.size() > ndim) { return InvalidArgument( - "Tiling (%s) must have at as many dimensions as the array (%d)", + "Tiling (%s) must have at most as many dimensions as the array (%d)", absl::StrJoin(tiling_spec, ","), ndim); } if (absl::c_find_if(tiling_spec, [](int64_t d) { return d < 1; }) != @@ -668,6 +678,11 @@ static Status ParseTilingSpecification( return InvalidArgument("Tiling sizes (%s) must be >= 1", absl::StrJoin(tiling_spec, ",")); } + if (ndim == 1) { + // Tiling doesn't do anything for a rank-1 array, except add padding. Since + // we're not going to touch any padding elements, we can ignore it. + return OkStatus(); + } int offset = ndim; offset -= tiling_spec.size(); absl::c_copy(tiling_spec, tiling.begin() + offset); @@ -867,35 +882,32 @@ void TransposePlan::BuildPlanNodes( } StatusOr> TransposePlan::Create( - size_t elem_size_in_bytes, absl::Span dims, - absl::Span permutation, - std::variant input_layout, Tiling output_tiling, - Transformation transformation, int num_threads) { + const Options& o) { auto is_negative = [](int d) { return d < 0; }; - if (absl::c_find_if(dims, is_negative) != dims.end()) { + if (absl::c_find_if(o.dims, is_negative) != o.dims.end()) { return InvalidArgument("dims must be non-negative, got %s", - absl::StrJoin(dims, ",")); + absl::StrJoin(o.dims, ",")); } - if (permutation.size() != dims.size()) { + if (o.permutation.size() != o.dims.size()) { return InvalidArgument( "dims and permutation must have equal sizes, got %d and %d", - dims.size(), permutation.size()); + o.dims.size(), o.permutation.size()); } - if (!IsPermutation(permutation)) { + if (!IsPermutation(o.permutation)) { return InvalidArgument("permutation argument is not valid, got: %s", - absl::StrJoin(permutation, ",")); + absl::StrJoin(o.permutation, ",")); } - if (num_threads < 1) { + if (o.num_threads < 1) { return InvalidArgument("num_threads argument must be >= 1, got: %d", - num_threads); + o.num_threads); } - int ndim = dims.size(); + int ndim = o.dims.size(); auto plan = std::make_unique(); - plan->num_threads_requested_ = num_threads; - plan->elem_size_in_bytes_ = elem_size_in_bytes; - switch (elem_size_in_bytes) { + plan->num_threads_requested_ = o.num_threads; + plan->elem_size_in_bytes_ = o.elem_size_in_bytes; + switch (o.elem_size_in_bytes) { case 1: case 2: case 4: @@ -904,26 +916,26 @@ StatusOr> TransposePlan::Create( break; default: return InvalidArgument("Unsupported elem_size_in_bytes=%d", - elem_size_in_bytes); + o.elem_size_in_bytes); } - plan->num_elems_ = std::accumulate(dims.begin(), dims.end(), int64_t{1}, + plan->num_elems_ = std::accumulate(o.dims.begin(), o.dims.end(), int64_t{1}, std::multiplies()); plan->original_a_dims_.resize(ndim); - absl::c_copy(dims, plan->original_a_dims_.begin()); - plan->original_b_dims_ = Permute(dims, permutation); + absl::c_copy(o.dims, plan->original_a_dims_.begin()); + plan->original_b_dims_ = Permute(o.dims, o.permutation); TF_RETURN_IF_ERROR( - ParseTilingSpecification(ndim, output_tiling.tiling, plan->b_tiling_)); + ParseTilingSpecification(ndim, o.output_tiling.tiling, plan->b_tiling_)); // Handles strides. - if (std::holds_alternative(input_layout)) { + if (std::holds_alternative(o.input_layout)) { absl::Span input_strides_in_bytes = - std::get(input_layout).strides_in_bytes; - if (input_strides_in_bytes.size() != dims.size()) { + std::get(o.input_layout).strides_in_bytes; + if (input_strides_in_bytes.size() != o.dims.size()) { return InvalidArgument( "dims and input_strides_in_bytes must have equal sizes, got %d " "and %d", - dims.size(), input_strides_in_bytes.size()); + o.dims.size(), input_strides_in_bytes.size()); } plan->original_a_strides_.resize(ndim); absl::c_copy(input_strides_in_bytes, plan->original_a_strides_.begin()); @@ -936,20 +948,20 @@ StatusOr> TransposePlan::Create( int64_t stride = input_strides_in_bytes.at(k); // If there is a dimension with size equal to the element size, sort it // last. This ensures that we place any stride-1 dimension last. - bool is_stride1 = stride == elem_size_in_bytes; + bool is_stride1 = stride == o.elem_size_in_bytes; // If there are multiple stride-1 dimensions, we'd prefer the one that // matches the stride-1 dimension of the output. // Failing that, we'd just prefer the largest stride-1 dimension last. - bool is_trailing_dim_in_b = permutation.back() == k; + bool is_trailing_dim_in_b = o.permutation.back() == k; // If we are applying ef57 conversion, we want a size-2 stride-1 // dimension last. bool ef57_even = - (is_stride1 && transformation == Transformation::kF64ToEf57 && - dims[k] == 2); + (is_stride1 && o.transformation == Transformation::kF64ToEf57 && + o.dims[k] == 2); return std::make_tuple(is_stride1, -std::abs(stride), ef57_even, - is_trailing_dim_in_b, dims[k]); + is_trailing_dim_in_b, o.dims[k]); }; absl::c_stable_sort(dim_order, [&cost](int i, int j) { return cost(i) < cost(j); }); @@ -961,18 +973,18 @@ StatusOr> TransposePlan::Create( plan->permutation_.reserve(ndim); for (int i = 0; i < ndim; ++i) { plan->lda_.push_back(input_strides_in_bytes.at(dim_order[i])); - plan->a_dims_.push_back(dims[dim_order[i]]); - plan->permutation_.push_back(inv_dim_order[permutation[i]]); + plan->a_dims_.push_back(o.dims[dim_order[i]]); + plan->permutation_.push_back(inv_dim_order[o.permutation[i]]); } plan->lda_tile_.resize(ndim, 1); plan->a_tiling_.resize(ndim, 1); } else { TF_RETURN_IF_ERROR(ParseTilingSpecification( - ndim, std::get(input_layout).tiling, plan->a_tiling_)); + ndim, std::get(o.input_layout).tiling, plan->a_tiling_)); plan->a_dims_ = plan->original_a_dims_; plan->permutation_.resize(ndim); - absl::c_copy(permutation, plan->permutation_.begin()); + absl::c_copy(o.permutation, plan->permutation_.begin()); ComputeStrides(plan->elem_size_in_bytes_, plan->a_dims_, plan->a_tiling_, plan->lda_, plan->lda_tile_); } @@ -990,15 +1002,15 @@ StatusOr> TransposePlan::Create( absl::StrJoin(plan->b_tiling_, ",")); } - plan->transformation_ = transformation; - switch (transformation) { + plan->transformation_ = o.transformation; + switch (o.transformation) { case Transformation::kNone: break; case Transformation::kF64ToEf57: - if (elem_size_in_bytes != sizeof(float)) { + if (o.elem_size_in_bytes != sizeof(float)) { return InvalidArgument( "EF57 conversion requires a element size of %d bytes, got %d", - sizeof(float), elem_size_in_bytes); + sizeof(float), o.elem_size_in_bytes); } if (plan->a_dims_.empty() || plan->a_dims_.back() % 2 != 0 || plan->lda_.back() != sizeof(float)) { @@ -1336,43 +1348,36 @@ TransposePlanCache::TransposePlanCache(int capacity) TransposePlanCache::~TransposePlanCache() = default; StatusOr> TransposePlanCache::GetOrCreate( - size_t elem_size_in_bytes, absl::Span dims, - absl::Span permutation, - std::variant input_layout, - TransposePlan::Tiling output_tiling, - TransposePlan::Transformation transformation, int num_threads) { + const TransposePlan::Options& o) { TransposePlanCacheKey key; - key.elem_size_in_bytes = elem_size_in_bytes; - key.dims.resize(dims.size()); - absl::c_copy(dims, key.dims.begin()); - key.permutation.resize(permutation.size()); - absl::c_copy(permutation, key.permutation.begin()); - if (std::holds_alternative(input_layout)) { + key.elem_size_in_bytes = o.elem_size_in_bytes; + key.dims.resize(o.dims.size()); + absl::c_copy(o.dims, key.dims.begin()); + key.permutation.resize(o.permutation.size()); + absl::c_copy(o.permutation, key.permutation.begin()); + if (std::holds_alternative(o.input_layout)) { absl::Span input_strides_in_bytes = - std::get(input_layout).strides_in_bytes; + std::get(o.input_layout).strides_in_bytes; key.input_layout = absl::InlinedVector( input_strides_in_bytes.begin(), input_strides_in_bytes.end()); key.input_layout_is_tiling = false; } else { absl::Span input_tiling = - std::get(input_layout).tiling; + std::get(o.input_layout).tiling; key.input_layout = absl::InlinedVector(input_tiling.begin(), input_tiling.end()); key.input_layout_is_tiling = true; } - key.output_tiling.resize(output_tiling.tiling.size()); - absl::c_copy(output_tiling.tiling, key.output_tiling.begin()); - key.transformation = transformation; - key.num_threads = num_threads; + key.output_tiling.resize(o.output_tiling.tiling.size()); + absl::c_copy(o.output_tiling.tiling, key.output_tiling.begin()); + key.transformation = o.transformation; + key.num_threads = o.num_threads; return cache_.GetOrCreateIfAbsent( key, [&](const TransposePlanCacheKey& key) -> StatusOr> { - TF_ASSIGN_OR_RETURN( - std::unique_ptr plan, - TransposePlan::Create(elem_size_in_bytes, dims, permutation, - input_layout, output_tiling, transformation, - num_threads)); + TF_ASSIGN_OR_RETURN(std::unique_ptr plan, + TransposePlan::Create(o)); return std::shared_ptr(std::move(plan)); }); } diff --git a/xla/pjrt/transpose.h b/xla/pjrt/transpose.h index 7bf3bf48dfc89..767d707fdb88e 100644 --- a/xla/pjrt/transpose.h +++ b/xla/pjrt/transpose.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ limitations under the License. #ifndef XLA_PJRT_TRANSPOSE_H_ #define XLA_PJRT_TRANSPOSE_H_ +#include #include #include #include @@ -34,6 +35,7 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" +#include "absl/types/span.h" #include "absl/types/variant.h" #include "xla/pjrt/lru_cache.h" #include "xla/statusor.h" @@ -86,13 +88,18 @@ class TransposePlan { kF64ToEf57 = 1, }; + struct Options { + size_t elem_size_in_bytes; + absl::Span dims; + absl::Span permutation; + std::variant input_layout = Tiling{}; + Tiling output_tiling; + Transformation transformation = Transformation::kNone; + int num_threads = 1; + }; + static StatusOr> Create( - size_t elem_size_in_bytes, absl::Span dims, - absl::Span permutation, - std::variant input_layout = Tiling{}, - Tiling output_tiling = Tiling{}, - Transformation transformation = Transformation::kNone, - int num_threads = 1); + const Options& options); TransposePlan(); ~TransposePlan(); @@ -276,14 +283,7 @@ class TransposePlanCache { // Creates or returns a cached copy of a transpose plan. StatusOr> GetOrCreate( - size_t elem_size_in_bytes, absl::Span dims, - absl::Span permutation, - std::variant - input_layout = TransposePlan::Tiling{}, - TransposePlan::Tiling output_tiling = TransposePlan::Tiling{}, - TransposePlan::Transformation transformation = - TransposePlan::Transformation::kNone, - int num_threads = 1); + const TransposePlan::Options& options); private: LRUCache #include -#if (defined(__GNUC__) || defined(__clang__)) && defined(__SSE2__) -#define XLA_HAS_SSE2 -#elif defined(_MSC_VER) && !defined(_M_ARM64EC) && defined(_M_X64) -#define XLA_HAS_SSE2 -#elif defined(_MSC_VER) && !defined(_M_ARM64EC) && \ - (defined(_M_IX86_FP) && _M_IX86_FP >= 2) -#define XLA_HAS_SSE2 -#elif defined(__AVX__) -#define XLA_HAS_SSE2 -#endif +#include "xla/compiler_macros.h" -#if defined(__ARM_NEON) && !defined(__ARM_BIG_ENDIAN) -#define XLA_HAS_ARM_NEON -#endif +namespace xla { #ifdef XLA_HAS_SSE2 #include // IWYU pragma: keep @@ -43,31 +32,11 @@ limitations under the License. #ifdef XLA_HAS_ARM_NEON #include -#endif +#endif // XLA_HAS_ARM_NEON #if defined(XLA_HAS_SSE2) || defined(XLA_HAS_ARM_NEON) #define XLA_HAS_VEC128 -#endif - -namespace xla { - -#pragma push_macro("XLA_UNROLL") -#if defined(__clang__) -#define XLA_UNROLL _Pragma("unroll") -#elif defined(__GNUC__) -#define XLA_UNROLL _Pragma("GCC unroll 128") -#else -#define XLA_UNROLL -#endif - -#pragma push_macro("XLA_FLATTEN") -#if defined(__GNUC__) || defined(__clang__) -#define XLA_FLATTEN __attribute__((flatten)) -#elif defined(_MSC_VER) -#define XLA_FLATTEN [[msvc::flatten]] -#else -#define XLA_FLATTEN -#endif +#endif // defined(XLA_HAS_SSE2) || defined(XLA_HAS_ARM_NEON) // The transpose microkernels use a general approach of zipping elements from // different rows together. We start zipping together elements of size 1, size 2 @@ -722,9 +691,6 @@ struct TransposeMicroKernel { } }; -#pragma pop_macro("XLA_FLATTEN") -#pragma pop_macro("XLA_UNROLL") - } // namespace xla #endif // XLA_PJRT_TRANSPOSE_KERNELS_H_ diff --git a/xla/pjrt/transpose_test.cc b/xla/pjrt/transpose_test.cc index 62e507b962f42..011a1ca5eb478 100644 --- a/xla/pjrt/transpose_test.cc +++ b/xla/pjrt/transpose_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ limitations under the License. #include "xla/pjrt/transpose.h" #include +#include +#include #include #include #include @@ -119,10 +121,17 @@ TEST(TransposeTest, CoalesceDimensions) { } TEST(TransposeTest, InvalidTilings) { - auto plan = - TransposePlan::Create(sizeof(float), {3, 4, 5}, {0, 1, 2}, - /*input_layout=*/TransposePlan::Tiling{{8, 128}}, - /*output_tiling=*/TransposePlan::Tiling{{4}}); + TransposePlan::Options options; + std::vector dims = {3, 4, 5}; + std::vector perm = {0, 1, 2}; + options.elem_size_in_bytes = sizeof(float); + options.dims = dims; + options.permutation = perm; + std::vector input_tiling = {8, 128}; + std::vector output_tiling = {4}; + options.input_layout = TransposePlan::Tiling{input_tiling}; + options.output_tiling = TransposePlan::Tiling{output_tiling}; + auto plan = TransposePlan::Create(options); EXPECT_EQ(plan.status().code(), tsl::error::UNIMPLEMENTED); EXPECT_THAT( plan.status().message(), @@ -360,12 +369,15 @@ class TransposeTest : public ::testing::TestWithParam { tsl::thread::ThreadPool threadpool(tsl::Env::Default(), "Transpose", parallelism); std::vector output_dims = Permute(test.dims, test.permutation); - TF_ASSERT_OK_AND_ASSIGN( - auto plan, TransposePlan::Create( - sizeof(T), test.dims, test.permutation, - TransposePlan::Tiling{test.input_tiling}, - TransposePlan::Tiling{test.output_tiling}, - TransposePlan::Transformation::kNone, parallelism)); + TransposePlan::Options options; + options.elem_size_in_bytes = sizeof(T); + options.dims = test.dims; + options.permutation = test.permutation; + options.input_layout = TransposePlan::Tiling{test.input_tiling}; + options.output_tiling = TransposePlan::Tiling{test.output_tiling}; + options.transformation = TransposePlan::Transformation::kNone; + options.num_threads = parallelism; + TF_ASSERT_OK_AND_ASSIGN(auto plan, TransposePlan::Create(options)); VLOG(1) << plan->ToString(); xla::Array untiled_input(test.dims); untiled_input.FillIota(0); @@ -406,10 +418,15 @@ TEST(TransposeTest, NegativeStrides1D) { std::vector expected(n); absl::c_iota(input, int32_t{7}); std::iota(expected.rbegin(), expected.rend(), 7); - TF_ASSERT_OK_AND_ASSIGN( - auto plan, TransposePlan::Create( - sizeof(int32_t), {n}, /*permutation=*/{0}, - TransposePlan::Striding{{-int64_t{sizeof(int32_t)}}})); + std::vector dims = {n}; + std::vector permutation = {0}; + TransposePlan::Options options; + options.elem_size_in_bytes = sizeof(int32_t); + options.dims = dims; + options.permutation = permutation; + std::vector strides = {-int64_t{sizeof(int32_t)}}; + options.input_layout = TransposePlan::Striding{strides}; + TF_ASSERT_OK_AND_ASSIGN(auto plan, TransposePlan::Create(options)); plan->Execute(input.data() + (n - 1), output.data()); EXPECT_EQ(expected, output); } @@ -427,11 +444,16 @@ TEST(TransposeTest, NegativeStrides2D) { {1, 5, 9}, }; xla::Array output({4, 3}); - TF_ASSERT_OK_AND_ASSIGN( - auto plan, TransposePlan::Create( - sizeof(int16_t), {3, 4}, /*permutation=*/{1, 0}, - TransposePlan::Striding{ - {4 * sizeof(int16_t), -int64_t{sizeof(int16_t)}}})); + std::vector dims = {3, 4}; + std::vector permutation = {1, 0}; + TransposePlan::Options options; + options.elem_size_in_bytes = sizeof(int16_t); + options.dims = dims; + options.permutation = permutation; + std::vector strides = {4 * sizeof(int16_t), + -int64_t{sizeof(int16_t)}}; + options.input_layout = TransposePlan::Striding{strides}; + TF_ASSERT_OK_AND_ASSIGN(auto plan, TransposePlan::Create(options)); plan->Execute(input.data() + 3, output.data()); EXPECT_EQ(expected, output); } @@ -497,11 +519,15 @@ static void BM_Eigen_float(const TransposeTestCase& bm, int parallelism, template void BM_Transpose(const TransposeTestCase& bm, int parallelism, ::testing::benchmark::State& state) { - TF_ASSERT_OK_AND_ASSIGN( - auto plan, - TransposePlan::Create(sizeof(T), bm.dims, bm.permutation, - TransposePlan::Tiling{}, TransposePlan::Tiling{}, - TransposePlan::Transformation::kNone, parallelism)); + TransposePlan::Options options; + options.elem_size_in_bytes = sizeof(T); + options.dims = bm.dims; + options.permutation = bm.permutation; + options.input_layout = TransposePlan::Tiling{}; + options.output_tiling = TransposePlan::Tiling{}; + options.transformation = TransposePlan::Transformation::kNone; + options.num_threads = parallelism; + TF_ASSERT_OK_AND_ASSIGN(auto plan, TransposePlan::Create(options)); Array input(bm.dims); input.FillIota(0); std::vector output_dims = Permute(bm.dims, bm.permutation); @@ -556,25 +582,31 @@ static void* benchmarks = []() { }(); TEST(TransposePlanCache, Basics) { + std::vector dims = {1, 2, 3}; + std::vector permutation_210 = {2, 1, 0}; + std::vector permutation_120 = {1, 2, 0}; + std::vector permutation_012 = {0, 1, 2}; TransposePlanCache cache(2); - TF_ASSERT_OK_AND_ASSIGN( - auto p1, cache.GetOrCreate(/*elem_size_in_bytes=*/4, /*dims=*/{1, 2, 3}, - /*permutation=*/{2, 1, 0})); - TF_ASSERT_OK_AND_ASSIGN( - auto p1a, cache.GetOrCreate(/*elem_size_in_bytes=*/4, /*dims=*/{1, 2, 3}, - /*permutation=*/{2, 1, 0})); + TransposePlan::Options o; + o.elem_size_in_bytes = 4; + o.dims = dims; + o.permutation = permutation_210; + TF_ASSERT_OK_AND_ASSIGN(auto p1, cache.GetOrCreate(o)); + TF_ASSERT_OK_AND_ASSIGN(auto p1a, cache.GetOrCreate(o)); EXPECT_TRUE(p1.get() == p1a.get()); - TF_ASSERT_OK_AND_ASSIGN( - auto p2, cache.GetOrCreate(/*elem_size_in_bytes=*/4, /*dims=*/{1, 2, 3}, - /*permutation=*/{1, 2, 0})); + TransposePlan::Options o2; + o2.elem_size_in_bytes = 4; + o2.dims = dims; + o2.permutation = permutation_120; + TF_ASSERT_OK_AND_ASSIGN(auto p2, cache.GetOrCreate(o2)); EXPECT_TRUE(p1.get() != p2.get()); - TF_ASSERT_OK_AND_ASSIGN( - auto p3, cache.GetOrCreate(/*elem_size_in_bytes=*/4, /*dims=*/{1, 2, 3}, - /*permutation=*/{0, 1, 2})); + TransposePlan::Options o3; + o3.elem_size_in_bytes = 4; + o3.dims = dims; + o3.permutation = permutation_012; + TF_ASSERT_OK_AND_ASSIGN(auto p3, cache.GetOrCreate(o3)); EXPECT_TRUE(p3.get() != p1.get()); - TF_ASSERT_OK_AND_ASSIGN( - auto p1b, cache.GetOrCreate(/*elem_size_in_bytes=*/4, /*dims=*/{1, 2, 3}, - /*permutation=*/{2, 1, 0})); + TF_ASSERT_OK_AND_ASSIGN(auto p1b, cache.GetOrCreate(o)); EXPECT_TRUE(p1.get() != p1b.get()); } diff --git a/xla/pjrt/utils.cc b/xla/pjrt/utils.cc index 59d12ced99c16..c5a0f6016c5ad 100644 --- a/xla/pjrt/utils.cc +++ b/xla/pjrt/utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -41,6 +41,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/layout.h" #include "xla/layout_util.h" #include "xla/pjrt/layout_mode.h" #include "xla/primitive_util.h" @@ -206,6 +207,54 @@ static StatusOr> MlirAttrsToLayoutModes( return result; } +// TODO(b/329428415): Make this generic enough to be used by the GPU and TPU +// compilers. +StatusOr GetMemorySpaceColor(const std::string& memory_kind) { + // TODO(yashkatariya,zce): Unpinned_host is not valid for compiler. Only + // pinned_host matters. So should there be a different lowering for + // unpinned_host? + if (memory_kind == "unpinned_host" || memory_kind == "pinned_host") { + return xla::Layout::kHostMemorySpace; + } else if (memory_kind == "device") { + return xla::Layout::kDefaultMemorySpace; + } else { + return InvalidArgument("Unknown memory kind %s", memory_kind); + } +} + +// Helper method that takes an ArrayAttr of DictionaryAttrs for each arg or +// result of a function, and looks for "mhlo.layout_mode". `all_attrs` can be +// nullptr. `num_values` is the number of arguments or results. +static absl::StatusOr> MlirAttrsToMemoryKinds( + mlir::ArrayAttr all_attrs, size_t num_values) { + if (all_attrs == nullptr) { + return std::vector(num_values, + xla::Layout::kDefaultMemorySpace); + } + if (all_attrs.size() != num_values) { + return InvalidArgument( + "MlirAttrsToMemoryKinds got unexpected number of attributes: %d, " + "expected: %d", + all_attrs.size(), num_values); + } + + std::vector result; + result.reserve(all_attrs.size()); + for (const mlir::Attribute& dict_attr : all_attrs) { + mlir::StringAttr attr = + dict_attr.cast().getAs( + "mhlo.memory_kind"); + if (attr != nullptr) { + TF_ASSIGN_OR_RETURN(MemorySpaceColor memory_space, + GetMemorySpaceColor(attr.getValue().str())); + result.emplace_back(memory_space); + } else { + result.emplace_back(xla::Layout::kDefaultMemorySpace); + } + } + return result; +} + // Helper function for getting default LayoutModes for tupled arguments or // outputs. Returns nullopt if the arguments/outputs are not tupled. Raises an // error if layout modes are requested on tupled values. @@ -231,6 +280,33 @@ static StatusOr>> GetTupleLayoutModes( return std::vector(types[0].cast().size()); } +// Helper function for getting default LayoutModes for tupled arguments or +// outputs. Returns nullopt if the arguments/outputs are not tupled. Raises an +// error if layout modes are requested on tupled values. +static absl::StatusOr>> +GetTupleMemoryKinds(mlir::ArrayRef types, + mlir::ArrayAttr all_attrs) { + if (types.size() != 1 || !llvm::isa(types[0])) { + return std::nullopt; + } + if (all_attrs != nullptr) { + if (all_attrs.size() != 1) { + return InvalidArgument( + "GetTupleMemoryKinds expected single tuple attr, got %d attrs", + all_attrs.size()); + } + mlir::StringAttr attr = + all_attrs.begin()->cast().getAs( + "mhlo.memory_kind"); + if (attr != nullptr) { + return Unimplemented("mhlo.memory_kind not supported with tupled values"); + } + } + // Use default layout for all outputs. + return std::vector(types[0].cast().size(), + xla::Layout::kDefaultMemorySpace); +} + StatusOr> GetArgLayoutModes(mlir::ModuleOp module) { mlir::func::FuncOp main = module.lookupSymbol("main"); if (main == nullptr) { @@ -263,11 +339,47 @@ StatusOr> GetOutputLayoutModes(mlir::ModuleOp module) { return MlirAttrsToLayoutModes(main.getAllResultAttrs(), main.getNumResults()); } +absl::StatusOr> GetArgMemoryKinds( + mlir::ModuleOp module) { + mlir::func::FuncOp main = module.lookupSymbol("main"); + if (main == nullptr) { + return InvalidArgument( + "GetArgMemoryKinds passed module without main function"); + } + + // Special case: tupled arguments + TF_ASSIGN_OR_RETURN( + std::optional> maybe_tuple_result, + GetTupleMemoryKinds(main.getFunctionType().getInputs(), + main.getAllArgAttrs())); + if (maybe_tuple_result) return *maybe_tuple_result; + + return MlirAttrsToMemoryKinds(main.getAllArgAttrs(), main.getNumArguments()); +} + +absl::StatusOr> GetOutputMemoryKinds( + mlir::ModuleOp module) { + mlir::func::FuncOp main = module.lookupSymbol("main"); + if (main == nullptr) { + return InvalidArgument( + "GetOutputMemoryKinds passed module without main function"); + } + + // Special case: tupled outputs + TF_ASSIGN_OR_RETURN( + std::optional> maybe_tuple_result, + GetTupleMemoryKinds(main.getFunctionType().getResults(), + main.getAllResultAttrs())); + if (maybe_tuple_result) return *maybe_tuple_result; + + return MlirAttrsToMemoryKinds(main.getAllResultAttrs(), main.getNumResults()); +} + // Make sure to choose delimiter that will never show up in Layout strings. -static const char* kLayoutModeDelimiter = ";"; +static const char* kDelimiter = ";"; static std::string GetFrontendAttr(absl::Span layout_modes) { - return absl::StrJoin(layout_modes, kLayoutModeDelimiter, + return absl::StrJoin(layout_modes, kDelimiter, [](std::string* out, const LayoutMode& mode) { absl::StrAppend(out, mode.ToString()); }); @@ -290,11 +402,39 @@ Status AddLayoutModesToFrontendAttrs(mlir::ModuleOp module, return OkStatus(); } +static std::string GetFrontendAttrForMemorySpace( + const std::vector& memory_spaces) { + return absl::StrJoin( + memory_spaces, kDelimiter, + [](std::string* out, const MemorySpaceColor memory_kind) { + absl::StrAppend(out, memory_kind); + }); +} + +Status AddMemoryKindsToFrontendAttrs(mlir::ModuleOp module, + XlaComputation& xla_computation) { + TF_ASSIGN_OR_RETURN(std::vector arg_memory_spaces, + GetArgMemoryKinds(module)); + TF_ASSIGN_OR_RETURN(std::vector out_memory_spaces, + GetOutputMemoryKinds(module)); + + // Type is string->string proto map. Using auto here to deal with different + // build environments. + auto& frontend_attrs = *xla_computation.mutable_proto() + ->mutable_frontend_attributes() + ->mutable_map(); + frontend_attrs["arg_memory_spaces"] = + GetFrontendAttrForMemorySpace(arg_memory_spaces); + frontend_attrs["out_memory_spaces"] = + GetFrontendAttrForMemorySpace(out_memory_spaces); + return OkStatus(); +} + static StatusOr> GetLayoutModesFromFrontendAttr( absl::string_view attr) { // SkipEmpty() needed to avoid returning the empty string when attr is empty. std::vector str_modes = - absl::StrSplit(attr, kLayoutModeDelimiter, absl::SkipEmpty()); + absl::StrSplit(attr, kDelimiter, absl::SkipEmpty()); std::vector result; for (const std::string& str_mode : str_modes) { TF_ASSIGN_OR_RETURN(LayoutMode mode, LayoutMode::FromString(str_mode)); @@ -315,6 +455,35 @@ static StatusOr> GetLayoutModes( return GetLayoutModesFromFrontendAttr(iter->second); } +static StatusOr> GetMemoryKindsFromFrontendAttr( + absl::string_view attr) { + // SkipEmpty() needed to avoid returning the empty string when attr is empty. + std::vector str_memory_spaces = + absl::StrSplit(attr, kDelimiter, absl::SkipEmpty()); + + std::vector result; + result.reserve(str_memory_spaces.size()); + for (const std::string& str_mem_space : str_memory_spaces) { + MemorySpaceColor memory_space; + CHECK(absl::SimpleAtoi(str_mem_space, &memory_space)); + result.emplace_back(memory_space); + } + return result; +} + +static StatusOr> GetMemoryKinds( + const XlaComputation& computation, absl::string_view frontend_attr_name, + size_t num_values) { + const auto& frontend_attrs = computation.proto().frontend_attributes().map(); + auto iter = frontend_attrs.find(frontend_attr_name); + if (iter == frontend_attrs.end()) { + // Return all default memory space i.e. 0 if frontend attr isn't present. + return std::vector(num_values, + xla::Layout::kDefaultMemorySpace); + } + return GetMemoryKindsFromFrontendAttr(iter->second); +} + StatusOr> GetArgLayoutModes( const XlaComputation& computation) { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, @@ -326,6 +495,17 @@ StatusOr> GetArgLayoutModes( return GetLayoutModes(computation, "arg_layout_modes", num_args); } +StatusOr> GetArgMemoryKinds( + const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + computation.GetProgramShape()); + size_t num_args = program_shape.parameters_size() == 1 && + program_shape.parameters(0).IsTuple() + ? program_shape.parameters(0).tuple_shapes_size() + : program_shape.parameters_size(); + return GetMemoryKinds(computation, "arg_memory_spaces", num_args); +} + StatusOr> GetOutputLayoutModes( const XlaComputation& computation) { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, @@ -336,6 +516,16 @@ StatusOr> GetOutputLayoutModes( return GetLayoutModes(computation, "out_layout_modes", num_outputs); } +StatusOr> GetOutputMemoryKinds( + const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + computation.GetProgramShape()); + size_t num_outputs = program_shape.result().IsTuple() + ? program_shape.result().tuple_shapes_size() + : 1; + return GetMemoryKinds(computation, "out_memory_spaces", num_outputs); +} + static StatusOr LayoutModeToXlaShape( const LayoutMode& layout_mode, const Shape& unsharded_shape, const Shape& sharded_shape, @@ -378,6 +568,8 @@ static StatusOr LayoutModeToXlaShape( StatusOr, Shape>> LayoutModesToXlaShapes( const XlaComputation& computation, std::vector arg_layout_modes, std::vector out_layout_modes, + const std::vector& arg_memory_spaces, + const std::vector& out_memory_spaces, std::function(Shape)> choose_compact_layout_for_shape_function) { // Compute sharded argument and output shapes. @@ -431,6 +623,10 @@ StatusOr, Shape>> LayoutModesToXlaShapes( // Convert each LayoutMode to an xla::Shape with the appropriate Layout set or // unset. + if (arg_memory_spaces.size() != arg_layout_modes.size()) { + return InvalidArgument( + "The sizes of arg_memory_spaces and arg_layout_modes don't match"); + } std::vector flat_arg_layouts; flat_arg_layouts.reserve(arg_layout_modes.size()); for (int i = 0; i < arg_layout_modes.size(); ++i) { @@ -439,8 +635,17 @@ StatusOr, Shape>> LayoutModesToXlaShapes( LayoutModeToXlaShape(arg_layout_modes[i], unsharded_arg_shapes[i], sharded_arg_shapes[i], choose_compact_layout_for_shape_function)); + // When layout is AUTO, memory space can't be set since it will be partial. + if (layout.has_layout()) { + layout.mutable_layout()->set_memory_space(arg_memory_spaces[i]); + } flat_arg_layouts.emplace_back(std::move(layout)); } + + if (out_memory_spaces.size() != out_layout_modes.size()) { + return InvalidArgument( + "The sizes of out_memory_spaces and out_layout_modes don't match"); + } std::vector flat_out_layouts; flat_out_layouts.reserve(out_layout_modes.size()); for (int i = 0; i < out_layout_modes.size(); ++i) { @@ -449,6 +654,10 @@ StatusOr, Shape>> LayoutModesToXlaShapes( LayoutModeToXlaShape(out_layout_modes[i], unsharded_out_shapes[i], sharded_out_shapes[i], choose_compact_layout_for_shape_function)); + // When layout is AUTO, memory space can't be set since it will be partial. + if (layout.has_layout()) { + layout.mutable_layout()->set_memory_space(out_memory_spaces[i]); + } flat_out_layouts.emplace_back(std::move(layout)); } @@ -468,12 +677,15 @@ StatusOr, std::vector>> LayoutModesToXla(const XlaComputation& computation, std::vector arg_layout_modes, std::vector out_layout_modes, + const std::vector& arg_memory_spaces, + const std::vector& out_memory_spaces, std::function(Shape)> choose_compact_layout_for_shape_function, ExecutableBuildOptions& build_options) { TF_ASSIGN_OR_RETURN( auto pair, LayoutModesToXlaShapes(computation, arg_layout_modes, out_layout_modes, + arg_memory_spaces, out_memory_spaces, choose_compact_layout_for_shape_function)); std::vector& arg_layouts = pair.first; Shape& out_layout = pair.second; @@ -576,7 +788,7 @@ StatusOr> ComputeParametersThatMustBeDonated( const HloInputOutputAliasConfig& config = module.input_output_alias_config(); TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus( [&](const ShapeIndex& output_index, - const HloInputOutputAliasConfig::Alias& alias) { + const HloInputOutputAliasConfig::Alias& alias) -> absl::Status { if (tuple_inputs) { if (alias.parameter_number != 0) { return InvalidArgument( diff --git a/xla/pjrt/utils.h b/xla/pjrt/utils.h index 7c423afc6dafb..3134b49702eee 100644 --- a/xla/pjrt/utils.h +++ b/xla/pjrt/utils.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -37,6 +37,8 @@ limitations under the License. namespace xla { +using MemorySpaceColor = int; + // Returns the num_replicas, num_partitions and device assignment given a // ExecutableBuildOptions and whether we want a portable executable. Status ParseDeviceAssignmentCompileOptions( @@ -55,11 +57,27 @@ StatusOr> GetArgLayoutModes(mlir::ModuleOp module); // LayoutMode::Mode::kDefault. StatusOr> GetOutputLayoutModes(mlir::ModuleOp module); +// Returns the memory space for each argument of the computations. Checks +// for the "mhlo.memory_kind" frontend attribute, and if not present, assumes 0. +StatusOr> GetArgMemoryKinds( + mlir::ModuleOp module); +// Returns the memory space for each output of the computations. Checks for +// the "mhlo.memory_kind" frontend attribute, and if not present, assumes 0. +StatusOr> GetOutputMemoryKinds( + mlir::ModuleOp module); + // Populates the frontend attributes "arg_layout_mode" and "out_layout_mode" in // xla_computation based on `module`. This function must be called before the // LayoutMode getters below work correctly on `computation`. Status AddLayoutModesToFrontendAttrs(mlir::ModuleOp module, XlaComputation& xla_computation); + +// Populates the frontend attributes "arg_memory_kinds" and "out_memory_kinds" +// in xla_computation based on `module`. This function must be called before the +// LayoutMode getters below work correctly on `computation`. +Status AddMemoryKindsToFrontendAttrs(mlir::ModuleOp module, + XlaComputation& xla_computation); + // Returns the LayoutMode for each argument of the computations. Checks for the // "arg_layout_mode" frontend attribute, and if not present, assumes // LayoutMode::Mode::kDefault. @@ -71,11 +89,22 @@ StatusOr> GetArgLayoutModes( StatusOr> GetOutputLayoutModes( const XlaComputation& computation); +// Returns the memory space for each argument of the computations. Checks for +// the "arg_memory_kind" frontend attribute, and if not present, assumes 0. +StatusOr> GetArgMemoryKinds( + const XlaComputation& computation); +// Returns the memory space for each argument of the computations. Checks for +// the "out_memory_kind" frontend attribute, and if not present, assumes 0. +StatusOr> GetOutputMemoryKinds( + const XlaComputation& computation); + // Returns (arg shapes, output shape) with properly-set Layouts that can // be passed to XLA to reflect arg_layout_modes and out_layout_modes. StatusOr, Shape>> LayoutModesToXlaShapes( const XlaComputation& computation, std::vector arg_layout_modes, std::vector out_layout_modes, + const std::vector& arg_memory_spaces, + const std::vector& out_memory_spaces, std::function(Shape)> choose_compact_layout_for_shape_function); @@ -87,6 +116,8 @@ StatusOr, std::vector>> LayoutModesToXla(const XlaComputation& computation, std::vector arg_layout_modes, std::vector out_layout_modes, + const std::vector& arg_memory_spaces, + const std::vector& out_memory_spaces, std::function(Shape)> choose_compact_layout_for_shape_function, ExecutableBuildOptions& build_options); diff --git a/xla/pjrt/worker_thread.cc b/xla/pjrt/worker_thread.cc index ca37c538752f2..3e51b5f4c3ea9 100644 --- a/xla/pjrt/worker_thread.cc +++ b/xla/pjrt/worker_thread.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/pjrt/worker_thread.h b/xla/pjrt/worker_thread.h index 0782e7ed0da0a..bf76c46389c5d 100644 --- a/xla/pjrt/worker_thread.h +++ b/xla/pjrt/worker_thread.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/primitive_util.cc b/xla/primitive_util.cc index b0b7d751f40a5..75f263ced5993 100644 --- a/xla/primitive_util.cc +++ b/xla/primitive_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,10 +19,12 @@ limitations under the License. #include #include +#include "absl/base/optimization.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/ascii.h" #include "absl/strings/string_view.h" #include "xla/statusor.h" +#include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" @@ -31,12 +33,9 @@ namespace xla { namespace primitive_util { int SignificandWidth(PrimitiveType type) { - return PrimitiveTypeSwitch( + return FloatingPointTypeSwitch( [&](auto constant_type) -> int { - if constexpr (IsFloatingPointType(constant_type)) { - return std::numeric_limits>::digits; - } - LOG(FATAL) << "Not a floating data type " << type; + return std::numeric_limits>::digits; }, type); } @@ -60,12 +59,9 @@ int UnderflowExponent(PrimitiveType type) { // normalized floating-point number." as such it does not actually yield the // minimum exponent but one above the minimum exponent that a normalized // number can have. - return PrimitiveTypeSwitch( + return FloatingPointTypeSwitch( [&](auto constant_type) -> int { - if constexpr (IsFloatingPointType(constant_type)) { - return std::numeric_limits>::min_exponent; - } - LOG(FATAL) << "Not a floating data type " << type; + return std::numeric_limits>::min_exponent; }, type); } @@ -76,12 +72,9 @@ int OverflowExponent(PrimitiveType type) { // representable finite floating-point number." as such it does not actually // yield the maximum exponent but the exponent of the first integer which // overflows. - return PrimitiveTypeSwitch( + return FloatingPointTypeSwitch( [&](auto constant_type) -> int { - if constexpr (IsFloatingPointType(constant_type)) { - return std::numeric_limits>::max_exponent; - } - LOG(FATAL) << "Not a floating data type " << type; + return std::numeric_limits>::max_exponent; }, type); } @@ -91,14 +84,25 @@ int ExponentBias(PrimitiveType type) { } bool HasInfinity(PrimitiveType type) { - return PrimitiveTypeSwitch( - [&](auto constant_type) -> bool { - if constexpr (IsFloatingPointType(constant_type)) { + if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { + return FloatingPointTypeSwitch( + [&](auto constant_type) -> bool { return std::numeric_limits>::has_infinity; - } - return false; - }, - type); + }, + type); + } + return false; +} + +bool HasNegativeZero(PrimitiveType type) { + if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { + return FloatingPointTypeSwitch( + [&](auto constant_type) -> bool { + return has_negative_zero_v>; + }, + type); + } + return false; } xla::PrimitiveType SignedIntegralTypeForBitWidth(int64_t src_bitwidth) { @@ -174,7 +178,7 @@ GetPrimitiveTypeStringMap() { } // namespace -StatusOr StringToPrimitiveType(absl::string_view name) { +absl::StatusOr StringToPrimitiveType(absl::string_view name) { const auto& map = GetPrimitiveTypeStringMap(); auto found = map.find(name); if (found == map.end()) { diff --git a/xla/primitive_util.h b/xla/primitive_util.h index 63fa4e19359f4..33309d2ee9148 100644 --- a/xla/primitive_util.h +++ b/xla/primitive_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,8 +34,8 @@ limitations under the License. #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/float8.h" #include "tsl/platform/logging.h" // IWYU pragma: keep +#include "tsl/platform/ml_dtypes.h" namespace xla { namespace primitive_util { @@ -68,6 +68,9 @@ int ExponentBias(PrimitiveType type); // Returns whether the type has a value for infinity. bool HasInfinity(PrimitiveType type); +// Returns whether the type has a value for negative zero. +bool HasNegativeZero(PrimitiveType type); + // Returns the XLA primitive type (eg, F32) corresponding to the given // template parameter native type (eg, float). template @@ -324,67 +327,11 @@ template using PrimitiveTypeConstant = std::integral_constant; -template -constexpr R PrimitiveTypeSwitch(F&& f, PrimitiveType type) { - switch (type) { - case PRED: - return std::forward(f)(PrimitiveTypeConstant()); - case S4: - return std::forward(f)(PrimitiveTypeConstant()); - case S8: - return std::forward(f)(PrimitiveTypeConstant()); - case S16: - return std::forward(f)(PrimitiveTypeConstant()); - case S32: - return std::forward(f)(PrimitiveTypeConstant()); - case S64: - return std::forward(f)(PrimitiveTypeConstant()); - case U4: - return std::forward(f)(PrimitiveTypeConstant()); - case U8: - return std::forward(f)(PrimitiveTypeConstant()); - case U16: - return std::forward(f)(PrimitiveTypeConstant()); - case U32: - return std::forward(f)(PrimitiveTypeConstant()); - case U64: - return std::forward(f)(PrimitiveTypeConstant()); - case F8E4M3FN: - return std::forward(f)( - PrimitiveTypeConstant()); - case F8E4M3B11FNUZ: - return std::forward(f)( - PrimitiveTypeConstant()); - case F8E4M3FNUZ: - return std::forward(f)( - PrimitiveTypeConstant()); - case F8E5M2: - return std::forward(f)(PrimitiveTypeConstant()); - case F8E5M2FNUZ: - return std::forward(f)( - PrimitiveTypeConstant()); - case F16: - return std::forward(f)(PrimitiveTypeConstant()); - case BF16: - return std::forward(f)(PrimitiveTypeConstant()); - case F32: - return std::forward(f)(PrimitiveTypeConstant()); - case F64: - return std::forward(f)(PrimitiveTypeConstant()); - case C64: - return std::forward(f)(PrimitiveTypeConstant()); - case C128: - return std::forward(f)(PrimitiveTypeConstant()); - case TUPLE: - return std::forward(f)(PrimitiveTypeConstant()); - case OPAQUE_TYPE: - return std::forward(f)( - PrimitiveTypeConstant()); - case TOKEN: - return std::forward(f)(PrimitiveTypeConstant()); - default: - LOG(FATAL) << "unhandled type " << type; - } +// Returns true if values of the given primitive type are held in array shapes. +inline constexpr bool IsArrayType(PrimitiveType primitive_type) { + return primitive_type != TUPLE && primitive_type != OPAQUE_TYPE && + primitive_type != TOKEN && primitive_type > PRIMITIVE_TYPE_INVALID && + primitive_type < PrimitiveType_ARRAYSIZE; } constexpr bool IsF8Type(PrimitiveType type) { @@ -417,11 +364,121 @@ constexpr bool Is4BitType(PrimitiveType type) { return type == S4 || type == U4; } -// Returns true if values of the given primitive type are held in array shapes. -inline constexpr bool IsArrayType(PrimitiveType primitive_type) { - return primitive_type > PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && - primitive_type != OPAQUE_TYPE && primitive_type != TOKEN && - primitive_type < PrimitiveType_ARRAYSIZE; +template +constexpr R IntegralTypeSwitch(F&& f, PrimitiveType type) { + if (ABSL_PREDICT_TRUE(IsIntegralType(type))) { + switch (type) { + case S4: + return std::forward(f)(PrimitiveTypeConstant()); + case S8: + return std::forward(f)(PrimitiveTypeConstant()); + case S16: + return std::forward(f)(PrimitiveTypeConstant()); + case S32: + return std::forward(f)(PrimitiveTypeConstant()); + case S64: + return std::forward(f)(PrimitiveTypeConstant()); + case U4: + return std::forward(f)(PrimitiveTypeConstant()); + case U8: + return std::forward(f)(PrimitiveTypeConstant()); + case U16: + return std::forward(f)(PrimitiveTypeConstant()); + case U32: + return std::forward(f)(PrimitiveTypeConstant()); + case U64: + return std::forward(f)(PrimitiveTypeConstant()); + default: + ABSL_UNREACHABLE(); + } + } + LOG(FATAL) << "Not an integral data type " << type; +} + +template +constexpr R FloatingPointTypeSwitch(F&& f, PrimitiveType type) { + if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { + switch (type) { + case F8E4M3FN: + return std::forward(f)( + PrimitiveTypeConstant()); + case F8E4M3B11FNUZ: + return std::forward(f)( + PrimitiveTypeConstant()); + case F8E4M3FNUZ: + return std::forward(f)( + PrimitiveTypeConstant()); + case F8E5M2: + return std::forward(f)( + PrimitiveTypeConstant()); + case F8E5M2FNUZ: + return std::forward(f)( + PrimitiveTypeConstant()); + case F16: + return std::forward(f)(PrimitiveTypeConstant()); + case BF16: + return std::forward(f)(PrimitiveTypeConstant()); + case F32: + return std::forward(f)(PrimitiveTypeConstant()); + case F64: + return std::forward(f)(PrimitiveTypeConstant()); + default: + ABSL_UNREACHABLE(); + } + } + LOG(FATAL) << "Not a floating point data type " << type; +} + +template +constexpr R ComplexTypeSwitch(F&& f, PrimitiveType type) { + if (ABSL_PREDICT_TRUE(IsComplexType(type))) { + switch (type) { + case C64: + return std::forward(f)(PrimitiveTypeConstant()); + case C128: + return std::forward(f)(PrimitiveTypeConstant()); + default: + ABSL_UNREACHABLE(); + } + } + LOG(FATAL) << "Not a complex data type " << type; +} + +template +constexpr R ArrayTypeSwitch(F&& f, PrimitiveType type) { + if (ABSL_PREDICT_TRUE(IsArrayType(type))) { + if (IsFloatingPointType(type)) { + return FloatingPointTypeSwitch(std::forward(f), type); + } + if (IsIntegralType(type)) { + return IntegralTypeSwitch(std::forward(f), type); + } + if (IsComplexType(type)) { + return ComplexTypeSwitch(std::forward(f), type); + } + if (type == PRED) { + return std::forward(f)(PrimitiveTypeConstant()); + } + } + LOG(FATAL) << "Not an array data type " << type; +} + +template +constexpr R PrimitiveTypeSwitch(F&& f, PrimitiveType type) { + if (ABSL_PREDICT_TRUE(IsArrayType(type))) { + return ArrayTypeSwitch(std::forward(f), type); + } + if (type == TUPLE) { + return std::forward(f)(PrimitiveTypeConstant()); + } + if (type == TOKEN) { + return std::forward(f)(PrimitiveTypeConstant()); + } + if (type == OPAQUE_TYPE) { + return std::forward(f)( + PrimitiveTypeConstant()); + } + LOG(FATAL) << "unhandled type " << type; } namespace internal { @@ -470,20 +527,10 @@ inline constexpr auto kByteWidths = ByteWidthArrayHelper( template & kWidths> inline constexpr int WidthForType(PrimitiveType type) { - if (ABSL_PREDICT_FALSE(type == TOKEN)) { - // Tokens require no space. - return 0; - } - if (ABSL_PREDICT_FALSE(type == TUPLE)) { - LOG(FATAL) << "TUPLE is an invalid type for BitWidth"; + if (ABSL_PREDICT_TRUE(IsArrayType(type))) { + return kWidths[type]; } - if (ABSL_PREDICT_FALSE(type == OPAQUE_TYPE)) { - LOG(FATAL) << "OPAQUE_TYPE is an invalid type for BitWidth"; - } - if (ABSL_PREDICT_FALSE(!IsArrayType(type))) { - LOG(FATAL) << "Unhandled primitive type " << type; - } - return kWidths[type]; + LOG(FATAL) << "Unhandled primitive type " << type; } } // namespace internal @@ -661,7 +708,7 @@ const std::string& LowercasePrimitiveTypeName(PrimitiveType s); // Returns the PrimitiveType matching the given name. The given name is expected // to be lower-case. -StatusOr StringToPrimitiveType(absl::string_view name); +absl::StatusOr StringToPrimitiveType(absl::string_view name); // Returns true if the given name is a primitive type string (lower-case). bool IsPrimitiveTypeName(absl::string_view name); @@ -698,14 +745,11 @@ bool IsCanonicalRepresentation(PrimitiveType type) { } inline bool FitsInIntegralType(int64_t x, PrimitiveType ty) { - return primitive_util::PrimitiveTypeSwitch( + return primitive_util::IntegralTypeSwitch( [&](auto primitive_type) -> bool { - if constexpr (primitive_util::IsIntegralType(primitive_type)) { - using NativeT = primitive_util::NativeTypeOf; - return std::numeric_limits::min() <= x && - std::numeric_limits::max() >= x; - } - LOG(FATAL) << "Invalid primitive type " << PrimitiveType_Name(ty); + using NativeT = primitive_util::NativeTypeOf; + return std::numeric_limits::min() <= x && + std::numeric_limits::max() >= x; }, ty); } diff --git a/xla/primitive_util_test.cc b/xla/primitive_util_test.cc index 8f9a67ff37556..c61f73438e8a5 100644 --- a/xla/primitive_util_test.cc +++ b/xla/primitive_util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/printer.cc b/xla/printer.cc index ed7c77ee3d275..7aae58aa22f7f 100644 --- a/xla/printer.cc +++ b/xla/printer.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/printer.h b/xla/printer.h index 2dd402fdbdc4c..cef4658f0db05 100644 --- a/xla/printer.h +++ b/xla/printer.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/protobuf_util.cc b/xla/protobuf_util.cc index f9b40c1b8f4e1..022cbb745baac 100644 --- a/xla/protobuf_util.cc +++ b/xla/protobuf_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/protobuf_util.h b/xla/protobuf_util.h index 70ab3c9bf87f8..e739353cf826b 100644 --- a/xla/protobuf_util.h +++ b/xla/protobuf_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/BUILD b/xla/python/BUILD index 27d0882db3d71..e7e7f80b80f55 100644 --- a/xla/python/BUILD +++ b/xla/python/BUILD @@ -1,16 +1,11 @@ -load("//xla:pytype.default.bzl", "pytype_strict_library") -load("//xla:strict.default.bzl", "py_strict_library", "py_strict_test") load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load( - "//xla:xla.bzl", - "xla_cc_test", - "xla_py_test_deps", -) load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load( "@tsl//tsl:tsl.bzl", "if_cuda_or_rocm", + "if_google", + "internal_visibility", ) load("@tsl//tsl:tsl.default.bzl", "tsl_pybind_extension") load("@tsl//tsl/platform:build_config.bzl", "pyx_library", "tf_proto_library") @@ -19,10 +14,19 @@ load( "@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) +load("//xla:pytype.default.bzl", "pytype_strict_library") +load("//xla:strict.default.bzl", "py_strict_library", "py_strict_test") +load( + "//xla:xla.bzl", + "xla_cc_test", + "xla_py_test_deps", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [":friends"], + default_visibility = internal_visibility([ + ":friends", + ]), licenses = ["notice"], ) @@ -37,15 +41,19 @@ package_group( pytype_strict_library( name = "xla_client", srcs = ["xla_client.py"], + pytype_srcs = ["xla_client.pyi"], srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ - ":xla_extension", + ":xla_extension", # buildcleaner: keep "@ml_dtypes", ], ) -exports_files(["xla_client.pyi"]) +exports_files([ + "xla_client.py", + "xla_client.pyi", +]) pyx_library( name = "custom_call_for_test", @@ -125,10 +133,9 @@ py_strict_test( python_version = "PY3", srcs_version = "PY3", tags = [ - "config-cuda-only", "no_oss", "requires-gpu-nvidia", - ], # TODO(phawkins): This test passes, but requires --config=monolithic. + ] + if_google(["config-cuda-only"]), # TODO(phawkins): This test passes, but requires --config=monolithic. deps = [ ":xla_client", ":xla_extension", @@ -139,32 +146,16 @@ py_strict_test( ] + xla_py_test_deps(), ) -cc_library( - name = "status_casters", - hdrs = ["status_casters.h"], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - ":exceptions", - "//xla:status", - "//xla:statusor", - "@pybind11", - "@tsl//tsl/platform:macros", - ], -) - tsl_pybind_extension( name = "status_casters_ext", srcs = ["status_casters_ext.cc"], visibility = ["//visibility:private"], deps = [ - ":exceptions", - ":status_casters", - "@pybind11", + "//xla/pjrt:exceptions", + "//xla/pjrt:status_casters", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@nanobind", ], ) @@ -181,21 +172,6 @@ py_strict_test( ] + xla_py_test_deps(), ) -cc_library( - name = "exceptions", - hdrs = ["exceptions.h"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - "//xla:status", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - ], -) - cc_library( name = "types", srcs = ["types.cc"], @@ -208,20 +184,28 @@ cc_library( features = ["-use_header_modules"], visibility = [":friends"], deps = [ - ":exceptions", + ":nb_helpers", + ":nb_numpy", "//xla:literal", "//xla:shape_util", "//xla:status", "//xla:status_macros", "//xla:statusor", - "//xla:types", + "//xla:util", "//xla:xla_data_proto_cc", + "//xla/pjrt:exceptions", "//xla/python/ifrt", + "//xla/python/pjrt_ifrt", + "//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", - "@pybind11", - "@pybind11_abseil//pybind11_abseil:absl_casters", - "@tsl//tsl/platform:protobuf", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_config_python//:python_headers", # buildcleaner: keep + "@nanobind", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", ], ) @@ -240,23 +224,8 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@pybind11", - ], -) - -cc_library( - name = "python_utils", - hdrs = ["python_utils.h"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - "//xla:status_macros", - "//xla:util", "@local_config_python//:python_headers", # buildcleaner: keep - "@pybind11", + "@nanobind", ], ) @@ -272,15 +241,18 @@ cc_library( features = ["-use_header_modules"], visibility = [":friends"], deps = [ - ":exceptions", + ":nb_class_ptr", ":python_ref_manager", # placeholder for index annotation deps "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@local_config_python//:python_headers", # buildcleaner: keep + "//xla/pjrt:exceptions", + "@tsl//tsl/platform", "@tsl//tsl/platform:logging", - "@pybind11", + "@nanobind", ], ) @@ -295,11 +267,12 @@ cc_library( ], features = ["-use_header_modules"], deps = [ - ":traceback", - "//xla:statusor", "//xla:util", "@com_google_absl//absl/container:flat_hash_map", - "@pybind11", + "@com_google_absl//absl/status:statusor", + "@local_config_python//:python_headers", # buildcleaner: keep + "@nanobind", + "@tsl//tsl/platform:logging", "@tsl//tsl/platform:protobuf", "@tsl//tsl/profiler/protobuf:profile_proto_cc", ], @@ -309,23 +282,27 @@ cc_library( name = "py_client", srcs = [ "py_array.cc", - "py_buffer.cc", "py_client.cc", "py_compile_only_client.cc", + "py_device.cc", "py_device_list.cc", "py_executable.cc", "py_host_callback.cc", + "py_memory_space.cc", + "py_program.cc", "py_values.cc", "sharding.cc", ], hdrs = [ "py_array.h", - "py_buffer.h", "py_client.h", "py_compile_only_client.h", + "py_device.h", "py_device_list.h", "py_executable.h", "py_host_callback.h", + "py_memory_space.h", + "py_program.h", "py_values.h", "sharded_device_array.h", "sharding.h", @@ -341,13 +318,14 @@ cc_library( features = ["-use_header_modules"], deps = [ ":callback", - ":exceptions", + ":nb_absl_span", + ":nb_class_ptr", + ":nb_helpers", + ":nb_numpy", ":pprof_profile_builder", ":py_client_gpu", ":py_host_callback_proto_cc", ":python_ref_manager", - ":python_utils", - ":status_casters", ":traceback", ":transfer_guard_lib", ":types", @@ -356,48 +334,71 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", + "@local_config_python//:python_headers", # buildcleaner: keep "//xla:comparison_util", + "//xla:literal", "//xla:shape_util", + "//xla:status", + "//xla:status_macros", "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", + "//xla/hlo/ir:hlo", + "//xla/pjrt:exceptions", "//xla/pjrt:host_callback", "//xla/pjrt:lru_cache", "//xla/pjrt:mlir_to_hlo", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", "//xla/pjrt:pjrt_compiler", + "//xla/pjrt:pjrt_device_description", + "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", + "//xla/pjrt:pjrt_layout", "//xla/pjrt:pjrt_stream_executor_client", + "//xla/pjrt:status_casters", "//xla/pjrt:transpose", "//xla/python/ifrt", + "//xla/python/ifrt:plugin_program", + "//xla/python/ifrt:plugin_program_serdes", "//xla/python/pjrt_ifrt", "//xla/python/pjrt_ifrt:xla_host_callback_proto_cc", "//xla/python/pjrt_ifrt:xla_ifrt", + "//xla/service:computation_placer_hdr", "//xla/service:custom_call_status", "//xla/service:custom_call_target_registry", "//xla/service:platform_util", + "//xla/tsl/python/lib/core:numpy", + "@tsl//tsl/concurrency:ref_count", + "@tsl//tsl/framework:allocator", + "@tsl//tsl/platform:casts", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:fingerprint", - "@tsl//tsl/platform:float8", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:ml_dtypes", + "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", - "@tsl//tsl/python/lib/core:numpy", "@com_google_protobuf//:protobuf", "@llvm-project//llvm:Support", - "@pybind11", - "@pybind11_abseil//pybind11_abseil:absl_casters", + "@llvm-project//mlir:IR", + "@nanobind", ] + if_cuda([ "@local_config_cuda//cuda:cuda_headers", + "//xla/stream_executor/cuda:cuda_driver", ]) + if_rocm([ "@local_config_rocm//rocm:rocm_headers", ]), @@ -411,21 +412,29 @@ cc_library( hdrs = [ "callback.h", ], + compatible_with = [], copts = [ "-fexceptions", "-fno-strict-aliasing", ], features = ["-use_header_modules"], deps = [ + ":nb_numpy", ":python_ref_manager", "//xla:comparison_util", "//xla:xla_data_proto_cc", + "//xla/pjrt:host_callback", "//xla/pjrt:transpose", "//xla/service:custom_call_status", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@pybind11", + "@local_config_python//:python_headers", # buildcleaner: keep + "@nanobind", "@tsl//tsl/platform:statusor", ], ) @@ -438,6 +447,7 @@ cc_library( hdrs = if_cuda_or_rocm([ "py_client_gpu.h", ]), + compatible_with = [], copts = [ "-fexceptions", "-fno-strict-aliasing", @@ -448,12 +458,14 @@ cc_library( features = ["-use_header_modules"], deps = [ ":callback", - ":exceptions", + ":nb_numpy", "//xla:comparison_util", + "//xla/pjrt:exceptions", + "//xla/pjrt:host_callback", "//xla/service:custom_call_status", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", - "@pybind11", + "@nanobind", "@tsl//tsl/platform:errors", ] + if_cuda([ "@local_config_cuda//cuda:cuda_headers", @@ -473,21 +485,33 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":nb_class_ptr", ":py_client", ":python_ref_manager", ":traceback", + ":types", ":util", - "//xla:types", + "//xla:shape_util", + "//xla:status_macros", "//xla:util", + "//xla/pjrt:exceptions", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_compiler", + "//xla/pjrt:pjrt_layout", + "//xla/python/ifrt", + "//xla/python/pjrt_ifrt", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@dlpack", + "@llvm-project//llvm:Support", "@local_config_python//:python_headers", # buildcleaner: keep - "@pybind11", + "@nanobind", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", ], ) @@ -503,33 +527,25 @@ cc_library( features = ["-use_header_modules"], visibility = [":friends"], # For the functions to access C++ flags/thread-local variables deps = [ - ":exceptions", + ":nb_helpers", ":py_client", ":python_ref_manager", - ":python_utils", ":pytree", - ":status_casters", ":types", - ":util", # placeholder for index annotation deps - "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_config_python//:python_headers", # build_cleaner: keep - "//xla:shape_util", - "//xla:statusor", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/pjrt:lru_cache", "//xla/pjrt:pjrt_client", - "//xla/python/ifrt", - "@tsl//tsl/platform:status", + "//xla/pjrt:status_casters", + "@tsl//tsl/platform:logging", "@tsl//tsl/profiler/lib:traceme", - "@pybind11", + "@nanobind", ], ) @@ -546,6 +562,31 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "custom_partition_callback", + srcs = ["custom_partition_callback.cc"], + hdrs = ["custom_partition_callback.h"], + deps = [ + "//xla:debug_options_flags", + "//xla:util", + "//xla/client:xla_computation", + "//xla/hlo/ir:hlo", + "//xla/pjrt:mlir_to_hlo", + "//xla/pjrt/c:pjrt_c_api_custom_partitioner_extension_hdrs", + "//xla/pjrt/c:pjrt_c_api_hdrs", + "//xla/pjrt/c:pjrt_c_api_helpers", + "//xla/service:call_inliner", + "//xla/service:custom_call_sharding_helper", + "//xla/service:hlo_pass_pipeline", + "//xla/service/spmd:spmd_partitioner", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "custom_call_sharding", srcs = ["custom_call_sharding.cc"], @@ -558,18 +599,23 @@ cc_library( features = ["-use_header_modules"], visibility = ["//visibility:private"], deps = [ + ":custom_partition_callback", ":inspect_sharding", - ":status_casters", # placeholder for index annotation deps - "//xla/client:xla_computation", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "//xla:shape_util", + "//xla:status", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_sharding_util", - "//xla/service:call_inliner", - "//xla/service:custom_call_sharding_helper", - "//xla/service:hlo_pass_pipeline", - "//xla/service/spmd:spmd_partitioner", - "@tsl//tsl/platform:errors", - "@pybind11", + "//xla/pjrt:status_casters", + "//xla/pjrt/c:pjrt_c_api_custom_partitioner_extension_hdrs", + "//xla/pjrt/c:pjrt_c_api_hdrs", + "//xla/pjrt/c:pjrt_c_api_helpers", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", + "@nanobind", ], ) @@ -584,7 +630,8 @@ cc_library( ], features = ["-use_header_modules"], deps = [ - ":status_casters", + ":nb_absl_span", + ":nb_helpers", ":types", # placeholder for index annotation deps "@com_google_absl//absl/types:span", @@ -600,7 +647,8 @@ cc_library( "//xla/client/lib:self_adjoint_eig", "//xla/client/lib:sorting", "//xla/client/lib:svd", - "@pybind11", + "//xla/pjrt:status_casters", + "@nanobind", ], ) @@ -638,20 +686,36 @@ cc_library( visibility = ["//visibility:private"], deps = [ ":jax_jit", + ":nb_helpers", + ":nb_numpy", ":py_client", - ":python_utils", + ":python_ref_manager", ":pytree", - ":status_casters", + ":traceback", ":transfer_guard_lib", - ":util", # placeholder for index annotation deps + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@local_config_python//:python_headers", # buildcleaner: keep + "//xla:util", + "//xla/pjrt:exceptions", "//xla/pjrt:lru_cache", + "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_future", "//xla/python/ifrt", - "//xla/python/pjrt_ifrt", + "@tsl//tsl/concurrency:ref_count", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", - "@pybind11", + "@nanobind", ], ) @@ -667,30 +731,39 @@ cc_library( features = ["-use_header_modules"], visibility = ["//visibility:private"], deps = [ - ":exceptions", ":jax_jit", + ":nb_class_ptr", + ":nb_helpers", + ":nb_numpy", ":py_client", - ":python_utils", + ":python_ref_manager", ":pytree", - ":status_casters", + ":traceback", ":types", - ":util", # placeholder for index annotation deps + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@com_google_absl//absl/types:variant", + "@local_config_python//:python_headers", # buildcleaner: keep + "//xla:status_macros", + "//xla:util", "//xla:xla_data_proto_cc", + "//xla/pjrt:exceptions", "//xla/pjrt:pjrt_client", + "//xla/pjrt:status_casters", "//xla/python/ifrt", - "//xla/python/pjrt_ifrt", + "//xla/tsl/python/lib/core:numpy", + "@tsl//tsl/concurrency:ref_count", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", - "@pybind11", - "@pybind11_abseil//pybind11_abseil:absl_casters", + "@nanobind", ], ) @@ -724,17 +797,22 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":nb_class_ptr", ":outfeed_receiver", ":py_client", - ":status_casters", ":types", # placeholder for index annotation deps "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", + "//xla:literal", "//xla/client:executable_build_options", "//xla/client:xla_builder", "//xla/pjrt:pjrt_client", - "@pybind11", + "//xla/pjrt:status_casters", + "@tsl//tsl/platform:logging", + "@nanobind", ], ) @@ -772,7 +850,7 @@ cc_library( features = ["-use_header_modules"], visibility = [":friends"], deps = [ - ":exceptions", + ":nb_class_ptr", ":pytree_proto_cc", # placeholder for index annotation deps "@com_google_absl//absl/algorithm:container", @@ -781,9 +859,11 @@ cc_library( "@com_google_absl//absl/hash", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@local_config_python//:python_headers", # buildcleaner: keep + "//xla/pjrt:exceptions", "@tsl//tsl/platform:logging", - "@pybind11", - "@pybind11_abseil//pybind11_abseil:absl_casters", + "@nanobind", ], ) @@ -799,19 +879,23 @@ cc_library( features = ["-use_header_modules"], deps = [ ":refine_polymorphic_shapes", - ":status_casters", - ":types", # placeholder for index annotation deps + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "//xla:status", "//xla/client:xla_computation", "//xla/mlir/utils:error_util", "//xla/mlir_hlo", "//xla/mlir_hlo:all_passes", "//xla/pjrt:mlir_to_hlo", + "//xla/pjrt:status_casters", "//xla/service/llvm_ir:llvm_util", "//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeWriter", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", @@ -819,7 +903,8 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:SparseTensorDialect", - "@pybind11", + "@llvm-project//mlir:Support", + "@nanobind", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_serialization", @@ -834,6 +919,7 @@ cc_library( "//xla/mlir/utils:error_util", "@com_google_absl//absl/status", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeWriter", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", @@ -860,8 +946,6 @@ cc_library( ], features = ["-use_header_modules"], deps = [ - ":exceptions", - ":status_casters", ":types", ":xplane_to_profile_instructions", # placeholder for index annotation deps @@ -871,16 +955,20 @@ cc_library( "//xla/backends/profiler/cpu:python_tracer", "//xla/backends/profiler/plugin:plugin_tracer", "//xla/backends/profiler/plugin:profiler_c_api_hdrs", + "//xla/pjrt:exceptions", + "//xla/pjrt:status_casters", "//xla/pjrt/c:pjrt_c_api_hdrs", "//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs", - "//xla/python/profiler/internal:traceme_wrapper", + "@tsl//tsl/platform:macros", + "@tsl//tsl/platform:protobuf", "@tsl//tsl/profiler/lib:profiler_factory", "@tsl//tsl/profiler/lib:profiler_interface", "@tsl//tsl/profiler/lib:profiler_session", + "@tsl//tsl/profiler/lib:traceme", "@tsl//tsl/profiler/rpc:profiler_server_impl", "@tsl//tsl/profiler/rpc/client:capture_profile", "@tsl//tsl/profiler/rpc/client:profiler_client_impl", - "@pybind11", + "@nanobind", ] + select({ ":gpu_enabled": [ "//xla/backends/profiler/gpu:device_tracer", @@ -901,13 +989,11 @@ cc_library( features = ["-use_header_modules"], visibility = [":friends"], deps = [ - ":status_casters", # placeholder for index annotation deps - "@com_google_absl//absl/base:core_headers", - "//xla:status", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/status", "//xla:util", - "@pybind11", - "@pybind11_abseil//pybind11_abseil:absl_casters", + "@nanobind", ], ) @@ -928,7 +1014,7 @@ cc_library( "//xla/pjrt:pjrt_future", "//xla/python/ifrt", "@com_google_absl//absl/strings:str_format", - "@pybind11", + "@com_google_absl//absl/types:span", ], ) @@ -944,12 +1030,15 @@ cc_library( features = ["-use_header_modules"], visibility = ["//visibility:private"], deps = [ + ":nb_helpers", # placeholder for index annotation deps + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "//xla/pjrt:lru_cache", - "@pybind11", + "@nanobind", ], ) @@ -964,17 +1053,21 @@ cc_library( ], features = ["-use_header_modules"], deps = [ - ":exceptions", + ":nb_absl_span", + ":nb_helpers", + ":nb_numpy", ":py_client", - ":status_casters", ":types", # placeholder for index annotation deps "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "//xla:array", "//xla:debug_options_flags", + "//xla:literal", "//xla:shape_util", "//xla:statusor", "//xla:util", @@ -983,8 +1076,14 @@ cc_library( "//xla/client:executable_build_options", "//xla/client:xla_builder", "//xla/client:xla_computation", + "//xla/ffi", + "//xla/ffi:ffi_api", + "//xla/ffi/api:c_api", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", + "//xla/pjrt:exceptions", + "//xla/pjrt:pjrt_executable", + "//xla/pjrt:status_casters", "//xla/service:call_inliner", "//xla/service:computation_placer", "//xla/service:custom_call_target_registry", @@ -998,7 +1097,8 @@ cc_library( "//xla/service:name_uniquer", "//xla/service:tuple_simplifier", "@tsl//tsl/lib/strings:proto_serialization", - "@pybind11", + "@tsl//tsl/platform:logging", + "@nanobind", ], ) @@ -1053,9 +1153,27 @@ cc_library( ]), ) +cc_library( + name = "logging", + srcs = ["logging.cc"], + hdrs = ["logging.h"], + deps = [ + "@com_google_absl//absl/log:initialize", + ], +) + tsl_pybind_extension( name = "xla_extension", - srcs = ["xla_extension.cc"], + srcs = ["xla.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + defines = select({ + ":gpu_enabled": ["XLA_PYTHON_ENABLE_GPU=1"], + "//conditions:default": [], + }), + features = ["-use_header_modules"], linkopts = select({ ":use_jax_cuda_pip_rpaths": [ "-Wl,-rpath,$$ORIGIN/../nvidia/cuda_cupti/lib", @@ -1073,37 +1191,15 @@ tsl_pybind_extension( ], pytype_srcs = glob(["xla_extension/*.pyi"]), visibility = ["//visibility:public"], - deps = [ - ":xla_extension_library", - "@pybind11", - ], -) - -cc_library( - name = "xla_extension_library", - srcs = [ - "logging.cc", - "logging.h", - "xla.cc", - ], - hdrs = [ - "xla.h", - ], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - defines = select({ - ":gpu_enabled": ["XLA_PYTHON_ENABLE_GPU=1"], - "//conditions:default": [], - }), - features = ["-use_header_modules"], deps = [ ":custom_call_sharding", ":dlpack", ":jax_jit", + ":logging", ":mlir", + ":nb_absl_flat_hash_map", + ":nb_absl_span", + ":nb_class_ptr", ":ops", ":outfeed_receiver_py", ":pjit", @@ -1114,7 +1210,6 @@ cc_library( ":python_ref_manager", ":pytree", ":refine_polymorphic_shapes", - ":status_casters", ":traceback", ":transfer_guard_lib", ":types", @@ -1124,7 +1219,10 @@ cc_library( # placeholder for index annotation deps "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/hash", "@com_google_absl//absl/log:initialize", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", @@ -1137,27 +1235,58 @@ cc_library( "//xla:statusor", "//xla:types", "//xla:util", + "//xla/ffi:ffi_api", + "//xla/pjrt:exceptions", "//xla/pjrt:mlir_to_hlo", "//xla/pjrt:pjrt_api", "//xla/pjrt:pjrt_c_api_client", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", "//xla/pjrt:pjrt_compiler", + "//xla/pjrt:pjrt_executable", + "//xla/pjrt:pjrt_layout", + "//xla/pjrt:status_casters", "//xla/pjrt/c:pjrt_c_api_hdrs", "//xla/pjrt/cpu:cpu_client", "//xla/pjrt/distributed", "//xla/pjrt/distributed:client", + "//xla/pjrt/distributed:key_value_store_interface", "//xla/pjrt/distributed:protocol_proto_cc", "//xla/pjrt/distributed:service", + "//xla/pjrt/gpu:gpu_helpers", "//xla/python/ifrt", + "//xla/python/ifrt:plugin_program", + "//xla/python/ifrt:plugin_program_serdes", + "//xla/python/ifrt_proxy/client:py_module", "//xla/python/pjrt_ifrt", + "//xla/python/pjrt_ifrt:xla_ifrt", + "//xla/service/cpu:collectives_interface", + "//xla/tsl/python/lib/core:numpy", + "@tsl//tsl/concurrency:ref_count", "@tsl//tsl/distributed_runtime/preemption:preemption_sync_manager", "@tsl//tsl/platform", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform/cloud:gcs_file_system", - "@tsl//tsl/python/lib/core:numpy", - "@pybind11", + "@llvm-project//mlir:IR", + "@nanobind", ] + select({ + # gloo transport only builds on linux + "@tsl//tsl:macos": [], + "@tsl//tsl:windows": [], + "//conditions:default": [ + "//xla/pjrt/cpu:gloo_collectives", + "//xla/pjrt/cpu:gloo_kv_store", + "@gloo//:transport_tcp", + ], + }) + select({ + # mpitrampoline does not build on windows + "@tsl//tsl:windows": [], + "//conditions:default": [ + "//xla/pjrt/cpu:mpi_collectives", + ], + }) + select({ ":gpu_enabled": [ "//xla/pjrt/gpu:se_gpu_pjrt_client", ], @@ -1180,7 +1309,6 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:hlo_proto_cc", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -1215,3 +1343,76 @@ xla_cc_test( "@tsl//tsl/profiler/utils:xplane_schema", ], ) + +cc_library( + name = "nb_class_ptr", + hdrs = ["nb_class_ptr.h"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + deps = ["@nanobind"], +) + +cc_library( + name = "nb_helpers", + srcs = ["nb_helpers.cc"], + hdrs = ["nb_helpers.h"], + compatible_with = [], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/strings:str_format", + "@local_config_python//:python_headers", + "@nanobind", + ], +) + +cc_library( + name = "nb_numpy", + srcs = ["nb_numpy.cc"], + hdrs = ["nb_numpy.h"], + compatible_with = [], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + deps = [ + "//xla/tsl/python/lib/core:numpy", + "@com_google_absl//absl/types:span", + "@local_config_python//:python_headers", + "@nanobind", + ], +) + +cc_library( + name = "nb_absl_span", + hdrs = ["nb_absl_span.h"], + compatible_with = [], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/types:span", + "@nanobind", + ], +) + +cc_library( + name = "nb_absl_flat_hash_map", + hdrs = ["nb_absl_flat_hash_map.h"], + compatible_with = [], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@nanobind", + ], +) + +cc_library( + name = "nb_absl_flat_hash_set", + hdrs = ["nb_absl_flat_hash_set.h"], + compatible_with = [], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/container:flat_hash_set", + "@nanobind", + ], +) diff --git a/xla/python/callback.cc b/xla/python/callback.cc index fbba7e6841299..b3e96ce313808 100644 --- a/xla/python/callback.cc +++ b/xla/python/callback.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,54 +15,79 @@ limitations under the License. #include "xla/python/callback.h" +#include #include +#include #include #include #include #include +#include #include +#include +#include "absl/base/casts.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/types/span.h" -#include "pybind11/numpy.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/string_view.h" // from @nanobind // IWYU pragma: keep +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" #include "xla/primitive_util.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/python_ref_manager.h" #include "xla/service/custom_call_status.h" #include "tsl/platform/statusor.h" -namespace py = pybind11; +namespace nb = nanobind; namespace xla { -Status CpuCallback::PrepareAndCallInternal(void* result, void** arg_ptrs) { +CpuCallback::~CpuCallback() { + // The destructor may be called without GIL held. In that case, we defer it + // to GlobalPyRefManager. + std::vector objects; + objects.push_back(std::move(callable_)); + for (auto& arg : args_) { + objects.push_back(std::move(arg.dtype)); + } + + GlobalPyRefManager()->AddGarbage(absl::MakeSpan(objects)); +} + +absl::Status CpuCallback::PrepareAndCallInternal(void* result, + void** arg_ptrs) { absl::Span inputs(arg_ptrs, args_.size()); absl::Span outputs(reinterpret_cast(result), results_.size()); - py::gil_scoped_acquire gil; - py::tuple args(inputs.size()); + nb::gil_scoped_acquire gil; + nb::tuple args = nb::steal(PyTuple_New(inputs.size())); for (size_t i = 0; i < inputs.size(); ++i) { if (args_[i].type == xla::TOKEN) { - args[i] = py::none(); + PyTuple_SET_ITEM(args.ptr(), i, nb::none().release().ptr()); } else { - static_assert(sizeof(ssize_t) == sizeof(int64_t)); - absl::Span strides( - reinterpret_cast(args_[i].strides.data()), - args_[i].strides.size()); - args[i] = py::array(args_[i].dtype, args_[i].dims, strides, - const_cast(inputs[i])); - args[i].attr("flags").attr("writeable") = Py_False; + nb_numpy_ndarray array = + nb_numpy_ndarray(args_[i].dtype, args_[i].dims, args_[i].strides, + const_cast(inputs[i])); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(args.ptr(), i, array.release().ptr()); } } TF_ASSIGN_OR_RETURN(auto result_tuple, CallInternal(std::move(args))); for (size_t i = 0; i < results_.size(); ++i) { - py::object output = py::reinterpret_borrow( - PyTuple_GetItem(result_tuple.ptr(), i)); - py::array array = py::cast(std::move(output)); + if (results_[i].type == xla::TOKEN) { + continue; + } + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + nb_numpy_ndarray array = nb_numpy_ndarray::ensure(std::move(output)); absl::Span dims( reinterpret_cast(array.shape()), array.ndim()); absl::Span strides( @@ -70,11 +95,14 @@ Status CpuCallback::PrepareAndCallInternal(void* result, void** arg_ptrs) { if (strides == results_[i].expected_strides) { std::memcpy(outputs[i], array.data(), results_[i].size_in_bytes); } else { - xla::StatusOr> plan = - transpose_cache_.GetOrCreate( - xla::primitive_util::ByteWidth(results_[i].type), dims, - results_[i].reversed_layout, - /*input_layout=*/xla::TransposePlan::Striding{strides}); + xla::TransposePlan::Options options; + options.elem_size_in_bytes = + xla::primitive_util::ByteWidth(results_[i].type); + options.dims = dims; + options.permutation = results_[i].reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + absl::StatusOr> plan = + transpose_cache_.GetOrCreate(options); if (!plan.ok()) { return std::move(plan).status(); } @@ -82,7 +110,7 @@ Status CpuCallback::PrepareAndCallInternal(void* result, void** arg_ptrs) { } } - return OkStatus(); + return absl::OkStatus(); } void CpuCallback::PrepareAndCall(void* result, void** arg_ptrs, @@ -95,43 +123,50 @@ void CpuCallback::PrepareAndCall(void* result, void** arg_ptrs, } } -Status CpuCallback::PrepareAndCall(void* result, void** arg_ptrs) { +absl::Status CpuCallback::PrepareAndCall(void* result, void** arg_ptrs) { return PrepareAndCallInternal(result, arg_ptrs); } -StatusOr CpuCallback::CallInternal(py::tuple args) { - py::object result_object; - try { - result_object = callable_(*py::reinterpret_borrow(args)); - } catch (py::error_already_set& e) { - PyErr_Clear(); +absl::StatusOr CpuCallback::CallInternal(nb::tuple args) { + auto py_error_to_status = [](nb::python_error& e) { std::string error_message = e.what(); return absl::InternalError( absl::StrFormat("CpuCallback error: %s", error_message)); + }; + nb::object result_object; + try { + result_object = callable_(*nb::borrow(args)); + } catch (nb::python_error& e) { + return py_error_to_status(e); } if (!PyTuple_Check(result_object.ptr())) { return absl::InternalError( absl::StrFormat("CPU callback expected a tuple result, got %s", - static_cast(py::repr(result_object)))); + nb::cast(nb::repr(result_object)))); } if (PyTuple_Size(result_object.ptr()) != results_.size()) { return absl::InternalError( absl::StrFormat("CPU callback expected a tuple with %d results, got %d", results_.size(), PyTuple_Size(result_object.ptr()))); } - py::tuple result_tuple = py::cast(result_object); + nb::tuple result_tuple = nb::cast(result_object); for (size_t i = 0; i < results_.size(); ++i) { - py::object output = py::reinterpret_borrow( - PyTuple_GetItem(result_tuple.ptr(), i)); + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); if (results_[i].type == xla::TOKEN) { if (!output.is_none()) { return absl::InternalError(absl::StrFormat( "Token output from Python callback should be None, got %s", - static_cast(py::repr(output)))); + nb::cast(nb::repr(output)))); } continue; } - py::array array = py::cast(std::move(output)); + nb_numpy_ndarray array; + try { + array = nb_numpy_ndarray::from_any(output, NPY_ARRAY_ENSUREARRAY); + } catch (nb::python_error& e) { + return py_error_to_status(e); + } static_assert(sizeof(ssize_t) == sizeof(int64_t), "Expected ssize_t to be of equal size to int64_t"); absl::Span dims( @@ -147,15 +182,15 @@ StatusOr CpuCallback::CallInternal(py::tuple args) { return result_tuple; } -StatusOr CpuCallback::Call(py::tuple args) { +absl::StatusOr CpuCallback::Call(nb::tuple args) { return CallInternal(std::move(args)); } -std::optional CpuCallback::Call(py::tuple args, +std::optional CpuCallback::Call(nb::tuple args, XlaCustomCallStatus* status) { auto statusor = CallInternal(std::move(args)); if (!statusor.ok()) { - absl::string_view msg = statusor.status().message(); + std::string_view msg = statusor.status().message(); XlaCustomCallStatusSetFailure(status, msg.data(), msg.length()); return std::nullopt; } diff --git a/xla/python/callback.h b/xla/python/callback.h index 401554825f071..c7d209db48788 100644 --- a/xla/python/callback.h +++ b/xla/python/callback.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,13 +16,18 @@ limitations under the License. #ifndef XLA_PYTHON_CALLBACK_H_ #define XLA_PYTHON_CALLBACK_H_ +#include +#include #include #include #include -#include "pybind11/numpy.h" // from @pybind11 +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" // from @nanobind #include "xla/pjrt/transpose.h" -#include "xla/python/python_ref_manager.h" +#include "xla/python/nb_numpy.h" #include "xla/service/custom_call_status.h" #include "xla/xla_data.pb.h" @@ -32,7 +37,7 @@ class CpuCallback { public: struct Arg { xla::PrimitiveType type; // XLA type - pybind11::dtype dtype; // NumPy type, for array types. + nb_dtype dtype; // NumPy type, for array types. absl::InlinedVector dims; // Dimensions, for array types. std::vector strides; // Byte strides, for array types. size_t size_in_bytes; // Size of the array in bytes. @@ -50,24 +55,14 @@ class CpuCallback { size_t size_in_bytes; }; - explicit CpuCallback(pybind11::function callable, std::vector args, + explicit CpuCallback(nanobind::callable callable, std::vector args, std::vector results) : callable_(std::move(callable)), args_(std::move(args)), results_(std::move(results)), transpose_cache_(/*capacity=*/16) {} - ~CpuCallback() { - // The destructor may be called without GIL held. In that case, we defer it - // to GlobalPyRefManager. - std::vector objects; - objects.push_back(std::move(callable_)); - for (auto& arg : args_) { - objects.push_back(std::move(arg.dtype)); - } - - GlobalPyRefManager()->AddGarbage(absl::MakeSpan(objects)); - } + ~CpuCallback(); const std::vector& args() const { return args_; } size_t num_args() const { return args_.size(); } @@ -79,17 +74,17 @@ class CpuCallback { void PrepareAndCall(void* result, void** arg_ptrs, XlaCustomCallStatus* status); - Status PrepareAndCall(void* result, void** arg_ptrs); + absl::Status PrepareAndCall(void* result, void** arg_ptrs); - std::optional Call(pybind11::tuple args, + std::optional Call(nanobind::tuple args, XlaCustomCallStatus* status); - StatusOr Call(pybind11::tuple args); + absl::StatusOr Call(nanobind::tuple args); private: - Status PrepareAndCallInternal(void* result, void** arg_ptrs); - StatusOr CallInternal(pybind11::tuple args); + absl::Status PrepareAndCallInternal(void* result, void** arg_ptrs); + absl::StatusOr CallInternal(nanobind::tuple args); - pybind11::function callable_; + nanobind::callable callable_; std::vector args_; std::vector results_; xla::TransposePlanCache transpose_cache_; diff --git a/xla/python/custom_call_sharding.cc b/xla/python/custom_call_sharding.cc index 9b5e5e1dfef81..f1b67ed1c1596 100644 --- a/xla/python/custom_call_sharding.cc +++ b/xla/python/custom_call_sharding.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,221 +15,183 @@ limitations under the License. #include "xla/python/custom_call_sharding.h" #include -#include +#include #include #include #include +#include #include #include #include -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/stl.h" // from @pybind11 -#include "xla/client/xla_computation.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_opcode.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/optional.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/vector.h" // from @nanobind // IWYU pragma: keep +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/utils/hlo_sharding_util.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h" +#include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/custom_partition_callback.h" #include "xla/python/inspect_sharding.h" -#include "xla/python/status_casters.h" -#include "xla/service/call_inliner.h" -#include "xla/service/custom_call_sharding_helper.h" -#include "xla/service/hlo_pass_pipeline.h" -#include "xla/service/spmd/spmd_partitioner_util.h" -#include "tsl/platform/errors.h" +#include "xla/shape.h" +#include "xla/status.h" +#include "xla/util.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla { -namespace py = ::pybind11; +namespace nb = ::nanobind; -std::vector GetArgShapes(const HloInstruction* instruction) { - std::vector result; - result.reserve(instruction->operand_count()); - for (HloInstruction* operand : instruction->operands()) { - result.push_back(operand->shape()); - } - return result; -} - -std::vector> GetArgShardings( - const HloInstruction* instruction) { - std::vector> result; - result.reserve(instruction->operand_count()); - for (HloInstruction* operand : instruction->operands()) { - if (operand->has_sharding()) { - result.push_back(operand->sharding()); - } else { - result.push_back(std::nullopt); - } +class PyCustomCallPartitionerCallbacks { + public: + PyCustomCallPartitionerCallbacks(nb::object prop_user_sharding, + nb::object partition, + nb::object infer_sharding_from_operands) + : prop_user_sharding_(prop_user_sharding), + partition_(partition), + infer_sharding_from_operands_(infer_sharding_from_operands) { + callbacks_.version = 0; + callbacks_.private_data = this; + callbacks_.dtor = +[](JAX_CustomCallPartitioner_Callbacks* self) { + delete GetSelfPtr(self); + }; + callbacks_.partition = +[](JAX_CustomCallPartitioner_Callbacks* self, + JAX_CustomCallPartitioner_Partition_Args* args) { + jax::PopulateResults(GetSelfPtr(self)->CallPartition(args), args); + }; + callbacks_.infer_sharding = + +[](JAX_CustomCallPartitioner_Callbacks* self, + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) { + jax::PopulateResults( + GetSelfPtr(self)->CallInferShardingFromOperands(args), args); + }; + callbacks_.propagate_user_sharding = + +[](JAX_CustomCallPartitioner_Callbacks* self, + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) { + jax::PopulateResults( + GetSelfPtr(self)->CallPropagateUserSharding(args), args); + }; } - return result; -} - -HloInstruction* InlineHloComputation(HloInstruction* instruction, - HloComputation* computation, - HloComputation::Builder* builder, - std::vector operands, - std::function new_channel, - const std::string& suffix) { - HloCloneContext context(instruction->GetModule(), suffix); - absl::flat_hash_map replacements; - auto resolve = [&](HloInstruction* inst) { - auto it = replacements.find(inst); - if (it == replacements.end()) { - throw py::key_error( - absl::StrCat("Could not find mapping for: ", inst->ToString())); + absl::StatusOr< + std::tuple, xla::HloSharding>> + CallPartition(JAX_CustomCallPartitioner_Partition_Args* args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); } - return it->second; - }; + TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); + std::vector shapes = std::move(std::get<0>(args_tuple)); + std::vector> shardings = + std::move(std::get<1>(args_tuple)); + xla::Shape result_shape = std::move(std::get<2>(args_tuple)); + std::optional result_sharding = + std::move(std::get<3>(args_tuple)); + std::string_view backend_config = std::move(std::get<4>(args_tuple)); - for (auto* inst : computation->MakeInstructionPostOrder()) { - if (inst->opcode() == HloOpcode::kParameter) { - replacements.emplace(inst, operands[inst->parameter_number()]); - } else { - std::vector new_operands; - new_operands.reserve(inst->operand_count()); - for (HloInstruction* operand : inst->mutable_operands()) { - new_operands.push_back(resolve(operand)); - } - auto* new_inst = builder->AddInstruction( - inst->CloneWithNewOperands(inst->shape(), new_operands, &context)); - HloChannelInstruction* channel_instr = - DynCast(new_inst); - if (channel_instr && channel_instr->channel_id().has_value()) { - new_inst->set_channel_id(new_channel()); + { + nb::gil_scoped_acquire gil; + try { + auto py_result = + partition_(shapes, shardings, result_shape, result_sharding, + nb::bytes(backend_config.data(), backend_config.size())); + try { + auto [ir, arg_shardings, result_sharding] = nb::cast< + std::tuple, HloSharding>>( + py_result); + if (arg_shardings.size() != args->num_args) { + return xla::Internal( + "Shardings returned from partitioning: lengths must match: %d " + "vs %d", + arg_shardings.size(), args->num_args); + } + return std::make_tuple(std::string(ir.c_str(), ir.size()), + std::move(arg_shardings), + std::move(result_sharding)); + } catch (const nb::cast_error& e) { + return xla::Internal( + "Shardings returned from partitioning: expected " + "Tuple[bytes, List[HloSharding], HloSharding] got: %s", + nb::cast(nb::repr(py_result))); + } + } catch (const nb::python_error& e) { + return xla::Internal("custom_partitioner: %s", e.what()); } - replacements.emplace(inst, new_inst); } } - return resolve(computation->root_instruction()); -} -class PyCustomCallPartitioner : public CustomCallPartitioner { - public: - PyCustomCallPartitioner(py::object prop_user_sharding, py::object partition, - py::object infer_sharding_from_operands, - bool can_side_effecting_have_replicated_sharding) - : prop_user_sharding_(prop_user_sharding), - partition_(partition), - infer_sharding_from_operands_(infer_sharding_from_operands), - can_side_effecting_have_replicated_sharding_( - can_side_effecting_have_replicated_sharding) {} - xla::Status Partition(spmd::SpmdPartitioningVisitor* partitioner, - HloInstruction* instruction) const override { - py::gil_scoped_acquire gil; - try { - auto py_result = - partition_(GetArgShapes(instruction), GetArgShardings(instruction), - instruction->shape(), instruction->sharding(), - py::bytes(instruction->raw_backend_config_string())); + absl::StatusOr> CallInferShardingFromOperands( + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); + std::vector arg_shapes = std::move(std::get<0>(args_tuple)); + std::vector> arg_shardings = + std::move(std::get<1>(args_tuple)); + xla::Shape result_shape = std::move(std::get<2>(args_tuple)); + std::string_view backend_config = std::move(std::get<3>(args_tuple)); - const XlaComputation* computation = nullptr; // Kept alive by py_result. - std::vector arg_shardings; - std::optional result_sharding; - try { - std::tie(computation, arg_shardings, result_sharding) = - py::cast, - HloSharding>>(py_result); - } catch (const py::cast_error& e) { - return xla::InternalError( - "Shardings returned from partitioning %s: expected " - "Tuple[XlaComputation, List[HloSharding], HloSharding] got: %s", - instruction->ToString(), py::repr(py_result)); - } - auto hlo_module_config = - xla::HloModule::CreateModuleConfigFromProto( - computation->proto(), xla::DefaultDebugOptionsIgnoringFlags()) - .value(); - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - xla::HloModule::CreateFromProto(computation->proto(), - hlo_module_config)); - std::vector operands; - operands.reserve(instruction->operand_count()); - if (arg_shardings.size() != instruction->operand_count()) { - return xla::InternalError( - "Shardings returned from partitioning %s must match: %d vs %d", - instruction->ToString(), arg_shardings.size(), - instruction->operand_count()); - } - for (size_t i = 0; i < instruction->operand_count(); ++i) { - operands.push_back( - partitioner->GetPartitionedHlo(instruction->mutable_operand(i)) - .Reshard(arg_shardings[i]) - .hlo()); + std::optional result; + nb::gil_scoped_acquire gil; + try { + auto py_result = infer_sharding_from_operands_( + arg_shapes, arg_shardings, result_shape, + nb::bytes(backend_config.data(), backend_config.size())); + if (py_result.is_none()) { + return std::nullopt; } - - // The custom call module does not go through the main compiler pipeline, - // so inline all calls here explicitly, since some targets require it. - HloPassPipeline pipeline("custom-call-inliner"); - pipeline.AddPass(); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module.get(), {}).status()); - - auto* partitioned_hlo = InlineHloComputation( - instruction, hlo_module->entry_computation(), partitioner->builder(), - operands, [partitioner]() { return partitioner->NewChannel(); }, - "_custom_call_lowering_rule"); - partitioned_hlo->set_sharding(result_sharding.value()); - - spmd::PartitionedHlo result_partitioned = - spmd::PartitionedHlo(partitioned_hlo, instruction->shape(), - partitioner->MakePartitioningState()) - .Reshard(instruction->sharding()); - - partitioner->SetPartitionedHlo(instruction, result_partitioned); - return xla::OkStatus(); - } catch (const pybind11::error_already_set& e) { - return xla::InternalError("custom_partitioner: %s", e.what()); + return nb::cast(py_result); + } catch (const nb::python_error& e) { + return xla::Internal("custom_partitioner: %s", e.what()); } } - HloSharding PropagateUserSharding( - const HloInstruction* instruction, const HloInstruction* user, - const HloSharding& sharding) const override { - py::gil_scoped_acquire gil; + + absl::StatusOr CallPropagateUserSharding( + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); + xla::HloSharding result_sharding = std::move(std::get<0>(args_tuple)); + xla::Shape result_shape = std::move(std::get<1>(args_tuple)); + std::string_view backend_config = std::move(std::get<2>(args_tuple)); + + nb::gil_scoped_acquire gil; try { // TODO(parkers): expand this API to handle the `user` sharding. // The user is used when the custom call returns a Tuple and // the user is a get-tuple-element. In this case we must update only // part of the sharding spec. - auto result = py::cast(prop_user_sharding_( - sharding, instruction->shape(), - py::bytes(instruction->raw_backend_config_string()))); + auto result = nb::cast(prop_user_sharding_( + result_sharding, result_shape, + nb::bytes(backend_config.data(), backend_config.size()))); return result; - } catch (const pybind11::error_already_set& e) { - LOG(FATAL) << absl::StrFormat("custom_partitioner: %s", e.what()); - } - } - std::optional InferShardingFromOperands( - const HloInstruction* instruction) const override { - std::optional result; - std::vector arg_shapes = GetArgShapes(instruction); - auto arg_shardings = GetArgShardings(instruction); - py::gil_scoped_acquire gil; - try { - auto py_result = infer_sharding_from_operands_( - arg_shapes, arg_shardings, instruction->shape(), - py::bytes(instruction->raw_backend_config_string())); - if (py_result.is_none()) { - return std::nullopt; - } - return py::cast(py_result); - } catch (const pybind11::error_already_set& e) { - LOG(FATAL) << absl::StrFormat("custom_partitioner: %s", e.what()); + } catch (const nb::python_error& e) { + return xla::Internal("custom_partitioner: %s", e.what()); } - return result; - } - bool IsCustomCallShardable(const HloInstruction* instruction) const override { - return true; } - bool CanSideEffectingHaveReplicatedSharding() const override { - return can_side_effecting_have_replicated_sharding_; + + JAX_CustomCallPartitioner_Callbacks* callbacks() { return &callbacks_; } + + private: + static PyCustomCallPartitionerCallbacks* GetSelfPtr( + JAX_CustomCallPartitioner_Callbacks* callbacks) { + return reinterpret_cast( + callbacks->private_data); } - absl::Status status_set_; - py::object prop_user_sharding_; - py::object partition_; - py::object infer_sharding_from_operands_; - bool can_side_effecting_have_replicated_sharding_; + JAX_CustomCallPartitioner_Callbacks callbacks_; + nb::object prop_user_sharding_; + nb::object partition_; + nb::object infer_sharding_from_operands_; }; namespace { @@ -240,26 +202,59 @@ void CallInspectSharding(void* obj, JAX_InspectSharding_Callback_Args* args) { return; } try { - py::gil_scoped_acquire gil; - py::handle(reinterpret_cast(obj))(*std::move(arg)); - } catch (const pybind11::error_already_set& e) { + nb::gil_scoped_acquire gil; + nb::handle(reinterpret_cast(obj))(*std::move(arg)); + } catch (const nb::python_error& e) { jax::InspectShardingSetError(args, std::string(e.what())); } } } // namespace -void BuildCustomCallShardingPybindAPI(pybind11::module& m) { +void BuildCustomCallShardingPybindAPI(nb::module_& m) { m.def( "register_custom_call_partitioner", - [](std::string name, py::object prop_user_sharding, py::object partition, - py::object infer_sharding_from_operands, - bool can_side_effecting_have_replicated_sharding) { - RegisterCustomCallPartitioner( - name, - std::make_unique( - prop_user_sharding, partition, infer_sharding_from_operands, - can_side_effecting_have_replicated_sharding)); + [](std::string name, nb::object prop_user_sharding, nb::object partition, + nb::object infer_sharding_from_operands, + bool can_side_effecting_have_replicated_sharding, + std::optional c_api) { + auto* c_fns = + (new PyCustomCallPartitionerCallbacks(prop_user_sharding, partition, + infer_sharding_from_operands)) + ->callbacks(); + c_fns->can_side_effecting_have_replicated_sharding = + can_side_effecting_have_replicated_sharding; + if (!c_api.has_value()) { + RegisterCustomCallPartitioner( + name, jax::CreateCApiCustomCallPartitioner(c_fns)); + return; + } + + if (std::string_view(c_api->name()) != "pjrt_c_api") { + throw absl::InvalidArgumentError( + "Argument to register_custom_call_partitioner was not a " + "pjrt_c_api capsule."); + } + auto* c_api_value = static_cast(c_api->data()); + PJRT_Custom_Partitioner_Extension* extension = + pjrt::FindExtension( + c_api_value, + PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner); + if (extension == nullptr) { + return; + } + PJRT_Register_Custom_Partitioner_Args args; + args.struct_size = PJRT_Register_Custom_Partitioner_Args_STRUCT_SIZE; + args.name = name.c_str(); + args.name_size = name.size(); + args.callbacks = c_fns; + PJRT_Error* error = + reinterpret_cast( + extension) + ->register_custom_partitioner(&args); + std::unique_ptr error_ptr( + error, pjrt::MakeErrorDeleter(c_api_value)); + ThrowIfError(pjrt::PjrtErrorToStatus(error_ptr.get(), c_api_value)); }, R"(Registers a partitioner for a custom-call operation. @@ -274,21 +269,24 @@ void BuildCustomCallShardingPybindAPI(pybind11::module& m) { Takes operand sharding and returns the instruction sharding. can_side_effecting_have_replicated_sharding: Side effecting ops are not allowed to have replicated sharding. Pass true to disable this check. + c_api: Optional `PJRT_Api*` if it is called with a plugin. This is safe to + call on plugins that do not implement the custom partitioner extension )", - py::arg("name"), py::arg("prop_user_sharding"), py::arg("partition"), - py::arg("infer_sharding_from_operands"), - py::arg("can_side_effecting_have_replicated_sharding") = false); + nb::arg("name"), nb::arg("prop_user_sharding"), nb::arg("partition"), + nb::arg("infer_sharding_from_operands"), + nb::arg("can_side_effecting_have_replicated_sharding") = false, + nb::arg("c_api").none() = std::nullopt); m.def("encode_inspect_sharding_callback", - [](py::object handler) -> py::bytes { + [](nb::object handler) -> nb::bytes { JAX_InspectSharding_Callback cb; cb.call = &CallInspectSharding; cb.data = handler.ptr(); char bytes[sizeof(JAX_InspectSharding_Callback)]; - memcpy(&bytes, &cb, sizeof(JAX_InspectSharding_Callback)); - return py::bytes(bytes, sizeof(JAX_InspectSharding_Callback)); + std::memcpy(&bytes, &cb, sizeof(JAX_InspectSharding_Callback)); + return nb::bytes(bytes, sizeof(JAX_InspectSharding_Callback)); }); - py::module hlo_sharding_util_m = m.def_submodule( + nb::module_ hlo_sharding_util_m = m.def_submodule( "hlo_sharding_util", "Utilities for manipulating HloSharding."); hlo_sharding_util_m.def( "PartiallyReplicateTiledShardingOnDims", diff --git a/xla/python/custom_call_sharding.h b/xla/python/custom_call_sharding.h index 9b339062444d5..ed840d5509f69 100644 --- a/xla/python/custom_call_sharding.h +++ b/xla/python/custom_call_sharding.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,11 +17,11 @@ limitations under the License. #define XLA_PYTHON_CUSTOM_CALL_SHARDING_H_ // placeholder for index annotation headers -#include "pybind11/pybind11.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind namespace xla { -void BuildCustomCallShardingPybindAPI(pybind11::module& m); +void BuildCustomCallShardingPybindAPI(nanobind::module_& m); } // namespace xla diff --git a/xla/python/custom_partition_callback.cc b/xla/python/custom_partition_callback.cc new file mode 100644 index 0000000000000..d9bcb596bf599 --- /dev/null +++ b/xla/python/custom_partition_callback.cc @@ -0,0 +1,505 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/python/custom_partition_callback.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "xla/client/xla_computation.h" +#include "xla/debug_options_flags.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_clone_context.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h" +#include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/service/call_inliner.h" +#include "xla/service/custom_call_sharding_helper.h" +#include "xla/service/hlo_pass_pipeline.h" +#include "xla/service/spmd/spmd_partitioner_util.h" +#include "xla/util.h" + +namespace xla { + +absl::StatusOr InlineHloComputation( + HloInstruction* instruction, HloComputation* computation, + HloComputation::Builder* builder, std::vector operands, + std::function new_channel, const std::string& suffix) { + HloCloneContext context(instruction->GetModule(), suffix); + + absl::flat_hash_map replacements; + auto resolve = [&](HloInstruction* inst) -> absl::StatusOr { + auto it = replacements.find(inst); + if (it == replacements.end()) { + return absl::InternalError( + absl::StrCat("Could not find mapping for: ", inst->ToString())); + } + return it->second; + }; + + for (auto* inst : computation->MakeInstructionPostOrder()) { + if (inst->opcode() == HloOpcode::kParameter) { + replacements.emplace(inst, operands[inst->parameter_number()]); + } else { + std::vector new_operands; + new_operands.reserve(inst->operand_count()); + for (HloInstruction* operand : inst->mutable_operands()) { + TF_ASSIGN_OR_RETURN(auto* new_operand, resolve(operand)); + new_operands.push_back(new_operand); + } + auto* new_inst = builder->AddInstruction( + inst->CloneWithNewOperands(inst->shape(), new_operands, &context)); + HloChannelInstruction* channel_instr = + DynCast(new_inst); + if (channel_instr && channel_instr->channel_id().has_value()) { + new_inst->set_channel_id(new_channel()); + } + replacements.emplace(inst, new_inst); + } + } + return resolve(computation->root_instruction()); +} + +class CApiCustomCallPartitioner : public xla::CustomCallPartitioner { + public: + explicit CApiCustomCallPartitioner(JAX_CustomCallPartitioner_Callbacks* c_fns) + : c_fns_(c_fns) {} + ~CApiCustomCallPartitioner() override { c_fns_->dtor(c_fns_); } + absl::Status Partition(spmd::SpmdPartitioningVisitor* partitioner, + HloInstruction* instruction) const override { + JAX_CustomCallPartitioner_Partition_Args args; + auto scratch = jax::PopulateArgs(&args, instruction); + c_fns_->partition(c_fns_, &args); + + XlaComputation computation; + std::vector arg_shardings; + std::optional result_sharding; + std::string mlir_module; + TF_ASSIGN_OR_RETURN(std::tie(mlir_module, arg_shardings, result_sharding), + jax::ConsumeResults(&args)); + TF_RETURN_IF_ERROR(ParseMlirModuleStringAndConvertToXlaComputation( + mlir_module, computation, /*use_tuple_args=*/false, + /*return_tuple=*/false)); + auto hlo_module_config = + xla::HloModule::CreateModuleConfigFromProto( + computation.proto(), xla::DefaultDebugOptionsIgnoringFlags()) + .value(); + TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, + xla::HloModule::CreateFromProto(computation.proto(), + hlo_module_config)); + std::vector operands; + operands.reserve(instruction->operand_count()); + if (arg_shardings.size() != instruction->operand_count()) { + return xla::Internal( + "Shardings returned from partitioning %s must match: %d vs %d", + instruction->ToString(), arg_shardings.size(), + instruction->operand_count()); + } + for (size_t i = 0; i < instruction->operand_count(); ++i) { + operands.push_back( + partitioner->GetPartitionedHlo(instruction->mutable_operand(i)) + .Reshard(arg_shardings[i]) + .hlo()); + } + + // The custom call module does not go through the main compiler pipeline, + // so inline all calls here explicitly, since some targets require it. + HloPassPipeline pipeline("custom-call-inliner"); + pipeline.AddPass(); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module.get(), {}).status()); + + TF_ASSIGN_OR_RETURN( + auto* partitioned_hlo, + InlineHloComputation( + instruction, hlo_module->entry_computation(), + partitioner->builder(), operands, + [partitioner]() { return partitioner->NewChannel(); }, + "_custom_call_lowering_rule")); + partitioned_hlo->set_sharding(result_sharding.value()); + + spmd::PartitionedHlo result_partitioned = + spmd::PartitionedHlo(partitioned_hlo, instruction->shape(), + partitioner->MakePartitioningState()) + .Reshard(instruction->sharding()); + + partitioner->SetPartitionedHlo(instruction, result_partitioned); + return absl::OkStatus(); + } + HloSharding PropagateUserSharding( + const HloInstruction* instruction, const HloInstruction* user, + const HloSharding& sharding) const override { + JAX_CustomCallPartitioner_PropagateUserSharding_Args args; + auto scratch = jax::PopulateArgs(&args, instruction, sharding); + c_fns_->propagate_user_sharding(c_fns_, &args); + auto status_or_result = jax::ConsumeResults(&args); + TF_CHECK_OK(status_or_result.status()); + return *status_or_result; + } + std::optional InferShardingFromOperands( + const HloInstruction* instruction) const override { + JAX_CustomCallPartitioner_InferShardingFromOperands_Args args; + auto scratch = jax::PopulateArgs(&args, instruction); + c_fns_->infer_sharding(c_fns_, &args); + auto status_or_result = jax::ConsumeResults(&args); + TF_CHECK_OK(status_or_result.status()); + return *status_or_result; + } + bool IsCustomCallShardable(const HloInstruction* instruction) const override { + return true; + } + bool CanSideEffectingHaveReplicatedSharding() const override { + return c_fns_->can_side_effecting_have_replicated_sharding; + } + + JAX_CustomCallPartitioner_Callbacks* c_fns_; +}; + +} // namespace xla + +namespace jax { + +namespace { + +void SetCAPIString(JAX_CustomCallPartitioner_string& out, std::string result, + std::vector& scratch) { + scratch.push_back(std::move(result)); + out.data = scratch.back().data(); + out.size = scratch.back().size(); +} + +std::string_view ToStringView(JAX_CustomCallPartitioner_string data) { + return std::string_view(data.data, data.size); +} + +void SetCAPIAval(JAX_CustomCallPartitioner_aval& result, + const xla::HloInstruction* inst, + std::vector& scratch) { + SetCAPIString(result.shape, inst->shape().SerializeAsString(), scratch); + if (inst->has_sharding()) { + result.has_sharding = true; + SetCAPIString(result.sharding, + inst->sharding().ToProto().SerializeAsString(), scratch); + } else { + result.has_sharding = false; + } +} + +} // namespace + +struct ResultScratch { + absl::Status status; + std::vector strings; + std::vector op_args_sharding_storage; +}; + +absl::StatusOr ReadHloSharding( + JAX_CustomCallPartitioner_string data) { + xla::OpSharding proto; + if (data.size > std::numeric_limits::max() || + !proto.ParseFromArray(data.data, data.size)) { + return absl::InternalError( + "custom_call_sharding.cc: error parsing OpShardingProto"); + } + return xla::HloSharding::FromProto(std::move(proto)); +} + +absl::StatusOr ReadHloShape(JAX_CustomCallPartitioner_string data) { + xla::ShapeProto proto; + if (data.size > std::numeric_limits::max() || + !proto.ParseFromArray(data.data, data.size)) { + return absl::InternalError( + "custom_call_sharding.cc: error parsing xla::Shape"); + } + return xla::Shape(proto); +} + +bool PopulateErrorHeader(JAX_CustomCallPartitioner_version_and_error& header, + absl::Status status) { + header.has_error = !status.ok(); + if (header.has_error) { + auto* status_copy = new absl::Status(status); + header.data = status_copy; + header.cleanup_fn = reinterpret_cast( + +[](absl::Status* data) { delete data; }); + header.code = pjrt::StatusCodeToPjrtErrorCode(status_copy->code()); + header.error_msg.data = status_copy->message().data(); + header.error_msg.size = status_copy->message().size(); + } + return header.has_error; +} + +absl::Status ConsumeHeader( + JAX_CustomCallPartitioner_version_and_error& header) { + if (header.has_error) { + return absl::Status(pjrt::PjrtErrorCodeToStatusCode(header.code), + ToStringView(header.error_msg)); + } + return absl::OkStatus(); +} + +void PopulateResults( + absl::StatusOr, + xla::HloSharding>> + results, + JAX_CustomCallPartitioner_Partition_Args* args) { + if (PopulateErrorHeader(args->header, results.status())) { + return; + } + auto* scratch = new ResultScratch; + args->header.data = scratch; + args->header.cleanup_fn = reinterpret_cast( + +[](ResultScratch* data) { delete data; }); + auto& [mlir_module, shardings, result_shardings] = *results; + scratch->strings.reserve(2 + args->num_args); + SetCAPIString(args->mlir_module, std::move(mlir_module), scratch->strings); + SetCAPIString(args->result_sharding, + result_shardings.ToProto().SerializeAsString(), + scratch->strings); + scratch->op_args_sharding_storage.resize(args->num_args); + for (size_t i = 0; i < args->num_args; ++i) { + SetCAPIString(scratch->op_args_sharding_storage[i], + shardings[i].ToProto().SerializeAsString(), scratch->strings); + } + args->args_sharding = scratch->op_args_sharding_storage.data(); +} + +absl::StatusOr< + std::tuple, xla::HloSharding>> +ConsumeResults(JAX_CustomCallPartitioner_Partition_Args* args) { + absl::Cleanup cleanup = [args] { + args->header.cleanup_fn(args->header.data); + }; + TF_RETURN_IF_ERROR(ConsumeHeader(args->header)); + TF_ASSIGN_OR_RETURN(auto result_sharding, + ReadHloSharding(args->result_sharding)); + std::vector arg_shardings; + arg_shardings.reserve(args->num_args); + for (size_t i = 0; i < args->num_args; ++i) { + TF_ASSIGN_OR_RETURN(auto arg_sharding, + ReadHloSharding(args->args_sharding[i])); + arg_shardings.push_back(std::move(arg_sharding)); + } + return std::tuple, + xla::HloSharding>( + std::string(ToStringView(args->mlir_module)), std::move(arg_shardings), + std::move(result_sharding)); +} + +PartitionScratch PopulateArgs(JAX_CustomCallPartitioner_Partition_Args* args, + const xla::HloInstruction* instruction) { + args->header.api_version = 0; + args->header.data = nullptr; + args->header.cleanup_fn = nullptr; + PartitionScratch scratch; + scratch.op_args_storage.resize(instruction->operand_count()); + scratch.strings.reserve(instruction->operand_count() * 2 + 2); + size_t i = 0; + for (xla::HloInstruction* operand : instruction->operands()) { + SetCAPIAval(scratch.op_args_storage[i], operand, scratch.strings); + ++i; + } + args->num_args = instruction->operand_count(); + args->op_args = scratch.op_args_storage.data(); + SetCAPIAval(args->op_result, instruction, scratch.strings); + args->backend_config.data = instruction->raw_backend_config_string().data(); + args->backend_config.size = instruction->raw_backend_config_string().size(); + return scratch; +} + +absl::StatusOr, std::vector>, + xla::Shape, std::optional, std::string_view>> +ReadArgs(JAX_CustomCallPartitioner_Partition_Args* args) { + std::vector shapes; + std::vector> shardings; + shapes.reserve(args->num_args); + shardings.reserve(args->num_args); + for (size_t i = 0; i < args->num_args; ++i) { + TF_ASSIGN_OR_RETURN(auto shape, ReadHloShape(args->op_args[i].shape)); + shapes.push_back(shape); + if (args->op_args[i].has_sharding) { + TF_ASSIGN_OR_RETURN(auto sharding, + ReadHloSharding(args->op_args[i].sharding)); + shardings.push_back(std::move(sharding)); + } else { + shardings.push_back(std::nullopt); + } + } + + TF_ASSIGN_OR_RETURN(auto result_shape, ReadHloShape(args->op_result.shape)); + std::optional result_sharding; + if (args->op_result.has_sharding) { + TF_ASSIGN_OR_RETURN(result_sharding, + ReadHloSharding(args->op_result.sharding)); + } + return std::tuple, + std::vector>, xla::Shape, + std::optional, std::string_view>( + std::move(shapes), std::move(shardings), std::move(result_shape), + std::move(result_sharding), ToStringView(args->backend_config)); +} + +absl::StatusOr, + std::vector>, + xla::Shape, std::string_view>> +ReadArgs(JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) { + std::vector shapes; + std::vector> shardings; + shapes.reserve(args->num_args); + shardings.reserve(args->num_args); + for (size_t i = 0; i < args->num_args; ++i) { + TF_ASSIGN_OR_RETURN(auto shape, ReadHloShape(args->op_args[i].shape)); + shapes.push_back(shape); + if (args->op_args[i].has_sharding) { + TF_ASSIGN_OR_RETURN(auto sharding, + ReadHloSharding(args->op_args[i].sharding)); + shardings.push_back(std::move(sharding)); + } else { + shardings.push_back(std::nullopt); + } + } + + TF_ASSIGN_OR_RETURN(auto result_shape, ReadHloShape(args->result_shape)); + return std::tuple, + std::vector>, xla::Shape, + std::string_view>(std::move(shapes), std::move(shardings), + std::move(result_shape), + ToStringView(args->backend_config)); +} + +PartitionScratch PopulateArgs( + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args, + const xla::HloInstruction* instruction) { + args->header.api_version = 0; + args->header.data = nullptr; + args->header.cleanup_fn = nullptr; + PartitionScratch scratch; + scratch.op_args_storage.resize(instruction->operand_count()); + scratch.strings.reserve(instruction->operand_count() * 2 + 2); + size_t i = 0; + for (xla::HloInstruction* operand : instruction->operands()) { + SetCAPIAval(scratch.op_args_storage[i], operand, scratch.strings); + ++i; + } + args->num_args = instruction->operand_count(); + args->op_args = scratch.op_args_storage.data(); + SetCAPIString(args->result_shape, instruction->shape().SerializeAsString(), + scratch.strings); + args->backend_config.data = instruction->raw_backend_config_string().data(); + args->backend_config.size = instruction->raw_backend_config_string().size(); + return scratch; +} + +void PopulateResults( + absl::StatusOr> result, + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) { + if (PopulateErrorHeader(args->header, result.status())) { + return; + } + args->has_result_sharding = result->has_value(); + if (result->has_value()) { + auto* data = new std::string((*result)->ToProto().SerializeAsString()); + args->header.data = data; + args->header.cleanup_fn = reinterpret_cast( + +[](std::string* data) { delete data; }); + args->result_sharding.data = data->data(); + args->result_sharding.size = data->size(); + } else { + args->header.cleanup_fn = +[](void*) {}; + } +} +absl::StatusOr> ConsumeResults( + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) { + absl::Cleanup cleanup = [args] { + args->header.cleanup_fn(args->header.data); + }; + TF_RETURN_IF_ERROR(ConsumeHeader(args->header)); + if (!args->has_result_sharding) { + return std::nullopt; + } + return ReadHloSharding(args->result_sharding); +} + +absl::StatusOr> +ReadArgs(JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) { + TF_ASSIGN_OR_RETURN(auto shape, ReadHloShape(args->result_shape)); + TF_ASSIGN_OR_RETURN(auto sharding, ReadHloSharding(args->result_sharding)); + return std::tuple( + std::move(sharding), std::move(shape), + ToStringView(args->backend_config)); +} +PartitionScratch PopulateArgs( + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args, + const xla::HloInstruction* instruction, const xla::HloSharding& sharding) { + args->header.api_version = 0; + args->header.data = nullptr; + args->header.cleanup_fn = nullptr; + PartitionScratch scratch; + scratch.strings.reserve(2); + SetCAPIString(args->result_sharding, sharding.ToProto().SerializeAsString(), + scratch.strings); + SetCAPIString(args->result_shape, instruction->shape().SerializeAsString(), + scratch.strings); + args->backend_config.data = instruction->raw_backend_config_string().data(); + args->backend_config.size = instruction->raw_backend_config_string().size(); + return scratch; +} + +void PopulateResults( + absl::StatusOr result, + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) { + if (PopulateErrorHeader(args->header, result.status())) { + return; + } + auto* data = new std::string(result->ToProto().SerializeAsString()); + args->header.data = data; + args->header.cleanup_fn = reinterpret_cast( + +[](std::string* data) { delete data; }); + args->result_sharding.data = data->data(); + args->result_sharding.size = data->size(); +} +absl::StatusOr ConsumeResults( + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) { + absl::Cleanup cleanup = [args] { + args->header.cleanup_fn(args->header.data); + }; + TF_RETURN_IF_ERROR(ConsumeHeader(args->header)); + return ReadHloSharding(args->result_sharding); +} + +std::unique_ptr CreateCApiCustomCallPartitioner( + JAX_CustomCallPartitioner_Callbacks* c_fns) { + return std::make_unique(c_fns); +} + +} // namespace jax diff --git a/xla/python/custom_partition_callback.h b/xla/python/custom_partition_callback.h new file mode 100644 index 0000000000000..33cc31e75fc9b --- /dev/null +++ b/xla/python/custom_partition_callback.h @@ -0,0 +1,81 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_PYTHON_CUSTOM_PARTITION_CALLBACK_H_ +#define XLA_PYTHON_CUSTOM_PARTITION_CALLBACK_H_ + +#include +#include +#include +#include +#include + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h" +#include "xla/service/custom_call_sharding_helper.h" + +namespace jax { + +struct PartitionScratch { + std::vector strings; + std::vector op_args_storage; +}; +PartitionScratch PopulateArgs(JAX_CustomCallPartitioner_Partition_Args* args, + const xla::HloInstruction* instruction); +absl::StatusOr, std::vector>, + xla::Shape, std::optional, std::string_view>> +ReadArgs(JAX_CustomCallPartitioner_Partition_Args* args); +void PopulateResults( + absl::StatusOr, + xla::HloSharding>> + results, + JAX_CustomCallPartitioner_Partition_Args* args); +absl::StatusOr< + std::tuple, xla::HloSharding>> +ConsumeResults(JAX_CustomCallPartitioner_Partition_Args* args); + +absl::StatusOr, + std::vector>, + xla::Shape, std::string_view>> +ReadArgs(JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args); +PartitionScratch PopulateArgs( + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args, + const xla::HloInstruction* instruction); +void PopulateResults( + absl::StatusOr> result, + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args); +absl::StatusOr> ConsumeResults( + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args); + +absl::StatusOr> +ReadArgs(JAX_CustomCallPartitioner_PropagateUserSharding_Args* args); +PartitionScratch PopulateArgs( + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args, + const xla::HloInstruction* instruction, const xla::HloSharding& sharding); +void PopulateResults( + absl::StatusOr result, + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args); +absl::StatusOr ConsumeResults( + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args); + +// Wraps c-api callbacks with the custom-call partitioner. +std::unique_ptr CreateCApiCustomCallPartitioner( + JAX_CustomCallPartitioner_Callbacks* c_fns); + +} // namespace jax + +#endif // XLA_PYTHON_CUSTOM_PARTITION_CALLBACK_H_ diff --git a/xla/python/dlpack.cc b/xla/python/dlpack.cc index 42353e5b61559..6ac2ebed0d7db 100644 --- a/xla/python/dlpack.cc +++ b/xla/python/dlpack.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,31 +15,49 @@ limitations under the License. #include "xla/python/dlpack.h" +#include + #include #include #include #include #include +#include #include #include #include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "include/dlpack/dlpack.h" // from @dlpack -#include "pybind11/gil.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/py_array.h" +#include "xla/python/py_client.h" #include "xla/python/python_ref_manager.h" #include "xla/python/traceback.h" +#include "xla/python/types.h" #include "xla/python/util.h" -#include "xla/types.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" #include "xla/util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" -namespace py = pybind11; +namespace nb = nanobind; namespace xla { namespace { @@ -50,7 +68,7 @@ struct DLPackTensor { ~DLPackTensor(); // `buffer_reference` is populated if we have shared (read-only) access. - py::object buffer_reference; + nb::object buffer_reference; // `external_reference` is always populated. std::unique_ptr external_reference; @@ -73,7 +91,7 @@ void DLPackTensorDeleter(DLManagedTensor* t) { } } -StatusOr PrimitiveTypeToDLDataType(PrimitiveType type) { +absl::StatusOr PrimitiveTypeToDLDataType(PrimitiveType type) { switch (type) { case S8: return DLDataType{kDLInt, 8, 1}; @@ -100,7 +118,7 @@ StatusOr PrimitiveTypeToDLDataType(PrimitiveType type) { case BF16: return DLDataType{kDLBfloat, 16, 1}; case PRED: - return DLDataType{kDLUInt, 8, 1}; + return DLDataType{kDLBool, 8, 1}; case C64: return DLDataType{kDLComplex, 64, 1}; case C128: @@ -111,12 +129,21 @@ StatusOr PrimitiveTypeToDLDataType(PrimitiveType type) { } } -StatusOr DLDataTypeToPrimitiveType(DLDataType type) { +absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type) { if (type.lanes != 1) { return Unimplemented("DLPack types with lanes != 1 not implemented, got %d", type.lanes); } switch (type.code) { + case kDLBool: + switch (type.bits) { + case 8: + return PRED; + default: + return Unimplemented( + "Only 8-bit DLPack booleans are supported, got %d bits", + type.bits); + } case kDLInt: switch (type.bits) { case 8: @@ -183,7 +210,7 @@ StatusOr DLDataTypeToPrimitiveType(DLDataType type) { } } -StatusOr> StridesToLayout( +absl::StatusOr> StridesToLayout( absl::Span dims, absl::Span strides) { CHECK_EQ(dims.size(), strides.size()); std::vector minor_to_major(dims.size()); @@ -214,7 +241,7 @@ StatusOr> StridesToLayout( return minor_to_major; } -StatusOr DLDeviceTypeForDevice(const PjRtDevice& device) { +absl::StatusOr DLDeviceTypeForDevice(const PjRtDevice& device) { if (device.client()->platform_id() == CpuId()) { return kDLCPU; } else if (device.client()->platform_id() == CudaId()) { @@ -226,16 +253,16 @@ StatusOr DLDeviceTypeForDevice(const PjRtDevice& device) { device.DebugString()); } -StatusOr DLDeviceForDevice(const PjRtDevice& device) { +absl::StatusOr DLDeviceForDevice(const PjRtDevice& device) { DLDevice context; TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device)); context.device_id = device.local_hardware_id(); return context; } -StatusOr DeviceForDLDevice(const PjRtClient* cpu_client, - const PjRtClient* gpu_client, - const DLDevice& context) { +absl::StatusOr DeviceForDLDevice(const PjRtClient* cpu_client, + const PjRtClient* gpu_client, + const DLDevice& context) { switch (context.device_type) { case kDLCPU: if (cpu_client == nullptr) { @@ -266,15 +293,20 @@ StatusOr DeviceForDLDevice(const PjRtClient* cpu_client, } // namespace -StatusOr BufferToDLPackManagedTensor( - py::handle py_buffer, std::optional stream) { - ifrt::Array* ifrt_array = py::cast(py_buffer).ifrt_array(); - auto pack = std::make_unique(); +absl::StatusOr BufferToDLPackManagedTensor( + nb::handle py_buffer, std::optional stream) { + ifrt::Array* ifrt_array = nb::cast(py_buffer).ifrt_array(); if (ifrt_array == nullptr) { return Unimplemented( "BufferToDLPackManagedTensor called on deleted array."); } - PjRtBuffer* pjrt_buffer = IfrtHelpers::pjrt_buffer(ifrt_array); + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + PjRtBuffer* pjrt_buffer = arr->pjrt_buffers().front().get(); + if (pjrt_buffer->IsTuple()) { return Unimplemented( "BufferToDLPackManagedTensor is not implemented for tuple " @@ -284,21 +316,23 @@ StatusOr BufferToDLPackManagedTensor( return Unimplemented("DynamicShape is not implemented in DLPack."); } + auto pack = std::make_unique(); DLTensor& dt = pack->tensor.dl_tensor; { // AcquireExternalReference may block; there are no API guarantees. GlobalPyRefManager()->CollectGarbage(); - py::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN(pack->external_reference, pjrt_buffer->AcquireExternalReference()); if (stream) { TF_RETURN_IF_ERROR( pack->external_reference->WaitUntilBufferReadyOnStream(*stream)); } else { - TF_RETURN_IF_ERROR(AwaitBuffersReady(ifrt_array)); + TF_RETURN_IF_ERROR( + AwaitBuffersReady(absl::MakeConstSpan(&ifrt_array, 1))); } } - pack->buffer_reference = py::reinterpret_borrow(py_buffer); + pack->buffer_reference = nb::borrow(py_buffer); dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer(); pack->tensor.manager_ctx = pack.get(); @@ -311,44 +345,54 @@ StatusOr BufferToDLPackManagedTensor( pack->shape = std::vector(pjrt_buffer->dimensions().begin(), pjrt_buffer->dimensions().end()); - pack->strides = - StridesForShape(pjrt_buffer->element_type(), pjrt_buffer->dimensions(), - pjrt_buffer->layout()); + + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + Layout xla_layout = GetXlaLayoutUnsafe(pjrt_buffer->layout()); + pack->strides = StridesForShape(pjrt_buffer->element_type(), + pjrt_buffer->dimensions(), xla_layout); + dt.shape = reinterpret_cast(pack->shape.data()); dt.strides = reinterpret_cast(pack->strides.data()); dt.byte_offset = 0; - py::capsule capsule(&pack.release()->tensor, kDlTensorCapsuleName, - [](PyObject* obj) { - DLManagedTensor* dlmt = static_cast( - PyCapsule_GetPointer(obj, kDlTensorCapsuleName)); - if (dlmt) { - DLPackTensorDeleter(dlmt); - } else { - // The tensor has been deleted. Clear any error from - // PyCapsule_GetPointer. - PyErr_Clear(); - } - }); + // We cannot use nanobind's capsule object constructor because we need to + // detect if the capsule name has been changed in the deleter, but nanobind + // hides the underlying Python object from the deleter. + nb::capsule capsule = nb::steal( + PyCapsule_New(&pack.release()->tensor, kDlTensorCapsuleName, + [](PyObject* obj) noexcept { + DLManagedTensor* dlmt = static_cast( + PyCapsule_GetPointer(obj, kDlTensorCapsuleName)); + if (dlmt) { + DLPackTensorDeleter(dlmt); + } else { + // The tensor has been deleted. Clear any error from + // PyCapsule_GetPointer. + PyErr_Clear(); + } + })); + if (!capsule.ptr()) { + throw nb::python_error(); + } return capsule; } -StatusOr DLPackManagedTensorToBuffer( - const pybind11::capsule& tensor, std::shared_ptr cpu_client, - std::shared_ptr gpu_client) { +absl::StatusOr DLPackManagedTensorToBuffer( + const nb::capsule& tensor, std::optional> cpu_client, + std::optional> gpu_client) { // TODO(hyeontaek): This is a potential target for an IFRT client to multiplex // multiple PjRt clients. Devices from these PjRt clients could be expressed // as a unified set of IFRT devices. - auto* cpu_pjrt_client = cpu_client ? cpu_client->pjrt_client() : nullptr; - auto* gpu_pjrt_client = gpu_client ? gpu_client->pjrt_client() : nullptr; + auto* cpu_pjrt_client = cpu_client ? (*cpu_client)->pjrt_client() : nullptr; + auto* gpu_pjrt_client = gpu_client ? (*gpu_client)->pjrt_client() : nullptr; - if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) { + if (std::string_view(tensor.name()) != kDlTensorCapsuleName) { return InvalidArgument( "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " "Note that a DLPack tensor may be consumed at most once.", - absl::string_view(tensor.name())); + std::string_view(tensor.name())); } - DLManagedTensor* dlmt = static_cast(tensor); + DLManagedTensor* dlmt = static_cast(tensor.data()); if (dlmt->dl_tensor.ndim < 0) { return InvalidArgument( "Number of dimensions in DLManagedTensor must be nonnegative, got %d", @@ -377,6 +421,32 @@ StatusOr DLPackManagedTensorToBuffer( Shape shape = ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, minor_to_major); + // Raise an error if the resulting PjRtBuffer would have a non-default layout. + // TODO(skyewm): we do this because JAX doesn't currently have good support + // for non-default layouts, and will return wrong results if a non-default + // layout is passed to a computation expecting default layouts. Remove this + // special case when non-default layouts are better supported by JAX. + absl::StatusOr default_layout_from_client = + device->client()->GetDefaultLayout(element_type, dimensions); + Layout default_layout; + if (default_layout_from_client.ok()) { + default_layout = *default_layout_from_client; + } else if (absl::IsUnimplemented(default_layout_from_client.status())) { + // TODO(skyewm): consider remove the fallback path when GetDefaultLayout is + // unimplemented. + Shape host_shape = ShapeUtil::MakeShape(element_type, dimensions); + default_layout = LayoutUtil::GetWithDefaultLayout(host_shape).layout(); + } else { + return default_layout_from_client.status(); + } + if (shape.layout() != default_layout) { + return Unimplemented( + "from_dlpack got array with non-default layout with minor-to-major " + "dimensions (%s), expected (%s)", + absl::StrJoin(shape.layout().minor_to_major(), ","), + absl::StrJoin(default_layout.minor_to_major(), ",")); + } + std::function on_delete_callback; if (dlmt->deleter) { on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; @@ -393,8 +463,8 @@ StatusOr DLPackManagedTensorToBuffer( // TODO(phawkins): simplify the expression below once we know cpu_client is // always non-null. auto client = (cpu_client && device->client() == cpu_pjrt_client) - ? std::move(cpu_client) - : std::move(gpu_client); + ? std::move(*cpu_client) + : std::move(*gpu_client); auto* ifrt_client = llvm::dyn_cast_or_null(client->ifrt_client()); if (ifrt_client == nullptr) { @@ -407,16 +477,16 @@ StatusOr DLPackManagedTensorToBuffer( std::move(ifrt_array), false, true); } -StatusOr DLPackManagedTensorToBuffer( - const pybind11::capsule& tensor, PjRtDevice* device, - std::shared_ptr client, std::optional stream) { - if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) { +absl::StatusOr DLPackManagedTensorToBuffer( + const nb::capsule& tensor, PjRtDevice* device, + nb_class_ptr client, std::optional stream) { + if (std::string_view(tensor.name()) != kDlTensorCapsuleName) { return InvalidArgument( "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " "Note that a DLPack tensor may be consumed at most once.", - absl::string_view(tensor.name())); + std::string_view(tensor.name())); } - DLManagedTensor* dlmt = static_cast(tensor); + DLManagedTensor* dlmt = static_cast(tensor.data()); if (dlmt->dl_tensor.ndim < 0) { return InvalidArgument( "Number of dimensions in DLManagedTensor must be nonnegative, got %d", diff --git a/xla/python/dlpack.h b/xla/python/dlpack.h index 1da97448c64da..abad969abaaae 100644 --- a/xla/python/dlpack.h +++ b/xla/python/dlpack.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,13 @@ limitations under the License. #ifndef XLA_PYTHON_DLPACK_H_ #define XLA_PYTHON_DLPACK_H_ -#include +#include +#include -#include "pybind11/pybind11.h" // from @pybind11 -#include "xla/python/py_buffer.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" #include "xla/python/py_client.h" namespace xla { @@ -31,16 +34,17 @@ namespace xla { // stream, if set, is a GPU stream, e.g. cudaStream_t for CUDA GPUs, that should // be synchronized to the buffer as per // https://dmlc.github.io/dlpack/latest/python_spec.html#python-specification-for-dlpack. -StatusOr BufferToDLPackManagedTensor( - pybind11::handle buffer, std::optional stream); +absl::StatusOr BufferToDLPackManagedTensor( + nanobind::handle buffer, std::optional stream); -StatusOr DLPackManagedTensorToBuffer( - const pybind11::capsule& tensor, std::shared_ptr cpu_client, - std::shared_ptr gpu_client); +absl::StatusOr DLPackManagedTensorToBuffer( + const nanobind::capsule& tensor, + std::optional> cpu_client, + std::optional> gpu_client); -StatusOr DLPackManagedTensorToBuffer( - const pybind11::capsule& tensor, PjRtDevice* device, - std::shared_ptr client, std::optional stream); +absl::StatusOr DLPackManagedTensorToBuffer( + const nanobind::capsule& tensor, ifrt::Device* device, + nb_class_ptr client, std::optional stream); } // namespace xla diff --git a/xla/python/ifrt/BUILD b/xla/python/ifrt/BUILD index 602c1b294d43f..d346f7a55c17b 100644 --- a/xla/python/ifrt/BUILD +++ b/xla/python/ifrt/BUILD @@ -1,5 +1,7 @@ -load("//xla:xla.bzl", "xla_cc_test") +load("@tsl//tsl:tsl.bzl", "internal_visibility") +load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@tsl//tsl/platform:build_config.bzl", "tf_proto_library") +load("//xla:xla.bzl", "xla_cc_test") package_group( name = "friends", @@ -20,10 +22,10 @@ package_group( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ + default_visibility = internal_visibility([ ":friends", ":internal", - ], + ]), ) exports_files([ @@ -66,14 +68,19 @@ cc_library( "tuple.h", "value.h", ], + compatible_with = get_compatible_with_portable(), deps = [ + ":device_proto_cc", + ":dtype_proto_cc", ":serdes", - ":types_proto_cc", + ":shape_proto_cc", + ":sharding_proto_cc", "//xla:status", "//xla:statusor", "//xla:util", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", + "//xla/pjrt:pjrt_layout", "//xla/python/ifrt/ir", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -81,14 +88,19 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/hash", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@tsl//tsl/concurrency:ref_count", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", ], ) @@ -153,7 +165,11 @@ xla_cc_test( srcs = ["shape_test.cc"], deps = [ ":ifrt", + ":shape_proto_cc", + "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", ], ) @@ -200,6 +216,7 @@ cc_library( ":ifrt", ":mock", ":test_util", + "//xla/pjrt:pjrt_common", "@tsl//tsl/platform:test", ], ) @@ -291,12 +308,16 @@ cc_library( hdrs = ["mock.h"], deps = [ ":ifrt", + "//xla:literal", "//xla:test", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_device_description", + "//xla/pjrt:pjrt_layout", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@tsl//tsl/concurrency:ref_count", @@ -307,6 +328,7 @@ cc_library( name = "serdes", srcs = ["serdes.cc"], hdrs = ["serdes.h"], + compatible_with = get_compatible_with_portable(), deps = [ ":serdes_proto_cc", "@com_google_absl//absl/base:core_headers", @@ -346,11 +368,16 @@ cc_library( name = "sharding_serdes", srcs = ["sharding_serdes.cc"], hdrs = ["sharding_serdes.h"], + compatible_with = get_compatible_with_portable(), deps = [ ":ifrt", ":serdes", ":sharding_proto_cc", - "//xla:statusor", + ":sharding_serdes_proto_cc", + "//xla:util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Support", "@tsl//tsl/platform:statusor", ], @@ -363,20 +390,113 @@ xla_cc_test( deps = [ ":ifrt", ":serdes", + ":serdes_proto_cc", ":sharding_serdes", ":sharding_test_util", "@com_google_absl//absl/functional:bind_front", "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@tsl//tsl/platform:statusor", ], ) tf_proto_library( - name = "types_proto", - srcs = ["types.proto"], + name = "device_proto", + srcs = ["device.proto"], +) + +xla_cc_test( + name = "device_test", + size = "small", + srcs = ["device_test.cc"], + deps = [ + ":device_proto_cc", + ":ifrt", + ":sharding_test_util", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:platform_port", + "@tsl//tsl/platform:statusor", + ], +) + +tf_proto_library( + name = "dtype_proto", + srcs = ["dtype.proto"], +) + +xla_cc_test( + name = "dtype_test", + size = "small", + srcs = ["dtype_test.cc"], + deps = [ + ":dtype_proto_cc", + ":ifrt", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + ], +) + +tf_proto_library( + name = "shape_proto", + srcs = ["shape.proto"], ) tf_proto_library( name = "sharding_proto", srcs = ["sharding.proto"], - protodeps = [":types_proto"], + protodeps = [":serdes_proto"], +) + +tf_proto_library( + name = "sharding_serdes_proto", + srcs = ["sharding_serdes.proto"], + protodeps = [ + ":device_proto", + ":shape_proto", + ], +) + +cc_library( + name = "plugin_program", + srcs = ["plugin_program.cc"], + hdrs = ["plugin_program.h"], + deps = [ + ":ifrt", + "@llvm-project//llvm:Support", + ], +) + +cc_library( + name = "plugin_program_serdes", + srcs = ["plugin_program_serdes.cc"], + deps = [ + ":plugin_program", + ":serdes", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + ], + alwayslink = True, +) + +xla_cc_test( + name = "plugin_program_serdes_test", + srcs = ["plugin_program_serdes_test.cc"], + deps = [ + ":plugin_program", + ":plugin_program_serdes", + ":serdes", + ":serdes_proto_cc", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@tsl//tsl/protobuf:status_proto_cc", + ], ) diff --git a/xla/python/ifrt/array.cc b/xla/python/ifrt/array.cc index a36d485c4ce1f..aa44b46927210 100644 --- a/xla/python/ifrt/array.cc +++ b/xla/python/ifrt/array.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/array.h b/xla/python/ifrt/array.h index b09671b4ba0f6..b63d9f4c90096 100644 --- a/xla/python/ifrt/array.h +++ b/xla/python/ifrt/array.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ limitations under the License. #include #include "llvm/Support/ExtensibleRTTI.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/future.h" #include "xla/python/ifrt/shape.h" @@ -36,6 +37,8 @@ namespace ifrt { class Client; +using Layout = ::xla::PjRtLayout; + // Semantics for operations that may copy or move sharded buffers in an array. enum class ArrayCopySemantics : int { // Always creates new buffers to construct an output array. Mutation of the @@ -69,16 +72,20 @@ class Array : public llvm::RTTIExtends { virtual const Shape& shape() const = 0; virtual const Sharding& sharding() const = 0; virtual std::shared_ptr shared_ptr_sharding() const = 0; + // The device memory layout for each shard of the Array. All shards are + // assumed to have the same layout. Cannot be nullptr; implementations should + // return UNIMPLEMENTED instead. + virtual absl::StatusOr> layout() const = 0; // Breaks an array up into per-device arrays. This is the elimination // counterpart of `Client::AssembleArrayFromSingleDeviceArrays()`. - virtual StatusOr>> + virtual absl::StatusOr>> DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) = 0; // Returns a shard of an Array which is fully replicated. This is an // optimization so that instead of disassembling into all the shards when // the Array is fully replicated, we can just get 1 shard out and create an // Array from it. - virtual StatusOr> FullyReplicatedShard( + virtual absl::StatusOr> FullyReplicatedShard( ArrayCopySemantics semantics) = 0; // Fetches the array to host and stores it as unreplicated, unsharded data. @@ -129,7 +136,7 @@ class Array : public llvm::RTTIExtends { // // It may fail if the buffer data would be sent from/to an unaddressable // device. - virtual StatusOr> Reshard( + virtual absl::StatusOr> Reshard( std::shared_ptr new_sharding, ArrayCopySemantics semantics) = 0; diff --git a/xla/python/ifrt/array_impl_test_lib.cc b/xla/python/ifrt/array_impl_test_lib.cc index 9394c2f795c0a..ec1f81424274f 100644 --- a/xla/python/ifrt/array_impl_test_lib.cc +++ b/xla/python/ifrt/array_impl_test_lib.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -83,7 +83,7 @@ TEST_P(ArrayImplWithHostBufferSemanticsTest, // Regardless of the host buffer semantics chosen, the host buffer must not be // used by the runtime once `on_done_with_host_buffer` has been called. - if (semantics == Client::HostBufferSemantics::kZeroCopy) { + if (semantics == Client::HostBufferSemantics::kImmutableZeroCopy) { // `on_done_with_host_buffer` is called only when the `Array` is destroyed // if the runtime implements `kZeroCopy`. A deadlock will occur if we keep // the `Array` instance. @@ -108,7 +108,7 @@ INSTANTIATE_TEST_CASE_P( testing::Values( Client::HostBufferSemantics::kImmutableOnlyDuringCall, Client::HostBufferSemantics::kImmutableUntilTransferCompletes, - Client::HostBufferSemantics::kZeroCopy)); + Client::HostBufferSemantics::kImmutableZeroCopy)); TEST(ArrayImplTest, MakeArrayFromHostBufferImmutableOnlyDuringCall) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); @@ -184,12 +184,12 @@ TEST(ArrayImplTest, MakeArrayFromHostBufferZeroCopy) { std::shared_ptr sharding = SingleDeviceSharding::Create(device, MemoryKind()); - TF_ASSERT_OK_AND_ASSIGN( - auto array, - client->MakeArrayFromHostBuffer(data->data(), dtype, shape, - /*byte_strides=*/std::nullopt, sharding, - Client::HostBufferSemantics::kZeroCopy, - /*on_done_with_host_buffer=*/nullptr)); + TF_ASSERT_OK_AND_ASSIGN(auto array, + client->MakeArrayFromHostBuffer( + data->data(), dtype, shape, + /*byte_strides=*/std::nullopt, sharding, + Client::HostBufferSemantics::kImmutableZeroCopy, + /*on_done_with_host_buffer=*/nullptr)); // The `Array` may alias the host buffer, but once the transfer is done and // the `Array` is destroyed, the host buffer is not accessed. This test would diff --git a/xla/python/ifrt/array_test.cc b/xla/python/ifrt/array_test.cc index cfa4093ee21d3..ec94659eb7fe5 100644 --- a/xla/python/ifrt/array_test.cc +++ b/xla/python/ifrt/array_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/client.cc b/xla/python/ifrt/client.cc index dc5c231ce9110..25233a2150ff2 100644 --- a/xla/python/ifrt/client.cc +++ b/xla/python/ifrt/client.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/client.h b/xla/python/ifrt/client.h index 418b08017be88..1018fef99bf9a 100644 --- a/xla/python/ifrt/client.h +++ b/xla/python/ifrt/client.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ limitations under the License. #include "llvm/Support/ExtensibleRTTI.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/tuple.h" @@ -96,20 +97,21 @@ class Client : public llvm::RTTIExtends { // // TODO(hyeontaek): Consider changing `on_done_with_host_buffer` into a // returned `Future` for consistency with other IFRT APIs. - virtual StatusOr> MakeArrayFromHostBuffer( + virtual absl::StatusOr> MakeArrayFromHostBuffer( const void* data, DType dtype, Shape shape, std::optional> byte_strides, std::shared_ptr sharding, HostBufferSemantics semantics, std::function on_done_with_host_buffer) = 0; // Builds a larger array out of individual per-device shards. - virtual StatusOr> AssembleArrayFromSingleDeviceArrays( + virtual absl::StatusOr> + AssembleArrayFromSingleDeviceArrays( Shape shape, std::shared_ptr sharding, absl::Span> arrays, ArrayCopySemantics semantics) = 0; // Builds a tuple from a sequence of values. - virtual StatusOr> MakeTuple( + virtual absl::StatusOr> MakeTuple( absl::Span> values) = 0; // The following APIs are taken from `xla::PjRtClient` for fast prototyping. @@ -124,6 +126,14 @@ class Client : public llvm::RTTIExtends { virtual absl::string_view platform_version() const = 0; virtual PlatformId platform_id() const = 0; + // Returns the attributes of the client. In principle, these try to describe + // capabilities of a client rather than being a "feature flag". + // + // List of officially supported attributes: + // + // * supports_executable_serialization (bool; default = true): Whether IFRT + // executables produced by this client are serializable. If false, all + // executables from this client are considered not serializable. using ClientAttribute = xla::PjRtValueType; virtual absl::flat_hash_map attributes() const = 0; @@ -136,10 +146,10 @@ class Client : public llvm::RTTIExtends { // TODO(hyeontaek): Consider removing this API. This API is potentially not // being used by JAX or will be replaced with explicit device assignment. - virtual StatusOr GetDefaultDeviceAssignment( + virtual absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const = 0; - virtual StatusOr LookupDevice(int device_id) const = 0; - virtual StatusOr LookupAddressableDevice( + virtual absl::StatusOr LookupDevice(int device_id) const = 0; + virtual absl::StatusOr LookupAddressableDevice( int local_hardware_id) const = 0; // TODO(hyeontaek): Potentially remove this method to encourage supporting @@ -147,9 +157,16 @@ class Client : public llvm::RTTIExtends { virtual Compiler* GetDefaultCompiler() = 0; // Returns a topology description for that covers the provided devices. - virtual StatusOr> + virtual absl::StatusOr> GetTopologyForDevices(absl::Span devices) const = 0; + // Returns the default layout on `device` for a buffer with `dtype` and + // single-shard dimensions `dims`. + // TODO(hyeontaek): Change the API to take `Shape` and `Sharding` instead of + // single-shard dimensions and device. + virtual absl::StatusOr> GetDefaultLayoutForDevice( + DType dtype, absl::Span dims, Device* device) const = 0; + static char ID; // NOLINT }; diff --git a/xla/python/ifrt/client_impl_test_lib.cc b/xla/python/ifrt/client_impl_test_lib.cc index 6b3fdbd0b3fc8..5e739951ffb5d 100644 --- a/xla/python/ifrt/client_impl_test_lib.cc +++ b/xla/python/ifrt/client_impl_test_lib.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/compiler.cc b/xla/python/ifrt/compiler.cc index 4b50f8b862599..77b8526808136 100644 --- a/xla/python/ifrt/compiler.cc +++ b/xla/python/ifrt/compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/compiler.h b/xla/python/ifrt/compiler.h index 64d333ff2a876..6bcce6334a1e9 100644 --- a/xla/python/ifrt/compiler.h +++ b/xla/python/ifrt/compiler.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -65,7 +65,7 @@ class Compiler : public llvm::RTTIExtends { public: // Compiles `mlir_module` and returns a `LoadedExecutable`. // TODO(hyeontaek): Move executable loading to `Client`. - virtual StatusOr> Compile( + virtual absl::StatusOr> Compile( std::unique_ptr program, std::unique_ptr options) = 0; @@ -73,7 +73,7 @@ class Compiler : public llvm::RTTIExtends { // `LoadedExecutable::Serialize()`. The compatibility of `serialized` is // implementation specific. // TODO(hyeontaek): Move executable loading to `Client`. - virtual StatusOr> + virtual absl::StatusOr> DeserializeLoadedExecutable( absl::string_view serialized, std::unique_ptr options) = 0; diff --git a/xla/python/ifrt/device.cc b/xla/python/ifrt/device.cc index d999f35e91985..0e48b6f5edf5e 100644 --- a/xla/python/ifrt/device.cc +++ b/xla/python/ifrt/device.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,16 +15,25 @@ limitations under the License. #include "xla/python/ifrt/device.h" +#include +#include #include +#include #include #include -#include "xla/python/ifrt/types.pb.h" +#include "absl/base/optimization.h" +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "xla/python/ifrt/device.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace ifrt { -DeviceList::DeviceList(Devices devices) { +DeviceList::DeviceList(Devices devices) : hash_(kUnsetHash) { if (devices.size() <= kInlineDeviceSize) { state_ = State{std::move(devices)}; } else { @@ -32,8 +41,30 @@ DeviceList::DeviceList(Devices devices) { } } -StatusOr DeviceList::FromProto(LookupDeviceFunc lookup_device, - const DeviceListProto& proto) { +DeviceList::DeviceList(const DeviceList& other) + : state_(other.state_), + hash_(other.hash_.load(std::memory_order_relaxed)) {} + +DeviceList::DeviceList(DeviceList&& other) + : state_(std::move(other.state_)), + hash_(other.hash_.load(std::memory_order_relaxed)) {} + +DeviceList& DeviceList::operator=(const DeviceList& other) { + state_ = other.state_; + hash_.store(other.hash_.load(std::memory_order_relaxed), + std::memory_order_relaxed); + return *this; +} + +DeviceList& DeviceList::operator=(DeviceList&& other) { + state_ = std::move(other.state_); + hash_.store(other.hash_.load(std::memory_order_relaxed), + std::memory_order_relaxed); + return *this; +} + +absl::StatusOr DeviceList::FromProto(LookupDeviceFunc lookup_device, + const DeviceListProto& proto) { DeviceList::Devices devices; devices.reserve(proto.device_ids_size()); for (int device_id : proto.device_ids()) { @@ -52,6 +83,28 @@ DeviceListProto DeviceList::ToProto() const { return proto; } +uint64_t DeviceList::hash() const { + uint64_t hash = hash_.load(std::memory_order_relaxed); + if (ABSL_PREDICT_FALSE(hash == kUnsetHash)) { + hash = absl::HashOf(devices()); + if (ABSL_PREDICT_FALSE(hash == kUnsetHash)) { + ++hash; + } + hash_.store(hash, std::memory_order_relaxed); + } + return hash; +} + +std::string DeviceList::DebugString() const { + return absl::StrCat("[", + absl::StrJoin(devices(), ",", + [](std::string* out, Device* device) { + absl::StrAppend(out, + device->DebugString()); + }), + "]"); +} + std::vector GetDeviceIds(DeviceList device_list) { std::vector ids; ids.reserve(device_list.devices().size()); diff --git a/xla/python/ifrt/device.h b/xla/python/ifrt/device.h index caa816a6ebfea..f0a441b80665f 100644 --- a/xla/python/ifrt/device.h +++ b/xla/python/ifrt/device.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,16 +16,20 @@ limitations under the License. #ifndef XLA_PYTHON_IFRT_DEVICE_H_ #define XLA_PYTHON_IFRT_DEVICE_H_ +#include +#include #include +#include #include #include #include #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/python/ifrt/types.pb.h" +#include "xla/python/ifrt/device.pb.h" namespace xla { namespace ifrt { @@ -50,19 +54,19 @@ class DeviceList { // Constructor with a pre-populated `devices`. explicit DeviceList(Devices devices); - DeviceList(const DeviceList& devices) = default; - DeviceList(DeviceList&& devices) = default; - DeviceList& operator=(const DeviceList& other) = default; - DeviceList& operator=(DeviceList&& other) = default; + DeviceList(const DeviceList& other); + DeviceList(DeviceList&& other); + DeviceList& operator=(const DeviceList& other); + DeviceList& operator=(DeviceList&& other); // Function that matches the semantics of `Client::LookupDevice()`. - using LookupDeviceFunc = absl::FunctionRef(int)>; + using LookupDeviceFunc = absl::FunctionRef(int)>; // Constructs `DeviceList` from `DeviceListProto`. Devices are looked up using // `lookup_device`. Device ids in the proto must be consistent with the // devices returned by `lookup_device`. - static StatusOr FromProto(LookupDeviceFunc lookup_device, - const DeviceListProto& proto); + static absl::StatusOr FromProto(LookupDeviceFunc lookup_device, + const DeviceListProto& proto); // Returns a `DeviceListProto` representation. DeviceListProto ToProto() const; @@ -70,11 +74,19 @@ class DeviceList { absl::Span devices() const { return state().devices; } bool operator==(const DeviceList& other) const { + const std::shared_ptr* lhs = + std::get_if>(&state_); + const std::shared_ptr* rhs = + std::get_if>(&other.state_); + if (lhs != nullptr && rhs != nullptr && lhs->get() == rhs->get()) { + return true; + } return devices() == other.devices(); } - bool operator!=(const DeviceList& other) const { - return devices() != other.devices(); - } + bool operator!=(const DeviceList& other) const { return !(*this == other); } + + // Returns the hash of devices. This hash is stable only within the process. + uint64_t hash() const; int size() const { return state().devices.size(); } bool empty() const { return state().devices.empty(); } @@ -89,6 +101,8 @@ class DeviceList { auto end() const { return state().devices.end(); } auto cend() const { return state().devices.cend(); } + std::string DebugString() const; + private: // Internal state that may be shared across `DeviceList` instances. struct State { @@ -122,6 +136,11 @@ class DeviceList { } std::variant> state_; + + // Cached hash. 0 indicates the hash needs to be computed and cached. + // May be written multiple times with the same non-zero value. + static constexpr uint64_t kUnsetHash = 0; + mutable std::atomic hash_; }; // Returns the id of each device in `device_list`. @@ -132,7 +151,7 @@ std::vector GetDeviceIds(DeviceList device_list); // d2->id()"). template H AbslHashValue(H h, const DeviceList& devices) { - return H::combine(std::move(h), devices.devices()); + return H::combine(std::move(h), devices.hash()); } } // namespace ifrt diff --git a/xla/python/ifrt/device.proto b/xla/python/ifrt/device.proto new file mode 100644 index 0000000000000..e5ce1f8e301a1 --- /dev/null +++ b/xla/python/ifrt/device.proto @@ -0,0 +1,25 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +// Wire format for `DeviceList`. +message DeviceListProto { + // Serialization and deserialization are expected to ensure that device ids + // are stable across proto construction and consumption. + repeated int32 device_ids = 1; +} diff --git a/xla/python/ifrt/device_test.cc b/xla/python/ifrt/device_test.cc new file mode 100644 index 0000000000000..ca988be158cca --- /dev/null +++ b/xla/python/ifrt/device_test.cc @@ -0,0 +1,98 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/ifrt/device.h" + +#include +#include +#include +#include + +#include +#include "absl/status/statusor.h" +#include "absl/synchronization/blocking_counter.h" +#include "xla/python/ifrt/device.pb.h" +#include "xla/python/ifrt/sharding_test_util.h" +#include "tsl/platform/cpu_info.h" +#include "tsl/platform/env.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/threadpool.h" + +namespace xla { +namespace ifrt { +namespace { + +class DeviceListTest : public test_util::ShardingTest {}; + +TEST_P(DeviceListTest, ToFromProto) { + auto device_list = GetDevices({0, 1}); + DeviceListProto proto = device_list.ToProto(); + auto lookup_device_func = [&](int device_id) -> absl::StatusOr { + return client()->LookupDevice(device_id); + }; + TF_ASSERT_OK_AND_ASSIGN(auto device_list_copy, + DeviceList::FromProto(lookup_device_func, proto)); + EXPECT_EQ(device_list_copy, device_list); +} + +TEST_P(DeviceListTest, IdenticalHashFromConcurrentCalls) { + auto device_list = GetDevices({0, 1}); + + const int num_threads = 16; + absl::BlockingCounter counter(num_threads); + tsl::thread::ThreadPool thread_pool( + tsl::Env::Default(), tsl::ThreadOptions(), "test_pool", + std::min(num_threads, tsl::port::MaxParallelism())); + std::vector hashes(num_threads); + for (int i = 0; i < num_threads; ++i) { + thread_pool.Schedule([&, i]() { + hashes[i] = device_list.hash(); + counter.DecrementCount(); + }); + } + + counter.Wait(); + for (int i = 0; i < num_threads; ++i) { + EXPECT_EQ(hashes[i], device_list.hash()); + } + EXPECT_NE(device_list.hash(), 0); +} + +TEST_P(DeviceListTest, EqualityTest) { + auto device_list1 = GetDevices({0, 1}); + auto device_list2 = GetDevices({0, 1}); + EXPECT_EQ(device_list1, device_list2); + + auto device_list3 = device_list1; + EXPECT_EQ(device_list1, device_list3); + + auto device_list4 = std::move(device_list2); + EXPECT_EQ(device_list1, device_list4); + + auto device_list5 = GetDevices({0}); + EXPECT_NE(device_list1, device_list5); + + auto device_list6 = GetDevices({1, 0}); + EXPECT_NE(device_list1, device_list6); +} + +INSTANTIATE_TEST_SUITE_P(NumDevices, DeviceListTest, + testing::Values(test_util::ShardingTestParam{ + /*num_devices=*/2, + /*num_addressable_devices=*/2})); + +} // namespace +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt/dtype.cc b/xla/python/ifrt/dtype.cc index e6747641307a0..b032d350a8156 100644 --- a/xla/python/ifrt/dtype.cc +++ b/xla/python/ifrt/dtype.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,9 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "xla/python/ifrt/dtype.pb.h" namespace xla { namespace ifrt { @@ -78,6 +80,89 @@ std::optional DType::bit_size() const { } } +absl::StatusOr DType::FromProto(const DTypeProto& dtype_proto) { + switch (dtype_proto.kind()) { + case DTypeProto::KIND_PRED: + return DType(DType::Kind::kPred); + case DTypeProto::KIND_TOKEN: + return DType(DType::Kind::kToken); +#define CASE(X) \ + case DTypeProto::KIND_##X: \ + return DType(DType::Kind::k##X); + CASE(S4); + CASE(S8); + CASE(S16); + CASE(S32); + CASE(S64); + CASE(U4); + CASE(U8); + CASE(U16); + CASE(U32); + CASE(U64); + CASE(F16); + CASE(F32); + CASE(F64); + CASE(BF16); + CASE(C64); + CASE(C128); + CASE(F8E4M3FN); + CASE(F8E4M3B11FNUZ); + CASE(F8E4M3FNUZ); + CASE(F8E5M2); + CASE(F8E5M2FNUZ); +#undef CASE + case DTypeProto::KIND_STRING: + return DType(DType::Kind::kString); + default: + return DType(DType::Kind::kInvalid); + } +} + +DTypeProto DType::ToProto() const { + DTypeProto dtype_proto; + switch (kind()) { + case DType::Kind::kPred: + dtype_proto.set_kind(DTypeProto::KIND_PRED); + break; + case DType::Kind::kToken: + dtype_proto.set_kind(DTypeProto::KIND_TOKEN); + break; +#define CASE(X) \ + case DType::Kind::k##X: \ + dtype_proto.set_kind(DTypeProto::KIND_##X); \ + break; + CASE(S4); + CASE(S8); + CASE(S16); + CASE(S32); + CASE(S64); + CASE(U4); + CASE(U8); + CASE(U16); + CASE(U32); + CASE(U64); + CASE(F16); + CASE(F32); + CASE(F64); + CASE(BF16); + CASE(C64); + CASE(C128); + CASE(F8E4M3FN); + CASE(F8E4M3B11FNUZ); + CASE(F8E4M3FNUZ); + CASE(F8E5M2); + CASE(F8E5M2FNUZ); +#undef CASE + case DType::Kind::kString: + dtype_proto.set_kind(DTypeProto::KIND_STRING); + break; + default: + dtype_proto.set_kind(DTypeProto::KIND_UNSPECIFIED); + break; + } + return dtype_proto; +} + std::string DType::DebugString() const { switch (kind_) { case kInvalid: diff --git a/xla/python/ifrt/dtype.h b/xla/python/ifrt/dtype.h index 184e3597f56f1..517d35207705d 100644 --- a/xla/python/ifrt/dtype.h +++ b/xla/python/ifrt/dtype.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,6 +20,9 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" +#include "xla/python/ifrt/dtype.pb.h" + namespace xla { namespace ifrt { @@ -112,6 +115,12 @@ class DType { // std::nullopt if there is no fixed size. std::optional bit_size() const; + // Constructs `DType` from `DTypeProto`. + static absl::StatusOr FromProto(const DTypeProto& proto); + + // Returns a `DTypeProto` representation. + DTypeProto ToProto() const; + std::string DebugString() const; private: diff --git a/xla/python/ifrt/dtype.proto b/xla/python/ifrt/dtype.proto new file mode 100644 index 0000000000000..b23ffa75c2a8b --- /dev/null +++ b/xla/python/ifrt/dtype.proto @@ -0,0 +1,73 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +// Data type kinds. Mirrors `xla::ifrt::DType`. +message DTypeProto { + enum Kind { + KIND_UNSPECIFIED = 0; + + // Predicates are two-state booleans. + KIND_PRED = 1; + + // Signed integral values of fixed width. + KIND_S4 = 21; + KIND_S8 = 2; + KIND_S16 = 3; + KIND_S32 = 4; + KIND_S64 = 5; + + // Unsigned integral values of fixed width. + KIND_U4 = 22; + KIND_U8 = 6; + KIND_U16 = 7; + KIND_U32 = 8; + KIND_U64 = 9; + + // Floating-point values of fixed width. + KIND_F16 = 10; + KIND_F32 = 11; + KIND_F64 = 12; + + // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit + // floating-point format, but uses 1 bit for the sign, 8 bits for the + // exponent and 7 bits for the mantissa. + KIND_BF16 = 16; + + // Complex values of fixed width. + KIND_C64 = 15; // Paired F32 (real, imag), as in std::complex. + KIND_C128 = 18; // Paired F64 (real, imag), as in std::complex. + + // A token type threaded between side-effecting operations. Shapes of this + // dtype will have empty dimensions. + KIND_TOKEN = 17; + + KIND_F8E4M3FN = 20; + KIND_F8E4M3B11FNUZ = 23; + KIND_F8E4M3FNUZ = 25; + KIND_F8E5M2 = 19; + KIND_F8E5M2FNUZ = 24; + + // Variable-length string represented as raw bytes, as in `bytes` in Python, + // i.e., no encoding enforcement. String is not support in XLA. DType.Kind + // needs to match xla.PrimitiveType enum, so choose a large enum to avoid + // collision. + KIND_STRING = 99; + } + Kind kind = 1; +} diff --git a/xla/python/ifrt/dtype_test.cc b/xla/python/ifrt/dtype_test.cc new file mode 100644 index 0000000000000..4295c5e09ce29 --- /dev/null +++ b/xla/python/ifrt/dtype_test.cc @@ -0,0 +1,40 @@ +// Copyright 2024 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt/dtype.h" + +#include +#include "xla/python/ifrt/dtype.pb.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace { + +TEST(DTypeTest, FromToFromProto) { + for (int i = 0; i < DTypeProto::Kind_descriptor()->value_count(); ++i) { + DTypeProto proto; + proto.set_kind(static_cast( + DTypeProto::Kind_descriptor()->value(i)->number())); + TF_ASSERT_OK_AND_ASSIGN(DType dtype, DType::FromProto(proto)); + TF_ASSERT_OK_AND_ASSIGN(DType dtype_copy, + DType::FromProto(dtype.ToProto())); + EXPECT_EQ(dtype_copy, dtype); + } +} + +} // namespace +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt/executable.cc b/xla/python/ifrt/executable.cc index 0ae7bfd8a1547..77cabe7f6a938 100644 --- a/xla/python/ifrt/executable.cc +++ b/xla/python/ifrt/executable.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/executable.h b/xla/python/ifrt/executable.h index 885fd50834784..612827023d4da 100644 --- a/xla/python/ifrt/executable.h +++ b/xla/python/ifrt/executable.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" #include "xla/pjrt/pjrt_client.h" @@ -40,18 +41,19 @@ class Executable : public llvm::RTTIExtends { virtual absl::string_view name() const = 0; // Returns a fingerprint of this executable. - virtual StatusOr> Fingerprint() const = 0; + virtual absl::StatusOr> Fingerprint() const = 0; // Serializes this executable into a string. The compatibility of the // serialized executable is implementation-specific. - virtual StatusOr Serialize() const = 0; + virtual absl::StatusOr Serialize() const = 0; // The following APIs are taken from `xla::PjRtExecutable` for fast // prototyping. TODO(hyeontaek): Factor some of them out as // `XlaCompatibleExecutable`. virtual int num_devices() const = 0; virtual int64_t SizeOfGeneratedCodeInBytes() const = 0; - virtual StatusOr GetCompiledMemoryStats() const = 0; + virtual absl::StatusOr GetCompiledMemoryStats() + const = 0; // TODO(hyeontaek): Move the following XLA-specific methods to // pjrt_executable.h and put it in an `XlaCompatibleExecutable`. @@ -61,20 +63,22 @@ class Executable : public llvm::RTTIExtends { const = 0; // Returns a list of output `OpSharding`. virtual std::optional> GetOutputShardings() const = 0; - // Returns a list of parameter `xla::Layout`s. - virtual StatusOr> GetParameterLayouts() const = 0; - // Returns a list of output/result `xla::Layout`s. - virtual StatusOr> GetOutputLayouts() const = 0; + // Returns a list of parameter layouts. + virtual absl::StatusOr>> + GetParameterLayouts() const = 0; + // Returns a list of output/result layouts. + virtual absl::StatusOr>> + GetOutputLayouts() const = 0; // Returns an `HloModule` (optimized) per partition. - virtual StatusOr>> GetHloModules() - const = 0; + virtual absl::StatusOr>> + GetHloModules() const = 0; using CostAnalysisValue = xla::PjRtValueType; // Returns named values for cost properties of this executable (such as // operations, size of input/outputs, and run time estimate). Properties may // differ for different implementations and platforms. - virtual StatusOr> + virtual absl::StatusOr> GetCostAnalysis() const = 0; static char ID; // NOLINT @@ -94,11 +98,21 @@ class LoadedExecutable virtual absl::string_view name() const = 0; // Returns a fingerprint of this executable. - virtual StatusOr> Fingerprint() const = 0; + virtual absl::StatusOr> Fingerprint() const = 0; // Serializes this executable into a string. The compatibility of the // serialized executable is implementation-specific. - virtual StatusOr Serialize() const = 0; + virtual absl::StatusOr Serialize() const = 0; + + // Returns a future that becomes ready when the executable is ready to be + // used for execution. + // + // This can be used by implementations that support async compilation, where + // `Compiler::Compile()` returns an executable ~immediately and does heavy + // compilation work in the background. Implementations must still ensure that + // all other methods can be used even without explicitly waiting for the ready + // future (e.g., via blocking). + virtual Future GetReadyFuture() const = 0; // The following APIs are taken from `xla::PjRtExecutable` for fast // prototyping. @@ -106,7 +120,8 @@ class LoadedExecutable // TODO(hyeontaek): Factor some of them out as `XlaCompatibleExecutable`. virtual int num_devices() const = 0; virtual int64_t SizeOfGeneratedCodeInBytes() const = 0; - virtual StatusOr GetCompiledMemoryStats() const = 0; + virtual absl::StatusOr GetCompiledMemoryStats() + const = 0; // The following APIs are taken from `xla::PjRtLoadedExecutable` for fast // prototyping. @@ -118,24 +133,26 @@ class LoadedExecutable const = 0; // Returns a list of output OpSharding. virtual std::optional> GetOutputShardings() const = 0; - // Returns a list of parameter `xla::Layout`s. - virtual StatusOr> GetParameterLayouts() const = 0; - // Returns a list of output/result `xla::Layout`s. - virtual StatusOr> GetOutputLayouts() const = 0; + // Returns a list of parameter layouts. + virtual absl::StatusOr>> + GetParameterLayouts() const = 0; + // Returns a list of output/result layouts. + virtual absl::StatusOr>> + GetOutputLayouts() const = 0; // Return an HloModule (optimized) per partition. - virtual StatusOr>> GetHloModules() - const = 0; + virtual absl::StatusOr>> + GetHloModules() const = 0; // Returns a list of lists of memory kind strings for output. The returned // value is `[num_programs, num_output]`. The size of the outer list should be // equal to `GetHloModules()`. Under SPMD, one can use // `GetOutputMemoryKinds().front()`. - virtual StatusOr>> + virtual absl::StatusOr>> GetOutputMemoryKinds() const = 0; // Returns named values for cost properties of this executable (such as // operations, size of input/outputs, and run time estimate). Properties may // differ for different implementations and platforms. - virtual StatusOr< + virtual absl::StatusOr< absl::flat_hash_map> GetCostAnalysis() const = 0; @@ -169,7 +186,7 @@ class LoadedExecutable // incrementally. We need to have a stricter way to control this behavior // (e.g., having per-argument/output booleans or providing a separate barrier // API). - virtual StatusOr Execute( + virtual absl::StatusOr Execute( absl::Span> args, const ExecuteOptions& options, std::optional devices) = 0; diff --git a/xla/python/ifrt/future.cc b/xla/python/ifrt/future.cc index c0a71d20e52ca..c4032c6274804 100644 --- a/xla/python/ifrt/future.cc +++ b/xla/python/ifrt/future.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/future.h b/xla/python/ifrt/future.h index fa3b289510b45..358b38b05bb88 100644 --- a/xla/python/ifrt/future.h +++ b/xla/python/ifrt/future.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/future_test.cc b/xla/python/ifrt/future_test.cc index 649fe019f368a..5d61f84b72aab 100644 --- a/xla/python/ifrt/future_test.cc +++ b/xla/python/ifrt/future_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/host_callback.cc b/xla/python/ifrt/host_callback.cc index 5ad6393049b1c..e3b2d1833deba 100644 --- a/xla/python/ifrt/host_callback.cc +++ b/xla/python/ifrt/host_callback.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/host_callback.h b/xla/python/ifrt/host_callback.h index c8827b030b7b7..3432f6f037019 100644 --- a/xla/python/ifrt/host_callback.h +++ b/xla/python/ifrt/host_callback.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -64,7 +64,7 @@ class LoadedHostCallback // // TODO(hyeontaek): Change `Serialize()` to return `HostCallback` instead of a // serialized host callback directly. - virtual StatusOr Serialize() const = 0; + virtual absl::StatusOr Serialize() const = 0; static char ID; // NOLINT }; diff --git a/xla/python/ifrt/index.cc b/xla/python/ifrt/index.cc index ac6d48108b01e..bd065c0161e52 100644 --- a/xla/python/ifrt/index.cc +++ b/xla/python/ifrt/index.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/index.h b/xla/python/ifrt/index.h index b6c0382f46d10..02d30709512ce 100644 --- a/xla/python/ifrt/index.h +++ b/xla/python/ifrt/index.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/index_domain.cc b/xla/python/ifrt/index_domain.cc index 0bc30f32bed60..556a22e924af9 100644 --- a/xla/python/ifrt/index_domain.cc +++ b/xla/python/ifrt/index_domain.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/index_domain.h b/xla/python/ifrt/index_domain.h index 8796b30b41ba6..159c2aef38af5 100644 --- a/xla/python/ifrt/index_domain.h +++ b/xla/python/ifrt/index_domain.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/index_domain_test.cc b/xla/python/ifrt/index_domain_test.cc index dc78e6f747b97..b98c7fd5e5ca8 100644 --- a/xla/python/ifrt/index_domain_test.cc +++ b/xla/python/ifrt/index_domain_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/index_test.cc b/xla/python/ifrt/index_test.cc index c6262d051be22..69dedbee1f68f 100644 --- a/xla/python/ifrt/index_test.cc +++ b/xla/python/ifrt/index_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/ir/BUILD b/xla/python/ifrt/ir/BUILD index 8080024b9b5bf..d404ecdd7ce61 100644 --- a/xla/python/ifrt/ir/BUILD +++ b/xla/python/ifrt/ir/BUILD @@ -1,4 +1,5 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -12,6 +13,7 @@ td_library( "ifrt_interfaces.td", "ifrt_ops.td", ], + compatible_with = get_compatible_with_portable(), visibility = ["//xla/python/ifrt:friends"], deps = [ "@llvm-project//mlir:AttrTdFiles", @@ -22,6 +24,7 @@ td_library( gentbl_cc_library( name = "ifrt_dialect_inc_gen", + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -74,6 +77,7 @@ gentbl_cc_library( gentbl_cc_library( name = "ifrt_ops_inc_gen", + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( ["-gen-op-decls"], @@ -92,14 +96,23 @@ gentbl_cc_library( gentbl_cc_library( name = "ifrt_interfaces_inc_gen", + compatible_with = get_compatible_with_portable(), tbl_outs = [ + ( + ["-gen-attr-interface-decls"], + "ifrt_attr_interfaces.h.inc", + ), + ( + ["-gen-attr-interface-defs"], + "ifrt_attr_interfaces.cc.inc", + ), ( ["-gen-op-interface-decls"], - "ifrt_interfaces.h.inc", + "ifrt_op_interfaces.h.inc", ), ( ["-gen-op-interface-defs"], - "ifrt_interfaces.cc.inc", + "ifrt_op_interfaces.cc.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", @@ -123,15 +136,21 @@ cc_library( "ifrt_ops.h", "sharding_param.h", ], + compatible_with = get_compatible_with_portable(), visibility = ["//xla/python/ifrt:friends"], deps = [ ":ifrt_dialect_inc_gen", ":ifrt_interfaces_inc_gen", ":ifrt_ops_inc_gen", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", + "@llvm-project//mlir:CallOpInterfaces", # buildcleaner: keep "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@tsl//tsl/platform:errors", ], ) @@ -139,8 +158,10 @@ cc_library( name = "compiler", srcs = ["compiler.cc"], hdrs = ["compiler.h"], + compatible_with = get_compatible_with_portable(), visibility = ["//xla/python/ifrt:friends"], deps = [ + "//xla:statusor", "//xla/python/ifrt", "@com_google_absl//absl/container:flat_hash_map", "@llvm-project//llvm:Support", diff --git a/xla/python/ifrt/ir/compiler.cc b/xla/python/ifrt/ir/compiler.cc index ab036ad6a182e..8922d23ec30f2 100644 --- a/xla/python/ifrt/ir/compiler.cc +++ b/xla/python/ifrt/ir/compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ namespace ifrt { char IfrtIRProgram::ID = 0; char IfrtIRCompileOptions::ID = 0; -StatusOr> GetIfrtIRCompileOptions( +absl::StatusOr> GetIfrtIRCompileOptions( std::unique_ptr options) { if (!llvm::isa(options.get())) { return absl::InvalidArgumentError("options must be IfrtIRCompileOptions"); diff --git a/xla/python/ifrt/ir/compiler.h b/xla/python/ifrt/ir/compiler.h index 87511c0809c37..94dd1e19f15a5 100644 --- a/xla/python/ifrt/ir/compiler.h +++ b/xla/python/ifrt/ir/compiler.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/executable.h" +#include "xla/statusor.h" namespace xla { namespace ifrt { @@ -47,9 +48,13 @@ struct IfrtIRCompileOptions explicit IfrtIRCompileOptions( std::vector device_assignments, absl::flat_hash_map loaded_exec_binding = - {}) + {}, + std::shared_ptr>> + compile_options_overrides = {}) : device_assignments(std::move(device_assignments)), - loaded_exec_binding(std::move(loaded_exec_binding)) {} + loaded_exec_binding(std::move(loaded_exec_binding)), + compile_options_overrides(std::move(compile_options_overrides)) {} // Map from logical device ids in MLIR module to runtime device ids obtained // from IFRT client. @@ -60,11 +65,18 @@ struct IfrtIRCompileOptions // outlive the LoadedExecutable to be compiled. absl::flat_hash_map loaded_exec_binding; + // Mapping from values of `ifrt.compile_option_key` attribute of a `CallOp` to + // compile options. If a `CallOp` does not have have the attribute set or does + // not have an entry in this map then default compile options are used. + std::shared_ptr>> + compile_options_overrides; + static char ID; // NOLINT }; // Gets `xla::ifrt::IfrtIRCompileOptions` from `xla::ifrt::CompileOptions`. -StatusOr> GetIfrtIRCompileOptions( +absl::StatusOr> GetIfrtIRCompileOptions( std::unique_ptr options); } // namespace ifrt diff --git a/xla/python/ifrt/ir/constants.h b/xla/python/ifrt/ir/constants.h index 9b539dc8db9c2..cd1bc06bf85ee 100644 --- a/xla/python/ifrt/ir/constants.h +++ b/xla/python/ifrt/ir/constants.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,6 +33,11 @@ inline constexpr llvm::StringLiteral kIfrtDonatedArgAttrName = "ifrt.donated"; // in "local" view (i.e., already sharded). inline constexpr llvm::StringLiteral kIfrtLocalViewAttrName = "ifrt.local_view"; +// Name of StringAttr on CallOp used to store an optional key to use into a +// mapping of user-provided compile options. +inline constexpr llvm::StringLiteral kIfrtCompileOptionsKey = + "ifrt.compile_options_key"; + } // namespace ifrt } // namespace xla diff --git a/xla/python/ifrt/ir/ifrt_dialect.cc b/xla/python/ifrt/ir/ifrt_dialect.cc index a86c88e767625..15e7cd2b22fb5 100644 --- a/xla/python/ifrt/ir/ifrt_dialect.cc +++ b/xla/python/ifrt/ir/ifrt_dialect.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,10 +17,11 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" @@ -35,6 +36,7 @@ limitations under the License. #include "mlir/IR/OpImplementation.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "xla/python/ifrt/ir/constants.h" +#include "xla/python/ifrt/ir/ifrt_interfaces.h" #include "xla/python/ifrt/ir/ifrt_ops.h" // Generated definitions. @@ -118,45 +120,90 @@ mlir::LogicalResult IfrtDialect::verifyRegionArgAttribute( return mlir::success(); } -mlir::LogicalResult IfrtShardingAttr::verify( +//===----------------------------------------------------------------------===// +// IfrtShardingParamAttr +//===----------------------------------------------------------------------===// + +mlir::LogicalResult IfrtShardingParamAttr::verify( + llvm::function_ref emitError, + ShardingParam sharding_param) { + return sharding_param.verify(emitError); +} + +mlir::LogicalResult IfrtShardingParamAttr::CanApplyTo( llvm::function_ref emitError, - ShardingParam sharding) { - return sharding.verify(emitError); + mlir::RankedTensorType shape, llvm::ArrayRef device_ids) const { + return getSharding().CanApplyTo(emitError, shape, device_ids); +} + +absl::StatusOr> +IfrtShardingParamAttr::GlobalShapeFromLocalShape( + llvm::ArrayRef local_shape) const { + return getSharding().GlobalShapeFromLocalShape(local_shape); +} + +absl::StatusOr> +IfrtShardingParamAttr::LocalShapeFromGlobalShape( + llvm::ArrayRef global_shape) const { + return getSharding().LocalShapeFromGlobalShape(global_shape); +} + +// Returns the number of devices the sharding applies to. +int IfrtShardingParamAttr::NumDevices() const { + return getSharding().NumDevices(); +}; + +//===----------------------------------------------------------------------===// +// IfrtUnspecifiedShardingAttr +//===----------------------------------------------------------------------===// + +mlir::LogicalResult IfrtUnspecifiedShardingAttr::CanApplyTo( + llvm::function_ref emitError, + mlir::RankedTensorType shape, llvm::ArrayRef device_ids) const { + // The unspecified sharding can be applied to any array. + return mlir::success(); } +absl::StatusOr> +IfrtUnspecifiedShardingAttr::GlobalShapeFromLocalShape( + llvm::ArrayRef local_shape) const { + // Unspecified sharding does not change the shape. + llvm::SmallVector global_shape(local_shape.begin(), + local_shape.end()); + return global_shape; +} + +absl::StatusOr> +IfrtUnspecifiedShardingAttr::LocalShapeFromGlobalShape( + llvm::ArrayRef global_shape) const { + // Unspecified sharding does not change the shape. + llvm::SmallVector local_shape(global_shape.begin(), + global_shape.end()); + return local_shape; +} + +int IfrtUnspecifiedShardingAttr::NumDevices() const { return 0; } + +//===----------------------------------------------------------------------===// +// IfrtArrayType +//===----------------------------------------------------------------------===// + +// Returns an array of logical device ids. llvm::ArrayRef IfrtArrayType::getDevices() const { return getDevicesAttr().getIds(); } mlir::LogicalResult IfrtArrayType::verify( llvm::function_ref emitError, - mlir::RankedTensorType shape, ShardingParam sharding, + mlir::RankedTensorType shape, IfrtShardingAttrInterface sharding_attr, IfrtDevicesAttr devices) { - if (mlir::failed(sharding.verify(emitError))) { - return mlir::failure(); - } - - if (shape.getRank() != sharding.dim_shards().size()) { - return emitError() << "Requires dim shards to have the same rank as the " - "array. Array rank is " - << shape.getRank() << " vs dim shards rank of " - << sharding.dim_shards().size(); - } - - int devices_in_mesh = 1; - for (const int axis_size : sharding.minor_to_major().axis_sizes) { - devices_in_mesh *= axis_size; - } - if (llvm::ArrayRef ids = devices.getIds(); - devices_in_mesh != ids.size()) { - return emitError() << "Requires the same amount of `devices` and from " - "`sharding`. Actual: " - << ids.size() << " vs " << devices_in_mesh; - } - - return mlir::success(); + return sharding_attr.CanApplyTo(emitError, shape, devices.getIds()); } +//===----------------------------------------------------------------------===// +// IfrtDevicesAttr +//===----------------------------------------------------------------------===// + IfrtDevicesAttr::operator llvm::ArrayRef() const { return getIds(); } mlir::LogicalResult IfrtDevicesAttr::verify( @@ -165,10 +212,10 @@ mlir::LogicalResult IfrtDevicesAttr::verify( llvm::SmallSet device_set; for (int id : ids) { if (id < 0) { - return emitError() << "Device list has negative id " << id; + return emitError() << "Device list has negative logical id " << id; } if (auto [unused_it, inserted] = device_set.insert(id); !inserted) { - return emitError() << "Device list has duplicate id " << id; + return emitError() << "Device list has duplicate logical id " << id; } } diff --git a/xla/python/ifrt/ir/ifrt_dialect.h b/xla/python/ifrt/ir/ifrt_dialect.h index 5b0259b9fc476..a8fcd0bc500b5 100644 --- a/xla/python/ifrt/ir/ifrt_dialect.h +++ b/xla/python/ifrt/ir/ifrt_dialect.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ limitations under the License. #define XLA_PYTHON_IFRT_IR_IFRT_DIALECT_H_ #include "mlir/IR/Dialect.h" // from @llvm-project +#include "xla/python/ifrt/ir/ifrt_interfaces.h" #include "xla/python/ifrt/ir/sharding_param.h" // Generated definitions. diff --git a/xla/python/ifrt/ir/ifrt_dialect.td b/xla/python/ifrt/ir/ifrt_dialect.td index 248827a83efed..a9cec4b6d218a 100644 --- a/xla/python/ifrt/ir/ifrt_dialect.td +++ b/xla/python/ifrt/ir/ifrt_dialect.td @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,11 @@ limitations under the License. include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinTypes.td" include "mlir/IR/DialectBase.td" +include "xla/python/ifrt/ir/ifrt_interfaces.td" + +//===--------------------------------------------------------------------------- +// Dialect +//===--------------------------------------------------------------------------- def Ifrt_Dialect : Dialect { let name = "ifrt"; @@ -31,14 +36,19 @@ def Ifrt_Dialect : Dialect { let usePropertiesForAttributes = 0; } -def Ifrt_ShardingParameter : - AttrOrTypeParameter<"::xla::ifrt::ShardingParam", ""> { - let parser = "::xla::ifrt::ShardingParam::Parse($_parser)"; -} +//===--------------------------------------------------------------------------- +// Attributes +//===--------------------------------------------------------------------------- def Ifrt_DevicesAttr : AttrDef { let mnemonic = "devices"; - let summary = "Represents a list of device ids."; + let summary = "Represents a list of logical device ids."; + let description = [{ + Logical device ids are 0-based indices within a list of concrete devices + an IFRT module is compiled for. These ids are logical in order to + ensure that an IFRT IR module is hermetically specified, and thus can + be compiled for different concrete device ids without modifications. + }]; let parameters = (ins ArrayRefParameter<"int">:$ids); let assemblyFormat = "`[` $ids `]`"; @@ -54,8 +64,15 @@ def Ifrt_DevicesAttr : AttrDef { }]; } -def Ifrt_ShardingAttr : AttrDef { - let mnemonic = "sharding"; +def Ifrt_ShardingParameter : + AttrOrTypeParameter<"::xla::ifrt::ShardingParam", ""> { + let parser = "::xla::ifrt::ShardingParam::Parse($_parser)"; +} + +def Ifrt_ShardingParamAttr : AttrDef +]> { + let mnemonic = "sharding_param"; let summary = "ShardingParam as an attribute."; let parameters = (ins Ifrt_ShardingParameter:$sharding); @@ -64,31 +81,65 @@ def Ifrt_ShardingAttr : AttrDef { let genVerifyDecl = 1; } +def Ifrt_UnspecifiedShardingAttr : AttrDef +]> { + let mnemonic = "sharding_unspecified"; + let summary = "Attribute to be used when sharding is unspecified."; + + let parameters = (ins); + let assemblyFormat = ""; + + let genVerifyDecl = 1; +} + +//===--------------------------------------------------------------------------- +// Types +//===--------------------------------------------------------------------------- + def Ifrt_ArrayType : TypeDef { let mnemonic = "array"; let summary = "An Ifrt array sharded on a set of devices."; let parameters = (ins Builtin_RankedTensor:$shape, - Ifrt_ShardingParameter:$sharding, + "::xla::ifrt::IfrtShardingAttrInterface":$sharding_attr, Ifrt_DevicesAttr:$devices_attr); let builders = [ + TypeBuilder<(ins + "::mlir::RankedTensorType":$shape, + "::xla::ifrt::IfrtShardingAttrInterface":$sharding_attr, + "::llvm::ArrayRef":$devices), [{ + return Base::get( + $_ctxt, shape, sharding_attr, + ::xla::ifrt::IfrtDevicesAttr::get($_ctxt, devices)); + }]>, + TypeBuilder<(ins + "::mlir::RankedTensorType":$shape, + "::llvm::ArrayRef":$devices), [{ + return Base::get( + $_ctxt, shape, ::xla::ifrt::IfrtUnspecifiedShardingAttr::get($_ctxt), + ::xla::ifrt::IfrtDevicesAttr::get($_ctxt, devices)); + }]>, TypeBuilder<(ins "::mlir::RankedTensorType":$shape, "::xla::ifrt::ShardingParam":$sharding, "::llvm::ArrayRef":$devices), [{ - return Base::get($_ctxt, shape, sharding, - ::xla::ifrt::IfrtDevicesAttr::get($_ctxt, devices)); + return Base::get( + $_ctxt, shape, + ::xla::ifrt::IfrtShardingParamAttr::get($_ctxt, sharding), + ::xla::ifrt::IfrtDevicesAttr::get($_ctxt, devices)); }]> ]; - let assemblyFormat = "`<` $shape`,` $sharding`,` $devices_attr`>`"; + let assemblyFormat = "`<` $shape`,` $sharding_attr`,` $devices_attr`>`"; let genVerifyDecl = 1; let extraClassDeclaration = [{ - // Get device ids from `devices_attr`. + // Get logical device ids from `devices_attr`. ::llvm::ArrayRef getDevices() const; }]; } diff --git a/xla/python/ifrt/ir/ifrt_interfaces.cc b/xla/python/ifrt/ir/ifrt_interfaces.cc index f6d0473fa1e74..a079f75fbdd59 100644 --- a/xla/python/ifrt/ir/ifrt_interfaces.cc +++ b/xla/python/ifrt/ir/ifrt_interfaces.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,7 +21,12 @@ limitations under the License. #include "xla/python/ifrt/ir/constants.h" // Generated definitions. -#include "xla/python/ifrt/ir/ifrt_interfaces.cc.inc" + +#define GET_ATTR_INTERFACE_CLASSES +#include "xla/python/ifrt/ir/ifrt_attr_interfaces.cc.inc" + +#define GET_OP_INTERFACE_CLASSES +#include "xla/python/ifrt/ir/ifrt_op_interfaces.cc.inc" namespace mlir { namespace OpTrait { diff --git a/xla/python/ifrt/ir/ifrt_interfaces.h b/xla/python/ifrt/ir/ifrt_interfaces.h index 5dadf8951b08c..b496f4b57d3be 100644 --- a/xla/python/ifrt/ir/ifrt_interfaces.h +++ b/xla/python/ifrt/ir/ifrt_interfaces.h @@ -1,5 +1,6 @@ #include "xla/python/ifrt/ir/constants.h" -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +#include "xla/python/ifrt/ir/ifrt_dialect.h" +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -80,8 +81,15 @@ class IfrtCallLikeTrait { } // namespace OpTrait } // namespace mlir +// IWYU pragma: begin_exports + // Generated definitions. +#define GET_ATTR_INTERFACE_CLASSES +#include "xla/python/ifrt/ir/ifrt_attr_interfaces.h.inc" + #define GET_OP_INTERFACE_CLASSES -#include "xla/python/ifrt/ir/ifrt_interfaces.h.inc" // IWYU pragma: export +#include "xla/python/ifrt/ir/ifrt_op_interfaces.h.inc" + +// IWYU pragma: end_exports #endif // XLA_PYTHON_IFRT_IR_IFRT_INTERFACES_H_ diff --git a/xla/python/ifrt/ir/ifrt_interfaces.td b/xla/python/ifrt/ir/ifrt_interfaces.td index 738d5b445830e..f1a94a517b91b 100644 --- a/xla/python/ifrt/ir/ifrt_interfaces.td +++ b/xla/python/ifrt/ir/ifrt_interfaces.td @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,10 @@ limitations under the License. include "mlir/IR/OpBase.td" +//===--------------------------------------------------------------------------- +// Op interfaces +//===--------------------------------------------------------------------------- + def Ifrt_SpmdExpandableInterface : OpInterface<"IfrtSpmdExpandable"> { let cppNamespace = "::xla::ifrt"; @@ -63,4 +67,49 @@ def NestedInIfrtFunc : NativeOpTrait<"xla::ifrt::NestedInIfrtFuncTrait">; class IfrtCallLike : ParamNativeOpTrait<"xla::ifrt::IfrtCallLikeTrait", callee_op_type>; +//===--------------------------------------------------------------------------- +// Attribute interfaces +//===--------------------------------------------------------------------------- + +class Ifrt_AttrInterface : AttrInterface { + let cppNamespace = "::xla::ifrt"; +} + +def Ifrt_ShardingAttrInterface : Ifrt_AttrInterface<"IfrtShardingAttrInterface"> { + let description = [{ + Interface that all IFRT IR sharding attributes must implement. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/"Verifies if the sharding can be applied to the array.", + /*retTy=*/"::mlir::LogicalResult", + /*methodName=*/"CanApplyTo", + /*args=*/(ins + "::llvm::function_ref":$emitError, + "::mlir::RankedTensorType":$shape, + "llvm::ArrayRef":$device_ids + ) + >, + InterfaceMethod< + /*desc=*/"Returns the shape of the global array from a local array shape.", + /*retTy=*/"::absl::StatusOr>", + /*methodName=*/"GlobalShapeFromLocalShape", + /*args=*/(ins "llvm::ArrayRef":$local_shape) + >, + InterfaceMethod< + /*desc=*/"Returns the shape of the local array from a global array shape.", + /*retTy=*/"::absl::StatusOr>", + /*methodName=*/"LocalShapeFromGlobalShape", + /*args=*/(ins "llvm::ArrayRef":$global_shape) + >, + InterfaceMethod< + /*desc=*/"Returns the number of devices this sharding applied to.", + /*retTy=*/"int", + /*methodName=*/"NumDevices", + /*args=*/(ins) + > + ]; +} + #endif // XLA_PYTHON_IFRT_IR_IFRT_INTERFACES_TD_ diff --git a/xla/python/ifrt/ir/ifrt_ops.cc b/xla/python/ifrt/ir/ifrt_ops.cc index a8bc0dfbe92db..c61e3b82b7e59 100644 --- a/xla/python/ifrt/ir/ifrt_ops.cc +++ b/xla/python/ifrt/ir/ifrt_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,10 +32,12 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "xla/python/ifrt/ir/constants.h" #include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/ir/ifrt_interfaces.h" // Generated definitions. #define GET_OP_CLASSES @@ -61,18 +63,16 @@ mlir::FailureOr GetGlobalShape(mlir::Value value) { } mlir::FailureOr GetGlobalShapeFromLocal( - mlir::Type type, ShardingParam shard_param) { + mlir::Type type, IfrtShardingAttrInterface sharding_attr) { if (auto local_ranked_tensor = type.dyn_cast()) { - llvm::SmallVector global_shape; - auto local_shape = local_ranked_tensor.getShape(); - if (local_shape.size() != shard_param.dim_shards().size()) { + auto global_shape = + sharding_attr.GlobalShapeFromLocalShape(local_ranked_tensor.getShape()); + if (global_shape.ok()) { + return mlir::RankedTensorType::get(global_shape.value(), + local_ranked_tensor.getElementType()); + } else { return mlir::failure(); } - for (auto [idx, dim_shard] : llvm::enumerate(shard_param.dim_shards())) { - global_shape.push_back(dim_shard * local_shape[idx]); - } - return mlir::RankedTensorType::get(global_shape, - local_ranked_tensor.getElementType()); } else { // IFRT arrays cannot be in the local view. return mlir::failure(); @@ -122,7 +122,7 @@ mlir::LogicalResult VerifyGlobalLocalShapesEquivalent( // Convert from local shape to global shape using the sharding provided // by the CallOp func signature. mlir::FailureOr callee_shape = - GetGlobalShapeFromLocal(callee_type, array.getSharding()); + GetGlobalShapeFromLocal(callee_type, array.getShardingAttr()); if (mlir::failed(callee_shape)) { return op->emitOpError() << "fails to get global shape from " << callee_mnemonic << ": " << callee_type; diff --git a/xla/python/ifrt/ir/ifrt_ops.h b/xla/python/ifrt/ir/ifrt_ops.h index 70aeeb8d6b098..312b7c4966d54 100644 --- a/xla/python/ifrt/ir/ifrt_ops.h +++ b/xla/python/ifrt/ir/ifrt_ops.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/ir/ifrt_ops.td b/xla/python/ifrt/ir/ifrt_ops.td index 1bd2b248e77e4..b6fb4ba10f144 100644 --- a/xla/python/ifrt/ir/ifrt_ops.td +++ b/xla/python/ifrt/ir/ifrt_ops.td @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/ir/sharding_param.cc b/xla/python/ifrt/ir/sharding_param.cc index 6c26f947a8ba9..67bec9cb2abe0 100644 --- a/xla/python/ifrt/ir/sharding_param.cc +++ b/xla/python/ifrt/ir/sharding_param.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,14 +18,22 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tsl/platform/errors.h" namespace xla { namespace ifrt { @@ -66,14 +74,37 @@ void PopulateDevices(llvm::ArrayRef permutation, } // namespace +absl::Status ShardingParam::MinorToMajor::verify() const { + if (permutation.size() != axis_sizes.size() || axis_sizes.empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Expect same non-zero size for `permutation` and `axis_sizes`. Actual ", + permutation.size(), " vs ", axis_sizes.size())); + } + llvm::DenseSet permutation_set(permutation.begin(), permutation.end()); + if (permutation_set.size() != permutation.size()) { + return absl::InvalidArgumentError( + absl::StrCat("`permutation` [", absl::StrJoin(permutation, ","), + "] has duplicate values")); + } + for (const int index : permutation) { + if (index < 0 || index >= axis_sizes.size()) { + return absl::InvalidArgumentError( + absl::StrCat("Out of range axis ", index, " to the mesh of [", + absl::StrJoin(permutation, ","), "] on ", + absl::StrJoin(axis_sizes, "x"))); + } + } + return absl::OkStatus(); +} + mlir::LogicalResult ShardingParam::MinorToMajor::verify( llvm::function_ref emit_error) const { - if (permutation.size() != axis_sizes.size() || axis_sizes.empty()) { - return emit_error() << "Expect same non-zero size for `permutation` and " - "`axis_sizes`. Actual " - << permutation.size() << " vs " << axis_sizes.size(); + auto status = verify(); + if (status.ok()) { + return mlir::success(); + } else { + return emit_error() << status.message(); } - return mlir::success(); } void ShardingParam::MinorToMajor::ToDeviceList( @@ -120,12 +151,8 @@ mlir::FailureOr ShardingParam::Parse( return ShardingParam(dim_shards, minor_to_major); } -mlir::LogicalResult ShardingParam::verify( - llvm::function_ref emit_error) const { - if (mlir::failed(minor_to_major().verify(emit_error))) { - return mlir::failure(); - } - +absl::Status ShardingParam::verify() const { + TF_RETURN_IF_ERROR(minor_to_major().verify()); int dim_index = 0; int cum_size = 1; for (const int index : minor_to_major().permutation) { @@ -135,19 +162,10 @@ mlir::LogicalResult ShardingParam::verify( if (dim_index == dim_shards().size()) { break; } - if (index < 0 || index >= minor_to_major().axis_sizes.size()) { - return emit_error() << "Out of range axis " << index << " to the mesh of " - << minor_to_major().permutation << " on " - << minor_to_major().axis_sizes; - } - cum_size *= minor_to_major().axis_sizes[index]; - if (cum_size > dim_shards()[dim_index]) { - return emit_error() << "Dimension #" << dim_index << " of " - << dim_shards()[dim_index] - << " shards can't be assigned to the axes"; - } else if (cum_size == dim_shards()[dim_index]) { - cum_size = 1; + while (dim_index < dim_shards().size() && + cum_size % dim_shards()[dim_index] == 0) { + cum_size /= dim_shards()[dim_index]; dim_index++; } } @@ -155,12 +173,22 @@ mlir::LogicalResult ShardingParam::verify( dim_index++; } if (dim_index != dim_shards().size()) { - return emit_error() << "Can't shard the dims " << dim_shards() - << " to the mesh of " << minor_to_major().permutation - << " on " << minor_to_major().axis_sizes; + return absl::InvalidArgumentError(absl::StrCat( + "Can't shard the dims ", absl::StrJoin(dim_shards(), "x"), + " to the mesh of [", absl::StrJoin(minor_to_major().permutation, ","), + "] on ", absl::StrJoin(minor_to_major().axis_sizes, "x"))); } + return absl::OkStatus(); +} - return mlir::success(); +mlir::LogicalResult ShardingParam::verify( + llvm::function_ref emit_error) const { + auto status = verify(); + if (status.ok()) { + return mlir::success(); + } else { + return emit_error() << status.message(); + } } std::string ShardingParam::DebugString() const { @@ -170,6 +198,70 @@ std::string ShardingParam::DebugString() const { return result; } +mlir::LogicalResult ShardingParam::CanApplyTo( + llvm::function_ref emitError, + mlir::RankedTensorType shape, llvm::ArrayRef device_ids) const { + if (mlir::failed(verify(emitError))) { + return mlir::failure(); + } + + if (shape.getRank() != dim_shards().size()) { + return emitError() << "Requires dim shards to have the same rank as the " + "array. Array rank is " + << shape.getRank() << " vs dim shards rank of " + << dim_shards().size(); + } + + auto devices_in_mesh = NumDevices(); + if (devices_in_mesh != device_ids.size()) { + return emitError() << "Requires the same amount of `devices` and from " + "`sharding`. Actual: " + << device_ids.size() << " vs " << devices_in_mesh; + } + + return mlir::success(); +} + +absl::StatusOr> +ShardingParam::GlobalShapeFromLocalShape( + llvm::ArrayRef local_shape) const { + llvm::SmallVector global_shape; + if (local_shape.size() != dim_shards().size()) { + return absl::InvalidArgumentError( + "Rank of local tensor differs from rank of `dim_shards`."); + } + for (auto [idx, dim_shard] : llvm::enumerate(dim_shards())) { + global_shape.push_back(dim_shard * local_shape[idx]); + } + return global_shape; +} + +absl::StatusOr> +ShardingParam::LocalShapeFromGlobalShape( + llvm::ArrayRef global_shape) const { + auto num_shards = dim_shards(); + llvm::SmallVector local_shape; + local_shape.reserve(global_shape.size()); + for (int i = 0; i < num_shards.size(); ++i) { + if (global_shape[i] % num_shards[i] != 0) { + return absl::InvalidArgumentError(absl::StrCat( + "Global shape is not divisible by the number of shards in dimension ", + i, ". Global size: ", global_shape[i], + ", number of shards: ", num_shards[i], ".")); + } + local_shape.push_back(global_shape[i] / num_shards[i]); + } + return local_shape; +} + +int ShardingParam::NumDevices() const { + int devices_in_mesh = 1; + for (const int axis_size : minor_to_major().axis_sizes) { + devices_in_mesh *= axis_size; + } + return devices_in_mesh; +} + llvm::hash_code hash_value(ShardingParam sharding) { return sharding.hash_value(); } diff --git a/xla/python/ifrt/ir/sharding_param.h b/xla/python/ifrt/ir/sharding_param.h index 5f29e7e246769..b2b4020d4a2ee 100644 --- a/xla/python/ifrt/ir/sharding_param.h +++ b/xla/python/ifrt/ir/sharding_param.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,11 +19,14 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -66,8 +69,6 @@ namespace ifrt { // in axis-0. // // See `support` directory for conversions with other sharding annotations. -// -// TODO(b/271129892): Should we support maximal sharding here? class ShardingParam { public: // Represents a permutation of mesh dimensions from minor to major. @@ -79,6 +80,7 @@ class ShardingParam { // The size of mesh dimensions before the permutation. llvm::SmallVector axis_sizes; + absl::Status verify() const; mlir::LogicalResult verify( llvm::function_ref emit_error) const; @@ -94,9 +96,24 @@ class ShardingParam { : dim_shards_(dim_shards), minor_to_major_(minor_to_major) {} static mlir::FailureOr Parse(mlir::AsmParser& ods_parser); + absl::Status verify() const; mlir::LogicalResult verify( llvm::function_ref emit_error) const; + // Verifies if the sharding can be applied to the array. + mlir::LogicalResult CanApplyTo( + llvm::function_ref emitError, + mlir::RankedTensorType shape, llvm::ArrayRef device_ids) const; + + absl::StatusOr> GlobalShapeFromLocalShape( + llvm::ArrayRef local_shape) const; + + absl::StatusOr> LocalShapeFromGlobalShape( + llvm::ArrayRef global_shape) const; + + // Returns the number of devices the array is sharded over. + int NumDevices() const; + llvm::ArrayRef dim_shards() const { return dim_shards_; } const MinorToMajor& minor_to_major() const { return minor_to_major_; } diff --git a/xla/python/ifrt/ir/tests/BUILD b/xla/python/ifrt/ir/tests/BUILD index a65fad7799fe7..b5ba2b9553ad4 100644 --- a/xla/python/ifrt/ir/tests/BUILD +++ b/xla/python/ifrt/ir/tests/BUILD @@ -1,4 +1,4 @@ -load("//xla:glob_lit_test.bzl", "glob_lit_tests") +load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") load("//xla:xla.bzl", "xla_cc_test") package( @@ -6,20 +6,31 @@ package( licenses = ["notice"], ) -glob_lit_tests( +lit_test_suite( name = "all_tests", - data = [":test_utilities"], - driver = "//xla:run_lit.sh", - test_file_exts = ["mlir"], -) - -filegroup( - name = "test_utilities", - testonly = True, - data = [ + srcs = enforce_glob( + [ + "ifrt_duplicated_callee_elimination.mlir", + "ifrt_verify_sharding_specified.mlir", + "spmd_expansion.mlir", + "spmd_interface_verification.mlir", + "verify_array.mlir", + "verify_assemble.mlir", + "verify_attrs.mlir", + "verify_call.mlir", + "verify_call_loaded_executable.mlir", + "verify_disassemble.mlir", + "verify_loaded_executable.mlir", + "verify_reshard.mlir", + ], + include = [ + "*.mlir", + ], + ), + cfg = "//xla:lit.cfg.py", + tools = [ ":ifrt-opt", "@llvm-project//llvm:FileCheck", - "@llvm-project//mlir:run_lit.sh", ], ) @@ -27,6 +38,7 @@ cc_binary( name = "ifrt-opt", srcs = ["ifrt-opt.cc"], deps = [ + "//xla/mlir_hlo:hlo_dialect_registration", "//xla/python/ifrt/ir", "//xla/python/ifrt/ir/transforms:built_in_spmd_expansions", "//xla/python/ifrt/ir/transforms:passes", diff --git a/xla/python/ifrt/ir/tests/executable_impl_test_base.cc b/xla/python/ifrt/ir/tests/executable_impl_test_base.cc index 2f85ef0cd24c1..dca71fb3a5fa0 100644 --- a/xla/python/ifrt/ir/tests/executable_impl_test_base.cc +++ b/xla/python/ifrt/ir/tests/executable_impl_test_base.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/ir/tests/executable_impl_test_base.h b/xla/python/ifrt/ir/tests/executable_impl_test_base.h index 9ff9d3db0a777..26f9b5840b9ce 100644 --- a/xla/python/ifrt/ir/tests/executable_impl_test_base.h +++ b/xla/python/ifrt/ir/tests/executable_impl_test_base.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc b/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc index b3170deeb0a11..678677ead34d3 100644 --- a/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc +++ b/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -50,7 +50,8 @@ class IfrtIrExecutableImplTest TEST_F(IfrtIrExecutableImplTest, CallXla) { std::string source = R"( -!array = !ifrt.array, 2x1 to [0] on 2, [0,1]> +!array = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]> module { func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [0,1] @@ -97,13 +98,15 @@ module { TEST_F(IfrtIrExecutableImplTest, Reshard) { std::string source = R"( module { - func.func @main(%arg0: !ifrt.array, 1 to [0] on 1, [0]>) - -> !ifrt.array, 1 to [0] on 1, [1]> + func.func @main(%arg0: !ifrt.array, + #ifrt.sharding_param<1 to [0] on 1>, [0]>) + -> !ifrt.array, #ifrt.sharding_param<1 to [0] on 1>, [1]> attributes {ifrt.function} { %0 = "ifrt.Reshard"(%arg0) - : (!ifrt.array, 1 to [0] on 1, [0]>) - -> !ifrt.array, 1 to [0] on 1, [1]> - return %0 : !ifrt.array, 1 to [0] on 1, [1]> + : (!ifrt.array, #ifrt.sharding_param<1 to [0] on 1>, [0]>) + -> !ifrt.array, #ifrt.sharding_param<1 to [0] on 1>, [1]> + return %0 : !ifrt.array, + #ifrt.sharding_param<1 to [0] on 1>, [1]> } } )"; @@ -137,7 +140,8 @@ module { TEST_F(IfrtIrExecutableImplTest, ZeroInput) { std::string source = R"( -!array = !ifrt.array, 2x1 to [0] on 2, [0,1]> +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> module { func.func @main() -> !array attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @one() on devices [0,1] : () -> !array @@ -172,7 +176,8 @@ module { TEST_F(IfrtIrExecutableImplTest, ZeroOutput) { std::string source = R"( -!array = !ifrt.array, 2x1 to [0] on 2, [0,1]> +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> module { func.func @main(%arg0: !array) attributes {ifrt.function} { %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [0,1] : (!array) -> () @@ -214,7 +219,8 @@ module { TEST_F(IfrtIrExecutableImplTest, BufferDonation) { std::string source = R"( -!array = !ifrt.array, 2x1 to [0] on 2, [0,1]> +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> module { func.func @main(%arg0: !array {ifrt.donated}) -> !array attributes {ifrt.function} { @@ -299,7 +305,8 @@ module { std::make_unique(std::move(xla_options)))); std::string source = R"( -!array = !ifrt.array, 2x1 to [0] on 2, [0,1]> +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> module { func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { %0, %ctrl_0 = ifrt.CallLoadedExecutable @add_one(%arg0) diff --git a/xla/python/ifrt/ir/tests/ifrt-opt.cc b/xla/python/ifrt/ir/tests/ifrt-opt.cc index 5cde072cc2a40..38a8979a14a98 100644 --- a/xla/python/ifrt/ir/tests/ifrt-opt.cc +++ b/xla/python/ifrt/ir/tests/ifrt-opt.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include "mlir/IR/DialectRegistry.h" // from @llvm-project #include "mlir/InitAllDialects.h" // from @llvm-project #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/python/ifrt/ir/ifrt_dialect.h" #include "xla/python/ifrt/ir/transforms/built_in_spmd_expansions.h" #include "xla/python/ifrt/ir/transforms/passes.h" @@ -23,6 +24,7 @@ limitations under the License. int main(int argc, char** argv) { mlir::DialectRegistry registry; mlir::registerAllDialects(registry); + mlir::mhlo::registerAllMhloDialects(registry); registry.insert(); xla::ifrt::registerIfrtIrPasses(); diff --git a/xla/python/ifrt/ir/tests/ifrt_duplicated_callee_elimination.mlir b/xla/python/ifrt/ir/tests/ifrt_duplicated_callee_elimination.mlir index fe6dee2e68736..5cf62e23e0e59 100644 --- a/xla/python/ifrt/ir/tests/ifrt_duplicated_callee_elimination.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_duplicated_callee_elimination.mlir @@ -1,7 +1,8 @@ // RUN: ifrt-opt %s -ifrt-duplicated-callee-elimination | FileCheck %s // CHECK-LABEL: @main -func.func @main(%arg0: !ifrt.array, 1x1 to [0] on 1, [0]>) +func.func @main(%arg0: !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [0]>) attributes {ifrt.function} { // CHECK: %[[CTRL:.+]] = ifrt.Call @callee %ctrl_0 = ifrt.Call @callee() on devices [0,1] : () -> () @@ -17,7 +18,8 @@ func.func @main(%arg0: !ifrt.array, 1x1 to [0] on 1, [0]>) // CHECK-NOT: ifrt.Call @callee // CHECK: ifrt.Call @callee_different_signature %ctrl_4 = ifrt.Call @callee_different_signature(%arg0) on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 1, [0]>) -> () + : (!ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [0]>) -> () return } diff --git a/xla/python/ifrt/ir/tests/ifrt_verify_sharding_specified.mlir b/xla/python/ifrt/ir/tests/ifrt_verify_sharding_specified.mlir new file mode 100644 index 0000000000000..6be0131833826 --- /dev/null +++ b/xla/python/ifrt/ir/tests/ifrt_verify_sharding_specified.mlir @@ -0,0 +1,69 @@ +// RUN: ifrt-opt %s -ifrt-verify-sharding-specified -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @good_arrays +#sharding = #ifrt.sharding_param<2 to [0] on 2> +module @good_arrays { + func.func @main(%arg0: !ifrt.array, #sharding, [0,1]>) + -> !ifrt.array, #sharding, [2,3]> + attributes {ifrt.function} { + %0, %ctrl_1 = ifrt.Call @identity(%arg0) on devices [0,1] + : (!ifrt.array, #sharding, [0,1]>) + -> !ifrt.array, #sharding, [0,1]> + %1 = "ifrt.Reshard"(%0) + : (!ifrt.array, #sharding, [0,1]>) + -> !ifrt.array, #sharding, [2,3]> + return %1 : !ifrt.array, #sharding, [2,3]> + } + + func.func private @identity(%arg0: tensor<2xi32>) -> tensor<2xi32> { + return %arg0 : tensor<2xi32> + } +} + +// ----- + +module @main_arg_sharding_unspecified { + // expected-error @+1 {{'func.func' op argument 0 has unspecified sharding.}} + func.func @main( + %arg0: !ifrt.array, #ifrt.sharding_unspecified, [0,1]>) + attributes {ifrt.function} { + return + } +} + +// ----- + +#sharding = #ifrt.sharding_param<2 to [0] on 2> +module @main_result_sharding_unspecified { + func.func @main() + -> !ifrt.array, #ifrt.sharding_unspecified, [0,1]> + attributes {ifrt.function} { + // expected-error @+1 {{'ifrt.Call' op result 0 has unspecified sharding.}} + %0, %ctrl_1 = ifrt.Call @create_array() on devices [0,1] + : () -> !ifrt.array, #ifrt.sharding_unspecified, [0,1]> + return %0 : !ifrt.array, #ifrt.sharding_unspecified, [0,1]> + } + + func.func private @create_array() -> tensor<2xi32> { + %0 = mhlo.constant dense<1> : tensor<2xi32> + return %0 : tensor<2xi32> + } +} + +// ----- + +#sharding = #ifrt.sharding_param<2 to [0] on 2> +module @reshard_with_unspecified_sharding { + func.func @main(%arg0: !ifrt.array, #sharding, [0,1]>) + -> !ifrt.array, #sharding, [2,3]> + attributes {ifrt.function} { + // expected-error @+1 {{'ifrt.Reshard' op result 0 has unspecified sharding.}} + %0 = ifrt.Reshard(%arg0) + : (!ifrt.array, #sharding, [0,1]>) + -> !ifrt.array, #ifrt.sharding_unspecified, [2,3]> + %1 = ifrt.Reshard(%0) + : (!ifrt.array, #ifrt.sharding_unspecified, [2,3]>) + -> !ifrt.array, #sharding, [2,3]> + return %1 : !ifrt.array, #sharding, [2,3]> + } +} diff --git a/xla/python/ifrt/ir/tests/spmd_expansion.mlir b/xla/python/ifrt/ir/tests/spmd_expansion.mlir index f15908703b9ec..28a1dda2b3f77 100644 --- a/xla/python/ifrt/ir/tests/spmd_expansion.mlir +++ b/xla/python/ifrt/ir/tests/spmd_expansion.mlir @@ -1,7 +1,7 @@ // RUN: ifrt-opt %s -spmd-expansion -split-input-file -verify-diagnostics | FileCheck %s #device = #ifrt -#sharding = #ifrt.sharding<2x1 to [0] on 2> +#sharding = #ifrt.sharding_param<2x1 to [0] on 2> // CHECK-LABEL: @identity_axis0_sharded module @identity_axis0_sharded attributes {ifrt.devices = #device} { // CHECK-NEXT: func.func @main @@ -20,7 +20,7 @@ module @identity_axis0_sharded attributes {ifrt.devices = #device} { // ----- #device = #ifrt -#sharding = #ifrt.sharding<1x2 to [0] on 2> +#sharding = #ifrt.sharding_param<1x2 to [0] on 2> // CHECK-LABEL: @identity_axis1_sharded module @identity_axis1_sharded attributes {ifrt.devices = #device, ifrt.entry_function = "entry_func"} { @@ -40,7 +40,7 @@ module @identity_axis1_sharded // ----- #device = #ifrt -#sharding = #ifrt.sharding<3x2 to [1,0] on 2x3> +#sharding = #ifrt.sharding_param<3x2 to [1,0] on 2x3> // CHECK-LABEL: @identify_both_axes_sharded module @identify_both_axes_sharded attributes {ifrt.devices = #device} { // CHECK-NEXT: func.func @main @@ -70,10 +70,12 @@ module @with_func_call attributes {ifrt.devices = #device} { // CHECK: return // CHECK-SAME: tensor<1x2xi32> func.func @main( - %arg0: tensor<2x2xi32> {ifrt.sharding = #ifrt.sharding<2x1 to [0] on 2>, - ifrt.devices = #device}) - -> (tensor<2x2xi32> {ifrt.sharding = #ifrt.sharding<2x1 to [0] on 2>, - ifrt.devices = #device}) { + %arg0: tensor<2x2xi32> { + ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2>, + ifrt.devices = #device}) + -> (tensor<2x2xi32> { + ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2>, + ifrt.devices = #device}) { %0 = func.call @identify(%arg0) : (tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -102,10 +104,12 @@ module @with_nested_func_call attributes {ifrt.devices = #device} { // CHECK: return // CHECK-SAME: tensor<1x2xi32> func.func @main( - %arg0: tensor<2x2xi32> {ifrt.sharding = #ifrt.sharding<2x1 to [0] on 2>, - ifrt.devices = #device}) - -> (tensor<2x2xi32> {ifrt.sharding = #ifrt.sharding<2x1 to [0] on 2>, - ifrt.devices = #device}) { + %arg0: tensor<2x2xi32> { + ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2>, + ifrt.devices = #device}) + -> (tensor<2x2xi32> { + ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2>, + ifrt.devices = #device}) { %0 = func.call @call_identify(%arg0) : (tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -136,7 +140,7 @@ module @with_nested_func_call attributes {ifrt.devices = #device} { // ----- #device = #ifrt -#sharding = #ifrt.sharding<1x2 to [0] on 2> +#sharding = #ifrt.sharding_param<1x2 to [0] on 2> // expected-error@+1 {{cannot find entry function `main`}} module @missing_main_function attributes {ifrt.devices = #device} { @@ -145,7 +149,7 @@ module @missing_main_function // ----- #device = #ifrt -#sharding = #ifrt.sharding<1x2 to [0] on 2> +#sharding = #ifrt.sharding_param<1x2 to [0] on 2> // expected-error@+1 {{cannot find entry function `entry_func`}} module @missing_entry_function attributes {ifrt.devices = #device, ifrt.entry_function = "entry_func"} { @@ -161,7 +165,7 @@ module @missing_entry_function // ----- #device = #ifrt -#sharding = #ifrt.sharding<2x1 to [0] on 2> +#sharding = #ifrt.sharding_param<2x1 to [0] on 2> module @non_divisible_global_shape attributes {ifrt.devices = #device} { // expected-error@+1 {{Global shape is not divisible by the number of shards in dimension 0. Global size: 3, number of shards: 2}} func.func @main( diff --git a/xla/python/ifrt/ir/tests/spmd_interface_verification.mlir b/xla/python/ifrt/ir/tests/spmd_interface_verification.mlir index b6c81d613b602..0bd8ef268e0fa 100644 --- a/xla/python/ifrt/ir/tests/spmd_interface_verification.mlir +++ b/xla/python/ifrt/ir/tests/spmd_interface_verification.mlir @@ -2,11 +2,14 @@ module @good_return_only { func.func @main( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @simple_return(%arg0) on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> return } @@ -17,11 +20,14 @@ module @good_return_only { module @good_non_expandable_on_one_device{ func.func @main( - %arg0: !ifrt.array, 1x1 to [0] on 1, [0]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 1>, + [0]>) attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @math_absi(%arg0) on devices [0] - : (!ifrt.array, 1x1 to [0] on 1, [0]>) - -> !ifrt.array, 1x1 to [0] on 1, [0]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 1>, + [0]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 1>, + [0]> return } @@ -33,11 +39,14 @@ module @good_non_expandable_on_one_device{ module @good_excluded_dialect_on_two_devices { func.func @main( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @arith_self_add(%arg0) on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> return } @@ -51,11 +60,14 @@ module @good_excluded_dialect_on_two_devices { module @unexpandable_on_two_devices { func.func @main( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @math_absi(%arg0) on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> return } diff --git a/xla/python/ifrt/ir/tests/verify_array.mlir b/xla/python/ifrt/ir/tests/verify_array.mlir index 86eafd22d9508..a1379bdd644eb 100644 --- a/xla/python/ifrt/ir/tests/verify_array.mlir +++ b/xla/python/ifrt/ir/tests/verify_array.mlir @@ -13,14 +13,17 @@ func.func @good_array() { /// The equivalent HloSharding is /// {devices=[4,1,3]0,2,1,3,4,6,5,7,8,10,9,11 replicate_on_last_dim} %0 = builtin.unrealized_conversion_cast to - !ifrt.array, 4x1 to [1,0,2] on 2x2x3, [0,1,2,3,4,5,6,7,8,9,10,11]> + !ifrt.array, + #ifrt.sharding_param<4x1 to [1,0,2] on 2x2x3>, + [0,1,2,3,4,5,6,7,8,9,10,11]> return } #devices = #ifrt func.func @good_array_with_aliased_devices() { %0 = builtin.unrealized_conversion_cast to - !ifrt.array, 4x1 to [0,1] on 2x2, #devices> + !ifrt.array, #ifrt.sharding_param<4x1 to [0,1] on 2x2>, + #devices> return } @@ -28,54 +31,60 @@ func.func @good_array_with_aliased_devices() { func.func @good_array_scalar() { %0 = builtin.unrealized_conversion_cast to - !ifrt.array, to [0,1] on 2x2, [0,1,2,3]> + !ifrt.array,#ifrt.sharding_param< to [0,1] on 2x2>, [0,1,2,3]> return } // ----- func.func @array_devices_should_be_distinct() { - // expected-error@+3 {{Device list has duplicate id 0}} + // expected-error@+3 {{Device list has duplicate logical id 0}} // expected-error@+2 {{failed to parse Ifrt_ArrayType parameter 'devices_attr'}} %0 = builtin.unrealized_conversion_cast to - !ifrt.array, 1x1 to [0] on 2, [0,0]> + !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0,0]> return } // ----- func.func @array_devices_should_be_non_negative() { - // expected-error@+3 {{Device list has negative id -1}} - // expected-error@+2 {{failed to parse Ifrt_ArrayType parameter 'devices_attr'}} + // expected-error@+4 {{Device list has negative logical id -1}} + // expected-error@+3 {{failed to parse Ifrt_ArrayType parameter 'devices_attr'}} %0 = builtin.unrealized_conversion_cast to - !ifrt.array, 1x1 to [0] on 2, [-1,0]> + !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [-1,0]> return } // ----- func.func @array_requires_same_permutation_and_axis_sizes() { - // expected-error@+2 {{Expect same non-zero size for `permutation` and `axis_sizes`. Actual 2 vs 1}} + // expected-error@+3 {{Expect same non-zero size for `permutation` and `axis_sizes`. Actual 2 vs 1}} + // expected-error@+2 {{failed to parse Ifrt_ArrayType parameter 'sharding_attr'}} %0 = builtin.unrealized_conversion_cast to - !ifrt.array, 1x1 to [0,1] on 2, [0,1]> + !ifrt.array, #ifrt.sharding_param<1x1 to [0,1] on 2>, + [0,1]> return } // ----- func.func @array_requires_enough_devices() { - // expected-error@+2 {{Can't shard the dims 2, 2 to the mesh of 0 on 2}} + // expected-error@+3 {{Can't shard the dims 2x2 to the mesh of [0] on 2}} + // expected-error@+2 {{failed to parse Ifrt_ArrayType parameter 'sharding_attr'}} %0 = builtin.unrealized_conversion_cast to - !ifrt.array, 2x2 to [0] on 2, [0,1]> + !ifrt.array, #ifrt.sharding_param<2x2 to [0] on 2>, [0,1]> return } // ----- func.func @array_requires_shard_distributable_to_axes() { - // expected-error@+2 {{Dimension #1 of 2 shards can't be assigned to the axes}} + // expected-error@+3 {{Can't shard the dims 1x2 to the mesh of [0] on 3}} + // expected-error@+2 {{failed to parse Ifrt_ArrayType parameter 'sharding_attr'}} %0 = builtin.unrealized_conversion_cast to - !ifrt.array, 1x2 to [0] on 3, [0,1,2]> + !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 3>, + [0,1,2]> return } @@ -84,7 +93,8 @@ func.func @array_requires_shard_distributable_to_axes() { func.func @array_requires_same_size_of_devices_and_from_axes() { // expected-error@+2 {{Requires the same amount of `devices` and from `sharding`. Actual: 3 vs 4}} %0 = builtin.unrealized_conversion_cast to - !ifrt.array, 2x2 to [0,1] on 2x2, [0,1,2]> + !ifrt.array, #ifrt.sharding_param<2x2 to [0,1] on 2x2>, + [0,1,2]> return } @@ -93,15 +103,18 @@ func.func @array_requires_same_size_of_devices_and_from_axes() { func.func @array_requires_rank_matching_dim_shards() { // expected-error@+2 {{Requires dim shards to have the same rank as the array. Array rank is 2 vs dim shards rank of 0}} %0 = builtin.unrealized_conversion_cast to - !ifrt.array, to [0,1] on 2x2, [0,1,2,3]> + !ifrt.array, #ifrt.sharding_param< to [0,1] on 2x2>, + [0,1,2,3]> return } // ----- func.func @array_requires_non_empty_permutation() { - // expected-error@+2 {{Expect same non-zero size for `permutation` and `axis_sizes`. Actual 0 vs 0}} + // expected-error@+3 {{Expect same non-zero size for `permutation` and `axis_sizes`. Actual 0 vs 0}} + // expected-error@+2 {{failed to parse Ifrt_ArrayType parameter 'sharding_attr'}} %0 = builtin.unrealized_conversion_cast to - !ifrt.array, 2x2 to [] on , [0,1,2,3]> + !ifrt.array, #ifrt.sharding_param<2x2 to [] on>, + [0,1,2,3]> return } diff --git a/xla/python/ifrt/ir/tests/verify_assemble.mlir b/xla/python/ifrt/ir/tests/verify_assemble.mlir index dea843152255f..9c3416961102d 100644 --- a/xla/python/ifrt/ir/tests/verify_assemble.mlir +++ b/xla/python/ifrt/ir/tests/verify_assemble.mlir @@ -1,57 +1,77 @@ // RUN: ifrt-opt %s -split-input-file -verify-diagnostics func.func @good_assemble( - %arg0: !ifrt.array, 1x1 to [0] on 1, [0]>, - %arg1: !ifrt.array, 1x1 to [0] on 1, [1]>) + %arg0: !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [0]>, + %arg1: !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [1]>) attributes {ifrt.function} { %0 = "ifrt.Assemble"(%arg0, %arg1) {operandSegmentSizes=array} - : (!ifrt.array, 1x1 to [0] on 1, [0]>, - !ifrt.array, 1x1 to [0] on 1, [1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [0]>, + !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [1]>) + -> !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> return } // ----- func.func @assemble_requires_in_ifrt_function( - %arg0: !ifrt.array, 1x1 to [0] on 1, [0]>, - %arg1: !ifrt.array, 1x1 to [0] on 1, [1]>) { + %arg0: !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [0]>, + %arg1: !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [1]>) { // expected-error@+1 {{'ifrt.Assemble' op must be in a FuncOp with attr `ifrt.function`}} %0 = "ifrt.Assemble"(%arg0, %arg1) {operandSegmentSizes=array} - : (!ifrt.array, 1x1 to [0] on 1, [0]>, - !ifrt.array, 1x1 to [0] on 1, [1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [0]>, + !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [1]>) + -> !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> return } // ----- func.func @assemble_requires_inputs_on_single_devices( - %arg0: !ifrt.array, 1x2 to [0] on 2, [0,1]>, - %arg1: !ifrt.array, 1x1 to [0] on 1, [2]>) + %arg0: !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]>, + %arg1: !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [2]>) attributes {ifrt.function} { - // expected-error@+1 {{'ifrt.Assemble' op requires every input to be a single device array. Actual: '!ifrt.array, 1x2 to [0] on 2, [0, 1]>'}} + // expected-error@+1 {{'ifrt.Assemble' op requires every input to be a single device array. Actual: '!ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, [0, 1]>'}} %0 = "ifrt.Assemble"(%arg0, %arg1) {operandSegmentSizes=array} - : (!ifrt.array, 1x2 to [0] on 2, [0,1]>, - !ifrt.array, 1x1 to [0] on 1, [2]>) - -> !ifrt.array, 1x3 to [0] on 3, [0,1,2]> + : (!ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]>, + !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [2]>) + -> !ifrt.array, + #ifrt.sharding_param<1x3 to [0] on 3>, [0,1,2]> return } // ----- func.func @assemble_requires_same_device_list( - %arg0: !ifrt.array, 1x1 to [0] on 1, [0]>, - %arg1: !ifrt.array, 1x1 to [0] on 1, [1]>) + %arg0: !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [0]>, + %arg1: !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [1]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.Assemble' op requires the same input/output device list. Input 0, 1 vs Output 1, 2}} %0 = "ifrt.Assemble"(%arg0, %arg1) {operandSegmentSizes=array} - : (!ifrt.array, 1x1 to [0] on 1, [0]>, - !ifrt.array, 1x1 to [0] on 1, [1]>) - -> !ifrt.array, 1x2 to [0] on 2, [1,2]> + : (!ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [0]>, + !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [1]>) + -> !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [1,2]> return } diff --git a/xla/python/ifrt/ir/tests/verify_attrs.mlir b/xla/python/ifrt/ir/tests/verify_attrs.mlir index 654a65617ed15..45c18e7149367 100644 --- a/xla/python/ifrt/ir/tests/verify_attrs.mlir +++ b/xla/python/ifrt/ir/tests/verify_attrs.mlir @@ -5,7 +5,9 @@ func.func @good_function_attr() attributes {ifrt.function} { } func.func @good_donated_attr( - %arg0: !ifrt.array, 1x1 to [0] on 1, [0]> {ifrt.donated}) + %arg0: !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, + [0]> {ifrt.donated}) attributes {ifrt.function} { return } @@ -26,7 +28,8 @@ module @func_attr_should_be_on_func_op attributes {ifrt.function} {} // expected-error@+1 {{'func.func' op has `ifrt.donated` arg attr that is not a UnitAttr}} func.func @donated_attr_should_be_unit( - %arg0: !ifrt.array, 1x1 to [0] on 1, [0]> + %arg0: !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [0]> {ifrt.donated = "1"}) attributes {ifrt.function} { return @@ -36,6 +39,8 @@ func.func @donated_attr_should_be_unit( // expected-error@+1 {{'func.func' op has `ifrt.donated` arg attr but not has `ifrt.function` attr}} func.func @donated_attr_should_be_with_func_attr( - %arg0: !ifrt.array, 1x1 to [0] on 1, [0]> {ifrt.donated}) { + %arg0: !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, + [0]> {ifrt.donated}) { return } diff --git a/xla/python/ifrt/ir/tests/verify_call.mlir b/xla/python/ifrt/ir/tests/verify_call.mlir index 444de6ae58b87..4e34630dd4121 100644 --- a/xla/python/ifrt/ir/tests/verify_call.mlir +++ b/xla/python/ifrt/ir/tests/verify_call.mlir @@ -1,41 +1,53 @@ // RUN: ifrt-opt %s -split-input-file -verify-diagnostics func.func @good_call( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> return } func.func @good_call_with_control_dep( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>, + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>, %arg1: !ifrt.control) attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @callee(%arg0) after %arg1 on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> return } func.func @good_call_with_io_aliases( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1] {io_aliases=[array]} - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> return } #devices = #ifrt func.func @good_call_with_aliased_devices( - %arg0: !ifrt.array, 1x1 to [0] on 2, #devices>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + #devices>) attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices #devices - : (!ifrt.array, 1x1 to [0] on 2, #devices>) - -> !ifrt.array, 1x1 to [0] on 2, #devices> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + #devices>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + #devices> return } @@ -47,11 +59,14 @@ func.func @callee(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { func.func @call_requires_in_ifrt_function( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) { + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) { // expected-error@+1 {{'ifrt.Call' op must be in a FuncOp with attr `ifrt.function`}} %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> return } @@ -62,24 +77,30 @@ func.func @callee(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { // ----- func.func @call_requires_valid_reference( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.Call' op requires '@missing_reference' to reference a valid `func.func`}} %0, %ctrl_0 = ifrt.Call @missing_reference(%arg0) on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> return } // ----- func.func @call_requires_same_input_size( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.Call' op requires the same input size. Input 1 vs Callee 0}} %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> return } @@ -91,12 +112,15 @@ func.func @callee() -> (tensor<4x4xi32>) { // ----- func.func @call_requires_same_input_shape( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.Call' op requires the same global shape. Input #0 'tensor<2x2xi32>' vs Callee 'tensor<2x4xi32>'}} %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> return } @@ -108,12 +132,15 @@ func.func @callee(%arg0: tensor<2x4xi32>) -> tensor<4x4xi32> { // ----- func.func @call_requires_same_output_size( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.Call' op requires the same output size. Output 1 vs Callee 0}} %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> return } @@ -124,12 +151,15 @@ func.func @callee(%arg0: tensor<2x2xi32>) { // ----- func.func @call_requires_same_output_shape( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.Call' op requires the same global shape. Output #0 'tensor<4x4xi32>' vs Callee 'tensor<2x4xi32>'}} %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> return } @@ -141,12 +171,15 @@ func.func @callee(%arg0: tensor<2x2xi32>) -> tensor<2x4xi32> { // ----- func.func @call_requires_non_negative_devices_attr( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { - // expected-error@+1 {{'ifrt.Call' Device list has negative id -1}} + // expected-error@+1 {{'ifrt.Call' Device list has negative logical id -1}} %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1,-1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> return } @@ -158,12 +191,15 @@ func.func @callee(%arg0: tensor<2x2xi32>) -> tensor<4x4xi32> { // ----- func.func @call_requires_unique_devices_attr( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { - // expected-error@+1 {{'ifrt.Call' Device list has duplicate id 0}} + // expected-error@+1 {{'ifrt.Call' Device list has duplicate logical id 0}} %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,0] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> return } @@ -175,12 +211,14 @@ func.func @callee(%arg0: tensor<2x2xi32>) -> tensor<4x4xi32> { // ----- func.func @call_requires_input_place_on_devices( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,2]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,2]>) attributes {ifrt.function} { - // expected-error@+1 {{'ifrt.Call' op requires all inputs placed on `devices` attr. The following input is placed on device 2 not found in `devices` attr. '!ifrt.array, 1x1 to [0] on 2, [0, 2]>'}} + // expected-error@+1 {{'ifrt.Call' op requires all inputs placed on `devices` attr. The following input is placed on device 2 not found in `devices` attr. '!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 2]>'}} %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,2]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,2]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> return } @@ -192,12 +230,15 @@ func.func @callee(%arg0: tensor<2x2xi32>) -> tensor<4x4xi32> { // ----- func.func @call_requires_output_place_on_devices( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { - // expected-error@+1 {{'ifrt.Call' op requires all outputs placed on `devices` attr. The following output is placed on device 2 not found in `devices` attr. '!ifrt.array, 1x2 to [0] on 2, [0, 2]>'}} + // expected-error@+1 {{'ifrt.Call' op requires all outputs placed on `devices` attr. The following output is placed on device 2 not found in `devices` attr. '!ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, [0, 2]>'}} %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,2]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,2]> return } @@ -209,13 +250,16 @@ func.func @callee(%arg0: tensor<2x2xi32>) -> tensor<4x4xi32> { // ----- func.func @io_aliases_should_be_pairs( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.Call' op attribute 'io_aliases' failed to satisfy constraint: Array of pairs of aliased input/output indices}} %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1] {io_aliases=[array]} - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> return } @@ -226,13 +270,16 @@ func.func @callee(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { // ----- func.func @io_aliases_should_have_valid_input_index( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.Call' op can't alias input #1 to output #0 as only having 1 inputs}} %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1] {io_aliases=[array]} - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> return } @@ -243,14 +290,18 @@ func.func @callee(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { // ----- func.func @io_aliases_should_only_alias_input_once( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.Call' op can't alias input #0 more than once}} %0, %1, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1] {io_aliases=[array, array]} - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> (!ifrt.array, 1x1 to [0] on 2, [0,1]>, - !ifrt.array, 1x1 to [0] on 2, [0,1]>) + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>, + !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) return } @@ -262,13 +313,16 @@ func.func @callee(%arg0: tensor<2x2xi32>) // ----- func.func @io_aliases_should_have_valid_output_index( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.Call' op can't alias input #0 to output #1 as only having 1 outputs}} %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1] {io_aliases=[array]} - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> return } @@ -279,14 +333,18 @@ func.func @callee(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { // ----- func.func @io_aliases_should_only_alias_output_once( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.Call' op can't alias output #0 more than once}} %0, %ctrl_0 = ifrt.Call @callee(%arg0, %arg0) on devices [0,1] {io_aliases=[array, array]} - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>, - !ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>, + !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> return } @@ -298,13 +356,16 @@ func.func @callee(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) // ----- func.func @io_aliases_should_have_same_type( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { - // expected-error@+1 {{'ifrt.Call' op can't alias input #0 to output #0 with different types: '!ifrt.array, 1x1 to [0] on 2, [0, 1]>' vs '!ifrt.array, 2x1 to [0] on 2, [0, 1]>'}} + // expected-error@+1 {{'ifrt.Call' op can't alias input #0 to output #0 with different types: '!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]>' vs '!ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>'}} %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1] {io_aliases=[array]} - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 2x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]> return } @@ -315,11 +376,14 @@ func.func @callee(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { // ----- func.func @good_call_local_view( - %arg0: !ifrt.array, 2x2 to [0, 1] on 2x2, [0,1,2,3]>) + %arg0: !ifrt.array, + #ifrt.sharding_param<2x2 to [0, 1] on 2x2>, [0,1,2,3]>) attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1,2,3] {ifrt.local_view} - : (!ifrt.array, 2x2 to [0, 1] on 2x2, [0,1,2,3]>) - -> !ifrt.array, 2x2 to [0, 1] on 2x2, [0,1,2,3]> + : (!ifrt.array, + #ifrt.sharding_param<2x2 to [0, 1] on 2x2>, [0,1,2,3]>) + -> !ifrt.array, + #ifrt.sharding_param<2x2 to [0, 1] on 2x2>, [0,1,2,3]> return } @@ -330,12 +394,15 @@ func.func @callee(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { // ----- func.func @call_local_view_should_have_valid_shape( - %arg0: !ifrt.array, 2x2 to [0, 1] on 2x2, [0,1,2,3]>) + %arg0: !ifrt.array, + #ifrt.sharding_param<2x2 to [0, 1] on 2x2>, [0,1,2,3]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.Call' op requires the same global shape. Input #0 'tensor<4x4xi32>' vs Callee 'tensor<8x8xi32>'}} %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1,2,3] {ifrt.local_view} - : (!ifrt.array, 2x2 to [0, 1] on 2x2, [0,1,2,3]>) - -> !ifrt.array, 2x2 to [0, 1] on 2x2, [0,1,2,3]> + : (!ifrt.array, + #ifrt.sharding_param<2x2 to [0, 1] on 2x2>, [0,1,2,3]>) + -> !ifrt.array, + #ifrt.sharding_param<2x2 to [0, 1] on 2x2>, [0,1,2,3]> return } diff --git a/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir b/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir index ad6ae18f2dc04..3b0fbc80ca434 100644 --- a/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir +++ b/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir @@ -1,42 +1,55 @@ // RUN: ifrt-opt %s -split-input-file -verify-diagnostics func.func @good( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0) - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> return } func.func @good_with_control_dep( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>, + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>, %arg1: !ifrt.control) attributes {ifrt.function} { %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0) after %arg1 - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> return } ifrt.LoadedExecutable @callee on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> // ----- func.func @requires_in_ifrt_function( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) { + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) { // expected-error@+1 {{'ifrt.CallLoadedExecutable' op must be in a FuncOp with attr `ifrt.function`}} %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0) - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> return } ifrt.LoadedExecutable @callee on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> // ----- @@ -49,7 +62,8 @@ func.func @requires_valid_reference() attributes {ifrt.function} { // ----- func.func @requires_loaded_executable_callee( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.CallLoadedExecutable' op requires '@wrong_reference' to reference a valid `ifrt.LoadedExecutable`}} %ctrl_0 = ifrt.CallLoadedExecutable @wrong_reference() : () -> () @@ -63,119 +77,156 @@ func.func @wrong_reference() { // ----- func.func @requires_matching_signature( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { - // expected-error@+1 {{'ifrt.CallLoadedExecutable' op requires callee signature matching '(!ifrt.array, 1x1 to [0] on 2, [0, 1]>) -> !ifrt.array, 1x2 to [0] on 2, [0, 1]>'. Actual '(!ifrt.array, 1x1 to [0] on 2, [0, 1]>) -> !ifrt.array, 1x2 to [0] on 2, [0, 1]>'}} + // expected-error@+1 {{'ifrt.CallLoadedExecutable' op requires callee signature matching '(!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]>) -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, [0, 1]>'. Actual '(!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]>) -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, [0, 1]>'}} %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0) - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> return } ifrt.LoadedExecutable @callee on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> // ----- func.func @io_aliases_should_be_pairs( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.CallLoadedExecutable' op attribute 'io_aliases' failed to satisfy constraint: Array of pairs of aliased input/output indices}} %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0) {io_aliases=[array]} - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> return } ifrt.LoadedExecutable @callee on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> // ----- func.func @io_aliases_should_have_valid_input_index( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias input #1 to output #0 as only having 1 inputs}} %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0) {io_aliases=[array]} - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> return } ifrt.LoadedExecutable @callee on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> // ----- func.func @io_aliases_should_only_alias_input_once( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias input #0 more than once}} %0, %1, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0) {io_aliases=[array, array]} - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> (!ifrt.array, 1x1 to [0] on 2, [0,1]>, - !ifrt.array, 1x1 to [0] on 2, [0,1]>) + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>, + !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) return } ifrt.LoadedExecutable @callee on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> // ----- func.func @io_aliases_should_have_valid_output_index( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias input #0 to output #1 as only having 1 outputs}} %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0) {io_aliases=[array]} - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> return } ifrt.LoadedExecutable @callee on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> // ----- func.func @io_aliases_should_only_alias_output_once( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias output #0 more than once}} %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0, %arg0) {io_aliases=[array, array]} - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>, - !ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>, + !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> return } ifrt.LoadedExecutable @callee on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> // ----- func.func @io_aliases_should_have_same_type( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { - // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias input #0 to output #0 with different types: '!ifrt.array, 1x1 to [0] on 2, [0, 1]>' vs '!ifrt.array, 2x1 to [0] on 2, [0, 1]>'}} + // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias input #0 to output #0 with different types: '!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]>' vs '!ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>'}} %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0) {io_aliases=[array]} - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 2x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]> return } ifrt.LoadedExecutable @callee on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 2x1 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]> diff --git a/xla/python/ifrt/ir/tests/verify_disassemble.mlir b/xla/python/ifrt/ir/tests/verify_disassemble.mlir index e7000af3b5fd7..e36946470c23b 100644 --- a/xla/python/ifrt/ir/tests/verify_disassemble.mlir +++ b/xla/python/ifrt/ir/tests/verify_disassemble.mlir @@ -1,53 +1,69 @@ // RUN: ifrt-opt %s -split-input-file -verify-diagnostics func.func @good_disassemble( - %arg0: !ifrt.array, 1x2 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]>) attributes {ifrt.function} { %0, %1 = "ifrt.Disassemble"(%arg0) {operand_segment_sizes=array} - : (!ifrt.array, 1x2 to [0] on 2, [0,1]>) - -> (!ifrt.array, 1x1 to [0] on 1, [0]>, - !ifrt.array, 1x1 to [0] on 1, [1]>) + : (!ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]>) + -> (!ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [0]>, + !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [1]>) return } // ----- func.func @disassemble_requires_in_ifrt_function( - %arg0: !ifrt.array, 1x2 to [0] on 2, [0,1]>) { + %arg0: !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]>) { // expected-error@+1 {{'ifrt.Disassemble' op must be in a FuncOp with attr `ifrt.function`}} %0, %1 = "ifrt.Disassemble"(%arg0) {operand_segment_sizes=array} - : (!ifrt.array, 1x2 to [0] on 2, [0,1]>) - -> (!ifrt.array, 1x1 to [0] on 1, [0]>, - !ifrt.array, 1x1 to [0] on 1, [1]>) + : (!ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]>) + -> (!ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [0]>, + !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [1]>) return } // ----- func.func @disassemble_requires_outputs_on_single_devices( - %arg0: !ifrt.array, 1x4 to [0, 1] on 2x2, [0,1,2,3]>) + %arg0: !ifrt.array, + #ifrt.sharding_param<1x4 to [0, 1] on 2x2>, [0,1,2,3]>) attributes {ifrt.function} { - // expected-error@+1 {{'ifrt.Disassemble' op requires every output to be a single device array. Actual: '!ifrt.array, 1x2 to [0] on 2, [0, 1]>'}} + // expected-error@+1 {{'ifrt.Disassemble' op requires every output to be a single device array. Actual: '!ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, [0, 1]>'}} %0, %1 = "ifrt.Disassemble"(%arg0) {operand_segment_sizes=array} - : (!ifrt.array, 1x4 to [0, 1] on 2x2, [0,1,2,3]>) - -> (!ifrt.array, 1x2 to [0] on 2, [0,1]>, - !ifrt.array, 1x2 to [0] on 2, [2,3]>) + : (!ifrt.array, + #ifrt.sharding_param<1x4 to [0, 1] on 2x2>, [0,1,2,3]>) + -> (!ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]>, + !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [2,3]>) return } // ----- func.func @disassemble_requires_same_device_list( - %arg0: !ifrt.array, 1x2 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.Disassemble' op requires the same input/output device list. Input 0, 1 vs Output 1, 2}} %0, %1 = "ifrt.Disassemble"(%arg0) {operand_segment_sizes=array} - : (!ifrt.array, 1x2 to [0] on 2, [0,1]>) - -> (!ifrt.array, 1x1 to [0] on 1, [1]>, - !ifrt.array, 1x1 to [0] on 1, [2]>) + : (!ifrt.array, + #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]>) + -> (!ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [1]>, + !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [2]>) return } diff --git a/xla/python/ifrt/ir/tests/verify_loaded_executable.mlir b/xla/python/ifrt/ir/tests/verify_loaded_executable.mlir index 9839e57fcc327..23d6d9759ff42 100644 --- a/xla/python/ifrt/ir/tests/verify_loaded_executable.mlir +++ b/xla/python/ifrt/ir/tests/verify_loaded_executable.mlir @@ -1,13 +1,17 @@ // RUN: ifrt-opt %s -split-input-file -verify-diagnostics ifrt.LoadedExecutable @good on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> #devices = #ifrt ifrt.LoadedExecutable @good_with_aliased_devices on devices #devices - : (!ifrt.array, 1x1 to [0] on 2, #devices>) - -> !ifrt.array, 1x2 to [0] on 2, #devices> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + #devices>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + #devices> // ----- @@ -23,21 +27,27 @@ ifrt.LoadedExecutable @requires_array_output on devices [0,1] // ----- -// expected-error@+1 {{'ifrt.LoadedExecutable' Device list has duplicate id 0}} +// expected-error@+1 {{'ifrt.LoadedExecutable' Device list has duplicate logical id 0}} ifrt.LoadedExecutable @requires_unique_devices_attr on devices [0,0] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> // ----- -// expected-error@+1 {{'ifrt.LoadedExecutable' op requires all inputs placed on `devices` attr. The following input is placed on device 2 not found in `devices` attr. '!ifrt.array, 1x1 to [0] on 2, [0, 2]>'}} +// expected-error@+1 {{'ifrt.LoadedExecutable' op requires all inputs placed on `devices` attr. The following input is placed on device 2 not found in `devices` attr. '!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 2]>'}} ifrt.LoadedExecutable @requires_input_place_on_devices on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,2]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,1]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,2]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> // ----- -// expected-error@+1 {{'ifrt.LoadedExecutable' op requires all outputs placed on `devices` attr. The following output is placed on device 2 not found in `devices` attr. '!ifrt.array, 1x2 to [0] on 2, [0, 2]>'}} +// expected-error@+1 {{'ifrt.LoadedExecutable' op requires all outputs placed on `devices` attr. The following output is placed on device 2 not found in `devices` attr. '!ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, [0, 2]>'}} ifrt.LoadedExecutable @requires_output_place_on_devices on devices [0,1] - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [0] on 2, [0,2]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,2]> diff --git a/xla/python/ifrt/ir/tests/verify_reshard.mlir b/xla/python/ifrt/ir/tests/verify_reshard.mlir index cc8370e81f9ad..7731b9664ce2b 100644 --- a/xla/python/ifrt/ir/tests/verify_reshard.mlir +++ b/xla/python/ifrt/ir/tests/verify_reshard.mlir @@ -1,67 +1,87 @@ // RUN: ifrt-opt %s -split-input-file -verify-diagnostics func.func @good_reshard( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { %0 = ifrt.Reshard(%arg0) - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 4, [0,1,2,3]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 4>, + [0,1,2,3]> return } func.func @good_reshard_with_control_dep( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>, + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>, %arg1: !ifrt.control) attributes {ifrt.function} { %0 = ifrt.Reshard(%arg0) after %arg1 - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 4, [0,1,2,3]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 4>, + [0,1,2,3]> return } // ----- func.func @reshard_requires_in_ifrt_function( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) { + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) { // expected-error@+1 {{'ifrt.Reshard' op must be in a FuncOp with attr `ifrt.function`}} %0 = ifrt.Reshard(%arg0) - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 4, [0,1,2,3]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 4>, + [0,1,2,3]> return } // ----- func.func @reshard_requires_same_global_shape( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { // expected-error@+1 {{'ifrt.Reshard' op requires the same global shape. Input 'tensor<2x2xi32>' vs Output 'tensor<2x1xi32>'}} %0 = ifrt.Reshard(%arg0) - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x1 to [0] on 2, [2,3]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [2,3]> return } // ----- func.func @reshard_requires_non_negative_axis_index( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { - // expected-error@+3 {{Out of range axis -1 to the mesh of -1 on 2}} + // expected-error@+5 {{Out of range axis -1 to the mesh of [-1] on 2}} + // expected-error@+4 {{failed to parse Ifrt_ArrayType parameter 'sharding_attr'}} %0 = ifrt.Reshard(%arg0) - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [-1] on 2, [2,3]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [-1] on 2>, + [2,3]> return } // ----- func.func @reshard_requires_valid_axis_index( - %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) + %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) attributes {ifrt.function} { - // expected-error@+3 {{Out of range axis 1234567890 to the mesh of 1234567890 on 2}} + // expected-error@+6 {{Out of range axis 1234567890 to the mesh of [1234567890] on 2}} + // expected-error@+5 {{failed to parse Ifrt_ArrayType parameter 'sharding_attr'}} %0 = ifrt.Reshard(%arg0) - : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) - -> !ifrt.array, 1x2 to [1234567890] on 2, [2,3]> + : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, + #ifrt.sharding_param<1x2 to [1234567890] on 2>, [2,3]> return } diff --git a/xla/python/ifrt/ir/transforms/BUILD b/xla/python/ifrt/ir/transforms/BUILD index a3c7f51edfd81..356249f57f98d 100644 --- a/xla/python/ifrt/ir/transforms/BUILD +++ b/xla/python/ifrt/ir/transforms/BUILD @@ -1,4 +1,5 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -8,6 +9,7 @@ package( gentbl_cc_library( name = "passes_inc_gen", + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -28,10 +30,12 @@ cc_library( name = "passes", srcs = [ "ifrt_duplicated_callee_elimination_pass.cc", + "ifrt_verify_sharding_specified_pass.cc", "spmd_expandable_interface_verification_pass.cc", "spmd_expansion_pass.cc", ], hdrs = ["passes.h"], + compatible_with = get_compatible_with_portable(), deps = [ ":constants", ":passes_inc_gen", @@ -40,7 +44,6 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -53,6 +56,7 @@ cc_library( name = "built_in_spmd_expansions", srcs = ["built_in_spmd_expansions.cc"], hdrs = ["built_in_spmd_expansions.h"], + compatible_with = get_compatible_with_portable(), deps = [ "//xla/python/ifrt/ir/transforms/spmd_expanders:spmd_expander", "@llvm-project//mlir:FuncDialect", @@ -63,5 +67,6 @@ cc_library( cc_library( name = "constants", hdrs = ["constants.h"], + compatible_with = get_compatible_with_portable(), deps = ["@llvm-project//llvm:Support"], ) diff --git a/xla/python/ifrt/ir/transforms/built_in_spmd_expansions.cc b/xla/python/ifrt/ir/transforms/built_in_spmd_expansions.cc index 80bd5e693f4e2..b0f25459f7fc9 100644 --- a/xla/python/ifrt/ir/transforms/built_in_spmd_expansions.cc +++ b/xla/python/ifrt/ir/transforms/built_in_spmd_expansions.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/ir/transforms/built_in_spmd_expansions.h b/xla/python/ifrt/ir/transforms/built_in_spmd_expansions.h index 257d4c51741a0..9151e6ca0030e 100644 --- a/xla/python/ifrt/ir/transforms/built_in_spmd_expansions.h +++ b/xla/python/ifrt/ir/transforms/built_in_spmd_expansions.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/ir/transforms/constants.h b/xla/python/ifrt/ir/transforms/constants.h index cd65bfcb0b507..98bfd12e2c19b 100644 --- a/xla/python/ifrt/ir/transforms/constants.h +++ b/xla/python/ifrt/ir/transforms/constants.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/ir/transforms/ifrt_duplicated_callee_elimination_pass.cc b/xla/python/ifrt/ir/transforms/ifrt_duplicated_callee_elimination_pass.cc index 02a45776ac6a0..c708e80c571ca 100644 --- a/xla/python/ifrt/ir/transforms/ifrt_duplicated_callee_elimination_pass.cc +++ b/xla/python/ifrt/ir/transforms/ifrt_duplicated_callee_elimination_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/ir/transforms/ifrt_verify_sharding_specified_pass.cc b/xla/python/ifrt/ir/transforms/ifrt_verify_sharding_specified_pass.cc new file mode 100644 index 0000000000000..9aaa990832a26 --- /dev/null +++ b/xla/python/ifrt/ir/transforms/ifrt_verify_sharding_specified_pass.cc @@ -0,0 +1,106 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/ir/transforms/passes.h" + +namespace xla { +namespace ifrt { + +namespace { + +#define GEN_PASS_DEF_IFRTVERIFYSHARDINGSPECIFIEDPASS +#include "xla/python/ifrt/ir/transforms/passes.h.inc" + +bool IsArrayWithUnspecifiedSharding(mlir::Type type) { + auto array_type = llvm::dyn_cast_or_null(type); + if (array_type == nullptr) { + return false; + } + return array_type.getShardingAttr().isa(); +} + +class IfrtVerifyShardingSpecifiedPass + : public impl::IfrtVerifyShardingSpecifiedPassBase< + IfrtVerifyShardingSpecifiedPass> { + public: + void runOnOperation() override; +}; + +void IfrtVerifyShardingSpecifiedPass::runOnOperation() { + mlir::ModuleOp module_op = getOperation(); + mlir::WalkResult result = + module_op.walk([](mlir::Operation* op) -> mlir::WalkResult { + auto func_op = llvm::dyn_cast_or_null(op); + if (func_op != nullptr) { + mlir::FunctionType func_type = func_op.getFunctionType(); + for (const auto [idx, input_type] : + llvm::enumerate(func_type.getInputs())) { + if (IsArrayWithUnspecifiedSharding(input_type)) { + return op->emitOpError() + << "argument " << idx << " has unspecified sharding."; + } + } + for (const auto [idx, result_type] : + llvm::enumerate(func_type.getResults())) { + if (IsArrayWithUnspecifiedSharding(result_type)) { + return op->emitOpError() + << "result " << idx << " has unspecified sharding."; + } + } + } else { + for (const auto [idx, operand_type] : + llvm::enumerate(op->getOperandTypes())) { + if (IsArrayWithUnspecifiedSharding(operand_type)) { + return op->emitOpError() + << "argument " << idx << " has unspecified sharding."; + } + } + for (const auto [idx, result_type] : + llvm::enumerate(op->getResultTypes())) { + if (IsArrayWithUnspecifiedSharding(result_type)) { + return op->emitOpError() + << "result " << idx << " has unspecified sharding."; + } + } + } + return mlir::WalkResult::advance(); + }); + if (result.wasInterrupted()) { + signalPassFailure(); + return; + } +} + +} // namespace + +std::unique_ptr> +CreateIfrtVerifyShardingSpecifiedPass() { + return std::make_unique(); +} + +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt/ir/transforms/passes.h b/xla/python/ifrt/ir/transforms/passes.h index d6735fe858b03..546eb55e68179 100644 --- a/xla/python/ifrt/ir/transforms/passes.h +++ b/xla/python/ifrt/ir/transforms/passes.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -39,6 +39,9 @@ std::unique_ptr> CreateSpmdExpansionPass(); std::unique_ptr> CreateIfrtDuplicatedCalleeEliminationPass(); +std::unique_ptr> +CreateIfrtVerifyShardingSpecifiedPass(); + // Generated definitions. This should be placed after all Pass creations. #define GEN_PASS_REGISTRATION #include "xla/python/ifrt/ir/transforms/passes.h.inc" // IWYU pragma: export diff --git a/xla/python/ifrt/ir/transforms/passes.td b/xla/python/ifrt/ir/transforms/passes.td index 15350134e8017..fe9709004f35f 100644 --- a/xla/python/ifrt/ir/transforms/passes.td +++ b/xla/python/ifrt/ir/transforms/passes.td @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -62,7 +62,7 @@ For example, the following: ```mlir #device = #ifrt -#sharding = #ifrt.sharding<2x1 to [0] on 2> +#sharding = #ifrt.sharding_param<2x1 to [0] on 2> module attributes {ifrt.devices = #device} { func.func @main( %arg0: tensor<2x2xi32> {ifrt.sharding = #sharding, @@ -79,7 +79,7 @@ will be transformed into: ```mlir // The function's input and output shapes are now local. #device = #ifrt -#sharding = #ifrt.sharding<2x1 to [0] on 2> +#sharding = #ifrt.sharding_param<2x1 to [0] on 2> module attributes {ifrt.devices = #device} { func.func @main( %arg0: tensor<1x2xi32> {ifrt.sharding = #sharding, @@ -107,4 +107,16 @@ them. The duplicated callee `FuncOp` will not be removed. let constructor = "CreateIfrtDuplicatedCalleeEliminationPass()"; } +def IfrtVerifyShardingSpecifiedPass : + Pass<"ifrt-verify-sharding-specified", "mlir::ModuleOp"> { + let summary = "Verify that all `!ifrt.array` have sharding specified."; + let description = [{ +Verify that each `!ifrt.array` has sharding attribute that is not of type +`!ifrt.sharding_unspecified`. + }]; + + let constructor = "CreateIfrtVerifyShardingSpecifiedPass()"; +} + + #endif // XLA_PYTHON_IFRT_IR_TRANSFORMS_PASSES_TD_ diff --git a/xla/python/ifrt/ir/transforms/spmd_expandable_interface_verification_pass.cc b/xla/python/ifrt/ir/transforms/spmd_expandable_interface_verification_pass.cc index 6a034df2bfeda..b7b8dee3ec512 100644 --- a/xla/python/ifrt/ir/transforms/spmd_expandable_interface_verification_pass.cc +++ b/xla/python/ifrt/ir/transforms/spmd_expandable_interface_verification_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/ir/transforms/spmd_expanders/BUILD b/xla/python/ifrt/ir/transforms/spmd_expanders/BUILD index f98f3f53872e2..da4979e16b981 100644 --- a/xla/python/ifrt/ir/transforms/spmd_expanders/BUILD +++ b/xla/python/ifrt/ir/transforms/spmd_expanders/BUILD @@ -1,3 +1,5 @@ +load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") + package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], @@ -11,6 +13,7 @@ cc_library( hdrs = glob([ "*spmd_expander.h", ]), + compatible_with = get_compatible_with_portable(), visibility = ["//xla/python/ifrt:friends"], deps = [ "//xla/python/ifrt/ir", diff --git a/xla/python/ifrt/ir/transforms/spmd_expanders/noop_ifrt_spmd_expander.h b/xla/python/ifrt/ir/transforms/spmd_expanders/noop_ifrt_spmd_expander.h index f5542d73ec5ad..dc900f3a4ec8b 100644 --- a/xla/python/ifrt/ir/transforms/spmd_expanders/noop_ifrt_spmd_expander.h +++ b/xla/python/ifrt/ir/transforms/spmd_expanders/noop_ifrt_spmd_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/ir/transforms/spmd_expanders/terminator_ifrt_spmd_expander.h b/xla/python/ifrt/ir/transforms/spmd_expanders/terminator_ifrt_spmd_expander.h index ca69d7a96d6ea..9ebd91251def8 100644 --- a/xla/python/ifrt/ir/transforms/spmd_expanders/terminator_ifrt_spmd_expander.h +++ b/xla/python/ifrt/ir/transforms/spmd_expanders/terminator_ifrt_spmd_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/ir/transforms/spmd_expanders/unimplemented_ifrt_spmd_expander.h b/xla/python/ifrt/ir/transforms/spmd_expanders/unimplemented_ifrt_spmd_expander.h index 5d4a518abf428..cab76e43d7839 100644 --- a/xla/python/ifrt/ir/transforms/spmd_expanders/unimplemented_ifrt_spmd_expander.h +++ b/xla/python/ifrt/ir/transforms/spmd_expanders/unimplemented_ifrt_spmd_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc b/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc index ecf4fdd76a530..cc13c47df669c 100644 --- a/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc +++ b/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,25 +13,30 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include -#include #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/types/span.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "xla/python/ifrt/ir/ifrt_dialect.h" #include "xla/python/ifrt/ir/ifrt_interfaces.h" -#include "xla/python/ifrt/ir/sharding_param.h" #include "xla/python/ifrt/ir/transforms/constants.h" #include "xla/python/ifrt/ir/transforms/passes.h" @@ -140,23 +145,6 @@ mlir::Operation* TopologicalIterator::next() { bool TopologicalIterator::hasNext() { return !ops_to_visit_.empty(); } -absl::StatusOr> LocalShapeFromGlobalShape( - absl::Span global_shape, ShardingParam sharding_param) { - auto num_shards = sharding_param.dim_shards(); - std::vector local_shape; - local_shape.reserve(global_shape.size()); - for (int i = 0; i < num_shards.size(); ++i) { - if (global_shape[i] % num_shards[i] != 0) { - return absl::InvalidArgumentError(absl::StrCat( - "Global shape is not divisible by the number of shards in dimension ", - i, ". Global size: ", global_shape[i], - ", number of shards: ", num_shards[i], ".")); - } - local_shape.push_back(global_shape[i] / num_shards[i]); - } - return local_shape; -} - // Updates `function` input signature operand at `argument_index` with // `new_shape`. void UpdateFunctionInputShape(const int argument_index, @@ -178,14 +166,13 @@ mlir::LogicalResult UpdateFunctionArgsUsingSharding( // can have resource type as input. for (int i = 0; i < function.getNumArguments(); ++i) { auto arg_sharding_attr = - function.getArgAttrOfType(i, kIfrtShardingAttrName); + function.getArgAttrOfType( + i, kIfrtShardingAttrName); if (arg_sharding_attr == nullptr) { return function.emitOpError() << "requires `" << kIfrtShardingAttrName << "` attribute on arg " << i; } - ShardingParam sharding = arg_sharding_attr.getSharding(); - auto value = function.getFunctionType().getInput(i); mlir::RankedTensorType ranked_type = @@ -196,8 +183,8 @@ mlir::LogicalResult UpdateFunctionArgsUsingSharding( } llvm::ArrayRef arg_shape = ranked_type.getShape(); - absl::StatusOr> arg_local_shape = - LocalShapeFromGlobalShape(arg_shape, sharding); + absl::StatusOr> arg_local_shape = + arg_sharding_attr.LocalShapeFromGlobalShape(arg_shape); if (!arg_local_shape.ok()) { return function.emitOpError() << arg_local_shape.status().message(); } diff --git a/xla/python/ifrt/memory.cc b/xla/python/ifrt/memory.cc index 27a7ad8d6d551..caaa3e720fd9c 100644 --- a/xla/python/ifrt/memory.cc +++ b/xla/python/ifrt/memory.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/memory.h b/xla/python/ifrt/memory.h index 8cd1a9538a035..0a5525e693f92 100644 --- a/xla/python/ifrt/memory.h +++ b/xla/python/ifrt/memory.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/memory_test.cc b/xla/python/ifrt/memory_test.cc index bb223471c4477..ccd6d02895559 100644 --- a/xla/python/ifrt/memory_test.cc +++ b/xla/python/ifrt/memory_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/mock.cc b/xla/python/ifrt/mock.cc index 57d9004813de1..24e7186f0e447 100644 --- a/xla/python/ifrt/mock.cc +++ b/xla/python/ifrt/mock.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,11 +15,29 @@ limitations under the License. #include "xla/python/ifrt/mock.h" +#include #include #include #include +#include #include +#include +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/literal.h" +#include "xla/pjrt/pjrt_device_description.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/value.h" +#include "tsl/concurrency/ref_count.h" + namespace xla { namespace ifrt { @@ -57,6 +75,10 @@ MockArray::MockArray(tsl::RCReference delegated) ON_CALL(*this, shared_ptr_sharding).WillByDefault([this]() { return delegated_->shared_ptr_sharding(); }); + ON_CALL(*this, layout) + .WillByDefault([this]() -> absl::StatusOr> { + return delegated_->layout(); + }); ON_CALL(*this, DisassembleIntoSingleDeviceArrays) .WillByDefault([this](ArrayCopySemantics semantics) { return delegated_->DisassembleIntoSingleDeviceArrays(semantics); @@ -156,6 +178,12 @@ MockClient::MockClient(std::unique_ptr delegated) .WillByDefault([this](absl::Span devices) { return delegated_->GetTopologyForDevices(devices); }); + ON_CALL(*this, GetDefaultLayoutForDevice) + .WillByDefault([this](xla::ifrt::DType dtype, + absl::Span dims, + xla::ifrt::Device* device) { + return delegated_->GetDefaultLayoutForDevice(dtype, dims, device); + }); } // LINT.ThenChange() @@ -171,10 +199,18 @@ MockDevice::MockDevice(Device* delegated) : delegated_(delegated) { .WillByDefault([this]() -> const xla::PjRtDeviceDescription& { return delegated_->description(); }); - ON_CALL(*this, id).WillByDefault([this]() { return delegated_->id(); }); + ON_CALL(*this, global_device_id).WillByDefault([this]() { + return delegated_->global_device_id(); + }); ON_CALL(*this, process_index).WillByDefault([this]() { return delegated_->process_index(); }); + ON_CALL(*this, local_device_id).WillByDefault([this]() { + return delegated_->local_device_id(); + }); + ON_CALL(*this, local_hardware_id_typed).WillByDefault([this]() { + return delegated_->local_hardware_id_typed(); + }); ON_CALL(*this, local_hardware_id).WillByDefault([this]() { return delegated_->local_hardware_id(); }); @@ -187,9 +223,12 @@ MockDevice::MockDevice(Device* delegated) : delegated_(delegated) { ON_CALL(*this, ToString).WillByDefault([this]() { return delegated_->ToString(); }); - ON_CALL(*this, Attributes).WillByDefault([this]() { - return delegated_->Attributes(); - }); + ON_CALL(*this, Attributes) + .WillByDefault( + [this]() + -> const absl::flat_hash_map& { + return delegated_->Attributes(); + }); ON_CALL(*this, CreateAsyncTrackingEvent) .WillByDefault([this](absl::string_view description) { return delegated_->CreateAsyncTrackingEvent(description); diff --git a/xla/python/ifrt/mock.h b/xla/python/ifrt/mock.h index 9ba0583d0015d..cff44976c0a68 100644 --- a/xla/python/ifrt/mock.h +++ b/xla/python/ifrt/mock.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" @@ -64,17 +65,19 @@ class MockArray final : public llvm::RTTIExtends { MOCK_METHOD(const Sharding&, sharding, (), (const, final)); MOCK_METHOD(std::shared_ptr, shared_ptr_sharding, (), (const, final)); - MOCK_METHOD(StatusOr>>, + MOCK_METHOD(absl::StatusOr>, layout, (), + (const, final)); + MOCK_METHOD(absl::StatusOr>>, DisassembleIntoSingleDeviceArrays, (ArrayCopySemantics semantics), (final)); - MOCK_METHOD(StatusOr>, FullyReplicatedShard, + MOCK_METHOD(absl::StatusOr>, FullyReplicatedShard, (ArrayCopySemantics semantics), (final)); MOCK_METHOD(Future, CopyToHostBuffer, (void* data, std::optional> byte_strides, ArrayCopySemantics semantics), (final)); - MOCK_METHOD(StatusOr>, Reshard, + MOCK_METHOD(absl::StatusOr>, Reshard, (std::shared_ptr new_sharding, ArrayCopySemantics semantics), (final)); @@ -96,20 +99,20 @@ class MockClient final : public llvm::RTTIExtends { explicit MockClient(std::unique_ptr delegated); // LINT.IfChange - MOCK_METHOD(StatusOr>, MakeArrayFromHostBuffer, + MOCK_METHOD(absl::StatusOr>, MakeArrayFromHostBuffer, (const void* data, DType dtype, Shape shape, std::optional> byte_strides, std::shared_ptr sharding, HostBufferSemantics semantics, std::function on_done_with_host_buffer), (final)); - MOCK_METHOD(StatusOr>, + MOCK_METHOD(absl::StatusOr>, AssembleArrayFromSingleDeviceArrays, (Shape shape, std::shared_ptr sharding, absl::Span> arrays, ArrayCopySemantics semantics), (final)); - MOCK_METHOD(StatusOr>, MakeTuple, + MOCK_METHOD(absl::StatusOr>, MakeTuple, (absl::Span> values), (final)); MOCK_METHOD(absl::string_view, runtime_type, (), (const, final)); MOCK_METHOD(absl::string_view, platform_name, (), (const, final)); @@ -123,16 +126,22 @@ class MockClient final : public llvm::RTTIExtends { MOCK_METHOD(absl::Span, addressable_devices, (), (const, final)); MOCK_METHOD(int, process_index, (), (const, final)); - MOCK_METHOD(StatusOr, GetDefaultDeviceAssignment, + MOCK_METHOD(absl::StatusOr, GetDefaultDeviceAssignment, (int num_replicas, int num_partitions), (const, final)); - MOCK_METHOD(StatusOr, LookupDevice, (int device_id), (const, final)); - MOCK_METHOD(StatusOr, LookupAddressableDevice, + MOCK_METHOD(absl::StatusOr, LookupDevice, (int device_id), + (const, final)); + MOCK_METHOD(absl::StatusOr, LookupAddressableDevice, (int local_hardware_id), (const, final)); MOCK_METHOD(Compiler*, GetDefaultCompiler, (), (final)); MOCK_METHOD( absl::StatusOr>, GetTopologyForDevices, (absl::Span devices), (const, final)); + MOCK_METHOD(absl::StatusOr>, + GetDefaultLayoutForDevice, + (xla::ifrt::DType dtype, absl::Span dims, + xla::ifrt::Device* device), + (const, final)); // LINT.ThenChange(mock.cc:MockClientDelegation) xla::ifrt::Client* delegated() const { return delegated_.get(); } @@ -147,11 +156,11 @@ class MockClient final : public llvm::RTTIExtends { class MockCompiler final : public llvm::RTTIExtends { public: - MOCK_METHOD(StatusOr>, Compile, + MOCK_METHOD(absl::StatusOr>, Compile, (std::unique_ptr program, std::unique_ptr options), (final)); - MOCK_METHOD(StatusOr>, + MOCK_METHOD(absl::StatusOr>, DeserializeLoadedExecutable, (absl::string_view serialized, std::unique_ptr options), @@ -172,9 +181,12 @@ class MockDevice final : public Device { MOCK_METHOD(bool, IsAddressable, (), (const, final)); MOCK_METHOD(const xla::PjRtDeviceDescription&, description, (), (const, final)); - MOCK_METHOD(int, id, (), (const, final)); + MOCK_METHOD(xla::PjRtGlobalDeviceId, global_device_id, (), (const, final)); MOCK_METHOD(int, process_index, (), (const, final)); MOCK_METHOD(int, local_hardware_id, (), (const, final)); + MOCK_METHOD(xla::PjRtLocalDeviceId, local_device_id, (), (const, final)); + MOCK_METHOD(xla::PjRtLocalHardwareId, local_hardware_id_typed, (), + (const, final)); MOCK_METHOD(absl::string_view, device_kind, (), (const, final)); MOCK_METHOD(absl::string_view, DebugString, (), (const, final)); MOCK_METHOD(absl::string_view, ToString, (), (const, final)); @@ -187,9 +199,9 @@ class MockDevice final : public Device { MOCK_METHOD(Status, TransferToInfeed, (const LiteralSlice& literal), (final)); MOCK_METHOD(Status, TransferFromOutfeed, (MutableBorrowingLiteral literal), (final)); - MOCK_METHOD(StatusOr, default_memory_space, (), + MOCK_METHOD(absl::StatusOr, default_memory_space, (), (const, final)); - MOCK_METHOD(StatusOr, GetAllocatorStats, (), + MOCK_METHOD(absl::StatusOr, GetAllocatorStats, (), (const, final)); MOCK_METHOD(absl::Span, memory_spaces, (), (const, final)); @@ -219,25 +231,26 @@ class MockExecutable final : public llvm::RTTIExtends { public: MOCK_METHOD(absl::string_view, name, (), (const, final)); - MOCK_METHOD(StatusOr>, Fingerprint, (), + MOCK_METHOD(absl::StatusOr>, Fingerprint, (), (const, final)); - MOCK_METHOD(StatusOr, Serialize, (), (const, final)); + MOCK_METHOD(absl::StatusOr, Serialize, (), (const, final)); MOCK_METHOD(int, num_devices, (), (const, final)); MOCK_METHOD(int64_t, SizeOfGeneratedCodeInBytes, (), (const, final)); - MOCK_METHOD(StatusOr, GetCompiledMemoryStats, (), + MOCK_METHOD(absl::StatusOr, GetCompiledMemoryStats, (), (const, final)); MOCK_METHOD(std::optional>, GetParameterShardings, (), (const, final)); MOCK_METHOD(std::optional>, GetOutputShardings, (), (const, final)); - MOCK_METHOD(StatusOr>, GetParameterLayouts, (), - (const, final)); - MOCK_METHOD(StatusOr>, GetOutputLayouts, (), - (const, final)); - MOCK_METHOD(StatusOr>>, GetHloModules, - (), (const, final)); - MOCK_METHOD((StatusOr>), - GetCostAnalysis, (), (const, final)); + MOCK_METHOD(absl::StatusOr>>, + GetParameterLayouts, (), (const, final)); + MOCK_METHOD(absl::StatusOr>>, + GetOutputLayouts, (), (const, final)); + MOCK_METHOD(absl::StatusOr>>, + GetHloModules, (), (const, final)); + MOCK_METHOD( + (absl::StatusOr>), + GetCostAnalysis, (), (const, final)); static char ID; // NOLINT }; @@ -247,29 +260,31 @@ class MockLoadedExecutable final public: MOCK_METHOD(Client*, client, (), (const, final)); MOCK_METHOD(absl::string_view, name, (), (const, final)); - MOCK_METHOD(StatusOr>, Fingerprint, (), + MOCK_METHOD(absl::StatusOr>, Fingerprint, (), (const, final)); - MOCK_METHOD(StatusOr, Serialize, (), (const, final)); + MOCK_METHOD(absl::StatusOr, Serialize, (), (const, final)); + MOCK_METHOD(Future, GetReadyFuture, (), (const, override)); MOCK_METHOD(int, num_devices, (), (const, final)); MOCK_METHOD(int64_t, SizeOfGeneratedCodeInBytes, (), (const, final)); - MOCK_METHOD(StatusOr, GetCompiledMemoryStats, (), + MOCK_METHOD(absl::StatusOr, GetCompiledMemoryStats, (), (const, final)); MOCK_METHOD(std::optional>, GetParameterShardings, (), (const, final)); MOCK_METHOD(std::optional>, GetOutputShardings, (), (const, final)); - MOCK_METHOD(StatusOr>, GetParameterLayouts, (), - (const, final)); - MOCK_METHOD(StatusOr>, GetOutputLayouts, (), - (const, final)); + MOCK_METHOD(absl::StatusOr>>, + GetParameterLayouts, (), (const, final)); + MOCK_METHOD(absl::StatusOr>>, + GetOutputLayouts, (), (const, final)); MOCK_METHOD(absl::StatusOr>>, GetOutputMemoryKinds, (), (const, final)); - MOCK_METHOD(StatusOr>>, GetHloModules, - (), (const, final)); - MOCK_METHOD((StatusOr>), - GetCostAnalysis, (), (const, final)); - MOCK_METHOD(StatusOr, Execute, + MOCK_METHOD(absl::StatusOr>>, + GetHloModules, (), (const, final)); + MOCK_METHOD( + (absl::StatusOr< + absl::flat_hash_map>), + GetCostAnalysis, (), (const, final)); + MOCK_METHOD(absl::StatusOr, Execute, (absl::Span> args, const ExecuteOptions& options, std::optional devices), @@ -298,7 +313,7 @@ class MockLoadedHostCallback final : public llvm::RTTIExtends { public: MOCK_METHOD(Client*, client, (), (const, final)); - MOCK_METHOD(StatusOr, Serialize, (), (const, final)); + MOCK_METHOD(absl::StatusOr, Serialize, (), (const, final)); static char ID; // NOLINT }; @@ -308,10 +323,10 @@ class MockLoadedHostCallback final class MockSharding : public llvm::RTTIExtends { public: MOCK_METHOD( - (StatusOr< + (absl::StatusOr< std::vector>>>), Disassemble, (const Shape& shape), (const, final)); - MOCK_METHOD(StatusOr>, IndexDomains, + MOCK_METHOD(absl::StatusOr>, IndexDomains, (const Shape& shape), (const, final)); MOCK_METHOD(std::string, DebugString, (), (const, final)); diff --git a/xla/python/ifrt/no_impl_test_main.cc b/xla/python/ifrt/no_impl_test_main.cc index eed1ff461880f..b2f7a1adfea35 100644 --- a/xla/python/ifrt/no_impl_test_main.cc +++ b/xla/python/ifrt/no_impl_test_main.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/plugin_program.cc b/xla/python/ifrt/plugin_program.cc new file mode 100644 index 0000000000000..4e27deb38bdbe --- /dev/null +++ b/xla/python/ifrt/plugin_program.cc @@ -0,0 +1,24 @@ +// Copyright 2024 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt/plugin_program.h" + +namespace xla { +namespace ifrt { + +char PluginProgram::ID = 0; +char PluginCompileOptions::ID = 0; + +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt/plugin_program.h b/xla/python/ifrt/plugin_program.h new file mode 100644 index 0000000000000..54c59e9a72666 --- /dev/null +++ b/xla/python/ifrt/plugin_program.h @@ -0,0 +1,57 @@ +/* + * Copyright 2024 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PLUGIN_PROGRAM_H_ +#define XLA_PYTHON_IFRT_PLUGIN_PROGRAM_H_ + +#include + +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/python/ifrt/compiler.h" + +namespace xla { +namespace ifrt { + +// `PluginProgram` is a subclass of `xla::ifrt::Program` used mainly with +// the IFRT proxy as of Apr 2024, and facilitates generic RPCs from the IFRT +// frontend (on the proxy-client) to the IFRT backend (on the proxy-server). A +// `PluginProgram` and its compiled executable need not be associated with a +// particular `xla::ifrt::Device`; instead, IFRT backends are expected to +// intercept and act on the compilation and subsequent executions of +// PluginProgram without passing them to particular devices. +// +// Another way to think of `PluginProgram` is that it is associated with a +// 'controller device', as opposed to CPU or GPU devices, where the term +// 'controller' means the same as in 'JAX uses a multi-controller programming +// model'. +struct PluginProgram + : public llvm::RTTIExtends { + std::string data; + static char ID; // NOLINT +}; + +struct PluginCompileOptions + : llvm::RTTIExtends { + PluginCompileOptions() = default; + ~PluginCompileOptions() override = default; + + static char ID; // NOLINT +}; + +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PLUGIN_PROGRAM_H_ diff --git a/xla/python/ifrt/plugin_program_serdes.cc b/xla/python/ifrt/plugin_program_serdes.cc new file mode 100644 index 0000000000000..b5ff9ff1d2cb6 --- /dev/null +++ b/xla/python/ifrt/plugin_program_serdes.cc @@ -0,0 +1,102 @@ +// Copyright 2024 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/python/ifrt/plugin_program.h" +#include "xla/python/ifrt/serdes.h" + +namespace xla { +namespace ifrt { + +namespace { + +constexpr absl::string_view kSerializationPrefix = + "__serialized_plugin_program "; + +class PluginProgramSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::PluginProgram"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + return absl::StrCat(kSerializationPrefix, + llvm::cast(serializable).data); + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr) override { + if (!absl::StartsWith(serialized, kSerializationPrefix)) { + return absl::InvalidArgumentError( + absl::StrCat("Bad serialized ", type_name())); + } + absl::string_view data(serialized); + data.remove_prefix(kSerializationPrefix.size()); + auto result = std::make_unique(); + result->data = data; + return result; + } + + static char ID; // NOLINT +}; + +char PluginProgramSerDes::ID = 0; + +bool register_plugin_program_serdes = ([]() { + RegisterSerDes( + std::make_unique()); +}(), true); + +class PluginCompileOptionsSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::PluginCompileOptions"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + return ""; + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr) override { + return std::make_unique(); + } + + static char ID; // NOLINT +}; + +char PluginCompileOptionsSerDes::ID = 0; + +bool register_plugin_compile_options_serdes = ([]() { + RegisterSerDes( + std::make_unique()); +}(), true); + +} // namespace + +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt/plugin_program_serdes_test.cc b/xla/python/ifrt/plugin_program_serdes_test.cc new file mode 100644 index 0000000000000..1168c691d5a03 --- /dev/null +++ b/xla/python/ifrt/plugin_program_serdes_test.cc @@ -0,0 +1,57 @@ +// Copyright 2024 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include "llvm/Support/Casting.h" +#include "xla/python/ifrt/plugin_program.h" +#include "xla/python/ifrt/serdes.h" +#include "xla/python/ifrt/serdes.pb.h" +#include "tsl/platform/statusor.h" +#include "tsl/protobuf/error_codes.pb.h" +#include "tsl/protobuf/status.pb.h" + +namespace xla { +namespace ifrt { +namespace { + +using ::testing::IsNull; +using ::testing::Not; + +TEST(PluginProgramSerDesTest, RoundTrip) { + PluginProgram orig; + orig.data = "foo"; + TF_ASSERT_OK_AND_ASSIGN(Serialized serialized, Serialize(orig)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr deserialized, + Deserialize(serialized, /*options=*/nullptr)); + + auto deserialized_program = llvm::dyn_cast(deserialized); + ASSERT_THAT(deserialized_program, Not(IsNull())); + EXPECT_EQ(deserialized_program->data, "foo"); +} + +TEST(PluginCompileOptionsSerDesTest, RoundTrip) { + PluginCompileOptions orig; + TF_ASSERT_OK_AND_ASSIGN(Serialized serialized, Serialize(orig)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr deserialized, + Deserialize(serialized, /*options=*/nullptr)); + ASSERT_THAT(llvm::dyn_cast(deserialized), + Not(IsNull())); +} + +} // namespace +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt/serdes.cc b/xla/python/ifrt/serdes.cc index bb05fad0e16c2..4131a50b40097 100644 --- a/xla/python/ifrt/serdes.cc +++ b/xla/python/ifrt/serdes.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "xla/python/ifrt/serdes.pb.h" @@ -87,7 +88,8 @@ absl::StatusOr Serialize(Serializable& serializable) { auto it = r->type_id_to_serdes.find(serializable.dynamicClassID()); if (it == r->type_id_to_serdes.end()) { return absl::UnimplementedError( - "Serializable has no associated SerDes implementation"); + "Serialize call failed. Serializable has no associated SerDes " + "implementation"); } serdes = it->second; } @@ -107,8 +109,9 @@ absl::StatusOr> Deserialize( absl::MutexLock l(&r->mu); auto it = r->name_to_serdes.find(serialized.type_name()); if (it == r->name_to_serdes.end()) { - return absl::UnimplementedError( - "Serializable has no associated SerDes implementation"); + return absl::UnimplementedError(absl::StrCat( + "Deserialize call failed. Serializable has no associated SerDes ", + "implementation. type_name: ", serialized.type_name())); } serdes = it->second; } diff --git a/xla/python/ifrt/serdes.h b/xla/python/ifrt/serdes.h index 3cb0d853eb9cb..f6d173489b543 100644 --- a/xla/python/ifrt/serdes.h +++ b/xla/python/ifrt/serdes.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/serdes.proto b/xla/python/ifrt/serdes.proto index 98ff64677d7e1..4693b645f7151 100644 --- a/xla/python/ifrt/serdes.proto +++ b/xla/python/ifrt/serdes.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/serdes_test.cc b/xla/python/ifrt/serdes_test.cc index edcb0c64b6eb4..6f80578d33051 100644 --- a/xla/python/ifrt/serdes_test.cc +++ b/xla/python/ifrt/serdes_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/shape.cc b/xla/python/ifrt/shape.cc index 7ee0e9accc46e..0d6fb7c1963cb 100644 --- a/xla/python/ifrt/shape.cc +++ b/xla/python/ifrt/shape.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,18 +15,41 @@ limitations under the License. #include "xla/python/ifrt/shape.h" +#include #include #include #include +#include +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" -#include "xla/python/ifrt/types.pb.h" +#include "absl/strings/string_view.h" +#include "xla/python/ifrt/shape.pb.h" +#include "xla/status.h" +#include "xla/statusor.h" #include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace ifrt { -StatusOr Shape::FromProto(const ShapeProto& proto) { +namespace { + +// Helper type for the visitor. +template +struct overloaded : Ts... { + using Ts::operator()...; +}; + +// Explicit deduction guide. +template +overloaded(Ts...) -> overloaded; + +} // namespace + +absl::StatusOr Shape::FromProto(const ShapeProto& proto) { Shape::Dimensions dims; dims.reserve(proto.dims_size()); for (int64_t dim : proto.dims()) { @@ -60,9 +83,105 @@ std::string Shape::DebugString() const { return absl::StrCat("[", absl::StrJoin(dims_, ","), "]"); } +absl::StatusOr BoundedDynamicShapeTag::FromProto( + const BoundedDynamicShapeTagProto& proto) { + BoundedDynamicShapeTag::DynamicDimensions dynamic_dims; + dynamic_dims.reserve(proto.is_dynamic_dims_size()); + for (bool dynamic_dim : proto.is_dynamic_dims()) { + dynamic_dims.push_back(dynamic_dim); + } + return BoundedDynamicShapeTag(std::move(dynamic_dims)); +} + +BoundedDynamicShapeTagProto BoundedDynamicShapeTag::ToProto() const { + BoundedDynamicShapeTagProto proto; + proto.mutable_is_dynamic_dims()->Reserve(dynamic_dims_.size()); + for (bool dynamic_dim : dynamic_dims_) { + proto.mutable_is_dynamic_dims()->AddAlreadyReserved(dynamic_dim); + } + return proto; +} + +absl::StatusOr DynamicShape::Create(Shape shape, + DynamicShapeTag tag) { + TF_RETURN_IF_ERROR(std::visit( + overloaded{ + [&](const BoundedDynamicShapeTag& tag) -> absl::Status { + if (tag.DynamicDims().size() != shape.dims().size()) { + return InvalidArgument( + "Shape and tag must have the same number of dimensions."); + } + return xla::OkStatus(); + }, + }, + tag)); + return DynamicShape(std::move(shape), std::move(tag)); +} + +absl::StatusOr DynamicShape::GetPaddedShape() const { + return std::visit( + overloaded{ + [this](BoundedDynamicShapeTag tag) { return shape_; }, + }, + tag_); +} + +bool DynamicShape::IsDynamicDim(int dimension) const { + return std::visit( + overloaded{ + [dimension](BoundedDynamicShapeTag tag) { + return tag.DynamicDims().at(dimension); + }, + }, + tag_); +} + +absl::StatusOr DynamicShape::FromProto( + const DynamicShapeProto& proto) { + TF_ASSIGN_OR_RETURN(Shape shape, Shape::FromProto(proto.shape())); + if (proto.has_bounded_dynamic_shape_tag()) { + TF_ASSIGN_OR_RETURN( + BoundedDynamicShapeTag tag, + BoundedDynamicShapeTag::FromProto(proto.bounded_dynamic_shape_tag())); + return DynamicShape::Create(std::move(shape), std::move(tag)); + } + return InvalidArgument("Only support bounded dynamic shape."); +} + +DynamicShapeProto DynamicShape::ToProto() const { + DynamicShapeProto proto; + *proto.mutable_shape() = shape_.ToProto(); + std::visit( + overloaded{ + [&proto](BoundedDynamicShapeTag tag) { + *proto.mutable_bounded_dynamic_shape_tag() = tag.ToProto(); + }, + }, + tag_); + return proto; +} + +std::string DynamicShape::DebugString() const { + return std::visit( + overloaded{[this](BoundedDynamicShapeTag tag) { + absl::InlinedVector dim_reps; + dim_reps.reserve(shape_.dims().size()); + for (int i = 0; i < shape_.dims().size(); ++i) { + absl::string_view prefix = tag.DynamicDims()[i] ? "<=" : ""; + dim_reps.push_back(absl::StrCat(prefix, shape_.dims()[i])); + } + return absl::StrCat("[", absl::StrJoin(dim_reps, ","), "]"); + }}, + tag_); +} + std::ostream& operator<<(std::ostream& os, const Shape& shape) { return os << shape.DebugString(); } +std::ostream& operator<<(std::ostream& os, const DynamicShape& dynamic_shape) { + return os << dynamic_shape.DebugString(); +} + } // namespace ifrt } // namespace xla diff --git a/xla/python/ifrt/shape.h b/xla/python/ifrt/shape.h index 662197c8c47cd..56cbb5fbb82bd 100644 --- a/xla/python/ifrt/shape.h +++ b/xla/python/ifrt/shape.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,20 +16,26 @@ limitations under the License. #ifndef XLA_PYTHON_IFRT_SHAPE_H_ #define XLA_PYTHON_IFRT_SHAPE_H_ +#include + #include #include #include +#include +#include +#include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" #include "absl/types/span.h" -#include "xla/python/ifrt/types.pb.h" -#include "xla/statusor.h" +#include "xla/python/ifrt/shape.pb.h" namespace xla { namespace ifrt { -// Shape of an array. Only supports static shapes at the moment. Every dimension -// size must be equal to or greater than 0. +// Shape of an array. Only supports static shapes (dynamic shapes are supported +// through `ifrt::DynamicShape`). Every dimension size must be equal to or +// greater than 0. class Shape { public: // Maximum dimensions to inline. @@ -45,7 +51,7 @@ class Shape { Shape& operator=(Shape&&) = default; // Constructs `Shape` from `ShapeProto`. - static StatusOr FromProto(const ShapeProto& proto); + static absl::StatusOr FromProto(const ShapeProto& proto); // Returns a `ShapeProto` representation. ShapeProto ToProto() const; @@ -64,7 +70,103 @@ class Shape { Dimensions dims_; }; +// A tag for `Shape` to indicate bounded dynamism. Should be used together with +// `Shape` to represent a bounded dynamic shape where the number of dimensions +// of the shape is fixed, but certain dimensions in the shape have no fixed +// size and only a size upper bound. +class BoundedDynamicShapeTag { + public: + // Maximum dimensions to inline. + static constexpr int kInlineDimensionSize = 6; + + using DynamicDimensions = absl::InlinedVector; + + explicit BoundedDynamicShapeTag(absl::Span dynamic_dims) + : dynamic_dims_( + DynamicDimensions(dynamic_dims.begin(), dynamic_dims.end())) { + CHECK(absl::c_any_of(dynamic_dims_, [](bool b) { return b; })) + << "At least one dimension needs to be dynamically sized."; + } + + BoundedDynamicShapeTag(const BoundedDynamicShapeTag&) = default; + BoundedDynamicShapeTag(BoundedDynamicShapeTag&&) = default; + BoundedDynamicShapeTag& operator=(const BoundedDynamicShapeTag&) = default; + BoundedDynamicShapeTag& operator=(BoundedDynamicShapeTag&&) = default; + + absl::Span DynamicDims() const { return dynamic_dims_; } + + bool operator==(const BoundedDynamicShapeTag& other) const { + return dynamic_dims_ == other.dynamic_dims_; + } + + bool operator!=(const BoundedDynamicShapeTag& other) const { + return !(*this == other); + } + + // Constructs `BoundedDynamicShapeTag` from `BoundedDynamicShapeTagProto`. + static absl::StatusOr FromProto( + const BoundedDynamicShapeTagProto& proto); + + // Returns a `BoundedDynamicShapeTagProto` representation. + BoundedDynamicShapeTagProto ToProto() const; + + private: + // This vector is the same size as `Shape`'s 'dims()' and indicates whether + // the respective dimension is dynamically sized. + DynamicDimensions dynamic_dims_; +}; + +// Use static polymorphism to facilitate type checking. Currently only support +// one type of dynamism. +using DynamicShapeTag = std::variant; + +// Shape with dynamism in dimension sizes, etc. +class DynamicShape { + public: + // Constructs `DynamicShape` from `Shape` and `DynamicShapeTag`. Fails if + // the dimensions mismatch. + // + // When `tag` is a `BoundedDynamicShapeTag`: for any dimension that is dynamic + // as indicated by `tag`, the corresponding dimension in `shape` represents + // the upper bound of the dimension size. + static absl::StatusOr Create(Shape shape, DynamicShapeTag tag); + + DynamicShape(const DynamicShape&) = default; + DynamicShape(DynamicShape&&) = default; + DynamicShape& operator=(const DynamicShape&) = default; + DynamicShape& operator=(DynamicShape&&) = default; + + const DynamicShapeTag& GetTag() const { return tag_; } + + bool operator==(const DynamicShape& other) const { + return tag_ == other.tag_ && shape_ == other.shape_; + } + bool operator!=(const DynamicShape& other) const { return !(*this == other); } + + // Gets the shape after padding. Only works for bounded dynamic shape for now. + absl::StatusOr GetPaddedShape() const; + + // Returns whether a certain dimension in the shape is dynamic. + bool IsDynamicDim(int dimension) const; + + // Constructs `DynamicShape` from `DynamicShapeProto`. + static absl::StatusOr FromProto(const DynamicShapeProto& proto); + + // Returns a `DynamicShapeProto` representation. + DynamicShapeProto ToProto() const; + + std::string DebugString() const; + + private: + DynamicShape(Shape shape, DynamicShapeTag tag) + : shape_(std::move(shape)), tag_(std::move(tag)) {} + + Shape shape_; + DynamicShapeTag tag_; +}; + std::ostream& operator<<(std::ostream& os, const Shape& shape); +std::ostream& operator<<(std::ostream& os, const DynamicShape& dynamic_shape); } // namespace ifrt } // namespace xla diff --git a/xla/python/ifrt/shape.proto b/xla/python/ifrt/shape.proto new file mode 100644 index 0000000000000..354383dbd43a2 --- /dev/null +++ b/xla/python/ifrt/shape.proto @@ -0,0 +1,37 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +// Wire format for `Shape`. Currently support static shapes with all dimension +// sizes greater than or equal to 0. +message ShapeProto { + repeated int64 dims = 1; +} + +// Wire format for `BoundedDynamicShapeTag`. +message BoundedDynamicShapeTagProto { + repeated bool is_dynamic_dims = 1; +} + +// Wire format for `DynamicShape`. Currently only support bounded dynamic shape. +message DynamicShapeProto { + ShapeProto shape = 1; + oneof tag { + BoundedDynamicShapeTagProto bounded_dynamic_shape_tag = 2; + } +} diff --git a/xla/python/ifrt/shape_test.cc b/xla/python/ifrt/shape_test.cc index 2c007fb062e7a..1477f842a7952 100644 --- a/xla/python/ifrt/shape_test.cc +++ b/xla/python/ifrt/shape_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,21 +15,31 @@ limitations under the License. #include "xla/python/ifrt/shape.h" +#include #include #include +#include #include #include #include +#include "absl/status/status.h" +#include "xla/python/ifrt/shape.pb.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" namespace xla { namespace ifrt { namespace { +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; +using ::testing::HasSubstr; +using ::tsl::testing::StatusIs; + TEST(ShapeTest, LargeDim) { Shape shape({std::numeric_limits::max()}); - EXPECT_THAT(shape.dims(), - testing::ElementsAre(std::numeric_limits::max())); + EXPECT_THAT(shape.dims(), ElementsAre(std::numeric_limits::max())); } TEST(ShapeTest, ManyDims) { @@ -37,7 +47,7 @@ TEST(ShapeTest, ManyDims) { std::vector dims(kNumDims); std::iota(dims.begin(), dims.end(), 0); Shape shape(dims); - EXPECT_THAT(shape.dims(), testing::ElementsAreArray(dims)); + EXPECT_THAT(shape.dims(), ElementsAreArray(dims)); } TEST(ShapeTest, ScalarNumElements) { @@ -75,6 +85,115 @@ TEST(ShapeTest, NonZeroDimsNumElements) { } } +TEST(ShapeTest, ToFromProto) { + { + Shape shape({}); + ShapeProto proto = shape.ToProto(); + TF_ASSERT_OK_AND_ASSIGN(Shape shape_copy, shape.FromProto(proto)); + EXPECT_EQ(shape_copy, shape); + } + { + Shape shape({1, 2}); + ShapeProto proto = shape.ToProto(); + TF_ASSERT_OK_AND_ASSIGN(Shape shape_copy, shape.FromProto(proto)); + EXPECT_EQ(shape_copy, shape); + } +} + +TEST(BoundedDynamicShapeTagDeathTest, NoDynamicDim) { + EXPECT_DEATH(BoundedDynamicShapeTag tag({false, false}), + "At least one dimension needs to be dynamically sized"); +} + +TEST(BoundedDynamicShapeTagTest, ToFromProto) { + BoundedDynamicShapeTag tag({true, false}); + BoundedDynamicShapeTagProto proto = tag.ToProto(); + TF_ASSERT_OK_AND_ASSIGN(BoundedDynamicShapeTag tag_copy, + tag.FromProto(proto)); + EXPECT_EQ(tag_copy, tag); +} + +TEST(DynamicShapeTest, SizeMismatch) { + Shape shape({1, 2, 3}); + BoundedDynamicShapeTag tag({true, true}); + EXPECT_THAT(DynamicShape::Create(shape, tag), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("must have the same number of dimensions"))); +} + +TEST(DynamicShapeTest, Equality) { + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape shape1, + DynamicShape::Create(Shape({2, 4}), + BoundedDynamicShapeTag({true, false}))); + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape shape2, + DynamicShape::Create(Shape({3, 4}), + BoundedDynamicShapeTag({true, false}))); + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape shape3, + DynamicShape::Create(Shape({2, 4}), + BoundedDynamicShapeTag({true, true}))); + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape shape4, + DynamicShape::Create(Shape({2, 4, 3}), + BoundedDynamicShapeTag({true, false, true}))); + EXPECT_EQ(shape1, shape1); + EXPECT_NE(shape1, shape2); + EXPECT_NE(shape1, shape3); + EXPECT_NE(shape1, shape4); +} + +TEST(DynamicShapeTest, IsDynamicDim) { + Shape shape({1, 2, 3}); + BoundedDynamicShapeTag tag({true, false, true}); + TF_ASSERT_OK_AND_ASSIGN(DynamicShape dynamic_shape, + DynamicShape::Create(shape, tag)); + EXPECT_TRUE(dynamic_shape.IsDynamicDim(0)); + EXPECT_FALSE(dynamic_shape.IsDynamicDim(1)); + EXPECT_TRUE(dynamic_shape.IsDynamicDim(2)); +} + +TEST(DynamicShapeTest, GetPaddedShape) { + Shape shape({1, 2, 3}); + BoundedDynamicShapeTag tag({true, true, true}); + TF_ASSERT_OK_AND_ASSIGN(DynamicShape dynamic_shape, + DynamicShape::Create(shape, tag)); + TF_ASSERT_OK_AND_ASSIGN(Shape padded_shape, dynamic_shape.GetPaddedShape()); + EXPECT_EQ(padded_shape, shape); +} + +TEST(DynamicShapeTest, ToFromProto) { + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape shape, + DynamicShape::Create(Shape({2, 4}), + BoundedDynamicShapeTag({true, false}))); + DynamicShapeProto proto = shape.ToProto(); + TF_ASSERT_OK_AND_ASSIGN(DynamicShape shape_copy, shape.FromProto(proto)); + EXPECT_EQ(shape_copy, shape); +} + +TEST(DynamicShapeTest, ToString) { + { + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape shape, + DynamicShape::Create(Shape({2, 4}), + BoundedDynamicShapeTag({true, true}))); + std::ostringstream output; + output << shape; + EXPECT_EQ(output.str(), "[<=2,<=4]"); + } + { + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape shape, + DynamicShape::Create(Shape({2, 4}), + BoundedDynamicShapeTag({false, true}))); + std::ostringstream output; + output << shape; + EXPECT_EQ(output.str(), "[2,<=4]"); + } +} + } // namespace } // namespace ifrt } // namespace xla diff --git a/xla/python/ifrt/sharding.cc b/xla/python/ifrt/sharding.cc index 3dbea364c4e0c..5fbc4727055a4 100644 --- a/xla/python/ifrt/sharding.cc +++ b/xla/python/ifrt/sharding.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,19 +21,26 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" #include "absl/log/check.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ExtensibleRTTI.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/index.h" #include "xla/python/ifrt/index_domain.h" +#include "xla/python/ifrt/ir/sharding_param.h" #include "xla/python/ifrt/memory.h" -#include "xla/statusor.h" +#include "xla/python/ifrt/shape.h" #include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace ifrt { @@ -131,8 +138,8 @@ std::vector GetTileIndices(absl::Span dim_shards) { // Returns the tile shape after disassembling `shape` with `sharding_param`. // // Fails if can't shard evenly. -StatusOr GetDisassembledShape(const ShardingParam& sharding_param, - const Shape& shape) { +absl::StatusOr GetDisassembledShape(const ShardingParam& sharding_param, + const Shape& shape) { std::vector dims; dims.reserve(shape.dims().size()); for (const auto [dim, dim_shards] : @@ -166,17 +173,22 @@ std::unique_ptr SingleDeviceSharding::Create( new SingleDeviceSharding(device, memory_kind)); } -StatusOr>>> +absl::StatusOr>>> SingleDeviceSharding::Disassemble(const Shape& shape) const { DCHECK(this); - std::vector>> result; - result.reserve(1); - result.push_back( - {shape, SingleDeviceSharding::Create(devices_[0], memory_kind_)}); - return result; + return std::vector>>{ + {shape, SingleDeviceSharding::Create(devices_[0], memory_kind_)}}; +} + +absl::StatusOr< + std::vector>>> +SingleDeviceSharding::Disassemble(const DynamicShape& dynamic_shape) const { + DCHECK(this); + return std::vector>>{ + {dynamic_shape, SingleDeviceSharding::Create(devices_[0], memory_kind_)}}; } -StatusOr> SingleDeviceSharding::IndexDomains( +absl::StatusOr> SingleDeviceSharding::IndexDomains( const Shape& shape) const { DCHECK(this); std::vector result; @@ -202,14 +214,22 @@ OpaqueSharding::OpaqueSharding(DeviceList devices, MemoryKind memory_kind) : llvm::RTTIExtends(std::move(devices), memory_kind) {} -StatusOr>>> +absl::StatusOr>>> OpaqueSharding::Disassemble(const Shape& shape) const { DCHECK(this); return InvalidArgument( "OpaqueSharding does not have shard shape information"); } -StatusOr> OpaqueSharding::IndexDomains( +absl::StatusOr< + std::vector>>> +OpaqueSharding::Disassemble(const DynamicShape& dynamic_shape) const { + DCHECK(this); + return InvalidArgument( + "OpaqueSharding does not have shard shape information"); +} + +absl::StatusOr> OpaqueSharding::IndexDomains( const Shape& shape) const { DCHECK(this); return InvalidArgument( @@ -236,6 +256,15 @@ std::unique_ptr ConcreteSharding::Create( std::move(shard_shapes))); } +std::unique_ptr ConcreteSharding::Create( + DeviceList devices, MemoryKind memory_kind, DynamicShape dynamic_shape, + std::vector shard_dynamic_shapes) { + CHECK_EQ(devices.size(), shard_dynamic_shapes.size()); + return std::unique_ptr(new ConcreteSharding( + std::move(devices), memory_kind, std::move(dynamic_shape), + std::move(shard_dynamic_shapes))); +} + ConcreteSharding::ConcreteSharding(DeviceList devices, MemoryKind memory_kind, Shape shape, std::vector shard_shapes) : llvm::RTTIExtends(std::move(devices), @@ -243,25 +272,69 @@ ConcreteSharding::ConcreteSharding(DeviceList devices, MemoryKind memory_kind, shape_(std::move(shape)), shard_shapes_(std::move(shard_shapes)) {} -StatusOr>>> +ConcreteSharding::ConcreteSharding( + DeviceList devices, MemoryKind memory_kind, DynamicShape dynamic_shape, + std::vector shard_dynamic_shapes) + : llvm::RTTIExtends(std::move(devices), + memory_kind), + shape_(std::move(dynamic_shape)), + shard_shapes_(std::move(shard_dynamic_shapes)) {} + +absl::StatusOr>>> ConcreteSharding::Disassemble(const Shape& shape) const { DCHECK(this); - if (shape != shape_) { + if (!has_static_shape()) { + return InvalidArgument( + "ConcreteSharding holds dynamic shape, but was asked " + "to disassemble static shape %s", + shape.DebugString()); + } + if (shape != std::get(shape_)) { return InvalidArgument( "ConcreteSharding can only disassemble shape %s, but was asked " "to disassemble shape %s", - shape_.DebugString(), shape.DebugString()); + std::get(shape_).DebugString(), shape.DebugString()); } std::vector>> result; result.reserve(devices_.size()); + const std::vector& shard_shapes = + std::get>(shard_shapes_); + for (int i = 0; i < devices_.size(); ++i) { + result.push_back({shard_shapes[i], + SingleDeviceSharding::Create(devices_[i], memory_kind_)}); + } + return result; +} + +absl::StatusOr< + std::vector>>> +ConcreteSharding::Disassemble(const DynamicShape& dynamic_shape) const { + DCHECK(this); + if (!has_dynamic_shape()) { + return InvalidArgument( + "ConcreteSharding holds static shape, but was asked " + "to disassemble dynamic shape %s", + dynamic_shape.DebugString()); + } + if (dynamic_shape != std::get(shape_)) { + return InvalidArgument( + "ConcreteSharding can only disassemble dynamic shape %s, but was asked " + "to disassemble dynamic shape %s", + std::get(shape_).DebugString(), + dynamic_shape.DebugString()); + } + std::vector>> result; + result.reserve(devices_.size()); + const std::vector& shard_dynamic_shapes = + std::get>(shard_shapes_); for (int i = 0; i < devices_.size(); ++i) { - result.push_back({shard_shapes_[i], + result.push_back({shard_dynamic_shapes[i], SingleDeviceSharding::Create(devices_[i], memory_kind_)}); } return result; } -StatusOr> ConcreteSharding::IndexDomains( +absl::StatusOr> ConcreteSharding::IndexDomains( const Shape& shape) const { DCHECK(this); return InvalidArgument( @@ -270,19 +343,23 @@ StatusOr> ConcreteSharding::IndexDomains( std::string ConcreteSharding::DebugString() const { DCHECK(this); - return absl::StrFormat( - "ConcreteSharding(devices: %s, shape: %s, shard_shapes: %s, memory_kind: " - "%s)", - absl::StrJoin(devices_, ",", - [](std::string* out, const Device* device) { - absl::StrAppend(out, device->ToString()); - }), - shape_.DebugString(), - absl::StrJoin(shard_shapes_, ",", - [](std::string* out, const Shape& shard_shape) { - absl::StrAppend(out, shard_shape.DebugString()); - }), - memory_kind_.DebugString()); + return std::visit( + [this](const auto& shape, const auto& shard_shapes) { + return absl::StrFormat( + "ConcreteSharding(devices: %s, shape: %s, shard_shapes: %s, " + "memory_kind: %s)", + absl::StrJoin(devices_, ",", + [](std::string* out, const Device* device) { + absl::StrAppend(out, device->ToString()); + }), + shape.DebugString(), + absl::StrJoin(shard_shapes, ",", + [](std::string* out, const auto& shard_shape) { + absl::StrAppend(out, shard_shape.DebugString()); + }), + memory_kind_.DebugString()); + }, + shape_, shard_shapes_); } std::unique_ptr ConcreteEvenSharding::Create( @@ -301,7 +378,7 @@ ConcreteEvenSharding::ConcreteEvenSharding(DeviceList devices, shape_(std::move(shape)), shard_shape_(std::move(shard_shape)) {} -StatusOr>>> +absl::StatusOr>>> ConcreteEvenSharding::Disassemble(const Shape& shape) const { DCHECK(this); if (shape != shape_) { @@ -319,7 +396,16 @@ ConcreteEvenSharding::Disassemble(const Shape& shape) const { return result; } -StatusOr> ConcreteEvenSharding::IndexDomains( +absl::StatusOr< + std::vector>>> +ConcreteEvenSharding::Disassemble(const DynamicShape& dynamic_shape) const { + return InvalidArgument( + "ConcreteEvenSharding can only disassemble static shape, but was asked " + "to disassemble dynamic shape %s", + dynamic_shape.DebugString()); +} + +absl::StatusOr> ConcreteEvenSharding::IndexDomains( const Shape& shape) const { DCHECK(this); return InvalidArgument( @@ -339,8 +425,9 @@ std::string ConcreteEvenSharding::DebugString() const { memory_kind_.DebugString()); } -StatusOr> ShardingParamSharding::Create( - ShardingParam sharding_param, DeviceList devices, MemoryKind memory_kind) { +absl::StatusOr> +ShardingParamSharding::Create(ShardingParam sharding_param, DeviceList devices, + MemoryKind memory_kind) { int64_t device_count = absl::c_accumulate(sharding_param.minor_to_major().axis_sizes, 1, std::multiplies()); @@ -354,7 +441,7 @@ StatusOr> ShardingParamSharding::Create( std::move(sharding_param), std::move(devices), memory_kind)); } -StatusOr>>> +absl::StatusOr>>> ShardingParamSharding::Disassemble(const Shape& shape) const { DCHECK(this); if (shape.dims().size() != sharding_param_.dim_shards().size()) { @@ -375,7 +462,16 @@ ShardingParamSharding::Disassemble(const Shape& shape) const { return result; } -StatusOr> ShardingParamSharding::IndexDomains( +absl::StatusOr< + std::vector>>> +ShardingParamSharding::Disassemble(const DynamicShape& dynamic_shape) const { + return InvalidArgument( + "ShardingParamSharding can only disassemble static shape, but was asked " + "to disassemble dynamic shape %s", + dynamic_shape.DebugString()); +} + +absl::StatusOr> ShardingParamSharding::IndexDomains( const Shape& shape) const { DCHECK(this); diff --git a/xla/python/ifrt/sharding.h b/xla/python/ifrt/sharding.h index 67d9607d2409c..71561394a31a1 100644 --- a/xla/python/ifrt/sharding.h +++ b/xla/python/ifrt/sharding.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,8 +20,10 @@ limitations under the License. #include #include #include +#include #include +#include "absl/log/check.h" #include "llvm/Support/ExtensibleRTTI.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/index_domain.h" @@ -29,7 +31,7 @@ limitations under the License. #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/serdes.h" #include "xla/python/ifrt/shape.h" -#include "xla/statusor.h" +#include "xla/python/ifrt/sharding.pb.h" namespace xla { namespace ifrt { @@ -54,16 +56,21 @@ class Sharding : public llvm::RTTIExtends { // Breaks a shape up into per-device shapes and shardings. See // Array::DisassembleIntoSingleDeviceArrays(). It may return an error if // disassembly is unsupported. - virtual StatusOr< + virtual absl::StatusOr< std::vector>>> Disassemble(const Shape& shape) const = 0; + // Variant of `Disassemble` that takes a dynamic shape. + virtual absl::StatusOr< + std::vector>>> + Disassemble(const DynamicShape& dynamic_shape) const = 0; + // Maps each shard to an `IndexDomain` over `shape`. The result is a list of // `index_domain_i` such that `array[index_domain_i] = disassembled_array_i`. // Note that multiple shards may map onto equal `IndexDomain`. For instance, a // fully replicated sharding would return a vector of `[IndexDomain(shape)] * // devices().size()`. - virtual StatusOr> IndexDomains( + virtual absl::StatusOr> IndexDomains( const Shape& shape) const = 0; virtual std::string DebugString() const = 0; @@ -80,6 +87,9 @@ class Sharding : public llvm::RTTIExtends { std::ostream& operator<<(std::ostream& os, const Sharding& sharding); +// TODO(hyeontaek): Move the subclasses of `Sharding` to a seperate file, +// making this sharding.{h,cc} only define interface and common functions. + // Single-device sharding. // // TODO(hyeontaek): `SingleDeviceSharding` tends to be created or consumed in a @@ -96,10 +106,14 @@ class SingleDeviceSharding final ~SingleDeviceSharding() override = default; - StatusOr>>> + absl::StatusOr>>> Disassemble(const Shape& shape) const override; - StatusOr> IndexDomains( + absl::StatusOr< + std::vector>>> + Disassemble(const DynamicShape& dynamic_shape) const override; + + absl::StatusOr> IndexDomains( const Shape& shape) const override; std::string DebugString() const override; @@ -124,10 +138,14 @@ class OpaqueSharding : public llvm::RTTIExtends { ~OpaqueSharding() override = default; - StatusOr>>> + absl::StatusOr>>> Disassemble(const Shape& shape) const override; - StatusOr> IndexDomains( + absl::StatusOr< + std::vector>>> + Disassemble(const DynamicShape& dynamic_shape) const override; + + absl::StatusOr> IndexDomains( const Shape& shape) const override; std::string DebugString() const override; @@ -145,28 +163,63 @@ class OpaqueSharding : public llvm::RTTIExtends { class ConcreteSharding : public llvm::RTTIExtends { public: // Creates a concrete sharding that may contain non-identical shard shapes. - // REQUIRES: devices.size() == shard_shapes.size() + // REQUIRES: `devices`.size() == `shard_shapes`.size() static std::unique_ptr Create( DeviceList devices, MemoryKind memory_kind, Shape shape, std::vector shard_shapes); - Shape shape() const { + // Creates a concrete sharding that may contain non-identical shard dynamic + // shapes. + // REQUIRES: `devices`.size() == `shard_dynamic_shapes`.size() + static std::unique_ptr Create( + DeviceList devices, MemoryKind memory_kind, DynamicShape dynamic_shape, + std::vector shard_dynamic_shapes); + + bool has_dynamic_shape() const { DCHECK(this); - return shape_; + return std::holds_alternative(shape_) && + std::holds_alternative>(shard_shapes_); + } + + bool has_static_shape() const { + DCHECK(this); + return std::holds_alternative(shape_) && + std::holds_alternative>(shard_shapes_); } + + const Shape& shape() const { + DCHECK(has_static_shape()); + return std::get(shape_); + } + + const DynamicShape& dynamic_shape() const { + DCHECK(has_dynamic_shape()); + return std::get(shape_); + } + const std::vector& shard_shapes() const { DCHECK(this); - return shard_shapes_; + DCHECK(std::holds_alternative>(shard_shapes_)); + return std::get>(shard_shapes_); + } + + const std::vector& shard_dynamic_shapes() const { + DCHECK(this); + DCHECK(std::holds_alternative>(shard_shapes_)); + return std::get>(shard_shapes_); } // Sharding implementation. ~ConcreteSharding() override = default; - StatusOr>>> + absl::StatusOr>>> Disassemble(const Shape& shape) const override; + absl::StatusOr< + std::vector>>> + Disassemble(const DynamicShape& dynamic_shape) const override; - StatusOr> IndexDomains( + absl::StatusOr> IndexDomains( const Shape& shape) const override; std::string DebugString() const override; @@ -177,8 +230,12 @@ class ConcreteSharding : public llvm::RTTIExtends { ConcreteSharding(DeviceList devices, MemoryKind memory_kind, Shape shape, std::vector shard_shapes); - Shape shape_; - std::vector shard_shapes_; + ConcreteSharding(DeviceList devices, MemoryKind memory_kind, + DynamicShape dynamic_shape, + std::vector shard_dynamic_shapes); + + std::variant shape_; + std::variant, std::vector> shard_shapes_; }; // Opaque sharding that does not define a fixed semantics for conversion between @@ -206,10 +263,13 @@ class ConcreteEvenSharding ~ConcreteEvenSharding() override = default; - StatusOr>>> + absl::StatusOr>>> Disassemble(const Shape& shape) const override; + absl::StatusOr< + std::vector>>> + Disassemble(const DynamicShape& dynamic_shape) const override; - StatusOr> IndexDomains( + absl::StatusOr> IndexDomains( const Shape& shape) const override; std::string DebugString() const override; @@ -228,15 +288,18 @@ class ConcreteEvenSharding class ShardingParamSharding : public llvm::RTTIExtends { public: - static StatusOr> Create( + static absl::StatusOr> Create( ShardingParam sharding_param, DeviceList devices, MemoryKind memory_kind); const ShardingParam& sharding_param() const { return sharding_param_; } - StatusOr>>> + absl::StatusOr>>> Disassemble(const Shape& shape) const override; + absl::StatusOr< + std::vector>>> + Disassemble(const DynamicShape& dynamic_shape) const override; - StatusOr> IndexDomains( + absl::StatusOr> IndexDomains( const Shape& shape) const override; std::string DebugString() const override; diff --git a/xla/python/ifrt/sharding.proto b/xla/python/ifrt/sharding.proto index 65b481deed59b..5a6c1484922d0 100644 --- a/xla/python/ifrt/sharding.proto +++ b/xla/python/ifrt/sharding.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,34 +17,10 @@ syntax = "proto3"; package xla.ifrt; -import "xla/python/ifrt/types.proto"; +import "xla/python/ifrt/serdes.proto"; -// Wire format for `SingleDeviceSharding`. -message SingleDeviceShardingProto { - // Serialization and deserialization are expected to ensure that device ids - // are stable across proto construction and consumption. - int32 device_id = 1; - optional string memory_kind = 2; -} - -// Wire format for `OpaqueSharding`. -message OpaqueShardingProto { - DeviceListProto devices = 1; - optional string memory_kind = 2; -} - -// Wire format for `ConcreteSharding`. -message ConcreteShardingProto { - DeviceListProto devices = 1; - optional string memory_kind = 4; - ShapeProto shape = 2; - repeated ShapeProto shard_shapes = 3; -} - -// Wire format for `ConcreteEvenSharding`. -message ConcreteEvenShardingProto { - DeviceListProto devices = 1; - optional string memory_kind = 4; - ShapeProto shape = 2; - ShapeProto shard_shape = 3; +// Wire format for `Sharding`. A suitable serializer and deserializer +// implementation must be registered. +message ShardingProto { + xla.ifrt.Serialized serialized_sharding = 1; } diff --git a/xla/python/ifrt/sharding_serdes.cc b/xla/python/ifrt/sharding_serdes.cc index fe88e5f97f0df..e7c237dbe65fb 100644 --- a/xla/python/ifrt/sharding_serdes.cc +++ b/xla/python/ifrt/sharding_serdes.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,12 +20,18 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ExtensibleRTTI.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/serdes.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/ifrt/sharding.pb.h" +#include "xla/python/ifrt/sharding_serdes.pb.h" +#include "xla/util.h" #include "tsl/platform/statusor.h" namespace xla { @@ -35,6 +41,9 @@ char DeserializeShardingOptions::ID = 0; namespace { +// TODO(hyeontaek): Move SerDes for the subclasses of `Sharding` to a separate +// file, making this sharding_serdes.{h,cc} only define common functions. + // Serialization/deserialization for `SingleDeviceSharding`. class SingleDeviceShardingSerDes : public llvm::RTTIExtends { @@ -136,9 +145,17 @@ class ConcreteShardingSerDes if (sharding.memory_kind().memory_kind().has_value()) { proto.set_memory_kind(std::string(*sharding.memory_kind().memory_kind())); } - *proto.mutable_shape() = sharding.shape().ToProto(); - for (const Shape& shape : sharding.shard_shapes()) { - *proto.add_shard_shapes() = shape.ToProto(); + if (sharding.has_static_shape()) { + *proto.mutable_shape() = sharding.shape().ToProto(); + for (const Shape& shape : sharding.shard_shapes()) { + *proto.add_shard_shapes() = shape.ToProto(); + } + } else { + *proto.mutable_dynamic_shape() = sharding.dynamic_shape().ToProto(); + for (const DynamicShape& dynamic_shape : + sharding.shard_dynamic_shapes()) { + *proto.add_shard_dynamic_shapes() = dynamic_shape.ToProto(); + } } return proto.SerializeAsString(); } @@ -162,16 +179,35 @@ class ConcreteShardingSerDes if (proto.has_memory_kind()) { memory_kind = MemoryKind(proto.memory_kind()); } - TF_ASSIGN_OR_RETURN(auto shape, Shape::FromProto(proto.shape())); - std::vector shard_shapes; - shard_shapes.reserve(proto.shard_shapes_size()); - for (const auto& shard_shape_proto : proto.shard_shapes()) { - TF_ASSIGN_OR_RETURN(auto shard_shape, - Shape::FromProto(shard_shape_proto)); - shard_shapes.push_back(std::move(shard_shape)); + if (proto.has_shape()) { + TF_ASSIGN_OR_RETURN(auto shape, Shape::FromProto(proto.shape())); + std::vector shard_shapes; + shard_shapes.reserve(proto.shard_shapes_size()); + for (const auto& shard_shape_proto : proto.shard_shapes()) { + TF_ASSIGN_OR_RETURN(auto shard_shape, + Shape::FromProto(shard_shape_proto)); + shard_shapes.push_back(std::move(shard_shape)); + } + return ConcreteSharding::Create(std::move(devices), memory_kind, + std::move(shape), + std::move(shard_shapes)); + } + if (!proto.has_dynamic_shape()) { + return absl::InvalidArgumentError( + "ConcreteSharding must have Shape or DynamicShape."); + } + TF_ASSIGN_OR_RETURN(auto dynamic_shape, + DynamicShape::FromProto(proto.dynamic_shape())); + std::vector shard_dynamic_shapes; + shard_dynamic_shapes.reserve(proto.shard_dynamic_shapes_size()); + for (const auto& shard_dynamic_shape_proto : proto.shard_dynamic_shapes()) { + TF_ASSIGN_OR_RETURN(auto dynamic_shape, + DynamicShape::FromProto(shard_dynamic_shape_proto)); + shard_dynamic_shapes.push_back(std::move(dynamic_shape)); } return ConcreteSharding::Create(std::move(devices), memory_kind, - std::move(shape), std::move(shard_shapes)); + std::move(dynamic_shape), + std::move(shard_dynamic_shapes)); } static char ID; // NOLINT @@ -259,7 +295,7 @@ bool register_concrete_even_sharding_serdes = ([]{ } // namespace -StatusOr> +absl::StatusOr> GetDeserializeShardingOptions(std::unique_ptr options) { if (!llvm::isa(options.get())) { return xla::InvalidArgument("options must be DeserializeShardingOptions"); @@ -268,5 +304,25 @@ GetDeserializeShardingOptions(std::unique_ptr options) { static_cast(options.release())); } +// TODO(hyeontaek): Move this common logic into Sharding::FromProto() and +// Sharding::ToProto(). + +absl::StatusOr> FromShardingProto( + DeviceList::LookupDeviceFunc lookup_device, + const ShardingProto& sharding_proto) { + TF_ASSIGN_OR_RETURN(std::unique_ptr sharding, + Deserialize(sharding_proto.serialized_sharding(), + std::make_unique( + std::move(lookup_device)))); + return std::unique_ptr(llvm::cast(sharding.release())); +} + +absl::StatusOr ToShardingProto(const Sharding& sharding) { + ShardingProto sharding_proto; + TF_ASSIGN_OR_RETURN(*sharding_proto.mutable_serialized_sharding(), + Serialize(const_cast(sharding))); + return sharding_proto; +} + } // namespace ifrt } // namespace xla diff --git a/xla/python/ifrt/sharding_serdes.h b/xla/python/ifrt/sharding_serdes.h index 49762102c8635..e5a53237a49a2 100644 --- a/xla/python/ifrt/sharding_serdes.h +++ b/xla/python/ifrt/sharding_serdes.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,10 +18,12 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "llvm/Support/ExtensibleRTTI.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/serdes.h" -#include "xla/statusor.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/sharding.pb.h" namespace xla { namespace ifrt { @@ -43,9 +45,20 @@ struct DeserializeShardingOptions }; // Casts `DeserializeOptions` into `DeserializeShardingOptions`. -StatusOr> +absl::StatusOr> GetDeserializeShardingOptions(std::unique_ptr options); +// TODO(hyeontaek): Remove these functions from xla::ifrt, once migration to +// Sharding::FromProto() and Sharding::ToProto() is done. + +// Deserializes `ShardingProto` into `Sharding`. +absl::StatusOr> FromShardingProto( + DeviceList::LookupDeviceFunc lookup_device, + const ShardingProto& sharding_proto); + +// Serializes `Sharding` into `ShardingProto`. +absl::StatusOr ToShardingProto(const Sharding& sharding); + } // namespace ifrt } // namespace xla diff --git a/xla/python/ifrt/sharding_serdes.proto b/xla/python/ifrt/sharding_serdes.proto new file mode 100644 index 0000000000000..dcd8fc8fe6b65 --- /dev/null +++ b/xla/python/ifrt/sharding_serdes.proto @@ -0,0 +1,55 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +import "xla/python/ifrt/device.proto"; +import "xla/python/ifrt/shape.proto"; + +// Wire format for `SingleDeviceSharding`. +message SingleDeviceShardingProto { + // Serialization and deserialization are expected to ensure that device ids + // are stable across proto construction and consumption. + int32 device_id = 1; + optional string memory_kind = 2; +} + +// Wire format for `OpaqueSharding`. +message OpaqueShardingProto { + DeviceListProto devices = 1; + optional string memory_kind = 2; +} + +// Wire format for `ConcreteSharding`. +message ConcreteShardingProto { + DeviceListProto devices = 1; + optional string memory_kind = 4; + oneof shape_or_dynamic_shape { + ShapeProto shape = 2; + DynamicShapeProto dynamic_shape = 5; + } + repeated ShapeProto shard_shapes = 3; + repeated DynamicShapeProto shard_dynamic_shapes = 6; +} + +// Wire format for `ConcreteEvenSharding`. +message ConcreteEvenShardingProto { + DeviceListProto devices = 1; + optional string memory_kind = 4; + ShapeProto shape = 2; + ShapeProto shard_shape = 3; +} diff --git a/xla/python/ifrt/sharding_serdes_test.cc b/xla/python/ifrt/sharding_serdes_test.cc index 1e94a07cae6de..28f04c98d0e96 100644 --- a/xla/python/ifrt/sharding_serdes_test.cc +++ b/xla/python/ifrt/sharding_serdes_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,9 +21,14 @@ limitations under the License. #include #include #include "absl/functional/bind_front.h" +#include "llvm/Support/Casting.h" +#include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/serdes.h" +#include "xla/python/ifrt/serdes.pb.h" +#include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" #include "xla/python/ifrt/sharding_test_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace ifrt { @@ -90,6 +95,41 @@ TEST_P(ShardingSerDesTest, ConcreteShardingRoundTrip) { ElementsAreArray(sharding->shard_shapes())); } +TEST_P(ShardingSerDesTest, ConcreteShardingWithDynamicShapeRoundTrip) { + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape dynamic_shape, + DynamicShape::Create(Shape({10, 20}), + BoundedDynamicShapeTag({false, true}))); + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape shard_dynamic_shape1, + DynamicShape::Create(Shape({3, 20}), + BoundedDynamicShapeTag({false, true}))); + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape shard_dynamic_shape2, + DynamicShape::Create(Shape({7, 20}), + BoundedDynamicShapeTag({false, true}))); + auto sharding = ConcreteSharding::Create( + GetDevices({0, 1}), MemoryKind("abc"), + /*dynamic_shape=*/dynamic_shape, + /*shard_dynamic_shapes=*/{shard_dynamic_shape1, shard_dynamic_shape2}); + + TF_ASSERT_OK_AND_ASSIGN(Serialized serialized, Serialize(*sharding)); + + TF_ASSERT_OK_AND_ASSIGN( + auto deserialized, + Deserialize(serialized, + std::make_unique( + absl::bind_front(&Client::LookupDevice, client())))); + + const auto* out_sharding = + llvm::dyn_cast(deserialized.get()); + ASSERT_NE(out_sharding, nullptr); + EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices())); + EXPECT_THAT(out_sharding->dynamic_shape(), sharding->dynamic_shape()); + EXPECT_THAT(out_sharding->shard_dynamic_shapes(), + ElementsAreArray(sharding->shard_dynamic_shapes())); +} + TEST_P(ShardingSerDesTest, ConcreteEvenShardingRoundTrip) { auto sharding = ConcreteEvenSharding::Create(GetDevices({0, 1}), MemoryKind("abc"), @@ -114,7 +154,8 @@ TEST_P(ShardingSerDesTest, ConcreteEvenShardingRoundTrip) { INSTANTIATE_TEST_SUITE_P(NumDevices, ShardingSerDesTest, testing::Values(test_util::ShardingTestParam{ - .num_devices = 2, .num_addressable_devices = 2})); + /*num_devices=*/2, + /*num_addressable_devices=*/2})); } // namespace } // namespace ifrt diff --git a/xla/python/ifrt/sharding_test.cc b/xla/python/ifrt/sharding_test.cc index f3997f5288daa..9be2bf217d16d 100644 --- a/xla/python/ifrt/sharding_test.cc +++ b/xla/python/ifrt/sharding_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,7 +23,11 @@ limitations under the License. #include #include "llvm/Support/Casting.h" #include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/index.h" +#include "xla/python/ifrt/index_domain.h" #include "xla/python/ifrt/ir/sharding_param.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" @@ -60,15 +64,32 @@ TEST_P(SingleDeviceShardingTest, Disassemble) { std::shared_ptr sharding = SingleDeviceSharding::Create(device_list.devices().front(), MemoryKind()); - Shape shape({10, 20}); - TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); - - ASSERT_THAT(disassembled, SizeIs(1)); - const auto& [result_shape, result_sharding] = disassembled[0]; - ASSERT_EQ(shape, result_shape); - ASSERT_TRUE(llvm::isa(*result_sharding)); - EXPECT_THAT(result_sharding->devices().devices(), - ElementsAreArray(device_list.devices())); + { // Disassemble static shape. + Shape shape({10, 20}); + TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); + + ASSERT_THAT(disassembled, SizeIs(1)); + const auto& [result_shape, result_sharding] = disassembled[0]; + ASSERT_EQ(shape, result_shape); + ASSERT_TRUE(llvm::isa(*result_sharding)); + EXPECT_THAT(result_sharding->devices().devices(), + ElementsAreArray(device_list.devices())); + } + { // Disassemble dynamic shape. + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape dynamic_shape, + DynamicShape::Create(Shape({10, 20}), + BoundedDynamicShapeTag({true, true}))); + TF_ASSERT_OK_AND_ASSIGN(auto disassembled, + sharding->Disassemble(dynamic_shape)); + + ASSERT_THAT(disassembled, SizeIs(1)); + const auto& [result_shape, result_sharding] = disassembled[0]; + ASSERT_EQ(dynamic_shape, result_shape); + ASSERT_TRUE(llvm::isa(*result_sharding)); + EXPECT_THAT(result_sharding->devices().devices(), + ElementsAreArray(device_list.devices())); + } } TEST_P(OpaqueShardingTest, FailedToDisassemble) { @@ -81,6 +102,15 @@ TEST_P(OpaqueShardingTest, FailedToDisassemble) { StatusIs( tsl::error::INVALID_ARGUMENT, HasSubstr("OpaqueSharding does not have shard shape information"))); + + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape dynamic_shape, + DynamicShape::Create(Shape({30}), BoundedDynamicShapeTag({true}))); + EXPECT_THAT( + sharding->Disassemble(dynamic_shape), + StatusIs( + tsl::error::INVALID_ARGUMENT, + HasSubstr("OpaqueSharding does not have shard shape information"))); } TEST_P(OpaqueShardingTest, IndexDomainsFails) { @@ -116,6 +146,36 @@ TEST_P(ConcreteShardingTest, Disassemble) { } } +TEST_P(ConcreteShardingTest, DisassembleDynamicShape) { + DeviceList device_list = GetDevices({0, 1}); + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape dynamic_shape, + DynamicShape::Create(Shape({10}), BoundedDynamicShapeTag({true}))); + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape shard_dynamic_shape1, + DynamicShape::Create(Shape({3}), BoundedDynamicShapeTag({true}))); + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape shard_dynamic_shape2, + DynamicShape::Create(Shape({7}), BoundedDynamicShapeTag({true}))); + std::vector shard_dynamic_shapes{ + std::move(shard_dynamic_shape1), std::move(shard_dynamic_shape2)}; + auto sharding = ConcreteSharding::Create(device_list, MemoryKind(), + dynamic_shape, shard_dynamic_shapes); + EXPECT_THAT(sharding->Disassemble(Shape({10})), + StatusIs(tsl::error::INVALID_ARGUMENT, + HasSubstr("ConcreteSharding holds dynamic shape"))); + TF_ASSERT_OK_AND_ASSIGN(auto disassembled, + sharding->Disassemble(DynamicShape(dynamic_shape))); + ASSERT_THAT(disassembled, SizeIs(2)); + for (int i = 0; i < disassembled.size(); ++i) { + const auto& [dynamic_shape, sharding] = disassembled[i]; + EXPECT_EQ(dynamic_shape, shard_dynamic_shapes[i]); + EXPECT_TRUE(llvm::isa(*sharding)); + EXPECT_THAT(sharding->devices().devices(), + ElementsAre(device_list.devices()[i])); + } +} + TEST_P(ConcreteShardingTest, DisassembleFailsForUnexpectedShape) { auto device_list = GetDevices({0, 1}); std::vector shard_shapes; @@ -306,19 +366,24 @@ TEST_P(ShardingParamShardingTest, IndexDomainWithReplication) { INSTANTIATE_TEST_SUITE_P(NumDevices, SingleDeviceShardingTest, testing::Values(test_util::ShardingTestParam{ - .num_devices = 6, .num_addressable_devices = 6})); + /*num_devices=*/6, + /*num_addressable_devices=*/6})); INSTANTIATE_TEST_SUITE_P(NumDevices, OpaqueShardingTest, testing::Values(test_util::ShardingTestParam{ - .num_devices = 6, .num_addressable_devices = 6})); + /*num_devices=*/6, + /*num_addressable_devices=*/6})); INSTANTIATE_TEST_SUITE_P(NumDevices, ConcreteShardingTest, testing::Values(test_util::ShardingTestParam{ - .num_devices = 6, .num_addressable_devices = 6})); + /*num_devices=*/6, + /*num_addressable_devices=*/6})); INSTANTIATE_TEST_SUITE_P(NumDevices, ConcreteEvenShardingTest, testing::Values(test_util::ShardingTestParam{ - .num_devices = 6, .num_addressable_devices = 6})); + /*num_devices=*/6, + /*num_addressable_devices=*/6})); INSTANTIATE_TEST_SUITE_P(NumDevices, ShardingParamShardingTest, testing::Values(test_util::ShardingTestParam{ - .num_devices = 6, .num_addressable_devices = 4})); + /*num_devices=*/6, + /*num_addressable_devices=*/4})); } // namespace } // namespace ifrt diff --git a/xla/python/ifrt/sharding_test_util.cc b/xla/python/ifrt/sharding_test_util.cc index c145ec90d6bdb..2f97ce7a738af 100644 --- a/xla/python/ifrt/sharding_test_util.cc +++ b/xla/python/ifrt/sharding_test_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "xla/pjrt/pjrt_common.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/mock.h" #include "xla/python/ifrt/test_util.h" @@ -52,14 +53,16 @@ std::shared_ptr MakeShardingTestClient( for (int i = 0; i < num_addressable_devices; ++i) { auto device = std::make_unique(); - ON_CALL(*device, id).WillByDefault(Return(i + 10)); + ON_CALL(*device, global_device_id) + .WillByDefault(Return(PjRtGlobalDeviceId(i + 10))); ON_CALL(*device, IsAddressable).WillByDefault(Return(true)); state->devices.push_back(device.get()); state->device_map.insert({i + 10, std::move(device)}); } for (int i = num_addressable_devices; i < num_devices; ++i) { auto device = std::make_unique(); - ON_CALL(*device, id).WillByDefault(Return(i + 10)); + ON_CALL(*device, global_device_id) + .WillByDefault(Return(PjRtGlobalDeviceId(i + 10))); ON_CALL(*device, IsAddressable).WillByDefault(Return(false)); state->devices.push_back(device.get()); state->device_map.insert({i + 10, std::move(device)}); @@ -70,7 +73,7 @@ std::shared_ptr MakeShardingTestClient( .WillByDefault( [state]() -> absl::Span { return state->devices; }); ON_CALL(*client, LookupDevice) - .WillByDefault([state](int device_id) -> StatusOr { + .WillByDefault([state](int device_id) -> absl::StatusOr { auto it = state->device_map.find(device_id); if (it == state->device_map.end()) { return InvalidArgument("Unexpected device id: %d", device_id); diff --git a/xla/python/ifrt/sharding_test_util.h b/xla/python/ifrt/sharding_test_util.h index 043063b8e16c9..b311811820927 100644 --- a/xla/python/ifrt/sharding_test_util.h +++ b/xla/python/ifrt/sharding_test_util.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/support/BUILD b/xla/python/ifrt/support/BUILD index f8e596a2cfc6d..f524761a3a545 100644 --- a/xla/python/ifrt/support/BUILD +++ b/xla/python/ifrt/support/BUILD @@ -6,35 +6,41 @@ package( ) cc_library( - name = "sharding_param_to_op_sharding", - srcs = ["sharding_param_to_op_sharding.cc"], - hdrs = ["sharding_param_to_op_sharding.h"], + name = "sharding_conversions", + srcs = ["sharding_conversions.cc"], + hdrs = ["sharding_conversions.h"], visibility = ["//xla/python/ifrt:friends"], deps = [ - "//xla:statusor", "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", "//xla/python/ifrt/ir", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", - "@tsl//tsl/platform:errors", ], ) xla_cc_test( - name = "sharding_param_to_op_sharding_test", - srcs = ["sharding_param_to_op_sharding_test.cc"], + name = "sharding_conversions_test", + srcs = ["sharding_conversions_test.cc"], deps = [ - ":sharding_param_to_op_sharding", - "//xla:statusor", + ":sharding_conversions", + "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:tile_assignment", "//xla/python/ifrt", "//xla/python/ifrt:sharding_test_util", "//xla/python/ifrt/ir", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@tsl//tsl/platform:errors", + "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", ], diff --git a/xla/python/ifrt/support/sharding_conversions.cc b/xla/python/ifrt/support/sharding_conversions.cc new file mode 100644 index 0000000000000..1c6a2b3f6f5a7 --- /dev/null +++ b/xla/python/ifrt/support/sharding_conversions.cc @@ -0,0 +1,207 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/ifrt/support/sharding_conversions.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/ir/sharding_param.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace ifrt { +namespace support { + +absl::StatusOr ToOpSharding(const ShardingParam& sharding_param, + absl::Span device_mapping) { + OpSharding op_sharding; + { + bool all_dim_replicated = true; + for (const int64_t dim_shard : sharding_param.dim_shards()) { + if (dim_shard != 1) { + all_dim_replicated = false; + break; + } + } + if (all_dim_replicated) { + op_sharding.set_type(OpSharding::REPLICATED); + return op_sharding; + } + } + op_sharding.set_type(OpSharding::OTHER); + + // Populate tile_assignment_dimensions. + auto* tile_assignment_dims = op_sharding.mutable_tile_assignment_dimensions(); + int64_t cum_size = 1; + tile_assignment_dims->Reserve(sharding_param.dim_shards().size() + 1); + for (const int64_t dim_shard : sharding_param.dim_shards()) { + cum_size *= dim_shard; + tile_assignment_dims->Add(dim_shard); + } + int device_count = 1; + for (const int axis_size : sharding_param.minor_to_major().axis_sizes) { + device_count *= axis_size; + } + if (device_count != cum_size) { + op_sharding.set_replicate_on_last_tile_dim(true); + tile_assignment_dims->Add(device_count / cum_size); + } + + // Populate tile_assignment_devices. + llvm::SmallVector devices; + sharding_param.minor_to_major().ToDeviceList(devices); + auto* tile_assignment_devices = op_sharding.mutable_tile_assignment_devices(); + tile_assignment_devices->Reserve(devices.size()); + for (const int device : devices) { + if (device < 0 || device >= device_mapping.size()) { + return absl::OutOfRangeError( + absl::StrCat("Can't map device with logical id ", device, + ". The logical device id should be within [0, ", + device_mapping.size(), ").")); + } + tile_assignment_devices->Add(device_mapping[device]); + } + + return op_sharding; +} + +absl::StatusOr ToHloSharding(const ShardingParam& sharding_param) { + auto axis_sizes = sharding_param.minor_to_major().axis_sizes; + llvm::SmallVector reshape_dims; + reshape_dims.reserve(axis_sizes.size()); + int device_count = 1; + for (auto axis_size : llvm::reverse(axis_sizes)) { + reshape_dims.push_back(axis_size); + device_count *= axis_size; + } + if (device_count == 1) { + // Generate single-device sharding as TileMaximal. + return HloSharding::Replicate(); + } + int64_t cum_size = 1; + llvm::SmallVector dims; + dims.reserve(sharding_param.dim_shards().size()); + for (const int64_t dim_shard : sharding_param.dim_shards()) { + cum_size *= dim_shard; + dims.push_back(dim_shard); + } + // Applies the inverse of the transposes from `ToShardingParam`. + llvm::SmallVector permutation; + int num_axis = sharding_param.minor_to_major().permutation.size(); + permutation.reserve(num_axis); + for (const int axis_id : + llvm::reverse(sharding_param.minor_to_major().permutation)) { + permutation.push_back(num_axis - axis_id - 1); + } + if (device_count != cum_size) { + // Add the replicated dimension. + dims.push_back(device_count / cum_size); + return HloSharding::PartialTile( + TileAssignment(dims, reshape_dims, permutation)); + } else { + return HloSharding::IotaTile(dims, reshape_dims, permutation); + } +} + +absl::StatusOr ToShardingParam(const HloSharding& hlo_sharding, + int rank, int num_devices) { + // `dim_shards` has size equal to the rank of the array, with each entry + // representing the number of shards for the corresponding dimension. + // `minor_to_major.permutation` and `minor_to_major.axis_sizes` must be + // of the same size, and specify how the shards are mapped over the axis in + // `minor_to_major` order. + ShardingParam::MinorToMajor minor_to_major; + + if (hlo_sharding.IsReplicated() || + (hlo_sharding.IsTileMaximal() && hlo_sharding.HasUniqueDevice() && + num_devices == 1)) { + // Convert replicated or TileMaximal. Only single-device TileMaximal + // conversion is supported. + llvm::SmallVector dim_shards(rank, 1); + minor_to_major.permutation.push_back(0); + minor_to_major.axis_sizes.push_back(num_devices); + return ShardingParam(dim_shards, std::move(minor_to_major)); + } else if (hlo_sharding.IsTiled()) { + const xla::TileAssignment& tile_assignment = hlo_sharding.tile_assignment(); + if (!tile_assignment.iota()) { + return absl::InvalidArgumentError(absl::StrCat( + "Conversion from `HloSharding` without `IotaTileAssignment` is not " + "supported; sharding=", + hlo_sharding.ToString())); + } + if (rank != hlo_sharding.TiledDataRank()) { + return absl::InvalidArgumentError(absl::StrFormat( + "`TiledData` expected to have have %d dimensions, but has %d " + "dimensions; sharding=%s", + rank, hlo_sharding.TiledDataRank(), hlo_sharding.ToString())); + } + if (hlo_sharding.subgroup_types().size() > 1 || + (hlo_sharding.subgroup_types().size() == 1 && + hlo_sharding.subgroup_types()[0] != xla::OpSharding::REPLICATED)) { + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported conversion to `ShardingParam` from `HloSharding` that " + "has more than a subgroup or a subgroup that is not REPLICATED; " + "sharding=", + hlo_sharding.ToString())); + } + // Get the `dim_shards` from the tile assignment. + llvm::SmallVector dim_shards(tile_assignment.dimensions().begin(), + tile_assignment.dimensions().end()); + if (hlo_sharding.ReplicateOnLastTileDim() || + (hlo_sharding.subgroup_types().size() == 1 && + hlo_sharding.subgroup_types()[0] == xla::OpSharding::REPLICATED)) { + dim_shards.pop_back(); + } + if (tile_assignment.iota()->reshape_dims().empty()) { + // If there are no reshape_dims, then the array is replicated. + minor_to_major.permutation.push_back(0); + minor_to_major.axis_sizes.push_back(num_devices); + } else { + for (auto reshape_dim : + llvm::reverse(tile_assignment.iota()->reshape_dims())) { + minor_to_major.axis_sizes.push_back(reshape_dim); + } + // The devices generated by HloSharding + // np.arange(ndevices).reshape(reshape_dims).transpose(transpose_perm) + // must be equal to the devices ShardingParam + // np.arange(ndevices).reshape(reverse(axis_size)).T.transpose(perm).T + // Step 1: Compute transpose(transpose_perm).T. + // Step 2: Compute T.transpose(transpose_perm).T. + int num_axis = tile_assignment.iota()->transpose_perm().size(); + for (int axis_id : + llvm::reverse(tile_assignment.iota()->transpose_perm())) { + minor_to_major.permutation.push_back(num_axis - axis_id - 1); + } + } + return ShardingParam(dim_shards, std::move(minor_to_major)); + } + return absl::UnimplementedError( + absl::StrCat("Unsupported conversion to `ShardingParam` from " + "`HloSharding`; sharding=", + hlo_sharding.ToString())); +} + +} // namespace support +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt/support/sharding_conversions.h b/xla/python/ifrt/support/sharding_conversions.h new file mode 100644 index 0000000000000..1a9c2d3e728a5 --- /dev/null +++ b/xla/python/ifrt/support/sharding_conversions.h @@ -0,0 +1,62 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_SUPPORT_SHARDING_CONVERSIONS_H_ +#define XLA_PYTHON_IFRT_SUPPORT_SHARDING_CONVERSIONS_H_ + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/ir/sharding_param.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace ifrt { +namespace support { + +// Converts ShardingParam and a device_mapping to OpSharding. +// +// The function assumes that `sharding_param` is valid. The logical device +// ids from `sharding_param` are used as indices into the device_mapping to +// obtain the device ids to create the OpSharding. +// +// Returns error when `device_mapping` can't map the logical devices in +// `sharding_param`. +absl::StatusOr ToOpSharding(const ShardingParam& sharding_param, + absl::Span device_mapping); + +// Converts ShardingParam to HloSharding. +// +// This assumes that `sharding_param` is valid. +// The returned HloSharding uses the same logical device ids as the +// given ShardingParam. +absl::StatusOr ToHloSharding(const ShardingParam& sharding_param); + +// Converts HloSharding to ShardingParam. +// +// It assumes that `hlo_sharding` is valid. +// +// Returns error when `hlo_sharding` cannot be converted to sharding param. +// Only a subset of HloShardings are supported: REPLICATED (including MAXIMAL +// on single-device), partially replicated, fully partitioned shardings. +// (Non-fully-replicated) MAXIMAL and MANUAL shardings are not supported. +absl::StatusOr ToShardingParam(const HloSharding& hlo_sharding, + int rank, int num_devices); + +} // namespace support +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_SUPPORT_SHARDING_CONVERSIONS_H_ diff --git a/xla/python/ifrt/support/sharding_conversions_test.cc b/xla/python/ifrt/support/sharding_conversions_test.cc new file mode 100644 index 0000000000000..d6619f4693d0e --- /dev/null +++ b/xla/python/ifrt/support/sharding_conversions_test.cc @@ -0,0 +1,333 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/ifrt/support/sharding_conversions.h" + +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/ir/tile_assignment.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/index_domain.h" +#include "xla/python/ifrt/ir/sharding_param.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/sharding_test_util.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace support { +namespace { + +using ::tsl::testing::StatusIs; +using xla::HloSharding; + +absl::StatusOr ToHloShardingViaOpSharding( + const ShardingParam& sharding_param, absl::Span device_list) { + TF_ASSIGN_OR_RETURN(xla::OpSharding op_sharding, + ToOpSharding(sharding_param, device_list)); + return HloSharding::FromProto(op_sharding); +} + +TEST(ShardingConversionsTest, Replicated) { + ShardingParam expected_sharding_param{ + /*dim_shards=*/{1, 1, 1}, + {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; + TF_EXPECT_OK(expected_sharding_param.verify()); + TF_ASSERT_OK_AND_ASSIGN(const HloSharding hlo_iota_sharding, + ToHloSharding(expected_sharding_param)); + TF_ASSERT_OK_AND_ASSIGN( + const HloSharding hlo_sharding, + ToHloShardingViaOpSharding(expected_sharding_param, {0, 1, 2, 3, 4, 5})); + EXPECT_EQ(hlo_sharding.ToString(), "{replicated}"); + EXPECT_EQ(hlo_sharding, hlo_iota_sharding); + TF_ASSERT_OK_AND_ASSIGN(auto sharding_param, + ToShardingParam(hlo_iota_sharding, 3, 6)); + // We do not compare expected_sharding_param and sharding_param because they + // haven't been canonicalized (1x1x1 to [0, 1] on 2x3 vs. 1x1x1 to [0] on 6). + TF_ASSERT_OK_AND_ASSIGN(const HloSharding actual_hlo_sharding, + ToHloSharding(sharding_param)); + EXPECT_EQ(hlo_iota_sharding, actual_hlo_sharding); +} + +TEST(ShardingConversionsTest, SingleDeviceReplicated) { + ShardingParam expected_sharding_param{ + /*dim_shards=*/{1, 1}, {/*permutation=*/{0}, /*axis_sizes=*/{1}}}; + TF_EXPECT_OK(expected_sharding_param.verify()); + TF_ASSERT_OK_AND_ASSIGN(const HloSharding hlo_iota_sharding, + ToHloSharding(expected_sharding_param)); + TF_ASSERT_OK_AND_ASSIGN( + const HloSharding hlo_sharding, + ToHloShardingViaOpSharding(expected_sharding_param, {0})); + EXPECT_EQ(hlo_sharding.ToString(), "{replicated}"); + EXPECT_EQ(hlo_sharding, hlo_iota_sharding); + TF_ASSERT_OK_AND_ASSIGN(auto sharding_param, + ToShardingParam(hlo_iota_sharding, 2, 1)); + EXPECT_EQ(expected_sharding_param, sharding_param); +} + +TEST(ShardingConversionsTest, Permutation) { + ShardingParam expected_sharding_param{ + /*dim_shards=*/{2, 1, 3}, + {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; + TF_EXPECT_OK(expected_sharding_param.verify()); + TF_ASSERT_OK_AND_ASSIGN(const HloSharding hlo_iota_sharding, + ToHloSharding(expected_sharding_param)); + TF_ASSERT_OK_AND_ASSIGN( + const HloSharding hlo_sharding, + ToHloShardingViaOpSharding(expected_sharding_param, {0, 1, 2, 3, 4, 5})); + EXPECT_EQ(hlo_sharding.ToString(), "{devices=[2,1,3]0,3,1,4,2,5}"); + EXPECT_EQ(hlo_sharding, hlo_iota_sharding); + TF_ASSERT_OK_AND_ASSIGN(auto sharding_param, + ToShardingParam(hlo_iota_sharding, 3, 6)); + EXPECT_EQ(expected_sharding_param, sharding_param); +} + +TEST(ShardingConversionsTest, Partial) { + ShardingParam expected_sharding_param{ + /*dim_shards=*/{2, 1}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; + TF_EXPECT_OK(expected_sharding_param.verify()); + TF_ASSERT_OK_AND_ASSIGN(const HloSharding hlo_iota_sharding, + ToHloSharding(expected_sharding_param)); + TF_ASSERT_OK_AND_ASSIGN( + const HloSharding hlo_sharding, + ToHloShardingViaOpSharding(expected_sharding_param, {0, 1, 2, 3, 4, 5})); + EXPECT_EQ(hlo_sharding.ToString(), + "{devices=[2,1,3]0,1,2,3,4,5 last_tile_dim_replicate}"); + EXPECT_EQ(hlo_sharding, hlo_iota_sharding); + TF_ASSERT_OK_AND_ASSIGN(auto sharding_param, + ToShardingParam(hlo_iota_sharding, 2, 6)); + // We do not compare expected_sharding_param and sharding_param because they + // haven't been canonicalized (2x1 to [0, 1] on 2x3 vs. 2x1 to [0] on 6). + TF_ASSERT_OK_AND_ASSIGN(const HloSharding actual_hlo_sharding, + ToHloSharding(sharding_param)); + EXPECT_EQ(hlo_iota_sharding, actual_hlo_sharding); +} + +TEST(ShardingConversionsTest, OneDimToTwoAxes) { + ShardingParam expected_sharding_param{ + /*dim_shards=*/{4}, {/*permutation=*/{1, 0}, /*axis_sizes=*/{2, 2}}}; + TF_EXPECT_OK(expected_sharding_param.verify()); + TF_ASSERT_OK_AND_ASSIGN(const HloSharding hlo_iota_sharding, + ToHloSharding(expected_sharding_param)); + TF_ASSERT_OK_AND_ASSIGN( + const HloSharding hlo_sharding, + ToHloShardingViaOpSharding(expected_sharding_param, {0, 1, 2, 3})); + EXPECT_EQ(hlo_sharding.ToString(), "{devices=[4]0,2,1,3}"); + EXPECT_EQ(hlo_sharding, hlo_iota_sharding); + TF_ASSERT_OK_AND_ASSIGN(auto sharding_param, + ToShardingParam(hlo_iota_sharding, 1, 4)); + EXPECT_EQ(expected_sharding_param, sharding_param); +} + +TEST(ShardingConversionsTest, NonTrivialDeviceAssignment) { + ShardingParam expected_sharding_param{ + /*dim_shards=*/{2, 1, 3}, + {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; + TF_EXPECT_OK(expected_sharding_param.verify()); + TF_ASSERT_OK_AND_ASSIGN( + const HloSharding hlo_sharding, + ToHloShardingViaOpSharding(expected_sharding_param, {6, 5, 4, 3, 2, 1})); + EXPECT_EQ(hlo_sharding.ToString(), "{devices=[2,1,3]6,3,5,2,4,1}"); +} + +TEST(ShardingConversionsTest, VerifyIncorrectShardings) { + ShardingParam different_permutation_and_axis{ + /*dim_shards=*/{1, 1}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2}}}; + EXPECT_FALSE(different_permutation_and_axis.verify().ok()); + ShardingParam too_many_slices{/*dim_shards=*/{2, 2}, + {/*permutation=*/{0}, /*axis_sizes=*/{2}}}; + EXPECT_FALSE(too_many_slices.verify().ok()); + ShardingParam incorrect_permutation{ + /*dim_shards=*/{4, 1}, + {/*permutation=*/{0, 1, 1}, /*axis_sizes=*/{2, 2, 2}}}; + EXPECT_FALSE(incorrect_permutation.verify().ok()); +} + +TEST(ShardingConversionsTest, ErrorOnDeviceAssignment) { + ShardingParam sharding_param{/*dim_shards=*/{2, 1, 3}, + {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; + TF_EXPECT_OK(sharding_param.verify()); + EXPECT_THAT( + ToHloShardingViaOpSharding(sharding_param, {6, 5, 4, 3, 2}), + StatusIs(absl::StatusCode::kOutOfRange, + ::testing::HasSubstr("Can't map device with logical id 5"))); +} + +struct HloShardingTestStruct { + HloSharding hlo_sharding; + int rank; + int num_devices; +}; + +using HloShardingToShardingParamTest = + ::testing::TestWithParam; + +TEST_P(HloShardingToShardingParamTest, HloShardingToShardingParam) { + const auto& param = GetParam(); + TF_ASSERT_OK_AND_ASSIGN( + auto sharding_param, + ToShardingParam(param.hlo_sharding, param.rank, param.num_devices)); + EXPECT_TRUE(sharding_param.verify().ok()); + TF_ASSERT_OK_AND_ASSIGN(auto actual_hlo_sharding, + ToHloSharding(sharding_param)); + EXPECT_EQ(param.hlo_sharding, actual_hlo_sharding); + // Verify that the conversion to OpSharding is also correct. + std::vector device_ids(param.num_devices); + std::iota(device_ids.begin(), device_ids.end(), 0); + TF_ASSERT_OK_AND_ASSIGN( + auto hlo_via_op_sharding, + ToHloShardingViaOpSharding(sharding_param, device_ids)); + EXPECT_EQ(param.hlo_sharding, hlo_via_op_sharding); +} + +INSTANTIATE_TEST_SUITE_P( + HloShardingConversionTests, HloShardingToShardingParamTest, + testing::ValuesIn({ + {HloSharding::IotaTile({4, 2}), 2, 8}, + {HloSharding::IotaTile({2, 4}, {4, 2}, {1, 0}), 2, 8}, + {HloSharding::IotaTile({8, 1}), 2, 8}, + {HloSharding::IotaTile({8, 1}, {4, 2}, {1, 0}), 2, 8}, + {HloSharding::PartialTile(TileAssignment({4, 1, 2}, {8}, {0})), 2, 8}, + {HloSharding::PartialTile(TileAssignment({2, 1, 4}, {4, 2}, {1, 0})), 2, + 8}, + {HloSharding::PartialTile(TileAssignment({1, 4, 2}, {8}, {0})), 2, 8}, + {HloSharding::PartialTile(TileAssignment({1, 2, 4}, {4, 2}, {1, 0})), 2, + 8}, + {HloSharding::PartialTile(TileAssignment({4, 3, 2}, {2, 3, 4}, + {2, 1, 0})), + 2, 24}, + {HloSharding::PartialTile(TileAssignment({4, 2, 3}, {6, 4}, {1, 0})), 2, + 24}, + {HloSharding::PartialTile(TileAssignment({6, 1, 4}, {24}, {0})), 2, 24}, + {HloSharding::PartialTile(TileAssignment({12, 1, 2}, {2, 12}, {1, 0})), + 2, 24}, + {HloSharding::PartialTile(TileAssignment({8, 1, 3}, {6, 4}, {1, 0})), 2, + 24}, + {HloSharding::PartialTile(TileAssignment({2, 1, 12}, {24}, {0})), 2, + 24}, + {HloSharding::PartialTile(TileAssignment({3, 1, 8}, {2, 3, 4}, + {1, 0, 2})), + 2, 24}, + {HloSharding::PartialTile(TileAssignment({1, 4, 6}, {6, 4}, {1, 0})), 2, + 24}, + {HloSharding::PartialTile(TileAssignment({1, 12, 2}, {2, 12}, {1, 0})), + 2, 24}, + + {HloSharding::PartialTile(TileAssignment({3, 2, 1, 4}, {2, 3, 4}, + {1, 0, 2})), + 3, 24}, + {HloSharding::PartialTile(TileAssignment({2, 4, 1, 3}, {2, 3, 4}, + {0, 2, 1})), + 3, 24}, + {HloSharding::PartialTile(TileAssignment({4, 3, 1, 2}, {2, 3, 4}, + {2, 1, 0})), + 3, 24}, + {HloSharding::PartialTile(TileAssignment({12, 1, 1, 2}, {2, 12}, + {1, 0})), + 3, 24}, + })); + +class ShardingConversionsEquivalentTest : public test_util::ShardingTest { + public: + void AssertSameTiling(const ShardingParam& sharding_param, + const HloSharding& hlo_sharding, const Shape& shape) { + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr sharding, + ShardingParamSharding::Create( + sharding_param, device_list, MemoryKind())); + const xla::Shape xla_shape(PrimitiveType::F16, shape.dims(), {}, {}); + + TF_ASSERT_OK_AND_ASSIGN(const std::vector index_domains, + sharding->IndexDomains(shape)); + ASSERT_EQ(index_domains.size(), + hlo_sharding.tile_assignment().num_elements()); + const xla::Shape xla_tile_shape = hlo_sharding.TileShape(xla_shape); + for (int i = 0; i < index_domains.size(); ++i) { + SCOPED_TRACE(absl::StrCat("on device ", i)); + EXPECT_EQ(index_domains[i].origin().elements(), + hlo_sharding.TileOffsetForDevice(xla_shape, i)); + EXPECT_EQ(index_domains[i].shape().dims(), xla_tile_shape.dimensions()); + } + } + + private: + std::shared_ptr client_; +}; + +TEST_P(ShardingConversionsEquivalentTest, ShardingParamFullySharded) { + ShardingParam sharding_param{/*dim_shards=*/{2, 3}, + {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; + TF_EXPECT_OK(sharding_param.verify()); + TF_ASSERT_OK_AND_ASSIGN( + const HloSharding hlo_sharding, + ToHloShardingViaOpSharding(sharding_param, {0, 1, 2, 3, 4, 5})); + AssertSameTiling(sharding_param, hlo_sharding, Shape({6, 6})); +} + +TEST_P(ShardingConversionsEquivalentTest, ShardingParamWithPermutation) { + ShardingParam sharding_param{/*dim_shards=*/{2, 3}, + {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; + TF_EXPECT_OK(sharding_param.verify()); + TF_ASSERT_OK_AND_ASSIGN( + const HloSharding hlo_sharding, + ToHloShardingViaOpSharding(sharding_param, {0, 1, 2, 3, 4, 5})); + AssertSameTiling(sharding_param, hlo_sharding, Shape({6, 6})); +} + +TEST_P(ShardingConversionsEquivalentTest, ShardingParamWithReplication) { + ShardingParam sharding_param{/*dim_shards=*/{2, 1}, + {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; + TF_EXPECT_OK(sharding_param.verify()); + TF_ASSERT_OK_AND_ASSIGN( + const HloSharding hlo_sharding, + ToHloShardingViaOpSharding(sharding_param, {0, 1, 2, 3, 4, 5})); + AssertSameTiling(sharding_param, hlo_sharding, Shape({6, 6})); +} + +TEST_P(ShardingConversionsEquivalentTest, OpShardingReplicated) { + OpSharding op_sharding; + op_sharding.set_type(OpSharding::REPLICATED); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_sharding, + HloSharding::FromProto(op_sharding)); + TF_ASSERT_OK_AND_ASSIGN(auto actual, ToShardingParam(hlo_sharding, 2, 6)); + ShardingParam expected{/*dim_shards=*/{1, 1}, + {/*permutation=*/{0}, /*axis_sizes=*/{6}}}; + TF_EXPECT_OK(expected.verify()); + EXPECT_EQ(actual, expected); +} + +INSTANTIATE_TEST_SUITE_P(NumDevices, ShardingConversionsEquivalentTest, + testing::Values(test_util::ShardingTestParam{ + .num_devices = 6, .num_addressable_devices = 4})); + +} // namespace +} // namespace support +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt/support/sharding_param_to_op_sharding.cc b/xla/python/ifrt/support/sharding_param_to_op_sharding.cc deleted file mode 100644 index a8b5bfeefac98..0000000000000 --- a/xla/python/ifrt/support/sharding_param_to_op_sharding.cc +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/python/ifrt/support/sharding_param_to_op_sharding.h" - -#include - -#include "absl/types/span.h" -#include "llvm/ADT/SmallVector.h" -#include "xla/python/ifrt/ir/sharding_param.h" -#include "xla/statusor.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" - -namespace xla { -namespace ifrt { -namespace support { - -StatusOr ToOpSharding(const ShardingParam& sharding_param, - absl::Span device_mapping) { - OpSharding op_sharding; - op_sharding.set_type(OpSharding::OTHER); - - // Populate tile_assignment_dimensions. - auto* tile_assignment_dims = op_sharding.mutable_tile_assignment_dimensions(); - int64_t cum_size = 1; - tile_assignment_dims->Reserve(sharding_param.dim_shards().size() + 1); - for (const int64_t dim_shard : sharding_param.dim_shards()) { - cum_size *= dim_shard; - tile_assignment_dims->Add(dim_shard); - } - int device_count = 1; - for (const int axis_size : sharding_param.minor_to_major().axis_sizes) { - device_count *= axis_size; - } - if (device_count != cum_size) { - op_sharding.set_replicate_on_last_tile_dim(true); - tile_assignment_dims->Add(device_count / cum_size); - } - - // Populate tile_assignment_devices. - llvm::SmallVector devices; - sharding_param.minor_to_major().ToDeviceList(devices); - auto* tile_assignment_devices = op_sharding.mutable_tile_assignment_devices(); - tile_assignment_devices->Reserve(devices.size()); - for (const int device : devices) { - if (device < 0 || device >= device_mapping.size()) { - return tsl::errors::OutOfRange("Can't map device ", device); - } - tile_assignment_devices->Add(device_mapping[device]); - } - - return op_sharding; -} - -} // namespace support -} // namespace ifrt -} // namespace xla diff --git a/xla/python/ifrt/support/sharding_param_to_op_sharding.h b/xla/python/ifrt/support/sharding_param_to_op_sharding.h deleted file mode 100644 index 92635491d3773..0000000000000 --- a/xla/python/ifrt/support/sharding_param_to_op_sharding.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_PYTHON_IFRT_SUPPORT_SHARDING_PARAM_TO_OP_SHARDING_H_ -#define XLA_PYTHON_IFRT_SUPPORT_SHARDING_PARAM_TO_OP_SHARDING_H_ - -#include "absl/types/span.h" -#include "xla/python/ifrt/ir/sharding_param.h" -#include "xla/statusor.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace ifrt { -namespace support { - -// Converts ShardingParam to OpSharding. -// -// This assumes that `sharding_param` is valid. -// -// Returns error when `device_mapping` can't map the logical devices in -// `sharding_param`. -StatusOr ToOpSharding(const ShardingParam& sharding_param, - absl::Span device_mapping); - -} // namespace support -} // namespace ifrt -} // namespace xla - -#endif // XLA_PYTHON_IFRT_SUPPORT_SHARDING_PARAM_TO_OP_SHARDING_H_ diff --git a/xla/python/ifrt/support/sharding_param_to_op_sharding_test.cc b/xla/python/ifrt/support/sharding_param_to_op_sharding_test.cc deleted file mode 100644 index bc40f4011af11..0000000000000 --- a/xla/python/ifrt/support/sharding_param_to_op_sharding_test.cc +++ /dev/null @@ -1,165 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/python/ifrt/support/sharding_param_to_op_sharding.h" - -#include -#include - -#include -#include -#include "absl/strings/str_cat.h" -#include "absl/types/span.h" -#include "xla/hlo/ir/hlo_sharding.h" -#include "xla/python/ifrt/device.h" -#include "xla/python/ifrt/ir/sharding_param.h" -#include "xla/python/ifrt/shape.h" -#include "xla/python/ifrt/sharding.h" -#include "xla/python/ifrt/sharding_test_util.h" -#include "xla/statusor.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status_matchers.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace ifrt { -namespace support { -namespace { - -using ::tsl::testing::StatusIs; - -StatusOr ToHloSharding(const ShardingParam& sharding_param, - absl::Span device_list) { - TF_ASSIGN_OR_RETURN(xla::OpSharding op_sharding, - ToOpSharding(sharding_param, device_list)); - return xla::HloSharding::FromProto(op_sharding); -} - -TEST(ShardingParamToOpShardingTest, Replicated) { - ShardingParam sharding_param{/*dim_shards=*/{1, 1, 1}, - {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; - TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding actual, - ToHloSharding(sharding_param, {0, 1, 2, 3, 4, 5})); - EXPECT_EQ(actual.ToString(), "{replicated}"); -} - -TEST(ShardingParamToOpShardingTest, Maximal) { - ShardingParam sharding_param{/*dim_shards=*/{1, 1}, - {/*permutation=*/{0}, /*axis_sizes=*/{1}}}; - TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding actual, - ToHloSharding(sharding_param, {0})); - EXPECT_EQ(actual.ToString(), "{maximal device=0}"); -} - -TEST(ShardingParamToOpShardingTest, Permutation) { - ShardingParam sharding_param{/*dim_shards=*/{2, 1, 3}, - {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; - TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding actual, - ToHloSharding(sharding_param, {0, 1, 2, 3, 4, 5})); - EXPECT_EQ(actual.ToString(), "{devices=[2,1,3]0,3,1,4,2,5}"); -} - -TEST(ShardingParamToOpShardingTest, Partial) { - ShardingParam sharding_param{/*dim_shards=*/{2, 1}, - {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; - TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding actual, - ToHloSharding(sharding_param, {0, 1, 2, 3, 4, 5})); - EXPECT_EQ(actual.ToString(), - "{devices=[2,1,3]0,1,2,3,4,5 last_tile_dim_replicate}"); -} - -TEST(ShardingParamToOpShardingTest, OneDimToTwoAxes) { - ShardingParam sharding_param{/*dim_shards=*/{4}, - {/*permutation=*/{1, 0}, /*axis_sizes=*/{2, 2}}}; - TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding actual, - ToHloSharding(sharding_param, {0, 1, 2, 3})); - EXPECT_EQ(actual.ToString(), "{devices=[4]0,2,1,3}"); -} - -TEST(ShardingParamToOpShardingTest, NonTrivialDeviceAssignment) { - ShardingParam sharding_param{/*dim_shards=*/{2, 1, 3}, - {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; - TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding actual, - ToHloSharding(sharding_param, {6, 5, 4, 3, 2, 1})); - EXPECT_EQ(actual.ToString(), "{devices=[2,1,3]6,3,5,2,4,1}"); -} - -TEST(ShardingParamToOpShardingTest, ErrorOnDeviceAssignment) { - ShardingParam sharding_param{/*dim_shards=*/{2, 1, 3}, - {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; - EXPECT_THAT(ToHloSharding(sharding_param, {6, 5, 4, 3, 2}), - StatusIs(tsl::error::OUT_OF_RANGE, "Can't map device 5")); -} - -class ShardingParamToOpShardingEquivalentTest : public test_util::ShardingTest { - public: - void AssertSameTiling(const ShardingParam& sharding_param, - const HloSharding& hlo_sharding, const Shape& shape) { - auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); - TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr sharding, - ShardingParamSharding::Create( - sharding_param, device_list, MemoryKind())); - const xla::Shape xla_shape(PrimitiveType::F16, shape.dims(), {}, {}); - - TF_ASSERT_OK_AND_ASSIGN(const std::vector index_domains, - sharding->IndexDomains(shape)); - ASSERT_EQ(index_domains.size(), - hlo_sharding.tile_assignment().num_elements()); - const xla::Shape xla_tile_shape = hlo_sharding.TileShape(xla_shape); - for (int i = 0; i < index_domains.size(); ++i) { - SCOPED_TRACE(absl::StrCat("on device ", i)); - EXPECT_EQ(index_domains[i].origin().elements(), - hlo_sharding.TileOffsetForDevice(xla_shape, i)); - EXPECT_EQ(index_domains[i].shape().dims(), xla_tile_shape.dimensions()); - } - } - - private: - std::shared_ptr client_; -}; - -TEST_P(ShardingParamToOpShardingEquivalentTest, FullySharded) { - ShardingParam sharding_param{/*dim_shards=*/{2, 3}, - {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; - TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding hlo_sharding, - ToHloSharding(sharding_param, {0, 1, 2, 3, 4, 5})); - AssertSameTiling(sharding_param, hlo_sharding, Shape({6, 6})); -} - -TEST_P(ShardingParamToOpShardingEquivalentTest, WithPermutation) { - ShardingParam sharding_param{/*dim_shards=*/{2, 3}, - {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; - TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding hlo_sharding, - ToHloSharding(sharding_param, {0, 1, 2, 3, 4, 5})); - AssertSameTiling(sharding_param, hlo_sharding, Shape({6, 6})); -} - -TEST_P(ShardingParamToOpShardingEquivalentTest, WithReplication) { - ShardingParam sharding_param{/*dim_shards=*/{2, 1}, - {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; - TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding hlo_sharding, - ToHloSharding(sharding_param, {0, 1, 2, 3, 4, 5})); - AssertSameTiling(sharding_param, hlo_sharding, Shape({6, 6})); -} - -INSTANTIATE_TEST_SUITE_P(NumDevices, ShardingParamToOpShardingEquivalentTest, - testing::Values(test_util::ShardingTestParam{ - .num_devices = 6, .num_addressable_devices = 4})); - -} // namespace -} // namespace support -} // namespace ifrt -} // namespace xla diff --git a/xla/python/ifrt/test_util.cc b/xla/python/ifrt/test_util.cc index 228db89e355e0..2da9bf8b4d209 100644 --- a/xla/python/ifrt/test_util.cc +++ b/xla/python/ifrt/test_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,20 +34,21 @@ namespace { class ClientFactory { public: - void Register(std::function>()> factory) { + void Register( + std::function>()> factory) { absl::MutexLock lock(&mu_); CHECK(!factory_) << "Client factory has been already registered."; factory_ = std::move(factory); } - std::function>()> Get() const { + std::function>()> Get() const { absl::MutexLock lock(&mu_); return factory_; } private: mutable absl::Mutex mu_; - std::function>()> factory_ + std::function>()> factory_ ABSL_GUARDED_BY(mu_); }; @@ -59,11 +60,11 @@ ClientFactory& GetGlobalClientFactory() { } // namespace void RegisterClientFactory( - std::function>()> factory) { + std::function>()> factory) { GetGlobalClientFactory().Register(std::move(factory)); } -StatusOr> GetClient() { +absl::StatusOr> GetClient() { auto factory = GetGlobalClientFactory().Get(); CHECK(factory) << "Client factory has not been registered."; return factory(); diff --git a/xla/python/ifrt/test_util.h b/xla/python/ifrt/test_util.h index 218711753a4e5..5ac6ecdea5fad 100644 --- a/xla/python/ifrt/test_util.h +++ b/xla/python/ifrt/test_util.h @@ -1,5 +1,5 @@ #include "tsl/lib/core/status_test_util.h" -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -39,13 +39,13 @@ namespace test_util { // Registers an IFRT client factory function. Must be called only once. void RegisterClientFactory( - std::function>()> factory); + std::function>()> factory); // Returns true iff an IFRT client factory function has been registered. bool IsClientFactoryRegistered(); // Gets a new IFRT client using the registered client factory. -StatusOr> GetClient(); +absl::StatusOr> GetClient(); // Set a default test filter if user doesn't provide one using --gtest_filter. void SetTestFilterIfNotUserSpecified(absl::string_view custom_filter); diff --git a/xla/python/ifrt/tuple.cc b/xla/python/ifrt/tuple.cc index c14ace8e3cb96..a4f9d919d498f 100644 --- a/xla/python/ifrt/tuple.cc +++ b/xla/python/ifrt/tuple.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/tuple.h b/xla/python/ifrt/tuple.h index 523a2e3030aa2..076455354a856 100644 --- a/xla/python/ifrt/tuple.h +++ b/xla/python/ifrt/tuple.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/tuple_impl_test_lib.cc b/xla/python/ifrt/tuple_impl_test_lib.cc index 9716eff1fa9d5..aba8c7a58a388 100644 --- a/xla/python/ifrt/tuple_impl_test_lib.cc +++ b/xla/python/ifrt/tuple_impl_test_lib.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,7 +32,7 @@ namespace xla { namespace ifrt { namespace { -StatusOr> MakeArray(Client* client) { +absl::StatusOr> MakeArray(Client* client) { DType dtype(DType::kF32); Shape shape({2, 3}); std::vector data(6); diff --git a/xla/python/ifrt/types.proto b/xla/python/ifrt/types.proto deleted file mode 100644 index e9c799bcc1ed6..0000000000000 --- a/xla/python/ifrt/types.proto +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -syntax = "proto3"; - -package xla.ifrt; - -// Wire format for `DeviceList`. -message DeviceListProto { - // Serialization and deserialization are expected to ensure that device ids - // are stable across proto construction and consumption. - repeated int32 device_ids = 1; -} - -// Wire format for `Shape`. Currently support static shapes with all dimension -// sizes greater than or equal to 0. -message ShapeProto { - repeated int64 dims = 1; -} diff --git a/xla/python/ifrt/value.cc b/xla/python/ifrt/value.cc index e0cdde43c95e0..752b365853dcb 100644 --- a/xla/python/ifrt/value.cc +++ b/xla/python/ifrt/value.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt/value.h b/xla/python/ifrt/value.h index a7a8a389098a9..11bcf519ce6dc 100644 --- a/xla/python/ifrt/value.h +++ b/xla/python/ifrt/value.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/ifrt_proxy/client/BUILD b/xla/python/ifrt_proxy/client/BUILD new file mode 100644 index 0000000000000..87939859fce1b --- /dev/null +++ b/xla/python/ifrt_proxy/client/BUILD @@ -0,0 +1,531 @@ +# Copyright 2023 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@tsl//tsl:tsl.bzl", "if_google") +load("//xla/python/ifrt_proxy/common:ifrt_proxy.bzl", "default_ifrt_proxy_visibility", "ifrt_proxy_cc_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = default_ifrt_proxy_visibility, +) + +cc_library( + name = "grpc_client_session", + srcs = [ + "grpc_client_session.cc", + ], + hdrs = ["grpc_client_session.h"], + deps = [ + ":client_session", + "//xla/pjrt/distributed:util", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:grpc_credentials", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "@com_github_grpc_grpc//:grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:unbounded_work_queue", + ], +) + +ifrt_proxy_cc_test( + name = "grpc_client_session_test", + srcs = [ + "grpc_client_session_test.cc", + ], + deps = [ + ":grpc_client_session", + ":version", + "//xla/python/ifrt_proxy/common:grpc_credentials", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "@com_github_grpc_grpc//:gpr", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:log_sink_registry", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + ], +) + +cc_library( + name = "rpc_helper", + srcs = [ + "rpc_helper.cc", + ], + hdrs = ["rpc_helper.h"], + deps = [ + ":client_session", + ":host_buffer", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@tsl//tsl/platform:status_to_from_proto", + ] + if_google(["@com_google_absl//absl/types:source_location"]), +) + +cc_library( + name = "client", + srcs = ["client.cc"], + hdrs = ["client.h"], + deps = [ + ":array", + ":compiler", + ":device", + ":memory", + ":rpc_helper", + "//xla:xla_data_proto_cc", + "//xla/pjrt:pjrt_compiler", + "//xla/pjrt:pjrt_device_description", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:common_serdes", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:types", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@tsl//tsl/concurrency:ref_count", + "@tsl//tsl/platform:statusor", + ], +) + +ifrt_proxy_cc_test( + name = "client_test", + srcs = ["client_test.cc"], + deps = [ + ":client", + ":client_session", + ":host_buffer", + ":mock_client_session", + ":mock_host_buffer", + ":rpc_helper", + ":version", + "//xla/pjrt:pjrt_device_description", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/service:computation_placer_hdr", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform", + "@tsl//tsl/platform:protobuf", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + ], +) + +cc_library( + name = "device", + srcs = ["device.cc"], + hdrs = ["device.h"], + deps = [ + "//xla:literal", + "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_device_description", + "//xla/pjrt:pjrt_future", + "//xla/python/ifrt", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "array", + srcs = ["array.cc"], + hdrs = ["array.h"], + deps = [ + ":rpc_helper", + "//xla:status_macros", + "//xla/python/ifrt", + "//xla/python/ifrt:sharding_serdes", + "//xla/python/ifrt_proxy/common:array_util", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:types", + "//xla/python/ifrt_proxy/common:types_proto_cc", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@tsl//tsl/concurrency:ref_count", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +ifrt_proxy_cc_test( + name = "array_test", + srcs = ["array_test.cc"], + deps = [ + ":array", + ":client_session", + ":host_buffer", + ":mock_client_session", + ":mock_host_buffer", + ":rpc_helper", + ":version", + "//xla/python/ifrt", + "//xla/python/ifrt:mock", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:types", + "//xla/python/ifrt_proxy/common:types_proto_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/concurrency:ref_count", + "@tsl//tsl/platform:protobuf", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:test", + ], +) + +cc_library( + name = "client_session", + hdrs = ["client_session.h"], + deps = [ + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "mock_client_session", + testonly = True, + hdrs = ["mock_client_session.h"], + deps = [ + ":client_session", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "compiler", + srcs = ["compiler.cc"], + hdrs = ["compiler.h"], + deps = [ + ":executable", + ":rpc_helper", + "//xla/pjrt:host_callback", + "//xla/python/ifrt", + "//xla/python/ifrt:serdes", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/server:host_callback", + "//xla/python/pjrt_ifrt", + "//xla/python/pjrt_ifrt:xla_ifrt", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@tsl//tsl/concurrency:ref_count", + "@tsl//tsl/platform:status_to_from_proto", + "@tsl//tsl/platform:statusor", + ], +) + +ifrt_proxy_cc_test( + name = "compiler_test", + srcs = ["compiler_test.cc"], + deps = [ + ":client_session", + ":compiler", + ":host_buffer", + ":mock_client_session", + ":mock_host_buffer", + ":rpc_helper", + ":version", + "//xla/python/ifrt", + "//xla/python/ifrt:mock", + "//xla/python/ifrt:serdes", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@tsl//tsl/platform:protobuf", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + ], +) + +cc_library( + name = "executable", + srcs = ["executable.cc"], + hdrs = ["executable.h"], + deps = [ + ":array", + ":host_buffer", + ":rpc_helper", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/pjrt:host_callback", + "//xla/pjrt:pjrt_executable", + "//xla/pjrt:pjrt_layout", + "//xla/python/ifrt", + "//xla/python/ifrt:sharding_serdes", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:types", + "//xla/python/pjrt_ifrt", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@tsl//tsl/concurrency:ref_count", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:platform_port", + "@tsl//tsl/platform:status_to_from_proto", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "host_buffer", + hdrs = ["host_buffer.h"], + deps = [ + "//xla/python/ifrt", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + ], +) + +cc_library( + name = "mock_host_buffer", + testonly = True, + hdrs = ["mock_host_buffer.h"], + deps = [ + ":host_buffer", + "//xla/python/ifrt", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "grpc_host_buffer", + srcs = ["grpc_host_buffer.cc"], + hdrs = ["grpc_host_buffer.h"], + deps = [ + ":host_buffer", + "//xla/pjrt/distributed:util", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_proto_cc", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:unbounded_work_queue", + "@tsl//tsl/protobuf:status_proto_cc", + ], +) + +cc_library( + name = "grpc_client", + srcs = ["grpc_client.cc"], + deps = [ + ":client", + ":grpc_client_session", + ":grpc_host_buffer", + ":registry", + ":rpc_helper", + ":version", + "//xla/pjrt/distributed:util", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:log_entry", + "@com_google_absl//absl/log:log_sink", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], + alwayslink = True, +) + +cc_library( + name = "registry", + srcs = ["registry.cc"], + hdrs = ["registry.h"], + deps = [ + "//xla/python/ifrt", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "memory", + hdrs = ["memory.h"], + deps = [ + "//xla/pjrt:pjrt_client", + "//xla/python/ifrt", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "version", + hdrs = ["version.h"], +) + +ifrt_proxy_cc_test( + name = "executable_test", + srcs = ["executable_test.cc"], + deps = [ + ":array", + ":client_session", + ":executable", + ":host_buffer", + ":mock_client_session", + ":mock_host_buffer", + ":rpc_helper", + ":version", + "//xla:shape_util", + "//xla/pjrt:pjrt_common", + "//xla/pjrt:pjrt_layout", + "//xla/python/ifrt", + "//xla/python/ifrt:mock", + "//xla/python/ifrt:sharding_serdes", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:types", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@tsl//tsl/concurrency:ref_count", + "@tsl//tsl/platform:casts", + "@tsl//tsl/platform:protobuf", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + ], +) + +cc_library( + name = "py_module", + srcs = ["py_module.cc"], + hdrs = ["py_module.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = ["//xla/python:__pkg__"], + deps = [ + ":grpc_client", + ":registry", + "//xla/pjrt:status_casters", + "//xla/python:nb_class_ptr", + "//xla/python:py_client", + "//xla/python/ifrt", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:log_entry", + "@com_google_absl//absl/log:log_sink", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@nanobind", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:statusor", + ], +) diff --git a/xla/python/ifrt_proxy/client/README.md b/xla/python/ifrt_proxy/client/README.md new file mode 100644 index 0000000000000..b97be206439a1 --- /dev/null +++ b/xla/python/ifrt_proxy/client/README.md @@ -0,0 +1,11 @@ +This directory implements the IFRT proxy client. + +## Expected behavior when connection to the IFRT proxy server fails + +If a connection to the proxy server fails abruptly, any in-progress or further +IFRT API calls and `Future`s are expected to either return valid values (if the +value was already fetched from the server and is being cached locally) or an +error from `rpc_helper.cc`'s `WrapAsConnectionError()`. They are expected to +neither "hang" beyond the brief period required to determine whether the +connection has failed nor crash the process internally within the proxy client +library. diff --git a/xla/python/ifrt_proxy/client/array.cc b/xla/python/ifrt_proxy/client/array.cc new file mode 100644 index 0000000000000..21758ea7d8f04 --- /dev/null +++ b/xla/python/ifrt_proxy/client/array.cc @@ -0,0 +1,345 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/array.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/cleanup/cleanup.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/sharding_serdes.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/common/array_util.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/types.h" +#include "xla/python/ifrt_proxy/common/types.pb.h" +#include "xla/status_macros.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +char Array::ID = 0; + +absl::StatusOr> +Array::MakeArrayFromHostBuffer( + xla::ifrt::Client* client, std::shared_ptr rpc_helper, + const void* data, DType dtype, Shape shape, + std::optional> byte_strides, + std::shared_ptr sharding, + xla::ifrt::Client::HostBufferSemantics semantics, + std::function on_done_with_host_buffer) { + TF_ASSIGN_OR_RETURN(const auto array_mem_region, + ArrayMemRegion::FromZerothElementPointer( + /*zeroth_element=*/data, dtype, shape, byte_strides)); + + const uint64_t host_buffer_handle = + rpc_helper->host_buffer_store()->NextHandle(); + TF_RETURN_IF_ERROR( + rpc_helper->host_buffer_store() + ->Store(host_buffer_handle, array_mem_region.mem_region()) + .Await()); + + auto req = std::make_unique(); + req->set_host_buffer_handle(host_buffer_handle); + *req->mutable_dtype() = dtype.ToProto(); + *req->mutable_shape() = shape.ToProto(); + TF_ASSIGN_OR_RETURN(*req->mutable_sharding(), ToShardingProto(*sharding)); + if (byte_strides.has_value()) { + *req->mutable_byte_strides() = ToByteStridesProto(*byte_strides); + } + + TF_ASSIGN_OR_RETURN( + auto response, + rpc_helper->MakeArrayFromHostBuffer(std::move(req)).Await()); + const ArrayHandle handle{response->array_handle()}; + + if (on_done_with_host_buffer != nullptr) { + std::move(on_done_with_host_buffer)(); + } + + return tsl::RCReference( + tsl::MakeRef(client, std::move(rpc_helper), dtype, + std::move(shape), std::move(sharding), handle)); +} + +void Array::Destruct(RpcHelper* rpc_helper, ArrayHandle handle) { + auto req = std::make_unique(); + req->set_array_handle(handle.handle); + rpc_helper->DestructArray(std::move(req)) + .OnReady( + [](absl::StatusOr> response) { + if (!response.ok()) { + LOG(WARNING) + << "Server returned an error when asked to destruct array: " + << response.status(); + } + }); +} + +Future Array::GetReadyFuture() const { + auto req = std::make_unique(); + req->set_array_handle(handle_.handle); + + auto promise = Future::CreatePromise(); + rpc_helper_->CheckArrayReady(std::move(req)) + .OnReady( + [promise](absl::StatusOr> + resp) mutable -> void { promise.Set(resp.status()); }); + return Future(std::move(promise)); +} + +Future Array::Delete() { + auto req = std::make_unique(); + req->set_array_handle(handle_.handle); + + absl::StatusOr> response = + rpc_helper_->DeleteArray(std::move(req)).Await(); + if (!response.ok()) { + return Future(response.status()); + } + + // TODO(b/266635130): So that the caller is not blocked until the server + // replies with the deletion's response, from within + // `Future(status_handle_promise).OnReady()`, schedule `CheckFuture()` on a + // separate thread. + return rpc_helper_->CheckFuture((*response)->deletion_future_handle()); +} + +bool Array::IsDeleted() const { + auto req = std::make_unique(); + req->set_array_handle(handle_.handle); + + absl::StatusOr> response = + rpc_helper_->IsArrayDeleted(std::move(req)).Await(); + if (response.ok()) { + return (*response)->deleted(); + } else { + LOG(ERROR) << "Internal error from proxy server during Array::IsDeleted(): " + << response.status(); + // Return false so that the user likely queries the array with some + // method that returns an absl::Status, and ends up with the real + // error being returned to them by that method. + return false; + } +} + +absl::StatusOr> +Array::AssembleArrayFromSingleDeviceArrays( + xla::ifrt::Client* client, std::shared_ptr rpc_helper, + Shape shape, std::shared_ptr sharding, + absl::Span> arrays, + ArrayCopySemantics semantics) { + auto req = std::make_unique(); + TF_RET_CHECK(!arrays.empty()); + *req->mutable_shape() = shape.ToProto(); + TF_ASSIGN_OR_RETURN(*req->mutable_sharding(), ToShardingProto(*sharding)); + req->set_copy_semantics(ToArrayCopySemanticsProto(semantics)); + for (const tsl::RCReference& rcref : arrays) { + Array* array = llvm::dyn_cast(rcref.get()); + if (array == nullptr) { + return absl::InvalidArgumentError(absl::Substitute( + "Array at $0 supplied to AssembleArrayFromSingleDeviceArrays() is " + "not a xla::ifrt::proxy::Array.", + rcref.get())); + } + req->add_single_device_array_handles(array->handle_.handle); + } + + TF_ASSIGN_OR_RETURN( + std::shared_ptr response, + rpc_helper->AssembleArrayFromSingleDeviceArrays(std::move(req)).Await()); + ArrayHandle handle{response->array_handle()}; + + return tsl::RCReference( + tsl::MakeRef(client, std::move(rpc_helper), arrays[0]->dtype(), + std::move(shape), std::move(sharding), handle)); +} + +absl::StatusOr>> +Array::DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) { + auto req = std::make_unique(); + req->set_array_handle(handle_.handle); + req->set_copy_semantics(ToArrayCopySemanticsProto(semantics)); + + TF_ASSIGN_OR_RETURN( + std::shared_ptr response, + rpc_helper_->DisassembleIntoSingleDeviceArrays(std::move(req)).Await()); + std::vector handles; + for (auto& handle : response->single_device_array_handles()) { + handles.push_back(ArrayHandle{handle}); + } + + TF_ASSIGN_OR_RETURN(auto shape_and_shardings, sharding_->Disassemble(shape_)); + CHECK_EQ(handles.size(), shape_and_shardings.size()) + << " " << absl::StrJoin(handles, ",") << " " << shape_ << " " + << *sharding_ << " "; + + std::vector> result; + result.reserve(handles.size()); + for (int i = 0; i < handles.size(); ++i) { + result.push_back(tsl::RCReference(tsl::MakeRef( + client_, rpc_helper_, dtype_, std::move(shape_and_shardings[i].first), + std::move(shape_and_shardings[i].second), handles[i]))); + } + + return result; +} + +absl::StatusOr> Array::FullyReplicatedShard( + ArrayCopySemantics semantics) { + auto req = std::make_unique(); + req->set_array_handle(handle_.handle); + req->set_copy_semantics(ToArrayCopySemanticsProto(semantics)); + + TF_ASSIGN_OR_RETURN( + std::shared_ptr response, + rpc_helper_->FullyReplicatedShard(std::move(req)).Await()); + + ArrayHandle handle{response->array_handle()}; + + // We are making the assumption the Array returned by the server corresponds + // to the first device. Revisit this when IFRT supports: (1) an inexpensive + // way to derive a SingleDeviceSharding from a fully replicated Array's + // sharding and (2) A generalized `Reshard` API that allows the user to + // request an Array to be made out of a specific single shard. + std::unique_ptr single_device_sharding = + xla::ifrt::SingleDeviceSharding::Create(sharding_->devices()[0], + sharding_->memory_kind()); + + return tsl::RCReference( + tsl::MakeRef(client_, rpc_helper_, dtype_, shape_, + std::move(single_device_sharding), handle)); +} + +absl::StatusOr> Array::Reshard( + std::shared_ptr new_sharding, + ArrayCopySemantics semantics) { + auto req = std::make_unique(); + req->set_array_handle(handle_.handle); + TF_ASSIGN_OR_RETURN(*req->mutable_sharding(), ToShardingProto(*new_sharding)); + req->set_copy_semantics(ToArrayCopySemanticsProto(semantics)); + + TF_ASSIGN_OR_RETURN(std::shared_ptr response, + rpc_helper_->Reshard(std::move(req)).Await()); + ArrayHandle handle{response->array_handle()}; + + return tsl::RCReference(tsl::MakeRef( + client_, rpc_helper_, dtype_, shape_, std::move(new_sharding), handle)); +} + +Future Array::CopyToHostBuffer( + void* data, std::optional> byte_strides, + ArrayCopySemantics semantics) { + const auto mem_region = ArrayMemRegion::FromZerothElementPointer( + /*zeroth_element=*/data, dtype_, shape_, byte_strides); + if (!mem_region.ok()) { + return Future(mem_region.status()); + } + + auto req = std::make_unique(); + req->set_array_handle(handle_.handle); + if (byte_strides.has_value()) { + *req->mutable_byte_strides() = ToByteStridesProto(*byte_strides); + } + const uint64_t host_buffer_handle = + rpc_helper_->host_buffer_store()->NextHandle(); + req->set_host_buffer_handle(host_buffer_handle); + + auto promise = Future::CreatePromise(); + auto on_ready = [host_buffer_store = rpc_helper_->host_buffer_store(), + promise, host_buffer_handle, + mem_region = mem_region->mem_region()]( + absl::StatusOr> + resp) mutable { + if (!resp.ok()) { + promise.Set(resp.status()); + return; + } + + auto host_buffer = host_buffer_store->Lookup(host_buffer_handle); + host_buffer.OnReady( + [promise, mem_region, host_buffer_store, + host_buffer_handle](absl::StatusOr data) mutable { + absl::Cleanup cleanup = [&]() { + host_buffer_store->Delete(host_buffer_handle) + .OnReady([buffer_status = data.status()](absl::Status status) { + if (!status.ok()) { + LOG(WARNING) << "Failed to delete host buffer: " << status + << " (buffer status: " << buffer_status << ")"; + } + }); + }; + + if (!data.ok()) { + promise.Set(data.status()); + return; + } + if (data->size() != mem_region.size()) { + auto status = absl::InternalError( + absl::StrCat("During CopyToHostBuffer, size mismatch in " + "response from proxy: ", + mem_region.size(), " vs ", data->size())); + LOG(ERROR) << status; + promise.Set(status); + return; + } +#if defined(PLATFORM_GOOGLE) + data->CopyToArray(const_cast(mem_region.data())); +#else + std::memcpy(const_cast(mem_region.data()), + data->Flatten().data(), data->size()); +#endif + promise.Set(absl::OkStatus()); + }); + }; + rpc_helper_->CopyToHostBuffer(std::move(req)).OnReady(std::move(on_ready)); + return Future(std::move(promise)); +} + +xla::ifrt::Client* Array::client() const { return client_; } + +std::string Array::DebugString() const { + return absl::Substitute("proxy::Array, this=$0, handle=$1", this, + handle_.handle); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/client/array.h b/xla/python/ifrt_proxy/client/array.h new file mode 100644 index 0000000000000..3b5e8d9d5e114 --- /dev/null +++ b/xla/python/ifrt_proxy/client/array.h @@ -0,0 +1,148 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_ARRAY_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_ARRAY_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/tuple.h" +#include "xla/python/ifrt/value.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/common/types.h" +#include "tsl/concurrency/ref_count.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Implementation of the xla::ifrt::Array interface. +class Array final : public llvm::RTTIExtends { + public: + // `Array::MakeArrayFromHostBuffer()` implements + // `Client::MakeArrayFromHostBuffer()`. + // TODO(b/261226026): Implement logic directly in client.cc. + static absl::StatusOr> + MakeArrayFromHostBuffer(xla::ifrt::Client* client, + std::shared_ptr rpc_helper, + const void* data, DType dtype, Shape shape, + std::optional> byte_strides, + std::shared_ptr sharding, + xla::ifrt::Client::HostBufferSemantics semantics, + std::function on_done_with_host_buffer); + + // `Array::AssembleArrayFromSingleDeviceArrays()` implements + // `Client::AssembleArrayFromSingleDeviceArrays()`. + // TODO(b/261226026): Implement logic directly in client.cc. + static absl::StatusOr> + AssembleArrayFromSingleDeviceArrays( + xla::ifrt::Client* client, std::shared_ptr rpc_helper, + Shape shape, std::shared_ptr sharding, + absl::Span> arrays, + ArrayCopySemantics semantics); + + // Destructs the array associated with the given handle. The corresponding + // array becomes unusable afterwards. + static void Destruct(RpcHelper* rpc_helper, ArrayHandle handle); + + Array(xla::ifrt::Client* const client, std::shared_ptr rpc_helper, + DType dtype, Shape shape, std::shared_ptr sharding, + ArrayHandle handle) + : client_(client), + rpc_helper_(std::move(rpc_helper)), + dtype_(dtype), + shape_(std::move(shape)), + sharding_(std::move(sharding)), + handle_(handle) {} + + ~Array() override { Destruct(rpc_helper_.get(), handle_); } + + ArrayHandle handle() const { return handle_; } + + xla::ifrt::Client* client() const override; + Future GetReadyFuture() const override; + Future Delete() override; + bool IsDeleted() const override; + std::string DebugString() const override; + + DType dtype() const override { return dtype_; } + const Shape& shape() const override { return shape_; } + const Sharding& sharding() const override { return *sharding_; } + std::shared_ptr shared_ptr_sharding() const override { + return sharding_; + } + absl::StatusOr> layout() const override { + return absl::UnimplementedError( + "Array::layout() not implemented for IFRT proxy"); + }; + + absl::StatusOr>> + DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) override; + + absl::StatusOr> FullyReplicatedShard( + xla::ifrt::ArrayCopySemantics semantics) override; + + ABSL_MUST_USE_RESULT + Future CopyToHostBuffer( + void* data, std::optional> byte_strides, + ArrayCopySemantics semantics) override; + + absl::StatusOr> Reshard( + std::shared_ptr new_sharding, + ArrayCopySemantics semantics) override; + + static char ID; // NOLINT + + private: + template + friend tsl::RCReference tsl::MakeRef(Args&&... args); + + // Not owned. Used only for implementing `client()` interface method. Note + // that `client()` will still return the pointer even if the pointed-to memory + // is freed; this unfortunate behavior currently exists in all IFRT + // implementations. + xla::ifrt::Client* const client_; + + const std::shared_ptr rpc_helper_; + const DType dtype_; + const Shape shape_; + const std::shared_ptr sharding_; + const ArrayHandle handle_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_ARRAY_H_ diff --git a/xla/python/ifrt_proxy/client/array_test.cc b/xla/python/ifrt_proxy/client/array_test.cc new file mode 100644 index 0000000000000..686f533387bde --- /dev/null +++ b/xla/python/ifrt_proxy/client/array_test.cc @@ -0,0 +1,138 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/array.h" + +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/mock.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/client/host_buffer.h" +#include "xla/python/ifrt_proxy/client/mock_client_session.h" +#include "xla/python/ifrt_proxy/client/mock_host_buffer.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/client/version.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/types.h" +#include "xla/python/ifrt_proxy/common/types.pb.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/test.h" + +using ::testing::_; +using ::testing::Pointee; +using ::testing::Return; +using ::tsl::protobuf::TextFormat; +using ::tsl::testing::IsOk; + +#if defined(PLATFORM_GOOGLE) +using ::testing::EquivToProto; +using ::testing::proto::Partially; +#endif + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +IfrtProxyVersion Version() { + IfrtProxyVersion version; + version.set_protocol_version(kClientMinVersion); + return version; +} + +class ArrayTest : public ::testing::Test { + protected: + void SetUp() override { + session_ = std::make_shared(); + rpc_helper_ = std::make_shared(Version(), session_); + + host_buffer_store_ = std::make_shared(); + rpc_helper_->set_host_buffer_store(host_buffer_store_); + + // Default handler that ignores all uninteresting requests, but still + // invokes the callback in order to avoid hanging the caller forever. + EXPECT_CALL(*session_, Enqueue(_)) + .WillRepeatedly(Return(Future( + absl::InternalError("Request has no mock handlers")))); + } + + std::shared_ptr session_; + std::shared_ptr rpc_helper_; + std::shared_ptr host_buffer_store_; +}; + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(ArrayTest, Destruction) { + IfrtResponse response; + EXPECT_CALL( + *session_, + Enqueue(Pointee(Partially(EquivToProto(R"pb(destruct_array_request { + array_handle: 1234 + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + MockClient client; + tsl::MakeRef(&client, rpc_helper_, DType(DType::Kind::kBF16), + Shape({}), /*sharding=*/nullptr, ArrayHandle{1234}); +} +#endif + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(ArrayTest, FullyReplicatedShard) { + IfrtResponse response; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb(response_metadata {} + fully_replicated_shard_response { array_handle: 5678 })pb", + &response)); + + EXPECT_CALL(*session_, Enqueue(Pointee(Partially(EquivToProto( + R"pb(fully_replicated_shard_request { + array_handle: 1234 + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + MockClient client; + MockDevice mock_device; + + auto sharding = xla::ifrt::SingleDeviceSharding::Create( + &mock_device, xla::ifrt::MemoryKind()); + + auto array = + tsl::MakeRef(&client, rpc_helper_, DType(DType::Kind::kBF16), + Shape({}), std::move(sharding), ArrayHandle{1234}); + + ASSERT_THAT(array->FullyReplicatedShard(ArrayCopySemantics::kAlwaysCopy), + IsOk()); +} +#endif + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/client/client.cc b/xla/python/ifrt_proxy/client/client.cc new file mode 100644 index 0000000000000..7153ad21420c1 --- /dev/null +++ b/xla/python/ifrt_proxy/client/client.cc @@ -0,0 +1,209 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/client.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xla/pjrt/pjrt_device_description.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt_proxy/client/array.h" +#include "xla/python/ifrt_proxy/client/device.h" +#include "xla/python/ifrt_proxy/client/memory.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/types.h" +#include "xla/xla_data.pb.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +char Client::ID = 0; + +absl::StatusOr> Client::Create( + std::shared_ptr rpc_helper, InitResponse init_response) { + absl::flat_hash_set addressable_device_ids( + init_response.addressable_device_ids().begin(), + init_response.addressable_device_ids().end()); + + absl::flat_hash_map> memories; + for (const auto& m : init_response.memories()) { + auto memory = std::make_unique(m.id(), m.memory_space_kind(), + m.debug_string(), m.to_string()); + memories.insert({m.id(), std::move(memory)}); + } + + absl::flat_hash_map> devices; + std::vector device_ptrs; + std::vector addressable_device_ptrs; + + for (const auto& d : init_response.devices()) { + absl::flat_hash_map attributes; + for (const auto& [key, attr] : d.attributes()) { + TF_ASSIGN_OR_RETURN(xla::PjRtDeviceAttribute value, + FromVariantProto(attr)); + attributes.insert({key, std::move(value)}); + } + + DeviceDescription desc(d.id(), init_response.process_index(), + d.device_kind(), d.debug_string(), d.to_string(), + std::move(attributes)); + bool is_addressable = addressable_device_ids.contains(d.id()); + + auto device = + std::make_unique(std::move(desc), d.local_device_id(), + d.local_hardware_id(), is_addressable); + device_ptrs.push_back(device.get()); + if (is_addressable) { + addressable_device_ptrs.push_back(device.get()); + } + + if (d.has_default_memory_id()) { + const auto it = memories.find(d.default_memory_id()); + if (it == memories.end()) { + return absl::NotFoundError( + absl::StrCat("Memory ", d.default_memory_id(), " not found")); + } + device->default_memory_space_ = it->second.get(); + } + for (const int memory_id : d.memory_ids()) { + const auto it = memories.find(memory_id); + if (it == memories.end()) { + return absl::NotFoundError( + absl::StrCat("Memory ", memory_id, " not found")); + } + device->memory_spaces_.push_back(it->second.get()); + } + + devices.insert({d.id(), std::move(device)}); + } + + for (const auto& m : init_response.memories()) { + Memory* memory = memories.at(m.id()).get(); + for (const int device_id : m.device_ids()) { + const auto device = devices.find(device_id); + if (device == devices.end()) { + return absl::NotFoundError( + absl::StrCat("Device ", device_id, " not found")); + } + memory->devices_.push_back(device->second.get()); + } + } + + // Prefix the runtime_type string received from the server with "proxy/" so + // that the users (of this proxy client, such as JAX) do not erroneously + // conclude that they are talking with the backend runtime directly. + std::string runtime_type = + absl::StrCat("proxy/", init_response.runtime_type()); + + return absl::WrapUnique(new Client( + std::move(rpc_helper), init_response.session_id(), + init_response.platform_name(), init_response.platform_version(), + init_response.platform_id(), init_response.process_index(), runtime_type, + std::move(devices), std::move(device_ptrs), + std::move(addressable_device_ptrs), std::move(memories))); +} + +Client::Client(std::shared_ptr rpc_helper, uint64_t session_id, + std::string platform_name, std::string platform_version, + uint64_t platform_id, uint64_t process_index, + std::string runtime_type, + absl::flat_hash_map> devices, + std::vector device_ptrs, + std::vector addressable_device_ptrs, + absl::flat_hash_map> memories) + : rpc_helper_(rpc_helper), + platform_name_(std::move(platform_name)), + platform_version_(std::move(platform_version)), + platform_id_(platform_id), + process_index_(process_index), + runtime_type_(std::move(runtime_type)), + devices_(std::move(devices)), + device_ptrs_(device_ptrs), + addressable_device_ptrs_(std::move(addressable_device_ptrs)), + memories_(std::move(memories)), + default_compiler_(this, rpc_helper) {} + +Client::~Client() { rpc_helper_->Disconnect(); } + +absl::StatusOr Client::LookupDevice(int device_id) const { + auto it = devices_.find(device_id); + if (it == devices_.end()) { + return absl::NotFoundError( + absl::StrCat("Device ", device_id, " not found.")); + } + return it->second.get(); +} + +absl::StatusOr> +Client::MakeArrayFromHostBuffer( + const void* data, DType dtype, Shape shape, + std::optional> byte_strides, + std::shared_ptr sharding, + xla::ifrt::Client::HostBufferSemantics semantics, + std::function on_done_with_host_buffer) { + return Array::MakeArrayFromHostBuffer( + this, rpc_helper_, data, dtype, std::move(shape), std::move(byte_strides), + std::move(sharding), semantics, std::move(on_done_with_host_buffer)); +} + +absl::StatusOr> +Client::AssembleArrayFromSingleDeviceArrays( + Shape shape, std::shared_ptr sharding, + absl::Span> arrays, + ArrayCopySemantics semantics) { + return Array::AssembleArrayFromSingleDeviceArrays( + this, rpc_helper_, std::move(shape), sharding, arrays, semantics); +} + +absl::StatusOr Client::GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const { + auto req = std::make_unique(); + req->set_num_replicas(num_replicas); + req->set_num_partitions(num_partitions); + + auto future = rpc_helper_->GetDefaultDeviceAssignment(std::move(req)); + TF_ASSIGN_OR_RETURN(auto response, future.Await()); + + TF_ASSIGN_OR_RETURN( + auto assignment_to_return, + DeviceAssignment::Deserialize(response->device_assignment())); + + return *std::move(assignment_to_return); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/client/client.h b/xla/python/ifrt_proxy/client/client.h new file mode 100644 index 0000000000000..9a8da13fbd59c --- /dev/null +++ b/xla/python/ifrt_proxy/client/client.h @@ -0,0 +1,163 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_CLIENT_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_CLIENT_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/tuple.h" +#include "xla/python/ifrt/value.h" +#include "xla/python/ifrt_proxy/client/compiler.h" +#include "xla/python/ifrt_proxy/client/device.h" +#include "xla/python/ifrt_proxy/client/memory.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "tsl/concurrency/ref_count.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Implementation of the xla::ifrt::Client interface. +class Client final : public llvm::RTTIExtends { + public: + static absl::StatusOr> Create( + std::shared_ptr rpc_helper, InitResponse init_response); + + ~Client() override; + + absl::StatusOr> MakeArrayFromHostBuffer( + const void* data, DType dtype, Shape shape, + std::optional> byte_strides, + std::shared_ptr sharding, HostBufferSemantics semantics, + std::function on_done_with_host_buffer) override; + + absl::StatusOr> + AssembleArrayFromSingleDeviceArrays( + Shape shape, std::shared_ptr sharding, + absl::Span> arrays, + ArrayCopySemantics semantics) override; + + absl::StatusOr> MakeTuple( + absl::Span> values) override { + return absl::UnimplementedError( + "MakeTuple is not supported for the IFRT proxy client."); + } + + absl::string_view runtime_type() const override { return runtime_type_; } + absl::string_view platform_name() const override { return platform_name_; } + absl::string_view platform_version() const override { + return platform_version_; + } + PlatformId platform_id() const override { return platform_id_; } + absl::flat_hash_map attributes() + const override { + // TODO(b/309059940): Forward the backend attributes to the client. + return {}; + } + int device_count() const override { return devices().size(); } + int addressable_device_count() const override { + return addressable_devices().size(); + } + absl::Span devices() const override { + return device_ptrs_; + } + absl::Span addressable_devices() const override { + return addressable_device_ptrs_; + } + int process_index() const override { return process_index_; } + absl::StatusOr GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const override; + absl::StatusOr LookupDevice(int device_id) const override; + absl::StatusOr LookupAddressableDevice( + int local_hardware_id) const override { + return absl::UnimplementedError( + "LookupAddressableDevice is not supported for the IFRT proxy client."); + } + xla::ifrt::Compiler* GetDefaultCompiler() override { + return &default_compiler_; + } + absl::StatusOr> + GetTopologyForDevices( + absl::Span devices) const override { + return absl::UnimplementedError( + "GetTopologyForDevices is not supported for the IFRT proxy client."); + } + absl::StatusOr> GetDefaultLayoutForDevice( + xla::ifrt::DType dtype, absl::Span dims, + xla::ifrt::Device* device) const override { + return absl::UnimplementedError( + "GetDefaultLayout is not supported for the IFRT proxy client."); + } + + // For llvm::RTTIExtends. + static char ID; // NOLINT + + private: + Client(std::shared_ptr rpc_helper, uint64_t session_id, + std::string platform_name, std::string platform_version, + uint64_t platform_id, uint64_t process_index, std::string runtime_type, + absl::flat_hash_map> devices, + std::vector device_ptrs, + std::vector addressable_device_ptrs, + absl::flat_hash_map> memories); + + // rpc_helper_ will be referenced by various IFRT objects whose lifetime is + // managed by the layer above the IFRT interface, so shared_ptr is + // appropriate. + const std::shared_ptr rpc_helper_; + + const std::string platform_name_; + const std::string platform_version_; + const uint64_t platform_id_; + const uint64_t process_index_; + const std::string runtime_type_; + + const absl::flat_hash_map> devices_; + const std::vector device_ptrs_; + const std::vector addressable_device_ptrs_; + + const absl::flat_hash_map> memories_; + + Compiler default_compiler_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_CLIENT_H_ diff --git a/xla/python/ifrt_proxy/client/client_session.h b/xla/python/ifrt_proxy/client/client_session.h new file mode 100644 index 0000000000000..9bd795825e50c --- /dev/null +++ b/xla/python/ifrt_proxy/client/client_session.h @@ -0,0 +1,59 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_CLIENT_SESSION_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_CLIENT_SESSION_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Base class that defines the interface between IFRT service protocol and the +// stream implementation that is responsible for sending requests and receiving +// responses. +// +// `ClientSession` implementation must be thread-safe. +class ClientSession { + public: + // `Response` represents either an `IfrtResponse` value, or an `absl::Status` + // value corresponding to termination of the session stream. Value will never + // be a nullptr with OK status. + using Response = absl::StatusOr>; + + virtual ~ClientSession() = default; + + // Enqueues `request` to be sent via the stream; enqueued requests are sent in + // FIFO order. The caller must ensure that `request->op_id()` is unique + // throughout the stream's lifetime. The returned future becomes ready when a + // response for the given op id becomes ready. + virtual Future Enqueue(std::unique_ptr request) = 0; + + // Terminates the `ClientSession` if it has not already been terminated. + virtual void Finish(const absl::Status& s) {} +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_CLIENT_SESSION_H_ diff --git a/xla/python/ifrt_proxy/client/client_test.cc b/xla/python/ifrt_proxy/client/client_test.cc new file mode 100644 index 0000000000000..95fabee93c1a3 --- /dev/null +++ b/xla/python/ifrt_proxy/client/client_test.cc @@ -0,0 +1,216 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/client.h" + +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/pjrt/pjrt_device_description.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/client/host_buffer.h" +#include "xla/python/ifrt_proxy/client/mock_client_session.h" +#include "xla/python/ifrt_proxy/client/mock_host_buffer.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/client/version.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/service/computation_placer.h" +#include "tsl/platform/platform.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::testing::ElementsAre; +using ::testing::Not; +using ::testing::Pair; +using ::testing::Pointee; +using ::testing::Return; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; +using ::tsl::testing::IsOk; +using ::tsl::testing::IsOkAndHolds; + +#if defined(PLATFORM_GOOGLE) +using ::testing::EquivToProto; +using ::testing::proto::Partially; +#endif + +IfrtProxyVersion Version() { + IfrtProxyVersion version; + version.set_protocol_version(kClientMinVersion); + return version; +} + +class ClientTest : public ::testing::Test { + protected: + void SetUp() override { + session_ = std::make_shared(); + rpc_helper_ = std::make_shared(Version(), session_); + + host_buffer_store_ = std::make_shared(); + rpc_helper_->set_host_buffer_store(host_buffer_store_); + + InitResponse response; + ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString( + R"pb( + platform_name: "ifrt-service" + platform_version: "n/a" + platform_id: 42 + process_index: 1 + runtime_type: "ifrt-service" + devices { + id: 0 + local_hardware_id: 1234 + device_kind: "mock" + default_memory_id: 0 + memory_ids: [ 0 ] + attributes { + key: "name" + value { string_value: "device0" } + } + } + devices { + id: 1 + local_hardware_id: 1234 + device_kind: "mock" + default_memory_id: 1 + memory_ids: [ 1 ] + attributes { + key: "name" + value { string_value: "device1" } + } + } + addressable_device_ids: 1 + memories { + id: 0 + memory_space_kind: "mock" + device_ids: [ 0 ] + } + memories { + id: 1 + memory_space_kind: "mock" + device_ids: [ 1 ] + } + )pb", + &response)); + TF_ASSERT_OK_AND_ASSIGN(client_, Client::Create(rpc_helper_, response)); + } + + std::shared_ptr session_; + std::shared_ptr rpc_helper_; + std::shared_ptr host_buffer_store_; + std::unique_ptr client_; +}; + +TEST_F(ClientTest, Init) { + EXPECT_EQ(client_->platform_name(), "ifrt-service"); + EXPECT_EQ(client_->platform_version(), "n/a"); + EXPECT_EQ(client_->platform_id(), 42); + EXPECT_EQ(client_->process_index(), 1); + EXPECT_EQ(client_->runtime_type(), "proxy/ifrt-service"); + + ASSERT_EQ(client_->device_count(), 2); + ASSERT_EQ(client_->addressable_device_count(), 1); + + TF_ASSERT_OK_AND_ASSIGN(auto* const device0, client_->LookupDevice(0)); + EXPECT_EQ(device0->id(), 0); + EXPECT_EQ(device0->local_hardware_id(), 1234); + EXPECT_EQ(device0->device_kind(), "mock"); + EXPECT_THAT(device0->Attributes(), + ElementsAre(Pair( + "name", xla::PjRtDeviceAttribute(std::string("device0"))))); + + ASSERT_THAT(device0->memory_spaces(), SizeIs(1)); + auto* const memory0 = device0->memory_spaces()[0]; + EXPECT_EQ(memory0->id(), 0); + EXPECT_EQ(memory0->memory_space_kind(), "mock"); + EXPECT_THAT(memory0->devices(), UnorderedElementsAre(device0)); + EXPECT_THAT(device0->default_memory_space(), IsOkAndHolds(memory0)); + + TF_ASSERT_OK_AND_ASSIGN(auto* const device1, client_->LookupDevice(1)); + EXPECT_EQ(device1->id(), 1); + EXPECT_EQ(device1->local_hardware_id(), 1234); + EXPECT_EQ(device1->device_kind(), "mock"); + EXPECT_THAT(device1->Attributes(), + ElementsAre(Pair( + "name", xla::PjRtDeviceAttribute(std::string("device1"))))); + + ASSERT_THAT(device1->memory_spaces(), SizeIs(1)); + auto* const memory1 = device1->memory_spaces()[0]; + EXPECT_EQ(memory1->id(), 1); + EXPECT_EQ(memory1->memory_space_kind(), "mock"); + EXPECT_THAT(memory1->devices(), UnorderedElementsAre(device1)); + EXPECT_THAT(device1->default_memory_space(), IsOkAndHolds(memory1)); + + EXPECT_THAT(client_->addressable_devices(), ElementsAre(device1)); +} + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(ClientTest, GetDefaultDeviceAssignmentSuccess) { + IfrtResponse response; + xla::DeviceAssignment assignment(1, 3); + ASSERT_THAT(assignment.Serialize( + response.mutable_get_default_device_assignment_response() + ->mutable_device_assignment()), + IsOk()); + + EXPECT_CALL(*session_, Enqueue(Pointee(Partially(EquivToProto( + R"pb( + get_default_device_assignment_request { + num_replicas: 1 + num_partitions: 3 + } + )pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + TF_ASSERT_OK_AND_ASSIGN(auto assignment_got, + client_->GetDefaultDeviceAssignment(1, 3)); + EXPECT_EQ(assignment_got.replica_count(), 1); + EXPECT_EQ(assignment_got.computation_count(), 3); +} +#endif + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(ClientTest, GetDefaultDeviceAssignmentFailure) { + EXPECT_CALL(*session_, Enqueue(Pointee(Partially(EquivToProto( + R"pb( + get_default_device_assignment_request { + num_replicas: 1 + num_partitions: 3 + } + )pb"))))) + .WillOnce(Return(Future( + absl::InternalError("injected from test")))); + + EXPECT_THAT(client_->GetDefaultDeviceAssignment(1, 3), Not(IsOk())); +} +#endif + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/client/compiler.cc b/xla/python/ifrt_proxy/client/compiler.cc new file mode 100644 index 0000000000000..55132de4cec64 --- /dev/null +++ b/xla/python/ifrt_proxy/client/compiler.cc @@ -0,0 +1,159 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/compiler.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/serdes.h" +#include "xla/python/ifrt_proxy/client/executable.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/server/host_callback.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/status_to_from_proto.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +Compiler::Compiler(xla::ifrt::Client* client, + std::shared_ptr rpc_helper) + : client_(client), rpc_helper_(std::move(rpc_helper)) {} + +absl::StatusOr> Compiler::Compile( + std::unique_ptr program, + std::unique_ptr options) { + auto request = std::make_unique(); + TF_ASSIGN_OR_RETURN(*request->mutable_program(), Serialize(*program)); + + // Extract host callbacks from the XLA compile options. `XlaCompileOptions`'s + // SerDes fails when it contains host callbacks, so the following + // implementation handles host callback serialization out of band until we can + // natively support IFRT host callback on IFRT proxy. + std::vector> + loaded_host_callbacks; + if (auto* xla_options = + llvm::dyn_cast(options.get())) { + for (const auto& loaded_host_callback : + xla_options->loaded_host_callbacks) { + auto* pjrt_host_callback = + llvm::dyn_cast( + loaded_host_callback.get()); + if (pjrt_host_callback == nullptr) { + return absl::UnimplementedError("Unsupported host callback type"); + } + + const xla::HostCallback& xla_host_callback = + pjrt_host_callback->host_callback(); + + // The proxy server runs `RemoteLoadedHostCallback` that delegates actual + // host callback execution to the proxy client. + auto remote_loaded_host_callback = tsl::MakeRef( + client_, xla_host_callback.operands, xla_host_callback.results, + /*queue=*/nullptr); + TF_ASSIGN_OR_RETURN(*request->add_host_callbacks(), + remote_loaded_host_callback->Serialize()); + } + + loaded_host_callbacks.swap(xla_options->loaded_host_callbacks); + } + + TF_ASSIGN_OR_RETURN(*request->mutable_compile_options(), Serialize(*options)); + + // TODO(b/266635130): Avoid blocking the caller. + TF_ASSIGN_OR_RETURN(std::shared_ptr response, + rpc_helper_->Compile(std::move(request)).Await()); + + std::vector + addressable_device_logical_device_ids; + addressable_device_logical_device_ids.reserve( + response->addressable_device_logical_ids_size()); + for (const auto& logical_device_id : + response->addressable_device_logical_ids()) { + xla::ifrt::LoadedExecutable::LogicalDeviceIds id{ + logical_device_id.replica(), logical_device_id.partition()}; + addressable_device_logical_device_ids.push_back(id); + } + + std::vector addressable_devices; + addressable_devices.reserve(response->addressable_device_ids_size()); + for (const int32_t device_id : response->addressable_device_ids()) { + TF_ASSIGN_OR_RETURN(xla::ifrt::Device* const device, + client_->LookupDevice(device_id)); + addressable_devices.push_back(device); + } + + absl::StatusOr> fingerprint; + switch (response->fingerprint_case()) { + case CompileResponse::kFingerprintValue: + fingerprint = response->fingerprint_value(); + break; + case CompileResponse::kFingerprintError: + fingerprint = tsl::StatusFromProto(response->fingerprint_error()); + break; + default: + fingerprint = std::nullopt; + break; + } + + Future ready_future = + rpc_helper_->CheckFuture(response->ready_future_handle()); + + std::vector loaded_host_callback_handles( + response->loaded_host_callback_handles().begin(), + response->loaded_host_callback_handles().end()); + + return std::make_unique( + client_, rpc_helper_, response->loaded_executable_handle(), + response->name(), response->num_devices(), + std::move(addressable_device_logical_device_ids), + std::move(addressable_devices), std::move(fingerprint), + std::move(ready_future), std::move(loaded_host_callbacks), + std::move(loaded_host_callback_handles)); +} + +absl::StatusOr> +Compiler::DeserializeLoadedExecutable( + absl::string_view serialized, + std::unique_ptr options) { + return absl::UnimplementedError( + "IFRT service compiler does not support `DeserializeLoadedExecutable` " + "since the underlying serialization format is not stable"); +} + +char Compiler::ID = 0; + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/client/compiler.h b/xla/python/ifrt_proxy/client/compiler.h new file mode 100644 index 0000000000000..6bfc814766d11 --- /dev/null +++ b/xla/python/ifrt_proxy/client/compiler.h @@ -0,0 +1,57 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_COMPILER_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_COMPILER_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +class Compiler final : public llvm::RTTIExtends { + public: + Compiler(xla::ifrt::Client* client, std::shared_ptr rpc_helper); + + absl::StatusOr> Compile( + std::unique_ptr program, + std::unique_ptr options) override; + + absl::StatusOr> + DeserializeLoadedExecutable( + absl::string_view serialized, + std::unique_ptr options) + override; + + static char ID; // NOLINT + + private: + xla::ifrt::Client* client_; + std::shared_ptr rpc_helper_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_COMPILER_H_ diff --git a/xla/python/ifrt_proxy/client/compiler_test.cc b/xla/python/ifrt_proxy/client/compiler_test.cc new file mode 100644 index 0000000000000..fbefdec79b00d --- /dev/null +++ b/xla/python/ifrt_proxy/client/compiler_test.cc @@ -0,0 +1,225 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/compiler.h" + +#include +#include +#include + +#include +#include +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/mock.h" +#include "xla/python/ifrt/serdes.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/client/host_buffer.h" +#include "xla/python/ifrt_proxy/client/mock_client_session.h" +#include "xla/python/ifrt_proxy/client/mock_host_buffer.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/client/version.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::FieldsAre; +using ::testing::Invoke; +using ::testing::Optional; +using ::testing::Pointee; +using ::testing::Return; +using ::tsl::protobuf::TextFormat; +using ::tsl::testing::IsOkAndHolds; +using ::tsl::testing::StatusIs; + +#if defined(PLATFORM_GOOGLE) +using ::testing::EquivToProto; +using ::testing::proto::Partially; +#endif + +struct TestProgram : llvm::RTTIExtends { + static char ID; // NOLINT +}; + +[[maybe_unused]] char TestProgram::ID = 0; // NOLINT + +class TestProgramSerDes : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::proxy::TestProgram"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + CHECK(llvm::isa(serializable)); + return ""; + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + return std::make_unique(); + } + + static char ID; // NOLINT +}; + +[[maybe_unused]] char TestProgramSerDes::ID = 0; // NOLINT + +struct TestCompileOptions + : llvm::RTTIExtends { + static char ID; // NOLINT +}; + +[[maybe_unused]] char TestCompileOptions::ID = 0; // NOLINT + +class TestCompileOptionsSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::proxy::TestCompileOptions"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + CHECK(llvm::isa(serializable)); + return ""; + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + return std::make_unique(); + } + + static char ID; // NOLINT +}; + +[[maybe_unused]] char TestCompileOptionsSerDes::ID = 0; // NOLINT + +IfrtProxyVersion Version() { + IfrtProxyVersion version; + version.set_protocol_version(kClientMinVersion); + return version; +} + +class CompilerTest : public testing::Test { + protected: + static void SetUpTestSuite() { + RegisterSerDes(std::make_unique()); + RegisterSerDes( + std::make_unique()); + } + + void SetUp() override { + session_ = std::make_shared(); + rpc_helper_ = std::make_shared(Version(), session_); + + host_buffer_store_ = std::make_shared(); + rpc_helper_->set_host_buffer_store(host_buffer_store_); + + // Default handler that ignores all uninteresting requests but still + // invokes the callback in order to avoid hanging the caller forever. + EXPECT_CALL(*session_, Enqueue(_)) + .WillRepeatedly(Return(Future( + absl::InternalError("Request has no mock handlers")))); + } + + std::shared_ptr session_; + std::shared_ptr rpc_helper_; + std::shared_ptr host_buffer_store_; +}; + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(CompilerTest, Compile) { + std::vector devices(2); + + MockClient client; + ON_CALL(client, LookupDevice(_)).WillByDefault(Invoke([&](int id) { + return &devices[id]; + })); + + Compiler compiler(&client, rpc_helper_); + + IfrtResponse response; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb(compile_response { + loaded_executable_handle: 1234 + name: "foo-executable" + num_devices: 2 + addressable_device_logical_ids { replica: 0 partition: 0 } + addressable_device_logical_ids { replica: 0 partition: 1 } + addressable_device_ids: [ 0, 1 ] + fingerprint_value: "fingerprint" + ready_future_handle: 5678 + })pb", + &response)); + EXPECT_CALL(*session_, + Enqueue(Pointee(Partially(EquivToProto( + R"pb(compile_request { + program { type_name: "xla::ifrt::proxy::TestProgram" } + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + ASSERT_TRUE(TextFormat::ParseFromString(R"pb( + response_metadata { + status { + code: 2 # UNKNOWN + message: "injected error" + } + } + )pb", + &response)); + EXPECT_CALL(*session_, + Enqueue(Pointee(Partially(EquivToProto(R"pb(check_future_request { + future_handle: 5678 + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + TF_ASSERT_OK_AND_ASSIGN( + auto executable, + compiler.Compile(std::make_unique(), + std::make_unique())); + + EXPECT_EQ(executable->name(), "foo-executable"); + EXPECT_EQ(executable->num_devices(), 2); + EXPECT_THAT(executable->addressable_device_logical_ids(), + ElementsAre(FieldsAre(0, 0), FieldsAre(0, 1))); + EXPECT_THAT(executable->addressable_devices(), + ElementsAre(&devices[0], &devices[1])); + EXPECT_THAT(executable->Fingerprint(), + IsOkAndHolds(Optional(std::string("fingerprint")))); + EXPECT_THAT(executable->GetReadyFuture().Await(), + StatusIs(absl::StatusCode::kUnknown, "injected error")); +} +#endif + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/client/device.cc b/xla/python/ifrt_proxy/client/device.cc new file mode 100644 index 0000000000000..f43d9aec101f7 --- /dev/null +++ b/xla/python/ifrt_proxy/client/device.cc @@ -0,0 +1,61 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/device.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/literal.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_future.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +std::unique_ptr Device::CreateAsyncTrackingEvent( + absl::string_view description) const { + return nullptr; +} + +absl::Status Device::TransferToInfeed(const xla::LiteralSlice& literal) { + return absl::UnimplementedError("Device does not support TransferToInfeed"); +} + +absl::Status Device::TransferFromOutfeed(xla::MutableBorrowingLiteral literal) { + return absl::UnimplementedError( + "Device does not support TransferFromOutfeed"); +} + +absl::Span Device::memory_spaces() const { + return memory_spaces_; +} + +absl::StatusOr Device::default_memory_space() const { + if (default_memory_space_ == nullptr) { + return absl::UnimplementedError( + "Device does not support default_memory_space"); + } + return default_memory_space_; +} + +char Device::ID = 0; // NOLINT + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/client/device.h b/xla/python/ifrt_proxy/client/device.h new file mode 100644 index 0000000000000..6cb461865818d --- /dev/null +++ b/xla/python/ifrt_proxy/client/device.h @@ -0,0 +1,139 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_DEVICE_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_DEVICE_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/literal.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_device_description.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/memory.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +class Client; + +class DeviceDescription final : public xla::PjRtDeviceDescription { + public: + DeviceDescription( + int id, int process_index, std::string device_kind, + std::string debug_string, std::string to_string, + absl::flat_hash_map attributes) + : id_(id), + process_index_(process_index), + device_kind_(device_kind), + debug_string_(std::move(debug_string)), + to_string_(std::move(to_string)), + attributes_(std::move(attributes)) {} + + int id() const override { return id_; } + + int process_index() const override { return process_index_; } + + absl::string_view device_kind() const override { return device_kind_; } + + absl::string_view DebugString() const override { return debug_string_; } + + absl::string_view ToString() const override { return to_string_; } + + const absl::flat_hash_map& Attributes() + const override { + return attributes_; + } + + private: + int id_; + int process_index_; + std::string device_kind_; + std::string debug_string_; + std::string to_string_; + absl::flat_hash_map attributes_; +}; + +class Device final : public xla::ifrt::Device { + public: + Device(DeviceDescription description, int local_device_id, + int local_hardware_id, bool is_addressable) + : description_(std::move(description)), + local_device_id_(local_device_id), + local_hardware_id_(local_hardware_id), + is_addressable_(is_addressable) {} + + xla::PjRtClient* client() const override { return nullptr; } + + bool IsAddressable() const override { return is_addressable_; } + + const xla::PjRtDeviceDescription& description() const override { + return description_; + } + + int local_hardware_id() const override { + return local_hardware_id_typed().value(); + } + + PjRtLocalDeviceId local_device_id() const override { + return PjRtLocalDeviceId(local_device_id_); + } + + PjRtLocalHardwareId local_hardware_id_typed() const override { + return PjRtLocalHardwareId(local_hardware_id_); + } + + std::unique_ptr CreateAsyncTrackingEvent( + absl::string_view description) const override; + + absl::Status TransferToInfeed(const xla::LiteralSlice& literal) override; + + absl::Status TransferFromOutfeed( + xla::MutableBorrowingLiteral literal) override; + + absl::Span memory_spaces() const override; + + absl::StatusOr default_memory_space() const override; + + static char ID; // NOLINT + + private: + friend class Client; // For `memory_spaces_` initialization. + + const DeviceDescription description_; + const int local_device_id_; + const int local_hardware_id_; + const bool is_addressable_; + + std::vector memory_spaces_; + xla::ifrt::Memory* default_memory_space_ = nullptr; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_DEVICE_H_ diff --git a/xla/python/ifrt_proxy/client/executable.cc b/xla/python/ifrt_proxy/client/executable.cc new file mode 100644 index 0000000000000..a8c9cc0788d31 --- /dev/null +++ b/xla/python/ifrt_proxy/client/executable.cc @@ -0,0 +1,581 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/executable.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/bind_front.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/layout.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding_serdes.h" +#include "xla/python/ifrt_proxy/client/array.h" +#include "xla/python/ifrt_proxy/client/host_buffer.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/types.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/cpu_info.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status_to_from_proto.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/threadpool.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +namespace { + +// Locally executes the loaded host callback with given operand buffer from the +// IFRT proxy server and returns a result buffer to be sent back. +absl::StatusOr ExecuteLoadedHostCallback( + xla::ifrt::LoadedHostCallback* loaded_host_callback, + absl::Cord operand_buffer) { +#if defined(PLATFORM_GOOGLE) + auto* pjrt_host_callback = + llvm::dyn_cast( + loaded_host_callback); + if (pjrt_host_callback == nullptr) { + return absl::UnimplementedError( + "Non-PjRt host callbacks cannot be executed"); + } + const xla::HostCallback& xla_host_callback = + pjrt_host_callback->host_callback(); + + // The following allocates both operands and results using `aligned_alloc` in + // order to (loosely) emulate the XLA implementation where host callbacks are + // often called with aligned operand/result buffers. While this may not be + // strictly necessary for some callbacks, this reduces the chances of proxied + // callbacks behaving differently on a best-effort basis. + constexpr int kAlignment = 32; + + struct Deleter { + void operator()(void* p) { free(p); } + }; + + std::vector> operands; + operands.reserve(xla_host_callback.operands.size()); + std::vector operand_ptrs; + operand_ptrs.reserve(xla_host_callback.operands.size()); + + absl::CordReader reader(operand_buffer); + for (const auto& spec : xla_host_callback.operands) { + const int64_t size = xla::ShapeUtil::ByteSizeOf(spec.shape); + void* p; + CHECK_EQ(posix_memalign(&p, kAlignment, size), 0); + std::unique_ptr buffer(reinterpret_cast(p)); + + if (reader.Available() < size) { + return absl::InternalError(absl::StrCat( + "Buffer overflow while reading host callback execution operands; ", + "range: [", reader.Position(), ", ", reader.Position() + size, "), ", + "buffer size: ", operand_buffer.size())); + } + reader.ReadN(size, buffer.get()); + + operand_ptrs.push_back(buffer.get()); + operands.push_back(std::move(buffer)); + } + if (reader.Available() > 0) { + return absl::InternalError(absl::StrCat( + "Host callback execution did not consume the entire operand buffer; " + "size: ", + operand_buffer.size(), "; consumed: ", reader.Available())); + } + + absl::Cord result_buffer; + std::vector result_ptrs; + result_ptrs.reserve(xla_host_callback.results.size()); + + for (const auto& spec : xla_host_callback.results) { + const int64_t size = xla::ShapeUtil::ByteSizeOf(spec.shape); + void* data; + CHECK_EQ(posix_memalign(&data, kAlignment, size), 0); + + result_ptrs.push_back(data); + result_buffer.AppendExternalMemory( + absl::string_view(reinterpret_cast(data), size), data, &free); + } + + TF_RETURN_IF_ERROR( + xla_host_callback.callback(result_ptrs.data(), operand_ptrs.data())); + + return result_buffer; +#else + return absl::UnimplementedError("ExecuteLoadedHostCallback is unsupported."); +#endif +} + +// Same as `ExecuteLoadedHostCallback`, except that it uses host buffer store to +// retrieve operands and store results. +absl::StatusOr PrepareAndExecuteLoadedHostCallback( + ClientHostBufferStore* host_buffer_store, + xla::ifrt::LoadedHostCallback* loaded_host_callback, + uint64_t operand_handle) { + TF_ASSIGN_OR_RETURN(absl::Cord operands, + host_buffer_store->Lookup(operand_handle).Await()); + absl::Cleanup cleanup = [&]() { + host_buffer_store->Delete(operand_handle).OnReady([](absl::Status status) { + if (!status.ok()) { + LOG(ERROR) << "Failed to delete host callback operands: " << status; + } + }); + }; + + TF_ASSIGN_OR_RETURN( + absl::Cord results, + ExecuteLoadedHostCallback(loaded_host_callback, std::move(operands))); + + const uint64_t result_handle = host_buffer_store->NextHandle(); + TF_RETURN_IF_ERROR(host_buffer_store->Store(result_handle, results).Await()); + return result_handle; +} + +} // namespace + +LoadedExecutable::LoadedExecutable( + xla::ifrt::Client* client, std::shared_ptr rpc_helper, + uint64_t handle, std::string name, int num_devices, + std::vector + addressable_device_logical_device_ids, + std::vector addressable_devices, + absl::StatusOr> fingerprint, + Future ready_future, + std::vector> + loaded_host_callbacks, + std::vector loaded_host_callback_handles) + : client_(client), + rpc_helper_(std::move(rpc_helper)), + handle_(handle), + name_(std::move(name)), + num_devices_(num_devices), + addressable_device_logical_device_ids_( + std::move(addressable_device_logical_device_ids)), + addressable_devices_(std::move(addressable_devices)), + fingerprint_(std::move(fingerprint)), + ready_future_(std::move(ready_future)) { + // Start host callback pollers. + CHECK_EQ(loaded_host_callbacks.size(), loaded_host_callback_handles.size()); + if (!loaded_host_callbacks.empty()) { + for (int i = 0; i < loaded_host_callbacks.size(); ++i) { + PollLoadedHostCallback(loaded_host_callback_handles[i], + loaded_host_callbacks[i]); + } + } + + // Asynchronously fetch shardings. Since users of `LoadedExecutable` typically + // require sharding information to invoke the executable, it is beneficial to + // eagerly schedule this fetch since, in some implementations, it may take a + // long time for sharding information to be available. + + auto promise = + Future>>::CreatePromise(); + metadata_future_ = Future>>(promise); + + auto req = std::make_unique(); + req->set_loaded_executable_handle(handle_); + + auto on_done = + [promise]( + absl::StatusOr> + response) mutable { + if (!response.ok()) { + LOG(ERROR) << "LoadedExecutableMetadata: Got " << response.status(); + promise.Set(response.status()); + return; + } + + auto info = std::make_shared(); + + if (response.value()->has_parameter_shardings()) { + const auto& p = response.value()->parameter_shardings().shardings(); + info->parameter_shardings.emplace(p.begin(), p.end()); + } + if (response.value()->has_output_shardings()) { + const auto& o = response.value()->output_shardings().shardings(); + info->output_shardings.emplace(o.begin(), o.end()); + } + + auto parse_layouts = + [](const LoadedExecutableMetadataResponse::LayoutList& list) { + std::vector layouts; + layouts.reserve(list.layouts_size()); + for (const auto& layout : list.layouts()) { + layouts.push_back(xla::Layout::CreateFromProto(layout)); + } + return layouts; + }; + + if (response.value()->has_parameter_layouts_list()) { + info->parameter_layouts = + parse_layouts(response.value()->parameter_layouts_list()); + } else if (response.value()->has_parameter_layouts_error()) { + info->parameter_layouts = + tsl::StatusFromProto(response.value()->parameter_layouts_error()); + } else { + info->parameter_layouts = absl::UnimplementedError( + "IFRT Proxy server did not return parameter layouts"); + } + if (response.value()->has_output_layouts_list()) { + info->output_layouts = + parse_layouts(response.value()->output_layouts_list()); + } else if (response.value()->has_output_layouts_error()) { + info->output_layouts = + tsl::StatusFromProto(response.value()->output_layouts_error()); + } else { + info->output_layouts = absl::UnimplementedError( + "IFRT Proxy server did not return output layouts"); + } + + if (const absl::Status s = tsl::StatusFromProto( + response.value()->output_memory_kinds().status()); + !s.ok()) { + info->output_memory_kinds = s; + } else { + std::vector> output_memory_kinds; + for (const auto& list : + response.value()->output_memory_kinds().memory_kind_lists()) { + std::vector kinds; + kinds.reserve(list.memory_kinds_size()); + for (const absl::string_view kind : list.memory_kinds()) { + const auto it = + info->memory_kinds.insert(std::string(kind)).first; + kinds.push_back(*it); + } + output_memory_kinds.push_back(std::move(kinds)); + } + info->output_memory_kinds = std::move(output_memory_kinds); + } + + promise.Set(std::move(info)); + }; + rpc_helper_->LoadedExecutableMetadata(std::move(req)) + .OnReady(std::move(on_done)); +} + +LoadedExecutable::~LoadedExecutable() { + auto req = std::make_unique(); + req->set_loaded_executable_handle(handle_); + + rpc_helper_->LoadedExecutableDestruct(std::move(req)) + .OnReady( + [](absl::StatusOr> + response) { + if (!response.ok()) { + LOG(ERROR) << "Failed to destroy `LoadedExecutable`: " + << response.status(); + } + }); +} + +xla::ifrt::Client* LoadedExecutable::client() const { return client_; } + +absl::string_view LoadedExecutable::name() const { return name_; } + +absl::StatusOr> LoadedExecutable::Fingerprint() + const { + return fingerprint_; +} + +absl::StatusOr LoadedExecutable::Serialize() const { + return absl::UnimplementedError( + "IFRT service executable does not support `Serialize` since the " + "underlying serialization format is not stable"); +} + +Future LoadedExecutable::GetReadyFuture() const { + return ready_future_; +} + +int LoadedExecutable::num_devices() const { return num_devices_; } + +int64_t LoadedExecutable::SizeOfGeneratedCodeInBytes() const { + LOG(FATAL) << "Unimplemented"; +} + +absl::StatusOr LoadedExecutable::GetCompiledMemoryStats() + const { + return absl::UnimplementedError("Unimplemented"); +} + +std::optional> LoadedExecutable::GetParameterShardings() + const { + auto info = metadata_future_.Await(); + if (!info.ok()) { + return std::nullopt; + } + return (*info)->parameter_shardings; +} + +std::optional> LoadedExecutable::GetOutputShardings() + const { + auto info = metadata_future_.Await(); + if (!info.ok()) { + return std::nullopt; + } + return (*info)->output_shardings; +} + +absl::StatusOr>> +LoadedExecutable::GetParameterLayouts() const { + TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await()); + TF_RETURN_IF_ERROR(info->parameter_layouts.status()); + + std::vector> result; + result.reserve(info->parameter_layouts->size()); + for (const xla::Layout& layout : *info->parameter_layouts) { + result.push_back(std::make_unique(layout)); + } + return result; +} + +absl::StatusOr>> +LoadedExecutable::GetOutputLayouts() const { + TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await()); + TF_RETURN_IF_ERROR(info->output_layouts.status()); + + std::vector> result; + result.reserve(info->output_layouts->size()); + for (const xla::Layout& layout : *info->output_layouts) { + result.push_back(std::make_unique(layout)); + } + return result; +} + +absl::StatusOr>> +LoadedExecutable::GetOutputMemoryKinds() const { + TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await()); + return info->output_memory_kinds; +} + +absl::StatusOr>> +LoadedExecutable::GetHloModules() const { + return absl::UnimplementedError( + "IFRT service does not support LoadedExecutable::GetHloModules() since " + "HloModule does not provide stable serialization"); +} + +absl::StatusOr< + absl::flat_hash_map> +LoadedExecutable::GetCostAnalysis() const { + return absl::UnimplementedError("Unimplemented"); +} + +absl::StatusOr +LoadedExecutable::Execute(absl::Span> args, + const ExecuteOptions& options, + std::optional devices) { + auto req = std::make_unique(); + req->set_loaded_executable_handle(handle_); + for (const auto& arg : args) { + auto* array = llvm::dyn_cast_or_null(arg.get()); + if (array == nullptr) { + return absl::InvalidArgumentError( + "Invalid IFRT array type provided to `LoadedExecutable::Execute`"); + } + req->add_args_handles(array->handle().handle); + } + TF_ASSIGN_OR_RETURN(*req->mutable_execute_options(), options.ToProto()); + if (devices.has_value()) { + for (const auto* device : *devices) { + req->add_device_ids(device->id()); + } + } + + TF_ASSIGN_OR_RETURN( + std::shared_ptr response, + rpc_helper_->LoadedExecutableExecute(std::move(req)).Await()); + + // NOTE: All future and array handles in `response` must have an owner + // locally, or be requested to be destructed remotely, before returning. + + xla::ifrt::LoadedExecutable::ExecuteResult result; + + // Populate the execution status future. `CheckFuture` deletes the server-side + // futures after its completion. + result.status = rpc_helper_->CheckFuture(response->status_handle()); + + // Create output arrays. The cleanup logic ensures that all handles are + // properly cleaned up on early return. + absl::Cleanup cleanup = [&]() { + int index = result.outputs.size(); + result.outputs.clear(); // Cleaned up by `~Array()`. + + for (; index < response->outputs_size(); ++index) { + Array::Destruct(rpc_helper_.get(), + ArrayHandle{response->outputs(index).array_handle()}); + } + }; + const auto lookup_device = absl::bind_front(&Client::LookupDevice, client()); + for (const auto& output : response->outputs()) { + TF_ASSIGN_OR_RETURN(DType dtype, DType::FromProto(output.dtype())); + TF_ASSIGN_OR_RETURN(Shape shape, Shape::FromProto(output.shape())); + TF_ASSIGN_OR_RETURN(auto sharding, + FromShardingProto(lookup_device, output.sharding())); + result.outputs.push_back(tsl::MakeRef( + client(), rpc_helper_, dtype, std::move(shape), std::move(sharding), + ArrayHandle{output.array_handle()})); + } + std::move(cleanup).Cancel(); + + return result; +} + +Future LoadedExecutable::Delete() { + auto req = std::make_unique(); + req->set_loaded_executable_handle(handle_); + + absl::StatusOr> response = + rpc_helper_->LoadedExecutableDelete(std::move(req)).Await(); + if (!response.ok()) { + return Future(response.status()); + } + return rpc_helper_->CheckFuture((*response)->future_handle()); +} + +bool LoadedExecutable::IsDeleted() const { + auto req = std::make_unique(); + req->set_loaded_executable_handle(handle_); + + absl::StatusOr> response = + rpc_helper_->LoadedExecutableIsDeleted(std::move(req)).Await(); + if (!response.ok()) { + LOG(ERROR) << "Failed to query the deletion status of `LoadedExecutable`: " + << response.status(); + return false; + } + return (*response)->is_deleted(); +} + +absl::Span +LoadedExecutable::addressable_device_logical_ids() const { + return addressable_device_logical_device_ids_; +} + +absl::Span LoadedExecutable::addressable_devices() + const { + return addressable_devices_; +} + +namespace { + +static tsl::ThreadOptions GetThreadOptions() { + tsl::ThreadOptions thread_options; + // Ensure the threads' stack is large enough for arbitrary Python code. + thread_options.stack_size = 2 * 1024 * 1024; // 2 MiB + return thread_options; +} + +} // namespace + +void LoadedExecutable::PollLoadedHostCallback( + uint64_t handle, + tsl::RCReference loaded_host_callback) { + // Note: individual host callbacks may live longer than the executable as the + // destruction of an IFRT executable is not required to block until all + // in-flight executions are complete. Therefore, the following lambda must not + // capture `this` and is scheduled on the default thread pool. + auto f = [rpc_helper = rpc_helper_, handle, + loaded_host_callback = std::move(loaded_host_callback)]() { + while (true) { + const uint64_t operand_handle = + rpc_helper->host_buffer_store()->NextHandle(); + + auto poll_req = std::make_unique(); + poll_req->set_loaded_host_callback_handle(handle); + poll_req->set_operand_host_buffer_handle(operand_handle); + auto response = + rpc_helper->LoadedHostCallbackPoll(std::move(poll_req)).Await(); + + if (!response.ok()) { + LOG_EVERY_N_SEC(ERROR, 60) + << "Failed to poll host callback execution: " << response.status(); + continue; + } + + if (!(*response)->has_host_callback_execution_handle()) { + // The host callback is destructed from the server. + break; + } + + auto ret_req = std::make_unique(); + ret_req->set_host_callback_execution_handle( + (*response)->host_callback_execution_handle()); + + absl::StatusOr result_handle = + PrepareAndExecuteLoadedHostCallback( + rpc_helper->host_buffer_store().get(), loaded_host_callback.get(), + operand_handle); + if (result_handle.ok()) { + ret_req->set_result_host_buffer_handle(*result_handle); + } else { + *ret_req->mutable_error() = tsl::StatusToProto(result_handle.status()); + } + + rpc_helper->LoadedHostCallbackReturn(std::move(ret_req)) + .OnReady([](absl::StatusOr< + std::shared_ptr> + response) { + if (!response.ok()) { + LOG(ERROR) << "Failed to return host callback results: " + << response.status(); + } + }); + } + }; + + static auto* global_pool = new tsl::thread::ThreadPool( + tsl::Env::Default(), GetThreadOptions(), "XLAIFRTProxy", + std::min(16, tsl::port::MaxParallelism())); + global_pool->Schedule(std::move(f)); +} + +char LoadedExecutable::ID = 0; // NOLINT + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/client/executable.h b/xla/python/ifrt_proxy/client/executable.h new file mode 100644 index 0000000000000..d6b12d9721191 --- /dev/null +++ b/xla/python/ifrt_proxy/client/executable.h @@ -0,0 +1,147 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_EXECUTABLE_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_EXECUTABLE_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/layout.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/xla_data.pb.h" +#include "tsl/concurrency/ref_count.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +class LoadedExecutable final + : public llvm::RTTIExtends { + public: + LoadedExecutable(xla::ifrt::Client* client, + std::shared_ptr rpc_helper, uint64_t handle, + std::string name, int num_devices, + std::vector + addressable_device_logical_device_ids, + std::vector addressable_devices, + absl::StatusOr> fingerprint, + Future ready_future, + std::vector> + loaded_host_callbacks, + std::vector loaded_host_callback_handles); + + ~LoadedExecutable() override; + + xla::ifrt::Client* client() const override; + absl::string_view name() const override; + absl::StatusOr> Fingerprint() const override; + absl::StatusOr Serialize() const override; + Future GetReadyFuture() const override; + + int num_devices() const override; + int64_t SizeOfGeneratedCodeInBytes() const override; + absl::StatusOr GetCompiledMemoryStats() const override; + + std::optional> GetParameterShardings() const override; + std::optional> GetOutputShardings() const override; + absl::StatusOr>> GetParameterLayouts() + const override; + absl::StatusOr>> GetOutputLayouts() + const override; + absl::StatusOr>> + GetOutputMemoryKinds() const override; + absl::StatusOr>> GetHloModules() + const override; + + absl::StatusOr> + GetCostAnalysis() const override; + + absl::StatusOr Execute( + absl::Span> args, + const ExecuteOptions& options, + std::optional devices) override; + + Future Delete() override; + bool IsDeleted() const override; + + absl::Span addressable_device_logical_ids() + const override; + absl::Span addressable_devices() const override; + + static char ID; // NOLINT + + private: + struct Metadata { + std::optional> parameter_shardings; + std::optional> output_shardings; + + absl::StatusOr> parameter_layouts; + absl::StatusOr> output_layouts; + + // Elements in `output_memory_kinds` point to elements in `memory_kinds`. + // Required since `GetOutputMemoryKinds()` returns `absl::string_view`. + // `memory_kinds` uses `absl::node_hash_set` for pointer stability. + absl::node_hash_set memory_kinds; + absl::StatusOr>> + output_memory_kinds; + }; + + void PollLoadedHostCallback( + uint64_t handle, + tsl::RCReference loaded_host_callback); + + xla::ifrt::Client* client_; + std::shared_ptr rpc_helper_; + + const uint64_t handle_; + const std::string name_; + const int num_devices_; + const std::vector + addressable_device_logical_device_ids_; + const std::vector addressable_devices_; + const absl::StatusOr> fingerprint_; + const Future ready_future_; + + // Metadata queried when the executable is created. Declared as `mutable` + // since `Future::Await()` is not const. + mutable Future>> metadata_future_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_EXECUTABLE_H_ diff --git a/xla/python/ifrt_proxy/client/executable_test.cc b/xla/python/ifrt_proxy/client/executable_test.cc new file mode 100644 index 0000000000000..33a33276eb7ba --- /dev/null +++ b/xla/python/ifrt_proxy/client/executable_test.cc @@ -0,0 +1,355 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/executable.h" + +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "xla/layout_util.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/mock.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/sharding_serdes.h" +#include "xla/python/ifrt_proxy/client/array.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/client/host_buffer.h" +#include "xla/python/ifrt_proxy/client/mock_client_session.h" +#include "xla/python/ifrt_proxy/client/mock_host_buffer.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/client/version.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/types.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/casts.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::Optional; +using ::testing::Pointee; +using ::testing::Return; +using ::testing::SizeIs; +using ::testing::StrEq; +using ::tsl::protobuf::TextFormat; +using ::tsl::testing::IsOkAndHolds; +using ::tsl::testing::StatusIs; + +#if defined(PLATFORM_GOOGLE) +using ::testing::EquivToProto; +using ::testing::proto::Partially; +#endif + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +IfrtProxyVersion Version() { + IfrtProxyVersion version; + version.set_protocol_version(kClientMinVersion); + return version; +} + +class LoadedExecutableTest : public ::testing::Test { + protected: + void SetUp() override { + session_ = std::make_shared(); + rpc_helper_ = std::make_shared(Version(), session_); + + host_buffer_store_ = std::make_shared(); + rpc_helper_->set_host_buffer_store(host_buffer_store_); + + // Default handler that ignores all uninteresting requests, but still + // invokes the callback in order to avoid hanging the caller forever. + EXPECT_CALL(*session_, Enqueue(_)) + .WillRepeatedly(Return(Future( + absl::InternalError("Request has no mock handlers")))); + } + + std::shared_ptr session_; + std::shared_ptr rpc_helper_; + std::shared_ptr host_buffer_store_; +}; + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(LoadedExecutableTest, Metadata) { + IfrtResponse response; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + loaded_executable_metadata_response { + parameter_shardings { + shardings { type: REPLICATED } + shardings { + type: OTHER + tile_shape { + element_type: BF16 + dimensions: [ 2, 2 ] + } + tile_assignment_dimensions: [ 0, 1 ] + } + } + output_shardings { shardings { type: REPLICATED } } + parameter_layouts_list { + layouts { minor_to_major: 0 } + layouts { minor_to_major: [ 1, 0 ] } + } + output_layouts_list { layouts { minor_to_major: [ 1, 0 ] } } + output_memory_kinds { memory_kind_lists { memory_kinds: [ "foo" ] } } + } + )pb", + &response)); + EXPECT_CALL(*session_, Enqueue(Pointee(Partially(EquivToProto( + R"pb(loaded_executable_metadata_request { + loaded_executable_handle: 1234 + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + MockClient client; + LoadedExecutable executable( + &client, rpc_helper_, /*handle=*/1234, /*name=*/"foo", + /*num_devices=*/2, /*addressable_device_logical_device_ids=*/{}, + /*addressable_devices=*/{}, /*fingerprint=*/"fingerprint", + /*ready_future=*/Future(absl::OkStatus()), + /*loaded_host_callbacks=*/{}, /*loaded_host_callback_handles=*/{}); + + EXPECT_THAT( + executable.GetParameterShardings(), + Optional(ElementsAre( + EquivToProto(R"pb(type: REPLICATED)pb"), + EquivToProto(R"pb(type: OTHER + tile_shape { + element_type: BF16 + dimensions: [ 2, 2 ] + } + tile_assignment_dimensions: [ 0, 1 ])pb")))); + EXPECT_THAT(executable.GetOutputShardings(), + Optional(ElementsAre(EquivToProto(R"pb(type: REPLICATED)pb")))); + ASSERT_OK_AND_ASSIGN(auto parameter_layouts, + executable.GetParameterLayouts()); + EXPECT_EQ(parameter_layouts.size(), 2); + EXPECT_EQ( + tensorflow::down_cast(parameter_layouts[0].get()) + ->xla_layout(), + xla::LayoutUtil::MakeDescendingLayout(/*rank=*/1)); + EXPECT_EQ( + tensorflow::down_cast(parameter_layouts[1].get()) + ->xla_layout(), + xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2)); + ASSERT_OK_AND_ASSIGN(auto output_layouts, executable.GetOutputLayouts()); + EXPECT_EQ(output_layouts.size(), 1); + EXPECT_EQ(tensorflow::down_cast(output_layouts[0].get()) + ->xla_layout(), + xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2)); + EXPECT_THAT(executable.GetOutputMemoryKinds(), + IsOkAndHolds(ElementsAre(ElementsAre("foo")))); +} +#endif + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(LoadedExecutableTest, Execute) { + MockDevice device; + ON_CALL(device, global_device_id()) + .WillByDefault(Return(xla::PjRtGlobalDeviceId(1))); + + MockClient client; + ON_CALL(client, LookupDevice(1)).WillByDefault(Return(&device)); + + LoadedExecutable executable( + &client, rpc_helper_, /*handle=*/1234, /*name=*/"foo", + /*num_devices=*/2, /*addressable_device_logical_device_ids=*/{}, + /*addressable_devices=*/{}, /*fingerprint=*/"fingerprint", + /*ready_future=*/Future(absl::OkStatus()), + /*loaded_host_callbacks=*/{}, /*loaded_host_callback_handles=*/{}); + + IfrtResponse response; + ASSERT_TRUE(TextFormat::ParseFromString(R"pb( + loaded_executable_execute_response { + status_handle: 2000 + outputs { + dtype { kind: KIND_F32 } + shape { dims: [ 4, 4 ] } + array_handle: 3000 + } + outputs { + dtype { kind: KIND_F16 } + shape { dims: [ 8 ] } + array_handle: 3001 + } + } + )pb", + &response)); + { + auto* outputs = response.mutable_loaded_executable_execute_response() + ->mutable_outputs(); + TF_ASSERT_OK_AND_ASSIGN( + *(*outputs)[0].mutable_sharding(), + ToShardingProto(*SingleDeviceSharding::Create(&device, MemoryKind()))); + TF_ASSERT_OK_AND_ASSIGN( + *(*outputs)[1].mutable_sharding(), + ToShardingProto(*SingleDeviceSharding::Create(&device, MemoryKind()))); + } + EXPECT_CALL(*session_, Enqueue(Pointee(Partially(EquivToProto( + R"pb(loaded_executable_execute_request { + loaded_executable_handle: 1234 + args_handles: [ 1000, 1001 ] + device_ids: [ 1 ] + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + ASSERT_TRUE(TextFormat::ParseFromString(R"pb( + response_metadata { + status { + code: 2 # UNKNOWN + message: "injected error" + } + } + )pb", + &response)); + EXPECT_CALL(*session_, + Enqueue(Pointee(Partially(EquivToProto(R"pb(check_future_request { + future_handle: 2000 + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + DeviceList devices({&device}); + + std::vector> args; + for (const uint64_t handle : {1000, 1001}) { + args.push_back(tsl::MakeRef( + &client, rpc_helper_, DType(DType::kF32), Shape({2, 2}), + OpaqueSharding::Create(devices, MemoryKind()), ArrayHandle{handle})); + } + + TF_ASSERT_OK_AND_ASSIGN( + auto result, executable.Execute( + absl::MakeSpan(args), + xla::ifrt::LoadedExecutable::ExecuteOptions(), devices)); + + EXPECT_THAT(result.status.Await(), + StatusIs(absl::StatusCode::kUnknown, "injected error")); + + ASSERT_THAT(result.outputs, SizeIs(2)); + + const auto output0 = result.outputs[0]; + EXPECT_EQ(output0->dtype(), DType(DType::kF32)); + EXPECT_EQ(output0->shape(), Shape({4, 4})); + EXPECT_EQ(llvm::cast(output0.get())->handle().handle, 3000); + + const auto output1 = result.outputs[1]; + EXPECT_EQ(output1->dtype(), DType(DType::kF16)); + EXPECT_EQ(output1->shape(), Shape({8})); + EXPECT_EQ(llvm::cast(output1.get())->handle().handle, 3001); +} +#endif + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(LoadedExecutableTest, Delete) { + MockClient client; + LoadedExecutable executable( + &client, rpc_helper_, /*handle=*/1234, /*name=*/"foo", + /*num_devices=*/2, /*addressable_device_logical_device_ids=*/{}, + /*addressable_devices=*/{}, /*fingerprint=*/"fingerprint", + /*ready_future=*/Future(absl::OkStatus()), + /*loaded_host_callbacks=*/{}, /*loaded_host_callback_handles=*/{}); + + { + IfrtResponse response; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + loaded_executable_delete_response { future_handle: 2000 } + )pb", + &response)); + EXPECT_CALL(*session_, Enqueue(Pointee(Partially(EquivToProto( + R"pb(loaded_executable_delete_request { + loaded_executable_handle: 1234 + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + response_metadata { + status { + code: 2 # UNKNOWN + message: "injected error" + } + } + )pb", + &response)); + EXPECT_CALL( + *session_, + Enqueue(Pointee(Partially(EquivToProto(R"pb(check_future_request { + future_handle: 2000 + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + Future result = executable.Delete(); + EXPECT_THAT(result.Await(), + StatusIs(absl::StatusCode::kUnknown, StrEq("injected error"))); + } + + { + IfrtResponse response; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + loaded_executable_is_deleted_response { is_deleted: true } + )pb", + &response)); + EXPECT_CALL(*session_, Enqueue(Pointee(Partially(EquivToProto( + R"pb(loaded_executable_is_deleted_request { + loaded_executable_handle: 1234 + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); + + EXPECT_TRUE(executable.IsDeleted()); + } + + IfrtResponse response; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + loaded_executable_destruct_response {} + )pb", + &response)); + EXPECT_CALL(*session_, Enqueue(Pointee(Partially(EquivToProto( + R"pb(loaded_executable_destruct_request { + loaded_executable_handle: 1234 + })pb"))))) + .WillOnce(MockClientSessionReturnResponse(response)); +} +#endif + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/client/grpc_client.cc b/xla/python/ifrt_proxy/client/grpc_client.cc new file mode 100644 index 0000000000000..4279ab7d4d430 --- /dev/null +++ b/xla/python/ifrt_proxy/client/grpc_client.cc @@ -0,0 +1,189 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/log/log_entry.h" +#include "absl/log/log_sink.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "grpcpp/client_context.h" +#include "xla/pjrt/distributed/util.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/client.h" +#include "xla/python/ifrt_proxy/client/grpc_client_session.h" +#include "xla/python/ifrt_proxy/client/grpc_host_buffer.h" +#include "xla/python/ifrt_proxy/client/registry.h" +#include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include "xla/python/ifrt_proxy/client/version.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +namespace { + +// Attempts to establish a session to the proxy-server and returns a `Client` +// based on the session if successful. `on_disconnect` will be invoked exactly +// once if this function returns successfully, and not invoked if this function +// returns a non-OK status. +absl::StatusOr> AttemptConnection( + absl::string_view server_address, + std::function on_disconnect, int attempt_no, + absl::AnyInvocable log_initial_connection) { + std::unique_ptr rpc_helper; + auto init_response_promise = + Future>>::CreatePromise(); + + if (on_disconnect == nullptr) { + on_disconnect = [](absl::Status s) { + LOG(WARNING) << "IFRT proxy server disconnected: " << s; + }; + } + + // TODO(b/266635130): Move gRPC stub creation to be outside of `Client` so + // that we can pass mock `ClientSession` to the client. + auto stub = CreateGrpcStub(server_address); + + auto session_disconnect_cb = + [init_response = Future>>( + init_response_promise), + on_disconnect = std::move(on_disconnect), + attempt_no](absl::Status s) mutable { + // If the `rpc_helper->Init().OnReady(cb)` statement below has returned, + // the callback cb in that statement (which sets `init_response`) is + // guaranteed by `GrpcClientSession::Create()` to be called before + // `session_disconnect_cb`. + // TODO(madthanu): The above statement is false (even if we wanted to, + // we cannot meaningfully enforce or document the guarantee of + // the returned Future's OnReady being called before another callback), + // although the exact way init_response_promise is set below makes it + // work most of the time. + if (init_response.IsReady() && init_response.Await().ok()) { + // If the init RPC has already completed successfully, we have + // already or will be returning OK from the `AttemptConnection` call. + // So, invoke `on_disconnect`. + on_disconnect(s); + } else { + // Otherwise, we are going to return an error from + // `AttemptConnection`. So do not invoke `on_disconnect`. + VLOG(0) << "GrpcClientSession attempt " << attempt_no + << " failed: " << s; + } + }; + + GrpcIfrtSessionMetadata metadata; + { + GrpcGetVersionRequest request; + request.mutable_min_version()->set_protocol_version(kClientMinVersion); + request.mutable_max_version()->set_protocol_version(kClientMaxVersion); + + ::grpc::ClientContext context; + GrpcGetVersionResponse response; + TF_RETURN_IF_ERROR( + xla::FromGrpcStatus(stub->GetVersion(&context, request, &response))); + + CHECK_GE(response.version().protocol_version(), kClientMinVersion); + CHECK_LE(response.version().protocol_version(), kClientMaxVersion); + *metadata.mutable_version() = response.version(); + } + + auto session = + GrpcClientSession::Create(stub, metadata, session_disconnect_cb); + rpc_helper = + std::make_unique(metadata.version(), std::move(session)); + + log_initial_connection(absl::StrCat("Sending InitRequest and waiting for ", + "response (attempt ", attempt_no, ").")); + + // TODO(b/282757875): Use a separate Request that will indicate quickly + // whether the grpc_client<->grpc_server session has been established or + // not, instead of combining it with the Request that will fetch device + // information (which can take a while, depending on the IFRT backend). + rpc_helper->Init(std::make_unique()) + .OnReady([&](auto resp) mutable { init_response_promise.Set(resp); }); + + TF_ASSIGN_OR_RETURN(auto init_response, + Future>>( + init_response_promise) + .Await()); + + auto host_buffer_store = std::make_unique( + stub, metadata.version(), init_response->session_id()); + rpc_helper->set_host_buffer_store(std::move(host_buffer_store)); + + return Client::Create(std::move(rpc_helper), std::move(*init_response)); +} + +absl::StatusOr> CreateGrpcClient( + absl::string_view server_address, const ClientConnectionOptions& options) { + auto log_initial_connection = + [f = std::move(options.on_connection_update)](absl::string_view msg) { + VLOG(0) << msg; + if (f) { + f(absl::StrCat(absl::Now(), ": ", msg)); + } + }; + + absl::Time start_time = absl::Now(); + absl::Status last_status; + for (int i = 0; absl::Now() - start_time < options.connection_timeout; ++i) { + log_initial_connection(absl::StrCat("Connecting to IFRT proxy server at ", + server_address, ", attempt #", i, + "...")); + absl::StatusOr> result = AttemptConnection( + server_address, options.on_disconnect, i, log_initial_connection); + if (result.ok()) { + log_initial_connection(absl::StrCat("Connected to IFRT proxy server on ", + "attempt #", i, ".")); + return result; + } else { + last_status = result.status(); + log_initial_connection( + absl::StrCat("Connection to IFRT proxy server attempt #", i, + "failed: ", last_status.ToString())); + } + absl::SleepFor(absl::Seconds(1)); + } + + // We want to prepend a human-friendly error message to status before + // returning. + auto err_msg = + absl::StrCat("Unable to establish connection to ifrt_proxy server, ", + "please check provided address '", server_address, + "'; detailed error: ", last_status.message()); + log_initial_connection(err_msg); + return tsl::errors::CreateWithUpdatedMessage(last_status, err_msg); +} + +} // namespace + +bool register_client_factory = + ([] { RegisterClientFactory("grpc", CreateGrpcClient); }(), true); + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/client/grpc_client_session.cc b/xla/python/ifrt_proxy/client/grpc_client_session.cc new file mode 100644 index 0000000000000..a70e633d6767e --- /dev/null +++ b/xla/python/ifrt_proxy/client/grpc_client_session.cc @@ -0,0 +1,266 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/grpc_client_session.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/call_once.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/bind_front.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "grpc/grpc.h" +#include "grpcpp/channel.h" +#include "grpcpp/client_context.h" +#include "grpcpp/create_channel.h" +#include "grpcpp/security/credentials.h" +#include "grpcpp/support/channel_arguments.h" +#include "xla/pjrt/distributed/util.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/common/grpc_credentials.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/threadpool.h" +#include "tsl/platform/unbounded_work_queue.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +using OpId = int64_t; + +// Logically equivalent to a map, but thread-safe and +// with various convenience functions. +class GrpcClientSession::ResponseCallbackTable { + public: + absl::Status Add(OpId op_id, ResponseCallback callback) { + absl::MutexLock l(&mu_); + const bool inserted = table_.insert({op_id, std::move(callback)}).second; + if (!inserted) { + return absl::AlreadyExistsError( + absl::StrCat("Op id ", op_id, " already exists")); + } + return absl::OkStatus(); + } + + std::optional Pop(OpId op_id) { + absl::MutexLock l(&mu_); + auto it = table_.find(op_id); + if (it == table_.end()) { + return std::nullopt; + } + auto cb = std::move(it->second); + table_.erase(it); + return std::move(cb); + } + + absl::flat_hash_map PopAll() { + absl::flat_hash_map result; + absl::MutexLock l(&mu_); + result = std::move(table_); + table_ = absl::flat_hash_map(); + return result; + } + + private: + absl::Mutex mu_; + absl::flat_hash_map table_ ABSL_GUARDED_BY(mu_); +}; + +std::shared_ptr GrpcClientSession::Create( + std::shared_ptr stub, + GrpcIfrtSessionMetadata metadata, + StreamTerminatedCallback stream_terminated_cb) { + auto context = std::make_unique<::grpc::ClientContext>(); + context->AddMetadata("ifrt-proxy-grpc-ifrt-session-metadata-bin", + metadata.SerializeAsString()); + std::shared_ptr result(new GrpcClientSession( + std::move(stub), std::move(context), std::move(stream_terminated_cb))); + return result; +} + +GrpcClientSession::GrpcClientSession( + std::shared_ptr stub, + std::unique_ptr<::grpc::ClientContext> context, + StreamTerminatedCallback stream_terminated_cb) + : response_callbacks_(std::make_unique()), + reader_thread_(std::make_unique( + tsl::Env::Default(), "ifrt_proxy_client_grpc_reader", + /*num_threads=*/1)), + stub_(std::move(stub)), + context_(std::move(context)), + stream_(stub_->IfrtSession(context_.get())), + stream_terminated_cb_(std::move(stream_terminated_cb)), + user_futures_work_queue_(std::make_unique( + tsl::Env::Default(), "GrpcClientSessionUserFuturesWorkQueue")) { + reader_thread_->Schedule( + absl::bind_front(&GrpcClientSession::ReadLoop, this)); +} + +Future GrpcClientSession::Enqueue( + std::unique_ptr request) { + auto promise = Future::CreatePromise(); + absl::Status status = Enqueue( + std::move(request), [promise, queue = user_futures_work_queue_.get()]( + Response response) mutable { + queue->Schedule([promise = std::move(promise), + response = std::move(response)]() mutable -> void { + promise.Set(std::move(response)); + }); + }); + if (!status.ok()) { + user_futures_work_queue_->Schedule([promise, status]() mutable -> void { + promise.Set(std::move(status)); + }); + } + return Future(std::move(promise)); +} + +absl::Status GrpcClientSession::Enqueue(std::unique_ptr req, + ResponseCallback callback) { + const OpId op_id = req->request_metadata().op_id(); + + absl::MutexLock l(&writer_mu_); + if (writes_stopped_) { + return absl::FailedPreconditionError( + "GrpcClientSession: writes no longer allowed."); + } + + TF_RETURN_IF_ERROR(response_callbacks_->Add(op_id, std::move(callback))); + + if (!stream_->Write(*req)) { + CHECK(response_callbacks_->Pop(op_id).has_value()); + return absl::UnknownError("GrpcClientSession: writing to stream failed."); + } + + return absl::OkStatus(); +} + +void GrpcClientSession::ReadLoop() { + while (true) { + auto read_buffer = std::make_unique(); + if (!stream_->Read(read_buffer.get())) { + LOG(INFO) << "GrpcClientSession: reader loop is exiting."; + break; + } + + const OpId op_id = read_buffer->response_metadata().op_id(); + std::optional callback = response_callbacks_->Pop(op_id); + + if (callback.has_value()) { + VLOG(1) << "GrpcClientSession: Issuing callback for " << op_id; + (*callback)(std::move(read_buffer)); + VLOG(1) << "GrpcClientSession: Done with callback for " << op_id; + } else { + LOG(ERROR) << "Received response with no remaining registered callback: " + << read_buffer->DebugString(); + } + } + + reader_thread_stopped_.Notify(); + Finish(absl::OkStatus()); +} + +void GrpcClientSession::Finish(const absl::Status& client_status) { + LOG(INFO) << "GrpcClientSession: Finish() called with client status " + << client_status; + + absl::call_once(finish_once_, [&] { + context_->TryCancel(); + + LOG(INFO) << "GrpcClientSession: Waiting for reader thread to stop."; + reader_thread_stopped_.WaitForNotification(); + + auto finish_stream_and_get_server_status = [&]() -> absl::Status { + LOG(INFO) << "GrpClientSession: Attempting to call stream->Finish()"; + absl::MutexLock l(&writer_mu_); + // Note: stream_->Finish() counts as a write, and needs to be serialized + // with stream->Write(). + LOG(INFO) << "GrpClientSession: Attempting to call stream->Finish(), " + "mutex acquired"; + absl::Status server_status = xla::FromGrpcStatus(stream_->Finish()); + LOG(INFO) << "GrpClientSession: stream->Finish() returned server status " + << server_status; + + CHECK(!writes_stopped_); + writes_stopped_ = true; + + return server_status; + }; + + absl::Status combined_status = finish_stream_and_get_server_status(); + combined_status.Update(client_status); + + auto all_callbacks = response_callbacks_->PopAll(); + for (auto& [_, cb] : all_callbacks) { + if (combined_status.ok()) { + cb(absl::AbortedError("Finish(OK) called.")); + } else { + cb(combined_status); + } + } + + LOG(INFO) << "GrpClientSession::Finish(): calling terminated cb with " + << combined_status; + stream_terminated_cb_(combined_status); + }); +} + +GrpcClientSession::~GrpcClientSession() { + GrpcClientSession::Finish(absl::CancelledError("~GrpcClientSession called.")); + reader_thread_.reset(); // Wait until the reader thread exits. + LOG(INFO) << "Deleting GrpcClientSession.user_futures_work_queue_ ..."; + user_futures_work_queue_.reset(); + LOG(INFO) << "Deleted GrpcClientSession.user_futures_work_queue_."; +} + +std::shared_ptr CreateGrpcStub( + absl::string_view server_address) { + ::grpc::ChannelArguments args; + // Remove message size limit to accommodate large messages exchanged during + // model compilation. + args.SetInt(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH, -1); + args.SetInt(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, -1); + std::shared_ptr<::grpc::Channel> channel = ::grpc::CreateCustomChannel( + std::string(server_address), GetClientCredentials(), args); + VLOG(0) << " Established channel."; + CHECK(channel != nullptr); + + std::shared_ptr stub = + grpc::GrpcIfrtService::NewStub(channel); + VLOG(0) << " Created stub."; + CHECK(stub != nullptr); + + return stub; +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/client/grpc_client_session.h b/xla/python/ifrt_proxy/client/grpc_client_session.h new file mode 100644 index 0000000000000..9ca8219760a15 --- /dev/null +++ b/xla/python/ifrt_proxy/client/grpc_client_session.h @@ -0,0 +1,144 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_GRPC_CLIENT_SESSION_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_GRPC_CLIENT_SESSION_H_ + +#include +#include + +#include "absl/base/call_once.h" +#include "absl/base/thread_annotations.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "grpcpp/client_context.h" +#include "grpcpp/support/client_callback.h" +#include "grpcpp/support/sync_stream.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "tsl/platform/threadpool.h" +#include "tsl/platform/unbounded_work_queue.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// `GrpcClientSession` implements the client side of an `IfrtSession` +// stream(ing RPC) and allows users to enqueue `IfrtRequest`s on the +// stream and register callbacks for when `IfrtResponse`s are received. +class GrpcClientSession : public ClientSession { + public: + // `StreamTerminatedCallback` represents a function that will be called when + // the underlying streaming RPC is terminated permanently. The callback may be + // invoked by the "primary" thread and with various mutex locks held, so the + // callback should both return soon and not block on any events (deadlocks may + // happen otherwise). + using StreamTerminatedCallback = std::function; + + // Returns an instantiation of GrpcClientSession on the given `stub`. + // `stream_terminated_cb` is guaranteed to be called exactly once (unless the + // process terminates beforehand). It is guaranteed that no registered + // `ResponseCallback` (see below) will be called after `stream_terminated_cb`. + static std::shared_ptr Create( + std::shared_ptr stub, + GrpcIfrtSessionMetadata metadata, + StreamTerminatedCallback stream_terminated_cb); + + Future Enqueue(std::unique_ptr request) override; + + // `ResponseCallback` represents a function that can be invoked when + // `ClientSession` receives an `IfrtResponse`. May be invoked by the "primary" + // thread and with various mutex locks held. + using ResponseCallback = std::function; + + absl::Status Enqueue(std::unique_ptr req, + ResponseCallback callback); + + // Terminates the `GrpcClientSession` if it has not already been terminated. + // Waits until `stream_terminated_cb` returns. + void Finish(const absl::Status& client_status) override; + + // Not copyable (or moveable) + GrpcClientSession(const GrpcClientSession&) = delete; + GrpcClientSession& operator=(const GrpcClientSession&) = delete; + + // Calls `Finish()`. Also waits for the destruction of + // `user_futures_work_queue_` (see below) and thus can block on user-level + // callbacks. + ~GrpcClientSession() override; + + private: + class ResponseCallbackTable; + + GrpcClientSession(std::shared_ptr stub, + std::unique_ptr<::grpc::ClientContext> context, + StreamTerminatedCallback stream_terminated_cb); + + // Repeatedly waits for a `IfrtResponse` message to arrive; for each message, + // looks up the corresponding callback registered in `response_callbacks_` and + // invokes it inline. + void ReadLoop(); + + // Thread-safe table that logically maps from RequestMetadata.OpId to + // ResponseCallback. + const std::unique_ptr response_callbacks_; + + // Thread that invokes `ReadLoop()`. + std::unique_ptr reader_thread_; + + // A notification (waited on by `Finish()`) for when `ReadLoop()` exits. + absl::Notification reader_thread_stopped_; + + // Set by `Finish()`, respected by `Enqueue()` calls. + bool writes_stopped_ ABSL_GUARDED_BY(writer_mu_) = false; + + // A mutex that ensures serialization between various `Enqueue()` calls, since + // only one thread is allowed to write to the gRPC stream at a time. + absl::Mutex writer_mu_; + + // Ensures logic inside `Finish()` is internally called only once. + absl::once_flag finish_once_; + + // References to gRPC objects used to read and write to the stream. + const std::shared_ptr stub_; + const std::unique_ptr<::grpc::ClientContext> context_; + const std::unique_ptr< + ::grpc::ClientReaderWriterInterface> + stream_; + + const StreamTerminatedCallback stream_terminated_cb_; + + // Threadpool used to perform `Future<>::Promise::Set()` for Futures returned + // to callers of `Enqueue(std::unique_ptr request)`. We do this + // because `Set()` may block on arbitrary `OnReady` callbacks set by those + // callers. + std::unique_ptr user_futures_work_queue_; +}; + +// Creates a gRPC stub that connects to `server_address`. It can be used for +// `GrpcClientSession`. The same stub can be reused across multiple sessions. +std::shared_ptr CreateGrpcStub( + absl::string_view server_address); + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_GRPC_CLIENT_SESSION_H_ diff --git a/xla/python/ifrt_proxy/client/grpc_client_session_test.cc b/xla/python/ifrt_proxy/client/grpc_client_session_test.cc new file mode 100644 index 0000000000000..18f1bb1328de2 --- /dev/null +++ b/xla/python/ifrt_proxy/client/grpc_client_session_test.cc @@ -0,0 +1,481 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/grpc_client_session.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/log/log_sink_registry.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "absl/time/time.h" +#include "grpc/support/time.h" +#include "grpcpp/channel.h" +#include "grpcpp/create_channel.h" +#include "grpcpp/server_builder.h" +#include "grpcpp/server_context.h" +#include "grpcpp/support/status.h" +#include "grpcpp/support/sync_stream.h" +#include "xla/python/ifrt_proxy/client/version.h" +#include "xla/python/ifrt_proxy/common/grpc_credentials.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +namespace { + +using ::testing::Not; +using ::tsl::testing::IsOk; + +constexpr int kOp1 = 1; +constexpr int kOp2 = 2; + +// Sufficient time for all processing (that are not explicitly waiting for +// further input) to have finished. +constexpr absl::Duration kSufficientTime = absl::Seconds(5); + +GrpcIfrtSessionMetadata Metadata() { + GrpcIfrtSessionMetadata metadata; + metadata.mutable_version()->set_protocol_version(kClientMaxVersion); + return metadata; +} + +absl::Status TestError() { return absl::UnknownError("test error"); } + +// A thread-safe queue of `absl::Status` values. +class Queue { + public: + void Push(absl::Status t) { + absl::MutexLock l(&mu_); + queue_.push_back(std::move(t)); + } + + std::optional PopOrTimeout( + absl::Duration timeout = kSufficientTime) { + absl::MutexLock l(&mu_); + auto cond = [this]() ABSL_SHARED_LOCKS_REQUIRED(mu_) -> bool { + return !queue_.empty(); + }; + mu_.AwaitWithTimeout(absl::Condition(&cond), timeout); + if (queue_.empty()) { + return std::nullopt; + } + absl::Status result = std::move(queue_.front()); + queue_.pop_front(); + return result; + } + + absl::Status Pop(absl::Duration timeout = kSufficientTime) { + auto result = PopOrTimeout(timeout); + CHECK(result.has_value()) << "Timeout!"; + return *result; + } + + void PopAllDuringDestruction() { + absl::MutexLock l(&mu_); + allow_non_empty_destruction_ = true; + } + + ~Queue() { + absl::MutexLock l(&mu_); + if (!allow_non_empty_destruction_) CHECK(queue_.empty()) << " " << this; + } + + private: + absl::Mutex mu_; + std::deque queue_ ABSL_GUARDED_BY(mu_); + bool allow_non_empty_destruction_ ABSL_GUARDED_BY(mu_) = false; +}; + +// Checks that the input is a list of zero-or-more OK statuses followed by +// zero-or-more NOT-OK statuses. Succeeds for {OK, NOT_OK, NOT_OK}, but fails +// for {OK, NOT_OK, OK}. +void ExpectHeadAndTail( + std::vector, absl::Status>> var_list) { + std::vector status_list; + for (const auto& v : var_list) { + if (std::holds_alternative>(v)) { + status_list.push_back(std::get>(v).status()); + } else { + status_list.push_back(std::get(v)); + } + } + bool seen_not_ok = false; + std::string str; + for (const auto& s : status_list) { + absl::StrAppend(&str, "\n", s.ToString(), "\n-----\n"); + } + for (const auto& s : status_list) { + if (!s.ok()) seen_not_ok = true; + if (seen_not_ok) { + EXPECT_THAT(s, Not(IsOk())) << str; + } + } +} + +using ServerStream = ::grpc::ServerReaderWriter; +using SessionAction = bool; +constexpr SessionAction kContinueSession = true; +constexpr SessionAction kStopSession = false; +using OnSessionStart = std::function; +using OnReqReceived = + std::function; + +// A simple implementation of IfrtService with various test-hooks. +class SimpleIfrtService : public grpc::GrpcIfrtService::Service { + public: + SimpleIfrtService(OnReqReceived on_req_received, + OnSessionStart on_session_start) + : on_req_received_(std::move(on_req_received)), + on_session_start_(std::move(on_session_start)) {} + + ::grpc::Status IfrtSession(::grpc::ServerContext* context, + ServerStream* stream) override { + if (on_session_start_ && on_session_start_() == kStopSession) { + return ::grpc::Status::OK; + } + + { + absl::MutexLock l(&mu_); + CHECK(contexts_.insert(context).second); + } + + while (true) { + IfrtRequest request; + LOG(INFO) << "Server: waiting on Read()."; + if (!stream->Read(&request)) { + LOG(INFO) << "Server: Read() returned false."; + break; + } + LOG(INFO) << "Server: Read() returned true."; + if (!on_req_received_) { + IfrtResponse response; + response.mutable_response_metadata()->set_op_id( + request.request_metadata().op_id()); + stream->Write(response); + } else if (on_req_received_(request, stream) == kStopSession) { + break; + } + } + { + absl::MutexLock l(&mu_); + CHECK_EQ(contexts_.erase(context), 1); + } + + LOG(INFO) << "Finishing IFRT session"; + return ::grpc::Status::OK; + } + + void CancelAllServerSessions() { + absl::MutexLock l(&mu_); + for (const auto& context : contexts_) { + context->TryCancel(); + } + } + + private: + const OnReqReceived on_req_received_; + const OnSessionStart on_session_start_; + + // Keeps track of `::grpc::ServerContext` for all ongoing sessions. + absl::Mutex mu_; + absl::flat_hash_set<::grpc::ServerContext*> contexts_ ABSL_GUARDED_BY(mu_); +}; + +// Encapsulates objects related to a client and server instance of +// `grpc::GrpcIfrtService`. +class ClientAndServer { + public: + explicit ClientAndServer(OnReqReceived on_req_received = nullptr, + OnSessionStart on_session_start = nullptr) { + std::string address = + absl::StrCat("localhost:", tsl::testing::PickUnusedPortOrDie()); + ::grpc::ServerBuilder builder; + builder.AddListeningPort(address, GetServerCredentials()); + ifrt_service_ = + std::make_unique(on_req_received, on_session_start); + builder.RegisterService(ifrt_service_.get()); + server_ = builder.BuildAndStart(); + + LOG(INFO) << "Server started and listening on " << address; + absl::FlushLogSinks(); + + std::shared_ptr<::grpc::Channel> channel = + ::grpc::CreateChannel(address, GetClientCredentials()); + channel->WaitForConnected(gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(10, GPR_TIMESPAN))); + LOG(INFO) << "conn_state = " << channel->GetState(/*try_to_connect=*/false); + + auto stub = grpc::GrpcIfrtService::NewStub(channel); + CHECK(stub != nullptr); + + client_session_ = GrpcClientSession::Create( + std::move(stub), Metadata(), [this](absl::Status s) { + client_finished_q_.Push(s); + client_finished_notification_.Notify(); + }); + + client_finished_q_.PopAllDuringDestruction(); + } + + void StopServer() { + ifrt_service_->CancelAllServerSessions(); + server_->Shutdown(); + server_->Wait(); + } + + ~ClientAndServer() { + StopServer(); + client_session_->Finish(absl::CancelledError("~ClientAndServer")); + client_finished_notification_.WaitForNotificationWithTimeout( + kSufficientTime); + CHECK(client_finished_notification_.HasBeenNotified()); + } + + GrpcClientSession* client_session() { return client_session_.get(); } + + Queue* client_finished_q() { return &client_finished_q_; } + + absl::StatusOr SendSimpleRequest(int op_id) { + owned_queues_.push_back(std::make_unique()); + Queue* q = owned_queues_.back().get(); + + auto req = std::make_unique(); + req->mutable_request_metadata()->set_op_id(op_id); + TF_RETURN_IF_ERROR(client_session_->Enqueue( + std::move(req), + [q](GrpcClientSession::Response resp) { q->Push(resp.status()); })); + + return q; + } + + private: + std::vector> owned_queues_; + Queue client_finished_q_; + absl::Notification client_finished_notification_; + std::shared_ptr client_session_; + + std::unique_ptr<::grpc::Server> server_; + std::unique_ptr ifrt_service_; +}; + +TEST(GrpcClientSessionTest, HappyCaseOneRequestWithServerTermination) { + ClientAndServer cs; + + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q, cs.SendSimpleRequest(kOp1)); + + EXPECT_THAT(response_q->Pop(), IsOk()); + + EXPECT_EQ(cs.client_finished_q()->PopOrTimeout(), std::nullopt); + + cs.StopServer(); + EXPECT_THAT(cs.client_finished_q()->Pop(), Not(IsOk())); +} + +TEST(GrpcClientSessionTest, HappyCaseTwoRequestsWithClientFinish) { + ClientAndServer cs; + + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_2, cs.SendSimpleRequest(kOp2)); + + EXPECT_THAT(response_q_1->Pop(), IsOk()); + EXPECT_THAT(response_q_2->Pop(), IsOk()); + + EXPECT_EQ(cs.client_finished_q()->PopOrTimeout(), std::nullopt); + + cs.client_session()->Finish(TestError()); + EXPECT_THAT(cs.client_finished_q()->Pop(), Not(IsOk())); +} + +TEST(GrpcClientSessionTest, ServerFinishesDuringFirstRead) { + ClientAndServer cs( + /*on_req_received=*/[](auto, auto) { return kStopSession; }); + + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); + EXPECT_THAT(response_q_1->Pop(), Not(IsOk())); + + absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + EXPECT_THAT(response_q_2.status(), Not(IsOk())); + + EXPECT_THAT(cs.client_finished_q()->Pop(), Not(IsOk())); +} + +TEST(GrpcClientSessionTest, ServerFinishesDuringConstruction) { + ClientAndServer cs(/*on_req_received=*/nullptr, + /*on_session_start=*/[]() { return kStopSession; }); + + absl::StatusOr response_q_1 = cs.SendSimpleRequest(kOp1); + absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + + ExpectHeadAndTail({response_q_1, response_q_2}); + if (response_q_1.ok()) EXPECT_THAT(response_q_1.value()->Pop(), Not(IsOk())); + if (response_q_2.ok()) EXPECT_THAT(response_q_2.value()->Pop(), Not(IsOk())); + + EXPECT_THAT(cs.client_finished_q()->Pop(), Not(IsOk())); +} + +TEST(GrpcClientSessionTest, ClientFinishesAfterServerConsumesFirstRequest) { + std::atomic session_ptr; + ClientAndServer cs( + /*on_req_received=*/[session_ptr = &session_ptr](auto, auto) { + session_ptr->load()->Finish(TestError()); + return kContinueSession; + }); + session_ptr.store(cs.client_session()); + + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); + EXPECT_THAT(response_q_1->Pop(), Not(IsOk())); + + absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + EXPECT_THAT(response_q_2.status(), Not(IsOk())); + + EXPECT_THAT(cs.client_finished_q()->Pop(), Not(IsOk())); +} + +TEST(GrpcClientSessionTest, ClientFinishesAfterServerWritesFirstResponse) { + std::atomic session_ptr; + ClientAndServer cs( + /*on_req_received=*/[session_ptr = &session_ptr](const IfrtRequest& r, + ServerStream* s) { + IfrtResponse response; + response.mutable_response_metadata()->set_op_id( + r.request_metadata().op_id()); + s->Write(response); + session_ptr->load()->Finish(TestError()); + return kContinueSession; + }); + session_ptr.store(cs.client_session()); + + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); + absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + + // The client may or may not terminate before the first response arrives. + response_q_1->Pop().IgnoreError(); + + // The client may or may not terminate before the second request could be + // enqueued. If it could be enqueued, the client will die without the server + // sending the corresponding response. + if (response_q_2.ok()) { + EXPECT_THAT(response_q_2.value()->Pop(), Not(IsOk())); + } + + EXPECT_THAT(cs.client_finished_q()->Pop(), Not(IsOk())); +} + +TEST(GrpcClientSessionTest, ClientFinishesDuringServerConstruction) { + std::atomic session_ptr; + absl::Notification init_done; + ClientAndServer cs(/*on_req_received=*/nullptr, + /*on_session_start=*/[session_ptr = &session_ptr, + init_done = &init_done]() { + init_done->WaitForNotification(); + session_ptr->load()->Finish(TestError()); + return kContinueSession; + }); + session_ptr.store(cs.client_session()); + init_done.Notify(); + + absl::StatusOr response_q_1 = cs.SendSimpleRequest(kOp1); + absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + + if (response_q_1.ok()) { + EXPECT_THAT(response_q_1.value()->Pop(), Not(IsOk())); + } + if (response_q_2.ok()) { + EXPECT_THAT(response_q_2.value()->Pop(), Not(IsOk())); + } + + ExpectHeadAndTail({response_q_1, response_q_2}); + + EXPECT_THAT(cs.client_finished_q()->Pop(), Not(IsOk())); +} + +TEST(GrpcClientSessionTest, MethodsAfterFinishReturnError) { + ClientAndServer cs; + + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); + cs.client_session()->Finish(TestError()); + + EXPECT_THAT(cs.SendSimpleRequest(kOp2), Not(IsOk())); + + response_q_1->PopAllDuringDestruction(); +} + +TEST(GrpcClientSessionTest, ReceivingBadIfrtResponseDoesNotCrash) { + ClientAndServer cs( + /*on_req_received=*/[](const IfrtRequest& r, ServerStream* s) mutable { + IfrtResponse resp; + resp.mutable_response_metadata()->set_op_id(kOp2); + s->Write(resp); + resp.mutable_response_metadata()->set_op_id( + r.request_metadata().op_id()); + s->Write(resp); + return kContinueSession; + }); + + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q, cs.SendSimpleRequest(kOp1)); + + EXPECT_THAT(response_q->Pop(), IsOk()); +} + +TEST(GrpcClientSessionTest, BadInitialChannelFailsPromptly) { + std::string address = + absl::StrCat("localhost:", tsl::testing::PickUnusedPortOrDie()); + + std::shared_ptr<::grpc::Channel> channel = + ::grpc::CreateChannel(address, GetClientCredentials()); + + std::unique_ptr stub = + grpc::GrpcIfrtService::NewStub(channel); + EXPECT_TRUE(stub != nullptr); + + auto session_finished = std::make_shared(); + auto session = GrpcClientSession::Create( + std::move(stub), Metadata(), + [session_finished](absl::Status s) { session_finished->Push(s); }); + + EXPECT_THAT(session_finished->Pop(), Not(IsOk())); +} + +} // namespace + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/client/grpc_host_buffer.cc b/xla/python/ifrt_proxy/client/grpc_host_buffer.cc new file mode 100644 index 0000000000000..c5a69737f057d --- /dev/null +++ b/xla/python/ifrt_proxy/client/grpc_host_buffer.cc @@ -0,0 +1,182 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/grpc_host_buffer.h" + +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "grpcpp/client_context.h" +#include "grpcpp/support/client_callback.h" +#include "grpcpp/support/status.h" +#include "grpcpp/support/sync_stream.h" +#include "xla/pjrt/distributed/util.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.pb.h" +#include "tsl/platform/env.h" +#include "tsl/platform/unbounded_work_queue.h" +#include "tsl/protobuf/status.pb.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +static constexpr int64_t kChunkSize = 1024 * 1024; + +GrpcClientHostBufferStore::GrpcClientHostBufferStore( + std::shared_ptr stub, + IfrtProxyVersion version, uint64_t session_id) + : stub_(std::move(stub)), + version_(std::move(version)), + session_id_(session_id), + lookup_work_queue_(std::make_unique( + tsl::Env::Default(), "HostBufferStoreLookupsWorkQueue")) {} + +GrpcClientHostBufferStore::~GrpcClientHostBufferStore() { + LOG(INFO) << "Waiting for destruction of HostBufferStoreLookupsWorkQueue..."; + lookup_work_queue_.reset(); + LOG(INFO) << "Destructed HostBufferStoreLookupsWorkQueue."; +} + +uint64_t GrpcClientHostBufferStore::NextHandle() { + return next_handle_.fetch_add(1, std::memory_order_relaxed); +} + +Future GrpcClientHostBufferStore::Store(uint64_t handle, + absl::string_view data) { + // The current implementation synchronously sends host buffer chunks. We may + // consider making it asynchronous if the caller can leverage such asynchrony. + + GrpcHostBufferStoreMetadata metadata; + metadata.set_session_id(session_id_); + metadata.set_handle(handle); + metadata.set_buffer_size(data.size()); + + ::grpc::ClientContext context; + context.AddMetadata("ifrt-proxy-grpc-host-buffer-store-metadata-bin", + metadata.SerializeAsString()); + + GrpcHostBufferStoreResponse response; + auto writer = stub_->HostBufferStore(&context, &response); + + for (int64_t offset = 0; offset < data.size(); offset += kChunkSize) { + GrpcHostBufferStoreRequest request; +#if defined(PLATFORM_GOOGLE) + request.set_alias_data(data.substr(offset, kChunkSize)); +#else + // TODO(b/325306748): Find a way to not do a memory-copy. + request.set_data(std::string(data.substr(offset, kChunkSize))); +#endif + writer->Write(request); + } + + if (!writer->WritesDone()) { + return Future( + absl::InternalError("Failed to write all host buffer chunks")); + } + + return Future(xla::FromGrpcStatus(writer->Finish())); +} + +Future GrpcClientHostBufferStore::Store(uint64_t handle, + const absl::Cord& data) { + // The current implementation synchronously sends host buffer chunks. We may + // consider making it asynchronous if the caller can leverage such asynchrony. + + GrpcHostBufferStoreMetadata metadata; + metadata.set_session_id(session_id_); + metadata.set_handle(handle); + metadata.set_buffer_size(data.size()); + + ::grpc::ClientContext context; + context.AddMetadata("ifrt-proxy-grpc-host-buffer-store-metadata-bin", + metadata.SerializeAsString()); + + GrpcHostBufferStoreResponse response; + auto writer = stub_->HostBufferStore(&context, &response); + + for (absl::string_view chunk : data.Chunks()) { + for (int64_t offset = 0; offset < chunk.size(); offset += kChunkSize) { + GrpcHostBufferStoreRequest request; +#if defined(PLATFORM_GOOGLE) + request.set_alias_data(chunk.substr(offset, kChunkSize)); +#else + // TODO(b/325306748): Find a way to not do a memory-copy. + request.set_data(std::string(chunk.substr(offset, kChunkSize))); +#endif + writer->Write(request); + } + } + if (!writer->WritesDone()) { + return Future( + absl::InternalError("Failed to write all host buffer chunks")); + } + + return Future(xla::FromGrpcStatus(writer->Finish())); +} + +Future> GrpcClientHostBufferStore::Lookup( + uint64_t handle) { + auto promise = Future>::CreatePromise(); + + lookup_work_queue_->Schedule([this, handle, promise]() mutable -> void { + GrpcHostBufferLookupRequest request; + request.set_handle(handle); + request.set_session_id(session_id_); + + ::grpc::ClientContext context; + + std::unique_ptr<::grpc::ClientReaderInterface> + stream = stub_->HostBufferLookup(&context, request); + + absl::Cord data; + GrpcHostBufferLookupResponse response; + while (stream->Read(&response)) { + data.Append(response.data()); + } + + absl::Status status = xla::FromGrpcStatus(stream->Finish()); + if (status.ok()) { + promise.Set(std::move(data)); + } else { + promise.Set(status); + } + }); + + return Future>(promise); +} + +Future GrpcClientHostBufferStore::Delete(uint64_t handle) { + GrpcHostBufferDeleteRequest request; + request.set_session_id(session_id_); + request.set_handle(handle); + + ::grpc::ClientContext context; + GrpcHostBufferDeleteResponse response; + return Future(xla::FromGrpcStatus( + stub_->HostBufferDelete(&context, request, &response))); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/client/grpc_host_buffer.h b/xla/python/ifrt_proxy/client/grpc_host_buffer.h new file mode 100644 index 0000000000000..bbf9b9eecfeef --- /dev/null +++ b/xla/python/ifrt_proxy/client/grpc_host_buffer.h @@ -0,0 +1,71 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_GRPC_HOST_BUFFER_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_GRPC_HOST_BUFFER_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/host_buffer.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" +#include "tsl/platform/unbounded_work_queue.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +class GrpcClientHostBufferStore : public ClientHostBufferStore { + public: + GrpcClientHostBufferStore( + std::shared_ptr stub, + IfrtProxyVersion version, uint64_t session_id); + + ~GrpcClientHostBufferStore() override; + + // Implements ClientHostBufferStore. + + uint64_t NextHandle() override; + Future Store(uint64_t handle, absl::string_view data) override; + Future Store(uint64_t handle, const absl::Cord& data) override; + Future> Lookup(uint64_t handle) override; + Future Delete(uint64_t handle) override; + + private: + const std::shared_ptr stub_; + const IfrtProxyVersion version_; + const uint64_t session_id_; + std::atomic next_handle_ = 0; + + // Implementation note: `lookup_work_queue_` may have closures that invoke + // user-defined code. Each `Lookup()` call is associated with a scheduled + // closure, and the closure is used to first perform synchronous reads of the + // streaming RPC, and then to do `promise.Set()` for the Future returned to + // the caller. + std::unique_ptr lookup_work_queue_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_GRPC_HOST_BUFFER_H_ diff --git a/xla/python/ifrt_proxy/client/host_buffer.h b/xla/python/ifrt_proxy/client/host_buffer.h new file mode 100644 index 0000000000000..ceaf51debc7d8 --- /dev/null +++ b/xla/python/ifrt_proxy/client/host_buffer.h @@ -0,0 +1,62 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_HOST_BUFFER_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_HOST_BUFFER_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "xla/python/ifrt/future.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +class ClientHostBufferStore { + public: + virtual ~ClientHostBufferStore() = default; + + virtual uint64_t NextHandle() = 0; + + // Stores the data associated with the given handle. Returns an error if the + // handle already exists. + virtual Future Store(uint64_t handle, + absl::string_view data) = 0; + + // Stores the data associated with the given handle. Returns an error if the + // handle already exists. + // TODO(b/315023499) Find a way to increase the chunk size + virtual Future Store(uint64_t handle, + const absl::Cord& data) = 0; + + // Retrieves the data associated with the handle. Returns an error if the + // handle does not exist. + virtual Future> Lookup(uint64_t handle) = 0; + + // Deletes the host buffer associated with the handle. Returns an error if the + // handle does not exist. + virtual Future Delete(uint64_t handle) = 0; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_HOST_BUFFER_H_ diff --git a/xla/python/ifrt_proxy/client/memory.h b/xla/python/ifrt_proxy/client/memory.h new file mode 100644 index 0000000000000..e33c3a1a30ac8 --- /dev/null +++ b/xla/python/ifrt_proxy/client/memory.h @@ -0,0 +1,80 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_MEMORY_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_MEMORY_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/memory.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +class Client; + +class Memory : public xla::ifrt::Memory { + public: + Memory(int id, std::string memory_space_kind, std::string debug_string, + std::string to_string) + : id_(id), + memory_space_kind_(std::move(memory_space_kind)), + debug_string_(std::move(debug_string)), + to_string_(std::move(to_string)) {} + + // Not copyable or movable: IFRT expects `string_view` from + // `memory_space_kind()` to be stable throughout the client's lifetime. + Memory(const Memory& other) = delete; + Memory& operator=(const Memory& other) = delete; + + PjRtClient* client() const override { return nullptr; } + + absl::Span devices() const override { + return devices_; + } + + int id() const override { return id_; } + + absl::string_view memory_space_kind() const override { + return memory_space_kind_; + } + + absl::string_view DebugString() const override { return debug_string_; } + + absl::string_view ToString() const override { return to_string_; } + + private: + friend class Client; // For `devices_` initialization. + + int id_; + std::vector devices_; + std::string memory_space_kind_; + std::string debug_string_; + std::string to_string_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_MEMORY_H_ diff --git a/xla/python/ifrt_proxy/client/mock_client_session.h b/xla/python/ifrt_proxy/client/mock_client_session.h new file mode 100644 index 0000000000000..6b2a5bda24989 --- /dev/null +++ b/xla/python/ifrt_proxy/client/mock_client_session.h @@ -0,0 +1,50 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_MOCK_CLIENT_SESSION_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_MOCK_CLIENT_SESSION_H_ + +#include + +#include +#include "absl/status/status.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +class MockClientSession final : public ClientSession { + public: + MOCK_METHOD(Future, Enqueue, (std::unique_ptr req), + (override)); + MOCK_METHOD(void, Finish, (const absl::Status& s), (override)); +}; + +ACTION_P(MockClientSessionReturnResponse, response_proto) { + auto response = std::make_unique(response_proto); + response->mutable_response_metadata()->set_op_id( + arg0->request_metadata().op_id()); + return Future(std::move(response)); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_MOCK_CLIENT_SESSION_H_ diff --git a/xla/python/ifrt_proxy/client/mock_host_buffer.h b/xla/python/ifrt_proxy/client/mock_host_buffer.h new file mode 100644 index 0000000000000..81d70cc4e9301 --- /dev/null +++ b/xla/python/ifrt_proxy/client/mock_host_buffer.h @@ -0,0 +1,50 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_MOCK_HOST_BUFFER_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_MOCK_HOST_BUFFER_H_ + +#include + +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/host_buffer.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +class MockClientHostBufferStore final : public ClientHostBufferStore { + public: + MOCK_METHOD(uint64_t, NextHandle, (), (override)); + MOCK_METHOD(Future, Store, + (uint64_t handle, absl::string_view data), (override)); + MOCK_METHOD(Future, Store, + (uint64_t handle, const absl::Cord& data), (override)); + MOCK_METHOD(Future>, Lookup, (uint64_t handle), + (override)); + MOCK_METHOD(Future, Delete, (uint64_t handle), (override)); +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_MOCK_HOST_BUFFER_H_ diff --git a/xla/python/ifrt_proxy/client/py_module.cc b/xla/python/ifrt_proxy/client/py_module.cc new file mode 100644 index 0000000000000..0a6d4346afa1e --- /dev/null +++ b/xla/python/ifrt_proxy/client/py_module.cc @@ -0,0 +1,121 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "xla/python/ifrt_proxy/client/py_module.h" + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/log/log_entry.h" +#include "absl/log/log_sink.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/function.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/optional.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt_proxy/client/registry.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/py_client.h" +#include "tsl/platform/env.h" +#include "tsl/platform/statusor.h" + +namespace nb = ::nanobind; + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +struct PyClientConnectionOptions { + std::optional> on_disconnect; + std::optional> on_connection_update; +}; + +absl::StatusOr> GetClient( + std::string proxy_server_address, + const PyClientConnectionOptions& py_options) { + DCHECK(PyGILState_Check()); + std::unique_ptr client; + + ClientConnectionOptions options; + if (py_options.on_disconnect) { + // While it is possible to pass around `py_options.on_disconnect` without + // wrapping it via a shared_ptr, copying the `py_options.on_disconnect` + // object can internally attempt to acquire the GIL [1], and can thus block + // or even deadlock. A unique_ptr or `absl::AnyInvocable` is not sufficient + // because downstream code can make copies. Reference: + // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors + auto py_on_disconnect = std::make_shared>( + std::move(*py_options.on_disconnect)); + + options.on_disconnect = + [on_disconnect = std::move(py_on_disconnect)](absl::Status s) mutable { + LOG(WARNING) << "Connection to server failed, calling supplied " + << "`on_disconnect` function: " << s; + tsl::Env::Default()->SchedClosure([s, on_disconnect]() mutable { + nb::gil_scoped_acquire gil_acquire; + (*on_disconnect)(s.ToString()); + on_disconnect = nullptr; + }); + }; + } + + if (py_options.on_connection_update) { + auto fn = std::make_shared>( + std::move(*py_options.on_connection_update)); + options.on_connection_update = [fn](absl::string_view log_line) -> void { + tsl::Env::Default()->SchedClosure([fn, str = std::string(log_line)] { + nb::gil_scoped_acquire gil_acquire; + (*fn)(std::string(str)); + }); + }; + } + + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(client, CreateClient(proxy_server_address, options)); + } + + // Constructing `xla::PyClient` requires GIL as it may dec-ref Python objects. + return xla::PyClient::Make(std::move(client)); +} + +} // namespace + +void BuildIfrtProxySubmodule(nb::module_& m) { + nb::module_ sub_module = m.def_submodule("ifrt_proxy", "IFRT proxy"); + + nb::class_(sub_module, "ClientConnectionOptions") + .def(nb::init<>()) + .def_rw("on_disconnect", &PyClientConnectionOptions::on_disconnect, + nb::arg().none()) + .def_rw("on_connection_update", + &PyClientConnectionOptions::on_connection_update, + nb::arg().none()); + + sub_module.def("get_client", xla::ValueOrThrowWrapper(GetClient), + nb::arg("proxy_server_address"), nb::arg("options")); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/client/py_module.h b/xla/python/ifrt_proxy/client/py_module.h new file mode 100644 index 0000000000000..239f693499f33 --- /dev/null +++ b/xla/python/ifrt_proxy/client/py_module.h @@ -0,0 +1,31 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_PY_MODULE_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_PY_MODULE_H_ + +#include "nanobind/nanobind.h" // from @nanobind + +namespace xla { +namespace ifrt { +namespace proxy { + +void BuildIfrtProxySubmodule(nanobind::module_& m); + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_PY_MODULE_H_ diff --git a/xla/python/ifrt_proxy/client/registry.cc b/xla/python/ifrt_proxy/client/registry.cc new file mode 100644 index 0000000000000..11680771b8b49 --- /dev/null +++ b/xla/python/ifrt_proxy/client/registry.cc @@ -0,0 +1,102 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/registry.h" + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "xla/python/ifrt/client.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +namespace { + +using FactoryFn = + std::function>( + absl::string_view, const ClientConnectionOptions&)>; + +struct Registry { + absl::Mutex mu; + absl::flat_hash_map factories ABSL_GUARDED_BY(mu); +}; + +Registry* registry() { + static auto* r = new Registry(); + return r; +} + +} // namespace + +void RegisterClientFactory(absl::string_view transport_name, + FactoryFn factory) { + absl::MutexLock l(®istry()->mu); + const bool inserted = + registry() + ->factories.insert({std::string(transport_name), factory}) + .second; + CHECK(inserted) << "IFRT proxy transport '" << transport_name + << "' already registered"; +} + +absl::StatusOr> CreateClient( + absl::string_view proxy_server_address, + const ClientConnectionOptions& options) { + const size_t pos = proxy_server_address.find("://"); + if (pos == std::string::npos) { + return absl::InvalidArgumentError( + absl::StrCat("IFRT proxy server address must be " + "'://' (e.g., " + "'grpc://localhost'), but got ", + proxy_server_address)); + } + + const absl::string_view transport_name = proxy_server_address.substr(0, pos); + const absl::string_view address = proxy_server_address.substr(pos + 3); + + FactoryFn factory; + { + absl::MutexLock l(®istry()->mu); + const auto it = registry()->factories.find(transport_name); + if (it == registry()->factories.end()) { + return absl::NotFoundError( + absl::StrCat("IFRT proxy transport '", transport_name, + "' not found; available transports are: ", + absl::StrJoin(registry()->factories, ", ", + [](std::string* out, const auto& it) { + out->append(it.first); + }))); + } + factory = it->second; + } + + return factory(address, options); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/client/registry.h b/xla/python/ifrt_proxy/client/registry.h new file mode 100644 index 0000000000000..ebf04532b278e --- /dev/null +++ b/xla/python/ifrt_proxy/client/registry.h @@ -0,0 +1,68 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_REGISTRY_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_REGISTRY_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "xla/python/ifrt/client.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +struct ClientConnectionOptions { + // Timeout for establishing the connection. + absl::Duration connection_timeout = absl::Minutes(2); + + // A callback that (if it is not set to nullptr) will be called if there was a + // successful connection to the proxy server, but there was a later + // disconnect. The callback may be called synchronously from a thread that + // performs various important activities, and therefore should not block on + // any events (or deadlocks may happen). + std::function on_disconnect = nullptr; + + // Captures logs related to establishing the connection. Logs may be generated + // synchronously from a thread that performs various important activities, + // so the function should not block (or deadlocks may happen). + std::function on_connection_update = nullptr; +}; + +// Registers a new factory for client backend implementation. Crashes if the +// same backend name is registered more than once. +void RegisterClientFactory( + absl::string_view transport_name, + std::function>( + absl::string_view address, const ClientConnectionOptions& options)> + factory); + +// Creates a client for the given backend target. The backend target string must +// be in the form of `:`. +absl::StatusOr> CreateClient( + absl::string_view proxy_server_address, + const ClientConnectionOptions& options = ClientConnectionOptions()); + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_REGISTRY_H_ diff --git a/xla/python/ifrt_proxy/client/rpc_helper.cc b/xla/python/ifrt_proxy/client/rpc_helper.cc new file mode 100644 index 0000000000000..e07f689513e6a --- /dev/null +++ b/xla/python/ifrt_proxy/client/rpc_helper.cc @@ -0,0 +1,175 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/rpc_helper.h" + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#if defined(PLATFORM_GOOGLE) +#include "absl/types/source_location.h" +#endif +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "tsl/platform/status_to_from_proto.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// DoRpc is a templated function that implements the logic of all RPC-wrapping +// functions of `RpcHelper`, such as `RpcHelper::MakeArrayFromHostBuffer()`. +template +Future>> DoRpc( + ClientSession* session, RequestMetadata metadata, + void (IfrtRequest::*set_req)(Req*), Resp* (IfrtResponse::*get_resp)(), + bool (IfrtResponse::*has_resp)() const, std::unique_ptr req) { + auto ifrt_req = std::make_unique(); + *ifrt_req->mutable_request_metadata() = metadata; + (ifrt_req.get()->*set_req)(req.release()); + + auto promise = Future>>::CreatePromise(); + auto on_ready = [promise, has_resp, + get_resp](ClientSession::Response r) mutable { + if (!r.ok()) { + LOG(ERROR) << "Connection to IFRT proxy server was terminated: " + << r.status(); + promise.Set(absl::UnavailableError( + absl::StrCat("Connection to IFRT proxy server was terminated: ", + r.status().ToString()))); + return; + } + + std::shared_ptr response = *std::move(r); + if (!response->has_response_metadata()) { + promise.Set(absl::InternalError( + absl::StrCat("IFRT server sent a message without metadata: ", + response->DebugString()))); + return; + } + + const absl::Status metadata_status = + tsl::StatusFromProto(response->response_metadata().status()); + const bool has_expected_response = (response.get()->*has_resp)(); + const auto has_some_response = + response->response_case() != IfrtResponse::RESPONSE_NOT_SET; + + if (metadata_status.ok() && !has_some_response) { + promise.Set(absl::InternalError( + absl::StrCat("OK response with no actual response set: ", + response->DebugString()))); + return; + } + + if (!has_expected_response && has_some_response) { + promise.Set(absl::InternalError(absl::StrCat( + "Response with wrong type (expected ", Resp::GetDescriptor()->name(), + "): ", response->DebugString()))); + return; + } + + // If the metadata_status is not-OK, according to ifrt_service.proto, + // there may be an error _instead_ of an actual response value. So, check if + // an actual response value exists, and if so return it irrespective of what + // the metadata_status says. + if (!has_some_response) { + promise.Set(metadata_status); + } else { + promise.Set( + std::make_shared(*std::move((response.get()->*get_resp)()))); + } + }; + session->Enqueue(std::move(ifrt_req)).OnReady(on_ready); + + return Future>>(promise); +} + +RequestMetadata RpcHelper::ManufactureRequestMetadata() { + RequestMetadata result; + { + absl::MutexLock l(&mu_); + result.set_op_id(next_op_id_++); + } + int prev_op_id = result.op_id() - 1; + if (prev_op_id != 0) { + // TODO(b/266635130): Depend only on necessary prior operations. + result.add_dependencies(prev_op_id); + } + // TODO(b/282757875): Add a ClearOps RPC for old dependencies. + return result; +} + +void RpcHelper::Disconnect() { + session_->Finish(absl::CancelledError("Disconnected by client")); +} + +// TODO(b/266635130): Remove this preprocessor macro. Preprocessor macros +// go against the style guide, but are convenient as we are introducing more +// RPCs and are making changes to the exact signature of the DoRpc function. +#define RPC(METHOD, PROPERTY) \ + RpcHelper::ResponseFuture RpcHelper::METHOD( \ + std::unique_ptr req) { \ + return DoRpc(session_.get(), ManufactureRequestMetadata(), \ + &IfrtRequest::set_allocated_##PROPERTY##_request, \ + &IfrtResponse::mutable_##PROPERTY##_response, \ + &IfrtResponse::has_##PROPERTY##_response, std::move(req)); \ + } + +RPC(Init, init); +RPC(GetDefaultDeviceAssignment, get_default_device_assignment); +RPC(CheckFuture, check_future); +RPC(MakeArrayFromHostBuffer, make_array_from_host_buffer); +RPC(AssembleArrayFromSingleDeviceArrays, + assemble_array_from_single_device_arrays); +RPC(DisassembleIntoSingleDeviceArrays, disassemble_into_single_device_arrays); +RPC(CopyToHostBuffer, copy_to_host_buffer); +RPC(CheckArrayReady, check_array_ready); +RPC(IsArrayDeleted, is_array_deleted); +RPC(DestructArray, destruct_array) +RPC(Reshard, reshard); +RPC(FullyReplicatedShard, fully_replicated_shard); +RPC(DeleteArray, delete_array); +RPC(Compile, compile); +RPC(LoadedExecutableMetadata, loaded_executable_metadata); +RPC(LoadedExecutableExecute, loaded_executable_execute); +RPC(LoadedExecutableDelete, loaded_executable_delete); +RPC(LoadedExecutableIsDeleted, loaded_executable_is_deleted); +RPC(LoadedExecutableDestruct, loaded_executable_destruct); +RPC(LoadedHostCallbackPoll, loaded_host_callback_poll); +RPC(LoadedHostCallbackReturn, loaded_host_callback_return); + +Future RpcHelper::CheckFuture(uint64_t handle) { + auto req = std::make_unique(); + req->set_future_handle(handle); + + auto promise = Future::CreatePromise(); + CheckFuture(std::move(req)) + .OnReady( + [promise](absl::StatusOr> + response) mutable { promise.Set(response.status()); }); + + return Future(promise); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/client/rpc_helper.h b/xla/python/ifrt_proxy/client/rpc_helper.h new file mode 100644 index 0000000000000..b5c3bf6340241 --- /dev/null +++ b/xla/python/ifrt_proxy/client/rpc_helper.h @@ -0,0 +1,150 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_RPC_HELPER_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_RPC_HELPER_H_ + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/client/host_buffer.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// RpcHelper helps establish a connection with the IFRT server and perform +// logical RPCs on the connection. +// +// TODO(b/266635130): RpcHelper currently makes each logical RPC order-dependent +// on the previous RPC it was asked to make. Instead, allow users of RpcHelper +// specify the necessary dependency. +class RpcHelper { + public: + RpcHelper(IfrtProxyVersion version, std::shared_ptr session) + : version_(std::move(version)), session_(std::move(session)) {} + + void Disconnect(); + + RpcHelper(const RpcHelper&) = delete; + RpcHelper& operator=(const RpcHelper&) = delete; + ~RpcHelper() { Disconnect(); } + + // IFRT Proxy version negotiated between the client and the server. + const IfrtProxyVersion& version() const { return version_; } + + // Initializes the host buffer store for this RpcHelper instance. This must be + // called exactly once during initialization before `host_buffer_store()` is + // called. + void set_host_buffer_store( + std::shared_ptr host_buffer_store) { + CHECK(host_buffer_store_ == nullptr); + host_buffer_store_ = std::move(host_buffer_store); + } + + const std::shared_ptr& host_buffer_store() const { + return host_buffer_store_; + } + + template + using ResponseFuture = Future>>; + + // Wrapper function for various logical RPCs defined in ifrt_service.proto. + // Whenever the RPC finishes, `on_done` will be called with the result or the + // return status. `on_done` can be called with various locks held and should + // return quickly without blocking on any event. `on_done` is guaranteed to be + // called exactly once. + // + // The functions can be invoked after the connection is broken, but will + // result in `on_done` getting called with an error (see + // "WrapAsConnectionError" in `rpc_helper.cc`). + + ResponseFuture Init(std::unique_ptr req); + ResponseFuture GetDefaultDeviceAssignment( + std::unique_ptr req); + + ResponseFuture CheckFuture( + std::unique_ptr req); + + ResponseFuture MakeArrayFromHostBuffer( + std::unique_ptr req); + ResponseFuture + AssembleArrayFromSingleDeviceArrays( + std::unique_ptr req); + ResponseFuture + DisassembleIntoSingleDeviceArrays( + std::unique_ptr req); + ResponseFuture CopyToHostBuffer( + std::unique_ptr req); + ResponseFuture CheckArrayReady( + std::unique_ptr req); + ResponseFuture Reshard(std::unique_ptr req); + ResponseFuture FullyReplicatedShard( + std::unique_ptr req); + ResponseFuture IsArrayDeleted( + std::unique_ptr req); + ResponseFuture DeleteArray( + std::unique_ptr req); + ResponseFuture DestructArray( + std::unique_ptr req); + + ResponseFuture Compile(std::unique_ptr req); + + ResponseFuture LoadedExecutableMetadata( + std::unique_ptr req); + ResponseFuture LoadedExecutableExecute( + std::unique_ptr req); + ResponseFuture LoadedExecutableDelete( + std::unique_ptr req); + ResponseFuture LoadedExecutableIsDeleted( + std::unique_ptr req); + ResponseFuture LoadedExecutableDestruct( + std::unique_ptr req); + + ResponseFuture LoadedHostCallbackPoll( + std::unique_ptr req); + ResponseFuture LoadedHostCallbackReturn( + std::unique_ptr req); + + // Utility functions for common functions. + + Future CheckFuture(uint64_t handle); + + private: + RequestMetadata ManufactureRequestMetadata() ABSL_LOCKS_EXCLUDED(mu_); + + const IfrtProxyVersion version_; + const std::shared_ptr session_; + std::shared_ptr host_buffer_store_; + + absl::Mutex mu_; + uint64_t next_op_id_ ABSL_GUARDED_BY(mu_) = 1; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_RPC_HELPER_H_ diff --git a/xla/python/ifrt_proxy/client/version.h b/xla/python/ifrt_proxy/client/version.h new file mode 100644 index 0000000000000..06df1e0c70b00 --- /dev/null +++ b/xla/python/ifrt_proxy/client/version.h @@ -0,0 +1,32 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_VERSION_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_VERSION_H_ + +namespace xla { +namespace ifrt { +namespace proxy { + +// TODO(b/296144873): Document the version upgrade policy. +inline constexpr int kClientMinVersion = 1; +inline constexpr int kClientMaxVersion = 1; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_VERSION_H_ diff --git a/xla/python/ifrt_proxy/common/BUILD b/xla/python/ifrt_proxy/common/BUILD new file mode 100644 index 0000000000000..a0474c001c79d --- /dev/null +++ b/xla/python/ifrt_proxy/common/BUILD @@ -0,0 +1,184 @@ +# Copyright 2023 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@tsl//tsl:tsl.bzl", "if_google") +load("@tsl//tsl/platform:build_config.bzl", "tf_proto_library") +# copybara:uncomment load("@bazel_skylib//:bzl_library.bzl", "bzl_library") + +load("//xla/python/ifrt_proxy/common:ifrt_proxy.bzl", "default_ifrt_proxy_visibility", "ifrt_proxy_cc_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = default_ifrt_proxy_visibility, +) + +# Export headers referenced by the google-internal-version of grpc_credentials. +exports_files( + ["grpc_credentials.h"], + visibility = if_google( + ["//xla/python/ifrt_proxy/common/google:__pkg__"], + ["//visibility:private"], + ), +) + +cc_library( + name = "grpc_credentials", + hdrs = ["grpc_credentials.h"], + deps = if_google( + ["//xla/python/ifrt_proxy/common/google:grpc_credentials_lib"], + [":grpc_credentials_oss_lib"], + ) + ["@com_github_grpc_grpc//:grpc++"], +) + +cc_library( + name = "grpc_credentials_oss_lib", + srcs = [ + "grpc_credentials.cc", + "grpc_credentials.h", + ], + visibility = ["//visibility:private"], + deps = [ + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@tsl//tsl/platform", + ], + alwayslink = True, +) + +tf_proto_library( + name = "types_proto", + srcs = ["types.proto"], +) + +tf_proto_library( + name = "ifrt_service_proto", + srcs = ["ifrt_service.proto"], + protodeps = [ + ":types_proto", + # copybara:uncomment "//google/protobuf:any", + "//xla:xla_data_proto", + "//xla/pjrt:execute_options_proto", + "//xla/python/ifrt:dtype_proto", + "//xla/python/ifrt:serdes_proto", + "//xla/python/ifrt:shape_proto", + "//xla/python/ifrt:sharding_proto", + "@tsl//tsl/protobuf:status_proto", + ], +) + +tf_proto_library( + name = "grpc_ifrt_service_proto", + srcs = ["grpc_ifrt_service.proto"], + has_services = True, + create_go_proto = False, + create_grpc_library = True, + create_java_proto = False, + create_kotlin_proto = False, + protodeps = [":ifrt_service_proto"], +) + +cc_library( + name = "types", + srcs = ["types.cc"], + hdrs = ["types.h"], + deps = [ + ":ifrt_service_proto_cc", + ":types_proto_cc", + "//xla/pjrt:pjrt_common", + "//xla/python/ifrt", + "//xla/python/ifrt:serdes", + "//xla/python/ifrt:sharding_serdes", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@tsl//tsl/platform:statusor", + ] + if_google(["@com_google_absl//absl/types:source_location"]), +) + +ifrt_proxy_cc_test( + name = "types_test", + srcs = ["types_test.cc"], + deps = [ + ":types", + ":types_proto_cc", + "//xla/pjrt:pjrt_common", + "//xla/python/ifrt", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + ], +) + +cc_library( + name = "array_util", + srcs = ["array_util.cc"], + hdrs = ["array_util.h"], + deps = [ + "//xla/python/ifrt", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@tsl//tsl/platform:statusor", + ], +) + +ifrt_proxy_cc_test( + name = "array_util_test", + srcs = ["array_util_test.cc"], + deps = [ + ":array_util", + "//xla/python/ifrt", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", + ], +) + +# common_serdes is a collection of all common libraries that register SerDes implementations. +cc_library( + name = "common_serdes", + deps = ["//xla/python/pjrt_ifrt:xla_program_serdes"], + alwayslink = True, +) + +cc_library( + name = "proto_util", + srcs = ["proto_util.cc"], + hdrs = ["proto_util.h"], + deps = [ + ":ifrt_service_proto_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:status_to_from_proto", + ], +) + +# copybara:uncomment_begin +# bzl_library( +# name = "ifrt_proxy_bzl", +# srcs = ["ifrt_proxy.bzl"], +# parse_tests = False, +# visibility = ["//visibility:private"], +# ) +# copybara:uncomment_end diff --git a/xla/python/ifrt_proxy/common/array_util.cc b/xla/python/ifrt_proxy/common/array_util.cc new file mode 100644 index 0000000000000..bdcf8a13dfcc8 --- /dev/null +++ b/xla/python/ifrt_proxy/common/array_util.cc @@ -0,0 +1,156 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/common/array_util.h" + +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/shape.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +namespace { + +std::string StridesAsStr(const ArrayMemRegion::ByteStrides& strides) { + if (!strides.has_value()) return "strides{nullopt}"; + return absl::StrCat("strides{", absl::StrJoin(*strides, ","), "}"); +} + +} // namespace + +absl::StatusOr> DefaultByteStrides(const DType dtype, + const Shape& shape) { + if (!dtype.byte_size().has_value()) { + return absl::InvalidArgumentError( + absl::StrCat("Unsupported data type to query byte-strides for: ", + dtype.DebugString())); + } + std::vector result(shape.dims().size()); + int64_t stride = *dtype.byte_size(); + for (int i = static_cast(shape.dims().size()) - 1; i >= 0; --i) { + result[i] = stride; + stride *= shape.dims()[i]; + } + return result; +} + +absl::StatusOr ArrayMemRegion::FromZerothElementPointer( + const void* zeroth_element, const DType dtype, const Shape& shape, + ByteStrides byte_strides) { + if (!dtype.byte_size().has_value()) { + return absl::InvalidArgumentError( + absl::StrCat("Unsupported data type to construct ArrayMemRegion: ", + dtype.DebugString())); + } + // Below, we return an error for all situations where the zeroth_element + // is different from mem_region_start. + void* const mem_region_start = const_cast(zeroth_element); + + if (!byte_strides.has_value() || + (byte_strides->empty() && shape.dims().empty())) { + return ArrayMemRegion(mem_region_start, + dtype.byte_size().value() * shape.num_elements()); + } + if (shape.num_elements() == 0) { + return ArrayMemRegion(mem_region_start, 0); + } + if (shape.dims().size() != byte_strides->size()) { + return absl::InvalidArgumentError( + absl::StrCat("Shape has different dimensions from byte_strides: ", + shape.DebugString(), " vs ", StridesAsStr(byte_strides))); + } + // Logic based on + // https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html + // + // So long as all strides are positive, the array's memory region begins at + // the zeroth element, and the last element of the array is farthest off from + // the beginning. We use the offset of the last element of the array to + // calculate the memory region. Note that this reasoning does not apply to + // negative strides, since the zeroth element can then be in the middle of the + // memory region (as an example, consider shape=[10, 10] and + // element_strides=[10,-1]). + uint64_t last_element_byte_offset = 0; + for (int i = 0; i < byte_strides->size(); ++i) { + int stride = (*byte_strides)[i]; + if (shape.dims()[i] < 0) { + return absl::InvalidArgumentError( + absl::StrCat("A shape dimension is negative: ", shape.DebugString())); + } else if (shape.dims()[i] == 1) { + // The stride shouldn't matter in this case, so continue without checking + // validity of the given stride. + continue; + } else if (stride <= 0) { + return absl::UnimplementedError( + absl::StrCat("Negative or zero strides are not fully supported: ", + StridesAsStr(byte_strides))); + } else if (stride % dtype.byte_size().value() != 0) { + return absl::UnimplementedError(absl::StrCat( + "byte_stride[", i, "] is not a multiple of the data-type's size: ", + StridesAsStr(byte_strides), ", dtype=", dtype.DebugString())); + } else { + // `shape.dims()[i]` cannot be negative (we explicitly check for this + // above) or zero (we return early for `shape.num_elements() == 0`). + DCHECK_GT(shape.dims()[i], 0); + last_element_byte_offset += (stride * (shape.dims()[i] - 1)); + } + } + return ArrayMemRegion(mem_region_start, + last_element_byte_offset + dtype.byte_size().value()); +} + +absl::StatusOr ArrayMemRegion::FromMinimalMemRegion( + absl::string_view mem_region, const DType dtype, const Shape& shape, + ByteStrides byte_strides) { + // FromZerothElementPointer() currently returns an error for any situation + // where the zeroth_element will is not equal to the place where the minimal + // memory region starts. + TF_ASSIGN_OR_RETURN( + auto result, + FromZerothElementPointer(mem_region.data(), dtype, shape, byte_strides)); + + if (result.mem_region().size() != mem_region.size()) { + return absl::InvalidArgumentError( + absl::StrCat("Incorrect size ", result.mem_region().size(), " vs ", + mem_region.size(), "; is provided memory region minimal? ", + dtype.DebugString(), " ", shape.DebugString(), " ", + StridesAsStr(byte_strides))); + } + CHECK_EQ(result.mem_region().data(), mem_region.data()); + return result; +} + +absl::string_view ArrayMemRegion::mem_region() const { + return absl::string_view(static_cast(mem_region_start_), nbytes_); +} + +void* ArrayMemRegion::zeroth_element() const { + // ArrayMemRegion cannot yet be constructed for situations where the + // zeroth element pointer is different from mem_region_start_. + return mem_region_start_; +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/common/array_util.h b/xla/python/ifrt_proxy/common/array_util.h new file mode 100644 index 0000000000000..2ba8ff7ce4256 --- /dev/null +++ b/xla/python/ifrt_proxy/common/array_util.h @@ -0,0 +1,78 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_COMMON_ARRAY_UTIL_H_ +#define XLA_PYTHON_IFRT_PROXY_COMMON_ARRAY_UTIL_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/shape.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Returns the byte-strides corresponding to the compact major-to-minor layout. +absl::StatusOr> DefaultByteStrides(DType dtype, + const Shape& shape); + +// Denotes a chunk of contiguous memory that contains all elements of the +// in-host (RAM) representation of an Array. +class ArrayMemRegion { + public: + // Nullopt implies compact major-to-minor layout, as returned by + // `DefaultByteStrides()`. + using ByteStrides = std::optional>; + + // Constructs an ArrayMemRegion given `mem_region`, where `mem_region` is + // minimal, i.e., the lower-most and upper-most addresses of `mem_region` are + // necessary to retrieve elements from the array. + static absl::StatusOr FromMinimalMemRegion( + absl::string_view mem_region, DType dtype, const Shape& shape, + ByteStrides byte_strides); + + // Constructs an ArrayMemRegion given a pointer to the zeroth-element of the + // (in-host representation of the) Array. + static absl::StatusOr FromZerothElementPointer( + const void* zeroth_element, DType dtype, const Shape& shape, + ByteStrides byte_strides); + + // Returns a region of memory whose lower-most and upper-most addresses are + // necessary to retrieve elements of the (in-host representation of) the + // array. + absl::string_view mem_region() const; + + // Returns a pointer to the zeroth-element of the (in-host representation of + // the) Array. + void* zeroth_element() const; + + private: + ArrayMemRegion(void* mem_region_start, size_t nbytes) + : mem_region_start_(mem_region_start), nbytes_(nbytes) {} + + void* const mem_region_start_; + const size_t nbytes_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_COMMON_ARRAY_UTIL_H_ diff --git a/xla/python/ifrt_proxy/common/array_util_test.cc b/xla/python/ifrt_proxy/common/array_util_test.cc new file mode 100644 index 0000000000000..51e189bb9ffc7 --- /dev/null +++ b/xla/python/ifrt_proxy/common/array_util_test.cc @@ -0,0 +1,201 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/common/array_util.h" + +#include +#include +#include +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/shape.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +namespace { + +using ::testing::ElementsAre; +using ::testing::Not; +using ::testing::TestWithParam; +using ::tsl::testing::IsOk; +using ::tsl::testing::IsOkAndHolds; + +constexpr DType::Kind kF64 = DType::Kind::kF64; +constexpr DType::Kind kS32 = DType::Kind::kS32; +constexpr DType::Kind kString = DType::Kind::kString; +using Strides = std::vector; + +TEST(DefaultByteStrides, ErrorsIfBadDtype) { + EXPECT_THAT(DefaultByteStrides(DType(kString), Shape({1})), Not(IsOk())); +} + +TEST(DefaultByteStrides, HappyCase) { + EXPECT_THAT(DefaultByteStrides(DType(kF64), Shape({4, 3, 5})), + IsOkAndHolds(ElementsAre(120, 40, 8))); +} + +// TC represents a testcase. +struct TC { + const std::string test_name; + const DType::Kind dtype_kind; + const std::vector shape; + const std::optional> byte_strides; + const std::optional expected_size; +}; +std::string PrintToString(const TC& tc) { return tc.test_name; } + +class ArrayMemRegionSuccess : public TestWithParam {}; +INSTANTIATE_TEST_SUITE_P( + Tests, ArrayMemRegionSuccess, + testing::Values( + // F64 + TC{"DefaultF64", kF64, {4, 3, 5}, std::nullopt}, + TC{"MajorToMinorStridesF64", kF64, {4, 3, 5}, Strides({120, 40, 8})}, + TC{"NotMajorToMinorF64", kF64, {3, 4, 5}, Strides({40, 120, 8})}, + TC{"TransposedF64", kF64, {5, 3, 4}, Strides({8, 40, 120})}, + // S32 + TC{"DefaultS32", kS32, {4, 3, 5}, std::nullopt}, + TC{"MajorToMinorStridesS32", kS32, {4, 3, 5}, Strides({60, 20, 4})}, + TC{"NotMajorToMinorS32", kS32, {3, 4, 5}, Strides({20, 60, 4})}, + TC{"TransposedS32", kS32, {5, 3, 4}, Strides({4, 20, 60})}, + // Scalar + TC{"ScalarF64DefaultStrides", kF64, {}, std::nullopt}, + TC{"ScalarF64EmptyStrides", kF64, {}, Strides({})}, + // Zero elements + TC{"NoColsDefaultStrides", kF64, {5, 0}, std::nullopt}, + TC{"NoColsStridesNonZero", kF64, {5, 0}, Strides({40, 4})}, + TC{"NoColsStridesZero", kF64, {5, 0}, Strides({0, 0})}, + TC{"NoRowsDefaultStrides", kF64, {0, 5}, std::nullopt}, + TC{"NoRowsStridesNonZero", kF64, {0, 5}, Strides({40, 4})}, + TC{"NoRowsStridesZero", kF64, {0, 5}, Strides({0, 0})}, + // Dimension with size 1 + TC{"SingleElementArbitraryStrides", kF64, {1, 1}, Strides({100, 100})}, + TC{"OneRowArbitraryColStride", kF64, {1, 5}, Strides({100, 8})}, + TC{"OneColArbitraryRowStride", kF64, {5, 1}, Strides({8, 100})}, + TC{"OneRowZeroColStride", kF64, {1, 5}, Strides({0, 8})}, + TC{"OneColZeroRowStride", kF64, {5, 1}, Strides({8, 0})}, + // Non-compact strides. + TC{"NonCompactSingleDimension", kS32, {5}, Strides({16}), 68}, + TC{"NonCompactDim0", kS32, {4, 3, 5}, Strides({120, 20, 4}), 420}, + TC{"PaddedElements", kS32, {4, 3, 5}, Strides({120, 40, 8}), 476}), + testing::PrintToStringParamName()); +TEST_P(ArrayMemRegionSuccess, TestCase) { + const TC tc = GetParam(); + const DType dtype(tc.dtype_kind); + const Shape shape(tc.shape); + const size_t expected_size = tc.expected_size.value_or( + dtype.byte_size().value() * shape.num_elements()); + std::string data(expected_size, 'a'); + + TF_ASSERT_OK_AND_ASSIGN(auto mem_region1, + ArrayMemRegion::FromZerothElementPointer( + data.data(), dtype, shape, tc.byte_strides)); + EXPECT_EQ(mem_region1.zeroth_element(), data.data()); + // Note: `EXPECT_EQ(mem_region.mem_region(), absl::string_view(data))` can + // cause asan to complain if the expectation fails. + EXPECT_EQ(mem_region1.mem_region().data(), data.data()); + EXPECT_EQ(mem_region1.mem_region().size(), data.size()); + + TF_ASSERT_OK_AND_ASSIGN( + auto mem_region2, ArrayMemRegion::FromMinimalMemRegion(data, dtype, shape, + tc.byte_strides)); + EXPECT_EQ(mem_region2.zeroth_element(), data.data()); + EXPECT_EQ(mem_region2.mem_region().data(), data.data()); + EXPECT_EQ(mem_region2.mem_region().size(), data.size()); +} + +class ArrayMemRegionFailure : public TestWithParam {}; +INSTANTIATE_TEST_SUITE_P( + Tests, ArrayMemRegionFailure, + testing::Values( + // Will not be supported + TC{"OneString", kString, {}, std::nullopt}, + TC{"ManyStrings", kString, {5}, std::nullopt}, + // Currently unimplemented + TC{"NegativeByteStrides", kS32, {4, 3, 5}, Strides({-60, -20, -4})}, + TC{"ZeroByteStride", kS32, {5, 5}, Strides({0, 0})}, + TC{"SmallerByteStrideThanDataType", kS32, {5, 5}, Strides({1, 1})}, + TC{"ByteStrideIndivisibleByDataType", kS32, {5, 5}, Strides({7, 7})}, + // Bad arguments + TC{"NegativeShapeDimension", kS32, {-5, -5}, Strides({20, 4})}), + testing::PrintToStringParamName()); +TEST_P(ArrayMemRegionFailure, TestCase) { + const TC tc = GetParam(); + const DType dtype(tc.dtype_kind); + const Shape shape(tc.shape); + char const* kSomeAddr = reinterpret_cast(1UL << 48); + + auto mem_region1 = ArrayMemRegion::FromZerothElementPointer( + /*zeroth_element=*/kSomeAddr, dtype, shape, tc.byte_strides); + EXPECT_THAT(mem_region1.status(), Not(IsOk())); + + const size_t kSomeSize = 1024; + auto mem_region2 = ArrayMemRegion::FromMinimalMemRegion( + absl::string_view(kSomeAddr, kSomeSize), dtype, shape, tc.byte_strides); + EXPECT_THAT(mem_region2.status(), Not(IsOk())); +} + +TEST(ArrayMemRegion, FromBadMemRegionSizeFails) { + const DType kDType(kS32); + const Shape kShape({5, 5}); + const size_t kDataBytes = kDType.byte_size().value() * kShape.num_elements(); + + const size_t kExtraSuffixBytes = 10; + std::string data_with_extra_suffix(kDataBytes + kExtraSuffixBytes, 'a'); + + // If we know that the zeroth_element is at the beginning, then we + // can construct the ArrayMemoryRegion; the constructed ArrayMemoryRegion + // will not contain the suffix. + TF_ASSERT_OK_AND_ASSIGN( + auto mem_region1, + ArrayMemRegion::FromZerothElementPointer( + /*zeroth_element=*/data_with_extra_suffix.data(), kDType, kShape, + /*byte_strides=*/std::nullopt)); + EXPECT_EQ(mem_region1.mem_region().data(), data_with_extra_suffix.data()); + EXPECT_EQ(mem_region1.zeroth_element(), data_with_extra_suffix.data()); + EXPECT_LT(mem_region1.mem_region().size(), data_with_extra_suffix.size()); + EXPECT_EQ(mem_region1.mem_region().size(), kDataBytes); + + // But given the data_with_extra_suffix region, we cannot discover where + // within it the zeroth-element points to, so we cannot construct an + // ArrayMemoryRegion from it. + auto mem_region2 = ArrayMemRegion::FromMinimalMemRegion( + data_with_extra_suffix, kDType, kShape, + /*byte_strides=*/std::nullopt); + EXPECT_THAT(mem_region2.status(), Not(IsOk())); + + // Similarly, if we provided `FromMinimalMemRegion` a `data` that was smaller + // than what the constructed `ArrayMemoryRegion` should point to, that will + // be detected as an error. + std::string data_without_some_bytes(kDataBytes - kExtraSuffixBytes, 'a'); + auto mem_region3 = ArrayMemRegion::FromMinimalMemRegion( + data_without_some_bytes, kDType, kShape, + /*byte_strides=*/std::nullopt); + EXPECT_THAT(mem_region3.status(), Not(IsOk())); +} + +} // namespace + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/common/grpc_credentials.cc b/xla/python/ifrt_proxy/common/grpc_credentials.cc new file mode 100644 index 0000000000000..f72424b859577 --- /dev/null +++ b/xla/python/ifrt_proxy/common/grpc_credentials.cc @@ -0,0 +1,71 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/common/grpc_credentials.h" + +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "grpcpp/security/credentials.h" +#include "grpcpp/security/server_credentials.h" +#include "tsl/platform/platform.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +namespace { + +bool UseInsecureCredentials() { + // Use insecure only with `bazel test`. + const bool insecure = (getenv("TEST_UNDECLARED_OUTPUTS_DIR") != nullptr); + + if (insecure) { + // We should not be getting to this point at all in the google-internal + // code, but check to be sure. + CHECK_EQ(TSL_IS_IN_OSS, 1); + } + + return insecure; +} + +} // namespace + +std::shared_ptr<::grpc::ChannelCredentials> GetClientCredentials() { + if (UseInsecureCredentials()) { + LOG(WARNING) << "Using insecure client credentials for gRPC."; + return ::grpc::InsecureChannelCredentials(); // NOLINT + } else { + LOG(INFO) << "Using ALTS client credentials for gRPC."; + return ::grpc::experimental::AltsCredentials( + ::grpc::experimental::AltsCredentialsOptions()); + } +} + +std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials() { + if (UseInsecureCredentials()) { + LOG(WARNING) << "Using insecure server credentials for gRPC."; + return ::grpc::InsecureServerCredentials(); // NOLINT + } else { + LOG(INFO) << "Using ALTS server credentials for gRPC."; + return ::grpc::experimental::AltsServerCredentials( + ::grpc::experimental::AltsServerCredentialsOptions()); + } +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/common/grpc_credentials.h b/xla/python/ifrt_proxy/common/grpc_credentials.h new file mode 100644 index 0000000000000..46435a4adebef --- /dev/null +++ b/xla/python/ifrt_proxy/common/grpc_credentials.h @@ -0,0 +1,41 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_COMMON_GRPC_CREDENTIALS_H_ +#define XLA_PYTHON_IFRT_PROXY_COMMON_GRPC_CREDENTIALS_H_ + +#include + +#include "grpcpp/security/credentials.h" +#include "grpcpp/security/server_credentials.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Get credentials to use in the client gRPC. +// TODO(b/323079791): Migrate to use utility library from tsl/platform. +std::shared_ptr<::grpc::ChannelCredentials> GetClientCredentials(); + +// Get credentials to use in the server gRPC. +// TODO(b/323079791): Migrate to use utility library from tsl/platform. +std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials(); + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_COMMON_GRPC_CREDENTIALS_H_ diff --git a/xla/python/ifrt_proxy/common/grpc_ifrt_service.proto b/xla/python/ifrt_proxy/common/grpc_ifrt_service.proto new file mode 100644 index 0000000000000..6741e5d98af8a --- /dev/null +++ b/xla/python/ifrt_proxy/common/grpc_ifrt_service.proto @@ -0,0 +1,107 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package xla.ifrt.proxy; + +import "xla/python/ifrt_proxy/common/ifrt_service.proto"; + +service GrpcIfrtService { + // Returns the IFRT Proxy version that both the client and the server + // supports. Returns an error if there's no such version. + rpc GetVersion(GrpcGetVersionRequest) returns (GrpcGetVersionResponse) {} + + // IfrtSession is a stream of IFRT requests (from the client) and responses + // from the server. + // + // Clients can optionally start the stream with an InitRequest to configure + // startup options and to retrieve basic run-time system details such as the + // number and handles of the available devices (see InitResponse). But clients + // that are fine with the default options and do not immediately need the info + // from the InitResponse can start with any other request. + // + // TODO(b/282757875): Investigate if there are useful details that client + // should supply to the server even before the first InitRequest message - may + // be via gRPC metadata. + rpc IfrtSession(stream IfrtRequest) returns (stream IfrtResponse) {} + + // Sends a host buffer from the client to the server. Uses client-side + // streaming to allow sending buffers that exceed the 2GiB protobuf + // serialization limit. + rpc HostBufferStore(stream GrpcHostBufferStoreRequest) + returns (GrpcHostBufferStoreResponse); + + // Reads a host buffer from the server to the client. Uses server-side + // streaming to allow >2GiB host buffer transfer. + rpc HostBufferLookup(GrpcHostBufferLookupRequest) + returns (stream GrpcHostBufferLookupResponse); + + // Deletes a host buffer from the server. + rpc HostBufferDelete(GrpcHostBufferDeleteRequest) + returns (GrpcHostBufferDeleteResponse); +} + +message GrpcGetVersionRequest { + IfrtProxyVersion min_version = 1; + IfrtProxyVersion max_version = 2; +} + +message GrpcGetVersionResponse { + IfrtProxyVersion version = 1; +} + +// Metadata for `IfrtSession` requests, sent as client metadata associated with +// key "ifrt-proxy-grpc-ifrt-session-metadata-bin". +message GrpcIfrtSessionMetadata { + IfrtProxyVersion version = 1; +} + +// Metadata for `Store` requests, sent as client metadata associated with key +// "ifrt-proxy-grpc-host-buffer-store-metadata-bin". +message GrpcHostBufferStoreMetadata { + fixed64 session_id = 1; + fixed64 handle = 2; + int64 buffer_size = 3; +} + +// `Store` request that contains actual data, potentially chunked. All requests +// in a transfer must be sent in order and the server simply concatenate `bytes` +// in the response under this assumption. +message GrpcHostBufferStoreRequest { + bytes data = 1; // copybara_removed [ctype = STRING_PIECE] +} + +message GrpcHostBufferStoreResponse {} + +// `Lookup` request that specifies which host buffer in the server to read. +message GrpcHostBufferLookupRequest { + fixed64 session_id = 1; + fixed64 handle = 2; +} + +// `Lookup` response that returns the (potentially chunked) host buffer +// contents. As in `GrpcHostBufferStoreRequest`, all responses must be sent in +// order and the client simply concatenates `data`. +message GrpcHostBufferLookupResponse { + bytes data = 1; // copybara_removed [ctype = STRING_PIECE] +} + +// `Delete` request that specifies the host buffer to delete. +message GrpcHostBufferDeleteRequest { + fixed64 session_id = 1; + fixed64 handle = 2; +} + +message GrpcHostBufferDeleteResponse {} diff --git a/xla/python/ifrt_proxy/common/ifrt_proxy.bzl b/xla/python/ifrt_proxy/common/ifrt_proxy.bzl new file mode 100644 index 0000000000000..9dd5c3e3ad996 --- /dev/null +++ b/xla/python/ifrt_proxy/common/ifrt_proxy.bzl @@ -0,0 +1,8 @@ +"""Common libraries for IFRT proxy.""" + +load("//xla:xla.bzl", "xla_cc_test") + +def ifrt_proxy_cc_test(**kwargs): + xla_cc_test(**kwargs) + +default_ifrt_proxy_visibility = ["//xla/python/ifrt_proxy:__subpackages__"] diff --git a/xla/python/ifrt_proxy/common/ifrt_service.proto b/xla/python/ifrt_proxy/common/ifrt_service.proto new file mode 100644 index 0000000000000..4274ba5f58122 --- /dev/null +++ b/xla/python/ifrt_proxy/common/ifrt_service.proto @@ -0,0 +1,491 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package xla.ifrt.proxy; + +import "google/protobuf/any.proto"; +import "xla/pjrt/execute_options.proto"; +import "xla/python/ifrt/dtype.proto"; +import "xla/python/ifrt/serdes.proto"; +import "xla/python/ifrt/shape.proto"; +import "xla/python/ifrt/sharding.proto"; +import "xla/python/ifrt_proxy/common/types.proto"; +import "xla/xla_data.proto"; +import "tsl/protobuf/status.proto"; + +option cc_enable_arenas = true; + +message IfrtProxyVersion { + int32 protocol_version = 1; +} + +message IfrtRequest { + RequestMetadata request_metadata = 1; + + oneof request { + InitRequest init_request = 2; + + // ===== Future ===== + CheckFutureRequest check_future_request = 3; + + // ===== Array ===== + MakeArrayFromHostBufferRequest make_array_from_host_buffer_request = 4; + AssembleArrayFromSingleDeviceArraysRequest + assemble_array_from_single_device_arrays_request = 5; + CopyToHostBufferRequest copy_to_host_buffer_request = 6; + DisassembleIntoSingleDeviceArraysRequest + disassemble_into_single_device_arrays_request = 7; + CheckArrayReadyRequest check_array_ready_request = 8; + DeleteArrayRequest delete_array_request = 9; + ReshardRequest reshard_request = 10; + FullyReplicatedShardRequest fully_replicated_shard_request = 20; + IsArrayDeletedRequest is_array_deleted_request = 11; + DestructArrayRequest destruct_array_request = 12; + + // ==== Compiler ==== + CompileRequest compile_request = 13; + + // ===== LoadedExecutable ===== + LoadedExecutableMetadataRequest loaded_executable_metadata_request = 14; + LoadedExecutableExecuteRequest loaded_executable_execute_request = 15; + LoadedExecutableDeleteRequest loaded_executable_delete_request = 16; + LoadedExecutableIsDeletedRequest loaded_executable_is_deleted_request = 17; + LoadedExecutableDestructRequest loaded_executable_destruct_request = 18; + + // ===== LoadedHostCallback ===== + LoadedHostCallbackPollRequest loaded_host_callback_poll_request = 21; + LoadedHostCallbackReturnRequest loaded_host_callback_return_request = 22; + + // ===== Client ===== + GetDefaultDeviceAssignmentRequest get_default_device_assignment_request = + 19; + } +} + +message IfrtResponse { + ResponseMetadata response_metadata = 1; + + oneof response { + InitResponse init_response = 2; + + // ===== Future ===== + CheckFutureResponse check_future_response = 3; + + // ===== Array ===== + MakeArrayFromHostBufferResponse make_array_from_host_buffer_response = 4; + AssembleArrayFromSingleDeviceArraysResponse + assemble_array_from_single_device_arrays_response = 5; + CopyToHostBufferResponse copy_to_host_buffer_response = 6; + DisassembleIntoSingleDeviceArraysResponse + disassemble_into_single_device_arrays_response = 7; + CheckArrayReadyResponse check_array_ready_response = 8; + DeleteArrayResponse delete_array_response = 9; + ReshardResponse reshard_response = 10; + FullyReplicatedShardResponse fully_replicated_shard_response = 20; + IsArrayDeletedResponse is_array_deleted_response = 11; + DestructArrayResponse destruct_array_response = 12; + + // ===== Compiler ===== + CompileResponse compile_response = 13; + + // ===== LoadedExecutable ===== + LoadedExecutableMetadataResponse loaded_executable_metadata_response = 14; + LoadedExecutableExecuteResponse loaded_executable_execute_response = 15; + LoadedExecutableDeleteResponse loaded_executable_delete_response = 16; + LoadedExecutableIsDeletedResponse loaded_executable_is_deleted_response = + 17; + LoadedExecutableDestructResponse loaded_executable_destruct_response = 18; + + // ===== LoadedHostCallback ===== + LoadedHostCallbackPollResponse loaded_host_callback_poll_response = 21; + LoadedHostCallbackReturnResponse loaded_host_callback_return_response = 22; + + // ===== Client ===== + GetDefaultDeviceAssignmentResponse get_default_device_assignment_response = + 19; + } +} + +// Metadata of an IFRT Request. +message RequestMetadata { + // Identifies a logical IFRT Operation (equivalent to an IFRT API call). + // + // For the operations that require chunking (e.g.: MakeArrayFromHostBuffer) + // all the request proto messages share the same op_id. + // + // Must be unique and monotonically increasing across the life of a client - + // may stretch across multiple successive IfrtSessions used to reconnect and + // resync after transient connectivity failures. + fixed64 op_id = 1; + + // List of one or more prior ops this current op is "dependent" + // upon. Currently this allows the client to define the order in which the + // server starts the execution of requests. Future versions may add other + // types of dependencies. For instance, a separate list of dependencies that + // must *complete* executing before the current one can start to execute. + // + // An op_id that has not yet been seen by the server is treated as an error + // that fails the op. + repeated fixed64 dependencies = 2; + + // UserContext is a basic provenance mechanism that allows the server-side + // actions and artifacts (say, allocating a buffer) to be associated with the + // corresponding client-side context that triggered those actions. + // + // The optional UserContextId is generated by the client and are used as an + // opaque label by the server and the run-time systems behind it. + // TODO(b/282757875): Add a pointer to Usercontext bugs/design doc. + fixed64 user_context_id = 3; + + // Additional implementation-specific payloads. + repeated google.protobuf.Any payloads = 4; +} + +// Metadata of an IFRT Response. + +message ResponseMetadata { + // ID of the operation this response belongs to. + fixed64 op_id = 1; + + // Status of the operation. + // + // In case of "chunked" responses (i.e., the full logical response is + // spread across a sequence of IfrtResponse protos), the actual sequence of + // IfrtResponse messages will follow only if this Status is OK in the very + // first message. That is, in case of errors, server sends a single + // IfrtResponse with the appropriate error included. + // + // In case of "batched" operations (i.e., where the response is carrying + // the outcomes of multiple requests that were "batched" in the same + // IfrtRequest proto - such as deleting a bunch of Arrays) this Status + // field provides a way to quickly check if none of the individual + // operations encountered errors. Clients should not rely on specific error + // type or string when this is not OK, they should check the response + // message for individual Statuses. + tensorflow.StatusProto status = 2; +} + +// InitRequest allows the client to specify the optional startup configuration +// parameters such as an idle timeout for this `IfrtSession`, backend servers +// addresses, and whether to turn on tracing, etc. +// +// Initialization of a a session is optional, but if a client chooses to do it, +// it must be the very first op i.e., the InitRequest must be the very first +// request of the session. +message InitRequest {} + +// InitResponse contains basic runtime system info (such as the available +// devices, and name and type of the platform) that most clients can immediately +// make use of. It may also carry the status for whether the optional +// configuration requested by the InitRequest has been successfully applied. +message InitResponse { + uint64 session_id = 8; + + string platform_name = 1; // == ifrt::Client::platform_name() + string platform_version = 2; // == ifrt::Client::platform_version() + uint64 platform_id = 3; // == ifrt::Client::platform_id() + uint64 process_index = 4; // == ifrt::Client::process_index() + string runtime_type = 5; // == ifrt::Client::runtime_type() + + message Device { + int32 id = 1; + int32 local_device_id = 9; + int32 local_hardware_id = 2; + string device_kind = 3; + optional int32 default_memory_id = 7; + repeated int32 memory_ids = 8; + string debug_string = 4; + string to_string = 5; + map attributes = 6; + } + + repeated Device devices = 6; // == ifrt::Client::devices() + repeated int32 addressable_device_ids = + 7; // == ifrt::Client::addressable_devices() + + message Memory { + int32 id = 1; + string memory_space_kind = 2; + repeated int32 device_ids = 3; + string debug_string = 4; + string to_string = 5; + } + + repeated Memory memories = 9; +} + +// ================ Future-related operations ================ + +// Checks if the given Futures are ready on the server. This is a destructive +// read, i.e., the given future will no longer be able to be referenced. +message CheckFutureRequest { + fixed64 future_handle = 1; +} +message CheckFutureResponse {} + +// ================ Array-related operations ================ + +// In the current context of the IFRT proxy service, the term `Host` in the +// proto names below refers to the host where the proxy client and the user code +// (e.g.: a Jax application) are running. + +// Makes an IFRT Array from the contents of a HostBuffer. +// Equivalent to `ifrt::Client::MakeArrayFromHostBuffer`. +message MakeArrayFromHostBufferRequest { + DTypeProto dtype = 1; + ShapeProto shape = 2; + ShardingProto sharding = 3; + fixed64 host_buffer_handle = 4; + optional proto.ByteStrides byte_strides = 5; +} +message MakeArrayFromHostBufferResponse { + fixed64 array_handle = 1; +} + +// Makes an IFRT Array from a set of single-device Arrays. +// Equivalent to ifrt::Client::AssembleArrayFromSingleDeviceArrays. +message AssembleArrayFromSingleDeviceArraysRequest { + ShapeProto shape = 1; + ShardingProto sharding = 2; + repeated fixed64 single_device_array_handles = 3; + proto.ArrayCopySemantics copy_semantics = 4; +} +message AssembleArrayFromSingleDeviceArraysResponse { + fixed64 array_handle = 1; +} + +// Reads the contents of a given IFRT Array. +// Equivalent to ifrt::Array::CopyToHostBuffer. +message CopyToHostBufferRequest { + fixed64 array_handle = 2; + optional proto.ByteStrides byte_strides = 3; + fixed64 host_buffer_handle = 1; +} +message CopyToHostBufferResponse {} + +// Breaks the given Array into its constituent per-device Arrays. +// Equivalent to ifrt::Array::DisassmebleIntoSingleDeviceArrays. +message DisassembleIntoSingleDeviceArraysRequest { + fixed64 array_handle = 1; + proto.ArrayCopySemantics copy_semantics = 2; +} +message DisassembleIntoSingleDeviceArraysResponse { + repeated fixed64 single_device_array_handles = 1; +} + +message ReshardRequest { + fixed64 array_handle = 1; + ShardingProto sharding = 2; + proto.ArrayCopySemantics copy_semantics = 3; +} +message ReshardResponse { + fixed64 array_handle = 1; +} + +message FullyReplicatedShardRequest { + fixed64 array_handle = 1; + proto.ArrayCopySemantics copy_semantics = 2; +} +message FullyReplicatedShardResponse { + fixed64 array_handle = 1; +} + +// Checks if the given Arrays are ready on the server. +message CheckArrayReadyRequest { + fixed64 array_handle = 1; +} +message CheckArrayReadyResponse {} + +// Deletes the given Array. Response contains the handle for a Future that +// becomes ready when the deletion completes. +message DeleteArrayRequest { + fixed64 array_handle = 1; +} +message DeleteArrayResponse { + fixed64 deletion_future_handle = 1; +} + +message IsArrayDeletedRequest { + fixed64 array_handle = 1; +} +message IsArrayDeletedResponse { + bool deleted = 1; +} + +message DestructArrayRequest { + fixed64 array_handle = 1; +} +message DestructArrayResponse {} + +// ================ Compiler-related operations ================ + +// Modeled after `xla::PjRtLoadedExecutable::LogicalDeviceIds`. +// +// TODO(hyeontaek): this XLA-specific type is temporary and will be removed when +// `addressable_device_logical_ids()` is removed from `LoadedExecutable` or +// moved to a type-erased proto field. +message LogicalDeviceIds { + int32 replica = 1; + int32 partition = 2; +} + +// Compiles `mlir_module` and returns a `LoadedExecutable`. +message CompileRequest { + xla.ifrt.Serialized program = 1; + xla.ifrt.Serialized compile_options = 2; + repeated bytes host_callbacks = 3; +} +message CompileResponse { + fixed64 loaded_executable_handle = 1; + repeated fixed64 loaded_host_callback_handles = 8; + + // A subset of LoadedExecutable's fields that are cheap to calculate. See + // `LoadedExecutableMetadataResponse` for the rest of metadata. + string name = 2; + int32 num_devices = 3; + repeated LogicalDeviceIds addressable_device_logical_ids = 4; + repeated int32 addressable_device_ids = 5; + oneof fingerprint { + bytes fingerprint_value = 6; + tensorflow.StatusProto fingerprint_error = 7; + } + fixed64 ready_future_handle = 9; +} + +// ================ LoadedExecutable-related operations ================ + +// Reads `LoadedExecutable`'s metadata that's typically available only after +// compilation. Metadata fields that are cheaper to calculate are available +// immediately as part of `CompileResponse`. +message LoadedExecutableMetadataRequest { + fixed64 loaded_executable_handle = 1; +} +message LoadedExecutableMetadataResponse { + message ShardingList { + repeated xla.OpSharding shardings = 1; + } + + optional ShardingList parameter_shardings = 1; + optional ShardingList output_shardings = 2; + + message LayoutList { + repeated xla.LayoutProto layouts = 1; + } + + oneof parameter_layouts { + LayoutList parameter_layouts_list = 4; + tensorflow.StatusProto parameter_layouts_error = 5; + } + oneof output_layouts { + LayoutList output_layouts_list = 6; + tensorflow.StatusProto output_layouts_error = 7; + } + + message MemoryKindList { + repeated string memory_kinds = 1; + } + + message OutputMemoryKind { + tensorflow.StatusProto status = 1; + repeated MemoryKindList memory_kind_lists = 2; + } + + OutputMemoryKind output_memory_kinds = 3; +} + +// Mirrors `LoadedExecutable::Execute`. Returns output array handles and a +// future handle that becomes ready when the execution completes. The latter can +// be checked by issuing `CheckFutureRequest`. +message LoadedExecutableExecuteRequest { + fixed64 loaded_executable_handle = 1; + repeated fixed64 args_handles = 2; + xla.ExecuteOptionsProto execute_options = 3; + repeated int32 device_ids = 4; +} +message LoadedExecutableExecuteResponse { + fixed64 status_handle = 1; + + message Output { + DTypeProto dtype = 1; + ShapeProto shape = 2; + ShardingProto sharding = 3; + fixed64 array_handle = 4; + } + + repeated Output outputs = 2; +} + +// Mirrors `LoadedExecutable::Delete`. Returns a handle of a future that becomes +// ready when the deletion completes. +message LoadedExecutableDeleteRequest { + fixed64 loaded_executable_handle = 1; +} +message LoadedExecutableDeleteResponse { + fixed64 future_handle = 1; +} + +// Mirrors `LoadedExecutable::IsDeleted`. +message LoadedExecutableIsDeletedRequest { + fixed64 loaded_executable_handle = 1; +} +message LoadedExecutableIsDeletedResponse { + bool is_deleted = 1; +} + +// Mirrors `LoadedExecutable::~LoadedExecutable`. The LoadedExecutable handle +// becomes unusable after this request. +message LoadedExecutableDestructRequest { + fixed64 loaded_executable_handle = 1; +} +message LoadedExecutableDestructResponse {} + +// ================ LoadedHostCallback-related operations ================ + +// Waits for the given host callback on the server to have any pending execution +// and retrieves its execution identifier and operands. The server serializes +// all operands, concatenates them in the argument order, stores it as a single +// host buffer assocatiated with the given handle. +message LoadedHostCallbackPollRequest { + fixed64 loaded_host_callback_handle = 1; + fixed64 operand_host_buffer_handle = 2; +} +message LoadedHostCallbackPollResponse { + optional fixed64 host_callback_execution_handle = 1; +} + +// Returns the results of a client-side host callback execution, requested by +// `LoadedHostCallbackPollResponse`. The client concatenates all serialized +// results and stores them as a single host buffer associated with the given +// handle. +message LoadedHostCallbackReturnRequest { + fixed64 host_callback_execution_handle = 1; + oneof result { + fixed64 result_host_buffer_handle = 3; + tensorflow.StatusProto error = 2; + } +} +message LoadedHostCallbackReturnResponse {} + +// ============= Operations supported by the IFRT `Client` class ============= + +// Mirrors Client::GetDefaultDeviceAssignment. +message GetDefaultDeviceAssignmentRequest { + fixed64 num_replicas = 1; + fixed64 num_partitions = 2; +} +message GetDefaultDeviceAssignmentResponse { + xla.DeviceAssignmentProto device_assignment = 1; +} diff --git a/xla/python/ifrt_proxy/common/proto_util.cc b/xla/python/ifrt_proxy/common/proto_util.cc new file mode 100644 index 0000000000000..a9d057c2139a5 --- /dev/null +++ b/xla/python/ifrt_proxy/common/proto_util.cc @@ -0,0 +1,38 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/common/proto_util.h" + +#include +#include + +#include "absl/status/status.h" +#include "tsl/platform/status_to_from_proto.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +std::unique_ptr NewIfrtResponse(uint64_t op_id, + absl::Status status) { + auto ifrt_resp = std::make_unique(); + auto* response_metadata = ifrt_resp->mutable_response_metadata(); + response_metadata->set_op_id(op_id); + *response_metadata->mutable_status() = tsl::StatusToProto(status); + return ifrt_resp; +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/common/proto_util.h b/xla/python/ifrt_proxy/common/proto_util.h new file mode 100644 index 0000000000000..d999d14f97836 --- /dev/null +++ b/xla/python/ifrt_proxy/common/proto_util.h @@ -0,0 +1,57 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_COMMON_PROTO_UTIL_H_ +#define XLA_PYTHON_IFRT_PROXY_COMMON_PROTO_UTIL_H_ + +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Makes an IfrtResponse proto with the given metadata. +std::unique_ptr NewIfrtResponse( + uint64_t op_id, absl::Status status = absl::OkStatus()); + +// Converts an `absl::string_view` into a type that is appropriate for doing +// `proto->set_string_field(...)`. This type can be absl::string_view in the +// newest versions of protobuf, but needs to be std::string for previous +// versions. (As of Feb 2024, OpenXLA uses an old version.) +#if defined(PLATFORM_GOOGLE) +inline absl::string_view AsProtoStringData( + absl::string_view s ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return s; +} +#else +inline std::string AsProtoStringData(absl::string_view s) { + LOG_FIRST_N(WARNING, 5) << "AsProtoStringData(): copying string_view->string"; + return std::string(s); +} +#endif + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_COMMON_PROTO_UTIL_H_ diff --git a/xla/python/ifrt_proxy/common/types.cc b/xla/python/ifrt_proxy/common/types.cc new file mode 100644 index 0000000000000..05e9dff6272d7 --- /dev/null +++ b/xla/python/ifrt_proxy/common/types.cc @@ -0,0 +1,119 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/common/types.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt_proxy/common/types.pb.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +absl::StatusOr FromVariantProto( + const proto::Variant& variant_proto) { + switch (variant_proto.value_case()) { + case proto::Variant::kStringValue: + return variant_proto.string_value(); + case proto::Variant::kInt64Value: + return variant_proto.int64_value(); + case proto::Variant::kInt64List: { + const auto& values = variant_proto.int64_list().values(); + return std::vector(values.begin(), values.end()); + } + case proto::Variant::kFloatValue: + return variant_proto.float_value(); + default: + return absl::UnimplementedError(absl::StrCat( + "Unknown xla.ifrt.proto.Variant case: ", variant_proto.value_case())); + } +} + +absl::StatusOr ToVariantProto(const xla::PjRtValueType& value) { + proto::Variant variant; + if (auto* s = std::get_if(&value)) { + variant.set_string_value(*s); + } else if (auto* i = std::get_if(&value)) { + variant.set_int64_value(*i); + } else if (auto* is = std::get_if>(&value)) { + for (const int64_t i : *is) { + variant.mutable_int64_list()->add_values(i); + } + } else if (auto* f = std::get_if(&value)) { + variant.set_float_value(*f); + } else { + return absl::UnimplementedError("Unknown xla::PjRtValueType type"); + } + return variant; +} + +proto::ArrayCopySemantics ToArrayCopySemanticsProto(ArrayCopySemantics s) { + switch (s) { + case ArrayCopySemantics::kAlwaysCopy: + return proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY; + case ArrayCopySemantics::kDonateInput: + return proto::ARRAY_COPY_SEMANTICS_DONATE_INPUT; + case ArrayCopySemantics::kReuseInput: + return proto::ARRAY_COPY_SEMANTICS_REUSE_INPUT; + } +} + +absl::StatusOr FromArrayCopySemanticsProto( + proto::ArrayCopySemantics s) { + MakeArrayFromHostBufferRequest req; + switch (s) { + case proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY: + return ArrayCopySemantics::kAlwaysCopy; + case proto::ARRAY_COPY_SEMANTICS_DONATE_INPUT: + return ArrayCopySemantics::kDonateInput; + case proto::ARRAY_COPY_SEMANTICS_REUSE_INPUT: + return ArrayCopySemantics::kReuseInput; + default: + return absl::InvalidArgumentError( + absl::StrCat("Unhandled proto-enum value ", s, ":", + proto::ArrayCopySemantics_Name(s))); + } +} + +std::vector FromByteStridesProto(const proto::ByteStrides& strides) { + std::vector result; + result.reserve(strides.strides_size()); + for (auto x : strides.strides()) { + result.push_back(x); + } + return result; +} + +proto::ByteStrides ToByteStridesProto(const absl::Span strides) { + proto::ByteStrides result; + for (auto x : strides) { + result.add_strides(x); + } + return result; +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/common/types.h b/xla/python/ifrt_proxy/common/types.h new file mode 100644 index 0000000000000..0c517e2da054a --- /dev/null +++ b/xla/python/ifrt_proxy/common/types.h @@ -0,0 +1,58 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_COMMON_TYPES_H_ +#define XLA_PYTHON_IFRT_PROXY_COMMON_TYPES_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/types.pb.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +struct ArrayHandle { + uint64_t handle; + + template + friend void AbslStringify(Sink& sink, const ArrayHandle& h) { + absl::Format(&sink, "arr_%v", h.handle); + } +}; + +absl::StatusOr FromArrayCopySemanticsProto( + proto::ArrayCopySemantics s); +proto::ArrayCopySemantics ToArrayCopySemanticsProto(ArrayCopySemantics s); + +absl::StatusOr FromVariantProto( + const proto::Variant& variant_proto); +absl::StatusOr ToVariantProto(const xla::PjRtValueType& value); + +std::vector FromByteStridesProto(const proto::ByteStrides& strides); +proto::ByteStrides ToByteStridesProto(absl::Span strides); + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_COMMON_TYPES_H_ diff --git a/xla/python/ifrt_proxy/common/types.proto b/xla/python/ifrt_proxy/common/types.proto new file mode 100644 index 0000000000000..ca3829891d762 --- /dev/null +++ b/xla/python/ifrt_proxy/common/types.proto @@ -0,0 +1,43 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package xla.ifrt.proto; + +// Mirrors `xla::PjRtValueType`, which is used in IFRT to model +// polymorphic-typed values, e.g., `xla::ifrt::Executable::CostAnalysisValue`. +message Variant { + message Int64List { + repeated sfixed64 values = 1; + } + + oneof value { + bytes string_value = 1; + sfixed64 int64_value = 2; + Int64List int64_list = 3; + float float_value = 4; + } +} + +enum ArrayCopySemantics { + ARRAY_COPY_SEMANTICS_UNSPECIFIED = 0; + ARRAY_COPY_SEMANTICS_ALWAYS_COPY = 1; + ARRAY_COPY_SEMANTICS_REUSE_INPUT = 2; + ARRAY_COPY_SEMANTICS_DONATE_INPUT = 3; +} + +message ByteStrides { + repeated int64 strides = 1; +} diff --git a/xla/python/ifrt_proxy/common/types_test.cc b/xla/python/ifrt_proxy/common/types_test.cc new file mode 100644 index 0000000000000..fdbf3ff123cc8 --- /dev/null +++ b/xla/python/ifrt_proxy/common/types_test.cc @@ -0,0 +1,84 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/common/types.h" + +#include +#include +#include + +#include +#include +#include "xla/pjrt/pjrt_common.h" +#include "xla/python/ifrt_proxy/common/types.pb.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::tsl::testing::IsOkAndHolds; + +class VariantTest : public testing::TestWithParam {}; + +TEST_P(VariantTest, ToFromVariantProto) { + const auto& variant = GetParam(); + TF_ASSERT_OK_AND_ASSIGN(proto::Variant variant_proto, + ToVariantProto(variant)); + EXPECT_THAT(FromVariantProto(variant_proto), IsOkAndHolds(variant)); +} + +INSTANTIATE_TEST_SUITE_P( + Variant, VariantTest, + testing::Values(xla::PjRtValueType(std::string("foo")), + xla::PjRtValueType(static_cast(1234)), + xla::PjRtValueType(std::vector{1, 2}), + xla::PjRtValueType(3.14f))); + +class ByteStridesTest : public testing::TestWithParam> {}; + +TEST_P(ByteStridesTest, ToFromProto) { + std::vector strides = GetParam(); + EXPECT_EQ(FromByteStridesProto(ToByteStridesProto(strides)), strides); +} + +INSTANTIATE_TEST_SUITE_P( + ByteStrides, ByteStridesTest, + testing::ValuesIn(std::vector>{ + {}, {1}, {0}, {4, 8}, {8, 4}, {1, 2, 3, 4}, {0, 4}, {4, 0}})); + +TEST(ArrayCopySemanticsTest, FromToFromProto) { + for (int i = 0; i < proto::ArrayCopySemantics_descriptor()->value_count(); + ++i) { + const auto proto_enum = static_cast( + proto::ArrayCopySemantics_descriptor()->value(i)->number()); + if (proto_enum == proto::ARRAY_COPY_SEMANTICS_UNSPECIFIED) { + continue; + } + TF_ASSERT_OK_AND_ASSIGN(const auto cpp_enum, + FromArrayCopySemanticsProto(proto_enum)); + TF_ASSERT_OK_AND_ASSIGN( + const auto cpp_enum_copy, + FromArrayCopySemanticsProto(ToArrayCopySemanticsProto(cpp_enum))); + EXPECT_EQ(cpp_enum_copy, cpp_enum); + } +} + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/integration_tests/BUILD b/xla/python/ifrt_proxy/integration_tests/BUILD new file mode 100644 index 0000000000000..736520601eee4 --- /dev/null +++ b/xla/python/ifrt_proxy/integration_tests/BUILD @@ -0,0 +1,107 @@ +# Copyright 2023 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +load("//xla/python/ifrt_proxy/common:ifrt_proxy.bzl", "ifrt_proxy_cc_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:private"], +) + +cc_library( + name = "register_pjrt_cpu_for_ifrt_api_tests", + testonly = True, + srcs = ["register_pjrt_cpu_for_ifrt_api_tests.cc"], + deps = [ + "//xla/pjrt:pjrt_client", + "//xla/pjrt/cpu:cpu_client", + "//xla/python/ifrt", + "//xla/python/ifrt:test_util", + "//xla/python/ifrt_proxy/client:grpc_client", + "//xla/python/ifrt_proxy/client:registry", + "//xla/python/ifrt_proxy/server:grpc_server", + "//xla/python/pjrt_ifrt", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + ], + alwayslink = True, +) + +ifrt_proxy_cc_test( + name = "client_impl_test_tfrt_cpu", + deps = [ + ":register_pjrt_cpu_for_ifrt_api_tests", + "//xla/python/ifrt:client_impl_test_lib", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) + +ifrt_proxy_cc_test( + name = "array_impl_test_tfrt_cpu", + srcs = ["array_impl_test_tfrt_cpu.cc"], + deps = [ + ":register_pjrt_cpu_for_ifrt_api_tests", + "//xla/python/ifrt:array_impl_test_lib", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + ], +) + +ifrt_proxy_cc_test( + name = "executable_impl_test_tfrt_cpu", + timeout = "moderate", + srcs = ["executable_impl_test_tfrt_cpu.cc"], + deps = [ + ":register_pjrt_cpu_for_ifrt_api_tests", # buildcleaner: keep + "//xla/python/ifrt:test_util", + "//xla/python/ifrt/ir/tests:executable_impl_test_lib", + "//xla/python/pjrt_ifrt:xla_executable_impl_test_lib", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + ], +) + +ifrt_proxy_cc_test( + name = "mock_array_test", + size = "small", + srcs = ["mock_array_test.cc"], + deps = [ + "//xla:status", + "//xla/pjrt/cpu:cpu_client", + "//xla/python/ifrt", + "//xla/python/ifrt:mock", + "//xla/python/ifrt_proxy/client", + "//xla/python/ifrt_proxy/client:grpc_client", + "//xla/python/ifrt_proxy/client:registry", + "//xla/python/ifrt_proxy/server:grpc_server", + "//xla/python/pjrt_ifrt", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/concurrency:ref_count", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + ], +) diff --git a/xla/python/ifrt_proxy/integration_tests/array_impl_test_tfrt_cpu.cc b/xla/python/ifrt_proxy/integration_tests/array_impl_test_tfrt_cpu.cc new file mode 100644 index 0000000000000..4ce343a1cb72d --- /dev/null +++ b/xla/python/ifrt_proxy/integration_tests/array_impl_test_tfrt_cpu.cc @@ -0,0 +1,42 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" + +int main(int argc, char** argv) { + const std::string disabled[] = { + // TfrtCpuBuffer::ToLiteral() currently does not respect the layout of the + // destination literal. + "ArrayImplTest.MakeArrayFromHostBufferAndCopyToHostBufferWithByteStrides", + + // `ShardingParamSharding` does not support serialization yet. + // TODO(b/282757875): Enable the test once IFRT implements + // `ShardingParamShardingSerDes`. + "ArrayImplTest.AssembleAndDisassembleArray", + }; + + const std::string filter = absl::StrCat("-", absl::StrJoin(disabled, ":")); +#ifdef GTEST_FLAG_SET + GTEST_FLAG_SET(filter, filter.c_str()); +#else + testing::GTEST_FLAG(filter) = filter.c_str(); +#endif + + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/xla/python/ifrt_proxy/integration_tests/executable_impl_test_tfrt_cpu.cc b/xla/python/ifrt_proxy/integration_tests/executable_impl_test_tfrt_cpu.cc new file mode 100644 index 0000000000000..2f9b1e4731916 --- /dev/null +++ b/xla/python/ifrt_proxy/integration_tests/executable_impl_test_tfrt_cpu.cc @@ -0,0 +1,48 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "xla/python/ifrt/test_util.h" + +int main(int argc, char** argv) { + const std::string disabled[] = { + // Executable::IsDeleted always returns false with TFRT CPU backend. + "LoadedExecutableImplTest.IsDeleted", + + // Enable this when Serialization support for IFRT IR is available. + "IfrtIrExecutableImplTest.CallXla", + "IfrtIrExecutableImplTest.Reshard", + "IfrtIrExecutableImplTest.ZeroInput", + "IfrtIrExecutableImplTest.ZeroOutput", + "IfrtIrExecutableImplTest.BufferDonation", + "IfrtIrExecutableImplTest.LoadedExecBinding", + "ProgramLoadedExecutableImplTest.MultipleAtomProgramsNeedDummyInputs", + }; + + const std::string filter = absl::StrCat("-", absl::StrJoin(disabled, ":")); + +#ifdef GTEST_FLAG_SET + GTEST_FLAG_SET(filter, filter.c_str()); +#else + testing::GTEST_FLAG(filter) = filter.c_str(); +#endif + + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc b/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc new file mode 100644 index 0000000000000..3d4cd8fd800d5 --- /dev/null +++ b/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc @@ -0,0 +1,274 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include "absl/base/thread_annotations.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/pjrt/cpu/cpu_client.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/mock.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt_proxy/client/client.h" +#include "xla/python/ifrt_proxy/client/registry.h" +#include "xla/python/ifrt_proxy/server/grpc_server.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/status.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/env.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" +#include "tsl/platform/threadpool.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::tsl::testing::IsOk; +using ::tsl::testing::StatusIs; + +constexpr absl::StatusCode kInternal = absl::StatusCode::kInternal; + +constexpr absl::Duration kSomeTime = absl::Seconds(1); + +class MockArrayTest : public testing::Test { + public: + void SetUp() override { + std::string address = + absl::StrCat("localhost:", tsl::testing::PickUnusedPortOrDie()); + TF_ASSERT_OK_AND_ASSIGN( + server_, GrpcServer::CreateFromIfrtClientFactory( + address, [this] { return CreateMockBackend(); })); + TF_ASSERT_OK_AND_ASSIGN(client_, + CreateClient(absl::StrCat("grpc://", address))); + } + + struct ArrayPair { + // IFRT array exposed to the proxy's user. Not a mock. + tsl::RCReference proxy_client_array; + // IFRT array owned by the proxy server whose behavior should be + // reflected by proxy_client_array. Mock but delegated. + tsl::RCReference backend_array; + }; + + absl::StatusOr NewArray() { + DType dtype(DType::kF32); + Shape shape({2, 3}); + auto data = std::make_unique>(6); + std::iota(data->begin(), data->end(), 0); + xla::ifrt::Device* device = client_->addressable_devices().at(0); + std::shared_ptr sharding = + SingleDeviceSharding::Create(device, MemoryKind()); + + TF_ASSIGN_OR_RETURN( + auto client_arr, + client_->MakeArrayFromHostBuffer( + data->data(), dtype, shape, + /*byte_strides=*/std::nullopt, sharding, + Client::HostBufferSemantics::kImmutableOnlyDuringCall, + /*on_done_with_host_buffer=*/nullptr)); + + // When the above `MakeArrayFromHostBuffer` results in the server issuing a + // `MakeArrayFromHostBuffer()` to the underlying mock backend, the mock + // backend enqueues the returned mock array onto `mock_arrays_` (this code + // is in `CreateMockBackend()`). + absl::MutexLock l(&mu_); + CHECK_EQ(mock_arrays_.size(), 1); + auto mock = mock_arrays_.back(); + mock_arrays_.pop_back(); + return ArrayPair{client_arr, mock}; + } + + std::unique_ptr server_; + std::unique_ptr client_; + + private: + absl::StatusOr> CreateMockBackend() { + // TODO(b/292339723): Use reference backend as the delegate while mocking. + CpuClientOptions options; + options.asynchronous = true; + options.cpu_device_count = 2; + TF_ASSIGN_OR_RETURN(auto tfrt_cpu_client, xla::GetTfrtCpuClient(options)); + auto mock_backend = std::make_unique( + /*delegate=*/xla::ifrt::PjRtClient::Create(std::move(tfrt_cpu_client))); + + ON_CALL(*mock_backend, MakeArrayFromHostBuffer) + .WillByDefault( + [this, mock_backend = mock_backend.get()]( + const void* data, DType dtype, Shape shape, + std::optional> byte_strides, + std::shared_ptr sharding, + Client::HostBufferSemantics semantics, + std::function on_done_with_host_buffer) + -> absl::StatusOr> { + TF_ASSIGN_OR_RETURN( + auto delegated, + mock_backend->delegated()->MakeArrayFromHostBuffer( + data, dtype, shape, byte_strides, sharding, semantics, + on_done_with_host_buffer)); + auto result = tsl::MakeRef(delegated); + + absl::MutexLock l(&mu_); + mock_arrays_.push_back(result); + return result; + }); + + return mock_backend; + } + + absl::Mutex mu_; + std::vector> mock_arrays_ ABSL_GUARDED_BY(mu_); +}; + +TEST_F(MockArrayTest, ReadyFutureWaitsUntilReady) { + TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); + + absl::Notification wait_ready; + + EXPECT_CALL(*arr.backend_array, GetReadyFuture).WillOnce([&] { + wait_ready.WaitForNotification(); + return arr.backend_array->delegated()->GetReadyFuture(); + }); + + auto ready = arr.proxy_client_array->GetReadyFuture(); + + absl::SleepFor(kSomeTime); + EXPECT_FALSE(ready.IsReady()); + + wait_ready.Notify(); + EXPECT_THAT(ready.Await(), IsOk()); +} + +TEST_F(MockArrayTest, ReadyFuturePropagatesError) { + TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); + + EXPECT_CALL(*arr.backend_array, GetReadyFuture).WillOnce([&] { + return Future(absl::InternalError("testing")); + }); + + EXPECT_THAT(arr.proxy_client_array->GetReadyFuture().Await(), + StatusIs(kInternal)); +} + +TEST_F(MockArrayTest, DeletionFutureWaitsUntilDeleted) { + TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); + + tsl::thread::ThreadPool threads(tsl::Env::Default(), "t", /*num_threads=*/1); + absl::Notification wait_ready; + + EXPECT_CALL(*arr.backend_array, Delete).WillOnce([&] { + // TODO(b/266635130): Write a version of this testcase where the Delete() + // call of the MockArray blocks on `wait_ready`, instead of the Future it + // returns being blocked on `wait_ready`. That version of the testcase does + // not currently work since both the client and the server synchronously + // block until the MockArray's Delete() returns. + auto promise = Future::CreatePromise(); + threads.Schedule([&, promise]() mutable { + wait_ready.WaitForNotification(); + promise.Set(arr.backend_array->delegated()->Delete().Await()); + }); + return Future(promise); + }); + + EXPECT_FALSE(arr.proxy_client_array->IsDeleted()); + auto deleted_future = arr.proxy_client_array->Delete(); + + absl::SleepFor(kSomeTime); + EXPECT_FALSE(deleted_future.IsReady()); + EXPECT_FALSE(arr.proxy_client_array->IsDeleted()); + + wait_ready.Notify(); + EXPECT_THAT(deleted_future.Await(), IsOk()); + EXPECT_TRUE(arr.proxy_client_array->IsDeleted()); +} + +TEST_F(MockArrayTest, DeletionPropagatesError) { + TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); + + EXPECT_CALL(*arr.backend_array, Delete).WillOnce([&] { + return Future(absl::InternalError("testing")); + }); + + EXPECT_FALSE(arr.proxy_client_array->IsDeleted()); + EXPECT_THAT(arr.proxy_client_array->Delete().Await(), StatusIs(kInternal)); +} + +TEST_F(MockArrayTest, CopyToHostFutureWaitsUntilCopied) { + TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); + + absl::Notification wait_ready; + + EXPECT_CALL(*arr.backend_array, CopyToHostBuffer) + .WillOnce([&](auto data, auto byte_strides, auto semantics) { + wait_ready.WaitForNotification(); + return arr.backend_array->delegated()->CopyToHostBuffer( + data, byte_strides, semantics); + }); + + char data[1000]; + auto copied = arr.proxy_client_array->CopyToHostBuffer( + data, /*byte_strides=*/std::nullopt, ArrayCopySemantics::kAlwaysCopy); + + absl::SleepFor(kSomeTime); + EXPECT_FALSE(copied.IsReady()); + + wait_ready.Notify(); + EXPECT_THAT(copied.Await(), IsOk()); +} + +TEST_F(MockArrayTest, CopyToHostFuturePropagatesError) { + TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); + + absl::Notification wait_ready; + + EXPECT_CALL(*arr.backend_array, CopyToHostBuffer).WillOnce([&] { + return Future(absl::InternalError("testing")); + }); + + char data[1000]; + auto copied = arr.proxy_client_array->CopyToHostBuffer( + data, /*byte_strides=*/std::nullopt, ArrayCopySemantics::kAlwaysCopy); + + EXPECT_THAT(copied.Await(), StatusIs(kInternal)); +} + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/integration_tests/register_pjrt_cpu_for_ifrt_api_tests.cc b/xla/python/ifrt_proxy/integration_tests/register_pjrt_cpu_for_ifrt_api_tests.cc new file mode 100644 index 0000000000000..6b344e011ede7 --- /dev/null +++ b/xla/python/ifrt_proxy/integration_tests/register_pjrt_cpu_for_ifrt_api_tests.cc @@ -0,0 +1,80 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file registers a factory with `xla::ifrt::test_util` that will spawn a +// IFRT proxy client connected to an instance of a proxy server that is backed +// by the IFRT-PjRt-CPU backend. +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "xla/pjrt/cpu/cpu_client.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/test_util.h" +#include "xla/python/ifrt/tuple.h" +#include "xla/python/ifrt/value.h" +#include "xla/python/ifrt_proxy/client/registry.h" +#include "xla/python/ifrt_proxy/server/grpc_server.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace test_util { +namespace { + +absl::StatusOr> CreateIfrtBackendClient() { + TF_ASSIGN_OR_RETURN(std::unique_ptr tfrt_cpu_client, + xla::GetTfrtCpuClient(/*asynchronous=*/true, + /*cpu_device_count=*/2)); + return xla::ifrt::PjRtClient::Create(std::move(tfrt_cpu_client)); +} + +const bool kUnused = + (xla::ifrt::test_util::RegisterClientFactory( + []() -> absl::StatusOr> { + std::string address = + absl::StrCat("localhost:", tsl::testing::PickUnusedPortOrDie()); + TF_ASSIGN_OR_RETURN(auto server, + GrpcServer::CreateFromIfrtClientFactory( + address, CreateIfrtBackendClient)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr client, + CreateClient(absl::StrCat("grpc://", address))); + + return std::shared_ptr( + client.release(), /*deleter=*/ + [server = server.release()](xla::ifrt::Client* client) { + // Client has to be destructed before the server since the + // server's destructor (as of Jul 2023) waits for the client to + // end its session. + // TODO(b/282757875): Make the server cancel the client's + // session if the server is getting destructed. + delete client; + delete server; + }); + }), + true); + +} // namespace +} // namespace test_util +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/jax/BUILD b/xla/python/ifrt_proxy/jax/BUILD new file mode 100644 index 0000000000000..b05846d91e0d2 --- /dev/null +++ b/xla/python/ifrt_proxy/jax/BUILD @@ -0,0 +1,50 @@ +# Copyright 2023 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Jax library for IFRT proxy. +load("//xla:pytype.default.bzl", "pytype_strict_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], +) + +pytype_strict_library( + name = "ifrt_proxy_internal", + srcs = ["ifrt_proxy_internal.py"], + # copybara:uncomment_begin + # visibility = [ + # "//xla/python/ifrt_proxy/common/google:friends", + # "//xla/python/ifrt_proxy/common/google:jax_users", + # ], + # copybara:uncomment_end + deps = [ + "//xla/python:xla_client", + ], +) + +# copybara:uncomment_begin(ifrt_proxy.py is not exported to github) +# pytype_strict_library( +# name = "ifrt_proxy", +# srcs = ["ifrt_proxy.py"], +# visibility = [ +# "//xla/python/ifrt_proxy/common/google:friends", +# "//xla/python/ifrt_proxy/common/google:jax_users", +# ], +# deps = [ +# ":ifrt_proxy_internal", +# "//third_party/py/jax", +# ], +# ) +# copybara:uncomment_end diff --git a/xla/python/ifrt_proxy/jax/ifrt_proxy_internal.py b/xla/python/ifrt_proxy/jax/ifrt_proxy_internal.py new file mode 100644 index 0000000000000..4b46a2ea0317b --- /dev/null +++ b/xla/python/ifrt_proxy/jax/ifrt_proxy_internal.py @@ -0,0 +1,77 @@ +# Copyright 2023 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Library to help create a IFRT proxy client. + +This library is no longer recommended nor used in OSS; it is used internally +within google code. TODO(madthanu): Remove library. +""" + +import dataclasses +from typing import Callable, Optional + +from xla.python import xla_client + + +@dataclasses.dataclass +class ConnectionOptions: + """Various connection options. + + Attributes: + on_disconnect: Optional, a callback that will be called if there was a + successful connection to the proxy server and Jax commands could be + issued, but there was a later disconnect before the Client is destroyed. + on_connection_update: Optional, a callback that will be called with status + updates about initial connection establishment. The updates will be + provided as human-readable strings, and an end-user may find them helpful. + """ + + on_disconnect: Optional[Callable[[str], None]] = None + on_connection_update: Optional[Callable[[str], None]] = None + + +_backend_created: bool = False +_connection_options: ConnectionOptions = ConnectionOptions() + + +def get_client(proxy_server_address: str) -> xla_client.Client: + """Creates an IFRT Proxy client for the given server address.""" + global _backend_created + _backend_created = True + py_module = xla_client._xla.ifrt_proxy # pylint: disable=protected-access + cpp_options = py_module.ClientConnectionOptions() + cpp_options.on_disconnect = _connection_options.on_disconnect + cpp_options.on_connection_update = _connection_options.on_connection_update + client = py_module.get_client(proxy_server_address, cpp_options) + return client + + +def set_connection_options( + options: ConnectionOptions, +) -> None: + """Sets the connection options for the "proxy" jax_platforms. + + Args: + options: See documentation for ConnectionOptions class. + + Raises: + ValueError: If this function is called after the proxy backend has already + been created. + """ + global _connection_options + if _backend_created: + raise ValueError( + "set_connection_options() called after proxy backend was created." + ) + _connection_options = options diff --git a/xla/python/ifrt_proxy/server/BUILD b/xla/python/ifrt_proxy/server/BUILD new file mode 100644 index 0000000000000..2f4fa44484c20 --- /dev/null +++ b/xla/python/ifrt_proxy/server/BUILD @@ -0,0 +1,324 @@ +# Copyright 2023 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +load("//xla/python/ifrt_proxy/common:ifrt_proxy.bzl", "default_ifrt_proxy_visibility", "ifrt_proxy_cc_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = default_ifrt_proxy_visibility, +) + +cc_library( + name = "grpc_server", + srcs = ["grpc_server.cc"], + hdrs = ["grpc_server.h"], + deps = [ + ":grpc_service_impl", + ":host_buffer", + ":ifrt_backend", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:grpc_credentials", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto", + "@com_github_grpc_grpc//:grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:statusor", + ], +) + +ifrt_proxy_cc_test( + name = "grpc_server_test", + srcs = ["grpc_server_test.cc"], + tags = ["no_aarch64"], # TODO(b/326080238): Fix this. + deps = [ + ":grpc_server", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + ], +) + +cc_library( + name = "grpc_service_impl", + srcs = ["grpc_service_impl.cc"], + hdrs = ["grpc_service_impl.h"], + deps = [ + ":host_buffer", + ":ifrt_backend", + ":ifrt_session_handler", + ":version", + "//xla/pjrt/distributed:util", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:proto_util", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + ], +) + +ifrt_proxy_cc_test( + name = "grpc_service_impl_test", + size = "small", + srcs = ["grpc_service_impl_test.cc"], + tags = ["no_aarch64"], # TODO(b/326080238): Fix this. + deps = [ + ":grpc_server", + ":grpc_service_impl", + ":host_buffer", + ":version", + "//xla/python/ifrt_proxy/client:grpc_host_buffer", + "//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:test", + ], +) + +cc_library( + name = "ifrt_session_handler", + srcs = ["ifrt_session_handler.cc"], + hdrs = ["ifrt_session_handler.h"], + deps = [ + ":ifrt_backend", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:proto_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@tsl//tsl/platform:statusor", + ], +) + +ifrt_proxy_cc_test( + name = "ifrt_session_handler_test", + srcs = ["ifrt_session_handler_test.cc"], + deps = [ + ":ifrt_backend", + ":ifrt_session_handler", + "//xla/python/ifrt", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:status_matchers", + ], +) + +cc_library( + name = "ifrt_backend", + srcs = ["ifrt_backend.cc"], + hdrs = ["ifrt_backend.h"], + deps = [ + ":host_buffer", + ":host_callback", + ":version", + "//xla:shape_util", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_layout", + "//xla/python/ifrt", + "//xla/python/ifrt:serdes", + "//xla/python/ifrt:sharding_serdes", + "//xla/python/ifrt_proxy/common:array_util", + "//xla/python/ifrt_proxy/common:common_serdes", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:proto_util", + "//xla/python/ifrt_proxy/common:types", + "//xla/python/ifrt_proxy/common:types_proto_cc", + "//xla/python/pjrt_ifrt:xla_ifrt", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@tsl//tsl/concurrency:ref_count", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:status_to_from_proto", + "@tsl//tsl/platform:statusor", + ], +) + +ifrt_proxy_cc_test( + name = "ifrt_backend_test", + srcs = ["ifrt_backend_test.cc"], + deps = [ + ":host_buffer", + ":host_callback", + ":ifrt_backend", + ":version", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/pjrt:host_callback", + "//xla/pjrt:pjrt_common", + "//xla/pjrt:pjrt_device_description", + "//xla/pjrt:pjrt_layout", + "//xla/python/ifrt", + "//xla/python/ifrt:mock", + "//xla/python/ifrt:serdes", + "//xla/python/ifrt:sharding_serdes", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:types", + "//xla/python/ifrt_proxy/common:types_proto_cc", + "//xla/python/pjrt_ifrt:xla_ifrt", + "//xla/service:computation_placer_hdr", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@tsl//tsl/concurrency:ref_count", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:protobuf", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:status_to_from_proto", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + "@tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@tsl//tsl/protobuf:status_proto_cc", + ], +) + +cc_library( + name = "mock_ifrt_backend", + testonly = True, + hdrs = ["mock_ifrt_backend.h"], + deps = [ + ":ifrt_backend", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "host_buffer", + srcs = ["host_buffer.cc"], + hdrs = ["host_buffer.h"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "host_callback", + srcs = ["host_callback.cc"], + hdrs = ["host_callback.h"], + deps = [ + "//xla:shape_util", + "//xla/pjrt:host_callback", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:proto_util", + "//xla/python/pjrt_ifrt", + "//xla/python/pjrt_ifrt:xla_host_callback_proto_cc", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@tsl//tsl/concurrency:ref_count", + "@tsl//tsl/platform:errors", + ], +) + +cc_library( + name = "version", + srcs = ["version.cc"], + hdrs = ["version.h"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +ifrt_proxy_cc_test( + name = "version_test", + srcs = ["version_test.cc"], + deps = [ + ":version", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:status_matchers", + ], +) + +ifrt_proxy_cc_test( + name = "host_buffer_test", + srcs = ["host_buffer_test.cc"], + deps = [ + ":host_buffer", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:status_matchers", + ], +) diff --git a/xla/python/ifrt_proxy/server/grpc_server.cc b/xla/python/ifrt_proxy/server/grpc_server.cc new file mode 100644 index 0000000000000..fbd3b6952eb10 --- /dev/null +++ b/xla/python/ifrt_proxy/server/grpc_server.cc @@ -0,0 +1,99 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/grpc_server.h" + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "grpc/grpc.h" +#include "grpcpp/completion_queue.h" +#include "grpcpp/grpcpp.h" +#include "grpcpp/server_builder.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt_proxy/common/grpc_credentials.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" +#include "xla/python/ifrt_proxy/server/grpc_service_impl.h" +#include "xla/python/ifrt_proxy/server/host_buffer.h" +#include "xla/python/ifrt_proxy/server/ifrt_backend.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +GrpcServer::~GrpcServer() { + server_->Shutdown(); + server_->Wait(); +} + +absl::StatusOr> GrpcServer::Create( + absl::string_view address, + std::unique_ptr impl) { + if (impl == nullptr) { + return absl::InvalidArgumentError( + "Service implementation cannot be a nullptr."); + } + + ::grpc::ServerBuilder builder; + // Remove message size limit to accommodate large messages exchanged during + // model compilation. + builder.AddChannelArgument(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH, -1); + builder.AddChannelArgument(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, -1); + builder.RegisterService(impl.get()); + builder.AddListeningPort(std::string(address), GetServerCredentials()); + auto server = builder.BuildAndStart(); + if (server == nullptr) { + return absl::UnavailableError( + absl::StrCat("Failed to initialize gRPC server at address:", address)); + } + + return absl::WrapUnique( + new GrpcServer(address, std::move(impl), std::move(server))); +} + +absl::StatusOr> +GrpcServer::CreateFromIfrtClientFactory( + absl::string_view address, + absl::AnyInvocable>()> + backend_ifrt_client_factory) { + if (backend_ifrt_client_factory == nullptr) { + return absl::InvalidArgumentError( + "backend_ifrt_client_factory cannot be nullptr."); + } + + auto service = std::make_unique( + [ifrt_client_factory = std::move(backend_ifrt_client_factory)]( + IfrtProxyVersion version, uint64_t session_id, + std::shared_ptr host_buffer_store) mutable + -> absl::StatusOr> { + TF_ASSIGN_OR_RETURN(auto ifrt_client, ifrt_client_factory()); + return IfrtBackend::Create(version, session_id, std::move(ifrt_client), + std::move(host_buffer_store)); + }); + + return Create(address, std::move(service)); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/server/grpc_server.h b/xla/python/ifrt_proxy/server/grpc_server.h new file mode 100644 index 0000000000000..d9bd31dcee376 --- /dev/null +++ b/xla/python/ifrt_proxy/server/grpc_server.h @@ -0,0 +1,79 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_SERVER_GRPC_SERVER_H_ +#define XLA_PYTHON_IFRT_PROXY_SERVER_GRPC_SERVER_H_ + +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "grpcpp/server.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Makes and runs a gRPC server with the given implementation and address. +// Destroying this object shuts down the underlying gRPC server, and so can +// block. +class GrpcServer { + public: + // The address parameter must be in the standard URI format - as needed by the + // ::grpc::ServerBuilder::AddListentingPort. See the ::grpc::ServerBuilder + // documentation for more details. + static absl::StatusOr> Create( + absl::string_view address, + std::unique_ptr impl); + + static absl::StatusOr> + CreateFromIfrtClientFactory( + absl::string_view address, + absl::AnyInvocable>()> + backend_ifrt_client_factory); + + // Starts shutting down the server and waits until it properly shuts down. + ~GrpcServer(); + + // Address this server is listening on. + std::string address() const { return address_; } + + // Blocks until the server shuts down. + void Wait() { server_->Wait(); } + + private: + GrpcServer(absl::string_view address, + std::unique_ptr impl, + std::unique_ptr<::grpc::Server> server) + : address_(address), impl_(std::move(impl)), server_(std::move(server)) {} + + const std::string address_; // Address this server is listening on. + + // Make sure that impl_ outlives the server_. + std::unique_ptr impl_; + std::unique_ptr<::grpc::Server> server_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_SERVER_GRPC_SERVER_H_ diff --git a/xla/python/ifrt_proxy/server/grpc_server_test.cc b/xla/python/ifrt_proxy/server/grpc_server_test.cc new file mode 100644 index 0000000000000..40216bca7876c --- /dev/null +++ b/xla/python/ifrt_proxy/server/grpc_server_test.cc @@ -0,0 +1,72 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/grpc_server.h" + +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::testing::Not; +using ::tsl::testing::IsOk; +using ::tsl::testing::StatusIs; + +// A fake IFRT service that fails all the Session creation attempts. +class FakeIfrtService : public grpc::GrpcIfrtService::Service {}; + +TEST(GrpcServerTest, CreationTest) { + auto addr = absl::StrCat("[::1]:", tsl::testing::PickUnusedPortOrDie()); + auto grpc_service_impl = std::make_unique(); + ASSERT_THAT(GrpcServer::Create(addr, std::move(grpc_service_impl)), IsOk()); + // Also implicitly tests that the destruction of the GrpcServer object. +} + +TEST(GrpcServerTest, CreationFailsIfImplIsNullptr) { + auto addr = absl::StrCat("[::1]:", tsl::testing::PickUnusedPortOrDie()); + EXPECT_THAT(GrpcServer::Create(addr, nullptr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(GrpcServerTest, CreationFailsWithInvalidAddress) { + auto grpc_service_impl = std::make_unique(); + EXPECT_THAT(GrpcServer::Create(/*address=*/"invalid-address", + std::move(grpc_service_impl)), + Not(IsOk())); +} + +TEST(GrpcServerTest, RetrievingServerAddressWorks) { + auto addr = absl::StrCat("[::1]:", tsl::testing::PickUnusedPortOrDie()); + auto grpc_service_impl = std::make_unique(); + TF_ASSERT_OK_AND_ASSIGN( + auto grpc_server, GrpcServer::Create(addr, std::move(grpc_service_impl))); + EXPECT_EQ(grpc_server->address(), addr); +} + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/server/grpc_service_impl.cc b/xla/python/ifrt_proxy/server/grpc_service_impl.cc new file mode 100644 index 0000000000000..8f89d253affcb --- /dev/null +++ b/xla/python/ifrt_proxy/server/grpc_service_impl.cc @@ -0,0 +1,241 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/grpc_service_impl.h" + +#include +#include +#include +#include +#include + +#include "absl/cleanup/cleanup.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "grpcpp/server_context.h" +#include "grpcpp/support/status.h" +#include "grpcpp/support/sync_stream.h" +#include "xla/pjrt/distributed/util.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/proto_util.h" +#include "xla/python/ifrt_proxy/server/host_buffer.h" +#include "xla/python/ifrt_proxy/server/ifrt_session_handler.h" +#include "xla/python/ifrt_proxy/server/version.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +::grpc::Status GrpcServiceImpl::GetVersion(::grpc::ServerContext* context, + const GrpcGetVersionRequest* request, + GrpcGetVersionResponse* response) { + auto protocol_version = + ChooseVersion(request->min_version().protocol_version(), + request->max_version().protocol_version()); + if (!protocol_version.ok()) { + return xla::ToGrpcStatus(protocol_version.status()); + } + response->mutable_version()->set_protocol_version(*protocol_version); + return ::grpc::Status::OK; +} + +::grpc::Status GrpcServiceImpl::IfrtSession( + ::grpc::ServerContext* context, + ::grpc::ServerReaderWriter* stream) { + GrpcIfrtSessionMetadata metadata; + { + const auto it = context->client_metadata().find( + "ifrt-proxy-grpc-ifrt-session-metadata-bin"); + if (it == context->client_metadata().end()) { + return ::grpc::Status(::grpc::StatusCode::INVALID_ARGUMENT, + "Missing metadata for GrpcIfrtService.IfrtSession: " + "ifrt-proxy-grpc-ifrt-session-metadata-bin"); + } + if (!metadata.ParseFromString(AsProtoStringData( + absl::string_view(it->second.data(), it->second.size())))) { + return ::grpc::Status(::grpc::StatusCode::INVALID_ARGUMENT, + "Unable to parse GrpcIfrtSessionMetadata"); + } + } + + const uint64_t session_id = + next_session_id_.fetch_add(1, std::memory_order_relaxed); + + VLOG(0) << "Starting a new IFRT session with session_id=" << session_id; + + // Create a host buffer store for the session. + auto host_buffer_store = + std::make_shared(); + { + absl::MutexLock l(&host_buffer_store_mu_); + CHECK(host_buffer_stores_.insert({session_id, host_buffer_store}).second); + } + absl::Cleanup cleanup = [&] { + absl::MutexLock l(&host_buffer_store_mu_); + CHECK_GT(host_buffer_stores_.erase(session_id), 0); + }; + + absl::Mutex writer_mu; + + auto session_handler = IfrtSessionHandler::Create( + session_id, + [this, version = metadata.version(), + host_buffer_store = std::move(host_buffer_store)](uint64_t session_id) { + return backend_factory_(version, session_id, host_buffer_store); + }); + + if (!session_handler.ok()) { + LOG(INFO) << "Creating session " << session_id + << " failed: " << session_handler.status(); + return xla::ToGrpcStatus(session_handler.status()); + } + + bool first_request_read = false; + while (true) { + auto request = std::make_unique(); + if (!stream->Read(request.get())) { + break; + } + if (!first_request_read) { + VLOG(0) << "First request read for session " << session_id; + first_request_read = true; + } + (*session_handler) + ->NewIncomingRequest(std::move(request), + [&](std::shared_ptr response) { + absl::MutexLock l(&writer_mu); + stream->Write(*response); + }); + } + + VLOG(0) << "Finishing IFRT session " << session_id; + return ::grpc::Status::OK; +} + +::grpc::Status GrpcServiceImpl::HostBufferStore( + ::grpc::ServerContext* context, + ::grpc::ServerReader* stream, + GrpcHostBufferStoreResponse* response) { + const auto it = context->client_metadata().find( + "ifrt-proxy-grpc-host-buffer-store-metadata-bin"); + if (it == context->client_metadata().end()) { + return ::grpc::Status( + ::grpc::StatusCode::INTERNAL, + "Missing gRPC metadata for GrpcHostBufferService.Store"); + } + + GrpcHostBufferStoreMetadata metadata; + if (!metadata.ParseFromString(AsProtoStringData( + absl::string_view(it->second.data(), it->second.size())))) { + return ::grpc::Status(::grpc::StatusCode::DATA_LOSS, + "Unable to parse GrpcHostBufferStoreMetadata"); + } + + std::string data; + data.reserve(metadata.buffer_size()); + + GrpcHostBufferStoreRequest request; + while (stream->Read(&request)) { + data.append(request.data()); + } + if (data.size() != metadata.buffer_size()) { + return ::grpc::Status( + ::grpc::StatusCode::DATA_LOSS, + absl::StrCat("Potential data loss for host buffers: expected ", + metadata.buffer_size(), " bytes but got ", data.size(), + " bytes")); + } + + auto store = GetHostBufferStore(metadata.session_id()); + if (!store.ok()) { + return xla::ToGrpcStatus(store.status()); + } + return xla::ToGrpcStatus((*store)->Store(metadata.handle(), std::move(data))); +} + +::grpc::Status GrpcServiceImpl::HostBufferLookup( + ::grpc::ServerContext* context, const GrpcHostBufferLookupRequest* request, + ::grpc::ServerWriter* stream) { + static constexpr int64_t kChunkSize = 1024 * 1024; + + auto store = GetHostBufferStore(request->session_id()); + if (!store.ok()) { + return xla::ToGrpcStatus(store.status()); + } + auto data = (*store)->Lookup(request->handle()); + if (!data.ok()) { + return xla::ToGrpcStatus(data.status()); + } + + GrpcHostBufferLookupResponse response; + if (!(*data)->empty()) { + for (int64_t offset = 0; offset < (*data)->size(); offset += kChunkSize) { +#if defined(PLATFORM_GOOGLE) + response.set_alias_data( + absl::string_view(**data).substr(offset, kChunkSize)); +#else + // TODO(b/325306748): Find a way to not do a memory-copy. + response.set_data((*data)->substr(offset, kChunkSize)); +#endif + stream->Write(response); + response.Clear(); + } + } else { + // Send at least one response even if the buffer is empty. + stream->Write(response); + } + + return ::grpc::Status::OK; +} + +::grpc::Status GrpcServiceImpl::HostBufferDelete( + ::grpc::ServerContext* context, const GrpcHostBufferDeleteRequest* request, + GrpcHostBufferDeleteResponse* response) { + auto store = GetHostBufferStore(request->session_id()); + if (!store.ok()) { + return xla::ToGrpcStatus(store.status()); + } + return xla::ToGrpcStatus((*store)->Delete(request->handle())); +} + +bool GrpcServiceImpl::Test_InsertHostBufferStore( + uint64_t session_id, + std::shared_ptr store) { + absl::MutexLock l(&host_buffer_store_mu_); + return host_buffer_stores_.insert({session_id, std::move(store)}).second; +} + +bool GrpcServiceImpl::Test_DeleteHostBufferStore(uint64_t session_id) { + absl::MutexLock l(&host_buffer_store_mu_); + return host_buffer_stores_.erase(session_id) > 0; +} + +absl::StatusOr> +GrpcServiceImpl::GetHostBufferStore(uint64_t session_id) { + absl::MutexLock l(&host_buffer_store_mu_); + const auto it = host_buffer_stores_.find(session_id); + if (it == host_buffer_stores_.end()) { + return absl::NotFoundError( + absl::StrCat("Session id ", session_id, " does not exist")); + } + return it->second; +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/server/grpc_service_impl.h b/xla/python/ifrt_proxy/server/grpc_service_impl.h new file mode 100644 index 0000000000000..c75709b4e6ff9 --- /dev/null +++ b/xla/python/ifrt_proxy/server/grpc_service_impl.h @@ -0,0 +1,107 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_SERVER_GRPC_SERVICE_IMPL_H_ +#define XLA_PYTHON_IFRT_PROXY_SERVER_GRPC_SERVICE_IMPL_H_ + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/log/die_if_null.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "grpcpp/server_context.h" +#include "grpcpp/support/status.h" +#include "grpcpp/support/sync_stream.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/server/host_buffer.h" +#include "xla/python/ifrt_proxy/server/ifrt_backend.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Implementation for `GrpcIfrtService`. +class GrpcServiceImpl : public grpc::GrpcIfrtService::Service { + public: + using BackendFactory = + absl::AnyInvocable>( + IfrtProxyVersion version, uint64_t session_id, + std::shared_ptr + host_buffer_store)>; + + explicit GrpcServiceImpl(BackendFactory backend_factory) + : backend_factory_(ABSL_DIE_IF_NULL(std::move(backend_factory))) {} + + ::grpc::Status GetVersion(::grpc::ServerContext* context, + const GrpcGetVersionRequest* request, + GrpcGetVersionResponse* response) override; + + ::grpc::Status IfrtSession( + ::grpc::ServerContext* context, + ::grpc::ServerReaderWriter* stream) override; + + ::grpc::Status HostBufferStore( + ::grpc::ServerContext* context, + ::grpc::ServerReader* stream, + GrpcHostBufferStoreResponse* response) override; + + ::grpc::Status HostBufferLookup( + ::grpc::ServerContext* context, + const GrpcHostBufferLookupRequest* request, + ::grpc::ServerWriter* stream) override; + + ::grpc::Status HostBufferDelete( + ::grpc::ServerContext* context, + const GrpcHostBufferDeleteRequest* request, + GrpcHostBufferDeleteResponse* response) override; + + // Test-only method that adds a new session in the host buffer store map. + // Returns false if the session id already exists. + bool Test_InsertHostBufferStore( + uint64_t session_id, + std::shared_ptr store); + + // Test-only method that removes the given session id from the host buffer + // store map. Returns false if the session id does not exist. + bool Test_DeleteHostBufferStore(uint64_t session_id); + + private: + absl::StatusOr> + GetHostBufferStore(uint64_t session_id) + ABSL_LOCKS_EXCLUDED(host_buffer_store_mu_); + + BackendFactory backend_factory_; + std::atomic next_session_id_ = 1; + + absl::Mutex host_buffer_store_mu_; + absl::flat_hash_map> + host_buffer_stores_ ABSL_GUARDED_BY(host_buffer_store_mu_); +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_SERVER_GRPC_SERVICE_IMPL_H_ diff --git a/xla/python/ifrt_proxy/server/grpc_service_impl_test.cc b/xla/python/ifrt_proxy/server/grpc_service_impl_test.cc new file mode 100644 index 0000000000000..2f8f553794dc4 --- /dev/null +++ b/xla/python/ifrt_proxy/server/grpc_service_impl_test.cc @@ -0,0 +1,184 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/grpc_service_impl.h" + +#include +#include +#include + +#include +#include +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "grpcpp/server.h" +#include "grpcpp/server_builder.h" +#include "grpcpp/support/channel_arguments.h" +#include "grpcpp/support/status.h" +#include "xla/python/ifrt_proxy/client/grpc_host_buffer.h" +#include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/server/grpc_server.h" +#include "xla/python/ifrt_proxy/server/host_buffer.h" +#include "xla/python/ifrt_proxy/server/version.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::tsl::testing::IsOk; +using ::tsl::testing::IsOkAndHolds; +using ::tsl::testing::StatusIs; + +IfrtProxyVersion Version() { + IfrtProxyVersion version; + version.set_protocol_version(kServerMaxVersion); + return version; +} + +// Sets up fresh GrpcServer for testing. +absl::StatusOr> MakeGrpcServer() { + // TODO(b/282993619): For external/GKE uses, we may need to find (or build) + // a utility function that works similar to PickUnusedPortorDie(). + auto addr = absl::StrCat("[::1]:", tsl::testing::PickUnusedPortOrDie()); + return GrpcServer::CreateFromIfrtClientFactory(addr, []() { + return absl::UnimplementedError( + "IFRT client creation fails. This test is not expected to " + "instantiate any IFRT client"); + }); +} + +TEST(GrpcServiceImplTest, CanBeUsedToSetupAnGrpcServer) { + ASSERT_THAT(MakeGrpcServer(), IsOk()); + // Also implicitly tests that destruction of both the server and the + // implementation objects. +} + +class GrpcIfrtServiceImplHostBufferTest + : public testing::TestWithParam { + protected: + GrpcIfrtServiceImplHostBufferTest() + : impl_([](IfrtProxyVersion version, uint64_t session_id, + std::shared_ptr host_buffer_store) { + return absl::UnimplementedError( + "IFRT backend creation is not implemented"); + }) { + ::grpc::ServerBuilder builder; + builder.RegisterService(&impl_); + server_ = builder.BuildAndStart(); + + stub_ = grpc::GrpcIfrtService::NewStub( + server_->InProcessChannel(::grpc::ChannelArguments())); + } + + // Returns a string to be stored as a host buffer. The length is parameterized + // so that we can test chunking. + std::string GetTestData() const { + std::string data; + for (int i = 0; i < GetParam(); ++i) { + data.push_back(i % 7); + } + return data; + } + + GrpcServiceImpl impl_; + std::unique_ptr<::grpc::Server> server_; + std::shared_ptr stub_; +}; + +TEST_P(GrpcIfrtServiceImplHostBufferTest, StoreAndLookupStringView) { + static constexpr uint64_t kSessionId = 1; + + auto store = std::make_shared(); + ASSERT_TRUE(impl_.Test_InsertHostBufferStore(kSessionId, store)); + GrpcClientHostBufferStore client(stub_, Version(), kSessionId); + + constexpr uint64_t kHandle = 2; + const std::string data = GetTestData(); + absl::string_view source(data); + + ASSERT_THAT(client.Store(kHandle, source).Await(), IsOk()); + EXPECT_THAT(client.Lookup(kHandle).Await(), IsOkAndHolds(data)); + + EXPECT_TRUE(impl_.Test_DeleteHostBufferStore(kSessionId)); +} + +TEST_P(GrpcIfrtServiceImplHostBufferTest, StoreAndLookupCord) { + static constexpr uint64_t kSessionId = 1; + + auto store = std::make_shared(); + ASSERT_TRUE(impl_.Test_InsertHostBufferStore(kSessionId, store)); + GrpcClientHostBufferStore client(stub_, Version(), kSessionId); + + constexpr uint64_t kHandle = 2; + const std::string data = GetTestData(); + + absl::Cord source(data); + ASSERT_THAT(client.Store(kHandle, source).Await(), IsOk()); + EXPECT_THAT(client.Lookup(kHandle).Await(), IsOkAndHolds(data)); + + EXPECT_TRUE(impl_.Test_DeleteHostBufferStore(kSessionId)); +} + +TEST_P(GrpcIfrtServiceImplHostBufferTest, Lookup) { + static constexpr uint64_t kSessionId = 1; + + auto store = std::make_shared(); + ASSERT_TRUE(impl_.Test_InsertHostBufferStore(kSessionId, store)); + GrpcClientHostBufferStore client(stub_, Version(), kSessionId); + + constexpr uint64_t kHandle = 2; + const std::string data = GetTestData(); + ASSERT_THAT(store->Store(kHandle, data), IsOk()); + + EXPECT_THAT(client.Lookup(kHandle).Await(), IsOkAndHolds(data)); + + EXPECT_TRUE(impl_.Test_DeleteHostBufferStore(kSessionId)); +} + +TEST_P(GrpcIfrtServiceImplHostBufferTest, Delete) { + static constexpr uint64_t kSessionId = 1; + + auto store = std::make_shared(); + ASSERT_TRUE(impl_.Test_InsertHostBufferStore(kSessionId, store)); + GrpcClientHostBufferStore client(stub_, Version(), kSessionId); + + constexpr uint64_t kHandle = 2; + const std::string data = GetTestData(); + ASSERT_THAT(store->Store(kHandle, data), IsOk()); + + ASSERT_THAT(client.Delete(kHandle).Await(), IsOk()); + EXPECT_THAT(client.Lookup(kHandle).Await(), + StatusIs(absl::StatusCode::kNotFound)); + + EXPECT_TRUE(impl_.Test_DeleteHostBufferStore(kSessionId)); +} + +INSTANTIATE_TEST_SUITE_P( + DataSize, GrpcIfrtServiceImplHostBufferTest, + testing::Values(0, // Empty host buffer. + 16, // Small enough to fit in one chunk. + 3 * 1024 * 1024)); // Requires multiple chunks + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/server/host_buffer.cc b/xla/python/ifrt_proxy/server/host_buffer.cc new file mode 100644 index 0000000000000..4b9dd7391ec81 --- /dev/null +++ b/xla/python/ifrt_proxy/server/host_buffer.cc @@ -0,0 +1,65 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/host_buffer.h" + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +absl::Status HostBufferStore::Store(uint64_t handle, std::string data) { + absl::MutexLock lock(&mu_); + const bool inserted = + buffers_.insert({handle, std::make_shared(std::move(data))}) + .second; + if (!inserted) { + return absl::AlreadyExistsError( + absl::StrCat("Host buffer handle ", handle, " already exists")); + } + return absl::OkStatus(); +} + +absl::StatusOr> HostBufferStore::Lookup( + uint64_t handle) { + absl::MutexLock lock(&mu_); + const auto it = buffers_.find(handle); + if (it == buffers_.end()) { + return absl::NotFoundError( + absl::StrCat("Host buffer handle ", handle, " not found")); + } + return it->second; +} + +absl::Status HostBufferStore::Delete(uint64_t handle) { + absl::MutexLock lock(&mu_); + if (buffers_.erase(handle) == 0) { + return absl::NotFoundError( + absl::StrCat("Host buffer handle ", handle, " not found")); + } + return absl::OkStatus(); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/server/host_buffer.h b/xla/python/ifrt_proxy/server/host_buffer.h new file mode 100644 index 0000000000000..f9b07a40f30e9 --- /dev/null +++ b/xla/python/ifrt_proxy/server/host_buffer.h @@ -0,0 +1,61 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_SERVER_HOST_BUFFER_H_ +#define XLA_PYTHON_IFRT_PROXY_SERVER_HOST_BUFFER_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Keeps host buffers transferred from the client so that `IfrtBackend` can +// access them when requests with pointers to host buffers arrive. +// +// We expect one `HostBufferStore` to exist per session (i.e., per `IfrtBackend` +// instance) so that host buffers are cleaned up on session termination. +class HostBufferStore { + public: + // Stores the data associated with the given handle. Returns an error if the + // handle already exists. + absl::Status Store(uint64_t handle, std::string data); + + // Retrieves the data associated with the handle. Returns an error if the + // handle does not exist. + absl::StatusOr> Lookup(uint64_t handle); + + // Deletes the host buffer associated with the handle. Returns an error if the + // handle does not exist. + absl::Status Delete(uint64_t handle); + + private: + absl::Mutex mu_; + absl::flat_hash_map> buffers_ + ABSL_GUARDED_BY(mu_); +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_SERVER_HOST_BUFFER_H_ diff --git a/xla/python/ifrt_proxy/server/host_buffer_test.cc b/xla/python/ifrt_proxy/server/host_buffer_test.cc new file mode 100644 index 0000000000000..7adc31658dda3 --- /dev/null +++ b/xla/python/ifrt_proxy/server/host_buffer_test.cc @@ -0,0 +1,57 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/host_buffer.h" + +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "tsl/platform/status_matchers.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::testing::Pointee; +using ::tsl::testing::IsOk; +using ::tsl::testing::IsOkAndHolds; +using ::tsl::testing::StatusIs; + +TEST(HostBufferStoreTest, ReadAfterWrite) { + HostBufferStore store; + const uint64_t kHandle = 1; + + ASSERT_THAT(store.Store(kHandle, "foo"), IsOk()); + EXPECT_THAT(store.Lookup(kHandle), IsOkAndHolds(Pointee(std::string("foo")))); + + ASSERT_THAT(store.Delete(kHandle), IsOk()); + EXPECT_THAT(store.Lookup(kHandle), StatusIs(absl::StatusCode::kNotFound)); +} + +TEST(HostBufferStoreTest, UnknownHandle) { + HostBufferStore store; + const uint64_t kHandle = 1; + + EXPECT_THAT(store.Lookup(kHandle), StatusIs(absl::StatusCode::kNotFound)); + EXPECT_THAT(store.Delete(kHandle), StatusIs(absl::StatusCode::kNotFound)); +} + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/server/host_callback.cc b/xla/python/ifrt_proxy/server/host_callback.cc new file mode 100644 index 0000000000000..43e13700c293b --- /dev/null +++ b/xla/python/ifrt_proxy/server/host_callback.cc @@ -0,0 +1,195 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/host_callback.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/functional/bind_front.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt_proxy/common/proto_util.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/python/pjrt_ifrt/xla_host_callback.pb.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/errors.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +RemoteLoadedHostCallbackQueue::~RemoteLoadedHostCallbackQueue() { Close(); } + +absl::Status RemoteLoadedHostCallbackQueue::Push(ExecutionRequest request) { + absl::MutexLock l(&mu_); + if (closed_) { + return absl::CancelledError( + "RemoteLoadedHostCallback has stopped accepting new execution " + "requests"); + } + requests_.push_back(std::move(request)); + return absl::OkStatus(); +} + +std::optional +RemoteLoadedHostCallbackQueue::Pop() { + auto not_empty = [this]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return !requests_.empty() || closed_; + }; + absl::MutexLock l(&mu_, absl::Condition(¬_empty)); + if (closed_) { + return std::nullopt; + } + ExecutionRequest request = std::move(requests_.front()); + requests_.pop_front(); + return request; +} + +void RemoteLoadedHostCallbackQueue::Close() { + std::deque requests; + { + absl::MutexLock l(&mu_); + if (!closed_) { + requests.swap(requests_); + } + closed_ = true; + } + for (auto& request : requests) { + request.status.Set(absl::CancelledError( + "RemoteLoadedHostCallback execution has been cancelled")); + } +} + +absl::StatusOr> +RemoteLoadedHostCallback::CreateFromSerialized( + xla::ifrt::Client* client, absl::string_view serialized, + std::shared_ptr queue) { + xla::ifrt::XlaHostCallbackProto proto; + if (!proto.ParseFromString(AsProtoStringData(serialized))) { + return absl::DataLossError( + "Unable to deserialize RemoteLoadedHostCallback"); + } + + auto from_proto = + [](const auto& arg_protos) -> std::vector { + std::vector args; + args.reserve(arg_protos.size()); + for (const xla::ifrt::XlaHostCallbackProto::ArgInfo& arg_proto : + arg_protos) { + xla::HostCallbackArgInfo& arg = args.emplace_back(); + arg.channel_id = static_cast(arg_proto.channel_id()); + arg.shape = xla::Shape(arg_proto.shape()); + } + return args; + }; + + return tsl::MakeRef( + client, from_proto(proto.operands()), from_proto(proto.results()), + std::move(queue)); +} + +RemoteLoadedHostCallback::RemoteLoadedHostCallback( + xla::ifrt::Client* client, std::vector operands, + std::vector results, + std::shared_ptr queue) + : llvm::RTTIExtends( + client, + [&]() { + auto xla_host_callback = std::make_unique(); + xla_host_callback->operands = std::move(operands); + xla_host_callback->results = std::move(results); + xla_host_callback->callback = + absl::bind_front(&RemoteLoadedHostCallback::Execute, this); + return xla_host_callback; + }()), + queue_(std::move(queue)) {} + +RemoteLoadedHostCallback::~RemoteLoadedHostCallback() { + if (queue_ != nullptr) { + queue_->Close(); + } +} + +absl::Status RemoteLoadedHostCallback::Execute(void** result_ptrs, + void** operand_ptrs) { + if (queue_ == nullptr) { + return absl::FailedPreconditionError( + "RemoteLoadedHostCallback without queue cannot be executed"); + } + + RemoteLoadedHostCallbackQueue::ExecutionRequest request; + + auto to_buffer = + [&](absl::Span args, void** ptrs, + std::vector& buffers) { + buffers.reserve(args.size()); + for (int i = 0; i < args.size(); ++i) { + const int64_t size = xla::ShapeUtil::ByteSizeOf(args[i].shape); + buffers.push_back( + RemoteLoadedHostCallbackQueue::Buffer{ptrs[i], size}); + } + }; + to_buffer(host_callback().operands, operand_ptrs, request.operands); + to_buffer(host_callback().results, result_ptrs, request.results); + + request.status = Future::CreatePromise(); + Future status(request.status); + + // Enqueue the execution request. `IfrtBackend` retrieves this by calling + // `PopExecutionRequest` and fulfills the `results` promise. + TF_RETURN_IF_ERROR(queue_->Push(std::move(request))); + + // Block until the execution finishes and return its status. + return status.Await(); +} + +absl::StatusOr RemoteLoadedHostCallback::Serialize() const { + xla::ifrt::XlaHostCallbackProto proto; + + auto to_proto = [](absl::Span args, + auto* args_proto) { + args_proto->Reserve(args.size()); + for (const auto& arg : args) { + auto* arg_proto = args_proto->Add(); + arg_proto->set_channel_id(arg.channel_id); + *arg_proto->mutable_shape() = arg.shape.ToProto(); + } + }; + to_proto(host_callback().operands, proto.mutable_operands()); + to_proto(host_callback().results, proto.mutable_results()); + + return proto.SerializeAsString(); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/server/host_callback.h b/xla/python/ifrt_proxy/server/host_callback.h new file mode 100644 index 0000000000000..e2d6ea834e7d6 --- /dev/null +++ b/xla/python/ifrt_proxy/server/host_callback.h @@ -0,0 +1,126 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_SERVER_HOST_CALLBACK_H_ +#define XLA_PYTHON_IFRT_PROXY_SERVER_HOST_CALLBACK_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "tsl/concurrency/ref_count.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// Command queue interface between `RemoteLoadedHostCallback` and `IfrtBackend`. +// Responsible for keeping track of in-flight execution requests. +class RemoteLoadedHostCallbackQueue { + public: + struct Buffer { + void* data; + int64_t size; + }; + + // Encapsulates a host buffer execution. Operand and result buffers are + // pre-allocated and the caller is expected to fill them in-place before + // fulfilling the `status` promise. + struct ExecutionRequest { + std::vector operands; + std::vector results; + Future::Promise status; + }; + + ~RemoteLoadedHostCallbackQueue(); + + // Pushes a new execution request to the queue. Returns an error if the queue + // has already been closed. + absl::Status Push(ExecutionRequest request); + + // Blocks until this host callback queue has at least one pending execution + // and returns its information needed to perform execution. Returns nullopt if + // the request queue has already been closed by `Close()`. + std::optional Pop(); + + // Closes this request queue. After this call, all pending executions are + // unblocked with an error and no more executions can be enqueued. + void Close(); + + private: + absl::Mutex mu_; + bool closed_ ABSL_GUARDED_BY(mu_) = false; + std::deque requests_ ABSL_GUARDED_BY(mu_); +}; + +// Host callback that delegates its execution to an external executor. The +// executor waits for execution requests to be enqueued to the given +// `RemoteLoadedHostCallbackQueue` and returns results after execution by +// fulfilling the returned promise. +// +// This class is thread-safe. +// +// Note: The current implementation inherits from PjRt's host callback +// implementation. Even though this is a violation of the IFRT proxy's layering +// principle, it is unavoidable right now because the base `LoadedHostCallback` +// in IFRT has no associated execution semantics. For now, the IFRT proxy +// focuses on supporting host callbacks on PjRt-like IFRT implementations. +class RemoteLoadedHostCallback + : public llvm::RTTIExtends { + public: + // Creates from a serialized string returned by `Serialize()`. + static absl::StatusOr> + CreateFromSerialized(xla::ifrt::Client* client, absl::string_view serialized, + std::shared_ptr queue); + + // Create from operand/result specs. + RemoteLoadedHostCallback( + xla::ifrt::Client* client, std::vector operands, + std::vector results, + std::shared_ptr queue); + + ~RemoteLoadedHostCallback() override; + + // Serializes the remote host callback instance. The returned string can be + // deserialized into `RmeoteLoadedHostCallback` using `CreateFromSerialized`. + absl::StatusOr Serialize() const override; + + private: + // Implements the interface required by `xla::HostCallback`. + absl::Status Execute(void** result_ptrs, void** operand_ptrs); + + std::shared_ptr queue_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_SERVER_HOST_CALLBACK_H_ diff --git a/xla/python/ifrt_proxy/server/ifrt_backend.cc b/xla/python/ifrt_proxy/server/ifrt_backend.cc new file mode 100644 index 0000000000000..2758b714d8fa8 --- /dev/null +++ b/xla/python/ifrt_proxy/server/ifrt_backend.cc @@ -0,0 +1,1206 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/ifrt_backend.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/bind_front.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "xla/layout.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/serdes.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/sharding_serdes.h" +#include "xla/python/ifrt_proxy/common/array_util.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/proto_util.h" +#include "xla/python/ifrt_proxy/common/types.h" +#include "xla/python/ifrt_proxy/common/types.pb.h" +#include "xla/python/ifrt_proxy/server/host_buffer.h" +#include "xla/python/ifrt_proxy/server/host_callback.h" +#include "xla/python/ifrt_proxy/server/version.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/status_macros.h" +#include "xla/xla_data.pb.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status_to_from_proto.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/threadpool.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +namespace { + +// Convenient wrapper for `xla::ifrt::Deserialize()`. +template +absl::StatusOr> Deserialize( + const Serialized& serialized, + std::unique_ptr options = nullptr) { + TF_ASSIGN_OR_RETURN(auto deserialized, + Deserialize(serialized, std::move(options))); + auto obj = absl::WrapUnique(llvm::dyn_cast(deserialized.release())); + if (obj == nullptr) { + return absl::InvalidArgumentError("Deserialization type mismatch"); + } + return obj; +} + +} // namespace + +IfrtBackend::IfrtBackend(IfrtProxyVersion version, uint64_t session_id, + std::unique_ptr ifrt_client, + std::shared_ptr host_buffer_store) + : version_(std::move(version)), + session_id_(session_id), + client_(std::move(ifrt_client)), + host_buffer_store_(std::move(host_buffer_store)), + compile_thread_pool_( + tsl::Env::Default(), + []() { + tsl::ThreadOptions options; + // Use a larger stack size since XLA often requires larger stacks + // for compilation. + options.stack_size = 240 * 1024; + return options; + }(), + "IfrtBackend", + // TODO(b/282757875): Consider making this configurable. + /*num_threads=*/32) {} + +absl::StatusOr> IfrtBackend::Create( + IfrtProxyVersion version, uint64_t session_id, + std::unique_ptr ifrt_client, + std::shared_ptr host_buffer_store) { + if (ifrt_client == nullptr) { + return absl::InvalidArgumentError("ifrt_client cannot be a nullptr."); + } + if (version.protocol_version() < kServerMinVersion || + version.protocol_version() > kServerMaxVersion) { + return absl::FailedPreconditionError(absl::StrCat( + "Protocol version ", version.protocol_version(), + " is unsupported by IFRT Proxy server; supported versions: [", + kServerMinVersion, ",", kServerMaxVersion, "]")); + } + return absl::WrapUnique( + new IfrtBackend(std::move(version), session_id, std::move(ifrt_client), + std::move(host_buffer_store))); +} + +IfrtBackend::~IfrtBackend() { + // Cancel all in-flight host callback executions. + absl::flat_hash_map + host_callback_executions; + { + absl::MutexLock lock(&host_callback_executions_mutex_); + host_callback_executions.swap(host_callback_executions_); + } + for (auto& [handle, execution_request] : host_callback_executions) { + std::move(execution_request) + .status.Set(absl::CancelledError("IFRT backend has shut down")); + } + + // Wait until all async work from `AsyncExecute` finishes execution. + { + auto done = [this]() ABSL_SHARED_LOCKS_REQUIRED(in_flight_count_mutex_) { + return in_flight_count_ == 0; + }; + absl::MutexLock lock(&in_flight_count_mutex_, absl::Condition(&done)); + } +} + +Future IfrtBackend::Process( + std::unique_ptr request) { + switch (request->request_case()) { + case IfrtRequest::RequestCase::kInitRequest: + return Future(HandleInit(std::move(request))); + case IfrtRequest::RequestCase::kCheckFutureRequest: + return HandleCheckFutureRequest(std::move(request)); + case IfrtRequest::RequestCase::kMakeArrayFromHostBufferRequest: + return Future( + HandleMakeArrayFromHostBufferRequest(std::move(request))); + case IfrtRequest::RequestCase::kAssembleArrayFromSingleDeviceArraysRequest: + return Future( + HandleAssembleArrayFromSingleDeviceArraysRequest(std::move(request))); + case IfrtRequest::RequestCase::kCopyToHostBufferRequest: + return HandleCopyToHostBufferRequest(std::move(request)); + case IfrtRequest::RequestCase::kDisassembleIntoSingleDeviceArraysRequest: + return Future( + HandleDisassembleIntoSingleDeviceArraysRequest(std::move(request))); + case IfrtRequest::RequestCase::kCheckArrayReadyRequest: + return Future(HandleCheckArrayReadyRequest(std::move(request))); + case IfrtRequest::RequestCase::kReshardRequest: + return Future(HandleReshardRequest(std::move(request))); + case IfrtRequest::RequestCase::kFullyReplicatedShardRequest: + return Future( + HandleFullyReplicatedShardRequest(std::move(request))); + case IfrtRequest::RequestCase::kDeleteArrayRequest: + return Future(HandleDeleteArrayRequest(std::move(request))); + case IfrtRequest::RequestCase::kIsArrayDeletedRequest: + return Future(HandleIsArrayDeletedRequest(std::move(request))); + case IfrtRequest::RequestCase::kDestructArrayRequest: + return Future(HandleDestructArrayRequest(std::move(request))); + case IfrtRequest::RequestCase::kCompileRequest: + return Future(HandleCompileRequest(std::move(request))); + case IfrtRequest::RequestCase::kLoadedExecutableMetadataRequest: + return HandleLoadedExecutableMetadataRequest(std::move(request)); + case IfrtRequest::RequestCase::kLoadedExecutableExecuteRequest: + return Future( + HandleLoadedExecutableExecuteRequest(std::move(request))); + case IfrtRequest::RequestCase::kLoadedExecutableDeleteRequest: + return Future( + HandleLoadedExecutableDeleteRequest(std::move(request))); + case IfrtRequest::RequestCase::kLoadedExecutableIsDeletedRequest: + return Future( + HandleLoadedExecutableIsDeletedRequest(std::move(request))); + case IfrtRequest::RequestCase::kLoadedExecutableDestructRequest: + return Future( + HandleLoadedExecutableDestructRequest(std::move(request))); + case IfrtRequest::RequestCase::kLoadedHostCallbackPollRequest: + return HandleLoadedHostCallbackPollRequest(std::move(request)); + case IfrtRequest::RequestCase::kLoadedHostCallbackReturnRequest: + return Future( + HandleLoadedHostCallbackReturnRequest(std::move(request))); + case IfrtRequest::RequestCase::kGetDefaultDeviceAssignmentRequest: + return Future( + HandleGetDefaultDeviceAssignmentRequest(std::move(request))); + default: + return Future(absl::UnimplementedError(absl::StrCat( + "Got unimplemented request type: ", request->request_case()))); + } +} + +uint64_t IfrtBackend::HandleGenerator::New() { + absl::MutexLock lock(&mu_); + return current_++; +} + +void IfrtBackend::HandleGenerator::BulkNew(absl::Span handles) { + absl::MutexLock lock(&mu_); + std::iota(handles.begin(), handles.end(), current_); + current_ += handles.size(); +} + +Future IfrtBackend::AsyncExecute( + std::function handle_fn, tsl::thread::ThreadPool* thread_pool) { + { + absl::MutexLock lock(&in_flight_count_mutex_); + ++in_flight_count_; + } + auto promise = Future::CreatePromise(); + auto f = [this, promise, handle_fn = std::move(handle_fn)]() mutable { + promise.Set(handle_fn()); + { + absl::MutexLock lock(&in_flight_count_mutex_); + --in_flight_count_; + } + }; + if (thread_pool != nullptr) { + thread_pool->Schedule(std::move(f)); + } else { + tsl::Env::Default()->SchedClosure(std::move(f)); + } + return Future(std::move(promise)); +} + +///////////////////////////////////////////////////////////////////////////// +// +// Handlers for individual request types +// + +BackendInterface::Response IfrtBackend::HandleInit( + std::unique_ptr request) { + std::unique_ptr response = + NewIfrtResponse(request->request_metadata().op_id()); + auto* init_resp = response->mutable_init_response(); + init_resp->set_session_id(session_id_); + init_resp->set_platform_name(AsProtoStringData(client_->platform_name())); + init_resp->set_platform_version( + AsProtoStringData(client_->platform_version())); + init_resp->set_platform_id(client_->platform_id()); + init_resp->set_runtime_type(AsProtoStringData(client_->runtime_type())); + init_resp->set_process_index(client_->process_index()); + + for (auto* device : client_->devices()) { + InitResponse::Device* d = init_resp->add_devices(); + d->set_id(device->id()); + d->set_local_device_id(device->local_device_id().value()); + d->set_local_hardware_id(device->local_hardware_id_typed().value()); + d->set_device_kind(AsProtoStringData(device->device_kind())); + if (auto default_memory_space = device->default_memory_space(); + default_memory_space.ok()) { + d->set_default_memory_id((*default_memory_space)->id()); + } + for (const auto* memory : device->memory_spaces()) { + d->add_memory_ids(memory->id()); + } + d->set_debug_string(AsProtoStringData(device->DebugString())); + d->set_to_string(AsProtoStringData(device->ToString())); + for (const auto& [name, attr] : device->Attributes()) { + TF_ASSIGN_OR_RETURN((*d->mutable_attributes())[name], + ToVariantProto(attr)); + } + } + for (auto* addressable_device : client_->addressable_devices()) { + init_resp->add_addressable_device_ids(addressable_device->id()); + } + + absl::flat_hash_map memories; + for (auto* device : client_->devices()) { + for (xla::ifrt::Memory* memory : device->memory_spaces()) { + const auto [it, inserted] = memories.insert({memory->id(), memory}); + if (!inserted && it->second != memory) { + return absl::FailedPreconditionError(absl::StrCat( + "Two memories cannot have the same id: ", memory->ToString(), + " vs. ", it->second->ToString())); + } + } + } + for (const auto& [id, memory] : memories) { + auto* m = init_resp->add_memories(); + m->set_id(id); + m->set_memory_space_kind(AsProtoStringData(memory->memory_space_kind())); + for (const auto* device : memory->devices()) { + m->add_device_ids(device->id()); + } + m->set_debug_string(AsProtoStringData(memory->DebugString())); + m->set_to_string(AsProtoStringData(memory->ToString())); + } + + return response; +} + +Future IfrtBackend::HandleCheckFutureRequest( + std::unique_ptr request) { + const CheckFutureRequest& check_request = request->check_future_request(); + + Future future; + { + absl::MutexLock lock(&futures_mutex_); + const auto it = futures_.find(check_request.future_handle()); + if (it == futures_.end()) { + return Future(absl::NotFoundError(absl::StrCat( + "Unknown future handle: ", check_request.future_handle()))); + } + future = std::move(it->second); + futures_.erase(it); + } + + auto promise = Future::CreatePromise(); + // With PjRtFuture, the `Future` needs to be owned by one or more owners until + // `OnReady()`'s lambda gets executed. So, capture a copy of `future` in the + // lambda, making the lambda itself an owner of `future`. + future.OnReady([op_id = request->request_metadata().op_id(), promise, + hold = future](absl::Status status) mutable { + if (!status.ok()) { + promise.Set(std::move(status)); + return; + } + auto ifrt_resp = NewIfrtResponse(op_id); + ifrt_resp->mutable_check_future_response(); + promise.Set(std::move(ifrt_resp)); + }); + + return Future(std::move(promise)); +} + +BackendInterface::Response IfrtBackend::HandleMakeArrayFromHostBufferRequest( + std::unique_ptr request) { + if (!request->has_make_array_from_host_buffer_request()) { + return absl::InternalError( + "MakeArrayFromHostBuffer got an IfrtRequest with no " + "MakeArrayFromHostBufferRequest in it."); + } + auto* make_array_request = + request->mutable_make_array_from_host_buffer_request(); + + TF_ASSIGN_OR_RETURN( + auto sharding, + FromShardingProto(absl::bind_front(&Client::LookupDevice, client_.get()), + make_array_request->sharding())); + + const auto byte_strides = [&]() -> std::optional> { + if (!make_array_request->has_byte_strides()) return std::nullopt; + return FromByteStridesProto(make_array_request->byte_strides()); + }(); + TF_ASSIGN_OR_RETURN(const auto shape, + Shape::FromProto(make_array_request->shape())); + TF_ASSIGN_OR_RETURN(const auto dtype, + DType::FromProto(make_array_request->dtype())); + + const uint64_t host_buffer_handle = make_array_request->host_buffer_handle(); + absl::Cleanup cleanup = [&] { + CHECK_OK(host_buffer_store_->Delete(host_buffer_handle)); + }; + TF_ASSIGN_OR_RETURN(std::shared_ptr host_buffer, + host_buffer_store_->Lookup(host_buffer_handle)); + std::move(cleanup).Invoke(); + + TF_ASSIGN_OR_RETURN(const auto mem_region, + ArrayMemRegion::FromMinimalMemRegion( + *host_buffer, dtype, shape, byte_strides)); + + TF_ASSIGN_OR_RETURN( + auto array, + client_->MakeArrayFromHostBuffer( + mem_region.zeroth_element(), dtype, std::move(shape), + std::move(byte_strides), std::move(sharding), + xla::ifrt::Client::HostBufferSemantics:: + kImmutableUntilTransferCompletes, + [hold = std::move(host_buffer)]() mutable { hold.reset(); })); + + // TODO(b/282757875): Consider merging the handle_generator with the + // arrays_. + uint64_t handle = handle_generator_.New(); + { + absl::MutexLock lock(&arrays_mutex_); + arrays_.insert({handle, std::move(array)}); + } + + std::unique_ptr response = + NewIfrtResponse(request->request_metadata().op_id()); + auto* make_array_resp = + response->mutable_make_array_from_host_buffer_response(); + make_array_resp->set_array_handle(handle); + + return response; +} + +BackendInterface::Response +IfrtBackend::HandleAssembleArrayFromSingleDeviceArraysRequest( + std::unique_ptr request) { + const auto& assemble_request = + request->assemble_array_from_single_device_arrays_request(); + + std::vector> arrays; + { + absl::ReaderMutexLock lock(&arrays_mutex_); + for (const uint64_t handle : + assemble_request.single_device_array_handles()) { + TF_ASSIGN_OR_RETURN(arrays.emplace_back(), GetArrayLocked(handle)); + } + } + + TF_ASSIGN_OR_RETURN(Shape shape, Shape::FromProto(assemble_request.shape())); + TF_ASSIGN_OR_RETURN( + auto sharding, + FromShardingProto(absl::bind_front(&Client::LookupDevice, client_.get()), + assemble_request.sharding())); + TF_ASSIGN_OR_RETURN(auto semantics, FromArrayCopySemanticsProto( + assemble_request.copy_semantics())); + + TF_ASSIGN_OR_RETURN(auto array, client_->AssembleArrayFromSingleDeviceArrays( + std::move(shape), std::move(sharding), + absl::MakeSpan(arrays), semantics)); + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + + uint64_t handle = handle_generator_.New(); + ifrt_resp->mutable_assemble_array_from_single_device_arrays_response() + ->set_array_handle(handle); + { + absl::MutexLock lock(&arrays_mutex_); + arrays_.insert({handle, std::move(array)}); + } + + return ifrt_resp; +} + +Future IfrtBackend::HandleCopyToHostBufferRequest( + std::unique_ptr request) { + const CopyToHostBufferRequest& copy_to_host = + request->copy_to_host_buffer_request(); + + auto array = GetArray(copy_to_host.array_handle()); + if (!array.ok()) { + return Future(array.status()); + } + + // Determine the size and allocate the host buffer. + // TODO(b/282757875): We may need to redo this to account for byte_strides, + // padding, and alignment requirements. + std::optional element_size = (*array)->dtype().byte_size(); + if (element_size == std::nullopt) { + return Future( + absl::InternalError("Array element size is unknown.")); + } + int64_t host_buffer_size = + (*array)->shape().num_elements() * element_size.value(); + // Use `std::unique_ptr` for pointer stability. + auto host_buffer = std::make_unique(); + host_buffer->resize(host_buffer_size); + + const auto byte_strides = [&]() -> std::optional> { + if (!copy_to_host.has_byte_strides()) { + return std::nullopt; + } + return FromByteStridesProto(copy_to_host.byte_strides()); + }(); + const auto mem_region = ArrayMemRegion::FromMinimalMemRegion( + absl::string_view(*host_buffer), (*array)->dtype(), (*array)->shape(), + byte_strides); + if (!mem_region.ok()) { + return Future(mem_region.status()); + } + + // TODO(b/282757875): Consider other ArrayCopySemantics. + Future copy_status = + (*array)->CopyToHostBuffer(mem_region->zeroth_element(), byte_strides, + ArrayCopySemantics::kAlwaysCopy); + + auto resp_promise = Future::CreatePromise(); + Future resp_future(resp_promise); + auto on_ready = [this, op_id = request->request_metadata().op_id(), + host_buffer = std::move(host_buffer), + host_buffer_handle = copy_to_host.host_buffer_handle()]( + absl::Status status) mutable + -> absl::StatusOr> { + TF_RETURN_IF_ERROR(status); + + TF_RETURN_IF_ERROR( + host_buffer_store_->Store(host_buffer_handle, *std::move(host_buffer))); + + std::unique_ptr response = NewIfrtResponse(op_id); + response->mutable_copy_to_host_buffer_response(); + return response; + }; + copy_status.OnReady( + [promise = std::move(resp_promise), on_ready = std::move(on_ready)]( + absl::Status status) mutable { promise.Set(on_ready(status)); }); + + return resp_future; +} + +BackendInterface::Response +IfrtBackend::HandleDisassembleIntoSingleDeviceArraysRequest( + std::unique_ptr request) { + TF_ASSIGN_OR_RETURN( + auto array, + GetArray(request->disassemble_into_single_device_arrays_request() + .array_handle())); + + // TODO(b/282757875): Consider other ArrayCopySemantics. + TF_ASSIGN_OR_RETURN(auto single_device_arrays, + array->DisassembleIntoSingleDeviceArrays( + xla::ifrt::ArrayCopySemantics::kAlwaysCopy)); + + // Set up an IfrtResponse with pre-allocated space for the right number of + // single device array handles. + int64_t num_arrays = single_device_arrays.size(); + auto response = NewIfrtResponse(request->request_metadata().op_id()); + + // Pre-allocate space in the response proto and fill it in with bulk allocated + // new handles. + auto* handles = + response->mutable_disassemble_into_single_device_arrays_response() + ->mutable_single_device_array_handles(); + handles->Reserve(num_arrays); + uint64_t* handles_buf = handles->AddNAlreadyReserved(num_arrays); + handle_generator_.BulkNew(absl::MakeSpan(handles_buf, num_arrays)); + + // Install the newly created arrays into the arrays_. + { + absl::MutexLock lock(&arrays_mutex_); + for (int i = 0; i < num_arrays; ++i) { + arrays_.insert({handles_buf[i], single_device_arrays[i]}); + } + } + + return response; +} + +Future IfrtBackend::HandleCheckArrayReadyRequest( + std::unique_ptr request) { + auto array = GetArray(request->check_array_ready_request().array_handle()); + if (!array.ok()) { + return Future(array.status()); + } + + auto ifrt_response_promise = + Future::CreatePromise(); + Future ifrt_response_future( + ifrt_response_promise); + + (*array)->GetReadyFuture().OnReady( + [op_id = request->request_metadata().op_id(), + promise = std::move(ifrt_response_promise)]( + absl::Status status) mutable -> void { + if (!status.ok()) { + promise.Set(std::move(status)); + return; + } + auto ifrt_response = NewIfrtResponse(op_id); + ifrt_response->mutable_check_array_ready_response(); + promise.Set(std::move(ifrt_response)); + }); + return ifrt_response_future; +} + +BackendInterface::Response IfrtBackend::HandleReshardRequest( + std::unique_ptr request) { + const auto& reshard_request = request->reshard_request(); + TF_ASSIGN_OR_RETURN(auto array, GetArray(reshard_request.array_handle())); + TF_ASSIGN_OR_RETURN( + std::shared_ptr sharding, + FromShardingProto(absl::bind_front(&Client::LookupDevice, client_.get()), + reshard_request.sharding())); + TF_ASSIGN_OR_RETURN(auto semantics, FromArrayCopySemanticsProto( + reshard_request.copy_semantics())); + + TF_ASSIGN_OR_RETURN(auto resharded_array, + array->Reshard(sharding, semantics)); + + uint64_t resharded_array_handle = handle_generator_.New(); + { + absl::MutexLock lock(&arrays_mutex_); + arrays_.insert({resharded_array_handle, std::move(resharded_array)}); + } + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + ifrt_resp->mutable_reshard_response()->set_array_handle( + resharded_array_handle); + return ifrt_resp; +} + +BackendInterface::Response IfrtBackend::HandleFullyReplicatedShardRequest( + std::unique_ptr request) { + const auto& fully_replicated_shard_request = + request->fully_replicated_shard_request(); + TF_ASSIGN_OR_RETURN(auto array, + GetArray(fully_replicated_shard_request.array_handle())); + TF_ASSIGN_OR_RETURN(auto semantics, + FromArrayCopySemanticsProto( + fully_replicated_shard_request.copy_semantics())); + + // Here we are making the assumption that the `FullyReplicatedShard` returns + // the Array corresponding to the first device in the sharding - as needed by + // the proxy client for making the SingleDeviceSharding corresponding to the + // newly created array. Revisit this when IFRT supports: (1) an inexpensive + // way to derive a SingleDeviceSharding from a fully replicated Array's + // sharding and (2) A generalized Reshard API that allows the user to request + // an Array to be made out of a specific single shard. + TF_ASSIGN_OR_RETURN(auto new_array, array->FullyReplicatedShard(semantics)); + + uint64_t new_array_handle = handle_generator_.New(); + { + absl::MutexLock lock(&arrays_mutex_); + arrays_.insert({new_array_handle, std::move(new_array)}); + } + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + ifrt_resp->mutable_fully_replicated_shard_response()->set_array_handle( + new_array_handle); + return ifrt_resp; +} + +BackendInterface::Response IfrtBackend::HandleDeleteArrayRequest( + std::unique_ptr request) { + TF_ASSIGN_OR_RETURN(auto array, + GetArray(request->delete_array_request().array_handle())); + + auto deletion_future = array->Delete(); + uint64_t future_handle = handle_generator_.New(); + { + absl::MutexLock lock(&futures_mutex_); + futures_.insert({future_handle, std::move(deletion_future)}); + } + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + ifrt_resp->mutable_delete_array_response()->set_deletion_future_handle( + future_handle); + return ifrt_resp; +} + +BackendInterface::Response IfrtBackend::HandleIsArrayDeletedRequest( + std::unique_ptr request) { + TF_ASSIGN_OR_RETURN( + auto array, GetArray(request->is_array_deleted_request().array_handle())); + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + ifrt_resp->mutable_is_array_deleted_response()->set_deleted( + array->IsDeleted()); + return ifrt_resp; +} + +BackendInterface::Response IfrtBackend::HandleDestructArrayRequest( + std::unique_ptr request) { + { + absl::MutexLock lock(&arrays_mutex_); + bool deleted = + arrays_.erase(request->destruct_array_request().array_handle()); + if (!deleted) { + return absl::NotFoundError( + absl::StrCat("Unknown array handle: ", + request->destruct_array_request().array_handle())); + } + } + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + + // Currently DestructArrayResponse is an empty message, but proxy clients may + // rely on its presence for correct demuxing. + ifrt_resp->mutable_destruct_array_response(); + return ifrt_resp; +} + +Future IfrtBackend::HandleCompileRequest( + std::unique_ptr request) { + // Perform compilation on a thread pool in order to (1) avoid blocking the RPC + // thread during compilation and (2) run compilation with bigger stacks (often + // necessary for XLA). + auto f = [this, request = std::shared_ptr( + std::move(request))]() -> Response { + const CompileRequest& compile_request = request->compile_request(); + + TF_ASSIGN_OR_RETURN(auto program, Deserialize( + compile_request.program())); + TF_ASSIGN_OR_RETURN(auto options, Deserialize( + compile_request.compile_options())); + + // Deserialize host callbacks. IFRT proxy currently allows only one type of + // host callbacks from the client (`RemoteLoadedHostCallback`) and this is + // serialized out of band into its own field in the request proto. + std::vector> + host_callback_queues; + { + std::vector> + loaded_host_callbacks; + for (int i = 0; i < compile_request.host_callbacks_size(); ++i) { + host_callback_queues.emplace_back( + std::make_shared()); + TF_ASSIGN_OR_RETURN( + loaded_host_callbacks.emplace_back(), + RemoteLoadedHostCallback::CreateFromSerialized( + client_.get(), compile_request.host_callbacks(i), + host_callback_queues.back())); + } + if (!loaded_host_callbacks.empty()) { + if (auto xla_options = + llvm::dyn_cast(options.get())) { + xla_options->loaded_host_callbacks = std::move(loaded_host_callbacks); + } else { + return absl::UnimplementedError( + "Host callbacks are supported only for XLA-like IFRT " + "implementations using `xla::ifrt::XlaCompileOptions`"); + } + } + } + + TF_ASSIGN_OR_RETURN(auto executable, + client_->GetDefaultCompiler()->Compile( + std::move(program), std::move(options))); + + std::unique_ptr ifrt_resp = + NewIfrtResponse(request->request_metadata().op_id()); + auto* compile_resp = ifrt_resp->mutable_compile_response(); + + uint64_t handle = handle_generator_.New(); + compile_resp->set_loaded_executable_handle(handle); + + std::vector host_callback_handles(host_callback_queues.size()); + handle_generator_.BulkNew(absl::MakeSpan(host_callback_handles)); + compile_resp->mutable_loaded_host_callback_handles()->Add( + host_callback_handles.begin(), host_callback_handles.end()); + + // Populate executable metadata. + compile_resp->set_name(AsProtoStringData(executable->name())); + compile_resp->set_num_devices(executable->num_devices()); + for (const auto& logical_device_id : + executable->addressable_device_logical_ids()) { + LogicalDeviceIds* proto = + compile_resp->add_addressable_device_logical_ids(); + proto->set_replica(logical_device_id.replica); + proto->set_partition(logical_device_id.partition); + } + for (const auto* device : executable->addressable_devices()) { + compile_resp->add_addressable_device_ids(device->id()); + } + // TODO(b/282757875): Consider making fingerprint calculation asynchronous + // if it is expected to take long. + auto fingerprint = executable->Fingerprint(); + if (!fingerprint.ok()) { + *compile_resp->mutable_fingerprint_error() = + tsl::StatusToProto(fingerprint.status()); + } else if (fingerprint->has_value()) { + compile_resp->set_fingerprint_value(std::move(fingerprint)->value()); + } + // Register the ready future to `futures_`. Caller is expected to call + // `CheckFuture` exactly once to check for its status and erase it. In + // future, we may introduce separate mechanisms to remove futures from + // `futures_` without checking its status for situations where futures are + // not used. + { + absl::MutexLock lock(&futures_mutex_); + compile_resp->set_ready_future_handle(handle_generator_.New()); + futures_.insert( + {compile_resp->ready_future_handle(), executable->GetReadyFuture()}); + } + + { + absl::MutexLock lock(&executables_mutex_); + executables_.insert({handle, std::move(executable)}); + } + { + absl::MutexLock lock(&host_callback_queues_mutex_); + for (int i = 0; i < host_callback_queues.size(); ++i) { + host_callback_queues_.insert( + {host_callback_handles[i], std::move(host_callback_queues[i])}); + } + } + + return ifrt_resp; + }; + return AsyncExecute(std::move(f), &compile_thread_pool_); +} + +Future +IfrtBackend::HandleLoadedExecutableMetadataRequest( + std::unique_ptr request) { + // Call `GetParameterShardings` and `GetOutputShardings` on a thread pool + // since some implementations may block until compilation completes. + return AsyncExecute([this, request = std::shared_ptr( + std::move(request))]() -> Response { + const uint64_t handle = request->loaded_executable_metadata_request() + .loaded_executable_handle(); + TF_ASSIGN_OR_RETURN(std::shared_ptr executable, + GetLoadedExecutable(handle)); + + std::unique_ptr ifrt_resp = + NewIfrtResponse(request->request_metadata().op_id()); + auto* metadata_resp = + ifrt_resp->mutable_loaded_executable_metadata_response(); + + if (auto parameter_shardings = executable->GetParameterShardings(); + parameter_shardings.has_value()) { + metadata_resp->mutable_parameter_shardings()->mutable_shardings()->Add( + parameter_shardings->begin(), parameter_shardings->end()); + } + if (auto output_shardings = executable->GetOutputShardings(); + output_shardings.has_value()) { + metadata_resp->mutable_output_shardings()->mutable_shardings()->Add( + output_shardings->begin(), output_shardings->end()); + } + + if (auto parameter_layouts = executable->GetParameterLayouts(); + parameter_layouts.ok()) { + auto* const layouts = + metadata_resp->mutable_parameter_layouts_list()->mutable_layouts(); + for (const std::unique_ptr& parameter_layout : + *parameter_layouts) { + // TODO(b/329165105): use PjRtLayout::Serialize instead + const xla::PjRtXlaLayout* layout = + dynamic_cast(parameter_layout.get()); + TF_RET_CHECK(layout != nullptr) + << "IFRT proxy only supports PjRtXlaLayout, got a different " + "subclass"; + layouts->Add(layout->xla_layout().ToProto()); + } + } else { + *metadata_resp->mutable_parameter_layouts_error() = + tsl::StatusToProto(parameter_layouts.status()); + } + if (auto output_layouts = executable->GetOutputLayouts(); + output_layouts.ok()) { + auto* const layouts = + metadata_resp->mutable_output_layouts_list()->mutable_layouts(); + for (const std::unique_ptr& output_layout : + *output_layouts) { + // TODO(b/329165105): use PjRtLayout::Serialize instead + const xla::PjRtXlaLayout* layout = + dynamic_cast(output_layout.get()); + TF_RET_CHECK(layout != nullptr) + << "IFRT proxy only supports PjRtXlaLayout, got a different " + "subclass"; + layouts->Add(layout->xla_layout().ToProto()); + } + } else { + *metadata_resp->mutable_output_layouts_error() = + tsl::StatusToProto(output_layouts.status()); + } + + auto output_memory_kinds = executable->GetOutputMemoryKinds(); + if (output_memory_kinds.ok()) { + for (const auto& memory_kinds : *output_memory_kinds) { + auto* const list = metadata_resp->mutable_output_memory_kinds() + ->add_memory_kind_lists() + ->mutable_memory_kinds(); + list->Reserve(memory_kinds.size()); + list->Add(memory_kinds.begin(), memory_kinds.end()); + } + } else { + *metadata_resp->mutable_output_memory_kinds()->mutable_status() = + tsl::StatusToProto(output_memory_kinds.status()); + } + + return ifrt_resp; + }); +} + +BackendInterface::Response IfrtBackend::HandleLoadedExecutableExecuteRequest( + std::unique_ptr request) { + const LoadedExecutableExecuteRequest& execute = + request->loaded_executable_execute_request(); + TF_ASSIGN_OR_RETURN(std::shared_ptr executable, + GetLoadedExecutable(execute.loaded_executable_handle())); + + std::vector> args; + args.reserve(execute.args_handles_size()); + { + absl::ReaderMutexLock lock(&arrays_mutex_); + for (const uint64_t handle : execute.args_handles()) { + TF_ASSIGN_OR_RETURN(args.emplace_back(), GetArrayLocked(handle)); + } + } + + TF_ASSIGN_OR_RETURN(auto execute_options, + xla::ifrt::LoadedExecutable::ExecuteOptions::FromProto( + execute.execute_options())); + + std::optional devices; + if (!execute.device_ids().empty()) { + DeviceList::Devices d; + d.reserve(execute.device_ids_size()); + for (const int32_t device_id : execute.device_ids()) { + TF_ASSIGN_OR_RETURN(d.emplace_back(), client_->LookupDevice(device_id)); + } + devices = DeviceList(std::move(d)); + } + + TF_ASSIGN_OR_RETURN( + xla::ifrt::LoadedExecutable::ExecuteResult result, + executable->Execute(absl::MakeSpan(args), execute_options, devices)); + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + LoadedExecutableExecuteResponse* execute_response = + ifrt_resp->mutable_loaded_executable_execute_response(); + + // Register the future to `futures_`. Caller is expected to call + // `CheckFuture` exactly once to check for its status and erase it. In future, + // we may introduce separate mechanisms to remove futures from `futures_` + // without checking its status for situations where futures are not used. + { + absl::MutexLock lock(&futures_mutex_); + execute_response->set_status_handle(handle_generator_.New()); + futures_.insert( + {execute_response->status_handle(), std::move(result.status)}); + } + + // Register output arrays. At this point, we should never early return because + // doing so will leak futures or output arrays registered so far. + std::vector output_handles(result.outputs.size()); + handle_generator_.BulkNew(absl::MakeSpan(output_handles)); + { + absl::MutexLock lock(&arrays_mutex_); + for (int i = 0; i < result.outputs.size(); ++i) { + tsl::RCReference& array = result.outputs[i]; + + LoadedExecutableExecuteResponse::Output* output = + execute_response->add_outputs(); + *output->mutable_dtype() = array->dtype().ToProto(); + *output->mutable_shape() = array->shape().ToProto(); + TF_ASSIGN_OR_RETURN(*output->mutable_sharding(), + ToShardingProto(array->sharding())); + output->set_array_handle(output_handles[i]); + + arrays_.insert({output_handles[i], std::move(array)}); + } + } + + return ifrt_resp; +} + +BackendInterface::Response IfrtBackend::HandleLoadedExecutableDeleteRequest( + std::unique_ptr request) { + const auto& del = request->loaded_executable_delete_request(); + TF_ASSIGN_OR_RETURN(std::shared_ptr executable, + GetLoadedExecutable(del.loaded_executable_handle())); + + Future future = executable->Delete(); + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + auto* del_response = ifrt_resp->mutable_loaded_executable_delete_response(); + + { + absl::MutexLock lock(&futures_mutex_); + del_response->set_future_handle(handle_generator_.New()); + futures_.insert({del_response->future_handle(), std::move(future)}); + } + + return ifrt_resp; +} + +BackendInterface::Response IfrtBackend::HandleLoadedExecutableIsDeletedRequest( + std::unique_ptr request) { + const auto& is_deleted = request->loaded_executable_is_deleted_request(); + TF_ASSIGN_OR_RETURN( + std::shared_ptr executable, + GetLoadedExecutable(is_deleted.loaded_executable_handle())); + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + auto* is_deleted_response = + ifrt_resp->mutable_loaded_executable_is_deleted_response(); + is_deleted_response->set_is_deleted(executable->IsDeleted()); + + return ifrt_resp; +} + +BackendInterface::Response IfrtBackend::HandleLoadedExecutableDestructRequest( + std::unique_ptr request) { + const auto& destruct = request->loaded_executable_destruct_request(); + + std::shared_ptr executable; + { + absl::MutexLock lock(&executables_mutex_); + const auto it = executables_.find(destruct.loaded_executable_handle()); + if (it == executables_.end()) { + return absl::NotFoundError( + absl::StrCat("Unknown loaded executable handle: ", + destruct.loaded_executable_handle())); + } + executable = std::move(it->second); + executables_.erase(it); + } + executable.reset(); + + // `RemoteLoadedHostCallback`'s request queue is closed when the host callback + // objects are destroyed by the underlying IFRT implementation when there are + // no more host callback executions to be done. + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + ifrt_resp->mutable_loaded_executable_destruct_response(); + return ifrt_resp; +} + +Future +IfrtBackend::HandleLoadedHostCallbackPollRequest( + std::unique_ptr request) { + return AsyncExecute([this, request = std::shared_ptr( + std::move(request))]() -> Response { + const auto& poll = request->loaded_host_callback_poll_request(); + const uint64_t handle = poll.loaded_host_callback_handle(); + + // Find the host callback queue associated with the given handle. + std::shared_ptr queue; + { + absl::MutexLock lock(&host_callback_queues_mutex_); + auto it = host_callback_queues_.find(handle); + if (it == host_callback_queues_.end()) { + return absl::NotFoundError( + absl::StrCat("Unknown loaded host callback handle: ", handle)); + } + queue = it->second; + } + + // Block until the host callback has any pending execution and pop its + // execution info. May return a nullopt if the host callback has been + // deleted by the underlying IFRT implementation. + auto execution_request = queue->Pop(); + if (!execution_request.has_value()) { + { + absl::MutexLock lock(&host_callback_queues_mutex_); + host_callback_queues_.erase(handle); + } + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + ifrt_resp->mutable_loaded_host_callback_poll_response(); + return ifrt_resp; + } + + // After this point, we must fulfill the promise eventually in order to + // avoid deadlock (`absl::Cleanup` ensures this). + + absl::Cleanup cleanup = [&] { + std::move(execution_request) + ->status.Set(absl::UnknownError( + "Unable to enqueue the host callback execution")); + }; + + // Store the operands as a single contiguous buffer in the host buffer + // store. The client retrieves it by invoking `HostBufferLookup`. + { + std::string buffer; + for (const auto& operand : execution_request->operands) { + buffer.append(static_cast(operand.data), operand.size); + } + TF_RETURN_IF_ERROR(host_buffer_store_->Store( + poll.operand_host_buffer_handle(), std::move(buffer))); + } + + const uint64_t execution_handle = handle_generator_.New(); + { + absl::MutexLock lock(&host_callback_executions_mutex_); + host_callback_executions_.insert( + {execution_handle, *std::move(execution_request)}); + } + std::move(cleanup).Cancel(); + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + auto* poll_response = + ifrt_resp->mutable_loaded_host_callback_poll_response(); + poll_response->set_host_callback_execution_handle(execution_handle); + return ifrt_resp; + }); +} + +BackendInterface::Response IfrtBackend::HandleLoadedHostCallbackReturnRequest( + std::unique_ptr request) { + const auto& ret = request->loaded_host_callback_return_request(); + + RemoteLoadedHostCallbackQueue::ExecutionRequest execution_request; + { + absl::MutexLock lock(&host_callback_executions_mutex_); + const auto it = + host_callback_executions_.find(ret.host_callback_execution_handle()); + if (it == host_callback_executions_.end()) { + return absl::NotFoundError( + absl::StrCat("Unknown host callback execution: ", + ret.host_callback_execution_handle())); + } + execution_request = std::move(it->second); + host_callback_executions_.erase(it); + } + absl::Cleanup cleanup = [&] { + std::move(execution_request) + .status.Set(absl::UnknownError( + "Unable to process the host callback execution results")); + }; + + // Copy the results from the host buffer store to the preallocated result + // buffers from `RemoteLoadedHostCallback`. Must be done before fulfilling the + // promise since the buffers may not be alive after that. + absl::Status status; + if (ret.has_result_host_buffer_handle()) { + TF_ASSIGN_OR_RETURN( + std::shared_ptr buffer, + host_buffer_store_->Lookup(ret.result_host_buffer_handle())); + absl::Cleanup cleanup = [&] { + CHECK_OK(host_buffer_store_->Delete(ret.result_host_buffer_handle())); + }; + + int64_t offset = 0; + for (const auto& result : execution_request.results) { + if (offset + result.size > buffer->size()) { + return absl::InternalError( + absl::StrCat("Buffer overflow while reading host callback " + "execution results; ", + "range: [", offset, ", ", offset + result.size, "), ", + "buffer size: ", buffer->size())); + } + std::memcpy(result.data, buffer->data() + offset, result.size); + offset += result.size; + } + if (offset != buffer->size()) { + return absl::InternalError( + absl::StrCat("Host callback execution did not consume the entire " + "result buffer; size: ", + buffer->size(), "; consumed: ", offset)); + } + } else { + status = tsl::StatusFromProto(ret.error()); + } + + // Fulfill the result promise. This unblocks the execution of the associated + // `RemoteLoadedHostCallback`. It is unsafe to access `execution_request` + // after this since the buffers may not be alive. + std::move(execution_request).status.Set(std::move(status)); + std::move(cleanup).Cancel(); + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + ifrt_resp->mutable_loaded_host_callback_return_response(); + return ifrt_resp; +} + +BackendInterface::Response IfrtBackend::HandleGetDefaultDeviceAssignmentRequest( + std::unique_ptr request) { + const auto& get_default_device_assignment_request = + request->get_default_device_assignment_request(); + TF_ASSIGN_OR_RETURN( + auto assignment, + client_->GetDefaultDeviceAssignment( + get_default_device_assignment_request.num_replicas(), + get_default_device_assignment_request.num_partitions())); + + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); + + // Currently, the xla::DeviceAssignment::Serialize does not fail. If test + // coverage for this error is needed, consider using testing::test_value to + // inject one. + TF_RETURN_IF_ERROR(assignment.Serialize( + ifrt_resp->mutable_get_default_device_assignment_response() + ->mutable_device_assignment())); + + return ifrt_resp; +} + +absl::StatusOr> +IfrtBackend::GetLoadedExecutable(uint64_t handle) { + absl::MutexLock lock(&executables_mutex_); + auto it = executables_.find(handle); + if (it == executables_.end()) { + return absl::NotFoundError( + absl::StrCat("Unknown loaded executable handle: ", handle)); + } + return it->second; +} + +absl::StatusOr> IfrtBackend::GetArray( + uint64_t array_handle) { + absl::ReaderMutexLock lock(&arrays_mutex_); + return GetArrayLocked(array_handle); +} + +absl::StatusOr> IfrtBackend::GetArrayLocked( + uint64_t array_handle) { + auto it = arrays_.find(array_handle); + if (it == arrays_.end()) { + return absl::NotFoundError( + absl::StrCat("Unknown array handle: ", array_handle)); + } + return it->second; +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/server/ifrt_backend.h b/xla/python/ifrt_proxy/server/ifrt_backend.h new file mode 100644 index 0000000000000..9dd57c66dd2a1 --- /dev/null +++ b/xla/python/ifrt_proxy/server/ifrt_backend.h @@ -0,0 +1,207 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_SERVER_IFRT_BACKEND_H_ +#define XLA_PYTHON_IFRT_PROXY_SERVER_IFRT_BACKEND_H_ + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/server/host_buffer.h" +#include "xla/python/ifrt_proxy/server/host_callback.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/threadpool.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// The abstract class `BackendInterface` defines the interface used by the IFRT +// service to interact with a variety of backend runtime system it can utilize. +class BackendInterface { + public: + virtual ~BackendInterface() = default; + + // Currently, responses (particularly those that carry buffer contents) can be + // of non-trivial size. Once we figured out how best to move the data, we may + // want to revise the shared_ptr below to the `IfrtResponse` proto itself. + // Also, if and when we have a move-only Future in xla::ifrt, we may consider + // changing it to std::unique_ptr. + using Response = absl::StatusOr>; + + // Processes a given IFRT Request and returns a Future of an IfrtResponse. + virtual Future Process(std::unique_ptr request) = 0; +}; + +// IfrtBackend implements a backend that already has a linkable C++ client that +// conforms to the xla::ifrt API. +class IfrtBackend final : public BackendInterface { + public: + // Creates an returns an IfrtBackend that uses the given IFRT Client to + // process the incoming proxy client requests. The `ifrt_client` param cannot + // be a nullptr. + static absl::StatusOr> Create( + IfrtProxyVersion version, uint64_t session_id, + std::unique_ptr ifrt_client, + std::shared_ptr host_buffer_store); + + ~IfrtBackend() override; + + // IFRT Proxy version negotiated between the client and the server. + const IfrtProxyVersion& version() const { return version_; } + + Future Process(std::unique_ptr request) override; + + private: + // Generates unique handles for returning to the client. All object types + // currently use this single "handle space". + class HandleGenerator { + public: + uint64_t New(); + + // Bulk allocates a given number of handles and saves them into the provided + // Span. + void BulkNew(absl::Span handles); + + private: + absl::Mutex mu_; + uint64_t current_ ABSL_GUARDED_BY(mu_) = 1; + }; + + IfrtBackend(IfrtProxyVersion version, uint64_t session_id, + std::unique_ptr ifrt_client, + std::shared_ptr host_buffer_store); + + // Executes the given function on the given thread pool and returns a future + // that becomes ready when the function returns. If the thread pool is not + // given, uses a default thread pool implementation that does not limit the + // maximum number of threads. + Future AsyncExecute(std::function handle_fn, + tsl::thread::ThreadPool* thread_pool = nullptr); + + ////////////////////////////////////////////////////////////////////// + // Handlers for individual requests + // + + Response HandleInit(std::unique_ptr request); + + Future HandleCheckFutureRequest( + std::unique_ptr request); + + Response HandleMakeArrayFromHostBufferRequest( + std::unique_ptr request); + Response HandleAssembleArrayFromSingleDeviceArraysRequest( + std::unique_ptr request); + Future HandleCopyToHostBufferRequest( + std::unique_ptr request); + Response HandleDisassembleIntoSingleDeviceArraysRequest( + std::unique_ptr request); + Response HandleReshardRequest(std::unique_ptr request); + Response HandleFullyReplicatedShardRequest( + std::unique_ptr request); + Future HandleCheckArrayReadyRequest( + std::unique_ptr request); + Response HandleDeleteArrayRequest(std::unique_ptr request); + Response HandleIsArrayDeletedRequest(std::unique_ptr request); + Response HandleDestructArrayRequest(std::unique_ptr request); + + Future HandleCompileRequest(std::unique_ptr request); + + Future HandleLoadedExecutableMetadataRequest( + std::unique_ptr request); + Response HandleLoadedExecutableExecuteRequest( + std::unique_ptr request); + Response HandleLoadedExecutableDeleteRequest( + std::unique_ptr request); + Response HandleLoadedExecutableIsDeletedRequest( + std::unique_ptr request); + Response HandleLoadedExecutableDestructRequest( + std::unique_ptr request); + + Future HandleLoadedHostCallbackPollRequest( + std::unique_ptr request); + Response HandleLoadedHostCallbackReturnRequest( + std::unique_ptr request); + + Response HandleGetDefaultDeviceAssignmentRequest( + std::unique_ptr request); + + ////////////////////////////////////////////////////////////////////// + // Convenient methods for object lookups + // + + absl::StatusOr> + GetLoadedExecutable(uint64_t handle); + + absl::StatusOr> GetArray(uint64_t handle); + absl::StatusOr> GetArrayLocked( + uint64_t handle) ABSL_SHARED_LOCKS_REQUIRED(arrays_mutex_); + + HandleGenerator handle_generator_; + + // Must not change during the life of this object. + const IfrtProxyVersion version_; + const uint64_t session_id_; + const std::unique_ptr client_; + const std::shared_ptr host_buffer_store_; + + absl::Mutex futures_mutex_; + absl::flat_hash_map> futures_ + ABSL_GUARDED_BY(futures_mutex_); + + absl::Mutex arrays_mutex_; + absl::flat_hash_map> arrays_ + ABSL_GUARDED_BY(arrays_mutex_); + + absl::Mutex executables_mutex_; + absl::flat_hash_map> + executables_ ABSL_GUARDED_BY(executables_mutex_); + + absl::Mutex host_callback_queues_mutex_; + absl::flat_hash_map> + host_callback_queues_ ABSL_GUARDED_BY(host_callback_queues_mutex_); + + absl::Mutex host_callback_executions_mutex_; + absl::flat_hash_map + host_callback_executions_ + ABSL_GUARDED_BY(host_callback_executions_mutex_); + + absl::Mutex in_flight_count_mutex_; + int64_t in_flight_count_ ABSL_GUARDED_BY(in_flight_count_mutex_) = 0; + + // Use a separate thread pool for compilation as XLA compilation often + // requires a bigger stack. + tsl::thread::ThreadPool compile_thread_pool_; +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_SERVER_IFRT_BACKEND_H_ diff --git a/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/xla/python/ifrt_proxy/server/ifrt_backend_test.cc new file mode 100644 index 0000000000000..085c4239144f4 --- /dev/null +++ b/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -0,0 +1,1414 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/ifrt_backend.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_device_description.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/mock.h" +#include "xla/python/ifrt/serdes.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/sharding_serdes.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/types.pb.h" +#include "xla/python/ifrt_proxy/server/host_buffer.h" +#include "xla/python/ifrt_proxy/server/host_callback.h" +#include "xla/python/ifrt_proxy/server/version.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/service/computation_placer.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/test.h" +#include "xla/xla_data.pb.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/status_to_from_proto.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" +#include "tsl/protobuf/error_codes.pb.h" +#include "tsl/protobuf/status.pb.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::testing::_; +using ::testing::ByMove; +using ::testing::DoAll; +using ::testing::ElementsAreArray; +using ::testing::HasSubstr; +using ::testing::Invoke; +using ::testing::Not; +using ::testing::NotNull; +using ::testing::Pointee; +using ::testing::Return; +using ::testing::ReturnRef; +using ::testing::SizeIs; +using ::testing::StrEq; +using ::tsl::protobuf::TextFormat; +using ::tsl::testing::IsOk; +using ::tsl::testing::IsOkAndHolds; +using ::tsl::testing::StatusIs; + +#if defined(PLATFORM_GOOGLE) +using ::testing::EquivToProto; +using ::testing::proto::IgnoringRepeatedFieldOrdering; +using ::testing::proto::Partially; +#endif + +constexpr uint64_t kSessionId = 12345; + +IfrtProxyVersion Version() { + IfrtProxyVersion version; + version.set_protocol_version(kServerMaxVersion); + return version; +} + +// Makes an empty request with the given op_id. Does not fail. +std::unique_ptr NewIfrtRequest(uint64_t op_id) { + auto ifrt_request = std::make_unique(); + auto* request_metadata = ifrt_request->mutable_request_metadata(); + request_metadata->set_op_id(op_id); + return ifrt_request; +} + +TEST(IfrtBackendTest, CreationFailsWithNullIfrtClient) { + EXPECT_THAT(IfrtBackend::Create(Version(), kSessionId, nullptr, nullptr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(IfrtBackendTest, SuccessfulCreation) { + auto ifrt_client = std::make_unique(); + ASSERT_THAT(IfrtBackend::Create(Version(), kSessionId, std::move(ifrt_client), + std::make_shared()), + IsOk()); +} + +TEST(IfrtBackendTest, ShutdownSucceeds) { + auto ifrt_client = std::make_unique(); + TF_ASSERT_OK_AND_ASSIGN( + auto ifrt_backend, + IfrtBackend::Create(Version(), kSessionId, std::move(ifrt_client), + std::make_shared())); +} + +TEST(IfrtBackendTest, ProcessFailsWithNoRequestSet) { + auto ifrt_client = std::make_unique(); + TF_ASSERT_OK_AND_ASSIGN( + auto ifrt_backend, + IfrtBackend::Create(Version(), kSessionId, std::move(ifrt_client), + std::make_shared())); + + // Make a new request but leave the `OneOf` `request` field unset. And, that + // should fail the Process call. + auto request = std::make_unique(); + auto process_status = ifrt_backend->Process(std::move(request)).Await(); + ASSERT_THAT(process_status, Not(IsOk())); +} + +struct TestProgram : llvm::RTTIExtends { + static char ID; // NOLINT +}; + +[[maybe_unused]] char TestProgram::ID = 0; // NOLINT + +class TestProgramSerDes : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::proxy::TestProgram"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + CHECK(llvm::isa(serializable)); + return ""; + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + return std::make_unique(); + } + + static char ID; // NOLINT +}; + +[[maybe_unused]] char TestProgramSerDes::ID = 0; // NOLINT + +struct TestCompileOptions + : llvm::RTTIExtends { + static char ID; // NOLINT +}; + +[[maybe_unused]] char TestCompileOptions::ID = 0; // NOLINT + +class TestCompileOptionsSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::proxy::TestCompileOptions"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + CHECK(llvm::isa(serializable)); + return ""; + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr options) override { + return std::make_unique(); + } + + static char ID; // NOLINT +}; + +[[maybe_unused]] char TestCompileOptionsSerDes::ID = 0; // NOLINT + +class IfrtBackendHandlerTest : public testing::Test { + protected: + static void SetUpTestSuite() { + RegisterSerDes(std::make_unique()); + RegisterSerDes( + std::make_unique()); + } + + void SetUp() override { + auto mock_client = std::make_unique(); + + std::vector raw_device_ptrs; + for (int i = 0; i < 2; ++i) { + auto mock_device = std::make_unique(); + ON_CALL(*mock_device, global_device_id()) + .WillByDefault(Return(xla::PjRtGlobalDeviceId(i))); + raw_device_ptrs.push_back(mock_device.get()); + mock_devices_.push_back(std::move(mock_device)); + } + + ON_CALL(*mock_client, devices()).WillByDefault(Return(raw_device_ptrs)); + ON_CALL(*mock_client, LookupDevice(_)) + .WillByDefault( + Invoke([this](int id) -> absl::StatusOr { + if (id < 0 || id >= mock_devices_.size()) { + return absl::NotFoundError( + absl::StrCat("Unknown device id: ", id)); + } + return mock_devices_[id].get(); + })); + + // Remembering a raw pointer to the mock client here is OK, since most tests + // anyway have to make the basic and tacit assumption that the backend will + // call into the mock client --and thus keep it alive-- for the duration of + // the test. + mock_client_ = mock_client.get(); + + EXPECT_CALL(*mock_client_, GetDefaultCompiler) + .WillRepeatedly(Return(&mock_compiler_)); + + host_buffer_store_ = std::make_shared(); + TF_ASSERT_OK_AND_ASSIGN( + backend_, + IfrtBackend::Create(Version(), kSessionId, std::move(mock_client), + host_buffer_store_)); + } + + absl::StatusOr> CallBackend( + std::unique_ptr request) { + auto response_future = backend_->Process(std::move(request)); + return std::move(response_future).Await(); + } + + uint64_t NewOpId() { + absl::MutexLock lock(&mu_); + return current_op_id_++; + } + + uint64_t NewHostBufferHandle() { return current_host_buffer_handle_++; } + + // Utility method to set up a given MockArray (in the backend) that can then + // be the target of the other Array-specific methods. Returns the array + // handle. + absl::StatusOr MakeTestArray(tsl::RCReference mock_array) { + EXPECT_CALL(*mock_client_, MakeArrayFromHostBuffer(_, _, _, _, _, _, _)) + .WillOnce(Return(std::move(mock_array))); + + auto ifrt_request = NewIfrtRequest(NewOpId()); + { + const uint64_t host_buffer_handle = NewHostBufferHandle(); + TF_RETURN_IF_ERROR( + host_buffer_store_->Store(host_buffer_handle, "01234567")); + + auto* make_array = + ifrt_request->mutable_make_array_from_host_buffer_request(); + make_array->mutable_dtype()->set_kind(DTypeProto::KIND_S32); + make_array->mutable_shape()->add_dims(2); + make_array->set_host_buffer_handle(host_buffer_handle); + + TF_ASSIGN_OR_RETURN(auto* device, mock_client_->LookupDevice(1)); + TF_ASSIGN_OR_RETURN( + *make_array->mutable_sharding(), + ToShardingProto(*SingleDeviceSharding::Create(device, MemoryKind()))); + } + TF_ASSIGN_OR_RETURN(auto make_array_response, + CallBackend(std::move(ifrt_request))); + + TF_RETURN_IF_ERROR(tsl::StatusFromProto( + make_array_response->response_metadata().status())); + return make_array_response->make_array_from_host_buffer_response() + .array_handle(); + } + + absl::StatusOr CompileTestLoadedExecutable( + absl::StatusOr> loaded_executable) { + auto request = NewIfrtRequest(NewOpId()); + CompileRequest* compile_request = request->mutable_compile_request(); + TestProgram program; + TF_ASSIGN_OR_RETURN(*compile_request->mutable_program(), + Serialize(program)); + TestCompileOptions compile_options; + TF_ASSIGN_OR_RETURN(*compile_request->mutable_compile_options(), + Serialize(compile_options)); + + EXPECT_CALL(mock_compiler_, Compile(_, _)) + .WillOnce(Return(ByMove(std::move(loaded_executable)))); + + TF_ASSIGN_OR_RETURN(std::shared_ptr response, + CallBackend(std::move(request))); + + TF_RET_CHECK(response->has_compile_response()); + return response->compile_response(); + } + + absl::Status CheckFuture(uint64_t handle) { + auto request = NewIfrtRequest(NewOpId()); + request->mutable_check_future_request()->set_future_handle(handle); + TF_ASSIGN_OR_RETURN(std::shared_ptr response, + CallBackend(std::move(request))); + return tsl::StatusFromProto(response->response_metadata().status()); + } + + xla::ifrt::MockClient* mock_client_; + xla::ifrt::MockCompiler mock_compiler_; + std::vector> mock_devices_; + std::shared_ptr host_buffer_store_; + + private: + absl::Mutex mu_; + uint64_t current_op_id_ ABSL_GUARDED_BY(mu_) = 1; + uint64_t current_host_buffer_handle_ = 1; + + std::unique_ptr backend_; +}; + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(IfrtBackendHandlerTest, Init) { + EXPECT_CALL(*mock_client_, platform_name()) + .WillRepeatedly(Return("ifrt_backend")); + EXPECT_CALL(*mock_client_, platform_version()).WillRepeatedly(Return("n/a")); + EXPECT_CALL(*mock_client_, platform_id()).WillRepeatedly(Return(42)); + EXPECT_CALL(*mock_client_, process_index()).WillRepeatedly(Return(1)); + EXPECT_CALL(*mock_client_, runtime_type()) + .WillRepeatedly(Return("ifrt-service")); + + std::vector> mock_memory_devices; + mock_memory_devices.reserve(mock_devices_.size()); + for (const auto& mock_device : mock_devices_) { + mock_memory_devices.push_back({mock_device.get()}); + } + + std::vector mock_memories(mock_devices_.size()); + for (int i = 0; i < mock_memories.size(); ++i) { + MockMemory& memory = mock_memories[i]; + EXPECT_CALL(memory, devices()) + .WillRepeatedly(Return(mock_memory_devices[i])); + EXPECT_CALL(memory, id()).WillRepeatedly(Return(i)); + EXPECT_CALL(memory, memory_space_kind()).WillRepeatedly(Return("mock")); + } + + std::vector> device_memories; + device_memories.reserve(mock_devices_.size()); + for (int i = 0; i < mock_devices_.size(); ++i) { + device_memories.push_back({&mock_memories[i]}); + } + + using AttributeMap = + absl::flat_hash_map; + std::vector device_attributes(mock_devices_.size()); + + const uint32_t kLocalHardwareId = 1234; + for (int i = 0; i < mock_devices_.size(); ++i) { + device_attributes[i].insert({"name", absl::StrCat("device", i)}); + + MockDevice& mock_device = *mock_devices_[i]; + // TODO(b/314368788): Clean up PJRT device ID APIs. + EXPECT_CALL(mock_device, local_hardware_id_typed()) + .WillRepeatedly(Return(xla::PjRtLocalHardwareId(kLocalHardwareId))); + EXPECT_CALL(mock_device, local_hardware_id()) + .WillRepeatedly(Return(kLocalHardwareId)); + EXPECT_CALL(mock_device, local_device_id()) + .WillRepeatedly(Return(xla::PjRtLocalDeviceId(kLocalHardwareId))); + EXPECT_CALL(mock_device, device_kind()).WillRepeatedly(Return("mock")); + EXPECT_CALL(mock_device, memory_spaces()) + .WillRepeatedly(Return(device_memories[i])); + EXPECT_CALL(mock_device, default_memory_space()) + .WillRepeatedly(Return(&mock_memories[i])); + EXPECT_CALL(mock_device, Attributes()) + .WillRepeatedly(ReturnRef(device_attributes[i])); + } + + auto request = NewIfrtRequest(NewOpId()); + request->mutable_init_request(); + + EXPECT_THAT(CallBackend(std::move(request)), + IsOkAndHolds(Pointee( + Partially(IgnoringRepeatedFieldOrdering(EquivToProto(R"pb( + init_response { + session_id: 12345 + platform_name: "ifrt_backend" + platform_version: "n/a" + platform_id: 42 + process_index: 1 + runtime_type: "ifrt-service" + devices { + id: 0 + local_device_id: 1234 + local_hardware_id: 1234 + device_kind: "mock" + default_memory_id: 0 + memory_ids: [ 0 ] + attributes { + key: "name" + value { string_value: "device0" } + } + } + devices { + id: 1 + local_device_id: 1234 + local_hardware_id: 1234 + device_kind: "mock" + default_memory_id: 1 + memory_ids: [ 1 ] + attributes { + key: "name" + value { string_value: "device1" } + } + } + memories { + id: 0 + memory_space_kind: "mock" + device_ids: [ 0 ] + } + memories { + id: 1 + memory_space_kind: "mock" + device_ids: [ 1 ] + } + } + )pb")))))); +} +#endif + +// TODO(b/282757875): Use the MockRuntime fixture to cover the error cases for +// MakeArrayFromHostBuffer and CopyToHostBuffer methods as well. + +// Consider redoing the happy-path test below with PjRt CPU-only backend for +// non-SingleDeviceSharding. +TEST_F(IfrtBackendHandlerTest, DisassembleIntoSingleDeviceArraysSucceeds) { + // Set up a mock source array that returns two single device arrays on + // disassembly. + std::vector> single_device_arrays; + single_device_arrays.push_back(tsl::MakeRef()); + single_device_arrays.push_back(tsl::MakeRef()); + tsl::RCReference source_mock_array = + tsl::MakeRef(); + EXPECT_CALL(*source_mock_array, DisassembleIntoSingleDeviceArrays(_)) + .WillOnce(Return(std::move(single_device_arrays))); + + // Inject the mock_array. + TF_ASSERT_OK_AND_ASSIGN(auto array_handle, + MakeTestArray(std::move(source_mock_array))); + + // Disassemble. + auto disassemble_request = NewIfrtRequest(NewOpId()); + disassemble_request->mutable_disassemble_into_single_device_arrays_request() + ->set_array_handle(array_handle); + TF_ASSERT_OK_AND_ASSIGN(auto disassemble_response, + CallBackend(std::move(disassemble_request))); + + // We must have gotten back two handles corresponding to the two single device + // arrays we injected. + EXPECT_THAT( + disassemble_response->disassemble_into_single_device_arrays_response() + .single_device_array_handles(), + SizeIs(2)); +} + +TEST_F(IfrtBackendHandlerTest, MakeArrayFromHostBufferSuccess) { + // Given the below shape, dtype, and compact byte_strides, the size of the + // array data needs to be 480 bytes. + const uint64_t kHostBufferHandle = 1234; + ASSERT_THAT( + host_buffer_store_->Store(kHostBufferHandle, std::string(480, 'a')), + IsOk()); + + auto ifrt_request = NewIfrtRequest(NewOpId()); + { + auto* make_array = + ifrt_request->mutable_make_array_from_host_buffer_request(); + ASSERT_TRUE( + TextFormat::ParseFromString(R"pb( + dtype { kind: KIND_F64 } + shape { dims: [ 5, 3, 4 ] } + byte_strides { strides: [ 8, 40, 120 ] } + )pb", + make_array)); + make_array->set_host_buffer_handle(kHostBufferHandle); + TF_ASSERT_OK_AND_ASSIGN(auto* device, mock_client_->LookupDevice(1)); + TF_ASSERT_OK_AND_ASSIGN( + *make_array->mutable_sharding(), + ToShardingProto(*SingleDeviceSharding::Create(device, MemoryKind()))); + } + + const Shape expected_shape({5, 3, 4}); + const std::vector expected_byte_strides_vec = {8, 40, 120}; + const std::optional> expected_byte_strides = + absl::Span(expected_byte_strides_vec); + + tsl::RCReference mock_array = + tsl::MakeRef(); + + EXPECT_CALL(*mock_client_, + MakeArrayFromHostBuffer(_, DType(DType::kF64), expected_shape, + expected_byte_strides, _, _, _)) + .WillOnce(Return(std::move(mock_array))); + + TF_ASSERT_OK_AND_ASSIGN(auto response, CallBackend(std::move(ifrt_request))); + EXPECT_NE(response->make_array_from_host_buffer_response().array_handle(), 0); +} + +TEST_F(IfrtBackendHandlerTest, AssembleArrayFromSingleDeviceArrays) { + auto ifrt_request = NewIfrtRequest(NewOpId()); + { + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + shape { dims: [ 2, 2 ] } + copy_semantics: ARRAY_COPY_SEMANTICS_ALWAYS_COPY + )pb", + ifrt_request + ->mutable_assemble_array_from_single_device_arrays_request())); + TF_ASSERT_OK_AND_ASSIGN(auto* device, mock_client_->LookupDevice(1)); + TF_ASSERT_OK_AND_ASSIGN( + *ifrt_request + ->mutable_assemble_array_from_single_device_arrays_request() + ->mutable_sharding(), + ToShardingProto(*SingleDeviceSharding::Create(device, MemoryKind()))); + } + + std::vector> single_device_arrays; + for (int i = 0; i < 2; ++i) { + auto array = tsl::MakeRef(); + single_device_arrays.push_back(array); + + TF_ASSERT_OK_AND_ASSIGN(uint64_t array_handle, MakeTestArray(array)); + ifrt_request->mutable_assemble_array_from_single_device_arrays_request() + ->add_single_device_array_handles(array_handle); + } + + tsl::RCReference result = + tsl::MakeRef(); + const Shape expected_shape({2, 2}); + + EXPECT_CALL(*mock_client_, + AssembleArrayFromSingleDeviceArrays( + expected_shape, _, ElementsAreArray(single_device_arrays), _)) + .WillOnce(Return(std::move(result))); + + TF_ASSERT_OK_AND_ASSIGN(auto response, CallBackend(std::move(ifrt_request))); + EXPECT_NE(response->assemble_array_from_single_device_arrays_response() + .array_handle(), + 0); +} + +TEST_F(IfrtBackendHandlerTest, CopyToHostSuccess) { + Shape shape({5, 3, 4}); + tsl::RCReference array = + tsl::MakeRef(); + ON_CALL(*array, shape()).WillByDefault(ReturnRef(shape)); + ON_CALL(*array, dtype()).WillByDefault(Return(DType(DType::kF64))); + + TF_ASSERT_OK_AND_ASSIGN(auto array_handle, MakeTestArray(array)); + + auto ifrt_request = NewIfrtRequest(NewOpId()); + auto* copy_to_host = ifrt_request->mutable_copy_to_host_buffer_request(); + ASSERT_TRUE( + TextFormat::ParseFromString(R"pb( + byte_strides { strides: [ 8, 40, 120 ] } + )pb", + copy_to_host)); + copy_to_host->set_array_handle(array_handle); + const uint64_t host_buffer_handle = NewHostBufferHandle(); + copy_to_host->set_host_buffer_handle(host_buffer_handle); + + const std::vector expected_byte_strides_vec = {8, 40, 120}; + const std::optional> expected_byte_strides = + absl::Span(expected_byte_strides_vec); + EXPECT_CALL(*array, CopyToHostBuffer(_, expected_byte_strides, _)) + .WillOnce(Return(Future(absl::OkStatus()))); + + TF_ASSERT_OK_AND_ASSIGN(auto response, CallBackend(std::move(ifrt_request))); + // Given the above shape, dtype, and compact byte_strides, the size of the + // array data needs to be 480 bytes. + EXPECT_THAT(host_buffer_store_->Lookup(host_buffer_handle), + IsOkAndHolds(Pointee(SizeIs(480)))); +} + +TEST_F(IfrtBackendHandlerTest, CopyToHostFailsWithNonExistentArrays) { + auto ifrt_request = NewIfrtRequest(NewOpId()); + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + byte_strides { strides: [ 8, 40, 120 ] } + )pb", + ifrt_request->mutable_copy_to_host_buffer_request())); + ifrt_request->mutable_copy_to_host_buffer_request()->set_array_handle(0); + + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kNotFound)); +} + +TEST_F(IfrtBackendHandlerTest, + DisassembleIntoSingleArrayFailsWhenBackendRuntimeFails) { + // Set up a mock source array that fails the disassembly. + constexpr absl::string_view kDisassembleErrorMessage = + "Some test-injected error message that is unlikely to match other error " + "messages - 1234"; + tsl::RCReference source_mock_array = + tsl::MakeRef(); + EXPECT_CALL(*source_mock_array, DisassembleIntoSingleDeviceArrays(_)) + .WillOnce(Return(absl::UnknownError(kDisassembleErrorMessage))); + + // Set up the mock client to return the source_mock_array when the test tries + // to MakeArrayFromHostBuffer. + TF_ASSERT_OK_AND_ASSIGN(auto array_handle, + MakeTestArray(std::move(source_mock_array))); + + // Disassembly must fail with the error we injected. + auto disassemble_request = NewIfrtRequest(NewOpId()); + disassemble_request->mutable_disassemble_into_single_device_arrays_request() + ->set_array_handle(array_handle); + ASSERT_THAT( + CallBackend(std::move(disassemble_request)), + StatusIs(absl::StatusCode::kUnknown, StrEq(kDisassembleErrorMessage))); +} + +TEST_F(IfrtBackendHandlerTest, ReshardSuccess) { + auto src_mock_array = tsl::MakeRef(); + auto resharded_mock_array = tsl::MakeRef(); + EXPECT_CALL(*src_mock_array, Reshard(_, _)) + .WillOnce(Return(std::move(resharded_mock_array))); + TF_ASSERT_OK_AND_ASSIGN(auto src_array_handle, + MakeTestArray(std::move(src_mock_array))); + + auto ifrt_request = NewIfrtRequest(NewOpId()); + auto* reshard_request = ifrt_request->mutable_reshard_request(); + reshard_request->set_array_handle(src_array_handle); + reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); + TF_ASSERT_OK_AND_ASSIGN(auto* device, mock_client_->LookupDevice(1)); + TF_ASSERT_OK_AND_ASSIGN( + *ifrt_request->mutable_reshard_request()->mutable_sharding(), + ToShardingProto(*SingleDeviceSharding::Create(device, MemoryKind()))); + + TF_ASSERT_OK_AND_ASSIGN(auto response, CallBackend(std::move(ifrt_request))); + + EXPECT_THAT(tsl::StatusFromProto(response->response_metadata().status()), + IsOk()); + EXPECT_NE(response->reshard_response().array_handle(), 0); +} + +TEST_F(IfrtBackendHandlerTest, FullyReplicatedShardSuccess) { + auto fully_replicated_mock_array = tsl::MakeRef(); + auto resultant_array = tsl::MakeRef(); + EXPECT_CALL(*fully_replicated_mock_array, FullyReplicatedShard(_)) + .WillOnce(Return(std::move(resultant_array))); + TF_ASSERT_OK_AND_ASSIGN( + auto fully_replicated_array_handle, + MakeTestArray(std::move(fully_replicated_mock_array))); + + auto ifrt_request = NewIfrtRequest(NewOpId()); + auto* fully_replicated_shard_request = + ifrt_request->mutable_fully_replicated_shard_request(); + fully_replicated_shard_request->set_array_handle( + fully_replicated_array_handle); + fully_replicated_shard_request->set_copy_semantics( + proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); + + TF_ASSERT_OK_AND_ASSIGN(auto response, CallBackend(std::move(ifrt_request))); + EXPECT_NE(response->fully_replicated_shard_response().array_handle(), 0); +} + +TEST_F(IfrtBackendHandlerTest, FullyReplicatedShardFailure) { + auto fully_replicated_mock_array = tsl::MakeRef(); + EXPECT_CALL(*fully_replicated_mock_array, FullyReplicatedShard(_)) + .WillOnce(Return(absl::UnknownError("injected error"))); + TF_ASSERT_OK_AND_ASSIGN( + auto fully_replicated_array_handle, + MakeTestArray(std::move(fully_replicated_mock_array))); + + auto ifrt_request = NewIfrtRequest(NewOpId()); + auto* fully_replicated_shard_request = + ifrt_request->mutable_fully_replicated_shard_request(); + fully_replicated_shard_request->set_array_handle( + fully_replicated_array_handle); + fully_replicated_shard_request->set_copy_semantics( + proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); + + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kUnknown, StrEq("injected error"))); +} + +TEST_F(IfrtBackendHandlerTest, + FullyReplicatedShardFailsWithNonExistentArrayHandle) { + auto ifrt_request = NewIfrtRequest(NewOpId()); + auto* fully_replicated_shard_request = + ifrt_request->mutable_fully_replicated_shard_request(); + fully_replicated_shard_request->set_array_handle(0); + fully_replicated_shard_request->set_copy_semantics( + proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); + + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kNotFound)); +} + +TEST_F(IfrtBackendHandlerTest, ReshardFailsWhenTheBackendFails) { + auto mock_array = tsl::MakeRef(); + EXPECT_CALL(*mock_array, Reshard(_, _)) + .WillOnce(Return(absl::UnknownError("injected error"))); + TF_ASSERT_OK_AND_ASSIGN(auto array_handle, + MakeTestArray(std::move(mock_array))); + + auto ifrt_request = NewIfrtRequest(NewOpId()); + auto* reshard_request = ifrt_request->mutable_reshard_request(); + reshard_request->set_array_handle(array_handle); + reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); + TF_ASSERT_OK_AND_ASSIGN(auto* device, mock_client_->LookupDevice(1)); + TF_ASSERT_OK_AND_ASSIGN( + *ifrt_request->mutable_reshard_request()->mutable_sharding(), + ToShardingProto(*SingleDeviceSharding::Create(device, MemoryKind()))); + + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kUnknown, StrEq("injected error"))); +} + +TEST_F(IfrtBackendHandlerTest, ReshardFailsWithNonExistentArrayHandle) { + auto ifrt_request = NewIfrtRequest(NewOpId()); + auto* reshard_request = ifrt_request->mutable_reshard_request(); + reshard_request->set_array_handle(0); + reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); + reshard_request->mutable_sharding(); + + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kNotFound)); +} + +TEST_F(IfrtBackendHandlerTest, + CheckArrayReadyRequestRelaysTheResultFromBackend) { + auto mock_array = tsl::MakeRef(); + EXPECT_CALL(*mock_array, GetReadyFuture()) + .WillOnce(Return(Future(absl::OkStatus()))) + .WillOnce( + Return(Future(absl::UnknownError("injected error")))); + TF_ASSERT_OK_AND_ASSIGN(auto array_handle, + MakeTestArray(std::move(mock_array))); + + { + auto ifrt_request = NewIfrtRequest(NewOpId()); + ifrt_request->mutable_check_array_ready_request()->set_array_handle( + array_handle); + TF_ASSERT_OK_AND_ASSIGN(auto ifrt_response, + CallBackend(std::move(ifrt_request))); + + EXPECT_THAT(ifrt_response->response_metadata().status().code(), + tensorflow::error::OK); + EXPECT_TRUE(ifrt_response->has_check_array_ready_response()); + } + + { + auto ifrt_request = NewIfrtRequest(NewOpId()); + ifrt_request->mutable_check_array_ready_request()->set_array_handle( + array_handle); + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kUnknown, StrEq("injected error"))); + } +} + +TEST_F(IfrtBackendHandlerTest, + CheckArrayReadyRequestFailsWithNonExistentArrayHandle) { + auto ifrt_request = NewIfrtRequest(NewOpId()); + ifrt_request->mutable_check_array_ready_request()->set_array_handle(0); + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kNotFound)); +} + +TEST_F(IfrtBackendHandlerTest, DeleteArraySuccess) { + tsl::RCReference mock_array = + tsl::MakeRef(); + EXPECT_CALL(*mock_array, Delete()) + .WillOnce(Return(Future(absl::OkStatus()))); + TF_ASSERT_OK_AND_ASSIGN(auto array_handle, + MakeTestArray(std::move(mock_array))); + + uint64_t op_id = NewOpId(); + auto ifrt_request = NewIfrtRequest(op_id); + ifrt_request->mutable_delete_array_request()->set_array_handle(array_handle); + TF_ASSERT_OK_AND_ASSIGN(auto resp, CallBackend(std::move(ifrt_request))); + EXPECT_THAT(tsl::StatusFromProto(resp->response_metadata().status()), IsOk()); + EXPECT_NE(resp->delete_array_response().deletion_future_handle(), 0); +} + +TEST_F(IfrtBackendHandlerTest, DeleteArrayFailsWithNonExistentArrayHandle) { + auto ifrt_request = NewIfrtRequest(NewOpId()); + ifrt_request->mutable_delete_array_request()->set_array_handle(0); + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kNotFound)); +} + +TEST_F(IfrtBackendHandlerTest, + IsDeleteRelaysBackTheReturnValueFromBackendRuntime) { + tsl::RCReference mock_array = + tsl::MakeRef(); + + EXPECT_CALL(*mock_array, IsDeleted()) + .WillOnce(Return(true)) + .WillOnce(Return(false)); + + TF_ASSERT_OK_AND_ASSIGN(auto array_handle, + MakeTestArray(std::move(mock_array))); + + auto ifrt_request = NewIfrtRequest(NewOpId()); + ifrt_request->mutable_is_array_deleted_request()->set_array_handle( + array_handle); + TF_ASSERT_OK_AND_ASSIGN(auto resp, CallBackend(std::move(ifrt_request))); + EXPECT_TRUE(resp->is_array_deleted_response().deleted()); + + ifrt_request = NewIfrtRequest(NewOpId()); + ifrt_request->mutable_is_array_deleted_request()->set_array_handle( + array_handle); + TF_ASSERT_OK_AND_ASSIGN(resp, CallBackend(std::move(ifrt_request))); + EXPECT_FALSE(resp->is_array_deleted_response().deleted()); +} + +TEST_F(IfrtBackendHandlerTest, IsDeleteFailsForNonExistentArrays) { + auto ifrt_request = NewIfrtRequest(NewOpId()); + ifrt_request->mutable_is_array_deleted_request()->set_array_handle(0); + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kNotFound)); +} + +TEST_F(IfrtBackendHandlerTest, DestructArrayTest) { + tsl::RCReference mock_array = + tsl::MakeRef(); + TF_ASSERT_OK_AND_ASSIGN(auto array_handle, + MakeTestArray(std::move(mock_array))); + + auto ifrt_request = NewIfrtRequest(NewOpId()); + ifrt_request->mutable_destruct_array_request()->set_array_handle( + array_handle); + TF_ASSERT_OK_AND_ASSIGN(auto ifrt_resp, CallBackend(std::move(ifrt_request))); + EXPECT_TRUE(ifrt_resp->has_destruct_array_response()); + + // Retrying DestructArray should fail. And, this establishes that: (1) the + // handle no longer exists on the server, (2) DestructArray fails for + // non-existent arrays and (3) DestructArray is not idempotent. + ifrt_request = NewIfrtRequest(NewOpId()); + ifrt_request->mutable_destruct_array_request()->set_array_handle( + array_handle); + EXPECT_THAT(CallBackend(std::move(ifrt_request)), + StatusIs(absl::StatusCode::kNotFound)); +} + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(IfrtBackendHandlerTest, CompileSuccess) { + std::vector devices(4); + for (int i = 0; i < 4; ++i) { + EXPECT_CALL(devices[i], global_device_id()) + .WillOnce(Return(xla::PjRtGlobalDeviceId(i))); + } + + std::vector + addressable_device_logical_ids; + std::vector addressable_devices; + for (int i = 0; i < 4; ++i) { + xla::ifrt::LoadedExecutable::LogicalDeviceIds id{i / 2, i % 2}; + addressable_device_logical_ids.push_back(id); + addressable_devices.push_back(&devices[i]); + } + + auto executable = std::make_unique(); + EXPECT_CALL(*executable, name()).WillOnce(Return("executable_name")); + EXPECT_CALL(*executable, num_devices()).WillOnce(Return(4)); + EXPECT_CALL(*executable, addressable_device_logical_ids()) + .WillOnce(Return(absl::MakeSpan(addressable_device_logical_ids))); + EXPECT_CALL(*executable, addressable_devices()) + .WillOnce(Return(absl::MakeSpan(addressable_devices))); + EXPECT_CALL(*executable, Fingerprint()).WillOnce(Return("fingerprint")); + EXPECT_CALL(*executable, GetReadyFuture()) + .WillOnce(Return(Future(absl::OkStatus()))); + + ASSERT_OK_AND_ASSIGN(CompileResponse response, + CompileTestLoadedExecutable(std::move(executable))); + EXPECT_THAT(response, Partially(EquivToProto(R"pb( + name: "executable_name" + num_devices: 4 + addressable_device_logical_ids { replica: 0 partition: 0 } + addressable_device_logical_ids { replica: 0 partition: 1 } + addressable_device_logical_ids { replica: 1 partition: 0 } + addressable_device_logical_ids { replica: 1 partition: 1 } + addressable_device_ids: [ 0, 1, 2, 3 ] + fingerprint_value: "fingerprint" + )pb"))); + EXPECT_OK(CheckFuture(response.ready_future_handle())); +} +#endif + +TEST_F(IfrtBackendHandlerTest, CompileFailure) { + ASSERT_THAT( + CompileTestLoadedExecutable(absl::InternalError("injected error")), + StatusIs(absl::StatusCode::kInternal, StrEq("injected error"))); +} + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(IfrtBackendHandlerTest, LoadedExecutableMetadata) { + MockLoadedExecutable* executable; + uint64_t handle; + { + auto e = std::make_unique(); + executable = e.get(); + TF_ASSERT_OK_AND_ASSIGN(CompileResponse response, + CompileTestLoadedExecutable(std::move(e))); + handle = response.loaded_executable_handle(); + } + + { + OpSharding op_sharding1; + ASSERT_TRUE( + TextFormat::ParseFromString(R"pb(type: REPLICATED)pb", &op_sharding1)); + + OpSharding op_sharding2; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb(type: OTHER + tile_shape { + element_type: BF16 + dimensions: [ 2, 2 ] + } + tile_assignment_dimensions: [ 0, 1 ])pb", + &op_sharding2)); + + EXPECT_CALL(*executable, GetParameterShardings()) + .WillOnce(Return(std::vector{op_sharding1, op_sharding2})); + + EXPECT_CALL(*executable, GetOutputShardings()) + .WillOnce(Return(std::vector{op_sharding1})); + + std::vector> parameter_layouts; + parameter_layouts.push_back(std::make_unique( + xla::LayoutUtil::MakeDescendingLayout(/*rank=*/1))); + parameter_layouts.push_back(std::make_unique( + xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2))); + EXPECT_CALL(*executable, GetParameterLayouts()) + .WillOnce(Return(std::move(parameter_layouts))); + + std::vector> output_layouts; + output_layouts.push_back(std::make_unique( + xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2))); + EXPECT_CALL(*executable, GetOutputLayouts()) + .WillOnce(Return(std::move(output_layouts))); + EXPECT_CALL(*executable, GetOutputMemoryKinds()) + .WillOnce(Return(std::vector>{{"foo"}})); + + auto request = NewIfrtRequest(NewOpId()); + LoadedExecutableMetadataRequest* metadata_request = + request->mutable_loaded_executable_metadata_request(); + metadata_request->set_loaded_executable_handle(handle); + + EXPECT_THAT(CallBackend(std::move(request)), + IsOkAndHolds(Pointee(Partially(EquivToProto(R"pb( + loaded_executable_metadata_response { + parameter_shardings { + shardings { type: REPLICATED } + shardings { + type: OTHER + tile_shape { + element_type: BF16 + dimensions: [ 2, 2 ] + } + tile_assignment_dimensions: [ 0, 1 ] + } + } + output_shardings { shardings { type: REPLICATED } } + parameter_layouts_list { + layouts { minor_to_major: 0 } + layouts { minor_to_major: [ 1, 0 ] } + } + output_layouts_list { layouts { minor_to_major: [ 1, 0 ] } } + output_memory_kinds { + memory_kind_lists { memory_kinds: [ "foo" ] } + } + } + )pb"))))); + } + + { + EXPECT_CALL(*executable, GetParameterShardings()) + .WillOnce(Return(std::nullopt)); + EXPECT_CALL(*executable, GetOutputShardings()) + .WillOnce(Return(std::nullopt)); + EXPECT_CALL(*executable, GetParameterLayouts()) + .WillOnce(Return(absl::UnimplementedError("unimplemented"))); + EXPECT_CALL(*executable, GetOutputLayouts()) + .WillOnce(Return(absl::UnimplementedError("unimplemented"))); + EXPECT_CALL(*executable, GetOutputMemoryKinds()) + .WillOnce(Return(std::vector>{})); + + auto request = NewIfrtRequest(NewOpId()); + LoadedExecutableMetadataRequest* metadata_request = + request->mutable_loaded_executable_metadata_request(); + metadata_request->set_loaded_executable_handle(handle); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr response, + CallBackend(std::move(request))); + const auto& metadata_response = + response->loaded_executable_metadata_response(); + EXPECT_FALSE(metadata_response.has_parameter_shardings()); + EXPECT_FALSE(metadata_response.has_output_shardings()); + EXPECT_TRUE(metadata_response.has_parameter_layouts_error()); + EXPECT_TRUE(metadata_response.has_output_layouts_error()); + } +} +#endif + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(IfrtBackendHandlerTest, LoadedExecutableExecute) { + MockDevice device; + ON_CALL(device, global_device_id()) + .WillByDefault(Return(xla::PjRtGlobalDeviceId(0))); + + MockLoadedExecutable* executable; + uint64_t handle; + { + auto e = std::make_unique(); + executable = e.get(); + TF_ASSERT_OK_AND_ASSIGN(CompileResponse response, + CompileTestLoadedExecutable(std::move(e))); + handle = response.loaded_executable_handle(); + } + + constexpr int kNumArgs = 3; + constexpr int kNumOutputs = 2; + + Shape shape({2, 2}); + auto sharding = SingleDeviceSharding::Create(&device, MemoryKind()); + + auto make_array = [&]() { + auto array = tsl::MakeRef(); + ON_CALL(*array, dtype()).WillByDefault(Return(DType(DType::kF32))); + ON_CALL(*array, shape()).WillByDefault(ReturnRef(shape)); + ON_CALL(*array, sharding()).WillByDefault(ReturnRef(*sharding)); + return array; + }; + + std::vector> outputs; + outputs.reserve(kNumOutputs); + for (int i = 0; i < kNumOutputs; ++i) { + outputs.push_back(make_array()); + } + + EXPECT_CALL(*executable, Execute(SizeIs(kNumArgs), _, _)) + .WillOnce( + Invoke([&](absl::Span> args, + const xla::ifrt::LoadedExecutable::ExecuteOptions& options, + std::optional devices) + -> absl::StatusOr { + return LoadedExecutable::ExecuteResult{ + .status = + Future(absl::InternalError("injected error")), + .outputs = outputs, + }; + })); + + auto request = NewIfrtRequest(NewOpId()); + LoadedExecutableExecuteRequest* execute_request = + request->mutable_loaded_executable_execute_request(); + for (int i = 0; i < kNumArgs; ++i) { + TF_ASSERT_OK_AND_ASSIGN(uint64_t arg_handle, MakeTestArray(make_array())); + execute_request->add_args_handles(arg_handle); + } + execute_request->set_loaded_executable_handle(handle); + TF_ASSERT_OK_AND_ASSIGN( + *execute_request->mutable_execute_options(), + xla::ifrt::LoadedExecutable::ExecuteOptions().ToProto()); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr response, + CallBackend(std::move(request))); + EXPECT_THAT(response, Pointee(Partially(EquivToProto(R"pb( + loaded_executable_execute_response { + outputs { + dtype { kind: KIND_F32 } + shape { dims: [ 2, 2 ] } + } + outputs { + dtype { kind: KIND_F32 } + shape { dims: [ 2, 2 ] } + } + } + )pb")))); + TF_ASSERT_OK_AND_ASSIGN( + auto sharding_proto, + ToShardingProto(*SingleDeviceSharding::Create(&device, MemoryKind()))); + for (const auto& output : + response->loaded_executable_execute_response().outputs()) { + EXPECT_THAT(output.sharding(), EquivToProto(sharding_proto)); + EXPECT_NE(output.array_handle(), 0); + } + + EXPECT_THAT( + CheckFuture( + response->loaded_executable_execute_response().status_handle()), + StatusIs(absl::StatusCode::kInternal, StrEq("injected error"))); + + // The second call to `CheckFuture` fails since `CheckFuture` above performs a + // destructive read. + EXPECT_THAT( + CheckFuture( + response->loaded_executable_execute_response().status_handle()), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("Unknown future handle"))); +} +#endif + +// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS +#if defined(PLATFORM_GOOGLE) +TEST_F(IfrtBackendHandlerTest, LoadedExecutableDelete) { + MockLoadedExecutable* executable; + uint64_t handle; + { + auto e = std::make_unique(); + executable = e.get(); + TF_ASSERT_OK_AND_ASSIGN(CompileResponse response, + CompileTestLoadedExecutable(std::move(e))); + handle = response.loaded_executable_handle(); + } + + { + EXPECT_CALL(*executable, Delete()) + .WillOnce(Return(Future(absl::OkStatus()))); + + auto request = NewIfrtRequest(NewOpId()); + LoadedExecutableDeleteRequest* delete_request = + request->mutable_loaded_executable_delete_request(); + delete_request->set_loaded_executable_handle(handle); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr response, + CallBackend(std::move(request))); + ASSERT_TRUE(response->has_loaded_executable_delete_response()); + + EXPECT_THAT( + CheckFuture( + response->loaded_executable_delete_response().future_handle()), + IsOk()); + } + + { + EXPECT_CALL(*executable, IsDeleted()).WillOnce(Return(true)); + + auto request = NewIfrtRequest(NewOpId()); + LoadedExecutableIsDeletedRequest* is_deleted_request = + request->mutable_loaded_executable_is_deleted_request(); + is_deleted_request->set_loaded_executable_handle(handle); + + EXPECT_THAT(CallBackend(std::move(request)), + IsOkAndHolds(Pointee(Partially(EquivToProto(R"pb( + loaded_executable_is_deleted_response { is_deleted: true } + )pb"))))); + } +} +#endif + +TEST_F(IfrtBackendHandlerTest, LoadedExecutableDestruct) { + MockLoadedExecutable* executable; + uint64_t handle; + { + auto e = std::make_unique(); + executable = e.get(); + TF_ASSERT_OK_AND_ASSIGN(CompileResponse response, + CompileTestLoadedExecutable(std::move(e))); + handle = response.loaded_executable_handle(); + } + + { + auto request = NewIfrtRequest(NewOpId()); + LoadedExecutableDestructRequest* destruct_request = + request->mutable_loaded_executable_destruct_request(); + destruct_request->set_loaded_executable_handle(handle); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr response, + CallBackend(std::move(request))); + ASSERT_TRUE(response->has_loaded_executable_destruct_response()); + } + + // Any attempt to access the loaded executable handle should now return an + // error. + { + auto request = NewIfrtRequest(NewOpId()); + LoadedExecutableDestructRequest* destruct_request = + request->mutable_loaded_executable_destruct_request(); + destruct_request->set_loaded_executable_handle(handle); + + EXPECT_THAT(CallBackend(std::move(request)), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("Unknown loaded executable handle"))); + } +} + +TEST_F(IfrtBackendHandlerTest, LoadedHostCallbackExecute) { + // Build a remote host callback with one F32 argument and one F32 result. + std::vector hcb_args = {{ + .channel_id = 1, + .shape = xla::ShapeUtil::MakeShape(xla::F32, {}), + }}; + std::vector hcb_results = {{ + .channel_id = 2, + .shape = xla::ShapeUtil::MakeShape(xla::F32, {}), + }}; + auto hcb = tsl::MakeRef( + mock_client_, std::move(hcb_args), std::move(hcb_results), + /*queue=*/nullptr); + + // Compile an executable with the above host callback. The resulting loaded + // host callback handle and `xla::HostCallback` are kept for triggering host + // callback execution. + // + // The setup code must use `xla::ifrt::XlaCompileOptions` for now since this + // is the only allowed compile options type that is currently recognized as + // supporting host callbacks. + MockLoadedExecutable* executable; + tsl::RCReference loaded_host_callback; + uint64_t loaded_host_callback_handle; + { + auto request = NewIfrtRequest(NewOpId()); + CompileRequest* compile_request = request->mutable_compile_request(); + + TestProgram program; + TF_ASSERT_OK_AND_ASSIGN(*compile_request->mutable_program(), + Serialize(program)); + xla::ifrt::XlaCompileOptions compile_options; + TF_ASSERT_OK_AND_ASSIGN(*compile_request->mutable_compile_options(), + Serialize(compile_options)); + + TF_ASSERT_OK_AND_ASSIGN(std::string host_callback_serialized, + hcb->Serialize()); + compile_request->add_host_callbacks(std::move(host_callback_serialized)); + + auto e = std::make_unique(); + executable = e.get(); + + EXPECT_CALL(mock_compiler_, Compile(_, _)) + .WillOnce(DoAll( + Invoke( + [&](const std::unique_ptr& program, + const std::unique_ptr& options) { + auto* xla_compile_options = + llvm::cast(options.get()); + auto& loaded_host_callbacks = + xla_compile_options->loaded_host_callbacks; + ASSERT_EQ(loaded_host_callbacks.size(), 1); + loaded_host_callback = loaded_host_callbacks.front(); + }), + Return(ByMove(std::move(e))))); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr response, + CallBackend(std::move(request))); + + ASSERT_TRUE(response->has_compile_response()); + CompileResponse compile_response = response->compile_response(); + + loaded_host_callback_handle = + compile_response.loaded_host_callback_handles(0); + ASSERT_THAT(loaded_host_callback, NotNull()); + } + + // Enqueue a host callback execution. This is done on a separate thread since + // `LoadedHostCallbackPollRequest` blocks until there is a pending execution. + auto host_callback_thread = absl::WrapUnique(tsl::Env::Default()->StartThread( + tsl::ThreadOptions(), "HostCallback", [&]() { + xla::Literal x = xla::LiteralUtil::CreateR0(1.0f); + + std::vector operands; + operands.push_back(x.untyped_data()); + + xla::Literal out = xla::LiteralUtil::CreateR0(0.0f); + std::vector results; + results.push_back(out.untyped_data()); + + const xla::HostCallback* xla_host_callback = + &llvm::cast(loaded_host_callback.get()) + ->host_callback(); + ASSERT_THAT( + xla_host_callback->callback(results.data(), operands.data()), + IsOk()); + EXPECT_EQ(out, xla::LiteralUtil::CreateR0(2.0f)); + })); + + // Poll for a host callback execution and verify its argument against the one + // passed by the execution thread above. + uint64_t host_callback_execution_handle; + { + const uint64_t operand_host_buffer_handle = NewHostBufferHandle(); + + auto request = NewIfrtRequest(NewOpId()); + LoadedHostCallbackPollRequest* poll_request = + request->mutable_loaded_host_callback_poll_request(); + poll_request->set_loaded_host_callback_handle(loaded_host_callback_handle); + poll_request->set_operand_host_buffer_handle(operand_host_buffer_handle); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr response, + CallBackend(std::move(request))); + + ASSERT_TRUE(response->has_loaded_host_callback_poll_response()); + const LoadedHostCallbackPollResponse& poll_response = + response->loaded_host_callback_poll_response(); + host_callback_execution_handle = + poll_response.host_callback_execution_handle(); + + TF_ASSERT_OK_AND_ASSIGN( + const std::shared_ptr operands, + host_buffer_store_->Lookup(operand_host_buffer_handle)); + EXPECT_EQ(xla::BorrowingLiteral(operands->data(), + xla::ShapeUtil::MakeShape(xla::F32, {})), + xla::LiteralUtil::CreateR0(1.0f)); + } + + // Return the execution result. This will unblock the execution thread above, + // which also verifies the result. + { + auto result = xla::LiteralUtil::CreateR0(2.0f); + std::string result_buffer(absl::string_view( + static_cast(result.untyped_data()), result.size_bytes())); + + const uint64_t result_host_buffer_handle = NewHostBufferHandle(); + ASSERT_THAT(host_buffer_store_->Store(result_host_buffer_handle, + std::move(result_buffer)), + IsOk()); + + auto request = NewIfrtRequest(NewOpId()); + LoadedHostCallbackReturnRequest* ret_request = + request->mutable_loaded_host_callback_return_request(); + ret_request->set_host_callback_execution_handle( + host_callback_execution_handle); + ret_request->set_result_host_buffer_handle(result_host_buffer_handle); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr response, + CallBackend(std::move(request))); + ASSERT_TRUE(response->has_loaded_host_callback_return_response()); + } +} + +TEST_F(IfrtBackendHandlerTest, GetDefaultDeviceAssignmentSuccess) { + const int kNumReplicas = 1; + const int kNumPartitions = 3; + + EXPECT_CALL(*mock_client_, + GetDefaultDeviceAssignment(kNumReplicas, kNumPartitions)) + .WillOnce(Return(xla::DeviceAssignment(kNumReplicas, kNumPartitions))); + + auto request = NewIfrtRequest(NewOpId()); + auto* default_device_assignment_request = + request->mutable_get_default_device_assignment_request(); + default_device_assignment_request->set_num_replicas(kNumReplicas); + default_device_assignment_request->set_num_partitions(kNumPartitions); + + TF_ASSERT_OK_AND_ASSIGN(auto response, CallBackend(std::move(request))); + TF_ASSERT_OK_AND_ASSIGN(auto assignment_got, + xla::DeviceAssignment::Deserialize( + response->get_default_device_assignment_response() + .device_assignment())); + EXPECT_EQ(assignment_got->replica_count(), kNumReplicas); + EXPECT_EQ(assignment_got->computation_count(), kNumPartitions); +} + +TEST_F(IfrtBackendHandlerTest, + GetDefaultDeviceAssignmentFailsIfTheBackendFails) { + const int kNumReplicas = 1; + const int kNumPartitions = 3; + + EXPECT_CALL(*mock_client_, + GetDefaultDeviceAssignment(kNumReplicas, kNumPartitions)) + .WillOnce(Return(absl::UnknownError("injected error"))); + + auto request = NewIfrtRequest(NewOpId()); + auto* default_device_assignment_request = + request->mutable_get_default_device_assignment_request(); + default_device_assignment_request->set_num_replicas(kNumReplicas); + default_device_assignment_request->set_num_partitions(kNumPartitions); + + EXPECT_THAT(CallBackend(std::move(request)), + StatusIs(absl::StatusCode::kUnknown, StrEq("injected error"))); +} + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/server/ifrt_session_handler.cc b/xla/python/ifrt_proxy/server/ifrt_session_handler.cc new file mode 100644 index 0000000000000..a4b0a95f2866f --- /dev/null +++ b/xla/python/ifrt_proxy/server/ifrt_session_handler.cc @@ -0,0 +1,117 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/ifrt_session_handler.h" + +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/proto_util.h" + +// The tsl include below is needed only for the Status macros such as +// ASSIGN_OR_RETURN, since the OSS absl package does not have the counterparts +// yet. +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +absl::StatusOr> IfrtSessionHandler::Create( + uint64_t id, BackendFactory backend_factory) { + if (backend_factory == nullptr) { + return absl::InvalidArgumentError("BackendFactory cannot be nullptr."); + } + return absl::WrapUnique( + new IfrtSessionHandler(id, std::move(backend_factory))); +} + +IfrtSessionHandler::IfrtSessionHandler(uint64_t id, + BackendFactory backend_factory) + : session_id_(id), backend_factory_(std::move(backend_factory)) {} + +void IfrtSessionHandler::NewIncomingRequest( + std::unique_ptr request, + std::function)> on_done) { + VLOG(2) << "NewIncomingRequest: " << request->DebugString(); + + const uint64_t op_id = request->request_metadata().op_id(); + + // The current implementation exploits the async nature of the backend_ IFRT + // client to minimize the amount of work we do per request. However, using a + // threadpool here might make sense as a performance optimization. + + auto result = [&]() -> Future { + if (request->has_init_request()) { + return ProcessInitRequest(std::move(request)); + } + if (auto status = SetupBackendIfNeeded(); !status.ok()) { + return Future(status); + } + absl::ReaderMutexLock read_lock(&backend_mu_); + return backend_->Process(std::move(request)); + }(); + + // Consider maintaining a count of in-flight requests (that won't complete + // until the following OnReady callback happens) so we can safely deleting the + // reactor_. + result.OnReady([op_id, on_done = std::move(on_done)]( + absl::StatusOr> result) { + if (result.ok()) { + on_done(*std::move(result)); + } else { + on_done(NewIfrtResponse(op_id, result.status())); + } + }); +} + +Future IfrtSessionHandler::ProcessInitRequest( + std::unique_ptr request) { + absl::MutexLock lock(&backend_mu_); + if (backend_ != nullptr) { + // Currently backends cannot be reinitialized. + return Future(absl::FailedPreconditionError( + "This session has already been initialized.")); + } + + auto backend = backend_factory_(session_id_); + if (!backend.ok()) { + return Future(backend.status()); + } + backend_ = *std::move(backend); + + return backend_->Process(std::move(request)); +} + +absl::Status IfrtSessionHandler::SetupBackendIfNeeded() { + absl::MutexLock lock(&backend_mu_); + if (backend_ != nullptr) { + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN(backend_, backend_factory_(session_id_)); + return absl::OkStatus(); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/server/ifrt_session_handler.h b/xla/python/ifrt_proxy/server/ifrt_session_handler.h new file mode 100644 index 0000000000000..505341a693495 --- /dev/null +++ b/xla/python/ifrt_proxy/server/ifrt_session_handler.h @@ -0,0 +1,82 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_SERVER_IFRT_SESSION_HANDLER_H_ +#define XLA_PYTHON_IFRT_PROXY_SERVER_IFRT_SESSION_HANDLER_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/server/ifrt_backend.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// IfrtSessionHandler glues an incoming stream to a stack of backend runtimes +// abstracted out by a `BackendInterface`. It utilizes the provided `Backend` to +// process the incoming client requests after ensuring that dependencies as +// specified by the client are honored and the chunked requests are fully +// re-assembled. +class IfrtSessionHandler { + public: + using BackendFactory = + absl::AnyInvocable>( + uint64_t session_id)>; + + using Response = BackendInterface::Response; + + // Makes a new IfrtSessionHandler with the given Session ID that uniquely + // identifies this session. The backend_factory cannot be a nullptr. + static absl::StatusOr> Create( + uint64_t id, BackendFactory backend_factory); + + uint64_t session_id() const { return session_id_; } + + // Top-level handler the transport implementation calls to hand off a new + // incoming request. `on_done` is called asynchronously to return responses. + void NewIncomingRequest( + std::unique_ptr request, + std::function)> on_done); + + private: + IfrtSessionHandler(uint64_t id, BackendFactory backend_factory); + + // InitRequest is treated somewhat differently than the rest since it triggers + // the creation of the backend_ + Future ProcessInitRequest(std::unique_ptr request) + ABSL_LOCKS_EXCLUDED(backend_mu_); + + // Sets up the backaned_ only if needed - i.e., only if it is a nullptr. + absl::Status SetupBackendIfNeeded() ABSL_LOCKS_EXCLUDED(backend_mu_); + + const uint64_t session_id_; // Unique ID of this Session. + + // The backend_ runtime(s) this session relies on for processing the incoming + // requests. It is instantiated at the start of a new Bidi stream, and + // currently does not change for the life of this object. + BackendFactory backend_factory_; + absl::Mutex backend_mu_; + std::unique_ptr backend_ ABSL_GUARDED_BY(backend_mu_); +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_SERVER_IFRT_SESSION_HANDLER_H_ diff --git a/xla/python/ifrt_proxy/server/ifrt_session_handler_test.cc b/xla/python/ifrt_proxy/server/ifrt_session_handler_test.cc new file mode 100644 index 0000000000000..b5a1e0bc316d5 --- /dev/null +++ b/xla/python/ifrt_proxy/server/ifrt_session_handler_test.cc @@ -0,0 +1,70 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/server/ifrt_session_handler.h" + +#include +#include +#include + +#include +#include +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/server/ifrt_backend.h" +#include "tsl/platform/status_matchers.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::testing::Not; +using ::tsl::testing::IsOk; + +// FakeBackend. Currently: Fails or returns illegal values where possible. +// All other methods return dummy strings or empty vectors. Individual tests +// can make derived classes that override specific methods as needed. +class FakeBackend : public BackendInterface { + public: + FakeBackend() = default; + ~FakeBackend() override = default; + + Future Process( + std::unique_ptr request) override { + return Future(std::make_unique()); + } +}; + +TEST(IfrtSessionHandlerTest, NullptrForBackendMakerFails) { + EXPECT_THAT(IfrtSessionHandler::Create(1234, nullptr), Not(IsOk())); +} + +TEST(IfrtSessionHandlerTest, SuccessfulCreation) { + std::unique_ptr backend = std::make_unique(); + EXPECT_THAT( + IfrtSessionHandler::Create( + 1234, [&](uint64_t session_id) { return std::move(backend); }), + IsOk()); +} + +// TODO(b/282757875) Add "end-to-end" tests that cover the entire path from the +// Server/BidiReactor to the backend. Since IfrtSessionHandler writes the +// responses (IfrtResponse messages) directly to the Bidi Reactor, tests for the +// actual processing of requests need a full server and a fake client that +// allows us retrieve and examine the responses. + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/server/mock_ifrt_backend.h b/xla/python/ifrt_proxy/server/mock_ifrt_backend.h new file mode 100644 index 0000000000000..620808f912be8 --- /dev/null +++ b/xla/python/ifrt_proxy/server/mock_ifrt_backend.h @@ -0,0 +1,42 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_SERVER_MOCK_IFRT_BACKEND_H_ +#define XLA_PYTHON_IFRT_PROXY_SERVER_MOCK_IFRT_BACKEND_H_ + +#include + +#include +#include "absl/status/status.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/server/ifrt_backend.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +class MockIfrtBackend final : public BackendInterface { + public: + MOCK_METHOD(Future, Process, (std::unique_ptr request), + (final)); +}; + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_SERVER_MOCK_IFRT_BACKEND_H_ diff --git a/xla/python/ifrt_proxy/server/version.cc b/xla/python/ifrt_proxy/server/version.cc new file mode 100644 index 0000000000000..b4f5298203a5a --- /dev/null +++ b/xla/python/ifrt_proxy/server/version.cc @@ -0,0 +1,48 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "xla/python/ifrt_proxy/server/version.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +absl::StatusOr ChooseVersion(int client_min_version, + int client_max_version, + int server_min_version, + int server_max_version) { + const int version = std::min(server_max_version, client_max_version); + + if (version < server_min_version || version < client_min_version) { + return absl::InvalidArgumentError(absl::StrCat( + "IFRT Proxy client and server failed to agree on the " + "protocol version; supported versions: client = [", + client_min_version, ", ", client_max_version, "], server = [", + server_min_version, ", ", server_max_version, "]")); + } + + return version; +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt_proxy/server/version.h b/xla/python/ifrt_proxy/server/version.h new file mode 100644 index 0000000000000..2556b5656f618 --- /dev/null +++ b/xla/python/ifrt_proxy/server/version.h @@ -0,0 +1,41 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef XLA_PYTHON_IFRT_PROXY_SERVER_VERSION_H_ +#define XLA_PYTHON_IFRT_PROXY_SERVER_VERSION_H_ + +#include "absl/status/statusor.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// TODO(b/296144873): Document the version upgrade policy. +inline constexpr int kServerMinVersion = 1; +inline constexpr int kServerMaxVersion = 1; + +// Returns a version that both the client and the server support, or an error if +// there is no such a version. +absl::StatusOr ChooseVersion(int client_min_version, + int client_max_version, + int server_min_version = kServerMinVersion, + int server_max_version = kServerMaxVersion); + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_SERVER_VERSION_H_ diff --git a/xla/python/ifrt_proxy/server/version_test.cc b/xla/python/ifrt_proxy/server/version_test.cc new file mode 100644 index 0000000000000..efebcad9d65d9 --- /dev/null +++ b/xla/python/ifrt_proxy/server/version_test.cc @@ -0,0 +1,69 @@ +/* + * Copyright 2023 The OpenXLA Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "xla/python/ifrt_proxy/server/version.h" + +#include +#include +#include "absl/status/status.h" +#include "tsl/platform/status_matchers.h" + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +using ::tsl::testing::IsOk; +using ::tsl::testing::StatusIs; + +struct Param { + int client_min_version; + int client_max_version; + int server_min_version; + int server_max_version; +}; + +class CompatibleVersionTest : public ::testing::TestWithParam {}; + +TEST_P(CompatibleVersionTest, Verify) { + const Param& param = GetParam(); + EXPECT_THAT(ChooseVersion(param.client_min_version, param.client_max_version, + param.server_min_version, param.server_max_version), + IsOk()); +} + +INSTANTIATE_TEST_SUITE_P(CompatibleVersionTest, CompatibleVersionTest, + ::testing::Values(Param{1, 1, 1, 1}, Param{1, 2, 2, 2}, + Param{2, 2, 1, 2}, + Param{1, 3, 3, 4})); + +class IncompatibleVersionTest : public ::testing::TestWithParam {}; + +TEST_P(IncompatibleVersionTest, Verify) { + const Param& param = GetParam(); + EXPECT_THAT(ChooseVersion(param.client_min_version, param.client_max_version, + param.server_min_version, param.server_max_version), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +INSTANTIATE_TEST_SUITE_P(IncompatibleVersionTest, IncompatibleVersionTest, + ::testing::Values(Param{1, 2, 3, 3}, Param{1, 3, 4, 6}, + Param{1, 1, 2, 2})); + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/xla/python/inspect_sharding.cc b/xla/python/inspect_sharding.cc index e7f05d750fc3b..81fb65606433d 100644 --- a/xla/python/inspect_sharding.cc +++ b/xla/python/inspect_sharding.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -60,7 +60,7 @@ class InspectShardingCallPartitioner : public xla::CustomCallPartitioner { HloInstruction* instruction) const override { const HloInstruction* operand = instruction->operand(0); if (!operand->has_sharding()) { - return xla::InternalError( + return xla::Internal( "Inspect sharding called but no sharding is available."); } std::string sharding_spec = @@ -71,14 +71,14 @@ class InspectShardingCallPartitioner : public xla::CustomCallPartitioner { args.error_txt = nullptr; const auto& str = instruction->raw_backend_config_string(); if (str.size() != sizeof(JAX_InspectSharding_Callback)) { - return xla::InternalError("Invalid config string for inspect sharding."); + return xla::Internal("Invalid config string for inspect sharding."); } JAX_InspectSharding_Callback cb; memcpy(&cb, str.data(), sizeof(JAX_InspectSharding_Callback)); cb.call(cb.data, &args); if (args.error_txt) { - auto result = xla::InternalError("Error calling inspect_sharding: %s", - args.error_txt); + auto result = + xla::Internal("Error calling inspect_sharding: %s", args.error_txt); args.free_error(&args); return result; } diff --git a/xla/python/inspect_sharding.h b/xla/python/inspect_sharding.h index bb0ee0477b37a..4afc3a63875a0 100644 --- a/xla/python/inspect_sharding.h +++ b/xla/python/inspect_sharding.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/jax_jit.cc b/xla/python/jax_jit.cc index cd9407fce89cc..0e6eec0f35024 100644 --- a/xla/python/jax_jit.cc +++ b/xla/python/jax_jit.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,34 +29,37 @@ limitations under the License. #include #include -#include -#include -#include #include #include #include -#include +#include #include #include +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/types/span.h" -#include "pybind11/cast.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/optional.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // from @nanobind // IWYU pragma: keep #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/py_values.h" #include "xla/python/pytree.h" #include "xla/python/sharding.h" -#include "xla/python/status_casters.h" #include "xla/python/types.h" -#include "tsl/platform/status.h" +#include "tsl/platform/logging.h" #include "tsl/profiler/lib/traceme.h" namespace jax { -namespace py = pybind11; +namespace nb = nanobind; // TODO(phawkins): Add support for Tracers. // TODO(jblespiau): Use absl Status. @@ -66,7 +69,7 @@ namespace { // `thread_local_state.extra_jit_context` is set from Python. It's done when // loading the Python jax modules on the main-thread. For other threads, we // need to initialize the field the first time we access `thread_local_state`. -py::object& initialize_local_state = *new py::object(); +nb::object& initialize_local_state = *new nb::object(); } // namespace @@ -84,7 +87,7 @@ JitState& ThreadLocalJitState() { if (thread_local_state.extra_jit_context == std::nullopt) { CHECK(initialize_local_state.ptr() != nullptr); // Avoids reentrant calls to the initialization function. - thread_local_state.extra_jit_context = py::none(); + thread_local_state.extra_jit_context = nb::none(); initialize_local_state(); } return thread_local_state; @@ -104,7 +107,7 @@ bool GetEnableX64() { return thread_local_state.enable_x64.value_or(*global_state.enable_x64); } -std::optional GetDefaultDevice() { +std::optional GetDefaultDevice() { auto& global_state = GlobalJitState(); auto& thread_local_state = ThreadLocalJitState(); return thread_local_state.default_device.has_value() @@ -112,7 +115,7 @@ std::optional GetDefaultDevice() { : global_state.default_device; } -std::optional GetPostHook() { +std::optional GetPostHook() { auto& global_state = GlobalJitState(); auto& thread_local_state = ThreadLocalJitState(); return thread_local_state.post_hook.has_value() ? thread_local_state.post_hook @@ -120,9 +123,9 @@ std::optional GetPostHook() { } static std::string OptionalDebugString( - const std::optional optional) { + const std::optional optional) { if (optional.has_value()) { - return py::cast(py::str(optional.value())); + return nb::cast(nb::str(optional.value())); } else { return "None"; } @@ -136,13 +139,60 @@ bool FetchMemoriesFlag() { *global_state.enable_memories); } -std::string CallSignature::DebugString() const { - auto py_object_formatter = [](std::string* out, const py::object& o) { - out->append(py::cast(py::str(o))); +std::string ArgumentSignature::DebugString() const { + auto py_object_formatter = [](std::string* out, const nb::object& o) { + out->append(nb::cast(nb::str(o))); }; auto treedef_formatter = [](std::string* out, const xla::PyTreeDef& d) { out->append(d.ToString()); }; + return absl::StrFormat( + "static args (positional + keyword): [%s], " + "static arg keyword names: [%s], " + "dynamic arg signatures (positional + keyword): [%s]" + "dynamic arg shardings: [%s]", + absl::StrJoin(static_args, ",", py_object_formatter), + absl::StrJoin(static_arg_names, ",", py_object_formatter), + absl::StrJoin(dynamic_arg_names, ",", py_object_formatter), + absl::StrJoin(dynamic_arg_treedefs, "| ", treedef_formatter)); +} + +bool ArgumentSignature::operator==(const ArgumentSignature& other) const { + if (dynamic_arg_treedefs != other.dynamic_arg_treedefs) { + return false; + } + auto object_ptr_equality = [](nb::handle a, nb::handle b) { + return a.ptr() == b.ptr(); + }; + if (!absl::c_equal(dynamic_arg_names, other.dynamic_arg_names, + object_ptr_equality)) { + return false; + } + if (!absl::c_equal(static_arg_names, other.static_arg_names, + object_ptr_equality)) { + return false; + } + return absl::c_equal( + static_args, other.static_args, + [](const nb::object& a, const nb::object& b) { + try { + return a.type().ptr() == b.type().ptr() && a.equal(b); + } catch (const nb::python_error& e) { + throw std::invalid_argument(absl::StrCat( + "static arguments should be comparable using __eq__." + "The following error was raised when comparing two objects of " + "types ", + nb::cast(nb::str(a.type())), " and ", + nb::cast(nb::str(b.type())), + ". The error was:\n", e.what())); + } + }); +} + +std::string CallSignature::DebugString() const { + auto py_object_formatter = [](std::string* out, const nb::object& o) { + out->append(nb::cast(nb::str(o))); + }; auto signature_formatter = [](std::string* out, const xla::PyArgSignature& s) { out->append(s.DebugString()); @@ -151,25 +201,20 @@ std::string CallSignature::DebugString() const { out->append(o ? "true" : "false"); }; return absl::StrFormat( - "static args (positional + keyword): %s\nstatic arg keyword names: %s\n" + "arg signature: %s\n" "dynamic arg signatures (positional + keyword): %s\n" "dynamic arg shardings: %s\n" "committed args: %s\n" - "dynamic arg keyword names: %s\n" - "dynamic arg treedefs: %s\n" "device: %s\n" "default_device: %s\n" "jax_enable_x64: %d\n" "jax_enable_memories: %d\n" "global_extra_jit_context: %s\n" "thread_local_extra_jit_context: %s\n", - absl::StrJoin(static_args, ",", py_object_formatter), - absl::StrJoin(static_arg_names, ",", py_object_formatter), + arg_signature.DebugString(), absl::StrJoin(dynamic_arg_signatures, ", ", signature_formatter), absl::StrJoin(dynamic_arg_shardings, ", ", py_object_formatter), absl::StrJoin(committed_args, ",", bool_formatter), - absl::StrJoin(dynamic_arg_names, ",", py_object_formatter), - absl::StrJoin(dynamic_arg_treedefs, "| ", treedef_formatter), // new line device != nullptr ? device->DebugString() : "nullptr", OptionalDebugString(default_device), jax_enable_x64, jax_enable_memories, OptionalDebugString(global_extra_jit_context), @@ -177,162 +222,152 @@ std::string CallSignature::DebugString() const { } bool CallSignature::operator==(const CallSignature& other) const { - // TODO(chky): Consider implementing hashing and equality for sharding in cpp - // instead of hashing and checking sharding's pointer values. - return std::tie(dynamic_arg_treedefs, dynamic_arg_names, - dynamic_arg_signatures, device, jax_enable_x64, - jax_enable_memories, static_arg_names, committed_args) == - std::tie(other.dynamic_arg_treedefs, other.dynamic_arg_names, - other.dynamic_arg_signatures, other.device, - other.jax_enable_x64, other.jax_enable_memories, - other.static_arg_names, other.committed_args) && - // `==` on py:objects is the Python `is`. We need equal. - std::equal(dynamic_arg_shardings.begin(), dynamic_arg_shardings.end(), - other.dynamic_arg_shardings.begin(), - other.dynamic_arg_shardings.end(), - [](const py::object& a, const py::object& b) { - return ShardingEqual(a, b); - }) && - std::equal( - static_args.begin(), static_args.end(), other.static_args.begin(), - other.static_args.end(), - [this](const py::object& a, const py::object& b) { - try { - return py::type::handle_of(a) == py::type::handle_of(b) && - a.equal(b); - } catch (const py::error_already_set& e) { - throw std::invalid_argument(absl::StrCat( - "static arguments should be comparable using __eq__." - "The following error was raised during a call to '", - function_name, "' when comparing two objects of types ", - py::cast(py::str(py::type::of(a))), " and ", - py::cast(py::str(py::type::of(b))), - ". The error was:\n", e.what())); - } - }) && - (global_extra_jit_context.has_value() == - other.global_extra_jit_context.has_value()) && - (!global_extra_jit_context.has_value() || - global_extra_jit_context->equal(*other.global_extra_jit_context)) && - (default_device.has_value() == other.default_device.has_value()) && - (!default_device.has_value() || - default_device->equal(*other.default_device)) && - (thread_local_extra_jit_context.has_value() == - other.thread_local_extra_jit_context.has_value()) && - (!thread_local_extra_jit_context.has_value() || - thread_local_extra_jit_context->equal( - *other.thread_local_extra_jit_context)); + if (arg_signature != other.arg_signature) { + return false; + } + if (dynamic_arg_signatures != other.dynamic_arg_signatures) { + return false; + } + if (device != other.device) { + return false; + } + if (jax_enable_x64 != other.jax_enable_x64) { + return false; + } + if (jax_enable_memories != other.jax_enable_memories) { + return false; + } + if (committed_args != other.committed_args) { + return false; + } + return + // `==` on py:objects is the Python `is`. We need equal. + absl::c_equal(dynamic_arg_shardings, other.dynamic_arg_shardings, + ShardingEqual) && + (global_extra_jit_context.has_value() == + other.global_extra_jit_context.has_value()) && + (!global_extra_jit_context.has_value() || + global_extra_jit_context->equal(*other.global_extra_jit_context)) && + (default_device.has_value() == other.default_device.has_value()) && + (!default_device.has_value() || + default_device->equal(*other.default_device)) && + (thread_local_extra_jit_context.has_value() == + other.thread_local_extra_jit_context.has_value()) && + (!thread_local_extra_jit_context.has_value() || + thread_local_extra_jit_context->equal( + *other.thread_local_extra_jit_context)); } // Filter out static arguments, flatten and concatenate other arguments (i.e. // dynamic positional and keyword arguments), filling `arguments` in place. -xla::Status ParseArguments(absl::Span positional_args, - absl::Span keyword_args, - py::handle kwnames, - absl::Span static_argnums, - absl::Span static_argnames, - xla::PyTreeRegistry* pytree_registry, - ParsedArgumentsAsBuffers& arguments) { +absl::Status ParseArguments( + absl::Span positional_args, + absl::Span keyword_args, nb::handle kwnames, + absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry* pytree_registry, ArgumentSignature& signature, + absl::InlinedVector& flat_dynamic_args) { tsl::profiler::TraceMe traceme("ParseArguments"); - arguments.flat_dynamic_args.reserve(positional_args.size() + - keyword_args.size()); + flat_dynamic_args.reserve(positional_args.size() + keyword_args.size()); if (static_argnums.empty()) { - arguments.signature.dynamic_arg_treedefs.reserve(positional_args.size()); + signature.dynamic_arg_treedefs.reserve(positional_args.size()); // Positional arguments. for (int i = 0; i < positional_args.size(); ++i) { - arguments.signature.dynamic_arg_treedefs.emplace_back(pytree_registry); - xla::PyTreeDef& pytree_def = - arguments.signature.dynamic_arg_treedefs.back(); - pytree_def.Flatten(positional_args[i], arguments.flat_dynamic_args); + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(nb::handle(positional_args[i]), flat_dynamic_args); } } else { - arguments.signature.dynamic_arg_treedefs.reserve(positional_args.size()); + signature.dynamic_arg_treedefs.reserve(positional_args.size()); // Positional arguments. for (int i = 0; i < positional_args.size(); ++i) { if (std::find(static_argnums.begin(), static_argnums.end(), i) == static_argnums.end()) { - arguments.signature.dynamic_arg_treedefs.emplace_back(pytree_registry); - xla::PyTreeDef& pytree_def = - arguments.signature.dynamic_arg_treedefs.back(); - pytree_def.Flatten(positional_args[i], arguments.flat_dynamic_args); + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(positional_args[i], flat_dynamic_args); } else { - arguments.signature.static_args.emplace_back( - py::reinterpret_borrow(positional_args[i])); + signature.static_args.emplace_back( + nb::borrow(positional_args[i])); } } } // Keyword arguments. if (!keyword_args.empty()) { - std::vector> kwargs(keyword_args.size()); + std::vector> kwargs(keyword_args.size()); // We first intern the keys, then sort them (by name, as in the Python path) // (see also xla::PyTreeDef::Flatten) and then create the signatures. // TODO(jblespiau): We should be able to sort the keys by interned-key // pointers, but this requires the Python compilation to do the same. for (int i = 0; i < keyword_args.size(); ++i) { // Intern the key if not already interned. - kwargs[i].first = py::handle(PyTuple_GET_ITEM(kwnames.ptr(), i)); - kwargs[i].first.inc_ref(); - kwargs[i].second = py::handle(keyword_args[i]); - if (!PyUnicode_CHECK_INTERNED(kwargs[i].first.ptr())) { - PyUnicode_InternInPlace(&kwargs[i].first.ptr()); + PyObject* key = PyTuple_GET_ITEM(kwnames.ptr(), i); + Py_INCREF(key); + if (!PyUnicode_CHECK_INTERNED(key)) { + PyUnicode_InternInPlace(&key); } + kwargs[i].first = key; + kwargs[i].second = keyword_args[i]; } std::sort(kwargs.begin(), kwargs.end(), - [](const std::pair& a, - const std::pair& b) { + [](const std::pair& a, + const std::pair& b) { return a.first < b.first; }); - auto kwarg_is_static = [&](py::handle name) { + auto kwarg_is_static = [&](nb::handle name) { for (const auto& kw : static_argnames) { if (kw.ptr() == name.ptr()) return true; } return false; }; - arguments.signature.dynamic_arg_names.reserve(keyword_args.size()); + signature.dynamic_arg_names.reserve(keyword_args.size()); for (int i = 0; i < keyword_args.size(); ++i) { if (kwarg_is_static(kwargs[i].first)) { - arguments.signature.static_arg_names.push_back( - py::reinterpret_steal(kwargs[i].first)); - arguments.signature.static_args.push_back( - py::reinterpret_borrow(kwargs[i].second)); + signature.static_arg_names.push_back( + nb::steal(kwargs[i].first)); + signature.static_args.push_back( + nb::borrow(kwargs[i].second)); } else { - arguments.signature.dynamic_arg_names.push_back( - py::reinterpret_steal(kwargs[i].first)); - arguments.signature.dynamic_arg_treedefs.emplace_back(pytree_registry); - xla::PyTreeDef& pytree_def = - arguments.signature.dynamic_arg_treedefs.back(); - pytree_def.Flatten(kwargs[i].second, arguments.flat_dynamic_args); + signature.dynamic_arg_names.push_back( + nb::steal(kwargs[i].first)); + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(nb::handle(kwargs[i].second.ptr()), + flat_dynamic_args); } } } - return ::tsl::OkStatus(); + return absl::OkStatus(); } -void BuildJaxjitSubmodule(py::module& m) { - py::module jitlib = m.def_submodule("jax_jit", "Jax C++ jit library"); +void BuildJaxjitSubmodule(nb::module_& m) { + nb::module_ jitlib = m.def_submodule("jax_jit", "Jax C++ jit library"); - py::class_ jit_state_(jitlib, "JitState"); - jit_state_.def_readwrite("disable_jit", &JitState::disable_jit); - jit_state_.def_readwrite("enable_x64", &JitState::enable_x64); - jit_state_.def_readwrite("enable_memories", &JitState::enable_memories); - jit_state_.def_readwrite("default_device", &JitState::default_device); - jit_state_.def_readwrite("extra_jit_context", &JitState::extra_jit_context); - jit_state_.def_readwrite("post_hook", &JitState::post_hook); + nb::class_ jit_state_(jitlib, "JitState"); + jit_state_.def_rw("disable_jit", &JitState::disable_jit, nb::arg().none()); + jit_state_.def_rw("enable_x64", &JitState::enable_x64, nb::arg().none()); + jit_state_.def_rw("enable_memories", &JitState::enable_memories, + nb::arg().none()); + jit_state_.def_rw("default_device", &JitState::default_device, + nb::arg().none()); + jit_state_.def_rw("extra_jit_context", &JitState::extra_jit_context, + nb::arg().none()); + jit_state_.def_rw("post_hook", &JitState::post_hook, nb::arg().none()); GetEnableMemories = +[] { return FetchMemoriesFlag(); }; jitlib.def( "global_state", [&]() { return &GlobalJitState(); }, - py::return_value_policy::reference); + nb::rv_policy::reference); jitlib.def( "thread_local_state", [&]() { return &ThreadLocalJitState(); }, - py::return_value_policy::reference); + nb::rv_policy::reference); jitlib.def( "swap_thread_local_state_disable_jit", @@ -342,30 +377,24 @@ void BuildJaxjitSubmodule(py::module& m) { tls->disable_jit = value; return result; }, - py::return_value_policy::reference); + nb::arg("value").none(), nb::rv_policy::reference); - jitlib.def("jit_is_disabled", &GetDisableJit); jitlib.def("get_enable_x64", &GetEnableX64); jitlib.def("set_thread_local_state_initialization_callback", - [](py::object f) { initialize_local_state = f; }); + [](nb::object f) { initialize_local_state = f; }); - // TODO(yashkatariya, phawkins): Remove after 3 months from March 20, 2023. - struct CompiledFunction {}; - py::class_ give_me_a_name(m, "CompiledFunction"); - - py::class_ arg_signature(jitlib, "PyArgSignature"); + nb::class_ arg_signature(jitlib, "PyArgSignature"); arg_signature - .def_property_readonly( + .def_prop_ro( "dtype", [](const xla::PyArgSignature& sig) { - return xla::ValueOrThrow(PrimitiveTypeToDtype(sig.dtype)); - }) - .def_property_readonly( - "shape", - [](const xla::PyArgSignature& sig) { - return xla::SpanToTuple(absl::MakeConstSpan(sig.shape)); + return xla::ValueOrThrow(xla::PrimitiveTypeToNbDtype(sig.dtype)); }) - .def_readonly("weak_type", &xla::PyArgSignature::weak_type); + .def_prop_ro("shape", + [](const xla::PyArgSignature& sig) { + return xla::SpanToNbTuple(absl::MakeConstSpan(sig.shape)); + }) + .def_ro("weak_type", &xla::PyArgSignature::weak_type); jitlib.def("_ArgSignatureOfValue", xla::ValueOrThrowWrapper(xla::PyArgSignatureOfValue)); diff --git a/xla/python/jax_jit.h b/xla/python/jax_jit.h index 1a14478352355..26b75de374785 100644 --- a/xla/python/jax_jit.h +++ b/xla/python/jax_jit.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,24 +16,29 @@ limitations under the License. #ifndef XLA_PYTHON_JAX_JIT_H_ #define XLA_PYTHON_JAX_JIT_H_ -#include +#include + +#include #include #include #include +#include #include #include // placeholder for index annotation headers #include "absl/container/inlined_vector.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" -#include "pybind11/pybind11.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind #include "xla/pjrt/pjrt_client.h" -#include "xla/python/ifrt/array.h" +#include "xla/python/nb_helpers.h" #include "xla/python/py_values.h" #include "xla/python/python_ref_manager.h" #include "xla/python/pytree.h" #include "xla/python/sharding.h" +#include "tsl/platform/logging.h" namespace jax { @@ -47,7 +52,7 @@ struct JitState { if (extra_jit_context) { // We likely do not hold the GIL if this JitState is thread-local, so we // hand the Python object to the global reference manager to destroy. - pybind11::object o = std::move(*extra_jit_context); + nanobind::object o = std::move(*extra_jit_context); xla::GlobalPyRefManager()->AddGarbage(absl::MakeSpan(&o, 1)); extra_jit_context = std::nullopt; } @@ -61,15 +66,15 @@ struct JitState { // in global state, indicating there is no manual override. // TODO(skyewm): make this a C++ type when all JAX backends support a single // C++ device interface - std::optional default_device; + std::optional default_device; // Extra context that should be included in the JIT cache key. Must be // hashable and have an equality defined. - std::optional extra_jit_context; + std::optional extra_jit_context; // A callback that, if present, is called when a JITted function is executed // from cache. May be unset even in global state. - std::optional post_hook; + std::optional post_hook; }; JitState& GlobalJitState(); @@ -84,8 +89,90 @@ bool GetEnableX64(); // TODO(skyewm): return a C++ type when all JAX backends support a single C++ // device interface -std::optional GetDefaultDevice(); -std::optional GetPostHook(); +std::optional GetDefaultDevice(); +std::optional GetPostHook(); + +// An ArgumentSignature describes the static arguments to a function call, and +// how the dynamic arguments are related to the arguments. Together with the +// values of the dynamic arguments, this fully describes the arguments. +struct ArgumentSignature { + // A PyTreeDef for each dynamic argument, positional arguments first + // followed by keyword arguments. Keyword arguments are in the order given + // by dynamic_arg_names. + absl::InlinedVector dynamic_arg_treedefs; + + // Dynamic keyword argument names. Interned, and sorted by the keyword + // name. Interned values are safe to compare by pointer. + std::vector dynamic_arg_names; + + // Static arguments. Contains the positional arguments sorted in argument + // order, followed by static keyword arguments in the order given by + // `static_arg_names`. + std::vector static_args; + + // Static keyword argument names. Interned, and sorted by keyword name. + std::vector static_arg_names; + + bool operator==(const ArgumentSignature& other) const; + bool operator!=(const ArgumentSignature& other) const { + return !(*this == other); + } + + std::string DebugString() const; +}; + +template +H AbslHashValue(H h, const ArgumentSignature& s) { + h = H::combine(std::move(h), s.dynamic_arg_treedefs, + s.dynamic_arg_names.size(), s.static_args.size(), + s.static_arg_names.size()); + + for (const auto& name : s.dynamic_arg_names) { + h = H::combine(std::move(h), name.ptr()); + } + for (size_t i = 0; i < s.static_args.size(); ++i) { + const auto& static_arg = s.static_args[i]; + Py_hash_t hash; + try { + hash = xla::nb_hash(static_arg); + } catch (const nanobind::python_error& e) { + if (!e.matches(PyExc_TypeError)) throw; + throw std::invalid_argument(absl::StrCat( + "Non-hashable static arguments are not supported. An error occurred " + "while trying to hash an object of type ", + nanobind::cast(nanobind::str(static_arg.type())), + ", ", nanobind::cast(nanobind::str(static_arg)), + ". The error was:\n", e.what(), "\n")); + } + h = H::combine(std::move(h), hash); + } + for (const auto& name : s.static_arg_names) { + h = H::combine(std::move(h), name.ptr()); + } + return h; +} + +// Filter out static arguments, flatten and concatenate other arguments (i.e. +// dynamic positional and keyword arguments), filling `arguments` in place. +// Args: +// positional_args: positional arguments +// keyword_args: the values of the keyword arguments +// kwnames: either None or a tuple containing the keyword argument names +// static_argnums: the indices of the static arguments in the positional +// arguments +// static_argnames: the names of the static arguments +// pytree_registry: the registry to use to convert the arguments to pytrees +// arguments: output; describes the static arguments and the identities of the +// dynamic arguments. +// flat_dynamic_args: output; the concatenation of the dynamic positional +// arguments and sorted keyword arguments. +absl::Status ParseArguments( + absl::Span positional_args, + absl::Span keyword_args, nanobind::handle kwnames, + absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry* pytree_registry, ArgumentSignature& signature, + absl::InlinedVector& flat_dynamic_args); // The signature of Python jitted function call, partitioned into: // - dynamic positional arguments (i.e. positional args which are not static) @@ -98,29 +185,17 @@ std::optional GetPostHook(); // (a) equality (delegated to Python) of the static arguments. struct CallSignature { // Not part of the signature, but we need it for error messages. - absl::string_view function_name; + std::string_view function_name; + + ArgumentSignature arg_signature; - // A PyTreeDef for each dynamic argument, positional arguments first - // followed by keyword arguments. Keyword arguments are in the order given - // by dynamic_arg_names. - absl::InlinedVector dynamic_arg_treedefs; - // Dynamic keyword argument names. Interned, and sorted by the keyword - // name. - std::vector dynamic_arg_names; // Shape and dtype for both the dynamic positional arguments and the keyword // arguments (sorted by keyword name). absl::InlinedVector dynamic_arg_signatures; // The sharding of the jax.Array arguments. This is only used by pjit with // jax.Array enabled. - std::vector dynamic_arg_shardings; - - // Static arguments. Contains the positional arguments sorted in argument - // order, followed by static keyword arguments in the order given by - // `static_arg_names`. - std::vector static_args; - // Static keyword argument names. Interned, and sorted by keyword name. - std::vector static_arg_names; + std::vector dynamic_arg_shardings; absl::InlinedVector committed_args; @@ -133,11 +208,11 @@ struct CallSignature { // For JIT on PJIT, we need to fallback to python whenever default_device // changes. - std::optional default_device; + std::optional default_device; // Opaque additional context that should be included as part of the cache key. - std::optional global_extra_jit_context; - std::optional thread_local_extra_jit_context; + std::optional global_extra_jit_context; + std::optional thread_local_extra_jit_context; bool operator==(const CallSignature& other) const; bool operator!=(const CallSignature& other) const { @@ -149,8 +224,7 @@ struct CallSignature { template H AbslHashValue(H h, const CallSignature& s) { - h = H::combine(std::move(h), s.dynamic_arg_treedefs, - s.dynamic_arg_signatures); + h = H::combine(std::move(h), s.arg_signature, s.dynamic_arg_signatures); DCHECK(s.dynamic_arg_shardings.empty() || s.dynamic_arg_shardings.size() == s.dynamic_arg_signatures.size()); @@ -159,39 +233,11 @@ H AbslHashValue(H h, const CallSignature& s) { // slow python hashing function. Consider implementing hashing function and // equality checks in C++ in jax::Sharding and use those here. for (const auto& sharding : s.dynamic_arg_shardings) { - h = H::combine(std::move(h), ShardingHash(sharding)); - } - - for (const auto& name : s.dynamic_arg_names) { - h = H::combine(std::move(h), name.ptr()); + // TODO(phawkins): remove .ptr() after nanobind transition is complete. + h = H::combine(std::move(h), ShardingHash(sharding.ptr())); } - h = H::combine(std::move(h), s.committed_args); - - h = H::combine(std::move(h), s.dynamic_arg_names.size()); - for (const auto& static_arg : s.static_args) { - ssize_t hash; - try { - hash = pybind11::hash(static_arg); - } catch (const pybind11::error_already_set& e) { - if (!e.matches(PyExc_TypeError)) throw; - throw std::invalid_argument(absl::StrCat( - "Non-hashable static arguments are not supported. An error occurred " - "during a call to '", - s.function_name, "' while trying to hash an object of type ", - pybind11::cast( - pybind11::str(pybind11::type::of(static_arg))), - ", ", pybind11::cast(pybind11::str(static_arg)), - ". The error was:\n", e.what(), "\n")); - } - h = H::combine(std::move(h), hash); - } - h = H::combine(std::move(h), s.static_args.size()); - for (const auto& name : s.static_arg_names) { - h = H::combine(std::move(h), name.ptr()); - } - h = H::combine(std::move(h), s.static_arg_names.size()); - h = H::combine(std::move(h), s.device, s.jax_enable_x64); + h = H::combine(std::move(h), s.committed_args, s.device, s.jax_enable_x64); // We do not hash the extra_jit_context fields since calling Python hash // functions is expensive (~300ns) and we don't expect a large number of @@ -199,35 +245,8 @@ H AbslHashValue(H h, const CallSignature& s) { return h; } -// The resulting information of the parsing and conversion of the arguments. -struct ParsedArgumentsAsBuffers { - // The call signature will be filled during 2 steps: - // - `ParseArguments` will fill the static arguments and the pytree - // structures - // - the shapes and dtypes are filled later, by `ParseAndTransferArguments`. - CallSignature signature; - // The concatenation of the dynamic positional arguments and the sorted - // keyword arguments. - absl::InlinedVector flat_dynamic_args; - std::vector keep_alive_objects; - - xla::ifrt::Client* ifrt_client; - // The following is only valid if the parsing succeeds. - std::vector> ifrt_arg_arrays; -}; - -// Filter out static arguments, flatten and concatenate other arguments (i.e. -// dynamic positional and keyword arguments), filling `arguments` in place. -xla::Status ParseArguments(absl::Span positional_args, - absl::Span keyword_args, - pybind11::handle kwnames, - absl::Span static_argnums, - absl::Span static_argnames, - xla::PyTreeRegistry* pytree_registry, - ParsedArgumentsAsBuffers& arguments); - // The function to call in `xla.cc` to add the bindings for this module. -void BuildJaxjitSubmodule(pybind11::module& m); +void BuildJaxjitSubmodule(nanobind::module_& m); } // namespace jax diff --git a/xla/python/logging.cc b/xla/python/logging.cc index 01265becf09ab..2d8261f025974 100644 --- a/xla/python/logging.cc +++ b/xla/python/logging.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/logging.h b/xla/python/logging.h index 8d9fe83b0e259..9a791611df913 100644 --- a/xla/python/logging.h +++ b/xla/python/logging.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/mlir.cc b/xla/python/mlir.cc index 77dc23bc7b640..53369a3314fdc 100644 --- a/xla/python/mlir.cc +++ b/xla/python/mlir.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,9 +14,13 @@ limitations under the License. ==============================================================================*/ #include +#include #include "mhlo/transforms/passes.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Bytecode/BytecodeWriter.h" // from @llvm-project #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -24,11 +28,14 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project -#include "pybind11/cast.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/stl.h" // from @pybind11 +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // from @nanobind // IWYU pragma: keep #include "stablehlo/dialect/ChloOps.h" // from @stablehlo #include "stablehlo/dialect/Serialization.h" // from @stablehlo #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo @@ -37,21 +44,22 @@ limitations under the License. #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/refine_polymorphic_shapes.h" -#include "xla/python/status_casters.h" -#include "xla/python/types.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/status.h" #include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" -namespace py = pybind11; +namespace nb = nanobind; namespace xla { namespace { -StatusOr> ParseModule( - mlir::MLIRContext* context, std::string str) { +absl::StatusOr> ParseModule( + mlir::MLIRContext* context, std::string_view str) { mlir::OwningOpRef module; context->loadDialect(); context->loadDialect(); @@ -86,6 +94,16 @@ std::string PrintModule(mlir::ModuleOp module) { return s; } +absl::StatusOr SerializeUsingBytecode(mlir::ModuleOp module) { + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + mlir::BytecodeWriterConfig config; + if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) { + return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed"); + } + return bytecode; +} + void EnablePrintBeforeAndAfter(mlir::PassManager& pm) { auto print_before = [](mlir::Pass*, mlir::Operation*) { return true; }; auto print_after = [](mlir::Pass*, mlir::Operation*) { return true; }; @@ -96,7 +114,7 @@ void EnablePrintBeforeAndAfter(mlir::PassManager& pm) { // Exists for backwards compatibility. // TODO(phawkins): port remaining users of XlaComputations to use mlir::Modules // instead and delete this function. -StatusOr PyXlaComputationToMlirModule( +absl::StatusOr PyXlaComputationToMlirModule( const XlaComputation& computation, bool emit_stable_hlo) { mlir::MLIRContext context; if (VLOG_IS_ON(3)) context.disableMultithreading(); @@ -121,9 +139,8 @@ StatusOr PyXlaComputationToMlirModule( return PrintModule(*module); } -StatusOr PyMlirModuleToXlaComputation(std::string mlir_module, - bool use_tuple_args, - bool return_tuple) { +absl::StatusOr PyMlirModuleToXlaComputation( + std::string_view mlir_module, bool use_tuple_args, bool return_tuple) { mlir::MLIRContext context; TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, ParseModule(&context, mlir_module)); @@ -133,7 +150,7 @@ StatusOr PyMlirModuleToXlaComputation(std::string mlir_module, return computation; } -StatusOr PyMhloToStablehlo(std::string mlir_module) { +absl::StatusOr PyMhloToStablehlo(std::string_view mlir_module) { mlir::MLIRContext context; if (VLOG_IS_ON(3)) context.disableMultithreading(); // JAX can be customized in a way that involves operations from custom @@ -151,59 +168,56 @@ StatusOr PyMhloToStablehlo(std::string mlir_module) { if (!mlir::succeeded(pm.run(*module))) { return tsl::errors::InvalidArgument("MHLO => StableHLO failed"); } - return PrintModule(*module); + // Use bytecode, passing unregistered dialects with properties causes issues + // when using textual assembly. + TF_ASSIGN_OR_RETURN(std::string bytecode, SerializeUsingBytecode(*module)); + return nb::bytes(bytecode.data(), bytecode.size()); } -StatusOr PyStablehloToMhlo(std::string mlir_module) { +absl::StatusOr PyStablehloToMhlo(const nb::bytes& mlir_module) { mlir::MLIRContext context; if (VLOG_IS_ON(3)) context.disableMultithreading(); // See PyMhloToStablehlo for an explanation of why we're allowing unregistered // dialects here. context.allowUnregisteredDialects(true); - TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, - ParseModule(&context, mlir_module)); + TF_ASSIGN_OR_RETURN( + mlir::OwningOpRef module, + ParseModule(&context, + std::string_view(mlir_module.c_str(), mlir_module.size()))); mlir::PassManager pm(&context); if (VLOG_IS_ON(3)) EnablePrintBeforeAndAfter(pm); pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); if (!mlir::succeeded(pm.run(*module))) { return tsl::errors::InvalidArgument("StableHLO => MHLO failed"); } - return PrintModule(*module); + + // Use bytecode, passing unregistered dialects with properties causes issues + // when using textual assembly. + TF_ASSIGN_OR_RETURN(std::string bytecode, SerializeUsingBytecode(*module)); + return nb::bytes(bytecode.data(), bytecode.size()); } -StatusOr PySerializePortableArtifact(std::string mlir_module, - std::string target) { +absl::StatusOr PySerializePortableArtifact( + std::string_view mlir_module, std::string_view target) { mlir::MLIRContext context; if (VLOG_IS_ON(3)) context.disableMultithreading(); TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, ParseModule(&context, mlir_module)); - // Legalize CHLO -> [MHLO+Shape] -> StableHLO - mlir::PassManager pm(&context); - if (VLOG_IS_ON(3)) EnablePrintBeforeAndAfter(pm); - pm.addNestedPass( - mlir::mhlo::createChloLegalizeToHloPass()); - pm.addNestedPass( - mlir::mhlo::createShapeLegalizeToHloPass()); - pm.addPass(mlir::createReconcileUnrealizedCastsPass()); - pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); - if (!mlir::succeeded(pm.run(*module))) { - return tsl::errors::InvalidArgument( - "CHLO => [MHLO+Shape] => StableHLO failed"); - } - // Serialize portable artifact - std::string buffer; - llvm::raw_string_ostream os(buffer); - if (failed(mlir::stablehlo::serializePortableArtifact(*module, target, os))) - return tsl::errors::InvalidArgument("Failed to serialize StableHLO"); - return py::bytes(buffer); + TF_ASSIGN_OR_RETURN( + std::string bytecode, + SerializeUsingVersionedStablehlo(*module, target, /*inplace=*/true)); + return nb::bytes(bytecode.data(), bytecode.size()); } -StatusOr PyDeserializePortableArtifact(std::string bytecode_str) { +absl::StatusOr PyDeserializePortableArtifact( + const nb::bytes& bytecode_str) { mlir::MLIRContext context; mlir::OwningOpRef module = - mlir::stablehlo::deserializePortableArtifact(bytecode_str, &context); + mlir::stablehlo::deserializePortableArtifact( + std::string_view(bytecode_str.c_str(), bytecode_str.size()), + &context); if (!module) return tsl::errors::InvalidArgument("Failed to deserialize StableHLO"); return PrintModule(*module); @@ -211,40 +225,64 @@ StatusOr PyDeserializePortableArtifact(std::string bytecode_str) { } // namespace -void BuildMlirSubmodule(py::module& m) { - py::module mlir_module = m.def_submodule("mlir", "MLIR/XLA integration"); +void BuildMlirSubmodule(nb::module_& m) { + nb::module_ mlir_module = m.def_submodule("mlir", "MLIR/XLA integration"); mlir_module.def("xla_computation_to_mlir_module", xla::ValueOrThrowWrapper(PyXlaComputationToMlirModule), - py::arg("computation"), py::arg("emit_stable_hlo") = true); + nb::arg("computation"), nb::arg("emit_stable_hlo") = true); + mlir_module.def( + "mlir_module_to_xla_computation", + [](const nb::bytes& bytecode, bool use_tuple_args, bool return_tuple) { + return xla::ValueOrThrow(PyMlirModuleToXlaComputation( + std::string_view(bytecode.c_str(), bytecode.size()), use_tuple_args, + return_tuple)); + }, + nb::arg("mlir_module"), nb::arg("use_tuple_args") = false, + nb::arg("return_tuple") = false); mlir_module.def("mlir_module_to_xla_computation", xla::ValueOrThrowWrapper(PyMlirModuleToXlaComputation), - py::arg("mlir_module"), py::arg("use_tuple_args") = false, - py::arg("return_tuple") = false); + nb::arg("mlir_module"), nb::arg("use_tuple_args") = false, + nb::arg("return_tuple") = false); + mlir_module.def( + "mhlo_to_stablehlo", + [](const nb::bytes& bytecode) { + return xla::ValueOrThrow(PyMhloToStablehlo( + std::string_view(bytecode.c_str(), bytecode.size()))); + }, + nb::arg("mlir_module")); mlir_module.def("mhlo_to_stablehlo", xla::ValueOrThrowWrapper(PyMhloToStablehlo), - py::arg("mlir_module")); + nb::arg("mlir_module")); mlir_module.def("stablehlo_to_mhlo", xla::ValueOrThrowWrapper(PyStablehloToMhlo), - py::arg("mlir_module")); + nb::arg("mlir_module")); + mlir_module.def( + "serialize_portable_artifact", + [](const nb::bytes& bytecode, std::string_view target) { + return xla::ValueOrThrow(PySerializePortableArtifact( + std::string_view(bytecode.c_str(), bytecode.size()), target)); + }, + nb::arg("mlir_module"), nb::arg("target")); mlir_module.def("serialize_portable_artifact", xla::ValueOrThrowWrapper(PySerializePortableArtifact), - py::arg("mlir_module"), py::arg("target")); + nb::arg("mlir_module"), nb::arg("target")); mlir_module.def("deserialize_portable_artifact", xla::ValueOrThrowWrapper(PyDeserializePortableArtifact), - py::arg("mlir_module")); + nb::arg("mlir_module")); mlir_module.def( "refine_polymorphic_shapes", - [](std::string mlir_module, bool enable_shape_assertions, - bool validate_static_shapes) -> py::bytes { + [](nb::bytes bytecode, bool enable_shape_assertions, + bool validate_static_shapes) -> nb::bytes { std::string buffer; llvm::raw_string_ostream os(buffer); xla::ThrowIfError(RefinePolymorphicShapes( - mlir_module, os, enable_shape_assertions, validate_static_shapes)); - return py::bytes(buffer); + std::string_view(bytecode.c_str(), bytecode.size()), os, + enable_shape_assertions, validate_static_shapes)); + return nb::bytes(buffer.data(), buffer.size()); }, - py::arg("mlir_module"), py::arg("enable_shape_assertions") = true, - py::arg("validate_static_shapes") = true, + nb::arg("mlir_module"), nb::arg("enable_shape_assertions") = true, + nb::arg("validate_static_shapes") = true, R"(Refines the dynamic shapes for a module. The "main" function must have static shapes and all the intermediate dynamic shapes depend only on the input static diff --git a/xla/python/mlir.h b/xla/python/mlir.h index 99c9a5af921ce..c6ac71e4c6bc5 100644 --- a/xla/python/mlir.h +++ b/xla/python/mlir.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,11 +17,11 @@ limitations under the License. #define XLA_PYTHON_MLIR_H_ // placeholder for index annotation headers -#include "pybind11/pybind11.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind namespace xla { -void BuildMlirSubmodule(pybind11::module& m); +void BuildMlirSubmodule(nanobind::module_& m); } // namespace xla diff --git a/xla/python/nb_absl_flat_hash_map.h b/xla/python/nb_absl_flat_hash_map.h new file mode 100644 index 0000000000000..b2a89027d530f --- /dev/null +++ b/xla/python/nb_absl_flat_hash_map.h @@ -0,0 +1,33 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_NB_ABSL_FLAT_HASH_MAP_H_ +#define XLA_PYTHON_NB_ABSL_FLAT_HASH_MAP_H_ + +#include "absl/container/flat_hash_map.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/detail/nb_dict.h" // from @nanobind + +namespace nanobind { +namespace detail { + +template +struct type_caster> + : dict_caster, Key, T> {}; + +} // namespace detail +} // namespace nanobind + +#endif // XLA_PYTHON_NB_ABSL_FLAT_HASH_MAP_H_ diff --git a/xla/python/nb_absl_flat_hash_set.h b/xla/python/nb_absl_flat_hash_set.h new file mode 100644 index 0000000000000..f2d606f5b8ba4 --- /dev/null +++ b/xla/python/nb_absl_flat_hash_set.h @@ -0,0 +1,33 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_NB_ABSL_FLAT_HASH_SET_H_ +#define XLA_PYTHON_NB_ABSL_FLAT_HASH_SET_H_ + +#include "absl/container/flat_hash_set.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/detail/nb_set.h" // from @nanobind + +namespace nanobind { +namespace detail { + +template +struct type_caster> + : set_caster, Key> {}; + +} // namespace detail +} // namespace nanobind + +#endif // XLA_PYTHON_NB_ABSL_FLAT_HASH_SET_H_ diff --git a/xla/python/nb_absl_span.h b/xla/python/nb_absl_span.h new file mode 100644 index 0000000000000..819098f53c378 --- /dev/null +++ b/xla/python/nb_absl_span.h @@ -0,0 +1,68 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_NB_ABSL_SPAN_H_ +#define XLA_PYTHON_NB_ABSL_SPAN_H_ + +#include +#include + +#include "absl/types/span.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/detail/nb_list.h" // from @nanobind +#include "nanobind/stl/vector.h" // from @nanobind // IWYU pragma: keep + +namespace nanobind { +namespace detail { + +template +struct type_caster> { + NB_TYPE_CASTER(absl::Span, + const_name("Span[") + make_caster::Name + const_name("]")) + + using Caster = make_caster; + + list_caster, T> vec_caster; + + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (!vec_caster.from_python(src, flags, cleanup)) { + return false; + } + value = vec_caster.value; + return true; + } + + static handle from_cpp(absl::Span src, rv_policy policy, + cleanup_list *cleanup) noexcept { + object ret = steal(PyList_New(src.size())); + if (ret.is_valid()) { + Py_ssize_t i = 0; + for (const T &value : src) { + handle h = Caster::from_cpp(value, policy, cleanup); + if (!h.is_valid()) { + ret.reset(); + break; + } + PyList_SET_ITEM(ret.ptr(), i++, h.ptr()); + } + } + return ret.release(); + } +}; + +} // namespace detail +} // namespace nanobind + +#endif // XLA_PYTHON_NB_ABSL_SPAN_H_ diff --git a/xla/python/nb_class_ptr.h b/xla/python/nb_class_ptr.h new file mode 100644 index 0000000000000..5ab613230ce2d --- /dev/null +++ b/xla/python/nb_class_ptr.h @@ -0,0 +1,59 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_NB_CLASS_PTR_H_ +#define XLA_PYTHON_NB_CLASS_PTR_H_ + +#include "nanobind/nanobind.h" // from @nanobind + +namespace xla { + +// A reference-counting smart pointer to a nanobind-wrapped class on the Python +// heap. Type T must be a class known to nanobind via a nanobind::class_ +// declaration. nb_class_ptr is useful for managing C++ classes that may be +// allocated inline in Python objects on the Python heap. +template +class nb_class_ptr : public nanobind::object { + public: + inline nb_class_ptr() : nanobind::object() {} + inline nb_class_ptr(nanobind::handle h, ::nanobind::detail::borrow_t) + : nanobind::object(h, ::nanobind::detail::borrow_t{}) {} + inline nb_class_ptr(nanobind::handle h, ::nanobind::detail::steal_t) + : nanobind::object(h, ::nanobind::detail::steal_t{}) {} + inline static bool check_(nanobind::handle h) { + nanobind::handle type = nanobind::type(); + return h.type().is(type); + }; + + T* operator->() const { return nanobind::inst_ptr(ptr()); } + T& operator*() const { return *nanobind::inst_ptr(ptr()); } + T* get() const { return ptr() ? nanobind::inst_ptr(ptr()) : nullptr; } +}; + +// This function is analogous to std::make_unique(...), but instead it +// allocates the object on the Python heap +template +nb_class_ptr make_nb_class(Args&&... args) { + nanobind::handle type = nanobind::type(); + nanobind::object instance = nanobind::inst_alloc(type); + T* ptr = nanobind::inst_ptr(instance); + new (ptr) T(std::forward(args)...); + nanobind::inst_mark_ready(instance); + return nb_class_ptr(instance.release(), ::nanobind::detail::steal_t{}); +} + +} // namespace xla + +#endif // XLA_PYTHON_NB_CLASS_PTR_H_ diff --git a/xla/python/nb_helpers.cc b/xla/python/nb_helpers.cc new file mode 100644 index 0000000000000..3026345b9c1d2 --- /dev/null +++ b/xla/python/nb_helpers.cc @@ -0,0 +1,41 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/nb_helpers.h" + +#include + +#include "nanobind/nanobind.h" // from @nanobind + +namespace nb = nanobind; + +namespace xla { + +Py_hash_t nb_hash(nb::handle o) { + Py_hash_t h = PyObject_Hash(o.ptr()); + if (h == -1) { + throw nb::python_error(); + } + return h; +} + +bool nb_isinstance(nanobind::handle inst, nanobind::handle cls) { + int ret = PyObject_IsInstance(inst.ptr(), cls.ptr()); + if (ret == -1) { + throw nb::python_error(); + } + return ret; +} +} // namespace xla diff --git a/xla/python/nb_helpers.h b/xla/python/nb_helpers.h new file mode 100644 index 0000000000000..cf5e25018c00e --- /dev/null +++ b/xla/python/nb_helpers.h @@ -0,0 +1,73 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_NB_HELPERS_H_ +#define XLA_PYTHON_NB_HELPERS_H_ + +#include + +#include "absl/strings/str_format.h" +#include "nanobind/nanobind.h" // from @nanobind + +namespace xla { + +// Calls Python hash() on an object. +// TODO(phawkins): consider upstreaming this to nanobind. +Py_hash_t nb_hash(nanobind::handle o); + +// Calls Python isinstance(inst, cls). +// TODO(phawkins): consider upstreaming this to nanobind. +bool nb_isinstance(nanobind::handle inst, nanobind::handle cls); + +// Issues a Python deprecation warning. Throws a C++ exception if issuing the +// Python warning causes a Python exception to be raised. +template +void PythonDeprecationWarning(const absl::FormatSpec& format, + const Args&... args) { + if (PyErr_WarnEx(PyExc_DeprecationWarning, + absl::StrFormat(format, args...).c_str(), 1) < 0) { + throw nanobind::python_error(); + } +} + +// Variant of NB_TYPE_CASTER that doesn't define from_cpp() +#define NB_TYPE_CASTER_FROM_PYTHON_ONLY(Value_, descr) \ + using Value = Value_; \ + static constexpr auto Name = descr; \ + template \ + using Cast = movable_cast_t; \ + explicit operator Value*() { return &value; } \ + explicit operator Value&() { return (Value&)value; } \ + explicit operator Value&&() { return (Value&&)value; } \ + Value value; + +template +nanobind::object nb_property_readonly(Func&& get) { + nanobind::handle property(reinterpret_cast(&PyProperty_Type)); + return property(nanobind::cpp_function(std::forward(get)), + nanobind::none(), nanobind::none(), ""); +} + +template +nanobind::object nb_property(GetFunc&& get, SetFunc&& set) { + nanobind::handle property(reinterpret_cast(&PyProperty_Type)); + return property(nanobind::cpp_function(std::forward(get)), + nanobind::cpp_function(std::forward(set)), + nanobind::none(), ""); +} + +} // namespace xla + +#endif // XLA_PYTHON_NB_HELPERS_H_ diff --git a/xla/python/nb_numpy.cc b/xla/python/nb_numpy.cc new file mode 100644 index 0000000000000..4f8f177bdf618 --- /dev/null +++ b/xla/python/nb_numpy.cc @@ -0,0 +1,171 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/nb_numpy.h" + +#include + +#include +#include +#include + +#include "absl/types/span.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "xla/tsl/python/lib/core/numpy.h" + +namespace nb = nanobind; + +namespace xla { + +/*static*/ nb_dtype nb_dtype::from_args(const nb::object& args) { + PyArray_Descr* descr; + if (!PyArray_DescrConverter(args.ptr(), &descr) || !descr) { + throw nb::python_error(); + } + return nb::steal(reinterpret_cast(descr)); +} + +nb_numpy_ndarray::nb_numpy_ndarray( + nb_dtype dtype, absl::Span shape, + std::optional> strides, const void* ptr, + nb::handle base) { + const int64_t* strides_ptr = nullptr; + if (strides) { + if (shape.size() != strides->size()) { + throw std::invalid_argument("shape and strides must have the same size."); + } + strides_ptr = strides->data(); + } + int flags = 0; + if (base && ptr) { + nb_numpy_ndarray base_array; + if (nb::try_cast(base, base_array)) { + flags = base_array.flags() & ~NPY_ARRAY_OWNDATA; + } else { + flags = NPY_ARRAY_WRITEABLE; + } + } + // The reinterpret_cast below assumes that npy_intp and int64_t are the same + // width. If that changes, then the code should be updated to convert instead. + static_assert(sizeof(int64_t) == sizeof(npy_intp)); + nb::object array = nb::steal(PyArray_NewFromDescr( + &PyArray_Type, reinterpret_cast(dtype.release().ptr()), + shape.size(), reinterpret_cast(shape.data()), + reinterpret_cast(strides_ptr), const_cast(ptr), + flags, + /*obj=*/nullptr)); + if (!array) { + throw nb::python_error(); + } + if (ptr) { + if (base) { + PyArray_SetBaseObject(reinterpret_cast(array.ptr()), + base.inc_ref().ptr()); + } else { + array = nb::steal(PyArray_NewCopy( + reinterpret_cast(array.ptr()), NPY_ANYORDER)); + } + } + m_ptr = array.release().ptr(); +} + +/*static*/ nb_numpy_ndarray nb_numpy_ndarray::from_any(nanobind::handle h, + int extra_requirements) { + nb::handle out = PyArray_FromAny( + h.ptr(), /*dtype=*/nullptr, /*min_depth=*/0, + /*max_depth=*/0, + /*requirements=*/NPY_ARRAY_ENSUREARRAY | extra_requirements, + /*context=*/nullptr); + if (PyErr_Occurred()) { + throw nb::python_error(); + } + return nb::steal(out); +} + +/*static*/ nb_numpy_ndarray nb_numpy_ndarray::ensure(nanobind::handle h, + int extra_requirements) { + nb::handle out = PyArray_FromAny( + h.ptr(), /*dtype=*/nullptr, /*min_depth=*/0, + /*max_depth=*/0, + /*requirements=*/NPY_ARRAY_ENSUREARRAY | extra_requirements, + /*context=*/nullptr); + if (!out) { + PyErr_Clear(); + } + return nb::steal(out); +} + +nb_dtype nb_numpy_ndarray::dtype() const { + PyArrayObject* self = reinterpret_cast(ptr()); + return nb::borrow(reinterpret_cast(PyArray_DESCR(self))); +} + +npy_intp nb_numpy_ndarray::ndim() const { + PyArrayObject* self = reinterpret_cast(ptr()); + return PyArray_NDIM(self); +} + +const npy_intp* nb_numpy_ndarray::shape() const { + PyArrayObject* self = reinterpret_cast(ptr()); + return PyArray_SHAPE(self); +} + +npy_intp nb_numpy_ndarray::shape(npy_intp dim) const { + PyArrayObject* self = reinterpret_cast(ptr()); + if (dim < 0 || dim >= PyArray_NDIM(self)) { + throw std::invalid_argument("Invalid dimension."); + } + return PyArray_SHAPE(self)[dim]; +} + +const npy_intp* nb_numpy_ndarray::strides() const { + PyArrayObject* self = reinterpret_cast(ptr()); + return PyArray_STRIDES(self); +} + +npy_intp nb_numpy_ndarray::strides(npy_intp dim) const { + PyArrayObject* self = reinterpret_cast(ptr()); + if (dim < 0 || dim >= PyArray_NDIM(self)) { + throw std::invalid_argument("Invalid dimension."); + } + return PyArray_STRIDES(self)[dim]; +} + +npy_intp nb_numpy_ndarray::itemsize() const { + PyArrayObject* self = reinterpret_cast(ptr()); + return PyArray_ITEMSIZE(self); +} + +npy_intp nb_numpy_ndarray::size() const { + PyArrayObject* self = reinterpret_cast(ptr()); + return PyArray_SIZE(self); +} + +const void* nb_numpy_ndarray::data() const { + PyArrayObject* self = reinterpret_cast(ptr()); + return PyArray_DATA(self); +} + +void* nb_numpy_ndarray::mutable_data() { + PyArrayObject* self = reinterpret_cast(ptr()); + return PyArray_DATA(self); +} + +int nb_numpy_ndarray::flags() const { + PyArrayObject* self = reinterpret_cast(ptr()); + return PyArray_FLAGS(self); +} + +} // namespace xla diff --git a/xla/python/nb_numpy.h b/xla/python/nb_numpy.h new file mode 100644 index 0000000000000..7dc495c4de724 --- /dev/null +++ b/xla/python/nb_numpy.h @@ -0,0 +1,107 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Nanobind wrappers for NumPy types. +// +// Unlike pybind11, nanobind does not provide direct wrappers for NumPy types. +// This file provides nanobind equivalents of pybind11::dtype and +// pybind11::array. + +#ifndef XLA_PYTHON_NB_NUMPY_H_ +#define XLA_PYTHON_NB_NUMPY_H_ + +#include + +#include +#include +#include + +#include "absl/types/span.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "xla/tsl/python/lib/core/numpy.h" + +#if NPY_ABI_VERSION < 0x02000000 +#define PyDataType_ELSIZE(descr) ((descr)->elsize) +#endif + +namespace xla { + +// Caution: to use this type you must call tsl::ImportNumpy() in your module +// initialization function. Otherwise PyArray_DescrCheck will be nullptr. +class nb_dtype : public nanobind::object { + public: + NB_OBJECT_DEFAULT(nb_dtype, object, "dtype", PyArray_DescrCheck); // NOLINT + + explicit nb_dtype(const nanobind::str& format) + : nb_dtype(from_args(format)) {} + explicit nb_dtype(std::string_view format) + : nb_dtype(from_args(nanobind::str(format.data(), format.size()))) {} + + static nb_dtype from_args(const nanobind::object& args); + + int char_() const { + auto* descr = reinterpret_cast(ptr()); + return descr->type; + } + + int itemsize() const { + auto* descr = reinterpret_cast(ptr()); + return PyDataType_ELSIZE(descr); + } + + /// Single-character code for dtype's kind. + /// For example, floating point types are 'f' and integral types are 'i'. + char kind() const { + auto* descr = reinterpret_cast(ptr()); + return descr->kind; + } +}; + +class nb_numpy_ndarray : public nanobind::object { + public: + NB_OBJECT_DEFAULT(nb_numpy_ndarray, object, "ndarray", + PyArray_Check); // NOLINT + + nb_numpy_ndarray(nb_dtype dtype, absl::Span shape, + std::optional> strides, + const void* ptr = nullptr, + nanobind::handle base = nanobind::handle()); + + // Ensures that the given handle is a numpy array. If provided, + // extra_requirements flags (NPY_ARRAY_...) are passed to PyArray_FromAny. + // In case of an error, nullptr is returned and the Python error is cleared. + static nb_numpy_ndarray ensure(nanobind::handle h, + int extra_requirements = 0); + + // Constructs a numpy ndarray via the PyArray_From Any API. This throws an + // error if an exception occurs. + static nb_numpy_ndarray from_any(nanobind::handle h, int extra_requirements); + + nb_dtype dtype() const; + npy_intp ndim() const; + const npy_intp* shape() const; + npy_intp shape(npy_intp dim) const; + const npy_intp* strides() const; + npy_intp strides(npy_intp dim) const; + npy_intp itemsize() const; + npy_intp size() const; + const void* data() const; + void* mutable_data(); + int flags() const; +}; + +} // namespace xla + +#endif // XLA_PYTHON_NB_NUMPY_H_ diff --git a/xla/python/ops.cc b/xla/python/ops.cc index 2d44a5bf41523..1b583873b9715 100644 --- a/xla/python/ops.cc +++ b/xla/python/ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ limitations under the License. #include "xla/python/ops.h" +#include +#include #include #include #include @@ -22,8 +24,13 @@ limitations under the License. #include #include "absl/types/span.h" -#include "pybind11/attr.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/optional.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/pair.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/vector.h" // from @nanobind // IWYU pragma: keep #include "xla/client/lib/approx_topk.h" #include "xla/client/lib/approx_topk_shape.h" #include "xla/client/lib/comparators.h" @@ -35,225 +42,475 @@ limitations under the License. #include "xla/client/lib/svd.h" #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" -#include "xla/python/status_casters.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_helpers.h" #include "xla/python/types.h" #include "xla/xla_data.pb.h" -namespace xla { +namespace nb = nanobind; + +namespace nanobind { +namespace detail { + +// XLA protocol buffers +// We don't actually care that these are the protocol buffers, we merely want +// objects that duck type as protocol buffers. The client code currently avoids +// depending on Python protocol buffers to avoid conflicting definitions from +// different modules that both include XLA. + +template <> +struct type_caster { + public: + NB_TYPE_CASTER_FROM_PYTHON_ONLY( + xla::ConvolutionDimensionNumbers, + const_name("xla::ConvolutionDimensionNumbers")); + + // PyObject -> C++ conversion. + bool from_python(handle handle, uint8_t, cleanup_list*) { + try { + value.set_input_batch_dimension( + cast(getattr(handle, "input_batch_dimension"))); + value.set_input_feature_dimension( + cast(getattr(handle, "input_feature_dimension"))); + value.set_output_batch_dimension( + cast(getattr(handle, "output_batch_dimension"))); + value.set_output_feature_dimension( + cast(getattr(handle, "output_feature_dimension"))); + value.set_kernel_input_feature_dimension( + cast(getattr(handle, "kernel_input_feature_dimension"))); + value.set_kernel_output_feature_dimension( + cast(getattr(handle, "kernel_output_feature_dimension"))); + std::vector dims; + dims = cast>( + getattr(handle, "input_spatial_dimensions")); + std::copy(dims.begin(), dims.end(), + tsl::protobuf::RepeatedFieldBackInserter( + value.mutable_input_spatial_dimensions())); + dims = cast>( + getattr(handle, "kernel_spatial_dimensions")); + std::copy(dims.begin(), dims.end(), + tsl::protobuf::RepeatedFieldBackInserter( + value.mutable_kernel_spatial_dimensions())); + dims = cast>( + getattr(handle, "output_spatial_dimensions")); + std::copy(dims.begin(), dims.end(), + tsl::protobuf::RepeatedFieldBackInserter( + value.mutable_output_spatial_dimensions())); + return true; + } catch (...) { + return false; + } + } +}; + +template <> +struct type_caster { + public: + NB_TYPE_CASTER_FROM_PYTHON_ONLY(xla::DotDimensionNumbers, + const_name("xla::DotDimensionNumbers")); + + // PyObject -> C++ conversion. + bool from_python(handle handle, uint8_t flags, cleanup_list*) noexcept { + try { + std::vector dims = cast>( + getattr(handle, "lhs_contracting_dimensions")); + std::copy(dims.begin(), dims.end(), + tsl::protobuf::RepeatedFieldBackInserter( + value.mutable_lhs_contracting_dimensions())); + dims = cast>( + getattr(handle, "rhs_contracting_dimensions")); + std::copy(dims.begin(), dims.end(), + tsl::protobuf::RepeatedFieldBackInserter( + value.mutable_rhs_contracting_dimensions())); + dims = + cast>(getattr(handle, "lhs_batch_dimensions")); + std::copy(dims.begin(), dims.end(), + tsl::protobuf::RepeatedFieldBackInserter( + value.mutable_lhs_batch_dimensions())); + dims = + cast>(getattr(handle, "rhs_batch_dimensions")); + std::copy(dims.begin(), dims.end(), + tsl::protobuf::RepeatedFieldBackInserter( + value.mutable_rhs_batch_dimensions())); + return true; + } catch (...) { + return false; + } + } +}; + +template <> +struct type_caster { + public: + NB_TYPE_CASTER_FROM_PYTHON_ONLY(xla::GatherDimensionNumbers, + const_name("xla::GatherDimensionNumbers")); + + // PyObject -> C++ conversion. + bool from_python(handle handle, uint8_t, cleanup_list*) { + try { + std::vector dims; + dims = cast>(getattr(handle, "offset_dims")); + std::copy(dims.begin(), dims.end(), + tsl::protobuf::RepeatedFieldBackInserter( + value.mutable_offset_dims())); + dims = + cast>(getattr(handle, "collapsed_slice_dims")); + std::copy(dims.begin(), dims.end(), + tsl::protobuf::RepeatedFieldBackInserter( + value.mutable_collapsed_slice_dims())); + dims = cast>(getattr(handle, "start_index_map")); + std::copy(dims.begin(), dims.end(), + tsl::protobuf::RepeatedFieldBackInserter( + value.mutable_start_index_map())); + value.set_index_vector_dim( + cast(getattr(handle, "index_vector_dim"))); + return true; + } catch (...) { + return false; + } + } +}; + +template <> +struct type_caster { + public: + NB_TYPE_CASTER_FROM_PYTHON_ONLY(xla::ScatterDimensionNumbers, + const_name("xla::ScatterDimensionNumbers")); + + // PyObject -> C++ conversion. + bool from_python(handle handle, uint8_t, cleanup_list*) { + try { + std::vector dims; + dims = cast>(getattr(handle, "update_window_dims")); + std::copy(dims.begin(), dims.end(), + tsl::protobuf::RepeatedFieldBackInserter( + value.mutable_update_window_dims())); + dims = + cast>(getattr(handle, "inserted_window_dims")); + std::copy(dims.begin(), dims.end(), + tsl::protobuf::RepeatedFieldBackInserter( + value.mutable_inserted_window_dims())); + dims = cast>( + getattr(handle, "scatter_dims_to_operand_dims")); + std::copy(dims.begin(), dims.end(), + tsl::protobuf::RepeatedFieldBackInserter( + value.mutable_scatter_dims_to_operand_dims())); + value.set_index_vector_dim( + cast(getattr(handle, "index_vector_dim"))); + return true; + } catch (...) { + return false; + } + } +}; + +template <> +struct type_caster { + public: + NB_TYPE_CASTER_FROM_PYTHON_ONLY(xla::ReplicaGroup, + const_name("xla::ReplicaGroup")); -namespace py = pybind11; + // PyObject -> C++ conversion. + bool from_python(handle handle, uint8_t, cleanup_list*) { + try { + auto dims = cast>(getattr(handle, "replica_ids")); + std::copy(dims.begin(), dims.end(), + tsl::protobuf::RepeatedFieldBackInserter( + value.mutable_replica_ids())); + return true; + } catch (...) { + return false; + } + } +}; -void BuildOpsSubmodule(py::module* m) { +template <> +struct type_caster { + public: + NB_TYPE_CASTER_FROM_PYTHON_ONLY(xla::PaddingConfig, + const_name("xla::PaddingConfig")); + + // PyObject -> C++ conversion. + bool from_python(handle handle, uint8_t, cleanup_list*) { + try { + sequence dimensions = borrow(getattr(handle, "dimensions")); + + for (const auto& dimension : dimensions) { + xla::PaddingConfig::PaddingConfigDimension* config_dim = + value.add_dimensions(); + config_dim->set_edge_padding_low( + cast(getattr(dimension, "edge_padding_low"))); + config_dim->set_edge_padding_high( + cast(getattr(dimension, "edge_padding_high"))); + config_dim->set_interior_padding( + cast(getattr(dimension, "interior_padding"))); + } + return true; + } catch (...) { + return false; + } + } +}; + +template <> +struct type_caster { + public: + NB_TYPE_CASTER_FROM_PYTHON_ONLY(xla::PrecisionConfig, + const_name("xla::PrecisionConfig")); + + // PyObject -> C++ conversion. + bool from_python(handle handle, uint8_t, cleanup_list*) { + try { + if (handle.is_none()) { + return true; + } + + sequence operand_precisions = + borrow(getattr(handle, "operand_precision")); + + for (const auto& operand_precision : operand_precisions) { + value.add_operand_precision( + cast(operand_precision)); + } + return true; + } catch (...) { + return false; + } + } +}; + +} // namespace detail +} // namespace nanobind + +namespace xla { + +void BuildOpsSubmodule(nb::module_& m) { // ops submodule, containing free functions that add operators to an // XlaBuilder. - py::module ops = m->def_submodule("ops", "XLA operations"); + nb::module_ ops = m.def_submodule("ops", "XLA operations"); - py::enum_( + nb::enum_( ops, "TriangularSolveOptions_Transpose") .value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID) .value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE) .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE) .value("ADJOINT", TriangularSolveOptions::ADJOINT); - py::enum_(ops, "RandomAlgorithm") + nb::enum_(ops, "RandomAlgorithm") .value("RNG_DEFAULT", RandomAlgorithm::RNG_DEFAULT) .value("RNG_THREE_FRY", RandomAlgorithm::RNG_THREE_FRY) .value("RNG_PHILOX", RandomAlgorithm::RNG_PHILOX); - py::enum_(ops, "CustomCallSchedule") + nb::enum_(ops, "CustomCallSchedule") .value("SCHEDULE_NONE", CustomCallSchedule::SCHEDULE_NONE) .value("SCHEDULE_LATEST", CustomCallSchedule::SCHEDULE_LATEST) .value("SCHEDULE_EARLIEST", CustomCallSchedule::SCHEDULE_EARLIEST); - py::enum_(ops, "CustomCallApiVersion") + nb::enum_(ops, "CustomCallApiVersion") .value("API_VERSION_ORIGINAL", CustomCallApiVersion::API_VERSION_ORIGINAL) .value("API_VERSION_STATUS_RETURNING", CustomCallApiVersion::API_VERSION_STATUS_RETURNING) .value("API_VERSION_STATUS_RETURNING_UNIFIED", - CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED); - - ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens")); - ops.def("AllGather", &AllGather, py::arg("operand"), - py::arg("all_gather_dimension"), py::arg("shard_count"), - py::arg("replica_groups") = py::list(), - py::arg("channel_id") = std::nullopt, - py::arg("shape_with_layout") = std::nullopt, - py::arg("use_global_device_ids") = std::nullopt); + CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED) + .value("API_VERSION_TYPED_FFI", + CustomCallApiVersion::API_VERSION_TYPED_FFI); + + ops.def("AfterAll", &AfterAll, nb::arg("builder"), nb::arg("tokens")); + ops.def("AllGather", &AllGather, nb::arg("operand"), + nb::arg("all_gather_dimension"), nb::arg("shard_count"), + nb::arg("replica_groups") = nb::list(), + nb::arg("channel_id") = std::nullopt, + nb::arg("shape_with_layout") = std::nullopt, + nb::arg("use_global_device_ids") = std::nullopt); ops.def("AllReduce", static_cast, const std::optional&, const std::optional&, const std::optional)>(&AllReduce), - py::arg("operand"), py::arg("computation"), - py::arg("replica_groups") = py::list(), - py::arg("channel_id") = std::nullopt, - py::arg("shape_with_layout") = std::nullopt, - py::arg("use_global_device_ids") = std::nullopt); - ops.def("ReduceScatter", &ReduceScatter, py::arg("operand"), - py::arg("computation"), py::arg("scatter_dimension"), - py::arg("shard_count"), py::arg("replica_groups") = py::list(), - py::arg("channel_id") = std::nullopt, - py::arg("layout") = std::nullopt, - py::arg("use_global_device_ids") = std::nullopt); - ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"), - py::arg("concat_dimension"), py::arg("split_count"), - py::arg("replica_groups") = py::list(), - py::arg("layout") = std::nullopt, - py::arg("channel_id") = std::nullopt); - ops.def("ApproxTopK", &ApproxTopK, py::arg("builder"), py::arg("operands"), - py::arg("init_values"), py::arg("top_k"), py::arg("reduction_dim"), - py::arg("comparator"), py::arg("recall_target") = 0.9, - py::arg("aggregate_to_topk") = true, - py::arg("reduction_input_size_override") = -1); - ops.def("ApproxTopKFallback", &ApproxTopKFallback, py::arg("builder"), - py::arg("operands"), py::arg("init_values"), py::arg("top_k"), - py::arg("reduction_dim"), py::arg("comparator"), - py::arg("recall_target") = 0.9, py::arg("aggregate_to_topk") = true, - py::arg("reduction_input_size_override") = -1); + nb::arg("operand"), nb::arg("computation"), + nb::arg("replica_groups") = nb::list(), + nb::arg("channel_id") = std::nullopt, + nb::arg("shape_with_layout") = std::nullopt, + nb::arg("use_global_device_ids") = std::nullopt); + ops.def("ReduceScatter", &ReduceScatter, nb::arg("operand"), + nb::arg("computation"), nb::arg("scatter_dimension"), + nb::arg("shard_count"), nb::arg("replica_groups") = nb::list(), + nb::arg("channel_id") = std::nullopt, + nb::arg("layout") = std::nullopt, + nb::arg("use_global_device_ids") = std::nullopt); + ops.def("AllToAll", &AllToAll, nb::arg("operand"), nb::arg("split_dimension"), + nb::arg("concat_dimension"), nb::arg("split_count"), + nb::arg("replica_groups") = nb::list(), + nb::arg("layout") = std::nullopt, + nb::arg("channel_id") = std::nullopt); + ops.def("ApproxTopK", &ApproxTopK, nb::arg("builder"), nb::arg("operands"), + nb::arg("init_values"), nb::arg("top_k"), nb::arg("reduction_dim"), + nb::arg("comparator"), nb::arg("recall_target") = 0.9, + nb::arg("aggregate_to_topk") = true, + nb::arg("reduction_input_size_override") = -1); + ops.def("ApproxTopKFallback", &ApproxTopKFallback, nb::arg("builder"), + nb::arg("operands"), nb::arg("init_values"), nb::arg("top_k"), + nb::arg("reduction_dim"), nb::arg("comparator"), + nb::arg("recall_target") = 0.9, nb::arg("aggregate_to_topk") = true, + nb::arg("reduction_input_size_override") = -1); ops.def("ApproxTopKReductionOutputSize", xla::ValueOrThrowWrapper(ApproxTopKReductionOutputSize), - py::arg("input_size"), py::arg("rank"), py::arg("top_k"), - py::arg("recall_target"), py::arg("aggregate_to_topk") = true, - py::arg("input_size_override") = -1); - ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"), - py::arg("new_element_type")); - ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes")); - ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"), - py::arg("shape"), py::arg("broadcast_dimensions")); - ops.def("Call", &Call, py::arg("builder"), py::arg("computation"), - py::arg("operands")); - ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true); - ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max")); - ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions")); - ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"), - py::arg("source_target_pairs"), py::arg("channel_id") = std::nullopt); - ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"), - py::arg("dimension")); + nb::arg("input_size"), nb::arg("rank"), nb::arg("top_k"), + nb::arg("recall_target"), nb::arg("aggregate_to_topk") = true, + nb::arg("input_size_override") = -1); + ops.def("BitcastConvertType", &BitcastConvertType, nb::arg("operand"), + nb::arg("new_element_type")); + ops.def("Broadcast", &Broadcast, nb::arg("operand"), nb::arg("sizes")); + ops.def("BroadcastInDim", &BroadcastInDim, nb::arg("operand"), + nb::arg("shape"), nb::arg("broadcast_dimensions")); + ops.def("Call", &Call, nb::arg("builder"), nb::arg("computation"), + nb::arg("operands")); + ops.def("Cholesky", &Cholesky, nb::arg("a"), nb::arg("lower") = true); + ops.def("Clamp", &Clamp, nb::arg("min"), nb::arg("operand"), nb::arg("max")); + ops.def("Collapse", &Collapse, nb::arg("operand"), nb::arg("dimensions")); + ops.def("CollectivePermute", &CollectivePermute, nb::arg("operand"), + nb::arg("source_target_pairs"), nb::arg("channel_id") = std::nullopt); + ops.def("ConcatInDim", &ConcatInDim, nb::arg("builder"), nb::arg("operands"), + nb::arg("dimension")); ops.def("Conditional", static_cast, absl::Span)>(&Conditional), - py::arg("branch_index"), py::arg("branch_computations"), - py::arg("branch_operands")); + nb::arg("branch_index"), nb::arg("branch_computations"), + nb::arg("branch_operands")); ops.def("Conditional", static_cast(&Conditional), - py::arg("predicate"), py::arg("true_operand"), - py::arg("true_computation"), py::arg("false_operand"), - py::arg("false_computation")); - ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal")); - ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"), - py::arg("literal")); - ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"), - py::arg("rhs"), py::arg("window_strides"), py::arg("padding"), - py::arg("lhs_dilation"), py::arg("rhs_dilation"), - py::arg("dimension_numbers"), py::arg("feature_group_count") = 1, - py::arg("batch_group_count") = 1, - py::arg("precision_config") = nullptr, - py::arg("preferred_element_type") = std::nullopt, - py::arg("window_reversal") = std::nullopt); - ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"), - py::arg("new_element_type")); - ops.def("CreateToken", &CreateToken, py::arg("builder")); + nb::arg("predicate"), nb::arg("true_operand"), + nb::arg("true_computation"), nb::arg("false_operand"), + nb::arg("false_computation")); + ops.def("Constant", &ConstantLiteral, nb::arg("builder"), nb::arg("literal")); + ops.def("ConstantLiteral", &ConstantLiteral, nb::arg("builder"), + nb::arg("literal")); + ops.def("ConvGeneralDilated", &ConvGeneralDilated, nb::arg("lhs"), + nb::arg("rhs"), nb::arg("window_strides"), nb::arg("padding"), + nb::arg("lhs_dilation"), nb::arg("rhs_dilation"), + nb::arg("dimension_numbers"), nb::arg("feature_group_count") = 1, + nb::arg("batch_group_count") = 1, + nb::arg("precision_config") = nullptr, + nb::arg("preferred_element_type") = std::nullopt, + nb::arg("window_reversal") = std::nullopt); + ops.def("ConvertElementType", &ConvertElementType, nb::arg("operand"), + nb::arg("new_element_type")); + ops.def("CreateToken", &CreateToken, nb::arg("builder")); ops.def("CrossReplicaSum", static_cast)>( &CrossReplicaSum), - py::arg("operand"), py::arg("replica_groups") = py::list()); + nb::arg("operand"), nb::arg("replica_groups") = nb::list()); ops.def( "CustomCall", - [](XlaBuilder* builder, const py::bytes& call_target_name, + [](XlaBuilder* builder, const nb::bytes& call_target_name, absl::Span operands, const Shape& shape, - const py::bytes& opaque, bool has_side_effect, + const nb::bytes& opaque, bool has_side_effect, CustomCallSchedule schedule, CustomCallApiVersion api_version) -> XlaOp { - return CustomCall(builder, call_target_name, operands, shape, opaque, - has_side_effect, /*output_operand_aliasing=*/{}, + std::string call_target_name_str(call_target_name.c_str(), + call_target_name.size()); + std::string opaque_str(opaque.c_str(), opaque.size()); + return CustomCall(builder, call_target_name_str, operands, shape, + opaque_str, has_side_effect, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, schedule, api_version); }, - py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), - py::arg("shape"), py::arg("opaque") = py::bytes(""), - py::arg("has_side_effect") = false, - py::arg("schedule") = CustomCallSchedule::SCHEDULE_NONE, - py::arg("api_version") = CustomCallApiVersion::API_VERSION_ORIGINAL); + nb::arg("builder"), nb::arg("call_target_name"), nb::arg("operands"), + nb::arg("shape"), nb::arg("opaque") = nb::bytes(""), + nb::arg("has_side_effect") = false, + nb::arg("schedule") = CustomCallSchedule::SCHEDULE_NONE, + nb::arg("api_version") = CustomCallApiVersion::API_VERSION_ORIGINAL); ops.def( "CustomCallWithLayout", - [](XlaBuilder* builder, const py::bytes& call_target_name, + [](XlaBuilder* builder, const nb::bytes& call_target_name, absl::Span operands, const Shape& shape_with_layout, absl::Span operand_shapes_with_layout, - const py::bytes& opaque, bool has_side_effect, + const nb::bytes& opaque, bool has_side_effect, CustomCallSchedule schedule, CustomCallApiVersion api_version) -> XlaOp { + std::string call_target_name_str(call_target_name.c_str(), + call_target_name.size()); + std::string opaque_str(opaque.c_str(), opaque.size()); return CustomCallWithLayout( - builder, call_target_name, operands, shape_with_layout, - operand_shapes_with_layout, opaque, has_side_effect, + builder, call_target_name_str, operands, shape_with_layout, + operand_shapes_with_layout, opaque_str, has_side_effect, /*output_operand_aliasing=*/{}, /*literal=*/nullptr, schedule, api_version); }, - py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), - py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"), - py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false, - py::arg("schedule") = CustomCallSchedule::SCHEDULE_NONE, - py::arg("api_version") = CustomCallApiVersion::API_VERSION_ORIGINAL); + nb::arg("builder"), nb::arg("call_target_name"), nb::arg("operands"), + nb::arg("shape_with_layout"), nb::arg("operand_shapes_with_layout"), + nb::arg("opaque") = nb::bytes(""), nb::arg("has_side_effect") = false, + nb::arg("schedule") = CustomCallSchedule::SCHEDULE_NONE, + nb::arg("api_version") = CustomCallApiVersion::API_VERSION_ORIGINAL); ops.def( "CustomCallWithAliasing", - [](XlaBuilder* builder, const py::bytes& call_target_name, + [](XlaBuilder* builder, const nb::bytes& call_target_name, absl::Span operands, const Shape& shape_with_layout, absl::Span operand_shapes_with_layout, - const py::bytes& opaque, bool has_side_effect, + const nb::bytes& opaque, bool has_side_effect, absl::Span>> output_operand_aliasing, const Literal* literal, CustomCallSchedule schedule, CustomCallApiVersion api_version) -> XlaOp { + std::string call_target_name_str(call_target_name.c_str(), + call_target_name.size()); + std::string opaque_str(opaque.c_str(), opaque.size()); return CustomCallWithLayout( - builder, call_target_name, operands, shape_with_layout, - operand_shapes_with_layout, opaque, has_side_effect, + builder, call_target_name_str, operands, shape_with_layout, + operand_shapes_with_layout, opaque_str, has_side_effect, output_operand_aliasing, literal, schedule, api_version); }, - py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), - py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"), - py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false, - py::arg("output_operand_aliasing"), py::arg("literal") = nullptr, - py::arg("schedule") = CustomCallSchedule::SCHEDULE_NONE, - py::arg("api_version") = CustomCallApiVersion::API_VERSION_ORIGINAL); + nb::arg("builder"), nb::arg("call_target_name"), nb::arg("operands"), + nb::arg("shape_with_layout"), nb::arg("operand_shapes_with_layout"), + nb::arg("opaque") = nb::bytes(""), nb::arg("has_side_effect") = false, + nb::arg("output_operand_aliasing"), nb::arg("literal") = nullptr, + nb::arg("schedule") = CustomCallSchedule::SCHEDULE_NONE, + nb::arg("api_version") = CustomCallApiVersion::API_VERSION_ORIGINAL); ops.def( "CustomCallWithComputation", - [](XlaBuilder* builder, const std::string& call_target_name, + [](XlaBuilder* builder, const nb::bytes& call_target_name, absl::Span operands, const XlaComputation& computation, - const Shape& shape, const std::string& opaque, bool has_side_effect, + const Shape& shape, const nb::bytes& opaque, bool has_side_effect, absl::Span>> output_operand_aliasing, const Literal* literal, CustomCallSchedule schedule, CustomCallApiVersion api_version) -> XlaOp { + std::string call_target_name_str(call_target_name.c_str(), + call_target_name.size()); + std::string opaque_str(opaque.c_str(), opaque.size()); return CustomCallWithComputation( - builder, call_target_name, operands, computation, shape, opaque, - has_side_effect, output_operand_aliasing, literal, schedule, - api_version); + builder, call_target_name_str, operands, computation, shape, + opaque_str, has_side_effect, output_operand_aliasing, literal, + schedule, api_version); }, - py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), - py::arg("computation"), py::arg("shape"), - py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false, - py::arg("output_operand_aliasing"), py::arg("literal") = nullptr, - py::arg("schedule") = CustomCallSchedule::SCHEDULE_NONE, - py::arg("api_version") = CustomCallApiVersion::API_VERSION_ORIGINAL); - ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"), - py::arg("precision_config") = nullptr, - py::arg("preferred_element_type") = std::nullopt); - ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"), - py::arg("dimension_numbers"), py::arg("precision_config") = nullptr, - py::arg("preferred_element_type") = std::nullopt); + nb::arg("builder"), nb::arg("call_target_name"), nb::arg("operands"), + nb::arg("computation"), nb::arg("shape"), + nb::arg("opaque") = nb::bytes(""), nb::arg("has_side_effect") = false, + nb::arg("output_operand_aliasing"), nb::arg("literal") = nullptr, + nb::arg("schedule") = CustomCallSchedule::SCHEDULE_NONE, + nb::arg("api_version") = CustomCallApiVersion::API_VERSION_ORIGINAL); + ops.def("Dot", &Dot, nb::arg("lhs"), nb::arg("rhs"), + nb::arg("precision_config") = nullptr, + nb::arg("preferred_element_type") = std::nullopt); + ops.def("DotGeneral", &DotGeneral, nb::arg("lhs"), nb::arg("rhs"), + nb::arg("dimension_numbers"), nb::arg("precision_config") = nullptr, + nb::arg("preferred_element_type") = std::nullopt); ops.def("DynamicReshape", static_cast, absl::Span, const std::vector&)>(&DynamicReshape), - py::arg("operand"), py::arg("dim_sizes"), py::arg("new_size_bounds"), - py::arg("dims_are_dynamic")); + nb::arg("operand"), nb::arg("dim_sizes"), nb::arg("new_size_bounds"), + nb::arg("dims_are_dynamic")); ops.def("DynamicSlice", static_cast, absl::Span)>(&DynamicSlice), - py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes")); + nb::arg("operand"), nb::arg("start_indices"), nb::arg("slice_sizes")); ops.def("DynamicUpdateSlice", static_cast)>( &DynamicUpdateSlice), - py::arg("operand"), py::arg("update"), py::arg("start_indices")); + nb::arg("operand"), nb::arg("update"), nb::arg("start_indices")); ops.def( "Eigh", [](XlaOp a, bool lower, int64_t max_iter, float epsilon, @@ -262,51 +519,51 @@ void BuildOpsSubmodule(py::module* m) { SelfAdjointEig(a, lower, max_iter, epsilon, sort_eigenvalues); return std::make_pair(eigh.v, eigh.w); }, - py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 15, - py::arg("epsilon") = 1e-5, py::arg("sort_eigenvalues") = true); - ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"), - py::arg("fft_length")); - ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"), - py::arg("dimension_numbers"), py::arg("slice_sizes"), - py::arg("indices_are_sorted") = false); - ops.def("GetDimensionSize", &GetDimensionSize, py::arg("operand"), - py::arg("dimension")); - ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"), - py::arg("index")); - ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"), - py::arg("shape"), py::arg("config") = ""); + nb::arg("a"), nb::arg("lower") = true, nb::arg("max_iter") = 15, + nb::arg("epsilon") = 1e-5, nb::arg("sort_eigenvalues") = true); + ops.def("Fft", &Fft, nb::arg("operand"), nb::arg("fft_type"), + nb::arg("fft_length")); + ops.def("Gather", &Gather, nb::arg("a"), nb::arg("start_indices"), + nb::arg("dimension_numbers"), nb::arg("slice_sizes"), + nb::arg("indices_are_sorted") = false); + ops.def("GetDimensionSize", &GetDimensionSize, nb::arg("operand"), + nb::arg("dimension")); + ops.def("GetTupleElement", &GetTupleElement, nb::arg("tuple_data"), + nb::arg("index")); + ops.def("InfeedWithToken", &InfeedWithToken, nb::arg("token"), + nb::arg("shape"), nb::arg("config") = ""); ops.def("Iota", static_cast(&Iota), - py::arg("builder"), py::arg("shape"), py::arg("iota_dimension")); + nb::arg("builder"), nb::arg("shape"), nb::arg("iota_dimension")); ops.def("Iota", static_cast(&Iota), - py::arg("builder"), py::arg("type"), py::arg("size")); + nb::arg("builder"), nb::arg("type"), nb::arg("size")); ops.def( "LU", [](XlaOp a) -> std::tuple { LuDecompositionResult lu = LuDecomposition(a); return std::make_tuple(lu.lu, lu.pivots, lu.permutation); }, - py::arg("operand")); - ops.def("Map", &Map, py::arg("builder"), py::arg("operands"), - py::arg("computation"), py::arg("dimensions"), - py::arg("static_operands") = py::list()); - ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to")); - ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"), - py::arg("token"), py::arg("shape_with_layout"), - py::arg("outfeed_config") = ""); - ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"), - py::arg("padding_config")); + nb::arg("operand")); + ops.def("Map", &Map, nb::arg("builder"), nb::arg("operands"), + nb::arg("computation"), nb::arg("dimensions"), + nb::arg("static_operands") = nb::list()); + ops.def("NextAfter", &NextAfter, nb::arg("from"), nb::arg("to")); + ops.def("OutfeedWithToken", &OutfeedWithToken, nb::arg("operand"), + nb::arg("token"), nb::arg("shape_with_layout"), + nb::arg("outfeed_config") = ""); + ops.def("Pad", &Pad, nb::arg("operand"), nb::arg("padding_value"), + nb::arg("padding_config")); ops.def("Parameter", static_cast&)>( &Parameter), - py::arg("builder"), py::arg("parameter_number"), py::arg("shape"), - py::arg("name") = "", - py::arg("replicated_at_leaf_buffers") = std::vector()); + nb::arg("builder"), nb::arg("parameter_number"), nb::arg("shape"), + nb::arg("name") = "", + nb::arg("replicated_at_leaf_buffers") = std::vector()); ops.def("ProductOfElementaryHouseholderReflectors", - &ProductOfElementaryHouseholderReflectors, py::arg("a"), - py::arg("taus")); + &ProductOfElementaryHouseholderReflectors, nb::arg("a"), + nb::arg("taus")); ops.def( "QR", [](XlaOp a, bool full_matrices) -> std::pair { @@ -314,24 +571,24 @@ void BuildOpsSubmodule(py::module* m) { QrExplicit(a, full_matrices, q, r); return std::make_pair(q, r); }, - py::arg("operand"), py::arg("full_matrices")); + nb::arg("operand"), nb::arg("full_matrices")); ops.def( "QrDecomposition", [](XlaOp a) -> std::pair { QrDecomposition d = Qr(a); return std::make_pair(d.q_and_r, d.taus); }, - py::arg("operand")); - ops.def("RecvFromHost", &RecvFromHost, py::arg("token"), py::arg("shape"), - py::arg("handle")); + nb::arg("operand")); + ops.def("RecvFromHost", &RecvFromHost, nb::arg("token"), nb::arg("shape"), + nb::arg("handle")); ops.def("Reduce", static_cast, absl::Span, const XlaComputation&, absl::Span)>(&Reduce), - py::arg("builder"), py::arg("operands"), py::arg("init_values"), - py::arg("computation"), py::arg("dimensions_to_reduce")); - ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"), - py::arg("exponent_bits"), py::arg("mantissa_bits")); + nb::arg("builder"), nb::arg("operands"), nb::arg("init_values"), + nb::arg("computation"), nb::arg("dimensions_to_reduce")); + ops.def("ReducePrecision", &ReducePrecision, nb::arg("operand"), + nb::arg("exponent_bits"), nb::arg("mantissa_bits")); ops.def("ReduceWindowWithGeneralPadding", static_cast, @@ -339,10 +596,10 @@ void BuildOpsSubmodule(py::module* m) { absl::Span, absl::Span>)>( &ReduceWindowWithGeneralPadding), - py::arg("operand"), py::arg("init_value"), py::arg("computation"), - py::arg("window_dimensions"), py::arg("window_strides"), - py::arg("base_dilations"), py::arg("window_dilations"), - py::arg("padding")); + nb::arg("operand"), nb::arg("init_value"), nb::arg("computation"), + nb::arg("window_dimensions"), nb::arg("window_strides"), + nb::arg("base_dilations"), nb::arg("window_dilations"), + nb::arg("padding")); ops.def("ReduceWindowWithGeneralPadding", static_cast, absl::Span, @@ -351,59 +608,59 @@ void BuildOpsSubmodule(py::module* m) { absl::Span, absl::Span>)>( &ReduceWindowWithGeneralPadding), - py::arg("operands"), py::arg("init_values"), py::arg("computation"), - py::arg("window_dimensions"), py::arg("window_strides"), - py::arg("base_dilations"), py::arg("window_dilations"), - py::arg("padding")); - ops.def("RemoveDynamicDimension", &RemoveDynamicDimension, py::arg("operand"), - py::arg("dimension")); - ops.def("ReplicaId", &ReplicaId, py::arg("builder")); + nb::arg("operands"), nb::arg("init_values"), nb::arg("computation"), + nb::arg("window_dimensions"), nb::arg("window_strides"), + nb::arg("base_dilations"), nb::arg("window_dilations"), + nb::arg("padding")); + ops.def("RemoveDynamicDimension", &RemoveDynamicDimension, nb::arg("operand"), + nb::arg("dimension")); + ops.def("ReplicaId", &ReplicaId, nb::arg("builder")); ops.def("Reshape", static_cast, absl::Span)>(&Reshape), - py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes")); + nb::arg("operand"), nb::arg("dimensions"), nb::arg("new_sizes")); ops.def("Reshape", static_cast)>(&Reshape), - py::arg("operand"), py::arg("new_sizes")); - ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions")); - ops.def("RngBitGenerator", &RngBitGenerator, py::arg("algorithm"), - py::arg("initial_state"), py::arg("shape")); - ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"), - py::arg("shape")); - ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"), - py::arg("shape")); + nb::arg("operand"), nb::arg("new_sizes")); + ops.def("Rev", &Rev, nb::arg("operand"), nb::arg("dimensions")); + ops.def("RngBitGenerator", &RngBitGenerator, nb::arg("algorithm"), + nb::arg("initial_state"), nb::arg("shape")); + ops.def("RngNormal", &RngNormal, nb::arg("mu"), nb::arg("sigma"), + nb::arg("shape")); + ops.def("RngUniform", &RngUniform, nb::arg("a"), nb::arg("b"), + nb::arg("shape")); ops.def("Scatter", static_cast( &Scatter), - py::arg("input"), py::arg("scatter_indices"), py::arg("updates"), - py::arg("update_computation"), py::arg("dimension_numbers"), - py::arg("indices_are_sorted") = false, - py::arg("unique_indices") = false); + nb::arg("input"), nb::arg("scatter_indices"), nb::arg("updates"), + nb::arg("update_computation"), nb::arg("dimension_numbers"), + nb::arg("indices_are_sorted") = false, + nb::arg("unique_indices") = false); ops.def("Scatter", static_cast, XlaOp, absl::Span, const XlaComputation&, const ScatterDimensionNumbers&, bool, bool)>( &Scatter), - py::arg("inputs"), py::arg("scatter_indices"), py::arg("updates"), - py::arg("update_computation"), py::arg("dimension_numbers"), - py::arg("indices_are_sorted") = false, - py::arg("unique_indices") = false); - ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"), - py::arg("on_false")); + nb::arg("inputs"), nb::arg("scatter_indices"), nb::arg("updates"), + nb::arg("update_computation"), nb::arg("dimension_numbers"), + nb::arg("indices_are_sorted") = false, + nb::arg("unique_indices") = false); + ops.def("Select", &Select, nb::arg("pred"), nb::arg("on_true"), + nb::arg("on_false")); ops.def("SelectAndScatterWithGeneralPadding", - &SelectAndScatterWithGeneralPadding, py::arg("operand"), - py::arg("select"), py::arg("window_dimensions"), - py::arg("window_strides"), py::arg("padding"), py::arg("source"), - py::arg("init_value"), py::arg("scatter")); - ops.def("SendToHost", &SendToHost, py::arg("operand"), py::arg("token"), - py::arg("shape_with_layout"), py::arg("handle")); - ops.def("SetDimensionSize", &SetDimensionSize, py::arg("operand"), - py::arg("val"), py::arg("dimension")); - ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"), - py::arg("limit_indices"), py::arg("strides")); - ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"), - py::arg("limit_index"), py::arg("stride"), py::arg("dimno")); + &SelectAndScatterWithGeneralPadding, nb::arg("operand"), + nb::arg("select"), nb::arg("window_dimensions"), + nb::arg("window_strides"), nb::arg("padding"), nb::arg("source"), + nb::arg("init_value"), nb::arg("scatter")); + ops.def("SendToHost", &SendToHost, nb::arg("operand"), nb::arg("token"), + nb::arg("shape_with_layout"), nb::arg("handle")); + ops.def("SetDimensionSize", &SetDimensionSize, nb::arg("operand"), + nb::arg("val"), nb::arg("dimension")); + ops.def("Slice", &Slice, nb::arg("operand"), nb::arg("start_indices"), + nb::arg("limit_indices"), nb::arg("strides")); + ops.def("SliceInDim", &SliceInDim, nb::arg("operand"), nb::arg("start_index"), + nb::arg("limit_index"), nb::arg("stride"), nb::arg("dimno")); ops.def( "Sort", [](XlaBuilder* builder, absl::Span operands, @@ -426,9 +683,9 @@ void BuildOpsSubmodule(py::module* m) { } }); }, - py::arg("builder"), py::arg("operands"), - py::arg("comparator") = std::nullopt, py::arg("dimension") = -1, - py::arg("is_stable") = false); + nb::arg("builder"), nb::arg("operands"), + nb::arg("comparator") = std::nullopt, nb::arg("dimension") = -1, + nb::arg("is_stable") = false); ops.def( "SVD", [](XlaOp a, int64_t max_iter, @@ -436,28 +693,28 @@ void BuildOpsSubmodule(py::module* m) { auto svd = SVD(a, max_iter, epsilon); return std::make_tuple(svd.u, svd.d, svd.v); }, - py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6); + nb::arg("a"), nb::arg("max_iter") = 100, nb::arg("epsilon") = 1e-6); ops.def( "TopK", [](XlaOp input, int64_t k) { return TopK(input, k, /*index_type=*/PrimitiveType::S32); }, - py::arg("input"), py::arg("k")); - ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation")); - ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"), - py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"), - py::arg("transpose_a")); - ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements")); - ops.def("While", &While, py::arg("condition"), py::arg("body"), - py::arg("init")); - - ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x")); - ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x")); - ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x")); - ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x")); - ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"), - py::arg("b"), py::arg("x")); - ops.def("Zeta", &Zeta, py::arg("x"), py::arg("q")); + nb::arg("input"), nb::arg("k")); + ops.def("Transpose", &Transpose, nb::arg("operand"), nb::arg("permutation")); + ops.def("TriangularSolve", &TriangularSolve, nb::arg("a"), nb::arg("b"), + nb::arg("left_side"), nb::arg("lower"), nb::arg("unit_diagonal"), + nb::arg("transpose_a")); + ops.def("Tuple", &Tuple, nb::arg("builder"), nb::arg("elements")); + ops.def("While", &While, nb::arg("condition"), nb::arg("body"), + nb::arg("init")); + + ops.def("Igamma", &Igamma, nb::arg("a"), nb::arg("x")); + ops.def("Igammac", &Igammac, nb::arg("a"), nb::arg("x")); + ops.def("IgammaGradA", &IgammaGradA, nb::arg("a"), nb::arg("x")); + ops.def("RandomGammaGrad", &RandomGammaGrad, nb::arg("a"), nb::arg("x")); + ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, nb::arg("a"), + nb::arg("b"), nb::arg("x")); + ops.def("Zeta", &Zeta, nb::arg("x"), nb::arg("q")); #define BINARY_OP(op) \ ops.def( \ @@ -465,8 +722,8 @@ void BuildOpsSubmodule(py::module* m) { [](XlaOp a, XlaOp b, std::optional> dims) { \ return dims ? op(a, b, *dims) : op(a, b); \ }, \ - py::arg("lhs"), py::arg("rhs"), \ - py::arg("broadcast_dimensions") = std::nullopt) + nb::arg("lhs"), nb::arg("rhs"), \ + nb::arg("broadcast_dimensions") = std::nullopt) BINARY_OP(Eq); BINARY_OP(Ne); BINARY_OP(Ge); diff --git a/xla/python/ops.h b/xla/python/ops.h index 636b5be20346c..715fc6a99e7c5 100644 --- a/xla/python/ops.h +++ b/xla/python/ops.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,11 +17,11 @@ limitations under the License. #define XLA_PYTHON_OPS_H_ // placeholder for index annotation headers -#include "pybind11/pybind11.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind namespace xla { -void BuildOpsSubmodule(pybind11::module* m); +void BuildOpsSubmodule(nanobind::module_& m); } // namespace xla diff --git a/xla/python/outfeed_receiver.cc b/xla/python/outfeed_receiver.cc index 7e965a4781e3b..2f30052b87e54 100644 --- a/xla/python/outfeed_receiver.cc +++ b/xla/python/outfeed_receiver.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -171,10 +171,10 @@ class OutfeedReceiverImpl { void Start(); - StatusOr AddOutfeedToBuilder(XlaBuilder* builder, XlaOp token, - uint32_t consumer_id, - std::vector arrays, - uint32_t device_idx); + absl::StatusOr AddOutfeedToBuilder(XlaBuilder* builder, XlaOp token, + uint32_t consumer_id, + std::vector arrays, + uint32_t device_idx); private: bool CallbackQueueHasSpace() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { @@ -192,8 +192,8 @@ class OutfeedReceiverImpl { Status SendShutdownOutfeedHeader(int device_idx); // Receives a raw Literal from a device outfeed. - StatusOr> ReceiveRawFromOutfeed(PjRtDevice* device, - const Shape& shape); + absl::StatusOr> ReceiveRawFromOutfeed( + PjRtDevice* device, const Shape& shape); // Enqueues received data in the callbaback queue. void EnqueueReceivedData(uint32_t device_idx, @@ -352,8 +352,9 @@ void OutfeedReceiverImpl::EnqueueReceivedData( callback_queues_[device_idx].push(std::move(received)); } -StatusOr> OutfeedReceiverImpl::ReceiveRawFromOutfeed( - PjRtDevice* device, const Shape& shape) { +absl::StatusOr> +OutfeedReceiverImpl::ReceiveRawFromOutfeed(PjRtDevice* device, + const Shape& shape) { auto literal = std::make_unique(shape); TF_RETURN_IF_ERROR(device->TransferFromOutfeed(literal.get())); return literal; @@ -442,7 +443,7 @@ Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) { return OkStatus(); } -StatusOr OutfeedReceiverImpl::AddOutfeedToBuilder( +absl::StatusOr OutfeedReceiverImpl::AddOutfeedToBuilder( XlaBuilder* builder, XlaOp token, uint32_t consumer_id, std::vector arrays, uint32_t device_idx) { XlaOp data = Tuple(builder, std::move(arrays)); @@ -498,11 +499,9 @@ OutfeedReceiver::~OutfeedReceiver() = default; void OutfeedReceiver::Start() { p_impl_->Start(); } -StatusOr OutfeedReceiver::AddOutfeedToBuilder(XlaBuilder* builder, - XlaOp token, - uint32_t consumer_id, - std::vector arrays, - uint32_t device_idx) { +absl::StatusOr OutfeedReceiver::AddOutfeedToBuilder( + XlaBuilder* builder, XlaOp token, uint32_t consumer_id, + std::vector arrays, uint32_t device_idx) { if (consumer_id == kOutfeedCidShutdown) { return InvalidArgument("Consumer ID cannot be a reserved value: %d", consumer_id); diff --git a/xla/python/outfeed_receiver.h b/xla/python/outfeed_receiver.h index c15814f3ff20e..4ca1f959e3e9e 100644 --- a/xla/python/outfeed_receiver.h +++ b/xla/python/outfeed_receiver.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -69,10 +69,10 @@ class OutfeedReceiver { // Returns error status if the outfeed shape is different than the // previously used shape for the same consumer_id or the consumer id is // invalid. - StatusOr AddOutfeedToBuilder(XlaBuilder* builder, XlaOp token, - uint32_t consumer_id, - std::vector arrays, - uint32_t device_idx); + absl::StatusOr AddOutfeedToBuilder(XlaBuilder* builder, XlaOp token, + uint32_t consumer_id, + std::vector arrays, + uint32_t device_idx); private: std::unique_ptr p_impl_; diff --git a/xla/python/outfeed_receiver_py.cc b/xla/python/outfeed_receiver_py.cc index addd479f7bb6e..6cddc4c15fbf0 100644 --- a/xla/python/outfeed_receiver_py.cc +++ b/xla/python/outfeed_receiver_py.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,21 +23,28 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/base/thread_annotations.h" +#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" -#include "pybind11/cast.h" // from @pybind11 -#include "pybind11/functional.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/function.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/optional.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/vector.h" // from @nanobind // IWYU pragma: keep #include "xla/client/executable_build_options.h" #include "xla/client/xla_builder.h" +#include "xla/literal.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/nb_class_ptr.h" #include "xla/python/outfeed_receiver.h" #include "xla/python/py_client.h" -#include "xla/python/status_casters.h" #include "xla/python/types.h" +#include "tsl/platform/logging.h" namespace xla { -namespace py = pybind11; +namespace nb = nanobind; namespace { @@ -47,11 +54,11 @@ class OutfeedReceiverForPython { public: // A callback to Python takes: consumer id, received literal. using CallbackToPython = - std::function, uint32_t, pybind11::object)>; + std::function, uint32_t, nb::object)>; OutfeedReceiverForPython( CallbackToPython callback_python, - std::vector> clients, + std::vector> clients, ssize_t max_callback_queue_size_bytes, const std::optional& executable_build_options) : callback_python_(std::move(callback_python)), @@ -63,7 +70,7 @@ class OutfeedReceiverForPython { }; std::vector client_ptrs(clients_.size()); absl::c_transform(clients_, client_ptrs.begin(), - [](const std::shared_ptr& client) { + [](const nb_class_ptr& client) { return client->pjrt_client(); }); outfeed_receiver_ = std::make_unique( @@ -84,15 +91,16 @@ class OutfeedReceiverForPython { absl::MutexLock lock(&mu_); outfeed_receiver_shutting_down_ = true; } - py::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; outfeed_receiver_ = nullptr; // Shutdown the outfeed receiver. } void Start() { outfeed_receiver_->Start(); } - StatusOr AddOutfeed(XlaBuilder* builder, XlaOp token, - uint32_t consumer_id, std::vector arrays, - uint32_t device_idx) { + absl::StatusOr AddOutfeed(XlaBuilder* builder, XlaOp token, + uint32_t consumer_id, + std::vector arrays, + uint32_t device_idx) { return outfeed_receiver_->AddOutfeedToBuilder(builder, token, consumer_id, arrays, device_idx); } @@ -108,15 +116,16 @@ class OutfeedReceiverForPython { } // We expect the number of clients to be small, so an O(n) search is fine. auto it = absl::c_find_if( - clients_, [device](const std::shared_ptr& client) { + clients_, [device](const nb_class_ptr& client) { return client->pjrt_client() == device->client(); }); CHECK(it != clients_.end()); - py::gil_scoped_acquire gil_acquire; // Need GIL also for LiteralToPython - py::object literal_python = LiteralToPython(std::move(literal)).value(); + PyClient* client = it->get(); + nb::gil_scoped_acquire gil_acquire; // Need GIL also for LiteralToPython + nb::object literal_python = LiteralToPython(std::move(literal)).value(); // The callback_ should handle all exceptions in user-code. If we get // an exception here, it is a bug in the callback and we should stop. - callback_python_(WrapWithClient(*it, device), consumer_id, + callback_python_(client->GetPyDevice(device), consumer_id, std::move(literal_python)); } @@ -124,31 +133,32 @@ class OutfeedReceiverForPython { CallbackToPython callback_python_; absl::Mutex mu_; bool outfeed_receiver_shutting_down_ ABSL_GUARDED_BY(mu_) = false; - std::vector> clients_; + std::vector> clients_; std::unique_ptr outfeed_receiver_; }; } // namespace -void BuildOutfeedReceiverSubmodule(py::module* m) { - py::module outfeed_receiver = - m->def_submodule("outfeed_receiver", "Outfeed receiver"); +void BuildOutfeedReceiverSubmodule(nb::module_& m) { + nb::module_ outfeed_receiver = + m.def_submodule("outfeed_receiver", "Outfeed receiver"); outfeed_receiver.def( "start", [](OutfeedReceiverForPython::CallbackToPython callback_to_python, - std::vector> clients, - ssize_t max_callback_queue_size_bytes, + nb::sequence clients, ssize_t max_callback_queue_size_bytes, std::optional executable_build_options) -> std::unique_ptr { auto server = std::make_unique( - callback_to_python, clients, max_callback_queue_size_bytes, - executable_build_options); + std::move(callback_to_python), + SequenceToVector>(clients), + max_callback_queue_size_bytes, executable_build_options); + nb::gil_scoped_release gil_release; server->Start(); return server; }, - py::arg("callback_to_python"), py::arg("backends"), - py::arg("max_queue_size_bytes") = 256 * 1024 * 1024, - py::arg("executable_build_options") = std::nullopt, + nb::arg("callback_to_python"), nb::arg("backends"), + nb::arg("max_queue_size_bytes") = 256 * 1024 * 1024, + nb::arg("executable_build_options").none() = nb::none(), R"(Starts a multithreaded outfeed receiver. There is one thread for each of the specified devices. When Python @@ -163,23 +173,22 @@ void BuildOutfeedReceiverSubmodule(py::module* m) { * max_queue_size_bytes: an optional integer to bound the maximum size of arrays in the callback queue. When this limit is reached the device listener pauses. - )", - py::call_guard()); + )"); - py::class_ outfeed_receiver_class( + nb::class_ outfeed_receiver_class( outfeed_receiver, "OutfeedReceiverForPython"); outfeed_receiver_class.def( "add_outfeed", xla::ValueOrThrowWrapper(&OutfeedReceiverForPython::AddOutfeed), - py::arg("builder"), py::arg("token"), py::arg("consumer_id"), - py::arg("arrays"), py::arg("device_idx"), + nb::arg("builder"), nb::arg("token"), nb::arg("consumer_id"), + nb::arg("arrays"), nb::arg("device_idx"), R"(Adds an outfeed into the given computation builder. Has the side-effect of registering the sent shape along with the consumer ID. Returns error if the outfeed shape is not compatible with previously used shape for the same consumer ID.)", - py::call_guard()); + nb::call_guard()); } } // namespace xla diff --git a/xla/python/outfeed_receiver_py.h b/xla/python/outfeed_receiver_py.h index c3d3c03cbe8c3..4605204102cda 100644 --- a/xla/python/outfeed_receiver_py.h +++ b/xla/python/outfeed_receiver_py.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,11 +17,11 @@ limitations under the License. #define XLA_PYTHON_OUTFEED_RECEIVER_PY_H_ // placeholder for index annotation headers -#include "pybind11/pybind11.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind namespace xla { -void BuildOutfeedReceiverSubmodule(pybind11::module* m); +void BuildOutfeedReceiverSubmodule(nanobind::module_& m); } // namespace xla diff --git a/xla/python/outfeed_receiver_test.cc b/xla/python/outfeed_receiver_test.cc index 0f67839f4cdfc..cb7c0040f64cd 100644 --- a/xla/python/outfeed_receiver_test.cc +++ b/xla/python/outfeed_receiver_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -258,7 +258,7 @@ TEST(OutfeedReceiverTest, DifferentShapeForConsumerIdError) { const Shape shape1 = ShapeUtil::MakeShape(U32, {128}); XlaOp data1 = Iota(&builder, shape1, 0); // A different shape for the same consumer ID. - StatusOr send1 = outfeed_receiver->AddOutfeedToBuilder( + absl::StatusOr send1 = outfeed_receiver->AddOutfeedToBuilder( &builder, send0, consumer_id0, {data1}, 0); EXPECT_FALSE(send1.ok()); EXPECT_THAT(send1.status().ToString(), @@ -283,7 +283,7 @@ TEST(OutfeedReceiverTest, InvalidConsumerIdError) { XlaBuilder builder("execute_test_outfeed"); const Shape shape0 = ShapeUtil::MakeShape(U32, {16}); XlaOp data0 = Iota(&builder, shape0, 0); - StatusOr send0 = outfeed_receiver->AddOutfeedToBuilder( + absl::StatusOr send0 = outfeed_receiver->AddOutfeedToBuilder( &builder, CreateToken(&builder), 0, {data0}, 0); EXPECT_FALSE(send0.ok()); diff --git a/xla/python/pjit.cc b/xla/python/pjit.cc index 76e3dfc5a9e99..e2e30cfe0e520 100644 --- a/xla/python/pjit.cc +++ b/xla/python/pjit.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,48 +15,76 @@ limitations under the License. #include "xla/python/pjit.h" +#include + #include +#include +#include #include #include #include #include #include +#include #include // NOLINT -#include #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/synchronization/notification.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/optional.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/vector.h" // from @nanobind // IWYU pragma: keep +#include "xla/pjrt/exceptions.h" #include "xla/pjrt/lru_cache.h" +#include "xla/pjrt/pjrt_client.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/sharding.h" #include "xla/python/jax_jit.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" #include "xla/python/py_array.h" #include "xla/python/py_executable.h" #include "xla/python/py_values.h" -#include "xla/python/python_utils.h" +#include "xla/python/python_ref_manager.h" #include "xla/python/pytree.h" #include "xla/python/sharding.h" -#include "xla/python/status_casters.h" +#include "xla/python/traceback.h" #include "xla/python/transfer_guard_lib.h" -#include "xla/python/util.h" +#include "xla/util.h" +#include "tsl/concurrency/ref_count.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace jax { namespace { -namespace py = pybind11; +namespace nb = nanobind; struct PjitCacheEntry { explicit PjitCacheEntry(xla::PyTreeRegistry* registry) : out_pytree_def(registry) {} std::shared_ptr executable; - std::vector in_shardings; - std::vector out_avals; - std::vector out_dtypes; + std::vector in_shardings; + std::vector out_avals; + std::vector out_dtypes; std::vector> out_shapes; std::vector out_weak_types; - std::vector out_shardings; + std::vector out_shardings; std::vector out_committed; xla::PyTreeDef out_pytree_def; // Bitvector of kept arguments from Jaxpr DCE pass. Used to drop some `args` @@ -92,7 +120,7 @@ class PjitFunctionCache { // We include as part of the cache key `donate_argnums` (and any other fields // that aren't subsumed by the CallSignature we compute for each call). - std::shared_ptr Lookup(pybind11::handle function, + std::shared_ptr Lookup(nb::handle function, absl::Span donate_argnums); std::shared_ptr DefaultCache(); @@ -102,15 +130,15 @@ class PjitFunctionCache { private: struct Key { - pybind11::handle function; // Does not hold a reference. + nb::handle function; // Does not hold a reference. // Other fields that are part of the arguments to `jit`, but are not // otherwise part of CallSignature. std::vector donate_argnums; bool operator==(const Key& other) const { - return std::tie(function, donate_argnums) == - std::tie(other.function, other.donate_argnums); + return function.ptr() == other.function.ptr() && + donate_argnums == other.donate_argnums; } }; @@ -131,7 +159,7 @@ class PjitFunctionCache { // We use a weak pointer because we want to allow caching across multiple // calls to `pjit(f)` if `f` remains alive, but we do not want the cache // to keep `f` alive if all other references are dropped. - pybind11::weakref weakref; + std::optional weakref; }; Cache::LRUList lru_list_; @@ -145,7 +173,7 @@ std::shared_ptr PjitFunctionCache::DefaultCache() { } std::shared_ptr PjitFunctionCache::Lookup( - pybind11::handle function, absl::Span donate_argnums) { + nb::handle function, absl::Span donate_argnums) { Key key; key.function = function; key.donate_argnums = @@ -155,15 +183,15 @@ std::shared_ptr PjitFunctionCache::Lookup( return insert.first->second->cache; } std::shared_ptr cache = std::make_shared(&lru_list_); - pybind11::cpp_function callback( - [this, key{std::move(key)}](pybind11::handle weakref) { + auto callback = + nb::cpp_function([this, key{std::move(key)}](nb::handle weakref) { functions_.erase(key); }); PyObject* weakref = PyWeakref_NewRef(function.ptr(), callback.ptr()); if (weakref) { std::unique_ptr& entry = insert.first->second; entry = std::make_unique(cache); - entry->weakref = pybind11::reinterpret_steal(weakref); + entry->weakref = nb::steal(weakref); } else { PyErr_Clear(); // `function` is not weak-referenceable. Don't bother adding it to the @@ -176,11 +204,12 @@ std::shared_ptr PjitFunctionCache::Lookup( class PjitFunction { public: - PjitFunction(std::string function_name, std::optional fun, - py::function cache_miss, std::vector static_argnums, - std::vector static_argnames, + PjitFunction(std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, std::vector donate_argnums, std::shared_ptr pytree_registry, + nb::callable shard_arg_fallback, std::shared_ptr cache); ~PjitFunction(); @@ -189,39 +218,40 @@ class PjitFunction { PjitFunction(PjitFunction&&) = default; PjitFunction& operator=(PjitFunction&&) = default; - // pybind11::object typed subclass for PjitFunction objects. - class pyobject : public py::object { + // nb::object typed subclass for PjitFunction objects. + class pyobject : public nb::object { public: - PYBIND11_OBJECT(pyobject, // NOLINT - py::object, PjitFunction::IsPjitFunction); + NB_OBJECT(pyobject, nb::object, "PjitFunction", + PjitFunction::IsPjitFunction); pyobject() = default; PjitFunction* func() const { return PjitFunction::AsPjitFunctionUnchecked(*this); } }; - // Alias as ::object; outside the scope above we won't confuse pybind11's + // Alias as ::object; outside the scope above we won't confuse nanobind's // macros. using object = pyobject; // Returns true if `h` is a PjitFunction. - static bool IsPjitFunction(py::handle handle); + static bool IsPjitFunction(nb::handle handle); // Converts `handle` to a PjitFunction*. Does not do any checking. - static PjitFunction* AsPjitFunctionUnchecked(py::handle handle); + static PjitFunction* AsPjitFunctionUnchecked(nb::handle handle); - xla::StatusOr Call(py::handle callable, PyObject* const* args, - size_t nargs, PyObject* kwnames); + absl::StatusOr Call(nb::handle callable, PyObject* const* args, + size_t nargs, PyObject* kwnames); void ClearPythonReferences(); const std::string& function_name() const { return function_name_; } - const std::optional& fun() const { return fun_; } - const py::function& cache_miss() const { return cache_miss_; } + const std::optional& fun() const { return fun_; } + const nb::callable& cache_miss() const { return cache_miss_; } const std::shared_ptr& pytree_registry() const { return pytree_registry_; } + const nb::callable& shard_arg_fallback() const { return shard_arg_fallback_; } const std::vector& static_argnums() const { return static_argnums_; } - const std::vector& static_argnames() const { + const std::vector& static_argnames() const { return static_argnames_; } const std::vector& donate_argnums() const { return donate_argnums_; } @@ -231,31 +261,36 @@ class PjitFunction { void ClearCache() { executables_->Clear(); } - py::object PythonSignature() { + nb::object PythonSignature() { if (!fun_.has_value()) { - throw py::value_error(absl::StrFormat( - "Calling __signature__ on PjitFunction(%s) not supported.", - function_name_)); + throw nb::value_error( + absl::StrFormat( + "Calling __signature__ on PjitFunction(%s) not supported.", + function_name_) + .c_str()); } - static const auto* inspect = new py::module(py::module::import("inspect")); + static const auto* inspect = + new nb::module_(nb::module_::import_("inspect")); return inspect->attr("signature")(*fun_); } private: - xla::Status UpdateArgsSignature(ParsedArgumentsAsBuffers& arguments); + absl::Status ComputeCallSignature( + absl::Span flat_dynamic_args, + CallSignature& call_signature); void PopulateCacheEntry(PjitCacheEntry& cache_entry, - const CallSignature& signature, - const py::tuple& out_and_fastpath_data); + const nb::tuple& out_and_fastpath_data); std::string function_name_; - std::optional fun_; - py::function cache_miss_; + std::optional fun_; + nb::callable cache_miss_; std::vector static_argnums_; - std::vector static_argnames_; + std::vector static_argnames_; std::vector donate_argnums_; std::shared_ptr pytree_registry_; + nb::callable shard_arg_fallback_; std::shared_ptr cache_; std::shared_ptr executables_; }; @@ -283,25 +318,26 @@ PjitFunctionStore& GetGlobalPjitFunctionStore() { return *store; } -PjitFunction::PjitFunction(std::string function_name, - std::optional fun, - py::function cache_miss, - std::vector static_argnums, - std::vector static_argnames, - std::vector donate_argnums, - std::shared_ptr pytree_registry, - std::shared_ptr cache) +PjitFunction::PjitFunction( + std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, std::vector donate_argnums, + std::shared_ptr pytree_registry, + nb::callable shard_arg_fallback, std::shared_ptr cache) : function_name_(std::move(function_name)), fun_(std::move(fun)), cache_miss_(std::move(cache_miss)), static_argnums_(std::move(static_argnums)), - static_argnames_(std::move(static_argnames)), donate_argnums_(donate_argnums), pytree_registry_(std::move(pytree_registry)), + shard_arg_fallback_(std::move(shard_arg_fallback)), cache_(std::move(cache)) { std::sort(static_argnums_.begin(), static_argnums_.end()); - for (py::str& s : static_argnames_) { - PyUnicode_InternInPlace(&s.ptr()); + static_argnames.reserve(static_argnames.size()); + for (nb::str& name : static_argnames) { + PyObject* s = name.inc_ref().ptr(); + PyUnicode_InternInPlace(&s); + static_argnames_.push_back(nb::steal(s)); } if (!fun_.has_value()) { executables_ = cache_->DefaultCache(); @@ -314,37 +350,54 @@ PjitFunction::PjitFunction(std::string function_name, PjitFunction::~PjitFunction() { GetGlobalPjitFunctionStore().Erase(this); } +void CallShardArgFallback( + nb::handle arg, nb::handle sharding, const nb::callable& fallback, + std::vector>& num_args_arrays, + std::vector& keep_alive_objects) { + tsl::profiler::TraceMe traceme("cpp_pjit_shard_arg_fallback"); + auto py_array_or_bufs = fallback(arg, sharding); + auto py_array = nb::cast(py_array_or_bufs); + num_args_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); + keep_alive_objects.push_back(std::move(py_array_or_bufs)); +} + // Prepares the input PjRtBuffers from the python arguments. This is equivalent // to shard_args() in pxla.py but for only a few supported cases. -xla::StatusOr>> +absl::StatusOr>> PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, - ParsedArgumentsAsBuffers& arguments, - const std::vector& kept_args) { - const auto& addressable_devices = executable.AddressableDevices(); - int num_args = arguments.flat_dynamic_args.size(); + absl::Span flat_dynamic_args, + bool enable_x64, const std::vector& kept_args, + const std::vector& in_shardings, + const nb::callable& shard_arg_fallback, + std::vector& keep_alive_objects) { + const auto& addressable_devices = + executable.ifrt_loaded_executable()->addressable_devices(); + int num_args = flat_dynamic_args.size(); std::vector> num_args_arrays; num_args_arrays.reserve(num_args); xla::DevicePutOptions options; - options.squash_64bit_types = !arguments.signature.jax_enable_x64; + options.squash_64bit_types = !enable_x64; options.allow_zero_copy = true; xla::PjRtDevice* data_device = nullptr; if (executable.ifrt_loaded_executable()->num_devices() == 1) { data_device = executable.ifrt_loaded_executable()->addressable_devices()[0]; } - + int dce_i = 0; for (int i = 0; i < num_args; ++i) { if (!kept_args[i]) { continue; } - const py::object& arg = arguments.flat_dynamic_args[i]; + int dce_index = dce_i; + ++dce_i; + + const nb::object& arg = flat_dynamic_args[i]; auto transfer_guard_formatter = [] { return std::string(""); }; - if (arg.get_type() != xla::PyArray::type()) { + if (arg.type().ptr() != xla::PyArray::type().ptr()) { if (data_device != nullptr) { - py::handle arg = arguments.flat_dynamic_args[i]; TF_RETURN_IF_ERROR( jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); TF_ASSIGN_OR_RETURN( @@ -354,16 +407,18 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, num_args_arrays.push_back(std::move(on_device.ifrt_array)); if (on_device.owning_pybuffer) { - arguments.keep_alive_objects.push_back( - std::move(on_device.owning_pybuffer)); + keep_alive_objects.push_back(std::move(on_device.owning_pybuffer)); } continue; + } else { + CallShardArgFallback(arg.ptr(), in_shardings[dce_index], + shard_arg_fallback, num_args_arrays, + keep_alive_objects); + continue; } - - return xla::Unimplemented("Unhandled non PyArray argument."); } - xla::PyArray py_array = arg; + xla::PyArray py_array = nb::borrow(arg); const auto& sharding = py_array.sharding(); int sharding_num_devices = jax::Sharding::SafeNumDevices(sharding); @@ -373,27 +428,30 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, DCHECK(py_array.committed() || (!py_array.committed() && sharding_num_devices == 1)); - if (sharding.get_type() == jax::PmapSharding::type()) { - return xla::Unimplemented( - "Handling PyArray in PmapSharding is not implemented."); + if (sharding.type().ptr() == jax::PmapSharding::type().ptr()) { + CallShardArgFallback(arg.ptr(), in_shardings[dce_index], + shard_arg_fallback, num_args_arrays, + keep_alive_objects); + continue; } if (py_array.num_shards() != addressable_devices.size()) { - return xla::InvalidArgument( - "Expected PyArray to have %d shards, but got %d", - addressable_devices.size(), py_array.num_shards()); + CallShardArgFallback(arg.ptr(), in_shardings[dce_index], + shard_arg_fallback, num_args_arrays, + keep_alive_objects); + continue; } xla::ifrt::Array* ifrt_array = py_array.ifrt_array(); // PyArray inputs should have already been checked in // `xla::PyArgSignatureOfValue()` called by - // `PjitFunction::UpdateArgsSignature()`. + // `PjitFunction::ComputeCallSignature()`. DCHECK(ifrt_array != nullptr) << "PyArray has been unexpectedly deleted."; - if (sharding_num_devices == 1 && ifrt_array->sharding().devices().front() != - addressable_devices[0].get()) { + if (sharding_num_devices == 1 && + ifrt_array->sharding().devices().front() != addressable_devices[0]) { xla::ifrt::DeviceList::Devices ifrt_devices; - ifrt_devices.push_back(addressable_devices[0].get()); + ifrt_devices.push_back(addressable_devices[0]); auto sharding = xla::ifrt::OpaqueSharding::Create( xla::ifrt::DeviceList(std::move(ifrt_devices)), ifrt_array->sharding().memory_kind()); @@ -406,18 +464,17 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, num_args_arrays.push_back(tsl::FormRef(ifrt_array)); } - arguments.keep_alive_objects.push_back(arg); + keep_alive_objects.push_back(arg); } return num_args_arrays; } -xla::StatusOr PjitFunction::Call(py::handle callable, - PyObject* const* args, - size_t nargs, PyObject* kwnames) { +absl::StatusOr PjitFunction::Call(nb::handle callable, + PyObject* const* args, + size_t nargs, PyObject* kwnames) { tsl::profiler::TraceMe traceme( [&] { return absl::StrCat("PjitFunction(", function_name_, ")"); }); - ParsedArgumentsAsBuffers arguments; // Make sure we trigger a garbage collection on JIT function calls. Otherwise // code like @@ -429,19 +486,20 @@ xla::StatusOr PjitFunction::Call(py::handle callable, if (GetDisableJit()) { if (!fun_.has_value()) { - throw py::value_error( + throw nb::value_error( absl::StrFormat("Disable jit is not supported in the AOT path since " "the function is not available for (%s)", - function_name_)); + function_name_) + .c_str()); } - return py::reinterpret_steal( + return nb::steal( PyObject_Vectorcall(fun_.value().ptr(), args, nargs, kwnames)); } // Calls the cache_miss_ function. This just calls the Python function; it may // return nullptr value if a Python exception is thrown. - auto cache_miss = [&]() -> py::tuple { - return py::reinterpret_steal( + auto cache_miss = [&]() -> nb::tuple { + return nb::steal( PyObject_Vectorcall(cache_miss_.ptr(), args, nargs, kwnames)); }; @@ -449,11 +507,11 @@ xla::StatusOr PjitFunction::Call(py::handle callable, // the fastpath data. If the cache miss returns a Python error, returns // nullptr and leaves the Python error set. auto fallback_to_cache_miss = [&]() { - py::tuple cache_miss_output = cache_miss(); + nb::tuple cache_miss_output = cache_miss(); if (!cache_miss_output.ptr()) { - return py::object(); + return nb::object(); } - return py::object(cache_miss_output[0]); + return nb::object(cache_miss_output[0]); }; size_t num_positional_args = PyVectorcall_NARGS(nargs); @@ -461,9 +519,13 @@ xla::StatusOr PjitFunction::Call(py::handle callable, absl::Span positional_args(args, num_positional_args); absl::Span keyword_args(args + num_positional_args, num_keyword_args); - auto status = - ParseArguments(positional_args, keyword_args, kwnames, static_argnums_, - static_argnames_, pytree_registry_.get(), arguments); + + CallSignature call_signature; + std::vector keep_alive_objects; + absl::InlinedVector flat_dynamic_args; + auto status = ParseArguments( + positional_args, keyword_args, kwnames, static_argnums_, static_argnames_, + pytree_registry_.get(), call_signature.arg_signature, flat_dynamic_args); if (!status.ok()) { VLOG(2) << "ParseArguments failed: " << status; return fallback_to_cache_miss(); @@ -473,18 +535,15 @@ xla::StatusOr PjitFunction::Call(py::handle callable, // committed PyArray inputs. For other cases, e.g. Tracers or ShapedArray, it // will fallback to python. For jit, numpy arrays and scalars are also // allowed, which we will check later. - for (const auto& arg : arguments.flat_dynamic_args) { - if (arg.get_type() != xla::PyArray::type()) { + for (const auto& arg : flat_dynamic_args) { + if (arg.type().ptr() != xla::PyArray::type().ptr()) { continue; } - xla::PyArray py_array = arg; - if (!py_array.fastpath_enabled()) { - return fallback_to_cache_miss(); - } + xla::PyArray py_array = nb::borrow(arg); // Only allow committed PyArray in cpp pjit for now as the logic on handling - // sharding for uncommited PyArray is complicated and still under + // sharding for uncommitted PyArray is complicated and still under // development. // // TODO(chky): Consider support uncommitted PyArray in cpp when the python @@ -497,17 +556,17 @@ xla::StatusOr PjitFunction::Call(py::handle callable, } } - status = UpdateArgsSignature(arguments); + status = ComputeCallSignature(flat_dynamic_args, call_signature); if (!status.ok()) { - VLOG(2) << "UpdateArgsSignature failed: " << status; + VLOG(2) << "ComputeCallSignature failed: " << status; return fallback_to_cache_miss(); } - VLOG(2) << "CallSignature:\n" << arguments.signature.DebugString(); + VLOG(2) << "CallSignature:\n" << call_signature.DebugString(); bool inserted = false; std::shared_ptr cache_entry = executables_->GetOrCreateIfAbsent( - arguments.signature, [this, &inserted](const CallSignature& unused) { + call_signature, [this, &inserted](const CallSignature& unused) { inserted = true; return std::make_shared(pytree_registry_.get()); }); @@ -516,19 +575,19 @@ xla::StatusOr PjitFunction::Call(py::handle callable, // In case of several threads attempting to compile the executable, only // the one that inserted the item will perform the compilation. if (inserted) { - py::object out_and_fastpath_data; - py::tuple out_tuple; - VLOG(2) << "Cache miss for " << arguments.signature.DebugString(); + nb::object out_and_fastpath_data; + nb::tuple out_tuple; + VLOG(2) << "Cache miss for " << call_signature.DebugString(); try { // Calls Python and may release the GIL. May also throw if // compilation/tracing fails. out_and_fastpath_data = cache_miss(); if (!out_and_fastpath_data.ptr()) { - throw py::error_already_set(); + throw nb::python_error(); } - out_tuple = py::cast(out_and_fastpath_data); + out_tuple = nb::cast(out_and_fastpath_data); - PopulateCacheEntry(*cache_entry, arguments.signature, out_tuple); + PopulateCacheEntry(*cache_entry, out_tuple); } catch (const std::exception& e) { VLOG(2) << "cache miss fail: " << e.what(); cache_entry->fall_back_to_python = true; @@ -540,17 +599,17 @@ xla::StatusOr PjitFunction::Call(py::handle callable, // We have already computed the result in the miss path so we can return // it. We are even *required* to do so if there are donated arguments, // because any donated buffers will now be invalid. - return py::object(out_tuple[0]); + return nb::object(out_tuple[0]); } else { if (cache_entry->thread_id == std::this_thread::get_id()) { auto error_string = absl::StrCat("Recursively calling jit: ", - arguments.signature.DebugString()); + call_signature.DebugString()); PyErr_SetString(PyExc_RecursionError, error_string.c_str()); - throw pybind11::error_already_set(); + throw nb::python_error(); } // Release the GIL while we wait, making sure the compile thread can // lock it. - py::gil_scoped_release release; + nb::gil_scoped_release release; cache_entry->compilation_complete.WaitForNotification(); } } @@ -561,8 +620,10 @@ xla::StatusOr PjitFunction::Call(py::handle callable, } // A vector of [num_inputs]. - auto num_args_arrays = PrepareIfrtInputs(*cache_entry->executable, arguments, - cache_entry->kept_var_bitvec); + auto num_args_arrays = PrepareIfrtInputs( + *cache_entry->executable, flat_dynamic_args, + call_signature.jax_enable_x64, cache_entry->kept_var_bitvec, + cache_entry->in_shardings, shard_arg_fallback_, keep_alive_objects); if (!num_args_arrays.ok()) { VLOG(2) << "Failed to prepare IFRT inputs: " << num_args_arrays.status(); @@ -572,7 +633,7 @@ xla::StatusOr PjitFunction::Call(py::handle callable, // A vector of [num_outputs]. std::vector> output_arrays; { - py::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN(auto result, cache_entry->executable->ifrt_executable()->Execute( absl::MakeSpan(*num_args_arrays), @@ -586,7 +647,7 @@ xla::StatusOr PjitFunction::Call(py::handle callable, // Convert the ifrt::Array objects to PyArray. int num_outputs = output_arrays.size(); - absl::InlinedVector outputs; + absl::InlinedVector outputs; outputs.reserve(num_outputs); for (int i = 0; i < num_outputs; ++i) { // Creating the PyArray result. In addition to the IFRT arrays, the metadata @@ -602,73 +663,75 @@ xla::StatusOr PjitFunction::Call(py::handle callable, outputs.push_back(std::move(py_array)); } - py::object out = cache_entry->out_pytree_def.Unflatten(outputs); + nb::object out = nb::steal( + cache_entry->out_pytree_def.Unflatten(outputs).release().ptr()); // If there is a post-hook function, call it with the inputs and the outputs. - std::optional post_hook = GetPostHook(); + std::optional post_hook = GetPostHook(); if (post_hook) { - py::tuple args_tuple(num_positional_args); + nb::tuple args_tuple = + nb::steal(PyTuple_New(num_positional_args)); for (size_t i = 0; i < num_positional_args; ++i) { - args_tuple[i] = args[i]; + Py_INCREF(args[i]); + PyTuple_SET_ITEM(args_tuple.ptr(), i, args[i]); } - py::dict kwargs; + nb::dict kwargs; if (kwnames) { for (size_t i = 0; i < num_keyword_args; ++i) { - kwargs[py::handle(PyTuple_GET_ITEM(kwnames, i))] = - args[num_positional_args + i]; + kwargs[nb::handle(PyTuple_GET_ITEM(kwnames, i))] = + nb::borrow(args[num_positional_args + i]); } } - (*post_hook)(callable, args_tuple, kwargs, out); + (*post_hook)(nb::handle(callable.ptr()), args_tuple, kwargs, + nb::handle(out.ptr())); } return out; } -xla::Status PjitFunction::UpdateArgsSignature( - ParsedArgumentsAsBuffers& arguments) { - arguments.signature.function_name = function_name_; +absl::Status PjitFunction::ComputeCallSignature( + absl::Span flat_dynamic_args, CallSignature& signature) { + signature.function_name = function_name_; // Get dynamic argument signatures. JitState& global_state = jax::GlobalJitState(); JitState& tls = jax::ThreadLocalJitState(); bool jax_enable_x64 = GetEnableX64(); - arguments.signature.default_device = GetDefaultDevice(); - arguments.signature.jax_enable_x64 = jax_enable_x64; - arguments.signature.jax_enable_memories = GetEnableMemories(); + signature.default_device = GetDefaultDevice(); + signature.jax_enable_x64 = jax_enable_x64; + signature.jax_enable_memories = GetEnableMemories(); - auto& dynamic_arg_signatures = arguments.signature.dynamic_arg_signatures; - dynamic_arg_signatures.reserve(arguments.flat_dynamic_args.size()); - auto& dynamic_arg_shardings = arguments.signature.dynamic_arg_shardings; - dynamic_arg_shardings.reserve(arguments.flat_dynamic_args.size()); + auto& dynamic_arg_signatures = signature.dynamic_arg_signatures; + dynamic_arg_signatures.reserve(flat_dynamic_args.size()); + auto& dynamic_arg_shardings = signature.dynamic_arg_shardings; + dynamic_arg_shardings.reserve(flat_dynamic_args.size()); - for (py::handle arg : arguments.flat_dynamic_args) { - TF_ASSIGN_OR_RETURN(auto signature, + for (nb::handle arg : flat_dynamic_args) { + TF_ASSIGN_OR_RETURN(auto arg_signature, xla::PyArgSignatureOfValue(arg, jax_enable_x64)); - arguments.signature.dynamic_arg_signatures.push_back(std::move(signature)); + signature.dynamic_arg_signatures.push_back(std::move(arg_signature)); // It should be already checked previously in the entry point of // PjitFunction::Call(). - if (arg.get_type() == xla::PyArray::type()) { - auto py_array = py::reinterpret_borrow(arg); - - arguments.signature.dynamic_arg_shardings.push_back(py_array.sharding()); - arguments.signature.committed_args.push_back(py_array.committed()); + if (arg.type().ptr() == xla::PyArray::type().ptr()) { + auto py_array = nb::borrow(arg); + signature.dynamic_arg_shardings.push_back(py_array.sharding()); + signature.committed_args.push_back(py_array.committed()); } else { - arguments.signature.dynamic_arg_shardings.push_back(py::none()); - arguments.signature.committed_args.push_back(false); + signature.dynamic_arg_shardings.push_back(nb::none()); + signature.committed_args.push_back(false); } } - arguments.signature.thread_local_extra_jit_context = tls.extra_jit_context; - arguments.signature.global_extra_jit_context = global_state.extra_jit_context; + signature.thread_local_extra_jit_context = tls.extra_jit_context; + signature.global_extra_jit_context = global_state.extra_jit_context; - return xla::OkStatus(); + return absl::OkStatus(); } void PjitFunction::PopulateCacheEntry(PjitCacheEntry& cache_entry, - const CallSignature& signature, - const py::tuple& out_and_fastpath_data) { + const nb::tuple& out_and_fastpath_data) { DCHECK_EQ(out_and_fastpath_data.size(), 2); if (out_and_fastpath_data[1].is_none()) { @@ -677,82 +740,89 @@ void PjitFunction::PopulateCacheEntry(PjitCacheEntry& cache_entry, return; } - py::tuple fastpath_data = py::cast(out_and_fastpath_data[1]); + nb::tuple fastpath_data = nb::cast(out_and_fastpath_data[1]); - cache_entry.executable = py::cast>( + cache_entry.executable = nb::cast>( fastpath_data.attr("xla_executable")); - py::list in_shardings = fastpath_data.attr("in_shardings"); - cache_entry.in_shardings.reserve(in_shardings.size()); - for (py::handle sharding : in_shardings) { - cache_entry.in_shardings.push_back( - py::reinterpret_borrow(sharding)); + nb::sequence in_shardings = fastpath_data.attr("in_shardings"); + cache_entry.in_shardings.reserve(nb::len(in_shardings)); + for (nb::handle sharding : in_shardings) { + cache_entry.in_shardings.push_back(nb::borrow(sharding)); } - py::list out_shardings = fastpath_data.attr("out_shardings"); - cache_entry.out_shardings.reserve(out_shardings.size()); - for (py::handle sharding : out_shardings) { - cache_entry.out_shardings.push_back( - py::reinterpret_borrow(sharding)); + nb::sequence out_shardings = fastpath_data.attr("out_shardings"); + cache_entry.out_shardings.reserve(nb::len(out_shardings)); + for (nb::handle sharding : out_shardings) { + cache_entry.out_shardings.push_back(nb::borrow(sharding)); } - py::list out_committed = fastpath_data.attr("out_committed"); - cache_entry.out_committed.reserve(out_committed.size()); - for (py::handle c : out_committed) { - cache_entry.out_committed.push_back(py::cast(c)); + nb::sequence out_committed = fastpath_data.attr("out_committed"); + cache_entry.out_committed.reserve(nb::len(out_committed)); + for (nb::handle c : out_committed) { + cache_entry.out_committed.push_back(nb::cast(c)); } - py::list out_avals = fastpath_data.attr("out_avals"); - cache_entry.out_avals.reserve(out_avals.size()); - cache_entry.out_dtypes.reserve(out_avals.size()); - cache_entry.out_shapes.reserve(out_avals.size()); - cache_entry.out_weak_types.reserve(out_avals.size()); - for (py::handle aval : out_avals) { - cache_entry.out_avals.push_back(py::reinterpret_borrow(aval)); + nb::sequence out_avals = fastpath_data.attr("out_avals"); + cache_entry.out_avals.reserve(nb::len(out_avals)); + cache_entry.out_dtypes.reserve(nb::len(out_avals)); + cache_entry.out_shapes.reserve(nb::len(out_avals)); + cache_entry.out_weak_types.reserve(nb::len(out_avals)); + for (nb::handle aval : out_avals) { + cache_entry.out_avals.push_back(nb::borrow(aval)); cache_entry.out_dtypes.push_back(aval.attr("dtype")); cache_entry.out_shapes.push_back( - py::cast>(aval.attr("shape"))); + nb::cast>(aval.attr("shape"))); cache_entry.out_weak_types.push_back( - py::cast(aval.attr("weak_type"))); + nb::cast(aval.attr("weak_type"))); } - cache_entry.out_pytree_def = - py::cast(fastpath_data.attr("out_pytree_def")); + cache_entry.out_pytree_def = nb::cast( + nb::handle(fastpath_data.attr("out_pytree_def").ptr())); - py::list kept_var_bitvec = fastpath_data.attr("kept_var_bitvec"); - cache_entry.kept_var_bitvec.reserve(kept_var_bitvec.size()); - for (py::handle k : kept_var_bitvec) { - cache_entry.kept_var_bitvec.push_back(py::cast(k)); + nb::sequence kept_var_bitvec = fastpath_data.attr("kept_var_bitvec"); + cache_entry.kept_var_bitvec.reserve(nb::len(kept_var_bitvec)); + for (nb::handle k : kept_var_bitvec) { + cache_entry.kept_var_bitvec.push_back(nb::cast(k)); } } // Helper function used by the tp_clear GC method. void PjitFunction::ClearPythonReferences() { - py::function cache_miss; + // TODO(mattjj): phawkins@ observed that the xla::PyTreeRegistry + // pytree_registry_ attribute of PjitFunction could in principle also have + // python references to clear + nb::callable cache_miss; + std::optional fun; + nb::callable shard_arg_fallback; // Swap values for nulls before they are destroyed. See the Python // Py_CLEAR() documentation for a discussion of this topic. std::swap(cache_miss_, cache_miss); + std::swap(fun_, fun); + std::swap(shard_arg_fallback_, shard_arg_fallback); } struct PjitFunctionObject { PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 PyObject* dict; // Dictionary for __dict__ PyObject* weakrefs; // Weak references; for use by the Python interpreter. +#endif // PY_VERSION_HEX < 0x030C0000 vectorcallfunc vectorcall; PjitFunction fun; }; PyObject* PjitFunction_Type = nullptr; -bool PjitFunction::IsPjitFunction(py::handle handle) { - return handle.get_type() == PjitFunction_Type; +bool PjitFunction::IsPjitFunction(nb::handle handle) { + return handle.type().ptr() == PjitFunction_Type; } -PjitFunction* PjitFunction::AsPjitFunctionUnchecked(py::handle handle) { +PjitFunction* PjitFunction::AsPjitFunctionUnchecked(nb::handle handle) { return &(reinterpret_cast(handle.ptr())->fun); } -PjitFunction* AsPjitFunction(py::handle handle) { +PjitFunction* AsPjitFunction(nb::handle handle) { if (!PjitFunction::IsPjitFunction(handle)) { throw xla::XlaRuntimeError(xla::InvalidArgument("Expected a PjitFunction")); } @@ -768,16 +838,17 @@ PyObject* PjitFunction_tp_vectorcall(PyObject* callable, PyObject* const* args, return absl::StrCat("PjitFunction(", o->fun.function_name(), ")"); }); try { - xla::StatusOr out = o->fun.Call(callable, args, nargs, kwnames); + absl::StatusOr out = + o->fun.Call(callable, args, nargs, kwnames); if (!out.ok()) { PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str()); return nullptr; } return out.value().release().ptr(); - } catch (py::error_already_set& e) { + } catch (nb::python_error& e) { e.restore(); return nullptr; - } catch (py::cast_error& e) { + } catch (nb::cast_error& e) { PyErr_SetString(PyExc_ValueError, e.what()); return nullptr; } catch (std::invalid_argument& e) { @@ -794,8 +865,10 @@ PyObject* PjitFunction_tp_new(PyTypeObject* subtype, PyObject* args, PjitFunctionObject* self = reinterpret_cast(subtype->tp_alloc(subtype, 0)); if (!self) return nullptr; +#if PY_VERSION_HEX < 0x030C0000 self->dict = nullptr; self->weakrefs = nullptr; +#endif // PY_VERSION_HEX < 0x030C0000 self->vectorcall = PjitFunction_tp_vectorcall; return reinterpret_cast(self); } @@ -804,29 +877,44 @@ void PjitFunction_tp_dealloc(PyObject* self) { PyObject_GC_UnTrack(self); PyTypeObject* tp = Py_TYPE(self); PjitFunctionObject* o = reinterpret_cast(self); - if (o->weakrefs) { - PyObject_ClearWeakRefs(self); - } + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 Py_CLEAR(o->dict); +#else + _PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 o->fun.~PjitFunction(); tp->tp_free(self); Py_DECREF(tp); } int PjitFunction_tp_traverse(PyObject* self, visitproc visit, void* arg) { + // TODO(mattjj): phawkins@ observed that the xla::PyTreeRegistry + // pytree_registry_ attribute of PjitFunction could in principle also have + // python references to visit PjitFunctionObject* o = reinterpret_cast(self); -#if PY_VERSION_HEX >= 0x03090000 // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse Py_VISIT(Py_TYPE(self)); -#endif +#if PY_VERSION_HEX < 0x030C0000 Py_VISIT(o->dict); +#else + _PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 Py_VISIT(o->fun.cache_miss().ptr()); + Py_VISIT(o->fun.shard_arg_fallback().ptr()); + if (o->fun.fun()) { + Py_VISIT(o->fun.fun()->ptr()); + } return 0; } int PjitFunction_tp_clear(PyObject* self) { PjitFunctionObject* o = reinterpret_cast(self); +#if PY_VERSION_HEX < 0x030C0000 Py_CLEAR(o->dict); +#else + _PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 o->fun.ClearPythonReferences(); return 0; } @@ -843,42 +931,18 @@ PyObject* PjitFunction_tp_descr_get(PyObject* self, PyObject* obj, return PyMethod_New(self, obj); } -// Support d = instance.__dict__. -PyObject* PjitFunction_get_dict(PyObject* self, void*) { - PjitFunctionObject* o = reinterpret_cast(self); - if (!o->dict) { - o->dict = PyDict_New(); - } - Py_XINCREF(o->dict); - return o->dict; -} - -int PjitFunction_set_dict(PyObject* self, PyObject* new_dict, void*) { - PjitFunctionObject* o = reinterpret_cast(self); - if (!PyDict_Check(new_dict)) { - PyErr_Format(PyExc_TypeError, - "__dict__ must be set to a dictionary, not a '%s'", - Py_TYPE(new_dict)->tp_name); - return -1; - } - Py_INCREF(new_dict); - Py_CLEAR(o->dict); - o->dict = new_dict; - return 0; -} - static PyGetSetDef PjitFunction_tp_getset[] = { // Having a __dict__ seems necessary to allow !functool.wraps to override // __doc__. - {const_cast("__dict__"), PjitFunction_get_dict, - PjitFunction_set_dict, nullptr, nullptr}, + {const_cast("__dict__"), PyObject_GenericGetDict, + PyObject_GenericSetDict, nullptr, nullptr}, {nullptr, nullptr, nullptr, nullptr, nullptr}}; PyObject* PjitFunction_tp_repr(PyObject* self) { try { const std::string& repr = absl::StrFormat( "", - static_cast(py::repr(py::getattr(self, "__wrapped__")))); + nb::cast(nb::repr(nb::getattr(self, "__wrapped__")))); return PyUnicode_FromString(repr.c_str()); } catch (...) { // Ignore all errors when accessing a repr. @@ -890,24 +954,26 @@ PyObject* PjitFunction_tp_repr(PyObject* self) { void InitializePjitFunction( PjitFunctionObject* fn_obj, std::string function_name, - std::optional fun, py::function cache_miss, - std::vector static_argnums, std::vector static_argnames, + std::optional fun, nb::callable cache_miss, + std::vector static_argnums, std::vector static_argnames, std::vector donate_argnums, std::shared_ptr pytree_registry, - std::shared_ptr cache) { + nb::callable shard_arg_fallback, std::shared_ptr cache) { new (&fn_obj->fun) PjitFunction( std::move(function_name), std::move(fun), std::move(cache_miss), std::move(static_argnums), std::move(static_argnames), - std::move(donate_argnums), std::move(pytree_registry), std::move(cache)); + std::move(donate_argnums), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(cache)); } -py::object MakePjitFunction( - std::string function_name, std::optional fun, - py::function cache_miss, std::vector static_argnums, - std::vector static_argnames, std::vector donate_argnums, +nb::object MakePjitFunction( + std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, std::vector donate_argnums, std::shared_ptr pytree_registry, - std::shared_ptr cache) { - py::object obj = py::reinterpret_steal(PjitFunction_tp_new( + nb::callable shard_arg_fallback, + std::optional> cache) { + nb::object obj = nb::steal(PjitFunction_tp_new( reinterpret_cast(PjitFunction_Type), nullptr, nullptr)); PjitFunctionObject* fn_obj = reinterpret_cast(obj.ptr()); if (!cache) { @@ -917,7 +983,8 @@ py::object MakePjitFunction( InitializePjitFunction(fn_obj, std::move(function_name), std::move(fun), std::move(cache_miss), std::move(static_argnums), std::move(static_argnames), std::move(donate_argnums), - std::move(pytree_registry), std::move(cache)); + std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(*cache)); return obj; } @@ -925,83 +992,95 @@ py::object MakePjitFunction( // PjitFunction. Increment these if changing them. const int kPjitFunctionPickleVersion = 1; +PyMemberDef PjitFunction_members[] = { + {"__vectorcalloffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, vectorcall)), + READONLY, nullptr}, +#if PY_VERSION_HEX < 0x030C0000 + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, dict)), READONLY, + nullptr}, + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, weakrefs)), READONLY, + nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; + +PyType_Slot PjitFunction_slots[] = { + {Py_tp_new, reinterpret_cast(PjitFunction_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(PjitFunction_tp_dealloc)}, + {Py_tp_traverse, reinterpret_cast(PjitFunction_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(PjitFunction_tp_clear)}, + {Py_tp_getset, reinterpret_cast(PjitFunction_tp_getset)}, + {Py_tp_descr_get, reinterpret_cast(PjitFunction_tp_descr_get)}, + {Py_tp_call, reinterpret_cast(PyVectorcall_Call)}, + {Py_tp_repr, reinterpret_cast(PjitFunction_tp_repr)}, + {Py_tp_members, reinterpret_cast(PjitFunction_members)}, + {0, nullptr}, +}; + } // namespace -void BuildPjitSubmodule(py::module& m) { - py::class_> cache( - m, "PjitFunctionCache"); - cache.def(py::init(), - py::arg("capacity") = PjitFunctionCache::kDefaultCapacity); +void BuildPjitSubmodule(nb::module_& m) { + nb::class_ cache(m, "PjitFunctionCache"); + cache.def(nb::init(), + nb::arg("capacity") = PjitFunctionCache::kDefaultCapacity); cache.def("size", &PjitFunctionCache::Size); cache.def("capacity", &PjitFunctionCache::Capacity); cache.def("clear", &PjitFunctionCache::Clear); cache.def_static("clear_all", []() { GetGlobalPjitFunctionStore().ClearFunctionCache(); }); - cache.def(py::pickle( - // __getstate__ - // Pickles as an empty cache; the client can repopulate as needed. - [](const PjitFunctionCache& cache) { - py::dict pickle; - pickle["version"] = kPjitFunctionPickleVersion; - pickle["capacity"] = cache.Capacity(); - return pickle; - }, - // __setstate__ - [](const py::dict& pickle) { - int version = py::cast(pickle["version"]); - if (version != kPjitFunctionPickleVersion) { - throw std::invalid_argument(absl::StrFormat( - "Invalid PjitFunction pickle version, got %d, expected %d", - version, kPjitFunctionPickleVersion)); - } - int capacity = py::cast(pickle["capacity"]); - return std::make_shared(capacity); - })); + cache.def("__getstate__", + // Pickles as an empty cache; the client can repopulate as needed. + [](const PjitFunctionCache& cache) { + nb::dict pickle; + pickle["version"] = kPjitFunctionPickleVersion; + pickle["capacity"] = cache.Capacity(); + return pickle; + }); + cache.def("__setstate__", + [](PjitFunctionCache* cache, const nb::dict& pickle) { + int version = nb::cast(pickle["version"]); + if (version != kPjitFunctionPickleVersion) { + throw std::invalid_argument(absl::StrFormat( + "Invalid PjitFunction pickle version, got %d, expected %d", + version, kPjitFunctionPickleVersion)); + } + int capacity = nb::cast(pickle["capacity"]); + new (cache) PjitFunctionCache(capacity); + }); // We need to use heap-allocated type objects because we want to add // additional methods dynamically. - py::object cfun; - { - py::str name = py::str("PjitFunction"); - py::str qualname = py::str("PjitFunction"); - PyHeapTypeObject* heap_type = reinterpret_cast( - PyType_Type.tp_alloc(&PyType_Type, 0)); - // Caution: we must not call any functions that might invoke the GC until - // PyType_Ready() is called. Otherwise the GC might see a half-constructed - // type object. - CHECK(heap_type) << "Unable to create heap type object"; - heap_type->ht_name = name.release().ptr(); - heap_type->ht_qualname = qualname.release().ptr(); - PyTypeObject* type = &heap_type->ht_type; - type->tp_name = "PjitFunction"; - type->tp_basicsize = sizeof(PjitFunctionObject); - type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE | - Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_HAVE_VECTORCALL; - type->tp_new = PjitFunction_tp_new; - type->tp_dealloc = PjitFunction_tp_dealloc; - type->tp_dictoffset = offsetof(PjitFunctionObject, dict); - type->tp_traverse = PjitFunction_tp_traverse; - type->tp_clear = PjitFunction_tp_clear; - type->tp_weaklistoffset = offsetof(PjitFunctionObject, weakrefs); - type->tp_getset = PjitFunction_tp_getset; - type->tp_descr_get = PjitFunction_tp_descr_get; - type->tp_call = PyVectorcall_Call; - type->tp_vectorcall_offset = offsetof(PjitFunctionObject, vectorcall); - type->tp_repr = PjitFunction_tp_repr; - CHECK_EQ(PyType_Ready(type), 0); - PjitFunction_Type = reinterpret_cast(type); - cfun = py::reinterpret_borrow(PjitFunction_Type); + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".PjitFunction"); + PyType_Spec PjitFunction_spec = { + /*.name=*/name.c_str(), + /*.basicsize=*/static_cast(sizeof(PjitFunctionObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL, +#else // PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL | Py_TPFLAGS_MANAGED_DICT | + Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX < 0x030C0000 + /*.slots=*/PjitFunction_slots, + }; + PjitFunction_Type = PyType_FromSpec(&PjitFunction_spec); + if (!PjitFunction_Type) { + throw nb::python_error(); } - py::object cfun_type = py::reinterpret_borrow(PjitFunction_Type); + nb::object cfun = nb::borrow(PjitFunction_Type); // Add PjitFunction to the xla_extension module so it can be pickled. - m.attr("PjitFunction") = cfun_type; - cfun.attr("__module__") = m.attr("__name__"); - - cfun.attr("__getstate__") = py::cpp_function( + m.attr("PjitFunction") = cfun; + cfun.attr("__getstate__") = nb::cpp_function( [](const PjitFunction::object& self) { PjitFunction* fn = self.func(); - py::dict pickle; + nb::dict pickle; pickle["version"] = kPjitFunctionPickleVersion; pickle["function_name"] = fn->function_name(); if (fn->fun().has_value()) { @@ -1009,16 +1088,17 @@ void BuildPjitSubmodule(py::module& m) { } pickle["cache_miss"] = fn->cache_miss(); pickle["static_argnums"] = fn->static_argnums(); - pickle["static_argnames"] = fn->static_argnames(); + pickle["static_argnames"] = nb::cast(fn->static_argnames()); pickle["donate_argnums"] = fn->donate_argnums(); - pickle["pytree_registry"] = fn->pytree_registry(); + pickle["pytree_registry"] = nb::cast(fn->pytree_registry()); + pickle["shard_arg_fallback"] = fn->shard_arg_fallback(); pickle["cache"] = fn->cache(); return pickle; }, - py::is_method(cfun_type)); - cfun.attr("__setstate__") = py::cpp_function( - [](py::object& self, const py::dict& pickle) { - int version = py::cast(pickle["version"]); + nb::is_method()); + cfun.attr("__setstate__") = nb::cpp_function( + [](nb::object& self, const nb::dict& pickle) { + int version = nb::cast(pickle["version"]); if (version != kPjitFunctionPickleVersion) { throw std::invalid_argument(absl::StrFormat( "Invalid PjitFunction pickle version, got %d, expected %d. " @@ -1027,66 +1107,71 @@ void BuildPjitSubmodule(py::module& m) { version, kPjitFunctionPickleVersion)); } std::string function_name = - py::cast(pickle["function_name"]); - std::optional fun; + nb::cast(pickle["function_name"]); + std::optional fun; if (pickle.contains("fun")) { - fun = py::cast(pickle["fun"]); + fun = nb::cast(pickle["fun"]); } - py::function cache_miss = py::cast(pickle["cache_miss"]); + nb::callable cache_miss = nb::cast(pickle["cache_miss"]); std::vector static_argnums = - py::cast>(pickle["static_argnums"]); - std::vector static_argnames = - py::cast>(pickle["static_argnames"]); + nb::cast>(pickle["static_argnums"]); + std::vector static_argnames = + nb::cast>(pickle["static_argnames"]); std::vector donate_argnums = - py::cast>(pickle["donate_argnums"]); + nb::cast>(pickle["donate_argnums"]); std::shared_ptr pytree_registry = - py::cast>( - pickle["pytree_registry"]); + nb::cast>( + nb::handle(pickle["pytree_registry"].ptr())); + nb::callable shard_arg_fallback = + nb::cast(pickle["shard_arg_fallback"]); std::shared_ptr cache = - py::cast>(pickle["cache"]); + nb::cast>(pickle["cache"]); InitializePjitFunction( reinterpret_cast(self.ptr()), std::move(function_name), std::move(fun), std::move(cache_miss), std::move(static_argnums), std::move(static_argnames), std::move(donate_argnums), std::move(pytree_registry), - std::move(cache)); + std::move(shard_arg_fallback), std::move(cache)); }, - py::is_method(cfun_type)); + nb::is_method()); cfun.attr("__signature__") = - property_readonly([](py::handle self) -> py::object { + xla::nb_property_readonly([](nb::handle self) -> nb::object { return AsPjitFunction(self)->PythonSignature(); }); cfun.attr("_cache_miss") = - property_readonly([](py::handle self) -> py::object { + xla::nb_property_readonly([](nb::handle self) -> nb::object { return AsPjitFunction(self)->cache_miss(); }); // All private members are only for testing/debugging purposes - cfun.attr("_cache_size") = py::cpp_function( - [](py::handle self) -> int { + cfun.attr("_cache_size") = nb::cpp_function( + [](nb::handle self) -> int { return AsPjitFunction(self)->cache_capacity(); }, - py::is_method(cfun)); - cfun.attr("_clear_cache") = py::cpp_function( - [](py::handle self) { AsPjitFunction(self)->ClearCache(); }, - py::is_method(cfun)); + nb::is_method()); + cfun.attr("_clear_cache") = nb::cpp_function( + [](nb::handle self) { AsPjitFunction(self)->ClearCache(); }, + nb::is_method()); m.def( "pjit", - [](std::string function_name, std::optional fun, - py::function cache_miss, std::vector static_argnums, - std::vector static_argnames, std::vector donate_argnums, - std::shared_ptr pytree_registry, - std::shared_ptr cache) { + [](std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, std::vector donate_argnums, + nb::object pytree_registry, nb::callable shard_arg_fallback, + std::optional> cache) { + std::shared_ptr registry = + nb::cast>( + nb::handle(pytree_registry.ptr())); return MakePjitFunction( std::move(function_name), std::move(fun), std::move(cache_miss), std::move(static_argnums), std::move(static_argnames), - std::move(donate_argnums), std::move(pytree_registry), - std::move(cache)); + std::move(donate_argnums), std::move(registry), + std::move(shard_arg_fallback), std::move(cache)); }, - py::arg("function_name"), py::arg("fun"), py::arg("cache_miss"), - py::arg("static_argnums"), py::arg("static_argnames"), - py::arg("donate_argnums"), py::arg("pytree_registry"), - py::arg("cache") = nullptr); + nb::arg("function_name"), nb::arg("fun").none(), nb::arg("cache_miss"), + nb::arg("static_argnums"), nb::arg("static_argnames"), + nb::arg("donate_argnums"), nb::arg("pytree_registry"), + nb::arg("shard_arg_fallback"), nb::arg("cache").none() = nb::none()); } } // namespace jax diff --git a/xla/python/pjit.h b/xla/python/pjit.h index 251e3d85d83d0..2d605d8db11c3 100644 --- a/xla/python/pjit.h +++ b/xla/python/pjit.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,12 +17,11 @@ limitations under the License. #define XLA_PYTHON_PJIT_H_ // placeholder for index annotation headers -#include "pybind11/pybind11.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind namespace jax { -void BuildPjitSubmodule(pybind11::module& m); - +void BuildPjitSubmodule(nanobind::module_& m); } #endif // XLA_PYTHON_PJIT_H_ diff --git a/xla/python/pjrt_ifrt/BUILD b/xla/python/pjrt_ifrt/BUILD index 07f7718d25967..978cb76e16eca 100644 --- a/xla/python/pjrt_ifrt/BUILD +++ b/xla/python/pjrt_ifrt/BUILD @@ -1,5 +1,7 @@ -load("//xla:xla.bzl", "xla_cc_test") +load("@tsl//tsl:tsl.bzl", "internal_visibility") +load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@tsl//tsl/platform:build_config.bzl", "tf_proto_library") +load("//xla:xla.bzl", "xla_cc_test") package_group( name = "friends", @@ -20,10 +22,10 @@ package_group( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ + default_visibility = internal_visibility([ ":friends", ":internal", - ], + ]), ) exports_files([ @@ -41,8 +43,10 @@ cc_library( "xla_compiler.h", "xla_sharding.h", ], + compatible_with = get_compatible_with_portable(), deps = [ ":xla_compiler_proto_cc", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/pjrt:pjrt_executable", "//xla/python/ifrt", @@ -75,20 +79,21 @@ tf_proto_library( cc_library( name = "xla_program_serdes", srcs = ["xla_program_serdes.cc"], + compatible_with = get_compatible_with_portable(), deps = [ ":xla_ifrt", "//xla/mlir_hlo:mhlo_passes", + "//xla/pjrt:mlir_to_hlo", "//xla/python/ifrt:serdes", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:Support", "@stablehlo//:stablehlo_portable_api", "@stablehlo//:stablehlo_serialization", + "@tsl//tsl/platform:status", ], alwayslink = True, ) @@ -115,7 +120,7 @@ tf_proto_library( srcs = ["xla_sharding.proto"], protodeps = [ "//xla:xla_data_proto", - "//xla/python/ifrt:types_proto", + "//xla/python/ifrt:device_proto", ], ) @@ -184,6 +189,7 @@ xla_cc_test( ":tfrt_cpu_client_test_lib", ":xla_ifrt", "//xla:xla_data_proto_cc", + "//xla/python/ifrt", "//xla/python/ifrt:sharding_test_util", "//xla/python/ifrt:tuple_impl_test_lib", "@com_google_googletest//:gtest_main", @@ -210,6 +216,7 @@ cc_library( "pjrt_host_callback.h", "pjrt_tuple.h", ], + compatible_with = get_compatible_with_portable(), deps = [ ":xla_ifrt", "//xla:literal", @@ -220,20 +227,25 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/pjrt:host_callback", + "//xla/pjrt:mlir_to_hlo", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", + "//xla/pjrt:pjrt_layout", "//xla/pjrt:utils", "//xla/python/ifrt", "//xla/service:hlo_proto_cc", "//xla/translate/mhlo_to_hlo:type_to_shape", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", diff --git a/xla/python/pjrt_ifrt/pjrt_array.cc b/xla/python/pjrt_ifrt/pjrt_array.cc index 048003004c3ec..1f676a1bfe423 100644 --- a/xla/python/pjrt_ifrt/pjrt_array.cc +++ b/xla/python/pjrt_ifrt/pjrt_array.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,33 +15,108 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_array.h" +#include #include #include #include #include +#include #include #include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "xla/literal.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/utils.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/future.h" #include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/shape_util.h" +#include "xla/status.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace ifrt { +namespace { + +// Validates the sharding and PjRtBuffers have consistent device and memory +// kind. +Status ValidateArrayCreationInput(std::shared_ptr sharding, + const PjRtArray::PjRtBuffers& pjrt_buffers) { + if (pjrt_buffers.empty()) { + return InvalidArgument("pjrt_buffers must be non-empty"); + } + if (sharding->devices().size() != pjrt_buffers.size()) { + return InvalidArgument("device and buffer counts mismatch: %d vs. %d", + sharding->devices().size(), pjrt_buffers.size()); + } + + // Canonicalize memory kind in case it hasn't been done before. + MemoryKind canonicalized_sharding_memory_kind = CanonicalizeMemoryKind( + sharding->memory_kind(), sharding->devices().front()); + for (int i = 0; i < sharding->devices().size(); ++i) { + if (pjrt_buffers[i]->device() != sharding->devices()[i]) { + return InvalidArgument( + "PjRtBuffer's memory space is addressed by device %s vs sharding is " + "on device %s", + pjrt_buffers[i]->device()->DebugString(), + sharding->devices()[i]->DebugString()); + } + MemoryKind buffer_memory_kind = + MakeMemoryKindFromPjRtBuffer(pjrt_buffers[i].get()); + if (canonicalized_sharding_memory_kind != buffer_memory_kind) { + return InvalidArgument( + "PjRtBuffer's memory kind does not match sharding's memory kind. Got " + "PjRtBuffer's memory kind: %s vs shardings's memory kind: %s", + buffer_memory_kind.DebugString(), + canonicalized_sharding_memory_kind.DebugString()); + } + } + return OkStatus(); +} + +// Validates the PjRtBuffers have consistent memory kind and returns the memory +// kind. +absl::StatusOr GetMemoryKindFromPjRtBuffers( + const PjRtArray::PjRtBuffers& pjrt_buffers) { + const auto first_memory_kind = + MakeMemoryKindFromPjRtBuffer(pjrt_buffers.front().get()); + const MemoryKind canonical_first_memory_kind = + CanonicalizeMemoryKind(first_memory_kind, pjrt_buffers.front()->device()); + for (const auto& pjrt_buffer : pjrt_buffers) { + if (auto memory_kind = MakeMemoryKindFromPjRtBuffer(pjrt_buffer.get()); + canonical_first_memory_kind != + CanonicalizeMemoryKind(memory_kind, pjrt_buffer->device())) { + return InvalidArgument( + "Memory kind mismatch between PjRtBuffers. Got one buffer with " + "memory kind: %s and another with memory_kind: %s", + first_memory_kind.DebugString(), memory_kind.DebugString()); + } + } + return first_memory_kind; +} + +} // namespace + char PjRtCompatibleArray::ID = 0; char PjRtArray::ID = 0; -StatusOr ToPrimitiveType(DType dtype) { +absl::StatusOr ToPrimitiveType(DType dtype) { switch (dtype.kind()) { #define CASE(DT, PT) \ case DT: \ @@ -80,7 +155,7 @@ StatusOr ToPrimitiveType(DType dtype) { return InvalidArgument("Invalid DType: %d", static_cast(dtype.kind())); } -StatusOr ToDType(xla::PrimitiveType primitive_type) { +absl::StatusOr ToDType(xla::PrimitiveType primitive_type) { switch (primitive_type) { case xla::PrimitiveType::PRIMITIVE_TYPE_INVALID: case xla::PrimitiveType::PRED: @@ -120,43 +195,23 @@ MemoryKind MakeMemoryKindFromPjRtBuffer(PjRtBuffer* pjrt_buffer) { return MemoryKind(pjrt_buffer->memory_space()->memory_space_kind()); } -StatusOr> PjRtArray::Create( +absl::StatusOr> PjRtArray::Create( PjRtCompatibleClient* client, DType dtype, Shape shape, std::shared_ptr sharding, PjRtBuffers pjrt_buffers) { - if (pjrt_buffers.empty()) { - return InvalidArgument("pjrt_buffers must be non-empty"); - } - if (sharding->devices().size() != pjrt_buffers.size()) { - return InvalidArgument("device and buffer counts mismatch: %d vs. %d", - sharding->devices().size(), pjrt_buffers.size()); - } - - // Canonicalize memory kind in case it hasn't been done before. - MemoryKind canonicalized_sharding_memory_kind = CanonicalizeMemoryKind( - sharding->memory_kind(), sharding->devices().front()); - for (int i = 0; i < sharding->devices().size(); ++i) { - if (pjrt_buffers[i]->device() != sharding->devices()[i]) { - return InvalidArgument( - "PjRtBuffer's memory space is addressed by device %s vs sharding is " - "on device %s", - pjrt_buffers[i]->device()->DebugString(), - sharding->devices()[i]->DebugString()); - } - MemoryKind buffer_memory_kind = - MakeMemoryKindFromPjRtBuffer(pjrt_buffers[i].get()); - if (canonicalized_sharding_memory_kind != buffer_memory_kind) { - return InvalidArgument( - "PjRtBuffer's memory kind does not match sharding's memory kind. Got " - "PjRtBuffer's memory kind: %s vs shardings's memory kind: %s", - buffer_memory_kind.DebugString(), - canonicalized_sharding_memory_kind.DebugString()); - } - } + TF_RETURN_IF_ERROR(ValidateArrayCreationInput(sharding, pjrt_buffers)); return tsl::MakeRef(client, dtype, std::move(shape), std::move(sharding), std::move(pjrt_buffers)); } -StatusOr> PjRtArray::Create( +absl::StatusOr> PjRtArray::Create( + PjRtCompatibleClient* client, DType dtype, DynamicShape dynamic_shape, + std::shared_ptr sharding, PjRtBuffers pjrt_buffers) { + TF_RETURN_IF_ERROR(ValidateArrayCreationInput(sharding, pjrt_buffers)); + return tsl::MakeRef(client, dtype, std::move(dynamic_shape), + std::move(sharding), std::move(pjrt_buffers)); +} + +absl::StatusOr> PjRtArray::Create( PjRtCompatibleClient* client, std::shared_ptr pjrt_buffer) { TF_ASSIGN_OR_RETURN(auto dtype, ToDType(pjrt_buffer->element_type())); Shape shape(pjrt_buffer->dimensions()); @@ -167,7 +222,7 @@ StatusOr> PjRtArray::Create( PjRtBuffers({std::move(pjrt_buffer)})); } -StatusOr> PjRtArray::FullyReplicatedShard( +absl::StatusOr> PjRtArray::FullyReplicatedShard( ArrayCopySemantics semantics) { return PjRtArray::Create(client(), GetPjRtBuffer(semantics, 0)); } @@ -188,39 +243,61 @@ std::shared_ptr PjRtArray::GetPjRtBuffer( } } -StatusOr> PjRtArray::Create( +absl::StatusOr> PjRtArray::Create( PjRtCompatibleClient* client, Shape shape, PjRtBuffers pjrt_buffers) { TF_ASSIGN_OR_RETURN(auto dtype, xla::ifrt::ToDType(pjrt_buffers.front()->element_type())); + TF_ASSIGN_OR_RETURN(MemoryKind memory_kind, + GetMemoryKindFromPjRtBuffers(pjrt_buffers)); + DeviceList::Devices devices; devices.reserve(pjrt_buffers.size()); std::vector shapes; shapes.reserve(pjrt_buffers.size()); - const auto first_memory_kind = - MakeMemoryKindFromPjRtBuffer(pjrt_buffers.front().get()); - const MemoryKind canonical_first_memory_kind = - CanonicalizeMemoryKind(first_memory_kind, pjrt_buffers.front()->device()); for (const auto& pjrt_buffer : pjrt_buffers) { devices.push_back(pjrt_buffer->device()); shapes.push_back(Shape(pjrt_buffer->dimensions())); - if (auto memory_kind = MakeMemoryKindFromPjRtBuffer(pjrt_buffer.get()); - canonical_first_memory_kind != - CanonicalizeMemoryKind(memory_kind, devices.back())) { - return InvalidArgument( - "Memory kind mismatch between PjRtBuffers. Got one buffer with " - "memory kind: %s and another with memory_kind: %s", - first_memory_kind.DebugString(), memory_kind.DebugString()); - } } auto sharding = ifrt::ConcreteSharding::Create(DeviceList(std::move(devices)), - first_memory_kind, + memory_kind, /*shape=*/shape, /*shard_shapes=*/shapes); return PjRtArray::Create(client, dtype, std::move(shape), std::move(sharding), std::move(pjrt_buffers)); } +absl::StatusOr> PjRtArray::Create( + PjRtCompatibleClient* client, DynamicShape dynamic_shape, + PjRtBuffers pjrt_buffers) { + TF_ASSIGN_OR_RETURN(auto dtype, + xla::ifrt::ToDType(pjrt_buffers.front()->element_type())); + TF_ASSIGN_OR_RETURN(auto memory_kind, + GetMemoryKindFromPjRtBuffers(pjrt_buffers)); + + DeviceList::Devices devices; + devices.reserve(pjrt_buffers.size()); + std::vector dynamic_shapes; + dynamic_shapes.reserve(pjrt_buffers.size()); + + for (const auto& pjrt_buffer : pjrt_buffers) { + devices.push_back(pjrt_buffer->device()); + TF_ASSIGN_OR_RETURN( + DynamicShape dynamic_shape, + // Extracts dynamic shape info from the buffers. + DynamicShape::Create( + Shape(pjrt_buffer->dimensions()), + BoundedDynamicShapeTag(pjrt_buffer->is_dynamic_dimension()))); + dynamic_shapes.push_back(std::move(dynamic_shape)); + } + auto sharding = ifrt::ConcreteSharding::Create( + DeviceList(std::move(devices)), memory_kind, + /*dynamic_shape=*/dynamic_shape, + /*shard_dynamic_shapes=*/dynamic_shapes); + return PjRtArray::Create(client, dtype, std::move(dynamic_shape), + std::move(sharding), std::move(pjrt_buffers)); +} + PjRtArray::PjRtArray(PjRtCompatibleClient* client, DType dtype, Shape shape, std::shared_ptr sharding, PjRtBuffers pjrt_buffers) @@ -230,23 +307,41 @@ PjRtArray::PjRtArray(PjRtCompatibleClient* client, DType dtype, Shape shape, sharding_(std::move(sharding)), pjrt_buffers_(std::move(pjrt_buffers)) {} -StatusOr>> +PjRtArray::PjRtArray(PjRtCompatibleClient* client, DType dtype, + DynamicShape dynamic_shape, + std::shared_ptr sharding, + PjRtBuffers pjrt_buffers) + : client_(client), + dtype_(dtype), + shape_(std::move(dynamic_shape)), + sharding_(std::move(sharding)), + pjrt_buffers_(std::move(pjrt_buffers)) {} + +absl::StatusOr>> PjRtArray::DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) { DCHECK(this); std::vector> result; result.reserve(sharding_->devices().size()); - TF_ASSIGN_OR_RETURN(auto shape_and_shardings, sharding_->Disassemble(shape_)); - for (int i = 0; i < sharding_->devices().size(); ++i) { - PjRtBuffers buffers; - buffers.reserve(1); - buffers.push_back(GetPjRtBuffer(semantics, i)); - TF_ASSIGN_OR_RETURN( - auto array, PjRtArray::Create(client_, dtype_, - std::move(shape_and_shardings[i].first), - std::move(shape_and_shardings[i].second), - std::move(buffers))); - result.push_back(std::move(array)); - } + TF_RETURN_IF_ERROR(std::visit( + [&](const auto& this_shape) { + TF_ASSIGN_OR_RETURN(auto shape_and_shardings, + sharding_->Disassemble(this_shape)); + for (int i = 0; i < sharding_->devices().size(); ++i) { + PjRtBuffers buffers; + buffers.reserve(1); + buffers.push_back(GetPjRtBuffer(semantics, i)); + TF_ASSIGN_OR_RETURN( + auto array, + PjRtArray::Create(client_, dtype_, + std::move(shape_and_shardings[i].first), + std::move(shape_and_shardings[i].second), + std::move(buffers))); + result.push_back(std::move(array)); + } + return xla::OkStatus(); + }, + shape_)); + return result; } @@ -267,15 +362,16 @@ Future PjRtArray::CopyToHostBuffer( PjRtBuffer* pjrt_buffer = pjrt_buffers_.front().get(); absl::Span dims; - StatusOr> logical_dims; + absl::StatusOr> logical_dims; if (!pjrt_buffer->has_dynamic_dimensions()) { - dims = shape_.dims(); + dims = std::get(shape_).dims(); } else { // TODO(b/182461453): This is a blocking call. If we further implemented // populating dynamic shape metadata while fetching the literal, we wouldn't // need this static approach. // TODO(hyeontaek): Clean up this dynamic shape access once we formalize // dynamic shape support in IFRT. + // TODO(b/314805296): Use the new dynamic shape here. logical_dims = pjrt_buffer->logical_dimensions(); if (!logical_dims.ok()) { return Future(std::move(logical_dims).status()); @@ -311,7 +407,7 @@ Future PjRtArray::CopyToHostBuffer( } // TODO(yashkatariya): Maybe move this to ifrt::Device? -StatusOr GetMemorySpaceFromMemoryKind( +absl::StatusOr GetMemorySpaceFromMemoryKind( ifrt::Device* device, ifrt::MemoryKind memory_kind) { PjRtMemorySpace* memory_space = nullptr; for (PjRtMemorySpace* ms : device->memory_spaces()) { @@ -332,7 +428,7 @@ StatusOr GetMemorySpaceFromMemoryKind( return memory_space; } -StatusOr> PjRtArray::Reshard( +absl::StatusOr> PjRtArray::Reshard( std::shared_ptr new_sharding, ArrayCopySemantics semantics) { DCHECK(this); @@ -417,8 +513,12 @@ StatusOr> PjRtArray::Reshard( } } } - return PjRtArray::Create(client_, dtype_, shape_, std::move(new_sharding), - std::move(buffers)); + return std::visit( + [this, &new_sharding, &buffers](const auto& shape) { + return PjRtArray::Create(client_, dtype_, shape, + std::move(new_sharding), std::move(buffers)); + }, + shape_); } Future PjRtArray::GetReadyFuture() const { @@ -452,9 +552,32 @@ bool PjRtArray::IsDeleted() const { std::string PjRtArray::DebugString() const { DCHECK(this); - return absl::StrFormat("PjRtArray(dtype=%s; shape=%s; sharding=%s)", - dtype_.DebugString(), shape_.DebugString(), - sharding_->DebugString()); + absl::StatusOr> layout_ptr = layout(); + std::string layout_str = + layout_ptr.ok() ? (*layout_ptr)->ToString() : ""; + + return absl::StrFormat( + "PjRtArray(dtype=%s; shape=%s; sharding=%s; layout=%s)", + dtype_.DebugString(), + std::visit([](const auto& shape) { return shape.DebugString(); }, shape_), + sharding_->DebugString(), layout_str); +} + +// TODO(b/330198879): populate layout at construction instead of accessing PJRT +// buffer directly for consistency with Pathways. +absl::StatusOr> PjRtArray::layout() const { + CHECK(!pjrt_buffers_.empty()); + std::unique_ptr layout = pjrt_buffers_[0]->layout(); +#ifndef NDEBUG + for (int i = 1; i < pjrt_buffers_.size(); ++i) { + std::unique_ptr layout_i = pjrt_buffers_[i]->layout(); + DCHECK(*layout == *layout_i) + << "PjRtArray has mismatched layouts across shards! " + << "shard 0: " << layout->ToString() << ", shard " << i << ": " + << layout_i->ToString(); + } +#endif + return layout; } } // namespace ifrt diff --git a/xla/python/pjrt_ifrt/pjrt_array.h b/xla/python/pjrt_ifrt/pjrt_array.h index 11193e025aea8..5542ae9e54e11 100644 --- a/xla/python/pjrt_ifrt/pjrt_array.h +++ b/xla/python/pjrt_ifrt/pjrt_array.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,12 +19,14 @@ limitations under the License. #include #include #include +#include #include #include "absl/container/inlined_vector.h" #include "llvm/Support/ExtensibleRTTI.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/shape.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "tsl/concurrency/ref_count.h" @@ -32,10 +34,10 @@ namespace xla { namespace ifrt { // Converts IFRT `DType` into `xla::PrimitiveType`. -StatusOr ToPrimitiveType(DType dtype); +absl::StatusOr ToPrimitiveType(DType dtype); // Converts `xla::PrimitiveType` into IFRT `DType`. -StatusOr ToDType(xla::PrimitiveType primitive_type); +absl::StatusOr ToDType(xla::PrimitiveType primitive_type); // Creates IFRT `MemoryKind` from an XLA `PjRtBuffer`. MemoryKind MakeMemoryKindFromPjRtBuffer(PjRtBuffer* pjrt_buffer); @@ -46,7 +48,7 @@ class PjRtCompatibleArray public: // APIs that allow direct access to `PjRtBuffer` for PjRt-only operations. virtual absl::Span> pjrt_buffers() = 0; - virtual StatusOr>> + virtual absl::StatusOr>> mutable_pjrt_buffers() = 0; static char ID; // NOLINT @@ -60,34 +62,45 @@ class PjRtArray final using PjRtBuffers = absl::InlinedVector, kPjRtBufferInlineSize>; - // General array construction. - static StatusOr> Create( + // General array construction (with static shape). + static absl::StatusOr> Create( PjRtCompatibleClient* client, DType dtype, Shape shape, std::shared_ptr sharding, PjRtBuffers pjrt_buffers); + // General array construction (with dynamic shape). + static absl::StatusOr> Create( + PjRtCompatibleClient* client, DType dtype, DynamicShape dynamic_shape, + std::shared_ptr sharding, PjRtBuffers pjrt_buffers); + // Shorthand for a single-shard array construction. - static StatusOr> Create( + static absl::StatusOr> Create( PjRtCompatibleClient* client, std::shared_ptr pjrt_buffer); // Shorthand for a multi-shard array construction using ConcreteSharding. // TODO(hyeontaek): Remove this once IFRT Sharding and JAX Sharding is unified // so that ConcreteSharding can be replaced with a real Sharding. - static StatusOr> Create( + static absl::StatusOr> Create( PjRtCompatibleClient* client, Shape shape, PjRtBuffers pjrt_buffers); + // Shorthand for a multi-shard array construction using ConcreteSharding with + // DynamicShape. + static absl::StatusOr> Create( + PjRtCompatibleClient* client, DynamicShape dynamic_shape, + PjRtBuffers pjrt_buffers); + // PjRtCompatibleArray implementation. absl::Span> pjrt_buffers() override { DCHECK(this); return pjrt_buffers_; } - StatusOr>> mutable_pjrt_buffers() + absl::StatusOr>> mutable_pjrt_buffers() override { DCHECK(this); return absl::MakeSpan(pjrt_buffers_); } - StatusOr> FullyReplicatedShard( + absl::StatusOr> FullyReplicatedShard( ArrayCopySemantics semantics) override; // Array implementation. @@ -103,10 +116,27 @@ class PjRtArray final DCHECK(this); return dtype_; } - const Shape& shape() const override { + + bool has_dynamic_shape() const { DCHECK(this); - return shape_; + return std::holds_alternative(shape_); } + + bool has_static_shape() const { + DCHECK(this); + return std::holds_alternative(shape_); + } + + const Shape& shape() const override { + DCHECK(has_static_shape()); + return std::get(shape_); + } + + const DynamicShape& dynamic_shape() const { + DCHECK(has_dynamic_shape()); + return std::get(shape_); + } + const Sharding& sharding() const override { DCHECK(this); return *sharding_; @@ -116,7 +146,9 @@ class PjRtArray final return sharding_; } - StatusOr>> + absl::StatusOr> layout() const override; + + absl::StatusOr>> DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) override; ABSL_MUST_USE_RESULT @@ -124,7 +156,7 @@ class PjRtArray final void* data, std::optional> byte_strides, ArrayCopySemantics semantics) override; - StatusOr> Reshard( + absl::StatusOr> Reshard( std::shared_ptr new_sharding, ArrayCopySemantics semantics) override; @@ -144,12 +176,16 @@ class PjRtArray final PjRtArray(PjRtCompatibleClient* client, DType dtype, Shape shape, std::shared_ptr sharding, PjRtBuffers pjrt_buffers); + PjRtArray(PjRtCompatibleClient* client, DType dtype, + DynamicShape dynamic_shape, + std::shared_ptr sharding, PjRtBuffers pjrt_buffers); + template friend tsl::RCReference tsl::MakeRef(Args&&... args); PjRtCompatibleClient* client_; DType dtype_; - Shape shape_; + std::variant shape_; std::shared_ptr sharding_; PjRtBuffers pjrt_buffers_; }; diff --git a/xla/python/pjrt_ifrt/pjrt_array_impl_test_tfrt_cpu.cc b/xla/python/pjrt_ifrt/pjrt_array_impl_test_tfrt_cpu.cc index 5685f84a028b0..8f3429bffcde2 100644 --- a/xla/python/pjrt_ifrt/pjrt_array_impl_test_tfrt_cpu.cc +++ b/xla/python/pjrt_ifrt/pjrt_array_impl_test_tfrt_cpu.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/pjrt_ifrt/pjrt_client.cc b/xla/python/pjrt_ifrt/pjrt_client.cc index 578f1c4ad14ff..ad313646f2bb3 100644 --- a/xla/python/pjrt_ifrt/pjrt_client.cc +++ b/xla/python/pjrt_ifrt/pjrt_client.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,9 +21,16 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" #include "absl/memory/memory.h" +#include "absl/status/statusor.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/layout.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/client.h" #include "xla/python/ifrt/sharding.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_tuple.h" @@ -34,6 +41,16 @@ limitations under the License. namespace xla { namespace ifrt { +namespace { + +// A nullptr std::function implicitly converts to a non-nullptr +// absl::AnyInvocable, which later crashes when being invoked. absl team +// explicitly said this is WAI. See b/258212655#comment10. +absl::AnyInvocable FromStdFunction(std::function&& f) { + return f ? std::move(f) : absl::AnyInvocable(); +} + +} // namespace char PjRtCompatibleClient::ID = 0; char PjRtClient::ID = 0; @@ -43,21 +60,43 @@ std::unique_ptr PjRtClient::Create( return absl::WrapUnique(new PjRtClient(std::move(pjrt_client))); } -StatusOr> PjRtClient::CreatePjRtArray( - std::shared_ptr pjrt_buffer) { +absl::flat_hash_map +PjRtClient::attributes() const { + absl::flat_hash_map attributes; + attributes.insert({"supports_executable_serialization", true}); + + if (std::optional plugin_attributes = + pjrt_client_->plugin_attributes(); + plugin_attributes.has_value()) { + attributes.insert( + {"pjrt_c_api_major_version", + ClientAttribute(plugin_attributes->pjrt_c_api_major_version)}); + attributes.insert( + {"pjrt_c_api_minor_version", + ClientAttribute(plugin_attributes->pjrt_c_api_minor_version)}); + for (const auto& [key, value] : plugin_attributes->attributes) { + attributes.insert({key, value}); + } + } + + return attributes; +} + +absl::StatusOr> +PjRtClient::CreatePjRtArray(std::shared_ptr pjrt_buffer) { TF_ASSIGN_OR_RETURN(auto array, PjRtArray::Create(this, std::move(pjrt_buffer))); return tsl::RCReference(std::move(array)); } -StatusOr> PjRtClient::CreatePjRtArray( - Shape shape, PjRtBuffers pjrt_buffers) { +absl::StatusOr> +PjRtClient::CreatePjRtArray(Shape shape, PjRtBuffers pjrt_buffers) { TF_ASSIGN_OR_RETURN(auto array, PjRtArray::Create(this, std::move(shape), std::move(pjrt_buffers))); return tsl::RCReference(std::move(array)); } -StatusOr> PjRtClient::MakeArrayFromHostBuffer( +absl::StatusOr> PjRtClient::MakeArrayFromHostBuffer( const void* data, DType dtype, Shape shape, std::optional> byte_strides, std::shared_ptr sharding, @@ -95,24 +134,25 @@ StatusOr> PjRtClient::MakeArrayFromHostBuffer( absl::StrAppend(out, ms->memory_space_kind()); })); } - TF_ASSIGN_OR_RETURN( - buffer, pjrt_client_->BufferFromHostBuffer( - data, primitive_type, shape.dims(), byte_strides, semantics, - std::move(on_done_with_host_buffer), memory_space, - /*device_layout=*/nullptr)); - } else { TF_ASSIGN_OR_RETURN( buffer, pjrt_client_->BufferFromHostBuffer( data, primitive_type, shape.dims(), byte_strides, semantics, - std::move(on_done_with_host_buffer), sharding->devices().front())); + FromStdFunction(std::move(on_done_with_host_buffer)), memory_space, + /*device_layout=*/nullptr)); + } else { + TF_ASSIGN_OR_RETURN( + buffer, pjrt_client_->BufferFromHostBuffer( + data, primitive_type, shape.dims(), byte_strides, semantics, + FromStdFunction(std::move(on_done_with_host_buffer)), + sharding->devices().front())); } return PjRtArray::Create( this, dtype, std::move(shape), std::move(sharding), PjRtArray::PjRtBuffers({std::shared_ptr(buffer.release())})); } -StatusOr> +absl::StatusOr> PjRtClient::AssembleArrayFromSingleDeviceArrays( Shape shape, std::shared_ptr sharding, absl::Span> arrays, ArrayCopySemantics semantics) { @@ -181,12 +221,12 @@ PjRtClient::AssembleArrayFromSingleDeviceArrays( std::move(buffers)); } -StatusOr> PjRtClient::MakeTuple( +absl::StatusOr> PjRtClient::MakeTuple( absl::Span> values) { return PjRtTuple::Create(this, values); } -StatusOr> +absl::StatusOr> PjRtClient::GetTopologyForDevices(absl::Span devices) const { // TODO(parkers): Consider constructing a sub-slice topology based on the // provided devices. @@ -195,5 +235,15 @@ PjRtClient::GetTopologyForDevices(absl::Span devices) const { topology); } +absl::StatusOr> +PjRtClient::GetDefaultLayoutForDevice(DType dtype, + absl::Span dims, + Device* device) const { + TF_ASSIGN_OR_RETURN(PrimitiveType element_type, ToPrimitiveType(dtype)); + TF_ASSIGN_OR_RETURN(xla::Layout layout, + pjrt_client_->GetDefaultLayout(element_type, dims)); + return std::make_unique(std::move(layout)); +} + } // namespace ifrt } // namespace xla diff --git a/xla/python/pjrt_ifrt/pjrt_client.h b/xla/python/pjrt_ifrt/pjrt_client.h index 4877e36098166..32e7759118054 100644 --- a/xla/python/pjrt_ifrt/pjrt_client.h +++ b/xla/python/pjrt_ifrt/pjrt_client.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -48,9 +48,9 @@ class PjRtCompatibleClient // operations. virtual xla::PjRtClient* pjrt_client() = 0; virtual std::shared_ptr shared_ptr_pjrt_client() = 0; - virtual StatusOr> CreatePjRtArray( + virtual absl::StatusOr> CreatePjRtArray( std::shared_ptr pjrt_buffer) = 0; - virtual StatusOr> CreatePjRtArray( + virtual absl::StatusOr> CreatePjRtArray( Shape shape, PjRtBuffers pjrt_buffers) = 0; static char ID; // NOLINT @@ -70,28 +70,28 @@ class PjRtClient final std::shared_ptr shared_ptr_pjrt_client() override { return pjrt_client_; } - StatusOr> CreatePjRtArray( + absl::StatusOr> CreatePjRtArray( std::shared_ptr pjrt_buffer) override; - StatusOr> CreatePjRtArray( + absl::StatusOr> CreatePjRtArray( Shape shape, PjRtBuffers pjrt_buffers) override; // Client implementation. ~PjRtClient() override = default; - StatusOr> MakeArrayFromHostBuffer( + absl::StatusOr> MakeArrayFromHostBuffer( const void* data, DType dtype, Shape shape, std::optional> byte_strides, std::shared_ptr sharding, Client::HostBufferSemantics semantics, std::function on_done_with_host_buffer) override; - StatusOr> AssembleArrayFromSingleDeviceArrays( + absl::StatusOr> AssembleArrayFromSingleDeviceArrays( Shape shape, std::shared_ptr sharding, absl::Span> arrays, ArrayCopySemantics semantics) override; - StatusOr> MakeTuple( + absl::StatusOr> MakeTuple( absl::Span> values) override; absl::string_view runtime_type() const override { @@ -111,18 +111,8 @@ class PjRtClient final DCHECK(this); return pjrt_client_->platform_id(); } - absl::flat_hash_map attributes() - const override { - std::optional attributes = - pjrt_client_->plugin_attributes(); - if (!attributes.has_value()) { - return {}; - } - return {{"pjrt_c_api_major_version", - ClientAttribute(attributes->pjrt_c_api_major_version)}, - {"pjrt_c_api_minor_version", - ClientAttribute(attributes->pjrt_c_api_minor_version)}}; - } + + absl::flat_hash_map attributes() const override; int device_count() const override { DCHECK(this); @@ -141,18 +131,18 @@ class PjRtClient final return pjrt_client_->addressable_devices(); } int process_index() const override { return pjrt_client_->process_index(); } - StatusOr GetDefaultDeviceAssignment( + absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override { DCHECK(this); return pjrt_client_->GetDefaultDeviceAssignment(num_replicas, num_partitions); } - StatusOr LookupDevice(int device_id) const override { + absl::StatusOr LookupDevice(int device_id) const override { DCHECK(this); return pjrt_client_->LookupDevice(device_id); } - StatusOr LookupAddressableDevice( + absl::StatusOr LookupAddressableDevice( int local_hardware_id) const override { DCHECK(this); return pjrt_client_->LookupAddressableDevice(local_hardware_id); @@ -163,9 +153,13 @@ class PjRtClient final return &default_compiler_; } - StatusOr> + absl::StatusOr> GetTopologyForDevices(absl::Span devices) const override; + absl::StatusOr> GetDefaultLayoutForDevice( + DType dtype, absl::Span dims, + Device* device) const override; + static char ID; // NOLINT private: diff --git a/xla/python/pjrt_ifrt/pjrt_compiler.cc b/xla/python/pjrt_ifrt/pjrt_compiler.cc index 0c9074a46671c..bb7aaad5e02df 100644 --- a/xla/python/pjrt_ifrt/pjrt_compiler.cc +++ b/xla/python/pjrt_ifrt/pjrt_compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ namespace ifrt { char PjRtCompiler::ID = 0; -StatusOr> PjRtCompiler::Compile( +absl::StatusOr> PjRtCompiler::Compile( std::unique_ptr program, std::unique_ptr options) { DCHECK(this); const auto* xla_program = llvm::dyn_cast(program.get()); @@ -46,7 +46,7 @@ StatusOr> PjRtCompiler::Compile( std::move(xla_compile_options->loaded_host_callbacks)); } -StatusOr> +absl::StatusOr> PjRtCompiler::DeserializeLoadedExecutable( absl::string_view serialized, std::unique_ptr options) { diff --git a/xla/python/pjrt_ifrt/pjrt_compiler.h b/xla/python/pjrt_ifrt/pjrt_compiler.h index cfd62d352bac9..55a0c763ba3b5 100644 --- a/xla/python/pjrt_ifrt/pjrt_compiler.h +++ b/xla/python/pjrt_ifrt/pjrt_compiler.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,11 +38,11 @@ class PjRtCompiler final : public llvm::RTTIExtends { ~PjRtCompiler() override = default; - StatusOr> Compile( + absl::StatusOr> Compile( std::unique_ptr program, std::unique_ptr options) override; - StatusOr> DeserializeLoadedExecutable( + absl::StatusOr> DeserializeLoadedExecutable( absl::string_view serialized, std::unique_ptr options) override; diff --git a/xla/python/pjrt_ifrt/pjrt_executable.cc b/xla/python/pjrt_ifrt/pjrt_executable.cc index e93cd5916a939..8f9b2109eb362 100644 --- a/xla/python/pjrt_ifrt/pjrt_executable.cc +++ b/xla/python/pjrt_ifrt/pjrt_executable.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "xla/hlo/ir/hlo_sharding.h" #include "xla/pjrt/host_callback.h" +#include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" @@ -58,7 +59,6 @@ limitations under the License. #include "xla/xla_data.pb.h" #include "tsl/concurrency/ref_count.h" #include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" namespace xla { namespace ifrt { @@ -66,7 +66,7 @@ namespace ifrt { namespace { // Returns the op sharding of the root instruction in the entry computation. -StatusOr FindRootInstruction( +absl::StatusOr FindRootInstruction( const HloModuleProto& proto) { for (const auto& computation : proto.computations()) { if (computation.id() == proto.entry_computation_id()) { @@ -82,7 +82,8 @@ StatusOr FindRootInstruction( // Returns the output element types of the first module in a // `PjRtLoadedExecutable`. -StatusOr> GetFirstModuleOutputElementTypes( +absl::StatusOr> +GetFirstModuleOutputElementTypes( xla::PjRtLoadedExecutable* pjrt_loaded_executable) { auto element_types = pjrt_loaded_executable->GetOutputElementTypes(); TF_RETURN_IF_ERROR(element_types.status()); @@ -94,7 +95,8 @@ StatusOr> GetFirstModuleOutputElementTypes( // Returns the output dimensions of the first module in a // `PjRtLoadedExecutable`. -StatusOr> GetFirstModuleOutputDimensions( +absl::StatusOr> +GetFirstModuleOutputDimensions( xla::PjRtLoadedExecutable* pjrt_loaded_executable) { auto dimensions = pjrt_loaded_executable->GetOutputDimensions(); TF_RETURN_IF_ERROR(dimensions.status()); @@ -106,7 +108,7 @@ StatusOr> GetFirstModuleOutputDimensions( // Returns the output shardings of the first module in a // `PjRtLoadedExecutable`. -StatusOr> GetFirstModuleOutputSharding( +absl::StatusOr> GetFirstModuleOutputSharding( xla::PjRtLoadedExecutable* pjrt_loaded_executable, const xla::Shape& shape) { auto output_shardings = pjrt_loaded_executable->GetOutputShardings(); @@ -129,7 +131,7 @@ StatusOr> GetFirstModuleOutputSharding( // Returns the flattened output memory_kinds of the first module in a // `UnimplementedError` will be converted into `std::nullopt`. -StatusOr>> +absl::StatusOr>> GetFirstModuleOutputMemoryKinds( xla::PjRtLoadedExecutable* pjrt_loaded_executable) { auto output_memory_kinds = pjrt_loaded_executable->GetOutputMemoryKinds(); @@ -151,7 +153,7 @@ struct ShapePartialInfo { std::vector dimensions; }; -StatusOr CreateShapePartialInfo( +absl::StatusOr CreateShapePartialInfo( absl::Span shapes) { ShapePartialInfo partial_info; partial_info.element_types.reserve(shapes.size()); @@ -176,29 +178,29 @@ char PjRtCompatibleLoadedExecutable::ID = 0; char PjRtExecutable::ID = 0; char PjRtLoadedExecutable::ID = 0; -StatusOr> PjRtExecutable::Create( +absl::StatusOr> PjRtExecutable::Create( std::unique_ptr pjrt_executable) { return std::unique_ptr(new PjRtExecutable( std::shared_ptr(pjrt_executable.release()))); } -StatusOr> PjRtExecutable::Create( +absl::StatusOr> PjRtExecutable::Create( std::shared_ptr pjrt_executable) { return std::unique_ptr( new PjRtExecutable(std::move(pjrt_executable))); } -StatusOr> PjRtExecutable::Fingerprint() const { +absl::StatusOr> PjRtExecutable::Fingerprint() const { DCHECK(this); return pjrt_executable_->FingerprintExecutable(); } -StatusOr PjRtExecutable::Serialize() const { +absl::StatusOr PjRtExecutable::Serialize() const { DCHECK(this); return pjrt_executable_->SerializeExecutable(); } -StatusOr> PjRtLoadedExecutable::Create( +absl::StatusOr> PjRtLoadedExecutable::Create( PjRtCompatibleClient* client, std::unique_ptr pjrt_loaded_executable, std::vector> loaded_host_callbacks) { @@ -208,7 +210,7 @@ StatusOr> PjRtLoadedExecutable::Create( std::move(loaded_host_callbacks)); } -StatusOr> PjRtLoadedExecutable::Create( +absl::StatusOr> PjRtLoadedExecutable::Create( PjRtCompatibleClient* client, std::shared_ptr pjrt_loaded_executable, std::vector> loaded_host_callbacks) { @@ -231,7 +233,7 @@ StatusOr> PjRtLoadedExecutable::Create( result_memory_kinds, loaded_host_callbacks); } -static StatusOr> ResultShapesOfModule( +static absl::StatusOr> ResultShapesOfModule( mlir::ModuleOp module) { auto main = module.lookupSymbol("main"); if (!main) { @@ -247,7 +249,7 @@ static StatusOr> ResultShapesOfModule( return result_shapes; } -StatusOr> PjRtLoadedExecutable::Create( +absl::StatusOr> PjRtLoadedExecutable::Create( PjRtCompatibleClient* client, mlir::ModuleOp module, xla::CompileOptions compile_options, std::vector> loaded_host_callbacks) { @@ -261,6 +263,7 @@ StatusOr> PjRtLoadedExecutable::Create( build_options.use_spmd_partitioning() && build_options.num_partitions() > 1 && (build_options.use_auto_spmd_partitioning() || + build_options.any_allow_spmd_sharding_propagation_to_parameters() || build_options.any_allow_spmd_sharding_propagation_to_output()); TF_ASSIGN_OR_RETURN( auto pjrt_loaded_executable, @@ -317,7 +320,7 @@ StatusOr> PjRtLoadedExecutable::Create( } } -StatusOr> +absl::StatusOr> PjRtLoadedExecutable::CreateInternal( PjRtCompatibleClient* client, std::shared_ptr pjrt_loaded_executable, @@ -470,9 +473,10 @@ PjRtLoadedExecutable::PjRtLoadedExecutable( PjRtLoadedExecutable::~PjRtLoadedExecutable() = default; -StatusOr PjRtLoadedExecutable::Execute( - absl::Span> args, const ExecuteOptions& options, - std::optional devices) { +absl::StatusOr +PjRtLoadedExecutable::Execute(absl::Span> args, + const ExecuteOptions& options, + std::optional devices) { DCHECK(this); // TODO(hyeontaek): Check input sharding consistency. @@ -529,7 +533,7 @@ StatusOr PjRtLoadedExecutable::Execute( auto opts = options; if (!all_loaded_host_callbacks_->empty() && !returned_future_supported) { - return InternalError( + return Internal( "Host callback not supported without returned future support in " "runtime: %s", client_->runtime_type()); @@ -609,7 +613,7 @@ StatusOr PjRtLoadedExecutable::Execute( if (pjrt_outputs.size() != num_computations) { return FailedPrecondition( "Unexpected number of computations in outputs: %d vs. %d", - pjrt_outputs.front().size(), num_computations); + pjrt_outputs.size(), num_computations); } const int num_outputs = pjrt_outputs.front().size(); if (num_outputs != output_dtypes_.size()) { @@ -667,9 +671,10 @@ StatusOr PjRtLoadedExecutable::Execute( return result; } -StatusOr> PjRtLoadedExecutable::Fingerprint() const { +absl::StatusOr> PjRtLoadedExecutable::Fingerprint() + const { DCHECK(this); - StatusOr fingerprint = + absl::StatusOr fingerprint = pjrt_loaded_executable_->FingerprintExecutable(); if (fingerprint.ok()) { return {fingerprint.value()}; @@ -681,7 +686,7 @@ StatusOr> PjRtLoadedExecutable::Fingerprint() const { } } -StatusOr PjRtLoadedExecutable::Serialize() const { +absl::StatusOr PjRtLoadedExecutable::Serialize() const { DCHECK(this); return pjrt_loaded_executable_->SerializeExecutable(); } diff --git a/xla/python/pjrt_ifrt/pjrt_executable.h b/xla/python/pjrt_ifrt/pjrt_executable.h index d2651d66ea449..f667508a287be 100644 --- a/xla/python/pjrt_ifrt/pjrt_executable.h +++ b/xla/python/pjrt_ifrt/pjrt_executable.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" @@ -81,9 +82,9 @@ class PjRtExecutable final : public llvm::RTTIExtends { public: // Creates PjRtExecutable from xla::PjRtExecutable. - static StatusOr> Create( + static absl::StatusOr> Create( std::unique_ptr pjrt_executable); - static StatusOr> Create( + static absl::StatusOr> Create( std::shared_ptr pjrt_executable); // PjRtCompatibleExecutable implementation. @@ -113,19 +114,21 @@ class PjRtExecutable final return pjrt_executable_->GetOutputShardings(); } - StatusOr> GetParameterLayouts() const override { + absl::StatusOr>> GetParameterLayouts() + const override { DCHECK(this); return pjrt_executable_->GetParameterLayouts(); } - StatusOr> GetOutputLayouts() const override { + absl::StatusOr>> GetOutputLayouts() + const override { DCHECK(this); return pjrt_executable_->GetOutputLayouts(); } - StatusOr> Fingerprint() const override; + absl::StatusOr> Fingerprint() const override; - StatusOr Serialize() const override; + absl::StatusOr Serialize() const override; int num_devices() const override { DCHECK(this); @@ -136,18 +139,19 @@ class PjRtExecutable final DCHECK(this); return pjrt_executable_->SizeOfGeneratedCodeInBytes(); } - StatusOr GetCompiledMemoryStats() const override { + absl::StatusOr GetCompiledMemoryStats() const override { DCHECK(this); return pjrt_executable_->GetCompiledMemoryStats(); } - StatusOr>> GetHloModules() + absl::StatusOr>> GetHloModules() const override { DCHECK(this); return pjrt_executable_->GetHloModules(); } - StatusOr> + absl::StatusOr< + absl::flat_hash_map> GetCostAnalysis() const override { return pjrt_executable_->GetCostAnalysis(); } @@ -172,11 +176,11 @@ class PjRtLoadedExecutable final // Creates PjRtExecutable from xla::PjRtLoadedExecutable. We expect that // xla::PjRtLoadedExecutable has fixed output dtypes/shapes/shardings. // PjRtLoadedExecutable::GetHloModules() must be implemented. - static StatusOr> Create( + static absl::StatusOr> Create( PjRtCompatibleClient* client, std::unique_ptr pjrt_loaded_executable, std::vector> loaded_host_callbacks); - static StatusOr> Create( + static absl::StatusOr> Create( PjRtCompatibleClient* client, std::shared_ptr pjrt_loaded_executable, std::vector> loaded_host_callbacks); @@ -186,7 +190,7 @@ class PjRtLoadedExecutable final // options.executable_build_options has use_auto_spmd_partitioning or // allow_spmd_sharding_propagation_to_output enabled, // PjRtLoadedExecutable::GetHloModules() must be implemented. - static StatusOr> Create( + static absl::StatusOr> Create( PjRtCompatibleClient* client, mlir::ModuleOp module, xla::CompileOptions compile_options, std::vector> loaded_host_callbacks); @@ -212,6 +216,12 @@ class PjRtLoadedExecutable final return pjrt_loaded_executable_->name(); } + Future GetReadyFuture() const override { + // PjRtCompiler blocks until compilation finishes and returns only the + // executables that are ready. + return Future(absl::OkStatus()); + } + std::optional> GetParameterShardings() const override { DCHECK(this); @@ -223,19 +233,21 @@ class PjRtLoadedExecutable final return pjrt_loaded_executable_->GetOutputShardings(); } - StatusOr> GetParameterLayouts() const override { + absl::StatusOr>> GetParameterLayouts() + const override { DCHECK(this); return pjrt_loaded_executable_->GetParameterLayouts(); } - StatusOr> GetOutputLayouts() const override { + absl::StatusOr>> GetOutputLayouts() + const override { DCHECK(this); return pjrt_loaded_executable_->GetOutputLayouts(); } - StatusOr> Fingerprint() const override; + absl::StatusOr> Fingerprint() const override; - StatusOr Serialize() const override; + absl::StatusOr Serialize() const override; int num_devices() const override { DCHECK(this); @@ -246,19 +258,19 @@ class PjRtLoadedExecutable final DCHECK(this); return pjrt_loaded_executable_->SizeOfGeneratedCodeInBytes(); } - StatusOr GetCompiledMemoryStats() const override { + absl::StatusOr GetCompiledMemoryStats() const override { DCHECK(this); return pjrt_loaded_executable_->GetCompiledMemoryStats(); } - StatusOr>> GetHloModules() + absl::StatusOr>> GetHloModules() const override { DCHECK(this); return pjrt_loaded_executable_->GetHloModules(); } - StatusOr>> GetOutputMemoryKinds() - const override { + absl::StatusOr>> + GetOutputMemoryKinds() const override { DCHECK(this); return pjrt_loaded_executable_->GetOutputMemoryKinds(); } @@ -267,9 +279,9 @@ class PjRtLoadedExecutable final DCHECK(this); return client_; } - StatusOr Execute(absl::Span> args, - const ExecuteOptions& options, - std::optional devices) override; + absl::StatusOr Execute( + absl::Span> args, const ExecuteOptions& options, + std::optional devices) override; Future Delete() override; bool IsDeleted() const override { @@ -287,7 +299,8 @@ class PjRtLoadedExecutable final return pjrt_loaded_executable_->addressable_devices(); } - StatusOr> + absl::StatusOr< + absl::flat_hash_map> GetCostAnalysis() const override { return pjrt_loaded_executable_->GetCostAnalysis(); } @@ -295,7 +308,7 @@ class PjRtLoadedExecutable final static char ID; // NOLINT private: - static StatusOr> CreateInternal( + static absl::StatusOr> CreateInternal( PjRtCompatibleClient* client, std::shared_ptr pjrt_loaded_executable, absl::Span result_element_types, diff --git a/xla/python/pjrt_ifrt/pjrt_executable_impl_test_tfrt_cpu.cc b/xla/python/pjrt_ifrt/pjrt_executable_impl_test_tfrt_cpu.cc index 462bd2bc202af..ad013b7789b7d 100644 --- a/xla/python/pjrt_ifrt/pjrt_executable_impl_test_tfrt_cpu.cc +++ b/xla/python/pjrt_ifrt/pjrt_executable_impl_test_tfrt_cpu.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/pjrt_ifrt/pjrt_host_callback.cc b/xla/python/pjrt_ifrt/pjrt_host_callback.cc index 345efb9597e69..988789bf69ecc 100644 --- a/xla/python/pjrt_ifrt/pjrt_host_callback.cc +++ b/xla/python/pjrt_ifrt/pjrt_host_callback.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,7 +22,8 @@ namespace ifrt { char PjRtHostSendAndRecvLoadedHostCallback::ID = 0; -StatusOr PjRtHostSendAndRecvLoadedHostCallback::Serialize() const { +absl::StatusOr PjRtHostSendAndRecvLoadedHostCallback::Serialize() + const { return Unimplemented( "PjRtHostSendAndRecvLoadedHostCallback serialization is not supported"); } diff --git a/xla/python/pjrt_ifrt/pjrt_host_callback.h b/xla/python/pjrt_ifrt/pjrt_host_callback.h index ed898d503992a..b192cbf76144e 100644 --- a/xla/python/pjrt_ifrt/pjrt_host_callback.h +++ b/xla/python/pjrt_ifrt/pjrt_host_callback.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -51,7 +51,7 @@ class PjRtHostSendAndRecvLoadedHostCallback Client* client() const override { return client_; } - StatusOr Serialize() const override; + absl::StatusOr Serialize() const override; static char ID; // NOLINT diff --git a/xla/python/pjrt_ifrt/pjrt_tuple.cc b/xla/python/pjrt_ifrt/pjrt_tuple.cc index 66dc45400922b..c18b2ed6c1d1e 100644 --- a/xla/python/pjrt_ifrt/pjrt_tuple.cc +++ b/xla/python/pjrt_ifrt/pjrt_tuple.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ limitations under the License. namespace xla { namespace ifrt { -/*static*/ StatusOr> PjRtTuple::Create( +/*static*/ absl::StatusOr> PjRtTuple::Create( PjRtCompatibleClient* client, absl::Span> values) { return tsl::MakeRef(client, values); } diff --git a/xla/python/pjrt_ifrt/pjrt_tuple.h b/xla/python/pjrt_ifrt/pjrt_tuple.h index dc79eca54408b..3d4bab2c5d71b 100644 --- a/xla/python/pjrt_ifrt/pjrt_tuple.h +++ b/xla/python/pjrt_ifrt/pjrt_tuple.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ namespace ifrt { class PjRtTuple final : public llvm::RTTIExtends { public: - static StatusOr> Create( + static absl::StatusOr> Create( PjRtCompatibleClient* client, absl::Span> values); ~PjRtTuple() override = default; diff --git a/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc b/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc index b790c4c560616..356c2f9e5d2f3 100644 --- a/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc +++ b/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,9 +25,9 @@ namespace ifrt { namespace { const bool kUnused = (test_util::RegisterClientFactory( - []() -> StatusOr> { + []() -> absl::StatusOr> { CpuClientOptions options; - options.cpu_device_count = 2; + options.cpu_device_count = 4; TF_ASSIGN_OR_RETURN(auto pjrt_client, xla::GetTfrtCpuClient(options)); return std::shared_ptr( diff --git a/xla/python/pjrt_ifrt/xla_compiler.cc b/xla/python/pjrt_ifrt/xla_compiler.cc index be06e363156f8..165742ad28904 100644 --- a/xla/python/pjrt_ifrt/xla_compiler.cc +++ b/xla/python/pjrt_ifrt/xla_compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -86,7 +86,7 @@ char XlaProgram::ID = 0; char XlaCompileOptions::ID = 0; char XlaDeserializeExecutableOptions::ID = 0; -StatusOr> GetXlaCompileOptions( +absl::StatusOr> GetXlaCompileOptions( std::unique_ptr options) { if (!llvm::isa(options.get())) { return xla::InvalidArgument("options must be XlaCompileOptions"); @@ -95,7 +95,7 @@ StatusOr> GetXlaCompileOptions( static_cast(options.release())); } -StatusOr> +absl::StatusOr> GetXlaDeserializeExecutableOptions( std::unique_ptr options) { if (!llvm::isa(options.get())) { diff --git a/xla/python/pjrt_ifrt/xla_compiler.h b/xla/python/pjrt_ifrt/xla_compiler.h index 7581d51615936..dc670d52418a6 100644 --- a/xla/python/pjrt_ifrt/xla_compiler.h +++ b/xla/python/pjrt_ifrt/xla_compiler.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -104,12 +104,12 @@ struct XlaDeserializeExecutableOptions }; // Gets `xla::ifrt::XlaCompileOptions` from `xla::ifrt::CompileOptions`. -StatusOr> GetXlaCompileOptions( +absl::StatusOr> GetXlaCompileOptions( std::unique_ptr options); // Gets `xla::ifrt::XlaDeserializeExecutableOptions` from // `xla::ifrt::DeserializeExecutableOptions`. -StatusOr> +absl::StatusOr> GetXlaDeserializeExecutableOptions( std::unique_ptr options); diff --git a/xla/python/pjrt_ifrt/xla_compiler.proto b/xla/python/pjrt_ifrt/xla_compiler.proto index f5a1e2e639d8c..f9f97b10074c6 100644 --- a/xla/python/pjrt_ifrt/xla_compiler.proto +++ b/xla/python/pjrt_ifrt/xla_compiler.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc b/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc index 4ac037f58258b..e0a221d4ba753 100644 --- a/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc +++ b/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -46,7 +46,7 @@ func.func @main(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { }})"; // Compiles an MLIR module on specified devices. -StatusOr> CompileOnDevices( +absl::StatusOr> CompileOnDevices( Client* client, Compiler* compiler, absl::string_view mlir_module_str, absl::Span devices, bool replicated) { mlir::MLIRContext context; diff --git a/xla/python/pjrt_ifrt/xla_host_callback.proto b/xla/python/pjrt_ifrt/xla_host_callback.proto index 56045ae35e79a..e583c5f946b29 100644 --- a/xla/python/pjrt_ifrt/xla_host_callback.proto +++ b/xla/python/pjrt_ifrt/xla_host_callback.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/pjrt_ifrt/xla_program_serdes.cc b/xla/python/pjrt_ifrt/xla_program_serdes.cc index d8ce5cdce50a1..83348b7b8b6e2 100644 --- a/xla/python/pjrt_ifrt/xla_program_serdes.cc +++ b/xla/python/pjrt_ifrt/xla_program_serdes.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,8 +21,6 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/Support/ExtensibleRTTI.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project @@ -30,8 +28,10 @@ limitations under the License. #include "stablehlo/api/PortableApi.h" // from @stablehlo #include "stablehlo/dialect/Serialization.h" // from @stablehlo #include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/pjrt/mlir_to_hlo.h" #include "xla/python/ifrt/serdes.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "tsl/platform/status.h" namespace xla { namespace ifrt { @@ -72,25 +72,10 @@ class XlaProgramSerDes : public llvm::RTTIExtends { mlir::OwningOpRef module( llvm::cast(program.mlir_module->clone())); - mlir::PassManager pm(module->getContext()); - pm.addNestedPass( - mlir::mhlo::createChloLegalizeToHloPass()); - pm.addNestedPass( - mlir::mhlo::createShapeLegalizeToHloPass()); - pm.addPass(mlir::createReconcileUnrealizedCastsPass()); - pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); - if (!mlir::succeeded(pm.run(*module))) { - return absl::InvalidArgumentError( - "CHLO => [MHLO+Shape] => StableHLO failed"); - } - // Serialize portable artifact. - std::string serialized; - llvm::raw_string_ostream os(serialized); - if (mlir::failed(mlir::stablehlo::serializePortableArtifact( - *module, mlir::stablehlo::getCurrentVersion(), os))) { - return absl::InvalidArgumentError("Failed to serialize StableHLO"); - } + TF_ASSIGN_OR_RETURN(std::string serialized, + xla::SerializeUsingVersionedStablehlo( + *module, mlir::stablehlo::getCurrentVersion())); return serialized; } diff --git a/xla/python/pjrt_ifrt/xla_program_serdes_test.cc b/xla/python/pjrt_ifrt/xla_program_serdes_test.cc index 8309d04e4f024..33c1336477c4c 100644 --- a/xla/python/pjrt_ifrt/xla_program_serdes_test.cc +++ b/xla/python/pjrt_ifrt/xla_program_serdes_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/pjrt_ifrt/xla_sharding.cc b/xla/python/pjrt_ifrt/xla_sharding.cc index f9f8ec4664c60..d1204b4357aa0 100644 --- a/xla/python/pjrt_ifrt/xla_sharding.cc +++ b/xla/python/pjrt_ifrt/xla_sharding.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,6 +26,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/types/span.h" #include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla { @@ -90,7 +92,7 @@ std::unique_ptr HloSharding::Create( std::move(devices), memory_kind, std::move(xla_hlo_sharding))); } -StatusOr>>> +absl::StatusOr>>> HloSharding::Disassemble(const Shape& shape) const { TF_ASSIGN_OR_RETURN(auto index_domains, IndexDomains(shape)); std::vector>> result; @@ -102,7 +104,16 @@ HloSharding::Disassemble(const Shape& shape) const { return result; } -StatusOr> HloSharding::IndexDomains( +absl::StatusOr< + std::vector>>> +HloSharding::Disassemble(const DynamicShape& dynamic_shape) const { + return InvalidArgument( + "HloSharding can only disassemble static shape, but was asked " + "to disassemble dynamic shape %s", + dynamic_shape.DebugString()); +} + +absl::StatusOr> HloSharding::IndexDomains( const Shape& shape) const { auto format_shape = [&] { return absl::StrCat("[", absl::StrJoin(shape.dims(), ","), "]"); diff --git a/xla/python/pjrt_ifrt/xla_sharding.h b/xla/python/pjrt_ifrt/xla_sharding.h index 9711538b99912..ab1eb91e39c5b 100644 --- a/xla/python/pjrt_ifrt/xla_sharding.h +++ b/xla/python/pjrt_ifrt/xla_sharding.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ limitations under the License. #include #include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" namespace xla { @@ -56,10 +57,13 @@ class HloSharding final ~HloSharding() override = default; - StatusOr>>> + absl::StatusOr>>> Disassemble(const Shape& shape) const override; + absl::StatusOr< + std::vector>>> + Disassemble(const DynamicShape& dynamic_shape) const override; - StatusOr> IndexDomains( + absl::StatusOr> IndexDomains( const Shape& shape) const override; std::string DebugString() const override; diff --git a/xla/python/pjrt_ifrt/xla_sharding.proto b/xla/python/pjrt_ifrt/xla_sharding.proto index aba858b50f5a5..6867d9713c39e 100644 --- a/xla/python/pjrt_ifrt/xla_sharding.proto +++ b/xla/python/pjrt_ifrt/xla_sharding.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ syntax = "proto3"; package xla.ifrt; -import "xla/python/ifrt/types.proto"; +import "xla/python/ifrt/device.proto"; import "xla/xla_data.proto"; // Wire format for `HloSharding`. diff --git a/xla/python/pjrt_ifrt/xla_sharding_serdes.cc b/xla/python/pjrt_ifrt/xla_sharding_serdes.cc index e403752f0e98e..2cd284f025db1 100644 --- a/xla/python/pjrt_ifrt/xla_sharding_serdes.cc +++ b/xla/python/pjrt_ifrt/xla_sharding_serdes.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc b/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc index b530875bd9099..8adfa7ddc5209 100644 --- a/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc +++ b/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/pjrt_ifrt/xla_sharding_test.cc b/xla/python/pjrt_ifrt/xla_sharding_test.cc index 46ee68784713e..a5cd85294fe42 100644 --- a/xla/python/pjrt_ifrt/xla_sharding_test.cc +++ b/xla/python/pjrt_ifrt/xla_sharding_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding_test_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -327,6 +328,21 @@ TEST_P(HloShardingTest, DisassembleFailsWithMismatchingShapeDimsSize) { HasSubstr("shape must have 2 dimensions, but has 1 dimensions"))); } +TEST_P(HloShardingTest, DisassembleFailsWithDynamicShape) { + auto device_list = GetDevices({0, 1}); + auto xla_hlo_sharding = xla::HloSharding::Tile( + xla::TileAssignment((absl::Span){2})); + std::shared_ptr sharding = + HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); + + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape dynamic_shape, + DynamicShape::Create(Shape({10}), BoundedDynamicShapeTag({true}))); + EXPECT_THAT(sharding->Disassemble(dynamic_shape), + StatusIs(tsl::error::INVALID_ARGUMENT, + HasSubstr("can only disassemble static shape"))); +} + INSTANTIATE_TEST_SUITE_P(NumDevices, HloShardingTest, testing::Values(test_util::ShardingTestParam{ .num_devices = 6, .num_addressable_devices = 4})); diff --git a/xla/python/pmap_lib.cc b/xla/python/pmap_lib.cc index aed4ca042331e..aefe984dc0630 100644 --- a/xla/python/pmap_lib.cc +++ b/xla/python/pmap_lib.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,7 +15,11 @@ limitations under the License. #include "xla/python/pmap_lib.h" +#include + #include +#include +#include #include #include #include @@ -25,42 +29,56 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" #include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/synchronization/notification.h" #include "absl/types/span.h" -#include "absl/types/variant.h" -#include "pybind11/cast.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 -#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil -#include "xla/python/exceptions.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/shared_ptr.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/variant.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/vector.h" // from @nanobind // IWYU pragma: keep +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" -#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/device.h" #include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" #include "xla/python/jax_jit.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" #include "xla/python/py_array.h" -#include "xla/python/py_buffer.h" +#include "xla/python/py_client.h" +#include "xla/python/py_device.h" #include "xla/python/py_executable.h" #include "xla/python/py_values.h" -#include "xla/python/python_utils.h" +#include "xla/python/python_ref_manager.h" #include "xla/python/pytree.h" #include "xla/python/sharded_device_array.h" #include "xla/python/sharding.h" -#include "xla/python/status_casters.h" +#include "xla/python/traceback.h" #include "xla/python/types.h" -#include "xla/python/util.h" +#include "xla/status_macros.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/concurrency/ref_count.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace jax { -namespace py = pybind11; +namespace nb = nanobind; namespace { @@ -68,21 +86,21 @@ namespace { // from `sharding_specs` and the argument shape, we cache derived computations // for performance. struct InputSpec { - InputSpec(py::object indices, py::object array_sharding) + InputSpec(nb::object indices, nb::object array_sharding) : indices(std::move(indices)), array_sharding(std::move(array_sharding)) {} - py::object indices; - py::object array_sharding; + nb::object indices; + nb::object array_sharding; }; // An object containing the arguments to create Array from the // output buffers. struct ResultSpec { public: - explicit ResultSpec(py::object aval) + explicit ResultSpec(nb::object aval) : out_aval(std::move(aval)), - weak_type(py::cast(out_aval.attr("weak_type"))) {} - py::object out_aval; + weak_type(nb::cast(out_aval.attr("weak_type"))) {} + nb::object out_aval; bool weak_type; }; @@ -92,7 +110,7 @@ struct ShardArgResult { // ifrt_array->sharding().num_shards() == `num_devices`. tsl::RCReference ifrt_array; // The Python argument will be always be copied to `owning_sda`. - py::object owning_sda; + nb::object owning_sda; }; // Shars a single argument over devices. @@ -112,56 +130,53 @@ struct ShardArgResult { // need to fallback to Python. // // Both `devices` and `sharding_spec` has the same length. -xla::StatusOr ShardArg( - py::handle arg, absl::Span devices, - const InputSpec& input_spec, py::handle py_devices, - const py::function& python_fallback) { - if (arg.get_type() == xla::PyArray::type()) { - auto py_array = py::reinterpret_borrow(arg); - if (py_array.fastpath_enabled()) { - if (py_array.sharding().get_type() == - input_spec.array_sharding.get_type()) { - auto* pmap_sharding = py_array.sharding().cast(); - auto* cached_pmap_sharding = - input_spec.array_sharding.cast(); - - if (pmap_sharding->sharding_spec() == - cached_pmap_sharding->sharding_spec()) { - ShardArgResult result; - result.owning_sda = py::reinterpret_borrow(arg); - result.ifrt_array = tsl::FormRef(py_array.ifrt_array()); - if (result.ifrt_array == nullptr) { - return xla::InvalidArgument("Array has been deleted."); - } - if (result.ifrt_array->sharding().devices().devices() != devices) { - xla::ifrt::DeviceList::Devices ifrt_devices; - ifrt_devices.reserve(devices.size()); - ifrt_devices.insert(ifrt_devices.end(), devices.begin(), - devices.end()); - // pmap does not support memory_kind for now. - auto sharding = xla::ifrt::OpaqueSharding::Create( - xla::ifrt::DeviceList(std::move(ifrt_devices)), - xla::ifrt::MemoryKind()); - TF_ASSIGN_OR_RETURN( - auto copied_ifrt_array, - result.ifrt_array->Reshard( - std::move(sharding), - xla::ifrt::ArrayCopySemantics::kReuseInput)); - result.ifrt_array = std::move(copied_ifrt_array); - } - return result; +absl::StatusOr ShardArg( + nb::handle arg, absl::Span devices, + const InputSpec& input_spec, nb::handle py_devices, + const nb::callable& python_fallback) { + if (arg.type().ptr() == xla::PyArray::type().ptr()) { + auto py_array = nb::borrow(arg); + if (py_array.sharding().type().ptr() == + input_spec.array_sharding.type().ptr()) { + auto* pmap_sharding = + nb::cast(nb::handle(py_array.sharding().ptr())); + auto* cached_pmap_sharding = nb::cast( + nb::handle(input_spec.array_sharding.ptr())); + + if (pmap_sharding->sharding_spec() == + cached_pmap_sharding->sharding_spec()) { + ShardArgResult result; + result.owning_sda = nb::borrow(arg.ptr()); + result.ifrt_array = tsl::FormRef(py_array.ifrt_array()); + if (result.ifrt_array == nullptr) { + return xla::InvalidArgument("Array has been deleted."); + } + if (result.ifrt_array->sharding().devices().devices() != devices) { + xla::ifrt::DeviceList::Devices ifrt_devices; + ifrt_devices.reserve(devices.size()); + ifrt_devices.insert(ifrt_devices.end(), devices.begin(), + devices.end()); + // pmap does not support memory_kind for now. + auto sharding = xla::ifrt::OpaqueSharding::Create( + xla::ifrt::DeviceList(std::move(ifrt_devices)), + xla::ifrt::MemoryKind()); + TF_ASSIGN_OR_RETURN(auto copied_ifrt_array, + result.ifrt_array->Reshard( + std::move(sharding), + xla::ifrt::ArrayCopySemantics::kReuseInput)); + result.ifrt_array = std::move(copied_ifrt_array); } + return result; } } } - static auto ndarray_type = py::module::import("numpy").attr("ndarray").ptr(); - auto ndarray = py::array::ensure(arg); - if (ndarray && py::type::of(arg) == ndarray_type && + auto ndarray = xla::nb_numpy_ndarray::ensure(arg); + if (ndarray && PyArray_CheckExact(arg.ptr()) && xla::DtypeToPrimitiveType(ndarray.dtype()).status().ok()) { tsl::profiler::TraceMe traceme("ndarray pmap ShardArg"); - py::list indices = input_spec.indices; - py::list py_devices_list = py::cast(py_devices); + nb::list indices = nb::list(input_spec.indices); + nb::list py_devices_list = nb::cast(py_devices); auto n_devices = py_devices_list.size(); if (indices.size() != n_devices) { return xla::InvalidArgument("indices vs devices mismatch: %d vs %d", @@ -177,7 +192,7 @@ xla::StatusOr ShardArg( std::vector shapes; shapes.reserve(n_devices); - py::list owning_pylist(n_devices); + nb::list owning_pylist; ShardArgResult result; result.owning_sda = owning_pylist; const bool jax_enable_x64 = GetEnableX64(); @@ -186,16 +201,15 @@ xla::StatusOr ShardArg( options.squash_64bit_types = !jax_enable_x64; options.allow_zero_copy = true; for (size_t i = 0; i < n_devices; ++i) { - auto to_device = - py::cast>(py_devices_list[i]); - if (to_device.get_client() == nullptr) { + auto to_device = nb::cast(py_devices_list[i]); + if (to_device->client().get() == nullptr) { return xla::InvalidArgument("Cannot copy to unattached devices."); } TF_ASSIGN_OR_RETURN( xla::DevicePutResult on_device, - DevicePut(arg[indices[i]], to_device.get_client()->ifrt_client(), - to_device.get(), options, xla::ifrt::MemoryKind())); + DevicePut(arg[indices[i]], to_device->client()->ifrt_client(), + to_device->device(), options, xla::ifrt::MemoryKind())); per_device_arrays.push_back(std::move(on_device.ifrt_array)); devices.push_back(per_device_arrays.back()->sharding().devices().front()); @@ -224,12 +238,11 @@ xla::StatusOr ShardArg( return result; } tsl::profiler::TraceMe traceme("pmap_lib_shard_arg_python_fallback"); - auto py_array_or_bufs = python_fallback(arg, py_devices, input_spec.indices, - input_spec.array_sharding); + auto py_array_or_bufs = python_fallback(arg, input_spec.array_sharding); - auto py_array = py::cast(py_array_or_bufs); + auto py_array = nb::cast(py_array_or_bufs); ShardArgResult result; - result.owning_sda = py_array_or_bufs; + result.owning_sda = nb::borrow(py_array_or_bufs.ptr()); result.ifrt_array = tsl::FormRef(py_array.ifrt_array()); return result; } @@ -239,15 +252,15 @@ struct PmapCacheEntry { : out_pytree_def(registry) {} std::shared_ptr executable; // The value `backend.local_devices()`. - py::object py_devices; // To pass back to Python. + nb::object py_devices; // To pass back to Python. std::vector devices; std::vector input_specs; xla::PyTreeDef out_pytree_def; // Objects necessary to build the out Array objects. std::vector out_result_specs; - std::vector out_array_shardings; - std::vector out_dtypes; + std::vector out_array_shardings; + std::vector out_dtypes; std::vector> out_shapes; std::vector out_committed; @@ -268,9 +281,9 @@ struct PmapCacheEntry { // the correct underlying `PyLoadedExecutable`. This class is thread-safe. class PmapFunction { public: - PmapFunction(py::function fun, py::function cache_miss, + PmapFunction(nb::callable fun, nb::callable cache_miss, std::vector static_argnums, - py::function python_shard_arg_fallback, + nb::callable python_shard_arg_fallback, std::shared_ptr pytree_registry) : fun_(std::move(fun)), cache_miss_(std::move(cache_miss)), @@ -279,7 +292,8 @@ class PmapFunction { python_shard_arg_fallback_(std::move(python_shard_arg_fallback)) { std::sort(static_argnums_.begin(), static_argnums_.end()); - function_name_ = py::str(py::getattr(fun_, "__name__", fun_)); + function_name_ = + nb::cast(nb::str(nb::getattr(fun_, "__name__", fun_))); } PmapFunction(const PmapFunction&) = delete; PmapFunction& operator=(const PmapFunction& other) = delete; @@ -292,49 +306,50 @@ class PmapFunction { // (c) call the executable // (d) construct `Array` objects from the outputs // (e) reconstruct the `PyTree`. - xla::StatusOr Call(py::handle callable, PyObject* const* args, - size_t nargs, PyObject* kwnames); + absl::StatusOr Call(nb::handle callable, PyObject* const* args, + size_t nargs, PyObject* kwnames); - py::object PythonSignature() { - static const auto* inspect = new py::module(py::module::import("inspect")); + nb::object PythonSignature() { + static const auto* inspect = + new nb::module_(nb::module_::import_("inspect")); return inspect->attr("signature")(fun_); } int cache_size() const { return executables_.size(); } void cache_clear() { return executables_.clear(); } - const py::function& fun() const { return fun_; } - const py::function& cache_miss() const { return cache_miss_; } + const nb::callable& fun() const { return fun_; } + const nb::callable& cache_miss() const { return cache_miss_; } const std::string& function_name() const { return function_name_; } const std::shared_ptr& pytree_registry() const { return pytree_registry_; } - const py::function& python_shard_arg_fallback() const { + const nb::callable& python_shard_arg_fallback() const { return python_shard_arg_fallback_; } const std::vector& static_argnums() const { return static_argnums_; } - // pybind11::object typed subclass for PmapFunction objects. - class pyobject : public py::object { + // nb::object typed subclass for PmapFunction objects. + class pyobject : public nb::object { public: - PYBIND11_OBJECT(pyobject, // NOLINT - py::object, PmapFunction::IsPmapFunction); + NB_OBJECT(pyobject, nb::object, "PmapFunction", + PmapFunction::IsPmapFunction); pyobject() = default; PmapFunction* func() const { return PmapFunction::AsPmapFunctionUnchecked(*this); } }; - // Alias as ::object; outside the scope above we won't confuse pybind11's + // Alias as ::object; outside the scope above we won't confuse nanobind's // macros. using object = pyobject; // Returns true if `h` is a PmapFunction. - static bool IsPmapFunction(py::handle handle); + static bool IsPmapFunction(nb::handle handle); // Converts `handle` to a PmapFunction*. Does not do any checking. - static PmapFunction* AsPmapFunctionUnchecked(py::handle handle); + static PmapFunction* AsPmapFunctionUnchecked(nb::handle handle); // Helper function used by the tp_clear GC method. void ClearPythonReferences() { - py::function fun, cache_miss, python_shard_arg_fallback; + nb::callable fun, cache_miss, python_shard_arg_fallback; // Swap values for nulls before they are destroyed. See the Python // Py_CLEAR() documentation for a discussion of this topic. std::swap(fun_, fun); @@ -346,40 +361,29 @@ class PmapFunction { // // It deals with the arguments signatures and also of the global and // thread-local jit context. - xla::Status UpdateArgsSignature(ParsedArgumentsAsBuffers& arguments) { - arguments.signature.function_name = function_name_; + absl::Status ComputeCallSignature( + absl::Span flat_dynamic_args, + CallSignature& signature) { + signature.function_name = function_name_; // Get dynamic argument signatures. JitState& global_state = jax::GlobalJitState(); JitState& tls = jax::ThreadLocalJitState(); const bool jax_enable_x64 = GetEnableX64(); - arguments.signature.jax_enable_x64 = jax_enable_x64; - for (py::handle arg : arguments.flat_dynamic_args) { + signature.jax_enable_x64 = jax_enable_x64; + for (nb::handle arg : flat_dynamic_args) { auto signature_or_error = xla::PyArgSignatureOfValue(arg, jax_enable_x64); if (!signature_or_error.ok()) { VLOG(2) << "PyArgSignatureOfValue failed: " << signature_or_error.status(); return signature_or_error.status(); } - arguments.signature.dynamic_arg_signatures.push_back( + signature.dynamic_arg_signatures.push_back( std::move(signature_or_error).value()); } - try { - py::object pxla_module = py::module::import("jax").attr("config"); - py::object sda = py::getattr(pxla_module, "_trace_context", py::none()); - if (!sda.is_none()) { - arguments.signature.thread_local_extra_jit_context = sda(); - } - } catch (const py::error_already_set& e) { - // Ignore; jax may not be present. - } - if (!arguments.signature.thread_local_extra_jit_context.has_value()) { - arguments.signature.thread_local_extra_jit_context = - tls.extra_jit_context; - arguments.signature.global_extra_jit_context = - global_state.extra_jit_context; - } - return xla::Status(); + signature.thread_local_extra_jit_context = tls.extra_jit_context; + signature.global_extra_jit_context = global_state.extra_jit_context; + return absl::Status(); } // Returns, for debugging purposes (e.g. finding why some call misses the @@ -399,15 +403,14 @@ class PmapFunction { private: // Mutates `cache_entry` in place. void PopulateCacheEntry(PmapCacheEntry& cache_entry, - const CallSignature& signature, - const py::tuple& out_and_fastpath_data); + const nb::tuple& out_and_fastpath_data); bool always_fallback_to_python_ = false; - py::function fun_; // The Python function to pmap. + nb::callable fun_; // The Python function to pmap. std::string function_name_; // See JAX _cpp_pmap in api.py for documentation. - py::function cache_miss_; + nb::callable cache_miss_; // We need to know the static arguments to remove them from the arguments // passed to the underlying PyLoadedExecutable. In sorted order. @@ -419,54 +422,54 @@ class PmapFunction { // The fallback function to use with `ShardArgs`. // TODO(jblespiau): Add support for more types from C++. - py::function python_shard_arg_fallback_; + nb::callable python_shard_arg_fallback_; }; void PmapFunction::PopulateCacheEntry(PmapCacheEntry& cache_entry, - const CallSignature& signature, - const py::tuple& out_and_fastpath_data) { + const nb::tuple& out_and_fastpath_data) { CHECK_EQ(out_and_fastpath_data.size(), 2); if (out_and_fastpath_data[1].is_none()) { cache_entry.fall_back_to_python = true; return; } - py::tuple pmap_data = py::cast(out_and_fastpath_data[1]); - if (py::cast(pmap_data.attr("version")) != 1) { + nb::tuple pmap_data = nb::cast(out_and_fastpath_data[1]); + if (nb::cast(pmap_data.attr("version")) != 1) { throw xla::XlaRuntimeError(absl::StrCat( "The versions of jaxlib and Jax are incompatible (pmap cpp version 1 " "expected, but got ", - py::cast(pmap_data.attr("version")), + nb::cast(pmap_data.attr("version")), "Upgrade jaxlib and jax. Provided data was:", - py::cast(py::str(py::repr(pmap_data))))); + nb::cast(nb::str(nb::repr(pmap_data))))); } - // See api.py::_PmapFastpathData in the JAX code base for the expected + // See api.nb::_PmapFastpathData in the JAX code base for the expected // namedtuple. std::shared_ptr executable; try { - executable = py::cast>( + executable = nb::cast>( pmap_data.attr("xla_executable")); - } catch (const py::cast_error& e) { + } catch (const nb::cast_error& e) { // Backends that don't implement the C++ PjRt APIs + cache_entry.fall_back_to_python = true; always_fallback_to_python_ = true; return; } cache_entry.executable = std::move(executable); - const std::vector>& client_and_devices = + const std::vector>& devices = cache_entry.executable->AddressableDevices(); - cache_entry.devices.reserve(client_and_devices.size()); - for (auto& client_and_device : client_and_devices) { - cache_entry.devices.push_back(client_and_device.get()); + cache_entry.devices.reserve(devices.size()); + for (auto& device : devices) { + cache_entry.devices.push_back(device->device()); } // Inputs shard args details. - py::list input_indices = pmap_data.attr("input_indices"); + nb::list input_indices = pmap_data.attr("input_indices"); cache_entry.py_devices = pmap_data.attr("input_devices"); - auto input_devices = - py::cast>(pmap_data.attr("input_devices")); + auto input_devices = nb::cast>>( + pmap_data.attr("input_devices")); - py::list input_array_shardings = pmap_data.attr("input_array_shardings"); + nb::list input_array_shardings = pmap_data.attr("input_array_shardings"); cache_entry.input_specs.reserve(input_array_shardings.size()); @@ -476,9 +479,10 @@ void PmapFunction::PopulateCacheEntry(PmapCacheEntry& cache_entry, } // Outputs specs. - auto out_tree = py::cast(pmap_data.attr("out_pytree_def")); + auto out_tree = nb::cast( + nb::handle(pmap_data.attr("out_pytree_def").ptr())); cache_entry.out_pytree_def = std::move(out_tree); - py::list out_avals = pmap_data.attr("out_avals"); + nb::list out_avals = pmap_data.attr("out_avals"); cache_entry.out_result_specs.reserve(out_avals.size()); cache_entry.out_dtypes.reserve(out_avals.size()); @@ -487,40 +491,40 @@ void PmapFunction::PopulateCacheEntry(PmapCacheEntry& cache_entry, for (int i = 0; i < out_avals.size(); ++i) { cache_entry.out_dtypes.push_back(out_avals[i].attr("dtype")); cache_entry.out_shapes.push_back( - py::cast>(out_avals[i].attr("shape"))); + nb::cast>(out_avals[i].attr("shape"))); cache_entry.out_result_specs.emplace_back(out_avals[i]); } - py::list out_array_shardings = pmap_data.attr("out_array_shardings"); + nb::list out_array_shardings = pmap_data.attr("out_array_shardings"); - DCHECK(out_array_shardings.empty() || + DCHECK(out_array_shardings.size() == 0 || out_avals.size() == out_array_shardings.size()); cache_entry.out_array_shardings.reserve(out_array_shardings.size()); - for (py::handle out_array_sharding : out_array_shardings) { + for (nb::handle out_array_sharding : out_array_shardings) { cache_entry.out_array_shardings.push_back( - py::reinterpret_borrow(out_array_sharding)); + nb::borrow(out_array_sharding)); } - py::list out_committed = pmap_data.attr("out_committed"); + nb::list out_committed = pmap_data.attr("out_committed"); - DCHECK(out_committed.empty() || out_avals.size() == out_committed.size()); + DCHECK(out_committed.size() == 0 || out_avals.size() == out_committed.size()); cache_entry.out_committed.reserve(out_committed.size()); - for (py::handle c : out_committed) { - cache_entry.out_committed.push_back(py::cast(c)); + for (nb::handle c : out_committed) { + cache_entry.out_committed.push_back(nb::cast(c)); } } -xla::StatusOr PmapFunction::Call(py::handle callable, - PyObject* const* args, - size_t nargs, PyObject* kwnames) { +absl::StatusOr PmapFunction::Call(nb::handle callable, + PyObject* const* args, + size_t nargs, PyObject* kwnames) { xla::GlobalPyRefManager()->MaybeCollectGarbage(); // Calls the cache_miss_ function. This just calls the Python function; it may // return nullptr value if a Python exception is thrown. - auto cache_miss = [&]() -> py::tuple { - return py::reinterpret_steal( + auto cache_miss = [&]() -> nb::tuple { + return nb::steal( PyObject_Vectorcall(cache_miss_.ptr(), args, nargs, kwnames)); }; @@ -528,11 +532,11 @@ xla::StatusOr PmapFunction::Call(py::handle callable, // the fastpath data. If the cache miss returns a Python error, returns // nullptr and leaves the Python error set. auto fallback_to_cache_miss = [&]() { - py::tuple cache_miss_output = cache_miss(); + nb::tuple cache_miss_output = cache_miss(); if (!cache_miss_output.ptr()) { - return py::object(); + return nb::object(); } - return py::object(cache_miss_output[0]); + return nb::object(cache_miss_output[0]); }; if (always_fallback_to_python_) { @@ -544,16 +548,19 @@ xla::StatusOr PmapFunction::Call(py::handle callable, absl::Span positional_args(args, num_positional_args); absl::Span keyword_args(args + num_positional_args, num_keyword_args); - ParsedArgumentsAsBuffers arguments; - xla::Status status = + CallSignature call_signature; + absl::InlinedVector flat_dynamic_args; + std::vector keep_alive_objects; + absl::Status status = ParseArguments(positional_args, keyword_args, kwnames, static_argnums_, - /*static_argnames=*/{}, pytree_registry_.get(), arguments); + /*static_argnames=*/{}, pytree_registry_.get(), + call_signature.arg_signature, flat_dynamic_args); if (!status.ok()) { VLOG(2) << "ParseArguments failed: " << status; return fallback_to_cache_miss(); } - status = UpdateArgsSignature(arguments); + status = ComputeCallSignature(flat_dynamic_args, call_signature); if (!status.ok()) { return fallback_to_cache_miss(); } @@ -563,7 +570,7 @@ xla::StatusOr PmapFunction::Call(py::handle callable, it; bool inserted; std::tie(it, inserted) = executables_.try_emplace( - arguments.signature, std::unique_ptr()); + call_signature, std::unique_ptr()); if (inserted) { it->second = std::make_unique(pytree_registry_.get()); } @@ -573,18 +580,19 @@ xla::StatusOr PmapFunction::Call(py::handle callable, // In case of several threads attempting to compile the executable, only // the one that inserted the item will perform the compilation. if (inserted) { - py::object out_and_fastpath_data; - py::tuple out_tuple; - VLOG(2) << "Cache miss for " << arguments.signature.DebugString(); + nb::object out_and_fastpath_data; + nb::tuple out_tuple; + VLOG(2) << "Cache miss for " << call_signature.DebugString(); try { // Calls Python and may release the GIL. May also throw if // compilation/tracing fails. out_and_fastpath_data = cache_miss(); if (!out_and_fastpath_data.ptr()) { - throw py::error_already_set(); + throw nb::python_error(); } - out_tuple = py::cast(out_and_fastpath_data); - PopulateCacheEntry(cache_entry, arguments.signature, out_tuple); + out_tuple = nb::cast(out_and_fastpath_data); + + PopulateCacheEntry(cache_entry, out_tuple); } catch (const std::exception& e) { cache_entry.fall_back_to_python = true; cache_entry.compilation_complete.Notify(); @@ -595,11 +603,11 @@ xla::StatusOr PmapFunction::Call(py::handle callable, // We have already computed the result in the miss path so we can return // it. We are even *required* to do so if there are donated arguments, // because any donated buffers will now be invalid. - return py::object(out_tuple[0]); + return nb::object(out_tuple[0]); } else { // Release the GIL while we wait, making sure the compile thread can // lock it. - py::gil_scoped_release release; + nb::gil_scoped_release release; cache_entry.compilation_complete.WaitForNotification(); } } @@ -610,26 +618,26 @@ xla::StatusOr PmapFunction::Call(py::handle callable, // 1. Parse arguments. std::vector& input_devices = cache_entry.devices; std::vector& input_specs = cache_entry.input_specs; - const int num_args = arguments.flat_dynamic_args.size(); + const int num_args = flat_dynamic_args.size(); // We need [num_args] for the `Execute` call below. std::vector> num_args_arrays(num_args); for (int i = 0; i < num_args; ++i) { TF_ASSIGN_OR_RETURN( ShardArgResult sharded_arg, - ShardArg(arguments.flat_dynamic_args[i], input_devices, input_specs[i], + ShardArg(flat_dynamic_args[i].ptr(), input_devices, input_specs[i], cache_entry.py_devices, python_shard_arg_fallback_)); num_args_arrays[i] = std::move(sharded_arg.ifrt_array); if (sharded_arg.owning_sda) { - arguments.keep_alive_objects.push_back(std::move(sharded_arg.owning_sda)); + keep_alive_objects.push_back(std::move(sharded_arg.owning_sda)); } } // A vector of [num_outputs]. std::vector> output_arrays; { - py::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; auto ifrt_executable = cache_entry.executable->ifrt_executable(); TF_ASSIGN_OR_RETURN( auto result, ifrt_executable->Execute(absl::MakeSpan(num_args_arrays), @@ -644,12 +652,12 @@ xla::StatusOr PmapFunction::Call(py::handle callable, // we access them from Python. auto traceback = xla::Traceback::Get(); // TODO(jblespiau): Change the `client` function to return a reference. - std::shared_ptr client = cache_entry.executable->client(); + xla::nb_class_ptr client = cache_entry.executable->client(); // Convert the PjRtBuffer objects to PyBuffer, and invert the order from // [num_devices, num_args] to [num_args, num_devices]. const int num_outputs = output_arrays.size(); - std::vector flat_sharded_device_arrays; + std::vector flat_sharded_device_arrays; flat_sharded_device_arrays.reserve(num_outputs); const auto& output_specs = cache_entry.out_result_specs; @@ -666,25 +674,28 @@ xla::StatusOr PmapFunction::Call(py::handle callable, flat_sharded_device_arrays.push_back(std::move(py_array)); } - py::object out = + nb::object out = cache_entry.out_pytree_def.Unflatten(flat_sharded_device_arrays); // If there is a post-hook function, call it with the inputs and the outputs. - std::optional post_hook = GetPostHook(); + std::optional post_hook = GetPostHook(); if (post_hook) { - py::tuple args_tuple(num_positional_args); + nb::tuple args_tuple = + nb::steal(PyTuple_New(num_positional_args)); for (size_t i = 0; i < num_positional_args; ++i) { - args_tuple[i] = args[i]; + Py_INCREF(args[i]); + PyTuple_SET_ITEM(args_tuple.ptr(), i, args[i]); } - py::dict kwargs; + nb::dict kwargs; if (kwnames) { for (size_t i = 0; i < num_keyword_args; ++i) { - kwargs[py::handle(PyTuple_GET_ITEM(kwnames, i))] = - args[num_positional_args + i]; + kwargs[nb::handle(PyTuple_GET_ITEM(kwnames, i))] = + nb::borrow(args[num_positional_args + i]); } } - (*post_hook)(callable, args_tuple, kwargs, out); + (*post_hook)(nb::handle(callable.ptr()), args_tuple, kwargs, + nb::handle(out.ptr())); } return out; @@ -700,15 +711,15 @@ struct JaxPmapFunctionObject { PyObject* JaxPmapFunction_Type = nullptr; -bool PmapFunction::IsPmapFunction(py::handle handle) { - return handle.get_type() == JaxPmapFunction_Type; +bool PmapFunction::IsPmapFunction(nb::handle handle) { + return handle.type().ptr() == JaxPmapFunction_Type; } -PmapFunction* PmapFunction::AsPmapFunctionUnchecked(py::handle handle) { +PmapFunction* PmapFunction::AsPmapFunctionUnchecked(nb::handle handle) { return &(reinterpret_cast(handle.ptr())->fun); } -xla::StatusOr AsPmapFunction(py::handle handle) { +absl::StatusOr AsPmapFunction(nb::handle handle) { if (!PmapFunction::IsPmapFunction(handle)) { return xla::InvalidArgument("Expected a PmapFunction"); } @@ -727,16 +738,17 @@ PyObject* JaxPmapFunction_tp_vectorcall(PyObject* callable, return absl::StrCat("JaxPmapFunction(", o->fun.function_name(), ")"); }); try { - xla::StatusOr out = o->fun.Call(callable, args, nargs, kwnames); + absl::StatusOr out = + o->fun.Call(callable, args, nargs, kwnames); if (!out.ok()) { PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str()); return nullptr; } return out.value().release().ptr(); - } catch (py::error_already_set& e) { + } catch (nb::python_error& e) { e.restore(); return nullptr; - } catch (py::cast_error& e) { + } catch (nb::cast_error& e) { PyErr_SetString(PyExc_ValueError, e.what()); return nullptr; } catch (std::invalid_argument& e) { @@ -833,11 +845,11 @@ static PyGetSetDef JaxPmapFunction_tp_getset[] = { } // extern "C" -py::object MakePmapFunction( - py::function fun, py::function cache_miss, std::vector static_argnums, - py::function python_shard_arg_fallback, +nb::object MakePmapFunction( + nb::callable fun, nb::callable cache_miss, std::vector static_argnums, + nb::callable python_shard_arg_fallback, std::shared_ptr pytree_registry) { - py::object obj = py::reinterpret_steal(JaxPmapFunction_tp_new( + nb::object obj = nb::steal(JaxPmapFunction_tp_new( reinterpret_cast(JaxPmapFunction_Type), nullptr, nullptr)); JaxPmapFunctionObject* buf = reinterpret_cast(obj.ptr()); @@ -853,65 +865,77 @@ const int kPmapFunctionPickleVersion = 1; } // namespace -void BuildPmapSubmodule(py::module& m) { - py::module pmap_lib = m.def_submodule("pmap_lib", "Jax C++ pmap library"); +void BuildPmapSubmodule(nb::module_& m) { + nb::module_ pmap_lib = m.def_submodule("pmap_lib", "Jax C++ pmap library"); + nb::module_ pmap_lib_nb = nb::cast(nb::borrow(pmap_lib.ptr())); - py::class_ no_sharding(pmap_lib, "NoSharding"); - no_sharding.def(py::init<>()) - .def(py::pickle([](const NoSharding& self) { return py::make_tuple(); }, - [](py::tuple t) { return NoSharding{}; })) + nb::class_ no_sharding(pmap_lib_nb, "NoSharding"); + no_sharding.def(nb::init<>()) + .def("__getstate__", + [](const NoSharding& self) { return nb::make_tuple(); }) + .def("__setstate__", + [](NoSharding& self, nb::tuple t) { new (&self) NoSharding(); }) .def("__repr__", [](const NoSharding& chuncked) { return "NoSharding()"; }) .def("__eq__", - [](const NoSharding& self, py::object obj) { - return py::isinstance(obj); + [](const NoSharding& self, nb::object obj) { + return nb::isinstance(obj); }) .def("__hash__", [](const NoSharding& self) { const size_t hash = absl::HashOf(self); - return py::int_(hash); + return nb::int_(hash); }); - py::class_ chunked(pmap_lib, "Chunked"); - chunked.def(py::init>()) - .def(py::pickle( - [](const Chunked& self) { return py::make_tuple(self.chunks); }, - [](py::tuple t) { return Chunked{t[0].cast>()}; })) - .def_readonly("chunks", &Chunked::chunks) + nb::class_ chunked(pmap_lib_nb, "Chunked"); + chunked.def(nb::init>()) + .def("__getstate__", + [](const Chunked& self) { return nb::make_tuple(self.chunks); }) + .def("__setstate__", + [](Chunked& self, nb::tuple t) { + new (&self) Chunked{nb::cast>(t[0])}; + }) + .def_ro("chunks", &Chunked::chunks) .def("__repr__", [](const Chunked& chuncked) { return absl::StrCat("Chunked(", absl::StrJoin(chuncked.chunks, ","), ")"); }) - .def("__eq__", [](const Chunked& self, py::object other) { - if (!py::isinstance(other)) { + .def("__eq__", [](const Chunked& self, nb::object other) { + if (!nb::isinstance(other)) { return false; } - return self == py::cast(other); + return self == nb::cast(other); }); - py::class_ unstacked(pmap_lib, "Unstacked"); - unstacked.def(py::init()) - .def(py::pickle( - [](const Unstacked& self) { return py::make_tuple(self.size); }, - [](py::tuple t) { return Unstacked{t[0].cast()}; })) - .def_readonly("size", &Unstacked::size) + nb::class_ unstacked(pmap_lib_nb, "Unstacked"); + unstacked.def(nb::init()) + .def("__getstate__", + [](const Unstacked& self) { return nb::make_tuple(self.size); }) + .def("__setstate__", + [](Unstacked& self, nb::tuple t) { + new (&self) Unstacked{nb::cast(t[0])}; + }) + .def_ro("size", &Unstacked::size) .def("__repr__", [](const Unstacked& x) { return absl::StrCat("Unstacked(", x.size, ")"); }) - .def("__eq__", [](const Unstacked& self, py::object other) { - if (!py::isinstance(other)) { + .def("__eq__", [](const Unstacked& self, nb::object other) { + if (!nb::isinstance(other)) { return false; } - return self == py::cast(other); + return self == nb::cast(other); }); - py::class_ sharded_axis(pmap_lib, "ShardedAxis"); - sharded_axis.def(py::init()) - .def(py::pickle( - [](const ShardedAxis& self) { return py::make_tuple(self.axis); }, - [](py::tuple t) { return ShardedAxis{t[0].cast()}; })) - .def_readonly("axis", &ShardedAxis::axis) + nb::class_ sharded_axis(pmap_lib_nb, "ShardedAxis"); + sharded_axis.def(nb::init()) + .def("__getstate__", + [](const ShardedAxis& self) { return nb::make_tuple(self.axis); }) + .def("__setstate__", + [](ShardedAxis& self, nb::tuple t) { + new (&self) ShardedAxis{nb::cast(t[0])}; + }) + .def_ro("axis", &ShardedAxis::axis) .def("__repr__", [](const ShardedAxis& x) { return absl::StrCat("ShardedAxis(axis=", x.axis, ")"); @@ -920,12 +944,15 @@ void BuildPmapSubmodule(py::module& m) { return self == other; }); - py::class_ replicated(pmap_lib, "Replicated"); - replicated.def(py::init()) - .def(py::pickle( - [](const Replicated& self) { return py::make_tuple(self.replicas); }, - [](py::tuple t) { return Replicated{t[0].cast()}; })) - .def_readonly("replicas", &Replicated::replicas) + nb::class_ replicated(pmap_lib_nb, "Replicated"); + replicated.def(nb::init()) + .def("__getstate__", + [](const Replicated& self) { return nb::make_tuple(self.replicas); }) + .def("__setstate__", + [](Replicated& self, nb::tuple t) { + new (&self) Replicated{nb::cast(t[0])}; + }) + .def_ro("replicas", &Replicated::replicas) .def("__repr__", [](const Replicated& x) { return absl::StrCat("Replicated(replicas=", x.replicas, ")"); @@ -934,45 +961,47 @@ void BuildPmapSubmodule(py::module& m) { return self == other; }); - py::class_ sharding_spec(pmap_lib, "ShardingSpec"); + nb::class_ sharding_spec(pmap_lib_nb, "ShardingSpec"); sharding_spec - .def(py::init(), py::arg("sharding"), - py::arg("mesh_mapping")) - .def(py::pickle( - [](const ShardingSpec& self) { - auto sharding = - xla::SpanToTuple(absl::MakeConstSpan(self.GetSharding())); - auto mesh_mapping = - xla::SpanToTuple(absl::MakeConstSpan(self.GetMeshMapping())); - return py::make_tuple(sharding, mesh_mapping); - }, - [](py::tuple t) { - return ShardingSpec{t[0].cast>(), - t[1].cast>()}; - })) - .def_property_readonly( + .def(nb::init(), nb::arg("sharding"), + nb::arg("mesh_mapping")) + .def("__getstate__", + [](const ShardingSpec& self) { + auto sharding = + xla::SpanToNbTuple(absl::MakeConstSpan(self.GetSharding())); + auto mesh_mapping = + xla::SpanToNbTuple(absl::MakeConstSpan(self.GetMeshMapping())); + return nb::make_tuple(sharding, mesh_mapping); + }) + .def("__setstate__", + [](ShardingSpec& self, nb::tuple t) { + new (&self) + ShardingSpec{nb::cast>(t[0]), + nb::cast>(t[1])}; + }) + .def_prop_ro( "sharding", [](const ShardingSpec& self) { - return xla::SpanToTuple(absl::MakeConstSpan(self.GetSharding())); - }) - .def_property_readonly( - "mesh_mapping", - [](const ShardingSpec& self) { - return xla::SpanToTuple(absl::MakeConstSpan(self.GetMeshMapping())); + return xla::SpanToNbTuple(absl::MakeConstSpan(self.GetSharding())); }) + .def_prop_ro("mesh_mapping", + [](const ShardingSpec& self) { + return xla::SpanToNbTuple( + absl::MakeConstSpan(self.GetMeshMapping())); + }) .def("__eq__", [](const ShardingSpec& self, const ShardingSpec& other) { return self == other; }) .def("__hash__", [](const ShardingSpec& self) { const size_t hash = absl::HashOf(self); - return py::int_(hash); + return nb::int_(hash); }); // We need to use heap-allocated type objects because we want to add // additional methods dynamically. - py::object cfun; + nb::object cfun; { - py::str name = py::str("PmapFunction"); - py::str qualname = py::str("PmapFunction"); + nb::str name = nb::str("PmapFunction"); + nb::str qualname = nb::str("PmapFunction"); PyHeapTypeObject* heap_type = reinterpret_cast( PyType_Type.tp_alloc(&PyType_Type, 0)); // Caution: we must not call any functions that might invoke the GC until @@ -998,41 +1027,40 @@ void BuildPmapSubmodule(py::module& m) { type->tp_vectorcall_offset = offsetof(JaxPmapFunctionObject, vectorcall); CHECK_EQ(PyType_Ready(type), 0); JaxPmapFunction_Type = reinterpret_cast(type); - cfun = py::reinterpret_borrow(JaxPmapFunction_Type); + cfun = nb::borrow(JaxPmapFunction_Type); } - py::object cfun_type = - py::reinterpret_borrow(JaxPmapFunction_Type); + nb::object cfun_type = nb::borrow(JaxPmapFunction_Type); // Add PmapFunction to the xla_extension module so it can be pickled. m.attr("PmapFunction") = cfun_type; cfun.attr("__signature__") = - property_readonly([](py::handle self) -> py::object { + xla::nb_property_readonly([](nb::handle self) -> nb::object { PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); return fun->PythonSignature(); }); // Required by `post_hook`. cfun.attr("_cache_miss") = - property_readonly([](py::handle self) -> py::object { + xla::nb_property_readonly([](nb::handle self) -> nb::object { PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); return fun->cache_miss(); }); - cfun.attr("__getstate__") = py::cpp_function( + cfun.attr("__getstate__") = nb::cpp_function( [](const PmapFunction::object& self) { PmapFunction* fn = self.func(); - py::dict pickle; + nb::dict pickle; pickle["version"] = kPmapFunctionPickleVersion; pickle["fun"] = fn->fun(); pickle["cache_miss"] = fn->cache_miss(); pickle["static_argnums"] = fn->static_argnums(); pickle["python_shard_arg_fallback"] = fn->python_shard_arg_fallback(); - pickle["pytree_registry"] = fn->pytree_registry(); + pickle["pytree_registry"] = nb::cast(fn->pytree_registry()); return pickle; }, - py::is_method(cfun_type)); - cfun.attr("__setstate__") = py::cpp_function( - [](PmapFunction::object& self, const py::dict& pickle) { - int version = py::cast(pickle["version"]); + nb::is_method()); + cfun.attr("__setstate__") = nb::cpp_function( + [](PmapFunction::object& self, const nb::dict& pickle) { + int version = nb::cast(pickle["version"]); if (version != kPmapFunctionPickleVersion) { throw std::invalid_argument(absl::StrFormat( "Invalid PmapFunction pickle version, got %d, expected %d. " @@ -1040,55 +1068,58 @@ void BuildPmapSubmodule(py::module& m) { "versions is not supported.", version, kPmapFunctionPickleVersion)); } - py::function fun = py::cast(pickle["fun"]); - py::function cache_miss = py::cast(pickle["cache_miss"]); + nb::callable fun = nb::cast(pickle["fun"]); + nb::callable cache_miss = nb::cast(pickle["cache_miss"]); std::vector static_argnums = - py::cast>(pickle["static_argnums"]); - py::function python_shard_arg_fallback = - py::cast(pickle["python_shard_arg_fallback"]); - auto pytree_registry = - pickle["pytree_registry"] - .cast>(); + nb::cast>(pickle["static_argnums"]); + nb::callable python_shard_arg_fallback = + nb::cast(pickle["python_shard_arg_fallback"]); + std::shared_ptr pytree_registry = + nb::cast>( + nb::handle(pickle["pytree_registry"].ptr())); new (&(reinterpret_cast(self.ptr())->fun)) PmapFunction(std::move(fun), std::move(cache_miss), std::move(static_argnums), std::move(python_shard_arg_fallback), std::move(pytree_registry)); }, - py::is_method(cfun_type)); + nb::is_method()); // This is only for testing/debugging purposes. cfun.attr("_cache_size") = - property_readonly([](py::handle self) -> py::object { + xla::nb_property_readonly([](nb::handle self) -> nb::object { PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); - return py::cast(fun->cache_size()); + return nb::cast(fun->cache_size()); }); - cfun.attr("_cache_clear") = py::cpp_function( - [](py::handle self) { + cfun.attr("_cache_clear") = nb::cpp_function( + [](nb::handle self) { PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); fun->cache_clear(); }, - py::is_method(cfun)); + nb::is_method()); - cfun.attr("_debug_cache_keys") = py::cpp_function( - [](py::handle self) -> std::string { + cfun.attr("_debug_cache_keys") = nb::cpp_function( + [](nb::handle self) -> std::string { PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); return fun->DebugCacheKeys(); }, - py::is_method(cfun_type)); + nb::is_method()); pmap_lib.def( "pmap", - [](py::function fun, py::function cache_miss, - std::vector static_argnums, py::function shard_arg_fallback, - std::shared_ptr pytree_registry) -> py::object { + [](nb::callable fun, nb::callable cache_miss, + std::vector static_argnums, nb::callable shard_arg_fallback, + nb::object pytree_registry) -> nb::object { + std::shared_ptr registry = + nb::cast>( + nb::handle(pytree_registry.ptr())); return MakePmapFunction( std::move(fun), std::move(cache_miss), std::move(static_argnums), - std::move(shard_arg_fallback), std::move(pytree_registry)); + std::move(shard_arg_fallback), std::move(registry)); }, - py::arg("fun"), py::arg("cache_miss"), py::arg("static_argnums"), - py::arg("shard_arg_fallback"), py::arg("pytree_registry")); + nb::arg("fun"), nb::arg("cache_miss"), nb::arg("static_argnums"), + nb::arg("shard_arg_fallback"), nb::arg("pytree_registry")); } } // namespace jax diff --git a/xla/python/pmap_lib.h b/xla/python/pmap_lib.h index b76b05a98ae0f..307f22c3a4a3a 100644 --- a/xla/python/pmap_lib.h +++ b/xla/python/pmap_lib.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,15 +21,7 @@ limitations under the License. #include // placeholder for index annotation headers -#include "absl/types/variant.h" -#include "pybind11/cast.h" // from @pybind11 -#include "pybind11/numpy.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 -#include "xla/pjrt/pjrt_client.h" -#include "xla/python/py_buffer.h" -#include "xla/python/sharded_device_array.h" -#include "xla/python/types.h" +#include "nanobind/nanobind.h" // from @nanobind // TODO(jblespiau): The current implementation moves the Python logic to C++, // as a preliminary step to executing the `pmap` execution path from C++. @@ -38,7 +30,7 @@ limitations under the License. namespace jax { -void BuildPmapSubmodule(pybind11::module& m); +void BuildPmapSubmodule(nanobind::module_& m); } // namespace jax diff --git a/xla/python/pprof_profile_builder.cc b/xla/python/pprof_profile_builder.cc index 50a788f6da676..0953f86f885bc 100644 --- a/xla/python/pprof_profile_builder.cc +++ b/xla/python/pprof_profile_builder.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,24 +15,29 @@ limitations under the License. #include "xla/python/pprof_profile_builder.h" +#include // IWYU pragma: keep + #include +#include #include -#include "xla/python/traceback.h" -#include "xla/statusor.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/string_view.h" // from @nanobind // IWYU pragma: keep #include "xla/util.h" +#include "tsl/platform/logging.h" #include "tsl/platform/protobuf.h" namespace xla { -namespace py = pybind11; +namespace nb = nanobind; PprofProfileBuilder::PprofProfileBuilder() { CHECK_EQ(0, StringId("")); } -int PprofProfileBuilder::StringId(const std::string& s) { +int PprofProfileBuilder::StringId(std::string_view s) { auto ret = strings_.emplace(s, profile_.string_table_size()); if (ret.second) { - profile_.add_string_table(s); + profile_.add_string_table(s.data(), s.size()); } return ret.first->second; } @@ -43,10 +48,11 @@ int PprofProfileBuilder::FunctionId(PyCodeObject* code) { if (ret.second) { auto* function = profile_.add_function(); function->set_id(ret.first->second); - int name = StringId(py::str(code->co_name)); + int name = StringId(nb::cast(nb::str(code->co_name))); function->set_name(name); function->set_system_name(name); - function->set_filename(StringId(py::str(code->co_filename))); + function->set_filename( + StringId(nb::cast(nb::str(code->co_filename)))); function->set_start_line(code->co_firstlineno); } return ret.first->second; @@ -66,7 +72,7 @@ int PprofProfileBuilder::LocationId(PyCodeObject* code, int instruction) { return ret.first->second; } -StatusOr JsonToPprofProfile(std::string json) { +absl::StatusOr JsonToPprofProfile(std::string json) { tensorflow::tfprof::pprof::Profile profile; auto status = tsl::protobuf::util::JsonStringToMessage(json, &profile); if (!status.ok()) { @@ -76,12 +82,13 @@ StatusOr JsonToPprofProfile(std::string json) { return InvalidArgument("JSON parsing failed: %s", std::string{status.message()}); } - return py::bytes(profile.SerializeAsString()); + std::string s = profile.SerializeAsString(); + return nb::bytes(s.data(), s.size()); } -StatusOr PprofProfileToJson(py::bytes binary_proto) { +absl::StatusOr PprofProfileToJson(nb::bytes binary_proto) { tensorflow::tfprof::pprof::Profile profile; - profile.ParseFromString(binary_proto); + profile.ParseFromArray(binary_proto.c_str(), binary_proto.size()); std::string output; auto status = tsl::protobuf::util::MessageToJsonString(profile, &output); if (!status.ok()) { diff --git a/xla/python/pprof_profile_builder.h b/xla/python/pprof_profile_builder.h index e9cc0ef0ae692..9902f367d5bd5 100644 --- a/xla/python/pprof_profile_builder.h +++ b/xla/python/pprof_profile_builder.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,12 +16,15 @@ limitations under the License. #ifndef XLA_PYTHON_PPROF_PROFILE_BUILDER_H_ #define XLA_PYTHON_PPROF_PROFILE_BUILDER_H_ +#include + #include +#include #include #include "absl/container/flat_hash_map.h" -#include "pybind11/pybind11.h" // from @pybind11 -#include "xla/statusor.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" // from @nanobind #include "tsl/profiler/protobuf/profile.pb.h" namespace xla { @@ -33,7 +36,7 @@ class PprofProfileBuilder { tensorflow::tfprof::pprof::Profile& profile() { return profile_; } // Adds or returns the ID of `s` in the table. - int StringId(const std::string& s); + int StringId(std::string_view s); // Adds or returns the ID of a function. int FunctionId(PyCodeObject* code); @@ -56,10 +59,10 @@ class PprofProfileBuilder { // extensions that contain the same protocol buffer message. Instead, we accept // a JSON representation from Python and use this function to serialize it to // a uncompressed binary protocol buffer. -StatusOr JsonToPprofProfile(std::string json); +absl::StatusOr JsonToPprofProfile(std::string json); // The reverse, useful for testing. -StatusOr PprofProfileToJson(pybind11::bytes binary_proto); +absl::StatusOr PprofProfileToJson(nanobind::bytes binary_proto); } // namespace xla diff --git a/xla/python/profiler.cc b/xla/python/profiler.cc index 52b611bca6239..31f76a44ae09e 100644 --- a/xla/python/profiler.cc +++ b/xla/python/profiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,38 +18,88 @@ limitations under the License. #include #include #include +#include #include -#include -#include "absl/strings/string_view.h" -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 +#include "absl/strings/str_cat.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // from @nanobind // IWYU pragma: keep #include "xla/backends/profiler/plugin/plugin_tracer.h" #include "xla/backends/profiler/plugin/profiler_c_api.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_profiler_extension.h" -#include "xla/python/exceptions.h" -#include "xla/python/profiler/internal/traceme_wrapper.h" -#include "xla/python/status_casters.h" -#include "xla/python/types.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/xplane_to_profile_instructions.h" -#include "xla/status.h" +#include "tsl/platform/macros.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep #include "tsl/profiler/lib/profiler_factory.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/lib/profiler_session.h" +#include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/rpc/client/capture_profile.h" #include "tsl/profiler/rpc/profiler_server.h" namespace xla { -namespace py = pybind11; +namespace nb = nanobind; namespace { -// Adds a trivial forwarding class so these Python bindings and TensorFlow's -// bindings of the same thing don't register the same class with pybind11. -class TraceMeWrapper : public xla::profiler::TraceMeWrapper { + +// Wraps TraceMe with an interface that takes python types. +class TraceMeWrapper { public: - using xla::profiler::TraceMeWrapper::TraceMeWrapper; + // nb::str and nb::kwargs are taken by const reference to avoid + // python reference-counting overhead. + TraceMeWrapper(const nb::str& name, const nb::kwargs& kwargs) + : traceme_( + [&]() { + std::string name_and_metadata = nb::cast(name); + if (kwargs.size() > 0) { + AppendMetadata(&name_and_metadata, kwargs); + } + return name_and_metadata; + }, + /*level=*/1) {} + + // nb::kwargs is taken by const reference to avoid python + // reference-counting overhead. + void SetMetadata(const nb::kwargs& kwargs) { + if (TF_PREDICT_FALSE(kwargs.size() > 0)) { + traceme_.AppendMetadata([&]() { + std::string metadata; + AppendMetadata(&metadata, kwargs); + return metadata; + }); + } + } + + void Stop() { traceme_.Stop(); } + + static bool IsEnabled() { return tsl::profiler::TraceMe::Active(); } + + private: + // Converts kwargs to strings and appends them to name encoded as TraceMe + // metadata. + static void AppendMetadata(std::string* name, const nb::kwargs& kwargs) { + name->push_back('#'); + for (const auto& kv : kwargs) { + absl::StrAppend(name, nb::cast(kv.first), "=", + EncodePyObject(kv.second), ","); + } + name->back() = '#'; + } + + static std::string EncodePyObject(nb::handle handle) { + if (nb::isinstance(handle)) { + return nb::cast(handle) ? "1" : "0"; + } + return nb::cast(nb::str(handle)); + } + + tsl::profiler::TraceMe traceme_; }; tensorflow::ProfileOptions DefaultPythonProfileOptions() { @@ -60,10 +110,10 @@ tensorflow::ProfileOptions DefaultPythonProfileOptions() { } const PLUGIN_Profiler_Api* FindProfilerApi(const PJRT_Api* pjrt_api) { - const PJRT_Structure_Base* next = - reinterpret_cast(pjrt_api->extension_start); + const PJRT_Extension_Base* next = + reinterpret_cast(pjrt_api->extension_start); while (next != nullptr && - next->type != PJRT_Structure_Type::PJRT_Structure_Type_Profiler) { + next->type != PJRT_Extension_Type::PJRT_Extension_Type_Profiler) { next = next->next; } if (next == nullptr) { @@ -73,12 +123,40 @@ const PLUGIN_Profiler_Api* FindProfilerApi(const PJRT_Api* pjrt_api) { } } // namespace -void BuildProfilerSubmodule(py::module* m) { - py::module profiler = - m->def_submodule("profiler", "TensorFlow profiler integration"); - py::class_> - profiler_server_class(profiler, "ProfilerServer"); +// nanobind requires in-place construction of types, but tsl::ProfilerSession +// can only be created by its factory function. No matter, we just box it +// ourselves. +struct ProfilerSessionWrapper { + explicit ProfilerSessionWrapper(std::unique_ptr session) + : session(std::move(session)) {} + + std::unique_ptr session; +}; + +static std::string GetFdoProfile(const std::string& xspace, + bool as_textproto = false) { + tensorflow::profiler::XSpace xspace_proto; + // TODO(phawkins): change to std::string_view when protobuf is + // updated in XLA. + xspace_proto.ParseFromString(std::string(xspace.c_str(), xspace.size())); + tensorflow::profiler::ProfiledInstructionsProto fdo_profile; + xla::ThrowIfError(xla::ConvertXplaneToProfiledInstructionsProto( + {xspace_proto}, &fdo_profile)); + if (as_textproto) { + std::string textproto; + if (tsl::protobuf::TextFormat::PrintToString(fdo_profile, &textproto)) { + return textproto; + } + throw xla::XlaRuntimeError("Unable to serialize format to textproto"); + } + return fdo_profile.SerializeAsString(); +} + +void BuildProfilerSubmodule(nb::module_& m) { + nb::module_ profiler = + m.def_submodule("profiler", "TensorFlow profiler integration"); + nb::class_ profiler_server_class( + profiler, "ProfilerServer"); profiler.def( "start_server", [](int port) -> std::unique_ptr { @@ -86,14 +164,14 @@ void BuildProfilerSubmodule(py::module* m) { server->StartProfilerServer(port); return server; }, - py::arg("port")); - profiler.def("register_plugin_profiler", [](py::capsule c_api) -> void { - if (absl::string_view(c_api.name()) != "pjrt_c_api") { + nb::arg("port")); + profiler.def("register_plugin_profiler", [](nb::capsule c_api) -> void { + if (std::string_view(c_api.name()) != "pjrt_c_api") { throw xla::XlaRuntimeError( "Argument to register_plugin_profiler was not a pjrt_c_api capsule."); } const PLUGIN_Profiler_Api* profiler_api = - FindProfilerApi(static_cast(c_api)); + FindProfilerApi(static_cast(c_api.data())); std::function( const tensorflow::ProfileOptions&)> create_func = [profiler_api = profiler_api]( @@ -104,101 +182,120 @@ void BuildProfilerSubmodule(py::module* m) { tsl::profiler::RegisterProfilerFactory(std::move(create_func)); }); - py::class_ profiler_session_class(profiler, - "ProfilerSession"); + nb::class_ profiler_session_class(profiler, + "ProfilerSession"); profiler_session_class - .def(py::init([]() { - return tsl::ProfilerSession::Create(DefaultPythonProfileOptions()); - })) - .def(py::init([](const tensorflow::ProfileOptions& options) { - return tsl::ProfilerSession::Create(options); - })) + .def("__init__", + [](ProfilerSessionWrapper* wrapper) { + new (wrapper) ProfilerSessionWrapper( + tsl::ProfilerSession::Create(DefaultPythonProfileOptions())); + }) + .def("__init__", + [](ProfilerSessionWrapper* wrapper, + const tensorflow::ProfileOptions& options) { + new (wrapper) + ProfilerSessionWrapper(tsl::ProfilerSession::Create(options)); + }) .def("stop_and_export", - [](tsl::ProfilerSession* sess, + [](ProfilerSessionWrapper* sess, const std::string& tensorboard_dir) -> void { tensorflow::profiler::XSpace xspace; // Disables the ProfilerSession - xla::ThrowIfError(sess->CollectData(&xspace)); + xla::ThrowIfError(sess->session->CollectData(&xspace)); xla::ThrowIfError(tsl::profiler::ExportToTensorBoard( xspace, tensorboard_dir, /* also_export_trace_json= */ true)); }) .def("stop", - [](tsl::ProfilerSession* sess) -> pybind11::bytes { + [](ProfilerSessionWrapper* sess) -> nb::bytes { tensorflow::profiler::XSpace xspace; // Disables the ProfilerSession - xla::ThrowIfError(sess->CollectData(&xspace)); - return xspace.SerializeAsString(); + xla::ThrowIfError(sess->session->CollectData(&xspace)); + std::string xspace_str = xspace.SerializeAsString(); + return nb::bytes(xspace_str.data(), xspace_str.size()); }) .def("export", - [](tsl::ProfilerSession* sess, const std::string& xspace, + [](ProfilerSessionWrapper* sess, nb::bytes xspace, const std::string& tensorboard_dir) -> void { tensorflow::profiler::XSpace xspace_proto; - xspace_proto.ParseFromString(xspace); + // TODO(phawkins): change to std::string_view when protobuf is + // updated in XLA. + xspace_proto.ParseFromString( + std::string(xspace.c_str(), xspace.size())); xla::ThrowIfError(tsl::profiler::ExportToTensorBoard( xspace_proto, tensorboard_dir, /* also_export_trace_json= */ true)); }); - py::class_ profile_options_class( + nb::class_ profile_options_class( profiler, "ProfileOptions"); - profile_options_class.def(py::init(&DefaultPythonProfileOptions)) - .def_property("include_dataset_ops", - &tensorflow::ProfileOptions::include_dataset_ops, - &tensorflow::ProfileOptions::set_include_dataset_ops) - .def_property("host_tracer_level", - &tensorflow::ProfileOptions::host_tracer_level, - &tensorflow::ProfileOptions::set_host_tracer_level) - .def_property("python_tracer_level", - &tensorflow::ProfileOptions::python_tracer_level, - &tensorflow::ProfileOptions::set_python_tracer_level) - .def_property("enable_hlo_proto", - &tensorflow::ProfileOptions::enable_hlo_proto, - &tensorflow::ProfileOptions::set_enable_hlo_proto) - .def_property("start_timestamp_ns", - &tensorflow::ProfileOptions::start_timestamp_ns, - &tensorflow::ProfileOptions::set_start_timestamp_ns) - .def_property("duration_ms", &tensorflow::ProfileOptions::duration_ms, - &tensorflow::ProfileOptions::set_duration_ms) - .def_property( + profile_options_class + .def("__init__", + [](tensorflow::ProfileOptions* options) { + new (options) + tensorflow::ProfileOptions(DefaultPythonProfileOptions()); + }) + .def_prop_rw("include_dataset_ops", + &tensorflow::ProfileOptions::include_dataset_ops, + &tensorflow::ProfileOptions::set_include_dataset_ops) + .def_prop_rw("host_tracer_level", + &tensorflow::ProfileOptions::host_tracer_level, + &tensorflow::ProfileOptions::set_host_tracer_level) + .def_prop_rw("python_tracer_level", + &tensorflow::ProfileOptions::python_tracer_level, + &tensorflow::ProfileOptions::set_python_tracer_level) + .def_prop_rw("enable_hlo_proto", + &tensorflow::ProfileOptions::enable_hlo_proto, + &tensorflow::ProfileOptions::set_enable_hlo_proto) + .def_prop_rw("start_timestamp_ns", + &tensorflow::ProfileOptions::start_timestamp_ns, + &tensorflow::ProfileOptions::set_start_timestamp_ns) + .def_prop_rw("duration_ms", &tensorflow::ProfileOptions::duration_ms, + &tensorflow::ProfileOptions::set_duration_ms) + .def_prop_rw( "repository_path", &tensorflow::ProfileOptions::repository_path, [](tensorflow::ProfileOptions* options, const std::string& path) { options->set_repository_path(path); }); - py::class_ traceme_class(profiler, "TraceMe", - py::module_local()); - traceme_class.def(py::init()) - .def("__enter__", [](py::object self) -> py::object { return self; }) - .def("__exit__", - [](py::object self, const py::object& ex_type, - const py::object& ex_value, - const py::object& traceback) -> py::object { - py::cast(self)->Stop(); - return py::none(); - }) + nb::class_ traceme_class(profiler, "TraceMe"); + traceme_class.def(nb::init()) + .def("__enter__", [](nb::object self) -> nb::object { return self; }) + .def( + "__exit__", + [](nb::object self, const nb::object& ex_type, + const nb::object& ex_value, + const nb::object& traceback) -> nb::object { + nb::cast(self)->Stop(); + return nb::none(); + }, + nb::arg("ex_type").none(), nb::arg("ex_value").none(), + nb::arg("traceback").none()) .def("set_metadata", &TraceMeWrapper::SetMetadata) .def_static("is_enabled", &TraceMeWrapper::IsEnabled); profiler.def( "get_profiled_instructions_proto", - [](py::str tensorboard_dir) -> pybind11::bytes { + [](std::string tensorboard_dir) -> nb::bytes { tensorflow::profiler::ProfiledInstructionsProto profile_proto; xla::ThrowIfError( xla::ConvertXplaneUnderLogdirToProfiledInstructionsProto( tensorboard_dir, &profile_proto)); - return profile_proto.SerializeAsString(); + std::string profile_proto_str = profile_proto.SerializeAsString(); + return nb::bytes(profile_proto_str.data(), profile_proto_str.size()); }, - py::arg("tensorboard_dir")); + nb::arg("tensorboard_dir")); - profiler.def( - "get_fdo_profile", [](const std::string& xspace) -> pybind11::bytes { - tensorflow::profiler::XSpace xspace_proto; - xspace_proto.ParseFromString(xspace); - tensorflow::profiler::ProfiledInstructionsProto fdo_profile; - xla::ThrowIfError(xla::ConvertXplaneToProfiledInstructionsProto( - {xspace_proto}, &fdo_profile)); - return fdo_profile.SerializeAsString(); - }); + profiler.def("get_fdo_profile", + [](nb::bytes xspace, bool as_textproto = false) -> nb::object { + std::string out = GetFdoProfile( + std::string(xspace.c_str(), xspace.size()), as_textproto); + return nb::bytes(out.data(), out.size()); + }); + + profiler.def("get_fdo_profile", [](nb::bytes xspace) -> nb::object { + std::string out = GetFdoProfile(std::string(xspace.c_str(), xspace.size())); + return nb::bytes(out.data(), out.size()); + }); } } // namespace xla diff --git a/xla/python/profiler.h b/xla/python/profiler.h index 1b06ce61fbf59..f977589a32518 100644 --- a/xla/python/profiler.h +++ b/xla/python/profiler.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,11 +17,11 @@ limitations under the License. #define XLA_PYTHON_PROFILER_H_ // placeholder for index annotation headers -#include "pybind11/pybind11.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind namespace xla { -void BuildProfilerSubmodule(pybind11::module* m); +void BuildProfilerSubmodule(nanobind::module_& m); } // namespace xla diff --git a/xla/python/profiler/internal/BUILD b/xla/python/profiler/internal/BUILD index 1262c593b19f9..f030939e857df 100644 --- a/xla/python/profiler/internal/BUILD +++ b/xla/python/profiler/internal/BUILD @@ -1,6 +1,7 @@ +load("@tsl//tsl:tsl.bzl", "internal_visibility") load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") -load("@tsl//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("@tsl//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -15,10 +16,10 @@ cc_library( compatible_with = get_compatible_with_portable(), copts = tf_profiler_copts() + ["-fexceptions"], features = ["-use_header_modules"], # Incompatible with -fexceptions. - visibility = [ - "//tensorflow/python/profiler/internal:__subpackages__", + visibility = internal_visibility([ "//xla/backends/profiler:__subpackages__", - ], + "//tensorflow/python/profiler/internal:__subpackages__", + ]), deps = [ "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", @@ -41,10 +42,10 @@ cc_library( name = "traceme_wrapper", hdrs = ["traceme_wrapper.h"], copts = tf_profiler_copts(), - visibility = [ - "//tensorflow/python/profiler/internal:__pkg__", + visibility = internal_visibility([ "//xla/python:__pkg__", - ], + "//tensorflow/python/profiler/internal:__pkg__", + ]), deps = [ "@com_google_absl//absl/strings", "@pybind11", diff --git a/xla/python/profiler/internal/python_hooks.cc b/xla/python/profiler/internal/python_hooks.cc index 8e0d7a3e15f8a..6d2ff7318c273 100644 --- a/xla/python/profiler/internal/python_hooks.cc +++ b/xla/python/profiler/internal/python_hooks.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/profiler/internal/python_hooks.h b/xla/python/profiler/internal/python_hooks.h index b81f49356fd97..165d135a88750 100644 --- a/xla/python/profiler/internal/python_hooks.h +++ b/xla/python/profiler/internal/python_hooks.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/profiler/internal/traceme_wrapper.h b/xla/python/profiler/internal/traceme_wrapper.h index 5e2e34d479e53..63a595fd2040e 100644 --- a/xla/python/profiler/internal/traceme_wrapper.h +++ b/xla/python/profiler/internal/traceme_wrapper.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/py_array.cc b/xla/python/py_array.cc index 8986cae185fcb..d511952f25bc2 100644 --- a/xla/python/py_array.cc +++ b/xla/python/py_array.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,10 @@ limitations under the License. #include "xla/python/py_array.h" +#include + +#include +#include #include #include #include @@ -26,38 +30,112 @@ limitations under the License. #include #include +#include "absl/base/casts.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/Support/Casting.h" -#include "pybind11/pytypes.h" // from @pybind11 -#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/optional.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/vector.h" // from @nanobind // IWYU pragma: keep +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/pjrt/exceptions.h" #include "xla/pjrt/lru_cache.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/status_casters.h" +#include "xla/primitive_util.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/future.h" #include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/xla_sharding.h" -#include "xla/python/py_buffer.h" +#include "xla/python/py_client.h" +#include "xla/python/py_device.h" #include "xla/python/py_values.h" #include "xla/python/python_ref_manager.h" -#include "xla/python/python_utils.h" #include "xla/python/sharding.h" -#include "xla/python/status_casters.h" +#include "xla/python/traceback.h" #include "xla/python/transfer_guard_lib.h" +#include "xla/python/types.h" #include "xla/python/util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/status_macros.h" +#include "xla/statusor.h" +#if GOOGLE_CUDA +#include "xla/stream_executor/cuda/cuda_driver.h" +#endif #include "xla/util.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" namespace xla { namespace { -namespace py = pybind11; +namespace nb = nanobind; + +PjRtBuffer* GetPjrtBuffer(ifrt::Array* ifrt_array) { + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return arr->pjrt_buffers().front().get(); +} + +StatusOr XlaDynamicShape(ifrt::Array* ifrt_array, + std::optional& scratch) { + auto* pjrt_buffer = GetPjrtBuffer(ifrt_array); + + if (!scratch) { + absl::Span dims; + std::optional> logical_dims_storage; + if (pjrt_buffer->has_dynamic_dimensions()) { + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(std::vector logical_dims, + pjrt_buffer->logical_dimensions()); + logical_dims_storage.emplace(std::move(logical_dims)); + } + dims = *logical_dims_storage; + } else { + dims = pjrt_buffer->dimensions(); + } + Shape shape = ShapeUtil::MakeShape(pjrt_buffer->element_type(), dims); + // TODO(b/327524065): fix this + *shape.mutable_layout() = GetXlaLayoutUnsafe(pjrt_buffer->layout()); + scratch = std::move(shape); + } + return &scratch.value(); +} tsl::RCReference CreateIfRtArrayFromSingleDeviceShardedPyArrays( - py::object dtype, absl::Span shape, + nb_dtype dtype, absl::Span shape, absl::Span py_arrays) { if (py_arrays.empty()) { // TODO(hyeontaek): Return a Status. - throw py::value_error("At least one array must be provided."); + throw nb::value_error("At least one array must be provided."); } std::vector> ifrt_arrays; ifrt_arrays.reserve(py_arrays.size()); @@ -82,19 +160,21 @@ tsl::RCReference CreateIfRtArrayFromSingleDeviceShardedPyArrays( if (canonical_first_memory_kind != ifrt::CanonicalizeMemoryKind( ifrt_arrays.back()->sharding().memory_kind(), devices.back())) { - throw py::value_error(absl::StrFormat( - "Memory kind mismatch between PjRtBuffers. Got one buffer with " - "memory kind '%s' and another with memory_kind '%s'", - first_memory_kind.DebugString(), - ifrt_arrays.back()->sharding().memory_kind().DebugString())); + throw nb::value_error( + absl::StrFormat( + "Memory kind mismatch between PjRtBuffers. Got one buffer with " + "memory kind '%s' and another with memory_kind '%s'", + first_memory_kind.DebugString(), + ifrt_arrays.back()->sharding().memory_kind().DebugString()) + .c_str()); } } ifrt::Client* client = ifrt_arrays.front()->client(); - auto ifrt_dtype = ToIfRtDType(dtype); + auto ifrt_dtype = DtypeToIfRtDType(dtype); if (!ifrt_dtype.ok()) { // TODO(hyeontaek): Return a Status. - throw py::value_error(ifrt_dtype.status().ToString()); + throw nb::value_error(ifrt_dtype.status().ToString().c_str()); } auto ifrt_array = client->AssembleArrayFromSingleDeviceArrays( ifrt::Shape(shape), @@ -105,27 +185,29 @@ tsl::RCReference CreateIfRtArrayFromSingleDeviceShardedPyArrays( absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput); if (!ifrt_array.ok()) { // TODO(hyeontaek): Return a Status. - throw py::value_error(ifrt_array.status().ToString()); + throw nb::value_error(ifrt_array.status().ToString().c_str()); } return *std::move(ifrt_array); } // Creates an IFRT `MemoryKind` from a JAX `Sharding`. -ifrt::MemoryKind CreateIfRtMemoryKindFromSharding(const py::object& sharding) { - py::object py_memory_kind = py::none(); +ifrt::MemoryKind CreateIfRtMemoryKindFromSharding(const nb::object& sharding) { + nb::object py_memory_kind = nb::none(); // sharding.attr("memory_kind") can crash if sharding was originally created // from C++ and casted into a Python Sharding object. Thus, we cast sharding // to a C++ type and use C++ `memory_kind()` method, which bypasses any Python // attribute access. - auto type = sharding.get_type(); + nb::handle type = sharding.type(); if (type.is(jax::NamedSharding::type())) { - py_memory_kind = py::cast(sharding).memory_kind(); + py_memory_kind = + nb::cast(sharding)->memory_kind(); } else if (type.is(jax::GSPMDSharding::type())) { - py_memory_kind = py::cast(sharding).memory_kind(); + py_memory_kind = + nb::cast(sharding)->memory_kind(); } else if (type.is(jax::SingleDeviceSharding::type())) { py_memory_kind = - py::cast(sharding).memory_kind(); + nb::cast(sharding)->memory_kind(); } else { py_memory_kind = sharding.attr("memory_kind"); } @@ -133,12 +215,15 @@ ifrt::MemoryKind CreateIfRtMemoryKindFromSharding(const py::object& sharding) { if (py_memory_kind.is_none()) { return ifrt::MemoryKind(); } - return ifrt::MemoryKind(py::cast(py_memory_kind)); + return ifrt::MemoryKind(nb::cast(py_memory_kind)); } struct PyArrayObject { PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 PyObject* weakrefs; + PyObject* dict; +#endif // PY_VERSION_HEX < 0x030B0000 alignas(PyArray::Storage) char array_storage[sizeof(PyArray::Storage)]; }; static_assert(std::is_standard_layout::value); @@ -158,14 +243,15 @@ extern "C" void PyArray_tp_dealloc(PyObject* self) { PyTypeObject* tp = Py_TYPE(self); auto* obj = reinterpret_cast(self); - if (obj->weakrefs) { - PyObject_ClearWeakRefs(self); - } - GetPyArrayStorageFromObject(obj)->~PyArray_Storage(); + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 PyObject*& dict = *_PyObject_GetDictPtr(self); Py_CLEAR(dict); +#else + _PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 tp->tp_free(self); Py_DECREF(tp); @@ -174,40 +260,26 @@ extern "C" void PyArray_tp_dealloc(PyObject* self) { // dynamic_attr: Allow the garbage collector to traverse the internal instance // `__dict__`. extern "C" int PyArray_tp_traverse(PyObject* self, visitproc visit, void* arg) { +#if PY_VERSION_HEX < 0x030C0000 PyObject*& dict = *_PyObject_GetDictPtr(self); Py_VISIT(dict); -// https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse -#if PY_VERSION_HEX >= 0x03090000 +#else + _PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse Py_VISIT(Py_TYPE(self)); -#endif return 0; } // dynamic_attr: Allow the GC to clear the dictionary. extern "C" int PyArray_tp_clear(PyObject* self) { +#if PY_VERSION_HEX < 0x030C0000 PyObject*& dict = *_PyObject_GetDictPtr(self); Py_CLEAR(dict); - return 0; -} - -// Give instances of this type a `__dict__` and opt into garbage collection. -void EnableDynamicAttribute(PyHeapTypeObject* heap_type) { - auto* type = &heap_type->ht_type; - type->tp_flags |= Py_TPFLAGS_HAVE_GC; -#if PY_VERSION_HEX < 0x030B0000 - type->tp_dictoffset = type->tp_basicsize; // place dict at the end - type->tp_basicsize += - (ssize_t)sizeof(PyObject*); // and allocate enough space for it #else - type->tp_flags |= Py_TPFLAGS_MANAGED_DICT; -#endif - type->tp_traverse = PyArray_tp_traverse; - type->tp_clear = PyArray_tp_clear; - - static PyGetSetDef getset[] = {{"__dict__", PyObject_GenericGetDict, - PyObject_GenericSetDict, nullptr, nullptr}, - {nullptr, nullptr, nullptr, nullptr, nullptr}}; - type->tp_getset = getset; + _PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + return 0; } template @@ -232,33 +304,35 @@ struct ShapedArrayCacheKey { }; // Constructing ShapedArrays has gotten slow. Cache it. -py::object MakeShapedArrayCached(const ShapedArrayCacheKey& key) { +nb::object MakeShapedArrayCached(const ShapedArrayCacheKey& key) { using CacheT = - LRUCache>>; + LRUCache>>; static auto* lru_list = new CacheT::LRUList(4096); static auto* cache = new CacheT(lru_list); - static const py::handle* shaped_array = nullptr; - if (shaped_array == nullptr) { - auto* jax_core = PyImport_ImportModule("jax.core"); - if (jax_core != nullptr) { - shaped_array = new py::handle( - py::reinterpret_steal(jax_core).attr("ShapedArray")); - } else { - PyErr_Clear(); - return py::none(); + static const nb::object* shaped_array = []() -> nb::object* { + nb::object jax_core; + try { + jax_core = nb::module_::import_("jax.core"); + } catch (nb::python_error& e) { + return nullptr; } + return new nb::object(jax_core.attr("ShapedArray")); + }(); + if (!shaped_array) { + return nb::none(); } auto value = cache->GetOrCreateIfAbsent(key, [](const ShapedArrayCacheKey& key) { - return std::make_shared>(); + return std::make_shared>(); }); if (!value->has_value()) { - auto dtype = IfrtDtypeToDtype(key.dtype).value(); - py::object aval = (*shaped_array)( - SpanToTuple(absl::Span(key.dims)), dtype, key.weak_type); + nb_dtype dtype = IfrtDtypeToNbDtype(key.dtype).value(); + nb::object aval = + (*shaped_array)(SpanToNbTuple(absl::Span(key.dims)), + dtype, key.weak_type); *value = aval; return aval; } @@ -267,15 +341,15 @@ py::object MakeShapedArrayCached(const ShapedArrayCacheKey& key) { } // namespace -PyArray_Storage::PyArray_Storage(pybind11::object aval, bool weak_type, - pybind11::dtype dtype, +PyArray_Storage::PyArray_Storage(nb::object aval, bool weak_type, + xla::nb_dtype dtype, std::vector shape, - pybind11::object sharding, bool committed, - std::shared_ptr py_client, - std::shared_ptr traceback, - tsl::RCReference ifrt_array) - : fastpath_enabled(true), - aval(std::move(aval)), + nb::object sharding, bool committed, + nb_class_ptr py_client, + std::optional traceback, + tsl::RCReference ifrt_array, + xla::PjRtFuture result_status) + : aval(std::move(aval)), weak_type(weak_type), dtype(std::move(dtype)), shape(std::move(shape)), @@ -283,7 +357,8 @@ PyArray_Storage::PyArray_Storage(pybind11::object aval, bool weak_type, committed(committed), py_client(std::move(py_client)), traceback(std::move(traceback)), - ifrt_array(std::move(ifrt_array)) { + ifrt_array(std::move(ifrt_array)), + result_status(std::move(result_status)) { next = this->py_client->arrays_; this->py_client->arrays_ = this; if (next) { @@ -292,36 +367,28 @@ PyArray_Storage::PyArray_Storage(pybind11::object aval, bool weak_type, prev = nullptr; } -PyArray_Storage::PyArray_Storage(DisableFastpath) : fastpath_enabled(false) {} - -void PyArray::PyInit(py::object self, py::object aval, py::object sharding, +void PyArray::PyInit(PyArray self, nb::object aval, nb::object sharding, absl::Span py_arrays, bool committed, bool skip_checks) { - auto dtype = aval.attr("dtype"); - auto shape = pybind11::cast>(aval.attr("shape")); + auto dtype = nb::cast(aval.attr("dtype")); + auto shape = nb::cast>(aval.attr("shape")); auto ifrt_array = CreateIfRtArrayFromSingleDeviceShardedPyArrays(dtype, shape, py_arrays); Construct(reinterpret_cast(self.ptr()), aval, - pybind11::cast(aval.attr("weak_type")), std::move(dtype), + nb::cast(aval.attr("weak_type")), std::move(dtype), std::move(shape), std::move(sharding), committed, py_arrays.at(0).py_client(), Traceback::Get(), - std::move(ifrt_array)); - - PyArray py_array = self; + std::move(ifrt_array), xla::PjRtFuture()); if (!skip_checks) { - py_array.CheckAndRearrange(); + self.CheckAndRearrange(); } } -void PyArray::PyInit(py::object self, DisableFastpath) { - Construct(reinterpret_cast(self.ptr()), - PyArray_Storage::DisableFastpath()); -} - PyArray PyArray::MakeFromSingleDeviceArray( - std::shared_ptr py_client, std::shared_ptr traceback, - tsl::RCReference ifrt_array, bool weak_type, bool committed) { + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, bool weak_type, bool committed, + xla::PjRtFuture result_status) { if (!llvm::isa(ifrt_array->sharding())) { throw XlaRuntimeError( InvalidArgument("Constructing single device jax.Array from non-single " @@ -333,23 +400,24 @@ PyArray PyArray::MakeFromSingleDeviceArray( key.dtype = ifrt_array->dtype(); key.weak_type = weak_type; auto aval = MakeShapedArrayCached(key); - auto dtype = IfrtDtypeToDtype(key.dtype).value(); + auto dtype = IfrtDtypeToNbDtype(key.dtype).value(); const ifrt::MemoryKind memory_kind = ifrt_array->sharding().memory_kind(); - auto py_memory_kind = + nb::object py_memory_kind = (jax::GetEnableMemories() && memory_kind.memory_kind().has_value()) - ? py::object(py::str(*memory_kind.memory_kind())) - : py::none(); - auto sharding = py::cast(std::make_unique( - py_client, ifrt_array->sharding().devices(), std::move(py_memory_kind))); + ? nb::object(nb::str(memory_kind.memory_kind()->data(), + memory_kind.memory_kind()->size())) + : nb::none(); + nb::object sharding = make_nb_class( + py_client, ifrt_array->sharding().devices(), std::move(py_memory_kind)); return PyArray(std::move(aval), weak_type, dtype, std::move(key.dims), std::move(sharding), std::move(py_client), std::move(traceback), std::move(ifrt_array), committed, - /*skip_checks=*/true); + /*skip_checks=*/true, std::move(result_status)); } PyArray PyArray::MakeFromIfrtArrayAndSharding( - std::shared_ptr py_client, std::shared_ptr traceback, - tsl::RCReference ifrt_array, py::object sharding, + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, nb::object sharding, bool weak_type, bool committed, bool skip_checks) { auto shape_span = ifrt_array->shape().dims(); ShapedArrayCacheKey key; @@ -357,54 +425,58 @@ PyArray PyArray::MakeFromIfrtArrayAndSharding( key.dtype = ifrt_array->dtype(); key.weak_type = weak_type; auto aval = MakeShapedArrayCached(key); - auto dtype = IfrtDtypeToDtype(key.dtype).value(); + auto dtype = IfrtDtypeToNbDtype(key.dtype).value(); return PyArray(std::move(aval), weak_type, dtype, std::move(key.dims), std::move(sharding), std::move(py_client), std::move(traceback), std::move(ifrt_array), committed, skip_checks); } -PyArrayResultHandler::PyArrayResultHandler(py::object aval, py::object sharding, +PyArrayResultHandler::PyArrayResultHandler(nb::object aval, nb::object sharding, bool committed, bool skip_checks) : aval_(std::move(aval)), sharding_(std::move(sharding)), committed_(committed), skip_checks_(skip_checks) { - weak_type_ = pybind11::cast(aval_.attr("weak_type")); - dtype_ = aval_.attr("dtype"); - shape_ = pybind11::cast>(aval_.attr("shape")); + weak_type_ = nb::cast(aval_.attr("weak_type")); + dtype_ = nb::cast(aval_.attr("dtype")); + shape_ = nb::cast>(aval_.attr("shape")); } PyArray PyArrayResultHandler::Call(absl::Span py_arrays) const { - return Call(py_arrays.at(0).py_client(), - CreateIfRtArrayFromSingleDeviceShardedPyArrays(dtype_, shape_, - py_arrays)); + return Call( + py_arrays.at(0).py_client(), + CreateIfRtArrayFromSingleDeviceShardedPyArrays(dtype_, shape_, py_arrays), + xla::PjRtFuture()); } PyArray PyArrayResultHandler::Call( - std::shared_ptr py_client, - tsl::RCReference ifrt_array) const { + nb_class_ptr py_client, tsl::RCReference ifrt_array, + xla::PjRtFuture result_status) const { return PyArray(aval_, weak_type_, dtype_, shape_, sharding_, std::move(py_client), Traceback::Get(), std::move(ifrt_array), - committed_, skip_checks_); + committed_, skip_checks_, std::move(result_status)); } PyArray PyArrayResultHandler::Call(PyArray py_array) const { - return Call(py_array.py_client(), tsl::FormRef(py_array.ifrt_array())); + return Call(py_array.py_client(), tsl::FormRef(py_array.ifrt_array()), + xla::PjRtFuture()); } -PyArray::PyArray(py::object aval, bool weak_type, py::dtype dtype, - std::vector shape, py::object sharding, - std::shared_ptr py_client, - std::shared_ptr traceback, +PyArray::PyArray(nb::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nb::object sharding, + nb_class_ptr py_client, + std::optional traceback, tsl::RCReference ifrt_array, bool committed, - bool skip_checks) { + bool skip_checks, + xla::PjRtFuture result_status) { auto* self = PyArray_tp_new(reinterpret_cast(type_), nullptr, nullptr); - ptr() = self; + m_ptr = self; Construct(reinterpret_cast(self), std::move(aval), weak_type, std::move(dtype), std::move(shape), std::move(sharding), committed, - std::move(py_client), std::move(traceback), std::move(ifrt_array)); + std::move(py_client), std::move(traceback), std::move(ifrt_array), + std::move(result_status)); if (!skip_checks) { CheckAndRearrange(); @@ -432,53 +504,54 @@ const std::vector& PyArray::py_arrays_cached() { auto ifrt_arrays = ifrt_array()->DisassembleIntoSingleDeviceArrays( ifrt::ArrayCopySemantics::kReuseInput); if (!ifrt_arrays.ok()) { - throw py::value_error( + throw nb::value_error( absl::StrCat("Failed to disassemble into single-device arrays: ", - ifrt_arrays.status().ToString())); + ifrt_arrays.status().ToString()) + .c_str()); } py_arrays.reserve(ifrt_arrays->size()); for (auto& ifrt_array : *ifrt_arrays) { py_arrays.push_back(PyArray::MakeFromSingleDeviceArray( py_client(), traceback(), std::move(ifrt_array), weak_type(), - committed())); + committed(), result_status())); } } return py_arrays; } -py::object PyArray::arrays() { +nb::object PyArray::arrays() { // For performance, we only keep pjrt buffers by default. But on python side // "_arrays" returns PyArrays instead, and subsequent calls to "_arrays" // should return the same PyArrays (to avoid duplicate device to host // transfers). So we create PyArrays the first time it is called and reuse // them later. - if (ifrt_array() == nullptr || ifrt_array()->IsDeleted()) return py::none(); + if (ifrt_array() == nullptr || ifrt_array()->IsDeleted()) return nb::none(); if (llvm::isa(&ifrt_array()->sharding())) { std::vector py_arrays; py_arrays.push_back(*this); - return py::cast(py_arrays); + return nb::cast(py_arrays); } - return py::cast(py_arrays_cached()); + return nb::cast(py_arrays_cached()); } -Status PyArray::set_arrays(py::object obj) { +Status PyArray::set_arrays(nb::object obj) { if (obj.is_none()) { SetIfrtArray(tsl::RCReference()); py_arrays().clear(); return OkStatus(); } - if (!py::isinstance(obj)) { + if (!nb::isinstance(obj)) { return InvalidArgument("Unsupported arg when setting Array._arrays: %s", - py::cast(py::str(obj.get_type()))); + nb::cast(nb::str(obj.type()))); } - py::list list = obj; + nb::list list(obj); - if (list.empty()) return OkStatus(); + if (list.size() == 0) return OkStatus(); SetIfrtArray(tsl::RCReference()); py_arrays().clear(); @@ -488,10 +561,10 @@ Status PyArray::set_arrays(py::object obj) { devices.reserve(list.size()); std::vector shapes; shapes.reserve(list.size()); - for (py::handle obj : list) { - if (obj.get_type().is(PyArray::type())) { - auto py_array = py::reinterpret_borrow(obj); - if (py_array.py_client() != py_client()) { + for (nb::handle obj : list) { + if (obj.type().is(PyArray::type())) { + auto py_array = nb::borrow(obj); + if (py_array.py_client().get() != py_client().get()) { return InvalidArgument("Client mismatch when assigning to _arrays."); } if (py_array.num_shards() != 1) { @@ -503,7 +576,7 @@ Status PyArray::set_arrays(py::object obj) { shapes.push_back(ifrt_arrays.back()->shape()); } else { return InvalidArgument("Unsupported arg when setting Array._arrays: %s", - py::cast(py::str(obj.get_type()))); + nb::cast(nb::str(obj.type()))); } } const ifrt::MemoryKind first_memory_kind = @@ -518,11 +591,14 @@ Status PyArray::set_arrays(py::object obj) { ifrt::CanonicalizeMemoryKind( ifrt_array->sharding().memory_kind(), ifrt_array->sharding().devices().front())) { - throw py::value_error(absl::StrFormat( - "Memory kind mismatch between single-device arrays. Got one array " - "with memory kind '%s' and another with memory_kind '%s'", - first_memory_kind.DebugString(), - ifrt_array->sharding().memory_kind().DebugString())); + throw nb::value_error( + absl::StrFormat( + "Memory kind mismatch between single-device arrays. Got one " + "array " + "with memory kind '%s' and another with memory_kind '%s'", + first_memory_kind.DebugString(), + ifrt_array->sharding().memory_kind().DebugString()) + .c_str()); } } @@ -550,16 +626,17 @@ StatusOr PyArray::FullyReplicatedShard() { ifrt::ArrayCopySemantics::kReuseInput)); return MakeFromSingleDeviceArray(py_client(), traceback(), std::move(fully_replicated_ifrt_shard), - weak_type(), committed()); + weak_type(), committed(), result_status()); } Status PyArray::BlockUntilReady() const { - pybind11::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; if (ifrt_array() == nullptr) { return InvalidArgument( "BlockHostUntilReady() called on deleted or donated buffer"); } - return AwaitBuffersReady(ifrt_array()); + ifrt::Array* ifrt_array = this->ifrt_array(); + return AwaitBuffersReady(absl::MakeConstSpan(&ifrt_array, 1)); } StatusOr PyArray::GetOnDeviceSizeInBytes() { @@ -568,10 +645,9 @@ StatusOr PyArray::GetOnDeviceSizeInBytes() { "GetOnDeviceSizeInBytes() called on deleted or donated buffer"); } - TF_ASSIGN_OR_RETURN( - size_t shard_size, - IfrtHelpers::pjrt_buffer(ifrt_array())->GetOnDeviceSizeInBytes()); - return shard_size * py::len(sharding().attr("device_set")); + TF_ASSIGN_OR_RETURN(size_t shard_size, + GetPjrtBuffer(ifrt_array())->GetOnDeviceSizeInBytes()); + return shard_size * nb::len(nb::object(sharding().attr("device_set"))); } StatusOr PyArray::FetchSingleShard(std::string_view api) { @@ -590,20 +666,35 @@ StatusOr PyArray::FetchSingleShard(std::string_view api) { return py_arrays[0]; } -StatusOr PyArray::SingleDeviceArrayToNumpyArray() { +absl::Status PyArray::BlockUntilResultStatusIsReady() { + auto& result_status = GetStorage().result_status; + // If the result_status future is not valid, this result did not come directly + // from a computation that returns tokens, so we don't wait for the status. + if (!result_status.IsValid()) { + return absl::OkStatus(); + } + if (!result_status.IsReady()) { + // Only release the gil if we need to Await(). + nb::gil_scoped_release release_gil; + return result_status.Await(); + } + return result_status.Await(); +} + +StatusOr PyArray::SingleDeviceArrayToNumpyArray() { TF_ASSIGN_OR_RETURN(auto arr, FetchSingleShard("SingleDeviceArrayToNumpyArray")); - return PyHostValue::AsNumPyArray(arr.GetStorage().host_value, - arr.GetStorage().dynamic_shape, - arr.ifrt_array(), arr); + auto result = arr.GetStorage().host_value.AsNumPyArray( + arr.GetStorage().dynamic_shape, arr.ifrt_array()); + TF_RETURN_IF_ERROR(arr.BlockUntilResultStatusIsReady()); + return result; } Status PyArray::CopySingleDeviceArrayToHostAsync() { TF_ASSIGN_OR_RETURN(auto arr, FetchSingleShard("CopySingleDeviceArrayToHostAsync")); - return PyHostValue::CopyToHostAsync(arr.GetStorage().host_value, - arr.GetStorage().dynamic_shape, - arr.ifrt_array()); + return arr.GetStorage().host_value.CopyToHostAsync( + arr.GetStorage().dynamic_shape, arr.ifrt_array()); } StatusOr PyArray::AssertUnsharded(std::string_view api) { @@ -626,14 +717,197 @@ StatusOr PyArray::UnsafeBufferPointer() { TF_ASSIGN_OR_RETURN(auto arr, AssertUnsharded("UnsafeBufferPointer")); return py_client()->pjrt_client()->UnsafeBufferPointer( - IfrtHelpers::pjrt_buffer(arr.ifrt_array())); + GetPjrtBuffer(arr.ifrt_array())); } -StatusOr PyArray::CudaArrayInterface() { - TF_ASSIGN_OR_RETURN(auto arr, AssertUnsharded("UnsafeBufferPointer")); +nb::dict PyArray::CudaArrayInterface() { + auto arr_or_error = AssertUnsharded("UnsafeBufferPointer"); + if (!arr_or_error.ok()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only supported for unsharded arrays."); + } + auto arr = *arr_or_error; + + ifrt::Array* ifrt_array = arr.ifrt_array(); + std::optional& scratch = arr.GetStorage().dynamic_shape; + auto* pjrt_buffer = GetPjrtBuffer(ifrt_array); + if (pjrt_buffer->client()->platform_id() != CudaId()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only defined for NVidia GPU buffers."); + } + if (pjrt_buffer->IsTuple()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only defined for array buffers."); + } + + switch (pjrt_buffer->element_type()) { + case PrimitiveType::PRED: + case PrimitiveType::S8: + case PrimitiveType::S16: + case PrimitiveType::S32: + case PrimitiveType::S64: + case PrimitiveType::U8: + case PrimitiveType::U16: + case PrimitiveType::U32: + case PrimitiveType::U64: + case PrimitiveType::F16: + case PrimitiveType::F32: + case PrimitiveType::F64: + case PrimitiveType::C64: + case PrimitiveType::C128: + break; + + default: + throw nb::attribute_error( + absl::StrFormat( + "__cuda_array_interface__ is not supported for %s buffers.", + PrimitiveType_Name(pjrt_buffer->element_type())) + .c_str()); + } + + nb::str typestr = + ValueOrThrow(TypeDescriptorForPrimitiveType(pjrt_buffer->element_type())); - return IfrtHelpers::CudaArrayInterface(arr.ifrt_array(), - arr.GetStorage().dynamic_shape); + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + Layout xla_layout = GetXlaLayoutUnsafe(pjrt_buffer->layout()); + if (!LayoutUtil::IsMonotonicWithDim0Major(xla_layout)) { + throw nb::attribute_error( + "__cuda_array_interface__ is only currently supported for " + "buffers in row-major order."); + } + + nb::dict result; + const auto* dynamic_shape = + ValueOrThrow(XlaDynamicShape(ifrt_array, scratch)); + result["shape"] = SpanToNbTuple(dynamic_shape->dimensions()); + result["typestr"] = std::move(typestr); + std::unique_ptr external_reference_hold = + ValueOrThrow(pjrt_buffer->AcquireExternalReference()); + const void* root_ptr = + external_reference_hold->OpaqueDeviceMemoryDataPointer(); + nb::tuple data = + nb::make_tuple(nb::int_(absl::bit_cast(root_ptr)), + nb::bool_(true) /* read-only */ + ); + result["data"] = std::move(data); + result["version"] = nb::int_(2); + return result; +} + +StatusOr CudaArrayInterfaceToBuffer(const nb::dict& cai, + nb_class_ptr client) { +#ifndef GOOGLE_CUDA + throw XlaRuntimeError("This operation requires CUDA support."); +#else + if (!cai.contains("data")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `data`"); + } + if (!cai.contains("shape")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `shape`"); + } + if (!cai.contains("typestr")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `typestr`"); + } + if (!cai.contains("version")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `version`"); + } + auto version = nb::cast(cai["version"]); + if (version < 2 || version > 3) { + LOG(WARNING) << "CUDA Array Interface version " << version + << " support is undefined"; + } + auto data = nb::cast(cai["data"]); + auto data_value = nb::cast(data[0]); + void* data_ptr = reinterpret_cast(data_value); + auto dimensions = nb::cast>(cai["shape"]); + if (data_value == 0 && absl::c_find(dimensions, 0) == dimensions.end()) { + return absl::InvalidArgumentError( + "CUDA Array Interface `data`(=NULL) and `shape`(no zero-valued " + "dimensions) are inconsistent"); + } + auto ndim = dimensions.size(); + TF_ASSIGN_OR_RETURN( + PrimitiveType element_type, + DtypeToPrimitiveType(nb_dtype::from_args(cai["typestr"]))); + + // cannot determine device_id/stream when device pointer is NULL. + int device_id = + (data_value == 0 + ? 0 + : stream_executor::gpu::CreatedContexts::GetDeviceOrdinal(data_ptr)); + TF_ASSIGN_OR_RETURN(auto device, + client->DeviceFromLocalHardwareId(device_id)); + bool is_default_stream = + data_value == 0 || version == 2 || + (version == 3 && (!cai.contains("stream") || cai["stream"].is_none())); + TF_ASSIGN_OR_RETURN( + std::intptr_t stream, + ([is_default_stream, cai, device]() -> StatusOr { + if (is_default_stream) { + return device->GetStreamForExternalReadyEvents(); + } else { + auto stream_ = nb::cast(cai["stream"]); + if (stream_ == 0) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not allow zero stream value"); + } + return stream_; + } + }())); + + std::vector minor_to_major(ndim); + if (cai.contains("strides") && !cai["strides"].is_none() && data_value != 0) { + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + auto strides = nb::cast>(cai["strides"]); + if (strides.size() != ndim) { + return absl::InvalidArgumentError( + "CUDA Array Interface `shape` and `strides` dimensionalities are " + "inconsistent"); + } + absl::c_sort(minor_to_major, [&](int a, int b) { + // If two dimensions have the same stride, prefer the major-to-minor + // interpretation of the ordering, since that's what JAX wants. + return (strides[a] == strides[b] ? b < a : strides[a] < strides[b]); + }); + int64_t stride = ShapeUtil::ByteSizeOfPrimitiveType(element_type); + for (int64_t d : minor_to_major) { + if (dimensions[d] > 1 && strides[d] != stride) { + return absl::UnimplementedError(absl::StrCat( + "Only arrays with trivial (compact) striding are supported; " + "i.e., arrays whose striding represents a transposition of the " + "underlying buffer but not broadcasting. Dimensions were: [%s], " + "strides were [%s].", + absl::StrJoin(dimensions, ","), absl::StrJoin(strides, ","))); + } + stride *= dimensions[d]; + } + } else { + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, + minor_to_major); + std::function on_delete_callback = []() {}; + TF_ASSIGN_OR_RETURN( + auto pjrt_buffer, + device->client()->pjrt_client()->CreateViewOfDeviceBuffer( + static_cast(data_ptr), shape, device->device(), + on_delete_callback, + stream <= 2 ? std::nullopt : std::make_optional(stream))); + auto* ifrt_client = + llvm::dyn_cast_or_null(client->ifrt_client()); + if (ifrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_ASSIGN_OR_RETURN(auto ifrt_array, + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer))); + return PyArray::MakeFromSingleDeviceArray(std::move(client), Traceback::Get(), + std::move(ifrt_array), false, true); +#endif // GOOGLE_CUDA } Status PyArray::Delete() { @@ -665,19 +939,16 @@ PyArray PyArray::Clone() const { return PyArray(aval(), weak_type(), dtype(), std::vector(shape().begin(), shape().end()), sharding(), py_client(), traceback(), std::move(out), - committed(), /*skip_checks=*/true); + committed(), /*skip_checks=*/true, result_status()); } -py::handle PyArray::Storage::AsHandle() { +nb::handle PyArray::Storage::AsHandle() { return reinterpret_cast(reinterpret_cast(this) - offsetof(PyArrayObject, array_storage)); } PyArray::Storage::~PyArray_Storage() { CHECK(PyGILState_Check()); - if (!fastpath_enabled) { - return; - } if (py_client->arrays_ == this) { py_client->arrays_ = next; } @@ -689,8 +960,8 @@ PyArray::Storage::~PyArray_Storage() { } } -StatusOr PyArray::CopyToDeviceWithSharding( - ifrt::DeviceList devices, pybind11::object dst_sharding) { +StatusOr PyArray::CopyToDeviceWithSharding(ifrt::DeviceList devices, + nb::object dst_sharding) { auto* ifrt_array_ptr = ifrt_array(); ifrt::MemoryKind dst_memory_kind = CreateIfRtMemoryKindFromSharding(dst_sharding); @@ -704,14 +975,15 @@ StatusOr PyArray::CopyToDeviceWithSharding( { auto transfer_guard_formatter = [this, &dst_sharding] { return absl::StrCat( - "aval=", py::cast(py::repr(aval())), - ", sharding=", py::cast(py::repr(sharding())), - ", dst_sharding=", py::cast(py::repr(dst_sharding))); + "aval=", nb::cast(nb::repr(aval())), + ", sharding=", nb::cast(nb::repr(sharding())), + ", dst_sharding=", + nb::cast(nb::repr(dst_sharding))); }; TF_RETURN_IF_ERROR( jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter)); GlobalPyRefManager()->CollectGarbage(); - py::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; std::shared_ptr ifrt_sharding; // The sharding conversions are tried in the order of narrowness (e.g., // ShardingParamSharding is an IFRT-level sharding, whereas HloSharding is @@ -766,29 +1038,31 @@ StatusOr PyArray::CopyToDeviceWithSharding( return PyArray(aval(), weak_type(), dtype(), std::vector(shape_span.begin(), shape_span.end()), dst_sharding, py_client(), std::move(traceback), - std::move(out_array), committed(), /*skip_checks=*/true); + std::move(out_array), committed(), + /*skip_checks=*/true, result_status()); } StatusOr PyArray::BatchedDevicePut( - py::object aval, py::object sharding, std::vector xs, - std::vector> dst_devices, bool committed, + nb::object aval, nb::object sharding, std::vector xs, + absl::Span dst_devices, bool committed, bool force_copy, PjRtClient::HostBufferSemantics host_buffer_semantics, bool jax_enable_x64) { if (dst_devices.size() != xs.size() || xs.empty()) { - throw py::value_error( + throw nb::value_error( absl::StrCat("Argument sizes (xs and devices) must match %zu vs " "%zu and be nonzero", - dst_devices.size(), xs.size())); + dst_devices.size(), xs.size()) + .c_str()); } - for (ClientAndPtr& device : dst_devices) { - if (device.get_client() == nullptr) { + for (const PyDevice* device : dst_devices) { + if (device->client().get() == nullptr) { return InvalidArgument("Cannot copy to unattached devices."); } } auto transfer_guard_formatter = [&aval, &sharding] { return absl::StrCat( - "aval=", py::cast(py::repr(aval)), - ", dst_sharding=", py::cast(py::repr(sharding))); + "aval=", nb::cast(nb::repr(aval)), + ", dst_sharding=", nb::cast(nb::repr(sharding))); }; GlobalPyRefManager()->CollectGarbage(); @@ -798,10 +1072,10 @@ StatusOr PyArray::BatchedDevicePut( DevicePutOptions options; options.squash_64bit_types = !jax_enable_x64; options.allow_zero_copy = - (!force_copy && - (host_buffer_semantics == ifrt::Client::HostBufferSemantics::kZeroCopy)); + (!force_copy && (host_buffer_semantics == + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); - py::list owning_pylist(dst_devices.size()); + nb::list owning_pylist; std::vector> ifrt_arrays; xla::ifrt::DeviceList::Devices devices; @@ -822,8 +1096,8 @@ StatusOr PyArray::BatchedDevicePut( } TF_ASSIGN_OR_RETURN( DevicePutResult on_device, - DevicePut(x, dst_devices[i].get_client()->ifrt_client(), - dst_devices[i].get(), options, dst_memory_kind)); + DevicePut(x, dst_devices[i]->client()->ifrt_client(), + dst_devices[i]->device(), options, dst_memory_kind)); ifrt_arrays.push_back(std::move(on_device.ifrt_array)); devices.push_back(ifrt_arrays.back()->sharding().devices().front()); shapes.push_back(ifrt_arrays.back()->shape()); @@ -833,9 +1107,12 @@ StatusOr PyArray::BatchedDevicePut( ++i; } - auto weak_type = pybind11::cast(aval.attr("weak_type")); + // TODO(phawkins): it's highly suspicious to me that owning_pylist isn't + // consumed here. Look into this. + + auto weak_type = nb::cast(aval.attr("weak_type")); auto dtype = aval.attr("dtype"); - auto shape = pybind11::cast>(aval.attr("shape")); + auto shape = nb::cast>(aval.attr("shape")); TF_ASSIGN_OR_RETURN( auto ifrt_array, ifrt_arrays.front()->client()->AssembleArrayFromSingleDeviceArrays( @@ -848,17 +1125,44 @@ StatusOr PyArray::BatchedDevicePut( xla::ifrt::ArrayCopySemantics::kReuseInput)); return PyArray(aval, weak_type, dtype, std::move(shape), sharding, - dst_devices[0].client(), Traceback::Get(), ifrt_array, - committed, /*skip_checks=*/true); + dst_devices[0]->client(), Traceback::Get(), + std::move(ifrt_array), committed, /*skip_checks=*/true); +} + +absl::Status PyArray::BatchedBlockUntilReady(std::vector objs) { + // Create ready futures for all arrays before blocking on their readiness. + // This helps reduce the latency in some backend implementations where + // querying readiness of an array is not free. + + std::vector ifrt_arrays; + ifrt_arrays.reserve(objs.size()); + for (nb::handle obj : objs) { + if (obj.type().is(PyArray::type())) { + auto py_array = nb::borrow(obj); + ifrt::Array* const ifrt_array = py_array.ifrt_array(); + if (ifrt_array == nullptr) { + return absl::InvalidArgumentError( + "BlockHostUntilReady() called on deleted or donated buffer"); + } + ifrt_arrays.push_back(ifrt_array); + } else { + return absl::InvalidArgumentError( + "PyArray::BatchedBlockUntilReady can take PyArray only"); + } + } + + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + return AwaitBuffersReady(absl::MakeConstSpan(ifrt_arrays)); } -std::vector PyClient::LiveArrays() { - std::vector result; +std::vector PyClient::LiveArrays() const { + std::vector result; for (PyArray::Storage* array = arrays_; array; array = array->next) { bool all_deleted = (array->ifrt_array == nullptr || array->ifrt_array->IsDeleted()); if (!all_deleted) { - result.push_back(py::reinterpret_borrow(array->AsHandle())); + result.push_back(nb::borrow(array->AsHandle())); } } return result; @@ -885,9 +1189,15 @@ struct ExtraBufferInfo { std::unique_ptr external_reference_hold; }; +// The default layout of a non-tuple array should have major-to-minor layout +// and no tiles. +bool HasDefaultLayout(const Layout& layout) { + return LayoutUtil::IsMonotonicWithDim0Major(layout) && layout.tiles().empty(); +} + int PyArray_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) { - Status status = [&]() { - PyArray py_array = py::reinterpret_borrow(exporter); + Status status = [&]() -> absl::Status { + PyArray py_array = nb::borrow(exporter); if (py_array.ifrt_array() == nullptr) { // TODO(phawkins): why is this happening? return InvalidArgument("Array is null"); @@ -912,7 +1222,7 @@ int PyArray_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) { "Python buffer protocol is only defined for buffers with a single " "shard."); } - if (!py_array.sharding().get_type().is(jax::SingleDeviceSharding::type())) { + if (!py_array.sharding().type().is(jax::SingleDeviceSharding::type())) { return InvalidArgument( "Python buffer protocol is only defined for single-device sharded " "buffers."); @@ -932,7 +1242,7 @@ int PyArray_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) { // Py_buffer objects are POD C structures, so we don't need to hold the GIL. // Additionally we call BlockHostUntilReady() below, which may block. - py::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; if (buffer.IsTuple()) { return InvalidArgument( @@ -948,17 +1258,26 @@ int PyArray_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) { return InvalidArgument("Deleted buffer used in buffer protocol."); } + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + Layout xla_layout = GetXlaLayoutUnsafe(buffer.layout()); + if (((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS || (flags & PyBUF_STRIDES) == PyBUF_ND) && - !LayoutUtil::IsMonotonicWithDim0Major(buffer.layout())) { + !LayoutUtil::IsMonotonicWithDim0Major(xla_layout)) { return InvalidArgument("Buffer is not in C-contiguous layout."); } else if ((flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS && - !LayoutUtil::IsMonotonicWithDim0Minor(buffer.layout())) { + !LayoutUtil::IsMonotonicWithDim0Minor(xla_layout)) { return InvalidArgument("Buffer is not in F-contiguous layout."); } else if ((flags & PyBUF_ANY_CONTIGUOUS) == PyBUF_ANY_CONTIGUOUS && - !LayoutUtil::IsMonotonicWithDim0Major(buffer.layout()) && - !LayoutUtil::IsMonotonicWithDim0Minor(buffer.layout())) { + !LayoutUtil::IsMonotonicWithDim0Major(xla_layout) && + !LayoutUtil::IsMonotonicWithDim0Minor(xla_layout)) { return InvalidArgument("Buffer is not in contiguous layout."); + } else if (!HasDefaultLayout(xla_layout)) { + // Fail and fall back to using __array__ if the CPU buffer has a device + // specific layout. For instance, this happens for host buffers in pinned + // memories of the TPU device. + return InvalidArgument( + "Buffer is potentially a device buffer with non default layout."); } std::memset(view, 0, sizeof(Py_buffer)); const void* root_ptr = @@ -980,8 +1299,8 @@ int PyArray_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) { view->shape = reinterpret_cast( const_cast(buffer.dimensions().data())); if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { - extra->strides = ByteStridesForShape( - buffer.element_type(), buffer.dimensions(), buffer.layout()); + extra->strides = ByteStridesForShape(buffer.element_type(), + buffer.dimensions(), xla_layout); view->strides = reinterpret_cast( const_cast(extra->strides.data())); } @@ -1008,127 +1327,273 @@ void PyArray_bf_releasebuffer(PyObject*, Py_buffer* buffer) { delete extra; } -PyBufferProcs PyArray_tp_as_buffer = []() { - PyBufferProcs procs; - procs.bf_getbuffer = &PyArray_bf_getbuffer; - procs.bf_releasebuffer = &PyArray_bf_releasebuffer; - return procs; -}(); +// Returns if shape has a major-to-minor layout. +bool HasMajorToMinorLayout(const xla::Shape& shape) { + if (shape.has_layout()) { + for (int i = 0; i < shape.layout().minor_to_major_size(); ++i) { + if (shape.layout().minor_to_major(i) != + shape.layout().minor_to_major_size() - 1 - i) { + return false; + } + } + } + return true; +} +// Returns byte_strides if shape has a non-major-to-minor layout. +std::optional> ByteStridesOrDefaultForShapeInt64( + const Shape& shape) { + if (!shape.has_layout() || HasMajorToMinorLayout(shape)) { + return std::nullopt; + } + return ByteStridesForShape(shape); +} + +bool IsZeroCopyableCpuBuffer(const PjRtBuffer* buf) { + // For CPU buffers with device-specific layouts, we must delinearize + // to unpack the array. This could happen for the host buffer + // pre-mapped to the TPU device, a.k.a., pinned host buffers for the + // device. + bool has_default_layout = buf->layout() == nullptr || + HasDefaultLayout(GetXlaLayoutUnsafe(buf->layout())); + // On CPU for non-int4 values, we can return the value in a zero-copy way. + // For int4 values, we must copy in order to unpack the array. + return buf->IsOnCpu() && !primitive_util::Is4BitType(buf->element_type()) && + has_default_layout; +} } // namespace -Status PyArray::SetUpType() { - static constexpr char kName[] = "ArrayImpl"; +PyHostValue::PyHostValue() = default; +PyHostValue::~PyHostValue() = default; - py::str name = py::str(kName); - py::str qualname = py::str(kName); +StatusOr PyHostValue::AsNumPyArray( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { + if (ifrt_array->IsDeleted()) { + return InvalidArgument("DeviceArray has been deleted."); + } + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr != nullptr) { + auto* pjrt_buffer = arr->pjrt_buffers().front().get(); + TF_RET_CHECK(!pjrt_buffer->IsTuple()); + // On CPU for non-int4 values, we can return the value in a zero-copy way. + // For int4 values, we must copy in order to unpack the array. + if (IsZeroCopyableCpuBuffer(pjrt_buffer)) { + TF_ASSIGN_OR_RETURN(const auto* shape, + XlaDynamicShape(ifrt_array, dynamic_shape_holder)); + TF_ASSIGN_OR_RETURN(nb_dtype dtype, + PrimitiveTypeToNbDtype(shape->element_type())); + // Objects that must be kept alive while the array is alive. + struct Hold { + tsl::RCReference buffer; + std::unique_ptr external_reference_hold; + }; + auto hold = std::make_unique(); + TF_ASSIGN_OR_RETURN(hold->external_reference_hold, + pjrt_buffer->AcquireExternalReference()); + hold->buffer = tsl::FormRef(ifrt_array); + void* data = + hold->external_reference_hold->OpaqueDeviceMemoryDataPointer(); + nb::capsule hold_capsule(hold.release(), [](void* h) noexcept { + delete static_cast(h); + }); + nb_numpy_ndarray array(dtype, shape->dimensions(), + ByteStridesForShape(*shape), data, hold_capsule); + array.attr("flags").attr("writeable") = nb::bool_(false); + { + nb::gil_scoped_release gil; + TF_RETURN_IF_ERROR(ifrt_array->GetReadyFuture().Await()); + } + return array; + } + } - auto* heap_type = reinterpret_cast( - PyType_Type.tp_alloc(&PyType_Type, 0)); - // Caution: we must not call any functions that might invoke the GC until - // PyType_Ready() is called below. Otherwise the GC might see a - // half-constructed type object. - if (!heap_type) { - return Internal("Unable to create heap type object"); + TF_RETURN_IF_ERROR(CopyToHostAsync(dynamic_shape_holder, ifrt_array)); + if (!ready_.IsReady()) { + nb::gil_scoped_release gil; + TF_RETURN_IF_ERROR(ready_.Await()); + } else { + TF_RETURN_IF_ERROR(ready_.Await()); } - heap_type->ht_name = name.release().ptr(); - heap_type->ht_qualname = qualname.release().ptr(); - PyTypeObject* type = &heap_type->ht_type; - type->tp_name = kName; - type->tp_basicsize = sizeof(PyArrayObject); - type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE; - type->tp_new = PyArray_tp_new; - type->tp_dealloc = PyArray_tp_dealloc; + return value_; +} - // Supported protocols - type->tp_as_number = &heap_type->as_number; - type->tp_as_sequence = &heap_type->as_sequence; - type->tp_as_mapping = &heap_type->as_mapping; - type->tp_as_buffer = &PyArray_tp_as_buffer; +Status PyHostValue::CopyToHostAsync(std::optional& dynamic_shape_holder, + ifrt::Array* ifrt_array) { + if (ready_.IsValid()) { + // The array value has been populated, so CopyToHostAsync has been called. + return OkStatus(); + } + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr != nullptr && !arr->pjrt_buffers().front()->IsTuple() && + IsZeroCopyableCpuBuffer(arr->pjrt_buffers().front().get())) { + return OkStatus(); + } + auto transfer_guard_formatter = [ifrt_array] { + return absl::StrCat( + "shape=()", absl::StrJoin(ifrt_array->shape().dims(), ","), + "), dtype=", ifrt_array->dtype().DebugString(), + ", device=", ifrt_array->sharding().devices().front()->DebugString()); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter)); + + // TODO(b/182461453): This is a blocking call. If we further implemented + // populating dynamic shape metadata while fetching the literal, we wouldn't + // need this static approach. + const xla::Shape* dynamic_shape; + std::optional shape_holder; + if (llvm::isa(ifrt_array)) { + TF_ASSIGN_OR_RETURN(dynamic_shape, + XlaDynamicShape(ifrt_array, dynamic_shape_holder)); + } else { + // Skip querying the dynamic shape for a non-PjRt Array. + TF_ASSIGN_OR_RETURN(xla::PrimitiveType type, + ifrt::ToPrimitiveType(ifrt_array->dtype())); + shape_holder = ShapeUtil::MakeShapeWithDescendingLayout( + type, ifrt_array->shape().dims()); + dynamic_shape = &*shape_holder; + } - // Allow dynamic attributes. - EnableDynamicAttribute(heap_type); + xla::Shape host_shape = ShapeUtil::DeviceShapeToHostShape(*dynamic_shape); + + auto strides = ByteStridesOrDefaultForShapeInt64(host_shape); + TF_ASSIGN_OR_RETURN(nb_dtype dtype, + PrimitiveTypeToNbDtype(host_shape.element_type())); + value_ = nb_numpy_ndarray(dtype, host_shape.dimensions(), strides); + // TODO(hyeontaek): Several PjRt runtimes assume that the host buffer uses + // the same transposition as the device buffer. This is different from + // PjRtBuffer::ToLiteral()'s semantics that the runtime respects the layout + // of the host buffer literal. On the other hand, the runtime often knows + // better about an efficient layout for the host buffer. It will be useful + // to revisit the semantics of PjRtBuffer::ToLiteral() to see if it is + // desirable for the runtime to choose the layout. + ready_ = ifrt_array->CopyToHostBuffer(value_.mutable_data(), strides, + ifrt::ArrayCopySemantics::kReuseInput); + // Make sure the destination of the copy remains alive until the copy is done. + value_.inc_ref(); + ready_.OnReady([array{value_.ptr()}](Status status) { + GlobalPyRefManager()->AddGarbage(nb::steal(array)); + }); + value_.attr("flags").attr("writeable") = nb::bool_(false); + return OkStatus(); +} - // Allow weak references to DeviceArray objects. - type->tp_weaklistoffset = offsetof(PyArrayObject, weakrefs); +namespace { +PyGetSetDef PyArray_tp_getset[] = { + {"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, + nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}, +}; - TF_RET_CHECK(PyType_Ready(type) == 0); +PyMemberDef PyArray_members[] = { +#if PY_VERSION_HEX < 0x030C0000 + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(PyArrayObject, weakrefs)), READONLY, + nullptr}, + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(PyArrayObject, dict)), READONLY, nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; // namespace xla + +PyType_Slot PyArray_slots[] = { + {Py_tp_new, reinterpret_cast(PyArray_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(PyArray_tp_dealloc)}, + {Py_tp_members, reinterpret_cast(PyArray_members)}, + {Py_tp_traverse, reinterpret_cast(PyArray_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(PyArray_tp_clear)}, + {Py_tp_getset, reinterpret_cast(PyArray_tp_getset)}, + {Py_bf_getbuffer, reinterpret_cast(PyArray_bf_getbuffer)}, + {Py_bf_releasebuffer, reinterpret_cast(PyArray_bf_releasebuffer)}, + {0, nullptr}, +}; - PyArray::type_ = reinterpret_cast(type); +} // namespace - return OkStatus(); -} +Status PyArray::RegisterTypes(nb::module_& m) { + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".ArrayImpl"); + + PyType_Spec PyArray_spec = { + /*.name=*/name.c_str(), + /*.basicsize=*/static_cast(sizeof(PyArrayObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, +#else // PY_VERSION_HEX >= 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_MANAGED_DICT | Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX >= 0x030C0000 + /*.slots=*/PyArray_slots, + }; -Status PyArray::RegisterTypes(py::module& m) { - TF_RETURN_IF_ERROR(PyArray::SetUpType()); - auto type = py::reinterpret_borrow(type_); + type_ = PyType_FromSpec(&PyArray_spec); + if (!type_) { + throw nb::python_error(); + } + auto type = nb::borrow(type_); m.attr("ArrayImpl") = type; - type.attr("__init__") = py::cpp_function( - [](py::object self, py::object aval, py::object sharding, py::list arrays, + type.attr("__init__") = nb::cpp_function( + [](PyArray self, nb::object aval, nb::object sharding, nb::list arrays, bool committed, bool skip_checks) { - if (arrays[0].get_type().is(PyArray::type())) { - auto py_arrays = py::cast>(arrays); + if (arrays[0].type().is(PyArray::type())) { + auto py_arrays = nb::cast>(arrays); PyArray::PyInit(self, std::move(aval), std::move(sharding), py_arrays, committed, skip_checks); } else { - throw py::type_error( - absl::StrCat("Unsupported type for elements in `arrays`: ", - std::string(py::str(arrays[0].get_type())))); + throw nb::type_error( + absl::StrCat( + "Unsupported type for elements in `arrays`: ", + nb::cast(nb::str(arrays[0].type()))) + .c_str()); } }, - py::is_method(type), py::arg("aval"), py::arg("sharding"), - py::arg("arrays"), py::arg("committed"), py::arg("_skip_checks") = false); - // TODO(yashkatariya): remove this once the transition completes. - type.attr("_init_with_fastpath_disabled") = py::cpp_function( - [](py::object self) { - PyArray::PyInit(self, PyArray::DisableFastpath()); - }, - py::is_method(type)); - type.attr("delete") = - py::cpp_function([](PyArray& self) { xla::ThrowIfError(self.Delete()); }, - py::is_method(type)); - type.attr("_sharding") = jax::property_readonly(&PyArray::sharding); - type.attr("aval") = jax::property(&PyArray::aval, &PyArray::set_aval); + nb::is_method(), nb::arg("aval"), nb::arg("sharding"), nb::arg("arrays"), + nb::arg("committed"), nb::arg("_skip_checks") = false); + type.attr("delete") = nb::cpp_function( + [](PyArray& self) { xla::ThrowIfError(self.Delete()); }, nb::is_method()); + type.attr("_sharding") = nb_property_readonly(&PyArray::sharding); + type.attr("aval") = nb_property(&PyArray::aval, &PyArray::set_aval); type.attr("_arrays") = - jax::property(&PyArray::arrays, [](PyArray& self, py::object obj) { + nb_property(&PyArray::arrays, [](PyArray& self, nb::object obj) { xla::ThrowIfError(self.set_arrays(obj)); }); - type.attr("_fully_replicated_shard") = py::cpp_function( + type.attr("_fully_replicated_shard") = nb::cpp_function( [](PyArray self) { return xla::ValueOrThrow(self.FullyReplicatedShard()); }, - py::is_method(type)); + nb::is_method()); type.attr("_npy_value") = - jax::property(&PyArray::npy_value, &PyArray::set_npy_value); - type.attr("_committed") = jax::property_readonly(&PyArray::committed); - type.attr("unsafe_buffer_pointer") = py::cpp_function( + nb_property(&PyArray::npy_value, &PyArray::set_npy_value); + type.attr("_committed") = nb_property_readonly(&PyArray::committed); + type.attr("unsafe_buffer_pointer") = nb::cpp_function( [](PyArray self) { return xla::ValueOrThrow(self.UnsafeBufferPointer()); }, - py::is_method(type)); - type.attr("__cuda_array_interface__") = - jax::property_readonly([](PyArray self) { - return xla::ValueOrThrow(self.CudaArrayInterface()); - }); - type.attr("on_device_size_in_bytes") = py::cpp_function( + nb::is_method()); + type.attr("__cuda_array_interface__") = nb_property_readonly( + [](PyArray self) { return self.CudaArrayInterface(); }); + type.attr("_pjrt_layout") = + nb_property_readonly(xla::ValueOrThrowWrapper(&PyArray::layout)); + type.attr("on_device_size_in_bytes") = nb::cpp_function( xla::ValueOrThrowWrapper(&PyArray::GetOnDeviceSizeInBytes), - py::is_method(type)); - type.attr("_single_device_array_to_np_array") = py::cpp_function( + nb::is_method()); + type.attr("_single_device_array_to_np_array") = nb::cpp_function( xla::ValueOrThrowWrapper(&PyArray::SingleDeviceArrayToNumpyArray), - py::is_method(type)); - type.attr("_copy_single_device_array_to_host_async") = py::cpp_function( + nb::is_method()); + type.attr("_copy_single_device_array_to_host_async") = nb::cpp_function( [](PyArray& self) { xla::ThrowIfError(self.CopySingleDeviceArrayToHostAsync()); }, - py::is_method(type)); - type.attr("block_until_ready") = py::cpp_function( - [](PyArray self) -> py::object { + nb::is_method()); + type.attr("block_until_ready") = nb::cpp_function( + [](PyArray self) -> nb::object { xla::ThrowIfError(self.BlockUntilReady()); return self; }, - py::is_method(type)); - type.attr("platform") = py::cpp_function( + nb::is_method()); + type.attr("platform") = nb::cpp_function( [](PyArray self) { if (self.ifrt_array()->client()->platform_name() == "cuda" || self.ifrt_array()->client()->platform_name() == "rocm") { @@ -1137,37 +1602,37 @@ Status PyArray::RegisterTypes(py::module& m) { return self.ifrt_array()->client()->platform_name(); } }, - py::is_method(type)); - type.attr("is_ready") = py::cpp_function( + nb::is_method()); + type.attr("is_ready") = nb::cpp_function( [](PyArray self) { return xla::ValueOrThrow(self.IsReady()); }, - py::is_method(type)); + nb::is_method()); type.attr("is_deleted") = - py::cpp_function(&PyArray::IsDeleted, py::is_method(type)); - type.attr("traceback") = jax::property_readonly(&PyArray::traceback); - type.attr("clone") = py::cpp_function(&PyArray::Clone, py::is_method(type)); + nb::cpp_function(&PyArray::IsDeleted, nb::is_method()); + type.attr("traceback") = nb_property_readonly(&PyArray::traceback); + type.attr("clone") = nb::cpp_function(&PyArray::Clone, nb::is_method()); type.attr("__module__") = m.attr("__name__"); - m.attr("copy_array_to_devices_with_sharding") = py::cpp_function( - [](PyArray self, std::vector> dst_devices, - py::object sharding) { + m.attr("copy_array_to_devices_with_sharding") = nb::cpp_function( + [](PyArray self, absl::Span dst_devices, + nb::object sharding) { ifrt::DeviceList::Devices devices; devices.reserve(dst_devices.size()); for (auto& d : dst_devices) { - devices.push_back(d.get()); + devices.push_back(d->device()); } return xla::ValueOrThrow(self.CopyToDeviceWithSharding( ifrt::DeviceList(devices), std::move(sharding))); }); - m.attr("array_result_handler") = py::cpp_function( - [](py::object aval, py::object sharding, bool committed, - bool skip_checks) -> std::unique_ptr { - return std::make_unique( + m.attr("array_result_handler") = nb::cpp_function( + [](nb::object aval, nb::object sharding, bool committed, + bool skip_checks) -> nb_class_ptr { + return make_nb_class( std::move(aval), std::move(sharding), committed, skip_checks); }, - py::arg("aval"), py::arg("sharding"), py::arg("committed"), - py::arg("_skip_checks") = false); + nb::arg("aval"), nb::arg("sharding"), nb::arg("committed"), + nb::arg("_skip_checks") = false); - py::class_(m, "ResultHandler") + nb::class_(m, "ResultHandler") .def("__call__", [](const PyArrayResultHandler& self, PyArray arg) { return self.Call(arg); }) .def("__call__", diff --git a/xla/python/py_array.h b/xla/python/py_array.h index 88701d84f08eb..0d53b02c1103b 100644 --- a/xla/python/py_array.h +++ b/xla/python/py_array.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,10 @@ limitations under the License. #ifndef XLA_PYTHON_PY_ARRAY_H_ #define XLA_PYTHON_PY_ARRAY_H_ +#include + +#include +#include #include #include #include @@ -23,50 +27,84 @@ limitations under the License. #include // placeholder for index annotation headers +#include "absl/status/status.h" +#include "absl/types/span.h" #include "llvm/Support/Casting.h" -#include "pybind11/pybind11.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" -#include "xla/python/py_buffer.h" -#include "xla/python/types.h" +#include "xla/python/py_client.h" +#include "xla/python/traceback.h" +#include "xla/shape.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "xla/util.h" +#include "tsl/concurrency/ref_count.h" namespace xla { // Private to PyArray, but you cannot forward declare member classes. -struct PyArray_Storage { - PyArray_Storage(pybind11::object aval, bool weak_type, pybind11::dtype dtype, - std::vector shape, pybind11::object sharding, - bool committed, std::shared_ptr py_client, - std::shared_ptr traceback, - tsl::RCReference ifrt_array); +// Not thread safe; assumes the GIL is held. +class PyHostValue { + public: + PyHostValue(); + ~PyHostValue(); - // TODO(yashkatariya): remove this once the transition completes. - struct DisableFastpath {}; - explicit PyArray_Storage(DisableFastpath); + PyHostValue(const PyHostValue&) = delete; + PyHostValue(PyHostValue&&) = delete; + PyHostValue& operator=(const PyHostValue&) = delete; + PyHostValue& operator=(PyHostValue&&) = delete; - ~PyArray_Storage(); - pybind11::handle AsHandle(); + Status CopyToHostAsync(std::optional& dynamic_shape_holder, + ifrt::Array* ifrt_array); + + absl::StatusOr AsNumPyArray( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array); + + private: + ifrt::Future ready_; + nb_numpy_ndarray value_; +}; + +// Private to PyArray, but you cannot forward declare member classes. +struct PyArray_Storage { + PyArray_Storage(nanobind::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nanobind::object sharding, + bool committed, nb_class_ptr py_client, + std::optional traceback, + tsl::RCReference ifrt_array, + xla::PjRtFuture result_status); - // TODO(yashkatariya): remove this once the transition completes. - bool fastpath_enabled; + ~PyArray_Storage(); + nanobind::handle AsHandle(); - pybind11::object aval; + nanobind::object aval; bool weak_type = false; - pybind11::dtype dtype; + nb_dtype dtype; std::vector shape; - pybind11::object sharding; - pybind11::object npy_value = pybind11::none(); + nanobind::object sharding; + nanobind::object npy_value = nanobind::none(); bool committed = false; - std::shared_ptr py_client; - std::shared_ptr traceback; + nb_class_ptr py_client; + std::optional traceback; tsl::RCReference ifrt_array; // optional field, used only in python std::vector py_arrays; - std::shared_ptr host_value; // Protected by the GIL. + PyHostValue host_value; // Protected by the GIL. std::optional dynamic_shape = std::nullopt; + // Only set if this Array was generated by a computation that has effects. + // This is the result status of the XLA computation that generated this + // array. + xla::PjRtFuture result_status; // Doubly-linked list of all PyArrays known to the client. Protected by the // GIL. Since multiple PyArrays may share the same PjRtBuffer, there may be @@ -78,73 +116,77 @@ struct PyArray_Storage { // The C++ implementation of jax.Array. A few key methods and data members are // implemented in C++ for performance, while most of the functionalities are // still implemented in python. -class PyArray : public pybind11::object { +class PyArray : public nanobind::object { public: - PYBIND11_OBJECT(PyArray, pybind11::object, PyArray::IsPyArray); + NB_OBJECT(PyArray, nanobind::object, "Array", PyArray::IsPyArray); PyArray() = default; // "__init__" methods. Only used in python - static void PyInit(pybind11::object self, pybind11::object aval, - pybind11::object sharding, + static void PyInit(PyArray self, nanobind::object aval, + nanobind::object sharding, absl::Span py_arrays, bool committed, bool skip_checks); - // TODO(yashkatariya): remove this once the transition completes. - struct DisableFastpath {}; - static void PyInit(pybind11::object self, DisableFastpath); - // Only used in C++. `skip_checks` should only be set for Arrays created by // jax that cannot possibly have consistency issues (e.g. `sharding` devices // different than `ifrt_array` devices). Arrays created by users should be // checked. - PyArray(pybind11::object aval, bool weak_type, pybind11::dtype dtype, - std::vector shape, pybind11::object sharding, - std::shared_ptr py_client, - std::shared_ptr traceback, + PyArray(nanobind::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nanobind::object sharding, + nb_class_ptr py_client, + std::optional traceback, tsl::RCReference ifrt_array, bool committed, - bool skip_checks); + bool skip_checks, + xla::PjRtFuture result_status = + xla::PjRtFuture()); static PyArray MakeFromSingleDeviceArray( - std::shared_ptr py_client, std::shared_ptr traceback, - tsl::RCReference ifrt_array, bool weak_type, bool committed); + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, bool weak_type, bool committed, + xla::PjRtFuture result_status = + xla::PjRtFuture()); static PyArray MakeFromIfrtArrayAndSharding( - std::shared_ptr py_client, std::shared_ptr traceback, - tsl::RCReference ifrt_array, pybind11::object sharding, + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, nanobind::object sharding, bool weak_type, bool committed, bool skip_checks); - static Status RegisterTypes(pybind11::module& m); + static Status RegisterTypes(nanobind::module_& m); using Storage = PyArray_Storage; - const pybind11::object& aval() const { return GetStorage().aval; } - void set_aval(pybind11::object aval) { GetStorage().aval = std::move(aval); } + const nanobind::object& aval() const { return GetStorage().aval; } + void set_aval(nanobind::object aval) { GetStorage().aval = std::move(aval); } bool weak_type() const { return GetStorage().weak_type; } - const pybind11::dtype& dtype() const { return GetStorage().dtype; } + const nb_dtype& dtype() const { return GetStorage().dtype; } absl::Span shape() const { return GetStorage().shape; } - const pybind11::object& sharding() const { return GetStorage().sharding; } + const nanobind::object& sharding() const { return GetStorage().sharding; } + + absl::StatusOr> layout() { + return ifrt_array()->layout(); + } bool committed() const { return GetStorage().committed; } - const pybind11::object& npy_value() const { return GetStorage().npy_value; } - void set_npy_value(pybind11::object v) { + const nanobind::object& npy_value() const { return GetStorage().npy_value; } + void set_npy_value(nanobind::object v) { GetStorage().npy_value = std::move(v); } - const std::shared_ptr& py_client() const { + const nb_class_ptr& py_client() const { return GetStorage().py_client; } - const std::shared_ptr& traceback() const { + const std::optional& traceback() const { return GetStorage().traceback; } // Returns xla::InvalidArgument if the buffer has been deleted. // See `PjRtFuture` for the semantics of `IsReady` and `IsKnownReady`. - StatusOr IsReady() { + absl::StatusOr IsReady() { ifrt::Array* ifrt_array_ptr = ifrt_array(); if (ifrt_array_ptr->IsDeleted()) { return InvalidArgument("Array has been deleted."); @@ -152,6 +194,10 @@ class PyArray : public pybind11::object { return ifrt_array_ptr->GetReadyFuture().IsReady(); } + const xla::PjRtFuture& result_status() const { + return GetStorage().result_status; + } + ifrt::Array* ifrt_array() const { return GetStorage().ifrt_array.get(); } // Short-term escape hatch to get PjRtBuffers from PyArray. @@ -190,9 +236,9 @@ class PyArray : public pybind11::object { } const std::vector& py_arrays_cached(); - pybind11::object arrays(); - Status set_arrays(pybind11::object obj); - StatusOr FullyReplicatedShard(); + nanobind::object arrays(); + Status set_arrays(nanobind::object obj); + absl::StatusOr FullyReplicatedShard(); int num_shards() const { ifrt::Array* ifrt_array_ptr = ifrt_array(); @@ -202,25 +248,24 @@ class PyArray : public pybind11::object { return ifrt_array_ptr->sharding().devices().size(); } - // TODO(yashkatariya): remove this once the transition completes. - bool fastpath_enabled() const { return GetStorage().fastpath_enabled; } - - static pybind11::handle type() { + static nanobind::handle type() { DCHECK(type_); - return pybind11::handle(type_); + return nanobind::handle(type_); } - static bool IsPyArray(pybind11::handle arg) { - return arg.get_type().is(PyArray::type()); + static bool IsPyArray(nanobind::handle arg) { + return arg.type().is(PyArray::type()); } Status BlockUntilReady() const; - StatusOr GetOnDeviceSizeInBytes(); - StatusOr SingleDeviceArrayToNumpyArray(); + absl::Status BlockUntilResultStatusIsReady(); + + absl::StatusOr GetOnDeviceSizeInBytes(); + absl::StatusOr SingleDeviceArrayToNumpyArray(); Status CopySingleDeviceArrayToHostAsync(); - StatusOr CudaArrayInterface(); - StatusOr UnsafeBufferPointer(); + nanobind::dict CudaArrayInterface(); + absl::StatusOr UnsafeBufferPointer(); Status Delete(); @@ -228,19 +273,22 @@ class PyArray : public pybind11::object { PyArray Clone() const; - StatusOr CopyToDeviceWithSharding(ifrt::DeviceList devices, - pybind11::object dst_sharding); + absl::StatusOr CopyToDeviceWithSharding( + ifrt::DeviceList devices, nanobind::object dst_sharding); - static StatusOr BatchedDevicePut( - pybind11::object aval, pybind11::object sharding, - std::vector xs, - std::vector> dst_devices, bool committed, + static absl::StatusOr BatchedDevicePut( + nanobind::object aval, nanobind::object sharding, + std::vector xs, + absl::Span dst_devices, bool committed, bool force_copy, PjRtClient::HostBufferSemantics host_buffer_semantics, bool jax_enable_x64); + static absl::Status BatchedBlockUntilReady( + std::vector objs); + private: - StatusOr FetchSingleShard(std::string_view api); - StatusOr AssertUnsharded(std::string_view api); + absl::StatusOr FetchSingleShard(std::string_view api); + absl::StatusOr AssertUnsharded(std::string_view api); void CheckAndRearrange(); @@ -249,33 +297,36 @@ class PyArray : public pybind11::object { Storage& GetStorage(); const Storage& GetStorage() const; - static Status SetUpType(); - inline static PyObject* type_ = nullptr; }; class PyArrayResultHandler { public: - PyArrayResultHandler(pybind11::object aval, pybind11::object sharding, + PyArrayResultHandler(nanobind::object aval, nanobind::object sharding, bool committed, bool skip_checks); PyArray Call(absl::Span py_arrays) const; PyArray Call(PyArray py_array) const; - PyArray Call(std::shared_ptr py_client, - tsl::RCReference ifrt_array) const; + PyArray Call(nb_class_ptr py_client, + tsl::RCReference ifrt_array, + xla::PjRtFuture result_status = + xla::PjRtFuture()) const; private: - pybind11::object aval_; - pybind11::object sharding_; + nanobind::object aval_; + nanobind::object sharding_; bool weak_type_; bool committed_; bool skip_checks_; - pybind11::object dtype_; + nb_dtype dtype_; std::vector shape_; }; +absl::StatusOr CudaArrayInterfaceToBuffer( + const nanobind::dict& cai, nb_class_ptr cuda_client); + } // namespace xla #endif // XLA_PYTHON_PY_ARRAY_H_ diff --git a/xla/python/py_buffer.cc b/xla/python/py_buffer.cc deleted file mode 100644 index ed98c84c81fa2..0000000000000 --- a/xla/python/py_buffer.cc +++ /dev/null @@ -1,332 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/python/py_buffer.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/base/casts.h" -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 -#include "xla/pjrt/pjrt_client.h" -#include "xla/pjrt/pjrt_compiler.h" -#include "xla/primitive_util.h" -#include "xla/python/ifrt/array.h" -#include "xla/python/ifrt/device.h" -#include "xla/python/pjrt_ifrt/pjrt_array.h" -#include "xla/python/py_client.h" -#include "xla/python/python_ref_manager.h" -#include "xla/python/python_utils.h" -#include "xla/python/status_casters.h" -#include "xla/python/transfer_guard_lib.h" -#include "xla/python/types.h" -#include "xla/python/util.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" -namespace xla { - -namespace py = pybind11; - -namespace { - -// Returns if shape has a major-to-minor layout. -bool HasMajorToMinorLayout(const xla::Shape& shape) { - if (shape.has_layout()) { - for (int i = 0; i < shape.layout().minor_to_major_size(); ++i) { - if (shape.layout().minor_to_major(i) != - shape.layout().minor_to_major_size() - 1 - i) { - return false; - } - } - } - return true; -} - -// Returns byte_strides if shape has a non-major-to-minor layout. -std::optional> ByteStridesOrDefaultForShapeInt64( - const Shape& shape) { - if (!shape.has_layout() || HasMajorToMinorLayout(shape)) { - return std::nullopt; - } - return ByteStridesForShape(shape); -} - -} // namespace - -/* static */ PjRtBuffer* IfrtHelpers::pjrt_buffer(ifrt::Array* ifrt_array) { - auto* arr = llvm::dyn_cast_or_null(ifrt_array); - if (arr == nullptr) { - throw XlaRuntimeError( - "This operation is implemented for a PjRt-compatible backend only."); - } - return arr->pjrt_buffers().front().get(); -} - -/* static */ PjRtDevice* IfrtHelpers::pjrt_device(ifrt::Array* ifrt_array) { - return ifrt_array->sharding().devices().front(); -} - -/* static */ StatusOr IfrtHelpers::xla_dynamic_shape( - ifrt::Array* ifrt_array, std::optional& scratch) { - auto* pjrt_buffer = IfrtHelpers::pjrt_buffer(ifrt_array); - - if (!scratch) { - absl::Span dims; - std::optional> logical_dims_storage; - if (pjrt_buffer->has_dynamic_dimensions()) { - { - py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN(std::vector logical_dims, - pjrt_buffer->logical_dimensions()); - logical_dims_storage.emplace(std::move(logical_dims)); - } - dims = *logical_dims_storage; - } else { - dims = pjrt_buffer->dimensions(); - } - Shape shape = ShapeUtil::MakeShape(pjrt_buffer->element_type(), dims); - *shape.mutable_layout() = pjrt_buffer->layout(); - scratch = std::move(shape); - } - return &scratch.value(); -} - -pybind11::tuple IfrtHelpers::python_shape(ifrt::Array* ifrt_array) { - return SpanToTuple(ifrt_array->shape().dims()); -} - -pybind11::dtype IfrtHelpers::python_dtype(ifrt::Array* ifrt_array) { - // TODO(hyeontaek): Support non-XLA types such as xla::ifrt::DType::kString. - PrimitiveType primitive = ifrt::ToPrimitiveType(ifrt_array->dtype()).value(); - return PrimitiveTypeToDtype(primitive).value(); -} - -/* static */ StatusOr> IfrtHelpers::CopyToDevice( - ifrt::Array* ifrt_array, PjRtDevice* dst_device) { - CHECK(dst_device != nullptr); - auto transfer_guard_formatter = [ifrt_array, dst_device] { - auto shape = py::cast(py::str(python_shape(ifrt_array))); - auto dtype = py::cast(py::str(python_dtype(ifrt_array))); - return absl::StrCat("shape=", shape, ", dtype=", dtype, - ", device=", pjrt_device(ifrt_array)->DebugString(), - ", dst_device=", dst_device->DebugString()); - }; - TF_RETURN_IF_ERROR( - jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter)); - - GlobalPyRefManager()->CollectGarbage(); - py::gil_scoped_release gil_release; - // TODO(yashkatariya): Plumb sharding or memory_kind here. - return ifrt_array->Reshard( - ifrt::SingleDeviceSharding::Create(dst_device, ifrt::MemoryKind()), - ifrt::ArrayCopySemantics::kReuseInput); -} - -/* static */ StatusOr PyHostValue::AsNumPyArray( - std::shared_ptr& host_value, - std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array, - pybind11::handle this_obj) { - if (ifrt_array->IsDeleted()) { - return InvalidArgument("DeviceArray has been deleted."); - } - auto* arr = llvm::dyn_cast_or_null(ifrt_array); - if (arr != nullptr) { - auto* pjrt_buffer = arr->pjrt_buffers().front().get(); - TF_RET_CHECK(!pjrt_buffer->IsTuple()); - // On CPU for non-int4 values, we can return the value in a zero-copy way. - // For int4 values, we must copy in order to unpack the array. - if (pjrt_buffer->IsOnCpu() && - !primitive_util::Is4BitType(pjrt_buffer->element_type())) { - TF_ASSIGN_OR_RETURN( - const auto* shape, - IfrtHelpers::xla_dynamic_shape(ifrt_array, dynamic_shape_holder)); - TF_ASSIGN_OR_RETURN(py::dtype dtype, - PrimitiveTypeToDtype(shape->element_type())); - // Objects that must be kept alive while the array is alive. - struct Hold { - tsl::RCReference buffer; - std::unique_ptr external_reference_hold; - }; - auto hold = std::make_unique(); - TF_ASSIGN_OR_RETURN(hold->external_reference_hold, - pjrt_buffer->AcquireExternalReference()); - hold->buffer = tsl::FormRef(ifrt_array); - void* data = - hold->external_reference_hold->OpaqueDeviceMemoryDataPointer(); - py::capsule hold_capsule(hold.release(), - [](void* h) { delete static_cast(h); }); - py::array array(dtype, shape->dimensions(), ByteStridesForShape(*shape), - data, hold_capsule); - array.attr("flags").attr("writeable") = Py_False; - { - py::gil_scoped_release gil; - TF_RETURN_IF_ERROR(ifrt_array->GetReadyFuture().Await()); - } - return array; - } - } - - TF_RETURN_IF_ERROR( - CopyToHostAsync(host_value, dynamic_shape_holder, ifrt_array)); - if (!host_value->ready.HasBeenNotified()) { - py::gil_scoped_release gil; - host_value->ready.WaitForNotification(); - } - TF_RETURN_IF_ERROR(host_value->status); - TF_ASSIGN_OR_RETURN(py::object array, LiteralToPython(host_value->value)); - array.attr("flags").attr("writeable") = Py_False; - return array; -} - -/* static */ Status PyHostValue::CopyToHostAsync( - std::shared_ptr& host_value, - std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { - if (host_value) { - return OkStatus(); - } - auto* arr = llvm::dyn_cast_or_null(ifrt_array); - if (arr != nullptr) { - auto* pjrt_buffer = arr->pjrt_buffers().front().get(); - if (pjrt_buffer->IsOnCpu() && - !primitive_util::Is4BitType(pjrt_buffer->element_type())) { - return OkStatus(); - } - } - auto transfer_guard_formatter = [ifrt_array] { - auto shape = - py::cast(py::str(IfrtHelpers::python_shape(ifrt_array))); - auto dtype = - py::cast(py::str(IfrtHelpers::python_dtype(ifrt_array))); - return absl::StrCat("shape=", shape, ", dtype=", dtype, ", device=", - IfrtHelpers::pjrt_device(ifrt_array)->DebugString()); - }; - TF_RETURN_IF_ERROR( - jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter)); - - auto host_value_copy = std::make_shared(); - host_value = host_value_copy; - // TODO(b/182461453): This is a blocking call. If we further implemented - // populating dynamic shape metadata while fetching the literal, we wouldn't - // need this static approach. - const xla::Shape* dynamic_shape; - std::optional shape_holder; - if (llvm::isa(ifrt_array)) { - TF_ASSIGN_OR_RETURN(dynamic_shape, IfrtHelpers::xla_dynamic_shape( - ifrt_array, dynamic_shape_holder)); - } else { - // Skip querying the dynamic shape for a non-PjRt Array. - TF_ASSIGN_OR_RETURN(xla::PrimitiveType type, - ifrt::ToPrimitiveType(ifrt_array->dtype())); - shape_holder = ShapeUtil::MakeShapeWithDescendingLayout( - type, ifrt_array->shape().dims()); - dynamic_shape = &*shape_holder; - } - - py::gil_scoped_release gil; - xla::Shape host_shape = ShapeUtil::DeviceShapeToHostShape(*dynamic_shape); - // TODO(hyeontaek): Several PjRt runtimes assume that the host buffer uses - // the same transposition as the device buffer. This is different from - // PjRtBuffer::ToLiteral()'s semantics that the runtime respects the layout - // of the host buffer literal. On the other hand, the runtime often knows - // better about an efficient layout for the host buffer. It will be useful - // to revisit the semantics of PjRtBuffer::ToLiteral() to see if it is - // desirable for the runtime to choose the layout. - host_value_copy->value = std::make_shared(host_shape); - ifrt::Future copy_future = ifrt_array->CopyToHostBuffer( - host_value_copy->value->untyped_data(), - ByteStridesOrDefaultForShapeInt64(host_shape), - ifrt::ArrayCopySemantics::kReuseInput); - copy_future.OnReady([host_value{std::move(host_value_copy)}](Status status) { - host_value->status = std::move(status); - host_value->ready.Notify(); - }); - return OkStatus(); -} - -StatusOr IfrtHelpers::CudaArrayInterface( - ifrt::Array* ifrt_array, std::optional& scratch) { - auto* pjrt_buffer = IfrtHelpers::pjrt_buffer(ifrt_array); - if (pjrt_buffer->client()->platform_id() != CudaId()) { - return InvalidArgument( - "__cuda_array_interface__ is only defined for NVidia GPU buffers."); - } - if (pjrt_buffer->IsTuple()) { - return InvalidArgument( - "__cuda_array_interface__ is only defined for array buffers."); - } - if (pjrt_buffer->element_type() == BF16) { - return InvalidArgument( - "__cuda_array_interface__ is not supported for bfloat16 buffers."); - } - if (pjrt_buffer->element_type() == F8E4M3FN) { - return InvalidArgument( - "__cuda_array_interface__ is not supported for F8E4M3FN buffers."); - } - if (pjrt_buffer->element_type() == F8E4M3B11FNUZ) { - return InvalidArgument( - "__cuda_array_interface__ is not supported for F8E4M3B11FNUZ buffers."); - } - if (pjrt_buffer->element_type() == F8E5M2) { - return InvalidArgument( - "__cuda_array_interface__ is not supported for F8E5M2 buffers."); - } - if (pjrt_buffer->element_type() == F8E4M3FNUZ) { - return InvalidArgument( - "__cuda_array_interface__ is not supported for F8E4M3FNUZ buffers."); - } - if (pjrt_buffer->element_type() == F8E5M2FNUZ) { - return InvalidArgument( - "__cuda_array_interface__ is not supported for F8E5M2FNUZ buffers."); - } - TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(pjrt_buffer->layout())); - - py::dict result; - TF_ASSIGN_OR_RETURN(const auto* dynamic_shape, - IfrtHelpers::xla_dynamic_shape(ifrt_array, scratch)); - result["shape"] = SpanToTuple(dynamic_shape->dimensions()); - TF_ASSIGN_OR_RETURN(py::str typestr, TypeDescriptorForPrimitiveType( - pjrt_buffer->element_type())); - result["typestr"] = std::move(typestr); - TF_ASSIGN_OR_RETURN( - std::unique_ptr external_reference_hold, - pjrt_buffer->AcquireExternalReference()); - const void* root_ptr = - external_reference_hold->OpaqueDeviceMemoryDataPointer(); - py::tuple data(2); - data[0] = py::int_(absl::bit_cast(root_ptr)); - data[1] = py::bool_(true); // read-only - result["data"] = std::move(data); - result["version"] = py::int_(2); - return result; -} - -StatusOr ToIfRtDType(py::dtype dtype) { - TF_ASSIGN_OR_RETURN(auto primitive_type, DtypeToPrimitiveType(dtype)); - return ifrt::ToDType(primitive_type); -} - -StatusOr ToPybind11DType(ifrt::DType dtype) { - TF_ASSIGN_OR_RETURN(auto primitive_type, ifrt::ToPrimitiveType(dtype)); - return PrimitiveTypeToDtype(primitive_type); -} - -} // namespace xla diff --git a/xla/python/py_buffer.h b/xla/python/py_buffer.h deleted file mode 100644 index a501dbf8196d4..0000000000000 --- a/xla/python/py_buffer.h +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_PYTHON_PY_BUFFER_H_ -#define XLA_PYTHON_PY_BUFFER_H_ - -#include -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "absl/synchronization/notification.h" -#include "pybind11/numpy.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 -#include "xla/python/ifrt/array.h" -#include "xla/python/pjrt_ifrt/pjrt_array.h" -#include "xla/python/py_client.h" -#include "xla/python/traceback.h" -#include "xla/statusor.h" -#include "xla/types.h" - -namespace xla { - -// TODO(parkers): Move everything in this file to a better home. -struct PyHostValue { - static Status CopyToHostAsync(std::shared_ptr& host_value, - std::optional& dynamic_shape_holder, - ifrt::Array* ifrt_array); - - static StatusOr AsNumPyArray( - std::shared_ptr& host_value, - std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array, - pybind11::handle this_obj); - - absl::Notification ready; - Status status; - std::shared_ptr value; -}; - -struct IfrtHelpers { - static StatusOr xla_dynamic_shape( - ifrt::Array* ifrt_array, std::optional& scratch); - static StatusOr> CopyToDevice( - ifrt::Array* ifrt_array, PjRtDevice* dst_device); - static PjRtBuffer* pjrt_buffer(ifrt::Array* ifrt_array); - static PjRtDevice* pjrt_device(ifrt::Array* ifrt_array); - static pybind11::tuple python_shape(ifrt::Array* ifrt_array); - static pybind11::dtype python_dtype(ifrt::Array* ifrt_array); - static StatusOr CudaArrayInterface( - ifrt::Array* ifrt_array, std::optional& scratch); -}; - -// TODO(hyeontaek): Move the following functions to a separate file. -StatusOr ToIfRtDType(pybind11::dtype dtype); -StatusOr ToPybind11DType(ifrt::DType dtype); - -} // namespace xla - -#endif // XLA_PYTHON_PY_BUFFER_H_ diff --git a/xla/python/py_client.cc b/xla/python/py_client.cc index 2caed4eb9562f..ba09f96c570f1 100644 --- a/xla/python/py_client.cc +++ b/xla/python/py_client.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,37 +15,84 @@ limitations under the License. #include "xla/python/py_client.h" +#include + +#include +#include #include #include #include #include +#include #include +#include #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/optional.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/pair.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/variant.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/vector.h" // from @nanobind // IWYU pragma: keep +#include "xla/literal.h" +#include "xla/pjrt/exceptions.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/pjrt_stream_executor_client.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/callback.h" -#include "xla/python/exceptions.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/host_callback.h" #include "xla/python/ifrt/memory.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" #include "xla/python/pprof_profile_builder.h" #include "xla/python/py_array.h" -#include "xla/python/py_buffer.h" +#include "xla/python/py_device.h" #include "xla/python/py_executable.h" #include "xla/python/py_host_callback.h" +#include "xla/python/py_memory_space.h" #include "xla/python/py_values.h" #include "xla/python/python_ref_manager.h" #include "xla/python/traceback.h" #include "xla/python/transfer_guard_lib.h" +#include "xla/python/types.h" #include "xla/service/custom_call_target_registry.h" -#include "xla/service/platform_util.h" +#include "xla/service/platform_util.h" // IWYU pragma: keep +#include "xla/shape.h" +#include "xla/status_macros.h" +#include "xla/util.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/casts.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -54,7 +101,14 @@ limitations under the License. namespace xla { -namespace py = pybind11; +namespace nb = nanobind; + +/*static*/ nb_class_ptr PyClient::Make( + std::shared_ptr ifrt_client) { + auto client = make_nb_class(std::move(ifrt_client)); + Initialize(client); + return client; +} PyClient::PyClient(std::shared_ptr ifrt_client) : ifrt_client_(std::move(ifrt_client)), @@ -62,58 +116,81 @@ PyClient::PyClient(std::shared_ptr ifrt_client) CHECK(ifrt_client_); } +/* static */ void PyClient::Initialize(nb_class_ptr client) { + for (ifrt::Device* device : client->ifrt_client()->devices()) { + client->devices_[device] = make_nb_class(client, device); + + for (PjRtMemorySpace* memory : device->memory_spaces()) { + auto& py_memory = client->memory_spaces_[memory]; + if (py_memory.get() == nullptr) { + py_memory = make_nb_class(client, memory); + } + } + } +} + PyClient::~PyClient() { - py::gil_scoped_release gil; + nb::gil_scoped_release gil; ifrt_client_ = nullptr; } -std::vector> PyClient::Devices() { - std::vector> devices; +nb_class_ptr PyClient::GetPyDevice(ifrt::Device* device) { + auto& py_device = devices_[device]; + if (py_device.get() == nullptr) { + py_device = make_nb_class( + nb::borrow>(nb::find(this)), device); + } + return py_device; +} + +nb_class_ptr PyClient::GetPyMemorySpace( + PjRtMemorySpace* memory_space) { + auto& py_memory = memory_spaces_[memory_space]; + if (py_memory.get() == nullptr) { + py_memory = make_nb_class( + nb::borrow>(nb::find(this)), memory_space); + } + return py_memory; +} + +std::vector> PyClient::Devices() { + std::vector> devices; auto span = ifrt_client_->devices(); devices.reserve(span.size()); for (PjRtDevice* device : span) { - devices.push_back(WrapWithClient(shared_from_this(), device)); + devices.push_back(GetPyDevice(device)); } return devices; } -std::vector> PyClient::LocalDevices() { - std::vector> devices; +std::vector> PyClient::LocalDevices() { + std::vector> devices; devices.reserve(ifrt_client_->addressable_devices().size()); for (ifrt::Device* device : ifrt_client_->addressable_devices()) { - devices.push_back(WrapWithClient(shared_from_this(), device)); + devices.push_back(GetPyDevice(device)); } return devices; } -StatusOr> PyClient::DeviceFromLocalHardwareId( +absl::StatusOr> PyClient::DeviceFromLocalHardwareId( int local_hardware_id) { TF_ASSIGN_OR_RETURN(PjRtDevice * device, ifrt_client_->LookupAddressableDevice(local_hardware_id)); - return WrapWithClient(shared_from_this(), device); + return GetPyDevice(device); } -std::vector PyClient::LiveBuffers() { +nb::list PyClient::LiveExecutables() { CHECK(PyGILState_Check()); - std::vector buffers; - for (py::object& array : LiveArrays()) { - buffers.push_back(std::move(array)); - } - return buffers; -} - -std::vector> PyClient::LiveExecutables() { - CHECK(PyGILState_Check()); - std::vector> executables; + nb::list executables; for (PyLoadedExecutable* exec = executables_; exec; exec = exec->next_) { if (!exec->is_deleted()) { - executables.push_back(exec->shared_from_this()); + executables.append(nb::find(exec)); } } return executables; } -Status PyClient::Defragment() { +absl::Status PyClient::Defragment() { CHECK(PyGILState_Check()); auto runtime_type = ifrt_client_->runtime_type(); if (runtime_type == PjRtRuntimeTypeString(PjRtRuntimeType::kTfrt)) { @@ -195,31 +272,33 @@ Status PyClient::Defragment() { // TODO(skyewm): delete executables? } - return OkStatus(); + return absl::OkStatus(); } -StatusOr PyClient::BufferFromPyval( - pybind11::handle argument, PjRtDevice* device, bool force_copy, - ifrt::Client::HostBufferSemantics host_buffer_semantics) { +/* static */ absl::StatusOr PyClient::BufferFromPyval( + nb_class_ptr client, nb::handle argument, PjRtDevice* device, + bool force_copy, ifrt::Client::HostBufferSemantics host_buffer_semantics) { if (device == nullptr) { - TF_RET_CHECK(!ifrt_client_->addressable_devices().empty()); - device = ifrt_client_->addressable_devices().front(); + TF_RET_CHECK(!client->ifrt_client_->addressable_devices().empty()); + device = client->ifrt_client_->addressable_devices().front(); } CHECK(device != nullptr); auto transfer_guard_formatter = [&argument, dst_device = device] { - auto type = py::cast(py::str(argument.get_type())); + auto type = nb::cast(nb::str(argument.type())); // Catch exceptions because shape and dtype properties convertible to str // are not guaranteed to present in an arbitrary argument. std::string shape; std::string dtype; try { - shape = py::cast(py::str(argument.attr("shape"))); + shape = + nb::cast(nb::str(nb::object(argument.attr("shape")))); } catch (const std::exception& e) { shape = ""; } try { - dtype = py::cast(py::str(argument.attr("dtype"))); + dtype = + nb::cast(nb::str(nb::object(argument.attr("dtype")))); } catch (const std::exception& e) { dtype = ""; } @@ -230,87 +309,86 @@ StatusOr PyClient::BufferFromPyval( jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); TF_ASSIGN_OR_RETURN(PjRtDevice * found_device, - ifrt_client_->LookupDevice(device->id())); + client->ifrt_client_->LookupDevice(device->id())); if (found_device != device) { return InvalidArgument("Cannot copy value to device '%s' with '%s' backend", device->DebugString(), - ifrt_client_->platform_name()); + client->ifrt_client_->platform_name()); } GlobalPyRefManager()->CollectGarbage(); DevicePutOptions options; options.squash_64bit_types = false; options.allow_zero_copy = - (!force_copy && - (host_buffer_semantics == ifrt::Client::HostBufferSemantics::kZeroCopy)); + (!force_copy && (host_buffer_semantics == + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); + // TODO(phawkins): remove .ptr() after nanobind transition is complete. TF_ASSIGN_OR_RETURN(DevicePutResult put, - DevicePut(argument, ifrt_client_.get(), device, options, - ifrt::MemoryKind())); + DevicePut(argument.ptr(), client->ifrt_client_.get(), + device, options, ifrt::MemoryKind())); if (put.ifrt_array) { auto traceback = Traceback::Get(); return PyArray::MakeFromSingleDeviceArray( - shared_from_this(), std::move(traceback), std::move(put.ifrt_array), + std::move(client), std::move(traceback), std::move(put.ifrt_array), /*weak_type=*/false, /*committed=*/false); } else { - return py::reinterpret_borrow(put.owning_pybuffer); + return put.owning_pybuffer; } } -StatusOr>> -PyClient::MakeCrossHostReceiveBuffers(absl::Span shapes, - PjRtDevice* device) { +/* static */ absl::StatusOr PyClient::MakeCrossHostReceiveBuffers( + nb_class_ptr client, absl::Span shapes, + PjRtDevice* device) { CHECK(device != nullptr); absl::Mutex mu; - StatusOr> recv_descriptors_or; + absl::StatusOr> recv_descriptors_or; bool done = false; TF_ASSIGN_OR_RETURN( - auto buffers, pjrt_client()->MakeCrossHostReceiveBuffers( - shapes, device, - [&done, &recv_descriptors_or, - &mu](StatusOr recv_state_or) { - absl::MutexLock l(&mu); - if (recv_state_or.ok()) { - py::gil_scoped_acquire gil; - recv_descriptors_or = - std::move(recv_state_or->descriptors); - } else { - recv_descriptors_or = recv_state_or.status(); - } - done = true; - })); + auto buffers, + client->pjrt_client()->MakeCrossHostReceiveBuffers( + shapes, device, + [&done, &recv_descriptors_or, + &mu](absl::StatusOr recv_state_or) { + absl::MutexLock l(&mu); + if (recv_state_or.ok()) { + nb::gil_scoped_acquire gil; + recv_descriptors_or = std::move(recv_state_or->descriptors); + } else { + recv_descriptors_or = recv_state_or.status(); + } + done = true; + })); { - py::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; absl::MutexLock l(&mu); mu.Await(absl::Condition(&done)); } TF_RETURN_IF_ERROR(recv_descriptors_or.status()); CHECK_EQ(buffers.size(), recv_descriptors_or->size()); - std::vector> result; - result.reserve(buffers.size()); + nb::list result; for (int i = 0; i < buffers.size(); ++i) { auto& descriptors = recv_descriptors_or->at(i); CHECK_EQ(descriptors.serialized_descriptors.size(), 1); const std::string& desc = descriptors.serialized_descriptors[0]; - pybind11::bytes py_desc = pybind11::bytes(desc); - auto traceback = Traceback::Get(); - auto* client = - llvm::dyn_cast_or_null(ifrt_client()); - if (client == nullptr) { + nb::bytes py_desc = nb::bytes(desc.data(), desc.size()); + auto* ifrt_client = llvm::dyn_cast_or_null( + client->ifrt_client()); + if (ifrt_client == nullptr) { throw XlaRuntimeError( "This operation is implemented for a PjRt-compatible backend only."); } TF_ASSIGN_OR_RETURN(auto ifrt_array, - client->CreatePjRtArray(std::move(buffers[i]))); - auto py_buf = PyArray::MakeFromSingleDeviceArray( - shared_from_this(), Traceback::Get(), std::move(ifrt_array), - /*weak_type=*/false, - /*committed=*/false); - result.push_back(std::make_pair(std::move(py_desc), std::move(py_buf))); + ifrt_client->CreatePjRtArray(std::move(buffers[i]))); + auto py_buf = PyArray::MakeFromSingleDeviceArray(client, Traceback::Get(), + std::move(ifrt_array), + /*weak_type=*/false, + /*committed=*/false); + result.append(nb::make_tuple(std::move(py_desc), std::move(py_buf))); } return result; } @@ -320,7 +398,7 @@ namespace { // Makes IFRT `CompileOptions` from XLA `CompileOptions` and optional host // callbacks. std::unique_ptr MakeIfrtCompileOptions( - CompileOptions options, std::vector host_callbacks) { + CompileOptions options, std::vector host_callbacks) { std::vector> ifrt_loaded_host_callbacks; ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); @@ -328,8 +406,8 @@ std::unique_ptr MakeIfrtCompileOptions( // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or // `PyClient::GetEmitPythonCallbackDescriptor()`. for (auto& host_callback : host_callbacks) { - ifrt_loaded_host_callbacks.push_back( - tsl::FormRef(host_callback.get_pointer())); + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); } return std::make_unique( std::move(options), std::move(ifrt_loaded_host_callbacks)); @@ -338,9 +416,8 @@ std::unique_ptr MakeIfrtCompileOptions( // Makes IFRT `DeserializeExecutableOptions` from XLA `CompileOptions` and // optional host callbacks. std::unique_ptr -MakeIfrtDeserializeExecutableOptions( - std::optional options, - std::vector host_callbacks) { +MakeIfrtDeserializeExecutableOptions(std::optional options, + std::vector host_callbacks) { std::vector> ifrt_loaded_host_callbacks; ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); @@ -348,8 +425,8 @@ MakeIfrtDeserializeExecutableOptions( // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or // `PyClient::GetEmitPythonCallbackDescriptor()`. for (auto& host_callback : host_callbacks) { - ifrt_loaded_host_callbacks.push_back( - tsl::FormRef(host_callback.get_pointer())); + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); } return std::make_unique( std::move(options), std::move(ifrt_loaded_host_callbacks)); @@ -357,14 +434,19 @@ MakeIfrtDeserializeExecutableOptions( } // namespace -StatusOr> PyClient::Compile( - std::string mlir_module, CompileOptions options, - std::vector host_callbacks) { - // Pass allocated device memory size to compile options for pjrt compatible - // backends. +/* static */ absl::StatusOr> +PyClient::CompileIfrtProgram( + nb_class_ptr client, std::unique_ptr ifrt_program, + std::unique_ptr ifrt_options) { auto* pjrt_compatible_client = - llvm::dyn_cast_or_null(ifrt_client_.get()); - if (pjrt_compatible_client != nullptr) { + llvm::dyn_cast_or_null( + client->ifrt_client_.get()); + auto* ifrt_xla_options = + llvm::dyn_cast_or_null(ifrt_options.get()); + // For XLA programs, pass allocated device memory size to compile options for + // pjrt compatible backends. + if (pjrt_compatible_client != nullptr && ifrt_xla_options != nullptr) { + xla::CompileOptions& options = ifrt_xla_options->compile_options; auto addressable_devices = pjrt_compatible_client->pjrt_client()->addressable_devices(); if (!addressable_devices.empty()) { @@ -383,50 +465,58 @@ StatusOr> PyClient::Compile( std::unique_ptr ifrt_loaded_executable; std::optional fingerprint; - auto ifrt_compile_options = - MakeIfrtCompileOptions(std::move(options), std::move(host_callbacks)); { - py::gil_scoped_release gil_release; - mlir::MLIRContext context; - TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, - ParseMlirModuleString(mlir_module, context)); - TF_ASSIGN_OR_RETURN( - ifrt_loaded_executable, - ifrt_client_->GetDefaultCompiler()->Compile( - std::make_unique(module.get()), - std::move(ifrt_compile_options))); + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(ifrt_loaded_executable, + client->ifrt_client_->GetDefaultCompiler()->Compile( + std::move(ifrt_program), std::move(ifrt_options))); TF_ASSIGN_OR_RETURN(fingerprint, ifrt_loaded_executable->Fingerprint()); } auto traceback = Traceback::Get(); - return std::make_shared( - shared_from_this(), std::move(ifrt_loaded_executable), + return make_nb_class( + std::move(client), std::move(ifrt_loaded_executable), std::move(traceback), std::move(fingerprint)); } -StatusOr PyClient::SerializeExecutable( +/* static */ absl::StatusOr> PyClient::Compile( + nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks) { + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + return CompileIfrtProgram( + client, std::make_unique(module.get()), + MakeIfrtCompileOptions(std::move(options), std::move(host_callbacks))); +} + +absl::StatusOr PyClient::SerializeExecutable( const PyLoadedExecutable& executable) const { - return executable.ifrt_loaded_executable()->Serialize(); + TF_ASSIGN_OR_RETURN(auto serialized, + executable.ifrt_loaded_executable()->Serialize()); + return nb::bytes(serialized.data(), serialized.size()); } -StatusOr> PyClient::DeserializeExecutable( - const std::string& serialized, std::optional options, - std::vector host_callbacks) { +/* static */ absl::StatusOr> +PyClient::DeserializeExecutable(nb_class_ptr client, + nb::bytes serialized, + std::optional options, + std::vector host_callbacks) { std::unique_ptr ifrt_loaded_executable; std::optional fingerprint; auto ifrt_deserialize_options = MakeIfrtDeserializeExecutableOptions( std::move(options), std::move(host_callbacks)); { - py::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN( ifrt_loaded_executable, - ifrt_client_->GetDefaultCompiler()->DeserializeLoadedExecutable( - serialized, std::move(ifrt_deserialize_options))); - TF_ASSIGN_OR_RETURN(fingerprint, ifrt_loaded_executable->Fingerprint()); + client->ifrt_client_->GetDefaultCompiler()->DeserializeLoadedExecutable( + std::string_view(serialized.c_str(), serialized.size()), + std::move(ifrt_deserialize_options))); } TF_ASSIGN_OR_RETURN(fingerprint, ifrt_loaded_executable->Fingerprint()); auto traceback = Traceback::Get(); - return std::make_shared( - shared_from_this(), std::move(ifrt_loaded_executable), + return make_nb_class( + std::move(client), std::move(ifrt_loaded_executable), std::move(traceback), std::move(fingerprint)); } @@ -463,7 +553,7 @@ H AbslHashValue(H h, const HeapProfileKey& key) { } // namespace -StatusOr PyClient::HeapProfile() { +absl::StatusOr PyClient::HeapProfile() { CHECK(PyGILState_Check()); absl::flat_hash_set buffer_set; absl::flat_hash_map entries; @@ -477,7 +567,7 @@ StatusOr PyClient::HeapProfile() { buffer->device()}; ++entries[key]; } - return OkStatus(); + return absl::OkStatus(); }; for (PyArray_Storage* array = arrays_; array; array = array->next) { @@ -493,16 +583,17 @@ StatusOr PyClient::HeapProfile() { "only."); } for (const auto& buffer : arr->pjrt_buffers()) { - TF_RETURN_IF_ERROR( - add_buffer_to_profile(buffer.get(), array->traceback.get())); + TF_RETURN_IF_ERROR(add_buffer_to_profile( + buffer.get(), array->traceback ? array->traceback->get() : nullptr)); } } for (PyLoadedExecutable* executable = executables_; executable; executable = executable->next_) { if (!executable->is_deleted()) { - HeapProfileKey key{executable->traceback(), - executable->SizeOfGeneratedCodeInBytes(), nullptr}; + HeapProfileKey key{ + executable->traceback() ? executable->traceback()->get() : nullptr, + executable->SizeOfGeneratedCodeInBytes(), nullptr}; ++entries[key]; } } @@ -535,46 +626,49 @@ StatusOr PyClient::HeapProfile() { kind_label->set_str(buffer_string_id); auto* device_label = sample->add_label(); device_label->set_key(device_string_id); - device_label->set_str( - builder.StringId(std::string(entry.first.device->DebugString()))); + std::string device_label_str(entry.first.device->DebugString()); + device_label->set_str(builder.StringId(device_label_str)); } else { kind_label->set_str(executable_string_id); } } - return py::bytes(builder.profile().SerializeAsString()); + std::string serialized = builder.profile().SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); } -StatusOr PyClient::MakePythonCallbackUsingHostSendAndRecv( - pybind11::function callable, absl::Span operand_shapes, +absl::StatusOr PyClient::MakePythonCallbackUsingHostSendAndRecv( + nb::callable callable, absl::Span operand_shapes, absl::Span result_shapes, absl::Span send_channel_ids, - absl::Span recv_channel_ids, - pybind11::function serializer) { + absl::Span recv_channel_ids, nb::callable serializer) { TF_ASSIGN_OR_RETURN( auto loaded_host_callback, PyHostSendAndRecvLoadedHostCallback::Create( ifrt_client(), std::move(callable), operand_shapes, result_shapes, send_channel_ids, recv_channel_ids, std::move(serializer))); - py::capsule callback_capsule(loaded_host_callback.release(), [](void* ptr) { - static_cast(ptr)->DropRef(); - }); + nb::capsule callback_capsule( + loaded_host_callback.release(), [](void* ptr) noexcept { + static_cast(ptr)->DropRef(); + }); return callback_capsule; } -StatusOr> -PyClient::GetEmitPythonCallbackDescriptor( - pybind11::function callable, absl::Span operand_shapes, - absl::Span result_shapes) { - TF_ASSIGN_OR_RETURN( - auto loaded_host_callback, - PyCpuLoadedHostCallback::Create(ifrt_client(), std::move(callable), - operand_shapes, result_shapes)); +absl::StatusOr> +PyClient::GetEmitPythonCallbackDescriptor(nb::callable callable, + nb::object operand_shapes, + nb::object result_shapes) { + TF_ASSIGN_OR_RETURN(auto loaded_host_callback, + PyCpuLoadedHostCallback::Create( + ifrt_client(), std::move(callable), + nb::cast>(operand_shapes), + nb::cast>(result_shapes))); const uint64_t descriptor = loaded_host_callback->descriptor(); - py::capsule callback_capsule(loaded_host_callback.release(), [](void* ptr) { - static_cast(ptr)->DropRef(); - }); - return std::make_pair(descriptor, py::object(std::move(callback_capsule))); + nb::capsule callback_capsule( + loaded_host_callback.release(), [](void* ptr) noexcept { + static_cast(ptr)->DropRef(); + }); + return std::make_pair(descriptor, nb::object(std::move(callback_capsule))); } XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_cpu_callback", @@ -586,4 +680,154 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value())); #endif +/* static */ int PyClient::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + PyClient* c = nb::inst_ptr(self); + for (const auto& [ifrt_device, py_device] : c->devices_) { + Py_VISIT(py_device.ptr()); + } + for (const auto& [ifrt_memory, py_memory] : c->memory_spaces_) { + Py_VISIT(py_memory.ptr()); + } + return 0; +} + +/* static */ int PyClient::tp_clear(PyObject* self) { + PyClient* c = nb::inst_ptr(self); + absl::flat_hash_map> devices; + std::swap(devices, c->devices_); + absl::flat_hash_map> + memory_spaces; + std::swap(memory_spaces, c->memory_spaces_); + return 0; +} + +PyType_Slot PyClient::slots_[] = { + {Py_tp_traverse, (void*)PyClient::tp_traverse}, + {Py_tp_clear, (void*)PyClient::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyClient::RegisterPythonTypes(nb::module_& m) { + nb::enum_(m, "HostBufferSemantics") + .value("IMMUTABLE_ONLY_DURING_CALL", + PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall) + .value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES", + PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes) + .value("ZERO_COPY", PjRtClient::HostBufferSemantics::kImmutableZeroCopy); + + nb::class_ py_local_client(m, "Client", nb::is_weak_referenceable(), + nb::type_slots(PyClient::slots_)); + py_local_client.def_prop_ro("platform", &PyClient::platform_name) + .def_prop_ro("platform_version", &PyClient::platform_version) + .def_prop_ro("runtime_type", &PyClient::runtime_type) + .def("device_count", &PyClient::device_count) + .def("local_device_count", &PyClient::addressable_device_count) + .def("devices", &PyClient::Devices) + .def("local_devices", &PyClient::LocalDevices) + .def("device_from_local_hardware_id", + xla::ValueOrThrowWrapper(&PyClient::DeviceFromLocalHardwareId)) + .def("live_executables", &PyClient::LiveExecutables) + .def("live_arrays", &PyClient::LiveArrays) + .def("live_buffers", &PyClient::LiveArrays) + .def("process_index", &PyClient::process_index) + .def("host_id", &PyClient::process_index) + .def("task_id", &PyClient::process_index) + .def( + "buffer_from_pyval", + [](nb_class_ptr client, nb::handle argument, + PyDevice* device, bool force_copy, + PjRtClient::HostBufferSemantics host_buffer_semantics) { + return ValueOrThrow( + PyClient::BufferFromPyval(std::move(client), argument, + device ? device->device() : nullptr, + force_copy, host_buffer_semantics)); + }, + nb::arg("argument"), nb::arg("device").none() = nullptr, + nb::arg("force_copy") = false, + nb::arg("host_buffer_semantics") = + PjRtClient::HostBufferSemantics::kImmutableZeroCopy) + .def( + "make_cross_host_receive_buffers", + [](nb_class_ptr client, absl::Span shapes, + PjRtDevice* device) { + return ValueOrThrow(PyClient::MakeCrossHostReceiveBuffers( + std::move(client), shapes, device)); + }, + nb::arg("shapes"), nb::arg("device")) + .def( + "compile", + [](nb_class_ptr client, nb::bytes mlir_module, + CompileOptions options, std::vector host_callbacks) { + return ValueOrThrow(PyClient::Compile( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(options), std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile", + [](nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks) { + return ValueOrThrow(PyClient::Compile( + std::move(client), std::move(mlir_module), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def("compile_ifrt_program", + xla::ValueOrThrowWrapper(PyClient::CompileIfrtProgram)) + .def("serialize_executable", + xla::ValueOrThrowWrapper(&PyClient::SerializeExecutable)) + .def( + "deserialize_executable", + [](nb_class_ptr client, nb::bytes serialized, + std::optional options, + std::vector host_callbacks) { + return ValueOrThrow(PyClient::DeserializeExecutable( + std::move(client), std::move(serialized), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("serialized"), nb::arg("compile_options").none() = nb::none(), + nb::arg("host_callbacks") = std::vector()) + .def("heap_profile", xla::ValueOrThrowWrapper(&PyClient::HeapProfile)) + // TODO(zhangqiaorjc): Experimental. + .def("defragment", + [](PyClient& self) { xla::ThrowIfError(self.Defragment()); }) + .def("get_emit_python_callback_descriptor", + xla::ValueOrThrowWrapper(&PyClient::GetEmitPythonCallbackDescriptor), + nb::arg("callable"), nb::arg("operand_shapes"), + nb::arg("result_shapes").none() = nb::none()) + .def("make_python_callback_from_host_send_and_recv", + xla::ValueOrThrowWrapper( + &PyClient::MakePythonCallbackUsingHostSendAndRecv), + nb::arg("callable"), nb::arg("operand_shapes"), + nb::arg("result_shapes"), nb::arg("send_channel_ids"), + nb::arg("recv_channel_ids"), + nb::arg("serializer").none() = nb::none()) + .def( + "get_default_layout", + [](PyClient& self, nb_dtype dtype, nb::sequence shard_shape, + nb_class_ptr device) -> std::unique_ptr { + ifrt::DType ifrt_type = xla::ValueOrThrow(DtypeToIfRtDType(dtype)); + std::vector dims = SequenceToVector(shard_shape); + return xla::ValueOrThrow( + self.ifrt_client()->GetDefaultLayoutForDevice( + ifrt_type, dims, device->device())); + }, + nb::arg("dtype"), nb::arg("shard_shape"), nb::arg("device")) + .def("__getattr__", + [](PyClient& client, std::string_view name) -> nb::object { + const auto& attrs = client.attributes(); + auto it = attrs.find(name); + if (it != attrs.end()) { + return std::visit([](auto&& v) { return nb::cast(v); }, + it->second); + } + throw nb::attribute_error( + absl::StrCat("Unknown attribute ", name).c_str()); + }); +} + } // namespace xla diff --git a/xla/python/py_client.h b/xla/python/py_client.h index 581801831ffec..b5f86cc31e857 100644 --- a/xla/python/py_client.h +++ b/xla/python/py_client.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,98 +16,49 @@ limitations under the License. #ifndef XLA_PYTHON_PY_CLIENT_H_ #define XLA_PYTHON_PY_CLIENT_H_ +#include + #include #include #include #include +#include #include #include #include "absl/container/flat_hash_map.h" -#include "pybind11/pybind11.h" // from @pybind11 +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" // from @nanobind #include "xla/client/xla_builder.h" +#include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/pjrt/pjrt_common.h" -#include "xla/python/exceptions.h" +#include "xla/pjrt/pjrt_executable.h" #include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" -#include "xla/statusor.h" -#include "xla/types.h" +#include "xla/shape.h" namespace xla { class PyClient; class PyLoadedExecutable; class PyArray; +class PyDevice; +class PyMemorySpace; struct PyArray_Storage; -// Custom holder types. -// -// We must keep the PyClient object alive as long as any of the runtime -// objects are alive. Since we don't have a lot of control over Python -// destructor ordering, we keep the PyClient object as a std::shared_ptr<>, -// and ensure that each Python runtime object holds a reference to the -// PyClient. An alternative design would be to keep a single global -// singleton PyClient, although this seems less flexible, especially for -// writing tests. -// -// To maintain PyClient references, we define pybind11 holder classes that -// are custom smart pointers that also keep a reference to a PyClient. -// pybind11 has a `keep_alive` feature that has a similar goal, but it doesn't -// seem sufficiently flexible to describe ownership relationships in cases where -// the ownership doesn't pertain to a direct argument or return value of a -// function. Another alternative to the holder classes would be to create proxy -// objects that contain both a reference and a runtime class; holder classes -// seem less tedious to define. - -// A pair of a PyClient reference and an unowned pointer to T. -template -class ClientAndPtr { - public: - ClientAndPtr() = default; - // pybind11 requires that we define a constructor that takes a raw pointer, - // but it should be unreachable. - explicit ClientAndPtr(T*) { - LOG(FATAL) << "ClientAndPtr should constructed via WrapWithClient."; - } - - ClientAndPtr(const ClientAndPtr&) = default; - ClientAndPtr(ClientAndPtr&&) = default; - ClientAndPtr& operator=(const ClientAndPtr&) = default; - ClientAndPtr& operator=(ClientAndPtr&&) = default; - - PyClient* get_client() const { return client_; } - - std::shared_ptr client() const { - return std::shared_ptr(contents_, client_); - } - - T* get() const { return contents_.get(); } - T* operator->() const { return contents_.get(); } - T& operator*() const { return *contents_; } - - private: - template - friend ClientAndPtr WrapWithClient(std::shared_ptr client, - U* contents); - std::shared_ptr contents_; - PyClient* client_; -}; - -// By defining a templated helper function, we can use return type deduction -// and avoid specifying types at the caller. -template -ClientAndPtr WrapWithClient(std::shared_ptr client, T* contents) { - ClientAndPtr result; - result.client_ = client.get(); - result.contents_ = std::shared_ptr(std::move(client), contents); - return result; -} - // Python wrapper around PjRtClient. // We use a wrapper class to add Python-specific functionality. -class PyClient : public std::enable_shared_from_this { +class PyClient { public: + static nb_class_ptr Make(std::shared_ptr ifrt_client); + + // Do not call the constructor directly. Use `PyClient::Make` instead. explicit PyClient(std::shared_ptr ifrt_client); virtual ~PyClient(); @@ -139,7 +90,7 @@ class PyClient : public std::enable_shared_from_this { return shared_ptr_pjrt_client(); } - absl::string_view platform_name() const { + std::string_view platform_name() const { // TODO(phawkins): this is a temporary backwards compatibility shim. We // changed the name PJRT reports for GPU platforms to "cuda" or "rocm", but // we haven't yet updated JAX clients that expect "gpu". Migrate users and @@ -151,12 +102,10 @@ class PyClient : public std::enable_shared_from_this { return ifrt_client_->platform_name(); } } - absl::string_view platform_version() const { + std::string_view platform_version() const { return ifrt_client_->platform_version(); } - absl::string_view runtime_type() const { - return ifrt_client_->runtime_type(); - } + std::string_view runtime_type() const { return ifrt_client_->runtime_type(); } // Returns implementation-specific attributes about this client, e.g. the PJRT // C API version if applicable. @@ -171,44 +120,53 @@ class PyClient : public std::enable_shared_from_this { int device_count() const { return ifrt_client_->device_count(); } int process_index() const { return ifrt_client_->process_index(); } - std::vector> Devices(); - std::vector> LocalDevices(); - StatusOr> DeviceFromLocalHardwareId( + std::vector> Devices(); + std::vector> LocalDevices(); + absl::StatusOr> DeviceFromLocalHardwareId( int local_hardware_id); + // Returns the PyDevice associated with the given PjRtDevice. + nb_class_ptr GetPyDevice(PjRtDevice* device); + + // Returns the PyMemorySpace associated with the given PjRtMemorySpace. + nb_class_ptr GetPyMemorySpace(PjRtMemorySpace* memory_space); + // Returns a vector of live PyArray objects. PyArray objects may share // PjRtBuffers, so there may be duplicates of the same underlying device // buffer. - std::vector LiveBuffers(); - std::vector LiveBuffersOnDevice(PjRtDevice* device); + std::vector LiveBuffersOnDevice(PjRtDevice* device); - // Returns a vector of live PyLoadedExecutable objects. - // note: must return std::shared_ptr instead of raw ptrs - // https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#std-shared-ptr - std::vector> LiveExecutables(); + nanobind::list LiveExecutables(); // TODO(zhangqiaorjc): Remove when we have transparent defragmentation. - Status Defragment(); + absl::Status Defragment(); - StatusOr>> - MakeCrossHostReceiveBuffers(absl::Span shapes, - PjRtDevice* device); + static absl::StatusOr MakeCrossHostReceiveBuffers( + nb_class_ptr client, absl::Span shapes, + PjRtDevice* device); - StatusOr BufferFromPyval( - pybind11::handle argument, PjRtDevice* device, bool force_copy, + static absl::StatusOr BufferFromPyval( + nb_class_ptr client, nanobind::handle argument, + PjRtDevice* device, bool force_copy, ifrt::Client::HostBufferSemantics host_buffer_semantics); - StatusOr> Compile( - std::string mlir_module, CompileOptions options, - std::vector host_callbacks); + static absl::StatusOr> CompileIfrtProgram( + nb_class_ptr client, + std::unique_ptr ifrt_program, + std::unique_ptr ifrt_options); + + static absl::StatusOr> Compile( + nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks); - StatusOr SerializeExecutable( + absl::StatusOr SerializeExecutable( const PyLoadedExecutable& executable) const; - StatusOr> DeserializeExecutable( - const std::string& serialized, std::optional options, - std::vector host_callbacks); + static absl::StatusOr> DeserializeExecutable( + nb_class_ptr client, nanobind::bytes serialized, + std::optional options, + std::vector host_callbacks); - StatusOr HeapProfile(); + absl::StatusOr HeapProfile(); // `GetEmitPythonCallbackDescriptor` takes in an input Python callable that // takes in arguments of shapes `operand_shapes` and returns values of shapes @@ -222,19 +180,21 @@ class PyClient : public std::enable_shared_from_this { // The callable receives as arguments NumPy arrays for arguments with array // types, and None for Token argument. The callable must return a tuple of // either arrays or None values. - StatusOr> - GetEmitPythonCallbackDescriptor(pybind11::function callable, - absl::Span operand_shapes, - absl::Span result_shapes); + // TODO(phawkins): pass operand_shapes and result_shapes as + // absl::Span when nanobind transition is complete. + absl::StatusOr> + GetEmitPythonCallbackDescriptor(nanobind::callable callable, + nanobind::object operand_shapes, + nanobind::object result_shapes); // Deprecated; please switch to emitting a `CustomCallOp` directly. - StatusOr EmitPythonCallbackFromDescriptor( + absl::StatusOr EmitPythonCallbackFromDescriptor( XlaBuilder& builder, uint64_t descriptor, absl::Span operands, absl::Span result_shapes, std::optional> operand_layouts, bool has_side_effect); // Deprecated; please switch to using `GetEmitPythonCallbackDescriptor` // and then emitting a `CustomCall` op instead. - StatusOr> EmitPythonCallback( - pybind11::function callable, XlaBuilder& builder, + absl::StatusOr> EmitPythonCallback( + nanobind::callable callable, XlaBuilder& builder, absl::Span operands, absl::Span result_shapes, std::optional> operand_layouts, bool has_side_effect); @@ -253,20 +213,29 @@ class PyClient : public std::enable_shared_from_this { // The callable receives as arguments NumPy arrays for arguments with array // types, and None for Token argument. The callable must return a tuple of // either arrays or None values. - StatusOr MakePythonCallbackUsingHostSendAndRecv( - pybind11::function callable, absl::Span operand_shapes, + absl::StatusOr MakePythonCallbackUsingHostSendAndRecv( + nanobind::callable callable, absl::Span operand_shapes, absl::Span result_shapes, absl::Span send_channel_ids, absl::Span recv_channel_ids, - pybind11::function serializer); + nanobind::callable serializer); - std::vector LiveArrays(); + std::vector LiveArrays() const; + + static void RegisterPythonTypes(nanobind::module_& m); + + protected: + static void Initialize(nb_class_ptr client); private: friend class PyLoadedExecutable; friend class PyArray; friend struct PyArray_Storage; + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + std::shared_ptr ifrt_client_; absl::flat_hash_map client_attributes_; @@ -276,10 +245,12 @@ class PyClient : public std::enable_shared_from_this { PyLoadedExecutable* executables_ = nullptr; PyArray_Storage* arrays_ = nullptr; + + absl::flat_hash_map> devices_; + absl::flat_hash_map> + memory_spaces_; }; } // namespace xla -PYBIND11_DECLARE_HOLDER_TYPE(T, xla::ClientAndPtr); - #endif // XLA_PYTHON_PY_CLIENT_H_ diff --git a/xla/python/py_client_gpu.cc b/xla/python/py_client_gpu.cc index 6a5caf680243d..100d9fd59942c 100644 --- a/xla/python/py_client_gpu.cc +++ b/xla/python/py_client_gpu.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,10 +24,12 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #endif -#include "pybind11/pybind11.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/host_callback.h" #include "xla/primitive_util.h" #include "xla/python/callback.h" -#include "xla/python/exceptions.h" +#include "xla/python/nb_numpy.h" #if TENSORFLOW_USE_ROCM #define gpuSuccess hipSuccess @@ -45,7 +47,7 @@ limitations under the License. #define gpuMemcpyHostToDevice cudaMemcpyHostToDevice #endif -namespace py = pybind11; +namespace nb = nanobind; namespace xla { @@ -79,36 +81,40 @@ void XlaPythonGpuCallback(gpuStreamHandle stream, void** buffers, } CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) << "Failed to gpuStreamSynchronize"; - py::gil_scoped_acquire gil; - py::tuple host_input_arrays(arity); + nb::gil_scoped_acquire gil; + nb::tuple host_input_arrays = nb::steal(PyTuple_New(arity)); for (size_t i = 0; i < arity; ++i) { CpuCallback::Arg arg = callback->args()[i]; if (arg.type == TOKEN) { - host_input_arrays[i] = py::none(); + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); continue; } - py::capsule base(host_input_buffers[i], - [](void* ptr) { delete[] static_cast(ptr); }); - host_input_arrays[i] = - py::array(arg.dtype, arg.dims, arg.strides, - const_cast(host_input_buffers[i]), /*base=*/base); - host_input_arrays[i].attr("flags").attr("writeable") = Py_False; + nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { + delete[] static_cast(ptr); + }); + auto array = nb_numpy_ndarray(arg.dtype, arg.dims, arg.strides, + const_cast(host_input_buffers[i]), + /*base=*/base); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); } - std::optional maybe_result_tuple = + EnterHostCallback(); + std::optional maybe_result_tuple = callback->Call(host_input_arrays, status); + LeaveHostCallback(); if (!maybe_result_tuple) { return; } - py::tuple result_tuple = maybe_result_tuple.value(); + nb::tuple result_tuple = maybe_result_tuple.value(); std::vector temp_buffers; for (size_t i = 0; i < callback->results().size(); ++i) { CpuCallback::Result result = callback->results()[i]; if (result.type == TOKEN) { continue; } - py::object output = py::reinterpret_borrow( - PyTuple_GetItem(result_tuple.ptr(), i)); - py::array array = py::cast(std::move(output)); + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + nb_numpy_ndarray array = nb_numpy_ndarray::ensure(std::move(output)); absl::Span dims( reinterpret_cast(array.shape()), array.ndim()); absl::Span strides( @@ -121,11 +127,13 @@ void XlaPythonGpuCallback(gpuStreamHandle stream, void** buffers, } else { void* temp = new char[result.size_in_bytes]; temp_buffers.push_back(temp); + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(result.type); + options.dims = dims; + options.permutation = result.reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; xla::StatusOr> plan = - callback->transpose_cache().GetOrCreate( - xla::primitive_util::ByteWidth(result.type), dims, - result.reversed_layout, - /*input_layout=*/xla::TransposePlan::Striding{strides}); + callback->transpose_cache().GetOrCreate(options); if (!plan.ok()) { throw xla::XlaRuntimeError(plan.status().ToString()); } @@ -136,7 +144,7 @@ void XlaPythonGpuCallback(gpuStreamHandle stream, void** buffers, CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; } } - py::gil_scoped_release release; + nb::gil_scoped_release release; CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) << "Failed to gpuStreamSynchronize"; for (int i = 0; i < temp_buffers.size(); ++i) { diff --git a/xla/python/py_client_gpu.h b/xla/python/py_client_gpu.h index 5ce8079b79f37..d7675e1b6ad03 100644 --- a/xla/python/py_client_gpu.h +++ b/xla/python/py_client_gpu.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/py_compile_only_client.cc b/xla/python/py_compile_only_client.cc index 78908290c64b0..58880ecab15cc 100644 --- a/xla/python/py_compile_only_client.cc +++ b/xla/python/py_compile_only_client.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,20 +15,59 @@ limitations under the License. #include "xla/python/py_compile_only_client.h" +#include #include #include #include #include +#include #include #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "pybind11/stl.h" // from @pybind11 +#include "llvm/Support/Casting.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/shared_ptr.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/vector.h" // from @nanobind // IWYU pragma: keep +#include "xla/literal.h" #include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_device_description.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" -#include "xla/python/status_casters.h" -#include "tsl/python/lib/core/numpy.h" //NOLINT +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/tuple.h" +#include "xla/python/ifrt/value.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/py_client.h" +#include "xla/service/computation_placer.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/util.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" + +namespace nb = nanobind; namespace xla { @@ -45,21 +84,32 @@ class PjRtCompileOnlyDevice : public PjRtDevice { PjRtClient* client() const override { return nullptr; } bool IsAddressable() const override { return false; } - int local_hardware_id() const override { return -1; } + int local_hardware_id() const override { + return local_hardware_id_typed().value(); + } + + PjRtLocalDeviceId local_device_id() const override { + return PjRtLocalDeviceId(local_hardware_id_typed().value()); + } + + PjRtLocalHardwareId local_hardware_id_typed() const override { + return PjRtLocalHardwareId(-1); + } + std::unique_ptr CreateAsyncTrackingEvent( absl::string_view description) const override { return nullptr; } - Status TransferToInfeed(const LiteralSlice& literal) override { + absl::Status TransferToInfeed(const LiteralSlice& literal) override { return Unimplemented("TransferToInfeed is not supported"); } - Status TransferFromOutfeed(MutableBorrowingLiteral literal) override { + absl::Status TransferFromOutfeed(MutableBorrowingLiteral literal) override { return Unimplemented("TransferFromOutfeed is not supported"); } absl::Span memory_spaces() const override { return {}; } - StatusOr default_memory_space() const override { + absl::StatusOr default_memory_space() const override { return Unimplemented("default_memory_space is not supported"); } @@ -70,13 +120,14 @@ class PjRtCompileOnlyDevice : public PjRtDevice { class InvalidIfrtCompiler final : public llvm::RTTIExtends { public: - StatusOr> Compile( + absl::StatusOr> Compile( std::unique_ptr program, std::unique_ptr options) override { return Unimplemented("Compile not implemented."); } - StatusOr> DeserializeLoadedExecutable( + absl::StatusOr> + DeserializeLoadedExecutable( absl::string_view serialized, std::unique_ptr options) override { return Unimplemented("DeserializeLoadedExecutable not implemented."); @@ -100,7 +151,7 @@ class CompileOnlyIfRtClient final } } - StatusOr> MakeArrayFromHostBuffer( + absl::StatusOr> MakeArrayFromHostBuffer( const void* data, ifrt::DType dtype, ifrt::Shape shape, std::optional> byte_strides, std::shared_ptr sharding, @@ -110,7 +161,8 @@ class CompileOnlyIfRtClient final "MakeArrayFromHostBuffer not available with compile-only client."); } - StatusOr> AssembleArrayFromSingleDeviceArrays( + absl::StatusOr> + AssembleArrayFromSingleDeviceArrays( ifrt::Shape shape, std::shared_ptr sharding, absl::Span> arrays, ifrt::ArrayCopySemantics semantics) override { @@ -119,7 +171,7 @@ class CompileOnlyIfRtClient final "client."); } - StatusOr> MakeTuple( + absl::StatusOr> MakeTuple( absl::Span> values) override { return Unimplemented("MakeTuple not available with compile-only client."); } @@ -149,17 +201,17 @@ class CompileOnlyIfRtClient final return {}; } int process_index() const override { return 0; } - StatusOr GetDefaultDeviceAssignment( + absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override { return Unimplemented( "GetDefaultDeviceAssignment not available with compile-only client."); } - StatusOr LookupDevice(int device_id) const override { + absl::StatusOr LookupDevice(int device_id) const override { return Unimplemented( "LookupDevice not available with compile-only client."); } - StatusOr LookupAddressableDevice( + absl::StatusOr LookupAddressableDevice( int local_hardware_id) const override { return Unimplemented( "LookupAddressableDevice not available with compile-only client."); @@ -171,12 +223,19 @@ class CompileOnlyIfRtClient final const PjRtTopologyDescription& topology() const { return *topology_; } - StatusOr> + absl::StatusOr> GetTopologyForDevices( absl::Span devices) const override { return topology_; } + absl::StatusOr> GetDefaultLayoutForDevice( + ifrt::DType dtype, absl::Span dims, + ifrt::Device* device) const override { + return absl::UnimplementedError( + "GetDefaultLayout not supported for CompileOnlyIfRtClient."); + } + private: InvalidIfrtCompiler default_compiler_; std::shared_ptr topology_; @@ -191,15 +250,24 @@ class CompileOnlyPyClient : public PyClient { public: using PyClient::PyClient; - StatusOr> CompileUnloaded( - std::string mlir_module, CompileOptions options, - std::vector host_callbacks) { + static nb_class_ptr Make( + std::shared_ptr topology) { + auto client = + nb::borrow>(make_nb_class( + std::make_unique(std::move(topology)))); + CompileOnlyPyClient::Initialize(client); + return client; + } + + absl::StatusOr> CompileUnloaded( + std::string_view mlir_module, CompileOptions options, + std::vector host_callbacks) { if (!host_callbacks.empty()) { return Unimplemented( "Compiling with host_callbacks not available with compile-only " "client."); } - pybind11::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; mlir::MLIRContext context; TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, ParseMlirModuleString(mlir_module, context)); @@ -210,25 +278,36 @@ class CompileOnlyPyClient : public PyClient { return PjRtCompile(std::move(options), module.get(), ifrt_client->topology()); } + + private: + static void Initialize(nb_class_ptr client) { + PyClient::Initialize(client); + } }; } // namespace -std::shared_ptr MakeCompileOnlyClient( +nb_class_ptr MakeCompileOnlyClient( std::shared_ptr topology) { - return std::make_shared( - std::make_unique(std::move(topology))); + return CompileOnlyPyClient::Make(std::move(topology)); } -void RegisterCompileOnlyClient(pybind11::module& m) { - pybind11::class_>(m, - "CompileOnlyPyClient") - .def("compile", - xla::ValueOrThrowWrapper(&CompileOnlyPyClient::CompileUnloaded), - pybind11::arg("computation"), - pybind11::arg("compile_options") = CompileOptions(), - pybind11::arg("host_callbacks") = std::vector()); +void RegisterCompileOnlyClient(nb::module_& m) { + nb::class_(m, "CompileOnlyPyClient") + .def( + "compile", + [](CompileOnlyPyClient& self, nb::bytes mlir_module, + CompileOptions options, std::vector host_callbacks) { + return ValueOrThrow(self.CompileUnloaded( + std::string_view(mlir_module.c_str(), mlir_module.size()), + std::move(options), std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile", ValueOrThrowWrapper(&CompileOnlyPyClient::CompileUnloaded), + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()); } } // namespace xla diff --git a/xla/python/py_compile_only_client.h b/xla/python/py_compile_only_client.h index 84afa54f01f6c..67350c849bc5e 100644 --- a/xla/python/py_compile_only_client.h +++ b/xla/python/py_compile_only_client.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,9 @@ limitations under the License. #include // placeholder for index annotation headers +#include "nanobind/nanobind.h" // from @nanobind #include "xla/pjrt/pjrt_compiler.h" +#include "xla/python/nb_class_ptr.h" #include "xla/python/py_client.h" namespace xla { @@ -33,10 +35,10 @@ namespace xla { // Python duck typing to treat the unloaded executable like a loaded executable // (except it will raise errors if you try to run it, which is what we want for // AOT environments). -std::shared_ptr MakeCompileOnlyClient( +nb_class_ptr MakeCompileOnlyClient( std::shared_ptr); -void RegisterCompileOnlyClient(pybind11::module& m); +void RegisterCompileOnlyClient(nanobind::module_& m); } // namespace xla diff --git a/xla/python/py_device.cc b/xla/python/py_device.cc new file mode 100644 index 0000000000000..69ede5ff4eec5 --- /dev/null +++ b/xla/python/py_device.cc @@ -0,0 +1,321 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/py_device.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/optional.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/variant.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/vector.h" // from @nanobind // IWYU pragma: keep +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/py_client.h" +#include "xla/python/py_memory_space.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "tsl/framework/allocator.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace nb = ::nanobind; + +namespace xla { + +PyDevice::PyDevice(nb_class_ptr client, ifrt::Device* device) + : client_(std::move(client)), device_(device) {} + +int PyDevice::id() const { return device_->id(); } + +int PyDevice::process_index() const { return device_->process_index(); } + +std::string_view PyDevice::platform() const { + // TODO(phawkins): this is a temporary backwards + // compatibility shim. We changed the name PJRT + // reports for GPU platforms to "cuda" or "rocm", + // but we haven't yet updated JAX clients that + // expect "gpu". Migrate users and remove this + // code. + if (client_->platform_name() == "cuda" || + client_->platform_name() == "rocm") { + return std::string_view("gpu"); + } else { + return client_->platform_name(); + } +} + +std::string_view PyDevice::device_kind() const { + return device_->device_kind(); +} + +std::optional PyDevice::local_hardware_id() const { + int local_hardware_id = device_->local_hardware_id(); + if (local_hardware_id == -1) { + return std::nullopt; + } + return local_hardware_id; +} + +std::string_view PyDevice::Str() const { return device_->DebugString(); } + +std::string_view PyDevice::Repr() const { return device_->ToString(); } + +absl::Status PyDevice::TransferToInfeed(LiteralSlice literal) { + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + return device_->TransferToInfeed(literal); +} + +absl::StatusOr PyDevice::TransferFromOutfeed(Shape shape) { + GlobalPyRefManager()->CollectGarbage(); + std::shared_ptr literal; + { + nb::gil_scoped_release gil_release; + ShapeUtil::ForEachMutableSubshape( + &shape, [](Shape* subshape, const ShapeIndex&) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); + literal = std::make_shared(shape); + TF_RETURN_IF_ERROR(device_->TransferFromOutfeed(literal.get())); + } + return LiteralToPython(std::move(literal)); +} + +absl::StatusOr> PyDevice::Memory( + std::string_view kind) const { + xla::PjRtMemorySpace* result_memory_space = nullptr; + for (auto* memory_space : device_->memory_spaces()) { + if (memory_space->memory_space_kind() == kind) { + if (result_memory_space != nullptr) { + std::string memories = absl::StrJoin( + device_->memory_spaces(), ", ", + [](std::string* out, const auto& memory_space) { + absl::StrAppend(out, memory_space->memory_space_kind()); + }); + auto device_kind = device_->device_kind(); + return xla::InvalidArgument( + "Found more than one addressable memory for " + "kind %s which is not allowed. There can only " + "be one memory for each " + "kind. Device %s can address the following " + "memory kinds: %s", + kind, device_kind, memories); + } + result_memory_space = memory_space; + } + } + if (result_memory_space == nullptr) { + std::string memories = + absl::StrJoin(device_->memory_spaces(), ", ", + [](std::string* out, const auto& memory_space) { + absl::StrAppend(out, memory_space->memory_space_kind()); + }); + auto device_kind = device_->device_kind(); + return xla::InvalidArgument( + "Could not find memory addressable by device %s. Device %s " + "can address the following memory kinds: %s. " + "Got memory kind: %s", + device_kind, device_kind, memories, kind); + } + return client_->GetPyMemorySpace(result_memory_space); +} + +absl::StatusOr> PyDevice::DefaultMemory() const { + TF_ASSIGN_OR_RETURN(auto* memory_space, device_->default_memory_space()); + return client_->GetPyMemorySpace(memory_space); +} + +nb::list PyDevice::AddressableMemories() const { + nb::list memory_spaces; + for (auto* memory_space : device_->memory_spaces()) { + memory_spaces.append(client_->GetPyMemorySpace(memory_space)); + } + return memory_spaces; +} + +absl::StatusOr> PyDevice::MemoryStats() const { + GlobalPyRefManager()->CollectGarbage(); + absl::StatusOr maybe_stats = + device_->GetAllocatorStats(); + if (absl::IsUnimplemented(maybe_stats.status())) { + return std::nullopt; + } + // Raise error if any status other than Unimplemented is returned. + ThrowIfError(maybe_stats.status()); + + nb::dict result; + result["num_allocs"] = maybe_stats->num_allocs; + result["bytes_in_use"] = maybe_stats->bytes_in_use; + result["peak_bytes_in_use"] = maybe_stats->peak_bytes_in_use; + result["largest_alloc_size"] = maybe_stats->largest_alloc_size; + if (maybe_stats->bytes_limit) { + result["bytes_limit"] = *maybe_stats->bytes_limit; + } + result["bytes_reserved"] = maybe_stats->bytes_reserved; + result["peak_bytes_reserved"] = maybe_stats->peak_bytes_reserved; + if (maybe_stats->bytes_reservable_limit) { + result["bytes_reservable_limit"] = *maybe_stats->bytes_reservable_limit; + } + result["largest_free_block_bytes"] = maybe_stats->largest_free_block_bytes; + if (maybe_stats->pool_bytes) { + result["pool_bytes"] = *maybe_stats->pool_bytes; + } + if (maybe_stats->peak_pool_bytes) { + result["peak_pool_bytes"] = *maybe_stats->peak_pool_bytes; + } + return result; +} + +absl::StatusOr PyDevice::GetStreamForExternalReadyEvents() + const { + return device_->GetStreamForExternalReadyEvents(); +} + +/* static */ int PyDevice::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + PyDevice* d = nb::inst_ptr(self); + Py_VISIT(d->client().ptr()); + return 0; +} + +/* static */ int PyDevice::tp_clear(PyObject* self) { + PyDevice* d = nb::inst_ptr(self); + nb_class_ptr client; + std::swap(client, d->client_); + return 0; +} + +PyType_Slot PyDevice::slots_[] = { + {Py_tp_traverse, (void*)PyDevice::tp_traverse}, + {Py_tp_clear, (void*)PyDevice::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyDevice::RegisterPythonType(nb::module_& m) { + nb::class_ device( + m, "Device", nb::type_slots(PyDevice::slots_), + "A descriptor of an available device.\n\nSubclasses are used to " + "represent specific types of devices, e.g. CPUs, GPUs. Subclasses may " + "have additional properties specific to that device type."); + device + .def_prop_ro( + "id", &PyDevice::id, + "Integer ID of this device.\n\nUnique across all available devices " + "of this type, including remote devices on multi-host platforms.") + .def_prop_ro("process_index", &PyDevice::process_index, + "Integer index of this device's process.\n\n" + "This is always 0 except on multi-process platforms.") + .def_prop_ro("host_id", &PyDevice::process_index, + "Deprecated; please use process_index") + .def_prop_ro("task_id", &PyDevice::process_index, + "Deprecated; please use process_index") + .def_prop_ro("platform", &PyDevice::platform) + .def_prop_ro("device_kind", &PyDevice::device_kind) + .def_prop_ro("client", &PyDevice::client) + .def_prop_ro( + "local_hardware_id", &PyDevice::local_hardware_id, + "Opaque hardware ID, e.g., the CUDA device number. In general, not " + "guaranteed to be dense, and not guaranteed to be defined on all " + "platforms.") + .def("__str__", &PyDevice::Str) + .def("__repr__", &PyDevice::Repr) + .def("transfer_to_infeed", + ThrowIfErrorWrapper(&PyDevice::TransferToInfeed)) + .def("transfer_from_outfeed", + ValueOrThrowWrapper(&PyDevice::TransferFromOutfeed)) + .def("memory", ValueOrThrowWrapper(&PyDevice::Memory), nb::arg("kind")) + .def("default_memory", ValueOrThrowWrapper(&PyDevice::DefaultMemory), + "Returns the default memory of a device.") + .def("addressable_memories", &PyDevice::AddressableMemories, + "Returns all the memories that a device can address.") + + .def("live_buffers", + [](nb::handle device) { + PythonDeprecationWarning( + "Per device live_buffers() is deprecated. Please " + "use the jax.live_arrays() for jax.Arrays instead."); + return nb::list(); + }) + .def( + "memory_stats", ValueOrThrowWrapper(&PyDevice::MemoryStats), + "Returns memory statistics for this device keyed by name. May not " + "be implemented on all platforms, and different platforms may return " + "different stats, or -1 for unavailable stats. 'bytes_in_use' is " + "usually available. Intended for diagnostic use.") + .def( + "get_stream_for_external_ready_events", + xla::ValueOrThrowWrapper(&PyDevice::GetStreamForExternalReadyEvents)); + static PyMethodDef get_attr_method = { + "__getattr__", + +[](PyObject* self, PyObject* args) -> PyObject* { + PyObject* key; + if (!PyArg_ParseTuple(args, "O", &key)) { + PyErr_SetString(PyExc_TypeError, "__getattr__ must take 1 argument."); + return nullptr; + } + try { + auto device = nb::cast(nb::handle(self)); + auto name = nb::cast(nb::handle(key)); + const auto& attrs = device->device_->Attributes(); + auto it = attrs.find(name); + if (it != attrs.end()) { + auto result = + std::visit([](auto&& v) { return nb::cast(v); }, it->second); + return result.release().ptr(); + } + PyErr_SetNone(PyExc_AttributeError); + return nullptr; + } catch (std::exception& e) { + PyErr_Format(PyExc_SystemError, "Unhandled nanobind exception: %s", + e.what()); + return nullptr; + } catch (...) { + PyErr_SetString(PyExc_SystemError, "Unhandled nanobind exception."); + return nullptr; + } + }, + METH_VARARGS, + nullptr, + }; + device.attr("__getattr__") = nb::steal(PyDescr_NewMethod( + reinterpret_cast(device.ptr()), &get_attr_method)); +} + +} // namespace xla diff --git a/xla/python/py_device.h b/xla/python/py_device.h new file mode 100644 index 0000000000000..e3e3d8fbeada3 --- /dev/null +++ b/xla/python/py_device.h @@ -0,0 +1,83 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_PY_DEVICE_H_ +#define XLA_PYTHON_PY_DEVICE_H_ + +#include + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "xla/literal.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/py_client.h" +#include "xla/shape.h" + +namespace xla { + +class PyDevice { + public: + PyDevice(nb_class_ptr client, ifrt::Device* device); + + // Devices are compared using Python object identity, so we don't allow them + // to be copied or moved. + PyDevice(const PyDevice&) = delete; + PyDevice(PyDevice&&) = delete; + PyDevice& operator=(const PyDevice&) = delete; + PyDevice& operator=(PyDevice&&) = delete; + + const nb_class_ptr& client() const { return client_; } + ifrt::Device* device() const { return device_; } + + int id() const; + int process_index() const; + std::string_view platform() const; + std::string_view device_kind() const; + std::optional local_hardware_id() const; + + std::string_view Str() const; + std::string_view Repr() const; + + absl::Status TransferToInfeed(LiteralSlice literal); + absl::StatusOr TransferFromOutfeed(Shape shape); + + absl::StatusOr> Memory( + std::string_view kind) const; + absl::StatusOr> DefaultMemory() const; + nanobind::list AddressableMemories() const; + absl::StatusOr> MemoryStats() const; + + absl::StatusOr GetStreamForExternalReadyEvents() const; + + static void RegisterPythonType(nanobind::module_& m); + + private: + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + nb_class_ptr client_; + ifrt::Device* device_; +}; + +} // namespace xla + +#endif // XLA_PYTHON_PY_DEVICE_H_ diff --git a/xla/python/py_device_list.cc b/xla/python/py_device_list.cc index 558a33302575c..a6d6a52b04bd1 100644 --- a/xla/python/py_device_list.cc +++ b/xla/python/py_device_list.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,77 +15,75 @@ limitations under the License. #include "xla/python/py_device_list.h" +#include + #include #include -#include #include #include #include #include #include "absl/hash/hash.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" -#include "pybind11/attr.h" // from @pybind11 -#include "pybind11/cast.h" // from @pybind11 -#include "pybind11/detail/common.h" // from @pybind11 -#include "pybind11/gil.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 -#include "pybind11/stl.h" // from @pybind11 // NOLINT +#include "nanobind/make_iterator.h" // from @nanobind +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep #include "xla/pjrt/pjrt_client.h" #include "xla/python/ifrt/device.h" -#include "xla/python/ifrt/memory.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_helpers.h" #include "xla/python/py_client.h" +#include "xla/python/py_device.h" #include "xla/python/python_ref_manager.h" #include "xla/python/sharding.h" -#include "xla/statusor.h" +#include "xla/python/types.h" #include "xla/util.h" namespace jax { -namespace py = ::pybind11; +namespace nb = ::nanobind; -PyDeviceList::PyDeviceList(std::shared_ptr py_client, +PyDeviceList::PyDeviceList(xla::nb_class_ptr py_client, xla::ifrt::DeviceList device_list) : py_client_(std::move(py_client)), device_list_(std::move(device_list)) {} -PyDeviceList::PyDeviceList(py::tuple py_device_assignment) +PyDeviceList::PyDeviceList(nb::tuple py_device_assignment) : device_list_(py_device_assignment) { // Attempt to convert to Python devices into `ifrt::DeviceList`. - if (py_device_assignment.empty()) { + if (py_device_assignment.size() == 0) { device_list_ = xla::ifrt::DeviceList({}); return; } xla::ifrt::DeviceList::Devices devices; - devices.reserve(devices.size()); - for (py::handle obj : py_device_assignment) { - if (!py::isinstance(obj)) { - // Non-`xla::PjRtDevice` is used on an alternative JAX backend with device + devices.reserve(py_device_assignment.size()); + for (nb::handle obj : py_device_assignment) { + if (!nb::isinstance(obj.ptr())) { + // Non-`xla::PyDevice` is used on an alternative JAX backend with device // duck typing. Use Python device objects already set in `device_list_`. return; } - auto py_device = py::cast>(obj); - if (py_client_ == nullptr) { - py_client_ = py_device.client(); - } else if (py_device.client() != py_client_) { + auto py_device = nb::cast(obj); + if (py_client_.get() == nullptr) { + py_client_ = py_device->client(); + } else if (py_device->client().get() != py_client_.get()) { // If the list contains multiple clients, fall back to device duck typing. return; } - devices.push_back(py_device.get()); + devices.push_back(py_device->device()); } device_list_ = xla::ifrt::DeviceList(std::move(devices)); } PyDeviceList::~PyDeviceList() { if (device_list_.index() == 1) { - py::object py_device_assignment = - py::cast(std::get<1>(std::move(device_list_))); xla::GlobalPyRefManager()->AddGarbage( - absl::MakeSpan(&py_device_assignment, 1)); + std::move(std::get<1>(std::move(device_list_)))); } } -xla::StatusOr PyDeviceList::ifrt_device_list() const { +absl::StatusOr PyDeviceList::ifrt_device_list() const { switch (device_list_.index()) { case 0: return std::get<0>(device_list_); @@ -103,109 +101,111 @@ int64_t PyDeviceList::Hash() { hash_ = absl::HashOf(std::get<0>(device_list_)); break; case 1: - hash_ = py::hash(std::get<1>(device_list_)); + hash_ = xla::nb_hash(std::get<1>(device_list_)); break; default: - throw py::value_error("Unrecognized DeviceList type"); + throw nb::value_error("Unrecognized DeviceList type"); } } return *hash_; } -bool PyDeviceList::operator==(py::handle other) { - if (!py::isinstance(other)) { +bool PyDeviceList::operator==(nb::handle other) { + if (!nb::isinstance(other)) { return false; } - auto o = py::cast>(other); + auto o = nb::cast(other); // Fast-path using a pointer equality check. - if (this == o.get()) { + if (this == o) { return true; } if (Hash() != o->Hash()) { return false; } if (device_list_.index() == 0 && o->device_list_.index() == 0) { - py::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; return std::get<0>(device_list_) == std::get<0>(o->device_list_); } else { return AsTuple().equal(o->AsTuple()); } } -bool PyDeviceList::operator!=(py::handle other) { return !(*this == other); } +bool PyDeviceList::operator!=(nb::handle other) { return !(*this == other); } int PyDeviceList::Len() const { switch (device_list_.index()) { case 0: return std::get<0>(device_list_).size(); case 1: - return py::len(std::get<1>(device_list_)); + return nb::len(std::get<1>(device_list_)); default: - throw py::value_error("Unrecognized DeviceList type"); + throw nb::value_error("Unrecognized DeviceList type"); } } -py::object PyDeviceList::GetItem(int index) { +nb::object PyDeviceList::GetItem(int index) { switch (device_list_.index()) { case 0: { const xla::ifrt::DeviceList& device_list = std::get<0>(device_list_); if (index < -device_list.size() || index >= device_list.size()) { - throw py::index_error(); + throw nb::index_error(); } else if (index < 0) { index += device_list.size(); } - return py::cast(xla::WrapWithClient(py_client_, device_list[index])); + return py_client_->GetPyDevice(device_list[index]); } case 1: return std::get<1>(device_list_).attr("__getitem__")(index); default: - throw py::value_error("Unrecognized DeviceList type"); + throw nb::value_error("Unrecognized DeviceList type"); } } -py::object PyDeviceList::GetSlice(py::slice slice) { +nb::object PyDeviceList::GetSlice(nb::slice slice) { switch (device_list_.index()) { case 0: { const xla::ifrt::DeviceList& device_list = std::get<0>(device_list_); - size_t start, stop, step, slicelength; - if (!slice.compute(device_list.size(), &start, &stop, &step, - &slicelength)) { - throw py::error_already_set(); + Py_ssize_t start, stop, step, slicelength; + if (PySlice_GetIndicesEx(slice.ptr(), device_list.size(), &start, &stop, + &step, &slicelength) != 0) { + throw nb::python_error(); } - std::vector> out; - out.reserve(slicelength); + nb::tuple out = nb::steal(PyTuple_New(slicelength)); for (size_t i = 0; i < slicelength; ++i) { - out.push_back(xla::WrapWithClient(py_client_, device_list[start])); + nb::object d = py_client_->GetPyDevice(device_list[start]); + PyTuple_SET_ITEM(out.ptr(), i, d.release().ptr()); start += step; } - return py::cast(out); + return std::move(out); } case 1: return std::get<1>(device_list_).attr("__getitem__")(slice); default: - throw py::value_error("Unrecognized DeviceList type"); + throw nb::value_error("Unrecognized DeviceList type"); } } -py::tuple PyDeviceList::AsTuple() { +nb::tuple PyDeviceList::AsTuple() const { switch (device_list_.index()) { case 0: { const xla::ifrt::DeviceList& device_list = std::get<0>(device_list_); - std::vector> out; - out.reserve(device_list.size()); + nb::tuple out = nb::steal(PyTuple_New(device_list.size())); + int i = 0; for (xla::ifrt::Device* device : device_list) { - out.push_back(xla::WrapWithClient(py_client_, device)); + nb::object d = py_client_->GetPyDevice(device); + PyTuple_SET_ITEM(out.ptr(), i, d.release().ptr()); + ++i; } - return py::cast(out); + return out; } case 1: return std::get<1>(device_list_); default: - throw py::value_error("Unrecognized DeviceList type"); + throw nb::value_error("Unrecognized DeviceList type"); } } -py::iterator PyDeviceList::Iter() { +nb::iterator PyDeviceList::Iter() { switch (device_list_.index()) { case 0: { // Iterator whose deference converts `xla::ifrt::Device*` into JAX @@ -213,33 +213,32 @@ py::iterator PyDeviceList::Iter() { struct Iterator { void operator++() { ++it; } bool operator==(const Iterator& other) const { return it == other.it; } - xla::ClientAndPtr operator*() const { - return xla::WrapWithClient(py_client, *it); + xla::nb_class_ptr operator*() const { + return py_client->GetPyDevice(*it); } - const std::shared_ptr& py_client; + xla::nb_class_ptr py_client; xla::ifrt::DeviceList::Devices::const_iterator it; }; - return py::make_iterator( + return nb::make_iterator( + nb::type(), "ifrt_device_iterator", Iterator{py_client_, std::get<0>(device_list_).begin()}, Iterator{py_client_, std::get<0>(device_list_).end()}); } case 1: - return py::make_iterator(std::get<1>(device_list_).begin(), - std::get<1>(device_list_).end()); + return nb::make_iterator( + nb::type(), "python_device_iterator", + std::get<1>(device_list_).begin(), std::get<1>(device_list_).end()); default: - throw py::value_error("Unrecognized DeviceList type"); + throw nb::value_error("Unrecognized DeviceList type"); } } -std::string PyDeviceList::Str() { return py::str(AsTuple()); } - -py::tuple PyDeviceList::Dump() { return AsTuple(); } - -std::shared_ptr PyDeviceList::Load( - py::tuple py_device_assignment) { - return std::make_shared(std::move(py_device_assignment)); +std::string PyDeviceList::Str() { + return nb::cast(nb::str(AsTuple())); } +nb::tuple PyDeviceList::Dump() const { return AsTuple(); } + bool PyDeviceList::IsFullyAddressable() { if (!is_fully_addressable_.has_value()) { is_fully_addressable_ = true; @@ -256,9 +255,9 @@ bool PyDeviceList::IsFullyAddressable() { break; } case 1: { - for (py::handle device : std::get<1>(device_list_)) { - if (py::cast(device.attr("process_index")) != - py::cast(device.attr("client").attr("process_index")())) { + for (nb::handle device : std::get<1>(device_list_)) { + if (nb::cast(device.attr("process_index")) != + nb::cast(device.attr("client").attr("process_index")())) { is_fully_addressable_ = false; break; } @@ -266,50 +265,55 @@ bool PyDeviceList::IsFullyAddressable() { break; } default: - throw py::value_error("Unrecognized DeviceList type"); + throw nb::value_error("Unrecognized DeviceList type"); } } return *is_fully_addressable_; } -std::shared_ptr PyDeviceList::AddressableDeviceList() { - if (IsFullyAddressable()) { +/*static*/ xla::nb_class_ptr PyDeviceList::AddressableDeviceList( + xla::nb_class_ptr self) { + if (self->IsFullyAddressable()) { // Do not cache this result in `addressable_device_list_`. Otherwise, it // will create a cycle that prevents deletion of this object. - return shared_from_this(); + return self; } - if (!addressable_device_list_.has_value()) { - switch (device_list_.index()) { + if (!self->addressable_device_list_.has_value()) { + switch (self->device_list_.index()) { case 0: { xla::ifrt::DeviceList::Devices addressable_devices; - const int process_index = py_client_ ? py_client_->process_index() : 0; - for (xla::ifrt::Device* device : std::get<0>(device_list_).devices()) { + const int process_index = + self->py_client_ ? self->py_client_->process_index() : 0; + for (xla::ifrt::Device* device : + std::get<0>(self->device_list_).devices()) { if (device->process_index() == process_index) { addressable_devices.push_back(device); } } - addressable_device_list_ = std::make_shared( - py_client_, xla::ifrt::DeviceList(std::move(addressable_devices))); + self->addressable_device_list_ = xla::make_nb_class( + self->py_client_, + xla::ifrt::DeviceList(std::move(addressable_devices))); break; } case 1: { - std::vector addressable_py_device_assignment; - for (py::handle device : std::get<1>(device_list_)) { - if (py::cast(device.attr("process_index")) == - py::cast(device.attr("client").attr("process_index")())) { - addressable_py_device_assignment.push_back( - py::cast(device)); + auto device_list = std::get<1>(self->device_list_); + std::vector addressable_devices; + for (size_t i = 0; i < device_list.size(); ++i) { + nb::object device = device_list[i]; + if (nb::cast(device.attr("process_index")) == + nb::cast(device.attr("client").attr("process_index")())) { + addressable_devices.push_back(std::move(device)); } } - addressable_device_list_ = std::make_shared( - py::cast(std::move(addressable_py_device_assignment))); + self->addressable_device_list_ = xla::make_nb_class( + xla::MutableSpanToNbTuple(absl::MakeSpan(addressable_devices))); break; } default: - throw py::value_error("Unrecognized DeviceList type"); + throw nb::value_error("Unrecognized DeviceList type"); } } - return *addressable_device_list_; + return *self->addressable_device_list_; } void PyDeviceList::PopulateMemoryKindInfo() { @@ -319,11 +323,11 @@ void PyDeviceList::PopulateMemoryKindInfo() { return; } if (device_list_.index() != 0) { - throw py::value_error("Unrecognized DeviceList type"); + throw nb::value_error("Unrecognized DeviceList type"); } MemoryKindInfo info; if (!GetEnableMemories()) { - info.default_memory_kind = py::none(); + info.default_memory_kind = nb::none(); memory_kind_info_ = std::move(info); return; } @@ -336,7 +340,7 @@ void PyDeviceList::PopulateMemoryKindInfo() { } } if (addressable_device == nullptr) { - info.default_memory_kind = py::none(); + info.default_memory_kind = nb::none(); memory_kind_info_ = std::move(info); return; } @@ -348,50 +352,54 @@ void PyDeviceList::PopulateMemoryKindInfo() { return; } info.default_memory_kind = - py::cast(std::string((*default_memory)->memory_space_kind())); - std::vector memory_kinds; - memory_kinds.reserve(addressable_device->memory_spaces().size()); - for (xla::ifrt::Memory* memory : addressable_device->memory_spaces()) { - memory_kinds.push_back(std::string(memory->memory_space_kind())); + nb::cast(std::string((*default_memory)->memory_space_kind())); + nb::tuple memory_kinds = nb::steal( + PyTuple_New(addressable_device->memory_spaces().size())); + for (size_t i = 0; i < addressable_device->memory_spaces().size(); ++i) { + auto* memory = addressable_device->memory_spaces()[i]; + nb::str s = nb::str(memory->memory_space_kind().data(), + memory->memory_space_kind().size()); + PyTuple_SET_ITEM(memory_kinds.ptr(), i, s.release().ptr()); } - info.memory_kinds = py::cast(memory_kinds); + info.memory_kinds = std::move(memory_kinds); memory_kind_info_ = std::move(info); } void PyDeviceList::PopulateMemoryKindInfoForDuckTypedDevices() { MemoryKindInfo info; if (!GetEnableMemories()) { - info.default_memory_kind = py::none(); + info.default_memory_kind = nb::none(); // info.memory_kinds is default-initialized to an empty tuple. memory_kind_info_ = std::move(info); return; } try { - py::handle addressable_device; - for (py::handle device : std::get<1>(device_list_)) { - if (py::cast(device.attr("process_index")) == - py::cast(device.attr("client").attr("process_index")())) { + nb::handle addressable_device; + for (nb::handle device : std::get<1>(device_list_)) { + if (nb::cast(device.attr("process_index")) == + nb::cast(device.attr("client").attr("process_index")())) { addressable_device = device; break; } } if (!addressable_device) { - info.default_memory_kind = py::none(); + info.default_memory_kind = nb::none(); // info.memory_kinds is default-initialized to an empty tuple. memory_kind_info_ = std::move(info); return; } auto default_memory = addressable_device.attr("default_memory")(); info.default_memory_kind = default_memory.attr("kind"); - info.memory_kinds = addressable_device.attr("addressable_memories")(); + info.memory_kinds = nb::tuple( + nb::object(addressable_device.attr("addressable_memories")())); memory_kind_info_ = std::move(info); - } catch (py::error_already_set& e) { + } catch (nb::python_error& e) { // Cache the error. memory_kind_info_ = xla::InvalidArgument("%s", e.what()); } } -xla::StatusOr PyDeviceList::MemoryKinds() { +absl::StatusOr PyDeviceList::MemoryKinds() { if (!memory_kind_info_.has_value()) { PopulateMemoryKindInfo(); } @@ -401,7 +409,7 @@ xla::StatusOr PyDeviceList::MemoryKinds() { return (*memory_kind_info_)->memory_kinds; } -xla::StatusOr PyDeviceList::DefaultMemoryKind() { +absl::StatusOr PyDeviceList::DefaultMemoryKind() { if (!memory_kind_info_.has_value()) { PopulateMemoryKindInfo(); } @@ -411,38 +419,40 @@ xla::StatusOr PyDeviceList::DefaultMemoryKind() { return (*memory_kind_info_)->default_memory_kind; } -void RegisterDeviceList(py::module& m) { - py::class_>(m, "DeviceList") - .def(py::init()) +void RegisterDeviceList(nb::module_& m) { + nb::class_(m, "DeviceList") + .def(nb::init()) .def("__hash__", &PyDeviceList::Hash) .def("__eq__", &PyDeviceList::operator==) .def("__ne__", &PyDeviceList::operator!=) .def("__len__", &PyDeviceList::Len) .def("__getitem__", &PyDeviceList::GetItem) .def("__getitem__", &PyDeviceList::GetSlice) - .def("__iter__", &PyDeviceList::Iter, py::keep_alive<0, 1>()) + .def("__iter__", &PyDeviceList::Iter, nb::keep_alive<0, 1>()) .def("__str__", &PyDeviceList::Str) - .def(py::pickle([](PyDeviceList* l) { return l->Dump(); }, - [](py::tuple t) { return PyDeviceList::Load(t); })) - .def_property_readonly("is_fully_addressable", - &PyDeviceList::IsFullyAddressable) - .def_property_readonly("addressable_device_list", - &PyDeviceList::AddressableDeviceList) + .def("__repr__", &PyDeviceList::Str) + .def("__getstate__", [](const PyDeviceList& l) { return l.Dump(); }) + .def("__setstate__", + [](PyDeviceList& self, nb::tuple t) { + new (&self) PyDeviceList(std::move(t)); + }) + .def_prop_ro("is_fully_addressable", &PyDeviceList::IsFullyAddressable) + .def_prop_ro("addressable_device_list", + &PyDeviceList::AddressableDeviceList) // `xla::ValueOrThrowWrapper` does not work with - // `def_property_readonly()`. Manually convert an error into an exception. - .def_property_readonly( - "default_memory_kind", - [](PyDeviceList* l) { - auto kind = l->DefaultMemoryKind(); - if (!kind.ok()) { - throw py::value_error(kind.status().ToString()); - } - return *kind; - }) - .def_property_readonly("memory_kinds", [](PyDeviceList* l) { + // `def_prop_ro()`. Manually convert an error into an exception. + .def_prop_ro("default_memory_kind", + [](PyDeviceList* l) { + auto kind = l->DefaultMemoryKind(); + if (!kind.ok()) { + throw nb::value_error(kind.status().ToString().c_str()); + } + return *kind; + }) + .def_prop_ro("memory_kinds", [](PyDeviceList* l) { auto kinds = l->MemoryKinds(); if (!kinds.ok()) { - throw py::value_error(kinds.status().ToString()); + throw nb::value_error(kinds.status().ToString().c_str()); } return *kinds; }); diff --git a/xla/python/py_device_list.h b/xla/python/py_device_list.h index 1484aaf2c9dc0..910f081f37119 100644 --- a/xla/python/py_device_list.h +++ b/xla/python/py_device_list.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,20 +22,20 @@ limitations under the License. #include #include -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" // from @nanobind #include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" #include "xla/python/py_client.h" -#include "xla/statusor.h" namespace jax { // Device list with various caching and direct access to IFRT DeviceList. -class PyDeviceList : public std::enable_shared_from_this { +class PyDeviceList { public: - PyDeviceList(std::shared_ptr py_client, + PyDeviceList(xla::nb_class_ptr py_client, xla::ifrt::DeviceList device_list); - explicit PyDeviceList(pybind11::tuple py_device_assignment); + explicit PyDeviceList(nanobind::tuple py_device_assignment); ~PyDeviceList(); PyDeviceList(const PyDeviceList&) = delete; @@ -44,32 +44,31 @@ class PyDeviceList : public std::enable_shared_from_this { PyDeviceList& operator=(PyDeviceList&&) = delete; // These two methods are safe to call from C++ without GIL. - std::shared_ptr py_client() const { return py_client_; } - xla::StatusOr ifrt_device_list() const; + xla::nb_class_ptr py_client() const { return py_client_; } + absl::StatusOr ifrt_device_list() const; // Methods below require GIL. int64_t Hash(); - bool operator==(pybind11::handle other); - bool operator!=(pybind11::handle other); + bool operator==(nanobind::handle other); + bool operator!=(nanobind::handle other); int Len() const; - pybind11::object GetItem(int index); - pybind11::object GetSlice(pybind11::slice slice); - pybind11::iterator Iter(); + nanobind::object GetItem(int index); + nanobind::object GetSlice(nanobind::slice slice); + nanobind::iterator Iter(); std::string Str(); - pybind11::tuple Dump(); - static std::shared_ptr Load( - pybind11::tuple py_device_assignment); + nanobind::tuple Dump() const; bool IsFullyAddressable(); - std::shared_ptr AddressableDeviceList(); - xla::StatusOr DefaultMemoryKind(); - xla::StatusOr MemoryKinds(); + static xla::nb_class_ptr AddressableDeviceList( + xla::nb_class_ptr self); + absl::StatusOr DefaultMemoryKind(); + absl::StatusOr MemoryKinds(); private: - pybind11::tuple AsTuple(); + nanobind::tuple AsTuple() const; // Finds the memory kind info from an addressable device. void PopulateMemoryKindInfo(); @@ -79,26 +78,26 @@ class PyDeviceList : public std::enable_shared_from_this { // Valid only if `device_list_` contains `xla::ifrt::DeviceList` and // non-empty. - std::shared_ptr py_client_; + xla::nb_class_ptr py_client_; // Either C++ `ifrt::DeviceList` or Python duck-type devices. // TODO(hyeontaek): Remove support for Python duck-type devices once all // JAX backends and tests are migrated to use an `xla::ifrt::Device` type // for JAX devices. - std::variant device_list_; + std::variant device_list_; std::optional hash_; // Populated on demand. // TODO(hyeontaek): Make the following property cached within // `xla::ifrt::DeviceList`. std::optional is_fully_addressable_; // Populated on demand. - std::optional> + std::optional> addressable_device_list_; // Populated on demand. struct MemoryKindInfo { - pybind11::object default_memory_kind; - pybind11::tuple memory_kinds; + nanobind::object default_memory_kind; + nanobind::tuple memory_kinds; }; - std::optional> + std::optional> memory_kind_info_; // Populated on demand. }; @@ -108,7 +107,7 @@ class PyDeviceList : public std::enable_shared_from_this { // module_arg {} // } // go/pywald-pybind-annotation END -void RegisterDeviceList(pybind11::module& m); +void RegisterDeviceList(nanobind::module_& m); } // namespace jax diff --git a/xla/python/py_executable.cc b/xla/python/py_executable.cc index 786872913beec..65226d4e8a13c 100644 --- a/xla/python/py_executable.cc +++ b/xla/python/py_executable.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,36 +15,59 @@ limitations under the License. #include "xla/python/py_executable.h" +#include + +#include #include #include #include +#include #include #include #include #include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "xla/hlo/ir/hlo_module.h" #include "xla/layout.h" -#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/py_array.h" +#include "xla/python/py_client.h" +#include "xla/python/py_device.h" +#include "xla/python/traceback.h" +#include "tsl/concurrency/ref_count.h" #include "tsl/platform/fingerprint.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla { -namespace py = pybind11; +namespace nb = nanobind; -Status PyToken::Await() { +absl::Status PyToken::Await() { CHECK(future_.IsValid()); - py::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; return future_.Await(); } -Status PyShardedToken::Await() { - py::gil_scoped_release gil_release; - Status status = OkStatus(); +absl::Status PyShardedToken::Await() { + nb::gil_scoped_release gil_release; + absl::Status status = absl::OkStatus(); for (auto& future : futures_) { auto s = future.Await(); if (!s.ok()) status = std::move(s); @@ -53,9 +76,9 @@ Status PyShardedToken::Await() { } PyLoadedExecutable::PyLoadedExecutable( - std::shared_ptr client, + nb_class_ptr client, std::unique_ptr ifrt_loaded_executable, - std::shared_ptr traceback, + std::optional traceback, std::optional fingerprint) : client_(std::move(client)), ifrt_loaded_executable_(std::move(ifrt_loaded_executable)), @@ -90,12 +113,12 @@ PyLoadedExecutable::~PyLoadedExecutable() { } } -std::vector> PyLoadedExecutable::AddressableDevices() +std::vector> PyLoadedExecutable::AddressableDevices() const { - std::vector> devices; + std::vector> devices; devices.reserve(ifrt_loaded_executable_->addressable_devices().size()); for (ifrt::Device* device : ifrt_loaded_executable_->addressable_devices()) { - devices.push_back(WrapWithClient(client_, device)); + devices.push_back(client_->GetPyDevice(device)); } return devices; } @@ -110,7 +133,6 @@ template <> struct ShardedBufferAdapter { static int num_devices(const ExecuteShardedArg& arg) { if (std::holds_alternative(arg)) { - CHECK(std::get(arg).fastpath_enabled()); return std::get(arg).num_addressable_shards(); } else { return std::get>(arg).size(); @@ -119,7 +141,6 @@ struct ShardedBufferAdapter { static tsl::RCReference GetIfRtArray( const ExecuteShardedArg& arg) { if (std::holds_alternative(arg)) { - CHECK(std::get(arg).fastpath_enabled()); return tsl::FormRef(std::get(arg).ifrt_array()); } auto& arg_vector = std::get>(arg); @@ -154,9 +175,10 @@ struct ShardedBufferAdapter { }; void PopulateExecuteShardedResults( - const std::shared_ptr& client, + const nb_class_ptr& client, std::vector> ifrt_arrays, - int num_computations, std::vector>& outputs) { + const xla::PjRtFuture& result_status, int num_computations, + std::vector>& outputs) { auto traceback = Traceback::Get(); DCHECK_GT(num_computations, 0); int num_output_buffers = ifrt_arrays.size(); @@ -169,21 +191,24 @@ void PopulateExecuteShardedResults( TF_CHECK_OK(exploded_arrays.status()); for (auto& exploded_array : *exploded_arrays) { outputs[buffer_id].push_back(PyArray::MakeFromSingleDeviceArray( - client, traceback, std::move(exploded_array), false, true)); + client, traceback, std::move(exploded_array), false, true, + result_status)); } } } template > -StatusOr ExecuteShardedOnLocalDevicesInternal( - const ExecuteOptions& options, const std::shared_ptr& client, +absl::StatusOr ExecuteShardedOnLocalDevicesInternal( + const ExecuteOptions& options, const nb_class_ptr& client, ifrt::LoadedExecutable* ifrt_loaded_executable, absl::Span args, - std::optional>>& returned_futures) { + std::optional>>& returned_futures, + bool attach_status_to_results) { std::vector> output_arrays; - std::unique_ptr> returned_future; + std::unique_ptr> returned_future; int num_computations = ifrt_loaded_executable->addressable_devices().size(); + xla::PjRtFuture result_status; { - py::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; for (const auto& arg : args) { if (ArgAdapter::num_devices(arg) != num_computations) { return xla::InvalidArgument( @@ -203,6 +228,11 @@ StatusOr ExecuteShardedOnLocalDevicesInternal( absl::MakeSpan(arg_arrays), options, /*devices=*/std::nullopt)); output_arrays = std::move(result.outputs); + // attach_status_to_results is only supposed to be true when the computation + // has tokens. + if (attach_status_to_results) { + result_status = result.status; + } if (returned_futures.has_value()) { returned_futures->resize(num_computations, std::move(result.status)); } @@ -217,23 +247,25 @@ StatusOr ExecuteShardedOnLocalDevicesInternal( : PyShardedToken(); return PyExecuteResults(client, std::move(output_arrays), num_computations, - std::move(py_sharded_token)); + std::move(py_sharded_token), result_status); } } // namespace PyExecuteResults::PyExecuteResults( - const std::shared_ptr& client, + const nb_class_ptr& client, std::vector> ifrt_arrays, - int num_computations, PyShardedToken token) + int num_computations, PyShardedToken token, + xla::PjRtFuture result_status) : client_(client), ifrt_arrays_(std::move(ifrt_arrays)), num_computations_(num_computations), - token_(std::move(token)) {} + token_(std::move(token)), + result_status_(std::move(result_status)) {} void PyExecuteResults::CheckNotDisassembled() const { if (is_exploded_) { - throw py::value_error("ExecuteResults already exploded."); + throw nb::value_error("ExecuteResults already exploded."); } } @@ -245,7 +277,7 @@ std::vector> PyExecuteResults::Consume() { PyShardedToken PyExecuteResults::ConsumeToken() { if (token_consumed_) { - throw py::value_error("ExecuteResults token already consumed."); + throw nb::value_error("ExecuteResults token already consumed."); } token_consumed_ = true; return std::move(token_); @@ -254,7 +286,8 @@ PyShardedToken PyExecuteResults::ConsumeToken() { std::vector> PyExecuteResults::DisassembleIntoSingleDeviceArrays() { std::vector> outputs; - PopulateExecuteShardedResults(client_, Consume(), num_computations_, outputs); + PopulateExecuteShardedResults(client_, Consume(), result_status_, + num_computations_, outputs); return outputs; } @@ -262,9 +295,10 @@ std::vector> PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays(size_t n) { CheckNotDisassembled(); if (n > ifrt_arrays_.size()) { - throw py::value_error( + throw nb::value_error( absl::StrCat("In DisassemblePrefixIntoSingleDeviceArrays: ", n, " > ", - ifrt_arrays_.size())); + ifrt_arrays_.size()) + .c_str()); } std::vector> ifrt_arrays; ifrt_arrays.reserve(ifrt_arrays_.size() - n); @@ -274,113 +308,128 @@ PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays(size_t n) { ifrt_arrays_.erase(ifrt_arrays_.begin() + n, ifrt_arrays_.end()); std::swap(ifrt_arrays_, ifrt_arrays); std::vector> outputs; - PopulateExecuteShardedResults(client_, std::move(ifrt_arrays), + PopulateExecuteShardedResults(client_, std::move(ifrt_arrays), result_status_, num_computations_, outputs); return outputs; } -std::vector PyExecuteResults::ConsumeWithHandlers( - std::vector> +std::vector PyExecuteResults::ConsumeWithHandlers( + std::vector> out_handlers) { - std::vector outputs; + std::vector outputs; auto ifrt_arrays = Consume(); auto traceback = Traceback::Get(); DCHECK_GT(num_computations_, 0); int num_output_buffers = ifrt_arrays.size(); outputs.reserve(num_output_buffers); if (out_handlers.size() != num_output_buffers) { - throw py::value_error(absl::StrCat( - "Mismatch between out_handlers and num_results: ", out_handlers.size(), - " vs ", num_output_buffers)); + throw nb::value_error( + absl::StrCat("Mismatch between out_handlers and num_results: ", + out_handlers.size(), " vs ", num_output_buffers) + .c_str()); } for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) { auto& handler = out_handlers[buffer_id]; if (std::holds_alternative(handler)) { outputs.push_back(std::get(handler)->Call( - client_, std::move(ifrt_arrays[buffer_id]))); + client_, std::move(ifrt_arrays[buffer_id]), result_status_)); } else { tsl::profiler::TraceMe traceme("ConsumeWithHandlers fallback."); - std::vector bufs; - bufs.reserve(num_computations_); auto disassembled_arrays = ifrt_arrays[buffer_id]->DisassembleIntoSingleDeviceArrays( ifrt::ArrayCopySemantics::kReuseInput); TF_CHECK_OK(disassembled_arrays.status()); + nb::list bufs = + nb::steal(PyList_New(disassembled_arrays->size())); + int i = 0; for (auto& disassembled_array : *disassembled_arrays) { - bufs.push_back(PyArray::MakeFromSingleDeviceArray( - client_, traceback, std::move(disassembled_array), false, true)); + nb::object array = PyArray::MakeFromSingleDeviceArray( + client_, traceback, std::move(disassembled_array), false, true, + result_status_); + PyList_SET_ITEM(bufs.ptr(), i, array.release().ptr()); + ++i; } - outputs.push_back(std::get(handler)(std::move(bufs))); + outputs.push_back(std::get(handler)(std::move(bufs))); } } return outputs; } -StatusOr>> +absl::StatusOr>> PyLoadedExecutable::ExecuteShardedOnLocalDevices( absl::Span args) { - std::optional>> returned_futures; - TF_ASSIGN_OR_RETURN(auto outputs_and_tokens, - ExecuteShardedOnLocalDevicesInternal( - options_, client_, ifrt_loaded_executable_.get(), - args, returned_futures)); + std::optional>> returned_futures; + TF_ASSIGN_OR_RETURN( + auto outputs_and_tokens, + ExecuteShardedOnLocalDevicesInternal( + options_, client_, ifrt_loaded_executable_.get(), args, + returned_futures, /*attach_status_to_results=*/false)); return outputs_and_tokens.DisassembleIntoSingleDeviceArrays(); } -StatusOr>, PyShardedToken>> +absl::StatusOr>, PyShardedToken>> PyLoadedExecutable::ExecuteShardedOnLocalDevicesWithTokens( absl::Span args) { - std::optional>> returned_futures; + std::optional>> returned_futures; returned_futures.emplace(); - TF_ASSIGN_OR_RETURN(auto outputs_and_tokens, - ExecuteShardedOnLocalDevicesInternal( - options_, client_, ifrt_loaded_executable_.get(), - args, returned_futures)); + TF_ASSIGN_OR_RETURN( + auto outputs_and_tokens, + ExecuteShardedOnLocalDevicesInternal( + options_, client_, ifrt_loaded_executable_.get(), args, + returned_futures, /*attach_status_to_results=*/true)); return std::make_pair(outputs_and_tokens.DisassembleIntoSingleDeviceArrays(), outputs_and_tokens.ConsumeToken()); } -StatusOr PyLoadedExecutable::ExecuteSharded( +absl::StatusOr PyLoadedExecutable::ExecuteSharded( std::vector args, bool with_tokens) { - std::optional>> returned_futures; + std::optional>> returned_futures; if (with_tokens) { returned_futures.emplace(); } absl::Span span_args = args; - return ExecuteShardedOnLocalDevicesInternal(options_, client_, - ifrt_loaded_executable_.get(), - span_args, returned_futures); + return ExecuteShardedOnLocalDevicesInternal( + options_, client_, ifrt_loaded_executable_.get(), span_args, + returned_futures, /*attach_status_to_results=*/with_tokens); } -StatusOr>> +absl::StatusOr>> PyLoadedExecutable::HloModules() const { + nb::gil_scoped_release gil_release; return ifrt_loaded_executable_->GetHloModules(); } -StatusOr>> +absl::StatusOr>> PyLoadedExecutable::GetOutputMemoryKinds() const { + nb::gil_scoped_release gil_release; return ifrt_loaded_executable_->GetOutputMemoryKinds(); } -StatusOr> PyLoadedExecutable::GetParameterLayouts() const { +absl::StatusOr>> +PyLoadedExecutable::GetParameterLayouts() const { + nb::gil_scoped_release gil_release; return ifrt_loaded_executable_->GetParameterLayouts(); } -StatusOr> PyLoadedExecutable::GetOutputLayouts() const { +absl::StatusOr>> +PyLoadedExecutable::GetOutputLayouts() const { + nb::gil_scoped_release gil_release; return ifrt_loaded_executable_->GetOutputLayouts(); } std::optional> PyLoadedExecutable::GetParameterShardings() const { + nb::gil_scoped_release gil_release; return ifrt_loaded_executable_->GetParameterShardings(); } std::optional> PyLoadedExecutable::GetOutputShardings() const { + nb::gil_scoped_release gil_release; return ifrt_loaded_executable_->GetOutputShardings(); } -void PyLoadedExecutable::KeepAlive(py::object obj) { +void PyLoadedExecutable::KeepAlive(nb::object obj) { keepalives_.push_back(std::move(obj)); } diff --git a/xla/python/py_executable.h b/xla/python/py_executable.h index 8e722c908235b..e3c4428fb9246 100644 --- a/xla/python/py_executable.h +++ b/xla/python/py_executable.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,38 +16,55 @@ limitations under the License. #ifndef XLA_PYTHON_PY_EXECUTABLE_H_ #define XLA_PYTHON_PY_EXECUTABLE_H_ +#include +#include #include #include #include +#include #include #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" -#include "pybind11/gil.h" // from @pybind11 +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "xla/hlo/ir/hlo_module.h" +#include "xla/layout.h" +#include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/nb_class_ptr.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/py_array.h" #include "xla/python/py_client.h" #include "xla/python/traceback.h" -#include "xla/statusor.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/status.h" namespace xla { class PyToken { public: PyToken() = default; - explicit PyToken(PjRtFuture future) : future_(std::move(future)) {} + explicit PyToken(PjRtFuture future) + : future_(std::move(future)) {} static PyToken ReadyPyToken() { - return PyToken(PjRtFuture(OkStatus())); + return PyToken(PjRtFuture(absl::OkStatus())); } - Status Await(); + absl::Status Await(); private: - PjRtFuture future_; + PjRtFuture future_; }; // PyShardedToken contains a PyToken for each device's execution. @@ -55,7 +72,7 @@ class PyShardedToken { public: // Default construction creates a always-ready token. PyShardedToken() = default; - explicit PyShardedToken(std::vector> futures) + explicit PyShardedToken(std::vector> futures) : futures_(std::move(futures)) {} PyToken GetPyToken(int device_id) const { @@ -63,25 +80,27 @@ class PyShardedToken { return PyToken(futures_.at(device_id)); } - Status Await(); + absl::Status Await(); private: - std::vector> futures_; + std::vector> futures_; }; class PyExecuteResults { public: - PyExecuteResults(const std::shared_ptr& client, + PyExecuteResults(const nb_class_ptr& client, std::vector> ifrt_arrays, - int num_computations, PyShardedToken token); + int num_computations, PyShardedToken token, + xla::PjRtFuture result_status = + xla::PjRtFuture()); std::vector> DisassembleIntoSingleDeviceArrays(); std::vector> DisassemblePrefixIntoSingleDeviceArrays( size_t n); - std::vector ConsumeWithHandlers( - std::vector> + std::vector ConsumeWithHandlers( + std::vector> out_handlers); std::vector> Consume(); @@ -98,10 +117,12 @@ class PyExecuteResults { private: bool is_exploded_ = false; bool token_consumed_ = false; - std::shared_ptr client_; + nb_class_ptr client_; std::vector> ifrt_arrays_; int num_computations_; PyShardedToken token_; + // Only set if the computation has tokens. + xla::PjRtFuture result_status_; }; using ExecuteShardedArg = std::variant>; @@ -109,17 +130,16 @@ using ExecuteShardedArg = std::variant>; // Python wrapper around PjRtExecutable. We use a wrapper class: // a) to keep the PyClient alive via a std::shared_ptr<> // b) to add Python-specific functionality. -class PyLoadedExecutable - : public std::enable_shared_from_this { +class PyLoadedExecutable { public: PyLoadedExecutable( - std::shared_ptr client, + nb_class_ptr client, std::unique_ptr ifrt_loaded_executable, - std::shared_ptr traceback, + std::optional traceback, std::optional fingerprint); ~PyLoadedExecutable(); - std::shared_ptr client() const { return client_; } + nb_class_ptr client() const { return client_; } ifrt::LoadedExecutable* ifrt_loaded_executable() const { return ifrt_loaded_executable_.get(); } @@ -129,24 +149,24 @@ class PyLoadedExecutable return ifrt_loaded_executable_->addressable_device_logical_ids(); } - std::vector> AddressableDevices() const; + std::vector> AddressableDevices() const; int64_t SizeOfGeneratedCodeInBytes() const { return ifrt_loaded_executable_->SizeOfGeneratedCodeInBytes(); } - StatusOr GetCompiledMemoryStats() const { - pybind11::gil_scoped_release scope; + absl::StatusOr GetCompiledMemoryStats() const { + nanobind::gil_scoped_release scope; return ifrt_loaded_executable_->GetCompiledMemoryStats(); } - StatusOr> GetCostAnalysis() - const { + absl::StatusOr> + GetCostAnalysis() const { return ifrt_loaded_executable_->GetCostAnalysis(); } void Delete() { - // TODO(hyeontaek): Return Status. + // TODO(hyeontaek): Return absl::Status. TF_CHECK_OK(ifrt_loaded_executable_->Delete().Await()); } @@ -156,30 +176,32 @@ class PyLoadedExecutable // PjRtExecutable::Execute. The result is similarly transposed back into the // argid,deviceid format. // args is [num_args x num_devices]. - StatusOr>> ExecuteShardedOnLocalDevices( - absl::Span args); + absl::StatusOr>> + ExecuteShardedOnLocalDevices(absl::Span args); - StatusOr>, PyShardedToken>> + absl::StatusOr>, PyShardedToken>> ExecuteShardedOnLocalDevicesWithTokens( absl::Span args); - StatusOr ExecuteSharded(std::vector args, - bool with_tokens); + absl::StatusOr ExecuteSharded( + std::vector args, bool with_tokens); - StatusOr>> HloModules() const; + absl::StatusOr>> HloModules() const; - StatusOr>> GetOutputMemoryKinds() - const; + absl::StatusOr>> + GetOutputMemoryKinds() const; - StatusOr> GetParameterLayouts() const; + absl::StatusOr>> GetParameterLayouts() + const; - StatusOr> GetOutputLayouts() const; + absl::StatusOr>> GetOutputLayouts() + const; std::optional> GetParameterShardings() const; std::optional> GetOutputShardings() const; - Traceback* traceback() { return traceback_.get(); } + const std::optional& traceback() { return traceback_; } ifrt::LoadedExecutable* ifrt_executable() const { return ifrt_loaded_executable_.get(); @@ -210,14 +232,14 @@ class PyLoadedExecutable const std::optional& fingerprint() const { return fingerprint_; } // Keep `obj` alive as long as PyLoadedExecutable. - void KeepAlive(pybind11::object obj); + void KeepAlive(nanobind::object obj); private: friend class PyClient; - std::shared_ptr client_; + nb_class_ptr client_; std::unique_ptr ifrt_loaded_executable_; - std::shared_ptr traceback_; + std::optional traceback_; // Identical executables (i.e. representing the same program) will have the // same fingerprint. nullopt on platforms or executables where fingerprints @@ -228,7 +250,7 @@ class PyLoadedExecutable ExecuteOptions options_; // Python objects to keep alive as requested by user. - std::vector keepalives_; + std::vector keepalives_; // Doubly-linked list of all executables known to the client. Protected by the // GIL. diff --git a/xla/python/py_host_callback.cc b/xla/python/py_host_callback.cc index 6905540f9307e..3088a9b408f94 100644 --- a/xla/python/py_host_callback.cc +++ b/xla/python/py_host_callback.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ limitations under the License. #include "xla/python/py_host_callback.h" +#include #include #include #include @@ -22,9 +23,12 @@ limitations under the License. #include #include "google/protobuf/any.pb.h" +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" -#include "pybind11/gil.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind #include "xla/layout_util.h" #include "xla/pjrt/host_callback.h" #include "xla/pjrt/pjrt_compiler.h" @@ -36,9 +40,10 @@ limitations under the License. #include "xla/python/python_ref_manager.h" #include "xla/python/types.h" #include "xla/shape.h" -#include "xla/statusor.h" #include "xla/util.h" +namespace nb = nanobind; + namespace xla { char PyCpuLoadedHostCallback::ID = 0; @@ -46,7 +51,7 @@ char PyHostSendAndRecvLoadedHostCallback::ID = 0; namespace { -StatusOr> CreateCallbackArgs( +absl::StatusOr> CreateCallbackArgs( absl::Span operand_shapes) { std::vector callback_args(operand_shapes.size()); for (int i = 0; i < operand_shapes.size(); ++i) { @@ -62,7 +67,7 @@ StatusOr> CreateCallbackArgs( callback_args[i].type = shape.element_type(); callback_args[i].size_in_bytes = ShapeUtil::ByteSizeOf(layout); TF_ASSIGN_OR_RETURN(callback_args[i].dtype, - PrimitiveTypeToDtype(shape.element_type())); + PrimitiveTypeToNbDtype(shape.element_type())); } else if (shape.IsToken()) { callback_args[i].type = TOKEN; } else { @@ -75,7 +80,7 @@ StatusOr> CreateCallbackArgs( return callback_args; } -StatusOr> CreateCallbackResults( +absl::StatusOr> CreateCallbackResults( absl::Span result_shapes) { std::vector callback_results(result_shapes.size()); for (int i = 0; i < result_shapes.size(); ++i) { @@ -107,9 +112,9 @@ StatusOr> CreateCallbackResults( } // namespace -StatusOr> +absl::StatusOr> PyCpuLoadedHostCallback::Create(ifrt::Client* ifrt_client, - pybind11::function callable, + nb::callable callable, absl::Span operand_shapes, absl::Span result_shapes) { ifrt::PlatformId platform_id = ifrt_client->platform_id(); @@ -131,19 +136,18 @@ PyCpuLoadedHostCallback::Create(ifrt::Client* ifrt_client, std::move(cpu_callback))); } -StatusOr PyCpuLoadedHostCallback::Serialize() const { +absl::StatusOr PyCpuLoadedHostCallback::Serialize() const { return Unimplemented( "PyHostSendAndRecvLoadedHostCallback serialization is not supported"); } -StatusOr> +absl::StatusOr> PyHostSendAndRecvLoadedHostCallback::Create( - ifrt::Client* ifrt_client, pybind11::function callable, + ifrt::Client* ifrt_client, nb::callable callable, absl::Span operand_shapes, absl::Span result_shapes, absl::Span send_channel_ids, - absl::Span recv_channel_ids, - pybind11::function serializer) { + absl::Span recv_channel_ids, nb::callable serializer) { TF_ASSIGN_OR_RETURN(auto callback_args, CreateCallbackArgs(operand_shapes)); TF_ASSIGN_OR_RETURN(auto callback_results, CreateCallbackResults(result_shapes)); @@ -188,11 +192,11 @@ PyHostSendAndRecvLoadedHostCallback::Create( PyHostSendAndRecvLoadedHostCallback::PyHostSendAndRecvLoadedHostCallback( ifrt::Client* ifrt_client, - std::unique_ptr xla_host_callback, - pybind11::function callable, absl::Span operand_shapes, + std::unique_ptr xla_host_callback, nb::callable callable, + absl::Span operand_shapes, absl::Span result_shapes, absl::Span send_channel_ids, - absl::Span recv_channel_ids, pybind11::function serializer) + absl::Span recv_channel_ids, nb::callable serializer) : llvm::RTTIExtends( ifrt_client, std::move(xla_host_callback)), @@ -205,12 +209,13 @@ PyHostSendAndRecvLoadedHostCallback::PyHostSendAndRecvLoadedHostCallback( PyHostSendAndRecvLoadedHostCallback::~PyHostSendAndRecvLoadedHostCallback() { GlobalPyRefManager()->AddGarbage( - absl::MakeSpan(static_cast(&callable_), 1)); + absl::MakeSpan(static_cast(&callable_), 1)); GlobalPyRefManager()->AddGarbage( - absl::MakeSpan(static_cast(&serializer_), 1)); + absl::MakeSpan(static_cast(&serializer_), 1)); } -StatusOr PyHostSendAndRecvLoadedHostCallback::Serialize() const { +absl::StatusOr PyHostSendAndRecvLoadedHostCallback::Serialize() + const { if (serializer_.is_none()) { return InvalidArgument( "Host callback cannot be serialized because serializer was not " @@ -236,10 +241,11 @@ StatusOr PyHostSendAndRecvLoadedHostCallback::Serialize() const { std::string callable; { - pybind11::gil_scoped_acquire gil_acquire; + nb::gil_scoped_acquire gil_acquire; try { - callable = pybind11::cast(serializer_(callable_)); - } catch (const pybind11::error_already_set& e) { + nb::bytes bytes = nb::cast(serializer_(callable_)); + callable = std::string(bytes.c_str(), bytes.size()); + } catch (const nb::python_error& e) { return absl::InternalError(absl::StrCat( "Unable to pickle the host_callback callable: ", e.what())); } catch (const std::exception& e) { diff --git a/xla/python/py_host_callback.h b/xla/python/py_host_callback.h index c9518a4553896..6678361b22625 100644 --- a/xla/python/py_host_callback.h +++ b/xla/python/py_host_callback.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,18 +16,24 @@ limitations under the License. #ifndef XLA_PYTHON_PY_HOST_CALLBACK_H_ #define XLA_PYTHON_PY_HOST_CALLBACK_H_ +#include #include #include #include #include +#include "absl/base/casts.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "xla/pjrt/host_callback.h" #include "xla/python/callback.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/host_callback.h" #include "xla/python/pjrt_ifrt/pjrt_host_callback.h" #include "xla/shape.h" +#include "tsl/concurrency/ref_count.h" namespace xla { @@ -44,8 +50,8 @@ class PyCpuLoadedHostCallback final : public llvm::RTTIExtends { public: - static StatusOr> Create( - ifrt::Client* ifrt_client, pybind11::function callable, + static absl::StatusOr> Create( + ifrt::Client* ifrt_client, nanobind::callable callable, absl::Span operand_shapes, absl::Span result_shapes); @@ -60,7 +66,7 @@ class PyCpuLoadedHostCallback final ifrt::Client* client() const override { return ifrt_client_; } - StatusOr Serialize() const override; + absl::StatusOr Serialize() const override; static char ID; // NOLINT @@ -89,19 +95,19 @@ class PyHostSendAndRecvLoadedHostCallback final : public llvm::RTTIExtends { public: - static StatusOr> Create( - ifrt::Client* ifrt_client, pybind11::function callable, - absl::Span operand_shapes, - absl::Span result_shapes, - absl::Span send_channel_ids, - absl::Span recv_channel_ids, - pybind11::function serializer); + static absl::StatusOr> + Create(ifrt::Client* ifrt_client, nanobind::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); // PjRtLoadedHostCallback implementation. ~PyHostSendAndRecvLoadedHostCallback() override; - StatusOr Serialize() const override; + absl::StatusOr Serialize() const override; static char ID; // NOLINT @@ -109,22 +115,22 @@ class PyHostSendAndRecvLoadedHostCallback final PyHostSendAndRecvLoadedHostCallback( ifrt::Client* ifrt_client, std::unique_ptr xla_host_callback, - pybind11::function callable, absl::Span operand_shapes, + nanobind::callable callable, absl::Span operand_shapes, absl::Span result_shapes, absl::Span send_channel_ids, absl::Span recv_channel_ids, - pybind11::function serializer); + nanobind::callable serializer); template friend tsl::RCReference tsl::MakeRef(Args&&... args); // Retained arguments for host callback serialization. - pybind11::function callable_; + nanobind::callable callable_; std::vector operand_shapes_; std::vector result_shapes_; std::vector send_channel_ids_; std::vector recv_channel_ids_; - pybind11::function serializer_; + nanobind::callable serializer_; }; } // namespace xla diff --git a/xla/python/py_host_callback.proto b/xla/python/py_host_callback.proto index 642ddf377ac74..f91e122af0571 100644 --- a/xla/python/py_host_callback.proto +++ b/xla/python/py_host_callback.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/py_memory_space.cc b/xla/python/py_memory_space.cc new file mode 100644 index 0000000000000..e60b9d2fa6514 --- /dev/null +++ b/xla/python/py_memory_space.cc @@ -0,0 +1,107 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/py_memory_space.h" + +#include + +#include +#include + +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/string_view.h" // from @nanobind // IWYU pragma: keep +#include "xla/pjrt/pjrt_client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/py_client.h" + +namespace nb = ::nanobind; + +namespace xla { + +PyMemorySpace::PyMemorySpace(nb_class_ptr client, + PjRtMemorySpace* memory_space) + : client_(std::move(client)), memory_space_(memory_space) {} + +int PyMemorySpace::process_index() const { return client_->process_index(); } + +std::string_view PyMemorySpace::platform() const { + // TODO(phawkins): this is a temporary backwards + // compatibility shim. We changed the name PJRT + // reports for GPU platforms to "cuda" or "rocm", + // but we haven't yet updated JAX clients that + // expect "gpu". Migrate users and remove this + // code. + if (client_->platform_name() == "cuda" || + client_->platform_name() == "rocm") { + return std::string_view("gpu"); + } else { + return client_->platform_name(); + } +} + +std::string_view PyMemorySpace::kind() const { + return memory_space_->memory_space_kind(); +} + +std::string_view PyMemorySpace::Str() const { + return memory_space_->DebugString(); +} + +std::string_view PyMemorySpace::Repr() const { + return memory_space_->ToString(); +} + +nb::list PyMemorySpace::AddressableByDevices() const { + nb::list devices; + for (ifrt::Device* device : memory_space_->devices()) { + devices.append(client_->GetPyDevice(device)); + } + return devices; +} + +/* static */ int PyMemorySpace::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + PyMemorySpace* d = nb::inst_ptr(self); + Py_VISIT(d->client().ptr()); + return 0; +} + +/* static */ int PyMemorySpace::tp_clear(PyObject* self) { + PyMemorySpace* d = nb::inst_ptr(self); + nb_class_ptr client; + std::swap(client, d->client_); + return 0; +} + +PyType_Slot PyMemorySpace::slots_[] = { + {Py_tp_traverse, (void*)PyMemorySpace::tp_traverse}, + {Py_tp_clear, (void*)PyMemorySpace::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyMemorySpace::RegisterPythonType(nb::module_& m) { + nb::class_ device(m, "Memory", + nb::type_slots(PyMemorySpace::slots_)); + device.def_prop_ro("process_index", &PyMemorySpace::process_index) + .def_prop_ro("platform", &PyMemorySpace::platform) + .def_prop_ro("kind", &PyMemorySpace::kind) + .def("__str__", &PyMemorySpace::Str) + .def("__repr__", &PyMemorySpace::Repr) + .def("addressable_by_devices", &PyMemorySpace::AddressableByDevices, + "Returns devices that can address this memory."); +} + +} // namespace xla diff --git a/xla/python/py_memory_space.h b/xla/python/py_memory_space.h new file mode 100644 index 0000000000000..9bcea77d43ad2 --- /dev/null +++ b/xla/python/py_memory_space.h @@ -0,0 +1,66 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_PY_MEMORY_SPACE_H_ +#define XLA_PYTHON_PY_MEMORY_SPACE_H_ + +#include + +#include + +#include "nanobind/nanobind.h" // from @nanobind +#include "xla/pjrt/pjrt_client.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/py_client.h" + +namespace xla { + +class PyMemorySpace { + public: + PyMemorySpace(nb_class_ptr client, PjRtMemorySpace* memory_space); + + // Memory spaces are compared using Python object identity, so we don't allow + // them to be copied or moved. + PyMemorySpace(const PyMemorySpace&) = delete; + PyMemorySpace(PyMemorySpace&&) = delete; + PyMemorySpace& operator=(const PyMemorySpace&) = delete; + PyMemorySpace& operator=(PyMemorySpace&&) = delete; + + const nb_class_ptr& client() const { return client_; } + PjRtMemorySpace* memory_space() const { return memory_space_; } + + int process_index() const; + std::string_view platform() const; + std::string_view kind() const; + + std::string_view Str() const; + std::string_view Repr() const; + + nanobind::list AddressableByDevices() const; + + static void RegisterPythonType(nanobind::module_& m); + + private: + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + nb_class_ptr client_; + PjRtMemorySpace* memory_space_; +}; + +} // namespace xla + +#endif // XLA_PYTHON_PY_MEMORY_SPACE_H_ diff --git a/xla/python/py_program.cc b/xla/python/py_program.cc new file mode 100644 index 0000000000000..4cedf9d2b6b9b --- /dev/null +++ b/xla/python/py_program.cc @@ -0,0 +1,126 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/py_program.h" + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/vector.h" // from @nanobind // IWYU pragma: keep +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/plugin_program.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +namespace nb = ::nanobind; + +namespace { + +absl::StatusOr> MakePluginProgramFromString( + std::string data) { + auto plugin_program = std::make_unique(); + plugin_program->data = std::move(data); + return plugin_program; +} + +absl::StatusOr> MakePluginProgramFromBytes( + nb::bytes data) { + auto plugin_program = std::make_unique(); + plugin_program->data = std::string(data.c_str(), data.size()); + return plugin_program; +} + +absl::StatusOr> +MakePluginCompileOptions() { + return std::make_unique(); +} + +absl::StatusOr> MakeXlaProgram( + absl::string_view mlir_module) { + auto context = std::make_unique(); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, *context)); + return std::make_unique(std::move(context), + std::move(module)); +} + +absl::StatusOr> MakeXlaProgramFromString( + std::string mlir_module) { + return MakeXlaProgram(mlir_module); +} + +absl::StatusOr> MakeXlaProgramFromBytes( + nb::bytes mlir_module) { + return MakeXlaProgram( + absl::string_view(mlir_module.c_str(), mlir_module.size())); +} + +absl::StatusOr> MakeXlaCompileOptions( + CompileOptions options, std::vector host_callbacks) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or + // `PyClient::GetEmitPythonCallbackDescriptor()`. + for (auto& host_callback : host_callbacks) { + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); + } + return std::make_unique( + std::move(options), std::move(ifrt_loaded_host_callbacks)); +} + +} // namespace + +void BuildIfrtProgramsSubmodule(nanobind::module_& m) { + auto sub_module = m.def_submodule("ifrt_programs"); + nb::class_ ifrt_program_base_class(sub_module, "Program"); + nb::class_ ifrt_compile_options_base_class( + sub_module, "CompileOptions"); + sub_module + .def("make_xla_program", + xla::ValueOrThrowWrapper(MakeXlaProgramFromString)) + .def("make_xla_program", + xla::ValueOrThrowWrapper(MakeXlaProgramFromBytes)) + .def("make_plugin_program", + xla::ValueOrThrowWrapper(MakePluginProgramFromString)) + .def("make_plugin_program", + xla::ValueOrThrowWrapper(MakePluginProgramFromBytes)) + .def("make_xla_compile_options", + xla::ValueOrThrowWrapper(MakeXlaCompileOptions)) + .def("make_plugin_compile_options", + xla::ValueOrThrowWrapper(MakePluginCompileOptions)); +} + +} // namespace xla diff --git a/xla/python/py_program.h b/xla/python/py_program.h new file mode 100644 index 0000000000000..2c9c2732db4f6 --- /dev/null +++ b/xla/python/py_program.h @@ -0,0 +1,27 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_PY_PROGRAM_H_ +#define XLA_PYTHON_PY_PROGRAM_H_ + +#include "nanobind/nanobind.h" // from @nanobind + +namespace xla { + +void BuildIfrtProgramsSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // XLA_PYTHON_PY_PROGRAM_H_ diff --git a/xla/python/py_values.cc b/xla/python/py_values.cc index 85de0aa4a588d..206c87521a46a 100644 --- a/xla/python/py_values.cc +++ b/xla/python/py_values.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,65 +12,77 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Must be included first -// clang-format off -#include "xla/python/ifrt/memory.h" -#include "tsl/python/lib/core/numpy.h" //NOLINT -// clang-format on #include "xla/python/py_values.h" -// NOLINTBEGIN +#include + +#include #include #include #include #include +#include +#include #include -// NOLINTEND -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/complex.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // from @nanobind // IWYU pragma: keep #include "xla/primitive_util.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/py_array.h" -#include "xla/python/py_buffer.h" #include "xla/python/python_ref_manager.h" #include "xla/python/sharding.h" #include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/types.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/float8.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" -namespace py = pybind11; +namespace nb = nanobind; namespace xla { namespace { -using DevicePutFunc = std::function( - py::handle, ifrt::Client*, ifrt::Device*, const DevicePutOptions& options, +using DevicePutFunc = std::function( + nb::handle, ifrt::Client*, ifrt::Device*, const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind)>; template -StatusOr HandlePythonScalar(py::handle obj, - ifrt::Client* client, - ifrt::Device* to_device, - const DevicePutOptions& options, - ifrt::MemoryKind to_memory_kind) { +absl::StatusOr HandlePythonScalar( + nb::handle obj, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { T data; try { - data = py::cast(obj); + data = nb::cast(obj); } catch (const std::exception& e) { return InvalidArgument( "Unable to convert Python scalar to %s. This most likely means the " "value (%s) overflows the range of the type.", PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), - py::repr(obj)); + nb::cast(nb::repr(obj))); } void* ptr; @@ -89,7 +101,7 @@ StatusOr HandlePythonScalar(py::handle obj, } // Must release the GIL before BufferFromHostBuffer because backends may // decide to block/sleep for device buffer allocation. - py::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(type)); // TODO(yashkatariya): Plumb sharding or memory_kind here. TF_ASSIGN_OR_RETURN( @@ -102,10 +114,9 @@ StatusOr HandlePythonScalar(py::handle obj, return DevicePutResult(std::move(ifrt_array), /*weak_type=*/true); } -StatusOr HandlePythonInt(py::handle obj, ifrt::Client* client, - ifrt::Device* to_device, - const DevicePutOptions& options, - ifrt::MemoryKind to_memory_kind) { +absl::StatusOr HandlePythonInt( + nb::handle obj, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { void* ptr; PrimitiveType type; int64_t data_int64; @@ -113,32 +124,32 @@ StatusOr HandlePythonInt(py::handle obj, ifrt::Client* client, if (options.squash_64bit_types) { try { - data_int32 = py::cast(obj); + data_int32 = nb::cast(obj); } catch (const std::exception& e) { return InvalidArgument( "Unable to convert Python scalar to %s. This most likely means the " "value (%s) overflows the range of the type.", PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), - py::repr(obj)); + nb::cast(nb::repr(obj))); } ptr = &data_int32; type = S32; } else { try { - data_int64 = py::cast(obj); + data_int64 = nb::cast(obj); } catch (const std::exception& e) { return InvalidArgument( "Unable to convert Python scalar to %s. This most likely means the " "value (%s) overflows the range of the type.", PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), - py::repr(obj)); + nb::cast(nb::repr(obj))); } ptr = &data_int64; type = S64; } // Must release the GIL before BufferFromHostBuffer because backends may // decide to block/sleep for device buffer allocation. - py::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(type)); // TODO(yashkatariya): Plumb sharding or memory_kind here. TF_ASSIGN_OR_RETURN( @@ -152,10 +163,9 @@ StatusOr HandlePythonInt(py::handle obj, ifrt::Client* client, } template -StatusOr HandleNumpyScalar(py::handle h, ifrt::Client* client, - ifrt::Device* to_device, - const DevicePutOptions& options, - ifrt::MemoryKind to_memory_kind) { +absl::StatusOr HandleNumpyScalar( + nb::handle h, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { T data; SquashedT data_squashed; void* ptr; @@ -197,7 +207,7 @@ StatusOr HandleNumpyScalar(py::handle h, ifrt::Client* client, } // Must release the GIL before BufferFromHostBuffer because backends may // decide to block/sleep for device buffer allocation. - py::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(type)); // TODO(yashkatariya): Plumb sharding or memory_kind here. TF_ASSIGN_OR_RETURN( @@ -210,20 +220,19 @@ StatusOr HandleNumpyScalar(py::handle h, ifrt::Client* client, return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false); } -StatusOr HandleNumpyArray(py::handle h, ifrt::Client* client, - ifrt::Device* to_device, - const DevicePutOptions& options, - ifrt::MemoryKind to_memory_kind) { - py::array array = py::cast(h); +absl::StatusOr HandleNumpyArray( + nb::handle h, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + xla::nb_numpy_ndarray array = nb::cast(h); TF_ASSIGN_OR_RETURN(PrimitiveType type, DtypeToPrimitiveType(array.dtype())); PrimitiveType squashed_type; if (options.squash_64bit_types) { squashed_type = Squash64BitTypes(type); if (squashed_type != type) { - TF_ASSIGN_OR_RETURN(py::dtype squashed_dtype, - PrimitiveTypeToDtype(squashed_type)); - array = py::reinterpret_steal(PyArray_CastToType( + TF_ASSIGN_OR_RETURN(xla::nb_dtype squashed_dtype, + PrimitiveTypeToNbDtype(squashed_type)); + array = nb::steal(PyArray_CastToType( reinterpret_cast(array.ptr()), reinterpret_cast(squashed_dtype.release().ptr()), /*fortran=*/0)); @@ -248,11 +257,12 @@ StatusOr HandleNumpyArray(py::handle h, ifrt::Client* client, on_done_with_host_buffer = [py_buffer_ref{ std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ }; - host_buffer_semantics = ifrt::Client::HostBufferSemantics::kZeroCopy; + host_buffer_semantics = + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy; } // Must release the GIL before BufferFromHostBuffer because backends may // decide to block/sleep for device buffer allocation. - py::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(squashed_type)); TF_ASSIGN_OR_RETURN( auto ifrt_array, @@ -263,11 +273,12 @@ StatusOr HandleNumpyArray(py::handle h, ifrt::Client* client, return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false); } -StatusOr HandlePyArray(py::handle obj, ifrt::Client* client, - ifrt::Device* to_device, - const DevicePutOptions& options, - ifrt::MemoryKind to_memory_kind) { - auto py_array = py::reinterpret_borrow(obj); +absl::StatusOr HandlePyArray(nb::handle obj, + ifrt::Client* client, + ifrt::Device* to_device, + const DevicePutOptions& options, + ifrt::MemoryKind to_memory_kind) { + auto py_array = nb::borrow(obj); // We only allow single device case for PyArray in device put. if (py_array.num_shards() != 1) { @@ -283,7 +294,7 @@ StatusOr HandlePyArray(py::handle obj, ifrt::Client* client, } // Fallback to python for non-matching clients or pmap sharding. - if (py_array.sharding().get_type() == jax::PmapSharding::type() || + if (py_array.sharding().type().ptr() == jax::PmapSharding::type().ptr() || ifrt_array->sharding().devices().front()->client() != to_device->client()) { return HandleNumpyArray(obj.attr("_value"), client, to_device, options, @@ -294,9 +305,8 @@ StatusOr HandlePyArray(py::handle obj, ifrt::Client* client, (!to_memory_kind.memory_kind().has_value() || !ifrt_array->sharding().memory_kind().memory_kind().has_value() || ifrt_array->sharding().memory_kind() == to_memory_kind)) { - return DevicePutResult( - tsl::FormRef(ifrt_array), py_array.weak_type(), - /*owning_pybuffer=*/py::reinterpret_borrow(obj)); + return DevicePutResult(tsl::FormRef(ifrt_array), py_array.weak_type(), + /*owning_pybuffer=*/nb::borrow(obj)); } else { TF_ASSIGN_OR_RETURN( tsl::RCReference copied_ifrt_array, @@ -309,10 +319,10 @@ StatusOr HandlePyArray(py::handle obj, ifrt::Client* client, } // namespace -StatusOr DevicePut(py::handle arg, ifrt::Client* client, - ifrt::Device* to_device, - const DevicePutOptions& options, - ifrt::MemoryKind to_memory_kind) { +absl::StatusOr DevicePut(nb::handle arg, ifrt::Client* client, + ifrt::Device* to_device, + const DevicePutOptions& options, + ifrt::MemoryKind to_memory_kind) { tsl::profiler::TraceMe traceme("DevicePut"); static const absl::flat_hash_map* const handlers = [] { @@ -329,7 +339,7 @@ StatusOr DevicePut(py::handle arg, ifrt::Client* client, (*p)[reinterpret_cast(&PyComplex_Type)] = HandlePythonScalar; - const auto numpy = py::module::import("numpy"); + const auto numpy = nb::module_::import_("numpy"); (*p)[numpy.attr("ndarray").ptr()] = HandleNumpyArray; // Numpy scalar types. For some of them, we share the handler with @@ -371,16 +381,14 @@ StatusOr DevicePut(py::handle arg, ifrt::Client* client, return p; }(); - if (arg.get_type() == PyArray::type()) { - auto array = py::reinterpret_borrow(arg); - if (array.fastpath_enabled()) { - return HandlePyArray(arg, client, to_device, options, to_memory_kind); - } + if (arg.type().ptr() == PyArray::type().ptr()) { + auto array = nb::borrow(arg); + return HandlePyArray(arg, client, to_device, options, to_memory_kind); } - auto res = handlers->find(arg.get_type().ptr()); + auto res = handlers->find(arg.type().ptr()); if (res == handlers->end()) { - for (auto base_class : arg.get_type().attr("__mro__")) { + for (auto base_class : arg.type().attr("__mro__")) { res = handlers->find(base_class.ptr()); if (res != handlers->end()) { return res->second(arg, client, to_device, options, to_memory_kind); @@ -391,16 +399,16 @@ StatusOr DevicePut(py::handle arg, ifrt::Client* client, "Not supported: The C++ jax jit execution path, only accepts " "DeviceArray, Numpy arrays scalars of supported types " "(see implementation), or Python scalars. Got type ", - py::cast(py::str(arg.get_type())))); + nb::cast(nb::str(arg.type())))); } return res->second(arg, client, to_device, options, to_memory_kind); } -bool IsFloat0(py::array arg) { +bool IsFloat0(xla::nb_numpy_ndarray arg) { static const auto* dtypes_module = - new py::module(py::module::import("jax.dtypes")); + new nb::module_(nb::module_::import_("jax.dtypes")); static const auto* float0_dtype = - new py::handle(dtypes_module->attr("float0")); + new nb::handle(dtypes_module->attr("float0")); return float0_dtype->is(arg.attr("dtype")); } @@ -415,10 +423,10 @@ std::string PyArgSignature::DebugString() const { } using ToPyArgSignatureHandler = - std::function(py::handle, bool)>; + std::function(nb::handle, bool)>; -StatusOr PyArgSignatureOfValue(py::handle arg, - bool jax_enable_x64) { +absl::StatusOr PyArgSignatureOfValue(nb::handle arg, + bool jax_enable_x64) { static const absl::flat_hash_map* const handlers = [] { auto p = new absl::flat_hash_map(); @@ -427,11 +435,12 @@ StatusOr PyArgSignatureOfValue(py::handle arg, // The 4 Python native types. ToPyArgSignatureHandler bool_handler = - [](py::handle, bool) -> StatusOr { + [](nb::handle, bool) -> absl::StatusOr { return PyArgSignature(PrimitiveType::PRED, {}, true); }; ToPyArgSignatureHandler int_handler = - [](py::handle h, bool jax_enable_x64) -> StatusOr { + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { // TODO(phawkins): we should consider checking for integer overflow. if (jax_enable_x64) { return PyArgSignature(PrimitiveType::S64, {}, true); @@ -440,10 +449,10 @@ StatusOr PyArgSignatureOfValue(py::handle arg, } }; ToPyArgSignatureHandler float_handler = - [&dtypes](py::handle h, - bool jax_enable_x64) -> StatusOr { + [&dtypes](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { // Only Python native types has a True weak_type. - bool weak_type = !py::isinstance(h, dtypes.np_float64); + bool weak_type = !xla::nb_isinstance(h, dtypes.np_float64); if (jax_enable_x64) { return PyArgSignature(PrimitiveType::F64, {}, weak_type); } else { @@ -451,12 +460,12 @@ StatusOr PyArgSignatureOfValue(py::handle arg, } }; ToPyArgSignatureHandler complex_handler = - [&dtypes](py::handle h, - bool jax_enable_x64) -> StatusOr { + [&dtypes](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { // Note that this branch is also taken for np.complex128: // isinstance(np.complex128(3), complex) returns True // isinstance(np.complex64(3), complex) returns False - bool weak_type = !py::isinstance(h, dtypes.np_complex128); + bool weak_type = !xla::nb_isinstance(h, dtypes.np_complex128); if (jax_enable_x64) { return PyArgSignature(PrimitiveType::C128, {}, weak_type); } else { @@ -470,8 +479,10 @@ StatusOr PyArgSignatureOfValue(py::handle arg, (*p)[reinterpret_cast(&PyComplex_Type)] = complex_handler; ToPyArgSignatureHandler numpy_handler = - [](py::handle h, bool jax_enable_x64) -> StatusOr { - py::array numpy_array = py::cast(h); + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + xla::nb_numpy_ndarray numpy_array = + nb::cast(h); TF_ASSIGN_OR_RETURN(PrimitiveType dtype, DtypeToPrimitiveType(numpy_array.dtype())); if (!jax_enable_x64) { @@ -489,11 +500,12 @@ StatusOr PyArgSignatureOfValue(py::handle arg, numpy_array.ndim()), /*weak_type=*/false); }; - const auto numpy = py::module::import("numpy"); + const auto numpy = nb::module_::import_("numpy"); (*p)[numpy.attr("ndarray").ptr()] = numpy_handler; ToPyArgSignatureHandler np_uint64_handler = - [](py::handle h, bool jax_enable_x64) -> StatusOr { + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { if (jax_enable_x64) { return PyArgSignature(PrimitiveType::U64, {}, /*weak_type=*/false); } else { @@ -501,7 +513,8 @@ StatusOr PyArgSignatureOfValue(py::handle arg, } }; ToPyArgSignatureHandler np_int_handler = - [](py::handle h, bool jax_enable_x64) -> StatusOr { + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { if (jax_enable_x64) { return PyArgSignature(PrimitiveType::S64, {}, /*weak_type=*/false); } else { @@ -509,7 +522,8 @@ StatusOr PyArgSignatureOfValue(py::handle arg, } }; ToPyArgSignatureHandler numpy_array_handler = - [](py::handle h, bool jax_enable_x64) -> StatusOr { + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { // This block deals with all numpy scalar types, except for int64_dt, // float64_dt and complex128_dt which are taken care of in previous if // blocks. @@ -547,23 +561,21 @@ StatusOr PyArgSignatureOfValue(py::handle arg, return p; }(); - if (arg.get_type() == PyArray::type()) { - auto array = py::reinterpret_borrow(arg); - if (array.fastpath_enabled()) { - ifrt::Array* ifrt_array = array.ifrt_array(); - if (ifrt_array == nullptr) { - return xla::InvalidArgument("Array has been deleted."); - } - TF_ASSIGN_OR_RETURN(auto primitive_type, - ifrt::ToPrimitiveType(ifrt_array->dtype())); - return PyArgSignature(primitive_type, array.shape(), array.weak_type()); + if (arg.type().ptr() == PyArray::type().ptr()) { + auto array = nb::borrow(arg); + ifrt::Array* ifrt_array = array.ifrt_array(); + if (ifrt_array == nullptr) { + return xla::InvalidArgument("Array has been deleted."); } + TF_ASSIGN_OR_RETURN(auto primitive_type, + ifrt::ToPrimitiveType(ifrt_array->dtype())); + return PyArgSignature(primitive_type, array.shape(), array.weak_type()); } - auto res = handlers->find(arg.get_type().ptr()); + auto res = handlers->find(arg.type().ptr()); if (res == handlers->end()) { // We attempt to look at the MRO classes - for (auto base_class : arg.get_type().attr("__mro__")) { + for (auto base_class : arg.type().attr("__mro__")) { res = handlers->find(base_class.ptr()); if (res != handlers->end()) { return res->second(arg, jax_enable_x64); @@ -575,7 +587,7 @@ StatusOr PyArgSignatureOfValue(py::handle arg, "Buffer/DeviceArray, Numpy " "arrays scalars of supported types " "(see implementation), or Python scalars. Got type ", - py::cast(py::str(arg.get_type())))); + nb::cast(nb::str(arg.type())))); } return res->second(arg, jax_enable_x64); } diff --git a/xla/python/py_values.h b/xla/python/py_values.h index 3ee50e45f7626..425fc730208cc 100644 --- a/xla/python/py_values.h +++ b/xla/python/py_values.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,23 +18,27 @@ limitations under the License. #ifndef XLA_PYTHON_PY_VALUES_H_ #define XLA_PYTHON_PY_VALUES_H_ -#include +#include #include #include #include -#include "pybind11/numpy.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 -#include "xla/pjrt/pjrt_client.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" #include "xla/python/ifrt/memory.h" -#include "xla/python/py_client.h" +#include "xla/python/nb_numpy.h" +#include "tsl/concurrency/ref_count.h" namespace xla { struct DevicePutResult { explicit DevicePutResult( tsl::RCReference ifrt_array, bool weak_type, - pybind11::object owning_pybuffer = pybind11::object()) + nanobind::object owning_pybuffer = nanobind::object()) : ifrt_array(std::move(ifrt_array)), weak_type(weak_type), owning_pybuffer(owning_pybuffer) {} @@ -43,7 +47,7 @@ struct DevicePutResult { tsl::RCReference ifrt_array; bool weak_type; - pybind11::object owning_pybuffer; + nanobind::object owning_pybuffer; }; // Copies a buffer-like object to be on device. @@ -53,19 +57,20 @@ struct DevicePutResult { // If the value is known to be a PyBuffer object, py_buffer can be passed as // an optimization to avoid a Python->C++ cast. // -// May throw exceptions from pybind11 in addition to failing via an error +// May throw exceptions from nanobind in addition to failing via an error // Status. (We could catch these if needed, but there seems little point.) struct DevicePutOptions { bool squash_64bit_types = false; bool allow_zero_copy = true; }; -StatusOr DevicePut(pybind11::handle arg, ifrt::Client* client, - ifrt::Device* to_device, - const DevicePutOptions& options, - ifrt::MemoryKind to_memory_kind); +absl::StatusOr DevicePut(nanobind::handle arg, + ifrt::Client* client, + ifrt::Device* to_device, + const DevicePutOptions& options, + ifrt::MemoryKind to_memory_kind); // Returns `true` if `arg` is a JAX float0 array. -bool IsFloat0(pybind11::array arg); +bool IsFloat0(xla::nb_numpy_ndarray arg); // Describes the abstract shape and dtype of an argument. struct PyArgSignature { @@ -90,8 +95,8 @@ struct PyArgSignature { // Returns the PyArgSignature associated with an argument. Returns an error if // the argument is not supported. -StatusOr PyArgSignatureOfValue(pybind11::handle arg, - bool jax_enable_x64); +absl::StatusOr PyArgSignatureOfValue(nanobind::handle arg, + bool jax_enable_x64); template H AbslHashValue(H h, const xla::PyArgSignature& s) { diff --git a/xla/python/python_ref_manager.cc b/xla/python/python_ref_manager.cc index 990c74d80bbe4..4b0640c75ea0c 100644 --- a/xla/python/python_ref_manager.cc +++ b/xla/python/python_ref_manager.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,21 +15,26 @@ limitations under the License. #include "xla/python/python_ref_manager.h" +#include #include #include #include +#include #include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" // from @nanobind namespace xla { -namespace py = pybind11; +namespace nb = nanobind; PythonRefManager::ManagedPyObjects::ManagedPyObjects( - PythonRefManager* manager, absl::Span objects) + PythonRefManager* manager, absl::Span objects) : manager_(manager) { objects_.reserve(objects.size()); - for (pybind11::object& object : objects) { + for (nb::object& object : objects) { objects_.push_back(std::move(object)); } } @@ -41,21 +46,28 @@ PythonRefManager::ManagedPyObjects::~ManagedPyObjects() { } std::shared_ptr -PythonRefManager::ManageReference(py::object object) { +PythonRefManager::ManageReference(nb::object object) { return std::make_shared(this, - absl::Span(&object, 1)); + absl::Span(&object, 1)); } std::shared_ptr -PythonRefManager::ManageReferences(absl::Span objects) { +PythonRefManager::ManageReferences(absl::Span objects) { return std::make_shared(this, objects); } -void PythonRefManager::AddGarbage(absl::Span garbage) { +void PythonRefManager::AddGarbage(nb::object garbage) { absl::MutexLock lock(&mu_); // We want to collect arbitrary python garbage (e.g., buffers) aggressively. garbage_count_.fetch_add(100, std::memory_order_relaxed); - for (py::object& o : garbage) { + python_garbage_.push_back(std::move(garbage)); +} + +void PythonRefManager::AddGarbage(absl::Span garbage) { + absl::MutexLock lock(&mu_); + // We want to collect arbitrary python garbage (e.g., buffers) aggressively. + garbage_count_.fetch_add(100, std::memory_order_relaxed); + for (nb::object& o : garbage) { python_garbage_.push_back(std::move(o)); } } @@ -68,14 +80,13 @@ void PythonRefManager::AddGarbage( // process. garbage_count_.fetch_add(1, std::memory_order_relaxed); for (const auto& o : garbage) { - python_garbage_.push_back(py::reinterpret_steal( - reinterpret_cast(o.first))); + python_garbage_.push_back(nb::steal(reinterpret_cast(o.first))); } } void PythonRefManager::CollectGarbage() { // TODO(phawkins): we should CHECK(PyGILState_Check()); - std::deque garbage; + std::deque garbage; { absl::MutexLock lock(&mu_); garbage_count_ = 0; diff --git a/xla/python/python_ref_manager.h b/xla/python/python_ref_manager.h index f4428598cb715..87e55e5b58e65 100644 --- a/xla/python/python_ref_manager.h +++ b/xla/python/python_ref_manager.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_PYTHON_PYTHON_REF_MANAGER_H_ #define XLA_PYTHON_PYTHON_REF_MANAGER_H_ +#include + #include #include #include @@ -25,7 +27,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" -#include "pybind11/pybind11.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind namespace xla { @@ -42,13 +44,13 @@ class PythonRefManager { public: PythonRefManager() = default; - // Holds references to a set of pybind11::objects, adding the references to + // Holds references to a set of nanobind::objects, adding the references to // the PythonRefManager on destruction. class ManagedPyObjects { public: ManagedPyObjects() = default; ManagedPyObjects(PythonRefManager* manager, - absl::Span objects); + absl::Span objects); ~ManagedPyObjects(); @@ -59,18 +61,19 @@ class PythonRefManager { private: PythonRefManager* manager_ = nullptr; - absl::InlinedVector objects_; + absl::InlinedVector objects_; }; // Creates a managed std::shared_ptr to an object. When the shared_ptr is // destroyed, the reference to 'object' will be added to python_garbage_, // and collected next time CollectGarbage() is called. - std::shared_ptr ManageReference(pybind11::object object); + std::shared_ptr ManageReference(nanobind::object object); std::shared_ptr ManageReferences( - absl::Span objects); + absl::Span objects); // Adds garbage objects to the manager. - void AddGarbage(absl::Span garbage); + void AddGarbage(nanobind::object garbage); + void AddGarbage(absl::Span garbage); void AddGarbage(absl::Span const> garbage); // Releases the contents of python_garbage_. Requires that the GIL is held. @@ -89,7 +92,7 @@ class PythonRefManager { private: absl::Mutex mu_; - std::deque python_garbage_ ABSL_GUARDED_BY(mu_); + std::deque python_garbage_ ABSL_GUARDED_BY(mu_); // Writes to garbage_count_ are protected by mu_, reads are not protected. std::atomic garbage_count_{0}; diff --git a/xla/python/python_utils.h b/xla/python/python_utils.h deleted file mode 100644 index 7268b7286198d..0000000000000 --- a/xla/python/python_utils.h +++ /dev/null @@ -1,60 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_PYTHON_PYTHON_UTILS_H_ -#define XLA_PYTHON_PYTHON_UTILS_H_ - -#include - -#include -#include - -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 -#include "xla/status_macros.h" -#include "xla/util.h" - -namespace jax { - -// This file contains utilities to write Python wrapers using the C API. -// It's used for performance critical code such as PyBuffer, jax.jit or -// jax.pmap. - -// Helpers for building Python properties -template -pybind11::object property_readonly(Func&& get) { - pybind11::handle property(reinterpret_cast(&PyProperty_Type)); - return property(pybind11::cpp_function(std::forward(get)), - pybind11::none(), pybind11::none(), ""); -} - -template -pybind11::object property(GetFunc&& get, SetFunc&& set) { - pybind11::handle property(reinterpret_cast(&PyProperty_Type)); - return property(pybind11::cpp_function(std::forward(get)), - pybind11::cpp_function(std::forward(set)), - pybind11::none(), ""); -} - -template -pybind11::object def_static(Constructor&& constructor) { - pybind11::handle property(reinterpret_cast(&PyProperty_Type)); - return pybind11::staticmethod( - pybind11::cpp_function(std::forward(constructor))); -} - -} // namespace jax - -#endif // XLA_PYTHON_PYTHON_UTILS_H_ diff --git a/xla/python/pytree.cc b/xla/python/pytree.cc index 1a47247bc5a21..19ebdd7b92dea 100644 --- a/xla/python/pytree.cc +++ b/xla/python/pytree.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,7 +18,11 @@ limitations under the License. #include "xla/python/pytree.h" +#include + #include +#include +#include #include #include #include @@ -31,29 +35,33 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" #include "absl/hash/hash.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" -#include "pybind11/attr.h" // from @pybind11 -#include "pybind11/cast.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 -#include "pybind11/stl.h" // from @pybind11 -#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil -#include "xla/python/exceptions.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/optional.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/pair.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/vector.h" // from @nanobind // IWYU pragma: keep +#include "xla/pjrt/exceptions.h" +#include "xla/python/nb_class_ptr.h" #include "tsl/platform/logging.h" namespace xla { -namespace py = pybind11; +namespace nb = nanobind; PyTreeRegistry::PyTreeRegistry(bool enable_none, bool enable_tuple, bool enable_namedtuple, bool enable_list, bool enable_dict) { auto add_builtin_type = [&](PyTypeObject* type_obj, PyTreeKind kind) { - py::object type = py::reinterpret_borrow( - reinterpret_cast(type_obj)); + nb::object type = + nb::borrow(reinterpret_cast(type_obj)); auto registration = std::make_unique(); registration->kind = kind; registration->type = type; @@ -74,8 +82,8 @@ PyTreeRegistry::PyTreeRegistry(bool enable_none, bool enable_tuple, } } -void PyTreeRegistry::Register(py::object type, py::function to_iterable, - py::function from_iterable) { +void PyTreeRegistry::Register(nb::object type, nb::callable to_iterable, + nb::callable from_iterable) { auto registration = std::make_unique(); registration->kind = PyTreeKind::kCustom; registration->type = type; @@ -85,14 +93,36 @@ void PyTreeRegistry::Register(py::object type, py::function to_iterable, if (!it.second) { throw std::invalid_argument( absl::StrFormat("Duplicate custom PyTreeDef type registration for %s.", - py::repr(type))); + nb::cast(nb::repr(type)))); } } +std::pair +PyTreeRegistry::Registration::ToIterable(nanobind::handle o) const { + nb::object out = to_iterable(o); + nb::tuple leaves_and_aux_data; + if (!nb::try_cast(out, leaves_and_aux_data) || + leaves_and_aux_data.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable function for a custom PyTree node should return " + "a (children, aux_data) tuple, got ", + nb::cast(nb::repr(out)))); + } + nb::iterable leaves; + if (!nb::try_cast(leaves_and_aux_data[0], leaves)) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable function for a custom PyTree node should return " + "a (children, aux_data) tuple where 'children' is iterable, " + "got ", + nb::cast(nb::repr(out)))); + } + return std::make_pair(std::move(leaves), nb::object(leaves_and_aux_data[1])); +} + // Computes the node kind of a given Python object. PyTreeKind PyTreeRegistry::KindOfObject( - py::handle obj, PyTreeRegistry::Registration const** custom) const { - const PyTreeRegistry::Registration* registration = Lookup(obj.get_type()); + nb::handle obj, PyTreeRegistry::Registration const** custom) const { + const PyTreeRegistry::Registration* registration = Lookup(obj.type()); if (registration) { if (registration->kind == PyTreeKind::kCustom) { *custom = registration; @@ -100,7 +130,7 @@ PyTreeKind PyTreeRegistry::KindOfObject( *custom = nullptr; } return registration->kind; - } else if (py::isinstance(obj) && py::hasattr(obj, "_fields")) { + } else if (nb::isinstance(obj) && nb::hasattr(obj, "_fields")) { // We can only identify namedtuples heuristically, here by the presence of // a _fields attribute. return PyTreeKind::kNamedTuple; @@ -110,7 +140,7 @@ PyTreeKind PyTreeRegistry::KindOfObject( } /*static*/ const PyTreeRegistry::Registration* PyTreeRegistry::Lookup( - py::handle type) const { + nb::handle type) const { auto it = registrations_.find(type); return it == registrations_.end() ? nullptr : it->second.get(); } @@ -124,28 +154,34 @@ std::shared_ptr DefaultPyTreeRegistry() { return registry; } -/*static*/ std::vector GetSortedPyDictKeys(PyObject* py_dict) { - std::vector keys; +/*static*/ std::vector GetSortedPyDictKeys(PyObject* py_dict) { + std::vector keys; keys.reserve(PyDict_Size(py_dict)); PyObject* key; Py_ssize_t pos = 0; while (PyDict_Next(py_dict, &pos, &key, /*value=*/nullptr)) { - keys.push_back(py::reinterpret_borrow(key)); + keys.push_back(nb::borrow(key)); } - std::stable_sort( - keys.begin(), keys.end(), [](const py::object& a, const py::object& b) { - int cmp = PyObject_RichCompareBool(a.ptr(), b.ptr(), Py_LT); - if (cmp == -1) { - throw py::error_already_set(); - } - return cmp; - }); + try { + std::stable_sort( + keys.begin(), keys.end(), [](const nb::object& a, const nb::object& b) { + int cmp = PyObject_RichCompareBool(a.ptr(), b.ptr(), Py_LT); + if (cmp == -1) { + throw nb::python_error(); + } + return cmp; + }); + } catch (nb::python_error& e) { + nb::raise_from(e, PyExc_ValueError, + "Comparator raised exception while sorting pytree " + "dictionary keys."); + } return keys; } -/*static*/ bool IsSortedPyDictKeysEqual(absl::Span lhs, - absl::Span rhs) { +/*static*/ bool IsSortedPyDictKeysEqual(absl::Span lhs, + absl::Span rhs) { if (lhs.size() != rhs.size()) { return false; } @@ -182,17 +218,73 @@ bool PyTreeDef::operator==(const PyTreeDef& other) const { return true; } +nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { + PyTreeRegistry::Registration const* custom; + PyTreeKind kind = KindOfObject(x, &custom); + switch (kind) { + case PyTreeKind::kNone: + return nb::make_tuple(nb::make_tuple(), nb::none()); + case PyTreeKind::kTuple: + case PyTreeKind::kList: + return nb::make_tuple(nb::borrow(x), nb::none()); + case PyTreeKind::kDict: { + nb::dict dict = nb::borrow(x); + std::vector sorted_keys = GetSortedPyDictKeys(dict.ptr()); + nb::tuple keys = nb::steal(PyTuple_New(sorted_keys.size())); + nb::tuple values = nb::steal(PyTuple_New(sorted_keys.size())); + for (size_t i = 0; i < sorted_keys.size(); ++i) { + PyTuple_SET_ITEM(values.ptr(), i, + nb::object(dict[sorted_keys[i]]).release().ptr()); + PyTuple_SET_ITEM(keys.ptr(), i, sorted_keys[i].release().ptr()); + } + return nb::make_tuple(std::move(values), std::move(keys)); + } + case PyTreeKind::kNamedTuple: { + nb::tuple in = nb::borrow(x); + nb::list out; + for (size_t i = 0; i < in.size(); ++i) { + out.append(in[i]); + } + return nb::make_tuple(std::move(out), x.type()); + } + case PyTreeKind::kCustom: { + auto [leaves, aux_data] = custom->ToIterable(x); + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } + default: + DCHECK(kind == PyTreeKind::kLeaf); + return nb::none(); + } +} + template -void PyTreeDef::FlattenImpl(py::handle handle, T& leaves, - const std::optional& leaf_predicate) { +void PyTreeDef::FlattenImpl(nb::handle handle, T& leaves, + const std::optional& leaf_predicate) { Node node; const int start_num_nodes = traversal_.size(); const int start_num_leaves = leaves.size(); - if (leaf_predicate && (*leaf_predicate)(handle).cast()) { - leaves.push_back(py::reinterpret_borrow(handle)); + bool is_known_leaf = false; + if (leaf_predicate) { + nb::object o = (*leaf_predicate)(handle); + // Historically we accepted "truthy" values from leaf predicates. Accept + // None here to keep existing clients happy. + if (o.is_none()) { + is_known_leaf = false; + } else if (!nb::try_cast(o, is_known_leaf)) { + throw std::invalid_argument(absl::StrCat( + "is_leaf predicate returned a non-boolean value ", + nb::cast(nb::repr(o)), "; expected a boolean")); + } + } + if (is_known_leaf) { + if constexpr (std::is_same_v) { + leaves.append(nb::borrow(handle)); + } else { + leaves.push_back(nb::borrow(handle)); + } } else { node.kind = registry_->KindOfObject(handle, &node.custom); - auto recurse = [this, &leaf_predicate, &leaves](py::handle child) { + auto recurse = [this, &leaf_predicate, &leaves](nb::handle child) { if (Py_EnterRecursiveCall( " in flatten; PyTree may have cyclical node references.")) { return; @@ -219,10 +311,10 @@ void PyTreeDef::FlattenImpl(py::handle handle, T& leaves, break; } case PyTreeKind::kDict: { - py::dict dict = py::reinterpret_borrow(handle); + nb::dict dict = nb::borrow(handle); - std::vector keys = GetSortedPyDictKeys(dict.ptr()); - for (py::handle key : keys) { + std::vector keys = GetSortedPyDictKeys(dict.ptr()); + for (nb::handle key : keys) { recurse(dict[key]); } node.arity = dict.size(); @@ -230,31 +322,31 @@ void PyTreeDef::FlattenImpl(py::handle handle, T& leaves, break; } case PyTreeKind::kCustom: { - py::tuple out = py::cast(node.custom->to_iterable(handle)); - if (out.size() != 2) { - throw xla::XlaRuntimeError( - "PyTree custom to_iterable function should return a pair"); - } - node.node_data = out[1]; + auto [leaves, aux_data] = node.custom->ToIterable(handle); + node.node_data = std::move(aux_data); node.arity = 0; - for (py::handle entry : py::cast(out[0])) { + for (nb::handle entry : leaves) { ++node.arity; recurse(entry); } break; } case PyTreeKind::kNamedTuple: { - py::tuple tuple = py::reinterpret_borrow(handle); + nb::tuple tuple = nb::borrow(handle); node.arity = tuple.size(); - node.node_data = py::reinterpret_borrow(tuple.get_type()); - for (py::handle entry : tuple) { + node.node_data = nb::borrow(tuple.type()); + for (nb::handle entry : tuple) { recurse(entry); } break; } default: DCHECK(node.kind == PyTreeKind::kLeaf); - leaves.push_back(py::reinterpret_borrow(handle)); + if constexpr (std::is_same_v) { + leaves.append(nb::borrow(handle)); + } else { + leaves.push_back(nb::borrow(handle)); + } } } node.num_nodes = traversal_.size() - start_num_nodes + 1; @@ -262,40 +354,44 @@ void PyTreeDef::FlattenImpl(py::handle handle, T& leaves, traversal_.push_back(std::move(node)); } -void PyTreeDef::Flatten(py::handle handle, - absl::InlinedVector& leaves, - std::optional leaf_predicate) { +void PyTreeDef::Flatten(nb::handle handle, + absl::InlinedVector& leaves, + std::optional leaf_predicate) { FlattenImpl(handle, leaves, leaf_predicate); } -void PyTreeDef::Flatten(py::handle handle, std::vector& leaves, - std::optional leaf_predicate) { +void PyTreeDef::Flatten(nb::handle handle, std::vector& leaves, + std::optional leaf_predicate) { FlattenImpl(handle, leaves, leaf_predicate); } -/*static*/ std::pair, std::unique_ptr> -PyTreeDef::Flatten(pybind11::handle x, - std::optional leaf_predicate, +void PyTreeDef::Flatten(nb::handle handle, nb::list& leaves, + std::optional leaf_predicate) { + FlattenImpl(handle, leaves, leaf_predicate); +} + +/*static*/ std::pair, nb_class_ptr> +PyTreeDef::Flatten(nb::handle x, std::optional leaf_predicate, std::shared_ptr registry) { - auto def = std::make_unique(registry ? registry - : DefaultPyTreeRegistry()); - std::vector leaves; - def->Flatten(x, leaves); + auto def = + make_nb_class(registry ? registry : DefaultPyTreeRegistry()); + std::vector leaves; + def->Flatten(x, leaves, leaf_predicate); return std::make_pair(std::move(leaves), std::move(def)); } /*static*/ bool PyTreeDef::AllLeaves(PyTreeRegistry* registry, - const py::iterable& x) { + const nb::iterable& x) { const PyTreeRegistry::Registration* custom; - for (const py::handle& h : x) { + for (const nb::handle& h : x) { if (registry->KindOfObject(h, &custom) != PyTreeKind::kLeaf) return false; } return true; } template -py::object PyTreeDef::UnflattenImpl(T leaves) const { - absl::InlinedVector agenda; +nb::object PyTreeDef::UnflattenImpl(T leaves) const { + absl::InlinedVector agenda; auto it = leaves.begin(); int leaf_count = 0; for (const Node& node : traversal_) { @@ -309,7 +405,7 @@ py::object PyTreeDef::UnflattenImpl(T leaves) const { "Too few leaves for PyTreeDef; expected %d, got %d", num_leaves(), leaf_count)); } - agenda.push_back(py::reinterpret_borrow(*it)); + agenda.push_back(nb::borrow(*it)); ++it; ++leaf_count; break; @@ -321,11 +417,11 @@ py::object PyTreeDef::UnflattenImpl(T leaves) const { case PyTreeKind::kDict: case PyTreeKind::kCustom: { const int size = agenda.size(); - absl::Span span; + absl::Span span; if (node.arity > 0) { - span = absl::Span(&agenda[size - node.arity], node.arity); + span = absl::Span(&agenda[size - node.arity], node.arity); } - py::object o = MakeNode(node, span); + nb::object o = MakeNode(node, span); agenda.resize(size - node.arity); agenda.push_back(o); break; @@ -342,16 +438,16 @@ py::object PyTreeDef::UnflattenImpl(T leaves) const { return std::move(agenda.back()); } -py::object PyTreeDef::Unflatten(py::iterable leaves) const { +nb::object PyTreeDef::Unflatten(nb::iterable leaves) const { return UnflattenImpl(leaves); } -py::object PyTreeDef::Unflatten(absl::Span leaves) const { +nb::object PyTreeDef::Unflatten(absl::Span leaves) const { return UnflattenImpl(leaves); } -/*static*/ py::object PyTreeDef::MakeNode(const PyTreeDef::Node& node, - absl::Span children) { +/*static*/ nb::object PyTreeDef::MakeNode(const PyTreeDef::Node& node, + absl::Span children) { if (children.size() != node.arity) { throw std::logic_error("Node arity mismatch."); } @@ -360,31 +456,31 @@ py::object PyTreeDef::Unflatten(absl::Span leaves) const { throw std::logic_error("MakeNode not implemented for leaves."); case PyTreeKind::kNone: - return py::none(); + return nb::none(); case PyTreeKind::kTuple: case PyTreeKind::kNamedTuple: { - py::tuple tuple(node.arity); + nb::object tuple = nb::steal(PyTuple_New(node.arity)); for (int i = 0; i < node.arity; ++i) { - tuple[i] = std::move(children[i]); + PyTuple_SET_ITEM(tuple.ptr(), i, children[i].release().ptr()); } if (node.kind == PyTreeKind::kNamedTuple) { return node.node_data(*tuple); } else { - return std::move(tuple); + return tuple; } } case PyTreeKind::kList: { - py::list list(node.arity); + nb::object list = nb::steal(PyList_New(node.arity)); for (int i = 0; i < node.arity; ++i) { - list[i] = std::move(children[i]); + PyList_SET_ITEM(list.ptr(), i, children[i].release().ptr()); } - return std::move(list); + return list; } case PyTreeKind::kDict: { - py::dict dict; + nb::dict dict; for (int i = 0; i < node.arity; ++i) { dict[node.sorted_dict_keys[i]] = std::move(children[i]); } @@ -392,9 +488,9 @@ py::object PyTreeDef::Unflatten(absl::Span leaves) const { break; } case PyTreeKind::kCustom: { - py::tuple tuple(node.arity); + nb::object tuple = nb::steal(PyTuple_New(node.arity)); for (int i = 0; i < node.arity; ++i) { - tuple[i] = std::move(children[i]); + PyTuple_SET_ITEM(tuple.ptr(), i, children[i].release().ptr()); } return node.custom->from_iterable(node.node_data, tuple); } @@ -402,19 +498,20 @@ py::object PyTreeDef::Unflatten(absl::Span leaves) const { throw std::logic_error("Unreachable code."); } -py::list PyTreeDef::FlattenUpTo(py::handle xs) const { - py::list leaves(num_leaves()); - std::vector agenda; - agenda.push_back(py::reinterpret_borrow(xs)); +nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { + nb::list leaves = nb::steal(PyList_New(num_leaves())); + std::vector agenda; + agenda.push_back(nb::borrow(xs)); auto it = traversal_.rbegin(); int leaf = num_leaves() - 1; while (!agenda.empty()) { if (it == traversal_.rend()) { throw std::invalid_argument(absl::StrFormat( - "Tree structures did not match: %s vs %s", py::repr(xs), ToString())); + "Tree structures did not match: %s vs %s", + nb::cast(nb::repr(xs)), ToString())); } const Node& node = *it; - py::object object = agenda.back(); + nb::object object = agenda.back(); agenda.pop_back(); ++it; @@ -423,7 +520,7 @@ py::list PyTreeDef::FlattenUpTo(py::handle xs) const { if (leaf < 0) { throw std::logic_error("Leaf count mismatch."); } - leaves[leaf] = py::reinterpret_borrow(object); + PyList_SET_ITEM(leaves.ptr(), leaf, object.release().ptr()); --leaf; break; @@ -433,16 +530,17 @@ py::list PyTreeDef::FlattenUpTo(py::handle xs) const { case PyTreeKind::kTuple: { if (!PyTuple_CheckExact(object.ptr())) { throw std::invalid_argument( - absl::StrFormat("Expected tuple, got %s.", py::repr(object))); + absl::StrFormat("Expected tuple, got %s.", + nb::cast(nb::repr(object)))); } - py::tuple tuple = py::reinterpret_borrow(object); + nb::tuple tuple = nb::borrow(object); if (tuple.size() != node.arity) { - throw std::invalid_argument( - absl::StrFormat("Tuple arity mismatch: %d != %d; tuple: %s.", - tuple.size(), node.arity, py::repr(object))); + throw std::invalid_argument(absl::StrFormat( + "Tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), + node.arity, nb::cast(nb::repr(object)))); } - for (py::handle entry : tuple) { - agenda.push_back(py::reinterpret_borrow(entry)); + for (nb::handle entry : tuple) { + agenda.push_back(nb::borrow(entry)); } break; } @@ -450,16 +548,17 @@ py::list PyTreeDef::FlattenUpTo(py::handle xs) const { case PyTreeKind::kList: { if (!PyList_CheckExact(object.ptr())) { throw std::invalid_argument( - absl::StrFormat("Expected list, got %s.", py::repr(object))); + absl::StrFormat("Expected list, got %s.", + nb::cast(nb::repr(object)))); } - py::list list = py::reinterpret_borrow(object); + nb::list list = nb::borrow(object); if (list.size() != node.arity) { - throw std::invalid_argument( - absl::StrFormat("List arity mismatch: %d != %d; list: %s.", - list.size(), node.arity, py::repr(object))); + throw std::invalid_argument(absl::StrFormat( + "List arity mismatch: %d != %d; list: %s.", list.size(), + node.arity, nb::cast(nb::repr(object)))); } - for (py::handle entry : list) { - agenda.push_back(py::reinterpret_borrow(entry)); + for (nb::handle entry : list) { + agenda.push_back(nb::borrow(entry)); } break; } @@ -467,87 +566,92 @@ py::list PyTreeDef::FlattenUpTo(py::handle xs) const { case PyTreeKind::kDict: { if (!PyDict_CheckExact(object.ptr())) { throw std::invalid_argument( - absl::StrFormat("Expected dict, got %s.", py::repr(object))); + absl::StrFormat("Expected dict, got %s.", + nb::cast(nb::repr(object)))); } - py::dict dict = py::reinterpret_borrow(object); - std::vector keys = GetSortedPyDictKeys(dict.ptr()); + nb::dict dict = nb::borrow(object); + std::vector keys = GetSortedPyDictKeys(dict.ptr()); if (!IsSortedPyDictKeysEqual(keys, node.sorted_dict_keys)) { - // Convert to a py::list for py::repr to avoid having to stringify a + // Convert to a nb::list for nb::repr to avoid having to stringify a // vector. This is error path so it is fine to pay conversion cost. - throw std::invalid_argument(absl::StrFormat( - "Dict key mismatch; expected keys: %s; dict: %s.", - py::repr(py::cast(node.sorted_dict_keys)), py::repr(object))); + throw std::invalid_argument( + absl::StrFormat("Dict key mismatch; expected keys: %s; dict: %s.", + nb::cast( + nb::repr(nb::cast(node.sorted_dict_keys))), + nb::cast(nb::repr(object)))); } - for (py::handle key : keys) { + for (nb::handle key : keys) { agenda.push_back(dict[key]); } break; } case PyTreeKind::kNamedTuple: { - if (!py::isinstance(object) || - !py::hasattr(object, "_fields")) { - throw std::invalid_argument(absl::StrFormat( - "Expected named tuple, got %s.", py::repr(object))); + if (!nb::isinstance(object) || + !nb::hasattr(object, "_fields")) { + throw std::invalid_argument( + absl::StrFormat("Expected named tuple, got %s.", + nb::cast(nb::repr(object)))); } - py::tuple tuple = py::reinterpret_borrow(object); + nb::tuple tuple = nb::borrow(object); if (tuple.size() != node.arity) { throw std::invalid_argument(absl::StrFormat( "Named tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), - node.arity, py::repr(object))); + node.arity, nb::cast(nb::repr(object)))); } - if (tuple.get_type().not_equal(node.node_data)) { + if (tuple.type().not_equal(node.node_data)) { throw std::invalid_argument(absl::StrFormat( "Named tuple type mismatch: expected type: %s, tuple: %s.", - py::repr(node.node_data), py::repr(object))); + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(object)))); } - for (py::handle entry : tuple) { - agenda.push_back(py::reinterpret_borrow(entry)); + for (nb::handle entry : tuple) { + agenda.push_back(nb::borrow(entry)); } break; } case PyTreeKind::kCustom: { - auto* registration = registry_->Lookup(object.get_type()); + auto* registration = registry_->Lookup(object.type()); if (registration != node.custom) { throw std::invalid_argument(absl::StrFormat( "Custom node type mismatch: expected type: %s, value: %s.", - py::repr(node.custom->type), py::repr(object))); + nb::cast(nb::repr(node.custom->type)), + nb::cast(nb::repr(object)))); } - py::tuple out = py::cast(node.custom->to_iterable(object)); - if (out.size() != 2) { - throw xla::XlaRuntimeError( - "PyTree custom to_iterable function should return a pair"); - } - if (node.node_data.not_equal(out[1])) { + auto [leaves, aux_data] = node.custom->ToIterable(object); + if (node.node_data.not_equal(aux_data)) { throw std::invalid_argument(absl::StrFormat( "Mismatch custom node data: %s != %s; value: %s.", - py::repr(node.node_data), py::repr(out[1]), py::repr(object))); + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(aux_data)), + nb::cast(nb::repr(object)))); } int arity = 0; - for (py::handle entry : py::cast(out[0])) { + for (nb::handle entry : leaves) { ++arity; - agenda.push_back(py::reinterpret_borrow(entry)); + agenda.push_back(nb::borrow(entry)); } if (arity != node.arity) { throw std::invalid_argument(absl::StrFormat( "Custom type arity mismatch: %d != %d; value: %s.", arity, - node.arity, py::repr(object))); + node.arity, nb::cast(nb::repr(object)))); } break; } } } if (it != traversal_.rend() || leaf != -1) { - throw std::invalid_argument(absl::StrFormat( - "Tree structures did not match: %s vs %s", py::repr(xs), ToString())); + throw std::invalid_argument( + absl::StrFormat("Tree structures did not match: %s vs %s", + nb::cast(nb::repr(xs)), ToString())); } return leaves; } -py::object PyTreeDef::Walk(const py::function& f_node, py::handle f_leaf, - py::iterable leaves) const { - std::vector agenda; +nb::object PyTreeDef::Walk(const nb::callable& f_node, nb::handle f_leaf, + nb::iterable leaves) const { + std::vector agenda; auto it = leaves.begin(); for (const Node& node : traversal_) { switch (node.kind) { @@ -556,7 +660,7 @@ py::object PyTreeDef::Walk(const py::function& f_node, py::handle f_leaf, throw std::invalid_argument("Too few leaves for PyTreeDef"); } - py::object leaf = py::reinterpret_borrow(*it); + nb::object leaf = nb::borrow(*it); agenda.push_back(f_leaf.is_none() ? std::move(leaf) : f_leaf(std::move(leaf))); ++it; @@ -572,17 +676,17 @@ py::object PyTreeDef::Walk(const py::function& f_node, py::handle f_leaf, if (agenda.size() < node.arity) { throw std::logic_error("Too few elements for custom type."); } - py::tuple tuple(node.arity); + nb::object tuple = nb::steal(PyTuple_New(node.arity)); for (int i = node.arity - 1; i >= 0; --i) { - tuple[i] = agenda.back(); + PyTuple_SET_ITEM(tuple.ptr(), i, agenda.back().release().ptr()); agenda.pop_back(); } - py::object node_data = node.node_data; + nb::object node_data = node.node_data; if (node.kind == PyTreeKind::kDict) { - // Convert to a py::list for f_node invocation. - node_data = py::cast(node.sorted_dict_keys); + // Convert to a nb::list for f_node invocation. + node_data = nb::cast(node.sorted_dict_keys); } - agenda.push_back(f_node(tuple, node_data ? node_data : py::none())); + agenda.push_back(f_node(tuple, node_data ? node_data : nb::none())); } } } @@ -595,8 +699,8 @@ py::object PyTreeDef::Walk(const py::function& f_node, py::handle f_leaf, return std::move(agenda.back()); } -py::object PyTreeDef::FromIterableTreeHelper( - py::handle xs, +nb::object PyTreeDef::FromIterableTreeHelper( + nb::handle xs, absl::InlinedVector::const_reverse_iterator* it) const { if (*it == traversal_.rend()) { throw std::invalid_argument("Tree structures did not match."); @@ -604,13 +708,13 @@ py::object PyTreeDef::FromIterableTreeHelper( const Node& node = **it; ++*it; if (node.kind == PyTreeKind::kLeaf) { - return py::reinterpret_borrow(xs); + return nb::borrow(xs); } - py::iterable iterable = py::reinterpret_borrow(xs); - std::vector ys; + nb::iterable iterable = nb::borrow(xs); + std::vector ys; ys.reserve(node.arity); - for (py::handle x : iterable) { - ys.push_back(py::reinterpret_borrow(x)); + for (nb::handle x : iterable) { + ys.push_back(nb::borrow(x)); } if (ys.size() != node.arity) { throw std::invalid_argument("Arity mismatch between trees"); @@ -622,21 +726,21 @@ py::object PyTreeDef::FromIterableTreeHelper( return MakeNode(node, absl::MakeSpan(ys)); } -py::object PyTreeDef::FromIterableTree(py::handle xs) const { +nb::object PyTreeDef::FromIterableTree(nb::handle xs) const { auto it = traversal_.rbegin(); - py::object out = FromIterableTreeHelper(xs, &it); + nb::object out = FromIterableTreeHelper(xs, &it); if (it != traversal_.rend()) { throw std::invalid_argument("Tree structures did not match."); } return out; } -std::unique_ptr PyTreeDef::Compose(const PyTreeDef& inner) const { +nb_class_ptr PyTreeDef::Compose(const PyTreeDef& inner) const { if (inner.registry_ != registry_) { throw std::invalid_argument( "PyTree registries of PyTreeDefs passed to Compose() must match."); } - auto out = std::make_unique(registry_->shared_from_this()); + auto out = make_nb_class(registry_->shared_from_this()); out->traversal_.reserve(static_cast(num_leaves()) * inner.num_nodes() + num_nodes() - num_leaves()); @@ -651,12 +755,12 @@ std::unique_ptr PyTreeDef::Compose(const PyTreeDef& inner) const { return out; } -/*static*/ std::unique_ptr PyTreeDef::Tuple( - std::shared_ptr registry, - absl::Span defs) { - auto out = std::make_unique(std::move(registry)); +/*static*/ nb_class_ptr PyTreeDef::Tuple( + std::shared_ptr registry, nb::list defs) { + auto out = make_nb_class(std::move(registry)); int num_leaves = 0; - for (const PyTreeDef* def : defs) { + for (nb::handle def_handle : defs) { + const PyTreeDef* def = nb::cast(def_handle); if (def->registry() != out->registry()) { throw std::invalid_argument( "PyTree registries of PyTreeDefs passed to Tuple() must match."); @@ -673,8 +777,8 @@ std::unique_ptr PyTreeDef::Compose(const PyTreeDef& inner) const { return out; } -std::vector> PyTreeDef::Children() const { - std::vector> children; +std::vector> PyTreeDef::Children() const { + std::vector> children; if (traversal_.empty()) { return children; } @@ -682,7 +786,7 @@ std::vector> PyTreeDef::Children() const { children.resize(root.arity); int pos = traversal_.size() - 1; for (int i = root.arity - 1; i >= 0; --i) { - children[i] = std::make_unique(registry_->shared_from_this()); + children[i] = make_nb_class(registry_->shared_from_this()); const Node& node = traversal_.at(pos - 1); if (pos < node.num_nodes) { throw std::logic_error("children() walked off start of array"); @@ -730,9 +834,10 @@ std::string PyTreeDef::ToString() const { representation = "{"; std::string separator; auto child_iter = agenda.end() - node.arity; - for (const py::handle& key : node.sorted_dict_keys) { + for (const nb::handle& key : node.sorted_dict_keys) { absl::StrAppendFormat(&representation, "%s%s: %s", separator, - py::repr(key), *child_iter); + nb::cast(nb::repr(key)), + *child_iter); child_iter++; separator = ", "; } @@ -749,13 +854,15 @@ std::string PyTreeDef::ToString() const { if (node.node_data) { // Node data for named tuples is the type. data = absl::StrFormat( - "[%s]", py::str(py::getattr(node.node_data, "__name__"))); + "[%s]", nb::cast( + nb::str(nb::getattr(node.node_data, "__name__")))); } } else { - kind = static_cast( - py::str(py::getattr(node.custom->type, "__name__"))); + kind = nb::cast( + nb::str(nb::getattr(node.custom->type, "__name__"))); if (node.node_data) { - data = absl::StrFormat("[%s]", py::str(node.node_data)); + data = absl::StrFormat( + "[%s]", nb::cast(nb::str(node.node_data))); } } @@ -773,46 +880,40 @@ std::string PyTreeDef::ToString() const { return absl::StrCat("PyTreeDef(", agenda.back(), ")"); } -py::object PyTreeDef::ToPickle() const { - py::list traversal; +nb::object PyTreeDef::ToPickle() const { + nb::list traversal; for (const auto& node : traversal_) { - py::object node_data = node.node_data; + nb::object node_data = node.node_data; if (node.kind == PyTreeKind::kDict) { - // Convert to a py::list for pickling to avoid having to pickle a vector. + // Convert to a nb::list for pickling to avoid having to pickle a vector. // Pickle should be a rare operation so this conversion cost is hopefully // on non-critical path. - node_data = py::cast(node.sorted_dict_keys); + node_data = nb::cast(node.sorted_dict_keys); } traversal.append( - py::make_tuple(static_cast(node.kind), node.arity, - node_data ? node_data : py::none(), - node.custom != nullptr ? node.custom->type : py::none(), + nb::make_tuple(static_cast(node.kind), node.arity, + node_data ? node_data : nb::none(), + node.custom != nullptr ? node.custom->type : nb::none(), node.num_leaves, node.num_nodes)); } - return py::make_tuple(py::cast(registry_->shared_from_this()), traversal); + return nb::make_tuple(nb::cast(registry_->shared_from_this()), traversal); } -PyTreeDef PyTreeDef::FromPickle(py::object pickleable) { - py::tuple pickle = pickleable.cast(); - if (pickle.size() != 2) { - throw xla::XlaRuntimeError("Malformed pickled PyTreeDef, expected 2-tuple"); - } - auto registry = py::cast>(pickle[0]); - PyTreeDef tree(registry); - for (const auto& item : pickle[1].cast()) { - auto t = item.cast(); +void PyTreeDef::FromPickle(nb::object pickle) { + for (const auto& item : nb::cast(pickle)) { + auto t = nb::cast(item); if (t.size() != 6) { throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); } - Node& node = tree.traversal_.emplace_back(); - node.kind = static_cast(t[0].cast()); - node.arity = t[1].cast(); + Node& node = traversal_.emplace_back(); + node.kind = static_cast(nb::cast(t[0])); + node.arity = nb::cast(t[1]); switch (node.kind) { case PyTreeKind::kNamedTuple: - node.node_data = t[2].cast(); + node.node_data = t[2]; break; case PyTreeKind::kDict: - node.sorted_dict_keys = t[2].cast>(); + node.sorted_dict_keys = nb::cast>(t[2]); break; case PyTreeKind::kCustom: node.node_data = t[2]; @@ -824,21 +925,20 @@ PyTreeDef PyTreeDef::FromPickle(py::object pickleable) { break; } if (node.kind == PyTreeKind::kCustom) { - node.custom = t[3].is_none() ? nullptr : registry->Lookup(t[3]); + node.custom = t[3].is_none() ? nullptr : registry()->Lookup(t[3]); if (node.custom == nullptr) { throw xla::XlaRuntimeError( absl::StrCat("Unknown custom type in pickled PyTreeDef: ", - static_cast(py::repr(t[3])))); + nb::cast(nb::repr(t[3])))); } } else { if (!t[3].is_none()) { throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); } } - node.num_leaves = t[4].cast(); - node.num_nodes = t[5].cast(); + node.num_leaves = nb::cast(t[4]); + node.num_nodes = nb::cast(t[5]); } - return tree; } void PyTreeDef::SetNumLeavesAndNumNodes() { @@ -889,13 +989,13 @@ void PyTreeDef::SerializeTo(jax::PyTreeDefProto& result) const { case PyTreeKind::kDict: node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_DICT); for (auto& key : node.sorted_dict_keys) { - if (!py::isinstance(key)) { + if (!nb::isinstance(key)) { throw std::invalid_argument( "Only string keys are supported in proto pytree " "serialization."); } node_data->mutable_dict_keys()->add_str_id( - intern_str(py::cast(key))); + intern_str(nb::cast(key))); } break; default: @@ -908,17 +1008,19 @@ void PyTreeDef::SerializeTo(jax::PyTreeDefProto& result) const { } } -PyTreeDef PyTreeDef::DeserializeFrom(std::shared_ptr registry, - const jax::PyTreeDefProto& input) { - std::vector interned_strings; +nb_class_ptr PyTreeDef::DeserializeFrom( + std::shared_ptr registry, + const jax::PyTreeDefProto& input) { + std::vector interned_strings; interned_strings.reserve(input.interned_strings().size()); for (auto& s : input.interned_strings()) { - interned_strings.push_back(py::str(s)); + interned_strings.push_back(nb::cast(s)); } - PyTreeDef result(std::move(registry)); + nb_class_ptr result = + make_nb_class(std::move(registry)); for (auto& node_proto : input.nodes()) { - result.traversal_.emplace_back(); - auto& node = result.traversal_.back(); + result->traversal_.emplace_back(); + auto& node = result->traversal_.back(); node.arity = node_proto.arity(); node.custom = nullptr; switch (node_proto.type()) { @@ -950,60 +1052,58 @@ PyTreeDef PyTreeDef::DeserializeFrom(std::shared_ptr registry, break; } } - result.SetNumLeavesAndNumNodes(); + result->SetNumLeavesAndNumNodes(); return result; } -std::optional> -PyTreeDef::GetNodeData() const { +std::optional> PyTreeDef::GetNodeData() + const { if (traversal_.empty()) { throw std::logic_error("empty PyTreeDef traversal."); } auto builtin_type = [](PyTypeObject* type_obj) { - return py::reinterpret_borrow( - reinterpret_cast(type_obj)); + return nb::borrow(reinterpret_cast(type_obj)); }; const auto& node = traversal_.back(); switch (node.kind) { case PyTreeKind::kLeaf: return std::nullopt; case PyTreeKind::kNone: - return std::make_pair(builtin_type(Py_TYPE(Py_None)), py::none()); + return std::make_pair(builtin_type(Py_TYPE(Py_None)), nb::none()); case PyTreeKind::kTuple: - return std::make_pair(builtin_type(&PyTuple_Type), py::none()); + return std::make_pair(builtin_type(&PyTuple_Type), nb::none()); case PyTreeKind::kList: - return std::make_pair(builtin_type(&PyList_Type), py::none()); + return std::make_pair(builtin_type(&PyList_Type), nb::none()); case PyTreeKind::kDict: return std::make_pair(builtin_type(&PyDict_Type), - py::cast(node.sorted_dict_keys)); + nb::cast(node.sorted_dict_keys)); case PyTreeKind::kNamedTuple: - return std::make_pair(py::cast(node.node_data), - py::none()); + return std::make_pair(node.node_data, nb::none()); case PyTreeKind::kCustom: - return std::make_pair(py::cast(node.custom->type), - node.node_data); + return std::make_pair(node.custom->type, node.node_data); } } -PyTreeDef PyTreeDef::MakeFromNodeDataAndChildren( +nb_class_ptr PyTreeDef::MakeFromNodeDataAndChildren( std::shared_ptr registry, - std::optional> node_data, - pybind11::iterable children) { - PyTreeDef result(std::move(registry)); + std::optional> node_data, + nb::iterable children) { + nb_class_ptr result = + make_nb_class(std::move(registry)); int num_leaves = 0; int arity = 0; - for (pybind11::handle pchild : children) { - const PyTreeDef& child = py::cast(pchild); - absl::c_copy(child.traversal_, std::back_inserter(result.traversal_)); + for (nb::handle pchild : children) { + const PyTreeDef& child = nb::cast(pchild); + absl::c_copy(child.traversal_, std::back_inserter(result->traversal_)); num_leaves += child.num_leaves(); ++arity; } - result.traversal_.emplace_back(); - auto& node = result.traversal_.back(); + result->traversal_.emplace_back(); + auto& node = result->traversal_.back(); node.arity = arity; node.custom = nullptr; node.num_leaves = num_leaves; - node.num_nodes = result.traversal_.size(); + node.num_nodes = result->traversal_.size(); if (node_data == std::nullopt) { node.kind = PyTreeKind::kLeaf; ++node.num_leaves; @@ -1012,18 +1112,18 @@ PyTreeDef PyTreeDef::MakeFromNodeDataAndChildren( int is_nt = PyObject_IsSubclass(node_data->first.ptr(), reinterpret_cast(&PyTuple_Type)); if (is_nt == -1) { - throw py::error_already_set(); + throw nb::python_error(); } - if (is_nt != 0 && py::hasattr(node_data->first, "_fields")) { + if (is_nt != 0 && nb::hasattr(node_data->first, "_fields")) { node.kind = PyTreeKind::kNamedTuple; node.node_data = node_data->first; return result; } - auto* registration = result.registry()->Lookup(node_data->first); + auto* registration = result->registry()->Lookup(node_data->first); if (registration == nullptr) { - throw std::logic_error( - absl::StrFormat("Could not find type: %s.", - py::repr(node_data->first).cast())); + throw std::logic_error(absl::StrFormat( + "Could not find type: %s.", + nb::cast(nb::repr(node_data->first)))); } node.kind = registration->kind; if (node.kind == PyTreeKind::kCustom) { @@ -1032,36 +1132,40 @@ PyTreeDef PyTreeDef::MakeFromNodeDataAndChildren( } else if (node.kind == PyTreeKind::kNamedTuple) { node.node_data = node_data->first; } else if (node.kind == PyTreeKind::kDict) { - node.sorted_dict_keys = node_data->second.cast>(); + node.sorted_dict_keys = + nb::cast>(node_data->second); } return result; } -void BuildPytreeSubmodule(py::module& m) { - py::module pytree = m.def_submodule("pytree", "Python tree library"); - pytree.attr("version") = py::int_(3); +void BuildPytreeSubmodule(nb::module_& m) { + nb::module_ pytree = m.def_submodule("pytree", "Python tree library"); + pytree.attr("version") = nb::int_(3); + + nb::class_ treedef(pytree, "PyTreeDef"); - py::class_ treedef(pytree, "PyTreeDef"); + nb::class_ registry(m, "PyTreeRegistry", nb::dynamic_attr()); - py::class_> registry( - m, "PyTreeRegistry", py::dynamic_attr()); - registry.def(py::init(), py::kw_only(), - py::arg("enable_none") = true, py::arg("enable_tuple") = true, - py::arg("enable_namedtuple") = true, - py::arg("enable_list") = true, py::arg("enable_dict") = true); + registry.def(nb::init(), + nb::arg("enable_none") = true, nb::arg("enable_tuple") = true, + nb::arg("enable_namedtuple") = true, + nb::arg("enable_list") = true, nb::arg("enable_dict") = true); registry.def( "flatten", - [](std::shared_ptr registry, pybind11::handle x, - std::optional leaf_predicate) { - std::vector leaves; - PyTreeDef def(std::move(registry)); - def.Flatten(x, leaves, leaf_predicate); - return std::make_pair(std::move(leaves), std::move(def)); + [](std::shared_ptr registry, nb::object x, + std::optional leaf_predicate) { + nb::list leaves; + nb_class_ptr def = + make_nb_class(std::move(registry)); + def->Flatten(x, leaves, leaf_predicate); + return nb::make_tuple(std::move(leaves), std::move(def)); }, - py::arg("tree"), py::arg("leaf_predicate") = std::nullopt); + nb::arg("tree").none(), nb::arg("leaf_predicate").none() = std::nullopt); + registry.def("flatten_one_level", &PyTreeRegistry::FlattenOneLevel, + nb::arg("tree").none()); registry.def("register_node", &PyTreeRegistry::Register); registry.def("__reduce__", - [](py::object self) { return self.attr("__name__"); }); + [](nb::object self) { return self.attr("__name__"); }); pytree.def("default_registry", &DefaultPyTreeRegistry); pytree.attr("PyTreeRegistry") = m.attr("PyTreeRegistry"); @@ -1069,19 +1173,19 @@ void BuildPytreeSubmodule(py::module& m) { pytree.def("all_leaves", &PyTreeDef::AllLeaves); treedef.def("unflatten", - static_cast(&PyTreeDef::Unflatten)); - treedef.def("flatten_up_to", &PyTreeDef::FlattenUpTo); + static_cast( + &PyTreeDef::Unflatten)); + treedef.def("flatten_up_to", &PyTreeDef::FlattenUpTo, nb::arg("tree").none()); treedef.def("compose", &PyTreeDef::Compose); treedef.def( "walk", &PyTreeDef::Walk, "Walk pytree, calling f_node(node, node_data) at nodes, and f_leaf " "at leaves", - py::arg("f_node"), py::arg("f_leaf"), py::arg("leaves")); + nb::arg("f_node"), nb::arg("f_leaf"), nb::arg("leaves")); treedef.def("from_iterable_tree", &PyTreeDef::FromIterableTree); treedef.def("children", &PyTreeDef::Children); - treedef.def_property_readonly("num_leaves", &PyTreeDef::num_leaves); - treedef.def_property_readonly("num_nodes", &PyTreeDef::num_nodes); + treedef.def_prop_ro("num_leaves", &PyTreeDef::num_leaves); + treedef.def_prop_ro("num_nodes", &PyTreeDef::num_nodes); treedef.def("__repr__", &PyTreeDef::ToString); treedef.def("__eq__", [](const PyTreeDef& a, const PyTreeDef& b) { return a == b; }); @@ -1091,13 +1195,14 @@ void BuildPytreeSubmodule(py::module& m) { treedef.def("serialize_using_proto", [](const PyTreeDef& a) { jax::PyTreeDefProto result; a.SerializeTo(result); - return py::bytes(result.SerializeAsString()); + std::string serialized = result.SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); }); treedef.def_static( "deserialize_using_proto", - [](std::shared_ptr registry, py::bytes data) { + [](std::shared_ptr registry, nb::bytes data) { jax::PyTreeDefProto input; - std::string_view serialized = data; + std::string_view serialized(data.c_str(), data.size()); if (serialized.size() > std::numeric_limits::max()) { throw xla::XlaRuntimeError( "Pytree serialization too large to deserialize."); @@ -1107,16 +1212,25 @@ void BuildPytreeSubmodule(py::module& m) { } return PyTreeDef::DeserializeFrom(std::move(registry), input); }, - py::arg("registry"), py::arg("data")); + nb::arg("registry"), nb::arg("data")); treedef.def("node_data", &PyTreeDef::GetNodeData, "Returns None if a leaf-pytree, else (type, node_data)"); treedef.def_static( "make_from_node_data_and_children", - &PyTreeDef::MakeFromNodeDataAndChildren, + &PyTreeDef::MakeFromNodeDataAndChildren, nb::arg("registry"), + nb::arg("node_data").none(), nb::arg("children"), "Reconstructs a pytree from `node_data()` and `children()`."); - treedef.def( - py::pickle([](const PyTreeDef& t) { return t.ToPickle(); }, - [](py::object o) { return PyTreeDef::FromPickle(o); })); + treedef.def("__getstate__", &PyTreeDef::ToPickle); + treedef.def("__setstate__", [](PyTreeDef& t, nb::object o) { + nb::tuple pickle = nb::cast(o); + if (pickle.size() != 2) { + throw xla::XlaRuntimeError( + "Malformed pickled PyTreeDef, expected 2-tuple"); + } + auto registry = nb::cast>(pickle[0]); + new (&t) PyTreeDef(registry); + t.FromPickle(pickle[1]); + }); } } // namespace xla diff --git a/xla/python/pytree.h b/xla/python/pytree.h index 521ffde6f8a31..87fb085c3bee8 100644 --- a/xla/python/pytree.h +++ b/xla/python/pytree.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,9 +19,9 @@ limitations under the License. // See https://jax.readthedocs.io/en/latest/pytrees.html for the documentation // about pytree. +#include #include #include -#include #include #include #include @@ -30,9 +30,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/hash/hash.h" -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 -#include "pybind11/stl.h" // from @pybind11 // IWYU pragma: keep +#include "absl/types/span.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "xla/python/nb_class_ptr.h" #include "xla/python/pytree.pb.h" namespace xla { @@ -52,52 +52,67 @@ class PyTreeRegistry : public std::enable_shared_from_this { public: PyTreeRegistry(bool enable_none, bool enable_tuple, bool enable_namedtuple, bool enable_list, bool enable_dict); + + PyTreeRegistry(const PyTreeRegistry&) = delete; + PyTreeRegistry(PyTreeRegistry&&) = delete; + PyTreeRegistry& operator=(const PyTreeRegistry&) = delete; + PyTreeRegistry& operator=(PyTreeRegistry&&) = delete; + struct Registration { PyTreeKind kind; // The following values are populated for custom types. // The Python type object, used to identify the type. - pybind11::object type; + nanobind::object type; // A function with signature: object -> (iterable, aux_data) - pybind11::function to_iterable; + nanobind::callable to_iterable; // A function with signature: (aux_data, iterable) -> object - pybind11::function from_iterable; + nanobind::callable from_iterable; + + // Helper that calls to_iterable and validates that it returns a pair + // of an iterable and an aux_data object + std::pair ToIterable( + nanobind::handle o) const; }; // Registers a new custom type. Objects of `type` will be treated as container // node types in PyTrees. - void Register(pybind11::object type, pybind11::function to_iterable, - pybind11::function from_iterable); + void Register(nanobind::object type, nanobind::callable to_iterable, + nanobind::callable from_iterable); // Finds the custom type registration for `type`. Returns nullptr if none // exists. - const Registration* Lookup(pybind11::handle type) const; + const Registration* Lookup(nanobind::handle type) const; - PyTreeKind KindOfObject(pybind11::handle obj, + PyTreeKind KindOfObject(nanobind::handle obj, PyTreeRegistry::Registration const** custom) const; + // Flattens a pytree one level, returning either a tuple of the leaves and + // the node data, or None, if the entry is a leaf. + nanobind::object FlattenOneLevel(nanobind::handle x) const; + private: struct TypeHash { using is_transparent = void; - size_t operator()(const pybind11::object& t) const { + size_t operator()(const nanobind::object& t) const { return absl::HashOf(t.ptr()); } - size_t operator()(const pybind11::handle& t) const { + size_t operator()(const nanobind::handle& t) const { return absl::HashOf(t.ptr()); } }; struct TypeEq { using is_transparent = void; - bool operator()(const pybind11::object& a, - const pybind11::object& b) const { + bool operator()(const nanobind::object& a, + const nanobind::object& b) const { return a.ptr() == b.ptr(); } - bool operator()(const pybind11::object& a, - const pybind11::handle& b) const { + bool operator()(const nanobind::object& a, + const nanobind::handle& b) const { return a.ptr() == b.ptr(); } }; - absl::flat_hash_map, TypeHash, + absl::flat_hash_map, TypeHash, TypeEq> registrations_; bool enable_namedtuple_; @@ -121,55 +136,56 @@ class PyTreeDef { // Flattens a Pytree into a list of leaves and a PyTreeDef. // Returns references to the flattened objects, which might be temporary // objects in the case of custom pytype handlers. - static std::pair, std::unique_ptr> - Flatten(pybind11::handle x, - std::optional leaf_predicate = std::nullopt, + static std::pair, nb_class_ptr> + Flatten(nanobind::handle x, + std::optional leaf_predicate = std::nullopt, std::shared_ptr registry = nullptr); // Flattens a Pytree into a list of `leaves` and a PyTreeDef (this). // `leaves` owns references to the flattened objects, which might be // temporary objects in the case of custom pytype handlers. - void Flatten(pybind11::handle handle, std::vector& leaves, - std::optional leaf_predicate = std::nullopt); - void Flatten(pybind11::handle handle, - absl::InlinedVector& leaves, - std::optional leaf_predicate = std::nullopt); + void Flatten(nanobind::handle handle, std::vector& leaves, + std::optional leaf_predicate = std::nullopt); + void Flatten(nanobind::handle handle, + absl::InlinedVector& leaves, + std::optional leaf_predicate = std::nullopt); + void Flatten(nanobind::handle handle, nanobind::list& leaves, + std::optional leaf_predicate = std::nullopt); // Tests whether the given list is a flat list of leaves. - static bool AllLeaves(PyTreeRegistry* registry, const pybind11::iterable& x); + static bool AllLeaves(PyTreeRegistry* registry, const nanobind::iterable& x); // Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of // the tree-structure of 'x'. For example, if we flatten a value // [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the // list of leaves [1, (2, 3), {"foo": 4}]. - pybind11::list FlattenUpTo(pybind11::handle x) const; + nanobind::list FlattenUpTo(nanobind::handle x) const; // Returns an unflattened PyTree given an iterable of leaves and a PyTreeDef. - pybind11::object Unflatten(pybind11::iterable leaves) const; - pybind11::object Unflatten(absl::Span leaves) const; + nanobind::object Unflatten(nanobind::iterable leaves) const; + nanobind::object Unflatten(absl::Span leaves) const; // Composes two PyTreeDefs, replacing the leaves of this tree with copies of // `inner`. The returned PyTreeDef holds a reference to its registry. - std::unique_ptr Compose(const PyTreeDef& inner) const; + nb_class_ptr Compose(const PyTreeDef& inner) const; // Makes a Tuple PyTreeDef out of a vector of PyTreeDefs. - static std::unique_ptr Tuple( - std::shared_ptr registry, - absl::Span defs); + static nb_class_ptr Tuple(std::shared_ptr registry, + nanobind::list defs); // The returned PyTreeDefs hold a reference to the registry. - std::vector> Children() const; + std::vector> Children() const; // Maps a function over a PyTree structure, applying f_leaf to each leaf, and // f_node(node, node_data) to each container node. - pybind11::object Walk(const pybind11::function& f_node, - pybind11::handle f_leaf, - pybind11::iterable leaves) const; + nanobind::object Walk(const nanobind::callable& f_node, + nanobind::handle f_leaf, + nanobind::iterable leaves) const; // Given a tree of iterables with the same node/leaf structure as this PyTree, // build the corresponding PyTree. // TODO(phawkins): use flattening everywhere instead and delete this method. - pybind11::object FromIterableTree(pybind11::handle xs) const; + nanobind::object FromIterableTree(nanobind::handle xs) const; int num_leaves() const { if (traversal_.empty()) { @@ -191,24 +207,25 @@ class PyTreeDef { // Transforms the PyTreeDef into a pickleable object. Used to implement // `PyTreeDef.__getstate__`. - pybind11::object ToPickle() const; + nanobind::object ToPickle() const; // Transforms the object returned by `ToPickleable()` back to PyTreeDef. Used // to implement `PyTreeDef.__setstate__`. - static PyTreeDef FromPickle(pybind11::object pickleable); + void FromPickle(nanobind::object pickleable); void SerializeTo(jax::PyTreeDefProto& result) const; - static PyTreeDef DeserializeFrom(std::shared_ptr registry, - const jax::PyTreeDefProto& input); + static nb_class_ptr DeserializeFrom( + std::shared_ptr registry, + const jax::PyTreeDefProto& input); - std::optional> GetNodeData() + std::optional> GetNodeData() const; - static PyTreeDef MakeFromNodeDataAndChildren( + static nb_class_ptr MakeFromNodeDataAndChildren( std::shared_ptr registry, - std::optional> node_data, - pybind11::iterable children); + std::optional> node_data, + nanobind::iterable children); private: void SetNumLeavesAndNumNodes(); @@ -222,14 +239,14 @@ class PyTreeDef { // Kind-specific auxiliary data. For a kNamedTuple, contains the tuple type // object. For a kDict, use `sorted_dict_keys` field below. For a kCustom // type, contains the auxiliary data returned by the `to_iterable` function. - pybind11::object node_data; + nanobind::object node_data; // Kind-specific auxiliary data specialized for kDict. Use a c++ vector // to hold the sorted dict keys instead of a py::list to avoid creating // a new python list object when flattening kDict. For deeply nested dict, // using c++ vector instead of py::list avoids creating too many python // objects that make python gc sweep slow. - std::vector sorted_dict_keys; + std::vector sorted_dict_keys; // Custom type registration. Must be null for non-custom types. const PyTreeRegistry::Registration* custom = nullptr; @@ -247,21 +264,21 @@ class PyTreeDef { friend H AbslHashValue(H h, const PyTreeDef& t); // Helper that manufactures an instance of a node given its children. - static pybind11::object MakeNode(const Node& node, - absl::Span children); + static nanobind::object MakeNode(const Node& node, + absl::Span children); // Recursive helper used to implement FromIterableTree() - pybind11::object FromIterableTreeHelper( - pybind11::handle xs, + nanobind::object FromIterableTreeHelper( + nanobind::handle xs, absl::InlinedVector::const_reverse_iterator* it) const; template - void FlattenImpl(pybind11::handle handle, T& leaves, - const std::optional& leaf_predicate); + void FlattenImpl(nanobind::handle handle, T& leaves, + const std::optional& leaf_predicate); template - pybind11::object UnflattenImpl(T leaves) const; + nanobind::object UnflattenImpl(T leaves) const; // Pytree registry. Not owned. PyTreeRegistry* registry_; @@ -287,7 +304,7 @@ H AbslHashValue(H h, const PyTreeDef& t) { return h; } -void BuildPytreeSubmodule(pybind11::module& m); +void BuildPytreeSubmodule(nanobind::module_& m); } // namespace xla diff --git a/xla/python/pytree_test.py b/xla/python/pytree_test.py index 2bba921799e4e..4125d7a28257a 100644 --- a/xla/python/pytree_test.py +++ b/xla/python/pytree_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/python/refine_polymorphic_shapes.cc b/xla/python/refine_polymorphic_shapes.cc index 31961654cc2e1..465a3d3565b54 100644 --- a/xla/python/refine_polymorphic_shapes.cc +++ b/xla/python/refine_polymorphic_shapes.cc @@ -255,6 +255,7 @@ absl::Status RefinePolymorphicShapes(mlir::ModuleOp module, // TODO(necula): we should not need the inliner. pm.addPass(mlir::createInlinerPass()); pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::stablehlo::experimental::createChloRecomposeOpsPass()); pm.addPass(mlir::stablehlo::experimental::createStablehloRefineShapesPass()); pm.addNestedPass( mlir::stablehlo::experimental::createStablehloCanonicalizeDynamismPass()); diff --git a/xla/python/sharded_device_array.h b/xla/python/sharded_device_array.h index 8dd9acbb04221..12f3e22327bf8 100644 --- a/xla/python/sharded_device_array.h +++ b/xla/python/sharded_device_array.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,19 +16,13 @@ limitations under the License. #ifndef XLA_PYTHON_SHARDED_DEVICE_ARRAY_H_ #define XLA_PYTHON_SHARDED_DEVICE_ARRAY_H_ -#include -#include #include #include #include #include "absl/types/variant.h" -#include "pybind11/cast.h" // from @pybind11 -#include "pybind11/numpy.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 -#include "xla/pjrt/pjrt_client.h" -#include "xla/python/py_buffer.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/variant.h" // from @nanobind // IWYU pragma: keep #include "xla/python/types.h" // TODO(jblespiau): The current implementation moves the Python logic to C++, @@ -178,8 +172,8 @@ class ShardingSpec { std::vector mesh_mapping) : sharding_(std::move(sharding)), mesh_mapping_(std::move(mesh_mapping)) {} - ShardingSpec(pybind11::iterable py_sharding, - pybind11::iterable py_mesh_mapping) + ShardingSpec(nanobind::iterable py_sharding, + nanobind::iterable py_mesh_mapping) : sharding_(xla::IterableToVector(py_sharding)), mesh_mapping_( xla::IterableToVector(py_mesh_mapping)) {} diff --git a/xla/python/sharding.cc b/xla/python/sharding.cc index e0d5038776c50..2351f3491bb49 100644 --- a/xla/python/sharding.cc +++ b/xla/python/sharding.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,28 +15,33 @@ limitations under the License. #include "xla/python/sharding.h" +#include + #include -#include -#include #include +#include #include +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "pybind11/cast.h" // from @pybind11 -#include "pybind11/detail/common.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 -#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil -#include "xla/pjrt/pjrt_client.h" +#include "absl/strings/str_join.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // from @nanobind // IWYU pragma: keep +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" #include "xla/python/py_client.h" #include "xla/python/py_device_list.h" -#include "xla/python/util.h" -#include "xla/statusor.h" +#include "xla/python/sharded_device_array.h" +#include "tsl/platform/logging.h" namespace jax { -namespace py = pybind11; +namespace nb = nanobind; bool (*GetEnableMemories)() = +[] { static bool fetch_memory_kind_on_executable = [] { @@ -49,100 +54,92 @@ bool (*GetEnableMemories)() = +[] { return fetch_memory_kind_on_executable; }; -py::object CheckAndCanonicalizeMemoryKind(py::object memory_kind, - PyDeviceList* device_list) { +nb::object CheckAndCanonicalizeMemoryKind( + nb::object memory_kind, + const xla::nb_class_ptr& device_list) { if (!memory_kind.is_none()) { // If memory kind is not None, check if it's supported by the devices // mentioned in the Sharding. auto supported_memory_kinds = device_list->MemoryKinds(); if (!supported_memory_kinds.ok()) { - supported_memory_kinds = py::tuple(); + supported_memory_kinds = nb::tuple(); } - for (py::handle supported_memory_kind : *supported_memory_kinds) { + for (nb::handle supported_memory_kind : *supported_memory_kinds) { if (supported_memory_kind.equal(memory_kind)) { return memory_kind; } } - auto device_kind = py::cast( - device_list->AddressableDeviceList()->GetItem(0).attr("device_kind")); - throw py::value_error(absl::StrCat( - "Could not find memory addressable by device ", device_kind, - ". Device ", device_kind, " can address the following memory kinds: ", - py::cast( - py::str(", ").attr("join")(*supported_memory_kinds)), - ". Got memory kind: ", py::cast(memory_kind))); + nb::object device_kind = PyDeviceList::AddressableDeviceList(device_list) + ->GetItem(0) + .attr("device_kind"); + std::string_view device_kind_str = nb::cast(device_kind); + auto py_str_formatter = [](std::string* out, nb::handle h) { + *out += nb::cast(nb::str(h)); + }; + throw nb::value_error( + absl::StrCat( + "Could not find memory addressable by device ", device_kind_str, + ". Device ", device_kind_str, + " can address the following memory kinds: ", + absl::StrJoin(*supported_memory_kinds, ", ", py_str_formatter), + ". Got memory kind: ", nb::cast(memory_kind)) + .c_str()); } // If memory kind is None, canonicalize to default memory. - xla::StatusOr default_memory_kind = + absl::StatusOr default_memory_kind = device_list->DefaultMemoryKind(); if (!default_memory_kind.ok()) { - return py::none(); + return nb::none(); } return *std::move(default_memory_kind); } -int Sharding::SafeNumDevices(pybind11::handle sharding) { +int Sharding::SafeNumDevices(nb::handle sharding) { // Pure python shardings are not initialized, so we should not // even be casting if they are not initialized. - bool is_safe_to_cast = [&]() { - if (!xla::is_pybind_reinterpret_cast_ok(sharding)) { - return false; - } - auto* instance = - reinterpret_cast(sharding.ptr()); - for (auto vh : pybind11::detail::values_and_holders(instance)) { - if (!vh.holder_constructed()) { - return false; - } - } - - return true; - }(); - - if (is_safe_to_cast) { - auto* cpp_sharding = sharding.cast(); + if (nb::inst_check(sharding) && nb::inst_ready(sharding)) { + const auto* cpp_sharding = nb::cast(sharding); if (cpp_sharding->num_devices_.has_value()) { return (*cpp_sharding->num_devices_); } } - pybind11::set device_set = sharding.attr("device_set"); + nb::set device_set = sharding.attr("device_set"); return device_set.size(); } -size_t ShardingHash(const pybind11::object& sharding) { - auto type = sharding.get_type(); +size_t ShardingHash(nb::handle sharding) { + auto type = sharding.type(); if (type.is(NamedSharding::type())) { - const auto* named_sharding = xla::fast_cast(sharding); + const auto* named_sharding = nb::inst_ptr(sharding); return absl::Hash()(named_sharding->mesh().ptr()); } if (type.is(GSPMDSharding::type())) { - auto* gspmd_sharding = xla::fast_cast(sharding); + auto* gspmd_sharding = nb::inst_ptr(sharding); return gspmd_sharding->Hash(); } if (type.is(SingleDeviceSharding::type())) { - auto* single_device_sharding = - xla::fast_cast(sharding); + auto* single_device_sharding = nb::inst_ptr(sharding); return absl::Hash()(single_device_sharding->device().ptr()); } - return py::hash(sharding); + return xla::nb_hash(sharding); } -bool ShardingEqual(const pybind11::object& a, const pybind11::object& b) { +bool ShardingEqual(nb::handle a, nb::handle b) { if (a.ptr() == b.ptr()) return true; - auto a_type = a.get_type(); - auto b_type = b.get_type(); + auto a_type = a.type(); + auto b_type = b.type(); if (!a_type.is(b_type)) return false; if (a_type.is(NamedSharding::type())) { - auto* a_named_sharding = xla::fast_cast(a); - auto* b_named_sharding = xla::fast_cast(b); + auto* a_named_sharding = nb::inst_ptr(a); + auto* b_named_sharding = nb::inst_ptr(b); return a_named_sharding->mesh().ptr() == b_named_sharding->mesh().ptr() && a_named_sharding->spec().equal(b_named_sharding->spec()) && @@ -153,17 +150,17 @@ bool ShardingEqual(const pybind11::object& a, const pybind11::object& b) { } if (a_type.is(GSPMDSharding::type())) { - auto* a_gspmd_sharding = xla::fast_cast(a); - auto* b_gspmd_sharding = xla::fast_cast(b); + auto* a_gspmd_sharding = nb::inst_ptr(a); + auto* b_gspmd_sharding = nb::inst_ptr(b); return a_gspmd_sharding == b_gspmd_sharding; } if (a_type.is(SingleDeviceSharding::type())) { auto* a_single_device_sharding = - xla::fast_cast(a); + nb::inst_ptr(a); auto* b_single_device_sharding = - xla::fast_cast(b); + nb::inst_ptr(b); return a_single_device_sharding->device().ptr() == b_single_device_sharding->device().ptr() && @@ -174,50 +171,11 @@ bool ShardingEqual(const pybind11::object& a, const pybind11::object& b) { return a.equal(b); } -xla::ClientAndPtr GetMemory( - const xla::ClientAndPtr& device, const std::string& kind) { - xla::PjRtMemorySpace* result_memory_space = nullptr; - for (auto* memory_space : device->memory_spaces()) { - if (memory_space->memory_space_kind() == kind) { - if (result_memory_space != nullptr) { - std::string memories = absl::StrJoin( - device->memory_spaces(), ", ", - [](std::string* out, const auto& memory_space) { - absl::StrAppend(out, memory_space->memory_space_kind()); - }); - auto device_kind = device->device_kind(); - xla::ThrowIfError( - xla::InvalidArgument("Found more than one addressable memory for " - "kind %s which is not allowed. There can only " - "be one memory for each " - "kind. Device %s can address the following " - "memory kinds: %s", - kind, device_kind, memories)); - } - result_memory_space = memory_space; - } - } - if (result_memory_space == nullptr) { - std::string memories = - absl::StrJoin(device->memory_spaces(), ", ", - [](std::string* out, const auto& memory_space) { - absl::StrAppend(out, memory_space->memory_space_kind()); - }); - auto device_kind = device->device_kind(); - xla::ThrowIfError(xla::InvalidArgument( - "Could not find memory addressable by device %s. Device %s " - "can address the following memory kinds: %s. " - "Got memory kind: %s", - device_kind, device_kind, memories, kind)); - } - return WrapWithClient(device.client(), result_memory_space); -} - -NamedSharding::NamedSharding(py::object mesh, py::object spec, - py::object memory_kind, py::object parsed_pspec, - py::object manual_axes) +NamedSharding::NamedSharding(nb::object mesh, nb::object spec, + nb::object memory_kind, nb::object parsed_pspec, + nb::object manual_axes) : XLACompatibleSharding(/*num_devices=*/[&mesh]() { - py::array devices = mesh.attr("devices"); + xla::nb_numpy_ndarray devices = mesh.attr("devices"); return devices.size(); }()), mesh_(std::move(mesh)), @@ -225,127 +183,132 @@ NamedSharding::NamedSharding(py::object mesh, py::object spec, memory_kind_(std::move(memory_kind)), parsed_pspec_(std::move(parsed_pspec)), manual_axes_(std::move(manual_axes)) { - py::cast(this).attr("_preprocess")(); - internal_device_list_ = py::cast>( - mesh_.attr("_internal_device_list")); + nb::object idl = nb::object(mesh_.attr("_internal_device_list")); + internal_device_list_ = nb::cast>( + nb::object(mesh_.attr("_internal_device_list"))); memory_kind_ = - CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_.get()); + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); + + nb::module_ si = nb::module_::import_("jax._src.sharding_impls"); + parsed_pspec_ = si.attr("preprocess")(mesh_, spec_, parsed_pspec_); } -SingleDeviceSharding::SingleDeviceSharding(py::object device, - py::object memory_kind) +SingleDeviceSharding::SingleDeviceSharding(nb::object device, + nb::object memory_kind) : XLACompatibleSharding(/*num_devices=*/1), device_(device), memory_kind_(std::move(memory_kind)), - internal_device_list_(std::make_shared( - pybind11::make_tuple(std::move(device)))) { + internal_device_list_( + xla::make_nb_class(nb::make_tuple(std::move(device)))) { memory_kind_ = - CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_.get()); + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); } SingleDeviceSharding::SingleDeviceSharding( - std::shared_ptr client, xla::ifrt::DeviceList device_list, - pybind11::object memory_kind) + xla::nb_class_ptr client, xla::ifrt::DeviceList device_list, + nb::object memory_kind) : XLACompatibleSharding(/*num_devices=*/1), - device_(py::cast(WrapWithClient(client, device_list.front()))), + device_(client->GetPyDevice(device_list.front())), memory_kind_(std::move(memory_kind)), - internal_device_list_(std::make_shared( + internal_device_list_(xla::make_nb_class( std::move(client), std::move(device_list))) { memory_kind_ = - CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_.get()); + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); } -PmapSharding::PmapSharding(py::array devices, ShardingSpec sharding_spec) +PmapSharding::PmapSharding(xla::nb_numpy_ndarray devices, + ShardingSpec sharding_spec) : XLACompatibleSharding(/*num_devices=*/devices.size()), devices_(std::move(devices)), - sharding_spec_(std::move(sharding_spec)), - internal_device_list_(std::make_shared( - py::cast(devices_.attr("flat")))) {} + sharding_spec_(std::move(sharding_spec)) { + nb::object flat_devices = devices_.attr("flat"); + internal_device_list_ = + xla::make_nb_class(nb::tuple(flat_devices)); +} -GSPMDSharding::GSPMDSharding(py::tuple devices, xla::HloSharding op_sharding, - py::object memory_kind) - : XLACompatibleSharding(/*num_devices=*/devices.size()), - devices_(std::move(devices)), +GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding, + nb::object memory_kind, nb::object device_list) + : XLACompatibleSharding(/*num_devices=*/nb::len(devices.ptr())), + devices_(nb::tuple(devices)), hlo_sharding_(std::move(op_sharding)), - memory_kind_(std::move(memory_kind)), - internal_device_list_(std::make_shared(devices_)) { + memory_kind_(std::move(memory_kind)) { + if (device_list.is_none()) { + internal_device_list_ = xla::make_nb_class(devices_); + } else { + internal_device_list_ = + nb::cast>(std::move(device_list)); + } // This checks in python if the memory kind is correct for the given // devices. Currently in python this check is optimized but we want to // move that check to C++ after which we can remove this call. - CHECK(!devices_.empty()) + CHECK(devices_.size() != 0) << "Devices given to GSPMDSharding must not be empty"; memory_kind_ = - CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_.get()); + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); } -void RegisterSharding(py::module& m) { - py::object abc_module = py::module::import("abc"); - py::object abc_meta = abc_module.attr("ABCMeta"); - py::object abc_init = abc_module.attr("_abc_init"); - - // NOLINTNEXTLINE(bugprone-unused-raii) - py::class_(m, "Sharding", py::metaclass(abc_meta)); - abc_init(py::type::of()); - - // NOLINTNEXTLINE(bugprone-unused-raii) - py::class_(m, "XLACompatibleSharding", - py::metaclass(abc_meta)); - abc_init(py::type::of()); - - py::class_(m, "NamedSharding", - py::dynamic_attr()) - .def(py::init(), - py::arg("mesh"), py::arg("spec"), py::kw_only(), - py::arg("memory_kind") = py::none(), - py::arg("_parsed_pspec") = py::none(), - py::arg("_manual_axes") = py::frozenset(py::set())) - .def_property_readonly("mesh", &NamedSharding::mesh) - .def_property_readonly("spec", &NamedSharding::spec) - .def_property_readonly("_memory_kind", &NamedSharding::memory_kind) - .def_property_readonly("_manual_axes", &NamedSharding::manual_axes) - .def_property("_parsed_pspec", &NamedSharding::parsed_pspec, - &NamedSharding::set_parsed_pspec) - .def_property_readonly("_internal_device_list", - &NamedSharding::internal_device_list); - - py::class_( - m, "SingleDeviceSharding", py::dynamic_attr()) - .def(py::init(), py::arg("device"), py::kw_only(), - py::arg("memory_kind") = py::none()) - .def_property_readonly("_device", &SingleDeviceSharding::device) - .def_property_readonly("_memory_kind", &SingleDeviceSharding::memory_kind) - .def_property_readonly("_internal_device_list", - &SingleDeviceSharding::internal_device_list); - - py::class_(m, "PmapSharding", - py::dynamic_attr()) - .def(py::init(), py::arg("devices"), - py::arg("sharding_spec")) - .def_property_readonly("devices", &PmapSharding::devices) - .def_property_readonly("sharding_spec", &PmapSharding::sharding_spec) - .def_property_readonly("_internal_device_list", - &PmapSharding::internal_device_list); - - py::class_(m, "GSPMDSharding", - py::dynamic_attr()) - .def(py::init(), - py::arg("devices"), py::arg("op_sharding"), py::kw_only(), - py::arg("memory_kind") = py::none()) - .def(py::init(), - py::arg("devices"), py::arg("op_sharding"), py::kw_only(), - py::arg("memory_kind") = py::none()) - .def(py::init(), - py::arg("devices"), py::arg("op_sharding"), py::kw_only(), - py::arg("memory_kind") = py::none()) - .def(py::init(), - py::arg("devices"), py::arg("op_sharding"), py::kw_only(), - py::arg("memory_kind") = py::none()) - .def_property_readonly("_devices", &GSPMDSharding::devices) - .def_property_readonly("_hlo_sharding", &GSPMDSharding::hlo_sharding) - .def_property_readonly("_memory_kind", &GSPMDSharding::memory_kind) - .def_property_readonly("_internal_device_list", - &GSPMDSharding::internal_device_list); +void RegisterSharding(nb::module_& m) { + nb::class_(m, "Sharding").def(nb::init<>()); + + nb::class_(m, "XLACompatibleSharding") + .def(nb::init<>()); + + nb::class_(m, "NamedSharding", + nb::dynamic_attr()) + .def(nb::init(), + nb::arg("mesh"), nb::arg("spec").none(), + nb::arg("memory_kind").none() = nb::none(), + nb::arg("_parsed_pspec").none() = nb::none(), + nb::arg("_manual_axes") = nb::steal(PyFrozenSet_New(nullptr))) + .def_prop_ro("mesh", &NamedSharding::mesh) + .def_prop_ro("spec", &NamedSharding::spec) + .def_prop_ro("_memory_kind", &NamedSharding::memory_kind) + .def_prop_ro("_manual_axes", &NamedSharding::manual_axes) + .def_prop_rw("_parsed_pspec", &NamedSharding::parsed_pspec, + &NamedSharding::set_parsed_pspec) + .def_prop_ro("_internal_device_list", + &NamedSharding::internal_device_list); + + nb::class_( + m, "SingleDeviceSharding", nb::dynamic_attr()) + .def(nb::init(), nb::arg("device"), + nb::arg("memory_kind").none() = nb::none()) + .def_prop_ro("_device", &SingleDeviceSharding::device) + .def_prop_ro("_memory_kind", &SingleDeviceSharding::memory_kind) + .def_prop_ro("_internal_device_list", + &SingleDeviceSharding::internal_device_list); + + nb::class_(m, "PmapSharding", + nb::dynamic_attr()) + .def( + "__init__", + [](PmapSharding* self, nb::object devices, + ShardingSpec sharding_spec) { + new (self) PmapSharding(xla::nb_numpy_ndarray::ensure(devices), + std::move(sharding_spec)); + }, + nb::arg("devices"), nb::arg("sharding_spec")) + .def_prop_ro("devices", &PmapSharding::devices) + .def_prop_ro("sharding_spec", &PmapSharding::sharding_spec) + .def_prop_ro("_internal_device_list", + &PmapSharding::internal_device_list); + + nb::class_(m, "GSPMDSharding", + nb::dynamic_attr()) + .def(nb::init(), + nb::arg("devices"), nb::arg("op_sharding"), + nb::arg("memory_kind").none() = nb::none(), + nb::arg("_device_list").none() = nb::none()) + .def(nb::init(), + nb::arg("devices"), nb::arg("op_sharding"), + nb::arg("memory_kind").none() = nb::none(), + nb::arg("_device_list").none() = nb::none()) + .def_prop_ro("_devices", &GSPMDSharding::devices) + .def_prop_ro("_hlo_sharding", &GSPMDSharding::hlo_sharding) + .def_prop_ro("_memory_kind", &GSPMDSharding::memory_kind) + .def_prop_ro("_internal_device_list", + &GSPMDSharding::internal_device_list); } } // namespace jax diff --git a/xla/python/sharding.h b/xla/python/sharding.h index 0e780b10c3783..8037594e30781 100644 --- a/xla/python/sharding.h +++ b/xla/python/sharding.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,25 +16,21 @@ limitations under the License. #ifndef XLA_PYTHON_SHARDING_H_ #define XLA_PYTHON_SHARDING_H_ -#include +#include #include -#include #include -#include -#include // placeholder for index annotation headers -#include "absl/types/span.h" -#include "pybind11/cast.h" // from @pybind11 -#include "pybind11/numpy.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 -#include "xla/pjrt/pjrt_client.h" +#include "absl/hash/hash.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_numpy.h" #include "xla/python/py_client.h" #include "xla/python/py_device_list.h" #include "xla/python/sharded_device_array.h" -#include "xla/python/status_casters.h" #include "xla/xla_data.pb.h" namespace jax { @@ -49,7 +45,7 @@ class Sharding { virtual ~Sharding() = default; - static int SafeNumDevices(pybind11::handle sharding); + static int SafeNumDevices(nanobind::handle sharding); private: std::optional num_devices_; @@ -59,18 +55,16 @@ extern bool (*GetEnableMemories)(); // Checks if the memory kind is valid, and canonicalizes the // memory kind to default memory on backends that support memories. -pybind11::object CheckAndCanonicalizeMemoryKind(pybind11::object memory_kind, - PyDeviceList* device_list); +nanobind::object CheckAndCanonicalizeMemoryKind( + nanobind::object memory_kind, + const xla::nb_class_ptr& device_list); // Returns a hash that may sometimes return different hashes for equal values. // It is not a correct implementation of `__hash__` in python, but it's fine // for jit/pjit dispatch since it only causes spurious cache misses. -size_t ShardingHash(const pybind11::object& sharding); +size_t ShardingHash(nanobind::handle sharding); -bool ShardingEqual(const pybind11::object& a, const pybind11::object& b); - -xla::ClientAndPtr GetMemory( - const xla::ClientAndPtr& device, const std::string& kind); +bool ShardingEqual(nanobind::handle a, nanobind::handle b); class XLACompatibleSharding : public Sharding { public: @@ -81,118 +75,106 @@ class XLACompatibleSharding : public Sharding { class NamedSharding : public XLACompatibleSharding { public: - NamedSharding(pybind11::object mesh, pybind11::object spec, - pybind11::object memory_kind, pybind11::object parsed_pspec, - pybind11::object manual_axes); - - const pybind11::object& mesh() const { return mesh_; } - const pybind11::object& spec() const { return spec_; } - const pybind11::object& memory_kind() const { return memory_kind_; } - const pybind11::object& parsed_pspec() const { return parsed_pspec_; } - const pybind11::object& manual_axes() const { return manual_axes_; } - void set_parsed_pspec(pybind11::object parsed_pspec) { + NamedSharding(nanobind::object mesh, nanobind::object spec, + nanobind::object memory_kind, nanobind::object parsed_pspec, + nanobind::object manual_axes); + + const nanobind::object& mesh() const { return mesh_; } + const nanobind::object& spec() const { return spec_; } + const nanobind::object& memory_kind() const { return memory_kind_; } + const nanobind::object& parsed_pspec() const { return parsed_pspec_; } + const nanobind::object& manual_axes() const { return manual_axes_; } + void set_parsed_pspec(nanobind::object parsed_pspec) { parsed_pspec_ = std::move(parsed_pspec); } - static pybind11::handle type() { - static auto type = pybind11::type::handle_of(); + static nanobind::handle type() { + static auto type = nanobind::type(); return type; } - std::shared_ptr internal_device_list() const { + xla::nb_class_ptr internal_device_list() const { return internal_device_list_; } private: - pybind11::object mesh_; - pybind11::object spec_; - pybind11::object memory_kind_; - pybind11::object parsed_pspec_; - pybind11::object manual_axes_; - std::shared_ptr internal_device_list_; + nanobind::object mesh_; + nanobind::object spec_; + nanobind::object memory_kind_; + nanobind::object parsed_pspec_; + nanobind::object manual_axes_; + xla::nb_class_ptr internal_device_list_; }; class SingleDeviceSharding : public XLACompatibleSharding { public: explicit SingleDeviceSharding( - pybind11::object device, pybind11::object memory_kind = pybind11::none()); + nanobind::object device, nanobind::object memory_kind = nanobind::none()); // Used only in C++ to accelerate `PyArray::MakeFromSingleDeviceArray()`. - SingleDeviceSharding(std::shared_ptr client, + SingleDeviceSharding(xla::nb_class_ptr client, xla::ifrt::DeviceList device_list, - pybind11::object memory_kind); + nanobind::object memory_kind); - const pybind11::object& device() const { return device_; } - const pybind11::object& memory_kind() const { return memory_kind_; } + const nanobind::object& device() const { return device_; } + const nanobind::object& memory_kind() const { return memory_kind_; } - static pybind11::handle type() { - static auto type = pybind11::type::handle_of(); + static nanobind::handle type() { + static auto type = nanobind::type(); return type; } - std::shared_ptr internal_device_list() const { + xla::nb_class_ptr internal_device_list() const { return internal_device_list_; } private: - pybind11::object device_; - pybind11::object memory_kind_; - std::shared_ptr internal_device_list_; + nanobind::object device_; + nanobind::object memory_kind_; + xla::nb_class_ptr internal_device_list_; }; // The C++ implementation of jax.PmapSharding in python. It contains a few key // data members and methods that are performance-critical. class PmapSharding : public XLACompatibleSharding { public: - PmapSharding(pybind11::array devices, ShardingSpec sharding_spec); + PmapSharding(xla::nb_numpy_ndarray devices, ShardingSpec sharding_spec); ~PmapSharding() override = default; - pybind11::array devices() const { return devices_; } + xla::nb_numpy_ndarray devices() const { return devices_; } const ShardingSpec& sharding_spec() const { return sharding_spec_; } - static pybind11::handle type() { - static auto type = pybind11::type::handle_of(); + static nanobind::handle type() { + static auto type = nanobind::type(); return type; } - std::shared_ptr internal_device_list() const { + xla::nb_class_ptr internal_device_list() const { return internal_device_list_; } private: - pybind11::array devices_; + xla::nb_numpy_ndarray devices_; ShardingSpec sharding_spec_; - std::shared_ptr internal_device_list_; + xla::nb_class_ptr internal_device_list_; }; class GSPMDSharding : public XLACompatibleSharding { public: - GSPMDSharding(pybind11::list devices, xla::OpSharding op_sharding, - pybind11::object memory_kind = pybind11::none()) - : GSPMDSharding( - pybind11::tuple(devices), - xla::ValueOrThrow(xla::HloSharding::FromProto(op_sharding)), - std::move(memory_kind)) {} - - GSPMDSharding(pybind11::tuple devices, xla::OpSharding op_sharding, - pybind11::object memory_kind = pybind11::none()) + GSPMDSharding(nanobind::sequence devices, xla::OpSharding op_sharding, + nanobind::object memory_kind, nanobind::object device_list) : GSPMDSharding( std::move(devices), xla::ValueOrThrow(xla::HloSharding::FromProto(op_sharding)), - std::move(memory_kind)) {} - - GSPMDSharding(pybind11::list devices, xla::HloSharding op_sharding, - pybind11::object memory_kind = pybind11::none()) - : GSPMDSharding(pybind11::tuple(devices), std::move(op_sharding), - std::move(memory_kind)) {} + std::move(memory_kind), std::move(device_list)) {} - GSPMDSharding(pybind11::tuple devices, xla::HloSharding op_sharding, - pybind11::object memory_kind = pybind11::none()); + GSPMDSharding(nanobind::sequence devices, xla::HloSharding op_sharding, + nanobind::object memory_kind, nanobind::object device_list); - const pybind11::tuple& devices() const { return devices_; } - const pybind11::object& memory_kind() const { return memory_kind_; } + const nanobind::tuple& devices() const { return devices_; } + const nanobind::object& memory_kind() const { return memory_kind_; } size_t Hash() { if (!hash_.has_value()) { @@ -201,8 +183,8 @@ class GSPMDSharding : public XLACompatibleSharding { return *hash_; } - static pybind11::handle type() { - static auto type = pybind11::type::handle_of(); + static nanobind::handle type() { + static auto type = nanobind::type(); return type; } @@ -214,7 +196,7 @@ class GSPMDSharding : public XLACompatibleSharding { this->memory_kind().equal(other.memory_kind()); } - std::shared_ptr internal_device_list() const { + xla::nb_class_ptr internal_device_list() const { return internal_device_list_; } @@ -246,14 +228,14 @@ class GSPMDSharding : public XLACompatibleSharding { return hlo_sharding().IsReplicated(); } - pybind11::tuple devices_; + nanobind::tuple devices_; xla::HloSharding hlo_sharding_; - pybind11::object memory_kind_; + nanobind::object memory_kind_; std::optional hash_; - std::shared_ptr internal_device_list_; + xla::nb_class_ptr internal_device_list_; }; -void RegisterSharding(pybind11::module& m); +void RegisterSharding(nanobind::module_& m); } // namespace jax diff --git a/xla/python/status_casters.h b/xla/python/status_casters.h deleted file mode 100644 index 1b7020005360d..0000000000000 --- a/xla/python/status_casters.h +++ /dev/null @@ -1,220 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_PYTHON_STATUS_CASTERS_H_ -#define XLA_PYTHON_STATUS_CASTERS_H_ - -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 -#include "xla/python/exceptions.h" -#include "xla/status.h" -#include "xla/statusor.h" -#include "tsl/platform/macros.h" - -namespace xla { - -// C++ -> Python caster helpers. -// -// Failing statuses become Python exceptions; OK Status() becomes None. -// -// Given there can be only a single global pybind11 type_caster for the -// `absl::Status` type, and given XLA wants a custom exception being raised, -// we use a dedicated helper to implement this feature without relying on a -// global `type_caster`. -// -// For example: -// -// - Functions without arguments: -// m.def("my_func", []() { xla::ThrowIfError(MyFunc()); } -// - Classes with a single argument: -// py_class.def("delete", [](Buffer& self) { -// xla::ThrowIfError(self.Delete()); -// } -// -// For functions with more arguments, you can either inline the arguments, -// or use the `ThrowIfErrorWrapper` wrapper defined below: -// -// m.def("my_func", xla::ThrowIfErrorWrapper(MyFunc)); -// -// Nonstatic member functions can be wrapped by passing a -// pointer-to-member-function: -// xla::ThrowIfErrorWrapper(&MyClass::MyMethod) - -inline void ThrowIfError(xla::Status src) { - if (!src.ok()) { - throw xla::XlaRuntimeError(src); - } -} - -// If one does not want to have to define a lambda specifying the inputs -// arguments, on can use the `ThrowIfErrorWrapper` wrapper. -// -// There are three specializations: -// - For free functions, `Sig` is the function type and `F` is `Sig&`. -// - For callable types, `Sig` is the pointer to member function type -// and `F` is the type of the callable. -// - For a nonstatic member function of a class `C`, `Sig` is the function type -// and `F` is Sig C::*. -// -// In the first two cases, the wrapper returns a callable with signature `Sig`; -// in the third case, the wrapper returns callable with a modified signature -// that takes a C instance as the first argument. -template -struct ThrowIfErrorWrapper; - -// C++17 "deduction guide" that guides class template argument deduction (CTAD) -// For free functions. -template -ThrowIfErrorWrapper(F) -> ThrowIfErrorWrapper; - -// For callable types (with operator()). -template -ThrowIfErrorWrapper(xla::Status (&)(Args...)) - -> ThrowIfErrorWrapper; - -// For unbound nonstatic member functions. -template -ThrowIfErrorWrapper(xla::Status (C::*)(Args...)) - -> ThrowIfErrorWrapper; - -// Template specializations. - -// For free functions. -template -struct ThrowIfErrorWrapper { - explicit ThrowIfErrorWrapper(xla::Status (&f)(Args...)) : func(f) {} - void operator()(Args... args) { - xla::ThrowIfError(func(std::forward(args)...)); - } - xla::Status (&func)(Args...); -}; - -// For callable types (with operator()), non-const and const versions. -template -struct ThrowIfErrorWrapper { - explicit ThrowIfErrorWrapper(F&& f) : func(std::move(f)) {} - void operator()(Args... args) { - xla::ThrowIfError(func(std::forward(args)...)); - } - F func; -}; -template -struct ThrowIfErrorWrapper { - explicit ThrowIfErrorWrapper(F&& f) : func(std::move(f)) {} - void operator()(Args... args) const { - xla::ThrowIfError(func(std::forward(args)...)); - } - F func; -}; - -// For unbound nonstatic member functions, non-const and const versions. -// `ptmf` stands for "pointer to member function". -template -struct ThrowIfErrorWrapper { - explicit ThrowIfErrorWrapper(xla::Status (C::*ptmf)(Args...)) : ptmf(ptmf) {} - void operator()(C& instance, Args... args) { - xla::ThrowIfError((instance.*ptmf)(std::forward(args)...)); - } - xla::Status (C::*ptmf)(Args...); -}; -template -struct ThrowIfErrorWrapper { - explicit ThrowIfErrorWrapper(xla::Status (C::*ptmf)(Args...) const) - : ptmf(ptmf) {} - void operator()(const C& instance, Args... args) const { - xla::ThrowIfError((instance.*ptmf)(std::forward(args)...)); - } - xla::Status (C::*ptmf)(Args...) const; -}; - -// Utilities for `StatusOr`. -template -T ValueOrThrow(StatusOr v) { - if (!v.ok()) { - throw xla::XlaRuntimeError(v.status()); - } - return std::move(v).value(); -} - -template -struct ValueOrThrowWrapper; - -template -ValueOrThrowWrapper(F) -> ValueOrThrowWrapper; - -template -ValueOrThrowWrapper(xla::StatusOr (&)(Args...)) - -> ValueOrThrowWrapper(Args...), - xla::StatusOr (&)(Args...)>; - -template -ValueOrThrowWrapper(xla::StatusOr (C::*)(Args...)) - -> ValueOrThrowWrapper(Args...), C>; - -// Deduction guide for const methods. -template -ValueOrThrowWrapper(xla::StatusOr (C::*)(Args...) const) - -> ValueOrThrowWrapper(Args...) const, C>; - -template -struct ValueOrThrowWrapper(Args...), - xla::StatusOr (&)(Args...)> { - explicit ValueOrThrowWrapper(xla::StatusOr (&f)(Args...)) : func(f) {} - R operator()(Args... args) { - return xla::ValueOrThrow(func(std::forward(args)...)); - } - xla::StatusOr (&func)(Args...); -}; -template -struct ValueOrThrowWrapper (C::*)(Args...), F> { - explicit ValueOrThrowWrapper(F&& f) : func(std::move(f)) {} - R operator()(Args... args) { - return xla::ValueOrThrow(func(std::forward(args)...)); - } - F func; -}; -template -struct ValueOrThrowWrapper (C::*)(Args...) const, F> { - explicit ValueOrThrowWrapper(F&& f) : func(std::move(f)) {} - R operator()(Args... args) const { - return xla::ValueOrThrow(func(std::forward(args)...)); - } - F func; -}; - -// For unbound nonstatic member functions, non-const and const versions. -// `ptmf` stands for "pointer to member function". -template -struct ValueOrThrowWrapper(Args...), C> { - explicit ValueOrThrowWrapper(xla::StatusOr (C::*ptmf)(Args...)) - : ptmf(ptmf) {} - R operator()(C& instance, Args... args) { - return xla::ValueOrThrow((instance.*ptmf)(std::forward(args)...)); - } - xla::StatusOr (C::*ptmf)(Args...); -}; -template -struct ValueOrThrowWrapper(Args...) const, C> { - explicit ValueOrThrowWrapper(xla::StatusOr (C::*ptmf)(Args...) const) - : ptmf(ptmf) {} - R operator()(const C& instance, Args... args) const { - return xla::ValueOrThrow((instance.*ptmf)(std::forward(args)...)); - } - xla::StatusOr (C::*ptmf)(Args...) const; -}; - -} // namespace xla - -#endif // XLA_PYTHON_STATUS_CASTERS_H_ diff --git a/xla/python/status_casters_ext.cc b/xla/python/status_casters_ext.cc index 8f903da84427e..2b18b8b15f478 100644 --- a/xla/python/status_casters_ext.cc +++ b/xla/python/status_casters_ext.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,45 +13,47 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 -#include "xla/python/exceptions.h" -#include "xla/python/status_casters.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/status_casters.h" namespace xla { -namespace py = ::pybind11; +namespace nb = ::nanobind; namespace { -xla::Status MyFunc() { return xla::OkStatus(); } +absl::Status MyFunc() { return absl::OkStatus(); } class MyClass { public: - xla::Status MyMethod(int a, int b) { return xla::OkStatus(); } - xla::Status MyMethodConst(int a, int b) const { return xla::OkStatus(); } + absl::Status MyMethod(int a, int b) { return absl::OkStatus(); } + absl::Status MyMethodConst(int a, int b) const { return absl::OkStatus(); } - xla::StatusOr MyStatusOrMethod(int a, int b) { return a + b; } - xla::StatusOr MyStatusOrMethodConst(int a, int b) const { return a + b; } + absl::StatusOr MyStatusOrMethod(int a, int b) { return a + b; } + absl::StatusOr MyStatusOrMethodConst(int a, int b) const { + return a + b; + } }; -xla::StatusOr StatusOrIdentity(int i) { return i; } +absl::StatusOr StatusOrIdentity(int i) { return i; } -PYBIND11_MODULE(status_casters_ext, m) { +NB_MODULE(status_casters_ext, m) { // Exceptions - py::register_exception(m, "XlaRuntimeError", - PyExc_RuntimeError); + nb::exception(m, "XlaRuntimeError", PyExc_RuntimeError); m.def("my_lambda", - xla::ThrowIfErrorWrapper([]() { return xla::OkStatus(); })); + xla::ThrowIfErrorWrapper([]() { return absl::OkStatus(); })); m.def("my_lambda2", xla::ThrowIfErrorWrapper(MyFunc)); m.def("my_lambda_statusor", - xla::ValueOrThrowWrapper([]() -> xla::StatusOr { return 1; })); + xla::ValueOrThrowWrapper([]() -> absl::StatusOr { return 1; })); m.def("status_or_identity", xla::ValueOrThrowWrapper(StatusOrIdentity)); - py::class_ my_class(m, "MyClass"); - my_class.def(py::init<>()); + nb::class_ my_class(m, "MyClass"); + my_class.def(nb::init<>()); my_class.def("my_method", xla::ThrowIfErrorWrapper(&MyClass::MyMethod)); my_class.def("my_method_const", xla::ThrowIfErrorWrapper(&MyClass::MyMethod)); my_class.def("my_method_status_or", diff --git a/xla/python/status_casters_test.py b/xla/python/status_casters_test.py index 62901ec28ab25..cf15b238197ec 100644 --- a/xla/python/status_casters_test.py +++ b/xla/python/status_casters_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/python/tools/BUILD b/xla/python/tools/BUILD new file mode 100644 index 0000000000000..aa81e68cf3d92 --- /dev/null +++ b/xla/python/tools/BUILD @@ -0,0 +1,94 @@ +load("@tsl//tsl:tsl.default.bzl", "tsl_pybind_extension") + +# NOTE: We can't use `pytype_pybind_extension` nor `pytype_strict_contrib_test` +# because the OSS versions of these files do not include ports of those rules. +# We must instead use `tsl_pybind_extension` and `py_strict_test`. +load("//xla:pytype.default.bzl", "pytype_strict_library") +load("//xla:strict.default.bzl", "py_strict_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +exports_files([ + "__init__.py", + "types.py", + "_types.pyi", +]) + +# NOTE: This wrapper library is necessary in order to capture the Python +# dependencies of our extension (namely `ml_dtypes`). Although the +# underlying `pybind_extension` rule has a `py_deps` argument for capturing +# such dependencies directly, the `tsl_pybind_extension` rule doesn't expose +# that `py_deps` argument for us to use. +# +# NOTE: On the OSS side, the `pytype_strict_library` rule is changed into +# the non-typed rule, which in turn causes an error about the `pytype_srcs` +# field. The "..:xla_client" target gets around this by adding a custom +# copybara rule; but in lieu of adding yet another custom rule to maintain, +# we just use the generic copybara mechanism for commenting the field out +# on the OSS side. +# TODO(wrengr,phawkins): Once cl/619904840 lands, we can remove the +# pragma and the preceding commentary. +pytype_strict_library( + name = "types", + srcs = ["types.py"], + # copybara:uncomment pytype_srcs = ["_types.pyi"], + srcs_version = "PY3", + # Cannot build this on OSS because the ":xla_data_proto_py_pb2" + # dependency isn't part of the public API. + tags = ["no_oss"], + visibility = ["//visibility:public"], + deps = [ + ":_types", # buildcleaner: keep + "//third_party/py/numpy", + "//xla:xla_data_proto_py_pb2", + "@ml_dtypes", + ], +) + +# NOTE: Copybara detects the `tsl_pybind_extension` rule and automatically +# injects the "@com_google_protobuf//:protobuf_python" python dependency +# required by "@pybind11_protobuf//pybind11_protobuf:native_proto_caster". +tsl_pybind_extension( + name = "_types", + srcs = ["_types.cc"], + pytype_deps = ["//third_party/py/numpy"], + pytype_srcs = ["_types.pyi"], + # Users should depend on ":types" instead. + visibility = ["//visibility:private"], + deps = [ + "//xla:literal", + "//xla:xla_data_proto_cc", + "//xla/pjrt:status_casters", + "//xla/python:logging", + "//xla/python:nb_numpy", + "//xla/python:types", + "//xla/tsl/python/lib/core:numpy", + "@com_google_absl//absl/strings", + "@nanobind", + "@pybind11", + "@pybind11_abseil//pybind11_abseil:import_status_module", + "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", + ], +) + +py_strict_test( + name = "types_test", + size = "small", + srcs = ["types_test.py"], + python_version = "PY3", + srcs_version = "PY3", + # Cannot build this on OSS because the ":xla_data_proto_py_pb2" + # dependency isn't part of the public API. + tags = ["no_oss"], + deps = [ + ":types", + "//third_party/py/google/protobuf:use_fast_cpp_protos", # Automatically added go/proto_python_upb_flip + "//third_party/py/numpy", + "//xla:xla_data_proto_py_pb2", + "@absl_py//absl/testing:absltest", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/xla/python/tools/__init__.py b/xla/python/tools/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/xla/python/tools/_types.cc b/xla/python/tools/_types.cc new file mode 100644 index 0000000000000..1f3cde76cccc3 --- /dev/null +++ b/xla/python/tools/_types.cc @@ -0,0 +1,161 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/strings/str_cat.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/shared_ptr.h" // from @nanobind // IWYU pragma: keep +#include "pybind11/detail/common.h" // from @pybind11 +#include "pybind11/numpy.h" // from @pybind11 +#include "pybind11/pybind11.h" // from @pybind11 +#include "pybind11/pytypes.h" // from @pybind11 +// The "third_party/pybind11_abseil/status_casters.h" header says +// it's deprecated and that we should import the other headers directly. +#include "pybind11_abseil/import_status_module.h" // from @pybind11_abseil +#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf +#include "xla/literal.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/logging.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/xla_data.pb.h" +// NOTE: The tsl-numpy header forbids importing the actual NumPy arrayobject.h +// header before tsl-numpy (whereas, importing pybind11-numpy before tsl-numpy +// is fine); however, tsl-numpy does reexport NumPy's arrayobject.h header. +// Since one of the TF headers above already includes tsl-numpy, therefore +// we must include it down here rather than including actual NumPy directly. +#include "xla/tsl/python/lib/core/numpy.h" + +namespace py = ::pybind11; +namespace nb = ::nanobind; + +namespace { +py::object MakeNdarray(const xla::LiteralProto& proto) { + auto m_lit = xla::Literal::CreateFromProto(proto); + if (!m_lit.ok()) { + // NOTE: The OSS version of XLA is still using an old version of + // Abseil (LTS branch, Aug 2023, Patch 1) which does not have the + // `AbslStringify` interface for implicitly converting `absl::Status` + // into the `absl::AlphaNum` required by `absl::StrCat`. Therefore we + // inline the latest definition of the `AbslStringify` overload. + throw py::value_error(absl::StrCat( + "Cannot `xla::Literal::CreateFromProto`: ", + m_lit.status().ToString(absl::StatusToStringMode::kWithEverything))); + } + + // Move (not copy) the literal onto the heap, for sharing with Python. + auto lit = std::make_shared(std::move(m_lit).value()); + + auto nbobj = xla::ValueOrThrow(xla::LiteralToPython(std::move(lit))); + return py::reinterpret_steal(nbobj.release().ptr()); +} + +// Partial reversion of cl/617156835, until we can get the proto-casters +// (and hence the extension) switched over to nanobind. +// TODO(wrengr): Or can we mix `{py,nb}::module_::def` calls?? +xla::PrimitiveType DtypeToEtype(const py::dtype& py_d) { + auto nb_d = nb::borrow(py_d.ptr()); + return xla::ValueOrThrow(xla::DtypeToPrimitiveType(nb_d)); +} + +py::dtype EtypeToDtype(xla::PrimitiveType p) { + auto nb_d = xla::ValueOrThrow(xla::PrimitiveTypeToNbDtype(p)); + return py::reinterpret_steal(nb_d.release().ptr()); +} +} // namespace + +// NOTE: It seems insurmountable to get "native_proto_caster.h" to work +// with nanobind modules; therefore, we define our extension as a pybind11 +// module so that we can use `pybind11::module_::def`. +PYBIND11_MODULE(_types, py_m) { + // Initialize ABSL logging because code within XLA uses it. + // (As per `xla::Init` in "xla.cc"; though we don't need it ourselves.) +#ifndef PLATFORM_GOOGLE + xla::InitializeAbslLogging(); +#endif // PLATFORM_GOOGLE + + // Normally this would happen at the start of NB_MODULE, but since + // this is a pybind11 module we have to do this ourselves. + // (As per `xla::Init` in "xla.cc".) + nb::detail::init(NB_DOMAIN_STR); + + // Import implicit conversions from Python protobuf objects to C++ + // protobuf objects. + pybind11_protobuf::ImportNativeProtoCasters(); + + // Import dependencies for converting `absl::StatusOr` to Python exceptions. + // This also brings into scope pybind11 casters for doing conversions + // implicitly; however, towards the goal of converting everything to + // nanobind, we call `xla::ValueOrThrow` to make make the conversions + // explicit (since `nb::detail::type_caster` disallows raising exceptions, + // and therefore nanobind cannot do this implicitly). + py::google::ImportStatusModule(); + + // Import the 'ml_dtypes' module; which is implicitly required by + // `xla::LiteralToPython`. + // NOTE: If the `tsl_pybind_extension` build rule allowed us to specify + // this as a py_dep, then importing the module here would mean that + // client Python code need not import the hidden dependency themselves. + // However, since `tsl_pybind_extension` does not allow specifying py_deps, + // if client rules do not themselves declare the dependency then this will + // generate a `ModuleNotFoundError` / `ImportError` exception. Hence why + // we define the "types.py" wrapper library to encapsulate the dependency. + py::module_::import("ml_dtypes"); + + // Ensure that tsl-numpy initializes datastructures of the actual-NumPy + // implementation, and does whatever else tsl-numpy needs. This is + // also necessary for using the `xla::nb_dtype` type. + tsl::ImportNumpy(); + + // Declare that C++ can `nb::cast` from `std::shared_ptr` + // to `nb::object`; which is implicitly required by `xla::LiteralToPython`. + // (FWIW: This also enables using `nb::type()` to get + // the Python-type-object associated with the C++ class.) + // + // NOTE: This does *not* mean that C++ can `py::cast` from `xla::Literal` + // to `py::object`. It's unclear whether we can simultaneously provide + // both nanobind and pybind11 bindings (if we wanted the latter). + nb::module_ nb_m = nb::cast(nb::borrow(py_m.ptr())); + nb::class_(nb_m, "Literal") + .def("__repr__", &xla::Literal::ToString); + + // We do not define `py_m.doc()` here, since it wouldn't be inherited + // by the "types.py" wrapper library. See there for the python docstring. + + // LINT.IfChange + py_m.def("make_ndarray", &MakeNdarray, py::arg("proto").none(false), + py::pos_only(), R"pbdoc( + Converts `tensorflow.compiler.xla.xla_data_pb2.LiteralProto` + into an `xla::Literal` and then converts that literal into a tree + of tuples with leaves being `numpy.ndarray` views of array-shaped + sub-literals. + )pbdoc"); + + // This method name is based on `xla_client.dtype_to_etype`. + // NOTE: `xla_client` uses a Python class wrapping the protobuf-enum, + // rather than using the protobuf-enum directly. See the module docstring + // in "types.py" for more explanation on why. + py_m.def("dtype_to_etype", &DtypeToEtype, py::arg("dtype").none(false), + py::pos_only(), R"pbdoc( + Converts `numpy.dtype` into + `tensorflow.compiler.xla.xla_data_pb2.PrimitiveType`. + )pbdoc"); + + py_m.def("etype_to_dtype", &EtypeToDtype, py::arg("ptype").none(false), + py::pos_only(), R"pbdoc( + Converts `tensorflow.compiler.xla.xla_data_pb2.PrimitiveType` into + `numpy.dtype`. + )pbdoc"); + // LINT.ThenChange(_types.pyi) +} diff --git a/xla/python/tools/_types.pyi b/xla/python/tools/_types.pyi new file mode 100644 index 0000000000000..f355656f05b67 --- /dev/null +++ b/xla/python/tools/_types.pyi @@ -0,0 +1,25 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Union +import numpy as np +from xla import xla_data_pb2 + +# LINT.IfChange +NdarrayTree = Union[np.ndarray, tuple['NdarrayTree', ...]] +def make_ndarray(proto: xla_data_pb2.LiteralProto, /) -> NdarrayTree: ... +def dtype_to_etype(dtype: np.dtype, /) -> xla_data_pb2.PrimitiveType: ... +def etype_to_dtype(ptype: xla_data_pb2.PrimitiveType, /) -> np.dtype: ... +# LINT.ThenChange(types.py, _types.cc) diff --git a/xla/python/tools/types.py b/xla/python/tools/types.py new file mode 100644 index 0000000000000..bea14ffcc3a2b --- /dev/null +++ b/xla/python/tools/types.py @@ -0,0 +1,53 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""tensorflow.compiler.xla.python.tools.types. + +This module provides Python bindings for various functions in +'tensorflow/compiler/xla/python/types.h'. It is primarily intended +to assist internal users in debugging things; and is not considered +part of the public API for OpenXLA. + +NOTE: This module *does* depend on Python protocol buffers; so beware! +The XLA Python bindings are currently packaged both as part of jaxlib and +as part of TensorFlow. Therefore, since we use protocol buffers here, +importing both jaxlib and TensorFlow may fail with duplicate protocol +buffer message definitions. +""" + +from typing import Union +# NOTE: `ml_dtypes` is implicitly required by `xla::LiteralToPython`. +# The entire goal of this wrapper library is to capture this dependency, +# so that client code need not be aware of it. +import ml_dtypes # pylint: disable=unused-import +import numpy +# NOTE: These protos are not part of TensorFlow's public API, therefore +# we cannot abide by [g-direct-tensorflow-import]. +# pylint: disable=g-direct-tensorflow-import,unused-import +from xla import xla_data_pb2 +# pylint: enable=g-direct-tensorflow-import,unused-import + +# NOTE: `import as ` is required for names to be exported. +# See PEP 484 & +# pylint: disable=g-importing-member,useless-import-alias,unused-import,g-multiple-import +# LINT.IfChange +from ._types import ( + make_ndarray as make_ndarray, + dtype_to_etype as dtype_to_etype, + etype_to_dtype as etype_to_dtype, +) +# TODO(wrengr): We can't import the `NdarrayTree` defined in the pyi file. +# So re-defining it here for now. +NdarrayTree = Union[numpy.ndarray, tuple['NdarrayTree', ...]] +# LINT.ThenChange(_types.pyi) diff --git a/xla/python/tools/types_test.py b/xla/python/tools/types_test.py new file mode 100644 index 0000000000000..34c941ee1ef96 --- /dev/null +++ b/xla/python/tools/types_test.py @@ -0,0 +1,181 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import itertools +import math +import re +from typing import List, NamedTuple + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np + +# NOTE: These protos are not part of the public API, therefore we cannot +# abide by [g-direct-tensorflow-import]. +# pylint: disable=g-direct-tensorflow-import +from xla import xla_data_pb2 +from xla.python.tools import types +# pylint: enable=g-direct-tensorflow-import + + +class MakeNdarrayInvalidTest(absltest.TestCase): + """Tests for invalid/unsupported arguments to `make_ndarray`.""" + + def setUp(self): + super().setUp() + self.assert_cannot_create_from_proto = self.assertRaisesRegex( + ValueError, re.escape('Cannot `xla::Literal::CreateFromProto`') + ) + + # NOTE: The `Literal(const Shape&, bool, ArrayValueState)` ctor does + # a CHECK forbidding `element_size_in_bits` from being specified; + # so we can't test anything about custom sizes here. + + def testMissingLayout(self): + # NOTE: `CreateFromProto` requires explicit `shape.layout.minor_to_major`. + # Though in principle it could use a default ctor instead, like we + # do in `make_named_parameter` below`. + pb = xla_data_pb2.LiteralProto( + shape=xla_data_pb2.ShapeProto( + element_type=xla_data_pb2.PrimitiveType.F64, + dimensions=[1, 2, 3], + ) + ) + with self.assert_cannot_create_from_proto: + types.make_ndarray(pb) + + def testMissingMinorToMajor(self): + # NOTE: `CreateFromProto` requires explicit `shape.layout.minor_to_major`. + # Though in principle it could use a default ctor instead, like we + # do in `make_named_parameter` below`. + pb = xla_data_pb2.LiteralProto( + shape=xla_data_pb2.ShapeProto( + element_type=xla_data_pb2.PrimitiveType.F64, + dimensions=[1, 2, 3], + layout=xla_data_pb2.LayoutProto(), + ) + ) + with self.assert_cannot_create_from_proto: + types.make_ndarray(pb) + + def testInvalidPrimitiveType(self): + # NOTE: The `is_dynamic_dimension` field isn't required by + # `CreateFromProto`; however, the `Shape(const ShapeProto&)` ctor + # will log warnings if we leave it unspecified. + pb = xla_data_pb2.LiteralProto( + shape=xla_data_pb2.ShapeProto( + element_type=xla_data_pb2.PrimitiveType.PRIMITIVE_TYPE_INVALID, + dimensions=[1, 2, 3], + is_dynamic_dimension=[False, False, False], + layout=xla_data_pb2.LayoutProto( + minor_to_major=[0, 1, 2], + ), + ) + ) + with self.assert_cannot_create_from_proto: + types.make_ndarray(pb) + + def testHasDimLevelTypes(self): + # NOTE: `CreateFromProto` forbids `dim_level_types` (even if all-dense). + pb = xla_data_pb2.LiteralProto( + shape=xla_data_pb2.ShapeProto( + element_type=xla_data_pb2.PrimitiveType.F64, + dimensions=[1, 2, 3], + is_dynamic_dimension=[False, False, False], + layout=xla_data_pb2.LayoutProto( + dim_level_types=[ + xla_data_pb2.DimLevelType.DIM_DENSE, + xla_data_pb2.DimLevelType.DIM_DENSE, + xla_data_pb2.DimLevelType.DIM_DENSE, + ], + minor_to_major=[0, 1, 2], + ), + ) + ) + with self.assert_cannot_create_from_proto: + types.make_ndarray(pb) + + +class MakeNdarrayValidTestParameter(NamedTuple): + testcase_name: str + proto: xla_data_pb2.LiteralProto + arr: np.ndarray + + +def make_named_parameter( + testcase_name: str, + dimensions: List[int], + data: List[float], +) -> MakeNdarrayValidTestParameter: + """Helper function to construct parameters for `MakeNdarrayValidTest`.""" + assert math.prod(dimensions) == len(data) + nd = len(dimensions) + proto = xla_data_pb2.LiteralProto( + shape=xla_data_pb2.ShapeProto( + element_type=xla_data_pb2.PrimitiveType.F64, + dimensions=dimensions, + is_dynamic_dimension=itertools.repeat(False, nd), + layout=xla_data_pb2.LayoutProto( + minor_to_major=range(nd), + ), + ), + f64s=data, + ) + arr = types.make_ndarray(proto) + return MakeNdarrayValidTestParameter(testcase_name, proto, arr) + + +@parameterized.named_parameters( + make_named_parameter('A', [2, 3], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), + make_named_parameter('B', [1, 2, 3], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), + make_named_parameter('C', [2, 3], [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]), + make_named_parameter('D', [3, 2], [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]), +) +class MakeNdarrayValidTest(parameterized.TestCase): + """Correctness tests for valid arguments to `make_ndarray`.""" + + def testHasCorrectDtype(self, proto, arr): + """Test that the result has the right dtype.""" + e = proto.shape.element_type + d = arr.dtype + with self.subTest(msg='etype_to_dtype'): + self.assertEqual(types.etype_to_dtype(e), d) + with self.subTest(msg='dtype_to_etype'): + self.assertEqual(e, types.dtype_to_etype(d)) + + def testHasCorrectRank(self, proto, arr): + """Test that the result has the right rank.""" + self.assertLen(proto.shape.dimensions, arr.ndim) + + def testHasCorrectShape(self, proto, arr): + """Test that the result has the same/right shape.""" + self.assertTupleEqual(tuple(proto.shape.dimensions), arr.shape) + + def testHasCorrectData(self, proto, arr): + """Test that the result has the same/right data.""" + # TODO(wrengr): Figure out a way to abstract away the name of the + # proto field containing the data; so that we can test multiple types. + self.assertSequenceAlmostEqual(proto.f64s, list(np.nditer(arr))) + + # TODO(wrengr): Add tests for: + # * dynamic dimension sizes. + # * non-trivial `minor_to_major`. + # * problematic types {PRED,F16,C64,C128} are all handled correctly. + # * BF16 is handled correctly. + # * tuples are handled correctly + + +if __name__ == '__main__': + absltest.main() diff --git a/xla/python/tpu_driver/BUILD b/xla/python/tpu_driver/BUILD deleted file mode 100644 index 65682631cecaa..0000000000000 --- a/xla/python/tpu_driver/BUILD +++ /dev/null @@ -1,126 +0,0 @@ -load( - "//xla/python/tpu_driver:platform/external/tools.bzl", - "external_deps", - "go_grpc_library", -) -load("@tsl//tsl:tsl.default.bzl", "tsl_grpc_cc_dependencies") -load("@tsl//tsl/platform:build_config.bzl", "tf_proto_library") -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") - -licenses(["notice"]) - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "tpu_driver_proto", - srcs = ["tpu_driver.proto"], - cc_api_version = 2, - protodeps = [], -) - -tf_proto_library( - name = "tpu_service_proto", - srcs = ["tpu_service.proto"], - has_services = 1, - cc_api_version = 2, - create_grpc_library = True, - protodeps = [ - ":tpu_driver_proto", - "//xla:xla_data_proto", - "//xla:xla_proto", - "//xla/service:hlo_proto", - ], - use_grpc_namespace = True, -) - -cc_library( - name = "tpu_driver", - srcs = [ - "tpu_driver.cc", - ], - hdrs = [ - "event_id.h", - "platform/external/compat.h", - "tpu_driver.h", - ], - deps = [ - ":tpu_driver_proto_cc", - "//xla:status", - "//xla:statusor", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/service:hlo_proto_cc", - "@tsl//tsl/platform:logging", - ] + external_deps(), -) - -cc_library( - name = "grpc_tpu_driver", - srcs = [ - "grpc_tpu_driver.cc", - ], - hdrs = ["grpc_tpu_driver.h"], - deps = [ - ":tpu_driver", - ":tpu_driver_proto_cc", - ":tpu_service_cc_grpc_proto", - ":tpu_service_proto_cc", - "//xla:status", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/service:hlo_proto_cc", - "@tsl//tsl/platform:logging", - ] + tsl_grpc_cc_dependencies() + external_deps(), - alwayslink = 1, -) - -cc_library( - name = "recording_tpu_driver", - srcs = [ - "recording_tpu_driver.cc", - ], - deps = [ - ":tpu_driver", - ":tpu_driver_proto_cc", - ":tpu_service_cc_grpc_proto", - ":tpu_service_proto_cc", - "//xla:status", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/service:hlo_proto_cc", - "@com_google_absl//absl/base", - "@tsl//tsl/platform:env", - "@tsl//tsl/platform:logging", - ] + external_deps(), - alwayslink = 1, -) - -cc_library( - name = "pod_tpu_driver", - srcs = ["pod_tpu_driver.cc"], - deps = [ - ":grpc_tpu_driver", - ":tpu_driver", - ":tpu_driver_proto_cc", - "//xla/pjrt:semaphore", - "//xla/pjrt:worker_thread", - "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/container:flat_hash_set", - "@tsl//tsl/platform:env", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - ] + tsl_grpc_cc_dependencies() + external_deps(), - alwayslink = 1, -) - -go_grpc_library( - name = "tpu_service_go_grpc", - srcs = [":tpu_service_proto"], - compatible_with = ["//buildenv/target:gce"], - deps = [":tpu_service_go_proto"], -) diff --git a/xla/python/tpu_driver/README.md b/xla/python/tpu_driver/README.md deleted file mode 100644 index 5b31df30ecc20..0000000000000 --- a/xla/python/tpu_driver/README.md +++ /dev/null @@ -1,22 +0,0 @@ -# TPU Driver API - -This repository contains the TPU driver API and network (gRPC) transport -implementation for high-performance access to TPU hardware. - -# Building - -Bazel is used to build the driver library and tests. Remote tests will require -access to a Cloud TPU. - -## Fetching Bazel - -Download the latest copy of Bazel from -https://github.com/bazelbuild/bazel/releases. - -## Building - -`bazel build ...` - -## Testing - -`bazel test ...` diff --git a/xla/python/tpu_driver/client/BUILD b/xla/python/tpu_driver/client/BUILD deleted file mode 100644 index efb5329d8d4c6..0000000000000 --- a/xla/python/tpu_driver/client/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("@tsl//tsl:tsl.default.bzl", "filegroup") -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//xla:internal"], - licenses = ["notice"], -) - -filegroup( - name = "header_and_client", - srcs = glob([ - "c_api*", - "libtpu*", - ]), - visibility = ["//visibility:public"], -) - -cc_library( - name = "libtpu", - hdrs = ["libtpu.h"], -) diff --git a/xla/python/tpu_driver/client/libtpu.h b/xla/python/tpu_driver/client/libtpu.h deleted file mode 100644 index 746083ffeaecf..0000000000000 --- a/xla/python/tpu_driver/client/libtpu.h +++ /dev/null @@ -1,312 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_PYTHON_TPU_DRIVER_CLIENT_LIBTPU_H_ -#define XLA_PYTHON_TPU_DRIVER_CLIENT_LIBTPU_H_ - -#include -#include - -#define TPUDRIVER_CAPI_EXPORT __attribute__((visibility("default"))) - -#ifdef __cplusplus -extern "C" { -#endif - -// ------------------- TPU Driver Support ----------------------- - -struct TpuDriverFn; - -typedef struct TpuDriver TpuDriver; - -typedef struct TpuEvent TpuEvent; - -typedef struct TpuBufferHandleInternal TpuBufferHandleInternal; - -typedef struct TpuCompiledProgramHandleInternal - TpuCompiledProgramHandleInternal; - -typedef struct TpuLoadedProgramHandleInternal TpuLoadedProgramHandleInternal; - -typedef struct TpuBufferHandle { - TpuBufferHandleInternal* internal_handle; - TpuEvent* event; - int64_t size_in_bytes; -} TpuBufferHandle; - -typedef struct TpuCompiledProgramHandle { - TpuCompiledProgramHandleInternal* internal_handle; - TpuEvent* event; -} TpuCompiledProgramHandle; - -typedef struct TpuLoadedProgramHandle { - TpuLoadedProgramHandleInternal* internal_handle; - TpuEvent* event; -} TpuLoadedProgramHandle; - -// HloProto is a serialized xla::HloProto buffer. -typedef struct HloProto { - void* buffer; - int32_t size; -} HloProto; - -typedef struct DebugOptions { - void* buffer; - int32_t size; -} DebugOptions; - -// DeviceAssignment is a serialized xla::DeviceAssignmentProto buffer. -typedef struct DeviceAssignment { - void* bytes; - int32_t size; -} DeviceAssignment; - -typedef struct TpuStatus { - int32_t code; - char* msg; -} TpuStatus; - -typedef struct CompiledProgramShape { - struct TpuStatus* status; - void* bytes; - int32_t size; -} CompiledProgramShape; - -typedef struct TpuAllocationShape { - void* bytes; - int32_t size; -} TpuAllocationShape; - -typedef struct TpuSystemInfo { - void* bytes; - int32_t size; -} TpuSystemInfo; - -typedef void(PrototypeTpuDriver_Initialize)(struct TpuDriverFn* driver_fn, - bool initialize); -typedef struct TpuDriver*(PrototypeTpuDriver_Open)(const char* worker); -typedef void(PrototypeTpuDriver_Close)(struct TpuDriver* driver); -typedef struct TpuStatus*(PrototypeTpuDriver_Reset)(struct TpuDriver* driver); - -typedef struct TpuSystemInfo*(PrototypeTpuDriver_QuerySystemInfo)( - struct TpuDriver* driver); - -typedef void(PrototypeTpuDriver_FreeSystemInfo)(struct TpuSystemInfo* info); - -// TODO(frankchn): Make this not a hard-coded constant. -const int32_t MemoryRegion_HBM = 1; - -typedef int64_t(PrototypeTpuDriver_ComputeLinearizedBytesFromShape)( - struct TpuDriver* driver, const struct TpuAllocationShape shape); - -typedef struct TpuStatus*(PrototypeTpuDriver_LinearizeShape)( - struct TpuDriver* driver, void* dst, const void* src, - const struct TpuAllocationShape shape); - -typedef struct TpuStatus*(PrototypeTpuDriver_DelinearizeShape)( - struct TpuDriver* driver, void* dst, const void* src, - const struct TpuAllocationShape shape); - -typedef struct TpuCompiledProgramHandle*( - PrototypeTpuDriver_CompileProgram)(struct TpuDriver* driver, - const struct HloProto hlo_proto, - int32_t num_replicas, - const struct DebugOptions debug_options, - int32_t eventc, - struct TpuEvent** eventv); - -typedef struct TpuCompiledProgramHandle*( - PrototypeTpuDriver_CompileProgramFromText)(struct TpuDriver* driver, - const char* hlo_text, - int32_t num_replicas, - int32_t eventc, - struct TpuEvent** eventv); - -/* Note: We are not responsible for freeing the event within the - * TpuCompiledProgramHandle. You have to call FreeEvent separately to ensure - * that memory does not leak. - */ -typedef void(PrototypeTpuDriver_FreeCompiledProgramHandle)( - struct TpuCompiledProgramHandle* handle); - -typedef struct TpuLoadedProgramHandle*(PrototypeTpuDriver_LoadProgram)( - struct TpuDriver* driver, int32_t core_id, - const struct TpuCompiledProgramHandle* compiled_program_handle, - int32_t eventc, struct TpuEvent** eventv); - -/* Note: We are not responsible for freeing the event within the - * TpuLoadedProgramHandle. You have to call FreeEvent separately to ensure that - * memory does not leak. - */ -typedef struct TpuEvent*(PrototypeTpuDriver_UnloadProgram)( - struct TpuDriver* driver, - struct TpuLoadedProgramHandle* loaded_program_handle, int32_t eventc, - struct TpuEvent** eventv); - -typedef struct TpuEvent*(PrototypeTpuDriver_ExecuteProgram)( - struct TpuDriver* driver, struct TpuLoadedProgramHandle* handle, - int32_t inputc, struct TpuBufferHandle** input_buffer_handle, - int32_t outputc, struct TpuBufferHandle** output_buffer_handle, - struct DeviceAssignment device_assignment, int32_t eventc, - struct TpuEvent** eventv); - -typedef struct TpuBufferHandle*(PrototypeTpuDriver_AllocateTuple)( - struct TpuDriver* driver, int32_t core_id, int32_t memory_region, - int32_t bufferc, struct TpuBufferHandle** buffer_handle, int32_t eventc, - struct TpuEvent** eventv); - -typedef struct TpuBufferHandle*(PrototypeTpuDriver_Allocate)( - struct TpuDriver* driver, int32_t core_id, int32_t memory_region, - int64_t num_bytes, int32_t eventc, struct TpuEvent** eventv); - -typedef struct TpuBufferHandle*(PrototypeTpuDriver_AllocateShape)( - struct TpuDriver* driver, int32_t core_id, int32_t memory_region, - const struct TpuAllocationShape shape, int32_t eventc, - struct TpuEvent** eventv); - -/* Note: We are not responsible for freeing the event within the - * TpuBufferHandle. You have to call FreeEvent separately to ensure that memory - * does not leak. - */ -typedef struct TpuEvent*(PrototypeTpuDriver_Deallocate)( - struct TpuDriver* driver, struct TpuBufferHandle* buffer_handle, - int32_t eventc, struct TpuEvent** eventv); - -typedef struct TpuEvent*(PrototypeTpuDriver_TransferToDevice)( - struct TpuDriver* driver, const void* src, struct TpuBufferHandle* dst, - int32_t eventc, struct TpuEvent** eventv); - -typedef struct TpuEvent*(PrototypeTpuDriver_TransferFromDevice)( - struct TpuDriver* driver, struct TpuBufferHandle* src, void* dst, - int32_t eventc, struct TpuEvent** eventv); - -typedef struct TpuEvent*(PrototypeTpuDriver_TransferFromDeviceToDevice)( - struct TpuDriver* driver, struct TpuBufferHandle* src, - struct TpuBufferHandle* dst, int32_t eventc, struct TpuEvent** eventv); - -typedef struct CompiledProgramShape*( - PrototypeTpuDriver_GetCompiledProgramShape)( - struct TpuCompiledProgramHandle* handle); - -typedef void(PrototypeTpuDriver_FreeCompiledProgramShape)( - struct CompiledProgramShape* shape); - -typedef void(PrototypeTpuDriver_EventAddCallback)( - struct TpuEvent* event, - void (*callback_fn)(struct TpuStatus*, void* additional_info), - void* additional_info); - -typedef struct TpuStatus*(PrototypeTpuDriver_EventAwait)(struct TpuEvent* event, - int64_t timeout_in_us); - -typedef void(PrototypeTpuDriver_FreeEvent)(struct TpuEvent* event); - -typedef void(PrototypeTpuDriver_FreeStatus)(struct TpuStatus* status); - -typedef const char*(PrototypeTpuDriver_Version)(); - -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Initialize TpuDriver_Initialize; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Open TpuDriver_Open; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Close TpuDriver_Close; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Reset TpuDriver_Reset; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_QuerySystemInfo - TpuDriver_QuerySystemInfo; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_FreeSystemInfo - TpuDriver_FreeSystemInfo; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_ComputeLinearizedBytesFromShape - TpuDriver_ComputeLinearizedBytesFromShape; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_LinearizeShape - TpuDriver_LinearizeShape; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_DelinearizeShape - TpuDriver_DelinearizeShape; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_CompileProgram - TpuDriver_CompileProgram; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_CompileProgramFromText - TpuDriver_CompileProgramFromText; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_FreeCompiledProgramHandle - TpuDriver_FreeCompiledProgramHandle; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_LoadProgram - TpuDriver_LoadProgram; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_UnloadProgram - TpuDriver_UnloadProgram; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_ExecuteProgram - TpuDriver_ExecuteProgram; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_AllocateTuple - TpuDriver_AllocateTuple; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Allocate TpuDriver_Allocate; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_AllocateShape - TpuDriver_AllocateShape; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Deallocate TpuDriver_Deallocate; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferToDevice - TpuDriver_TransferToDevice; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferFromDevice - TpuDriver_TransferFromDevice; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferFromDeviceToDevice - TpuDriver_TransferFromDeviceToDevice; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_GetCompiledProgramShape - TpuDriver_GetCompiledProgramShape; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_FreeCompiledProgramShape - TpuDriver_FreeCompiledProgramShape; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_EventAddCallback - TpuDriver_EventAddCallback; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_EventAwait TpuDriver_EventAwait; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_FreeEvent TpuDriver_FreeEvent; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_FreeStatus TpuDriver_FreeStatus; -TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Version TpuDriver_Version; - -#ifdef __cplusplus -} -#endif - -struct TpuDriverFn { - PrototypeTpuDriver_Open* TpuDriver_Open; // NOLINT - PrototypeTpuDriver_Close* TpuDriver_Close; // NOLINT - PrototypeTpuDriver_Reset* TpuDriver_Reset; // NOLINT - PrototypeTpuDriver_ComputeLinearizedBytesFromShape* - TpuDriver_ComputeLinearizedBytesFromShape; // NOLINT - PrototypeTpuDriver_QuerySystemInfo* TpuDriver_QuerySystemInfo; // NOLINT - PrototypeTpuDriver_FreeSystemInfo* TpuDriver_FreeSystemInfo; // NOLINT - PrototypeTpuDriver_LinearizeShape* TpuDriver_LinearizeShape; // NOLINT - PrototypeTpuDriver_DelinearizeShape* TpuDriver_DelinearizeShape; // NOLINT - PrototypeTpuDriver_CompileProgram* TpuDriver_CompileProgram; // NOLINT - PrototypeTpuDriver_CompileProgramFromText* - TpuDriver_CompileProgramFromText; // NOLINT - PrototypeTpuDriver_FreeCompiledProgramHandle* - TpuDriver_FreeCompiledProgramHandle; // NOLINT - PrototypeTpuDriver_LoadProgram* TpuDriver_LoadProgram; // NOLINT - PrototypeTpuDriver_UnloadProgram* TpuDriver_UnloadProgram; // NOLINT - PrototypeTpuDriver_ExecuteProgram* TpuDriver_ExecuteProgram; // NOLINT - PrototypeTpuDriver_AllocateTuple* TpuDriver_AllocateTuple; // NOLINT - PrototypeTpuDriver_Allocate* TpuDriver_Allocate; // NOLINT - PrototypeTpuDriver_AllocateShape* TpuDriver_AllocateShape; // NOLINT - PrototypeTpuDriver_Deallocate* TpuDriver_Deallocate; // NOLINT - PrototypeTpuDriver_TransferToDevice* TpuDriver_TransferToDevice; // NOLINT - PrototypeTpuDriver_TransferFromDevice* - TpuDriver_TransferFromDevice; // NOLINT - PrototypeTpuDriver_TransferFromDeviceToDevice* - TpuDriver_TransferFromDeviceToDevice; // NOLINT - PrototypeTpuDriver_GetCompiledProgramShape* - TpuDriver_GetCompiledProgramShape; // NOLINT - PrototypeTpuDriver_FreeCompiledProgramShape* - TpuDriver_FreeCompiledProgramShape; // NOLINT - PrototypeTpuDriver_EventAddCallback* TpuDriver_EventAddCallback; // NOLINT - PrototypeTpuDriver_EventAwait* TpuDriver_EventAwait; // NOLINT - PrototypeTpuDriver_FreeEvent* TpuDriver_FreeEvent; // NOLINT - PrototypeTpuDriver_FreeStatus* TpuDriver_FreeStatus; // NOLINT - - PrototypeTpuDriver_Version* TpuDriver_Version; // NOLINT -}; - -#endif // XLA_PYTHON_TPU_DRIVER_CLIENT_LIBTPU_H_ diff --git a/xla/python/tpu_driver/client/libtpu_client.c b/xla/python/tpu_driver/client/libtpu_client.c deleted file mode 100644 index d759a786d45cd..0000000000000 --- a/xla/python/tpu_driver/client/libtpu_client.c +++ /dev/null @@ -1,167 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Before you start, make sure libtpu.so, libtpu.h and libtpu_client.c are in -// the same working directory. -// -// To compile: gcc -o libtpu_client libtpu_client.c -ldl -// To run: sudo ./libtpu_client - -#include -#include -#include - -#include "libtpu.h" - -void* LoadAndInitializeDriver(const char* shared_lib, - struct TpuDriverFn* driver_fn) { - void* handle; - handle = dlopen(shared_lib, RTLD_NOW); - if (!handle) { - fprintf(stderr, "Error: %s\n", dlerror()); - exit(EXIT_FAILURE); - } - - PrototypeTpuDriver_Initialize* initialize_fn; - *(void**)(&initialize_fn) = dlsym(handle, "TpuDriver_Initialize"); - initialize_fn(driver_fn, true); - - return handle; -} - -int main(int argc, char** argv) { - char* api_path = "libtpu.so"; - if (argc == 2) { - api_path = argv[1]; - } - - struct TpuDriverFn driver_fn; - void* handle = LoadAndInitializeDriver(api_path, &driver_fn); - - fprintf(stdout, "------ Going to Query Version ------\n"); - fprintf(stdout, "TPU Driver Version: %s\n", driver_fn.TpuDriver_Version()); - - fprintf(stdout, "------ Going to Open a TPU Driver ------\n"); - struct TpuDriver* driver = driver_fn.TpuDriver_Open("local://"); - - fprintf(stdout, "------ Going to Query for System Information ------\n"); - struct TpuSystemInfo* info = driver_fn.TpuDriver_QuerySystemInfo(driver); - driver_fn.TpuDriver_FreeSystemInfo(info); - - // An example of simple program to sum two parameters. - const char* hlo_module_text = R"(HloModule add_vec_module - ENTRY %add_vec (a: s32[256], b: s32[256]) -> s32[256] { - %a = s32[256] parameter(0) - %b = s32[256] parameter(1) - ROOT %sum = s32[256] add(%a, %b) - } - )"; - - fprintf(stdout, "------ Going to Compile a TPU program ------\n"); - struct TpuCompiledProgramHandle* cph = - driver_fn.TpuDriver_CompileProgramFromText(driver, hlo_module_text, - /*num_replicas=*/1, /*eventc=*/0, /*eventv*/NULL); - - TpuEvent* compile_events[] = {cph->event}; - fprintf(stdout, "------ Going to Load a TPU program ------\n"); - struct TpuLoadedProgramHandle* lph = - driver_fn.TpuDriver_LoadProgram(driver, /*core_id=*/0, cph, - /*eventc=*/1, /*eventv=*/compile_events); - - const int size = 1024; - - fprintf(stdout, "------ Going to Allocate a TPU Buffer ------\n"); - struct TpuBufferHandle* buf_a_handle = - driver_fn.TpuDriver_Allocate(driver, /*core-id=*/0, /*memory_region=*/1, - /*bytes=*/size, /*eventc=*/0, /*eventv=*/NULL); - fprintf(stdout, "------ Going to Allocate a TPU Buffer ------\n"); - struct TpuBufferHandle* buf_b_handle = - driver_fn.TpuDriver_Allocate(driver, /*core-id=*/0, /*memory_region=*/1, - /*bytes=*/size, /*eventc=*/0, /*eventv=*/NULL); - fprintf(stdout, "------ Going to Allocate a TPU Buffer ------\n"); - struct TpuBufferHandle* buf_sum_handle = - driver_fn.TpuDriver_Allocate(driver, /*core-id=*/0, /*memory_region=*/1, - /*bytes=*/size, /*eventc=*/0, /*eventv=*/NULL); - - char a_src[size], b_src[size], sum_src[size]; - for (int i = 0; i < size; ++i) { - a_src[i] = 1; - b_src[i] = 2; - sum_src[i] = 0; - } - - TpuEvent* allocate_buf_a_events[] = {buf_a_handle->event}; - fprintf(stdout, "------ Going to Transfer To Device ------\n"); - struct TpuEvent* transfer_ev1 = - driver_fn.TpuDriver_TransferToDevice(driver, a_src, buf_a_handle, - /*eventc=*/1, /*eventv=*/allocate_buf_a_events); - TpuEvent* allocate_buf_b_events[] = {buf_a_handle->event}; - fprintf(stdout, "------ Going to Transfer To Device ------\n"); - struct TpuEvent* transfer_ev2 = - driver_fn.TpuDriver_TransferToDevice(driver, b_src, buf_b_handle, - /*eventc=*/1, /*eventv=*/allocate_buf_b_events); - - fprintf(stdout, "------ Going to Execute a TPU program ------\n"); - DeviceAssignment device_assignment = {NULL, 0}; - TpuBufferHandle* input_buffer_handle[] = {buf_a_handle, buf_b_handle}; - TpuBufferHandle* output_buffer_handle[] = {buf_sum_handle}; - TpuEvent* transfer_events[] = {transfer_ev1, transfer_ev2}; - struct TpuEvent* execute_event = - driver_fn.TpuDriver_ExecuteProgram(driver, lph, - /*inputc=*/2, /*input_buffer_handle=*/input_buffer_handle, - /*outputc=*/1, /*output_buffer_handle=*/output_buffer_handle, - device_assignment, - /*eventc=*/2, /*eventv*/transfer_events); - - fprintf(stdout, "------ Going to Transfer From Device ------\n"); - TpuEvent* execute_events[] = {execute_event}; - struct TpuEvent* transfer_sum_event = - driver_fn.TpuDriver_TransferFromDevice(driver, buf_sum_handle, sum_src, - /*eventc=*/1, /*eventv=*/execute_events); - - TpuStatus* status = driver_fn.TpuDriver_EventAwait(transfer_sum_event, - 10000000); - if (status->code != 0) { - fprintf(stdout, "Transfer Event Await: Code: %d, Message: %s\n", - status->code, status->msg); - } - - fprintf(stdout, "------ Going to Unload a TPU program ------\n"); - struct TpuEvent* unload_program_event = driver_fn.TpuDriver_UnloadProgram( - driver, lph, /*eventc=*/1, /*eventv=*/execute_events); - - fprintf(stdout, "------ Going to Deallocate a TPU Buffer ------\n"); - struct TpuEvent* dealloc_ev1 = driver_fn.TpuDriver_Deallocate(driver, - buf_a_handle, /*eventc=*/0, /*eventv=*/NULL); - driver_fn.TpuDriver_FreeEvent(dealloc_ev1); - - fprintf(stdout, "------ Going to Deallocate a TPU Buffer ------\n"); - struct TpuEvent* dealloc_ev2 = driver_fn.TpuDriver_Deallocate(driver, - buf_b_handle, /*eventc=*/0, /*eventv=*/NULL); - driver_fn.TpuDriver_FreeEvent(dealloc_ev2); - - fprintf(stdout, "------ Going to Deallocate a TPU Buffer ------\n"); - struct TpuEvent* dealloc_ev3 = driver_fn.TpuDriver_Deallocate(driver, - buf_sum_handle, /*eventc=*/0, /*eventv=*/NULL); - driver_fn.TpuDriver_FreeEvent(dealloc_ev3); - - fprintf(stdout, "sum:\n"); - for (size_t i = 0; i < size; ++i) { - fprintf(stdout, "%d ", sum_src[i]); - } - - dlclose(handle); - exit(EXIT_SUCCESS); -} diff --git a/xla/python/tpu_driver/event_id.h b/xla/python/tpu_driver/event_id.h deleted file mode 100644 index ee006a992f565..0000000000000 --- a/xla/python/tpu_driver/event_id.h +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2019 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== -#ifndef XLA_PYTHON_TPU_DRIVER_EVENT_ID_H_ -#define XLA_PYTHON_TPU_DRIVER_EVENT_ID_H_ - -#include -#include -#include -#include - -#include "absl/strings/str_cat.h" - -namespace tpu_driver { - -// For gRPC serialization, events are represented as a pair of -// {client, operation} ids. To simplify serialization, these are encoded as a -// single integer field. -// -// This class provides a typed interface for these values as well as support for -// hashing and ostreams (for logging). -struct EventId { - uint64_t client_id; - uint64_t operation_id; - - template - friend H AbslHashValue(H h, const EventId& c) { - return H::combine(std::move(h), c.client_id, c.operation_id); - } - - bool operator==(const EventId& r) const { - return r.client_id == client_id && r.operation_id == operation_id; - } - - friend std::ostream& operator<<(std::ostream& os, EventId r) { - return os << r.client_id << ":" << r.operation_id; - } - - std::string ToString() const { - return absl::StrCat(client_id, ":", operation_id); - } - - uint64_t AsInt() const { return client_id << 44 | operation_id; } - - static EventId FromInt(uint64_t value) { - return EventId{value >> 44, value & 0xfffffffffff}; - } -}; - -} // namespace tpu_driver - -#endif // XLA_PYTHON_TPU_DRIVER_EVENT_ID_H_ diff --git a/xla/python/tpu_driver/grpc_tpu_driver.cc b/xla/python/tpu_driver/grpc_tpu_driver.cc deleted file mode 100644 index 52430c8cb87fa..0000000000000 --- a/xla/python/tpu_driver/grpc_tpu_driver.cc +++ /dev/null @@ -1,1108 +0,0 @@ -// Copyright 2019 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/base/thread_annotations.h" -#include "absl/strings/strip.h" -#include "absl/synchronization/mutex.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "absl/types/span.h" -#include "grpcpp/grpcpp.h" -#include "xla/python/tpu_driver/event_id.h" -#include "xla/python/tpu_driver/platform/external/compat.h" -#include "xla/python/tpu_driver/tpu_driver.h" -#include "xla/python/tpu_driver/tpu_driver.pb.h" -#include "xla/python/tpu_driver/tpu_service.grpc.pb.h" -#include "xla/util.h" - -namespace tpu_driver { -namespace { - -using xla::OkStatus; -using xla::Status; - -const int64_t kMaxStreamWriteSize = 10 * 1000 * 1000; -const absl::Duration kWriteEpochDuration = absl::Microseconds(10); - -constexpr char kGrpcProtocol[] = "grpc://"; - -class GrpcTpuStream; -class GrpcTpuDriver; - -class GrpcEvent : public Event { - public: - explicit GrpcEvent(EventId id, GrpcTpuStream* stream) - : id_(id), stream_(stream) {} - ~GrpcEvent() override; - - xla::Status Await() override; - std::optional AwaitWithTimeout(absl::Duration duration) override; - void AddCallback(std::function callback) override; - - EventId id() const { return id_; } - GrpcTpuStream* stream() const { return stream_; } - - private: - const EventId id_; - GrpcTpuStream* stream_; -}; - -class ErrorEvent : public GrpcEvent { - public: - explicit ErrorEvent(Status status) : GrpcEvent(EventId{0, 0}, nullptr) { - status_ = status; - } - - xla::Status Await() override { return status_; } - std::optional AwaitWithTimeout( - absl::Duration duration) override { - return status_; - } - void AddCallback(std::function callback) override { - callback(status_); - } - - private: - Status status_; -}; - -class GrpcBufferHandle : public BufferHandle { - public: - explicit GrpcBufferHandle(EventId id, std::shared_ptr event, - int64_t bytes, - std::optional shape = std::nullopt) - : id_(id), - stream_(event->stream()), - event_(std::move(event)), - bytes_(bytes), - shape_(shape) {} - - std::shared_ptr OnReady() override { return event_; } - int64_t size_in_bytes() override { return bytes_; } - - EventId id() const { return id_; } - GrpcTpuStream* stream() const { return stream_; } - - std::optional shape() override { return shape_; } - - private: - const EventId id_; - GrpcTpuStream* stream_; - std::shared_ptr event_; - int64_t bytes_; - std::optional shape_; -}; - -class GrpcCompiledProgramHandle : public CompiledProgramHandle { - public: - explicit GrpcCompiledProgramHandle(EventId id, - std::shared_ptr event) - : id_(id), - stream_(event->stream()), - event_(std::move(event)), - metadata_(std::make_shared()) {} - - std::shared_ptr OnReady() override { return event_; } - - EventId id() const { return id_; } - GrpcTpuStream* stream() const { return stream_; } - - Status program_shape(xla::ProgramShapeProto* program_shape) override { - auto opt_status = OnReady()->AwaitWithTimeout(absl::Hours(1)); - if (!opt_status.has_value()) { - return xla::InternalError("Compile failed to finish within 1 hour."); - } - - Status status = opt_status.value(); - if (!status.ok()) { - return status; - } - *program_shape = metadata_->program_shape(); - return OkStatus(); - } - - std::shared_ptr metadata() { return metadata_; } - - private: - const EventId id_; - GrpcTpuStream* stream_; - std::shared_ptr event_; - - // Using a shared pointer here because the program handle can go out of scope - // before we get a response back, but we want a valid location to write things - // into regardless. - std::shared_ptr metadata_; -}; - -class GrpcLoadedProgramHandle : public LoadedProgramHandle { - public: - explicit GrpcLoadedProgramHandle(EventId id, std::shared_ptr event) - : id_(id), stream_(event->stream()), event_(std::move(event)) {} - - std::shared_ptr OnReady() override { return event_; } - - EventId id() const { return id_; } - GrpcTpuStream* stream() const { return stream_; } - - private: - const EventId id_; - GrpcTpuStream* stream_; - std::shared_ptr event_; -}; - -class GrpcTpuStream { - public: - explicit GrpcTpuStream(int32_t id, GrpcTpuDriver* driver, - std::unique_ptr stub); - virtual ~GrpcTpuStream(); - - std::unique_ptr Allocate(int32_t core_id, MemoryRegion region, - int64_t num_bytes, - absl::Span wait_for); - std::unique_ptr Allocate(int32_t core_id, MemoryRegion region, - const xla::ShapeProto& shape, - absl::Span wait_for); - std::unique_ptr AllocateTuple( - int32_t core_id, MemoryRegion region, - absl::Span children, - absl::Span wait_for); - std::shared_ptr Deallocate(std::unique_ptr handle, - absl::Span wait_for); - - std::shared_ptr TransferToDevice(const void* src, BufferHandle* dst, - absl::Span wait_for); - std::shared_ptr TransferFromDevice(const BufferHandle* src, void* dst, - absl::Span wait_for); - - std::shared_ptr TransferFromDeviceToDevice( - const BufferHandle* src, BufferHandle* dst, - absl::Span wait_for); - - std::unique_ptr CompileProgram( - const xla::HloProto& source, int32_t num_replicas, - absl::Span wait_for, - const xla::DebugOptions& debug_options); - std::unique_ptr LoadProgram( - int32_t core_id, const CompiledProgramHandle* handle, - absl::Span wait_for); - std::shared_ptr UnloadProgram( - std::unique_ptr handle, - absl::Span wait_for); - std::shared_ptr ExecuteProgram( - LoadedProgramHandle* program, absl::Span inputs, - absl::Span outputs, - const xla::DeviceAssignmentProto& device_assignment, - absl::Span wait_for); - - private: - friend class GrpcEvent; - friend class GrpcTpuDriver; - - struct EventInfo { - bool all_deps_done = false; - bool done = false; // response received - bool deleted = false; // deleted by the user - Status status; - absl::InlinedVector, 1> callbacks; - // Most events should have <= 2 requirement events. - absl::InlinedVector deps; - }; - - struct TransferInfo { - explicit TransferInfo(void* dst, int64_t num_bytes) - : dst(dst), num_bytes(num_bytes) {} - - void* const dst; - const uint64_t num_bytes; - }; - - struct CompileMetadataInfo { - explicit CompileMetadataInfo( - std::shared_ptr metadata) { - compiled_metadata = metadata; - } - std::shared_ptr compiled_metadata; - }; - - // Every public method above should call this first. - void InitializeRequest(StreamRequest::Entry* req, - absl::Span wait_for) - ABSL_LOCKS_EXCLUDED(events_mutex_); - - // The first update to an event marks it done and calls registered callbacks. - // All subsequent updates must have the same OK-ness as the first update. - // Among non-OK updates, only the first error status is remembered. - void UpdateEventStatus(EventId id, Status status) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(events_mutex_); - - // To ensure callbacks are still triggered, after this is called, we do not - // remove the event from the event mapping until a response is received from - // the server. - void DeleteEvent(EventId id) ABSL_LOCKS_EXCLUDED(events_mutex_); - - // Wait at most `duration` for event `id` to complete. Returns the event - // status or an empty optional if the event does not complete in time. - std::optional WaitForEvent(EventId id, absl::Duration duration) - ABSL_LOCKS_EXCLUDED(events_mutex_); - - void AddEventCallback(EventId id, std::function callback) - ABSL_LOCKS_EXCLUDED(events_mutex_); - - void AddWriteRequest(std::unique_ptr req) { - absl::MutexLock m(&request_lock_); - VLOG(2) << "Adding request: " << req->DebugString(); - requests_.push_back(std::move(req)); - } - - // Unique identifier for this stream. - int32_t id_; - // The parent driver that created this stream. - GrpcTpuDriver* driver_; - - std::unique_ptr stub_; - ::grpc::ClientContext ctx_; - std::unique_ptr< - ::grpc::ClientReaderWriterInterface> - stream_; - - absl::Mutex request_lock_; - std::deque> requests_ - ABSL_GUARDED_BY(request_lock_); - int64_t num_pending_requests_ ABSL_GUARDED_BY(request_lock_) = 0; - - bool shutting_down_ ABSL_GUARDED_BY(request_lock_) = false; - - void StreamWriterFn(); - Thread writer_thread_; - - void StreamReaderFn(); - Thread reader_thread_; - - // Map from operation ID to event information. - absl::Mutex events_mutex_; - absl::flat_hash_map events_ - ABSL_GUARDED_BY(events_mutex_); - - // Map from operation ID to transfer information. - // When a D2H transfer completes, received data is copied into the `dst` - // pointer in `TransferInfo`. - absl::Mutex transfers_mutex_; - absl::flat_hash_map transfers_ - ABSL_GUARDED_BY(transfers_mutex_); - - absl::Mutex compiles_mutex_; - absl::flat_hash_map compiles_ - ABSL_GUARDED_BY(compiles_mutex_); -}; - -class GrpcTpuDriver : public TpuDriver { - public: - explicit GrpcTpuDriver(const TpuDriverConfig& config, - std::shared_ptr<::grpc::ChannelCredentials> creds, - int32_t client_id) - : config_(config), creds_(creds), client_id_(client_id) { - SystemInfo system_info; - QuerySystemInfo(&system_info); - for (auto& chip_info : system_info.tpu_chip()) { - for (auto& core_info : chip_info.core()) { - int32_t core_id = core_info.id(); - // We have one stream per core, so use core ID as stream ID. - streams_[core_id] = AllocateStream(core_id); - } - } - CHECK_GT(streams_.size(), 0) << "Can't find any TPU chip in the system."; - - host_stream_ = AllocateStream(-1); - } - - ~GrpcTpuDriver() override { - if (closed_) { - return; - } - auto status = Close(); - if (!status.ok()) { - LOG(ERROR) << status; - } - } - - void QuerySystemInfo(SystemInfo* system_info) override; - Status Reset() override; - - std::unique_ptr Allocate( - int32_t core_id, MemoryRegion region, int64_t num_bytes, - absl::Span wait_for) override { - return streams_[core_id]->Allocate(core_id, region, num_bytes, wait_for); - } - std::unique_ptr Allocate( - int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape, - absl::Span wait_for) override { - return streams_[core_id]->Allocate(core_id, region, shape, wait_for); - } - std::unique_ptr AllocateTuple( - int32_t core_id, MemoryRegion region, - absl::Span children, - absl::Span wait_for) override { - return streams_[core_id]->AllocateTuple(core_id, region, children, - wait_for); - } - std::shared_ptr Deallocate( - std::unique_ptr handle, - absl::Span wait_for) override { - auto* stream = static_cast(handle.get())->stream(); - return stream->Deallocate(std::move(handle), wait_for); - } - - std::shared_ptr TransferToDevice( - const void* src, BufferHandle* dst, - absl::Span wait_for) override { - auto* stream = static_cast(dst)->stream(); - return stream->TransferToDevice(src, dst, wait_for); - } - std::shared_ptr TransferFromDevice( - const BufferHandle* src, void* dst, - absl::Span wait_for) override { - auto* stream = static_cast(src)->stream(); - return stream->TransferFromDevice(src, dst, wait_for); - } - - std::shared_ptr TransferFromDeviceToDevice( - const BufferHandle* src, BufferHandle* dst, - absl::Span wait_for) override { - auto* stream = static_cast(src)->stream(); - return stream->TransferFromDeviceToDevice(src, dst, wait_for); - } - - std::unique_ptr CompileProgram( - const xla::HloProto& source, int32_t num_replicas, - absl::Span wait_for, - const xla::DebugOptions& debug_options) override { - // Always compile using the first/default core's stream. - return streams_[0]->CompileProgram(source, num_replicas, wait_for, - debug_options); - } - std::unique_ptr LoadProgram( - int32_t core_id, const CompiledProgramHandle* handle, - absl::Span wait_for) override { - return streams_[core_id]->LoadProgram(core_id, handle, wait_for); - } - std::shared_ptr UnloadProgram( - std::unique_ptr handle, - absl::Span wait_for) override { - auto* stream = - static_cast(handle.get())->stream(); - return stream->UnloadProgram(std::move(handle), wait_for); - } - std::shared_ptr ExecuteProgram( - LoadedProgramHandle* program, absl::Span inputs, - absl::Span outputs, - const xla::DeviceAssignmentProto& device_assignment, - absl::Span wait_for) override { - auto* stream = - static_cast(program)->stream(); - return stream->ExecuteProgram(program, inputs, outputs, device_assignment, - wait_for); - } - - EventId NewOperationId() { return EventId{client_id_, ++operation_id_}; } - - static std::unique_ptr CreateTpuDriverStub( - const TpuDriverConfig& config, - std::shared_ptr<::grpc::ChannelCredentials> creds); - - uint32_t client_id() const { return client_id_; } - - private: - Status Close(); - std::unique_ptr AllocateStream(int32_t core_id); - - const TpuDriverConfig config_; - std::shared_ptr<::grpc::ChannelCredentials> creds_; - const uint32_t client_id_; - // Map from stream IDs to streams. - absl::flat_hash_map> streams_; - std::unique_ptr host_stream_; - // Shared by all streams. - std::atomic operation_id_{0}; - std::atomic closed_{false}; -}; // namespace - -GrpcEvent::~GrpcEvent() { stream_->DeleteEvent(id_); } - -Status GrpcEvent::Await() { - auto opt_status = stream_->WaitForEvent(id_, absl::InfiniteDuration()); - return opt_status.value(); -} - -std::optional GrpcEvent::AwaitWithTimeout(absl::Duration duration) { - return stream_->WaitForEvent(id_, duration); -} - -void GrpcEvent::AddCallback(std::function callback) { - stream_->AddEventCallback(id_, std::move(callback)); -} - -GrpcTpuStream::GrpcTpuStream(int32_t id, GrpcTpuDriver* driver, - std::unique_ptr stub) - : id_(id), - driver_(driver), - stub_(std::move(stub)), - stream_(stub_->StreamExecute(&ctx_)), - writer_thread_(&GrpcTpuStream::StreamWriterFn, this), - reader_thread_(&GrpcTpuStream::StreamReaderFn, this) {} - -GrpcTpuStream::~GrpcTpuStream() { - { - absl::MutexLock lock(&request_lock_); - shutting_down_ = true; - } - - VLOG(1) << "Shutting down stream."; - { - // Mark all remaining events invalid. - absl::MutexLock lock(&events_mutex_); - for (const auto& e : events_) { - if (!e.second.done) { - LOG(ERROR) << "Resetting: " << e.first; - UpdateEventStatus(e.first, xla::Status(absl::StatusCode::kAborted, - "Driver was closed.")); - } - } - } - VLOG(1) << "Closing stream."; - stream_->WritesDone(); - stream_->Finish().IgnoreError(); - VLOG(1) << "Waiting for writer."; - writer_thread_.join(); - VLOG(1) << "Waiting for reader."; - reader_thread_.join(); -} - -void GrpcTpuStream::InitializeRequest(StreamRequest::Entry* req, - absl::Span wait_for) { - auto operation_id = driver_->NewOperationId(); - EventInfo event_info; - - req->set_operation_id(operation_id.AsInt()); - if (wait_for.empty()) { - event_info.all_deps_done = true; - } else { - event_info.deps.reserve(wait_for.size()); - for (auto* event : wait_for) { - auto grpc_event = static_cast(event); - req->add_wait_for_id(grpc_event->id().AsInt()); - event_info.deps.push_back(grpc_event->id()); - } - } - - absl::MutexLock lock(&events_mutex_); - events_[operation_id] = event_info; -} - -void GrpcTpuStream::UpdateEventStatus(EventId id, Status status) { - auto it = events_.find(id); - - // These should only happen when the server shuts down, and our local event - // cancellation interleaves with server responses. It should be safe to ignore - // the second updates in these situations. - if (it == events_.end()) { - VLOG(1) << "Received a status update: " << status - << ", but cannot find GrpcEvent " << id; - return; - } - if (it->second.done) { - // Done and deleted events must have already been removed. - CHECK(!it->second.deleted); - VLOG(1) << "Received a second status update: " << status.message() - << ", for GrpcEvent " << id - << " already done with status: " << it->second.status.message(); - return; - } - - // This is the first time this event finishes. Remember the results and call - // the callbacks. - VLOG(1) << "Response received for GrpcEvent " << id << ". " << status - << ". Firing " << it->second.callbacks.size() << " callbacks."; - it->second.done = true; - it->second.status = status; - for (const auto& callback : it->second.callbacks) { - callback(status); - } - - // Truly remove the event if it's both done and deleted. - if (it->second.deleted) { - events_.erase(it); - } -} - -void GrpcTpuStream::DeleteEvent(EventId id) { - absl::MutexLock lock(&events_mutex_); - auto it = events_.find(id); - CHECK(it != events_.end()); - CHECK(!it->second.deleted); - it->second.deleted = true; - // Truly remove the event if it's both done and deleted. - if (it->second.done) { - events_.erase(it); - } -} - -std::optional GrpcTpuStream::WaitForEvent(EventId id, - absl::Duration duration) { - events_mutex_.Lock(); - auto it = events_.find(id); - - if (it == events_.end()) { - // This event has already been marked as done and deleted. Assume success. - events_mutex_.Unlock(); - return OkStatus(); - } - - if (!it->second.all_deps_done) { - absl::InlinedVector deps = it->second.deps; - events_mutex_.Unlock(); - for (auto dep : deps) { - // If a requirement event timed out, no point in any further waiting. - if (!WaitForEvent(dep, duration)) { - return std::nullopt; - } - } - events_mutex_.Lock(); - } - - // Set the flag here, as we're guaranteed they have all completed at this - // point. This helps terminate recursion on a chain of completed events as - // soon as possible, at this event. - it = events_.find(id); - if (it != events_.end()) { - it->second.all_deps_done = true; - } - - auto done = [this, id]() { - events_mutex_.AssertHeld(); - return !events_.contains(id) || events_[id].done; - }; - if (events_mutex_.AwaitWithTimeout(absl::Condition(&done), duration)) { - auto status = events_.contains(id) ? events_[id].status : OkStatus(); - events_mutex_.Unlock(); - return status; - } - events_mutex_.Unlock(); - return std::nullopt; -} - -void GrpcTpuStream::AddEventCallback(EventId id, - std::function callback) { - absl::MutexLock lock(&events_mutex_); - auto it = events_.find(id); - if (it == events_.end()) { - callback(Status()); - return; - } - if (it->second.done) { - callback(it->second.status); - return; - } - it->second.callbacks.push_back(std::move(callback)); -} - -static bool ShouldBeginWriting(int64_t* pending_requests) { - return *pending_requests > 32; -} - -void GrpcTpuStream::StreamWriterFn() { - while (true) { - request_lock_.LockWhenWithTimeout( - absl::Condition(&ShouldBeginWriting, &num_pending_requests_), - kWriteEpochDuration); - if (shutting_down_) { - request_lock_.Unlock(); - return; - } - - if (requests_.empty()) { - request_lock_.Unlock(); - continue; - } - - std::vector reqs; - int64_t request_bytes = 0; - while (!requests_.empty()) { - StreamRequest::Entry* e = requests_.front().release(); - requests_.pop_front(); - const int64_t entry_bytes = e->ByteSizeLong(); - if (reqs.empty() || request_bytes + entry_bytes > kMaxStreamWriteSize) { - reqs.push_back(StreamRequest()); - request_bytes = 0; - } - VLOG(1) << "Sending request: " << EventId::FromInt(e->operation_id()); - VLOG(2) << "Sending request: " << e->DebugString(); - reqs.back().mutable_entry()->AddAllocated(e); - } - num_pending_requests_ = 0; - request_lock_.Unlock(); - - for (const auto& r : reqs) { - TraceMe activity("GrpcTpuStream::Send "); - ::grpc::WriteOptions opts; - opts.set_no_compression().clear_buffer_hint(); - stream_->Write(r, opts); - } - } -} - -void GrpcTpuStream::StreamReaderFn() { - StreamResponse resp; - while (stream_->Read(&resp)) { - VLOG(2) << "Received response: " << resp.DebugString(); - for (const StreamResponse::Entry& entry : resp.entry()) { - EventId event_id = EventId::FromInt(entry.operation_id()); - VLOG(1) << "Received response for: " << event_id; - - TraceMe activity("GrpcTpuStream::RequestComplete"); - if (entry.has_transfer_from()) { - TraceMe activity("GrpcTpuStream::TransferFromComplete"); - absl::MutexLock lock(&transfers_mutex_); - auto it = transfers_.find(event_id); - CHECK(it != transfers_.end()); - VLOG(1) << "Copying: " << it->second.num_bytes << " to position " - << it->second.dst; - if (entry.transfer_from().data().size() != it->second.num_bytes) { - absl::MutexLock lock(&events_mutex_); - UpdateEventStatus( - event_id, - Status( - absl::StatusCode::kDataLoss, - absl::StrCat("Expected ", it->second.num_bytes, " received ", - entry.transfer_from().data().size()))); - continue; - } - memcpy(it->second.dst, entry.transfer_from().data().data(), - it->second.num_bytes); - } - - if (entry.has_compile()) { - TraceMe activity("GrpcTpuStream::CompileComplete"); - absl::MutexLock lock(&compiles_mutex_); - auto it = compiles_.find(event_id); - CHECK(it != compiles_.end()); - *it->second.compiled_metadata = entry.compile().metadata(); - } - - absl::MutexLock lock(&events_mutex_); - if (entry.status().code() != tsl::error::Code::OK) { - UpdateEventStatus( - event_id, - Status(static_cast(entry.status().code()), - entry.status().message())); - } else { - UpdateEventStatus(event_id, OkStatus()); - } - } - } -} - -std::unique_ptr GrpcTpuStream::Allocate( - int32_t core_id, MemoryRegion region, int64_t num_bytes, - absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity("GrpcTpuStream::Allocate(num_bytes)"); - req->mutable_alloc()->set_core_id(core_id); - req->mutable_alloc()->set_region(region); - req->mutable_alloc()->set_num_bytes(num_bytes); - auto event = - std::make_shared(EventId::FromInt(req->operation_id()), this); - AddWriteRequest(std::move(req)); - return std::make_unique(event->id(), std::move(event), - num_bytes); -} - -std::unique_ptr GrpcTpuStream::Allocate( - int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape, - absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity("GrpcTpuStream::Allocate(shape)"); - req->mutable_alloc()->set_core_id(core_id); - req->mutable_alloc()->set_region(region); - *req->mutable_alloc()->mutable_shape() = shape; - auto event = - std::make_shared(EventId::FromInt(req->operation_id()), this); - AddWriteRequest(std::move(req)); - return std::make_unique( - event->id(), std::move(event), ComputeBytesFromShape(shape), shape); -} - -std::unique_ptr GrpcTpuStream::AllocateTuple( - int32_t core_id, MemoryRegion region, - absl::Span children, - absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity("GrpcTpuStream::AllocateTuple"); - req->mutable_alloc_tuple()->set_core_id(core_id); - req->mutable_alloc_tuple()->set_region(region); - for (auto child : children) { - auto grpc_child = static_cast(child); - req->mutable_alloc_tuple()->add_children(grpc_child->id().AsInt()); - } - auto event = - std::make_shared(EventId::FromInt(req->operation_id()), this); - AddWriteRequest(std::move(req)); - return std::make_unique(event->id(), std::move(event), 0); -} - -std::shared_ptr GrpcTpuStream::Deallocate( - std::unique_ptr handle, absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity("GrpcTpuStream::Deallocate"); - auto grpc_handle = static_cast(handle.get()); - req->mutable_dealloc()->set_handle(grpc_handle->id().AsInt()); - auto event = - std::make_shared(EventId::FromInt(req->operation_id()), this); - AddWriteRequest(std::move(req)); - return event; -} - -std::shared_ptr GrpcTpuStream::TransferToDevice( - const void* src, BufferHandle* dst, absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity("GrpcTpuStream::TransferToDevice"); - req->mutable_transfer_to()->mutable_data()->assign( - static_cast(src), dst->size_in_bytes()); - req->mutable_transfer_to()->set_target_handle( - static_cast(dst)->id().AsInt()); - auto event = - std::make_shared(EventId::FromInt(req->operation_id()), this); - AddWriteRequest(std::move(req)); - return event; -} - -std::shared_ptr GrpcTpuStream::TransferFromDevice( - const BufferHandle* src, void* dst, absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity("GrpcTpuStream::TransferFromDevice"); - req->mutable_transfer_from()->set_source_handle( - static_cast(src)->id().AsInt()); - EventId event_id = EventId::FromInt(req->operation_id()); - { - absl::MutexLock lock(&transfers_mutex_); - TransferInfo info(dst, const_cast(src)->size_in_bytes()); - transfers_.insert(std::make_pair(event_id, info)); - } - auto event = std::make_shared(event_id, this); - AddWriteRequest(std::move(req)); - return event; -} - -std::shared_ptr GrpcTpuStream::TransferFromDeviceToDevice( - const BufferHandle* src, BufferHandle* dst, - absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity([&req] { - return absl::StrCat("GrpcTpuStream::TransferFromDeviceToDevice", - req->operation_id()); - }); - - req->mutable_transfer_from_to()->set_source_handle( - static_cast(src)->id().AsInt()); - req->mutable_transfer_from_to()->set_target_handle( - static_cast(dst)->id().AsInt()); - EventId event_id = EventId::FromInt(req->operation_id()); - auto event = std::make_shared(event_id, this); - AddWriteRequest(std::move(req)); - return event; -} - -std::unique_ptr GrpcTpuStream::CompileProgram( - const xla::HloProto& source, int32_t num_replicas, - absl::Span wait_for, const xla::DebugOptions& debug_options) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity("GrpcTpuStream::CompileProgram"); - *req->mutable_compile()->mutable_hlo_program() = source; - req->mutable_compile()->set_num_replicas(num_replicas); - *req->mutable_compile()->mutable_debug_options() = debug_options; - EventId event_id = EventId::FromInt(req->operation_id()); - - auto event = - std::make_shared(EventId::FromInt(req->operation_id()), this); - - auto handle = std::make_unique(event->id(), - std::move(event)); - { - absl::MutexLock lock(&compiles_mutex_); - CompileMetadataInfo info(handle->metadata()); - compiles_.insert(std::make_pair(event_id, info)); - } - - AddWriteRequest(std::move(req)); - return std::move(handle); -} - -std::unique_ptr GrpcTpuStream::LoadProgram( - int32_t core_id, const CompiledProgramHandle* handle, - absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity("GrpcTpuStream::LoadProgram"); - req->mutable_load()->set_core_id(core_id); - auto grpc_handle = static_cast(handle); - if (grpc_handle->id().client_id != driver_->client_id()) { - auto event = std::make_shared( - xla::InvalidArgument("Invalid program handle (wrong client id). Did " - "you restart the server or use a stale handle?")); - return std::make_unique(event->id(), - std::move(event)); - } - req->mutable_load()->set_compiled_program_handle(grpc_handle->id().AsInt()); - auto event = - std::make_shared(EventId::FromInt(req->operation_id()), this); - AddWriteRequest(std::move(req)); - return std::make_unique(event->id(), - std::move(event)); -} - -std::shared_ptr GrpcTpuStream::UnloadProgram( - std::unique_ptr handle, - absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - TraceMe activity("GrpcTpuStream::UnloadProgram"); - req->mutable_unload()->set_loaded_program_handle( - static_cast(handle.get())->id().AsInt()); - auto event = - std::make_shared(EventId::FromInt(req->operation_id()), this); - AddWriteRequest(std::move(req)); - return event; -} - -std::shared_ptr GrpcTpuStream::ExecuteProgram( - LoadedProgramHandle* program, absl::Span inputs, - absl::Span outputs, - const xla::DeviceAssignmentProto& device_assignment, - absl::Span wait_for) { - auto req = std::make_unique(); - InitializeRequest(req.get(), wait_for); - auto program_handle = static_cast(program); - if (program_handle->id().client_id != driver_->client_id()) { - return std::make_shared( - xla::InvalidArgument("Invalid program handle (wrong client id). Did " - "you restart the server or use a stale handle?")); - } - - req->mutable_execute()->set_loaded_program_handle( - program_handle->id().AsInt()); - - for (BufferHandle* input : inputs) { - auto* grpc_handle = static_cast(input); - if (grpc_handle->id().client_id != driver_->client_id()) { - return std::make_shared(xla::InvalidArgument( - "Invalid input buffer (wrong client id). Did you restart the server " - "or use a stale handle?")); - } - req->mutable_execute()->add_input_handle(grpc_handle->id().AsInt()); - } - - for (BufferHandle* output : outputs) { - auto* grpc_handle = static_cast(output); - if (grpc_handle->id().client_id != driver_->client_id()) { - return std::make_shared(xla::InvalidArgument( - "Invalid output buffer (wrong client id). Did you restart the server " - "or use a stale handle?")); - } - req->mutable_execute()->add_output_handle( - static_cast(output)->id().AsInt()); - } - // Only pass along device_assignment if it's not default constructed. - if (!(device_assignment.replica_count() == 0 && - device_assignment.computation_count() == 0)) { - *req->mutable_execute()->mutable_device_assignment() = device_assignment; - } - auto event = - std::make_shared(EventId::FromInt(req->operation_id()), this); - AddWriteRequest(std::move(req)); - return event; -} - -/*static*/ std::unique_ptr -GrpcTpuDriver::CreateTpuDriverStub( - const TpuDriverConfig& config, - std::shared_ptr<::grpc::ChannelCredentials> creds) { - ::grpc::ChannelArguments args; - args.SetMaxReceiveMessageSize(std::numeric_limits::max()); - args.SetMaxSendMessageSize(std::numeric_limits::max()); - - // Send at least 20 keep-alives before giving up. - int keepalive_timeout_ms = config.grpc().keepalive_timeout_secs() * 1000; - int keepalive_interval_ms = keepalive_timeout_ms / 20; - - grpc_arg client_arg_vals[] = { - {.type = GRPC_ARG_INTEGER, - .key = const_cast( - GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS), - .value = {.integer = keepalive_interval_ms}}, - {.type = GRPC_ARG_INTEGER, - .key = const_cast(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA), - .value = {.integer = 0}}, // unlimited - {.type = GRPC_ARG_INTEGER, - .key = const_cast(GRPC_ARG_KEEPALIVE_TIME_MS), - .value = {.integer = keepalive_interval_ms}}, - {.type = GRPC_ARG_INTEGER, - .key = const_cast(GRPC_ARG_KEEPALIVE_TIMEOUT_MS), - .value = {.integer = keepalive_timeout_ms}}, - {.type = GRPC_ARG_INTEGER, - .key = const_cast(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS), - .value = {.integer = 1}}, - {.type = GRPC_ARG_INTEGER, - .key = const_cast(GRPC_ARG_HTTP2_WRITE_BUFFER_SIZE), - .value = {.integer = 64 * 1000 * 1000}}}; - - grpc_channel_args client_args = {.num_args = 6, .args = client_arg_vals}; - args.SetChannelArgs(&client_args); - - // strips out 'grpc://' - auto worker_addr = absl::StripPrefix(config.worker(), kGrpcProtocol); - std::shared_ptr<::grpc::Channel> channel = - ::grpc::CreateCustomChannel(std::string(worker_addr), creds, args); - return grpc::CloudTpuDriver::NewStub(channel); -} - -std::unique_ptr GrpcTpuDriver::AllocateStream(int32_t id) { - auto stub = CreateTpuDriverStub(config_, creds_); - ::grpc::ClientContext ctx; - ctx.set_fail_fast(false); - ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10)); - return std::make_unique(id, this, std::move(stub)); -} - -void GrpcTpuDriver::QuerySystemInfo(SystemInfo* system_info) { - auto stub = CreateTpuDriverStub(config_, creds_); - ::grpc::ClientContext ctx; - ctx.set_fail_fast(false); - ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10)); - - QuerySystemInfoRequest req; - QuerySystemInfoResponse resp; - ::grpc::Status status = stub->QuerySystemInfo(&ctx, req, &resp); - if (!status.ok()) { - LOG(ERROR) << "QuerySystemInfo request failed: " << status.error_code() - << ": " << status.error_message() << ": " - << status.error_details(); - return; - } - *system_info = resp.system_info(); -} - -Status GrpcTpuDriver::Reset() { - auto stub = CreateTpuDriverStub(config_, creds_); - ::grpc::ClientContext ctx; - ctx.set_fail_fast(false); - ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10)); - ResetRequest req; - ResetResponse resp; - ::grpc::Status status = stub->Reset(&ctx, req, &resp); - if (!status.ok()) { - LOG(ERROR) << "Failed to reset the gRPC driver: " << status.error_code() - << ": " << status.error_message() << ": " - << status.error_details(); - return xla::Status(absl::StatusCode(status.error_code()), - absl::StrCat("Failed to reset TPU driver. Error was: ", - status.error_message(), - ". Details: ", status.error_details())); - } - streams_.clear(); - host_stream_.reset(); - return Close(); -} - -Status GrpcTpuDriver::Close() { - auto stub = CreateTpuDriverStub(config_, creds_); - ::grpc::ClientContext ctx; - ctx.set_fail_fast(false); - ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10)); - CloseRequest req; - req.set_client_id(client_id_); - CloseResponse resp; - ::grpc::Status status = stub->Close(&ctx, req, &resp); - if (!status.ok()) { - return xla::Status(absl::StatusCode(status.error_code()), - absl::StrCat("Failed to close TPU driver. Error was: ", - status.error_message(), - ". Details: ", status.error_details())); - } - closed_ = true; - return OkStatus(); -} -} // namespace - -xla::StatusOr> CreateGrpcTpuDriver( - const TpuDriverConfig& config, - std::shared_ptr<::grpc::ChannelCredentials> creds) { - auto stub = GrpcTpuDriver::CreateTpuDriverStub(config, creds); - ::grpc::ClientContext ctx; - ctx.set_fail_fast(false); - ctx.set_deadline( - std::chrono::system_clock::now() + - std::chrono::seconds(config.grpc().connection_timeout_secs())); - OpenRequest req; - OpenResponse resp; - ::grpc::Status status = stub->Open(&ctx, req, &resp); - if (!status.ok()) { - LOG(ERROR) << "Failed to open the gRPC driver: " << status.error_code() - << ": " << status.error_message() << ": " - << status.error_details(); - return xla::Status( - absl::StatusCode(status.error_code()), - absl::StrCat( - "Failed to connect to remote server at address: ", config.worker(), - ". Error from gRPC: ", status.error_message(), - ". Details: ", status.error_details())); - } - return std::unique_ptr( - new GrpcTpuDriver(config, creds, resp.client_id())); -} - -REGISTER_TPU_DRIVER( - "grpc://", - [](const TpuDriverConfig& config) - -> xla::StatusOr> { - if (absl::StartsWith(config.worker(), "grpc://localhost")) { - LOG(INFO) << "Using local credentials for localhost: connection."; - return CreateGrpcTpuDriver( - config, ::grpc::experimental::LocalCredentials(LOCAL_TCP)); - } else { - return CreateGrpcTpuDriver(config, - ::grpc::InsecureChannelCredentials()); - } - }); - -} // namespace tpu_driver diff --git a/xla/python/tpu_driver/grpc_tpu_driver.h b/xla/python/tpu_driver/grpc_tpu_driver.h deleted file mode 100644 index 88cacbe8c030b..0000000000000 --- a/xla/python/tpu_driver/grpc_tpu_driver.h +++ /dev/null @@ -1,33 +0,0 @@ -#ifndef XLA_PYTHON_TPU_DRIVER_GRPC_TPU_DRIVER_H_ -#define XLA_PYTHON_TPU_DRIVER_GRPC_TPU_DRIVER_H_ - -// Copyright 2019 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== - -#include - -#include "grpcpp/grpcpp.h" -#include "xla/python/tpu_driver/tpu_driver.h" -#include "xla/python/tpu_driver/tpu_driver.pb.h" - -namespace tpu_driver { - -xla::StatusOr> CreateGrpcTpuDriver( - const TpuDriverConfig& config, - std::shared_ptr credentials); - -} // namespace tpu_driver - -#endif // XLA_PYTHON_TPU_DRIVER_GRPC_TPU_DRIVER_H_ diff --git a/xla/python/tpu_driver/platform/external/compat.h b/xla/python/tpu_driver/platform/external/compat.h deleted file mode 100644 index 934f6d5bccf6f..0000000000000 --- a/xla/python/tpu_driver/platform/external/compat.h +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2019 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== - -#ifndef XLA_PYTHON_TPU_DRIVER_PLATFORM_EXTERNAL_COMPAT_H_ -#define XLA_PYTHON_TPU_DRIVER_PLATFORM_EXTERNAL_COMPAT_H_ - -#include // NOLINT - -#include "absl/strings/string_view.h" - -namespace tpu_driver { - -class Thread { - public: - template - explicit Thread(Function&& f, Args&&... args) - : thread_(std::forward(f), std::forward(args)...) {} - void join() { thread_.join(); } - - private: - std::thread thread_; -}; - -class TraceMe { - public: - explicit TraceMe(absl::string_view name, int level = 1) {} - explicit TraceMe(std::string&& name, int level = 1) = delete; - explicit TraceMe(const std::string& name, int level = 1) = delete; - explicit TraceMe(const char* raw, int level = 1) - : TraceMe(absl::string_view(raw), level) {} - template - explicit TraceMe(NameGeneratorT name_generator, int level = 1) {} - ~TraceMe() {} -}; - -} // namespace tpu_driver - -#endif // XLA_PYTHON_TPU_DRIVER_PLATFORM_EXTERNAL_COMPAT_H_ diff --git a/xla/python/tpu_driver/platform/external/tools.bzl b/xla/python/tpu_driver/platform/external/tools.bzl deleted file mode 100644 index 80961fd9f62da..0000000000000 --- a/xla/python/tpu_driver/platform/external/tools.bzl +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Build dependencies and utilities for the TPU driver interface. -""" - -def go_grpc_library(**_kwargs): - # A dummy macro placeholder for compatibility reason. - pass - -def external_deps(): - return [ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - ] diff --git a/xla/python/tpu_driver/pod_tpu_driver.cc b/xla/python/tpu_driver/pod_tpu_driver.cc deleted file mode 100644 index 40d4532280542..0000000000000 --- a/xla/python/tpu_driver/pod_tpu_driver.cc +++ /dev/null @@ -1,991 +0,0 @@ -// Copyright 2020 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/btree_map.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/strings/str_split.h" -#include "absl/synchronization/mutex.h" -#include "xla/pjrt/semaphore.h" -#include "xla/pjrt/worker_thread.h" -#include "xla/python/tpu_driver/grpc_tpu_driver.h" -#include "xla/python/tpu_driver/tpu_driver.h" -#include "xla/python/tpu_driver/tpu_driver.pb.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" - -namespace tpu_driver { -namespace { - -#define CHECK_EXISTS_OR_RETURN(container, target_op_id, operation_id) \ - { \ - auto p = CheckHandleExists(container, target_op_id, operation_id); \ - if (p != nullptr) return p; \ - } - -using xla::OkStatus; -using xla::Status; -using xla::WorkerThread; - -const char kPodTpuDriverPrefix[] = "grpc+pod://"; - -class PodTpuDriver; - -class PodEvent : public Event { - public: - explicit PodEvent(PodTpuDriver* driver, int64_t operation_id) - : driver_(driver), operation_id_(operation_id) {} - int64_t operation_id() const { return operation_id_; } - - xla::Status Await() override; - - std::optional AwaitWithTimeout(absl::Duration duration) override; - - void AddCallback(std::function callback) override; - - private: - PodTpuDriver* driver_; - const int64_t operation_id_; -}; - -class ErrorEvent : public PodEvent { - public: - explicit ErrorEvent(PodTpuDriver* driver, int64_t operation_id, Status status) - : PodEvent(driver, operation_id) { - status_ = status; - } - - xla::Status Await() override { return status_; } - std::optional AwaitWithTimeout( - absl::Duration duration) override { - return status_; - } - void AddCallback(std::function callback) override { - callback(status_); - } - - private: - Status status_; -}; - -class CombinedEvent : public PodEvent { - public: - explicit CombinedEvent(PodTpuDriver* driver, int64_t operation_id, - std::vector> events) - : PodEvent(driver, operation_id), events_(events) { - for (auto& event : events_) { - event->AddCallback([this](Status s) { IncrementAndCheckComplete(s); }); - } - } - - xla::Status Await() override { - for (auto& event : events_) { - TF_RETURN_IF_ERROR(event->Await()); - } - return OkStatus(); - } - - std::optional AwaitWithTimeout( - absl::Duration duration) override { - for (auto& event : events_) { - auto start_time = absl::Now(); - auto status = event->AwaitWithTimeout(duration); - duration -= absl::Now() - start_time; - if (status == std::nullopt) { - return std::nullopt; - } else { - TF_RETURN_IF_ERROR(status.value()); - } - } - return OkStatus(); - } - - void AddCallback(std::function callback) - ABSL_LOCKS_EXCLUDED(mu_) override { - bool all_events_completed = false; - { - absl::MutexLock l(&mu_); - all_events_completed = events_completed_ == events_.size(); - } - if (all_events_completed) { - callback(event_status_); - } else { - absl::MutexLock l(&mu_); - callbacks_.push_back(std::move(callback)); - } - } - - private: - void IncrementAndCheckComplete(Status s) ABSL_LOCKS_EXCLUDED(mu_) { - std::vector> callbacks; - { - absl::MutexLock l(&mu_); - - event_status_ = s; - events_completed_++; - if (events_completed_ == events_.size()) { - // Copy callbacks to a temporary to be invoked outside the mutex. - callbacks.assign(callbacks_.begin(), callbacks_.end()); - callbacks_.clear(); - } else { - return; - } - } - - for (const auto& callback : callbacks) { - callback(event_status_); - } - } - - absl::Mutex mu_; - std::vector> events_; - std::vector> callbacks_ ABSL_GUARDED_BY(mu_); - int64_t events_completed_ ABSL_GUARDED_BY(mu_) = 0; - Status event_status_; -}; - -class PodBufferHandle : public BufferHandle { - public: - explicit PodBufferHandle(PodTpuDriver* driver, int64_t operation_id, - int64_t size_in_bytes, - std::optional shape, - int64_t core_id) - : driver_(driver), - operation_id_(operation_id), - size_in_bytes_(size_in_bytes), - shape_(shape), - event_(std::make_shared(driver_, operation_id_)), - core_id_(core_id) {} - - std::shared_ptr OnReady() override { return event_; } - int64_t size_in_bytes() override { return size_in_bytes_; } - std::optional shape() override { return shape_; } - - int64_t operation_id() const { return operation_id_; } - int64_t core_id() const { return core_id_; } - - private: - PodTpuDriver* driver_; - const int64_t operation_id_; - const int64_t size_in_bytes_; - const std::optional shape_; - std::shared_ptr event_; - const int64_t core_id_; -}; - -class PodCompiledProgramHandle : public CompiledProgramHandle { - public: - explicit PodCompiledProgramHandle(PodTpuDriver* driver, int64_t operation_id) - : driver_(driver), - operation_id_(operation_id), - event_(std::make_shared(driver_, operation_id_)) {} - - std::shared_ptr OnReady() override { return event_; } - - xla::Status program_shape(xla::ProgramShapeProto* program_shape) override; - - int64_t operation_id() const { return operation_id_; } - - private: - PodTpuDriver* driver_; - const int64_t operation_id_; - std::shared_ptr event_; -}; - -class PodLoadedProgramHandle : public LoadedProgramHandle { - public: - explicit PodLoadedProgramHandle(PodTpuDriver* driver, int64_t operation_id, - int64_t core_id) - : driver_(driver), - operation_id_(operation_id), - core_id_(core_id), - event_(std::make_shared(driver_, operation_id_)) {} - - std::shared_ptr OnReady() override { return event_; } - - int64_t operation_id() const { return operation_id_; } - int64_t core_id() const { return core_id_; } - - private: - PodTpuDriver* driver_; - const int64_t operation_id_; - const int64_t core_id_; - std::shared_ptr event_; -}; - -struct EventInFlight { - EventInFlight() - : underlying_event(nullptr), - create_fn(nullptr), - incomplete_deps(), - callbacks() {} - - std::shared_ptr underlying_event; - std::function(void)> create_fn; - - absl::flat_hash_set incomplete_deps; - std::vector> callbacks; -}; - -class PodTpuDriver : public TpuDriver { - public: - explicit PodTpuDriver(const TpuDriverConfig& config, - std::shared_ptr<::grpc::ChannelCredentials> creds) - : config_(config), - creds_(creds), - event_thread_(tsl::Env::Default(), "grpc_pod_event_thread") { - std::vector workers = absl::StrSplit( - absl::StripPrefix(config.worker(), kPodTpuDriverPrefix), ','); - - int worker_count = 0; - - // Flag for environments where local core # == all cores in TPU system #, - // which means that we are connecting to separate TPU systems or we are in - // a test environment. - bool in_local_core_environment = false; - - for (const auto& worker : workers) { - TpuDriverConfig worker_config(config_); - *(worker_config.mutable_worker()) = absl::StrCat("grpc://", worker); - auto tpu_driver = CreateGrpcTpuDriver(worker_config, creds_).value(); - - SystemInfo driver_info; - tpu_driver->QuerySystemInfo(&driver_info); - - if (driver_info.core_count() == driver_info.local_core_size()) { - drivers_.insert({worker_count, std::move(tpu_driver)}); - in_local_core_environment = true; - } else { - drivers_.insert({driver_info.host_id(), std::move(tpu_driver)}); - } - - worker_count++; - } - - absl::flat_hash_set> processed_chips; - - for (int driver_num = 0; driver_num < workers.size(); ++driver_num) { - SystemInfo driver_info; - drivers_[driver_num]->QuerySystemInfo(&driver_info); - - for (const auto& tpu_chip : driver_info.tpu_chip()) { - std::tuple coord{tpu_chip.chip_coord().x(), - tpu_chip.chip_coord().y(), - tpu_chip.chip_coord().z()}; - // We only want to add chips that we have not seen before if we are in a - // TPU pod slice, or we are only seeing local cores (e.g. we are - // connected to individual TPUs or we are in a test environment). - if (!processed_chips.contains(coord) || - driver_info.core_count() == driver_info.local_core_size()) { - *(pod_info_.add_tpu_chip()) = tpu_chip; - processed_chips.insert(coord); - } - } - - *(pod_info_.mutable_cpu()) = driver_info.cpu(); - } - - // Process all the unique chips that we have seen. - int core_count = 0; - for (auto& tpu_chip : *pod_info_.mutable_tpu_chip()) { - for (auto& tpu_core : *tpu_chip.mutable_core()) { - int current_core = tpu_core.id(); - if (in_local_core_environment) { - current_core = core_count; - } - - core_to_driver_.insert( - {current_core, drivers_[tpu_chip.host_id()].get()}); - core_to_driver_id_.insert({current_core, tpu_chip.host_id()}); - core_to_driver_core_.insert({current_core, tpu_core.id()}); - - tpu_core.set_id(current_core); - tpu_core.set_core_on_host_index(current_core); - *(pod_info_.add_local_core()) = tpu_core; - - core_count++; - } - - // We are setting host_id to zero because we want this to look like one - // host with many cores from the perspective of tpu_client.cc. - tpu_chip.set_host_id(0); - } - - pod_info_.set_chip_count(pod_info_.tpu_chip_size()); - pod_info_.set_core_count(pod_info_.local_core_size()); - - // We want this to look like one host with many TPU chips/cores connected. - pod_info_.set_host_count(1); - pod_info_.set_host_id(0); - } - - ~PodTpuDriver() override { - // TODO(frankchn): Unload all handles, and wait for all events to finish. - } - - void QuerySystemInfo(SystemInfo* system_info) override { - *system_info = pod_info_; - } - - xla::Status Reset() override { - for (auto& driver : drivers_) { - TF_RETURN_IF_ERROR(driver.second->Reset()); - } - return OkStatus(); - } - - std::unique_ptr Allocate( - int32_t core_id, MemoryRegion region, int64_t num_bytes, - absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - - ScheduleRequest( - operation_id, - [this, core_id, region, num_bytes, operation_id]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - underlying_buffers_.insert( - {operation_id, - core_to_driver_[core_id]->Allocate( - core_to_driver_core_[core_id], region, num_bytes, {})}); - return underlying_buffers_[operation_id]->OnReady(); - }, - deps); - - return std::make_unique(this, operation_id, num_bytes, - std::nullopt, core_id); - } - - std::unique_ptr Allocate( - int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape, - absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - - ScheduleRequest( - operation_id, - [this, core_id, region, shape, operation_id]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - underlying_buffers_.insert( - {operation_id, - core_to_driver_[core_id]->Allocate( - core_to_driver_core_[core_id], region, shape, {})}); - return underlying_buffers_[operation_id]->OnReady(); - }, - deps); - - return std::make_unique( - this, operation_id, ComputeBytesFromShape(shape), shape, core_id); - } - - std::unique_ptr AllocateTuple( - int32_t core_id, MemoryRegion region, - absl::Span children, - absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - - std::vector children_ids; - const size_t children_ids_size = children.size(); - children_ids.reserve(children_ids_size); - for (size_t i = 0; i < children_ids_size; ++i) { - auto child_op_id = - static_cast(children[i])->operation_id(); - deps.insert(child_op_id); - children_ids.push_back(child_op_id); - } - - ScheduleRequest( - operation_id, - [this, core_id, region, children_ids, operation_id]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { - std::vector child_buffers; - child_buffers.reserve(children_ids.size()); - for (size_t i = 0; i < children_ids.size(); ++i) { - CHECK_EXISTS_OR_RETURN(underlying_buffers_, children_ids[i], - operation_id); - child_buffers.push_back( - underlying_buffers_[children_ids[i]].get()); - } - - underlying_buffers_.insert( - {operation_id, core_to_driver_[core_id]->AllocateTuple( - core_to_driver_core_[core_id], region, - child_buffers, {})}); - return underlying_buffers_[operation_id]->OnReady(); - }, - deps); - - return std::make_unique(this, operation_id, 0, - std::nullopt, core_id); - } - - std::shared_ptr Deallocate( - std::unique_ptr handle, - absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - deps.insert(static_cast(handle.get())->operation_id()); - - auto op_id = static_cast(handle.get())->operation_id(); - auto core_id = static_cast(handle.get())->core_id(); - - ScheduleRequest( - operation_id, - [this, operation_id, op_id, core_id]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { - CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id); - - auto buf_iter = underlying_buffers_.find(op_id); - auto underlying_hn = std::move(buf_iter->second); - underlying_buffers_.erase(buf_iter); - - return core_to_driver_[core_id]->Deallocate( - std::move(underlying_hn), {}); - }, - deps); - - return std::make_shared(this, operation_id); - } - - std::shared_ptr TransferToDevice( - const void* src, BufferHandle* dst, - absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - deps.insert(static_cast(dst)->operation_id()); - - auto op_id = static_cast(dst)->operation_id(); - auto core_id = static_cast(dst)->core_id(); - - ScheduleRequest( - operation_id, - [this, src, operation_id, op_id, core_id]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { - CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id); - - auto buf_iter = underlying_buffers_.find(op_id); - return core_to_driver_[core_id]->TransferToDevice( - src, buf_iter->second.get(), {}); - }, - deps); - - return std::make_shared(this, operation_id); - } - - std::shared_ptr TransferFromDevice( - const BufferHandle* src, void* dst, - absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - deps.insert(static_cast(src)->operation_id()); - - auto op_id = static_cast(src)->operation_id(); - auto core_id = static_cast(src)->core_id(); - - ScheduleRequest( - operation_id, - [this, dst, operation_id, op_id, core_id]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { - CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id); - auto buf_iter = underlying_buffers_.find(op_id); - return core_to_driver_[core_id]->TransferFromDevice( - buf_iter->second.get(), dst, {}); - }, - deps); - - return std::make_shared(this, operation_id); - } - - std::shared_ptr TransferFromDeviceToDevice( - const BufferHandle* src, BufferHandle* dst, - absl::Span wait_for) override { - auto src_core_id = static_cast(src)->core_id(); - auto dst_core_id = static_cast(dst)->core_id(); - - auto src_driver_id = core_to_driver_id_[src_core_id]; - auto dst_driver_id = core_to_driver_id_[dst_core_id]; - - if (src_driver_id == dst_driver_id) { - // They are in the same host, we can schedule it normally - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - deps.insert(static_cast(src)->operation_id()); - deps.insert(static_cast(dst)->operation_id()); - - auto src_op_id = static_cast(src)->operation_id(); - auto dst_op_id = static_cast(dst)->operation_id(); - - ScheduleRequest( - operation_id, - [this, operation_id, src_op_id, dst_op_id, dst_core_id]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { - CHECK_EXISTS_OR_RETURN(underlying_buffers_, src_op_id, - operation_id); - CHECK_EXISTS_OR_RETURN(underlying_buffers_, dst_op_id, - operation_id); - - auto src_iter = underlying_buffers_.find(src_op_id); - auto dst_iter = underlying_buffers_.find(dst_op_id); - return core_to_driver_[dst_core_id]->TransferFromDeviceToDevice( - src_iter->second.get(), dst_iter->second.get(), {}); - }, - deps); - return std::make_shared(this, operation_id); - } else { - // src and dst are on different hosts, we have to bounce through us. - auto dst_size = dst->size_in_bytes(); - char* host_buf = new char[dst_size]; - - auto src_event = TransferFromDevice(src, host_buf, wait_for); - auto dst_event = TransferToDevice(host_buf, dst, {src_event.get()}); - dst_event->AddCallback( - [src_event, host_buf](xla::Status status) { delete[] host_buf; }); - return dst_event; - } - } - - std::unique_ptr CompileProgram( - const xla::HloProto& source, int32_t num_replicas, - absl::Span wait_for, - const xla::DebugOptions& debug_options) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - - ScheduleRequest( - operation_id, - [this, operation_id, source, num_replicas, - debug_options]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - auto cph_iterator = - underlying_cph_ - .insert( - {operation_id, - std::vector>()}) - .first; - - std::vector> collected_events; - for (int i = 0; i < drivers_.size(); ++i) { - auto current_cph = drivers_[i]->CompileProgram(source, num_replicas, - {}, debug_options); - cph_iterator->second.push_back(std::move(current_cph)); - collected_events.push_back(cph_iterator->second[i]->OnReady()); - } - return std::make_shared(this, operation_id, - collected_events); - }, - deps); - - return std::make_unique(this, operation_id); - } - - std::unique_ptr LoadProgram( - int32_t core_id, const CompiledProgramHandle* handle, - absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - deps.insert( - static_cast(handle)->operation_id()); - auto cph_op_id = - static_cast(handle)->operation_id(); - - ScheduleRequest( - operation_id, - [this, operation_id, cph_op_id, core_id]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { - CHECK_EXISTS_OR_RETURN(underlying_cph_, cph_op_id, operation_id); - auto cph_iter = underlying_cph_.find(cph_op_id); - - underlying_lph_.insert( - {operation_id, - core_to_driver_[core_id]->LoadProgram( - core_to_driver_core_[core_id], - cph_iter->second[core_to_driver_id_[core_id]].get(), - {})}); - - return underlying_lph_[operation_id]->OnReady(); - }, - deps); - - return std::make_unique(this, operation_id, - core_id); - } - - std::shared_ptr UnloadProgram( - std::unique_ptr handle, - absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - deps.insert( - static_cast(handle.get())->operation_id()); - auto op_id = - static_cast(handle.get())->operation_id(); - auto core_id = - static_cast(handle.get())->core_id(); - - ScheduleRequest( - operation_id, - [this, operation_id, op_id, core_id]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { - CHECK_EXISTS_OR_RETURN(underlying_lph_, op_id, operation_id); - auto lph_iter = underlying_lph_.find(op_id); - auto event = core_to_driver_[core_id]->UnloadProgram( - std::move(lph_iter->second), {}); - underlying_lph_.erase(lph_iter); - - return event; - }, - deps); - - return std::make_shared(this, operation_id); - } - - std::shared_ptr ExecuteProgram( - LoadedProgramHandle* program, absl::Span inputs, - absl::Span outputs, - const xla::DeviceAssignmentProto& device_assignment, - absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - - auto deps = GetDependencyOperationIds(wait_for); - deps.insert(static_cast(program)->operation_id()); - - auto op_id = static_cast(program)->operation_id(); - auto core_id = static_cast(program)->core_id(); - - std::vector input_op_ids; - std::vector output_op_ids; - input_op_ids.reserve(inputs.size()); - output_op_ids.reserve(outputs.size()); - - for (auto* input : inputs) { - auto input_dep = - static_cast(input)->operation_id(); - input_op_ids.push_back(input_dep); - deps.insert(input_dep); - } - for (auto* output : outputs) { - auto output_dep = - static_cast(output)->operation_id(); - output_op_ids.push_back(output_dep); - deps.insert(output_dep); - } - - ScheduleRequest( - operation_id, - [this, operation_id, core_id, op_id, input_op_ids, output_op_ids, - device_assignment]() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { - std::vector underlying_inputs; - std::vector underlying_outputs; - - underlying_inputs.reserve(input_op_ids.size()); - for (auto input_op_id : input_op_ids) { - CHECK_EXISTS_OR_RETURN(underlying_buffers_, input_op_id, - operation_id); - underlying_inputs.push_back( - underlying_buffers_[input_op_id].get()); - } - underlying_outputs.reserve(output_op_ids.size()); - for (auto output_op_id : output_op_ids) { - CHECK_EXISTS_OR_RETURN(underlying_buffers_, output_op_id, - operation_id); - underlying_outputs.push_back( - underlying_buffers_[output_op_id].get()); - } - - CHECK_EXISTS_OR_RETURN(underlying_lph_, op_id, operation_id); - LoadedProgramHandle* handle = underlying_lph_[op_id].get(); - return core_to_driver_[core_id]->ExecuteProgram( - handle, underlying_inputs, underlying_outputs, - device_assignment, {}); - }, - deps); - - return std::make_shared(this, operation_id); - } - - std::unique_ptr GetLinearizer() override { - return drivers_[0]->GetLinearizer(); - } - - // Helper methods for Event scheduling - - std::optional WaitForEvent(int64_t event_id, absl::Duration duration) - ABSL_LOCKS_EXCLUDED(mu_) { - std::shared_ptr underlying_event; - - { - absl::MutexLock l(&mu_); - auto event = events_.find(event_id); - - if (event == events_.end()) { - auto event_status = abnormal_event_status_.find(event_id); - if (event_status == abnormal_event_status_.end()) { - return OkStatus(); - } else { - return event_status->second; - } - } - - auto done = [this, event_id]() { - mu_.AssertHeld(); - // The event was either completed and erased from the map or we have - // an underlying event available to us. - return events_.count(event_id) == 0 || - (events_[event_id]->underlying_event != nullptr && - events_[event_id]->underlying_event.use_count() != 0); - }; - - auto status = mu_.AwaitWithTimeout(absl::Condition(&done), duration); - if (!status) { - return std::nullopt; - } - - if (events_.count(event_id) > 0) { - underlying_event = events_[event_id]->underlying_event; - } else { - underlying_event = nullptr; - } - } - - // Wait for the underlying event without holding on to the event_lock_, or - // else incoming events will not be processed. - if (underlying_event != nullptr) { - return underlying_event->AwaitWithTimeout(duration); - } else { - absl::MutexLock l(&mu_); - auto event_status = abnormal_event_status_.find(event_id); - if (event_status == abnormal_event_status_.end()) { - return OkStatus(); - } else { - return event_status->second; - } - } - } - - void AddCallbackForEvent(int64_t event_id, std::function fn) - ABSL_LOCKS_EXCLUDED(mu_) { - absl::MutexLock l(&mu_); - auto event = events_.find(event_id); - - if (event == events_.end()) { - auto event_status = abnormal_event_status_.find(event_id); - if (event_status == abnormal_event_status_.end()) { - fn(OkStatus()); - } else { - fn(event_status->second); - } - } else { - if (event->second->underlying_event != nullptr && - event->second->underlying_event.use_count() != 0) { - event->second->underlying_event->AddCallback(fn); - } else { - event->second->callbacks.push_back(std::move(fn)); - } - } - } - - xla::Status GetCompiledProgramShape(int64_t op_id, - xla::ProgramShapeProto* program_shape) - ABSL_LOCKS_EXCLUDED(mu_) { - absl::MutexLock l(&mu_); - - auto done = [this, op_id]() { - mu_.AssertHeld(); - return underlying_cph_.contains(op_id); - }; - mu_.Await(absl::Condition(&done)); - - return underlying_cph_[op_id][0]->program_shape(program_shape); - } - - private: - const TpuDriverConfig& config_; - std::shared_ptr<::grpc::ChannelCredentials> creds_; - - absl::flat_hash_map> drivers_; - absl::flat_hash_map core_to_driver_id_; - absl::flat_hash_map core_to_driver_; - absl::flat_hash_map core_to_driver_core_; - SystemInfo pod_info_; - - absl::Mutex mu_; - - absl::flat_hash_map> - underlying_buffers_ ABSL_GUARDED_BY(mu_); - absl::flat_hash_map>> - underlying_cph_ ABSL_GUARDED_BY(mu_); - absl::flat_hash_map> - underlying_lph_ ABSL_GUARDED_BY(mu_); - - absl::btree_map> events_ - ABSL_GUARDED_BY(mu_); - absl::flat_hash_map abnormal_event_status_ - ABSL_GUARDED_BY(mu_); - - std::atomic operation_id_counter_{0}; - - WorkerThread event_thread_; - - int64_t GetOperationId() { return operation_id_counter_++; } - - absl::flat_hash_set GetDependencyOperationIds( - absl::Span wait_for) { - absl::flat_hash_set deps; - for (auto* event : wait_for) { - deps.insert(static_cast(event)->operation_id()); - } - return deps; - } - - // EventCompleted is executed on the event_thread_ worker thread. We want - // to propagate the fact that the event is completed to any subsequent events - // that might depend on this event. - void EventCompleted(int64_t event_id, Status status) - ABSL_LOCKS_EXCLUDED(mu_) { - absl::MutexLock l(&mu_); - - absl::btree_map>::iterator - curr_event; - if (!status.ok()) abnormal_event_status_.insert({event_id, status}); - curr_event = events_.find(event_id); - - DCHECK(curr_event->second->callbacks.empty()); - DCHECK(curr_event->second->incomplete_deps.empty()); - - for (auto& event : events_) { - event.second->incomplete_deps.erase(event_id); - // The if statement conditions on both - // - all previous events have completed (incomplete_deps.empty()) - // - the op creating this event has not been called yet - // (event.second.create_fn != nullptr) - // We call the create_fn that creates the event and adds any relevant - // callbacks to the actual event, before setting create_fn to nullptr - // to indicate that it has already been called - if (event.second->incomplete_deps.empty() && - event.second->create_fn != nullptr) { - // We were the last unfilled dependency, all other dependencies are - // filled. We can now fire the create function. - event.second->underlying_event = event.second->create_fn(); - for (auto& fn : event.second->callbacks) { - event.second->underlying_event->AddCallback(std::move(fn)); - } - event.second->callbacks.clear(); - event.second->create_fn = nullptr; - } - } - - // We erase the current event to signal that it has finished. - events_.erase(curr_event); - } - - void ScheduleRequest(int64_t operation_id, - std::function(void)> fn, - const absl::flat_hash_set& deps) - ABSL_LOCKS_EXCLUDED(mu_) { - absl::MutexLock l(&mu_); - absl::btree_map>::iterator event; - absl::flat_hash_set incomplete_deps; - - event = - events_.insert({operation_id, std::make_unique()}).first; - for (const auto& dep : deps) { - if (events_.count(dep) > 0) incomplete_deps.insert(dep); - } - - if (incomplete_deps.empty()) { - // All dependencies have been fulfilled, we execute the request - // immediately and add a callback to inform our event fulfilled thread - // when it is done. - event->second->create_fn = nullptr; - event->second->underlying_event = fn(); - event->second->underlying_event->AddCallback( - [this, operation_id](Status status) { - event_thread_.Schedule([this, operation_id, status]() { - EventCompleted(operation_id, status); - }); - }); - } else { - // There are some dependencies that are not yet fulfilled. We attach - // the request to the event, and will execute it in the EventFulfilled - // worker thread when all its dependencies are fulfilled. - event->second->create_fn = std::move(fn); - event->second->incomplete_deps = std::move(incomplete_deps); - event->second->callbacks.push_back([this, operation_id](Status status) { - event_thread_.Schedule([this, operation_id, status]() { - EventCompleted(operation_id, status); - }); - }); - } - } - - template - std::shared_ptr CheckHandleExists( - absl::flat_hash_map& container, int64_t target_op_id, - int64_t operation_id) { - if (container.count(target_op_id) == 0) { - return std::make_shared( - this, operation_id, - tsl::errors::InvalidArgument("Handle ", target_op_id, - " does not exist.")); - } - return nullptr; - } -}; - -xla::Status PodEvent::Await() { - return driver_->WaitForEvent(operation_id_, absl::InfiniteDuration()).value(); -} - -std::optional PodEvent::AwaitWithTimeout(absl::Duration duration) { - return driver_->WaitForEvent(operation_id_, duration); -} - -void PodEvent::AddCallback(std::function callback) { - driver_->AddCallbackForEvent(operation_id_, std::move(callback)); -} - -xla::StatusOr> CreatePodTpuDriver( - const TpuDriverConfig& config, - std::shared_ptr<::grpc::ChannelCredentials> creds) { - return std::unique_ptr(new PodTpuDriver(config, creds)); -} - -xla::Status PodCompiledProgramHandle::program_shape( - xla::ProgramShapeProto* program_shape) { - return driver_->GetCompiledProgramShape(operation_id(), program_shape); -} - -} // namespace - -REGISTER_TPU_DRIVER(kPodTpuDriverPrefix, - [](const TpuDriverConfig& config) - -> xla::StatusOr> { - return CreatePodTpuDriver( - config, - ::grpc::InsecureChannelCredentials()); // NOLINT - }); - -} // namespace tpu_driver diff --git a/xla/python/tpu_driver/recording_tpu_driver.cc b/xla/python/tpu_driver/recording_tpu_driver.cc deleted file mode 100644 index fb53c008a3fd7..0000000000000 --- a/xla/python/tpu_driver/recording_tpu_driver.cc +++ /dev/null @@ -1,590 +0,0 @@ -// Copyright 2019 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include -#include -#include -#include -#include -#include -#include - -#include "absl/base/internal/sysinfo.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "xla/python/tpu_driver/platform/external/compat.h" -#include "xla/python/tpu_driver/tpu_driver.h" -#include "xla/python/tpu_driver/tpu_driver.pb.h" -#include "xla/python/tpu_driver/tpu_service.grpc.pb.h" -#include "tsl/platform/file_system.h" -#include "tsl/platform/threadpool.h" - -/* - * The ReplayDriver wraps a concrete TpuDriver implementation and records the - * stream of operations to a log file. This log can be later replayed and - * analyzed for debugging. - */ - -namespace tpu_driver { -namespace { - -static std::atomic id_counter(0); - -using xla::Status; - -class RecordingTpuDriver; - -class RecordingEvent : public Event { - public: - explicit RecordingEvent(std::shared_ptr event) - : shared_event_(std::move(event)), id_(id_counter++) {} - - explicit RecordingEvent(std::shared_ptr event, int64_t id) - : shared_event_(event), id_(id) {} - - ~RecordingEvent() override = default; - - xla::Status Await() override { return shared_event_->Await(); } - - std::optional AwaitWithTimeout( - absl::Duration duration) override { - return shared_event_->AwaitWithTimeout(duration); - } - - void AddCallback(std::function callback) override { - return shared_event_->AddCallback(callback); - } - - private: - std::shared_ptr shared_event_; - - int64_t id_; - friend class RecordingTpuDriver; -}; - -class RecordingBufferHandle : public BufferHandle { - public: - explicit RecordingBufferHandle(std::unique_ptr handle) - : handle_(std::move(handle)), - id_(id_counter++), - event_(std::make_shared(handle_->OnReady(), id_)) {} - std::shared_ptr OnReady() override { return event_; } - int64_t size_in_bytes() override { return handle_->size_in_bytes(); } - std::optional shape() override { return handle_->shape(); } - - private: - std::unique_ptr handle_; - int64_t id_; - std::shared_ptr event_; - friend class RecordingTpuDriver; -}; - -class RecordingCompiledProgramHandle : public CompiledProgramHandle { - public: - explicit RecordingCompiledProgramHandle( - std::unique_ptr handle) - : handle_(std::move(handle)), - id_(id_counter++), - event_(std::make_shared(handle_->OnReady(), id_)) {} - std::shared_ptr OnReady() override { return event_; } - int64_t size_in_bytes() override { return handle_->size_in_bytes(); } - xla::Status program_shape(xla::ProgramShapeProto* program_shape) override { - return handle_->program_shape(program_shape); - } - - private: - std::unique_ptr handle_; - int64_t id_; - std::shared_ptr event_; - friend class RecordingTpuDriver; -}; - -class RecordingLoadedProgramHandle : public LoadedProgramHandle { - public: - explicit RecordingLoadedProgramHandle( - std::unique_ptr handle) - : handle_(std::move(handle)), - id_(id_counter++), - event_(std::make_shared(handle_->OnReady(), id_)) {} - std::shared_ptr OnReady() override { return event_; } - int64_t size_in_bytes() override { return handle_->size_in_bytes(); } - - private: - std::unique_ptr handle_; - int64_t id_; - std::shared_ptr event_; - friend class RecordingTpuDriver; -}; - -class RecordingTpuDriver : public TpuDriver { - public: - explicit RecordingTpuDriver(std::unique_ptr driver, - const std::string recording_path, - const bool flush) - : driver_(std::move(driver)), - recording_path_(recording_path), - flush_(flush) { - auto file_status = - tsl::Env::Default()->NewAppendableFile(recording_path_, &log_file_); - if (!file_status.ok()) { - LOG(FATAL) << "Unable to open " << recording_path_ - << " for appending. Error: " << file_status; - } - } - ~RecordingTpuDriver() override { - { - log_file_->Flush().IgnoreError(); - log_file_->Close().IgnoreError(); - log_file_ = nullptr; - } - } - - void QuerySystemInfo(SystemInfo* system_info) override { - // TODO(frankchn): Should we even save this event, since it is out-of-band. - driver_->QuerySystemInfo(system_info); - } - - Status Reset() override { return driver_->Reset(); } - - std::unique_ptr Allocate( - int32_t core_id, MemoryRegion region, int64_t num_bytes, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto handle = - driver_->Allocate(core_id, region, num_bytes, unwrapped_wait_for); - auto recording_handle = - std::make_unique(std::move(handle)); - auto handle_id = recording_handle->id_; - - { - StreamRequest::Entry r; - r.mutable_alloc()->set_core_id(core_id); - r.mutable_alloc()->set_region(region); - r.mutable_alloc()->set_num_bytes(num_bytes); - - PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id); - } - - return recording_handle; - } - - std::unique_ptr Allocate( - int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto handle = driver_->Allocate(core_id, region, shape, unwrapped_wait_for); - auto recording_handle = - std::make_unique(std::move(handle)); - auto handle_id = recording_handle->id_; - - { - StreamRequest::Entry r; - r.mutable_alloc()->set_core_id(core_id); - r.mutable_alloc()->set_region(region); - *(r.mutable_alloc()->mutable_shape()) = shape; - - PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id); - } - - return recording_handle; - } - - std::unique_ptr AllocateTuple( - int32_t core_id, MemoryRegion region, - absl::Span children, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - std::vector unwrapped_children; - std::vector child_ids; - const auto children_size = children.size(); - unwrapped_children.reserve(children_size); - child_ids.reserve(children_size); - for (auto child : children) { - BufferHandle* unwrapped_child = - static_cast(child)->handle_.get(); - unwrapped_children.push_back(unwrapped_child); - child_ids.push_back( - static_cast(child)->id_); - } - - auto thread_id = GetCurrentThreadId(); - auto handle = driver_->AllocateTuple(core_id, region, unwrapped_children, - unwrapped_wait_for); - auto recording_handle = - std::make_unique(std::move(handle)); - auto handle_id = recording_handle->id_; - - { - StreamRequest::Entry r; - r.mutable_alloc_tuple()->set_core_id(core_id); - r.mutable_alloc_tuple()->set_region(region); - - for (auto child : child_ids) { - r.mutable_alloc_tuple()->add_children(child); - } - - PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id); - } - - return recording_handle; - } - - std::shared_ptr Deallocate( - std::unique_ptr handle, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto recording_handle = static_cast(handle.get()); - int64_t recording_handle_id = recording_handle->id_; - auto event = driver_->Deallocate(std::move(recording_handle->handle_), - unwrapped_wait_for); - auto recording_event = std::make_shared(std::move(event)); - int64_t event_id = recording_event->id_; - - { - StreamRequest::Entry r; - r.mutable_dealloc()->set_handle(recording_handle_id); - PopulateAndSaveEntry(&r, wait_for, event_id, thread_id); - } - - return recording_event; - } - - std::shared_ptr TransferToDevice( - const void* src, BufferHandle* dst, - absl::Span wait_for) override { - int64_t num_bytes = dst->size_in_bytes(); - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto recording_handle = static_cast(dst); - int64_t recording_handle_id = recording_handle->id_; - auto recording_event = - std::make_shared(driver_->TransferToDevice( - src, static_cast(dst)->handle_.get(), - unwrapped_wait_for)); - int64_t event_id = recording_event->id_; - - { - StreamRequest::Entry r; - r.mutable_transfer_to()->set_target_handle(recording_handle_id); - if (num_bytes > 0) { - r.mutable_transfer_to()->mutable_data()->assign( - static_cast(src), num_bytes); - } else { - *r.mutable_transfer_to()->mutable_data() = ""; - } - PopulateAndSaveEntry(&r, wait_for, event_id, thread_id); - } - - return recording_event; - } - - std::shared_ptr TransferFromDevice( - const BufferHandle* src, void* dst, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto src_handle_id = static_cast(src)->id_; - auto recording_event = - std::make_shared(driver_->TransferFromDevice( - static_cast(src)->handle_.get(), dst, - unwrapped_wait_for)); - auto event_id = recording_event->id_; - - { - StreamRequest::Entry r; - r.mutable_transfer_from()->set_source_handle(src_handle_id); - PopulateAndSaveEntry(&r, wait_for, event_id, thread_id); - } - - return recording_event; - } - - std::shared_ptr TransferFromDeviceToDevice( - const BufferHandle* src, BufferHandle* dst, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto src_handle_id = static_cast(src)->id_; - auto dst_handle_id = static_cast(dst)->id_; - auto recording_event = - std::make_shared(driver_->TransferFromDeviceToDevice( - static_cast(src)->handle_.get(), - static_cast(dst)->handle_.get(), - unwrapped_wait_for)); - auto event_id = recording_event->id_; - - { - StreamRequest::Entry r; - r.mutable_transfer_from_to()->set_source_handle(src_handle_id); - r.mutable_transfer_from_to()->set_target_handle(dst_handle_id); - PopulateAndSaveEntry(&r, wait_for, event_id, thread_id); - } - - return recording_event; - } - - std::unique_ptr CompileProgram( - const xla::HloProto& source, int32_t num_replicas, - absl::Span wait_for, - const xla::DebugOptions& debug_options) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto recording_handle = std::make_unique( - driver_->CompileProgram(source, num_replicas, unwrapped_wait_for, - debug_options)); - auto handle_id = recording_handle->id_; - - { - StreamRequest::Entry r; - *r.mutable_compile()->mutable_hlo_program() = source; - r.mutable_compile()->set_num_replicas(num_replicas); - PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id); - } - - return recording_handle; - } - - std::unique_ptr LoadProgram( - int32_t core_id, const CompiledProgramHandle* handle, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto compiled_handle_id = - static_cast(handle)->id_; - auto recording_handle = - std::make_unique(driver_->LoadProgram( - core_id, - static_cast(handle) - ->handle_.get(), - unwrapped_wait_for)); - auto handle_id = recording_handle->id_; - { - StreamRequest::Entry r; - r.mutable_load()->set_core_id(core_id); - r.mutable_load()->set_compiled_program_handle(compiled_handle_id); - PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id); - } - - return recording_handle; - } - - std::shared_ptr UnloadProgram( - std::unique_ptr handle, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto loaded_handle_id = - static_cast(handle.get())->id_; - auto recording_event = - std::make_shared(driver_->UnloadProgram( - std::move(static_cast(handle.get()) - ->handle_), - unwrapped_wait_for)); - auto event_id = recording_event->id_; - - { - StreamRequest::Entry r; - r.mutable_unload()->set_loaded_program_handle(loaded_handle_id); - PopulateAndSaveEntry(&r, wait_for, event_id, thread_id); - } - - return recording_event; - } - - std::shared_ptr ExecuteProgram( - LoadedProgramHandle* program, absl::Span inputs, - absl::Span outputs, - const xla::DeviceAssignmentProto& device_assignment, - absl::Span wait_for) override { - auto unwrapped_wait_for = UnwrapWaitFor(wait_for); - - auto thread_id = GetCurrentThreadId(); - auto program_handle_id = - static_cast(program)->id_; - - std::vector unwrapped_inputs; - std::vector input_ids; - const auto inputs_size = inputs.size(); - unwrapped_inputs.reserve(inputs_size); - input_ids.reserve(inputs_size); - for (auto input : inputs) { - BufferHandle* unwrapped_input = - static_cast(input)->handle_.get(); - unwrapped_inputs.push_back(unwrapped_input); - input_ids.push_back( - static_cast(input)->id_); - } - - std::vector unwrapped_outputs; - std::vector output_ids; - const auto output_size = outputs.size(); - unwrapped_outputs.reserve(output_size); - output_ids.reserve(output_size); - for (auto output : outputs) { - BufferHandle* unwrapped_output = - static_cast(output)->handle_.get(); - unwrapped_outputs.push_back(unwrapped_output); - output_ids.push_back( - static_cast(output)->id_); - } - - auto recording_event = - std::make_shared(driver_->ExecuteProgram( - static_cast(program)->handle_.get(), - unwrapped_inputs, unwrapped_outputs, device_assignment, - unwrapped_wait_for)); - auto event_id = recording_event->id_; - - { - StreamRequest::Entry r; - r.mutable_execute()->set_loaded_program_handle(program_handle_id); - for (auto input_id : input_ids) { - r.mutable_execute()->add_input_handle(input_id); - } - for (auto output_id : output_ids) { - r.mutable_execute()->add_output_handle(output_id); - } - *r.mutable_execute()->mutable_device_assignment() = device_assignment; - - PopulateAndSaveEntry(&r, wait_for, event_id, thread_id); - } - - return recording_event; - } - - std::unique_ptr GetLinearizer() override { - return driver_->GetLinearizer(); - } - - private: - std::unique_ptr driver_; - const std::string recording_path_; - const bool flush_; - - std::unique_ptr log_file_; - - void PopulateAndSaveEntry(StreamRequest::Entry* r, - absl::Span wait_for, - int64_t handle_id, int64_t thread_id) { - for (auto event : wait_for) { - auto recording_event = static_cast(event); - r->add_wait_for_id(recording_event->id_); - } - r->set_operation_id(handle_id); - r->set_thread_id(thread_id); - - uint64_t data_size = r->ByteSizeLong(); - std::vector buffer; - buffer.resize(sizeof(data_size) + data_size); - memcpy(buffer.data(), &data_size, sizeof(data_size)); - r->SerializeToArray(buffer.data() + sizeof(data_size), data_size); - - { - if (log_file_ == nullptr) { - LOG(WARNING) << "The TPU driver has been shut down before all logging " - "has been written."; - return; - } - - absl::string_view buffer_sp(buffer.data(), buffer.size()); - auto data_status = log_file_->Append(buffer_sp); - if (!data_status.ok()) { - LOG(WARNING) << "Unable to write data to log file. File possibly " - "corrupt. Error: " - << data_status; - } - - if (flush_) { - auto flush_status = log_file_->Flush(); - if (!flush_status.ok()) { - LOG(WARNING) << "Unable to flush data to log file. File possibly " - "corrupt. Error: " - << flush_status; - } - - auto sync_status = log_file_->Sync(); - if (!sync_status.ok()) { - LOG(WARNING) << "Unable to sync log file. File possibly " - "corrupt. Error: " - << sync_status; - } - } - } - } - - std::vector UnwrapWaitFor(absl::Span wait_for) { - std::vector unwrapped_events; - for (auto event : wait_for) { - Event* unwrapped_event = - static_cast(event)->shared_event_.get(); - unwrapped_events.push_back(unwrapped_event); - } - return unwrapped_events; - } - - int64_t GetCurrentThreadId() { return absl::base_internal::GetTID(); } -}; - -xla::StatusOr> RegisterRecordingTpuDriver( - const TpuDriverConfig& config) { - std::vector configs = absl::StrSplit(config.worker(), '|'); - - std::string file; - std::string worker; - bool flush = false; - - for (const auto& config : configs) { - std::vector kv = - absl::StrSplit(config, absl::MaxSplits('=', 1)); - if (kv[0] == "file") { - file = kv[1]; - } - if (kv[0] == "worker") { - worker = kv[1]; - } - if (kv[0] == "flush") { - if (kv[1] == "true" || kv[1] == "1") { - flush = true; - } - } - } - - TpuDriverConfig worker_config; - worker_config.set_worker(worker); - - auto driver_status = TpuDriverRegistry::Open(worker_config); - if (!driver_status.ok()) return driver_status.status(); - return std::unique_ptr( - new RecordingTpuDriver(std::move(driver_status).value(), file, flush)); -} - -// To record a sequence of operations, set the worker configuration string to -// record://|file=|worker=grpc://1.2.3.4:8470 (for GRPC). -REGISTER_TPU_DRIVER("record://", RegisterRecordingTpuDriver); - -} // namespace -} // namespace tpu_driver diff --git a/xla/python/tpu_driver/tpu_driver.cc b/xla/python/tpu_driver/tpu_driver.cc deleted file mode 100644 index b1113601a7ee5..0000000000000 --- a/xla/python/tpu_driver/tpu_driver.cc +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2019 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "xla/python/tpu_driver/tpu_driver.h" - -#include -#include -#include -#include - -#include "absl/strings/match.h" -#include "absl/synchronization/mutex.h" -#include "xla/util.h" - -namespace tpu_driver { - -namespace { - -typedef absl::flat_hash_map< - std::string, std::function>( - const TpuDriverConfig&)>> - DriverRegistryMap; - -DriverRegistryMap* GetDriverRegistryMap() { - static DriverRegistryMap* driver_registry = new DriverRegistryMap(); - return driver_registry; -} - -int64_t ByteSizeOfPrimitiveType(xla::PrimitiveType primitive_type) { - switch (primitive_type) { - case xla::PrimitiveType::PRED: - return sizeof(int8_t); - case xla::PrimitiveType::S8: - return sizeof(int8_t); - case xla::PrimitiveType::S16: - return sizeof(int16_t); - case xla::PrimitiveType::S32: - return sizeof(int32_t); - case xla::PrimitiveType::S64: - return sizeof(int64_t); - case xla::PrimitiveType::U8: - return sizeof(uint8_t); - case xla::PrimitiveType::U16: - return sizeof(uint16_t); - case xla::PrimitiveType::U32: - return sizeof(uint32_t); - case xla::PrimitiveType::U64: - return sizeof(uint64_t); - case xla::PrimitiveType::BF16: - return sizeof(float) / 2; - case xla::PrimitiveType::F16: - return sizeof(float) / 2; - case xla::PrimitiveType::F32: - return sizeof(float); - case xla::PrimitiveType::F64: - return sizeof(double); - case xla::PrimitiveType::C64: - return sizeof(std::complex); - case xla::PrimitiveType::C128: - return sizeof(std::complex); - case xla::PrimitiveType::TOKEN: - case xla::PrimitiveType::TUPLE: - case xla::PrimitiveType::OPAQUE_TYPE: - LOG(FATAL) << PrimitiveType_Name(primitive_type) - << " primitive type has no definitive size"; - default: - LOG(FATAL) << "Unhandled primitive type " << primitive_type; - } -} - -} // namespace - -/*static*/ int TpuDriverRegistry::RegisterDriver( - const std::string& prefix, - const std::function>( - const TpuDriverConfig&)>& creator) { - (*GetDriverRegistryMap())[prefix] = creator; - return GetDriverRegistryMap()->size(); -} - -/*static*/ xla::StatusOr> TpuDriverRegistry::Open( - const TpuDriverConfig& config) { - for (const auto& driver : *GetDriverRegistryMap()) { - if (absl::StartsWith(config.worker(), driver.first)) { - return driver.second(config); - } - } - return xla::NotFound("Unable to find driver in registry given worker: %s", - config.worker()); -} - -int64_t ComputeBytesFromShape(const xla::ShapeProto& shape) { - if (shape.tuple_shapes_size() > 0) { - LOG(FATAL) << "Tuples are not supported at the moment."; - } - - int64_t num_elems = 1; - for (auto dim : shape.dimensions()) { - num_elems *= dim; - } - - return ByteSizeOfPrimitiveType(shape.element_type()) * num_elems; -} - -} // namespace tpu_driver diff --git a/xla/python/tpu_driver/tpu_driver.h b/xla/python/tpu_driver/tpu_driver.h deleted file mode 100644 index ebf247b6672bc..0000000000000 --- a/xla/python/tpu_driver/tpu_driver.h +++ /dev/null @@ -1,255 +0,0 @@ -// Copyright 2019 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== - -#ifndef XLA_PYTHON_TPU_DRIVER_TPU_DRIVER_H_ -#define XLA_PYTHON_TPU_DRIVER_TPU_DRIVER_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "xla/python/tpu_driver/platform/external/compat.h" -#include "xla/python/tpu_driver/tpu_driver.pb.h" -#include "xla/service/hlo.pb.h" -#include "xla/status.h" -#include "xla/statusor.h" -#include "xla/xla.pb.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" - -// This API is EXPERIMENTAL and under active development. It is subject to -// change without notice. - -namespace tpu_driver { - -int64_t ComputeBytesFromShape(const xla::ShapeProto& shape); - -// Represents the deferred completion of a scheduled operation. -// -// Events may be blocked on, or used as `wait_for` arguments to enforce -// inter-operation dependencies. -class Event { - public: - virtual ~Event() = default; - - // Blocks until the event completes and returns the result status. - virtual xla::Status Await() = 0; - // Returns an empty result if the wait times out. - virtual std::optional AwaitWithTimeout( - absl::Duration duration) = 0; - - // If the event is already done, the callback is called immediately. - virtual void AddCallback(std::function callback) = 0; -}; - -// Represents a device memory allocation. -class BufferHandle { - public: - virtual ~BufferHandle() = default; - - // This event completes after the device memory is actually allocated. - // - // Methods that take a buffer handle, such as ExecuteProgram and Transfer*, - // automatically add this event as a dependency. - virtual std::shared_ptr OnReady() = 0; - - virtual int64_t size_in_bytes() = 0; - virtual std::optional shape() = 0; -}; - -// Represents a compiled program on the host. -class CompiledProgramHandle { - public: - virtual ~CompiledProgramHandle() = default; - - // This Event completes after the program is actually compiled on the host. - // - // Methods that take a compiled program handle, including LoadProgram, - // automatically add this event as a dependency. - virtual std::shared_ptr OnReady() = 0; - - virtual int64_t size_in_bytes() { - LOG(FATAL) << "Unimplemented."; - return 0; - } - - // Returns the shape of the compiled program. Blocks until compile completes. - virtual xla::Status program_shape(xla::ProgramShapeProto* program_shape) = 0; -}; - -// Represents a program loaded on the device. -class LoadedProgramHandle { - public: - virtual ~LoadedProgramHandle() = default; - - // This Event completes after the program is actually loaded on the device. - // - // Methods that take a loaded program handle, including ExecuteProgram and - // UnloadProgram, automatically add this event as a dependency. - virtual std::shared_ptr OnReady() = 0; - - virtual int64_t size_in_bytes() { - LOG(FATAL) << "Unimplemented."; - return 0; - } -}; - -// A TpuLinearizer manages the linearization and delinearization of user buffers -// in the TPU driver. This interface is not yet implemented. -class TpuLinearizer { - public: - virtual ~TpuLinearizer() = default; - - int64_t ComputeBytesFromShape(const xla::ShapeProto& shape) { - return ::tpu_driver::ComputeBytesFromShape(shape); - } - virtual int64_t ComputeLinearizedBytesFromShape( - const xla::ShapeProto& shape) = 0; - - virtual xla::Status LinearizeShape(void* dst, const void* src, - const xla::ShapeProto& shape) = 0; - virtual xla::Status DelinearizeShape(void* dst, const void* src, - const xla::ShapeProto& shape) = 0; -}; - -// A TpuDriver manages a set of operations scheduled to run on a TPU system. -// -// By default, two independently scheduled operations may execute in any order. -// Ordering can be imposed in one of two ways: -// -// 1. Users can specify event dependencies via the `wait_for` argument. -// 2. Operations using buffer or program handles implicitly wait for the handles -// to become ready before executing. -// -// For returned handle objects, the user is responsible for calling the release -// methods (Deallocate, UnloadProgram, etc.) that consume the given unique_ptr -// arguments and free up device resources. For returned event objects, there is -// no release method; the user can let them go out of scope naturally. As soon -// as those methods accepting plain-pointer arguments return, the user can let -// the corresponding smart-pointer objects be released or go out of scope, -// regardless of whether the scheduled device operations have started execution. -class TpuDriver { - public: - virtual ~TpuDriver() = default; - - virtual void QuerySystemInfo(SystemInfo* system_info) = 0; - // Synchronous. Reset the state of the TPU driver. After Reset(), this TPU - // driver object is no longer usable. Users must destroy this object and - // create a new one. - // - // All running programs will be terminated and all allocations reset. All - // events and buffer handles created prior to Reset() will be invalid, and any - // use will result in undefined behavior. - virtual xla::Status Reset() = 0; - - virtual std::unique_ptr Allocate( - int32_t core_id, MemoryRegion region, int64_t num_bytes, - absl::Span wait_for) = 0; - virtual std::unique_ptr Allocate( - int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape, - absl::Span wait_for) = 0; - - // Allocate a buffer representing a tuple of `children` buffers. - // - // The returned tuple buffer handle does not manage the memory of `children`: - // all `children` buffer handles must outlive the last usage of this tuple - // buffer handle. One way to guarantee that is to deallocate the tuple buffer - // handle before deallocating any buffer handle in `children`. - // - // All `children` buffers must exist in the same `core_id` and `region`. - // If `children` is empty, a zero-sized tuple will be allocated in `region`. - virtual std::unique_ptr AllocateTuple( - int32_t core_id, MemoryRegion region, - absl::Span children, - absl::Span wait_for) = 0; - virtual std::shared_ptr Deallocate( - std::unique_ptr handle, - absl::Span wait_for) = 0; - - /* For buffers declared with an xla::ShapeProto rather than a raw size, - * `src` must be laid out in consecutive row-major format for ingestion, and - * each element must take up the number of bytes specified by the type. - * - * For example, for a [3,3,3] tensor with a Float32 type, the memory layout - * would be as follows: - * - * [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], ..., [0,2,2], [1,0,0], ... - * [1,2,2], [2,0,0], ..., [2,2,2], - * - * and the entire buffer will be 108 bytes (27 elements x 4 bytes). - * - * See - * https://eli.thegreenplace.net/2015/memory-layout-of-multi-dimensional-arrays - * for a more detailed description. - * - * `TransferFromDevice` will write out the shape back in this order as well. - */ - virtual std::shared_ptr TransferToDevice( - const void* src, BufferHandle* dst, - absl::Span wait_for) = 0; - virtual std::shared_ptr TransferFromDevice( - const BufferHandle* src, void* dst, - absl::Span wait_for) = 0; - - virtual std::shared_ptr TransferFromDeviceToDevice( - const BufferHandle* src, BufferHandle* dst, - absl::Span wait_for) = 0; - - virtual std::unique_ptr CompileProgram( - const xla::HloProto& source, int32_t num_replicas, - absl::Span wait_for, - const xla::DebugOptions& debug_options) = 0; - virtual std::unique_ptr LoadProgram( - int32_t core_id, const CompiledProgramHandle* handle, - absl::Span wait_for) = 0; - virtual std::shared_ptr UnloadProgram( - std::unique_ptr handle, - absl::Span wait_for) = 0; - virtual std::shared_ptr ExecuteProgram( - LoadedProgramHandle* program, absl::Span inputs, - absl::Span outputs, - const xla::DeviceAssignmentProto& device_assignment, - absl::Span wait_for) = 0; - - virtual std::unique_ptr GetLinearizer() { return nullptr; } -}; - -class TpuDriverRegistry { - public: - static xla::StatusOr> Open( - const TpuDriverConfig& config); - static int RegisterDriver( - const std::string& prefix, - const std::function>( - const TpuDriverConfig&)>& creator); -}; - -#define REGISTER_TPU_DRIVER(prefix, fn) \ - REGISTER_TPU_DRIVER_HELPER(__COUNTER__, prefix, fn) -#define REGISTER_TPU_DRIVER_HELPER(ctr, prefix, fn) \ - static int register_tpu_driver_count_unused_##ctr = \ - ::tpu_driver::TpuDriverRegistry::RegisterDriver(prefix, fn); - -} // namespace tpu_driver - -#endif // XLA_PYTHON_TPU_DRIVER_TPU_DRIVER_H_ diff --git a/xla/python/tpu_driver/tpu_driver.proto b/xla/python/tpu_driver/tpu_driver.proto deleted file mode 100644 index f9f2494eaf1c3..0000000000000 --- a/xla/python/tpu_driver/tpu_driver.proto +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2019 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== - -syntax = "proto2"; - -package tpu_driver; - -enum MemoryRegion { HBM = 1; } - -message ChipCoordinate { - required int32 x = 1; - required int32 y = 2; - required int32 z = 3; -} - -message TpuCoreInfo { - required int32 id = 1; - optional int32 core_on_chip_index = 2; - optional int32 core_on_host_index = 3; - optional int64 hbm_bytes_available = 100; - optional int64 hbm_bytes_allocatable = 101; -} - -message TpuChipInfo { - repeated TpuCoreInfo core = 1; - optional int32 host_id = 2; - optional ChipCoordinate chip_coord = 3; -} - -message CpuInfo { - required int32 num_cpu_cores = 1; - required float cpu_load_average_1min = 2; - required int64 ram_bytes_total = 100; - required int64 ram_bytes_available = 101; -} - -message SystemInfo { - repeated TpuChipInfo tpu_chip = 1; - required CpuInfo cpu = 2; - repeated TpuCoreInfo local_core = 3; - optional int32 host_id = 4; - optional int32 host_count = 5; - optional int32 chip_count = 6; - optional int32 core_count = 7; -} - -message TpuDriverConfig { - optional string worker = 1; - - message GrpcConfig { - // Time in seconds before the initial connection to the server will timeout. - optional int64 connection_timeout_secs = 1 [default = 30]; - - // Time in seconds the server may be unresponsive before terminating the - // connection. - optional int64 keepalive_timeout_secs = 2 [default = 30]; - } - - optional GrpcConfig grpc = 2; -} diff --git a/xla/python/tpu_driver/tpu_service.proto b/xla/python/tpu_driver/tpu_service.proto deleted file mode 100644 index 0ef37ce755700..0000000000000 --- a/xla/python/tpu_driver/tpu_service.proto +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright 2019 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== - -syntax = "proto2"; - -package tpu_driver; - -import "xla/python/tpu_driver/tpu_driver.proto"; -import "xla/service/hlo.proto"; -import "xla/xla.proto"; -import "xla/xla_data.proto"; - -option optimize_for = SPEED; - -message StatusMessage { - required int32 code = 1; - optional string message = 2; -} - -message AllocateRequest { - required int32 core_id = 1; - required MemoryRegion region = 2; - oneof size { - int64 num_bytes = 3; - xla.ShapeProto shape = 4; - } -} - -message AllocateTupleRequest { - required int32 core_id = 1; - required MemoryRegion region = 2; - repeated int64 children = 3; -} - -message DeallocateRequest { - required int64 handle = 1; -} - -message TransferToDeviceRequest { - required int64 target_handle = 1; - required bytes data = 2; -} - -message TransferFromDeviceRequest { - required int64 source_handle = 1; -} - -message TransferFromDeviceResponse { - required bytes data = 2; -} - -message TransferFromDeviceToDeviceRequest { - required int64 source_handle = 1; - required int64 target_handle = 2; -} - -message CompileRequest { - required xla.HloProto hlo_program = 1; - optional int64 num_replicas = 2; - optional xla.DebugOptions debug_options = 3; -} - -message CompiledProgramMetadata { - required xla.ProgramShapeProto program_shape = 1; -} - -message CompileResponse { - required CompiledProgramMetadata metadata = 1; -} - -message LoadProgramRequest { - required int32 core_id = 1; - required int64 compiled_program_handle = 2; -} - -message UnloadProgramRequest { - required int64 loaded_program_handle = 1; -} - -message ExecuteRequest { - required int64 loaded_program_handle = 1; - repeated int64 input_handle = 2; - repeated int64 output_handle = 3; - optional xla.DeviceAssignmentProto device_assignment = 4; -} - -message StreamRequest { - message Entry { - oneof request { - AllocateRequest alloc = 1; - AllocateTupleRequest alloc_tuple = 2; - DeallocateRequest dealloc = 3; - TransferToDeviceRequest transfer_to = 4; - TransferFromDeviceRequest transfer_from = 5; - TransferFromDeviceToDeviceRequest transfer_from_to = 10; - CompileRequest compile = 6; - LoadProgramRequest load = 7; - UnloadProgramRequest unload = 8; - ExecuteRequest execute = 9; - } - // If specified, a list of encoded EventId values. - repeated int64 wait_for_id = 20; - // A unique, encoded EventId value. - // For Allocate, Compile, and Load, this also defines the result handle. - required int64 operation_id = 21; - - // A unique identifier for the thread that issued this request. Currently - // for debugging purposes only. - optional int64 thread_id = 22; - } - - repeated Entry entry = 30; -} - -message StreamResponse { - message Entry { - oneof response { - TransferFromDeviceResponse transfer_from = 3; - CompileResponse compile = 4; - } - required StatusMessage status = 10; - // Echos the given encoded EventId value. - required int64 operation_id = 11; - } - - repeated Entry entry = 20; -} - -message OpenRequest { - // The version number for this client. Versions are bumped in case of - // backwards incompatible client-server protocol changes. Servers will reject - // clients with an unsupported version. - optional int32 client_version = 1 [default = 0]; -} - -message OpenResponse { - required uint32 client_id = 1; - - // Maximum time this client can be idle before it is GC'ed and all resources - // released. - optional int32 max_idle_time_seconds = 2 [default = 3600]; -} - -message CloseRequest { - required fixed32 client_id = 1; -} - -message CloseResponse {} - -message ResetRequest {} - -message ResetResponse {} - -message QuerySystemInfoRequest {} - -message QuerySystemInfoResponse { - required SystemInfo system_info = 1; -} - -service CloudTpuDriver { - // Open the driver. If the driver is already open, return an error. - rpc Open(OpenRequest) returns (OpenResponse); - - // Close the driver. Any outstanding requests will be terminated. - rpc Close(CloseRequest) returns (CloseResponse); - - // Reset the driver. All connected clients will be disconnected. - rpc Reset(ResetRequest) returns (ResetResponse); - - // Query the driver for current system performance information. - rpc QuerySystemInfo(QuerySystemInfoRequest) returns (QuerySystemInfoResponse); - - // Enqueue an operation to be executed when its dependencies are satisfied. - rpc StreamExecute(stream StreamRequest) returns (stream StreamResponse); -} diff --git a/xla/python/traceback.cc b/xla/python/traceback.cc index 88d6965b56e23..6f50d6715b5d2 100644 --- a/xla/python/traceback.cc +++ b/xla/python/traceback.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,23 +15,34 @@ limitations under the License. #include "xla/python/traceback.h" -#include -#include +#include #include +#include #include #include #include "absl/hash/hash.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" -#include "pybind11/pytypes.h" // from @pybind11 -#include "xla/python/exceptions.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/optional.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/vector.h" // from @nanobind // IWYU pragma: keep +#include "xla/pjrt/exceptions.h" +#include "xla/python/nb_class_ptr.h" #include "xla/python/python_ref_manager.h" -#include "tsl/platform/logging.h" +#include "tsl/platform/platform.h" + +#ifdef PLATFORM_GOOGLE +#define Py_BUILD_CORE +#include "internal/pycore_frame.h" +#undef Py_BUILD_CORE +#endif // PLATFORM_GOOGLE namespace xla { -namespace py = pybind11; +namespace nb = nanobind; bool Traceback::enabled_ = true; @@ -54,7 +65,22 @@ Traceback::Traceback() { Py_INCREF(py_frame->f_code); frames_.emplace_back(py_frame->f_code, py_frame->f_lasti * kLastiWordBytes); } -#else // PY_VERSION_HEX < 0x030b0000 +#else // PY_VERSION_HEX < 0x030b0000 + +#ifdef PLATFORM_GOOGLE + // This code is equivalent to the version using public APIs, but it saves us + // an allocation of one object per stack frame. However, this is definitely + // violating the API contract of CPython, so we only use this where we can be + // confident we know exactly which CPython we are using (internal to Google). + // Feel free to turn this on if you like, but it might break at any time! + for (_PyInterpreterFrame* f = thread_state->cframe->current_frame; + f != nullptr; f = f->previous) { + if (_PyFrame_IsIncomplete(f)) continue; + Py_INCREF(f->f_code); + frames_.emplace_back(f->f_code, + _PyInterpreterFrame_LASTI(f) * sizeof(_Py_CODEUNIT)); + } +#else // PLATFORM_GOOGLE PyFrameObject* next; for (PyFrameObject* py_frame = PyThreadState_GetFrame(thread_state); py_frame != nullptr; py_frame = next) { @@ -62,6 +88,8 @@ Traceback::Traceback() { next = PyFrame_GetBack(py_frame); Py_XDECREF(py_frame); } +#endif // PLATFORM_GOOGLE + #endif // PY_VERSION_HEX < 0x030b0000 } @@ -80,7 +108,8 @@ Traceback::Traceback(Traceback&& other) : frames_(std::move(other.frames_)) { } std::string Traceback::Frame::ToString() const { - return absl::StrFormat("%s:%d (%s)", file_name, line_num, function_name); + return absl::StrFormat("%s:%d (%s)", nb::cast(file_name), + line_num, nb::cast(function_name)); } std::string Traceback::ToString() const { @@ -98,36 +127,28 @@ std::vector Traceback::Frames() const { std::vector frames; frames.reserve(frames_.size()); for (const auto& frame : frames_) { - frames.push_back(Frame{ - std::string(py::reinterpret_borrow(frame.first->co_filename)), - std::string(py::reinterpret_borrow(frame.first->co_name)), - frame.first->co_firstlineno, - PyCode_Addr2Line(frame.first, frame.second)}); + frames.push_back(Frame{nb::borrow(frame.first->co_filename), + nb::borrow(frame.first->co_name), + frame.first->co_firstlineno, + PyCode_Addr2Line(frame.first, frame.second)}); } return frames; } -std::shared_ptr Traceback::Get() { +std::optional> Traceback::Get() { DCHECK(PyGILState_Check()); if (!enabled_) { - return nullptr; + return std::nullopt; } - return std::make_shared(); -} - -void Traceback::SafeDestroy(Traceback traceback) { - // We want Traceback objects to be safe to destroy without holding the - // GIL, so we defer destruction of the strings. - GlobalPyRefManager()->AddGarbage(traceback.frames_); - traceback.frames_.clear(); + return make_nb_class(); } void Traceback::SetEnabled(bool enabled) { enabled_ = enabled; } -py::object Traceback::AsPythonTraceback() const { - py::object traceback = py::none(); - py::dict globals; - py::handle traceback_type(reinterpret_cast(&PyTraceBack_Type)); +nb::object Traceback::AsPythonTraceback() const { + nb::object traceback = nb::none(); + nb::dict globals; + nb::handle traceback_type(reinterpret_cast(&PyTraceBack_Type)); for (const std::pair& frame : frames_) { int lineno = PyCode_Addr2Line(frame.first, frame.second); // Under Python 3.11 we observed crashes when using a fake PyFrameObject @@ -148,8 +169,7 @@ py::object Traceback::AsPythonTraceback() const { traceback = traceback_type( /*tb_next=*/std::move(traceback), /*tb_frame=*/ - py::reinterpret_steal( - reinterpret_cast(py_frame)), + nb::steal(reinterpret_cast(py_frame)), /*tb_lasti=*/0, /*tb_lineno=*/ PyCode_Addr2Line(frame.first, frame.second)); @@ -157,23 +177,23 @@ py::object Traceback::AsPythonTraceback() const { return traceback; } -void BuildTracebackSubmodule(py::module& m) { - py::class_(m, "Frame") - .def_readonly("file_name", &Traceback::Frame::file_name) - .def_readonly("function_name", &Traceback::Frame::function_name) - .def_readonly("function_start_line", - &Traceback::Frame::function_start_line) - .def_readonly("line_num", &Traceback::Frame::line_num) +void BuildTracebackSubmodule(nb::module_& m) { + nb::class_(m, "Frame") + .def_ro("file_name", &Traceback::Frame::file_name) + .def_ro("function_name", &Traceback::Frame::function_name) + .def_ro("function_start_line", &Traceback::Frame::function_start_line) + .def_ro("line_num", &Traceback::Frame::line_num) .def("__repr__", [](const Traceback::Frame& frame) { - return absl::StrFormat("%s;%s:%d", frame.function_name, frame.file_name, - frame.line_num); + return absl::StrFormat( + "%s;%s:%d", nb::cast(frame.function_name), + nb::cast(frame.file_name), frame.line_num); }); - py::class_> traceback( - m, "Traceback", "Represents a Python stack trace."); - traceback.def_property_static( - "enabled", [](py::object /* cls */) { return Traceback::enabled(); }, - [](py::object /* cls */, bool enabled) { + nb::class_ traceback(m, "Traceback", + "Represents a Python stack trace."); + traceback.def_prop_rw_static( + "enabled", [](nb::object /* cls */) { return Traceback::enabled(); }, + [](nb::object /* cls */, bool enabled) { return Traceback::SetEnabled(enabled); }); traceback.def_static( @@ -186,20 +206,23 @@ void BuildTracebackSubmodule(py::module& m) { collection has a small overhead, so it is disabled by default. If traceback collection is disabled, returns ``None``. )doc"); - traceback.def_property_readonly("frames", &Traceback::Frames); - traceback.def("raw_frames", [](const Traceback& tb) -> py::tuple { + traceback.def_prop_ro("frames", &Traceback::Frames); + traceback.def("raw_frames", [](const Traceback& tb) -> nb::tuple { // We return a tuple of lists, rather than a list of tuples, because it // is cheaper to allocate only three Python objects for everything rather // than one per frame. - py::list out_code(tb.raw_frames().size()); - py::list out_lasti(tb.raw_frames().size()); + nb::list out_code = nb::steal(PyList_New(tb.raw_frames().size())); + nb::list out_lasti = + nb::steal(PyList_New(tb.raw_frames().size())); for (size_t i = 0; i < tb.raw_frames().size(); ++i) { const auto& frame = tb.raw_frames()[i]; - out_code[i] = py::reinterpret_borrow( - reinterpret_cast(frame.first)); - out_lasti[i] = py::int_(frame.second); + PyObject* code = reinterpret_cast(frame.first); + Py_INCREF(code); + PyList_SET_ITEM(out_code.ptr(), i, code); + PyList_SET_ITEM(out_lasti.ptr(), i, + nb::int_(frame.second).release().ptr()); } - return py::make_tuple(out_code, out_lasti); + return nb::make_tuple(out_code, out_lasti); }); traceback.def("__str__", &Traceback::ToString); traceback.def("__eq__", @@ -210,7 +233,7 @@ void BuildTracebackSubmodule(py::module& m) { traceback.def_static( "code_addr2line", - [](py::handle code, int lasti) { + [](nb::handle code, int lasti) { if (!PyCode_Check(code.ptr())) { throw xla::XlaRuntimeError("code argument must be a code object"); } @@ -222,7 +245,7 @@ void BuildTracebackSubmodule(py::module& m) { #if PY_VERSION_HEX >= 0x030b0000 traceback.def_static( "code_addr2location", - [](py::handle code, int lasti) { + [](nb::handle code, int lasti) { if (!PyCode_Check(code.ptr())) { throw xla::XlaRuntimeError("code argument must be a code object"); } @@ -230,9 +253,9 @@ void BuildTracebackSubmodule(py::module& m) { if (!PyCode_Addr2Location(reinterpret_cast(code.ptr()), lasti, &start_line, &start_column, &end_line, &end_column)) { - throw py::error_already_set(); + throw nb::python_error(); } - return py::make_tuple(start_line, start_column, end_line, end_column); + return nb::make_tuple(start_line, start_column, end_line, end_column); }, "Python wrapper around the Python C API function PyCode_Addr2Location"); #endif // PY_VERSION_HEX >= 0x030b0000 @@ -242,7 +265,7 @@ void BuildTracebackSubmodule(py::module& m) { // Python thread. m.def( "replace_thread_exc_traceback", - [](py::object tb) { + [](nb::object tb) { if (!tb.is_none() && !PyTraceBack_Check(tb.ptr())) { throw xla::XlaRuntimeError( "argument must be a traceback object or None"); @@ -258,7 +281,7 @@ void BuildTracebackSubmodule(py::module& m) { thread_state->exc_info->exc_traceback = new_tb; Py_XDECREF(old_exc_traceback); }, - py::arg("traceback")); + nb::arg("traceback").none()); #endif // PY_VERSION_HEX < 0x30b0000 } } // namespace xla diff --git a/xla/python/traceback.h b/xla/python/traceback.h index 6551a7628c21a..14d903905ef9a 100644 --- a/xla/python/traceback.h +++ b/xla/python/traceback.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,36 +17,34 @@ limitations under the License. #define XLA_PYTHON_TRACEBACK_H_ #include +#include #include #include #include // placeholder for index annotation headers #include "absl/container/inlined_vector.h" -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/stl.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind +#include "xla/python/nb_class_ptr.h" namespace xla { -// Represents a Python traceback. +// Represents a Python traceback. This object is designed to be allocated on +// the Python heap; creating or destroying a traceback requires the GIL. class Traceback { public: - // Require GIL. Creates a Traceback object that requires destructor to be + // Requires GIL. Creates a Traceback object that requires destructor to be // invoked with GIL held as well. - static std::shared_ptr Get(); + static std::optional> Get(); - // Safely destroy the traceback object regardless of whether GIL is held or - // not. - static void SafeDestroy(Traceback traceback); - - // Require GIL. + // Requires GIL. static bool enabled() { return enabled_; } - // Require GIL. + // Requires GIL. static void SetEnabled(bool enabled); - // Require GIL. + // Requires GIL. Don't call this directly, you're looking for Get(). Traceback(); - // Require GIL. + // Requires GIL. ~Traceback(); Traceback(const Traceback&) = delete; @@ -58,8 +56,8 @@ class Traceback { std::string ToString() const; struct Frame { - pybind11::str file_name; - pybind11::str function_name; + nanobind::str file_name; + nanobind::str function_name; int function_start_line; int line_num; @@ -74,7 +72,7 @@ class Traceback { // Returns the traceback as a fake Python Traceback object, suitable for // using as an exception traceback. - pybind11::object AsPythonTraceback() const; + nanobind::object AsPythonTraceback() const; bool operator==(const Traceback& other) const { return frames_ == other.frames_; @@ -95,13 +93,15 @@ class Traceback { static bool enabled_; }; +using nb_traceback = nb_class_ptr; + template H AbslHashValue(H h, const Traceback& traceback) { h = H::combine(std::move(h), traceback.raw_frames()); return h; } -void BuildTracebackSubmodule(pybind11::module& m); +void BuildTracebackSubmodule(nanobind::module_& m); } // namespace xla diff --git a/xla/python/transfer_guard_lib.cc b/xla/python/transfer_guard_lib.cc index 94a720cab3631..3f4809d4e28de 100644 --- a/xla/python/transfer_guard_lib.cc +++ b/xla/python/transfer_guard_lib.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,21 +18,18 @@ limitations under the License. #include "xla/python/transfer_guard_lib.h" -#include #include #include -#include "absl/base/attributes.h" -#include "pybind11/cast.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil -#include "xla/python/status_casters.h" -#include "xla/status.h" +#include "absl/functional/function_ref.h" +#include "absl/status/status.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/optional.h" // from @nanobind // IWYU pragma: keep #include "xla/util.h" namespace jax { -namespace py = ::pybind11; +namespace nb = ::nanobind; namespace { @@ -102,7 +99,7 @@ TransferGuardAction GetTransferGuardActionForDeviceToHost() { } // namespace -xla::Status ApplyTransferGuardToHostToDevice( +absl::Status ApplyTransferGuardToHostToDevice( absl::FunctionRef formatter) { switch (GetTransferGuardActionForHostToDevice()) { case TransferGuardAction::kAllow: @@ -114,10 +111,10 @@ xla::Status ApplyTransferGuardToHostToDevice( return xla::InvalidArgument("Disallowed host-to-device transfer: %s", formatter()); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } -xla::Status ApplyTransferGuardToDeviceToDevice( +absl::Status ApplyTransferGuardToDeviceToDevice( absl::FunctionRef formatter) { switch (GetTransferGuardActionForDeviceToDevice()) { case TransferGuardAction::kAllow: @@ -129,10 +126,10 @@ xla::Status ApplyTransferGuardToDeviceToDevice( return xla::InvalidArgument("Disallowed device-to-device transfer: %s", formatter()); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } -xla::Status ApplyTransferGuardToDeviceToHost( +absl::Status ApplyTransferGuardToDeviceToHost( absl::FunctionRef formatter) { switch (GetTransferGuardActionForDeviceToHost()) { case TransferGuardAction::kAllow: @@ -144,36 +141,38 @@ xla::Status ApplyTransferGuardToDeviceToHost( return xla::InvalidArgument("Disallowed device-to-host transfer: %s", formatter()); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } -void BuildTransferGuardSubmodule(py::module& m) { - py::module tglib = m.def_submodule("transfer_guard_lib", - "Jax transfer guard support library"); +void BuildTransferGuardSubmodule(nb::module_& m) { + nb::module_ tglib = m.def_submodule("transfer_guard_lib", + "Jax transfer guard support library"); - py::enum_ tglevel(tglib, "TransferGuardLevel"); + nb::enum_ tglevel(tglib, "TransferGuardLevel"); tglevel.value("ALLOW", TransferGuardLevel::kAllow); tglevel.value("LOG", TransferGuardLevel::kLog); tglevel.value("DISALLOW", TransferGuardLevel::kDisallow); tglevel.value("LOG_EXPLICIT", TransferGuardLevel::kLogExplicit); tglevel.value("DISALLOW_EXPLICIT", TransferGuardLevel::kDisallowExplicit); - py::class_ tgstate(tglib, "TransferGuardState"); - tgstate.def_readwrite("host_to_device", &TransferGuardState::host_to_device); - tgstate.def_readwrite("device_to_device", - &TransferGuardState::device_to_device); - tgstate.def_readwrite("device_to_host", &TransferGuardState::device_to_host); - tgstate.def_readwrite("explicit_device_put", - &TransferGuardState::explicit_device_put); - tgstate.def_readwrite("explicit_device_get", - &TransferGuardState::explicit_device_get); + nb::class_ tgstate(tglib, "TransferGuardState"); + tgstate.def_rw("host_to_device", &TransferGuardState::host_to_device, + nb::arg().none()); + tgstate.def_rw("device_to_device", &TransferGuardState::device_to_device, + nb::arg().none()); + tgstate.def_rw("device_to_host", &TransferGuardState::device_to_host, + nb::arg().none()); + tgstate.def_rw("explicit_device_put", + &TransferGuardState::explicit_device_put); + tgstate.def_rw("explicit_device_get", + &TransferGuardState::explicit_device_get); tglib.def( "global_state", [&]() { return &global_state; }, - py::return_value_policy::reference); + nb::rv_policy::reference); tglib.def( "thread_local_state", [&]() { return &thread_local_state; }, - py::return_value_policy::reference); + nb::rv_policy::reference); } } // namespace jax diff --git a/xla/python/transfer_guard_lib.h b/xla/python/transfer_guard_lib.h index 1381bfd642507..12a7216939991 100644 --- a/xla/python/transfer_guard_lib.h +++ b/xla/python/transfer_guard_lib.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,8 +20,9 @@ limitations under the License. #include // placeholder for index annotation headers -#include "pybind11/pybind11.h" // from @pybind11 -#include "xla/status.h" +#include "absl/functional/function_ref.h" +#include "absl/status/status.h" +#include "nanobind/nanobind.h" // from @nanobind namespace jax { @@ -75,23 +76,23 @@ enum class TransferGuardAction { // Guards a host-to-device transfer. formatter is called to describe the // transfer in a log message or error status. // REQUIRES: Python GIL. -xla::Status ApplyTransferGuardToHostToDevice( +absl::Status ApplyTransferGuardToHostToDevice( absl::FunctionRef formatter); // Guards a device-to-device transfer. formatter is called to describe the // transfer in a log message or error status. // REQUIRES: Python GIL. -xla::Status ApplyTransferGuardToDeviceToDevice( +absl::Status ApplyTransferGuardToDeviceToDevice( absl::FunctionRef formatter); // Guards a device-to-host transfer. formatter is called to describe the // transfer in a log message or error status. // REQUIRES: Python GIL. -xla::Status ApplyTransferGuardToDeviceToHost( +absl::Status ApplyTransferGuardToDeviceToHost( absl::FunctionRef formatter); // The function to call in `xla.cc` to add the bindings for this module. -void BuildTransferGuardSubmodule(pybind11::module& m); +void BuildTransferGuardSubmodule(nanobind::module_& m); } // namespace jax diff --git a/xla/python/types.cc b/xla/python/types.cc index 288fcf92543fb..ed84d68622c25 100644 --- a/xla/python/types.cc +++ b/xla/python/types.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,55 +15,75 @@ limitations under the License. #include "xla/python/types.h" -#include -#include -#include +#include + +#include #include #include -#include +#include #include #include #include #include "absl/container/flat_hash_map.h" -#include "xla/python/exceptions.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/ndarray.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // from @nanobind // IWYU pragma: keep +#include "xla/layout.h" +#include "xla/literal.h" +#include "xla/pjrt/exceptions.h" #include "xla/python/ifrt/dtype.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla { -namespace py = pybind11; +namespace nb = nanobind; namespace { struct CustomDtypes { - py::dtype bfloat16; - py::dtype float8_e4m3fn; - py::dtype float8_e4m3b11fnuz; - py::dtype float8_e4m3fnuz; - py::dtype float8_e5m2; - py::dtype float8_e5m2fnuz; - py::dtype int4; - py::dtype uint4; + nb_dtype bfloat16; + nb_dtype float8_e4m3fn; + nb_dtype float8_e4m3b11fnuz; + nb_dtype float8_e4m3fnuz; + nb_dtype float8_e5m2; + nb_dtype float8_e5m2fnuz; + nb_dtype int4; + nb_dtype uint4; }; const CustomDtypes& GetCustomDtypes() { static const CustomDtypes& custom_dtypes = *[]() { - py::module ml_dtypes = py::module::import("ml_dtypes"); + nb::module_ ml_dtypes = nb::module_::import_("ml_dtypes"); auto* dtypes = new CustomDtypes; - dtypes->bfloat16 = py::dtype::from_args(ml_dtypes.attr("bfloat16")); + dtypes->bfloat16 = nb_dtype::from_args(ml_dtypes.attr("bfloat16")); dtypes->float8_e4m3fn = - py::dtype::from_args(ml_dtypes.attr("float8_e4m3fn")); - dtypes->float8_e5m2 = py::dtype::from_args(ml_dtypes.attr("float8_e5m2")); + nb_dtype::from_args(ml_dtypes.attr("float8_e4m3fn")); + dtypes->float8_e5m2 = nb_dtype::from_args(ml_dtypes.attr("float8_e5m2")); dtypes->float8_e4m3b11fnuz = - py::dtype::from_args(ml_dtypes.attr("float8_e4m3b11fnuz")); + nb_dtype::from_args(ml_dtypes.attr("float8_e4m3b11fnuz")); dtypes->float8_e4m3fnuz = - py::dtype::from_args(ml_dtypes.attr("float8_e4m3fnuz")); + nb_dtype::from_args(ml_dtypes.attr("float8_e4m3fnuz")); dtypes->float8_e5m2fnuz = - py::dtype::from_args(ml_dtypes.attr("float8_e5m2fnuz")); - dtypes->int4 = py::dtype::from_args(ml_dtypes.attr("int4")); - dtypes->uint4 = py::dtype::from_args(ml_dtypes.attr("uint4")); + nb_dtype::from_args(ml_dtypes.attr("float8_e5m2fnuz")); + dtypes->int4 = nb_dtype::from_args(ml_dtypes.attr("int4")); + dtypes->uint4 = nb_dtype::from_args(ml_dtypes.attr("uint4")); return dtypes; }(); return custom_dtypes; @@ -71,7 +91,7 @@ const CustomDtypes& GetCustomDtypes() { } // namespace -xla::StatusOr DtypeToPrimitiveType(const py::dtype& np_type) { +absl::StatusOr DtypeToPrimitiveType(const nb_dtype& np_type) { static auto& builtin_dtypes = *new absl::flat_hash_map, PrimitiveType>({ {{'?', 'b', 1}, PRED}, @@ -100,17 +120,17 @@ xla::StatusOr DtypeToPrimitiveType(const py::dtype& np_type) { } struct DtypeEq { - bool operator()(const py::dtype& a, const py::dtype& b) const { + bool operator()(const nb_dtype& a, const nb_dtype& b) const { return a.equal(b); } }; struct DtypeHash { - ssize_t operator()(const py::dtype& key) const { return py::hash(key); } + ssize_t operator()(const nb_dtype& key) const { return nb_hash(key); } }; static auto* custom_dtype_map = []() { const CustomDtypes& custom_dtypes = GetCustomDtypes(); auto* map = - new absl::flat_hash_map(); + new absl::flat_hash_map(); map->emplace(custom_dtypes.bfloat16, BF16); map->emplace(custom_dtypes.float8_e4m3fn, F8E4M3FN); map->emplace(custom_dtypes.float8_e4m3b11fnuz, F8E4M3B11FNUZ); @@ -127,35 +147,39 @@ xla::StatusOr DtypeToPrimitiveType(const py::dtype& np_type) { return custom_it->second; } return InvalidArgument("Unknown NumPy dtype %s char %c kind %c itemsize %d", - static_cast(py::repr(np_type)), + nb::cast(nb::repr(np_type)), np_type.char_(), np_type.kind(), np_type.itemsize()); } -xla::StatusOr PrimitiveTypeToDtype(PrimitiveType type) { +absl::StatusOr PrimitiveTypeToNbDtype(PrimitiveType type) { const CustomDtypes& custom_dtypes = GetCustomDtypes(); + auto to_nb_dtype = [](int typenum) -> nb_dtype { + return nb::steal( + reinterpret_cast(PyArray_DescrFromType(typenum))); + }; switch (type) { case PRED: - return py::dtype::of(); + return to_nb_dtype(NPY_BOOL); case S4: return custom_dtypes.int4; case S8: - return py::dtype::of(); + return to_nb_dtype(NPY_INT8); case S16: - return py::dtype::of(); + return to_nb_dtype(NPY_INT16); case S32: - return py::dtype::of(); + return to_nb_dtype(NPY_INT32); case S64: - return py::dtype::of(); + return to_nb_dtype(NPY_INT64); case U4: return custom_dtypes.uint4; case U8: - return py::dtype::of(); + return to_nb_dtype(NPY_UINT8); case U16: - return py::dtype::of(); + return to_nb_dtype(NPY_UINT16); case U32: - return py::dtype::of(); + return to_nb_dtype(NPY_UINT32); case U64: - return py::dtype::of(); + return to_nb_dtype(NPY_UINT64); case F8E4M3FN: return custom_dtypes.float8_e4m3fn; case F8E4M3B11FNUZ: @@ -169,58 +193,62 @@ xla::StatusOr PrimitiveTypeToDtype(PrimitiveType type) { case BF16: return custom_dtypes.bfloat16; case F16: - return py::dtype("e"); // PEP 3118 code for "float16 + return to_nb_dtype(NPY_HALF); case F32: - return py::dtype::of(); + return to_nb_dtype(NPY_FLOAT); case F64: - return py::dtype::of(); + return to_nb_dtype(NPY_DOUBLE); case C64: - return py::dtype::of>(); + return to_nb_dtype(NPY_COMPLEX64); case C128: - return py::dtype::of>(); + return to_nb_dtype(NPY_COMPLEX128); default: return Unimplemented("Unimplemented primitive type %s", PrimitiveType_Name(type)); } } -StatusOr IfrtDtypeToDtype(ifrt::DType dtype) { +absl::StatusOr IfrtDtypeToNbDtype(ifrt::DType dtype) { const CustomDtypes& custom_dtypes = GetCustomDtypes(); + auto to_nb_dtype = [](int typenum) -> nb_dtype { + return nb::steal( + reinterpret_cast(PyArray_DescrFromType(typenum))); + }; switch (dtype.kind()) { case ifrt::DType::kPred: - return py::dtype::of(); + return to_nb_dtype(NPY_BOOL); case ifrt::DType::kS4: return custom_dtypes.int4; case ifrt::DType::kS8: - return py::dtype::of(); + return to_nb_dtype(NPY_INT8); case ifrt::DType::kS16: - return py::dtype::of(); + return to_nb_dtype(NPY_INT16); case ifrt::DType::kS32: - return py::dtype::of(); + return to_nb_dtype(NPY_INT32); case ifrt::DType::kS64: - return py::dtype::of(); + return to_nb_dtype(NPY_INT64); case ifrt::DType::kU4: return custom_dtypes.uint4; case ifrt::DType::kU8: - return py::dtype::of(); + return to_nb_dtype(NPY_UINT8); case ifrt::DType::kU16: - return py::dtype::of(); + return to_nb_dtype(NPY_UINT16); case ifrt::DType::kU32: - return py::dtype::of(); + return to_nb_dtype(NPY_UINT32); case ifrt::DType::kU64: - return py::dtype::of(); + return to_nb_dtype(NPY_UINT64); case ifrt::DType::kF16: - return py::dtype("e"); // PEP 3118 code for "float16" + return to_nb_dtype(NPY_HALF); case ifrt::DType::kF32: - return py::dtype::of(); + return to_nb_dtype(NPY_FLOAT); case ifrt::DType::kF64: - return py::dtype::of(); + return to_nb_dtype(NPY_DOUBLE); case ifrt::DType::kBF16: return custom_dtypes.bfloat16; case ifrt::DType::kC64: - return py::dtype::of>(); + return to_nb_dtype(NPY_COMPLEX64); case ifrt::DType::kC128: - return py::dtype::of>(); + return to_nb_dtype(NPY_COMPLEX128); case ifrt::DType::kF8E4M3FN: return custom_dtypes.float8_e4m3fn; case ifrt::DType::kF8E4M3B11FNUZ: @@ -238,43 +266,48 @@ StatusOr IfrtDtypeToDtype(ifrt::DType dtype) { // part of dtype. Using 'O' allows us to represent variable-length bytes // and is also consistent with TensorFlow's tensor -> ndarray conversion // logic (see `TF_DataType_to_PyArray_TYPE`). - return py::dtype("O"); + return to_nb_dtype(NPY_OBJECT); default: return Unimplemented("Unimplemented primitive type %s", dtype.DebugString()); } } +absl::StatusOr DtypeToIfRtDType(nb_dtype dtype) { + TF_ASSIGN_OR_RETURN(auto primitive_type, DtypeToPrimitiveType(dtype)); + return ifrt::ToDType(primitive_type); +} + const NumpyScalarTypes& GetNumpyScalarTypes() { static const NumpyScalarTypes* singleton = []() { NumpyScalarTypes* dtypes = new NumpyScalarTypes(); - py::module numpy = py::module::import("numpy"); - py::module ml_dtypes = py::module::import("ml_dtypes"); - dtypes->np_bool = py::object(numpy.attr("bool_")); - dtypes->np_int4 = py::object(ml_dtypes.attr("int4")); - dtypes->np_int8 = py::object(numpy.attr("int8")); - dtypes->np_int16 = py::object(numpy.attr("int16")); - dtypes->np_int32 = py::object(numpy.attr("int32")); - dtypes->np_int64 = py::object(numpy.attr("int64")); - dtypes->np_uint4 = py::object(ml_dtypes.attr("uint4")); - dtypes->np_uint8 = py::object(numpy.attr("uint8")); - dtypes->np_uint16 = py::object(numpy.attr("uint16")); - dtypes->np_uint32 = py::object(numpy.attr("uint32")); - dtypes->np_uint64 = py::object(numpy.attr("uint64")); - dtypes->np_bfloat16 = py::object(ml_dtypes.attr("bfloat16")); - dtypes->np_float8_e4m3fn = py::object(ml_dtypes.attr("float8_e4m3fn")); + nb::module_ numpy = nb::module_::import_("numpy"); + nb::module_ ml_dtypes = nb::module_::import_("ml_dtypes"); + dtypes->np_bool = nb::object(numpy.attr("bool_")); + dtypes->np_int4 = nb::object(ml_dtypes.attr("int4")); + dtypes->np_int8 = nb::object(numpy.attr("int8")); + dtypes->np_int16 = nb::object(numpy.attr("int16")); + dtypes->np_int32 = nb::object(numpy.attr("int32")); + dtypes->np_int64 = nb::object(numpy.attr("int64")); + dtypes->np_uint4 = nb::object(ml_dtypes.attr("uint4")); + dtypes->np_uint8 = nb::object(numpy.attr("uint8")); + dtypes->np_uint16 = nb::object(numpy.attr("uint16")); + dtypes->np_uint32 = nb::object(numpy.attr("uint32")); + dtypes->np_uint64 = nb::object(numpy.attr("uint64")); + dtypes->np_bfloat16 = nb::object(ml_dtypes.attr("bfloat16")); + dtypes->np_float8_e4m3fn = nb::object(ml_dtypes.attr("float8_e4m3fn")); dtypes->np_float8_e4m3b11fnuz = - py::object(ml_dtypes.attr("float8_e4m3b11fnuz")); - dtypes->np_float8_e5m2 = py::object(ml_dtypes.attr("float8_e5m2")); - dtypes->np_float8_e4m3fnuz = py::object(ml_dtypes.attr("float8_e4m3fnuz")); - dtypes->np_float8_e5m2fnuz = py::object(ml_dtypes.attr("float8_e5m2fnuz")); - dtypes->np_float16 = py::object(numpy.attr("float16")); - dtypes->np_float32 = py::object(numpy.attr("float32")); - dtypes->np_float64 = py::object(numpy.attr("float64")); - dtypes->np_complex64 = py::object(numpy.attr("complex64")); - dtypes->np_complex128 = py::object(numpy.attr("complex128")); - dtypes->np_longlong = py::object(numpy.attr("longlong")); - dtypes->np_intc = py::object(numpy.attr("intc")); + nb::object(ml_dtypes.attr("float8_e4m3b11fnuz")); + dtypes->np_float8_e5m2 = nb::object(ml_dtypes.attr("float8_e5m2")); + dtypes->np_float8_e4m3fnuz = nb::object(ml_dtypes.attr("float8_e4m3fnuz")); + dtypes->np_float8_e5m2fnuz = nb::object(ml_dtypes.attr("float8_e5m2fnuz")); + dtypes->np_float16 = nb::object(numpy.attr("float16")); + dtypes->np_float32 = nb::object(numpy.attr("float32")); + dtypes->np_float64 = nb::object(numpy.attr("float64")); + dtypes->np_complex64 = nb::object(numpy.attr("complex64")); + dtypes->np_complex128 = nb::object(numpy.attr("complex128")); + dtypes->np_longlong = nb::object(numpy.attr("longlong")); + dtypes->np_intc = nb::object(numpy.attr("intc")); return dtypes; }(); return *singleton; @@ -318,7 +351,7 @@ const char* PEP3118FormatDescriptorForPrimitiveType(PrimitiveType type) { } } -StatusOr TypeDescriptorForPrimitiveType(PrimitiveType type) { +absl::StatusOr TypeDescriptorForPrimitiveType(PrimitiveType type) { #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ #define ENDIAN_PREFIX "<" #else @@ -326,35 +359,35 @@ StatusOr TypeDescriptorForPrimitiveType(PrimitiveType type) { #endif switch (type) { case PRED: - return py::str("|b1"); + return nb::str("|b1"); case S8: - return py::str("|i1"); + return nb::str("|i1"); case S16: - return py::str(ENDIAN_PREFIX "i2"); + return nb::str(ENDIAN_PREFIX "i2"); case S32: - return py::str(ENDIAN_PREFIX "i4"); + return nb::str(ENDIAN_PREFIX "i4"); case S64: - return py::str(ENDIAN_PREFIX "i8"); + return nb::str(ENDIAN_PREFIX "i8"); case U8: - return py::str("|u1"); + return nb::str("|u1"); case U16: - return py::str(ENDIAN_PREFIX "u2"); + return nb::str(ENDIAN_PREFIX "u2"); case U32: - return py::str(ENDIAN_PREFIX "u4"); + return nb::str(ENDIAN_PREFIX "u4"); case U64: - return py::str(ENDIAN_PREFIX "u8"); + return nb::str(ENDIAN_PREFIX "u8"); case BF16: - return py::str(ENDIAN_PREFIX "V2"); + return nb::str(ENDIAN_PREFIX "V2"); case F16: - return py::str(ENDIAN_PREFIX "f2"); + return nb::str(ENDIAN_PREFIX "f2"); case F32: - return py::str(ENDIAN_PREFIX "f4"); + return nb::str(ENDIAN_PREFIX "f4"); case F64: - return py::str(ENDIAN_PREFIX "f8"); + return nb::str(ENDIAN_PREFIX "f8"); case C64: - return py::str(ENDIAN_PREFIX "c8"); + return nb::str(ENDIAN_PREFIX "c8"); case C128: - return py::str(ENDIAN_PREFIX "c16"); + return nb::str(ENDIAN_PREFIX "c16"); default: return Unimplemented("Unimplemented primitive type %s", PrimitiveType_Name(type)); @@ -414,17 +447,18 @@ std::vector StridesForShape(PrimitiveType element_type, /*innermost_stride_size=*/1); } -StatusOr LiteralToPython(std::shared_ptr literal) { +absl::StatusOr LiteralToPython( + std::shared_ptr literal) { xla::Literal& m = *literal; if (m.shape().IsTuple()) { std::vector elems = m.DecomposeTuple(); - std::vector arrays(elems.size()); + std::vector arrays(elems.size()); for (int i = 0; i < elems.size(); ++i) { TF_ASSIGN_OR_RETURN( arrays[i], LiteralToPython(std::make_unique(std::move(elems[i])))); } - py::tuple result(elems.size()); + nb::tuple result = nb::steal(PyTuple_New(elems.size())); for (int i = 0; i < elems.size(); ++i) { PyTuple_SET_ITEM(result.ptr(), i, arrays[i].release().ptr()); } @@ -432,68 +466,25 @@ StatusOr LiteralToPython(std::shared_ptr literal) { } TF_RET_CHECK(m.shape().IsArray()); - py::object literal_object = py::cast(literal); - TF_ASSIGN_OR_RETURN(py::dtype dtype, - PrimitiveTypeToDtype(m.shape().element_type())); - return py::array(dtype, m.shape().dimensions(), - ByteStridesForShape(m.shape()), m.untyped_data(), - literal_object); -} - -StatusOr GetPythonBufferTree(const py::object& argument) { - PythonBufferTree tree; - if (py::isinstance(argument)) { - py::tuple tuple = py::reinterpret_borrow(argument); - std::vector host_shapes(tuple.size()); - for (int i = 0; i < host_shapes.size(); ++i) { - TF_ASSIGN_OR_RETURN(PythonBufferTree subtree, - GetPythonBufferTree(tuple[i])); - tree.leaves.reserve(tree.leaves.size() + subtree.leaves.size()); - std::move(subtree.leaves.begin(), subtree.leaves.end(), - std::back_inserter(tree.leaves)); - tree.arrays.reserve(tree.arrays.size() + subtree.arrays.size()); - std::move(subtree.arrays.begin(), subtree.arrays.end(), - std::back_inserter(tree.arrays)); - host_shapes[i] = std::move(subtree.shape); - } - tree.shape = ShapeUtil::MakeTupleShape(host_shapes); - } else { - pybind11::detail::type_caster caster; - if (!caster.load(argument, /*convert=*/true)) { - return InvalidArgument("Invalid array value."); - } - DCHECK_EQ(caster.arrays.size(), 1); - tree.arrays.push_back(std::move(caster.arrays.front())); - tree.leaves.push_back(std::move(*caster)); - tree.shape = tree.leaves.front().shape(); - } - return tree; + nb::object literal_object = nb::cast(literal); + TF_ASSIGN_OR_RETURN(nb_dtype dtype, + PrimitiveTypeToNbDtype(m.shape().element_type())); + return nb_numpy_ndarray(dtype, m.shape().dimensions(), + ByteStridesForShape(m.shape()), m.untyped_data(), + literal_object); } -template -static py::tuple IntSpanToTupleHelper(absl::Span xs) { - py::tuple out(xs.size()); +nb::tuple MutableSpanToNbTuple(absl::Span xs) { + nb::tuple out = nb::steal(PyTuple_New(xs.size())); for (int i = 0; i < xs.size(); ++i) { - out[i] = py::int_(xs[i]); + PyTuple_SET_ITEM(out.ptr(), i, xs[i].release().ptr()); } return out; } -template <> -pybind11::tuple SpanToTuple(absl::Span xs) { - return IntSpanToTupleHelper(xs); -} -template <> -pybind11::tuple SpanToTuple(absl::Span xs) { - return IntSpanToTupleHelper(xs); -} - -std::optional CastToArray(py::handle h) { - py::array array = py::array::ensure( - h, py::array::c_style | py::detail::npy_api::NPY_ARRAY_ALIGNED_); - if (!array) { - return std::nullopt; - } +std::optional CastToArray(nb::handle h) { + auto array = + nb_numpy_ndarray::ensure(h, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_ALIGNED); auto type_or_status = DtypeToPrimitiveType(array.dtype()); if (!type_or_status.ok()) { throw xla::XlaRuntimeError(type_or_status.status()); diff --git a/xla/python/types.h b/xla/python/types.h index 1b20f84222c40..f155ec226d8f7 100644 --- a/xla/python/types.h +++ b/xla/python/types.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,70 +16,72 @@ limitations under the License. #ifndef XLA_PYTHON_TYPES_H_ #define XLA_PYTHON_TYPES_H_ -#include +#include + +#include #include #include -#include #include #include "absl/container/inlined_vector.h" -#include "pybind11/numpy.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 -#include "pybind11/stl.h" // from @pybind11 -#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil +#include "absl/types/span.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "xla/layout.h" #include "xla/literal.h" #include "xla/python/ifrt/dtype.h" +#include "xla/python/nb_numpy.h" #include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/status.h" #include "xla/statusor.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/protobuf.h" namespace xla { // Converts a NumPy dtype to a PrimitiveType. -StatusOr DtypeToPrimitiveType(const pybind11::dtype& np_type); +absl::StatusOr DtypeToPrimitiveType(const nb_dtype& np_type); // Converts a PrimitiveType to a Numpy dtype. -StatusOr PrimitiveTypeToDtype(PrimitiveType type); +absl::StatusOr PrimitiveTypeToNbDtype(PrimitiveType type); // Converts an IFRT dtype to a NumPy dtype. -StatusOr IfrtDtypeToDtype(ifrt::DType dtype); +absl::StatusOr IfrtDtypeToNbDtype(ifrt::DType dtype); + +absl::StatusOr DtypeToIfRtDType(nb_dtype dtype); // Returns a Python buffer protocol (PEP 3118) format descriptor string for // `type`. Return nullptr if there is no suitable choice of format string. const char* PEP3118FormatDescriptorForPrimitiveType(PrimitiveType type); // Returns a numpy-style typestr for `type`, as returned by np.dtype(...).str -StatusOr TypeDescriptorForPrimitiveType(PrimitiveType type); +absl::StatusOr TypeDescriptorForPrimitiveType( + PrimitiveType type); struct NumpyScalarTypes { - pybind11::object np_bool; - pybind11::object np_int4; - pybind11::object np_int8; - pybind11::object np_int16; - pybind11::object np_int32; - pybind11::object np_int64; - pybind11::object np_uint4; - pybind11::object np_uint8; - pybind11::object np_uint16; - pybind11::object np_uint32; - pybind11::object np_uint64; - pybind11::object np_bfloat16; - pybind11::object np_float8_e4m3fn; - pybind11::object np_float8_e4m3b11fnuz; - pybind11::object np_float8_e4m3fnuz; - pybind11::object np_float8_e5m2; - pybind11::object np_float8_e5m2fnuz; - pybind11::object np_float16; - pybind11::object np_float32; - pybind11::object np_float64; - pybind11::object np_complex64; - pybind11::object np_complex128; - pybind11::object np_longlong; - pybind11::object np_intc; + nanobind::object np_bool; + nanobind::object np_int4; + nanobind::object np_int8; + nanobind::object np_int16; + nanobind::object np_int32; + nanobind::object np_int64; + nanobind::object np_uint4; + nanobind::object np_uint8; + nanobind::object np_uint16; + nanobind::object np_uint32; + nanobind::object np_uint64; + nanobind::object np_bfloat16; + nanobind::object np_float8_e4m3fn; + nanobind::object np_float8_e4m3b11fnuz; + nanobind::object np_float8_e4m3fnuz; + nanobind::object np_float8_e5m2; + nanobind::object np_float8_e5m2fnuz; + nanobind::object np_float16; + nanobind::object np_float32; + nanobind::object np_float64; + nanobind::object np_complex64; + nanobind::object np_complex128; + nanobind::object np_longlong; + nanobind::object np_intc; }; const NumpyScalarTypes& GetNumpyScalarTypes(); @@ -100,53 +102,38 @@ std::vector StridesForShape(PrimitiveType element_type, // buffers with the literals. Takes ownership of `literal` and keeps the // necessary pieces alive using Python reference counting. // Requires the GIL. -StatusOr LiteralToPython(std::shared_ptr literal); - -// Converts a Python object into an XLA shape and a vector of leaf buffers. -// The leaf buffers correspond to a depth-first, left-to-right traversal of -// the Python value. -// Requires the GIL. -struct PythonBufferTree { - // Holds a reference to the arrays pointed to by `leaves`, since we may - // need to make a copy if the array is not in a C-style layout. - absl::InlinedVector arrays; - absl::InlinedVector leaves; - Shape shape; -}; -StatusOr GetPythonBufferTree( - const pybind11::object& argument); +absl::StatusOr LiteralToPython( + std::shared_ptr literal); -// Converts a sequence of C++ ints to a Python tuple of ints. -// Pybind11 by default converts a std::vector to a Python list; -// we frequently want a tuple instead e.g. for shapes. template -pybind11::tuple SpanToTuple(absl::Span xs) { - pybind11::tuple out(xs.size()); +nanobind::tuple SpanToNbTuple(absl::Span xs) { + nanobind::tuple out = + nanobind::steal(PyTuple_New(xs.size())); for (int i = 0; i < xs.size(); ++i) { - out[i] = pybind11::cast(xs[i]); + PyTuple_SET_ITEM(out.ptr(), i, nanobind::cast(xs[i]).release().ptr()); } return out; } -template <> -pybind11::tuple SpanToTuple(absl::Span xs); -template <> -pybind11::tuple SpanToTuple(absl::Span xs); -// Converts a Python iterable/sequence of T to std::vector +// Converts a sequence of Python objects to a Python tuple, stealing the +// references to the objects. +nanobind::tuple MutableSpanToNbTuple(absl::Span xs); + + template -std::vector IterableToVector(const pybind11::iterable& iterable) { +std::vector IterableToVector(const nanobind::iterable& iterable) { std::vector output; for (auto item : iterable) { - output.push_back(item.cast()); + output.push_back(nanobind::cast(item)); } return output; } template -std::vector SequenceToVector(const pybind11::sequence& sequence) { +std::vector SequenceToVector(const nanobind::sequence& sequence) { std::vector output; - output.reserve(sequence.size()); + output.reserve(PySequence_Size(sequence.ptr())); for (auto item : sequence) { - output.push_back(item.cast()); + output.push_back(nanobind::cast(item)); } return output; } @@ -155,19 +142,15 @@ std::vector SequenceToVector(const pybind11::sequence& sequence) { // xla::BorrowingLiteral. Converts a Python array-like object into a buffer // pointer and shape. struct CastToArrayResult { - pybind11::object array; // Holds a reference to the array to keep it alive. + nanobind::object array; // Holds a reference to the array to keep it alive. const char* buf_ptr; xla::Shape shape; }; -std::optional CastToArray(pybind11::handle h); +std::optional CastToArray(nanobind::handle h); } // namespace xla -// This namespace is a documented pybind11 extension point. -// Caution: Unusually for Google code, this code uses C++ exceptions because -// they are the only mechanism for reporting cast failures to pybind11. However, -// the exceptions are local to the binding code. -namespace pybind11 { +namespace nanobind { namespace detail { // Literals. @@ -179,23 +162,29 @@ namespace detail { template <> struct type_caster { public: - PYBIND11_TYPE_CASTER(xla::BorrowingLiteral, _("xla::BorrowingLiteral")); + using Value = xla::BorrowingLiteral; + static constexpr auto Name = const_name("xla::BorrowingLiteral"); // NOLINT + template + using Cast = movable_cast_t; + explicit operator Value*() { return &value; } + explicit operator Value&() { return (Value&)value; } + explicit operator Value&&() { return (Value&&)value; } + Value value; // Pybind appears to keep type_casters alive until the callee has run. - absl::InlinedVector arrays; + absl::InlinedVector arrays; - bool load(handle input, bool) { + bool from_python(handle input, uint8_t, cleanup_list*) { // TODO(b/79707221): support nested tuples if/when XLA adds support for // nested BorrowingLiterals. - if (pybind11::isinstance(input)) { - pybind11::tuple tuple = - pybind11::reinterpret_borrow(input); + if (nanobind::isinstance(input)) { + nanobind::tuple tuple = nanobind::borrow(input); std::vector shapes; std::vector buffers; arrays.reserve(tuple.size()); shapes.reserve(tuple.size()); buffers.reserve(tuple.size()); - for (pybind11::handle entry : tuple) { + for (nanobind::handle entry : tuple) { auto c = xla::CastToArray(entry); if (!c) { return false; @@ -221,241 +210,28 @@ struct type_caster { template <> struct type_caster { public: - PYBIND11_TYPE_CASTER(xla::LiteralSlice, _("xla::LiteralSlice")); + NB_TYPE_CASTER(xla::LiteralSlice, const_name("xla::LiteralSlice")); // Pybind appears to keep type_casters alive until the callee has run. type_caster literal_caster; - bool load(handle handle, bool convert) { - if (!literal_caster.load(handle, convert)) { + bool from_python(handle handle, uint8_t flags, cleanup_list* cleanup) { + if (!literal_caster.from_python(handle, flags, cleanup)) { return false; } value = static_cast(literal_caster); return true; } -}; - -// XLA protocol buffers -// We don't actually care that these are the protocol buffers, we merely want -// objects that duck type as protocol buffers. The client code currently avoids -// depending on Python protocol buffers to avoid conflicting definitions from -// different modules that both include XLA. - -template <> -struct type_caster { - public: - PYBIND11_TYPE_CASTER(xla::ConvolutionDimensionNumbers, - _("xla::ConvolutionDimensionNumbers")); - - // PyObject -> C++ conversion. - bool load(handle handle, bool) { - value.set_input_batch_dimension( - getattr(handle, "input_batch_dimension").cast()); - value.set_input_feature_dimension( - getattr(handle, "input_feature_dimension").cast()); - value.set_output_batch_dimension( - getattr(handle, "output_batch_dimension").cast()); - value.set_output_feature_dimension( - getattr(handle, "output_feature_dimension").cast()); - value.set_kernel_input_feature_dimension( - getattr(handle, "kernel_input_feature_dimension").cast()); - value.set_kernel_output_feature_dimension( - getattr(handle, "kernel_output_feature_dimension").cast()); - std::vector dims; - dims = getattr(handle, "input_spatial_dimensions") - .cast>(); - std::copy(dims.begin(), dims.end(), - tsl::protobuf::RepeatedFieldBackInserter( - value.mutable_input_spatial_dimensions())); - dims = getattr(handle, "kernel_spatial_dimensions") - .cast>(); - std::copy(dims.begin(), dims.end(), - tsl::protobuf::RepeatedFieldBackInserter( - value.mutable_kernel_spatial_dimensions())); - dims = getattr(handle, "output_spatial_dimensions") - .cast>(); - std::copy(dims.begin(), dims.end(), - tsl::protobuf::RepeatedFieldBackInserter( - value.mutable_output_spatial_dimensions())); - return true; - } -}; - -template <> -struct type_caster { - public: - PYBIND11_TYPE_CASTER(xla::DotDimensionNumbers, _("xla::DotDimensionNumbers")); - - // PyObject -> C++ conversion. - bool load(handle handle, bool) { - std::vector dims; - dims = getattr(handle, "lhs_contracting_dimensions") - .cast>(); - std::copy(dims.begin(), dims.end(), - tsl::protobuf::RepeatedFieldBackInserter( - value.mutable_lhs_contracting_dimensions())); - dims = getattr(handle, "rhs_contracting_dimensions") - .cast>(); - std::copy(dims.begin(), dims.end(), - tsl::protobuf::RepeatedFieldBackInserter( - value.mutable_rhs_contracting_dimensions())); - dims = getattr(handle, "lhs_batch_dimensions").cast>(); - std::copy(dims.begin(), dims.end(), - tsl::protobuf::RepeatedFieldBackInserter( - value.mutable_lhs_batch_dimensions())); - dims = getattr(handle, "rhs_batch_dimensions").cast>(); - std::copy(dims.begin(), dims.end(), - tsl::protobuf::RepeatedFieldBackInserter( - value.mutable_rhs_batch_dimensions())); - return true; - } -}; -template <> -struct type_caster { - public: - PYBIND11_TYPE_CASTER(xla::GatherDimensionNumbers, - _("xla::GatherDimensionNumbers")); - - // PyObject -> C++ conversion. - bool load(handle handle, bool) { - std::vector dims; - dims = getattr(handle, "offset_dims").cast>(); - std::copy( - dims.begin(), dims.end(), - tsl::protobuf::RepeatedFieldBackInserter(value.mutable_offset_dims())); - dims = getattr(handle, "collapsed_slice_dims").cast>(); - std::copy(dims.begin(), dims.end(), - tsl::protobuf::RepeatedFieldBackInserter( - value.mutable_collapsed_slice_dims())); - dims = getattr(handle, "start_index_map").cast>(); - std::copy(dims.begin(), dims.end(), - tsl::protobuf::RepeatedFieldBackInserter( - value.mutable_start_index_map())); - value.set_index_vector_dim( - getattr(handle, "index_vector_dim").cast()); - return true; - } -}; - -template <> -struct type_caster { - public: - PYBIND11_TYPE_CASTER(xla::ScatterDimensionNumbers, - _("xla::ScatterDimensionNumbers")); - - // PyObject -> C++ conversion. - bool load(handle handle, bool) { - std::vector dims; - dims = getattr(handle, "update_window_dims").cast>(); - std::copy(dims.begin(), dims.end(), - tsl::protobuf::RepeatedFieldBackInserter( - value.mutable_update_window_dims())); - dims = getattr(handle, "inserted_window_dims").cast>(); - std::copy(dims.begin(), dims.end(), - tsl::protobuf::RepeatedFieldBackInserter( - value.mutable_inserted_window_dims())); - dims = getattr(handle, "scatter_dims_to_operand_dims") - .cast>(); - std::copy(dims.begin(), dims.end(), - tsl::protobuf::RepeatedFieldBackInserter( - value.mutable_scatter_dims_to_operand_dims())); - value.set_index_vector_dim( - getattr(handle, "index_vector_dim").cast()); - return true; - } -}; - -template <> -struct type_caster { - public: - PYBIND11_TYPE_CASTER(xla::ReplicaGroup, _("xla::ReplicaGroup")); - - // PyObject -> C++ conversion. - bool load(handle handle, bool) { - std::vector dims; - dims = getattr(handle, "replica_ids").cast>(); - std::copy( - dims.begin(), dims.end(), - tsl::protobuf::RepeatedFieldBackInserter(value.mutable_replica_ids())); - return true; - } -}; - -template <> -struct type_caster { - public: - PYBIND11_TYPE_CASTER(xla::PaddingConfig, _("xla::PaddingConfig")); - - // PyObject -> C++ conversion. - bool load(handle handle, bool) { - sequence dimensions = - reinterpret_borrow(getattr(handle, "dimensions")); - - for (const auto& dimension : dimensions) { - xla::PaddingConfig::PaddingConfigDimension* config_dim = - value.add_dimensions(); - config_dim->set_edge_padding_low( - getattr(dimension, "edge_padding_low").cast()); - config_dim->set_edge_padding_high( - getattr(dimension, "edge_padding_high").cast()); - config_dim->set_interior_padding( - getattr(dimension, "interior_padding").cast()); - } - return true; - } -}; - -template <> -struct type_caster { - public: - PYBIND11_TYPE_CASTER(xla::OpMetadata, _("xla::OpMetadata")); - - // PyObject -> C++ conversion. - bool load(handle handle, bool) { - pybind11::handle op_type = getattr(handle, "op_type"); - if (!op_type.is_none()) { - value.set_op_type(op_type.cast()); - } - pybind11::handle op_name = getattr(handle, "op_name"); - if (!op_name.is_none()) { - value.set_op_name(op_name.cast()); - } - pybind11::handle source_file = getattr(handle, "source_file"); - if (!source_file.is_none()) { - value.set_source_file(source_file.cast()); - } - pybind11::handle source_line = getattr(handle, "source_line"); - if (!source_line.is_none()) { - value.set_source_line(source_line.cast()); - } - return true; - } -}; - -template <> -struct type_caster { - public: - PYBIND11_TYPE_CASTER(xla::PrecisionConfig, _("xla::PrecisionConfig")); - - // PyObject -> C++ conversion. - bool load(handle handle, bool) { - if (handle.is_none()) { - return true; - } - - sequence operand_precisions = - reinterpret_borrow(getattr(handle, "operand_precision")); - - for (const auto& operand_precision : operand_precisions) { - value.add_operand_precision( - operand_precision.cast()); - } - return true; + static handle from_cpp(xla::LiteralSlice src, rv_policy policy, + cleanup_list* cleanup) noexcept { + PyErr_Format(PyExc_NotImplementedError, + "LiteralSlice::from_cpp not implemented"); + return handle(); } }; } // namespace detail -} // namespace pybind11 +} // namespace nanobind #endif // XLA_PYTHON_TYPES_H_ diff --git a/xla/python/util.cc b/xla/python/util.cc index 4507aad9a55b5..9db18e31534db 100644 --- a/xla/python/util.cc +++ b/xla/python/util.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,16 +19,30 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_future.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/future.h" #include "xla/status.h" #include "xla/util.h" namespace xla { -Status AwaitBuffersReady(ifrt::Array* ifrt_array) { - Status s = ifrt_array->GetReadyFuture().Await(); +Status AwaitBuffersReady(absl::Span ifrt_arrays) { + ifrt::Future future; + if (ifrt_arrays.size() == 1) { + future = ifrt_arrays[0]->GetReadyFuture(); + } else { + std::vector> futures; + futures.reserve(ifrt_arrays.size()); + for (ifrt::Array* const ifrt_array : ifrt_arrays) { + futures.push_back(ifrt_array->GetReadyFuture()); + } + future = ifrt::JoinFutures(absl::MakeSpan(futures)); + } + + Status s = future.Await(); if (!s.ok()) { // Fix up error string because some clients rely on it. if (s.message() == "GetReadyFuture() called on deleted or donated buffer") { diff --git a/xla/python/util.h b/xla/python/util.h index f231f1c0e783c..3e9f26d2bbd08 100644 --- a/xla/python/util.h +++ b/xla/python/util.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,76 +19,15 @@ limitations under the License. #include #include -#include "absl/strings/str_format.h" -#include "pybind11/pybind11.h" // from @pybind11 -#include "xla/pjrt/pjrt_client.h" +#include "absl/types/span.h" #include "xla/python/ifrt/array.h" #include "xla/status.h" namespace xla { -template -bool is_pybind_reinterpret_cast_ok(pybind11::handle h) { - static pybind11::detail::type_info* const type_info = []() { - auto* type_info = - pybind11::detail::get_type_info(typeid(T), /*throw_if_missing=*/false); - CHECK(type_info); - CHECK(type_info->simple_type); - return type_info; - }(); - PyTypeObject* srctype = Py_TYPE(h.ptr()); - // Exact type match. - if (srctype == type_info->type) { - return true; - } - // If we have a subtype, then look for a base type that matches. - if (PyType_IsSubtype(srctype, type_info->type)) { - const auto& bases = pybind11::detail::all_type_info(srctype); - for (auto* base : bases) { - if (PyType_IsSubtype(base->type, type_info->type)) { - return true; - } - } - } - return false; -} - -// Faster version of the pybind11 cast x.cast. -// pybind11's cast is fairly slow because it looks up the type information -// in a global hash table. It's not a particularly fast hash table and the -// lookup is pointless when we know the target type and can cache the lookup. -// This function does depend on a number of pybind11 internals; -// if it ever bitrots, one option is to replace it with a pybind11 cast. -// Return nullptr if the cast fails. -template -T* fast_cast(pybind11::handle h) { - if (!is_pybind_reinterpret_cast_ok(h)) { - // Fall back to pybind11's usual cast. - return h.cast(); - } - auto* instance = reinterpret_cast(h.ptr()); - if (instance->simple_layout) { - return reinterpret_cast(instance->simple_value_holder[0]); - } else { - return reinterpret_cast( - pybind11::detail::values_and_holders(instance).begin()->value_ptr()); - } -} - -// Issues a Python deprecation warning. Throws a C++ exception if issuing the -// Python warning causes a Python exception to be raised. -template -void PythonDeprecationWarning(const absl::FormatSpec& format, - const Args&... args) { - if (PyErr_WarnEx(PyExc_DeprecationWarning, - absl::StrFormat(format, args...).c_str(), 1) < 0) { - throw pybind11::error_already_set(); - } -} - // Requests if given buffers are ready, awaits for results and returns OK if // all of the buffers are ready or the last non-ok status. -Status AwaitBuffersReady(ifrt::Array* ifrt_array); +Status AwaitBuffersReady(absl::Span ifrt_arrays); } // namespace xla diff --git a/xla/python/weakref_lru_cache.cc b/xla/python/weakref_lru_cache.cc index 20bde65e9c928..a64cf93a6c4dd 100644 --- a/xla/python/weakref_lru_cache.cc +++ b/xla/python/weakref_lru_cache.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ limitations under the License. #include "xla/python/weakref_lru_cache.h" +#include +#include #include #include #include @@ -22,27 +24,35 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/cleanup/cleanup.h" +#include "absl/container/node_hash_map.h" #include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" -#include "pybind11/cast.h" // from @pybind11 -#include "pybind11/gil.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/shared_ptr.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/vector.h" // from @nanobind // IWYU pragma: keep #include "xla/pjrt/lru_cache.h" +#include "xla/python/nb_helpers.h" + +namespace nb = nanobind; namespace jax { namespace { -// Minimal wrapper to expose a pybind11::dict_iterator's value as something +// Minimal wrapper to expose a nb::dict_iterator's value as something // hashable with Abseil. class HashablePyDictValue { protected: - using Iter = pybind11::detail::dict_iterator; + using Iter = nb::detail::dict_iterator; template friend H AbslHashValue(H h, const HashablePyDictValue& value) { - return H::combine(std::move(h), pybind11::hash(value.iter_->first), - pybind11::hash(value.iter_->second)); + auto kv = *value.iter_; + return H::combine(std::move(h), xla::nb_hash(kv.first), + xla::nb_hash(kv.second)); } explicit HashablePyDictValue(const Iter& iter) : iter_(iter) {} @@ -50,7 +60,7 @@ class HashablePyDictValue { Iter iter_; }; -// Similarly, a minimalist adaptor around the pybind11::detail::dict_iterator +// Similarly, a minimalist adaptor around the nb::detail::dict_iterator // itself. Note that the iterator "is" also a Value. Does not meet the full // standard iterator requirements, only enough to support H::combine_unordered. class HashablePyDictIter : protected HashablePyDictValue { @@ -72,9 +82,9 @@ class HashablePyDictIter : protected HashablePyDictValue { class WeakrefLRUCache : public std::enable_shared_from_this { public: struct Key { - pybind11::object context; - pybind11::args args; - pybind11::kwargs kwargs; + nb::object context; + nb::args args; + nb::kwargs kwargs; bool operator==(const Key& other) const { return context.equal(other.context) && args.equal(other.args) && @@ -83,8 +93,8 @@ class WeakrefLRUCache : public std::enable_shared_from_this { template friend H AbslHashValue(H h, const Key& key) { - h = H::combine(std::move(h), pybind11::hash(key.context), - pybind11::hash(key.args)); + h = H::combine(std::move(h), xla::nb_hash(key.context), + xla::nb_hash(key.args)); h = H::combine_unordered(std::move(h), HashablePyDictIter(key.kwargs.begin()), HashablePyDictIter(key.kwargs.end())); @@ -95,7 +105,7 @@ class WeakrefLRUCache : public std::enable_shared_from_this { struct CacheEntry { bool has_result = false; - pybind11::object result; + nb::object result; absl::Notification completed; std::thread::id thread_id = std::this_thread::get_id(); }; @@ -108,13 +118,13 @@ class WeakrefLRUCache : public std::enable_shared_from_this { }; struct UnboundWeakrefCacheEntry { - pybind11::handle object; + nb::handle object; WeakrefLRUCache* cache; size_t cached_hash; }; struct WeakrefCacheEntry { - pybind11::weakref weakref; + nb::weakref weakref; size_t cached_hash; }; @@ -141,13 +151,12 @@ class WeakrefLRUCache : public std::enable_shared_from_this { if (obj == Py_None) { return false; } - return pybind11::reinterpret_borrow(obj).equal( - rhs.object); + return nb::borrow(obj).equal(rhs.object); } }; using Cache = xla::LRUCache>; - WeakrefLRUCache(pybind11::function cache_context_fn, pybind11::function fn, + WeakrefLRUCache(nb::callable cache_context_fn, nb::callable fn, int64_t maxsize) : cache_context_fn_(cache_context_fn), fn_(fn), lru_list_(maxsize) {} @@ -156,17 +165,16 @@ class WeakrefLRUCache : public std::enable_shared_from_this { if (it != entries_.end()) { return (it->second); } - pybind11::weakref weakref( - key.object, pybind11::cpp_function([this_weak = weak_from_this(), - cached_hash = key.cached_hash]( - pybind11::handle weakref) { + nb::weakref weakref( + key.object, + nb::cpp_function([this_weak = weak_from_this(), + cached_hash = key.cached_hash](nb::handle weakref) { auto cache = this_weak.lock(); if (cache == nullptr) { return; } - auto it = cache->entries_.find(WeakrefCacheEntry{ - pybind11::reinterpret_borrow(weakref), - cached_hash}); + auto it = cache->entries_.find( + WeakrefCacheEntry{nb::borrow(weakref), cached_hash}); // Create temp-var to avoid re-entrant erase. auto tmp = std::move(it->second); cache->entries_.erase(it); @@ -177,29 +185,36 @@ class WeakrefLRUCache : public std::enable_shared_from_this { .first->second); } - pybind11::object Call(pybind11::object weakref_key, pybind11::args args, - pybind11::kwargs kwargs) { - pybind11::object context = cache_context_fn_(); + nb::object Call(nb::object weakref_key, nb::args args, + nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS { + nb::object context = cache_context_fn_(); std::shared_ptr cache_ptr = GetCache(UnboundWeakrefCacheEntry{ - weakref_key, this, static_cast(pybind11::hash(weakref_key))}); + weakref_key, this, static_cast(xla::nb_hash(weakref_key))}); Cache& cache = *cache_ptr; ++total_queries_; bool inserted = false; + std::shared_ptr entry; { // Because the gil can be released during cache insertion, this forces // the lock order to be mu_ then gil so we must release the gil first. - pybind11::gil_scoped_release release; + nb::gil_scoped_release release; // Acquire a mutex to avoid problems where the gil is released during // cache insertion and then a second thread invalidates the cache order. mu_.Lock(); } - Key key{context, args, kwargs}; - auto entry = cache.GetOrCreateIfAbsent(key, [&inserted](const Key& key) { - inserted = true; - return std::make_shared(); - }); - mu_.Unlock(); + { + // GetOrCreateIfAbsent calls into Python hash and equality functions, + // which may throw exceptions. The use of absl::Cleanup ensures mu_ is + // released if that happens. + absl::Cleanup unlock = [this]() + ABSL_UNLOCK_FUNCTION(mu_) { mu_.Unlock(); }; + Key key{context, args, kwargs}; + entry = cache.GetOrCreateIfAbsent(key, [&inserted](const Key& key) { + inserted = true; + return std::make_shared(); + }); + } if (!entry->completed.HasBeenNotified()) { if (inserted) { ++misses_; @@ -208,14 +223,14 @@ class WeakrefLRUCache : public std::enable_shared_from_this { entry->has_result = true; } else { if (entry->thread_id == std::this_thread::get_id()) { - auto error_string = absl::StrCat( - "Recursively calling ", - pybind11::cast(pybind11::repr(weakref_key)), - pybind11::cast(pybind11::repr(args))); + auto error_string = + absl::StrCat("Recursively calling ", + nb::cast(nb::repr(weakref_key)), + nb::cast(nb::repr(args))); PyErr_SetString(PyExc_RecursionError, error_string.c_str()); - throw pybind11::error_already_set(); + throw nb::python_error(); } - pybind11::gil_scoped_release release; + nb::gil_scoped_release release; entry->completed.WaitForNotification(); } } @@ -227,6 +242,20 @@ class WeakrefLRUCache : public std::enable_shared_from_this { return fn_(weakref_key, *args, **kwargs); } } + std::vector GetKeys() { + std::vector results; + mu_.Lock(); + for (const auto& wr_key : entries_) { + for (const auto& rest : *wr_key.second) { + nb::tuple result = + nb::make_tuple(wr_key.first.weakref, rest.first.context, + rest.first.args, rest.first.kwargs); + results.push_back(std::move(result)); + } + } + mu_.Unlock(); + return results; + } CacheInfo GetCacheInfo() const { CacheInfo result; result.hits = total_queries_ - misses_; @@ -245,8 +274,8 @@ class WeakrefLRUCache : public std::enable_shared_from_this { deferred_deletes.clear(); } - pybind11::function cache_context_fn_; - pybind11::function fn_; + nb::callable cache_context_fn_; + nb::callable fn_; Cache::LRUList lru_list_; absl::node_hash_map, WeakrefKeyHash, WeakrefKeyEq> @@ -256,23 +285,20 @@ class WeakrefLRUCache : public std::enable_shared_from_this { absl::Mutex mu_; }; -namespace { -namespace py = ::pybind11; -} // namespace - -void BuildWeakrefLRUCacheAPI(pybind11::module& m) { +void BuildWeakrefLRUCacheAPI(nb::module_& m) { auto weakref_lru_cache = - py::class_>( - m, "WeakrefLRUCache") + nb::class_(m, "WeakrefLRUCache", + nb::is_weak_referenceable()) .def("__call__", &WeakrefLRUCache::Call) + .def("cache_keys", &WeakrefLRUCache::GetKeys) .def("cache_info", &WeakrefLRUCache::GetCacheInfo) .def("cache_clear", &WeakrefLRUCache::Clear); - py::class_(weakref_lru_cache, + nb::class_(weakref_lru_cache, "WeakrefLRUCacheInfo") - .def_readonly("hits", &WeakrefLRUCache::CacheInfo::hits) - .def_readonly("misses", &WeakrefLRUCache::CacheInfo::misses) - .def_readonly("maxsize", &WeakrefLRUCache::CacheInfo::maxsize) - .def_readonly("currsize", &WeakrefLRUCache::CacheInfo::currsize) + .def_ro("hits", &WeakrefLRUCache::CacheInfo::hits) + .def_ro("misses", &WeakrefLRUCache::CacheInfo::misses) + .def_ro("maxsize", &WeakrefLRUCache::CacheInfo::maxsize) + .def_ro("currsize", &WeakrefLRUCache::CacheInfo::currsize) .def("__repr__", [](WeakrefLRUCache::CacheInfo& info) { return absl::StrCat( "WeakrefLRUCache(hits=", info.hits, ", misses=", info.misses, @@ -280,12 +306,10 @@ void BuildWeakrefLRUCacheAPI(pybind11::module& m) { }); m.def( "weakref_lru_cache", - [](pybind11::function cache_context_fn, pybind11::function fn, - int64_t maxsize) { + [](nb::callable cache_context_fn, nb::callable fn, int64_t maxsize) { return std::make_shared(cache_context_fn, fn, maxsize); }, - pybind11::arg("cache_context_fn"), pybind11::arg("fn"), - pybind11::arg("maxsize") = 2048); + nb::arg("cache_context_fn"), nb::arg("fn"), nb::arg("maxsize") = 2048); } } // namespace jax diff --git a/xla/python/weakref_lru_cache.h b/xla/python/weakref_lru_cache.h index d15bc3b6ba9d3..4565ce98e7f3e 100644 --- a/xla/python/weakref_lru_cache.h +++ b/xla/python/weakref_lru_cache.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,11 +17,11 @@ limitations under the License. #define XLA_PYTHON_WEAKREF_LRU_CACHE_H_ // placeholder for index annotation headers -#include "pybind11/pybind11.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind namespace jax { -void BuildWeakrefLRUCacheAPI(pybind11::module& m); +void BuildWeakrefLRUCacheAPI(nanobind::module_& m); } // namespace jax diff --git a/xla/python/weakref_lru_cache_test.py b/xla/python/weakref_lru_cache_test.py index 213d8a9c23aed..ad5f07bee0bf7 100644 --- a/xla/python/weakref_lru_cache_test.py +++ b/xla/python/weakref_lru_cache_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -94,6 +94,58 @@ def CacheFn(obj, kwkey1, kwkey2): self.assertEqual(cache(wrkey, kwkey1="b", kwkey2="a"), 2) self.assertEqual(cache(wrkey, kwkey2="b", kwkey1="a"), 1) + def testGetKeys(self): + def CacheFn(obj, arg): + del obj + return arg + "extra" + + cache = xla_client.weakref_lru_cache(lambda: None, CacheFn, 4) + + class WRKey: + pass + + wrkey = WRKey() + + self.assertEmpty(cache.cache_keys()) + cache(wrkey, "arg1") + cache(wrkey, "arg2") + self.assertLen(cache.cache_keys(), 2) + + def testCrashingKey(self): + class WRKey: + pass + + class CrashingKey: + # A key that raises exceptions if eq or hash is called. + + def __eq__(self, other): + raise ValueError("eq") + + def __hash__(self): + raise ValueError("hash") + + cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048) + wrkey = WRKey() + with self.assertRaises(ValueError): + for _ in range(100): + cache(wrkey, CrashingKey()) + + def testPrintingStats(self): + class WRKey: + pass + + cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048) + wrkey = WRKey() + for i in range(10): + cache(wrkey, i) + for i in range(5): + cache(wrkey, i) + + self.assertEqual( + repr(cache.cache_info()), + "WeakrefLRUCache(hits=5, misses=10, maxsize=2048, currsize=10)", + ) + if __name__ == "__main__": absltest.main() diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 4e02186dfe93b..85acf36fe17ae 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,62 +13,84 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/python/xla.h" +#include #include -#include #include #include #include #include #include -#include #include #include #include #include "absl/base/casts.h" -// clang-format off -// Must be included first +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/synchronization/mutex.h" -#include "xla/pjrt/c/pjrt_c_api.h" -#include "xla/pjrt/distributed/protocol.pb.h" -#include "xla/python/py_client.h" -#include "tsl/python/lib/core/numpy.h" //NOLINT -// clang-format on - -#include "absl/strings/ascii.h" -#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" -#include "pybind11/attr.h" // from @pybind11 -#include "pybind11/cast.h" // from @pybind11 -#include "pybind11/detail/common.h" // from @pybind11 -#include "pybind11/numpy.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 -#include "pybind11/stl_bind.h" // from @pybind11 -#include "xla/layout_util.h" +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/nb_defs.h" // from @nanobind +#include "nanobind/stl/function.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/optional.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/pair.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/set.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/variant.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/vector.h" // from @nanobind // IWYU pragma: keep +#include "xla/ffi/ffi_api.h" +#include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/distributed/client.h" #include "xla/pjrt/distributed/distributed.h" +#include "xla/pjrt/distributed/protocol.pb.h" #include "xla/pjrt/distributed/service.h" -#include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt_proxy/client/py_module.h" +#include "xla/python/py_client.h" +#include "xla/python/py_program.h" +#include "xla/service/cpu/collectives_interface.h" +#include "xla/tsl/python/lib/core/numpy.h" //NOLINT #ifdef XLA_PYTHON_ENABLE_GPU #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #endif // XLA_PYTHON_ENABLE_GPU + +#ifdef __linux__ +#include "gloo/transport/tcp/attr.h" // from @gloo +#include "gloo/transport/tcp/device.h" // from @gloo +#include "xla/pjrt/cpu/gloo_collectives.h" +#include "xla/pjrt/cpu/gloo_kv_store.h" +#endif // __linux__ + +#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) +#include "xla/pjrt/cpu/mpi_collectives.h" +#endif // !_WIN32 && !PLATFORM_GOOGLE + #include "xla/pjrt/cpu/cpu_client.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/gpu/gpu_helpers.h" #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_c_api_client.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/python/custom_call_sharding.h" #include "xla/python/dlpack.h" #include "xla/python/jax_jit.h" -#include "xla/python/logging.h" +#include "xla/python/logging.h" // IWYU pragma: keep #include "xla/python/mlir.h" +#include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_class_ptr.h" #include "xla/python/ops.h" #include "xla/python/outfeed_receiver_py.h" #include "xla/python/pjit.h" @@ -77,33 +99,28 @@ limitations under the License. #include "xla/python/pprof_profile_builder.h" #include "xla/python/profiler.h" #include "xla/python/py_array.h" -#include "xla/python/py_buffer.h" #include "xla/python/py_compile_only_client.h" +#include "xla/python/py_device.h" #include "xla/python/py_device_list.h" #include "xla/python/py_executable.h" +#include "xla/python/py_memory_space.h" #include "xla/python/python_ref_manager.h" #include "xla/python/pytree.h" #include "xla/python/sharding.h" -#include "xla/python/status_casters.h" #include "xla/python/traceback.h" #include "xla/python/transfer_guard_lib.h" -#include "xla/python/types.h" -#include "xla/python/util.h" #include "xla/python/weakref_lru_cache.h" #include "xla/python/xla_compiler.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/statusor.h" -#include "xla/util.h" #include "tsl/distributed_runtime/preemption/preemption_sync_manager.h" #include "tsl/platform/platform.h" +#include "tsl/platform/status.h" // TODO(phawkins): remove host_id properties after JAX is update to avoid them. namespace xla { namespace { -namespace py = pybind11; +namespace nb = nanobind; bool IsOptimizedBuild() { #if NDEBUG @@ -143,26 +160,32 @@ bool IsSanitized() { return IsAsan() || IsMsan() || IsTsan(); } } // namespace -static void Init(py::module_& m) { +NB_MODULE(xla_extension, m_nb) { // Initialize ABSL logging because code within XLA uses it. #ifndef PLATFORM_GOOGLE InitializeAbslLogging(); #endif // PLATFORM_GOOGLE + // We seem to get a fair number of leak warnings from nanobind. It's unclear + // whether these are false positives or not. + nb::set_leak_warnings(false); + tsl::ImportNumpy(); // Exceptions - py::register_exception(m, "XlaRuntimeError", - PyExc_RuntimeError); + nb::exception xla_runtime_error(m_nb, "XlaRuntimeError", + PyExc_RuntimeError); // Types - py::enum_(m, "PrimitiveType") + nb::enum_(m_nb, "PrimitiveType") .value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID) .value("PRED", PRED) + .value("S4", S4) .value("S8", S8) .value("S16", S16) .value("S32", S32) .value("S64", S64) + .value("U4", U4) .value("U8", U8) .value("U16", U16) .value("U32", U32) @@ -183,525 +206,318 @@ static void Init(py::module_& m) { .value("TOKEN", TOKEN); // Must be before PyClient.compile. - BuildXlaCompilerSubmodule(m); - - py::class_> device( - m, "Device", - "A descriptor of an available device.\n\nSubclasses are used to " - "represent specific types of devices, e.g. CPUs, GPUs. Subclasses may " - "have additional properties specific to that device type."); - device - .def_property_readonly( - "id", &PjRtDevice::id, - "Integer ID of this device.\n\nUnique across all available devices " - "of this type, including remote devices on multi-host platforms.") - .def_property_readonly( - "process_index", &PjRtDevice::process_index, - "Integer index of this device's process.\n\n" - "This is always 0 except on multi-process platforms.") - .def_property_readonly("host_id", &PjRtDevice::process_index, - "Deprecated; please use process_index") - .def_property_readonly("task_id", &PjRtDevice::process_index, - "Deprecated; please use process_index") - .def_property_readonly("platform", - [](const ClientAndPtr& device) { - // TODO(phawkins): this is a temporary backwards - // compatibility shim. We changed the name PJRT - // reports for GPU platforms to "cuda" or "rocm", - // but we haven't yet updated JAX clients that - // expect "gpu". Migrate users and remove this - // code. - if (device.client()->platform_name() == "cuda" || - device.client()->platform_name() == "rocm") { - return absl::string_view("gpu"); - } else { - return device.client()->platform_name(); - } - }) - .def_property_readonly("device_kind", &PjRtDevice::device_kind) - .def_property_readonly("client", - [](const ClientAndPtr& device) { - return device.client(); - }) - .def_property_readonly( - "local_hardware_id", - [](const ClientAndPtr& device) -> std::optional { - int local_hardware_id = device->local_hardware_id(); - if (local_hardware_id == -1) { - return std::nullopt; - } - return local_hardware_id; - }, - "Opaque hardware ID, e.g., the CUDA device number. In general, not " - "guaranteed to be dense, and not guaranteed to be defined on all " - "platforms.") - .def("__str__", &PjRtDevice::DebugString) - .def("__repr__", &PjRtDevice::ToString) - .def("transfer_to_infeed", - [](PjRtDevice& device, const LiteralSlice& literal) { - GlobalPyRefManager()->CollectGarbage(); - py::gil_scoped_release gil_release; - xla::ThrowIfError(device.TransferToInfeed(literal)); + BuildXlaCompilerSubmodule(m_nb); + + PyDevice::RegisterPythonType(m_nb); + PyMemorySpace::RegisterPythonType(m_nb); + PyClient::RegisterPythonTypes(m_nb); + + nb::class_(m_nb, "PjRtLayout") + .def("__str__", &PjRtLayout::ToString) + .def("__eq__", [](const PjRtLayout& layout, + const PjRtLayout& other) { return layout == other; }) + .def("__hash__", + [](const PjRtLayout& layout) { return absl::HashOf(layout); }); + + nb::class_(m_nb, "PjRtXlaLayout") + .def("__getstate__", + [](const PjRtXlaLayout& layout) -> nb::tuple { + absl::StatusOr serialized = layout.Serialize(); + ThrowIfError(serialized.status()); + return nb::make_tuple( + nb::bytes(serialized->data(), serialized->size())); }) - .def("transfer_from_outfeed", - [](PjRtDevice& device, const Shape& shape) -> py::object { - GlobalPyRefManager()->CollectGarbage(); - std::shared_ptr literal; - { - py::gil_scoped_release gil_release; - Shape shape_with_layout = shape; - ShapeUtil::ForEachMutableSubshape( - &shape_with_layout, [](Shape* subshape, const ShapeIndex&) { - if (!subshape->has_layout()) { - LayoutUtil::SetToDefaultLayout(subshape); - } - }); - literal = std::make_shared(shape_with_layout); - xla::ThrowIfError(device.TransferFromOutfeed(literal.get())); - } - return ValueOrThrow(LiteralToPython(std::move(literal))); - }) - .def( - "memory", - [](const ClientAndPtr& device, const std::string& kind) { - return jax::GetMemory(device, kind); - }, - py::arg("kind")) - // Returns the default memory of a device. - .def("default_memory", - [](const ClientAndPtr& device) { - auto* memory_space = - xla::ValueOrThrow(device->default_memory_space()); - return WrapWithClient(device.client(), memory_space); - }) - // Returns all the memories that a device can address. - .def("addressable_memories", - [](const ClientAndPtr& device) { - std::vector> memory_spaces; - auto span = device->memory_spaces(); - memory_spaces.reserve(span.size()); - for (auto* memory_space : span) { - memory_spaces.push_back( - WrapWithClient(device.client(), memory_space)); - } - return memory_spaces; - }) - .def("live_buffers", - [](const ClientAndPtr& device) { - PythonDeprecationWarning( - "Per device live_buffers() is going to be deprecated. Please " - "use the jax.live_arrays() for jax.Arrays instead."); - return py::list(); - }) - .def( - "memory_stats", - [](const PjRtDevice& device) - -> std::optional> { - GlobalPyRefManager()->CollectGarbage(); - xla::StatusOr maybe_stats = - device.GetAllocatorStats(); - if (absl::IsUnimplemented(maybe_stats.status())) { - return std::nullopt; - } - // Raise error if any status other than Unimplemented is returned. - ThrowIfError(maybe_stats.status()); - - std::map result; - result["num_allocs"] = maybe_stats->num_allocs; - result["bytes_in_use"] = maybe_stats->bytes_in_use; - result["peak_bytes_in_use"] = maybe_stats->peak_bytes_in_use; - result["largest_alloc_size"] = maybe_stats->largest_alloc_size; - if (maybe_stats->bytes_limit) { - result["bytes_limit"] = *maybe_stats->bytes_limit; - } - result["bytes_reserved"] = maybe_stats->bytes_reserved; - result["peak_bytes_reserved"] = maybe_stats->peak_bytes_reserved; - if (maybe_stats->bytes_reservable_limit) { - result["bytes_reservable_limit"] = - *maybe_stats->bytes_reservable_limit; - } - result["largest_free_block_bytes"] = - maybe_stats->largest_free_block_bytes; - if (maybe_stats->pool_bytes) { - result["pool_bytes"] = *maybe_stats->pool_bytes; - } - if (maybe_stats->peak_pool_bytes) { - result["peak_pool_bytes"] = *maybe_stats->peak_pool_bytes; - } - return result; - }, - "Returns memory statistics for this device keyed by name. May not be " - "implemented on all platforms, and different platforms may return " - "different stats, or -1 for unavailable stats. 'bytes_in_use' is " - "usually available. Intended for diagnostic use.") - .def("get_stream_for_external_ready_events", - xla::ValueOrThrowWrapper( - &PjRtDevice::GetStreamForExternalReadyEvents)); - static PyMethodDef get_attr_method = { - "__getattr__", - +[](PyObject* self, PyObject* args) -> PyObject* { - PyObject* key; - if (!PyArg_ParseTuple(args, "O", &key)) { - PyErr_SetString(PyExc_TypeError, "__getattr__ must take 1 argument."); - return nullptr; - } - try { - auto device = py::cast(py::handle(self)); - auto name = py::cast(py::handle(key)); - const auto& attrs = device->Attributes(); - auto it = attrs.find(name); - if (it != attrs.end()) { - auto result = - std::visit([](auto&& v) { return py::cast(v); }, it->second); - return result.release().ptr(); - } - PyErr_SetNone(PyExc_AttributeError); - return nullptr; - } catch (std::exception& e) { - PyErr_Format(PyExc_SystemError, - "Some unhandled pybind11 exception: %s", e.what()); - return nullptr; - } catch (...) { - PyErr_SetString(PyExc_SystemError, - "Some unhandled pybind11 exception."); - return nullptr; - } - }, - METH_VARARGS, - nullptr, - }; - device.attr("__getattr__") = - py::reinterpret_steal(PyDescr_NewMethod( - reinterpret_cast(device.ptr()), &get_attr_method)); - - py::class_> memory_space( - m, "Memory"); - memory_space - .def_property_readonly( - "process_index", - [](const ClientAndPtr& memory_space) { - return memory_space.client()->process_index(); - }) - .def_property_readonly( - "platform", - [](const ClientAndPtr& memory_space) { - // TODO(phawkins): this is a temporary backwards - // compatibility shim. We changed the name PJRT - // reports for GPU platforms to "cuda" or "rocm", - // but we haven't yet updated JAX clients that - // expect "gpu". Migrate users and remove this - // code. - if (memory_space.client()->platform_name() == "cuda" || - memory_space.client()->platform_name() == "rocm") { - return absl::string_view("gpu"); - } else { - return memory_space.client()->platform_name(); - } - }) - .def_property_readonly("kind", &PjRtMemorySpace::memory_space_kind) - .def("__str__", &PjRtMemorySpace::DebugString) - .def("__repr__", &PjRtMemorySpace::ToString) - // Returns the devices that can address this `Memory`. - .def("addressable_by_devices", - [](const ClientAndPtr& memory_space) { - std::vector> devices; - auto span = memory_space->devices(); - devices.reserve(span.size()); - for (PjRtDevice* device : span) { - devices.push_back(WrapWithClient(memory_space.client(), device)); - } - return devices; - }); + .def("__setstate__", [](PjRtXlaLayout* self, nb::tuple t) { + // TODO(b/328671718): don't assume PjRtXlaLayout. We probably want a + // generic method on PjRtCompiler instead, although we'll have + // somehow have to attach a compiler to this PjRtLayout (something + // like ClientAndPtr). + nb::bytes serialized = nb::cast(t[0]); + absl::StatusOr layout = PjRtXlaLayout::Deserialize( + std::string_view(serialized.c_str(), serialized.size())); + ThrowIfError(layout.status()); + new (self) PjRtXlaLayout(std::move(*layout)); + }); - // Local XLA client methods. + jax::BuildWeakrefLRUCacheAPI(m_nb); - py::enum_(m, "HostBufferSemantics") - .value("IMMUTABLE_ONLY_DURING_CALL", - PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall) - .value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES", - PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes) - .value("ZERO_COPY", PjRtClient::HostBufferSemantics::kZeroCopy); + nb::class_ cpu_collectives(m_nb, + "CpuCollectives"); - jax::BuildWeakrefLRUCacheAPI(m); + m_nb.def( + "make_gloo_tcp_collectives", + [](std::shared_ptr distributed_client, - py::class_> py_local_client(m, "Client"); - py_local_client.def_property_readonly("platform", &PyClient::platform_name) - .def_property_readonly("platform_version", &PyClient::platform_version) - .def_property_readonly("runtime_type", &PyClient::runtime_type) - .def("device_count", &PyClient::device_count) - .def("local_device_count", &PyClient::addressable_device_count) - .def("devices", &PyClient::Devices) - .def("local_devices", &PyClient::LocalDevices) - .def("device_from_local_hardware_id", - xla::ValueOrThrowWrapper(&PyClient::DeviceFromLocalHardwareId)) - .def("live_buffers", &PyClient::LiveBuffers) - .def("live_executables", &PyClient::LiveExecutables) - .def("live_arrays", &PyClient::LiveArrays) - .def("process_index", &PyClient::process_index) - .def("host_id", &PyClient::process_index) - .def("task_id", &PyClient::process_index) - .def( - "buffer_from_pyval", - [](py::handle py_client, py::handle argument, py::handle py_device, - bool force_copy, - PjRtClient::HostBufferSemantics host_buffer_semantics) { - PyClient* client = fast_cast(py_client); - PjRtDevice* device = py_device.is_none() - ? nullptr - : fast_cast(py_device); - return ValueOrThrow(client->BufferFromPyval( - argument, device, force_copy, host_buffer_semantics)); - }, - py::arg("argument"), py::arg("device") = nullptr, - py::arg("force_copy") = false, - py::arg("host_buffer_semantics") = - PjRtClient::HostBufferSemantics::kZeroCopy) - .def("make_cross_host_receive_buffers", - xla::ValueOrThrowWrapper(&PyClient::MakeCrossHostReceiveBuffers), - py::arg("shapes"), py::arg("device")) - .def("compile", xla::ValueOrThrowWrapper(&PyClient::Compile), - py::arg("computation"), - py::arg("compile_options") = CompileOptions(), - py::arg("host_callbacks") = std::vector()) - .def("serialize_executable", - xla::ValueOrThrowWrapper(&PyClient::SerializeExecutable)) - .def("deserialize_executable", - xla::ValueOrThrowWrapper(&PyClient::DeserializeExecutable), - py::arg("serialized"), py::arg("compile_options") = std::nullopt, - py::arg("host_callbacks") = std::vector()) - .def("heap_profile", xla::ValueOrThrowWrapper(&PyClient::HeapProfile)) - // TODO(zhangqiaorjc): Experimental. - .def("defragment", - [](PyClient& self) { xla::ThrowIfError(self.Defragment()); }) - .def("get_emit_python_callback_descriptor", - xla::ValueOrThrowWrapper(&PyClient::GetEmitPythonCallbackDescriptor), - py::arg("callable"), py::arg("operand_shapes"), - py::arg("result_shapes") = std::nullopt) - .def("make_python_callback_from_host_send_and_recv", - xla::ValueOrThrowWrapper( - &PyClient::MakePythonCallbackUsingHostSendAndRecv), - py::arg("callable"), py::arg("operand_shapes"), - py::arg("result_shapes"), py::arg("send_channel_ids"), - py::arg("recv_channel_ids"), py::arg("serializer") = py::none()) - .def("__getattr__", [](PyClient& client, std::string name) -> py::object { - const auto& attrs = client.attributes(); - auto it = attrs.find(name); - if (it != attrs.end()) { - return std::visit([](auto&& v) { return py::cast(v); }, it->second); + std::optional hostname, + std::optional interface) + -> std::shared_ptr { +#ifdef __linux__ + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); } - throw py::attribute_error(absl::StrCat("Unknown attribute ", name)); - }); + auto gloo_kv_store = std::make_unique(kv_store); + auto tcp_attrs = gloo::transport::tcp::attr(); + if (hostname) { + tcp_attrs.hostname = *hostname; + } + if (interface) { + tcp_attrs.iface = *interface; + } + auto tcp_device = gloo::transport::tcp::CreateDevice(tcp_attrs); + return std::make_shared(std::move(gloo_kv_store), + std::move(tcp_device)); +#else // __linux__ + throw xla::XlaRuntimeError( + "make_gloo_tcp_collectives only implemented for linux"); +#endif // __linux__ + }, + nb::arg("distributed_client"), nb::arg("hostname").none() = std::nullopt, + nb::arg("interface").none() = std::nullopt); + +#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) + nb::class_ mpi_collectives(m_nb, "MpiCollectives", + cpu_collectives); + mpi_collectives.def("Init", &cpu::MpiCollectives::Init); + mpi_collectives.def("Finalize", &cpu::MpiCollectives::Finalize); + m_nb.def("make_mpi_collectives", + []() -> std::shared_ptr { + return std::make_shared(); + }); +#else // !_WIN32 && !PLATFORM_GOOGLE + m_nb.def("make_mpi_collectives", + []() -> std::shared_ptr { + throw xla::XlaRuntimeError( + "make_mpi_collectives is not implemented for Windows"); + }); +#endif // !_WIN32 && !PLATFORM_GOOGLE - m.def( + m_nb.def( "get_tfrt_cpu_client", [](bool asynchronous, std::shared_ptr distributed_client, - int node_id, int num_nodes) -> std::shared_ptr { - py::gil_scoped_release gil_release; - CpuClientOptions options; - if (distributed_client != nullptr) { - std::string key_prefix = "cpu:"; - options.kv_get = - [distributed_client, key_prefix]( - std::string_view k, - absl::Duration timeout) -> xla::StatusOr { - return distributed_client->BlockingKeyValueGet( - absl::StrCat(key_prefix, k), timeout); - }; - options.kv_put = [distributed_client, key_prefix]( - std::string_view k, - std::string_view v) -> xla::Status { - return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), - v); - }; - options.node_id = node_id; - options.num_nodes = num_nodes; - } + int node_id, int num_nodes, + std::shared_ptr collectives) + -> nb_class_ptr { + std::unique_ptr ifrt_client; + { + nb::gil_scoped_release gil_release; + CpuClientOptions options; + if (distributed_client != nullptr) { + options.kv_store = + GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + options.node_id = node_id; + options.num_nodes = num_nodes; + + options.collectives = std::move(collectives); + } - options.asynchronous = asynchronous; - std::unique_ptr client = - xla::ValueOrThrow(GetTfrtCpuClient(options)); - return std::make_shared( - ifrt::PjRtClient::Create(std::move(client))); + options.asynchronous = asynchronous; + std::unique_ptr client = + xla::ValueOrThrow(GetTfrtCpuClient(options)); + ifrt_client = ifrt::PjRtClient::Create(std::move(client)); + } + return PyClient::Make(std::move(ifrt_client)); }, - py::arg("asynchronous") = true, py::arg("distributed_client") = nullptr, - py::arg("node_id") = 0, py::arg("num_nodes") = 1); - m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool { - xla::StatusOr pjrt_api = pjrt::PjrtApi(platform_name); + nb::arg("asynchronous") = true, nb::arg("distributed_client") = nullptr, + nb::arg("node_id") = 0, nb::arg("num_nodes") = 1, + nb::arg("collectives").none() = + std::shared_ptr()); + m_nb.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool { + absl::StatusOr pjrt_api = pjrt::PjrtApi(platform_name); return pjrt_api.ok(); }); - m.def("load_pjrt_plugin", - [](std::string platform_name, std::string library_path) -> py::capsule { + m_nb.def( + "load_pjrt_plugin", + [](std::string platform_name, std::optional library_path, + std::optional c_api) -> nb::capsule { + if (library_path.has_value()) { const PJRT_Api* api = xla::ValueOrThrow( - pjrt::LoadPjrtPlugin(platform_name, library_path)); - return py::capsule(absl::bit_cast(api), "pjrt_c_api"); - }); - m.def("pjrt_plugin_initialized", [](std::string platform_name) -> bool { + pjrt::LoadPjrtPlugin(platform_name, *library_path)); + return nb::capsule(absl::bit_cast(api), "pjrt_c_api"); + } + if (absl::string_view(c_api->name()) != "pjrt_c_api") { + throw nb::value_error( + "c_api argument to load_pjrt_plugin is not a pjrt_c_api " + "capsule."); + } + xla::ThrowIfError(pjrt::SetPjrtApi( + platform_name, static_cast(c_api->data()))); + return *c_api; + }, + nb::arg("platform_name"), nb::arg("library_path").none() = std::nullopt, + nb::arg("c_api").none() = std::nullopt); + m_nb.def("pjrt_plugin_initialized", [](std::string platform_name) -> bool { return xla::ValueOrThrow(pjrt::IsPjrtPluginInitialized(platform_name)); }); - m.def("initialize_pjrt_plugin", [](std::string platform_name) { + m_nb.def("initialize_pjrt_plugin", [](std::string platform_name) { return xla::ThrowIfError(pjrt::InitializePjrtPlugin(platform_name)); }); #ifdef XLA_PYTHON_ENABLE_GPU - py::class_ alloc_config(m, "GpuAllocatorConfig"); - alloc_config.def(py::init<>()) - .def_readwrite("kind", &GpuAllocatorConfig::kind) - .def_readwrite("memory_fraction", &GpuAllocatorConfig::memory_fraction) - .def_readwrite("preallocate", &GpuAllocatorConfig::preallocate); - py::enum_(alloc_config, "Kind") + nb::class_ alloc_config(m_nb, "GpuAllocatorConfig"); + alloc_config.def(nb::init<>()) + .def_rw("kind", &GpuAllocatorConfig::kind) + .def_rw("memory_fraction", &GpuAllocatorConfig::memory_fraction) + .def_rw("preallocate", &GpuAllocatorConfig::preallocate) + .def_rw("collective_memory_size", + &GpuAllocatorConfig::collective_memory_size); + nb::enum_(alloc_config, "Kind") .value("DEFAULT", GpuAllocatorConfig::Kind::kDefault) .value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform) .value("BFC", GpuAllocatorConfig::Kind::kBFC) .value("CUDA_ASYNC", GpuAllocatorConfig::Kind::kCudaAsync); - m.def( + m_nb.def( "get_gpu_client", [](bool asynchronous, const GpuAllocatorConfig& allocator_config, std::shared_ptr distributed_client, int node_id, int num_nodes, std::optional> allowed_devices, std::optional platform_name, - std::optional mock = false) -> std::shared_ptr { - py::gil_scoped_release gil_release; - PjRtClient::KeyValueGetCallback kv_get = nullptr; - PjRtClient::KeyValuePutCallback kv_put = nullptr; - if (distributed_client != nullptr) { - // Use the plugin name as key prefix. - std::string key_prefix = "gpu:"; - kv_get = [distributed_client, key_prefix]( - std::string_view k, - absl::Duration timeout) -> xla::StatusOr { - return distributed_client->BlockingKeyValueGet( - absl::StrCat(key_prefix, k), timeout); - }; - kv_put = [distributed_client, key_prefix]( - std::string_view k, std::string_view v) -> xla::Status { - return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), - v); - }; + std::optional mock = false) -> nb_class_ptr { + std::unique_ptr ifrt_client; + { + nb::gil_scoped_release gil_release; + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"gpu:"); + } + GpuClientOptions options; + options.allocator_config = allocator_config; + options.node_id = node_id; + options.num_nodes = num_nodes; + options.allowed_devices = allowed_devices; + options.platform_name = platform_name; + options.kv_store = kv_store; + options.enable_mock_nccl = mock.value_or(false); + std::unique_ptr pjrt_client = + xla::ValueOrThrow(GetStreamExecutorGpuClient(options)); + ifrt_client = ifrt::PjRtClient::Create(std::move(pjrt_client)); } - GpuClientOptions options; - options.allocator_config = allocator_config; - options.node_id = node_id; - options.num_nodes = num_nodes; - options.allowed_devices = allowed_devices; - options.platform_name = platform_name; - options.kv_get = kv_get; - options.kv_put = kv_put; - options.enable_mock_nccl = mock.value_or(false); - std::unique_ptr client = - xla::ValueOrThrow(GetStreamExecutorGpuClient(options)); - return std::make_shared( - ifrt::PjRtClient::Create(std::move(client))); + return PyClient::Make(std::move(ifrt_client)); }, - py::arg("asynchronous") = true, - py::arg("allocator_config") = GpuAllocatorConfig(), - py::arg("distributed_client") = nullptr, py::arg("node_id") = 0, - py::arg("num_nodes") = 1, py::arg("allowed_devices") = std::nullopt, - py::arg("platform_name") = std::nullopt, py::arg("mock") = std::nullopt); + nb::arg("asynchronous") = true, + nb::arg("allocator_config") = GpuAllocatorConfig(), + nb::arg("distributed_client") = nullptr, nb::arg("node_id") = 0, + nb::arg("num_nodes") = 1, + nb::arg("allowed_devices").none() = std::nullopt, + nb::arg("platform_name").none() = std::nullopt, + nb::arg("mock").none() = std::nullopt); #endif // XLA_PYTHON_ENABLE_GPU - m.def( + m_nb.def( "get_c_api_client", [](std::string platform_name, const absl::flat_hash_map& options, std::shared_ptr distributed_client) - -> std::shared_ptr { - py::gil_scoped_release gil_release; - PjRtClient::KeyValueGetCallback kv_get = nullptr; - PjRtClient::KeyValuePutCallback kv_put = nullptr; - if (distributed_client != nullptr) { - kv_get = [distributed_client, platform_name](std::string_view k, - absl::Duration timeout) { - return distributed_client->BlockingKeyValueGet( - absl::StrCat(platform_name, ":", k), timeout); - }; - kv_put = [distributed_client, platform_name](std::string_view k, - std::string_view v) { - return distributed_client->KeyValueSet( - absl::StrCat(platform_name, ":", k), v); - }; + -> nb_class_ptr { + std::unique_ptr ifrt_client; + { + nb::gil_scoped_release gil_release; + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore( + distributed_client, + /*key_prefix=*/absl::StrCat(platform_name, ":")); + } + std::unique_ptr c_api_client = xla::ValueOrThrow( + GetCApiClient(platform_name, options, kv_store)); + ifrt_client = ifrt::PjRtClient::Create(std::move(c_api_client)); } - std::unique_ptr c_api_client = xla::ValueOrThrow( - GetCApiClient(platform_name, options, kv_get, kv_put)); - return std::make_shared( - ifrt::PjRtClient::Create(std::move(c_api_client))); + return PyClient::Make(std::move(ifrt_client)); }, - py::arg("platform_name"), - py::arg("options") = absl::flat_hash_map(), - py::arg("distributed_client") = nullptr); - m.def("get_default_c_api_topology", - [](std::string platform_name, std::string topology_name, - const absl::flat_hash_map& options) - -> std::shared_ptr { - return xla::ValueOrThrow( - GetCApiTopology(platform_name, topology_name, options)); - }); - m.def("get_topology_for_devices", - [](std::vector> devices_and_clients) { - if (devices_and_clients.empty()) { - throw py::value_error( - "get_topology_for_devices requires >= 1 devices."); - } - auto client = devices_and_clients[0].client(); - std::vector devices; - devices.reserve(devices_and_clients.size()); - for (const ClientAndPtr& device : devices_and_clients) { - if (device.get_client() != client.get()) { - throw py::value_error( - "devices passed to get_topology_for_devices come from " - "different clients."); - } - devices.push_back(device.get()); + nb::arg("platform_name"), + nb::arg("options") = absl::flat_hash_map(), + nb::arg("distributed_client").none() = nullptr); + // TODO(b/322357665): Delete this method after TPU plugin changes to use the + // standard registration. + m_nb.def("get_default_c_api_topology", + [](std::string platform_name, std::string topology_name, + const absl::flat_hash_map& options) + -> std::shared_ptr { + return xla::ValueOrThrow( + GetCApiTopology(platform_name, topology_name, options)); + }); + m_nb.def( + "get_c_api_topology", + [](nb::capsule c_api, std::string topology_name, + const absl::flat_hash_map& options) + -> std::shared_ptr { + if (absl::string_view(c_api.name()) != "pjrt_c_api") { + throw nb::value_error( + "Argument to get_c_api_topology was not a pjrt_c_api capsule."); + } + return xla::ValueOrThrow( + GetCApiTopology(static_cast(c_api.data()), + topology_name, options)); + }); + m_nb.def( + "get_topology_for_devices", + [](const std::vector>& py_devices) { + if (py_devices.empty()) { + throw nb::value_error( + "get_topology_for_devices requires >= 1 devices."); + } + auto client = py_devices[0]->client(); + std::vector ifrt_devices; + ifrt_devices.reserve(py_devices.size()); + for (const auto& py_device : py_devices) { + if (py_device->client().get() != client.get()) { + throw nb::value_error( + "devices passed to get_topology_for_devices come from " + "different clients."); } - return xla::ValueOrThrow(client->ifrt_client()->GetTopologyForDevices( - absl::MakeSpan(devices))); - }); - - TF_CHECK_OK(PyArray::RegisterTypes(m)); - jax::RegisterDeviceList(m); - jax::RegisterSharding(m); + ifrt_devices.push_back(py_device->device()); + } + return xla::ValueOrThrow(client->ifrt_client()->GetTopologyForDevices( + absl::MakeSpan(ifrt_devices))); + }); - py::class_(m, "CompiledMemoryStats") - .def_readwrite("generated_code_size_in_bytes", - &CompiledMemoryStats::generated_code_size_in_bytes) - .def_readwrite("argument_size_in_bytes", - &CompiledMemoryStats::argument_size_in_bytes) - .def_readwrite("output_size_in_bytes", - &CompiledMemoryStats::output_size_in_bytes) - .def_readwrite("alias_size_in_bytes", - &CompiledMemoryStats::alias_size_in_bytes) - .def_readwrite("temp_size_in_bytes", - &CompiledMemoryStats::temp_size_in_bytes) - .def_property_readonly("serialized_hlo_proto", - [](const CompiledMemoryStats& cms) -> py::bytes { - return py::bytes(cms.serialized_hlo_proto); - }) + TF_CHECK_OK(PyArray::RegisterTypes(m_nb)); + jax::RegisterDeviceList(m_nb); + jax::RegisterSharding(m_nb); + + nb::class_(m_nb, "CompiledMemoryStats") + .def_rw("generated_code_size_in_bytes", + &CompiledMemoryStats::generated_code_size_in_bytes) + .def_rw("argument_size_in_bytes", + &CompiledMemoryStats::argument_size_in_bytes) + .def_rw("output_size_in_bytes", + &CompiledMemoryStats::output_size_in_bytes) + .def_rw("alias_size_in_bytes", &CompiledMemoryStats::alias_size_in_bytes) + .def_rw("temp_size_in_bytes", &CompiledMemoryStats::temp_size_in_bytes) + .def_rw("host_generated_code_size_in_bytes", + &CompiledMemoryStats::host_generated_code_size_in_bytes) + .def_rw("host_argument_size_in_bytes", + &CompiledMemoryStats::host_argument_size_in_bytes) + .def_rw("host_output_size_in_bytes", + &CompiledMemoryStats::host_output_size_in_bytes) + .def_rw("host_alias_size_in_bytes", + &CompiledMemoryStats::host_alias_size_in_bytes) + .def_rw("host_temp_size_in_bytes", + &CompiledMemoryStats::host_temp_size_in_bytes) + .def_prop_ro("serialized_hlo_proto", + [](const CompiledMemoryStats& cms) -> nb::bytes { + return nb::bytes(cms.serialized_hlo_proto.data(), + cms.serialized_hlo_proto.size()); + }) .def("__str__", &CompiledMemoryStats::DebugString); - py::class_(m, "ExecuteResults") + nb::class_(m_nb, "ExecuteResults") .def("__len__", [](PyExecuteResults& results) { return results.Size(); }) .def("disassemble_into_single_device_arrays", - [](PyExecuteResults& results) { - return results.DisassembleIntoSingleDeviceArrays(); - }) + &PyExecuteResults::DisassembleIntoSingleDeviceArrays) .def("disassemble_prefix_into_single_device_arrays", - [](PyExecuteResults& results, size_t n) { - return results.DisassemblePrefixIntoSingleDeviceArrays(n); - }) - .def("consume_with_handlers", - [](PyExecuteResults& results, - std::vector> - out_handlers) { - return results.ConsumeWithHandlers(std::move(out_handlers)); - }) - .def("consume_token", - [](PyExecuteResults& results) { return results.ConsumeToken(); }); + &PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays) + .def("consume_with_handlers", &PyExecuteResults::ConsumeWithHandlers) + .def("consume_token", &PyExecuteResults::ConsumeToken); - py::class_> - loaded_executable(m, "LoadedExecutable"); - loaded_executable.def_property_readonly("client", &PyLoadedExecutable::client) + nb::class_(m_nb, "LoadedExecutable") + .def_prop_ro("client", &PyLoadedExecutable::client) .def("local_logical_device_ids", [](PyLoadedExecutable* exec) { auto span = exec->addressable_device_logical_ids(); @@ -723,17 +539,16 @@ static void Init(py::module_& m) { .def("execute_sharded_on_local_devices", xla::ValueOrThrowWrapper( &PyLoadedExecutable::ExecuteShardedOnLocalDevices), - py::arg("arguments")) + nb::arg("arguments")) .def("execute_sharded_on_local_devices_with_tokens", xla::ValueOrThrowWrapper( &PyLoadedExecutable::ExecuteShardedOnLocalDevicesWithTokens), - py::arg("arguments")) + nb::arg("arguments")) // TODO(parkers): Switch execute_sharded_on_local_devices* to this. .def("execute_sharded", xla::ValueOrThrowWrapper(&PyLoadedExecutable::ExecuteSharded), - py::arg("arguments"), py::arg("with_tokens") = false) - .def("hlo_modules", - xla::ValueOrThrowWrapper(&PyLoadedExecutable::HloModules)) + nb::arg("arguments"), nb::arg("with_tokens") = false) + .def("hlo_modules", ValueOrThrowWrapper(&PyLoadedExecutable::HloModules)) .def("get_output_memory_kinds", xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetOutputMemoryKinds)) .def("get_output_shardings", &PyLoadedExecutable::GetOutputShardings) @@ -746,63 +561,74 @@ static void Init(py::module_& m) { .def("keep_alive", &PyLoadedExecutable::KeepAlive) .def("compile_options", [](const PyLoadedExecutable& self) { - return ValueOrThrow(self.pjrt_executable()->GetCompileOptions()); + return xla::ValueOrThrow( + self.pjrt_executable()->GetCompileOptions()); }) .def("cost_analysis", xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetCostAnalysis)) - .def_property_readonly("traceback", &PyLoadedExecutable::traceback) - .def_property_readonly("fingerprint", - [](PyLoadedExecutable* exec) -> py::object { - if (exec->fingerprint().has_value()) { - return py::bytes(*exec->fingerprint()); - } else { - return py::none(); - } - }); - py::class_ token(m, "Token"); + .def_prop_ro("traceback", &PyLoadedExecutable::traceback) + .def_prop_ro("fingerprint", [](PyLoadedExecutable* exec) -> nb::object { + if (exec->fingerprint().has_value()) { + return nb::bytes(exec->fingerprint()->data(), + exec->fingerprint()->size()); + } else { + return nb::none(); + } + }); + nb::class_ token(m_nb, "Token"); token.def("block_until_ready", [](PyToken& self) { xla::ThrowIfError(self.Await()); }); - py::class_ sharded_token(m, "ShardedToken"); + + nb::class_ sharded_token(m_nb, "ShardedToken"); sharded_token.def("block_until_ready", [](PyShardedToken& self) { xla::ThrowIfError(self.Await()); }); sharded_token.def("get_token", &PyShardedToken::GetPyToken); - m.def("buffer_to_dlpack_managed_tensor", - xla::ValueOrThrowWrapper(BufferToDLPackManagedTensor), - py::arg("buffer"), py::arg("stream") = py::none()); - m.def("dlpack_managed_tensor_to_buffer", - [](const pybind11::capsule& tensor, ClientAndPtr device, - std::optional stream) { - return xla::ValueOrThrow(DLPackManagedTensorToBuffer( - tensor, device.get(), device.client(), stream)); - }); + m_nb.def("buffer_to_dlpack_managed_tensor", + xla::ValueOrThrowWrapper(BufferToDLPackManagedTensor), + nb::arg("buffer"), nb::arg("stream").none() = nb::none()); + m_nb.def( + "dlpack_managed_tensor_to_buffer", + [](const nb::capsule& tensor, nb_class_ptr device, + std::optional stream) { + return xla::ValueOrThrow(DLPackManagedTensorToBuffer( + tensor, device->device(), device->client(), stream)); + }, + nb::arg("dlpack"), nb::arg("device"), nb::arg("stream").none()); // Legacy overload - m.def( + m_nb.def( "dlpack_managed_tensor_to_buffer", - [](const pybind11::capsule& tensor, std::shared_ptr cpu_client, - std::shared_ptr gpu_client) { + [](const nb::capsule& tensor, + std::optional> cpu_client, + std::optional> gpu_client) { return xla::ValueOrThrow(DLPackManagedTensorToBuffer( tensor, std::move(cpu_client), std::move(gpu_client))); }, - py::arg("dlpack"), py::arg("cpu_backend") = nullptr, - py::arg("gpu_backend") = nullptr); - - BuildProfilerSubmodule(&m); - BuildOpsSubmodule(&m); - BuildOutfeedReceiverSubmodule(&m); - BuildPytreeSubmodule(m); - jax::BuildJaxjitSubmodule(m); - jax::BuildPmapSubmodule(m); - jax::BuildPjitSubmodule(m); - jax::BuildTransferGuardSubmodule(m); - BuildTracebackSubmodule(m); - BuildMlirSubmodule(m); - BuildCustomCallShardingPybindAPI(m); - - py::class_> - preemption_sync_manager(m, "PreemptionSyncManager"); + nb::arg("dlpack"), nb::arg("cpu_backend").none() = nb::none(), + nb::arg("gpu_backend").none() = nb::none()); + m_nb.def("cuda_array_interface_to_buffer", + xla::ValueOrThrowWrapper(CudaArrayInterfaceToBuffer)); + + BuildIfrtProgramsSubmodule(m_nb); + BuildProfilerSubmodule(m_nb); + BuildOpsSubmodule(m_nb); + BuildOutfeedReceiverSubmodule(m_nb); + BuildPytreeSubmodule(m_nb); + jax::BuildJaxjitSubmodule(m_nb); + jax::BuildPmapSubmodule(m_nb); + jax::BuildPjitSubmodule(m_nb); + jax::BuildTransferGuardSubmodule(m_nb); + BuildTracebackSubmodule(m_nb); + BuildMlirSubmodule(m_nb); + BuildCustomCallShardingPybindAPI(m_nb); + + // The following uses python bindings for PyClient defined above using + // pybind11, and hence needs pybind11::module_ (not just nanobind::module_). + xla::ifrt::proxy::BuildIfrtProxySubmodule(m_nb); + + nb::class_ preemption_sync_manager( + m_nb, "PreemptionSyncManager"); preemption_sync_manager .def( "initialize", @@ -812,32 +638,30 @@ static void Init(py::module_& m) { xla::ValueOrThrow(client->GetCoordinationServiceAgent()); xla::ThrowIfError(manager.Initialize(agent)); }, - py::arg("distributed_client")) + nb::arg("distributed_client")) .def("reached_sync_point", [](tsl::PreemptionSyncManager& manager, int step_counter) { return manager.ReachedSyncPoint(step_counter); }); - m.def("create_preemption_sync_manager", - []() { return tsl::CreatePreemptionSyncManager(); }); + m_nb.def("create_preemption_sync_manager", + []() { return tsl::CreatePreemptionSyncManager(); }); - py::class_> - distributed_runtime_service(m, "DistributedRuntimeService"); + nb::class_ distributed_runtime_service( + m_nb, "DistributedRuntimeService"); distributed_runtime_service.def("shutdown", &DistributedRuntimeService::Shutdown, - py::call_guard()); - py::class_> - distributed_runtime_client(m, "DistributedRuntimeClient"); + nb::call_guard()); + nb::class_ distributed_runtime_client( + m_nb, "DistributedRuntimeClient"); distributed_runtime_client .def("connect", [](DistributedRuntimeClient& self) { - py::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; xla::ThrowIfError(self.Connect()); }) .def("shutdown", [](DistributedRuntimeClient& self) { - py::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; xla::ThrowIfError(self.Shutdown()); }) // This method assumes that the value is a Python string. Use @@ -847,32 +671,32 @@ static void Init(py::module_& m) { "blocking_key_value_get", [](DistributedRuntimeClient& client, std::string key, int64_t timeout_in_ms) { - py::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; return xla::ValueOrThrow(client.BlockingKeyValueGet( key, absl::Milliseconds(timeout_in_ms))); }, - py::arg("key"), py::arg("timeout_in_ms")) + nb::arg("key"), nb::arg("timeout_in_ms")) // Same as `blocking_key_value_get()`, but retrieves the raw Python byte // values explicitly. .def( "blocking_key_value_get_bytes", [](DistributedRuntimeClient& client, std::string key, - int64_t timeout_in_ms) -> py::bytes { - py::gil_scoped_release gil_release; + int64_t timeout_in_ms) -> nb::bytes { + nb::gil_scoped_release gil_release; std::string result = xla::ValueOrThrow(client.BlockingKeyValueGet( key, absl::Milliseconds(timeout_in_ms))); - return py::bytes(result); + return nb::bytes(result.data(), result.size()); }, - py::arg("key"), py::arg("timeout_in_ms")) + nb::arg("key"), nb::arg("timeout_in_ms")) .def( "wait_at_barrier", [](DistributedRuntimeClient& client, std::string barrier_id, int64_t timeout_in_ms) { - py::gil_scoped_release gil_release; + nb::gil_scoped_release gil_release; xla::ThrowIfError(client.WaitAtBarrier( barrier_id, absl::Milliseconds(timeout_in_ms))); }, - py::arg("barrier_id"), py::arg("timeout_in_ms")) + nb::arg("barrier_id"), nb::arg("timeout_in_ms")) // The key must be a string, but the value can either be a Python string // or bytes object. // With Python string values, use `key_value_set()` and @@ -881,48 +705,70 @@ static void Init(py::module_& m) { // `blocking_key_value_get_bytes()`. .def( "key_value_set", - [](DistributedRuntimeClient& client, std::string key, - std::string value) { - py::gil_scoped_release gil_release; + [](DistributedRuntimeClient& client, std::string_view key, + std::string_view value) { + nb::gil_scoped_release gil_release; xla::ThrowIfError(client.KeyValueSet(key, value)); }, - py::arg("key"), py::arg("value")) + nb::arg("key"), nb::arg("value")) + .def( + "key_value_set", + [](DistributedRuntimeClient& client, std::string_view key, + nb::bytes value) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.KeyValueSet( + key, std::string_view(value.c_str(), value.size()))); + }, + nb::arg("key"), nb::arg("value")) + // The key must be a string, but the value must a + // Python bytes object. + // Use `key_value_set_bytes()` and `blocking_key_value_get_bytes()`. + .def( + "key_value_set_bytes", + [](DistributedRuntimeClient& client, std::string_view key, + nb::bytes value) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.KeyValueSet( + key, std::string_view(value.c_str(), value.size()))); + }, + nb::arg("key"), nb::arg("value")) // Assumes that all values in the directory are Python strings. .def( "key_value_dir_get", - [](DistributedRuntimeClient& client, std::string key) { - py::gil_scoped_release gil_release; + [](DistributedRuntimeClient& client, std::string_view key) { + nb::gil_scoped_release gil_release; return xla::ValueOrThrow(client.KeyValueDirGet(key)); }, - py::arg("key")) + nb::arg("key")) // Assumes that all values in the directory are Python byte objects. // Same as `key_value_dir_get()`, but retrieves Python byte values // explicitly. .def( "key_value_dir_get_bytes", - [](DistributedRuntimeClient& client, std::string key) - -> std::vector> { - py::gil_scoped_release gil_release; + [](DistributedRuntimeClient& client, std::string_view key) + -> std::vector> { + nb::gil_scoped_release gil_release; std::vector> result = xla::ValueOrThrow(client.KeyValueDirGet(key)); - // Convert std::string values to py::bytes. - std::vector> kvs; + // Convert std::string values to nb::bytes. + std::vector> kvs; kvs.reserve(result.size()); for (const auto& kv : result) { - kvs.push_back(std::pair(kv.first, py::bytes(kv.second))); + kvs.push_back(std::pair( + kv.first, nb::bytes(kv.second.data(), kv.second.size()))); } return kvs; }, - py::arg("key")) + nb::arg("key")) .def( "key_value_delete", - [](DistributedRuntimeClient& client, std::string key) { - py::gil_scoped_release gil_release; + [](DistributedRuntimeClient& client, std::string_view key) { + nb::gil_scoped_release gil_release; return client.KeyValueDelete(key); }, - py::arg("key")); + nb::arg("key")); - m.def( + m_nb.def( "get_distributed_runtime_service", [](std::string address, int num_nodes, std::optional heartbeat_interval, @@ -949,19 +795,19 @@ static void Init(py::module_& m) { xla::ValueOrThrow(GetDistributedRuntimeService(address, options)); return service; }, - py::arg("address"), py::arg("num_nodes"), py::kw_only(), - py::arg("heartbeat_interval") = std::nullopt, - py::arg("max_missing_heartbeats") = std::nullopt, - py::arg("cluster_register_timeout") = std::nullopt, - py::arg("shutdown_timeout") = std::nullopt); + nb::arg("address"), nb::arg("num_nodes"), + nb::arg("heartbeat_interval").none() = std::nullopt, + nb::arg("max_missing_heartbeats").none() = std::nullopt, + nb::arg("cluster_register_timeout").none() = std::nullopt, + nb::arg("shutdown_timeout").none() = std::nullopt); - m.def( + m_nb.def( "get_distributed_runtime_client", [](std::string address, int node_id, std::optional rpc_timeout, std::optional init_timeout, std::optional shutdown_timeout, std::optional heartbeat_interval, std::optional max_missing_heartbeats, - std::optional> missed_heartbeat_callback, std::optional shutdown_on_destruction) @@ -992,66 +838,67 @@ static void Init(py::module_& m) { } return GetDistributedRuntimeClient(address, options); }, - py::arg("address"), py::arg("node_id"), py::kw_only(), - py::arg("rpc_timeout") = std::nullopt, - py::arg("init_timeout") = std::nullopt, - py::arg("shutdown_timeout") = std::nullopt, - py::arg("heartbeat_interval") = std::nullopt, - py::arg("max_missing_heartbeats") = std::nullopt, - py::arg("missed_heartbeat_callback") = std::nullopt, - py::arg("shutdown_on_destruction") = std::nullopt); - - m.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); }); - - m.def("is_optimized_build", &IsOptimizedBuild); - - m.def("json_to_pprof_profile", xla::ValueOrThrowWrapper(JsonToPprofProfile), - "Encodes the JSON representation of a pprof Profile into its binary " - "protocol buffer encoding."); - m.def("pprof_profile_to_json", xla::ValueOrThrowWrapper(PprofProfileToJson), - "Decodes an uncompressed pprof Profile protocol buffer into a JSON " - "representation"); - - RegisterCompileOnlyClient(m); - py::class_>( - m, "DeviceTopology") + nb::arg("address"), nb::arg("node_id"), + nb::arg("rpc_timeout").none() = std::nullopt, + nb::arg("init_timeout").none() = std::nullopt, + nb::arg("shutdown_timeout").none() = std::nullopt, + nb::arg("heartbeat_interval").none() = std::nullopt, + nb::arg("max_missing_heartbeats").none() = std::nullopt, + nb::arg("missed_heartbeat_callback").none() = std::nullopt, + nb::arg("shutdown_on_destruction").none() = std::nullopt); + + m_nb.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); }); + + m_nb.def("is_optimized_build", &IsOptimizedBuild); + + m_nb.def("json_to_pprof_profile", + xla::ValueOrThrowWrapper(JsonToPprofProfile), + "Encodes the JSON representation of a pprof Profile into its binary " + "protocol buffer encoding."); + m_nb.def("pprof_profile_to_json", + xla::ValueOrThrowWrapper(PprofProfileToJson), + "Decodes an uncompressed pprof Profile protocol buffer into a JSON " + "representation"); + + RegisterCompileOnlyClient(m_nb); + nb::class_(m_nb, "DeviceTopology") .def("_make_compile_only_devices", [](std::shared_ptr topology) { return MakeCompileOnlyClient(topology)->Devices(); }) - .def_property_readonly("platform", - [](PjRtTopologyDescription& topology) { - return topology.platform_name(); - }) - .def_property_readonly("platform_version", - [](PjRtTopologyDescription& topology) { - return topology.platform_version(); - }) + .def_prop_ro("platform", + [](PjRtTopologyDescription& topology) { + return topology.platform_name(); + }) + .def_prop_ro("platform_version", + [](PjRtTopologyDescription& topology) { + return topology.platform_version(); + }) .def("serialize", - [](PjRtTopologyDescription& topology) -> py::bytes { - return py::bytes(ValueOrThrow(topology.Serialize())); + [](PjRtTopologyDescription& topology) -> nb::bytes { + std::string serialized = ValueOrThrow(topology.Serialize()); + return nb::bytes(serialized.data(), serialized.size()); }) - .def( - "__getattr__", - [](PjRtTopologyDescription& topology, - std::string name) -> py::object { - const auto& attrs = topology.Attributes(); - auto it = attrs.find(name); - if (it != attrs.end()) { - return std::visit([](auto&& v) { return py::cast(v); }, - it->second); - } - throw py::attribute_error(absl::StrCat("Unknown attribute ", name)); - }); + .def("__getattr__", + [](PjRtTopologyDescription& topology, + std::string_view name) -> nb::object { + const auto& attrs = topology.Attributes(); + auto it = attrs.find(name); + if (it != attrs.end()) { + return std::visit([](auto&& v) { return nb::cast(v); }, + it->second); + } + throw nb::attribute_error( + absl::StrCat("Unknown attribute ", name).c_str()); + }); - py::class_>(m, "Executable") - .def("hlo_modules", - xla::ValueOrThrowWrapper(&PjRtExecutable::GetHloModules)) + nb::class_(m_nb, "Executable") + .def("hlo_modules", ValueOrThrowWrapper(&PjRtExecutable::GetHloModules)) .def("get_output_memory_kinds", xla::ValueOrThrowWrapper(&PjRtExecutable::GetOutputMemoryKinds)) .def("get_output_shardings", &PjRtExecutable::GetOutputShardings) .def("get_parameter_layouts", - xla::ValueOrThrowWrapper(&PjRtExecutable::GetParameterLayouts)) + ValueOrThrowWrapper(&PjRtExecutable::GetParameterLayouts)) .def("get_output_layouts", xla::ValueOrThrowWrapper(&PjRtExecutable::GetOutputLayouts)) .def("get_parameter_shardings", &PjRtExecutable::GetParameterShardings) @@ -1059,55 +906,42 @@ static void Init(py::module_& m) { xla::ValueOrThrowWrapper(&PjRtExecutable::GetCompiledMemoryStats)) .def("compile_options", xla::ValueOrThrowWrapper(&PjRtExecutable::GetCompileOptions)) - .def("serialize", [](const PjRtExecutable& exec) -> py::bytes { - return ValueOrThrow(exec.SerializeExecutable()); - }); + .def("serialize", + [](const PjRtExecutable& exec) -> nb::bytes { + std::string serialized = ValueOrThrow(exec.SerializeExecutable()); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("cost_analysis", + xla::ValueOrThrowWrapper(&PjRtExecutable::GetCostAnalysis)); - m.def("is_asan", IsAsan); - m.def("is_msan", IsMsan); - m.def("is_tsan", IsTsan); - m.def("is_sanitized", IsSanitized); + m_nb.def("is_asan", IsAsan); + m_nb.def("is_msan", IsMsan); + m_nb.def("is_tsan", IsTsan); + m_nb.def("is_sanitized", IsSanitized); - m.def( + m_nb.def( "batched_device_put", - [](py::object aval, py::object sharding, std::vector xs, - std::vector> dst_devices, bool committed, + [](nb::object aval, nb::object sharding, std::vector xs, + std::vector dst_devices, bool committed, bool force_copy, - PjRtClient::HostBufferSemantics host_buffer_semantics) -> PyArray { + PjRtClient::HostBufferSemantics host_buffer_semantics) -> nb::object { return ValueOrThrow(PyArray::BatchedDevicePut( - std::move(aval), std::move(sharding), std::move(xs), + nb::borrow(aval.ptr()), nb::borrow(sharding.ptr()), std::move(xs), std::move(dst_devices), committed, force_copy, host_buffer_semantics, jax::GetEnableX64())); }, - py::arg("aval"), py::arg("sharding"), py::arg("xs"), py::arg("devices"), - py::arg("committed") = true, py::arg("force_copy") = false, - py::arg("host_buffer_semantics") = - PjRtClient::HostBufferSemantics::kZeroCopy); - m.def( - "check_and_canonicalize_memory_kind", - [](py::object memory_kind, jax::PyDeviceList* device_list) -> py::object { - return jax::CheckAndCanonicalizeMemoryKind(memory_kind, device_list); - }); -} // NOLINT(readability/fn_size) + nb::arg("aval"), nb::arg("sharding"), nb::arg("xs"), nb::arg("devices"), + nb::arg("committed") = true, nb::arg("force_copy") = false, + nb::arg("host_buffer_semantics") = + PjRtClient::HostBufferSemantics::kImmutableZeroCopy); -// This code in essence is a copy of PYBIND11_MODULE(). We can't just call -// PYBIND11_MODULE because we want the entry point of the module to be in -// the py_extension() translation unit but we don't want anything else to be -// defined there. Inside Google, py_extension() translation units are linked -// differently and they end up with a different instance of the -// py::module_local() state, breaking that feature of pybind11. -static py::module_::module_def xla_module_def; + m_nb.def("batched_block_until_ready", [](std::vector xs) { + ThrowIfError(PyArray::BatchedBlockUntilReady(std::move(xs))); + }); -PyObject* InitializeXlaExtension() { - PYBIND11_CHECK_PYTHON_VERSION - PYBIND11_ENSURE_INTERNALS_READY - auto m = py::module_::create_extension_module("xla_extension", nullptr, - &xla_module_def); - try { - Init(m); - return m.ptr(); - } - PYBIND11_CATCH_INIT_EXCEPTIONS -} + m_nb.def("check_and_canonicalize_memory_kind", + &jax::CheckAndCanonicalizeMemoryKind, nb::arg("memory_kind").none(), + nb::arg("device_list")); +} // NOLINT(readability/fn_size) } // namespace xla diff --git a/xla/python/xla.h b/xla/python/xla.h deleted file mode 100644 index 690606897542d..0000000000000 --- a/xla/python/xla.h +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_PYTHON_XLA_H_ -#define XLA_PYTHON_XLA_H_ - -// placeholder for index annotation headers -#include "pybind11/pybind11.h" // from @pybind11 - -namespace xla { - -PyObject *InitializeXlaExtension(); - -} // namespace xla - -#endif // XLA_PYTHON_XLA_H_ diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 919eb67733697..8b63b072bd3a3 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ import logging import os import threading -from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, List, Mapping, Optional, Protocol, Sequence, Tuple, Union import ml_dtypes import numpy as np @@ -43,15 +43,16 @@ # Pylint has false positives for type annotations. # pylint: disable=invalid-sequence-index +ifrt_programs = _xla.ifrt_programs ops = _xla.ops profiler = _xla.profiler # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 218 +_version = 255 # Version number for MLIR:Python components. -mlir_api_version = 54 +mlir_api_version = 56 xla_platform_names = { 'cpu': 'Host', @@ -67,6 +68,7 @@ def make_cpu_client( distributed_client=None, node_id=0, num_nodes=1, + collectives=None ) -> ...: register_custom_call_handler('cpu', _xla.register_custom_call_target) return _xla.get_tfrt_cpu_client( @@ -74,6 +76,7 @@ def make_cpu_client( distributed_client=distributed_client, node_id=node_id, num_nodes=num_nodes, + collectives=collectives, ) @@ -88,10 +91,6 @@ def make_gpu_client( """Returns a GPU client. BFC allocator is used by default.""" options = generate_pjrt_gpu_plugin_options() allocator = options['allocator'] - memory_fraction = ( - options['memory_fraction'] if 'memory_fraction' in options else None - ) - preallocate = options['preallocate'] if 'preallocate' in options else None config = _xla.GpuAllocatorConfig() if allocator == 'default': config.kind = _xla.GpuAllocatorConfig.Kind.DEFAULT @@ -101,9 +100,12 @@ def make_gpu_client( config.kind = _xla.GpuAllocatorConfig.Kind.BFC if allocator == 'cuda_async': config.kind = _xla.GpuAllocatorConfig.Kind.CUDA_ASYNC - if memory_fraction: - config.memory_fraction = float(memory_fraction) - config.preallocate = preallocate not in ('0', 'false', 'False') + if 'memory_fraction' in options: + config.memory_fraction = options['memory_fraction'] + if 'preallocate' in options: + config.preallocate = options['preallocate'] + if 'collective_memory_size' in options: + config.collective_memory_size = options['collective_memory_size'] register_custom_call_handler('CUDA', _xla.register_custom_call_target) register_custom_call_handler('ROCM', _xla.register_custom_call_target) @@ -139,12 +141,23 @@ def make_tfrt_tpu_c_api_device_topology( return _xla.get_default_c_api_topology('tpu', topology_name, dict(**kwargs)) +def make_c_api_device_topology( + c_api: Any, topology_name: str = '', **kwargs +) -> DeviceTopology: + """Creates a PJRT C API TopologyDescription.""" + return _xla.get_c_api_topology(c_api, topology_name, dict(**kwargs)) + + def pjrt_plugin_loaded(plugin_name: str) -> bool: return _xla.pjrt_plugin_loaded(plugin_name) def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any: - return _xla.load_pjrt_plugin(plugin_name, library_path) + return _xla.load_pjrt_plugin(plugin_name, library_path, c_api=None) + + +def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None: + return _xla.load_pjrt_plugin(plugin_name, None, c_api) def pjrt_plugin_initialized(plugin_name: str) -> bool: @@ -193,25 +206,21 @@ def make_tpu_client(library_path: Optional[str] = None): return make_tfrt_tpu_c_api_client() -def generate_pjrt_gpu_plugin_options( - visible_devices: str = 'all', -) -> _NameValueMapping: +def generate_pjrt_gpu_plugin_options() -> _NameValueMapping: """Generates the PjRt GPU plugin options. - Args: - visible_devices: A string of visible cuda devices. - Returns: A dictionary of plugin options. """ options = {} - if visible_devices != 'all': - options['visible_devices'] = [int(x) for x in visible_devices.split(',')] - options['platform_name'] = 'cuda' + options['platform_name'] = 'cuda' allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower() memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION', '') preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE', '') + collective_memory_size = os.getenv( + 'XLA_PYTHON_CLIENT_COLLECTIVE_MEM_SIZE_MB', '' + ) if allocator not in ('default', 'platform', 'bfc', 'cuda_async'): raise ValueError( 'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", ' @@ -222,6 +231,8 @@ def generate_pjrt_gpu_plugin_options( options['memory_fraction'] = float(memory_fraction) if preallocate: options['preallocate'] = preallocate not in ('false', 'False', '0') + if collective_memory_size: + options['collective_memory_size'] = int(collective_memory_size) * (1 << 20) return options @@ -257,10 +268,12 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): XLA_ELEMENT_TYPE_TO_DTYPE = { PrimitiveType.PRED: np.dtype('bool'), + PrimitiveType.S4: np.dtype('int4'), PrimitiveType.S8: np.dtype('int8'), PrimitiveType.S16: np.dtype('int16'), PrimitiveType.S32: np.dtype('int32'), PrimitiveType.S64: np.dtype('int64'), + PrimitiveType.U4: np.dtype('uint4'), PrimitiveType.U8: np.dtype('uint8'), PrimitiveType.U16: np.dtype('uint16'), PrimitiveType.U32: np.dtype('uint32'), @@ -531,6 +544,7 @@ def window_padding_type_to_pad_values( SingleDeviceSharding = _xla.SingleDeviceSharding PmapSharding = _xla.PmapSharding GSPMDSharding = _xla.GSPMDSharding +PjRtLayout = _xla.PjRtLayout def LoadedExecutable_execute(self, arguments, device=None): @@ -552,14 +566,22 @@ def LoadedExecutable_execute_with_token(self, arguments, device=None): LoadedExecutable.execute_with_token = LoadedExecutable_execute_with_token -_custom_callback_handler: dict[str, Any] = {} -# Key is xla_platform_name, value is (function_name, function) -_custom_callback: dict[str, list[Tuple[str, Any]]] = {} +class CustomCallHandler(Protocol): + + def __call__( + self, name: str, fn: Any, platform: str, /, api_version: int = ... + ) -> None: + ... + + +_custom_callback_handler: dict[str, CustomCallHandler] = {} +# Key is xla_platform_name, value is (function_name, function, api_version) +_custom_callback: dict[str, list[tuple[str, Any, int]]] = {} _custom_callback_lock = threading.Lock() def register_custom_call_target( - name: str, fn: Any, platform: str = 'cpu' + name: str, fn: Any, platform: str = 'cpu', api_version: int = 0 ) -> None: """Registers a custom call target. @@ -567,18 +589,26 @@ def register_custom_call_target( name: bytes containing the name of the function. fn: a PyCapsule object containing the function pointer. platform: the target platform. + api_version: the XLA FFI version to use. Supported versions are: 0 for the + untyped FFI and 1 for the typed FFI. """ # To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM" # Since that is hardcoded to CUDA, we are using the following as workaround. xla_platform_name = xla_platform_names.get(platform, platform) with _custom_callback_lock: if xla_platform_name in _custom_callback_handler: - _custom_callback_handler[xla_platform_name](name, fn, xla_platform_name) + _custom_callback_handler[xla_platform_name]( + name, fn, xla_platform_name, api_version + ) else: - _custom_callback.setdefault(xla_platform_name, []).append((name, fn)) + _custom_callback.setdefault(xla_platform_name, []).append( + (name, fn, api_version) + ) -def register_custom_call_handler(platform: str, handler: Any) -> None: +def register_custom_call_handler( + platform: str, handler: CustomCallHandler +) -> None: """Registers a custom handler and use it to register existing custom calls. If a custom call handler for the platform already exist, calling this method @@ -598,8 +628,8 @@ def register_custom_call_handler(platform: str, handler: Any) -> None: return _custom_callback_handler[xla_platform_name] = handler if xla_platform_name in _custom_callback: - for name, fn in _custom_callback[xla_platform_name]: - handler(name, fn, xla_platform_name) + for name, fn, api_version in _custom_callback[xla_platform_name]: + handler(name, fn, xla_platform_name, api_version) del _custom_callback[xla_platform_name] @@ -904,5 +934,7 @@ def heap_profile(client: Client) -> bytes: array_result_handler = _xla.array_result_handler copy_array_to_devices_with_sharding = _xla.copy_array_to_devices_with_sharding batched_device_put = _xla.batched_device_put +batched_block_until_ready = _xla.batched_block_until_ready check_and_canonicalize_memory_kind = _xla.check_and_canonicalize_memory_kind Layout = _xla.Layout +custom_call_targets = _xla.custom_call_targets diff --git a/xla/python/xla_client.pyi b/xla/python/xla_client.pyi index 04e22b8e7cf41..98e6da4c193f4 100644 --- a/xla/python/xla_client.pyi +++ b/xla/python/xla_client.pyi @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ import numpy from . import xla_extension as _xla from .xla_extension import Shape as Shape from .xla_extension import Layout as Layout +from .xla_extension import ifrt_programs as ifrt_programs from .xla_extension import ops as ops from .xla_extension import profiler as profiler @@ -42,6 +43,7 @@ from .xla_extension import OpSharding as OpSharding from .xla_extension import HloSharding as HloSharding from .xla_extension import PrimitiveType as PrimitiveType from .xla_extension import Traceback as Traceback +from .xla_extension import PjRtLayout as PjRtLayout from .xla_extension import XlaBuilder as XlaBuilder from .xla_extension import XlaComputation as XlaComputation from .xla_extension import XlaOp as XlaOp @@ -85,6 +87,7 @@ def make_cpu_client( distributed_client: Optional[DistributedRuntimeClient] = ..., node_id: int = ..., num_nodes: int = ..., + collectives: Optional[_xla.CpuCollectives] = ..., ) -> Client: ... @@ -103,6 +106,9 @@ def make_tfrt_tpu_c_api_client(options: Optional[_NameValueMapping] = None) -> C def make_tfrt_tpu_c_api_device_topology(topology_name: Optional[str] = None, **kwargs) -> DeviceTopology: ... +def make_c_api_device_topology(c_api: Any, topology_name: str = '', **kwargs) -> DeviceTopology: + ... + def get_topology_for_devices(devices: List[Device]) -> DeviceTopology: ... @@ -121,15 +127,16 @@ def pjrt_plugin_loaded(plugin_name: str) -> bool: def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any: ... +def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None: + ... + def pjrt_plugin_initialized(plugin_name: str) -> bool: ... def initialize_pjrt_plugin(plugin_name: str) -> None: ... -def generate_pjrt_gpu_plugin_options( - visible_devices: str = 'all', -) -> _NameValueMapping: +def generate_pjrt_gpu_plugin_options() -> _NameValueMapping: ... class OpMetadata: @@ -223,6 +230,8 @@ def copy_array_to_devices_with_sharding(self: ArrayImpl, devices: List[Device], def batched_device_put(aval: Any, sharding: Any, shards: Sequence[Any], devices: List[Device]) -> ArrayImpl: ... +def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ... + def check_and_canonicalize_memory_kind( memory_kind: Optional[str], device_list: DeviceList) -> Optional[str]: ... @@ -234,7 +243,7 @@ def array_result_handler( ... def register_custom_call_target( - name: str, fn: Callable, platform: str = ... + name: str, fn: Callable, platform: str = ..., api_version: int = ... ) -> None: ... @@ -242,3 +251,5 @@ def register_custom_call_handler(xla_platform_name: str, handler: Any) -> None: ... def encode_inspect_sharding_callback(handler: Any) -> bytes: ... + +def custom_call_targets(platform: str) -> dict[str, Any]: ... diff --git a/xla/python/xla_client_backend_independent_test.py b/xla/python/xla_client_backend_independent_test.py index 4a739a83fe5e8..05e85430ac333 100644 --- a/xla/python/xla_client_backend_independent_test.py +++ b/xla/python/xla_client_backend_independent_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/python/xla_client_test.py b/xla/python/xla_client_test.py index 4bb0c60f97c60..1cefa2dd9984f 100644 --- a/xla/python/xla_client_test.py +++ b/xla/python/xla_client_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -82,6 +82,8 @@ def jax_array_copy_to_host_async(self): # use widely for parameterizing tests. # pylint: disable=g-complex-comprehension +_CUSTOM_CALLS_REGISTERED = False + def TestFactory(xla_backend, cloud_tpu=False, @@ -109,6 +111,12 @@ def setUp(self): super(ComputationTest, self).setUp() self.backend = xla_backend() + global _CUSTOM_CALLS_REGISTERED + if self.backend.platform == "cpu" and not _CUSTOM_CALLS_REGISTERED: + for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): + xla_client.register_custom_call_target(name, fn, platform="cpu") + _CUSTOM_CALLS_REGISTERED = True + def _NewComputation(self, name=None): if name is None: name = self.id() @@ -223,8 +231,9 @@ def testFingerprint(self): executable = self.backend.compile( xla_computation_to_mlir_module(computation)) fingerprint = executable.fingerprint - if self.backend.platform == "tpu" and not (cloud_tpu or pathways or - pathways_ifrt): + if ( + self.backend.platform == "tpu" or self.backend.platform == "gpu" + ) and not (cloud_tpu or pathways or pathways_ifrt): logging.info("fingerprint: %s", fingerprint) self.assertNotEmpty(fingerprint) else: @@ -402,8 +411,6 @@ def testCustomCall(self): if self.backend.platform != "cpu": self.skipTest("Test requires cpu platform") c = self._NewComputation() - for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): - xla_client.register_custom_call_target(name, fn, platform="cpu") ops.CustomCallWithLayout( c, b"test_subtract_f32", @@ -425,8 +432,6 @@ def testCustomCallWithUnifiedApi(self): if self.backend.platform != "cpu": self.skipTest("Test requires cpu platform") c = self._NewComputation() - for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): - xla_client.register_custom_call_target(name, fn, platform="cpu") opaque_str = b"foo" ops.CustomCallWithLayout( @@ -448,6 +453,22 @@ def testCustomCallWithUnifiedApi(self): .API_VERSION_STATUS_RETURNING_UNIFIED) self._ExecuteAndCompareClose(c, expected=[1.25 + len(opaque_str)]) + def testCustomCallLookup(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + if xla_client._version < 241: + self.skipTest("Test requires jaxlib version 241") + + self.assertTrue(_CUSTOM_CALLS_REGISTERED) + xla_client.make_cpu_client() + self.assertContainsSubset( + [ + call.decode() + for call in custom_call_for_test.cpu_custom_call_targets.keys() + ], + xla_client.custom_call_targets("Host").keys(), + ) + tests.append(ComputationsWithConstantsTest) class ComputationFromProtoTest(absltest.TestCase): @@ -512,6 +533,12 @@ def testScalarMinusVectorExplicitNumbering(self, dtype): class LayoutsTest(ComputationTest): """Tests related to getting and setting on-device memory layouts.""" + def _minor_to_major(self, layout: xla_client.PjRtLayout): # pylint: disable=invalid-name + m2m_str = re.search("{([0-9,]*)", str(layout)).group(1) + if not m2m_str: + return () + return tuple(int(x) for x in m2m_str.split(",")) + @unittest.skipIf(pathways, "not implemented") def testGetArgumentLayouts(self): # Create computation with a few parameters. @@ -536,9 +563,9 @@ def MakeArg(shape, dtype): # Test that compiled executable returns plausible layouts. layouts: Sequence[xla_client.Layout] = executable.get_parameter_layouts() self.assertLen(layouts, 3) - self.assertLen(layouts[0].minor_to_major(), 3) - self.assertLen(layouts[1].minor_to_major(), 2) - self.assertEmpty(layouts[2].minor_to_major()) + self.assertLen(self._minor_to_major(layouts[0]), 3) + self.assertLen(self._minor_to_major(layouts[1]), 2) + self.assertEmpty(self._minor_to_major(layouts[2])) @unittest.skipIf(pathways, "not implemented") def testGetArgumentLayoutsTupled(self): @@ -569,9 +596,9 @@ def testGetArgumentLayoutsTupled(self): # Test that compiled executable returns plausible layouts. layouts: Sequence[xla_client.Layout] = executable.get_parameter_layouts() self.assertLen(layouts, 3) - self.assertLen(layouts[0].minor_to_major(), 3) - self.assertEmpty(layouts[1].minor_to_major()) - self.assertLen(layouts[2].minor_to_major(), 1) + self.assertLen(self._minor_to_major(layouts[0]), 3) + self.assertEmpty(self._minor_to_major(layouts[1])) + self.assertLen(self._minor_to_major(layouts[2]), 1) @unittest.skipIf(pathways, "not implemented") def testGetOutputLayouts(self): @@ -595,9 +622,9 @@ def testGetOutputLayouts(self): # Test that compiled executable returns plausible layouts. layouts: Sequence[xla_client.Layout] = executable.get_output_layouts() self.assertLen(layouts, 3) - self.assertLen(layouts[0].minor_to_major(), 2) - self.assertEmpty(layouts[1].minor_to_major()) - self.assertLen(layouts[2].minor_to_major(), 1) + self.assertLen(self._minor_to_major(layouts[0]), 2) + self.assertEmpty(self._minor_to_major(layouts[1])) + self.assertLen(self._minor_to_major(layouts[2]), 1) @unittest.skipIf(pathways, "not implemented") def testSetArgumentLayouts(self): @@ -631,9 +658,9 @@ def testSetArgumentLayouts(self): # Check input layouts. input_layouts = executable.get_parameter_layouts() self.assertLen(input_layouts, 3) - self.assertEqual(input_layouts[0].minor_to_major(), (0, 1, 2)) - self.assertEqual(input_layouts[1].minor_to_major(), ()) - self.assertEqual(input_layouts[2].minor_to_major(), (0,)) + self.assertEqual(self._minor_to_major(input_layouts[0]), (0, 1, 2)) + self.assertEqual(self._minor_to_major(input_layouts[1]), ()) + self.assertEqual(self._minor_to_major(input_layouts[2]), (0,)) # Compile a version with default arg0 layout so we can make sure we # actually set it above. @@ -641,8 +668,9 @@ def testSetArgumentLayouts(self): module_str.replace('"{0,1,2}"', '"default"') ) self.assertNotEqual( - input_layouts[0].minor_to_major(), - default_executable.get_parameter_layouts()[0].minor_to_major()) + self._minor_to_major(input_layouts[0]), + self._minor_to_major(default_executable.get_parameter_layouts()[0]), + ) @unittest.skipIf(pathways or pathways_ifrt, "not implemented") def testSetArgumentLayoutsLegacy(self): @@ -685,8 +713,10 @@ def MakeArg(shape, dtype, layout): executable.get_parameter_layouts()) self.assertEqual(len(actual_layouts), len(expected_layouts)) for actual, expected in zip(actual_layouts, expected_layouts): - self.assertEqual(actual.minor_to_major(), - expected.layout().minor_to_major()) + self.assertEqual( + self._minor_to_major(actual), + expected.layout().minor_to_major(), + ) @unittest.skipIf(pathways, "not implemented") def testSetOutputLayouts(self): @@ -720,9 +750,9 @@ def testSetOutputLayouts(self): # Check output layouts. output_layouts = executable.get_output_layouts() self.assertLen(output_layouts, 3) - self.assertEqual(output_layouts[0].minor_to_major(), (0, 1, 2)) - self.assertEqual(output_layouts[1].minor_to_major(), ()) - self.assertEqual(output_layouts[2].minor_to_major(), (0,)) + self.assertEqual(self._minor_to_major(output_layouts[0]), (0, 1, 2)) + self.assertEqual(self._minor_to_major(output_layouts[1]), ()) + self.assertEqual(self._minor_to_major(output_layouts[2]), (0,)) # Compile a version with default first output layout so we can make sure # we actually set it above. @@ -730,8 +760,9 @@ def testSetOutputLayouts(self): module_str.replace('"{0,1,2}"', '"default"') ) self.assertNotEqual( - output_layouts[0].minor_to_major(), - default_executable.get_output_layouts()[0].minor_to_major()) + self._minor_to_major(output_layouts[0]), + self._minor_to_major(default_executable.get_output_layouts()[0]), + ) @unittest.skipIf(pathways, "not implemented") def SetLayoutsSharded(self): @@ -767,13 +798,13 @@ def SetLayoutsSharded(self): # Check input layouts. input_layouts = executable.get_parameter_layouts() self.assertLen(input_layouts, 2) - self.assertEqual(input_layouts[0].minor_to_major(), (0, 1)) - self.assertEqual(input_layouts[1].minor_to_major(), ()) + self.assertEqual(self._minor_to_major(input_layouts[0]), (0, 1)) + self.assertEqual(self._minor_to_major(input_layouts[1]), ()) # Check output layout. output_layouts = executable.get_output_layouts() self.assertLen(output_layouts, 1) - self.assertEqual(input_layouts[0].minor_to_major(), (0, 1)) + self.assertEqual(self._minor_to_major(input_layouts[0]), (0, 1)) # Compile a version with default layouts so we can make sure we actually # set it above. @@ -781,11 +812,13 @@ def SetLayoutsSharded(self): module_str.replace('"{0,1}"', '"default"') ) self.assertNotEqual( - input_layouts[0].minor_to_major(), - default_executable.get_parameter_layouts()[0].minor_to_major()) + self._minor_to_major(input_layouts[0]), + self._minor_to_major(default_executable.get_parameter_layouts()[0]), + ) self.assertNotEqual( - output_layouts[0].minor_to_major(), - default_executable.get_output_layouts()[0].minor_to_major()) + self._minor_to_major(output_layouts[0]), + self._minor_to_major(default_executable.get_output_layouts()[0]), + ) @unittest.skipIf(pathways, "not implemented") def testAutoArgumentLayouts(self): @@ -817,8 +850,8 @@ def testAutoArgumentLayouts(self): # Check input layouts. input_layouts = executable.get_parameter_layouts() - self.assertEqual(input_layouts[0].minor_to_major(), (1, 0)) - self.assertEqual(input_layouts[1].minor_to_major(), (2, 0, 1)) + self.assertEqual(self._minor_to_major(input_layouts[0]), (1, 0)) + self.assertEqual(self._minor_to_major(input_layouts[1]), (2, 0, 1)) # Compile a version with default layouts so we can make sure the compiler # is actually choosing above. @@ -828,8 +861,8 @@ def testAutoArgumentLayouts(self): # We expect the compiler to choose a non-default layout for the second # (1024,8,128) argument. self.assertNotEqual( - input_layouts[1].minor_to_major(), - default_executable.get_parameter_layouts()[1].minor_to_major(), + self._minor_to_major(input_layouts[1]), + self._minor_to_major(default_executable.get_parameter_layouts()[1]), ) @unittest.skipIf(pathways, "not implemented") @@ -860,7 +893,7 @@ def testAutoOutputLayouts(self): # Check output layout output_layout, = executable.get_output_layouts() - self.assertEqual(output_layout.minor_to_major(), (2, 0, 1)) + self.assertEqual(self._minor_to_major(output_layout), (2, 0, 1)) # Compile a version with default layouts so we can make sure the compiler # is actually choosing above. @@ -869,8 +902,8 @@ def testAutoOutputLayouts(self): ) # We expect the compiler to choose a non-default output layout. self.assertNotEqual( - output_layout.minor_to_major(), - default_executable.get_output_layouts()[0].minor_to_major(), + self._minor_to_major(output_layout), + self._minor_to_major(default_executable.get_output_layouts()[0]), ) tests.append(LayoutsTest) @@ -2512,6 +2545,12 @@ def testPlatform(self): for device in self.backend.local_devices(): self.assertEqual(device.platform, self.backend.platform) + def testCoreCount(self): + if self.backend.platform != "gpu": + self.skipTest("core_count is only supported on GPU") + for device in self.backend.local_devices(): + self.assertGreater(device.core_count, 0) + def testLocalHardwareId(self): for device in self.backend.devices(): local_hardware_id = device.local_hardware_id @@ -2907,6 +2946,34 @@ def testNotExistPjRtCApiVersion(self): with self.assertRaises(AttributeError): self.backend.pjrt_c_api_minor_version # pylint: disable=pointless-statement + @unittest.skipIf(pathways or pathways_ifrt, "has different behavior") + def testPluginProgramDoesNotCompile(self): + program = xla_client.ifrt_programs.make_plugin_program("foobar") + options = xla_client.ifrt_programs.make_plugin_compile_options() + with self.assertRaisesRegex( + xla_client.XlaRuntimeError, "PjRtCompiler requires an XlaProgram" + ): + self.backend.compile_ifrt_program(program, options) + + @unittest.skipIf(pathways, "does not work with non-ifrt legacy pathways") + def testXlaProgramViaIfrtProgram(self): + c = self._NewComputation() + ops.Iota(c, xla_client.PrimitiveType.F32, 10) + program = xla_client.ifrt_programs.make_xla_program( + xla_computation_to_mlir_module(c.build()) + ) + options = xla_client.ifrt_programs.make_xla_compile_options( + xla_client.CompileOptions(), [] + ) + + compiled_c = self.backend.compile_ifrt_program(program, options) + results = xla_client.execute_with_python_values( + compiled_c, arguments=(), backend=self.backend + ) + + self.assertLen(results, 1) + np.testing.assert_equal(results[0], np.arange(10, dtype=np.float32)) + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or tfrt_tpu, "not implemented") def testExecutableSerialization(self): diff --git a/xla/python/xla_compiler.cc b/xla/python/xla_compiler.cc index ac0b0933eae6a..9231c95c2aa02 100644 --- a/xla/python/xla_compiler.cc +++ b/xla/python/xla_compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,35 +19,49 @@ limitations under the License. #include #include #include +#include #include #include #include "absl/hash/hash.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" -#include "pybind11/attr.h" // from @pybind11 -#include "pybind11/cast.h" // from @pybind11 -#include "pybind11/numpy.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 -#include "pybind11/stl_bind.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/ndarray.h" // from @nanobind +#include "nanobind/stl/optional.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/pair.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/variant.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/vector.h" // from @nanobind // IWYU pragma: keep #include "xla/array.h" #include "xla/client/executable_build_options.h" #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" #include "xla/debug_options_flags.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/layout.h" #include "xla/layout_util.h" -#include "xla/python/exceptions.h" +#include "xla/literal.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" #include "xla/python/py_client.h" -#include "xla/python/status_casters.h" #include "xla/python/types.h" #include "xla/service/call_inliner.h" #include "xla/service/computation_placer.h" @@ -68,11 +82,45 @@ limitations under the License. #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/lib/strings/proto_serialization.h" +#include "tsl/platform/logging.h" + +namespace nanobind { +namespace detail { + +template <> +struct type_caster { + public: + NB_TYPE_CASTER_FROM_PYTHON_ONLY(xla::OpMetadata, + const_name("xla::OpMetadata")); + + bool from_python(handle h, uint8_t, cleanup_list*) { + handle op_type = getattr(h, "op_type"); + if (!op_type.is_none()) { + value.set_op_type(cast(op_type)); + } + handle op_name = getattr(h, "op_name"); + if (!op_name.is_none()) { + value.set_op_name(cast(op_name)); + } + handle source_file = getattr(h, "source_file"); + if (!source_file.is_none()) { + value.set_source_file(cast(source_file)); + } + handle source_line = getattr(h, "source_line"); + if (!source_line.is_none()) { + value.set_source_line(cast(source_line)); + } + return true; + } +}; + +} // namespace detail +} // namespace nanobind namespace xla { namespace { -namespace py = pybind11; +namespace nb = nanobind; struct Uniquer { absl::Mutex mu; @@ -91,29 +139,29 @@ static std::string UniquifyName(const std::string& name) { } // Converts a computation to a serialized HloModuleProto. -StatusOr GetComputationSerializedProto( +StatusOr GetComputationSerializedProto( const XlaComputation& computation) { std::string result; if (!tsl::SerializeToStringDeterministic(computation.proto(), &result)) { return Unknown("Failed to serialize the HloModuleProto."); } - return py::bytes(result); + return nb::bytes(result.data(), result.size()); } // Converts a hlo module to a serialized HloModuleProto. -StatusOr GetHloModuleSerializedProto(const HloModule& module) { +StatusOr GetHloModuleSerializedProto(const HloModule& module) { std::string result; if (!tsl::SerializeToStringDeterministic(module.ToProto(), &result)) { return Unknown("Failed to serialize the HloModuleProto."); } - return py::bytes(result); + return nb::bytes(result.data(), result.size()); } // Converts a serialized HloModuleProto into a HloModule. StatusOr> HloModuleFromSerializedProto( - const py::bytes& bytes) { + const nb::bytes& bytes) { HloModuleProto proto; - proto.ParseFromString(bytes); + proto.ParseFromArray(bytes.c_str(), bytes.size()); TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, HloModule::CreateModuleConfigFromProto( proto, GetDebugOptionsFromFlags())); @@ -243,28 +291,36 @@ StatusOr IotaTileHelper( subgroup_types); } -// Registers a 'fn_capsule' as a CPU custom call target. -// 'fn_capsule' must be a void* pointer encapsulated in a PyCapsule object, -// with name "xla._CUSTOM_CALL_TARGET". +// Registers a 'fn_capsule' as a custom call target. +// 'fn_capsule' must be a void* pointer encapsulated in a PyCapsule object. // 'platform' is an XLA platform name, e.g., "Host" or "CUDA". -Status PyRegisterCustomCallTarget(const std::string& fn_name, - py::capsule capsule, - const std::string& platform) { - static const char* const kName = "xla._CUSTOM_CALL_TARGET"; - if (absl::string_view(capsule.name()) != kName) { - return InvalidArgument( - "Argument to RegisterCustomCallTargetRegistry was not a " - "xla._CUSTOM_CALL_TARGET capsule."); +absl::Status PyRegisterCustomCallTarget(const std::string& fn_name, + nb::capsule capsule, + const std::string& platform, + int api_version) { + switch (api_version) { + case 0: + CustomCallTargetRegistry::Global()->Register( + fn_name, static_cast(capsule.data()), platform); + return absl::OkStatus(); + case 1: + ffi::Ffi::RegisterStaticHandler(xla::ffi::GetXlaFfiApi(), fn_name, + platform, + reinterpret_cast( + static_cast(capsule.data()))); + return absl::OkStatus(); + default: + return absl::UnimplementedError(absl::StrFormat( + "API version %d is not supported by RegisterCustomCallTarget. " + "Supported versions are 0 and 1.", + api_version)); } - CustomCallTargetRegistry::Global()->Register( - fn_name, static_cast(capsule), platform); - return OkStatus(); } template -void DefRepeatedProperty(py::class_& cls, const char* name, +void DefRepeatedProperty(nb::class_& cls, const char* name, Container* (T::*getter)()) { - cls.def_property( + cls.def_prop_rw( name, [getter](T& obj) { Container* elems = (obj.*getter)(); @@ -285,15 +341,12 @@ void DefRepeatedProperty(py::class_& cls, const char* name, } // namespace -void BuildXlaCompilerSubmodule(py::module& m) { +void BuildXlaCompilerSubmodule(nb::module_& m) { // Shapes - py::class_ layout_class(m, "Layout"); - layout_class - .def(py::init([](absl::Span minor_to_major) { - return std::make_unique(minor_to_major); - })) + nb::class_ layout_class(m, "Layout"); + layout_class.def(nb::init>()) .def("minor_to_major", - [](Layout layout) { return SpanToTuple(layout.minor_to_major()); }) + [](Layout layout) { return SpanToNbTuple(layout.minor_to_major()); }) .def("__eq__", [](const Layout& layout, const Layout& other) { return layout == other; }) .def("__ne__", [](const Layout& layout, @@ -301,29 +354,31 @@ void BuildXlaCompilerSubmodule(py::module& m) { .def("__hash__", [](const Layout& layout) { return absl::HashOf(layout); }) .def("to_string", &Layout::ToString) - .def(py::pickle( - [](const Layout& self) -> py::tuple { - auto proto = self.ToProto(); - std::string result; - if (!tsl::SerializeToStringDeterministic(proto, &result)) { - // throw converted by PyBind to a Python RuntimeError. - throw XlaRuntimeError( - absl::StrCat("Layout.py_pickle: ", - "SerializeToStringDeterministic failed")); - } - return py::make_tuple(py::bytes(result)); - }, - [](py::tuple t) { - LayoutProto result; - result.ParseFromString(t[0].cast()); - return Layout::CreateFromProto(result); - })); + .def("__getstate__", + [](const Layout& self) -> nb::tuple { + auto proto = self.ToProto(); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("Layout.py_pickle: ", + "SerializeToStringDeterministic failed")); + } + return nb::make_tuple(nb::bytes(result.data(), result.size())); + }) + .def("__setstate__", [](Layout* self, nb::tuple t) { + LayoutProto result; + nb::bytes serialized = nb::cast(t[0]); + result.ParseFromArray(serialized.c_str(), serialized.size()); + new (self) Layout(Layout::CreateFromProto(result)); + }); - py::class_ shape_class(m, "Shape"); + nb::class_ shape_class(m, "Shape"); shape_class - .def(py::init([](const std::string& s) { - return std::make_unique(ValueOrThrow(ParseShape(s))); - })) + .def("__init__", + [](Shape* self, const std::string& s) { + new (self) Shape(ValueOrThrow(ParseShape(s))); + }) .def_static( "tuple_shape", [](std::vector shapes) -> Shape { @@ -332,8 +387,8 @@ void BuildXlaCompilerSubmodule(py::module& m) { "Constructs a tuple shape.") .def_static("array_shape", xla::ValueOrThrowWrapper( - [](PrimitiveType type, py::object dims_seq, - std::optional layout_seq, + [](PrimitiveType type, nb::sequence dims_seq, + std::optional layout_seq, std::optional> dynamic_dimensions) -> StatusOr { std::vector dims = @@ -348,14 +403,14 @@ void BuildXlaCompilerSubmodule(py::module& m) { type, dims, std::nullopt, dynamic_dimensions); } }), - "Constructs an array shape.", py::arg("type"), - py::arg("dims"), py::arg("layout") = std::nullopt, - py::arg("dynamic_dimensions") = std::nullopt) + "Constructs an array shape.", nb::arg("type"), + nb::arg("dims"), nb::arg("layout").none() = std::nullopt, + nb::arg("dynamic_dimensions").none() = std::nullopt) .def_static( "array_shape", xla::ValueOrThrowWrapper( - [](py::dtype dtype, py::object dims_seq, - std::optional layout_seq, + [](nb_dtype dtype, nb::sequence dims_seq, + std::optional layout_seq, std::optional> dynamic_dimensions) -> StatusOr { PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype)); @@ -370,26 +425,26 @@ void BuildXlaCompilerSubmodule(py::module& m) { dynamic_dimensions); } }), - "Constructs an array shape.", py::arg("type"), py::arg("dims"), - py::arg("layout") = std::nullopt, - py::arg("dynamic_dimensions") = std::nullopt) + "Constructs an array shape.", nb::arg("type"), nb::arg("dims"), + nb::arg("layout").none() = std::nullopt, + nb::arg("dynamic_dimensions").none() = std::nullopt) .def_static("token_shape", []() { return ShapeUtil::MakeTokenShape(); }) .def_static( "scalar_shape", [](PrimitiveType type) -> Shape { return ShapeUtil::MakeScalarShape(type); }, - "Constructs a scalar shape.", py::arg("type")) + "Constructs a scalar shape.", nb::arg("type")) .def_static( "scalar_shape", - [](py::dtype dtype) -> Shape { + [](nb_dtype dtype) -> Shape { PrimitiveType type = xla::ValueOrThrow(DtypeToPrimitiveType(dtype)); return ShapeUtil::MakeScalarShape(type); }, - "Constructs a scalar shape.", py::arg("type")) + "Constructs a scalar shape.", nb::arg("type")) .def("dimensions", - [](const Shape& shape) -> py::tuple { - return SpanToTuple(shape.dimensions()); + [](const Shape& shape) -> nb::tuple { + return SpanToNbTuple(shape.dimensions()); }) .def("layout", [](const Shape& shape) -> Layout { return shape.layout(); }) @@ -397,15 +452,15 @@ void BuildXlaCompilerSubmodule(py::module& m) { .def("element_type", [](const Shape& shape) { return xla::ValueOrThrow( - PrimitiveTypeToDtype(shape.element_type())); + PrimitiveTypeToNbDtype(shape.element_type())); }) .def("numpy_dtype", [](const Shape& shape) { if (shape.IsTuple()) { - return py::dtype("O"); + return nb_dtype("O"); } return xla::ValueOrThrow( - PrimitiveTypeToDtype(shape.element_type())); + PrimitiveTypeToNbDtype(shape.element_type())); }) .def("is_tuple", &Shape::IsTuple) .def("is_array", &Shape::IsArray) @@ -413,14 +468,15 @@ void BuildXlaCompilerSubmodule(py::module& m) { .def("is_static", &Shape::is_static) .def("is_dynamic", &Shape::is_dynamic) .def("is_dynamic_dimension", &Shape::is_dynamic_dimension, - py::arg("dimension")) + nb::arg("dimension")) .def("set_dynamic_dimension", &Shape::set_dynamic_dimension, - py::arg("dimension"), py::arg("is_dynamic")) + nb::arg("dimension"), nb::arg("is_dynamic")) .def("rank", &Shape::rank) .def("to_serialized_proto", [](const Shape& shape) { ShapeProto proto = shape.ToProto(); - return py::bytes(proto.SerializeAsString()); + std::string s = proto.SerializeAsString(); + return nb::bytes(s.data(), s.size()); }) .def("tuple_shapes", [](const Shape& shape) { @@ -451,26 +507,27 @@ void BuildXlaCompilerSubmodule(py::module& m) { return shape.ToString(/*print_layout=*/true); }); - py::class_(m, "ProgramShape") - .def(py::init( - [](absl::Span params, Shape result) -> ProgramShape { - ProgramShape program_shape; + nb::class_(m, "ProgramShape") + .def( + "__init__", + [](ProgramShape* self, absl::Span params, Shape result) { + new (self) ProgramShape(); for (const Shape& param : params) { - *program_shape.add_parameters() = param; + *self->add_parameters() = param; } - *program_shape.mutable_result() = result; - return program_shape; - })) + *self->mutable_result() = result; + }) .def("parameter_shapes", static_cast& (ProgramShape::*)() const>( &ProgramShape::parameters)) .def("result_shape", &ProgramShape::result) .def("__repr__", &ProgramShape::ToString); - py::class_(m, "ShapeIndex") - .def(py::init([](const std::vector& v) { - return std::make_unique(v.begin(), v.end()); - })) + nb::class_(m, "ShapeIndex") + .def("__init__", + [](ShapeIndex* self, const std::vector& v) { + new (self) ShapeIndex(v.begin(), v.end()); + }) .def("__repr__", &ShapeIndex::ToString) .def("__eq__", [](const ShapeIndex& shape_ind, const ShapeIndex& other) { return shape_ind == other; }) @@ -480,16 +537,17 @@ void BuildXlaCompilerSubmodule(py::module& m) { [](const ShapeIndex& shape_ind) { return absl::HashOf(shape_ind); }); // Literals - py::class_>(m, "Literal") - .def("__repr__", &Literal::ToString); + nb::class_(m, "Literal").def("__repr__", &Literal::ToString); - py::class_(m, "XlaComputation") - .def(py::init([](const py::bytes& serialized_hlo_module_proto) - -> std::unique_ptr { - HloModuleProto proto; - proto.ParseFromString(std::string(serialized_hlo_module_proto)); - return std::make_unique(proto); - })) + nb::class_(m, "XlaComputation") + .def("__init__", + [](XlaComputation* self, + const nb::bytes& serialized_hlo_module_proto) { + HloModuleProto proto; + proto.ParseFromArray(serialized_hlo_module_proto.c_str(), + serialized_hlo_module_proto.size()); + new (self) XlaComputation(proto); + }) .def("get_hlo_module", xla::ValueOrThrowWrapper(GetHloModule)) .def("program_shape", xla::ValueOrThrowWrapper(&XlaComputation::GetProgramShape)) @@ -497,62 +555,59 @@ void BuildXlaCompilerSubmodule(py::module& m) { .def("as_serialized_hlo_module_proto", xla::ValueOrThrowWrapper(GetComputationSerializedProto)) .def("as_hlo_text", xla::ValueOrThrowWrapper(GetComputationHloText), - py::arg("print_large_constants") = false) + nb::arg("print_large_constants") = false) .def("as_hlo_dot_graph", xla::ValueOrThrowWrapper(GetComputationHloDotGraph)) .def("hash", xla::ValueOrThrowWrapper(HashComputation)) .def("as_hlo_module", xla::ValueOrThrowWrapper(GetHloModule)); - py::class_ hlo_print_options_class(m, "HloPrintOptions"); - hlo_print_options_class.def(py::init<>()) + nb::class_ hlo_print_options_class(m, "HloPrintOptions"); + hlo_print_options_class.def(nb::init<>()) .def_static("short_parsable", &HloPrintOptions::ShortParsable) .def_static("canonical", &HloPrintOptions::Canonical) .def_static("fingerprint", &HloPrintOptions::Fingerprint) - .def_property("print_large_constants", - &HloPrintOptions::print_large_constants, - &HloPrintOptions::set_print_large_constants) - .def_property("print_metadata", &HloPrintOptions::print_metadata, - &HloPrintOptions::set_print_metadata) - .def_property("print_backend_config", - &HloPrintOptions::print_backend_config, - &HloPrintOptions::set_print_backend_config) - .def_property("print_result_shape", &HloPrintOptions::print_result_shape, - &HloPrintOptions::set_print_result_shape) - .def_property("print_operand_shape", - &HloPrintOptions::print_operand_shape, - &HloPrintOptions::set_print_operand_shape) - .def_property("print_operand_names", - &HloPrintOptions::print_operand_names, - &HloPrintOptions::set_print_operand_names) - .def_property("print_ids", &HloPrintOptions::print_ids, - &HloPrintOptions::set_print_ids) - .def_property("print_extra_attributes", - &HloPrintOptions::print_extra_attributes, - &HloPrintOptions::set_print_extra_attributes) - .def_property("print_program_shape", - &HloPrintOptions::print_program_shape, - &HloPrintOptions::set_print_program_shape) - .def_property("print_percent", &HloPrintOptions::print_percent, - &HloPrintOptions::set_print_percent) - .def_property("print_control_dependencies", - &HloPrintOptions::print_control_dependencies, - &HloPrintOptions::set_print_control_dependencies) - .def_property("compact_operands", &HloPrintOptions::compact_operands, - &HloPrintOptions::set_compact_operands) - .def_property("include_layout_in_shapes", - &HloPrintOptions::include_layout_in_shapes, - &HloPrintOptions::set_include_layout_in_shapes) - .def_property("canonicalize_instruction_names", - &HloPrintOptions::canonicalize_instruction_names, - &HloPrintOptions::set_canonicalize_instruction_names) - .def_property("canonicalize_computations", - &HloPrintOptions::canonicalize_computations, - &HloPrintOptions::set_canonicalize_computations) - .def_property("indent_amount", &HloPrintOptions::indent_amount, - &HloPrintOptions::set_indent_amount) - .def_property("is_in_nested_computation", - &HloPrintOptions::is_in_nested_computation, - &HloPrintOptions::set_is_in_nested_computation); + .def_prop_rw("print_large_constants", + &HloPrintOptions::print_large_constants, + &HloPrintOptions::set_print_large_constants) + .def_prop_rw("print_metadata", &HloPrintOptions::print_metadata, + &HloPrintOptions::set_print_metadata) + .def_prop_rw("print_backend_config", + &HloPrintOptions::print_backend_config, + &HloPrintOptions::set_print_backend_config) + .def_prop_rw("print_result_shape", &HloPrintOptions::print_result_shape, + &HloPrintOptions::set_print_result_shape) + .def_prop_rw("print_operand_shape", &HloPrintOptions::print_operand_shape, + &HloPrintOptions::set_print_operand_shape) + .def_prop_rw("print_operand_names", &HloPrintOptions::print_operand_names, + &HloPrintOptions::set_print_operand_names) + .def_prop_rw("print_ids", &HloPrintOptions::print_ids, + &HloPrintOptions::set_print_ids) + .def_prop_rw("print_extra_attributes", + &HloPrintOptions::print_extra_attributes, + &HloPrintOptions::set_print_extra_attributes) + .def_prop_rw("print_program_shape", &HloPrintOptions::print_program_shape, + &HloPrintOptions::set_print_program_shape) + .def_prop_rw("print_percent", &HloPrintOptions::print_percent, + &HloPrintOptions::set_print_percent) + .def_prop_rw("print_control_dependencies", + &HloPrintOptions::print_control_dependencies, + &HloPrintOptions::set_print_control_dependencies) + .def_prop_rw("compact_operands", &HloPrintOptions::compact_operands, + &HloPrintOptions::set_compact_operands) + .def_prop_rw("include_layout_in_shapes", + &HloPrintOptions::include_layout_in_shapes, + &HloPrintOptions::set_include_layout_in_shapes) + .def_prop_rw("canonicalize_instruction_names", + &HloPrintOptions::canonicalize_instruction_names, + &HloPrintOptions::set_canonicalize_instruction_names) + .def_prop_rw("canonicalize_computations", + &HloPrintOptions::canonicalize_computations, + &HloPrintOptions::set_canonicalize_computations) + .def_prop_rw("indent_amount", &HloPrintOptions::indent_amount, + &HloPrintOptions::set_indent_amount) + .def_prop_rw("is_in_nested_computation", + &HloPrintOptions::is_in_nested_computation, + &HloPrintOptions::set_is_in_nested_computation); // HloModule.computations() returns raw pointers. // pybind seems to prefer smart pointers. @@ -582,20 +637,18 @@ void BuildXlaCompilerSubmodule(py::module& m) { const std::shared_ptr module; }; - py::class_> - hlo_computation_class(m, "HloComputation"); + nb::class_ hlo_computation_class(m, "HloComputation"); - hlo_computation_class.def_property_readonly("name", &ComputationWrapper::name) + hlo_computation_class.def_prop_ro("name", &ComputationWrapper::name) .def("render_html", &ComputationWrapper::render_html); - py::class_> hlo_module_class( - m, "HloModule"); - hlo_module_class.def_property_readonly("name", &HloModule::name) + nb::class_ hlo_module_class(m, "HloModule"); + hlo_module_class.def_prop_ro("name", &HloModule::name) .def( "to_string", static_cast( &HloModule::ToString), - py::arg("options") = HloPrintOptions()) + nb::arg("options") = HloPrintOptions()) .def("as_serialized_hlo_module_proto", xla::ValueOrThrowWrapper(GetHloModuleSerializedProto)) .def("from_serialized_hlo_module_proto", @@ -609,40 +662,37 @@ void BuildXlaCompilerSubmodule(py::module& m) { std::make_shared(comp, m)); return computations; }) - .def_property_readonly( - "spmd_output_sharding", - [](const HloModule& m) -> std::optional { - if (!m.has_spmd_output_sharding()) return std::nullopt; - return m.spmd_output_sharding().ToProto(); - }) - .def_property_readonly( - "spmd_parameters_shardings", - [](const HloModule& m) - -> std::optional> { - if (!m.has_spmd_parameters_shardings()) return std::nullopt; - std::vector param_shardings; - for (const auto& parameter_sharding : - m.spmd_parameters_shardings()) { - param_shardings.push_back(parameter_sharding.ToProto()); - } - return param_shardings; - }); + .def_prop_ro("spmd_output_sharding", + [](const HloModule& m) -> std::optional { + if (!m.has_spmd_output_sharding()) return std::nullopt; + return m.spmd_output_sharding().ToProto(); + }) + .def_prop_ro("spmd_parameters_shardings", + [](const HloModule& m) + -> std::optional> { + if (!m.has_spmd_parameters_shardings()) + return std::nullopt; + std::vector param_shardings; + for (const auto& parameter_sharding : + m.spmd_parameters_shardings()) { + param_shardings.push_back(parameter_sharding.ToProto()); + } + return param_shardings; + }); - py::class_> - hlo_module_group_class(m, "HloModuleGroup"); + nb::class_ hlo_module_group_class(m, "HloModuleGroup"); hlo_module_group_class - .def(py::init( - [](const std::string& name, - const std::vector>& hlo_modules) - -> std::shared_ptr { - std::vector> modules; - modules.reserve(hlo_modules.size()); - for (const auto& m : hlo_modules) { - modules.push_back(m->Clone(/*suffix=*/"")); - } - return std::make_shared(name, std::move(modules)); - })) - .def_property_readonly("name", &HloModuleGroup::name) + .def("__init__", + [](HloModuleGroup* self, const std::string& name, + const std::vector>& hlo_modules) { + std::vector> modules; + modules.reserve(hlo_modules.size()); + for (const auto& m : hlo_modules) { + modules.push_back(m->Clone(/*suffix=*/"")); + } + new (self) HloModuleGroup(name, std::move(modules)); + }) + .def_prop_ro("name", &HloModuleGroup::name) .def("to_string", &HloModuleGroup::ToString) .def("to_modules", [](HloModuleGroup& m) -> std::vector> { @@ -662,21 +712,21 @@ void BuildXlaCompilerSubmodule(py::module& m) { *hlo_module.entry_computation(), /*label=*/"", hlo_module.config().debug_options(), RenderedGraphFormat::kDot)); }); - m.def("hlo_module_cost_analysis", - xla::ValueOrThrowWrapper( - [](PyClient* client, const HloModule& module) - -> StatusOr> { - TF_ASSIGN_OR_RETURN(auto analysis, - client->pjrt_client()->GetHloCostAnalysis()); - TF_RETURN_IF_ERROR( - module.entry_computation()->Accept(analysis.get())); + m.def( + "hlo_module_cost_analysis", + xla::ValueOrThrowWrapper([](PyClient* client, const HloModule& module) + -> StatusOr { + TF_ASSIGN_OR_RETURN(auto analysis, + client->pjrt_client()->GetHloCostAnalysis()); + TF_RETURN_IF_ERROR(module.entry_computation()->Accept(analysis.get())); - // Convert from HloCostAnalysis::Properties to a standard map. - absl::flat_hash_map ret; - analysis->properties().ForEach( - [&](absl::string_view key, float val) { ret[key] = val; }); - return ret; - })); + // Convert from HloCostAnalysis::Properties to a standard map. + nb::dict ret; + analysis->properties().ForEach([&](absl::string_view key, float val) { + ret[nb::str(key.data(), key.size())] = nb::cast(val); + }); + return ret; + })); m.def("hlo_module_from_text", xla::ValueOrThrowWrapper([](const std::string& hlo_module_text) -> StatusOr> { @@ -687,12 +737,13 @@ void BuildXlaCompilerSubmodule(py::module& m) { return result; })); - py::class_ xla_op_class(m, "XlaOp"); + nb::class_ xla_op_class(m, "XlaOp"); - py::class_(m, "XlaBuilder") - .def(py::init([](const std::string& name) -> std::unique_ptr { - return std::make_unique(UniquifyName(name)); - })) + nb::class_(m, "XlaBuilder") + .def("__init__", + [](XlaBuilder* self, const std::string& name) { + new (self) XlaBuilder(UniquifyName(name)); + }) // TODO(phawkins): delete capitalized names after updating callers. .def("Build", xla::ValueOrThrowWrapper( @@ -700,7 +751,7 @@ void BuildXlaCompilerSubmodule(py::module& m) { return root ? builder.Build(*root) : builder.Build(); }), "Builds a computation from the contents of the builder.", - py::arg("root") = std::nullopt) + nb::arg("root") = std::nullopt) .def("GetShape", xla::ValueOrThrowWrapper(&XlaBuilder::GetShape)) .def("build", xla::ValueOrThrowWrapper( @@ -708,7 +759,7 @@ void BuildXlaCompilerSubmodule(py::module& m) { return root ? builder.Build(*root) : builder.Build(); }), "Builds a computation from the contents of the builder.", - py::arg("root") = std::nullopt) + nb::arg("root") = std::nullopt) .def("clear_op_metadata", &XlaBuilder::ClearOpMetadata) .def("get_shape", xla::ValueOrThrowWrapper(&XlaBuilder::GetShape)) .def( @@ -718,7 +769,7 @@ void BuildXlaCompilerSubmodule(py::module& m) { return root ? builder.GetProgramShape(*root) : builder.GetProgramShape(); }, - py::arg("root") = std::nullopt) + nb::arg("root") = std::nullopt) .def("is_constant", xla::ValueOrThrowWrapper(&XlaBuilder::IsConstant)) .def("set_op_metadata", &XlaBuilder::SetOpMetadata) .def("set_sharding", &XlaBuilder::SetSharding) @@ -735,69 +786,72 @@ void BuildXlaCompilerSubmodule(py::module& m) { }); // Device assignments - py::class_(m, "DeviceAssignment") + nb::class_(m, "DeviceAssignment") .def_static( "create", - xla::ValueOrThrowWrapper( - [](py::array_t array) -> StatusOr { - if (array.ndim() != 2) { - return InvalidArgument( - "Argument to DeviceAssignment constructor must be a " - "2D array, received an %dD array.", - array.ndim()); - } - DeviceAssignment result(array.shape(0), array.shape(1)); - for (int i = 0; i < array.shape(0); ++i) { - for (int j = 0; j < array.shape(1); ++j) { - result(i, j) = array.at(i, j); - } - } - return result; - })) + xla::ValueOrThrowWrapper([](nb::ndarray> array) + -> StatusOr { + if (array.ndim() != 2) { + return InvalidArgument( + "Argument to DeviceAssignment constructor must be a " + "2D array, received an %dD array.", + array.ndim()); + } + DeviceAssignment result(array.shape(0), array.shape(1)); + for (int i = 0; i < array.shape(0); ++i) { + for (int j = 0; j < array.shape(1); ++j) { + result(i, j) = array(i, j); + } + } + return result; + })) .def("replica_count", &DeviceAssignment::replica_count) .def("computation_count", &DeviceAssignment::computation_count) .def("__repr__", &DeviceAssignment::ToString) .def("serialize", xla::ValueOrThrowWrapper([](const DeviceAssignment& da) - -> StatusOr { + -> StatusOr { DeviceAssignmentProto proto; TF_RETURN_IF_ERROR(da.Serialize(&proto)); std::string result; if (!tsl::SerializeToStringDeterministic(proto, &result)) { return Unknown("Failed to serialize the DeviceAssignmentProto."); } - return py::bytes(result); + return nb::bytes(result.data(), result.size()); })); - py::class_ compile_options(m, "CompileOptions"); + nb::class_ compile_options(m, "CompileOptions"); compile_options - .def(py::init([]() -> CompileOptions { - CompileOptions options; - DebugOptions* debug_options = - options.executable_build_options.mutable_debug_options(); - // Sets fast-math-disabling default options expected by JAX. - debug_options->set_xla_cpu_enable_fast_min_max(false); - debug_options->set_xla_gpu_enable_fast_min_max(false); - return options; - })) - .def(py::pickle( - [](const CompileOptions& self) -> py::tuple { - auto proto = ValueOrThrow(self.ToProto()); - std::string result; - if (!tsl::SerializeToStringDeterministic(proto, &result)) { - // throw converted by PyBind to a Python RuntimeError. - throw XlaRuntimeError( - absl::StrCat("CompileOptions.py_pickle: ", - "SerializeToStringDeterministic failed")); - } - return py::make_tuple(py::bytes(result)); - }, - [](py::tuple t) { - CompileOptionsProto result; - result.ParseFromString(t[0].cast()); - return ValueOrThrow(CompileOptions::FromProto(result)); - })) + .def("__init__", + [](CompileOptions* self) { + new (self) CompileOptions(); + DebugOptions* debug_options = + self->executable_build_options.mutable_debug_options(); + // Sets fast-math-disabling default options expected by JAX. + debug_options->set_xla_cpu_enable_fast_min_max(false); + debug_options->set_xla_gpu_enable_fast_min_max(false); + }) + .def("__getstate__", + [](const CompileOptions& self) -> nb::tuple { + auto proto = ValueOrThrow(self.ToProto()); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("CompileOptions.py_pickle: ", + "SerializeToStringDeterministic failed")); + } + return nb::make_tuple(nb::bytes(result.data(), result.size())); + }) + .def("__setstate__", + [](CompileOptions* self, nb::tuple t) { + CompileOptionsProto result; + nb::bytes serialized = nb::cast(t[0]); + result.ParseFromArray(serialized.c_str(), serialized.size()); + new (self) CompileOptions( + ValueOrThrow(CompileOptions::FromProto(result))); + }) .def("SerializeAsString", - [](const CompileOptions& self) -> py::bytes { + [](const CompileOptions& self) -> nb::bytes { auto proto = ValueOrThrow(self.ToProto()); std::string result; if (!tsl::SerializeToStringDeterministic(proto, &result)) { @@ -806,28 +860,26 @@ void BuildXlaCompilerSubmodule(py::module& m) { absl::StrCat("CompileOptions.SerializeAsString: ", "SerializeToStringDeterministic failed")); } - return py::bytes(result); + return nb::bytes(result.data(), result.size()); }) .def_static("ParseFromString", - [](py::bytes s) { + [](nb::bytes s) { CompileOptionsProto result; - result.ParseFromString(s); + result.ParseFromArray(s.c_str(), s.size()); return ValueOrThrow(CompileOptions::FromProto(result)); }) - .def_readwrite("argument_layouts", &CompileOptions::argument_layouts) - .def_readwrite("parameter_is_tupled_arguments", - &CompileOptions::parameter_is_tupled_arguments) - .def_readwrite("compile_portable_executable", - &CompileOptions::compile_portable_executable) - .def_readonly("executable_build_options", - &CompileOptions::executable_build_options) - .def_readwrite("env_option_overrides", - &CompileOptions::env_option_overrides) + .def_rw("argument_layouts", &CompileOptions::argument_layouts) + .def_rw("parameter_is_tupled_arguments", + &CompileOptions::parameter_is_tupled_arguments) + .def_rw("compile_portable_executable", + &CompileOptions::compile_portable_executable) + .def_ro("executable_build_options", + &CompileOptions::executable_build_options) + .def_rw("env_option_overrides", &CompileOptions::env_option_overrides) // TODO(phawkins): the following fields exist for backward compatibility. // Remove them after JAX has been updated not to use them. - .def_readwrite("tuple_arguments", - &CompileOptions::parameter_is_tupled_arguments) - .def_property( + .def_rw("tuple_arguments", &CompileOptions::parameter_is_tupled_arguments) + .def_prop_rw( "num_replicas", [](const CompileOptions& options) { return options.executable_build_options.num_replicas(); @@ -835,7 +887,7 @@ void BuildXlaCompilerSubmodule(py::module& m) { [](CompileOptions& options, int num_replicas) { options.executable_build_options.set_num_replicas(num_replicas); }) - .def_property( + .def_prop_rw( "num_partitions", [](const CompileOptions& options) { return options.executable_build_options.num_partitions(); @@ -843,13 +895,13 @@ void BuildXlaCompilerSubmodule(py::module& m) { [](CompileOptions& options, int num_partitions) { options.executable_build_options.set_num_partitions(num_partitions); }) - .def_property( + .def_prop_rw( "profile_version", [](const CompileOptions& options) { return options.profile_version; }, [](CompileOptions& options, int64_t profile_version) { options.profile_version = profile_version; }) - .def_property( + .def_prop_rw( "device_assignment", [](const CompileOptions& options) -> std::optional { return options.executable_build_options.has_device_assignment() @@ -865,52 +917,88 @@ void BuildXlaCompilerSubmodule(py::module& m) { }); // Custom-call targets. - m.def("register_custom_call_target", - [](const std::string& fn_name, py::capsule capsule, - const std::string& platform) { - xla::ThrowIfError(PyRegisterCustomCallTarget( - fn_name, std::move(capsule), platform)); - }); + m.def( + "register_custom_call_target", + [](nb::object fn_name_py, nb::capsule capsule, + const std::string& platform, const int api_version) { + std::string fn_name; + if (!nb::try_cast(fn_name_py, fn_name)) { + nb::bytes bytes = nb::cast(fn_name_py); + fn_name = std::string(bytes.c_str(), bytes.size()); + } + xla::ThrowIfError(PyRegisterCustomCallTarget( + fn_name, std::move(capsule), platform, api_version)); + }, + nb::arg("fn_name"), nb::arg("capsule"), nb::arg("platform"), + nb::arg("api_version") = 0); - py::class_(m, "DebugOptions") + m.def( + "custom_call_targets", + [](const std::string& platform) -> nb::dict { + nb::dict targets; + for (const auto& [name, target] : + CustomCallTargetRegistry::Global()->registered_symbols(platform)) { + targets[nb::str(name.data(), name.size())] = nb::capsule(target); + } + + for (const auto& [name, registration] : + ffi::StaticRegisteredHandlers(platform)) { + targets[nb::str(name.data(), name.size())] = + nb::capsule(reinterpret_cast(registration.handler)); + } + return targets; + }, + nb::arg("platform")); + + nb::class_(m, "DebugOptions") .def("__repr__", &DebugOptions::DebugString) - .def_property("xla_backend_optimization_level", - &DebugOptions::xla_backend_optimization_level, - &DebugOptions::set_xla_backend_optimization_level) - .def_property("xla_cpu_enable_fast_math", - &DebugOptions::xla_cpu_enable_fast_math, - &DebugOptions::set_xla_cpu_enable_fast_math) - .def_property("xla_cpu_enable_xprof_traceme", - &DebugOptions::xla_cpu_enable_xprof_traceme, - &DebugOptions::set_xla_cpu_enable_xprof_traceme) - .def_property("xla_cpu_fast_math_honor_infs", - &DebugOptions::xla_cpu_fast_math_honor_infs, - &DebugOptions::set_xla_cpu_fast_math_honor_infs) - .def_property("xla_cpu_fast_math_honor_nans", - &DebugOptions::xla_cpu_fast_math_honor_nans, - &DebugOptions::set_xla_cpu_fast_math_honor_nans) - .def_property("xla_cpu_fast_math_honor_division", - &DebugOptions::xla_cpu_fast_math_honor_division, - &DebugOptions::set_xla_cpu_fast_math_honor_division) - .def_property("xla_cpu_fast_math_honor_functions", - &DebugOptions::xla_cpu_fast_math_honor_functions, - &DebugOptions::set_xla_cpu_fast_math_honor_functions) - .def_property("xla_detailed_logging", &DebugOptions::xla_detailed_logging, - &DebugOptions::set_xla_detailed_logging) - .def_property("xla_enable_dumping", &DebugOptions::xla_enable_dumping, - &DebugOptions::set_xla_enable_dumping) - .def_property("xla_gpu_enable_fast_min_max", - &DebugOptions::xla_gpu_enable_fast_min_max, - &DebugOptions::set_xla_gpu_enable_fast_min_max) - .def_property("xla_gpu_cuda_data_dir", - &DebugOptions::xla_gpu_cuda_data_dir, - [](DebugOptions* self, std::string value) { - self->set_xla_gpu_cuda_data_dir(value); - }) - .def_property("xla_llvm_disable_expensive_passes", - &DebugOptions::xla_llvm_disable_expensive_passes, - &DebugOptions::set_xla_llvm_disable_expensive_passes) - .def_property( + .def_prop_rw("xla_backend_optimization_level", + &DebugOptions::xla_backend_optimization_level, + &DebugOptions::set_xla_backend_optimization_level) + .def_prop_rw("xla_cpu_enable_fast_math", + &DebugOptions::xla_cpu_enable_fast_math, + &DebugOptions::set_xla_cpu_enable_fast_math) + .def_prop_rw("xla_cpu_enable_xprof_traceme", + &DebugOptions::xla_cpu_enable_xprof_traceme, + &DebugOptions::set_xla_cpu_enable_xprof_traceme) + .def_prop_rw("xla_cpu_fast_math_honor_infs", + &DebugOptions::xla_cpu_fast_math_honor_infs, + &DebugOptions::set_xla_cpu_fast_math_honor_infs) + .def_prop_rw("xla_cpu_fast_math_honor_nans", + &DebugOptions::xla_cpu_fast_math_honor_nans, + &DebugOptions::set_xla_cpu_fast_math_honor_nans) + .def_prop_rw("xla_cpu_fast_math_honor_division", + &DebugOptions::xla_cpu_fast_math_honor_division, + &DebugOptions::set_xla_cpu_fast_math_honor_division) + .def_prop_rw("xla_cpu_fast_math_honor_functions", + &DebugOptions::xla_cpu_fast_math_honor_functions, + &DebugOptions::set_xla_cpu_fast_math_honor_functions) + .def_prop_rw("xla_detailed_logging", &DebugOptions::xla_detailed_logging, + &DebugOptions::set_xla_detailed_logging) + .def_prop_rw("xla_enable_dumping", &DebugOptions::xla_enable_dumping, + &DebugOptions::set_xla_enable_dumping) + .def_prop_rw("xla_gpu_enable_fast_min_max", + &DebugOptions::xla_gpu_enable_fast_min_max, + &DebugOptions::set_xla_gpu_enable_fast_min_max) + .def_prop_rw("xla_gpu_dump_autotune_results_to", + &DebugOptions::xla_gpu_dump_autotune_results_to, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_dump_autotune_results_to(value); + }) + .def_prop_rw("xla_gpu_load_autotune_results_from", + &DebugOptions::xla_gpu_load_autotune_results_from, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_load_autotune_results_from(value); + }) + .def_prop_rw("xla_gpu_cuda_data_dir", + &DebugOptions::xla_gpu_cuda_data_dir, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_cuda_data_dir(value); + }) + .def_prop_rw("xla_llvm_disable_expensive_passes", + &DebugOptions::xla_llvm_disable_expensive_passes, + &DebugOptions::set_xla_llvm_disable_expensive_passes) + .def_prop_rw( "xla_disable_hlo_passes", [](DebugOptions* self) { return absl::StrJoin(self->xla_disable_hlo_passes(), ","); @@ -922,7 +1010,7 @@ void BuildXlaCompilerSubmodule(py::module& m) { self->add_xla_disable_hlo_passes(passname); } }) - .def_property( + .def_prop_rw( "xla_enable_hlo_passes_only", [](DebugOptions* self) { return absl::StrJoin(self->xla_enable_hlo_passes_only(), ","); @@ -934,84 +1022,95 @@ void BuildXlaCompilerSubmodule(py::module& m) { self->add_xla_enable_hlo_passes_only(passname); } }) - .def_property("xla_test_all_input_layouts", - &DebugOptions::xla_test_all_input_layouts, - &DebugOptions::set_xla_test_all_input_layouts) - .def_property("xla_force_host_platform_device_count", - &DebugOptions::xla_force_host_platform_device_count, - &DebugOptions::set_xla_force_host_platform_device_count) - .def_property("xla_dump_to", &DebugOptions::xla_dump_to, - [](DebugOptions* self, std::string value) { - self->set_xla_dump_to(value); - }) - .def_property("xla_dump_hlo_module_re", - &DebugOptions::xla_dump_hlo_module_re, - [](DebugOptions* self, std::string value) { - self->set_xla_dump_hlo_module_re(value); - }) - .def_property("xla_dump_hlo_pass_re", &DebugOptions::xla_dump_hlo_pass_re, - [](DebugOptions* self, std::string value) { - self->set_xla_dump_hlo_pass_re(value); - }) - .def_property("xla_dump_hlo_as_text", &DebugOptions::xla_dump_hlo_as_text, - &DebugOptions::set_xla_dump_hlo_as_text) - .def_property("xla_dump_hlo_as_proto", - &DebugOptions::xla_dump_hlo_as_proto, - &DebugOptions::set_xla_dump_hlo_as_proto) - .def_property("xla_dump_hlo_as_dot", &DebugOptions::xla_dump_hlo_as_dot, - &DebugOptions::set_xla_dump_hlo_as_dot) - .def_property("xla_dump_hlo_as_url", &DebugOptions::xla_dump_hlo_as_url, - &DebugOptions::set_xla_dump_hlo_as_url) - .def_property("xla_dump_hlo_as_html", &DebugOptions::xla_dump_hlo_as_html, - &DebugOptions::set_xla_dump_hlo_as_html) - .def_property("xla_dump_fusion_visualization", - &DebugOptions::xla_dump_fusion_visualization, - &DebugOptions::set_xla_dump_fusion_visualization) - .def_property("xla_dump_hlo_snapshots", - &DebugOptions::xla_dump_hlo_snapshots, - &DebugOptions::set_xla_dump_hlo_snapshots) - .def_property("xla_dump_max_hlo_modules", - &DebugOptions::xla_dump_max_hlo_modules, - &DebugOptions::set_xla_dump_max_hlo_modules) - .def_property("xla_dump_module_metadata", - &DebugOptions::xla_dump_module_metadata, - &DebugOptions::set_xla_dump_module_metadata) - .def_property("xla_dump_compress_protos", - &DebugOptions::xla_dump_compress_protos, - &DebugOptions::set_xla_dump_compress_protos) - .def_property("xla_dump_hlo_as_long_text", - &DebugOptions::xla_dump_hlo_as_long_text, - &DebugOptions::set_xla_dump_hlo_as_long_text) - .def_property("xla_dump_disable_metadata", - &DebugOptions::xla_dump_disable_metadata, - &DebugOptions::set_xla_dump_disable_metadata) - .def_property("xla_dump_hlo_pipeline_re", - &DebugOptions::xla_dump_hlo_pipeline_re, - [](DebugOptions* self, std::string value) { - self->set_xla_dump_hlo_pipeline_re(value); - }) - .def_property("xla_gpu_enable_async_all_reduce", - &DebugOptions::xla_gpu_enable_async_all_reduce, - &DebugOptions::set_xla_gpu_enable_async_all_reduce) - .def_property("xla_gpu_enable_async_all_gather", - &DebugOptions::xla_gpu_enable_async_all_gather, - &DebugOptions::set_xla_gpu_enable_async_all_gather) - .def_property("xla_gpu_enable_async_collective_permute", - &DebugOptions::xla_gpu_enable_async_collective_permute, - &DebugOptions::set_xla_gpu_enable_async_collective_permute) - .def_property("xla_gpu_enable_async_all_to_all", - &DebugOptions::xla_gpu_enable_async_all_to_all, - &DebugOptions::set_xla_gpu_enable_async_all_to_all) - .def_property("xla_gpu_enable_async_reduce_scatter", - &DebugOptions::xla_gpu_enable_async_reduce_scatter, - &DebugOptions::set_xla_gpu_enable_async_reduce_scatter); + .def_prop_rw("xla_test_all_input_layouts", + &DebugOptions::xla_test_all_input_layouts, + &DebugOptions::set_xla_test_all_input_layouts) + .def_prop_rw("xla_force_host_platform_device_count", + &DebugOptions::xla_force_host_platform_device_count, + &DebugOptions::set_xla_force_host_platform_device_count) + .def_prop_rw("xla_dump_to", &DebugOptions::xla_dump_to, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_to(value); + }) + .def_prop_rw("xla_dump_hlo_module_re", + &DebugOptions::xla_dump_hlo_module_re, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_hlo_module_re(value); + }) + .def_prop_rw("xla_dump_hlo_pass_re", &DebugOptions::xla_dump_hlo_pass_re, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_hlo_pass_re(value); + }) + .def_prop_rw("xla_dump_hlo_as_text", &DebugOptions::xla_dump_hlo_as_text, + &DebugOptions::set_xla_dump_hlo_as_text) + .def_prop_rw("xla_dump_hlo_as_proto", + &DebugOptions::xla_dump_hlo_as_proto, + &DebugOptions::set_xla_dump_hlo_as_proto) + .def_prop_rw("xla_dump_hlo_as_dot", &DebugOptions::xla_dump_hlo_as_dot, + &DebugOptions::set_xla_dump_hlo_as_dot) + .def_prop_rw("xla_dump_hlo_as_url", &DebugOptions::xla_dump_hlo_as_url, + &DebugOptions::set_xla_dump_hlo_as_url) + .def_prop_rw("xla_dump_hlo_as_html", &DebugOptions::xla_dump_hlo_as_html, + &DebugOptions::set_xla_dump_hlo_as_html) + .def_prop_rw("xla_dump_fusion_visualization", + &DebugOptions::xla_dump_fusion_visualization, + &DebugOptions::set_xla_dump_fusion_visualization) + .def_prop_rw("xla_dump_hlo_snapshots", + &DebugOptions::xla_dump_hlo_snapshots, + &DebugOptions::set_xla_dump_hlo_snapshots) + .def_prop_rw("xla_dump_max_hlo_modules", + &DebugOptions::xla_dump_max_hlo_modules, + &DebugOptions::set_xla_dump_max_hlo_modules) + .def_prop_rw("xla_dump_module_metadata", + &DebugOptions::xla_dump_module_metadata, + &DebugOptions::set_xla_dump_module_metadata) + .def_prop_rw("xla_dump_compress_protos", + &DebugOptions::xla_dump_compress_protos, + &DebugOptions::set_xla_dump_compress_protos) + .def_prop_rw("xla_dump_hlo_as_long_text", + &DebugOptions::xla_dump_hlo_as_long_text, + &DebugOptions::set_xla_dump_hlo_as_long_text) + .def_prop_rw("xla_dump_disable_metadata", + &DebugOptions::xla_dump_disable_metadata, + &DebugOptions::set_xla_dump_disable_metadata) + .def_prop_rw("xla_dump_hlo_pipeline_re", + &DebugOptions::xla_dump_hlo_pipeline_re, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_hlo_pipeline_re(value); + }) + .def_prop_rw("xla_gpu_enable_async_all_reduce", + &DebugOptions::xla_gpu_enable_async_all_reduce, + &DebugOptions::set_xla_gpu_enable_async_all_reduce) + .def_prop_rw("xla_gpu_enable_async_all_gather", + &DebugOptions::xla_gpu_enable_async_all_gather, + &DebugOptions::set_xla_gpu_enable_async_all_gather) + .def_prop_rw("xla_gpu_enable_async_collective_broadcast", + &DebugOptions::xla_gpu_enable_async_collective_broadcast, + &DebugOptions::set_xla_gpu_enable_async_collective_broadcast) + .def_prop_rw("xla_gpu_enable_async_collective_permute", + &DebugOptions::xla_gpu_enable_async_collective_permute, + &DebugOptions::set_xla_gpu_enable_async_collective_permute) + .def_prop_rw("xla_gpu_enable_async_all_to_all", + &DebugOptions::xla_gpu_enable_async_all_to_all, + &DebugOptions::set_xla_gpu_enable_async_all_to_all) + .def_prop_rw("xla_gpu_enable_async_reduce_scatter", + &DebugOptions::xla_gpu_enable_async_reduce_scatter, + &DebugOptions::set_xla_gpu_enable_async_reduce_scatter); - py::class_(m, "ExecutableBuildOptions") - .def(py::init<>()) + nb::class_(m, "ExecutableBuildOptions") + .def(nb::init<>()) .def("__repr__", &ExecutableBuildOptions::ToString) - .def_property("fdo_profile", &ExecutableBuildOptions::fdo_profile, - &ExecutableBuildOptions::set_fdo_profile) - .def_property( + .def_prop_rw( + "fdo_profile", + [](const ExecutableBuildOptions& options) { + return nb::bytes(options.fdo_profile().data(), + options.fdo_profile().size()); + }, + [](ExecutableBuildOptions& options, nb::bytes fdo_profile) { + options.set_fdo_profile( + std::string(fdo_profile.c_str(), fdo_profile.size())); + }) + .def_prop_rw( "result_layout", [](const ExecutableBuildOptions& options) -> std::optional { return options.result_layout() @@ -1019,14 +1118,14 @@ void BuildXlaCompilerSubmodule(py::module& m) { : std::nullopt; }, &ExecutableBuildOptions::set_result_layout) - .def_property("num_replicas", &ExecutableBuildOptions::num_replicas, - &ExecutableBuildOptions::set_num_replicas) - .def_property("num_partitions", &ExecutableBuildOptions::num_partitions, - &ExecutableBuildOptions::set_num_partitions) - .def_property_readonly( - "debug_options", &ExecutableBuildOptions::mutable_debug_options, - py::return_value_policy::reference, py::keep_alive<1, 0>()) - .def_property( + .def_prop_rw("num_replicas", &ExecutableBuildOptions::num_replicas, + &ExecutableBuildOptions::set_num_replicas) + .def_prop_rw("num_partitions", &ExecutableBuildOptions::num_partitions, + &ExecutableBuildOptions::set_num_partitions) + .def_prop_ro("debug_options", + &ExecutableBuildOptions::mutable_debug_options, + nb::rv_policy::reference, nb::keep_alive<1, 0>()) + .def_prop_rw( "device_assignment", [](const ExecutableBuildOptions& options) -> std::optional { @@ -1036,21 +1135,31 @@ void BuildXlaCompilerSubmodule(py::module& m) { : std::nullopt; }, &ExecutableBuildOptions::set_device_assignment) - .def_property("use_spmd_partitioning", - &ExecutableBuildOptions::use_spmd_partitioning, - &ExecutableBuildOptions::set_use_spmd_partitioning) - .def_property("use_auto_spmd_partitioning", - &ExecutableBuildOptions::use_auto_spmd_partitioning, - &ExecutableBuildOptions::set_use_auto_spmd_partitioning) - .def_property( + .def_prop_rw("use_spmd_partitioning", + &ExecutableBuildOptions::use_spmd_partitioning, + &ExecutableBuildOptions::set_use_spmd_partitioning) + .def_prop_rw("use_auto_spmd_partitioning", + &ExecutableBuildOptions::use_auto_spmd_partitioning, + &ExecutableBuildOptions::set_use_auto_spmd_partitioning) + .def_prop_rw( "auto_spmd_partitioning_mesh_shape", &ExecutableBuildOptions::auto_spmd_partitioning_mesh_shape, &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_shape) - .def_property( - "auto_spmd_partitioning_mesh_ids", - &ExecutableBuildOptions::auto_spmd_partitioning_mesh_ids, - &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_ids) - .def_property( + .def_prop_rw("auto_spmd_partitioning_mesh_ids", + &ExecutableBuildOptions::auto_spmd_partitioning_mesh_ids, + &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_ids) + .def_prop_rw( + "allow_spmd_sharding_propagation_to_parameters", + [](const ExecutableBuildOptions& options) -> std::vector { + return std::vector( + options.allow_spmd_sharding_propagation_to_parameters().begin(), + options.allow_spmd_sharding_propagation_to_parameters().end()); + }, + [](ExecutableBuildOptions& options, std::vector values) { + absl::InlinedVector v(values.begin(), values.end()); + options.set_allow_spmd_sharding_propagation_to_parameters(v); + }) + .def_prop_rw( "allow_spmd_sharding_propagation_to_output", [](const ExecutableBuildOptions& options) -> std::vector { return std::vector( @@ -1062,7 +1171,7 @@ void BuildXlaCompilerSubmodule(py::module& m) { options.set_allow_spmd_sharding_propagation_to_output(v); }); - py::enum_ op_sharding_type(m, "OpSharding_Type"); + nb::enum_ op_sharding_type(m, "OpSharding_Type"); op_sharding_type.value("REPLICATED", OpSharding::REPLICATED) .value("MAXIMAL", OpSharding::MAXIMAL) .value("MANUAL", OpSharding::MANUAL) @@ -1070,49 +1179,53 @@ void BuildXlaCompilerSubmodule(py::module& m) { .value("OTHER", OpSharding::OTHER) .value("UNKNOWN", OpSharding::UNKNOWN); - py::enum_ op_sharding_shard_group_type( + nb::enum_ op_sharding_shard_group_type( m, "OpSharding_ShardGroupType"); op_sharding_shard_group_type.value("AS", OpSharding::AS) .value("LIKE", OpSharding::LIKE); - py::class_ op_sharding(m, "OpSharding"); + nb::class_ op_sharding(m, "OpSharding"); op_sharding - .def_property_readonly_static( + .def_prop_ro_static( "Type", - [op_sharding_type](const py::object&) { return op_sharding_type; }) - .def_property_readonly_static( - "ShardGroupType", - [op_sharding_shard_group_type](const py::object&) { - return op_sharding_shard_group_type; - }) - .def(py::init<>()) - .def(py::pickle( - [](const OpSharding& self) { - return py::make_tuple(py::bytes(self.SerializeAsString())); - }, - [](py::tuple t) { - OpSharding result; - result.ParseFromString(t[0].cast()); - return result; - })) - .def_property("type", &xla::OpSharding::type, &xla::OpSharding::set_type) - .def_property("replicate_on_last_tile_dim", - &xla::OpSharding::replicate_on_last_tile_dim, - &xla::OpSharding::set_replicate_on_last_tile_dim) - .def_property("is_shard_group", &xla::OpSharding::is_shard_group, - &xla::OpSharding::set_is_shard_group) - .def_property("shard_group_id", &xla::OpSharding::shard_group_id, - &xla::OpSharding::set_shard_group_id) - .def_property("shard_group_type", &xla::OpSharding::shard_group_type, - &xla::OpSharding::set_shard_group_type) - .def("__repr__", &xla::OpSharding::DebugString) + [op_sharding_type](const nb::object&) { return op_sharding_type; }) + .def_prop_ro_static("ShardGroupType", + [op_sharding_shard_group_type](const nb::object&) { + return op_sharding_shard_group_type; + }) + .def(nb::init<>()) + .def("__getstate__", + [](const OpSharding& self) { + std::string serialized = self.SerializeAsString(); + return nb::make_tuple( + nb::bytes(serialized.data(), serialized.size())); + }) + .def("__setstate__", + [](OpSharding* self, nb::tuple t) { + new (self) OpSharding(); + nb::bytes serialized = nb::cast(t[0]); + self->ParseFromArray(serialized.c_str(), serialized.size()); + }) + .def_prop_rw("type", &xla::OpSharding::type, &xla::OpSharding::set_type) + .def_prop_rw("replicate_on_last_tile_dim", + &xla::OpSharding::replicate_on_last_tile_dim, + &xla::OpSharding::set_replicate_on_last_tile_dim) + .def_prop_rw("is_shard_group", &xla::OpSharding::is_shard_group, + &xla::OpSharding::set_is_shard_group) + .def_prop_rw("shard_group_id", &xla::OpSharding::shard_group_id, + &xla::OpSharding::set_shard_group_id) + .def_prop_rw("shard_group_type", &xla::OpSharding::shard_group_type, + &xla::OpSharding::set_shard_group_type) + .def("__repr__", + [](const xla::OpSharding& self) { return self.DebugString(); }) .def("ParseFromString", - [](OpSharding& sharding, const std::string& s) { - sharding.ParseFromString(s); + [](OpSharding& sharding, const nb::bytes& s) { + sharding.ParseFromArray(s.c_str(), s.size()); }) .def("SerializeToString", [](const OpSharding& sharding) { - return py::bytes(sharding.SerializeAsString()); + std::string serialized = sharding.SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); }) .def("clone", [](const OpSharding& sharding) { return OpSharding(sharding); }); @@ -1129,7 +1242,7 @@ void BuildXlaCompilerSubmodule(py::module& m) { DefRepeatedProperty(op_sharding, "last_tile_dims", &xla::OpSharding::mutable_last_tile_dims); - py::class_ hlo_sharding(m, "HloSharding"); + nb::class_ hlo_sharding(m, "HloSharding"); hlo_sharding .def_static("from_proto", xla::ValueOrThrowWrapper(xla::HloSharding::FromProto)) @@ -1143,10 +1256,10 @@ void BuildXlaCompilerSubmodule(py::module& m) { "Constructs a tuple sharding.") .def_static( "iota_tile", xla::ValueOrThrowWrapper(IotaTileHelper), - py::arg("dims"), - py::arg("reshape_dims") = absl::Span(), - py::arg("transpose_perm") = absl::Span(), - py::arg("subgroup_types") = absl::Span()) + nb::arg("dims"), + nb::arg("reshape_dims") = absl::Span(), + nb::arg("transpose_perm") = absl::Span(), + nb::arg("subgroup_types") = absl::Span()) .def_static("manual", [] { return HloSharding::Manual(); }) .def_static("replicate", [] { return HloSharding::Replicate(); }) .def_static("unknown", [] { return HloSharding::Unknown(); }) @@ -1172,12 +1285,18 @@ void BuildXlaCompilerSubmodule(py::module& m) { }) .def("tile_assignment_dimensions", [](const xla::HloSharding& self) { - return self.tile_assignment().dimensions(); + absl::Span span = + self.tile_assignment().dimensions(); + CHECK(span.data()); + return span; }) .def("tile_assignment_devices", [](const xla::HloSharding& self) { - return absl::MakeConstSpan(self.tile_assignment().array().data(), - self.tile_assignment().num_elements()); + auto span = + absl::MakeConstSpan(self.tile_assignment().array().data(), + self.tile_assignment().num_elements()); + CHECK(span.data()); + return span; }) .def("replicate_on_last_tile_dim", &xla::HloSharding::ReplicateOnLastTileDim) @@ -1186,27 +1305,27 @@ void BuildXlaCompilerSubmodule(py::module& m) { [](const xla::HloSharding& self) { return self.ToString(); }) .def("to_proto", &xla::HloSharding::ToProto); - py::class_ frontend_attributes(m, "FrontendAttributes"); - frontend_attributes.def(py::init<>()) + nb::class_ frontend_attributes(m, "FrontendAttributes"); + frontend_attributes.def(nb::init<>()) .def("__setitem__", [](FrontendAttributes* attr, std::string key, std::string value) { (*attr->mutable_map())[key] = value; }); - py::enum_(m, "PrecisionConfig_Precision") + nb::enum_(m, "PrecisionConfig_Precision") .value("DEFAULT", PrecisionConfig::DEFAULT) .value("HIGH", PrecisionConfig::HIGH) .value("HIGHEST", PrecisionConfig::HIGHEST); - py::enum_(m, "FftType") + nb::enum_(m, "FftType") .value("FFT", FftType::FFT) .value("IFFT", FftType::IFFT) .value("RFFT", FftType::RFFT) .value("IRFFT", FftType::IRFFT); // Hlo Module Passes - py::class_ hlo_pass_interface(m, "HloPassInterface"); - hlo_pass_interface.def_property_readonly("name", &HloPassInterface::name) + nb::class_ hlo_pass_interface(m, "HloPassInterface"); + hlo_pass_interface.def_prop_ro("name", &HloPassInterface::name) .def("is_pass_pipeline", &HloPassInterface::IsPassPipeline) .def("run", [](HloPassInterface& pass, HloModule* module) -> bool { @@ -1217,11 +1336,11 @@ void BuildXlaCompilerSubmodule(py::module& m) { return xla::ValueOrThrow(pass.RunOnModuleGroup(module_group)); }); - py::class_(m, "HloDCE").def(py::init<>()); - py::class_(m, "CallInliner").def(py::init<>()); - py::class_(m, "FlattenCallGraph") - .def(py::init<>()); - py::class_(m, "TupleSimplifier") - .def(py::init<>()); + nb::class_(m, "HloDCE").def(nb::init<>()); + nb::class_(m, "CallInliner").def(nb::init<>()); + nb::class_(m, "FlattenCallGraph") + .def(nb::init<>()); + nb::class_(m, "TupleSimplifier") + .def(nb::init<>()); } // NOLINT(readability/fn_size) } // namespace xla diff --git a/xla/python/xla_compiler.h b/xla/python/xla_compiler.h index 60b2a3a9c9a70..0f32f8fe9a632 100644 --- a/xla/python/xla_compiler.h +++ b/xla/python/xla_compiler.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,11 +17,11 @@ limitations under the License. #define XLA_PYTHON_XLA_COMPILER_H_ // placeholder for index annotation headers -#include "pybind11/pybind11.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind namespace xla { -void BuildXlaCompilerSubmodule(pybind11::module& m); +void BuildXlaCompilerSubmodule(nanobind::module_& m); } // namespace xla diff --git a/xla/python/xla_extension.cc b/xla/python/xla_extension.cc deleted file mode 100644 index 9ee5299519fc4..0000000000000 --- a/xla/python/xla_extension.cc +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "pybind11/pybind11.h" // from @pybind11 -#include "xla/python/xla.h" - -extern "C" PYBIND11_EXPORT PyObject *PyInit_xla_extension() { - return xla::InitializeXlaExtension(); -} diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index 2f9a147c37992..275a9f94efff8 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,6 +37,8 @@ from typing import ( import numpy as np +from . import ifrt_programs +from . import ifrt_proxy from . import jax_jit from . import mlir from . import ops @@ -59,10 +61,12 @@ class XlaRuntimeError(RuntimeError): class PrimitiveType(enum.IntEnum): PRIMITIVE_TYPE_INVALID: PrimitiveType PRED: PrimitiveType + S4: PrimitiveType S8: PrimitiveType S16: PrimitiveType S32: PrimitiveType S64: PrimitiveType + U4: PrimitiveType U8: PrimitiveType U16: PrimitiveType U32: PrimitiveType @@ -256,7 +260,7 @@ class CompileOptions: env_option_overrides: List[Tuple[str, str]] def register_custom_call_target( - fn_name: str, capsule: Any, platform: str + fn_name: str, capsule: Any, platform: str, api_version: int = ..., ) -> _Status: ... def register_custom_call_partitioner( name: str, @@ -264,6 +268,7 @@ def register_custom_call_partitioner( partition: Callable, infer_sharding_from_operands: Callable, can_side_effecting_have_replicated_sharding: bool, + c_api: Optional[Any], ) -> None: ... def encode_inspect_sharding_callback(handler: Any) -> bytes: ... @@ -300,11 +305,15 @@ class DebugOptions: xla_dump_hlo_pipeline_re: str xla_gpu_enable_async_all_reduce: bool xla_gpu_enable_async_all_gather: bool + xla_gpu_enable_async_collective_broadcast: bool xla_gpu_enable_async_collective_permute: bool xla_gpu_enable_async_all_to_all: bool xla_gpu_enable_async_reduce_scatter: bool + xla_gpu_cuda_data_dir: str xla_detailed_logging: bool xla_enable_dumping: bool + xla_gpu_dump_autotune_results_to: str + xla_gpu_load_autotune_results_from: str class CompiledMemoryStats: generated_code_size_in_bytes: int @@ -312,6 +321,11 @@ class CompiledMemoryStats: output_size_in_bytes: int alias_size_in_bytes: int temp_size_in_bytes: int + host_generated_code_size_in_bytes: int + host_argument_size_in_bytes: int + host_output_size_in_bytes: int + host_alias_size_in_bytes: int + host_temp_size_in_bytes: int serialized_hlo_proto: bytes def __str__(self) -> str: ... @@ -439,6 +453,13 @@ class Memory: def __str__(self) -> str: ... def addressable_by_devices(self) -> List[Device]: ... +class PjRtLayout: + def __str__(self) -> str: ... + def __eq__(self, other: PjRtLayout) -> bool: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, Any): ... + class GpuAllocatorConfig: class Kind(enum.IntEnum): DEFAULT: int @@ -451,6 +472,7 @@ class GpuAllocatorConfig: kind: Kind = ..., memory_fraction: float = ..., preallocate: bool = ..., + collective_memory_size: int = ..., ) -> None: ... class HostBufferSemantics(enum.IntEnum): @@ -488,6 +510,11 @@ class Client: compile_options: CompileOptions = ..., host_callbacks: Sequence[Any] = ..., ) -> LoadedExecutable: ... + def compile_ifrt_program( + self, + program: ifrt_programs.Program, + program_options: ifrt_programs.CompileOptions, + ) -> LoadedExecutable: ... def serialize_executable(self, executable: LoadedExecutable) -> bytes: ... def deserialize_executable( self, @@ -512,13 +539,25 @@ class Client: recv_channel_ids: Sequence[int], serializer: Optional[Callable] = ..., ) -> Any: ... + def get_default_layout( + self, dtype: np.dtype, shard_shape: Sequence[int], device: Device + ) -> PjRtLayout: ... def __getattr__(self, name: str) -> Any: ... +class CpuCollectives: ... + +def make_gloo_tcp_collectives( + distributed_client: Optional[DistributedRuntimeClient] = ..., + hostname: Optional[str] = ..., + interface: Optional[str] = ..., +) -> CpuCollectives: ... + def get_tfrt_cpu_client( asynchronous: bool = ..., distributed_client: Optional[DistributedRuntimeClient] = ..., node_id: int = ..., num_nodes: int = ..., + collectives: Optional[CpuCollectives] = ..., ) -> Client: ... def get_gpu_client( asynchronous: bool = ..., @@ -549,7 +588,7 @@ def get_default_c_api_topology( options: Dict[str, Union[str, int, List[int], float]], ) -> DeviceTopology: ... def get_topology_for_devices(devices: List[Device]) -> DeviceTopology: ... -def load_pjrt_plugin(platform_name: str, library_path: str) -> _Status: ... +def load_pjrt_plugin(platform_name: str, library_path: Optional[str], c_api: Optional[Any]) -> _Status: ... def pjrt_plugin_loaded(plugin_name: str) -> bool: ... def pjrt_plugin_initialized(plugin_name: str) -> bool: ... def initialize_pjrt_plugin(platform_name: str) -> _Status: ... @@ -567,8 +606,6 @@ ArrayImpl = Any # _skip_checks: bool = ...): ... # def block_until_ready(self) -> ArrayImpl: ... # def is_deleted(self) -> bool: ... -# # TODO(yashkatariya): remove this once the transition completes. -# def _init_with_fastpath_disabled(self) -> None: ... # def is_ready(self) -> bool: ... # def delete(self): ... # def unsafe_buffer_pointer(self) -> Any: ... @@ -588,18 +625,19 @@ ArrayImpl = Any def copy_array_to_devices_with_sharding( self: ArrayImpl, devices: List[Device], sharding: Any ) -> ArrayImpl: ... + +def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ... + def batched_device_put( aval: Any, sharding: Any, shards: Sequence[Any], devices: List[Device], committed: bool = True, -) -> ArrayImpl: - ... - +) -> ArrayImpl: ... def check_and_canonicalize_memory_kind( - memory_kind: Optional[str], device_list: DeviceList) -> Optional[str]: ... - + memory_kind: Optional[str], device_list: DeviceList +) -> Optional[str]: ... def array_result_handler( aval: Any, sharding: Any, committed: bool, _skip_checks: bool = ... ) -> Callable: ... @@ -662,6 +700,7 @@ class Executable: def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... def serialize(self) -> str: ... def compile_options(self) -> CompileOptions: ... + def cost_analysis(self) -> Dict[str, Any]: ... class DeviceTopology: platform: str @@ -673,17 +712,28 @@ class DeviceTopology: def buffer_to_dlpack_managed_tensor( buffer: ArrayImpl, stream: int | None = None ) -> Any: ... +@overload def dlpack_managed_tensor_to_buffer( tensor: Any, device: Device, stream: int | None ) -> ArrayImpl: ... - -# Legacy overload -def dlpack_managed_tensor_to_buffer( +@overload +def dlpack_managed_tensor_to_buffer( # Legacy overload tensor: Any, cpu_backend: Optional[Client] = ..., gpu_backend: Optional[Client] = ..., ) -> ArrayImpl: ... +def cuda_array_interface_to_buffer( + cai: Dict[str, Union[ + str, int, None, + Tuple[int, ...], Tuple[int, bool], + List[Tuple[str, str]], + List[Tuple[str, str, Tuple[int, ...]]]] + ], + gpu_backend: Optional[Client] = ..., +) -> ArrayImpl: ... + + # === BEGIN py_traceback.cc class Frame: @@ -725,6 +775,7 @@ class DistributedRuntimeClient: def key_value_dir_get(self, key: str) -> _Status: ... def key_value_dir_get_bytes(self, key: str) -> _Status: ... def key_value_set(self, key: str, value: str) -> _Status: ... + def key_value_set_bytes(self, key: str, value: bytes) -> _Status: ... def key_value_delete(self, key: str) -> _Status: ... def wait_at_barrier(self, barrier_id: str, timeout_in_ms: int) -> _Status: ... @@ -734,7 +785,8 @@ def get_distributed_runtime_service( heartbeat_interval: Optional[int] = ..., max_missing_heartbeats: Optional[int] = ..., cluster_register_timeout: Optional[int] = ..., - shutdown_timeout: Optional[int] = ...) -> DistributedRuntimeService: ... + shutdown_timeout: Optional[int] = ..., +) -> DistributedRuntimeService: ... def get_distributed_runtime_client( address: str, node_id: int, @@ -757,7 +809,6 @@ def is_optimized_build() -> bool: ... def json_to_pprof_profile(json: str) -> bytes: ... def pprof_profile_to_json(proto: bytes) -> str: ... -CompiledFunction = Any class PmapFunction: def __call__(self, *args, **kwargs) -> Any: ... @@ -780,6 +831,7 @@ class DeviceList: def __getitem__(self, index: Any) -> Any: ... def __iter__(self) -> Iterator[Device]: ... def __str__(self) -> str: ... + def __repr__(self) -> str: ... def __getstate__(self) -> Any: ... def __setstate__(self, state: Any): ... @property @@ -832,6 +884,7 @@ class GSPMDSharding(XLACompatibleSharding): op_sharding: Union[OpSharding, HloSharding], *, memory_kind: Optional[str] = None, + _device_list: Optional[DeviceList] = None, ): ... _devices: Tuple[Device, ...] _hlo_sharding: HloSharding @@ -859,6 +912,7 @@ def pjit( static_argnames: Sequence[str], donate_argnums: Sequence[int], pytree_registry: pytree.PyTreeRegistry, + shard_arg_fallback: Callable, cache: Optional[PjitFunctionCache] = ..., ) -> PjitFunction: ... diff --git a/xla/python/xla_extension/ifrt_programs.pyi b/xla/python/xla_extension/ifrt_programs.pyi new file mode 100644 index 0000000000000..58b0996b75797 --- /dev/null +++ b/xla/python/xla_extension/ifrt_programs.pyi @@ -0,0 +1,33 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, Optional, Callable, Sequence, Union + +from xla.python import xla_extension + +class Program: ... + +class CompileOptions: ... + +def make_xla_program(mlir_module: Union[str, bytes]) -> Program: ... + +def make_plugin_program(data: Union[str, bytes]) -> Program: ... + +def make_xla_compile_options( + compile_options: xla_extension.CompileOptions, + host_callbacks: Sequence[Any] +) -> CompileOptions: ... + +def make_plugin_compile_options() -> CompileOptions: ... diff --git a/xla/python/xla_extension/ifrt_proxy.pyi b/xla/python/xla_extension/ifrt_proxy.pyi new file mode 100644 index 0000000000000..f65685025e516 --- /dev/null +++ b/xla/python/xla_extension/ifrt_proxy.pyi @@ -0,0 +1,32 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, Optional, Callable + +from xla.python import xla_extension + +_Status = Any +Client = xla_extension.Client + + +class ClientConnectionOptions: + on_disconnect: Optional[Callable[[_Status], None]] = None + on_connection_update: Optional[Callable[[str], None]] = None + + +def get_client( + proxy_server_address: str, + options: ClientConnectionOptions +) -> Client: ... diff --git a/xla/python/xla_extension/jax_jit.pyi b/xla/python/xla_extension/jax_jit.pyi index 9bf5e30d6c890..b0428919a20a7 100644 --- a/xla/python/xla_extension/jax_jit.pyi +++ b/xla/python/xla_extension/jax_jit.pyi @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,6 @@ from xla.python import xla_extension Client = xla_extension.Client Device = xla_extension.Device -CompiledFunction = xla_extension.CompiledFunction class JitState: disable_jit: Optional[bool] @@ -34,7 +33,6 @@ class JitState: def global_state() -> JitState: ... def thread_local_state() -> JitState: ... -def jit_is_disabled() -> bool: ... def get_enable_x64() -> bool: ... def set_thread_local_state_initialization_callback( function: Callable[[], None]): ... diff --git a/xla/python/xla_extension/mlir.pyi b/xla/python/xla_extension/mlir.pyi index f62c6565dc0a4..ff886a1bdf966 100644 --- a/xla/python/xla_extension/mlir.pyi +++ b/xla/python/xla_extension/mlir.pyi @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,14 +16,20 @@ from typing import Union from . import XlaComputation -def xla_computation_to_mlir_module(computation: XlaComputation) -> str: ... +def xla_computation_to_mlir_module( + computation: XlaComputation, emit_stable_hlo: bool = ... +) -> str: ... def mlir_module_to_xla_computation( - mlir_module: str, use_tuple_args: bool = ..., - return_tuple: bool = ...) -> XlaComputation: ... -def mhlo_to_stablehlo(mlir_module: Union[bytes, str]) -> str: ... -def stablehlo_to_mhlo(mlir_module: Union[bytes, str]) -> str: ... -def serialize_portable_artifact(mlir_module: str, target:str) -> bytes: ... + mlir_module: Union[bytes, str], + use_tuple_args: bool = ..., + return_tuple: bool = ..., +) -> XlaComputation: ... +def mhlo_to_stablehlo(mlir_module: Union[bytes, str]) -> bytes: ... +def stablehlo_to_mhlo(mlir_module: Union[bytes, str]) -> bytes: ... +def serialize_portable_artifact(mlir_module: str, target: str) -> bytes: ... def deserialize_portable_artifact(mlir_module: bytes) -> str: ... -def refine_polymorphic_shapes(mlir_module: Union[bytes, str], - enable_shape_assertions: bool = ..., - validate_static_shapes: bool = ...) -> bytes: ... +def refine_polymorphic_shapes( + mlir_module: Union[bytes, str], + enable_shape_assertions: bool = ..., + validate_static_shapes: bool = ..., +) -> bytes: ... diff --git a/xla/python/xla_extension/ops.pyi b/xla/python/xla_extension/ops.pyi index a8538d7e05487..55624f47446a6 100644 --- a/xla/python/xla_extension/ops.pyi +++ b/xla/python/xla_extension/ops.pyi @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -59,6 +59,7 @@ class CustomCallApiVersion(enum.IntEnum): API_VERSION_ORIGINAL: int API_VERSION_STATUS_RETURNING: int API_VERSION_STATUS_RETURNING_UNIFIED: int + API_VERSION_TYPED_FFI: int def AfterAll(builder: XlaBuilder, tokens: Sequence[XlaOp]) -> XlaOp: ... def AllGather( diff --git a/xla/python/xla_extension/outfeed_receiver.pyi b/xla/python/xla_extension/outfeed_receiver.pyi index 3960a92732937..b0850355de65a 100644 --- a/xla/python/xla_extension/outfeed_receiver.pyi +++ b/xla/python/xla_extension/outfeed_receiver.pyi @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/python/xla_extension/pmap_lib.pyi b/xla/python/xla_extension/pmap_lib.pyi index 817b379e0c686..4ede6eebed836 100644 --- a/xla/python/xla_extension/pmap_lib.pyi +++ b/xla/python/xla_extension/pmap_lib.pyi @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/python/xla_extension/profiler.pyi b/xla/python/xla_extension/profiler.pyi index ff39fe2ec29b9..92dbb02639b7f 100644 --- a/xla/python/xla_extension/profiler.pyi +++ b/xla/python/xla_extension/profiler.pyi @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ # ============================================================================== from types import TracebackType -from typing import Any, Optional, Type +from typing import Any, Optional, Type, Union _Status = Any @@ -24,7 +24,9 @@ def start_server(port: int) -> ProfilerServer: ... def register_plugin_profiler(c_api: Any) -> None: ... def get_profiled_instructions_proto(tensorboard_dir: str) -> bytes: ... -def get_fdo_profile(xspace: bytes) -> bytes: ... +def get_fdo_profile( + xspace: bytes, as_textproto: bool = ... +) -> Union[bytes, str]: ... class ProfilerSession: def __init__(self, options: Optional[ProfileOptions] = ...) -> None: ... diff --git a/xla/python/xla_extension/pytree.pyi b/xla/python/xla_extension/pytree.pyi index 8fbe5047ed9cd..e493fda1dfde5 100644 --- a/xla/python/xla_extension/pytree.pyi +++ b/xla/python/xla_extension/pytree.pyi @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,6 +29,9 @@ class PyTreeRegistry: tree: Any, leaf_predicate: Optional[Callable[[Any], bool]] = ..., ) -> Tuple[List[Any], PyTreeDef]: ... + def flatten_one_level( + self, tree: Any + ) -> Optional[Tuple[Iterable[Any], Any]]: ... def register_node( self, __type: Type[_T], @@ -53,7 +56,6 @@ class PyTreeDef: def children(self) -> List[PyTreeDef]: ... @staticmethod def make_from_node_data_and_children( - self, registry: PyTreeRegistry, node_data: Optional[Tuple[Type, Any]], children: Iterable[PyTreeDef], @@ -71,10 +73,9 @@ class PyTreeDef: def serialize_using_proto(self) -> bytes: ... @staticmethod def deserialize_using_proto( - self, registry: PyTreeRegistry, data: bytes + registry: PyTreeRegistry, data: bytes ) -> PyTreeDef: ... - _Children = TypeVar("_Children", bound=Iterable[Any]) _AuxData = TypeVar("_AuxData", bound=Hashable) diff --git a/xla/python/xla_extension/transfer_guard_lib.pyi b/xla/python/xla_extension/transfer_guard_lib.pyi index 645b64d9c5a40..41240a475bec8 100644 --- a/xla/python/xla_extension/transfer_guard_lib.pyi +++ b/xla/python/xla_extension/transfer_guard_lib.pyi @@ -1,4 +1,4 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# Copyright 2022 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/python/xplane_to_profile_instructions.cc b/xla/python/xplane_to_profile_instructions.cc index 14f62dafcc29c..772e1821a82a6 100644 --- a/xla/python/xplane_to_profile_instructions.cc +++ b/xla/python/xplane_to_profile_instructions.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/xplane_to_profile_instructions.h b/xla/python/xplane_to_profile_instructions.h index 00be1718bc321..8971dc651030e 100644 --- a/xla/python/xplane_to_profile_instructions.h +++ b/xla/python/xplane_to_profile_instructions.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python/xplane_to_profile_instructions_test.cc b/xla/python/xplane_to_profile_instructions_test.cc index bfb4d86e24954..05395eff20f8e 100644 --- a/xla/python/xplane_to_profile_instructions_test.cc +++ b/xla/python/xplane_to_profile_instructions_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/python_api/types_.py b/xla/python_api/types_.py index 91f7446f046b1..d81cd086f46aa 100644 --- a/xla/python_api/types_.py +++ b/xla/python_api/types_.py @@ -1,4 +1,4 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. diff --git a/xla/python_api/xla_literal.py b/xla/python_api/xla_literal.py index e484d8eb3e398..7d0c05d42721a 100644 --- a/xla/python_api/xla_literal.py +++ b/xla/python_api/xla_literal.py @@ -1,4 +1,4 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. diff --git a/xla/python_api/xla_literal_test.py b/xla/python_api/xla_literal_test.py index 480c09c9f507e..5f93651744281 100644 --- a/xla/python_api/xla_literal_test.py +++ b/xla/python_api/xla_literal_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. diff --git a/xla/python_api/xla_shape.py b/xla/python_api/xla_shape.py index 7a024abedb4df..12a186099542f 100644 --- a/xla/python_api/xla_shape.py +++ b/xla/python_api/xla_shape.py @@ -1,4 +1,4 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. diff --git a/xla/python_api/xla_shape_test.py b/xla/python_api/xla_shape_test.py index bb08ddc950b71..9ffd5815676af 100644 --- a/xla/python_api/xla_shape_test.py +++ b/xla/python_api/xla_shape_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. diff --git a/xla/pytype.default.bzl b/xla/pytype.default.bzl index 05143e8a71518..b63011cc1b8e4 100644 --- a/xla/pytype.default.bzl +++ b/xla/pytype.default.bzl @@ -10,5 +10,6 @@ def pytype_strict_binary(name, **kwargs): native.py_binary(name = name, **kwargs) # Placeholder to use until bazel supports pytype_strict_library. -def pytype_strict_library(name, **kwargs): +def pytype_strict_library(name, pytype_deps = [], pytype_srcs = [], **kwargs): + _ = (pytype_deps, pytype_srcs) # @unused native.py_library(name = name, **kwargs) diff --git a/xla/refcounting_hash_map.h b/xla/refcounting_hash_map.h index 1a407d260e597..3f7e328024531 100644 --- a/xla/refcounting_hash_map.h +++ b/xla/refcounting_hash_map.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/refcounting_hash_map_test.cc b/xla/refcounting_hash_map_test.cc index a2a39bb3a02f5..71211cc36c02e 100644 --- a/xla/refcounting_hash_map_test.cc +++ b/xla/refcounting_hash_map_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/reference_util.cc b/xla/reference_util.cc index 897e43f4a3371..fb0348c2f278c 100644 --- a/xla/reference_util.cc +++ b/xla/reference_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/reference_util.h b/xla/reference_util.h index 46c3e6192cf94..418237b33f8c3 100644 --- a/xla/reference_util.h +++ b/xla/reference_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/reference_util_test.cc b/xla/reference_util_test.cc index 9cfee58ae34e6..1ade118862480 100644 --- a/xla/reference_util_test.cc +++ b/xla/reference_util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runlit.cfg.py b/xla/runlit.cfg.py deleted file mode 100644 index 1c28ed57e8b3c..0000000000000 --- a/xla/runlit.cfg.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Lit runner configuration.""" - -import os -import platform -import sys -import lit.formats -from lit.llvm import llvm_config -from lit.llvm.subst import ToolSubst - -# Lint for undefined variables is disabled as config is not defined inside this -# file, instead config is injected by way of evaluating runlit.cfg.py from -# runlit.site.cfg.py which in turn is evaluated by lit.py. The structure is -# common for lit tests and intended to only persist temporarily (b/136126535). -# pylint: disable=undefined-variable -# Configuration file for the 'lit' test runner. - -# name: The name of this test suite. -config.name = 'MLIR ' + os.path.basename(config.mlir_test_dir) - -config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) - -# suffixes: A list of file extensions to treat as test files. -config.suffixes = ['.cc', '.hlo', '.hlotxt', '.json', '.mlir', '.pbtxt', '.py'] - -# test_source_root: The root path where tests are located. -config.test_source_root = config.mlir_test_dir - -# test_exec_root: The root path where tests should be run. -config.test_exec_root = os.environ['RUNFILES_DIR'] - -if platform.system() == 'Windows': - tool_patterns = [ - ToolSubst('FileCheck.exe', unresolved='fatal'), - # Handle these specially as they are strings searched for during testing. - ToolSubst('count.exe', unresolved='fatal'), - ToolSubst('not.exe', unresolved='fatal') - ] - - llvm_config.config.substitutions.append( - ('%python', '"%s"' % (sys.executable))) - - llvm_config.add_tool_substitutions(tool_patterns, - [llvm_config.config.llvm_tools_dir]) -else: - llvm_config.use_default_substitutions() - -subst_marker = 'SUBST_' -subst_marker_len = len(subst_marker) -# Include aditional substitutions that may be defined via params -llvm_config.config.substitutions.extend( - ('%%{%s}' % key, val) - for key, val in lit_config.params.items() - if not key.startswith(subst_marker) -) - -# Include ir substitutions for FileCheck -llvm_config.config.substitutions.append(( - '%{IR_SUBST}', - ' '.join( - "-D{}='{}'".format(key[subst_marker_len:], val.replace('[SPACE]', ' ')) - for key, val in lit_config.params.items() - if key.startswith(subst_marker) - ), -)) - -# Tweak the PATH to include the tools dir. -llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) - -tool_dirs = config.mlir_tf_tools_dirs + [ - config.mlir_tools_dir, config.llvm_tools_dir -] -tool_names = [ - 'hlo_to_llvm_ir', - 'ifrt-opt', - 'kernel-gen-opt', - 'mlir-bisect', - 'mlir-hlo-opt', - 'mlir-opt', - 'mlir-translate', - 'xla-cpu-opt', - 'xla-gpu-opt', - 'xla-mlir-gpu-opt', - 'xla-runtime-opt', - 'xla-translate', - 'xla-translate-gpu-opt', - 'xla-translate-opt', - 'hlo-opt', -] -tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] -llvm_config.add_tool_substitutions(tools, tool_dirs) -# pylint: enable=undefined-variable diff --git a/xla/runlit.site.cfg.py b/xla/runlit.site.cfg.py deleted file mode 100644 index 5fc38b581a0ba..0000000000000 --- a/xla/runlit.site.cfg.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2019 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Lit runner site configuration.""" - -import os -import platform -import lit.llvm - -# Handle the test srcdir for platforms. On windows, things are weird with bazel. -if platform.system() == "Windows": - srcdir = os.environ["TEST_SRCDIR"] - real_test_srcdir = srcdir[:srcdir.find("xla/")] - external_srcdir = os.path.join(real_test_srcdir, "external") -else: - real_test_srcdir = os.environ["TEST_SRCDIR"] - external_srcdir = real_test_srcdir - -# Lint for undefined variables is disabled as config is not defined inside this -# file, instead config is injected by lit.py. The structure is common for lit -# tests and intended to only persist temporarily (b/136126535). -# pylint: disable=undefined-variable -config.llvm_tools_dir = os.path.join(external_srcdir, "llvm-project", "llvm") -config.mlir_obj_root = os.path.join(real_test_srcdir) -config.mlir_tools_dir = os.path.join(external_srcdir, "llvm-project", "mlir") -# TODO(jpienaar): Replace with suffices in build rule. -config.suffixes = [".td", ".mlir", ".pbtxt"] - -xla_root_dir = "xla/" -mlir_tf_tools_dirs = [ - "mlir/backends/cpu", - "mlir/backends/gpu", - "mlir/runtime", - "mlir/tools/mlir_bisect", - "mlir_hlo", - "python/ifrt/ir/tests", - "service/gpu/tests", - "service/mlir_gpu", - "translate", - "translate/mhlo_to_lhlo_with_xla", - "tools", -] -config.mlir_tf_tools_dirs = [ - os.path.join(real_test_srcdir, os.environ["TEST_WORKSPACE"], xla_root_dir, - s) for s in mlir_tf_tools_dirs -] -test_dir = os.environ["TEST_TARGET"] -test_dir = test_dir.strip("/").rsplit(":", 1)[0] -config.mlir_test_dir = os.path.join(real_test_srcdir, - os.environ["TEST_WORKSPACE"], test_dir) - -if platform.system() == "Windows": - # Configure this to work with msys2, TF's preferred windows bash. - config.lit_tools_dir = "/usr/bin" - -lit.llvm.initialize(lit_config, config) - - -# Let the main config do the real work. -lit_config.load_config( - config, - os.path.join( - os.path.join(real_test_srcdir, os.environ["TEST_WORKSPACE"], - xla_root_dir + "runlit.cfg.py"))) -# pylint: enable=undefined-variable diff --git a/xla/runtime/BUILD b/xla/runtime/BUILD index 19d1c0042d998..929cd024921e4 100644 --- a/xla/runtime/BUILD +++ b/xla/runtime/BUILD @@ -1,11 +1,20 @@ -load("//xla:xla.bzl", "xla_cc_test") +load("@tsl//tsl:tsl.bzl", "internal_visibility") load("@tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") -load("@tsl//tsl/platform:build_config.bzl", "if_llvm_system_z_available", "tf_platform_deps") +load("@tsl//tsl/platform:build_config.bzl", "tf_platform_deps") +load( + "@tsl//tsl/platform:build_config_root.bzl", + "if_llvm_aarch32_available", + "if_llvm_aarch64_available", + "if_llvm_powerpc_available", + "if_llvm_system_z_available", + "if_llvm_x86_available", +) load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//xla:internal"], + default_visibility = internal_visibility(["//xla:internal"]), licenses = ["notice"], ) @@ -80,6 +89,7 @@ cc_library( deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", ], ) @@ -140,27 +150,6 @@ cc_library( ], ) -xla_cc_test( - name = "custom_call_test", - srcs = ["custom_call_test.cc"], - deps = [ - ":arguments", - ":async_runtime", - ":custom_call", - ":custom_call_registry", - ":diagnostics", - ":jit_executable", - ":module", - ":state", - "//xla/mlir/runtime/ir/tests:testlib", - "//xla/mlir/runtime/transforms:compilation_pipeline_gpu", - "//xla/mlir/runtime/transforms:custom_call_encoding", - "@tsl//tsl/platform:test", - "@tsl//tsl/platform:test_benchmark", - "@tsl//tsl/platform:test_main", - ], -) - cc_library( name = "custom_call_registry", srcs = ["custom_call_registry.cc"], @@ -236,30 +225,6 @@ cc_library( ], ) -xla_cc_test( - name = "executable_test", - srcs = ["executable_test.cc"], - tags = ["nomsan"], # TODO(ezhulenev): Find msan error in LLVM coroutine passes - deps = [ - ":arguments", - ":async_runtime", - ":custom_call_registry", - ":jit_executable", - ":logical_result", - ":results", - ":types", - "//xla/mlir/runtime/transforms:compilation_pipeline_options", - "//xla/mlir/runtime/transforms/tests:testlib_pipeline", - "//xla/mlir/runtime/utils:async_runtime_api", - "@com_google_absl//absl/base:dynamic_annotations", - "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:env", - "@tsl//tsl/platform:test", - "@tsl//tsl/platform:test_benchmark", - "@tsl//tsl/platform:test_main", - ], -) - cc_library( name = "execution_engine", srcs = ["execution_engine.cc"], @@ -277,26 +242,21 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", "@llvm-project//llvm:TransformUtils", + ] + if_llvm_aarch32_available([ + "@llvm-project//llvm:ARMAsmParser", + "@llvm-project//llvm:ARMCodeGen", + ]) + if_llvm_aarch64_available([ + "@llvm-project//llvm:AArch64AsmParser", + "@llvm-project//llvm:AArch64CodeGen", + ]) + if_llvm_powerpc_available([ + "@llvm-project//llvm:PowerPCAsmParser", + "@llvm-project//llvm:PowerPCCodeGen", + ]) + if_llvm_system_z_available([ + "@llvm-project//llvm:SystemZAsmParser", + "@llvm-project//llvm:SystemZCodeGen", + ]) + if_llvm_x86_available([ "@llvm-project//llvm:X86AsmParser", "@llvm-project//llvm:X86CodeGen", - ] + select({ - "@tsl//tsl:arm_any": [ - "@llvm-project//llvm:AArch64AsmParser", # fixdeps: keep - "@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep - ], - "@tsl//tsl:linux_ppc64le": [ - "@llvm-project//llvm:PowerPCAsmParser", # fixdeps: keep - "@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep - ], - "@tsl//tsl:macos_arm64": [ - "@llvm-project//llvm:AArch64AsmParser", # fixdeps: keep - "@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep - ], - "//conditions:default": [ - ], - }) + if_llvm_system_z_available([ - "@llvm-project//llvm:SystemZAsmParser", # fixdeps: keep - "@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep ]), ) diff --git a/xla/runtime/aot_ffi.cc b/xla/runtime/aot_ffi.cc index fac277a9817aa..950a0d1395a9b 100644 --- a/xla/runtime/aot_ffi.cc +++ b/xla/runtime/aot_ffi.cc @@ -1,4 +1,4 @@ -// Copyright 2023 The TensorFlow Authors +// Copyright 2023 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/xla/runtime/aot_ffi.h b/xla/runtime/aot_ffi.h index 9800fa69aadd1..4d5a101f8a863 100644 --- a/xla/runtime/aot_ffi.h +++ b/xla/runtime/aot_ffi.h @@ -1,4 +1,4 @@ -// Copyright 2023 The TensorFlow Authors +// Copyright 2023 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/xla/runtime/aot_ffi_c_symbols.cc b/xla/runtime/aot_ffi_c_symbols.cc index cc355d45a38e1..7851172d4b8de 100644 --- a/xla/runtime/aot_ffi_c_symbols.cc +++ b/xla/runtime/aot_ffi_c_symbols.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/aot_ffi_c_symbols.h b/xla/runtime/aot_ffi_c_symbols.h index 202ea92f8acfe..694c880521e36 100644 --- a/xla/runtime/aot_ffi_c_symbols.h +++ b/xla/runtime/aot_ffi_c_symbols.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/aot_ffi_execution_context.h b/xla/runtime/aot_ffi_execution_context.h index 59565ee00bafb..3eeae717c5637 100644 --- a/xla/runtime/aot_ffi_execution_context.h +++ b/xla/runtime/aot_ffi_execution_context.h @@ -1,4 +1,4 @@ -// Copyright 2023 The TensorFlow Authors +// Copyright 2023 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/xla/runtime/arguments.cc b/xla/runtime/arguments.cc index 9ef28c42bad4b..4f69b70099a2a 100644 --- a/xla/runtime/arguments.cc +++ b/xla/runtime/arguments.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/arguments.h b/xla/runtime/arguments.h index 7604098e4efbf..527c7b32db22d 100644 --- a/xla/runtime/arguments.h +++ b/xla/runtime/arguments.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/async_runtime.cc b/xla/runtime/async_runtime.cc index 64ebf4939d369..753c0e336a9a6 100644 --- a/xla/runtime/async_runtime.cc +++ b/xla/runtime/async_runtime.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/async_runtime.h b/xla/runtime/async_runtime.h index 4e077af9ba361..9562c5c0fdeca 100644 --- a/xla/runtime/async_runtime.h +++ b/xla/runtime/async_runtime.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/async_values_cache.h b/xla/runtime/async_values_cache.h index 0a72c5806de00..03d16efa24e60 100644 --- a/xla/runtime/async_values_cache.h +++ b/xla/runtime/async_values_cache.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/compiler.h b/xla/runtime/compiler.h index 4c045cf1b9214..68a72e4d2fe2e 100644 --- a/xla/runtime/compiler.h +++ b/xla/runtime/compiler.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/constraints.cc b/xla/runtime/constraints.cc index 5086d25a31bd2..0f70240446491 100644 --- a/xla/runtime/constraints.cc +++ b/xla/runtime/constraints.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" namespace xla { namespace runtime { diff --git a/xla/runtime/constraints.h b/xla/runtime/constraints.h index 3a0366f55adb7..a9a2d624b56d9 100644 --- a/xla/runtime/constraints.h +++ b/xla/runtime/constraints.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/cpu_event.h b/xla/runtime/cpu_event.h index e7caaf8a932ab..3611d7fd728cf 100644 --- a/xla/runtime/cpu_event.h +++ b/xla/runtime/cpu_event.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/custom_call.cc b/xla/runtime/custom_call.cc index 60bb02e7ae72b..af411d63976d0 100644 --- a/xla/runtime/custom_call.cc +++ b/xla/runtime/custom_call.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/custom_call.h b/xla/runtime/custom_call.h index aa4ff57a35801..27c64483b55d2 100644 --- a/xla/runtime/custom_call.h +++ b/xla/runtime/custom_call.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -808,7 +808,7 @@ struct Decode, checks> { if (auto decoded = DecodeUserData(ctx.user_data); LLVM_LIKELY(succeeded(decoded))) return decoded; - return ctx.diagnostic->EmitError(InternalError( + return ctx.diagnostic->EmitError(Internal( "failed to decode UserData of type %s", typeid(T).name())); } }; diff --git a/xla/runtime/custom_call_registry.cc b/xla/runtime/custom_call_registry.cc index 5926655974911..596d7c7a3ea39 100644 --- a/xla/runtime/custom_call_registry.cc +++ b/xla/runtime/custom_call_registry.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/custom_call_registry.h b/xla/runtime/custom_call_registry.h index d82fb51257323..344d608c302c2 100644 --- a/xla/runtime/custom_call_registry.h +++ b/xla/runtime/custom_call_registry.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/custom_call_test.cc b/xla/runtime/custom_call_test.cc deleted file mode 100644 index 96d873ab5931f..0000000000000 --- a/xla/runtime/custom_call_test.cc +++ /dev/null @@ -1,2073 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/runtime/custom_call.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "xla/mlir/runtime/ir/tests/testlib.h" -#include "xla/mlir/runtime/transforms/compilation_pipeline_gpu.h" -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/arguments.h" -#include "xla/runtime/async_runtime.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/diagnostics.h" -#include "xla/runtime/jit_executable.h" -#include "xla/runtime/module.h" -#include "xla/runtime/state.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" - -namespace xla { -namespace runtime { - -using absl::StatusOr; - -//===----------------------------------------------------------------------===// -// A helper function that compiles `module` to XLA runtime executable and runs -// `test` function with the given arguments. Caller can also register custom -// calls (direct or dynamic) and custom types. -//===----------------------------------------------------------------------===// - -struct CustomCallRegistry { - std::function dynamic_custom_calls; - std::function direct_custom_calls; -}; - -static absl::StatusOr Compile( - std::string_view source, const CustomCallRegistry& registry, - const CompilationPipelineOptions& copts, - const TypeConverter& type_converter = {}, - absl::Span exported = {"test"}) { - JitExecutable::Options opts; - opts.specialization = JitExecutable::Specialization::kDisabled; - opts.compiler.symbols_binding = ToSymbolsBinding( - registry.direct_custom_calls, copts.populate_type_id_names); - opts.compiler.type_converter = type_converter; - - opts.compiler.register_dialects = [&](DialectRegistry& dialects) { - RegisterTestlibDialect(dialects); - RegisterDefaultXlaGpuRuntimeDialects(dialects); - }; - - opts.compiler.create_compilation_pipeline = [=](PassManager& passes) { - CreateDefaultXlaGpuRuntimeCompilationPipeline(passes, copts, - /*add_async_passes=*/true); - }; - - return JitExecutable::Instantiate(source, opts, exported); -} - -static absl::Status CompileAndExecute( - std::string_view source, ArgumentsRef args, - const CustomCallRegistry& registry, - const CompilationPipelineOptions& copts = {}, - const TypeConverter& type_converter = {}, - absl::Span exported = {"test"}, - CustomCall::UserData user_data = {}) { - StatusOr jit_executable = - Compile(source, registry, copts, type_converter, exported); - if (!jit_executable.ok()) return jit_executable.status(); - - AsyncValuePtr executable = jit_executable->DefaultExecutable(); - if (executable.IsError()) - return absl::InternalError(executable.GetError().message()); - - // Register all dynamic custom calls. - DynamicCustomCallRegistry dynamic_custom_calls; - if (registry.dynamic_custom_calls) - registry.dynamic_custom_calls(dynamic_custom_calls); - - // Always add a pointer to `self` to user data. - user_data.insert(&executable.get()); - - // Collect all emitted diangostics to a string; - std::string error; - DiagnosticEngine diagnostic_engine; - diagnostic_engine.AddHandler([&](Diagnostic& diagnostic) -> LogicalResult { - error.append(diagnostic.status().message()); - return success(); - }); - - Executable::ExecuteOpts execute_opts; - execute_opts.custom_call_registry = &dynamic_custom_calls; - execute_opts.diagnostic_engine = &diagnostic_engine; - execute_opts.custom_call_data = &user_data; - execute_opts.async_task_runner = - reinterpret_cast(0XDEADBEEF); - - // We do not support returning results from tests. - NoResultConverter converter; - - auto executed = executable->Execute(args, converter, execute_opts); - if (!executed.ok()) - return absl::InternalError( - absl::StrFormat("%s: %s", executed.status().message(), error)); - - return absl::OkStatus(); -} - -template -static absl::StatusOr> CompileAndExecute( - std::string_view source, ArgumentsRef args, const StatefulModule& m, - absl::Span exported = {"test"}) { - CustomCallRegistry registry = { - [&](DynamicCustomCallRegistry& registry) { m.Export(registry); }, - [&](DirectCustomCallRegistry& registry) { m.Export(registry); }, - }; - auto state = m.CreateModuleState(); - if (!state.ok()) return state.status(); - - CustomCall::UserData user_data; - auto initialized = m.InitializeUserData(state->get(), user_data); - if (!initialized.ok()) return initialized; - - auto executed = CompileAndExecute(source, args, registry, /*copts=*/{}, - /*type_converter=*/{}, exported, user_data); - if (!executed.ok()) return executed; - - return state; -} - -// No-Op custom call with a single `i32` argument. -static void I32NoOp(DynamicCustomCallRegistry& registry) { - registry.Register( - CustomCall::Bind("test.custom_call").Arg().To([](int32_t) { - return success(); - })); -} - -//===----------------------------------------------------------------------===// -// A test for stateful module with a direct custom call. -//===----------------------------------------------------------------------===// - -struct Counter : public Module::State { - int32_t value = 0; -}; - -// Package custom call that updates a `Counter` as a runtime module. -struct CounterModule : public StatefulModule { - CounterModule() : StatefulModule("counter") {} - - static bool Inc(ExecutionContext* e, void** args, void** attrs, void** rets) { - auto impl = CustomCall::Bind("test.increment") - .UserData() // counter - .Arg() // value - .To([](Counter* counter, int32_t value) { - return success(counter->value += value); - }); - return succeeded(Executable::Call(e, *impl, args, attrs, rets)); - } - - void Export(DirectCustomCallRegistry& registry) const final { - registry.Register("test.increment", Inc); - } - - absl::StatusOr> CreateModuleState() const final { - return std::make_unique(); - } - - absl::Status InitializeUserData(Counter* state, - CustomCall::UserData& user_data) const final { - user_data.insert(state); - return absl::OkStatus(); - } -}; - -TEST(CustomCallTest, DirectCustomCall) { - absl::string_view source = R"( - func.func private @increment(%arg0: i32) - attributes { rt.custom_call = "test.increment" } - - func.func @test() { - %0 = arith.constant 42 : i32 - call @increment(%0) : (i32) -> () - return - } - )"; - - auto counter = CompileAndExecute(source, /*args=*/{}, CounterModule()); - ASSERT_TRUE(counter.ok()); - EXPECT_EQ((*counter)->value, 42); -} - -//===----------------------------------------------------------------------===// -// All other tests use dynamic custom calls and do not use modules. -//===----------------------------------------------------------------------===// - -TEST(CustomCallTest, ScalarArgs) { - absl::string_view source = R"( - func.func private @custom_call(%arg0: i1, %arg1: i32, %arg2: i64, - %arg3: f32, %arg4: f64) - attributes { rt.dynamic, rt.custom_call = "test.custom_call" } - - func.func @test() { - %0 = arith.constant false - %1 = arith.constant 42 : i32 - %2 = arith.constant 42 : i64 - %3 = arith.constant 42.0 : f32 - %4 = arith.constant 42.0 : f64 - call @custom_call(%0, %1, %2, %3, %4) : (i1, i32, i64, f32, f64) -> () - return - } - )"; - - bool i1 = true; - int32_t i32 = 0; - int64_t i64 = 0; - float f32 = 0.0; - double f64 = 0.0; - - auto f = [&](bool arg0, int32_t arg1, int64_t arg2, float arg3, double arg4) { - (i1 = arg0, i32 = arg1, i64 = arg2, f32 = arg3, f64 = arg4); - return success(); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.custom_call") - .Arg() - .Arg() - .Arg() - .Arg() - .Arg() - .To(f)); - }}; - - ASSERT_TRUE(CompileAndExecute(source, /*args=*/{}, registry).ok()); - - EXPECT_EQ(i1, false); - EXPECT_EQ(i32, 42); - EXPECT_EQ(i64, 42); - EXPECT_EQ(f32, 42.0); - EXPECT_EQ(f64, 42.0); -} - -TEST(CustomCallTest, ScalarRets) { - absl::string_view source = R"( - func.func private @custom_call_result() -> (i1, i32, i64, f32, f64) - attributes { rt.dynamic, rt.custom_call = "test.custom_call_result" } - - func.func private @custom_call(%arg0: i1, %arg1: i32, %arg2: i64, - %arg3: f32, %arg4: f64) - attributes { rt.dynamic, rt.custom_call = "test.custom_call" } - - func.func @test() { - %0, %1, %2, %3, %4 = call @custom_call_result() - : () -> (i1, i32, i64, f32, f64) - call @custom_call(%0, %1, %2, %3, %4) : (i1, i32, i64, f32, f64) -> () - return - } - )"; - - bool i1 = true; - int32_t i32 = 0; - int64_t i64 = 0; - float f32 = 0.0; - double f64 = 0.0; - - auto f_result = [&](Result ret0, Result ret1, - Result ret2, Result ret3, - Result ret4) { - ret0.Set(false); - ret1.Set(42); - ret2.Set(42); - ret3.Set(42.0); - ret4.Set(42.0); - return success(); - }; - - auto f = [&](bool arg0, int32_t arg1, int64_t arg2, float arg3, double arg4) { - (i1 = arg0, i32 = arg1, i64 = arg2, f32 = arg3, f64 = arg4); - return success(); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.custom_call_result") - .Ret() - .Ret() - .Ret() - .Ret() - .Ret() - .To(f_result)); - - registry.Register(CustomCall::Bind("test.custom_call") - .Arg() - .Arg() - .Arg() - .Arg() - .Arg() - .To(f)); - }}; - - ASSERT_TRUE(CompileAndExecute(source, /*args=*/{}, registry).ok()); - - EXPECT_EQ(i1, false); - EXPECT_EQ(i32, 42); - EXPECT_EQ(i64, 42); - EXPECT_EQ(f32, 42.0); - EXPECT_EQ(f64, 42.0); -} - -TEST(CustomCallTest, StatusOrRet) { - absl::string_view source = R"( - func.func private @custom_call_return(%arg0: i32) -> (i64) - attributes { rt.dynamic, rt.custom_call = "test.custom_call_return" } - - func.func private @custom_call(%arg64 : i64) - attributes { rt.dynamic, rt.custom_call = "test.custom_call" } - - func.func @test() { - %0 = arith.constant 42 : i32 - %1 = call @custom_call_return(%0) : (i32) -> (i64) - call @custom_call(%1) : (i64) -> () - return - } - )"; - - int64_t i64 = 0; - auto f_result = [](int32_t arg) -> absl::StatusOr { return arg; }; - auto f = [&](int64_t arg) { - i64 = arg; - return success(); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.custom_call_return") - .Arg() - .Ret() - .To(f_result)); - - registry.Register( - CustomCall::Bind("test.custom_call").Arg().To(f)); - }}; - - ASSERT_TRUE(CompileAndExecute(source, /*args=*/{}, registry).ok()); - EXPECT_EQ(i64, 42); -} - -TEST(CustomCallTest, StatusOrAsyncToken) { - absl::string_view source = R"( - func.func private @custom_call_return() -> !async.token - attributes { rt.dynamic, rt.custom_call = "test.custom_call_return" } - - func.func @test() { - %0 = call @custom_call_return() : () -> !async.token - async.await %0 : !async.token - return - } - )"; - - auto f_result = []() -> absl::StatusOr> { - return tsl::MakeAvailableAsyncValueRef(); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.custom_call_return") - .Ret>() - .To(f_result)); - }}; - - ASSERT_TRUE(CompileAndExecute(source, /*args=*/{}, registry).ok()); -} - -TEST(CustomCallTest, StatusOrAsyncScalarValue) { - absl::string_view source = R"( - func.func private @custom_call_return() -> !async.value - attributes { rt.dynamic, rt.custom_call = "test.custom_call_return" } - - func.func private @custom_call(%arg32 : i32) - attributes { rt.dynamic, rt.custom_call = "test.custom_call" } - - func.func @test() { - %0 = call @custom_call_return() : () -> !async.value - %1 = async.await %0 : !async.value - call @custom_call(%1) : (i32) -> () - return - } - )"; - - auto f_result = []() -> absl::StatusOr> { - return tsl::MakeAvailableAsyncValueRef(42); - }; - - int32_t i32 = 0; - auto f = [&](int32_t arg) { - i32 = arg; - return success(); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.custom_call_return") - .Ret>() - .To(f_result)); - - registry.Register( - CustomCall::Bind("test.custom_call").Arg().To(f)); - }}; - - ASSERT_TRUE(CompileAndExecute(source, /*args=*/{}, registry).ok()); - EXPECT_EQ(i32, 42); -} - -TEST(CustomCallTest, StatusOrTupleRets) { - absl::string_view source = R"( - func.func private @custom_call_return(%arg0 : i64, %arg1 : i64) -> (i64, - i64) - attributes { rt.dynamic, rt.custom_call = "test.custom_call_return" } - - func.func private @custom_call(%arg0 : i64, %arg1 : i64) - attributes { rt.dynamic, rt.custom_call = "test.custom_call" } - - func.func @test() { - %0 = arith.constant 42 : i64 - %1 = arith.constant 43 : i64 - %2, %3 = call @custom_call_return(%0, %1) : (i64, i64) -> (i64, i64) - call @custom_call(%2, %3) : (i64, i64) -> () - return - } - )"; - - int64_t a = 0; - int64_t b = 0; - auto f_result = - [](int64_t arg0, - int64_t arg1) -> absl::StatusOr> { - return std::make_tuple(arg0, arg1); - }; - auto f = [&](int64_t arg0, int64_t arg1) { - a = arg0; - b = arg1; - return success(); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.custom_call_return") - .Arg() - .Ret() - .Arg() - .Ret() - .To(f_result)); - - registry.Register(CustomCall::Bind("test.custom_call") - .Arg() - .Arg() - .To(f)); - }}; - - ASSERT_TRUE(CompileAndExecute(source, /*args=*/{}, registry).ok()); - EXPECT_EQ(a, 42); - EXPECT_EQ(b, 43); -} - -TEST(CustomCallTest, OpaqueArgs) { - absl::string_view source = R"( - func.func private @use(%arg0: !rt.opaque) - attributes { rt.dynamic, rt.custom_call = "test.use" } - - func.func @test(%arg0: !rt.opaque) { - call @use(%arg0) : (!rt.opaque) -> () - return - } - )"; - - // We'll pass around an opaque pointer to this string in our custom calls. - std::string message = ""; - - auto use = [&](void* arg0) { - std::string* str = reinterpret_cast(arg0); - (*str) += "foo"; - return success(); - }; - - OpaqueArg arg0(&message); - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.use").Arg().To(use)); - }}; - - ASSERT_TRUE(CompileAndExecute(source, {arg0}, registry).ok()); - EXPECT_EQ(message, "foo"); -} - -TEST(CustomCallTest, OpaqueArgsAndRets) { - absl::string_view source = R"( - func.func private @make() -> (!rt.opaque) - attributes { rt.dynamic, rt.custom_call = "test.make" } - - func.func private @use(%arg0: !rt.opaque) - attributes { rt.dynamic, rt.custom_call = "test.use" } - - func.func @test() { - %0 = call @make() : () -> (!rt.opaque) - call @use(%0) : (!rt.opaque) -> () - return - } - )"; - - // We'll pass around an opaque pointer to this string in our custom calls. - std::string message = ""; - - auto make = [&](Result res) { - res.Set(&message); - return success(); - }; - - auto use = [&](void* arg0) { - std::string* str = reinterpret_cast(arg0); - (*str) += "foo"; - return success(); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.make").Ret().To(make)); - registry.Register(CustomCall::Bind("test.use").Arg().To(use)); - }}; - - ASSERT_TRUE(CompileAndExecute(source, /*args=*/{}, registry).ok()); - EXPECT_EQ(message, "foo"); -} - -// Instead of passing a pointer to value of underlying type we pass it wrapped -// into a typed reference, for example this would allow to automatically cast -// type-erased `AsyncValue *` to typed `AsyncValuePtr`. -struct ValueRef { - std::string* value = nullptr; -}; - -// Register decoding for `ValueRef` (!testlib.value) arguments and results. -XLA_RUNTIME_REGISTER_OPAQUE_ARG_DECODING(ValueRef, std::string*); -XLA_RUNTIME_REGISTER_OPAQUE_RET_DECODING(ValueRef, std::string*); - -// Register mapping from custom type id to its unique symbol name. -static void RegisterTypeName(TypeIDNameRegistry& registry) { - registry.Register>("__type_id_testlib_value"); -} - -// Register custom call argument encoding for a custom value type. -static void RegisterArgEncoding(CustomCallArgEncodingSet& encoding) { - encoding.Add(OpaqueArgEncoding::Match(), - TypeID::get>()); -} - -// Register custom call result encoding for a custom value type. -static void RegisterRetEncoding(CustomCallRetEncodingSet& encoding) { - encoding.Add(OpaqueRetEncoding::Match(), - TypeID::get>()); -} - -// Conversion from argument compile-time type to the argument run-time types. -static std::unique_ptr ConvertArgTypeToOpaqueArg(ValueType arg) { - return std::make_unique(); -} - -// Compilation pipeline options with `testlib` and custom args/rets support. -CompilationPipelineOptions TestlibCopts() { - CompilationPipelineOptions copts; - copts.populate_type_id_names = RegisterTypeName; - copts.populate_arg_encodings = RegisterArgEncoding; - copts.populate_ret_encodings = RegisterRetEncoding; - copts.populate_type_conversions = AddTestlibTypeConversions; - return copts; -} - -TEST(CustomCallTest, CustomArgAsOpaqueArg) { - absl::string_view source = R"( - func.func private @use(%arg0: !testlib.value) - attributes { rt.dynamic, rt.custom_call = "test.use" } - - func.func @test(%arg0: !testlib.value) { - call @use(%arg0) : (!testlib.value) -> () - return - } - )"; - - // We'll pass around an opaque pointer to this string in our custom calls. - std::string message = ""; - - auto use = [&](ValueRef arg0) { - (*arg0.value) += "foo"; - return success(); - }; - - OpaqueArg arg0(&message); - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.use").Arg().To(use)); - }}; - - CompilationPipelineOptions copts = TestlibCopts(); - TypeConverter type_converter(ConvertArgTypeToOpaqueArg); - - ASSERT_TRUE( - CompileAndExecute(source, {arg0}, registry, copts, type_converter).ok()); - EXPECT_EQ(message, "foo"); -} - -// In the test above we relied on the conversion of custom argument type to -// opaque type and opaque argument. In this test we introduce a custom type and -// argument to preserve the type information at run time. -struct ValueArgType : public llvm::RTTIExtends { - static constexpr char ID = 0; // NOLINT - StatusOr AsArgument() const final { return ArgumentAbi{1}; } - std::string ToString() const final { return "!testlib.value"; } -}; - -// Value argument passed as a single pointer to the XLA executable. -struct ValueArg final : public llvm::RTTIExtends { - static constexpr char ID = 0; // NOLINT - - explicit ValueArg(std::string* ptr) : ptr(ptr) {} - - absl::Status Verify(const Type& type) const final { - return llvm::isa(type) - ? absl::OkStatus() - : absl::InvalidArgumentError("unsupported type"); - } - - void Pack(absl::Span args) const final { - args[0] = const_cast(reinterpret_cast(&ptr)); - } - - std::string ToString() const final { return "!testlib.value"; } - - std::string* ptr; -}; - -// Converts `!testlib.value` type to the `ValueArgType` run-time type. -static std::unique_ptr ConvertArgTypeToValueArg(ValueType arg) { - return std::make_unique(); -} - -TEST(CustomCallTest, CustomArg) { - absl::string_view source = R"( - func.func private @use(%arg0: !testlib.value) - attributes { rt.dynamic, rt.custom_call = "test.use" } - - func.func @test(%arg0: !testlib.value) { - call @use(%arg0) : (!testlib.value) -> () - return - } - )"; - - // We'll pass around an opaque pointer to this string in our custom calls. - std::string message = ""; - - auto use = [&](ValueRef arg0) { - (*arg0.value) += "bar"; - return success(); - }; - - ValueArg arg0(&message); - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.use").Arg().To(use)); - }}; - - CompilationPipelineOptions copts = TestlibCopts(); - TypeConverter type_converter(ConvertArgTypeToValueArg); - - ASSERT_TRUE( - CompileAndExecute(source, {arg0}, registry, copts, type_converter).ok()); - EXPECT_EQ(message, "bar"); -} - -TEST(CustomCallTest, CustomArgsAndRets) { - absl::string_view source = R"( - func.func private @make() -> (!testlib.value) - attributes { rt.dynamic, rt.custom_call = "test.make" } - - func.func private @use(%arg0: !testlib.value) - attributes { rt.dynamic, rt.custom_call = "test.use" } - - func.func @test() { - %0 = call @make() : () -> (!testlib.value) - call @use(%0) : (!testlib.value) -> () - return - } - )"; - - // Our `!testlib.value` type at run time will be just a pointer to a string, - // and it will be encoded similar to the `!rt.opaque` test above. - std::string message = ""; - - auto make = [&](Result res) { - res.Set(&message); - return success(); - }; - - auto use = [&](ValueRef arg0) { - (*arg0.value) += "foo"; - return success(); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.make").Ret().To(make)); - registry.Register(CustomCall::Bind("test.use").Arg().To(use)); - }}; - - CompilationPipelineOptions copts = TestlibCopts(); - - ASSERT_TRUE(CompileAndExecute(source, /*args=*/{}, registry, copts).ok()); - EXPECT_EQ(message, "foo"); -} - -TEST(CustomCallTest, MemRefRets) { - absl::string_view source = R"( - func.func private @custom_call_result() -> memref<2x2xf32> - attributes { rt.dynamic, rt.custom_call = "test.custom_call_result" } - - func.func private @custom_call(%arg0: memref<2x2xf32>) - attributes { rt.dynamic, rt.custom_call = "test.custom_call" } - - func.func @test() { - %0 = call @custom_call_result() : () -> (memref<2x2xf32>) - call @custom_call(%0) : (memref<2x2xf32>) -> () - return - } - )"; - - // Allocate storage for arguments. - std::vector input = {1.0, 2.0, 3.0, 4.0}; - - // Observe returned memref by capturing memref argument shape and data. - std::vector arg_shape; - std::vector arg_data; - - auto f_result = [&](Result ret0) { - std::vector dims = {ret0.GetDims().begin(), ret0.GetDims().end()}; - ret0.Set({ret0.GetDType(), input.data(), dims}); - return success(); - }; - - auto f = [&](MemrefView arg0) { - absl::Span data = {reinterpret_cast(arg0.data), 4}; - arg_shape = {arg0.sizes.begin(), arg0.sizes.end()}; - arg_data = {data.begin(), data.end()}; - return success(); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.custom_call_result") - .Ret() // ret0 - .To(f_result)); - - registry.Register(CustomCall::Bind("test.custom_call") - .Arg() // arg0 - .To(f)); - }}; - - ASSERT_TRUE(CompileAndExecute(source, /*args=*/{}, registry).ok()); - EXPECT_EQ(arg_shape, std::vector({2, 2})); - EXPECT_EQ(arg_data, input); -} - -TEST(CustomCallTest, AsyncMemRefRets) { - absl::string_view source = R"( - func.func private @custom_call_result() -> !async.value> - attributes { rt.dynamic, rt.custom_call = "test.custom_call_result" } - - func.func private @custom_call(%arg0: memref<2x2xf32>) - attributes { rt.dynamic, rt.custom_call = "test.custom_call" } - - func.func @test() { - %0 = call @custom_call_result() : () -> (!async.value>) - %1 = async.await %0 : !async.value> - call @custom_call(%1) : (memref<2x2xf32>) -> () - return - } - )"; - - // Allocate storage for arguments. - std::vector input = {1.0, 2.0, 3.0, 4.0}; - - // Observe returned memref by capturing memref argument shape and data. - std::vector arg_shape; - std::vector arg_data; - - auto f_result = [&](Result> ret0) { - std::vector dims = {ret0.GetDims().begin(), ret0.GetDims().end()}; - auto async_value = tsl::MakeAvailableAsyncValueRef( - ret0.GetDType(), input.data(), dims); - ret0.Set(async_value); - return success(); - }; - - auto f = [&](MemrefView arg0) { - llvm::ArrayRef data = {reinterpret_cast(arg0.data), 4}; - arg_shape = {arg0.sizes.begin(), arg0.sizes.end()}; - arg_data = {data.begin(), data.end()}; - return success(); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.custom_call_result") - .Ret>() // ret0 - .To(f_result)); - - registry.Register(CustomCall::Bind("test.custom_call") - .Arg() // arg0 - .To(f)); - }}; - - ASSERT_TRUE(CompileAndExecute(source, {}, registry).ok()); - EXPECT_EQ(arg_shape, std::vector({2, 2})); - EXPECT_EQ(arg_data, input); -} - -TEST(CustomCallTest, ArgSizeCheck) { - // Try to pass two argument to a custom call that expects one. - absl::string_view source = R"( - func.func private @custom_call(%arg0: i32, %arg1: i32) - attributes { rt.dynamic, rt.custom_call = "test.custom_call" } - - func.func @test() { - %0 = arith.constant 42 : i32 - call @custom_call(%0, %0) : (i32, i32) -> () - return - } - )"; - - std::string error = ""; - - CustomCallRegistry registry = {I32NoOp}; - - auto status = CompileAndExecute(source, /*args=*/{}, registry); - EXPECT_FALSE(status.ok()); - EXPECT_EQ(status.message(), - "run time error: custom call 'test.custom_call' failed: Wrong " - "number of arguments: expected 1 got 2"); -} - -TEST(CustomCallTest, ArgTypeCheck) { - // Try to pass `i64` argument to a custom call that expects `i32`. - absl::string_view source = R"( - func.func private @custom_call(%arg1: i64) - attributes { rt.dynamic, rt.custom_call = "test.custom_call" } - - func.func @test() { - %0 = arith.constant 42 : i64 - call @custom_call(%0) : (i64) -> () - return - } - )"; - - std::string error = ""; - - CustomCallRegistry registry = {I32NoOp}; - - auto status = CompileAndExecute(source, /*args=*/{}, registry); - EXPECT_FALSE(status.ok()); - EXPECT_EQ(status.message(), - "run time error: custom call 'test.custom_call' failed: Failed to " - "decode all custom call operands (bad operads at: 0)"); -} - -// Register custom call attribute decoding for `testlib.enum_type`. -XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING(EnumType); - -TEST(CustomCallTest, EnumAttr) { - absl::string_view source = R"( - func.func private @custom_call() - attributes { rt.dynamic, rt.custom_call = "test.custom_call" } - - func.func @test() { - call @custom_call() { enum = #testlib.enum_type }: () -> () - return - } - )"; - - std::vector enums; - - auto handler = [&](EnumType value) -> LogicalResult { - enums.push_back(value); - return success(); - }; - - auto types = [](TypeIDNameRegistry& registry) { - registry.Register>("__type_id_testlib_enum"); - }; - - auto attrs = [](CustomCallAttrEncodingSet& encoding) { - encoding.Add>(); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.custom_call") - .Attr("enum") - .To(handler)); - }}; - - CompilationPipelineOptions copts; - copts.populate_type_id_names = types; - copts.populate_attr_encodings = attrs; - - EXPECT_TRUE(CompileAndExecute(source, /*args=*/{}, registry, copts).ok()); - ASSERT_EQ(enums.size(), 1); - EXPECT_EQ(enums.front(), EnumType::Baz); -} - -// Map enum defined by MLIR to a custom enum class. -enum class MyEnumType : uint32_t { kFoo, kBar, kBaz }; - -MyEnumType FromEnumType(EnumType value) { - switch (value) { - case EnumType::Foo: - return MyEnumType::kFoo; - case EnumType::Bar: - return MyEnumType::kBar; - case EnumType::Baz: - return MyEnumType::kBaz; - } -} - -XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING(MyEnumType); - -TEST(CustomCallTest, MappedEnumAttr) { - absl::string_view source = R"( - func.func private @custom_call() - attributes { rt.dynamic, rt.custom_call = "test.custom_call" } - - func.func @test() { - call @custom_call() { enum = #testlib.enum_type }: () -> () - return - } - )"; - - std::vector enums; - - auto handler = [&](MyEnumType value) -> LogicalResult { - enums.push_back(value); - return success(); - }; - - auto types = [](TypeIDNameRegistry& registry) { - registry.Register>("__type_id_my_enum"); - }; - - auto attrs = [](CustomCallAttrEncodingSet& encoding) { - encoding.Add>( - FromEnumType); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.custom_call") - .Attr("enum") - .To(handler)); - }}; - - CompilationPipelineOptions copts; - copts.populate_type_id_names = types; - copts.populate_attr_encodings = attrs; - - EXPECT_TRUE(CompileAndExecute(source, /*args=*/{}, registry, copts).ok()); - ASSERT_EQ(enums.size(), 1); - EXPECT_EQ(enums.front(), MyEnumType::kBaz); -} - -// Structure corresponding to the MLIR attribute. -struct PairOfDims { - int64_t rank; - absl::Span a; - absl::Span b; -}; - -// Register aggregate attribute decoding. -XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING( - PairOfDims, AggregateMember("rank"), - AggregateMember>("a"), - AggregateMember>("b")); - -TEST(CustomCallTest, StructAttr) { - absl::string_view source = R"( - func.func private @custom_call() - attributes { rt.dynamic, rt.custom_call = "test.custom_call" } - - func.func @test() { - call @custom_call() { - dims = #testlib.pair_of_dims<2, [1, 1], [2, 2]> - }: () -> () - return - } - )"; - - int64_t rank = 0; - std::vector a; - std::vector b; - - auto handler = [&](PairOfDims value) -> LogicalResult { - rank = value.rank; - a.assign(value.a.begin(), value.a.end()); - b.assign(value.b.begin(), value.b.end()); - return success(); - }; - - auto types = [](TypeIDNameRegistry& registry) { - registry.Register>("__type_id_pair_of_dims"); - }; - - auto attrs = [](CustomCallAttrEncodingSet& encoding) { - encoding.Add>( - encoding, AggregateAttrDef() - .Add("rank", &PairOfDimsAttr::getRank) - .Add("a", &PairOfDimsAttr::getA) - .Add("b", &PairOfDimsAttr::getB)); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.custom_call") - .Attr("dims") - .To(handler)); - }}; - - CompilationPipelineOptions copts; - copts.populate_type_id_names = types; - copts.populate_attr_encodings = attrs; - - EXPECT_TRUE(CompileAndExecute(source, /*args=*/{}, registry, copts).ok()); - EXPECT_EQ(rank, 2); - EXPECT_EQ(a, std::vector(2, 1)); - EXPECT_EQ(b, std::vector(2, 2)); -} - -TEST(CustomCallTest, FunctionOrdinalAttr) { - using FunctionOrdinal = CustomCall::FunctionOrdinal; - - absl::string_view source = R"( - func.func private @init() - attributes { rt.dynamic, rt.custom_call = "test.init" } - - func.func private @custom_call() - attributes { rt.dynamic, rt.custom_call = "test.custom_call" } - - // We use a nested call to `@init` custom call as a simple way of proving - // that `@call_init` was called from `@custom_call` handler. - func.func @call_init() { - call @init() : () -> () - return - } - - func.func @test() { - call @custom_call() { func = @call_init }: () -> () - return - } - )"; - - bool called_init = false; - - // Custom call handler for `@init` custom call. - auto init = [&]() { - called_init = true; - return success(); - }; - - // Dynamic custom call registry for resolving nested custom calls. - DynamicCustomCallRegistry nested_registry; - nested_registry.Register(CustomCall::Bind("test.init").To(init)); - - // Execute options for nested custom calls. - Executable::ExecuteOpts execute_opts; - execute_opts.custom_call_registry = &nested_registry; - execute_opts.async_task_runner = - reinterpret_cast(0XDEADBEEF); - - // Custom call handler for `@custom_call` custom call. - auto handler = [&](Executable* executable, FunctionOrdinal exported) { - FunctionRef fn = executable->function_ref(exported.ordinal); - return success(fn({}, NoResultConverter{}, execute_opts).ok()); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.init").To(init)); - registry.Register(CustomCall::Bind("test.custom_call") - .UserData() - .Attr("func") - .To(handler)); - }}; - - std::vector exported = {"test", "call_init"}; - EXPECT_TRUE(CompileAndExecute(source, /*args=*/{}, registry, /*copts=*/{}, - /*type_converter=*/{}, exported) - .ok()); - EXPECT_TRUE(called_init); -} - -TEST(CustomCallTest, OptionalAttr) { - absl::string_view source = R"( - func.func private @custom_call() - attributes { rt.dynamic, rt.custom_call = "test.custom_call" } - - func.func @test() { - call @custom_call() { attr0, attr1 = 42 : i64 }: () -> () - return - } - )"; - - std::vector> attrs; - - auto handler = [&](std::optional attr0, - std::optional attr1) -> LogicalResult { - attrs.push_back(attr0); - attrs.push_back(attr1); - return success(); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.custom_call") - .Attr>("attr0") - .Attr>("attr1") - .To(handler)); - }}; - - EXPECT_TRUE(CompileAndExecute(source, /*args=*/{}, registry).ok()); - ASSERT_EQ(attrs.size(), 2); - EXPECT_EQ(attrs[0], std::nullopt); - EXPECT_EQ(attrs[1], 42); -} - -TEST(CustomCallTest, StateArg) { - absl::string_view source = R"( - func.func private @custom_call() - attributes { rt.dynamic, rt.custom_call = "test.custom_call" } - - func.func @test() { - call @custom_call() { id = 0 : i64 } : () -> () - return - } - )"; - - auto handler = [](int64_t id, State state0, State state1) { - state0.GetOrCreate([] { return 42; }).IgnoreError(); - state1.GetOrCreate([] { return 42; }).IgnoreError(); - return success(); - }; - - StateVector state_i32; - StateVector state_i64; - - StateVector::Snapshot snapshot_i32 = state_i32.snapshot(); - StateVector::Snapshot snapshot_i64 = state_i64.snapshot(); - CustomCall::UserData user_data(&snapshot_i32, &snapshot_i64); - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.custom_call") - .Attr("id") - .State("id") - .State("id") - .To(handler)); - }}; - - ASSERT_TRUE(CompileAndExecute(source, /*args=*/{}, registry, /*copts=*/{}, - /*type_converter=*/{}, {"test"}, user_data) - .ok()); - ASSERT_EQ(*state_i32[0], 42); - ASSERT_EQ(*state_i64[0], 42); -} - -TEST(CustomCallTest, DictionaryAttr) { - absl::string_view source = R"( - func.func private @custom_call() - attributes { rt.dynamic, rt.custom_call = "test.custom_call" } - - func.func @test() { - call @custom_call() { - dict = { foo = "Uh oh", bar = 42 : i32, baz = array } - }: () -> () - return - } - )"; - - std::string foo; - int32_t bar = 0; - std::vector baz; - std::vector dictionary_keys; - - auto handler = [&](Dictionary dict) -> LogicalResult { - if (dict.size() != 3) return failure(); - - foo = *dict.get("foo"); - bar = *dict.get("bar"); - auto span = dict.get>("baz"); - baz = std::vector(span->begin(), span->end()); - - // Need to copy to vector of strings since strings string_view points to - // will no longer exist once this runs. - for (auto key : dict.keys()) { - dictionary_keys.push_back(std::string(key)); - } - - return success(); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.custom_call") - .Attr("dict") - .To(handler)); - }}; - - ASSERT_TRUE(CompileAndExecute(source, /*args=*/{}, registry).ok()); - EXPECT_EQ(foo, "Uh oh"); - EXPECT_EQ(bar, 42); - EXPECT_EQ(baz, std::vector({1, 2})); - EXPECT_EQ(dictionary_keys, std::vector({"bar", "baz", "foo"})); -} - -TEST(CustomCallTest, MemrefF8Arg) { - absl::string_view source = R"( - func.func private @custom_call(%arg0: memref) - attributes { rt.dynamic, rt.custom_call = "test.custom_call" } - - func.func @test(%arg0: memref) { - call @custom_call(%arg0) : (memref) -> () - return - } - )"; - - xla::PrimitiveType dtype = xla::PrimitiveType::PRIMITIVE_TYPE_INVALID; - std::vector sizes; - - auto handler = [&](StridedMemrefView arg0) { - dtype = arg0.dtype; - sizes.assign(arg0.sizes.begin(), arg0.sizes.end()); - return success(); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.custom_call") - .Arg() - .To(handler)); - }}; - - std::vector data(42); - MemrefDesc arg0(PrimitiveType::F8E4M3FN, data.data(), 0, {42}, {1}); - - Arguments args(1); - args.emplace_back(std::move(arg0)); - - ASSERT_TRUE(CompileAndExecute(source, args, registry).ok()); - EXPECT_EQ(dtype, PrimitiveType::F8E4M3FN); - EXPECT_EQ(sizes.size(), 1); - EXPECT_EQ(sizes[0], 42); -} - -TEST(CustomCallTest, MemrefDynamicOffset) { - absl::string_view source = R"( - func.func private @custom_call(%arg: memref<1xi32, strided<[1], offset: ?>>) - attributes { rt.dynamic, rt.custom_call = "test.custom_call" } - - func.func @test(%arg0: memref<8xi32>, %arg1: i32) { - %0 = arith.index_castui %arg1 : i32 to index - %1 = memref.subview %arg0[%0] [1] [1] - : memref<8xi32> to memref<1xi32, strided<[1], offset: ?>> - call @custom_call(%1) : (memref<1xi32, strided<[1], offset: ?>>) -> () - return - } - )"; - - int32_t value = 0; - - auto handler = [&](StridedMemrefView arg0) { - value = reinterpret_cast(arg0.data)[0]; - return success(); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.custom_call") - .Arg() - .To(handler)); - }}; - - std::vector data(8); - std::iota(data.begin(), data.end(), 0); - MemrefDesc arg0(PrimitiveType::S32, data.data(), 0, {8}, {1}); - - Arguments args(2); - args.emplace_back(std::move(arg0)); - args.emplace_back(int32_t{4}); - - ASSERT_TRUE(CompileAndExecute(source, args, registry).ok()); - EXPECT_EQ(value, 4); -} - -//===----------------------------------------------------------------------===// -// Performance benchmarks are below. -//===----------------------------------------------------------------------===// - -namespace bm = ::testing::benchmark; - -using DirectCustomCall = DirectCustomCallRegistry::DirectCustomCall; -using RuntimeChecks = CustomCall::RuntimeChecks; - -// Give short aliases to enums for benchmarks pretty printing. -static constexpr RuntimeChecks all = RuntimeChecks::kDefault; -static constexpr RuntimeChecks less = RuntimeChecks::kLess; -static constexpr RuntimeChecks none = RuntimeChecks::kNone; - -static void BenchmarkCustomCall( - bm::State& state, std::string_view module, ArgumentsRef args, - std::string_view name, DirectCustomCall custom_call, - std::function types = {}, - std::function attrs = {}, - const CustomCall::UserData& user_data = {}) { - CustomCallRegistry registry; - - // Wrap benchmarked custom call into a direct custom call registry. - registry.direct_custom_calls = [&](DirectCustomCallRegistry& registry) { - registry.Register(name, custom_call); - }; - - CompilationPipelineOptions copts; - copts.populate_type_id_names = std::move(types); - copts.populate_attr_encodings = std::move(attrs); - - StatusOr jit_executable = Compile(module, registry, copts); - CHECK(jit_executable.ok()) << jit_executable.status(); - - AsyncValuePtr executable = jit_executable->DefaultExecutable(); - CHECK(!executable.IsError()) << executable.GetError().message(); - - // Prepare the call frame outside of a benchmark loop. - Executable::CallFrame call_frame; - CHECK(executable->InitializeCallFrame(args, &call_frame).ok()); - - Executable::ExecuteOpts execute_opts; - execute_opts.custom_call_data = &user_data; - execute_opts.async_task_runner = - reinterpret_cast(0XDEADBEEF); - - DiagnosticEngine diagnostic_engine; - execute_opts.diagnostic_engine = &diagnostic_engine; - - for (auto _ : state) { - call_frame.args[0] = nullptr; // reset execution context - executable->Execute(call_frame, execute_opts); - CHECK(!call_frame.is_error) << call_frame.error; - } -} - -//===----------------------------------------------------------------------===// -// Custom call with a single i32 argument. -//===----------------------------------------------------------------------===// - -template -static bool I32X1(ExecutionContext* ctx, void** args, void** attrs, - void** rets) { - static auto* handler = CustomCall::Bind("test.custom_call") - .Arg() - .To([](int32_t arg0) { - benchmark::DoNotOptimize(arg0); - return success(); - }) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -template -static void I32X1(bm::State& state) { - absl::string_view source = R"( - func.func private @custom_call(%arg0: i32) - attributes { rt.custom_call = "test.custom_call" } - - func.func @test() { - %0 = arith.constant 0 : i32 - call @custom_call(%0) : (i32) -> () - return - } - )"; - - BenchmarkCustomCall(state, source, {}, "test.custom_call", &I32X1); -} - -static void BM_I32X1All(bm::State& s) { I32X1(s); } -static void BM_I32X1None(bm::State& s) { I32X1(s); } - -BENCHMARK(BM_I32X1All); -BENCHMARK(BM_I32X1None); - -//===----------------------------------------------------------------------===// -// Custom call with twelve i32 argument. -//===----------------------------------------------------------------------===// - -template -static bool I32X12(ExecutionContext* ctx, void** args, void** attrs, - void** rets) { - static auto* handler = - CustomCall::Bind("test.custom_call") - .Arg() - .Arg() - .Arg() - .Arg() - .Arg() - .Arg() - .Arg() - .Arg() - .Arg() - .Arg() - .Arg() - .Arg() - .To([](int32_t arg0, int32_t arg1, int32_t arg2, int32_t arg3, - int32_t arg4, int32_t arg5, int32_t arg6, int32_t arg7, - int32_t arg8, int32_t arg9, int32_t arg10, - int32_t arg11) { - benchmark::DoNotOptimize(arg0 + arg1 + arg2 + arg3 + arg4 + arg5 + - arg6 + arg7 + arg8 + arg9 + arg10 + arg11); - return success(); - }) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -template -static void I32X12(bm::State& state) { - absl::string_view source = R"( - func.func private @custom_call(%arg0: i32, %arg1: i32, %arg2: i32, - %arg3: i32, %arg4: i32, %arg5: i32, - %arg6: i32, %arg7: i32, %arg8: i32, - %arg9: i32, %arg10: i32, %arg11: i32) - attributes { rt.custom_call = "test.custom_call" } - - func.func @test() { - %0 = arith.constant 0 : i32 - %1 = arith.constant 1 : i32 - %2 = arith.constant 2 : i32 - %3 = arith.constant 3 : i32 - %4 = arith.constant 4 : i32 - %5 = arith.constant 5 : i32 - %6 = arith.constant 6 : i32 - %7 = arith.constant 7 : i32 - %8 = arith.constant 8 : i32 - %9 = arith.constant 9 : i32 - %10 = arith.constant 10 : i32 - %11 = arith.constant 11 : i32 - call @custom_call(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11) - : (i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32) -> () - func.return - } - )"; - - BenchmarkCustomCall(state, source, {}, "test.custom_call", &I32X12); -} - -static void BM_I32X12All(bm::State& s) { I32X12(s); } -static void BM_I32X12None(bm::State& s) { I32X12(s); } - -BENCHMARK(BM_I32X12All); -BENCHMARK(BM_I32X12None); - -//===----------------------------------------------------------------------===// -// Custom call with a single i32 result. -//===----------------------------------------------------------------------===// - -template -static bool RetI32X1(ExecutionContext* ctx, void** args, void** attrs, - void** rets) { - static auto* handler = - CustomCall::Bind("test.custom_call") - .Ret() - .To([]() -> absl::StatusOr { return 42; }) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -template -static void RetI32X1(bm::State& state) { - absl::string_view source = R"( - func.func private @custom_call() -> i32 - attributes { rt.custom_call = "test.custom_call" } - - func.func @test() { - %0 = call @custom_call() : () -> (i32) - return - } - )"; - - BenchmarkCustomCall(state, source, {}, "test.custom_call", &RetI32X1); -} - -static void BM_RetI32X1All(bm::State& s) { RetI32X1(s); } -static void BM_RetI32X1None(bm::State& s) { RetI32X1(s); } - -BENCHMARK(BM_RetI32X1All); -BENCHMARK(BM_RetI32X1None); - -//===----------------------------------------------------------------------===// -// Custom call with twelve i32 results. -//===----------------------------------------------------------------------===// - -template -static bool RetI32X12(ExecutionContext* ctx, void** args, void** attrs, - void** rets) { - static auto* handler = - CustomCall::Bind("test.custom_call") - .Ret() - .Ret() - .Ret() - .Ret() - .Ret() - .Ret() - .Ret() - .Ret() - .Ret() - .Ret() - .Ret() - .Ret() - .To( - []() -> absl::StatusOr> { - return std::make_tuple(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12); - }) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -template -static void RetI32X12(bm::State& state) { - absl::string_view source = R"( - func.func private @custom_call() - -> (i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32) - attributes { rt.custom_call = "test.custom_call" } - - func.func @test() { - %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11 = call @custom_call() - : () -> (i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32) - return - } - )"; - - BenchmarkCustomCall(state, source, {}, "test.custom_call", - &RetI32X12); -} - -static void BM_RetI32X12All(bm::State& s) { RetI32X12(s); } -static void BM_RetI32X12None(bm::State& s) { RetI32X12(s); } - -BENCHMARK(BM_RetI32X12All); -BENCHMARK(BM_RetI32X12None); - -//===----------------------------------------------------------------------===// -// Custom call with a single memref argument. -//===----------------------------------------------------------------------===// - -using Flat = FlatMemrefView; -using Strided = StridedMemrefView; - -template -static bool MemrefX1(ExecutionContext* ctx, void** args, void** attrs, - void** rets) { - static auto* handler = CustomCall::Bind("test.custom_call") - .Arg() - .template To([](MemrefType arg0) { - benchmark::DoNotOptimize(arg0); - return success(); - }) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -template -static void MemrefX1(bm::State& state) { - absl::string_view source = R"( - func.func private @custom_call(%arg0: memref<4x4xf32>) - attributes { rt.custom_call = "test.custom_call" } - - func.func @test() { - %0 = memref.alloca() : memref<4x4xf32> - call @custom_call(%0) : (memref<4x4xf32>) -> () - return - } - )"; - - BenchmarkCustomCall(state, source, {}, "test.custom_call", - &MemrefX1); -} - -static void BM_FlatMemrefX1All(bm::State& s) { MemrefX1(s); } -static void BM_FlatMemrefX1None(bm::State& s) { MemrefX1(s); } -static void BM_MemrefX1All(bm::State& s) { MemrefX1(s); } -static void BM_MemrefX1None(bm::State& s) { MemrefX1(s); } -static void BM_StridedMemrefX1All(bm::State& s) { MemrefX1(s); } -static void BM_StridedMemrefX1None(bm::State& s) { MemrefX1(s); } - -BENCHMARK(BM_FlatMemrefX1All); -BENCHMARK(BM_FlatMemrefX1None); - -BENCHMARK(BM_MemrefX1All); -BENCHMARK(BM_MemrefX1None); - -BENCHMARK(BM_StridedMemrefX1All); -BENCHMARK(BM_StridedMemrefX1None); - -//===----------------------------------------------------------------------===// -// Custom call with twelve memref argument. -//===----------------------------------------------------------------------===// - -template -static bool MemrefX12(ExecutionContext* ctx, void** args, void** attrs, - void** rets) { - static auto* handler = - CustomCall::Bind("test.custom_call") - .template Arg() - .template Arg() - .template Arg() - .template Arg() - .template Arg() - .template Arg() - .template Arg() - .template Arg() - .template Arg() - .template Arg() - .template Arg() - .template Arg() - .template To( - [](MemrefType arg0, MemrefType arg1, MemrefType arg2, - MemrefType arg3, MemrefType arg4, MemrefType arg5, - MemrefType arg6, MemrefType arg7, MemrefType arg8, - MemrefType arg9, MemrefType arg10, MemrefType arg11) { - benchmark::DoNotOptimize(arg0); - benchmark::DoNotOptimize(arg1); - benchmark::DoNotOptimize(arg2); - benchmark::DoNotOptimize(arg3); - benchmark::DoNotOptimize(arg4); - benchmark::DoNotOptimize(arg5); - benchmark::DoNotOptimize(arg6); - benchmark::DoNotOptimize(arg7); - benchmark::DoNotOptimize(arg8); - benchmark::DoNotOptimize(arg9); - benchmark::DoNotOptimize(arg10); - benchmark::DoNotOptimize(arg11); - return success(); - }) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -template -static void MemrefX12(bm::State& state) { - absl::string_view source = R"( - func.func private @custom_call( - %arg0: memref<4x4xf32>, %arg1: memref<4x4xf32>, %arg2: memref<4x4xf32>, - %arg3: memref<4x4xf32>, %arg4: memref<4x4xf32>, %arg5: memref<4x4xf32>, - %arg6: memref<4x4xf32>, %arg7: memref<4x4xf32>, %arg8: memref<4x4xf32>, - %arg9: memref<4x4xf32>, %arg10: memref<4x4xf32>, %arg11: memref<4x4xf32> - ) attributes { rt.custom_call = "test.custom_call" } - - func.func @test() { - %0 = memref.alloca() : memref<4x4xf32> - call @custom_call(%0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0) - : (memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>, - memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>, - memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32> - ) -> () - return - } - )"; - - BenchmarkCustomCall(state, source, {}, "test.custom_call", - &MemrefX12); -} - -static void BM_FlatMemrefX12All(bm::State& s) { MemrefX12(s); } -static void BM_FlatMemrefX12None(bm::State& s) { MemrefX12(s); } -static void BM_MemrefX12All(bm::State& s) { MemrefX12(s); } -static void BM_MemrefX12None(bm::State& s) { MemrefX12(s); } -static void BM_StridedMemrefX12All(bm::State& s) { MemrefX12(s); } -static void BM_StridedMemrefX12None(bm::State& s) { - MemrefX12(s); -} - -BENCHMARK(BM_FlatMemrefX12All); -BENCHMARK(BM_FlatMemrefX12None); - -BENCHMARK(BM_MemrefX12All); -BENCHMARK(BM_MemrefX12None); - -BENCHMARK(BM_StridedMemrefX12All); -BENCHMARK(BM_StridedMemrefX12None); - -//===----------------------------------------------------------------------===// -// Custom call with a single i32 attribute. -//===----------------------------------------------------------------------===// - -template -static bool I32AttrX1(ExecutionContext* ctx, void** args, void** attrs, - void** rets) { - static auto* handler = CustomCall::Bind("test.custom_call") - .Attr("attr0") - .To([](int32_t attr0) { - benchmark::DoNotOptimize(attr0); - return success(); - }) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -template -static void I32AttrX1(bm::State& state) { - absl::string_view source = R"( - func.func private @custom_call() - attributes { rt.custom_call = "test.custom_call" } - - func.func @test() { - call @custom_call() { attr0 = 42 : i32 }: () -> () - return - } - )"; - - BenchmarkCustomCall(state, source, {}, "test.custom_call", - &I32AttrX1); -} - -static void BM_I32AttrX1All(bm::State& s) { I32AttrX1(s); } -static void BM_I32AttrX1None(bm::State& s) { I32AttrX1(s); } -static void BM_I32AttrX1Less(bm::State& s) { I32AttrX1(s); } - -BENCHMARK(BM_I32AttrX1All); -BENCHMARK(BM_I32AttrX1Less); -BENCHMARK(BM_I32AttrX1None); - -//===----------------------------------------------------------------------===// -// Custom call with twelve i32 attributes. -//===----------------------------------------------------------------------===// - -template -static bool I32AttrX12(ExecutionContext* ctx, void** args, void** attrs, - void** rets) { - static auto* handler = - CustomCall::Bind("test.custom_call") - .Attr("attr0") - .Attr("attr1") - .Attr("attr2") - .Attr("attr3") - .Attr("attr4") - .Attr("attr5") - .Attr("attr6") - .Attr("attr7") - .Attr("attr8") - .Attr("attr9") - .Attr("attr10") - .Attr("attr11") - .To([](int32_t attr0, int32_t attr1, int32_t attr2, - int32_t attr3, int32_t attr4, int32_t attr5, - int32_t attr6, int32_t attr7, int32_t attr8, - int32_t attr9, int32_t attr10, int32_t attr11) { - benchmark::DoNotOptimize(attr0 + attr1 + attr2 + attr3 + attr4 + - attr5 + attr6 + attr7 + attr8 + attr9 + - attr10 + attr11); - return success(); - }) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -template -static void I32AttrX12(bm::State& state) { - absl::string_view source = R"( - func.func private @custom_call() - attributes { rt.custom_call = "test.custom_call" } - - func.func @test() { - call @custom_call() - { "attr0" = 0 : i32, "attr1" = 1 : i32, "attr2" = 2 : i32, - "attr3" = 3 : i32, "attr4" = 4 : i32, "attr5" = 5 : i32, - "attr6" = 6 : i32, "attr7" = 7 : i32, "attr8" = 8 : i32, - "attr9" = 9 : i32, "attr10" = 10 : i32, "attr11" = 11 : i32 - } : () -> () - func.return - } - )"; - - BenchmarkCustomCall(state, source, {}, "test.custom_call", - &I32AttrX12); -} - -static void BM_I32AttrX12All(bm::State& s) { I32AttrX12(s); } -static void BM_I32AttrX12None(bm::State& s) { I32AttrX12(s); } -static void BM_I32AttrX12Types(bm::State& s) { I32AttrX12(s); } - -BENCHMARK(BM_I32AttrX12All); -BENCHMARK(BM_I32AttrX12Types); -BENCHMARK(BM_I32AttrX12None); - -//===----------------------------------------------------------------------===// -// Custom call with a single PairOfDims attribute. -//===----------------------------------------------------------------------===// - -template -static bool AggregateAttrX1(ExecutionContext* ctx, void** args, void** attrs, - void** rets) { - static auto* handler = CustomCall::Bind("test.custom_call") - .Attr("attr0") - .To([](PairOfDims attr0) { - benchmark::DoNotOptimize(attr0); - return success(); - }) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -template -static void AggregateAttrX1(bm::State& state) { - absl::string_view source = R"( - func.func private @custom_call() - attributes { rt.custom_call = "test.custom_call" } - - func.func @test() { - call @custom_call() { - attr0 = #testlib.pair_of_dims<2, [1, 1], [2, 2]> - }: () -> () - return - } - )"; - - auto types = [](TypeIDNameRegistry& registry) { - registry.Register>("__type_id_pair_of_dims"); - }; - - auto attrs = [](CustomCallAttrEncodingSet& encoding) { - encoding.Add>( - encoding, AggregateAttrDef() - .Add("rank", &PairOfDimsAttr::getRank) - .Add("a", &PairOfDimsAttr::getA) - .Add("b", &PairOfDimsAttr::getB)); - }; - - BenchmarkCustomCall(state, source, {}, "test.custom_call", - &AggregateAttrX1, types, attrs); -} - -static void BM_AggregateAttrX1All(bm::State& s) { AggregateAttrX1(s); } -static void BM_AggregateAttrX1None(bm::State& s) { AggregateAttrX1(s); } -static void BM_AggregateAttrX1Less(bm::State& s) { AggregateAttrX1(s); } - -BENCHMARK(BM_AggregateAttrX1All); -BENCHMARK(BM_AggregateAttrX1Less); -BENCHMARK(BM_AggregateAttrX1None); - -//===----------------------------------------------------------------------===// -// Custom call with UserData arguments. -//===----------------------------------------------------------------------===// - -// Use std::integral_constant to fake multiple unique UserData types. -template -using Data = std::integral_constant; - -// Benchmark how long it takes to prepare UserData. -static void BM_PrepareUserData(bm::State& state) { - Data<0> data0; - Data<1> data1; - Data<2> data2; - Data<3> data3; - Data<4> data4; - Data<5> data5; - Data<6> data6; - Data<7> data7; - Data<8> data8; - Data<9> data9; - - for (auto _ : state) { - CustomCall::UserData user_data(&data0, &data1, &data2, &data3, &data4, - &data5, &data6, &data7, &data8, &data9); - benchmark::DoNotOptimize(user_data); - } -} - -BENCHMARK(BM_PrepareUserData); - -template -static bool UserDataX12(ExecutionContext* ctx, void** args, void** attrs, - void** rets) { - static auto* handler = - CustomCall::Bind("test.custom_call") - .UserData*>() - .UserData*>() - .UserData*>() - .UserData*>() - .UserData*>() - .UserData*>() - .UserData*>() - .UserData*>() - .UserData*>() - .UserData*>() - .UserData*>() - .UserData*>() - .To([](Data<0>* data0, Data<1>* data1, Data<2>* data2, - Data<3>* data3, Data<4>* data4, Data<5>* data5, - Data<6>* data6, Data<7>* data7, Data<8>* data8, - Data<9>* data9, Data<10>* data10, Data<11>* data11) { - benchmark::DoNotOptimize(data0); - benchmark::DoNotOptimize(data1); - benchmark::DoNotOptimize(data2); - benchmark::DoNotOptimize(data3); - benchmark::DoNotOptimize(data4); - benchmark::DoNotOptimize(data5); - benchmark::DoNotOptimize(data6); - benchmark::DoNotOptimize(data7); - benchmark::DoNotOptimize(data8); - benchmark::DoNotOptimize(data9); - benchmark::DoNotOptimize(data10); - benchmark::DoNotOptimize(data11); - return success(); - }) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -template -static void UserDataX12(bm::State& state) { - absl::string_view source = R"( - func.func private @custom_call() - attributes { rt.custom_call = "test.custom_call" } - - func.func @test() { - call @custom_call() : () -> () - return - } - )"; - - Data<0> data0; - Data<1> data1; - Data<2> data2; - Data<3> data3; - Data<4> data4; - Data<5> data5; - Data<6> data6; - Data<7> data7; - Data<8> data8; - Data<9> data9; - Data<10> data10; - Data<11> data11; - - CustomCall::UserData user_data; - user_data.insert_all(&data0, &data1, &data2, &data3, &data4, &data5, &data6, - &data7, &data8, &data9, &data10, &data11); - - BenchmarkCustomCall(state, source, {}, "test.custom_call", - &UserDataX12, {}, {}, user_data); -} - -static void BM_UserDataX12All(bm::State& s) { UserDataX12(s); } -static void BM_UserDataX12None(bm::State& s) { UserDataX12(s); } -static void BM_UserDataX12Less(bm::State& s) { UserDataX12(s); } - -BENCHMARK(BM_UserDataX12All); -BENCHMARK(BM_UserDataX12Less); -BENCHMARK(BM_UserDataX12None); - -//===----------------------------------------------------------------------===// -// Benchmark memref encoding for a sequence of custom calls. -//===----------------------------------------------------------------------===// - -static LogicalResult Sink(CustomCall::RemainingArgs) { return success(); } - -template -static bool RemainingArgsSink(ExecutionContext* ctx, void** args, void** attrs, - void** rets) { - static auto* handler = CustomCall::Bind("test.custom_call") - .RemainingArgs() - .To(CustomCall::FunctionWrapper()) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -template -static void MemrefEncoding(bm::State& state) { - absl::string_view source = R"( - func.func private @custom_call( - %arg0: memref<4x4xf32>, %arg1: memref<5x5xf32>, %arg2: memref<6x6xf32>, - %arg3: memref<4x4xf32>, %arg4: memref<5x5xf32>, %arg5: memref<6x6xf32> - ) attributes { rt.custom_call = "test.custom_call" } - - func.func @test() { - %0 = memref.alloca() : memref<4x4xf32> - %1 = memref.alloca() : memref<5x5xf32> - %2 = memref.alloca() : memref<6x6xf32> - - call @custom_call(%0, %1, %2, %0, %1, %2) - : (memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>, - memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>) -> () - call @custom_call(%0, %1, %2, %0, %1, %2) - : (memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>, - memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>) -> () - call @custom_call(%0, %1, %2, %0, %1, %2) - : (memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>, - memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>) -> () - call @custom_call(%0, %1, %2, %0, %1, %2) - : (memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>, - memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>) -> () - call @custom_call(%0, %1, %2, %0, %1, %2) - : (memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>, - memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>) -> () - call @custom_call(%0, %1, %2, %0, %1, %2) - : (memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>, - memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>) -> () - call @custom_call(%0, %1, %2, %0, %1, %2) - : (memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>, - memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>) -> () - call @custom_call(%0, %1, %2, %0, %1, %2) - : (memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>, - memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>) -> () - call @custom_call(%0, %1, %2, %0, %1, %2) - : (memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>, - memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>) -> () - call @custom_call(%0, %1, %2, %0, %1, %2) - : (memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>, - memref<4x4xf32>, memref<5x5xf32>, memref<6x6xf32>) -> () - return - } - )"; - - BenchmarkCustomCall(state, source, {}, "test.custom_call", - &RemainingArgsSink); -} - -static void BM_MemrefEncoding(bm::State& s) { MemrefEncoding(s); } - -BENCHMARK(BM_MemrefEncoding); - -} // namespace runtime -} // namespace xla - -// Add explicit dense type ids for all data types passed as UserData to measure -// the effects of explicit type id declaration/definition. -#define DEFINE_DENSE_TYPE_ID(n) \ - XLA_RUNTIME_DECLARE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall, \ - xla::runtime::Data); \ - XLA_RUNTIME_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall, \ - xla::runtime::Data) - -DEFINE_DENSE_TYPE_ID(0); -DEFINE_DENSE_TYPE_ID(1); -DEFINE_DENSE_TYPE_ID(2); -DEFINE_DENSE_TYPE_ID(3); -DEFINE_DENSE_TYPE_ID(4); -DEFINE_DENSE_TYPE_ID(5); -DEFINE_DENSE_TYPE_ID(6); -DEFINE_DENSE_TYPE_ID(7); -DEFINE_DENSE_TYPE_ID(8); -DEFINE_DENSE_TYPE_ID(9); -DEFINE_DENSE_TYPE_ID(10); -DEFINE_DENSE_TYPE_ID(11); - -#undef DEFINE_DENSE_TYPE_ID diff --git a/xla/runtime/default/async_values_cache.h b/xla/runtime/default/async_values_cache.h index cda0c543a65e7..a5175a7f5f494 100644 --- a/xla/runtime/default/async_values_cache.h +++ b/xla/runtime/default/async_values_cache.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/default/memory_mapper.h b/xla/runtime/default/memory_mapper.h index 96ee65606550d..44de6730e2f01 100644 --- a/xla/runtime/default/memory_mapper.h +++ b/xla/runtime/default/memory_mapper.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/diagnostics.cc b/xla/runtime/diagnostics.cc index 8e60e4cb27e7e..a111e01bc5898 100644 --- a/xla/runtime/diagnostics.cc +++ b/xla/runtime/diagnostics.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/diagnostics.h b/xla/runtime/diagnostics.h index 8acbd3d6853ec..f2d8207723337 100644 --- a/xla/runtime/diagnostics.h +++ b/xla/runtime/diagnostics.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -98,7 +98,7 @@ class InFlightDiagnostic { // Example: // // LogicalResult call(DiagnosticEngine diag, ...) { - // if () return diag.EmitError(InternalError("oops")); + // if () return diag.EmitError(Internal("oops")); // ... // } // diff --git a/xla/runtime/diagnostics_test.cc b/xla/runtime/diagnostics_test.cc index fc28fe25fdb7d..a1995b08f0a7c 100644 --- a/xla/runtime/diagnostics_test.cc +++ b/xla/runtime/diagnostics_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/errors.h b/xla/runtime/errors.h index 3f29a35423c0f..95e96cc08e0ee 100644 --- a/xla/runtime/errors.h +++ b/xla/runtime/errors.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ absl::Status InvalidArgument(const absl::FormatSpec& format, } template -absl::Status InternalError(const absl::FormatSpec& format, +absl::Status Internal(const absl::FormatSpec& format, const Args&... args) { return absl::InternalError(absl::StrFormat(format, args...)); } diff --git a/xla/runtime/executable.cc b/xla/runtime/executable.cc index 46d82be23a3b3..0b9b1b67f41e0 100644 --- a/xla/runtime/executable.cc +++ b/xla/runtime/executable.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -160,7 +160,7 @@ Executable::GetArgumentsMemoryLayout(const FunctionType& signature) { continue; } - return InternalError("unknown operand #%i argument ABI: %s", i, + return Internal("unknown operand #%i argument ABI: %s", i, type->ToString()); } @@ -195,7 +195,7 @@ Executable::GetResultsMemoryLayout(const FunctionType& signature) { continue; } - return InternalError("unknown result #%i argument ABI: %s", i, + return Internal("unknown result #%i argument ABI: %s", i, type->ToString()); } @@ -381,7 +381,7 @@ Status Executable::ReturnResults(unsigned ordinal, CallFrame* call_frame) const { // If execution failed, forward error to all results. if (call_frame->is_error) { - auto err = InternalError("run time error: %s", call_frame->error); + auto err = Internal("run time error: %s", call_frame->error); return (results.ReturnError(err), err); } @@ -400,7 +400,7 @@ Status Executable::ReturnResults(unsigned ordinal, } if (LLVM_UNLIKELY(!converted)) - return InternalError("failed to convert all returned values"); + return Internal("failed to convert all returned values"); else return absl::OkStatus(); } diff --git a/xla/runtime/executable.h b/xla/runtime/executable.h index 05a3abe461883..b71c9901795e7 100644 --- a/xla/runtime/executable.h +++ b/xla/runtime/executable.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -149,6 +149,8 @@ class Executable { std::string_view name() const { return name_; } + std::string&& take_ir_module_string() { return std::move(ir_module_string_); } + std::optional specialization() const { return specialization_; } // Returns the number of exported functions. Functions are indexed by their @@ -375,13 +377,15 @@ class Executable { std::unique_ptr engine, std::vector functions, std::optional specialization, - std::chrono::milliseconds time_to_compile) + std::chrono::milliseconds time_to_compile, + std::string&& ir_module_string = "") : name_(name), memory_mapper_(std::move(memory_mapper)), engine_(std::move(engine)), functions_(std::move(functions)), specialization_(specialization), - time_to_compile_(time_to_compile) { + time_to_compile_(time_to_compile), + ir_module_string_(ir_module_string) { // All exported functions must have a non-null function pointer. assert(llvm::all_of(functions_, [](const Function& f) { return f.fptr; })); } @@ -403,6 +407,10 @@ class Executable { // The time it took to compile this binary. std::chrono::milliseconds time_to_compile_; + + // The (optional) string containing the LLVM module, if requested by + // compilation or set explicitly. + std::string ir_module_string_; }; // Function reference provides a function-like API for a function exported from diff --git a/xla/runtime/executable_test.cc b/xla/runtime/executable_test.cc deleted file mode 100644 index 123b67e30e38e..0000000000000 --- a/xla/runtime/executable_test.cc +++ /dev/null @@ -1,1002 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/base/dynamic_annotations.h" -#include "absl/synchronization/barrier.h" -#include "absl/synchronization/blocking_counter.h" -#include "absl/synchronization/notification.h" -#include "xla/mlir/runtime/transforms/compilation_pipeline_options.h" -#include "xla/mlir/runtime/transforms/tests/testlib_pipeline.h" -#include "xla/mlir/runtime/utils/async_runtime_api.h" -#include "xla/runtime/arguments.h" -#include "xla/runtime/async_runtime.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/jit_executable.h" -#include "xla/runtime/logical_result.h" -#include "xla/runtime/results.h" -#include "xla/runtime/types.h" -#include "tsl/platform/env.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" -#include "tsl/platform/threadpool.h" - -namespace xla { -namespace runtime { - -using absl::StatusOr; - -//===----------------------------------------------------------------------===// -// A helper function that compiles the given `module` to an XLA runtime -// executable and runs the module's `test` function with the given arguments. -// Results are returned to the caller via the user-provided result converter. -//===----------------------------------------------------------------------===// - -static AsyncTaskRunner* NoRunner() { - return reinterpret_cast(0XDEADBEEF); -} - -// Lazily execute tasks -class LazyAsyncTaskRunner : public AsyncTaskRunner { - public: - void Schedule(Task task) final { tasks_.push_back(std::move(task)); } - void Run() { - while (!tasks_.empty()) { - tasks_.back()(); - tasks_.pop_back(); - break; - } - } - - private: - std::vector tasks_; -}; - -struct CustomCallRegistry { - std::function dynamic_custom_calls; - std::function direct_custom_calls; -}; - -static absl::StatusOr Compile( - std::string_view module, absl::Span exported, - const CustomCallRegistry& registry = {}) { - JitExecutable::Options opts; - CompilationPipelineOptions copts; - opts.specialization = JitExecutable::Specialization::kDisabled; - opts.compiler.symbols_binding = ToSymbolsBinding( - registry.direct_custom_calls, copts.populate_type_id_names); - opts.compiler.register_dialects = [&](DialectRegistry& dialects) { - RegisterXlaRuntimeTestlibDialects(dialects); - }; - opts.compiler.create_compilation_pipeline = CreateXlaRuntimeTestlibPipeline; - - return JitExecutable::Instantiate(module, opts, exported); -} - -static absl::StatusOr Execute( - JitExecutable& jit_executable, unsigned ordinal, ArgumentsRef args, - ResultConverter& results, AsyncTaskRunner* async_task_runner = NoRunner(), - const CustomCallRegistry& registry = {}, bool use_lazy_runner = false) { - AsyncValuePtr executable = jit_executable.DefaultExecutable(); - if (executable.IsError()) return executable.GetError(); - - // Register all dynamic custom calls. - DynamicCustomCallRegistry dynamic_custom_calls; - if (registry.dynamic_custom_calls) - registry.dynamic_custom_calls(dynamic_custom_calls); - - CustomCall::UserData user_data; - // Always add a pointer to `self` to user data. - user_data.insert(&executable.get()); - - Executable::ExecuteOpts execute_opts; - execute_opts.custom_call_registry = &dynamic_custom_calls; - execute_opts.custom_call_data = &user_data; - execute_opts.async_task_runner = async_task_runner; - if (use_lazy_runner) { - LazyAsyncTaskRunner runner; - execute_opts.async_task_runner = &runner; - FunctionRef function_ref = executable->function_ref(ordinal); - auto status = function_ref(args, results, execute_opts); - runner.Run(); - return status; - } - - FunctionRef function_ref = executable->function_ref(ordinal); - return function_ref(args, results, execute_opts); -} - -static absl::StatusOr CompileAndExecute( - std::string_view module, ArgumentsRef args, ResultConverter& results, - AsyncTaskRunner* async_task_runner = NoRunner(), - const CustomCallRegistry& registry = {}, bool use_lazy_runner = false) { - StatusOr jit_executable = Compile(module, {"test"}, registry); - if (!jit_executable.ok()) return jit_executable.status(); - - return Execute(*jit_executable, 0, args, results, async_task_runner, registry, - use_lazy_runner); -} - -//===----------------------------------------------------------------------===// - -namespace { - -// An owning wrapper around Memref desciptor that releases the underlying buffer -// when destructed. Used for testing passing ownerhip of memrefs allocated in -// the compiled executables to the C++ caller. -struct OwnedMemref { - ~OwnedMemref() { - if (desc.has_value()) std::free(desc->data()); - } - - MemrefDesc* operator->() { return &desc.value(); } - - std::optional desc; -}; - -} // namespace - -//===----------------------------------------------------------------------===// - -static void AssertNoError(const absl::Status& status) { - assert(false && "Unexpected call to `ReturnError`"); -} - -static void IgnoreError(const absl::Status& status) {} - -void Emplace(void* int_ptr, AsyncValue* dst) { - auto& v = dst->get(); - v = *reinterpret_cast(int_ptr); -} - -template -struct ReturnScalar { - LogicalResult operator()(unsigned result_index, const Type* type, - const Type* runtime_type, void* ret) const { - PrimitiveType dtype = primitive_util::NativeToPrimitiveType(); - - if (auto* s = llvm::dyn_cast(type); s && s->type() == dtype) { - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(ret, sizeof(T)); - *ptr = *reinterpret_cast(ret); - return success(); - } - - return failure(); - } - - T* ptr = nullptr; -}; - -struct ReturnMemref { - LogicalResult operator()(unsigned result_index, const Type* type, - const Type* runtime_type, void* ret) const { - auto* memref = llvm::dyn_cast(runtime_type); - if (!memref) return failure(); - - auto desc = ConvertReturnedMemref(*this, memref, ret); - if (failed(desc)) return failure(); - - ptr->desc = std::move(*desc); - return success(); - } - - MemrefDesc operator()(PrimitiveType element_type, void* base_ptr, - void* data_ptr, int64_t offset, - absl::Span sizes, - absl::Span strides) const { - return MemrefDesc(element_type, base_ptr, offset, sizes, strides); - } - - OwnedMemref* ptr = nullptr; -}; - -struct ReturnAsyncToken { - LogicalResult operator()(unsigned result_index, const Type* type, - const Type* runtime_type, void* result_ptr) const { - if (!llvm::isa(type)) return failure(); - - // Load the pointer to the async token from a pointer to result storage. - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(result_ptr, sizeof(void*)); - void* ret = *reinterpret_cast(result_ptr); - auto* token = static_cast(ret); - auto* async_value = AsyncRuntime::GetAsyncValue(token); - CHECK(async_value->IsAvailable()); - chain.SetStateConcrete(); - AsyncRuntime::DropRef(AsyncRuntime::ToAsyncRuntimeObject(token)); - return success(); - } - - AsyncValuePtr chain; -}; - -struct ReturnAsyncI32 { - LogicalResult operator()(unsigned result_index, const Type* type, - const Type* runtime_type, void* result_ptr) const { - auto* value_type = llvm::dyn_cast(type); - if (!value_type) return failure(); - - // Load the pointer to the async value from a pointer to result storage. - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(result_ptr, sizeof(void*)); - void* ret = *reinterpret_cast(result_ptr); - auto* value = static_cast(ret); - auto* scalar = llvm::dyn_cast(&value_type->value_type()); - - if (scalar && scalar->type() == PrimitiveType::S32) { - ExtractAsyncValue(value, ptr.value(), Emplace); - return success(); - } - - return failure(); - } - - AsyncValuePtr ptr; -}; - -template -struct FetchMemrefDescFromAsyncValue { - void operator()(AsyncValue* value, MemrefDesc&& desc) const; -}; - -template <> -struct FetchMemrefDescFromAsyncValue { - void operator()(AsyncValue* value, MemrefDesc&& desc) const { - value->get().desc = std::move(desc); - } -}; - -template <> -struct FetchMemrefDescFromAsyncValue { - void operator()(AsyncValue* value, MemrefDesc&& desc) const { - value->get() = std::move(desc); - } -}; - -template -struct ReturnAsyncMemref { - LogicalResult operator()(unsigned result_index, const Type* type, - const Type* runtime_type, void* result_ptr) const { - auto* value_type = llvm::dyn_cast(type); - if (!value_type) return failure(); - - // Load the pointer to the async memref from a pointer to result storage. - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(result_ptr, sizeof(void*)); - void* ret = *reinterpret_cast(result_ptr); - auto* value = static_cast(ret); - auto* memref = llvm::dyn_cast(&value_type->value_type()); - - if (memref) { - ExtractAsyncValue( - value, ptr.value(), - [converter = *this, m = *memref](void* data, AsyncValue* dst) { - auto desc = ConvertReturnedMemref(converter, &m, data); - if (succeeded(desc)) { - FetchMemrefDescFromAsyncValue()(dst, - std::move(*desc)); - dst->SetStateConcrete(); - } - }); - return success(); - } - - return failure(); - } - - MemrefDesc operator()(PrimitiveType element_type, void* base_ptr, - void* data_ptr, int64_t offset, - absl::Span sizes, - absl::Span strides) const { - return MemrefDesc(element_type, base_ptr, offset, sizes, strides); - } - - AsyncValuePtr ptr; -}; - -using ReturnAsyncOwnedMemref = ReturnAsyncMemref; -using ReturnAsyncMemrefDesc = ReturnAsyncMemref; - -// Execute all tasks in the caller thread immediately. -class InlineAsyncTaskRunner : public AsyncTaskRunner { - public: - void Schedule(Task task) final { (task(), num_executed_++); } - size_t num_executed() const { return num_executed_; } - - private: - size_t num_executed_ = 0; -}; - -//===----------------------------------------------------------------------===// - -TEST(ExecutableTest, ReturnScalar) { - absl::string_view module = R"( - func.func @test() -> i32 { - %0 = arith.constant 42 : i32 - return %0 : i32 - } - )"; - - int32_t result = 0; - ResultConverterSet converter(AssertNoError, ReturnScalar{&result}); - - ASSERT_TRUE(CompileAndExecute(module, {}, converter).ok()); - EXPECT_EQ(result, 42); -} - -TEST(ExecutableTest, ReturnMemref) { - absl::string_view module = R"( - func.func @test() -> memref { - %0 = arith.constant 1 : index - %1 = arith.constant 2 : index - %2 = memref.alloc(%0, %1) : memref - return %2 : memref - } - )"; - - OwnedMemref result; - ResultConverterSet converter(AssertNoError, ReturnMemref{&result}); - - ASSERT_TRUE(CompileAndExecute(module, {}, converter).ok()); - ASSERT_TRUE(result.desc.has_value()); - EXPECT_EQ(result->rank(), 2); - EXPECT_EQ(result->size(0), 1); - EXPECT_EQ(result->size(1), 2); -} - -TEST(ExecutableTest, ScalarArgs) { - absl::string_view module = R"( - func.func @test(%arg0: i32, %arg1: i32) -> i32 { - %0 = arith.addi %arg0, %arg1 : i32 - return %0 : i32 - } - )"; - - int32_t result = 0; - ResultConverterSet converter(AssertNoError, ReturnScalar{&result}); - - ScalarArg arg0(static_cast(20)); - ScalarArg arg1(static_cast(22)); - - ASSERT_TRUE(CompileAndExecute(module, {arg0, arg1}, converter).ok()); - EXPECT_EQ(result, 42); -} - -TEST(ExecutableTest, MemrefF8Arg) { - absl::string_view module = R"( - func.func @test(%arg0: memref) -> index { - %c0 = arith.constant 0 : index - %0 = memref.dim %arg0, %c0 : memref - return %0 : index - } - )"; - - int64_t result = 0; - ResultConverterSet converter(AssertNoError, ReturnScalar{&result}); - - MemrefDesc arg0(PrimitiveType::F8E4M3FN, nullptr, 0, {42}, {1}); - - Arguments args(1); - args.emplace_back(std::move(arg0)); - - ASSERT_TRUE(CompileAndExecute(module, args, converter).ok()); - EXPECT_EQ(result, 42); -} - -TEST(ExecutableTest, MultipleFunctions) { - absl::string_view module = R"( - func.func @add(%arg0: i32, %arg1: i32) -> i32 { - %0 = arith.addi %arg0, %arg1 : i32 - return %0 : i32 - } - - func.func @mul(%arg0: i32, %arg1: i32) -> i32 { - %0 = arith.muli %arg0, %arg1 : i32 - return %0 : i32 - } - )"; - - absl::StatusOr compiled = Compile(module, {"add", "mul"}); - ASSERT_TRUE(compiled.ok()); - EXPECT_EQ(compiled->num_functions(), 2); - - int32_t result = 0; - ResultConverterSet converter(AssertNoError, ReturnScalar{&result}); - - ScalarArg arg0(static_cast(20)); - ScalarArg arg1(static_cast(22)); - - ASSERT_TRUE(Execute(*compiled, /*ordinal=*/0, {arg0, arg1}, converter).ok()); - EXPECT_EQ(result, 20 + 22); - - ASSERT_TRUE(Execute(*compiled, /*ordinal=*/1, {arg0, arg1}, converter).ok()); - EXPECT_EQ(result, 20 * 22); -} - -TEST(ExecutableTest, AssertionFailure) { - absl::string_view module = R"( - func.func @test(%arg0: i32) { - %c42 = arith.constant 42 : i32 - %0 = arith.cmpi ne, %c42, %arg0 : i32 - cf.assert %0, "Oops, argument can't be 42" - return - } - )"; - - NoResultConverter converter; - - { - ScalarArg arg0(int32_t{20}); - EXPECT_TRUE(CompileAndExecute(module, {arg0}, converter).ok()); - } - - { - ScalarArg arg0(int32_t{42}); - auto executed = CompileAndExecute(module, {arg0}, converter); - EXPECT_FALSE(executed.ok()); - EXPECT_EQ(executed.status().message(), - "run time error: Oops, argument can't be 42"); - } -} - -TEST(ExecutableTest, AssertionFailureOrResult) { - absl::string_view module = R"( - func.func @test(%arg0: i32) -> i32 { - %c42 = arith.constant 42 : i32 - %0 = arith.cmpi ne, %c42, %arg0 : i32 - cf.assert %0, "Oops, argument can't be 42" - %1 = arith.addi %arg0, %c42 : i32 - return %1 : i32 - } - )"; - - { - int32_t result = 0; - ResultConverterSet converter(AssertNoError, ReturnScalar{&result}); - - ScalarArg arg0(int32_t{20}); - EXPECT_TRUE(CompileAndExecute(module, {arg0}, converter).ok()); - EXPECT_EQ(result, 62); - } - - { - int32_t result = 0; - ResultConverterSet converter(IgnoreError, ReturnScalar{&result}); - - ScalarArg arg0(int32_t{42}); - auto executed = CompileAndExecute(module, {arg0}, converter); - EXPECT_FALSE(executed.ok()); - EXPECT_EQ(executed.status().message(), - "run time error: Oops, argument can't be 42"); - EXPECT_EQ(result, 0); - } -} - -TEST(ExecutableTest, AsyncExecuteAndAwait) { - absl::string_view module = R"( - func.func @test(%arg0: i32, %arg1: i32) -> i32 { - %token, %result = async.execute -> !async.value { - %0 = arith.addi %arg0, %arg1 : i32 - async.yield %0 : i32 - } - %1 = async.await %result : !async.value - return %1 : i32 - } - )"; - - int32_t result = 0; - ResultConverterSet converter(AssertNoError, ReturnScalar{&result}); - - ScalarArg arg0(static_cast(20)); - ScalarArg arg1(static_cast(22)); - - InlineAsyncTaskRunner runner; - - ASSERT_TRUE(CompileAndExecute(module, {arg0, arg1}, converter, &runner).ok()); - EXPECT_EQ(runner.num_executed(), 1); - EXPECT_EQ(result, 42); -} - -TEST(ExecutableTest, AsyncTokenRet) { - absl::string_view module = R"( - async.func @test() -> !async.token { - return - } - )"; - - AsyncValueRef result = MakeConstructedAsyncValueRef(); - ResultConverterSet converter(AssertNoError, ReturnAsyncToken{result.AsPtr()}); - - ASSERT_TRUE(CompileAndExecute(module, {}, converter).ok()); - EXPECT_EQ(result.IsAvailable(), true); -} - -TEST(ExecutableTest, AsyncScalarRet) { - absl::string_view module = R"( - async.func @test(%arg0: i32, %arg1: i32) -> !async.value { - %0 = arith.addi %arg0, %arg1 : i32 - return %0 : i32 - } - )"; - - AsyncValueRef result = MakeConstructedAsyncValueRef(); - ResultConverterSet converter(AssertNoError, ReturnAsyncI32{result.AsPtr()}); - - ScalarArg arg0(static_cast(20)); - ScalarArg arg1(static_cast(22)); - - ASSERT_TRUE(CompileAndExecute(module, {arg0, arg1}, converter).ok()); - EXPECT_EQ(result.get(), 42); -} - -TEST(ExecutableTest, AsyncTokenArg) { - absl::string_view module = R"( - async.func @test(%arg0: !async.token, %arg1: i32) -> !async.value { - async.await %arg0 : !async.token - return %arg1 : i32 - } - )"; - - AsyncValueRef result = MakeConstructedAsyncValueRef(); - ResultConverterSet converter(AssertNoError, ReturnAsyncI32{result.AsPtr()}); - - AsyncValueRef ch = tsl::MakeAvailableAsyncValueRef(); - - Arguments arguments(2); - arguments.emplace_back(AsyncTokenArg(ch)); - arguments.push_back(ScalarArg(static_cast(22))); - - ASSERT_TRUE(CompileAndExecute(module, arguments, converter).ok()); - EXPECT_EQ(result.get(), 22); -} - -TEST(ExecutableTest, AsyncScalarArg) { - absl::string_view module = R"( - async.func @test(%arg0: !async.value, %arg1: i32) -> !async.value { - %0 = async.await %arg0 : !async.value - %1 = arith.addi %0, %arg1 : i32 - return %1 : i32 - } - )"; - - AsyncValueRef result = MakeConstructedAsyncValueRef(); - ResultConverterSet converter(AssertNoError, ReturnAsyncI32{result.AsPtr()}); - - AsyncValueRef async_val = - tsl::MakeAvailableAsyncValueRef(20); - AsyncScalarArg arg0(async_val); - ScalarArg arg1(static_cast(22)); - - Arguments arguments(2); - arguments.push_back(arg0); - arguments.push_back(arg1); - - ASSERT_TRUE(CompileAndExecute(module, arguments, converter).ok()); - EXPECT_EQ(result.get(), 42); -} - -TEST(ExecutableTest, AsyncMemrefArg) { - absl::string_view module = R"( - async.func @test(%arg0: !async.value>) -> - !async.value> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - %0 = async.await %arg0 : !async.value> - %dim0 = memref.dim %0, %c0 : memref - %dim1 = memref.dim %0, %c1 : memref - %1 = memref.alloc(%dim0, %dim1) : memref - - memref.copy %0, %1 : memref to memref - - return %1 : memref - } - )"; - - AsyncValueRef result = - MakeConstructedAsyncValueRef(); - ResultConverterSet converter(AssertNoError, - ReturnAsyncOwnedMemref{result.AsPtr()}); - std::vector input = {42.0, 42.0, 42.0, 42.0, 42.0, 42.0, 42.0, 42.0}; - MemrefDesc memref{ - PrimitiveType::F32, input.data(), 0, {4, 2}, {4, 2} /*fake strides*/}; - AsyncValueRef async_memref = - tsl::MakeAvailableAsyncValueRef(std::move(memref)); - - AsyncMemrefArg arg0(async_memref); - - ASSERT_TRUE(CompileAndExecute(module, {arg0}, converter).ok()); - ASSERT_TRUE(result.get().desc.has_value()); - EXPECT_EQ(result.get()->rank(), 2); - EXPECT_EQ(result.get()->size(0), 4); - EXPECT_EQ(result.get()->size(1), 2); - - float* data = reinterpret_cast(result.get()->data()); - EXPECT_TRUE(std::all_of(data, data + 8, [](float v) { return v == 42.0f; })); -} - -TEST(ExecutableTest, AsyncMemrefRet) { - absl::string_view module = R"( - async.func @test(%arg0: index) -> !async.value> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - %0 = memref.alloc(%arg0) : memref - scf.for %i = %c0 to %arg0 step %c1 { - %c42 = arith.constant 42.0 : f32 - memref.store %c42, %0[%i] : memref - } - - return %0 : memref - } - )"; - - AsyncValueRef result = - MakeConstructedAsyncValueRef(); - ResultConverterSet converter(AssertNoError, - ReturnAsyncOwnedMemref{result.AsPtr()}); - - ScalarArg arg0(static_cast(32)); - - ASSERT_TRUE(CompileAndExecute(module, {arg0}, converter).ok()); - ASSERT_TRUE(result.get().desc.has_value()); - EXPECT_EQ(result.get()->rank(), 1); - EXPECT_EQ(result.get()->size(0), 32); - - float* data = reinterpret_cast(result.get()->data()); - EXPECT_TRUE(std::all_of(data, data + 32, [](float v) { return v == 42.0f; })); -} - -TEST(ExecutableTest, AsyncMemrefInputsAndRets) { - absl::string_view module = R"( - func.func private @custom_call(%arg0: memref<2x2xf32>, - %arg1: memref<2x2xf32>) - attributes { rt.dynamic, rt.custom_call = "test.double" } - - async.func @test(%input: !async.value>, - %output: memref<2x2xf32>) - -> !async.value> { - %token, %result = execute -> !async.value> { - %0 = async.await %input : !async.value> - func.call @custom_call(%0, %output) - : (memref<2x2xf32>, memref<2x2xf32>) -> () - async.yield %output : memref<2x2xf32> - } - %1 = async.await %result : !async.value> - return %1 : memref<2x2xf32> - } - )"; - - // Doubles every element in the array. - auto test_double = [&](MemrefView input, MemrefView output) { - float* in = reinterpret_cast(input.data); - float* out = reinterpret_cast(output.data); - for (int i = 0; i < 4; ++i) { - out[i] = in[i] * 2; - } - return success(); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.double") - .Arg() // input - .Arg() // output - .To(test_double)); - }}; - - // Allocates storage and sets the initial data. - // In this test case, this buffer is shared across all inputs and outputs, - // which mimics the buffer reuse behavior in XLA. - std::array storage = {1.0, 2.0, 3.0, 4.0}; - std::array sizes = {2, 2}; - const auto& fake_strides = sizes; - - // Constructs inputs and output for the first run. - AsyncValueRef input_1 = - tsl::MakeAvailableAsyncValueRef( - PrimitiveType::F32, storage.data(), 0, sizes, fake_strides); - // Wraps the output fed in the parameter packs as an async output. - auto result_1 = MakeConstructedAsyncValueRef( - PrimitiveType::F32, storage.data(), 0, sizes, fake_strides); - ResultConverterSet first_converter(AssertNoError, - ReturnAsyncMemrefDesc{result_1.AsPtr()}); - - Arguments args_1(2); - args_1.emplace_back(AsyncMemrefArg(input_1)); - args_1.push_back( - MemrefDesc(PrimitiveType::F32, storage.data(), 0, sizes, fake_strides)); - - LazyAsyncTaskRunner runner; - auto exec_ref = - CompileAndExecute(module, args_1, first_converter, &runner, registry, - /*use_lazy_runner=*/true); - ASSERT_TRUE(exec_ref.ok()); - result_1.AndThen([exec_ref = *std::move(exec_ref)] {}); - - // Constructs inputs and output for the second run. - auto result_2 = MakeConstructedAsyncValueRef( - MemrefDesc(PrimitiveType::F32, storage.data(), 0, sizes, fake_strides)); - ResultConverterSet second_converter(AssertNoError, - ReturnAsyncMemrefDesc{result_2.AsPtr()}); - Arguments args_2(2); - args_2.emplace_back(AsyncMemrefArg(result_1)); - args_2.push_back( - MemrefDesc(PrimitiveType::F32, storage.data(), 0, sizes, fake_strides)); - exec_ref = - CompileAndExecute(module, args_2, second_converter, &runner, registry, - /*use_lazy_runner=*/true); - result_2.AndThen([exec_ref = *std::move(exec_ref)] {}); - tsl::BlockUntilReady(result_2.GetAsyncValue()); - - EXPECT_THAT(storage, testing::ElementsAre(4.0, 8.0, 12.0, 16.0)); -} - -TEST(ExecutableTest, AsyncWaiting) { - absl::string_view module = R"( - async.func @test2(%arg0: i32, %arg1: i32) -> !async.value { - %0 = arith.addi %arg0, %arg1 : i32 - return %0 : i32 - } - async.func @test(%arg0: i32, %arg1:i32) -> !async.value { - %0 = async.call @test2(%arg0, %arg1) : (i32, i32) -> !async.value - %1 = async.await %0 : !async.value - return %1 : i32 - } - )"; - - AsyncValueRef result = MakeConstructedAsyncValueRef(); - ResultConverterSet converter(AssertNoError, ReturnAsyncI32{result.AsPtr()}); - - ScalarArg arg0(static_cast(20)); - ScalarArg arg1(static_cast(22)); - - ASSERT_TRUE(CompileAndExecute(module, {arg0, arg1}, converter).ok()); - EXPECT_EQ(result.get(), 42); -} - -TEST(ExecutableTest, AsyncCustomCall) { - absl::string_view source = R"( - func.func private @custom_call_return() -> !async.value - attributes { rt.dynamic, rt.custom_call = "test.custom_call_return" } - - func.func private @custom_call(%arg32 : i32) - attributes { rt.dynamic, rt.custom_call = "test.custom_call" } - - async.func @test() -> !async.token { - %0 = func.call @custom_call_return() : () -> !async.value - %1 = async.await %0 : !async.value - func.call @custom_call(%1) : (i32) -> () - return - } - )"; - - auto f_result = []() -> absl::StatusOr> { - return tsl::MakeAvailableAsyncValueRef(42); - }; - - int32_t i32 = 0; - auto f = [&](int32_t arg) { - i32 = arg; - return success(); - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.custom_call_return") - .Ret>() - .To(f_result)); - - registry.Register( - CustomCall::Bind("test.custom_call").Arg().To(f)); - }}; - - AsyncValueRef result = MakeConstructedAsyncValueRef(); - ResultConverterSet converter(AssertNoError, ReturnAsyncToken{result.AsPtr()}); - - ASSERT_TRUE( - CompileAndExecute(source, /*args=*/{}, converter, NoRunner(), registry) - .ok()); - EXPECT_EQ(i32, 42); -} - -TEST(ExecutableTest, AsyncExecute) { - absl::string_view source = R"( - module { - func.func private @custom_call_return() -> !async.value - attributes { rt.dynamic, rt.custom_call = "test.custom_call_return" } - - async.func @test() -> !async.value { - %token, %result = async.execute -> !async.value { - %0 = func.call @custom_call_return() : () -> !async.value - %1 = async.await %0 : !async.value - async.yield %1 : i32 - } - %1 = async.await %result : !async.value - return %1 : i32 - } - } - )"; - - LazyAsyncTaskRunner runner; - - auto async_result = tsl::MakeAvailableAsyncValueRef(42); - auto f_result = [&]() -> absl::StatusOr> { - return async_result; - }; - - CustomCallRegistry registry = {[&](DynamicCustomCallRegistry& registry) { - registry.Register(CustomCall::Bind("test.custom_call_return") - .Ret>() - .To(f_result)); - }}; - AsyncValueRef result = MakeConstructedAsyncValueRef(); - ResultConverterSet converter(AssertNoError, ReturnAsyncI32{result.AsPtr()}); - - ASSERT_TRUE(CompileAndExecute(source, /*args=*/{}, converter, &runner, - registry, /*use_lazy_runner=*/true) - .ok()); - - EXPECT_EQ(result.get(), 42); -} - -//===----------------------------------------------------------------------===// -// Multi-threaded compilation to detect tsan errors. -//===----------------------------------------------------------------------===// - -TEST(ExecutableTest, ConcurrentCompilation) { - CustomCallRegistry registry; - - absl::string_view module = R"( - func.func @test() -> i32 { - %0 = arith.constant 42 : i32 - return %0 : i32 - } - )"; - - tsl::thread::ThreadPool pool(tsl::Env::Default(), "test", 32); - - int num_tasks = 256; - - absl::Notification wait; - absl::BlockingCounter done(num_tasks); - - for (int i = 0; i < num_tasks; ++i) { - pool.Schedule([&] { - wait.WaitForNotification(); - - StatusOr jit_executable = - Compile(module, {"test"}, registry); - EXPECT_TRUE(jit_executable.ok()); - - done.DecrementCount(); - }); - } - - wait.Notify(); - done.Wait(); -} - -//===----------------------------------------------------------------------===// -// Performance benchmarks are below. -//===----------------------------------------------------------------------===// - -static void CompileAndBenchmark( - benchmark::State& state, std::string_view module, ArgumentsRef args, - ResultConverter& results, AsyncTaskRunner* async_task_runner = NoRunner()) { - JitExecutable::Options opts; - opts.specialization = JitExecutable::Specialization::kDisabled; - opts.compiler.register_dialects = RegisterXlaRuntimeTestlibDialects; - opts.compiler.create_compilation_pipeline = CreateXlaRuntimeTestlibPipeline; - - StatusOr jit_executable = - JitExecutable::Instantiate(module, "test", opts); - CHECK(jit_executable.ok()) << jit_executable.status().message(); - - AsyncValuePtr executable = jit_executable->DefaultExecutable(); - CHECK(!executable.IsError()) << executable.GetError().message(); - - Executable::CallFrame call_frame; - auto initialized = executable->InitializeCallFrame(args, &call_frame); - CHECK(initialized.ok()) << initialized.message(); - - Executable::ExecuteOpts execute_opts; - execute_opts.async_task_runner = async_task_runner; - - for (auto _ : state) { - call_frame.args[0] = nullptr; // reset execution context - executable->Execute(call_frame, execute_opts); - CHECK(!call_frame.is_error) << call_frame.error; - absl::Status returned = executable->ReturnResults(results, &call_frame); - CHECK(returned.ok()) << returned.message(); - } -} - -void BM_AsyncExecuteAndAwait(benchmark::State& state) { - absl::string_view module = R"( - func.func @test(%arg0: i32, %arg1: i32) -> i32 { - %token, %result = async.execute -> !async.value { - %0 = arith.addi %arg0, %arg1 : i32 - async.yield %0 : i32 - } - %1 = async.await %result : !async.value - return %1 : i32 - } - )"; - - int32_t result = 0; - ResultConverterSet converter(AssertNoError, ReturnScalar{&result}); - - ScalarArg arg0(static_cast(20)); - ScalarArg arg1(static_cast(22)); - - InlineAsyncTaskRunner runner; - CompileAndBenchmark(state, module, {arg0, arg1}, converter, &runner); -} - -void BM_AsyncFunc(benchmark::State& state) { - absl::string_view module = R"( - async.func @test(%arg0: i32, %arg1: i32) -> !async.value { - %0 = arith.addi %arg0, %arg1 : i32 - return %0 : i32 - } - )"; - - AsyncValueRef result = MakeConstructedAsyncValueRef(); - ResultConverterSet converter(AssertNoError, ReturnAsyncI32{result.AsPtr()}); - - ScalarArg arg0(static_cast(20)); - ScalarArg arg1(static_cast(22)); - - InlineAsyncTaskRunner runner; - CompileAndBenchmark(state, module, {arg0, arg1}, converter, &runner); -} - -void BM_AsyncFuncCall(benchmark::State& state) { - absl::string_view module = R"( - async.func @test2(%arg0: i32, %arg1: i32) -> !async.value { - %0 = arith.addi %arg0, %arg1 : i32 - return %0 : i32 - } - async.func @test(%arg0: i32, %arg1:i32) -> !async.value { - %0 = async.call @test2(%arg0, %arg1) : (i32, i32) -> !async.value - %1 = async.await %0 : !async.value - return %1 : i32 - } - )"; - - AsyncValueRef result = MakeConstructedAsyncValueRef(); - ResultConverterSet converter(AssertNoError, ReturnAsyncI32{result.AsPtr()}); - - ScalarArg arg0(static_cast(20)); - ScalarArg arg1(static_cast(22)); - - InlineAsyncTaskRunner runner; - CompileAndBenchmark(state, module, {arg0, arg1}, converter, &runner); -} - -BENCHMARK(BM_AsyncExecuteAndAwait); -BENCHMARK(BM_AsyncFunc); -BENCHMARK(BM_AsyncFuncCall); - -} // namespace runtime -} // namespace xla diff --git a/xla/runtime/execution_engine.cc b/xla/runtime/execution_engine.cc index 55ac3b2773549..3396c10e0c1be 100644 --- a/xla/runtime/execution_engine.cc +++ b/xla/runtime/execution_engine.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -166,9 +166,9 @@ absl::Status ExportWithXlaRuntimeAbi(llvm::Module &module, // Check that we have a function with a valid type. llvm::Function *func = module.getFunction(original_name); if (!func) - return InternalError("exported function not found: %s", original_name); + return Internal("exported function not found: %s", original_name); if (!func->getReturnType()->isVoidTy()) - return InternalError("exported function must return void"); + return Internal("exported function must return void"); // Add an XLA interface function for the exported function. llvm::FunctionType *xla_runtime_type = @@ -296,7 +296,7 @@ ExecutionEngine::CreateFromModule(std::unique_ptr ctx, // Set up the target machine details. if (!options.target_machine) - return InternalError("target machine was not provided"); + return Internal("target machine was not provided"); module->setDataLayout(options.target_machine->createDataLayout()); module->setTargetTriple(options.target_machine->getTargetTriple().str()); @@ -305,7 +305,7 @@ ExecutionEngine::CreateFromModule(std::unique_ptr ctx, if (auto status = ExportWithXlaRuntimeAbi(*module, name, GetExportedName(name)); !status.ok()) { - return InternalError( + return Internal( "failed to set up exported function %s interface: %s", name, status.message()); } @@ -320,7 +320,7 @@ ExecutionEngine::CreateFromModule(std::unique_ptr ctx, auto transformer = options.make_optimizing_transformer(options.target_machine.get()); if (auto err = transformer(module_ptr)) - return InternalError("failed to run optimization pipeline: %s", + return Internal("failed to run optimization pipeline: %s", ToString(err)); // Callback to create the object layer with a user-provided section memory @@ -364,7 +364,7 @@ ExecutionEngine::CreateFromModule(std::unique_ptr ctx, nullptr, std::make_unique()); if (auto err = executorProcessControl.takeError()) - return InternalError("failed to create executor process control: %s", + return Internal("failed to create executor process control: %s", ToString(err)); // TODO(b/286475799): Concurrent compilation leads to spurious memory @@ -382,14 +382,14 @@ ExecutionEngine::CreateFromModule(std::unique_ptr ctx, .create(); if (auto err = jit.takeError()) - return InternalError("failed to construct LLJIT: %s", ToString(err)); + return Internal("failed to construct LLJIT: %s", ToString(err)); lljit_lock.reset(); // Register input module with the LLJIT. ThreadSafeModule tsm(std::move(module), std::move(ctx)); if (auto err = (*jit)->addIRModule(std::move(tsm))) - return InternalError("failed to add source module: %s", ToString(err)); + return Internal("failed to add source module: %s", ToString(err)); llvm::orc::JITDylib &main_jd = (*jit)->getMainJITDylib(); llvm::DataLayout data_layout = (*jit)->getDataLayout(); @@ -400,7 +400,7 @@ ExecutionEngine::CreateFromModule(std::unique_ptr ctx, data_layout); auto symbols = absoluteSymbols(options.symbols_binding(mangle)); if (auto err = main_jd.define(symbols)) - return InternalError("failed to add symbols bindings: %s", ToString(err)); + return Internal("failed to add symbols bindings: %s", ToString(err)); } // Resolve all exported functions to function pointers. @@ -408,13 +408,13 @@ ExecutionEngine::CreateFromModule(std::unique_ptr ctx, // Trigger compilation by looking up the exported function. Expected addr = (*jit)->lookup(GetExportedName(name)); if (auto err = addr.takeError()) - return InternalError("failed to compile exported function %s: %s", name, + return Internal("failed to compile exported function %s: %s", name, ToString(err)); // Check that we found an address of an exported function. auto ptr = addr->toPtr(); if (!ptr) - return InternalError("exported function %s resolved to null", name); + return Internal("exported function %s resolved to null", name); engine->exported_.push_back(ptr); } @@ -425,7 +425,7 @@ ExecutionEngine::CreateFromModule(std::unique_ptr ctx, options.save_compiled_obj_file ? obj_cache->stealObject(module_ptr) : nullptr; if (options.save_compiled_obj_file && !obj_file) - return InternalError("could not find object file for the XLA module"); + return Internal("could not find object file for the XLA module"); // Fill remaining fields and return constructed ExecutionEngine to the caller. engine->jit_ = std::move(*jit); @@ -473,10 +473,10 @@ ExecutionEngine::CreateFromObjFile( .setObjectLinkingLayerCreator(obj_layer_creator) .create(); if (auto err = jit.takeError()) - return InternalError("failed to construct LLJIT: %s", ToString(err)); + return Internal("failed to construct LLJIT: %s", ToString(err)); if (auto err = (*jit)->addObjectFile(std::move(obj_file))) - return InternalError("failed to add object file: %s", ToString(err)); + return Internal("failed to add object file: %s", ToString(err)); llvm::orc::JITDylib &main_jd = (*jit)->getMainJITDylib(); llvm::DataLayout data_layout = (*jit)->getDataLayout(); @@ -485,7 +485,7 @@ ExecutionEngine::CreateFromObjFile( auto generator = DynamicLibrarySearchGenerator::GetForCurrentProcess( data_layout.getGlobalPrefix()); if (auto err = generator.takeError()) - return InternalError("failed to construct DyLib search generator"); + return Internal("failed to construct DyLib search generator"); main_jd.addGenerator(std::move(*generator)); // Register user-provided symbols. @@ -494,7 +494,7 @@ ExecutionEngine::CreateFromObjFile( data_layout); auto symbols = absoluteSymbols(options.symbols_binding(mangle)); if (auto err = main_jd.define(symbols)) - return InternalError("failed to add symbols bindings: %s", ToString(err)); + return Internal("failed to add symbols bindings: %s", ToString(err)); } // Resolve all exported functions to function pointers. @@ -502,13 +502,13 @@ ExecutionEngine::CreateFromObjFile( // Lookup exported function in the loaded object file. Expected addr = (*jit)->lookup(GetExportedName(name)); if (auto err = addr.takeError()) - return InternalError("failed to look up the exported function %s: %s", + return Internal("failed to look up the exported function %s: %s", name, ToString(err)); // Check that we found an address of an exported function. auto ptr = addr->toPtr(); if (!ptr) - return InternalError("exported function %s resolved to null", name); + return Internal("exported function %s resolved to null", name); engine->exported_.push_back(ptr); } diff --git a/xla/runtime/execution_engine.h b/xla/runtime/execution_engine.h index f7e3c299a738e..19c71d5329040 100644 --- a/xla/runtime/execution_engine.h +++ b/xla/runtime/execution_engine.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/ffi/BUILD b/xla/runtime/ffi/BUILD index 91b1fbf3f7a46..f2f0af9e6b096 100644 --- a/xla/runtime/ffi/BUILD +++ b/xla/runtime/ffi/BUILD @@ -1,5 +1,5 @@ -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load("@tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") +load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/xla/runtime/ffi/ffi_abi.h b/xla/runtime/ffi/ffi_abi.h index 3bb01f9a70230..c0185e6e61263 100644 --- a/xla/runtime/ffi/ffi_abi.h +++ b/xla/runtime/ffi/ffi_abi.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/ffi/ffi_api.h b/xla/runtime/ffi/ffi_api.h index 448d145d68e35..f6f9a9e25d992 100644 --- a/xla/runtime/ffi/ffi_api.h +++ b/xla/runtime/ffi/ffi_api.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/ffi/ffi_c_api.h b/xla/runtime/ffi/ffi_c_api.h index ee98223d266b0..c7f3f31f4bc0a 100644 --- a/xla/runtime/ffi/ffi_c_api.h +++ b/xla/runtime/ffi/ffi_c_api.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/jit_executable.cc b/xla/runtime/jit_executable.cc index cee4ef7fb2e29..a18eda247ba95 100644 --- a/xla/runtime/jit_executable.cc +++ b/xla/runtime/jit_executable.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -142,7 +142,7 @@ static bool HasStaticShapeOperands(const FunctionType& signature) { // the operands have unresolved constraints. if (opts.specialization == Specialization::kDisabled && IsSpecializationOnly(fn.constraints)) - return InternalError( + return Internal( "compilation options disabled specialization, yet operands " "have unresolved constraints: [%s]", absl::StrJoin(fn.constraints, ", ")); @@ -279,7 +279,7 @@ StatusOr> JitExecutable::GetExecutable( } assert(false && "failed to detect incorrect operand"); - return InternalError("failed to resolve symbolic shapes"); + return Internal("failed to resolve symbolic shapes"); } // Combine with a hash value computed from the value constrained operands. @@ -319,7 +319,7 @@ StatusOr> JitExecutable::GetExecutable( if (auto specialized = (*compiler)->Specialize(0, arguments, *symbolic_shapes, fn.constraints, listener); !specialized.ok()) { - return InternalError("failed to specialize executable: %s", + return Internal("failed to specialize executable: %s", specialized.message()); } diff --git a/xla/runtime/jit_executable.h b/xla/runtime/jit_executable.h index ef08b6b7adf2d..2cf8878f0fcb9 100644 --- a/xla/runtime/jit_executable.h +++ b/xla/runtime/jit_executable.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/logical_result.h b/xla/runtime/logical_result.h index 572c98248bd78..2702010e5273a 100644 --- a/xla/runtime/logical_result.h +++ b/xla/runtime/logical_result.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/map_by_type.h b/xla/runtime/map_by_type.h index e24b0756940c8..cbd3740811b96 100644 --- a/xla/runtime/map_by_type.h +++ b/xla/runtime/map_by_type.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/map_by_type_test.cc b/xla/runtime/map_by_type_test.cc index adcc1a6c2bbde..0b3b37d876e69 100644 --- a/xla/runtime/map_by_type_test.cc +++ b/xla/runtime/map_by_type_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/memory_mapper.cc b/xla/runtime/memory_mapper.cc index 93e0b2c73df56..77ce86f0f8ce0 100644 --- a/xla/runtime/memory_mapper.cc +++ b/xla/runtime/memory_mapper.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/memory_mapper.h b/xla/runtime/memory_mapper.h index e47a6839d3c7f..1b9271abab64b 100644 --- a/xla/runtime/memory_mapper.h +++ b/xla/runtime/memory_mapper.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/memref_view.h b/xla/runtime/memref_view.h index 7bce51a3000ff..1abefcbe44eb7 100644 --- a/xla/runtime/memref_view.h +++ b/xla/runtime/memref_view.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/module.h b/xla/runtime/module.h index 7b9b3a0602993..c7ba0a6ff4a79 100644 --- a/xla/runtime/module.h +++ b/xla/runtime/module.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/module_registry.cc b/xla/runtime/module_registry.cc index 626f940a010ae..8c7fbae80fb99 100644 --- a/xla/runtime/module_registry.cc +++ b/xla/runtime/module_registry.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/module_registry.h b/xla/runtime/module_registry.h index 33f76bf4b4d22..63e183714665d 100644 --- a/xla/runtime/module_registry.h +++ b/xla/runtime/module_registry.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/module_test.cc b/xla/runtime/module_test.cc index 4ca1ac838b23e..f342c79cba663 100644 --- a/xla/runtime/module_test.cc +++ b/xla/runtime/module_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/results.h b/xla/runtime/results.h index 6afd6cd96174d..45b81f0d5f80b 100644 --- a/xla/runtime/results.h +++ b/xla/runtime/results.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/results_test.cc b/xla/runtime/results_test.cc index fbdbb6a18811c..eb571de7c955b 100644 --- a/xla/runtime/results_test.cc +++ b/xla/runtime/results_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/runtime.h b/xla/runtime/runtime.h index 56e81c8910d0c..f7ad385b0fcf2 100644 --- a/xla/runtime/runtime.h +++ b/xla/runtime/runtime.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/state.h b/xla/runtime/state.h index c505dd96ffea7..8a5c4e09f4f18 100644 --- a/xla/runtime/state.h +++ b/xla/runtime/state.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/state_test.cc b/xla/runtime/state_test.cc index cc2bacfd9c18b..61ccce7aaa698 100644 --- a/xla/runtime/state_test.cc +++ b/xla/runtime/state_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/symbolic_shape.cc b/xla/runtime/symbolic_shape.cc index b3fb1ef5fe22b..a562b3af50928 100644 --- a/xla/runtime/symbolic_shape.cc +++ b/xla/runtime/symbolic_shape.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/symbolic_shape.h b/xla/runtime/symbolic_shape.h index 9f3e1192001ad..6bf8a4466a1b9 100644 --- a/xla/runtime/symbolic_shape.h +++ b/xla/runtime/symbolic_shape.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/symbolic_shape_test.cc b/xla/runtime/symbolic_shape_test.cc index 8621e25349a3b..fe4f599366f8b 100644 --- a/xla/runtime/symbolic_shape_test.cc +++ b/xla/runtime/symbolic_shape_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/tracing.h b/xla/runtime/tracing.h index 6b9336bcc7be7..2d5b8e3b04921 100644 --- a/xla/runtime/tracing.h +++ b/xla/runtime/tracing.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/type_id.cc b/xla/runtime/type_id.cc index 8af0158e400e9..cc34ef81c59a5 100644 --- a/xla/runtime/type_id.cc +++ b/xla/runtime/type_id.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/type_id.h b/xla/runtime/type_id.h index e1a430ad05f26..73c1f3af9b0f7 100644 --- a/xla/runtime/type_id.h +++ b/xla/runtime/type_id.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/type_id_test.cc b/xla/runtime/type_id_test.cc index fd7cf815766d3..984a767d001cb 100644 --- a/xla/runtime/type_id_test.cc +++ b/xla/runtime/type_id_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/types.cc b/xla/runtime/types.cc index 9aee029c021fb..2e2b7dfdaafb5 100644 --- a/xla/runtime/types.cc +++ b/xla/runtime/types.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/runtime/types.h b/xla/runtime/types.h index 4cca54fa0f923..e4cae8bf95cde 100644 --- a/xla/runtime/types.h +++ b/xla/runtime/types.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/BUILD b/xla/service/BUILD index f89094ace260c..09d5c00060465 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -1,26 +1,14 @@ # Description: # XLA service implementation. -load("//xla:strict.default.bzl", "py_strict_library", "py_strict_test") -load("//xla/tests:build_defs.bzl", "xla_test") load("@bazel_skylib//rules:build_test.bzl", "build_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load( - "//xla:xla.bzl", - "xla_cc_binary", - "xla_cc_test", - "xla_py_proto_library", - "xla_py_test_deps", - "xla_symbol_repository_deps", -) -load("//xla/service:xla_compile.bzl", "xla_aot_compile_cpu", "xla_aot_compile_gpu", "xla_aot_compile_gpu_runtime_autotuning") -load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm", "if_rocm_is_configured", ) -load("@tsl//tsl:tsl.bzl", "if_google", "if_libtpu") +load("@tsl//tsl:tsl.bzl", "if_google", "if_libtpu", "internal_visibility", "tsl_copts") load("@tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable", "internal_hlo_deps") load( "@tsl//tsl/platform:build_config.bzl", @@ -32,10 +20,22 @@ load( "@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) +load("//xla:strict.default.bzl", "py_strict_library", "py_strict_test") +load( + "//xla:xla.bzl", + "xla_cc_binary", + "xla_cc_test", + "xla_py_proto_library", + "xla_py_test_deps", + "xla_symbol_repository_deps", +) +load("//xla/service:xla_compile.bzl", "xla_aot_compile_cpu", "xla_aot_compile_gpu", "xla_aot_compile_gpu_runtime_autotuning") +load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") +load("//xla/tests:build_defs.bzl", "xla_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [":friends"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -79,6 +79,7 @@ tf_proto_library( name = "metrics_proto", srcs = ["metrics.proto"], cc_api_version = 2, + visibility = ["//visibility:public"], ) xla_py_proto_library( @@ -114,6 +115,7 @@ cc_library( ":hlo_pass", ":shape_inference", "//xla:frontend_attributes", + "//xla:util", "//xla/hlo/ir:hlo", "@com_google_absl//absl/container:flat_hash_map", "@tsl//tsl/platform:errors", @@ -135,27 +137,6 @@ xla_cc_test( ], ) -cc_library( - name = "async_op_canonicalizer", - srcs = ["async_op_canonicalizer.cc"], - hdrs = ["async_op_canonicalizer.h"], - deps = [ - ":hlo_pass", - "//xla/hlo/ir:hlo", - ], -) - -xla_cc_test( - name = "async_op_canonicalizer_test", - srcs = ["async_op_canonicalizer_test.cc"], - deps = [ - ":async_op_canonicalizer", - ":hlo_dce", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - ], -) - cc_library( name = "all_reduce_key", srcs = ["all_reduce_key.cc"], @@ -319,6 +300,7 @@ cc_library( srcs = ["float_normalization.cc"], hdrs = ["float_normalization.h"], deps = [ + ":call_graph", ":float_support", ":hlo_dce", ":hlo_pass", @@ -341,12 +323,15 @@ xla_cc_test( ":hlo_creation_utils", ":hlo_verifier", "//xla:shape_util", + "//xla:statusor", "//xla:test", "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:statusor", ], ) @@ -397,6 +382,8 @@ cc_library( deps = [ ":collective_ops_utils", ":hlo_pass", + "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/gpu:backend_configs_cc", "//xla/service/graphcycles", @@ -409,15 +396,14 @@ xla_cc_test( name = "collective_permute_decomposer_test", srcs = ["collective_permute_decomposer_test.cc"], deps = [ + ":collective_ops_utils", ":collective_permute_decomposer", ":hlo_parser", - "//xla:test", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", + "//xla/service/gpu:backend_configs_cc", "//xla/tests:hlo_test_base", - "//xla/tests:test_utils", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:test_main", ], ) @@ -534,11 +520,13 @@ xla_cc_test( ":collective_pipeliner", ":hlo_parser", ":hlo_pass_pipeline", + "//xla:status", "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", + "@com_google_absl//absl/log:check", "@com_google_googletest//:gtest_main", ], ) @@ -611,14 +599,17 @@ xla_cc_test( "//xla:statusor", "//xla:test", "//xla:test_helpers", - "//xla:types", "//xla:xla_data_proto_cc", "//xla/client:padding", "//xla/hlo/ir:hlo", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -681,14 +672,19 @@ xla_cc_test( ":hlo_parser", ":sharding_propagation", "//xla:protobuf_util", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/transforms:hlo_constant_splitter", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", + "@tsl//tsl/platform:statusor", ], ) @@ -795,10 +791,11 @@ cc_library( hdrs = ["pattern_matcher.h"], deps = [ ":hlo_parser", - "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:ptrvec", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -845,6 +842,20 @@ xla_cc_test( ], ) +xla_cc_test( + name = "hlo_dfs_reachability_test", + srcs = ["hlo_dfs_reachability_test.cc"], + deps = [ + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_dfs_reachability", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/random", + "@tsl//tsl/platform:test_benchmark", + ], +) + xla_cc_test( name = "hlo_reachability_test", srcs = ["hlo_reachability_test.cc"], @@ -856,12 +867,15 @@ xla_cc_test( "//xla/hlo/ir:hlo_reachability", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/random", + "@tsl//tsl/platform:test_benchmark", ], ) xla_cc_test( name = "hlo_instruction_test", srcs = ["hlo_instruction_test.cc"], + tags = ["no_aarch64"], deps = [ "//xla:literal", "//xla:protobuf_util", @@ -870,6 +884,7 @@ xla_cc_test( "//xla:test_helpers", "//xla:util", "//xla:window_util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/gpu:backend_configs_cc", "//xla/tests:hlo_test_base", @@ -919,7 +934,6 @@ xla_cc_test( deps = [ ":call_graph", "//xla:shape_util", - "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -1045,6 +1059,7 @@ cc_library( "//xla:types", "//xla:util", "//xla/stream_executor", + "//xla/stream_executor:platform_manager", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/host:host_platform_id", "//xla/stream_executor/rocm:rocm_platform_id", @@ -1082,6 +1097,7 @@ cc_library( name = "service", srcs = ["service.cc"], hdrs = ["service.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":allocation_tracker", ":backend", @@ -1128,6 +1144,7 @@ cc_library( "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:protobuf", + "@tsl//tsl/profiler/lib:scoped_annotation", ], alwayslink = 1, ) @@ -1430,6 +1447,7 @@ cc_library( ":dump", ":hlo_execution_profile", ":hlo_graph_dumper", + ":hlo_module_config", ":hlo_proto_cc", ":maybe_owning_device_memory", ":shaped_buffer", @@ -1444,7 +1462,6 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", "//xla/stream_executor", "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory_allocator", @@ -1477,6 +1494,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/stream_executor", + "//xla/stream_executor:dnn", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1486,6 +1504,31 @@ cc_library( ], ) +xla_test( + name = "compiler_test", + srcs = ["compiler_test.cc"], + backend_tags = { + "gpu": if_google(["requires-gpu-nvidia"]), + }, + backends = [ + "gpu", + "cpu", + ], + deps = [ + ":compiler", + "//xla:autotune_results_proto_cc", + "//xla/stream_executor", + "//xla/stream_executor:device_description_proto_cc", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_googletest//:gtest", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", + ] + if_gpu_is_configured([ + "//xla/stream_executor/gpu:gpu_init", + ]), +) + cc_library( name = "llvm_compiler", srcs = ["llvm_compiler.cc"], @@ -1503,24 +1546,27 @@ cc_library( hdrs = ["transfer_manager.h"], deps = [ ":compiler", - ":executable", ":maybe_owning_device_memory", ":shaped_buffer", "//xla:literal", + "//xla:shape_tree", "//xla:shape_util", + "//xla:status", "//xla:status_macros", "//xla:statusor", - "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", "//xla/stream_executor", "//xla/stream_executor:device_memory", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:notification", + "@tsl//tsl/platform:statusor", ], ) @@ -1566,7 +1612,6 @@ cc_library( srcs = ["channel_tracker.cc"], hdrs = ["channel_tracker.h"], deps = [ - "//xla:status_macros", "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", @@ -1608,7 +1653,6 @@ cc_library( deps = [ ":buffer_assignment_proto_cc", ":buffer_value_containers", - ":heap_simulator", ":hlo_alias_analysis", ":hlo_buffer", ":hlo_dataflow_analysis", @@ -1623,6 +1667,7 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", + "//xla/service/heap_simulator", "//xla/service/memory_space_assignment", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", @@ -1643,9 +1688,7 @@ xla_cc_test( name = "buffer_assignment_test", srcs = ["buffer_assignment_test.cc"], deps = [ - ":async_op_canonicalizer", ":buffer_assignment", - ":buffer_assignment_proto_cc", ":buffer_value", ":call_graph", ":copy_insertion", @@ -1714,62 +1757,6 @@ xla_cc_test( ], ) -cc_library( - name = "heap_simulator", - srcs = ["heap_simulator.cc"], - hdrs = ["heap_simulator.h"], - deps = [ - ":buffer_value", - ":buffer_value_containers", - ":hlo_alias_analysis", - ":hlo_buffer", - ":hlo_dataflow_analysis", - ":hlo_ordering", - ":hlo_proto_cc", - ":time_utils", - ":tuple_points_to_analysis", - "//xla:comparison_util", - "//xla:status", - "//xla:statusor", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_live_range", - "//xla/service/memory_space_assignment:repacking", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - -xla_cc_test( - name = "heap_simulator_test", - srcs = ["heap_simulator_test.cc"], - deps = [ - ":async_op_canonicalizer", - ":buffer_value", - ":heap_simulator", - ":hlo_dce", - ":hlo_ordering", - ":hlo_parser", - ":hlo_value", - ":tuple_points_to_analysis", - "//xla:literal", - "//xla:status_macros", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:test", - ], -) - xla_cc_test( name = "hlo_module_group_test", srcs = ["hlo_module_group_test.cc"], @@ -1893,7 +1880,6 @@ cc_library( srcs = ["hlo_memory_scheduler.cc"], hdrs = ["hlo_memory_scheduler.h"], deps = [ - ":heap_simulator", ":hlo_alias_analysis", ":hlo_pass", ":logical_buffer", @@ -1904,6 +1890,7 @@ cc_library( "//xla:types", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/service/heap_simulator", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@tsl//tsl/lib/gtl:map_util", @@ -1916,7 +1903,6 @@ xla_cc_test( name = "hlo_memory_scheduler_test", srcs = ["hlo_memory_scheduler_test.cc"], deps = [ - ":heap_simulator", ":hlo_dce", ":hlo_memory_scheduler", ":hlo_ordering", @@ -1924,6 +1910,7 @@ xla_cc_test( "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service/heap_simulator", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", @@ -2010,6 +1997,7 @@ cc_library( "//xla:status_macros", "//xla:statusor", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/client:xla_builder", "//xla/client:xla_computation", "//xla/client/lib:comparators", @@ -2099,12 +2087,15 @@ cc_library( srcs = ["op_expander_pass.cc"], hdrs = ["op_expander_pass.h"], deps = [ - ":hlo_creation_utils", ":hlo_pass", - "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:statusor", ], ) @@ -2138,12 +2129,17 @@ cc_library( srcs = ["comparison_expander.cc"], hdrs = ["comparison_expander.h"], deps = [ - ":hlo_creation_utils", - ":hlo_pass", ":op_expander_pass", + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_util", + "//xla:statusor", "//xla:util", - "//xla/client/lib:comparators", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", ], ) @@ -2412,6 +2408,7 @@ cc_library( ":hlo_creation_utils", ":hlo_module_config", ":hlo_pass", + ":host_memory_offload_annotations_hdr", ":pattern_matcher", ":shape_inference", "//xla:comparison_util", @@ -2472,6 +2469,7 @@ xla_cc_test( ":hlo_parser", ":hlo_pass", ":hlo_pass_pipeline", + ":host_memory_offload_annotations_hdr", ":layout_assignment", ":pattern_matcher", ":pattern_matcher_gmock", @@ -2512,8 +2510,11 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -2558,6 +2559,7 @@ xla_cc_test( name = "logistic_expander_test", srcs = ["logistic_expander_test.cc"], deps = [ + ":dynamic_padder", ":hlo_creation_utils", ":hlo_parser", ":hlo_pass", @@ -2568,6 +2570,7 @@ xla_cc_test( ":shape_inference", "//xla:literal", "//xla:shape_util", + "//xla:statusor", "//xla:test", "//xla:types", "//xla:window_util", @@ -2716,7 +2719,6 @@ cc_library( ":collective_combiner_utils", ":hlo_domain_map", ":hlo_pass", - "//xla:array2d", "//xla:shape_util", "//xla:status", "//xla:status_macros", @@ -2725,7 +2727,6 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/hlo/utils:hlo_sharding_util", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -3143,16 +3144,6 @@ cc_library( ], ) -cc_library( - name = "sparse_util", - srcs = ["sparse_util.cc"], - hdrs = ["sparse_util.h"], - deps = [ - "//xla:shape_util", - "//xla/hlo/ir:hlo", - ], -) - xla_cc_test( name = "space_to_batch_converter_test", size = "small", @@ -3173,14 +3164,15 @@ cc_library( srcs = ["while_loop_unroller.cc"], hdrs = ["while_loop_unroller.h"], deps = [ - ":async_op_canonicalizer", ":call_inliner", + ":collective_ops_utils", ":flatten_call_graph", ":hlo_cse", ":hlo_pass", ":tuple_simplifier", ":while_loop_analysis", ":while_loop_constant_sinking", + "//xla:comparison_util", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -3195,6 +3187,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@tsl//tsl/platform:errors", @@ -3213,10 +3206,12 @@ xla_cc_test( "//xla/tests:literal_test_util", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", + "@tsl//tsl/platform:statusor", ], ) @@ -3226,8 +3221,10 @@ cc_library( hdrs = ["while_loop_analysis.h"], deps = [ ":pattern_matcher", + "//xla:comparison_util", "//xla:literal", "//xla:literal_util", + "//xla:shape_util", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_reachability", @@ -3240,15 +3237,19 @@ xla_cc_test( name = "while_loop_analysis_test", srcs = ["while_loop_analysis_test.cc"], deps = [ - ":hlo_parser", ":while_loop_analysis", - "//xla:literal_util", + "//xla:comparison_util", + "//xla:statusor", "//xla:test", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", ], ) @@ -3266,14 +3267,21 @@ cc_library( "//xla:comparison_util", "//xla:literal_util", "//xla:shape_util", + "//xla:status_macros", "//xla:statusor", "//xla:union_find", + "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", ], @@ -3283,8 +3291,6 @@ xla_cc_test( name = "while_loop_simplifier_test", srcs = ["while_loop_simplifier_test.cc"], deps = [ - ":algebraic_simplifier", - ":hlo_cse", ":hlo_dce", ":hlo_parser", ":tuple_simplifier", @@ -3292,10 +3298,12 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:test", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", "@tsl//tsl/lib/core:status_test_util", @@ -3312,6 +3320,9 @@ cc_library( "//xla:statusor", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", ], ) @@ -3319,14 +3330,12 @@ xla_cc_test( name = "while_loop_trip_count_annotator_test", srcs = ["while_loop_trip_count_annotator_test.cc"], deps = [ - ":pattern_matcher", - ":while_loop_simplifier", ":while_loop_trip_count_annotator", - "//xla:status_macros", "//xla:test", + "//xla:xla_data_proto_cc", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", ], ) @@ -3382,15 +3391,18 @@ cc_library( hdrs = ["dot_decomposer.h"], deps = [ ":hlo_pass", - ":sparse_util", - "//xla:permutation_util", + ":shape_inference", "//xla:shape_util", - "//xla:status_macros", - "//xla:types", + "//xla:status", "//xla/hlo/ir:hlo", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", ], ) @@ -3399,11 +3411,15 @@ xla_cc_test( srcs = ["dot_decomposer_test.cc"], deps = [ ":dot_decomposer", - ":hlo_parser", + ":pattern_matcher", + ":pattern_matcher_gmock", + "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", - "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", ], ) @@ -3417,8 +3433,14 @@ cc_library( "//xla:shape_util", "//xla:status", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:statusor", ], ) @@ -3430,6 +3452,8 @@ xla_cc_test( ":hlo_parser", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", ], ) @@ -3440,24 +3464,39 @@ cc_library( deps = [ ":hlo_pass", ":shape_inference", + "//xla:protobuf_util", + "//xla:shape_util", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/service/graphcycles", - ], -) - -xla_cc_test( + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( name = "dot_merger_test", srcs = ["dot_merger_test.cc"], deps = [ ":algebraic_simplifier", ":dot_merger", - ":hlo_parser", ":pattern_matcher", ":pattern_matcher_gmock", + "//xla:shape_util", + "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", - "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/strings:string_view", "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", ], ) @@ -3934,12 +3973,19 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/service:shaped_buffer", "//xla/stream_executor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:numbers", + "@tsl//tsl/platform:statusor", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -3957,6 +4003,7 @@ xla_cc_test( "//xla:shape_util", "//xla:types", "//xla/stream_executor", + "//xla/stream_executor:platform_manager", "//xla/stream_executor/host:host_platform", "//xla/stream_executor/host:host_platform_id", "//xla/tests:literal_test_util", @@ -3984,6 +4031,7 @@ cc_library( "//xla/hlo/ir:hlo", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", "@tsl//tsl/lib/gtl:map_util", "@tsl//tsl/platform:errors", ], @@ -4050,6 +4098,7 @@ xla_cc_test( name = "hlo_computation_test", srcs = ["hlo_computation_test.cc"], deps = [ + ":hlo_parser", ":pattern_matcher", ":pattern_matcher_gmock", "//xla:literal", @@ -4194,6 +4243,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -4206,7 +4256,6 @@ xla_cc_test( name = "hlo_dataflow_analysis_test", srcs = ["hlo_dataflow_analysis_test.cc"], deps = [ - ":async_op_canonicalizer", ":flatten_call_graph", ":hlo_creation_utils", ":hlo_dataflow_analysis", @@ -4264,6 +4313,7 @@ cc_library( ":hlo_value", "//xla:shape_tree", "//xla:shape_util", + "//xla:side_effect_util", "//xla:status", "//xla:statusor", "//xla:util", @@ -4271,9 +4321,11 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", @@ -4344,6 +4396,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", @@ -4530,7 +4583,6 @@ cc_library( deps = [ ":call_graph", ":computation_layout", - ":hlo_alias_analysis", ":hlo_dce", ":hlo_graph_dumper", ":hlo_pass", @@ -4540,6 +4592,7 @@ cc_library( "//xla:permutation_util", "//xla:shape_layout", "//xla:shape_util", + "//xla:status", "//xla:status_macros", "//xla:statusor", "//xla:types", @@ -4557,8 +4610,8 @@ cc_library( "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:protobuf", "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", ], ) @@ -4612,10 +4665,11 @@ xla_cc_test( deps = [ ":copy_insertion", ":hlo_graph_dumper", + ":hlo_module_config", ":hlo_parser", - ":hlo_runner", + "//xla:comparison_util", "//xla:debug_options_flags", - "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:test", "//xla:test_helpers", @@ -4625,7 +4679,9 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", + "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_benchmark", ], @@ -4652,6 +4708,7 @@ cc_library( deps = [ ":hlo_dataflow_analysis", ":hlo_pass", + "//xla:shape_util", "//xla/hlo/ir:hlo", ], ) @@ -4715,12 +4772,14 @@ cc_library( hdrs = ["hlo_verifier.h"], deps = [ ":collective_ops_utils", + ":hlo_module_config", ":hlo_pass", - ":pattern_matcher", ":shape_inference", "//xla:comparison_util", "//xla:permutation_util", + "//xla:shape_layout", "//xla:shape_util", + "//xla:status", "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", @@ -4728,8 +4787,15 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -4836,7 +4902,6 @@ xla_cc_test( deps = [ ":hlo_rematerialization_test_utils", "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", ], ) @@ -4947,6 +5012,7 @@ cc_library( hdrs = [ "hlo_pass_pipeline.h", ], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":compilation_stats", ":dump", @@ -4965,6 +5031,7 @@ cc_library( "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:status", + "@tsl//tsl/profiler/lib:scoped_annotation", ], ) @@ -4993,6 +5060,7 @@ cc_library( "//xla:shape_util", "//xla/hlo/ir:hlo", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", "@tsl//tsl/platform:errors", ], ) @@ -5014,6 +5082,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", ], ) @@ -5216,7 +5285,6 @@ cc_library( "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", "@llvm-project//llvm:Core", ], ) @@ -5226,6 +5294,7 @@ cc_library( srcs = ["elemental_ir_emitter.cc"], hdrs = ["elemental_ir_emitter.h"], deps = [ + ":algorithm_util", ":float8_fnuz_ir_emitter", "//xla:permutation_util", "//xla:shape_util", @@ -5241,9 +5310,12 @@ cc_library( "//xla/service/llvm_ir:llvm_loop", "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:loop_emitter", + "//xla/service/llvm_ir:math_ops", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", @@ -5259,20 +5331,31 @@ xla_test( "cpu", "gpu", ], + # TODO(b/332870133): Enable when it passes on H100. + disabled_backends = ["gpu_h100"], tags = [ "no_windows", # TODO(b/152037541) ], deps = [ + ":elemental_ir_emitter", + ":hlo_module_config", ":hlo_parser", "//xla:error_spec", "//xla:execution_options_util", + "//xla:literal", + "//xla:literal_util", "//xla:status_macros", "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/service/llvm_ir:ir_array", "//xla/tests:client_library_test_base", "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:ir_headers", + "@tsl//tsl/platform:statusor", ], ) @@ -5333,8 +5416,11 @@ cc_library( hdrs = ["hlo_graph_dumper.h"], deps = [ ":pattern_matcher", + "//xla:comparison_util", "//xla:literal", "//xla:shape_util", + "//xla:status", + "//xla:statusor", "//xla:types", "//xla:util", "//xla:window_util", @@ -5342,21 +5428,30 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", - "//xla/stream_executor", "//xla/stream_executor:dnn", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@tsl//tsl/lib/gtl:map_util", "@tsl//tsl/lib/io:zlib_compression_options", "@tsl//tsl/lib/io:zlib_outputbuffer", "@tsl//tsl/platform:base64", "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:numbers", "@tsl//tsl/platform:protobuf", "@tsl//tsl/platform:regexp", "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:thread_annotations", ], alwayslink = 1, ) @@ -5373,7 +5468,9 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -5428,12 +5525,12 @@ cc_library( ":hlo_pass", "//xla:literal", "//xla:shape_util", - "//xla:status_macros", + "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings:string_view", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:status", ], ) @@ -5441,18 +5538,16 @@ xla_cc_test( name = "zero_sized_hlo_elimination_test", srcs = ["zero_sized_hlo_elimination_test.cc"], deps = [ - ":shape_inference", ":zero_sized_hlo_elimination", - "//xla:literal", + "//xla:literal_util", "//xla:shape_util", - "//xla:status_macros", "//xla:test", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", ], ) @@ -5462,6 +5557,7 @@ cc_library( hdrs = ["stream_pool.h"], deps = [ "//xla/stream_executor", + "@com_google_absl//absl/strings:str_format", ], ) @@ -5472,6 +5568,7 @@ xla_cc_test( ":stream_pool", "//xla:test_helpers", "//xla/stream_executor", + "//xla/stream_executor:platform_manager", "//xla/stream_executor/host:host_platform", "//xla/tests:xla_internal_test_main", ], @@ -5551,6 +5648,7 @@ cc_library( "@eigen_archive//:eigen3", "@tsl//tsl/platform:blocking_counter", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", ], ) @@ -5704,6 +5802,193 @@ xla_cc_test( ], ) +cc_library( + name = "host_memory_offload_annotations_hdr", + hdrs = ["host_memory_offload_annotations.h"], + deps = [ + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "convert_memory_placement_to_internal_annotations", + srcs = ["convert_memory_placement_to_internal_annotations.cc"], + hdrs = ["convert_memory_placement_to_internal_annotations.h"], + deps = [ + ":host_memory_offload_annotations_hdr", + "//xla:side_effect_util", + "//xla:statusor", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "convert_memory_placement_to_internal_annotations_test", + srcs = ["convert_memory_placement_to_internal_annotations_test.cc"], + deps = [ + ":convert_memory_placement_to_internal_annotations", + ":host_memory_offload_annotations_hdr", + "//xla:statusor", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "host_memory_transfer_asyncifier", + srcs = ["host_memory_transfer_asyncifier.cc"], + hdrs = ["host_memory_transfer_asyncifier.h"], + deps = [ + ":hlo_pass", + "//xla:shape_util", + "//xla:status", + "//xla:statusor", + "//xla:util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "host_memory_transfer_asyncifier_test", + srcs = ["host_memory_transfer_asyncifier_test.cc"], + deps = [ + ":host_memory_transfer_asyncifier", + ":pattern_matcher", + ":pattern_matcher_gmock", + "//xla:statusor", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "host_offload_legalize", + srcs = ["host_offload_legalize.cc"], + hdrs = ["host_offload_legalize.h"], + deps = [ + ":call_graph", + ":hlo_alias_analysis", + ":hlo_pass", + ":hlo_value", + ":host_memory_offload_annotations_hdr", + "//xla:shape_util", + "//xla:status", + "//xla:util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "host_offload_legalize_test", + srcs = ["host_offload_legalize_test.cc"], + shard_count = 12, + deps = [ + ":host_memory_offload_annotations_hdr", + ":host_offload_legalize", + ":pattern_matcher", + ":pattern_matcher_gmock", + "//xla:shape_util", + "//xla:statusor", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "host_offloader", + srcs = ["host_offloader.cc"], + hdrs = ["host_offloader.h"], + deps = [ + ":hlo_alias_analysis", + ":hlo_buffer", + ":hlo_pass", + ":hlo_value", + ":host_memory_offload_annotations_hdr", + ":pattern_matcher", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status", + "//xla:util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "host_offloader_test", + srcs = ["host_offloader_test.cc"], + shard_count = 12, + deps = [ + ":host_memory_offload_annotations_hdr", + ":host_offload_legalize", + ":host_offloader", + ":pattern_matcher", + ":pattern_matcher_gmock", + "//xla:shape_util", + "//xla:statusor", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", + ], +) + cc_library( name = "while_util", srcs = ["while_util.cc"], @@ -5735,13 +6020,16 @@ xla_cc_test( srcs = ["while_util_test.cc"], deps = [ ":while_util", + "//xla:statusor", "//xla:test", "//xla:util", + "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", + "@tsl//tsl/platform:statusor", ], ) @@ -5837,7 +6125,6 @@ cc_library( deps = [ ":hlo_dce", ":hlo_pass", - ":tuple_util", ":while_loop_analysis", ":while_util", "//xla:shape_util", @@ -5848,7 +6135,11 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -5858,11 +6149,17 @@ xla_cc_test( deps = [ ":hlo_parser", ":while_loop_invariant_code_motion", + "//xla:literal_util", + "//xla:shape_util", "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/log", "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", ], ) @@ -5897,6 +6194,40 @@ xla_cc_test( ], ) +cc_library( + name = "fusion_constant_sinking", + srcs = ["fusion_constant_sinking.cc"], + hdrs = ["fusion_constant_sinking.h"], + deps = [ + ":hlo_dce", + ":hlo_pass", + "//xla:shape_util", + "//xla:statusor", + "//xla:util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "fusion_constant_sinking_test", + srcs = ["fusion_constant_sinking_test.cc"], + deps = [ + ":fusion_constant_sinking", + ":pattern_matcher", + ":pattern_matcher_gmock", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@tsl//tsl/platform:statusor", + ], +) + cc_library( name = "while_loop_constant_sinking", srcs = ["while_loop_constant_sinking.cc"], @@ -5904,6 +6235,7 @@ cc_library( deps = [ ":hlo_pass", ":while_util", + "//xla:shape_util", "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", @@ -5925,6 +6257,39 @@ xla_cc_test( ], ) +cc_library( + name = "while_loop_fusible_sinking", + srcs = ["while_loop_fusible_sinking.cc"], + hdrs = ["while_loop_fusible_sinking.h"], + deps = [ + ":hlo_pass", + ":while_util", + "//xla:statusor", + "//xla:util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "while_loop_fusible_sinking_test", + srcs = ["while_loop_fusible_sinking_test.cc"], + deps = [ + ":while_loop_fusible_sinking", + "//xla:test", + "//xla/hlo/utils:hlo_matchers", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@tsl//tsl/platform:statusor", + ], +) + cc_library( name = "despecializer", srcs = ["despecializer.cc"], @@ -6274,6 +6639,17 @@ cc_library( visibility = ["//visibility:public"], ) +cc_test( + name = "custom_call_target_registry_test", + srcs = ["custom_call_target_registry_test.cc"], + deps = [ + ":custom_call_status", + ":custom_call_target_registry", + "//xla:test", + "@tsl//tsl/platform:test_main", + ], +) + # Exposes the public interface only and hides internal details. Suitable for # linking into a static library or binary. cc_library( @@ -6296,13 +6672,13 @@ filegroup( "custom_call_status.h", "custom_call_status_internal.h", ], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), ) filegroup( name = "custom_call_status_srcs", srcs = ["custom_call_status.cc"], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), ) # Internal version that exposes internal details and private interfaces. For @@ -6313,10 +6689,10 @@ cc_library( "custom_call_status_internal.h", ], compatible_with = get_compatible_with_portable(), - visibility = [ + visibility = internal_visibility([ ":__subpackages__", "//tensorflow/compiler/tf2xla:__pkg__", - ], + ]), deps = [ ":custom_call_status", "@com_google_absl//absl/strings", @@ -6431,10 +6807,12 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service/gpu:backend_configs_cc", "//xla/stream_executor:device_memory", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@tsl//tsl/platform:blocking_counter", ], ) @@ -6448,6 +6826,12 @@ cc_library( ":hlo_pass", "//xla:statusor", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:statusor", ], ) @@ -6456,10 +6840,16 @@ xla_cc_test( srcs = ["collective_transformation_reorderer_test.cc"], deps = [ ":collective_transformation_reorderer", + ":hlo_verifier", + "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", ], ) @@ -6470,10 +6860,15 @@ xla_cc_test( ":collective_ops_utils", ":computation_placer", ":global_device_id", + ":hlo_parser", + "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings:string_view", "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], ) @@ -6486,6 +6881,7 @@ cc_library( ":hlo_pass", ":pattern_matcher", "//xla:shape_util", + "//xla:util", "//xla/client:xla_builder", "//xla/client/lib:comparators", "//xla/hlo/ir:hlo", @@ -6525,8 +6921,15 @@ cc_library( ":hlo_creation_utils", ":op_expander_pass", ":shape_inference", + "//xla:shape_util", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -6536,10 +6939,12 @@ xla_cc_test( deps = [ ":operand_upcaster", "//xla:shape_util", + "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", + "@tsl//tsl/platform:statusor", ], ) @@ -6550,7 +6955,11 @@ cc_library( deps = [ ":op_expander_pass", ":shape_inference", + "//xla:shape_util", + "//xla:util", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", ], ) @@ -6560,10 +6969,12 @@ xla_cc_test( deps = [ ":result_caster", "//xla:shape_util", + "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", + "@tsl//tsl/platform:statusor", ], ) @@ -6585,10 +6996,9 @@ cc_library( hdrs = ["convert_operand_folding.h"], deps = [ ":op_expander_pass", - "//xla:comparison_util", + "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "@com_google_absl//absl/base:core_headers", ], ) @@ -6617,7 +7027,10 @@ cc_library( ":hlo_proto_cc", ":hlo_proto_util", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/synchronization", "@tsl//tsl/platform:status", ], ) @@ -6626,11 +7039,15 @@ xla_cc_test( name = "xla_debug_info_manager_test", srcs = ["xla_debug_info_manager_test.cc"], deps = [ + ":hlo_module_config", ":hlo_proto_cc", ":xla_debug_info_manager", + "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/synchronization", + "@com_google_googletest//:gtest_main", ], ) @@ -6690,20 +7107,63 @@ xla_cc_test( ], ) +cc_library( + name = "lockable", + hdrs = ["lockable.h"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "lockable_test", + srcs = ["lockable_test.cc"], + deps = [ + ":lockable", + "@com_google_absl//absl/synchronization", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + cc_library( name = "rendezvous", srcs = ["rendezvous.cc"], hdrs = ["rendezvous.h"], deps = [ + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", ], ) +xla_cc_test( + name = "rendezvous_test", + srcs = ["rendezvous_test.cc"], + deps = [ + ":rendezvous", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_benchmark", + "@tsl//tsl/platform:test_main", + ], +) + cc_library( name = "compilation_environments", srcs = ["compilation_environments.cc"], @@ -6856,9 +7316,11 @@ cc_library( name = "change_op_data_type", srcs = ["change_op_data_type.cc"], hdrs = ["change_op_data_type.h"], + copts = tsl_copts(), deps = [ ":hlo_creation_utils", ":hlo_pass", + "//xla/service/cpu:onednn_matmul_rewriter", ], ) @@ -7000,42 +7462,17 @@ build_test( xla_cc_binary( name = "xla_compile", srcs = ["xla_compile_main.cc"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM"]), visibility = ["//visibility:public"], deps = [ - ":compiler", - ":executable", - ":export_hlo", - ":hlo_module_config", - ":symbol_repository", - ":xla_compile_result_proto_cc_impl", - "//xla:autotune_results_proto_cc", - "//xla:debug_options_flags", - "//xla:statusor", - "//xla:util", - "//xla/mlir_hlo", - "//xla/pjrt:mlir_to_hlo", - "//xla/service:cpu_plugin", - "//xla/service/gpu:autotuner_util", - "//xla/service/gpu:gpu_symbol_repository", - "//xla/tools:hlo_module_loader", + ":cpu_plugin", + "//xla:status", "//xla/tools:xla_compile_lib", - "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/status", + "//xla/tsl/util:command_line_flags", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@stablehlo//:register", - "@tsl//tsl/platform:env", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:path", "@tsl//tsl/platform:platform_port", - "@tsl//tsl/platform:protobuf", - "@tsl//tsl/platform:status_to_from_proto", "@tsl//tsl/platform:types", - "@tsl//tsl/util:command_line_flags", ] + if_cuda_is_configured([ "//xla/service/gpu:executable_proto_cc", "//xla/service/gpu:gpu_compiler", @@ -7089,19 +7526,6 @@ xla_aot_compile_gpu( module = "xla_aot_compile_test_constant.mlir", ) -xla_aot_compile_gpu( - name = "xla_aot_compile_test_gpu_executable_gemm", - autotune_results = "xla_aot_compile_test_autotune_results.prototxt", - gpu_target_config = "xla_aot_compile_test_gpu_target_config.prototxt", - module = "xla_aot_compile_test_gemm.mlir", -) - -xla_aot_compile_gpu_runtime_autotuning( - name = "xla_aot_compile_test_gpu_executable_gemm_runtime_autotuning", - gpu_target_config = "xla_aot_compile_test_gpu_target_config.prototxt", - module = "xla_aot_compile_test_gemm.mlir", -) - xla_aot_compile_gpu( name = "xla_aot_compile_test_gpu_executable_convolution", autotune_results = "xla_aot_compile_test_autotune_results.prototxt", @@ -7119,15 +7543,22 @@ xla_cc_test( name = "xla_aot_compile_cpu_test", srcs = ["xla_aot_compile_cpu_test.cc"], data = [":xla_aot_compile_test_cpu_executable"], - tags = ["no_oss"], + tags = [ + "no_oss", + "notap", + ], deps = [ ":cpu_plugin", ":platform_util", + ":shaped_buffer", "//xla:executable_run_options", + "//xla:literal", "//xla:literal_util", "//xla/client:client_library", + "//xla/client:executable_build_options", "//xla/client:local_client", "//xla/service/cpu:cpu_compiler", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", @@ -7144,9 +7575,13 @@ xla_cc_test( deps = [ ":cpu_plugin", ":platform_util", + ":shaped_buffer", + "//xla:error_spec", "//xla:executable_run_options", + "//xla:literal", "//xla:literal_util", "//xla/client:client_library", + "//xla/client:executable_build_options", "//xla/client:local_client", "//xla/service/cpu:cpu_compiler", "//xla/tests:literal_test_util", @@ -7166,14 +7601,9 @@ xla_cc_test( ":xla_aot_compile_test_gpu_executable", ":xla_aot_compile_test_gpu_executable_hlo", ":xla_aot_compile_test_gpu_executable_constant", - ":xla_aot_compile_test_gpu_executable_gemm", - ":xla_aot_compile_test_gpu_executable_gemm_runtime_autotuning", ":xla_aot_compile_test_gpu_executable_convolution", ":xla_aot_compile_test_gpu_executable_convolution_runtime_autotuning", ]), - env = { - "XLA_FLAGS": "--xla_gpu_enable_xla_runtime_executable", - }, tags = [ "gpu", "no_oss", @@ -7184,12 +7614,18 @@ xla_cc_test( deps = if_cuda_is_configured([ ":gpu_plugin_impl", ":platform_util", + ":shaped_buffer", "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "//xla:executable_run_options", "//xla:literal_util", "//xla/client:client_library", "//xla/client:local_client", "//xla/service/cpu:cpu_compiler", + "//xla/client:executable_build_options", + "//xla:literal", + "//xla:shape_util", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:path", @@ -7222,14 +7658,14 @@ cc_library( deps = [ ":compilation_environments", "//xla:parse_flags_from_env", + "//xla:status", "//xla:statusor", "//xla:util", "//xla:xla_proto_cc", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", "@tsl//tsl/platform:protobuf", "@tsl//tsl/platform:statusor", - "@tsl//tsl/util:command_line_flags", ], ) @@ -7241,7 +7677,6 @@ xla_cc_test( ":compilation_environments", ":gpu_compilation_environment", "//xla:parse_flags_from_env", - "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", @@ -7258,7 +7693,6 @@ cc_library( ":compiler", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", - "//xla/stream_executor:device_description_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", @@ -7286,4 +7720,19 @@ tf_proto_library( visibility = ["//visibility:public"], ) +cc_library( + name = "algorithm_util", + srcs = ["algorithm_util.cc"], + hdrs = ["algorithm_util.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/stream_executor", + "//xla/stream_executor:blas", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + ], +) + exports_files(["xla_aot_compile_test_gpu_target_config.prototxt"]) diff --git a/xla/service/algebraic_simplifier.cc b/xla/service/algebraic_simplifier.cc index 57f46a51316f7..df2a0fb60bf39 100644 --- a/xla/service/algebraic_simplifier.cc +++ b/xla/service/algebraic_simplifier.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -55,6 +56,7 @@ limitations under the License. #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_module_config.h" +#include "xla/service/host_memory_offload_annotations.h" #include "xla/service/pattern_matcher.h" #include "xla/service/shape_inference.h" #include "xla/shape.h" @@ -1078,7 +1080,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { return OkStatus(); } -StatusOr AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare( +absl::StatusOr AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare( HloInstruction* conjunction) { HloInstruction *lhs, *rhs; if (!Match(conjunction, m::And(m::Op(&lhs), m::Op(&rhs)))) { @@ -1122,6 +1124,15 @@ StatusOr AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare( return false; } +Status AlgebraicSimplifierVisitor::HandleAllToAll(HloInstruction* all_to_all) { + if (all_to_all->shape().IsArray() && + Match(all_to_all->mutable_operand(0), + m::Broadcast(m::ConstantScalar()))) { + return ReplaceInstruction(all_to_all, all_to_all->mutable_operand(0)); + } + return OkStatus(); +} + Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { HloInstruction *lhs, *rhs; CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs)))); @@ -1495,10 +1506,19 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { return OkStatus(); } - if (HloInstruction* bitcast_operand = - BitcastingOperandOfReshapeOrCopyChain(copy, options_)) { - ReplaceWithBitcast(copy, bitcast_operand); - return OkStatus(); + const bool copy_is_to_different_memory_space = + options_.is_layout_sensitive() && copy->shape().has_layout() && + copy->operand(0)->shape().has_layout() && + copy->shape().layout().memory_space() != + copy->operand(0)->shape().layout().memory_space(); + if (!copy_is_to_different_memory_space) { + // Do not replace a copy between different memory spaces with a bitcast. + HloInstruction* bitcast_operand = + BitcastingOperandOfReshapeOrCopyChain(copy, options_); + if (bitcast_operand != nullptr) { + ReplaceWithBitcast(copy, bitcast_operand); + return OkStatus(); + } } // Replace Copy(Reshape()) with Reshape() if the Reshape is a logical bitcast. @@ -1789,7 +1809,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( return OkStatus(); } -StatusOr +absl::StatusOr AlgebraicSimplifierVisitor::TrySimplifyTautologicalBitcastConvert( HloInstruction* bitcast) { CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcastConvert); @@ -1868,12 +1888,14 @@ AlgebraicSimplifierVisitor::TryRemoveUpcastAndDowncastSurroundingBinaryOp( const PrimitiveType bin_op_type = bin_op_instr->shape().element_type(); if (!primitive_util::IsIntegralType(final_type) || !primitive_util::IsIntegralType(bin_op_type) || + primitive_util::Is4BitType(final_type) || + primitive_util::Is4BitType(bin_op_type) || (primitive_util::IsSignedIntegralType(final_type) != primitive_util::IsSignedIntegralType(bin_op_type)) || (primitive_util::IsUnsignedIntegralType(final_type) != primitive_util::IsUnsignedIntegralType(bin_op_type))) { // So far, only the safety of this transformation with same signedness - // integer types has been verified. + // non-4-bit integer types has been verified. // TODO(b/277095299): Add support for floating point types. return OkStatus(); } @@ -2229,8 +2251,9 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return OkStatus(); } -StatusOr AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot( - HloInstruction* dot) { +absl::StatusOr +AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot( + HloDotInstruction* dot) { const Shape& lhs_shape = dot->operand(0)->shape(); int64_t num_degenerate_lhs_dims = 0; std::vector lhs_dimension_map(lhs_shape.rank(), -1); @@ -2283,6 +2306,30 @@ StatusOr AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot( } } + std::vector sparsity(dot->sparsity().begin(), + dot->sparsity().end()); + std::vector sparse_meta(sparsity.size()); + for (int i = 0; i < sparsity.size(); ++i) { + // Update sparse dimension number in the descriptor. + SparsityDescriptor& descriptor = sparsity[i]; + const std::vector& dimension_map = + descriptor.index() == 0 ? lhs_dimension_map : rhs_dimension_map; + CHECK_LT(static_cast(descriptor.dimension()), dimension_map.size()); + int preceding_dims_elided = absl::c_count_if( + absl::MakeSpan(dimension_map.data(), descriptor.dimension()), + [&](int64_t dim) { return dim == -1; }); + descriptor.set_dimension(descriptor.dimension() - preceding_dims_elided); + + // Reshape sparsity metadata operand, if affected. + HloInstruction* meta = + dot->mutable_operand(HloDotInstruction::kOperands + i); + Shape new_shape = ShapeUtil::DropDegenerateDimensions(meta->shape()); + if (!ShapeUtil::Equal(new_shape, meta->shape())) { + TF_ASSIGN_OR_RETURN(meta, MakeReshapeHlo(new_shape, meta)); + } + sparse_meta[i] = meta; + } + HloInstruction* new_lhs = num_degenerate_lhs_dims > 0 ? dot->parent()->AddInstruction(HloInstruction::CreateReshape( @@ -2298,8 +2345,9 @@ StatusOr AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot( TF_ASSIGN_OR_RETURN( auto new_dot, MakeDotHlo(new_lhs, new_rhs, new_dnums, dot->precision_config(), - /*preferred_element_type=*/dot->shape().element_type())); + dot->shape().element_type(), sparsity, sparse_meta)); dot->SetupDerivedInstruction(new_dot); + if (ShapeUtil::Compatible(dot->shape(), new_dot->shape())) { TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_dot)); } else { @@ -2387,8 +2435,9 @@ Status AlgebraicSimplifierVisitor::SimplifyTransposeOfBroadcast( transpose->shape())); } -StatusOr AlgebraicSimplifierVisitor::RemoveTransposesFromDotOperands( - HloInstruction* dot) { +absl::StatusOr +AlgebraicSimplifierVisitor::RemoveTransposesFromDotOperands( + HloDotInstruction* dot) { const int64_t rank = dot->shape().rank(); const auto& dnums = dot->dot_dimension_numbers(); HloInstruction* lhs = dot->mutable_operand(0); @@ -2449,7 +2498,7 @@ StatusOr AlgebraicSimplifierVisitor::RemoveTransposesFromDotOperands( return true; } -StatusOr +absl::StatusOr AlgebraicSimplifierVisitor::NormalizeDotOperandToBatchMajorAndContractingMinor( HloInstruction* dot_operand, absl::Span batch_dimensions, absl::Span contracting_dimensions) { @@ -2482,7 +2531,7 @@ HloInstruction* AlgebraicSimplifierVisitor::AddReduce( shape, hlo, zero, dims, AddReduce_computation)); } -StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcat( +absl::StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcat( HloInstruction* dot) { const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); if (dnums.lhs_contracting_dimensions_size() != 1 || @@ -2508,7 +2557,8 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcat( lhs_contracting_dim, /*swapped=*/true); } -StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( +absl::StatusOr +AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( HloInstruction* dot, HloInstruction* lhs, int64_t lhs_contracting_dim, HloInstruction* rhs, int64_t rhs_contracting_dim, bool swapped) { bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate && @@ -2625,7 +2675,7 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( return add_result; } -StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( +absl::StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( HloInstruction* dot) { const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); if (dnums.lhs_contracting_dimensions_size() != 1 || @@ -2765,7 +2815,7 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( // associative, so as long as we permute the elements of the contracting // dimensions on both sides of the dot in the same way, the result of the // dot is not affected. -StatusOr +absl::StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims( HloInstruction* dot) { // This transformation assumes layout is not assigned yet. @@ -2975,8 +3025,13 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims( // If appropriate, reorder operation on dot operand to the mirror operation on // the other dot operand -StatusOr -AlgebraicSimplifierVisitor::AssociativeReorderDotOperator(HloInstruction* dot) { +absl::StatusOr +AlgebraicSimplifierVisitor::AssociativeReorderDotOperator( + HloDotInstruction* dot) { + if (dot->sparse_operands()) { + return nullptr; + } + DotDimensionNumbers dnums = dot->dot_dimension_numbers(); HloInstruction* lhs = dot->mutable_operand(0); HloInstruction* rhs = dot->mutable_operand(1); @@ -3232,6 +3287,7 @@ AlgebraicSimplifierVisitor::AssociativeReorderDotOperator(HloInstruction* dot) { Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { CHECK(computation_ == dot->parent()); + HloDotInstruction* dot_cast = Cast(dot); const auto& dnums = dot->dot_dimension_numbers(); HloInstruction *lhs, *rhs; @@ -3297,7 +3353,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { } // Reorder nested dots with associativity using flops as a heuristic - if (options_.use_associative_reordering()) { + if (options_.use_associative_reordering() && !dot_cast->sparse_operands()) { HloInstruction *inner, *outer; HloInstruction *new_inner, *new_outer; @@ -3305,11 +3361,13 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { bool outer_lhs_dot = false; bool outer_rhs_dot = false; - if (lhs->opcode() == HloOpcode::kDot) { + if (lhs->opcode() == HloOpcode::kDot && + !Cast(lhs)->sparse_operands()) { outer = dot; inner = lhs; outer_lhs_dot = true; - } else if (rhs->opcode() == HloOpcode::kDot) { + } else if (rhs->opcode() == HloOpcode::kDot && + !Cast(rhs)->sparse_operands()) { outer = dot; inner = rhs; outer_rhs_dot = true; @@ -3658,7 +3716,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { if (options_.use_associative_reordering()) { TF_ASSIGN_OR_RETURN(HloInstruction * dot_operator_reordered, - AssociativeReorderDotOperator(dot)); + AssociativeReorderDotOperator(dot_cast)); if (dot_operator_reordered) { VLOG(10) << "Reordering dot operand to its mirror"; return ReplaceInstruction(dot, dot_operator_reordered); @@ -3763,13 +3821,13 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { } TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions, - RemoveDegenerateDimensionFromDot(dot)); + RemoveDegenerateDimensionFromDot(dot_cast)); if (removed_degenerate_dimensions) { return OkStatus(); } TF_ASSIGN_OR_RETURN(bool removed_transposes, - RemoveTransposesFromDotOperands(dot)); + RemoveTransposesFromDotOperands(dot_cast)); if (removed_transposes) { return OkStatus(); } @@ -3834,7 +3892,7 @@ Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) { } namespace { -StatusOr> MinMaxToClamp( +absl::StatusOr> MinMaxToClamp( HloInstruction* clamp_lower_bound_bcast, HloInstruction* to_clamp, HloInstruction* clamp_upper_bound_bcast, AlgebraicSimplifier* simplifier) { HloInstruction* clamp_lower_bound; @@ -3899,6 +3957,24 @@ Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) { } } + // max(max(x, y), y) -> max(x, y) + // max(max(x, y), x) -> max(x, y) + if (Match(lhs, m::MaximumAnyOrder(m::Op(), m::Op().Is(rhs)))) { + return ReplaceInstruction(maximum, lhs); + } + // max(x, max(x, y)) -> max(x, y) + if (Match(rhs, m::Maximum(m::Op().Is(lhs), m::Op()))) { + return ReplaceInstruction(maximum, rhs); + } + // max(y, max(x, y)) -> max(y, x) + // Note that we cannot simplify to max(x, y) here, as for the case that x and + // y are NaN but with different sign, it will make a difference. + if (Match(rhs, m::Maximum(m::Op(), m::Op().Is(lhs)))) { + TF_RETURN_IF_ERROR(maximum->ReplaceOperandWith(1, rhs->mutable_operand(0))); + MarkAsChanged(); + return OkStatus(); + } + HloInstruction* clamp_upper_bound_bcast; HloInstruction* clamp_lower_bound_bcast; HloInstruction* to_clamp; @@ -3953,6 +4029,24 @@ Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) { } } + // min(min(x, y), y) -> min(x, y) + // min(min(x, y), x) -> min(x, y) + if (Match(lhs, m::MinimumAnyOrder(m::Op(), m::Op().Is(rhs)))) { + return ReplaceInstruction(minimum, lhs); + } + // min(x, min(x, y)) -> min(x, y) + if (Match(rhs, m::Minimum(m::Op().Is(lhs), m::Op()))) { + return ReplaceInstruction(minimum, rhs); + } + // min(y, min(x, y)) -> min(y, x) + // Note that we cannot simplify to min(x, y) here, as for the case that x and + // y are NaN but with different sign, it will make a difference. + if (Match(rhs, m::Minimum(m::Op(), m::Op().Is(lhs)))) { + TF_RETURN_IF_ERROR(minimum->ReplaceOperandWith(1, rhs->mutable_operand(0))); + MarkAsChanged(); + return OkStatus(); + } + HloInstruction* clamp_upper_bound_bcast; HloInstruction* clamp_lower_bound_bcast; HloInstruction* to_clamp; @@ -4468,16 +4562,15 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand)); } - // A degenerate broadcast that has the same input and output rank can be - // converted into a transpose. + // A broadcast that has the same input and output rank can be converted into a + // transpose with the inverse of broadcast's dimensions. if (broadcast->shape().rank() == operand->shape().rank() && ShapeUtil::ElementsIn(broadcast->shape()) == ShapeUtil::ElementsIn(operand->shape())) { - VLOG(10) << "transform broadcast(X) -> transpose(X) where " - "n(broadcast(X)) == n(X)"; return ReplaceWithNewInstruction( - broadcast, - HloInstruction::CreateTranspose(broadcast->shape(), operand, dims)); + broadcast, HloInstruction::CreateTranspose( + broadcast->shape(), operand, + InversePermutation(broadcast->dimensions()))); } // A broadcast of a reshape which merely inserts 1-sized dimensions can @@ -4656,6 +4749,72 @@ Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) { } } } + + // Below is a common JAX code issue encountered when generating a Causal mask + // The user either neglected to specify `dtype=bool` in `ones()` + // or mistakenly applied `.astype(bool)` to the result of `tril()` instead of + // to `ones()`. Consequently, the mask will be converted from f32 to bool, + // resulting in suboptimal HLO. + // + // mask = jnp.tril(jnp.ones((seq_len, seq_len))) + // res = jnp.where(mask, x, -jnp.inf) + // + // # it will be lowered to the following suboptimal HLO + // %cmp0 = pred compare(s32, s32, direction=GE) + // %sel0 = f32 select(%cmp0, ones, zeros) + // %cmp1 = pred compare(%sel0, zeros, direction=NE) + // + // # which can be simplified to just + // %cmp0 = pred compare(s32, s32, direction=GE) + // + // Simplification: + // Ne(select(Ge(a, b), ones, zeros), zeros) -> Ge(a, b) + if (compare->comparison_direction() == ComparisonDirection::kNe && + IsAll(rhs, 0)) { + HloInstruction* compare0; + HloInstruction* sel_on_true; + HloInstruction* sel_on_false; + if (Match(lhs, + m::Select(m::Op(&compare0) + .WithOpcode(HloOpcode::kCompare) + .WithComparisonDirection(ComparisonDirection::kGe), + m::Op(&sel_on_true), m::Op(&sel_on_false))) && + IsAll(sel_on_true, 1) && IsAll(sel_on_false, 0) && + SameShape(compare->shape(), compare0->shape())) { + return ReplaceInstruction(compare, compare0); + } + } + + // Gt(Max(a,b), a) -> Gt(b,a) + // Gt(Max(a,b), b) -> Gt(a,b) + // Gt(a, Min(a,b)) -> Gt(a,b) + // Gt(b, Min(a,b)) -> Gt(b,a) + if (compare->comparison_direction() == ComparisonDirection::kGt) { + HloInstruction* a; + HloInstruction* b; + if (Match(lhs, m::Maximum(m::Op(&a), m::Op(&b)))) { + if (rhs == a) { // Gt(Max(a,b), a) -> Gt(b,a) + TF_RETURN_IF_ERROR(compare->ReplaceOperandWith(0, b)); + MarkAsChanged(); + return OkStatus(); + } else if (rhs == b) { // Gt(Max(a,b), b) -> Gt(a,b) + TF_RETURN_IF_ERROR(compare->ReplaceOperandWith(0, a)); + MarkAsChanged(); + return OkStatus(); + } + } else if (Match(rhs, m::Minimum(m::Op(&a), m::Op(&b)))) { + if (lhs == a) { // Gt(a, Min(a,b)) -> Gt(a,b) + TF_RETURN_IF_ERROR(compare->ReplaceOperandWith(1, b)); + MarkAsChanged(); + return OkStatus(); + } else if (lhs == b) { // Gt(b, Min(a,b)) -> Gt(b,a) + TF_RETURN_IF_ERROR(compare->ReplaceOperandWith(1, a)); + MarkAsChanged(); + return OkStatus(); + } + } + } + return OkStatus(); } @@ -4688,6 +4847,27 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) { return TryRemoveUpcastAndDowncastSurroundingBinaryOp(convert); } +Status AlgebraicSimplifierVisitor::HandleCustomCall( + HloInstruction* custom_call) { + // Remove redundant slice to dynamic of pad to static + HloInstruction *pad_to_static0, *pad_to_static1, *pad_to_static_operand; + if (Match( + custom_call, + m::CustomCall( + {"SliceToDynamic"}, + m::GetTupleElement(m::CustomCall(&pad_to_static0, {"PadToStatic"}, + m::Op(&pad_to_static_operand)), + 0), + m::GetTupleElement( + m::CustomCall(&pad_to_static1, {"PadToStatic"}, m::Op()), + 1))) && + pad_to_static0 == pad_to_static1 && + SameShape(custom_call->shape(), pad_to_static_operand->shape())) { + return ReplaceInstruction(custom_call, pad_to_static_operand); + } + return OkStatus(); +} + // Complex(Real(c), Imag(c)) -> c Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) { HloInstruction *c0, *c1; @@ -4979,7 +5159,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { return OkStatus(); } -StatusOr +absl::StatusOr AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( HloInstruction* broadcast) { TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast); @@ -5064,9 +5244,8 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( new_operands.push_back(operand); } } - VLOG(4) << "Sinking broadcast after user:" - << "\n old broadcast: " << broadcast->ToString() - << "\n old user: " << user->ToString(); + VLOG(4) << "Sinking broadcast after user:" << "\n old broadcast: " + << broadcast->ToString() << "\n old user: " << user->ToString(); changed_shape = ShapeUtil::ChangeElementType(operand->shape(), user->shape().element_type()); simplifier_->UpdateLayout(&changed_shape); @@ -5606,7 +5785,7 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) { return OkStatus(); } -StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( +absl::StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( HloInstruction* slice) { // Only try to do this for effective scalars. We could do the same for slicing // out larger pieces of padding (replacing with a broadcast of the padding @@ -5658,7 +5837,7 @@ StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( return false; } -StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( +absl::StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( HloInstruction* slice) { CHECK_EQ(slice->opcode(), HloOpcode::kSlice); if (!IsUnstridedSlice(slice)) { @@ -5714,7 +5893,7 @@ StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( // Allowing a slice to move through a reverse with any necessary updates to the // slice config. -StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse( +absl::StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse( HloInstruction* slice) { VLOG(2) << "Entered TryToReorderSliceAndReverse for slice:" << slice->ToString(); @@ -5904,7 +6083,7 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { options_.use_associative_reordering() && slice->operand(0)->opcode() == HloOpcode::kDot) { // Unpack the dot operands - HloInstruction* dot = slice->mutable_operand(0); + HloDotInstruction* dot = Cast(slice->mutable_operand(0)); HloInstruction* lhs = dot->mutable_operand(0); HloInstruction* rhs = dot->mutable_operand(1); DotDimensionNumbers dnums = dot->dot_dimension_numbers(); @@ -5918,50 +6097,61 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { bool slice_lhs = false; bool slice_rhs = false; + // Sparse metadata may need to be sliced. + std::array sparse_meta = {nullptr, nullptr}; + for (int i = 0; i < dot->sparse_operands(); ++i) { + const SparsityDescriptor& descriptor = dot->sparsity()[i]; + sparse_meta[descriptor.index()] = + dot->mutable_operand(HloDotInstruction::kOperands + i); + } + auto slice_meta = [&](const DimensionVector& operand_start_indices, + const DimensionVector& operand_limit_indices, + const DimensionVector& operand_strides, + HloInstruction* meta, int dimension) { + DimensionVector start_indices, limit_indices, strides; + for (int64_t i = 0; i < meta->shape().rank(); ++i) { + start_indices.push_back(operand_start_indices[i]); + limit_indices.push_back(i != dimension ? operand_limit_indices[i] + : meta->shape().dimensions(i)); + strides.push_back(operand_strides[i]); + } + return MakeSliceHlo(meta, start_indices, limit_indices, strides); + }; + // Here we build up the slice dimensions for lhs DimensionVector lhs_start_indices, lhs_limit_indices, lhs_strides; for (int64_t lhs_index = 0; lhs_index < lhs->shape().rank(); ++lhs_index) { - int64_t start = 0; - int64_t limit = lhs->shape().dimensions(lhs_index); - int64_t stride = 1; - if (map_lhs_dot[lhs_index] != -1) { - // If it is not a contracting dimension, we slice it according to the - // slicing of the corresponding dimension in dot - int64_t dot_index = map_lhs_dot[lhs_index]; - start = slice->slice_starts(dot_index); - limit = slice->slice_limits(dot_index); - stride = slice->slice_strides(dot_index); - } + int64_t size = lhs->shape().dimensions(lhs_index); + // If it is not a contracting dimension, we slice it according to the + // slicing of the corresponding dimension in dot + int64_t i = map_lhs_dot[lhs_index]; + int64_t start = i >= 0 ? slice->slice_starts(i) : 0; + int64_t limit = i >= 0 ? slice->slice_limits(i) : size; + int64_t stride = i >= 0 ? slice->slice_strides(i) : 1; lhs_start_indices.push_back(start); lhs_limit_indices.push_back(limit); lhs_strides.push_back(stride); // Record if any slicing occurs here - if (start != 0 || limit < lhs->shape().dimensions(lhs_index)) { - slice_lhs = true; - } + bool update = start != 0 || limit < size || stride != 1; + slice_lhs |= update; } // Here we do the same for rhs DimensionVector rhs_start_indices, rhs_limit_indices, rhs_strides; for (int64_t rhs_index = 0; rhs_index < rhs->shape().rank(); ++rhs_index) { - int64_t start = 0; - int64_t limit = rhs->shape().dimensions(rhs_index); - int64_t stride = 1; - if (map_rhs_dot[rhs_index] != -1) { - // If it is not a contracting dimension, we slice it according to the - // slicing of the corresponding dimension in dot - int64_t dot_index = map_rhs_dot[rhs_index]; - start = slice->slice_starts(dot_index); - limit = slice->slice_limits(dot_index); - stride = slice->slice_strides(dot_index); - } + int64_t size = rhs->shape().dimensions(rhs_index); + // If it is not a contracting dimension, we slice it according to the + // slicing of the corresponding dimension in dot + int64_t i = map_rhs_dot[rhs_index]; + int64_t start = i >= 0 ? slice->slice_starts(i) : 0; + int64_t limit = i >= 0 ? slice->slice_limits(i) : size; + int64_t stride = i >= 0 ? slice->slice_strides(i) : 1; rhs_start_indices.push_back(start); rhs_limit_indices.push_back(limit); rhs_strides.push_back(stride); // Record if any slicing occurs here - if (start != 0 || limit < rhs->shape().dimensions(rhs_index)) { - slice_rhs = true; - } + bool update = start != 0 || limit < size || stride != 1; + slice_rhs |= update; } // Create Hlo for new slices @@ -5978,11 +6168,37 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { MakeSliceHlo(rhs, rhs_start_indices, rhs_limit_indices, rhs_strides)); } + // Create Hlo for new metadata (for sparse dot) + std::vector new_sparsity; + std::vector new_meta; + if (dot->sparse_operands()) { + if (auto& lhs = dot->sparsity().front(); lhs.index() == 0) { + if (slice_lhs) { + TF_ASSIGN_OR_RETURN( + sparse_meta[0], + slice_meta(lhs_start_indices, lhs_limit_indices, lhs_strides, + sparse_meta[0], lhs.dimension())); + } + new_sparsity.push_back(lhs); + new_meta.push_back(sparse_meta[0]); + } + if (auto& rhs = dot->sparsity().back(); rhs.index() == 1) { + if (slice_rhs) { + TF_ASSIGN_OR_RETURN( + sparse_meta[1], + slice_meta(rhs_start_indices, rhs_limit_indices, rhs_strides, + sparse_meta[1], rhs.dimension())); + } + new_sparsity.push_back(rhs); + new_meta.push_back(sparse_meta[1]); + } + } + // Finally, create Hlo for the new dot and reorder - HloInstruction* new_dot; TF_ASSIGN_OR_RETURN( - new_dot, MakeDotHlo(new_lhs, new_rhs, dnums, dot->precision_config(), - dot->shape().element_type())); + HloInstruction * new_dot, + MakeDotHlo(new_lhs, new_rhs, dnums, dot->precision_config(), + dot->shape().element_type(), new_sparsity, new_meta)); // We should only do this reorder if both new_lhs and new_rhs have free // dimensions. Otherwise, it will conflict with an existing optimization @@ -6225,15 +6441,21 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice( transpose->dimensions())); } - // Convert a dynamic slice into a slice if all offsets are constant and the - // operand is not constant. + // Convert a dynamic slice into a slice if all offsets are constant, the + // operand is not constant, and the input and output memory spaces are the + // same. if (operand->opcode() != HloOpcode::kConstant && absl::c_all_of(absl::MakeSpan(dynamic_slice->operands().begin() + 1, dynamic_slice->operands().end()), [](HloInstruction* operand) { return operand->opcode() == HloOpcode::kConstant && ShapeUtil::ElementIsIntegral(operand->shape()); - })) { + }) && + (!options_.is_layout_sensitive() || + (dynamic_slice->shape().has_layout() && + dynamic_slice->operand(0)->shape().has_layout() && + dynamic_slice->shape().layout().memory_space() == + dynamic_slice->operand(0)->shape().layout().memory_space()))) { const int64_t rank = operand->shape().rank(); std::vector slice_starts(rank); std::vector slice_limits(rank); @@ -6395,6 +6617,20 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( compatible = false; } } + + const auto custom_call_pattern = m::CustomCall( + {host_memory_offload_annotations::kMoveToHostCustomCallTarget}); + if (Match(dus_update, + m::AnyOf(m::Reshape(custom_call_pattern), + m::Bitcast(custom_call_pattern), + custom_call_pattern))) { + // If this dynamic-update-slice is used for host memory offloading, it + // should not be converted into a pad. Also allow for a reshape or a + // bitcast between the host-offloading custom-call and the + // dynamic-update-slice. + compatible = false; + } + PaddingConfig padding_config; if (compatible) { for (int64_t dim = 0; dim < updated_shape.rank(); ++dim) { @@ -6710,6 +6946,20 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { return OkStatus(); } + HloInstruction* negate_arg; + if (ShapeUtil::ElementIsFloating(reduce->shape()) && + Match(arg, m::Negate(m::Op(&negate_arg))) && + IsScalarConstantZero(init_value) && + Match(reduce->to_apply()->root_instruction(), + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)))) { + TF_RETURN_IF_ERROR(reduce->ReplaceOperandWith(0, negate_arg)); + auto users = reduce->users(); + auto* negated_reduce = arg->AddInstruction(HloInstruction::CreateUnary( + reduce->shape(), HloOpcode::kNegate, reduce)); + MarkAsChanged(); + return reduce->ReplaceUsesWith(users, negated_reduce); + } + // Try to reorder reduce(dot(A, B)) to dot(A, reduce(B)) if (options_.use_associative_reordering()) { HloInstruction *a, *b; @@ -6720,7 +6970,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { IsScalarConstantZero(init_value) && Match(reduce->to_apply()->root_instruction(), m::AddAnyOrder(m::Parameter(0), m::Parameter(1))) && - arg->dot_dimension_numbers().lhs_batch_dimensions().empty()) { + arg->dot_dimension_numbers().lhs_batch_dimensions().empty() && + !Cast(arg)->sparse_operands()) { // Create maps for converting AB dimensions to A and B DotDimensionNumbers ab_dnums = arg->dot_dimension_numbers(); std::vector map_ab_a, map_ab_b; @@ -7020,9 +7271,13 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) && Match(reduce->to_apply()->root_instruction(), m::AddAnyOrder(m::Parameter(0), m::Parameter(1))) && - absl::c_any_of(reduce->dimensions(), [&](int64_t dim) { - return dim < dot->dot_dimension_numbers().lhs_batch_dimensions_size(); - })) { + absl::c_any_of( + reduce->dimensions(), + [&](int64_t dim) { + return dim < + dot->dot_dimension_numbers().lhs_batch_dimensions_size(); + }) && + !Cast(dot)->sparse_operands()) { const auto& dnums = dot->dot_dimension_numbers(); DotDimensionNumbers new_dnums = dnums; new_dnums.clear_lhs_batch_dimensions(); @@ -7125,6 +7380,29 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { } } + // For Computation equal to Min, Max, And or Or, replace Reduce(Broadcast(x), + // a, Computation()) with Computation(x, a) when x is a scalar and the + // broadcast is reduced to a scalar. + if (HloInstruction * broadcast_arg; + Match(arg, m::Broadcast(m::Op(&broadcast_arg))) && + (Match(function->root_instruction(), + m::MaximumAnyOrder(m::Parameter(0), m::Parameter(1))) || + Match(function->root_instruction(), + m::MinimumAnyOrder(m::Parameter(0), m::Parameter(1))) || + Match(function->root_instruction(), + m::AndAnyOrder(m::Parameter(0), m::Parameter(1))) || + Match(function->root_instruction(), + m::OrAnyOrder(m::Parameter(0), m::Parameter(1))))) { + if (broadcast_arg->shape().rank() == 0 && + reduce->dimensions().size() == arg->shape().rank()) { + return ReplaceWithNewInstruction( + reduce, + HloInstruction::CreateBinary( + reduce_result_shape, function->root_instruction()->opcode(), + broadcast_arg, reduce->mutable_operand(1))); + } + } + return OkStatus(); } @@ -7659,7 +7937,7 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { } // Convert transpose(dot(a,b)) to dot(b,a). - auto do_transpose_of_dot = [&]() -> StatusOr { + auto do_transpose_of_dot = [&]() -> absl::StatusOr { if (options_.supports_non_canonical_dots() || operand->opcode() != HloOpcode::kDot || operand->user_count() != 1) { return false; @@ -7716,7 +7994,7 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { if (options_.supports_non_canonical_dots() && Match(operand, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs))) && dot->user_count() == 1) { - TF_ASSIGN_OR_RETURN(bool did_transform, [&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(bool did_transform, [&]() -> absl::StatusOr { const auto& dnums = dot->dot_dimension_numbers(); const int64_t num_batch_dims = dnums.lhs_batch_dimensions_size(); for (int64_t i = 0; i < num_batch_dims; ++i) { @@ -7753,10 +8031,23 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { new_dnums.add_rhs_batch_dimensions( dnums.lhs_batch_dimensions(transpose->dimensions(batch_dim))); } + + HloDotInstruction* dot_cast = Cast(dot); + int size = dot_cast->sparse_operands(); // 0..2 + std::vector sparsity(size); + std::vector sparse_meta(size); + for (int i = 0; i < size; ++i) { + SparsityDescriptor descriptor = dot_cast->sparsity()[i]; + descriptor.set_index(1 - descriptor.index()); + sparsity[size - i - 1] = descriptor; + sparse_meta[size - i - 1] = + dot_cast->mutable_operand(HloDotInstruction::kOperands + i); + } + HloInstruction* new_dot = MakeDotHlo(rhs, lhs, new_dnums, SwapOperandsInDotPrecisionConfig(dot->precision_config()), - dot->shape().element_type()) + dot->shape().element_type(), sparsity, sparse_meta) .value(); *new_dot->mutable_shape()->mutable_layout() = transpose->shape().layout(); @@ -7790,14 +8081,15 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { // Replace reshape of a transpose of a reshape with concatenated slicing if // the reshape/transpose combination can be interpreted as a space-to-depth // transformation. - if (operand->opcode() == HloOpcode::kReshape && + if (!options_.is_layout_sensitive() && + operand->opcode() == HloOpcode::kReshape && transpose->user_count() == 1 && HloOpcode::kReshape == transpose->users()[0]->opcode()) { VLOG(2) << "trying depth-to-space transform"; HloInstruction* reshape_operand = operand->mutable_operand(0); HloInstruction* outer_reshape = transpose->users()[0]; TF_ASSIGN_OR_RETURN( - bool did_transform, ([&]() -> StatusOr { + bool did_transform, ([&]() -> absl::StatusOr { if (operand->shape().dimensions_size() != reshape_operand->shape().dimensions_size() + 1) { return false; @@ -7941,7 +8233,7 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { return OkStatus(); } -StatusOr AlgebraicSimplifierVisitor::FoldConvInputPad( +absl::StatusOr AlgebraicSimplifierVisitor::FoldConvInputPad( HloInstruction* convolution) { HloInstruction *lhs, *a, *b; if (Match(convolution, @@ -8002,7 +8294,7 @@ StatusOr AlgebraicSimplifierVisitor::FoldConvInputPad( return false; } -StatusOr AlgebraicSimplifierVisitor::FoldConvFilterPad( +absl::StatusOr AlgebraicSimplifierVisitor::FoldConvFilterPad( HloInstruction* convolution) { auto* lhs = convolution->mutable_operand(0); auto* rhs = convolution->mutable_operand(1); @@ -8068,7 +8360,7 @@ StatusOr AlgebraicSimplifierVisitor::FoldConvFilterPad( return true; } -StatusOr AlgebraicSimplifierVisitor::SwapConvOperands( +absl::StatusOr AlgebraicSimplifierVisitor::SwapConvOperands( HloInstruction* convolution) { if (!options_.enable_conv_operand_swap() || options_.is_layout_sensitive()) { return false; @@ -8205,7 +8497,7 @@ StatusOr AlgebraicSimplifierVisitor::SwapConvOperands( return true; } -StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( +absl::StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( HloInstruction* convolution) { auto* lhs = convolution->mutable_operand(0); auto* rhs = convolution->mutable_operand(1); @@ -8327,7 +8619,7 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( return true; } -StatusOr AlgebraicSimplifierVisitor::SimplifyConvToMultiply( +absl::StatusOr AlgebraicSimplifierVisitor::SimplifyConvToMultiply( HloInstruction* convolution) { if (options_.is_layout_sensitive() || absl::c_linear_search(convolution->precision_config().operand_precision(), @@ -8515,7 +8807,7 @@ Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) { return ReplaceWithNewInstruction(map, std::move(clone)); } -StatusOr AlgebraicSimplifier::Run( +absl::StatusOr AlgebraicSimplifier::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/algebraic_simplifier.h b/xla/service/algebraic_simplifier.h index 23207449735e5..6ec43b9928dce 100644 --- a/xla/service/algebraic_simplifier.h +++ b/xla/service/algebraic_simplifier.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,6 +29,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" #include "xla/util.h" @@ -71,6 +72,11 @@ class AlgebraicSimplifierOptions { return conv_is_lowerable_callback_(reverse_dims); } + void set_conv_is_lowerable_callback( + ConvIsLowerableCallback conv_is_lowerable_callback) { + conv_is_lowerable_callback_ = std::move(conv_is_lowerable_callback); + } + // If is_layout_sensitive is true, then the simplifier preserves layout during // transformation. Otherwise, layout is ignored. void set_is_layout_sensitive(bool is_layout_sensitive) { @@ -282,7 +288,7 @@ class AlgebraicSimplifier : public HloModulePass { // Run algebraic simplification on the given computation. Returns whether the // computation was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -313,6 +319,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { Status HandleAdd(HloInstruction* add) override; + Status HandleAllToAll(HloInstruction* all_to_all) override; + Status HandleAnd(HloInstruction* logical_and) override; Status HandleBitcast(HloInstruction* bitcast) override; @@ -333,6 +341,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { Status HandleComplex(HloInstruction* complex) override; + Status HandleCustomCall(HloInstruction* custom_call) override; + Status HandleReal(HloInstruction* real) override; Status HandleImag(HloInstruction* imag) override; @@ -441,7 +451,7 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { private: // Removes degenerate dimension from dot. - StatusOr RemoveDegenerateDimensionFromDot(HloInstruction* dot); + absl::StatusOr RemoveDegenerateDimensionFromDot(HloDotInstruction* dot); // Moves the transpose to the broadcast if possible. Can also be called with a // bitcast transpose. @@ -464,7 +474,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // Transposes a dot operand such that the batch dimensions are the most major, // and the contracting dimensions are most minor. - StatusOr NormalizeDotOperandToBatchMajorAndContractingMinor( + absl::StatusOr + NormalizeDotOperandToBatchMajorAndContractingMinor( HloInstruction* dot_operand, absl::Span batch_dimensions, absl::Span contracting_dimensions); @@ -475,7 +486,7 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // // LHS [batch dims..., non-contracting dim, contracting dim] // RHS [batch dims..., contracting dim, non-contracting dim]. - StatusOr RemoveTransposesFromDotOperands(HloInstruction* dot); + absl::StatusOr RemoveTransposesFromDotOperands(HloDotInstruction* dot); // Helper method to perform and add reduction on a list of dimensions. HloInstruction* AddReduce(HloInstruction* hlo, absl::Span dims, @@ -519,20 +530,21 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // A Broadcast that feeds an element-wise operation with a unique non-scalar // operand can sink to after the operation. - StatusOr TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( + absl::StatusOr TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( HloInstruction* broadcast); - StatusOr OptimizeDotOfConcat(HloInstruction* dot); - StatusOr OptimizeDotOfConcatHelper( + absl::StatusOr OptimizeDotOfConcat(HloInstruction* dot); + absl::StatusOr OptimizeDotOfConcatHelper( HloInstruction* dot, HloInstruction* lhs, int64_t lhs_contracting_dim, HloInstruction* rhs, int64_t rhs_contracting_dim, bool swapped); - StatusOr OptimizeDotOfGather(HloInstruction* dot); + absl::StatusOr OptimizeDotOfGather(HloInstruction* dot); - StatusOr OptimizeDotOfReorderContractingDims( + absl::StatusOr OptimizeDotOfReorderContractingDims( HloInstruction* dot); - StatusOr AssociativeReorderDotOperator(HloInstruction* dot); + absl::StatusOr AssociativeReorderDotOperator( + HloDotInstruction* dot); HloComputation* GetOrCreateScalarAddComputation(PrimitiveType type) { HloComputation*& scalar_add_computation = scalar_add_computations_[type]; @@ -556,37 +568,39 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // Tries to fold a kPad in the input or filter into the convolution // instruction's window. - virtual StatusOr FoldConvInputPad(HloInstruction* convolution); - StatusOr FoldConvFilterPad(HloInstruction* convolution); + virtual absl::StatusOr FoldConvInputPad(HloInstruction* convolution); + absl::StatusOr FoldConvFilterPad(HloInstruction* convolution); // Tries to swap convolution operands if they would result in a more efficient // convolution. - StatusOr SwapConvOperands(HloInstruction* convolution); + absl::StatusOr SwapConvOperands(HloInstruction* convolution); // Tries to use a kDot in place of the given convolution. - StatusOr SimplifyConvToDot(HloInstruction* convolution); + absl::StatusOr SimplifyConvToDot(HloInstruction* convolution); // Tries to use a multiplication in place of the given convolution. - StatusOr SimplifyConvToMultiply(HloInstruction* convolution); + absl::StatusOr SimplifyConvToMultiply(HloInstruction* convolution); // Tries to simplify a slice where the result of the slice is a scalar. - StatusOr TrySimplifyScalarSlice(HloInstruction* slice); + absl::StatusOr TrySimplifyScalarSlice(HloInstruction* slice); // Tries to convert slice(reshape(X)) into reshape(slice(X)) - StatusOr TryToReorderSliceAndReshape(HloInstruction* slice); + absl::StatusOr TryToReorderSliceAndReshape(HloInstruction* slice); // Tries to convert slice(reverse(X)) into reverse(slice(X)) - StatusOr TryToReorderSliceAndReverse(HloInstruction* slice); + absl::StatusOr TryToReorderSliceAndReverse(HloInstruction* slice); // Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into // `(< a N)`. This is crucial for being able to figure out the loop trip // count. // // Assumes that the input is conjunction. - StatusOr TrySimplifyTautologicalCompare(HloInstruction* conjunction); + absl::StatusOr TrySimplifyTautologicalCompare( + HloInstruction* conjunction); // Tries to simlplify (bitcast-convert (concat (bitcast-convert A) ...)) where // the types of inner and outer bitcast-convert cancel out. - StatusOr TrySimplifyTautologicalBitcastConvert(HloInstruction* bitcast); + absl::StatusOr TrySimplifyTautologicalBitcastConvert( + HloInstruction* bitcast); // Tries to remove surrounding converts around a binary op where the op has a // more precise type than its inputs and output. diff --git a/xla/service/algebraic_simplifier_overflow_test.cc b/xla/service/algebraic_simplifier_overflow_test.cc index 3196b610cd41a..8e011d6d24edf 100644 --- a/xla/service/algebraic_simplifier_overflow_test.cc +++ b/xla/service/algebraic_simplifier_overflow_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/algebraic_simplifier_proof_distributive_property.py b/xla/service/algebraic_simplifier_proof_distributive_property.py index 0f6a8ce67522b..6c4e78ba9a82e 100644 --- a/xla/service/algebraic_simplifier_proof_distributive_property.py +++ b/xla/service/algebraic_simplifier_proof_distributive_property.py @@ -1,4 +1,4 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/service/algebraic_simplifier_test.cc b/xla/service/algebraic_simplifier_test.cc index a156077655ab1..83515ae5615cf 100644 --- a/xla/service/algebraic_simplifier_test.cc +++ b/xla/service/algebraic_simplifier_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -39,6 +39,7 @@ limitations under the License. #include "xla/service/hlo_parser.h" #include "xla/service/hlo_pass_fix.h" #include "xla/service/hlo_pass_pipeline.h" +#include "xla/service/host_memory_offload_annotations.h" #include "xla/service/layout_assignment.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" @@ -928,6 +929,78 @@ TEST_F(AlgebraicSimplifierTest, ASSERT_FALSE(AlgebraicSimplifier(options).Run(m.get()).value()); } +TEST_F(AlgebraicSimplifierTest, ReduceOfNegate) { + const char* kModuleStr = R"( + HloModule m + add_f32 { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT r = f32[] add(p0, p1) + } + + ENTRY test { + p = f32[15,7] parameter(0) + n = negate(p) + ROOT reduce = f32[15] reduce(n, f32[] constant(0)), dimensions={1}, to_apply=add_f32 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + AlgebraicSimplifierOptions options = default_options_; + options.set_unconditionally_simplify_reduce_of_transpose_or_reshape(true); + ASSERT_TRUE(AlgebraicSimplifier(options).Run(m.get()).value()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Negate(m::Reduce(m::Parameter(0), m::ConstantScalar(0))))); +} + +TEST_F(AlgebraicSimplifierTest, ReduceBroadcastOfScalar) { + // Test Reduce(Broadcast(x), a, Max) + const char* kModuleStrForMax = R"( + HloModule m + max_f32 { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT r = f32[] maximum(p0, p1) + } + + ENTRY test { + p = f32[] parameter(0) + b = f32[1000,1000] broadcast(p), dimensions={} + ROOT reduce = f32[] reduce(b, f32[] constant(0)), dimensions={0,1}, to_apply=max_f32 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, + ParseAndReturnVerifiedModule(kModuleStrForMax)); + AlgebraicSimplifierOptions options = default_options_; + ASSERT_TRUE(AlgebraicSimplifier(options).Run(m.get()).value()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::MaximumAnyOrder(m::Parameter(0), m::ConstantScalar(0)))); + + // Test Reduce(Broadcast(x), a, And) + const char* kModuleStrForAnd = R"( + HloModule m + and_u4 { + p0 = u4[] parameter(0) + p1 = u4[] parameter(1) + ROOT r = u4[] and(p0, p1) + } + + ENTRY test { + p = u4[] parameter(0) + b = u4[1000,1000] broadcast(p), dimensions={} + ROOT reduce = u4[] reduce(b, u4[] constant(0)), dimensions={0,1}, to_apply=and_u4 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kModuleStrForAnd)); + ASSERT_TRUE(AlgebraicSimplifier(options).Run(m.get()).value()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::AndAnyOrder(m::Parameter(0), m::ConstantScalar(0)))); +} + // Test that Const + A is canonicalized to A + Const. TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { auto m = CreateNewVerifiedModule(); @@ -2427,6 +2500,150 @@ ENTRY test { } } +TEST_F(AlgebraicSimplifierTest, MinimumOfMinimum1) { + const char* const hlo_string = R"( +HloModule test + +ENTRY main { + x = f32[] parameter(0) + y = f32[] parameter(1) + min1 = f32[] minimum(x, y) + ROOT min = f32[] minimum(min1, y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Minimum(m::Parameter(0), m::Parameter(1)))); +} + +TEST_F(AlgebraicSimplifierTest, MinimumOfMinimum2) { + const char* const hlo_string = R"( +HloModule test + +ENTRY main { + x = f32[] parameter(0) + y = f32[] parameter(1) + min1 = f32[] minimum(x, y) + ROOT min = f32[] minimum(min1, x) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Minimum(m::Parameter(0), m::Parameter(1)))); +} + +TEST_F(AlgebraicSimplifierTest, MinimumOfMinimum3) { + const char* const hlo_string = R"( +HloModule test + +ENTRY main { + x = f32[] parameter(0) + y = f32[] parameter(1) + min1 = f32[] minimum(x, y) + ROOT min = f32[] minimum(y, min1) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Minimum(m::Parameter(1), m::Parameter(0)))); +} + +TEST_F(AlgebraicSimplifierTest, MinimumOfMinimum4) { + const char* const hlo_string = R"( +HloModule test + +ENTRY main { + x = f32[] parameter(0) + y = f32[] parameter(1) + min1 = f32[] minimum(x, y) + ROOT min = f32[] minimum(x, min1) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Minimum(m::Parameter(0), m::Parameter(1)))); +} + +TEST_F(AlgebraicSimplifierTest, MaximumOfMaximum1) { + const char* const hlo_string = R"( +HloModule test + +ENTRY main { + x = f32[] parameter(0) + y = f32[] parameter(1) + max1 = f32[] maximum(x, y) + ROOT max = f32[] maximum(max1, y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Maximum(m::Parameter(0), m::Parameter(1)))); +} + +TEST_F(AlgebraicSimplifierTest, MaximumOfMaximum2) { + const char* const hlo_string = R"( +HloModule test + +ENTRY main { + x = f32[] parameter(0) + y = f32[] parameter(1) + max1 = f32[] maximum(x, y) + ROOT max = f32[] maximum(max1, x) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Maximum(m::Parameter(0), m::Parameter(1)))); +} + +TEST_F(AlgebraicSimplifierTest, MaximumOfMaximum3) { + const char* const hlo_string = R"( +HloModule test + +ENTRY main { + x = f32[] parameter(0) + y = f32[] parameter(1) + max1 = f32[] maximum(x, y) + ROOT max = f32[] maximum(y, max1) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Maximum(m::Parameter(1), m::Parameter(0)))); +} + +TEST_F(AlgebraicSimplifierTest, MaximumOfMaximum4) { + const char* const hlo_string = R"( +HloModule test + +ENTRY main { + x = f32[] parameter(0) + y = f32[] parameter(1) + max1 = f32[] maximum(x, y) + ROOT max = f32[] maximum(x, max1) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Maximum(m::Parameter(0), m::Parameter(1)))); +} + TEST_F(AlgebraicSimplifierTest, TrivialReduceWindow_Add) { const char* const hlo_string = R"( HloModule test @@ -2883,7 +3100,6 @@ TEST_F(AlgebraicSimplifierTest, DoNotRemoveUnaryConcatenateWithCtrlDep) { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Concatenate(m::Parameter(0)))); - LOG(ERROR) << "module: " << m->ToString(); AlgebraicSimplifier simplifier(default_options_); ASSERT_FALSE(simplifier.Run(m.get()).value()); @@ -3296,6 +3512,37 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { EXPECT_THAT(computation->root_instruction(), param0); } +// Test that a simplification which changes copy to a bitcast is not performed +// if layout sensitive is true. +TEST_F(AlgebraicSimplifierTest, CopyWithDifferentMemorySpaces) { + auto m = CreateNewVerifiedModule(); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0")); + HloInstruction* copy = builder.AddInstruction( + HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); + + // Set to different memory spaces. + param0->mutable_shape()->mutable_layout()->set_memory_space(0); + copy->mutable_shape()->mutable_layout()->set_memory_space(123); + + HloComputation* computation = + m->AddEntryComputationWithLayouts(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); + + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + EXPECT_FALSE(simplifier.Run(m.get()).value()); + + // Copy has not been removed. + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); +} + // Test that a reshape which could be replaced with a bitcast is not if // add_bitcasts is false. TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { @@ -5415,7 +5662,7 @@ TEST_F(AlgebraicSimplifierTest, TransposeReshapeToConcatSlice) { HloModule TransposeReshapeDepthToSpace ENTRY entry { - %param = f32[8,14,14,128]{0,1,2,3} parameter(0) + %param = f32[8,14,14,128] parameter(0) %reshape.1 = f32[8,14,14,2,64] reshape(%param) %transpose = transpose(%reshape.1), dimensions={0,1,3,2,4} ROOT %reshape.2 = f32[8,28,14,64] reshape(%transpose) @@ -5442,7 +5689,7 @@ TEST_F(AlgebraicSimplifierTest, TransposeReshapeTooLarge) { HloModule TransposeReshapeDepthToSpaceBig ENTRY entry { - %param = f32[8,14,14,128]{0,1,2,3} parameter(0) + %param = f32[8,14,14,128] parameter(0) %reshape.1 = f32[8,14,14,8,16] reshape(%param) %transpose = transpose(%reshape.1), dimensions={0,1,3,2,4} ROOT %reshape.2 = f32[8,112,14,16] reshape(%transpose) @@ -5462,7 +5709,7 @@ TEST_F(AlgebraicSimplifierTest, TransposeReshapeNotDepthToSpace) { HloModule TransposeReshapeDepthToSpace ENTRY entry { - %param = f32[8,14,14,128]{0,1,2,3} parameter(0) + %param = f32[8,14,14,128] parameter(0) %reshape.1 = f32[8,14,14,2,64] reshape(%param) %transpose = transpose(%reshape.1), dimensions={0,3,1,2,4} ROOT %reshape.2 = f32[8,28,14,64] reshape(%transpose) @@ -6076,6 +6323,30 @@ TEST_F(AlgebraicSimplifierTest, SliceDotReorder) { GmockMatch(m::Dot(m::Slice(m::Parameter(0)), m::Parameter(1)))); } +TEST_F(AlgebraicSimplifierTest, SliceDotReorderWithStrides) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + a = f32[2048,2] parameter(0) + b = f32[2,2048] parameter(1) + dot = f32[2048,2048] dot(a,b), + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + ROOT slice = f32[16,256] slice(dot), slice={[0:128:8],[0:2048:8]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + options.set_use_associative_reordering(true); + EXPECT_TRUE(AlgebraicSimplifier(options).Run(module.get()).value()); + ASSERT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Dot(m::Slice(m::Parameter(0)), m::Slice(m::Parameter(1))))); +} + TEST_F(AlgebraicSimplifierTest, TransposeOfBatchDot) { const char* hlo_string = R"( HloModule module @@ -6162,6 +6433,11 @@ TEST_F(AlgebraicSimplifierTest, TransposeOfNonCanonicalBatchDotCantSimplify) { } TEST_F(AlgebraicSimplifierTest, DynamicSliceOfTranspose) { + // This test is without layouts so we have to set the verifier to be layout + // insensitive. + verifier_layout_sensitive_ = false; + instruction_can_change_layout_func_ = {}; + const char* hlo_string = R"( HloModule module @@ -7063,6 +7339,148 @@ ENTRY AddBroadcastZeroWithDynamicSlice { EXPECT_THAT(root->operand(1)->opcode(), HloOpcode::kPad); } +// Test that dynamic-update-slice with a scalar broadcast does not become a pad +// if the dynamic-update-slice is for host memory offload. +TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceOfBroadcastToPadHostOffload) { + const std::string hlo_string = absl::StrFormat( + R"( +HloModule DynamicUpdateSliceOfBroadcastToPadHostOffload + +ENTRY DynamicUpdateSliceOfBroadcastToPadHostOffload { + constant_bf16_0 = bf16[] constant(0) + broadcast_0 = bf16[56,2,2048,2,128] broadcast(constant_bf16_0), dimensions={} + param_0 = bf16[1,2,2048,2,128] parameter(0) + custom_call = bf16[1,2,2048,2,128] custom-call(param_0), custom_call_target="%s" + constant_s32_0 = s32[] constant(0) + ROOT dynamic_update_slice = bf16[56,2,2048,2,128] dynamic-update-slice(broadcast_0, custom_call, constant_s32_0, constant_s32_0, constant_s32_0, constant_s32_0, constant_s32_0) +} +)", + host_memory_offload_annotations::kMoveToHostCustomCallTarget); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + VLOG(2) << "Before rewrite dus->pad\n" << module->ToString(); + AlgebraicSimplifier simplifier(default_options_); + EXPECT_FALSE(simplifier.Run(module.get()).value()); + VLOG(2) << "After rewrite dus->pad\n" << module->ToString(); + // Look for the following pattern: + // constant(0) param(0) + // | | + // broadcast custom-call constant(0) + // | | / + // | | / + // | | / + // | | / + // | | / + // dynamic-update-slice + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::DynamicUpdateSlice( + m::Broadcast(m::ConstantScalar(0)), + m::CustomCall( + {host_memory_offload_annotations::kMoveToHostCustomCallTarget}, + m::Parameter(0)), + m::ConstantScalar(0), m::ConstantScalar(0), m::ConstantScalar(0), + m::ConstantScalar(0), m::ConstantScalar(0)))); +} + +// Test that dynamic-update-slice with a scalar broadcast does not become a pad +// if the dynamic-update-slice is for host memory offload. Also disable +// optimization if there is a reshape between the custom-call and the +// dynamic-update-slice. +TEST_F(AlgebraicSimplifierTest, + DynamicUpdateSliceOfBroadcastToPadHostOffloadWithReshape) { + const std::string hlo_string = absl::StrFormat( + R"( +HloModule DynamicUpdateSliceOfBroadcastToPadHostOffloadWithReshape + +ENTRY DynamicUpdateSliceOfBroadcastToPadHostOffloadWithReshape { + constant_bf16_0 = bf16[] constant(0) + broadcast_0 = bf16[56,2,2048,2,128] broadcast(constant_bf16_0), dimensions={} + param_0 = bf16[2,2048,2,128] parameter(0) + custom_call = bf16[2,2048,2,128] custom-call(param_0), custom_call_target="%s" + reshape = bf16[1,2,2048,2,128] reshape(custom_call) + constant_s32_0 = s32[] constant(0) + ROOT dynamic_update_slice = bf16[56,2,2048,2,128] dynamic-update-slice(broadcast_0, reshape, constant_s32_0, constant_s32_0, constant_s32_0, constant_s32_0, constant_s32_0) +} +)", + host_memory_offload_annotations::kMoveToHostCustomCallTarget); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + VLOG(2) << "Before rewrite dus->pad\n" << module->ToString(); + AlgebraicSimplifier simplifier(default_options_); + EXPECT_FALSE(simplifier.Run(module.get()).value()); + VLOG(2) << "After rewrite dus->pad\n" << module->ToString(); + // Look for the following pattern: + // param(0) + // | + // constant(0) custom-call + // | | + // broadcast reshape constant(0) + // | | / + // | | / + // | | / + // | | / + // dynamic-update-slice + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::DynamicUpdateSlice( + m::Broadcast(m::ConstantScalar(0)), + m::Reshape(m::CustomCall( + {host_memory_offload_annotations::kMoveToHostCustomCallTarget}, + m::Parameter(0))), + m::ConstantScalar(0), m::ConstantScalar(0), m::ConstantScalar(0), + m::ConstantScalar(0), m::ConstantScalar(0)))); +} + +// Test that dynamic-update-slice with a scalar broadcast does not become a pad +// if the dynamic-update-slice is for host memory offload. Also disable +// optimization if there is a bitcast between the custom-call and the +// dynamic-update-slice. +TEST_F(AlgebraicSimplifierTest, + DynamicUpdateSliceOfBroadcastToPadHostOffloadWithBitcast) { + const std::string hlo_string = absl::StrFormat( + R"( +HloModule DynamicUpdateSliceOfBroadcastToPadHostOffloadWithBitcast + +ENTRY DynamicUpdateSliceOfBroadcastToPadHostOffloadWithBitcast { + constant_bf16_0 = bf16[] constant(0) + broadcast_0 = bf16[56,2,2048,2,128] broadcast(constant_bf16_0), dimensions={} + param_0 = bf16[2,2048,2,128] parameter(0) + custom_call = bf16[2,2048,2,128] custom-call(param_0), custom_call_target="%s" + bitcast = bf16[1,2,2048,2,128] bitcast(custom_call) + constant_s32_0 = s32[] constant(0) + ROOT dynamic_update_slice = bf16[56,2,2048,2,128] dynamic-update-slice(broadcast_0, bitcast, constant_s32_0, constant_s32_0, constant_s32_0, constant_s32_0, constant_s32_0) +} +)", + host_memory_offload_annotations::kMoveToHostCustomCallTarget); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + VLOG(2) << "Before rewrite dus->pad\n" << module->ToString(); + AlgebraicSimplifier simplifier(default_options_); + EXPECT_FALSE(simplifier.Run(module.get()).value()); + VLOG(2) << "After rewrite dus->pad\n" << module->ToString(); + // Look for the following pattern: + // param(0) + // | + // constant(0) custom-call + // | | + // broadcast bitcast constant(0) + // | | / + // | | / + // | | / + // | | / + // dynamic-update-slice + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::DynamicUpdateSlice( + m::Broadcast(m::ConstantScalar(0)), + m::Bitcast(m::CustomCall( + {host_memory_offload_annotations::kMoveToHostCustomCallTarget}, + m::Parameter(0))), + m::ConstantScalar(0), m::ConstantScalar(0), m::ConstantScalar(0), + m::ConstantScalar(0), m::ConstantScalar(0)))); +} + // Test of dynamic-update-slice with dims where update and result have the same // size so we can replace indices to 0. TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceTrivialIndices) { @@ -7545,6 +7963,7 @@ TEST_F(AlgebraicSimplifierTest, DividedByConstantInstructionWithoutLayout) { // This test is without layouts so we have to set the verifier to be layout // insensitive. verifier_layout_sensitive_ = false; + instruction_can_change_layout_func_ = {}; Shape shape = ShapeUtil::MakeShape(F32, {}); shape.clear_layout(); @@ -7897,6 +8316,78 @@ TEST_F(AlgebraicSimplifierTest, ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); } +TEST_F(AlgebraicSimplifierTest, CompareGtMaxA) { + // Gt(Max(a,b), a) -> Gt(b,a) + const char* kModuleStr = R"( + HloModule m + test { + a = f32[4] parameter(0) + b = f32[4] parameter(1) + m0 = f32[4] maximum(a, b) + ROOT compare = pred[4] compare(m0, a), direction=GT + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Compare(m::Parameter(1), m::Parameter(0)) + .WithComparisonDirection(ComparisonDirection::kGt))); +} + +TEST_F(AlgebraicSimplifierTest, CompareGtMaxB) { + // Gt(Max(a,b), b) -> Gt(a,b) + const char* kModuleStr = R"( + HloModule m + test { + a = f32[4] parameter(0) + b = f32[4] parameter(1) + m0 = f32[4] maximum(a, b) + ROOT compare = pred[4] compare(m0, b), direction=GT + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Compare(m::Parameter(0), m::Parameter(1)) + .WithComparisonDirection(ComparisonDirection::kGt))); +} + +TEST_F(AlgebraicSimplifierTest, CompareGtAMin) { + // Gt(a, Min(a,b)) -> Gt(a,b) + const char* kModuleStr = R"( + HloModule m + test { + a = f32[4] parameter(0) + b = f32[4] parameter(1) + m0 = f32[4] minimum(a, b) + ROOT compare = pred[4] compare(a, m0), direction=GT + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Compare(m::Parameter(0), m::Parameter(1)) + .WithComparisonDirection(ComparisonDirection::kGt))); +} + +TEST_F(AlgebraicSimplifierTest, CompareGtBMin) { + // Gt(b, Min(a,b)) -> Gt(b,a) + const char* kModuleStr = R"( + HloModule m + test { + a = f32[4] parameter(0) + b = f32[4] parameter(1) + m0 = f32[4] minimum(a, b) + ROOT compare = pred[4] compare(b, m0), direction=GT + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Compare(m::Parameter(1), m::Parameter(0)) + .WithComparisonDirection(ComparisonDirection::kGt))); +} + TEST_F(AlgebraicSimplifierTest, CompareIota) { const char* kModuleStr = R"( HloModule m @@ -8162,6 +8653,30 @@ TEST_F(AlgebraicSimplifierTest, EqTrue2) { EXPECT_EQ(root, param0); } +TEST_F(AlgebraicSimplifierTest, CompareSelectCompare) { + // Causal mask suboptimal HLO simplification + // Ne(select(Ge(a, b), ones, zeros), zeros) -> Ge(a, b) + const char* kModuleStr = R"( + HloModule m + test { + a = s32[4,4] parameter(0) + b = s32[4,4] parameter(1) + %cmp0 = pred[4,4] compare(a, b), direction=GE + %c1 = f32[] constant(1) + %ones = f32[4,4] broadcast(f32[] %c1) + %c0 = f32[] constant(0) + %zeros = f32[4,4] broadcast(f32[] %c0) + %sel0 = f32[4,4] select(%cmp0, %ones, %zeros) + ROOT %cmp1 = pred[4,4] compare(%sel0, %zeros), direction=NE + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Compare(m::Parameter(0), m::Parameter(1)) + .WithComparisonDirection(ComparisonDirection::kGe))); +} + TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) { // Some backends may have better performance by treating an outer product as a // Dot, rather than a broadcast Multiply @@ -9925,13 +10440,12 @@ TEST_F(AlgebraicSimplifierTest, TransposeOfBroadcast) { EXPECT_TRUE( RunHloPass(AlgebraicSimplifier(default_options_), m.get()).value()); SCOPED_TRACE(m->ToString()); - EXPECT_THAT( - m->entry_computation()->root_instruction(), - GmockMatch( - m::Broadcast(m::Parameter(0)) - .WithPredicate([](const HloInstruction* instr) { - return instr->dimensions() == std::vector({0, 3}); - }))); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)) + .WithPredicate([](const HloInstruction* instr) { + return instr->dimensions() == + std::vector({0, 3}); + }))); } TEST_F(AlgebraicSimplifierTest, TransposeBitcastOfBroadcast) { @@ -9947,13 +10461,12 @@ TEST_F(AlgebraicSimplifierTest, TransposeBitcastOfBroadcast) { options.set_is_layout_sensitive(true); EXPECT_TRUE(RunHloPass(AlgebraicSimplifier(options), m.get()).value()); SCOPED_TRACE(m->ToString()); - EXPECT_THAT( - m->entry_computation()->root_instruction(), - GmockMatch( - m::Broadcast(m::Parameter(0)) - .WithPredicate([](const HloInstruction* instr) { - return instr->dimensions() == std::vector({0, 3}); - }))); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)) + .WithPredicate([](const HloInstruction* instr) { + return instr->dimensions() == + std::vector({0, 3}); + }))); } TEST_F(AlgebraicSimplifierTest, TransposeOfBroadcastWithLayoutCheckSkipped) { @@ -10019,7 +10532,7 @@ TEST_F(AlgebraicSimplifierTest, DontSinkInstructionsInDSAsyncComputation) { dynamic_slice_sizes={1} ROOT %dynamic-slice-done = f32[1]{0} dynamic-slice-done(((f32[10]{0}, s32[]), f32[1]{0}, u32[]) - %dynamic-slice-start), dynamic_slice_sizes={1} + %dynamic-slice-start) } )"; @@ -10030,6 +10543,69 @@ TEST_F(AlgebraicSimplifierTest, DontSinkInstructionsInDSAsyncComputation) { EXPECT_FALSE(changed); } +TEST_F(AlgebraicSimplifierTest, NoOpSliceToDynamicOfPadToStatic) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[<=512] parameter(0) + c = (f32[512], s32[]) custom-call(p0), custom_call_target="PadToStatic" + gte0 = f32[512] get-tuple-element(c), index=0 + gte1 = s32[] get-tuple-element(c), index=1 + ROOT c2 = f32[<=512] custom-call(gte0, gte1), custom_call_target="SliceToDynamic" + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(0))); +} + +TEST_F(AlgebraicSimplifierTest, DiffShapeSliceToDynamicOfPadToStatic) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[<=512] parameter(0) + c = (f32[512], s32[]) custom-call(p0), custom_call_target="PadToStatic" + gte0 = f32[512] get-tuple-element(c), index=0 + gte1 = s32[] get-tuple-element(c), index=1 + ROOT c2 = f32[<=1024] custom-call(gte0, gte1), custom_call_target="SliceToDynamic" + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} + +TEST_F(AlgebraicSimplifierTest, DiffShapeSliceToDynamicDifferentPadToStatic) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[<=512] parameter(0) + c = (f32[512], s32[]) custom-call(p0), custom_call_target="PadToStatic" + p1 = f32[<=512] parameter(1) + c1 = (f32[512], s32[]) custom-call(p1), custom_call_target="PadToStatic" + gte0 = f32[512] get-tuple-element(c), index=0 + gte1 = s32[] get-tuple-element(c1), index=1 + ROOT c2 = f32[<=512] custom-call(gte0, gte1), custom_call_target="SliceToDynamic" + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} + +TEST_F(AlgebraicSimplifierTest, NotPadToStaticSizeDynamicDifferentPadToStatic) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[<=512] parameter(0) + c = (f32[512], s32[]) custom-call(p0), custom_call_target="PadToStatic" + gte0 = f32[512] get-tuple-element(c), index=0 + gte1 = s32[] parameter(1) + ROOT c2 = f32[<=512] custom-call(gte0, gte1), custom_call_target="SliceToDynamic" + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} class AlgebraicSimplifierUpcastDowncastTest : public AlgebraicSimplifierTest, public ::testing::WithParamInterface< @@ -10150,5 +10726,127 @@ GetUpcastDowncastTestCases() { INSTANTIATE_TEST_SUITE_P(AllTypes, AlgebraicSimplifierUpcastDowncastTest, ::testing::ValuesIn(GetUpcastDowncastTestCases())); +template +auto SparseDotMatcher(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) { + return match::Op() + .WithOpcode(HloOpcode::kDot) + .WithOperand(0, std::forward(arg0)) + .WithOperand(1, std::forward(arg1)) + .WithOperand(2, std::forward(arg2)); +} + +TEST_F(AlgebraicSimplifierTest, SparseDotRemoveDegenerateDimensions) { + const char* kHlo = R"( + HloModule m + ENTRY test { + %lhs = f32[1,5,10,16,1] parameter(0) + %rhs = f32[5,1,20,1,32] parameter(1) + %meta = u16[1,5,10,2,1] parameter(2) + ROOT %dot = f32[1,5,10,20] dot(%lhs, %rhs, %meta), + lhs_batch_dims={0,1}, rhs_batch_dims={1,0}, + lhs_contracting_dims={3,4}, rhs_contracting_dims={4,3}, + sparsity=L.3@2:4 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(module.get()).value()); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, GmockMatch(m::Reshape(SparseDotMatcher(m::Reshape(m::Parameter(0)), + m::Reshape(m::Parameter(1)), + m::Reshape(m::Parameter(2))) + .WithShape(F32, {5, 10, 20})))); + auto dot = Cast(root->operand(0)); + auto descriptor = dot->sparsity().front(); + EXPECT_EQ(descriptor.index(), 0); + EXPECT_EQ(descriptor.dimension(), 2); +} + +TEST_F(AlgebraicSimplifierTest, SparseDotMoveSliceToOperands) { + const char* kHlo = R"( + HloModule m + ENTRY test { + %lhs = f32[7,12,16] parameter(0) + %rhs = f32[7,22,32] parameter(1) + %meta = u16[7,12,2] parameter(2) + %dot = f32[7,12,22] dot(%lhs, %rhs, %meta), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sparsity=L.2@2:4 + ROOT %slice = f32[5,10,20] slice(%dot), slice={[0:5], [0:10], [0:20]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); + AlgebraicSimplifierOptions options; + options.set_use_associative_reordering(true); + ASSERT_TRUE(AlgebraicSimplifier(options).Run(module.get()).value()); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(SparseDotMatcher(m::Slice(m::Parameter(0)), + m::Slice(m::Parameter(1)), + m::Slice(m::Parameter(2))) + .WithShape(F32, {5, 10, 20}))); + auto dot = Cast(root); + auto descriptor = dot->sparsity().front(); + EXPECT_EQ(descriptor.index(), 0); + EXPECT_EQ(descriptor.dimension(), 2); +} + +TEST_F(AlgebraicSimplifierTest, SparseDotTranspose) { + const char* hlo_string = R"( + HloModule m + ENTRY test { + %lhs = f32[10,16] parameter(0) + %rhs = f32[32,20] parameter(1) + %meta = u16[10,2] parameter(2) + %dot = f32[10,20] dot(%lhs, %rhs, %meta), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + sparsity=L.1@2:4 + ROOT %transpose = f32[20,10] transpose(%dot), dimensions={1,0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + EXPECT_TRUE(AlgebraicSimplifier(default_options_).Run(module.get()).value()); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + GmockMatch(SparseDotMatcher(m::Parameter(1), m::Parameter(0), + m::Parameter(2)) + .WithShape(F32, {20, 10}))); + auto dot = Cast(root); + auto descriptor = dot->sparsity().front(); + EXPECT_EQ(descriptor.index(), 1); + EXPECT_EQ(descriptor.dimension(), 1); +} + +TEST_F(AlgebraicSimplifierTest, BroadcastToTranspose) { + const std::string hlo_string = R"( + HloModule broadcast_module + ENTRY %main { + input = f32[6,4,3] parameter(0) + ROOT output = f32[4,3,6] broadcast(input), dimensions={2,0,1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + EXPECT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + HloInstruction* root = m->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Transpose(m::Parameter(0)))); + EXPECT_EQ(root->dimensions(), std::vector({1, 2, 0})); +} + +TEST_F(AlgebraicSimplifierTest, BroadcastToTranspose2) { + const std::string hlo_string = R"( + HloModule broadcast_module + ENTRY %main { + input = f32[6,4,3] parameter(0) + ROOT output = f32[4,6,3] broadcast(input), dimensions={1,0,2} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + EXPECT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + HloInstruction* root = m->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Transpose(m::Parameter(0)))); + EXPECT_EQ(root->dimensions(), std::vector({1, 0, 2})); +} + } // namespace } // namespace xla diff --git a/xla/service/algorithm_util.cc b/xla/service/algorithm_util.cc new file mode 100644 index 0000000000000..26f0e274cf04a --- /dev/null +++ b/xla/service/algorithm_util.cc @@ -0,0 +1,190 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/algorithm_util.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "xla/stream_executor/blas.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace algorithm_util { + +namespace { +namespace se = stream_executor; +} // namespace + +absl::StatusOr GetBlasComputationType( + PrecisionConfig::Algorithm algorithm) { + // Note: If we will support other algorithm & storage type combinations, such + // as ALG_DOT_BF16_BF16_F32 with F32 input and output storage types, then + // we'll have to also depend on the storage types here. For the mentioned + // example, the computation type would be kBF16AsF32. + // Only the currently supported algorithms are listed here. + switch (algorithm) { + case PrecisionConfig::ALG_DOT_F16_F16_F16: + return se::blas::ComputationType::kF16; + case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32: + case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM: + case PrecisionConfig::ALG_DOT_F16_F16_F32: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32: + case PrecisionConfig::ALG_DOT_F32_F32_F32: + return se::blas::ComputationType::kF32; + case PrecisionConfig::ALG_DOT_TF32_TF32_F32: + return se::blas::ComputationType::kTF32AsF32; + case PrecisionConfig::ALG_DOT_F64_F64_F64: + return se::blas::ComputationType::kF64; + default: + return absl::InternalError( + absl::StrFormat("GetBlasComputationType: unsupported algorithm %s", + xla::PrecisionConfig::Algorithm_Name(algorithm))); + } +} + +absl::StatusOr GetDotAccumulatorType( + PrecisionConfig::Algorithm algorithm) { + // All dot algorithms should be listed here. + switch (algorithm) { + case PrecisionConfig::ALG_DOT_F16_F16_F16: + return F16; + case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32: + case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM: + case PrecisionConfig::ALG_DOT_F16_F16_F32: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: + case PrecisionConfig::ALG_DOT_TF32_TF32_F32: + case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3: + case PrecisionConfig::ALG_DOT_F32_F32_F32: + return F32; + case PrecisionConfig::ALG_DOT_BF16_BF16_BF16: + return BF16; + case PrecisionConfig::ALG_DOT_F64_F64_F64: + return F64; + case PrecisionConfig::ALG_UNSET: + default: + return absl::InternalError( + absl::StrFormat("GetDotAccumulatorType: unsupported algorithm %s", + xla::PrecisionConfig::Algorithm_Name(algorithm))); + } +} + +bool HasTf32InputType(PrecisionConfig::Algorithm algorithm) { + return algorithm == PrecisionConfig::ALG_DOT_TF32_TF32_F32 || + algorithm == PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3; +} + +bool HasFastAccum(PrecisionConfig::Algorithm algorithm) { + return algorithm == PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM; +} + +// It's clear that those libraries could support more, but we only list the ones +// which we explicitly test for now. +bool IsSupportedByCublasOrCublasLt(PrecisionConfig::Algorithm algorithm) { + switch (algorithm) { + case PrecisionConfig::ALG_UNSET: + case PrecisionConfig::ALG_DOT_F16_F16_F32: + case PrecisionConfig::ALG_DOT_F32_F32_F32: + case PrecisionConfig::ALG_DOT_F64_F64_F64: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32: + case PrecisionConfig::ALG_DOT_TF32_TF32_F32: + case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32: + case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM: + return true; + default: + return false; + } +} + +// Checks if we support the given algorithm using cuDNN. +bool IsSupportedByCudnn(PrecisionConfig::Algorithm algorithm) { + switch (algorithm) { + // When the CuDnn backend starts supporting specific algorithms, then + // those should be listed here. + case PrecisionConfig::ALG_UNSET: + return true; + default: + return false; + } +} + +bool IsSupportedByElementalIrEmitter(PrecisionConfig::Algorithm algorithm) { + switch (algorithm) { + // Probably more can be added. + case PrecisionConfig::ALG_DOT_F32_F32_F32: + case PrecisionConfig::ALG_UNSET: + return true; + default: + return false; + } +} + +// Is the given algorithm supported on GPU with the given compute capability and +// input/output storage types. +bool IsSupportedDotAlgorithmOnGpu( + PrecisionConfig::Algorithm algorithm, + stream_executor::GpuComputeCapability gpu_compute_capability, + PrimitiveType input_storage_type, PrimitiveType output_storage_type) { + // Note: We may want to add some complex types here if people request that. + const bool is_cuda_ge_ampere = + std::holds_alternative( + gpu_compute_capability) && + std::get(gpu_compute_capability) + .IsAtLeastAmpere(); + + const bool is_cuda_ge_ada = + std::holds_alternative( + gpu_compute_capability) && + std::get(gpu_compute_capability) + .IsAtLeast(8, 9); + + switch (algorithm) { + case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32: + case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM: + // Other F8 types are actually not supported by NVIDIA GPUs. + return is_cuda_ge_ada && + (input_storage_type == F8E5M2 || input_storage_type == F8E4M3FN) && + (output_storage_type == F8E5M2 || + output_storage_type == F8E4M3FN || output_storage_type == F16 || + output_storage_type == BF16 || output_storage_type == F32); + case PrecisionConfig::ALG_DOT_F16_F16_F32: + return input_storage_type == F16 && + (output_storage_type == F16 || output_storage_type == F32); + case PrecisionConfig::ALG_DOT_BF16_BF16_F32: + return is_cuda_ge_ampere && input_storage_type == BF16 && + (output_storage_type == BF16 || output_storage_type == F32); + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: + return is_cuda_ge_ampere && input_storage_type == F32 && + output_storage_type == F32; + case PrecisionConfig::ALG_DOT_TF32_TF32_F32: + return is_cuda_ge_ampere && input_storage_type == F32 && + output_storage_type == F32; + case PrecisionConfig::ALG_DOT_F32_F32_F32: + return input_storage_type == F32 && output_storage_type == F32; + case PrecisionConfig::ALG_DOT_F64_F64_F64: + return input_storage_type == F64 && output_storage_type == F64; + default: + return false; + } +} + +} // namespace algorithm_util +} // namespace xla diff --git a/xla/service/algorithm_util.h b/xla/service/algorithm_util.h new file mode 100644 index 0000000000000..9ce28f552dd5a --- /dev/null +++ b/xla/service/algorithm_util.h @@ -0,0 +1,73 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_ALGORITHM_UTIL_H_ +#define XLA_SERVICE_ALGORITHM_UTIL_H_ + +#include "absl/status/statusor.h" +#include "xla/stream_executor/blas.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// We try to keep most algorithm-specific queries in this file, so that we only +// have to update one file when we add a new one. +// We can also add some platform-specific queries as long as we don't need to +// depend on specific targets, such as the "gpu" folder. +namespace algorithm_util { + +// Get the ComputationType corresponding to an algorithm. See the +// ComputationType definition for more info. +absl::StatusOr GetBlasComputationType( + PrecisionConfig::Algorithm algorithm); + +// Get the accumulator type of an algorithm. +absl::StatusOr GetDotAccumulatorType( + PrecisionConfig::Algorithm algorithm); + +// Are the AType & BType TF32? +bool HasTf32InputType(PrecisionConfig::Algorithm algorithm); + +// Checks if the algorithm uses fast accumulation as in +// CUBLASLT_MATMUL_DESC_FAST_ACCUM. +bool HasFastAccum(PrecisionConfig::Algorithm algorithm); + +// Checks if we support the given algorithm using cuBLAS or cuBLASLt. +// +// It's clear that those libraries could support more, but we only list the ones +// which we explicitly test for now. +// +// We may want to also check storage types, but for now those are checked in +// IsSupportedDotAlgorithmOnGpu. +bool IsSupportedByCublasOrCublasLt(PrecisionConfig::Algorithm algorithm); + +// Checks if we support the given algorithm using cuDNN. +bool IsSupportedByCudnn(PrecisionConfig::Algorithm algorithm); + +// Checks if we support the given algorithm using the elemental IR emitter. +bool IsSupportedByElementalIrEmitter(PrecisionConfig::Algorithm algorithm); + +// Is the given algorithm supported on GPU with the given compute capability and +// input/output storage types. +bool IsSupportedDotAlgorithmOnGpu( + PrecisionConfig::Algorithm algorithm, + stream_executor::GpuComputeCapability gpu_compute_capability, + PrimitiveType input_storage_type, PrimitiveType output_storage_type); + +} // namespace algorithm_util +} // namespace xla + +#endif // XLA_SERVICE_ALGORITHM_UTIL_H_ diff --git a/xla/service/all_gather_broadcast_reorder.cc b/xla/service/all_gather_broadcast_reorder.cc index 6345d701b8e77..31a72c2d82750 100644 --- a/xla/service/all_gather_broadcast_reorder.cc +++ b/xla/service/all_gather_broadcast_reorder.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,7 +27,7 @@ limitations under the License. namespace xla { -StatusOr AllGatherBroadcastReorder::Run( +absl::StatusOr AllGatherBroadcastReorder::Run( HloModule *module, const absl::flat_hash_set &execution_threads) { if (hlo_query::ContainsLayoutConstrainedCollective(*module, diff --git a/xla/service/all_gather_broadcast_reorder.h b/xla/service/all_gather_broadcast_reorder.h index 57b63518e2e81..5746f2424fd95 100644 --- a/xla/service/all_gather_broadcast_reorder.h +++ b/xla/service/all_gather_broadcast_reorder.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ class AllGatherBroadcastReorder : public HloModulePass { absl::string_view name() const override { return "all-gather-bcast-reorder"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/all_gather_broadcast_reorder_test.cc b/xla/service/all_gather_broadcast_reorder_test.cc index cb69f2edebd17..0c7eb62232d13 100644 --- a/xla/service/all_gather_broadcast_reorder_test.cc +++ b/xla/service/all_gather_broadcast_reorder_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/all_gather_combiner.cc b/xla/service/all_gather_combiner.cc index b42fb21a8554b..ecb9fb42474d6 100644 --- a/xla/service/all_gather_combiner.cc +++ b/xla/service/all_gather_combiner.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -47,22 +48,26 @@ limitations under the License. namespace xla { namespace { +// Returns the most frequent all-gather dim if it can be a valid gather dim +// for all shapes involved, else returns 0. int64_t FindMostFrequentGatherDim( absl::Span to_combine) { assert(!to_combine.empty()); // Count frequencies. + int64_t min_rank = std::numeric_limits::max(); std::vector frequency; for (const HloInstruction* it : to_combine) { int64_t dim = Cast(it)->all_gather_dimension(); frequency.resize(std::max(dim + 1, static_cast(frequency.size())), 0); frequency[dim]++; + min_rank = std::min(min_rank, it->shape().rank()); } int64_t most_frequent_dim = std::distance( frequency.begin(), std::max_element(frequency.begin(), frequency.end())); - return most_frequent_dim; + return most_frequent_dim < min_rank ? most_frequent_dim : 0; } // Combines the elements of to_combine into a single AllGather op. All entries @@ -193,7 +198,7 @@ AllGatherCombiner::AllGatherCombiner(int64_t combine_threshold_in_bytes, combine_threshold_count_(combine_threshold_count), combine_by_dim_(combine_by_dim) {} -StatusOr AllGatherCombiner::Run( +absl::StatusOr AllGatherCombiner::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(1) << "Running AllGatherCombiner with threshold of " diff --git a/xla/service/all_gather_combiner.h b/xla/service/all_gather_combiner.h index 4208cce37f828..8e7a0e062799c 100644 --- a/xla/service/all_gather_combiner.h +++ b/xla/service/all_gather_combiner.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ class AllGatherCombiner : public HloModulePass { absl::string_view name() const override { return "all-gather-combiner"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/all_gather_combiner_test.cc b/xla/service/all_gather_combiner_test.cc index 20b3e57ff8394..4635e3a1bedd2 100644 --- a/xla/service/all_gather_combiner_test.cc +++ b/xla/service/all_gather_combiner_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -312,7 +312,7 @@ ENTRY entry { op::Sharding("{{maximal device=0}, {maximal device=0}}")); } -TEST_F(AllGatherCombinerTest, CombineAllGathersIrrespectiveOfDim) { +TEST_F(AllGatherCombinerTest, CombineAllGathersDifferentDims) { const char* const hlo_string = R"( HloModule Module @@ -343,7 +343,7 @@ ENTRY entry { op::Bitcast(op::GetTupleElement(combined_all_gather, 1)))); } -TEST_F(AllGatherCombinerTest, CombineManyAllGathersIrrespectiveOfDim) { +TEST_F(AllGatherCombinerTest, CombineManyAllGathersDifferentDims) { const char* const hlo_string = R"( HloModule Module @@ -392,7 +392,7 @@ ENTRY entry { ASSERT_EQ(0, all_gathers.front()->all_gather_dimension()); } -TEST_F(AllGatherCombinerTest, CombineManyAllGathersIrrespectiveOfDimRank4) { +TEST_F(AllGatherCombinerTest, CombineManyAllGathersDifferentDimsRank4) { const char* const hlo_string = R"( HloModule Module @@ -442,6 +442,59 @@ ENTRY entry { ASSERT_EQ(0, all_gathers.front()->all_gather_dimension()); } +TEST_F(AllGatherCombinerTest, CombineManyAllGathersDifferentDimsMixedRanks) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + param0 = f32[2,7]{1,0} parameter(0) + param1 = f32[3,8]{1,0} parameter(1) + param2 = f32[4,9]{0,1} parameter(2) + param3 = f32[5,10]{0,1} parameter(3) + param4 = f32[6]{0} parameter(4) + allgather0 = f32[2,28]{1,0} all-gather(param0), replica_groups={}, + dimensions={1} + allgather1 = f32[3,32]{1,0} all-gather(param1), replica_groups={}, + dimensions={1} + allgather2 = f32[4,36]{0,1} all-gather(param2), replica_groups={}, + dimensions={1} + allgather3 = f32[5,40]{0,1} all-gather(param3), replica_groups={}, + dimensions={1} + allgather4 = f32[24]{0} all-gather(param4), replica_groups={}, + dimensions={0} + ROOT tuple = (f32[2,28]{1,0}, f32[3,32]{1,0}, f32[4,36]{0,1}, f32[5,40]{0,1}, + f32[24]{0}) tuple(allgather0, allgather1, allgather2, allgather3, + allgather4) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + AllGatherCombiner combine(1024 * 1024, kMaxCombineCount, + /*combine_by_dim=*/false); + ASSERT_EQ(AllGatherCount(*module), 5); + TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get())); + EXPECT_TRUE(changed); + + Matcher combined_all_gather = op::AllGather( + op::Bitcast(op::Parameter(0)), op::Bitcast(op::Parameter(1)), + op::Bitcast(op::Parameter(2)), op::Bitcast(op::Parameter(3)), + op::Parameter(4)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::Bitcast(op::GetTupleElement(combined_all_gather, 0)), + op::Bitcast(op::GetTupleElement(combined_all_gather, 1)), + op::Bitcast(op::GetTupleElement(combined_all_gather, 2)), + op::Bitcast(op::GetTupleElement(combined_all_gather, 3)), + op::GetTupleElement(combined_all_gather, 4))); + std::vector all_gathers = FindAllGathers(*module); + ASSERT_EQ(1, all_gathers.size()); + + // when using different ranks and the most frequent AG dim (1) is not valid + // for rank 1 shape, we use default dim 0. + ASSERT_EQ(0, all_gathers.front()->all_gather_dimension()); +} + TEST_F(AllGatherCombinerTest, CombineAllGathersByDim) { const char* const hlo_string = R"( HloModule Module diff --git a/xla/service/all_gather_decomposer.cc b/xla/service/all_gather_decomposer.cc index 7e5325a52a595..24db8fd359947 100644 --- a/xla/service/all_gather_decomposer.cc +++ b/xla/service/all_gather_decomposer.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -101,7 +101,7 @@ Status DecomposeAllGather(HloAllGatherInstruction* ag, HloComputation* comp) { return OkStatus(); } -StatusOr AllGatherDecomposer::Run( +absl::StatusOr AllGatherDecomposer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/all_gather_decomposer.h b/xla/service/all_gather_decomposer.h index e08c638385eaa..da56d0c402303 100644 --- a/xla/service/all_gather_decomposer.h +++ b/xla/service/all_gather_decomposer.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ class AllGatherDecomposer : public HloModulePass { // Run AllGatherDecomposer pass on computations in 'module'. // Returns whether the 'module' was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/all_gather_decomposer_test.cc b/xla/service/all_gather_decomposer_test.cc index 4eb629a347311..ccb2a7cadd505 100644 --- a/xla/service/all_gather_decomposer_test.cc +++ b/xla/service/all_gather_decomposer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/all_reduce_combiner.cc b/xla/service/all_reduce_combiner.cc index 1f2dbf0dee978..5d7b9b6ee4c5b 100644 --- a/xla/service/all_reduce_combiner.cc +++ b/xla/service/all_reduce_combiner.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -108,7 +108,7 @@ AllReduceCombiner::AllReduceCombiner(int64_t combine_threshold_in_bytes, : combine_threshold_in_bytes_(combine_threshold_in_bytes), combine_threshold_count_(combine_threshold_count) {} -StatusOr AllReduceCombiner::Run( +absl::StatusOr AllReduceCombiner::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(1) << "Running AllReduceCombiner with threshold of " diff --git a/xla/service/all_reduce_combiner.h b/xla/service/all_reduce_combiner.h index abc7cdd2e1499..4ef9e96125825 100644 --- a/xla/service/all_reduce_combiner.h +++ b/xla/service/all_reduce_combiner.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ class AllReduceCombiner : public HloModulePass { absl::string_view name() const override { return "all-reduce-combiner"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/all_reduce_combiner_test.cc b/xla/service/all_reduce_combiner_test.cc index 5c8e0769294c7..33399c2bae720 100644 --- a/xla/service/all_reduce_combiner_test.cc +++ b/xla/service/all_reduce_combiner_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/all_reduce_contiguous.cc b/xla/service/all_reduce_contiguous.cc index 9a5d35ec12e8e..7f07b2fa756df 100644 --- a/xla/service/all_reduce_contiguous.cc +++ b/xla/service/all_reduce_contiguous.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -83,7 +83,7 @@ Status ReplaceWithContiguousAllReduce(HloAllReduceInstruction* all_reduce) { } } // namespace -StatusOr AllReduceContiguous::Run( +absl::StatusOr AllReduceContiguous::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(1) << "Running AllReduceContiguous"; diff --git a/xla/service/all_reduce_contiguous.h b/xla/service/all_reduce_contiguous.h index dd89e97041dd1..d81582536fba4 100644 --- a/xla/service/all_reduce_contiguous.h +++ b/xla/service/all_reduce_contiguous.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ class AllReduceContiguous : public HloModulePass { absl::string_view name() const override { return "all-reduce-contiguous"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/all_reduce_contiguous_test.cc b/xla/service/all_reduce_contiguous_test.cc index aef70985eb661..ccd1effdbc6c3 100644 --- a/xla/service/all_reduce_contiguous_test.cc +++ b/xla/service/all_reduce_contiguous_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/all_reduce_folder.cc b/xla/service/all_reduce_folder.cc index 6870a0cf83beb..9d034dd45ef60 100644 --- a/xla/service/all_reduce_folder.cc +++ b/xla/service/all_reduce_folder.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -136,7 +136,7 @@ std::optional> FoldReplicaGroups( } // namespace -StatusOr AllReduceFolder::Run( +absl::StatusOr AllReduceFolder::Run( HloModule *module, const absl::flat_hash_set &execution_threads) { if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) { diff --git a/xla/service/all_reduce_folder.h b/xla/service/all_reduce_folder.h index c603cc42eb742..e175a65677163 100644 --- a/xla/service/all_reduce_folder.h +++ b/xla/service/all_reduce_folder.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ class AllReduceFolder : public HloModulePass { absl::string_view name() const override { return "all-reduce-folder"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/all_reduce_folder_test.cc b/xla/service/all_reduce_folder_test.cc index a30b0b51e56b2..57d2c7518838d 100644 --- a/xla/service/all_reduce_folder_test.cc +++ b/xla/service/all_reduce_folder_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,15 +30,15 @@ using ::testing::HasSubstr; class AllReduceFolderTest : public HloTestBase { public: - StatusOr> RunPass(absl::string_view hlo_module, - bool expect_change) { + absl::StatusOr> RunPass( + absl::string_view hlo_module, bool expect_change) { TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module)); auto changed = AllReduceFolder().Run(module.get()); if (!changed.ok()) { return changed.status(); } EXPECT_EQ(changed.value(), expect_change); - return StatusOr>(std::move(module)); + return absl::StatusOr>(std::move(module)); } size_t AllReduceCount(std::unique_ptr &module) { diff --git a/xla/service/all_reduce_key.cc b/xla/service/all_reduce_key.cc index 34b33a2ae4120..ddc0ac79a41aa 100644 --- a/xla/service/all_reduce_key.cc +++ b/xla/service/all_reduce_key.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/all_reduce_key.h b/xla/service/all_reduce_key.h index 126b182499af0..ae560edecce82 100644 --- a/xla/service/all_reduce_key.h +++ b/xla/service/all_reduce_key.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/all_reduce_promotion.cc b/xla/service/all_reduce_promotion.cc index 30965128a8152..b0328759c7d31 100644 --- a/xla/service/all_reduce_promotion.cc +++ b/xla/service/all_reduce_promotion.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -61,7 +61,7 @@ AllReducePromotion::AllReducePromotion( absl::Span const> from_to_types) : pass_(from_to_types, IsAllReduce, CloneAllReduce) {} -StatusOr AllReducePromotion::Run( +absl::StatusOr AllReducePromotion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { return pass_.Run(module, execution_threads); diff --git a/xla/service/all_reduce_promotion.h b/xla/service/all_reduce_promotion.h index 83a3d3facad3e..a1ad33033187f 100644 --- a/xla/service/all_reduce_promotion.h +++ b/xla/service/all_reduce_promotion.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ class AllReducePromotion : public HloModulePass { absl::string_view name() const override { return "all-reduce-promotion"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/all_reduce_promotion_test.cc b/xla/service/all_reduce_promotion_test.cc index 08ee9e4f582da..380c1c3cf8e24 100644 --- a/xla/service/all_reduce_promotion_test.cc +++ b/xla/service/all_reduce_promotion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/all_reduce_reassociate.cc b/xla/service/all_reduce_reassociate.cc index dc7e7990e383e..84d83b0b736c6 100644 --- a/xla/service/all_reduce_reassociate.cc +++ b/xla/service/all_reduce_reassociate.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -171,7 +171,7 @@ bool MatchOperandsToAllReduceWithOptionalConvert(HloInstruction* inst, } } // namespace -StatusOr AllReduceReassociate::Run( +absl::StatusOr AllReduceReassociate::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) { diff --git a/xla/service/all_reduce_reassociate.h b/xla/service/all_reduce_reassociate.h index 7f030c3fd1b2e..228d2f5cd15b5 100644 --- a/xla/service/all_reduce_reassociate.h +++ b/xla/service/all_reduce_reassociate.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ class AllReduceReassociate : public HloModulePass { absl::string_view name() const override { return "all-reduce-reassociate"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/all_reduce_reassociate_test.cc b/xla/service/all_reduce_reassociate_test.cc index 058909ade40b5..aa1f13eaf04a7 100644 --- a/xla/service/all_reduce_reassociate_test.cc +++ b/xla/service/all_reduce_reassociate_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ using ::testing::_; class AllReduceSimplifierTest : public HloTestBase { public: - StatusOr> RunPass( + absl::StatusOr> RunPass( absl::string_view hlo_module, bool expect_change, bool reassociate_converted_ar = false) { TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module)); @@ -46,7 +46,7 @@ class AllReduceSimplifierTest : public HloTestBase { return changed.status(); } EXPECT_EQ(changed.value(), expect_change); - return StatusOr>(std::move(module)); + return absl::StatusOr>(std::move(module)); } size_t AllReduceCount(std::unique_ptr& module) { diff --git a/xla/service/all_reduce_simplifier.cc b/xla/service/all_reduce_simplifier.cc index 235ea2dcd41be..67aadf41b9e98 100644 --- a/xla/service/all_reduce_simplifier.cc +++ b/xla/service/all_reduce_simplifier.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,7 +27,7 @@ limitations under the License. namespace xla { -StatusOr AllReduceSimplifier::Run( +absl::StatusOr AllReduceSimplifier::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { TF_ASSIGN_OR_RETURN( diff --git a/xla/service/all_reduce_simplifier.h b/xla/service/all_reduce_simplifier.h index 702fca84f06fa..72bc60923dc3b 100644 --- a/xla/service/all_reduce_simplifier.h +++ b/xla/service/all_reduce_simplifier.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ class AllReduceSimplifier : public HloModulePass { // Run all-reduce simplification on the given computation. Returns whether the // computation was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/all_reduce_simplifier_test.cc b/xla/service/all_reduce_simplifier_test.cc index 788c7da35910a..0843fc6df1a87 100644 --- a/xla/service/all_reduce_simplifier_test.cc +++ b/xla/service/all_reduce_simplifier_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/all_to_all_decomposer.cc b/xla/service/all_to_all_decomposer.cc index 7042dd500c8c7..241b242f693fc 100644 --- a/xla/service/all_to_all_decomposer.cc +++ b/xla/service/all_to_all_decomposer.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -45,7 +45,7 @@ bool AllToAllDecomposer::InstructionMatchesPattern( } return all_to_all->shape().rank() < min_array_rank_; } -StatusOr AllToAllDecomposer::ExpandInstruction( +absl::StatusOr AllToAllDecomposer::ExpandInstruction( HloInstruction* instruction) { auto* all_to_all = Cast(instruction); int64_t split_dim = *all_to_all->split_dimension(); diff --git a/xla/service/all_to_all_decomposer.h b/xla/service/all_to_all_decomposer.h index 33313cadf25de..3ef1891a41266 100644 --- a/xla/service/all_to_all_decomposer.h +++ b/xla/service/all_to_all_decomposer.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ class AllToAllDecomposer : public OpExpanderPass { private: bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; bool decompose_to_tuple_; int64_t min_array_rank_; diff --git a/xla/service/allocation_tracker.cc b/xla/service/allocation_tracker.cc index c628e4053e129..a2cbad64b7596 100644 --- a/xla/service/allocation_tracker.cc +++ b/xla/service/allocation_tracker.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ limitations under the License. namespace xla { -StatusOr AllocationTracker::Register( +absl::StatusOr AllocationTracker::Register( ScopedShapedBuffer shaped_buffer, const std::string& tag) { absl::MutexLock lock(&mutex_); VLOG(2) << "Register"; @@ -40,7 +40,7 @@ StatusOr AllocationTracker::Register( return RegisterInternal(std::move(replicated_buffers), tag); } -StatusOr AllocationTracker::RegisterReplicatedBuffers( +absl::StatusOr AllocationTracker::RegisterReplicatedBuffers( std::vector replicated_buffers, const std::string& tag) { absl::MutexLock lock(&mutex_); @@ -57,7 +57,7 @@ static ShapedBuffer ReleaseIfScopedShapedBuffer(ScopedShapedBuffer b) { } template -StatusOr AllocationTracker::RegisterInternal( +absl::StatusOr AllocationTracker::RegisterInternal( std::vector replicated_buffers, const std::string& tag) { static_assert(std::is_same::value || std::is_same::value, @@ -126,8 +126,8 @@ Status AllocationTracker::Unregister(const GlobalDataHandle& data) { return OkStatus(); } -StatusOr> AllocationTracker::DeconstructTuple( - const GlobalDataHandle& data) { +absl::StatusOr> +AllocationTracker::DeconstructTuple(const GlobalDataHandle& data) { absl::MutexLock lock(&mutex_); TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, @@ -164,13 +164,13 @@ StatusOr> AllocationTracker::DeconstructTuple( return std::move(element_handles); } -StatusOr> AllocationTracker::Resolve( +absl::StatusOr> AllocationTracker::Resolve( const GlobalDataHandle& data) const { absl::MutexLock lock(&mutex_); return AllocationTracker::ResolveInternal(data); } -StatusOr AllocationTracker::ResolveForReplica( +absl::StatusOr AllocationTracker::ResolveForReplica( const GlobalDataHandle& data, int replica_id) const { absl::MutexLock lock(&mutex_); TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, @@ -184,8 +184,8 @@ StatusOr AllocationTracker::ResolveForReplica( return replicated_buffers[replica_id]; } -StatusOr> AllocationTracker::ResolveInternal( - const GlobalDataHandle& data) const { +absl::StatusOr> +AllocationTracker::ResolveInternal(const GlobalDataHandle& data) const { VLOG(2) << "resolve:" << data.handle(); auto it = handle_to_shaped_buffers_.find(data.handle()); if (it == handle_to_shaped_buffers_.end()) { diff --git a/xla/service/allocation_tracker.h b/xla/service/allocation_tracker.h index efbbdd68240c5..c8359c0fc3097 100644 --- a/xla/service/allocation_tracker.h +++ b/xla/service/allocation_tracker.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -42,12 +42,12 @@ class AllocationTracker { // Registers a shaped buffer of device memory, and returns a corresponding // handle that can be used for talking to XLA clients. The given shaped buffer // will be treated as the buffer corresponding to the only replica. - StatusOr Register(ScopedShapedBuffer shaped_buffer, - const std::string& tag); + absl::StatusOr Register(ScopedShapedBuffer shaped_buffer, + const std::string& tag); // Registers a vector of shaped buffers of device memory, one per replica, and // returns a corresponding handle that can be used for talking to XLA clients. - StatusOr RegisterReplicatedBuffers( + absl::StatusOr RegisterReplicatedBuffers( std::vector replicated_buffers, const std::string& tag); @@ -55,20 +55,20 @@ class AllocationTracker { Status Unregister(const GlobalDataHandle& data); // Returns a vector of global data handles that point to the tuple elements. - StatusOr> DeconstructTuple( + absl::StatusOr> DeconstructTuple( const GlobalDataHandle& Data); // Resolve a handle from an XLA client to a vector of shaped buffers, one per // replica, or provide an error status to say whether any of those buffers // were not found (or found, but found deallocated). - StatusOr> Resolve( + absl::StatusOr> Resolve( const GlobalDataHandle& data) const; // Resolves a handle from an XLA client and replica id to a shaped buffer, or // provide an error status to say whether it was not found (or found, but // found deallocated). - StatusOr ResolveForReplica(const GlobalDataHandle& data, - int replica_id) const; + absl::StatusOr ResolveForReplica( + const GlobalDataHandle& data, int replica_id) const; private: // Data structure encapsulating single memory allocation on the device. @@ -83,7 +83,7 @@ class AllocationTracker { // Internal helper which resolves the given GlobalDataHandle to a // list of ScopedShapedBuffers. - StatusOr> ResolveInternal( + absl::StatusOr> ResolveInternal( const GlobalDataHandle& data) const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Internal helper which registers a vector of shaped buffers, one per @@ -91,7 +91,7 @@ class AllocationTracker { // it's ShapedBuffer, all of the given buffers must already be tracked by this // object -- presumably this is a call from DeconstructTuple. template - StatusOr RegisterInternal( + absl::StatusOr RegisterInternal( std::vector replicated_buffers, const std::string& tag) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); diff --git a/xla/service/ar_crs_combiner.cc b/xla/service/ar_crs_combiner.cc index a14c73d95e09f..1ad71ca8d7e56 100644 --- a/xla/service/ar_crs_combiner.cc +++ b/xla/service/ar_crs_combiner.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -42,8 +42,8 @@ namespace { // divide by the number of partitions. Depending on the topology and the // implementation of the all-reduce for the backend, this may give a better // performance. -StatusOr ReplaceReplicatedAllReduce(HloModule* module, - int64_t partition_count) { +absl::StatusOr ReplaceReplicatedAllReduce(HloModule* module, + int64_t partition_count) { TF_ASSIGN_OR_RETURN( auto replication_analysis, HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/true)); @@ -534,7 +534,7 @@ Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsSPMD( return OkStatus(); } -StatusOr ArCrsCombiner::RewriteGraph() { +absl::StatusOr ArCrsCombiner::RewriteGraph() { if (all_reduce_map_.empty()) { return false; } @@ -600,7 +600,7 @@ StatusOr ArCrsCombiner::RewriteGraph() { return true; } -StatusOr ArCrsCombiner::Run( +absl::StatusOr ArCrsCombiner::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { call_graph_ = CallGraph::Build(module); diff --git a/xla/service/ar_crs_combiner.h b/xla/service/ar_crs_combiner.h index 64df84bd1ba76..7b537b7dd8742 100644 --- a/xla/service/ar_crs_combiner.h +++ b/xla/service/ar_crs_combiner.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -77,7 +77,7 @@ class ArCrsCombiner : public HloModulePass { spmd_partition_(spmd_partition) {} absl::string_view name() const override { return "ar-crs-combiner"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -161,7 +161,7 @@ class ArCrsCombiner : public HloModulePass { // Performs the graph rewrite that eliminates the early AllReduce and turns // the later CRS into an AllReduce. - StatusOr RewriteGraph(); + absl::StatusOr RewriteGraph(); int num_spatial_partitions_; diff --git a/xla/service/ar_crs_combiner_test.cc b/xla/service/ar_crs_combiner_test.cc index 87c66c9179f0c..119f0f41b02fe 100644 --- a/xla/service/ar_crs_combiner_test.cc +++ b/xla/service/ar_crs_combiner_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/async_collective_creator.cc b/xla/service/async_collective_creator.cc index e2641d3b0e039..2903a87bede47 100644 --- a/xla/service/async_collective_creator.cc +++ b/xla/service/async_collective_creator.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,8 +24,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/shape_inference.h" +#include "xla/util.h" #include "tsl/platform/errors.h" namespace xla { @@ -36,7 +38,8 @@ struct ReplacedAsync { HloInstruction* done; }; -StatusOr CreateAsyncAllReduce(HloInstruction* instruction) { +absl::StatusOr CreateAsyncAllReduce( + HloInstruction* instruction) { HloComputation* computation = instruction->parent(); auto* ar = Cast(instruction); HloInstruction* start = @@ -50,7 +53,8 @@ StatusOr CreateAsyncAllReduce(HloInstruction* instruction) { return ReplacedAsync{start, done}; } -StatusOr CreateAsyncAllGather(HloInstruction* instruction) { +absl::StatusOr CreateAsyncAllGather( + HloInstruction* instruction) { HloComputation* computation = instruction->parent(); auto* ag = Cast(instruction); std::vector operand_shapes; @@ -74,7 +78,7 @@ StatusOr CreateAsyncAllGather(HloInstruction* instruction) { return ReplacedAsync{start, done}; } -StatusOr CreateAsyncCollectivePermute( +absl::StatusOr CreateAsyncCollectivePermute( HloInstruction* instruction, absl::Span context_shapes) { HloComputation* computation = instruction->parent(); auto* cp = Cast(instruction); @@ -111,7 +115,7 @@ StatusOr CreateAsyncCollectivePermute( return ReplacedAsync{start, done}; } -StatusOr CreateAsyncStartDone( +absl::StatusOr CreateAsyncStartDone( HloInstruction* instruction, absl::Span context_shapes) { HloComputation* computation = instruction->parent(); TF_ASSIGN_OR_RETURN( @@ -125,97 +129,113 @@ StatusOr CreateAsyncStartDone( } // namespace -StatusOr AsyncCollectiveCreator::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { +// Find all supported collective ops first as we can't modify the instructions +// while iterating through them. +std::vector AsyncCollectiveCreator::MatchCollectives( + HloComputation* computation) { + std::vector supported_collectives; + for (HloInstruction* instruction : computation->instructions()) { + const HloOpcode op = instruction->opcode(); + if ((op == HloOpcode::kAllReduce && + config_.convert_all_reduce(instruction)) || + (op == HloOpcode::kAllGather && + config_.convert_all_gather(instruction)) || + (op == HloOpcode::kCollectiveBroadcast && + config_.convert_collective_broadcast(instruction)) || + (op == HloOpcode::kCollectivePermute && + config_.convert_collective_permute(instruction)) || + (op == HloOpcode::kAllToAll && + config_.convert_all_to_all(instruction)) || + (op == HloOpcode::kReduceScatter && + config_.convert_reduce_scatter(instruction))) { + supported_collectives.push_back(instruction); + } + } + return supported_collectives; +} + +absl::StatusOr AsyncCollectiveCreator::ReplaceCollectives( + HloComputation* computation, + std::vector& supported_collectives) { bool changed = false; - for (HloComputation* computation : - module->MakeNonfusionComputations(execution_threads)) { - // Find all supported collective ops first as we can't modify the - // instructions while iterating through them. - std::vector supported_collectives; - for (HloInstruction* instruction : computation->instructions()) { - const HloOpcode op = instruction->opcode(); - if ((op == HloOpcode::kAllReduce && - config_.convert_all_reduce(instruction)) || - (op == HloOpcode::kAllGather && - config_.convert_all_gather(instruction)) || - (op == HloOpcode::kCollectivePermute && - config_.convert_collective_permute(instruction)) || - (op == HloOpcode::kAllToAll && - config_.convert_all_to_all(instruction)) || - (op == HloOpcode::kReduceScatter && - config_.convert_reduce_scatter(instruction))) { - supported_collectives.push_back(instruction); - } + HloModule* module = computation->parent(); + absl::flat_hash_map replaced_pairs; + const bool should_update_schedule = + module->has_schedule() && + module->schedule().is_computation_scheduled(computation); + for (HloInstruction* instruction : supported_collectives) { + absl::StatusOr async_pair; + switch (instruction->opcode()) { + case HloOpcode::kAllReduce: + async_pair = CreateAsyncAllReduce(instruction); + break; + case HloOpcode::kAllGather: + async_pair = CreateAsyncAllGather(instruction); + break; + case HloOpcode::kCollectivePermute: + async_pair = CreateAsyncCollectivePermute( + instruction, config_.get_context_shapes(instruction)); + break; + case HloOpcode::kCollectiveBroadcast: + case HloOpcode::kAllToAll: + case HloOpcode::kReduceScatter: + async_pair = CreateAsyncStartDone( + instruction, config_.get_context_shapes(instruction)); + break; + default: + return Internal("Unexpected opcode %s", + HloOpcodeString(instruction->opcode())); } - if (supported_collectives.empty()) { - continue; + TF_RETURN_IF_ERROR(async_pair.status()); + async_pair->start->set_metadata(instruction->metadata()); + async_pair->start->CopyBackendConfigFrom(instruction); + if (should_update_schedule) { + replaced_pairs[instruction] = *async_pair; } - absl::flat_hash_map replaced_pairs; - const bool should_update_schedule = - module->has_schedule() && - module->schedule().is_computation_scheduled(computation); - for (HloInstruction* instruction : supported_collectives) { - StatusOr async_pair; - switch (instruction->opcode()) { - case HloOpcode::kAllReduce: - async_pair = CreateAsyncAllReduce(instruction); - break; - case HloOpcode::kAllGather: - async_pair = CreateAsyncAllGather(instruction); - break; - case HloOpcode::kCollectivePermute: - async_pair = CreateAsyncCollectivePermute( - instruction, config_.get_context_shapes(instruction)); - break; - case HloOpcode::kAllToAll: - case HloOpcode::kReduceScatter: - async_pair = CreateAsyncStartDone( - instruction, config_.get_context_shapes(instruction)); - break; - default: - return InternalError("Unexpected opcode %s", - HloOpcodeString(instruction->opcode())); - } - TF_RETURN_IF_ERROR(async_pair.status()); - async_pair->start->set_metadata(instruction->metadata()); - async_pair->start->CopyBackendConfigFrom(instruction); - if (should_update_schedule) { - replaced_pairs[instruction] = *async_pair; - } + // Update control dependencies if present. + TF_RETURN_IF_ERROR( + instruction->CopyAllControlDepsTo(async_pair->start, async_pair->done)); + TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); - // Update control dependencies if present. - for (HloInstruction* pred : instruction->control_predecessors()) { - TF_RETURN_IF_ERROR(pred->AddControlDependencyTo(async_pair->start)); - } - for (HloInstruction* succ : instruction->control_successors()) { - TF_RETURN_IF_ERROR(async_pair->done->AddControlDependencyTo(succ)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + computation->ReplaceInstruction(instruction, async_pair->done), + "replacing ", instruction->ToShortString()); + changed = true; + } + if (should_update_schedule) { + std::vector new_sequence; + const HloInstructionSequence& sequence = + module->schedule().sequence(computation); + new_sequence.reserve(sequence.size() + replaced_pairs.size()); + for (HloInstruction* instr : sequence.instructions()) { + auto it = replaced_pairs.find(instr); + if (it != replaced_pairs.end()) { + new_sequence.push_back(it->second.start); + new_sequence.push_back(it->second.done); + continue; } - TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); - - TF_RETURN_WITH_CONTEXT_IF_ERROR( - computation->ReplaceInstruction(instruction, async_pair->done), - "replacing ", instruction->ToShortString()); - changed = true; + new_sequence.push_back(instr); } - if (should_update_schedule) { - std::vector new_sequence; - const HloInstructionSequence& sequence = - module->schedule().sequence(computation); - new_sequence.reserve(sequence.size() + replaced_pairs.size()); - for (HloInstruction* instr : sequence.instructions()) { - auto it = replaced_pairs.find(instr); - if (it != replaced_pairs.end()) { - new_sequence.push_back(it->second.start); - new_sequence.push_back(it->second.done); - continue; - } - new_sequence.push_back(instr); - } - module->schedule().set_sequence(computation, new_sequence); + module->schedule().set_sequence(computation, new_sequence); + } + return changed; +} + +absl::StatusOr AsyncCollectiveCreator::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + std::vector supported_collectives = + MatchCollectives(computation); + if (supported_collectives.empty()) { + continue; } + TF_ASSIGN_OR_RETURN(bool comp_changed, + ReplaceCollectives(computation, supported_collectives)); + changed |= comp_changed; } return changed; } diff --git a/xla/service/async_collective_creator.h b/xla/service/async_collective_creator.h index 25ff6931905aa..077a934e5dc11 100644 --- a/xla/service/async_collective_creator.h +++ b/xla/service/async_collective_creator.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,14 +31,15 @@ class AsyncCollectiveCreator : public HloModulePass { // Function to query the shape of the "context" for collectives that use // HLO async-start/async-done. using ContextShapeQuery = - std::function(const HloInstruction*)>; + std::function(const HloInstruction *)>; struct CollectiveCreatorConfig { HloPredicate convert_all_reduce = HloPredicateFalse; HloPredicate convert_all_gather = HloPredicateFalse; + HloPredicate convert_collective_broadcast = HloPredicateFalse; HloPredicate convert_collective_permute = HloPredicateFalse; HloPredicate convert_all_to_all = HloPredicateFalse; HloPredicate convert_reduce_scatter = HloPredicateFalse; - ContextShapeQuery get_context_shapes = [](const HloInstruction*) { + ContextShapeQuery get_context_shapes = [](const HloInstruction *) { return std::vector{}; }; }; @@ -47,9 +48,15 @@ class AsyncCollectiveCreator : public HloModulePass { absl::string_view name() const override { return "async-collective-creator"; } using HloPassInterface::Run; - StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; + absl::StatusOr Run( + HloModule *module, + const absl::flat_hash_set &execution_threads) override; + + std::vector MatchCollectives(HloComputation *computation); + absl::StatusOr ReplaceCollectives( + HloComputation *computation, + std::vector &supported_collectives); + const CollectiveCreatorConfig *config() const { return &config_; } private: CollectiveCreatorConfig config_; diff --git a/xla/service/async_collective_creator_test.cc b/xla/service/async_collective_creator_test.cc index 63eae4cecc876..8c9b574003da9 100644 --- a/xla/service/async_collective_creator_test.cc +++ b/xla/service/async_collective_creator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -177,6 +177,33 @@ TEST_F(AsyncAllReduceCreatorTest, SplitsSingleCollectivePermuteScheduled) { original_instr_sequence_size + 1); } +TEST_F(AsyncAllReduceCreatorTest, SplitsSingleCollectiveBroadcast) { + constexpr absl::string_view hlo_string = R"( + HloModule test + ENTRY entry { + p0 = f32[8,16] parameter(0) + ROOT cb = f32[8,16] collective-broadcast(p0), replica_groups={{7,0,1,2,3,4,5,6}} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_string)); + + AsyncCollectiveCreator::CollectiveCreatorConfig config; + config.convert_collective_broadcast = HloPredicateTrue; + TF_ASSERT_OK(AsyncCollectiveCreator(config).Run(hlo_module.get()).status()); + + HloComputation* computation = hlo_module->entry_computation(); + ASSERT_THAT(computation, NotNull()); + ASSERT_EQ(computation->instruction_count(), 3); + const HloInstruction* done = computation->root_instruction(); + EXPECT_EQ(done->opcode(), HloOpcode::kAsyncDone); + ASSERT_THAT(done->operands(), SizeIs(1)); + const HloInstruction* start = done->operand(0); + EXPECT_EQ(start->opcode(), HloOpcode::kAsyncStart); + ASSERT_THAT(start->async_wrapped_instruction(), NotNull()); + EXPECT_THAT(start->async_wrapped_opcode(), HloOpcode::kCollectiveBroadcast); +} + TEST_F(AsyncAllReduceCreatorTest, SplitsSingleAllToAll) { constexpr absl::string_view hlo_string = R"( HloModule test diff --git a/xla/service/async_op_canonicalizer.cc b/xla/service/async_op_canonicalizer.cc deleted file mode 100644 index 7b40b0014b956..0000000000000 --- a/xla/service/async_op_canonicalizer.cc +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/async_op_canonicalizer.h" - -namespace xla { - -namespace { - -struct AsyncGroup { - std::optional id; - std::vector instructions; -}; - -StatusOr CreateAsyncGroups( - HloModule* module, - const absl::flat_hash_set& execution_threads, - std::vector& async_groups) { - absl::flat_hash_map async_groups_by_id; - absl::flat_hash_map - async_groups_by_instruction; - - for (const HloComputation* computation : - module->MakeNonfusionComputations(execution_threads)) { - for (HloInstruction* instruction : - computation->MakeInstructionPostOrder()) { - if (instruction->opcode() == HloOpcode::kAsyncStart) { - std::optional group_id = instruction->async_group_id(); - // We expect that there weren't any other async-starts with the same - // group id. Treat it as an error in case there is a collision. - TF_RET_CHECK(!group_id.has_value() || - !async_groups_by_id.contains(*group_id)) - << "The group id was taken by another group already."; - async_groups.push_back({group_id, {instruction}}); - async_groups_by_instruction[instruction] = &async_groups.back(); - if (group_id.has_value()) { - async_groups_by_id[*group_id] = &async_groups.back(); - } - } else if (instruction->opcode() == HloOpcode::kAsyncUpdate || - instruction->opcode() == HloOpcode::kAsyncDone) { - // We expect the instruction group id to match the operand's id. - TF_RET_CHECK(instruction->async_group_id() == - instruction->operand(0)->async_group_id()); - // Use the operand to find the async group (not the group id) because - // the instruction might not have a group id assigned yet. - auto async_group_it = - async_groups_by_instruction.find(instruction->operand(0)); - TF_RET_CHECK(async_group_it != async_groups_by_instruction.end()); - AsyncGroup* async_group = async_group_it->second; - async_group->instructions.push_back(instruction); - async_groups_by_instruction[instruction] = async_group; - } - } - } - - // Assign ids to async groups that don't have one. - int64_t next_id = 0; - auto get_next_id = [&]() { - while (async_groups_by_id.contains(next_id)) { - ++next_id; - } - return next_id; - }; - bool modified = false; - for (AsyncGroup& async_group : async_groups) { - if (!async_group.id.has_value()) { - async_group.id = get_next_id(); - async_groups_by_id[*async_group.id] = &async_group; - } - for (HloInstruction* instruction : async_group.instructions) { - modified |= async_group.id != instruction->async_group_id(); - instruction->set_async_group_id(async_group.id); - } - } - - return modified; -} - -} // namespace - -StatusOr AsyncOpCanonicalizer::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - XLA_VLOG_LINES( - 1, module->ToString(HloPrintOptions().set_syntax_sugar_async_ops(false))); - - std::vector async_groups; - TF_ASSIGN_OR_RETURN( - bool modified, - CreateAsyncGroups(module, execution_threads, async_groups)); - - for (const AsyncGroup& async_group : async_groups) { - HloComputation* computation = - async_group.instructions[0]->async_wrapped_computation(); - for (int i = 1; i < async_group.instructions.size(); ++i) { - HloInstruction* instruction = async_group.instructions[i]; - if (instruction->async_wrapped_computation() != computation) { - instruction->async_wrapped_computation()->RemoveAsyncInstruction( - instruction); - instruction->ReplaceCalledComputations( - [&](HloComputation*) { return computation; }); - computation->AddAsyncInstruction(*instruction); - } - } - } - XLA_VLOG_LINES( - 1, module->ToString(HloPrintOptions().set_syntax_sugar_async_ops(false))); - return modified; -} - -} // namespace xla diff --git a/xla/service/async_op_canonicalizer.h b/xla/service/async_op_canonicalizer.h deleted file mode 100644 index 0cf9bab610664..0000000000000 --- a/xla/service/async_op_canonicalizer.h +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_ASYNC_OP_CANONICALIZER_H_ -#define XLA_SERVICE_ASYNC_OP_CANONICALIZER_H_ - -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_pass_interface.h" - -namespace xla { - -// This pass looks at all of the async operations in the module and assigns the -// async operations that participate in the same async action a unique async -// group id. Async operations in the same group id typically consist of one -// async-start operation, one async-done operation, and zero or more -// async-update operations. Then, this pass ensures all of the async operations -// with the same group id wrap the same computation such that each async -// computation is associated with all of the async operations that have the same -// group id. -class AsyncOpCanonicalizer : public HloModulePass { - public: - ~AsyncOpCanonicalizer() override = default; - absl::string_view name() const override { return "async-op-canonicalizer"; } - using HloPassInterface::Run; - StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla - -#endif // XLA_SERVICE_ASYNC_OP_CANONICALIZER_H_ diff --git a/xla/service/async_op_canonicalizer_test.cc b/xla/service/async_op_canonicalizer_test.cc deleted file mode 100644 index 94ead1454b903..0000000000000 --- a/xla/service/async_op_canonicalizer_test.cc +++ /dev/null @@ -1,141 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/async_op_canonicalizer.h" - -#include - -#include "xla/service/hlo_dce.h" -#include "xla/tests/hlo_test_base.h" - -namespace xla { -namespace { - -using AsyncOpCanonicalizerTest = HloTestBase; - -TEST_F(AsyncOpCanonicalizerTest, AsyncCallsSingleComputation) { - std::string hlo_string = R"( -HloModule AsyncCall - -%called_computation (param_0: f32[4096], param_1: f32[4096]) -> f32[4096] { - %param_0 = f32[4096]{0} parameter(0) - %param_1 = f32[4096]{0} parameter(1) - %negate_0 = f32[4096]{0} negate(f32[4096]{0} %param_0) - %negate_1 = f32[4096]{0} negate(f32[4096]{0} %param_1) - ROOT %result.1 = f32[4096]{0} add(f32[4096]{0} %negate_0, f32[4096]{0} %negate_1) -} - -%async_wrapped (async_param: f32[4096], async_param.1: f32[4096]) -> f32[4096] { - %async_param = f32[4096]{0} parameter(0) - %async_param.1 = f32[4096]{0} parameter(1) - ROOT %call = f32[4096]{0} call(f32[4096]{0} %async_param, f32[4096]{0} %async_param.1), to_apply=%called_computation -} - -ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] { - %a = f32[4096]{0} parameter(0) - %b = f32[4096]{0} parameter(1) - %async-start = ((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) async-start(f32[4096]{0} %a, f32[4096]{0} %b), calls=%async_wrapped - %negate_2 = f32[4096]{0} negate(f32[4096]{0} %a) - %async-update = ((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) async-update(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start), calls=%async_wrapped - %negate_3 = f32[4096]{0} negate(f32[4096]{0} %b) - %add_0 = f32[4096]{0} add(f32[4096]{0} %negate_2, f32[4096]{0} %negate_3) - %async-done = f32[4096]{0} async-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-update), calls=%async_wrapped - ROOT %add_1 = f32[4096]{0} add(f32[4096]{0} %add_0, f32[4096]{0} %async-done) -} - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - AsyncOpCanonicalizer canonicalizer; - EXPECT_TRUE(canonicalizer.Run(module.get()).value()); - - HloInstruction* async_start = FindInstruction(module.get(), "async-start"); - HloInstruction* async_update = FindInstruction(module.get(), "async-update"); - HloInstruction* async_done = FindInstruction(module.get(), "async-done"); - - EXPECT_EQ(async_start->async_group_id(), 0); - EXPECT_EQ(async_update->async_group_id(), 0); - EXPECT_EQ(async_done->async_group_id(), 0); - EXPECT_EQ(async_start->async_wrapped_computation(), - async_update->async_wrapped_computation()); - EXPECT_EQ(async_start->async_wrapped_computation(), - async_done->async_wrapped_computation()); -} - -TEST_F(AsyncOpCanonicalizerTest, AsyncCallsMultipleComputations) { - std::string hlo_string = R"( -HloModule AsyncCall - -%called_computation (param_0: f32[4096], param_1: f32[4096]) -> f32[4096] { - %param_0 = f32[4096]{0} parameter(0) - %param_1 = f32[4096]{0} parameter(1) - %negate_0 = f32[4096]{0} negate(f32[4096]{0} %param_0) - %negate_1 = f32[4096]{0} negate(f32[4096]{0} %param_1) - ROOT %result.1 = f32[4096]{0} add(f32[4096]{0} %negate_0, f32[4096]{0} %negate_1) -} - -%async_wrapped.1 (async_param: f32[4096], async_param.1: f32[4096]) -> f32[4096] { - %async_param = f32[4096]{0} parameter(0) - %async_param.1 = f32[4096]{0} parameter(1) - ROOT %call = f32[4096]{0} call(f32[4096]{0} %async_param, f32[4096]{0} %async_param.1), to_apply=%called_computation -} - -%async_wrapped.2 (async_param: f32[4096], async_param.1: f32[4096]) -> f32[4096] { - %async_param = f32[4096]{0} parameter(0) - %async_param.1 = f32[4096]{0} parameter(1) - ROOT %call = f32[4096]{0} call(f32[4096]{0} %async_param, f32[4096]{0} %async_param.1), to_apply=%called_computation -} - -%async_wrapped.3 (async_param: f32[4096], async_param.1: f32[4096]) -> f32[4096] { - %async_param = f32[4096]{0} parameter(0) - %async_param.1 = f32[4096]{0} parameter(1) - ROOT %call = f32[4096]{0} call(f32[4096]{0} %async_param, f32[4096]{0} %async_param.1), to_apply=%called_computation -} - -ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] { - %a = f32[4096]{0} parameter(0) - %b = f32[4096]{0} parameter(1) - %async-start = ((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) async-start(f32[4096]{0} %a, f32[4096]{0} %b), calls=%async_wrapped.1 - %negate_2 = f32[4096]{0} negate(f32[4096]{0} %a) - %async-update = ((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) async-update(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start), calls=%async_wrapped.2 - %negate_3 = f32[4096]{0} negate(f32[4096]{0} %b) - %add_0 = f32[4096]{0} add(f32[4096]{0} %negate_2, f32[4096]{0} %negate_3) - %async-done = f32[4096]{0} async-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-update), calls=%async_wrapped.3 - ROOT %add_1 = f32[4096]{0} add(f32[4096]{0} %add_0, f32[4096]{0} %async-done) -} - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - AsyncOpCanonicalizer canonicalizer; - EXPECT_TRUE(canonicalizer.Run(module.get()).value()); - HloDCE dce; - dce.Run(module.get()).value(); - - HloInstruction* async_start = FindInstruction(module.get(), "async-start"); - HloInstruction* async_update = FindInstruction(module.get(), "async-update"); - HloInstruction* async_done = FindInstruction(module.get(), "async-done"); - - EXPECT_EQ(async_start->async_group_id(), 0); - EXPECT_EQ(async_update->async_group_id(), 0); - EXPECT_EQ(async_done->async_group_id(), 0); - EXPECT_EQ(async_start->async_wrapped_computation(), - async_update->async_wrapped_computation()); - EXPECT_EQ(async_start->async_wrapped_computation(), - async_done->async_wrapped_computation()); -} - -} // namespace -} // namespace xla diff --git a/xla/service/backend.cc b/xla/service/backend.cc index d37eba45cf3c7..aac138f395caf 100644 --- a/xla/service/backend.cc +++ b/xla/service/backend.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -78,7 +78,7 @@ struct Backend::IntraOpThreadPool { std::unique_ptr device; }; -/* static */ StatusOr> Backend::CreateBackend( +/* static */ absl::StatusOr> Backend::CreateBackend( const BackendOptions& options) { se::Platform* platform = options.platform(); TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); @@ -95,7 +95,7 @@ struct Backend::IntraOpThreadPool { return std::move(backend); } -/* static */ StatusOr> +/* static */ absl::StatusOr> Backend::CreateDefaultBackend() { TF_ASSIGN_OR_RETURN(se::Platform * platform, PlatformUtil::GetDefaultPlatform()); @@ -104,33 +104,32 @@ Backend::CreateDefaultBackend() { return CreateBackend(backend_options); } -StatusOr Backend::BorrowStream(int device_ordinal, - se::StreamPriority priority) { +absl::StatusOr Backend::BorrowStream( + int device_ordinal, se::StreamPriority priority) { TF_ASSIGN_OR_RETURN(auto executor, stream_executor(device_ordinal)); return BorrowStream(executor, priority); } -StatusOr Backend::BorrowStream(se::StreamExecutor* executor, - se::StreamPriority priority) { +absl::StatusOr Backend::BorrowStream( + se::StreamExecutor* executor, se::StreamPriority priority) { absl::MutexLock l(&mu_); if (!stream_pools_.contains(executor)) { - stream_pools_.emplace(executor, std::make_unique()); + stream_pools_.emplace(executor, std::make_unique(executor)); } - return stream_pools_.at(executor)->BorrowStream(executor, priority); + return stream_pools_.at(executor)->BorrowStream(priority); } -StatusOr> Backend::BorrowStreams( +absl::StatusOr> Backend::BorrowStreams( int device_ordinal, int num_streams, se::StreamPriority priority) { absl::MutexLock l(&mu_); TF_ASSIGN_OR_RETURN(auto executor, stream_executor(device_ordinal)); if (!stream_pools_.contains(executor)) { - stream_pools_.emplace(executor, std::make_unique()); + stream_pools_.emplace(executor, std::make_unique(executor)); } std::vector ptrs; for (int i = 0; i < num_streams; i++) { - StreamPool::Ptr ptr = - stream_pools_.at(executor)->BorrowStream(executor, priority); + StreamPool::Ptr ptr = stream_pools_.at(executor)->BorrowStream(priority); ptrs.push_back(std::move(ptr)); } return ptrs; @@ -181,7 +180,7 @@ tsl::thread::ThreadPool* Backend::eigen_intra_op_thread_pool() const { return intra_op_thread_pool_->pool.get(); } -StatusOr Backend::stream_executor( +absl::StatusOr Backend::stream_executor( int device_ordinal) const { if (device_ordinal < 0 || device_ordinal > stream_executors_.back()->device_ordinal()) { @@ -198,8 +197,8 @@ StatusOr Backend::stream_executor( device_name(device_ordinal)); } -StatusOr Backend::devices_equivalent(int device_ordinal_a, - int device_ordinal_b) { +absl::StatusOr Backend::devices_equivalent(int device_ordinal_a, + int device_ordinal_b) { // Use the name from device description to determine equivalence. This is a // bit crude but works for GPUs which is the important case where we compile // an executable for one GPU and want to know if it will run (well) on diff --git a/xla/service/backend.h b/xla/service/backend.h index 8e2e2f89ef97b..fb8a324a4d320 100644 --- a/xla/service/backend.h +++ b/xla/service/backend.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -74,12 +74,12 @@ class BackendOptions { class Backend { public: // Creates a new backend. - static StatusOr> CreateBackend( + static absl::StatusOr> CreateBackend( const BackendOptions& options); // Creates a backend for the default platform. The default platform is defined // in PlatformUtil. - static StatusOr> CreateDefaultBackend(); + static absl::StatusOr> CreateDefaultBackend(); ~Backend(); @@ -109,7 +109,7 @@ class Backend { } // Returns the stream executor for the given device ordinal. - StatusOr stream_executor(int device_ordinal) const; + absl::StatusOr stream_executor(int device_ordinal) const; // Returns the stream executor for the default device ordinal. This stream // executor can only be used when the number of computations is 1 (replication @@ -122,13 +122,13 @@ class Backend { // Borrows a stream for use by the caller with a given priority, either by // grabbing it from an internal pool, or by constructing/initializating it, // and returns the result to the caller. - StatusOr BorrowStream( + absl::StatusOr BorrowStream( int device_ordinal, se::StreamPriority priority = se::StreamPriority::Default); - StatusOr BorrowStream( + absl::StatusOr BorrowStream( se::StreamExecutor* executor, se::StreamPriority priority = se::StreamPriority::Default); - StatusOr> BorrowStreams( + absl::StatusOr> BorrowStreams( int device_ordinal, int num_streams, se::StreamPriority priority = se::StreamPriority::Default); @@ -136,8 +136,8 @@ class Backend { // as `BorrowStreams` above does. // Purely for convenience, the caller could rather make this anonymous // function itself. - std::function>(int, int, - se::StreamPriority)> + std::function>( + int, int, se::StreamPriority)> StreamBorrowerWithPriority() { return [this](int device_ordinal, int num_streams, se::StreamPriority priority) { @@ -159,7 +159,8 @@ class Backend { // Returns true if the devices with the given ordinals are equivalent from // XLA's perspective. That is, an executable compiled for one device would // be equivalent to an executable compiled for the other. - StatusOr devices_equivalent(int device_ordinal_a, int device_ordinal_b); + absl::StatusOr devices_equivalent(int device_ordinal_a, + int device_ordinal_b); // For the host platform, returns the configured eigen threadpool device to be // used for scheduling work. For other platforms, returns NULL. diff --git a/xla/service/batch_dot_simplification.cc b/xla/service/batch_dot_simplification.cc index 638a1a1253dc1..2941182ccc4a6 100644 --- a/xla/service/batch_dot_simplification.cc +++ b/xla/service/batch_dot_simplification.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,13 +16,20 @@ limitations under the License. #include "xla/service/batch_dot_simplification.h" #include "absl/algorithm/container.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/hlo_creation_utils.h" namespace xla { -StatusOr +absl::StatusOr BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( HloInstruction* batch_dot) { + // Sparse dots are not supported on CPU. + if (Cast(batch_dot)->sparse_operands()) { + return false; + } + // This pass assumes the lhs and rhs batch dimensions are equal and strictly // ascending. const auto& is_iota = [](absl::Span dims) { @@ -108,7 +115,7 @@ absl::string_view BatchDotSimplification::name() const { return "batch-dot-simplification"; } -StatusOr BatchDotSimplification::Run( +absl::StatusOr BatchDotSimplification::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/batch_dot_simplification.h b/xla/service/batch_dot_simplification.h index eecf8f705914d..0f5238386429d 100644 --- a/xla/service/batch_dot_simplification.h +++ b/xla/service/batch_dot_simplification.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,13 +28,13 @@ namespace xla { class BatchDotSimplification : public HloModulePass { public: using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; absl::string_view name() const override; private: - StatusOr ElideDegenerateBatchDimensionFromBatchDot( + absl::StatusOr ElideDegenerateBatchDimensionFromBatchDot( HloInstruction* batch_dot); }; } // namespace xla diff --git a/xla/service/batch_dot_simplification_test.cc b/xla/service/batch_dot_simplification_test.cc index a472895da0330..fd60e8f2a3ade 100644 --- a/xla/service/batch_dot_simplification_test.cc +++ b/xla/service/batch_dot_simplification_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/batchnorm_expander.cc b/xla/service/batchnorm_expander.cc index 0e2edccf19137..592cb4a210cfe 100644 --- a/xla/service/batchnorm_expander.cc +++ b/xla/service/batchnorm_expander.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -580,7 +580,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( return OkStatus(); } -StatusOr BatchNormExpander::Run( +absl::StatusOr BatchNormExpander::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES(2, "BatchNormExpander::Run(), before:\n" + module->ToString()); diff --git a/xla/service/batchnorm_expander.h b/xla/service/batchnorm_expander.h index 0e47fbc99fcc5..ab2c13f56bc2c 100644 --- a/xla/service/batchnorm_expander.h +++ b/xla/service/batchnorm_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ class BatchNormExpander : public HloModulePass { // Run operation expander on the given computation. Returns whether the // computation was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/batchnorm_expander_test.cc b/xla/service/batchnorm_expander_test.cc index 29319e9fa89bf..9922dd601940d 100644 --- a/xla/service/batchnorm_expander_test.cc +++ b/xla/service/batchnorm_expander_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/bfloat16_conversion_folding.cc b/xla/service/bfloat16_conversion_folding.cc index ed35d81ebf843..c8e94d576c5c0 100644 --- a/xla/service/bfloat16_conversion_folding.cc +++ b/xla/service/bfloat16_conversion_folding.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -252,7 +252,7 @@ Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) { return OkStatus(); } -StatusOr BFloat16ConversionFolding::Run( +absl::StatusOr BFloat16ConversionFolding::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES( diff --git a/xla/service/bfloat16_conversion_folding.h b/xla/service/bfloat16_conversion_folding.h index c107ce253ce13..707738dd8491c 100644 --- a/xla/service/bfloat16_conversion_folding.h +++ b/xla/service/bfloat16_conversion_folding.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -44,7 +44,7 @@ class BFloat16ConversionFolding : public HloModulePass { // Run BF16 conversion folding on the given computation. Returns whether the // computation was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/bfloat16_conversion_folding_test.cc b/xla/service/bfloat16_conversion_folding_test.cc index c02e148749d6b..99cf031b565d5 100644 --- a/xla/service/bfloat16_conversion_folding_test.cc +++ b/xla/service/bfloat16_conversion_folding_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -75,7 +75,7 @@ class BFloat16ConversionFoldingTest : public HloTestBase { bool FoldConversions(HloModule* module) { TestBFloat16Support bfloat16_support_; BFloat16ConversionFolding fold(&bfloat16_support_); - StatusOr result = fold.Run(module); + absl::StatusOr result = fold.Run(module); EXPECT_IS_OK(result.status()); return result.value(); } diff --git a/xla/service/bfloat16_propagation.cc b/xla/service/bfloat16_propagation.cc index bd976d154692e..38eb493081f30 100644 --- a/xla/service/bfloat16_propagation.cc +++ b/xla/service/bfloat16_propagation.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -831,7 +831,7 @@ Status BFloat16Propagation::SkipNoopConversions( // their users. During the backward pass, the potential changes are stored in // changes_to_bf16_ which are subject to further adjustments then applied to the // HLOs. -StatusOr BFloat16Propagation::Run( +absl::StatusOr BFloat16Propagation::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { consider_using_bfloat16_.clear(); diff --git a/xla/service/bfloat16_propagation.h b/xla/service/bfloat16_propagation.h index bfd78416785e9..21625e7337573 100644 --- a/xla/service/bfloat16_propagation.h +++ b/xla/service/bfloat16_propagation.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -68,7 +68,7 @@ class BFloat16Propagation : public HloModulePass { // Runs the pass on the given module. Returns whether the module was changed // (precision reductions were added). using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/bfloat16_propagation_test.cc b/xla/service/bfloat16_propagation_test.cc index d9702591efedc..c52c6a37d67c3 100644 --- a/xla/service/bfloat16_propagation_test.cc +++ b/xla/service/bfloat16_propagation_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -67,7 +67,7 @@ class BFloat16PropagationTest : public HloTestBase { bool PropagatePrecision(HloModule* module) { TestBFloat16Support bfloat16_support; BFloat16Propagation propagation(&bfloat16_support); - StatusOr result = propagation.Run(module); + absl::StatusOr result = propagation.Run(module); EXPECT_IS_OK(result.status()); return result.value(); } diff --git a/xla/service/bitcast_dtypes_expander.cc b/xla/service/bitcast_dtypes_expander.cc index ad3ec9c409583..f4cc6809599cd 100644 --- a/xla/service/bitcast_dtypes_expander.cc +++ b/xla/service/bitcast_dtypes_expander.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ limitations under the License. namespace xla { -StatusOr BitcastDtypesExpander::ExpandInstruction( +absl::StatusOr BitcastDtypesExpander::ExpandInstruction( HloInstruction* instruction) { HloInstruction* input = instruction->mutable_operand(0); const Shape& from_shape = input->shape(); diff --git a/xla/service/bitcast_dtypes_expander.h b/xla/service/bitcast_dtypes_expander.h index 3d07c7541bf34..ce7663d47fc0e 100644 --- a/xla/service/bitcast_dtypes_expander.h +++ b/xla/service/bitcast_dtypes_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ class BitcastDtypesExpander : public OpExpanderPass { protected: bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; private: diff --git a/xla/service/bitcast_dtypes_expander_test.cc b/xla/service/bitcast_dtypes_expander_test.cc index b400b28cf8824..b145e8ceb7b5f 100644 --- a/xla/service/bitcast_dtypes_expander_test.cc +++ b/xla/service/bitcast_dtypes_expander_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/broadcast_canonicalizer.cc b/xla/service/broadcast_canonicalizer.cc index 84cb13c769ca8..e763b4d60d0e2 100644 --- a/xla/service/broadcast_canonicalizer.cc +++ b/xla/service/broadcast_canonicalizer.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ namespace xla { BroadcastCanonicalizer::BroadcastCanonicalizer() {} -StatusOr BroadcastCanonicalizer::Run( +absl::StatusOr BroadcastCanonicalizer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/broadcast_canonicalizer.h b/xla/service/broadcast_canonicalizer.h index 30a7dbaabcc24..0206d187942d8 100644 --- a/xla/service/broadcast_canonicalizer.h +++ b/xla/service/broadcast_canonicalizer.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ class BroadcastCanonicalizer : public HloModulePass { absl::string_view name() const override { return "broadcast_canonicalizer"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/broadcast_canonicalizer_test.cc b/xla/service/broadcast_canonicalizer_test.cc index 5214cd6342e79..d0ce32e6c62f5 100644 --- a/xla/service/broadcast_canonicalizer_test.cc +++ b/xla/service/broadcast_canonicalizer_test.cc @@ -1,5 +1,5 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/buffer_assignment.cc b/xla/service/buffer_assignment.cc index 46b19bc805a79..a94ee0720828e 100644 --- a/xla/service/buffer_assignment.cc +++ b/xla/service/buffer_assignment.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -40,7 +40,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_live_range.h" #include "xla/map_util.h" #include "xla/service/buffer_value_containers.h" -#include "xla/service/heap_simulator.h" +#include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" @@ -74,7 +74,7 @@ absl::flat_hash_map BuildIdToHloInstructionMap( return id_to_hlo_instruction; } -StatusOr> +absl::StatusOr> BuildIdToLogicalBufferMap( const BufferAssignmentProto& proto, const absl::flat_hash_map& @@ -205,7 +205,7 @@ Status GatherComputationsByAllocationType( true)); // Thread local. break; default: - return InternalError("Unexpected calling opcode: %s", + return Internal("Unexpected calling opcode: %s", HloOpcodeString(instruction->opcode())); } } @@ -470,7 +470,7 @@ bool BufferAssignment::HasTopLevelAllocation( return HasAllocationAt(instruction, /*index=*/{}); } -StatusOr BufferAssignment::GetUniqueSlice( +absl::StatusOr BufferAssignment::GetUniqueSlice( const HloInstruction* instruction, const ShapeIndex& index) const { VLOG(3) << "Trying to find unique slice for " << instruction->name() << " [" << index << "]"; @@ -502,7 +502,8 @@ StatusOr BufferAssignment::GetUniqueSlice( return result; } -StatusOr BufferAssignment::GetUniqueTopLevelSlice( +absl::StatusOr +BufferAssignment::GetUniqueTopLevelSlice( const HloInstruction* instruction) const { return GetUniqueSlice(instruction, /*index=*/{}); } @@ -523,7 +524,8 @@ bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, SliceSet slices; Status status = ShapeUtil::ForEachSubshapeWithStatus( instr->shape(), - [&](const Shape& /*subshape*/, const ShapeIndex& index) { + [&](const Shape& /*subshape*/, + const ShapeIndex& index) -> absl::Status { auto shape_slices = GetAllSlices(instr, index); if (shape_slices.empty()) { return InvalidArgument("No slices assigned to part of instr."); @@ -548,7 +550,7 @@ bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, }); } -StatusOr +absl::StatusOr BufferAssignment::GetUniqueTopLevelOutputSlice() const { return GetUniqueTopLevelSlice( module_->entry_computation()->root_instruction()); @@ -995,7 +997,7 @@ BufferAssignmentProto BufferAssignment::ToProto() const { } /* static */ -StatusOr> BufferAssignment::FromProto( +absl::StatusOr> BufferAssignment::FromProto( const BufferAssignmentProto& proto, const HloModule* module, BufferValue::SizeFunction buffer_size, HloDataflowAnalysis::CanShareBuffer can_share_buffer) { @@ -1022,11 +1024,21 @@ StatusOr> BufferAssignment::FromProto( // Process each buffer allocation entry in the proto to create a new // allocation. for (const auto& alloc_proto : proto.buffer_allocations()) { - auto* allocation = buffer_assignment->NewEmptyAllocation( + BufferAllocation* allocation = buffer_assignment->NewEmptyAllocation( alloc_proto.size(), alloc_proto.color()); + + // We don't copy allocation index as it gets automatically assigned. CHECK(allocation->index() == alloc_proto.index()) << "Expected allocations in BufferAssignment proto to be sorted by " "index."; + + // Set allocation properties for a newly constructed BufferAllocation. + allocation->set_is_thread_local(alloc_proto.is_thread_local()); + allocation->set_is_tuple(alloc_proto.is_tuple()); + allocation->set_constant(alloc_proto.is_constant()); + + // If allocation corresponds to an entry computation parameter, copy + // parameter properties to a BufferAllocation. if (alloc_proto.is_entry_computation_parameter()) { std::vector shape_idx_vals; absl::c_copy(alloc_proto.parameter_shape_index(), @@ -1044,6 +1056,9 @@ StatusOr> BufferAssignment::FromProto( buffer_assignment->AddAssignment(allocation, *buffer_val, assignee.offset(), assignee.size()); } + + // We don't set `maybe_live_out` as it is inferred automatically by + // buffer assignment when we call `AddAssignment` above. CHECK_EQ(allocation->maybe_live_out(), alloc_proto.maybe_live_out()) << "Dataflow analysis differs from proto."; } @@ -1059,7 +1074,7 @@ StatusOr> BufferAssignment::FromProto( } /* static */ -StatusOr> BufferAssigner::Run( +absl::StatusOr> BufferAssigner::Run( const HloModule* module, std::unique_ptr hlo_ordering, BufferValue::SizeFunction buffer_size, LogicalBuffer::AlignmentFunction color_alignment, @@ -1970,7 +1985,8 @@ void BufferAssigner::AssignBuffersFromHeapSimulator( } } -StatusOr> BufferAssigner::CreateAssignment( +absl::StatusOr> +BufferAssigner::CreateAssignment( const HloModule* module, std::unique_ptr hlo_ordering, BufferValue::SizeFunction buffer_size, LogicalBuffer::AlignmentFunction color_alignment, diff --git a/xla/service/buffer_assignment.h b/xla/service/buffer_assignment.h index 94d4cdf28226f..d365d9fd365b6 100644 --- a/xla/service/buffer_assignment.h +++ b/xla/service/buffer_assignment.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/buffer_assignment.pb.h" -#include "xla/service/heap_simulator.h" +#include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_dataflow_analysis.h" @@ -72,7 +72,6 @@ class BufferAllocation { BufferAllocation(Index index, int64_t size, LogicalBuffer::Color color) : index_(index), size_(size), color_(color) {} - ~BufferAllocation() {} // Returns the index of this allocation. Index index() const { return index_; } @@ -421,16 +420,16 @@ class BufferAssignment { // Convenience function which returns the unique slice containing the buffer // at the given index of the given instruction. If a slice is not assigned or // the slice cannot be determined at compile time then an error is returned. - StatusOr GetUniqueSlice( + absl::StatusOr GetUniqueSlice( const HloInstruction* instruction, const ShapeIndex& index) const; // Like GetUniqueSlice but fixes the index to the top-level of the shape // (index = {}). - StatusOr GetUniqueTopLevelSlice( + absl::StatusOr GetUniqueTopLevelSlice( const HloInstruction* instruction) const; // Like GetUniqueTopLevelSlice but returns the slice for the output of the // entry computation of the HLO module (ie, the result of the XLA // computation). - StatusOr GetUniqueTopLevelOutputSlice() const; + absl::StatusOr GetUniqueTopLevelOutputSlice() const; // Returns the set BufferValues which may be the source of the value at the // given index and instruction. @@ -480,7 +479,7 @@ class BufferAssignment { // Convert BufferAssignment to or from a proto. BufferAssignmentProto ToProto() const; - static StatusOr> FromProto( + static absl::StatusOr> FromProto( const BufferAssignmentProto& proto, const HloModule* module, BufferValue::SizeFunction buffer_size, HloDataflowAnalysis::CanShareBuffer can_share_buffer); @@ -637,7 +636,7 @@ class BufferAssigner { // LogicalBuffer. If preset_assignments is provided, those pre-set assignment // offsets will be used. The caller guarantees that those assignments are // valid and they do not overwrite each other. - static StatusOr> Run( + static absl::StatusOr> Run( const HloModule* module, std::unique_ptr hlo_ordering, BufferValue::SizeFunction buffer_size, LogicalBuffer::AlignmentFunction color_alignment, @@ -665,7 +664,7 @@ class BufferAssigner { virtual ~BufferAssigner() = default; // Create a buffer assignment. - StatusOr> CreateAssignment( + absl::StatusOr> CreateAssignment( const HloModule* module, std::unique_ptr hlo_ordering, BufferValue::SizeFunction buffer_size, LogicalBuffer::AlignmentFunction color_alignment, diff --git a/xla/service/buffer_assignment.proto b/xla/service/buffer_assignment.proto index 88c90b9ed5cba..98d9287bdb8d9 100644 --- a/xla/service/buffer_assignment.proto +++ b/xla/service/buffer_assignment.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/buffer_assignment_test.cc b/xla/service/buffer_assignment_test.cc index eeab8bacb6bac..5380368cad491 100644 --- a/xla/service/buffer_assignment_test.cc +++ b/xla/service/buffer_assignment_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,7 +31,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/literal.h" -#include "xla/service/async_op_canonicalizer.h" #include "xla/service/buffer_value.h" #include "xla/service/call_graph.h" #include "xla/service/copy_insertion.h" @@ -104,7 +103,7 @@ class BufferAssignmentTest : public HloTestBase { .value(); } - StatusOr> ConvertToProtoAndBack( + absl::StatusOr> ConvertToProtoAndBack( const BufferAssignment* buffers, const HloModule* module) { // Dump proto for buffer assignments. auto proto = buffers->ToProto(); @@ -842,7 +841,8 @@ TEST_F(BufferAssignmentTest, PresetAssignments) { auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(2, f32vec100_, "p2")); Shape f32vec100_color1 = ShapeUtil::MakeShapeWithDenseLayout( - F32, {100}, {0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + F32, {100}, {0}, /*tiles=*/{}, /*tail_padding_alignment_in_elements=*/1, + /*element_size_in_bits=*/0, /*memory_space=*/1); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_color1, HloOpcode::kMultiply, broadcast, param0)); @@ -904,7 +904,8 @@ TEST_F(BufferAssignmentTest, PresetAssignmentsWhile) { // HloValue and HloBuffer (i.e., a while loop). auto module = CreateNewVerifiedModule(); Shape f32vec10_color1 = ShapeUtil::MakeShapeWithDenseLayout( - F32, {10}, {0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + F32, {10}, {0}, /*tiles=*/{}, /*tail_padding_alignment_in_elements=*/1, + /*element_size_in_bits=*/0, /*memory_space=*/1); Shape t_s32_f32v10_color1 = ShapeUtil::MakeTupleShape({s32_, f32vec10_color1}); @@ -2752,16 +2753,12 @@ ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] { %negate_6 = f32[4096]{0} negate(f32[4096]{0} %negate_5) %negate_7 = f32[4096]{0} negate(f32[4096]{0} %negate_6) %add_0 = f32[4096]{0} add(f32[4096]{0} %negate_4, f32[4096]{0} %negate_7) - %async-done = f32[4096]{0} call-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start), to_apply=%called_computation + %async-done = f32[4096]{0} call-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start) ROOT %add_1 = f32[4096]{0} add(f32[4096]{0} %add_0, f32[4096]{0} %async-done) } )"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text)); - AsyncOpCanonicalizer async_op_canonicalizer; - EXPECT_TRUE(async_op_canonicalizer.Run(m.get()).ok()); - HloDCE dce; - EXPECT_TRUE(dce.Run(m.get()).ok()); auto buffers = RunBufferAssignmentWithSequentialOrdering(m.get()); @@ -2813,16 +2810,12 @@ ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] { %negate_6 = f32[4096]{0} negate(f32[4096]{0} %negate_5) %negate_7 = f32[4096]{0} negate(f32[4096]{0} %negate_6) %add_0 = f32[4096]{0} add(f32[4096]{0} %negate_4, f32[4096]{0} %negate_7) - %async-done = f32[4096]{0} call-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start), async_execution_thread="foobar", to_apply=%called_computation + %async-done = f32[4096]{0} call-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start) ROOT %add_1 = f32[4096]{0} add(f32[4096]{0} %add_0, f32[4096]{0} %async-done) } )"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text)); - AsyncOpCanonicalizer async_op_canonicalizer; - EXPECT_TRUE(async_op_canonicalizer.Run(m.get()).ok()); - HloDCE dce; - EXPECT_TRUE(dce.Run(m.get()).ok()); auto colorer = [](HloAliasAnalysis* alias_analysis, const HloOrdering&) { for (const HloBuffer& buffer : alias_analysis->buffers()) { @@ -2925,18 +2918,14 @@ ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] { %negate_8 = f32[4096]{0} negate(f32[4096]{0} %negate_7) %negate_9 = f32[4096]{0} negate(f32[4096]{0} %negate_8) %add_0 = f32[4096]{0} add(f32[4096]{0} %negate_6, f32[4096]{0} %negate_9) - %async-done.1 = f32[4096]{0} call-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start.1), async_execution_thread="foobar", to_apply=%called_computation1 - %async-done.2 = f32[4096]{0} call-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start.2), async_execution_thread="foobar", to_apply=%called_computation2 + %async-done.1 = f32[4096]{0} call-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start.1) + %async-done.2 = f32[4096]{0} call-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start.2) %add_1 = f32[4096]{0} add(f32[4096]{0} %add_0, f32[4096]{0} %async-done.1) ROOT %add_2 = f32[4096]{0} add(f32[4096]{0} %add_1, f32[4096]{0} %async-done.2) } )"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text)); - AsyncOpCanonicalizer async_op_canonicalizer; - EXPECT_TRUE(async_op_canonicalizer.Run(m.get()).ok()); - HloDCE dce; - EXPECT_TRUE(dce.Run(m.get()).ok()); auto colorer = [](HloAliasAnalysis* alias_analysis, const HloOrdering&) { for (const HloBuffer& buffer : alias_analysis->buffers()) { @@ -3032,16 +3021,12 @@ TEST_F(BufferAssignmentTest, AsyncCallImplicitSharding) { ENTRY entry { p0 = f32[8] parameter(0) call-start = ((f32[8]), f32[8], s32[]) call-start(p0), async_execution_thread="foo", to_apply=called_computation - ROOT call-done = f32[8] call-done(call-start), async_execution_thread="foo", to_apply=called_computation + ROOT call-done = f32[8] call-done(call-start) } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo_string)); - AsyncOpCanonicalizer canonicalizer; - TF_ASSERT_OK(canonicalizer.Run(module.get()).status()); - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); auto buffers = RunBufferAssignmentWithSequentialOrdering(module.get()); diff --git a/xla/service/buffer_value.cc b/xla/service/buffer_value.cc index 4eda8d2a204a0..70b02a09dc002 100644 --- a/xla/service/buffer_value.cc +++ b/xla/service/buffer_value.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/buffer_value.h b/xla/service/buffer_value.h index 1a9e873741061..980ae7fd6fe40 100644 --- a/xla/service/buffer_value.h +++ b/xla/service/buffer_value.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ namespace xla { // TODO(b/78906445) Delete this class when TuplePointsToAnalysis is unused. // // XLA arrays are trivially a single BufferValue. Tuples are made up of more -// than one BufferValue: an BufferValue for the pointer vector, and an +// than one BufferValue: a BufferValue for the pointer vector, and a // BufferValue for each child element. // // Every BufferValue is defined by a particular instruction and most diff --git a/xla/service/buffer_value_containers.h b/xla/service/buffer_value_containers.h index fe81e31d74422..2e02dd8df7dec 100644 --- a/xla/service/buffer_value_containers.h +++ b/xla/service/buffer_value_containers.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/call_graph.cc b/xla/service/call_graph.cc index 41d7e961f7fc1..fcf0d239a2dab 100644 --- a/xla/service/call_graph.cc +++ b/xla/service/call_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -148,15 +148,13 @@ CallGraph::CallGraph( const CallGraphNode& CallGraph::GetNode( const HloComputation* computation) const { - auto it = node_indices_.find(computation); - CHECK(it != node_indices_.end()); - return nodes_[it->second]; + DCHECK(node_indices_.contains(computation)); + return nodes_[node_indices_.find(computation)->second]; } CallGraphNode& CallGraph::GetNode(const HloComputation* computation) { - auto it = node_indices_.find(computation); - CHECK(it != node_indices_.end()); - return nodes_[it->second]; + DCHECK(node_indices_.contains(computation)); + return nodes_[node_indices_.find(computation)->second]; } bool CallGraph::DominatesHelper( @@ -190,6 +188,21 @@ bool CallGraph::Dominates(const HloComputation* a, return DominatesHelper(a, b, &visited); } +bool CallGraph::CanReach(const HloComputation* a, + const HloComputation* b) const { + if (a == b) { + return true; + } + + const CallGraphNode& b_node = GetNode(b); + for (const HloComputation* b_caller : b_node.callers()) { + if (CanReach(a, b_caller)) { + return true; + } + } + return false; +} + namespace { // Returns the call context of a computation which is called from contexts 'a' // and 'b'. diff --git a/xla/service/call_graph.h b/xla/service/call_graph.h index 9cce636926c30..d0fa22157f363 100644 --- a/xla/service/call_graph.h +++ b/xla/service/call_graph.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -217,6 +217,10 @@ class CallGraph { // 'a'. Trivially, a computation dominates itself. bool Dominates(const HloComputation* a, const HloComputation* b) const; + // Returns true if 'a' can reach 'b' in the call graph. 'a' can reach 'b' if + // 'a' is 'b' or 'a' can reach one of the callers of 'b'. + bool CanReach(const HloComputation* a, const HloComputation* b) const; + // Returns whether 'instruction' is contained in 'computation' either directly // ('instruction->parent' is 'computation') or indirectly ('computation' // dominates 'instruction->parent' in the call graph). diff --git a/xla/service/call_graph_test.cc b/xla/service/call_graph_test.cc index ba0460fcf0ce9..9337bec456fa5 100644 --- a/xla/service/call_graph_test.cc +++ b/xla/service/call_graph_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -707,7 +707,7 @@ TEST_F(CallGraphTest, VisitWithError) { std::unique_ptr call_graph = CallGraph::Build(module.get()); Status status = call_graph->VisitNodes( - [](const CallGraphNode&) { return InternalError("Visitation failed"); }); + [](const CallGraphNode&) { return Internal("Visitation failed"); }); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.code(), tsl::error::INTERNAL); diff --git a/xla/service/call_inliner.cc b/xla/service/call_inliner.cc index e300476047fa4..7b1a46f778e0d 100644 --- a/xla/service/call_inliner.cc +++ b/xla/service/call_inliner.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -92,7 +92,7 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { // Resolves the callee subcomputation_hlo to the new (inline) HLO in the // caller computation, or returns a NotFound error if that subcomputation HLO // has not been mapped. - StatusOr Resolve(HloInstruction* subcomputation_hlo) { + absl::StatusOr Resolve(HloInstruction* subcomputation_hlo) { auto it = subcomputation_hlo_to_new_hlo_.find(subcomputation_hlo); if (it == subcomputation_hlo_to_new_hlo_.end()) { return NotFound( @@ -123,8 +123,8 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { } // namespace -/* static */ StatusOr CallInliner::Inline( - HloInstruction* call) { +/* static */ absl::StatusOr +CallInliner::Inline(HloInstruction* call) { TF_RET_CHECK(call->opcode() == HloOpcode::kCall) << "Instruction was not a call op: " << call->opcode(); const auto& callees = call->called_computations(); @@ -136,7 +136,7 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { return visitor.ConsumeInstructionMap(); } -StatusOr CallInliner::Run( +absl::StatusOr CallInliner::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { std::unique_ptr call_graph = CallGraph::Build(module); diff --git a/xla/service/call_inliner.h b/xla/service/call_inliner.h index f203b2f52eeb6..2ce5e7054a923 100644 --- a/xla/service/call_inliner.h +++ b/xla/service/call_inliner.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ class CallInliner : public HloModulePass { // Inlines one call instruction. Returns a mapping from the original // instructions to their inlined versions. - static StatusOr Inline(HloInstruction* call); + static absl::StatusOr Inline(HloInstruction* call); // If single_call_site is true, only functions with a single call site will be // inlined. @@ -44,7 +44,7 @@ class CallInliner : public HloModulePass { absl::string_view name() const override { return "CallInliner"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/call_inliner_test.cc b/xla/service/call_inliner_test.cc index 45212bea68552..4248c01244480 100644 --- a/xla/service/call_inliner_test.cc +++ b/xla/service/call_inliner_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -269,7 +269,7 @@ HloModule inline_specified_threads_only %main_inner () -> u32[] { %co.0 = u32[] constant(0) %async-start = ((), u32[], u32[]) call-start(), async_execution_thread="secondary_thread", to_apply=secondary_outer - %async-done = u32[] call-done(((), u32[], u32[]) %async-start), async_execution_thread="secondary_thread", to_apply=secondary_outer + %async-done = u32[] call-done(((), u32[], u32[]) %async-start) ROOT %add.2 = add(%co.0, %async-done) } diff --git a/xla/service/change_op_data_type.cc b/xla/service/change_op_data_type.cc index b86ebd5798572..7f06bc76acc4d 100644 --- a/xla/service/change_op_data_type.cc +++ b/xla/service/change_op_data_type.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,9 @@ limitations under the License. #include #include "xla/service/hlo_creation_utils.h" +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) +#include "xla/service/cpu/onednn_matmul_rewriter.h" +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 namespace xla { namespace { @@ -35,7 +38,7 @@ std::optional GetUniformOperandType( } } // namespace -StatusOr ChangeOpDataType::Run( +absl::StatusOr ChangeOpDataType::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; @@ -59,6 +62,12 @@ StatusOr ChangeOpDataType::Run( if (it == to_type_map_.end()) { continue; } +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + if (instr->opcode() == HloOpcode::kDot && + cpu::OneDnnMatMulRewriter::ShouldRewrite(instr)) { + continue; + } +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 const PrimitiveType to_type = it->second; absl::InlinedVector new_operands; for (HloInstruction* operand : instr->mutable_operands()) { diff --git a/xla/service/change_op_data_type.h b/xla/service/change_op_data_type.h index d72567306fd61..5aa5042535e87 100644 --- a/xla/service/change_op_data_type.h +++ b/xla/service/change_op_data_type.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -63,7 +63,8 @@ class ChangeOpDataType : public HloModulePass { } absl::string_view name() const override { return "change-op-data-type"; } - StatusOr Run( + using HloPassInterface::Run; + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/change_op_data_type_test.cc b/xla/service/change_op_data_type_test.cc index d11f160987c33..2bd746b4bc6bd 100644 --- a/xla/service/change_op_data_type_test.cc +++ b/xla/service/change_op_data_type_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/channel_tracker.cc b/xla/service/channel_tracker.cc index 580193b0d3094..8ad2445f082ef 100644 --- a/xla/service/channel_tracker.cc +++ b/xla/service/channel_tracker.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ limitations under the License. namespace xla { -StatusOr ChannelTracker::NewChannel( +absl::StatusOr ChannelTracker::NewChannel( ChannelHandle::ChannelType type) { if (type != ChannelHandle::DEVICE_TO_DEVICE && type != ChannelHandle::HOST_TO_DEVICE && diff --git a/xla/service/channel_tracker.h b/xla/service/channel_tracker.h index 8d4e8fc6a0155..87b90a2c83c5f 100644 --- a/xla/service/channel_tracker.h +++ b/xla/service/channel_tracker.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ class ChannelTracker { // Creates a new Channel object and returns the corresponding // ChannelHandle for it. - StatusOr NewChannel(ChannelHandle::ChannelType type); + absl::StatusOr NewChannel(ChannelHandle::ChannelType type); private: // Guards the channel mapping. diff --git a/xla/service/cholesky_expander.cc b/xla/service/cholesky_expander.cc index efcf8f2fe37d3..95d46d0f323a5 100644 --- a/xla/service/cholesky_expander.cc +++ b/xla/service/cholesky_expander.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -52,7 +52,7 @@ namespace xla { // l = temp / l[..., j, j) * mask + l // return l // Returns a (result, error) pair. -StatusOr> CholeskyExpander::CholeskyUnblocked( +absl::StatusOr> CholeskyExpander::CholeskyUnblocked( XlaOp a, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); @@ -73,8 +73,9 @@ StatusOr> CholeskyExpander::CholeskyUnblocked( XlaOp l = ZerosLike(a); // Construct the for loop body to iterate over rows. - auto body_fn = [&](XlaOp i, absl::Span loop_vars, - XlaBuilder* body_builder) -> StatusOr> { + auto body_fn = + [&](XlaOp i, absl::Span loop_vars, + XlaBuilder* body_builder) -> absl::StatusOr> { std::vector row_shape_dims(major_dims.begin(), major_dims.end()); std::vector col_shape_dims(major_dims.begin(), major_dims.end()); auto body_a = loop_vars[0]; @@ -126,7 +127,7 @@ StatusOr> CholeskyExpander::CholeskyUnblocked( XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64_t block_size, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); const int ndims = a_shape.rank(); if (ndims < 2) { @@ -217,7 +218,7 @@ bool CholeskyExpander::InstructionMatchesPattern(HloInstruction* instruction) { return instruction->opcode() == HloOpcode::kCholesky; } -StatusOr CholeskyExpander::ExpandInstruction( +absl::StatusOr CholeskyExpander::ExpandInstruction( HloInstruction* instruction) { const CholeskyOptions& options = instruction->cholesky_options(); const std::string name = absl::StrFormat( diff --git a/xla/service/cholesky_expander.h b/xla/service/cholesky_expander.h index bf91721b03c32..3178d36e949b1 100644 --- a/xla/service/cholesky_expander.h +++ b/xla/service/cholesky_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,10 +29,10 @@ class CholeskyExpander : public OpExpanderPass { protected: bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; - virtual StatusOr> CholeskyUnblocked( + virtual absl::StatusOr> CholeskyUnblocked( XlaOp a, PrecisionConfig::Precision precision); private: diff --git a/xla/service/collective_combiner_utils.h b/xla/service/collective_combiner_utils.h index c22ba04cc7955..b54296d764a87 100644 --- a/xla/service/collective_combiner_utils.h +++ b/xla/service/collective_combiner_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ namespace xla { // together. Instructions will be combined until the threshold for output byte // size or instruction count is reached. template -StatusOr CombineInstructionsByKey( +absl::StatusOr CombineInstructionsByKey( HloComputation* computation, absl::FunctionRef(const HloInstruction*)> key_fn, absl::FunctionRef)> combine_fn, diff --git a/xla/service/collective_decomposer_utils.cc b/xla/service/collective_decomposer_utils.cc index dd4535544d827..d86c6b5ae4e91 100644 --- a/xla/service/collective_decomposer_utils.cc +++ b/xla/service/collective_decomposer_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ limitations under the License. namespace xla { // Create the start indices for decompositing the given collective. -StatusOr> +absl::StatusOr> CreateStartIndicesForCollectiveDecomposition( CollectiveOpGroupMode group_mode, absl::Span replica_groups, const Shape &shard_shape, diff --git a/xla/service/collective_decomposer_utils.h b/xla/service/collective_decomposer_utils.h index 6aaedfd79f1d1..905ab12c24069 100644 --- a/xla/service/collective_decomposer_utils.h +++ b/xla/service/collective_decomposer_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ limitations under the License. namespace xla { -StatusOr> +absl::StatusOr> CreateStartIndicesForCollectiveDecomposition( CollectiveOpGroupMode group_mode, absl::Span replica_groups, const Shape &shard_shape, diff --git a/xla/service/collective_ops_utils.cc b/xla/service/collective_ops_utils.cc index 87d9d9445cdb5..a980b88753d12 100644 --- a/xla/service/collective_ops_utils.cc +++ b/xla/service/collective_ops_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,14 +17,17 @@ limitations under the License. #include #include +#include #include #include "absl/container/flat_hash_map.h" +#include "absl/strings/str_join.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/global_device_id.h" +#include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/pattern_matcher.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -86,8 +89,9 @@ std::optional GetReductionIdentity(ReductionKind kind, } } -StatusOr> GetParticipatingIDs( - int current_id, std::optional total_participant_count, +absl::StatusOr> GetParticipatingIDs( + CollectiveOpGroupMode group_mode, int current_id, + std::optional total_participant_count, absl::Span groups) { // Empty replica_groups() means that all replicas participate. if (groups.empty()) { @@ -97,24 +101,37 @@ StatusOr> GetParticipatingIDs( return all_participants; } + // Formatter for printing replica groups in StrJoin. + auto group_formatter = [](std::string* out, const ReplicaGroup& group) { + out->append("["); + out->append(absl::StrJoin(group.replica_ids(), ", ")); + out->append("]"); + }; + // Figure out the other replicas that go together with this one. std::optional group; for (const ReplicaGroup& g : groups) { if (absl::c_linear_search(g.replica_ids(), current_id)) { TF_RET_CHECK(!group.has_value()) - << "ID " << current_id << " appears twice in replica groups"; + << "Replica ID " << current_id << " appears twice in replica groups" + << "; group_mode=" << CollectiveOpGroupModeToString(group_mode) + << "; groups_size=" << groups.size() + << "; groups= " << absl::StrJoin(groups, ", ", group_formatter); group = g; } } TF_RET_CHECK(group.has_value()) - << "ID " << current_id << " doesn't appear in replica groups"; + << "Replica ID " << current_id << " doesn't appear in replica groups" + << "; group_mode=" << CollectiveOpGroupModeToString(group_mode) + << "; groups_size=" << groups.size() + << "; groups= " << absl::StrJoin(groups, ", ", group_formatter); return std::vector(group->replica_ids().begin(), group->replica_ids().end()); } // Returns the group formation mode implied by (a) whether the operation has // channel_id and (b) if it has use_global_device_ids and if yes, its value. -StatusOr GetCollectiveOpGroupMode( +absl::StatusOr GetCollectiveOpGroupMode( bool has_channel_id, std::optional use_global_device_ids) { if (!has_channel_id) { if (!use_global_device_ids.has_value() || !*use_global_device_ids) { @@ -148,7 +165,7 @@ absl::string_view CollectiveOpGroupModeToString( } } -StatusOr>> +absl::StatusOr>> GetParticipatingDevicesGroups(const DeviceAssignment& device_assignment, absl::Span replica_groups, CollectiveOpGroupMode group_mode) { @@ -257,7 +274,7 @@ GetParticipatingDevicesGroups(const DeviceAssignment& device_assignment, } } -StatusOr> GetParticipatingFlattenedIdGroups( +absl::StatusOr> GetParticipatingFlattenedIdGroups( const DeviceAssignment& device_assignment, absl::Span replica_groups, CollectiveOpGroupMode group_mode) { @@ -287,7 +304,7 @@ StatusOr> GetParticipatingFlattenedIdGroups( return flattened_id_groups; } -StatusOr> GetParticipatingFlattenedIdGroups( +absl::StatusOr> GetParticipatingFlattenedIdGroups( absl::Span replica_groups, CollectiveOpGroupMode replica_group_mode, int replica_count, int partition_count) { @@ -358,7 +375,7 @@ StatusOr> GetParticipatingFlattenedIdGroups( return flattened_replica_groups; } -StatusOr> GetParticipatingDevices( +absl::StatusOr> GetParticipatingDevices( GlobalDeviceId device_id, const DeviceAssignment& device_assignment, absl::Span replica_groups, CollectiveOpGroupMode group_mode) { @@ -369,6 +386,11 @@ StatusOr> GetParticipatingDevices( device_assignment.LogicalIdForDevice(device_id)); int current_replica_id = logical_id.replica_id; int current_partition_id = logical_id.computation_id; + TF_RET_CHECK(0 <= current_replica_id && current_replica_id < replica_count) + << current_replica_id << " " << replica_count; + TF_RET_CHECK(0 <= current_partition_id && + current_partition_id < partition_count) + << current_partition_id << " " << partition_count; std::vector participants; switch (group_mode) { @@ -377,13 +399,15 @@ StatusOr> GetParticipatingDevices( // use current replica id to find the set of participating replicas. If // replica groups are empty, assume a group with all replicas. TF_ASSIGN_OR_RETURN(std::vector participating_replicas, - GetParticipatingIDs(current_replica_id, replica_count, - replica_groups)); + GetParticipatingIDs(group_mode, current_replica_id, + replica_count, replica_groups)); // The set of participating devices is the replicas from the current // partition. participants.reserve(participating_replicas.size()); for (int replica_id : participating_replicas) { + TF_RET_CHECK(0 <= replica_id && replica_id < replica_count) + << replica_id << " " << replica_count; participants.emplace_back( device_assignment(replica_id, current_partition_id)); } @@ -394,10 +418,12 @@ StatusOr> GetParticipatingDevices( // replica_groups contain partition_id, group contains all partitions for // the current replica. TF_ASSIGN_OR_RETURN(std::vector participating_partitions, - GetParticipatingIDs(current_partition_id, + GetParticipatingIDs(group_mode, current_partition_id, partition_count, replica_groups)); participants.reserve(participating_partitions.size()); for (int partition_id : participating_partitions) { + TF_RET_CHECK(0 <= partition_id && partition_id < partition_count) + << partition_id << " " << partition_count; participants.emplace_back( device_assignment(current_replica_id, partition_id)); } @@ -408,10 +434,12 @@ StatusOr> GetParticipatingDevices( // replica_groups contain replica_ids. Group contains replicas for all // partitions. TF_ASSIGN_OR_RETURN(std::vector participating_replicas, - GetParticipatingIDs(current_replica_id, replica_count, - replica_groups)); + GetParticipatingIDs(group_mode, current_replica_id, + replica_count, replica_groups)); participants.reserve(participating_replicas.size() * partition_count); for (int replica_id : participating_replicas) { + TF_RET_CHECK(0 <= replica_id && replica_id < replica_count) + << replica_id << " " << replica_count; for (int partition_id = 0; partition_id < partition_count; ++partition_id) { participants.emplace_back( @@ -433,7 +461,7 @@ StatusOr> GetParticipatingDevices( // so no need to pass in total_participant_count. TF_ASSIGN_OR_RETURN( std::vector participating_flattened_ids, - GetParticipatingIDs(current_flattened_id, + GetParticipatingIDs(group_mode, current_flattened_id, /*total_participant_count=*/std::nullopt, replica_groups)); @@ -441,6 +469,8 @@ StatusOr> GetParticipatingDevices( for (int flattened_id : participating_flattened_ids) { // Map from flattened id back to replica_id, partition_id. int replica_id = flattened_id / partition_count; + TF_RET_CHECK(0 <= replica_id && replica_id < replica_count) + << replica_id << " " << replica_count; int partition_id = flattened_id % partition_count; participants.emplace_back(device_assignment(replica_id, partition_id)); } @@ -449,7 +479,7 @@ StatusOr> GetParticipatingDevices( } } -StatusOr> GetPariticipantCountsForReplicaGroups( +absl::StatusOr> GetPariticipantCountsForReplicaGroups( int64_t num_replicas, int64_t num_partitions, absl::Span replica_groups, CollectiveOpGroupMode group_mode) { @@ -554,6 +584,7 @@ bool IsCollective(const HloInstruction* instruction) { case HloOpcode::kAllGatherStart: case HloOpcode::kAllGatherDone: case HloOpcode::kAllToAll: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermuteStart: case HloOpcode::kCollectivePermuteDone: @@ -567,9 +598,43 @@ bool IsCollective(const HloInstruction* instruction) { } } return false; + case HloOpcode::kAsyncStart: + case HloOpcode::kAsyncUpdate: + case HloOpcode::kAsyncDone: + return IsCollective(instruction->async_wrapped_instruction()); + default: + return false; + } +} + +bool IsCollectiveWithChannelId(const HloInstruction* instruction) { + switch (instruction->opcode()) { + case HloOpcode::kAllReduce: + case HloOpcode::kAllReduceStart: + case HloOpcode::kAllGather: + case HloOpcode::kAllGatherStart: + case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteStart: + return instruction->channel_id().has_value(); + case HloOpcode::kFusion: + for (const auto* inner_inst : instruction->fused_instructions()) { + if (IsCollectiveWithChannelId(inner_inst)) { + return true; + } + } + return false; default: return false; } } +bool IsSyncCollective(const HloInstruction* instr) { + auto backend_config = instr->backend_config(); + if (!backend_config.ok()) { + return false; + } + return backend_config->collective_backend_config().is_sync(); +} + } // end namespace xla diff --git a/xla/service/collective_ops_utils.h b/xla/service/collective_ops_utils.h index 55af5f51f113b..d8e046bee4df9 100644 --- a/xla/service/collective_ops_utils.h +++ b/xla/service/collective_ops_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ limitations under the License. #include "absl/functional/function_ref.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -52,14 +53,6 @@ std::optional MatchReductionComputation( std::optional GetReductionIdentity(ReductionKind kind, PrimitiveType type); -// Figures out which IDs are participating in the collective subgroup. -// An empty `groups` indicates that all [0, total_participant_count) IDs -// are participating. Note that for CollectiveOpGroupMode::kFlattenedID, -// groups cannot be empty, so `total_participant_count` is an optional. -StatusOr> GetParticipatingIDs( - int current_id, std::optional total_participant_count, - absl::Span groups); - // There are broadly 4 modes that collective communication ops use to describe // which sets of devices are participating with a given device in the operation. // These modes are determined by the values of channel_id (optional) and @@ -100,12 +93,21 @@ enum class CollectiveOpGroupMode { kFlattenedID, }; +// Figures out which IDs are participating in the collective subgroup. +// An empty `groups` indicates that all [0, total_participant_count) IDs +// are participating. Note that for CollectiveOpGroupMode::kFlattenedID, +// groups cannot be empty, so `total_participant_count` is an optional. +absl::StatusOr> GetParticipatingIDs( + CollectiveOpGroupMode group_mode, int current_id, + std::optional total_participant_count, + absl::Span groups); + absl::string_view CollectiveOpGroupModeToString( CollectiveOpGroupMode group_mode); // Returns the group formation mode implied by (a) whether the operation has // channel_id and (b) if it has use_global_device_ids and if yes, its value. -StatusOr GetCollectiveOpGroupMode( +absl::StatusOr GetCollectiveOpGroupMode( bool has_channel_id, std::optional use_global_device_ids); // Figures out subgroups of participating devices from given replica_groups and @@ -121,32 +123,32 @@ StatusOr GetCollectiveOpGroupMode( // // This functions returns {{33, 34}, {44, 45, 55, 56}} // There are 2 subgroups of participating devices {33, 34}, {44, 45, 55, 56}. -StatusOr>> +absl::StatusOr>> GetParticipatingDevicesGroups(const DeviceAssignment& device_assignment, absl::Span replica_groups, CollectiveOpGroupMode group_mode); // Same as above, except that it returns the flattened id in the replica groups // instead of device id. -StatusOr> GetParticipatingFlattenedIdGroups( +absl::StatusOr> GetParticipatingFlattenedIdGroups( const DeviceAssignment& device_assignment, absl::Span replica_groups, CollectiveOpGroupMode group_mode); // Same as above, but take replica/partition count instead of device assignment. -StatusOr> GetParticipatingFlattenedIdGroups( +absl::StatusOr> GetParticipatingFlattenedIdGroups( absl::Span replica_groups, CollectiveOpGroupMode replica_group_mode, int replica_count, int partition_count); // Figures out which devices are participating in the collective subgroup. -StatusOr> GetParticipatingDevices( +absl::StatusOr> GetParticipatingDevices( GlobalDeviceId device_id, const DeviceAssignment& device_assignment, absl::Span replica_groups, CollectiveOpGroupMode group_mode); // Figures out how many ranks are participating in each collective subgroup. -StatusOr> GetPariticipantCountsForReplicaGroups( +absl::StatusOr> GetPariticipantCountsForReplicaGroups( int64_t num_replicas, int64_t num_partitions, absl::Span replica_groups, CollectiveOpGroupMode group_mode); @@ -170,6 +172,13 @@ inline constexpr absl::string_view kNopReturnTokenCustomCallTarget = // Returns true if instruction is a collective op or a collective fusion. bool IsCollective(const HloInstruction* instruction); +// Returns true if instruction is a collective op (or a collective fusion) with +// channel_id. +bool IsCollectiveWithChannelId(const HloInstruction* instruction); + +// Returns true if instruction is a synchronous collective op. +bool IsSyncCollective(const HloInstruction* instr); + // Key that identifies a particular Rendezvous object in our global hashtable. // This determines which calls to ExecuteOnStream communicate with each other. // The rules are as follows. @@ -374,9 +383,42 @@ class Rendezvous { std::make_shared(key_.num_local_participants)}; }; +// We only pipeline Send-Recv chains with channel_id > 0, where each chain +// has a unique channel_id, and allows multiple Send-Recv chains using +// channel_id 0. +inline bool MayPipelineSendRecvChannel(int64_t channel_id) { + return channel_id > 0; +} + constexpr char kSendRecvSourceTargetPairsAttr[] = "_xla_send_recv_source_target_pairs"; +// When a Send or Recv is annotated with frontend attribute +// _xla_send_recv_pipeline="1", asynchronous stream kP2P1 is used to execute the +// Send or Recv. For all other cases, asynchronous stream kP2P0 is used. +constexpr char kSendRecvPipelineAttr[] = "_xla_send_recv_pipeline"; + +// This frontend attribute conveys the following information: +// (1) _xla_send_recv_validation="invalid": the runtime should skip sending or +// receiving data when the instruction is executed. +// (2) the absent of the attribute: the runtime should faithfully perform the +// Send or Recv operation when the instruction is executed. +// (3) _xla_send_recv_validation={list-of-bounds}: the list-of-bounds +// corresponds to the value of _xla_send_recv_source_target_pairs, and specifies +// the execution instances for which the runtime should faithfully perform the +// Send or Recv operation. Here is an example: +// _xla_send_recv_source_target_pairs={{0,1}, {1,2}} +// _xla_send_recv_validation={{2,3}, {5,7}} +// The Send or Recv instruction with the above two attributes have the +// following semantics: +// The communication between device 0 and 1 will only send or receive data +// for execution instances 2 and 3 of the instruction on devices 0 and 1. +// For execution instances 0, 1, and beyond 3, the runtime should skip sending +// or receiving any data. +// Similarly, the communication between device 1 and 2 will only send or +// receive data on execution instances 5 and 7. +constexpr char kSendRecvValidationAttr[] = "_xla_send_recv_validation"; + } // end namespace xla #endif // XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_ diff --git a/xla/service/collective_ops_utils_test.cc b/xla/service/collective_ops_utils_test.cc index c313332833091..9bb6d2239e0d2 100644 --- a/xla/service/collective_ops_utils_test.cc +++ b/xla/service/collective_ops_utils_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,26 +15,35 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" +#include #include #include #include #include +#include #include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/computation_placer.h" #include "xla/service/global_device_id.h" +#include "xla/service/hlo_parser.h" +#include "xla/shape_util.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { namespace { TEST(CollectiveOpsUtilsTest, GetParticipatingIDs_NoReplicaGroups) { - std::vector actual = GetParticipatingIDs( - /*current_id=*/0, /*total_participant_count=*/3, - /*groups=*/{}) - .value(); + std::vector actual = + GetParticipatingIDs(CollectiveOpGroupMode::kFlattenedID, + /*current_id=*/0, /*total_participant_count=*/3, + /*groups=*/{}) + .value(); std::vector expected = {0, 1, 2}; EXPECT_EQ(actual, expected); } @@ -49,14 +58,61 @@ TEST(CollectiveOpsUtilsTest, GetParticipatingIDs_ReplicaGroups) { replica_groups[2].add_replica_ids(3); std::vector actual = - GetParticipatingIDs( - /*current_id=*/1, /*total_participant_count=*/std::nullopt, - replica_groups) + GetParticipatingIDs(CollectiveOpGroupMode::kFlattenedID, + /*current_id=*/1, + /*total_participant_count=*/std::nullopt, + replica_groups) .value(); std::vector expected = {1, 5}; EXPECT_EQ(actual, expected); } +TEST(CollectiveOpsUtilsTest, CollectiveWithChannelId) { + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY %cluster { + %param0 = f32[512]{0} parameter(0) + %copy0 = f32[512]{0} copy(param0) + %reshape0 = f32[1,1,512]{2,0,1} reshape(f32[512]{0} %copy0) + %all-gather = f32[1,4,512]{2,0,1} all-gather(f32[1,1,512]{2,0,1} %reshape0), channel_id=3621, replica_groups={{0,1,2,3}}, dimensions={1}, use_global_device_ids=true + %copy1 = f32[1,4,512]{2,0,1} copy(all-gather) + ROOT root = f32[1,4,512]{2,1,0} copy(%copy1) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + HloInstruction *all_gather = + module->entry_computation()->GetInstructionWithName("all-gather"); + + EXPECT_TRUE(IsCollectiveWithChannelId(all_gather)); +} + +TEST(CollectiveOpsUtilsTest, CollectiveWithChannelId2) { + ReplicaGroup group; + for (int64_t i = 0; i < 8; i++) { + group.add_replica_ids(i); + } + + auto builder = HloComputation::Builder("CollectiveWithChannelId2"); + TF_ASSERT_OK_AND_ASSIGN( + HloInstruction * param_0, + builder.AddParameter(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(BF16, {1, 512, 4096}), "p0"))); + HloInstruction *instr = + builder.AddInstruction(HloInstruction::CreateAllGather( + ShapeUtil::MakeShape(BF16, {1, 4096, 4096}), {param_0}, 1, {group}, + true, 231, true)); + auto computation = builder.Build( + builder.AddInstruction(HloInstruction::CreateTuple({instr}))); + auto fusion = + HloInstruction::CreateFusion(ShapeUtil::MakeShape(BF16, {1, 4096, 4096}), + HloInstruction::FusionKind::kOutput, + {param_0}, computation.get(), "fusion"); + + EXPECT_TRUE(IsCollectiveWithChannelId(fusion.get())); +} + } // namespace // Tests for GetCollectOpGroupMode @@ -96,7 +152,7 @@ class GetCollectOpGroupModeTest : public testing::TestWithParam {}; TEST_P(GetCollectOpGroupModeTest, Test) { const TestCase &tc = GetParam(); - StatusOr actual = + absl::StatusOr actual = GetCollectiveOpGroupMode(tc.has_channel_id, tc.use_global_device_ids); if (tc.expected) { TF_ASSERT_OK(actual.status()); @@ -139,7 +195,7 @@ struct TestCase { // modes and their behavior. std::string TestCase::ToString() const { std::ostringstream s; - StatusOr group_mode = + absl::StatusOr group_mode = GetCollectiveOpGroupMode(has_channel_id, use_global_device_ids); if (group_mode.ok()) { s << CollectiveOpGroupModeToString(*group_mode); @@ -393,7 +449,7 @@ TEST_P(GetParticipatingDevicesTest, Test) { return group; }); - StatusOr group_mode = + absl::StatusOr group_mode = GetCollectiveOpGroupMode(tc.has_channel_id, tc.use_global_device_ids); if (!group_mode.ok()) { @@ -403,7 +459,7 @@ TEST_P(GetParticipatingDevicesTest, Test) { // Execute each sub-test. for (const TestCase::CurrentIdAndOutput &subtest : tc.subtests) { - StatusOr> actual = + absl::StatusOr> actual = GetParticipatingDevices(GlobalDeviceId(subtest.current_id), device_assignment, replica_groups, *group_mode); if (!actual.ok()) { @@ -417,9 +473,9 @@ TEST_P(GetParticipatingDevicesTest, Test) { EXPECT_EQ(*actual, expected); } - StatusOr>> actual_device_groups = - GetParticipatingDevicesGroups(device_assignment, replica_groups, - *group_mode); + absl::StatusOr>> + actual_device_groups = GetParticipatingDevicesGroups( + device_assignment, replica_groups, *group_mode); if (!actual_device_groups.ok()) { EXPECT_TRUE(tc.expected_failure); diff --git a/xla/service/collective_opt_utils.cc b/xla/service/collective_opt_utils.cc index 426cc43f123a3..cbc7a4c8867bd 100644 --- a/xla/service/collective_opt_utils.cc +++ b/xla/service/collective_opt_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -267,13 +267,46 @@ bool IsPerIdOffset(const HloInstruction* offset, int64_t shard_size, return true; } +std::optional SpecFromReduceScatterInstr( + const HloInstruction* rs_instr, int64_t num_partitions, + int64_t num_replicas, int64_t min_rank, bool is_constrain_layout, + bool use_global_device_ids, bool is_cross_module) { + if (rs_instr->shape().rank() < min_rank) { + return std::nullopt; + } + CHECK(rs_instr->opcode() == HloOpcode::kReduceScatter); + ReduceScatterSpec spec; + spec.split_dim = rs_instr->dimensions(0); + if (!is_cross_module) { + spec.sharded_replicas = num_replicas; + spec.group_size = rs_instr->replica_groups().empty() + ? num_replicas + : rs_instr->replica_groups()[0].replica_ids_size(); + } else if (use_global_device_ids) { + spec.sharded_replicas = num_replicas; + spec.sharded_partitions = num_partitions; + spec.group_size = rs_instr->replica_groups()[0].replica_ids_size(); + } else { + spec.sharded_partitions = num_partitions; + spec.group_size = num_partitions; + } + spec.original_split_dims = {spec.split_dim}; + spec.dynamic_slice = nullptr; + return spec; +} + } // namespace std::optional MatchReduceScatter( - const HloAllReduceInstruction* ar, int64_t num_partitions, + const HloAllReduceInstructionBase* ar, int64_t num_partitions, int64_t num_replicas, bool allow_multiple_split_dims, bool allow_intervening_reshape, int64_t min_rank, HloPredicate match_partition_id, HloPredicate match_replica_id) { + if (ar->opcode() == HloOpcode::kReduceScatter) { + return SpecFromReduceScatterInstr( + ar, num_partitions, num_replicas, min_rank, ar->constrain_layout(), + ar->use_global_device_ids(), ar->channel_id().has_value()); + } auto spec = MatchWithDynamicSlice( ar, num_partitions, num_replicas, allow_multiple_split_dims, allow_intervening_reshape, min_rank, match_partition_id, match_replica_id, diff --git a/xla/service/collective_opt_utils.h b/xla/service/collective_opt_utils.h index d26c2ab1ad50d..7d044be3c3456 100644 --- a/xla/service/collective_opt_utils.h +++ b/xla/service/collective_opt_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ struct ReduceScatterSpec { // Matches the given all-reduce operation to a reduce-scatter pattern. std::optional MatchReduceScatter( - const HloAllReduceInstruction* ar, int64_t num_partitions, + const HloAllReduceInstructionBase* ar, int64_t num_partitions, int64_t num_replicas, bool allow_multiple_split_dims = false, bool allow_intervening_reshape = false, int64_t min_rank = 1, HloPredicate match_partition_id = HloPredicateIsOp, diff --git a/xla/service/collective_permute_decomposer.cc b/xla/service/collective_permute_decomposer.cc index 7d8d54a10d05b..784d3c7318ec2 100644 --- a/xla/service/collective_permute_decomposer.cc +++ b/xla/service/collective_permute_decomposer.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/collective_permute_decomposer.h" #include +#include #include #include #include @@ -29,14 +30,20 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/graphcycles/graphcycles.h" +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" namespace xla { namespace { + +using SourceTargetPair = std::pair; +using SourceTargetPairs = std::vector; + // Returns true if the (source, target) relationship has a cycle. // -bool hasCycles(const std::vector>& pairs) { +bool HasCycles(const SourceTargetPairs& pairs) { // Build a direct graph to check for cycles in (source, target) relationship. tensorflow::GraphCycles graph; @@ -65,48 +72,55 @@ bool hasCycles(const std::vector>& pairs) { return false; } -// Returns true if the CollectivePermuteStart instruction should be transformed -// to Send/Recv. We currently limit the transformation to asynchronous -// CollectivePermuteStart without any cycle in the (source, target) -// relationship, with only one input and without any context data. +// Returns true if the CollectivePermute instruction should be transformed +// to Send/Recv. We currently limit the transformation to CollectivePermute +// operations without any cycle in their (source, target) relationship, +// with only one input and without any context data. bool ShouldDecompose(const HloCollectivePermuteInstruction& collective_permute, int64_t threshold_in_bytes) { - auto backend_config = - collective_permute.backend_config() - .value(); - if (backend_config.is_sync()) { - return false; - } - if (collective_permute.operand_count() != 1) { + // TODO(b/316043789): enable the transformation for the no channel_id case. + if (!collective_permute.channel_id().has_value()) { return false; } const Shape& result_shape = collective_permute.shape(); - // Skip the transformation if there is any context data. - if (result_shape.tuple_shapes_size() != 2) { + // Skip the transformation if result is not an array, such as containing + // context data. + if (!result_shape.IsArray()) { return false; } - const Shape& shape = result_shape.tuple_shapes(0); - CHECK(shape.IsArray()); - if (ShapeUtil::ByteSizeOf(shape) < threshold_in_bytes) { + if (ShapeUtil::ByteSizeOf(result_shape) < threshold_in_bytes) { return false; } - return !hasCycles(collective_permute.source_target_pairs()); + return !HasCycles(collective_permute.source_target_pairs()); } +// Returns true for a pipelineable collective-permute. As a simple heuristic, +// currently only pipeline a collective-permute with a loop input as its send +// data. +bool MayPipeline(const HloCollectivePermuteInstruction& collective_permute) { + const HloInstruction* data = collective_permute.operand(0); + return (data->opcode() == HloOpcode::kGetTupleElement && + data->operand(0)->opcode() == HloOpcode::kParameter); +} + +// Decomposes a collective-permute and adds frontend attributes to record +// pipeline decision. The present of the frontend attribute means that the +// collective-permute will be pipelined and the value of the attribute +// represents the runtime stream to execute the instruction. Without the +// frontend attribute, the collective-permute will not be pipelined. Status DecomposeCollectivePermute( HloCollectivePermuteInstruction* collective_permute, - HloComputation* computation) { - // The HLO verifier ensures that CollectivePermuteStart's single user is - // CollectivePermuteDone. - HloInstruction* collective_permute_done = collective_permute->users().front(); - // Encode no channel_id in CP as channel_id 0. - int64_t channel_id = collective_permute->channel_id().value_or(0); + HloComputation* computation, const std::string& pipeline_decision) { + // We currently only decompose collective-permute with a channel_id. + int64_t channel_id = collective_permute->channel_id().value(); HloInstruction* data = collective_permute->mutable_operand(0); const Shape& data_shape = data->shape(); const OpMetadata& metadata = collective_permute->metadata(); + const xla::FrontendAttributes& old_attributes = + collective_permute->frontend_attributes(); xla::FrontendAttributes attributes; std::string source_target_pairs_string = "{" + @@ -120,7 +134,8 @@ Status DecomposeCollectivePermute( absl::StrAppend(out, value, "}"); })) + "}"; - + attributes.mutable_map()->insert(old_attributes.map().begin(), + old_attributes.map().end()); (*attributes.mutable_map())[kSendRecvSourceTargetPairsAttr] = source_target_pairs_string; @@ -128,47 +143,197 @@ Status DecomposeCollectivePermute( computation->AddInstruction(HloInstruction::CreateToken()); HloInstruction* recv = computation->AddInstruction( HloInstruction::CreateRecv(data_shape, after_all, channel_id)); - recv->set_frontend_attributes(attributes); + recv->add_frontend_attributes(attributes); recv->set_metadata(metadata); HloInstruction* send = computation->AddInstruction( HloInstruction::CreateSend(data, after_all, channel_id)); - send->set_frontend_attributes(attributes); + send->add_frontend_attributes(attributes); send->set_metadata(metadata); HloInstruction* recv_done = computation->AddInstruction(HloInstruction::CreateRecvDone(recv)); - computation->AddInstruction(HloInstruction::CreateSendDone(send)); + HloInstruction* send_done = + computation->AddInstruction(HloInstruction::CreateSendDone(send)); HloInstruction* recv_data = computation->AddInstruction( HloInstruction::CreateGetTupleElement(recv_done, 0)); - TF_RETURN_IF_ERROR(collective_permute_done->ReplaceAllUsesWith(recv_data)); - TF_RETURN_IF_ERROR( - computation->RemoveInstructionAndUnusedOperands(collective_permute_done)); + TF_RETURN_IF_ERROR(collective_permute->ReplaceAllUsesWith(recv_data)); TF_RETURN_IF_ERROR( computation->RemoveInstructionAndUnusedOperands(collective_permute)); + if (!pipeline_decision.empty()) { + xla::FrontendAttributes attributes; + (*attributes.mutable_map())[kSendRecvPipelineAttr] = pipeline_decision; + send->add_frontend_attributes(attributes); + send_done->add_frontend_attributes(attributes); + recv->add_frontend_attributes(attributes); + recv_done->add_frontend_attributes(attributes); + } + return OkStatus(); } + +// Returns true if the (source, target) pairs form a forward cycle with all +// participants in the cycle, such as {{0,1},{1,2},{2,3},{3,0}}. We assume that +// the (source, target) pairs are ordered via increasing source IDs, as they are +// currently generated by SPMD partitioning. +// +bool IsForwardCycle(const SourceTargetPair& backedge, + const SourceTargetPairs& others) { + int64_t num_pairs = others.size() + 1; + if (backedge.first != num_pairs - 1 || backedge.second != 0) { + return false; + } + for (int64_t i = 0; i < num_pairs - 1; ++i) { + const SourceTargetPair& pair = others[i]; + if (pair.first != i || pair.second != i + 1) { + return false; + } + } + return true; +} + +// Returns true if the (source, target) pairs form a backward cycle with all +// participants in the cycle, such as {{0,3},{1,0},{2,1},{3,2}}. We assume that +// the (source, target) pairs are ordered via increasing source IDs, as they are +// currently generated by SPMD partitioning. +// +bool IsBackwardCycle(const SourceTargetPair& backedge, + const SourceTargetPairs& others) { + int64_t num_pairs = others.size() + 1; + if (backedge.first != 0 || backedge.second != num_pairs - 1) { + return false; + } + for (int64_t i = 0; i < num_pairs - 1; ++i) { + const SourceTargetPair& pair = others[i]; + if (pair.first != i + 1 || pair.second != i) { + return false; + } + } + return true; +} + +// Checks whether the two collective-permutes for a forward cycle or a backward +// cycle for pipelining. If the two collective-permutes form a cycle, returns +// a pair of the collective-permutes with the one for the backward edge of the +// cycle as the first entry in the pair. +std::optional> +CheckCyclePatterns(HloCollectivePermuteInstruction* cp0, + HloCollectivePermuteInstruction* cp1) { + const SourceTargetPairs& cp0_pairs = cp0->source_target_pairs(); + const SourceTargetPairs& cp1_pairs = cp1->source_target_pairs(); + if (cp0_pairs.size() == 1) { + if (IsForwardCycle(cp0_pairs.front(), cp1_pairs) || + IsBackwardCycle(cp0_pairs.front(), cp1_pairs)) { + // cp0 represents the backedge for the cycle. + return std::make_pair(cp0, cp1); + } + } + if (cp1_pairs.size() == 1) { + if (IsForwardCycle(cp1_pairs.front(), cp0_pairs) || + IsBackwardCycle(cp1_pairs.front(), cp0_pairs)) { + // cp1 represents the forward edge for the cycle. + return std::make_pair(cp1, cp0); + } + } + return std::nullopt; +} + } // namespace -StatusOr CollectivePermuteDecomposer::Run( +absl::StatusOr CollectivePermuteDecomposer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; - for (auto comp : module->computations(execution_threads)) { - for (auto hlo : comp->MakeInstructionPostOrder()) { - if (hlo->opcode() != HloOpcode::kCollectivePermuteStart) { + std::vector all_computations = + module->MakeComputationPostOrder(execution_threads); + absl::flat_hash_set while_bodies; + // Process the computation from callers to callees and collect while-body + // along the way. When we process a computation, we know whether it is a + // while-body computation or not. + for (auto iter = all_computations.rbegin(); iter != all_computations.rend(); + ++iter) { + HloComputation* computation = *iter; + bool may_pipeline = while_bodies.contains(computation); + // Record the collective-permute to be decomposed as well as at most two + // collective-permute for which the decomposed Send-Recv chains will be + // pipelined. + // + // Currently, we simply choose the first pipelineable collect-permute we + // encounter, along with another pipelineable collective-permute that forms + // and cycle with the first collective-permute. We consider a + // collective-permute pipelineable if the send-data is a loop parameter. + // When two collective-permutes that form a cycle are selected, + // cp0_to_pipeline records the collective-permute for the backedge of the + // cycle. + std::vector cps_to_decompose; + HloCollectivePermuteInstruction* cp0_to_pipeline = nullptr; + HloCollectivePermuteInstruction* cp1_to_pipeline = nullptr; + for (HloInstruction* hlo : computation->MakeInstructionPostOrder()) { + if (hlo->opcode() == HloOpcode::kWhile) { + // Collect while-body computations. + while_bodies.insert(hlo->while_body()); continue; } - auto collective_permute = Cast(hlo); - if (ShouldDecompose(*collective_permute, threshold_in_bytes_)) { - TF_RETURN_IF_ERROR( - DecomposeCollectivePermute(collective_permute, comp)); - changed = true; + if (hlo->opcode() != HloOpcode::kCollectivePermute) { + continue; } + + HloCollectivePermuteInstruction* cp = + Cast(hlo); + if (!ShouldDecompose(*cp, threshold_in_bytes_)) { + continue; + } + // Record collective-permute to be decomposed. + cps_to_decompose.push_back(cp); + + if (!while_bodies.contains(computation) || !may_pipeline) { + continue; + } + if (cp0_to_pipeline != nullptr && cp1_to_pipeline != nullptr) { + // Already find a pair of collective-permute that forms a cycle to + // pipeline. + continue; + } + if (!MayPipeline(*cp)) { + continue; + } + if (cp0_to_pipeline == nullptr) { + // Record the first pipelineable collective-permute. + cp0_to_pipeline = cp; + continue; + } + auto optional_pair = CheckCyclePatterns(cp0_to_pipeline, cp); + if (optional_pair.has_value()) { + // Add another pipelineable collective-permute that forms a cycle with + // the first pipelineable collect-permute. + + // Collective-permute for the backward edge. + cp0_to_pipeline = optional_pair.value().first; + // Collective-permute for the forward edges. + cp1_to_pipeline = optional_pair.value().second; + } + } + + // Decompose the collective-permute, may add frontend attribute to record + // pipeline decision. + for (HloCollectivePermuteInstruction* cp : cps_to_decompose) { + std::string pipeline_decision; + if (cp0_to_pipeline == cp) { + pipeline_decision = "0"; + } else if (cp1_to_pipeline == cp) { + pipeline_decision = "1"; + } + TF_RETURN_IF_ERROR( + DecomposeCollectivePermute(cp, computation, pipeline_decision)); + } + if (!cps_to_decompose.empty()) { + changed = true; } } + return changed; } diff --git a/xla/service/collective_permute_decomposer.h b/xla/service/collective_permute_decomposer.h index 321c5e1275819..f0d4c0c0df9ae 100644 --- a/xla/service/collective_permute_decomposer.h +++ b/xla/service/collective_permute_decomposer.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,15 +21,15 @@ limitations under the License. namespace xla { -// CollectivePermuteDecomposer is a pass that converts asynchronous -// CollectivePermute operations without any cycle in the (source, target) -// relationship to Send/Recv. We currently restrict this transformation to -// CollectivePermuteStart with one input and without any context data. +// CollectivePermuteDecomposer is a pass that (1) converts CollectivePermute +// operations without any cycle in their (source, target) relationship to +// Send/Recv, and (2) annotates the Send/Recv for pipelining with a frontend +// frontend attribute. We currently restrict the decomposition to +// CollectivePermute with one input and without any context data. // // before transformation: -// start = (, ) collective-permute-start(data), +// cp = (, ) collective-permute(data), // source_target_pairs={...} -// done = collective-permute-done(start) // // after transformation: // after-all = token[] after-all() @@ -41,7 +41,16 @@ namespace xla { // recv-done = (, token[]) recv-done(recv), channel_id=0 // send-done = token[] send-done(send), channel_id=0, // control-predecessors={recv-done} -// done = get-tuple-element(recv-done), index=0 +// cp = get-tuple-element(recv-done), index=0 +// +// For pipelining, we first make pipelining decision on CollectivePermute +// operations, and then record the decision on the decomposed Send/Recv via +// frontend attributes. We currently only pipeline CollectivePermute operations +// that send loop input data. As a simple heuristics, we pick the first +// encountered pipelineable CollectivePermute for pipelining. Then, if there is +// another pipelineable CollectivePermute that forms a forward or backward +// cycle with the first CollectivePermute, we mark both CollectivePermute +// for pipelining. Otherwise, we only mark one CollectivePermute for pipelining. // class CollectivePermuteDecomposer : public HloModulePass { public: @@ -54,7 +63,7 @@ class CollectivePermuteDecomposer : public HloModulePass { using HloPassInterface::Run; // Runs CollectivePermuteDecomposer pass on computations in 'module'. // Returns whether the 'module' was changed. - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/collective_permute_decomposer_test.cc b/xla/service/collective_permute_decomposer_test.cc index c0c3a8d7bc824..e7a707743109d 100644 --- a/xla/service/collective_permute_decomposer_test.cc +++ b/xla/service/collective_permute_decomposer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,6 +23,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" @@ -33,33 +35,13 @@ using ::testing::HasSubstr; namespace op = xla::testing::opcode_matchers; using CollectivePermuteDecomposerTest = HloTestBase; -TEST_F(CollectivePermuteDecomposerTest, SyncNotTransformed) { - const absl::string_view kModuleStr = R"( - HloModule test - ENTRY test_computation { - p = u32[] replica-id() - start = (u32[], u32[]) collective-permute-start(p), - source_target_pairs={{0,1}, {1,2}}, - backend_config="{\"is_sync\":true}" - ROOT done = u32[] collective-permute-done(start) - } - )"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_FALSE(changed); -} - TEST_F(CollectivePermuteDecomposerTest, WithCycleNotTransformed) { const absl::string_view kModuleStr = R"( HloModule test ENTRY test_computation { - p = (u32[], u32[]) replica-id() - start = u32[] collective-permute-start(p), + p = u32[] replica-id() + ROOT cp = u32[] collective-permute(p), channel_id=1, source_target_pairs={{0,1}, {1,0}} - ROOT done = u32[] collective-permute-done(start) } )"; @@ -75,9 +57,8 @@ TEST_F(CollectivePermuteDecomposerTest, WithContextDataNotTransformed) { HloModule test ENTRY test_computation { p = u32[] replica-id() - start = (u32[], u32[], u32[], u32[]) collective-permute-start(p), + ROOT cp = (u32[], u32[], u32[], u32[]) collective-permute(p), channel_id=1, source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}} - ROOT done = u32[] collective-permute-done(start) } )"; @@ -88,15 +69,14 @@ TEST_F(CollectivePermuteDecomposerTest, WithContextDataNotTransformed) { EXPECT_FALSE(changed); } -TEST_F(CollectivePermuteDecomposerTest, TransformedDefaultChannelId) { +TEST_F(CollectivePermuteDecomposerTest, TransformedExplicitChannelId) { const char* const kModuleStr = R"( HloModule test ENTRY test_computation { p = u32[] replica-id() - start = (u32[], u32[]) collective-permute-start(p), + ROOT cp = u32[] collective-permute(p), channel_id=1, source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}, metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35} - ROOT done = u32[] collective-permute-done(start) } )"; @@ -112,26 +92,34 @@ TEST_F(CollectivePermuteDecomposerTest, TransformedDefaultChannelId) { EXPECT_EQ(inst->metadata().source_line(), 35); }; + auto check_not_pipelined = [](const HloInstruction* instr) { + const FrontendAttributes& attributes = instr->frontend_attributes(); + EXPECT_EQ(attributes.map().end(), + attributes.map().find(kSendRecvPipelineAttr)); + }; + HloInstruction* after_all = FindInstruction(module.get(), "after-all"); HloInstruction* recv = FindInstruction(module.get(), "recv"); EXPECT_EQ(recv->operand(0), after_all); - EXPECT_EQ(recv->channel_id().value(), 0); + EXPECT_EQ(recv->channel_id().value(), 1); EXPECT_THAT( recv->ToString(), HasSubstr( "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\"")); check_metadata(recv); + check_not_pipelined(recv); HloInstruction* recv_done = FindInstruction(module.get(), "recv-done"); EXPECT_EQ(recv_done->operand(0), recv); HloInstruction* send = FindInstruction(module.get(), "send"); EXPECT_EQ(send->operand(1), after_all); - EXPECT_EQ(send->channel_id().value(), 0); + EXPECT_EQ(send->channel_id().value(), 1); EXPECT_THAT( send->ToString(), HasSubstr( "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\"")); check_metadata(send); + check_not_pipelined(send); HloInstruction* send_done = FindInstruction(module.get(), "send-done"); EXPECT_EQ(send_done->operand(0), send); @@ -139,14 +127,13 @@ TEST_F(CollectivePermuteDecomposerTest, TransformedDefaultChannelId) { EXPECT_THAT(root, op::GetTupleElement(recv_done, 0)); } -TEST_F(CollectivePermuteDecomposerTest, TransformedExplicitChannelId) { +TEST_F(CollectivePermuteDecomposerTest, NotTransformedDefaultChannelId) { const char* const kModuleStr = R"( HloModule test ENTRY test_computation { p = u32[] replica-id() - start = (u32[], u32[]) collective-permute-start(p), channel_id=2, + ROOT cp = u32[] collective-permute(p), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}} - ROOT done = u32[] collective-permute-done(start) } )"; @@ -154,12 +141,7 @@ TEST_F(CollectivePermuteDecomposerTest, TransformedExplicitChannelId) { ParseAndReturnUnverifiedModule((kModuleStr))); CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_TRUE(changed); - - HloInstruction* recv = FindInstruction(module.get(), "recv"); - EXPECT_EQ(recv->channel_id().value(), 2); - HloInstruction* send = FindInstruction(module.get(), "send"); - EXPECT_EQ(send->channel_id().value(), 2); + EXPECT_FALSE(changed); } TEST_F(CollectivePermuteDecomposerTest, ThresholdNotTransformed) { @@ -167,10 +149,9 @@ TEST_F(CollectivePermuteDecomposerTest, ThresholdNotTransformed) { HloModule test ENTRY test_computation { p = u32[] replica-id() - start = (u32[], u32[]) collective-permute-start(p), + ROOT cp = u32[] collective-permute(p), channel_id=1, source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}, metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35} - ROOT done = u32[] collective-permute-done(start) } )"; @@ -181,5 +162,230 @@ TEST_F(CollectivePermuteDecomposerTest, ThresholdNotTransformed) { EXPECT_FALSE(changed); } +TEST_F(CollectivePermuteDecomposerTest, Pipeline1) { + const char* const kModuleStr = R"( + HloModule module + cond { + param = (u32[], u32[2]) parameter(0) + count = get-tuple-element(param), index=0 + ub = u32[] constant(2) + ROOT result = pred[] compare(count, ub), direction=LT + } + + body { + param = (u32[], u32[2]) parameter(0) + count = get-tuple-element(param), index=0 + send-data = get-tuple-element(param), index=1 + + recv-data = u32[2] collective-permute(send-data), channel_id=1, + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}, + frontend_attributes={_xla_other_attribute="xyz"} + + c1 = u32[] constant(1) + new_count = u32[] add(count, c1) + + r = u32[2] broadcast(c1), dimensions={} + s = u32[2] add(r, recv-data) + + ROOT result = (u32[], u32[2]) tuple(new_count, s) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + c1 = u32[] constant(1) + r = u32[] replica-id() + a = u32[] add(c1, r) + init = u32[2] broadcast(a), dimensions={} + while_init = (u32[], u32[2]) tuple(c0, init) + while_result = (u32[], u32[2]) while(while_init), body=body, condition=cond + ROOT result = u32[2] get-tuple-element(while_result), index=1 + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kModuleStr))); + CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + HloInstruction* recv = FindInstruction(module.get(), "recv"); + EXPECT_EQ(recv->channel_id().value(), 1); + EXPECT_THAT( + recv->ToString(), + HasSubstr( + "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\"")); + EXPECT_THAT(recv->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\"")); + EXPECT_THAT(recv->ToString(), HasSubstr("_xla_other_attribute=\"xyz\"")); + HloInstruction* recv_done = FindInstruction(module.get(), "recv-done"); + EXPECT_THAT(recv_done->ToString(), + HasSubstr("_xla_send_recv_pipeline=\"0\"")); + + HloInstruction* send = FindInstruction(module.get(), "send"); + EXPECT_EQ(send->channel_id().value(), 1); + EXPECT_THAT( + send->ToString(), + HasSubstr( + "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\"")); + EXPECT_THAT(send->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\"")); + EXPECT_THAT(send->ToString(), HasSubstr("_xla_other_attribute=\"xyz\"")); + HloInstruction* send_done = FindInstruction(module.get(), "send-done"); + EXPECT_THAT(send_done->ToString(), + HasSubstr("_xla_send_recv_pipeline=\"0\"")); +} + +TEST_F(CollectivePermuteDecomposerTest, ForwardPipeline2) { + const char* const kModuleStr = R"( + HloModule module + cond { + param = (u32[], u32[2]) parameter(0) + count = get-tuple-element(param), index=0 + ub = u32[] constant(2) + ROOT result = pred[] compare(count, ub), direction=LT + } + + body { + param = (u32[], u32[2]) parameter(0) + count = get-tuple-element(param), index=0 + send-data = get-tuple-element(param), index=1 + + recv-data.0 = u32[2] collective-permute(send-data), channel_id=1, + source_target_pairs={{3,0}} + + recv-data.1 = u32[2] collective-permute(send-data), channel_id=2, + source_target_pairs={{0,1}, {1,2}, {2,3}} + + replica = u32[] replica-id() + constant0 = u32[] constant(0) + compare0 = pred[] compare(replica, constant0), direction=EQ + compare = pred[2] broadcast(compare0), dimensions={} + recv-data = u32[2] select(compare, recv-data.0, recv-data.1) + + c1 = u32[] constant(1) + new_count = u32[] add(count, c1) + + r = u32[2] broadcast(c1), dimensions={} + s = u32[2] add(r, recv-data) + + ROOT result = (u32[], u32[2]) tuple(new_count, s) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + c1 = u32[] constant(1) + r = u32[] replica-id() + a = u32[] add(c1, r) + init = u32[2] broadcast(a), dimensions={} + while_init = (u32[], u32[2]) tuple(c0, init) + while_result = (u32[], u32[2]) while(while_init), body=body, condition=cond + ROOT result = u32[2] get-tuple-element(while_result), index=1 + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kModuleStr))); + CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + HloInstruction* recv = FindInstruction(module.get(), "recv"); + EXPECT_EQ(recv->channel_id().value(), 1); + EXPECT_THAT(recv->ToString(), + HasSubstr("_xla_send_recv_source_target_pairs=\"{{3,0}}\"")); + EXPECT_THAT(recv->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\"")); + HloInstruction* send = FindInstruction(module.get(), "send"); + EXPECT_THAT(send->ToString(), + HasSubstr("_xla_send_recv_source_target_pairs=\"{{3,0}}\"")); + EXPECT_THAT(send->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\"")); + + HloInstruction* recv1 = FindInstruction(module.get(), "recv.1"); + EXPECT_EQ(recv1->channel_id().value(), 2); + EXPECT_THAT( + recv1->ToString(), + HasSubstr("_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3}}\"")); + EXPECT_THAT(recv1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\"")); + HloInstruction* recv_done1 = FindInstruction(module.get(), "recv-done.1"); + EXPECT_THAT(recv_done1->ToString(), + HasSubstr("_xla_send_recv_pipeline=\"1\"")); + HloInstruction* send1 = FindInstruction(module.get(), "send.1"); + EXPECT_THAT( + send1->ToString(), + HasSubstr("_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3}}\"")); + EXPECT_THAT(send1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\"")); + HloInstruction* send_done1 = FindInstruction(module.get(), "send-done.1"); + EXPECT_THAT(send_done1->ToString(), + HasSubstr("_xla_send_recv_pipeline=\"1\"")); +} + +TEST_F(CollectivePermuteDecomposerTest, BackwardPipeline2) { + const char* const kModuleStr = R"( + HloModule module + cond { + param = (u32[], u32[2]) parameter(0) + count = get-tuple-element(param), index=0 + ub = u32[] constant(2) + ROOT result = pred[] compare(count, ub), direction=LT + } + + body { + param = (u32[], u32[2]) parameter(0) + count = get-tuple-element(param), index=0 + send-data = get-tuple-element(param), index=1 + + recv-data.0 = u32[2] collective-permute(send-data), channel_id=1, + source_target_pairs={{1,0},{2,1},{3,2}} + + recv-data.1 = u32[2] collective-permute(send-data), channel_id=2, + source_target_pairs={{0,3}} + + replica = u32[] replica-id() + constant0 = u32[] constant(0) + compare0 = pred[] compare(replica, constant0), direction=NE + compare = pred[2] broadcast(compare0), dimensions={} + recv-data = u32[2] select(compare, recv-data.0, recv-data.1) + + c1 = u32[] constant(1) + new_count = u32[] add(count, c1) + + r = u32[2] broadcast(c1), dimensions={} + s = u32[2] add(r, recv-data) + + ROOT result = (u32[], u32[2]) tuple(new_count, s) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + c1 = u32[] constant(1) + r = u32[] replica-id() + a = u32[] add(c1, r) + init = u32[2] broadcast(a), dimensions={} + while_init = (u32[], u32[2]) tuple(c0, init) + while_result = (u32[], u32[2]) while(while_init), body=body, condition=cond + ROOT result = u32[2] get-tuple-element(while_result), index=1 + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kModuleStr))); + CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + HloInstruction* recv = FindInstruction(module.get(), "recv"); + EXPECT_EQ(recv->channel_id().value(), 1); + EXPECT_THAT( + recv->ToString(), + HasSubstr("_xla_send_recv_source_target_pairs=\"{{1,0},{2,1},{3,2}}\"")); + EXPECT_THAT(recv->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\"")); + HloInstruction* send = FindInstruction(module.get(), "send"); + EXPECT_THAT( + send->ToString(), + HasSubstr("_xla_send_recv_source_target_pairs=\"{{1,0},{2,1},{3,2}}\"")); + EXPECT_THAT(send->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\"")); + + HloInstruction* recv1 = FindInstruction(module.get(), "recv.1"); + EXPECT_EQ(recv1->channel_id().value(), 2); + EXPECT_THAT(recv1->ToString(), + HasSubstr("_xla_send_recv_source_target_pairs=\"{{0,3}}\"")); + EXPECT_THAT(recv1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\"")); + HloInstruction* send1 = FindInstruction(module.get(), "send.1"); + EXPECT_THAT(send1->ToString(), + HasSubstr("_xla_send_recv_source_target_pairs=\"{{0,3}}\"")); + EXPECT_THAT(send1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\"")); +} + } // namespace } // namespace xla diff --git a/xla/service/collective_pipeliner.cc b/xla/service/collective_pipeliner.cc index 57a4c45d0110b..040c2042e7b73 100644 --- a/xla/service/collective_pipeliner.cc +++ b/xla/service/collective_pipeliner.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -66,6 +66,10 @@ namespace { using InstructionMap = absl::flat_hash_map; +// Record the loop invariant parameters used in a chain as well as their +// parameter indices. +using LoopVariantParameterInfo = + std::vector>; // Update all control dependencies for a cloned instruction to connect other // cloned instructions rather than originals. @@ -449,7 +453,8 @@ std::vector CollectDependenciesToPipeline( std::optional> CollectIndependentOperandChain( HloInstruction* instr, int64_t loop_iter, - const absl::flat_hash_set& loop_invariant_params) { + const absl::flat_hash_set& loop_invariant_params, + HloPredicate should_allow_loop_variant_parameter_in_chain) { std::vector chain; absl::flat_hash_set visited_set({instr}); std::vector> stack(1, {instr, 0}); @@ -475,7 +480,8 @@ std::optional> CollectIndependentOperandChain( if (curr_operand->opcode() == HloOpcode::kParameter) { continue; } - if (is_loop_variant_parameter_input(curr_operand)) { + if (is_loop_variant_parameter_input(curr_operand) && + !should_allow_loop_variant_parameter_in_chain(curr_operand)) { return std::nullopt; } if (visited_set.insert(curr_operand).second) { @@ -499,7 +505,10 @@ std::optional> CollectIndependentOperandChain( const bool is_scalar_shaped = ShapeUtil::IsEffectiveScalar(chain_instr->shape()); if (!all_users_in_chain) { - if (!loop_invariant_params.contains(chain_instr) && !is_scalar_shaped) { + if (!loop_invariant_params.contains(chain_instr) && !is_scalar_shaped && + (chain_instr->opcode() != HloOpcode::kGetTupleElement || + chain_instr->operand(0)->opcode() != HloOpcode::kParameter || + !should_allow_loop_variant_parameter_in_chain(chain_instr))) { return std::nullopt; } } @@ -516,12 +525,14 @@ std::optional> CollectIndependentOperandChain( std::optional> CollectChainsToPushBackwards( HloInstruction* instr, int64_t loop_iter, const HloComputation* while_body, int64_t level_to_operate_on, - const absl::flat_hash_set& loop_invariant_params) { - if (instr->user_count() != 1 || instr->HasControlDependencies()) { + const absl::flat_hash_set& loop_invariant_params, + HloPredicate should_allow_loop_variant_parameter_in_chain) { + if (instr->HasControlDependencies()) { return std::nullopt; } - return CollectIndependentOperandChain(instr, loop_iter, - loop_invariant_params); + return CollectIndependentOperandChain( + instr, loop_iter, loop_invariant_params, + should_allow_loop_variant_parameter_in_chain); } // Given a dynamic-update-slice find the output index of the loop we feed into. @@ -590,25 +601,39 @@ void UpdateInstructionChannelId(HloInstruction* cloned_instr, } } if (auto* channel_instr = DynCast(cloned_instr)) { + if (channel_instr->opcode() == HloOpcode::kSendDone || + channel_instr->opcode() == HloOpcode::kRecvDone) { + auto* operand = channel_instr->operand(0); + CHECK(operand->opcode() == HloOpcode::kSend || + operand->opcode() == HloOpcode::kRecv); + channel_instr->set_channel_id( + Cast(operand)->channel_id()); + return; + } if (channel_instr->channel_id()) { channel_instr->set_channel_id(next_channel_id++); } } } -// Clones a chain of instructions from a move_info for backward movement. +// Clones a chain of instructions from a move_info for backward movement, and +// returns the cloned of the last instruction in the chain. The last instruction +// in the chain is the collective instruction being pipelined and shouldn't be +// shared by multiple chains. As such, the last_cloned being returned shouldn't +// be nullptr. template -StatusOr CloneBackwardChain(Comp& target_computation, - const WhileMoveInfo& move_info, - InstructionMap& clone_map, - int64_t loop_iter_idx, - int64_t& next_channel_id) { +absl::StatusOr CloneBackwardChain( + Comp& target_computation, const WhileMoveInfo& move_info, + InstructionMap& clone_map, int64_t loop_iter_idx, int64_t& next_channel_id, + LoopVariantParameterInfo* loop_variant_parameter_info = nullptr) { std::vector to_clone(move_info.formatting_ops.begin(), move_info.formatting_ops.end()); to_clone.push_back(move_info.collective_to_move); HloInstruction* last_cloned = nullptr; for (auto* chain_op : to_clone) { - if (IsLoopIterator(chain_op, loop_iter_idx)) { + // Do not clone a loop iterator or an op that is already cloned. + if (IsLoopIterator(chain_op, loop_iter_idx) || + clone_map.contains(chain_op)) { continue; } auto new_operands = MapNewOperands(chain_op->operands(), clone_map); @@ -618,6 +643,13 @@ StatusOr CloneBackwardChain(Comp& target_computation, UpdateInstructionChannelId(cloned, next_channel_id); clone_map[chain_op] = cloned; last_cloned = cloned; + if (loop_variant_parameter_info != nullptr && + chain_op->opcode() == HloOpcode::kGetTupleElement && + chain_op->operand(0)->opcode() == HloOpcode::kParameter && + chain_op->tuple_index() != loop_iter_idx) { + loop_variant_parameter_info->push_back( + std::make_pair(chain_op->tuple_index(), cloned)); + } } CHECK_NE(last_cloned, nullptr); return last_cloned; @@ -655,7 +687,9 @@ class WhileLoopAnalysis { void CollectCollectivesToMove( int64_t level_to_operate_on, CollectivePipeliner::PipeliningDirection direction, - HloPredicate should_process, HloPredicate acceptable_formatting); + HloPredicate should_process, HloPredicate acceptable_formatting, + HloPredicate should_allow_loop_variant_parameter_in_chain = + HloPredicateFalse); HloInstruction* while_loop_instruction() const { return while_; } private: @@ -763,7 +797,8 @@ bool WhileLoopAnalysis::ComputeLoopStatistics() { void WhileLoopAnalysis::CollectCollectivesToMove( int64_t level_to_operate_on, CollectivePipeliner::PipeliningDirection direction, - HloPredicate should_process, HloPredicate acceptable_formatting) { + HloPredicate should_process, HloPredicate acceptable_formatting, + HloPredicate should_allow_loop_variant_parameter_in_chain) { move_infos_.clear(); HloComputation* while_body = while_->while_body(); const HloInstruction* loop_parameter = @@ -970,7 +1005,8 @@ void WhileLoopAnalysis::CollectCollectivesToMove( CHECK_EQ(direction, CollectivePipeliner::PipeliningDirection::kBackward); auto chain_collected = CollectChainsToPushBackwards( instr, *loop_iteration_idx_, while_body, level_to_operate_on, - invariant_loop_parameters_); + invariant_loop_parameters_, + should_allow_loop_variant_parameter_in_chain); if (!chain_collected.has_value()) { VLOG(5) << "Skipping " << instr->name() << " because didn't find compatible slice of parameter"; @@ -1036,12 +1072,17 @@ bool IsLoopInvariant( absl::flat_hash_set visited; while (!stack.empty()) { auto& current = stack.back(); + invariant_cache[std::get<0>(current)] = true; if (std::get<0>(current)->HasSideEffect() || std::get<0>(current)->opcode() == HloOpcode::kParameter) { invariant_cache[std::get<0>(current)] = false; + stack.pop_back(); + continue; } if (std::get<0>(current)->operands().empty()) { invariant_cache[std::get<0>(current)] = true; + stack.pop_back(); + continue; } if (std::get<1>(current) > 0) { auto* current_operand = @@ -1347,7 +1388,7 @@ Status TransformLoopForward(const WhileLoopAnalysis& loop_analysis, [&next_channel_id, insert_non_alias_custom_call, level_to_operate_on]( HloInstruction* stacked_data, const InstructionMap& pipelined_values_map, - const WhileMoveInfo& move_info) -> StatusOr { + const WhileMoveInfo& move_info) -> absl::StatusOr { HloInstruction* processed = stacked_data->parent()->AddInstruction( move_info.collective_to_move->CloneWithNewOperands( move_info.collective_to_move->shape(), {stacked_data})); @@ -1374,11 +1415,11 @@ Status TransformLoopForward(const WhileLoopAnalysis& loop_analysis, return processed; }; auto extract_and_process_slice = - [&process_slice](HloInstruction* stacked_data, - HloInstruction* data_to_slice, - const WhileMoveInfo& move_info, - const InstructionMap& pipelined_values_map, - HloInstruction* dus_index) -> StatusOr { + [&process_slice]( + HloInstruction* stacked_data, HloInstruction* data_to_slice, + const WhileMoveInfo& move_info, + const InstructionMap& pipelined_values_map, + HloInstruction* dus_index) -> absl::StatusOr { HloComputation* computation = stacked_data->parent(); const Shape& slice_target_shape = move_info.collective_to_move->operand(0)->shape(); @@ -1608,6 +1649,9 @@ Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, auto pipelined_instrs = CollectDependenciesToPipeline( move_info.collective_to_move, absl::MakeSpan(move_info.formatting_ops)); for (auto* pipelined : pipelined_instrs) { + if (pipelined->opcode() == HloOpcode::kConstant) { + continue; + } const bool is_loop_invariant = IsLoopInvariant(pipelined, invariant_cache); is_output_instruction[pipelined] = new_init_operands.size(); @@ -1625,8 +1669,26 @@ Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, new_init_operands.push_back(CreateZero(loop_computation, expanded_shape, expanded_shape.element_type())); indices_to_insert.insert(new_root_operands.size()); + Shape extra_trivial_dim_shape = + ShapeUtil::PrependMajorDimension(1, pipelined->shape()); HloInstruction* reshaped = body_computation->AddInstruction( - HloInstruction::CreateReshape(expanded_shape, pipelined)); + HloInstruction::CreateReshape(extra_trivial_dim_shape, pipelined)); + std::vector indices( + expanded_shape.dimensions_size(), + CreateZero(body_computation, + move_info.dynamic_update_slice->index_shapes()[0], + move_info.dynamic_update_slice->index_shapes()[0] + .element_type())); + indices[0] = move_info.dynamic_update_slice->index_operands()[0]; + HloInstruction* input = + body_computation->AddInstruction(HloInstruction::CreateCustomCall( + expanded_shape, + {body_computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0((int32_t)new_root_operands.size())))}, + "PlaceHolder")); + reshaped = body_computation->AddInstruction( + HloInstruction::CreateDynamicUpdateSlice(expanded_shape, input, + reshaped, indices)); new_root_operands.push_back(reshaped); } } @@ -1729,7 +1791,7 @@ Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, TF_RETURN_IF_ERROR(output->ReplaceOperandWith(0, new_param)); TF_RETURN_IF_ERROR( old_operand_param->parent()->RemoveInstruction(old_operand_param)); - if (insert_non_alias_custom_call) { + if (insert_non_alias_custom_call && original_to_move_indices.contains(i)) { auto* old_operand = output->mutable_operand(1); auto* custom_call = cloned_body->AddInstruction(HloInstruction::CreateCustomCall( @@ -1754,6 +1816,9 @@ Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, auto pipelined_instrs = CollectDependenciesToPipeline( to_move.collective_to_move, absl::MakeSpan(to_move.formatting_ops)); for (auto* original_pipelined : pipelined_instrs) { + if (original_pipelined->opcode() == HloOpcode::kConstant) { + continue; + } const bool is_loop_invariant = IsLoopInvariant(original_pipelined, invariant_cache); CHECK(is_output_instruction.contains(original_pipelined)); @@ -1782,28 +1847,44 @@ Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, {to_sink})); UpdateInstructionChannelId(pipelined_instr_cloned, next_channel_id); pipelined_map[to_move.collective_to_move] = pipelined_instr_cloned; - auto collect_operands = [&pipelined_map](HloInstruction* instr) { + absl::flat_hash_set to_add_batch_set; + auto collect_operands = [&pipelined_map, &to_add_batch_set, + loop_computation, + &to_move](HloInstruction* instr) { std::vector operands; for (auto* operand : instr->mutable_operands()) { + if (operand->opcode() == HloOpcode::kConstant) { + HloInstruction* cloned_constant = loop_computation->AddInstruction( + operand->CloneWithNewOperands(operand->shape(), {})); + if (!to_add_batch_set.contains(instr)) { + operands.push_back(cloned_constant); + continue; + } + Shape full_shape = + ComputeFullOutputShape(to_move, cloned_constant->shape()); + absl::InlinedVector operand_dims; + operand_dims.resize(cloned_constant->shape().dimensions_size()); + absl::c_iota(operand_dims, 1); + HloInstruction* broadcasted = + loop_computation->AddInstruction(HloInstruction::CreateBroadcast( + full_shape, cloned_constant, operand_dims)); + operands.push_back(broadcasted); + continue; + } auto it = pipelined_map.find(operand); CHECK(it != pipelined_map.end()); operands.push_back(it->second); } return operands; }; - absl::flat_hash_set to_add_batch_set; absl::flat_hash_set formatting_ops_set( to_move.formatting_ops.begin(), to_move.formatting_ops.end()); std::vector stack(1, to_move.collective_to_move); - while (!stack.empty()) { - auto* current = stack.back(); - stack.pop_back(); - to_add_batch_set.insert(current); - for (auto* u : current->users()) { - if (formatting_ops_set.contains(u)) { - stack.push_back(u); - } + for (auto* current : to_move.formatting_ops) { + if (IsLoopInvariant(current, invariant_cache)) { + continue; } + to_add_batch_set.insert(current); } // We are adding a batch dimension to the formatting ops, so we need to // specially rewrite each instruction potentially if adding dimensions has @@ -1815,12 +1896,12 @@ Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, HloInstruction* cloned_not_to_batch = loop_computation->AddInstruction( formatting_op->CloneWithNewOperands( formatting_op->shape(), collect_operands(formatting_op))); + UpdateInstructionChannelId(cloned_not_to_batch, next_channel_id); pipelined_map[formatting_op] = cloned_not_to_batch; continue; } if (formatting_op->IsElementwise() || formatting_op->opcode() == HloOpcode::kReshape || - formatting_op->opcode() == HloOpcode::kReduce || formatting_op->opcode() == HloOpcode::kAllReduce || formatting_op->opcode() == HloOpcode::kConvert || formatting_op->opcode() == HloOpcode::kCollectivePermute) { @@ -1831,6 +1912,26 @@ Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, pipelined_map[formatting_op] = cloned_elementwise; continue; } + if (formatting_op->opcode() == HloOpcode::kReduce) { + auto operands = collect_operands(formatting_op); + std::vector dimensions(formatting_op->dimensions().begin(), + formatting_op->dimensions().end()); + for (auto& dim : dimensions) { + ++dim; + } + // Look through broadcast for reduce init value. + if (operands[1]->opcode() == HloOpcode::kBroadcast) { + CHECK(operands[1]->operand(0)->opcode() == HloOpcode::kConstant); + operands[1] = operands[1]->mutable_operand(0); + } + HloInstruction* expanded_reduce = + loop_computation->AddInstruction(HloInstruction::CreateReduce( + ComputeFullOutputShape(to_move, formatting_op->shape()), + operands[0], operands[1], dimensions, + formatting_op->to_apply())); + pipelined_map[formatting_op] = expanded_reduce; + continue; + } if (formatting_op->opcode() == HloOpcode::kBroadcast) { CHECK(formatting_op->dimensions().empty()); auto operands = collect_operands(formatting_op); @@ -1974,13 +2075,13 @@ Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, // x_ag = p0_ag_next // } // x_last = computation(p0_ag_next) -static Status TransformLoopBackward(const WhileLoopAnalysis& loop_analysis, - bool insert_non_alias_custom_call, - int64_t level_to_operate_on, - bool process_different_sized_ops, - HloPredicate should_process, - HloPredicate acceptable_formatting, - int64_t& next_channel_id) { +static Status TransformLoopBackward( + const WhileLoopAnalysis& loop_analysis, bool insert_non_alias_custom_call, + int64_t level_to_operate_on, bool process_different_sized_ops, + HloPredicate should_process, HloPredicate acceptable_formatting, + CollectivePipeliner::HloPostprocessor postprocess_peeled, + CollectivePipeliner::HloPostprocessor postprocess_rotated, + int64_t& next_channel_id) { // Defining some maps/sets to keep track of instructions duplicated. absl::flat_hash_map while_body_to_peeled; absl::flat_hash_map collective_to_move_map; @@ -2076,6 +2177,10 @@ static Status TransformLoopBackward(const WhileLoopAnalysis& loop_analysis, loop_analysis.GetMoveInfos()[i], chain_clone_map, *loop_analysis.GetLoopIterationIdx(), next_channel_id)); + + if (postprocess_peeled.has_value()) { + TF_RETURN_IF_ERROR(postprocess_peeled.value()(new_init_operands[idx])); + } } ConstantValue next_loop_iteration = loop_analysis.GetLoopStart()->add(*loop_analysis.GetLoopIncrement()); @@ -2103,6 +2208,8 @@ static Status TransformLoopBackward(const WhileLoopAnalysis& loop_analysis, collective_to_move_clone_map[u] = loop_iterator_for_pipelined_instrs; } } + // Record the loop variant parameters used in the backward chain. + LoopVariantParameterInfo loop_variant_parameter_info; // Clone loop in the body of the new loop. We change some things like // input/output shapes and how we connect loop iterator to the original // chains that we are pipelining. @@ -2116,10 +2223,15 @@ static Status TransformLoopBackward(const WhileLoopAnalysis& loop_analysis, if (it != collective_to_move_map.end()) { TF_ASSIGN_OR_RETURN( cloned_instr, - CloneBackwardChain( - body_builder, loop_analysis.GetMoveInfos()[it->second], - collective_to_move_clone_map, - *loop_analysis.GetLoopIterationIdx(), next_channel_id)); + CloneBackwardChain(body_builder, + loop_analysis.GetMoveInfos()[it->second], + collective_to_move_clone_map, + *loop_analysis.GetLoopIterationIdx(), + next_channel_id, &loop_variant_parameter_info)); + + if (postprocess_rotated.has_value()) { + TF_RETURN_IF_ERROR(postprocess_rotated.value()(cloned_instr)); + } } else { auto new_operands = MapNewOperands(instr->operands(), while_body_replacement_map); @@ -2140,6 +2252,18 @@ static Status TransformLoopBackward(const WhileLoopAnalysis& loop_analysis, } while_body_replacement_map[instr] = cloned_instr; } + // For each loop variant parameter used in the backward chain, we temporarily + // use a newly added loop parameter in the cloned loop. We now need to replace + // this temporary value with an element in the loop output tuple. The index + // of the element in the tuple is the same as the index of the loop variant + // parameter before we pipeline the loop. + for (const auto& [idx, value] : loop_variant_parameter_info) { + auto it = while_body_replacement_map.find(new_root_operands[idx]); + CHECK(it != while_body_replacement_map.end()) + << new_root_operands[idx]->ToString() << " not present in map"; + TF_RETURN_IF_ERROR(value->ReplaceAllUsesWith(it->second)); + } + new_root_operands.back() = body_builder.AddInstruction(HloInstruction::CreateBinary( loop_index_shape, HloOpcode::kAdd, @@ -2255,7 +2379,7 @@ static Status TransformLoopBackward(const WhileLoopAnalysis& loop_analysis, return OkStatus(); } -StatusOr CollectivePipeliner::Run( +absl::StatusOr CollectivePipeliner::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { CHECK(config_.acceptable_formatting); @@ -2289,7 +2413,8 @@ StatusOr CollectivePipeliner::Run( << loop_analysis.GetLoopIterationCount()->ToString(); loop_analysis.CollectCollectivesToMove( config_.level_to_operate_on, config_.pipelining_direction, - config_.should_process, config_.acceptable_formatting); + config_.should_process, config_.acceptable_formatting, + config_.should_allow_loop_variant_parameter_in_chain); if (loop_analysis.GetMoveInfos().empty()) { continue; } @@ -2322,7 +2447,8 @@ StatusOr CollectivePipeliner::Run( TF_RETURN_IF_ERROR(TransformLoopBackward( loop_analysis, !config_.last_run, config_.level_to_operate_on, config_.process_different_sized_ops, config_.should_process, - config_.acceptable_formatting, next_channel_id)); + config_.acceptable_formatting, config_.postprocess_backward_peeled_op, + config_.postprocess_backward_rorated_op, next_channel_id)); } ++transformed_loops; changed = true; diff --git a/xla/service/collective_pipeliner.h b/xla/service/collective_pipeliner.h index 8553dd9652193..384bafe4df26b 100644 --- a/xla/service/collective_pipeliner.h +++ b/xla/service/collective_pipeliner.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -61,6 +61,12 @@ class CollectivePipeliner : public HloModulePass { kForward, kForwardSink, }; + + // Postprocessing cloned collective instructions, such as for modifying loop + // iteration related frontend attributes to reflect loop pipelining. + using HloPostprocessor = + std::optional>; + struct Config { int64_t level_to_operate_on = 0; // Maximum number of HLOs to pipeline per loop. (Meant to help controlling @@ -82,6 +88,13 @@ class CollectivePipeliner : public HloModulePass { // buffer we are storing the value in in the output loop for forward // pipelining. This function allows to not do it for certain ops. HloPredicate reuse_pipelined_op_buffer; + // Determine whether a loop variant parameter should be allowed in + // pipelining chains. This is currently only used to support kBackward + // pipelinining. + HloPredicate should_allow_loop_variant_parameter_in_chain = + HloPredicateFalse; + HloPostprocessor postprocess_backward_peeled_op = std::nullopt; + HloPostprocessor postprocess_backward_rorated_op = std::nullopt; }; static const char* const kInsertedByPreviousStep; static const char* const kSunkByPreviousStep; @@ -113,7 +126,7 @@ class CollectivePipeliner : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/collective_pipeliner_test.cc b/xla/service/collective_pipeliner_test.cc index efdd4f95bf69d..f28d00b4d2fbd 100644 --- a/xla/service/collective_pipeliner_test.cc +++ b/xla/service/collective_pipeliner_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,12 @@ limitations under the License. #include "xla/service/collective_pipeliner.h" #include +#include #include #include #include +#include "absl/log/check.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -29,6 +31,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/hlo_parser.h" #include "xla/service/hlo_pass_pipeline.h" +#include "xla/status.h" #include "xla/statusor.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" @@ -53,16 +56,20 @@ class CollectivePipelinerTest : public HloTestBase { HloModuleConfig config_; }; -StatusOr RunOptimizer( +absl::StatusOr RunOptimizer( HloModule* module, bool last_run, int64_t level_to_operate_on = 0, bool pipeline_use_tree = false, bool process_different_sized_ops = true, CollectivePipeliner::PipeliningDirection direction = CollectivePipeliner::PipeliningDirection::kForward, HloPredicate should_process = HloPredicateIsOp, - HloPredicate acceptable_formatting = - [](const HloInstruction*) { return true; }, - HloPredicate reuse_pipelined_op_buffer = - [](const HloInstruction* i) { return true; }) { + HloPredicate acceptable_formatting = HloPredicateTrue, + HloPredicate reuse_pipelined_op_buffer = HloPredicateTrue, + HloPredicate should_allow_loop_variant_parameter_in_chain = + HloPredicateFalse, + CollectivePipeliner::HloPostprocessor postprocess_backward_peeled = + std::nullopt, + CollectivePipeliner::HloPostprocessor postprocess_backward_rotated = + std::nullopt) { CollectivePipeliner::Config config = { /*level_to_operate_on=*/level_to_operate_on, /*max_pipelining_per_loop=*/INT64_MAX, @@ -74,7 +81,8 @@ StatusOr RunOptimizer( /*should_process=*/should_process, /*acceptable_formatting=*/acceptable_formatting, /*reuse_pipelined_op_buffer=*/reuse_pipelined_op_buffer, - }; + should_allow_loop_variant_parameter_in_chain, postprocess_backward_peeled, + postprocess_backward_rotated}; HloPassPipeline pass("optimizer"); pass.AddPass(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); @@ -163,6 +171,71 @@ ENTRY entry { EXPECT_EQ(get_tuple_index->tuple_index(), 3); } +TEST_F(CollectivePipelinerTest, UpdateSendRecvChannelIdForHostTransfers) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2 + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + after-all = after-all() + send.88 = (s32[], u32[], token[]) send( + add.232, after-all), channel_id=2, is_host_transfer=true + send-done.88 = token[] send-done(send.88), channel_id=2, is_host_transfer=true + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1 + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ar.1, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0) + while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true).value()); + XLA_VLOG_LINES(1, module->ToString()); + auto* entry_comp = module->entry_computation(); + auto* unrolled_send_done = entry_comp->GetInstructionWithName("send-done.0"); + ASSERT_THAT(unrolled_send_done, ::testing::NotNull()); + auto* unrolled_send = unrolled_send_done->operand(0); + auto channel_id = [](const HloInstruction* instr) { + return DynCast(instr)->channel_id(); + }; + EXPECT_EQ(channel_id(unrolled_send), channel_id(unrolled_send_done)); +} + TEST_F(CollectivePipelinerTest, TransformIncrementIndexByOneNoReuse) { constexpr absl::string_view hlo_string = R"( HloModule module @@ -1678,12 +1751,16 @@ TEST_F(CollectivePipelinerTest, TransformRecvSendBackwards) { after-all = token[] after-all() recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0, 1}, {1, 2}, {2, 3}, {3, 4}}" + _xla_send_recv_source_target_pairs="{{0, 1}, {1, 2}, {2, 3}, {3, 4}}", + _xla_send_recv_pipeline="0" } send = (f32[1, 1024, 1024], u32[], token[]) send(p, after-all), channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0, 1}, {1, 2}, {2, 3}, {3, 4}}" + _xla_send_recv_source_target_pairs="{{0, 1}, {1, 2}, {2, 3}, {3, 4}}", + _xla_send_recv_pipeline="0" + } + recv-done = (f32[1, 1024, 1024], token[]) recv-done(recv), channel_id=1, frontend_attributes={ + _xla_send_recv_pipeline="0" } - recv-done = (f32[1, 1024, 1024], token[]) recv-done(recv), channel_id=1 recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done), index=0 replica = u32[] replica-id() @@ -1696,7 +1773,9 @@ TEST_F(CollectivePipelinerTest, TransformRecvSendBackwards) { d = f32[1, 1024, 1024] tan(c) s = f32[1, 1024, 1024] dot(c, d), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} - send-done = token[] send-done(send), channel_id=1 + send-done = token[] send-done(send), channel_id=1, frontend_attributes={ + _xla_send_recv_pipeline="0" + } ROOT result = (u32[], f32[1, 1024, 1024]) tuple(new_count, s) } @@ -1747,6 +1826,140 @@ TEST_F(CollectivePipelinerTest, TransformRecvSendBackwards) { EXPECT_EQ(recv1->channel_id(), send1->channel_id()); } +TEST_F(CollectivePipelinerTest, + TransformRecvSendBackwardsWithLoopVariantParameter) { + constexpr absl::string_view hlo_string = R"( + HloModule module + cond { + param = (u32[], u32[2]) parameter(0) + count = get-tuple-element(param), index=0 + ub = u32[] constant(2) + ROOT result = pred[] compare(count, ub), direction=LT + } + + body { + param = (u32[], u32[2]) parameter(0) + count = get-tuple-element(param), index=0 + send-data = get-tuple-element(param), index=1 + + after-all.0 = token[] after-all() + recv.0 = (u32[2], u32[], token[]) recv(after-all.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_other_attr="0" + } + after-all.0.s = token[] after-all() + send.0 = (u32[2], u32[], token[]) send(send-data, after-all.0.s), + channel_id=1, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_other_attr="0" + } + recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + recv-data = u32[2] get-tuple-element(recv-done.0), index=0 + + c1 = u32[] constant(1) + new_count = u32[] add(count, c1) + + r = u32[2] broadcast(c1), dimensions={} + s = u32[2] add(r, recv-data) + + send-done.0 = token[] send-done(send.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + ROOT result = (u32[], u32[2]) tuple(new_count, s) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + c1 = u32[] constant(1) + r = u32[] replica-id() + a = u32[] add(c1, r) + init = u32[2] broadcast(a), dimensions={} + while_init = (u32[], u32[2]) tuple(c0, init) + while_result = (u32[], u32[2]) while(while_init), body=body, condition=cond + ROOT result = u32[2] get-tuple-element(while_result), index=1 + })"; + + auto should_pipeline = [](const HloInstruction* instr) { + if (!HloPredicateIsOp(instr) && + !HloPredicateIsOp(instr)) + return false; + const HloSendRecvInstruction* send_recv = + dynamic_cast(instr); + // Check that the Send or Recv is used for non-trivial computation, which + // also help avoid repeatedly pipelining a loop. + return (send_recv->user_count() == 1 && send_recv->parent() != nullptr && + send_recv->users()[0] != send_recv->parent()->root_instruction()); + }; + auto should_allow_loop_variant_parameter = [](const HloInstruction* instr) { + CHECK(instr->opcode() == HloOpcode::kGetTupleElement && + instr->operand(0)->opcode() == HloOpcode::kParameter); + return true; + }; + const char* kAttr = "_xla_other_attr"; + // Mutate an existing attribute. + auto postprocess_peeled = [&](HloInstruction* instr) { + xla::FrontendAttributes attributes = instr->frontend_attributes(); + (*attributes.mutable_map())[kAttr] = "1"; + instr->set_frontend_attributes(attributes); + return OkStatus(); + }; + auto postprocess_rotated = [&](HloInstruction* instr) { + xla::FrontendAttributes attributes = instr->frontend_attributes(); + (*attributes.mutable_map())[kAttr] = "2"; + instr->set_frontend_attributes(attributes); + return OkStatus(); + }; + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, 0, + /*pipeline_use_tree=*/false, + /*process_different_sized_ops=*/false, + CollectivePipeliner::PipeliningDirection::kBackward, + should_pipeline, + /*acceptable_formatting=*/HloPredicateTrue, + /*reuse_pipelined_op_buffer=*/HloPredicateTrue, + should_allow_loop_variant_parameter, + postprocess_peeled, postprocess_rotated) + .value()); + XLA_VLOG_LINES(10, module->ToString()); + auto while_op = FindInstruction(module.get(), "while"); + EXPECT_EQ(while_op->opcode(), HloOpcode::kWhile); + EXPECT_EQ(while_op->shape().tuple_shapes().size(), 5); + auto recv1 = + DynCast(FindInstruction(module.get(), "recv.1")); + EXPECT_NE(recv1, nullptr); + auto recv2 = + DynCast(FindInstruction(module.get(), "recv.2")); + EXPECT_NE(recv2, nullptr); + EXPECT_EQ(recv1->channel_id(), recv2->channel_id()); + + auto send1 = + DynCast(FindInstruction(module.get(), "send.1")); + EXPECT_NE(send1, nullptr); + auto send2 = + DynCast(FindInstruction(module.get(), "send.2")); + EXPECT_NE(send2, nullptr); + EXPECT_EQ(send1->channel_id(), send2->channel_id()); + + EXPECT_EQ(recv1->channel_id(), send1->channel_id()); + + const char* kSourceTarget = "_xla_send_recv_source_target_pairs=\"{{3,0}}\""; + const char* kPeeledAttr = "_xla_other_attr=\"1\""; + const char* kRotatedAttr = "_xla_other_attr=\"2\""; + EXPECT_THAT(send1->ToString(), ::testing::HasSubstr(kSourceTarget)); + EXPECT_THAT(recv1->ToString(), ::testing::HasSubstr(kSourceTarget)); + EXPECT_THAT(send2->ToString(), ::testing::HasSubstr(kSourceTarget)); + EXPECT_THAT(recv2->ToString(), ::testing::HasSubstr(kSourceTarget)); + EXPECT_THAT(send1->ToString(), ::testing::HasSubstr(kPeeledAttr)); + EXPECT_THAT(recv1->ToString(), ::testing::HasSubstr(kPeeledAttr)); + EXPECT_THAT(send2->ToString(), ::testing::HasSubstr(kRotatedAttr)); + EXPECT_THAT(recv2->ToString(), ::testing::HasSubstr(kRotatedAttr)); +} + TEST_F(CollectivePipelinerTest, MultiUsesElementwiseMerge) { constexpr absl::string_view hlo_string = R"( HloModule module diff --git a/xla/service/collective_transformation_reorderer.cc b/xla/service/collective_transformation_reorderer.cc index 010b730dbe40d..ba04054de22fb 100644 --- a/xla/service/collective_transformation_reorderer.cc +++ b/xla/service/collective_transformation_reorderer.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,14 +15,19 @@ limitations under the License. #include "xla/service/collective_transformation_reorderer.h" +#include #include #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_dce.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -131,9 +136,30 @@ GetAllGatherTransformations(HloInstruction* all_gather) { } return transformations; } + +// Find a list of reshapes feeding the all-reduce that could be moved to after +// the all-reduce. +std::vector GetAllReduceTransformations( + HloInstruction* all_reduce) { + HloAllReduceInstruction* all_reduce_instruction = + DynCast(all_reduce); + CHECK_NE(all_reduce_instruction, nullptr); + if (all_reduce_instruction->constrain_layout()) { + return {}; + } + std::vector transformation_hlos; + HloInstruction* transformation_hlo = all_reduce->mutable_operand(0); + while (transformation_hlo->opcode() == HloOpcode::kReshape && + transformation_hlo->user_count() == 1) { + transformation_hlos.push_back(transformation_hlo); + transformation_hlo = transformation_hlo->mutable_operand(0); + } + return transformation_hlos; +} } // namespace -StatusOr CollectiveTransformationReorder::ReorderAllGatherTransformations( +absl::StatusOr +CollectiveTransformationReorder::ReorderAllGatherTransformations( HloModule* module, const absl::flat_hash_set& execution_threads) { // First, find all all-gathers and reshapes that are eligible for this @@ -205,16 +231,76 @@ StatusOr CollectiveTransformationReorder::ReorderAllGatherTransformations( computation->set_root_instruction(new_all_gather); } } - // Remove the original all-gather and reshapes. - HloDCE dce; - TF_RETURN_IF_ERROR(dce.Run(module, execution_threads).status()); return true; } -StatusOr CollectiveTransformationReorder::Run( +absl::StatusOr +CollectiveTransformationReorder::ReorderAllReduceTransformations( HloModule* module, const absl::flat_hash_set& execution_threads) { - return ReorderAllGatherTransformations(module, execution_threads); + // First, find all reshapes and all-reduces that are eligible for this + // transformation. + absl::flat_hash_map> + all_reduce_to_transformations; + for (HloComputation* computation : + module->MakeComputationPostOrder(execution_threads)) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kAllReduce) { + if (instruction->user_count() != 1 || + computation->root_instruction() == instruction) { + continue; + } + std::vector reshapes = + GetAllReduceTransformations(instruction); + if (reshapes.empty()) { + continue; + } + all_reduce_to_transformations[instruction] = std::move(reshapes); + } + } + } + if (all_reduce_to_transformations.empty()) { + return false; + } + for (auto& [inst, reshapes] : all_reduce_to_transformations) { + HloComputation* computation = inst->parent(); + HloAllReduceInstruction* all_reduce = + DynCast(inst); + CHECK(!reshapes.empty()); + HloInstruction* cur_operand = reshapes.back()->mutable_operand(0); + HloInstruction* new_all_reduce = + computation->AddInstruction(HloInstruction::CreateAllReduce( + cur_operand->shape(), {cur_operand}, all_reduce->to_apply(), + all_reduce->replica_groups(), all_reduce->constrain_layout(), + all_reduce->channel_id(), all_reduce->use_global_device_ids())); + + // For each eligible reshape on the old all-reduce's operand, we reshape the + // new all-reduce result instead. + cur_operand = new_all_reduce; + for (int64_t i = reshapes.size() - 1; i >= 0; --i) { + cur_operand = computation->AddInstruction( + HloInstruction::CreateReshape(reshapes[i]->shape(), cur_operand)); + } + TF_RETURN_IF_ERROR( + computation->ReplaceInstruction(all_reduce, cur_operand)); + } + return true; +} + +absl::StatusOr CollectiveTransformationReorder::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + TF_ASSIGN_OR_RETURN(bool ag_changed, ReorderAllGatherTransformations( + module, execution_threads)); + TF_ASSIGN_OR_RETURN(bool ar_changed, ReorderAllReduceTransformations( + module, execution_threads)); + if (ag_changed || ar_changed) { + // Remove the original all-gathers/all-reduces and reshapes. + HloDCE dce; + TF_RETURN_IF_ERROR(dce.Run(module, execution_threads).status()); + } + return ag_changed || ar_changed; } } // namespace xla diff --git a/xla/service/collective_transformation_reorderer.h b/xla/service/collective_transformation_reorderer.h index af5d16e3e1b44..fe730bd5bfe50 100644 --- a/xla/service/collective_transformation_reorderer.h +++ b/xla/service/collective_transformation_reorderer.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_COLLECTIVE_TRANSFORMATION_REORDERER_H_ #define XLA_SERVICE_COLLECTIVE_TRANSFORMATION_REORDERER_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" @@ -24,11 +27,16 @@ limitations under the License. namespace xla { -// Transforms all-gather + reshape into reshape + all-gather when the reshape -// only changes the shape of the all-gather shards, i.e., it does not reshape -// across the all-gather dimension. +// Transforms +// -- all-gather + reshape into reshape + all-gather and +// -- reshape + all-reduce into all-reduce + reshape. +// Both transformations require that there are no other users affected, i.e., +// reshape user count should be 1. +// all-gather transformation requires the reshape to only change the shape of +// the all-gather shards, i.e., not reshaping across the all-gather dimension. +// all-reduce transformation requires all-reduce to be not layout constrained. -// Generally speaking, +// all-gather + reshape example: // input = [C_0, C_1, ..., C_i, ..., C_{n-1}, C_n] ... // all-gather = [C_0, C_1, ..., P*C_i, ... C_{n-1}, C_n] all-gather(input) @@ -52,12 +60,16 @@ class CollectiveTransformationReorder : public HloModulePass { "collective-transformation-reorderer"; return kName; } - StatusOr Run( + using HloPassInterface::Run; + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; private: - StatusOr ReorderAllGatherTransformations( + absl::StatusOr ReorderAllGatherTransformations( + HloModule* module, + const absl::flat_hash_set& execution_threads); + absl::StatusOr ReorderAllReduceTransformations( HloModule* module, const absl::flat_hash_set& execution_threads); }; diff --git a/xla/service/collective_transformation_reorderer_test.cc b/xla/service/collective_transformation_reorderer_test.cc index 00b72ccd2d853..3721406e64901 100644 --- a/xla/service/collective_transformation_reorderer_test.cc +++ b/xla/service/collective_transformation_reorderer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,9 +15,19 @@ limitations under the License. #include "xla/service/collective_transformation_reorderer.h" +#include + #include +#include +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/service/hlo_verifier.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -26,7 +36,7 @@ namespace op = xla::testing::opcode_matchers; class CollectiveTransformationReordererTest : public HloTestBase { public: - StatusOr RunCollectiveTransformationReorderer(HloModule* module) { + absl::StatusOr RunCollectiveTransformationReorderer(HloModule* module) { CollectiveTransformationReorder reorderer; return reorderer.Run(module, {}); } @@ -144,6 +154,151 @@ TEST_F(CollectiveTransformationReordererTest, EXPECT_FALSE(changed); } +TEST_F(CollectiveTransformationReordererTest, AllReduceSingleReshape) { + absl::string_view hlo_string = R"( + HloModule module + + add { + a = bf16[] parameter(0) + b = bf16[] parameter(1) + ROOT s = bf16[] add(a, b) + } + + ENTRY entry { + param = bf16[16384,6144] parameter(0) + reshape = bf16[1,16384,6144] reshape(param) + all-reduce = bf16[1,16384,6144] all-reduce(reshape), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=add + constant = s32[] constant(0) + ROOT dynamic-slice = bf16[1,16384,384] dynamic-slice(all-reduce, constant, constant, constant), dynamic_slice_sizes={1,16384,384} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunCollectiveTransformationReorderer(module.get())); + EXPECT_TRUE(changed); + TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) + .Run(module.get()) + .status()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::DynamicSlice(op::Reshape(op::AllReduce(op::Parameter())), + op::Constant(), op::Constant(), op::Constant())); +} + +TEST_F(CollectiveTransformationReordererTest, AllReduceTwoReshapes) { + absl::string_view hlo_string = R"( + HloModule module + + add { + a = bf16[] parameter(0) + b = bf16[] parameter(1) + ROOT s = bf16[] add(a, b) + } + + ENTRY entry { + param = bf16[16384,3072,2] parameter(0) + reshape.1 = bf16[16384,6144] reshape(param) + reshape.2 = bf16[1,16384,6144] reshape(reshape.1) + all-reduce = bf16[1,16384,6144] all-reduce(reshape.2), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=add + constant = s32[] constant(0) + ROOT dynamic-slice = bf16[1,16384,384] dynamic-slice(all-reduce, constant, constant, constant), dynamic_slice_sizes={1,16384,384} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunCollectiveTransformationReorderer(module.get())); + EXPECT_TRUE(changed); + TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) + .Run(module.get()) + .status()); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::DynamicSlice(op::Reshape(op::Reshape(op::AllReduce(op::Parameter()))), + op::Constant(), op::Constant(), op::Constant())); +} + +TEST_F(CollectiveTransformationReordererTest, AllReduceReshapeWithTwoUsers) { + absl::string_view hlo_string = R"( + HloModule module + + add { + a = bf16[] parameter(0) + b = bf16[] parameter(1) + ROOT s = bf16[] add(a, b) + } + + ENTRY entry { + param = bf16[16384,6144] parameter(0) + reshape = bf16[1,16384,6144] reshape(param) + all-reduce = bf16[1,16384,6144] all-reduce(reshape), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=add + constant = s32[] constant(0) + dynamic-slice = bf16[1,16384,384] dynamic-slice(all-reduce, constant, constant, constant), dynamic_slice_sizes={1,16384,384} + copy = bf16[1,16384,6144] copy(reshape) + ROOT tuple = (bf16[1,16384,6144], bf16[1,16384,384]) tuple(copy, dynamic-slice) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunCollectiveTransformationReorderer(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(CollectiveTransformationReordererTest, AllReduceWithTwoUsersReshape) { + absl::string_view hlo_string = R"( + HloModule module + + add { + a = bf16[] parameter(0) + b = bf16[] parameter(1) + ROOT s = bf16[] add(a, b) + } + + ENTRY entry { + param = bf16[16384,6144] parameter(0) + reshape = bf16[1,16384,6144] reshape(param) + all-reduce = bf16[1,16384,6144] all-reduce(reshape), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=add + constant = s32[] constant(0) + dynamic-slice = bf16[1,16384,384] dynamic-slice(all-reduce, constant, constant, constant), dynamic_slice_sizes={1,16384,384} + copy = bf16[1,16384,6144] copy(all-reduce) + ROOT tuple = (bf16[1,16384,6144], bf16[1,16384,384]) tuple(copy, dynamic-slice) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunCollectiveTransformationReorderer(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(CollectiveTransformationReordererTest, AllReduceConstrainLayout) { + absl::string_view hlo_string = R"( + HloModule module + + add { + a = bf16[] parameter(0) + b = bf16[] parameter(1) + ROOT s = bf16[] add(a, b) + } + + ENTRY entry { + param = bf16[16384,6144] parameter(0) + reshape = bf16[1,16384,6144] reshape(param) + all-reduce = bf16[1,16384,6144] all-reduce(reshape), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, constrain_layout=true, to_apply=add + constant = s32[] constant(0) + ROOT dynamic-slice = bf16[1,16384,384] dynamic-slice(all-reduce, constant, constant, constant), dynamic_slice_sizes={1,16384,384} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunCollectiveTransformationReorderer(module.get())); + EXPECT_FALSE(changed); +} + } // namespace } // namespace xla diff --git a/xla/service/collectives_schedule_linearizer.cc b/xla/service/collectives_schedule_linearizer.cc index 12e26ac38d9cb..a367831a1d0fe 100644 --- a/xla/service/collectives_schedule_linearizer.cc +++ b/xla/service/collectives_schedule_linearizer.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ limitations under the License. namespace xla { // TODO(b/181653482): Fix for interprocedural collectives as well. -StatusOr CollectivesScheduleLinearizer::Run( +absl::StatusOr CollectivesScheduleLinearizer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { if (is_enabled_ && !is_enabled_(module)) { diff --git a/xla/service/collectives_schedule_linearizer.h b/xla/service/collectives_schedule_linearizer.h index 33c5866145c88..ad722dc395887 100644 --- a/xla/service/collectives_schedule_linearizer.h +++ b/xla/service/collectives_schedule_linearizer.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -39,7 +39,7 @@ class CollectivesScheduleLinearizer : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/collectives_schedule_linearizer_test.cc b/xla/service/collectives_schedule_linearizer_test.cc index 42bf2e88a8b67..eeb9b8b936e55 100644 --- a/xla/service/collectives_schedule_linearizer_test.cc +++ b/xla/service/collectives_schedule_linearizer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/comparison_expander.cc b/xla/service/comparison_expander.cc index c08bc76fe35f1..4a7ff3d5a4462 100644 --- a/xla/service/comparison_expander.cc +++ b/xla/service/comparison_expander.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,44 +15,59 @@ limitations under the License. #include "xla/service/comparison_expander.h" -#include "xla/client/lib/comparators.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include +#include + +#include "absl/algorithm/container.h" +#include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/hlo_creation_utils.h" +#include "xla/literal_util.h" +#include "xla/primitive_util.h" +#include "xla/shape_util.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" namespace xla { HloInstruction* BitcastConvertFloatingPointToIntegral( - HloComputation* computation, HloInstruction* value, - const Shape& signed_shape, const Shape& unsigned_shape, - HloInstruction* zero, HloInstruction* max_value) { + HloComputation* computation, HloInstruction* value, HloInstruction* zero, + HloInstruction* min_value, HloInstruction* max_value) { // Switch from a floating point value to a integer value in such a way that // when using the integer value to compare, we get the same result for normal // values, and -Nan is treated as the smallest value, and Nan is treated as // the largest value. // If f is a float, and // x = bit_cast(f); - // y = x < 0 ? numeric_limits::max() - x : x; + // y = x < 0 ? numeric_limits::max() ^ x : x; // then y is ordered as an int32_t such that finite values have the obvious // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning // and end of the ordering. - // Note that in order to avoid -x to overflow, we calculate - // numeric_limits::max() - x as unsigned, and then convert back to - // signed. + auto signed_shape = max_value->shape(); auto signed_value = computation->AddInstruction( HloInstruction::CreateBitcastConvert(signed_shape, value)); - auto unsigned_value = computation->AddInstruction( - HloInstruction::CreateBitcastConvert(unsigned_shape, value)); - auto flipped_value = computation->AddInstruction(HloInstruction::CreateBinary( - unsigned_shape, HloOpcode::kSubtract, max_value, unsigned_value)); - flipped_value = computation->AddInstruction( - HloInstruction::CreateBitcastConvert(signed_shape, flipped_value)); - auto compare_shape = signed_shape; - compare_shape.set_element_type(PRED); + auto compare_shape = ShapeUtil::ChangeElementType(signed_shape, PRED); + HloInstruction* flipped_value; + if (primitive_util::HasNegativeZero(value->shape().element_type())) { + flipped_value = computation->AddInstruction(HloInstruction::CreateBinary( + signed_shape, HloOpcode::kXor, max_value, signed_value)); + } else { + // There is no -0 so min_denorm() must take its place, this is the same as + // adding one to flipped_value. + flipped_value = computation->AddInstruction(HloInstruction::CreateBinary( + signed_shape, HloOpcode::kSubtract, min_value, signed_value)); + + // NaN is the smallest value as it is negative. + auto nan_bit_pattern = min_value; + auto is_nan = computation->AddInstruction(HloInstruction::CreateCompare( + compare_shape, signed_value, nan_bit_pattern, + ComparisonDirection::kEq)); + flipped_value = computation->AddInstruction(HloInstruction::CreateTernary( + signed_shape, HloOpcode::kSelect, is_nan, min_value, flipped_value)); + } auto is_negative = computation->AddInstruction(HloInstruction::CreateCompare( compare_shape, signed_value, zero, ComparisonDirection::kLt)); return computation->AddInstruction( @@ -63,9 +78,9 @@ HloInstruction* BitcastConvertFloatingPointToIntegral( bool ComparisonExpander::InstructionMatchesPattern( HloInstruction* instruction) { if (HloCompareInstruction* compare = - dynamic_cast(instruction)) { + DynCast(instruction)) { HloInstruction* lhs = instruction->operands()[0]; - if (compare->type() == Comparison::Type::kFloatTotalOrder && + if (compare->order() == Comparison::Order::kTotal && primitive_util::IsFloatingPointType(lhs->shape().element_type())) { return true; } @@ -73,60 +88,64 @@ bool ComparisonExpander::InstructionMatchesPattern( return false; } -StatusOr ComparisonExpander::ExpandInstruction( +absl::StatusOr ComparisonExpander::ExpandInstruction( HloInstruction* instruction) { - CHECK(instruction->opcode() == HloOpcode::kCompare); + CHECK_EQ(instruction->opcode(), HloOpcode::kCompare); HloCompareInstruction* compare = static_cast(instruction); - CHECK(compare->type() == Comparison::Type::kFloatTotalOrder); + CHECK(compare->order() == Comparison::Order::kTotal) + << ComparisonOrderToString(compare->order()); HloComputation* computation = instruction->parent(); HloInstruction* lhs = instruction->operands()[0]; HloInstruction* rhs = instruction->operands()[1]; - Shape compare_shape = lhs->shape(); - PrimitiveType compare_type = compare_shape.element_type(); + PrimitiveType compare_type = lhs->shape().element_type(); CHECK(primitive_util::IsFloatingPointType(compare_type)); - // Special-case handling for BF16. We currently do not support direct - // comparisons with BF16, so we convert to F32 and then use the F32 - // comparison logic. - if (compare_type == BF16) { - compare_type = F32; - compare_shape.set_element_type(compare_type); - lhs = computation->AddInstruction( - HloInstruction::CreateConvert(compare_shape, lhs)); - rhs = computation->AddInstruction( - HloInstruction::CreateConvert(compare_shape, rhs)); + if (auto do_upcast = absl::c_find_if( + expand_via_upcast_, + [compare_type](std::pair upcast) { + return upcast.first == compare_type; + }); + do_upcast != expand_via_upcast_.end()) { + CHECK(primitive_util::CastPreservesValues(do_upcast->first, + do_upcast->second)); + compare_type = do_upcast->second; + lhs = computation->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(lhs->shape(), compare_type), lhs)); + rhs = computation->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(rhs->shape(), compare_type), rhs)); } - int64_t bit_width = primitive_util::BitWidth(compare_type); + int64_t bit_width = primitive_util::BitWidth(lhs->shape().element_type()); PrimitiveType signed_type = primitive_util::SignedIntegralTypeForBitWidth(bit_width); - PrimitiveType unsigned_type = - primitive_util::UnsignedIntegralTypeForBitWidth(bit_width); - auto signed_shape = compare_shape; - signed_shape.set_element_type(signed_type); - auto unsigned_shape = compare_shape; - unsigned_shape.set_element_type(unsigned_type); + auto signed_shape = ShapeUtil::ChangeElementType(lhs->shape(), signed_type); + auto zero_value = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(signed_type))); - zero_value = computation->AddInstruction(HloInstruction::CreateBroadcast( - signed_shape, zero_value, zero_value->shape().dimensions())); - auto max_signed = computation->AddInstruction( + zero_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, zero_value, {})); + + auto min_value = computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::MinValue(signed_shape.element_type()))); + min_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, min_value, {})); + + auto max_value = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::MaxValue(signed_type))); - auto max_shape = max_signed->shape(); - max_shape.set_element_type(unsigned_type); - auto max_unsigned = computation->AddInstruction( - HloInstruction::CreateConvert(max_shape, max_signed)); - auto max_value = computation->AddInstruction(HloInstruction::CreateBroadcast( - unsigned_shape, max_unsigned, max_shape.dimensions())); - lhs = BitcastConvertFloatingPointToIntegral( - computation, lhs, signed_shape, unsigned_shape, zero_value, max_value); - rhs = BitcastConvertFloatingPointToIntegral( - computation, rhs, signed_shape, unsigned_shape, zero_value, max_value); + max_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, max_value, {})); + + lhs = BitcastConvertFloatingPointToIntegral(computation, lhs, zero_value, + min_value, max_value); + rhs = BitcastConvertFloatingPointToIntegral(computation, rhs, zero_value, + min_value, max_value); + auto new_compare = computation->AddInstruction(HloInstruction::CreateCompare( instruction->shape(), lhs, rhs, compare->direction(), Comparison::Type::kSigned)); + VLOG(2) << "New comparison instruction for total order:" - << new_compare->ToString() << "\n"; + << new_compare->ToString(); return new_compare; } diff --git a/xla/service/comparison_expander.h b/xla/service/comparison_expander.h index 2d42393484d64..d95b6df78c123 100644 --- a/xla/service/comparison_expander.h +++ b/xla/service/comparison_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,10 +17,15 @@ limitations under the License. #define XLA_SERVICE_COMPARISON_EXPANDER_H_ #include +#include -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_pass_interface.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/primitive_util.h" #include "xla/service/op_expander_pass.h" +#include "xla/statusor.h" +#include "xla/xla_data.pb.h" namespace xla { @@ -28,7 +33,11 @@ namespace xla { // order comparison of floating point numbers. class ComparisonExpander : public OpExpanderPass { public: - explicit ComparisonExpander() = default; + explicit ComparisonExpander( + absl::Span> + expand_via_upcast = {}) + : expand_via_upcast_(expand_via_upcast.begin(), expand_via_upcast.end()) { + } ~ComparisonExpander() override = default; absl::string_view name() const override { return "comparison-expander"; } @@ -38,8 +47,10 @@ class ComparisonExpander : public OpExpanderPass { // Returns a replacement for `instruction`, or nullptr if no replacement is // needed (e.g. only the to_apply subcomputation of the instruction was // modified). - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; + + std::vector> expand_via_upcast_; }; } // namespace xla diff --git a/xla/service/compilation_cache.cc b/xla/service/compilation_cache.cc index 220403a2706b6..9187924b6d6a9 100644 --- a/xla/service/compilation_cache.cc +++ b/xla/service/compilation_cache.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -51,7 +51,7 @@ ExecutionHandle CompilationCache::Insert( return handle; } -StatusOr> CompilationCache::LookUp( +absl::StatusOr> CompilationCache::LookUp( const ExecutionHandle& handle) const { absl::MutexLock lock(&mutex_); diff --git a/xla/service/compilation_cache.h b/xla/service/compilation_cache.h index c76cacea6ba08..65384bf8340c4 100644 --- a/xla/service/compilation_cache.h +++ b/xla/service/compilation_cache.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -39,7 +39,7 @@ class CompilationCache { // Lookup the Executable for the specified handle in the cache. Return a // shared_ptr to the Executable if it exists in the cache. - StatusOr> LookUp( + absl::StatusOr> LookUp( const ExecutionHandle& handle) const; protected: diff --git a/xla/service/compilation_environments.cc b/xla/service/compilation_environments.cc index 48a6ea0f1cd0b..0c2569b92dfcf 100644 --- a/xla/service/compilation_environments.cc +++ b/xla/service/compilation_environments.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -124,7 +124,7 @@ CompilationEnvironments& CompilationEnvironments::operator=( return *this; } -StatusOr> +absl::StatusOr> CompilationEnvironments::CreateFromProto( const CompilationEnvironmentsProto& proto) { auto envs = std::make_unique(); diff --git a/xla/service/compilation_environments.h b/xla/service/compilation_environments.h index 6c7b592ac59b1..3ffea24bb53b1 100644 --- a/xla/service/compilation_environments.h +++ b/xla/service/compilation_environments.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -47,7 +47,7 @@ namespace xla { class CompilationEnvironments { public: using ProcessNewEnvFn = - std::function>( + std::function>( std::unique_ptr)>; CompilationEnvironments() = default; @@ -56,8 +56,8 @@ class CompilationEnvironments { ~CompilationEnvironments() = default; // Deserializes the given CompilationEnvironments proto. - static StatusOr> CreateFromProto( - const CompilationEnvironmentsProto& proto); + static absl::StatusOr> + CreateFromProto(const CompilationEnvironmentsProto& proto); // Whenever an environment is added to CompilationEnvironments, even when // GetEnv() adds a lazily initialized one, it is passed to the function @@ -97,6 +97,8 @@ class CompilationEnvironments { T& GetMutableEnv(); template const T& GetEnv(); + template + bool HasEnv(); // Removes all added environments. void Clear() { environments_.clear(); } @@ -148,6 +150,12 @@ const T& CompilationEnvironments::GetEnv() { return GetMutableEnv(); } +template +bool CompilationEnvironments::HasEnv() { + auto descriptor = T::descriptor(); + return environments_.find(descriptor) != environments_.end(); +} + } // namespace xla #endif // XLA_SERVICE_COMPILATION_ENVIRONMENTS_H_ diff --git a/xla/service/compilation_environments_test.cc b/xla/service/compilation_environments_test.cc index 23004e5434480..f0e43fac33fed 100644 --- a/xla/service/compilation_environments_test.cc +++ b/xla/service/compilation_environments_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -210,6 +210,13 @@ TEST_F(CompilationEnvironmentsTest, ProtoRoundTrip) { 20); } +TEST_F(CompilationEnvironmentsTest, EnvTypePresenceCheck) { + CompilationEnvironments envs; + EXPECT_FALSE(envs.HasEnv()); + envs.GetEnv(); + EXPECT_TRUE(envs.HasEnv()); +} + } // namespace } // namespace test } // namespace xla diff --git a/xla/service/compilation_stats.cc b/xla/service/compilation_stats.cc index 74f7cef41f3ac..d08994e4273b5 100644 --- a/xla/service/compilation_stats.cc +++ b/xla/service/compilation_stats.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/compilation_stats.h b/xla/service/compilation_stats.h index f37d597be5c6e..d5c851e17d9ad 100644 --- a/xla/service/compilation_stats.h +++ b/xla/service/compilation_stats.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/compile_only_service.cc b/xla/service/compile_only_service.cc index e1b8f42cd7e3e..ab3acea65840d 100644 --- a/xla/service/compile_only_service.cc +++ b/xla/service/compile_only_service.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,14 +33,14 @@ limitations under the License. namespace xla { -/* static */ StatusOr> +/* static */ absl::StatusOr> CompileOnlyService::NewService(se::Platform* platform) { ServiceOptions default_options; default_options.set_platform(platform); return NewService(default_options); } -/* static */ StatusOr> +/* static */ absl::StatusOr> CompileOnlyService::NewService(const ServiceOptions& options) { se::Platform* platform = options.platform(); if (platform == nullptr) { @@ -58,7 +58,7 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options, Compiler* compiler) : Service(options, /*execute_backend=*/nullptr), compiler_(compiler) {} -StatusOr>> +absl::StatusOr>> CompileOnlyService::CompileAheadOfTime( absl::Span computations, const AotCompilationOptions& options, diff --git a/xla/service/compile_only_service.h b/xla/service/compile_only_service.h index a523ee87edf83..09ca0534454b4 100644 --- a/xla/service/compile_only_service.h +++ b/xla/service/compile_only_service.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,9 +33,9 @@ class CompileOnlyService : public Service { // Factory for creating a CompileOnlyService. The parameter platform is the // platform that the service should target. If platform is null then the // default platform is used. - static StatusOr> NewService( + static absl::StatusOr> NewService( se::Platform* platform); - static StatusOr> NewService( + static absl::StatusOr> NewService( const ServiceOptions& options); // A description of a xla computation to compile using CompileAheadOfTime. @@ -48,7 +48,7 @@ class CompileOnlyService : public Service { // Compiles a list of xla computations for ahead-of-time execution. This is // intended for use in static compilation. See // |CompileOnlyClient::CompileAheadOfTime| for additional details. - StatusOr>> + absl::StatusOr>> CompileAheadOfTime(absl::Span computations, const AotCompilationOptions& options, std::unique_ptr* metadata); diff --git a/xla/service/compile_time_cap.h b/xla/service/compile_time_cap.h index 75da192e7e56a..875448e646dff 100644 --- a/xla/service/compile_time_cap.h +++ b/xla/service/compile_time_cap.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/compiler.cc b/xla/service/compiler.cc index 3fc707e334c0f..b0feb9a62ae81 100644 --- a/xla/service/compiler.cc +++ b/xla/service/compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "xla/stream_executor/dnn.h" #include "xla/util.h" #include "tsl/platform/logging.h" @@ -30,7 +31,16 @@ namespace xla { Compiler::TargetConfig::TargetConfig(se::StreamExecutor* s) : device_description(s->GetDeviceDescription().ToGpuProto()), - platform_name(s->platform()->Name()) {} + platform_name(s->platform()->Name()), + device_description_str(s->GetDeviceDescription().name()) { + se::dnn::DnnSupport* dnn = s->AsDnn(); + if (dnn != nullptr) { + absl::StatusOr dnn_version = dnn->GetVersion(); + if (dnn_version.ok()) { + dnn_version_info = *dnn_version; + } + } +} Compiler::TargetConfig::TargetConfig(const se::GpuTargetConfigProto& proto) : device_description({proto.gpu_device_info()}), @@ -61,7 +71,7 @@ std::unique_ptr Compiler::ComputeDefaultBackendConfig( } // Define a default version where metadata is not used. -StatusOr>> +absl::StatusOr>> Compiler::CompileAheadOfTime( std::unique_ptr module_group, const AotCompilationOptions& options, @@ -98,7 +108,7 @@ Compiler::GetPlatformCompilers() { (*factories)[platform_id] = std::move(compiler_factory); } -/* static */ StatusOr Compiler::GetForPlatform( +/* static */ absl::StatusOr Compiler::GetForPlatform( const se::Platform* platform) { absl::MutexLock lock(&platform_compiler_mutex_); diff --git a/xla/service/compiler.h b/xla/service/compiler.h index f6cc2d222e956..321e4a621921a 100644 --- a/xla/service/compiler.h +++ b/xla/service/compiler.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -44,6 +44,10 @@ limitations under the License. #include "tsl/platform/protobuf.h" #include "tsl/platform/threadpool.h" +namespace mlir { +class DialectRegistry; +} // namespace mlir + namespace xla { // The following types are used for ahead of time compilation. @@ -63,15 +67,20 @@ class AotCompilationResult { virtual ~AotCompilationResult() = default; - virtual StatusOr SerializeAsString() const { + virtual absl::StatusOr SerializeAsString() const { return Unimplemented("SerializeAsString unimplemented."); } - virtual StatusOr> LoadExecutable( - Compiler* compiler, se::StreamExecutor* executor) { + virtual absl::StatusOr> LoadExecutable( + Compiler* compiler, const se::StreamExecutor* executor) const { return Unimplemented("LoadExecutable unimplemented."); } + // Returns the optimized HLO module if one was computed and the implementation + // supports it. + virtual const HloModule* optimized_module() const = 0; + virtual std::unique_ptr consume_optimized_module() = 0; + protected: AotCompilationResult() = default; }; @@ -136,7 +145,7 @@ class Compiler { // An optional thread pool for parallel compilation. tsl::thread::ThreadPool* thread_pool = nullptr; - std::function, Shape>>( + std::function, Shape>>( const HloModule& module)> layout_canonicalization_callback = {}; @@ -145,6 +154,10 @@ class Compiler { // AOT device description. If provided, used instead of querying the device // on which compilation is performed. std::optional target_config; + + // Registry of MLIR dialects and plugins to be loaded during optimization. + // If non-null, it will be used to construct relevant MLIR contexts. + mlir::DialectRegistry* registry = nullptr; }; virtual ~Compiler() = default; @@ -154,25 +167,16 @@ class Compiler { // Runs Hlo passes to optimize the given Hlo module, returns the optimized // module. - virtual StatusOr> RunHloPasses( + virtual absl::StatusOr> RunHloPasses( std::unique_ptr module, se::StreamExecutor* executor, const CompileOptions& options) = 0; - StatusOr> RunHloPasses( + absl::StatusOr> RunHloPasses( std::unique_ptr module, se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator) { return RunHloPasses(std::move(module), executor, CompileOptions{device_allocator}); } - // Performs scheduling and buffer assignment and returns the buffer - // assignments. - // The returned 'BufferAssignment' retains a pointer to the 'HloModule', so - // the module must live at least as long as the buffer assignments. - virtual StatusOr> AssignBuffers( - HloModule* module, se::StreamExecutor* executor) { - return Unimplemented("This compiler does not support this method"); - } - // Compiles the HLO module for execution on a device given by the executor, // and returns an executable object or an error status. No HLO passes are // applied to module. Generally a module should be passed through RunHloPasses @@ -181,10 +185,10 @@ class Compiler { // // The compiler may optionally specialize to the individual device // (not just type of device) indicated by the executor. - virtual StatusOr> RunBackend( + virtual absl::StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* executor, const CompileOptions& options) = 0; - StatusOr> RunBackend( + absl::StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator) { return RunBackend(std::move(module), executor, @@ -197,7 +201,8 @@ class Compiler { // Note: The default implementation of the API here does not utilize the given // buffer assignment. Different backends are a expected to override the // following method to achieve this functionality. - virtual StatusOr> RunBackendWithBufferAssignment( + virtual absl::StatusOr> + RunBackendWithBufferAssignment( std::unique_ptr module, const BufferAssignmentProto* /*buffer_assignment_proto*/, se::StreamExecutor* executor, const CompileOptions& options) { @@ -205,7 +210,7 @@ class Compiler { return RunBackend(std::move(module), executor, options); } - StatusOr> RunBackendWithBufferAssignment( + absl::StatusOr> RunBackendWithBufferAssignment( std::unique_ptr module, const BufferAssignmentProto* buffer_assignment_proto, se::StreamExecutor* executor, @@ -217,7 +222,7 @@ class Compiler { // Returns a (deserialized) AotCompilationResult from a serialized // AotCompilationResult. - virtual StatusOr> + virtual absl::StatusOr> LoadAotCompilationResult(const std::string& serialized_aot_result) { return Unimplemented("LoadAotCompilationResult unimplemented."); } @@ -228,11 +233,11 @@ class Compiler { // // TODO(b/68666782): Remove this method after adding support for multiple // modules to RunHloPasses and RunBackends. - virtual StatusOr>> Compile( + virtual absl::StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_exec, const CompileOptions& options) = 0; - StatusOr>> Compile( + absl::StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_exec, se::DeviceMemoryAllocator* device_allocator) { @@ -261,13 +266,13 @@ class Compiler { // Compiles the HLO module group for ahead-of-time execution. This is // intended for use in static compilation. - virtual StatusOr>> + virtual absl::StatusOr>> CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& options) = 0; // Similar to CompileAheadOfTime above but AotCompilationMetadata // has an argument that can be populated during compilation. - virtual StatusOr>> + virtual absl::StatusOr>> CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& options, std::unique_ptr* metadata); @@ -287,7 +292,7 @@ class Compiler { // Returns the compiler singleton pointer if it is available for the given // platform, or an error status if it is not. - static StatusOr GetForPlatform(const se::Platform* platform); + static absl::StatusOr GetForPlatform(const se::Platform* platform); // Returns a function that computes the size in bytes of the logical // buffer that contains a shape. @@ -307,7 +312,7 @@ class Compiler { } // Returns an AotCompilationResult of the executable for serialization. - virtual StatusOr> Export( + virtual absl::StatusOr> Export( Executable* executable) const { return Unimplemented("Export unimplemented"); } @@ -434,7 +439,7 @@ class AotCompilationOptions { return target_config_; } void set_target_config(const Compiler::TargetConfig& target_config) { - target_config_ = std::move(target_config); + target_config_ = target_config; } protected: diff --git a/xla/service/compiler_test.cc b/xla/service/compiler_test.cc new file mode 100644 index 0000000000000..c2743c15aff88 --- /dev/null +++ b/xla/service/compiler_test.cc @@ -0,0 +1,79 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/compiler.h" + +#include +#include "xla/autotune_results.pb.h" +#include "xla/stream_executor/device_description.pb.h" +#include "xla/stream_executor/gpu/gpu_init.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/tests/test_macros.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +TEST(TargetConfigTest, DISABLED_ON_CPU(ExecutorConstructorFillsAllFields)) { + TF_ASSERT_OK(stream_executor::ValidateGPUMachineManager()); + TF_ASSERT_OK_AND_ASSIGN( + stream_executor::StreamExecutor * executor, + stream_executor::GPUMachineManager()->ExecutorForDevice(0)); + Compiler::TargetConfig config(executor); + stream_executor::GpuTargetConfigProto target = config.ToProto(); + + // We don't attempt to validate values because doing so would require talking + // to the driver directly. + EXPECT_GT(target.dnn_version_info().major(), 0) << target.DebugString(); + EXPECT_GT(target.gpu_device_info().threads_per_block_limit(), 0) + << target.DebugString(); + EXPECT_NE(target.device_description_str(), "") << target.DebugString(); + EXPECT_NE(target.platform_name(), "") << target.DebugString(); + EXPECT_EQ(target.autotune_results().version(), 0); + + EXPECT_EQ(5, + stream_executor::GpuTargetConfigProto::descriptor()->field_count()) + << "Make sure all the fields in GpuTargetConfigProto are set and " + "validated!"; +} + +TEST(TargetConfigTest, ProtoConstructorFillsAllFields) { + stream_executor::GpuTargetConfigProto config_proto; + config_proto.set_platform_name("platform"); + config_proto.mutable_dnn_version_info()->set_major(2); + config_proto.mutable_gpu_device_info()->set_threads_per_block_limit(5); + config_proto.set_device_description_str("foo"); + + Compiler::TargetConfig config(config_proto); + stream_executor::GpuTargetConfigProto target = config.ToProto(); + + EXPECT_EQ(target.dnn_version_info().major(), + config_proto.dnn_version_info().major()) + << target.DebugString(); + EXPECT_EQ(target.gpu_device_info().threads_per_block_limit(), 5) + << target.DebugString(); + EXPECT_EQ(target.device_description_str(), "foo") << target.DebugString(); + EXPECT_EQ(target.platform_name(), "platform") << target.DebugString(); + EXPECT_EQ(target.autotune_results().version(), 0); + + EXPECT_EQ(5, + stream_executor::GpuTargetConfigProto::descriptor()->field_count()) + << "Make sure all the fields in GpuTargetConfigProto are set and " + "validated!"; +} + +} // namespace +} // namespace xla diff --git a/xla/service/computation_layout.cc b/xla/service/computation_layout.cc index 5f83c31343bf5..675c9635a233e 100644 --- a/xla/service/computation_layout.cc +++ b/xla/service/computation_layout.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -59,13 +59,14 @@ bool ComputationLayout::AnyLayoutSet() const { result_layout_.LayoutIsSet(); } -StatusOr> ComputationLayout::FlattenedParameterLayouts() - const { +absl::StatusOr> +ComputationLayout::FlattenedParameterLayouts() const { std::vector result; for (int i = 0; i < parameter_count(); ++i) { TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( parameter_shape(i), - [this, &result](const Shape& subshape, const ShapeIndex& index) { + [this, &result](const Shape& subshape, + const ShapeIndex& index) -> absl::Status { if (subshape.IsTuple()) { return OkStatus(); } @@ -88,12 +89,13 @@ StatusOr> ComputationLayout::FlattenedParameterLayouts() return result; } -StatusOr> ComputationLayout::FlattenedResultLayouts() +absl::StatusOr> ComputationLayout::FlattenedResultLayouts() const { std::vector result; TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( result_shape(), - [this, &result](const Shape& subshape, const ShapeIndex& index) { + [this, &result](const Shape& subshape, + const ShapeIndex& index) -> absl::Status { if (subshape.IsTuple()) { return OkStatus(); } diff --git a/xla/service/computation_layout.h b/xla/service/computation_layout.h index 659ce36220182..b6c947b2b7c9e 100644 --- a/xla/service/computation_layout.h +++ b/xla/service/computation_layout.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -86,12 +86,12 @@ class ComputationLayout { // Returns a list of each parameter's layout. If the parameters are tupled, // returns an untupled list. Must only be called if all parameters have // layouts set (check with LayoutIsSet()). - StatusOr> FlattenedParameterLayouts() const; + absl::StatusOr> FlattenedParameterLayouts() const; // Returns a list of each output's layout. If the result shape is a tuple, // returns an untupled list. Must only be called if all outputs have layouts // set (check with LayoutIsSet()). - StatusOr> FlattenedResultLayouts() const; + absl::StatusOr> FlattenedResultLayouts() const; // Prints a string representation of this object. void Print(Printer* printer) const; diff --git a/xla/service/computation_placer.cc b/xla/service/computation_placer.cc index 9410c30d66ec5..b896c7d10cc40 100644 --- a/xla/service/computation_placer.cc +++ b/xla/service/computation_placer.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -42,14 +42,14 @@ using absl::StrCat; namespace xla { -StatusOr DeviceAssignment::LogicalIdForDevice( - GlobalDeviceId device_id) const { +absl::StatusOr +DeviceAssignment::LogicalIdForDevice(GlobalDeviceId device_id) const { std::optional logical_id; for (int r = 0; r < replica_count(); ++r) { for (int c = 0; c < computation_count(); ++c) { if ((*this)(r, c) == device_id.value()) { if (logical_id.has_value()) { - return InternalError( + return Internal( "Device %d appears twice in DeviceAssignment: %s", device_id.value(), ToString()); } @@ -60,12 +60,12 @@ StatusOr DeviceAssignment::LogicalIdForDevice( if (logical_id.has_value()) { return *logical_id; } else { - return InternalError("Device %d doesn't appear in DeviceAssignment: %s", + return Internal("Device %d doesn't appear in DeviceAssignment: %s", device_id.value(), ToString()); } } -StatusOr DeviceAssignment::ReplicaIdForDevice( +absl::StatusOr DeviceAssignment::ReplicaIdForDevice( GlobalDeviceId device_id) const { TF_ASSIGN_OR_RETURN(const LogicalID logical_id, LogicalIdForDevice(device_id)); @@ -98,7 +98,7 @@ Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const { return OkStatus(); } -/* static */ StatusOr> +/* static */ absl::StatusOr> DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) { TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count()); if (proto.replica_count() <= 0 || proto.computation_count() <= 0) { @@ -135,16 +135,16 @@ std::string DeviceAssignment::ToString() const { return output; } -StatusOr ComputationPlacer::DeviceId(int replica, int computation, - int replica_count, - int computation_count) { +absl::StatusOr ComputationPlacer::DeviceId(int replica, int computation, + int replica_count, + int computation_count) { TF_RET_CHECK(replica < replica_count); TF_RET_CHECK(computation < computation_count); return computation * replica_count + replica; } -StatusOr ComputationPlacer::AssignDevices( +absl::StatusOr ComputationPlacer::AssignDevices( int replica_count, int computation_count) { DeviceAssignment assignment(replica_count, computation_count); for (int replica = 0; replica < replica_count; ++replica) { @@ -165,7 +165,7 @@ StatusOr ComputationPlacer::AssignDevices( auto* computation_placers = GetPlatformComputationPlacers(); if (computation_placers->find(platform_id) != computation_placers->end()) { // TODO(b/282059652): Consider logging the platform name using - // MultiPlatformManager::PlatformWithId(). No doing that for now to avoid + // PlatformManager::PlatformWithId(). No doing that for now to avoid // introducing unwanted dependency. LOG(WARNING) << "computation placer already registered. Please check " "linkage and avoid linking the same target more than once."; @@ -173,8 +173,8 @@ StatusOr ComputationPlacer::AssignDevices( (*computation_placers)[platform_id].creation_function = creation_function; } -/* static */ StatusOr ComputationPlacer::GetForPlatform( - const se::Platform* platform) { +/* static */ absl::StatusOr +ComputationPlacer::GetForPlatform(const se::Platform* platform) { absl::MutexLock lock(&ComputationPlacer::platform_computation_placer_mutex_); auto* computation_placers = GetPlatformComputationPlacers(); diff --git a/xla/service/computation_placer.h b/xla/service/computation_placer.h index 2e6b5e3eeb4cb..12facc7950562 100644 --- a/xla/service/computation_placer.h +++ b/xla/service/computation_placer.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -55,9 +55,9 @@ class DeviceAssignment : public Array2D { }; // Finds the (replica ID, computation ID) pair for the given device. - StatusOr LogicalIdForDevice(GlobalDeviceId device_id) const; + absl::StatusOr LogicalIdForDevice(GlobalDeviceId device_id) const; // Finds the replica ID for the given device. - StatusOr ReplicaIdForDevice(GlobalDeviceId device_id) const; + absl::StatusOr ReplicaIdForDevice(GlobalDeviceId device_id) const; // Returns a map from device ID to logical ID. Querying this map is much more // efficient than `LogicalIdForDevice` if queried repeatedly. absl::flat_hash_map GetDeviceToLogicalIdMap() @@ -69,7 +69,7 @@ class DeviceAssignment : public Array2D { // Return a std::unique_ptr instead of a DeviceAssignment // directly because one of the supported TF platforms (mac) does not compile // due to a StatusOr of an incomplete type (DeviceAssignment). - static StatusOr> Deserialize( + static absl::StatusOr> Deserialize( const DeviceAssignmentProto& proto); std::string ToString() const; @@ -85,13 +85,14 @@ class ComputationPlacer { // Returns the device id assigned to the given replica and computation // instance for [replica_count x computation_count] setup. The returned device // id must match the assignment from PlaceReplicatedComputation(). - virtual StatusOr DeviceId(int replica, int computation, - int replica_count, int computation_count); + virtual absl::StatusOr DeviceId(int replica, int computation, + int replica_count, + int computation_count); // Returns the device ids assigned to a set of replicated computations, given // the number of replicas and the number of computations. - virtual StatusOr AssignDevices(int replica_count, - int computation_count); + virtual absl::StatusOr AssignDevices(int replica_count, + int computation_count); using ComputationPlacerCreationFunction = std::unique_ptr (*)(); @@ -103,7 +104,7 @@ class ComputationPlacer { // Returns the computation placer singleton pointer if it is available for the // given platform, or an error status if it is not. - static StatusOr GetForPlatform( + static absl::StatusOr GetForPlatform( const se::Platform* platform); private: diff --git a/xla/service/conditional_canonicalizer.cc b/xla/service/conditional_canonicalizer.cc index b409a9725d311..22eb2cb6500fe 100644 --- a/xla/service/conditional_canonicalizer.cc +++ b/xla/service/conditional_canonicalizer.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ Status CanonicalizeNonTupleConditional(HloInstruction* conditional) { } } // namespace -StatusOr ConditionalCanonicalizer::Run( +absl::StatusOr ConditionalCanonicalizer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES( diff --git a/xla/service/conditional_canonicalizer.h b/xla/service/conditional_canonicalizer.h index 86cab823277e3..efe9506be9766 100644 --- a/xla/service/conditional_canonicalizer.h +++ b/xla/service/conditional_canonicalizer.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,11 +28,11 @@ namespace xla { class ConditionalCanonicalizer : public HloModulePass { public: absl::string_view name() const override { - return "conditional canonicalizer"; + return "conditional-canonicalizer"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/conditional_canonicalizer_test.cc b/xla/service/conditional_canonicalizer_test.cc index dab5567c48f37..3d5e1e976da0d 100644 --- a/xla/service/conditional_canonicalizer_test.cc +++ b/xla/service/conditional_canonicalizer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/conditional_code_motion.cc b/xla/service/conditional_code_motion.cc index e163017429b4c..1fed4183420b6 100644 --- a/xla/service/conditional_code_motion.cc +++ b/xla/service/conditional_code_motion.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -533,8 +533,8 @@ Status RestructureConditionalInstruction(HloComputation* computation, return OkStatus(); } -StatusOr ConvertSpecialMove(HloInstruction* conditional, - bool is_layout_sensitive) { +absl::StatusOr ConvertSpecialMove(HloInstruction* conditional, + bool is_layout_sensitive) { int branch_count = conditional->branch_count(); if (branch_count <= 0) { return false; @@ -673,7 +673,7 @@ StatusOr ConvertSpecialMove(HloInstruction* conditional, // are the shape of the operands are identical and their properties are // identical. Will start from the root instruction of each branch and get // the identical ops to hoist. -StatusOr ConditionalCodeMotion::MoveInstructionOut( +absl::StatusOr ConditionalCodeMotion::MoveInstructionOut( HloInstruction* conditional, std::vector& to_move_out, std::vector& new_boundaries) { if (to_move_out.empty()) { @@ -780,7 +780,7 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( } // Hoist conditional users from outside to inside the branches. -StatusOr ConditionalCodeMotion::MoveUserInstructionsIn( +absl::StatusOr ConditionalCodeMotion::MoveUserInstructionsIn( HloInstruction* conditional, std::vector& to_move_in) { if (to_move_in.empty()) { return false; @@ -1235,7 +1235,7 @@ class MoveOperandIntoBranch { }; // Hoist operands of a conditional from outside to inside the branches. -StatusOr ConditionalCodeMotion::MoveOperandInstructionsIn( +absl::StatusOr ConditionalCodeMotion::MoveOperandInstructionsIn( HloInstruction* conditional, std::vector& to_move_in) { // Mapping boundaries to be moved to their new representations. int64_t to_move_in_size = to_move_in.size(); @@ -1944,7 +1944,7 @@ ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion( return Decision(Decision::Direction::kNoChange, 0); } -StatusOr ConditionalCodeMotion::Run( +absl::StatusOr ConditionalCodeMotion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(2) << "Begin a new pass of conditional code motion optimization.\n"; diff --git a/xla/service/conditional_code_motion.h b/xla/service/conditional_code_motion.h index 0f8b4600ed6fb..caf92342900c9 100644 --- a/xla/service/conditional_code_motion.h +++ b/xla/service/conditional_code_motion.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -181,7 +181,7 @@ class ConditionalCodeMotion : public HloModulePass { absl::string_view name() const override { return "conditional-code-motion"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -228,13 +228,13 @@ class ConditionalCodeMotion : public HloModulePass { // moved. int64_t memory_increase_allowance_ = 5000; int64_t memory_increase_ = 0; - StatusOr MoveInstructionOut(HloInstruction* conditional, - std::vector& to_move_out, - std::vector& new_boundaries); - StatusOr MoveUserInstructionsIn(HloInstruction* conditional, - std::vector& to_move_in); - StatusOr MoveOperandInstructionsIn(HloInstruction* conditional, - std::vector& to_move_in); + absl::StatusOr MoveInstructionOut( + HloInstruction* conditional, std::vector& to_move_out, + std::vector& new_boundaries); + absl::StatusOr MoveUserInstructionsIn( + HloInstruction* conditional, std::vector& to_move_in); + absl::StatusOr MoveOperandInstructionsIn( + HloInstruction* conditional, std::vector& to_move_in); void SetDefaultMoveConfig(); }; } // namespace conditional_opt diff --git a/xla/service/conditional_code_motion_test.cc b/xla/service/conditional_code_motion_test.cc index 5ca81eb1e628e..fcfe91d7a21df 100644 --- a/xla/service/conditional_code_motion_test.cc +++ b/xla/service/conditional_code_motion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/conditional_simplifier.cc b/xla/service/conditional_simplifier.cc index b6ae8972cf5bd..d4d568dc6ed9e 100644 --- a/xla/service/conditional_simplifier.cc +++ b/xla/service/conditional_simplifier.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -61,7 +61,7 @@ bool ComputationIsEmptyWithArrayRoot(const HloComputation* computation) { return empty_operations && contains_array; } -StatusOr TryRemoveUnusedConditionalOperands( +absl::StatusOr TryRemoveUnusedConditionalOperands( HloComputation* computation, const absl::flat_hash_set& calling_conditionals) { HloInstruction* param = computation->parameter_instruction(0); @@ -439,7 +439,7 @@ bool MergeDuplicateTupleElements(HloInstruction* conditional) { // inline that computation. // // Returns true if it made a change to the graph. -StatusOr ConditionalSimplifier::TryRemoveConditional( +absl::StatusOr ConditionalSimplifier::TryRemoveConditional( HloInstruction* conditional) { CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); // Do not remove conditionals that contain side-effecting instructions or @@ -601,7 +601,7 @@ static bool InstructionCallsChannelInstructions( return false; } -StatusOr ConditionalSimplifier::Run( +absl::StatusOr ConditionalSimplifier::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES( diff --git a/xla/service/conditional_simplifier.h b/xla/service/conditional_simplifier.h index 6faee234f23b2..8eeab8279dd8f 100644 --- a/xla/service/conditional_simplifier.h +++ b/xla/service/conditional_simplifier.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,12 +29,12 @@ class ConditionalSimplifier : public HloModulePass { public: absl::string_view name() const override { return "simplify-conditional"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; private: - StatusOr TryRemoveConditional(HloInstruction* conditional); + absl::StatusOr TryRemoveConditional(HloInstruction* conditional); }; } // namespace xla diff --git a/xla/service/conditional_simplifier_test.cc b/xla/service/conditional_simplifier_test.cc index a00a4346b680f..083ef03453d67 100644 --- a/xla/service/conditional_simplifier_test.cc +++ b/xla/service/conditional_simplifier_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/conditional_to_select.cc b/xla/service/conditional_to_select.cc index aee877071d151..3b2fd71003830 100644 --- a/xla/service/conditional_to_select.cc +++ b/xla/service/conditional_to_select.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ limitations under the License. namespace xla { -static StatusOr DoConditionalToSelect(HloInstruction* conditional) { +static absl::StatusOr DoConditionalToSelect(HloInstruction* conditional) { // Only allow conditional to select if the called computations // do not have side effects. if (conditional->true_computation()->HasSideEffect() || @@ -66,7 +66,7 @@ static StatusOr DoConditionalToSelect(HloInstruction* conditional) { return true; } -StatusOr ConditionalToSelect::Run( +absl::StatusOr ConditionalToSelect::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { std::unique_ptr call_graph = CallGraph::Build(module); diff --git a/xla/service/conditional_to_select.h b/xla/service/conditional_to_select.h index c66abcfb36116..cbc9cff571a90 100644 --- a/xla/service/conditional_to_select.h +++ b/xla/service/conditional_to_select.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ class ConditionalToSelect : public HloModulePass { // Run conditional to select on the given computation. Returns whether the // computation was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/conditional_to_select_test.cc b/xla/service/conditional_to_select_test.cc index b7b2181f98674..c6f4b7cc721b7 100644 --- a/xla/service/conditional_to_select_test.cc +++ b/xla/service/conditional_to_select_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/constant_value.cc b/xla/service/constant_value.cc index 1d69fc248c85a..a5b6c9c30f2f0 100644 --- a/xla/service/constant_value.cc +++ b/xla/service/constant_value.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,10 +19,11 @@ limitations under the License. namespace xla { -StatusOr ConstantValue::FromLiteral(const Literal& literal) { +absl::StatusOr ConstantValue::FromLiteral( + const Literal& literal) { CHECK_EQ(literal.shape().dimensions_size(), 0) << "Expected scalar literal"; return primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> StatusOr { + [&](auto primitive_type_constant) -> absl::StatusOr { if constexpr (primitive_util::IsIntegralType(primitive_type_constant)) { return ConstantValue( static_cast( diff --git a/xla/service/constant_value.h b/xla/service/constant_value.h index 97f3a8ff04795..2a88afc3e1b21 100644 --- a/xla/service/constant_value.h +++ b/xla/service/constant_value.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -53,7 +53,7 @@ class ConstantValue { static ConstantValue GetUnsigned(uint64_t value, int32_t bitwidth) { return ConstantValue(value, bitwidth, /*is_signed=*/false); } - static StatusOr FromLiteral(const Literal& literal); + static absl::StatusOr FromLiteral(const Literal& literal); ConstantValue add(const ConstantValue& other) const { return ConstantValue(value_ + other.value_, bitwidth_, is_signed_); } diff --git a/xla/service/constant_value_test.cc b/xla/service/constant_value_test.cc index d5e3bd9655f65..fd468fbe5f01f 100644 --- a/xla/service/constant_value_test.cc +++ b/xla/service/constant_value_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/convert_async_collectives_to_sync.cc b/xla/service/convert_async_collectives_to_sync.cc index 736e2aff5e015..60ec4a8788f68 100644 --- a/xla/service/convert_async_collectives_to_sync.cc +++ b/xla/service/convert_async_collectives_to_sync.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,8 +30,8 @@ limitations under the License. namespace xla { -StatusOr CreateSyncVariant(HloInstruction* async_start, - HloInstruction* async_done) { +absl::StatusOr CreateSyncVariant(HloInstruction* async_start, + HloInstruction* async_done) { HloInstruction* sync_instruction = nullptr; HloComputation* computation = async_start->parent(); @@ -74,7 +74,7 @@ StatusOr CreateSyncVariant(HloInstruction* async_start, break; } default: - return InternalError("Unexpected async start op %s", + return Internal("Unexpected async start op %s", HloOpcodeString(async_start->opcode())); } @@ -90,20 +90,6 @@ StatusOr CreateSyncVariant(HloInstruction* async_start, TF_RETURN_IF_ERROR(async_start->DropAllControlDeps()); TF_RETURN_IF_ERROR(async_done->DropAllControlDeps()); - // For the generic async-start/done, we also need to disconnect them from - // the called computations. - if (async_start_op == HloOpcode::kAsyncStart) { - auto disconnect_called_computation = - [](HloInstruction* async_op) -> Status { - TF_RET_CHECK(async_op->called_computations().size() == 1); - HloComputation* called = async_op->called_computations().front(); - called->RemoveAsyncInstruction(async_op); - return OkStatus(); - }; - TF_RETURN_IF_ERROR(disconnect_called_computation(async_start)); - TF_RETURN_IF_ERROR(disconnect_called_computation(async_done)); - } - // When we remove the async-done (and its unused operands), in most cases, // the async-start may not be deleted if its considered as having side effects // but in some cases it will be (e.g., the generic HLO kAsyncStart). Track its @@ -158,7 +144,7 @@ ConvertAsyncCollectivesToSync::ReplaceAsyncInstructionsWithSync( return OkStatus(); } -StatusOr ConvertAsyncCollectivesToSync::RunOnComputation( +absl::StatusOr ConvertAsyncCollectivesToSync::RunOnComputation( HloComputation* computation) { HloModule* module = computation->parent(); std::vector> async_pairs; @@ -171,10 +157,10 @@ StatusOr ConvertAsyncCollectivesToSync::RunOnComputation( absl::flat_hash_set in_flight_ops; for (HloInstruction* instruction : sequence.instructions()) { - if (hlo_query::IsAsyncCollectiveStartOp(instruction->opcode())) { + if (hlo_query::IsAsyncCollectiveStartOp(instruction)) { in_flight_ops.insert(instruction); VLOG(3) << "Found async start " << instruction->ToString(); - } else if (hlo_query::IsAsyncCollectiveDoneOp(instruction->opcode())) { + } else if (hlo_query::IsAsyncCollectiveDoneOp(instruction)) { // If this done is matching with the previous start and all intervening // ops are nops (i.e., prev_async_start was not reset to null), then we // were unable to schedule an independent op to overlap with this async @@ -207,7 +193,7 @@ StatusOr ConvertAsyncCollectivesToSync::RunOnComputation( return true; } -StatusOr ConvertAsyncCollectivesToSync::Run( +absl::StatusOr ConvertAsyncCollectivesToSync::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { if (!module->has_schedule()) { diff --git a/xla/service/convert_async_collectives_to_sync.h b/xla/service/convert_async_collectives_to_sync.h index 14db7a185e09e..2b37c6ee7fa46 100644 --- a/xla/service/convert_async_collectives_to_sync.h +++ b/xla/service/convert_async_collectives_to_sync.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ class ConvertAsyncCollectivesToSync : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -58,7 +58,7 @@ class ConvertAsyncCollectivesToSync : public HloModulePass { "async_collective_name"; private: - StatusOr RunOnComputation(HloComputation* computation); + absl::StatusOr RunOnComputation(HloComputation* computation); HloPredicate is_nop_; }; } // namespace xla diff --git a/xla/service/convert_async_collectives_to_sync_test.cc b/xla/service/convert_async_collectives_to_sync_test.cc index 05aed10705bec..355ba78b7e2fc 100644 --- a/xla/service/convert_async_collectives_to_sync_test.cc +++ b/xla/service/convert_async_collectives_to_sync_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/convert_memory_placement_to_internal_annotations.cc b/xla/service/convert_memory_placement_to_internal_annotations.cc new file mode 100644 index 0000000000000..a0d7887532651 --- /dev/null +++ b/xla/service/convert_memory_placement_to_internal_annotations.cc @@ -0,0 +1,102 @@ +/* Copyright 2024 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#include "xla/service/convert_memory_placement_to_internal_annotations.h" + +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/host_memory_offload_annotations.h" +#include "xla/side_effect_util.h" +#include "xla/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" + +namespace xla { + +absl::StatusOr ConvertMemoryPlacementToInternalAnnotations::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + for (HloComputation* c : module->MakeNonfusionComputations()) { + for (HloInstruction* instruction : c->MakeInstructionPostOrder()) { + if (instruction->IsCustomCall( + host_memory_offload_annotations::kDevicePlacement)) { + const auto& frontend_attributes = instruction->frontend_attributes(); + const auto it = frontend_attributes.map().find(kXlaBufferPlacementAttr); + if (it == frontend_attributes.map().end()) { + continue; + } + // XLA currently does not differentiate between pinned and unpinned host + // memory. + const bool is_to_host_case = + (it->second == + host_memory_offload_annotations::kMemoryTargetPinnedHost || + it->second == + host_memory_offload_annotations::kMemoryTargetUnpinnedHost); + const bool is_to_device_case = + (it->second == + host_memory_offload_annotations::kMemoryTargetDevice); + if (!is_to_host_case && !is_to_device_case) { + continue; + } + if (is_to_host_case) { + VLOG(1) << "Process forward case: " << instruction->ToString(); + if (instruction->users().size() != 1) { + VLOG(1) << "Skip because of too many users on instruction"; + continue; + } + if (instruction->operand_count() != 1) { + return Internal( + "Custom calls with target %s must have exactly one operand. %s " + "has %d.", + host_memory_offload_annotations::kDevicePlacement, + instruction->name(), instruction->operand_count()); + } + HloInstruction* input = instruction->mutable_operand(0); + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith( + c->AddInstruction(HloInstruction::CreateCustomCall( + input->shape(), {input}, + host_memory_offload_annotations:: + kMoveToHostCustomCallTarget)))); + TF_RETURN_IF_ERROR( + c->RemoveInstructionAndUnusedOperands(instruction)); + changed = true; + } else if (is_to_device_case) { + VLOG(1) << "Process backward case: " << instruction->ToString(); + HloInstruction* custom_call_operand = instruction->mutable_operand(0); + if (custom_call_operand->users().size() != 1) { + VLOG(1) << "Skip because operand is used by more than one user"; + continue; + } + HloInstruction* new_result = + c->AddInstruction(HloInstruction::CreateCustomCall( + custom_call_operand->shape(), {custom_call_operand}, + host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget)); + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(new_result)); + TF_RETURN_IF_ERROR( + c->RemoveInstructionAndUnusedOperands(instruction)); + changed = true; + } + } + } + } + return changed; +} + +} // namespace xla diff --git a/xla/service/convert_memory_placement_to_internal_annotations.h b/xla/service/convert_memory_placement_to_internal_annotations.h new file mode 100644 index 0000000000000..36d3a0b4fc887 --- /dev/null +++ b/xla/service/convert_memory_placement_to_internal_annotations.h @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#ifndef XLA_SERVICE_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_ +#define XLA_SERVICE_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { + +class ConvertMemoryPlacementToInternalAnnotations : public HloModulePass { + public: + ConvertMemoryPlacementToInternalAnnotations() = default; + + absl::string_view name() const override { + return "convert-memory-placement-to-internal-annotations"; + } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_SERVICE_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_ diff --git a/xla/service/convert_memory_placement_to_internal_annotations_test.cc b/xla/service/convert_memory_placement_to_internal_annotations_test.cc new file mode 100644 index 0000000000000..4ee97fc1fbfba --- /dev/null +++ b/xla/service/convert_memory_placement_to_internal_annotations_test.cc @@ -0,0 +1,488 @@ +/* Copyright 2024 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#include "xla/service/convert_memory_placement_to_internal_annotations.h" + +#include +#include +#include +#include +#include + +#include +#include +#include "xla/service/host_memory_offload_annotations.h" +#include "xla/statusor.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace { + +class ConvertMemoryPlacementToInternalAnnotationsTest : public HloTestBase { + public: + ConvertMemoryPlacementToInternalAnnotationsTest() = default; +}; + +TEST_F(ConvertMemoryPlacementToInternalAnnotationsTest, ConvertPinnedHostTest) { + const char* hlo_string = R"( +HloModule jit_f, entry_computation_layout={(f32[16]{0})->f32[16]{0}} + +region_0.9 { + arg_tuple.10 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.11 = s32[] get-tuple-element(arg_tuple.10), index=0 + constant.15 = s32[] constant(1) + add.33 = s32[] add(get-tuple-element.11, constant.15) + get-tuple-element.12 = f32[16]{0} get-tuple-element(arg_tuple.10), index=1 + sine.18 = f32[16]{0} sine(get-tuple-element.12) + sine.19 = f32[16]{0} sine(sine.18) + sine.20 = f32[16]{0} sine(sine.19) + get-tuple-element.13 = f32[16,16]{1,0} get-tuple-element(arg_tuple.10), index=2 + custom-call.21 = f32[16]{0} custom-call(sine.19), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="pinned_host"} + reshape.23 = f32[1,16]{1,0} reshape(custom-call.21) + constant.17 = s32[] constant(0) + compare.24 = pred[] compare(get-tuple-element.11, constant.17), direction=LT + constant.16 = s32[] constant(16) + add.25 = s32[] add(get-tuple-element.11, constant.16) + select.26 = s32[] select(compare.24, add.25, get-tuple-element.11) + dynamic-update-slice.27 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.13, reshape.23, select.26, constant.17) + get-tuple-element.14 = f32[16,16]{1,0} get-tuple-element(arg_tuple.10), index=3 + custom-call.22 = f32[16]{0} custom-call(sine.20), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="pinned_host"} + reshape.28 = f32[1,16]{1,0} reshape(custom-call.22) + compare.29 = pred[] compare(get-tuple-element.11, constant.17), direction=LT + add.30 = s32[] add(get-tuple-element.11, constant.16) + select.31 = s32[] select(compare.29, add.30, get-tuple-element.11) + dynamic-update-slice.32 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.14, reshape.28, select.31, constant.17) + ROOT tuple.34 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(add.33, sine.20, dynamic-update-slice.27, dynamic-update-slice.32) +} + +region_1.35 { + arg_tuple.36 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.38 = f32[16]{0} get-tuple-element(arg_tuple.36), index=1 + get-tuple-element.39 = f32[16,16]{1,0} get-tuple-element(arg_tuple.36), index=2 + get-tuple-element.40 = f32[16,16]{1,0} get-tuple-element(arg_tuple.36), index=3 + get-tuple-element.37 = s32[] get-tuple-element(arg_tuple.36), index=0 + constant.41 = s32[] constant(16) + ROOT compare.42 = pred[] compare(get-tuple-element.37, constant.41), direction=LT +} + +core_closed_call.43 { + constant.47 = s32[] constant(0) + Arg_0.44 = f32[16]{0} parameter(0) + constant.45 = f32[] constant(0) + broadcast.46 = f32[16,16]{1,0} broadcast(constant.45), dimensions={} + tuple.48 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(constant.47, Arg_0.44, broadcast.46, broadcast.46) + while.49 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) while(tuple.48), condition=region_1.35, body=region_0.9 + get-tuple-element.50 = s32[] get-tuple-element(while.49), index=0 + get-tuple-element.51 = f32[16]{0} get-tuple-element(while.49), index=1 + get-tuple-element.52 = f32[16,16]{1,0} get-tuple-element(while.49), index=2 + get-tuple-element.53 = f32[16,16]{1,0} get-tuple-element(while.49), index=3 + ROOT tuple.54 = (f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(get-tuple-element.52, get-tuple-element.53) +} + +region_2.65 { + arg_tuple.66 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.67 = s32[] get-tuple-element(arg_tuple.66), index=0 + constant.74 = s32[] constant(1) + add.108 = s32[] add(get-tuple-element.67, constant.74) + get-tuple-element.73 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=6 + constant.76 = s32[] constant(0) + compare.82 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + constant.75 = s32[] constant(16) + add.83 = s32[] add(get-tuple-element.67, constant.75) + select.84 = s32[] select(compare.82, add.83, get-tuple-element.67) + dynamic-slice.85 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.73, select.84, constant.76), dynamic_slice_sizes={1,16} + reshape.86 = f32[16]{0} reshape(dynamic-slice.85) + custom-call.87 = f32[16]{0} custom-call(reshape.86), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="device"} + get-tuple-element.69 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=2 + get-tuple-element.68 = f32[16]{0} get-tuple-element(arg_tuple.66), index=1 + cosine.88 = f32[16]{0} cosine(get-tuple-element.68) + reshape.93 = f32[1,16]{1,0} reshape(cosine.88) + compare.94 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.95 = s32[] add(get-tuple-element.67, constant.75) + select.96 = s32[] select(compare.94, add.95, get-tuple-element.67) + dynamic-update-slice.97 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.69, reshape.93, select.96, constant.76) + get-tuple-element.70 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=3 + sine.89 = f32[16]{0} sine(get-tuple-element.68) + cosine.90 = f32[16]{0} cosine(sine.89) + reshape.98 = f32[1,16]{1,0} reshape(cosine.90) + compare.99 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.100 = s32[] add(get-tuple-element.67, constant.75) + select.101 = s32[] select(compare.99, add.100, get-tuple-element.67) + dynamic-update-slice.102 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.70, reshape.98, select.101, constant.76) + get-tuple-element.71 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=4 + get-tuple-element.72 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=5 + compare.77 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.78 = s32[] add(get-tuple-element.67, constant.75) + select.79 = s32[] select(compare.77, add.78, get-tuple-element.67) + dynamic-slice.80 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.72, select.79, constant.76), dynamic_slice_sizes={1,16} + reshape.81 = f32[16]{0} reshape(dynamic-slice.80) + custom-call.91 = f32[16]{0} custom-call(reshape.81), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="device"} + cosine.92 = f32[16]{0} cosine(custom-call.91) + reshape.103 = f32[1,16]{1,0} reshape(cosine.92) + compare.104 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.105 = s32[] add(get-tuple-element.67, constant.75) + select.106 = s32[] select(compare.104, add.105, get-tuple-element.67) + dynamic-update-slice.107 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.71, reshape.103, select.106, constant.76) + ROOT tuple.109 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(add.108, custom-call.87, dynamic-update-slice.97, dynamic-update-slice.102, dynamic-update-slice.107, get-tuple-element.72, get-tuple-element.73) +} + +region_3.110 { + arg_tuple.111 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.113 = f32[16]{0} get-tuple-element(arg_tuple.111), index=1 + get-tuple-element.114 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=2 + get-tuple-element.115 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=3 + get-tuple-element.116 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=4 + get-tuple-element.117 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=5 + get-tuple-element.118 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=6 + get-tuple-element.112 = s32[] get-tuple-element(arg_tuple.111), index=0 + constant.119 = s32[] constant(16) + ROOT compare.120 = pred[] compare(get-tuple-element.112, constant.119), direction=LT +} + +region_4.130 { + arg_tuple.131 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) parameter(0) + get-tuple-element.132 = s32[] get-tuple-element(arg_tuple.131), index=0 + constant.140 = s32[] constant(1) + add.164 = s32[] add(get-tuple-element.132, constant.140) + get-tuple-element.133 = f32[16]{0} get-tuple-element(arg_tuple.131), index=1 + get-tuple-element.134 = f32[] get-tuple-element(arg_tuple.131), index=2 + broadcast.159 = f32[16]{0} broadcast(get-tuple-element.134), dimensions={} + add.160 = f32[16]{0} add(get-tuple-element.133, broadcast.159) + get-tuple-element.137 = f32[16,16]{1,0} get-tuple-element(arg_tuple.131), index=5 + constant.141 = s32[] constant(16) + subtract.142 = s32[] subtract(constant.141, get-tuple-element.132) + subtract.143 = s32[] subtract(subtract.142, constant.140) + constant.139 = s32[] constant(0) + compare.154 = pred[] compare(subtract.143, constant.139), direction=LT + add.155 = s32[] add(subtract.143, constant.141) + select.156 = s32[] select(compare.154, add.155, subtract.143) + dynamic-slice.157 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.137, select.156, constant.139), dynamic_slice_sizes={1,16} + reshape.158 = f32[16]{0} reshape(dynamic-slice.157) + multiply.161 = f32[16]{0} multiply(add.160, reshape.158) + get-tuple-element.136 = f32[16,16]{1,0} get-tuple-element(arg_tuple.131), index=4 + compare.149 = pred[] compare(subtract.143, constant.139), direction=LT + add.150 = s32[] add(subtract.143, constant.141) + select.151 = s32[] select(compare.149, add.150, subtract.143) + dynamic-slice.152 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.136, select.151, constant.139), dynamic_slice_sizes={1,16} + reshape.153 = f32[16]{0} reshape(dynamic-slice.152) + multiply.162 = f32[16]{0} multiply(multiply.161, reshape.153) + get-tuple-element.135 = f32[16,16]{1,0} get-tuple-element(arg_tuple.131), index=3 + compare.144 = pred[] compare(subtract.143, constant.139), direction=LT + add.145 = s32[] add(subtract.143, constant.141) + select.146 = s32[] select(compare.144, add.145, subtract.143) + dynamic-slice.147 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.135, select.146, constant.139), dynamic_slice_sizes={1,16} + reshape.148 = f32[16]{0} reshape(dynamic-slice.147) + multiply.163 = f32[16]{0} multiply(multiply.162, reshape.148) + constant.138 = f32[] constant(0) + ROOT tuple.165 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) tuple(add.164, multiply.163, constant.138, get-tuple-element.135, get-tuple-element.136, get-tuple-element.137) +} + +region_5.166 { + arg_tuple.167 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) parameter(0) + get-tuple-element.169 = f32[16]{0} get-tuple-element(arg_tuple.167), index=1 + get-tuple-element.170 = f32[] get-tuple-element(arg_tuple.167), index=2 + get-tuple-element.171 = f32[16,16]{1,0} get-tuple-element(arg_tuple.167), index=3 + get-tuple-element.172 = f32[16,16]{1,0} get-tuple-element(arg_tuple.167), index=4 + get-tuple-element.173 = f32[16,16]{1,0} get-tuple-element(arg_tuple.167), index=5 + get-tuple-element.168 = s32[] get-tuple-element(arg_tuple.167), index=0 + constant.174 = s32[] constant(16) + ROOT compare.175 = pred[] compare(get-tuple-element.168, constant.174), direction=LT +} + +ENTRY main.183 { + constant.6 = s32[] constant(0) + Arg_0.1 = f32[16]{0} parameter(0), sharding={devices=[2]<=[2]} + call.55 = (f32[16,16]{1,0}, f32[16,16]{1,0}) call(Arg_0.1), to_apply=core_closed_call.43 + get-tuple-element.56 = f32[16,16]{1,0} get-tuple-element(call.55), index=0 + get-tuple-element.57 = f32[16,16]{1,0} get-tuple-element(call.55), index=1 + constant.7 = f32[] constant(1) + tuple.58 = (f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16]{0}, f32[]) tuple(get-tuple-element.56, get-tuple-element.57, Arg_0.1, constant.7) + opt-barrier.59 = (f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16]{0}, f32[]) opt-barrier(tuple.58) + get-tuple-element.62 = f32[16]{0} get-tuple-element(opt-barrier.59), index=2 + constant.4 = f32[] constant(0) + broadcast.5 = f32[16,16]{1,0} broadcast(constant.4), dimensions={} + get-tuple-element.60 = f32[16,16]{1,0} get-tuple-element(opt-barrier.59), index=0 + get-tuple-element.61 = f32[16,16]{1,0} get-tuple-element(opt-barrier.59), index=1 + tuple.64 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(constant.6, get-tuple-element.62, broadcast.5, broadcast.5, broadcast.5, get-tuple-element.60, get-tuple-element.61) + while.121 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) while(tuple.64), condition=region_3.110, body=region_2.65 + get-tuple-element.122 = s32[] get-tuple-element(while.121), index=0 + get-tuple-element.123 = f32[16]{0} get-tuple-element(while.121), index=1 + get-tuple-element.127 = f32[16,16]{1,0} get-tuple-element(while.121), index=5 + get-tuple-element.128 = f32[16,16]{1,0} get-tuple-element(while.121), index=6 + constant.2 = f32[] constant(0) + broadcast.3 = f32[16]{0} broadcast(constant.2), dimensions={} + get-tuple-element.63 = f32[] get-tuple-element(opt-barrier.59), index=3 + get-tuple-element.124 = f32[16,16]{1,0} get-tuple-element(while.121), index=2 + get-tuple-element.125 = f32[16,16]{1,0} get-tuple-element(while.121), index=3 + get-tuple-element.126 = f32[16,16]{1,0} get-tuple-element(while.121), index=4 + tuple.129 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) tuple(constant.6, broadcast.3, get-tuple-element.63, get-tuple-element.124, get-tuple-element.125, get-tuple-element.126) + while.176 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) while(tuple.129), condition=region_5.166, body=region_4.130 + get-tuple-element.177 = s32[] get-tuple-element(while.176), index=0 + ROOT get-tuple-element.178 = f32[16]{0} get-tuple-element(while.176), index=1 + get-tuple-element.179 = f32[] get-tuple-element(while.176), index=2 + get-tuple-element.180 = f32[16,16]{1,0} get-tuple-element(while.176), index=3 + get-tuple-element.181 = f32[16,16]{1,0} get-tuple-element(while.176), index=4 + get-tuple-element.182 = f32[16,16]{1,0} get-tuple-element(while.176), index=5 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + bool changed = + ConvertMemoryPlacementToInternalAnnotations().Run(module.get()).value(); + EXPECT_TRUE(changed); + XLA_VLOG_LINES(1, module->ToString()); + int64_t custom_calls_count = 0; + for (auto* c : module->computations()) { + for (auto* instr : c->instructions()) { + if (instr->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget) || + instr->IsCustomCall( + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) { + ++custom_calls_count; + } + } + } + EXPECT_EQ(custom_calls_count, 4); +} + +TEST_F(ConvertMemoryPlacementToInternalAnnotationsTest, + ConvertUnpinnedHostTest) { + const char* hlo_string = R"( +HloModule jit_f, entry_computation_layout={(f32[16]{0})->f32[16]{0}} + +region_0.9 { + arg_tuple.10 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.11 = s32[] get-tuple-element(arg_tuple.10), index=0 + constant.15 = s32[] constant(1) + add.33 = s32[] add(get-tuple-element.11, constant.15) + get-tuple-element.12 = f32[16]{0} get-tuple-element(arg_tuple.10), index=1 + sine.18 = f32[16]{0} sine(get-tuple-element.12) + sine.19 = f32[16]{0} sine(sine.18) + sine.20 = f32[16]{0} sine(sine.19) + get-tuple-element.13 = f32[16,16]{1,0} get-tuple-element(arg_tuple.10), index=2 + custom-call.21 = f32[16]{0} custom-call(sine.19), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="unpinned_host"} + reshape.23 = f32[1,16]{1,0} reshape(custom-call.21) + constant.17 = s32[] constant(0) + compare.24 = pred[] compare(get-tuple-element.11, constant.17), direction=LT + constant.16 = s32[] constant(16) + add.25 = s32[] add(get-tuple-element.11, constant.16) + select.26 = s32[] select(compare.24, add.25, get-tuple-element.11) + dynamic-update-slice.27 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.13, reshape.23, select.26, constant.17) + get-tuple-element.14 = f32[16,16]{1,0} get-tuple-element(arg_tuple.10), index=3 + custom-call.22 = f32[16]{0} custom-call(sine.20), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="unpinned_host"} + reshape.28 = f32[1,16]{1,0} reshape(custom-call.22) + compare.29 = pred[] compare(get-tuple-element.11, constant.17), direction=LT + add.30 = s32[] add(get-tuple-element.11, constant.16) + select.31 = s32[] select(compare.29, add.30, get-tuple-element.11) + dynamic-update-slice.32 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.14, reshape.28, select.31, constant.17) + ROOT tuple.34 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(add.33, sine.20, dynamic-update-slice.27, dynamic-update-slice.32) +} + +region_1.35 { + arg_tuple.36 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.38 = f32[16]{0} get-tuple-element(arg_tuple.36), index=1 + get-tuple-element.39 = f32[16,16]{1,0} get-tuple-element(arg_tuple.36), index=2 + get-tuple-element.40 = f32[16,16]{1,0} get-tuple-element(arg_tuple.36), index=3 + get-tuple-element.37 = s32[] get-tuple-element(arg_tuple.36), index=0 + constant.41 = s32[] constant(16) + ROOT compare.42 = pred[] compare(get-tuple-element.37, constant.41), direction=LT +} + +core_closed_call.43 { + constant.47 = s32[] constant(0) + Arg_0.44 = f32[16]{0} parameter(0) + constant.45 = f32[] constant(0) + broadcast.46 = f32[16,16]{1,0} broadcast(constant.45), dimensions={} + tuple.48 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(constant.47, Arg_0.44, broadcast.46, broadcast.46) + while.49 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}) while(tuple.48), condition=region_1.35, body=region_0.9 + get-tuple-element.50 = s32[] get-tuple-element(while.49), index=0 + get-tuple-element.51 = f32[16]{0} get-tuple-element(while.49), index=1 + get-tuple-element.52 = f32[16,16]{1,0} get-tuple-element(while.49), index=2 + get-tuple-element.53 = f32[16,16]{1,0} get-tuple-element(while.49), index=3 + ROOT tuple.54 = (f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(get-tuple-element.52, get-tuple-element.53) +} + +region_2.65 { + arg_tuple.66 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.67 = s32[] get-tuple-element(arg_tuple.66), index=0 + constant.74 = s32[] constant(1) + add.108 = s32[] add(get-tuple-element.67, constant.74) + get-tuple-element.73 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=6 + constant.76 = s32[] constant(0) + compare.82 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + constant.75 = s32[] constant(16) + add.83 = s32[] add(get-tuple-element.67, constant.75) + select.84 = s32[] select(compare.82, add.83, get-tuple-element.67) + dynamic-slice.85 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.73, select.84, constant.76), dynamic_slice_sizes={1,16} + reshape.86 = f32[16]{0} reshape(dynamic-slice.85) + custom-call.87 = f32[16]{0} custom-call(reshape.86), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="device"} + get-tuple-element.69 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=2 + get-tuple-element.68 = f32[16]{0} get-tuple-element(arg_tuple.66), index=1 + cosine.88 = f32[16]{0} cosine(get-tuple-element.68) + reshape.93 = f32[1,16]{1,0} reshape(cosine.88) + compare.94 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.95 = s32[] add(get-tuple-element.67, constant.75) + select.96 = s32[] select(compare.94, add.95, get-tuple-element.67) + dynamic-update-slice.97 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.69, reshape.93, select.96, constant.76) + get-tuple-element.70 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=3 + sine.89 = f32[16]{0} sine(get-tuple-element.68) + cosine.90 = f32[16]{0} cosine(sine.89) + reshape.98 = f32[1,16]{1,0} reshape(cosine.90) + compare.99 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.100 = s32[] add(get-tuple-element.67, constant.75) + select.101 = s32[] select(compare.99, add.100, get-tuple-element.67) + dynamic-update-slice.102 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.70, reshape.98, select.101, constant.76) + get-tuple-element.71 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=4 + get-tuple-element.72 = f32[16,16]{1,0} get-tuple-element(arg_tuple.66), index=5 + compare.77 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.78 = s32[] add(get-tuple-element.67, constant.75) + select.79 = s32[] select(compare.77, add.78, get-tuple-element.67) + dynamic-slice.80 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.72, select.79, constant.76), dynamic_slice_sizes={1,16} + reshape.81 = f32[16]{0} reshape(dynamic-slice.80) + custom-call.91 = f32[16]{0} custom-call(reshape.81), custom_call_target="annotate_device_placement", frontend_attributes={_xla_buffer_placement="device"} + cosine.92 = f32[16]{0} cosine(custom-call.91) + reshape.103 = f32[1,16]{1,0} reshape(cosine.92) + compare.104 = pred[] compare(get-tuple-element.67, constant.76), direction=LT + add.105 = s32[] add(get-tuple-element.67, constant.75) + select.106 = s32[] select(compare.104, add.105, get-tuple-element.67) + dynamic-update-slice.107 = f32[16,16]{1,0} dynamic-update-slice(get-tuple-element.71, reshape.103, select.106, constant.76) + ROOT tuple.109 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(add.108, custom-call.87, dynamic-update-slice.97, dynamic-update-slice.102, dynamic-update-slice.107, get-tuple-element.72, get-tuple-element.73) +} + +region_3.110 { + arg_tuple.111 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) parameter(0) + get-tuple-element.113 = f32[16]{0} get-tuple-element(arg_tuple.111), index=1 + get-tuple-element.114 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=2 + get-tuple-element.115 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=3 + get-tuple-element.116 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=4 + get-tuple-element.117 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=5 + get-tuple-element.118 = f32[16,16]{1,0} get-tuple-element(arg_tuple.111), index=6 + get-tuple-element.112 = s32[] get-tuple-element(arg_tuple.111), index=0 + constant.119 = s32[] constant(16) + ROOT compare.120 = pred[] compare(get-tuple-element.112, constant.119), direction=LT +} + +region_4.130 { + arg_tuple.131 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) parameter(0) + get-tuple-element.132 = s32[] get-tuple-element(arg_tuple.131), index=0 + constant.140 = s32[] constant(1) + add.164 = s32[] add(get-tuple-element.132, constant.140) + get-tuple-element.133 = f32[16]{0} get-tuple-element(arg_tuple.131), index=1 + get-tuple-element.134 = f32[] get-tuple-element(arg_tuple.131), index=2 + broadcast.159 = f32[16]{0} broadcast(get-tuple-element.134), dimensions={} + add.160 = f32[16]{0} add(get-tuple-element.133, broadcast.159) + get-tuple-element.137 = f32[16,16]{1,0} get-tuple-element(arg_tuple.131), index=5 + constant.141 = s32[] constant(16) + subtract.142 = s32[] subtract(constant.141, get-tuple-element.132) + subtract.143 = s32[] subtract(subtract.142, constant.140) + constant.139 = s32[] constant(0) + compare.154 = pred[] compare(subtract.143, constant.139), direction=LT + add.155 = s32[] add(subtract.143, constant.141) + select.156 = s32[] select(compare.154, add.155, subtract.143) + dynamic-slice.157 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.137, select.156, constant.139), dynamic_slice_sizes={1,16} + reshape.158 = f32[16]{0} reshape(dynamic-slice.157) + multiply.161 = f32[16]{0} multiply(add.160, reshape.158) + get-tuple-element.136 = f32[16,16]{1,0} get-tuple-element(arg_tuple.131), index=4 + compare.149 = pred[] compare(subtract.143, constant.139), direction=LT + add.150 = s32[] add(subtract.143, constant.141) + select.151 = s32[] select(compare.149, add.150, subtract.143) + dynamic-slice.152 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.136, select.151, constant.139), dynamic_slice_sizes={1,16} + reshape.153 = f32[16]{0} reshape(dynamic-slice.152) + multiply.162 = f32[16]{0} multiply(multiply.161, reshape.153) + get-tuple-element.135 = f32[16,16]{1,0} get-tuple-element(arg_tuple.131), index=3 + compare.144 = pred[] compare(subtract.143, constant.139), direction=LT + add.145 = s32[] add(subtract.143, constant.141) + select.146 = s32[] select(compare.144, add.145, subtract.143) + dynamic-slice.147 = f32[1,16]{1,0} dynamic-slice(get-tuple-element.135, select.146, constant.139), dynamic_slice_sizes={1,16} + reshape.148 = f32[16]{0} reshape(dynamic-slice.147) + multiply.163 = f32[16]{0} multiply(multiply.162, reshape.148) + constant.138 = f32[] constant(0) + ROOT tuple.165 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) tuple(add.164, multiply.163, constant.138, get-tuple-element.135, get-tuple-element.136, get-tuple-element.137) +} + +region_5.166 { + arg_tuple.167 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) parameter(0) + get-tuple-element.169 = f32[16]{0} get-tuple-element(arg_tuple.167), index=1 + get-tuple-element.170 = f32[] get-tuple-element(arg_tuple.167), index=2 + get-tuple-element.171 = f32[16,16]{1,0} get-tuple-element(arg_tuple.167), index=3 + get-tuple-element.172 = f32[16,16]{1,0} get-tuple-element(arg_tuple.167), index=4 + get-tuple-element.173 = f32[16,16]{1,0} get-tuple-element(arg_tuple.167), index=5 + get-tuple-element.168 = s32[] get-tuple-element(arg_tuple.167), index=0 + constant.174 = s32[] constant(16) + ROOT compare.175 = pred[] compare(get-tuple-element.168, constant.174), direction=LT +} + +ENTRY main.183 { + constant.6 = s32[] constant(0) + Arg_0.1 = f32[16]{0} parameter(0), sharding={devices=[2]<=[2]} + call.55 = (f32[16,16]{1,0}, f32[16,16]{1,0}) call(Arg_0.1), to_apply=core_closed_call.43 + get-tuple-element.56 = f32[16,16]{1,0} get-tuple-element(call.55), index=0 + get-tuple-element.57 = f32[16,16]{1,0} get-tuple-element(call.55), index=1 + constant.7 = f32[] constant(1) + tuple.58 = (f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16]{0}, f32[]) tuple(get-tuple-element.56, get-tuple-element.57, Arg_0.1, constant.7) + opt-barrier.59 = (f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16]{0}, f32[]) opt-barrier(tuple.58) + get-tuple-element.62 = f32[16]{0} get-tuple-element(opt-barrier.59), index=2 + constant.4 = f32[] constant(0) + broadcast.5 = f32[16,16]{1,0} broadcast(constant.4), dimensions={} + get-tuple-element.60 = f32[16,16]{1,0} get-tuple-element(opt-barrier.59), index=0 + get-tuple-element.61 = f32[16,16]{1,0} get-tuple-element(opt-barrier.59), index=1 + tuple.64 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(constant.6, get-tuple-element.62, broadcast.5, broadcast.5, broadcast.5, get-tuple-element.60, get-tuple-element.61) + while.121 = (s32[], f32[16]{0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}, f32[16,16]{1,0}) while(tuple.64), condition=region_3.110, body=region_2.65 + get-tuple-element.122 = s32[] get-tuple-element(while.121), index=0 + get-tuple-element.123 = f32[16]{0} get-tuple-element(while.121), index=1 + get-tuple-element.127 = f32[16,16]{1,0} get-tuple-element(while.121), index=5 + get-tuple-element.128 = f32[16,16]{1,0} get-tuple-element(while.121), index=6 + constant.2 = f32[] constant(0) + broadcast.3 = f32[16]{0} broadcast(constant.2), dimensions={} + get-tuple-element.63 = f32[] get-tuple-element(opt-barrier.59), index=3 + get-tuple-element.124 = f32[16,16]{1,0} get-tuple-element(while.121), index=2 + get-tuple-element.125 = f32[16,16]{1,0} get-tuple-element(while.121), index=3 + get-tuple-element.126 = f32[16,16]{1,0} get-tuple-element(while.121), index=4 + tuple.129 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) tuple(constant.6, broadcast.3, get-tuple-element.63, get-tuple-element.124, get-tuple-element.125, get-tuple-element.126) + while.176 = (s32[], f32[16]{0}, f32[], f32[16,16]{1,0}, f32[16,16]{1,0}, /*index=5*/f32[16,16]{1,0}) while(tuple.129), condition=region_5.166, body=region_4.130 + get-tuple-element.177 = s32[] get-tuple-element(while.176), index=0 + ROOT get-tuple-element.178 = f32[16]{0} get-tuple-element(while.176), index=1 + get-tuple-element.179 = f32[] get-tuple-element(while.176), index=2 + get-tuple-element.180 = f32[16,16]{1,0} get-tuple-element(while.176), index=3 + get-tuple-element.181 = f32[16,16]{1,0} get-tuple-element(while.176), index=4 + get-tuple-element.182 = f32[16,16]{1,0} get-tuple-element(while.176), index=5 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + bool changed = + ConvertMemoryPlacementToInternalAnnotations().Run(module.get()).value(); + EXPECT_TRUE(changed); + XLA_VLOG_LINES(1, module->ToString()); + int64_t custom_calls_count = 0; + for (auto* c : module->computations()) { + for (auto* instr : c->instructions()) { + if (instr->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget) || + instr->IsCustomCall( + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) { + ++custom_calls_count; + } + } + } + EXPECT_EQ(custom_calls_count, 4); +} + +} // namespace +} // namespace xla diff --git a/xla/service/convert_mover.cc b/xla/service/convert_mover.cc index cb7681658d659..b847c199753b8 100644 --- a/xla/service/convert_mover.cc +++ b/xla/service/convert_mover.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,11 +30,11 @@ static bool IsLosslesslyConvertibleTo(const Literal& literal, // The only reason Convert() should fail is if we don't support converting // from x to y, which indeed means it's not losslessly-convertible. - StatusOr converted1 = literal.Convert(dst_ty); + absl::StatusOr converted1 = literal.Convert(dst_ty); if (!converted1.ok()) { return false; } - StatusOr converted2 = converted1->Convert(orig_ty); + absl::StatusOr converted2 = converted1->Convert(orig_ty); if (!converted2.ok()) { return false; } @@ -64,7 +64,7 @@ bool OpCommutesWithConvert(HloOpcode opcode) { } } -StatusOr MoveConvertPrecisionOps(HloComputation* comp) { +absl::StatusOr MoveConvertPrecisionOps(HloComputation* comp) { bool changed = false; // Move increase_precision "down" the graph: @@ -196,7 +196,7 @@ StatusOr MoveConvertPrecisionOps(HloComputation* comp) { } // anonymous namespace -StatusOr ConvertMover::Run( +absl::StatusOr ConvertMover::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/convert_mover.h b/xla/service/convert_mover.h index 0c47e2c97843a..57ae70de0bb94 100644 --- a/xla/service/convert_mover.h +++ b/xla/service/convert_mover.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -39,7 +39,8 @@ class ConvertMover : public HloModulePass { ConvertMover() = default; absl::string_view name() const override { return "convert-mover"; } - StatusOr Run( + using HloPassInterface::Run; + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/convert_mover_test.cc b/xla/service/convert_mover_test.cc index 21bb2ba77a287..cda7a4eb04046 100644 --- a/xla/service/convert_mover_test.cc +++ b/xla/service/convert_mover_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/convert_operand_folding.cc b/xla/service/convert_operand_folding.cc index daa28c0fdb7cb..97760f1605527 100644 --- a/xla/service/convert_operand_folding.cc +++ b/xla/service/convert_operand_folding.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,19 +15,70 @@ limitations under the License. #include "xla/service/convert_operand_folding.h" -#include "absl/base/attributes.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" +#include "xla/shape_util.h" #include "xla/xla_data.pb.h" namespace xla { namespace { bool IsUpcastConvert(const HloInstruction* hlo) { - if (hlo->opcode() != HloOpcode::kConvert) { + if (!hlo->shape().IsArray()) { return false; } - return primitive_util::CastPreservesValues( - hlo->operand(0)->shape().element_type(), hlo->shape().element_type()); + switch (hlo->opcode()) { + case HloOpcode::kDynamicSlice: + case HloOpcode::kGather: + case HloOpcode::kReshape: + case HloOpcode::kSlice: + case HloOpcode::kTranspose: { + return IsUpcastConvert(hlo->operand(0)); + } + case HloOpcode::kReduce: { + if (ShapeUtil::ElementsIn(hlo->shape()) == + ShapeUtil::ElementsIn(hlo->operand(0)->shape())) { + return IsUpcastConvert(hlo->operand(0)); + } + return false; + } + case HloOpcode::kConvert: + return primitive_util::CastPreservesValues( + hlo->operand(0)->shape().element_type(), hlo->shape().element_type()); + default: + return false; + } +} + +HloInstruction* EffectiveOperand(HloInstruction* hlo) { + switch (hlo->opcode()) { + case HloOpcode::kBroadcast: + case HloOpcode::kDynamicSlice: + case HloOpcode::kGather: + case HloOpcode::kReshape: + case HloOpcode::kSlice: + case HloOpcode::kTranspose: { + HloInstruction* operand = EffectiveOperand(hlo->mutable_operand(0)); + HloInstruction* clone = hlo->AddInstruction(hlo->Clone()); + *(clone->mutable_shape()) = ShapeUtil::ChangeElementType( + clone->shape(), operand->shape().element_type()); + clone->ReplaceOperandWithDifferentShape(0, operand).IgnoreError(); + return clone; + } + case HloOpcode::kReduce: { + // Reduce is a reshape in the case the the hlo chain was an upcast. + HloInstruction* operand = EffectiveOperand(hlo->mutable_operand(0)); + return hlo->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::ChangeElementType(hlo->shape(), + operand->shape().element_type()), + operand)); + } + case HloOpcode::kConvert: + return hlo->mutable_operand(0); + default: + return nullptr; + } } } // namespace @@ -46,13 +97,13 @@ bool ConvertOperandFolding::InstructionMatchesPattern( return false; } -StatusOr ConvertOperandFolding::ExpandInstruction( +absl::StatusOr ConvertOperandFolding::ExpandInstruction( HloInstruction* instruction) { for (int i = 0; i < instruction->operand_count(); ++i) { auto* operand = instruction->mutable_operand(i); if (IsUpcastConvert(operand)) { TF_RETURN_IF_ERROR(instruction->ReplaceOperandWithDifferentShape( - i, operand->mutable_operand(0))); + i, EffectiveOperand(operand))); } } return nullptr; diff --git a/xla/service/convert_operand_folding.h b/xla/service/convert_operand_folding.h index 739d941e20bd4..5ab6750f8f49e 100644 --- a/xla/service/convert_operand_folding.h +++ b/xla/service/convert_operand_folding.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,7 +32,7 @@ class ConvertOperandFolding : public OpExpanderPass { protected: bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; }; diff --git a/xla/service/convert_operand_folding_test.cc b/xla/service/convert_operand_folding_test.cc index c1a8b5d459cfb..95841a369a3a4 100644 --- a/xla/service/convert_operand_folding_test.cc +++ b/xla/service/convert_operand_folding_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -142,5 +142,73 @@ TEST_F(ConvertOperandFoldingTest, OneOperandFolded) { op::Shape("s16[2,2]{1,0}"))); } +TEST_F(ConvertOperandFoldingTest, FoldedWithFormatting) { + absl::string_view module_string = R"( + HloModule module + sum { + a = s16[] parameter(0) + b = s16[] parameter(1) + ROOT r = add(a,b) + } + + ENTRY main { + p0 = s8[3,10] parameter(0) + c0 = s16[3,10] convert(p0) + r0 = s16[3,2,5] reshape(c0) + t0 = s16[2,5,3] transpose(r0), dimensions={1,2,0} + s0 = s16[2,1,3] slice(t0), slice={[0:2], [2:3], [0:3]} + rs0 = s16[2,3] reshape(s0) + p1 = s8[3,1,2] parameter(1) + c1 = s16[3,1,2] convert(p1) + r1 = s16[1,3,2] transpose(c1), dimensions={1,0,2} + z = s16[] constant(0) + rr1 = s16[3,2] reduce(r1,z), dimensions={0}, to_apply=sum + ROOT dot = s16[2,2] dot(rs0, rr1), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool folded, + ConvertOperandFolding().Run(module.get())); + EXPECT_TRUE(folded); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Dot( + op::Reshape(op::Slice(op::Transpose(op::Reshape(op::Parameter(0))))), + op::Reshape(op::Transpose(op::Parameter(1))))); +} + +TEST_F(ConvertOperandFoldingTest, FoldedWithDSAndGather) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = s8[100,3] parameter(0) + c0 = s16[100,3] convert(p0) + ids = s32[20] parameter(2) + g = s16[20,3] gather(c0, ids), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,3} + t = s16[3,20] transpose(g), dimensions={1,0} + + p1 = s8[25,3] parameter(1) + c1 = s16[25,3] convert(p1) + z = s32[] constant(0) + s = s32[] parameter(3) + ds = s16[20,3] dynamic-slice(c1, s, z), dynamic_slice_sizes={20,3} + + ROOT dot = s16[3,3] dot(t, ds), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool folded, + ConvertOperandFolding().Run(module.get())); + EXPECT_TRUE(folded); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Dot(op::Transpose(op::Gather(op::Parameter(0), op::Parameter(2))), + op::DynamicSlice(op::Parameter(1), op::Parameter(3), + op::Constant()))); +} + } // namespace } // namespace xla diff --git a/xla/service/convolution_4d_expander.cc b/xla/service/convolution_4d_expander.cc index 742f85c589c63..594e77434cd55 100644 --- a/xla/service/convolution_4d_expander.cc +++ b/xla/service/convolution_4d_expander.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -53,7 +53,7 @@ bool Convolution4DExpander::InstructionMatchesPattern( return false; } -StatusOr Convolution4DExpander::ExpandInstruction( +absl::StatusOr Convolution4DExpander::ExpandInstruction( HloInstruction* instruction) { HloComputation* computation = instruction->parent(); ConvolutionDimensionNumbers dim_nums = diff --git a/xla/service/convolution_4d_expander.h b/xla/service/convolution_4d_expander.h index baa4ce43edfa8..dd4ed80f1e1cc 100644 --- a/xla/service/convolution_4d_expander.h +++ b/xla/service/convolution_4d_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ class Convolution4DExpander : public OpExpanderPass { protected: bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; }; diff --git a/xla/service/convolution_4d_expander_test.cc b/xla/service/convolution_4d_expander_test.cc index 41121b09dd63a..547a5a761cb9d 100644 --- a/xla/service/convolution_4d_expander_test.cc +++ b/xla/service/convolution_4d_expander_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/convolution_group_converter.cc b/xla/service/convolution_group_converter.cc index e42878e195044..4dfba405d7a7f 100644 --- a/xla/service/convolution_group_converter.cc +++ b/xla/service/convolution_group_converter.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -674,7 +674,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { } // namespace -StatusOr ConvolutionGroupConverter::Run( +absl::StatusOr ConvolutionGroupConverter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES( diff --git a/xla/service/convolution_group_converter.h b/xla/service/convolution_group_converter.h index a55633effe73d..e41b4e711b8c0 100644 --- a/xla/service/convolution_group_converter.h +++ b/xla/service/convolution_group_converter.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -45,7 +45,7 @@ class ConvolutionGroupConverter : public HloModulePass { // Run convolution rewriting on the given computation. Returns whether the // computation was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/convolution_group_converter_test.cc b/xla/service/convolution_group_converter_test.cc index a0d73a53bb242..16f7dcbd49acf 100644 --- a/xla/service/convolution_group_converter_test.cc +++ b/xla/service/convolution_group_converter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/convolution_pred_expander.cc b/xla/service/convolution_pred_expander.cc index 8f223353e86f1..f5a828684bc23 100644 --- a/xla/service/convolution_pred_expander.cc +++ b/xla/service/convolution_pred_expander.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ bool ConvolutionPredExpander::InstructionMatchesPattern( .WithElementType(PRED)); } -StatusOr ConvolutionPredExpander::ExpandInstruction( +absl::StatusOr ConvolutionPredExpander::ExpandInstruction( HloInstruction* instruction) { HloComputation* computation = instruction->parent(); diff --git a/xla/service/convolution_pred_expander.h b/xla/service/convolution_pred_expander.h index df6ea22ca6683..121a33de56975 100644 --- a/xla/service/convolution_pred_expander.h +++ b/xla/service/convolution_pred_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ class ConvolutionPredExpander : public OpExpanderPass { protected: bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; }; diff --git a/xla/service/convolution_pred_expander_test.cc b/xla/service/convolution_pred_expander_test.cc index ff99afd44c1d6..f3c95f8a7cd28 100644 --- a/xla/service/convolution_pred_expander_test.cc +++ b/xla/service/convolution_pred_expander_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/copy_insertion.cc b/xla/service/copy_insertion.cc index 546f25e431eb2..5e99bcb5fa4a2 100644 --- a/xla/service/copy_insertion.cc +++ b/xla/service/copy_insertion.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" @@ -132,7 +133,7 @@ bool ShouldCopyRootValue(const HloValue& value, // \ / // Tuple // -StatusOr> +absl::StatusOr> DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to, const ShapeTree& indices_to_copy) { DCHECK(ShapeUtil::Compatible(from->shape(), to->shape())); @@ -481,6 +482,13 @@ class LiveRangeRegions { // create a new HloValue aliasing without defining a new value (cannot be // value_definition). bool is_definition; + + std::string ToString() const { + return absl::StrCat( + "is_definition: ", std::to_string(is_definition), + ", value_definition: ", + value_definition ? value_definition->name() : "nullptr"); + } }; // Map instructions that use a value to the defining instruction of the value. // Because all values must belong to the same live range, an instruction can @@ -532,6 +540,20 @@ class LiveRangeRegions { return instr_map.find(instr) != instr_map.end(); } + std::string ToString() const { + std::string result; + + for (const auto* computation : computation_vector_) { + StrAppend(&result, "computation: ", computation->name(), "\n"); + for (const auto& entry : computation_map_.at(computation)) { + StrAppend(&result, " entry: ", entry.first->name(), ", ", + entry.second.ToString(), "\n"); + } + } + + return result; + } + private: ComputationMap computation_map_; absl::InlinedVector computation_vector_; @@ -723,7 +745,7 @@ class ComputeRelativeLocation { typedef LiveRangeRegions::InstructionEntry InstructionEntry; explicit ComputeRelativeLocation(HloOrdering* ordering) : ordering_(ordering) { - VLOG(3) << "New analysis\n"; + VLOG(3) << "New analysis"; } // Compute locationing constraints between two instructions. Here entry2 is @@ -752,7 +774,7 @@ class ComputeRelativeLocation { // computation, then any modification is considered intercepting. if (def->opcode() == HloOpcode::kParameter && use == use->parent()->root_instruction()) { - VLOG(3) << "Setting interception due to parameter/root relation\n"; + VLOG(3) << "Setting interception due to parameter/root relation"; return Relation(order, true); } @@ -781,11 +803,61 @@ class ComputeRelativeLocation { return Relation(order, false); } + // Special case for conditional instruction when in one branch two results + // can be put in one buffers and another branch returns two results from a + // multi-output instruction, e.g. fusion or variadic reduction. + // + // branch_0 { + // exp = f64[] exp(...) + // ROOT tuple = (f64[], f64[]) tuple(exp, exp) + // } + // + // fused_computation { + // abs = f64[] abs(...) + // negate = f64[] negate(...) + // ROOT tuple = (f64[], f64[]) tuple(abs, negate) + // } + // + // branch_1 { + // ROOT fusion = (f64[], f64[]) fusion(...), calls=%fused_computation + // } + // + // ENTRY main { + // ROOT root = (f64[], f64[]) conditional(...), + // branch_computations={%branch_0, %branch_1} + // } + // + // `branch_0` can use one buffer for both result. `branch_1` must use two + // different buffers. + // + // During live range analysis of results of `branch_0` this function will be + // called when entry1 and entry2 are different outputs on `fusion` in + // `branch_1`. `fusion` defines two buffers, but `value_definition` in + // LiveRangeRegions::InstructionInfo does not track the output index. The + // analysis will say that they are not interfering and assign the same + // buffer to both. + // + // This check makes sure that outputs of multi-output instructions are + // always interfering and can not be combined. It can be a false positive + // when entry1 and entry2 correspond to the same output, but we prefer that + // over correctness issues. + // + // A proper solution would be to track output index in + // LiveRangeRegions::InstructionInfo. + if (use->parent() == def->parent() && + def->parent()->IsConditionalBranchComputation() && + def == entry2.first && def->shape().IsTuple()) { + VLOG(3) << "Setting interception for multi-output instruction inside " + "conditional branch: " + << def->name(); + return Relation(order, true); + } + if (Relation::UseImpliesInterception(order)) { auto order2 = ComputeRuntimeOrdering(entry2.first, def); if (Relation::DefinitionImpliesInterception(order2)) { VLOG(3) << "Setting interception for " << def->ToString() - << " with use:" << entry1.first->ToString() << "\n"; + << " with use: " << entry1.first->ToString(); intercept = true; } } @@ -797,8 +869,7 @@ class ComputeRelativeLocation { Relation Compute(const LiveRangeRegions& range1, const LiveRangeRegions& range2) { Relation dir_src_dest; - for (int64_t index = 0; index < range1.size(); index++) { - auto* computation1 = range1.Computation(index); + for (const auto* computation1 : range1) { for (const auto* computation2 : range2) { for (auto instr_entry2 : range2[computation2]) { if (!ordering_->call_graph().Dominates(computation1, computation2)) { @@ -813,15 +884,15 @@ class ComputeRelativeLocation { bool unordered_intercept = false; for (auto instr_entry1 : range1[computation1]) { auto rel = Compute(instr_entry1, instr_entry2, instr2_can_modify); - VLOG(3) << "new relation with:" << instr_entry1.first->ToString() - << " = " << rel.ToString() << "\n"; + VLOG(3) << "New relation with " << instr_entry1.first->name() + << ": " << rel.ToString(); if (!rel.RuntimeOrderIsUnordered()) { instr2_relation.UnionRelationFromSameSource(rel); } else { unordered_ops.push_back(instr_entry1); unordered_intercept |= rel.InterceptDefUse(); } - VLOG(3) << "instr2 relation:" << instr2_relation.ToString() << "\n"; + VLOG(3) << "instr2 relation: " << instr2_relation.ToString(); } // Here instru2_relation is guaranteed to have at most a single entry, // because it was initialized to be empty, and has been updated only @@ -829,12 +900,12 @@ class ComputeRelativeLocation { // maintains that the updated result has only a single entry. if (!ForceRuntimeOrder(unordered_ops, instr_entry2, instr2_relation.GetRuntimeOrder())) { - VLOG(3) << "Unable to force ordering of unordered ops\n"; + VLOG(3) << "Unable to force ordering of unordered ops"; instr2_relation.UnionRelationFromSameSource(Relation( Relation::kBeforeStartOrAfterEnd, unordered_intercept)); } dir_src_dest.UnionRelationFromDifferentSource(instr2_relation); - VLOG(3) << "Resulting relation : " << dir_src_dest.ToString() << "\n"; + VLOG(3) << "Resulting relation: " << dir_src_dest.ToString(); } } } @@ -859,8 +930,8 @@ class ComputeRelativeLocation { for (const auto& instr_it : comp_it.second) { HloInstruction* entry1 = instr_it.first; for (HloInstruction* entry2 : instr_it.second) { - VLOG(3) << "Add control dependence between " << entry2->ToString(); - VLOG(3) << "\n vs " << entry1->ToString() << "\n"; + VLOG(3) << "Add control dependence between " << entry2->name() + << " vs " << entry1->name(); TF_CHECK_OK(entry2->AddControlDependencyTo(entry1)); } reachability_map.UpdateReachabilityThroughInstruction(entry1); @@ -925,8 +996,7 @@ class ComputeRelativeLocation { if (succ->opcode() == HloOpcode::kCopy && ModifiesNonCopy(pred, succ->operand(0))) { VLOG(3) << "Failed to force unordered op ordering due to copy ordering " - << " between " << pred->ToString() << "\n"; - VLOG(3) << " vs. " << succ->ToString() << "\n"; + << " between " << pred->name() << " vs " << succ->name(); return false; } } @@ -1056,8 +1126,8 @@ class ComputeRelativeLocation { (relation == Relation::kBeforeStart) ? entry1 : entry2; HloInstruction* succ = (relation == Relation::kBeforeStart) ? entry2 : entry1; - VLOG(3) << "Save unordered relation: " << pred->ToString() << "\n"; - VLOG(3) << " vs " << succ->ToString() << "\n"; + VLOG(3) << "Save unordered relation: " << pred->name() << " vs " + << succ->name(); CHECK_EQ(succ->parent(), pred->parent()); auto& dep_vec = ctrl_deps_[succ->parent()][succ]; for (HloInstruction*& op : dep_vec) { @@ -1071,8 +1141,8 @@ class ComputeRelativeLocation { return relation; } } - VLOG(2) << "Forcing unordered:" << pred->ToString() << "\n"; - VLOG(2) << " vs " << succ->ToString() << "\n"; + VLOG(2) << "Forcing unordered: " << pred->name() << " vs " + << succ->name(); dep_vec.push_back(pred); } return relation; @@ -1083,8 +1153,8 @@ class ComputeRelativeLocation { HloInstruction* instr2) { auto saved_relation = AlreadyComputed(instr1, instr2); if (saved_relation.first != kNotComputed) { - VLOG(3) << "Already computed between " << instr1->ToString() << "\n vs " - << instr2->ToString() << "\n"; + VLOG(3) << "Already computed between " << instr1->name() << " vs " + << instr2->name(); return saved_relation.second; } auto constraint = ordering_->GetExecutionConstraint(instr1, instr2); @@ -1120,15 +1190,15 @@ class ComputeRelativeLocation { if (absl::c_any_of(ctrl_deps[instr2], [&](HloInstruction* pred2) { return ControlDependenceBefore(instr1, pred2); })) { - VLOG(2) << "control-dependent: " << instr1->ToString() << "\n"; - VLOG(2) << "vs " << instr2->ToString() << "\n"; + VLOG(2) << "control-dependent: " << instr1->name() << " vs " + << instr2->name(); return Save(instr1, instr2, Relation::kBeforeStart); } else if (absl::c_any_of( ctrl_deps[instr1], [&](HloInstruction* pred1) { return ControlDependenceBefore(instr2, pred1); })) { - VLOG(2) << "control-dependent: " << instr2->ToString() << "\n"; - VLOG(2) << "vs " << instr1->ToString() << "\n"; + VLOG(2) << "control-dependent: " << instr2->name() << " vs " + << instr1->name(); return Save(instr1, instr2, Relation::kAfterEnd); } } @@ -1489,15 +1559,15 @@ class CopyRemover { CHECK_NE(src, nullptr); CHECK_NE(dest, nullptr); if (!use_region_analysis) { - VLOG(2) << "Configured to not use region-based analysis.\n"; + VLOG(2) << "Configured to not use region-based analysis."; return true; } *region_analysis_limit += live_range_size1 * live_range_size2; if (ValuesInterfere(src, dest, option)) { - VLOG(2) << "Region-based interference is true. \n"; + VLOG(2) << "Region-based interference is true."; return true; } - VLOG(2) << "Region-based interference is false. \n"; + VLOG(2) << "Region-based interference is false."; return false; }; @@ -1576,7 +1646,7 @@ class CopyRemover { CheckLiveRangeBefore(copy_node.src, Next(*copy_node.dest)) && // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}. CheckLiveRangeBefore(copy_node.dest->prev, Next(*copy_node.src)); - VLOG(2) << "LiveRangeBefore result: " << live_range_before << "\n"; + VLOG(2) << "LiveRangeBefore result: " << live_range_before; if (!live_range_before && CheckLiveRangeInterference(copy_node.src, copy_node.dest, kMergeFirstDestInSource)) { @@ -1606,11 +1676,11 @@ class CopyRemover { CheckLiveRangeBefore(Prev(*copy_node.dest), copy_node.src->next) && // Live range of 'last_src' must be before next_dest d_{y+1}. CheckLiveRangeBefore(copy_node.src, Next(*copy_node.dest)); - VLOG(2) << "LiveRangeBefore result: " << live_range_before << "\n"; + VLOG(2) << "LiveRangeBefore result: " << live_range_before; if (!live_range_before && CheckLiveRangeInterference(copy_node.src, copy_node.dest, kMergeLastSourceInDest)) { - VLOG(2) << "Region-based analysis concludes interference.\n"; + VLOG(2) << "Region-based analysis concludes interference."; return false; } VLOG(2) << "Splice src after prev of dest."; @@ -1683,8 +1753,8 @@ class CopyRemover { VLOG(2) << "Empty uses for " << *a.value; return ordering_->IsDefinedBefore(*a.value, *b.value); } - VLOG(3) << "Checking live ranges before :" << ValueListToString(&a) - << " vs " << ValueListToString(&b) << "\n"; + VLOG(3) << "Checking live ranges before: " << ValueListToString(&a) + << " vs " << ValueListToString(&b); // If any of the positions of the "a" value is a root of the same // computation as "b", "a"'s live range cannot be before "b"'s. This catches // the cases where the root may not be the last instruction in the @@ -1756,27 +1826,34 @@ class CopyRemover { // Get the entire range of values sharing the buffers in src and dest. auto src_live_range = ComputeLiveRangeRegions(src); auto dest_live_range = ComputeLiveRangeRegions(dest); + + VLOG(5) << "src value: " << src->value->ToString(); + VLOG(5) << "src live range:\n" << src_live_range.ToString(); + + VLOG(5) << "dest value: " << dest->value->ToString(); + VLOG(5) << "dest live range:\n" << dest_live_range.ToString(); + ComputeRelativeLocation relative_location_analysis(ordering_); auto rel1 = relative_location_analysis.Compute(src_live_range, dest_live_range); - VLOG(3) << "Location of dest in relation to src:" << rel1.ToString() - << " with interception set to " << rel1.InterceptDefUse() << "\n"; + VLOG(3) << "Location of dest in relation to src: " << rel1.ToString() + << " with interception set to " << rel1.InterceptDefUse(); auto rel2 = relative_location_analysis.Compute(dest_live_range, src_live_range); - VLOG(3) << "Location of src in relation to dest:" << rel2.ToString() - << " with interception set to " << rel1.InterceptDefUse() << "\n"; + VLOG(3) << "Location of src in relation to dest: " << rel2.ToString() + << " with interception set to " << rel2.InterceptDefUse(); // If src and dest are interleaved with each other, they interfere. if (rel1.RuntimeOrderOverlap() && rel2.RuntimeOrderOverlap()) { - VLOG(3) << "Both relations are overlap.\n"; + VLOG(3) << "Both relations are overlap."; return true; } // If src and dest belong to the same group of computations and do not // overlap, they do not interfere. if (rel1.RuntimeOrderOverlap() || rel2.RuntimeOrderOverlap()) { - VLOG(3) << "At least one relation is overlap.\n"; + VLOG(3) << "At least one relation is overlap."; if (rel1.RuntimeOrderOverlap()) { VLOG(3) << "rel1 is overlap, with interception = " - << rel1.InterceptDefUse() << "\n"; + << rel1.InterceptDefUse(); if (rel1.InterceptDefUse() || (merge_location != kMergeFirstDestInSource && rel2.InterceptDefUse())) { @@ -1784,7 +1861,7 @@ class CopyRemover { } } else { VLOG(3) << "rel2 is overlap, with interception = " - << rel2.InterceptDefUse() << "\n"; + << rel2.InterceptDefUse(); // Here src is at the end of a nested computation inside dest. if (rel2.InterceptDefUse() || (merge_location != kMergeLastSourceInDest && @@ -1927,6 +2004,9 @@ Status CopyInsertion::AddCopiesToResolveInterference( HloAliasAnalysis::Run(module, can_share_buffer_)); for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { + if (computation->IsAsyncComputation()) { + continue; + } for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { if (instruction->opcode() == HloOpcode::kWhile) { @@ -1940,7 +2020,13 @@ Status CopyInsertion::AddCopiesToResolveInterference( // have been copied. absl::flat_hash_set copied_operands; for (const auto& operand_and_output_index : - HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) { + HloDataflowAnalysis::GetInPlaceInputOutputPairs( + // Input/output buffer aliasing analysis needs to be done + // directly with the wrapped instruction when the compiler sees + // an async box. + instruction->opcode() == HloOpcode::kAsyncStart + ? instruction->async_wrapped_instruction() + : instruction)) { const HloOperandIndex& operand_index = operand_and_output_index.first; if (copied_operands.contains(operand_index.operand_number)) { continue; @@ -2162,7 +2248,8 @@ static int64_t GetNumExistingCopies( Status CopyInsertion::RemoveUnnecessaryCopies( HloModule* module, bool check_live_range_ordering, const absl::flat_hash_set& execution_threads) { - XLA_VLOG_LINES(4, module->ToString()); + XLA_VLOG_LINES( + 4, module->ToString(HloPrintOptions().set_syntax_sugar_async_ops(false))); // Use SequentialHloOrdering if the module has a schedule. The schedule can // provide more information on the ordering, allowing for detecting more @@ -2192,7 +2279,7 @@ Status CopyInsertion::RemoveUnnecessaryCopies( bool changed = true; int64_t num_iterations = -1; VLOG(6) << "Copy Insertion analyzing module with instruction count = " - << module->instruction_count() << "\n"; + << module->instruction_count(); BoundNonLinearCompilerAnalysis allowance(module, name(), 10); while (changed) { CHECK_LE(++num_iterations, num_existing_copies); @@ -2201,9 +2288,10 @@ Status CopyInsertion::RemoveUnnecessaryCopies( << " of copy elision"; for (HloComputation* computation : module->computations(execution_threads)) { - VLOG(2) << "computation:" << computation->name() << "\n"; + VLOG(2) << "computation:" << computation->name(); for (HloInstruction* instruction : computation->instructions()) { - VLOG(2) << instruction->ToString() << "\n"; + if (instruction->opcode() != HloOpcode::kCopy) continue; + // The region_analysis_cost_now is always set to // use_region_based_live_range_analysis_ if it is < 0, in which case the // analysis is always performed. @@ -2212,22 +2300,19 @@ Status CopyInsertion::RemoveUnnecessaryCopies( ? 0 : std::min(allowance.analysis_allowance(), use_region_based_live_range_analysis_); - if (instruction->opcode() == HloOpcode::kCopy) { - if (copy_remover.TryElideCopy(instruction, - ®ion_analysis_cost_now)) { - changed = true; - TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction)); - TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith( - instruction->mutable_operand(0))); - VLOG(6) << "succeeded in eliminating copy.\n"; - } - if (allowance.ContinueAnalysis() && region_analysis_cost_now > 0) { - VLOG(6) << "Copy Insertion analyzing module cost: " - << region_analysis_cost_now << "\n"; - VLOG(6) << "instruction:" << instruction->ToString() << "\n"; - allowance.DeductCost(region_analysis_cost_now); - VLOG(6) << "allowance:" << allowance.analysis_allowance() << "\n"; - } + if (copy_remover.TryElideCopy(instruction, ®ion_analysis_cost_now)) { + changed = true; + TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction)); + TF_RETURN_IF_ERROR( + instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); + VLOG(6) << "succeeded in eliminating copy."; + } + if (allowance.ContinueAnalysis() && region_analysis_cost_now > 0) { + VLOG(6) << "Copy Insertion analyzing module cost: " + << region_analysis_cost_now; + VLOG(6) << "instruction:" << instruction->ToString(); + allowance.DeductCost(region_analysis_cost_now); + VLOG(6) << "allowance:" << allowance.analysis_allowance(); } } } @@ -2235,7 +2320,7 @@ Status CopyInsertion::RemoveUnnecessaryCopies( return OkStatus(); } -StatusOr CopyInsertion::Run( +absl::StatusOr CopyInsertion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { // Copy insertion is performed in three steps: diff --git a/xla/service/copy_insertion.h b/xla/service/copy_insertion.h index 9cb49ddbb936d..a062e6d133d5d 100644 --- a/xla/service/copy_insertion.h +++ b/xla/service/copy_insertion.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -63,7 +63,7 @@ class CopyInsertion : public HloModulePass { // Run the pass on the given module. Returns whether the module was changed // (copies were inserted). using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/copy_insertion_test.cc b/xla/service/copy_insertion_test.cc index ab39471fb91d6..5250e9842895d 100644 --- a/xla/service/copy_insertion_test.cc +++ b/xla/service/copy_insertion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,25 +15,34 @@ limitations under the License. #include "xla/service/copy_insertion.h" +#include #include -#include +#include +#include +#include #include #include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "xla/comparison_util.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/literal.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/literal_util.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_parser.h" -#include "xla/service/hlo_runner.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test_benchmark.h" @@ -42,6 +51,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { +using ::testing::NotNull; using ::testing::UnorderedElementsAre; int64_t CountCopies(const HloComputation& computation) { @@ -3160,6 +3170,113 @@ ENTRY TestComputation { CHECK_EQ(tuple6->opcode(), HloOpcode::kParameter); } +TEST_F(CopyInsertionTest, ConditionalWithMultiOutputFusion) { + const std::string& hlo_string = R"( +HloModule TestModule + +branch_0 { + param_0 = f64[] parameter(0) + negate.2 = f64[] negate(f64[] param_0) + ROOT tuple = (f64[], f64[]) tuple(f64[] negate.2, f64[] negate.2) +} + +fused_computation { + param_0.1 = f64[] parameter(0) + abs.2 = f64[] abs(f64[] param_0.1) + negate.1 = f64[] negate(f64[] param_0.1) + ROOT %tuple.2 = (f64[], f64[]) tuple(f64[] negate.1, f64[] abs.2) +} + +branch_1 { + param_0.2 = f64[] parameter(0) + ROOT fusion = (f64[], f64[]) fusion(f64[] param_0.2), kind=kLoop, calls=%fused_computation +} + +ENTRY main { + pred.0 = s32[] parameter(0) + param_1 = f64[] parameter(1) + param_2 = f64[] parameter(2) + ROOT conditional.0 = (f64[], f64[]) conditional(s32[] pred.0, f64[] param_1, f64[] param_2), branch_computations={%branch_0, %branch_1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + CopyInsertion copy_insertion(nullptr, + /*use_region_based_live_range_analysis=*/-1); + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + VLOG(3) << module->ToString(); + + // `branch_0` returns the same result of `negate.2` twice. Normally the result + // would be put into one buffer and tuple would return two pointers to the + // same buffer. + // `branch_1` returns two results of multi-output fusion that should be put + // into different buffers. + // One copy is inserted in `branch_0` to ensure that result are put into two + // different buffers. + EXPECT_EQ(CountCopies(*module->GetComputationWithName("branch_0")), 1); + + EXPECT_EQ(CountCopies(*module->GetComputationWithName("branch_1")), 0); + EXPECT_EQ(CountCopies(*module->GetComputationWithName("main")), 0); +} + +TEST_F(CopyInsertionTest, ConditionalWithVariadicReduce) { + const std::string& hlo_string = R"( +HloModule TestModule + +branch_0 { + empty_tuple.0 = () parameter(0) + c_0 = f64[] constant(0) + ROOT tuple.3 = (f64[], f64[]) tuple(c_0, c_0) +} + +fused_computation { + param_0.1 = f64[] parameter(0) + abs.2 = f64[] abs(f64[] param_0.1) + negate.1 = f64[] negate(f64[] param_0.1) + ROOT %tuple.2 = (f64[], f64[]) tuple(f64[] negate.1, f64[] abs.2) +} + +reduce_region { + param_0.0 = f64[] parameter(0) + param_2.0 = f64[] parameter(2) + add.1.0 = f64[] add(param_0.0, param_2.0) + param_1.0 = f64[] parameter(1) + param_3.0 = f64[] parameter(3) + multiply.1.0 = f64[] multiply(param_1.0, param_3.0) + ROOT tuple.0.0 = (f64[], f64[]) tuple(add.1.0, multiply.1.0) +} + +branch_1 { + c_0 = f64[] constant(0) + param_0.1 = f64[128]{0} parameter(0) + ROOT reduce = (f64[], f64[]) reduce(param_0.1, param_0.1, c_0, c_0), dimensions={0}, to_apply=reduce_region +} + +ENTRY main { + pred.0 = s32[] parameter(0) + empty_tuple = () tuple() + param_2 = f64[128] parameter(1), sharding={replicated} + ROOT conditional.0 = (f64[], f64[]) conditional(s32[] pred.0, () empty_tuple, f64[128] param_2), branch_computations={%branch_0, %branch_1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + CopyInsertion copy_insertion(nullptr, + /*use_region_based_live_range_analysis=*/-1); + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + VLOG(3) << module->ToString(); + + // `branch_0` returns the same constant twice. Without copies it would return + // pointers to read-only buffer for constant. + // `branch_1` returns two results of a that should be put into different + // buffers. + // `conditional` needs to buffers for results, so the constant in `branch_0` + // should be copied twice. + EXPECT_EQ(CountCopies(*module->GetComputationWithName("branch_0")), 2); + EXPECT_EQ(CountCopies(*module->GetComputationWithName("branch_1")), 0); + EXPECT_EQ(CountCopies(*module->GetComputationWithName("main")), 0); +} + TEST_F(CopyInsertionTest, RootInstructionNotLast) { // This is a test for b/189219227. When the root instruction is scheduled not // as the last instruction, it still lives out. So, we make sure that the copy @@ -3441,8 +3558,8 @@ HloModule async_call ENTRY %main { %input.1 = s32[1024]{0} parameter(0) %buf = s32[1024]{0} custom-call(), custom_call_target="AllocateBuffer" - %async-start = ((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) async-start(s32[1024]{0} %input.1, s32[1024]{0} %buf), async_group_id=0, async_execution_thread="foobar", calls=%async_wrapped - ROOT %async-done = s32[1024]{0} async-done(((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) %async-start), async_group_id=0, async_execution_thread="foobar", calls=%async_wrapped + %async-start = ((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) async-start(s32[1024]{0} %input.1, s32[1024]{0} %buf), async_execution_thread="foobar", calls=%async_wrapped + ROOT %async-done = s32[1024]{0} async-done(((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) %async-start), async_execution_thread="foobar", calls=%async_wrapped } )"; @@ -3480,8 +3597,8 @@ HloModule async_call ENTRY %main { %input.1 = s32[1024]{0} parameter(0) %input.2 = s32[1024]{0} parameter(1) - %async-start = ((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) async-start(s32[1024]{0} %input.1, s32[1024]{0} %input.2), async_group_id=0, async_execution_thread="foobar", calls=%async_wrapped - ROOT %async-done = s32[1024]{0} async-done(((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) %async-start), async_group_id=0, async_execution_thread="foobar", calls=%async_wrapped + %async-start = ((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) async-start(s32[1024]{0} %input.1, s32[1024]{0} %input.2), async_execution_thread="foobar", calls=%async_wrapped + ROOT %async-done = s32[1024]{0} async-done(((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) %async-start), async_execution_thread="foobar", calls=%async_wrapped } )"; @@ -3649,5 +3766,108 @@ ENTRY main { EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kGetTupleElement); } +TEST_F(CopyInsertionTest, DontInsertCopiesInAsyncComputation) { + constexpr absl::string_view kModuleString = R"( +HloModule test + +%async_computation { + %param_0 = f32[10,32,512]{2,1,0:T(8,128)S(5)} parameter(0) + %param_1 = f32[1,32,512]{2,1,0:T(8,128)} parameter(1) + %param_2 = s32[]{:T(128)} parameter(2) + %param_3 = s32[]{:T(128)} parameter(3) + %param_4 = s32[]{:T(128)} parameter(4) + ROOT %dynamic-update-slice.1 = f32[10,32,512]{2,1,0:T(8,128)S(5)} + dynamic-update-slice(%param_0, %param_1, %param_2, %param_3, %param_4) +} + +ENTRY %main { + %param.1 = (s32[]{:T(128)}, f32[32,512]{1,0:T(8,128)}, + f32[10,32,512]{2,1,0:T(8,128)S(5)}) parameter(0) + %get-tuple-element.132 = f32[10,32,512]{2,1,0:T(8,128)S(5)} get-tuple-element( + %param.1), index=2 + %get-tuple-element.131 = f32[32,512]{1,0:T(8,128)} get-tuple-element( + %param.1), index=1 + %cosine.0 = f32[32,512]{1,0:T(8,128)} cosine(%get-tuple-element.131) + %reshape.6 = f32[1,32,512]{2,1,0:T(8,128)} reshape(%cosine.0) + %get-tuple-element.130 = s32[]{:T(128)} get-tuple-element(%param.1), index=0 + %constant.49 = s32[]{:T(128)} constant(0) + %compare.13 = pred[]{:T(512)} compare( + %get-tuple-element.130, %constant.49), direction=LT + %constant.50 = s32[]{:T(128)} constant(10) + %add.22 = s32[]{:T(128)} add(%get-tuple-element.130, %constant.50) + %select.6 = s32[]{:T(128)} select( + %compare.13, %add.22, %get-tuple-element.130) + %dynamic-update-slice-start = ( + (f32[10,32,512]{2,1,0:T(8,128)S(5)}, f32[1,32,512]{2,1,0:T(8,128)}, + s32[]{:T(128)}, s32[]{:T(128)}, s32[]{:T(128)}), + f32[10,32,512]{2,1,0:T(8,128)S(5)}, u32[]) async-start( + %get-tuple-element.132, %reshape.6, %select.6, + %constant.49, %constant.49), calls=%async_computation + ROOT %dynamic-update-slice-done = f32[10,32,512]{2,1,0:T(8,128)S(5)} + async-done(%dynamic-update-slice-start), calls=%async_computation +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + + CopyInsertion copy_insertion; + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + LOG(INFO) << module->ToString(); + + auto* async_computation = module->GetComputationWithName("async_computation"); + ASSERT_THAT(async_computation, NotNull()); + EXPECT_EQ(CountCopies(*async_computation), 0); + + auto* main_computation = module->GetComputationWithName("main"); + ASSERT_THAT(main_computation, NotNull()); + EXPECT_EQ(CountCopies(*main_computation), 1); +} + +TEST_F(CopyInsertionTest, AsyncDUSInLoop) { + constexpr absl::string_view kModuleString = R"( +HloModule module + +async_wrapped { + async_param.1 = s32[1024]{0} parameter(0) + async_param.2 = s32[256]{0} parameter(1) + async_param.3 = s32[] parameter(2) + ROOT dus = s32[1024]{0} dynamic-update-slice(async_param.1, async_param.2, async_param.3) +} + +condition { + input_tuple = (s32[1024]{0}, s32[256]{0}, s32[], pred[]) parameter(0) + ROOT cond = pred[] get-tuple-element(input_tuple), index=3 +} + +body { + input_tuple = (s32[1024]{0}, s32[256]{0}, s32[], pred[]) parameter(0) + input.1 = s32[1024]{0} get-tuple-element(input_tuple), index=0 + input.2 = s32[256]{0} get-tuple-element(input_tuple), index=1 + input.3 = s32[] get-tuple-element(input_tuple), index=2 + input.4 = pred[] get-tuple-element(input_tuple), index=3 + async-start = ((s32[1024]{0}, s32[256]{0}, s32[]), s32[1024]{0}, u32[]) async-start(input.1, input.2, input.3), calls=%async_wrapped + async-done = s32[1024]{0} async-done(async-start), calls=async_wrapped + ROOT tuple = (s32[1024]{0}, s32[256]{0}, s32[], pred[]) tuple(async-done, input.2, input.3, input.4) +} + +ENTRY main { + input.1 = s32[256]{0} parameter(0) + input.2 = s32[] parameter(1) + input.3 = pred[] parameter(2) + broadcast = s32[1024]{0} broadcast(input.2), dimensions={} + while_tuple = (s32[1024]{0}, s32[256]{0}, s32[], pred[]) tuple(broadcast, input.1, input.2, input.3) + while = (s32[1024]{0}, s32[256]{0}, s32[], pred[]) while(while_tuple), condition=condition, body=body + ROOT gte = s32[1024]{0} get-tuple-element(while), index=0 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + + CopyInsertion copy_insertion(nullptr, + /*use_region_based_live_range_analysis=*/-1); + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + VLOG(2) << module->ToString(); + EXPECT_EQ(CountCopies(*module), 0); +} + } // namespace } // namespace xla diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index 77f04ded91208..cb62da6a94ae8 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -2,30 +2,38 @@ # LLVM-based CPU backend for XLA. load("@bazel_skylib//rules:build_test.bzl", "build_test") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") +load("@tsl//tsl:tsl.bzl", "internal_visibility", "tf_openmp_copts", "tsl_copts") +load("@tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") +load("@tsl//tsl/platform:build_config.bzl", "tf_proto_library") load( - "//xla:xla.bzl", - "ORC_JIT_MEMORY_MAPPER_TARGETS", - "xla_cc_binary", - "xla_cc_test", + "@tsl//tsl/platform:build_config_root.bzl", + "if_llvm_aarch64_available", + "if_llvm_powerpc_available", + "if_llvm_system_z_available", + "if_llvm_x86_available", ) +load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load( "//third_party/compute_library:build_defs.bzl", "acl_deps", "if_enable_acl", ) -load("@tsl//tsl:tsl.bzl", "tf_openmp_copts", "tsl_copts") -load("@tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") load( - "@tsl//tsl/mkl:build_defs.bzl", + "//xla:xla.bzl", + "ORC_JIT_MEMORY_MAPPER_TARGETS", + "xla_cc_binary", + "xla_cc_test", +) +load( + "//xla/tsl/mkl:build_defs.bzl", "mkl_deps", ) -load("@tsl//tsl/platform:build_config.bzl", "if_llvm_system_z_available", "tf_proto_library") -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load(":build_defs.bzl", "runtime_copts") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [":friends"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -45,13 +53,22 @@ filegroup( ]), ) +bool_flag( + name = "experimental_mlir_gpu", + build_setting_default = False, +) + +config_setting( + name = "experimental_mlir_gpu_enabled", + flag_values = { + ":experimental_mlir_gpu": "True", + }, +) + cc_library( name = "test_header_helper", testonly = True, hdrs = ["test_target_triple_helper.h"], - deps = [ - "@tsl//tsl/platform:test", - ], ) # When using mlir based HloLowering, the following utils will sometimes be needed to define used symbols. @@ -95,7 +112,7 @@ filegroup( "runtime_matmul_s32.cc", "runtime_fork_join.cc", ], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), ) filegroup( @@ -122,7 +139,7 @@ filegroup( "runtime_lightweight_check.h", "runtime_matmul.h", ], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), ) cc_library( @@ -139,13 +156,10 @@ cc_library( "//xla:statusor", "//xla:types", "//xla:util", - "//xla:xla_data_proto_cc", "//xla/service:hlo_cost_analysis", "//xla/service:shaped_buffer", "@com_google_absl//absl/base", "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:notification", @@ -171,9 +185,9 @@ cc_library( "//xla/service:generic_transfer_manager", "//xla/service:transfer_manager", "//xla/stream_executor", + "//xla/stream_executor:platform_manager", "//xla/stream_executor/host:host_platform_id", "@com_google_absl//absl/base", - "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", @@ -203,6 +217,7 @@ cc_library( ":compiler_functor", ":conv_canonicalization", ":cpu_executable", + ":cpu_float_support", ":cpu_instruction_fusion", ":cpu_layout_assignment", ":cpu_options", @@ -211,7 +226,8 @@ cc_library( ":hlo_xla_runtime_pipeline", ":ir_emission_utils", ":ir_emitter", - ":onednn_rewriter", + ":onednn_matmul_rewriter", + ":onednn_ops_rewriter", ":parallel_task_assignment", ":simple_orc_jit", ":target_machine_features", @@ -238,14 +254,12 @@ cc_library( "//xla/mlir/runtime/transforms:jit_compiler", "//xla/mlir_hlo", "//xla/mlir_hlo:all_passes", - "//xla/mlir_hlo:lhlo", "//xla/mlir_hlo:mhlo_passes", "//xla/mlir_hlo:transforms_passes", "//xla/runtime:custom_call_registry", "//xla/runtime:executable", "//xla/runtime:jit_executable", "//xla/service:algebraic_simplifier", - "//xla/service:all_gather_decomposer", "//xla/service:all_reduce_promotion", "//xla/service:all_to_all_decomposer", "//xla/service:batch_dot_simplification", @@ -299,9 +313,9 @@ cc_library( "//xla/service:map_inliner", "//xla/service:operand_upcaster", "//xla/service:optimization_barrier_expander", + "//xla/service:optimize_input_output_buffer_alias", "//xla/service:qr_expander", "//xla/service:reduce_decomposer", - "//xla/service:reduce_scatter_decomposer", "//xla/service:reshape_decomposer", "//xla/service:reshape_mover", "//xla/service:result_caster", @@ -311,6 +325,7 @@ cc_library( "//xla/service:select_and_scatter_expander", "//xla/service:sharding_propagation", "//xla/service:sharding_remover", + "//xla/service:simplify_fp_conversions", "//xla/service:slice_sinker", "//xla/service:slow_operation_alarm", "//xla/service:sort_simplifier", @@ -344,6 +359,7 @@ cc_library( "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -357,7 +373,6 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", "@llvm-project//llvm:TargetParser", - "@llvm-project//llvm:X86CodeGen", # fixdeps: keep "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", "@llvm-project//mlir:ArithDialect", @@ -387,20 +402,14 @@ cc_library( "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", "@tsl//tsl/protobuf:error_codes_proto_impl_cc", - ] + select({ - "@tsl//tsl:arm_any": [ - "@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep - ], - "@tsl//tsl:linux_ppc64le": [ - "@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep - ], - "@tsl//tsl:macos_arm64": [ - "@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep - ], - "//conditions:default": [ - ], - }) + if_llvm_system_z_available([ + ] + if_llvm_aarch64_available([ + "@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep + ]) + if_llvm_powerpc_available([ + "@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep + ]) + if_llvm_system_z_available([ "@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep + ]) + if_llvm_x86_available([ + "@llvm-project//llvm:X86CodeGen", # fixdeps: keep ]), ) @@ -415,10 +424,12 @@ cc_library( "cpu_compiler_pure", ":executable_proto_cc", ":target_machine_features", + ":xla_framework", "//xla:cpu_function_runtime", "//xla:status", "//xla:statusor", "//xla:util", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/service:buffer_assignment", @@ -442,6 +453,7 @@ tf_proto_library( protodeps = [ ":xla_framework_proto", "//xla/service:hlo_proto", + "//xla:xla_proto", ], ) @@ -461,6 +473,10 @@ cc_library( name = "hlo_xla_runtime_pipeline", srcs = ["hlo_xla_runtime_pipeline.cc"], hdrs = ["hlo_xla_runtime_pipeline.h"], + local_defines = select({ + ":experimental_mlir_gpu_enabled": ["EXPERIMENTAL_MLIR_GPU=1"], + "//conditions:default": [], + }), deps = [ "//xla:status", "//xla/mlir/backends/cpu/transforms:passes", @@ -475,8 +491,6 @@ cc_library( "@llvm-project//mlir:ComplexToStandard", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncTransforms", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:GPUToNVVMTransforms", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:Pass", @@ -486,6 +500,7 @@ cc_library( "@llvm-project//mlir:ShapeToStandard", "@llvm-project//mlir:ShapeTransforms", "@llvm-project//mlir:SparseTensorTransforms", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorToLinalg", "@llvm-project//mlir:TensorTransforms", "@llvm-project//mlir:Transforms", @@ -494,7 +509,13 @@ cc_library( "@llvm-project//mlir:VectorTransforms", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", - ], + ] + select({ + ":experimental_mlir_gpu_enabled": [ + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToNVVMTransforms", + ], + "//conditions:default": [], + }), alwayslink = 1, # has pipeline registration ) @@ -510,7 +531,9 @@ cc_library( deps = [ ":compiler_functor", ":cpu_runtime", + ":onednn_layer_norm", ":onednn_matmul", + ":onednn_softmax", ":orc_jit_memory_mapper", ":runtime_conv2d", ":runtime_conv2d_acl", @@ -532,7 +555,9 @@ cc_library( "//xla:types", "//xla:util", "//xla/service:custom_call_target_registry", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@llvm-project//llvm:ExecutionEngine", "@llvm-project//llvm:MC", # fixdeps: keep @@ -671,6 +696,7 @@ cc_library( "//xla/service/llvm_ir:llvm_type_conversion_util", "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:loop_emitter", + "//xla/service/llvm_ir:math_ops", "//xla/service/llvm_ir:tuple_ops", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -745,7 +771,6 @@ cc_library( "//xla/service/llvm_ir:loop_emitter", "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:Core", - "@tsl//tsl/platform:logging", ], ) @@ -761,7 +786,6 @@ cc_library( "//xla/service/llvm_ir:kernel_support_library", "//xla/service/llvm_ir:llvm_util", "@llvm-project//llvm:Core", - "@tsl//tsl/platform:logging", ], ) @@ -845,10 +869,9 @@ cc_library( "//xla/runtime:execution_engine", "//xla/service:llvm_compiler", "//xla/service/llvm_ir:llvm_util", - "@com_google_absl//absl/memory", + "@com_google_absl//absl/functional:any_invocable", "@llvm-project//llvm:Analysis", "@llvm-project//llvm:Core", - "@llvm-project//llvm:IPO", "@llvm-project//llvm:Instrumentation", "@llvm-project//llvm:MC", "@llvm-project//llvm:Object", @@ -873,11 +896,13 @@ cc_library( copts = runtime_copts(), deps = [ ":collectives_interface", + ":cpu_executable_run_options", ":in_process_collectives", "//xla:executable_run_options", "//xla:shape_util", "//xla:statusor", "//xla:types", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", @@ -886,7 +911,10 @@ cc_library( "//xla/stream_executor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", @@ -928,12 +956,11 @@ cc_library( deps = [ ":runtime_lightweight_check", "//xla:executable_run_options", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", "@eigen_archive//:eigen3", "@tsl//tsl/framework/contraction:eigen_contraction_kernel", "@tsl//tsl/framework/convolution:eigen_helpers", - "@tsl//tsl/platform:mutex", + "@tsl//tsl/platform:mutex", # build_cleaner: keep ], ) @@ -949,12 +976,11 @@ cc_library( deps = [ ":runtime_lightweight_check", "//xla:executable_run_options", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", "@eigen_archive//:eigen3", "@tsl//tsl/framework/contraction:eigen_contraction_kernel", "@tsl//tsl/framework/convolution:eigen_helpers", - "@tsl//tsl/platform:mutex", + "@tsl//tsl/platform:mutex", # build_cleaner: keep ], ) @@ -966,7 +992,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//xla/service:custom_call_status_internal", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", ], ) @@ -1027,7 +1052,7 @@ cc_library( "@com_google_absl//absl/base:dynamic_annotations", "@eigen_archive//:eigen3", "@tsl//tsl/framework/contraction:eigen_contraction_kernel", - "@tsl//tsl/platform:mutex", + "@tsl//tsl/platform:mutex", # build_cleaner: keep ], ) @@ -1041,6 +1066,7 @@ cc_library( ":runtime_lightweight_check", ":runtime_matmul", "//xla:executable_run_options", + "@com_google_absl//absl/base", "@eigen_archive//:eigen3", "@tsl//tsl/platform:dynamic_annotations", "@tsl//tsl/platform:logging", @@ -1058,8 +1084,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":runtime_conv2d", + ":runtime_lightweight_check", ":runtime_single_threaded_conv2d", "//xla:executable_run_options", + "@com_google_absl//absl/base", "@eigen_archive//:eigen3", "@tsl//tsl/framework/convolution:eigen_helpers", "@tsl//tsl/platform:dynamic_annotations", @@ -1078,12 +1106,11 @@ cc_library( copts = runtime_copts(), visibility = ["//visibility:public"], deps = [ - ":runtime_lightweight_check", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", "@eigen_archive//:eigen3", "@tsl//tsl/framework/contraction:eigen_contraction_kernel", "@tsl//tsl/framework/convolution:eigen_helpers", + "@tsl//tsl/platform:mutex", # build_cleaner: keep ], ) @@ -1097,12 +1124,11 @@ cc_library( copts = runtime_copts(), visibility = ["//visibility:public"], deps = [ - ":runtime_lightweight_check", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", "@eigen_archive//:eigen3", "@tsl//tsl/framework/contraction:eigen_contraction_kernel", "@tsl//tsl/framework/convolution:eigen_helpers", + "@tsl//tsl/platform:mutex", # build_cleaner: keep ], ) @@ -1139,6 +1165,7 @@ cc_library( deps = [ "@com_google_absl//absl/base:core_headers", "@eigen_archive//:eigen3", + "@tsl//tsl/framework/contraction:eigen_contraction_kernel_no_mkl", ], ) @@ -1150,9 +1177,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":runtime_single_threaded_matmul_impl", - "@com_google_absl//absl/base:core_headers", "@eigen_archive//:eigen3", - "@tsl//tsl/framework/contraction:eigen_contraction_kernel", ], ) @@ -1201,14 +1226,12 @@ cc_library( deps = [ "//xla:executable_run_options", "//xla/service:custom_call_status_internal", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@eigen_archive//:eigen3", "@tsl//tsl/platform:blocking_counter", "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:mutex", ], ) @@ -1225,11 +1248,9 @@ xla_cc_test( ":runtime_single_threaded_matmul", "//xla:array2d", "//xla:types", - "//xla:util", "//xla/client:local_client", "//xla/service:custom_call_status_internal", "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", "@eigen_archive//:eigen3", "@tsl//tsl/platform:env", @@ -1241,6 +1262,7 @@ xla_cc_test( xla_cc_test( name = "cpu_instruction_fusion_test", srcs = ["cpu_instruction_fusion_test.cc"], + tags = ["no_aarch64"], deps = [ ":cpu_instruction_fusion", "//xla:shape_util", @@ -1251,7 +1273,6 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:logging", ], ) @@ -1275,7 +1296,6 @@ cc_library( srcs = ["cpu_instruction_fusion.cc"], hdrs = ["cpu_instruction_fusion.h"], deps = [ - ":ir_emission_utils", "//xla/hlo/ir:hlo", "//xla/service:fusion_node_indexing_evaluation", "//xla/service:instruction_fusion", @@ -1305,10 +1325,6 @@ xla_cc_test( ":ir_emission_utils", ":target_machine_features_fake", "//xla:test", - "//xla:test_helpers", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", ], @@ -1322,7 +1338,9 @@ cc_library( ":dot_op_emitter", ":ir_emission_utils", ":target_machine_features", + "//xla:shape_util", "//xla:util", + "//xla/hlo/ir:hlo", "//xla/service:computation_layout", "//xla/service:layout_assignment", "@com_google_absl//absl/container:flat_hash_map", @@ -1372,7 +1390,6 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", ], ) @@ -1445,23 +1462,10 @@ xla_cc_test( ":cpu_executable", ":parallel_task_assignment", ":target_machine_features_fake", - "//xla:literal", - "//xla:shape_layout", - "//xla:shape_util", "//xla:test", - "//xla:test_helpers", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/service:algebraic_simplifier", - "//xla/service:computation_layout", "//xla/tests:hlo_test_base", - "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:test", ], ) @@ -1472,7 +1476,6 @@ cc_library( deps = [ "//xla/service:hlo_module_config", "@com_google_absl//absl/strings", - "@tsl//tsl/platform:logging", ], ) @@ -1502,7 +1505,6 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", - "@tsl//tsl/platform:logging", ], ) @@ -1533,7 +1535,6 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "@llvm-project//llvm:Core", "@llvm-project//llvm:MC", - "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", ], ) @@ -1576,6 +1577,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":runtime_lightweight_check", + "//xla:literal", "//xla:shape_util", "//xla:status_macros", "//xla:statusor", @@ -1602,7 +1604,55 @@ cc_library( srcs = ["onednn_matmul.cc"], hdrs = [ "onednn_matmul.h", - "@tsl//tsl/util:onednn_util_hdrs", + "//xla/tsl/util:onednn_util_hdrs", + ], + copts = runtime_copts() + tsl_copts(), + visibility = ["//visibility:public"], + deps = [ + ":backend_config_proto_cc", + ":onednn_memory_util", + ":runtime_lightweight_check", + "//xla:executable_run_options", + "//xla:shape_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:dynamic_annotations", + "@eigen_archive//:eigen3", + "@tsl//tsl/platform:blocking_counter", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:platform_port", + ] + mkl_deps(), +) + +cc_library( + name = "onednn_layer_norm", + srcs = ["onednn_layer_norm.cc"], + hdrs = [ + "onednn_layer_norm.h", + "//xla/tsl/util:onednn_util_hdrs", + ], + copts = runtime_copts() + tsl_copts(), + visibility = ["//visibility:public"], + deps = [ + ":backend_config_proto_cc", + ":onednn_memory_util", + ":runtime_lightweight_check", + "//xla:executable_run_options", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:dynamic_annotations", + "@eigen_archive//:eigen3", + "@tsl//tsl/platform:blocking_counter", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:platform_port", + ] + mkl_deps(), +) + +cc_library( + name = "onednn_softmax", + srcs = ["onednn_softmax.cc"], + hdrs = [ + "onednn_softmax.h", + "//xla/tsl/util:onednn_util_hdrs", ], copts = runtime_copts() + tsl_copts(), visibility = ["//visibility:public"], @@ -1613,6 +1663,7 @@ cc_library( "//xla:executable_run_options", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", + "@com_google_absl//absl/log:check", "@eigen_archive//:eigen3", "@tsl//tsl/platform:blocking_counter", "@tsl//tsl/platform:env", @@ -1621,22 +1672,75 @@ cc_library( ) cc_library( - name = "onednn_rewriter", - srcs = ["onednn_rewriter.cc"], - hdrs = ["onednn_rewriter.h"], + name = "onednn_util", + hdrs = ["onednn_util.h"], +) + +cc_library( + name = "onednn_matmul_rewriter", + srcs = ["onednn_matmul_rewriter.cc"], + hdrs = [ + "onednn_matmul.h", + "onednn_matmul_rewriter.h", + "//xla/tsl/util:onednn_util_hdrs", + ], copts = tsl_copts(), deps = [ ":backend_config_proto_cc", + ":onednn_matmul", ":onednn_memory_util", + ":onednn_util", + "//xla:executable_run_options", + "//xla:shape_util", "//xla:status_macros", "//xla:xla_data_proto_cc", + "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/service:hlo_creation_utils", "//xla/service:hlo_pass", "//xla/service:pattern_matcher", + "@com_google_absl//absl/algorithm:container", + "@eigen_archive//:eigen3", + "@tsl//tsl/platform:blocking_counter", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:platform_port", ] + mkl_deps(), ) +cc_library( + name = "onednn_ops_rewriter", + srcs = ["onednn_ops_rewriter.cc"], + hdrs = [ + "onednn_ops_rewriter.h", + ], + copts = tsl_copts(), + deps = [ + ":backend_config_proto_cc", + ":onednn_memory_util", + ":onednn_util", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "//xla/service:pattern_matcher", + "@com_google_absl//absl/algorithm:container", + "@tsl//tsl/platform:platform_port", + ] + mkl_deps(), +) + +cc_library( + name = "cpu_float_support", + srcs = ["cpu_float_support.cc"], + hdrs = ["cpu_float_support.h"], + copts = tsl_copts(), + deps = [ + ":onednn_matmul_rewriter", + "//xla/service:float_support", + ], +) + cc_library( name = "cpu_symbol_repository", hdrs = ["cpu_symbol_repository.h"], @@ -1682,3 +1786,9 @@ cc_library( "@tsl//tsl/platform:errors", ], ) + +cc_library( + name = "cpu_executable_run_options", + hdrs = ["cpu_executable_run_options.h"], + deps = [":collectives_interface"], +) diff --git a/xla/service/cpu/backend_config.proto b/xla/service/cpu/backend_config.proto index 5ce04477008fb..7e82b08f3da5d 100644 --- a/xla/service/cpu/backend_config.proto +++ b/xla/service/cpu/backend_config.proto @@ -10,9 +10,12 @@ message BackendConfig { repeated int64 outer_dimension_partitions = 1; // Configuration to be used by oneDNN matmul OneDnnMatMulConfig onednn_matmul_config = 2; + OneDnnLayerNormConfig onednn_layer_norm_config = 3; } message OneDnnMatMulConfig { + bool transpose_a = 1; + bool transpose_b = 2; // These enum needs to be mapped to oneDNN enum for post_op algorithm. // TODO(intel-tf): Add kinds supported by oneDNN. enum FusionKind { @@ -22,6 +25,25 @@ message OneDnnMatMulConfig { TANH = 3; GELU_ERF = 4; GELU_TANH = 5; + BINARY_ADD = 6; + LINEAR = 7; } repeated FusionKind fused_ops = 3; + bool bias_broadcast = 4; + // To avoid protobuf failures for specific decimal values, + // the original float value alpha is type-casted to int32. + int32 alpha_typecast = 5; +} + +message OneDnnLayerNormConfig { + // These enum needs to be mapped to oneDNN enum for post_op algorithm. + // TODO(intel-tf): Add kinds supported by oneDNN. + enum FusionKind { + UNDEFINED = 0; + SCALE = 1; + SHIFT = 2; + SCALE_AND_SHIFT = 3; + } + FusionKind fused_ops = 1; + int32 epsilon_typecast = 2; } diff --git a/xla/service/cpu/buffer_desc.h b/xla/service/cpu/buffer_desc.h index ec2e03b3dc8e5..8d606250cb6c8 100644 --- a/xla/service/cpu/buffer_desc.h +++ b/xla/service/cpu/buffer_desc.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/buffer_info_util.cc b/xla/service/cpu/buffer_info_util.cc index 672ef49fa3d2c..f0b618efd0929 100644 --- a/xla/service/cpu/buffer_info_util.cc +++ b/xla/service/cpu/buffer_info_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/buffer_info_util.h b/xla/service/cpu/buffer_info_util.h index 5c3bf93e2182f..7ba98c1fb6787 100644 --- a/xla/service/cpu/buffer_info_util.h +++ b/xla/service/cpu/buffer_info_util.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/collectives_interface.h b/xla/service/cpu/collectives_interface.h index bd518db3a780b..54b6a280f5991 100644 --- a/xla/service/cpu/collectives_interface.h +++ b/xla/service/cpu/collectives_interface.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -59,9 +59,20 @@ class CollectivesCommunicator { // The all-to-all chunks are passed separately and do not have to be // contiguous in memory. virtual absl::Status AllToAll(const RendezvousKey& key, size_t chunk_bytes, - absl::Span input_buffer, - absl::Span output_buffer, + absl::Span input_buffers, + absl::Span output_buffers, absl::Duration timeout) = 0; + + // Performs an all-gather. + virtual absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) = 0; + + // Performs a reduce-scatter + virtual absl::Status ReduceScatter( + const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, const void* input_buffer, + void* output_buffer, absl::Duration timeout) = 0; }; class CollectivesInterface { diff --git a/xla/service/cpu/compiler_functor.cc b/xla/service/cpu/compiler_functor.cc index 5a877ae5ecee7..f909cb6ff2075 100644 --- a/xla/service/cpu/compiler_functor.cc +++ b/xla/service/cpu/compiler_functor.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/compiler_functor.h b/xla/service/cpu/compiler_functor.h index 107492790e875..3c8752b344688 100644 --- a/xla/service/cpu/compiler_functor.h +++ b/xla/service/cpu/compiler_functor.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,19 +16,19 @@ limitations under the License. #ifndef XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_ #define XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_ -#include +#include #include #include #include +#include "absl/functional/any_invocable.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/IR/FMF.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" -#include "llvm/IR/Operator.h" #include "llvm/Object/ObjectFile.h" #include "llvm/Target/TargetMachine.h" #include "xla/service/llvm_compiler.h" -#include "tsl/platform/logging.h" namespace xla { namespace cpu { @@ -43,8 +43,8 @@ class CompilerFunctor : public llvm::orc::IRCompileLayer::IRCompiler { bool disable_slp_vectorizer, llvm::FastMathFlags fast_math_flags, LLVMCompiler::ModuleHook pre_optimization_hook = nullptr, LLVMCompiler::ModuleHook post_optimization_hook = nullptr, - std::function post_codegen_hook = - nullptr, + absl::AnyInvocable + post_codegen_hook = nullptr, bool dfsan_enabled = false, const std::vector& dfsan_abi_list_files = {}, const std::vector& convert_to_xla_runtime_abi = {}) @@ -75,7 +75,7 @@ class CompilerFunctor : public llvm::orc::IRCompileLayer::IRCompiler { const llvm::FastMathFlags fast_math_flags_; LLVMCompiler::ModuleHook pre_optimization_hook_; LLVMCompiler::ModuleHook post_optimization_hook_; - std::function post_codegen_hook_; + absl::AnyInvocable post_codegen_hook_; const bool dfsan_enabled_ = false; const std::vector dfsan_abi_list_files_; const std::vector convert_to_xla_runtime_abi_; diff --git a/xla/service/cpu/conv_canonicalization.cc b/xla/service/cpu/conv_canonicalization.cc index 90a8fcb00ec87..c35b67410cc20 100644 --- a/xla/service/cpu/conv_canonicalization.cc +++ b/xla/service/cpu/conv_canonicalization.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ limitations under the License. namespace xla { namespace cpu { -StatusOr ConvCanonicalization::Run( +absl::StatusOr ConvCanonicalization::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/cpu/conv_canonicalization.h b/xla/service/cpu/conv_canonicalization.h index 7e7ad06b8e914..5bdecaf7b3ec7 100644 --- a/xla/service/cpu/conv_canonicalization.h +++ b/xla/service/cpu/conv_canonicalization.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ class ConvCanonicalization : public HloModulePass { return "convolution-canonicalization"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/cpu/conv_canonicalization_test.cc b/xla/service/cpu/conv_canonicalization_test.cc index f21e4443e6b79..76b5121693d34 100644 --- a/xla/service/cpu/conv_canonicalization_test.cc +++ b/xla/service/cpu/conv_canonicalization_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index ae2a2cd5672ae..dcd9781dfef05 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,12 +34,14 @@ limitations under the License. #include "absl/base/call_once.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -61,7 +63,9 @@ limitations under the License. #include "llvm/Target/TargetOptions.h" #include "llvm/TargetParser/Host.h" #include "llvm/TargetParser/Triple.h" +#ifdef TF_LLVM_X86_AVAILABLE #include "llvm/TargetParser/X86TargetParser.h" +#endif #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project #include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project @@ -82,6 +86,7 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project @@ -107,7 +112,6 @@ limitations under the License. #include "xla/mlir/runtime/transforms/compilation_pipeline_cpu.h" #include "xla/mlir/runtime/transforms/compiler.h" #include "xla/mlir/runtime/transforms/jit_compiler.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/mlir_hlo/transforms/passes.h" @@ -115,7 +119,6 @@ limitations under the License. #include "xla/runtime/executable.h" #include "xla/runtime/jit_executable.h" #include "xla/service/algebraic_simplifier.h" -#include "xla/service/all_gather_decomposer.h" #include "xla/service/all_reduce_promotion.h" #include "xla/service/all_to_all_decomposer.h" #include "xla/service/batch_dot_simplification.h" @@ -153,7 +156,6 @@ limitations under the License. #include "xla/service/cpu/runtime/xfeed.h" #include "xla/service/cpu/simple_orc_jit.h" #include "xla/service/cpu/target_machine_features.h" -#include "xla/service/cpu/xla_framework.h" #include "xla/service/cpu_gpu_shape_verifier.h" #include "xla/service/dot_decomposer.h" #include "xla/service/dump.h" @@ -190,9 +192,9 @@ limitations under the License. #include "xla/service/map_inliner.h" #include "xla/service/operand_upcaster.h" #include "xla/service/optimization_barrier_expander.h" +#include "xla/service/optimize_input_output_buffer_alias.h" #include "xla/service/qr_expander.h" #include "xla/service/reduce_decomposer.h" -#include "xla/service/reduce_scatter_decomposer.h" #include "xla/service/reshape_decomposer.h" #include "xla/service/reshape_mover.h" #include "xla/service/result_caster.h" @@ -202,6 +204,7 @@ limitations under the License. #include "xla/service/select_and_scatter_expander.h" #include "xla/service/sharding_propagation.h" #include "xla/service/sharding_remover.h" +#include "xla/service/simplify_fp_conversions.h" #include "xla/service/slow_operation_alarm.h" #include "xla/service/sort_simplifier.h" #include "xla/service/spmd/stateful_rng_spmd_partitioner.h" @@ -236,7 +239,9 @@ limitations under the License. #include "tsl/platform/statusor.h" #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) -#include "xla/service/cpu/onednn_rewriter.h" +#include "xla/service/cpu/cpu_float_support.h" +#include "xla/service/cpu/onednn_matmul_rewriter.h" +#include "xla/service/cpu/onednn_ops_rewriter.h" #endif namespace { @@ -271,6 +276,7 @@ xla::cpu::HloXlaRuntimePipelineOptions GetHloXlaRuntimePipelineOptions( xla::GetDebugOptionsFromFlags().xla_cpu_matmul_tiling_n_dim(), xla::GetDebugOptionsFromFlags().xla_cpu_matmul_tiling_k_dim()}; } +#ifdef TF_LLVM_X86_AVAILABLE options.enable_avx2 = [&] { // Derive whether this is an x86 CPU with AVX2 enabled. if (!target_triple.isX86()) return false; @@ -278,6 +284,9 @@ xla::cpu::HloXlaRuntimePipelineOptions GetHloXlaRuntimePipelineOptions( llvm::X86::getFeaturesForCPU(cpu_name, cpu_features); return llvm::is_contained(cpu_features, "avx2"); }(); +#else + options.enable_avx2 = false; +#endif options.cpu_name = cpu_name; if (xla::GetDebugOptionsFromFlags().xla_cpu_enable_mlir_fusion_outlining()) { options.enable_fusion_outlining = true; @@ -350,20 +359,6 @@ se::Platform::Id CpuAotCompilationOptions::PlatformId() const { return se::host::kHostPlatformId; } -CpuXlaRuntimeAotCompilationResult::CpuXlaRuntimeAotCompilationResult( - HloModuleProto hlo, std::string_view obj_file, std::string_view mlir_module, - XlaFrameworkMapping xla_framework_mapping) { - XlaRuntimeExecutableProto xla_runtime_executable; - *xla_runtime_executable.mutable_hlo_module_proto() = hlo; - xla_runtime_executable.set_obj_file(std::string(obj_file)); - xla_runtime_executable.set_mlir_module(std::string(mlir_module)); - - *xla_runtime_cpu_executable_.mutable_xla_runtime_executable() = - xla_runtime_executable; - *xla_runtime_cpu_executable_.mutable_xla_framework_mapping() = - xla_framework_mapping.ToProto(); -} - namespace { namespace runtime = ::xla::runtime; @@ -390,17 +385,22 @@ class FlattenTuplesAndBufferizeTypeConverter : public mlir::TypeConverter { }; runtime::JitExecutable::Options GetXlaRuntimeJitExecutableOptions( - const HloModule& module) { + const HloModule& module, mlir::DialectRegistry* custom_registry) { runtime::CpuPipelineOptions copts; runtime::JitExecutable::Options opts; copts.xla_cpu_sparse_cuda_threads = GetDebugOptionsFromFlags().xla_cpu_sparse_cuda_threads(); + std::optional maybeOverriddenPipeline = + options::ExperimentalOverriddenPipeline(module.config()); opts.specialization = runtime::JitExecutable::Specialization::kDisabled; opts.compiler.register_dialects = - [](xla::runtime::DialectRegistry& dialects) { - dialects->insert(); + [custom_registry](xla::runtime::DialectRegistry& dialects) { + dialects->insert(); runtime::RegisterDefaultXlaCpuRuntimeDialects(dialects); RegisterHloXlaRuntimePipelineDialects(*dialects); + if (custom_registry) { + custom_registry->appendTo(*dialects); + } }; opts.compiler.symbols_binding = runtime::ToSymbolsBinding( [](runtime::DirectCustomCallRegistry& registry) { @@ -412,7 +412,25 @@ runtime::JitExecutable::Options GetXlaRuntimeJitExecutableOptions( PopulateXlaXfeedCall(registry); }); opts.compiler.create_compilation_pipeline = - [copts](xla::runtime::PassManager& passes) { + [copts, maybeOverriddenPipeline = std::move(maybeOverriddenPipeline)]( + xla::runtime::PassManager& passes) { + if (maybeOverriddenPipeline.has_value()) { + std::string error_message; + llvm::raw_string_ostream error_stream(error_message); + mlir::LogicalResult result = mlir::parsePassPipeline( + maybeOverriddenPipeline.value(), *passes, error_stream); + if (mlir::failed(result)) { + LOG(ERROR) + << "Failed to parse experimental CPU compilation pipeline: " + << error_stream.str(); + return absl::InternalError( + "Failed to parse experimental CPU compilation pipeline."); + } + LOG(INFO) << "Experimental CPU compilation pipeline: " + << maybeOverriddenPipeline.value(); + return absl::OkStatus(); + } + HloXlaRuntimePipelineOptions options = GetHloXlaRuntimePipelineOptions( llvm::Triple(llvm::sys::getProcessTriple()), llvm::sys::getHostCPUName()); @@ -421,61 +439,41 @@ runtime::JitExecutable::Options GetXlaRuntimeJitExecutableOptions( Status status = CreateHloXlaRuntimePipeline(passes, options); if (!status.ok()) { - LOG(FATAL) << "HLO-XLA Runtime pipeline failed with: " + LOG(ERROR) << "HLO-XLA Runtime pipeline failed with: " << status.message(); + return status; } runtime::CreateDefaultXlaCpuRuntimeCompilationPipeline(passes, copts); + return absl::OkStatus(); }; opts.compiler.calling_convention = runtime::ResultsToOutsCallingConvention( FlattenTuplesAndBufferizeTypeConverter()); + opts.compiler.embed_ir_in_executable = + module.config().debug_options().xla_embed_ir_in_executable(); return opts; } } // namespace -StatusOr> -CpuXlaRuntimeAotCompilationResult::LoadExecutable( - Compiler* compiler, se::StreamExecutor* executor) { - XlaRuntimeExecutableProto xla_runtime_executable = - xla_runtime_cpu_executable_.xla_runtime_executable(); - TF_ASSIGN_OR_RETURN(HloModuleConfig hlo_module_config, - HloModule::CreateModuleConfigFromProto( - xla_runtime_executable.hlo_module_proto(), - GetDebugOptionsFromFlags())); - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_module, - HloModule::CreateFromProto(xla_runtime_executable.hlo_module_proto(), - hlo_module_config)); - - XlaFrameworkMapping xla_framework_mapping; - xla_framework_mapping.FromProto( - xla_runtime_cpu_executable_.xla_framework_mapping()); - - TF_ASSIGN_OR_RETURN(std::unique_ptr buffer_assignment, - compiler->AssignBuffers(hlo_module.get(), executor)); - - // TODO(b/232263665): JitOptions should be used only for JIT case because it - // has details irrelevant to AOT. - runtime::JitExecutable::Options opts = - GetXlaRuntimeJitExecutableOptions(*hlo_module); - - return CpuExecutable::LoadFromObjFile( - std::move(hlo_module), xla_runtime_executable.obj_file(), - xla_runtime_executable.mlir_module(), std::move(buffer_assignment), - xla_framework_mapping, opts); -} - CpuAotCompilationResult::CpuAotCompilationResult( ObjectFileData object_file_data, std::vector buffer_infos, - int64_t result_buffer_index, + int64_t result_buffer_index, std::unique_ptr module, std::unique_ptr hlo_profile_printer_data) : object_file_data_(std::move(object_file_data)), buffer_infos_(std::move(buffer_infos)), result_buffer_index_(result_buffer_index), + module_(std::move(module)), hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {} -CpuCompiler::CpuCompiler(bool allow_sparse_shapes) - : allow_sparse_shapes_(allow_sparse_shapes) { +const HloModule* CpuAotCompilationResult::optimized_module() const { + return module_.get(); +} + +std::unique_ptr CpuAotCompilationResult::consume_optimized_module() { + return std::move(module_); +} + +CpuCompiler::CpuCompiler() { // Initialize LLVM the first time the CpuCompiler is initialized. static bool llvm_initialized = []() { InitializeLLVMTarget(); @@ -484,9 +482,7 @@ CpuCompiler::CpuCompiler(bool allow_sparse_shapes) (void)llvm_initialized; } -CpuCompiler::CpuCompiler() : CpuCompiler(false) {} - -StatusOr>> CpuCompiler::Compile( +absl::StatusOr>> CpuCompiler::Compile( std::unique_ptr module_group, std::vector> stream_execs, const CompileOptions& options) { @@ -519,7 +515,7 @@ absl::once_flag llvm_command_line_options_initialized; // recorded. class CollectProfileCandidates : public DfsHloVisitorWithDefault { public: - static StatusOr> + static absl::StatusOr> GetCandidatesForComputation( const HloComputation& computation, const absl::flat_hash_map& @@ -595,16 +591,11 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { }; // Adds the HloVerifier for CPU to the given pipeline. -void AddHloVerifier(HloPassPipeline* pipeline, bool allow_sparse_shapes, - HloVerifierOpts&& opts = {}, bool debug_only = false) { - std::unique_ptr verifier_metadata; - if (allow_sparse_shapes) { - verifier_metadata = - std::make_unique(std::move(opts)); - } else { - verifier_metadata = - std::make_unique(std::move(opts)); - } +void AddHloVerifier(HloPassPipeline* pipeline, HloVerifierOpts&& opts = {}, + bool debug_only = false) { + auto verifier_metadata = + std::make_unique(std::move(opts)); + if (debug_only) { pipeline->AddInvariantCheckerDebug( std::move(verifier_metadata), "hlo verifier (debug)"); @@ -629,20 +620,21 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( HloPassPipeline spmd_pipeline("spmd-partitioner"); // Run some IR cleanup passes before running the SPMD partitioning // passes. - AddHloVerifier(&spmd_pipeline, allow_sparse_shapes_); + AddHloVerifier(&spmd_pipeline); spmd_pipeline.AddPass(); spmd_pipeline.AddPass(); spmd_pipeline.AddPass(); spmd_pipeline.AddPass( /*is_spmd=*/true, /*propagate_metadata=*/false, - module->config().allow_spmd_sharding_propagation_to_output()); + module->config().allow_spmd_sharding_propagation_to_output(), + module->config().allow_spmd_sharding_propagation_to_parameters()); spmd_pipeline.AddPass( num_partitions, module->config().replica_count()); TF_RETURN_IF_ERROR(spmd_pipeline.Run(module).status()); } else { HloPassPipeline sharding_removal_pipeline("sharding-removal"); - AddHloVerifier(&sharding_removal_pipeline, allow_sparse_shapes_); + AddHloVerifier(&sharding_removal_pipeline); // Remove redundant sharding ops when partition_count == 1. sharding_removal_pipeline.AddPass(); sharding_removal_pipeline.AddPass(); @@ -650,9 +642,9 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( } { - // Int4Packer must be run before the rest of the pipeline since it modifies - // the layout of the entry computation inputs/outputs, which is passed to - // LayoutAssignment. + // Int4Packer must be run before the rest of the pipeline since it + // modifies the layout of the entry computation inputs/outputs, which is + // passed to LayoutAssignment. HloPassPipeline int4_packer_pipeline("Int4Packer pipeline"); int4_packer_pipeline.AddPass( SubByteNormalization::SET_ELEMENT_SIZE); @@ -660,7 +652,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( } HloPassPipeline pipeline("HLO passes through layout assignment"); - AddHloVerifier(&pipeline, allow_sparse_shapes_); + AddHloVerifier(&pipeline); pipeline.AddPass(); pipeline.AddPass(); @@ -680,14 +672,18 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); + // The TopkDecomposer generates a compare op with type=TOTALORDER and must + // run before the ComparisonExpander which rewrites such comparisons. + pipeline.AddPass([&](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kTopK; + }); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); pipeline.AddPass(); // Inline computations with a single call site. @@ -699,8 +695,12 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) // AOT compiled code runs in single thread. if (!is_aot_compile) { - // Temporarily disabling oneDNN rewriter because it causes JAX regression. - // pipeline.AddPass(); + // Placing OneDnnOpsRewriter here to match the flax patterns + // TODO: Decide where would be the appropriate place for this pass to make + // it more generic + // TODO - intel: Name of the pass might seem redundant as oneDnnRewriter, + // but in future plan to rename oneDNNrewriter to specific to onednn matmul + pipeline.AddPass(); } #endif // INTEL_MKL && ENABLE_ONEDNN_V3 @@ -712,7 +712,16 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( // backend can support BF16/F8 operations without directly implementing a // BF16/F8 lowering for most ops. FloatSupport bf16_support(BF16); +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + CpuFloatSupport onednn_bf16_support(BF16); + if (!is_aot_compile) { + pipeline.AddPass(&onednn_bf16_support); + } else { + pipeline.AddPass(&bf16_support); + } +#else pipeline.AddPass(&bf16_support); +#endif FloatSupport f8e5m2_support(F8E5M2, F16); pipeline.AddPass(&f8e5m2_support); FloatSupport f8e4m3fn_support(F8E4M3FN, F16); @@ -785,7 +794,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( // Run the following passes to a fixed point. [&pipeline = pipeline.AddPass>("simplification"), this] { - AddHloVerifier(&pipeline, allow_sparse_shapes_, HloVerifierOpts{}, + AddHloVerifier(&pipeline, HloVerifierOpts{}, /*debug_only=*/true); AlgebraicSimplifierOptions options; @@ -820,9 +829,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); }(); pipeline.AddPass(); - pipeline.AddPass([&](const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kTopK; - }); // XLA lowers topk to a libcall while the MLIR based pipeline does not yet // support libcalls. Disable this for now. @@ -833,7 +839,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( } pipeline.AddPass(); pipeline.AddPass( - [&](const HloInstruction& dot, int64_t operand) -> StatusOr { + [&](const HloInstruction& dot, int64_t operand) -> absl::StatusOr { if (DotImplementationCanHandleTranspose(dot, *target_machine_features)) { return TransposeFolding::IsRowColumnTransposeDotOperand(dot, operand); @@ -869,7 +875,8 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( Status CpuCompiler::RunHloPassesAfterLayoutAssn( HloModule* module, bool is_aot_compile, - LLVMTargetMachineFeatures* target_machine_features, bool is_mlir_compile) { + LLVMTargetMachineFeatures* target_machine_features, + const CompileOptions& compile_options, bool is_mlir_compile) { HloPassPipeline pipeline("HLO passes after layout assignment"); // CopyInsertion is still needed by BufferAssignment. MLIR passes will handle @@ -890,11 +897,36 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( // After layout assignment, use a layout-sensitive verifier. pipeline.AddPass("after layout assignment"); - AddHloVerifier(&pipeline, allow_sparse_shapes_, - HloVerifierOpts{}.MakeLayoutSensitive(), /*debug_only=*/true); + AddHloVerifier(&pipeline, HloVerifierOpts{}.MakeLayoutSensitive(), + /*debug_only=*/true); pipeline.AddPass(); + const int max_parallelism = + module->config().intra_op_parallelism_threads() > 0 + ? module->config().intra_op_parallelism_threads() + : tsl::port::NumSchedulableCPUs(); + +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + // AOT compiled code runs in single thread. + if (!is_aot_compile) { + auto debug_options = module->config().debug_options(); + // Run SimplifyFPConversions pass to simplify the BF16 pattern and make it + // easier to match. + // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. + if (debug_options.xla_allow_excess_precision()) { + pipeline.AddPass(); + } + pipeline.AddPass(max_parallelism, + compile_options.thread_pool); + // Run SimplifyFPConversions pass again to remove redundant Convert ops + // that may exist as a result of running OneDnnMatMulRewriter pass. + if (debug_options.xla_allow_excess_precision()) { + pipeline.AddPass(); + } + } +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 + // Add a fusion pass now that layout assignment is done. pipeline.AddPass(); @@ -905,7 +937,7 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( "simplification after layout assignment"), this] { AddHloVerifier( - &pipeline, allow_sparse_shapes_, + &pipeline, HloVerifierOpts{}.MakeLayoutSensitive().WithInstructionCanChangeLayout( LayoutAssignment::InstructionCanChangeLayout), /*debug_only=*/true); @@ -922,10 +954,6 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( }(); // Outline ops in the entry computation into calls to subcomputations. - const int max_parallelism = - module->config().intra_op_parallelism_threads() > 0 - ? module->config().intra_op_parallelism_threads() - : tsl::port::NumSchedulableCPUs(); if (!is_aot_compile) { // Run ParallelTaskAssigner to assign parallel tasks to HLOs in module. // Note this is not run for AOT because it would bring in thread pool @@ -942,6 +970,7 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( // before (and sometime after) copy insertion, to avoid dead code from // interfering with the rewrites. pipeline.AddPass(); + pipeline.AddPass(true); pipeline.AddPass(); pipeline.AddPass(); return pipeline.Run(module).status(); @@ -949,13 +978,15 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, llvm::TargetMachine* target_machine, + const CompileOptions& compile_options, bool is_mlir_compile) { LLVMTargetMachineFeatures target_machine_features(target_machine); TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn( module, is_aot_compile, &target_machine_features, is_mlir_compile)); return RunHloPassesAfterLayoutAssn(module, is_aot_compile, - &target_machine_features, is_mlir_compile); + &target_machine_features, compile_options, + is_mlir_compile); } namespace { @@ -1068,77 +1099,39 @@ Status CreateHloProfilingArtifacts( } // namespace -StatusOr> CpuCompiler::RunHloPasses( +absl::StatusOr> CpuCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* /*stream_exec*/, - const CompileOptions& /*options*/) { + const CompileOptions& options) { std::unique_ptr jit_target_machine = SimpleOrcJIT::InferTargetMachineForJIT( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config())); - TF_RETURN_IF_ERROR(RunHloPasses( - module.get(), /*is_aot_compile=*/false, jit_target_machine.get(), - /*is_mlir_compile=*/ - module->config().debug_options().xla_cpu_use_xla_runtime())); + TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false, + jit_target_machine.get(), + /*compile_options=*/options, + /*is_mlir_compile=*/false)); return std::move(module); } -StatusOr> CpuCompiler::AssignBuffers( - HloModule* module, se::StreamExecutor* /*stream_exec*/) { - // Select an order for emitting the HLO instructions for each computation. - // Using this sequence enables tighter buffer liveness analysis and reduced - // memory usage (as compared to using DependencyHloOrdering). - TF_ASSIGN_OR_RETURN(HloSchedule schedule, - ScheduleModule(module, BufferSizeBytesFunction(), - ComputationSchedulerToModuleScheduler( - DFSMemoryScheduler))); - TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); - - // Run buffer allocation on the HLO graph. - TF_ASSIGN_OR_RETURN( - std::unique_ptr assignment, - BufferAssigner::Run( - module, std::make_unique(module->schedule()), - BufferSizeBytesFunction(), memory_alignment, - /*allocate_buffers_for_constants=*/true)); - - return std::move(assignment); -} - namespace { // Post-compilation callback functor for use by SimpleOrcJIT. // // Dumps machine code if dumping is enabled for the module. -struct OrcJITPostCompilationHook { - // Gets an std::function that implements this hook. - static std::function Create( - const HloModule* module) { - // This struct is not copyable, but std::functions must be. So to create an - // std::function out of this struct, we have to wrap it in a shared_ptr. - auto wrapped = std::make_shared(module); - return [wrapped](const llvm::object::ObjectFile& obj_file) { - (*wrapped)(obj_file); - }; - } +static absl::AnyInvocable +CreateOrcJITPostCompilationHook(const HloModule* module, + std::vector* obj_files) { + return [=](const llvm::object::ObjectFile& obj_file) { + if (obj_files) obj_files->push_back(obj_file.getData().str()); - // Constructor can't be private because we want to call it from - // std::make_shared, but users should call Create() instead. - explicit OrcJITPostCompilationHook(const HloModule* module) - : module(module) {} - - private: - void operator()(const llvm::object::ObjectFile& obj_file) { - if (!DumpingEnabledForHloModule(*module)) { - return; + if (DumpingEnabledForHloModule(*module)) { + DumpToFileInDir(*module, /*file_prefix=*/"", /*file_suffix=*/"o", + absl::string_view(obj_file.getData().data(), + obj_file.getData().size())); } - DumpToFileInDir(*module, /*file_prefix=*/"", /*file_suffix=*/"o", - absl::string_view(obj_file.getData().data(), - obj_file.getData().size())); - } - - const HloModule* module; -}; + }; +} void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { llvm_ir::InitializeLLVMCommandLineOptions( @@ -1178,7 +1171,7 @@ Status LowerMLIRModule(HloModule* module, mlir::ModuleOp mlir_module, return absl::OkStatus(); } -StatusOr> createMLIRModule( +absl::StatusOr> createMLIRModule( HloModule* module, mlir::MLIRContext& mlir_context, BufferAssignment* assignment, XlaFrameworkMapping* export_mapping = nullptr) { @@ -1313,8 +1306,8 @@ std::vector SubcomputationEmissionOrder( instruction->opcode() == HloOpcode::kAllReduce || instruction->opcode() == HloOpcode::kReduce || instruction->opcode() == HloOpcode::kReduceWindow; - for (auto it = instruction->called_computations().rbegin(); - it != instruction->called_computations().rend(); ++it) { + auto cc = absl::MakeSpan(instruction->called_computations()); + for (auto it = cc.rbegin(); it != cc.rend(); ++it) { HloComputation* called_computation = *it; ComputationToEmit callee{ called_computation, c.allow_reassociation || allow_reassociation}; @@ -1332,7 +1325,7 @@ std::vector SubcomputationEmissionOrder( } // namespace -StatusOr> +absl::StatusOr> CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { ModuleHook pre_optimization_ir_hook; ModuleHook post_optimization_ir_hook; @@ -1347,6 +1340,10 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { auto llvm_module = std::make_unique("__compute_module", *llvm_context); + // We collect compiled object files (machine code) so we can export + // CpuExecutable to an AOT compilation result. + std::vector obj_files; + auto jit = SimpleOrcJIT::Create( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), @@ -1355,10 +1352,9 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { options::SlpVectorizerDisabled(module->config()), llvm_ir::GetCpuFastMathFlags(module->config()), pre_optimization_ir_hook, post_optimization_ir_hook, - OrcJITPostCompilationHook::Create(module.get())); + CreateOrcJITPostCompilationHook(module.get(), &obj_files)); if (!jit) { - return InternalError("Creating JIT failed: %s", - llvm::toString(jit.takeError())); + return Internal("Creating JIT failed: %s", llvm::toString(jit.takeError())); } llvm_module->setDataLayout((*jit)->data_layout()); llvm_module->setTargetTriple((*jit)->target_triple().getTriple()); @@ -1471,6 +1467,8 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map))); + cpu_executable->set_obj_files(std::move(obj_files)); + if (embed_ir_in_executable) { cpu_executable->set_ir_module_string(ir_module_string); } @@ -1488,19 +1486,21 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { namespace { -StatusOr> GetXlaRuntimeCpuExecutable( - const HloModule& hlo_module, mlir::ModuleOp mlir_module, - absl::string_view entry_point, - const XlaFrameworkMapping& xla_framework_mapping) { +absl::StatusOr> +GetXlaRuntimeCpuExecutable(const HloModule& hlo_module, + mlir::ModuleOp mlir_module, + absl::string_view entry_point, + const XlaFrameworkMapping& xla_framework_mapping, + mlir::DialectRegistry* registry) { runtime::JitExecutable::Options opts = - GetXlaRuntimeJitExecutableOptions(hlo_module); + GetXlaRuntimeJitExecutableOptions(hlo_module, registry); std::string serialized_mlir = llvm_ir::DumpToString(mlir_module); absl::StatusOr jit_executable = runtime::JitExecutable::Instantiate(serialized_mlir, entry_point, opts); if (!jit_executable.ok()) { - return InternalError("Failed to compile XLA Runtime program: %s", - jit_executable.status().message()); + return Internal("Failed to compile XLA Runtime program: %s", + jit_executable.status().message()); } return std::make_unique( @@ -1509,9 +1509,9 @@ StatusOr> GetXlaRuntimeCpuExecutable( } } // namespace -StatusOr> +absl::StatusOr> CpuCompiler::CompileXlaRuntimeCpuExecutable( - std::unique_ptr hlo_module) { + std::unique_ptr hlo_module, mlir::DialectRegistry* registry) { // Select an order for emitting the HLO instructions for each // computation. Using this sequence enables tighter buffer liveness analysis // and reduced memory usage (as compared to using DependencyHloOrdering). @@ -1547,6 +1547,9 @@ CpuCompiler::CompileXlaRuntimeCpuExecutable( } mlir::MLIRContext mlir_context; + if (registry) { + mlir_context.appendDialectRegistry(*registry); + } XlaFrameworkMapping xla_framework_mapping; TF_ASSIGN_OR_RETURN( auto mlir_module, @@ -1556,7 +1559,7 @@ CpuCompiler::CompileXlaRuntimeCpuExecutable( TF_ASSIGN_OR_RETURN( auto xla_runtime_executable, GetXlaRuntimeCpuExecutable(*hlo_module, *mlir_module, "main", - xla_framework_mapping)); + xla_framework_mapping, registry)); if (DumpingEnabledForHloModule(*hlo_module)) { TF_ASSIGN_OR_RETURN(std::string_view obj_file, @@ -1571,10 +1574,10 @@ CpuCompiler::CompileXlaRuntimeCpuExecutable( std::move(xla_runtime_executable)); } -StatusOr> CpuCompiler::RunBackend( +absl::StatusOr> CpuCompiler::RunBackend( std::unique_ptr module, [[maybe_unused]] se::StreamExecutor* stream_exec, - [[maybe_unused]] const CompileOptions& options) { + const CompileOptions& options) { VLOG(1) << "Compiling: " << module->name(); XLA_SCOPED_LOGGING_TIMER( absl::StrFormat("Compiling [%s] for CPU using JIT", module->name())); @@ -1586,13 +1589,8 @@ StatusOr> CpuCompiler::RunBackend( &InitializeLLVMCommandLineOptions, module->config()); std::unique_ptr cpu_executable; - if (module->config().debug_options().xla_cpu_use_xla_runtime()) { - TF_ASSIGN_OR_RETURN(cpu_executable, - CompileXlaRuntimeCpuExecutable(std::move(module))); - } else { - TF_ASSIGN_OR_RETURN(cpu_executable, - CompileLegacyCpuExecutable(std::move(module))); - } + TF_ASSIGN_OR_RETURN(cpu_executable, + CompileLegacyCpuExecutable(std::move(module))); cpu_executable->set_debug_info( cpu_executable->buffer_assignment().GetStats().ToString()); @@ -1600,7 +1598,7 @@ StatusOr> CpuCompiler::RunBackend( return std::unique_ptr(std::move(cpu_executable)); } -StatusOr>> +absl::StatusOr>> CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& aot_options) { TF_RET_CHECK(!module_group->empty()); @@ -1648,7 +1646,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, const llvm::Target* target = llvm::TargetRegistry::lookupTarget(triple.getTriple(), error); if (target == nullptr) { - return InternalError("TargetRegistry::lookupTarget failed: %s", error); + return Internal("TargetRegistry::lookupTarget failed: %s", error); } llvm::Reloc::Model reloc_model = llvm::Reloc::Static; @@ -1699,172 +1697,177 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, HloModule* module = modules[i].get(); VLOG(1) << "Compiling ahead-of-time: " << module->name(); - TF_RETURN_IF_ERROR( - RunHloPasses(module, /*is_aot_compile=*/true, target_machine.get(), - /*is_mlir_compile=*/options.use_mlir_hlo_lowering())); + if (!module->has_schedule()) { + TF_RETURN_IF_ERROR( + RunHloPasses(module, /*is_aot_compile=*/true, target_machine.get(), + /*dummy*/ CompileOptions{}, + /*is_mlir_compile=*/options.use_mlir_hlo_lowering())); - TF_ASSIGN_OR_RETURN(HloSchedule schedule, - ScheduleModule(module, BufferSizeBytesFunction())); + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(module, BufferSizeBytesFunction())); - // Run buffer analysis on the HLO graph. This analysis figures out which - // temporary buffers are required to run the computation. - TF_ASSIGN_OR_RETURN( - std::unique_ptr assignment, - BufferAssigner::Run(module, - std::make_unique(schedule), - BufferSizeBytesFunction(), memory_alignment, - /*allocate_buffers_for_constants=*/true)); - // BufferAssignment::ToString() includes a header, so no need for us to - // print one ourselves. - if (DumpingEnabledForHloModule(*module)) { - DumpToFileInDirOrStdout(*module, "", "buffer_assignment", - assignment->ToString()); - } - DumpHloModuleIfEnabled(*module, *assignment, - absl::StrCat("cpu_", kAfterOptimizationsDumpName)); - - absl::flat_hash_map - instruction_to_profile_idx; - absl::flat_hash_map - computation_to_profile_idx; - std::unique_ptr hlo_profile_index_map; - std::unique_ptr hlo_profile_printer_data; - - if (module->config().hlo_profiling_enabled()) { - TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts( - *module, &instruction_to_profile_idx, &computation_to_profile_idx, - &hlo_profile_index_map, &hlo_profile_printer_data)); - } + // Run buffer analysis on the HLO graph. This analysis figures out which + // temporary buffers are required to run the computation. + TF_ASSIGN_OR_RETURN( + std::unique_ptr assignment, + BufferAssigner::Run(module, + std::make_unique(schedule), + BufferSizeBytesFunction(), memory_alignment, + /*allocate_buffers_for_constants=*/true)); + // BufferAssignment::ToString() includes a header, so no need for us to + // print one ourselves. + if (DumpingEnabledForHloModule(*module)) { + DumpToFileInDirOrStdout(*module, "", "buffer_assignment", + assignment->ToString()); + } + DumpHloModuleIfEnabled(*module, *assignment, + absl::StrCat("cpu_", kAfterOptimizationsDumpName)); + + absl::flat_hash_map + instruction_to_profile_idx; + absl::flat_hash_map + computation_to_profile_idx; + std::unique_ptr hlo_profile_index_map; + std::unique_ptr hlo_profile_printer_data; + + if (module->config().hlo_profiling_enabled()) { + TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts( + *module, &instruction_to_profile_idx, &computation_to_profile_idx, + &hlo_profile_index_map, &hlo_profile_printer_data)); + } - LLVMTargetMachineFeatures target_machine_features(target_machine.get()); - std::vector buffer_infos = - CreateBufferInfosFromBufferAssignment(*module, *assignment); - HloComputation* computation = module->entry_computation(); + LLVMTargetMachineFeatures target_machine_features(target_machine.get()); + std::vector buffer_infos = + CreateBufferInfosFromBufferAssignment(*module, *assignment); + HloComputation* computation = module->entry_computation(); - if (options.use_mlir_hlo_lowering()) { - TF_ASSIGN_OR_RETURN( - auto mlir_module, - createMLIRModule(module, mlir_context, assignment.get())); - TF_RETURN_IF_ERROR( - xla::runtime::ExportMainWithOrdinal0(*mlir_module, mlir_context)); - TF_RETURN_IF_ERROR( - LowerMLIRModule(module, *mlir_module, mlir_context, *target_machine)); + if (options.use_mlir_hlo_lowering()) { + TF_ASSIGN_OR_RETURN( + auto mlir_module, + createMLIRModule(module, mlir_context, assignment.get())); + TF_RETURN_IF_ERROR( + xla::runtime::ExportMainWithOrdinal0(*mlir_module, mlir_context)); + TF_RETURN_IF_ERROR(LowerMLIRModule(module, *mlir_module, mlir_context, + *target_machine)); - llvm::cast(mlir_module->lookupSymbol("main")) - .setName(options.entry_point_name()); + llvm::cast(mlir_module->lookupSymbol("main")) + .setName(options.entry_point_name()); - llvm_module = mlir::translateModuleToLLVMIR(*mlir_module, llvm_context); - if (!llvm_module) { - return InternalError("Failed to translate module to LLVM IR"); - } - // Set missing information - llvm_module->setDataLayout(target_machine->createDataLayout()); - llvm_module->setTargetTriple(triple.getTriple()); - if (pic_level != llvm::PICLevel::NotPIC) { - llvm_module->setPICLevel(pic_level); - } - if (pie_level != llvm::PIELevel::Default) { - llvm_module->setPIELevel(pie_level); - } - } else { - // Set required information before emitting IR - llvm_module = - std::make_unique("__compute_module", llvm_context); - llvm_module->setDataLayout(target_machine->createDataLayout()); - llvm_module->setTargetTriple(triple.getTriple()); - if (pic_level != llvm::PICLevel::NotPIC) { - llvm_module->setPICLevel(pic_level); - } - if (pie_level != llvm::PIELevel::Default) { - llvm_module->setPIELevel(pie_level); - } - IrEmitter ir_emitter( - &mlir_context, *module, *assignment, llvm_module.get(), - std::move(instruction_to_profile_idx), - std::move(computation_to_profile_idx), - ModuleComputationsTransitivelyContainCustomCall(*module), - &target_machine_features, - // TODO(b/66051036): Run full msan for AOT. - /*emit_code_for_msan=*/false); - - TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); - - for (ComputationToEmit subcomputation : - SubcomputationEmissionOrder(computation)) { - if (subcomputation.computation->IsFusionComputation()) { - continue; + llvm_module = mlir::translateModuleToLLVMIR(*mlir_module, llvm_context); + if (!llvm_module) { + return Internal("Failed to translate module to LLVM IR"); } - TF_RETURN_IF_ERROR( - ir_emitter - .EmitComputation(subcomputation.computation, - subcomputation.computation->name(), - /*is_top_level_computation=*/false, - schedule.sequence(subcomputation.computation) - .instructions(), - subcomputation.allow_reassociation) - .status()); + // Set missing information + llvm_module->setDataLayout(target_machine->createDataLayout()); + llvm_module->setTargetTriple(triple.getTriple()); + if (pic_level != llvm::PICLevel::NotPIC) { + llvm_module->setPICLevel(pic_level); + } + if (pie_level != llvm::PIELevel::Default) { + llvm_module->setPIELevel(pie_level); + } + } else { + // Set required information before emitting IR + llvm_module = + std::make_unique("__compute_module", llvm_context); + llvm_module->setDataLayout(target_machine->createDataLayout()); + llvm_module->setTargetTriple(triple.getTriple()); + if (pic_level != llvm::PICLevel::NotPIC) { + llvm_module->setPICLevel(pic_level); + } + if (pie_level != llvm::PIELevel::Default) { + llvm_module->setPIELevel(pie_level); + } + IrEmitter ir_emitter( + &mlir_context, *module, *assignment, llvm_module.get(), + std::move(instruction_to_profile_idx), + std::move(computation_to_profile_idx), + ModuleComputationsTransitivelyContainCustomCall(*module), + &target_machine_features, + // TODO(b/66051036): Run full msan for AOT. + /*emit_code_for_msan=*/false); + + TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); + + for (ComputationToEmit subcomputation : + SubcomputationEmissionOrder(computation)) { + if (subcomputation.computation->IsFusionComputation()) { + continue; + } + TF_RETURN_IF_ERROR( + ir_emitter + .EmitComputation(subcomputation.computation, + subcomputation.computation->name(), + /*is_top_level_computation=*/false, + schedule.sequence(subcomputation.computation) + .instructions(), + subcomputation.allow_reassociation) + .status()); + } + const std::string& entry_point_name = options.entry_point_name(); + TF_ASSIGN_OR_RETURN(llvm::Function * entry_function, + ir_emitter.EmitComputation( + computation, entry_point_name, + /*is_top_level_computation=*/true, + schedule.sequence(computation).instructions(), + /*allow_reassociation=*/false)); + + CHECK(entry_function->getName() == entry_point_name); } - const std::string& entry_point_name = options.entry_point_name(); - TF_ASSIGN_OR_RETURN(llvm::Function * entry_function, - ir_emitter.EmitComputation( - computation, entry_point_name, - /*is_top_level_computation=*/true, - schedule.sequence(computation).instructions(), - /*allow_reassociation=*/false)); - - CHECK(entry_function->getName() == entry_point_name); - } - ModuleHook pre_optimization_ir_hook; - ModuleHook post_optimization_ir_hook; - std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) = - GetIRModuleHooks(*module, user_pre_optimization_hook_, - user_post_optimization_hook_); - - // Run the LLVM verifier over the unoptimized LLVM IR. If it fails, run - // the pre-optimization IR dump hook before returning. - { - Status verify_status = VerifyLlvmModule(*llvm_module); - if (!verify_status.ok() && pre_optimization_ir_hook) { - pre_optimization_ir_hook(*llvm_module); + ModuleHook pre_optimization_ir_hook; + ModuleHook post_optimization_ir_hook; + std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) = + GetIRModuleHooks(*module, user_pre_optimization_hook_, + user_post_optimization_hook_); + + // Run the LLVM verifier over the unoptimized LLVM IR. If it fails, run + // the pre-optimization IR dump hook before returning. + { + Status verify_status = VerifyLlvmModule(*llvm_module); + if (!verify_status.ok() && pre_optimization_ir_hook) { + pre_optimization_ir_hook(*llvm_module); + } + TF_RETURN_IF_ERROR(verify_status); } - TF_RETURN_IF_ERROR(verify_status); - } - auto post_codegen_hook = [&](const llvm::object::ObjectFile& obj_file) { - if (!DumpingEnabledForHloModule(*module)) { - return; + auto post_codegen_hook = [&](const llvm::object::ObjectFile& obj_file) { + if (!DumpingEnabledForHloModule(*module)) { + return; + } + DumpToFileInDir(*module, /*file_prefix=*/"", /*file_suffix=*/"o", + absl::string_view(obj_file.getData().data(), + obj_file.getData().size())); + }; + + std::vector xla_runtime_abi_conversions; + if (options.use_mlir_hlo_lowering()) { + xla_runtime_abi_conversions.push_back(options.entry_point_name()); } - DumpToFileInDir(*module, /*file_prefix=*/"", /*file_suffix=*/"o", - absl::string_view(obj_file.getData().data(), - obj_file.getData().size())); - }; - std::vector xla_runtime_abi_conversions; - if (options.use_mlir_hlo_lowering()) { - xla_runtime_abi_conversions.push_back(options.entry_point_name()); + CompilerFunctor compiler_functor( + target_machine.get(), static_cast(opt_level), + options::OptimizeForSizeRequested(module->config()), + module->config().debug_options().xla_llvm_disable_expensive_passes(), + options::SlpVectorizerDisabled(module->config()), + llvm_ir::GetCpuFastMathFlags(module->config()), + pre_optimization_ir_hook, post_optimization_ir_hook, + post_codegen_hook, aot_options.sanitize_dataflow(), + aot_options.sanitize_abilists_dataflow(), + xla_runtime_abi_conversions); + std::unique_ptr object_file = + cantFail(compiler_functor(*llvm_module)); + ObjectFileData object_file_data(object_file->getBufferStart(), + object_file->getBufferEnd()); + + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, + assignment->GetUniqueTopLevelOutputSlice()); + + results.emplace_back(std::make_unique( + std::move(object_file_data), std::move(buffer_infos), + result_slice.index(), std::move(modules[i]), + std::move(hlo_profile_printer_data))); } - - CompilerFunctor compiler_functor( - target_machine.get(), static_cast(opt_level), - options::OptimizeForSizeRequested(module->config()), - module->config().debug_options().xla_llvm_disable_expensive_passes(), - options::SlpVectorizerDisabled(module->config()), - llvm_ir::GetCpuFastMathFlags(module->config()), - pre_optimization_ir_hook, post_optimization_ir_hook, post_codegen_hook, - aot_options.sanitize_dataflow(), - aot_options.sanitize_abilists_dataflow(), xla_runtime_abi_conversions); - std::unique_ptr object_file = - cantFail(compiler_functor(*llvm_module)); - ObjectFileData object_file_data(object_file->getBufferStart(), - object_file->getBufferEnd()); - - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment->GetUniqueTopLevelOutputSlice()); - - results.emplace_back(std::make_unique( - std::move(object_file_data), std::move(buffer_infos), - result_slice.index(), std::move(hlo_profile_printer_data))); } VLOG(1) << "Compilation finished"; @@ -1875,36 +1878,143 @@ se::Platform::Id CpuCompiler::PlatformId() const { return se::host::kHostPlatformId; } -// A special version that assigns zero size to sparse types -// and passes all other shapes to the cpu executable function. -static int64_t ShapeSizeBytesZeroSparse(const Shape& shape) { - if (LayoutUtil::IsSparseArray(shape)) { - return 0; - } - return CpuExecutable::ShapeSizeBytes(shape); +HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const { + return CpuExecutable::ShapeSizeBytes; } -HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const { - return allow_sparse_shapes_ ? ShapeSizeBytesZeroSparse - : CpuExecutable::ShapeSizeBytes; +namespace { + +// This is a result of exporting JIT compiled CpuExecutable to AOT compilation +// result that can be saved on disk and shipped over the wire. +class CpuExecutableAotCompilationResult : public AotCompilationResult { + public: + CpuExecutableAotCompilationResult(const HloModule* hlo_module, + const BufferAssignment* buffer_assignment, + std::string_view function_name, + std::string_view obj_file) { + *proto_.mutable_hlo_module()->mutable_hlo_module() = hlo_module->ToProto(); + *proto_.mutable_buffer_assignment() = buffer_assignment->ToProto(); + proto_.set_entry_function_name(std::string(function_name)); + proto_.set_obj_file(std::string(obj_file)); + *proto_.mutable_hlo_module()->mutable_config() = + *hlo_module->config().ToProto(); + module_ = hlo_module->Clone(); + } + + absl::StatusOr SerializeAsString() const override { + return proto_.SerializeAsString(); + } + + static absl::StatusOr> + FromString(const std::string& serialized) { + CompilationResultProto proto; + if (!proto.ParseFromString(serialized)) { + return Internal( + "Failed to parse serialized CpuExecutableAotCompilationResult."); + } + + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + HloModule::CreateFromProtoWithConfig(proto.hlo_module())); + + return std::unique_ptr( + new CpuExecutableAotCompilationResult(proto, std::move(module))); + } + + absl::StatusOr> LoadExecutable( + Compiler* compiler, const se::StreamExecutor* stream_exec) const override; + + const HloModule* optimized_module() const override { return module_.get(); } + + std::unique_ptr consume_optimized_module() override { + return std::move(module_); + } + + private: + explicit CpuExecutableAotCompilationResult(CompilationResultProto proto, + std::unique_ptr module) + : proto_(std::move(proto)), module_(std::move(module)) {} + + CompilationResultProto proto_; + std::unique_ptr module_; +}; + +} // namespace + +absl::StatusOr> +CpuExecutableAotCompilationResult::LoadExecutable( + Compiler* compiler, const se::StreamExecutor* stream_exec) const { + // Recreate HloModule from proto. + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + HloModule::CreateFromProtoWithConfig(proto_.hlo_module())); + + // Recreate BufferAssignment from proto. + TF_ASSIGN_OR_RETURN( + std::unique_ptr buffer_assignment, + BufferAssignment::FromProto(proto_.buffer_assignment(), module.get(), + compiler->BufferSizeBytesFunction(), + /*can_share_buffer=*/nullptr)); + + auto jit = SimpleOrcJIT::Create( + CompilerTargetOptions(module->config()), + CodeGenOptLevel(module->config()), + options::OptimizeForSizeRequested(module->config()), + module->config().debug_options().xla_llvm_disable_expensive_passes(), + options::SlpVectorizerDisabled(module->config()), + llvm_ir::GetCpuFastMathFlags(module->config()), + /*pre_optimization_hook=*/nullptr, /*post_optimization_hook=*/nullptr, + /*post_codegen_hook=*/nullptr); + if (!jit) { + return Internal("Creating JIT failed: %s", llvm::toString(jit.takeError())); + } + + // Create a named buffer from compiled object file. + llvm::StringRef data(proto_.obj_file().data(), proto_.obj_file().size()); + auto obj_file = + llvm::MemoryBuffer::getMemBuffer(data, proto_.entry_function_name()); + + cantFail((*jit)->AddObjFile(std::move(obj_file))); + + TF_ASSIGN_OR_RETURN( + auto cpu_executable, + CpuExecutable::Create(std::move(*jit), std::move(buffer_assignment), + std::move(module), proto_.entry_function_name(), + nullptr, nullptr)); + + // Dump computation proto state and buffer assignment for + // GetCompiledMemoryStats results. + auto hlo_proto = std::make_unique(); + *hlo_proto->mutable_hlo_module() = cpu_executable->module().ToProto(); + *hlo_proto->mutable_buffer_assignment() = + cpu_executable->buffer_assignment().ToProto(); + cpu_executable->set_hlo_proto(std::move(hlo_proto)); + + return cpu_executable; } -StatusOr> CpuCompiler::Export( +absl::StatusOr> CpuCompiler::Export( Executable* executable) const { auto* cpu_executable = tensorflow::down_cast(executable); if (!cpu_executable) return Internal("Could not downcast Executable to CpuExecutable"); - HloModuleProto module_proto = cpu_executable->module().ToProto(); - TF_ASSIGN_OR_RETURN(auto obj_file, cpu_executable->GetObjFile()); - TF_ASSIGN_OR_RETURN(auto mlir_module, cpu_executable->GetMlirModule()); - TF_ASSIGN_OR_RETURN(XlaFrameworkMapping xla_framework_mapping, - cpu_executable->GetXlaFrameworkMapping()); + if (cpu_executable->obj_files().size() != 1) { + return absl::InternalError( + absl::StrCat("Can't export CPU execuable, expected exactly one object " + "file but got: ", + cpu_executable->obj_files().size())); + } + + return {std::make_unique( + &cpu_executable->module(), &cpu_executable->buffer_assignment(), + cpu_executable->module_name(), cpu_executable->obj_files()[0])}; +} - std::unique_ptr result = - std::make_unique( - module_proto, obj_file, mlir_module, xla_framework_mapping); - return result; +absl::StatusOr> +CpuCompiler::LoadAotCompilationResult( + const std::string& serialized_aot_result) { + return CpuExecutableAotCompilationResult::FromString(serialized_aot_result); } } // namespace cpu diff --git a/xla/service/cpu/cpu_compiler.h b/xla/service/cpu/cpu_compiler.h index 0d7ce6eb53862..69ba7f498f3db 100644 --- a/xla/service/cpu/cpu_compiler.h +++ b/xla/service/cpu/cpu_compiler.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ limitations under the License. #include "xla/service/compiler.h" #include "xla/service/cpu/executable.pb.h" #include "xla/service/cpu/target_machine_features.h" +#include "xla/service/cpu/xla_framework.h" #include "xla/service/executable.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_cost_analysis.h" @@ -40,11 +41,14 @@ limitations under the License. #include "xla/stream_executor/stream_executor.h" #include "xla/util.h" +namespace mlir { +class DialectRegistry; +} // namespace mlir + namespace xla { namespace cpu { class CpuExecutable; -class XlaFrameworkMapping; // This class wraps the configurability options that LLVM exposes including: the // target triple, the target cpu and the target features. It also includes the @@ -96,44 +100,12 @@ class CpuAotCompilationOptions : public AotCompilationOptions { bool use_mlir_hlo_lowering_ = false; }; -class CpuXlaRuntimeAotCompilationResult : public AotCompilationResult { - public: - CpuXlaRuntimeAotCompilationResult(HloModuleProto hlo, - std::string_view obj_file, - std::string_view mlir_module, - XlaFrameworkMapping xla_framework_mapping); - - explicit CpuXlaRuntimeAotCompilationResult( - XlaRuntimeCpuExecutableProto executable) - : xla_runtime_cpu_executable_(executable) {} - - StatusOr SerializeAsString() const override { - return xla_runtime_cpu_executable_.SerializeAsString(); - } - - static StatusOr> - FromString(const std::string& serialized) { - XlaRuntimeCpuExecutableProto xla_runtime_cpu_executable; - if (!xla_runtime_cpu_executable.ParseFromString(serialized)) { - return InternalError("Failed to parse serialized JitRtExecutableProto."); - } - return std::make_unique( - xla_runtime_cpu_executable); - } - - StatusOr> LoadExecutable( - Compiler* compiler, se::StreamExecutor* executor) override; - - private: - XlaRuntimeCpuExecutableProto xla_runtime_cpu_executable_; -}; - class CpuAotCompilationResult : public AotCompilationResult { public: CpuAotCompilationResult( ObjectFileData object_file_data, std::vector buffer_infos, - int64_t result_buffer_index, + int64_t result_buffer_index, std::unique_ptr module, std::unique_ptr hlo_profile_printer_data); ~CpuAotCompilationResult() override = default; @@ -147,6 +119,9 @@ class CpuAotCompilationResult : public AotCompilationResult { } int64_t result_buffer_index() const { return result_buffer_index_; } + const HloModule* optimized_module() const override; + std::unique_ptr consume_optimized_module() override; + private: // Contains the compiled computation: an object file. const ObjectFileData object_file_data_; @@ -160,6 +135,9 @@ class CpuAotCompilationResult : public AotCompilationResult { // parameter when calling the compiled computation. const int64_t result_buffer_index_; + // Contains the optimized HLO module. + std::unique_ptr module_; + // Contains an instance of HloProfilePrinterData if HLO profiling is enabled, // otherwise is nullptr. std::unique_ptr hlo_profile_printer_data_; @@ -173,26 +151,22 @@ class CpuAotCompilationResult : public AotCompilationResult { class CpuCompiler : public LLVMCompiler { public: CpuCompiler(); - explicit CpuCompiler(bool allow_sparse_shapes); ~CpuCompiler() override = default; - StatusOr>> Compile( + absl::StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_execs, const CompileOptions& options) override; - StatusOr> RunHloPasses( + absl::StatusOr> RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, const CompileOptions& options) override; - StatusOr> AssignBuffers( - HloModule* module, se::StreamExecutor* stream_exec) override; - - StatusOr> RunBackend( + absl::StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, const CompileOptions& options) override; - StatusOr>> + absl::StatusOr>> CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& options) override; @@ -200,18 +174,20 @@ class CpuCompiler : public LLVMCompiler { HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; - StatusOr> Export( + absl::StatusOr> Export( Executable* executable) const override; // Returns a (deserialized) AotCompilationResult from a serialized // AotCompilationResult. - StatusOr> LoadAotCompilationResult( - const std::string& serialized_aot_result) override { - return CpuXlaRuntimeAotCompilationResult::FromString(serialized_aot_result); - } + absl::StatusOr> + LoadAotCompilationResult(const std::string& serialized_aot_result) override; - StatusOr> CompileXlaRuntimeCpuExecutable( - std::unique_ptr module); + // The optional `registry` supports MLIR dialects and plugins to be loaded + // during optimization. If non-null, it will be used to construct relevant + // MLIR contexts. + absl::StatusOr> CompileXlaRuntimeCpuExecutable( + std::unique_ptr module, + mlir::DialectRegistry* registry = nullptr); private: // Initialize the LLVM target. @@ -221,6 +197,7 @@ class CpuCompiler : public LLVMCompiler { // correctness. Status RunHloPasses(HloModule* module, bool is_aot_compile, llvm::TargetMachine* target_machine, + const CompileOptions& compile_options, bool is_mlir_compile = false); // Runs HLO passes up to and including layout assignment. @@ -232,17 +209,14 @@ class CpuCompiler : public LLVMCompiler { // Runs HLO passes after layout assignment. Status RunHloPassesAfterLayoutAssn( HloModule* module, bool is_aot_compile, - LLVMTargetMachineFeatures* target_machine_features, bool is_mlir_compile); + LLVMTargetMachineFeatures* target_machine_features, + const CompileOptions& compile_options, bool is_mlir_compile); - StatusOr> CompileLegacyCpuExecutable( + absl::StatusOr> CompileLegacyCpuExecutable( std::unique_ptr module); CpuCompiler(const CpuCompiler&) = delete; CpuCompiler& operator=(const CpuCompiler&) = delete; - - // Flag that can be used to override bail-out on sparse shapes. - // When set, buffer assignment assigns zero sizes to these shapes. - const bool allow_sparse_shapes_ = false; }; } // namespace cpu diff --git a/xla/service/cpu/cpu_compiler_registerer.cc b/xla/service/cpu/cpu_compiler_registerer.cc index 45cda9ae5d4d5..a1acf3aa24524 100644 --- a/xla/service/cpu/cpu_compiler_registerer.cc +++ b/xla/service/cpu/cpu_compiler_registerer.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc b/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc index e72a27424ec12..0852490baa0a1 100644 --- a/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc +++ b/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/cpu_executable.cc b/xla/service/cpu/cpu_executable.cc index 404a2557d864e..13da05afe2025 100644 --- a/xla/service/cpu/cpu_executable.cc +++ b/xla/service/cpu/cpu_executable.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -58,7 +58,7 @@ namespace cpu { namespace runtime = ::xla::runtime; -StatusOr> CpuExecutable::Create( +absl::StatusOr> CpuExecutable::Create( std::unique_ptr jit, std::unique_ptr assignment, std::unique_ptr hlo_module, @@ -91,7 +91,7 @@ StatusOr> CpuExecutable::Create( return executable; } -StatusOr> CpuExecutable::Create( +absl::StatusOr> CpuExecutable::Create( std::unique_ptr hlo_module, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map, @@ -100,6 +100,9 @@ StatusOr> CpuExecutable::Create( std::unique_ptr executable(new CpuExecutable( std::move(hlo_module), std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map), std::move(assignment))); + executable->set_ir_module_string( + xla_runtime_executable->GetExecutable().take_ir_module_string()); + executable->module_name_ = "main"; executable->xla_runtime_executable_ = std::move(xla_runtime_executable); return executable; } @@ -124,7 +127,7 @@ CpuExecutable::~CpuExecutable() { } } -static StatusOr MemoryForAllocation( +static absl::StatusOr MemoryForAllocation( const BufferAllocation& allocation, absl::Span arguments, se::DeviceMemoryAllocator* memory_allocator, int device_ordinal) { @@ -160,9 +163,10 @@ static StatusOr MemoryForAllocation( return MaybeOwningDeviceMemory{std::move(out)}; } -StatusOr> CpuExecutable::CreateBufferTable( - se::DeviceMemoryAllocator* memory_allocator, int device_ordinal, - absl::Span arguments) { +absl::StatusOr> +CpuExecutable::CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator, + int device_ordinal, + absl::Span arguments) { std::vector buffers( assignment_->Allocations().size()); VLOG(3) << "Allocating " << assignment_->Allocations().size() @@ -258,14 +262,14 @@ Status CpuExecutable::ExecuteComputeFunction( std::optional error_message = CustomCallStatusGetMessage(&status); if (error_message) { - return InternalError("CustomCall failed: %s", *error_message); + return Internal("CustomCall failed: %s", *error_message); } } return OkStatus(); } -StatusOr> CpuExecutable::LoadFromObjFile( +absl::StatusOr> CpuExecutable::LoadFromObjFile( std::unique_ptr hlo_module, absl::string_view obj_file, absl::string_view mlir_module, std::unique_ptr buffer_assignment, @@ -282,7 +286,7 @@ StatusOr> CpuExecutable::LoadFromObjFile( // Load MLIR module behind the compiled object file. auto module = mlir::parseSourceString(mlir_module, ctx.get()); - if (!module) return InternalError("Failed to parse AOT compiled module"); + if (!module) return Internal("Failed to parse AOT compiled module"); llvm::StringRef data(obj_file.data(), obj_file.size()); auto buffer = llvm::MemoryBuffer::getMemBuffer(data, hlo_module->name()); @@ -293,16 +297,16 @@ StatusOr> CpuExecutable::LoadFromObjFile( absl::StatusOr sig = opts.compiler.type_converter.Convert(func_type); if (!sig.ok()) - return InternalError("Type converter failed to convert function type"); + return Internal("Type converter failed to convert function type"); mlir::FunctionType runtime_type = opts.compiler.calling_convention(func_type); if (!runtime_type) - return InternalError("Calling convention failed to convert function type"); + return Internal("Calling convention failed to convert function type"); absl::StatusOr runtime_sig = opts.compiler.type_converter.Convert(runtime_type); if (!runtime_sig.ok()) - return InternalError( + return Internal( "Type converter failed to convert runtime function type"); // Cpu executable has a single exported function. @@ -315,7 +319,7 @@ StatusOr> CpuExecutable::LoadFromObjFile( opts.compiler.symbols_binding); if (!executable.ok()) - return InternalError("Failed to load XLA Runtime executable: %s", + return Internal("Failed to load XLA Runtime executable: %s", executable.status().message()); // Move runtime::Executable ownership to the XlaRuntimeCpuExecutable. @@ -329,7 +333,7 @@ StatusOr> CpuExecutable::LoadFromObjFile( std::move(xla_runtime_executable)); } -StatusOr CpuExecutable::CreateResultShapedBuffer( +absl::StatusOr CpuExecutable::CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, absl::Span buffers, absl::Span arguments) { @@ -433,12 +437,12 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( // which should point to a runtime::MemrefType. // Note: 'descriptor_index' and 'operand_index' are just used for error // reporting. -static StatusOr BufferToMemref( +static absl::StatusOr BufferToMemref( const BufferDesc& descriptor, const runtime::Type& operand_type, size_t descriptor_index, size_t operand_index) { auto* memref = llvm::dyn_cast(&operand_type); if (!memref) { - return InternalError( + return Internal( "Cannot convert descriptor %zu (operand_index %zu): " "the corresponding type in the signature is a %s, " "not a MemrefType.", @@ -492,7 +496,7 @@ Status XlaRuntimeCpuExecutable::Execute( // Verify that the number of arguments in the mapping matches the signature. // Add one to num_arguments to account for the signature's execution context. if (num_arguments + 1 != signature.num_operands()) { - return InternalError( + return Internal( "Wrong number of arguments: got %zu via XLA FrameworkMapping, expected " "%d.", num_arguments, static_cast(signature.num_operands()) - 1); @@ -508,7 +512,7 @@ Status XlaRuntimeCpuExecutable::Execute( size_t operand_index = arguments.size() + 1; const runtime::Type* operand_type = signature.operand(operand_index); - StatusOr memref = BufferToMemref( + absl::StatusOr memref = BufferToMemref( descriptor, *operand_type, descriptor_index, operand_index); if (!memref.ok()) { return memref.status(); @@ -548,7 +552,7 @@ Status XlaRuntimeCpuExecutable::Execute( GetExecutable().InitializeCallFrame(arguments, &call_frame, /*verify_arguments=*/false); !status.ok()) { - return InternalError("Failed to initialize call frame: %s.", + return Internal("Failed to initialize call frame: %s.", status.message()); } @@ -578,14 +582,14 @@ Status XlaRuntimeCpuExecutable::Execute( GetExecutable().Execute(call_frame, opts); if (auto status = GetExecutable().ReturnResults(converter, &call_frame); !status.ok()) { - return InternalError("Failed to execute XLA Runtime executable: %s%s%s.", + return Internal("Failed to execute XLA Runtime executable: %s%s%s.", status.message(), diagnostic.empty() ? "" : ": ", diagnostic); } return OkStatus(); } -StatusOr CpuExecutable::ExecuteAsyncOnStream( +absl::StatusOr CpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, std::vector arguments, HloExecutionProfile* hlo_execution_profile) { diff --git a/xla/service/cpu/cpu_executable.h b/xla/service/cpu/cpu_executable.h index dcf43b56e9423..c0c0f51f4194d 100644 --- a/xla/service/cpu/cpu_executable.h +++ b/xla/service/cpu/cpu_executable.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -76,10 +76,10 @@ class XlaRuntimeCpuExecutable { } } - StatusOr GetObjFile() const { + absl::StatusOr GetObjFile() const { if (!std::holds_alternative>( executable_)) { - return InternalError("No JitExecutable"); + return Internal("No JitExecutable"); } runtime::JitExecutable* jit_executable = @@ -87,15 +87,15 @@ class XlaRuntimeCpuExecutable { std::unique_ptr obj_file = jit_executable->DefaultExecutable()->obj_file(); if (!obj_file) - return InternalError("XlaRuntimeCpuExecutable didn't save the obj file"); + return Internal("XlaRuntimeCpuExecutable didn't save the obj file"); return std::string_view(obj_file->getBuffer()); } - StatusOr GetMlirModule() const { + absl::StatusOr GetMlirModule() const { if (!std::holds_alternative>( executable_)) { - return InternalError("No JitExecutable"); + return Internal("No JitExecutable"); } runtime::JitExecutable* jit_executable = @@ -124,7 +124,7 @@ class XlaRuntimeCpuExecutable { // architecture, so JIT-ed code and host code share the same ABI. class CpuExecutable : public Executable { public: - static StatusOr> Create( + static absl::StatusOr> Create( std::unique_ptr jit, std::unique_ptr assignment, std::unique_ptr hlo_module, @@ -132,7 +132,7 @@ class CpuExecutable : public Executable { std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map); // XLA Runtime factory method. - static StatusOr> Create( + static absl::StatusOr> Create( std::unique_ptr hlo_module, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map, @@ -149,7 +149,7 @@ class CpuExecutable : public Executable { return xla_runtime_executable_->Execute(descriptor_table, run_options); } - StatusOr ExecuteAsyncOnStream( + absl::StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, std::vector arguments, HloExecutionProfile* hlo_execution_profile) override; @@ -163,13 +163,19 @@ class CpuExecutable : public Executable { // Returns an Executable that is loaded from an object file (XLA program // compiled to a native function using the XLA Runtime stack). - static StatusOr> LoadFromObjFile( + static absl::StatusOr> LoadFromObjFile( std::unique_ptr hlo_module, absl::string_view obj_file, absl::string_view mlir_module, std::unique_ptr buffer_assignment, XlaFrameworkMapping xla_framework_mapping, runtime::JitExecutable::Options opts); + absl::Span obj_files() const { return obj_files_; } + + void set_obj_files(std::vector obj_files) { + obj_files_ = std::move(obj_files); + } + // This should be called after set_ir_module_string. const std::string& ir_module_string() const { return ir_module_string_; } @@ -177,6 +183,8 @@ class CpuExecutable : public Executable { ir_module_string_ = ir_module_string; } + const std::string& module_name() const { return module_name_; } + static int64_t ShapeSizeBytes(const Shape& shape); // Type of the computation function we expect in the JIT. @@ -193,17 +201,17 @@ class CpuExecutable : public Executable { int64_t SizeOfGeneratedCodeInBytes() const override; - StatusOr GetObjFile() const { + absl::StatusOr GetObjFile() const { if (!IsXlaRuntime()) return Unimplemented("Not an XLA Runtime executable"); return xla_runtime_executable_->GetObjFile(); } - StatusOr GetMlirModule() const { + absl::StatusOr GetMlirModule() const { if (!IsXlaRuntime()) return Unimplemented("Not an XLA Runtime executable"); return xla_runtime_executable_->GetMlirModule(); } - StatusOr GetXlaFrameworkMapping() const { + absl::StatusOr GetXlaFrameworkMapping() const { if (!IsXlaRuntime()) return Unimplemented("Not an XLA Runtime executable"); return xla_runtime_executable_->xla_framework_mapping(); } @@ -226,7 +234,7 @@ class CpuExecutable : public Executable { // // - buffers_to_free: buffers whose ownership was donated by the caller that // are to be freed by the caller. - StatusOr> CreateBufferTable( + absl::StatusOr> CreateBufferTable( se::DeviceMemoryAllocator* memory_allocator, int device_ordinal, absl::Span arguments); @@ -234,7 +242,7 @@ class CpuExecutable : public Executable { // result of the computation, moving buffers out of allocated_buffers and into // the result as appropriate. The addresses are set according to buffer // assignment. - StatusOr CreateResultShapedBuffer( + absl::StatusOr CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, absl::Span buffers, absl::Span arguments); @@ -246,6 +254,11 @@ class CpuExecutable : public Executable { // The JIT containing compiled modules. std::unique_ptr jit_; + // Object files (machine code) compiled from an HLO module by the JIT + // compiler. We capture all object files created by SimpleOrcJIT so we can + // export them to AOT compilation result. + std::vector obj_files_; + // Buffer assignment for the buffers we need to allocate. const std::unique_ptr assignment_; diff --git a/xla/service/cpu/cpu_executable_run_options.h b/xla/service/cpu/cpu_executable_run_options.h new file mode 100644 index 0000000000000..ee1a47e138228 --- /dev/null +++ b/xla/service/cpu/cpu_executable_run_options.h @@ -0,0 +1,42 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_CPU_EXECUTABLE_RUN_OPTIONS_H_ +#define XLA_SERVICE_CPU_CPU_EXECUTABLE_RUN_OPTIONS_H_ + +#include "xla/service/cpu/collectives_interface.h" + +namespace xla::cpu { + +// CPU-specific executable options. +// We keep these separate from ExecutableRunOptions to avoid adding +// dependencies to ExecutableRunOptions. +class CpuExecutableRunOptions { + public: + CpuExecutableRunOptions& set_collectives(CollectivesInterface* collectives) { + collectives_ = collectives; + return *this; + } + CollectivesInterface* collectives() const { return collectives_; } + + private: + // For cross-process collectives, use this collective implementation to + // communicate. + CollectivesInterface* collectives_; +}; + +} // namespace xla::cpu + +#endif // XLA_SERVICE_CPU_CPU_EXECUTABLE_RUN_OPTIONS_H_ diff --git a/xla/service/cpu/cpu_float_support.cc b/xla/service/cpu/cpu_float_support.cc new file mode 100644 index 0000000000000..0bb4dd8e875a7 --- /dev/null +++ b/xla/service/cpu/cpu_float_support.cc @@ -0,0 +1,65 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + +#include "xla/service/cpu/cpu_float_support.h" + +#include "xla/service/cpu/onednn_matmul_rewriter.h" + +namespace xla { +namespace cpu { + +bool CpuFloatSupport::IsSupported(const HloInstruction& hlo) const { + switch (hlo.opcode()) { + // oneDNN rewritable ops + case HloOpcode::kDot: + return LowPrecisionType() == BF16 && + OneDnnMatMulRewriter::ShouldRewrite(&hlo); + // Collective ops. + case HloOpcode::kAllGather: + case HloOpcode::kAllReduce: + case HloOpcode::kAllReduceStart: + case HloOpcode::kAllReduceDone: + case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: + case HloOpcode::kReduceScatter: + // Data movement only ops. + case HloOpcode::kBroadcast: + case HloOpcode::kConcatenate: + case HloOpcode::kCopy: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kGather: + case HloOpcode::kPad: + case HloOpcode::kReshape: + case HloOpcode::kReverse: + case HloOpcode::kScatter: + case HloOpcode::kSelect: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kSlice: + case HloOpcode::kTranspose: + // Other special ops. + case HloOpcode::kBitcast: + return true; + default: + return false; + } +} + +} // namespace cpu +} // namespace xla + +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 diff --git a/xla/service/cpu/cpu_float_support.h b/xla/service/cpu/cpu_float_support.h new file mode 100644 index 0000000000000..38c6a9bd81610 --- /dev/null +++ b/xla/service/cpu/cpu_float_support.h @@ -0,0 +1,52 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_CPU_FLOAT_SUPPORT_H_ +#define XLA_SERVICE_CPU_CPU_FLOAT_SUPPORT_H_ + +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + +#include "xla/service/float_support.h" + +namespace xla { +namespace cpu { + +class CpuFloatSupport : public FloatSupport { + public: + explicit CpuFloatSupport(PrimitiveType low_precision_type) + : FloatSupport(low_precision_type) {} + + bool SupportsLowPrecisionOperand(const HloInstruction& hlo, + int64_t operand_index) const override { + return FloatSupport::SupportsLowPrecisionOperand(hlo, operand_index) || + IsSupported(hlo); + } + + bool SupportsLowPrecisionOutput(const HloInstruction& hlo) const override { + return FloatSupport::SupportsLowPrecisionOutput(hlo) || IsSupported(hlo); + } + + private: + bool IsSupported(const HloInstruction& hlo) const; + // Performs early check for things that cannot be delayed becuase some later + // passes may change the shape of dot inputs. + bool DotSupported(const HloInstruction& hlo) const; +}; + +} // namespace cpu +} // namespace xla + +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 +#endif // XLA_SERVICE_CPU_CPU_FLOAT_SUPPORT_H_ diff --git a/xla/service/cpu/cpu_instruction_fusion.cc b/xla/service/cpu/cpu_instruction_fusion.cc index a19bbbc3290eb..1743cad5c2c5e 100644 --- a/xla/service/cpu/cpu_instruction_fusion.cc +++ b/xla/service/cpu/cpu_instruction_fusion.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/fusion_node_indexing_evaluation.h" +#include "xla/service/instruction_fusion.h" #include "xla/service/llvm_ir/fused_ir_emitter.h" namespace xla { @@ -96,10 +97,7 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return "Fusion is not profitable."; } - if (auto fusible = InstructionFusion::ShouldFuse(consumer, operand_index); - !fusible) { - return fusible; - } + RETURN_IF_NOT_FUSIBLE(InstructionFusion::ShouldFuse(consumer, operand_index)); // Fuse constants in general but avoid creating 2-instruction fusions with // just a constant and another node. diff --git a/xla/service/cpu/cpu_instruction_fusion.h b/xla/service/cpu/cpu_instruction_fusion.h index 1e090cb69b396..15a7316ba3afc 100644 --- a/xla/service/cpu/cpu_instruction_fusion.h +++ b/xla/service/cpu/cpu_instruction_fusion.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,9 +31,9 @@ class CpuInstructionFusion : public InstructionFusion { ~CpuInstructionFusion() override = default; using HloPassInterface::Run; - StatusOr Run(HloModule* module, - const absl::flat_hash_set& - execution_threads) override { + absl::StatusOr Run(HloModule* module, + const absl::flat_hash_set& + execution_threads) override { fusion_node_evaluations_.clear(); return InstructionFusion::Run(module, execution_threads); } diff --git a/xla/service/cpu/cpu_instruction_fusion_test.cc b/xla/service/cpu/cpu_instruction_fusion_test.cc index c9f0d4de950b0..de8edeb109675 100644 --- a/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/cpu_layout_assignment.cc b/xla/service/cpu/cpu_layout_assignment.cc index 8b124ddaa6039..35371ea946369 100644 --- a/xla/service/cpu/cpu_layout_assignment.cc +++ b/xla/service/cpu/cpu_layout_assignment.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,12 +15,18 @@ limitations under the License. #include "xla/service/cpu/cpu_layout_assignment.h" +#include #include +#include +#include #include "absl/container/flat_hash_map.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/map_util.h" #include "xla/service/cpu/dot_op_emitter.h" #include "xla/service/cpu/ir_emission_utils.h" +#include "xla/shape_util.h" #include "tsl/platform/errors.h" namespace xla { @@ -78,12 +84,17 @@ static optional ShouldMakeOperandColumnMajor( return it->second ? operand_idx : nullopt; } -static Shape RowMajorShape(const Shape& old_shape) { - Shape new_shape(old_shape); - std::vector dimension_order(new_shape.dimensions_size()); - std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); - *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); - return new_shape; +static Shape RowMajorShape(Shape shape) { + ShapeUtil::ForEachMutableSubshape( + &shape, [](Shape* subshape, const ShapeIndex& index) { + if (!subshape->IsArray()) { + return; + } + std::vector dimension_order(subshape->dimensions_size()); + std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); + *subshape->mutable_layout() = LayoutUtil::MakeLayout(dimension_order); + }); + return shape; } static Shape ColMajorShape(const Shape& old_shape) { @@ -103,6 +114,8 @@ static bool OperandsAndResultMustHaveRowMajorLayout( } else if (instr.opcode() == HloOpcode::kDot) { return DotOperandsAndResultMustHaveRowMajorLayout(instr, target_machine_features); + } else if (instr.opcode() == HloOpcode::kCustomCall) { + return instr.custom_call_target() == "TopK"; } return false; } @@ -126,6 +139,20 @@ Status CpuLayoutAssignment::AddBackendConstraints( const HloInstruction* op = instruction->operand(*op_idx); TF_RETURN_IF_ERROR( SetOperandLayout(ColMajorShape(op->shape()), instruction, *op_idx)); + } else if (instruction->opcode() == HloOpcode::kReduceScatter) { + // XLA:CPU can only support reduce-scatter where the scatter dimension + // is the most major dimension in the layout. + auto ars = Cast(instruction); + TF_RETURN_IF_ERROR(SetInstructionLayout( + ShapeUtil::MoveDimToMajor(ars->shape(), ars->scatter_dimension()), + ars)); + } else if (instruction->opcode() == HloOpcode::kAllGather) { + // XLA:CPU can only support all-gathers where the gather dimension is the + // most major dimension in the layout. + auto ag = Cast(instruction); + TF_RETURN_IF_ERROR(SetInstructionLayout( + ShapeUtil::MoveDimToMajor(ag->shape(), ag->all_gather_dimension()), + ag)); } else { for (int64_t operand_no = 0; operand_no < instruction->operand_count(); ++operand_no) { diff --git a/xla/service/cpu/cpu_layout_assignment.h b/xla/service/cpu/cpu_layout_assignment.h index 3c0bdb7fc2ea3..35ecde418bc60 100644 --- a/xla/service/cpu/cpu_layout_assignment.h +++ b/xla/service/cpu/cpu_layout_assignment.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/cpu_layout_assignment_test.cc b/xla/service/cpu/cpu_layout_assignment_test.cc index 1bb2d9e034616..3824dfdf14384 100644 --- a/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/xla/service/cpu/cpu_layout_assignment_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -261,7 +261,7 @@ struct DotOutputFusionLayoutAssignmentResult { const HloInstruction* addend_fusion_param; }; -static StatusOr RunDotOutputFusion( +static absl::StatusOr RunDotOutputFusion( HloModule* module, const std::string& test_name, int m, int k, int n, const int64_t dot_operand_idx_in_add) { DotOutputFusionLayoutAssignmentResult result; diff --git a/xla/service/cpu/cpu_options.cc b/xla/service/cpu/cpu_options.cc index 6f4d37ab9d95b..2ceff37b653a4 100644 --- a/xla/service/cpu/cpu_options.cc +++ b/xla/service/cpu/cpu_options.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,6 +26,8 @@ const char* const kXlaForceEnableExperimentalLlvmIrGemm = "xla_force_enable_experimental_llvm_ir_gemm"; const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size"; const char* const kDisableSlpVectorizer = "xla_cpu_disable_slp_vectorizer"; +const char* const kXlaCpuExperimentalOverridePipeline = + "xla_cpu_experimental_override_pipeline"; } // namespace @@ -105,6 +107,17 @@ std::optional> LlvmIrGemmTileSize( tile_size_n_in_vector_width); } +std::optional ExperimentalOverriddenPipeline( + const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + auto it = extra_options_map.find(kXlaCpuExperimentalOverridePipeline); + if (it == extra_options_map.end()) { + return std::nullopt; + } + return it->second; +} + } // namespace options } // namespace cpu } // namespace xla diff --git a/xla/service/cpu/cpu_options.h b/xla/service/cpu/cpu_options.h index 41b60b0f71da6..23476ce22e20a 100644 --- a/xla/service/cpu/cpu_options.h +++ b/xla/service/cpu/cpu_options.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_SERVICE_CPU_CPU_OPTIONS_H_ #define XLA_SERVICE_CPU_CPU_OPTIONS_H_ +#include + #include "xla/service/hlo_module_config.h" // Helper functions for querying options that are specific to the CPU backend. @@ -31,6 +33,8 @@ bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config); std::optional LlvmIrGemvTilingFactor(const HloModuleConfig& config); std::optional> LlvmIrGemmTileSize( const HloModuleConfig& config); +std::optional ExperimentalOverriddenPipeline( + const HloModuleConfig& config); } // namespace options } // namespace cpu diff --git a/xla/service/cpu/cpu_runtime.cc b/xla/service/cpu/cpu_runtime.cc index a4090d67fc8d9..a85ae39067809 100644 --- a/xla/service/cpu/cpu_runtime.cc +++ b/xla/service/cpu/cpu_runtime.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,7 +28,12 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/base/attributes.h" +#include "absl/base/dynamic_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" @@ -38,6 +43,7 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" #include "xla/service/cpu/collectives_interface.h" +#include "xla/service/cpu/cpu_executable_run_options.h" #include "xla/service/cpu/in_process_collectives.h" #include "xla/service/cpu/xfeed_manager.h" #include "xla/service/global_device_id.h" @@ -46,6 +52,7 @@ limitations under the License. #include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" @@ -142,6 +149,9 @@ extern const char* const kTracingStartSymbolName = extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce"; +extern const char* const kAllGatherSymbolName = "__xla_cpu_runtime_AllGather"; +extern const char* const kReduceScatterSymbolName = + "__xla_cpu_runtime_ReduceScatter"; extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll"; extern const char* const kCollectivePermuteSymbolName = "__xla_cpu_runtime_CollectivePermute"; @@ -150,12 +160,18 @@ extern const char* const kPartitionIdSymbolName = extern const char* const kReplicaIdSymbolName = "__xla_cpu_runtime_ReplicaId"; extern const char* const kOneDnnMatMulSymbolName = "__xla_cpu_runtime_OneDnnMatMul"; +extern const char* const kOneDnnSoftmaxSymbolName = + "__xla_cpu_runtime_OneDnnSoftmax"; +extern const char* const kOneDnnLayerNormSymbolName = + "__xla_cpu_runtime_OneDnnLayerNorm"; +extern const char* const kOneDnnMatMulReorderSymbolName = + "__xla_cpu_runtime_OneDnnMatMulReorder"; namespace { // Inverses the encoding of a Shape protobuf into an LLVM global variable. -StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, - int32_t size_bytes) { +absl::StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, + int32_t size_bytes) { ShapeProto shape_proto; if (!shape_proto.ParseFromArray(shape_ptr, size_bytes)) { return tsl::errors::Internal("Failed parsing the shape proto"); @@ -169,7 +185,7 @@ StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, } std::string ShapeString(const void* shape_ptr, int32_t shape_length) { - StatusOr shape = + absl::StatusOr shape = DecodeSelfDescribingShapeConstant(shape_ptr, shape_length); if (shape.ok()) { return ShapeUtil::HumanStringWithLayout(shape.value()); @@ -221,7 +237,7 @@ void ReleaseInfeedBufferAfterDequeueImpl( << device_ordinal; XfeedManager* xfeed = GetXfeedManager(device_ordinal); - StatusOr shape = + absl::StatusOr shape = DecodeSelfDescribingShapeConstant(shape_ptr, shape_length); xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr, std::move(shape)); @@ -259,7 +275,7 @@ void ReleaseOutfeedBufferAfterPopulationImpl( << device_ordinal; XfeedManager* xfeed = GetXfeedManager(device_ordinal); - StatusOr shape = + absl::StatusOr shape = DecodeSelfDescribingShapeConstant(shape_ptr, shape_length); xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr, std::move(shape)); @@ -312,7 +328,29 @@ CollectivesInterface* GetInProcessCollectivesImpl() { return c; } -absl::Duration DefaultCollectiveTimeout() { return absl::InfiniteDuration(); } +CollectivesInterface* GetCollectivesImpl( + const ExecutableRunOptions* run_options) { + if (run_options->cpu_executable_run_options() && + run_options->cpu_executable_run_options()->collectives()) { + return run_options->cpu_executable_run_options()->collectives(); + } + return GetInProcessCollectivesImpl(); +} + +absl::Duration DefaultCollectiveTimeout() { return absl::Minutes(30); } + +absl::StatusOr RankInGlobalDevices( + absl::Span devices, GlobalDeviceId device) { + auto it = absl::c_find(devices, device); + if (it == devices.end()) { + return InvalidArgument( + "Device %d not present in global devices %s.", device.value(), + absl::StrJoin(devices, ", ", [](std::string* out, GlobalDeviceId id) { + absl::StrAppend(out, id.value()); + })); + } + return std::distance(devices.begin(), it); +} ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void AllToAllImpl(const ExecutableRunOptions* run_options, @@ -330,12 +368,14 @@ void AllToAllImpl(const ExecutableRunOptions* run_options, GetRendezvousKey(run_options, device, group, channel_id_present, /*use_global_device_ids=*/std::nullopt, op_id); - auto it = absl::c_find(rendezvous_key.global_devices, device); - CHECK(it != rendezvous_key.global_devices.end()); - int rank = std::distance(rendezvous_key.global_devices.begin(), it); + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); - CollectivesInterface* collectives = GetInProcessCollectivesImpl(); + CollectivesInterface* collectives = GetCollectivesImpl(run_options); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(source_buffers, + sizeof(void*) * num_buffers); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(destination_buffers, + sizeof(void*) * num_buffers); auto communicator = collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); TF_CHECK_OK(communicator->AllToAll( @@ -345,6 +385,62 @@ void AllToAllImpl(const ExecutableRunOptions* run_options, DefaultCollectiveTimeout())); } +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY +void AllGatherImpl(const ExecutableRunOptions* run_options, + int32_t channel_id_present, int32_t use_global_device_ids, + int64_t op_id, const void* replica_groups_str, + int32_t replica_groups_str_size, int64_t buffer_size, + void* source_buffer, void* destination_buffer) { + GlobalDeviceId device(GetDeviceOrdinal(run_options)); + std::string_view replica_groups_serialized( + static_cast(replica_groups_str), replica_groups_str_size); + std::vector group = + ParseReplicaGroupsOnly(replica_groups_serialized).value(); + RendezvousKey rendezvous_key = + GetRendezvousKey(run_options, device, group, channel_id_present, + use_global_device_ids, op_id); + + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); + + CollectivesInterface* collectives = GetCollectivesImpl(run_options); + + auto communicator = + collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); + TF_CHECK_OK(communicator->AllGather(rendezvous_key, buffer_size, + source_buffer, destination_buffer, + DefaultCollectiveTimeout())); +} + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY +void ReduceScatterImpl(const ExecutableRunOptions* run_options, + const void* replica_groups_str, + int32_t replica_groups_str_size, + int32_t channel_id_present, + int32_t use_global_device_ids, int64_t op_id, + int32_t reduction_kind, int32_t element_type, + int64_t chunk_elems, void* input_buffer, + void* output_buffer) { + GlobalDeviceId device(GetDeviceOrdinal(run_options)); + std::string_view replica_groups_serialized( + static_cast(replica_groups_str), replica_groups_str_size); + std::vector group = + ParseReplicaGroupsOnly(replica_groups_serialized).value(); + RendezvousKey rendezvous_key = + GetRendezvousKey(run_options, device, group, channel_id_present, + use_global_device_ids, op_id); + + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); + + CollectivesInterface* collectives = GetCollectivesImpl(run_options); + + auto communicator = + collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); + TF_CHECK_OK(communicator->ReduceScatter( + rendezvous_key, static_cast(reduction_kind), + static_cast(element_type), chunk_elems, input_buffer, + output_buffer, DefaultCollectiveTimeout())); +} + ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void AllReduceImpl(const ExecutableRunOptions* run_options, const void* replica_groups_str, @@ -370,11 +466,9 @@ void AllReduceImpl(const ExecutableRunOptions* run_options, CHECK((num_buffers > 1 && shape.IsTuple()) || (num_buffers == 1 && LayoutUtil::IsDenseArray(shape))); - auto it = absl::c_find(rendezvous_key.global_devices, device); - CHECK(it != rendezvous_key.global_devices.end()); - int rank = std::distance(rendezvous_key.global_devices.begin(), it); + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); - CollectivesInterface* collectives = GetInProcessCollectivesImpl(); + CollectivesInterface* collectives = GetCollectivesImpl(run_options); auto communicator = collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); @@ -421,11 +515,9 @@ void CollectivePermuteImpl(const ExecutableRunOptions* run_options, GetRendezvousKey(run_options, device, {}, channel_id_present, /*use_global_device_ids=*/std::nullopt, op_id); - auto it = absl::c_find(rendezvous_key.global_devices, device); - CHECK(it != rendezvous_key.global_devices.end()); - int rank = std::distance(rendezvous_key.global_devices.begin(), it); + int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); - CollectivesInterface* collectives = GetInProcessCollectivesImpl(); + CollectivesInterface* collectives = GetCollectivesImpl(run_options); auto communicator = collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); @@ -503,6 +595,31 @@ void __xla_cpu_runtime_AllToAll(const xla::ExecutableRunOptions* run_options, destination_buffers); } +void __xla_cpu_runtime_AllGather(const xla::ExecutableRunOptions* run_options, + int32_t channel_id_present, + int32_t use_global_device_ids, int64_t op_id, + const void* replica_groups_str, + int32_t replica_groups_str_size, + int64_t buffer_size, void* source_buffer, + void* destination_buffer) { + return xla::cpu::runtime::AllGatherImpl( + run_options, channel_id_present, use_global_device_ids, op_id, + replica_groups_str, replica_groups_str_size, buffer_size, source_buffer, + destination_buffer); +} + +void __xla_cpu_runtime_ReduceScatter( + const xla::ExecutableRunOptions* run_options, + const void* replica_groups_str, int32_t replica_groups_str_size, + int32_t channel_id_present, int32_t use_global_device_ids, int64_t op_id, + int32_t reduction_kind, int32_t element_type, int64_t chunk_elems, + void* input_buffer, void* output_buffer) { + return xla::cpu::runtime::ReduceScatterImpl( + run_options, replica_groups_str, replica_groups_str_size, + channel_id_present, use_global_device_ids, op_id, reduction_kind, + element_type, chunk_elems, input_buffer, output_buffer); +} + void __xla_cpu_runtime_AllReduce(const xla::ExecutableRunOptions* run_options, const void* replica_groups_str, int32_t replica_groups_str_size, diff --git a/xla/service/cpu/cpu_runtime.h b/xla/service/cpu/cpu_runtime.h index 361a116e7c300..e4fc06fc85bd5 100644 --- a/xla/service/cpu/cpu_runtime.h +++ b/xla/service/cpu/cpu_runtime.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -84,7 +84,12 @@ extern const char* const kReplicaIdSymbolName; extern const char* const kTracingStartSymbolName; extern const char* const kTracingEndSymbolName; extern const char* const kAllToAllSymbolName; +extern const char* const kAllGatherSymbolName; +extern const char* const kReduceScatterSymbolName; extern const char* const kOneDnnMatMulSymbolName; +extern const char* const kOneDnnSoftmaxSymbolName; +extern const char* const kOneDnnLayerNormSymbolName; +extern const char* const kOneDnnMatMulReorderSymbolName; // All symbol names for XLA CPU runtime functions need to start with this // prefix. @@ -195,6 +200,19 @@ extern void __xla_cpu_runtime_AllToAll( int32_t replica_groups_str_size, int32_t num_buffers, int64_t buffer_size, void** source_buffers, void** destination_buffers); +extern void __xla_cpu_runtime_AllGather( + const xla::ExecutableRunOptions* run_options, int32_t channel_id_present, + int32_t use_global_device_ids, int64_t op_id, + const void* replica_groups_str, int32_t replica_groups_str_size, + int64_t buffer_size, void* source_buffer, void* destination_buffer); + +void __xla_cpu_runtime_ReduceScatter( + const xla::ExecutableRunOptions* run_options, + const void* replica_groups_str, int32_t replica_groups_str_size, + int32_t channel_id_present, int32_t use_global_device_ids, int64_t op_id, + int32_t reduction_kind, int32_t element_type, int64_t chunk_elems, + void* input_buffer, void* output_buffer); + // Write the partition ID into the output buffer. extern void __xla_cpu_runtime_PartitionId( const xla::ExecutableRunOptions* run_options, void* output_buffer); diff --git a/xla/service/cpu/cpu_runtime_test.cc b/xla/service/cpu/cpu_runtime_test.cc index ee0c92fb702fa..107119e32e00c 100644 --- a/xla/service/cpu/cpu_runtime_test.cc +++ b/xla/service/cpu/cpu_runtime_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/cpu_symbol_repository.h b/xla/service/cpu/cpu_symbol_repository.h index da5bd03459a25..425ca211b43ef 100644 --- a/xla/service/cpu/cpu_symbol_repository.h +++ b/xla/service/cpu/cpu_symbol_repository.h @@ -1,7 +1,7 @@ #ifndef XLA_SERVICE_CPU_CPU_SYMBOL_REPOSITORY_H_ #define XLA_SERVICE_CPU_CPU_SYMBOL_REPOSITORY_H_ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/cpu_transfer_manager.cc b/xla/service/cpu/cpu_transfer_manager.cc index aa4eeb1a338e3..27e62fc1af723 100644 --- a/xla/service/cpu/cpu_transfer_manager.cc +++ b/xla/service/cpu/cpu_transfer_manager.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/statusor.h" #include "xla/stream_executor/host/host_platform_id.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/types.h" #include "xla/util.h" @@ -61,7 +62,7 @@ Status CpuTransferManager::ReadDynamicShapes(se::Stream* stream, device_shape); } TF_ASSIGN_OR_RETURN(auto platform, - se::MultiPlatformManager::PlatformWithId(PlatformId())); + se::PlatformManager::PlatformWithId(PlatformId())); TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); return ReadDynamicShapesOnCpu(device_buffer, device_shape, compiler->ShapeSizeBytesFunction()); diff --git a/xla/service/cpu/cpu_transfer_manager.h b/xla/service/cpu/cpu_transfer_manager.h index bf9b10fff4702..9cdf0478a47cb 100644 --- a/xla/service/cpu/cpu_transfer_manager.h +++ b/xla/service/cpu/cpu_transfer_manager.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/cpu_xfeed.cc b/xla/service/cpu/cpu_xfeed.cc index fbee9570b654a..b6fea6cb0d385 100644 --- a/xla/service/cpu/cpu_xfeed.cc +++ b/xla/service/cpu/cpu_xfeed.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -50,7 +50,7 @@ class CpuInfeedBuffer : public cpu::runtime::XfeedBuffer { int32_t length() override { return length_; } void* data() override { return buffer_; } - void Done(StatusOr /*shape*/) override { delete this; } + void Done(absl::StatusOr /*shape*/) override { delete this; } private: int32_t length_; @@ -62,14 +62,14 @@ class CpuOutfeedBuffer : public cpu::runtime::XfeedBuffer { CpuOutfeedBuffer(void* destination, int32_t length) : destination_(destination), length_(length) {} - StatusOr WaitForNotification() { + absl::StatusOr WaitForNotification() { done_.WaitForNotification(); return status_; } int32_t length() override { return length_; } void* data() override { return destination_; } - void Done(StatusOr shape) override { + void Done(absl::StatusOr shape) override { status_ = std::move(shape); done_.Notify(); } @@ -77,13 +77,13 @@ class CpuOutfeedBuffer : public cpu::runtime::XfeedBuffer { private: void* destination_; int32_t length_; - StatusOr status_; + absl::StatusOr status_; tsl::Notification done_; }; // Transfers infeed data to device. InfeedBuffer->Done() must be called to // clean up the memory allocated for InfeedBuffer. -StatusOr TransferBufferToInfeedInternal( +absl::StatusOr TransferBufferToInfeedInternal( int64_t size, const void* source) { if (size > std::numeric_limits::max()) { return InvalidArgument("CPU infeed of %d bytes exceeds maximum of %d bytes", @@ -114,7 +114,7 @@ Status TransferBufferToInfeed(int device_ordinal, int64_t size, return OkStatus(); } -StatusOr TransferBuffersFromOutfeedInternal( +absl::StatusOr TransferBuffersFromOutfeedInternal( int device_ordinal, absl::Span> buffer_data, bool is_tuple) { std::vector> buffers; @@ -160,14 +160,14 @@ StatusOr TransferBuffersFromOutfeedInternal( return std::move(outfed_shapes[0]); } -StatusOr TransferArrayBufferFromOutfeed(int device_ordinal, - void* destination, - int64_t size_bytes) { +absl::StatusOr TransferArrayBufferFromOutfeed(int device_ordinal, + void* destination, + int64_t size_bytes) { return TransferBuffersFromOutfeedInternal( device_ordinal, {{destination, size_bytes}}, /*is_tuple=*/false); } -StatusOr TransferTupleBuffersFromOutfeed( +absl::StatusOr TransferTupleBuffersFromOutfeed( int device_ordinal, absl::Span> buffer_data) { return TransferBuffersFromOutfeedInternal(device_ordinal, buffer_data, @@ -281,7 +281,8 @@ Status ReadDynamicShapesOnCpu( TF_RET_CHECK(device_shape->is_dynamic()); Shape original_device_shape = *device_shape; TF_RETURN_IF_ERROR(device_buffer->buffers().ForEachElementWithStatus( - [&](const ShapeIndex& index, const se::DeviceMemoryBase& buffer) { + [&](const ShapeIndex& index, + const se::DeviceMemoryBase& buffer) -> absl::Status { const Shape& buffer_shape = ShapeUtil::GetSubshape(*device_shape, index); if (buffer_shape.IsTuple()) { diff --git a/xla/service/cpu/cpu_xfeed.h b/xla/service/cpu/cpu_xfeed.h index 418768f6c9e3a..26512839d5017 100644 --- a/xla/service/cpu/cpu_xfeed.h +++ b/xla/service/cpu/cpu_xfeed.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/dot_op_emitter.cc b/xla/service/cpu/dot_op_emitter.cc index e7c1371cbdbd6..c6bad1e54ec04 100644 --- a/xla/service/cpu/dot_op_emitter.cc +++ b/xla/service/cpu/dot_op_emitter.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -326,7 +326,7 @@ Status DotOpEmitter::EmitLinalgMatmul() { /*outputs=*/mlir::ValueRange{a}, /*indexingMaps=*/ mlir::AffineMap::inferFromExprList( - {b_exprs, c_exprs, parallel_exprs}), + {b_exprs, c_exprs, parallel_exprs}, context), /*iteratorTypes=*/iteratorTypes, [](mlir::OpBuilder& b, mlir::Location loc, mlir::ValueRange args) { mlir::ArithBuilder ab(b, loc); diff --git a/xla/service/cpu/dot_op_emitter.h b/xla/service/cpu/dot_op_emitter.h index df59c5410264c..9d4f56179601f 100644 --- a/xla/service/cpu/dot_op_emitter.h +++ b/xla/service/cpu/dot_op_emitter.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/dot_op_emitter_internal.h b/xla/service/cpu/dot_op_emitter_internal.h index 714d3251d779a..dc53817f89b7a 100644 --- a/xla/service/cpu/dot_op_emitter_internal.h +++ b/xla/service/cpu/dot_op_emitter_internal.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/elemental_ir_emitter.cc b/xla/service/cpu/elemental_ir_emitter.cc index 05b5a94daed07..2c5049ade90e2 100644 --- a/xla/service/cpu/elemental_ir_emitter.cc +++ b/xla/service/cpu/elemental_ir_emitter.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "xla/service/llvm_ir/math_ops.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -33,7 +34,7 @@ using xla::llvm_ir::IrArray; namespace xla { namespace cpu { -StatusOr CpuElementalIrEmitter::EmitAtan2( +absl::StatusOr CpuElementalIrEmitter::EmitAtan2( PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs, absl::string_view /*name*/) { std::string function_name; @@ -70,8 +71,8 @@ StatusOr CpuElementalIrEmitter::EmitAtan2( return result; } -StatusOr CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr CpuElementalIrEmitter::EmitTanh( + PrimitiveType prim_type, llvm::Value* value) { bool cast_result_to_fp16 = false; std::string function_name; switch (prim_type) { @@ -105,5 +106,32 @@ StatusOr CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, return result; } +absl::StatusOr CpuElementalIrEmitter::EmitErf( + PrimitiveType prim_type, llvm::Value* value) { + if (prim_type == F64) { + std::string function_name = "erf"; + // Create a function declaration. + llvm::Function* function = llvm::dyn_cast( + module() + ->getOrInsertFunction(function_name, value->getType(), + value->getType()) + .getCallee()); + function->setCallingConv(llvm::CallingConv::C); + function->setDoesNotThrow(); + function->setDoesNotAccessMemory(); + // Create an instruction to call the function. + llvm::Value* result = Call(function, value); + return result; + } + // Upcast F16 to F32 if necessary. + llvm::Type* type = prim_type == F16 ? b()->getFloatTy() : value->getType(); + if (type == b()->getFloatTy()) { + llvm::Value* x = FPCast(value, type); + auto* result = llvm_ir::EmitErfF32(b(), x); + return FPCast(result, value->getType()); + } + return Unimplemented("erf"); +} + } // namespace cpu } // namespace xla diff --git a/xla/service/cpu/elemental_ir_emitter.h b/xla/service/cpu/elemental_ir_emitter.h index 8d2efb0fce4a5..3056047e0787c 100644 --- a/xla/service/cpu/elemental_ir_emitter.h +++ b/xla/service/cpu/elemental_ir_emitter.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,13 +36,15 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { ir_emitter_(ir_emitter) {} protected: - StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs, - absl::string_view name) override; - StatusOr EmitTanh(PrimitiveType prim_type, - llvm::Value* value) override; + absl::StatusOr EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, llvm::Value* rhs, + absl::string_view name) override; + absl::StatusOr EmitTanh(PrimitiveType prim_type, + llvm::Value* value) override; + absl::StatusOr EmitErf(PrimitiveType prim_type, + llvm::Value* value) override; - StatusOr> EmitThreadLocalCall( + absl::StatusOr> EmitThreadLocalCall( const HloComputation& callee, absl::Span parameters, absl::string_view name, bool is_reducer) override { return ir_emitter_->EmitThreadLocalCall(callee, parameters, name, diff --git a/xla/service/cpu/executable.proto b/xla/service/cpu/executable.proto index ada719977f4dd..2c48a51f4b043 100644 --- a/xla/service/cpu/executable.proto +++ b/xla/service/cpu/executable.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,14 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -syntax = "proto2"; +syntax = "proto3"; package xla.cpu; import "xla/service/cpu/xla_framework.proto"; import "xla/service/hlo.proto"; +import "xla/xla.proto"; message XlaRuntimeCpuExecutableProto { optional XlaRuntimeExecutableProto xla_runtime_executable = 1; optional XlaFrameworkMappingProto xla_framework_mapping = 2; } + +message CompilationResultProto { + HloModuleProtoWithConfig hlo_module = 1; + BufferAssignmentProto buffer_assignment = 2; + string entry_function_name = 3; + bytes obj_file = 4; +} diff --git a/xla/service/cpu/hlo_xla_runtime_pipeline.cc b/xla/service/cpu/hlo_xla_runtime_pipeline.cc index caff8ee13e90f..331f5d933671a 100644 --- a/xla/service/cpu/hlo_xla_runtime_pipeline.cc +++ b/xla/service/cpu/hlo_xla_runtime_pipeline.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ limitations under the License. #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" // from @llvm-project #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" // from @llvm-project -#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project @@ -33,7 +32,6 @@ limitations under the License. #include "mlir/Dialect/Bufferization/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Func/Transforms/Passes.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" // from @llvm-project @@ -46,6 +44,8 @@ limitations under the License. #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "xla/mlir/backends/cpu/transforms/passes.h" #include "xla/mlir/runtime/transforms/compiler.h" @@ -56,6 +56,11 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#ifdef EXPERIMENTAL_MLIR_GPU +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project +#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project +#endif // EXPERIMENTAL_MLIR_GPU + namespace xla { namespace cpu { namespace { @@ -81,51 +86,6 @@ mlir::bufferization::OneShotBufferizationOptions GetBufferizationOptions( return options; } -void AddSparsificationPasses(mlir::OpPassManager& pm, bool new_deallocator, - int32_t xla_cpu_sparse_cuda_threads) { - // Sparse GPU acceleration for sparsified code. - // Setting 0 threads means no acceleration (default). - // Setting 1 thread means cuSPARSE libgen. - // Otherwise direct CUDA codegen. - const bool gpu_codegen = xla_cpu_sparse_cuda_threads > 0; - const bool gpu_libgen = xla_cpu_sparse_cuda_threads == 1; - mlir::SparsificationOptions sparsification_options; - sparsification_options.enableRuntimeLibrary = false; - if (gpu_codegen && !gpu_libgen) { - sparsification_options.parallelizationStrategy = - mlir::SparseParallelizationStrategy::kDenseOuterLoop; - } - // Sparsification set up. - pm.addNestedPass(mlir::createLinalgGeneralizationPass()); - pm.addPass(mlir::bufferization::createEmptyTensorEliminationPass()); - pm.addPass(mlir::createSparsificationAndBufferizationPass( - GetBufferizationOptions(new_deallocator), sparsification_options, - /*createSparseDeallocs=*/false, - /*enableRuntimeLibrary=*/false, - /*enableBufferInitialization=*/false, - /*vectorLength=*/0, - /*enableVLAVectorization=*/false, - /*enableSIMDIndex32=*/false, - /*enableGPULibgen=*/gpu_libgen)); - pm.addPass(mlir::createStorageSpecifierToLLVMPass()); - pm.addNestedPass(mlir::createCanonicalizerPass()); - pm.addNestedPass( - mlir::bufferization::createFinalizingBufferizePass()); - // Sparse GPU acceleration lowers to GPU dialect. - if (gpu_codegen) { - pm.addPass( - mlir::createSparseGPUCodegenPass(xla_cpu_sparse_cuda_threads, false)); - pm.addNestedPass(mlir::createStripDebugInfoPass()); - pm.addNestedPass(mlir::createConvertSCFToCFPass()); - pm.addNestedPass( - mlir::createConvertGpuOpsToNVVMOps()); - } -} - -void AddSparsificationPassPipeline(mlir::OpPassManager& pm) { - AddSparsificationPasses(pm, false, /*xla_cpu_sparse_cuda_threads=*/0); -} - } // namespace // -------------------------------------------------------------------------- // @@ -150,22 +110,6 @@ static Status CreateHloXlaPipeline( pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createCanonicalizerPass()); - // Some early sparse rewriting rules. - if (options.sparse_bufferization) { - pm.addNestedPass(createSparseCustomCallRewritingPass()); - // We wrap some CHLO unary operations with custom calls to preserve the - // sparsity information for those operations during the roundtrip. We now - // invoke the needed passes to lower such CHLO operations to HLO after we - // rewrite the custom calls back to such CHLO unary operations. - pm.addNestedPass( - mlir::mhlo::createLegalizeSparseOperationsPass( - /*legalizeToCustomCalls=*/false)); - pm.addNestedPass( - mlir::mhlo::createChloLegalizeToHloPass()); - pm.addNestedPass( - mlir::mhlo::createSparseRewritingPass()); - } - // Transform HLO operations to Linalg. pm.addNestedPass( mlir::mhlo::createLegalizeControlFlowPass()); @@ -210,7 +154,11 @@ static Status CreateHloXlaPipeline( // one-shot-bufferize generates unnecessary allocs for. The detensorize pass // replaces these linalg.generics with scalar ops. auto detensorize = mlir::createLinalgDetensorizePass(); - if (detensorize->initializeOptions("aggressive-mode=true").failed()) { + if (detensorize + ->initializeOptions( + "aggressive-mode=true", + [](const mlir::Twine&) { return mlir::failure(); }) + .failed()) { return tsl::errors::Internal("Failed to set up detensorize pass."); } pm.addNestedPass(std::move(detensorize)); @@ -221,13 +169,7 @@ static Status CreateHloXlaPipeline( // Always run canonicalizer (which does dead code removal) before // bufferizing anything. pm.addPass(mlir::createCanonicalizerPass()); - - if (options.sparse_bufferization) { - // Convert Sparse tensors. - AddSparsificationPasses(pm, false, options.xla_cpu_sparse_cuda_threads); - } else { - pm.addPass(mlir::hlo::createOneShotBufferizePass()); - } + pm.addPass(mlir::hlo::createOneShotBufferizePass()); pm.addNestedPass(createRewriteReallocToAllocPass()); pm.addNestedPass(mlir::createVectorizeCopyPass()); pm.addNestedPass(mlir::createNaiveCopyRemovalPass()); @@ -243,13 +185,13 @@ static Status CreateHloXlaPipeline( } pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createCanonicalizerPass()); - mlir::bufferization::BufferResultsToOutParamsOptions out_params_options; - out_params_options.filterFn = [](mlir::func::FuncOp* func) { + mlir::bufferization::BufferResultsToOutParamsOpts out_params_opts; + out_params_opts.filterFn = [](mlir::func::FuncOp* func) { // Only transform the entry point. return func->getSymName() == "main"; }; - pm.addPass(mlir::bufferization::createBufferResultsToOutParamsPass( - out_params_options)); + pm.addPass( + mlir::bufferization::createBufferResultsToOutParamsPass(out_params_opts)); pm.addNestedPass( mlir::bufferization::createPromoteBuffersToStackPass(nullptr)); @@ -317,10 +259,5 @@ static mlir::PassPipelineRegistration<> hlo_xla_runtime_pipeline( } }); -static mlir::PassPipelineRegistration<> sparsification_pipeline( - "hlo-xla-runtime-sparsification", - "Sparsification passes from HLO-XLA Runtime pipeline", - AddSparsificationPassPipeline); - } // namespace cpu } // namespace xla diff --git a/xla/service/cpu/hlo_xla_runtime_pipeline.h b/xla/service/cpu/hlo_xla_runtime_pipeline.h index c2fa5f6f4b4d2..5b5a970d1352f 100644 --- a/xla/service/cpu/hlo_xla_runtime_pipeline.h +++ b/xla/service/cpu/hlo_xla_runtime_pipeline.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/in_process_collectives.cc b/xla/service/cpu/in_process_collectives.cc index fc13c6c387037..b307807f28739 100644 --- a/xla/service/cpu/in_process_collectives.cc +++ b/xla/service/cpu/in_process_collectives.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -39,6 +39,7 @@ limitations under the License. #include "xla/service/global_device_id.h" #include "xla/status_macros.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" namespace xla { @@ -93,7 +94,7 @@ template constexpr bool always_false_v = false; template -void Reduce(absl::Span acc, absl::Span const> inputs) { +void ReduceHelper(absl::Span acc, absl::Span inputs) { // TODO(penporn): make sure this gets vectorized. if constexpr (reduction_kind == ReductionKind::SUM) { for (size_t j = 0; j < inputs.size(); ++j) { @@ -124,6 +125,49 @@ void Reduce(absl::Span acc, absl::Span const> inputs) { } } +template +absl::Status ReduceScatter(ReductionKind reduction_kind, + absl::Span inputs, void* output, + int64_t num_elems) { + using T = typename primitive_util::PrimitiveTypeToNative::type; + T initial_value = GetInitialValue(reduction_kind); + + absl::Span out_chunk = + absl::MakeSpan(reinterpret_cast(output), num_elems); + for (int64_t i = 0; i < num_elems; ++i) { + out_chunk[i] = initial_value; + } + + absl::Span input_chunks( + reinterpret_cast(inputs.data()), inputs.size()); + switch (reduction_kind) { + case ReductionKind::SUM: + ReduceHelper(out_chunk, input_chunks); + break; + case ReductionKind::PRODUCT: + ReduceHelper(out_chunk, input_chunks); + break; + case ReductionKind::MIN: + if constexpr (!is_complex_v) { + ReduceHelper(out_chunk, input_chunks); + } else { + return absl::InvalidArgumentError( + "Min reductions not supported for complex types"); + } + break; + case ReductionKind::MAX: + if constexpr (!is_complex_v) { + ReduceHelper(out_chunk, input_chunks); + } else { + return absl::InvalidArgumentError( + "Max reductions not supported for complex types"); + } + break; + } + + return absl::OkStatus(); +} + class CpuAllReduceRendezvous : public Rendezvous { public: @@ -146,110 +190,86 @@ class CpuAllReduceRendezvous return nullptr; } + auto bytes_per_elem = primitive_util::ByteWidth(me.primitive_type); + int64_t chunk_offset = start_elem * bytes_per_elem; + int64_t chunk_bytes = chunk_elems * bytes_per_elem; + void* reduce_output = + reinterpret_cast(me.destination_data) + chunk_offset; + + std::vector inputs; + inputs.reserve(world_size); + for (const auto& p : participants_) { + inputs.push_back(reinterpret_cast(p->source_data) + + chunk_offset); + } + switch (me.primitive_type) { case S8: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case PRED: case U8: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case S16: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case U16: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case S32: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case U32: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case S64: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case U64: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case F16: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case F32: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case F64: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case C64: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; case C128: - TF_RETURN_IF_ERROR(DoAllReduce(me, start_elem, chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems)); break; default: return absl::UnimplementedError("Unexpected datatype"); } - auto bytes_per_elem = primitive_util::ByteWidth(me.primitive_type); - int64_t chunk_offset = start_elem * bytes_per_elem; - int64_t chunk_bytes = chunk_elems * bytes_per_elem; + // All-gather the reduced chunks. for (const auto& p : participants_) { if (p->local_rank != me.local_rank) { - std::memcpy( - reinterpret_cast(p->destination_data) + chunk_offset, - reinterpret_cast(me.destination_data) + chunk_offset, - chunk_bytes); + std::memcpy(reinterpret_cast(p->destination_data) + chunk_offset, + reduce_output, chunk_bytes); } } return nullptr; } - - template - absl::Status DoAllReduce(const AllReduceParticipantData& me, - int64_t start_elem, int64_t num_elems) { - using T = typename primitive_util::PrimitiveTypeToNative::type; - T initial_value = GetInitialValue(me.reduction_kind); - T* acc = reinterpret_cast(me.destination_data); - for (int64_t i = start_elem; i < start_elem + num_elems; ++i) { - acc[i] = initial_value; - } - - absl::Span out_chunk = absl::MakeSpan( - reinterpret_cast(me.destination_data) + start_elem, num_elems); - std::vector> inputs; - inputs.reserve(participants_.size()); - for (const auto& p : participants_) { - inputs.push_back(absl::Span( - reinterpret_cast(p->source_data) + start_elem, num_elems)); - } - switch (me.reduction_kind) { - case ReductionKind::SUM: - Reduce(out_chunk, inputs); - break; - case ReductionKind::PRODUCT: - Reduce(out_chunk, inputs); - break; - case ReductionKind::MIN: - if constexpr (!is_complex_v) { - Reduce(out_chunk, inputs); - } else { - return absl::InvalidArgumentError( - "Min reductions not supported for complex types"); - } - break; - case ReductionKind::MAX: - if constexpr (!is_complex_v) { - Reduce(out_chunk, inputs); - } else { - return absl::InvalidArgumentError( - "Max reductions not supported for complex types"); - } - break; - } - - return absl::OkStatus(); - } }; struct CollectivePermuteParticipantData : ParticipantData { @@ -340,6 +360,147 @@ class CpuAllToAllRendezvous } }; +struct AllGatherParticipantData : ParticipantData { + AllGatherParticipantData(const RendezvousKey& rendezvous_key_p, int rank) + : ParticipantData(rendezvous_key_p, rank) {} + + const void* source_buffer; + void* destination_buffer; + size_t chunk_size; + + std::string ToString() const override { + return absl::StrFormat( + "AllGatherParticipantData{rank=%d, " + "devices=[%s], source_buffer=%p, " + "destination_buffer=%p, chunk_size=%d}", + local_rank, + absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), + source_buffer, destination_buffer, chunk_size); + } +}; + +class CpuAllGatherRendezvous + : public Rendezvous { + public: + explicit CpuAllGatherRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + CollectivesInterface* collectives_; + absl::StatusOr RunCollectiveOp( + const AllGatherParticipantData& p) override { + int world_size = p.rendezvous_key.global_devices.size(); + char* out = static_cast(p.destination_buffer); + for (int i = 0; i < world_size; ++i, out += p.chunk_size) { + std::memcpy(out, participants_[i]->source_buffer, p.chunk_size); + } + return nullptr; + } +}; + +struct ReduceScatterParticipantData : ParticipantData { + ReduceScatterParticipantData(const RendezvousKey& rendezvous_key_p, int rank) + : ParticipantData(rendezvous_key_p, rank) {} + + ReductionKind reduction_kind; + PrimitiveType element_type; + const void* source_buffer; + void* destination_buffer; + size_t chunk_elems; + + std::string ToString() const override { + return absl::StrFormat( + "ReduceScatterParticipantData{rank=%d, " + "devices=[%s], source_buffer=%p, " + "destination_buffer=%p, chunk_elems=%d}", + local_rank, + absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), + source_buffer, destination_buffer, chunk_elems); + } +}; + +class CpuReduceScatterRendezvous + : public Rendezvous { + public: + explicit CpuReduceScatterRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + CollectivesInterface* collectives_; + absl::StatusOr RunCollectiveOp( + const ReduceScatterParticipantData& me) override { + auto bytes_per_elem = primitive_util::ByteWidth(me.element_type); + int64_t chunk_offset = me.local_rank * me.chunk_elems * bytes_per_elem; + + std::vector inputs; + inputs.reserve(participants_.size()); + for (const auto& p : participants_) { + inputs.push_back(reinterpret_cast(p->source_buffer) + + chunk_offset); + } + + switch (me.element_type) { + case S8: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case PRED: + case U8: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case S16: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case U16: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case S32: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case U32: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case S64: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case U64: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case F16: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case F32: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case F64: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case C64: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + case C128: + TF_RETURN_IF_ERROR(ReduceScatter( + me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); + break; + default: + return absl::UnimplementedError("Unexpected datatype"); + } + + return nullptr; + } +}; + } // namespace struct InProcessCollectivesState { @@ -349,6 +510,10 @@ struct InProcessCollectivesState { collective_permute_rendezvous_map; RefcountingHashMap all_to_all_rendezvous_map; + RefcountingHashMap + all_gather_rendezvous_map; + RefcountingHashMap + reduce_scatter_rendezvous_map; }; InProcessCollectivesCommunicator::InProcessCollectivesCommunicator( @@ -429,6 +594,46 @@ absl::Status InProcessCollectivesCommunicator::AllToAll( .status(); } +absl::Status InProcessCollectivesCommunicator::AllGather( + const RendezvousKey& key, size_t chunk_bytes, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + AllGatherParticipantData participant(key, rank_); + participant.chunk_size = chunk_bytes; + participant.source_buffer = input_buffer; + participant.destination_buffer = output_buffer; + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + return CpuAllGatherRendezvous::SubmitParticipant( + [&] { + return state_->all_gather_rendezvous_map.GetOrCreateIfAbsent( + key, make_cpu_rendezvous); + }, + participant) + .status(); +} + +absl::Status InProcessCollectivesCommunicator::ReduceScatter( + const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + ReduceScatterParticipantData participant(key, rank_); + participant.element_type = element_type; + participant.reduction_kind = reduction_kind; + participant.chunk_elems = chunk_elems; + participant.source_buffer = input_buffer; + participant.destination_buffer = output_buffer; + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + return CpuReduceScatterRendezvous::SubmitParticipant( + [&] { + return state_->reduce_scatter_rendezvous_map.GetOrCreateIfAbsent( + key, make_cpu_rendezvous); + }, + participant) + .status(); +} InProcessCollectives::InProcessCollectives() : state_(std::make_unique()) {} InProcessCollectives::~InProcessCollectives() = default; diff --git a/xla/service/cpu/in_process_collectives.h b/xla/service/cpu/in_process_collectives.h index fb25fd3528d60..4551644585a6f 100644 --- a/xla/service/cpu/in_process_collectives.h +++ b/xla/service/cpu/in_process_collectives.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -55,6 +55,16 @@ class InProcessCollectivesCommunicator : public CollectivesCommunicator { absl::Span output_buffers, absl::Duration timeout) override; + absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + + absl::Status ReduceScatter(const RendezvousKey& key, + ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + private: InProcessCollectivesState* state_; int rank_; diff --git a/xla/service/cpu/ir_emission_utils.cc b/xla/service/cpu/ir_emission_utils.cc index 6d9a62c14dec0..8f2ab1ce52652 100644 --- a/xla/service/cpu/ir_emission_utils.cc +++ b/xla/service/cpu/ir_emission_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/ir_emission_utils.h b/xla/service/cpu/ir_emission_utils.h index 74687c6fca0e4..f1eb47483d151 100644 --- a/xla/service/cpu/ir_emission_utils.h +++ b/xla/service/cpu/ir_emission_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/ir_emission_utils_test.cc b/xla/service/cpu/ir_emission_utils_test.cc index ffc50a805160c..6fb2ef0931403 100644 --- a/xla/service/cpu/ir_emission_utils_test.cc +++ b/xla/service/cpu/ir_emission_utils_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/ir_emitter.cc b/xla/service/cpu/ir_emitter.cc index 18523dec84411..221e6554647fc 100644 --- a/xla/service/cpu/ir_emitter.cc +++ b/xla/service/cpu/ir_emitter.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -164,7 +164,7 @@ void IrEmitter::EmitThreadLocalFunctionEpilogue(HloComputation* computation) { } } -StatusOr IrEmitter::EmitComputation( +absl::StatusOr IrEmitter::EmitComputation( HloComputation* computation, absl::string_view function_name_prefix, bool is_top_level_computation, absl::Span instruction_order, @@ -1169,35 +1169,36 @@ Status IrEmitter::HandleAllReduceSingleReplica(HloInstruction* crs) { return OkStatus(); } +// Data types supported by ReduceScatter and AllReduce. +static bool DataTypeIsSupportedByReduceScatter(PrimitiveType datatype) { + // TODO(cheshire): Fix duplication wrt. cpu_runtime + switch (datatype) { + case PRED: + case S8: + case U8: + case S16: + case U16: + case S32: + case U32: + case S64: + case U64: + case F16: + case F32: + case F64: + case C64: + case C128: + return true; + default: + return false; + } +} + Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) { CHECK_GE(crs->operand_count(), 1); PrimitiveType datatype = crs->operand(0)->shape().element_type(); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); - bool is_datatype_supported = [&] { - // TODO(cheshire): Fix duplication wrt. cpu_runtime - switch (datatype) { - case PRED: - case S8: - case U8: - case S16: - case U16: - case S32: - case U32: - case S64: - case U64: - case F16: - case F32: - case F64: - case C64: - case C128: - return true; - default: - return false; - } - }(); - - if (!is_datatype_supported) { + if (!DataTypeIsSupportedByReduceScatter(datatype)) { return Unimplemented("AllReduce for datatype '%s' is not supported", primitive_util::LowercasePrimitiveTypeName(datatype)); } @@ -1285,7 +1286,59 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) { } Status IrEmitter::HandleReduceScatter(HloInstruction* rs) { - return Unimplemented("ReduceScatter is not implemented on CPU."); + CHECK_EQ(rs->operand_count(), 1); + PrimitiveType datatype = rs->operand(0)->shape().element_type(); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(rs)); + + if (!DataTypeIsSupportedByReduceScatter(datatype)) { + return Unimplemented("ReduceScatter for datatype '%s' is not supported", + primitive_util::LowercasePrimitiveTypeName(datatype)); + } + + if (!MatchReductionComputation(rs->to_apply()).has_value()) { + return Unimplemented("ReduceScatter for computation '%s' is not supported", + rs->to_apply()->ToString()); + } + + std::string replica_groups = ReplicaGroupsToString(rs->replica_groups()); + int32_t replica_groups_size = replica_groups.size(); + llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups); + + Shape shape = rs->operand(0)->shape(); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice, + assignment_.GetUniqueSlice(rs->operand(0), {})); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, + assignment_.GetUniqueSlice(rs, {})); + llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape); + llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape); + + bool use_global_device_ids = + Cast(rs)->use_global_device_ids(); + + EmitCallToFunc( + runtime::kReduceScatterSymbolName, + {/*run_options=*/GetExecutableRunOptionsArgument(), + /*replica_groups_str=*/replica_groups_v, + /*replica_groups_str_size=*/b_.getInt32(replica_groups_size), + + /*channel_id_present=*/ + b_.getInt32(static_cast(rs->channel_id().has_value())), + /*use_global_device_ids=*/ + b_.getInt32(static_cast(use_global_device_ids)), + /*op_id=*/ + b_.getInt64(rs->channel_id().has_value() ? *rs->channel_id() + : rs->GetModule()->unique_id()), + /*reduction_kind=*/ + b_.getInt32( + static_cast(*MatchReductionComputation(rs->to_apply()))), + /*element_type=*/ + b_.getInt32(static_cast(datatype)), + /*shape=*/b_.getInt64(ShapeUtil::ElementsIn(rs->shape())), + /*input_buffer=*/input_buffer, + /*output_buffer=*/output_buffer}, + b_.getVoidTy()); + + return OkStatus(); } Status IrEmitter::HandleAllToAll(HloInstruction* instruction) { @@ -1344,6 +1397,57 @@ Status IrEmitter::HandleAllToAll(HloInstruction* instruction) { return OkStatus(); } +Status IrEmitter::HandleAllGather(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction)); + + std::string replica_groups = + ReplicaGroupsToString(instruction->replica_groups()); + int32_t replica_groups_size = replica_groups.size(); + llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups); + + std::vector input_buffer_ptrs; + std::vector output_buffer_ptrs; + + const HloInstruction* op = instruction->operand(0); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice in_slice, + assignment_.GetUniqueSlice(op, {})); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice, + assignment_.GetUniqueSlice(instruction, {})); + const Shape& operand_shape = op->shape(); + CHECK(op->shape().IsArray()) + << "Operand to all-gather must be arrays: " << instruction->ToString(); + llvm::Value* output_buffer = EmitBufferPointer(out_slice, operand_shape); + llvm::Value* input_buffer = GetEmittedValueFor(op); + int64_t buffer_size = in_slice.size(); + + bool use_global_device_ids = + Cast(instruction)->use_global_device_ids(); + + EmitCallToFunc( + runtime::kAllGatherSymbolName, + { + /*run_options=*/GetExecutableRunOptionsArgument(), + /*channel_id_present=*/ + b_.getInt32( + static_cast(instruction->channel_id().has_value())), + /*use_global_device_ids=*/ + b_.getInt32(static_cast(use_global_device_ids)), + /*op_id=*/ + b_.getInt64(instruction->channel_id().has_value() + ? *instruction->channel_id() + : instruction->GetModule()->unique_id()), + /*replica_groups_str=*/replica_groups_v, + /*replica_groups_str_size=*/b_.getInt32(replica_groups_size), + /*buffer_size=*/b_.getInt64(buffer_size), + /*source_buffer=*/input_buffer, + /*destination_buffer=*/output_buffer, + }, + b_.getVoidTy()); + + llvm_ir::EmitTuple(GetIrArrayFor(instruction), output_buffer_ptrs, &b_); + return OkStatus(); +} + Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) { auto* instr = Cast(crs); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instr)); @@ -1622,7 +1726,7 @@ IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( return sharded_vector_type; } -StatusOr +absl::StatusOr IrEmitter::EmitInnerLoopForVectorizedReduction( const ReductionGenerator& reduction_generator, const llvm_ir::IrArray::Index& output_index, @@ -1722,7 +1826,7 @@ void IrEmitter::EmitShardedVectorStore( } } -StatusOr IrEmitter::EmitVectorizedReduce( +absl::StatusOr IrEmitter::EmitVectorizedReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, absl::Span dimensions, HloComputation* function, std::string* failure_reason) { @@ -2112,7 +2216,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { for (auto& padding_dimension : pad->padding_config().dimensions()) { if (padding_dimension.edge_padding_low() < 0 || padding_dimension.edge_padding_high() < 0) { - return InternalErrorStrCat( + return InternalStrCat( "Encountered negative padding in IrEmitter on CPU. " "This should have been eliminated at the HLO level. ", pad->ToString()); @@ -2378,13 +2482,16 @@ Status IrEmitter::HandleTopK(HloInstruction* hlo) { const HloInstruction* input = hlo->operand(0); const int64_t k = hlo->shape().tuple_shapes(0).dimensions().back(); const bool has_batch = hlo->shape().tuple_shapes(0).dimensions_size() == 2; - TF_RET_CHECK(input->shape().element_type() == F32); + TF_RET_CHECK(input->shape().element_type() == F32) << hlo->ToString(); TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major( - hlo->shape().tuple_shapes(0).layout())); + hlo->shape().tuple_shapes(0).layout())) + << hlo->ToString(); TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major( - hlo->shape().tuple_shapes(1).layout())); + hlo->shape().tuple_shapes(1).layout())) + << hlo->ToString(); TF_RET_CHECK( - LayoutUtil::IsMonotonicWithDim0Major(hlo->operand(0)->shape().layout())); + LayoutUtil::IsMonotonicWithDim0Major(hlo->operand(0)->shape().layout())) + << hlo->ToString(); TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice values_slice, assignment_.GetUniqueSlice(hlo->operand(0), {})); @@ -2410,38 +2517,191 @@ Status IrEmitter::HandleTopK(HloInstruction* hlo) { } #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) -Status IrEmitter::HandleOneDnnMatMul(HloInstruction* custom_call) { - auto lhs = custom_call->operand(0); - llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); - auto lhs_stack_alloca = GetAllocaAndEmitMemrefInfo(b_, lhs_array); +Status IrEmitter::HandleOneDnnMatMulCalls(HloInstruction* custom_call, + std::string runtime_symbol_name) { + // We would like to emit LLVM IR for the following function call + // custom_call_target(void* result, void** args) + // args can be thought of an array of pointers allocated on the stack, + // i.e., alloca [nargs x ptr], as such + // args[0]: ptr to nargs + // args[1]: ptr to ExecutableRunOptions + // args[2]: ptr to OneDnnMatMulConfig + // args[3...]: ptrs to operands + // This allows us to pass variable number of operands to the + // custom_call_target function. + // + // Currently, we assume that neither operands nor results are packed into + // tuple(s). + + // First three arguments: nargs, ExecutableRunOptions, and + // OneDnnMatMulConfig. + const int nargs_offset = 3; + const int num_operands = custom_call->operand_count(); + const int nargs = nargs_offset + num_operands; + int arg_indx = 0; + + llvm::Type* i64_type = b_.getInt64Ty(); + llvm::Type* ptr_type = b_.getPtrTy(); + llvm::ArrayType* ptr_array_type = llvm::ArrayType::get(ptr_type, nargs); + llvm::Value* args_val = llvm::UndefValue::get(ptr_array_type); + + // Insert nargs. + llvm::Value* nargs_val = b_.getInt64(nargs); + llvm::Value* nargs_ptr = + llvm_ir::EmitAllocaAtFunctionEntry(i64_type, "nargs", &b_); + b_.CreateLifetimeStart(nargs_ptr, b_.getInt64(-1)); + b_.CreateStore(nargs_val, nargs_ptr); + args_val = b_.CreateInsertValue(args_val, nargs_ptr, arg_indx++); + + // Insert ExecutableRunOptions. + llvm::Value* run_opts_val = GetExecutableRunOptionsArgument(); + args_val = b_.CreateInsertValue(args_val, run_opts_val, arg_indx++); + + // Insert OneDnnMatMulConfig. - auto rhs = custom_call->operand(1); - llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); - auto rhs_stack_alloca = GetAllocaAndEmitMemrefInfo(b_, rhs_array); + auto typed_custom_call = Cast(custom_call); + auto backend_config = typed_custom_call->backend_config(); + OneDnnMatMulConfig matmul_config; + matmul_config.CopyFrom(backend_config->onednn_matmul_config()); + std::string str_config; + matmul_config.SerializeToString(&str_config); + llvm::Value* matmul_config_val = + b_.CreateGlobalStringPtr(llvm_ir::AsStringRef(str_config)); + args_val = b_.CreateInsertValue(args_val, matmul_config_val, arg_indx++); + + // Insert operands. + std::vector operands_stack_alloca; + operands_stack_alloca.reserve(num_operands); + absl::c_transform(custom_call->operands(), operands_stack_alloca.begin(), + [this](HloInstruction* instr) { + llvm_ir::IrArray ir_array(GetIrArrayFor(instr)); + return GetAllocaAndEmitMemrefInfo(b_, ir_array); + }); + for (int i = 0; i < num_operands; ++i) { + args_val = b_.CreateInsertValue(args_val, operands_stack_alloca[i].value, + arg_indx++); + } + TF_RET_CHECK(nargs == arg_indx) + << "Number of arguments don't equal the last argument index."; + + llvm::Value* args_ptr = + llvm_ir::EmitAllocaAtFunctionEntry(ptr_array_type, "matmul.args", &b_); + b_.CreateLifetimeStart(args_ptr, b_.getInt64(-1)); + b_.CreateStore(args_val, args_ptr); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); llvm_ir::IrArray result_array = GetIrArrayFor(custom_call); auto result_stack_alloca = GetAllocaAndEmitMemrefInfo(b_, result_array); + EmitCallToFunc(std::move(runtime_symbol_name), + {result_stack_alloca.value, args_ptr}, b_.getVoidTy()); + + // Lifetime ends for all stack allocations. + b_.CreateLifetimeEnd(nargs_ptr, b_.getInt64(-1)); + for (int i = 0; i < num_operands; ++i) { + operands_stack_alloca[i].EmitLifetimeEnd(); + } + b_.CreateLifetimeEnd(args_ptr, b_.getInt64(-1)); + result_stack_alloca.EmitLifetimeEnd(); + + return OkStatus(); +} + +Status IrEmitter::HandleOneDnnLayerNorm(HloInstruction* custom_call) { + // args[0]: ptr to nargs + // args[1]: ptr to ExecutableRunOptions + // args[2]: ptr to OneDnnLayerNormConfig + // args[3...]: ptrs to operands + + // First three arguments: nargs, ExecutableRunOptions, and + // OneDnnLayerNormConfig. + const int nargs_offset = 3; + const int num_operands = custom_call->operand_count(); + const int nargs = nargs_offset + num_operands; + int arg_indx = 0; + + llvm::Type* i64_type = b_.getInt64Ty(); + llvm::Type* ptr_type = b_.getPtrTy(); + llvm::ArrayType* ptr_array_type = llvm::ArrayType::get(ptr_type, nargs); + llvm::Value* args_val = llvm::UndefValue::get(ptr_array_type); + + // Insert nargs. + llvm::Value* nargs_val = b_.getInt64(nargs); + llvm::Value* nargs_ptr = + llvm_ir::EmitAllocaAtFunctionEntry(i64_type, "nargs", &b_); + b_.CreateLifetimeStart(nargs_ptr, b_.getInt64(-1)); + b_.CreateStore(nargs_val, nargs_ptr); + args_val = b_.CreateInsertValue(args_val, nargs_ptr, arg_indx++); + + // Insert ExecutableRunOptions. + llvm::Value* run_opts_val = GetExecutableRunOptionsArgument(); + args_val = b_.CreateInsertValue(args_val, run_opts_val, arg_indx++); + + // Insert OneDnnLayerNormConfig. auto typed_custom_call = Cast(custom_call); auto backend_config = typed_custom_call->backend_config(); - OneDnnMatMulConfig matmul_config; - matmul_config.CopyFrom(backend_config->onednn_matmul_config()); + OneDnnLayerNormConfig ln_config; + ln_config.CopyFrom(backend_config->onednn_layer_norm_config()); std::string str_config; - matmul_config.SerializeToString(&str_config); + ln_config.SerializeToString(&str_config); + llvm::Value* ln_config_val = + b_.CreateGlobalStringPtr(llvm_ir::AsStringRef(str_config)); + args_val = b_.CreateInsertValue(args_val, ln_config_val, arg_indx++); + + // Insert operands. + std::vector operands_stack_alloca; + operands_stack_alloca.reserve(num_operands); + absl::c_transform(custom_call->operands(), operands_stack_alloca.begin(), + [this](HloInstruction* instr) { + llvm_ir::IrArray ir_array(GetIrArrayFor(instr)); + return GetAllocaAndEmitMemrefInfo(b_, ir_array); + }); + for (int i = 0; i < num_operands; ++i) { + args_val = b_.CreateInsertValue(args_val, operands_stack_alloca[i].value, + arg_indx++); + } + + llvm::Value* args_ptr = + llvm_ir::EmitAllocaAtFunctionEntry(ptr_array_type, "layernorm.args", &b_); + b_.CreateLifetimeStart(args_ptr, b_.getInt64(-1)); + b_.CreateStore(args_val, args_ptr); + + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); + llvm_ir::IrArray result_array = GetIrArrayFor(custom_call); + auto result_stack_alloca = GetAllocaAndEmitMemrefInfo(b_, result_array); + + EmitCallToFunc(runtime::kOneDnnLayerNormSymbolName, + {result_stack_alloca.value, args_ptr}, b_.getVoidTy()); + + // Lifetime ends for all stack allocations. + b_.CreateLifetimeEnd(nargs_ptr, b_.getInt64(-1)); + for (int i = 0; i < num_operands; ++i) { + operands_stack_alloca[i].EmitLifetimeEnd(); + } + b_.CreateLifetimeEnd(args_ptr, b_.getInt64(-1)); + result_stack_alloca.EmitLifetimeEnd(); + + return OkStatus(); +} + +Status IrEmitter::HandleOneDnnSoftmax(HloInstruction* custom_call) { + auto input = custom_call->operand(0); + llvm_ir::IrArray input_array(GetIrArrayFor(input)); + auto input_stack_alloca = GetAllocaAndEmitMemrefInfo(b_, input_array); + + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); + llvm_ir::IrArray result_array = GetIrArrayFor(custom_call); + auto result_stack_alloca = GetAllocaAndEmitMemrefInfo(b_, result_array); - EmitCallToFunc(runtime::kOneDnnMatMulSymbolName, + EmitCallToFunc(runtime::kOneDnnSoftmaxSymbolName, { GetExecutableRunOptionsArgument(), - lhs_stack_alloca.value, - rhs_stack_alloca.value, + input_stack_alloca.value, result_stack_alloca.value, - b_.CreateGlobalStringPtr(llvm_ir::AsStringRef(str_config)), }, b_.getVoidTy()); - lhs_stack_alloca.EmitLifetimeEnd(); - rhs_stack_alloca.EmitLifetimeEnd(); + input_stack_alloca.EmitLifetimeEnd(); result_stack_alloca.EmitLifetimeEnd(); return OkStatus(); @@ -2460,7 +2720,18 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { } #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) if (custom_call->custom_call_target() == "__onednn$matmul") { - return HandleOneDnnMatMul(custom_call); + return HandleOneDnnMatMulCalls(custom_call, + runtime::kOneDnnMatMulSymbolName); + } + if (custom_call->custom_call_target() == "__onednn$softmax") { + return HandleOneDnnSoftmax(custom_call); + } + if (custom_call->custom_call_target() == "__onednn$layernorm") { + return HandleOneDnnLayerNorm(custom_call); + } + if (custom_call->custom_call_target() == "__onednn$matmul_reorder") { + return HandleOneDnnMatMulCalls(custom_call, + runtime::kOneDnnMatMulReorderSymbolName); } #endif // INTEL_MKL && ENABLE_ONEDNN_V3 absl::Span operands(custom_call->operands()); @@ -2528,7 +2799,7 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { break; } default: - return InternalError( + return Internal( "Unknown custom-call API version enum value: %d (%s)", typed_custom_call->api_version(), CustomCallApiVersion_Name(typed_custom_call->api_version())); @@ -2550,13 +2821,13 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { [this, &xla_while](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { auto check = [this](const HloInstruction* a, const HloInstruction* b, - const ShapeIndex& index) { + const ShapeIndex& index) -> absl::Status { const BufferAllocation::Slice slice_a = assignment_.GetUniqueSlice(a, index).value(); const BufferAllocation::Slice slice_b = assignment_.GetUniqueSlice(b, index).value(); if (slice_a != slice_b) { - return InternalError( + return Internal( "instruction %s %s does not share slice with " "instruction %s %s", a->ToString(), slice_a.ToString(), b->ToString(), @@ -2628,7 +2899,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { return OkStatus(); } -StatusOr IrEmitter::EmitFastConcatenate( +absl::StatusOr IrEmitter::EmitFastConcatenate( HloInstruction* concatenate, absl::Span operands, std::string* failure_reason) { if (ShouldEmitParallelLoopFor(*concatenate)) { diff --git a/xla/service/cpu/ir_emitter.h b/xla/service/cpu/ir_emitter.h index 3a194d054cb5f..93516138be50e 100644 --- a/xla/service/cpu/ir_emitter.h +++ b/xla/service/cpu/ir_emitter.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -112,7 +112,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // // If 'allow_reassociation' is true, the fast-math reassociation flag will // be enabled in the function's body. This is used when emitting reducers. - StatusOr EmitComputation( + absl::StatusOr EmitComputation( HloComputation* computation, absl::string_view function_name_prefix, bool is_top_level_computation, absl::Span instruction_order, @@ -134,6 +134,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // special in some way are handled explicitly in HandleFoo methods. Status DefaultAction(HloInstruction* hlo) override; + Status HandleAllGather(HloInstruction* instruction) override; Status HandleAllToAll(HloInstruction* instruction) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleConstant(HloInstruction* constant) override; @@ -194,7 +195,10 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleAllReduceSingleReplica(HloInstruction* crs); Status HandleAllReduceMultipleReplica(HloInstruction* crs); #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) - Status HandleOneDnnMatMul(HloInstruction* hlo); + Status HandleOneDnnMatMulCalls(HloInstruction* hlo, + std::string runtime_symbol_name); + Status HandleOneDnnSoftmax(HloInstruction* hlo); + Status HandleOneDnnLayerNorm(HloInstruction* hlo); #endif // INTEL_MKL && ENABLE_ONEDNN_V3 // Private helper to initialize an IR function for the computation. void InitializeIrFunction(const std::string& function_name); @@ -362,12 +366,10 @@ class IrEmitter : public DfsHloVisitorWithDefault, // concepts that generalize over other vectorizable operations. We should // consider pulling out these abstractions into a VectorizingIrEmitter or // something similar. - StatusOr EmitVectorizedReduce(HloInstruction* reduce, - HloInstruction* arg, - HloInstruction* init_value, - absl::Span dimensions, - HloComputation* function, - std::string* failure_reason); + absl::StatusOr EmitVectorizedReduce( + HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, + absl::Span dimensions, HloComputation* function, + std::string* failure_reason); // We'd like to keep one or two one cache-line's worth of data in registers // without generating IR with illegal (e.g. excessively large or @@ -413,7 +415,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Emits the inner loop nest that runs the reduction. Helper function for // EmitVectorizedReduce. - StatusOr EmitInnerLoopForVectorizedReduction( + absl::StatusOr EmitInnerLoopForVectorizedReduction( const ReductionGenerator& reduction_generator, const llvm_ir::IrArray::Index& output_index, const ShardedVectorType& accumulator_type, HloInstruction* init_value, @@ -423,9 +425,9 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Tries to emit a fast concatenate operation using memcpy. Returns true if // successful, and false on failure. On failure, sets "failure_reason" to a // string describing why it could not emit a fast concatenate. - StatusOr EmitFastConcatenate(HloInstruction* concatenate, - absl::Span operands, - std::string* failure_reason); + absl::StatusOr EmitFastConcatenate( + HloInstruction* concatenate, absl::Span operands, + std::string* failure_reason); // Emits LLVM IR to transfer "element_count" elements of type "primitive_type" // from the address "source" to the address "target". @@ -660,7 +662,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, struct LiteralPtrEqualityFunctor { bool operator()(const Literal* lhs, const Literal* rhs) const { - return *lhs == *rhs; + return *lhs == *rhs && lhs->shape().layout() == rhs->shape().layout(); } }; diff --git a/xla/service/cpu/ir_function.cc b/xla/service/cpu/ir_function.cc index 09ab5ed239b18..69961501a37b8 100644 --- a/xla/service/cpu/ir_function.cc +++ b/xla/service/cpu/ir_function.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/ir_function.h b/xla/service/cpu/ir_function.h index c1841c13dfa8f..47034675bdf8e 100644 --- a/xla/service/cpu/ir_function.h +++ b/xla/service/cpu/ir_function.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/llvm_ir_runtime.cc b/xla/service/cpu/llvm_ir_runtime.cc index f9c0d0ae622b5..f3f0c8d816035 100644 --- a/xla/service/cpu/llvm_ir_runtime.cc +++ b/xla/service/cpu/llvm_ir_runtime.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -156,7 +156,7 @@ void RewriteCalls( llvm::Value* GenerateVF32Tanh(llvm::IRBuilder<>* b, llvm::Value* input, int32_t /*vector_width*/) { - return llvm_ir::EmitFastTanh(b, input); + return llvm_ir::EmitFastTanh(b, input, /*with_fma=*/true); } llvm::Value* GenerateVF32Exp(llvm::IRBuilder<>* b, llvm::Value* input, diff --git a/xla/service/cpu/llvm_ir_runtime.h b/xla/service/cpu/llvm_ir_runtime.h index a45b83f7937d2..8e5c0410c348d 100644 --- a/xla/service/cpu/llvm_ir_runtime.h +++ b/xla/service/cpu/llvm_ir_runtime.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/mlir_emitter.cc b/xla/service/cpu/mlir_emitter.cc index fb805157b54c6..8d3a28815ee01 100644 --- a/xla/service/cpu/mlir_emitter.cc +++ b/xla/service/cpu/mlir_emitter.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/mlir_emitter.h b/xla/service/cpu/mlir_emitter.h index 3787e2d753106..c7a8480e4e9d7 100644 --- a/xla/service/cpu/mlir_emitter.h +++ b/xla/service/cpu/mlir_emitter.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/onednn_layer_norm.cc b/xla/service/cpu/onednn_layer_norm.cc new file mode 100644 index 0000000000000..d2109a1bc2f95 --- /dev/null +++ b/xla/service/cpu/onednn_layer_norm.cc @@ -0,0 +1,107 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + +#include "xla/service/cpu/onednn_layer_norm.h" + +#include +#include +#include +#include + +#define EIGEN_USE_THREADS + +#include "dnnl.hpp" +#include "absl/base/dynamic_annotations.h" +#include "xla/executable_run_options.h" +#include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/onednn_memory_util.h" +#include "xla/service/cpu/runtime_lightweight_check.h" +#include "xla/tsl/util/onednn_threadpool.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace xla { +namespace cpu { +namespace { +using dnnl::engine; +using dnnl::layer_normalization_forward; +using dnnl::memory; +using dnnl::normalization_flags; +using dnnl::prop_kind; +using dnnl::stream; +} // namespace + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnLayerNorm( + void* result, void** args) { + // args[0]: ptr to nargs. We don't use nargs here. + // args[1]: ptr to ExecutableRunOptions + // args[2]: ptr to OneDnnLayerNormConfig + // args[3...]: ptrs to operands + int arg_indx = 1; + const xla::ExecutableRunOptions* run_options = + static_cast(args[arg_indx++]); + XLA_LIGHTWEIGHT_CHECK(run_options != nullptr); + XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); + tsl::OneDnnThreadPool thread_pool( + run_options->intra_op_thread_pool()->getPool(), false); + engine cpu_engine(engine::kind::cpu, 0); +#ifndef ENABLE_ONEDNN_OPENMP + auto onednn_stream = + stream(dnnl::threadpool_interop::make_stream(cpu_engine, &thread_pool)); +#else + auto onednn_stream = stream(cpu_engine); +#endif // ENABLE_ONEDNN_OPENMP + std::string config_str(static_cast(args[arg_indx++])); + OneDnnLayerNormConfig ln_config; + ln_config.ParseFromString(config_str); + + MemrefInfo layer_minfo(args[arg_indx++]); + MemrefInfo gamma_minfo(args[arg_indx++]); + MemrefInfo beta_minfo(args[arg_indx++]); + MemrefInfo result_minfo(result); + + auto src_md = layer_minfo.GetOneDnnMemDesc(); + auto dst_md = result_minfo.GetOneDnnMemDesc(); + auto scaleshift_md = beta_minfo.GetOneDnnMemDesc(); + + auto src_mem = memory(src_md, cpu_engine, layer_minfo.Data()); + auto dst_mem = memory(dst_md, cpu_engine, result_minfo.Data()); + auto scale_mem = memory(scaleshift_md, cpu_engine, gamma_minfo.Data()); + auto shift_mem = memory(scaleshift_md, cpu_engine, beta_minfo.Data()); + + // TODO(intel-tf): Move epsilon to OneDnnLayerNormConfig. + float epsilon; + *(reinterpret_cast(&epsilon)) = ln_config.epsilon_typecast(); + + auto lnorm_pd = layer_normalization_forward::primitive_desc( + cpu_engine, prop_kind::forward_inference, src_md, dst_md, epsilon, + normalization_flags::use_scale | normalization_flags::use_shift); + + auto lnorm_prim = layer_normalization_forward(lnorm_pd); + + std::unordered_map ln_args; + ln_args.insert({DNNL_ARG_SRC, src_mem}); + ln_args.insert({DNNL_ARG_SCALE, scale_mem}); + ln_args.insert({DNNL_ARG_SHIFT, shift_mem}); + ln_args.insert({DNNL_ARG_DST, dst_mem}); + + lnorm_prim.execute(onednn_stream, ln_args); +} + +} // namespace cpu +} // namespace xla + +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 diff --git a/xla/service/cpu/onednn_layer_norm.h b/xla/service/cpu/onednn_layer_norm.h new file mode 100644 index 0000000000000..e3f1634eb7bba --- /dev/null +++ b/xla/service/cpu/onednn_layer_norm.h @@ -0,0 +1,31 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_ONEDNN_LAYER_NORM_H_ +#define XLA_SERVICE_CPU_ONEDNN_LAYER_NORM_H_ +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + +namespace xla { +namespace cpu { + +extern "C" { +extern void __xla_cpu_runtime_OneDnnLayerNorm(void* result, void** args); +} // extern "C" + +} // namespace cpu +} // namespace xla + +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 +#endif // XLA_SERVICE_CPU_ONEDNN_LAYER_NORM_H_ diff --git a/xla/service/cpu/onednn_matmul.cc b/xla/service/cpu/onednn_matmul.cc index acb2143ed244c..4c01c732a96da 100644 --- a/xla/service/cpu/onednn_matmul.cc +++ b/xla/service/cpu/onednn_matmul.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,9 @@ limitations under the License. #include #include +#include #include +#include #include #define EIGEN_USE_THREADS @@ -31,7 +33,10 @@ limitations under the License. #include "xla/service/cpu/backend_config.pb.h" #include "xla/service/cpu/onednn_memory_util.h" #include "xla/service/cpu/runtime_lightweight_check.h" -#include "tsl/util/onednn_threadpool.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/util/onednn_threadpool.h" +#include "tsl/platform/logging.h" namespace xla { namespace cpu { @@ -40,57 +45,338 @@ using dnnl::engine; using dnnl::matmul; using dnnl::memory; using dnnl::stream; + +dnnl::memory::desc Transpose(const dnnl::memory::desc& md) { + int64_t ndims = md.get_ndims(); + // Do not transpose 1D + if (ndims == 1) { + return md; + } + + std::vector permutation(ndims); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation[ndims - 1], permutation[ndims - 2]); + return md.permute_axes(permutation); +} + +dnnl::memory::desc ShapeToMemDesc(const Shape& shape, bool transpose = false) { + auto dimensions = shape.dimensions(); + if (dimensions.size() == 0) { + return dnnl::memory::desc{}; + } + + auto dims = dnnl::memory::dims(dimensions.begin(), dimensions.end()); + + dnnl::memory::dims strides(dims.size()); + dnnl::memory::dim stride = 1; + for (auto i : shape.layout().minor_to_major()) { + strides.at(i) = stride; + stride *= dims.at(i); + } + + auto dt = ToOneDnnDataType(static_cast(shape.element_type())); + + return transpose ? Transpose(dnnl::memory::desc(dims, dt, strides)) + : dnnl::memory::desc(dims, dt, strides); +} + +dnnl::memory::desc OneDnnMatMulOptWeightsDesc( + const dnnl::engine& engine, const dnnl::memory::desc& input_md, + const dnnl::memory::desc& weights_md, const dnnl::memory::desc& bias_md, + const dnnl::memory::desc& output_md) { + auto weights_any_md = + memory::desc(weights_md.get_dims(), weights_md.get_data_type(), + dnnl::memory::format_tag::any); + + auto matmul_pd = matmul::primitive_desc(engine, input_md, weights_any_md, + bias_md, output_md); + + return matmul_pd.weights_desc(); +} + +dnnl::memory::desc OneDnnMatMulOptWeightsDesc( + const dnnl::engine& engine, const Shape& input_shape, + const Shape& weights_shape, const Shape& bias_shape, + const Shape& output_shape, const OneDnnMatMulConfig* matmul_config) { + auto input_md = ShapeToMemDesc(input_shape, matmul_config->transpose_a()); + auto weights_md = ShapeToMemDesc(weights_shape, matmul_config->transpose_b()); + auto bias_md = + absl::c_count(matmul_config->fused_ops(), OneDnnMatMulConfig::BIAS) > 0 + ? ShapeToMemDesc(bias_shape) + : dnnl::memory::desc{}; + auto output_md = ShapeToMemDesc(output_shape); + + // extend bias rank to match result rank + auto missed_rank = output_md.get_ndims() - bias_md.get_ndims(); + XLA_LIGHTWEIGHT_CHECK(missed_rank >= 0); + if (!bias_md.is_zero() && missed_rank > 0) { + auto bias_dims = bias_md.get_dims(); + bias_dims.insert(bias_dims.begin(), missed_rank, 1); + bias_md = bias_md.reshape(bias_dims); + } + + return OneDnnMatMulOptWeightsDesc(engine, input_md, weights_md, bias_md, + output_md); +} + +Shape MemDescToXlaShape(const dnnl::memory::desc& md) { + auto dtype = md.get_data_type(); + auto element_size = dnnl::memory::data_type_size(dtype); + int64_t bytes_num = md.get_size(); + XLA_LIGHTWEIGHT_CHECK(bytes_num % element_size == 0); + int64_t elements_num = static_cast(bytes_num / element_size); + return ShapeUtil::MakeShape(ToXlaPrimitiveType(dtype), {elements_num}); +} + +std::unique_ptr CreateOneDnnThreadPool( + const xla::ExecutableRunOptions* run_options) { +#ifndef ENABLE_ONEDNN_OPENMP + if (run_options != nullptr && + run_options->intra_op_thread_pool() != nullptr) { + return std::make_unique( + run_options->intra_op_thread_pool()->getPool(), false); + } else { + return nullptr; + } +#else + return nullptr; +#endif // ENABLE_ONEDNN_OPENMP +} + +dnnl::stream MakeOneDnnStream( + const dnnl::engine& cpu_engine, + dnnl::threadpool_interop::threadpool_iface* thread_pool) { + if (thread_pool != nullptr) { + return dnnl::threadpool_interop::make_stream(cpu_engine, thread_pool); + } else { + return dnnl::stream(cpu_engine); + } +} + } // namespace +Shape OneDnnMatMulOptWeightsShape(const Shape& input_shape, + const Shape& weights_shape, + const Shape& bias_shape, + const Shape& output_shape, + const OneDnnMatMulConfig* matmul_config) { + engine cpu_engine(engine::kind::cpu, 0); + auto optimized_weights_md = + OneDnnMatMulOptWeightsDesc(cpu_engine, input_shape, weights_shape, + bias_shape, output_shape, matmul_config); + return MemDescToXlaShape(optimized_weights_md); +} + ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMul( - const void* run_options_ptr, void* lhs, void* rhs, void* result, - void* config) { + void* result, void** args) { + // args[0]: ptr to nargs + // args[1]: ptr to ExecutableRunOptions + // args[2]: ptr to OneDnnMatMulConfig + // args[3...]: ptrs to operands + int arg_indx = 0; + const int64_t num_args = *(static_cast(args[arg_indx++])); + const xla::ExecutableRunOptions* run_options = - static_cast(run_options_ptr); + static_cast(args[arg_indx++]); XLA_LIGHTWEIGHT_CHECK(run_options != nullptr); XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); - tsl::OneDnnThreadPool thread_pool( - run_options->intra_op_thread_pool()->getPool(), false); - engine cpu_engine(engine::kind::cpu, 0); -#ifndef ENABLE_ONEDNN_OPENMP - auto onednn_stream = - stream(dnnl::threadpool_interop::make_stream(cpu_engine, &thread_pool)); -#else - auto onednn_stream = stream(cpu_engine); -#endif // ENABLE_ONEDNN_OPENMP - MemrefInfo lhs_minfo(lhs); - MemrefInfo rhs_minfo(rhs); - MemrefInfo result_minfo(result); + auto thread_pool = CreateOneDnnThreadPool(run_options); + engine cpu_engine(engine::kind::cpu, 0); + auto onednn_stream = MakeOneDnnStream(cpu_engine, thread_pool.get()); - std::string config_str(static_cast(config)); + std::string config_str(static_cast(args[arg_indx++])); OneDnnMatMulConfig matmul_config; matmul_config.ParseFromString(config_str); - // Currently, no fusion is supported. - XLA_LIGHTWEIGHT_CHECK(matmul_config.fused_ops().empty()); + MemrefInfo lhs_minfo(args[arg_indx++]); + MemrefInfo rhs_minfo(args[arg_indx++]); + MemrefInfo result_minfo(result); + + auto lhs_md = lhs_minfo.GetOneDnnMemDesc(); + auto rhs_md = rhs_minfo.GetOneDnnMemDesc(); + auto bias_md = memory::desc(); + auto result_md = result_minfo.GetOneDnnMemDesc(); + + // Update dims and strides for transposed inputs. + if (matmul_config.transpose_a()) { + lhs_md = Transpose(lhs_md); + } + + if (matmul_config.transpose_b()) { + rhs_md = Transpose(rhs_md); + } + auto bias_mem = memory(nullptr); + std::vector> postop_args; + + // Currently, GELU/ReLU only fusion is supported. + dnnl::post_ops post_ops; + for (auto& fused_op : matmul_config.fused_ops()) { + switch (fused_op) { + case OneDnnMatMulConfig::RELU: + post_ops.append_eltwise(dnnl::algorithm::eltwise_relu, 0.f, 0.f); + break; + case OneDnnMatMulConfig::TANH: + post_ops.append_eltwise(dnnl::algorithm::eltwise_tanh, 0.f, 0.f); + break; + case OneDnnMatMulConfig::GELU_TANH: + post_ops.append_eltwise(dnnl::algorithm::eltwise_gelu_tanh, 0.f, 0.f); + break; + case OneDnnMatMulConfig::GELU_ERF: + post_ops.append_eltwise(dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); + break; + case OneDnnMatMulConfig::BIAS: { + MemrefInfo bias_minfo(args[arg_indx++]); + bias_md = bias_minfo.GetOneDnnMemDesc(); + + // Extend bias rank to match result rank. + auto missed_rank = result_md.get_ndims() - bias_md.get_ndims(); + XLA_LIGHTWEIGHT_CHECK(missed_rank >= 0); + if (missed_rank > 0) { + auto bias_dims = bias_md.get_dims(); + bias_dims.insert(bias_dims.begin(), missed_rank, 1); + bias_md = bias_md.reshape(bias_dims); + } + bias_mem = memory(bias_md, cpu_engine, bias_minfo.Data()); + } break; + case OneDnnMatMulConfig::BINARY_ADD: { + MemrefInfo binary_minfo(args[arg_indx++]); + auto binary_md = binary_minfo.GetOneDnnMemDesc(); + auto arg_idx = + DNNL_ARG_ATTR_MULTIPLE_POST_OP(post_ops.len()) | DNNL_ARG_SRC_1; + post_ops.append_binary(dnnl::algorithm::binary_add, binary_md); + postop_args.emplace_back( + arg_idx, dnnl::memory(binary_md, cpu_engine, binary_minfo.Data())); + } break; + case OneDnnMatMulConfig::LINEAR: { + float const_float; + *(reinterpret_cast(&const_float)) = + matmul_config.alpha_typecast(); + post_ops.append_eltwise(dnnl::algorithm::eltwise_linear, const_float, + 0.f); + } break; + default: + LOG(FATAL) << __FILE__ << ":" << __LINE__ + << " Attempt to call OneDNN MatMul runtime library with " + "unsupported post op." + << std::endl; + } + } - auto src_md = lhs_minfo.GetOneDnnMemDesc(); - auto weights_md = rhs_minfo.GetOneDnnMemDesc(); - auto dst_md = result_minfo.GetOneDnnMemDesc(); + XLA_LIGHTWEIGHT_CHECK(num_args == arg_indx); - auto src_mem = memory(src_md, cpu_engine, lhs_minfo.Data()); - auto weights_mem = memory(weights_md, cpu_engine, rhs_minfo.Data()); - auto dst_mem = memory(dst_md, cpu_engine, result_minfo.Data()); + dnnl::primitive_attr attrs; + if (post_ops.len() > 0) { + attrs.set_post_ops(post_ops); + } - auto matmul_pd = - matmul::primitive_desc(cpu_engine, src_md, weights_md, dst_md); + bool weights_packed = rhs_md.get_ndims() == 1 && + rhs_md.get_dims().front() != lhs_md.get_dims().back(); + if (weights_packed) { + // expected 2D buffer with last dim of input and last dim of output + auto rhs_any_md = + memory::desc({lhs_md.get_dims().back(), result_md.get_dims().back()}, + rhs_md.get_data_type(), memory::format_tag::any); + + rhs_md = OneDnnMatMulOptWeightsDesc(cpu_engine, lhs_md, rhs_any_md, bias_md, + result_md); + } + + auto lhs_mem = memory(lhs_md, cpu_engine, lhs_minfo.Data()); + auto rhs_mem = memory(rhs_md, cpu_engine, rhs_minfo.Data()); + auto result_mem = memory(result_md, cpu_engine, result_minfo.Data()); + + auto matmul_pd = matmul::primitive_desc(cpu_engine, lhs_md, rhs_md, bias_md, + result_md, attrs); + + if (std::strstr(matmul_pd.impl_info_str(), "ref") != nullptr) { + LOG(WARNING) << "[Perf]: MatMul reference implementation being executed"; + } auto matmul_prim = matmul(matmul_pd); - std::unordered_map matmul_args; - matmul_args.insert({DNNL_ARG_SRC, src_mem}); - matmul_args.insert({DNNL_ARG_WEIGHTS, weights_mem}); - matmul_args.insert({DNNL_ARG_DST, dst_mem}); + std::unordered_map matmul_args{{DNNL_ARG_SRC, lhs_mem}, + {DNNL_ARG_WEIGHTS, rhs_mem}, + {DNNL_ARG_BIAS, bias_mem}, + {DNNL_ARG_DST, result_mem}}; + + matmul_args.insert(postop_args.begin(), postop_args.end()); matmul_prim.execute(onednn_stream, matmul_args); } +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMulReorder( + void* result, void** args) { + // args[0]: ptr to nargs + // args[1]: ptr to ExecutableRunOptions + // args[2]: ptr to OneDnnMatMulConfig + // args[3...]: ptrs to operands + int arg_indx = 0; + const int64_t num_args = *(static_cast(args[arg_indx++])); + + const xla::ExecutableRunOptions* run_options = + static_cast(args[arg_indx++]); + + auto thread_pool = CreateOneDnnThreadPool(run_options); + engine cpu_engine(engine::kind::cpu, 0); + auto onednn_stream = MakeOneDnnStream(cpu_engine, thread_pool.get()); + + std::string config_str(static_cast(args[arg_indx++])); + OneDnnMatMulConfig matmul_config; + matmul_config.ParseFromString(config_str); + + MemrefInfo input_minfo(args[arg_indx++]); + MemrefInfo weight_minfo(args[arg_indx++]); + MemrefInfo output_minfo(args[arg_indx++]); + MemrefInfo result_minfo(result); + + auto input_md = input_minfo.GetOneDnnMemDesc(); + auto weight_md = weight_minfo.GetOneDnnMemDesc(); + auto output_md = output_minfo.GetOneDnnMemDesc(); + + auto bias_md = dnnl::memory::desc{}; + if (absl::c_count(matmul_config.fused_ops(), OneDnnMatMulConfig::BIAS) > 0) { + MemrefInfo bias_minfo(args[arg_indx++]); + bias_md = bias_minfo.GetOneDnnMemDesc(); + } + + XLA_LIGHTWEIGHT_CHECK(num_args >= arg_indx); + + // Update dims and strides for transposed inputs. + bool transpose_a = matmul_config.transpose_a(); + if (transpose_a) { + input_md = Transpose(input_md); + } + bool transpose_b = matmul_config.transpose_b(); + if (transpose_b) { + weight_md = Transpose(weight_md); + } + + // extend bias rank to match result rank + if (!bias_md.is_zero()) { + auto missed_rank = output_md.get_ndims() - bias_md.get_ndims(); + XLA_LIGHTWEIGHT_CHECK(missed_rank >= 0); + if (missed_rank > 0) { + auto bias_dims = bias_md.get_dims(); + bias_dims.insert(bias_dims.begin(), missed_rank, 1); + bias_md = bias_md.reshape(bias_dims); + } + } + + auto result_md = OneDnnMatMulOptWeightsDesc(cpu_engine, input_md, weight_md, + bias_md, output_md); + + XLA_LIGHTWEIGHT_CHECK(result_minfo.GetOneDnnMemDesc().get_size() == + result_md.get_size()); + + auto weight_mem = dnnl::memory{weight_md, cpu_engine, weight_minfo.Data()}; + auto result_mem = dnnl::memory{result_md, cpu_engine, result_minfo.Data()}; + + dnnl::reorder rdr{weight_mem, result_mem}; + rdr.execute(onednn_stream, weight_mem, result_mem); + onednn_stream.wait(); +} + } // namespace cpu } // namespace xla diff --git a/xla/service/cpu/onednn_matmul.h b/xla/service/cpu/onednn_matmul.h index ac43d26a2f05f..6647eee2621a9 100644 --- a/xla/service/cpu/onednn_matmul.h +++ b/xla/service/cpu/onednn_matmul.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,23 +17,21 @@ limitations under the License. #define XLA_SERVICE_CPU_ONEDNN_MATMUL_H_ #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) +#include "xla/service/cpu/backend_config.pb.h" +#include "xla/shape.h" + namespace xla { namespace cpu { +Shape OneDnnMatMulOptWeightsShape(const Shape& input_shape, + const Shape& weights_shape, + const Shape& bias_shape, + const Shape& output_shape, + const OneDnnMatMulConfig* matmul_config); + extern "C" { -// TODO(intel-tf): Change the function signature as -// void onednn_matmul(void* result, void** args) -// where -// args[0]: num_args (>=3, including itself) -// args[1]: ExecutableRunOption -// args[2]: OneDnnMatMulConfig -// args[3...]: Actual Operands -// so that it can take variable number of arguments. -// -// For now, we are using a fixed number of arguments. -extern void __xla_cpu_runtime_OneDnnMatMul(const void* run_options_ptr, - void* lhs, void* rhs, void* result, - void* config); +extern void __xla_cpu_runtime_OneDnnMatMul(void* result, void** args); +extern void __xla_cpu_runtime_OneDnnMatMulReorder(void* result, void** args); } // extern "C" } // namespace cpu diff --git a/xla/service/cpu/onednn_matmul_rewriter.cc b/xla/service/cpu/onednn_matmul_rewriter.cc new file mode 100644 index 0000000000000..c9b9e9b4a04a4 --- /dev/null +++ b/xla/service/cpu/onednn_matmul_rewriter.cc @@ -0,0 +1,865 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + +#define EIGEN_USE_THREADS + +#include "xla/service/cpu/onednn_matmul_rewriter.h" + +#include "xla/executable_run_options.h" +#include "xla/hlo/evaluator/hlo_evaluator.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/onednn_matmul.h" +#include "xla/service/cpu/onednn_memory_util.h" +#include "xla/service/cpu/onednn_util.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/pattern_matcher.h" +#include "xla/status_macros.h" +#include "xla/tsl/util/onednn_threadpool.h" +#include "tsl/platform/logging.h" // IWYU pragma: keep + +namespace xla { +namespace cpu { + +namespace { +namespace m = match; + +inline Status ValidateDotDimensionNumbers( + const DotDimensionNumbers& dim_numbers) { + // Checks some invariants that do not hold in general, but DotDecomposer + // should have established for us. + TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1); + std::vector batch_dim_numbers( + dim_numbers.lhs_batch_dimensions_size()); + absl::c_iota(batch_dim_numbers, 0); + TF_RET_CHECK( + absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions())); + TF_RET_CHECK( + absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions())); + return OkStatus(); +} + +template +auto ElementwiseSafeIntermediate(HloInstruction** instr, Pattern pattern) { + return m::AnyOf(m::Broadcast(instr, pattern.WithOneUser()), + m::Slice(instr, pattern.WithOneUser()), + m::Bitcast(instr, pattern.WithOneUser()), + m::Reshape(instr, pattern.WithOneUser()), + pattern); +} + +inline auto OneDnnMatmulInstr(HloInstruction** instr) { + return m::CustomCall(instr, {"__onednn$matmul"}); +} + +inline auto ConvertBF16ToF32(HloInstruction** instr) { + return m::Convert(m::Op(instr).WithElementType(PrimitiveType::BF16)) + .WithElementType(PrimitiveType::F32); +} + +inline auto BcastConstScalar(HloInstruction** instr, double value) { + return m::Broadcast(instr, m::ConstantScalar(value)); +} + +inline auto BcastConstScalar(double value) { + return BcastConstScalar(nullptr, value); +} + +auto ConstScalarNear(double value) { + return m::ConstantScalar().WithPredicate( + [expected = value](const HloInstruction* instr) { + // Not a very robust floating-point comparison, but good enough for our + // purposes. + std::optional actual = + static_cast(instr) + ->literal() + .GetAsDouble({}); + if (!actual.has_value()) return false; + double epsilon; + switch (instr->shape().element_type()) { + case F16: + epsilon = 128 * std::numeric_limits::epsilon(); + break; + case BF16: + epsilon = 128 * std::numeric_limits::epsilon(); + break; + case F32: + epsilon = 128 * std::numeric_limits::epsilon(); + break; + case F64: + epsilon = 128 * std::numeric_limits::epsilon(); + break; + default: + return false; + } + return abs(*actual - expected) < (abs(*actual + expected) * epsilon); + }); +} + +bool IsScalar(const HloInstruction* instr) { + return ShapeUtil::IsEffectiveScalar(instr->shape()); +} + +std::optional GetConstantValueAsFloat32(const HloInstruction* inst) { + if (!IsScalar(inst)) { + return std::nullopt; + } + switch (inst->shape().element_type()) { + case F16: + return inst->literal().GetFirstElement(); + case BF16: + return inst->literal().GetFirstElement(); + case F32: + return inst->literal().GetFirstElement(); + default: + return std::nullopt; + } +} + +inline auto BcastConstScalarNear(double value) { + return m::Broadcast(ConstScalarNear(value)); +} + +auto GELUActivation(HloInstruction* instr, HloInstruction** src) { + // Attempt to match GELU_TANH activation + // (https://arxiv.org/abs/1606.08415), where: + // gelu_tanh(x) = x * cdf(x) + // cdf(x) = 0.5 * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x**3)) + HloInstruction* errf; + return Match(instr, m::MultiplyAnyOrder( + m::Op(src), + m::MultiplyAnyOrder( + BcastConstScalar(0.5), + m::AddAnyOrder(BcastConstScalar(1.0), + m::Op(&errf).WithOneUser())))) && + Match(errf, + m::Tanh(m::MultiplyAnyOrder( + BcastConstScalarNear(sqrt(M_2_PI)), + m::AddAnyOrder( + m::Op().Is(*src), + m::MultiplyAnyOrder( + BcastConstScalarNear(0.044715), + m::MultiplyAnyOrder( + m::Op().Is(*src), + m::MultiplyAnyOrder(m::Op().Is(*src), + m::Op().Is(*src)) + .WithOneUser()) + .WithOneUser()) + .WithOneUser()) + .WithOneUser()) + .WithOneUser()) + .WithOneUser()); +} + +// OneDNN matmul can fuse add operation with automatic broadcasting along the +// addend's dimensions that are 1s. When compatible, Broadcast can be replaced +// by Bitcast, which is much cheaper. Compute new shape for the Bitcast. +StatusOr AdjustBiasShape(const HloInstruction* broadcast_instr, + const Shape& dot_shape) { + if (broadcast_instr->opcode() != HloOpcode::kBroadcast) { + return absl::InvalidArgumentError( + "Hlo instruction is not a Broadcast insruction."); + } + auto bcast = Cast(broadcast_instr); + Shape new_shape = bcast->shape(); + // Broadcast instruction has "dimensions" parameter along which its input's + // dimensions should not change. For example, + // dot = f32[3,4,5,6] dot(...) + // arg = f32[3,6]{1,0} parameter(0) + // broad = f32[3,4,5,6]{3,2,1,0} broadcast(arg), dimensions={0,3} + // add = f32[3,4,5,6]{3,2,1,0} add(dot, arg) + // can be replaced with the following + // arg = f32[3,6]{1,0} parameter(0) + // bitcast = f32[3,1,1,6]{3,2,1,0} bitcast(arg) + // fused = f32[3,4,5,6]{3,2,1,0} custom-call((..., bitcast) + auto kept_dimensions = bcast->dimensions(); + for (int i = 0; i < new_shape.rank(); i++) { + if (!absl::c_linear_search(kept_dimensions, i)) { + new_shape.set_dimensions(i, 1); + } + } + + // If rank(new_shape) > rank(dot), extra dimensions with value = 1 can be + // deleted from the new_shape. + int64_t rank_difference = new_shape.rank() - dot_shape.rank(); + auto new_dims = new_shape.dimensions(); + std::vector dims_to_delete; + for (int i = 0; i < rank_difference; ++i) { + if (new_dims[i] == 1) { + dims_to_delete.push_back(i); + } + } + new_shape = ShapeUtil::DeleteDimensions(dims_to_delete, new_shape); + + // New shape for bias should satisfy the condition: + // rank(new_shape) <= rank(dot). + if (new_shape.rank() > dot_shape.rank()) { + return absl::CancelledError( + "Bias shape could not be adjusted for a fusion."); + } + + return new_shape; +}; + +inline bool IsOperandFusible(HloInstruction* operand, HloInstruction* dot) { + // Check if the operand's shape is compatible with matmul for fusion. + // An operand is fusable if + // 1. rank(operand) <= rank(dot) and + // 2. Starting from the last dim in backward direction, the dimension + // size of operand is either 1 or same to dot. + auto operand_dims = operand->shape().dimensions(); + auto dot_dims = dot->shape().dimensions(); + if (operand_dims.size() > dot_dims.size()) return false; + int operand_idx = operand_dims.size() - 1; + int dot_idx = dot_dims.size() - 1; + for (; operand_idx >= 0; --operand_idx, --dot_idx) { + if (operand_dims[operand_idx] != 1 && + operand_dims[operand_idx] != dot_dims[dot_idx]) + return false; + } + return true; +} + +inline bool IsRowMajor(const Shape& shape) { + return LayoutUtil::IsMonotonicWithDim0Major(shape.layout()); +} + +// Whether the element type of instr is compatible with oneDNN kernels. +// TODO(intel-tf): Restict compatible types based on instruction kind. +inline bool CompatibleElementType(const HloInstruction* instr) { + PrimitiveType element_type = instr->shape().element_type(); + return element_type == BF16 || element_type == F32 || element_type == F16; +} + +// Type conversion from and to any of BF16, F16 and FP32. +// TODO(intel-tf): Support more types when enabled. +template +inline auto SupportedConvert(Pattern pattern) { + auto supported_convert = [](const HloInstruction* instr) -> bool { + return CompatibleElementType(instr) && + CompatibleElementType(instr->operand(0)); + }; + return m::Convert(pattern).WithPredicate(supported_convert); +} + +template +inline auto SupportedConvert(HloInstruction** convert, Pattern pattern) { + auto supported_convert = [](const HloInstruction* instr) -> bool { + return CompatibleElementType(instr) && + CompatibleElementType(instr->operand(0)); + }; + return m::Convert(convert, pattern).WithPredicate(supported_convert); +} + +template +inline auto BitcastWithReshapeSemantics(HloInstruction** bitcast, + Pattern pattern) { + // TODO(intel-tf): Add stronger condition that Bitcast does not have transpose + // semantics. Some of the HLO passes replaces Transpose with Bitcast. Here + // the layouts are checked to be rowmajor since the current pass runs after + // the layout assignment and oneDNN matmul is enabled for rowmajor layouts. + auto is_reshape = [](const HloInstruction* instr) -> bool { + if (!instr) return false; + auto input_shape = instr->operand(0)->shape(); + auto output_shape = instr->shape(); + bool is_same_type = ShapeUtil::SameElementType(input_shape, output_shape); + bool has_equal_num_elems = ShapeUtil::ElementsIn(input_shape) == + ShapeUtil::ElementsIn(output_shape); + bool has_rowmajor_layout = + IsRowMajor(input_shape) && IsRowMajor(output_shape); + return is_same_type && has_equal_num_elems && has_rowmajor_layout; + }; + return m::Bitcast(bitcast, pattern).WithPredicate(is_reshape); +} + +template +inline auto OptionalConvertAndBitcast(HloInstruction** optional_convert, + HloInstruction** optional_bitcast, + Pattern pattern) { + // Checks the presence of some intermediate operations that can be moved / + // folded to allow dot fusion with add. + // Try to match either of the following: + // 1. pattern-root -> bf16/f16-to-fp32 convert -> bitcast + // 2. pattern-root -> bf16/f16-to-fp32 convert + // 3. pattern-root -> bitcast + // 4. pattern-root + auto common = + m::AnyOf( + SupportedConvert(optional_convert, std::move(pattern).WithOneUser()) + .WithElementType(PrimitiveType::F32), + std::move(pattern).WithOneUser()) + .WithOneUser(); + return m::AnyOf( + BitcastWithReshapeSemantics(optional_bitcast, common), common); +} + +} // namespace + +bool OneDnnMatMulRewriter::ShouldRewrite(const HloInstruction* dot_instr) { + // Currently, blocking control dependencies + if (dot_instr->HasControlDependencies()) return false; + if (!IsSupportedType(dot_instr->shape().element_type())) return false; + if (dot_instr->operands().size() != 2) return false; + + // Currently, we rewrite when the data type is F32 or BF16. Note we do not + // need to check equality of contraction dim-size of the operands. HLO + // verifier already does the job. We, however, need to check if contraction + // is over only 1 dimension (a.k.a. K dimension in matrix-multiplication + // parlance). We also restrict that batch dimensions of the operands + // match. + const Shape& lhs_shape = dot_instr->operand(0)->shape(); + const Shape& rhs_shape = dot_instr->operand(1)->shape(); + const Shape& output_shape = dot_instr->shape(); + // None of the operands and result should be ZeroElementArray. + if (ShapeUtil::IsZeroElementArray(lhs_shape) || + ShapeUtil::IsZeroElementArray(rhs_shape) || + ShapeUtil::IsZeroElementArray(output_shape)) { + return false; + } + // OneDNN only supports rank <= kOneDnnMaxNDims and singular non-contracting + // dimensions. We should not rewrite if any of these conditions are violated. + if (lhs_shape.rank() <= 0 || lhs_shape.rank() > kOneDnnMaxNDims || + rhs_shape.rank() <= 0 || rhs_shape.rank() > kOneDnnMaxNDims || + output_shape.rank() > std::min({lhs_shape.rank(), rhs_shape.rank(), + static_cast(kOneDnnMaxNDims)})) { + return false; + } + + // Layout should be row-major, contraction dimensions captures transpose + // scenarios in last two dimensions. + if (!IsRowMajor(lhs_shape) || !IsRowMajor(rhs_shape) || + !IsRowMajor(output_shape)) { + return false; + } + + auto dot_dim_numbers = dot_instr->dot_dimension_numbers(); + int64_t lhs_dim_k = dot_dim_numbers.lhs_contracting_dimensions(0); + int64_t rhs_dim_k = dot_dim_numbers.rhs_contracting_dimensions(0); + // Supported contraction is only in one of last two dimensions. + if (lhs_dim_k < lhs_shape.rank() - 2 || rhs_dim_k < rhs_shape.rank() - 2) { + return false; + } + + // OneDNN matmul has scratch allocation and copy overheads. The overheads + // can be amortized if there is sufficient number of flops. We don't rewrite + // for small cases (determined empirically). + // TODO(intel-tf): Relax the condition when more optimizations in oneDNN + // matmul is achieved. + auto num_flops = xla::HloCostAnalysis::GetDotFlops(lhs_shape, output_shape, + dot_dim_numbers); + auto rank = output_shape.rank(); + auto flops_threshold = (rank <= 2) ? (1 << 24) : (1 << 19); + return (num_flops >= flops_threshold); +} + +class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { + public: + // Matches patterns for possible MatMul fusions that are supported by oneDNN + // library. Matched HLO instruction(s) are replaced by custom call. + Status HandleDot(HloInstruction* instr) override { + HloInstruction* dot_instr; + auto pattern = m::Op(&dot_instr).WithOpcode(HloOpcode::kDot); + if (!Match(instr, pattern)) return OkStatus(); + + TF_RETURN_IF_ERROR( + ValidateDotDimensionNumbers(dot_instr->dot_dimension_numbers())); + if (!OneDnnMatMulRewriter::ShouldRewrite(dot_instr)) return OkStatus(); + TF_ASSIGN_OR_RETURN(dot_instr, ReconfigureDotDimensions(dot_instr)); + auto dot_dim_numbers = dot_instr->dot_dimension_numbers(); + const Shape& lhs_shape = dot_instr->operand(0)->shape(); + const Shape& rhs_shape = dot_instr->operand(1)->shape(); + const Shape& output_shape = dot_instr->shape(); + + int64_t lhs_dim_k = dot_dim_numbers.lhs_contracting_dimensions(0); + int64_t rhs_dim_k = dot_dim_numbers.rhs_contracting_dimensions(0); + + HloInstruction* matmul_call = + dot_instr->AddInstruction(HloInstruction::CreateCustomCall( + output_shape, + {dot_instr->mutable_operand(0), dot_instr->mutable_operand(1)}, + "__onednn$matmul")); + // Set additional info via config, e.g., transpose and fusion info. + BackendConfig backend_config; + OneDnnMatMulConfig* matmul_config = + backend_config.mutable_onednn_matmul_config(); + bool transpose_a = (lhs_dim_k != lhs_shape.rank() - 1); + bool transpose_b = (rhs_dim_k != rhs_shape.rank() - 2); + matmul_config->set_transpose_a(transpose_a); + matmul_config->set_transpose_b(transpose_b); + TF_RETURN_IF_ERROR(matmul_call->set_backend_config(backend_config)); + TF_RETURN_IF_ERROR(ReplaceInstruction(dot_instr, matmul_call)); + return OkStatus(); + } + + Status HandleAdd(HloInstruction* instr) override { + // Try to do a fusion for Dot(onednn-matmul) + Add. However, + // HLO Add instruction might receive the addends after additional + // processing like Broadcast, Bitcast, Convert, etc. is applied to the raw + // addends. Here, the following possible pattern is matched. + // + // clang-format off + // + // Dot addend + // | | + // v v + // optional instructions optional instructions + // (e.g, Convert, Bitcast) (e.g, Convert, Broadcast) + // | | + // +--------------+-------------------+ + // | + // v + // Add + // + // clang-format on + + HloInstruction *addend_intermediate, *dot; + HloInstruction* optional_dot_bitcast = nullptr; + HloInstruction* optional_dot_convert = nullptr; + + auto pattern = m::AddAnyOrder( + &instr, + OptionalConvertAndBitcast(&optional_dot_convert, &optional_dot_bitcast, + OneDnnMatmulInstr(&dot)) + .WithOneUser(), + m::Op(&addend_intermediate)); + + if (Match(instr, pattern)) { + if (!IsSupportedType(dot->shape().element_type())) return OkStatus(); + // TODO(intel-tf): Remove the condition below when the fusion Dot + + // Add(bias) + Add(e.g., residual) is enabled. + if (!dot->backend_config() + ->mutable_onednn_matmul_config() + ->fused_ops() + .empty() && + dot->backend_config() + ->mutable_onednn_matmul_config() + ->fused_ops(0) == OneDnnMatMulConfig::BIAS) { + return OkStatus(); + } + std::vector new_operands; + for (auto operand : dot->operands()) { + new_operands.push_back(operand); + } + + // At this point, the addend could have one of the following + // possiblities that the current fusion can handle: + // + // - addend -> Convert -> Broadcast -> Add + // - addend -> Broadcast -> Convert -> Add + // - addend -> Convert + // - addend -> Broadcast + // - addend + // + // Hunt for addend through possible sequences above and check the addend + // is compatible to onednn-matmul fusion. + HloInstruction* addend = nullptr; + HloInstruction* optional_addend_broadcast = nullptr; + auto addend_pattern = m::AnyOf( + m::Broadcast(&optional_addend_broadcast, + m::Convert(&addend, m::Op())), + m::Convert(m::Broadcast(&optional_addend_broadcast, m::Op(&addend))), + m::Convert(&addend, m::Op()), + m::Broadcast(&optional_addend_broadcast, m::Op(&addend)), + m::Op(&addend)); + if (!Match(addend_intermediate, addend_pattern)) return OkStatus(); + + if (optional_addend_broadcast && addend->shape().rank() != 1) { + auto new_shape = + AdjustBiasShape(optional_addend_broadcast, dot->shape()); + if (new_shape.ok()) { + addend = addend->AddInstruction( + HloInstruction::CreateBitcast(new_shape.value(), addend)); + } else { + VLOG(2) << new_shape.status(); + return OkStatus(); + } + } + + // Validate addend for fusion. + if (CompatibleElementType(addend) && IsOperandFusible(addend, dot)) { + new_operands.push_back(addend); + } else { + return OkStatus(); + } + + auto matmul_call = Cast(instr->AddInstruction( + dot->CloneWithNewOperands(dot->shape(), new_operands))); + + auto backend_config = matmul_call->backend_config(); + backend_config->mutable_onednn_matmul_config()->add_fused_ops( + addend->shape().rank() != 1 ? OneDnnMatMulConfig::BINARY_ADD + : OneDnnMatMulConfig::BIAS); + if (optional_addend_broadcast) { + backend_config->mutable_onednn_matmul_config()->set_bias_broadcast( + true); + } + TF_RETURN_IF_ERROR(matmul_call->set_backend_config(*backend_config)); + + HloInstruction* new_instr; + // If matched pattern has custom-call -> bitcast -> add, then we need to + // insert bitcast after the new fusion to maintain the correct shape + // (new-custom-call -> bitcast). Also, this will optionally be followed + // by -> convert for bf16 case to avoid datatype mismatch. + if (optional_dot_bitcast != nullptr && + optional_dot_bitcast->opcode() == HloOpcode::kBitcast) { + if (optional_dot_convert != nullptr && + optional_dot_convert->opcode() == HloOpcode::kConvert) { + auto bitcast_call = + matmul_call->AddInstruction(HloInstruction::CreateBitcast( + ShapeUtil::ChangeElementType( + instr->shape(), matmul_call->shape().element_type()), + matmul_call)); + new_instr = + bitcast_call->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType( + bitcast_call->shape(), + optional_dot_convert->shape().element_type()), + bitcast_call)); + } else { + new_instr = matmul_call->AddInstruction( + HloInstruction::CreateBitcast(instr->shape(), matmul_call)); + } + } else { + if (optional_dot_convert != nullptr && + optional_dot_convert->opcode() == HloOpcode::kConvert) { + new_instr = matmul_call->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType( + matmul_call->shape(), + optional_dot_convert->shape().element_type()), + matmul_call)); + } else { + new_instr = matmul_call; + } + } + TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_instr)); + } + + return OkStatus(); + } + + Status HandleMaximum(HloInstruction* instr) override { + HloInstruction* matmul_call; + HloInstruction* intermediate_instr = nullptr; + // Attempt to elide maximum and fuse ReLU activation into GEMM, including + // when slicing or bitcasting is applied to the result. + if (Match(instr, m::MaximumAnyOrder(ElementwiseSafeIntermediate( + &intermediate_instr, + OneDnnMatmulInstr(&matmul_call)) + .WithOneUser(), + BcastConstScalar(0)))) { + return FuseActivation(OneDnnMatMulConfig::RELU, instr, matmul_call, + intermediate_instr); + } + return OkStatus(); + } + + Status HandleMultiply(HloInstruction* instr) override { + HloInstruction* matmul_call; + HloInstruction* intermediate_instr = nullptr; + HloInstruction* src; + if (GELUActivation(instr, &src)) { + if (Match(src, + ElementwiseSafeIntermediate(&intermediate_instr, + OneDnnMatmulInstr(&matmul_call)))) { + return FuseActivation(OneDnnMatMulConfig::GELU_TANH, instr, matmul_call, + intermediate_instr); + } + } + + HloInstruction *dot, *constant; + auto pattern = m::Op(&instr) + .WithOpcode(HloOpcode::kMultiply) + .WithBinaryOperandsAnyOrder( + m::Op(&dot) + .WithOneUser() + .WithOpcode(HloOpcode::kCustomCall) + .WithCustomCallTarget({"__onednn$matmul"}), + m::Broadcast(m::Constant(&constant))); + + if (Match(instr, pattern)) { + std::vector new_operands; + auto constant_value = *GetConstantValueAsFloat32(constant); + + for (auto operand : dot->operands()) { + new_operands.push_back(operand); + } + auto matmul_call = Cast(instr->AddInstruction( + dot->CloneWithNewOperands(instr->shape(), new_operands))); + auto backend_config = matmul_call->backend_config(); + backend_config->mutable_onednn_matmul_config()->add_fused_ops( + OneDnnMatMulConfig::LINEAR); + // Casting to int32 because of issues in proto config for decimal types + // handling. + backend_config->mutable_onednn_matmul_config()->set_alpha_typecast( + *(reinterpret_cast(&constant_value))); + TF_RETURN_IF_ERROR(matmul_call->set_backend_config(*backend_config)); + TF_RETURN_IF_ERROR(ReplaceInstruction(instr, matmul_call)); + } + return OkStatus(); + } + + Status FuseActivation(OneDnnMatMulConfig_FusionKind kind, + HloInstruction* activation, HloInstruction* matmul, + HloInstruction* intermediate_instr = nullptr) { + TF_ASSIGN_OR_RETURN(auto backend_config, + matmul->backend_config()); + auto* matmul_config = backend_config.mutable_onednn_matmul_config(); + matmul_config->add_fused_ops(kind); + + std::unique_ptr output = matmul->Clone(); + TF_RETURN_IF_ERROR(output->set_backend_config(backend_config)); + + if (intermediate_instr) { + output = intermediate_instr->CloneWithNewOperands( + intermediate_instr->shape(), + {matmul->parent()->AddInstruction(std::move(output))}); + } + + return ReplaceWithNewInstruction(activation, std::move(output)); + } + + // This function changes dot instruction for supported matrix + // multiplication scenarios. In particular, it changes the shape + // of lhs, rhs and result arrays. + // - lhs configuration scenario + // lhs: [batch_dims,contracting_dim] to [batch_dims,1,contracting_dim] + // result: [batch_dims,feature_dim] to [batch_dims,1,feature_dim] + // + // - rhs configuration scenario + // rhs: [batch_dims,contracting_dim] to [batch_dims,contracting_dim,1] + // result: [batch_dims,feature_dim] to [batch_dims,feature_dim, 1] + // + // - both lhs and rhs configuration scenario + // lhs: [batch_dims,contracting_dim] to [batch_dims,1,contracting_dim] + // rhs: [batch_dims,contracting_dim] to [batch_dims,contracting_dim,1] + // result: [batch_dims] to [batch_dims,1,1] + StatusOr ReconfigureDotDimensions( + HloInstruction* dot_instr) { + HloInstruction* lhs = dot_instr->mutable_operand(0); + HloInstruction* rhs = dot_instr->mutable_operand(1); + DotDimensionNumbers dim_numbers = dot_instr->dot_dimension_numbers(); + + auto lhs_batch_dims = dim_numbers.lhs_batch_dimensions(); + auto lhs_contraction_dims = dim_numbers.lhs_contracting_dimensions(); + bool is_lhs_vector = lhs->shape().rank() == + (lhs_batch_dims.size() + lhs_contraction_dims.size()); + + auto rhs_batch_dims = dim_numbers.rhs_batch_dimensions(); + auto rhs_contraction_dims = dim_numbers.rhs_contracting_dimensions(); + bool is_rhs_vector = rhs->shape().rank() == + (rhs_batch_dims.size() + rhs_contraction_dims.size()); + + if (!is_lhs_vector && !is_rhs_vector) return dot_instr; + + std::vector adjusted_lhs_dims(lhs->shape().dimensions().begin(), + lhs->shape().dimensions().end()); + std::vector adjusted_rhs_dims(rhs->shape().dimensions().begin(), + rhs->shape().dimensions().end()); + std::vector adjusted_dot_dims( + dot_instr->shape().dimensions().begin(), + dot_instr->shape().dimensions().end()); + + if (is_lhs_vector) { + auto lhs_it = adjusted_lhs_dims.begin() + lhs_batch_dims.size(); + adjusted_lhs_dims.insert(lhs_it, 1, 1); + auto result_it = adjusted_dot_dims.begin() + lhs_batch_dims.size(); + adjusted_dot_dims.insert(result_it, 1, 1); + auto lhs_contraction_dim = + dot_instr->dot_dimension_numbers().lhs_contracting_dimensions(0); + dim_numbers.set_lhs_contracting_dimensions(0, lhs_contraction_dim + 1); + lhs = lhs->AddInstruction(HloInstruction::CreateBitcast( + ShapeUtil::MakeShape(lhs->shape().element_type(), adjusted_lhs_dims), + lhs)); + } + + if (is_rhs_vector) { + auto it = adjusted_rhs_dims.end(); + adjusted_rhs_dims.insert(it, 1, 1); + auto result_it = adjusted_dot_dims.end(); + adjusted_dot_dims.insert(result_it, 1, 1); + rhs = rhs->AddInstruction(HloInstruction::CreateBitcast( + ShapeUtil::MakeShape(rhs->shape().element_type(), adjusted_rhs_dims), + rhs)); + } + + HloInstruction* adjusted_dot = + dot_instr->AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(dot_instr->shape().element_type(), + adjusted_dot_dims), + lhs, rhs, dim_numbers, dot_instr->precision_config())); + + HloInstruction* replacement_instr = adjusted_dot->AddInstruction( + HloInstruction::CreateBitcast(dot_instr->shape(), adjusted_dot)); + + TF_RETURN_IF_ERROR(ReplaceInstruction(dot_instr, replacement_instr)); + return adjusted_dot; + } +}; + +class OneDnnMatMulReorderVisitor : public DfsHloRewriteVisitor { + public: + OneDnnMatMulReorderVisitor(int intra_op_parallelism, + const tsl::thread::ThreadPool* compile_threadpool) + : intra_op_parallelism_(intra_op_parallelism > 0 + ? intra_op_parallelism + : tsl::port::MaxParallelism()), + evaluator_(/*max_loop_iterations=*/0) { + if (compile_threadpool) { + threadpool_device_.reset( + new Eigen::ThreadPoolDevice(compile_threadpool->AsEigenThreadPool(), + compile_threadpool->NumThreads())); + } else { + threadpool_handle_.reset(new tsl::thread::ThreadPool( + tsl::Env::Default(), "XLACpuCompile", tsl::port::MaxParallelism())); + threadpool_device_.reset( + new Eigen::ThreadPoolDevice(threadpool_handle_->AsEigenThreadPool(), + threadpool_handle_->NumThreads())); + } + + evaluator_.set_custom_call_handler( + [this](const HloInstruction* custom_call_instr, + absl::Span operands) -> StatusOr { + TF_ASSIGN_OR_RETURN( + auto backend_config, + custom_call_instr->backend_config()); + auto& matmul_config = backend_config.onednn_matmul_config(); + + auto output = Literal::CreateFromShape(custom_call_instr->shape()); + + int64_t nargs = operands.size() + 3; + std::vector args; + args.push_back(&nargs); + + ExecutableRunOptions run_options; + run_options.set_intra_op_thread_pool(threadpool_device_.get()); + args.push_back(&run_options); // No ExecutableRunOptions. + + // OneDnnMatMulConfig + std::string config; + matmul_config.SerializeToString(&config); + args.push_back(config.data()); + + std::vector minfo_ptrs(operands.size()); + std::transform(operands.begin(), operands.end(), minfo_ptrs.begin(), + CreateMemrefInfoFromLiteral); + for (auto& minfo_ptr : minfo_ptrs) { + args.push_back(static_cast(minfo_ptr.get())); + } + + auto result_ptr = CreateMemrefInfoFromLiteral(&output); + __xla_cpu_runtime_OneDnnMatMulReorder(result_ptr.get(), args.data()); + + return output; + }); + } + + Status HandleCustomCall(HloInstruction* custom_call) override { + HloInstruction* matmul; + if (Match(custom_call, OneDnnMatmulInstr(&matmul))) { + TF_ASSIGN_OR_RETURN(auto backend_config, + matmul->backend_config()); + auto& matmul_config = backend_config.onednn_matmul_config(); + + auto operands = custom_call->operands(); + auto input = operands[0]; + auto weight = operands[1]; // assuming weights is the second operand + + auto input_shape = input->shape(); + auto weight_shape = weight->shape(); + if (weight_shape.rank() != 2) { + // pre-pack only 2D weights + return DefaultAction(custom_call); + } + + auto bias_shape = + absl::c_count(matmul_config.fused_ops(), OneDnnMatMulConfig::BIAS) > 0 + ? operands.at(2)->shape() + : Shape(); + + auto output_shape = custom_call->shape(); + +#ifndef ENABLE_ONEDNN_OPENMP + // set oneDNN cuncurrency settings (which is thread-local) + tsl::OneDnnThreadPool::set_onednn_max_threads(intra_op_parallelism_); +#endif + auto new_weight_shape = OneDnnMatMulOptWeightsShape( + input_shape, weight_shape, bias_shape, output_shape, &matmul_config); + + auto cmpt = custom_call->parent(); + std::vector new_operands{ + cmpt->AddInstruction( + HloInstruction::CreateConstant(Literal(input_shape))), + weight, + cmpt->AddInstruction( + HloInstruction::CreateConstant(Literal(output_shape))), + }; + + if (ShapeUtil::IsInitialized(bias_shape)) { + new_operands.push_back(cmpt->AddInstruction( + HloInstruction::CreateConstant(Literal(bias_shape)))); + } + + HloInstruction* reorder_call = + custom_call->AddInstruction(HloInstruction::CreateCustomCall( + new_weight_shape, new_operands, "__onednn$matmul_reorder")); + + reorder_call->CopyBackendConfigFrom(custom_call); + + Literal result; + + if (evaluator_.TryEvaluate(reorder_call, &result, true)) { + HloInstruction* reordered_weight = custom_call->AddInstruction( + HloInstruction::CreateConstant(std::move(result))); + return custom_call->ReplaceOperandWithDifferentShape(1, + reordered_weight); + + } else { + return DefaultAction(custom_call); + } + } + return DefaultAction(custom_call); + } + + private: + int intra_op_parallelism_; + HloEvaluator evaluator_; + std::unique_ptr threadpool_handle_; + std::unique_ptr threadpool_device_; +}; + +StatusOr OneDnnMatMulRewriter::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + OneDnnMatMulRewriteVisitor visitor; + TF_ASSIGN_OR_RETURN(auto result, + visitor.RunOnModule(module, execution_threads)); + + OneDnnMatMulReorderVisitor reorder_visitor(intra_op_parallelism_, + compile_threadpool_); + TF_ASSIGN_OR_RETURN(auto result2, + reorder_visitor.RunOnModule(module, execution_threads)); + + return {result || result2}; +} + +} // namespace cpu +} // namespace xla + +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 diff --git a/xla/service/cpu/onednn_matmul_rewriter.h b/xla/service/cpu/onednn_matmul_rewriter.h new file mode 100644 index 0000000000000..36cab7ee949c3 --- /dev/null +++ b/xla/service/cpu/onednn_matmul_rewriter.h @@ -0,0 +1,59 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_ONEDNN_MATMUL_REWRITER_H_ +#define XLA_SERVICE_CPU_ONEDNN_MATMUL_REWRITER_H_ +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + +#include + +#include "absl/algorithm/container.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" +#include "tsl/platform/threadpool.h" + +namespace xla { +namespace cpu { + +// This pass pattern-matches HLO Dot instructions and rewrites into custom +// calls. +class OneDnnMatMulRewriter : public HloModulePass { + public: + OneDnnMatMulRewriter(int intra_op_parallelism, + const tsl::thread::ThreadPool* compile_threadpool) + : intra_op_parallelism_(intra_op_parallelism), + compile_threadpool_(compile_threadpool) {} + OneDnnMatMulRewriter() = default; + absl::string_view name() const override { return "onednn-matmul-rewriter"; } + + using HloPassInterface::Run; + StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + static bool ShouldRewrite(const HloInstruction* dot_instr); + + private: + int intra_op_parallelism_; + const tsl::thread::ThreadPool* compile_threadpool_; +}; + +} // namespace cpu +} // namespace xla + +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 +#endif // XLA_SERVICE_CPU_ONEDNN_MATMUL_REWRITER_H_ diff --git a/xla/service/cpu/onednn_memory_util.cc b/xla/service/cpu/onednn_memory_util.cc index 2bb7edfa1ad4d..372ce97c27893 100644 --- a/xla/service/cpu/onednn_memory_util.cc +++ b/xla/service/cpu/onednn_memory_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -50,6 +50,27 @@ struct MemrefInfoPOD { void* data; }; +MemrefInfoHandler CreateMemrefInfoFromLiteral(const Literal* literal) { + MemrefInfoHandler result(new MemrefInfoPOD); + + const auto& shape = literal->shape(); + result->dtype = shape.element_type(); + result->rank = shape.rank(); + auto dimensions = shape.dimensions(); + std::copy(dimensions.begin(), dimensions.end(), + absl::MakeSpan(result->dims).begin()); + + int64_t stride = 1; + for (int i : shape.layout().minor_to_major()) { + result->strides[i] = stride; + stride *= dimensions.at(i); + } + + result->data = const_cast(literal->untyped_data()); + + return result; +} + StackAlloca GetAllocaAndEmitMemrefInfo(llvm::IRBuilder<>& builder, const llvm_ir::IrArray& ir_array) { const Shape& shape = ir_array.GetShape(); @@ -96,10 +117,8 @@ StackAlloca GetAllocaAndEmitMemrefInfo(llvm::IRBuilder<>& builder, // Allocate MemrefInfo on the stack llvm::Value* memref_info_ptr = llvm_ir::EmitAllocaAtFunctionEntry( memref_info_type, "memref.info", &builder); - llvm::Value* memref_life_start = - builder.CreateLifetimeStart(memref_info_ptr, builder.getInt64(-1)); - llvm::Value* memref_store = - builder.CreateStore(memref_info_val, memref_info_ptr); + builder.CreateLifetimeStart(memref_info_ptr, builder.getInt64(-1)); + builder.CreateStore(memref_info_val, memref_info_ptr); return {&builder, memref_info_ptr}; } @@ -146,6 +165,10 @@ void MemrefInfo::Print() { std::cout << "]\n"; } +int64_t MemrefInfo::GetChannels() const { return pod_->dims[pod_->rank - 1]; } + +int64_t MemrefInfo::GetRank() const { return pod_->rank; } + } // namespace cpu } // namespace xla diff --git a/xla/service/cpu/onednn_memory_util.h b/xla/service/cpu/onednn_memory_util.h index 5373d71fcd284..fb5292843b5a6 100644 --- a/xla/service/cpu/onednn_memory_util.h +++ b/xla/service/cpu/onednn_memory_util.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,11 +17,14 @@ limitations under the License. #define XLA_SERVICE_CPU_ONEDNN_MEMORY_UTIL_H_ #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) +#include + #include "dnnl.hpp" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Value.h" +#include "xla/literal.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/xla_data.pb.h" @@ -40,6 +43,9 @@ struct StackAlloca { // Declare as opaque to put structure definition together with dependant code. struct MemrefInfoPOD; +using MemrefInfoHandler = std::shared_ptr; + +MemrefInfoHandler CreateMemrefInfoFromLiteral(const Literal* literal); StackAlloca GetAllocaAndEmitMemrefInfo(llvm::IRBuilder<>& builder, const llvm_ir::IrArray& ir_array); @@ -105,6 +111,9 @@ class MemrefInfo { void Print(); + int64_t GetChannels() const; + int64_t GetRank() const; + private: MemrefInfoPOD* pod_; }; diff --git a/xla/service/cpu/onednn_ops_rewriter.cc b/xla/service/cpu/onednn_ops_rewriter.cc new file mode 100644 index 0000000000000..058355223cd5e --- /dev/null +++ b/xla/service/cpu/onednn_ops_rewriter.cc @@ -0,0 +1,511 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + +#include "xla/service/cpu/onednn_ops_rewriter.h" + +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/onednn_memory_util.h" +#include "xla/service/cpu/onednn_util.h" +#include "xla/service/pattern_matcher.h" +#include "xla/status_macros.h" + +namespace xla { +namespace cpu { + +namespace { +namespace m = match; + +template +auto OptionalConvert(Pattern pattern) { + return m::AnyOf(m::Convert(pattern), std::move(pattern)); +} + +inline auto OneDnnConvertibleInstr(HloInstruction** instr) { + return m::AnyOf(m::CustomCall(instr, {"__onednn$layernorm"}), + m::CustomCall(instr, {"__onednn$softmax"})); +} + +HloInstruction* FindLayerNormScale(HloInstruction* instr) { + HloInstruction* scale = nullptr; + auto scalePattern = m::Multiply().WithBinaryOperandsAnyOrder( + m::Broadcast(m::Op(&scale).WithOpcode(HloOpcode::kReshape)), + m::Broadcast(m::Reshape(m::Broadcast(m::Rsqrt()))).WithOneUser()); + Match(instr, scalePattern); + return scale; +} + +HloInstruction* FindLayerNormShift(HloInstruction* instr) { + HloInstruction* shift = nullptr; + Match(instr, + m::Add().WithBinaryOperandsAnyOrder( + m::Multiply() + .WithBinaryOperandsAnyOrder( + m::Op(), m::Subtract(m::Op(), m::Broadcast().WithOneUser()) + .WithOneUser()) + .WithOneUser(), + m::Broadcast(m::Op(&shift)))); + return shift; +} + +std::optional MatchSoftmax(HloInstruction* instr) { + // + // producer + // | \ + // | reduce_max + // | | + // | reshape + // | | + // | broadcast + // | | + // | reshape + // | | + // | broadcast + // | / + // subtract + // | + // exponential + // | \ + // | reduce_sum + // | | + // | reshape + // | | + // | broadcast + // | | + // | reshape + // | | + // | broadcast + // | / + // divide // (instr parameter) + // + // where both reductions occur only on the last axis. + HloInstruction* left_exponential; + HloInstruction* right_exponential; + HloInstruction* left_producer; + HloInstruction* right_producer; + + // Lower diamond + if (!Match(instr, + m::Divide( + m::Exp(&left_exponential, m::Op()), + m::Broadcast(m::Reshape( + m::Broadcast(OptionalConvert(m::Reshape(OptionalConvert( + m::Reduce(OptionalConvert( + m::Exp(&right_exponential, m::Op())), + m::Op()) + .WithPredicate([](const HloInstruction* reduce) { + HloComputation* reducer = reduce->to_apply(); + return (reducer->root_instruction()->opcode() == + HloOpcode::kAdd && + reduce->dimensions().size() == 1 && + reduce->dimensions()[0] != + reduce->shape().rank() - 1); + }) + .WithOneUse()))))))))) { + return std::nullopt; + } + + if (left_exponential != right_exponential || + left_exponential->user_count() != 2) { + return std::nullopt; + } + + // Upper diamond + if (!Match(left_exponential->mutable_operand(0), + m::Subtract( + m::Op(&left_producer), + m::Broadcast( + m::Reshape(m::Broadcast(m::Reshape( + m::Reduce(m::Op(&right_producer), m::Op()) + .WithPredicate([](const HloInstruction* reduce) { + HloComputation* reducer = reduce->to_apply(); + return (reducer->root_instruction()->opcode() == + HloOpcode::kMaximum && + reduce->dimensions().size() == 1 && + reduce->dimensions()[0] != + reduce->shape().rank() - 1); + }) + .WithOneUse())))) + .WithOneUse()) + .WithOneUse())) { + return std::nullopt; + } + + if (left_producer != right_producer || left_producer->user_count() != 2) { + return std::nullopt; + } + + return left_producer; +} + +auto MeanPattern(HloInstruction** input) { + return m::Reshape( + m::Convert(m::Divide(m::Reduce(m::Convert(m::Op(input)), m::Op()), + m::Broadcast(m::Convert())))); +} + +template +auto Square(Pattern pattern) { + return m::Multiply() + .WithBinaryOperandsAnyOrder(pattern, pattern) + .WithPredicate([](const HloInstruction* instr) { + return instr->unique_operands().size() == 1; + }); +} + +std::optional MatchTFKerasLayerNorm(HloInstruction* instr, + HloInstruction** src, + HloInstruction** scale, + HloInstruction** bias, float* eps) { + // variance = Mean((X - Mean(x))^2) + // Z = scale / sqrt(variance + eps) + // LN(X) = X*Z + Bias - Mean(X)*Z + + HloInstruction *src_a, *src_b, *src_c; + HloInstruction *bias_node, *scaled_norm_a, *scaled_norm_b, *mean0_a, *epsilon, + *sqrd_diff_mean, *scale_node, *sqrd_diff; + + // First Match X*Z + Bias - Mean(X)*Z + if (!Match( + instr, + m::Add().WithBinaryOperandsAnyOrder( + m::Multiply() + .WithBinaryOperandsAnyOrder(m::Op(src), m::Op(&scaled_norm_a)) + .WithOneUser(), + m::Subtract(m::Op(&bias_node), + m::Multiply().WithBinaryOperandsAnyOrder( + m::Broadcast(m::Reshape(m::Op(&mean0_a))), + m::Op(&scaled_norm_b))) + .WithOneUser()))) { + return std::nullopt; + } + + if (scaled_norm_a != scaled_norm_b) return std::nullopt; + + const Shape& src_shape = (*src)->shape(); + if (!IsSupportedType(src_shape.element_type())) return std::nullopt; + + // Get bias + if (!Match(bias_node, m::Broadcast(m::Op(bias)))) return std::nullopt; + + // Match Z = scale / sqrt(variance + eps) + if (!Match(scaled_norm_a, + m::Multiply().WithBinaryOperandsAnyOrder( + m::Op(&scale_node), + m::Broadcast( + m::Reshape(m::Rsqrt(m::Add().WithBinaryOperandsAnyOrder( + m::Broadcast(m::ConstantScalar(&epsilon)), + m::Op(&sqrd_diff_mean)))))))) { + return std::nullopt; + } + + // get epsilon + *eps = static_cast(epsilon->literal().GetAsDouble({}).value()); + // get scale + if (!Match(scale_node, m::Broadcast(m::Op(scale)))) return std::nullopt; + + // match variance + if (!Match(sqrd_diff_mean, MeanPattern(&sqrd_diff))) return std::nullopt; + + if (!Match(sqrd_diff, Square(m::Subtract( + m::Op(&src_a), + m::Broadcast(m::Reshape(MeanPattern(&src_b))))))) { + return std::nullopt; + } + + if (src_a != src_b && src_a != *src) return std::nullopt; + + // Match mean from Bias - Mean(X)*Z + if (!Match(mean0_a, MeanPattern(&src_c))) return std::nullopt; + + if (src_c != *src) return std::nullopt; + + return true; +} + +bool MatchFlaxLayerNorm(HloInstruction* instr, HloInstruction** src, + HloInstruction** scale, HloInstruction** bias, + float* eps, bool* is_bf16orfp16_convert, + bool* is_producer_bf16orfp16, + HloInstruction** convert_instr) { + HloInstruction *prod_s, *hinge; + HloInstruction *div0, *div1, *div_red; + HloInstruction *mul_in0, *mul_in1, *main_pipe_mul_in0; + HloInstruction *reduce_in0, *epsilon; + HloInstruction *broadcast0, *broadcast1; + + bool scaleFound = false; + bool shiftFound = false; + + auto spine = m::Add().WithBinaryOperandsAnyOrder( + m::Broadcast(), + m::Multiply() + .WithBinaryOperandsAnyOrder( + m::Op(&hinge).WithOneUser(), + m::Subtract( + OptionalConvert(m::Op(&prod_s)), + m::Broadcast( + m::Reshape( + m::Broadcast(m::Reshape(m::Op(&div_red).WithOpcode( + HloOpcode::kDivide)) + .WithOneUser()) + .WithOneUser()) + .WithOneUser()) + .WithOneUser()) + .WithOneUser()) + .WithOneUser()); + + if (!Match(instr, spine)) return false; + + const Shape& prod_shape = prod_s->shape(); + if (!IsSupportedType(prod_shape.element_type())) return false; + + HloInstruction* shift = FindLayerNormShift(instr); + shiftFound = (shift != nullptr); + + HloInstruction* scale_gamma = FindLayerNormScale(hinge); + scaleFound = (scale_gamma != nullptr); + + // Currently patterns without scale and shift are not supported. + // OneDNN only supports 2 <= rank <= 5 + if (!(prod_shape.rank() >= 2 && prod_shape.rank() <= 5) || !shiftFound || + !scaleFound) { + return false; + } + + // NOLINTBEGIN + auto main_pipeline = m::Multiply().WithBinaryOperandsAnyOrder( + m::Op(), + m::Broadcast( + m::Reshape( + m::Broadcast( + m::Rsqrt( + m::Add() + .WithBinaryOperandsAnyOrder( + m::Broadcast(m::ConstantScalar(&epsilon)), + m::Reshape( + m::Maximum() + .WithBinaryOperandsAnyOrder( + m::Broadcast(), + m::Subtract( + m::Op(&div0).WithOpcode( + HloOpcode::kDivide), + m::Multiply() + .WithBinaryOperandsAnyOrder( + m::Op(&main_pipe_mul_in0), + m::Op(&div1).WithOpcode( + HloOpcode::kDivide)) + .WithOneUser()) + .WithOneUser()) + .WithOneUser()) + .WithOneUser()) + .WithOneUser()) + .WithOneUser()) + .WithOneUser()) + .WithOneUser()) + .WithOneUser()); + // NOLINTEND + + if (!Match(hinge, main_pipeline)) return false; + + if ((div_red != div1) || (main_pipe_mul_in0 != div1)) return false; + + auto div_red_mul_src = + m::Divide() + .WithOperand(0, m::Reduce(m::Multiply().WithBinaryOperandsAnyOrder( + OptionalConvert(m::Op(&mul_in0)), + OptionalConvert(m::Op(&mul_in1))), + m::Constant()) + .WithPredicate([](const HloInstruction* reduce) { + HloComputation* reducer = reduce->to_apply(); + return (reducer->root_instruction()->opcode() == + HloOpcode::kAdd && + reduce->dimensions().size() == 1 && + reduce->dimensions()[0] == + reduce->shape().rank()); + })) + .WithOperand(1, m::Op(&broadcast0).WithOpcode(HloOpcode::kBroadcast)) + .WithOneUser(); + + if (!Match(div0, div_red_mul_src)) return false; + + if (mul_in0 != mul_in1) return false; + + auto div_red_subgraph = + m::Divide() + .WithOperand( + 0, + m::Reduce(OptionalConvert(m::Op(&reduce_in0)), m::Constant()) + .WithPredicate([](const HloInstruction* reduce) { + HloComputation* reducer = reduce->to_apply(); + return (reducer->root_instruction()->opcode() == + HloOpcode::kAdd && + reduce->dimensions().size() == 1 && + reduce->dimensions()[0] == reduce->shape().rank()); + })) + .WithOperand(1, m::Op(&broadcast1).WithOpcode(HloOpcode::kBroadcast)); + + if (!Match(div1, div_red_subgraph)) return false; + + if (broadcast1 != broadcast0 || reduce_in0 != mul_in0 || mul_in0 != prod_s) { + return false; + } + + *is_producer_bf16orfp16 = + (prod_s->shape().element_type() == PrimitiveType::F16) || + (prod_s->shape().element_type() == PrimitiveType::BF16); + if (instr->user_count() == 1 && + instr->users().at(0)->opcode() == HloOpcode::kConvert) { + *convert_instr = instr->users().at(0); + *is_bf16orfp16_convert = + ((*convert_instr)->shape().element_type() == PrimitiveType::F16 || + (*convert_instr)->shape().element_type() == PrimitiveType::BF16); + } + + *src = prod_s; + *scale = scale_gamma; + *bias = shift; + // get epsilon + *eps = static_cast(epsilon->literal().GetAsDouble({}).value()); + + return true; +} + +} // namespace + +class OneDnnOpsRewriterVisitor : public DfsHloRewriteVisitor { + public: + Status HandleAdd(HloInstruction* instr) override { + HloInstruction *src, *scale, *bias; + float eps; + bool is_bf16orfp16_convert = false; + bool is_producer_bf16orfp16 = false; + HloInstruction* convert_instr; + + bool found_ln = + MatchTFKerasLayerNorm(instr, &src, &scale, &bias, &eps).value_or(false); + + if (!found_ln) { + found_ln = MatchFlaxLayerNorm(instr, &src, &scale, &bias, &eps, + &is_bf16orfp16_convert, + &is_producer_bf16orfp16, &convert_instr); + } + + if (!found_ln) return OkStatus(); + + const Shape& src_shape = src->shape(); + auto scale_type = scale->shape().element_type(); + auto bias_type = bias->shape().element_type(); + HloInstruction* scale_operand = scale; + HloInstruction* bias_operand = bias; + + // oneDNN requires scale and shift float32 + if ((scale_type == PrimitiveType::BF16) || + (scale_type == PrimitiveType::F16)) { + scale_operand = instr->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(scale->shape(), PrimitiveType::F32), + scale)); + } + + if ((bias_type == PrimitiveType::BF16) || + (bias_type == PrimitiveType::F16)) { + bias_operand = instr->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(bias->shape(), PrimitiveType::F32), + bias)); + } + + HloInstruction* ln_call = + instr->AddInstruction(HloInstruction::CreateCustomCall( + src_shape, {src, scale_operand, bias_operand}, + "__onednn$layernorm")); + BackendConfig backend_config; + OneDnnLayerNormConfig* ln_config = + backend_config.mutable_onednn_layer_norm_config(); + ln_config->set_fused_ops(OneDnnLayerNormConfig::SCALE_AND_SHIFT); + ln_config->set_epsilon_typecast(*(reinterpret_cast(&eps))); + TF_RETURN_IF_ERROR(ln_call->set_backend_config(backend_config)); + + if (convert_instr != nullptr && is_bf16orfp16_convert && + is_producer_bf16orfp16) { + TF_RETURN_IF_ERROR(ReplaceInstruction(convert_instr, ln_call)); + } else { + TF_RETURN_IF_ERROR(ReplaceInstruction(instr, ln_call)); + } + + return OkStatus(); + } + + Status HandleConvert(HloInstruction* instr) override { + HloInstruction* custom_call; + HloInstruction* convert_instr; + auto pattern = + m::Op(&convert_instr) + .WithOpcode(HloOpcode::kConvert) + .WithOperand(0, OneDnnConvertibleInstr(&custom_call) + .WithOneUser() + .WithElementType(PrimitiveType::F32)); + + if (!IsSupportedType(instr->shape().element_type())) return OkStatus(); + if (Match(instr, pattern)) { + bool is_bf16orfp16_convert = + (convert_instr->shape().element_type() == PrimitiveType::BF16) || + (convert_instr->shape().element_type() == PrimitiveType::F16); + if (!is_bf16orfp16_convert) return OkStatus(); + HloInstruction* producer = instr->mutable_operand(0)->mutable_operand(0); + HloInstruction* newinp = + producer->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(producer->shape(), + instr->shape().element_type()), + producer)); + absl::InlinedVector newoperands = + custom_call->mutable_operands(); + newoperands.at(0) = newinp; + HloInstruction* updated_call = instr->AddInstruction( + custom_call->CloneWithNewOperands(instr->shape(), newoperands)); + TF_RETURN_IF_ERROR(ReplaceInstruction(instr, updated_call)); + } + + return OkStatus(); + } + + Status HandleDivide(HloInstruction* divide_instr) override { + if (divide_instr->HasControlDependencies()) return OkStatus(); + if (!IsSupportedType(divide_instr->shape().element_type())) + return OkStatus(); + std::optional producer = MatchSoftmax(divide_instr); + if (producer == std::nullopt) return OkStatus(); + + const Shape& output_shape = divide_instr->shape(); + HloInstruction* softmax_call = + divide_instr->AddInstruction(HloInstruction::CreateCustomCall( + output_shape, {producer.value()}, "__onednn$softmax")); + TF_RETURN_IF_ERROR(ReplaceInstruction(divide_instr, softmax_call)); + + return OkStatus(); + } +}; + +StatusOr OneDnnOpsRewriter::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + OneDnnOpsRewriterVisitor visitor; + return visitor.RunOnModule(module, execution_threads); +} + +} // namespace cpu +} // namespace xla + +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 diff --git a/xla/service/cpu/onednn_ops_rewriter.h b/xla/service/cpu/onednn_ops_rewriter.h new file mode 100644 index 0000000000000..ea62f33ebcfb9 --- /dev/null +++ b/xla/service/cpu/onednn_ops_rewriter.h @@ -0,0 +1,44 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_CPU_ONEDNN_OPS_REWRITER_H_ +#define XLA_SERVICE_CPU_ONEDNN_OPS_REWRITER_H_ +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + +#include + +#include "absl/algorithm/container.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { +namespace cpu { + +// This pass fuses hlo instructions that can be fused into single oneDNN +// operation and rewrites into custom calls. +class OneDnnOpsRewriter : public HloModulePass { + public: + absl::string_view name() const override { return "onednn-ops-rewriter"; } + + using HloPassInterface::Run; + StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace cpu +} // namespace xla + +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 +#endif // XLA_SERVICE_CPU_ONEDNN_OPS_REWRITER_H_ diff --git a/xla/service/cpu/onednn_rewriter.cc b/xla/service/cpu/onednn_rewriter.cc deleted file mode 100644 index 4452381ea6f19..0000000000000 --- a/xla/service/cpu/onednn_rewriter.cc +++ /dev/null @@ -1,145 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) - -#include "xla/service/cpu/onednn_rewriter.h" - -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/cpu/backend_config.pb.h" -#include "xla/service/cpu/onednn_memory_util.h" -#include "xla/service/pattern_matcher.h" -#include "xla/status_macros.h" -#include "tsl/platform/cpu_info.h" - -namespace xla { -namespace cpu { - -namespace { -namespace m = match; - -Status ValidateDotDimensionNumbers(const DotDimensionNumbers& dim_numbers) { - // Checks some invariants that do not hold in general, but DotDecomposer - // should have established for us. - TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1); - std::vector batch_dim_numbers( - dim_numbers.lhs_batch_dimensions_size()); - absl::c_iota(batch_dim_numbers, 0); - TF_RET_CHECK( - absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions())); - TF_RET_CHECK( - absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions())); - return OkStatus(); -} - -bool IsSupportedType(xla::PrimitiveType dtype) { - using tsl::port::TestCPUFeature; - using tsl::port::CPUFeature; - switch (dtype) { - case F32: - return true; - case BF16: - return TestCPUFeature(CPUFeature::AVX512_BF16) || - TestCPUFeature(CPUFeature::AMX_BF16); - default: - return false; - } - return false; -} - -} // namespace - -class OneDnnRewriterVisitor : public DfsHloRewriteVisitor { - public: - // Matches patterns for possible MatMul fusions that are supported by oneDNN - // library. Matched hlo instruction(s) are replaced by custom call. - Status HandleDot(HloInstruction* instr) override { - // Currently, blocking control dependencies - if (instr->HasControlDependencies()) return OkStatus(); - HloInstruction* dot_instr; - auto pattern = m::Op(&dot_instr).WithOpcode(HloOpcode::kDot); - if (!Match(instr, pattern)) return OkStatus(); - - // TODO(intel-tf): The rewrite pass runs after dot-decomposition pass. - // Adjust the rewrite condition when the rewrite pass is moved to a - // different point in the pass-pipeline. - - // Currently, we rewrite when the data type is F32 or BF16. Note we do not - // need to check equality of contraction dim-size of the operands. HLO - // verifier already does the job. We, however, need to check if contraction - // is over only 1 dimension (a.k.a. K dimension in matrix-multiplication - // parlance). We also restrict that batch dimensions of the operands - // match. - if (!IsSupportedType(dot_instr->shape().element_type())) return OkStatus(); - auto dot_dim_numbers = dot_instr->dot_dimension_numbers(); - TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot_dim_numbers)); - const Shape& lhs_shape = dot_instr->operand(0)->shape(); - const Shape& rhs_shape = dot_instr->operand(1)->shape(); - const Shape& output_shape = dot_instr->shape(); - bool should_rewrite = true; - // None of the operands and result should be ZeroElementArray. - should_rewrite &= !ShapeUtil::IsZeroElementArray(lhs_shape); - should_rewrite &= !ShapeUtil::IsZeroElementArray(rhs_shape); - should_rewrite &= !ShapeUtil::IsZeroElementArray(output_shape); - // OneDNN only supports 2 <= rank <= kOneDnnMaxNDims. - should_rewrite &= (lhs_shape.rank() == rhs_shape.rank()); - should_rewrite &= (rhs_shape.rank() == output_shape.rank()); - should_rewrite &= - (lhs_shape.rank() >= 2 && lhs_shape.rank() <= kOneDnnMaxNDims); - if (!should_rewrite) return OkStatus(); - // Transpose scenario needs some care and blocked for oneDNN rewrite for - // now. - // TODO(intel-tf): Add transpose scenarios - should_rewrite &= LayoutUtil::IsMonotonicWithDim0Major(lhs_shape.layout()); - if (!should_rewrite) return OkStatus(); - should_rewrite &= LayoutUtil::IsMonotonicWithDim0Major(rhs_shape.layout()); - if (!should_rewrite) return OkStatus(); - should_rewrite &= - LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout()); - if (!should_rewrite) return OkStatus(); - - // Check contracting dimensions: [..., M, K] x [..., K, N] - should_rewrite &= - (dot_dim_numbers.lhs_contracting_dimensions(0) == lhs_shape.rank() - 1); - should_rewrite &= - (dot_dim_numbers.rhs_contracting_dimensions(0) == rhs_shape.rank() - 2); - if (!should_rewrite) return OkStatus(); - - HloInstruction* matmul_call = - dot_instr->AddInstruction(HloInstruction::CreateCustomCall( - output_shape, - {dot_instr->mutable_operand(0), dot_instr->mutable_operand(1)}, - "__onednn$matmul")); - // Set additional info via config, e.g., fusion info. - BackendConfig backend_config; - // No fusion is supported now, so nothing to add to the config. - TF_RETURN_IF_ERROR(matmul_call->set_backend_config(backend_config)); - TF_RETURN_IF_ERROR(ReplaceInstruction(dot_instr, matmul_call)); - return OkStatus(); - } -}; - -StatusOr OneDnnRewriter::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - OneDnnRewriterVisitor visitor; - return visitor.RunOnModule(module, execution_threads); -} - -} // namespace cpu -} // namespace xla - -#endif // INTEL_MKL && ENABLE_ONEDNN_V3 diff --git a/xla/service/cpu/onednn_rewriter.h b/xla/service/cpu/onednn_rewriter.h index c0e5013ad80f5..a1ba3205c9680 100644 --- a/xla/service/cpu/onednn_rewriter.h +++ b/xla/service/cpu/onednn_rewriter.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/onednn_softmax.cc b/xla/service/cpu/onednn_softmax.cc new file mode 100644 index 0000000000000..5af6de5407859 --- /dev/null +++ b/xla/service/cpu/onednn_softmax.cc @@ -0,0 +1,88 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) +#include "xla/service/cpu/onednn_softmax.h" + +#include +#include +#include +#include + +// Both "absl/log/check.h" and "third_party/tsl/platform/logging.h" +// are transitively included in bazel. Both of them define similar CHECK macros. +// Explicitly including the Abseil header first because the TSL version has +// undefs. + +// Otherwise, we would get redefinition error. +// clang-format off +#include "absl/log/check.h" +// clang-format on + +#include "dnnl.hpp" +#include "absl/base/dynamic_annotations.h" +#include "xla/executable_run_options.h" +#include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/onednn_memory_util.h" +#include "xla/service/cpu/runtime_lightweight_check.h" +#include "xla/tsl/util/onednn_threadpool.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace xla { +namespace cpu { + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnSoftmax( + const void* run_options_ptr, void* input, void* result) { + const xla::ExecutableRunOptions* run_options = + static_cast(run_options_ptr); + XLA_LIGHTWEIGHT_CHECK(run_options != nullptr); + XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); + tsl::OneDnnThreadPool thread_pool( + run_options->intra_op_thread_pool()->getPool(), false); + dnnl::engine cpu_engine(dnnl::engine::kind::cpu, 0); +#ifndef ENABLE_ONEDNN_OPENMP + auto onednn_stream = dnnl::stream( + dnnl::threadpool_interop::make_stream(cpu_engine, &thread_pool)); +#else + auto onednn_stream = dnnl::stream(cpu_engine); +#endif // ENABLE_ONEDNN_OPENMP + + MemrefInfo input_minfo(input); + MemrefInfo result_minfo(result); + + auto src_md = input_minfo.GetOneDnnMemDesc(); + auto dst_md = result_minfo.GetOneDnnMemDesc(); + + auto src_mem = dnnl::memory(src_md, cpu_engine, input_minfo.Data()); + auto dst_mem = dnnl::memory(dst_md, cpu_engine, result_minfo.Data()); + + int axis = (input_minfo.GetOneDnnDims().size()) - 1; + + auto softmax_pd = dnnl::softmax_forward::primitive_desc( + cpu_engine, dnnl::prop_kind::forward_inference, + dnnl::algorithm::softmax_accurate, src_md, dst_md, axis); + + auto softmax_prim = dnnl::softmax_forward(softmax_pd); + + std::unordered_map softmax_args; + softmax_args.insert({DNNL_ARG_SRC, src_mem}); + softmax_args.insert({DNNL_ARG_DST, dst_mem}); + + softmax_prim.execute(onednn_stream, softmax_args); +} + +} // namespace cpu +} // namespace xla + +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 diff --git a/xla/service/cpu/onednn_softmax.h b/xla/service/cpu/onednn_softmax.h new file mode 100644 index 0000000000000..978551a013157 --- /dev/null +++ b/xla/service/cpu/onednn_softmax.h @@ -0,0 +1,32 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_ONEDNN_SOFTMAX_H_ +#define XLA_SERVICE_CPU_ONEDNN_SOFTMAX_H_ +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + +namespace xla { +namespace cpu { + +extern "C" { +extern void __xla_cpu_runtime_OneDnnSoftmax(const void* run_options_ptr, + void* input, void* result); +} // extern "C" + +} // namespace cpu +} // namespace xla + +#endif // INTEL_MKL && && ENABLE_ONEDNN_V3 +#endif // XLA_SERVICE_CPU_ONEDNN_SOFTMAX_H_ diff --git a/xla/service/cpu/onednn_util.h b/xla/service/cpu/onednn_util.h new file mode 100644 index 0000000000000..0b8a7c65b0bf4 --- /dev/null +++ b/xla/service/cpu/onednn_util.h @@ -0,0 +1,51 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_ONEDNN_UTIL_H_ +#define XLA_SERVICE_CPU_ONEDNN_UTIL_H_ +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + +#include "xla/xla_data.pb.h" +#include "tsl/platform/cpu_info.h" + +namespace xla { +namespace cpu { + +inline bool IsSupportedType(xla::PrimitiveType dtype) { + using tsl::port::CPUFeature; + // TODO(intel-tf): Enable more types. + switch (dtype) { + case F32: + return true; + case BF16: + return TestCPUFeature(CPUFeature::AVX512F) || + TestCPUFeature(CPUFeature::AVX_NE_CONVERT) || + TestCPUFeature(CPUFeature::AMX_BF16); + case F16: + return TestCPUFeature(CPUFeature::AVX512BW) && + (TestCPUFeature(CPUFeature::AVX512_FP16) || + TestCPUFeature(CPUFeature::AMX_FP16) || + TestCPUFeature(CPUFeature::AVX_NE_CONVERT)); + default: + return false; + } + return false; +} + +} // namespace cpu +} // namespace xla + +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 +#endif // XLA_SERVICE_CPU_ONEDNN_UTIL_H_ diff --git a/xla/service/cpu/orc_jit_memory_mapper.cc b/xla/service/cpu/orc_jit_memory_mapper.cc index 68e5534487c3e..fe7922f47ccf8 100644 --- a/xla/service/cpu/orc_jit_memory_mapper.cc +++ b/xla/service/cpu/orc_jit_memory_mapper.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/orc_jit_memory_mapper.h b/xla/service/cpu/orc_jit_memory_mapper.h index 2dbfc9ec8a01a..e40af88c61a2f 100644 --- a/xla/service/cpu/orc_jit_memory_mapper.h +++ b/xla/service/cpu/orc_jit_memory_mapper.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/parallel_loop_emitter.cc b/xla/service/cpu/parallel_loop_emitter.cc index d61764a48a5f7..8c0cedf3220e5 100644 --- a/xla/service/cpu/parallel_loop_emitter.cc +++ b/xla/service/cpu/parallel_loop_emitter.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/parallel_loop_emitter.h b/xla/service/cpu/parallel_loop_emitter.h index 3f338851ddbaf..5082ef48a4151 100644 --- a/xla/service/cpu/parallel_loop_emitter.h +++ b/xla/service/cpu/parallel_loop_emitter.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/parallel_task_assignment.cc b/xla/service/cpu/parallel_task_assignment.cc index 59cdef0380734..6d10375958589 100644 --- a/xla/service/cpu/parallel_task_assignment.cc +++ b/xla/service/cpu/parallel_task_assignment.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -188,7 +188,7 @@ int64_t ParallelTaskAssignment::GetTargetParallelTaskCount( return 1; } -StatusOr ParallelTaskAssigner::Run( +absl::StatusOr ParallelTaskAssigner::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES(2, "ParallelTaskAssigner ENTRY"); diff --git a/xla/service/cpu/parallel_task_assignment.h b/xla/service/cpu/parallel_task_assignment.h index c48fcab8bf4e3..2fad81a5d9b87 100644 --- a/xla/service/cpu/parallel_task_assignment.h +++ b/xla/service/cpu/parallel_task_assignment.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -91,7 +91,7 @@ class ParallelTaskAssigner : public HloModulePass { // `execution_threads` in 'module'. By default, all `execution_threads` are // included. Returns true if the computation was changed, false otherwise. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/cpu/parallel_task_assignment_test.cc b/xla/service/cpu/parallel_task_assignment_test.cc index 9a842a65fa874..a16cd24a5381a 100644 --- a/xla/service/cpu/parallel_task_assignment_test.cc +++ b/xla/service/cpu/parallel_task_assignment_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -40,7 +40,7 @@ class ParallelTaskAssignmentTest : public HloTestBase { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }) {} - StatusOr RunParallelTaskAssigner(HloModule* module) { + absl::StatusOr RunParallelTaskAssigner(HloModule* module) { return cpu::ParallelTaskAssigner(max_parallelism_, shape_size_func_, &target_machine_features_) .Run(module); diff --git a/xla/service/cpu/runtime/BUILD b/xla/service/cpu/runtime/BUILD index 42358f07634b0..6a16997de5705 100644 --- a/xla/service/cpu/runtime/BUILD +++ b/xla/service/cpu/runtime/BUILD @@ -31,7 +31,12 @@ cc_library( "//xla/runtime:custom_call", "//xla/runtime:custom_call_registry", "//xla/runtime:executable", + "//xla/runtime:memref_view", "//xla/service/cpu:cpu_runtime", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", "@llvm-project//mlir:Support", ], ) @@ -42,12 +47,13 @@ cc_library( hdrs = ["convolution.h"], deps = [ "//xla:executable_run_options", + "//xla:xla_data_proto_cc", "//xla/runtime:memref_view", "//xla/service/cpu:runtime_conv2d", "//xla/service/cpu:runtime_conv3d", - "//xla/service/cpu:runtime_fft", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", + "@eigen_archive//:eigen3", ], ) @@ -58,9 +64,14 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":convolution", + "//xla:xla_data_proto_cc", "//xla/runtime:aot_ffi", "//xla/runtime:aot_ffi_execution_context", + "//xla/runtime:memref_view", "//xla/runtime/ffi:ffi_api", + "//xla/runtime/ffi:ffi_c_api_hdrs", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", ], ) @@ -74,7 +85,9 @@ cc_library( "//xla/runtime:custom_call", "//xla/runtime:custom_call_registry", "//xla/runtime:executable", + "//xla/runtime:memref_view", "@com_google_absl//absl/types:span", + "@llvm-project//mlir:Support", ], ) @@ -88,9 +101,14 @@ cc_library( "//xla/runtime:custom_call", "//xla/runtime:custom_call_registry", "//xla/runtime:executable", + "//xla/runtime:memref_view", "//xla/service:custom_call_status_internal", + "//xla/service:custom_call_status_public_headers", "//xla/service:custom_call_target_registry", "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", "@llvm-project//mlir:Support", ], ) @@ -101,6 +119,7 @@ cc_library( hdrs = ["fft_call.h"], deps = [ "//xla:executable_run_options", + "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/runtime:custom_call", "//xla/runtime:custom_call_registry", @@ -111,6 +130,8 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:Support", ], ) @@ -121,11 +142,15 @@ cc_library( deps = [ "//xla:executable_run_options", "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/runtime:custom_call", "//xla/runtime:custom_call_registry", "//xla/runtime:executable", + "//xla/runtime:memref_view", "//xla/service/cpu:cpu_runtime", - "@llvm-project//mlir:IR", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", "@llvm-project//mlir:Support", ], ) @@ -136,8 +161,10 @@ cc_library( hdrs = ["rng.h"], deps = [ "//xla:executable_run_options", + "//xla:xla_data_proto_cc", "//xla/runtime:memref_view", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", ], ) @@ -151,6 +178,8 @@ cc_library( "//xla/runtime:custom_call", "//xla/runtime:custom_call_registry", "//xla/runtime:executable", + "//xla/runtime:memref_view", + "@llvm-project//mlir:Support", ], ) @@ -161,8 +190,12 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":rng", + "//xla:xla_data_proto_cc", "//xla/runtime:aot_ffi", "//xla/runtime:aot_ffi_execution_context", + "//xla/runtime:memref_view", "//xla/runtime/ffi:ffi_api", + "//xla/runtime/ffi:ffi_c_api_hdrs", + "@com_google_absl//absl/status", ], ) diff --git a/xla/service/cpu/runtime/collectives.cc b/xla/service/cpu/runtime/collectives.cc index f6e3847fc7913..6034cc600245b 100644 --- a/xla/service/cpu/runtime/collectives.cc +++ b/xla/service/cpu/runtime/collectives.cc @@ -1,4 +1,4 @@ -// Copyright 2022 The TensorFlow Authors +// Copyright 2022 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -26,11 +26,19 @@ #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "xla/executable_run_options.h" #include "xla/runtime/custom_call.h" #include "xla/runtime/custom_call_registry.h" #include "xla/runtime/executable.h" +#include "xla/runtime/memref_view.h" #include "xla/service/cpu/cpu_runtime.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/xla/service/cpu/runtime/collectives.h b/xla/service/cpu/runtime/collectives.h index de277e67608e7..043a3aaeb1222 100644 --- a/xla/service/cpu/runtime/collectives.h +++ b/xla/service/cpu/runtime/collectives.h @@ -1,4 +1,4 @@ -// Copyright 2022 The TensorFlow Authors +// Copyright 2022 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime/convolution.cc b/xla/service/cpu/runtime/convolution.cc index 05ac062f6cdb6..bc2c7ef29b253 100644 --- a/xla/service/cpu/runtime/convolution.cc +++ b/xla/service/cpu/runtime/convolution.cc @@ -1,4 +1,4 @@ -// Copyright 2023 The TensorFlow Authors +// Copyright 2023 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,9 +19,13 @@ #include #include "absl/status/status.h" +#include "absl/types/span.h" +#include "Eigen/Core" // from @eigen_archive #include "xla/executable_run_options.h" +#include "xla/runtime/memref_view.h" #include "xla/service/cpu/runtime_conv2d.h" #include "xla/service/cpu/runtime_conv3d.h" +#include "xla/xla_data.pb.h" namespace xla { namespace cpu { diff --git a/xla/service/cpu/runtime/convolution.h b/xla/service/cpu/runtime/convolution.h index dfabfc8ededaf..fe4433774a704 100644 --- a/xla/service/cpu/runtime/convolution.h +++ b/xla/service/cpu/runtime/convolution.h @@ -1,4 +1,4 @@ -// Copyright 2023 The TensorFlow Authors +// Copyright 2023 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #include +#include "absl/status/status.h" #include "absl/types/span.h" #include "xla/executable_run_options.h" #include "xla/runtime/memref_view.h" diff --git a/xla/service/cpu/runtime/convolution_call.cc b/xla/service/cpu/runtime/convolution_call.cc index 01484e7b621cc..793f6285da40c 100644 --- a/xla/service/cpu/runtime/convolution_call.cc +++ b/xla/service/cpu/runtime/convolution_call.cc @@ -1,4 +1,4 @@ -// Copyright 2023 The TensorFlow Authors +// Copyright 2023 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,9 +24,13 @@ #include #include +#include "absl/types/span.h" +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "xla/executable_run_options.h" #include "xla/runtime/custom_call.h" +#include "xla/runtime/custom_call_registry.h" #include "xla/runtime/executable.h" +#include "xla/runtime/memref_view.h" #include "xla/service/cpu/runtime/convolution.h" namespace xla { diff --git a/xla/service/cpu/runtime/convolution_call.h b/xla/service/cpu/runtime/convolution_call.h index eea35e4b852d5..07bc96c51b4bf 100644 --- a/xla/service/cpu/runtime/convolution_call.h +++ b/xla/service/cpu/runtime/convolution_call.h @@ -1,4 +1,4 @@ -// Copyright 2023 The TensorFlow Authors +// Copyright 2023 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime/convolution_ffi.cc b/xla/service/cpu/runtime/convolution_ffi.cc index d17e9805a3e5b..9673938a05eea 100644 --- a/xla/service/cpu/runtime/convolution_ffi.cc +++ b/xla/service/cpu/runtime/convolution_ffi.cc @@ -1,4 +1,4 @@ -// Copyright 2023 The TensorFlow Authors +// Copyright 2023 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,10 +14,15 @@ #include "xla/service/cpu/runtime/convolution_ffi.h" +#include "absl/status/status.h" +#include "absl/types/span.h" #include "xla/runtime/aot_ffi.h" #include "xla/runtime/aot_ffi_execution_context.h" #include "xla/runtime/ffi/ffi_api.h" +#include "xla/runtime/ffi/ffi_c_api.h" +#include "xla/runtime/memref_view.h" #include "xla/service/cpu/runtime/convolution.h" +#include "xla/xla_data.pb.h" namespace xla { struct ExecutableRunOptions; diff --git a/xla/service/cpu/runtime/convolution_ffi.h b/xla/service/cpu/runtime/convolution_ffi.h index c517c7461e064..7ca9319269a54 100644 --- a/xla/service/cpu/runtime/convolution_ffi.h +++ b/xla/service/cpu/runtime/convolution_ffi.h @@ -1,4 +1,4 @@ -// Copyright 2023 The TensorFlow Authors +// Copyright 2023 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime/custom_call.cc b/xla/service/cpu/runtime/custom_call.cc index e5b3659b95684..6b45f3d1a3671 100644 --- a/xla/service/cpu/runtime/custom_call.cc +++ b/xla/service/cpu/runtime/custom_call.cc @@ -1,4 +1,4 @@ -// Copyright 2022 The TensorFlow Authors +// Copyright 2022 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -25,11 +25,17 @@ #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "xla/primitive_util.h" #include "xla/runtime/custom_call.h" #include "xla/runtime/custom_call_registry.h" #include "xla/runtime/executable.h" +#include "xla/runtime/memref_view.h" +#include "xla/service/custom_call_status.h" #include "xla/service/custom_call_status_internal.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/hlo.pb.h" diff --git a/xla/service/cpu/runtime/custom_call.h b/xla/service/cpu/runtime/custom_call.h index ee8c0d8726bfb..c4992e60cb9a2 100644 --- a/xla/service/cpu/runtime/custom_call.h +++ b/xla/service/cpu/runtime/custom_call.h @@ -1,4 +1,4 @@ -// Copyright 2022 The TensorFlow Authors +// Copyright 2022 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime/fft_call.cc b/xla/service/cpu/runtime/fft_call.cc index 5a619f20c8fd4..c62b57422a154 100644 --- a/xla/service/cpu/runtime/fft_call.cc +++ b/xla/service/cpu/runtime/fft_call.cc @@ -1,4 +1,4 @@ -// Copyright 2022 The TensorFlow Authors +// Copyright 2022 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -27,6 +27,8 @@ #include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "xla/executable_run_options.h" #include "xla/runtime/custom_call.h" #include "xla/runtime/custom_call_registry.h" @@ -35,6 +37,7 @@ #include "xla/service/cpu/runtime_fft.h" #include "xla/service/hlo.pb.h" #include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" namespace xla { namespace cpu { diff --git a/xla/service/cpu/runtime/fft_call.h b/xla/service/cpu/runtime/fft_call.h index 4865511d8dc26..7e728824fefd1 100644 --- a/xla/service/cpu/runtime/fft_call.h +++ b/xla/service/cpu/runtime/fft_call.h @@ -1,4 +1,4 @@ -// Copyright 2022 The TensorFlow Authors +// Copyright 2022 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime/retain.cc b/xla/service/cpu/runtime/retain.cc index 5e96ca5f4f460..431c0f75a8c8c 100644 --- a/xla/service/cpu/runtime/retain.cc +++ b/xla/service/cpu/runtime/retain.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime/rng.cc b/xla/service/cpu/runtime/rng.cc index a1cc6c046d2ee..7f2edd42b56b2 100644 --- a/xla/service/cpu/runtime/rng.cc +++ b/xla/service/cpu/runtime/rng.cc @@ -1,4 +1,4 @@ -// Copyright 2023 The TensorFlow Authors +// Copyright 2023 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,7 +18,10 @@ #include #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "xla/executable_run_options.h" +#include "xla/runtime/memref_view.h" +#include "xla/xla_data.pb.h" namespace xla { namespace cpu { diff --git a/xla/service/cpu/runtime/rng.h b/xla/service/cpu/runtime/rng.h index ff36b6fd1dd3b..dc724ec15eb8f 100644 --- a/xla/service/cpu/runtime/rng.h +++ b/xla/service/cpu/runtime/rng.h @@ -1,4 +1,4 @@ -// Copyright 2023 The TensorFlow Authors +// Copyright 2023 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime/rng_call.cc b/xla/service/cpu/runtime/rng_call.cc index 97d46c38fdc6a..6bcbe0fe0bf7e 100644 --- a/xla/service/cpu/runtime/rng_call.cc +++ b/xla/service/cpu/runtime/rng_call.cc @@ -1,4 +1,4 @@ -// Copyright 2023 The TensorFlow Authors +// Copyright 2023 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,10 +17,12 @@ #include #include +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "xla/executable_run_options.h" #include "xla/runtime/custom_call.h" #include "xla/runtime/custom_call_registry.h" #include "xla/runtime/executable.h" +#include "xla/runtime/memref_view.h" #include "xla/service/cpu/runtime/rng.h" namespace xla { diff --git a/xla/service/cpu/runtime/rng_call.h b/xla/service/cpu/runtime/rng_call.h index ba95a1bbd1ef0..f189b90084076 100644 --- a/xla/service/cpu/runtime/rng_call.h +++ b/xla/service/cpu/runtime/rng_call.h @@ -1,4 +1,4 @@ -// Copyright 2023 The TensorFlow Authors +// Copyright 2023 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime/rng_ffi.cc b/xla/service/cpu/runtime/rng_ffi.cc index e4d485a68eda3..8efd9aabfade0 100644 --- a/xla/service/cpu/runtime/rng_ffi.cc +++ b/xla/service/cpu/runtime/rng_ffi.cc @@ -1,4 +1,4 @@ -// Copyright 2023 The TensorFlow Authors +// Copyright 2023 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,10 +14,14 @@ #include "xla/service/cpu/runtime/rng_ffi.h" +#include "absl/status/status.h" #include "xla/runtime/aot_ffi.h" #include "xla/runtime/aot_ffi_execution_context.h" #include "xla/runtime/ffi/ffi_api.h" +#include "xla/runtime/ffi/ffi_c_api.h" +#include "xla/runtime/memref_view.h" #include "xla/service/cpu/runtime/rng.h" +#include "xla/xla_data.pb.h" namespace xla { struct ExecutableRunOptions; diff --git a/xla/service/cpu/runtime/rng_ffi.h b/xla/service/cpu/runtime/rng_ffi.h index fd23e56ef99c7..4383f96ae4520 100644 --- a/xla/service/cpu/runtime/rng_ffi.h +++ b/xla/service/cpu/runtime/rng_ffi.h @@ -1,4 +1,4 @@ -// Copyright 2023 The TensorFlow Authors +// Copyright 2023 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime/xfeed.cc b/xla/service/cpu/runtime/xfeed.cc index b208ff705fe2a..38bb2eb34644e 100644 --- a/xla/service/cpu/runtime/xfeed.cc +++ b/xla/service/cpu/runtime/xfeed.cc @@ -1,4 +1,4 @@ -// Copyright 2022 The TensorFlow Authors +// Copyright 2022 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -26,15 +26,20 @@ #include #include -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "xla/executable_run_options.h" #include "xla/primitive_util.h" #include "xla/runtime/custom_call.h" #include "xla/runtime/custom_call_registry.h" #include "xla/runtime/executable.h" +#include "xla/runtime/memref_view.h" #include "xla/service/cpu/cpu_runtime.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" namespace xla { namespace cpu { diff --git a/xla/service/cpu/runtime/xfeed.h b/xla/service/cpu/runtime/xfeed.h index d4d5c9dca98ce..abdb7f117edc7 100644 --- a/xla/service/cpu/runtime/xfeed.h +++ b/xla/service/cpu/runtime/xfeed.h @@ -1,4 +1,4 @@ -// Copyright 2022 The TensorFlow Authors +// Copyright 2022 The OpenXLA Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_conv2d.cc b/xla/service/cpu/runtime_conv2d.cc index 9278f2af8f550..335da93b83cf5 100644 --- a/xla/service/cpu/runtime_conv2d.cc +++ b/xla/service/cpu/runtime_conv2d.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_conv2d.h b/xla/service/cpu/runtime_conv2d.h index 42726f2d21a21..affe727e159dc 100644 --- a/xla/service/cpu/runtime_conv2d.h +++ b/xla/service/cpu/runtime_conv2d.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_conv2d_acl.cc b/xla/service/cpu/runtime_conv2d_acl.cc index 7cfcd83e5f250..6f3738300b7df 100644 --- a/xla/service/cpu/runtime_conv2d_acl.cc +++ b/xla/service/cpu/runtime_conv2d_acl.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_conv2d_acl.h b/xla/service/cpu/runtime_conv2d_acl.h index 756642e0da819..69a2429ff49f6 100644 --- a/xla/service/cpu/runtime_conv2d_acl.h +++ b/xla/service/cpu/runtime_conv2d_acl.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_conv2d_mkl.cc b/xla/service/cpu/runtime_conv2d_mkl.cc index 68de54ecbca0a..1b794305580c7 100644 --- a/xla/service/cpu/runtime_conv2d_mkl.cc +++ b/xla/service/cpu/runtime_conv2d_mkl.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_conv2d_mkl.h b/xla/service/cpu/runtime_conv2d_mkl.h index 6bb1312bc7c8a..f5300e8ddf46d 100644 --- a/xla/service/cpu/runtime_conv2d_mkl.h +++ b/xla/service/cpu/runtime_conv2d_mkl.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_conv3d.cc b/xla/service/cpu/runtime_conv3d.cc index a8e0d0a7a7200..95dc56ac4546b 100644 --- a/xla/service/cpu/runtime_conv3d.cc +++ b/xla/service/cpu/runtime_conv3d.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_conv3d.h b/xla/service/cpu/runtime_conv3d.h index 2797608543775..0ad8d19df0873 100644 --- a/xla/service/cpu/runtime_conv3d.h +++ b/xla/service/cpu/runtime_conv3d.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_conv_impl.h b/xla/service/cpu/runtime_conv_impl.h index 613c14b1a5c6a..ad25c2de8bf01 100644 --- a/xla/service/cpu/runtime_conv_impl.h +++ b/xla/service/cpu/runtime_conv_impl.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_custom_call_status.cc b/xla/service/cpu/runtime_custom_call_status.cc index 42ad966245ea6..f88745bc911d2 100644 --- a/xla/service/cpu/runtime_custom_call_status.cc +++ b/xla/service/cpu/runtime_custom_call_status.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_custom_call_status.h b/xla/service/cpu/runtime_custom_call_status.h index 5300b4daecc37..e243b46e0e3bc 100644 --- a/xla/service/cpu/runtime_custom_call_status.h +++ b/xla/service/cpu/runtime_custom_call_status.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_fft.cc b/xla/service/cpu/runtime_fft.cc index 3d78d0faf45d0..d36ba626d4856 100644 --- a/xla/service/cpu/runtime_fft.cc +++ b/xla/service/cpu/runtime_fft.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_fft.h b/xla/service/cpu/runtime_fft.h index ad9b553d11aaa..2997d04ffb580 100644 --- a/xla/service/cpu/runtime_fft.h +++ b/xla/service/cpu/runtime_fft.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_fork_join.cc b/xla/service/cpu/runtime_fork_join.cc index 7e8ab842fe83f..f179f23c959dc 100644 --- a/xla/service/cpu/runtime_fork_join.cc +++ b/xla/service/cpu/runtime_fork_join.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -62,8 +62,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( void** buffer_table, void* status, uint64_t* prof_counters, int32_t num_partitions, int64_t* partitions, int32_t num_partitioned_dims, void* function_ptr) { - VLOG(2) << "ParallelForkJoin ENTRY" - << " num_partitions: " << num_partitions + VLOG(2) << "ParallelForkJoin ENTRY" << " num_partitions: " << num_partitions << " num_partitioned_dims: " << num_partitioned_dims; CHECK_EQ(params, nullptr); CHECK_GT(num_partitions, 1); @@ -97,8 +96,8 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( } // Call first compute function inline. - function(result_ptr, run_options_ptr, params, buffer_table, &statuses[0], - &partitions[0], prof_counters); + function(result_ptr, run_options_ptr, params, buffer_table, statuses.data(), + partitions, prof_counters); VLOG(3) << "ParallelForkJoin partition 0 done."; bc.Wait(); diff --git a/xla/service/cpu/runtime_fork_join.h b/xla/service/cpu/runtime_fork_join.h index 4ecf4f07617d8..11fc141e09f79 100644 --- a/xla/service/cpu/runtime_fork_join.h +++ b/xla/service/cpu/runtime_fork_join.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_fp16.cc b/xla/service/cpu/runtime_fp16.cc index f63b24f17d416..4c7acae967baf 100644 --- a/xla/service/cpu/runtime_fp16.cc +++ b/xla/service/cpu/runtime_fp16.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_fp16.h b/xla/service/cpu/runtime_fp16.h index 3f7af5197766a..c86d6dc37f0d5 100644 --- a/xla/service/cpu/runtime_fp16.h +++ b/xla/service/cpu/runtime_fp16.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_key_value_sort.cc b/xla/service/cpu/runtime_key_value_sort.cc index 148984ae9930a..3afd3936788bd 100644 --- a/xla/service/cpu/runtime_key_value_sort.cc +++ b/xla/service/cpu/runtime_key_value_sort.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_key_value_sort.h b/xla/service/cpu/runtime_key_value_sort.h index 2b45028877090..8df8503adf9e4 100644 --- a/xla/service/cpu/runtime_key_value_sort.h +++ b/xla/service/cpu/runtime_key_value_sort.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_lightweight_check.h b/xla/service/cpu/runtime_lightweight_check.h index a882bd56f2087..49fc9cb3457a5 100644 --- a/xla/service/cpu/runtime_lightweight_check.h +++ b/xla/service/cpu/runtime_lightweight_check.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_matmul.h b/xla/service/cpu/runtime_matmul.h index ec6e9bf09682b..d28fa6147663d 100644 --- a/xla/service/cpu/runtime_matmul.h +++ b/xla/service/cpu/runtime_matmul.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_matmul_acl.cc b/xla/service/cpu/runtime_matmul_acl.cc index b8728b47a12b8..feeadff718110 100644 --- a/xla/service/cpu/runtime_matmul_acl.cc +++ b/xla/service/cpu/runtime_matmul_acl.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_matmul_acl.h b/xla/service/cpu/runtime_matmul_acl.h index 57d522c3d04f2..94f4f56d65f10 100644 --- a/xla/service/cpu/runtime_matmul_acl.h +++ b/xla/service/cpu/runtime_matmul_acl.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_matmul_c128.cc b/xla/service/cpu/runtime_matmul_c128.cc index c237763d1bc60..0890c1c659942 100644 --- a/xla/service/cpu/runtime_matmul_c128.cc +++ b/xla/service/cpu/runtime_matmul_c128.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_matmul_c64.cc b/xla/service/cpu/runtime_matmul_c64.cc index e526061f7fad6..0152cf74927f1 100644 --- a/xla/service/cpu/runtime_matmul_c64.cc +++ b/xla/service/cpu/runtime_matmul_c64.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_matmul_common.h b/xla/service/cpu/runtime_matmul_common.h index a08be9d36680e..899c204933239 100644 --- a/xla/service/cpu/runtime_matmul_common.h +++ b/xla/service/cpu/runtime_matmul_common.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_matmul_f16.cc b/xla/service/cpu/runtime_matmul_f16.cc index 72d80d39f0cf4..fae796201a518 100644 --- a/xla/service/cpu/runtime_matmul_f16.cc +++ b/xla/service/cpu/runtime_matmul_f16.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_matmul_f32.cc b/xla/service/cpu/runtime_matmul_f32.cc index 7e40231590f40..e49e53c0de3e3 100644 --- a/xla/service/cpu/runtime_matmul_f32.cc +++ b/xla/service/cpu/runtime_matmul_f32.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_matmul_f64.cc b/xla/service/cpu/runtime_matmul_f64.cc index d75c400e4a5e7..318ef91babe69 100644 --- a/xla/service/cpu/runtime_matmul_f64.cc +++ b/xla/service/cpu/runtime_matmul_f64.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_matmul_s32.cc b/xla/service/cpu/runtime_matmul_s32.cc index 69c8634426d13..6bc5a9bb2abd7 100644 --- a/xla/service/cpu/runtime_matmul_s32.cc +++ b/xla/service/cpu/runtime_matmul_s32.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_pow.cc b/xla/service/cpu/runtime_pow.cc index d391a1409e83d..afd3c96e39b4c 100644 --- a/xla/service/cpu/runtime_pow.cc +++ b/xla/service/cpu/runtime_pow.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_pow.h b/xla/service/cpu/runtime_pow.h index 0dee0c607900e..ae3196cba89cc 100644 --- a/xla/service/cpu/runtime_pow.h +++ b/xla/service/cpu/runtime_pow.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_single_threaded_conv2d.cc b/xla/service/cpu/runtime_single_threaded_conv2d.cc index 32b0cb4a46824..c98ed373b6368 100644 --- a/xla/service/cpu/runtime_single_threaded_conv2d.cc +++ b/xla/service/cpu/runtime_single_threaded_conv2d.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_single_threaded_conv2d.h b/xla/service/cpu/runtime_single_threaded_conv2d.h index 76c1bbb221e88..89107792b2432 100644 --- a/xla/service/cpu/runtime_single_threaded_conv2d.h +++ b/xla/service/cpu/runtime_single_threaded_conv2d.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_single_threaded_conv3d.cc b/xla/service/cpu/runtime_single_threaded_conv3d.cc index 154d4369e3e3c..64da24fca18d0 100644 --- a/xla/service/cpu/runtime_single_threaded_conv3d.cc +++ b/xla/service/cpu/runtime_single_threaded_conv3d.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_single_threaded_conv3d.h b/xla/service/cpu/runtime_single_threaded_conv3d.h index b9f3e3558a8f6..1d008f4419cb2 100644 --- a/xla/service/cpu/runtime_single_threaded_conv3d.h +++ b/xla/service/cpu/runtime_single_threaded_conv3d.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_single_threaded_fft.cc b/xla/service/cpu/runtime_single_threaded_fft.cc index d573321d73d89..a11e52da4e8ce 100644 --- a/xla/service/cpu/runtime_single_threaded_fft.cc +++ b/xla/service/cpu/runtime_single_threaded_fft.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_single_threaded_fft.h b/xla/service/cpu/runtime_single_threaded_fft.h index 7019df54a779a..23a84bad931ff 100644 --- a/xla/service/cpu/runtime_single_threaded_fft.h +++ b/xla/service/cpu/runtime_single_threaded_fft.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_single_threaded_matmul.h b/xla/service/cpu/runtime_single_threaded_matmul.h index da336bc0ec6a7..407a7d29ca296 100644 --- a/xla/service/cpu/runtime_single_threaded_matmul.h +++ b/xla/service/cpu/runtime_single_threaded_matmul.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_single_threaded_matmul_c128.cc b/xla/service/cpu/runtime_single_threaded_matmul_c128.cc index 921b5b4317496..a6897e5494be1 100644 --- a/xla/service/cpu/runtime_single_threaded_matmul_c128.cc +++ b/xla/service/cpu/runtime_single_threaded_matmul_c128.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_single_threaded_matmul_c64.cc b/xla/service/cpu/runtime_single_threaded_matmul_c64.cc index 64341138b638a..64963d8119e91 100644 --- a/xla/service/cpu/runtime_single_threaded_matmul_c64.cc +++ b/xla/service/cpu/runtime_single_threaded_matmul_c64.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_single_threaded_matmul_common.h b/xla/service/cpu/runtime_single_threaded_matmul_common.h index f6e3cd5c34209..461fd95266d13 100644 --- a/xla/service/cpu/runtime_single_threaded_matmul_common.h +++ b/xla/service/cpu/runtime_single_threaded_matmul_common.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_single_threaded_matmul_f16.cc b/xla/service/cpu/runtime_single_threaded_matmul_f16.cc index f9f44ff12899e..785d90c705753 100644 --- a/xla/service/cpu/runtime_single_threaded_matmul_f16.cc +++ b/xla/service/cpu/runtime_single_threaded_matmul_f16.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_single_threaded_matmul_f32.cc b/xla/service/cpu/runtime_single_threaded_matmul_f32.cc index 85339d895af4f..23b552e470c69 100644 --- a/xla/service/cpu/runtime_single_threaded_matmul_f32.cc +++ b/xla/service/cpu/runtime_single_threaded_matmul_f32.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_single_threaded_matmul_f64.cc b/xla/service/cpu/runtime_single_threaded_matmul_f64.cc index 989fc520de41d..6d6c42726e888 100644 --- a/xla/service/cpu/runtime_single_threaded_matmul_f64.cc +++ b/xla/service/cpu/runtime_single_threaded_matmul_f64.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_single_threaded_matmul_s32.cc b/xla/service/cpu/runtime_single_threaded_matmul_s32.cc index 5f14070da155a..7602f188a2102 100644 --- a/xla/service/cpu/runtime_single_threaded_matmul_s32.cc +++ b/xla/service/cpu/runtime_single_threaded_matmul_s32.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_topk.cc b/xla/service/cpu/runtime_topk.cc index d59ccea13df3b..b239d056d94cc 100644 --- a/xla/service/cpu/runtime_topk.cc +++ b/xla/service/cpu/runtime_topk.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/runtime_topk.h b/xla/service/cpu/runtime_topk.h index e33766c7a83cf..13e922d16401c 100644 --- a/xla/service/cpu/runtime_topk.h +++ b/xla/service/cpu/runtime_topk.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/sample_harness.cc b/xla/service/cpu/sample_harness.cc index f09119ded2069..1f21f22d2846a 100644 --- a/xla/service/cpu/sample_harness.cc +++ b/xla/service/cpu/sample_harness.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -58,7 +58,7 @@ int main(int argc, char** argv) { // Execute and transfer result of computation. xla::ExecutionProfile profile; - xla::StatusOr result = client->ExecuteAndTransfer( + absl::StatusOr result = client->ExecuteAndTransfer( computation, /*arguments=*/{param0_data.get(), param1_data.get()}, /*execution_options=*/nullptr, diff --git a/xla/service/cpu/shape_partition.cc b/xla/service/cpu/shape_partition.cc index f30782f9a4a12..197c8a013c7ed 100644 --- a/xla/service/cpu/shape_partition.cc +++ b/xla/service/cpu/shape_partition.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/shape_partition.h b/xla/service/cpu/shape_partition.h index ac6667ae53264..e3ec5947e8bae 100644 --- a/xla/service/cpu/shape_partition.h +++ b/xla/service/cpu/shape_partition.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/shape_partition_test.cc b/xla/service/cpu/shape_partition_test.cc index efb7d4ccb95e7..4b270f7754f03 100644 --- a/xla/service/cpu/shape_partition_test.cc +++ b/xla/service/cpu/shape_partition_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/simple_orc_jit.cc b/xla/service/cpu/simple_orc_jit.cc index 8895b4f6451d5..04f522c770cc8 100644 --- a/xla/service/cpu/simple_orc_jit.cc +++ b/xla/service/cpu/simple_orc_jit.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,9 +22,13 @@ limitations under the License. #include #include #include +#include #include // NOLINT #include +#include +#include "absl/functional/any_invocable.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/ExecutionEngine/JITSymbol.h" #include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" @@ -35,6 +39,7 @@ limitations under the License. #include "llvm/Support/Alignment.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/Memory.h" +#include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/Process.h" #include "llvm/TargetParser/Host.h" #include "mlir/ExecutionEngine/CRunnerUtils.h" // from @llvm-project @@ -64,7 +69,9 @@ limitations under the License. #include "tsl/platform/logging.h" #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) +#include "xla/service/cpu/onednn_layer_norm.h" #include "xla/service/cpu/onednn_matmul.h" +#include "xla/service/cpu/onednn_softmax.h" #endif // Provided by compiler-rt and MLIR. @@ -75,10 +82,9 @@ extern "C" uint16_t __truncdfbf2(double); namespace xla { namespace cpu { -namespace { -llvm::SmallVector DetectMachineAttributes() { - llvm::SmallVector result; +std::vector DetectMachineAttributes() { + std::vector result; llvm::StringMap host_features; if (llvm::sys::getHostCPUFeatures(host_features)) { for (auto& feature : host_features) { @@ -89,6 +95,8 @@ llvm::SmallVector DetectMachineAttributes() { return result; } +namespace { + class DefaultMemoryMapper final : public llvm::SectionMemoryManager::MemoryMapper { public: @@ -289,6 +297,8 @@ bool ContiguousSectionMemoryManager::finalizeMemory(std::string* err_msg) { SimpleOrcJIT::InferTargetMachineForJIT( const llvm::TargetOptions& target_options, llvm::CodeGenOptLevel opt_level) { + std::vector attrs = DetectMachineAttributes(); + llvm::SmallVector llvm_attrs(attrs.begin(), attrs.end()); std::unique_ptr target_machine( llvm::EngineBuilder() .setTargetOptions(target_options) @@ -296,7 +306,7 @@ SimpleOrcJIT::InferTargetMachineForJIT( .selectTarget( /*TargetTriple=*/llvm::Triple(), /*MArch=*/"", /*MCPU=*/llvm::sys::getHostCPUName(), - /*MAttrs=*/DetectMachineAttributes())); + /*MAttrs=*/llvm_attrs)); CHECK(target_machine != nullptr); return target_machine; } @@ -309,7 +319,7 @@ SimpleOrcJIT::SimpleOrcJIT( bool disable_slp_vectorizer, llvm::FastMathFlags fast_math_flags, LLVMCompiler::ModuleHook pre_optimization_hook, LLVMCompiler::ModuleHook post_optimization_hook, - std::function post_codegen_hook) + absl::AnyInvocable post_codegen_hook) : target_machine_(InferTargetMachineForJIT(target_options, opt_level)), target_triple_(target_machine_->getTargetTriple()), data_layout_(target_machine_->createDataLayout()), @@ -386,7 +396,8 @@ llvm::Expected> SimpleOrcJIT::Create( bool disable_slp_vectorizer, llvm::FastMathFlags fast_math_flags, LLVMCompiler::ModuleHook pre_optimization_hook, LLVMCompiler::ModuleHook post_optimization_hook, - std::function post_codegen_hook) { + absl::AnyInvocable + post_codegen_hook) { auto SSP = std::make_shared(); auto target_process_control = llvm::orc::SelfExecutorProcessControl::Create(std::move(SSP)); @@ -441,6 +452,11 @@ void SimpleOrcJIT::notifyFreeingObject(llvm::JITEventListener::ObjectKey key) { gdb_jit_event_listener_->notifyFreeingObject(key); } +llvm::Error SimpleOrcJIT::AddObjFile( + std::unique_ptr obj_file) { + return object_layer_.add(*main_jit_dylib_, std::move(obj_file)); +} + llvm::Error SimpleOrcJIT::AddModule(llvm::orc::ThreadSafeModule module) { return compile_layer_.add(*main_jit_dylib_, std::move(module)); } @@ -485,6 +501,8 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(AllReduce); REGISTER_CPU_RUNTIME_SYMBOL(CollectivePermute); REGISTER_CPU_RUNTIME_SYMBOL(AllToAll); + REGISTER_CPU_RUNTIME_SYMBOL(AllGather); + REGISTER_CPU_RUNTIME_SYMBOL(ReduceScatter); REGISTER_CPU_RUNTIME_SYMBOL(PartitionId); REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId); REGISTER_CPU_RUNTIME_SYMBOL(MKLConv2DF32); @@ -525,6 +543,9 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd); #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) REGISTER_CPU_RUNTIME_SYMBOL(OneDnnMatMul); + REGISTER_CPU_RUNTIME_SYMBOL(OneDnnSoftmax); + REGISTER_CPU_RUNTIME_SYMBOL(OneDnnLayerNorm); + REGISTER_CPU_RUNTIME_SYMBOL(OneDnnMatMulReorder); #endif // INTEL_MKL && ENABLE_ONEDNN_V3 registry->Register("__gnu_f2h_ieee", reinterpret_cast(__gnu_f2h_ieee), diff --git a/xla/service/cpu/simple_orc_jit.h b/xla/service/cpu/simple_orc_jit.h index a6fa7f896a28a..8190a910b0a68 100644 --- a/xla/service/cpu/simple_orc_jit.h +++ b/xla/service/cpu/simple_orc_jit.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/functional/any_invocable.h" #include "llvm/ExecutionEngine/JITEventListener.h" #include "llvm/ExecutionEngine/Orc/Core.h" #include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" @@ -27,6 +28,7 @@ limitations under the License. #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" #include "llvm/ExecutionEngine/Orc/SymbolStringPool.h" #include "llvm/IR/Module.h" +#include "llvm/Support/MemoryBuffer.h" #include "llvm/Target/TargetMachine.h" #include "llvm/TargetParser/Triple.h" #include "xla/service/cpu/compiler_functor.h" @@ -62,7 +64,8 @@ class SimpleOrcJIT : public llvm::JITEventListener { llvm::FastMathFlags fast_math_flags, LLVMCompiler::ModuleHook pre_optimization_hook, LLVMCompiler::ModuleHook post_optimization_hook, - std::function post_codegen_hook); + absl::AnyInvocable + post_codegen_hook); static llvm::Expected> Create( const llvm::TargetOptions& target_options, @@ -71,7 +74,8 @@ class SimpleOrcJIT : public llvm::JITEventListener { llvm::FastMathFlags fast_math_flags, LLVMCompiler::ModuleHook pre_optimization_hook, LLVMCompiler::ModuleHook post_optimization_hook, - std::function post_codegen_hook); + absl::AnyInvocable + post_codegen_hook); ~SimpleOrcJIT() override; @@ -79,6 +83,7 @@ class SimpleOrcJIT : public llvm::JITEventListener { const llvm::Triple& target_triple() const { return target_triple_; } + llvm::Error AddObjFile(std::unique_ptr obj_file); llvm::Error AddModule(llvm::orc::ThreadSafeModule module); // Discards objects we no longer need once we are done compiling. @@ -132,6 +137,8 @@ class SimpleOrcJIT : public llvm::JITEventListener { llvm::JITEventListener* perf_jit_event_listener_; }; +std::vector DetectMachineAttributes(); + } // namespace cpu } // namespace xla diff --git a/xla/service/cpu/target_machine_features.cc b/xla/service/cpu/target_machine_features.cc index ee45ddb7827a2..6072356be0897 100644 --- a/xla/service/cpu/target_machine_features.cc +++ b/xla/service/cpu/target_machine_features.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/target_machine_features.h b/xla/service/cpu/target_machine_features.h index 65b7621a80ded..d1dcddf732aaf 100644 --- a/xla/service/cpu/target_machine_features.h +++ b/xla/service/cpu/target_machine_features.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/target_machine_features_fake.h b/xla/service/cpu/target_machine_features_fake.h index d8503da832037..2823770177f00 100644 --- a/xla/service/cpu/target_machine_features_fake.h +++ b/xla/service/cpu/target_machine_features_fake.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/test_target_triple_helper.h b/xla/service/cpu/test_target_triple_helper.h index 2e45e0832d04c..0b057bf400179 100644 --- a/xla/service/cpu/test_target_triple_helper.h +++ b/xla/service/cpu/test_target_triple_helper.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/tests/BUILD b/xla/service/cpu/tests/BUILD index d74915c467dce..8cf504bd4c957 100644 --- a/xla/service/cpu/tests/BUILD +++ b/xla/service/cpu/tests/BUILD @@ -1,9 +1,9 @@ # Description: # Tests for LLVM-based CPU backend for XLA. +load("@tsl//tsl:tsl.bzl", "tsl_copts") load("@tsl//tsl:tsl.default.bzl", "filegroup") load("//xla:xla.bzl", "xla_cc_test") -load("@tsl//tsl:tsl.bzl", "tsl_copts") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -38,16 +38,41 @@ cc_library( ], ) +xla_cc_test( + name = "cpu_aot_export_test", + srcs = ["cpu_aot_export_test.cc"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_module_group", + "//xla/service:compiler", + "//xla/service:cpu_plugin", + "//xla/service:executable", + "//xla/service:platform_util", + "//xla/service/cpu:cpu_compiler", + "//xla/stream_executor", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:ARMCodeGen", # fixdeps: keep + "@llvm-project//llvm:X86CodeGen", # fixdeps: keep + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test_main", + ], +) + xla_cc_test( name = "cpu_dyn_shape_test", srcs = ["cpu_dyn_shape_test.cc"], deps = [ + ":cpu_codegen_test", + "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/cpu:cpu_compiler", "//xla/service/cpu:test_header_helper", - "//xla/service/cpu/tests:cpu_codegen_test", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:test", + "@com_google_googletest//:gtest_main", "@tsl//tsl/platform:test_main", ], ) @@ -56,16 +81,17 @@ xla_cc_test( name = "cpu_fusion_test", srcs = ["cpu_fusion_test.cc"], deps = [ + "//xla:error_spec", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", - "//xla:util", "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:cpu_plugin", "//xla/service/cpu:cpu_instruction_fusion", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", - "@com_google_absl//absl/memory", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -76,7 +102,9 @@ xla_cc_test( srcs = ["cpu_bytesizeof_test.cc"], deps = [ "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/service/llvm_ir:llvm_util", + "@llvm-project//llvm:ir_headers", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -86,11 +114,12 @@ xla_cc_test( name = "cpu_external_constants_test", srcs = ["cpu_external_constants_test.cc"], deps = [ + ":cpu_codegen_test", "//xla:array2d", + "//xla:literal_util", "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service/cpu/tests:cpu_codegen_test", - "//xla/tests:filecheck", "@tsl//tsl/platform:test", ], ) @@ -99,18 +128,23 @@ xla_cc_test( name = "cpu_noalias_test", srcs = ["cpu_noalias_test.cc"], deps = [ + ":cpu_codegen_test", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", - "//xla:util", + "//xla:status", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", - "//xla/service/cpu/tests:cpu_codegen_test", + "//xla/service:hlo_ordering", + "//xla/service:logical_buffer", "//xla/service/llvm_ir:alias_analysis", + "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:llvm_util", "//xla/tests:filecheck", - "@com_google_absl//absl/memory", "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -121,13 +155,17 @@ xla_cc_test( srcs = ["cpu_intrinsic_test.cc"], deps = [ ":cpu_codegen_test", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/cpu:cpu_compiler", + "//xla/tests:hlo_test_base", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", "@llvm-project//llvm:ARMCodeGen", # fixdeps: keep "@llvm-project//llvm:Target", "@llvm-project//llvm:X86CodeGen", # fixdeps: keep - "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -139,13 +177,14 @@ xla_cc_test( copts = tsl_copts(), tags = ["no_mac_arm64"], deps = [ + ":cpu_codegen_test", + "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/cpu:cpu_compiler", "//xla/service/cpu:test_header_helper", - "//xla/service/cpu/tests:cpu_codegen_test", "//xla/tests:test_utils", - "@com_google_absl//absl/strings", - "@tsl//tsl/platform:logging", + "@com_google_googletest//:gtest_main", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -159,10 +198,10 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/service/cpu:cpu_compiler", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", "@llvm-project//llvm:ARMCodeGen", # fixdeps: keep "@llvm-project//llvm:Target", "@llvm-project//llvm:X86CodeGen", # fixdeps: keep - "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -173,22 +212,7 @@ xla_cc_test( srcs = ["tree_reduction_rewriter_test.cc"], deps = [ ":cpu_codegen_test", - "//xla:statusor", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", "//xla/service/cpu:cpu_compiler", - "//xla/tests:codegen_test_base", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "//xla/tests:llvm_irgen_test_base", - "//xla/tests:test_utils", - "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -198,9 +222,10 @@ xla_cc_test( name = "cpu_infeed_test", srcs = ["cpu_infeed_test.cc"], deps = [ + "//xla:error_spec", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", - "//xla:statusor", "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/client:global_data", @@ -212,7 +237,6 @@ xla_cc_test( "//xla/tests:client_library_test_base", "//xla/tests:literal_test_util", "@tsl//tsl/platform:env", - "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -222,13 +246,12 @@ xla_cc_test( name = "cpu_literal_caching_test", srcs = ["cpu_literal_caching_test.cc"], deps = [ + ":cpu_codegen_test", "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", "//xla/service/cpu:cpu_compiler", "//xla/service/cpu:test_header_helper", - "//xla/service/cpu/tests:cpu_codegen_test", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:test", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ], ) @@ -237,12 +260,11 @@ xla_cc_test( name = "cpu_outfeed_test", srcs = ["cpu_outfeed_test.cc"], deps = [ - "//xla/hlo/ir:hlo", + ":cpu_codegen_test", "//xla/service/cpu:cpu_compiler", "//xla/service/cpu:test_header_helper", - "//xla/service/cpu/tests:cpu_codegen_test", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:test", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ], ) @@ -251,12 +273,11 @@ xla_cc_test( name = "cpu_key_value_sort_test", srcs = ["cpu_key_value_sort_test.cc"], deps = [ - "//xla/hlo/ir:hlo", + ":cpu_codegen_test", "//xla/service/cpu:cpu_compiler", "//xla/service/cpu:test_header_helper", - "//xla/service/cpu/tests:cpu_codegen_test", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:test", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ], ) @@ -266,12 +287,12 @@ xla_cc_test( srcs = ["cpu_spmd_compile_test.cc"], deps = [ ":cpu_codegen_test", - "//xla/hlo/utils:hlo_query", + "//xla:debug_options_flags", + "//xla/service:executable", "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", "//xla/service/cpu:cpu_compiler", "//xla/service/cpu:test_header_helper", - "//xla/tests:hlo_test_base", + "@com_google_absl//absl/status:statusor", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -283,13 +304,17 @@ xla_cc_test( srcs = ["cpu_topk_test.cc"], deps = [ ":cpu_codegen_test", + "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/client:xla_builder", + "//xla/client:xla_computation", "//xla/client/lib:sorting", "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", "//xla/service/cpu:cpu_compiler", "//xla/service/cpu:test_header_helper", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:test", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ], ) @@ -299,13 +324,18 @@ xla_cc_test( srcs = ["cpu_vectorization_test.cc"], deps = [ ":cpu_codegen_test", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/cpu:cpu_compiler", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", "@llvm-project//llvm:ARMCodeGen", # fixdeps: keep "@llvm-project//llvm:Target", "@llvm-project//llvm:X86CodeGen", # fixdeps: keep - "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -316,14 +346,12 @@ xla_cc_test( srcs = ["cpu_while_test.cc"], deps = [ ":cpu_codegen_test", - "//xla/hlo/ir:hlo", "//xla/service/cpu:cpu_compiler", - "@com_google_absl//absl/strings", + "//xla/tests:literal_test_util", + "@com_google_googletest//:gtest_main", "@llvm-project//llvm:ARMCodeGen", # fixdeps: keep - "@llvm-project//llvm:Target", "@llvm-project//llvm:X86CodeGen", # fixdeps: keep - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:test", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ], ) diff --git a/xla/service/cpu/tests/cpu_aot_export_test.cc b/xla/service/cpu/tests/cpu_aot_export_test.cc new file mode 100644 index 0000000000000..39528634f7cd8 --- /dev/null +++ b/xla/service/cpu/tests/cpu_aot_export_test.cc @@ -0,0 +1,84 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include +#include "absl/strings/ascii.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_module_group.h" +#include "xla/service/compiler.h" +#include "xla/service/executable.h" +#include "xla/service/platform_util.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace cpu { + +using CpuAotCompilationTest = HloTestBase; + +TEST_F(CpuAotCompilationTest, ExportAndLoadExecutable) { + const absl::string_view hlo_string = R"( + HloModule Test + + ENTRY main { + a = f32[2, 2]{1,0} parameter(0) + ROOT b = f32[2, 2]{1,0} add(a, a) + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + auto compiler = backend().compiler(); + auto name = absl::AsciiStrToUpper( + PlatformUtil::CanonicalPlatformName("host").value()); + TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, + se::PlatformManager::PlatformWithName(name)); + TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * stream_exec, + platform->ExecutorForDevice(0)); + + // JIT compile executable + auto module_group = std::make_unique(std::move(module)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector> executables, + compiler->Compile(std::move(module_group), {{stream_exec}}, nullptr)); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr exported_aot_result, + compiler->Export(executables[0].get())); + + // Serialize-deserialize AOT compilation result. + TF_ASSERT_OK_AND_ASSIGN(std::string serialized_aot_result, + exported_aot_result->SerializeAsString()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr loaded_aot_result, + compiler->LoadAotCompilationResult(serialized_aot_result)); + + // Load Executable from AOT compilation result. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + loaded_aot_result->LoadExecutable(compiler, stream_exec)); +} + +} // namespace cpu +} // namespace xla diff --git a/xla/service/cpu/tests/cpu_bytesizeof_test.cc b/xla/service/cpu/tests/cpu_bytesizeof_test.cc index e282b150cd357..64b80d81b63cb 100644 --- a/xla/service/cpu/tests/cpu_bytesizeof_test.cc +++ b/xla/service/cpu/tests/cpu_bytesizeof_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/IR/DataLayout.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/test.h" class CpuByteSizeOfTest : public ::testing::Test {}; diff --git a/xla/service/cpu/tests/cpu_codegen_test.h b/xla/service/cpu/tests/cpu_codegen_test.h index 9cc6d3d149cb4..786d0bf91c32a 100644 --- a/xla/service/cpu/tests/cpu_codegen_test.h +++ b/xla/service/cpu/tests/cpu_codegen_test.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/tests/cpu_dyn_shape_test.cc b/xla/service/cpu/tests/cpu_dyn_shape_test.cc index c0b2dfbc36e81..ab279dc384269 100644 --- a/xla/service/cpu/tests/cpu_dyn_shape_test.cc +++ b/xla/service/cpu/tests/cpu_dyn_shape_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,9 +15,16 @@ limitations under the License. #include +#include +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/cpu/cpu_compiler.h" #include "xla/service/cpu/test_target_triple_helper.h" #include "xla/service/cpu/tests/cpu_codegen_test.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" namespace xla { namespace cpu { diff --git a/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index d088a530b325d..07e91b247aae4 100644 --- a/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,12 +18,15 @@ limitations under the License. #include #include -#include "absl/strings/str_cat.h" +#include #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/cpu/cpu_compiler.h" #include "xla/service/cpu/test_target_triple_helper.h" #include "xla/service/cpu/tests/cpu_codegen_test.h" +#include "xla/shape_util.h" #include "xla/tests/test_utils.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/service/cpu/tests/cpu_external_constants_test.cc b/xla/service/cpu/tests/cpu_external_constants_test.cc index 2a885c7095403..20bcc3f973c2c 100644 --- a/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,9 +20,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" #include "xla/service/cpu/tests/cpu_codegen_test.h" #include "xla/shape_util.h" -#include "xla/tests/filecheck.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/service/cpu/tests/cpu_fusion_test.cc b/xla/service/cpu/tests/cpu_fusion_test.cc index fea0191105f37..fcea12916cbd8 100644 --- a/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/xla/service/cpu/tests/cpu_fusion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,15 +17,18 @@ limitations under the License. #include #include +#include "xla/error_spec.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/service/cpu/cpu_instruction_fusion.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" +#include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/test.h" diff --git a/xla/service/cpu/tests/cpu_infeed_test.cc b/xla/service/cpu/tests/cpu_infeed_test.cc index a939288c6ba24..329faac5d2dd3 100644 --- a/xla/service/cpu/tests/cpu_infeed_test.cc +++ b/xla/service/cpu/tests/cpu_infeed_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,7 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/error_spec.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/literal_util.h" +#include "xla/shape.h" +#ifndef _WIN32 #include +#endif #include @@ -24,7 +31,6 @@ limitations under the License. #include "xla/client/xla_computation.h" #include "xla/literal.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/xla/service/cpu/tests/cpu_intrinsic_test.cc b/xla/service/cpu/tests/cpu_intrinsic_test.cc index 7f22dc09874f6..0e9d32beb5ae1 100644 --- a/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,12 +16,20 @@ limitations under the License. #include #include +#include #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm-c/Target.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/cpu/cpu_compiler.h" #include "xla/service/cpu/tests/cpu_codegen_test.h" +#include "xla/shape_util.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/test.h" namespace xla { @@ -135,15 +143,15 @@ IntrinsicTestSpec CpuUnaryIntrinsicTestCases[] = { IntrinsicTestSpec{ HloOpcode::kTanh, kTriple_x86_64, "", - R"(CHECK: fcmp fast uge <4 x float> %wide.load, )"}, + R"(CHECK: fcmp fast uge <4 x float> %wide.load, )"}, IntrinsicTestSpec{ HloOpcode::kTanh, kTriple_x86_64, "+avx", - R"(CHECK: fcmp fast uge <8 x float> %wide.load, )"}, + R"(CHECK: fcmp fast uge <8 x float> %wide.load, )"}, IntrinsicTestSpec{ HloOpcode::kTanh, kTriple_android_arm, "", - R"(CHECK: fcmp fast uge <4 x float> %wide.load, )"}, + R"(CHECK: fcmp fast uge <4 x float> %wide.load, )"}, IntrinsicTestSpec{ HloOpcode::kLog, kTriple_x86_64, "", diff --git a/xla/service/cpu/tests/cpu_key_value_sort_test.cc b/xla/service/cpu/tests/cpu_key_value_sort_test.cc index 5094265aea76e..2a8d9d8a24442 100644 --- a/xla/service/cpu/tests/cpu_key_value_sort_test.cc +++ b/xla/service/cpu/tests/cpu_key_value_sort_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,9 +15,11 @@ limitations under the License. #include +#include #include "xla/service/cpu/cpu_compiler.h" #include "xla/service/cpu/test_target_triple_helper.h" #include "xla/service/cpu/tests/cpu_codegen_test.h" +#include "tsl/platform/statusor.h" namespace xla { namespace cpu { diff --git a/xla/service/cpu/tests/cpu_literal_caching_test.cc b/xla/service/cpu/tests/cpu_literal_caching_test.cc index c010cc7cf79ab..23f3fd39c9448 100644 --- a/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/cpu/cpu_compiler.h" #include "xla/service/cpu/test_target_triple_helper.h" #include "xla/service/cpu/tests/cpu_codegen_test.h" -#include "xla/service/hlo_parser.h" +#include "tsl/platform/statusor.h" namespace xla { namespace cpu { diff --git a/xla/service/cpu/tests/cpu_noalias_test.cc b/xla/service/cpu/tests/cpu_noalias_test.cc index b46669888d896..ca64a1263c3e5 100644 --- a/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/xla/service/cpu/tests/cpu_noalias_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,18 +16,32 @@ limitations under the License. #include #include +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Casting.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/tests/cpu_codegen_test.h" +#include "xla/service/hlo_ordering.h" #include "xla/service/llvm_ir/alias_analysis.h" +#include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "xla/service/logical_buffer.h" #include "xla/shape_util.h" +#include "xla/status.h" #include "xla/tests/filecheck.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/service/cpu/tests/cpu_outfeed_test.cc b/xla/service/cpu/tests/cpu_outfeed_test.cc index c3c3c2c01458a..4dd41d5d8c9dc 100644 --- a/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,9 +15,11 @@ limitations under the License. #include +#include #include "xla/service/cpu/cpu_compiler.h" #include "xla/service/cpu/test_target_triple_helper.h" #include "xla/service/cpu/tests/cpu_codegen_test.h" +#include "tsl/platform/statusor.h" namespace xla { namespace cpu { diff --git a/xla/service/cpu/tests/cpu_profiling_test.cc b/xla/service/cpu/tests/cpu_profiling_test.cc index 92ec226e8ff06..80acbd762715c 100644 --- a/xla/service/cpu/tests/cpu_profiling_test.cc +++ b/xla/service/cpu/tests/cpu_profiling_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "absl/strings/ascii.h" -#include "absl/strings/str_cat.h" +#include +#include "absl/strings/string_view.h" #include "llvm-c/Target.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/service/cpu/cpu_compiler.h" diff --git a/xla/service/cpu/tests/cpu_spmd_compile_test.cc b/xla/service/cpu/tests/cpu_spmd_compile_test.cc index ae0a56d70c7d4..6dc8cb9f7bb08 100644 --- a/xla/service/cpu/tests/cpu_spmd_compile_test.cc +++ b/xla/service/cpu/tests/cpu_spmd_compile_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,13 +17,13 @@ limitations under the License. #include #include -#include "xla/hlo/utils/hlo_query.h" +#include "absl/status/statusor.h" +#include "xla/debug_options_flags.h" #include "xla/service/cpu/cpu_compiler.h" #include "xla/service/cpu/test_target_triple_helper.h" #include "xla/service/cpu/tests/cpu_codegen_test.h" +#include "xla/service/executable.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" -#include "xla/tests/hlo_test_base.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" @@ -50,7 +50,7 @@ ENTRY entry { auto hlo_module = ParseAndReturnVerifiedModule(hlo_string, config).value(); // Verify that compilation succeeded. - StatusOr> executable = + absl::StatusOr> executable = CompileToExecutable(std::move(hlo_module)); TF_EXPECT_OK(executable.status()); } diff --git a/xla/service/cpu/tests/cpu_topk_test.cc b/xla/service/cpu/tests/cpu_topk_test.cc index 2b44e5941a89b..618fd0f02a904 100644 --- a/xla/service/cpu/tests/cpu_topk_test.cc +++ b/xla/service/cpu/tests/cpu_topk_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,11 +15,19 @@ limitations under the License. #include +#include #include "xla/client/lib/sorting.h" #include "xla/client/xla_builder.h" +#include "xla/client/xla_computation.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/cpu/cpu_compiler.h" #include "xla/service/cpu/test_target_triple_helper.h" #include "xla/service/cpu/tests/cpu_codegen_test.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace cpu { diff --git a/xla/service/cpu/tests/cpu_vectorization_test.cc b/xla/service/cpu/tests/cpu_vectorization_test.cc index 0df598d885b78..ec29e43b3aff9 100644 --- a/xla/service/cpu/tests/cpu_vectorization_test.cc +++ b/xla/service/cpu/tests/cpu_vectorization_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,12 +16,20 @@ limitations under the License. #include #include +#include +#include "absl/algorithm/container.h" #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "llvm-c/Target.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/cpu/cpu_compiler.h" #include "xla/service/cpu/tests/cpu_codegen_test.h" +#include "xla/shape_util.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/service/cpu/tests/cpu_while_test.cc b/xla/service/cpu/tests/cpu_while_test.cc index 3504b1c080579..934aed068cb18 100644 --- a/xla/service/cpu/tests/cpu_while_test.cc +++ b/xla/service/cpu/tests/cpu_while_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,8 +16,10 @@ limitations under the License. #include #include -#include "xla/service/cpu/cpu_compiler.h" +#include #include "xla/service/cpu/tests/cpu_codegen_test.h" +#include "xla/tests/literal_test_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace cpu { diff --git a/xla/service/cpu/tests/tree_reduction_rewriter_test.cc b/xla/service/cpu/tests/tree_reduction_rewriter_test.cc index 3b182c91d307a..e46b4ac15ed42 100644 --- a/xla/service/cpu/tests/tree_reduction_rewriter_test.cc +++ b/xla/service/cpu/tests/tree_reduction_rewriter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,16 +15,7 @@ limitations under the License. #include -#include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/cpu/tests/cpu_codegen_test.h" -#include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" -#include "xla/statusor.h" -#include "xla/tests/filecheck.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/llvm_irgen_test_base.h" -#include "tsl/lib/core/status_test_util.h" -#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/service/cpu/tiled_dot_emitter.cc b/xla/service/cpu/tiled_dot_emitter.cc index b3cb7228810d0..e06fb90f9b764 100644 --- a/xla/service/cpu/tiled_dot_emitter.cc +++ b/xla/service/cpu/tiled_dot_emitter.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/tiled_dot_emitter.h b/xla/service/cpu/tiled_dot_emitter.h index e5060453a5155..2d32854f65089 100644 --- a/xla/service/cpu/tiled_dot_emitter.h +++ b/xla/service/cpu/tiled_dot_emitter.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/vector_support_library.cc b/xla/service/cpu/vector_support_library.cc index 6cd7a076707a8..ef4967ac6debe 100644 --- a/xla/service/cpu/vector_support_library.cc +++ b/xla/service/cpu/vector_support_library.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/vector_support_library.h b/xla/service/cpu/vector_support_library.h index 7250cf2d8bad4..958ad312cb0fb 100644 --- a/xla/service/cpu/vector_support_library.h +++ b/xla/service/cpu/vector_support_library.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc b/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc index 78ab194036440..e522c68d22e04 100644 --- a/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc +++ b/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,7 +26,8 @@ namespace xla { namespace { class CodegenReduceOnArchWithNoVectorRegisters : public HloTestBase {}; -StatusOr GetTargetVectorRegisterByteSize(std::string triple) { +absl::StatusOr GetTargetVectorRegisterByteSize( + std::string triple) { // Unfortunately we need a lot of boilerplate to get to an // llvm::TargetMachine. @@ -34,7 +35,7 @@ StatusOr GetTargetVectorRegisterByteSize(std::string triple) { const llvm::Target* target = llvm::TargetRegistry::lookupTarget(triple, error); if (target == nullptr) { - return InternalError("TargetRegistry::lookupTarget failed: %s", error); + return Internal("TargetRegistry::lookupTarget failed: %s", error); } llvm::LLVMContext context; diff --git a/xla/service/cpu/windows_compatibility.cc b/xla/service/cpu/windows_compatibility.cc index 3805bce16d422..f6da04d750f7e 100644 --- a/xla/service/cpu/windows_compatibility.cc +++ b/xla/service/cpu/windows_compatibility.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/windows_compatibility.h b/xla/service/cpu/windows_compatibility.h index ce213977eb2e3..4e10087e5fa74 100644 --- a/xla/service/cpu/windows_compatibility.h +++ b/xla/service/cpu/windows_compatibility.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/xfeed_manager.cc b/xla/service/cpu/xfeed_manager.cc index c5f3064f082c1..34e56780c6c84 100644 --- a/xla/service/cpu/xfeed_manager.cc +++ b/xla/service/cpu/xfeed_manager.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -68,7 +68,7 @@ XfeedBuffer* XfeedQueueManager::BlockingDequeueBuffer() { } void XfeedQueueManager::ReleaseCurrentBuffer(int32_t length, void* data, - StatusOr shape) { + absl::StatusOr shape) { VLOG(3) << "Releasing buffer with shape: " << (shape.ok() ? ShapeUtil::HumanString(shape.value()) : ""); @@ -81,7 +81,7 @@ void XfeedQueueManager::ReleaseCurrentBuffer(int32_t length, void* data, } int64_t GetByteSizeRequirement(const Shape& shape, int64_t pointer_size) { - if (shape.is_static() || shape.IsTuple()) { + if (shape.IsTuple() || shape.is_static()) { return ShapeUtil::ByteSizeOf(shape, pointer_size); } int64_t metadata_size = sizeof(int32_t) * shape.dimensions_size(); diff --git a/xla/service/cpu/xfeed_manager.h b/xla/service/cpu/xfeed_manager.h index 0d5d5ee34707c..e29d3b7d945eb 100644 --- a/xla/service/cpu/xfeed_manager.h +++ b/xla/service/cpu/xfeed_manager.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -44,7 +44,7 @@ class XfeedBuffer { // The 'shape' parameter reflects what shape the embedded program was // expecting / producing with respect to this XfeedBuffer. E.g. this will // contain information about the layout of an outfed buffer. - virtual void Done(StatusOr shape) = 0; + virtual void Done(absl::StatusOr shape) = 0; }; // Reusable component for managing the infeed and outfeed queue state. @@ -82,7 +82,8 @@ class XfeedQueueManager { // error status. In the case of outfeed, this indicates the layout of the // shape that has been outfed. In the case of infeed, this can be used for // sanity checking purposes. - void ReleaseCurrentBuffer(int32_t length, void* data, StatusOr shape); + void ReleaseCurrentBuffer(int32_t length, void* data, + absl::StatusOr shape); private: const std::string queue_name_; diff --git a/xla/service/cpu/xfeed_manager_test.cc b/xla/service/cpu/xfeed_manager_test.cc index 046ed55576638..5c6be64e6a7dd 100644 --- a/xla/service/cpu/xfeed_manager_test.cc +++ b/xla/service/cpu/xfeed_manager_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ class TestInfeedBuffer : public cpu::runtime::XfeedBuffer { int32_t length() override { return length_; } void* data() override { return nullptr; } - void Done(StatusOr shape) override { + void Done(absl::StatusOr shape) override { CHECK(!done_called_); done_called_ = true; TF_ASSERT_OK(shape.status()); diff --git a/xla/service/cpu/xla_framework.h b/xla/service/cpu/xla_framework.h index bd7d572fc3c6e..e3a0ff07facb9 100644 --- a/xla/service/cpu/xla_framework.h +++ b/xla/service/cpu/xla_framework.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu/xla_framework.proto b/xla/service/cpu/xla_framework.proto index 1530007487065..ce4b3874de871 100644 --- a/xla/service/cpu/xla_framework.proto +++ b/xla/service/cpu/xla_framework.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,4 +22,4 @@ message XlaFrameworkMappingProto { repeated int64 flattened_outputs = 2 [packed = true]; optional int64 result = 3 [default = -1]; optional bool output_is_tuple = 4; -} \ No newline at end of file +} diff --git a/xla/service/cpu_gpu_shape_verifier.cc b/xla/service/cpu_gpu_shape_verifier.cc index 1e39bd945421c..d67269aa51bf2 100644 --- a/xla/service/cpu_gpu_shape_verifier.cc +++ b/xla/service/cpu_gpu_shape_verifier.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,12 +32,15 @@ Status VerifyS4U4Usage(HloInstruction* instruction) { switch (instruction->opcode()) { case HloOpcode::kBitcast: case HloOpcode::kConstant: + case HloOpcode::kConcatenate: case HloOpcode::kConvert: case HloOpcode::kCopy: case HloOpcode::kFusion: case HloOpcode::kGetTupleElement: case HloOpcode::kParameter: + case HloOpcode::kSlice: case HloOpcode::kTuple: + case HloOpcode::kWhile: break; default: TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( diff --git a/xla/service/cpu_gpu_shape_verifier.h b/xla/service/cpu_gpu_shape_verifier.h index ef43fcc59670e..78fd43649b37d 100644 --- a/xla/service/cpu_gpu_shape_verifier.h +++ b/xla/service/cpu_gpu_shape_verifier.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/cpu_gpu_shape_verifier_test.cc b/xla/service/cpu_gpu_shape_verifier_test.cc index 12fb0fc2ca075..18e852b4d594d 100644 --- a/xla/service/cpu_gpu_shape_verifier_test.cc +++ b/xla/service/cpu_gpu_shape_verifier_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/custom_call_sharding_helper.cc b/xla/service/custom_call_sharding_helper.cc index 9ed93009a0593..8bc09a4ca79ae 100644 --- a/xla/service/custom_call_sharding_helper.cc +++ b/xla/service/custom_call_sharding_helper.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,7 +35,12 @@ bool CustomCallShardingHelper::IsCustomCallShardable( return false; } -xla::Status CustomCallPartitioner::Partition( +bool CustomCallShardingHelper::CanPropagateShardingToOperands( + const HloInstruction* instruction) const { + return true; +} + +absl::Status CustomCallPartitioner::Partition( spmd::SpmdPartitioningVisitor* partitioner, HloInstruction* hlo) const { return xla::Unimplemented("Implement sharding for %s", hlo->ToString()); } diff --git a/xla/service/custom_call_sharding_helper.h b/xla/service/custom_call_sharding_helper.h index e3432342b8ae7..f4287e39ba685 100644 --- a/xla/service/custom_call_sharding_helper.h +++ b/xla/service/custom_call_sharding_helper.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -47,6 +47,10 @@ class CustomCallShardingHelper { HloInstruction* instruction) const { return {}; } + // Returns if the given custom-call instruction can propagate sharding to its + // operands. + virtual bool CanPropagateShardingToOperands( + const HloInstruction* instruction) const; virtual ~CustomCallShardingHelper() = default; }; @@ -58,8 +62,8 @@ class SpmdPartitioningVisitor; // policies. class CustomCallPartitioner : public CustomCallShardingHelper { public: - virtual xla::Status Partition(spmd::SpmdPartitioningVisitor* partitioner, - HloInstruction* hlo) const; + virtual absl::Status Partition(spmd::SpmdPartitioningVisitor* partitioner, + HloInstruction* hlo) const; // Returns if the given side-effecting custom-call is allowed to have // replicated sharding. diff --git a/xla/service/custom_call_status.cc b/xla/service/custom_call_status.cc index 2e07ef3c51fcd..44b89148d2c25 100644 --- a/xla/service/custom_call_status.cc +++ b/xla/service/custom_call_status.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/custom_call_status.h b/xla/service/custom_call_status.h index db6da7e36c338..68287b22776de 100644 --- a/xla/service/custom_call_status.h +++ b/xla/service/custom_call_status.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/custom_call_status_internal.h b/xla/service/custom_call_status_internal.h index 4758a19986653..e2ccc5118719e 100644 --- a/xla/service/custom_call_status_internal.h +++ b/xla/service/custom_call_status_internal.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/custom_call_status_test.cc b/xla/service/custom_call_status_test.cc index cb253d89dc5e1..11900d7485c89 100644 --- a/xla/service/custom_call_status_test.cc +++ b/xla/service/custom_call_status_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/custom_call_status_test_c_caller.c b/xla/service/custom_call_status_test_c_caller.c index bb53b6aef8830..3739b21a26912 100644 --- a/xla/service/custom_call_status_test_c_caller.c +++ b/xla/service/custom_call_status_test_c_caller.c @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/custom_call_status_test_c_caller.h b/xla/service/custom_call_status_test_c_caller.h index cb3b8e9693c6c..380f91f59a1a5 100644 --- a/xla/service/custom_call_status_test_c_caller.h +++ b/xla/service/custom_call_status_test_c_caller.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/custom_call_target_registry.cc b/xla/service/custom_call_target_registry.cc index de1fa77851431..d19db6ab0da60 100644 --- a/xla/service/custom_call_target_registry.cc +++ b/xla/service/custom_call_target_registry.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,13 @@ limitations under the License. #include "xla/service/custom_call_target_registry.h" +#include +#include +#include // NOLINT +#include +#include +#include + namespace xla { CustomCallTargetRegistry* CustomCallTargetRegistry::Global() { @@ -26,7 +33,17 @@ void CustomCallTargetRegistry::Register(const std::string& symbol, void* address, const std::string& platform) { std::lock_guard lock(mu_); - registered_symbols_[std::make_pair(symbol, platform)] = address; + const auto [it, inserted] = + registered_symbols_.insert({{symbol, platform}, address}); + if (!inserted && it->second != address) { + std::cerr << "Duplicate custom call registration detected for symbol \"" + << symbol << "\" with different addresses " << address + << "(current) and " << it->second << " (previous) on platform " + << platform + << "Rejecting the registration to avoid confusion about which " + "symbol would actually get used at runtime.\n"; + std::exit(1); + } } void* CustomCallTargetRegistry::Lookup(const std::string& symbol, @@ -36,4 +53,18 @@ void* CustomCallTargetRegistry::Lookup(const std::string& symbol, return it == registered_symbols_.end() ? nullptr : it->second; } +std::unordered_map +CustomCallTargetRegistry::registered_symbols( + const std::string& platform) const { + std::unordered_map calls; + std::lock_guard lock(mu_); + for (const auto& [metadata, address] : registered_symbols_) { + if (metadata.second == platform) { + calls[metadata.first] = address; + } + } + + return calls; +} + } // namespace xla diff --git a/xla/service/custom_call_target_registry.h b/xla/service/custom_call_target_registry.h index 51ca96033d163..e78d188426ee3 100644 --- a/xla/service/custom_call_target_registry.h +++ b/xla/service/custom_call_target_registry.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,7 +35,14 @@ namespace xla { // The XLA:CPU ahead-of-time (AOT) compiler links using a standard offline // linker; so when compiling in CPU AOT mode, you *also* need to make sure the // name of the callee (presumably implemented in C++) matches up with the -// symbolic name used in the CustomCall. +// symbolic name used in the CustomCall. Be careful with the name of the symbol +// you register with the macros: C++ namespaces are not included, including +// anonymous namespaces,so if two libraries attempt to register functions with +// the same name in separate namespaces the registrations will collide. Either +// call the registration macro from the global namespace so that you have to +// refer to the function in a fully-qualified manner (which also requires you to +// emit HLO-based calls to it by the fully-qualified name *and* complicates +// future refactoring!) or use C-style namespacing directly in the symbol name. // // We maintain the registry in both the JIT and the AOT cases for simplicity, // but we only use it when running in JIT mode. @@ -47,6 +54,9 @@ class CustomCallTargetRegistry { const std::string& platform); void* Lookup(const std::string& symbol, const std::string& platform) const; + std::unordered_map registered_symbols( + const std::string& platform) const; + private: // hash> is surprisingly not provided by default in stl. It would // be better to use absl's hash function, but we're avoiding an absl diff --git a/xla/service/custom_call_target_registry_test.cc b/xla/service/custom_call_target_registry_test.cc new file mode 100644 index 0000000000000..1b423449953ba --- /dev/null +++ b/xla/service/custom_call_target_registry_test.cc @@ -0,0 +1,68 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/custom_call_target_registry.h" + +#include "xla/service/custom_call_status.h" +#include "xla/test.h" + +namespace xla { +namespace { + +using ::testing::_; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +void custom_call(void*, const void**, XlaCustomCallStatus*) {} +void custom_call2(void*, const void**, XlaCustomCallStatus*) {} + +TEST(CustomCallRegistryTest, Registers) { + CustomCallTargetRegistry registry; + EXPECT_EQ(registry.Lookup("custom_call", "Host"), nullptr); + registry.Register("custom_call", reinterpret_cast(custom_call), + "Host"); + EXPECT_EQ(custom_call, registry.Lookup("custom_call", "Host")); + // A registration with a different name is fine. + registry.Register("custom_call2", reinterpret_cast(&custom_call), + "Host"); + + EXPECT_EQ(registry.Lookup("custom_call", "CUDA"), nullptr); + // A registration on a different platform is fine. + registry.Register("custom_call", reinterpret_cast(custom_call), + "CUDA"); + EXPECT_EQ(custom_call, registry.Lookup("custom_call", "CUDA")); + + // A second registration of the same function is fine. + registry.Register("custom_call", reinterpret_cast(custom_call), + "Host"); + + EXPECT_THAT( + registry.registered_symbols("Host"), + UnorderedElementsAre(Pair("custom_call", _), Pair("custom_call2", _))); + EXPECT_THAT(registry.registered_symbols("CUDA"), + UnorderedElementsAre(Pair("custom_call", _))); +} + +TEST(CustomCallRegistryDeathTest, RejectsDuplicateRegistrations) { + CustomCallTargetRegistry registry; + registry.Register("custom_call", reinterpret_cast(custom_call), + "Host"); + EXPECT_DEATH(registry.Register("custom_call", + reinterpret_cast(custom_call2), "Host"), + "Duplicate custom call"); +} + +} // namespace +} // namespace xla diff --git a/xla/service/defuser.cc b/xla/service/defuser.cc index 77ba8ebe98919..2e8f9f03a8237 100644 --- a/xla/service/defuser.cc +++ b/xla/service/defuser.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,59 +36,7 @@ limitations under the License. namespace xla { -namespace { - -// Copy all the instructions in the given fusion instruction into the fusion -// instruction's parent computation and replace the use of the fusion -// instruction with the copy of the fusion expression root. -Status Defuse(HloInstruction* fusion_instruction) { - VLOG(2) << "Defusing instruction: " << fusion_instruction->ToString(); - - HloComputation* fused_computation = - fusion_instruction->fused_instructions_computation(); - - // A map from fused instruction to its defused clone. - absl::flat_hash_map - defused_instructions; - // Initialize map to contain the fusion instruction parameters mapping - // to the operands of the fusion instruction. - for (int64_t i = 0; i < fusion_instruction->operand_count(); ++i) { - defused_instructions[fused_computation->parameter_instruction(i)] = - fusion_instruction->mutable_operand(i); - } - - // Create a clone of each instruction of the fused computation in the same - // computation as the fusion instruction itself. - // TODO(b/68227302): Moving instruction to new computation rather than - // cloning and deleting. - for (HloInstruction* fused_instruction : - fused_computation->MakeInstructionPostOrder()) { - if (fused_instruction->opcode() == HloOpcode::kParameter) { - continue; - } - std::vector new_operands; - for (HloInstruction* operand : fused_instruction->operands()) { - new_operands.push_back(defused_instructions.at(operand)); - } - HloInstruction* defused_instruction = - fusion_instruction->parent()->AddInstruction( - fused_instruction->CloneWithNewOperands(fused_instruction->shape(), - new_operands)); - defused_instructions[fused_instruction] = defused_instruction; - } - - TF_RETURN_IF_ERROR(fusion_instruction->ReplaceAllUsesWith( - defused_instructions.at(fusion_instruction->fused_expression_root()))); - - HloModule* module = fusion_instruction->GetModule(); - TF_RETURN_IF_ERROR( - fusion_instruction->parent()->RemoveInstruction(fusion_instruction)); - return module->RemoveEmbeddedComputation(fused_computation); -} - -} // namespace - -StatusOr Defuser::Run( +absl::StatusOr Defuser::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(1) << "Defusing module " << module->name(); @@ -102,7 +50,7 @@ StatusOr Defuser::Run( TF_RET_CHECK(call_graph_node.caller_callsites().size() == 1); HloInstruction* fusion_instruction = call_graph_node.caller_callsites()[0].instruction(); - TF_RETURN_IF_ERROR(Defuse(fusion_instruction)); + TF_RETURN_IF_ERROR(fusion_instruction->Defuse()); changed = true; } return OkStatus(); diff --git a/xla/service/defuser.h b/xla/service/defuser.h index 9cf57c79357f7..8d7f4eaefa40f 100644 --- a/xla/service/defuser.h +++ b/xla/service/defuser.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ class Defuser : public HloModulePass { // Run defusion on the given module. Returns whether the module was // changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/defuser_test.cc b/xla/service/defuser_test.cc index 1b23d371c60dc..ad70f7998c66a 100644 --- a/xla/service/defuser_test.cc +++ b/xla/service/defuser_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/despecializer.cc b/xla/service/despecializer.cc index c63aeb3690573..d7c413f6bf36f 100644 --- a/xla/service/despecializer.cc +++ b/xla/service/despecializer.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -40,13 +40,13 @@ void Despecializer::AddReduceWindowToReduceBroadcastDeconstruct() { pipeline_.AddPass(); } -StatusOr Despecializer::Run( +absl::StatusOr Despecializer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { return pipeline_.Run(module, execution_threads); } -StatusOr DeconstructReduceWindowToReduceBroadcast::Run( +absl::StatusOr DeconstructReduceWindowToReduceBroadcast::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/despecializer.h b/xla/service/despecializer.h index 12d3a8212bc7d..054b99353a0f2 100644 --- a/xla/service/despecializer.h +++ b/xla/service/despecializer.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ class Despecializer : public HloModulePass { void AddReduceWindowToReduceBroadcastDeconstruct(); absl::string_view name() const override { return "despecializer"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -58,7 +58,7 @@ class DeconstructReduceWindowToReduceBroadcast : public HloModulePass { return "ReduceWindowToReduceAndBroadcast"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; @@ -70,9 +70,9 @@ class ControlDepRemover : public HloModulePass { absl::string_view name() const override { return "control-dep-remover"; } using HloPassInterface::Run; - StatusOr Run(HloModule* module, - const absl::flat_hash_set& - execution_threads) override { + absl::StatusOr Run(HloModule* module, + const absl::flat_hash_set& + execution_threads) override { bool changed = false; for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { diff --git a/xla/service/despecializer_test.cc b/xla/service/despecializer_test.cc index b026081891db4..6ba16f6b8f32e 100644 --- a/xla/service/despecializer_test.cc +++ b/xla/service/despecializer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/dfs_hlo_visitor_with_default_test.cc b/xla/service/dfs_hlo_visitor_with_default_test.cc index 3b4e068294349..a3a2329992ffd 100644 --- a/xla/service/dfs_hlo_visitor_with_default_test.cc +++ b/xla/service/dfs_hlo_visitor_with_default_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/dot_as_convolution_util.cc b/xla/service/dot_as_convolution_util.cc index 58073e4378823..e22dddcf7cee6 100644 --- a/xla/service/dot_as_convolution_util.cc +++ b/xla/service/dot_as_convolution_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -132,7 +132,7 @@ bool SpatialIsContracting(int64_t lhs_spatial_size, int64_t rhs_spatial_size, return dims; } -StatusOr> +absl::StatusOr> CreateShardedConvForDotGeneralConvolution( const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums, HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo) { diff --git a/xla/service/dot_as_convolution_util.h b/xla/service/dot_as_convolution_util.h index c84ea3c9a2dda..01236f8c7ec9d 100644 --- a/xla/service/dot_as_convolution_util.h +++ b/xla/service/dot_as_convolution_util.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -67,7 +67,7 @@ DotConvolutionDimsInfo ParseConvolutionDimsInfo(const HloInstruction* conv); // - 'dot_dnums' is the result of ParseDotConvolutionDimsInfo() for 'conv'. // - 'sharded_lhs_hlo' and 'sharded_rhs_hlo' are sharded inputs for the result // convolution instruction. -StatusOr> +absl::StatusOr> CreateShardedConvForDotGeneralConvolution( const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums, HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo); diff --git a/xla/service/dot_decomposer.cc b/xla/service/dot_decomposer.cc index 45d8554556217..33e77654ff60d 100644 --- a/xla/service/dot_decomposer.cc +++ b/xla/service/dot_decomposer.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,20 +15,28 @@ limitations under the License. #include "xla/service/dot_decomposer.h" +#include +#include +#include #include +#include #include "absl/algorithm/container.h" -#include "absl/strings/str_join.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/layout_util.h" -#include "xla/permutation_util.h" -#include "xla/service/sparse_util.h" +#include "xla/service/shape_inference.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/types.h" +#include "xla/status.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -40,13 +48,26 @@ namespace { // * Batch dimensions are the most major dimensions. // This requires transposing and reshaping of the lhs and rhs, and reshaping the // output batch to the original shape. -Status CanonicalizeDot(HloInstruction* original_dot) { +Status CanonicalizeDot(HloDotInstruction* original_dot) { auto computation = original_dot->parent(); const auto& original_dnums = original_dot->dot_dimension_numbers(); const int64_t num_batch_dims = original_dnums.lhs_batch_dimensions_size(); const int64_t num_contracting_dims = original_dnums.lhs_contracting_dimensions_size(); + // Sparse dimension (if present), must be at the end of the contracting + // dimensions list. + int lhs_sparse_dim = -1, rhs_sparse_dim = -1; + for (const SparsityDescriptor& descriptor : original_dot->sparsity()) { + (descriptor.index() == 0 ? lhs_sparse_dim : rhs_sparse_dim) = + descriptor.dimension(); + } + auto move_dim_to_end = [&](std::vector& dims, int sparse_dim) { + if (sparse_dim < 0) return; + auto it = std::remove(dims.begin(), dims.end(), sparse_dim); + *it = sparse_dim; // Effectively the same as erase+push_back. + }; + const auto& lhs_shape = original_dot->operand(0)->shape(); const int64_t lhs_rank = lhs_shape.rank(); const int64_t num_lhs_non_contracting_dims = @@ -89,6 +110,7 @@ Status CanonicalizeDot(HloInstruction* original_dot) { lhs_transpose.insert(lhs_transpose.end(), original_dnums.lhs_contracting_dimensions().begin(), original_dnums.lhs_contracting_dimensions().end()); + move_dim_to_end(lhs_transpose, lhs_sparse_dim); HloInstruction* lhs_operand = original_dot->mutable_operand(0); HloInstruction* transposed_lhs = computation->AddInstruction( HloInstruction::CreateTranspose( @@ -145,6 +167,7 @@ Status CanonicalizeDot(HloInstruction* original_dot) { rhs_transpose.insert(rhs_transpose.end(), original_dnums.rhs_contracting_dimensions().begin(), original_dnums.rhs_contracting_dimensions().end()); + move_dim_to_end(rhs_transpose, rhs_sparse_dim); rhs_transpose.insert(rhs_transpose.end(), rhs_non_contracting_dims.begin(), rhs_non_contracting_dims.end()); HloInstruction* rhs_operand = original_dot->mutable_operand(1); @@ -190,10 +213,47 @@ Status CanonicalizeDot(HloInstruction* original_dot) { num_batch_dims + (lhs_non_contracting_size > 1 ? 1 : 0)); dot_dnums.add_rhs_contracting_dimensions(num_batch_dims); + // Build sparsity data for the new dot. + std::vector sparsity; + std::vector sparse_meta; + sparsity.reserve(original_dot->sparse_operands()); + sparse_meta.reserve(original_dot->sparse_operands()); + auto transpose_meta = [&](HloInstruction* original_meta, + absl::Span transpose) { + return computation->AddInstruction( + HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions(transpose, original_meta->shape()), + original_meta, transpose), + &original_meta->metadata()); + }; + for (int i = 0; i < original_dot->sparse_operands(); ++i) { + SparsityDescriptor descriptor = original_dot->sparsity()[i]; + descriptor.set_dimension(num_batch_dims + (descriptor.index() == 0 && + lhs_non_contracting_size > 1)); + sparsity.push_back(descriptor); + HloInstruction* meta = + original_dot->mutable_operand(HloDotInstruction::kOperands + i); + HloInstruction* meta_operand; + if (descriptor.index() == 0) { + meta = transpose_meta(meta, lhs_transpose); + meta_operand = reshaped_lhs; + } else { + meta = transpose_meta(meta, rhs_transpose); + meta_operand = reshaped_rhs; + } + TF_ASSIGN_OR_RETURN(Shape result_shape, + ShapeInference::InferSparseDotMetadataShape( + meta_operand->shape(), dot_dnums, descriptor)); + meta = computation->AddInstruction( + HloInstruction::CreateReshape(result_shape, meta), &meta->metadata()); + sparse_meta.push_back(meta); + } + HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot( ShapeUtil::MakeShape(original_dot->shape().element_type(), dot_dims, dot_dynamic_dims), - reshaped_lhs, reshaped_rhs, dot_dnums, original_dot->precision_config())); + reshaped_lhs, reshaped_rhs, dot_dnums, original_dot->precision_config(), + sparsity, sparse_meta)); original_dot->SetupDerivedInstruction(dot); std::unique_ptr replacement = @@ -208,7 +268,7 @@ Status CanonicalizeDot(HloInstruction* original_dot) { } // namespace -StatusOr DotDecomposer::Run( +absl::StatusOr DotDecomposer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { // Gather all Non-canonical Dot operations. @@ -219,11 +279,6 @@ StatusOr DotDecomposer::Run( if (instruction->opcode() != HloOpcode::kDot) { continue; } - // Skips sparse instruction as DotDecomposer does not know how to handle - // sparse input yet. - if (SparseUtil::HasSparseInOut(instruction)) { - continue; - } const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); // A dot it not canonical if there is more than one contracting dimension. if (dnums.lhs_contracting_dimensions_size() != 1) { @@ -256,7 +311,7 @@ StatusOr DotDecomposer::Run( } bool changed = false; for (auto* dot : non_canonical_dots) { - TF_RETURN_IF_ERROR(CanonicalizeDot(dot)); + TF_RETURN_IF_ERROR(CanonicalizeDot(Cast(dot))); changed = true; } return changed; diff --git a/xla/service/dot_decomposer.h b/xla/service/dot_decomposer.h index af826ed8c7ad2..4fbe3123f63c5 100644 --- a/xla/service/dot_decomposer.h +++ b/xla/service/dot_decomposer.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_DOT_DECOMPOSER_H_ #define XLA_SERVICE_DOT_DECOMPOSER_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" @@ -31,7 +34,7 @@ class DotDecomposer : public HloModulePass { // Run DotDecomposer pass on computations in 'module'. // Returns whether the 'module' was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/dot_decomposer_test.cc b/xla/service/dot_decomposer_test.cc index 24f0921032c7b..70cd99d44d6bf 100644 --- a/xla/service/dot_decomposer_test.cc +++ b/xla/service/dot_decomposer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,16 +15,27 @@ limitations under the License. #include "xla/service/dot_decomposer.h" +#include + +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/service/hlo_parser.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_utils.h" - -namespace op = xla::testing::opcode_matchers; +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" namespace xla { namespace { +namespace m = ::xla::match; +namespace op = ::xla::testing::opcode_matchers; + using DotDecomposerTest = HloTestBase; TEST_F(DotDecomposerTest, CanonicalizeMultipleNonContractingDims) { @@ -120,5 +131,72 @@ TEST_F(DotDecomposerTest, DontAddRhsNonContractingDimIfOne) { op::Shape("f32[64,2]")))); } +template +auto SparseDotMatcher(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) { + return match::Op() + .WithOpcode(HloOpcode::kDot) + .WithOperand(0, std::forward(arg0)) + .WithOperand(1, std::forward(arg1)) + .WithOperand(2, std::forward(arg2)); +} + +TEST_F(DotDecomposerTest, CanonicalizeSparseLhs) { + absl::string_view kHlo = R"( + HloModule module + + ENTRY main { + lhs = f32[16,4,3,7] parameter(0) + rhs = f32[32,4,5,7] parameter(1) + meta = u16[2,4,3,7] parameter(2) + ROOT dot = f32[7,3,5] dot(lhs, rhs, meta), sparsity=L.0@2:4, + lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1}, + lhs_batch_dims={3}, rhs_batch_dims={3} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHlo)); + TF_ASSERT_OK_AND_ASSIGN(bool canonicalized, + DotDecomposer().Run(module.get())); + EXPECT_TRUE(canonicalized); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Reshape(SparseDotMatcher( + m::Reshape(m::Transpose(m::Parameter(0))), + m::Reshape(m::Transpose(m::Parameter(1))), + m::Reshape(m::Transpose(m::Parameter(2))))))); + auto dot = Cast(root->operand(0)); + auto descriptor = dot->sparsity().front(); + EXPECT_EQ(descriptor.index(), 0); + EXPECT_EQ(descriptor.dimension(), 2); +} + +TEST_F(DotDecomposerTest, CanonicalizeSparseRhs) { + absl::string_view kHlo = R"( + HloModule module + + ENTRY main { + lhs = f32[32,4,3,7] parameter(0) + rhs = f32[16,4,5,7] parameter(1) + meta = u16[2,4,5,7] parameter(2) + ROOT dot = f32[7,3,5] dot(lhs, rhs, meta), sparsity=R.0@2:4, + lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1}, + lhs_batch_dims={3}, rhs_batch_dims={3} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHlo)); + TF_ASSERT_OK_AND_ASSIGN(bool canonicalized, + DotDecomposer().Run(module.get())); + EXPECT_TRUE(canonicalized); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Reshape(SparseDotMatcher( + m::Reshape(m::Transpose(m::Parameter(0))), + m::Reshape(m::Transpose(m::Parameter(1))), + m::Reshape(m::Transpose(m::Parameter(2))))))); + auto dot = Cast(root->operand(0)); + auto descriptor = dot->sparsity().front(); + EXPECT_EQ(descriptor.index(), 1); + EXPECT_EQ(descriptor.dimension(), 1); +} + } // namespace } // namespace xla diff --git a/xla/service/dot_dimension_merger.cc b/xla/service/dot_dimension_merger.cc index 0730ab93af6d2..ba2f26bcca4a2 100644 --- a/xla/service/dot_dimension_merger.cc +++ b/xla/service/dot_dimension_merger.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,13 +21,21 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/layout_util.h" #include "xla/service/hlo_creation_utils.h" +#include "xla/shape.h" #include "xla/status.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -119,6 +127,24 @@ class BatchDimensionMerger : public DfsHloRewriteVisitor { shifted_contracting_dimensions.end()); } + // Update sparsity descriptors, if present. + auto sparsity = Cast(dot)->sparsity(); + std::vector new_sparsity(sparsity.begin(), + sparsity.end()); + std::vector sparse_meta(sparsity.size()); + for (int i = 0; i < sparsity.size(); ++i) { + SparsityDescriptor& descriptor = new_sparsity[i]; + int64_t sparse_batch_dim = + descriptor.index() == 0 ? lhs_batch_dimension : rhs_batch_dimension; + if (descriptor.dimension() > sparse_batch_dim) + descriptor.set_dimension(descriptor.dimension() - + (batch_dimension_count - 1)); + HloInstruction* meta = + dot->mutable_operand(HloDotInstruction::kOperands + i); + Shape new_meta_shape = merge_batch_dims(meta->shape(), sparse_batch_dim); + TF_ASSIGN_OR_RETURN(sparse_meta[i], MakeReshapeHlo(new_meta_shape, meta)); + } + TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_lhs, MakeReshapeHlo(new_lhs_shape, dot->mutable_operand(0))); @@ -129,7 +155,8 @@ class BatchDimensionMerger : public DfsHloRewriteVisitor { HloInstruction* new_dot = dot->parent()->AddInstruction( HloInstruction::CreateDot(new_dot_shape, reshaped_lhs, reshaped_rhs, new_dot_dimension_numbers, - dot->precision_config()), + dot->precision_config(), new_sparsity, + sparse_meta), &dot->metadata()); dot->SetupDerivedInstruction(new_dot); @@ -141,7 +168,7 @@ class BatchDimensionMerger : public DfsHloRewriteVisitor { } // namespace -StatusOr DotDimensionMerger::Run( +absl::StatusOr DotDimensionMerger::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { return BatchDimensionMerger().RunOnModule(module, execution_threads); diff --git a/xla/service/dot_dimension_merger.h b/xla/service/dot_dimension_merger.h index 9bb969732797c..53cee5d29b43a 100644 --- a/xla/service/dot_dimension_merger.h +++ b/xla/service/dot_dimension_merger.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_DOT_DIMENSION_MERGER_H_ #define XLA_SERVICE_DOT_DIMENSION_MERGER_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" @@ -29,7 +32,7 @@ class DotDimensionMerger : public HloModulePass { // Run the pass on computations in 'module'. // Return whether the 'module' was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/dot_dimension_merger_test.cc b/xla/service/dot_dimension_merger_test.cc index 83118d959abe2..a41b904f5259a 100644 --- a/xla/service/dot_dimension_merger_test.cc +++ b/xla/service/dot_dimension_merger_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,10 +15,13 @@ limitations under the License. #include "xla/service/dot_dimension_merger.h" +#include #include +#include #include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -135,5 +138,32 @@ ENTRY e { EXPECT_FALSE(modified); } +TEST_F(DotDimensionMergerTest, SparseDotUpdatesDescriptor) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = bf16[3,4,5,6,16] parameter(0) + p1 = bf16[3,4,5,32,6] parameter(1) + meta = u16[3,4,5,6,2] parameter(2) + ROOT d = bf16[4,5,6,6] dot(p0, p1, meta), sparsity=L.4@2:4, + lhs_batch_dims={1,2}, lhs_contracting_dims={0,4}, + rhs_batch_dims={1,2}, rhs_contracting_dims={0,3} +})"; + + RunAndFilecheckHloRewrite(kHloText, DotDimensionMerger(), R"( +; CHECK: %[[R0:.*]] = bf16[3,20,6,16]{3,2,1,0} reshape(%p0) +; CHECK: %[[R1:.*]] = bf16[3,20,32,6]{3,2,1,0} reshape(%p1) +; CHECK: %[[R2:.*]] = u16[3,20,6,2]{3,2,1,0} reshape(%meta) +; CHECK: %[[DOT:.*]] = bf16[20,6,6]{2,1,0} dot(%[[R0]], %[[R1]], %[[R2]]) +; CHECK-SAME: lhs_batch_dims={1} +; CHECK-SAME: lhs_contracting_dims={0,3} +; CHECK-SAME: rhs_batch_dims={1} +; CHECK-SAME: rhs_contracting_dims={0,2} +; CHECK-SAME: sparsity=L.3@2:4 +; CHECK-NEXT: ROOT {{.+}} = bf16[4,5,6,6]{3,2,1,0} reshape(%[[DOT]]) + )"); +} + } // namespace } // namespace xla diff --git a/xla/service/dot_merger.cc b/xla/service/dot_merger.cc index 020ade5b62b5f..71dbd9e73d3b5 100644 --- a/xla/service/dot_merger.cc +++ b/xla/service/dot_merger.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,15 +15,32 @@ limitations under the License. #include "xla/service/dot_merger.h" -#include +#include #include #include #include #include -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/protobuf_util.h" #include "xla/service/graphcycles/graphcycles.h" #include "xla/service/shape_inference.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -49,8 +66,8 @@ namespace { // - `a` does not transitively depend on the value of `b`, and `b` does not // transitively depend on the value of `a`. // -StatusOr TryMergeSameOperand(HloInstruction* a, - HloInstruction* b) { +absl::StatusOr TryMergeSameOperand(HloInstruction* a, + HloInstruction* b) { if (a->shape().layout() != b->shape().layout()) { VLOG(3) << "Can't merge dots because they have a different layout:\n" << "\t" << a->ToString() << "\n" @@ -112,6 +129,17 @@ StatusOr TryMergeSameOperand(HloInstruction* a, return nullptr; } + HloDotInstruction* dot_a = Cast(a); + HloDotInstruction* dot_b = Cast(b); + if (!absl::c_equal(dot_a->sparsity(), dot_b->sparsity(), + protobuf_util::ProtobufEquals)) { + VLOG(3) << "Can't merge dots because they have mismatching sparsity " + "descriptors:\n" + << "\t" << a->ToString() << "\n" + << "\t" << b->ToString(); + return nullptr; + } + VLOG(2) << "Merging dots sharing an operand:\n" << "\t" << a->ToString() << "\n" << "\t" << b->ToString(); @@ -172,6 +200,32 @@ StatusOr TryMergeSameOperand(HloInstruction* a, ++outer_dim; } + std::vector sparsity(dot_a->sparsity().begin(), + dot_a->sparsity().end()); + std::vector sparse_meta(sparsity.size()); + for (int i = 0; i < sparsity.size(); ++i) { + HloInstruction* meta = a->mutable_operand(HloDotInstruction::kOperands + i); + HloInstruction* other_meta = + b->mutable_operand(HloDotInstruction::kOperands + i); + if (sparsity[i].index() == (lhs_same ? 1 : 0)) { + TF_ASSIGN_OR_RETURN( + Shape meta_concat_shape, + ShapeInference::InferConcatOpShape( + {&meta->shape(), &other_meta->shape()}, outer_dim)); + meta = meta->AddInstruction(HloInstruction::CreateConcatenate( + meta_concat_shape, {meta, other_meta}, outer_dim)); + } else { + if (other_meta != meta) { + VLOG(3) + << "Can't merge dots because the sparsity metadata is different:\n" + << "\t" << a->ToString() << "\n" + << "\t" << b->ToString(); + return nullptr; + } + } + sparse_meta[i] = meta; + } + TF_ASSIGN_OR_RETURN( Shape concat_shape, ShapeInference::InferConcatOpShape( @@ -187,10 +241,11 @@ StatusOr TryMergeSameOperand(HloInstruction* a, Shape new_dot_shape, ShapeInference::InferDotOpShape( dot_lhs->shape(), dot_rhs->shape(), dnums, - /*preferred_element_type=*/a->shape().element_type())); + /*preferred_element_type=*/a->shape().element_type(), sparsity)); *new_dot_shape.mutable_layout() = a->shape().layout(); - HloInstruction* new_dot = a->AddInstruction(HloInstruction::CreateDot( - new_dot_shape, dot_lhs, dot_rhs, dnums, a->precision_config())); + HloInstruction* new_dot = a->AddInstruction( + HloInstruction::CreateDot(new_dot_shape, dot_lhs, dot_rhs, dnums, + a->precision_config(), sparsity, sparse_meta)); // We can't keep both. But one is better then none. if (!a->metadata().op_name().empty()) { @@ -223,7 +278,8 @@ StatusOr TryMergeSameOperand(HloInstruction* a, return new_dot; } -StatusOr MergeDots(HloComputation* comp, int64_t max_size_to_merge) { +absl::StatusOr MergeDots(HloComputation* comp, + int64_t max_size_to_merge) { auto is_merge_candidate = [&](HloInstruction* instr) { int64_t bytes = ShapeUtil::ByteSizeOfElements(instr->shape()); for (const HloInstruction* operand : instr->operands()) { @@ -336,9 +392,10 @@ StatusOr MergeDots(HloComputation* comp, int64_t max_size_to_merge) { int32_t b_id = graph_id(b); if (dead_instrs.contains(a) || dead_instrs.contains(b) || + (!is_merge_candidate(a) && !is_merge_candidate(b)) || + // Perform reachability checks last since they can be expensive. graph.IsReachableNonConst(a_id, b_id) || - graph.IsReachableNonConst(b_id, a_id) || - (!is_merge_candidate(a) && !is_merge_candidate(b))) { + graph.IsReachableNonConst(b_id, a_id)) { continue; } @@ -373,7 +430,7 @@ StatusOr MergeDots(HloComputation* comp, int64_t max_size_to_merge) { } // anonymous namespace -StatusOr DotMerger::Run( +absl::StatusOr DotMerger::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/dot_merger.h b/xla/service/dot_merger.h index 7bc0e80e6281a..5f7fa58eaf1e8 100644 --- a/xla/service/dot_merger.h +++ b/xla/service/dot_merger.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,9 +16,12 @@ limitations under the License. #ifndef XLA_SERVICE_DOT_MERGER_H_ #define XLA_SERVICE_DOT_MERGER_H_ -#include -#include +#include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" namespace xla { @@ -58,7 +61,7 @@ class DotMerger : public HloModulePass { absl::string_view name() const override { return "dot-merger"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/dot_merger_test.cc b/xla/service/dot_merger_test.cc index 23f4307f0a82f..97b9da0d0c279 100644 --- a/xla/service/dot_merger_test.cc +++ b/xla/service/dot_merger_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,15 +15,22 @@ limitations under the License. #include "xla/service/dot_merger.h" +#include #include +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/algebraic_simplifier.h" -#include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_utils.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -708,5 +715,103 @@ TEST_F(DotMergerTest, MergeWithTypeUpgrade) { EXPECT_EQ(d0, d1); } +TEST_F(DotMergerTest, MergeSparseDotsSameMetadata) { + absl::string_view kHlo = R"( + HloModule test + ENTRY main { + lhs0 = f16[5,10,32] parameter(0) + lhs1 = f16[5,10,32] parameter(1) + rhs = f16[5,10,16] parameter(2) + meta = u16[5,10,2] parameter(3) + dot0 = f32[5,10,10] dot(lhs0, rhs, meta), sparsity=R.2@2:4, + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2} + dot1 = f32[5,10,10] dot(lhs1, rhs, meta), sparsity=R.2@2:4, + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2} + ROOT tuple = (f32[5,10,10], f32[5,10,10]) tuple(dot0, dot1) + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHlo)); + DotMerger pass(/*max_size_to_merge=*/std::numeric_limits::max()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + EXPECT_TRUE(changed); + const HloInstruction *d0, *d1; + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::Slice(m::Op(&d0) + .WithOpcode(HloOpcode::kDot) + .WithOperand(0, m::Concatenate(m::Parameter(0), + m::Parameter(1))) + .WithOperand(1, m::Parameter(2)) + .WithOperand(2, m::Parameter(3)) + .WithShape(F32, {5, 20, 10})), + m::Slice(m::Op(&d1))))); + EXPECT_EQ(d0, d1); + EXPECT_EQ(d0->operand(2)->shape(), ShapeUtil::MakeShape(U16, {5, 10, 2})); +} + +TEST_F(DotMergerTest, MergeSparseDotsConcatMetadata) { + absl::string_view kHlo = R"( + HloModule test + ENTRY main { + lhs0 = f16[5,10,16] parameter(0) + lhs1 = f16[5,10,16] parameter(1) + rhs = f16[5,10,32] parameter(2) + meta0 = u16[5,10,2] parameter(3) + meta1 = u16[5,10,2] parameter(4) + dot0 = f32[5,10,10] dot(lhs0, rhs, meta0), sparsity=L.2@2:4, + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2} + dot1 = f32[5,10,10] dot(lhs1, rhs, meta1), sparsity=L.2@2:4, + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2} + ROOT tuple = (f32[5,10,10], f32[5,10,10]) tuple(dot0, dot1) + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHlo)); + DotMerger pass(/*max_size_to_merge=*/std::numeric_limits::max()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + EXPECT_TRUE(changed); + const HloInstruction *d0, *d1; + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::Slice(m::Op(&d0) + .WithOpcode(HloOpcode::kDot) + .WithOperand(0, m::Concatenate(m::Parameter(0), + m::Parameter(1))) + .WithOperand(1, m::Parameter(2)) + .WithOperand(2, m::Concatenate(m::Parameter(3), + m::Parameter(4))) + .WithShape(F32, {5, 20, 10})), + m::Slice(m::Op(&d1))))); + EXPECT_EQ(d0, d1); + EXPECT_EQ(d0->operand(2)->shape(), ShapeUtil::MakeShape(U16, {5, 20, 2})); +} + +TEST_F(DotMergerTest, MergeSparseDotsDifferentMetadata) { + absl::string_view kHlo = R"( + HloModule test + ENTRY main { + lhs0 = f16[5,10,32] parameter(0) + lhs1 = f16[5,10,32] parameter(1) + rhs = f16[5,10,16] parameter(2) + meta1 = u16[5,10,2] parameter(3) + meta2 = u16[5,10,2] parameter(4) + dot0 = f32[5,10,10] dot(lhs0, rhs, meta1), sparsity=R.2@2:4, + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2} + dot1 = f32[5,10,10] dot(lhs1, rhs, meta2), sparsity=R.2@2:4, + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2} + ROOT tuple = (f32[5,10,10], f32[5,10,10]) tuple(dot0, dot1) + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHlo)); + DotMerger pass(/*max_size_to_merge=*/std::numeric_limits::max()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + EXPECT_FALSE(changed); +} + } // namespace } // namespace xla diff --git a/xla/service/dump.cc b/xla/service/dump.cc index 247b1fdada5ea..5adf5957d523b 100644 --- a/xla/service/dump.cc +++ b/xla/service/dump.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -51,7 +51,7 @@ std::string RenderGraph(absl::string_view label, const HloModule& module, bool show_fusion_subcomputations) { HloRenderOptions hlo_render_options; hlo_render_options.show_fusion_subcomputations = show_fusion_subcomputations; - StatusOr rendered_graph = + absl::StatusOr rendered_graph = RenderGraph(*module.entry_computation(), label, module.config().debug_options(), format, hlo_render_options); if (rendered_graph.ok()) { @@ -472,7 +472,8 @@ static std::vector DumpHloModuleImpl( continue; } - StatusOr rendered_graph = WrapFusionExplorer(*computation); + absl::StatusOr rendered_graph = + WrapFusionExplorer(*computation); if (!rendered_graph.ok()) { VLOG(1) << "Skipping fusion visualization" << " for computation " << computation->name() @@ -623,7 +624,7 @@ void DumpToFileInDirOrStdout(const HloModule& module, string_view file_prefix, CanonicalDebugOptions opts(module.config().debug_options()); if (opts.dumping_to_stdout()) return op->dump(); - mlir::OpPrintingFlags print_flags = mlir::OpPrintingFlags().useLocalScope(); + mlir::OpPrintingFlags print_flags = mlir::OpPrintingFlags(); // Enable debug info so that it is easier to see the corresponding HLO node. if (file_prefix == "lmhlo") { print_flags.enableDebugInfo(/*enable=*/true, @@ -639,7 +640,7 @@ void DumpToFileInDirOrStdout(const HloModule& module, string_view file_prefix, void DumpProtobufToFile(const tsl::protobuf::Message& proto, const DebugOptions& debug_options, absl::string_view filename, - absl::AnyInvocable( + absl::AnyInvocable( tsl::Env*, const tsl::protobuf::Message&)> text_formatter) { CanonicalDebugOptions opts(debug_options); @@ -680,12 +681,13 @@ void DumpProtobufToFile(const tsl::protobuf::Message& proto, } } -void DumpPerModuleProtobufToFile( - const HloModule& module, const tsl::protobuf::Message& proto, - const DebugOptions& debug_options, absl::string_view name, - absl::AnyInvocable(tsl::Env*, - const tsl::protobuf::Message&)> - text_formatter) { +void DumpPerModuleProtobufToFile(const HloModule& module, + const tsl::protobuf::Message& proto, + const DebugOptions& debug_options, + absl::string_view name, + absl::AnyInvocable( + tsl::Env*, const tsl::protobuf::Message&)> + text_formatter) { const std::string filename = FilenameFor(module, TimestampFor(module), name); DumpProtobufToFile(proto, debug_options, filename, std::move(text_formatter)); } diff --git a/xla/service/dump.h b/xla/service/dump.h index 40e46862cd17b..3020fa4f5fc53 100644 --- a/xla/service/dump.h +++ b/xla/service/dump.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -89,7 +89,7 @@ void DumpToFileInDirOrStdout(const HloModule& module, void DumpProtobufToFile(const tsl::protobuf::Message& proto, const DebugOptions& debug_options, absl::string_view filename, - absl::AnyInvocable( + absl::AnyInvocable( tsl::Env*, const tsl::protobuf::Message&)> text_formatter = nullptr); @@ -100,12 +100,13 @@ std::string RenderGraph(absl::string_view label, const HloModule& module, // Similar to above, but the filename depends on module's information and the // given name. Also allows for the optional serialization function. -void DumpPerModuleProtobufToFile( - const HloModule& module, const tsl::protobuf::Message& proto, - const DebugOptions& debug_options, absl::string_view name, - absl::AnyInvocable(tsl::Env*, - const tsl::protobuf::Message&)> - text_formatter = nullptr); +void DumpPerModuleProtobufToFile(const HloModule& module, + const tsl::protobuf::Message& proto, + const DebugOptions& debug_options, + absl::string_view name, + absl::AnyInvocable( + tsl::Env*, const tsl::protobuf::Message&)> + text_formatter = nullptr); // Dumps the given HLO module if dumping is enabled for the module. Exactly // where and in what formats it's dumped is determined by the module's config. diff --git a/xla/service/dynamic_dimension_inference.cc b/xla/service/dynamic_dimension_inference.cc index 5e76198b462f9..ee74ae367d93f 100644 --- a/xla/service/dynamic_dimension_inference.cc +++ b/xla/service/dynamic_dimension_inference.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -71,7 +71,7 @@ namespace xla { namespace { // Replace `narrow_comp` with a new computation with `wide_shape` as input. -StatusOr> +absl::StatusOr> WidenComputation(HloComputation* narrow_comp, const Shape& wide_shape) { TF_RET_CHECK(wide_shape.IsTuple()); const Shape& narrow_shape = narrow_comp->parameter_instruction(0)->shape(); @@ -119,7 +119,7 @@ class DynamicDimensionInferenceVisitor : public DfsHloRewriteVisitor { Status DefaultAction(HloInstruction* hlo) override; - static StatusOr Run( + static absl::StatusOr Run( HloComputation* computation, HloDataflowAnalysis& dataflow_analysis, const DynamicParameterBinding& param_bindings, DynamicDimensionInference* parent, @@ -266,8 +266,8 @@ class DynamicDimensionInferenceVisitor : public DfsHloRewriteVisitor { // (including uses across control flow, but only within the same thread). The // given `ShapeIndex` is the leaf array returned by the given instruction that // will be considered. - StatusOr RequiresPadToStatic(HloInstruction* instr, - ShapeIndex shape_index); + absl::StatusOr RequiresPadToStatic(HloInstruction* instr, + ShapeIndex shape_index); // Insert pad-to-static after `inst` if `inst` has dynamic dimensions in it // and `RequiresPadToStatic` is true for all leaves. If the instruction @@ -485,8 +485,10 @@ Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) { TF_RETURN_IF_ERROR(custom_call_handler_(hlo, parent_)); } else { TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( - hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension, - int64_t operand_index, HloInstruction* dynamic_size) { + hlo, + [&](HloInstruction* operand, ShapeIndex index, int64_t dimension, + int64_t operand_index, + HloInstruction* dynamic_size) -> absl::Status { // Resize custom call should propagate dynamic batch (0) and channel // (3) dimensions. if (hlo->custom_call_target() == "SliceToDynamic" || @@ -565,8 +567,9 @@ Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) { return OkStatus(); } return ForEachOperandDynamicDimension( - hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension, - int64_t operand_index, HloInstruction* dynamic_size) { + hlo, + [&](HloInstruction* operand, ShapeIndex index, int64_t dimension, + int64_t operand_index, HloInstruction* dynamic_size) -> absl::Status { if (operand_index != 0) { return Unimplemented( "Dynamic dimension on padding value is not supported"); @@ -803,8 +806,9 @@ Status DynamicDimensionInferenceVisitor::HandleConvolution( return OkStatus(); } return ForEachOperandDynamicDimension( - hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension, - int64_t operand_index, HloInstruction* dynamic_size) { + hlo, + [&](HloInstruction* operand, ShapeIndex index, int64_t dimension, + int64_t operand_index, HloInstruction* dynamic_size) -> absl::Status { HloInstruction* conv = hlo; const ConvolutionDimensionNumbers& dimension_numbers = conv->convolution_dimension_numbers(); @@ -1077,7 +1081,7 @@ Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension( const Shape& subshape = ShapeUtil::GetSubshape(hlo->shape(), index); auto* element = dynamic_sizes.mutable_element(index); element->resize(subshape.rank(), nullptr); - element->at(dimension) = dynamic_size; + (*element)[dimension] = dynamic_size; return OkStatus(); })); dynamic_sizes.ForEachElement([&](const ShapeIndex& index, const auto& sizes) { @@ -1415,7 +1419,7 @@ Status DynamicDimensionInferenceVisitor::HandleReshape( SetDynamicSizes(hlo, {}, dynamic_sizes); return OkStatus(); } - return InternalError( + return Internal( "Need inferred dimension to be set to " "flatten-unflatten pair. %s", hlo->ToString()); @@ -1655,7 +1659,7 @@ Status DynamicDimensionInferenceVisitor::HandleReduceWindow( auto* leaf_dynamic_sizes = dynamic_sizes.mutable_element(reduce_window_result_index); leaf_dynamic_sizes->resize(subshape.rank(), nullptr); - leaf_dynamic_sizes->at(dimension) = dynamic_size; + (*leaf_dynamic_sizes)[dimension] = dynamic_size; }); return OkStatus(); @@ -1844,7 +1848,7 @@ Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) { } ++operand_dimension; } - return InternalError("Invalid instruction: %s", hlo->ToString()); + return Internal("Invalid instruction: %s", hlo->ToString()); } return Unimplemented( "Detects a dynamic dimension on the data input of gather, which " @@ -2120,7 +2124,8 @@ Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex dynamic_index, int64_t dimension, - int64_t operand_index, HloInstruction* operand_dynamic_size) { + int64_t operand_index, + HloInstruction* operand_dynamic_size) -> absl::Status { if (operand_index == 0) { SetDynamicSize(hlo, {}, dimension, operand_dynamic_size); return OkStatus(); @@ -2392,7 +2397,7 @@ Status DynamicDimensionInferenceVisitor::ForEachDynamicDimension( return OkStatus(); } -StatusOr DynamicDimensionInferenceVisitor::RequiresPadToStatic( +absl::StatusOr DynamicDimensionInferenceVisitor::RequiresPadToStatic( HloInstruction* instr, ShapeIndex shape_index) { TF_RET_CHECK(ShapeUtil::IsLeafIndex(instr->shape(), shape_index)) << instr->shape() << " @ " << shape_index; @@ -2669,7 +2674,7 @@ void DynamicDimensionInference::CopyMapping( } /* static */ -StatusOr DynamicDimensionInference::Run( +absl::StatusOr DynamicDimensionInference::Run( HloModule* module, OpSupportsDynamismHandler op_supports_dynamism_handler, CustomCallInferenceHandler custom_call_handler, ShapeCheckMode shape_check_mode, diff --git a/xla/service/dynamic_dimension_inference.h b/xla/service/dynamic_dimension_inference.h index 681ba709d7e98..4ce0b0c3087c1 100644 --- a/xla/service/dynamic_dimension_inference.h +++ b/xla/service/dynamic_dimension_inference.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -77,7 +77,7 @@ class DynamicDimensionInference { // false. using AssertionGenerator = std::function; - static StatusOr Run( + static absl::StatusOr Run( HloModule* module, OpSupportsDynamismHandler op_supports_dynamism_handler = nullptr, CustomCallInferenceHandler custom_call_handler = nullptr, diff --git a/xla/service/dynamic_dimension_inference_test.cc b/xla/service/dynamic_dimension_inference_test.cc index 0ca66d5f01900..24b65c063c95e 100644 --- a/xla/service/dynamic_dimension_inference_test.cc +++ b/xla/service/dynamic_dimension_inference_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -1351,8 +1351,8 @@ ENTRY computation { /*opaque=*/std::string{}, API_VERSION_STATUS_RETURNING)); })); - StatusOr filecheck_result = RunFileCheck(module_->ToString({}), - R"( + absl::StatusOr filecheck_result = RunFileCheck(module_->ToString({}), + R"( // CHECK: compare = pred[] compare(s32[] %a_size_1, s32[] %b_size_1), direction=EQ // CHECK: compare.5 = pred[] compare(s32[] %a_size_2, s32[] %b_size_2), direction=EQ // CHECK: and.2 = pred[] and(pred[] %compare, pred[] %compare.5) diff --git a/xla/service/dynamic_dimension_simplifier.cc b/xla/service/dynamic_dimension_simplifier.cc index ff3cdd9d921de..003a499d28ae4 100644 --- a/xla/service/dynamic_dimension_simplifier.cc +++ b/xla/service/dynamic_dimension_simplifier.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ namespace xla { namespace { // Concat(Concat(A, B), C) => Concat(A, B, C) -StatusOr ConcatForwarding(HloInstruction* concat) { +absl::StatusOr ConcatForwarding(HloInstruction* concat) { if (concat->opcode() != HloOpcode::kConcatenate) { return false; } @@ -51,7 +51,7 @@ StatusOr ConcatForwarding(HloInstruction* concat) { } // Slice(Concat(A1, A2, ..., An, ...), [n:n+1]) => An -StatusOr SliceConcatForwarding(HloInstruction* slice) { +absl::StatusOr SliceConcatForwarding(HloInstruction* slice) { if (slice->opcode() != HloOpcode::kSlice) { return false; } @@ -90,7 +90,7 @@ StatusOr SliceConcatForwarding(HloInstruction* slice) { } // Reshape(Broadcast(A, []->[1]), [1]->[]) ==> A -StatusOr ReshapeBroadcastForwarding(HloInstruction* reshape) { +absl::StatusOr ReshapeBroadcastForwarding(HloInstruction* reshape) { if (reshape->opcode() != HloOpcode::kReshape) { return false; } @@ -118,7 +118,7 @@ StatusOr ReshapeBroadcastForwarding(HloInstruction* reshape) { } // Reshape(Reshape(A, []->[1]), [1]->[]) ==> A -StatusOr ReshapeReshapeForwarding(HloInstruction* reshape) { +absl::StatusOr ReshapeReshapeForwarding(HloInstruction* reshape) { if (reshape->opcode() != HloOpcode::kReshape) { return false; } @@ -137,7 +137,7 @@ StatusOr ReshapeReshapeForwarding(HloInstruction* reshape) { } // Convert(A, T->T) ==> A -StatusOr IdentityConvertRemoving(HloInstruction* convert) { +absl::StatusOr IdentityConvertRemoving(HloInstruction* convert) { if (convert->opcode() != HloOpcode::kConvert) { return false; } @@ -150,7 +150,7 @@ StatusOr IdentityConvertRemoving(HloInstruction* convert) { } // Reshape(A, S->S) ==> A -StatusOr IdentityReshapeRemoving(HloInstruction* reshape) { +absl::StatusOr IdentityReshapeRemoving(HloInstruction* reshape) { if (reshape->opcode() != HloOpcode::kReshape) { return false; } @@ -164,7 +164,7 @@ StatusOr IdentityReshapeRemoving(HloInstruction* reshape) { } // namespace -StatusOr DynamicDimensionSimplifier::Run( +absl::StatusOr DynamicDimensionSimplifier::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES( diff --git a/xla/service/dynamic_dimension_simplifier.h b/xla/service/dynamic_dimension_simplifier.h index aae92ddd59a23..435dc579dca28 100644 --- a/xla/service/dynamic_dimension_simplifier.h +++ b/xla/service/dynamic_dimension_simplifier.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,11 +27,11 @@ namespace xla { class DynamicDimensionSimplifier : public HloModulePass { public: absl::string_view name() const override { - return "dynamic dimension simplifier"; + return "dynamic-dimension-simplifier"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/dynamic_dimension_simplifier_test.cc b/xla/service/dynamic_dimension_simplifier_test.cc index e9ad54f465f59..94e48eca1104e 100644 --- a/xla/service/dynamic_dimension_simplifier_test.cc +++ b/xla/service/dynamic_dimension_simplifier_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/dynamic_index_splitter.cc b/xla/service/dynamic_index_splitter.cc index 50a2d5a709ce6..cf4e21c997e97 100644 --- a/xla/service/dynamic_index_splitter.cc +++ b/xla/service/dynamic_index_splitter.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ limitations under the License. namespace xla { -StatusOr DynamicIndexSplitter::Run( +absl::StatusOr DynamicIndexSplitter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/dynamic_index_splitter.h b/xla/service/dynamic_index_splitter.h index 723c686239de9..babd86913da2b 100644 --- a/xla/service/dynamic_index_splitter.h +++ b/xla/service/dynamic_index_splitter.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ class DynamicIndexSplitter : public HloModulePass { DynamicIndexSplitter() = default; absl::string_view name() const override { return "dynamic-index-splitter"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/dynamic_index_splitter_test.cc b/xla/service/dynamic_index_splitter_test.cc index 4259a9498aaba..e3daa8a6ded59 100644 --- a/xla/service/dynamic_index_splitter_test.cc +++ b/xla/service/dynamic_index_splitter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/dynamic_padder.cc b/xla/service/dynamic_padder.cc index fd14724f742ad..f5e7ae0fbba8a 100644 --- a/xla/service/dynamic_padder.cc +++ b/xla/service/dynamic_padder.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -74,8 +74,8 @@ auto* dynamic_padding_gauge = tsl::monitoring::Gauge::New( // instruction. // // nullopt is returned if padding doesn't need to be reset. -StatusOr ChooseIdentityValue(HloInstruction* inst, - int64_t operand_number) { +absl::StatusOr ChooseIdentityValue(HloInstruction* inst, + int64_t operand_number) { // Padding on elementwise operation doesn't affect the result of the effective // data. if (inst->IsElementwise()) { @@ -173,7 +173,7 @@ StatusOr ChooseIdentityValue(HloInstruction* inst, } } -StatusOr ReplaceGetSize( +absl::StatusOr ReplaceGetSize( HloInstruction* instr, DynamicDimensionInference* dynamic_dimension_inference) { if (instr->opcode() != HloOpcode::kGetDimensionSize) { @@ -210,7 +210,7 @@ StatusOr ReplaceGetSize( return true; } -StatusOr ReplaceSetSize(HloInstruction* instr) { +absl::StatusOr ReplaceSetSize(HloInstruction* instr) { if (instr->opcode() != HloOpcode::kSetDimensionSize) { return false; } @@ -225,7 +225,7 @@ StatusOr ReplaceSetSize(HloInstruction* instr) { return true; } -StatusOr ReplaceSetBound(HloInstruction* instr) { +absl::StatusOr ReplaceSetBound(HloInstruction* instr) { if (instr->opcode() != HloOpcode::kCustomCall || instr->custom_call_target() != "SetBound") { return false; @@ -487,7 +487,7 @@ HloInstruction* GenerateBinaryMask( // [[a,b,P] // [c,d,P]] // -StatusOr RewriteDynamicReshapeSplitInput( +absl::StatusOr RewriteDynamicReshapeSplitInput( HloInstruction* reshape, int64_t input_dim, absl::Span output_dims, absl::Span output_dynamic_dims, @@ -670,7 +670,7 @@ StatusOr RewriteDynamicReshapeSplitInput( // | // [a,b,c,d,P,P] // -StatusOr RewriteDynamicReshapeCombineInput( +absl::StatusOr RewriteDynamicReshapeCombineInput( HloInstruction* reshape, absl::Span input_dims, int64_t output_dim, absl::Span input_dynamic_dims, DynamicDimensionInference* dynamic_dimension_inference) { @@ -790,7 +790,7 @@ StatusOr RewriteDynamicReshapeCombineInput( return true; } -StatusOr RewriteDynamicReshapeSingleGroup( +absl::StatusOr RewriteDynamicReshapeSingleGroup( HloInstruction* reshape, absl::Span input_dims, absl::Span output_dims, absl::Span input_dynamic_dims, @@ -831,7 +831,7 @@ StatusOr RewriteDynamicReshapeSingleGroup( return false; } -StatusOr RewriteReverse( +absl::StatusOr RewriteReverse( HloInstruction* reverse, DynamicDimensionInference* dynamic_dimension_inference) { // When we have [A, B, C, D, E] and reverse them, we get [E, D, C, B, A]. @@ -980,7 +980,7 @@ HloInstruction* RewriteInputWithDynamicPadding( return input; } -StatusOr RewriteDynamicConvolutionInputGrad( +absl::StatusOr RewriteDynamicConvolutionInputGrad( HloInstruction* custom_call_conv, DynamicDimensionInference* dynamic_dimension_inference) { HloInstruction* grad = custom_call_conv->mutable_operand(1); @@ -1049,7 +1049,7 @@ StatusOr RewriteDynamicConvolutionInputGrad( return true; } -StatusOr RewriteDynamicConvolutionForward( +absl::StatusOr RewriteDynamicConvolutionForward( HloInstruction* custom_call_conv, DynamicDimensionInference* dynamic_dimension_inference) { HloInstruction* input = custom_call_conv->mutable_operand(0); @@ -1108,7 +1108,7 @@ StatusOr RewriteDynamicConvolutionForward( return true; } -StatusOr RewriteDynamicConvolutionKernelGrad( +absl::StatusOr RewriteDynamicConvolutionKernelGrad( HloInstruction* custom_call_conv, DynamicDimensionInference* dynamic_dimension_inference) { HloInstruction* activations = custom_call_conv->mutable_operand(0); @@ -1192,7 +1192,7 @@ StatusOr RewriteDynamicConvolutionKernelGrad( return true; } -StatusOr RewriteDynamicReduceWindowSamePadding( +absl::StatusOr RewriteDynamicReduceWindowSamePadding( HloInstruction* hlo, DynamicDimensionInference* dynamic_dimension_inference) { if (hlo->shape().IsTuple()) { @@ -1236,7 +1236,7 @@ StatusOr RewriteDynamicReduceWindowSamePadding( return true; } -StatusOr RewriteDynamicSelectAndScatterSamePadding( +absl::StatusOr RewriteDynamicSelectAndScatterSamePadding( HloInstruction* hlo, DynamicDimensionInference* dynamic_dimension_inference) { HloInstruction* input = hlo->mutable_operand(0); @@ -1312,7 +1312,7 @@ StatusOr RewriteDynamicSelectAndScatterSamePadding( return true; } -StatusOr RewriteDynamicConcat( +absl::StatusOr RewriteDynamicConcat( HloInstruction* concat, DynamicDimensionInference* dynamic_dimension_inference) { const int64_t concat_dim = concat->concatenate_dimension(); @@ -1359,7 +1359,7 @@ StatusOr RewriteDynamicConcat( return true; } -StatusOr RewriteDynamicSort( +absl::StatusOr RewriteDynamicSort( HloInstruction* hlo, DynamicDimensionInference* dynamic_dimension_inference) { HloInstruction* dynamic_size = nullptr; @@ -1448,7 +1448,7 @@ StatusOr RewriteDynamicSort( return true; } -StatusOr RewriteDynamicBinaryOp( +absl::StatusOr RewriteDynamicBinaryOp( HloInstruction* binary, DynamicDimensionInference* dynamic_dimension_inference) { HloInstruction* operand_0 = binary->mutable_operand(0); @@ -1559,7 +1559,7 @@ StatusOr RewriteDynamicBinaryOp( return changed; } -StatusOr RewriteDynamicUpdateSlice( +absl::StatusOr RewriteDynamicUpdateSlice( HloInstruction* hlo, DynamicDimensionInference* dynamic_dimension_inference) { HloDynamicUpdateSliceInstruction* dus = @@ -1661,7 +1661,7 @@ StatusOr RewriteDynamicUpdateSlice( return true; } -StatusOr RewriteDynamicReshape( +absl::StatusOr RewriteDynamicReshape( HloInstruction* reshape, DynamicDimensionInference* dynamic_dimension_inference) { bool changed = false; @@ -1799,7 +1799,7 @@ StatusOr RewriteDynamicReshape( continue; } if (input_dims.size() > 1 && output_dims.size() > 1) { - return InternalError( + return Internal( "Should be handled by decomposing reshape into " "flatten-unflatten pair. %s", reshape->ToString()); @@ -1870,7 +1870,7 @@ class DynamicShapeRemovingVisitor : public DfsHloRewriteVisitor { Status HandleGetDimensionSize(HloInstruction* hlo) override; Status HandleSetDimensionSize(HloInstruction* hlo) override; - static StatusOr Run( + static absl::StatusOr Run( HloComputation* computation, const OpSupportsDynamismHandler& op_supports_dynamism_handler, DynamicDimensionInference* dynamic_shape_inference, @@ -1896,7 +1896,7 @@ class DynamicShapeRemovingVisitor : public DfsHloRewriteVisitor { private: // If a tensor produced by `inst` is in static form, convert it to dynamic and // returns the new instruction. - StatusOr ConvertToDynamic(HloInstruction* inst); + absl::StatusOr ConvertToDynamic(HloInstruction* inst); // Same as above, but for all of the instructions operands. The operands will // be replaced by dynamic operands as needed. @@ -1909,7 +1909,7 @@ class DynamicShapeRemovingVisitor : public DfsHloRewriteVisitor { absl::flat_hash_set execution_threads_; }; -StatusOr DynamicShapeRemovingVisitor::ConvertToDynamic( +absl::StatusOr DynamicShapeRemovingVisitor::ConvertToDynamic( HloInstruction* inst) { if (!dynamic_dimension_inference_->HasDynamicDimension(inst)) { return OkStatus(); @@ -2058,7 +2058,7 @@ Status DynamicShapeRemovingVisitor::HandleSetDimensionSize( } // namespace -StatusOr DynamicPadder::Run( +absl::StatusOr DynamicPadder::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(2) << "Pre DynamicPadder HLO:"; @@ -2220,7 +2220,7 @@ StatusOr DynamicPadder::Run( // their called computation to only take static tensors. for (auto it = computations.rbegin(); it != computations.rend(); ++it) { HloComputation* computation = *it; - if (!call_graph->Dominates(module->entry_computation(), computation)) { + if (!call_graph->CanReach(module->entry_computation(), computation)) { continue; } // if slice_dynamic_output_ is set and this is entry computation, we need @@ -2242,7 +2242,7 @@ StatusOr DynamicPadder::Run( } for (auto* computation : module->computations(execution_threads)) { - if (!call_graph->Dominates(module->entry_computation(), computation)) { + if (!call_graph->CanReach(module->entry_computation(), computation)) { continue; } for (auto instruction : computation->MakeInstructionPostOrder()) { @@ -2253,7 +2253,7 @@ StatusOr DynamicPadder::Run( } for (auto* computation : module->computations(execution_threads)) { - if (!call_graph->Dominates(module->entry_computation(), computation)) { + if (!call_graph->CanReach(module->entry_computation(), computation)) { continue; } for (auto instruction : computation->MakeInstructionPostOrder()) { diff --git a/xla/service/dynamic_padder.h b/xla/service/dynamic_padder.h index ac6c388f271a6..71430c31c2ec8 100644 --- a/xla/service/dynamic_padder.h +++ b/xla/service/dynamic_padder.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -75,7 +75,7 @@ class DynamicPadder : public HloModulePass { absl::string_view name() const override { return "dynamic_padder"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/dynamic_padder_test.cc b/xla/service/dynamic_padder_test.cc index b8113dec97caa..20421c57f0116 100644 --- a/xla/service/dynamic_padder_test.cc +++ b/xla/service/dynamic_padder_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -107,7 +107,7 @@ class DynamicPadderTest : public HloTestBase { return module; } - StatusOr RunPadder( + absl::StatusOr RunPadder( bool slice_dynamic_output = false, OpSupportsDynamismHandler op_supports_dynamism_handler = OpHasDynamismSupport, diff --git a/xla/service/dynamic_parameter_binding_test.cc b/xla/service/dynamic_parameter_binding_test.cc index 55e8b378c2f8d..11dfbcdbec961 100644 --- a/xla/service/dynamic_parameter_binding_test.cc +++ b/xla/service/dynamic_parameter_binding_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/dynamic_update_slice_test.cc b/xla/service/dynamic_update_slice_test.cc index 1632bc62800a6..eb8932a947809 100644 --- a/xla/service/dynamic_update_slice_test.cc +++ b/xla/service/dynamic_update_slice_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/dynamic_window_utils.cc b/xla/service/dynamic_window_utils.cc index dd72a33e1dfea..25b20e3ea3c2f 100644 --- a/xla/service/dynamic_window_utils.cc +++ b/xla/service/dynamic_window_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/dynamic_window_utils.h b/xla/service/dynamic_window_utils.h index 01f3445f9ba19..1c82b6b82e622 100644 --- a/xla/service/dynamic_window_utils.h +++ b/xla/service/dynamic_window_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/eigh_expander.cc b/xla/service/eigh_expander.cc index 4ee454caf7ef6..93d08aa184f40 100644 --- a/xla/service/eigh_expander.cc +++ b/xla/service/eigh_expander.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -94,8 +94,8 @@ struct Eigh2x2 { // rt1 = w_tl - t * w_tr // rt2 = w_br + t * w_tr // return rt1, rt2, c, s -StatusOr HermitianEigenDecomposition2x2(XlaOp w_tl, XlaOp w_tr, - XlaOp w_br) { +absl::StatusOr HermitianEigenDecomposition2x2(XlaOp w_tl, XlaOp w_tr, + XlaOp w_br) { TF_ASSIGN_OR_RETURN(Shape w_tl_shape, w_tl.builder()->GetShape(w_tl)); bool is_complex = primitive_util::IsComplexType(w_tl_shape.element_type()); @@ -268,8 +268,8 @@ struct FrobeniusNorms { XlaOp frobenius_sq_norm; }; -StatusOr ComputeFrobeniusNorms(XlaOp w_tl, XlaOp w_tr, - XlaOp w_bl, XlaOp w_br) { +absl::StatusOr ComputeFrobeniusNorms(XlaOp w_tl, XlaOp w_tr, + XlaOp w_bl, XlaOp w_br) { XlaBuilder* builder = w_tl.builder(); TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w_tl)); const int64_t num_dims = shape.rank(); @@ -299,12 +299,11 @@ StatusOr ComputeFrobeniusNorms(XlaOp w_tl, XlaOp w_tr, return norms; } -StatusOr> Sweeps(absl::Span initial_values, - int64_t n, int max_iters, - PrimitiveType index_type, - XlaBuilder* builder) { +absl::StatusOr> Sweeps( + absl::Span initial_values, int64_t n, int max_iters, + PrimitiveType index_type, XlaBuilder* builder) { auto while_cond_fn = [&](absl::Span values, - XlaBuilder* cond_builder) -> StatusOr { + XlaBuilder* cond_builder) -> absl::StatusOr { auto iter_cond = Lt(values[0], ScalarLike(values[0], max_iters)); XlaOp w_tl, w_tr, w_bl, w_br; @@ -322,14 +321,14 @@ StatusOr> Sweeps(absl::Span initial_values, auto while_body_fn = [&](absl::Span values, - XlaBuilder* body_builder) -> StatusOr> { + XlaBuilder* body_builder) -> absl::StatusOr> { std::vector sweep_values(values.begin() + 1, values.end()); TF_ASSIGN_OR_RETURN( sweep_values, ForEachIndex( n - 1, S32, [&](XlaOp iter, absl::Span values, - XlaBuilder* builder) -> StatusOr> { + XlaBuilder* builder) -> absl::StatusOr> { XlaOp tol, w_tl, w_tr, w_bl, w_br, v_tl, v_tr, v_bl, v_br; std::tie(tol, w_tl, w_tr, w_bl, w_br, v_tl, v_tr, v_bl, v_br) = std::make_tuple(values[0], values[1], values[2], values[3], @@ -430,7 +429,7 @@ Status EighExpander::SortByEigenvalues(XlaOp& v, XlaOp& w) { XlaOp EighExpander::BuildEigh(XlaOp a, bool lower, int64_t max_iter, float tol, bool sort_eigenvalues) { XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); const int64_t num_dims = a_shape.rank(); if (num_dims < 2) { @@ -543,7 +542,7 @@ bool EighExpander::InstructionMatchesPattern(HloInstruction* instruction) { instruction->custom_call_target() == kEighCustomCallName; } -StatusOr EighExpander::ExpandInstruction( +absl::StatusOr EighExpander::ExpandInstruction( HloInstruction* instruction) { const std::string name = absl::StrFormat("xla.%s_%s", instruction->custom_call_target(), diff --git a/xla/service/eigh_expander.h b/xla/service/eigh_expander.h index db86358aaae0c..ed53fac030afd 100644 --- a/xla/service/eigh_expander.h +++ b/xla/service/eigh_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ class EighExpander : public OpExpanderPass { protected: bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; virtual XlaOp BuildEigh(XlaOp a, bool lower, int64_t max_iter, float tol, diff --git a/xla/service/elemental_ir_emitter.cc b/xla/service/elemental_ir_emitter.cc index 8f6d656923b92..4f69d8d53d248 100644 --- a/xla/service/elemental_ir_emitter.cc +++ b/xla/service/elemental_ir_emitter.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,11 +26,14 @@ limitations under the License. // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Value.h" #include "llvm/Support/MathExtras.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -39,10 +42,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/permutation_util.h" #include "xla/primitive_util.h" +#include "xla/service/algorithm_util.h" #include "xla/service/float8_fnuz_ir_emitter.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_loop.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "xla/service/llvm_ir/math_ops.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/statusor.h" @@ -63,7 +68,7 @@ using xla::float8_fnuz_ir_emitter::EmitFloatingToF8fnuz; namespace { -StatusOr EmitReducePrecisionIR( +absl::StatusOr EmitReducePrecisionIR( PrimitiveType src_ty, llvm::Value* x, int64_t dest_exponent_bits, int64_t dest_mantissa_bits, bool quiet_nans, llvm::IRBuilder<>* b) { using llvm::APInt; @@ -206,30 +211,8 @@ StatusOr EmitReducePrecisionIR( return result; } -StatusOr DefaultEmitF32ToBF16Impl(llvm::Value* f32_value, - llvm::IRBuilder<>* b) { - TF_ASSIGN_OR_RETURN( - auto reduced_precision, - EmitReducePrecisionIR( - /*src_ty=*/F32, f32_value, - /*dest_exponent_bits=*/primitive_util::ExponentWidth(BF16), - /*dest_mantissa_bits=*/primitive_util::SignificandWidth(BF16) - 1, - /*quiet_nans=*/true, b)); - auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty()); - auto shifted = b->CreateLShr(as_int32, 16); - auto truncated = b->CreateTrunc(shifted, b->getInt16Ty()); - return b->CreateBitCast(truncated, b->getInt16Ty()); -} - -llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, llvm::IRBuilder<>* b) { - auto as_int16 = b->CreateBitCast(bf16_value, b->getInt16Ty()); - auto as_int32 = b->CreateZExt(as_int16, b->getInt32Ty()); - auto shifted = b->CreateShl(as_int32, 16); - return b->CreateBitCast(shifted, b->getFloatTy()); -} - -StatusOr EmitF16ToF8e5m2(llvm::Value* f16_value, - llvm::IRBuilder<>* b) { +absl::StatusOr EmitF16ToF8e5m2(llvm::Value* f16_value, + llvm::IRBuilder<>* b) { TF_ASSIGN_OR_RETURN( llvm::Value * reduced_precision, EmitReducePrecisionIR( @@ -280,7 +263,7 @@ llvm::Value* EmitF16ToF8e4m3fn(llvm::Value* f16_value, llvm::IRBuilder<>* b) { // f8E4M3FN's NaN representations, so don't use ReducePrecision to handle // exponent reduction. Denormal values are not handled properly here and are // dealt with later in this function. - StatusOr f16_reduced_statusor = EmitReducePrecisionIR( + absl::StatusOr f16_reduced_statusor = EmitReducePrecisionIR( /*src_ty=*/F16, f16_value, /*dest_exponent_bits=*/5, /*dest_mantissa_bits=*/3, @@ -571,7 +554,7 @@ llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, } } -StatusOr ElementalIrEmitter::EmitUnaryOp( +absl::StatusOr ElementalIrEmitter::EmitUnaryOp( const HloInstruction* op, llvm::Value* operand_value) { if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || op->operand(0)->shape().element_type() == PRED) { @@ -583,7 +566,7 @@ StatusOr ElementalIrEmitter::EmitUnaryOp( } } -StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( +absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( const HloInstruction* op, llvm::Value* operand_value) { switch (op->opcode()) { case HloOpcode::kConvert: { @@ -606,10 +589,6 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( primitive_util::IsSignedIntegralType(from_type)); } if (primitive_util::IsFloatingPointType(to_type)) { - if (to_type == BF16) { - return EmitF32ToBF16(EmitIntegralToFloating(operand_value, from_type, - F32, module_, b_)); - } if (to_type == F8E5M2) { return EmitF16ToF8e5m2( EmitIntegralToFloating(operand_value, from_type, F16, module_, @@ -727,7 +706,7 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } } -StatusOr ElementalIrEmitter::EmitFloatUnaryOp( +absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( const HloInstruction* op, llvm::Value* operand_value) { switch (op->opcode()) { case HloOpcode::kConvert: { @@ -739,7 +718,8 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } if (from_type == BF16) { TF_RET_CHECK(to_type != BF16); - operand_value = EmitBF16ToF32(operand_value, b_); + // The code below expects the source type to be F32. + operand_value = b_->CreateFPExt(operand_value, b_->getFloatTy()); from_type = F32; if (from_type == to_type) { return operand_value; @@ -794,13 +774,11 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( nullptr); } if (to_type == BF16) { - // Cast to F32 first. Other floating point formats are not supported by - // EmitReducePrecisionIR. - if (from_type != F32) { - operand_value = b_->CreateFPCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(F32, module_)); + // F16 to BF16 has to go through an intermediate F32. + if (from_type == F16) { + operand_value = b_->CreateFPExt(operand_value, b_->getFloatTy()); } - return EmitF32ToBF16(operand_value); + return FPCast(operand_value, b_->getBFloatTy()); } if (to_type == F8E5M2) { // Cast to F16 first. Casts to F8E5M2 must be from F16. @@ -902,6 +880,8 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( primitive_util::BitWidth(from_type), primitive_util::BitWidth(to_type)); } + case HloOpcode::kErf: + return EmitErf(op->shape().element_type(), operand_value); case HloOpcode::kExp: return EmitExp(op->shape().element_type(), operand_value, ""); case HloOpcode::kExpm1: @@ -982,7 +962,7 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } } -StatusOr ElementalIrEmitter::EmitComplexUnaryOp( +absl::StatusOr ElementalIrEmitter::EmitComplexUnaryOp( const HloInstruction* op, llvm::Value* operand_value) { PrimitiveType input_type = op->operand(0)->shape().element_type(); PrimitiveType component_type = @@ -994,21 +974,65 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( return EmitComplexLog(op, operand_value); } case HloOpcode::kLog1p: { - // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) - // log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1) - // log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1) + // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) + // log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1) + // log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1) + // + // that is accurate only when |a| is relatively small while + // large |a| and |b| lead to multiplication overflow in the real + // part. + // + // The following expression for the real part: + // + // log1p(a+bi).real = log(hypot(a+1, b)) + // = log(max(|a+1|, |b|) * sqrt(1 + (min(|a+1|, |b|) / + // max(|a+1|, b))^2)) [to fix overflow for maximal values + // of |a+1| and |b|] = log(max(|a+1|, |b|)) + log(sqrt(1 + // + (min(|a+1|, |b|) / max(|a+1|, b))^2)) = + // log(max(|a+1|, |b|)) + 0.5*log1p((min(|a+1|, |b|) / + // max(|a+1|, b))^2) [to fix inaccuracies for small a, + // we'll use log1p] = log1p((1 + a > |b| ? a : max(|a+1|, + // |b|) - 1) + 0.5*log1p((min(|a+1|, |b|) / max(|a+1|, + // b))^2) + // + // is accurate on the whole complex plane except when |b| is + // small and a is very close to -|b|^2/2 that leads to + // substraction errors when adding the two log1p values as in + // log1p(-|b|^2) + log1p(|b|^2) + // TODO: improve the accuracy for the case above. + auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); auto one = llvm::ConstantFP::get(llvm_ty, 1.0); - auto two = llvm::ConstantFP::get(llvm_ty, 2.0); - auto a_plus_one = FAdd(a, one); - auto sum_sq = FAdd(FAdd(FMul(a, a), FMul(two, a)), FMul(b, b)); - TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog1p(component_type, sum_sq)); - TF_ASSIGN_OR_RETURN(auto angle, - EmitAtan2(component_type, b, a_plus_one, "")); - auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle); + auto half = llvm::ConstantFP::get(llvm_ty, 0.5); + + auto a1 = FAdd(a, one); + auto abs_a1 = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {a1}, + {llvm_ty}, b_); + auto abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {b}, + {llvm_ty}, b_); + + auto max_abs_of_a1_and_b = EmitFloatMax(abs_a1, abs_b, ""); + auto min_abs_of_a1_and_b = EmitFloatMin(abs_a1, abs_b, ""); + + auto max_abs_of_a1_and_b_minus_one = + Select(FCmpOGT(a1, abs_b), a, FSub(max_abs_of_a1_and_b, one)); + auto min_max_ratio = FDiv(min_abs_of_a1_and_b, max_abs_of_a1_and_b); + + TF_ASSIGN_OR_RETURN( + auto log_of_max_abs_of_a1_and_b, + EmitLog1p(component_type, max_abs_of_a1_and_b_minus_one)); + TF_ASSIGN_OR_RETURN( + auto log_of_sqrt_part, + EmitLog1p(component_type, FMul(min_max_ratio, min_max_ratio))); + + auto r = FAdd(FMul(half, log_of_sqrt_part), log_of_max_abs_of_a1_and_b); + auto real_part = Select(FCmpUNO(r, r), min_abs_of_a1_and_b, + r); // handles nan and inf values correctly + + TF_ASSIGN_OR_RETURN(auto imag_part, EmitAtan2(component_type, b, a1, "")); + return EmitComposeComplex(op, real_part, imag_part); } case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -1039,16 +1063,21 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( } case HloOpcode::kExpm1: { // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i - TF_ASSIGN_OR_RETURN( - auto exp_a, - EmitExp(component_type, EmitExtractReal(operand_value), "")); - TF_ASSIGN_OR_RETURN( - auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value))); - TF_ASSIGN_OR_RETURN( - auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); - auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0); - auto real_result = FSub(FMul(exp_a, cos_b), one); - auto imag_result = FMul(exp_a, sin_b); + // [handle inaccuracies when a and/or b are small] + // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i + // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + auto zero = llvm::ConstantFP::get(b->getType(), 0.0); + auto one = llvm::ConstantFP::get(b->getType(), 1.0); + auto b_is_zero = FCmpOEQ(b, zero); + TF_ASSIGN_OR_RETURN(auto expm1_a, EmitExpm1(component_type, a)); + auto exp_a = FAdd(expm1_a, one); + TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b)); + TF_ASSIGN_OR_RETURN(auto cos_b_minus_one, EmitCosm1(component_type, b)); + auto cos_b = FAdd(cos_b_minus_one, one); + auto real_result = FAdd(FMul(expm1_a, cos_b), cos_b_minus_one); + auto imag_result = Select(b_is_zero, zero, FMul(exp_a, sin_b)); return EmitComposeComplex(op, real_result, imag_result); } case HloOpcode::kCos: @@ -1268,7 +1297,7 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( } } -StatusOr ElementalIrEmitter::EmitBinaryOp( +absl::StatusOr ElementalIrEmitter::EmitBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { PrimitiveType operand_type = op->operand(0)->shape().element_type(); if (operand_type == PRED) { @@ -1284,7 +1313,7 @@ StatusOr ElementalIrEmitter::EmitBinaryOp( } } -StatusOr ElementalIrEmitter::EmitFloatBinaryOp( +absl::StatusOr ElementalIrEmitter::EmitFloatBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { switch (op->opcode()) { case HloOpcode::kComplex: @@ -1308,10 +1337,7 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( // matches C++'s semantics. case HloOpcode::kCompare: { PrimitiveType operand_type = op->operand(0)->shape().element_type(); - if (operand_type == BF16) { - lhs_value = EmitBF16ToF32(lhs_value, b_); - rhs_value = EmitBF16ToF32(rhs_value, b_); - } else if (operand_type == F8E5M2) { + if (operand_type == F8E5M2) { lhs_value = EmitF8e5m2ToF16(lhs_value, b_); rhs_value = EmitF8e5m2ToF16(rhs_value, b_); } else if (operand_type == F8E4M3FN) { @@ -1365,20 +1391,25 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( // Using sqrt(a^2 + b^2) can cause overflow errors. Therefore we can use // sqrt(a^2 + b^2) = sqrt(a^2 * (1 + b^2/a^2)) // = |a| * sqrt(1 + (b/a)^2) -// With the assumption that |a| >= |b|. +// With the assumption that |a| >= |b|. This assumption is enforced by swapping +// a and b during the computation if necessary, since the final result is the +// same. // // This method returns the min, max, and sqrt term for this calculation. This is // done to prevent potential overflow errors that can occur from multiplying the // max with the sqrt term. (i.e. when calculating the sqrt of the absolute // value, we can take the sqrt of the max and the sqrt term before multiplying -// them together.) If return_sqrt is false, it returns 1 + (b/a)^2 instead of -// sqrt(1 + (b/a)^2). -StatusOr> +// them together.) If return_sqrt is false, it returns 1 + (min/max)^2 instead +// of sqrt(1 + (min/max)^2). +// +// Note that the precision of this computation can be improved by implementing +// another algorithm: +// Carlos F. Borges - An Improved Algorithm for hypot(a,b): +// https://arxiv.org/pdf/1904.09481.pdf +absl::StatusOr> ElementalIrEmitter::EmitComplexAbsHelper(PrimitiveType prim_type, - llvm::Value* operand_value, + llvm::Value* real, llvm::Value* imag, bool return_sqrt) { - llvm::Value* real = EmitExtractReal(operand_value); - llvm::Value* imag = EmitExtractImag(operand_value); llvm::Value* abs_real = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::fabs, {real}, {real->getType()}, b_); llvm::Value* abs_imag = llvm_ir::EmitCallToIntrinsic( @@ -1394,14 +1425,16 @@ ElementalIrEmitter::EmitComplexAbsHelper(PrimitiveType prim_type, return std::make_tuple(min, max, return_sqrt ? sqrt : one_p_div_sq); } -StatusOr ElementalIrEmitter::EmitComplexAbs( +absl::StatusOr ElementalIrEmitter::EmitComplexAbs( PrimitiveType prim_type, llvm::Value* operand_value) { llvm::Value* min; llvm::Value* max; llvm::Value* sqrt; + llvm::Value* real = EmitExtractReal(operand_value); + llvm::Value* imag = EmitExtractImag(operand_value); TF_ASSIGN_OR_RETURN( std::tie(min, max, sqrt), - EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true)); + EmitComplexAbsHelper(prim_type, real, imag, /*return_sqrt=*/true)); llvm::Value* result = FMul(max, sqrt); // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN. // In such cases, we return `min` instead of `result`. @@ -1410,14 +1443,16 @@ StatusOr ElementalIrEmitter::EmitComplexAbs( // Calculates ComplexAbs in the same way, except using: // sqrt(|a| * sqrt(1 + (b/a)^2)) = sqrt(|a|) * pow(1 + (b/a)^2, .25) -StatusOr ElementalIrEmitter::EmitSqrtComplexAbs( +absl::StatusOr ElementalIrEmitter::EmitSqrtComplexAbs( PrimitiveType prim_type, llvm::Value* operand_value) { llvm::Value* min; llvm::Value* max; llvm::Value* one_p_div_sq; + llvm::Value* real = EmitExtractReal(operand_value); + llvm::Value* imag = EmitExtractImag(operand_value); TF_ASSIGN_OR_RETURN( std::tie(min, max, one_p_div_sq), - EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/false)); + EmitComplexAbsHelper(prim_type, real, imag, /*return_sqrt=*/false)); TF_ASSIGN_OR_RETURN(llvm::Value * sqrt_max, EmitSqrt(prim_type, max)); TF_ASSIGN_OR_RETURN(llvm::Value * pow, EmitPow(prim_type, one_p_div_sq, @@ -1430,14 +1465,16 @@ StatusOr ElementalIrEmitter::EmitSqrtComplexAbs( // Calculates ComplexAbs in the same way, except using: // rsqrt(|a| * sqrt(1 + (b/a)^2)) = rsqrt(|a|) * rsqrt(sqrt(1 + (b/a)^2)) -StatusOr ElementalIrEmitter::EmitRsqrtComplexAbs( +absl::StatusOr ElementalIrEmitter::EmitRsqrtComplexAbs( PrimitiveType prim_type, llvm::Value* operand_value) { llvm::Value* min; llvm::Value* max; llvm::Value* sqrt; + llvm::Value* real = EmitExtractReal(operand_value); + llvm::Value* imag = EmitExtractImag(operand_value); TF_ASSIGN_OR_RETURN( std::tie(min, max, sqrt), - EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true)); + EmitComplexAbsHelper(prim_type, real, imag, /*return_sqrt=*/true)); TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_max, EmitRsqrt(prim_type, max)); TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_sqrt, EmitRsqrt(prim_type, sqrt)); llvm::Value* result = FMul(rsqrt_max, rsqrt_sqrt); @@ -1447,21 +1484,21 @@ StatusOr ElementalIrEmitter::EmitRsqrtComplexAbs( return Select(FCmpUNO(result, result), rsqrt_min, result); } -StatusOr ElementalIrEmitter::EmitComplexAdd( +absl::StatusOr ElementalIrEmitter::EmitComplexAdd( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { return EmitComposeComplex( op, FAdd(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), FAdd(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); } -StatusOr ElementalIrEmitter::EmitComplexSubtract( +absl::StatusOr ElementalIrEmitter::EmitComplexSubtract( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { return EmitComposeComplex( op, FSub(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), FSub(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); } -StatusOr ElementalIrEmitter::EmitComplexMultiply( +absl::StatusOr ElementalIrEmitter::EmitComplexMultiply( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { return EmitComposeComplex( op, @@ -1471,7 +1508,7 @@ StatusOr ElementalIrEmitter::EmitComplexMultiply( FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value)))); } -StatusOr ElementalIrEmitter::EmitComplexDivide( +absl::StatusOr ElementalIrEmitter::EmitComplexDivide( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { // Division of complex numbers is implemented here, taking into account // over/underflow, NaN and Inf values. @@ -1588,7 +1625,7 @@ StatusOr ElementalIrEmitter::EmitComplexDivide( result); } -StatusOr ElementalIrEmitter::EmitComplexLog( +absl::StatusOr ElementalIrEmitter::EmitComplexLog( const HloInstruction* op, llvm::Value* operand_value) { // log(a+bi) = log(abs(a+bi)) + i*atan2(b,a) PrimitiveType component_type = @@ -1609,7 +1646,7 @@ StatusOr ElementalIrEmitter::EmitComplexLog( // = sqrt(r) * [cos(t/2) + i*sin(t/2)] // where r = |a+bi| and t = atan2(b,a) // TODO(bixia): See doc for implementation without atan2. -StatusOr ElementalIrEmitter::EmitComplexSqrt( +absl::StatusOr ElementalIrEmitter::EmitComplexSqrt( const HloInstruction* op, PrimitiveType prim_type, llvm::Value* operand_value) { llvm::Type* type = static_cast(operand_value->getType()) @@ -1665,7 +1702,7 @@ StatusOr ElementalIrEmitter::EmitComplexSqrt( // = r^(-0.5) * [cos(-t/2) + i*sin(-t/2)] // = rsqrt(r) * [cos(-t/2) + i*sin(-t/2)] // where r = |a+bi| and t = atan2(b,a). -StatusOr ElementalIrEmitter::EmitComplexRsqrt( +absl::StatusOr ElementalIrEmitter::EmitComplexRsqrt( const HloInstruction* op, PrimitiveType prim_type, llvm::Value* operand_value) { llvm::Type* type = static_cast(operand_value->getType()) @@ -1725,68 +1762,78 @@ StatusOr ElementalIrEmitter::EmitComplexRsqrt( return EmitComposeComplex(op, real_part, imag_part); } -// (a+bi)^(c+di) = -// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), +// lhs_value^rhs_value +// = (a+bi)^(c+di) +// = (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) -StatusOr ElementalIrEmitter::EmitComplexPower( - const HloInstruction* op, llvm::Value* a, llvm::Value* b, llvm::Value* c, - llvm::Value* d) { +// = |lhs|^c * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)) +// where q = c*atan2(b,a)+d*ln(|lhs|) +absl::StatusOr ElementalIrEmitter::EmitComplexPower( + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { PrimitiveType component_type = primitive_util::ComplexComponentType(op->shape().element_type()); - llvm::Value* inf = llvm::ConstantFP::getInfinity(a->getType()); - auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b)); - auto zero = llvm::ConstantFP::get(a->getType(), 0); - auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); - auto one = llvm::ConstantFP::get(a->getType(), 1); - auto half_c = FMul(one_half, c); + auto a = EmitExtractReal(lhs_value); + auto b = EmitExtractImag(lhs_value); + auto c = EmitExtractReal(rhs_value); + auto d = EmitExtractImag(rhs_value); - TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c, - EmitPow(component_type, aa_p_bb, half_c, "")); + TF_ASSIGN_OR_RETURN(auto abs, EmitComplexAbs(component_type, lhs_value)); + TF_ASSIGN_OR_RETURN(auto abs_to_c, EmitPow(component_type, abs, c, "")); auto neg_d = FNeg(d); TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a, "")); auto neg_d_arg_lhs = FMul(neg_d, arg_lhs); TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs, EmitExp(component_type, neg_d_arg_lhs, "")); - auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); - TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb)); - auto half_d = FMul(one_half, d); - auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb)); + auto coeff = FMul(abs_to_c, e_to_neg_d_arg_lhs); + TF_ASSIGN_OR_RETURN(auto ln_abs, EmitLog(component_type, abs)); + auto q = FAdd(FMul(c, arg_lhs), FMul(d, ln_abs)); TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); + llvm::Value* inf = llvm::ConstantFP::getInfinity(a->getType()); + auto zero = llvm::ConstantFP::get(a->getType(), 0); + auto one = llvm::ConstantFP::get(a->getType(), 1); + // Case 0: // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see // Branch Cuts for Complex Elementary Functions or Much Ado About // Nothing's Sign Bit, W. Kahan, Section 10. - auto cutoff_0 = Select( - And(And(FCmpOEQ(aa_p_bb, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)), - EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero), - EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q))); + auto cutoff_0 = + Select(And(And(FCmpOEQ(abs, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)), + EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero), + EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q))); // Case 1: - // 1^(c + d*i) = 1 + 0*i - auto cutoff_1 = Select(And(FCmpOEQ(a, one), FCmpOEQ(b, zero)), + // x^0 is defined to be 1 for any x, see + // Branch Cuts for Complex Elementary Functions or Much Ado About + // Nothing's Sign Bit, W. Kahan, Section 10. + auto cutoff_1 = Select(And(FCmpOEQ(zero, c), FCmpOEQ(d, zero)), EmitComposeComplex(op, one, zero), cutoff_0); // Case 2: + // 1^(c + d*i) = 1 + 0*i + auto cutoff_2 = Select(And(FCmpOEQ(a, one), FCmpOEQ(b, zero)), + EmitComposeComplex(op, one, zero), cutoff_1); + + // Case 3: // inf^(c + 0*i) = inf + 0*i, c > 0 - auto cutoff_2 = Select( + auto cutoff_3 = Select( And(FCmpOEQ(a, inf), And(FCmpOEQ(b, zero), And(FCmpOEQ(d, zero), FCmpOGT(c, zero)))), - EmitComposeComplex(op, inf, zero), cutoff_1); + EmitComposeComplex(op, inf, zero), cutoff_2); - // Case 3: + // Case 4: // inf^(c + 0*i) = 0 + 0*i, c < 0 - auto cutoff_3 = Select( + auto cutoff_4 = Select( And(FCmpOEQ(a, inf), And(FCmpOEQ(b, zero), And(FCmpOEQ(d, zero), FCmpOLT(c, zero)))), - EmitComposeComplex(op, zero, zero), cutoff_2); + EmitComposeComplex(op, zero, zero), cutoff_3); - return cutoff_3; + return cutoff_4; } -StatusOr ElementalIrEmitter::EmitComplexBinaryOp( +absl::StatusOr ElementalIrEmitter::EmitComplexBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { switch (op->opcode()) { case HloOpcode::kAdd: @@ -1828,11 +1875,7 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( } } case HloOpcode::kPower: { - auto a = EmitExtractReal(lhs_value); - auto b = EmitExtractImag(lhs_value); - auto c = EmitExtractReal(rhs_value); - auto d = EmitExtractImag(rhs_value); - return EmitComplexPower(op, a, b, c, d); + return EmitComplexPower(op, lhs_value, rhs_value); } case HloOpcode::kAtan2: { // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2)) @@ -1879,14 +1922,14 @@ llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max(), name); } -StatusOr ElementalIrEmitter::EmitLog(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr ElementalIrEmitter::EmitLog( + PrimitiveType prim_type, llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value}, {value->getType()}, b_); } -StatusOr ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr ElementalIrEmitter::EmitLog1p( + PrimitiveType prim_type, llvm::Value* value) { auto x = value; auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); @@ -1935,39 +1978,77 @@ StatusOr ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, return Select(x_is_small, for_small_x, for_large_x); } -StatusOr ElementalIrEmitter::EmitSqrt(PrimitiveType, - llvm::Value* value) { +absl::StatusOr ElementalIrEmitter::EmitSqrt(PrimitiveType, + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {value}, {value->getType()}, b_); } -StatusOr ElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr ElementalIrEmitter::EmitRsqrt( + PrimitiveType prim_type, llvm::Value* value) { TF_ASSIGN_OR_RETURN(auto sqrt, EmitSqrt(prim_type, value)); return FDiv(llvm::ConstantFP::get(sqrt->getType(), 1.0), sqrt); } -StatusOr ElementalIrEmitter::EmitSin(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr ElementalIrEmitter::EmitSin( + PrimitiveType prim_type, llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, {value->getType()}, b_); } -StatusOr ElementalIrEmitter::EmitCos(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr ElementalIrEmitter::EmitCos( + PrimitiveType prim_type, llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value}, {value->getType()}, b_); } -StatusOr ElementalIrEmitter::EmitExp(PrimitiveType prim_type, - llvm::Value* value, - absl::string_view name) { +absl::StatusOr ElementalIrEmitter::EmitCosm1( + PrimitiveType prim_type, llvm::Value* value) { + auto x = value; + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto negative_half = llvm::ConstantFP::get(type, -0.5); + auto negative_one = llvm::ConstantFP::get(type, -1.0); + + // Algorithm copied from cephes cosm1: + // cosm1(x) = -0.5 * x^2 + x^4 * P(x^2); + // that is suitable when abs(x) < pi/4, otherwise we'll use cos(x)-1. + // + // This is an alternative algorithm + // cosm1(x) = -2 * sin(x/2)^2 + // that is only slightly less accurate around abs(x) == 0.1 but + // otherwise equivalent accuracy-wise compared to cephes cosm1. + // However, we are not using it because it is notably less + // performant than cephes cosm1. + + // TODO: define cosm1(x) as cosm1(x mod (2*pi)) to increase accuracy + // for large x values that are close to 2*pi*n where n is some integer. + static const std::array kCoeffs{ + 4.7377507964246204691685E-14, -1.1470284843425359765671E-11, + 2.0876754287081521758361E-9, -2.7557319214999787979814E-7, + 2.4801587301570552304991E-5, -1.3888888888888872993737E-3, + 4.1666666666666666609054E-2, + }; + TF_ASSIGN_OR_RETURN(auto cos_x, EmitCos(prim_type, x)); + auto for_large_x = FAdd(cos_x, negative_one); + + auto xx = FMul(x, x); + auto xxxx = FMul(xx, xx); + TF_ASSIGN_OR_RETURN(auto poly, EvaluatePolynomial(type, xx, kCoeffs)); + auto for_small_x = FAdd(FMul(xxxx, poly), FMul(negative_half, xx)); + + // (pi/4)^2 is approximately 0.61685 + return Select(FCmpOGT(xx, llvm::ConstantFP::get(type, 0.61685)), for_large_x, + for_small_x); +} + +absl::StatusOr ElementalIrEmitter::EmitExp( + PrimitiveType prim_type, llvm::Value* value, absl::string_view name) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value}, {value->getType()}, b_, name); } -StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr ElementalIrEmitter::EmitExpm1( + PrimitiveType prim_type, llvm::Value* value) { auto x = value; auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); @@ -1992,16 +2073,15 @@ StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, return expm1_of_x; } -StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, - llvm::Value* lhs, - llvm::Value* rhs, - absl::string_view name) { +absl::StatusOr ElementalIrEmitter::EmitPow( + PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs, + absl::string_view name) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs}, {lhs->getType()}, b_, name); } -StatusOr ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr ElementalIrEmitter::EmitCbrt( + PrimitiveType prim_type, llvm::Value* value) { auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto third = llvm::ConstantFP::get(type, 1.0 / 3.0); auto abs_value = @@ -2013,19 +2093,24 @@ StatusOr ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type, return signed_res; } -StatusOr ElementalIrEmitter::EmitAtan2( +absl::StatusOr ElementalIrEmitter::EmitAtan2( PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* /*rhs*/, absl::string_view /*name*/) { return Unimplemented("atan2"); } -StatusOr ElementalIrEmitter::EmitTanh(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr ElementalIrEmitter::EmitTanh( + PrimitiveType prim_type, llvm::Value* value) { return Unimplemented("tanh"); } -StatusOr ElementalIrEmitter::EmitTan(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr ElementalIrEmitter::EmitErf( + PrimitiveType prim_type, llvm::Value* value) { + return Unimplemented("erf"); +} + +absl::StatusOr ElementalIrEmitter::EmitTan( + PrimitiveType prim_type, llvm::Value* value) { auto sin_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, {value->getType()}, b_); auto cos_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value}, @@ -2033,7 +2118,7 @@ StatusOr ElementalIrEmitter::EmitTan(PrimitiveType prim_type, return FDiv(sin_x, cos_x); } -StatusOr ElementalIrEmitter::EmitReducePrecision( +absl::StatusOr ElementalIrEmitter::EmitReducePrecision( const HloInstruction* hlo, llvm::Value* x) { return EmitReducePrecisionIR( /*src_ty=*/hlo->operand(0)->shape().element_type(), x, @@ -2174,7 +2259,7 @@ llvm::Value* ElementalIrEmitter::EmitIntegerPow(llvm::Value* base, b_->CreateSelect(b_->CreateICmpEQ(original_base, one), one, zero)); } -StatusOr ElementalIrEmitter::EmitPredBinaryOp( +absl::StatusOr ElementalIrEmitter::EmitPredBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { // Per the reference interpreter, pred arithmetic should behave like // `int8_t(x) OP int8_t(y) != 0`. For most permitted ops, we can just emit @@ -2216,8 +2301,8 @@ StatusOr ElementalIrEmitter::EmitPredBinaryOp( case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: - return InternalError("Invalid binary op '%s' for pred", - HloOpcodeString(op->opcode())); + return Internal("Invalid binary op '%s' for pred", + HloOpcodeString(op->opcode())); default: return Unimplemented("binary pred op '%s'", @@ -2225,7 +2310,7 @@ StatusOr ElementalIrEmitter::EmitPredBinaryOp( } } -StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( +absl::StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, bool is_signed) { switch (op->opcode()) { @@ -2320,7 +2405,7 @@ llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value, lhs_value, rhs_value); } -StatusOr ElementalIrEmitter::EmitElementalSelect( +absl::StatusOr ElementalIrEmitter::EmitElementalSelect( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& index) { @@ -2334,7 +2419,7 @@ StatusOr ElementalIrEmitter::EmitElementalSelect( on_false_value); } -StatusOr ElementalIrEmitter::EmitElementalClamp( +absl::StatusOr ElementalIrEmitter::EmitElementalClamp( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& index) { @@ -2357,7 +2442,7 @@ StatusOr ElementalIrEmitter::EmitElementalClamp( } } -StatusOr ElementalIrEmitter::EmitElementalConcatenate( +absl::StatusOr ElementalIrEmitter::EmitElementalConcatenate( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& source_index) { @@ -2515,7 +2600,7 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( return output; } -StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( +absl::StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& index) { @@ -2563,7 +2648,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( return operand_to_generator.at(input_hlo)(input_index); } -StatusOr ElementalIrEmitter::EmitElementalGather( +absl::StatusOr ElementalIrEmitter::EmitElementalGather( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& index) { @@ -2682,7 +2767,8 @@ StatusOr ElementalIrEmitter::EmitElementalGather( return operand_generator(operand_index); } -StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( +absl::StatusOr +ElementalIrEmitter::EmitElementalDynamicUpdateSlice( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& index) { @@ -2769,7 +2855,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( return Load(ret_value_addr->getAllocatedType(), ret_value_addr); } -StatusOr ElementalIrEmitter::EmitElementalPad( +absl::StatusOr ElementalIrEmitter::EmitElementalPad( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& padded_index) { @@ -2830,10 +2916,21 @@ StatusOr ElementalIrEmitter::EmitElementalPad( return Load(ret_value_addr->getAllocatedType(), ret_value_addr); } -StatusOr ElementalIrEmitter::EmitElementalDot( +absl::StatusOr ElementalIrEmitter::EmitElementalDot( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& dot_result_index) { + if (!algorithm_util::IsSupportedByElementalIrEmitter( + hlo->precision_config().algorithm())) { + return absl::InvalidArgumentError(absl::StrFormat( + "Algorithm not supported by the ElementalIrEmitter: %s", + PrecisionConfig::Algorithm_Name(hlo->precision_config().algorithm()))); + } + const HloDotInstruction* dot = Cast(hlo); + if (dot->sparse_operands()) { + return Unimplemented("Sparse dot is supported by Triton emitter only."); + } + auto lhs_generator = operand_to_generator.at(hlo->operand(0)); auto rhs_generator = operand_to_generator.at(hlo->operand(1)); @@ -2926,8 +3023,8 @@ StatusOr ElementalIrEmitter::EmitElementalDot( TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); if (primitive_type == BF16) { - lhs_value = EmitBF16ToF32(lhs_value, b_); - rhs_value = EmitBF16ToF32(rhs_value, b_); + lhs_value = b_->CreateFPExt(lhs_value, b_->getFloatTy()); + rhs_value = b_->CreateFPExt(rhs_value, b_->getFloatTy()); } llvm::Value* next_accumulator = @@ -2939,7 +3036,7 @@ StatusOr ElementalIrEmitter::EmitElementalDot( llvm::Value* result = Load(accumulator_alloca->getAllocatedType(), accumulator_alloca); - return primitive_type == BF16 ? EmitF32ToBF16(result) : result; + return primitive_type == BF16 ? FPTrunc(result, b_->getBFloatTy()) : result; } llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( @@ -2954,6 +3051,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kConvert: case HloOpcode::kBitcastConvert: case HloOpcode::kCos: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -2973,7 +3071,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kTan: case HloOpcode::kTanh: return [this, hlo, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { + const IrArray::Index& index) -> absl::StatusOr { TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, operand_to_generator.at(hlo->operand(0))(index)); return EmitUnaryOp(hlo, operand_value); @@ -2996,7 +3094,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kShiftRightLogical: case HloOpcode::kSubtract: return [this, hlo, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { + const IrArray::Index& index) -> absl::StatusOr { const HloInstruction* lhs = hlo->operand(0); const HloInstruction* rhs = hlo->operand(1); TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, @@ -3007,30 +3105,32 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( }; case HloOpcode::kSelect: return [this, hlo, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { + const IrArray::Index& index) -> absl::StatusOr { return EmitElementalSelect(hlo, operand_to_generator, index); }; case HloOpcode::kClamp: return [this, hlo, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { + const IrArray::Index& index) -> absl::StatusOr { return EmitElementalClamp(hlo, operand_to_generator, index); }; case HloOpcode::kReducePrecision: return [this, hlo, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { + const IrArray::Index& index) -> absl::StatusOr { TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, operand_to_generator.at(hlo->operand(0))(index)); return EmitReducePrecision(hlo, operand_value); }; case HloOpcode::kConcatenate: - return [this, hlo, &operand_to_generator]( - const IrArray::Index target_index) -> StatusOr { - return EmitElementalConcatenate(hlo, operand_to_generator, - target_index); - }; + return + [this, hlo, &operand_to_generator](const IrArray::Index target_index) + -> absl::StatusOr { + return EmitElementalConcatenate(hlo, operand_to_generator, + target_index); + }; case HloOpcode::kReverse: - return [this, hlo, &operand_to_generator]( - const IrArray::Index& target_index) -> StatusOr { + return [this, hlo, + &operand_to_generator](const IrArray::Index& target_index) + -> absl::StatusOr { const HloInstruction* operand = hlo->operand(0); std::vector source_multi_index = target_index.multidim(); for (int64_t dim : hlo->dimensions()) { @@ -3043,18 +3143,19 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return operand_to_generator.at(operand)(source_index); }; case HloOpcode::kBroadcast: - return [this, hlo, &operand_to_generator]( - const IrArray::Index& target_index) -> StatusOr { - const HloInstruction* operand = hlo->operand(0); - // The `dimensions` member of the broadcast instruction maps from - // input dimensions to output dimensions. - return operand_to_generator.at(operand)( - target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(), - hlo->dimensions(), b_)); - }; + return + [this, hlo, &operand_to_generator](const IrArray::Index& target_index) + -> absl::StatusOr { + const HloInstruction* operand = hlo->operand(0); + // The `dimensions` member of the broadcast instruction maps from + // input dimensions to output dimensions. + return operand_to_generator.at(operand)( + target_index.SourceIndexOfBroadcast( + hlo->shape(), operand->shape(), hlo->dimensions(), b_)); + }; case HloOpcode::kIota: - return [this, hlo]( - const IrArray::Index& target_index) -> StatusOr { + return [this, hlo](const IrArray::Index& target_index) + -> absl::StatusOr { auto* iota = Cast(hlo); PrimitiveType element_type = iota->shape().element_type(); IrArray::Index elem_index = @@ -3088,9 +3189,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( primitive_util::IsFloatingPointType(component_element_type)) << component_element_type; llvm::Type* float_ir_type; - if (component_element_type == BF16) { - float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_); - } else if (component_element_type == F8E4M3FNUZ) { + if (component_element_type == F8E4M3FNUZ) { float_ir_type = llvm_ir::PrimitiveTypeToIrType(F16, module_); } else if (component_element_type == F8E5M2FNUZ) { float_ir_type = llvm_ir::PrimitiveTypeToIrType(F16, module_); @@ -3100,10 +3199,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( } llvm::Value* float_val = b_->CreateUIToFP(elem_index_linear, float_ir_type); - if (component_element_type == BF16) { - TF_ASSIGN_OR_RETURN(iota_result, EmitF32ToBF16(float_val)); - } else if (component_element_type == F8E4M3FNUZ || - component_element_type == F8E5M2FNUZ) { + if (component_element_type == F8E4M3FNUZ || + component_element_type == F8E5M2FNUZ) { TF_ASSIGN_OR_RETURN( iota_result, EmitFloatingToF8fnuz(F16, float_val, component_element_type, b_)); @@ -3119,7 +3216,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( }; case HloOpcode::kSlice: return [this, hlo, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { + const IrArray::Index& index) -> absl::StatusOr { IrArray::Index sliced_index = index.SourceIndexOfSlice( /*operand_shape=*/hlo->operand(0)->shape(), /*starts=*/hlo->slice_starts(), @@ -3128,18 +3225,18 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( }; case HloOpcode::kDynamicSlice: return [this, hlo, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { + const IrArray::Index& index) -> absl::StatusOr { return EmitElementalDynamicSlice(hlo, operand_to_generator, index); }; case HloOpcode::kGather: return [this, hlo, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { + const IrArray::Index& index) -> absl::StatusOr { return EmitElementalGather(hlo, operand_to_generator, index); }; case HloOpcode::kDynamicUpdateSlice: return [this, hlo, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { + const IrArray::Index& index) -> absl::StatusOr { return EmitElementalDynamicUpdateSlice(hlo, operand_to_generator, index); }; @@ -3160,8 +3257,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( index.SourceIndexOfReshape(hlo->shape(), operand->shape(), b_)); }; case HloOpcode::kCopy: - return [hlo, &operand_to_generator]( - const IrArray::Index& target_index) -> StatusOr { + return [hlo, &operand_to_generator](const IrArray::Index& target_index) + -> absl::StatusOr { IrArray::Index source_index(target_index.multidim(), hlo->operand(0)->shape(), target_index.GetType()); @@ -3177,20 +3274,21 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions())); }; case HloOpcode::kPad: - return [this, hlo, &operand_to_generator]( - const IrArray::Index& padded_index) -> StatusOr { - return EmitElementalPad(hlo, operand_to_generator, padded_index); - }; + return + [this, hlo, &operand_to_generator](const IrArray::Index& padded_index) + -> absl::StatusOr { + return EmitElementalPad(hlo, operand_to_generator, padded_index); + }; case HloOpcode::kDot: return [this, hlo, &operand_to_generator](const IrArray::Index& dot_result_index) - -> StatusOr { + -> absl::StatusOr { return EmitElementalDot(hlo, operand_to_generator, dot_result_index); }; case HloOpcode::kMap: return [this, hlo, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { + const IrArray::Index& index) -> absl::StatusOr { std::vector operands; for (int i = 0; i < hlo->operand_count(); i++) { TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, @@ -3250,11 +3348,6 @@ llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) { return ExtractValue(value, {1}); } -StatusOr ElementalIrEmitter::EmitF32ToBF16( - llvm::Value* f32_value) { - return DefaultEmitF32ToBF16Impl(f32_value, b_); -} - llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, llvm::Value* real, llvm::Value* imag) { @@ -3290,7 +3383,7 @@ llvm::Value* ElementalIrEmitter::EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs, return Add(accumulator, Mul(lhs, rhs)); } -StatusOr ElementalIrEmitter::EmitElementalMap( +absl::StatusOr ElementalIrEmitter::EmitElementalMap( const HloMapInstruction* map_instr, absl::Span elemental_operands) { TF_ASSIGN_OR_RETURN( @@ -3301,7 +3394,7 @@ StatusOr ElementalIrEmitter::EmitElementalMap( return values[0]; } -StatusOr ElementalIrEmitter::EmitElementalReduceWindow( +absl::StatusOr ElementalIrEmitter::EmitElementalReduceWindow( const HloReduceWindowInstruction* reduce_window, std::vector input_generators, std::vector initial_value_generators, @@ -3427,7 +3520,7 @@ StatusOr ElementalIrEmitter::EmitElementalReduceWindow( reduce_window->shape().IsTuple()); } -StatusOr ElementalIrEmitter::EmitElementalReduce( +absl::StatusOr ElementalIrEmitter::EmitElementalReduce( const HloReduceInstruction* reduce, std::vector input_generators, std::vector initial_value_generators, @@ -3518,7 +3611,7 @@ StatusOr ElementalIrEmitter::EmitElementalReduce( return EmitAccumResult(accumulator_addrs, accumulator_types, is_variadic); } -StatusOr ElementalIrEmitter::EmitAccumResult( +absl::StatusOr ElementalIrEmitter::EmitAccumResult( absl::Span accumulator_addrs, llvm::ArrayRef accumulator_types, bool is_variadic) { TF_RET_CHECK(accumulator_addrs.size() == accumulator_types.size()); @@ -3539,7 +3632,7 @@ StatusOr ElementalIrEmitter::EmitAccumResult( } } -StatusOr ElementalIrEmitter::EmitConvolution( +absl::StatusOr ElementalIrEmitter::EmitConvolution( const HloInstruction* convolution, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& index) { @@ -3698,7 +3791,7 @@ StatusOr ElementalIrEmitter::EmitConvolution( } // Evaluate polynomial using Horner's method. -StatusOr ElementalIrEmitter::EvaluatePolynomial( +absl::StatusOr ElementalIrEmitter::EvaluatePolynomial( llvm::Type* type, llvm::Value* x, absl::Span coefficients) { llvm::Value* poly = llvm::ConstantFP::get(type, 0.0); for (const double c : coefficients) { diff --git a/xla/service/elemental_ir_emitter.h b/xla/service/elemental_ir_emitter.h index 5dccb034e66e9..b636cb82df5ea 100644 --- a/xla/service/elemental_ir_emitter.h +++ b/xla/service/elemental_ir_emitter.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" @@ -69,31 +70,28 @@ class ElementalIrEmitter : public IrBuilderMixin { b_); } - virtual StatusOr EmitFloatBinaryOp(const HloInstruction* op, - llvm::Value* lhs_value, - llvm::Value* rhs_value); + virtual absl::StatusOr EmitFloatBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value); virtual llvm::Value* EmitExtractReal(llvm::Value* value); virtual llvm::Value* EmitExtractImag(llvm::Value* value); - virtual StatusOr EmitF32ToBF16(llvm::Value* f32_value); - private: - virtual StatusOr EmitUnaryOp(const HloInstruction* op, - llvm::Value* operand_value); + virtual absl::StatusOr EmitUnaryOp(const HloInstruction* op, + llvm::Value* operand_value); - virtual StatusOr EmitBinaryOp(const HloInstruction* op, - llvm::Value* lhs_value, - llvm::Value* rhs_value); + virtual absl::StatusOr EmitBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value); - virtual StatusOr EmitIntegerUnaryOp(const HloInstruction* op, - llvm::Value* operand_value); + virtual absl::StatusOr EmitIntegerUnaryOp( + const HloInstruction* op, llvm::Value* operand_value); - virtual StatusOr EmitFloatUnaryOp(const HloInstruction* op, - llvm::Value* operand_value); + virtual absl::StatusOr EmitFloatUnaryOp( + const HloInstruction* op, llvm::Value* operand_value); - virtual StatusOr EmitComplexUnaryOp(const HloInstruction* op, - llvm::Value* operand_value); + virtual absl::StatusOr EmitComplexUnaryOp( + const HloInstruction* op, llvm::Value* operand_value); llvm::Value* IsZero(llvm::Value* v); llvm::Value* IsIntMinDivisionOverflow(llvm::Value* lhs, llvm::Value* rhs); @@ -109,18 +107,15 @@ class ElementalIrEmitter : public IrBuilderMixin { llvm::Value* EmitIntegerPow(llvm::Value* lhs, llvm::Value* rhs, bool is_signed); - virtual StatusOr EmitPredBinaryOp(const HloInstruction* op, - llvm::Value* lhs_value, - llvm::Value* rhs_value); + virtual absl::StatusOr EmitPredBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value); - virtual StatusOr EmitIntegerBinaryOp(const HloInstruction* op, - llvm::Value* lhs_value, - llvm::Value* rhs_value, - bool is_signed); + virtual absl::StatusOr EmitIntegerBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, + bool is_signed); - virtual StatusOr EmitComplexBinaryOp(const HloInstruction* op, - llvm::Value* lhs_value, - llvm::Value* rhs_value); + virtual absl::StatusOr EmitComplexBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value); virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, @@ -136,91 +131,96 @@ class ElementalIrEmitter : public IrBuilderMixin { llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, bool is_signed); - virtual StatusOr EmitAtan2(PrimitiveType prim_type, - llvm::Value* lhs, llvm::Value* rhs, - absl::string_view name); + virtual absl::StatusOr EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs, + absl::string_view name); + + virtual absl::StatusOr EmitLog(PrimitiveType prim_type, + llvm::Value* value); - virtual StatusOr EmitLog(PrimitiveType prim_type, - llvm::Value* value); + virtual absl::StatusOr EmitSqrt(PrimitiveType prim_type, + llvm::Value* value); - virtual StatusOr EmitSqrt(PrimitiveType prim_type, - llvm::Value* value); + virtual absl::StatusOr EmitCbrt(PrimitiveType prim_type, + llvm::Value* value); - virtual StatusOr EmitCbrt(PrimitiveType prim_type, - llvm::Value* value); + virtual absl::StatusOr EmitRsqrt(PrimitiveType prim_type, + llvm::Value* value); - virtual StatusOr EmitRsqrt(PrimitiveType prim_type, - llvm::Value* value); + virtual absl::StatusOr EmitLog1p(PrimitiveType prim_type, + llvm::Value* value); - virtual StatusOr EmitLog1p(PrimitiveType prim_type, - llvm::Value* value); + virtual absl::StatusOr EmitSin(PrimitiveType prim_type, + llvm::Value* value); - virtual StatusOr EmitSin(PrimitiveType prim_type, - llvm::Value* value); + virtual absl::StatusOr EmitCos(PrimitiveType prim_type, + llvm::Value* value); - virtual StatusOr EmitCos(PrimitiveType prim_type, - llvm::Value* value); + virtual absl::StatusOr EmitCosm1(PrimitiveType prim_type, + llvm::Value* value); - virtual StatusOr EmitTan(PrimitiveType prim_type, - llvm::Value* value); + virtual absl::StatusOr EmitTan(PrimitiveType prim_type, + llvm::Value* value); - virtual StatusOr EmitExp(PrimitiveType prim_type, - llvm::Value* value, - absl::string_view name); + virtual absl::StatusOr EmitExp(PrimitiveType prim_type, + llvm::Value* value, + absl::string_view name); - virtual StatusOr EmitExpm1(PrimitiveType prim_type, - llvm::Value* value); + virtual absl::StatusOr EmitExpm1(PrimitiveType prim_type, + llvm::Value* value); - virtual StatusOr EmitPow(PrimitiveType prim_type, - llvm::Value* lhs, llvm::Value* rhs, - absl::string_view name); + virtual absl::StatusOr EmitPow(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs, + absl::string_view name); - virtual StatusOr EmitTanh(PrimitiveType prim_type, - llvm::Value* value); + virtual absl::StatusOr EmitErf(PrimitiveType prim_type, + llvm::Value* value); - virtual StatusOr EmitReducePrecision(const HloInstruction* hlo, - llvm::Value* x); + virtual absl::StatusOr EmitTanh(PrimitiveType prim_type, + llvm::Value* value); - virtual StatusOr> - EmitComplexAbsHelper(PrimitiveType prim_type, llvm::Value* operand_value, - bool return_sqrt); + virtual absl::StatusOr EmitReducePrecision( + const HloInstruction* hlo, llvm::Value* x); - virtual StatusOr EmitComplexAbs(PrimitiveType prim_type, - llvm::Value* operand_value); + virtual absl::StatusOr> + EmitComplexAbsHelper(PrimitiveType prim_type, llvm::Value* real, + llvm::Value* imag, bool return_sqrt); - virtual StatusOr EmitSqrtComplexAbs(PrimitiveType prim_type, - llvm::Value* operand_value); - virtual StatusOr EmitRsqrtComplexAbs( + virtual absl::StatusOr EmitComplexAbs( PrimitiveType prim_type, llvm::Value* operand_value); - virtual StatusOr EmitComplexAdd(const HloInstruction* op, - llvm::Value* lhs_value, - llvm::Value* rhs_value); + virtual absl::StatusOr EmitSqrtComplexAbs( + PrimitiveType prim_type, llvm::Value* operand_value); + virtual absl::StatusOr EmitRsqrtComplexAbs( + PrimitiveType prim_type, llvm::Value* operand_value); - virtual StatusOr EmitComplexSubtract(const HloInstruction* op, - llvm::Value* lhs_value, - llvm::Value* rhs_value); + virtual absl::StatusOr EmitComplexAdd(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value); - virtual StatusOr EmitComplexMultiply(const HloInstruction* op, - llvm::Value* lhs_value, - llvm::Value* rhs_value); + virtual absl::StatusOr EmitComplexSubtract( + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value); - virtual StatusOr EmitComplexDivide(const HloInstruction* op, - llvm::Value* lhs_value, - llvm::Value* rhs_value); + virtual absl::StatusOr EmitComplexMultiply( + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value); - virtual StatusOr EmitComplexLog(const HloInstruction* op, - llvm::Value* operand_value); + virtual absl::StatusOr EmitComplexDivide( + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value); - virtual StatusOr EmitComplexSqrt(const HloInstruction* op, - PrimitiveType prim_type, - llvm::Value* operand_value); + virtual absl::StatusOr EmitComplexLog( + const HloInstruction* op, llvm::Value* operand_value); - virtual StatusOr EmitComplexRsqrt(const HloInstruction* op, - PrimitiveType prim_type, - llvm::Value* operand_value); + virtual absl::StatusOr EmitComplexSqrt( + const HloInstruction* op, PrimitiveType prim_type, + llvm::Value* operand_value); - StatusOr EmitAccumResult( + virtual absl::StatusOr EmitComplexRsqrt( + const HloInstruction* op, PrimitiveType prim_type, + llvm::Value* operand_value); + + absl::StatusOr EmitAccumResult( absl::Span accumulator_addrs, llvm::ArrayRef accumulator_types, bool is_variadic); @@ -233,81 +233,78 @@ class ElementalIrEmitter : public IrBuilderMixin { llvm::Value* accumulator, xla::PrimitiveType primitive_type); - // Identifier of the thread unique among all threads on the device - virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); } - - StatusOr EmitElementalSelect( + absl::StatusOr EmitElementalSelect( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& index); - StatusOr EmitElementalClamp( + absl::StatusOr EmitElementalClamp( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& index); - StatusOr EmitElementalConcatenate( + absl::StatusOr EmitElementalConcatenate( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& target_index); - StatusOr EmitElementalDynamicSlice( + absl::StatusOr EmitElementalDynamicSlice( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& index); - StatusOr EmitElementalGather( + absl::StatusOr EmitElementalGather( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& index); - StatusOr EmitElementalDynamicUpdateSlice( + absl::StatusOr EmitElementalDynamicUpdateSlice( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& index); - StatusOr EmitElementalPad( + absl::StatusOr EmitElementalPad( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& padded_index); - StatusOr EmitElementalDot( + absl::StatusOr EmitElementalDot( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& dot_result_index); - virtual StatusOr> EmitThreadLocalCall( + virtual absl::StatusOr> EmitThreadLocalCall( const HloComputation& callee, absl::Span parameters, absl::string_view name, bool is_reducer) = 0; - StatusOr EmitElementalMap( + absl::StatusOr EmitElementalMap( const HloMapInstruction* map_instr, absl::Span elemental_operands); - StatusOr EmitElementalReduceWindow( + absl::StatusOr EmitElementalReduceWindow( const HloReduceWindowInstruction* reduce_window, std::vector input_generators, std::vector initial_value_generators, const llvm_ir::IrArray::Index& index); - StatusOr EmitElementalReduce( + absl::StatusOr EmitElementalReduce( const HloReduceInstruction* reduce, std::vector input_generators, std::vector initial_value_generators, const llvm_ir::IrArray::Index& index); - virtual StatusOr EmitConvolution( + virtual absl::StatusOr EmitConvolution( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& index); - // Computes the complex power function, returns (a + i*b)^(c + i*d). - StatusOr EmitComplexPower(const HloInstruction* op, - llvm::Value* a, llvm::Value* b, - llvm::Value* c, llvm::Value* d); + // Computes the complex power function. + absl::StatusOr EmitComplexPower(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value); // Evaluates a polynomial using Horner's method. - StatusOr EvaluatePolynomial( + absl::StatusOr EvaluatePolynomial( llvm::Type* type, llvm::Value* x, absl::Span coefficients); virtual bool fast_min_max() = 0; @@ -315,6 +312,30 @@ class ElementalIrEmitter : public IrBuilderMixin { llvm::IRBuilder<>* const b_; llvm::Module* module_; + + friend class ElementalIrEmitterForTests; +}; + +// Allow to instantiate IR emitter in tests. +class ElementalIrEmitterForTests : public ElementalIrEmitter { + public: + ElementalIrEmitterForTests(llvm::Module* module, llvm::IRBuilder<>* builder) + : ElementalIrEmitter(module, builder) {} + + absl::Status TestElementalDot(const HloInstruction* hlo, + const llvm_ir::IrArray::Index& index) { + return EmitElementalDot(hlo, generator_map_, index).status(); + } + + private: + absl::StatusOr> EmitThreadLocalCall( + const HloComputation& callee, absl::Span parameters, + absl::string_view name, bool is_reducer) override { + return absl::UnimplementedError(""); + } + bool fast_min_max() override { return false; } + + HloToElementGeneratorMap generator_map_; }; } // namespace xla diff --git a/xla/service/elemental_ir_emitter_test.cc b/xla/service/elemental_ir_emitter_test.cc index 9823d9101eb5f..9ee2680065a26 100644 --- a/xla/service/elemental_ir_emitter_test.cc +++ b/xla/service/elemental_ir_emitter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,15 +13,30 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/service/elemental_ir_emitter.h" + +#include +#include +#include +#include +#include + #include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" #include "xla/error_spec.h" -#include "xla/execution_options_util.h" -#include "xla/service/hlo_parser.h" -#include "xla/status_macros.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/llvm_ir/ir_array.h" #include "xla/test.h" -#include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -698,8 +713,10 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } +// TODO(b/324385428): Failing on GPU at head due to an LLVM integrate. Re-enable +// once this has been fixed. XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax, - MinimumHandlesNaNsOnTheRight) { + DISABLED_MinimumHandlesNaNsOnTheRight) { constexpr absl::string_view kHloText = R"( HloModule t @@ -811,5 +828,32 @@ ENTRY e { /*arel=*/1e-3})); } +class ElementalIrEmitterInternalTest : public HloTestBase {}; + +XLA_TEST_F(ElementalIrEmitterInternalTest, SparseDotIsUnsupported) { + constexpr absl::string_view kHloText = R"( +HloModule test + +ENTRY main { + lhs = f16[5,16] parameter(0) + rhs = f16[32,10] parameter(1) + meta = u16[5,2] parameter(2) + ROOT dot = f32[5,10] dot(lhs, rhs, meta), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4 +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); + HloInstruction* root = module->entry_computation()->root_instruction(); + + llvm::LLVMContext llvm_context; + llvm::Module llvm_module("", llvm_context); + llvm::IRBuilder<> builder(llvm_context); + ElementalIrEmitterForTests emitter(&llvm_module, &builder); + + llvm_ir::IrArray::Index test_index{builder.getInt64Ty()}; + auto result = emitter.TestElementalDot(root, test_index); + EXPECT_FALSE(result.ok()); +} + } // namespace } // namespace xla diff --git a/xla/service/executable.cc b/xla/service/executable.cc index f75975bcce805..7349471ea2439 100644 --- a/xla/service/executable.cc +++ b/xla/service/executable.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -58,7 +58,7 @@ void ExecutionInput::SetUnownedBuffer(const ShapeIndex& index, unowned_indices_.insert(index); } -StatusOr ExecutionInput::ToShapedBuffer( +absl::StatusOr ExecutionInput::ToShapedBuffer( se::DeviceMemoryAllocator* allocator, int device_ordinal) const { const Shape& input_shape = shape(); ShapedBuffer shaped_buffer(input_shape, device_ordinal); @@ -77,11 +77,11 @@ StatusOr ExecutionInput::ToShapedBuffer( return std::move(shaped_buffer); } -StatusOr Executable::ExecuteOnStream( +absl::StatusOr Executable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { - StatusOr result = + absl::StatusOr result = ExecuteAsyncOnStream(run_options, arguments, hlo_execution_profile); Status blocking_status = run_options->stream()->BlockHostUntilDone(); TF_RETURN_IF_ERROR(result.status()); @@ -99,7 +99,7 @@ static ExecutionInput MakeMaybeOwningDeviceMemoryTree( return result; } -StatusOr Executable::ExecuteAsyncOnStream( +absl::StatusOr Executable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { @@ -114,11 +114,11 @@ StatusOr Executable::ExecuteAsyncOnStream( return out.ConsumeResult(); } -StatusOr Executable::ExecuteOnStream( +absl::StatusOr Executable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, std::vector arguments, HloExecutionProfile* hlo_execution_profile) { - StatusOr result = ExecuteAsyncOnStream( + absl::StatusOr result = ExecuteAsyncOnStream( run_options, std::move(arguments), hlo_execution_profile); Status blocking_status = run_options->stream()->BlockHostUntilDone(); TF_RETURN_IF_ERROR(result.status()); @@ -126,7 +126,7 @@ StatusOr Executable::ExecuteOnStream( return result; } -StatusOr> Executable::ExecuteOnStreams( +absl::StatusOr> Executable::ExecuteOnStreams( absl::Span run_options, absl::Span> arguments) { TF_RET_CHECK(run_options.size() == arguments.size()); @@ -158,10 +158,10 @@ StatusOr> Executable::ExecuteOnStreams( return std::move(return_values); } -StatusOr Executable::ExecuteOnStreamWrapper( +absl::StatusOr Executable::ExecuteOnStreamWrapper( const ServiceExecutableRunOptions* run_options, absl::Span arguments) { - StatusOr result = + absl::StatusOr result = ExecuteAsyncOnStreamWrapper(run_options, arguments); Status block_status = run_options->stream()->BlockHostUntilDone(); TF_RETURN_IF_ERROR(result.status()); @@ -169,10 +169,10 @@ StatusOr Executable::ExecuteOnStreamWrapper( return result; } -StatusOr Executable::ExecuteOnStreamWrapper( +absl::StatusOr Executable::ExecuteOnStreamWrapper( const ServiceExecutableRunOptions* run_options, std::vector arguments) { - StatusOr result = + absl::StatusOr result = ExecuteAsyncOnStreamWrapper(run_options, std::move(arguments)); Status block_status = run_options->stream()->BlockHostUntilDone(); TF_RETURN_IF_ERROR(result.status()); @@ -233,22 +233,22 @@ Status ExecuteWrapperAfterExecution( return return_status; } -StatusOr Executable::ExecuteAsyncOnStreamWrapper( +absl::StatusOr Executable::ExecuteAsyncOnStreamWrapper( const ServiceExecutableRunOptions* run_options, absl::Span arguments) { auto state = ExecuteWrapperBeforeExecution(*this, run_options); - StatusOr return_value = + absl::StatusOr return_value = ExecuteAsyncOnStream(run_options, arguments, nullptr); TF_RETURN_IF_ERROR(ExecuteWrapperAfterExecution( this, state, return_value.status(), run_options->stream())); return return_value; } -StatusOr Executable::ExecuteAsyncOnStreamWrapper( +absl::StatusOr Executable::ExecuteAsyncOnStreamWrapper( const ServiceExecutableRunOptions* run_options, std::vector arguments) { auto state = ExecuteWrapperBeforeExecution(*this, run_options); - StatusOr return_value = + absl::StatusOr return_value = ExecuteAsyncOnStream(run_options, std::move(arguments), nullptr); TF_RETURN_IF_ERROR(ExecuteWrapperAfterExecution( this, state, return_value.status(), run_options->stream())); diff --git a/xla/service/executable.h b/xla/service/executable.h index dcda9507fc0e3..573ff79a564c3 100644 --- a/xla/service/executable.h +++ b/xla/service/executable.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -97,7 +97,7 @@ class ExecutionInput { Status SetDynamicShape(Shape dynamic_shape); - xla::StatusOr ToShapedBuffer( + absl::StatusOr ToShapedBuffer( se::DeviceMemoryAllocator* allocator, int device_ordinal) const; void SetBuffer(const ShapeIndex& index, MaybeOwningDeviceMemory buffer) { @@ -260,7 +260,7 @@ class Executable { // enabled. // // Returns a shaped buffer containing the result of the computation. - StatusOr ExecuteOnStream( + absl::StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, absl::Span arguments, HloExecutionProfile* hlo_execution_profile); @@ -283,19 +283,19 @@ class Executable { // If the hlo_execution_profile is provided as non-nullptr, profiling will be // enabled. Note that profiling is tricky to use correctly, as the profiling // objects (when they exist) must out-live the task. - virtual StatusOr ExecuteAsyncOnStream( + virtual absl::StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, absl::Span arguments, HloExecutionProfile* hlo_execution_profile); // Same as ExecuteAsyncOnStream(), but blocks waiting for the computation to // complete. - StatusOr ExecuteOnStream( + absl::StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, std::vector arguments, HloExecutionProfile* hlo_execution_profile); - virtual StatusOr ExecuteAsyncOnStream( + virtual absl::StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, std::vector arguments, HloExecutionProfile* hlo_execution_profile) = 0; @@ -304,26 +304,26 @@ class Executable { // streams. arguments[i] contains the arguments to the execution on // run_options[i]->stream() and the returned value is at index i of the // returned vector. - virtual StatusOr> ExecuteOnStreams( + virtual absl::StatusOr> ExecuteOnStreams( absl::Span run_options, absl::Span> arguments); // Convenience wrapper for calling Executable::ExecuteOnStream. Sets up a // timer for the execution, sets up HLO profiling if enabled, and fills in the // given ExecutionProfile if non-null. - StatusOr ExecuteOnStreamWrapper( + absl::StatusOr ExecuteOnStreamWrapper( const ServiceExecutableRunOptions* run_options, absl::Span arguments); - StatusOr ExecuteOnStreamWrapper( + absl::StatusOr ExecuteOnStreamWrapper( const ServiceExecutableRunOptions* run_options, std::vector arguments); - StatusOr ExecuteAsyncOnStreamWrapper( + absl::StatusOr ExecuteAsyncOnStreamWrapper( const ServiceExecutableRunOptions* run_options, absl::Span arguments); - StatusOr ExecuteAsyncOnStreamWrapper( + absl::StatusOr ExecuteAsyncOnStreamWrapper( const ServiceExecutableRunOptions* run_options, std::vector arguments); @@ -381,7 +381,7 @@ class Executable { } HloProto const* hlo_proto() const { - if (!hlo_proto_->has_hlo_module()) { + if (hlo_proto_ != nullptr && !hlo_proto_->has_hlo_module()) { *hlo_proto_->mutable_hlo_module() = module().ToProto(); } return hlo_proto_.get(); diff --git a/xla/service/execution_tracker.cc b/xla/service/execution_tracker.cc index b69151f7739e3..655aca424f190 100644 --- a/xla/service/execution_tracker.cc +++ b/xla/service/execution_tracker.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -73,7 +73,7 @@ Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { return OkStatus(); } -StatusOr ExecutionTracker::Resolve( +absl::StatusOr ExecutionTracker::Resolve( const ExecutionHandle& handle) { absl::MutexLock lock(&execution_mutex_); auto it = handle_to_execution_.find(handle.handle()); diff --git a/xla/service/execution_tracker.h b/xla/service/execution_tracker.h index d194e75010103..3bd5f32edced5 100644 --- a/xla/service/execution_tracker.h +++ b/xla/service/execution_tracker.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -78,7 +78,7 @@ class ExecutionTracker { // Resolves the given ExecutionHandle to an AsyncExecution. Returns an // error status if the given handle is not found, which means that the // execution is not yet registered or already unregistered. - StatusOr Resolve(const ExecutionHandle& handle); + absl::StatusOr Resolve(const ExecutionHandle& handle); private: // The next handle to assign to an execution. diff --git a/xla/service/export_hlo.h b/xla/service/export_hlo.h index 499fcfa5208dd..7e111019bd46f 100644 --- a/xla/service/export_hlo.h +++ b/xla/service/export_hlo.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/flatten_call_graph.cc b/xla/service/flatten_call_graph.cc index cd09c9b129c8e..191aca4b46e13 100644 --- a/xla/service/flatten_call_graph.cc +++ b/xla/service/flatten_call_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -62,15 +62,6 @@ void ReplaceCalledComputation(HloInstruction* instruction, } break; } - case HloOpcode::kAsyncStart: - case HloOpcode::kAsyncUpdate: - case HloOpcode::kAsyncDone: { - computation->RemoveAsyncInstruction(instruction); - instruction->ReplaceCalledComputations( - [&](HloComputation*) { return new_computation; }); - new_computation->AddAsyncInstruction(*instruction); - break; - } default: LOG(FATAL) << "unexpected opcode: " << instruction->opcode(); } @@ -79,11 +70,6 @@ void ReplaceCalledComputation(HloInstruction* instruction, // Flatten a single call graph node. Expects to visit nodes in postorder. Status FlattenNode(const CallGraphNode& node) { HloComputation* computation = node.computation(); - // Flatten async computations so that different async ops that belong to the - // same async group id call the same computation but async ops that have - // different async group ids call a different computation. This map maps from - // the async group id to the associated computation. - absl::flat_hash_map async_computations; HloModule* module = computation->parent(); // Clone callee for all call-sites except the first one. for (int i = 0; i < node.caller_callsites().size(); ++i) { @@ -100,33 +86,13 @@ Status FlattenNode(const CallGraphNode& node) { continue; } - // For async computations, look up in the async computations map and use the - // computation for the group id, if available. Otherwise, clone the - // computation. - HloComputation* clone; if (computation->IsAsyncComputation()) { - HloInstruction* caller = call_site.instruction(); - TF_RET_CHECK(caller->async_group_id().has_value()); - auto async_computation_it = - async_computations.find(*caller->async_group_id()); - if (async_computation_it != async_computations.end()) { - if (computation != async_computation_it->second) { - ReplaceCalledComputation(call_site.instruction(), computation, - async_computation_it->second); - } - continue; - } else if (async_computations.empty()) { - async_computations[*caller->async_group_id()] = computation; - continue; - } - clone = module->AddEmbeddedComputation(computation->Clone()); - ReplaceCalledComputation(call_site.instruction(), computation, clone); - async_computations[*call_site.instruction()->async_group_id()] = clone; - } else { - // Clone computation for the remaining sequential context call sites. - clone = module->AddEmbeddedComputation(computation->Clone()); - ReplaceCalledComputation(call_site.instruction(), computation, clone); + continue; } + // Clone computation for the remaining sequential context call sites. + HloComputation* clone = + module->AddEmbeddedComputation(computation->Clone()); + ReplaceCalledComputation(call_site.instruction(), computation, clone); // Clone the sub-tree of all computations called from this node. std::vector worklist; worklist.push_back(clone); @@ -152,7 +118,7 @@ Status FlattenNode(const CallGraphNode& node) { } // namespace -StatusOr FlattenCallGraph::Run( +absl::StatusOr FlattenCallGraph::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES(3, "Before flatten call graph:\n" + module->ToString()); diff --git a/xla/service/flatten_call_graph.h b/xla/service/flatten_call_graph.h index bb3b43b4b2026..66b3d9f15e5a9 100644 --- a/xla/service/flatten_call_graph.h +++ b/xla/service/flatten_call_graph.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ class FlattenCallGraph : public HloModulePass { // Duplicates computations called from multiple call- or while-nodes to // flatten the call graph. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/flatten_call_graph_test.cc b/xla/service/flatten_call_graph_test.cc index 4f1a7b665dbcf..57498209c756c 100644 --- a/xla/service/flatten_call_graph_test.cc +++ b/xla/service/flatten_call_graph_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -87,7 +87,7 @@ class FlattenCallGraphTest : public HloTestBase { return builder.Build(); } - StatusOr RunFlattenCallGraph(HloModule* module) { + absl::StatusOr RunFlattenCallGraph(HloModule* module) { FlattenCallGraph flatten; TF_ASSIGN_OR_RETURN(bool result, flatten.Run(module)); return result; @@ -254,55 +254,5 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { EXPECT_EQ(1, sub_node.caller_callsites().size()); } -TEST_F(FlattenCallGraphTest, AsyncCall) { - std::string hlo_string = R"( -HloModule AsyncCall - -%called_computation (param_0: f32[4096], param_1: f32[4096]) -> f32[4096] { - %param_0 = f32[4096]{0} parameter(0) - %param_1 = f32[4096]{0} parameter(1) - ROOT %result.1 = f32[4096]{0} add(f32[4096]{0} %param_0, f32[4096]{0} %param_1) -} - -%async_wrapped (async_param: f32[4096], async_param.1: f32[4096]) -> f32[4096] { - %async_param = f32[4096]{0} parameter(0) - %async_param.1 = f32[4096]{0} parameter(1) - ROOT %call = f32[4096]{0} call(f32[4096]{0} %async_param, f32[4096]{0} %async_param.1), to_apply=%called_computation -} - -ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] { - %a = f32[4096]{0} parameter(0) - %b = f32[4096]{0} parameter(1) - %async-start.0 = ((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) async-start(f32[4096]{0} %a, f32[4096]{0} %b), async_group_id=0, calls=%async_wrapped - %async-done.0 = f32[4096]{0} async-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start.0), async_group_id=0, calls=%async_wrapped - %async-start.1 = ((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) async-start(f32[4096]{0} %async-done.0, f32[4096]{0} %b), async_group_id=1, calls=%async_wrapped - %async-done.1 = f32[4096]{0} async-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start.1), async_group_id=1, calls=%async_wrapped - ROOT %add_1 = f32[4096]{0} add(f32[4096]{0} %a, f32[4096]{0} %async-done.1) -} - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); - EXPECT_TRUE(result); - - // We expect the entry computation, two async_wrapped computations and two - // called_computation computations. - EXPECT_EQ(5, module->computation_count()); - - EXPECT_EQ(FindInstruction(module.get(), "async-start.0") - ->async_wrapped_computation(), - FindInstruction(module.get(), "async-done.0") - ->async_wrapped_computation()); - EXPECT_EQ(FindInstruction(module.get(), "async-start.1") - ->async_wrapped_computation(), - FindInstruction(module.get(), "async-done.1") - ->async_wrapped_computation()); - EXPECT_NE(FindInstruction(module.get(), "async-start.0") - ->async_wrapped_computation(), - FindInstruction(module.get(), "async-start.1") - ->async_wrapped_computation()); -} - } // namespace } // namespace xla diff --git a/xla/service/float8_fnuz_ir_emitter.cc b/xla/service/float8_fnuz_ir_emitter.cc index 33de411cc4b87..42c421b7ee7b9 100644 --- a/xla/service/float8_fnuz_ir_emitter.cc +++ b/xla/service/float8_fnuz_ir_emitter.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ using primitive_util::UnderflowExponent; namespace { -StatusOr PrimitiveTypeToAPFloatSemantics( +absl::StatusOr PrimitiveTypeToAPFloatSemantics( PrimitiveType type) { switch (type) { case F8E4M3B11FNUZ: @@ -63,8 +63,8 @@ StatusOr PrimitiveTypeToAPFloatSemantics( } } -StatusOr PrimitiveTypeToLLVMType(llvm::IRBuilder<>* b, - PrimitiveType type) { +absl::StatusOr PrimitiveTypeToLLVMType(llvm::IRBuilder<>* b, + PrimitiveType type) { switch (type) { case F8E4M3B11FNUZ: case F8E4M3FN: @@ -73,7 +73,7 @@ StatusOr PrimitiveTypeToLLVMType(llvm::IRBuilder<>* b, case F8E5M2FNUZ: return b->getInt8Ty(); case BF16: - return b->getInt16Ty(); + return b->getBFloatTy(); case F16: return b->getHalfTy(); case F32: @@ -93,9 +93,9 @@ StatusOr PrimitiveTypeToLLVMType(llvm::IRBuilder<>* b, // // The result is provided as a uint64_t containing the bit encoding of the // maximum value. -StatusOr ComputeMaximumValue(PrimitiveType input_type, - PrimitiveType output_type, - llvm::IRBuilder<>* b) { +absl::StatusOr ComputeMaximumValue(PrimitiveType input_type, + PrimitiveType output_type, + llvm::IRBuilder<>* b) { // Sanity check inputs. TF_RET_CHECK(primitive_util::IsFloatingPointType(input_type)); TF_RET_CHECK(primitive_util::IsFloatingPointType(output_type)); @@ -136,10 +136,9 @@ StatusOr ComputeMaximumValue(PrimitiveType input_type, // Tests whether the input value can be represented in the output type as a // finite value. This takes into account rounding. -StatusOr IsInputOutsideOutputRange(PrimitiveType input_type, - llvm::Value* value, - PrimitiveType output_type, - llvm::IRBuilder<>* b) { +absl::StatusOr IsInputOutsideOutputRange( + PrimitiveType input_type, llvm::Value* value, PrimitiveType output_type, + llvm::IRBuilder<>* b) { const uint64_t shift = BitWidth(input_type) - 1; const uint64_t bit_mask = (0x1ull << shift) - 1; @@ -296,10 +295,10 @@ llvm::Value* ExtractMantissa(PrimitiveType type, llvm::Value* value, // ExtractMantissa(value) = 0b0000000011110111 // ^- third mantissa bit is at bit 4. // result = LastMantissaBit(BF16, 247.0, F8E4M3FNUZ, b) = 4 -StatusOr LastMantissaBit(PrimitiveType input_type, - llvm::Value* value, - PrimitiveType output_type, - llvm::IRBuilder<>* b) { +absl::StatusOr LastMantissaBit(PrimitiveType input_type, + llvm::Value* value, + PrimitiveType output_type, + llvm::IRBuilder<>* b) { const int src_mantissa_bits = SignificandWidth(input_type) - 1; const int dest_mantissa_bits = SignificandWidth(output_type) - 1; llvm::Type* int_type = b->getIntNTy(BitWidth(input_type)); @@ -368,10 +367,10 @@ StatusOr LastMantissaBit(PrimitiveType input_type, // Compute the rounding bias for round-to-nearest-even for the input value. // This takes into account whether the input value is a normal number and // whether it will map to a normal number in the output type. -StatusOr DynamicRoundingBias(PrimitiveType input_type, - llvm::Value* value, - PrimitiveType output_type, - llvm::IRBuilder<>* b) { +absl::StatusOr DynamicRoundingBias(PrimitiveType input_type, + llvm::Value* value, + PrimitiveType output_type, + llvm::IRBuilder<>* b) { llvm::Type* int_type = b->getIntNTy(BitWidth(input_type)); // Find the bit position of the last mantissa bit. @@ -496,17 +495,17 @@ llvm::Value* BuildOutputSign(llvm::Value* sign, PrimitiveType output_type, return b->CreateShl(sign, BitWidth(output_type) - 1); } -StatusOr GetQNaN(PrimitiveType type) { +absl::StatusOr GetQNaN(PrimitiveType type) { TF_ASSIGN_OR_RETURN(auto semantics, PrimitiveTypeToAPFloatSemantics(type)); return llvm::APFloat::getQNaN(*semantics).bitcastToAPInt().getZExtValue(); } } // namespace -StatusOr EmitFloatingToF8fnuz(PrimitiveType input_type, - llvm::Value* input_value, - PrimitiveType output_type, - llvm::IRBuilder<>* b) { +absl::StatusOr EmitFloatingToF8fnuz(PrimitiveType input_type, + llvm::Value* input_value, + PrimitiveType output_type, + llvm::IRBuilder<>* b) { // Sanity check for supported types. TF_RET_CHECK(input_type == BF16 || input_type == F16 || input_type == F32 || input_type == F64); @@ -567,11 +566,11 @@ StatusOr EmitFloatingToF8fnuz(PrimitiveType input_type, result); } -StatusOr EmitF8fnuzToFloating(PrimitiveType input_type, - llvm::Value* f8_value, - PrimitiveType output_type, - llvm::IRBuilder<>* b, - llvm::Module* module) { +absl::StatusOr EmitF8fnuzToFloating(PrimitiveType input_type, + llvm::Value* f8_value, + PrimitiveType output_type, + llvm::IRBuilder<>* b, + llvm::Module* module) { // Sanity check for supported types. TF_RET_CHECK(input_type == F8E4M3FNUZ || input_type == F8E5M2FNUZ); TF_RET_CHECK(primitive_util::IsFloatingPointType(output_type)); diff --git a/xla/service/float8_fnuz_ir_emitter.h b/xla/service/float8_fnuz_ir_emitter.h index 6aaad33767233..5e98b8ebeab51 100644 --- a/xla/service/float8_fnuz_ir_emitter.h +++ b/xla/service/float8_fnuz_ir_emitter.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,19 +28,19 @@ namespace float8_fnuz_ir_emitter { // Convert the given floating point input to the output type. input_type must // be one of BF16, F16, F32, and F64. output_type must be one of F8E4M3FNUZ and // F8E5M2FNUZ. -StatusOr EmitFloatingToF8fnuz(PrimitiveType input_type, - llvm::Value* input_value, - PrimitiveType output_type, - llvm::IRBuilder<>* b); +absl::StatusOr EmitFloatingToF8fnuz(PrimitiveType input_type, + llvm::Value* input_value, + PrimitiveType output_type, + llvm::IRBuilder<>* b); // Convert the given floating point input to the output type. input_type must // be one of F8E4M3FNUZ and F8E5M2FNUZ. output_type must be one of BF16, F16, // F32, and F64. -StatusOr EmitF8fnuzToFloating(PrimitiveType input_type, - llvm::Value* f8_value, - PrimitiveType output_type, - llvm::IRBuilder<>* b, - llvm::Module* module); +absl::StatusOr EmitF8fnuzToFloating(PrimitiveType input_type, + llvm::Value* f8_value, + PrimitiveType output_type, + llvm::IRBuilder<>* b, + llvm::Module* module); } // namespace float8_fnuz_ir_emitter } // namespace xla diff --git a/xla/service/float_normalization.cc b/xla/service/float_normalization.cc index 84774a3b4884e..97fbc5bbc2688 100644 --- a/xla/service/float_normalization.cc +++ b/xla/service/float_normalization.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/call_graph.h" #include "xla/service/hlo_dce.h" #include "xla/service/tuple_simplifier.h" #include "xla/shape_util.h" @@ -57,9 +58,10 @@ class FloatNormalizationVisitor : public DfsHloVisitorWithDefault { // Creates a copy of `hlo` with subshapes matching `from` type converted to // `to` type. If no matching subshapes are found, returns the original `hlo`. - StatusOr ConvertType(HloInstruction* hlo, PrimitiveType from, - PrimitiveType to, - HloComputation* computation); + absl::StatusOr ConvertType(HloInstruction* hlo, + PrimitiveType from, + PrimitiveType to, + HloComputation* computation); // Inserts a conversion HLO that changes the given HLO's output type. If the // output is a tuple, change all elements that match the from type. @@ -122,7 +124,7 @@ int64_t ShapeLeafCount(const Shape& shape) { return count; } -StatusOr FloatNormalizationVisitor::ConvertType( +absl::StatusOr FloatNormalizationVisitor::ConvertType( HloInstruction* hlo, PrimitiveType from, PrimitiveType to, HloComputation* computation) { if (CountSubshapesWithMatchingType(hlo->shape(), from) == 0) { @@ -206,17 +208,37 @@ Status FloatNormalizationVisitor::ChangeOutputTypeThenInsertConvertBack( HloInstruction::CreateConvert(original_subshape, leaf)); })); + std::vector conversions_to_simplify; for (auto* user : materialized_users) { // If the user is a low-precision -> high-precision convert, we can replace // it with `hlo`, which has its input changed to high-precision. + // But we should not replace it immediately, it can lead to type mismatch in + // the below specific case: + // Op + // bf16 / \. + // Convert \ bf16 + // fp32 | / + // Tuple [fp32, bf16] + // If we run the float normalization pass and replace `Convert` at first, + // the result will be: + // Op + // fp32 | + // Convert + // bf16 / \ bf16 + // Tuple [fp32, bf16] + // So we should keep the 'Convert' and replace it after all of the other + // users has been replaced. if (user->opcode() == HloOpcode::kConvert && user->shape().element_type() == to && to == HighPrecisionType() && from == LowPrecisionType()) { - TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo)); + conversions_to_simplify.emplace_back(user); } else { TF_RETURN_IF_ERROR(hlo->ReplaceUseWithDifferentShape(user, new_hlo)); } } + for (auto* convert : conversions_to_simplify) { + TF_RETURN_IF_ERROR(convert->ReplaceAllUsesWith(hlo)); + } if (is_root) { computation->set_root_instruction(new_hlo, /*accept_different_shape=*/true); } @@ -270,6 +292,14 @@ Status FloatNormalizationVisitor::ConvertCalledComputations( return OkStatus(); } +// Returns true if the called computations of the instruction should not +// be touched by float normalization. In particular, we must not introduce +// float conversions into collective reductions. +bool ShouldAvoidNormalizingComputationsForInstruction(HloInstruction* hlo) { + return hlo->opcode() == HloOpcode::kAllReduce || + hlo->opcode() == HloOpcode::kReduceScatter; +} + Status FloatNormalizationVisitor::HandleMultipleOutputs(HloInstruction* hlo) { std::vector operand_types(hlo->operand_count()); std::vector output_types(hlo->operand_count()); @@ -335,7 +365,7 @@ Status FloatNormalizationVisitor::HandleMultipleOutputs(HloInstruction* hlo) { std::vector low_precision_called_comps; for (auto* comp : hlo->called_computations()) { - if (comp->IsCollectiveCalledComputation()) { + if (ShouldAvoidNormalizingComputationsForInstruction(hlo)) { continue; } bool comp_has_low_precision = false; @@ -414,7 +444,7 @@ Status FloatNormalizationVisitor::HandleInstruction(HloInstruction* hlo) { std::vector low_precision_called_comps; for (auto* comp : hlo->called_computations()) { - if (comp->IsCollectiveCalledComputation()) { + if (ShouldAvoidNormalizingComputationsForInstruction(hlo)) { continue; } bool comp_has_low_precision = false; @@ -544,22 +574,73 @@ Status FloatNormalizationVisitor::Preprocess(HloInstruction* hlo) { return OkStatus(); } +// We must avoid normalizing computations that have non-normalizing users +// (e.g., all-reduce's computations should not be normalized). If a +// computation is shared between normalizing and non-normalizing users, we will +// clone the computation for the non-normalizing users so that it can be left +// unchanged. This function clones the shared computations and returns the set +// of non-normalizing computations that must be skipped by the visitor. +absl::flat_hash_set +CloneComputationsForNonNormalizingInstructions( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + std::unique_ptr call_graph = + CallGraph::Build(module, execution_threads); + + absl::flat_hash_set computations_to_skip; + for (const CallGraphNode& node : call_graph->nodes()) { + bool has_normalizing_users = false; + bool has_users_to_skip_normalization = false; + for (const CallSite& site : node.caller_callsites()) { + if (ShouldAvoidNormalizingComputationsForInstruction( + site.instruction())) { + has_users_to_skip_normalization = true; + } else { + has_normalizing_users = true; + } + } + // If the computation is only used by normalizing users or only by + // non-normalizing users, then we do not clone. + if (!has_users_to_skip_normalization) { + continue; + } + if (!has_normalizing_users) { + computations_to_skip.insert(node.computation()); + continue; + } + // Otherwise, we create a clone and replace the normalizing instructions' + // computations with the clone. + HloComputation* clone = module->DeepCloneComputation(node.computation()); + for (const CallSite& site : node.caller_callsites()) { + if (ShouldAvoidNormalizingComputationsForInstruction( + site.instruction())) { + site.instruction()->ReplaceCalledComputations( + [&](HloComputation* called) { + return called == node.computation() ? clone : called; + }); + } + } + computations_to_skip.insert(clone); + } + return computations_to_skip; +} } // namespace -StatusOr FloatNormalization::Run( +absl::StatusOr FloatNormalization::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES(2, "FloatNormalization::Run() for " + primitive_util::LowercasePrimitiveTypeName( float_support_->LowPrecisionType()) + ", before:\n" + module->ToString()); + auto computations_to_visit = + module->MakeComputationPostOrder(execution_threads); + auto computations_to_skip = + CloneComputationsForNonNormalizingInstructions(module, execution_threads); + FloatNormalizationVisitor visitor(float_support_, this); - for (auto* comp : module->MakeComputationPostOrder(execution_threads)) { - if (comp->IsCollectiveCalledComputation()) { - XLA_VLOG_LINES(2, "Skip processing collective called computation: " + - comp->ToString()); - continue; - } + for (auto* comp : computations_to_visit) { + if (computations_to_skip.contains(comp)) continue; TF_RETURN_IF_ERROR(comp->Accept(&visitor)); } XLA_VLOG_LINES(2, "FloatNormalization::Run() for " + diff --git a/xla/service/float_normalization.h b/xla/service/float_normalization.h index e23776315c38a..be0e92ea5ee56 100644 --- a/xla/service/float_normalization.h +++ b/xla/service/float_normalization.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ class FloatNormalization : public HloModulePass { // Run float normalization on the given computation. Returns whether the // computation was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -71,9 +71,9 @@ class BFloat16MixedPrecisionRemoval : public HloModulePass { // Run mixed precision removal on the given computation. Returns whether the // computation was changed. using HloPassInterface::Run; - StatusOr Run(HloModule* module, - const absl::flat_hash_set& - execution_threads) override { + absl::StatusOr Run(HloModule* module, + const absl::flat_hash_set& + execution_threads) override { FloatNormalization normalization(&no_mixed_precision_support_); return normalization.Run(module, execution_threads); } diff --git a/xla/service/float_normalization_test.cc b/xla/service/float_normalization_test.cc index fb6c133a99789..973a8ad12669d 100644 --- a/xla/service/float_normalization_test.cc +++ b/xla/service/float_normalization_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,9 +15,11 @@ limitations under the License. #include "xla/service/float_normalization.h" +#include #include #include +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -25,11 +27,14 @@ limitations under the License. #include "xla/service/float_support.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_verifier.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/statusor.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -120,7 +125,7 @@ class FloatNormalizationTest : public HloTestBase { PrimitiveType high_precision_type = F32) { TestFloatSupport float_support(low_precision_type, high_precision_type); FloatNormalization normalization(&float_support); - StatusOr result = normalization.Run(module); + absl::StatusOr result = normalization.Run(module); EXPECT_IS_OK(result.status()); HloVerifier verifier(/*layout_sensitive=*/false, @@ -528,7 +533,7 @@ class FloatNormalizationNoComputeSupportTest : public FloatNormalizationTest { high_precision_type); FloatNormalization normalization(&float_support); - StatusOr result = normalization.Run(module); + absl::StatusOr result = normalization.Run(module); EXPECT_IS_OK(result.status()); HloVerifier verifier(/*layout_sensitive=*/false, @@ -540,7 +545,7 @@ class FloatNormalizationNoComputeSupportTest : public FloatNormalizationTest { }; TEST_F(FloatNormalizationNoComputeSupportTest, - NoNormalizationForToApplyMultiOuputAllReduce) { + NoNormalizationForToApplyMultiOutputAllReduce) { auto module = CreateNewVerifiedModule(); HloComputation::Builder sum_builder("sum"); auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( @@ -581,6 +586,67 @@ TEST_F(FloatNormalizationNoComputeSupportTest, EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), BF16); } +TEST_F(FloatNormalizationNoComputeSupportTest, + NormalizationClonesSharedApplyAllReduceAndReduce) { + auto module = CreateNewVerifiedModule(); + HloComputation::Builder sum_builder("sum"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(BF16, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(BF16, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(BF16, {}), HloOpcode::kAdd, x, y)); + HloComputation* reduction = + module->AddEmbeddedComputation(sum_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + + Shape bf16_shape_a = ShapeUtil::MakeShape(BF16, {2, 4}); + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_shape_a, "a")); + + Shape bf16_shape_b = ShapeUtil::MakeShape(BF16, {2, 4, 2}); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape_b, "b")); + + Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {}); + HloInstruction* init = builder.AddInstruction( + HloInstruction::CreateParameter(2, bf16_scalar_shape, "init")); + + HloInstruction* all_reduce = builder.AddInstruction( + HloInstruction::CreateAllReduce(bf16_shape_a, {a}, reduction, + /*replica_groups=*/{}, + /*constrain_layout=*/false, + /*channel_id=*/std::nullopt, + /*use_global_device_ids=*/false)); + + HloInstruction* reduce = builder.AddInstruction( + HloInstruction::CreateReduce(bf16_shape_a, b, init, {2}, reduction)); + builder.AddInstruction(HloInstruction::CreateBinary( + bf16_shape_a, HloOpcode::kAdd, all_reduce, reduce)); + + auto computation = module->AddEntryComputation(builder.Build()); + // Verify that the shared computation was cloned, the all-reduce instruction + // got the unchanged bf16 add, while the reduction was promoted to f32 + // together with its called computation. + EXPECT_TRUE(Normalize(module.get())); + EXPECT_EQ(computation->root_instruction()->shape().element_type(), BF16); + EXPECT_EQ(all_reduce->operand(0)->shape().element_type(), BF16); + EXPECT_EQ(all_reduce->to_apply()->root_instruction()->opcode(), + HloOpcode::kAdd); + EXPECT_EQ(all_reduce->to_apply()->root_instruction()->shape().element_type(), + BF16); + EXPECT_EQ(reduce->called_computations().size(), 1); + EXPECT_EQ(reduce->called_computations()[0] + ->root_instruction() + ->shape() + .element_type(), + F32); + EXPECT_EQ(reduce->called_computations()[0]->root_instruction()->opcode(), + HloOpcode::kConvert); + EXPECT_EQ(reduce->shape().element_type(), F32); +} + TEST_F(FloatNormalizationNoComputeSupportTest, NoNormalizationForToApplyAllReduce) { auto module = CreateNewVerifiedModule(); @@ -653,4 +719,37 @@ TEST_F(FloatNormalizationNoComputeSupportTest, EXPECT_EQ(crs->to_apply()->root_instruction()->opcode(), HloOpcode::kAdd); } +TEST_F(FloatNormalizationTest, ConvertBeforeTuple) { + auto builder = HloComputation::Builder(TestName()); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, a, b)); + + HloInstruction* convert = + builder.AddInstruction(HloInstruction::CreateConvert(f32_shape, add)); + + builder.AddInstruction(HloInstruction::CreateVariadic( + ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), HloOpcode::kTuple, + {convert, add})); + + auto module = CreateNewVerifiedModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get(), BF16)); + + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kTuple); + EXPECT_EQ(computation->root_instruction()->operand(0)->shape().element_type(), + F32); + EXPECT_EQ( + computation->root_instruction()->shape().tuple_shapes(0).element_type(), + F32); +} + } // namespace xla diff --git a/xla/service/float_support.cc b/xla/service/float_support.cc index 75fafbaa8b8cd..3bcbfdd7dcb14 100644 --- a/xla/service/float_support.cc +++ b/xla/service/float_support.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -87,6 +87,7 @@ bool FloatSupport::EffectiveOperandPrecisionIsOutputPrecision( case HloOpcode::kAllToAll: case HloOpcode::kBroadcast: case HloOpcode::kClamp: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: case HloOpcode::kConcatenate: case HloOpcode::kConvert: diff --git a/xla/service/float_support.h b/xla/service/float_support.h index 4ee4e7157fa4c..9e4e35cabdb61 100644 --- a/xla/service/float_support.h +++ b/xla/service/float_support.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/fusion_constant_sinking.cc b/xla/service/fusion_constant_sinking.cc new file mode 100644 index 0000000000000..eca4745b993b4 --- /dev/null +++ b/xla/service/fusion_constant_sinking.cc @@ -0,0 +1,109 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/fusion_constant_sinking.h" + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_dce.h" +#include "xla/shape_util.h" +#include "xla/statusor.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +// Given the fusion instruction and the operand to the fusion, checks: +// 1. the operand is scalar and constant +// 2. the parameter instruction representing the operand is not used in any +// fusion instructions with a single operand. +// if the checks hold, it returns the parameter instruction representing the +// operand in the fusion computation, otherwise nullopt. +bool CanSink(HloInstruction* fusion, const HloInstruction* operand) { + if (!fusion->IsLoopFusion() && !fusion->IsOutputFusion()) { + return false; + } + + if (fusion->operand_count() == 1) { + return false; + } + + if (!ShapeUtil::IsScalar(operand->shape()) || !operand->IsConstant()) { + return false; + } + + int64_t operand_idx = fusion->operand_index(operand); + HloInstruction* fused_param = fusion->fused_parameter(operand_idx); + for (HloInstruction* user : fused_param->users()) { + // Fusions with single operands are not considered because the nested + // computation will be left without any parameters + if (user->opcode() == HloOpcode::kFusion && user->operand_count() == 1) { + return false; + } + } + return true; +} + +bool ProcessScalar(HloInstruction* scalar) { + if (!ShapeUtil::IsScalar(scalar->shape()) || !scalar->IsConstant()) { + return false; + } + bool processed = false; + std::vector sinkable_users; + for (HloInstruction* use : scalar->users()) { + if (CanSink(use, scalar)) { + sinkable_users.push_back(use); + } + } + for (HloInstruction* use : sinkable_users) { + HloInstruction* fused_scalar = use->FuseInstruction(scalar); + processed = true; + ProcessScalar(fused_scalar); + } + return processed; +} + +absl::StatusOr FusionConstantSinking::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + VLOG(3) << "HLO module before FusionConstantSinking:"; + XLA_VLOG_LINES(3, module->ToString()); + + bool changed = false; + for (HloComputation* c : module->MakeNonfusionComputations()) { + for (HloInstruction* i : c->MakeInstructionPostOrder()) { + changed |= ProcessScalar(i); + } + } + + if (changed) { + TF_ASSIGN_OR_RETURN(bool dce, HloDCE{}.Run(module, execution_threads)); + changed |= dce; + } + + VLOG(3) << "HLO module after FusionConstantSinking:"; + XLA_VLOG_LINES(3, module->ToString()); + return changed; +} + +} // namespace xla diff --git a/xla/service/fusion_constant_sinking.h b/xla/service/fusion_constant_sinking.h new file mode 100644 index 0000000000000..636c785696477 --- /dev/null +++ b/xla/service/fusion_constant_sinking.h @@ -0,0 +1,39 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_FUSION_CONSTANT_SINKING_H_ +#define XLA_SERVICE_FUSION_CONSTANT_SINKING_H_ + +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass which sinks constants into fusion computations. +class FusionConstantSinking : public HloModulePass { + public: + absl::string_view name() const override { return "fusion_constant_sinking"; } + + // Run fusion constant sinking operations on the given module. Returns whether + // the module was changed (constant expressions folded). + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_SERVICE_FUSION_CONSTANT_SINKING_H_ diff --git a/xla/service/fusion_constant_sinking_test.cc b/xla/service/fusion_constant_sinking_test.cc new file mode 100644 index 0000000000000..d822f03bd46b9 --- /dev/null +++ b/xla/service/fusion_constant_sinking_test.cc @@ -0,0 +1,210 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/fusion_constant_sinking.h" + +#include +#include + +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" +#include "xla/test.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +using FusionConstantSinkingTest = HloTestBase; + +TEST_F(FusionConstantSinkingTest, SinkConstant) { + std::string hlo_string = R"( + HloModule SimpleLoop + + %fused_computation.slice (param_0.51117: s8[56,4096,4096], param_1: s32[]) -> s8[1,4096,4096] { + %param_0.51117 = s8[56,4096,4096]{2,1,0:T(8,128)(4,1)} parameter(0) + p1 = s32[]{:T(128)} parameter(1) + %constant.85694 = s32[]{:T(128)} constant(0) + ROOT %dynamic-slice.22040 = s8[1,4096,4096]{2,1,0:T(8,128)(4,1)} dynamic-slice(s8[56,4096,4096]{2,1,0:T(8,128)(4,1)} %param_0.51117, s32[]{:T(128)} p1, s32[]{:T(128)} %constant.85694, s32[]{:T(128)} %constant.85694), dynamic_slice_sizes={1,4096,4096} + } + + ENTRY main { + p0 = s8[56,4096,4096]{2,1,0:T(8,128)(4,1)} parameter(0) + c = s32[]{:T(128)} constant(10) + ROOT out = s8[1,4096,4096]{2,1,0:T(8,128)(4,1)} fusion(s8[56,4096,4096]{2,1,0:T(8,128)(4,1)} p0, s32[]{:T(128)} c), kind=kLoop, calls=%fused_computation.slice + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + FusionConstantSinking constant_sinking; + + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&constant_sinking, module.get())); + + EXPECT_TRUE(result); + EXPECT_THAT( + module->GetComputationWithName("fused_computation.slice") + ->root_instruction(), + GmockMatch(match::DynamicSlice(match::Parameter(0), match::Constant(), + match::Constant(), match::Constant()))); +} + +TEST_F(FusionConstantSinkingTest, SingleOperandFusionNoSink) { + std::string hlo_string = R"( + HloModule SimpleLoop + + %fused_computation (param_1: s8[]) -> s8[1,4096,4096] { + param0 = s8[] parameter(0) + ROOT out = s8[1,4096,4096]{2,1,0:T(8,128)(4,1)} broadcast(param0), dimensions={} + } + + ENTRY main { + c = s8[]{:T(128)} constant(10) + ROOT out = s8[1,4096,4096]{2,1,0:T(8,128)(4,1)} fusion(s8[]{:T(128)} c), kind=kLoop, calls=%fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + FusionConstantSinking constant_sinking; + + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&constant_sinking, module.get())); + + EXPECT_FALSE(result); +} + +// Fusions with single operands are not considered because the nested +// computation will be left without any parameters +TEST_F(FusionConstantSinkingTest, SingleOperandUserNoSink) { + std::string hlo_string = R"( + HloModule SimpleLoop + + %fused_computation.inner (param_1: s32[]) -> s32[] { + p1 = s32[]{:T(128)} parameter(0) + %constant.85694 = s32[]{:T(128)} constant(10) + ROOT out = s32[] add(p1, %constant.85694) + } + + %fused_computation (param_0.51117: s32[4096,4096], param_1: + s32[]) -> s32[4096,4096] { + %param_0.51117 = s32[4096,4096]{1,0:T(8,128)(4,1)} parameter(0) + p1 = s32[]{:T(128)} parameter(1) + %inner.fusion = s32[] fusion(s32[]{:T(128)} p1), kind=kLoop, calls=%fused_computation.inner + %broadcast = s32[4096,4096]{1,0:T(8,128)(4,1)} broadcast(%inner.fusion), dimensions={} + ROOT out = s32[4096,4096] add(%broadcast, %param_0.51117) + } + + ENTRY main { + p0 = s32[4096,4096]{1,0:T(8,128)(4,1)} parameter(0) + c = s32[]{:T(128)} constant(10) + ROOT out = s32[4096,4096]{1,0:T(8,128)(4,1)} + fusion(s32[4096,4096]{1,0:T(8,128)(4,1)} p0, s32[]{:T(128)} c), kind=kLoop, calls=%fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + FusionConstantSinking constant_sinking; + + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&constant_sinking, module.get())); + + EXPECT_FALSE(result); +} + +TEST_F(FusionConstantSinkingTest, NonScalarNoSink) { + std::string hlo_string = R"( + HloModule SimpleLoop + + %fused_computation (param_1: s8[2], p1: s8[2,4096,4096]) -> s8[2,4096,4096] { + param0 = s8[2] parameter(0) + param1 = s8[2,4096,4096]{2,1,0:T(8,128)(4,1)} parameter(1) + bcast = s8[2,4096,4096]{2,1,0:T(8,128)(4,1)} broadcast(param0), dimensions={0} + ROOT out = s8[2,4096,4096]{2,1,0:T(8,128)(4,1)} add(param1, bcast) + } + + ENTRY main { + p = s8[2,4096,4096]{2,1,0:T(8,128)(4,1)} parameter(0) + c = s8[2]{0:T(128)} constant({10,20}) + ROOT out = s8[2,4096,4096]{2,1,0:T(8,128)(4,1)} fusion(s8[2]{0:T(128)} c, p), kind=kLoop, calls=%fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + FusionConstantSinking constant_sinking; + + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&constant_sinking, module.get())); + + EXPECT_FALSE(result); +} + +TEST_F(FusionConstantSinkingTest, SinkConstantNested) { + std::string hlo_string = R"( + HloModule SimpleLoop + + %fused_computation.inner (param_0.51117: s8[56,4096,4096], param_1: + s32[]) -> s8[1,4096,4096] { + %param_0.51117 = s8[56,4096,4096]{2,1,0:T(8,128)(4,1)} parameter(0) + p1 = s32[]{:T(128)} parameter(1) + %constant.85694 = s32[]{:T(128)} constant(0) + + ROOT %dynamic-slice.22040 = s8[1,4096,4096]{2,1,0:T(8,128)(4,1)} + dynamic-slice(s8[56,4096,4096]{2,1,0:T(8,128)(4,1)} %param_0.51117, + s32[]{:T(128)} p1, s32[]{:T(128)} %constant.85694, s32[]{:T(128)} + %constant.85694), dynamic_slice_sizes={1,4096,4096} + } + + %fused_computation (param_0.51117: s8[56,4096,4096], param_1: + s32[]) -> s8[4096,4096] { + %param_0.51117 = s8[56,4096,4096]{2,1,0:T(8,128)(4,1)} parameter(0) + p1 = s32[]{:T(128)} parameter(1) + + %inner.fusion = s8[1,4096,4096]{2,1,0:T(8,128)(4,1)} fusion(s8[56,4096,4096]{2,1,0:T(8,128)(4,1)} %param_0.51117, s32[]{:T(128)} p1), kind=kLoop, calls=%fused_computation.inner + + ROOT %bitcast = s8[4096,4096]{1,0:T(8,128)(4,1)} bitcast(s8[1,4096,4096]{2,1,0:T(8,128)(4,1)} %inner.fusion) + } + + ENTRY main { + p0 = s8[56,4096,4096]{2,1,0:T(8,128)(4,1)} parameter(0) + c = s32[]{:T(128)} constant(10) + ROOT out = s8[4096,4096]{1,0:T(8,128)(4,1)} + fusion(s8[56,4096,4096]{2,1,0:T(8,128)(4,1)} p0, s32[]{:T(128)} c), + kind=kLoop, calls=%fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + FusionConstantSinking constant_sinking; + + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&constant_sinking, module.get())); + + EXPECT_TRUE(result); + EXPECT_THAT( + module->GetComputationWithName("fused_computation")->num_parameters(), 1); + EXPECT_THAT(module->GetComputationWithName("fused_computation.inner") + ->num_parameters(), + 1); +} + +} // namespace +} // namespace xla diff --git a/xla/service/fusion_node_indexing_evaluation.cc b/xla/service/fusion_node_indexing_evaluation.cc index 7af17aba46d18..777094a505c3c 100644 --- a/xla/service/fusion_node_indexing_evaluation.cc +++ b/xla/service/fusion_node_indexing_evaluation.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/fusion_node_indexing_evaluation.h b/xla/service/fusion_node_indexing_evaluation.h index 80e376b991a78..3132bbc575208 100644 --- a/xla/service/fusion_node_indexing_evaluation.h +++ b/xla/service/fusion_node_indexing_evaluation.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/fusion_node_indexing_evaluation_test.cc b/xla/service/fusion_node_indexing_evaluation_test.cc index 92dda5938a118..5c3790e60c41e 100644 --- a/xla/service/fusion_node_indexing_evaluation_test.cc +++ b/xla/service/fusion_node_indexing_evaluation_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/fusion_queue.h b/xla/service/fusion_queue.h index fe458cc05c287..d9c9c1edb7cf6 100644 --- a/xla/service/fusion_queue.h +++ b/xla/service/fusion_queue.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gather_expander.cc b/xla/service/gather_expander.cc index 4dd3b69e17e15..6d9281f5410bc 100644 --- a/xla/service/gather_expander.cc +++ b/xla/service/gather_expander.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ limitations under the License. namespace xla { namespace { -StatusOr TransposeIndexVectorDimToLast( +absl::StatusOr TransposeIndexVectorDimToLast( HloInstruction* start_indices, int64_t index_vector_dim) { const Shape& start_indices_shape = start_indices->shape(); @@ -55,7 +55,7 @@ StatusOr TransposeIndexVectorDimToLast( // specific cases in the while loop that does the heavy lifting. // // See the "High Level Algorithm" section for a broader picture. -StatusOr CanonicalizeGatherIndices( +absl::StatusOr CanonicalizeGatherIndices( HloInstruction* start_indices, int64_t index_vector_dim) { // Transpose the non-index-vector dimensions to the front. TF_ASSIGN_OR_RETURN( @@ -85,7 +85,7 @@ StatusOr CanonicalizeGatherIndices( // Expands out or contracts away the gather dimensions in the accumulator // produced by the while loop. -StatusOr AdjustBatchDimsInAccumulator( +absl::StatusOr AdjustBatchDimsInAccumulator( const Shape& start_indices_shape, HloInstruction* accumulator, int64_t index_vector_dim) { std::vector batch_dim_bounds; @@ -109,7 +109,7 @@ StatusOr AdjustBatchDimsInAccumulator( // Expand an index vector from the start_indices tensor into a vector that can // be used to dynamic-slice out of the gather operand. -StatusOr ExpandIndexVectorIntoOperandSpace( +absl::StatusOr ExpandIndexVectorIntoOperandSpace( HloInstruction* index_vector, const GatherDimensionNumbers& dim_numbers, int64_t operand_rank) { HloComputation* computation = index_vector->parent(); @@ -150,7 +150,7 @@ StatusOr ExpandIndexVectorIntoOperandSpace( // This generates the body of the while that implements the main data movement // behavior of gather using dynamic-slice and dynamic-update-slice. -StatusOr> GatherLoopBody( +absl::StatusOr> GatherLoopBody( const HloInstruction& gather, HloInstruction* induction_var, const std::vector& incoming_loop_state) { const GatherDimensionNumbers& dim_numbers = gather.gather_dimension_numbers(); @@ -227,7 +227,7 @@ StatusOr> GatherLoopBody( // New loop state -- only the accumulator has changed. The // WhileUtil::MakeCountedLoop functions takes care of the induction variable // and the while loop exit condition. - return StatusOr>{ + return absl::StatusOr>{ {operand, start_indices, updated_accumulator}}; } @@ -251,7 +251,7 @@ HloInstruction* CreateGatherLoopAccumulatorInitValue( // except that it has the dimensions in the wrong order -- the batch dimensions // are the major dimensions and the offset dimensions are the minor dimensions. // Fix this up with a transpose. -StatusOr PermuteBatchAndOffsetDims( +absl::StatusOr PermuteBatchAndOffsetDims( HloInstruction* accumulator, absl::Span offset_dims, int64_t output_rank) { std::vector permutation; @@ -327,7 +327,7 @@ int64_t GatherIsBroadcast(HloInstruction* gather_instr) { // [3,1] out of operand into an accumulator of shape [4,3,1]. We then // reshape this result to [2,2,3] and finally transpose it to [2,3,2]. -StatusOr GatherExpander::ExpandInstruction( +absl::StatusOr GatherExpander::ExpandInstruction( HloInstruction* gather_instr) { CHECK(!ShapeUtil::IsZeroElementArray(gather_instr->shape())); @@ -379,7 +379,7 @@ StatusOr GatherExpander::ExpandInstruction( gather_instr->gather_slice_sizes(), gather_loop_trip_count, gather_instr->gather_dimension_numbers()); - StatusOr> gather_loop_result_or_error = + absl::StatusOr> gather_loop_result_or_error = WhileUtil::MakeCountedLoop( computation, gather_loop_trip_count, {operand, canonical_start_indices, accumulator_init}, diff --git a/xla/service/gather_expander.h b/xla/service/gather_expander.h index 6c76cf330adec..8f43141c3119e 100644 --- a/xla/service/gather_expander.h +++ b/xla/service/gather_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -50,7 +50,7 @@ class GatherExpander : public OpExpanderPass { protected: bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* gather_inst) override; private: diff --git a/xla/service/gather_expander_test.cc b/xla/service/gather_expander_test.cc index 7eba9bb42174f..2a215559e281d 100644 --- a/xla/service/gather_expander_test.cc +++ b/xla/service/gather_expander_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gather_scatter_utils.cc b/xla/service/gather_scatter_utils.cc index 1d188e7e53d80..223879d4936fe 100644 --- a/xla/service/gather_scatter_utils.cc +++ b/xla/service/gather_scatter_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,8 +24,8 @@ limitations under the License. namespace xla { -StatusOr TransformStartIndices(HloInstruction* indices, - int64_t index_vector_dim) { +absl::StatusOr TransformStartIndices( + HloInstruction* indices, int64_t index_vector_dim) { int64_t rank = indices->shape().rank(); if (index_vector_dim == rank) { // Add a size 1 dimension to the indices if the index_vector_dim is @@ -62,7 +62,7 @@ MakeOperandStartIndexPermutations(absl::Span dim_map, return {perm, InversePermutation(perm)}; } -StatusOr MaybeTranspose( +absl::StatusOr MaybeTranspose( HloInstruction* operand, absl::Span permutation) { if (IsIdentityPermutation(permutation)) { return operand; @@ -71,7 +71,7 @@ StatusOr MaybeTranspose( return result; } -StatusOr> MaybeTranspose( +absl::StatusOr> MaybeTranspose( absl::Span operands, const std::vector& operand_permutation) { std::vector result; @@ -83,8 +83,9 @@ StatusOr> MaybeTranspose( return result; } -StatusOr MoveDimensionToEnd(HloInstruction* operand, - size_t dimension, size_t rank) { +absl::StatusOr MoveDimensionToEnd(HloInstruction* operand, + size_t dimension, + size_t rank) { std::vector permutation; for (size_t i = 0; i < rank; ++i) { if (i != dimension) permutation.push_back(i); diff --git a/xla/service/gather_scatter_utils.h b/xla/service/gather_scatter_utils.h index fd14709327f80..3ce7eb43701b0 100644 --- a/xla/service/gather_scatter_utils.h +++ b/xla/service/gather_scatter_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,8 +29,8 @@ namespace xla { // Example: // input: indices = tensor<4x2x3xi32>, index_vector_dim = 1 // output: tensor<12x2xi32> -StatusOr TransformStartIndices(HloInstruction* indices, - int64_t index_vector_dim); +absl::StatusOr TransformStartIndices(HloInstruction* indices, + int64_t index_vector_dim); // Given a map from index vector positions to dimension numbers, returns a pair // of permutations that when applied to the operand, let you replace the map @@ -40,17 +40,18 @@ StatusOr TransformStartIndices(HloInstruction* indices, std::pair, std::vector> MakeOperandStartIndexPermutations(absl::Span, int operand_rank); -StatusOr MaybeTranspose(HloInstruction* operand, - absl::Span permutation); +absl::StatusOr MaybeTranspose( + HloInstruction* operand, absl::Span permutation); -StatusOr> MaybeTranspose( +absl::StatusOr> MaybeTranspose( absl::Span operands, const std::vector& operand_permutation); // Moves the given dimension to the last dimension. // Example: MoveDimensionToEnd(tensor<1x2x3xi1>, 0): tensor<2x3x1xi1>. -StatusOr MoveDimensionToEnd(HloInstruction* operand, - size_t dimension, size_t rank); +absl::StatusOr MoveDimensionToEnd(HloInstruction* operand, + size_t dimension, + size_t rank); } // namespace xla diff --git a/xla/service/gather_simplifier.cc b/xla/service/gather_simplifier.cc index 56f5f170d9bc1..354d26b4026a6 100644 --- a/xla/service/gather_simplifier.cc +++ b/xla/service/gather_simplifier.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ limitations under the License. namespace xla { -StatusOr GatherSimplifier::ExpandInstruction( +absl::StatusOr GatherSimplifier::ExpandInstruction( HloInstruction* inst) { auto* gather = DynCast(inst); @@ -125,19 +125,19 @@ StatusOr GatherSimplifier::ExpandInstruction( return MaybeTranspose(result, output_perm); } -bool GatherSimplifier::InstructionMatchesPattern(HloInstruction* inst) { - auto* gather = DynCast(inst); - if (!gather) { - return false; - } - +bool GatherSimplifier::IsSimplifiedGather(const HloGatherInstruction* gather) { auto* start_indices = gather->operands()[1]; const auto& dims = gather->gather_dimension_numbers(); - return start_indices->shape().rank() != 2 || dims.index_vector_dim() != 1 || - !IsIdentityPermutation(dims.start_index_map()) || - !dims.collapsed_slice_dims().empty() || - *dims.offset_dims().begin() != 1 || - *dims.offset_dims().rbegin() != dims.offset_dims().size(); + return start_indices->shape().rank() == 2 && dims.index_vector_dim() == 1 && + IsIdentityPermutation(dims.start_index_map()) && + dims.collapsed_slice_dims().empty() && + *dims.offset_dims().begin() == 1 && + *dims.offset_dims().rbegin() == dims.offset_dims().size(); +} + +bool GatherSimplifier::InstructionMatchesPattern(HloInstruction* inst) { + auto* gather = DynCast(inst); + return gather && !IsSimplifiedGather(gather); } } // namespace xla diff --git a/xla/service/gather_simplifier.h b/xla/service/gather_simplifier.h index f4cbd25d2143b..6d2b37502cd4c 100644 --- a/xla/service/gather_simplifier.h +++ b/xla/service/gather_simplifier.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_GATHER_SIMPLIFIER_H_ #define XLA_SERVICE_GATHER_SIMPLIFIER_H_ +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/op_expander_pass.h" namespace xla { @@ -36,10 +37,13 @@ class GatherSimplifier : public OpExpanderPass { public: absl::string_view name() const override { return "gather_simplifier"; } + static bool IsSimplifiedGather(const HloGatherInstruction* gather); + protected: bool InstructionMatchesPattern(HloInstruction* inst) override; - StatusOr ExpandInstruction(HloInstruction* inst) override; + absl::StatusOr ExpandInstruction( + HloInstruction* inst) override; }; } // namespace xla diff --git a/xla/service/gather_simplifier_test.cc b/xla/service/gather_simplifier_test.cc index 7087cc879c0d6..5f00fe83a01e2 100644 --- a/xla/service/gather_simplifier_test.cc +++ b/xla/service/gather_simplifier_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/generate_test_hlo_checks.py b/xla/service/generate_test_hlo_checks.py index 5a78e0b805427..21527dd0aec0e 100755 --- a/xla/service/generate_test_hlo_checks.py +++ b/xla/service/generate_test_hlo_checks.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/service/generate_test_hlo_checks_test.py b/xla/service/generate_test_hlo_checks_test.py index 4174dd3e195eb..90181650d6a9b 100644 --- a/xla/service/generate_test_hlo_checks_test.py +++ b/xla/service/generate_test_hlo_checks_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/service/generic_transfer_manager.cc b/xla/service/generic_transfer_manager.cc index 52d9420aa34bb..6a0ab999eb9b1 100644 --- a/xla/service/generic_transfer_manager.cc +++ b/xla/service/generic_transfer_manager.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,108 +15,35 @@ limitations under the License. #include "xla/service/generic_transfer_manager.h" +#include #include +#include +#include #include -#include #include #include #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" -#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/primitive_util.h" +#include "xla/service/shaped_buffer.h" #include "xla/service/transfer_manager.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" namespace xla { -namespace { - -// Transfer a memory block of the given size from the device source into the -// 'destination' buffer. -// -// size is the size to transfer to destination in bytes. -Status TransferBufferFromDevice(se::Stream* stream, - const se::DeviceMemoryBase& source, - int64_t size, void* destination) { - if (source.size() < size) { - return absl::FailedPreconditionError(absl::StrFormat( - "Source allocation on device not large enough for data transfer: " - "%d < %d", - source.size(), size)); - } - stream->ThenMemcpy(destination, source, size); - return OkStatus(); -} - -// Transfer a memory block of the given size from 'source' buffer to the given -// destination of the device. -// -// size is the size to transfer from source in bytes. -Status TransferBufferToDevice(se::Stream* stream, int64_t size, - const void* source, - se::DeviceMemoryBase* destination) { - if (destination->size() < size) { - return absl::FailedPreconditionError(absl::StrFormat( - "Destination allocation on device not large enough for data transfer: " - "%d < %d", - destination->size(), size)); - } - stream->ThenMemcpy(destination, source, size); - return OkStatus(); -} - -// Transfers a buffer of packed int4 values from the device to the host, then -// unpacks them on the host. 'source' is a buffer with (num_elements+1)/2 bytes -// where each byte stores two int4 values. 'destination' is a buffer with -// num_elements bytes, where a single int4 value will be written to each byte -// in the lower 4 bits. -Status TransferInt4ArrayFromDevice(se::Stream* stream, - const se::DeviceMemoryBase& source, - int64_t num_elements, void* destination) { - int64_t packed_size = (num_elements + 1) / 2; - auto packed_dst_data = std::make_unique>(packed_size); - TF_RETURN_IF_ERROR(TransferBufferFromDevice(stream, source, packed_size, - packed_dst_data->data())); - stream->ThenDoHostCallback([destination, num_elements, - moved_dst_data = std::move(packed_dst_data)]() { - UnpackInt4(*moved_dst_data, - absl::MakeSpan(static_cast(destination), num_elements)); - }); - return OkStatus(); -} - -// Packs an array of int4 values then transfers the packed buffer from the host -// to the device. 'source' is a buffer with num_elements bytes, where the lower -// 4 bits of each byte stores an int4 value. 'destination' is a buffer with -// (num_elements+1)/2 bytes, where two int4 values will be written into each -// byte. -Status TransferInt4ArrayToDevice(se::Stream* stream, int64_t num_elements, - const void* source, - se::DeviceMemoryBase* destination) { - auto packed_src_data = std::make_unique>( - CeilOfRatio(num_elements, int64_t{2})); - PackInt4(absl::MakeSpan(static_cast(source), num_elements), - absl::MakeSpan(*packed_src_data)); - TF_RETURN_IF_ERROR(TransferBufferToDevice( - stream, packed_src_data->size(), packed_src_data->data(), destination)); - // Ensure the buffer is transferred before we destroy it - stream->ThenDoHostCallback([keep_alive = std::move(packed_src_data)] {}); - return OkStatus(); -} - -} // namespace - GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id, size_t pointer_size) : platform_id_(platform_id), pointer_size_(pointer_size) {} @@ -138,9 +65,10 @@ Status GenericTransferManager::WriteSingleTupleIndexTable( TF_RETURN_IF_ERROR(TransferBufferToDevice( stream, GetByteSizeRequirement(shape), element_pointers->data(), region)); // Ensure the buffer is transferred before we destroy element_pointers. - stream->ThenDoHostCallback([element_pointers{std::move(element_pointers)}]() { - /* holds reference to element_pointers in closure */ - }); + TF_RETURN_IF_ERROR( + stream->DoHostCallback([element_pointers{std::move(element_pointers)}]() { + /* holds reference to element_pointers in closure */ + })); return OkStatus(); } @@ -201,10 +129,13 @@ void GenericTransferManager::TransferLiteralFromDevice( if ((transfer_metadata != nullptr) && tensorflow::down_cast(transfer_metadata) ->callback_is_host_callback_safe) { - stream->ThenDoHostCallback([done = std::move(done), stream] { + auto status = stream->DoHostCallback([done = std::move(done), stream] { done(stream->ok() ? OkStatus() - : InternalError("`TransferLiteralFromDevice` failed")); + : Internal("`TransferLiteralFromDevice` failed")); }); + if (!status.ok()) { + LOG(ERROR) << "`DoHostCallback` failed: " << status; + } } else { done(stream->BlockHostUntilDone()); } @@ -262,7 +193,8 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( subliteral.Relayout(device_subshape.layout())); TF_RETURN_IF_ERROR(TransferBuffer(relaid_out->untyped_data())); // Ensure the buffer is transferred before we destroy it. - stream->ThenDoHostCallback([keep_alive = std::move(relaid_out)] {}); + TF_RETURN_IF_ERROR(stream->DoHostCallback( + [keep_alive = std::move(relaid_out)] {})); } } return OkStatus(); @@ -286,9 +218,61 @@ Status GenericTransferManager::ResetDevices( "Device reset is not yet supported on this platform (b/30481585)"); } +Status GenericTransferManager::TransferBufferFromDevice( + se::Stream* stream, const se::DeviceMemoryBase& source, int64_t size, + void* destination) { + if (source.size() < size) { + return absl::FailedPreconditionError(absl::StrFormat( + "Source allocation on device not large enough for data transfer: " + "%d < %d", + source.size(), size)); + } + return stream->Memcpy(destination, source, size); +} + +Status GenericTransferManager::TransferBufferToDevice( + se::Stream* stream, int64_t size, const void* source, + se::DeviceMemoryBase* destination) { + if (destination->size() < size) { + return absl::FailedPreconditionError(absl::StrFormat( + "Destination allocation on device not large enough for data transfer: " + "%d < %d", + destination->size(), size)); + } + return stream->Memcpy(destination, source, size); +} + +Status GenericTransferManager::TransferInt4ArrayFromDevice( + se::Stream* stream, const se::DeviceMemoryBase& source, + int64_t num_elements, void* destination) { + int64_t packed_size = (num_elements + 1) / 2; + auto packed_dst_data = std::make_unique>(packed_size); + TF_RETURN_IF_ERROR(TransferBufferFromDevice(stream, source, packed_size, + packed_dst_data->data())); + TF_RETURN_IF_ERROR(stream->DoHostCallback([destination, num_elements, + packed_dst_data = + std::move(packed_dst_data)]() { + UnpackInt4(*packed_dst_data, + absl::MakeSpan(static_cast(destination), num_elements)); + })); + return OkStatus(); +} + +Status GenericTransferManager::TransferInt4ArrayToDevice( + se::Stream* stream, int64_t num_elements, const void* source, + se::DeviceMemoryBase* destination) { + auto packed_src_data = std::make_unique>( + CeilOfRatio(num_elements, int64_t{2})); + PackInt4(absl::MakeSpan(static_cast(source), num_elements), + absl::MakeSpan(*packed_src_data)); + TF_RETURN_IF_ERROR(TransferBufferToDevice( + stream, packed_src_data->size(), packed_src_data->data(), destination)); + return stream->DoHostCallback([keep_alive = std::move(packed_src_data)] {}); +} + int64_t GenericTransferManager::GetByteSizeRequirement( const Shape& shape) const { - if (shape.is_static() || shape.IsTuple()) { + if (shape.IsTuple() || shape.is_static()) { return ShapeUtil::ByteSizeOf(shape, pointer_size_); } int64_t metadata_size = sizeof(int32_t) * shape.dimensions_size(); @@ -304,4 +288,5 @@ Shape GenericTransferManager::HostShapeToDeviceShape( } return device_shape; } + } // namespace xla diff --git a/xla/service/generic_transfer_manager.h b/xla/service/generic_transfer_manager.h index 9fefa33924fbd..c80d89187073e 100644 --- a/xla/service/generic_transfer_manager.h +++ b/xla/service/generic_transfer_manager.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,24 @@ limitations under the License. #ifndef XLA_SERVICE_GENERIC_TRANSFER_MANAGER_H_ #define XLA_SERVICE_GENERIC_TRANSFER_MANAGER_H_ -#include - +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/node_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/literal.h" +#include "xla/service/shaped_buffer.h" #include "xla/service/transfer_manager.h" #include "xla/shape.h" +#include "xla/status.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/memory_allocation.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/xla_data.pb.h" @@ -54,6 +68,7 @@ class GenericTransferManager : public TransferManager { Status TransferLiteralToInfeed(se::StreamExecutor* executor, const LiteralSlice& literal) override; + Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, MutableBorrowingLiteral literal) override; @@ -68,11 +83,41 @@ class GenericTransferManager : public TransferManager { Shape HostShapeToDeviceShape(const Shape& host_shape) const override; private: - // Returns whether subbyte types (types less than 1 byte, e.g. U4) should - // have multiple values packed into a single byte on the device. Subbyte - // bytes are never packed on the host. By default, returns false, so a byte - // can only hold one value, but subclasses can override this. - virtual bool PackSubbyteTypes() const { return false; } + // Transfer a memory block of the given size from the device source into the + // 'destination' buffer. + // + // size is the size to transfer to destination in bytes. + virtual Status TransferBufferFromDevice(se::Stream* stream, + const se::DeviceMemoryBase& source, + int64_t size, void* destination); + + // Transfer a memory block of the given size from 'source' buffer to the given + // destination of the device. + // + // size is the size to transfer from source in bytes. + virtual Status TransferBufferToDevice(se::Stream* stream, int64_t size, + const void* source, + se::DeviceMemoryBase* destination); + + // Transfers a buffer of packed int4 values from the device to the host, then + // unpacks them on the host. 'source' is a buffer with (num_elements+1)/2 + // bytes where each byte stores two int4 values. 'destination' is a buffer + // with num_elements bytes, where a single int4 value will be written to each + // byte in the lower 4 bits. + virtual Status TransferInt4ArrayFromDevice(se::Stream* stream, + const se::DeviceMemoryBase& source, + int64_t num_elements, + void* destination); + + // Packs an array of int4 values then transfers the packed buffer from the + // host to the device. 'source' is a buffer with num_elements bytes, where the + // lower 4 bits of each byte stores an int4 value. 'destination' is a buffer + // with (num_elements+1)/2 bytes, where two int4 values will be written into + // each byte. + virtual Status TransferInt4ArrayToDevice(se::Stream* stream, + int64_t num_elements, + const void* source, + se::DeviceMemoryBase* destination); // The platform this transfer manager targets. const se::Platform::Id platform_id_; diff --git a/xla/service/generic_transfer_manager_test.cc b/xla/service/generic_transfer_manager_test.cc index 107ad62c83935..05eda50a0cdcd 100644 --- a/xla/service/generic_transfer_manager_test.cc +++ b/xla/service/generic_transfer_manager_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ limitations under the License. #include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/stream_executor/host/host_platform_id.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tests/literal_test_util.h" #include "xla/types.h" @@ -58,11 +59,9 @@ class GenericTransferManagerTest : public ::testing::Test { void SetUp() override { TF_ASSERT_OK_AND_ASSIGN( se::Platform * platform, - se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId)); + se::PlatformManager::PlatformWithId(se::host::kHostPlatformId)); TF_ASSERT_OK_AND_ASSIGN(stream_executor_, platform->ExecutorForDevice(0)); - stream_.emplace(stream_executor_); - stream_->Init(); - ASSERT_TRUE(stream_->ok()); + TF_ASSERT_OK_AND_ASSIGN(stream_, stream_executor_->CreateStream()); } ScopedShapedBuffer AllocateBuffer(const Shape& shape) { @@ -74,14 +73,14 @@ class GenericTransferManagerTest : public ::testing::Test { PackingTransferManager transfer_manager_; se::StreamExecutor* stream_executor_; - std::optional stream_; + std::unique_ptr stream_; }; TEST_F(GenericTransferManagerTest, TransferLiteralToDevice) { ScopedShapedBuffer buffer = AllocateBuffer(ShapeUtil::MakeShape(U16, {2, 2})); Literal literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}}); - TF_ASSERT_OK(transfer_manager_.TransferLiteralToDevice(&stream_.value(), - literal, buffer)); + TF_ASSERT_OK(transfer_manager_.TransferLiteralToDevice(stream_.get(), literal, + buffer)); se::DeviceMemoryBase device_mem = buffer.buffers().element({}); uint16_t* device_ptr = static_cast(device_mem.opaque()); @@ -114,7 +113,7 @@ TEST_F(GenericTransferManagerTest, TransferLiteralToDeviceInt4) { transfer_manager_.pack_subbyte_types_ = pack; ScopedShapedBuffer buffer = AllocateBuffer(ShapeUtil::MakeShape(S4, {2, 2})); - TF_ASSERT_OK(transfer_manager_.TransferLiteralToDevice(&stream_.value(), + TF_ASSERT_OK(transfer_manager_.TransferLiteralToDevice(stream_.get(), literal, buffer)); se::DeviceMemoryBase device_mem = buffer.buffers().element({}); ASSERT_EQ(device_mem.size(), pack ? 2 : 4); @@ -141,7 +140,7 @@ TEST_F(GenericTransferManagerTest, TransferLiteralFromDevice) { TF_ASSERT_OK_AND_ASSIGN( Literal literal, transfer_manager_.TransferManager::TransferLiteralFromDevice( - &stream_.value(), buffer)); + stream_.get(), buffer)); EXPECT_TRUE(LiteralTestUtil::Equal( literal, LiteralUtil::CreateR2({{1, 2}, {3, 4}}))); } @@ -170,7 +169,7 @@ TEST_F(GenericTransferManagerTest, TransferLiteralFromDeviceInt4) { TF_ASSERT_OK_AND_ASSIGN( Literal literal, transfer_manager_.TransferManager::TransferLiteralFromDevice( - &stream_.value(), buffer)); + stream_.get(), buffer)); EXPECT_TRUE(LiteralTestUtil::Equal( literal, LiteralUtil::CreateR2({{s4{1}, s4{-2}}, {s4{-3}, s4{4}}}))); diff --git a/xla/service/global_device_id.cc b/xla/service/global_device_id.cc index 8698912f8abbc..81632784860bf 100644 --- a/xla/service/global_device_id.cc +++ b/xla/service/global_device_id.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/global_device_id.h b/xla/service/global_device_id.h index cc843a5aee327..78f4c0a3dc914 100644 --- a/xla/service/global_device_id.h +++ b/xla/service/global_device_id.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,10 @@ limitations under the License. #ifndef XLA_SERVICE_GLOBAL_DEVICE_ID_H_ #define XLA_SERVICE_GLOBAL_DEVICE_ID_H_ +#include #include #include "absl/types/span.h" -#include "xla/types.h" #include "tsl/lib/gtl/int_type.h" namespace xla { diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index cbd8daa3b59dd..4f9a6096c2bc8 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -1,26 +1,22 @@ # Description: # GPU-specific components in XLA service implementation. -load("//xla/tests:build_defs.bzl", "xla_test") load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") -load("//xla:xla.bzl", "xla_cc_test", "xla_cub_deps", "xla_export_hlo_deps") -load( - "//xla/service/gpu:build_defs.bzl", - "build_cub_sort_kernels", - "get_cub_sort_kernel_types", -) -load( - "//xla/stream_executor:build_defs.bzl", - "if_gpu_is_configured", -) load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_hipblaslt", "if_rocm_is_configured", "rocm_copts", ) -load("@tsl//tsl:tsl.bzl", "if_google", "if_nccl", "tsl_copts", "tsl_gpu_library") +load( + "@tsl//tsl:tsl.bzl", + "if_google", + "if_nccl", + "internal_visibility", + "tsl_copts", + "tsl_gpu_library", +) load("@tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") load( "@tsl//tsl/platform:build_config.bzl", @@ -36,10 +32,22 @@ load( "@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) +load("//xla:xla.bzl", "xla_cc_test", "xla_cub_deps", "xla_export_hlo_deps") +load( + "//xla/service/gpu:build_defs.bzl", + "build_cub_sort_kernels", + "get_cub_sort_kernel_types", + "gpu_kernel_library", +) +load( + "//xla/stream_executor:build_defs.bzl", + "if_gpu_is_configured", +) +load("//xla/tests:build_defs.bzl", "xla_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [":friends"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -76,9 +84,13 @@ xla_cc_test( srcs = ["backend_configs_test.cc"], deps = [ ":backend_configs_cc", + "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", ], ) @@ -89,12 +101,19 @@ cc_library( compatible_with = get_compatible_with_portable(), visibility = ["//visibility:public"], deps = [ + ":nccl_clique_key", + "//xla:executable_run_options", "//xla:status_macros", "//xla:statusor", "//xla/service:executable", "//xla/service:global_device_id", "//xla/stream_executor", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", ], ) @@ -103,6 +122,19 @@ cc_library( hdrs = ["gpu_constants.h"], ) +cc_library( + name = "gpu_memory_space_assignment", + hdrs = ["gpu_memory_space_assignment.h"], + deps = [ + "//xla:status", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service:hlo_alias_analysis", + "//xla/service:hlo_ordering", + "//xla/service:hlo_value", + ], +) + cc_library( name = "launch_dimensions", srcs = [ @@ -117,8 +149,11 @@ cc_library( "//xla:statusor", "//xla:util", "//xla/stream_executor:device_description", + "//xla/stream_executor:launch_dim", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -129,6 +164,7 @@ xla_cc_test( tags = tf_cuda_tests_tags(), deps = [ "//xla:debug_options_flags", + "//xla:shape_util", "//xla:status", "//xla:status_macros", "//xla:test_helpers", @@ -136,25 +172,23 @@ xla_cc_test( "//xla/client/lib:constants", "//xla/ffi", "//xla/ffi:ffi_api", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:memref_view", - "//xla/runtime:module", - "//xla/runtime:module_registry", - "//xla/runtime/ffi:ffi_api", + "//xla/hlo/ir:hlo", "//xla/service:custom_call_status", "//xla/service:custom_call_target_registry", "//xla/service:executable", "//xla/service:gpu_plugin", - "//xla/service/gpu/runtime:custom_call_registry", - "//xla/service/gpu/runtime:support", + "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_types_header", "//xla/tests:client_library_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", @@ -165,16 +199,20 @@ xla_cc_test( xla_cc_test( name = "gpu_copy_insertion_test", - srcs = if_gpu_is_configured(["gpu_copy_insertion_test.cc"]), + srcs = ["gpu_copy_insertion_test.cc"], tags = tf_cuda_tests_tags(), - deps = if_gpu_is_configured([ + deps = [ ":buffer_sharing", "//xla:test", + "//xla:test_helpers", "//xla/hlo/ir:hlo", "//xla/service:copy_insertion", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - ]), + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@tsl//tsl/platform:statusor", + ], ) cc_library( @@ -184,17 +222,22 @@ cc_library( deps = [ ":buffer_allocations", ":ir_emission_utils", + "//xla:shape_tree", + "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", "//xla/service/llvm_ir:buffer_assignment_util", "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:tuple_ops", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", "@tsl//tsl/platform:logging", ], ) @@ -207,13 +250,16 @@ cc_library( deps = [ "//xla:shape_util", "//xla:status", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/llvm_ir:llvm_type_conversion_util", "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", "@llvm-project//llvm:TargetParser", "@tsl//tsl/platform:logging", ], @@ -226,6 +272,7 @@ xla_cc_test( ":target_util", "//xla/tests:xla_internal_test_main", "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", "@tsl//tsl/platform:test", ], ) @@ -249,11 +296,19 @@ cc_library( ":gpu_constants", ":gpu_executable", ":ir_emission_utils", + ":kernel_reuse_cache", + "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service:name_uniquer", + "//xla/service/gpu/model:indexing_map", + "//xla/service/gpu/runtime:nccl_collective_thunk", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", ], @@ -270,16 +325,13 @@ cc_library( ]), deps = [ ":backend_configs_cc", - ":gemm_thunk", + ":cublas_cudnn", ":gpu_asm_opts_util", - ":gpu_constants", ":gpu_conv_runner", - ":gpu_executable", + ":gpu_flash_attn", ":gpu_fused_mha_runner", - ":gpu_fusible", ":gpu_norm_runner", ":hlo_fusion_analysis", - ":hlo_to_ir_bindings", ":ir_emission_utils", ":ir_emitter", ":ir_emitter_context", @@ -287,14 +339,10 @@ cc_library( ":kernel_reuse_cache", ":launch_dimensions", ":matmul_utils", - ":nccl_collective_thunks", ":parallel_loop_emitter", - ":reduction_utils", - ":target_util", - ":thunk", + ":triton_call", "//xla:autotuning_proto_cc", "//xla:literal", - "//xla:permutation_util", "//xla:shape_util", "//xla:status", "//xla:status_macros", @@ -304,44 +352,59 @@ cc_library( "//xla/ffi:ffi_api", "//xla/ffi/api:c_api", "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/mlir_hlo", - "//xla/mlir_hlo:lhlo", - "//xla/mlir_hlo:lhlo_gpu", "//xla/mlir_hlo:transforms_gpu_passes", "//xla/service:buffer_assignment", + "//xla/service:collective_ops_utils", "//xla/service:custom_call_status", "//xla/service:custom_call_target_registry", + "//xla/service:global_device_id", "//xla/service:name_uniquer", "//xla/service/gpu/fusions", "//xla/service/gpu/fusions:fusion_emitter", - "//xla/service/gpu/fusions:input_slices", - "//xla/service/gpu/fusions:loop", - "//xla/service/gpu/fusions:reduction", "//xla/service/gpu/fusions:thunk_util", - "//xla/service/gpu/fusions:tiling_util", - "//xla/service/gpu/fusions:transpose", - "//xla/service/gpu/kernels:custom_fusion", "//xla/service/gpu/kernels:custom_kernel", - "//xla/service/gpu/runtime3:command_buffer_cmd", - "//xla/service/gpu/runtime3:command_buffer_cmd_emitter", - "//xla/service/gpu/runtime3:command_buffer_thunk", - "//xla/service/gpu/runtime3:custom_call_thunk", - "//xla/service/gpu/runtime3:fft_thunk", + "//xla/service/gpu/kernels:topk_custom_kernel", + "//xla/service/gpu/runtime:command_buffer_cmd", + "//xla/service/gpu/runtime:command_buffer_cmd_emitter", + "//xla/service/gpu/runtime:command_buffer_thunk", + "//xla/service/gpu/runtime:conditional_thunk", + "//xla/service/gpu/runtime:convolution_thunk", + "//xla/service/gpu/runtime:copy_thunk", + "//xla/service/gpu/runtime:custom_call_thunk", + "//xla/service/gpu/runtime:fft_thunk", + "//xla/service/gpu/runtime:flash_attn_thunk", + "//xla/service/gpu/runtime:fused_mha_thunk", + "//xla/service/gpu/runtime:gemm_thunk", + "//xla/service/gpu/runtime:infeed_thunk", + "//xla/service/gpu/runtime:kernel_thunk", + "//xla/service/gpu/runtime:nccl_all_gather_thunk", + "//xla/service/gpu/runtime:nccl_all_reduce_thunk", + "//xla/service/gpu/runtime:nccl_all_to_all_thunk", + "//xla/service/gpu/runtime:nccl_api", + "//xla/service/gpu/runtime:nccl_collective_broadcast_thunk", + "//xla/service/gpu/runtime:nccl_collective_permute_thunk", + "//xla/service/gpu/runtime:nccl_collective_thunk", + "//xla/service/gpu/runtime:nccl_recv_thunk", + "//xla/service/gpu/runtime:nccl_send_thunk", + "//xla/service/gpu/runtime:norm_thunk", + "//xla/service/gpu/runtime:outfeed_thunk", + "//xla/service/gpu/runtime:replica_id_thunk", + "//xla/service/gpu/runtime:send_recv_thunk", + "//xla/service/gpu/runtime:sequential_thunk", + "//xla/service/gpu/runtime:thunk", + "//xla/service/gpu/runtime:wait_for_streams_thunk", + "//xla/service/gpu/runtime:while_thunk", "//xla/service/llvm_ir:buffer_assignment_util", - "//xla/service/llvm_ir:dynamic_update_slice_util", - "//xla/service/llvm_ir:fused_ir_emitter", "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:kernel_support_library", + "//xla/service/llvm_ir:llvm_loop", "//xla/service/llvm_ir:llvm_util", + "//xla/service/llvm_ir:loop_emitter", "//xla/service/llvm_ir:sort_util", "//xla/stream_executor:device_description", - "//xla/translate/hlo_to_mhlo:hlo_utils", - "//xla/translate/mhlo_to_hlo:attribute_exporter", - "//xla/translate/mhlo_to_hlo:location_exporter", - "//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", - "//xla/translate/mhlo_to_lhlo_with_xla", - "@com_google_absl//absl/algorithm:container", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor/gpu:gpu_blas_lt", + "//xla/stream_executor/integrations:device_mem_allocator", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -354,30 +417,30 @@ cc_library( "@llvm-project//llvm:Linker", "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:BuiltinToLLVMIRTranslation", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncExtensions", - "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LLVMToLLVMIRTranslation", - "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Parser", "@llvm-project//mlir:ROCDLToLLVMIRTranslation", + "@llvm-project//mlir:Support", "@llvm-project//mlir:ToLLVMIRTranslation", + "@triton//:TritonDialects", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:human_readable_json", - "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", "@tsl//tsl/protobuf:dnn_proto_cc", ] + if_gpu_is_configured([ - ":cub_sort_thunk", - ":gpublas_lt_matmul_thunk", ":ir_emitter_triton", - "//xla/service/gpu/runtime3:cholesky_thunk", - "//xla/service/gpu/runtime3:triangular_solve_thunk", + "//xla/service/gpu/runtime:cholesky_thunk", + "//xla/service/gpu/runtime:cub_sort_thunk", + "//xla/service/gpu/runtime:gpublas_lt_matmul_thunk", + "//xla/service/gpu/runtime:triangular_solve_thunk", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", ]), ) @@ -392,12 +455,10 @@ cc_library( "elemental_ir_emitter.h", "ir_emitter.h", "ir_emitter_nested.h", - "kernel_mapping_scheme.h", ], copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]), deps = [ ":backend_configs_cc", - ":hlo_fusion_analysis", ":hlo_to_ir_bindings", ":ir_emission_utils", ":ir_emitter_context", @@ -405,6 +466,7 @@ cc_library( ":target_util", "//xla:literal", "//xla:shape_util", + "//xla:status_macros", "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", @@ -416,16 +478,25 @@ cc_library( "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:ir_builder_mixin", "//xla/service/llvm_ir:kernel_support_library", + "//xla/service/llvm_ir:llvm_loop", "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:loop_emitter", "//xla/service/llvm_ir:math_ops", "//xla/service/llvm_ir:tuple_ops", - "@com_google_absl//absl/container:inlined_vector", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -433,6 +504,8 @@ cc_library( name = "ir_emitter_triton", srcs = if_cuda_is_configured(["ir_emitter_triton.cc"]) + if_rocm_hipblaslt([ "ir_emitter_triton.cc", + ]) + if_cuda_is_configured(["ir_emitter_triton_cuda.cc"]) + if_rocm_is_configured([ + "ir_emitter_triton_rocm.cc", ]), hdrs = if_gpu_is_configured(["ir_emitter_triton.h"]), deps = [ @@ -445,6 +518,7 @@ cc_library( ":triton_tiling_propagation", "//xla:autotuning_proto_cc", "//xla:comparison_util", + "//xla:debug_options_flags", "//xla:literal", "//xla:shape_util", "//xla:status", @@ -456,10 +530,19 @@ cc_library( "//xla/hlo/utils:hlo_query", "//xla/mlir_hlo", "//xla/mlir_hlo:map_mhlo_to_scalar_op", + "//xla/service:algorithm_util", "//xla/service:dump", + "//xla/service:hlo_module_config", + "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/llvm_gpu_backend", + "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_map", + "//xla/service/gpu/model:symbolic_tile_analysis", + "//xla/service/gpu/model:symbolic_tiled_hlo_instruction", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:device_description", + "//xla/stream_executor:launch_dim", "//xla/translate/hlo_to_mhlo:hlo_module_importer", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -467,6 +550,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", @@ -475,9 +559,12 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//llvm:TargetParser", "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineToStandard", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithToLLVM", "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:ControlFlowToLLVM", "@llvm-project//mlir:ExecutionEngineUtils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -488,6 +575,7 @@ cc_library( "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:NVVMToLLVMIRTranslation", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLToLLVMIRTranslation", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFToControlFlow", "@llvm-project//mlir:Support", @@ -495,6 +583,7 @@ cc_library( "@llvm-project//mlir:Transforms", "@triton//:TritonDialects", "@triton//:TritonTransforms", + "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:path", @@ -502,36 +591,36 @@ cc_library( "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:tensor_float_32_utils", ] + if_cuda_is_configured([ - "@triton//:NVGPUToLLVM", - "@triton//:TritonGPUToLLVM", + "@triton//third_party/nvidia:NVGPUToLLVM", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", "@triton//:TritonGPUTransforms", "@triton//:TritonNvidiaGPUTransforms", "@triton//:TritonLLVMIR", "@triton//:TritonToTritonGPU", + "@triton//:TritonGPUToLLVM", ]), ) xla_test( name = "ir_emitter_triton_test", srcs = if_cuda_is_configured(["ir_emitter_triton_test.cc"]), - backend_tags = {"gpu": [ - "requires-gpu-sm70", - ]}, backends = [ - "gpu", + "gpu_a100", + "gpu_h100", ], shard_count = 20, tags = ["nomac"], deps = [ ":backend_configs_cc", ":gpu_device_info_for_tests", - ":ir_emission_utils", ":ir_emitter_triton", ":matmul_utils", + ":triton_fusion_analysis", "//xla:autotuning_proto_cc", "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", "//xla:status_macros", - "//xla:statusor", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:pattern_matcher", @@ -542,13 +631,15 @@ xla_test( "//xla/tests:filecheck", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Transforms", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", @@ -564,10 +655,11 @@ xla_test( name = "ir_emitter_triton_large_test", srcs = if_cuda_is_configured(["ir_emitter_triton_large_test.cc"]), backend_tags = {"gpu": [ - "requires-gpu-sm70", + "requires-gpu-sm80", ]}, backends = [ - "gpu", + "gpu_a100", + "gpu_h100", ], tags = [ "large", @@ -581,6 +673,7 @@ xla_test( "//xla/service/gpu/tests:gpu_codegen_test", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/log:check", "@com_google_googletest//:gtest", ], ) @@ -588,11 +681,8 @@ xla_test( xla_test( name = "ir_emitter_triton_parametrized_test", srcs = if_cuda_is_configured(["ir_emitter_triton_parametrized_test.cc"]), - backend_tags = {"gpu": [ - "requires-gpu-sm70", - ]}, backends = [ - "gpu", + "gpu_a100", ], shard_count = 10, tags = ["nomac"], @@ -614,9 +704,9 @@ xla_test( ) cc_library( - name = "triton_autotuner", - srcs = if_cuda_is_configured(["triton_autotuner.cc"]), - hdrs = if_cuda_is_configured(["triton_autotuner.h"]), + name = "gemm_fusion_autotuner", + srcs = if_cuda_is_configured(["gemm_fusion_autotuner.cc"]), + hdrs = if_cuda_is_configured(["gemm_fusion_autotuner.h"]), deps = if_cuda_is_configured([ ":autotuner_compile_util", ":autotuner_util", @@ -630,6 +720,7 @@ cc_library( ":matmul_utils", ":split_k_gemm_rewriter", ":stream_executor_util", + ":cudnn_fusion_compiler", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -638,17 +729,23 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@local_config_cuda//cuda:cuda_headers", "//xla:autotuning_proto_cc", "//xla:shape_util", "//xla:status_macros", + "//xla/tools:hlo_decomposer_lib", "//xla:status", "//xla:statusor", "//xla:util", + "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", + "//xla/service:algorithm_util", "//xla/service:dump", "//xla/service:executable", "//xla/service:float_normalization", @@ -665,15 +762,15 @@ cc_library( "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", - "@tsl//tsl/util/proto:proto_utils", + "//xla/tsl/util/proto:proto_utils", ]), ) xla_test( - name = "triton_autotuner_test", - srcs = if_cuda_is_configured(["triton_autotuner_test.cc"]), + name = "gemm_fusion_autotuner_test", + srcs = if_cuda_is_configured(["gemm_fusion_autotuner_test.cc"]), backend_tags = {"gpu": [ - "requires-gpu-sm70", + "requires-gpu-sm80", ]}, backends = [ "gpu", @@ -684,12 +781,12 @@ xla_test( deps = [ ":autotuner_util", ":backend_configs_cc", - ":gemm_rewriter_triton", + ":gemm_fusion", + ":gemm_fusion_autotuner", + ":ir_emission_utils", ":matmul_utils", - ":triton_autotuner", "//xla:autotuning_proto_cc", "//xla:error_spec", - "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", @@ -704,6 +801,7 @@ xla_test( "//xla/tests:test_utils", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tools:hlo_decomposer_lib", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", @@ -718,6 +816,21 @@ xla_test( ], ) +cc_library( + name = "triton_call", + srcs = if_gpu_is_configured(["triton_call.cc"]), + hdrs = ["triton_call.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + deps = [ + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "parallel_loop_emitter", srcs = ["parallel_loop_emitter.cc"], @@ -733,6 +846,10 @@ cc_library( "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:loop_emitter", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", @@ -744,205 +861,225 @@ cc_library( srcs = ["buffer_allocations.cc"], hdrs = ["buffer_allocations.h"], deps = [ - ":gpu_constants", "//xla:status", - "//xla:status_macros", "//xla:statusor", - "//xla:types", "//xla:util", "//xla/service:buffer_assignment", "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/gtl:map_util", - "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", ], ) cc_library( - name = "thunk", - srcs = ["thunk.cc"], - hdrs = ["thunk.h"], + name = "nccl_clique_key", + srcs = ["nccl_clique_key.cc"], + hdrs = ["nccl_clique_key.h"], + compatible_with = get_compatible_with_portable(), deps = [ - ":buffer_allocations", - ":gpu_executable_run_options", - "//xla/hlo/ir:hlo", - "//xla/service:executable", - "//xla/stream_executor", - "//xla/translate/mhlo_to_hlo:location_exporter", + "//xla/service:global_device_id", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@llvm-project//mlir:IR", - "@tsl//tsl/platform:status", ], ) -tsl_gpu_library( - name = "nccl_collective_thunks", - srcs = [ - "nccl_all_gather_thunk.cc", - "nccl_all_reduce_thunk.cc", - "nccl_all_to_all_thunk.cc", - "nccl_collective_permute_thunk.cc", - "nccl_collective_thunk.cc", - "nccl_p2p_thunk_common.cc", - "nccl_recv_thunk.cc", - "nccl_send_thunk.cc", - ], - hdrs = [ - "nccl_all_gather_thunk.h", - "nccl_all_reduce_thunk.h", - "nccl_all_to_all_thunk.h", - "nccl_collective_permute_thunk.h", - "nccl_collective_thunk.h", - "nccl_p2p_thunk_common.h", - "nccl_recv_thunk.h", - "nccl_send_thunk.h", - ], - # Override tsl_gpu_library()'s internal default value of ["//buildenv/target:gce"]. - compatible_with = [], +xla_cc_test( + name = "nccl_clique_key_test", + srcs = ["nccl_clique_key_test.cc"], deps = [ - ":buffer_allocations", - ":ir_emission_utils", - ":nccl_utils", - ":thunk", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/mlir_hlo:lhlo", - "//xla/mlir_hlo:lhlo_gpu", - "//xla/service:buffer_assignment", - "//xla/service:collective_ops_utils", + ":nccl_clique_key", "//xla/service:global_device_id", - "//xla/service:hlo_parser", - "//xla/service/llvm_ir:llvm_util", - "//xla/stream_executor/gpu:gpu_activation", - "//xla/stream_executor/gpu:gpu_activation_header", - "//xla/stream_executor/gpu:gpu_stream", - "//xla/translate/hlo_to_mhlo:hlo_utils", - "//xla/translate/mhlo_to_hlo:attribute_exporter", - "//xla/translate/mhlo_to_hlo:type_to_shape", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@llvm-project//mlir:IR", - "@tsl//tsl/platform:logging", + "@com_google_absl//absl/container:btree", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", ], ) -# Empty library to implement nested dependency conditions. -cc_library(name = "empty") - -# If NCCL/RCCL is supported, this target '#defines XLA_ENABLE_XCCL' and -# provides a header which #includes NCCL/RCCL. -alias( - name = "nccl_utils", - actual = if_nccl(":_nccl_utils", ":empty"), -) - -# Do not depend on this target, but rather depend on :nccl_utils. -tsl_gpu_library( - name = "_nccl_utils", - srcs = if_gpu_is_configured(["nccl_utils.cc"]), - hdrs = if_gpu_is_configured(["nccl_utils.h"]), - # Override tsl_gpu_library()'s internal default value of ["//buildenv/target:gce"]. - compatible_with = [], - defines = if_gpu_is_configured(["XLA_ENABLE_XCCL"]), - tags = ["manual"], # Only builds with if_nccl(). - deps = if_gpu_is_configured([ - ":gpu_executable_run_options", - ":thunk", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "//xla:debug_options_flags", - "//xla:status", - "//xla:status_macros", - "//xla:statusor", - "//xla:xla_data_proto_cc", - "//xla/service:collective_ops_utils", - "//xla/service:global_device_id", - "//xla/service:rendezvous", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:env", - ]) + if_cuda_is_configured([ - "@local_config_nccl//:nccl", - ]) + if_rocm_is_configured([ - "@local_config_rocm//rocm:rccl", - ]), -) - -# TODO(b/244780257): Remove this config. -bool_flag( - name = "enable_xlir", - build_setting_default = if_google(True, False), +cuda_library( + name = "sleep_kernel", + srcs = if_cuda_is_configured(["sleep_kernel.cu.cc"]), + hdrs = if_cuda_is_configured(["sleep_kernel.h"]), + deps = ["@local_config_cuda//cuda:cuda_headers"], ) cc_library( - name = "non_atomically_upgradeable_rw_lock", - srcs = [], - hdrs = [ - "non_atomically_upgradeable_rw_lock.h", - ], + name = "mock_nccl_xml_google", + srcs = ["mock_nccl_xml.cc"], + hdrs = ["mock_nccl_xml.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + tags = ["manual"], + visibility = ["//visibility:private"], deps = [ - "@com_google_absl//absl/synchronization", - ], + "//xla:status", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:regexp", + ] + if_cuda_is_configured([ + "@local_config_nccl//:nccl", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rccl", + ]), ) xla_cc_test( - name = "non_atomically_upgradeable_rw_lock_test", - srcs = ["non_atomically_upgradeable_rw_lock_test.cc"], + name = "mock_nccl_xml_test", + size = "small", + srcs = if_google(["mock_nccl_xml_test.cc"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + tags = tf_cuda_tests_tags(), deps = [ - ":non_atomically_upgradeable_rw_lock", + "//xla:status", "@com_google_googletest//:gtest_main", "@tsl//tsl/platform:test", - ], + ] + if_google([ + ":mock_nccl_xml_google", + ]) + if_cuda_is_configured([ + "@local_config_nccl//:nccl", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rccl", + ]), +) + +# Empty library to implement nested dependency conditions. +cc_library( + name = "empty", + compatible_with = get_compatible_with_portable(), +) + +alias( + name = "mock_nccl_utils", + actual = if_nccl( + if_google(":_mock_nccl_utils_google", ":_mock_nccl_utils_default"), + ":empty", + ), +) + +# Do not build mock_nccl_utils.cc in OSS. It uses the nccl internal cost model to estimate the +# communication time of nccl collective calls. Only build it in Google builds. +# TODO(b/306073484): Stub out the cost model api used in mock nccl functions. +cc_library( + name = "_mock_nccl_utils_google", + srcs = if_cuda_is_configured(["mock_nccl_utils.cc"]), + hdrs = if_cuda_is_configured([ + "mock_nccl_utils.h", + "mock_nccl_topo_config.h", + ]), + # Override tsl_gpu_library()'s internal default value of ["//buildenv/target:gce"]. + compatible_with = [], + tags = ["manual"], + visibility = ["//visibility:private"], + deps = if_cuda_is_configured([ + ":gpu_executable_run_options", + ":mock_nccl_xml_google", + ":nccl_clique_key", + ":sleep_kernel", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@local_config_cuda//cuda:cuda_headers", + "@local_config_nccl//:nccl", + "//xla:debug_options_flags", + "//xla:executable_run_options", + "//xla:shape_util", + "//xla:status", + "//xla:status_macros", + "//xla:statusor", + "//xla:util", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "//xla/service:rendezvous", + "//xla/service:lockable", + "//xla/service/gpu/runtime:nccl_api", + "//xla/service/gpu/runtime:nccl_clique", + "//xla/service/gpu/runtime:nccl_collective_thunk", + "//xla/service/gpu/runtime:nccl_p2p_thunk_common", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "//xla/stream_executor/gpu:gpu_activation", + "//xla/stream_executor/gpu:gpu_stream", + "//xla/stream_executor/gpu:gpu_types_header", + "@tsl//tsl/platform:env", + "@tsl//tsl/lib/gtl:int_type", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ]), +) + +cc_library( + name = "_mock_nccl_utils_default", + srcs = if_gpu_is_configured(["mock_nccl_utils_default.cc"]), + hdrs = if_gpu_is_configured(["mock_nccl_utils.h"]), + # Override tsl_gpu_library()'s internal default value of ["//buildenv/target:gce"]. + compatible_with = [], + tags = ["manual"], + visibility = ["//visibility:private"], + deps = if_gpu_is_configured([ + ":gpu_executable_run_options", + ":nccl_clique_key", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "//xla:executable_run_options", + "//xla:status", + "//xla:statusor", + "//xla:util", + "//xla/service/gpu/runtime:nccl_api", + "//xla/service/gpu/runtime:nccl_clique", + "//xla/service/gpu/runtime:nccl_collective_thunk", + "//xla/service/gpu/runtime:nccl_p2p_thunk_common", + "//xla/service/gpu/runtime:thunk", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "//xla/service:lockable", + "//xla/stream_executor", + "@tsl//tsl/lib/gtl:int_type", + ]) + if_cuda_is_configured([ + "@local_config_nccl//:nccl", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rccl", + ]), +) + +# TODO(b/244780257): Remove this config. +bool_flag( + name = "enable_xlir", + build_setting_default = if_google(True, False), ) cc_library( name = "gpu_executable", srcs = [ - "conditional_thunk.cc", - "convolution_thunk.cc", - "copy_thunk.cc", - "for_thunk.cc", - "fused_mha_thunk.cc", "gpu_executable.cc", - "infeed_thunk.cc", - "kernel_thunk.cc", - "memset_thunk.cc", - "norm_thunk.cc", - "outfeed_thunk.cc", - "replica_id_thunk.cc", - "sequential_thunk.cc", - "while_thunk.cc", ], hdrs = [ - "conditional_thunk.h", - "convolution_thunk.h", - "copy_thunk.h", - "for_thunk.h", - "fused_mha_thunk.h", - "gemm_thunk.h", "gpu_executable.h", - "infeed_thunk.h", - "kernel_thunk.h", - "memset_thunk.h", - "norm_thunk.h", - "outfeed_thunk.h", - "replica_id_thunk.h", - "sequential_thunk.h", - "while_thunk.h", ], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", @@ -950,95 +1087,71 @@ cc_library( deps = [ ":backend_configs_cc", ":buffer_allocations", - ":cusolver_context", - ":gemm_thunk", - ":gpu_asm_opts_util", ":gpu_constants", - ":gpu_conv_runner", ":gpu_executable_run_options", - ":gpu_fused_mha_runner", - ":gpu_norm_runner", - ":io_feed_manager", ":ir_emission_utils", - ":kernel_arguments", - ":launch_dimensions", - ":matmul_utils", - ":nccl_collective_thunks", - ":non_atomically_upgradeable_rw_lock", + ":nccl_clique_key", ":stream_executor_util", - ":thunk", - "//xla:array2d", - "//xla:literal", - "//xla:refcounting_hash_map", + "//xla:executable_run_options", "//xla:shape_tree", "//xla:shape_util", "//xla:status", "//xla:status_macros", - "//xla:statusor", - "//xla:types", "//xla:util", - "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/mlir/runtime/ir:rt", - "//xla/mlir/runtime/transforms:compilation_pipeline_gpu", - "//xla/mlir/runtime/transforms:type_converter", - "//xla/mlir_hlo:lhlo_gpu", - "//xla/runtime:executable", "//xla/service:buffer_assignment", - "//xla/service:custom_call_status_internal", "//xla/service:executable", - "//xla/service:hlo_dataflow_analysis", "//xla/service:hlo_execution_profile", + "//xla/service:hlo_module_config", "//xla/service:hlo_parser", + "//xla/service:hlo_value", + "//xla/service:maybe_owning_device_memory", + "//xla/service:rendezvous", "//xla/service:shaped_buffer", "//xla/service:stream_pool", "//xla/service:xla_debug_info_manager", - "//xla/service/gpu/kernels:custom_kernel", - "//xla/service/gpu/runtime:executable", - "//xla/service/gpu/runtime:support", - "//xla/service/gpu/runtime3:custom_call_thunk", - "//xla/service/gpu/runtime3:fft_thunk", + "//xla/service/gpu/runtime:annotation", + "//xla/service/gpu/runtime:nccl_clique", + "//xla/service/gpu/runtime:thunk", "//xla/stream_executor", - "//xla/stream_executor:blas", "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor/cuda:cuda_platform_id", - "//xla/stream_executor/gpu:asm_compiler", "//xla/stream_executor/gpu:gpu_activation", - "//xla/stream_executor/gpu:gpu_asm_opts", "//xla/stream_executor/gpu:gpu_executor_header", - "//xla/stream_executor/gpu:gpu_stream", - "//xla/stream_executor/gpu:gpu_types_header", + "//xla/stream_executor/gpu:gpu_stream_header", + "//xla/stream_executor/gpu:gpu_timer_header", "//xla/stream_executor/rocm:rocm_platform_id", - "//xla/translate/mhlo_to_hlo:location_exporter", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", + "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Support", + "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:random", "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:scoped_annotation", "@tsl//tsl/profiler/lib:traceme", ] + if_gpu_is_configured([ ":make_batch_pointers", - "//xla/service/gpu/runtime3:cholesky_thunk", - "//xla/service/gpu/runtime3:triangular_solve_thunk", ]) + if_cuda_is_configured([ "//xla/stream_executor/cuda:cublas_plugin", "//xla/stream_executor/cuda:cuda_stream", @@ -1070,7 +1183,6 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/mlir_hlo", - "//xla/mlir_hlo:lhlo", "//xla/service:buffer_assignment", "//xla/service:hlo_parser", "//xla/service/llvm_ir:buffer_assignment_util", @@ -1080,8 +1192,11 @@ cc_library( "//xla/translate/mhlo_to_hlo:type_to_shape", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", @@ -1092,8 +1207,11 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:ml_dtypes", + "@tsl//tsl/platform:statusor", ], ) @@ -1101,6 +1219,7 @@ xla_cc_test( name = "ir_emission_utils_test", srcs = ["ir_emission_utils_test.cc"], deps = [ + ":hlo_traversal", ":ir_emission_utils", "//xla:literal", "//xla:literal_util", @@ -1108,7 +1227,6 @@ xla_cc_test( "//xla:util", "//xla/hlo/ir:hlo", "//xla/mlir_hlo", - "//xla/mlir_hlo:lhlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/translate/hlo_to_mhlo:hlo_utils", @@ -1128,6 +1246,7 @@ cc_library( hdrs = ["reduction_utils.h"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ + ":ir_emission_utils", "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", @@ -1141,60 +1260,52 @@ cc_library( ]), ) +xla_cc_test( + name = "reduction_utils_test", + srcs = ["reduction_utils_test.cc"], + deps = [ + ":reduction_utils", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "cublas_cudnn", srcs = ["cublas_cudnn.cc"], hdrs = ["cublas_cudnn.h"], compatible_with = get_compatible_with_portable(), deps = [ + "//xla:status", + "//xla:statusor", + "//xla:util", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@tsl//tsl/platform:statusor", ], ) -cuda_library( - name = "gpu_prim_cuda", - hdrs = ["gpu_prim_cuda.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - deps = [ - "@eigen_archive//:eigen3", - "@tsl//tsl/platform:bfloat16", - ] + if_cuda_is_configured(xla_cub_deps()), -) - -cc_library( - name = "gpu_prim_rocm", - hdrs = ["gpu_prim_rocm.h"], - local_defines = if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), +gpu_kernel_library( + name = "gpu_prim", + hdrs = ["gpu_prim.h"], deps = [ "@eigen_archive//:eigen3", + "@local_config_cuda//cuda:cuda_headers", "@tsl//tsl/platform:bfloat16", - ] + if_rocm_is_configured([ + ] + if_cuda_is_configured(xla_cub_deps()) + if_rocm_is_configured([ "@local_config_rocm//rocm:rocprim", ]), ) cc_library( - name = "cub_sort_thunk", - srcs = if_gpu_is_configured(["cub_sort_thunk.cc"]), - hdrs = if_gpu_is_configured(["cub_sort_thunk.h"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - deps = if_gpu_is_configured([ - ":buffer_allocations", - ":thunk", - "@com_google_absl//absl/log:check", - "//xla/service:buffer_assignment", - "//xla:shape_util", - "//xla/stream_executor:device_memory", - "//xla:status", - "//xla:statusor", - "//xla:util", - "//xla:xla_data_proto_cc", - "@tsl//tsl/platform:errors", - ] + [":cub_sort_kernel_" + suffix for suffix in get_cub_sort_kernel_types()]), + name = "variant_visitor", + hdrs = ["variant_visitor.h"], ) build_cub_sort_kernels( @@ -1205,10 +1316,8 @@ build_cub_sort_kernels( "TENSORFLOW_USE_ROCM=1", ]), types = get_cub_sort_kernel_types(), - deps = if_cuda_is_configured([ - ":gpu_prim_cuda", - ]) + if_rocm_is_configured([ - ":gpu_prim_rocm", + deps = if_gpu_is_configured([ + ":gpu_prim", ]), ) @@ -1216,19 +1325,26 @@ cc_library( name = "gemm_rewriter", srcs = ["gemm_rewriter.cc"], hdrs = ["gemm_rewriter.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), deps = [ ":backend_configs_cc", ":cublas_cudnn", ":ir_emission_utils", ":matmul_utils", + "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:status", "//xla:status_macros", "//xla:statusor", + "//xla:types", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", + "//xla/service:algorithm_util", "//xla/service:hlo_creation_utils", "//xla/service:hlo_pass", "//xla/service:pattern_matcher", @@ -1236,10 +1352,14 @@ cc_library( "//xla/stream_executor:device_description", "//xla/stream_executor/gpu:gpu_blas_lt", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:statusor", "@tsl//tsl/protobuf:dnn_proto_cc", ] + if_cuda_is_configured([ @@ -1252,10 +1372,52 @@ cc_library( srcs = ["triton_support.cc"], hdrs = ["triton_support.h"], deps = [ + ":variant_visitor", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service:instruction_fusion", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", + "@tsl//tsl/platform:tensor_float_32_utils", + ], +) + +xla_test( + name = "triton_support_test", + srcs = if_cuda_is_configured(["triton_support_test.cc"]), + backends = [ + "gpu_a100", + ], + shard_count = 10, + tags = ["nomac"], + deps = [ + ":gpu_device_info_for_tests", + ":gpu_float_support", + ":ir_emission_utils", + ":ir_emitter_triton", + ":matmul_utils", + ":triton_fusion_analysis", + ":triton_support", + "//xla:error_spec", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:float_normalization", + "//xla/service:hlo_pass_pipeline", + "//xla/service/gpu/tests:gpu_codegen_test", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", ], ) @@ -1281,6 +1443,17 @@ cc_library( ], ) +xla_cc_test( + name = "triton_tiling_propagation_test", + srcs = ["triton_tiling_propagation_test.cc"], + deps = [ + ":triton_tiling_propagation", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "triton_fusion_analysis", srcs = ["triton_fusion_analysis.cc"], @@ -1292,16 +1465,19 @@ cc_library( "//xla:shape_util", "//xla:status", "//xla:status_macros", - "//xla:statusor", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/service:instruction_fusion", + "//xla/tools:hlo_decomposer_lib", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -1309,13 +1485,15 @@ xla_cc_test( name = "triton_fusion_analysis_test", srcs = ["triton_fusion_analysis_test.cc"], deps = [ - ":gemm_rewriter_triton", + ":gemm_fusion", ":triton_fusion_analysis", + "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", "@tsl//tsl/platform:statusor", @@ -1323,9 +1501,9 @@ xla_cc_test( ) cc_library( - name = "gemm_rewriter_triton", - srcs = ["gemm_rewriter_triton.cc"], - hdrs = ["gemm_rewriter_triton.h"], + name = "gemm_fusion", + srcs = ["gemm_fusion.cc"], + hdrs = ["gemm_fusion.h"], deps = [ ":backend_configs_cc", ":cublas_padding_requirements", @@ -1336,9 +1514,9 @@ cc_library( ":triton_tiling_propagation", "//xla:shape_util", "//xla:status", - "//xla:status_macros", "//xla:statusor", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", "//xla/service:instruction_fusion", @@ -1348,6 +1526,8 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", @@ -1356,11 +1536,11 @@ cc_library( ) xla_cc_test( - name = "gemm_rewriter_triton_test", - srcs = ["gemm_rewriter_triton_test.cc"], + name = "gemm_fusion_test", + srcs = ["gemm_fusion_test.cc"], deps = [ ":cublas_padding_requirements", - ":gemm_rewriter_triton", + ":gemm_fusion", ":triton_fusion_analysis", "//xla:autotuning_proto_cc", "//xla:statusor", @@ -1374,8 +1554,45 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "gemv_rewriter", + srcs = ["gemv_rewriter.cc"], + hdrs = ["gemv_rewriter.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "gemv_rewriter_test", + srcs = ["gemv_rewriter_test.cc"], + deps = [ + ":gemv_rewriter", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", "@tsl//tsl/platform:statusor", ], ) @@ -1441,52 +1658,6 @@ xla_cc_test( ], ) -cc_library( - name = "fusion_merger_triton", - srcs = ["fusion_merger_triton.cc"], - hdrs = ["fusion_merger_triton.h"], - deps = [ - ":backend_configs_cc", - ":ir_emission_utils", - ":triton_fusion_analysis", - "//xla:status", - "//xla:statusor", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "fusion_merger_triton_test", - srcs = ["fusion_merger_triton_test.cc"], - backend_tags = {"gpu": [ - "requires-gpu-sm70", - ]}, - backends = [ - "gpu", - ], - deps = [ - ":fusion_merger_triton", - "//xla:autotune_results_proto_cc", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # build_cleaner: keep - "@com_google_absl//absl/log", - "@com_google_googletest//:gtest", - "@tsl//tsl/platform:status_matchers", - ], -) - cc_library( name = "softmax_rewriter_triton", srcs = ["softmax_rewriter_triton.cc"], @@ -1498,18 +1669,19 @@ cc_library( "//xla:shape_util", "//xla:status", "//xla:status_macros", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/service:hlo_pass", + "//xla/service:instruction_fusion", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", @@ -1517,40 +1689,6 @@ cc_library( ], ) -cc_library( - name = "gemm_thunk", - srcs = ["gemm_thunk.cc"], - hdrs = ["gemm_thunk.h"], - deps = [ - ":matmul_utils", - ":thunk", - "//xla:status", - "//xla/service:buffer_assignment", - "//xla/stream_executor", - "//xla/stream_executor:device_memory", - "@tsl//tsl/platform:logging", - ], -) - -cc_library( - name = "gpublas_lt_matmul_thunk", - srcs = if_gpu_is_configured(["gpublas_lt_matmul_thunk.cc"]), - hdrs = if_gpu_is_configured(["gpublas_lt_matmul_thunk.h"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - deps = if_gpu_is_configured([ - ":matmul_utils", - ":thunk", - "//xla/service:buffer_assignment", - "//xla:status", - "//xla/stream_executor:device_memory", - "//xla/stream_executor", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", - ]), -) - cc_library( name = "gemm_algorithm_picker", srcs = if_gpu_is_configured(["gemm_algorithm_picker.cc"]), @@ -1561,16 +1699,24 @@ cc_library( deps = if_gpu_is_configured([ ":backend_configs_cc", ":buffer_comparator", - ":gemm_thunk", + ":cublas_cudnn", ":gpu_asm_opts_util", ":gpu_conv_runner", ":ir_emission_utils", ":matmul_utils", ":stream_executor_util", + ":variant_visitor", ":autotuner_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", "//xla:autotune_results_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", "//xla/service:hlo_pass", "//xla:status_macros", "//xla/stream_executor", @@ -1579,14 +1725,14 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor/gpu:redzone_allocator", + "//xla/tsl/util/proto:proto_utils", "//xla:util", + "//xla:autotuning_proto_cc", + "//xla:shape_util", + "//xla:statusor", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logger", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:statusor", - "//xla:autotuning_proto_cc", - "@tsl//tsl/util/proto:proto_utils", - "//xla:statusor", ]), ) @@ -1598,23 +1744,30 @@ cc_library( ":gpu_asm_opts_util", ":stream_executor_util", "@com_google_absl//absl/base", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "//xla/hlo/ir:hlo", + "//xla/service:compilation_environments", + "//xla/stream_executor", + "//xla/stream_executor/gpu:redzone_allocator", "//xla:autotune_results_proto_cc", "//xla:autotuning_proto_cc", - "//xla:status_macros", + "//xla:shape_util", "//xla:status", + "//xla:status_macros", + "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/stream_executor", - "//xla/stream_executor/gpu:redzone_allocator", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", "@tsl//tsl/platform:path", "@tsl//tsl/platform:protobuf", "@tsl//tsl/platform:statusor", @@ -1634,20 +1787,23 @@ cc_library( "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "//xla:shape_util", - "//xla:statusor", - "//xla:util", - "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:compiler", "//xla/service:executable", "//xla/service:hlo_module_config", + "//xla/service:maybe_owning_device_memory", + "//xla/service:shaped_buffer", "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_timer_header", + "//xla:executable_run_options", + "//xla:shape_util", + "//xla:statusor", + "//xla:util", + "//xla:xla_proto_cc", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", ]), @@ -1669,13 +1825,19 @@ xla_test( "gpu_v100", ], deps = [ + ":autotuner_util", ":backend_configs_cc", ":gemm_algorithm_picker", ":gemm_rewriter", + "//xla/hlo/ir:hlo", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", + "//xla/service:platform_util", "//xla/stream_executor:device_description", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream_executor_headers", "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", @@ -1705,13 +1867,17 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/mlir_hlo", - "//xla/mlir_hlo:lhlo_gpu", + "//xla/service:algorithm_util", "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_blas_lt", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", ] + if_cuda_is_configured([ @@ -1721,7 +1887,7 @@ cc_library( "//xla/stream_executor:host_or_device_scalar", ]) + if_rocm_is_configured([ "//xla/stream_executor/rocm:hipblas_lt_header", - "//xla/stream_executor/rocm:hipblaslt_plugin", + "//xla/stream_executor/rocm:amdhipblaslt_plugin", "//xla/stream_executor:host_or_device_scalar", "//xla/stream_executor/platform:dso_loader", ]) + if_static([ @@ -1734,12 +1900,15 @@ xla_cc_test( srcs = ["matmul_utils_test.cc"], deps = [ ":matmul_utils", + "//xla:shape_util", "//xla:test", + "//xla/hlo/ir:hlo", "//xla/service:hlo_parser", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # build_cleaner: keep "@com_google_absl//absl/strings", "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", ], ) @@ -1750,12 +1919,17 @@ cc_library( deps = [ "//xla:permutation_util", "//xla:shape_util", + "//xla:status", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", ], ) @@ -1766,9 +1940,13 @@ xla_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":dot_dimension_sorter", + "//xla:error_spec", + "//xla/hlo/ir:hlo", "//xla/service:gpu_plugin", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_googletest//:gtest", + "@tsl//tsl/platform:statusor", ], ) @@ -1778,8 +1956,15 @@ cc_library( hdrs = ["gpu_async_collective_annotator.h"], deps = [ ":backend_configs_cc", + "//xla:util", + "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -1789,11 +1974,16 @@ xla_cc_test( deps = [ ":backend_configs_cc", ":gpu_async_collective_annotator", + "//xla:util", + "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", ], ) @@ -1803,8 +1993,14 @@ cc_library( hdrs = ["gpu_convert_async_collectives_to_sync.h"], deps = [ ":backend_configs_cc", + "//xla/hlo/ir:hlo", "//xla/service:convert_async_collectives_to_sync", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -1814,10 +2010,15 @@ xla_cc_test( deps = [ ":backend_configs_cc", ":gpu_convert_async_collectives_to_sync", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", ], ) @@ -1832,33 +2033,46 @@ cc_library( ":autotuner_util", ":backend_configs_cc", ":buffer_comparator", + ":cublas_cudnn", ":gpu_asm_opts_util", ":gpu_autotuning_proto_cc", ":gpu_conv_runner", ":hlo_algorithm_denylist", ":stream_executor_util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", "@local_config_cuda//cuda:cudnn_header", "//xla:autotune_results_proto_cc", "//xla:autotuning_proto_cc", - "//xla/hlo/ir:hlo", + "//xla:debug_options_flags", "//xla:literal_util", + "//xla:shape_util", + "//xla:statusor", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", "//xla/service:executable", + "//xla/service:hlo_module_config", "//xla/service:hlo_pass", "//xla/service:slow_operation_alarm", "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:lazy_op_runner", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/gpu:redzone_allocator", "//xla/stream_executor/rocm:rocm_platform_id", - "//xla:util", - "//xla:xla_data_proto_cc", - "@tsl//tsl/platform:logger", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:numbers", - "@tsl//tsl/util/proto:proto_utils", + "@tsl//tsl/platform:statusor", + "//xla/tsl/util/proto:proto_utils", ]), ) @@ -1878,12 +2092,18 @@ xla_test( "gpu_v100", ], deps = [ + ":autotuner_util", ":conv_algorithm_picker", ":gpu_conv_rewriter", + "//xla:debug_options_flags", + "//xla/hlo/ir:hlo", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", + "//xla/service:platform_util", "//xla/service:tuple_simplifier", + "//xla/stream_executor:platform", "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", @@ -1910,7 +2130,15 @@ cc_library( "//xla/stream_executor", "//xla/stream_executor:dnn", "//xla/stream_executor:lazy_op_runner", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@eigen_archive//:eigen3", + "@tsl//tsl/platform:ml_dtypes", + "@tsl//tsl/platform:statusor", ], ) @@ -1933,19 +2161,21 @@ cc_library( "//xla/stream_executor", "//xla/stream_executor:dnn", "//xla/stream_executor:lazy_op_runner", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@tsl//tsl/platform:statusor", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", ]), ) cc_library( - name = "gpu_fused_mha_runner", - srcs = ["gpu_fused_mha_runner.cc"], - hdrs = ["gpu_fused_mha_runner.h"], + name = "gpu_flash_attn", + srcs = if_cuda_is_configured([ + "gpu_flash_attn.cc" + ]), + hdrs = ["gpu_flash_attn.h"], deps = [ - ":backend_configs_cc", - ":cublas_cudnn", ":stream_executor_util", "//xla:shape_util", "//xla:status", @@ -1953,39 +2183,78 @@ cc_library( "//xla:statusor", "//xla:types", "//xla:util", - "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/stream_executor", "//xla/stream_executor:dnn", - "//xla/stream_executor:lazy_op_runner", + "//xla/stream_executor/gpu:gpu_stream", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - ], + ] + if_cuda_is_configured([ + "@flash_attn//:flash_attn", + ]) ) cc_library( - name = "gpu_conv_rewriter", - srcs = ["gpu_conv_rewriter.cc"], - hdrs = ["gpu_conv_rewriter.h"], + name = "gpu_fused_mha_runner", + srcs = ["gpu_fused_mha_runner.cc"], + hdrs = ["gpu_fused_mha_runner.h"], deps = [ ":backend_configs_cc", ":cublas_cudnn", - "//xla:permutation_util", - "//xla:util", - "//xla:window_util", + ":stream_executor_util", + "//xla:shape_util", + "//xla:status", + "//xla:status_macros", + "//xla:statusor", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/stream_executor", + "//xla/stream_executor:dnn", + "//xla/stream_executor:lazy_op_runner", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@eigen_archive//:eigen3", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "gpu_conv_rewriter", + srcs = ["gpu_conv_rewriter.cc"], + hdrs = ["gpu_conv_rewriter.h"], + deps = [ + ":backend_configs_cc", + ":cublas_cudnn", + "//xla:permutation_util", + "//xla:shape_util", + "//xla:util", + "//xla:window_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", ], ) cc_library( name = "gpu_sort_rewriter", - srcs = if_cuda_is_configured(["gpu_sort_rewriter.cc"]), - hdrs = if_cuda_is_configured(["gpu_sort_rewriter.h"]), + srcs = if_gpu_is_configured(["gpu_sort_rewriter.cc"]), + hdrs = if_gpu_is_configured(["gpu_sort_rewriter.h"]), deps = [ - ":cub_sort_thunk", ":cublas_cudnn", "//xla:comparison_util", "//xla:shape_util", @@ -1994,7 +2263,9 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", + "//xla/service/gpu/runtime:cub_sort_thunk", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", @@ -2015,7 +2286,10 @@ cc_library( "//xla/service:hlo_creation_utils", "//xla/service:hlo_pass", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:statusor", @@ -2027,7 +2301,9 @@ xla_cc_test( srcs = ["move_copy_to_users_test.cc"], deps = [ ":move_copy_to_users", + "//xla/service:layout_assignment", "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -2039,7 +2315,10 @@ xla_cc_test( deps = [ ":cublas_cudnn", ":gpu_conv_rewriter", + "//xla:array4d", + "//xla:literal_util", "//xla:protobuf_util", + "//xla:shape_util", "//xla:test", "//xla:test_helpers", "//xla/hlo/ir:hlo", @@ -2048,6 +2327,9 @@ xla_cc_test( "//xla/service:shape_inference", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:str_format", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], ) @@ -2077,8 +2359,12 @@ cc_library( name = "cusolver_context", srcs = if_gpu_is_configured(["cusolver_context.cc"]), hdrs = if_gpu_is_configured(["cusolver_context.h"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), deps = [ "//xla:comparison_util", + "//xla:status", "//xla:statusor", "//xla:types", "//xla:util", @@ -2086,11 +2372,14 @@ cc_library( "//xla/stream_executor", "//xla/stream_executor:blas", "//xla/stream_executor/gpu:gpu_stream", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:status", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", - "@tsl//tsl/cuda:cusolver", + "//xla/tsl/cuda:cusolver", ]) + if_rocm_is_configured([ "@local_config_rocm//rocm:rocm_headers", "//xla/stream_executor/rocm:rocblas_wrapper", @@ -2106,18 +2395,27 @@ cc_library( deps = if_gpu_is_configured([ ":cusolver_context", ":ir_emission_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "//xla:comparison_util", "//xla:literal", "//xla:literal_util", + "//xla:shape_util", + "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", "//xla/stream_executor", - "@tsl//tsl/platform:logging", "//xla/stream_executor:blas", "//xla/stream_executor:device_memory_allocator", - "@com_google_absl//absl/algorithm:container", - ]) + ["@tsl//tsl/platform:status"], + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", + ]), ) cc_library( @@ -2138,6 +2436,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], ) @@ -2146,19 +2445,25 @@ xla_cc_test( name = "instruction_fusion_test", srcs = ["instruction_fusion_test.cc"], tags = [ + "no_aarch64", "nomsan", ], deps = [ ":gpu_device_info_for_tests", ":gpu_fusible", ":instruction_fusion", + "//xla:literal_util", + "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "//xla/tests:test_utils", + "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", ], ) @@ -2166,7 +2471,54 @@ tf_proto_library( name = "fusion_process_dump_proto", srcs = ["fusion_process_dump.proto"], cc_api_version = 2, - protodeps = [], + protodeps = [ + "//xla/stream_executor:device_description_proto", + ], +) + +cc_library( + name = "fusion_process_dump", + srcs = ["fusion_process_dump.cc"], + hdrs = ["fusion_process_dump.h"], + deps = [ + ":fusion_process_dump_proto_cc", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_graph_dumper", + "//xla/stream_executor:stream_executor_headers", + "//xla/tools:hlo_module_loader", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:path", + "@tsl//tsl/platform:protobuf", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "fusion_process_dump_test", + srcs = ["fusion_process_dump_test.cc"], + deps = [ + ":fusion_process_dump", + ":fusion_process_dump_proto_cc", + ":gpu_device_info_for_tests", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@tsl//tsl/platform:statusor", + ], ) cc_library( @@ -2176,32 +2528,37 @@ cc_library( deps = [ ":fusion_process_dump_proto_cc", ":gpu_fusible", + ":hlo_fusion_analysis", ":hlo_traversal", + "//xla:debug_options_flags", "//xla:shape_util", "//xla:statusor", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:dump", - "//xla/service:fusion_node_indexing_evaluation", "//xla/service:fusion_queue", "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_graph_dumper", "//xla/service:hlo_pass", "//xla/service:instruction_fusion", "//xla/service/gpu/model:fusion_analysis_cache", "//xla/service/gpu/model:gpu_hlo_cost_analysis", "//xla/service/gpu/model:gpu_performance_model", + "//xla/service/gpu/model:gpu_performance_model_base", "//xla/stream_executor:device_description", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", + "@llvm-project//llvm:Support", "@tsl//tsl/platform:blocking_counter", "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:status", ], @@ -2224,6 +2581,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", "@tsl//tsl/platform:status_matchers", ], @@ -2239,17 +2597,26 @@ cc_library( "//xla:shape_util", "//xla:statusor", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_dfs_reachability", "//xla/hlo/ir:hlo_reachability", + "//xla/service:hlo_cost_analysis", "//xla/service:hlo_graph_dumper", "//xla/service:hlo_pass", "//xla/service:instruction_fusion", "//xla/service/gpu/model:gpu_hlo_cost_analysis", "//xla/service/gpu/model:gpu_performance_model", + "//xla/service/gpu/model:gpu_performance_model_base", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", ], ) @@ -2263,6 +2630,9 @@ xla_cc_test( ":gpu_device_info_for_tests", ":gpu_fusible", ":multi_output_fusion", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/stream_executor:device_description", @@ -2272,6 +2642,34 @@ xla_cc_test( ], ) +cc_library( + name = "rename_fusions", + srcs = ["rename_fusions.cc"], + hdrs = ["rename_fusions.h"], + deps = [ + ":hlo_traversal", + ":ir_emission_utils", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +xla_cc_test( + name = "rename_fusions_test", + srcs = ["rename_fusions_test.cc"], + deps = [ + ":rename_fusions", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + ], +) + xla_cc_test( name = "softmax_rewriter_triton_test", srcs = ["softmax_rewriter_triton_test.cc"], @@ -2280,6 +2678,7 @@ xla_cc_test( "//xla:shape_util", "//xla:statusor", "//xla/hlo/ir:hlo", + "//xla/service:instruction_fusion", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/stream_executor:device_description", @@ -2296,6 +2695,24 @@ xla_cc_test( ], ) +cc_library( + name = "gpu_flash_attn_normalization", + srcs = ["gpu_flash_attn_normalization.cc"], + hdrs = ["gpu_flash_attn_normalization.h"], + deps = [ + ":backend_configs_cc", + ":gpu_flash_attn", + "//xla/hlo/ir:hlo", + "//xla:literal_util", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:status", + ], +) + cc_library( name = "gpu_sanitize_constant_names", srcs = ["gpu_sanitize_constant_names.cc"], @@ -2303,7 +2720,11 @@ cc_library( deps = [ "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", + "//xla/service:name_uniquer", "//xla/service/llvm_ir:buffer_assignment_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:status", ], @@ -2314,12 +2735,13 @@ xla_cc_test( srcs = ["gpu_sanitize_constant_names_test.cc"], deps = [ ":gpu_sanitize_constant_names", - "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", + "//xla:literal_util", + "//xla/hlo/ir:hlo", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], ) @@ -2333,13 +2755,22 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", "//xla/service:hlo_graph_dumper", "//xla/service:hlo_pass", + "//xla/service:instruction_fusion", "//xla/service/gpu/model:gpu_hlo_cost_analysis", "//xla/service/gpu/model:gpu_performance_model", + "//xla/service/gpu/model:gpu_performance_model_base", "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:status", ], ) @@ -2353,12 +2784,16 @@ xla_cc_test( ":fusion_merger", ":gpu_device_info_for_tests", ":gpu_fusible", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", ], ) @@ -2370,6 +2805,7 @@ cc_library( ":cublas_cudnn", "//xla:literal", "//xla:literal_util", + "//xla:shape_util", "//xla:util", "//xla:window_util", "//xla:xla_data_proto_cc", @@ -2377,6 +2813,13 @@ cc_library( "//xla/service:hlo_creation_utils", "//xla/service:hlo_pass", "//xla/service:shape_inference", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", ], ) @@ -2405,10 +2848,15 @@ cc_library( deps = [ ":cublas_cudnn", "//xla:shape_util", + "//xla:statusor", + "//xla:util", "//xla:window_util", "//xla/hlo/ir:hlo", - "//xla/stream_executor", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/status:statusor", + "@tsl//tsl/platform:logging", "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", ], ) @@ -2417,23 +2865,23 @@ xla_cc_test( srcs = ["cudnn_support_utils_test.cc"], deps = [ ":cudnn_support_utils", - "//xla:status_macros", + "//xla:shape_util", + "//xla:statusor", "//xla:test", "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:hlo_parser", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/stream_executor", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:status", + "@tsl//tsl/platform:logging", "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", ], ) @@ -2442,17 +2890,27 @@ cc_library( srcs = ["cudnn_pad_for_convolutions.cc"], hdrs = ["cudnn_pad_for_convolutions.h"], deps = [ + ":cublas_cudnn", ":cudnn_support_utils", ":ir_emission_utils", ":stream_executor_util", "//xla:literal_util", + "//xla:shape_util", + "//xla:status", + "//xla:statusor", "//xla:util", "//xla:window_util", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", "//xla/stream_executor", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:bind_front", - "@tsl//tsl/platform:status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", ], ) @@ -2462,13 +2920,12 @@ xla_cc_test( deps = [ ":cublas_cudnn", ":cudnn_pad_for_convolutions", - "//xla:status_macros", - "//xla:util", "//xla/service:hlo_parser", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # build_cleaner: keep + "@com_google_googletest//:gtest", ], ) @@ -2482,11 +2939,25 @@ cc_library( ":cudnn_support_utils", ":stream_executor_util", "//xla:shape_util", + "//xla:status", "//xla:statusor", + "//xla:util", "//xla/client:xla_builder", + "//xla/client:xla_computation", "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", "//xla/service:hlo_pass", "//xla/stream_executor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", ], ) @@ -2497,13 +2968,20 @@ xla_cc_test( ":backend_configs_cc", ":cublas_cudnn", ":cudnn_vectorize_convolutions", + "//xla:statusor", "//xla:util", "//xla/service:call_inliner", "//xla/service:hlo_parser", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", + "//xla/stream_executor:device_description", + "//xla/stream_executor:stream_executor_headers", + "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # build_cleaner: keep + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", ], ) @@ -2515,11 +2993,22 @@ cc_library( deps = [ ":backend_configs_cc", ":cublas_cudnn", + "//xla:literal", "//xla:statusor", "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", "//xla/service:hlo_creation_utils", "//xla/service:hlo_pass", "//xla/service:pattern_matcher", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", ], ) @@ -2527,11 +3016,11 @@ xla_cc_test( name = "cudnn_simplify_padding_test", srcs = ["cudnn_simplify_padding_test.cc"], deps = [ - ":cublas_cudnn", ":cudnn_pad_for_convolutions", ":cudnn_simplify_padding", ":cudnn_vectorize_convolutions", - "//xla:status_macros", + "//xla:literal", + "//xla:statusor", "//xla:util", "//xla/service:algebraic_simplifier", "//xla/service:call_inliner", @@ -2541,9 +3030,17 @@ xla_cc_test( "//xla/service:pattern_matcher_gmock", "//xla/service:reshape_mover", "//xla/service:tuple_simplifier", + "//xla/stream_executor:device_description", + "//xla/stream_executor:stream_executor_headers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # build_cleaner: keep + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", "@tsl//tsl/platform:statusor", ], ) @@ -2553,13 +3050,23 @@ cc_library( srcs = ["cublas_pad_for_gemms.cc"], hdrs = ["cublas_pad_for_gemms.h"], deps = [ - ":gemm_rewriter_triton", + ":gemm_fusion", ":ir_emission_utils", + ":triton_support", "//xla:literal_util", + "//xla:shape_util", + "//xla:statusor", "//xla:util", - "//xla:window_util", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", ], ) @@ -2568,6 +3075,8 @@ cc_library( srcs = ["cublas_padding_requirements.cc"], hdrs = ["cublas_padding_requirements.h"], deps = [ + ":variant_visitor", + "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", "//xla/stream_executor:device_description", @@ -2582,19 +3091,55 @@ xla_cc_test( ], deps = [ ":cublas_pad_for_gemms", + "//xla/hlo/ir:hlo", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", + "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # build_cleaner: keep + "@com_google_googletest//:gtest", ], ) +cc_library( + name = "cudnn_fusion_compiler", + srcs = if_cuda_is_configured(["cudnn_fusion_compiler.cc"]), + hdrs = if_cuda_is_configured(["cudnn_fusion_compiler.h"]), + deps = if_cuda_is_configured([ + ":backend_configs_cc", + ":ir_emission_utils", + ":kernel_reuse_cache", + ":matmul_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_config_cuda//cuda:cudnn_header", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service/gpu:triton_fusion_analysis", + "//xla/service:hlo_pass", + "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor/cuda:cudnn_frontend_helpers", + "//xla/stream_executor/cuda:cudnn_plugin", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + "@com_google_absl//absl/container:flat_hash_set", + ]), +) + tf_proto_library( name = "executable_proto", srcs = ["executable.proto"], cc_api_version = 2, protodeps = [ "//xla/service:hlo_proto", + "//xla:xla_proto", ], ) @@ -2621,15 +3166,26 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/service:compiler", "//xla/service:generic_transfer_manager", + "//xla/service:shaped_buffer", "//xla/service:transfer_manager", "//xla/stream_executor", + "//xla/stream_executor:memory_allocation", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/host:host_platform_id", "//xla/stream_executor/rocm:rocm_platform_id", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@llvm-project//llvm:Core", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:numbers", + "@tsl//tsl/platform:statusor", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -2639,10 +3195,18 @@ cc_library( srcs = ["gpu_reduce_scatter_creator.cc"], hdrs = ["gpu_reduce_scatter_creator.h"], deps = [ + "//xla:shape_util", + "//xla:status_macros", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/service:collective_opt_utils", + "//xla/service:hlo_module_config", "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", ], ) @@ -2657,6 +3221,7 @@ cc_library( "//xla/service:collective_ops_utils", "//xla/service:hlo_pass", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", @@ -2671,6 +3236,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:float_support", + "//xla/stream_executor:device_description", "@com_google_absl//absl/log:check", ], ) @@ -2684,51 +3250,46 @@ cc_library( "compile_module_to_llvm_ir.h", ], deps = [ - ":buffer_sharing", ":executable_proto_cc", ":gpu_constants", ":gpu_executable", + ":gpu_memory_space_assignment", ":ir_emitter_context", ":ir_emitter_unnested", ":metrics", ":runtime_intrinsics", - ":thunk", "//xla:shape_util", "//xla:status", - "//xla:status_macros", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/mlir/backends/gpu/transforms:passes", - "//xla/mlir/runtime/transforms:compilation_pipeline_gpu", - "//xla/mlir_hlo:transforms_gpu_passes", "//xla/service:buffer_assignment", "//xla/service:buffer_value", "//xla/service:dump", "//xla/service:hlo_dataflow_analysis", + "//xla/service:hlo_ordering", "//xla/service:hlo_proto_cc", - "//xla/service/llvm_ir:llvm_util", + "//xla/service:logical_buffer", + "//xla/service/gpu/runtime:conditional_thunk", + "//xla/service/gpu/runtime:sequential_thunk", + "//xla/service/gpu/runtime:thunk", + "//xla/service/gpu/runtime:while_thunk", "//xla/stream_executor", "//xla/stream_executor:device_description", "//xla/stream_executor/rocm:rocm_platform_id", - "//xla/translate/hlo_to_mhlo:hlo_utils", - "//xla/translate/mhlo_to_hlo:location_exporter", - "//xla/translate/mhlo_to_lhlo_with_xla", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:AsmParser", - "@llvm-project//llvm:Support", "@llvm-project//llvm:TransformUtils", "@llvm-project//llvm:ir_headers", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", "@tsl//tsl/platform:statusor", ], ) @@ -2738,16 +3299,29 @@ cc_library( srcs = ["command_buffer_scheduling.cc"], hdrs = ["command_buffer_scheduling.h"], deps = [ + ":backend_configs_cc", + ":cublas_cudnn", + ":hlo_fusion_analysis", + ":hlo_traversal", + ":ir_emission_utils", + ":variant_visitor", "//xla:shape_util", + "//xla:status", "//xla:statusor", "//xla:util", + "//xla/ffi:ffi_api", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", + "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", "@tsl//tsl/platform:statusor", ], ) @@ -2759,30 +3333,34 @@ xla_cc_test( ":command_buffer_scheduling", "//xla/hlo/ir:hlo", "//xla/service:hlo_parser", + "//xla/stream_executor:device_description", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_googletest//:gtest_main", + "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", ], ) cc_library( - name = "custom_fusion_rewriter", - srcs = ["custom_fusion_rewriter.cc"], - hdrs = ["custom_fusion_rewriter.h"], + name = "custom_kernel_fusion_rewriter", + srcs = ["custom_kernel_fusion_rewriter.cc"], + hdrs = ["custom_kernel_fusion_rewriter.h"], deps = [ + "//xla:shape_util", "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", "//xla/service/gpu/kernels:custom_fusion_library", - "//xla/service/gpu/kernels:custom_fusion_pattern", + "//xla/service/gpu/kernels:custom_kernel_fusion_pattern", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", @@ -2792,13 +3370,76 @@ cc_library( ) xla_cc_test( - name = "custom_fusion_rewriter_test", - srcs = ["custom_fusion_rewriter_test.cc"], + name = "custom_kernel_fusion_rewriter_test", + srcs = ["custom_kernel_fusion_rewriter_test.cc"], + deps = [ + ":custom_kernel_fusion_rewriter", + ":gpu_device_info_for_tests", + "//xla/hlo/ir:hlo", + "//xla/service/gpu/kernels:custom_kernel_fusion_pattern", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "address_computation_fusion_rewriter", + srcs = ["address_computation_fusion_rewriter.cc"], + hdrs = ["address_computation_fusion_rewriter.h"], + deps = [ + ":backend_configs_cc", + ":cublas_cudnn", + ":gpu_constants", + ":hlo_traversal", + ":ir_emission_utils", + "//xla:shape_util", + "//xla:status", + "//xla:util", + "//xla/ffi:ffi_api", + "//xla/ffi/api:c_api", + "//xla/hlo/ir:hlo", + "//xla/service:custom_call_target_registry", + "//xla/service:hlo_pass", + "//xla/service/gpu/kernels:custom_fusion_library", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "address_computation_fusion_rewriter_test", + srcs = if_cuda_is_configured(["address_computation_fusion_rewriter_test.cc"]), deps = [ - ":custom_fusion_rewriter", + ":address_computation_fusion_rewriter", + ":gpu_device_info_for_tests", + "//xla:shape_util", + "//xla/client:xla_builder", + "//xla/client/lib:constants", + "//xla/ffi", + "//xla/ffi:ffi_api", "//xla/hlo/ir:hlo", - "//xla/service/gpu/kernels:custom_fusion_pattern", + "//xla/service:buffer_value", + "//xla/service:custom_call_target_registry", + "//xla/service:executable", + "//xla/service:hlo_memory_scheduler", + "//xla/service:hlo_module_config", + "//xla/stream_executor", + "//xla/stream_executor/gpu:gpu_types_header", "//xla/tests:hlo_test_base", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -2815,6 +3456,7 @@ cc_library( ":instruction_fusion", ":multi_output_fusion", ":priority_fusion", + ":rename_fusions", ":variadic_op_splitter", "//xla:xla_proto_cc", "//xla/service:cpu_gpu_shape_verifier", @@ -2838,6 +3480,7 @@ cc_library( deps = [ ":alias_passthrough_params", ":copy_fusion", + ":gpu_flash_attn_normalization", ":gpu_sanitize_constant_names", ":horizontal_loop_fusion", "//xla:xla_proto_cc", @@ -2865,20 +3508,26 @@ cc_library( "TENSORFLOW_USE_ROCM=1", ]), deps = if_gpu_is_configured([ + ":gpu_p2p_pipeliner", + ":collective_permute_cycle_decomposer", + ":address_computation_fusion_rewriter", + ":algorithm_checker", ":alias_passthrough_params", ":all_reduce_blueconnect", + ":stream_attribute_async_wrapper", ":autotuner_util", ":buffer_sharing", ":compile_module_to_llvm_ir", ":conv_layout_normalization", ":copy_fusion", - ":custom_fusion_rewriter", + ":custom_kernel_fusion_rewriter", ":dot_dimension_sorter", + ":dot_operand_converter", ":executable_proto_cc", ":fusion_merger", ":fusion_wrapper", ":gemm_broadcast_folding_rewriter", - ":gemm_rewriter_triton", + ":gemm_fusion", ":gemm_rewriter", ":gpu_all_gather_optimizer", ":gpu_async_collective_annotator", @@ -2892,6 +3541,7 @@ cc_library( ":gpu_reduce_scatter_creator", ":gpu_sanitize_constant_names", ":gpu_scatter_expander", + ":gpu_windowed_einsum_handler", ":hlo_fusion_stats", ":horizontal_input_fusion", ":horizontal_loop_fusion", @@ -2912,6 +3562,7 @@ cc_library( ":runtime_intrinsics", ":scatter_slice_simplifier", ":softmax_rewriter_triton", + ":stream_attribute_annotator", ":topk_specializer", ":topk_splitter", ":tree_reduction_rewriter", @@ -2920,7 +3571,9 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:variant", + "@local_config_cuda//cuda:cuda_headers", "@llvm-project//llvm:AsmParser", "@llvm-project//llvm:BitReader", "@llvm-project//llvm:BitWriter", @@ -2939,9 +3592,6 @@ cc_library( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/transforms:hlo_constant_splitter", - "//xla/mlir/backends/gpu/transforms:passes", - "//xla/mlir/runtime/transforms:compilation_pipeline_gpu", - "//xla/runtime:jit_executable", "//xla/service:algebraic_simplifier", "//xla/service:all_gather_broadcast_reorder", "//xla/service:all_gather_combiner", @@ -2963,6 +3613,7 @@ cc_library( "//xla/service:conditional_canonicalizer", "//xla/service:conditional_simplifier", "//xla/service:convert_async_collectives_to_sync", + "//xla/service:convert_memory_placement_to_internal_annotations", "//xla/service:convert_mover", "//xla/service:convolution_4d_expander", "//xla/service:convolution_pred_expander", @@ -2993,12 +3644,16 @@ cc_library( "//xla/service:hlo_proto_cc", "//xla/service:hlo_rematerialization", "//xla/service:hlo_verifier", + "//xla/service:host_memory_transfer_asyncifier", + "//xla/service:host_offload_legalize", + "//xla/service:host_offloader", "//xla/service:layout_normalization", "//xla/service:llvm_compiler", "//xla/service:logistic_expander", "//xla/service:loop_schedule_linearizer", "//xla/service:operand_upcaster", "//xla/service:optimization_barrier_expander", + "//xla/service:optimize_input_output_buffer_alias", "//xla/service:qr_expander", "//xla/service:real_imag_expander", "//xla/service:reduce_decomposer", @@ -3035,10 +3690,11 @@ cc_library( "//xla/stream_executor:device_description_proto_cc", "//xla/stream_executor:device_description", "//xla/stream_executor", + "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/cuda:cuda_platform_id", + "//xla/stream_executor/integrations:device_mem_allocator", "//xla/translate/hlo_to_mhlo:hlo_utils", "//xla/translate/mhlo_to_hlo:location_exporter", - "//xla/translate/mhlo_to_lhlo_with_xla", "@tsl//tsl/platform:blocking_counter", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:env", @@ -3059,39 +3715,82 @@ cc_library( "@com_google_absl//absl/types:span", "//xla:shape_util", "//xla/hlo/ir:hlo_module_group", - "//xla/mlir/runtime/transforms:compilation_pipeline_options", - "//xla/runtime:compiler", - "//xla/runtime:executable", "//xla/service:buffer_value", "//xla/service:dynamic_dimension_inference", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_ordering", "//xla/service:layout_assignment", "//xla/service:logical_buffer", - "//xla/service/gpu/runtime:executable", + "//xla/stream_executor/rocm:rocm_platform_id", "@tsl//tsl/platform:numbers", ]) + xla_export_hlo_deps() + [ ":command_buffer_scheduling", - ":fusion_merger_triton", ":fusion_pipeline", ":ir_emitter_context", ":ir_emitter_unnested", ":prepare_hlo_for_ir_emitting_pipeline", - ":thunk", + ":rename_fusions", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor:platform_manager", "@llvm-project//mlir:FuncDialect", "@tsl//tsl/lib/monitoring:counter", ], ) -xla_cc_test( +xla_test( name = "gpu_compiler_test", - srcs = ["gpu_compiler_test.cc"], + srcs = if_gpu_is_configured(["gpu_compiler_test.cc"]), + backends = ["gpu"], + data = ["gpu_compiler_test_autotune_db.textproto"], + deps = [ + ":gpu_compiler", + ":gpu_hlo_schedule", + ":metrics", + "//xla:autotune_results_proto_cc", + "//xla:error_spec", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service:executable", + "//xla/service:hlo_module_config", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service:xla_debug_info_manager", + "//xla/service/gpu:autotuner_util", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:path", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + ], +) + +xla_cc_test( + name = "gpu_offloading_test", + srcs = ["gpu_offloading_test.cc"], tags = tf_cuda_tests_tags(), deps = [ ":horizontal_loop_fusion", ":metrics", "//xla:autotune_results_proto_cc", + "//xla:error_spec", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:buffer_assignment", + "//xla/service:buffer_value", "//xla/service:gpu_plugin", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_memory_scheduler", + "//xla/service:hlo_rematerialization", + "//xla/service:hlo_rematerialization_test_utils", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service:xla_debug_info_manager", @@ -3099,9 +3798,11 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/base:log_severity", "@com_google_absl//absl/log:scoped_mock_log", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", ], ) @@ -3110,12 +3811,16 @@ xla_cc_test( srcs = ["auto_sharding_gpu_compiler_test.cc"], tags = tf_cuda_tests_tags() + ["no_oss"], # TODO(b/277355322): Make autosharding work in OSS deps = [ - "//xla/hlo/utils:hlo_matchers", + "//xla:shape_util", + "//xla/hlo/ir:hlo", "//xla/service:gpu_plugin", + "//xla/service:hlo_module_config", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@tsl//tsl/platform:logging", ], ) @@ -3125,8 +3830,9 @@ cc_library( "nvptx_compiler_registration.cc", ]), deps = if_cuda_is_configured([ - "//xla/stream_executor/cuda:cuda_platform_id", ":nvptx_compiler_impl", + "//xla/service:compiler", + "//xla/stream_executor/cuda:cuda_platform_id", "@tsl//tsl/platform:path", ]), alwayslink = True, # Contains compiler registration @@ -3140,24 +3846,32 @@ cc_library( hdrs = if_cuda_is_configured([ "nvptx_compiler.h", ]), + local_defines = select({ + "//xla/stream_executor/cuda:libnvptxcompiler_support_enabled": [ + "ENABLE_LIBNVPTXCOMPILER_SUPPORT=1", + ], + "//conditions:default": [], + }), deps = if_cuda_is_configured([ ":autotuner_util", ":buffer_sharing", + ":conv_algorithm_picker", ":cublas_cudnn", ":cublas_pad_for_gemms", ":cublas_padding_requirements", ":cudnn_fused_conv_rewriter", ":cudnn_fused_mha_rewriter", ":cudnn_fused_mha_transpose_fusion", + ":cudnn_fusion_compiler", ":cudnn_norm_rewriter", ":cudnn_pad_for_convolutions", ":cudnn_simplify_padding", ":cudnn_vectorize_convolutions", ":cusolver_rewriter", ":gemm_algorithm_picker", + ":gemm_fusion_autotuner", ":gpu_asm_opts_util", ":gpu_compiler", - ":conv_algorithm_picker", ":gpu_conv_padding_legalization", ":gpu_conv_rewriter", ":gpu_executable", @@ -3168,81 +3882,102 @@ cc_library( ":move_copy_to_users", ":target_constants", ":triangular_solve_rewriter", - ":triton_autotuner", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + "@llvm-project//llvm:ir_headers", "@llvm-project//llvm:IRReader", "@llvm-project//llvm:Support", "//xla:autotune_results_proto_cc", + "//xla:status", "//xla:status_macros", "//xla:statusor", "//xla:types", "//xla:util", "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", "//xla/service:algebraic_simplifier", "//xla/service:call_inliner", "//xla/service:convert_mover", - "//xla/service:dump", - "//xla/hlo/ir:hlo", "//xla/service:dot_dimension_merger", + "//xla/service:dump", "//xla/service:float_normalization", "//xla/service:float_support", "//xla/service:hlo_constant_folding", "//xla/service:hlo_cse", + "//xla/service:hlo_dataflow_analysis", "//xla/service:hlo_dce", + "//xla/service:hlo_module_config", "//xla/service:hlo_pass", "//xla/service:hlo_pass_pipeline", "//xla/service:hlo_proto_cc", "//xla/service:hlo_verifier", + "//xla/service:layout_normalization", "//xla/service:llvm_compiler", + "//xla/service:reshape_decomposer", "//xla/service:reshape_mover", "//xla/service:tuple_simplifier", "//xla/service/gpu/llvm_gpu_backend", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor", - "//xla/stream_executor:stream_executor_internal", + "//xla/stream_executor:stream_executor_headers", "//xla/stream_executor/cuda:cuda_diagnostics", "//xla/stream_executor/cuda:cuda_platform_id", + "//xla/stream_executor/cuda:ptx_compiler", + "//xla/stream_executor/cuda:ptx_compiler_support", "//xla/stream_executor/gpu:asm_compiler", + "//xla/stream_executor/gpu:gpu_asm_opts", "//xla/stream_executor/gpu:gpu_driver_header", + "//xla/stream_executor/gpu:gpu_executor_header", "@tsl//tsl/platform:cuda_libdevice_path", "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:path", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", - "@tsl//tsl/util:env_var", - "//xla/service:layout_normalization", - "//xla/service:reshape_decomposer", + "//xla/tsl/util:env_var", ]), ) -xla_cc_test( +xla_test( name = "nvptx_compiler_test", srcs = if_gpu_is_configured([ "nvptx_compiler_test.cc", ]), + backends = [ + "gpu_v100", + "gpu_a100", + ], tags = [ "nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan. - "requires-gpu-sm70", ], deps = [ - ":gpu_compiler", + ":gpu_constants", + ":gpu_hlo_schedule", ":nvptx_compiler_impl", - "//xla:statusor", "//xla:util", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/service:backend", "//xla/service:buffer_assignment", - "//xla/service:gpu_plugin", + "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", "@tsl//tsl/lib/core:status_test_util", @@ -3255,14 +3990,12 @@ xla_cc_test( srcs = if_gpu_is_configured([ "gpu_aot_compilation_test.cc", ]), - env = { - "XLA_FLAGS": "--xla_gpu_enable_xla_runtime_executable", - }, local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), tags = [ "gpu", + "ignore_for_dep=third_party/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h", "no_oss", "nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan. "requires-gpu-nvidia", @@ -3279,9 +4012,9 @@ xla_cc_test( "//xla/service:executable", "//xla/service:gpu_plugin", "//xla/service:platform_util", - "//xla/stream_executor:multi_platform_manager", + "//xla/stream_executor", "//xla/stream_executor:platform", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor:platform_manager", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # build_cleaner: keep "@com_google_absl//absl/strings", @@ -3292,34 +4025,40 @@ xla_cc_test( cc_library( name = "amdgpu_compiler", - srcs = if_rocm_is_configured([ + srcs = [ "amdgpu_compiler_registration.cc", - ]), - deps = if_rocm_is_configured([ + ], + tags = ["manual"], + deps = [ ":amdgpu_compiler_impl", - "//xla/stream_executor/host:host_platform_id", - ]), + "//xla/service:compiler", + "//xla/stream_executor/rocm:rocm_platform_id", + ], alwayslink = True, # Contains compiler registration ) cc_library( name = "amdgpu_compiler_impl", - srcs = if_rocm_is_configured([ + srcs = [ "amdgpu_compiler.cc", - ]), - hdrs = if_rocm_is_configured([ + ], + hdrs = [ "amdgpu_compiler.h", - ]), - deps = if_rocm_is_configured([ + ], + tags = ["manual"], + deps = [ ":autotuner_util", + ":conv_algorithm_picker", + ":cublas_pad_for_gemms", + ":cublas_padding_requirements", ":cusolver_rewriter", - ":gemm_rewriter", ":gemm_algorithm_picker", + ":gemm_rewriter", ":gpu_compiler", - ":conv_algorithm_picker", ":gpu_conv_padding_legalization", ":gpu_conv_rewriter", ":gpu_layout_assignment", + ":gpu_sort_rewriter", ":reduction_degenerate_dim_remover", ":reduction_dimension_grouper", ":reduction_layout_normalizer", @@ -3327,21 +4066,32 @@ cc_library( ":tree_reduction_rewriter", ":triangular_solve_rewriter", "//xla:statusor", + "//xla:util", "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", "//xla/service:algebraic_simplifier", "//xla/service:call_inliner", - "//xla/hlo/ir:hlo", + "//xla/service:dot_dimension_merger", "//xla/service:hlo_constant_folding", "//xla/service:hlo_cse", + "//xla/service:hlo_module_config", "//xla/service:hlo_pass", "//xla/service:hlo_pass_pipeline", "//xla/service:hlo_verifier", "//xla/service:tuple_simplifier", "//xla/service/gpu/llvm_gpu_backend", "//xla/service/llvm_ir:llvm_util", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream_executor_headers", "//xla/stream_executor/rocm:rocm_platform_id", - "@tsl//tsl/platform:rocm_rocdl_path", - ]), + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:ir_headers", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], ) cc_library( @@ -3354,11 +4104,18 @@ cc_library( "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", + "//xla/service:computation_placer_hdr", "//xla/service:hlo_creation_utils", "//xla/service:hlo_pass", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", ], ) @@ -3369,11 +4126,15 @@ xla_cc_test( ":all_reduce_blueconnect", "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/service:computation_placer_hdr", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", - "//xla/tests:test_utils", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ], ) @@ -3408,10 +4169,18 @@ cc_library( "//xla:types", "//xla:util", "//xla/stream_executor", + "//xla/stream_executor:stream_executor_headers", "//xla/stream_executor/gpu:gpu_executor_header", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:notification", + "@tsl//tsl/platform:statusor", ], ) @@ -3421,22 +4190,29 @@ cc_library( hdrs = ["gpu_layout_assignment.h"], deps = [ ":backend_configs_cc", - ":ir_emission_utils", + ":cublas_cudnn", + ":gpu_flash_attn", ":matmul_utils", + ":reduction_utils", ":stream_executor_util", + "//xla:shape_layout", "//xla:shape_util", - "//xla:status_macros", + "//xla:status", + "//xla:util", "//xla:window_util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:computation_layout", "//xla/service:layout_assignment", + "//xla/service:logical_buffer", "//xla/stream_executor", - "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", ], ) @@ -3445,9 +4221,8 @@ xla_cc_test( srcs = ["gpu_layout_assignment_test.cc"], tags = tf_cuda_tests_tags(), deps = [ - ":cublas_cudnn", - ":gemm_rewriter", ":gpu_layout_assignment", + ":stream_executor_util", "//xla:shape_layout", "//xla:shape_util", "//xla:xla_data_proto_cc", @@ -3456,10 +4231,14 @@ xla_cc_test( "//xla/service:hlo_parser", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", + "//xla/stream_executor", + "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # build_cleaner: keep - "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", ], ) @@ -3476,6 +4255,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", @@ -3486,6 +4266,7 @@ xla_cc_test( name = "gpu_schedule_postprocessing_test", srcs = ["gpu_schedule_postprocessing_test.cc"], deps = [ + ":backend_configs_cc", ":gpu_schedule_postprocessing", "//xla:util", "//xla/hlo/ir:hlo", @@ -3510,9 +4291,11 @@ cc_library( "//xla:status", "//xla:statusor", "//xla:util", + "//xla/hlo/experimental/auto_reorder:auto_reorder", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/service:buffer_value", + "//xla/service:collective_ops_utils", "//xla/service:hlo_memory_scheduler", "//xla/service:hlo_pass_pipeline", "//xla/service:latency_hiding_scheduler", @@ -3526,11 +4309,14 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:path", "@tsl//tsl/platform:protobuf", + "@tsl//tsl/platform:statusor", ], ) @@ -3556,6 +4342,7 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", @@ -3563,6 +4350,44 @@ xla_cc_test( ], ) +cc_library( + name = "gpu_p2p_pipeliner", + srcs = ["gpu_p2p_pipeliner.cc"], + hdrs = ["gpu_p2p_pipeliner.h"], + deps = [ + "//xla:status", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:collective_ops_utils", + "//xla/service:collective_pipeliner", + "//xla/service:hlo_parser", + "//xla/service:hlo_pass_pipeline", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + ], +) + +xla_cc_test( + name = "gpu_p2p_pipeliner_test", + srcs = [ + "gpu_p2p_pipeliner_test.cc", + ], + deps = [ + ":gpu_p2p_pipeliner", + "//xla:statusor", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", + "//xla/service:hlo_parser", + "//xla/service:hlo_pass_pipeline", + "//xla/service:hlo_verifier", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/log:check", + "@com_google_googletest//:gtest", + ], +) + xla_cc_test( name = "while_transformer_test", srcs = ["while_transformer_test.cc"], @@ -3570,16 +4395,15 @@ xla_cc_test( "nomsan", ], deps = [ - ":instruction_fusion", + "//xla:comparison_util", + "//xla:literal_util", "//xla:shape_util", "//xla:test", - "//xla:test_helpers", - "//xla/service:copy_insertion", - "//xla/service:hlo_verifier", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", "//xla/service:while_loop_analysis", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", ], ) @@ -3600,13 +4424,38 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", "//xla/stream_executor", + "//xla/stream_executor:launch_dim", + "//xla/tsl/util:env_var", + "//xla/tsl/util/proto:proto_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", + "@eigen_archive//:eigen3", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:regexp", - "@tsl//tsl/profiler/lib:traceme", - "@tsl//tsl/util:env_var", - "@tsl//tsl/util/proto:proto_utils", + "@tsl//tsl/platform:ml_dtypes", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "stream_executor_util_test", + srcs = ["stream_executor_util_test.cc"], + deps = [ + ":stream_executor_util", + "//xla:autotuning_proto_cc", + "//xla/service:hlo_module_config", + "//xla/tsl/util/proto:proto_utils", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest_main", ], ) @@ -3626,35 +4475,25 @@ cc_library( cc_library( name = "hlo_fusion_analysis", srcs = ["hlo_fusion_analysis.cc"], - hdrs = [ - "hlo_fusion_analysis.h", - "kernel_mapping_scheme.h", - ], + hdrs = ["hlo_fusion_analysis.h"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":backend_configs_cc", - ":gpu_fusible", ":hlo_traversal", ":ir_emission_utils", ":launch_dimensions", ":reduction_utils", "//xla:shape_util", "//xla:statusor", - "//xla:union_find", - "//xla:util", "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service/llvm_ir:loop_emitter", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:ir_headers", - "@tsl//tsl/platform:macros", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", ], ) @@ -3666,8 +4505,11 @@ xla_cc_test( ":gpu_device_info_for_tests", ":hlo_fusion_analysis", ":hlo_traversal", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_description_proto_cc", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", "@tsl//tsl/platform:statusor", ], ) @@ -3680,20 +4522,24 @@ cc_library( "TENSORFLOW_USE_ROCM=1", ]), deps = if_gpu_is_configured([ - ":launch_dimensions", ":buffer_comparator_kernel", ":gpu_asm_opts_util", + ":launch_dimensions", "@com_google_absl//absl/base", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@eigen_archive//:eigen3", "//xla:shape_util", "//xla:status_macros", + "//xla:statusor", "//xla:util", "//xla/service:hlo_module_config", "//xla/stream_executor", - "//xla:statusor", "//xla/stream_executor/gpu:asm_compiler", - "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:ml_dtypes", + "@tsl//tsl/platform:statusor", ]), ) @@ -3722,6 +4568,12 @@ xla_cc_test( ":stream_executor_util", "//xla:shape_util", "//xla:types", + "//xla/service:hlo_module_config", + "//xla/stream_executor", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:platform_manager", + "@tsl//tsl/platform:ml_dtypes", + "@tsl//tsl/platform:status", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ] + if_gpu_is_configured([ @@ -3754,14 +4606,22 @@ cc_library( srcs = ["gpu_fusible.cc"], hdrs = ["gpu_fusible.h"], deps = [ + ":backend_configs_cc", + ":hlo_traversal", ":ir_emission_utils", ":reduction_utils", "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/service:hlo_dataflow_analysis", "//xla/service:instruction_fusion", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", ], ) @@ -3778,6 +4638,8 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", ], ) @@ -3790,13 +4652,29 @@ cc_library( ":backend_configs_cc", ":cublas_cudnn", "//xla:comparison_util", + "//xla:debug_options_flags", + "//xla:literal", + "//xla:shape_util", + "//xla:statusor", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:hlo_creation_utils", "//xla/service:hlo_pass", "//xla/service:pattern_matcher", "//xla/stream_executor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:statusor", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", @@ -3813,6 +4691,7 @@ xla_test( "requires-gpu-sm80-only", "noasan", "nomsan", + "no_rocm", ], }, backends = [ @@ -3825,20 +4704,30 @@ xla_test( ":cublas_cudnn", ":cudnn_fused_conv_rewriter", ":gpu_conv_rewriter", + "//xla:comparison_util", + "//xla:error_spec", + "//xla/hlo/ir:hlo", "//xla/service:algebraic_simplifier", "//xla/service:convert_mover", "//xla/service:hlo_constant_folding", + "//xla/service:hlo_module_config", "//xla/service:hlo_pass", "//xla/service:hlo_pass_pipeline", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service:reshape_mover", "//xla/service/gpu/tests:gpu_codegen_test", + "//xla/stream_executor:device_description", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest_main", "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", @@ -3856,17 +4745,31 @@ cc_library( ":cublas_cudnn", "//xla:shape_util", "//xla:status", + "//xla:statusor", + "//xla:types", + "//xla:util", "//xla:window_util", "//xla/hlo/ir:hlo", "//xla/service:hlo_creation_utils", "//xla/service:hlo_pass", "//xla/service:pattern_matcher", "//xla/stream_executor", - "@com_google_absl//absl/strings", - "@tsl//tsl/protobuf:dnn_proto_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/protobuf:dnn_proto_cc", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cudnn_header", + ]) + if_static([ + "@com_google_protobuf//:protobuf", ]), ) @@ -3878,7 +4781,9 @@ xla_cc_test( deps = [ ":cublas_cudnn", ":cudnn_norm_rewriter", + "//xla:error_spec", "//xla/service/gpu/tests:gpu_codegen_test", + "//xla/stream_executor:device_description", "//xla/tests:filecheck", "@com_google_googletest//:gtest_main", "@tsl//tsl/lib/core:status_test_util", @@ -3892,32 +4797,38 @@ cc_library( name = "cudnn_fused_mha_rewriter", srcs = ["cudnn_fused_mha_rewriter.cc"], hdrs = ["cudnn_fused_mha_rewriter.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":backend_configs_cc", ":cublas_cudnn", - ":ir_emission_utils", ":matmul_utils", + ":stream_executor_util", "//xla:permutation_util", "//xla:shape_util", + "//xla:status", "//xla:status_macros", "//xla:statusor", "//xla:types", + "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", "//xla/service:hlo_pass", "//xla/service:pattern_matcher", "//xla/stream_executor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", - "@tsl//tsl/protobuf:dnn_proto_cc", - ], + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]), ) cc_library( @@ -3928,48 +4839,56 @@ cc_library( ":backend_configs_cc", ":cublas_cudnn", ":matmul_utils", - "//xla:comparison_util", - "//xla:literal_util", "//xla:permutation_util", "//xla:shape_util", "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", "//xla/service:hlo_pass", "//xla/service:pattern_matcher", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", - "@tsl//tsl/protobuf:dnn_proto_cc", ], ) -xla_cc_test( +xla_test( name = "cudnn_fused_mha_rewriter_test", srcs = ["cudnn_fused_mha_rewriter_test.cc"], - tags = tf_cuda_tests_tags(), + backend_tags = {"gpu": [ + "requires-gpu-nvidia", + "no_rocm", + ]}, + backends = [ + "gpu", + ], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":backend_configs_cc", ":cublas_cudnn", ":cudnn_fused_mha_rewriter", ":cudnn_fused_mha_transpose_fusion", - "//xla:status_macros", + "//xla:test_helpers", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", "//xla/service:algebraic_simplifier", "//xla/service:computation_layout", "//xla/service:hlo_cse", "//xla/service:hlo_dce", + "//xla/service:hlo_module_config", "//xla/service:hlo_parser", + "//xla/service:hlo_verifier", "//xla/service:layout_normalization", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service:reshape_decomposer", + "//xla/stream_executor:device_description", "//xla/stream_executor:dnn", "//xla/tests:hlo_test_base", "@com_google_absl//absl/algorithm:container", @@ -3978,7 +4897,10 @@ xla_cc_test( "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", - ], + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cudnn_header", + ]), ) xla_test( @@ -3991,6 +4913,7 @@ xla_test( "gpu", ], deps = [ + ":variant_visitor", "//xla:error_spec", "//xla:xla_proto_cc", "//xla/stream_executor:device_description", @@ -4006,7 +4929,9 @@ xla_cc_test( srcs = ["conv_layout_normalization_test.cc"], tags = tf_cuda_tests_tags(), deps = [ - "//xla/service/gpu/tests:gpu_codegen_test", + "//xla:error_spec", + "//xla/hlo/ir:hlo", + "//xla/service/gpu/tests:gpu_codegen_test", # fixdeps: keep "//xla/tests:hlo_test_base", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -4018,13 +4943,17 @@ cc_library( srcs = ["variadic_op_splitter.cc"], hdrs = ["variadic_op_splitter.h"], deps = [ + "//xla:shape_util", "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -4033,10 +4962,12 @@ cc_library( srcs = ["gpu_scatter_expander.cc"], hdrs = ["gpu_scatter_expander.h"], deps = [ + "//xla:shape_util", "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service:scatter_expander", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings:string_view", ], ) @@ -4047,7 +4978,6 @@ xla_cc_test( "nomsan", ], deps = [ - ":ir_emission_utils", ":variadic_op_splitter", "//xla:literal_util", "//xla:shape_util", @@ -4058,6 +4988,7 @@ xla_cc_test( "//xla/service:pattern_matcher", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", ], ) @@ -4082,6 +5013,11 @@ cc_library( "//xla:debug_options_flags", "//xla/stream_executor", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:protobuf", + "@tsl//tsl/platform:status", ], ) @@ -4095,6 +5031,7 @@ xla_cc_test( deps = [ ":hlo_algorithm_denylist", "//xla/stream_executor:dnn", + "@com_google_absl//absl/strings", "@tsl//tsl/platform:env", "@tsl//tsl/platform:path", "@tsl//tsl/platform:resource_loader", @@ -4109,8 +5046,14 @@ cc_library( hdrs = ["alias_passthrough_params.h"], deps = [ "//xla:shape_util", + "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", ], ) @@ -4123,7 +5066,6 @@ xla_cc_test( deps = [ ":alias_passthrough_params", "//xla/tests:hlo_test_base", - "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:test", @@ -4137,13 +5079,24 @@ cc_library( deps = [ ":gpu_fusible", "//xla:shape_util", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:hlo_creation_utils", "//xla/service:hlo_pass", + "//xla/service:sub_byte_normalization", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -4155,10 +5108,12 @@ xla_cc_test( ":gpu_device_info_for_tests", ":horizontal_loop_fusion", ":instruction_fusion", + "//xla:error_spec", "//xla:literal", "//xla:shape_util", "//xla:test", "//xla:test_helpers", + "//xla/hlo/ir:hlo", "//xla/service:gpu_plugin", "//xla/service:hlo_dce", "//xla/service:hlo_parser", @@ -4167,8 +5122,11 @@ xla_cc_test( "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service:tuple_simplifier", + "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", "@tsl//tsl/lib/core:status_test_util", ], ) @@ -4179,12 +5137,18 @@ cc_library( hdrs = ["horizontal_input_fusion.h"], deps = [ ":gpu_fusible", + "//xla:shape_util", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:hlo_creation_utils", "//xla/service:hlo_pass", "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:statusor", ], ) @@ -4195,11 +5159,15 @@ xla_cc_test( deps = [ ":gpu_device_info_for_tests", ":horizontal_input_fusion", + "//xla:error_spec", + "//xla:literal_util", "//xla:shape_util", "//xla:test", + "//xla/hlo/ir:hlo", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service/gpu/tests:gpu_codegen_test", + "//xla/stream_executor:device_description", "//xla/tests:xla_internal_test_main", ], ) @@ -4217,8 +5185,13 @@ cc_library( "//xla/service:hlo_pass", "//xla/service:pattern_matcher", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -4232,6 +5205,14 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:statusor", ], ) @@ -4244,6 +5225,13 @@ cc_library( "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:statusor", ], ) @@ -4255,6 +5243,7 @@ xla_cc_test( "//xla:shape_util", "//xla:test", "//xla:test_helpers", + "//xla/hlo/ir:hlo", "//xla/service:hlo_parser", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", @@ -4272,12 +5261,20 @@ cc_library( "//xla:shape_util", "//xla:status_macros", "//xla:statusor", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", "//xla/service:pattern_matcher", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -4296,8 +5293,16 @@ cc_library( "//xla/service:hlo_pass", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:statusor", ], ) @@ -4314,180 +5319,16 @@ cc_library( "//xla/service:hlo_pass", "//xla/service:pattern_matcher", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) -# See tap/tensorflow.xla_gpu_jitrt. -test_suite( - name = "jitrt_executable_tests", - tests = [ - ":cudnn_fused_conv_rewriter_test", - ":cudnn_fused_mha_rewriter_test", - ":custom_call_test", - # TODO(anlunx): Re-enable when AOT is available in Thunk-based runtime. - # copybara:uncomment # ":gpu_aot_compilation_test", - # copybara:uncomment "//platforms/xla/tests/internal:xfeed_test_gpu", - # TODO(anlunx): Re-enable when the FFI mechanism is available in Thunk-based runtime. - # copybara:uncomment # "//third_party/py/jax/experimental/jax2tf/tests:primitives_test_gpu", - # copybara:uncomment "//third_party/py/jax/tests:pmap_test_gpu", - # copybara:uncomment "//tensorflow/compiler/tests:fft_test_gpu", - "//xla/python:xla_client_test_gpu", - "//xla/service/gpu/tests:add_preds.hlo.test", - "//xla/service/gpu/tests:concat.hlo.test", - "//xla/service/gpu/tests:constant.hlo.test", - "//xla/service/gpu/tests:copy.hlo.test", - "//xla/service/gpu/tests:copy_nested.hlo.test", - "//xla/service/gpu/tests:dynamic_update_slice_inplace.hlo.test", - "//xla/service/gpu/tests:element_wise_row_vectorization.hlo.test", - "//xla/service/gpu/tests:element_wise_row_vectorization_test", - "//xla/service/gpu/tests:fused_scatter.hlo.test", - "//xla/service/gpu/tests:fused_slice.hlo.test", - "//xla/service/gpu/tests:fused_slice_different_operands.hlo.test", - "//xla/service/gpu/tests:fusion.hlo.test", - "//xla/service/gpu/tests:gemm_broadcast_folding_rewrite_test", - "//xla/service/gpu/tests:gemm_rewrite_test", - "//xla/service/gpu/tests:gpu_alignment_test", - "//xla/service/gpu/tests:gpu_all_gather_optimizer_test", - "//xla/service/gpu/tests:gpu_atomic_test", - "//xla/service/gpu/tests:gpu_compilation_parallelism_test", - "//xla/service/gpu/tests:gpu_convolution_regression_test", - "//xla/service/gpu/tests:gpu_copy_alone_test", - "//xla/service/gpu/tests:gpu_copy_test", - "//xla/service/gpu/tests:gpu_dyn_shape_test", - "//xla/service/gpu/tests:gpu_ftz_test", - "//xla/service/gpu/tests:gpu_fusion_test", - "//xla/service/gpu/tests:gpu_index_test", - "//xla/service/gpu/tests:gpu_infeed_test", - "//xla/service/gpu/tests:gpu_input_fusible_slice_test", - "//xla/service/gpu/tests:gpu_kernel_tiling_test", - "//xla/service/gpu/tests:gpu_ldg_test", - "//xla/service/gpu/tests:gpu_noalias_test", - "//xla/service/gpu/tests:gpu_reduce_scatter_creator_test", - "//xla/service/gpu/tests:gpu_spmd_e2e_compile_test", - "//xla/service/gpu/tests:gpu_too_many_blocks_test", - "//xla/service/gpu/tests:gpu_unrolling_test", - "//xla/service/gpu/tests:in_place_op_test", - "//xla/service/gpu/tests:kernel_launch_test", - "//xla/service/gpu/tests:kernel_reuse.hlo.test", - "//xla/service/gpu/tests:launch_dimensions.hlo.test", - "//xla/service/gpu/tests:pad_to_static.hlo.test", - "//xla/service/gpu/tests:parallel_reduction_test", - "//xla/service/gpu/tests:pred_arithmetic_test", - "//xla/service/gpu/tests:reduce_unnested.hlo.test", - "//xla/service/gpu/tests:reduction_degenerate_dim_remover_test", - "//xla/service/gpu/tests:reduction_dimension_grouper_test", - "//xla/service/gpu/tests:reduction_layout_normalizer_test", - "//xla/service/gpu/tests:reduction_vectorization_sm_all.hlo.test", - "//xla/service/gpu/tests:reduction_vectorization_test", - "//xla/service/gpu/tests:rng_get_and_update_state.hlo.test", - "//xla/service/gpu/tests:scatter.hlo.test", - "//xla/service/gpu/tests:select_and_scatter.hlo.test", - "//xla/service/gpu/tests:select_and_scatter_test", - "//xla/service/gpu/tests:single_instruction.hlo.test", - "//xla/service/gpu/tests:slice_to_dynamic.hlo.test", - "//xla/service/gpu/tests:sorting.hlo.test", - "//xla/service/gpu/tests:sorting_test", - "//xla/service/gpu/tests:swap_conv_operands_test", - "//xla/service/gpu/tests:tree_reduction_rewriter_test", - "//xla/tests:all_reduce_test_gpu", - "//xla/tests:array_elementwise_ops_test_gpu", - "//xla/tests:axpy_simple_test_gpu", - "//xla/tests:bad_rng_shape_validation_test_gpu", - "//xla/tests:batch_normalization_test_gpu", - "//xla/tests:bfloat16_test_gpu", - "//xla/tests:binop_scaling_test_gpu", - "//xla/tests:bitcast_convert_test_gpu", - "//xla/tests:broadcast_simple_test_gpu", - "//xla/tests:broadcast_test_gpu", - "//xla/tests:buffer_donation_test_gpu", - "//xla/tests:call_test_gpu", - "//xla/tests:check_execution_arity_test_gpu", - "//xla/tests:cholesky_test_gpu", - "//xla/tests:client_test_gpu", - "//xla/tests:compilation_cache_test_gpu", - "//xla/tests:compute_constant_test_gpu", - "//xla/tests:concat_test_gpu", - "//xla/tests:conditional_test_gpu", - "//xla/tests:constant_reduction_function_test_gpu", - "//xla/tests:constants_test_gpu", - "//xla/tests:conv_depthwise_backprop_filter_test_gpu", - "//xla/tests:conv_depthwise_test_gpu", - "//xla/tests:convert_test_gpu", - "//xla/tests:convolution_dimension_numbers_test_gpu", - "//xla/tests:convolution_test_1d_autotune_disabled_gpu", - "//xla/tests:convolution_test_1d_gpu_alternative_layout_gpu", - "//xla/tests:convolution_test_1d_no_vmodule_gpu", - "//xla/tests:convolution_test_autotune_disabled_gpu", - "//xla/tests:convolution_test_cudnn_frontend_disabled_gpu", - "//xla/tests:convolution_test_gpu", - "//xla/tests:convolution_test_gpu_alternative_layout_gpu", - "//xla/tests:convolution_variants_test_gpu", - "//xla/tests:copy_test_gpu", - "//xla/tests:cpu_gpu_fusion_test_gpu", - "//xla/tests:deallocation_test_gpu", - "//xla/tests:deconstruct_tuple_test_gpu", - "//xla/tests:deep_graph_test_gpu", - "//xla/tests:dot_operation_single_threaded_runtime_test_gpu", - "//xla/tests:dot_operation_test_autotune_disabled_gpu", - "//xla/tests:dot_operation_test_gpu", - "//xla/tests:dynamic_ops_test_gpu", - "//xla/tests:float8_test_gpu", - "//xla/tests:floor_ceil_test_gpu", - "//xla/tests:fmax_fmin_test_gpu", - "//xla/tests:gather_operation_test_gpu", - "//xla/tests:get_dimension_size_test_gpu", - "//xla/tests:grouped_convolution_test_gpu", - "//xla/tests:half_test_gpu", - "//xla/tests:iota_test_gpu", - "//xla/tests:local_client_allocation_test_gpu", - "//xla/tests:local_client_execute_test_gpu", - "//xla/tests:log_test_gpu", - "//xla/tests:map_test_gpu", - "//xla/tests:matmul_test_gpu", - "//xla/tests:matrix_ops_simple_test_gpu", - "//xla/tests:multidimensional_slice_test_gpu", - "//xla/tests:multioutput_fusion_test_gpu", - "//xla/tests:outfeed_in_nested_computation_test_gpu", - "//xla/tests:pad_test_gpu", - "//xla/tests:params_test_gpu", - "//xla/tests:pred_test_gpu", - "//xla/tests:prng_test_gpu", - "//xla/tests:ptxas_bug_120501638_gpu", - "//xla/tests:query_inferred_shape_test_gpu", - "//xla/tests:reduce_hlo_test_gpu", - "//xla/tests:reduce_precision_test_gpu", - "//xla/tests:reduce_test_gpu", - "//xla/tests:reduce_window_test_gpu", - "//xla/tests:replay_test_gpu", - "//xla/tests:reshape_motion_test_gpu", - "//xla/tests:reshape_test_gpu", - "//xla/tests:reverse_test_gpu", - "//xla/tests:round_trip_packed_literal_test_gpu", - "//xla/tests:round_trip_transfer_test_gpu", - "//xla/tests:sample_text_test_gpu", - "//xla/tests:scalar_computations_test_gpu", - "//xla/tests:scatter_test_gpu", - "//xla/tests:select_and_scatter_test_gpu", - "//xla/tests:select_test_gpu", - "//xla/tests:slice_test_gpu", - "//xla/tests:token_hlo_test_gpu", - "//xla/tests:transfer_manager_test_gpu", - "//xla/tests:transpose_test_gpu", - "//xla/tests:triangular_solve_test_gpu", - "//xla/tests:tuple_test_gpu", - "//xla/tests:unary_op_test_gpu", - "//xla/tests:value_inference_test_gpu", - "//xla/tests:vector_ops_reduce_test_gpu", - "//xla/tests:vector_ops_simple_test_gpu", - "//xla/tests:while_test_gpu", - "//xla/tests:xla_hlo_profile_test_gpu", - ] + if_google([ - # Currently fails in OSS. - "//tensorflow/python/kernel_tests/signal:fft_ops_test_xla_gpu", - ]), -) - cc_library( name = "metrics", srcs = ["metrics.cc"], @@ -4499,6 +5340,44 @@ cc_library( ], ) +cc_library( + name = "dot_operand_converter", + srcs = ["dot_operand_converter.cc"], + hdrs = ["dot_operand_converter.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:op_expander_pass", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", + ], +) + +xla_test( + name = "dot_operand_converter_test", + srcs = if_cuda_is_configured(["dot_operand_converter_test.cc"]), + backends = [ + "gpu_a100", + "gpu_p100", + "gpu_v100", + ], + deps = if_cuda_is_configured([ + ":dot_operand_converter", + "@com_google_googletest//:gtest", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:pattern_matcher", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@tsl//tsl/platform:statusor", + ]), +) + cc_library( name = "make_batch_pointers", srcs = if_gpu_is_configured(["make_batch_pointers.cc"]), @@ -4510,12 +5389,13 @@ cc_library( "//xla:util", "//xla/stream_executor", "//xla/stream_executor:device_memory", + "//xla/stream_executor/gpu:gpu_stream_header", + "@com_google_absl//absl/status", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", ] + if_cuda_is_configured([ ":make_batch_pointers_kernel", ]) + if_rocm_is_configured([ - "//xla/stream_executor/gpu:gpu_stream_header", "//xla/stream_executor/rocm:rocm_helpers", ]), ) @@ -4523,7 +5403,9 @@ cc_library( cuda_library( name = "make_batch_pointers_kernel", srcs = if_cuda_is_configured(["make_batch_pointers.cu.cc"]), - deps = ["@local_config_cuda//cuda:cuda_headers"], + deps = [ + "@local_config_cuda//cuda:cuda_headers", # build_cleaner: keep + ], ) cc_library( @@ -4532,11 +5414,17 @@ cc_library( hdrs = ["triangular_solve_rewriter.h"], deps = [ ":cublas_cudnn", + "//xla:shape_util", "//xla:statusor", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:hlo_creation_utils", "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -4555,8 +5443,9 @@ tsl_gpu_library( "//xla/service:custom_call_target_registry", "//xla/service:platform_util", "//xla/stream_executor", - "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", @@ -4589,6 +5478,7 @@ cc_library( "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", @@ -4606,6 +5496,8 @@ xla_cc_test( "//xla/service:hlo_parser", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", "@tsl//tsl/lib/core:status_test_util", ], ) @@ -4616,10 +5508,20 @@ cc_library( hdrs = ["scatter_slice_simplifier.h"], deps = [ "//xla:shape_util", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:hlo_creation_utils", "//xla/service:hlo_pass", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -4633,6 +5535,7 @@ xla_cc_test( "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", ], ) @@ -4643,10 +5546,15 @@ cc_library( deps = [ ":cublas_cudnn", "//xla:shape_util", + "//xla:status_macros", "//xla:statusor", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@tsl//tsl/platform:protobuf", + "@tsl//tsl/platform:statusor", ], ) @@ -4661,6 +5569,7 @@ cc_library( "//xla:executable_run_options", "//xla:shape_util", "//xla:status", + "//xla:status_macros", "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", @@ -4672,6 +5581,8 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@tsl//tsl/platform:statusor", ], @@ -4682,7 +5593,6 @@ cc_library( srcs = ["topk_splitter.cc"], hdrs = ["topk_splitter.h"], deps = [ - "//xla:literal_util", "//xla:shape_util", "//xla:status", "//xla:statusor", @@ -4693,6 +5603,8 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:statusor", @@ -4704,7 +5616,6 @@ xla_cc_test( srcs = ["topk_splitter_test.cc"], deps = [ ":topk_splitter", - "//xla:error_spec", "//xla/hlo/ir:hlo", "//xla/service:hlo_dce", "//xla/service:pattern_matcher", @@ -4713,13 +5624,40 @@ xla_cc_test( "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", - "@tsl//tsl/platform:status", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], ) +xla_cc_test( + name = "topk_test", + srcs = ["topk_test.cc"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + tags = tf_gpu_tests_tags(), + deps = [ + ":topk_specializer", + "//xla:shape_util", + "//xla:status", + "//xla:statusor", + "//xla/hlo/ir:hlo", + "//xla/service:gpu_plugin", + "//xla/service:hlo_pass", + "//xla/service:platform_util", + "//xla/service:topk_rewriter", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test_main", + ], +) + cc_library( name = "copy_fusion", srcs = ["copy_fusion.cc"], @@ -4731,7 +5669,57 @@ cc_library( "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + ], +) + +cc_library( + name = "algorithm_checker", + srcs = ["algorithm_checker.cc"], + hdrs = ["algorithm_checker.h"], + deps = [ + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:algorithm_util", + "//xla/service:hlo_pass", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + ], +) + +xla_test( + name = "dot_algorithm_support_test", + srcs = if_cuda_is_configured(["dot_algorithm_support_test.cc"]), + backends = [ + "gpu_v100", + "gpu_a100", + ], + tags = [ + "nomac", + ], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_googletest//:gtest", ], ) @@ -4744,7 +5732,11 @@ cc_library( ":launch_dimensions", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/stream_executor:launch_dim", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@tsl//tsl/platform:logging", ], ) @@ -4760,12 +5752,14 @@ cc_library( "//xla:status", "//xla:statusor", "//xla/hlo/ir:hlo", - "//xla/mlir_hlo:lhlo", "//xla/service:buffer_assignment", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -4775,9 +5769,14 @@ cc_library( hdrs = ["hlo_traversal.h"], compatible_with = get_compatible_with_portable(), deps = [ + "//xla:shape_util", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", ], ) @@ -4799,11 +5798,14 @@ cc_library( hdrs = ["fusion_wrapper.h"], deps = [ ":gpu_fusible", + "//xla:status", "//xla:status_macros", "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@tsl//tsl/platform:errors", ], @@ -4824,11 +5826,13 @@ xla_cc_test( srcs = ["copy_fusion_test.cc"], deps = [ ":copy_fusion", + "//xla/hlo/ir:hlo", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", ], ) @@ -4837,15 +5841,27 @@ xla_cc_test( srcs = if_cuda_is_configured(["autotuner_util_test.cc"]), deps = if_cuda_is_configured([ ":autotuner_util", - "//xla:autotune_results_proto_cc", "@com_google_googletest//:gtest", "@com_google_absl//absl/base:log_severity", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:scoped_mock_log", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "//xla:autotune_results_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor/host:host_platform", "//xla/tests:hlo_test_base", "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:logging", "@tsl//tsl/platform:protobuf", - ]) + ["//xla/tests:xla_internal_test_main"], + ]) + [ + "//xla/tests:xla_internal_test_main", # Keep outside GPU guard + ], ) cc_library( @@ -4867,6 +5883,8 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", @@ -4875,33 +5893,31 @@ cc_library( xla_cc_test( name = "loop_double_buffer_transformer_test", - srcs = if_gpu_is_configured(["loop_double_buffer_transformer_test.cc"]), - tags = tf_cuda_tests_tags(), - deps = if_gpu_is_configured([ - ":gpu_compiler", + srcs = ["loop_double_buffer_transformer_test.cc"], + deps = [ ":loop_double_buffer_transformer", "//xla:test", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:tuple_simplifier", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - ]) + [ - "//xla:test_helpers", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/service:hlo_dce", "@com_google_absl//absl/container:flat_hash_set", + "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", ], ) -xla_cc_test( +xla_test( name = "determinism_test", srcs = ["determinism_test.cc"], + backends = [ + "gpu_a100", + ], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - tags = tf_gpu_tests_tags(), deps = [ ":autotuner_util", "//xla:literal", @@ -4927,3 +5943,196 @@ cc_library( "//xla/service:symbol_repository", ], ) + +cc_library( + name = "collective_permute_cycle_decomposer", + srcs = ["collective_permute_cycle_decomposer.cc"], + hdrs = ["collective_permute_cycle_decomposer.h"], + deps = [ + ":backend_configs_cc", + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:collective_ops_utils", + "//xla/service:hlo_parser", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "collective_permute_cycle_decomposer_test", + srcs = ["collective_permute_cycle_decomposer_test.cc"], + deps = [ + ":collective_permute_cycle_decomposer", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "stream_attribute_annotator", + srcs = ["stream_attribute_annotator.cc"], + hdrs = ["stream_attribute_annotator.h"], + deps = [ + ":backend_configs_cc", + "//xla:comparison_util", + "//xla:status", + "//xla:statusor", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_pass", + "//xla/service/gpu/runtime:thunk", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "stream_attribute_annotator_test", + srcs = ["stream_attribute_annotator_test.cc"], + deps = [ + ":backend_configs_cc", + ":stream_attribute_annotator", + "//xla:status", + "//xla:statusor", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "stream_attribute_async_wrapper", + srcs = ["stream_attribute_async_wrapper.cc"], + hdrs = ["stream_attribute_async_wrapper.h"], + deps = [ + ":backend_configs_cc", + "//xla:comparison_util", + "//xla:status", + "//xla:statusor", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service/gpu/runtime:thunk", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "stream_attribute_async_wrapper_test", + srcs = ["stream_attribute_async_wrapper_test.cc"], + deps = [ + ":backend_configs_cc", + ":stream_attribute_async_wrapper", + "//xla:status", + "//xla:statusor", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "gpu_windowed_einsum_handler", + srcs = ["gpu_windowed_einsum_handler.cc"], + hdrs = ["gpu_windowed_einsum_handler.h"], + deps = [ + ":backend_configs_cc", + "//xla:status", + "//xla:statusor", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_pass", + "//xla/service:pattern_matcher", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "gpu_windowed_einsum_handler_test", + srcs = ["gpu_windowed_einsum_handler_test.cc"], + deps = [ + ":backend_configs_cc", + ":gpu_windowed_einsum_handler", + "//xla:status", + "//xla:statusor", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", + ], +) diff --git a/xla/service/gpu/address_computation_fusion_rewriter.cc b/xla/service/gpu/address_computation_fusion_rewriter.cc new file mode 100644 index 0000000000000..5b92eb3423ebb --- /dev/null +++ b/xla/service/gpu/address_computation_fusion_rewriter.cc @@ -0,0 +1,515 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/address_computation_fusion_rewriter.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/ffi_api.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/gpu_constants.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +namespace { + +// A dataflow path flowing from a definition to a user. +using DefUseDataflowPath = absl::InlinedVector; + +// All dataflow paths flowing from a definition to all users. Each user will +// have a separate entry in the vector. +using DefUseDataflowPaths = absl::InlinedVector; + +// A dataflow path flowing from a user to a definition. +using UseDefDataflowPath = absl::InlinedVector; + +// All dataflow paths flowing from a user to all definitions of its operands. +using UseDefDataflowPaths = absl::InlinedVector; + +using DataflowPathView = absl::Span; +using DataflowPathsView = absl::Span; + +using InstructionSet = absl::flat_hash_set; + +bool IsNoOp(const HloInstruction* hlo) { + return HloPredicateIsOp(hlo); +} + +bool IsCustomCall(const HloInstruction* hlo, absl::string_view platform_name) { + auto* custom_call = DynCast(hlo); + if (custom_call == nullptr) return false; + + // TODO(vuson): properly handle token by following + // `LhloDialectEmitter::EmitCustomCallOp`'s `CreateOperands` logic for + // `LhloDialectEmitter::EmitFusionOp`'s `RewriteFusionOperand` + if (custom_call->shape().IsTuple() && + absl::c_any_of( + custom_call->shape().tuple_shapes(), + [&](const Shape& sub_shape) { return sub_shape.IsToken(); })) + return false; + + const std::string call_target_name = custom_call->custom_call_target(); + + bool is_ffi_custom_call = + custom_call->api_version() == CustomCallApiVersion::API_VERSION_TYPED_FFI; + + void* call_target = CustomCallTargetRegistry::Global()->Lookup( + call_target_name, std::string(platform_name)); + + absl::StatusOr handler_registration = + ffi::FindHandler(call_target_name, platform_name); + + // At least one implementation should be available at run time. + bool found_custom_call = !is_ffi_custom_call && call_target != nullptr; + bool found_ffi_handler = is_ffi_custom_call && handler_registration.ok(); + + return found_custom_call || found_ffi_handler; +} + +// Returns true if the slice is 128-byte-aligned. The slice starting +// address is determined by the product of all non-sliced dimensions and an +// offset defined by `slice_starts` of the slice op. +// +// For dynamic cases, we don't have info about the start indices, so we have to +// be conservative by only accepting sliced shapes that have the product of all +// non-sliced dimensions being a multiple of `kXlaAllocatedBufferAlignBytes`. +bool IsAlignedSlice(const Shape& src_shape, const Shape& dst_shape, + const HloSliceInstruction* slice) { + if (!IsContiguousSlice(src_shape, dst_shape)) return false; + + auto strides = ShapeUtil::ByteStrides(dst_shape); + if (!strides.has_value()) return false; + + for (auto dim : dst_shape.layout().minor_to_major()) { + if ((strides.value()[dim] % kXlaAllocatedBufferAlignBytes) == 0) + return true; + if (dst_shape.dimensions(dim) < src_shape.dimensions(dim)) { + return (slice != nullptr && + ((strides.value()[dim] * slice->slice_starts(dim)) % + kXlaAllocatedBufferAlignBytes == + 0)); + } + } + return true; +} + +UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) { + UseDefDataflowPaths sliced_operand_paths; + + auto fusion = HloFusionAdaptor::ForComputation(instr->parent()); + // This set is used to avoid duplicates in the matched results. It contains + // the matched instructions that we have seen so far. + InstructionSet processed_instrs; + + const auto& aliasing_pairs = + Cast(instr)->output_to_operand_aliasing(); + absl::flat_hash_set aliased_operands; + for (const auto& pair : aliasing_pairs) { + aliased_operands.insert(pair.second.first); + } + + for (auto* operand : instr->operands()) { + // output_to_operand_aliasing means the operand is to be materialized, which + // is against the whole idea of address computation fusion. Skip this + // operand. + if (aliased_operands.contains(instr->operand_index(operand))) continue; + UseDefDataflowPath maybe_sliced_operand_path; + bool slice_found = false; + // TODO: currently HloFindIf exits upon encountering the first node that + // matches. This works well if each operand only has 1 data flow (i.e. only + // flows through unary op). We might want to keep finding until the queue is + // empty: if the operand is a tuple, it might have different data flows + // (i.e. 1 for each element). + auto maybe_slice_adaptor = + HloFindIf({HloInstructionAdaptor(*operand)}, *fusion, [&](auto node) { + const HloInstruction* cur = &node.instruction(); + + // If the node is a match that has been processed, stop the traversal. + if (processed_instrs.contains(cur)) return true; + + maybe_sliced_operand_path.push_back(const_cast(cur)); + + if (IsOpcodeAnyOf( + node)) { + if (IsAlignedSlice(cur->operand(0)->shape(), cur->shape(), + DynCast(cur))) { + slice_found = true; + return slice_found; + } + } + + // TODO(vuson): lift the first restriction by considering fusing other + // uses of the operand to reuse the address computation. Only worth it + // if other uses are also custom calls though. + return cur->user_count() > 1 || !IsNoOp(cur); + }); + + if (maybe_slice_adaptor == std::nullopt) continue; + + const auto& maybe_slice_instr = maybe_slice_adaptor->instruction(); + + if (slice_found || processed_instrs.contains(&maybe_slice_instr)) { + // Even in the case of stopping at a match that has been processed, we + // still need to add instructions encountered in the sliced operand path + // during the latest traversal. + sliced_operand_paths.insert(sliced_operand_paths.end(), + maybe_sliced_operand_path.rbegin(), + maybe_sliced_operand_path.rend()); + processed_instrs.insert(maybe_sliced_operand_path.begin(), + maybe_sliced_operand_path.end()); + } + } + + sliced_operand_paths.push_back(const_cast(instr)); + return sliced_operand_paths; +} + +// Each user of `instr` that goes into a DUS will have an entry in the returned +// vector. +// Each entry contains the sliced paths for that user, i.e. the sequence of ops +// following the dataflow from the user itself to the DUS (included). +DefUseDataflowPaths GetSlicedUserPaths(const HloInstruction* instr) { + DefUseDataflowPaths sliced_user_paths; + auto fusion = HloFusionAdaptor::ForComputation(instr->parent()); + // This set is used to avoid duplicates in the matched results. It contains + // the matched instructions that we have seen so far. + InstructionSet processed_instrs; + + auto traverse_hlo_and_collect = [&](HloInstruction* start) { + DefUseDataflowPath maybe_sliced_user_path; + bool dus_found = false; + auto maybe_dus_adaptor = HloFindIf( + {HloInstructionAdaptor(*start)}, *fusion, + [&](auto node) { + const HloInstruction* cur = &node.instruction(); + // If the node is a match that has been processed, stop the + // traversal. + if (processed_instrs.contains(cur)) return true; + maybe_sliced_user_path.push_back(const_cast(cur)); + if (const auto slice_instr = + DynCast(cur)) { + if (IsAlignedSlice(slice_instr->shape(), + slice_instr->update()->shape(), nullptr)) { + dus_found = true; + return true; + } + } + return cur->user_count() > 1 || !IsNoOp(cur); + }, + /*visit_operands=*/false); + if (maybe_dus_adaptor == std::nullopt) return; + const auto& maybe_dus_instr = maybe_dus_adaptor->instruction(); + if (dus_found || processed_instrs.contains(&maybe_dus_instr)) { + // Even in the case of stopping at a match that has been processed, we + // still need to add instructions encountered in the sliced user path + // during the latest traversal. + processed_instrs.insert(maybe_sliced_user_path.begin(), + maybe_sliced_user_path.end()); + sliced_user_paths.push_back(std::move(maybe_sliced_user_path)); + } + }; + + if (instr->shape().IsTuple()) { + for (auto* user : instr->users()) { + if (DynCast(user)) { + traverse_hlo_and_collect(user); + } + } + } else { + if (instr->user_count() == 1) { + traverse_hlo_and_collect(instr->users().front()); + } + } + + return sliced_user_paths; +} + +absl::InlinedVector GetPatternCaptures( + DataflowPathView matches) { + absl::InlinedVector captures; + + InstructionSet matched_instrs(matches.begin(), matches.end()); + + for (HloInstruction* instr : matches) { + for (HloInstruction* operand : instr->operands()) { + if (!matched_instrs.contains(operand) && + absl::c_find(captures, operand) == captures.end()) { + captures.emplace_back(operand); + } + } + } + + return captures; +} + +Status CreateRootTuple(HloInstruction* hero, HloComputation::Builder& builder, + DataflowPathsView sliced_user_paths, + absl::flat_hash_map& instr_mapping) { + unsigned tuple_size = hero->shape().tuple_shapes_size(); + + std::vector sliced_elems(tuple_size, nullptr); + for (auto& sliced_user_path : sliced_user_paths) { + auto gte = Cast(sliced_user_path.front()); + sliced_elems[gte->tuple_index()] = sliced_user_path.back(); + } + + std::vector elements; + for (size_t i = 0; i < tuple_size; ++i) { + if (sliced_elems[i] != nullptr) { + elements.push_back(instr_mapping[sliced_elems[i]]); + continue; + } + auto* gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(instr_mapping[hero], i)); + if (hero->shape().tuple_shapes(i).IsTuple()) { + instr_mapping[gte] = gte; + TF_RETURN_IF_ERROR(CreateRootTuple(gte, builder, {}, instr_mapping)); + elements.push_back(builder.last_added_instruction()); + } else { + elements.push_back(gte); + } + } + if (elements.size() > 1) + builder.AddInstruction(HloInstruction::CreateTuple(elements)); + + return absl::OkStatus(); +} + +absl::StatusOr CreateFusionBody( + HloModule* module, DataflowPathView sliced_operand_paths, + DataflowPathsView sliced_user_paths, DataflowPathView captures) { + HloComputation::Builder builder("address-computation"); + + // A mapping from original instructions to instructions in the fusion body. + absl::flat_hash_map instr_mapping; + + auto mapped_operands = [&](HloInstruction* instr) { + absl::InlinedVector operands; + for (HloInstruction* operand : instr->operands()) { + operands.push_back(instr_mapping.at(operand)); + } + return operands; + }; + + // For every captured value create a parameter instruction in the computation + // body and set up instruction mapping. + for (const HloInstruction* capture : captures) { + int64_t index = instr_mapping.size(); + instr_mapping[capture] = + builder.AddInstruction(HloInstruction::CreateParameter( + index, capture->shape(), absl::StrCat("p", index))); + } + + // Instructions in the pattern are already topologically sorted, as we visited + // them following use-def path, then reverse the list. + HloInstruction* hero; + for (HloInstruction* instr : sliced_operand_paths) { + instr_mapping[instr] = builder.AddInstruction( + instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr))); + hero = instr; + } + + for (auto& sliced_user_path : sliced_user_paths) { + for (HloInstruction* instr : sliced_user_path) { + instr_mapping[instr] = builder.AddInstruction( + instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr))); + } + } + + // Create a tuple if the hero is a tuple to make sure there's a buffer + // assigned for each of the elements. Make sure the tuple is not nil first. + if (hero->shape().IsTuple() && hero->shape().tuple_shapes_size() > 0) { + TF_RETURN_IF_ERROR( + CreateRootTuple(hero, builder, sliced_user_paths, instr_mapping)); + } + + return module->AddComputationAndUnifyNamesAndIds(builder.Build(), false); +} + +absl::StatusOr CreateFusionInstruction( + HloModule* module, HloInstruction* orig, DataflowPathView captures, + HloComputation* body, bool dynamic) { + HloComputation* parent = orig->parent(); + + // Add a fusion operation calling outlined fusion computation. + HloInstruction* fusion = parent->AddInstruction(HloInstruction::CreateFusion( + body->root_instruction()->shape(), HloInstruction::FusionKind::kCustom, + captures, body)); + module->SetAndUniquifyInstrName(fusion, "address_computation"); + + // We don't need to set/update output_to_operand_aliasing for the new fusion + // instruction because all buffers are already assigned at this point. + + // Set backends config to a matched custom fusion config. + GpuBackendConfig gpu_config; + FusionBackendConfig& backend_config = + *gpu_config.mutable_fusion_backend_config(); + backend_config.set_kind("__custom_fusion"); + CustomFusionConfig config; + config.set_name(dynamic ? "dynamic_address_computation" + : "address_computation"); + *backend_config.mutable_custom_fusion_config() = config; + TF_RETURN_IF_ERROR(fusion->set_backend_config(std::move(gpu_config))); + + return fusion; +} + +} // namespace + +absl::StatusOr AddressComputationFusionRewriter::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + absl::flat_hash_map> + matches; + + // Collect all potential custom call matches in the non-fusion computations. + for (HloComputation* computation : module->computations()) { + if (computation->IsFusionComputation()) continue; + for (HloInstruction* instr : computation->instructions()) { + if (IsLegacyCublasMatmul(*instr) || + (IsCustomCall(instr, platform_name_))) { + UseDefDataflowPaths sliced_operand_paths = GetSlicedOperandPaths(instr); + bool has_sliced_operand_paths = sliced_operand_paths.size() > 1; + + DefUseDataflowPaths sliced_user_paths = GetSlicedUserPaths(instr); + bool has_sliced_user_paths = absl::c_any_of( + sliced_user_paths, + [&](auto& sliced_user_path) { return !sliced_user_path.empty(); }); + + if (absl::c_any_of(sliced_user_paths, [&](auto& sliced_user_path) { + return DynCast( + sliced_user_path.back()) == nullptr; + })) { + return absl::InternalError( + "Expect sliced user path to end with a DUS."); + } + + if (has_sliced_operand_paths || has_sliced_user_paths) { + matches[instr] = std::make_pair(std::move(sliced_operand_paths), + std::move(sliced_user_paths)); + } + } + } + } + + if (matches.empty()) return false; + + for (auto& [hero, paths] : matches) { + auto& [sliced_operand_paths, sliced_user_paths] = paths; + std::vector matched_instrs; + absl::c_copy(sliced_operand_paths, std::back_inserter(matched_instrs)); + + std::vector sliced_user_paths_view; + for (auto& sliced_user_path : sliced_user_paths) { + absl::c_copy(sliced_user_path, std::back_inserter(matched_instrs)); + DataflowPathView sliced_user_path_view{&sliced_user_path.front(), + sliced_user_path.size()}; + sliced_user_paths_view.push_back(std::move(sliced_user_path_view)); + } + + auto captures = GetPatternCaptures(matched_instrs); + + TF_ASSIGN_OR_RETURN( + HloComputation * fusion_body, + CreateFusionBody(module, sliced_operand_paths, + DataflowPathsView(sliced_user_paths_view), captures)); + + bool has_dynamic_slices = absl::c_any_of(matched_instrs, [&](auto* instr) { + return DynCast(instr) != nullptr; + }); + TF_ASSIGN_OR_RETURN( + HloInstruction * fusion, + CreateFusionInstruction(module, hero, captures, fusion_body, + has_dynamic_slices)); + + HloComputation* parent = hero->parent(); + if (fusion->shape().IsTuple()) { + TF_RETURN_IF_ERROR(parent->ReplaceInstructionWithDifferentShape( + const_cast(hero), fusion)); + for (auto& sliced_user_path : sliced_user_paths) { + auto old_gte = + Cast(sliced_user_path.front()); + HloInstruction* gte = + parent->AddInstruction(HloInstruction::CreateGetTupleElement( + fusion, old_gte->tuple_index())); + TF_RETURN_IF_ERROR( + parent->ReplaceInstruction(sliced_user_path.back(), gte)); + } + } else { + auto* instr_to_be_replaced = const_cast(hero); + if (sliced_user_paths.empty()) { + // The only case where a tuple-shaped original hero op is fused into a + // non-tuple-shaped fusion is there's only one element of the original + // tuple being used. In that case, we need to replace that single + // get-tuple-element (instead of the hero op) with the fusion + // instruction. + if (hero->shape().IsTuple()) { + if (hero->user_count() != 1 || + !DynCast(hero->users().front())) { + return absl::InternalError( + "Expect a single get-tuple-element user of the original " + "tuple-shaped hero op when address computation fusion does " + "not return a tuple"); + } + instr_to_be_replaced = hero->users().front(); + } + } else { + instr_to_be_replaced = sliced_user_paths.front().back(); + } + TF_RETURN_IF_ERROR( + parent->ReplaceInstruction(instr_to_be_replaced, fusion)); + } + } + + return true; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/address_computation_fusion_rewriter.h b/xla/service/gpu/address_computation_fusion_rewriter.h new file mode 100644 index 0000000000000..9a00a7c11f889 --- /dev/null +++ b/xla/service/gpu/address_computation_fusion_rewriter.h @@ -0,0 +1,91 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_ADDRESS_COMPUTATION_FUSION_REWRITER_H_ +#define XLA_SERVICE_GPU_ADDRESS_COMPUTATION_FUSION_REWRITER_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Pattern matches (slice(s) + custom call) to custom address computation +// fusions and rewrites them into fusion instructions and fusion computations. +// +// Example: +// +// ENTRY %main { +// %p0 = bf16[2,8,8]{2,1,0} parameter(0) +// %p1 = bf16[2,8,8]{2,1,0} parameter(1) +// %slice_lhs = bf16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} +// %bitcast_lhs = bf16[8,8]{1,0} bitcast(%slice_lhs) +// %slice_rhs = bf16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} +// %bitcast_rhs = bf16[8,8]{1,0} bitcast(%slice_rhs) +// ROOT %dot = bf16[8,8]{1,0} custom-call(%bitcast_lhs, %bitcast_rhs), +// custom_call_target="__cublas$gemm" +// } +// +// After the pass: +// +// %address_computation { +// %p0 = bf16[2,8,8]{2,1,0} parameter(0) +// %p1 = bf16[2,8,8]{2,1,0} parameter(1) +// %slice_lhs = bf16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} +// %bitcast_lhs = bf16[8,8]{1,0} bitcast(%slice_lhs) +// %slice_rhs = bf16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} +// %bitcast_rhs = bf16[8,8]{1,0} bitcast(%slice_rhs) +// ROOT %dot = bf16[8,8]{1,0} custom-call(%bitcast_lhs, %bitcast_rhs), +// custom_call_target="__cublas$gemm" +// } +// +// ENTRY %main { +// %p0 = bf16[2,8,8]{2,1,0} parameter(0) +// %p1 = bf16[2,8,8]{2,1,0} parameter(1) +// ROOT %fusion.2 = bf16[8,8]{1,0} fusion(%p0, %p1), +// kind=kCustom, calls=%address_computation, +// backend_config={"fusion_backend_config":{ +// "kind":"__custom_fusion", +// "custom_fusion_config":{"name":"address_computation"} +// }} +// } +// +class AddressComputationFusionRewriter : public HloModulePass { + public: + absl::string_view name() const override { + return "address-computation-fusion-rewriter"; + } + + explicit AddressComputationFusionRewriter(std::string platform_name) + : platform_name_(std::move(platform_name)) {} + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + std::string platform_name_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_ADDRESS_COMPUTATION_FUSION_REWRITER_H_ diff --git a/xla/service/gpu/address_computation_fusion_rewriter_test.cc b/xla/service/gpu/address_computation_fusion_rewriter_test.cc new file mode 100644 index 0000000000000..4d14024115a62 --- /dev/null +++ b/xla/service/gpu/address_computation_fusion_rewriter_test.cc @@ -0,0 +1,1769 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/address_computation_fusion_rewriter.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "xla/client/lib/constants.h" +#include "xla/client/xla_builder.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/service/buffer_value.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/hlo_memory_scheduler.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/service_executable_run_options.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/stream.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +#define PLATFORM "GPU" +namespace xla::gpu { + +class AddressComputationFusionRewriterTest : public HloTestBase {}; + +TEST_F(AddressComputationFusionRewriterTest, SimpleGemm) { + const char* hlo = R"( + HloModule test + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + } + )"; + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1) + ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]} + ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]]) + ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]} + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]]) + ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected); +} + +TEST_F(AddressComputationFusionRewriterTest, SimpleGemmWithWorkspace) { + const char* hlo = R"( + HloModule test + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + } + )"; + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1) + ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]} + ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]]) + ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]} + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]]) + ; CHECK: [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0 + ; CHECK: [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1 + ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) + ; CHECK: tuple([[DOT]], [[WORKSPACE]]) + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: ROOT [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected); +} + +TEST_F(AddressComputationFusionRewriterTest, SimpleGemmWorkspaceIgnored) { + const char* hlo = R"( + HloModule test + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + ROOT %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0 + } + )"; + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1) + ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]} + ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]]) + ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]} + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]]) + ; CHECK: [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0 + ; CHECK: [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1 + ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) + ; CHECK: tuple([[DOT]], [[WORKSPACE]]) + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: ROOT [[DOT_MAIN:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[FUSION]]), index=0 + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected); +} + +TEST_F(AddressComputationFusionRewriterTest, SimpleGemmNotRoot) { + const char* hlo = R"( + HloModule test + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + ROOT %res = f16[8,8]{1,0} add(%custom-call.1, %custom-call.1) + } + )"; + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1) + ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]} + ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]]) + ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]} + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]]) + ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[FUSION]]) + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected); +} + +TEST_F(AddressComputationFusionRewriterTest, + SimpleGemmOperandHasMultipleUsers) { + const char* hlo = R"( + HloModule test + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[4,8,8]{2,1,0} parameter(1) + %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[2:3], [0:8], [0:8]} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + ROOT %res = f16[8,8]{1,0} add(%custom-call.1, %bitcast.41) + } + )"; + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[8,8]{1,0} parameter(1) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[2:3], [0:8], [0:8]} + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]]) + ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[B1]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]} + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]]) + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(1) + ; CHECK: [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[P0]], [[B1]]) + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[B1]]) + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected); +} + +TEST_F(AddressComputationFusionRewriterTest, + SimpleGemmOperandsHaveMultipleUsers) { + const char* hlo = R"( + HloModule test + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + %custom-call.0 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + + ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.42, %bitcast.41), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + std::nullopt); +} + +TEST_F(AddressComputationFusionRewriterTest, SimpleGemmSlicingNotParameter) { + const char* hlo = R"( + HloModule test + + ENTRY %main.9 { + %p0 = f16[4,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %slice.12 = f16[2,8,8]{2,1,0} slice(%p0), slice={[0:2], [0:8], [0:8]} + %slice.13 = f16[1,8,8]{2,1,0} slice(%slice.12), slice={[1:2], [0:8], [0:8]} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + ROOT %res = f16[8,8]{1,0} add(%custom-call.1, %custom-call.1) + } + )"; + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1) + ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]} + ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]]) + ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]} + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]]) + ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[S0:%[^ ]+]] = f16[2,8,8]{2,1,0} slice([[P0]]), slice={[0:2], [0:8], [0:8]} + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1) + ; CHECK: [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[S0]], [[P1]]) + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[FUSION]]) + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected); +} + +TEST_F(AddressComputationFusionRewriterTest, SimpleGemmNotContiguousSlice) { + const char* hlo = R"( + HloModule test + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %slice.13 = f16[1,4,6]{2,1,0} slice(%p0), slice={[1:2], [0:4], [0:6]} + %bitcast.41 = f16[4,6]{1,0} bitcast(%slice.13) + %slice.14 = f16[1,6,4]{2,1,0} slice(%p1), slice={[1:2], [0:6], [0:4]} + %bitcast.42 = f16[6,4]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = f16[4,4]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + std::nullopt); +} + +TEST_F(AddressComputationFusionRewriterTest, SimpleGemmNonNoOpInSliceChain) { + const char* hlo = R"( + HloModule test + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]} + %slice.14 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} + %add.0 = f16[1,8,8]{2,1,0} add(%slice.13, %slice.14) + %bitcast.41 = f16[8,8]{1,0} bitcast(%add.0) + %slice.15 = f16[1,8,8]{2,1,0} slice(%p1), slice={[0:1], [0:8], [0:8]} + %slice.16 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} + %add.1 = f16[1,8,8]{2,1,0} add(%slice.15, %slice.16) + %bitcast.42 = f16[8,8]{1,0} bitcast(%add.1) + + ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + std::nullopt); +} + +TEST_F(AddressComputationFusionRewriterTest, SimpleGemmDuplicateOperand) { + const char* hlo = R"( + HloModule test + + ENTRY %main { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %get-tuple-element.240 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.241 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.10 = f32[200,100]{1,0} concatenate(%get-tuple-element.240, %get-tuple-element.241), dimensions={0} + %custom-call.16 = (f32[200,100]{1,0}, s8[120000]{0}) custom-call(%concatenate.10, %get-tuple-element.240), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"20000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.97 = f32[200,100]{1,0} get-tuple-element(%custom-call.16), index=0 + %slice.26 = f32[100,100]{1,0} slice(%get-tuple-element.97), slice={[0:100], [0:100]} + ROOT %custom-call.17 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%slice.26, %slice.26), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + })"; + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK: [[P0:%[^ ]+]] = f32[200,100]{1,0} parameter(0) + ; CHECK: [[S0:%[^ ]+]] = f32[100,100]{1,0} slice([[P0]]), slice={[0:100], [0:100]} + ; CHECK-NOT: slice + ; CHECK: [[CC:%[^ ]+]] = (f32[100,100]{1,0}, s8[80000]{0}) custom-call([[S0]], [[S0]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: ROOT [[FUSION:%[^ ]+]] = (f32[100,100]{1,0}, s8[80000]{0}) fusion + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected); +} + +TEST_F(AddressComputationFusionRewriterTest, SimpleGemmReverseOperandOrder) { + const char* hlo = R"( + HloModule test + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(1) + %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %p1 = f16[2,8,8]{2,1,0} parameter(0) + %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + } + )"; + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1) + ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[0:1], [0:8], [0:8]} + ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]]) + ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]} + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]]) + ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK-DAG: [[A0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1) + ; CHECK-DAG: [[A1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[A0]], [[A1]]) + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected); +} + +TEST_F(AddressComputationFusionRewriterTest, SimpleGemmReverseOperandOrder2) { + const char* hlo = R"( + HloModule test + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.42, %bitcast.41), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + } + )"; + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1) + ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]} + ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]]) + ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[0:1], [0:8], [0:8]} + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]]) + ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK-DAG: [[A0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1) + ; CHECK-DAG: [[A1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[A0]], [[A1]]) + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected); +} + +TEST_F(AddressComputationFusionRewriterTest, SimpleGemmOperandAliasingOutput) { + const char* hlo = R"( + HloModule test + + ENTRY %main.9 { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %get-tuple-element.287 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.288 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.12 = f32[200,100]{1,0} concatenate(%get-tuple-element.287, %get-tuple-element.288), dimensions={0} + %slice.30 = f32[100,100]{1,0} slice(%concatenate.12), slice={[16:116], [0:100]} + %slice.34 = f32[100,100]{1,0} slice(%concatenate.12), slice={[99:199], [0:100]} + ROOT %cublas-gemm.15 = (f32[100,100]{1,0}, s8[120000]{0}) custom-call(%get-tuple-element.287, %slice.30, %slice.34), + custom_call_target="__cublas$gemm", + output_to_operand_aliasing={{0}: (2, {})}, + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":1, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + }} + } + )"; + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK-DAG: [[P2:%[^ ]+]] = f32[100,100]{1,0} parameter(2) + ; CHECK-DAG: [[P1:%[^ ]+]] = f32[100,100]{1,0} parameter(1) + ; CHECK-DAG: [[P0:%[^ ]+]] = f32[200,100]{1,0} parameter(0) + ; CHECK-DAG: [[S1:%[^ ]+]] = f32[100,100]{1,0} slice([[P0]]), slice={[16:116], [0:100]} + ; CHECK: [[CC:%[^ ]+]] = (f32[100,100]{1,0}, s8[120000]{0}) custom-call([[P1]], [[S1]], [[P2]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: [[P:%[^ ]+]] = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + ; CHECK: [[GTE0:%[^ ]+]] = f32[100,100]{1,0} get-tuple-element([[P]]), index=0 + ; CHECK: [[GTE1:%[^ ]+]] = f32[100,100]{1,0} get-tuple-element([[P]]), index=1 + ; CHECK: [[CONCAT:%[^ ]+]] = f32[200,100]{1,0} concatenate([[GTE0]], [[GTE1]]), dimensions={0} + ; CHECK: [[S:%[^ ]+]] = f32[100,100]{1,0} slice([[CONCAT]]), slice={[99:199], [0:100]} + ; CHECK: ROOT [[FUSION:%[^ ]+]] = (f32[100,100]{1,0}, s8[120000]{0}) fusion([[CONCAT]], [[GTE0]], [[S]]) + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected); +} + +TEST_F(AddressComputationFusionRewriterTest, SimpleGemmOperandsFromSameSlice) { + const char* hlo = R"( + HloModule test + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %bitcast.42 = f16[8,8]{0,1} bitcast(%slice.13) + + ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + } + )"; + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[0:1], [0:8], [0:8]} + ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]]) + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{0,1} bitcast([[S0]]) + ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK-DAG: [[A0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[A0]]) + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected); +} + +static absl::Status Memcpy(se::Stream* stream, ffi::BufferBase src, + ffi::BufferBase dst) { + return stream->MemcpyD2D( + &dst.data, src.data, + absl::c_accumulate(src.dimensions, 1.0, std::multiplies()) * + sizeof(float)); +} + +XLA_FFI_DEFINE_HANDLER(kMemcpy, Memcpy, + ffi::Ffi::Bind() + .Ctx() + .Arg() // src + .Arg() // dst +); +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$memcpy", PLATFORM, + kMemcpy); + +TEST_F(AddressComputationFusionRewriterTest, SimpleCustomCall) { + XlaBuilder b(TestName()); + CustomCall(&b, "__xla_test$$memcpy", + /*operands=*/ + {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0}, + {128}, {1})}, + ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + // TF_ASSERT_OK_AND_ASSIGN( + // HloSchedule schedule, + // ScheduleModule(hlo.get(), [](const BufferValue& buffer) { + // return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + // })); + // TF_CHECK_OK(hlo->set_schedule(std::move(schedule))); + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK: [[P0:%[^ ]+]] = f32[256]{0} parameter(0) + ; CHECK: [[S0:%[^ ]+]] = f32[128]{0} slice([[P0]]), slice={[0:128]} + ; CHECK: ROOT [[CC:%[^ ]+]] = f32[128]{0} custom-call([[S0]]), + ; CHECK: custom_call_target="__xla_test$$memcpy", + ; CHECK: api_version=API_VERSION_TYPED_FFI + ; CHECK: } + + ; CHECK: ENTRY %{{.*}} { + ; CHECK: [[C0:%[^ ]+]] = f32[] constant(42) + ; CHECK: [[BC:%[^ ]+]] = f32[256]{0} broadcast([[C0]]) + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f32[128]{0} fusion([[BC]]) + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite( + hlo->ToString(), AddressComputationFusionRewriter(PLATFORM), expected); +} + +void Callback_Void(se::gpu::GpuStreamHandle stream, void** buffers, + const char* /*opaque*/, size_t /*opaque_len*/) {} + +XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Void, PLATFORM); + +TEST_F(AddressComputationFusionRewriterTest, SimpleCustomCallLegacy) { + XlaBuilder b(TestName()); + CustomCall(&b, "Callback_Void", + /*operands=*/ + {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0}, + {128}, {1})}, + ShapeUtil::MakeShape(F32, {128}), /*opaque=*/""); + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + // TF_ASSERT_OK_AND_ASSIGN( + // HloSchedule schedule, + // ScheduleModule(hlo.get(), [](const BufferValue& buffer) { + // return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + // })); + // TF_CHECK_OK(hlo->set_schedule(std::move(schedule))); + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK: [[P0:%[^ ]+]] = f32[256]{0} parameter(0) + ; CHECK: [[S0:%[^ ]+]] = f32[128]{0} slice([[P0]]), slice={[0:128]} + ; CHECK: ROOT [[CC:%[^ ]+]] = f32[128]{0} custom-call([[S0]]), + ; CHECK: custom_call_target="Callback_Void" + ; CHECK: } + + ; CHECK: ENTRY %{{.*}} { + ; CHECK: [[C0:%[^ ]+]] = f32[] constant(42) + ; CHECK: [[BC:%[^ ]+]] = f32[256]{0} broadcast([[C0]]) + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f32[128]{0} fusion([[BC]]) + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite( + hlo->ToString(), AddressComputationFusionRewriter(PLATFORM), expected); +} + +TEST_F(AddressComputationFusionRewriterTest, TupleSliceCustomCallLegacy) { + XlaBuilder b(TestName()); + CustomCall( + &b, "Callback_Void", + /*operands=*/ + { + Tuple(&b, + { + Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}), + {0, 0}, {4, 8}, {1, 1}), + Broadcast(ConstantR0WithType(&b, F32, 2), {256}), + }), + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 3), {1024}), + Broadcast(ConstantR0WithType(&b, F32, 4), {8}), + }), + }, + ShapeUtil::MakeShape(F32, {128}), /*opaque=*/""); + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + // TF_ASSERT_OK_AND_ASSIGN( + // HloSchedule schedule, + // ScheduleModule(hlo.get(), [](const BufferValue& buffer) { + // return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + // })); + // TF_CHECK_OK(hlo->set_schedule(std::move(schedule))); + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f32[8,8]{1,0} parameter(0) + ; CHECK-DAG: [[S0:%[^ ]+]] = f32[4,8]{1,0} slice([[P0]]), slice={[0:4], [0:8]} + ; CHECK-DAG: [[P1:%[^ ]+]] = f32[256]{0} parameter(1) + ; CHECK-DAG: [[T0:%[^ ]+]] = (f32[4,8]{1,0}, f32[256]{0}) tuple([[S0]], [[P1]]) + ; CHECK-DAG: [[P2:%[^ ]+]] = (f32[1024]{0}, f32[8]{0}) parameter(2) + ; CHECK: ROOT [[CC:%[^ ]+]] = f32[128]{0} custom-call([[T0]], [[P2]]), + ; CHECK: custom_call_target="Callback_Void" + ; CHECK: } + + ; CHECK: ENTRY %{{.*}} { + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f32[128]{0} fusion( + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite( + hlo->ToString(), AddressComputationFusionRewriter(PLATFORM), expected); +} + +TEST_F(AddressComputationFusionRewriterTest, TupledOutputCustomCallLegacy) { + XlaBuilder b(TestName()); + auto custom_call = CustomCall( + &b, "Callback_Void", + /*operands=*/ + { + Tuple(&b, + { + Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}), + {0, 0}, {4, 8}, {1, 1}), + Broadcast(ConstantR0WithType(&b, F32, 2), {256}), + }), + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 3), {1024}), + Broadcast(ConstantR0WithType(&b, F32, 4), {8}), + }), + }, + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {8}), + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {128}), + ShapeUtil::MakeShape(F32, {256}), + }), + ShapeUtil::MakeShape(F32, {1024}), + ShapeUtil::MakeShape(F32, {4, 8}), + }), + /*opaque=*/""); + Tuple(&b, {GetTupleElement(GetTupleElement(custom_call, 1), 0), + GetTupleElement(custom_call, 2)}); + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + // TF_ASSERT_OK_AND_ASSIGN( + // HloSchedule schedule, + // ScheduleModule(hlo.get(), [](const BufferValue& buffer) { + // return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + // })); + // TF_CHECK_OK(hlo->set_schedule(std::move(schedule))); + + const char* expected = R"( + ; CHECK: %address-computation {{.*}} { + ; CHECK-DAG: [[P2:%[^ ]+]] = (f32[1024]{0}, f32[8]{0}) parameter(2) + ; CHECK-DAG: [[P1:%[^ ]+]] = f32[256]{0} parameter(1) + ; CHECK-DAG: [[P0:%[^ ]+]] = f32[8,8]{1,0} parameter(0) + ; CHECK-DAG: [[S0:%[^ ]+]] = f32[4,8]{1,0} slice([[P0]]), slice={[0:4], [0:8]} + ; CHECK-DAG: [[T0:%[^ ]+]] = (f32[4,8]{1,0}, f32[256]{0}) tuple([[S0]], [[P1]]) + ; CHECK: [[CC:%[^ ]+]] = (f32[8]{0}, (f32[128]{0}, f32[256]{0}), f32[1024]{0}, f32[4,8]{1,0}) custom-call([[T0]], [[P2]]), + ; CHECK: custom_call_target="Callback_Void" + ; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[8]{0} get-tuple-element([[CC]]), index=0 + ; CHECK-DAG: [[GTE1:%[^ ]+]] = (f32[128]{0}, f32[256]{0}) get-tuple-element([[CC]]), index=1 + ; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[128]{0} get-tuple-element([[GTE1]]), index=0 + ; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[256]{0} get-tuple-element([[GTE1]]), index=1 + ; CHECK-DAG: [[T1:%[^ ]+]] = (f32[128]{0}, f32[256]{0}) tuple([[GTE2]], [[GTE3]]) + ; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1024]{0} get-tuple-element([[CC]]), index=2 + ; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[4,8]{1,0} get-tuple-element([[CC]]), index=3 + ; CHECK: ROOT {{.*}} = (f32[8]{0}, (f32[128]{0}, f32[256]{0}), f32[1024]{0}, f32[4,8]{1,0}) tuple([[GTE0]], [[T1]], [[GTE4]], [[GTE5]]) + ; CHECK: } + + ; CHECK: ENTRY %{{.*}} { + ; CHECK: [[FUSION:%[^ ]+]] = (f32[8]{0}, (f32[128]{0}, f32[256]{0}), f32[1024]{0}, f32[4,8]{1,0}) fusion + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: } + ; CHECK-DAG: [[GTE6:%[^ ]+]] = f32[1024]{0} get-tuple-element([[FUSION]]), index=2 + ; CHECK-DAG: [[GTE7:%[^ ]+]] = (f32[128]{0}, f32[256]{0}) get-tuple-element([[FUSION]]), index=1 + ; CHECK-DAG: [[GTE8:%[^ ]+]] = f32[128]{0} get-tuple-element([[GTE7]]), index=0 + ; CHECK: ROOT {{.*}} = (f32[128]{0}, f32[1024]{0}) tuple([[GTE8]], [[GTE6]]) + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite( + hlo->ToString(), AddressComputationFusionRewriter(PLATFORM), expected); +} + +TEST_F(AddressComputationFusionRewriterTest, UnalignedSlice) { + XlaBuilder b(TestName()); + CustomCall( + &b, "Callback_Void", + /*operands=*/ + {Slice(Broadcast(ConstantR0WithType(&b, S32, 42), {17}), {1}, {17}, {1})}, + ShapeUtil::MakeShape(S32, {16}), /*opaque=*/""); + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + // TF_ASSERT_OK_AND_ASSIGN( + // HloSchedule schedule, + // ScheduleModule(hlo.get(), [](const BufferValue& buffer) { + // return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + // })); + // TF_CHECK_OK(hlo->set_schedule(std::move(schedule))); + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo->ToString(), + AddressComputationFusionRewriter(PLATFORM), + std::nullopt); +} + +TEST_F(AddressComputationFusionRewriterTest, DynamicSimpleGemm) { + const char* hlo = R"( + HloModule test + + ENTRY main.9 { + p0 = f16[2,8,8]{2,1,0} parameter(0) + p1 = f16[2,8,8]{2,1,0} parameter(1) + c1_s32 = s32[] constant(1) + c0_s32 = s32[] constant(0) + slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.41 = f16[8,8]{1,0} bitcast(slice.13) + slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.42 = f16[8,8]{1,0} bitcast(slice.14) + + ROOT custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + } + )"; + + const char* expected = R"( + ; CHECK: address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3) + ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1) + ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2) + ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8} + ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]]) + ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8} + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]]) + ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected); +} + +TEST_F(AddressComputationFusionRewriterTest, DynamicSimpleGemmWithWorkspace) { + const char* hlo = R"( + HloModule test + + ENTRY main.9 { + p0 = f16[2,8,8]{2,1,0} parameter(0) + p1 = f16[2,8,8]{2,1,0} parameter(1) + c1_s32 = s32[] constant(1) + c0_s32 = s32[] constant(0) + slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.41 = f16[8,8]{1,0} bitcast(slice.13) + slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.42 = f16[8,8]{1,0} bitcast(slice.14) + + ROOT custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(bitcast.41, bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + } + )"; + + const char* expected = R"( + ; CHECK: address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3) + ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1) + ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2) + ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8} + ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]]) + ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8} + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]]) + ; CHECK: [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0 + ; CHECK: [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1 + ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) + ; CHECK: tuple([[DOT]], [[WORKSPACE]]) + ; CHECK: } + + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: ROOT [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected); +} + +TEST_F(AddressComputationFusionRewriterTest, + DynamicSimpleGemmWorkspaceIgnored) { + const char* hlo = R"( + HloModule test + + ENTRY main.9 { + p0 = f16[2,8,8]{2,1,0} parameter(0) + p1 = f16[2,8,8]{2,1,0} parameter(1) + c1_s32 = s32[] constant(1) + c0_s32 = s32[] constant(0) + slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.41 = f16[8,8]{1,0} bitcast(slice.13) + slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.42 = f16[8,8]{1,0} bitcast(slice.14) + + custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(bitcast.41, bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + ROOT get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(custom-call.1), index=0 + } + )"; + + const char* expected = R"( + ; CHECK: address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3) + ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1) + ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2) + ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8} + ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]]) + ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8} + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]]) + ; CHECK: [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0 + ; CHECK: [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1 + ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) + ; CHECK: tuple([[DOT]], [[WORKSPACE]]) + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: [[FUSION:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) fusion + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation"} + ; CHECK: } + ; CHECK: ROOT [[DOT_MAIN:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[FUSION]]), index=0 + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected); +} + +TEST_F(AddressComputationFusionRewriterTest, DynamicSimpleGemmNotRoot) { + const char* hlo = R"( + HloModule test + + ENTRY main.9 { + p0 = f16[2,8,8]{2,1,0} parameter(0) + p1 = f16[2,8,8]{2,1,0} parameter(1) + c1_s32 = s32[] constant(1) + c0_s32 = s32[] constant(0) + slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.41 = f16[8,8]{1,0} bitcast(slice.13) + slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.42 = f16[8,8]{1,0} bitcast(slice.14) + + custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + ROOT res = f16[8,8]{1,0} add(custom-call.1, custom-call.1) + } + )"; + + const char* expected = R"( + ; CHECK: address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3) + ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1) + ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2) + ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8} + ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]]) + ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8} + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]]) + ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation"} + ; CHECK: } + ; CHECK: ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[FUSION]]) + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected); +} + +TEST_F(AddressComputationFusionRewriterTest, DUSSimpleGemm) { + const char* hlo = R"( + HloModule test + + ENTRY main.9 { + p0 = f16[1,8,8]{2,1,0} parameter(0) + p1 = f16[1,8,8]{2,1,0} parameter(1) + p2 = f16[4,8,8]{2,1,0} parameter(2) + c1_s32 = s32[] constant(1) + c0_s32 = s32[] constant(0) + bitcast.41 = f16[8,8]{1,0} bitcast(p0) + bitcast.42 = f16[8,8]{1,0} bitcast(p1) + + custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1) + ROOT dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, c1_s32, c0_s32, c0_s32) + } + )"; + + const char* expected = R"( + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[8,8]{1,0} parameter(0) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[8,8]{1,0} parameter(1) + ; CHECK-DAG: [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(2) + ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(3) + ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(4) + ; CHECK-DAG: [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[P1]]), + ; CHECK-DAG: custom_call_target="__cublas$gemm" + ; CHECK-DAG: [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[CC]]) + ; CHECK: ROOT {{.*}} = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]]) + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[4,8,8]{2,1,0} fusion + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation"} + ; CHECK: } + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected); +} + +TEST_F(AddressComputationFusionRewriterTest, DUSSimpleGemmNotRoot) { + const char* hlo = R"( + HloModule test + + ENTRY main.9 { + p0 = f16[2,8,8]{2,1,0} parameter(0) + p1 = f16[2,8,8]{2,1,0} parameter(1) + p2 = f16[4,8,8]{2,1,0} parameter(2) + c1_s32 = s32[] constant(1) + c0_s32 = s32[] constant(0) + slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.41 = f16[8,8]{1,0} bitcast(slice.13) + slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.42 = f16[8,8]{1,0} bitcast(slice.14) + + custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1) + dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, c1_s32, c0_s32, c0_s32) + ROOT res = f16[4,8,8]{2,1,0} log(dus) + } + )"; + + const char* expected = R"( + ; CHECK: address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3) + ; CHECK-DAG: [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(4) + ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1) + ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2) + ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8} + ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]]) + ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8} + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]]) + ; CHECK-DAG: [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]), + ; CHECK-DAG: custom_call_target="__cublas$gemm" + ; CHECK-DAG: [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[CC]]) + ; CHECK: ROOT {{.*}} = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]]) + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: [[FUSION:%[^ ]+]] = f16[4,8,8]{2,1,0} fusion + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation"} + ; CHECK: } + ; CHECK: ROOT {{.*}} = f16[4,8,8]{2,1,0} log([[FUSION]]) + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected); +} + +TEST_F(AddressComputationFusionRewriterTest, DUSSimpleGemmWithWorkspace) { + const char* hlo = R"( + HloModule test + + ENTRY main.9 { + p0 = f16[2,8,8]{2,1,0} parameter(0) + p1 = f16[2,8,8]{2,1,0} parameter(1) + p2 = f16[4,8,8]{2,1,0} parameter(2) + c1_s32 = s32[] constant(1) + c0_s32 = s32[] constant(0) + slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.41 = f16[8,8]{1,0} bitcast(slice.13) + slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.42 = f16[8,8]{1,0} bitcast(slice.14) + + custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(bitcast.41, bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + + get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(custom-call.1), index=0 + bitcast.43 = f16[1,8,8]{2,1,0} bitcast(get-tuple-element.0) + dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, c1_s32, c0_s32, c0_s32) + get-tuple-element.1 = s8[256]{0} get-tuple-element(custom-call.1), index=1 + ROOT tuple = (f16[4,8,8]{2,1,0}, s8[256]{0}) tuple(dus, get-tuple-element.1) + } + )"; + + const char* expected = R"( + ; CHECK: address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3) + ; CHECK-DAG: [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(4) + ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1) + ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2) + ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8} + ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]]) + ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8} + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]]) + ; CHECK: [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[B0]], [[B1]]), + ; CHECK: custom_call_target="__cublas$gemm" + ; CHECK: [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0 + ; CHECK: [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[DOT]]) + ; CHECK: [[DUS:%[^ ]+]] = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]]) + ; CHECK: [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1 + ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0}) + ; CHECK: tuple([[DUS]], [[WORKSPACE]]) + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: [[FUSION:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0}) fusion + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation"} + ; CHECK: } + ; CHECK: [[DUS_MAIN:%[^ ]+]] = f16[4,8,8]{2,1,0} get-tuple-element([[FUSION]]), index=0 + ; CHECK: [[WORKSPACE_MAIN:%[^ ]+]] = s8[256]{0} get-tuple-element([[FUSION]]), index=1 + ; CHECK: ROOT {{.*}} = (f16[4,8,8]{2,1,0}, s8[256]{0}) + ; CHECK: tuple([[DUS_MAIN]], [[WORKSPACE_MAIN]]) + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected); +} + +TEST_F(AddressComputationFusionRewriterTest, DUSSimpleGemmWorkspaceIgnored) { + const char* hlo = R"( + HloModule test + + ENTRY %main.9 { + %p0 = f16[8,8]{1,0} parameter(0) + %p1 = f16[8,8]{1,0} parameter(1) + %p2 = f16[4,8,8]{2,1,0} parameter(2) + %c1_s32 = s32[] constant(1) + %c0_s32 = s32[] constant(0) + + %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%p0, %p1), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0 + %bitcast.43 = f16[1,8,8]{2,1,0} bitcast(%get-tuple-element.0) + ROOT %dus = f16[4,8,8]{2,1,0} dynamic-update-slice(%p2, %bitcast.43, %c1_s32, %c0_s32, %c0_s32) + })"; + + const char* expected = R"( + ; CHECK: address-computation {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[8,8]{1,0} parameter(0) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[8,8]{1,0} parameter(1) + ; CHECK-DAG: [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(2) + ; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(3) + ; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(4) + ; CHECK-DAG: [[CC:%[^ ]+]] = (f16[8,8]{1,0}, s8[256]{0}) custom-call([[P0]], [[P1]]), + ; CHECK-DAG: custom_call_target="__cublas$gemm" + ; CHECK-DAG: [[DOT:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[CC]]), index=0 + ; CHECK-DAG: [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[DOT]]) + ; CHECK-DAG: [[DUS:%[^ ]+]] = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]]) + ; CHECK-DAG: [[WORKSPACE:%[^ ]+]] = s8[256]{0} get-tuple-element([[CC]]), index=1 + ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0}) + ; CHECK: tuple([[DUS]], [[WORKSPACE]]) + ; CHECK: } + + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: [[FUSION:%[^ ]+]] = (f16[4,8,8]{2,1,0}, s8[256]{0}) fusion + ; CHECK: kind=kCustom, calls=%address-computation, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation"} + ; CHECK: } + ; CHECK: ROOT [[DOT_MAIN:%[^ ]+]] = f16[4,8,8]{2,1,0} get-tuple-element([[FUSION]]), index=0 + ; CHECK: } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), + expected); +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/algorithm_checker.cc b/xla/service/gpu/algorithm_checker.cc new file mode 100644 index 0000000000000..3104293f8d255 --- /dev/null +++ b/xla/service/gpu/algorithm_checker.cc @@ -0,0 +1,117 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/algorithm_checker.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/algorithm_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" + +namespace xla { +namespace gpu { + +namespace { + +bool HasNonDefaultOperandPrecision(const PrecisionConfig& config) { + return absl::c_any_of(config.operand_precision(), [](int precision) { + return static_cast(precision) != + PrecisionConfig::DEFAULT; + }); +} + +class AlgorithmCheckerVisitor : public ConstDfsHloVisitorWithDefault { + public: + explicit AlgorithmCheckerVisitor( + se::GpuComputeCapability gpu_compute_capability) + : gpu_compute_capability_(std::move(gpu_compute_capability)) {} + + absl::Status RunOnModule( + const HloModule* module, + const absl::flat_hash_set& execution_threads = {}) { + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + TF_RETURN_IF_ERROR(computation->Accept(this)); + } + return absl::OkStatus(); + } + + absl::Status HandleDot(const HloInstruction* hlo) override { + VLOG(1) << "Handling dot: " << hlo->ToString(); + const PrecisionConfig& config = hlo->precision_config(); + + if (config.algorithm() != PrecisionConfig::ALG_UNSET && + HasNonDefaultOperandPrecision(config)) { + LOG(WARNING) + << "There is no need to set precisions when we set the algorithm: " + << hlo->ToString(); + } + + if (config.algorithm() == PrecisionConfig::ALG_UNSET) { + return absl::OkStatus(); + } + + PrimitiveType lhs_storage_type = hlo->operand(0)->shape().element_type(); + PrimitiveType rhs_storage_type = hlo->operand(1)->shape().element_type(); + PrimitiveType output_storage_type = hlo->shape().element_type(); + + if (lhs_storage_type != rhs_storage_type) { + return absl::UnimplementedError(absl::StrFormat( + "Dot operands must have the same type when using an algorithm: %s", + hlo->ToString())); + } + + return algorithm_util::IsSupportedDotAlgorithmOnGpu( + config.algorithm(), gpu_compute_capability_, lhs_storage_type, + output_storage_type) + ? absl::OkStatus() + : absl::UnimplementedError(absl::StrFormat( + "Unsupported algorithm on the current device(s): %s", + PrecisionConfig::Algorithm_Name(config.algorithm()))); + } + + absl::Status DefaultAction(const HloInstruction* hlo) override { + return absl::OkStatus(); + } + + private: + se::GpuComputeCapability gpu_compute_capability_; +}; + +} // namespace + +absl::StatusOr AlgorithmChecker::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + TF_RETURN_IF_ERROR(AlgorithmCheckerVisitor(gpu_compute_capability_) + .RunOnModule(module, execution_threads)); + // No change was made. + return false; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/algorithm_checker.h b/xla/service/gpu/algorithm_checker.h new file mode 100644 index 0000000000000..f3b30c1c61f5f --- /dev/null +++ b/xla/service/gpu/algorithm_checker.h @@ -0,0 +1,54 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_ALGORITHM_CHECKER_H_ +#define XLA_SERVICE_GPU_ALGORITHM_CHECKER_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { +namespace gpu { + +// This checks if the requested algorithms are supported. This can give an early +// and specific error if an unsupported algorithm is requested. +// +// Note: Maybe we can make this more generic and move it outside of GPU. +class AlgorithmChecker : public HloModulePass { + public: + explicit AlgorithmChecker(se::GpuComputeCapability gpu_compute_capability) + : gpu_compute_capability_(std::move(gpu_compute_capability)){}; + + absl::string_view name() const override { return "algorithm-checker"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + se::GpuComputeCapability gpu_compute_capability_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_ALGORITHM_CHECKER_H_ diff --git a/xla/service/gpu/alias_passthrough_params.cc b/xla/service/gpu/alias_passthrough_params.cc index e4657be40bf21..5dea5bc548374 100644 --- a/xla/service/gpu/alias_passthrough_params.cc +++ b/xla/service/gpu/alias_passthrough_params.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,14 +14,20 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/alias_passthrough_params.h" +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/shape_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" namespace xla { namespace gpu { -StatusOr AliasPassthroughParams::Run( +absl::StatusOr AliasPassthroughParams::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { const HloInstruction* root = module->entry_computation()->root_instruction(); diff --git a/xla/service/gpu/alias_passthrough_params.h b/xla/service/gpu/alias_passthrough_params.h index 4483691dfe1b6..029068a6b5b5c 100644 --- a/xla/service/gpu/alias_passthrough_params.h +++ b/xla/service/gpu/alias_passthrough_params.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,8 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_ALIAS_PASSTHROUGH_PARAMS_H_ #define XLA_SERVICE_GPU_ALIAS_PASSTHROUGH_PARAMS_H_ -#include - +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/hlo_pass_interface.h" @@ -38,7 +39,7 @@ class AliasPassthroughParams : public HloModulePass { absl::string_view name() const override { return "alias_passthrough_params"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/alias_passthrough_params_test.cc b/xla/service/gpu/alias_passthrough_params_test.cc index 27044e4b36a8b..d8141232ebbd3 100644 --- a/xla/service/gpu/alias_passthrough_params_test.cc +++ b/xla/service/gpu/alias_passthrough_params_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ limitations under the License. #include "xla/service/gpu/alias_passthrough_params.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_utils.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" diff --git a/xla/service/gpu/all_reduce_blueconnect.cc b/xla/service/gpu/all_reduce_blueconnect.cc index 017edf054cc8c..033f255d7e568 100644 --- a/xla/service/gpu/all_reduce_blueconnect.cc +++ b/xla/service/gpu/all_reduce_blueconnect.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,21 +16,31 @@ limitations under the License. #include "xla/service/gpu/all_reduce_blueconnect.h" #include +#include +#include #include #include +#include #include #include "absl/algorithm/container.h" #include "absl/container/btree_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/service/computation_placer.h" #include "xla/service/hlo_creation_utils.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -56,7 +66,7 @@ struct DecomposedReplicaGroups { std::vector new_all_reduce_groups; }; -StatusOr> TryDecomposeReplicaGroup( +absl::StatusOr> TryDecomposeReplicaGroup( const ReplicaGroup& replica_group, const DeviceAssignment& device_assignment, size_t num_devices_per_host) { int group_size = replica_group.replica_ids_size(); @@ -107,8 +117,9 @@ StatusOr> TryDecomposeReplicaGroup( std::move(new_all_reduce_groups)}}; } -StatusOr> TryDecomposeReplicaGroups( - const HloAllReduceInstruction& all_reduce, size_t num_devices_per_host) { +absl::StatusOr> +TryDecomposeReplicaGroups(const HloAllReduceInstruction& all_reduce, + size_t num_devices_per_host) { const DeviceAssignment& device_assignment = all_reduce.GetModule()->config().static_device_assignment(); @@ -183,8 +194,8 @@ StatusOr> TryDecomposeReplicaGroups( // // When applied repeatedly, this transformation will reproduce the same pattern // as described in the BlueConnect paper. -StatusOr TryDecomposeAllReduce(HloAllReduceInstruction* all_reduce, - size_t num_devices_per_host) { +absl::StatusOr TryDecomposeAllReduce(HloAllReduceInstruction* all_reduce, + size_t num_devices_per_host) { TF_RET_CHECK(all_reduce); TF_RET_CHECK(!all_reduce->has_sharding()); @@ -251,9 +262,13 @@ StatusOr TryDecomposeAllReduce(HloAllReduceInstruction* all_reduce, outputs[i] = computation.AddInstruction(HloInstruction::CreateBitcast( all_reduce->operand(i)->shape(), outputs[i])); } + HloInstruction* replacement = MaybeMakeTuple(outputs); TF_RETURN_IF_ERROR( - computation.ReplaceInstruction(all_reduce, MaybeMakeTuple(outputs))); + all_reduce->CopyAllControlDepsTo(reduce_scatter, replacement)); + + TF_RETURN_IF_ERROR(all_reduce->DropAllControlDeps()); + TF_RETURN_IF_ERROR(computation.ReplaceInstruction(all_reduce, replacement)); // Try to apply decomposition recursively. TF_RETURN_IF_ERROR( @@ -265,7 +280,7 @@ StatusOr TryDecomposeAllReduce(HloAllReduceInstruction* all_reduce, } // namespace -StatusOr AllReduceBlueConnect::Run( +absl::StatusOr AllReduceBlueConnect::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(1) << "Running AllReduceBlueConnect"; diff --git a/xla/service/gpu/all_reduce_blueconnect.h b/xla/service/gpu/all_reduce_blueconnect.h index 72faa7b5b7e13..8633c77b0eba4 100644 --- a/xla/service/gpu/all_reduce_blueconnect.h +++ b/xla/service/gpu/all_reduce_blueconnect.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,9 +16,13 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_ALL_REDUCE_BLUECONNECT_H_ #define XLA_SERVICE_GPU_ALL_REDUCE_BLUECONNECT_H_ +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" namespace xla { @@ -39,7 +43,7 @@ class AllReduceBlueConnect : public HloModulePass { absl::string_view name() const override { return "all-reduce-blueconnect"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/gpu/all_reduce_blueconnect_test.cc b/xla/service/gpu/all_reduce_blueconnect_test.cc index 009636ab28738..b04fa92733747 100644 --- a/xla/service/gpu/all_reduce_blueconnect_test.cc +++ b/xla/service/gpu/all_reduce_blueconnect_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,17 +15,26 @@ limitations under the License. #include "xla/service/gpu/all_reduce_blueconnect.h" +#include +#include #include +#include +#include +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/computation_placer.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_utils.h" #include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -259,5 +268,68 @@ ENTRY %comp { EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false)); } +TEST_F(AllReduceBlueConnectTest, ControlDeps) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +%add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY %comp { + p0 = f32[4,4] parameter(0) + p1 = f32[4,4] parameter(1) + add = f32[4,4] add(p0, p1) + crs = f32[4,4] all-reduce(p0), to_apply=add, control-predecessors={add} + ROOT add1 = f32[4,4] add(crs, add), control-predecessors={crs} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + SetModuleConfig(*module, /*replica_count=*/8); + + // Remember all-reduce's control succ and preds. + const HloInstruction* ar = + module->entry_computation()->root_instruction()->operand(0); + auto expected_preds = ar->control_predecessors(); + auto expected_succs = ar->control_successors(); + + AllReduceBlueConnect pass(/*num_devices_per_host=*/4); + EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true)); + + // clang-format off + std::vector> scatter_gather_groups = { + {0, 1, 2, 3}, {4, 5, 6, 7}}; + std::vector> new_all_reduce_groups = { + {0, 4}, {1, 5}, {2, 6}, {3, 7}}; + // clang-format on + + const HloInstruction *matched_rs, *matched_bitcast; + auto bitcast = m::Bitcast(m::Parameter(0)).WithShape(F32, {16}); + auto reduce_scatter = m::ReduceScatter(&matched_rs, bitcast) + .WithShape(F32, {4}) + .WithReplicaGroups(scatter_gather_groups); + auto all_reduce = m::AllReduce(reduce_scatter) + .WithShape(F32, {4}) + .WithReplicaGroups(new_all_reduce_groups); + auto all_gather = m::AllGather(all_reduce) + .WithShape(F32, {16}) + .WithReplicaGroups(scatter_gather_groups); + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, GmockMatch(m::Add())); + + EXPECT_THAT( + root->operand(0), + GmockMatch( + m::Bitcast(&matched_bitcast, all_gather).WithShape(F32, {4, 4}))); + + // Verify that control dependencies are transferred correctly. + EXPECT_THAT(matched_rs, GmockMatch(m::Op().WithControlDeps( + absl::MakeSpan(expected_preds), {}))); + EXPECT_THAT(matched_bitcast, GmockMatch(m::Op().WithControlDeps( + {}, absl::MakeSpan(expected_succs)))); +} + } // namespace } // namespace xla diff --git a/xla/service/gpu/amdgpu_compiler.cc b/xla/service/gpu/amdgpu_compiler.cc index b6280e62a86bf..f429e20f27d1a 100644 --- a/xla/service/gpu/amdgpu_compiler.cc +++ b/xla/service/gpu/amdgpu_compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,71 +14,83 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/amdgpu_compiler.h" -#include -#include +#include +#include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/IR/Module.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/algebraic_simplifier.h" #include "xla/service/call_inliner.h" +#include "xla/service/convert_mover.h" +#include "xla/service/dot_dimension_merger.h" +#include "xla/service/float_normalization.h" +#include "xla/service/gpu/autotuner_util.h" #include "xla/service/gpu/conv_algorithm_picker.h" +#include "xla/service/gpu/cublas_pad_for_gemms.h" +#include "xla/service/gpu/cublas_padding_requirements.h" #include "xla/service/gpu/cusolver_rewriter.h" #include "xla/service/gpu/gemm_algorithm_picker.h" -#include "xla/service/gpu/gemm_rewriter.h" +#include "xla/service/gpu/gpu_compiler.h" #include "xla/service/gpu/gpu_conv_padding_legalization.h" #include "xla/service/gpu/gpu_conv_rewriter.h" -#include "xla/service/gpu/gpu_layout_assignment.h" +#include "xla/service/gpu/gpu_sort_rewriter.h" #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" -#include "xla/service/gpu/reduction_degenerate_dim_remover.h" -#include "xla/service/gpu/reduction_dimension_grouper.h" -#include "xla/service/gpu/reduction_layout_normalizer.h" #include "xla/service/gpu/target_constants.h" -#include "xla/service/gpu/tree_reduction_rewriter.h" #include "xla/service/gpu/triangular_solve_rewriter.h" #include "xla/service/hlo_constant_folding.h" -#include "xla/service/hlo_cse.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_fix.h" #include "xla/service/hlo_pass_pipeline.h" #include "xla/service/hlo_verifier.h" -#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/service/reshape_mover.h" #include "xla/service/tuple_simplifier.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/dnn.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" -#include "tsl/platform/rocm_rocdl_path.h" +#include "xla/stream_executor/stream_executor_pimpl.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/threadpool.h" namespace xla { namespace gpu { namespace { -// Returns the directory containing ROCm-Device-Libs files. This function is -// called in AMDGPUCompiler's constructor, so can't return an error. But -// AMDGPUCompiler::Compile will return an error when the wanted rocdl file -// doesn't exist in the folder this function returns. -std::string GetROCDLDir(const HloModuleConfig& config) { - std::vector potential_rocdl_dirs; - const std::string datadir = config.debug_options().xla_gpu_cuda_data_dir(); - if (!datadir.empty()) { - potential_rocdl_dirs.push_back(datadir); +struct ConvBfloat16Support : public FloatSupport { + explicit ConvBfloat16Support(const se::RocmComputeCapability& rocm) + : FloatSupport(BF16), + // TODO: MIOpen does not support bf16 convolutions yet + is_conv_bf16_supported_(rocm.has_bf16_dtype_support()) {} + + bool SupportsLowPrecisionOperand(const HloInstruction& hlo, + int64_t operand_index) const override { + return (hlo.opcode() != HloOpcode::kConvolution) || is_conv_bf16_supported_; } - potential_rocdl_dirs.push_back(tsl::RocdlRoot()); - - // Tries all potential ROCDL directories in the order they are inserted. - // Returns the first directory that exists in the file system. - for (const std::string& potential_rocdl_dir : potential_rocdl_dirs) { - if (tsl::Env::Default()->IsDirectory(potential_rocdl_dir).ok()) { - VLOG(2) << "Found ROCm-Device-Libs dir " << potential_rocdl_dir; - return potential_rocdl_dir; - } - VLOG(2) << "Unable to find potential ROCm-Device-Libs dir " - << potential_rocdl_dir; + + bool SupportsLowPrecisionOutput(const HloInstruction& hlo) const override { + return (hlo.opcode() != HloOpcode::kConvolution) || is_conv_bf16_supported_; } - // Last resort: maybe in the current folder. - return "."; -} + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + // Skip all HLOs other than convolutions. + return (hlo.opcode() != HloOpcode::kConvolution); + } + + private: + bool is_conv_bf16_supported_; +}; } // namespace -Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( +absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( HloModule* hlo_module, se::GpuComputeCapability gpu_version, se::dnn::VersionInfo dnn_version, se::DeviceMemoryAllocator* device_allocator) { @@ -88,6 +100,12 @@ Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddInvariantCheckerDebug( /*layout_sensitive=*/false, /*allow_mixed_precision=*/false); + + // Convert unsupported bf16 convolutions to f32. + ConvBfloat16Support conv_bf16_support( + std::get(gpu_version)); + pipeline.AddPass(&conv_bf16_support); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); @@ -101,25 +119,67 @@ Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( // tf2xla bridge, DepthwiseConvolutionConverter and GpuConvRewriter // introduces reshapes and transposes that can be eliminated using // AlgebraicSimplifier We run algsimp to a fixed point. - AlgebraicSimplifierOptions options; + AlgebraicSimplifierOptions options = + GetAlgebraicSimplifierOptions(hlo_module->config()); options.set_enable_conv_operand_swap(false); options.set_enable_unconditional_reduce_of_concat_replacement(false); pipeline.AddPass>(options); + // tf2xla bridge, DepthwiseConvolutionConverter, GpuConvRewriter, and + // CudnnSimplifyPadding introduce reshapes and transposes. Run ReshapeMover + // to a fixed point. Include algsimp because ReshapeMover relies on it. + [&, &pipeline = pipeline.AddPass>( + "reshape_mover_after_conv_canonicalization")] { + ReshapeMoverOptions reshape_mover_options; + reshape_mover_options.reshape_of_1d_broadcast_is_cheap = true; + pipeline.AddPass>(reshape_mover_options); + pipeline.AddPass(options); + }(); + + // The reshapes and transposes can possibly be eliminated using + // AlgebraicSimplifier. ConvertMover and ReshapeMover fight with each other. + // ConvertMover wants to move some converts down the graph, but ReshapeMover + // wants to move them up the graph. We run ConvertMover and algsimp to a fixed + // point. + [&, &pipeline = pipeline.AddPass>( + "simplify_after_conv_canonicalization")] { + pipeline.AddPass(); + pipeline.AddPass(options); + }(); + + // GpuConvRewriter, GpuConvPaddingLegalization and + // CudnnConvPadForTensorCores may add instructions which can be simplified + // by constant folding. pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - return OkStatus(); + return absl::OkStatus(); } -Status AMDGPUCompiler::OptimizeHloPostLayoutAssignment( +absl::Status AMDGPUCompiler::OptimizeHloPostLayoutAssignment( HloModule* hlo_module, se::StreamExecutor* stream_exec, const CompileOptions& options, const TargetConfig& gpu_target_config, tsl::thread::ThreadPool* thread_pool) { + HloPassPipeline pre_pipeline("AMDGPU post-layout_assignment part 1"); + + auto rocm_compute_capability = std::get( + gpu_target_config.device_description.gpu_compute_capability()); + + pre_pipeline.AddPass(); + + for (const auto& req : HipblasPaddingRequirements) { + pre_pipeline.AddPass(rocm_compute_capability, + req.data_type, req.multiple_of); + } + // Padding a gemm operand that's a constant results in pad(constant). Run + // constant-folding to simplify this into a new constant. + pre_pipeline.AddPass(); + TF_RETURN_IF_ERROR(pre_pipeline.Run(hlo_module).status()); + TF_RETURN_IF_ERROR(GpuCompiler::OptimizeHloPostLayoutAssignment( hlo_module, stream_exec, options, gpu_target_config, thread_pool)); - HloPassPipeline post_pipeline("AMDGPU post-layout_assignment"); + HloPassPipeline post_pipeline("AMDGPU post-layout_assignment part 2"); // Transform TriangularSolve ops into custom-calls, so we can add temp // memory. @@ -127,7 +187,7 @@ Status AMDGPUCompiler::OptimizeHloPostLayoutAssignment( TF_RETURN_IF_ERROR(post_pipeline.Run(hlo_module).status()); - return OkStatus(); + return absl::OkStatus(); } // Linearize collective schedule under if online autotuning of convolutions is @@ -148,29 +208,35 @@ bool AMDGPUCompiler::RequiresCollectiveScheduleLinearizer( return false; } -Status AMDGPUCompiler::AddConvAndGemmAutotuningPasses( +absl::Status AMDGPUCompiler::AddConvAndGemmAutotuningPasses( HloPassPipeline* pipeline, HloModule* hlo_module, AutotuneConfig& autotune_config, tsl::thread::ThreadPool* thread_pool) { if (GpuConvAlgorithmPicker::IsEnabled(hlo_module)) { pipeline->AddPass(autotune_config); } pipeline->AddPass(autotune_config); - return OkStatus(); + return absl::OkStatus(); +} + +absl::Status AMDGPUCompiler::AddCustomKernelReplacementPasses( + HloPassPipeline* pipeline, const DebugOptions& debug_options) { + if (debug_options.xla_gpu_enable_cub_radix_sort()) { + pipeline->AddPass(); + } + return absl::OkStatus(); } AMDGPUCompiler::AMDGPUCompiler() : GpuCompiler(stream_executor::rocm::kROCmPlatformId, amdgpu::TargetTriple(), amdgpu::DataLayout()) {} -StatusOr AMDGPUCompiler::CompileTargetBinary( - const HloModuleConfig& module_config, llvm::Module* llvm_module, - se::GpuComputeCapability gpu_version, bool relocatable, - const HloModule* debug_module, const CompileOptions& options) { - if (rocdl_dir_.empty()) { - // Compute rocdl_dir_ just once and cache it in this member. - rocdl_dir_ = GetROCDLDir(module_config); - } - +absl::StatusOr +AMDGPUCompiler::CompileTargetBinary(const HloModuleConfig& module_config, + llvm::Module* llvm_module, + se::GpuComputeCapability gpu_version, + bool relocatable, + const HloModule* debug_module, + const CompileOptions& options) { if (relocatable) { return Unimplemented("relocatable target binary is not implemented"); } @@ -184,7 +250,7 @@ StatusOr AMDGPUCompiler::CompileTargetBinary( !options.is_autotuning_compilation); TF_ASSIGN_OR_RETURN( hsaco, amdgpu::CompileToHsaco(llvm_module, gpu_version, - module_config.debug_options(), rocdl_dir_, + module_config.debug_options(), module_config.compilation_cache_key())); } diff --git a/xla/service/gpu/amdgpu_compiler.h b/xla/service/gpu/amdgpu_compiler.h index d070c5d9cd840..9f4ad9f656256 100644 --- a/xla/service/gpu/amdgpu_compiler.h +++ b/xla/service/gpu/amdgpu_compiler.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,14 +16,20 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_AMDGPU_COMPILER_H_ #define XLA_SERVICE_GPU_AMDGPU_COMPILER_H_ -#include -#include -#include - +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/IR/Module.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/autotuner_util.h" #include "xla/service/gpu/gpu_compiler.h" -#include "xla/statusor.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/hlo_pass_pipeline.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/stream_executor_pimpl.h" #include "xla/xla.pb.h" +#include "tsl/platform/threadpool.h" namespace xla { namespace gpu { @@ -33,12 +39,12 @@ class AMDGPUCompiler : public GpuCompiler { public: AMDGPUCompiler(); - Status OptimizeHloConvolutionCanonicalization( + absl::Status OptimizeHloConvolutionCanonicalization( HloModule* hlo_module, se::GpuComputeCapability gpu_version, se::dnn::VersionInfo dnn_version, se::DeviceMemoryAllocator* device_allocator) override; - Status OptimizeHloPostLayoutAssignment( + absl::Status OptimizeHloPostLayoutAssignment( HloModule* hlo_module, se::StreamExecutor* stream_exec, const CompileOptions& options, const TargetConfig& gpu_target_config, tsl::thread::ThreadPool* thread_pool) override; @@ -46,20 +52,20 @@ class AMDGPUCompiler : public GpuCompiler { bool RequiresCollectiveScheduleLinearizer( const HloModule* module, se::StreamExecutor* stream_exec) override; - Status AddConvAndGemmAutotuningPasses( + absl::Status AddConvAndGemmAutotuningPasses( HloPassPipeline* pipeline, HloModule* hlo_module, AutotuneConfig& autotune_config, tsl::thread::ThreadPool* thread_pool) override; - StatusOr CompileTargetBinary( + absl::Status AddCustomKernelReplacementPasses( + HloPassPipeline* pipeline, const DebugOptions& debug_options) override; + + absl::StatusOr CompileTargetBinary( const HloModuleConfig& module_config, llvm::Module* llvm_module, se::GpuComputeCapability gpu_version, bool relocatable, const HloModule* debug_module, const CompileOptions& options) override; private: - // The parent directory of ROCm-Device-Libs IR libraries. - std::string rocdl_dir_; - AMDGPUCompiler(const AMDGPUCompiler&) = delete; AMDGPUCompiler& operator=(const AMDGPUCompiler&) = delete; }; diff --git a/xla/service/gpu/amdgpu_compiler_registration.cc b/xla/service/gpu/amdgpu_compiler_registration.cc index 878c5cfe22e86..4d016bc650c54 100644 --- a/xla/service/gpu/amdgpu_compiler_registration.cc +++ b/xla/service/gpu/amdgpu_compiler_registration.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + +#include "xla/service/compiler.h" #include "xla/service/gpu/amdgpu_compiler.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" diff --git a/xla/service/gpu/auto_sharding_gpu_compiler_test.cc b/xla/service/gpu/auto_sharding_gpu_compiler_test.cc index d125ea202631c..eab4b0d48e5db 100644 --- a/xla/service/gpu/auto_sharding_gpu_compiler_test.cc +++ b/xla/service/gpu/auto_sharding_gpu_compiler_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,10 +15,16 @@ limitations under the License. #include -#include "xla/hlo/utils/hlo_matchers.h" +#include +#include +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/logging.h" namespace xla { namespace gpu { @@ -26,6 +32,8 @@ namespace { namespace m = ::xla::match; +using ::testing::Conditional; + class AutoShardingTest : public HloTestBase { protected: const char* const dot_hlo_string_ = R"( @@ -56,16 +64,41 @@ ENTRY matmul { }; TEST_F(AutoShardingTest, MatMulWithAutosharding) { - auto compiled_module = CompileMatMul(true, 4); - auto* instruction = FindInstruction(compiled_module.get(), "param"); - VLOG(2) << instruction->ToString(); - EXPECT_THAT(instruction, - GmockMatch(m::Op().WithSharding("{devices=[4,1]0,1,2,3}"))); + std::unique_ptr compiled_module = CompileMatMul(true, 4); + const HloInstruction* parameter1 = + compiled_module->entry_computation()->parameter_instruction(0); + const HloInstruction* parameter2 = + compiled_module->entry_computation()->parameter_instruction(1); + bool is_parameter1_replicated = ShapeUtil::Equal( + parameter1->shape(), ShapeUtil::MakeShape(PrimitiveType::F32, {32, 64})); + bool is_parameter2_replicated = ShapeUtil::Equal( + parameter2->shape(), ShapeUtil::MakeShape(PrimitiveType::F32, {64, 128})); + + // Check that at least one of the parameters is sharded, thereby telling us + // that the dot is as well. + VLOG(2) << parameter1->ToString(); + EXPECT_THAT( + parameter1, + Conditional( + is_parameter2_replicated, + AnyOf(GmockMatch(m::Op().WithShape(PrimitiveType::F32, {8, 64})), + GmockMatch(m::Op().WithShape(PrimitiveType::F32, {32, 16}))), + GmockMatch(m::Op().WithShape(PrimitiveType::F32, {32, 64})))); + + VLOG(2) << parameter2->ToString(); + EXPECT_THAT( + parameter2, + Conditional( + is_parameter1_replicated, + AnyOf(GmockMatch(m::Op().WithShape(PrimitiveType::F32, {16, 128})), + GmockMatch(m::Op().WithShape(PrimitiveType::F32, {64, 32}))), + GmockMatch(m::Op().WithShape(PrimitiveType::F32, {64, 128})))); } TEST_F(AutoShardingTest, MatMulWithoutAutosharding) { auto compiled_module = CompileMatMul(false, 4); - auto* instruction = FindInstruction(compiled_module.get(), "param"); + auto* instruction = + compiled_module->entry_computation()->parameter_instruction(0); VLOG(2) << instruction->ToString(); EXPECT_THAT(instruction, GmockMatch(m::Op().WithSharding("{replicated}"))); } diff --git a/xla/service/gpu/autotuner_compile_util.cc b/xla/service/gpu/autotuner_compile_util.cc index 4350dc235423a..e88f893a48036 100644 --- a/xla/service/gpu/autotuner_compile_util.cc +++ b/xla/service/gpu/autotuner_compile_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -34,12 +35,10 @@ limitations under the License. #include "xla/service/gpu/autotuner_util.h" #include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/hlo_module_config.h" +#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/service_executable_run_options.h" #include "xla/shape.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/stream.h" #include "xla/util.h" #include "xla/xla.pb.h" @@ -87,15 +86,13 @@ AutotunerCompileUtil::AutotunerCompileUtil(const AutotuneConfig& config, opts_.set_xla_gpu_dump_llvmir(false); // Avoid using another thread pool. opts_.set_xla_gpu_force_compilation_parallelism(1); + opts_.set_xla_gpu_enable_llvm_module_compilation_parallelism(false); // Avoid using GPU graphs as we don't want to measure graph construction time. opts_.clear_xla_gpu_enable_command_buffer(); - // Disable experimental XLA:GPU runtime. - opts_.set_xla_gpu_enable_gpu2_runtime(false); opts_.set_xla_embed_ir_in_executable(false); - opts_.set_xla_gpu_enable_persistent_temp_buffers(false); } -StatusOr> +absl::StatusOr> AutotunerCompileUtil::ProfileExecutable( Executable* executable, se::Stream* stream, absl::Span input_buffers, @@ -105,7 +102,7 @@ AutotunerCompileUtil::ProfileExecutable( ExecutionInputsFromBuffers(input_buffers, input_shapes); // Warmup: in and out buffers are reused while probing different configs, // so GPU caches should be in some comparable states during measurements. - StatusOr execution_output = + absl::StatusOr execution_output = Execute(*executable, std::move(execution_inputs)); if (!execution_output.ok()) { // Treat register allocation error gracefully. If the compilation happens @@ -122,19 +119,21 @@ AutotunerCompileUtil::ProfileExecutable( } std::vector execution_inputs = ExecutionInputsFromBuffers(input_buffers, input_shapes); - TF_ASSIGN_OR_RETURN(auto timer, - se::gpu::GpuTimer::Create(se::gpu::AsGpuStream(stream))); - TF_ASSIGN_OR_RETURN(ExecutionOutput execution_output, - Execute(*executable, std::move(execution_inputs))); - TF_ASSIGN_OR_RETURN(absl::Duration timer_duration, - timer.GetElapsedDuration()); + ExecutionProfile profile; + // Flag that a warm-up run was executed so that GpuTimer can use the, more + // accurate, delay kernel implementation. + profile.set_warmup_run_executed(true); + TF_ASSIGN_OR_RETURN( + ExecutionOutput execution_output, + Execute(*executable, std::move(execution_inputs), &profile)); return std::make_optional( - timer_duration, execution_output.Commit().ConsumeResult()); + absl::Nanoseconds(profile.compute_time_ns()), + execution_output.Commit().ConsumeResult()); } -StatusOr> AutotunerCompileUtil::Compile( +absl::StatusOr> AutotunerCompileUtil::Compile( GenerateModuleFn extractor) { - StatusOr> new_hlo_module = extractor(opts_); + absl::StatusOr> new_hlo_module = extractor(opts_); if (new_hlo_module.status().GetPayload(kUncompilableFusion).has_value()) { // Incompatible value of split-k is an example of an expected failure. return std::unique_ptr(); @@ -142,7 +141,7 @@ StatusOr> AutotunerCompileUtil::Compile( return new_hlo_module.status(); } - StatusOr> out = compiler_->RunBackend( + absl::StatusOr> out = compiler_->RunBackend( std::move(*new_hlo_module), &stream_executor_, Compiler::CompileOptions{&allocator_, /*thread_pool=*/nullptr, /*layout_canonicalization_callback=*/{}, @@ -156,12 +155,12 @@ StatusOr> AutotunerCompileUtil::Compile( return out; } -StatusOr> AutotunerCompileUtil::ExtractModule( +absl::StatusOr> AutotunerCompileUtil::ExtractModule( GenerateModuleFn extractor) { return extractor(opts_); } -/*static*/ StatusOr> +/*static*/ absl::StatusOr> AutotunerCompileUtil::Create(const AutotuneConfig& config, const DebugOptions& opts) { if (config.IsDeviceless()) { @@ -176,8 +175,9 @@ AutotunerCompileUtil::Create(const AutotuneConfig& config, *allocator, opts); } -StatusOr AutotunerCompileUtil::Execute( - Executable& executable, std::vector arguments) { +absl::StatusOr AutotunerCompileUtil::Execute( + Executable& executable, std::vector arguments, + ExecutionProfile* profile) { // Require exclusive GPU lock to prevent other runs during autotuning. GpuExecutableRunOptions gpu_opts; gpu_opts.set_requires_exclusive_lock_on_gpu(); @@ -187,6 +187,7 @@ StatusOr AutotunerCompileUtil::Execute( run_options.set_stream(&stream_); run_options.set_allocator(&allocator_); run_options.set_gpu_executable_run_options(&gpu_opts); + run_options.set_execution_profile(profile); ServiceExecutableRunOptions service_run_options(run_options); TF_ASSIGN_OR_RETURN(ExecutionOutput output, executable.ExecuteAsyncOnStreamWrapper( diff --git a/xla/service/gpu/autotuner_compile_util.h b/xla/service/gpu/autotuner_compile_util.h index b5f5a763f4ed8..a5d018c568b93 100644 --- a/xla/service/gpu/autotuner_compile_util.h +++ b/xla/service/gpu/autotuner_compile_util.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,9 +18,12 @@ limitations under the License. #include #include +#include #include #include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" @@ -29,8 +32,9 @@ limitations under the License. #include "xla/service/compiler.h" #include "xla/service/executable.h" #include "xla/service/gpu/autotuner_util.h" +#include "xla/service/shaped_buffer.h" #include "xla/shape.h" -#include "xla/statusor.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream.h" #include "xla/util.h" #include "xla/xla.pb.h" @@ -48,14 +52,14 @@ class AutotunerCompileUtil { // the debug options. In justified cases, it may override some of the provided // debug options. using GenerateModuleFn = - absl::AnyInvocable>( + absl::AnyInvocable>( const DebugOptions&)>; // Generates a compile util for a platform associated with the `stream`. // // Returns an empty optional if the AutotuneConfig is deviceless, as // autotuning is impossible in that case. - static StatusOr> Create( + static absl::StatusOr> Create( const AutotuneConfig& config, const DebugOptions& opts); struct ProfilingOutput { @@ -72,7 +76,7 @@ class AutotunerCompileUtil { // Runs the resulting executable with the given extractor, cached with // `(cache_key, config)`. Returns `std::nullopt` on expected failure, bad // `Status` otherwise. - StatusOr> ProfileExecutable( + absl::StatusOr> ProfileExecutable( Executable* executable, se::Stream* stream, absl::Span input_buffers, absl::Span input_shapes); @@ -83,13 +87,14 @@ class AutotunerCompileUtil { // - `nullptr` on *expected* failure // - `Executable` if everything goes fine. // - `Status` on *unexpected* failure. - StatusOr> Compile(GenerateModuleFn extractor); + absl::StatusOr> Compile( + GenerateModuleFn extractor); // Generic method to extract an HLO using the debug options of the // AutotunerCompileUtil. // // Typically we can use Compile directly. - StatusOr> ExtractModule( + absl::StatusOr> ExtractModule( GenerateModuleFn extractor); private: @@ -98,8 +103,9 @@ class AutotunerCompileUtil { se::DeviceMemoryAllocator& allocator, const DebugOptions& opts); - StatusOr Execute(Executable& executable, - std::vector arguments); + absl::StatusOr Execute(Executable& executable, + std::vector arguments, + ExecutionProfile* profile = nullptr); AutotuneConfig config_; Compiler* compiler_; diff --git a/xla/service/gpu/autotuner_util.cc b/xla/service/gpu/autotuner_util.cc index 1b9ee7297bee1..b424aed57f87e 100644 --- a/xla/service/gpu/autotuner_util.cc +++ b/xla/service/gpu/autotuner_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,27 +16,39 @@ limitations under the License. #include "xla/service/gpu/autotuner_util.h" #include -#include +#include +#include #include -#include #include -#include -#include "absl/log/log.h" +#include "absl/base/const_init.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "xla/autotune_results.pb.h" #include "xla/autotuning.pb.h" +#include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_asm_opts_util.h" #include "xla/service/gpu/stream_executor_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/redzone_allocator.h" +#include "xla/stream_executor/stream.h" #include "xla/util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/path.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep #include "tsl/platform/statusor.h" namespace xla { @@ -48,17 +60,8 @@ static absl::Mutex autotune_cache_mu(absl::kConstInit); static auto& autotune_cache ABSL_GUARDED_BY(autotune_cache_mu) = *new AutotuneCacheMap(); -/*static*/ Status AutotunerUtil::SerializeAutotuneResults( - AutotuneResults* results) { - absl::MutexLock lock(&autotune_cache_mu); - for (const auto& [k, result] : autotune_cache) { - auto& entry = *results->add_results(); - entry.set_device(std::string(k.GetModelStr())); - entry.set_hlo(std::string(k.GetHlo())); - *entry.mutable_result() = result; - } - - // Sort the results so that they're deterministic. +// Sort the results so that they're deterministic. +static void SortAutotuneResults(AutotuneResults* results) { std::sort(results->mutable_results()->pointer_begin(), results->mutable_results()->pointer_end(), [](const auto* a, const auto* b) { @@ -67,18 +70,52 @@ static auto& autotune_cache ABSL_GUARDED_BY(autotune_cache_mu) = std::make_pair(absl::string_view(b->device()), absl::string_view(b->hlo())); }); +} + +// Serialize `results` to string as a proto. +static absl::StatusOr AutotuneResultsToString( + const AutotuneResults& results, bool as_textproto) { + if (as_textproto) { + std::string textproto; + if (tsl::protobuf::TextFormat::PrintToString(results, &textproto)) { + return textproto; + } else { + return Internal("Failed to serialize autotune results."); + } + } + return results.SerializeAsString(); +} - return OkStatus(); +// Serialize a single entry to `results`. +static void SerializeAutotuneEntry(AutotuneResults* results, + const AutotuneCacheKey& k, + const AutotuneResult* res) { + auto& entry = *results->add_results(); + entry.set_device(std::string(k.GetModelStr())); + entry.set_hlo(std::string(k.GetHlo())); + *entry.mutable_result() = *res; } -/*static*/ Status AutotunerUtil::LoadAutotuneResults( +/*static*/ absl::Status AutotunerUtil::SerializeAutotuneResults( + AutotuneResults* results) { + absl::MutexLock lock(&autotune_cache_mu); + for (const auto& [k, result] : autotune_cache) { + SerializeAutotuneEntry(results, k, &result); + } + + SortAutotuneResults(results); + + return absl::OkStatus(); +} + +/*static*/ absl::Status AutotunerUtil::LoadAutotuneResults( const AutotuneResults& results) { absl::MutexLock lock(&autotune_cache_mu); for (const auto& result : results.results()) { autotune_cache[AutotuneCacheKey(result.device(), result.hlo())] = result.result(); } - return OkStatus(); + return absl::OkStatus(); } /*static*/ void AutotunerUtil::ClearAutotuneResults() { @@ -86,7 +123,7 @@ static auto& autotune_cache ABSL_GUARDED_BY(autotune_cache_mu) = autotune_cache.clear(); } -/* static*/ StatusOr AutotunerUtil::CreateBuffer( +/* static*/ absl::StatusOr AutotunerUtil::CreateBuffer( se::RedzoneAllocator& allocator, const Shape& shape, const AutotuneConfig& config, int64_t& rng_state) { TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer, @@ -125,9 +162,20 @@ static AutotuneResult* TryFindInCache(const AutotuneCacheKey& key) { absl::MutexLock lock(&autotune_cache_mu); auto it = autotune_cache.find(key); if (it != autotune_cache.end()) { - VLOG(1) << "Autotune cache hit"; + // Cache hit. + if (VLOG_IS_ON(1)) { + LOG(INFO) << "Autotune cache hit"; + } else if (VLOG_IS_ON(2)) { + LOG(INFO) << "Autotune cache hit: key = " << key.ToString(); + } return &it->second; } + + if (VLOG_IS_ON(1)) { + LOG(INFO) << "Autotune cache miss"; + } else if (VLOG_IS_ON(2)) { + LOG(INFO) << "Autotune cache miss: key = " << key.ToString(); + } return nullptr; } @@ -147,14 +195,22 @@ static AutotuneResult* TryFindInCache(const AutotuneCacheKey& key) { return inserted; } -/*static*/ StatusOr AutotunerUtil::Autotune( +/*static*/ absl::StatusOr AutotunerUtil::Autotune( const HloInstruction* instr, const AutotuneConfig& config, const AutotuneNoCacheFn& autotune_fn) { - AutotuneCacheKey key = GetKey(instr, config); + const AutotuneCacheKey key = GetKey(instr, config); if (AutotuneResult* res = TryFindInCache(key)) { return *res; } + // Cache miss. + if (config.should_require_complete_aot_autotune_results()) { + return NotFound( + "Complete XLA AOT autotuning results are required, but no AOT result " + "was found for key: %s", + key.ToString()); + } + TF_ASSIGN_OR_RETURN(AutotuneResult autotune_result, autotune_fn()); absl::MutexLock lock(&autotune_cache_mu); @@ -166,19 +222,20 @@ namespace { // Bump this version whenever you change the structure of the results. // LINT.IfChange(version) -constexpr int kVersion = 2; +constexpr int kVersion = 3; // LINT.ThenChange() bool IsTextProtoPath(absl::string_view file_path) { return absl::EndsWith(file_path, ".txt") || absl::EndsWith(file_path, ".textproto") || - absl::EndsWith(file_path, ".prototxt"); + absl::EndsWith(file_path, ".prototxt") || + absl::EndsWith(file_path, ".pbtxt"); } } // anonymous namespace -/*static*/ Status AutotunerUtil::LoadAutotuneResults(absl::string_view data, - bool as_textproto) { +/*static*/ absl::Status AutotunerUtil::LoadAutotuneResults( + absl::string_view data, bool as_textproto) { AutotuneResults results; // The cast here is necessary for MacOS builds. bool parse_success = @@ -196,27 +253,18 @@ bool IsTextProtoPath(absl::string_view file_path) { } TF_RETURN_IF_ERROR(LoadAutotuneResults(results)); - return OkStatus(); + return absl::OkStatus(); } -/*static*/ StatusOr AutotunerUtil::SerializeAutotuneResults( +/*static*/ absl::StatusOr AutotunerUtil::SerializeAutotuneResults( bool as_textproto) { AutotuneResults results; results.set_version(kVersion); TF_RETURN_IF_ERROR(SerializeAutotuneResults(&results)); - if (as_textproto) { - std::string textproto; - if (tsl::protobuf::TextFormat::PrintToString(results, &textproto)) { - return textproto; - } else { - return Status(absl::StatusCode::kInternal, - "Failed to serialize autotune results."); - } - } - return results.SerializeAsString(); + return AutotuneResultsToString(results, as_textproto); } -/*static*/ Status AutotunerUtil::SerializeAutotuneResultsToFile( +/*static*/ absl::Status AutotunerUtil::SerializeAutotuneResultsToFile( absl::string_view file_path) { TF_RET_CHECK(!file_path.empty()); @@ -231,10 +279,10 @@ bool IsTextProtoPath(absl::string_view file_path) { autotune_results_str)); LOG(INFO) << "Autotune results serialized to file: " << resolved_path; - return OkStatus(); + return absl::OkStatus(); } -/*static*/ Status AutotunerUtil::LoadAutotuneResultsFromFile( +/*static*/ absl::Status AutotunerUtil::LoadAutotuneResultsFromFile( absl::string_view file_path) { TF_RET_CHECK(!file_path.empty()); @@ -256,48 +304,13 @@ bool IsTextProtoPath(absl::string_view file_path) { LOG(INFO) << "Autotune results loaded from file: " << resolved_path; - return OkStatus(); + return absl::OkStatus(); } -/*static*/ std::unique_ptr -AutotunerUtil::ExtractInstructionIntoNewModule(const HloInstruction& hlo) { - auto new_hlo_module = std::make_unique( - "extracted", HloModuleConfig{}, - std::make_unique(hlo.GetModule()->comp_envs())); - int parameter_number = 0; - HloComputation::Builder builder("entry_computation"); - HloCloneContext clone_context(new_hlo_module.get()); - std::vector new_operands; - for (const HloInstruction* operand : hlo.operands()) { - std::unique_ptr new_parameter = - HloInstruction::CreateParameter(parameter_number, operand->shape(), - operand->name()); - ++parameter_number; - new_operands.push_back(builder.AddInstruction(std::move(new_parameter))); - } - std::unique_ptr new_instruction = - hlo.CloneWithNewOperands(hlo.shape(), new_operands, &clone_context); - builder.AddInstruction(std::move(new_instruction)); - new_hlo_module->AddEntryComputationWithLayouts(builder.Build()); - return new_hlo_module; -} - -/*static*/ std::unique_ptr -AutotunerUtil::ExtractComputationIntoNewModule( - const HloComputation& computation) { - auto new_hlo_module = - std::make_unique("extracted", HloModuleConfig{}, - std::make_unique( - computation.parent()->comp_envs())); - HloCloneContext clone_context(new_hlo_module.get()); - new_hlo_module->AddEntryComputationWithLayouts( - computation.CloneInContext(clone_context)); - return new_hlo_module; -} - -/*static*/ StatusOr AutotunerUtil::CreateRedzoneAllocator( - const AutotuneConfig& config, const DebugOptions& opts, - se::Stream* force_stream) { +/*static*/ absl::StatusOr +AutotunerUtil::CreateRedzoneAllocator(const AutotuneConfig& config, + const DebugOptions& opts, + se::Stream* force_stream) { se::Stream* stream = force_stream; if (stream == nullptr) { TF_ASSIGN_OR_RETURN(stream, config.GetStream()); @@ -310,5 +323,24 @@ AutotunerUtil::ExtractComputationIntoNewModule( : 0); } +/*static*/ absl::StatusOr +AutotunerUtil::SerializeAutotuneResultsForModule( + const HloModule& module, const AutotuneConfig& autotune_config, + bool as_textproto) { + AutotuneResults results; + results.set_version(kVersion); + + for (const HloInstruction* instr : + module.entry_computation()->instructions()) { + AutotuneCacheKey k(autotune_config.GetModelStr(), *instr); + if (const AutotuneResult* res = TryFindInCache(k)) { + SerializeAutotuneEntry(&results, k, res); + } + } + + SortAutotuneResults(&results); + return AutotuneResultsToString(results, as_textproto); +} + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/autotuner_util.h b/xla/service/gpu/autotuner_util.h index d25f881274389..6569615fe07b4 100644 --- a/xla/service/gpu/autotuner_util.h +++ b/xla/service/gpu/autotuner_util.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,22 +16,28 @@ limitations under the License. #define XLA_SERVICE_GPU_AUTOTUNER_UTIL_H_ #include +#include #include -#include #include -#include #include #include +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "xla/autotune_results.pb.h" #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/shape.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/gpu/redzone_allocator.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" #include "xla/xla.pb.h" namespace xla { @@ -54,7 +60,8 @@ struct DevicelessConfig { // A field to determine the architecture of the device. We only pick an // algorithm for non-Ampere architectures. - se::CudaComputeCapability cuda_compute_capability{0, 0}; + se::GpuComputeCapability gpu_compute_capability{ + se::CudaComputeCapability{0, 0}}; }; class AutotuneCacheKey { @@ -97,6 +104,9 @@ class AutotuneConfig { bool should_crash_on_check_failure() const { return should_crash_on_check_failure_; } + bool should_require_complete_aot_autotune_results() const { + return require_complete_aot_autotune_results_; + } AutotuneConfig(const std::variant& config, const DebugOptions& debug_options) @@ -105,7 +115,9 @@ class AutotuneConfig { should_crash_on_check_failure_( debug_options.xla_gpu_crash_on_verification_failures()), exhaustive_tiling_search_( - debug_options.xla_gpu_exhaustive_tiling_search()) {} + debug_options.xla_gpu_exhaustive_tiling_search()), + require_complete_aot_autotune_results_( + debug_options.xla_gpu_require_complete_aot_autotune_results()) {} absl::string_view GetModelStr() const { if (auto deviceless_config = std::get_if(&config_)) { @@ -127,16 +139,16 @@ class AutotuneConfig { return cf.allocator ? cf.allocator : GetExecutor()->GetAllocator(); } - StatusOr GetStream() const { + absl::StatusOr GetStream() const { CHECK(std::holds_alternative(config_)); return GetAllocator()->GetStream(GetExecutor()->device_ordinal()); } - se::CudaComputeCapability GetCudaComputeCapability() const { + const se::GpuComputeCapability& GetGpuComputeCapability() const { if (auto c = std::get_if(&config_)) { - return c->stream_exec->GetDeviceDescription().cuda_compute_capability(); + return c->stream_exec->GetDeviceDescription().gpu_compute_capability(); } - return std::get(config_).cuda_compute_capability; + return std::get(config_).gpu_compute_capability; } bool IsDeviceless() const { @@ -150,18 +162,19 @@ class AutotuneConfig { int32_t autotune_level_; bool should_crash_on_check_failure_; bool exhaustive_tiling_search_; + bool require_complete_aot_autotune_results_; }; -using AutotuneNoCacheFn = std::function()>; +using AutotuneNoCacheFn = std::function()>; struct AutotunerUtil { // Create a buffer for a given operation using redzone checker, initialize // based on a given rng state. - static StatusOr CreateBuffer( + static absl::StatusOr CreateBuffer( se::RedzoneAllocator& allocator, const Shape& shape, const AutotuneConfig& config, int64_t& rng_state); - static StatusOr Autotune( + static absl::StatusOr Autotune( const HloInstruction* instr, const AutotuneConfig& config, const AutotuneNoCacheFn& autotune_fn); @@ -185,7 +198,7 @@ struct AutotunerUtil { // Creates a RedzoneAllocator from a given config. If `force_stream` is // provided, than it is used for checking redzones. - static StatusOr CreateRedzoneAllocator( + static absl::StatusOr CreateRedzoneAllocator( const AutotuneConfig& config, const DebugOptions& opts, se::Stream* force_stream = nullptr); @@ -228,39 +241,40 @@ struct AutotunerUtil { // dots/convs it wants to run can also change. For example, XLA might change // the conv padding heuristics it uses, and we don't want that to mean that // all users of ahead-of-time autotuning are broken. - static StatusOr SerializeAutotuneResults( + static absl::StatusOr SerializeAutotuneResults( + bool as_textproto = false); + + // As above, but only performs serialization for instructions found in the + // module. + // + // Only serializes autotuning results for instructions found in the module: + // while this is more expensive than serializing all cache, this avoids + // quadratic blow-up when serializing cache for a large number of modules. + static absl::StatusOr SerializeAutotuneResultsForModule( + const HloModule& module, const AutotuneConfig& autotune_config, bool as_textproto = false); - static Status SerializeAutotuneResults(AutotuneResults* results); - static Status LoadAutotuneResults(absl::string_view data, - bool as_textproto = false); + static absl::Status SerializeAutotuneResults(AutotuneResults* results); + static absl::Status LoadAutotuneResults(absl::string_view data, + bool as_textproto = false); - static Status LoadAutotuneResults(const AutotuneResults& results); + static absl::Status LoadAutotuneResults(const AutotuneResults& results); // Serializes autotune results into a file. // // If `file_path` ends with ".txt" or ".textproto", then the textproto format // is used, otherwise the binary protobuf format. - static Status SerializeAutotuneResultsToFile(absl::string_view file_path); + static absl::Status SerializeAutotuneResultsToFile( + absl::string_view file_path); // Loads autotune results from a file. // // If `file_path` ends with ".txt" or ".textproto", then the file is // considered to be in the textproto format, otherwise the binary protobuf // format. - static Status LoadAutotuneResultsFromFile(absl::string_view file_path); + static absl::Status LoadAutotuneResultsFromFile(absl::string_view file_path); static void ClearAutotuneResults(); - - // Extracts an HLO instruction into a new HLO module replacing its operands - // with parameter instructions. - static std::unique_ptr ExtractInstructionIntoNewModule( - const HloInstruction& hlo); - - // Extracts an HLO computation into a new HLO module, using its clone as the - // root computation. - static std::unique_ptr ExtractComputationIntoNewModule( - const HloComputation& computation); }; } // namespace gpu diff --git a/xla/service/gpu/autotuner_util_test.cc b/xla/service/gpu/autotuner_util_test.cc index 09887c01be7be..28ec27c64e8da 100644 --- a/xla/service/gpu/autotuner_util_test.cc +++ b/xla/service/gpu/autotuner_util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,23 +17,36 @@ limitations under the License. #include #include +#include #include #include -#include "absl/base/log_severity.h" -#include "absl/log/scoped_mock_log.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "xla/autotune_results.pb.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream_executor_pimpl.h" #include "xla/tests/hlo_test_base.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/logging.h" // IWYU pragma: keep +#include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/status_matchers.h" namespace xla { namespace gpu { namespace { +using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::Not; using ::testing::TempDir; +using ::tsl::testing::StatusIs; class AutotunerUtilTest : public HloTestBase { protected: @@ -62,6 +75,13 @@ ENTRY e { EXPECT_THAT(str, Not(IsEmpty())); return str; } + + std::unique_ptr NewStreamExecutor() { + stream_executor::Platform* platform = + stream_executor::PlatformManager::PlatformWithName("Host").value(); + stream_executor::StreamExecutorConfig config(/*ordinal=*/0); + return platform->GetUncachedExecutor(config).value(); + } }; TEST_F(AutotunerUtilTest, SerializeAutotuneResultsToFile_TextProto1) { @@ -123,6 +143,63 @@ TEST_F(AutotunerUtilTest, LoadAutotuneResultsFromFile_Protobuf) { TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(kFilePath)); } +// Test that when complete AOT autotuning is required, and there is cache miss, +// a `NotFound` error will be raised. +TEST_F(AutotunerUtilTest, FailIfRequireCompleteAotAutotuning) { + std::string kFilePath = GetUniqueTempFilePath(".txt"); + auto hlo_module = GetOptimizedModule(kHloText); + TF_EXPECT_OK(hlo_module.status()); + std::vector computations = + (*hlo_module) + ->MakeNonfusionComputations(absl::flat_hash_set()); + EXPECT_THAT(computations, Not(IsEmpty())); + const HloInstruction* instruction = *computations[0]->instructions().begin(); + std::unique_ptr executor = + NewStreamExecutor(); + auto options = DebugOptions(); + options.set_xla_gpu_require_complete_aot_autotune_results(true); + AutotuneConfig config(DeviceConfig{executor.get()}, options); + EXPECT_THAT( + AutotunerUtil::Autotune(instruction, config, + [&] { return AutotuneResult(); }), + StatusIs( + absl::StatusCode::kNotFound, + HasSubstr("Complete XLA AOT autotuning results are required, but " + "no AOT result was found for key: computations = + (*hlo_module) + ->MakeNonfusionComputations(absl::flat_hash_set()); + EXPECT_THAT(computations, Not(IsEmpty())); + const HloInstruction* instruction = *computations[0]->instructions().begin(); + std::unique_ptr executor = + NewStreamExecutor(); + + { + // By default, JIT autotuning is OK. + AutotuneConfig config(DeviceConfig{executor.get()}, DebugOptions()); + TF_EXPECT_OK(AutotunerUtil::Autotune(instruction, config, [&] { + return AutotuneResult(); + }).status()); + } + + // Now require complete AOT autotuning results. + auto options = DebugOptions(); + options.set_xla_gpu_require_complete_aot_autotune_results(true); + + AutotuneConfig config(DeviceConfig{executor.get()}, options); + // Even though JIT autotuning is disabled, there is no cache miss when running + // autotuning for the same entry, so no error should be raised either. + TF_EXPECT_OK(AutotunerUtil::Autotune(instruction, config, [&] { + return AutotuneResult(); + }).status()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/backend_configs.proto b/xla/service/gpu/backend_configs.proto index 46f51ebe5857b..4e80ab7489735 100644 --- a/xla/service/gpu/backend_configs.proto +++ b/xla/service/gpu/backend_configs.proto @@ -11,6 +11,10 @@ import "tsl/protobuf/dnn.proto"; // These are metadata that the GPU backend attaches to HloInstructions and later // uses during e.g. codegen. // +// GpuBackendConfig serves as a parent config for all backend configs so +// configs won't overwrite each other. Any new backend config proto +// should be added to and used in GpuBackendConfig. +// // Remember that proto3 doesn't give clients a way to tell the difference // between a field not being present and a field having the default value. // Choose your defaults carefully. @@ -121,8 +125,22 @@ message CollectiveBackendConfig { bool no_parallel_custom_call = 2; } +// Backend config for cost model estimates. message ReificationCost { - double end_to_end_cycles = 1; // Total execution time of the reified op. + // Total execution time of the reified op. + double end_to_end_cycles = 1; + + // Estimated overall kernel execution in microseconds. + // + // GPU Cost Model estimates compute and memory access time separately. Exec + // time is a combined metric of the two. + double exec_time_us = 2; + + // Estimate for compute time in microseconds. + double compute_time_us = 3; + + // Estimate for memory access (read+write) time in microseconds. + double memory_access_time_us = 4; } // Backend config for a custom fusion (pre-compiled device kernel implementing a @@ -131,6 +149,10 @@ message CustomFusionConfig { string name = 1; } +message CuDnnFusionConfig { + int64 plan_id = 1; +} + message FusionBackendConfig { // kLoop, kInput, or kOutput (from HloInstruction::FusionKind), or your own // custom string. @@ -152,6 +174,8 @@ message FusionBackendConfig { // Cost model prediction. ReificationCost reification_cost = 3; + + CuDnnFusionConfig cudnn_fusion_config = 5; } // Backed config for norm executed by cuDNN. @@ -161,6 +185,14 @@ message CudnnNormBackendConfig { // Opaque algorithm number. stream_executor.dnn.AlgorithmProto algorithm = 2; + + // Norm type. + enum Kind { + LAYER_FWD_INFER = 0; + LAYER_FWD_TRAIN = 1; + LAYER_BWD = 2; + } + Kind kind = 3; } // Backend config for a fused Multi-Headed Attention (fMHA) that runs through @@ -197,3 +229,70 @@ message CudnnfMHABackendConfig { // Is causal mask bool is_causal_mask = 21; } + +message FlashAttnBackendConfig { + float dropout_rate = 2; + + float scale = 3; + + bool is_causal = 4; + + bool deterministic = 5; + + // Whether there is 'alibi_slopes' in the input arguments + bool has_alibi_slopes = 6; + + // Max sequence length of query in variable-length flash-attention + optional int32 max_seqlen_q = 7; + + // Max sequence length of key in variable-length flash-attention + optional int32 max_seqlen_k = 8; + + // Whether to return softmax in forward flash-attention + optional bool return_softmax = 9; +} + +// Generic backend config for XLA:GPU +message GpuBackendConfig { + // Specifies which operation queue the current instruction will run on. + // A backend may have multiple operation queues to run instructions + // concurrently, use this to signal the backend which queue to dispatch to. + // The backend should keep a mapping of + // operation_queue_id->actual_hardware_queue_id if runtime will create + // different IDs. + int64 operation_queue_id = 1; + + // Specifies which operation queues to await for data when running with + // multiple operation queues. + repeated int64 wait_on_operation_queues = 2; + + oneof backend_config { + CudnnConvBackendConfig cudnn_conv_backend_config = 3; + + GemmBackendConfig gemm_backend_config = 4; + + BitcastBackendConfig bitcast_backend_config = 5; + + CollectiveBackendConfig collective_backend_config = 6; + + FusionBackendConfig fusion_backend_config = 7; + + CudnnNormBackendConfig cudnn_norm_backend_config = 8; + + CudnnfMHABackendConfig cudnn_fmha_backend_config = 9; + + FlashAttnBackendConfig flash_attn_backend_config = 10; + } + + // This attribute instructs the latency-hiding scheduler to + // schedule this particular instruction to the earliest position. + // Note that setting this to true will make this instruction scheduled + // at the very beginning of the parent computation before + // every other nodes. + // An example use case would be deciding to schedule between collective + // or an async compute. LHS might put either one at the first place + // depending on the cost, but it'd be more beneficial if the collective + // is always scheduled first as it's not SM-heavy. + // In this case we can use this flag to enforce the ordering. + bool force_earliest_schedule = 11; +} diff --git a/xla/service/gpu/backend_configs_test.cc b/xla/service/gpu/backend_configs_test.cc index 99c86bd4e08d5..7a02490907055 100644 --- a/xla/service/gpu/backend_configs_test.cc +++ b/xla/service/gpu/backend_configs_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,11 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include +#include +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -44,13 +52,118 @@ TEST_F(BackendConfigsTest, DefaultCollectiveBackendConfig) { std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString, /*replica_count=*/2)); - const HloInstruction *ags = FindInstruction(module.get(), "agf32-start"); + const HloInstruction* ags = FindInstruction(module.get(), "agf32-start"); EXPECT_THAT(ags->has_backend_config(), IsFalse()); - auto collective_backend_config = - ags->backend_config(); - EXPECT_THAT(collective_backend_config.status(), IsOk()); - EXPECT_THAT(collective_backend_config->is_sync(), IsFalse()); - EXPECT_THAT(collective_backend_config->no_parallel_custom_call(), IsFalse()); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + ags->backend_config()); + const auto& collective_backend_config = + gpu_config.collective_backend_config(); + EXPECT_THAT(collective_backend_config.is_sync(), IsFalse()); + EXPECT_THAT(collective_backend_config.no_parallel_custom_call(), IsFalse()); +} + +TEST_F(BackendConfigsTest, DefaultGpuBackendConfigParseOpQueue) { + constexpr absl::string_view kHloString = R"( + HloModule ModuleWithAsync + + ENTRY entry { + p0f32 = f32[4, 4] parameter(0) + p1f32 = f32[4, 4] parameter(1) + + ROOT addf32 = f32[4, 4] add(p0f32, p1f32), backend_config={"operation_queue_id":"2"} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + HloInstruction* add = module->entry_computation()->root_instruction(); + EXPECT_TRUE(add->has_backend_config()); + auto real_gpu_backend_config = add->backend_config(); + EXPECT_THAT(real_gpu_backend_config.status(), IsOk()); + EXPECT_EQ(real_gpu_backend_config->operation_queue_id(), 2); +} + +TEST_F(BackendConfigsTest, DefaultGpuBackendConfigParseWaitOnQueue) { + constexpr absl::string_view kHloString = R"( + HloModule ModuleWithAsync + + ENTRY entry { + p0f32 = f32[4, 4] parameter(0) + p1f32 = f32[4, 4] parameter(1) + + ROOT addf32 = f32[4, 4] add(p0f32, p1f32), backend_config={"wait_on_operation_queues":[0, 1]} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + HloInstruction* add = module->entry_computation()->root_instruction(); + EXPECT_TRUE(add->has_backend_config()); + auto real_gpu_backend_config = add->backend_config(); + EXPECT_THAT(real_gpu_backend_config.status(), IsOk()); + std::vector expected_ids = {0, 1}; + EXPECT_EQ(real_gpu_backend_config->wait_on_operation_queues().size(), + expected_ids.size()); + for (int64_t i = 0; i < expected_ids.size(); i++) { + EXPECT_EQ(expected_ids[i], + real_gpu_backend_config->wait_on_operation_queues()[i]); + } +} + +TEST_F(BackendConfigsTest, DefaultGpuBackendConfigSetOpQueue) { + constexpr absl::string_view kHloString = R"( + HloModule ModuleWithAsync + + ENTRY entry { + p0f32 = f32[4, 4] parameter(0) + p1f32 = f32[4, 4] parameter(1) + + ROOT addf32 = f32[4, 4] add(p0f32, p1f32) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + HloInstruction* add = module->entry_computation()->root_instruction(); + EXPECT_FALSE(add->has_backend_config()); + GpuBackendConfig gpu_backend_config; + gpu_backend_config.set_operation_queue_id(2); + EXPECT_THAT(add->set_backend_config(gpu_backend_config), IsOk()); + EXPECT_EQ(add->raw_backend_config_string(), + "{\"operation_queue_id\":\"2\",\"wait_on_operation_queues\":[]," + "\"force_earliest_schedule\":false}"); +} + +TEST_F(BackendConfigsTest, DefaultGpuBackendConfigSetWaitOnQueue) { + constexpr absl::string_view kHloString = R"( + HloModule ModuleWithAsync + + ENTRY entry { + p0f32 = f32[4, 4] parameter(0) + p1f32 = f32[4, 4] parameter(1) + + ROOT addf32 = f32[4, 4] add(p0f32, p1f32) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + HloInstruction* add = module->entry_computation()->root_instruction(); + EXPECT_FALSE(add->has_backend_config()); + GpuBackendConfig gpu_backend_config; + // Wait on queues {0, 1} + gpu_backend_config.mutable_wait_on_operation_queues()->Add(0); + gpu_backend_config.mutable_wait_on_operation_queues()->Add(1); + EXPECT_THAT(add->set_backend_config(gpu_backend_config), IsOk()); + EXPECT_EQ(add->raw_backend_config_string(), + "{\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[\"0\"," + "\"1\"],\"force_earliest_schedule\":false}"); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig config, + add->backend_config()); } } // namespace diff --git a/xla/service/gpu/buffer_allocations.cc b/xla/service/gpu/buffer_allocations.cc index 193586bc9a480..df7472d8e6b16 100644 --- a/xla/service/gpu/buffer_allocations.cc +++ b/xla/service/gpu/buffer_allocations.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,19 +18,22 @@ limitations under the License. #include #include +#include "absl/types/span.h" +#include "xla/service/buffer_assignment.h" #include "xla/status.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" +#include "xla/util.h" +#include "tsl/platform/logging.h" namespace xla { namespace gpu { -Status BufferAllocations::TearDown( +absl::Status BufferAllocations::TearDown( const std::set& live_addresses, absl::Span allocations) { // Deallocate temporary buffers, taking care to try to deallocate all of them // even if one of the deallocations fails. - Status status; + absl::Status status; const int64_t num_buffers = allocations.size(); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { const BufferAllocation& allocation = allocations[i]; @@ -54,7 +57,23 @@ se::DeviceMemoryBase BufferAllocations::GetDeviceAddress( BufferAllocation::Index buffer_index) const { CHECK_GE(buffer_index, 0); CHECK_LT(buffer_index, buffers_.size()); - return buffers_[buffer_index]; + se::DeviceMemoryBase base = buffers_[buffer_index]; + if (reinterpret_cast(base.opaque()) == kExternalAllocationMarker) { + if (!external_allocations_) { + LOG(ERROR) << "Does not have external allocations for buffer " + << buffer_index; + return se::DeviceMemoryBase(); + } + auto external_address = + external_allocations_->GetDeviceAddress(buffer_index); + if (external_address.ok()) { + return external_address.value(); + } + LOG(ERROR) << "External address for allocation" << buffer_index + << " is not allocated yet"; + return se::DeviceMemoryBase(); + } + return base; } se::DeviceMemoryBase& BufferAllocations::GetMutableDeviceAddress( @@ -66,22 +85,42 @@ se::DeviceMemoryBase& BufferAllocations::GetMutableDeviceAddress( se::DeviceMemoryBase BufferAllocations::GetDeviceAddress( const BufferAllocation::Slice& buffer_slice) const { - se::DeviceMemoryBase base = GetDeviceAddress(buffer_slice.index()); - CHECK_LE(buffer_slice.offset(), base.size()); - CHECK_LE(buffer_slice.offset() + buffer_slice.size(), base.size()); - return se::DeviceMemoryBase( - static_cast(base.opaque()) + buffer_slice.offset(), - buffer_slice.size()); + int64_t index = buffer_slice.index(); + se::DeviceMemoryBase base = GetDeviceAddress(index); + + int64_t offset = buffer_slice.offset(); + CHECK_LE(buffer_slice.offset(), base.size()) + << "slice offset " << offset << " must be smaller than buffer #" << index + << " size " << base.size(); + + int64_t extent = offset + buffer_slice.size(); + CHECK_LE(extent, base.size()) + << "slice extent " << extent << " must be smaller than buffer #" << index + << " size " << base.size(); + + return base.GetByteSlice(buffer_slice.offset(), buffer_slice.size()); } -StatusOr BufferAllocations::GetDeviceAddress( - const BufferAllocation::Slice& buffer_slice, - const ExternalAllocations& external_allocations) const { - // Check if base memory address is an external allocation. - se::DeviceMemoryBase base = GetDeviceAddress(buffer_slice.index()); - return reinterpret_cast(base.opaque()) == kExternalAllocationMarker - ? external_allocations.GetDeviceAddress(buffer_slice) - : GetDeviceAddress(buffer_slice); +absl::Status BufferAllocations::AddExternalAllocation( + BufferAllocation::Index index, se::DeviceMemoryBase memory) const { + if (external_allocations_ == nullptr) { + return Internal( + "Calling external allocations, but no allocation tracker is provided" + "for allocation %d", + index); + } + return external_allocations_->AddAllocation(index, memory); +} + +absl::Status BufferAllocations::EraseExternalAllocation( + BufferAllocation::Index index) const { + if (external_allocations_ == nullptr) { + return Internal( + "Calling external allocations, but no allocation tracker is provided" + "for allocation %d", + index); + } + return external_allocations_->EraseAllocation(index); } } // namespace gpu diff --git a/xla/service/gpu/buffer_allocations.h b/xla/service/gpu/buffer_allocations.h index 37d2eb8c2d2f5..5f5883b8500a3 100644 --- a/xla/service/gpu/buffer_allocations.h +++ b/xla/service/gpu/buffer_allocations.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,10 +26,8 @@ limitations under the License. #include "absl/types/span.h" #include "xla/service/buffer_assignment.h" #include "xla/status.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" -#include "xla/stream_executor/stream_executor.h" namespace xla { namespace gpu { @@ -54,18 +52,29 @@ class BufferAllocations { public: virtual ~ExternalAllocations() = default; - // Return a device address for a given buffer slice. Returns error if + // Return a device address for a given buffer allocation. Returns error if // corresponding allocation is not yet allocated. - virtual StatusOr GetDeviceAddress( - BufferAllocation::Slice buffer_slice) const = 0; + virtual absl::StatusOr GetDeviceAddress( + BufferAllocation::Index index) const = 0; + + // Adds an external allocation for a given buffer index. Returns error if + // allocation already exists. + virtual absl::Status AddAllocation(BufferAllocation::Index index, + se::DeviceMemoryBase memory) = 0; + + // Erases an external allocation for a given buffer index. Returns error if + // allocation does not exists. + virtual absl::Status EraseAllocation(BufferAllocation::Index index) = 0; }; BufferAllocations(absl::Span buffers, int device_ordinal, - se::DeviceMemoryAllocator* memory_allocator) + se::DeviceMemoryAllocator* memory_allocator, + ExternalAllocations* external_allocations = nullptr) : buffers_(buffers.begin(), buffers.end()), device_ordinal_(device_ordinal), - memory_allocator_(memory_allocator) {} + memory_allocator_(memory_allocator), + external_allocations_(external_allocations) {} BufferAllocations(BufferAllocations&& other) = default; BufferAllocations& operator=(BufferAllocations&& other) = default; @@ -75,6 +84,9 @@ class BufferAllocations { se::DeviceMemoryAllocator* memory_allocator() const { return memory_allocator_; } + ExternalAllocations* external_allocations() const { + return external_allocations_; + } int device_ordinal() const { return device_ordinal_; } // Returns the device address of buffer `buffer_index`. `buffer_index` must be @@ -92,17 +104,17 @@ class BufferAllocations { se::DeviceMemoryBase GetDeviceAddress( const BufferAllocation::Slice& buffer_slice) const; - // Finds an allocation for a given buffer slice, and if it happens to be an - // external allocation resolves it using user-provided external allocations. - // Returns error if external allocations do not have an address for a slice. - StatusOr GetDeviceAddress( - const BufferAllocation::Slice& buffer_slice, - const ExternalAllocations& external_allocations) const; + // Add new allocation allocated by external allocator. + absl::Status AddExternalAllocation(BufferAllocation::Index index, + se::DeviceMemoryBase memory) const; + + // Remove allocation freed by external allocator. + absl::Status EraseExternalAllocation(BufferAllocation::Index index) const; // Tears down all buffers allocated by this object that are not in // `live_addresses`. - Status TearDown(const std::set& live_addresses, - absl::Span allocations); + absl::Status TearDown(const std::set& live_addresses, + absl::Span allocations); std::string ToString() const { std::string out; @@ -121,12 +133,16 @@ class BufferAllocations { // indexed by Index. Each element can point to a temporary buffer, an // input buffer, or nullptr if no buffer is needed for that Index. - // a nullptr buffer with non-zero size buffer is assumed to be lazily - // allocated buffer, and will be allocated through command buffer Allocate - // command during runtime. + // a special address (se::kExternalAllocationMarker) with non-zero size buffer + // is assumed to be lazily allocated buffer, and will be allocated through + // command buffer Allocate command during runtime. std::vector buffers_; int device_ordinal_; se::DeviceMemoryAllocator* memory_allocator_; + + // For buffer address that marked as ExternalAllocations, tracks its real + // address here. + ExternalAllocations* external_allocations_; }; } // namespace gpu diff --git a/xla/service/gpu/buffer_comparator.cc b/xla/service/gpu/buffer_comparator.cc index c5dbe883f2891..679b5375d0694 100644 --- a/xla/service/gpu/buffer_comparator.cc +++ b/xla/service/gpu/buffer_comparator.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,16 +18,23 @@ limitations under the License. #include #include #include -#include +#include +#include +#include +#include "Eigen/Core" // from @eigen_archive #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/hlo_module_config.h" +#include "xla/shape.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/stream_executor.h" #include "xla/util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/statusor.h" namespace xla { @@ -44,22 +51,22 @@ using ComparisonKernelT = // // Returns `true` if two buffers are equal, `false` otherwise. template -static StatusOr DeviceCompare(se::Stream* stream, - se::DeviceMemoryBase current, - se::DeviceMemoryBase expected, - const Shape& buffer_shape, - const HloModuleConfig& config, - std::string_view kernel_name, - void* kernel_symbol) { +static absl::StatusOr DeviceCompare(se::Stream* stream, + se::DeviceMemoryBase current, + se::DeviceMemoryBase expected, + const Shape& buffer_shape, + const HloModuleConfig& config, + std::string_view kernel_name, + void* kernel_symbol) { se::StreamExecutor* executor = stream->parent(); se::ScopedDeviceMemory out_param = executor->AllocateOwnedScalar(); - stream->ThenMemZero(out_param.ptr(), sizeof(uint64_t)); + TF_RETURN_IF_ERROR(stream->MemZero(out_param.ptr(), sizeof(uint64_t))); if (current.size() != expected.size()) { - return InternalError("Mismatched buffer size: %d bytes vs. %d bytes", - current.size(), expected.size()); + return Internal("Mismatched buffer size: %d bytes vs. %d bytes", + current.size(), expected.size()); } se::DeviceMemory current_typed(current); @@ -67,29 +74,27 @@ static StatusOr DeviceCompare(se::Stream* stream, uint64_t buffer_size = current_typed.ElementCount(); TF_ASSIGN_OR_RETURN( - std::unique_ptr> comparison_kernel, - (executor->CreateTypedKernel, - se::DeviceMemory, float, uint64_t, - se::DeviceMemory>(kernel_name, - kernel_symbol))); + ComparisonKernelT comparison_kernel, + (se::TypedKernel, se::DeviceMemory, + float, uint64_t, + se::DeviceMemory>::Create(executor, + kernel_name, + kernel_symbol))); const se::DeviceDescription& gpu_device_info = executor->GetDeviceDescription(); - TF_ASSIGN_OR_RETURN(LaunchDimensions dim, - CalculateLaunchDimensions(buffer_shape, gpu_device_info)); + LaunchDimensions dim = + CalculateLaunchDimensions(buffer_shape, gpu_device_info); - LaunchDimensions::Dim3D thread_counts = dim.thread_counts_per_block(); - LaunchDimensions::Dim3D block_counts = dim.block_counts(); TF_RETURN_IF_ERROR(stream->ThenLaunch( - se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z), - se::BlockDim(block_counts.x, block_counts.y, block_counts.z), - *comparison_kernel, current_typed, expected_typed, - static_cast(kTolerance), buffer_size, out_param.cref())); + dim.thread_counts_per_block(), dim.block_counts(), comparison_kernel, + current_typed, expected_typed, static_cast(kTolerance), + buffer_size, out_param.cref())); uint64_t result = -1; CHECK_EQ(out_param->size(), sizeof(result)); - stream->ThenMemcpy(&result, *out_param, sizeof(result)); + TF_RETURN_IF_ERROR(stream->Memcpy(&result, *out_param, sizeof(result))); TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); return result == 0; } @@ -99,12 +104,15 @@ static StatusOr DeviceCompare(se::Stream* stream, // // Returns true if no differences were seen, false otherwise. template -StatusOr HostCompare(se::Stream* stream, se::DeviceMemoryBase current, - se::DeviceMemoryBase expected) { +absl::StatusOr HostCompare(se::Stream* stream, + se::DeviceMemoryBase current, + se::DeviceMemoryBase expected) { int64_t n = current.size() / sizeof(ElementType); std::vector host_current(n), host_expected(n); - stream->ThenMemcpy(host_current.data(), current, current.size()); - stream->ThenMemcpy(host_expected.data(), expected, expected.size()); + TF_RETURN_IF_ERROR( + stream->Memcpy(host_current.data(), current, current.size())); + TF_RETURN_IF_ERROR( + stream->Memcpy(host_expected.data(), expected, expected.size())); TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); const auto canonicalize = [](ComparisonType a) -> ComparisonType { @@ -148,13 +156,11 @@ StatusOr HostCompare(se::Stream* stream, se::DeviceMemoryBase current, } template -static StatusOr CompareEqualParameterized(se::Stream* stream, - se::DeviceMemoryBase current, - se::DeviceMemoryBase expected, - const Shape& shape, - const HloModuleConfig& config, - std::string_view kernel_name, - void* kernel_symbol) { +static absl::StatusOr CompareEqualParameterized( + se::Stream* stream, se::DeviceMemoryBase current, + se::DeviceMemoryBase expected, const Shape& shape, + const HloModuleConfig& config, std::string_view kernel_name, + void* kernel_symbol) { XLA_SCOPED_LOGGING_TIMER("BufferComparator::CompareEqual"); TF_ASSIGN_OR_RETURN( bool result, DeviceCompare(stream, current, expected, shape, @@ -171,7 +177,7 @@ static StatusOr CompareEqualParameterized(se::Stream* stream, return false; } -StatusOr BufferComparator::CompareEqual( +absl::StatusOr BufferComparator::CompareEqual( se::Stream* stream, se::DeviceMemoryBase current, se::DeviceMemoryBase expected) const { switch (shape_.element_type()) { diff --git a/xla/service/gpu/buffer_comparator.cu.cc b/xla/service/gpu/buffer_comparator.cu.cc index 08d99184f096b..bbe3395345a05 100644 --- a/xla/service/gpu/buffer_comparator.cu.cc +++ b/xla/service/gpu/buffer_comparator.cu.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/buffer_comparator.h b/xla/service/gpu/buffer_comparator.h index 113bbf9a326fb..395a05225574b 100644 --- a/xla/service/gpu/buffer_comparator.h +++ b/xla/service/gpu/buffer_comparator.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,9 +16,10 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ #define XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ +#include "absl/status/statusor.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" -#include "xla/statusor.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream_executor.h" namespace xla::gpu { @@ -40,8 +41,9 @@ class BufferComparator { // abs(a - b) / (max(abs(a), abs(b)) + 1) < tolerance // // See the implementation for the tolerance value. - StatusOr CompareEqual(se::Stream* stream, se::DeviceMemoryBase current, - se::DeviceMemoryBase expected) const; + absl::StatusOr CompareEqual(se::Stream* stream, + se::DeviceMemoryBase current, + se::DeviceMemoryBase expected) const; private: Shape shape_; diff --git a/xla/service/gpu/buffer_comparator_test.cc b/xla/service/gpu/buffer_comparator_test.cc index 839e5038419c0..8b0f9e72d325c 100644 --- a/xla/service/gpu/buffer_comparator_test.cc +++ b/xla/service/gpu/buffer_comparator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,16 +15,24 @@ limitations under the License. #include "xla/service/gpu/buffer_comparator.h" +#include #include #include #include -#include +#include #include "xla/primitive_util.h" #include "xla/service/gpu/stream_executor_util.h" +#include "xla/service/hlo_module_config.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream.h" #include "xla/types.h" +#include "tsl/platform/ml_dtypes.h" +#include "tsl/platform/status.h" #include "tsl/platform/test.h" namespace xla { @@ -35,9 +43,9 @@ class BufferComparatorTest : public testing::Test { protected: BufferComparatorTest() #if GOOGLE_CUDA - : platform_(se::MultiPlatformManager::PlatformWithName("CUDA").value()), + : platform_(se::PlatformManager::PlatformWithName("CUDA").value()), #elif TENSORFLOW_USE_ROCM - : platform_(se::MultiPlatformManager::PlatformWithName("ROCM").value()), + : platform_(se::PlatformManager::PlatformWithName("ROCM").value()), #endif stream_exec_(platform_->ExecutorForDevice(0).value()) { } @@ -46,26 +54,26 @@ class BufferComparatorTest : public testing::Test { template bool CompareEqualBuffers(const std::vector& current, const std::vector& expected) { - se::Stream stream(stream_exec_); - stream.Init(); + auto stream = stream_exec_->CreateStream().value(); se::ScopedDeviceMemory current_buffer = stream_exec_->AllocateOwnedArray(current.size()); se::ScopedDeviceMemory expected_buffer = stream_exec_->AllocateOwnedArray(expected.size()); - stream.ThenMemcpy(current_buffer.ptr(), current.data(), - current_buffer->size()); - stream.ThenMemcpy(expected_buffer.ptr(), expected.data(), - expected_buffer->size()); - TF_CHECK_OK(stream.BlockHostUntilDone()); + TF_CHECK_OK(stream->Memcpy(current_buffer.ptr(), current.data(), + current_buffer->size())); + TF_CHECK_OK(stream->Memcpy(expected_buffer.ptr(), expected.data(), + expected_buffer->size())); + TF_CHECK_OK(stream->BlockHostUntilDone()); BufferComparator comparator( ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), {static_cast(current_buffer->ElementCount())}), HloModuleConfig()); - return comparator.CompareEqual(&stream, *current_buffer, *expected_buffer) + return comparator + .CompareEqual(stream.get(), *current_buffer, *expected_buffer) .value(); } @@ -339,21 +347,20 @@ TEST_F(BufferComparatorTest, BF16) { const int element_count = 3123; int64_t rng_state = 0; - se::Stream stream(stream_exec_); - stream.Init(); + auto stream = stream_exec_->CreateStream().value(); se::ScopedDeviceMemory lhs = stream_exec_->AllocateOwnedArray(element_count); - InitializeBuffer(&stream, BF16, &rng_state, *lhs.ptr()); + InitializeBuffer(stream.get(), BF16, &rng_state, *lhs.ptr()); se::ScopedDeviceMemory rhs = stream_exec_->AllocateOwnedArray(element_count); - InitializeBuffer(&stream, BF16, &rng_state, *rhs.ptr()); + InitializeBuffer(stream.get(), BF16, &rng_state, *rhs.ptr()); BufferComparator comparator(ShapeUtil::MakeShape(BF16, {element_count}), HloModuleConfig()); EXPECT_FALSE( - comparator.CompareEqual(&stream, *lhs.ptr(), *rhs.ptr()).value()); + comparator.CompareEqual(stream.get(), *lhs.ptr(), *rhs.ptr()).value()); } } // namespace diff --git a/xla/service/gpu/buffer_sharing.cc b/xla/service/gpu/buffer_sharing.cc index 260f533c2149a..4bdd18ae63ac5 100644 --- a/xla/service/gpu/buffer_sharing.cc +++ b/xla/service/gpu/buffer_sharing.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -79,10 +79,10 @@ std::optional FusionCanShareBufferHint(const HloInstruction* user, stream_executor::GpuDeviceInfoProto device_info; stream_executor::DeviceDescription device_description(device_info); auto analysis = HloFusionAnalysis::Create(fusion, &device_description); - bool is_reduction_emitter = analysis->GetEmitterFusionKind() == + bool is_reduction_emitter = analysis.GetEmitterFusionKind() == HloFusionAnalysis::EmitterFusionKind::kReduction; const HloInstruction* reduction_hero = - is_reduction_emitter ? reduction_hero = analysis->FindHeroReduction() + is_reduction_emitter ? reduction_hero = analysis.FindHeroReduction() : nullptr; // We need to make sure that the fusion parameter is accessed in the same @@ -203,7 +203,8 @@ std::optional CanShareBufferHint(const HloInstruction* user, const ShapeIndex& user_index) { switch (user->opcode()) { case HloOpcode::kAllReduce: - // NCCL all-reduce can be performed in-place. + case HloOpcode::kCollectiveBroadcast: + // NCCL all-reduce and collective-broadcast can be performed in-place. return user->operand_count() == 1 || (user_index.size() == 1 && user->operand(user_index[0]) == operand); @@ -211,7 +212,8 @@ std::optional CanShareBufferHint(const HloInstruction* user, // The matrix bias operand can be overwritten in-place. if (user->custom_call_target() == kCublasLtMatmulCallTarget) { GemmBackendConfig config = - std::move(user->backend_config()).value(); + std::move(user->backend_config()) + ->gemm_backend_config(); return (config.beta() != 0.) && user->operand(2) == operand; } // The operand of cholesky can be shared with the first output. diff --git a/xla/service/gpu/buffer_sharing.h b/xla/service/gpu/buffer_sharing.h index 5867dd8d03193..7fdf4af78c11c 100644 --- a/xla/service/gpu/buffer_sharing.h +++ b/xla/service/gpu/buffer_sharing.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/build_defs.bzl b/xla/service/gpu/build_defs.bzl index 81ddac91c1a26..c4d663007fdb1 100644 --- a/xla/service/gpu/build_defs.bzl +++ b/xla/service/gpu/build_defs.bzl @@ -5,6 +5,11 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured", "rocm_copts") load("@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") +# buildifier: disable=out-of-order-load +# Internally this loads a macro, but in OSS this is a function +def register_extension_info(**_kwargs): + pass + def get_cub_sort_kernel_types(name = ""): """ List of supported types for CUB sort kernels. """ @@ -41,6 +46,8 @@ def build_cub_sort_kernels(name, types, local_defines = [], **kwargs): **kwargs ) +register_extension_info(extension = build_cub_sort_kernels, label_regex_for_dep = "{extension_name}_.*") + def gpu_kernel_library(name, copts = [], local_defines = [], **kwargs): cuda_library( name = name, @@ -49,3 +56,5 @@ def gpu_kernel_library(name, copts = [], local_defines = [], **kwargs): copts = copts + rocm_copts(), **kwargs ) + +register_extension_info(extension = gpu_kernel_library, label_regex_for_dep = "{extension_name}") diff --git a/xla/service/gpu/collective_permute_cycle_decomposer.cc b/xla/service/gpu/collective_permute_cycle_decomposer.cc new file mode 100644 index 0000000000000..8edf682481baf --- /dev/null +++ b/xla/service/gpu/collective_permute_cycle_decomposer.cc @@ -0,0 +1,273 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/collective_permute_cycle_decomposer.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/literal_util.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/hlo_parser.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" + +namespace xla { + +namespace { +using SourceTargetPair = std::pair; +using SourceTargetPairs = std::vector; +enum class CycleType { kUnknown, kForward, kBackward }; + +// Returns true if the (source, target) pairs form a forward cycle with all +// participants in the cycle, such as {{0,1},{1,2},{2,3},{3,0}}. We assume that +// the (source, target) pairs are ordered as they are generated by SPMD +// partitioning. +// +bool IsForwardCycle(const SourceTargetPairs& pairs) { + int64_t num_pairs = pairs.size(); + const SourceTargetPair& last_pair = pairs[num_pairs - 1]; + if (last_pair.first != num_pairs - 1 || last_pair.second != 0) { + return false; + } + for (int64_t i = 0; i < num_pairs - 1; ++i) { + const SourceTargetPair& pair = pairs[i]; + if (pair.first != i || pair.second != i + 1) { + return false; + } + } + return true; +} + +// Returns true if the (source, target) pairs form a backward cycle with all +// participants in the cycle, such as {{0,3},{1,0},{2,1},{3,2}}. We assume that +// the (source, target) pairs are ordered as they are generated by SPMD +// partitioning. +// +bool IsBackwardCycle(const SourceTargetPairs& pairs) { + int64_t num_pairs = pairs.size(); + const SourceTargetPair& first_pair = pairs[0]; + if (first_pair.first != 0 || first_pair.second != num_pairs - 1) { + return false; + } + for (int64_t i = 1; i < num_pairs; ++i) { + const SourceTargetPair& pair = pairs[i]; + if (pair.first != i || pair.second != i - 1) { + return false; + } + } + return true; +} + +// Returns true if the CollectivePermute instruction has a cycle in its +// source-target pairs and should be decomposed. +CycleType ShouldDecomposeWithCycleType( + const HloCollectivePermuteInstruction& collective_permute, + int64_t threshold_in_bytes) { + if (!collective_permute.channel_id().has_value()) { + return CycleType::kUnknown; + } + + if (collective_permute.operand_count() != 1) { + return CycleType::kUnknown; + } + + const Shape& result_shape = collective_permute.shape(); + // Skip the transformation if there is any context data. + if (result_shape.IsTuple()) { + return CycleType::kUnknown; + } + + CHECK(result_shape.IsArray()); + if (ShapeUtil::ByteSizeOf(result_shape) < threshold_in_bytes) { + return CycleType::kUnknown; + } + + const SourceTargetPairs& pairs = collective_permute.source_target_pairs(); + if (pairs.size() == 1) { + return CycleType::kUnknown; + } + + return IsForwardCycle(pairs) ? CycleType::kForward + : IsBackwardCycle(pairs) ? CycleType::kBackward + : CycleType::kUnknown; +} + +// Constructs the frontend attributes for the two decomposed CollectivePermute +// instructions. +Status GetFrontendAttributes(HloCollectivePermuteInstruction* cp, + CycleType cycle_type, + xla::FrontendAttributes& cp1_attr, + xla::FrontendAttributes& cp2_attr) { + cp1_attr = cp->frontend_attributes(); + cp2_attr = cp->frontend_attributes(); + auto validation_it = + cp->frontend_attributes().map().find(kSendRecvValidationAttr); + if (validation_it == cp->frontend_attributes().map().end() || + validation_it->second == "invalid") { + return OkStatus(); + } + + auto statusor_bounds = ParseReplicaGroupsOnly(validation_it->second); + if (!statusor_bounds.ok()) { + return statusor_bounds.status(); + } + const std::vector& bounds = statusor_bounds.value(); + if (bounds.size() < 2) { + return Internal("Invalid number of replica groups"); + } + + int64_t num_pairs = bounds.size(); + // A forward cycle has its backedge at the end while a backward cycle has its + // backedge at the beginning. + auto backedge_start = cycle_type == CycleType::kBackward + ? bounds.begin() + : bounds.begin() + num_pairs - 1; + auto other_edges_start = + cycle_type == CycleType::kBackward ? bounds.begin() + 1 : bounds.begin(); + std::vector cp1_bounds(backedge_start, backedge_start + 1); + std::vector cp2_bounds(other_edges_start, + other_edges_start + num_pairs - 1); + auto bounds_to_string = [](const std::vector groups) { + return "{" + + absl::StrJoin(groups, ",", + [](std::string* out, const ReplicaGroup& value) { + absl::StrAppend(out, "{", value.replica_ids(0), ",", + value.replica_ids(1), "}"); + }) + + "}"; + }; + std::string cp1_validation_str = bounds_to_string(cp1_bounds); + std::string cp2_validation_str = bounds_to_string(cp2_bounds); + (*cp1_attr.mutable_map())[kSendRecvValidationAttr] = cp1_validation_str; + (*cp2_attr.mutable_map())[kSendRecvValidationAttr] = cp2_validation_str; + return OkStatus(); +} + +// Decomposes a CollectivePermute instruction with a cycle in its source-target +// pairs into two CollectivePermute instructions. +Status DecomposeCollectivePermuteCycle(HloCollectivePermuteInstruction* cp, + HloComputation* computation, + HloModule* module, + int64_t next_channel_id, + CycleType cycle_type) { + const SourceTargetPairs& pairs = cp->source_target_pairs(); + int64_t num_pairs = pairs.size(); + // A forward cycle has its backedge at the end as in + // {{0,1},{1,2},{2,3},{3,0}} while a backward cycle has its backedge at the + // beginning as in {{0,3},{1,0},{2,1},{3,2}}. + auto backedge_start = cycle_type == CycleType::kBackward + ? pairs.begin() + : pairs.begin() + num_pairs - 1; + auto other_edges_start = + cycle_type == CycleType::kBackward ? pairs.begin() + 1 : pairs.begin(); + SourceTargetPairs backedge(backedge_start, backedge_start + 1); + SourceTargetPairs other_edges(other_edges_start, + other_edges_start + num_pairs - 1); + const OpMetadata& metadata = cp->metadata(); + xla::FrontendAttributes cp1_attr, cp2_attr; + TF_RETURN_IF_ERROR(GetFrontendAttributes(cp, cycle_type, cp1_attr, cp2_attr)); + + // Create the CollectivePermute instruction for the communication represented + // by the backedge. + HloInstruction* cp1 = + computation->AddInstruction(HloInstruction::CreateCollectivePermute( + cp->shape(), cp->mutable_operand(0), backedge, + cp->channel_id().value())); + cp1->set_metadata(metadata); + cp1->set_frontend_attributes(cp1_attr); + int64_t cp1_receiver = backedge.back().second; + + // Create the CollectivePermute instruction for the communication represented + // byt other edges. + HloInstruction* cp2 = + computation->AddInstruction(HloInstruction::CreateCollectivePermute( + cp->shape(), cp->mutable_operand(0), other_edges, next_channel_id)); + cp2->set_metadata(metadata); + cp2->set_frontend_attributes(cp2_attr); + + // Calculate the received data as follows: + // partition = u32[] partition-id() + // constant = u32[] constant(cp1_receiver) + // compare0 = pred[] compare(partition, cp1_received), direction=EQ + // compare = pred[?] broadcast(compare0), dimensions={} + // recv-data = type[?] select(compare, cp1_done, cp2_done) + HloInstruction* partition = + computation->AddInstruction(HloInstruction::CreatePartitionId()); + HloInstruction* constant = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(U32, cp1_receiver))); + HloInstruction* compare0 = computation->AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), partition, + constant, Comparison::Direction::kEq)); + HloInstruction* compare = + computation->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(PRED, cp1->shape().dimensions()), compare0, {})); + HloInstruction* recv_data = + computation->AddInstruction(HloInstruction::CreateTernary( + cp1->shape(), HloOpcode::kSelect, compare, cp1, cp2)); + + TF_RETURN_IF_ERROR(cp->ReplaceAllUsesWith(recv_data)); + TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(cp)); + + return OkStatus(); +} +} // namespace + +absl::StatusOr CollectivePermuteCycleDecomposer::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + int64_t next_channel_id; + for (auto comp : module->computations(execution_threads)) { + for (auto hlo : comp->MakeInstructionPostOrder()) { + if (hlo->opcode() != HloOpcode::kCollectivePermute) { + continue; + } + auto collective_permute = Cast(hlo); + CycleType cycle_type = ShouldDecomposeWithCycleType(*collective_permute, + threshold_in_bytes_); + if (cycle_type != CycleType::kUnknown) { + if (changed == false) { + next_channel_id = hlo_query::NextChannelId(*module); + changed = true; + } + TF_RETURN_IF_ERROR(DecomposeCollectivePermuteCycle( + collective_permute, comp, module, next_channel_id++, cycle_type)); + } + } + } + return changed; +} + +} // namespace xla diff --git a/xla/service/gpu/collective_permute_cycle_decomposer.h b/xla/service/gpu/collective_permute_cycle_decomposer.h new file mode 100644 index 0000000000000..508a8597ee42f --- /dev/null +++ b/xla/service/gpu/collective_permute_cycle_decomposer.h @@ -0,0 +1,73 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_ +#define XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { + +// CollectivePermuteCycleDecomposer is a pass that converts CollectivePermute +// instructions with all participants forming either a forward cycle (such as +// {{0,1},{1,2},{2,3},{3,0}) or a backward cycle (such as {{3,2},{2,1},{1,0}, +// {0,3}}) into two CollectivePermute instructions. We currently restrict +// this transformation to CollectivePermute using partition mode, with one +// input, without any context data. Here is an example. +// +// before transformation: +// start = (, ) collective-permute(data), +// source_target_pairs={{0,1},{1,2},{2,3},{3,0}} +// +// after transformation: +// partition-id = u32[] partition-id() +// constant = u32[] constant(0) +// compare = pred[] compare(u32[] partition-id, u32[] constant), +// direction=EQ +// pred = pred[] broadcast(pred[] compare), dimensions={} +// cp1 = (, ) collective-permute(data), source_target_pairs={{3,0}} +// cp2 = (, ) collective-permute(data), +// source_target_pairs={{0,1},{1,2},{2,3}} +// data = select(pred, cp1, cp2) +// +class CollectivePermuteCycleDecomposer : public HloModulePass { + public: + explicit CollectivePermuteCycleDecomposer(int64_t threshold_in_bytes) + : threshold_in_bytes_(threshold_in_bytes) {} + absl::string_view name() const override { + return "collective-permute-cycle-decomposer"; + } + + using HloPassInterface::Run; + // Runs CollectivePermuteCycleDecomposer pass on computations in 'module'. + // Returns whether the 'module' was changed. + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + // Transform only if the size of the CollectivePermute data >= threshold. + int64_t threshold_in_bytes_; +}; + +} // namespace xla + +#endif // XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_ diff --git a/xla/service/gpu/collective_permute_cycle_decomposer_test.cc b/xla/service/gpu/collective_permute_cycle_decomposer_test.cc new file mode 100644 index 0000000000000..da687711f8ff7 --- /dev/null +++ b/xla/service/gpu/collective_permute_cycle_decomposer_test.cc @@ -0,0 +1,182 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/collective_permute_cycle_decomposer.h" + +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_parser.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +using ::testing::HasSubstr; +using CollectivePermuteCycleDecomposerTest = HloTestBase; + +using ::testing::HasSubstr; +using CollectivePermuteDecomposerTest = HloTestBase; + +TEST_F(CollectivePermuteDecomposerTest, DefaultChannelNotTransformed) { + const absl::string_view kModuleStr = R"( + HloModule test + ENTRY test_computation { + p = u32[] replica-id() + ROOT start = u32[] collective-permute(p), + source_target_pairs={{0,1},{1,0}} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kModuleStr))); + CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0); + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(CollectivePermuteCycleDecomposerTest, TrivialNotTransformed) { + const absl::string_view kModuleStr = R"( + HloModule test + ENTRY test_computation { + p = u32[] partition-id() + ROOT start = u32[] collective-permute(p), channel_id=1, + source_target_pairs={{0,0}} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kModuleStr))); + CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0); + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(CollectivePermuteCycleDecomposerTest, BelowThresholdNotTransformed) { + const absl::string_view kModuleStr = R"( + HloModule test + ENTRY test_computation { + p = u32[] partition-id() + ROOT start = u32[] collective-permute(p), channel_id=1, + source_target_pairs={{0,1},{1,2},{2,3},{3,0}} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kModuleStr))); + CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/33); + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycle) { + const absl::string_view kModuleStr = R"( + HloModule test + ENTRY test_computation { + p = u32[] partition-id() + ROOT start = u32[3,2] collective-permute(p), channel_id=1, + source_target_pairs={{0,1},{1,2},{2,3},{3,0}}, + frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"}, + metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kModuleStr))); + CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0); + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + + auto check_metadata = [](const HloInstruction* inst) { + EXPECT_EQ(inst->metadata().op_name(), "op1/op2/add"); + EXPECT_EQ(inst->metadata().source_file(), "foo/bar/mysource.py"); + EXPECT_EQ(inst->metadata().source_line(), 35); + }; + + HloCollectivePermuteInstruction* cp1 = + DynCast( + FindInstruction(module.get(), "collective-permute")); + HloCollectivePermuteInstruction* cp2 = + DynCast( + FindInstruction(module.get(), "collective-permute.1")); + EXPECT_NE(cp1, nullptr); + EXPECT_NE(cp2, nullptr); + EXPECT_EQ(cp1->operand(0), cp2->operand(0)); + EXPECT_GT(cp2->channel_id().value(), cp1->channel_id().value()); + EXPECT_THAT(cp1->ToString(), HasSubstr("source_target_pairs={{3,0}}")); + EXPECT_THAT(cp1->ToString(), + HasSubstr("_xla_send_recv_validation=\"{{3,10}}\"")); + EXPECT_THAT(cp2->ToString(), + HasSubstr("source_target_pairs={{0,1},{1,2},{2,3}}")); + EXPECT_THAT(cp2->ToString(), + HasSubstr("_xla_send_recv_validation=\"{{0,7},{1,8},{2,9}}\"")); + check_metadata(cp1); + check_metadata(cp2); +} + +TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycle) { + const absl::string_view kModuleStr = R"( + HloModule test + ENTRY test_computation { + p = u32[] partition-id() + ROOT start = u32[] collective-permute(p), channel_id=1, + source_target_pairs={{0,3},{1,0},{2,1},{3,2}}, + frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"}, + metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kModuleStr))); + CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0); + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + auto check_metadata = [](const HloInstruction* inst) { + EXPECT_EQ(inst->metadata().op_name(), "op1/op2/add"); + EXPECT_EQ(inst->metadata().source_file(), "foo/bar/mysource.py"); + EXPECT_EQ(inst->metadata().source_line(), 35); + }; + + HloCollectivePermuteInstruction* cp1 = + DynCast( + FindInstruction(module.get(), "collective-permute")); + HloCollectivePermuteInstruction* cp2 = + DynCast( + FindInstruction(module.get(), "collective-permute.1")); + EXPECT_NE(cp1, nullptr); + EXPECT_NE(cp2, nullptr); + EXPECT_EQ(cp1->operand(0), cp2->operand(0)); + EXPECT_GT(cp2->channel_id().value(), cp1->channel_id().value()); + EXPECT_THAT(cp1->ToString(), HasSubstr("source_target_pairs={{0,3}}")); + EXPECT_THAT(cp1->ToString(), + HasSubstr("_xla_send_recv_validation=\"{{0,7}}\"")); + EXPECT_THAT(cp2->ToString(), + HasSubstr("source_target_pairs={{1,0},{2,1},{3,2}}")); + EXPECT_THAT(cp2->ToString(), + HasSubstr("_xla_send_recv_validation=\"{{1,8},{2,9},{3,10}}\"")); + check_metadata(cp1); + check_metadata(cp2); +} + +} // namespace +} // namespace xla diff --git a/xla/service/gpu/command_buffer_scheduling.cc b/xla/service/gpu/command_buffer_scheduling.cc index 9d660dd54f7cb..a9ed89cbf91ea 100644 --- a/xla/service/gpu/command_buffer_scheduling.cc +++ b/xla/service/gpu/command_buffer_scheduling.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,50 +15,261 @@ limitations under the License. #include "xla/service/gpu/command_buffer_scheduling.h" +#include +#include #include -#include +#include #include #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/ffi/ffi_api.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/variant_visitor.h" #include "xla/shape.h" -#include "xla/statusor.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/stream_executor/device_description.h" #include "xla/util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" namespace xla::gpu { -namespace { - -// We categorize HLO instructions into two types. -// 1. Commands: Instructions that correspond to a command that will be -// submitted to a GPU. Fused computations and library calls fall into this -// category. -// 2. Intermediates: Instructions that produce intermediate values that are -// used by commands. -bool IsIntermediate(const HloInstruction* inst) { - switch (inst->opcode()) { - case HloOpcode::kGetTupleElement: - return true; - default: +using CommandBuffer = CommandBufferScheduling::CommandBuffer; +using CommandBufferConfig = CommandBufferScheduling::CommandBufferConfig; + +// Returns true if HLO computation can be executed as a command buffer. +static bool IsCommand(const HloComputation* computation, + const CommandBufferConfig& config); + +//===----------------------------------------------------------------------===// +// No-op HLO operations. +//===----------------------------------------------------------------------===// + +// Some of the HLO operations do not have corresponding operations at run time +// and they can be safely wrapped into command buffers together with load +// bearing commands. + +static bool IsConstant(const HloInstruction* hlo) { + return hlo->opcode() == HloOpcode::kConstant; +} + +static bool IsParameter(const HloInstruction* hlo) { + return hlo->opcode() == HloOpcode::kParameter; +} + +// Returns true if instruction is no-op at run time and doesn't have a +// corresponding Thunk or Command (metadata only operation). +static bool IsNoOp(const HloInstruction* hlo) { + return HloPredicateIsOp(hlo); +}; + +//===----------------------------------------------------------------------===// +// Synchronous HLO operations mapped to commands. +//===----------------------------------------------------------------------===// + +// Synchronous HLO operations can be wrapped into command buffers when they have +// a corresponding commands. + +// This is a template to define pattern matching functions for HLO instructions +// that do not have a corresponding class for them. +template +static bool IsCommand(const HloInstruction*, const CommandBufferConfig&); + +// While loops can be executed inside command buffers only if condition and body +// regions can be executed as command buffers. +template <> +bool IsCommand(const HloInstruction* hlo, + const CommandBufferConfig& config) { + return config.enabled_commands.contains(DebugOptions::CONDITIONALS) && + IsCommand(hlo->while_body(), config) && + IsCommand(hlo->while_condition(), config); +} + +// Conditional can be executed inside command buffers only if all regions of its +// branches can be executed as command buffers. +template <> +bool IsCommand(const HloInstruction* hlo, + const CommandBufferConfig& config) { + return config.enabled_commands.contains(DebugOptions::CONDITIONALS) && + absl::c_all_of(hlo->branch_computations(), + [&](const HloComputation* comp) { + return IsCommand(comp, config); + }); +} + +static bool IsCommand(const HloCustomCallInstruction* hlo, + const CommandBufferConfig& config) { + // cuBLAS gemms represented in the HLO as custom call instructions. + if (config.enabled_commands.contains(DebugOptions::CUBLAS) && + IsLegacyCublasMatmul(*hlo)) { + return true; + } + + if (!config.enabled_commands.contains(DebugOptions::CUSTOM_CALL)) { + return false; + } + + // A special case for jax-triton kernel while it is not ported to FFI. + if (hlo->custom_call_target() == "triton_kernel_call" && + // TODO(b/327718087): This is an ugly hack to prevent capturing triton + // custom calls that might do autotuning at run time. + !absl::StrContains(hlo->metadata().op_name(), "Autotuner")) { + return true; + } + + // Check if FFI handler is compatible with command buffers. + auto registration = ffi::FindHandler(hlo->custom_call_target(), "gpu"); + return registration.ok() + ? ffi::IsCommandBufferCompatible(registration->traits) + : false; +} + +static bool IsCommand(const HloInstruction* hlo, + const CommandBufferConfig& config) { + if (auto* fusion = DynCast(hlo)) { + auto gpu_config = fusion->backend_config(); + const FusionBackendConfig& backend_config = + gpu_config->fusion_backend_config(); + if (backend_config.kind() == kCuDnnFusionKind) { + return config.enabled_commands.contains(DebugOptions::CUDNN); + } + const auto& custom_config = backend_config.custom_fusion_config(); + if (custom_config.name() == "address_computation") { + auto fusion_analysis = + HloFusionAnalysis::Create(fusion, &config.device_description); + const HloFusionAdaptor& adaptor = fusion_analysis.fusion(); + auto custom_call_adaptor = HloFindIf( + adaptor.GetRoots(), adaptor, + [](auto node) { return node.opcode() == HloOpcode::kCustomCall; }); + const auto* custom_call = static_cast( + &custom_call_adaptor->instruction()); + return IsCommand(custom_call, config); + } + if (custom_config.name() == "dynamic_address_computation") { return false; + } + return config.enabled_commands.contains(DebugOptions::FUSION); + } + + if (auto* sort = DynCast(hlo)) + return config.enabled_commands.contains(DebugOptions::FUSION); + + if (hlo->opcode() == HloOpcode::kPartitionId || + hlo->opcode() == HloOpcode::kReplicaId) { + return config.enabled_commands.contains(DebugOptions::FUSION); + } + + if (auto* custom_call = DynCast(hlo)) + return IsCommand(custom_call, config); + + if (hlo->opcode() == HloOpcode::kWhile) + return IsCommand(hlo, config); + + if (hlo->opcode() == HloOpcode::kConditional) + return IsCommand(hlo, config); + + return false; +} + +//===----------------------------------------------------------------------===// +// Asynchronous HLO operations mapped to commands. +//===----------------------------------------------------------------------===// + +// Asynchronous HLO operations can be wrapped into command buffers only when +// both start and done operations can be put into the same command buffer. +// Command buffer semantics implies that when command buffer execution +// completes, all recorded commands are also completed, which means that if +// done operation is not part of the same command buffer, we would change the +// execution semantics and create additional synchronization point. + +static bool IsAsyncStartCommand(const HloInstruction* hlo, + const CommandBufferConfig& config) { + if (hlo->opcode() == HloOpcode::kAllReduceStart || + hlo->opcode() == HloOpcode::kAllGatherStart) { + return config.enabled_commands.contains(DebugOptions::COLLECTIVES); + } + + if (hlo->opcode() == HloOpcode::kAsyncStart) { + if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) { + return config.enabled_commands.contains(DebugOptions::COLLECTIVES); + } + } + + return false; +} + +static bool IsAsyncDoneCommand(const HloInstruction* hlo, + const CommandBufferConfig& config) { + if (hlo->opcode() == HloOpcode::kAllReduceDone || + hlo->opcode() == HloOpcode::kAllGatherDone) { + return config.enabled_commands.contains(DebugOptions::COLLECTIVES); + } + + if (hlo->opcode() == HloOpcode::kAsyncDone) { + if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) { + return config.enabled_commands.contains(DebugOptions::COLLECTIVES); + } + } + + return false; +} + +// Finds an async-done HLO operation corresponding on an async-start one. +static HloInstruction* FindAsyncDoneCommand(const HloInstruction* start) { + if (start->opcode() == HloOpcode::kAllReduceStart || + start->opcode() == HloOpcode::kAllGatherStart) { + CHECK(start->users().size() == 1); // NOLINT, checked by HLO verifier + return start->users().front(); + } else if (start->opcode() == HloOpcode::kAsyncStart) { + return start->async_chain_done(); } + + return nullptr; } -void RemoveTrailingIntermediates(HloInstructionSequence& seq) { +//===----------------------------------------------------------------------===// +// HLO computations mapped to command buffers. +//===----------------------------------------------------------------------===// + +// Returns true if HLO computation can be executed as a command buffer. +static bool IsCommand(const HloComputation* computation, + const CommandBufferConfig& config) { + return absl::c_all_of( + computation->instructions(), [&](const HloInstruction* inst) { + return IsNoOp(inst) || IsConstant(inst) || IsParameter(inst) || + IsCommand(inst, config) || IsAsyncStartCommand(inst, config) || + IsAsyncDoneCommand(inst, config); + }); +} + +//===----------------------------------------------------------------------===// + +static void RemoveTrailingNoOps(HloInstructionSequence& seq) { std::vector instructions = seq.instructions(); for (int i = instructions.size() - 1; i >= 0; i--) { - HloInstruction* inst = instructions[i]; - if (IsIntermediate(inst)) { + if (HloInstruction* inst = instructions[i]; IsNoOp(inst)) { seq.remove_instruction(inst); } else { break; @@ -66,270 +277,494 @@ void RemoveTrailingIntermediates(HloInstructionSequence& seq) { } } -constexpr int kMinNumCommands = 2; - -} // namespace +//===----------------------------------------------------------------------===// +// Discovering sequences of compatible Hlo instructions +//===----------------------------------------------------------------------===// // The input is a scheduled sequence of instructions. This function collects // subsequences that will be extracted as command buffers. std::vector CommandBufferScheduling::CollectCommandBufferSequences( - const HloInstructionSequence inst_sequence, - std::function is_command) { - struct Accumulator { - std::vector sequences; - HloInstructionSequence current_seq; - int num_commands_in_current_seq = 0; + const HloInstructionSequence schedule, const CommandBufferConfig& config, + int32_t min_num_commands) { + std::vector sequences; + + HloInstructionSequence current_seq; + int64_t num_commands_in_current_seq = 0; + + // Adds `current_seq` to `sequences` if it has enough commands in it. + auto collect_current_seq = [&]() { + if (num_commands_in_current_seq >= std::max(1, min_num_commands)) { + RemoveTrailingNoOps(current_seq); + sequences.push_back(std::move(current_seq)); + } + current_seq = HloInstructionSequence(); + num_commands_in_current_seq = 0; }; - auto start_new_sequence = [](Accumulator* acc) -> Accumulator* { - if (acc->num_commands_in_current_seq >= kMinNumCommands) { - RemoveTrailingIntermediates(acc->current_seq); - acc->sequences.push_back(acc->current_seq); + auto& instructions = schedule.instructions(); + + // Collect the sequence of instructions that contains the async start and its + // corresponding done instruction. If there is another start instruction + // between the original start and done, we may potentially extend the sequence + // to include its corresponding done instruction. For example, if we call this + // function on async-start_a in the following sequence: + // + // async_start_a + // async_start_b + // async_done_a + // async_done_b + // + // The returned sequence will contain async_done_b. So that all async pairs + // are captured by the same command buffer. + auto collect_async_region = [&](const HloInstruction* start) { + auto get_index = [&](const HloInstruction* inst) -> size_t { + auto it = std::find(instructions.begin(), instructions.end(), inst); + return std::distance(instructions.begin(), it); + }; + + HloInstructionSequence seq; + size_t done_index = get_index(FindAsyncDoneCommand(start)); + for (size_t i = get_index(start); i <= done_index; i++) { + HloInstruction* inst = instructions.at(i); + if (IsAsyncStartCommand(inst, config)) { + const HloInstruction* done = FindAsyncDoneCommand(inst); + done_index = std::max(done_index, get_index(done)); + } + seq.push_back(inst); } - acc->current_seq = HloInstructionSequence(); - acc->num_commands_in_current_seq = 0; - return acc; + return seq; }; - auto process_instruction = [&start_new_sequence, &is_command]( - Accumulator* acc, - HloInstruction* inst) -> Accumulator* { - if (is_command(inst)) { - acc->current_seq.push_back(inst); - acc->num_commands_in_current_seq += 1; - return acc; - } else if (IsIntermediate(inst)) { - if (acc->current_seq.size() > 0) { - acc->current_seq.push_back(inst); + // Check that instructions are safe to be captured by command buffer, and that + // we do not capture unmatched async done instruction. + auto check_async_region = [&](const HloInstructionSequence& seq) { + if (!absl::c_all_of(seq.instructions(), [&](HloInstruction* inst) { + return IsNoOp(inst) || IsCommand(inst, config) || + IsAsyncStartCommand(inst, config) || + IsAsyncDoneCommand(inst, config); + })) { + return false; + } + + absl::flat_hash_set done_instructions; + for (const HloInstruction* inst : seq.instructions()) { + if (IsAsyncStartCommand(inst, config)) { + done_instructions.insert(FindAsyncDoneCommand(inst)); + } + if (IsAsyncDoneCommand(inst, config)) { + if (!done_instructions.contains(inst)) { + return false; + } } - return acc; } - return start_new_sequence(acc); + return true; }; - std::vector instructions = inst_sequence.instructions(); - Accumulator acc; - absl::c_accumulate(instructions, &acc, process_instruction); - return start_new_sequence(&acc)->sequences; + for (size_t i = 0; i < instructions.size(); i++) { + HloInstruction* inst = instructions.at(i); + + // We add no-op instructions to current sequence only if they act as a glue + // between commands. We do not create command sequences consisting only from + // no-op instruction. First and last instruction in the command buffer is + // always a load-bearing command. + if (IsNoOp(inst) && num_commands_in_current_seq) { + current_seq.push_back(inst); + continue; + } + + // Synchronous commands always can be added to instruction sequence. + if (IsCommand(inst, config)) { + num_commands_in_current_seq++; + current_seq.push_back(inst); + continue; + } + + // We capture async commands if all instruction between start and done can + // be outlined into a command buffer. + if (IsAsyncStartCommand(inst, config)) { + HloInstructionSequence seq = collect_async_region(inst); + if (check_async_region(seq)) { + num_commands_in_current_seq += seq.instructions().size(); + for (HloInstruction* inst : seq.instructions()) { + current_seq.push_back(inst); + } + i += seq.instructions().size() - 1; + continue; + } + } + + // If we didn't find the next command, collect the current sequence and + // start a new one. + collect_current_seq(); + } + + // Don't forget to collect the final command sequence. + collect_current_seq(); + return sequences; } -// This function moves kParameter instructions in a computation to the beginning -// of the computation. This simplifies the construction of command buffer -// computations because we don't need to consider kParameter's as intermediates. -void CommandBufferScheduling::MoveParametersToFront( +// This function moves kParameter and kConstant instructions in a computation to +// the beginning of the computation. This simplifies the construction of command +// buffer computations because we don't need to deal with parameters and +// constants that have users outside of a command buffer. +absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront( HloComputation* computation) { + HloInstructionSequence new_sequence; HloSchedule& schedule = computation->parent()->schedule(); HloInstructionSequence& sequence = schedule.GetOrCreateSequence(computation); - std::vector new_sequence; + for (HloInstruction* inst : sequence.instructions()) { - if (inst->opcode() == HloOpcode::kParameter) { + if (IsParameter(inst) || IsConstant(inst)) { new_sequence.push_back(inst); + + // Because we move instruction to the front of the computation we can't + // have any control predecessors, however silently dropping them is unsafe + // as we can have transitive dependencies that define schedule order, so + // we forward control predecessors to all users. + for (HloInstruction* control_predecessor : inst->control_predecessors()) { + for (HloInstruction* user : inst->users()) { + TF_RETURN_IF_ERROR(control_predecessor->AddControlDependencyTo(user)); + } + } + TF_RETURN_IF_ERROR(inst->DropAllControlDeps()); } } for (HloInstruction* inst : sequence.instructions()) { - if (inst->opcode() != HloOpcode::kParameter) { + if (!IsParameter(inst) && !IsConstant(inst)) { new_sequence.push_back(inst); } } schedule.set_sequence(computation, new_sequence); + return absl::OkStatus(); } -StatusOr -CommandBufferScheduling::BuildCommandBuffer(HloInstructionSequence seq) { +//===----------------------------------------------------------------------===// +// Prepares command buffer from sequence of instructions +//===----------------------------------------------------------------------===// + +absl::StatusOr CommandBufferScheduling::PrepareCommandBuffer( + const HloInstructionSequence& seq) { auto builder = HloComputation::Builder("command_buffer"); - const std::vector& instructions = seq.instructions(); + + absl::Span instructions = + absl::MakeSpan(seq.instructions()); + + // A set of instructions that will be moved into command buffer computation. + absl::flat_hash_set in_command_buffer(instructions.begin(), + instructions.end()); // The sequence might use results of instructions that are not captured by the // sequence. We pass those results as parameters and map the producers of the // results to their corresponding parameter instructions. - absl::flat_hash_map parameters_map; - int64_t parameter_number = 0; + absl::flat_hash_map parameters; + + // Mapping from command buffer instructions to their clones in the command + // buffer computation body. + absl::flat_hash_map inst_mapping; + + // Maps HLO instructions in the original computation to instructions in the + // command buffer: (a) a parameter corresponding to captured value (b) cloned + // instruction corresponding to a command. + auto mapped_operands = [&](HloInstruction* instr) { + absl::InlinedVector operands; + for (HloInstruction* operand : instr->operands()) { + if (auto it = inst_mapping.find(operand); it != inst_mapping.end()) + operands.push_back(it->second); + } + return operands; + }; + + // Create parameters in the command buffer computation for captured values. for (HloInstruction* inst : instructions) { for (HloInstruction* operand : inst->operands()) { - if (absl::c_find(instructions, operand) != instructions.end()) { - continue; - } - - if (!parameters_map.contains(operand)) { - TF_ASSIGN_OR_RETURN( - HloInstruction * parameter, - builder.AddParameter(HloInstruction::CreateParameter( - parameter_number, operand->shape(), "param"))); - parameter_number++; - parameters_map[operand] = - static_cast(parameter); - } + // We already mapped instruction to a parameter. + if (parameters.contains(operand)) continue; + + // Operand instruction is a part of the command buffer. + if (in_command_buffer.contains(operand)) continue; + + // Create a new parameter for value defined outside of a command buffer. + int64_t parameter_id = parameters.size(); + auto* parameter = Cast(builder.AddInstruction( + HloInstruction::CreateParameter(parameter_id, operand->shape(), + absl::StrCat("p", parameter_id)))); + inst_mapping[operand] = parameters[operand] = parameter; } } - // We copy instructions from the sequence to the computation and map the - // original instruction to its clone. - absl::flat_hash_map instructions_map; + // Clone commands into the command buffer body with mapped operands. for (HloInstruction* inst : seq.instructions()) { - switch (inst->opcode()) { - case HloOpcode::kFusion: { - std::vector operands; - for (HloInstruction* operand : inst->operands()) { - auto it = parameters_map.find(operand); - if (it != parameters_map.end()) { - operands.push_back(it->second); - } else { - operands.push_back(instructions_map[operand]); - } - } - instructions_map[inst] = - builder.AddInstruction(HloInstruction::CreateFusion( - inst->shape(), inst->fusion_kind(), operands, - inst->fused_instructions_computation())); - break; + HloCloneContext ctx(inst->GetModule()); + + // Cloned instructions should call the same computations as original + // instructions will be dead code eliminated. + for (HloComputation* called_computation : inst->called_computations()) { + // Async computations can only be referenced by a single async chain at + // a time. Detach the current chain to let its copy bind to the + // computation. + if (called_computation->IsAsyncComputation()) { + called_computation->RemoveAsyncStart(); } - case HloOpcode::kConstant: - instructions_map[inst] = builder.AddInstruction( - HloInstruction::CreateConstant(inst->literal().Clone())); - break; - case HloOpcode::kGetTupleElement: { - HloGetTupleElementInstruction* get_tuple_index = - static_cast(inst); - HloInstruction* original_operand = get_tuple_index->mutable_operand(0); - auto it = parameters_map.find(original_operand); - HloInstruction* operand; - if (it != parameters_map.end()) { - operand = it->second; - } else { - operand = instructions_map[original_operand]; - } - instructions_map[inst] = - builder.AddInstruction(HloInstruction::CreateGetTupleElement( - inst->shape(), operand, get_tuple_index->tuple_index())); - break; - } - default: - return InternalError("HLO opcode unsupported by command buffers"); + ctx.MapComputation(called_computation, called_computation); } + + inst_mapping[inst] = builder.AddInstruction( + inst->CloneWithNewOperands(inst->shape(), mapped_operands(inst), &ctx)); } - // Build result tuple. - std::vector new_instructions; - absl::flat_hash_map inst_to_tuple_index_map; - int64_t index = 0; - for (HloInstruction* inst : seq.instructions()) { - new_instructions.push_back(instructions_map[inst]); - inst_to_tuple_index_map[inst] = index; - index++; + // Convert parameters to command buffer arguments. + std::vector arguments(parameters.size()); + for (auto& [argument, parameter] : parameters) { + arguments[parameter->parameter_number()] = argument; } - builder.AddInstruction(HloInstruction::CreateTuple(new_instructions)); - BuildCommandBufferResult result = {builder.Build(), parameters_map, - inst_to_tuple_index_map, instructions_map}; - return result; -} + // Collect command buffer `results` (instructions replaced in the original + // computation) and `results` (instructions in the command buffer). + std::vector results; + std::vector returned; -StatusOr CommandBufferScheduling::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - if (!module->has_schedule()) { - return InternalError("module is not scheduled"); + auto has_external_users = [&](HloInstruction* inst) { + return inst->IsRoot() || absl::c_any_of(inst->users(), [&](auto* user) { + return !in_command_buffer.contains(user); + }); + }; + + for (HloInstruction* inst : instructions) { + if (has_external_users(inst)) { + results.push_back(inst); + returned.push_back(inst_mapping[inst]); + } } - HloComputation* entry = module->entry_computation(); - MoveParametersToFront(entry); - - absl::flat_hash_set command_types; - for (auto cmd_type_num : - module->config().debug_options().xla_gpu_enable_command_buffer()) { - DebugOptions::CommandBufferCmdType cmd_type = - static_cast(cmd_type_num); - command_types.insert(cmd_type); + + // If we return multiple results wrap them into tuple. + if (returned.size() > 1) { + builder.AddInstruction(HloInstruction::CreateTuple(returned)); } - std::function is_command = - [&command_types = - std::as_const(command_types)](const HloInstruction* inst) { - if (inst->opcode() == HloOpcode::kFusion) { - if (command_types.contains(DebugOptions::FUSION)) return true; - } - return false; - }; + return CommandBuffer{std::move(arguments), std::move(results), + builder.Build(), std::move(inst_mapping)}; +} - std::vector sequences = CollectCommandBufferSequences( - module->schedule().sequence(entry), is_command); +//===----------------------------------------------------------------------===// +// Rewrites original computation into command buffer call +//===----------------------------------------------------------------------===// + +absl::StatusOr CommandBufferScheduling::RewriteCommandBuffer( + HloComputation* parent, const HloInstructionSequence& seq, + CommandBuffer command_buffer) { + if (command_buffer.results.empty()) + return absl::InternalError("command buffer results must not be empty"); + + // If we have more than one result we return them as tuple, and get individual + // values using `get-tuple-element` instructions. Otherwise we simply return + // a result from a command buffer computation. + Shape cmd_buffer_result_shape; + bool has_single_result = command_buffer.results.size() == 1; + + if (has_single_result) { + cmd_buffer_result_shape = command_buffer.results[0]->shape(); + } else { + absl::InlinedVector shapes; + shapes.reserve(command_buffer.results.size()); + for (auto* res : command_buffer.results) shapes.push_back(res->shape()); + cmd_buffer_result_shape = ShapeUtil::MakeTupleShape(shapes); + } - for (const HloInstructionSequence& seq : sequences) { - TF_ASSIGN_OR_RETURN(BuildCommandBufferResult result, - BuildCommandBuffer(seq)); + HloComputation* computation = + parent->parent()->AddComputationAndUnifyNamesAndIds( + std::move(command_buffer.computation), + /*is_entry=*/false); + + HloInstruction* call = parent->AddInstruction(HloInstruction::CreateCall( + cmd_buffer_result_shape, command_buffer.arguments, computation)); + + // Replace all users or original results with a command buffer results. + if (has_single_result) { + TF_RETURN_IF_ERROR(command_buffer.results[0]->ReplaceAllUsesWith(call)); + } else { + for (int i = 0; i < command_buffer.results.size(); i++) { + TF_RETURN_IF_ERROR( + command_buffer.results[i]->ReplaceAllUsesWith(parent->AddInstruction( + HloInstruction::CreateGetTupleElement(call, i)))); + } + } - Shape shape; - shape.set_element_type(TUPLE); - shape.mutable_tuple_shapes()->resize(result.inst_to_tuple_index_map.size()); - for (const auto [inst, index] : result.inst_to_tuple_index_map) { - shape.mutable_tuple_shapes()->at(index) = inst->shape(); + // As we are running after scheduling we have to keep it valid. + HloSchedule& schedule = parent->parent()->schedule(); + + // Update schedule to replace the last instruction with a command buffer call. + // Removal of the rest of the instructions in the sequence is handled by + // schedule update below. + HloInstructionSequence& sequence = schedule.GetOrCreateSequence(parent); + sequence.replace_instruction(seq.instructions().back(), call); + + // Rebuild original instruction sequence schedule in a newly created + // command buffer computation to guarantee that we'll get exactly the same + // buffer assignment result as if we were running without command buffers. + HloInstructionSequence cmd_buffer_schedule; + for (auto* argument : command_buffer.arguments) { + cmd_buffer_schedule.push_back(command_buffer.inst_mapping[argument]); + } + for (auto* inst : seq.instructions()) { + cmd_buffer_schedule.push_back(command_buffer.inst_mapping[inst]); + } + if (!has_single_result) { + cmd_buffer_schedule.push_back(computation->root_instruction()); + } + schedule.set_sequence(computation, cmd_buffer_schedule); + + // Forward control dependencies between original instructions to instruction + // in the command buffer computation. + auto& inst_mapping = command_buffer.inst_mapping; + for (HloInstruction* inst : seq.instructions()) { + HloInstruction* cmd_inst = inst_mapping[inst]; + + // Forward control dependencies to the new instruction inside command + // buffer. If the dependent instruction is not captured by the command + // buffer, forward the dependency to the command buffer call instead. + for (HloInstruction* predecessor : inst->control_predecessors()) { + if (auto it = inst_mapping.find(predecessor); it != inst_mapping.end()) { + // If predecessor mapped to a parameter instruction it means that we + // need to forward control dependency to a call operation, otherwise + // we add control dependency between commands in the command buffer. + HloInstruction* cmd_predecessor = it->second; + if (IsParameter(cmd_predecessor)) { + TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(call)); + } else { + TF_RETURN_IF_ERROR(cmd_predecessor->AddControlDependencyTo(cmd_inst)); + } + } else { + TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(call)); + } } - std::vector operands(result.parameters_map.size()); - for (const auto [inst, parameter] : result.parameters_map) { - operands[parameter->parameter_number()] = inst; + for (HloInstruction* successor : inst->control_successors()) { + if (auto it = inst_mapping.find(successor); it != inst_mapping.end()) { + HloInstruction* cmd_successor = it->second; + TF_RETURN_IF_ERROR(cmd_inst->AddControlDependencyTo(cmd_successor)); + } else { + TF_RETURN_IF_ERROR(call->AddControlDependencyTo(successor)); + } } - HloComputation* command_buffer = - module->AddComputationAndUnifyNamesAndIds(std::move(result.computation), - /*is_entry=*/false); - HloInstruction* call_command_buffer = entry->AddInstruction( - HloInstruction::CreateCall(shape, operands, command_buffer)); + TF_RETURN_IF_ERROR(inst->DropAllControlDeps()); + } - std::vector results(result.inst_to_tuple_index_map.size()); - for (int i = 0; i < result.inst_to_tuple_index_map.size(); i++) { - results[i] = entry->AddInstruction( - HloInstruction::CreateGetTupleElement(call_command_buffer, i)); - } + // Traverse in reverse order as original sequence was topologically sorted and + // we can't remove instructions with users. + for (int32_t i = seq.instructions().size() - 1; i >= 0; i--) { + TF_RETURN_IF_ERROR(parent->RemoveInstruction(seq.instructions()[i])); + } - // Remove instructions in the command buffer sequence. - bool first_inst = true; - for (HloInstruction* inst : seq.instructions()) { - // Replace the first instruction in the sequence by command buffer call. - // Removal of the rest of the instructions in the sequence is handled by - // HloSchedule::Update(). - if (first_inst) { - first_inst = false; - HloInstructionSequence& sequence = - module->schedule().GetOrCreateSequence(entry); - sequence.replace_instruction(inst, call_command_buffer); - } + return computation; +} - // Forward control dependencies to the new instruction inside command - // buffer. If the dependent instruction is not captured by the command - // buffer, forward the dependency to the command buffer call instead. - HloInstruction* new_inst = result.instructions_map[inst]; - for (HloInstruction* predecessor : inst->control_predecessors()) { - if (auto it = result.instructions_map.find(predecessor); - it != result.instructions_map.end()) { - HloInstruction* new_predecessor = it->second; - TF_RETURN_IF_ERROR(new_predecessor->AddControlDependencyTo(new_inst)); - } else { - TF_RETURN_IF_ERROR( - predecessor->AddControlDependencyTo(call_command_buffer)); - } - } - for (HloInstruction* successor : inst->control_successors()) { - if (auto it = result.instructions_map.find(successor); - it != result.instructions_map.end()) { - HloInstruction* new_successor = it->second; - TF_RETURN_IF_ERROR(new_inst->AddControlDependencyTo(new_successor)); - } else { - TF_RETURN_IF_ERROR( - call_command_buffer->AddControlDependencyTo(successor)); - } +//===----------------------------------------------------------------------===// + +CommandBufferScheduling::CommandBufferScheduling( + const se::DeviceDescription& device_description, + int32_t gpu_toolkit_version, int32_t gpu_driver_version) + : device_description_(device_description), + gpu_toolkit_version_(gpu_toolkit_version), + gpu_driver_version_(gpu_driver_version) {} + +absl::StatusOr CommandBufferScheduling::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + // We run command buffer scheduling after a regular scheduling to guarantee + // that command buffers will not change execution order and buffer assignment + // compared to a regular execution. Some operations (i.e. async collectives) + // can't be captured into command buffers, and forming too large command + // buffers too early can impact async operations scheduling. + if (!module->has_schedule()) return Internal("module is not scheduled"); + + const DebugOptions& debug_options = module->config().debug_options(); + + absl::flat_hash_set commands; + for (auto cmd_type : debug_options.xla_gpu_enable_command_buffer()) { + commands.insert(static_cast(cmd_type)); + } + CommandBufferConfig config{std::move(commands), device_description_}; + + // Erase command buffer cmd types that are not supported by the gpu runtime. + static constexpr auto kRequireConditionals = {DebugOptions::CONDITIONALS}; + static constexpr auto kRequireTracing = { + DebugOptions::CUBLAS, DebugOptions::CUDNN, DebugOptions::CUSTOM_CALL}; + + auto erase = [&](absl::Span cmds) { + for (auto cmd : cmds) { + if (config.enabled_commands.erase(cmd)) { + VLOG(1) << "Removed command buffer support for " + << DebugOptions::CommandBufferCmdType_Name(cmd) + << " as it's not supported with gpu toolkit version " + << gpu_toolkit_version_ << " and driver version " + << gpu_driver_version_ + << ". This might negatively impact peformance. To enable " + << DebugOptions::CommandBufferCmdType_Name(cmd) + << " support in command buffers use cuda-compat package: " +#if defined(PLATFORM_GOOGLE) + << "set CUDA_COMPAT_LOAD=1 env variable."; +#else + << "https://docs.nvidia.com/deploy/cuda-compatibility/."; +#endif } - TF_RETURN_IF_ERROR(inst->DropAllControlDeps()); + } + }; - int64_t tuple_index = result.inst_to_tuple_index_map[inst]; - TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(results[tuple_index])); - TF_RETURN_IF_ERROR(entry->RemoveInstruction(inst)); + // Check if CUDA/ROCM driver supports required features. + auto erase_cuda = [&](const se::CudaComputeCapability& cuda_comp) { + if (std::min(gpu_toolkit_version_, gpu_driver_version_) < 12030) { + erase(kRequireTracing); // cuStreamBeginCaptureToGraph + erase(kRequireConditionals); // on-device control flow } - } + }; + auto erase_rocm = [&](const se::RocmComputeCapability& rocm_comp) { + erase(kRequireConditionals); // on-device control flow + }; + std::visit(VariantVisitor{erase_cuda, erase_rocm}, + device_description_.gpu_compute_capability()); + + auto order = module->MakeComputationPostOrder(); + std::reverse(order.begin(), order.end()); + absl::flat_hash_set processed_command_buffers; + + for (HloComputation* comp : order) { + // Skip special computations that do not have lowering to thunks. + if (comp->IsFusionComputation() || comp->IsAsyncComputation() || + comp->IsCustomCallComputation()) + continue; + + // Skip computations that already part of command buffers. + if (processed_command_buffers.contains(comp)) continue; + + TF_RETURN_IF_ERROR(MoveParametersAndConstantsToFront(comp)); + + std::vector sequences = + CollectCommandBufferSequences( + module->schedule().sequence(comp), config, + debug_options.xla_gpu_graph_min_graph_size()); + + for (const HloInstructionSequence& seq : sequences) { + TF_ASSIGN_OR_RETURN(CommandBuffer command_buffer, + PrepareCommandBuffer(seq)); + TF_ASSIGN_OR_RETURN( + HloComputation * command_buffer_computation, + RewriteCommandBuffer(comp, seq, std::move(command_buffer))); + + // All computations reachable from a command buffer computation are nested + // command buffers (i.e. body computations attached to a while operation). + for (HloComputation* called : + command_buffer_computation->MakeEmbeddedComputationsList()) { + processed_command_buffers.insert(called); + } + } + } TF_RETURN_IF_ERROR(module->schedule().Update()); + return true; } diff --git a/xla/service/gpu/command_buffer_scheduling.h b/xla/service/gpu/command_buffer_scheduling.h index ad0844a207c3c..79855f307d600 100644 --- a/xla/service/gpu/command_buffer_scheduling.h +++ b/xla/service/gpu/command_buffer_scheduling.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ limitations under the License. #define XLA_SERVICE_GPU_COMMAND_BUFFER_SCHEDULING_H_ #include -#include #include #include @@ -25,11 +24,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" +#include "xla/status.h" +#include "xla/stream_executor/device_description.h" namespace xla::gpu { @@ -70,45 +69,72 @@ namespace xla::gpu { // custom call to a first class operation later. class CommandBufferScheduling : public HloModulePass { public: + struct CommandBufferConfig { + // DebugOptions control which commands are enabled. Long term we want to + // remove that flag and enable all supported commands by default. + absl::flat_hash_set enabled_commands; + const se::DeviceDescription& device_description; + }; + + CommandBufferScheduling(const se::DeviceDescription& device_description, + int32_t gpu_toolkit_version, + int32_t gpu_driver_version); + absl::string_view name() const override { return "command-buffer-scheduling"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; static std::vector CollectCommandBufferSequences( - HloInstructionSequence inst_sequence, - std::function is_command); - static void MoveParametersToFront(HloComputation* computation); + HloInstructionSequence schedule, const CommandBufferConfig& config, + int32_t min_num_commands = 1); + + // Moves kParameter and kConstant instructions in a computation to + // the beginning of the computation. This simplifies the construction of + // command buffer computations because we don't need to deal with parameters + // and constants that have users outside of a command buffer. + static absl::Status MoveParametersAndConstantsToFront( + HloComputation* computation); + + struct CommandBuffer { + // Command buffer arguments (call instruction arguments). + std::vector arguments; + + // Command buffer result (call instruction result tuple). + std::vector results; - struct BuildCommandBufferResult { + // Hlo computation corresponding to a command buffer body. std::unique_ptr computation; - // Maps external instructions used by the command buffer to a parameter - // of the command buffer computation. The command buffer uses parameters - // to access the results of external instructions. - absl::flat_hash_map - parameters_map; - - // We move some instructions to the command buffer computation and return - // the results back to the original computation by tuple. This field maps - // the original instruction to the tuple index of the result that replaces - // the original instruction. - absl::flat_hash_map inst_to_tuple_index_map; - - // Map original instructions to their clones in the command buffer - // computation. - absl::flat_hash_map instructions_map; + // Mapping from original instruction to their clones in the command buffer. + absl::flat_hash_map inst_mapping; }; - // Builds a computation from the instruction sequence. Used values constructed - // by instructions outside of the sequence are passed in as parameters. - // Results of instructions in the sequence are returned in a tuple. - static StatusOr BuildCommandBuffer( - HloInstructionSequence seq); + // Prepares a command buffer from the instruction sequence. Used values + // constructed by instructions outside of the sequence are passed in as + // parameters. Results of instructions in the sequence are returned in a tuple + // (if command buffer has a single result we don't wrap it into tuple). + static absl::StatusOr PrepareCommandBuffer( + const HloInstructionSequence& seq); + + // Rewrites prepared command buffer computation into Hlo operations in the + // parent computation (calls command buffer and replaced all users). + static absl::StatusOr RewriteCommandBuffer( + HloComputation* parent, const HloInstructionSequence& seq, + CommandBuffer command_buffer); + + private: + se::DeviceDescription device_description_; + // For NVIDIA gpus XLA can be compiled with a CUDA version that is larger than + // the version supported by the driver, e.g. we can compile for CUDA 12.3 but + // have 12.1 driver installed. When deciding what command buffer features we + // can use we have to consider both versions. + int32_t gpu_toolkit_version_; + int32_t gpu_driver_version_; }; } // namespace xla::gpu diff --git a/xla/service/gpu/command_buffer_scheduling_test.cc b/xla/service/gpu/command_buffer_scheduling_test.cc index aa63b7e40d2c2..56d5860ce09b5 100644 --- a/xla/service/gpu/command_buffer_scheduling_test.cc +++ b/xla/service/gpu/command_buffer_scheduling_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,27 +16,47 @@ limitations under the License. #include #include +#include #include #include #include -#include "absl/container/flat_hash_map.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/hlo_parser.h" +#include "xla/stream_executor/device_description.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" +#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" namespace xla::gpu { - namespace { -class CommandBufferSchedulingTest : public HloTestBase {}; +class CommandBufferSchedulingTest : public HloTestBase { + public: + // Use CUDA 12.3 version for testing as it has all the features we rely on. + static constexpr int32_t kCudaVersion = 12030; + + const se::DeviceDescription& device_desc() { + return backend().default_stream_executor()->GetDeviceDescription(); + } + + DebugOptions GetDebugOptionsForTest() override { + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION); + debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CONDITIONALS); + debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); + debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN); + debug_options.set_xla_gpu_graph_min_graph_size(2); + return debug_options; + } +}; + +using CommandBuffer = CommandBufferScheduling::CommandBuffer; TEST_F(CommandBufferSchedulingTest, SingleCommandBuffer) { const char* hlo = R"( @@ -63,11 +83,11 @@ TEST_F(CommandBufferSchedulingTest, SingleCommandBuffer) { })"; const char* expected = R"( -// CHECK: %command_buffer (param: s32[], param.1: s32[]) -> (s32[], s32[]) { -// CHECK: %param = s32[] parameter(0) -// CHECK: %param.1 = s32[] parameter(1) -// CHECK: %fusion.2 = s32[] fusion(%param, %param.1), kind=kLoop, calls=%fused_computation -// CHECK: %fusion.3 = s32[] fusion(%param, %param.1), kind=kLoop, calls=%fused_computation.1 +// CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> (s32[], s32[]) { +// CHECK: %[[P0]] = s32[] parameter(0) +// CHECK: %[[P1]] = s32[] parameter(1) +// CHECK: %fusion.2 = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation +// CHECK: %fusion.3 = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.1 // CHECK: ROOT %tuple = (s32[], s32[]) tuple(%fusion.2, %fusion.3) // CHECK: } // @@ -80,11 +100,12 @@ TEST_F(CommandBufferSchedulingTest, SingleCommandBuffer) { // CHECK: ROOT %custom-call = s32[] custom-call(%get-tuple-element, %get-tuple-element.1), custom_call_target="some target" // CHECK: })"; - RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(), expected, - [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + RunAndFilecheckHloRewrite( + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); } TEST_F(CommandBufferSchedulingTest, MultipleCommandBuffers) { @@ -130,45 +151,305 @@ TEST_F(CommandBufferSchedulingTest, MultipleCommandBuffers) { })"; const char* expected = R"( -// CHECK: %command_buffer (param: s32[], param.1: s32[], param.2: (s32[], s32[])) -> (s32[], s32[], s32[]) { -// CHECK: %param = s32[] parameter(0) -// CHECK: %param.1 = s32[] parameter(1) -// CHECK: %param.2 = (s32[], s32[]) parameter(2) -// CHECK: %fusion.4 = s32[] fusion(%param, %param.1), kind=kLoop, calls=%fused_computation -// CHECK: %get-tuple-element = s32[] get-tuple-element(%param.2), index=0 -// CHECK: %fusion.5 = s32[] fusion(%fusion.4, %get-tuple-element), kind=kLoop, calls=%fused_computation.1 -// CHECK: ROOT %tuple = (s32[], s32[], s32[]) tuple(%fusion.4, %get-tuple-element, %fusion.5) +// CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[], [[P2:.+]]: (s32[], s32[])) -> s32[] { +// CHECK: %[[P0]] = s32[] parameter(0) +// CHECK: %[[P1]] = s32[] parameter(1) +// CHECK: %[[P2]] = (s32[], s32[]) parameter(2) +// CHECK: %[[F0:.+]] = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation +// CHECK: %[[V0:.+]] = s32[] get-tuple-element(%[[P2]]), index=0 +// CHECK: ROOT {{.*}} = s32[] fusion(%[[F0]], %[[V0]]), kind=kLoop, calls=%fused_computation.1 // CHECK: } -// CHECK: %command_buffer.1 (param.3: s32[], param.4: s32[]) -> (s32[], s32[]) { -// CHECK: %param.3 = s32[] parameter(0) -// CHECK: %param.4 = s32[] parameter(1) -// CHECK: %fusion.6 = s32[] fusion(%param.3, %param.4), kind=kLoop, calls=%fused_computation.2 -// CHECK: %fusion.7 = s32[] fusion(%param.3, %fusion.6), kind=kLoop, calls=%fused_computation.3 -// CHECK: ROOT %tuple.1 = (s32[], s32[]) tuple(%fusion.6, %fusion.7) +// CHECK: %command_buffer.1 ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> s32[] { +// CHECK: %[[P0]] = s32[] parameter(0) +// CHECK: %[[P1]] = s32[] parameter(1) +// CHECK: %[[F2:.+]] = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.2 +// CHECK: ROOT {{.*}} = s32[] fusion(%[[P0]], %[[F2]]), kind=kLoop, calls=%fused_computation.3 // CHECK: } // CHECK: ENTRY %main (a: s32[], b: s32[], c: (s32[], s32[])) -> s32[] { // CHECK: %a = s32[] parameter(0) // CHECK: %b = s32[] parameter(1) // CHECK: %c = (s32[], s32[]) parameter(2) -// CHECK: %call = (s32[], s32[], s32[]) call(%a, %b, %c), to_apply=%command_buffer -// CHECK: %get-tuple-element.1 = s32[] get-tuple-element(%call), index=0 -// CHECK: %get-tuple-element.2 = s32[] get-tuple-element(%call), index=1 -// CHECK: %get-tuple-element.3 = s32[] get-tuple-element(%call), index=2 +// CHECK: %[[CMD0:.+]] = s32[] call(%a, %b, %c), to_apply=%command_buffer // CHECK: %e = s32[] get-tuple-element(%c), index=1 -// CHECK: %custom-call = s32[] custom-call(%get-tuple-element.3, %e), custom_call_target="some target" -// CHECK: %call.1 = (s32[], s32[]) call(%custom-call, %a), to_apply=%command_buffer.1 -// CHECK: %get-tuple-element.4 = s32[] get-tuple-element(%call.1), index=0 -// CHECK: %get-tuple-element.5 = s32[] get-tuple-element(%call.1), index=1 -// CHECK: ROOT %custom-call.1 = s32[] custom-call(%get-tuple-element.5), custom_call_target="some target" +// CHECK: %[[CALL:.+]] = s32[] custom-call(%[[CMD0]], %e), custom_call_target="some target" +// CHECK: %[[CMD1:.+]] = s32[] call(%[[CALL]], %a), to_apply=%command_buffer.1 +// CHECK: ROOT {{.*}} = s32[] custom-call(%[[CMD1]]), custom_call_target="some target" // CHECK: })"; - RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(), expected, - [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + RunAndFilecheckHloRewrite( + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + +TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedByDone) { + const char* hlo = R"( + HloModule TestModule, is_scheduled=true + + %add (p0: s32[4], p1: s32[4]) -> s32[4] { + %p0 = s32[4] parameter(0) + %p1 = s32[4] parameter(1) + ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1) + } + + ENTRY %main (a: s32[4]) -> s32[4] { + %a = s32[4] parameter(0) + %start = s32[4]{0} all-reduce-start(s32[4]{0} %a), + replica_groups={{0,1}}, to_apply=%add, + backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + ROOT %done = s32[4]{0} all-reduce-done(s32[4]{0} %start) + })"; + + const char* expected = R"( + CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[4] { + CHECK: %[[P0]] = s32[4]{0} parameter(0) + CHECK: %[[START:.+]] = s32[4]{0} all-reduce-start(%[[P0]]) + CHECK: ROOT %[[DONE:.+]] = s32[4]{0} all-reduce-done(%[[START]]) + CHECK: } + + CHECK: ENTRY %main (a: s32[4]) -> s32[4] { + CHECK: %[[A:.+]] = s32[4]{0} parameter(0) + CHECK: ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]), + CHECK: to_apply=%command_buffer + CHECK: })"; + + RunAndFilecheckHloRewrite( + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + +TEST_F(CommandBufferSchedulingTest, AllGatherStartFollowedByDone) { + const char* hlo = R"( + HloModule TestModule, is_scheduled=true + + ENTRY %main (a: s32[2]) -> s32[4] { + %a = s32[2] parameter(0) + + %start = (s32[2]{0}, s32[4]{0}) all-gather-start(%a), + channel_id=555, replica_groups={{0,1}}, dimensions={0}, + backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + + ROOT %done = s32[4]{0} all-gather-done(%start) + })"; + + const char* expected = R"( + CHECK: %command_buffer ([[P0:.+]]: s32[2]) -> s32[4] { + CHECK: %[[P0]] = s32[2]{0} parameter(0) + CHECK: %[[START:.+]] = {{.*}} all-gather-start(%[[P0]]) + CHECK: ROOT %[[DONE:.+]] = s32[4]{0} all-gather-done(%[[START]]) + CHECK: } + + CHECK: ENTRY %main (a: s32[2]) -> s32[4] { + CHECK: %[[A:.+]] = s32[2]{0} parameter(0) + CHECK: ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]), + CHECK: to_apply=%command_buffer + CHECK: })"; + + RunAndFilecheckHloRewrite( + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + +TEST_F(CommandBufferSchedulingTest, ReduceScatterStartFollowedByDone) { + const char* hlo = R"( + HloModule TestModule, is_scheduled=true + + %add (p0: s32[], p1: s32[]) -> s32[] { + %p0 = s32[] parameter(0) + %p1 = s32[] parameter(1) + ROOT %add = s32[] add(s32[] %p0, s32[] %p1) + } + + ENTRY %main (a: s32[4]) -> s32[2] { + %a = s32[4] parameter(0) + + %start = ((s32[4]{0}), s32[2]{0}) reduce-scatter-start(%a), + channel_id=555, replica_groups={{0,1}}, dimensions={0}, to_apply=add, + backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + + ROOT %done = s32[2]{0} reduce-scatter-done(%start) + })"; + + const char* expected = R"( + CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[2] { + CHECK: %[[P0]] = s32[4]{0} parameter(0) + CHECK: %[[START:.+]] = {{.*}} reduce-scatter-start(%[[P0]]) + CHECK: ROOT %[[DONE:.+]] = s32[2]{0} reduce-scatter-done(%[[START]]) + CHECK: } + + CHECK: ENTRY %main (a: s32[4]) -> s32[2] { + CHECK: %[[A:.+]] = s32[4]{0} parameter(0) + CHECK: ROOT %[[CALL:.+]] = s32[2]{0} call(%[[A]]), + CHECK: to_apply=%command_buffer + CHECK: })"; + + RunAndFilecheckHloRewrite( + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + +TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedByBitcast) { + const char* hlo = R"( + HloModule TestModule, is_scheduled=true + + %add (p0: s32[4], p1: s32[4]) -> s32[4] { + %p0 = s32[4] parameter(0) + %p1 = s32[4] parameter(1) + ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1) + } + + ENTRY %main (a: s32[4]) -> s32[4] { + %a = s32[4] parameter(0) + %start = s32[4]{0} all-reduce-start(s32[4]{0} %a), + replica_groups={{0,1}}, to_apply=%add, + backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + %bitcast = s32[4] bitcast(s32[4]{0} %a) + ROOT %done = s32[4]{0} all-reduce-done(s32[4]{0} %start) + })"; + + const char* expected = R"( + CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[4] { + CHECK: %[[P0]] = s32[4]{0} parameter(0) + CHECK: %[[START:.+]] = s32[4]{0} all-reduce-start(%[[P0]]) + CHECK: %[[BITCAST:.+]] = s32[4]{0} bitcast(%[[P0]]) + CHECK: ROOT %[[DONE:.+]] = s32[4]{0} all-reduce-done(%[[START]]) + CHECK: } + + CHECK: ENTRY %main (a: s32[4]) -> s32[4] { + CHECK: %[[A:.+]] = s32[4]{0} parameter(0) + CHECK: ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]), + CHECK: to_apply=%command_buffer + CHECK: })"; + + RunAndFilecheckHloRewrite( + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + +TEST_F(CommandBufferSchedulingTest, AllReduceStartFollowedAllReduceStart) { + const char* hlo = R"( + HloModule TestModule, is_scheduled=true + + %add (p0: s32[4], p1: s32[4]) -> s32[4] { + %p0 = s32[4] parameter(0) + %p1 = s32[4] parameter(1) + ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1) + } + + ENTRY %main (a: s32[4]) -> s32[4] { + %a = s32[4] parameter(0) + %start1 = s32[4]{0} all-reduce-start(s32[4]{0} %a), + replica_groups={{0,1}}, to_apply=%add, + backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + %start2 = s32[4]{0} all-reduce-start(s32[4]{0} %a), + replica_groups={{0,1}}, to_apply=%add, + backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + %done1 = s32[4]{0} all-reduce-done(s32[4]{0} %start1) + ROOT %done2 = s32[4]{0} all-reduce-done(s32[4]{0} %start2) + })"; + + const char* expected = R"( + CHECK: %command_buffer ([[P0:.+]]: s32[4]) -> s32[4] { + CHECK: %[[P0]] = s32[4]{0} parameter(0) + CHECK: %[[START1:.+]] = s32[4]{0} all-reduce-start(%[[P0]]) + CHECK: %[[START2:.+]] = s32[4]{0} all-reduce-start(%[[P0]]) + CHECK: %[[DONE1:.+]] = s32[4]{0} all-reduce-done(%[[START1]]) + CHECK: ROOT %[[DONE2:.+]] = s32[4]{0} all-reduce-done(%[[START2]]) + CHECK: } + + CHECK: ENTRY %main (a: s32[4]) -> s32[4] { + CHECK: %[[A:.+]] = s32[4]{0} parameter(0) + CHECK: ROOT %[[CALL:.+]] = s32[4]{0} call(%[[A]]), + CHECK: to_apply=%command_buffer + CHECK: })"; + + RunAndFilecheckHloRewrite( + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + +TEST_F(CommandBufferSchedulingTest, DoNotCaptureUnmatchedAsyncDone) { + const char* hlo = R"( + HloModule TestModule, is_scheduled=true + + %fused_computation(param_0: s32[], param_1: s32[]) -> s32[] { + %p0 = s32[] parameter(0) + %p1 = s32[] parameter(1) + ROOT %add = s32[] add(s32[] %p0, s32[] %p1) + } + + %fused_computation.1(param_0: s32[], param_1: s32[]) -> s32[] { + %p0 = s32[] parameter(0) + %p1 = s32[] parameter(1) + ROOT %add = s32[] add(s32[] %p0, s32[] %p1) + } + + %add (p0: s32[4], p1: s32[4]) -> s32[4] { + %p0 = s32[4] parameter(0) + %p1 = s32[4] parameter(1) + ROOT %add = s32[4] add(s32[4] %p0, s32[4] %p1) + } + + ENTRY %main (a: s32[4], b:s32[]) -> s32[] { + %a = s32[4] parameter(0) + %b = s32[] parameter(1) + %start1 = s32[4]{0} all-reduce-start(s32[4]{0} %a), + replica_groups={{0,1}}, to_apply=%add, + backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + %c = s32[] custom-call(), custom_call_target="target" + %start2 = s32[4]{0} all-reduce-start(s32[4]{0} %a), + replica_groups={{0,1}}, to_apply=%add, + backend_config={"collective_backend_config": {"is_sync":true,"no_parallel_custom_call":false}} + %done1 = s32[4]{0} all-reduce-done(s32[4]{0} %start1) + %done2 = s32[4]{0} all-reduce-done(s32[4]{0} %start2) + %fusion = s32[] fusion(s32[] %b, s32[] %c), kind=kLoop, calls=%fused_computation + ROOT %fusion.1 = s32[] fusion(s32[] %b, s32[] %c), kind=kLoop, calls=%fused_computation.1 + })"; + + const char* expected = R"( + CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> s32[] { + CHECK: %[[P0]] = s32[] parameter(0) + CHECK: %[[P1]] = s32[] parameter(1) + CHECK: %fusion.2 = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation + CHECK: ROOT %fusion.3 = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.1 + CHECK: } + + CHECK: ENTRY %main (a: s32[4], b: s32[]) -> s32[] { + CHECK: %[[A:.+]] = s32[4]{0} parameter(0) + CHECK: %[[B:.+]] = s32[] parameter(1) + CHECK: %[[START1:.+]] = s32[4]{0} all-reduce-start(%[[A]]) + CHECK: %[[C:.+]] = s32[] custom-call() + CHECK: %[[START2:.+]] = s32[4]{0} all-reduce-start(%[[A]]) + CHECK: %[[DONE1:.+]] = s32[4]{0} all-reduce-done(%[[START1]]) + CHECK: %[[DONE2:.+]] = s32[4]{0} all-reduce-done(%[[START2]]) + CHECK: %call = s32[] call(%b, %c), to_apply=%command_buffer + CHECK: })"; + + RunAndFilecheckHloRewrite( + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); } TEST_F(CommandBufferSchedulingTest, CollectCommandBufferSequence) { @@ -221,11 +502,11 @@ TEST_F(CommandBufferSchedulingTest, CollectCommandBufferSequence) { } EXPECT_EQ(seq.size(), 10); + CommandBufferScheduling::CommandBufferConfig config{{DebugOptions::FUSION}, + device_desc()}; + std::vector command_buffer_sequences = - CommandBufferScheduling::CollectCommandBufferSequences( - seq, [](const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kFusion; - }); + CommandBufferScheduling::CollectCommandBufferSequences(seq, config); EXPECT_EQ(command_buffer_sequences.size(), 2); std::vector seq_0 = @@ -277,7 +558,8 @@ TEST_F(CommandBufferSchedulingTest, MoveParametersToFront) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo)); - CommandBufferScheduling::MoveParametersToFront(module->entry_computation()); + TF_ASSERT_OK(CommandBufferScheduling::MoveParametersAndConstantsToFront( + module->entry_computation())); TF_ASSERT_OK_AND_ASSIGN( bool filecheck_matches, RunFileCheck( @@ -286,7 +568,7 @@ TEST_F(CommandBufferSchedulingTest, MoveParametersToFront) { EXPECT_TRUE(filecheck_matches); } -TEST_F(CommandBufferSchedulingTest, BuildComputation) { +TEST_F(CommandBufferSchedulingTest, PrepareCommandBuffer) { const char* hlo = R"( HloModule TestModule, is_scheduled=true @@ -325,20 +607,19 @@ TEST_F(CommandBufferSchedulingTest, BuildComputation) { instructions.push_back(inst); } - TF_ASSERT_OK_AND_ASSIGN( - CommandBufferScheduling::BuildCommandBufferResult result, - CommandBufferScheduling::BuildCommandBuffer(seq)); + TF_ASSERT_OK_AND_ASSIGN(CommandBuffer command_buffer, + CommandBufferScheduling::PrepareCommandBuffer(seq)); HloComputation* computation = module->AddComputationAndUnifyNamesAndIds( - std::move(result.computation), false); + std::move(command_buffer.computation), false); const char* expected = R"( -// CHECK: %command_buffer (param: s32[], param.1: s32[]) -> ((s32[], s32[]), s32[], s32[]) { -// CHECK: %param = s32[] parameter(0) -// CHECK: %param.1 = s32[] parameter(1) -// CHECK: %fusion.2 = (s32[], s32[]) fusion(%param, %param.1), kind=kLoop, calls=%fused_computation -// CHECK: %get-tuple-element = s32[] get-tuple-element(%fusion.2), index=0 -// CHECK: %fusion.3 = s32[] fusion(%param, %get-tuple-element), kind=kLoop, calls=%fused_computation.1 -// CHECK: ROOT %tuple.1 = ((s32[], s32[]), s32[], s32[]) tuple(%fusion.2, %get-tuple-element, %fusion.3) +// CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> (s32[], s32[]) { +// CHECK: %[[P0]] = s32[] parameter(0) +// CHECK: %[[P1]] = s32[] parameter(1) +// CHECK: %fusion.2 = (s32[], s32[]) fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation +// CHECK: %[[V0:.+]] = s32[] get-tuple-element(%fusion.2), index=0 +// CHECK: %fusion.3 = s32[] fusion(%[[P0]], %[[V0]]), kind=kLoop, calls=%fused_computation.1 +// CHECK: ROOT {{.*}} = (s32[], s32[]) tuple(%[[V0]], %fusion.3) // CHECK:})"; TF_ASSERT_OK_AND_ASSIGN( @@ -348,77 +629,385 @@ TEST_F(CommandBufferSchedulingTest, BuildComputation) { expected)); EXPECT_TRUE(filecheck_matches); - absl::flat_hash_map& - parameters_map = result.parameters_map; - EXPECT_EQ(parameters_map[instructions[0]]->parameter_number(), 0); - EXPECT_EQ(parameters_map[instructions[1]]->parameter_number(), 1); + auto& arguments = command_buffer.arguments; + ASSERT_EQ(arguments.size(), 2); + EXPECT_EQ(arguments[0], instructions[0]); + EXPECT_EQ(arguments[1], instructions[1]); - absl::flat_hash_map& inst_to_tuple_index_map = - result.inst_to_tuple_index_map; - EXPECT_EQ(inst_to_tuple_index_map[instructions[2]], 0); - EXPECT_EQ(inst_to_tuple_index_map[instructions[3]], 1); - EXPECT_EQ(inst_to_tuple_index_map[instructions[4]], 2); + auto& results = command_buffer.results; + ASSERT_EQ(results.size(), 2); + EXPECT_EQ(results[0], instructions[3]); + EXPECT_EQ(results[1], instructions[4]); } -TEST_F(CommandBufferSchedulingTest, RelayControlDependencies) { +TEST_F(CommandBufferSchedulingTest, ForwardControlDependencies) { const char* hlo = R"( - HloModule TestModule, is_scheduled=true + HloModule TestModule, is_scheduled=true - %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] { - %p0 = s32[] parameter(0) - %p1 = s32[] parameter(1) - ROOT %add = s32[] add(s32[] %p0, s32[] %p1) - } + %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] { + %p0 = s32[] parameter(0) + %p1 = s32[] parameter(1) + ROOT %add = s32[] add(s32[] %p0, s32[] %p1) + } - %fused_computation.1 (param_0: s32[], param_1: s32[]) -> s32[] { - %p0 = s32[] parameter(0) - %p1 = s32[] parameter(1) - ROOT %add = s32[] add(s32[] %p0, s32[] %p1) - } + %fused_computation.1 (param_0: s32[], param_1: s32[]) -> s32[] { + %p0 = s32[] parameter(0) + %p1 = s32[] parameter(1) + ROOT %add = s32[] add(s32[] %p0, s32[] %p1) + } - %fused_computation.2 (param_0: s32[], param_1: s32[]) -> s32[] { - %p0 = s32[] parameter(0) - %p1 = s32[] parameter(1) - ROOT %add = s32[] add(s32[] %p0, s32[] %p1) - } + %fused_computation.2 (param_0: s32[], param_1: s32[]) -> s32[] { + %p0 = s32[] parameter(0) + %p1 = s32[] parameter(1) + ROOT %add = s32[] add(s32[] %p0, s32[] %p1) + } - ENTRY %main (a: s32[], b: s32[]) -> s32[] { - %a = s32[] parameter(0) - %b = s32[] parameter(1) - %custom-call = s32[] custom-call(), custom_call_target="some target" - %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation, control-predecessors={%custom-call} - %fusion.1 = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation.1, control-predecessors={%fusion} - %custom-call.1 = s32[] custom-call(), custom_call_target="some target" - %fusion.2 = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation.2, control-predecessors={%fusion.1} - ROOT %custom-call.2 = s32[] custom-call(), custom_call_target="some target" - })"; + ENTRY %main (a: s32[], b: s32[]) -> s32[] { + %a = s32[] parameter(0) + %b = s32[] parameter(1) + %custom-call = s32[] custom-call(), custom_call_target="some target" + %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation, control-predecessors={%custom-call} + %fusion.1 = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation.1, control-predecessors={%fusion} + %custom-call.1 = s32[] custom-call(), custom_call_target="some target" + %fusion.2 = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation.2, control-predecessors={%fusion.1} + ROOT %custom-call.2 = s32[] custom-call(s32[] %fusion.1, s32[] %fusion.2), custom_call_target="some target" + })"; const char* expected = R"( -// CHECK: %command_buffer (param: s32[], param.1: s32[]) -> (s32[], s32[]) { -// CHECK: %param = s32[] parameter(0) -// CHECK: %param.1 = s32[] parameter(1) -// CHECK: %fusion.3 = s32[] fusion(%param, %param.1), kind=kLoop, calls=%fused_computation -// CHECK: %fusion.4 = s32[] fusion(%param, %param.1), kind=kLoop, calls=%fused_computation.1, control-predecessors={%fusion.3} -// CHECK: ROOT %tuple = (s32[], s32[]) tuple(%fusion.3, %fusion.4) -// CHECK: } -// -// CHECK: ENTRY %main (a: s32[], b: s32[]) -> s32[] { -// CHECK: %a = s32[] parameter(0) -// CHECK: %b = s32[] parameter(1) -// CHECK: %custom-call = s32[] custom-call(), custom_call_target="some target" -// CHECK: %call = (s32[], s32[]) call(%a, %b), to_apply=%command_buffer, control-predecessors={%custom-call} -// CHECK: %get-tuple-element = s32[] get-tuple-element(%call), index=0 -// CHECK: %get-tuple-element.1 = s32[] get-tuple-element(%call), index=1 -// CHECK: %custom-call.1 = s32[] custom-call(), custom_call_target="some target" -// CHECK: %fusion.2 = s32[] fusion(%a, %b), kind=kLoop, calls=%fused_computation.2, control-predecessors={%call} -// CHECK: ROOT %custom-call.2 = s32[] custom-call(), custom_call_target="some target" -// CHECK: })"; + CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> s32[] { + CHECK: %[[P0]] = s32[] parameter(0) + CHECK: %[[P1]] = s32[] parameter(1) + CHECK: %[[F0:.+]] = s32[] fusion(%[[P0]], %[[P1]]) + CHECK: ROOT {{.*}} = s32[] fusion(%[[P0]], %[[P1]]), {{.*}} control-predecessors={%[[F0]]} + CHECK: } + + CHECK: ENTRY %main (a: s32[], b: s32[]) -> s32[] { + CHECK: %a = s32[] parameter(0) + CHECK: %b = s32[] parameter(1) + CHECK: %custom-call = s32[] custom-call(), custom_call_target="some target" + CHECK: %call = s32[] call(%a, %b), to_apply=%command_buffer, control-predecessors={%custom-call} + CHECK: %custom-call.1 = s32[] custom-call(), custom_call_target="some target" + CHECK: %[[F3:.+]] = s32[] fusion(%a, %b), kind=kLoop, calls=%fused_computation.2, control-predecessors={%call} + CHECK: ROOT %custom-call.2 = s32[] custom-call(%call, %[[F3]]), custom_call_target="some target" + CHECK: })"; + + RunAndFilecheckHloRewrite( + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + +TEST_F(CommandBufferSchedulingTest, ForwardControlDependenciesToParams) { + const char* hlo = R"( + HloModule TestModule, is_scheduled=true + + %fused_computation.0 (p0: s32[], p1: s32[]) -> s32[] { + %p0 = s32[] parameter(0) + %p1 = s32[] parameter(1) + ROOT %add = s32[] add(s32[] %p0, s32[] %p1) + } + + %fused_computation.1 (p0: s32[], p1: s32[]) -> s32[] { + %p0 = s32[] parameter(0) + %p1 = s32[] parameter(1) + ROOT %add = s32[] add(s32[] %p0, s32[] %p1) + } + + ENTRY %main (a: s32[], b: s32[]) -> s32[] { + %a = s32[] parameter(0) + %b = s32[] parameter(1) + %custom-call = s32[] custom-call(), custom_call_target="some target" + %fusion = s32[] fusion(s32[] %custom-call, s32[] %a), kind=kLoop, calls=%fused_computation.0, control-predecessors={%custom-call} + ROOT %fusion.1 = s32[] fusion(s32[] %fusion, s32[] %b), kind=kLoop, calls=%fused_computation.1 + })"; + + const char* expected = R"( + CHECK: ENTRY %main (a: s32[], b: s32[]) -> s32[] { + CHECK: %a = s32[] parameter(0) + CHECK: %b = s32[] parameter(1) + CHECK: %[[CUSTOM_CALL:.+]] = s32[] custom-call(), custom_call_target="some target" + CHECK: ROOT {{.*}} call(%[[CUSTOM_CALL]], %a, %b), to_apply=%command_buffer, control-predecessors={%[[CUSTOM_CALL]]} + CHECK: })"; + + RunAndFilecheckHloRewrite( + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + +TEST_F(CommandBufferSchedulingTest, WhileNotCommand) { + const char* hlo = R"( + HloModule TestModule, is_scheduled=true + + %fused_computation (param_0: f32[1]) -> f32[1] { + %param_0 = f32[1]{0} parameter(0) + ROOT %copy.5 = f32[1]{0} copy(f32[1]{0} %param_0) + } + + %fused_computation.1 (param_0.1: f32[1], param_1: f32[1]) -> f32[1] { + %param_0.1 = f32[1]{0} parameter(0) + %param_1 = f32[1]{0} parameter(1) + ROOT %add.2 = f32[1]{0} add(f32[1]{0} %param_0.1, f32[1]{0} %param_1) + } + + %fused_computation.2 (param_0.2: f32[1], param_1.1: f32[1]) -> pred[1] { + %param_0.2 = f32[1]{0} parameter(0) + %param_1.1 = f32[1]{0} parameter(1) + ROOT %compare.3 = pred[1]{0} compare(f32[1]{0} %param_0.2, f32[1]{0} %param_1.1), direction=LT + } + + %fused_computation.3 (param_0.1: f32[1], param_1: f32[1]) -> f32[1] { + %param_0.1 = f32[1]{0} parameter(0) + %param_1 = f32[1]{0} parameter(1) + ROOT %add.2 = f32[1]{0} add(f32[1]{0} %param_0.1, f32[1]{0} %param_1) + } + + %body (Arg_.3: f32[1]) -> f32[1] { + %constant_4 = f32[1]{0} constant({1}) + %Arg_.3 = f32[1]{0} parameter(0) + %custom-call = s32[] custom-call(), custom_call_target="some target" + %add = f32[1]{0} fusion(f32[1]{0} %Arg_.3, f32[1]{0} %constant_4), kind=kLoop, calls=%fused_computation.1, control-predecessors={%custom-call} + ROOT %wrapped_add.1 = f32[1]{0} fusion(f32[1]{0} %add, f32[1]{0} %constant_4), kind=kLoop, calls=%fused_computation.3, control-predecessors={%custom-call} + } + + %cond (Arg_.11: f32[1]) -> pred[] { + %constant = f32[1]{0} constant({100}) + %Arg_.11 = f32[1]{0} parameter(0) + %wrapped_compare.2 = pred[1]{0} fusion(f32[1]{0} %Arg_.11, f32[1]{0} %constant), kind=kLoop, calls=%fused_computation.2 + ROOT %bitcast = pred[] bitcast(pred[1]{0} %wrapped_compare.2) + } + + ENTRY %main.18 (Arg_0.1: f32[1]) -> f32[] { + %Arg_0.1 = f32[1]{0} parameter(0), sharding={replicated} + %wrapped_copy.4 = f32[1]{0} fusion(f32[1]{0} %Arg_0.1), kind=kLoop, calls=%fused_computation + %while.16 = f32[1]{0} while(f32[1]{0} %wrapped_copy.4), condition=%cond, body=%body + ROOT %bitcast.1 = f32[] bitcast(f32[1]{0} %while.16) + })"; + + const char* expected = R"( + CHECK: %command_buffer ([[P0:.+]]: f32[1], [[P1:.+]]: f32[1]) -> f32[1] { + CHECK: %[[P0]] = f32[1]{0} parameter(0) + CHECK: %[[P1]] = f32[1]{0} parameter(1) + CHECK: %[[ADD:.*]] = f32[1]{0} fusion(%[[P0]], %[[P1]]), kind=kLoop + CHECK: ROOT {{.*}} = f32[1]{0} fusion(%[[ADD]], %[[P1]]), kind=kLoop + CHECK: } + + CHECK: %[[BODY:[a-z_0-9.]+]] ([[P0:.+]]: f32[1]) -> f32[1] { + CHECK: %[[C1:.*]] = f32[1]{0} constant({1}) + CHECK: %[[P0]] = f32[1]{0} parameter(0) + CHECK: %[[CC:.*]] = s32[] custom-call(), custom_call_target="some target" + CHECK: ROOT %call = f32[1]{0} call(%[[P0]], %[[C1]]), to_apply=%command_buffer, control-predecessors={%[[CC]]} + CHECK: } + + CHECK: ENTRY %[[MAIN:.+]] ([[ARG0:.+]]: f32[1]) -> f32[] { + CHECK: %[[ARG0]] = f32[1]{0} parameter(0) + CHECK: %[[COPY:.*]] = f32[1]{0} fusion(%[[ARG0]]), kind=kLoop + CHECK: %[[WHILE:.*]] = f32[1]{0} while(%[[COPY]]), condition=%[[COND:[a-z_0-9.]+]], body=%[[BODY]] + CHECK: ROOT %[[BC:.+]] = f32[] bitcast(%[[WHILE]]) + CHECK: })"; + + RunAndFilecheckHloRewrite( + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + +TEST_F(CommandBufferSchedulingTest, While) { + const char* hlo = R"( + HloModule TestModule, is_scheduled=true + + %fused_computation (param_0: f32[1]) -> f32[1] { + %param_0 = f32[1]{0} parameter(0) + ROOT %copy.5 = f32[1]{0} copy(f32[1]{0} %param_0) + } + + %fused_computation.1 (param_0.1: f32[1], param_1: f32[1]) -> f32[1] { + %param_0.1 = f32[1]{0} parameter(0) + %param_1 = f32[1]{0} parameter(1) + ROOT %add.2 = f32[1]{0} add(f32[1]{0} %param_0.1, f32[1]{0} %param_1) + } + + %fused_computation.2 (param_0.2: f32[1], param_1.1: f32[1]) -> pred[1] { + %param_0.2 = f32[1]{0} parameter(0) + %param_1.1 = f32[1]{0} parameter(1) + ROOT %compare.3 = pred[1]{0} compare(f32[1]{0} %param_0.2, f32[1]{0} %param_1.1), direction=LT + } + + %body (Arg_.3: f32[1]) -> f32[1] { + %constant_4 = f32[1]{0} constant({1}) + %Arg_.3 = f32[1]{0} parameter(0) + ROOT %wrapped_add.1 = f32[1]{0} fusion(f32[1]{0} %Arg_.3, f32[1]{0} %constant_4), kind=kLoop, calls=%fused_computation.1 + } + + %cond (Arg_.11: f32[1]) -> pred[] { + %constant = f32[1]{0} constant({100}) + %Arg_.11 = f32[1]{0} parameter(0) + %wrapped_compare.2 = pred[1]{0} fusion(f32[1]{0} %Arg_.11, f32[1]{0} %constant), kind=kLoop, calls=%fused_computation.2 + ROOT %bitcast = pred[] bitcast(pred[1]{0} %wrapped_compare.2) + } + + ENTRY %main.18 (Arg_0.1: f32[1]) -> f32[] { + %Arg_0.1 = f32[1]{0} parameter(0), sharding={replicated} + %wrapped_copy.4 = f32[1]{0} fusion(f32[1]{0} %Arg_0.1), kind=kLoop, calls=%fused_computation + %while.16 = f32[1]{0} while(f32[1]{0} %wrapped_copy.4), condition=%cond, body=%body + ROOT %bitcast.1 = f32[] bitcast(f32[1]{0} %while.16) + })"; + + const char* expected = R"( + CHECK: %command_buffer ([[P0:.+]]: f32[1]) -> f32[1] { + CHECK: %[[P0]] = f32[1]{0} parameter(0) + CHECK: %[[COPY:.*]] = f32[1]{0} fusion(%[[P0]]), kind=kLoop + CHECK: ROOT {{.*}} = f32[1]{0} while(%[[COPY]]), condition=%[[COND:[a-z_0-9.]+]], body=%[[BODY:[a-z_0-9.]+]] + CHECK: } + + CHECK: ENTRY %[[MAIN:.+]] ([[ARG0:.+]]: f32[1]) -> f32[] { + CHECK: %[[ARG0]] = f32[1]{0} parameter(0) + CHECK: %call = f32[1]{0} call(%[[ARG0]]), to_apply=%command_buffer + CHECK: ROOT %[[BC:.+]] = f32[] bitcast(%call) + CHECK: })"; + + RunAndFilecheckHloRewrite( + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + +TEST_F(CommandBufferSchedulingTest, Conditional) { + const char* hlo = R"( + HloModule TestModule, is_scheduled=true + + %fused_computation.1 (param_0.2: s32[5]) -> s32[5] { + %param_0.2 = s32[5]{0} parameter(0) + ROOT %negate.2 = s32[5]{0} negate(s32[5]{0} %param_0.2) + } + + %region_0.7 (Arg_.8: s32[5]) -> (s32[5]) { + %Arg_.8 = s32[5]{0} parameter(0) + %wrapped_negate.1 = s32[5]{0} fusion(s32[5]{0} %Arg_.8), kind=kLoop, calls=%fused_computation.1 + ROOT %tuple.3 = (s32[5]{0}) tuple(s32[5]{0} %wrapped_negate.1) + } + + %fused_computation.2 (param_0.3: s32[5]) -> s32[5] { + %param_0.3 = s32[5]{0} parameter(0) + ROOT %not.2 = s32[5]{0} not(s32[5]{0} %param_0.3) + } + + %region_1.10 (Arg_.11: s32[5]) -> (s32[5]) { + %Arg_.11 = s32[5]{0} parameter(0) + %wrapped_not.1 = s32[5]{0} fusion(s32[5]{0} %Arg_.11), kind=kLoop, calls=%fused_computation.2 + ROOT %tuple.4 = (s32[5]{0}) tuple(s32[5]{0} %wrapped_not.1) + } + + %fused_computation.3 (param_0.4: s32[5]) -> s32[5] { + %param_0.4 = s32[5]{0} parameter(0) + ROOT %multiply.2 = s32[5]{0} multiply(s32[5]{0} %param_0.4, s32[5]{0} %param_0.4) + } + + %region_2.13 (Arg_.14: s32[5]) -> (s32[5]) { + %Arg_.14 = s32[5]{0} parameter(0) + %wrapped_multiply.1 = s32[5]{0} fusion(s32[5]{0} %Arg_.14), kind=kLoop, calls=%fused_computation.3 + ROOT %tuple.5 = (s32[5]{0}) tuple(s32[5]{0} %wrapped_multiply.1) + } + + %fused_computation (param_0.1: s64[]) -> s32[] { + %constant_1 = s32[] constant(0) + %param_0.1 = s64[] parameter(0) + %convert.2 = s32[] convert(s64[] %param_0.1) + %constant_0 = s32[] constant(2) + ROOT %clamp.2 = s32[] clamp(s32[] %constant_1, s32[] %convert.2, s32[] %constant_0) + } + + ENTRY %main.17 (Arg_0.1: s64[], Arg_1.2: s32[5]) -> s32[5] { + %Arg_0.1 = s64[] parameter(0), sharding={replicated} + %fusion = s32[] fusion(s64[] %Arg_0.1), kind=kLoop, calls=%fused_computation + %Arg_1.2 = s32[5]{0} parameter(1), sharding={replicated} + %conditional.16.clone = (s32[5]{0}) conditional(s32[] %fusion, s32[5]{0} %Arg_1.2, s32[5]{0} %Arg_1.2, s32[5]{0} %Arg_1.2), branch_computations={%region_0.7, %region_1.10, %region_2.13} + ROOT %get-tuple-element = s32[5]{0} get-tuple-element((s32[5]{0}) %conditional.16.clone), index=0 + })"; + + const char* expected = R"( + CHECK: %command_buffer ([[P0:.+]]: s64[], [[P1:.+]]: s32[5]) -> (s32[5]) { + CHECK: %[[P0]] = s64[] parameter(0) + CHECK: %[[P1]] = s32[5]{0} parameter(1) + CHECK: %[[FUSION:.*]] = s32[] fusion(%[[P0]]), kind=kLoop + CHECK: ROOT {{.*}} = (s32[5]{0}) conditional(%[[FUSION]], %[[P1]], %[[P1]], %[[P1]]), branch_computations={%[[B1:[a-z_0-9.]+]], %[[B2:[a-z_0-9.]+]], %[[B3:[a-z_0-9.]+]]} + CHECK: } + + CHECK: ENTRY %[[MAIN:.+]] ([[ARG0:.+]]: s64[], [[ARG1:.+]]: s32[5]) -> s32[5] { + CHECK: %[[ARG0]] = s64[] parameter(0) + CHECK: %[[ARG1]] = s32[5]{0} parameter(1) + CHECK: %call = (s32[5]{0}) call(%[[ARG0]], %[[ARG1]]), to_apply=%command_buffer + CHECK: ROOT %[[GEP:.+]] = s32[5]{0} get-tuple-element(%call) + CHECK: })"; + + RunAndFilecheckHloRewrite( + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + +TEST_F(CommandBufferSchedulingTest, CuDnnFusionGraphCaptureWorks) { + const std::string kHloText = R"( +HloModule m, is_scheduled=true + +fusion0 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + ROOT d = f32[64,64] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +fusion1 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + ROOT d = f32[64,64] dot(p0, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +} + +fusion_a { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + ROOT a = f32[64,64] add(p0, p1) +} - RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(), expected, - [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); +ENTRY e { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + d0 = f32[64,64] fusion(p0, p1), kind=kCustom, + calls=fusion0, + backend_config={"fusion_backend_config": {"kind":"__cudnn$fusion"}} + a = f32[64,64] fusion(d0, d0), kind=kLoop, calls=fusion_a + ROOT d1 = f32[64,64] fusion(a, p1), kind=kCustom, + calls=fusion1, + backend_config={"fusion_backend_config": {"kind":"__cudnn$fusion"}} +})"; + + const std::string kExpected = R"( +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: ROOT +; CHECK-SAME: call( +; CHECK-SAME: to_apply=%command_buffer +})"; + + RunAndFilecheckHloRewrite( + kHloText, + CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), + kExpected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); } } // namespace diff --git a/xla/service/gpu/compile_module_to_llvm_ir.cc b/xla/service/gpu/compile_module_to_llvm_ir.cc index 2a64ef43f41d5..0eed8983af406 100644 --- a/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -27,61 +26,52 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "llvm/Transforms/Utils/SplitModule.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/mlir/backends/gpu/transforms/passes.h" -#include "xla/mlir/runtime/transforms/compilation_pipeline_gpu.h" -#include "xla/mlir_hlo/transforms/gpu_passes.h" #include "xla/service/buffer_assignment.h" #include "xla/service/buffer_value.h" #include "xla/service/dump.h" -#include "xla/service/gpu/buffer_sharing.h" -#include "xla/service/gpu/conditional_thunk.h" -#include "xla/service/gpu/for_thunk.h" #include "xla/service/gpu/gpu_constants.h" #include "xla/service/gpu/gpu_executable.h" +#include "xla/service/gpu/gpu_memory_space_assignment.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/ir_emitter_unnested.h" #include "xla/service/gpu/metrics.h" -#include "xla/service/gpu/sequential_thunk.h" -#include "xla/service/gpu/while_thunk.h" +#include "xla/service/gpu/runtime/conditional_thunk.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/service/gpu/runtime/while_thunk.h" #include "xla/service/hlo_dataflow_analysis.h" -#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/service/hlo_ordering.h" +#include "xla/service/logical_buffer.h" #include "xla/shape.h" #include "xla/status.h" -#include "xla/status_macros.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" -#include "xla/translate/mhlo_to_hlo/location_exporter.h" -#include "xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/casts.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" -namespace xla { -namespace gpu { +namespace xla::gpu { namespace { @@ -91,90 +81,23 @@ static mlir::LogicalResult DiagnosticHandler(mlir::Diagnostic& diag) { return mlir::failure(); } -static bool HasFp8(const HloModule& hlo_module) { - for (const HloComputation* computation : hlo_module.computations()) { - for (const HloInstruction* instruction : computation->instructions()) { - if (ShapeUtil::HasPrimitiveType(instruction->shape(), F8E5M2) || - ShapeUtil::HasPrimitiveType(instruction->shape(), F8E5M2FNUZ) || - ShapeUtil::HasPrimitiveType(instruction->shape(), F8E4M3FN) || - ShapeUtil::HasPrimitiveType(instruction->shape(), F8E4M3B11FNUZ) || - ShapeUtil::HasPrimitiveType(instruction->shape(), F8E4M3FNUZ)) { - return true; +// Removes all globals from the given module that are both uninitialized and +// have no uses within that module. +void RemoveUnusedAndUninitializedGlobals( + llvm::Module* llvm_module, + const std::vector& constants) { + for (const auto& info : constants) { + // Empty content means the constant is initialized in the LLVM IR, so we + // must not remove it. + if (!info.content.span().empty()) { + llvm::GlobalVariable* global = + llvm_module->getGlobalVariable(info.symbol_name); + CHECK(global != nullptr); + if (global->use_empty()) { + global->eraseFromParent(); } } } - return false; -} - -class DumpAfterPassIfEnabled : public mlir::PassInstrumentation { - public: - DumpAfterPassIfEnabled(const HloModule* hlo_module, - const mlir::ModuleOp* mlir_module) - : hlo_module_{hlo_module}, mlir_module_{mlir_module} {} - void runAfterPass(mlir::Pass* pass, mlir::Operation* op) override { - std::string pass_name = pass->getName().str(); - bool should_dump_pass = DumpingEnabledForHloPass( - pass_name, hlo_module_->config().debug_options()); - if (!should_dump_pass) return; - std::string module_str = llvm_ir::DumpToString(*mlir_module_); - auto prefix = "lower_to_xla_gpu_runtime"; - auto suffix = - absl::StrCat("pass_", absl::StrFormat("%02d", pass_counter_++), ".", - "after", ".", pass_name, ".mlir"); - DumpToFileInDirOrStdout(*hlo_module_, prefix, suffix, module_str); - } - - private: - const HloModule* hlo_module_; - const mlir::ModuleOp* mlir_module_; - int pass_counter_ = 0; -}; - -// Lowers MLIR module to the XLA Gpu runtime custom calls. -static Status LowerToXlaGpuRuntime( - mlir::ModuleOp module, llvm::StringRef entry_function_name, - llvm::ArrayRef buffer_sizes, ThunkSequence* thunk_sequence, - const HloModule* hlo_module, se::GpuComputeCapability compute_capability) { - if (!module) { - return InternalError("No MLIR module to lower."); - } - - const DebugOptions& debug_options = hlo_module->config().debug_options(); - bool should_verify = debug_options.xla_gpu_llvm_verification_level() >= 1; -#ifndef NDEBUG - should_verify = true; -#endif - - mlir::PassManager pm(module->getName(), mlir::PassManager::Nesting::Implicit); - pm.enableVerifier(should_verify); - if (hlo_module != nullptr && DumpingEnabledForHloModule(*hlo_module)) { - pm.addInstrumentation( - std::make_unique(hlo_module, &module)); - } - - absl::flat_hash_set command_types; - for (int command_type_num : debug_options.xla_gpu_enable_command_buffer()) { - if (!DebugOptions::CommandBufferCmdType_IsValid(command_type_num)) { - return InternalError("Invalid command buffer command type"); - } - DebugOptions::CommandBufferCmdType command_type = - static_cast(command_type_num); - command_types.insert(command_type); - } - - GpuPipelineOpts opts; - opts.command_types = command_types; - opts.min_graph_size = debug_options.xla_gpu_graph_min_graph_size(); - opts.enable_concurrent_region = - debug_options.xla_gpu_graph_enable_concurrent_region(); - opts.compute_capability = compute_capability; - populateXlaGpuRuntimePasses(pm, thunk_sequence, opts); - - if (pm.run(module).failed()) { - return InternalError("Failed to lower LMHLO to Gpu runtime custom calls."); - } - - return OkStatus(); } } // namespace @@ -188,9 +111,6 @@ void ForAllThunks(const std::function& fn, cond_thunk->branch_thunks()) { ForAllThunks(fn, &branch_thunks->thunks()); } - } else if (thunk->kind() == Thunk::kFor) { - auto* for_thunk = tensorflow::down_cast(thunk.get()); - ForAllThunks(fn, &for_thunk->body_thunk_sequence()->thunks()); } else if (thunk->kind() == Thunk::kSequential) { auto* sequential_thunk = tensorflow::down_cast(thunk.get()); @@ -205,88 +125,7 @@ void ForAllThunks(const std::function& fn, } } -static void ForwardCollectiveAttrs(mlir::ModuleOp module, - llvm::StringRef entry_function_name, - const HloModuleConfig& config) { - mlir::OpBuilder b(module.getContext()); - auto func = module.lookupSymbol(entry_function_name); - func->setAttr("replica_count", b.getI64IntegerAttr(config.replica_count())); - func->setAttr("num_partitions", b.getI64IntegerAttr(config.num_partitions())); -} - -StatusOr LowerToJitRt( - mlir::ModuleOp mlir_module, llvm::StringRef entry_function_name, - llvm::ArrayRef buffer_sizes, - std::unique_ptr thunk_sequence, const HloModule* hlo_module, - se::GpuComputeCapability compute_capability) { - const auto& module_config = hlo_module->config(); - // Forward collective (NCCL) attributes for use by the lowering pipeline. - ForwardCollectiveAttrs(mlir_module, entry_function_name, module_config); - - // Lower LMHLO operations to the XLA:GPU runtime custom calls. - TF_RETURN_IF_ERROR(LowerToXlaGpuRuntime( - mlir_module, {entry_function_name.data(), entry_function_name.size()}, - buffer_sizes, thunk_sequence.get(), hlo_module, compute_capability)); - - // TODO(b/232033540): Pass MLIR module directly to Gpu runtime executable - // without forcing serialization. - std::string module_str = llvm_ir::DumpToString(mlir_module); - - if (hlo_module != nullptr) { - DumpToFileInDirOrStdout(*hlo_module, "gpu_rt_host", "mlir", module_str); - } - - // Collect allocation indices for handling graph capture functions. - auto allocation_indices = GetAllocationIndices(mlir_module); - - return std::make_unique( - entry_function_name.str(), std::move(module_str), buffer_sizes.vec(), - std::move(allocation_indices), module_config.debug_options()); -} - -// Analyze the function signature to reconstruct a vector of BufferAllocation -// objects, as well as other output information. -// -// This function also serves as a half-baked verifier for function arg -// attributes, since a full verifier doesn't exist yet. -static Status GetMlirAllocationInfo( - mlir::func::FuncOp func, std::vector* allocations, - absl::flat_hash_map* output_info, - Shape* output_shape) { - CHECK(allocations->empty()); - allocations->reserve(func.getNumArguments()); - - std::vector buffer_sizes; - for (int i = 0; i < func.getNumArguments(); i++) { - mlir::BlockArgument arg = func.getArgument(i); - - TF_RET_CHECK(arg.getType().isa()); - mlir::ShapedType type = arg.getType().cast(); - TF_ASSIGN_OR_RETURN(auto element_type_bytes, - GetElementTypeBytes(type.getElementType())); - size_t size = type.getNumElements() * element_type_bytes; - buffer_sizes.push_back(size); - } - - for (int i = 0; i < func.getNumArguments(); i++) { - llvm::ArrayRef attrs = - mlir::function_interface_impl::getArgAttrs(func, i); - for (const mlir::NamedAttribute& attr : attrs) { - TF_RET_CHECK(attr.getName() == "lmhlo.params" || - attr.getName() == "lmhlo.param_shape_index" || - attr.getName() == "lmhlo.constant_name" || - attr.getName() == "lmhlo.must_alias" || - attr.getName() == "lmhlo.output_index"); - } - } - - return GpuExecutable::SetUpMlirAllocation(func, buffer_sizes, allocations, - output_info, output_shape); -} - -// The order of `thunk_sequence` corresponds to -// `hlo_schedule->ThunkLaunchOrder()`. -StatusOr CompileModuleToLlvmIr( +absl::StatusOr CompileModuleToLlvmIr( HloModule* hlo_module, llvm::LLVMContext* llvm_context, const std::string& target_triple, const std::string& data_layout, const std::string& platform_name, se::Platform::Id platform_id, @@ -307,7 +146,12 @@ StatusOr CompileModuleToLlvmIr( /*color_alignment=*/ [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; }, /*allocate_buffers_for_constants=*/true, - /*colorer=*/BufferAssigner::DefaultColorer(), + /*colorer=*/ + hlo_module->config() + .debug_options() + .xla_gpu_enable_nccl_user_buffers() + ? CollectiveColorer() + : BufferAssigner::DefaultColorer(), /*must_not_live_out=*/{}, can_share_buffer_function)); VLOG(1) << "Buffer Assignment Stats for " << hlo_module->name() << "\n" @@ -330,70 +174,26 @@ StatusOr CompileModuleToLlvmIr( << ": " << hlo_module->GetFingerprint128(); uint64_t start_usecs = tsl::Env::Default()->NowMicros(); - mlir::DialectRegistry registry; - IrEmitterUnnested::GetDependentDialects(registry); + mlir::DialectRegistry registry; // Disable MLIR multi-threading to prevent creating too many threads when // compiling XLA executables concurrently (e.g. during auto-tuning). auto mlir_context = std::make_unique( registry, mlir::MLIRContext::Threading::DISABLED); - mlir_context->getDiagEngine().registerHandler(DiagnosticHandler); - mlir::OwningOpRef mlir_module = llvm_ir::CreateMlirModuleOp( - mlir::Builder(mlir_context.get()).getUnknownLoc(), hlo_module->name()); - - absl::flat_hash_map - operation_map; - - // Store the allocations in the order of the LMHLO buffer arguments. - std::vector ordered_allocations; - TF_RETURN_IF_ERROR(HloToLhloModule(*results.buffer_assignment, *hlo_module, - *mlir_module, &ordered_allocations, - &operation_map)); - - results.module_name = - mlir::mhlo::GetDebugNameFromLocation(mlir_module->getLoc()); - - if (DumpingEnabledForHloModule(*hlo_module)) { - DumpToFileInDirOrStdout(*hlo_module, "lmhlo", mlir_module.get()); - } - - auto entry_function = mlir::cast( - mlir_module->lookupSymbol(hlo_module->entry_computation()->name())); - - bool emit_from_hlo = !IsXlaRuntimeExecutableEnabled(hlo_module->config()); - std::vector mlir_allocations; - absl::flat_hash_map mlir_output_info; - Shape mlir_output_shape; - TF_RETURN_IF_ERROR(GetMlirAllocationInfo(entry_function, &mlir_allocations, - &mlir_output_info, - &mlir_output_shape)); + results.module_name = hlo_module->name(); IrEmitterContext ir_emitter_context( hlo_module, results.buffer_assignment.get(), platform_name, gpu_device_info, mlir_context.get(), results.llvm_module.get(), - emit_from_hlo); + /*emit_kernels=*/true); std::vector allocations; - if (emit_from_hlo) { - results.output_shape = hlo_module->result_shape(); - TF_ASSIGN_OR_RETURN(results.output_info, - GetOutputInfo(*hlo_module, *results.buffer_assignment)); - TF_RET_CHECK(mlir_allocations.size() == ordered_allocations.size()); - ir_emitter_context.set_allocations(ordered_allocations); - results.use_original_allocations = true; - } else { - results.allocations = std::move(mlir_allocations); - results.output_shape = mlir_output_shape; - results.output_info = mlir_output_info; - allocations.reserve(results.allocations.size()); - for (auto& allocation : results.allocations) { - allocations.push_back(&allocation); - } - ir_emitter_context.set_allocations(allocations); - results.use_original_allocations = false; - } + results.output_shape = hlo_module->result_shape(); + TF_ASSIGN_OR_RETURN(results.output_info, + GetOutputInfo(*hlo_module, *results.buffer_assignment)); + results.use_original_allocations = true; auto ir_emitter = IrEmitterUnnested::Create(&ir_emitter_context); @@ -402,7 +202,7 @@ StatusOr CompileModuleToLlvmIr( "GpuCompiler::RunBackend - IR emission for ", hlo_module->name())); TF_RETURN_IF_ERROR( - ir_emitter->EmitLmhloRegion(&entry_function.getBody(), operation_map)); + ir_emitter->EmitHloComputation(hlo_module->entry_computation())); bool supports_runtime_managed_constants = // TODO(b/218907125): Implement this feature for ROCm as well. @@ -423,48 +223,12 @@ StatusOr CompileModuleToLlvmIr( RecordHloToLlvmDuration(end_usecs - start_usecs); } - // TODO(ezhulenev): Remove the FP8 check once https://reviews.llvm.org/D140088 - // is submitted. Currently we can't emit LLVM IR with fp8 types. - if (IsXlaRuntimeExecutableEnabled(hlo_module->config()) && - !HasFp8(*hlo_module)) { - // Sizes of all buffers required for running XLA module. - std::vector buffer_sizes; - llvm::transform( - results.allocations, std::back_inserter(buffer_sizes), - [](const BufferAllocation& allocation) { return allocation.size(); }); + auto thunk_sequence = ir_emitter->ConsumeThunkSequence(); + ForAllThunks([](Thunk* thunk) { thunk->ClearCompileTimeInfo(); }, + thunk_sequence.get()); + results.executable = std::move(thunk_sequence); - TF_ASSIGN_OR_RETURN( - results.executable, - LowerToJitRt(*mlir_module, entry_function.getName(), buffer_sizes, - ir_emitter->ConsumeThunkSequence(), hlo_module, - gpu_device_info.gpu_compute_capability())); - } else { - auto thunk_sequence = ir_emitter->ConsumeThunkSequence(); - ForAllThunks([](Thunk* thunk) { thunk->ClearCompileTimeInfo(); }, - thunk_sequence.get()); - results.executable = std::move(thunk_sequence); - } return results; } -// Removes all globals from the given module that are both uninitialized and -// have no uses within that module. -void RemoveUnusedAndUninitializedGlobals( - llvm::Module* llvm_module, - const std::vector& constants) { - for (const auto& info : constants) { - // Empty content means the constant is initialized in the LLVM IR, so we - // must not remove it. - if (!info.content.span().empty()) { - llvm::GlobalVariable* global = - llvm_module->getGlobalVariable(info.symbol_name); - CHECK(global != nullptr); - if (global->use_empty()) { - global->eraseFromParent(); - } - } - } -} - -} // namespace gpu -} // namespace xla +} // namespace xla::gpu diff --git a/xla/service/gpu/compile_module_to_llvm_ir.h b/xla/service/gpu/compile_module_to_llvm_ir.h index 03f93a3482975..4dee30414094a 100644 --- a/xla/service/gpu/compile_module_to_llvm_ir.h +++ b/xla/service/gpu/compile_module_to_llvm_ir.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,21 +18,25 @@ limitations under the License. #include #include -#include #include -#include #include +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/buffer_assignment.h" #include "xla/service/buffer_value.h" #include "xla/service/gpu/executable.pb.h" #include "xla/service/gpu/gpu_executable.h" -#include "xla/service/gpu/thunk.h" +#include "xla/service/gpu/runtime/thunk.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_dataflow_analysis.h" -#include "xla/statusor.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/platform.h" #include "xla/util.h" namespace xla { @@ -42,9 +46,7 @@ struct CompileModuleResults { std::unique_ptr llvm_module; std::unique_ptr buffer_assignment; std::vector allocations; - std::variant - executable; + GpuExecutable::OwnedThunkSequence executable; std::vector constants; absl::flat_hash_map output_info; Shape output_shape; @@ -59,13 +61,7 @@ struct CompileModuleResults { void ForAllThunks(const std::function& fn, ThunkSequence* thunk_sequence); -// Removes all globals from the given module that are both uninitialized and -// have no uses within that module. -void RemoveUnusedAndUninitializedGlobals( - llvm::Module* llvm_module, - const std::vector& constants); - -StatusOr CompileModuleToLlvmIr( +absl::StatusOr CompileModuleToLlvmIr( HloModule* hlo_module, llvm::LLVMContext* llvm_context, const std::string& target_triple, const std::string& data_layout, const std::string& platform_name, se::Platform::Id platform_id, diff --git a/xla/service/gpu/conditional_thunk.cc b/xla/service/gpu/conditional_thunk.cc deleted file mode 100644 index 6867544addc03..0000000000000 --- a/xla/service/gpu/conditional_thunk.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/conditional_thunk.h" - -#include - -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/util.h" -#include "tsl/platform/errors.h" - -namespace xla { -namespace gpu { - -ConditionalThunk::ConditionalThunk( - ThunkInfo thunk_info, ConditionalThunkConfig config, - const BufferAllocation::Slice& branch_index_buffer_index) - : Thunk(Kind::kConditional, thunk_info), - config_(std::move(config)), - branch_index_buffer_index_(branch_index_buffer_index) {} - -Status ConditionalThunk::Initialize(se::StreamExecutor* executor, - ExecutableSource src) { - if (config_.branch_index_is_bool) { - TF_RET_CHECK(config_.branch_thunks.size() == 2); - } else { - TF_RET_CHECK(!config_.branch_thunks.empty()); - } - for (auto& branch_thunk : config_.branch_thunks) { - TF_RETURN_IF_ERROR(branch_thunk->Initialize(executor, src)); - } - return OkStatus(); -} - -Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) { - auto& stream = *params.stream; - - // Copy the predicate value from device. - int32_t branch_index = -1; - bool pred = false; - se::DeviceMemoryBase branch_index_address = - params.buffer_allocations->GetDeviceAddress(branch_index_buffer_index_); - if (config_.branch_index_is_bool) { - stream.ThenMemcpy(&pred, branch_index_address, sizeof(bool)); - } else { - stream.ThenMemcpy(&branch_index, branch_index_address, sizeof(int32_t)); - } - - Status block_status = stream.BlockHostUntilDone(); - if (!block_status.ok()) { - return InternalError( - "Failed to retrieve branch_index value on stream %p: %s.", &stream, - block_status.message()); - } - if (config_.branch_index_is_bool) { - branch_index = pred ? 0 : 1; - } else { - // Handle default scenario for branch_index not in [0, num_branches). - if (branch_index < 0 || branch_index >= config_.branch_count) { - branch_index = config_.branch_count - 1; - } - } - - // Execute the branch computation corresponding to the value of branch_index. - TF_RETURN_IF_ERROR( - config_.branch_thunks[branch_index]->ExecuteOnStream(params)); - - return OkStatus(); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/conditional_thunk.h b/xla/service/gpu/conditional_thunk.h deleted file mode 100644 index 357d09e663605..0000000000000 --- a/xla/service/gpu/conditional_thunk.h +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_ -#define XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_ - -#include -#include - -#include "absl/types/span.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/sequential_thunk.h" -#include "xla/service/gpu/thunk.h" -#include "xla/stream_executor/stream_executor.h" - -namespace xla { -namespace gpu { - -struct ConditionalThunkConfig { - bool branch_index_is_bool; - int64_t branch_count; - std::vector> branch_thunks; -}; - -// ConditionalThunk implements the conditional instruction on GPU by reading the -// predicate of the conditional and executing the true or the false computation -// depending on the value of the predicate. -// -// ConditionalThunk assumes that the buffers of the conditional result and the -// result of the true and false computations share the same allocation. Also, -// the buffers of the true operand of the conditional and that of the parameter -// instruction of the true computation share the same allocation. Similarly, the -// buffers of the false operand and that of the parameter instruction of the -// false computation share the same allocation. -class ConditionalThunk : public Thunk { - public: - ConditionalThunk(ThunkInfo thunk_info, ConditionalThunkConfig config, - const BufferAllocation::Slice& branch_index_buffer_index); - - ConditionalThunk(const ConditionalThunk&) = delete; - ConditionalThunk& operator=(const ConditionalThunk&) = delete; - - Status Initialize(se::StreamExecutor* executor, - ExecutableSource src) override; - Status ExecuteOnStream(const ExecuteParams& params) override; - - absl::Span> branch_thunks() { - return config_.branch_thunks; - } - - private: - const ConditionalThunkConfig config_; - BufferAllocation::Slice branch_index_buffer_index_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_ diff --git a/xla/service/gpu/conv_algorithm_picker.cc b/xla/service/gpu/conv_algorithm_picker.cc index ee767e2b6c68f..9a9b690bae167 100644 --- a/xla/service/gpu/conv_algorithm_picker.cc +++ b/xla/service/gpu/conv_algorithm_picker.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,42 +16,70 @@ limitations under the License. #include "xla/service/gpu/conv_algorithm_picker.h" #include +#include +#include +#include #include #include #include #include #include -#include #include -#include #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/literal_util.h" +#include "xla/service/gpu/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/gpu_asm_opts_util.h" +#include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/gpu_autotuning.pb.h" +#include "xla/service/gpu/gpu_conv_runner.h" #include "xla/service/gpu/hlo_algorithm_denylist.h" #include "xla/service/gpu/stream_executor_util.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/service_executable_run_options.h" #include "xla/service/slow_operation_alarm.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/lazy_op_runner.h" +#include "xla/stream_executor/numeric_options.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/util/proto/proto_utils.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logger.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/numbers.h" -#include "tsl/util/proto/proto_utils.h" +#include "tsl/platform/statusor.h" #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) -#include "third_party/gpus/cudnn/cudnn.h" +#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: keep +#include "third_party/gpus/cudnn/cudnn_version.h" +#if CUDNN_VERSION >= 90000 +#include "third_party/gpus/cudnn/cudnn_ops.h" +#else +#include "third_party/gpus/cudnn/cudnn_ops_infer.h" +#endif // CUDNN_VERSION >= 90000 #include "xla/service/gpu/buffer_comparator.h" #include "xla/stream_executor/gpu/redzone_allocator.h" #endif @@ -75,10 +103,11 @@ class ScratchAllocator : public se::ScratchAllocator { } int64_t TotalAllocatedBytes() { return total_allocated_bytes_; } - StatusOr> AllocateBytes(int64_t byte_size) override; + absl::StatusOr> AllocateBytes( + int64_t byte_size) override; template - StatusOr> Allocate(int64_t num_elements) { + absl::StatusOr> Allocate(int64_t num_elements) { TF_ASSIGN_OR_RETURN(se::DeviceMemory bytes, AllocateBytes(num_elements * sizeof(T))); return se::DeviceMemory(bytes); @@ -91,15 +120,13 @@ class ScratchAllocator : public se::ScratchAllocator { int64_t total_allocated_bytes_ = 0; }; -StatusOr> ScratchAllocator::AllocateBytes( +absl::StatusOr> ScratchAllocator::AllocateBytes( int64_t byte_size) { CHECK_GE(byte_size, 0) << "byte_size must be positive."; if (byte_size > GetMemoryLimitInBytes()) { - return Status( - absl::StatusCode::kResourceExhausted, - absl::StrFormat( - "Allocating %d bytes exceeds the memory limit of %d bytes.", - byte_size, GetMemoryLimitInBytes())); + return absl::ResourceExhaustedError(absl::StrFormat( + "Allocating %d bytes exceeds the memory limit of %d bytes.", byte_size, + GetMemoryLimitInBytes())); } TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer, @@ -112,7 +139,7 @@ StatusOr> ScratchAllocator::AllocateBytes( return se::DeviceMemory(buffer_addr); } -StatusOr> GetAlgorithms( +absl::StatusOr> GetAlgorithms( const GpuConvConfig& config, se::Stream* stream, bool use_cudnn_frontend, bool use_fallback, const se::NumericOptions& numeric_options) { TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind, @@ -127,16 +154,20 @@ StatusOr> GetAlgorithms( se::StreamExecutor* stream_exec = stream->parent(); std::vector result; + auto dnn = stream_exec->AsDnn(); + if (dnn == nullptr) { + return absl::InvalidArgumentError("No DNN in stream executor."); + } switch (kind) { default: - return InternalError("Unknown ConvolutionKind %d", kind); + return Internal("Unknown ConvolutionKind %d", kind); case se::dnn::ConvolutionKind::FORWARD_BIAS_ACTIVATION: { if (!config.fusion) { - return InternalError( + return Internal( "GpuConvConfig had fusion ConvolutionKind but no FusionConfig."); } std::vector> runners; - TF_RETURN_IF_ERROR(stream_exec->GetFusedConvolveRunners( + TF_RETURN_IF_ERROR(dnn->GetFusedConvolveRunners( use_cudnn_frontend, // This refers to the kind of convolution op inside the fusion, not // the whole fused graph. @@ -162,7 +193,7 @@ StatusOr> GetAlgorithms( std::vector> runners; // This path is cuDNN-only, where the DeviceMemoryBase arguments and the // allocator are unused; so, they're all provided as nullptr. - TF_RETURN_IF_ERROR(stream_exec->GetGraphConvolveRunners( + TF_RETURN_IF_ERROR(dnn->GetGraphConvolveRunners( kind, input_type, output_type, stream, config.input_descriptor, config.filter_descriptor, config.output_descriptor, config.conv_desc, use_fallback, numeric_options, &runners, config.serialized_graph)); @@ -182,7 +213,7 @@ StatusOr> GetAlgorithms( std::vector> runners; // This path is cuDNN-only, where the DeviceMemoryBase arguments and the // allocator are unused; so, they're all provided as nullptr. - TF_RETURN_IF_ERROR(stream_exec->GetConvolveRunners( + TF_RETURN_IF_ERROR(dnn->GetConvolveRunners( use_cudnn_frontend, kind, input_type, output_type, stream, config.input_descriptor, /* input_data = */ DeviceMemoryBase(nullptr), @@ -206,7 +237,7 @@ StatusOr> GetAlgorithms( return result; } -StatusOr>> +absl::StatusOr>> GetMIOpenAlgorithms(const HloCustomCallInstruction* instr, absl::Span operand_buffers, absl::Span result_buffers, @@ -226,7 +257,11 @@ GetMIOpenAlgorithms(const HloCustomCallInstruction* instr, GetGpuConvParams(config, operand_buffers, result_buffers)); std::vector> runners; - TF_RETURN_IF_ERROR(stream_exec->GetConvolveRunners( + auto dnn = stream_exec->AsDnn(); + if (dnn == nullptr) { + return absl::InvalidArgumentError("No DNN in stream executor."); + } + TF_RETURN_IF_ERROR(dnn->GetConvolveRunners( /* use_cudnn_frontend = */ false, kind, dtype, dtype, stream, params.config->input_descriptor, params.input_buf, params.config->filter_descriptor, params.filter_buf, @@ -246,7 +281,7 @@ std::string NumBytesToString(int64_t bytes) { CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) { CudnnVersion cudnn_version; if (auto* dnn = stream_executor->AsDnn()) { - StatusOr version_or = dnn->GetVersion(); + absl::StatusOr version_or = dnn->GetVersion(); if (version_or.ok()) { const auto& version = version_or.value(); cudnn_version.set_major(version.major_version()); @@ -290,15 +325,15 @@ void PrintPlatformInfo(const se::Stream* stream) { // If the redzones are modified, logs an error, sets the appropriate failure // bits on `result`, and returns false. // -// Returns a status if an unexpected error has occurred, and the stream +// Returns a absl::Status if an unexpected error has occurred, and the stream // has been poisoned. // // `name` is a user-friendly name for the set of redzones being checked, e.g. // "input/output" or "scratch". -StatusOr CheckRedzones(const se::RedzoneAllocator& allocator, - se::Stream* stream, absl::string_view name, - std::string_view instr_str, - AutotuneResult* result) { +absl::StatusOr CheckRedzones(const se::RedzoneAllocator& allocator, + se::Stream* stream, absl::string_view name, + std::string_view instr_str, + AutotuneResult* result) { XLA_SCOPED_LOGGING_TIMER_LEVEL("CudnnConvAlgorithmPicker checking redzones", 2); using RedzoneCheckStatus = se::RedzoneAllocator::RedzoneCheckStatus; @@ -343,13 +378,13 @@ bool ShouldCheckConv(const HloModuleConfig& hlo_module_config) { return conv_autotune_level >= 4; } -StatusOr GpuConvAlgorithmPicker::PickBestAlgorithm( +absl::StatusOr GpuConvAlgorithmPicker::PickBestAlgorithm( const HloCustomCallInstruction* instr) { return AutotunerUtil::Autotune( instr, config_, [&] { return PickBestAlgorithmNoCache(instr); }); } -StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCache( +absl::StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCache( const HloCustomCallInstruction* instr) { AutotuneCacheKey key(config_.GetModelStr(), *instr); if (config_.IsDeviceless()) { @@ -375,7 +410,7 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCache( // Make sure any previous activity on this executor is done. We don't want // other work still running on the GPU to interfere with autotuning. if (!stream_exec->SynchronizeAllActivity()) { - return InternalError( + return Internal( "Failed to synchronize GPU for autotuning conv instruction"); } @@ -384,7 +419,7 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCache( se::DeviceMemoryAllocator* allocator = config_.GetAllocator(); TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream()); - StatusOr result_or(InternalError("Unknown platform.")); + absl::StatusOr result_or(Internal("Unknown platform.")); // Check StreamExecutor on which platform it is. ROCm and Cuda implementation // have diverged. Specifically, we need to make sure redzone allocator related // utilities are not used in ROCm routine @@ -412,7 +447,7 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCache( #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) -StatusOr +absl::StatusOr GpuConvAlgorithmPicker::AutotuneRuntimeArguments::FromInstruction( const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator, se::StreamExecutor* stream_exec, @@ -485,7 +520,7 @@ GpuConvAlgorithmPicker::AutotuneRuntimeArguments::FromInstruction( // failure code other than DISQUALIFIED means autotuning fails if // crash_on_checking_failure is set; and returning a DISQUALIFIED AutotuneResult // simply skips the engine/algorithm while recording a reason for skipping it. -StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( +absl::StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( se::Stream* stream, GenericConvRunner* const runner, std::optional* reference_result, absl::Span disabled_algos, @@ -577,12 +612,11 @@ StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( // Use assignment instead of brace-list to make GCC 4.9 happy. RunConvOptions options; options.runner_cache = runner; - options.profile_result = &profile_result; // The following plan timing code is based on // https://github.com/NVIDIA/cudnn-frontend/blob/60496f42fdc7a4ccc059f5934e306e728a756755/include/cudnn_frontend_find_plan.h float max_time = 0; float min_time = std::numeric_limits::max(); - Status launch_status; + absl::Status launch_status; std::vector operand_buffers = runtime_arguments.operand_buffers; std::vector result_buffers = @@ -590,20 +624,29 @@ StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( // Dry-run to warmup the plan. launch_status = RunGpuConv(config, operand_buffers, result_buffers, scratch_memory, stream, options); - constexpr float kThreshold = 0.95f; + // Flag that a warm-up run has been executed; this allows the GpuTimer for + // the main measurement to safely use the delay kernel pattern, even if lazy + // module loading is enabled. + options.profile_result = &profile_result; + profile_result.set_warmup_run_executed(true); constexpr int kMaxIter = 10; - // Iterate until new measurement is less than - // kThreshold * min(prev measurements). + // Iterate until the new measurement is within kThreshold of the current + // minimum. int num_iters = 0; - for (; - num_iters < kMaxIter && launch_status.ok() && profile_result.is_valid(); - num_iters++) { + for (; num_iters < kMaxIter && launch_status.ok(); ++num_iters) { launch_status = RunGpuConv(config, operand_buffers, result_buffers, scratch_memory, stream, options); + if (!profile_result.is_valid()) { + break; + } float old_min_time = min_time; min_time = std::min(min_time, profile_result.elapsed_time_in_ms()); max_time = std::max(max_time, profile_result.elapsed_time_in_ms()); - if (profile_result.elapsed_time_in_ms() / old_min_time >= kThreshold) { + + constexpr float kThreshold = 0.05f; + if (std::abs(profile_result.elapsed_time_in_ms() - old_min_time) / + old_min_time < + kThreshold) { break; } } @@ -688,7 +731,7 @@ StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( BufferComparator comparator(runtime_arguments.result_shape, runtime_arguments.hlo_module_config); for (int i = 0; i < result_buffers.size(); ++i) { - StatusOr compare_result = comparator.CompareEqual( + absl::StatusOr compare_result = comparator.CompareEqual( stream, (*reference_result)->buffers[i], result_buffers[i]); if (!compare_result.ok()) { LOG(ERROR) << "Unable to compare " @@ -731,8 +774,9 @@ StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( reference_result_buffers[i], runtime_arguments.input_output_allocator->AllocateBytes( result_buffers[i].size())); - stream->ThenMemcpy(&reference_result_buffers[i], result_buffers[i], - result_buffers[i].size()); + TF_RETURN_IF_ERROR(stream->Memcpy(&reference_result_buffers[i], + result_buffers[i], + result_buffers[i].size())); } (*reference_result) = {alg, reference_result_buffers}; } @@ -740,7 +784,8 @@ StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( return result; } -StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( +absl::StatusOr +GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( const HloCustomCallInstruction* instr, se::Stream* stream, std::optional instruction_info, const AutotuneRuntimeArguments& runtime_arguments) { @@ -855,17 +900,14 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id()); log.set_blas_version(blas_version); VLOG(2) << "Autotuning result: " << log.ShortDebugString(); - // If we crash on checking failure, we are in a testing/benchmark mode, thus - // omitting logging through the logger. - if (!crash_on_checking_failure) { - tsl::Logger::GetSingleton()->LogProto(log); - } else { + // If we crash on checking failure, we are in a testing/benchmark mode. + if (crash_on_checking_failure) { // Crash on miscompares and redzone violations if desired. for (const auto& profile : profile_results) { if (profile.has_failure() && profile.failure().kind() != AutotuneResult::DISQUALIFIED) { LOG(FATAL) << "crash_on_checking_failure encountered errors:\n\n" - << log.DebugString(); + << log.DebugString(); // NOLINT } } } @@ -878,7 +920,7 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( } #endif -StatusOr +absl::StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmWithAllocatedBuffer( const AutotuneConfig& config, const GpuConvConfig conv_config, const ServiceExecutableRunOptions* run_options, @@ -903,11 +945,12 @@ GpuConvAlgorithmPicker::PickBestAlgorithmWithAllocatedBuffer( /*instr=*/nullptr, stream, /*instruction_info=*/std::nullopt, autotune_runtime_arguments); #else - return InternalError("CUDA is not enabled"); + return Internal("CUDA is not enabled"); #endif } -StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( +absl::StatusOr +GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator, se::Stream* stream) { XLA_SCOPED_LOGGING_TIMER(absl::StrCat( @@ -931,7 +974,7 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( // before autotuning. It's conceivable that using uninitialized memory as // the inputs might affect performance if e.g. the inputs contain // denormals, and this is easy enough. - stream->ThenMemZero(&buffer, buffer.size()); + return stream->MemZero(&buffer, buffer.size()); }; // Allocate space for the input, filter, and output of the convolution. We @@ -941,7 +984,7 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( TF_ASSIGN_OR_RETURN(auto buffer, input_output_allocator.AllocateBytes( ShapeUtil::ByteSizeOf(operand->shape()))); - initialize_buffer(buffer); + TF_RETURN_IF_ERROR(initialize_buffer(buffer)); operand_buffers.push_back(buffer); } @@ -953,14 +996,14 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( result_buffers[i], input_output_allocator.AllocateBytes( ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(i)))); - initialize_buffer(result_buffers[i]); + TF_RETURN_IF_ERROR(initialize_buffer(result_buffers[i])); } } else { TF_ASSIGN_OR_RETURN( result_buffers[0], input_output_allocator.AllocateBytes( ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0)))); - initialize_buffer(result_buffers[0]); + TF_RETURN_IF_ERROR(initialize_buffer(result_buffers[0])); } ScratchAllocator scratch_allocator(device_ordinal, allocator); @@ -1014,7 +1057,7 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( RunConvOptions options; options.profile_result = &profile_result; options.runner_cache = &runner_cache; - Status launch_status = + absl::Status launch_status = RunGpuConv(config, absl::MakeSpan(operand_buffers), result_buffers, scratch_memory, stream, options); @@ -1043,7 +1086,8 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( return selected_algorithm; } -StatusOr GpuConvAlgorithmPicker::RunOnInstruction(HloInstruction* instr) { +absl::StatusOr GpuConvAlgorithmPicker::RunOnInstruction( + HloInstruction* instr) { CHECK(IsCustomCallToDnnConvolution(*instr)); const bool strict = instr->parent() @@ -1052,7 +1096,7 @@ StatusOr GpuConvAlgorithmPicker::RunOnInstruction(HloInstruction* instr) { .debug_options() .xla_gpu_strict_conv_algorithm_picker(); - StatusOr best_algo_or = + absl::StatusOr best_algo_or = PickBestAlgorithm(Cast(instr)); if (!best_algo_or.ok()) { auto msg = absl::StrFormat( @@ -1094,8 +1138,10 @@ StatusOr GpuConvAlgorithmPicker::RunOnInstruction(HloInstruction* instr) { ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes()})); Shape new_call_shape = ShapeUtil::MakeTupleShape(new_call_element_shapes); - TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, - instr->backend_config()); + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config, + instr->backend_config()); + CudnnConvBackendConfig& backend_config = + *gpu_backend_config.mutable_cudnn_conv_backend_config(); *backend_config.mutable_algorithm() = best_algo.algorithm(); backend_config.mutable_algorithm()->mutable_workspace_size()->set_value( best_algo.scratch_bytes()); @@ -1111,7 +1157,7 @@ StatusOr GpuConvAlgorithmPicker::RunOnInstruction(HloInstruction* instr) { VLOG(3) << "Replacing convolution " << instr->ToString() << " with " << new_call->ToString(); - TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config)); + TF_RETURN_IF_ERROR(new_call->set_backend_config(gpu_backend_config)); std::vector new_tuple_elements; new_tuple_elements.reserve(new_call->shape().tuple_shapes_size() - 1); @@ -1132,7 +1178,7 @@ StatusOr GpuConvAlgorithmPicker::RunOnInstruction(HloInstruction* instr) { return true; } -StatusOr GpuConvAlgorithmPicker::RunOnComputation( +absl::StatusOr GpuConvAlgorithmPicker::RunOnComputation( HloComputation* computation) { std::vector convs; for (HloInstruction* instr : computation->instructions()) { @@ -1149,7 +1195,7 @@ StatusOr GpuConvAlgorithmPicker::RunOnComputation( return changed; } -StatusOr GpuConvAlgorithmPicker::Run( +absl::StatusOr GpuConvAlgorithmPicker::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_SCOPED_LOGGING_TIMER( diff --git a/xla/service/gpu/conv_algorithm_picker.h b/xla/service/gpu/conv_algorithm_picker.h index 2850c0609e70a..046b11ca045e9 100644 --- a/xla/service/gpu/conv_algorithm_picker.h +++ b/xla/service/gpu/conv_algorithm_picker.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,18 +18,28 @@ limitations under the License. #include #include -#include #include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/autotune_results.pb.h" #include "xla/autotuning.pb.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/gpu_conv_runner.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_interface.h" #include "xla/service/service_executable_run_options.h" +#include "xla/shape.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/dnn.h" #include "xla/stream_executor/stream_executor.h" #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) @@ -90,12 +100,12 @@ class GpuConvAlgorithmPicker : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; // Run autotuning on allocated buffers and pick the best algorithm. - StatusOr PickBestAlgorithmWithAllocatedBuffer( + absl::StatusOr PickBestAlgorithmWithAllocatedBuffer( const AutotuneConfig& config, GpuConvConfig conv_config, const ServiceExecutableRunOptions* run_options, const DebugOptions& debug_options, @@ -103,12 +113,12 @@ class GpuConvAlgorithmPicker : public HloModulePass { std::vector result_buffers); private: - StatusOr RunOnComputation(HloComputation* computation); - StatusOr RunOnInstruction(HloInstruction* instr); + absl::StatusOr RunOnComputation(HloComputation* computation); + absl::StatusOr RunOnInstruction(HloInstruction* instr); - StatusOr PickBestAlgorithm( + absl::StatusOr PickBestAlgorithm( const HloCustomCallInstruction* instr); - StatusOr PickBestAlgorithmNoCache( + absl::StatusOr PickBestAlgorithmNoCache( const HloCustomCallInstruction* instr); #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) @@ -131,13 +141,13 @@ class GpuConvAlgorithmPicker : public HloModulePass { const GpuConvConfig gpu_conv_config; std::optional canonical_hlo; - static StatusOr FromInstruction( + static absl::StatusOr FromInstruction( const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator, se::StreamExecutor* stream, se::RedzoneAllocator* input_output_allocator); }; - StatusOr AutotuneOneConvRunner( + absl::StatusOr AutotuneOneConvRunner( se::Stream* stream, GenericConvRunner* runner, std::optional* reference_result, absl::Span disabled_algos, @@ -145,13 +155,13 @@ class GpuConvAlgorithmPicker : public HloModulePass { const AutotuneRuntimeArguments& runtime_arguments); // Pick the best algorithm for CUDA platform. - StatusOr PickBestAlgorithmNoCacheCuda( + absl::StatusOr PickBestAlgorithmNoCacheCuda( const HloCustomCallInstruction* instr, se::Stream* stream, std::optional instruction_info, const AutotuneRuntimeArguments& runtime_arguments); #endif - StatusOr PickBestAlgorithmNoCacheRocm( + absl::StatusOr PickBestAlgorithmNoCacheRocm( const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator, se::Stream* stream); diff --git a/xla/service/gpu/conv_algorithm_picker_test.cc b/xla/service/gpu/conv_algorithm_picker_test.cc index ae20168f68d62..bd36e0b439d1d 100644 --- a/xla/service/gpu/conv_algorithm_picker_test.cc +++ b/xla/service/gpu/conv_algorithm_picker_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,10 +15,19 @@ limitations under the License. #include "xla/service/gpu/conv_algorithm_picker.h" +#include +#include + +#include "absl/strings/string_view.h" +#include "xla/debug_options_flags.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/autotuner_util.h" #include "xla/service/gpu/gpu_conv_rewriter.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/service/platform_util.h" #include "xla/service/tuple_simplifier.h" +#include "xla/stream_executor/platform.h" #include "xla/tests/hlo_test_base.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/gpu/conv_layout_normalization.cc b/xla/service/gpu/conv_layout_normalization.cc index 3e744937f3a18..68c5d18effd56 100644 --- a/xla/service/gpu/conv_layout_normalization.cc +++ b/xla/service/gpu/conv_layout_normalization.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,24 +15,29 @@ limitations under the License. #include "xla/service/gpu/conv_layout_normalization.h" +#include #include -#include #include +#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/layout_util.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/hlo_creation_utils.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/util.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { namespace { -StatusOr> UpdateLayoutForCudnnConvolution( +absl::StatusOr> UpdateLayoutForCudnnConvolution( HloCustomCallInstruction* hlo) { HloInstruction* lhs = hlo->mutable_operand(0); HloInstruction* rhs = hlo->mutable_operand(1); @@ -186,7 +191,7 @@ StatusOr> UpdateLayoutForCudnnConvolution( } // namespace -StatusOr> NormalizeLayoutForGpuCustomCalls( +absl::StatusOr> NormalizeLayoutForGpuCustomCalls( HloCustomCallInstruction* hlo) { if (IsCustomCallToDnnConvolution(*hlo)) { TF_ASSIGN_OR_RETURN(std::optional bc_to_orig, diff --git a/xla/service/gpu/conv_layout_normalization.h b/xla/service/gpu/conv_layout_normalization.h index 52fd26c1fbf95..b8723d9234782 100644 --- a/xla/service/gpu/conv_layout_normalization.h +++ b/xla/service/gpu/conv_layout_normalization.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,19 +16,16 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_CONV_LAYOUT_NORMALIZATION_H_ #define XLA_SERVICE_GPU_CONV_LAYOUT_NORMALIZATION_H_ -#include #include -#include -#include "absl/strings/string_view.h" +#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/statusor.h" namespace xla { namespace gpu { -StatusOr> NormalizeLayoutForGpuCustomCalls( +absl::StatusOr> NormalizeLayoutForGpuCustomCalls( HloCustomCallInstruction*); } // end namespace gpu diff --git a/xla/service/gpu/conv_layout_normalization_test.cc b/xla/service/gpu/conv_layout_normalization_test.cc index ce5095743f832..56629fe39bee3 100644 --- a/xla/service/gpu/conv_layout_normalization_test.cc +++ b/xla/service/gpu/conv_layout_normalization_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - -#include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/test.h" @@ -59,7 +58,7 @@ HloModule TestModule %copy.3 = f32[1,20,257]{1,2,0} copy(f32[1,20,257]{2,1,0} %param_0) %param_1 = f32[31,257,136]{2,1,0} parameter(1) %copy.4 = f32[31,257,136]{0,2,1} copy(f32[31,257,136]{2,1,0} %param_1) - %custom-call.1 = (f32[1,23,136]{1,2,0}, u8[0]{0}) custom-call(f32[1,20,257]{1,2,0} %copy.3, f32[31,257,136]{0,2,1} %copy.4), window={size=31 stride=2 pad=23_23}, dim_labels=b0f_0oi->b0f, custom_call_target="__cudnn$convBackwardInput", backend_config="{conv_result_scale:1}" + %custom-call.1 = (f32[1,23,136]{1,2,0}, u8[0]{0}) custom-call(f32[1,20,257]{1,2,0} %copy.3, f32[31,257,136]{0,2,1} %copy.4), window={size=31 stride=2 pad=23_23}, dim_labels=b0f_0oi->b0f, custom_call_target="__cudnn$convBackwardInput", backend_config={"cudnn_conv_backend_config":{conv_result_scale:1}} %get-tuple-element.2 = f32[1,23,136]{1,2,0} get-tuple-element((f32[1,23,136]{1,2,0}, u8[0]{0}) %custom-call.1), index=0 %copy.5 = f32[1,23,136]{2,1,0} copy(f32[1,23,136]{1,2,0} %get-tuple-element.2) %get-tuple-element.3 = u8[0]{0} get-tuple-element((f32[1,23,136]{1,2,0}, u8[0]{0}) %custom-call.1), index=1 @@ -81,7 +80,7 @@ HloModule TestModule ENTRY %TestComputation { %param_0 = f32[2,128,1,378]{3,2,1,0} parameter(0) %param_1 = f32[1,5,128,128]{1,0,2,3} parameter(1) - ROOT %custom-call.1 = (f32[2,128,1,378]{3,2,1,0}, u8[0]{0}) custom-call(%param_0, %param_1), window={size=1x5 pad=0_0x2_2}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config="{conv_result_scale:1}" + ROOT %custom-call.1 = (f32[2,128,1,378]{3,2,1,0}, u8[0]{0}) custom-call(%param_0, %param_1), window={size=1x5 pad=0_0x2_2}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config={"cudnn_conv_backend_config":{conv_result_scale:1}} } )"; @@ -91,7 +90,8 @@ ENTRY %TestComputation { )"); } -TEST_F(ConvolutionLayoutNormalizationTest, FusedConv3D) { +// TODO(rocm): No Conv3D +TEST_F(ConvolutionLayoutNormalizationTest, DISABLED_ON_GPU_ROCM(FusedConv3D)) { const char* hlo = R"( HloModule TestModule diff --git a/xla/service/gpu/convolution_thunk.cc b/xla/service/gpu/convolution_thunk.cc deleted file mode 100644 index 17096e890cf77..0000000000000 --- a/xla/service/gpu/convolution_thunk.cc +++ /dev/null @@ -1,140 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/convolution_thunk.h" - -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/service/gpu/gpu_conv_runner.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" -#include "xla/util.h" -#include "tsl/platform/logging.h" - -namespace xla { -namespace gpu { - -ConvolutionThunk::ConvolutionThunk( - ThunkInfo thunk_info, GpuConvConfig config, - std::vector operand_slices, - std::vector result_slices, - BufferAllocation::Slice scratch_slice) - : Thunk(Kind::kConvolution, thunk_info), - operand_buffers_(std::move(operand_slices)), - result_buffers_(std::move(result_slices)), - scratch_buffer_(scratch_slice), - config_(std::move(config)) {} - -GenericConvRunner& ConvolutionThunk::GetOrCreateRunner( - const stream_executor::Stream* stream) { - absl::MutexLock lock(&mu_); - auto it = runner_cache_.find(stream); - if (it == runner_cache_.end()) { - it = runner_cache_ - .insert({stream, std::make_unique(config_)}) - .first; - } - return *it->second; -} - -Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) { - const auto& buffer_allocations = *params.buffer_allocations; - - std::vector operand_se_buffers, result_se_buffers; - operand_se_buffers.reserve(operand_buffers_.size()); - for (BufferAllocation::Slice buffer : operand_buffers_) { - operand_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer)); - } - - result_se_buffers.reserve(result_buffers_.size()); - for (BufferAllocation::Slice buffer : result_buffers_) { - result_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer)); - } - - se::DeviceMemoryBase scratch = - buffer_allocations.GetDeviceAddress(scratch_buffer_); - - RunConvOptions opts; - opts.runner_cache = &GetOrCreateRunner(params.stream); - - TF_RETURN_IF_ERROR(RunGpuConv(config_, absl::MakeSpan(operand_se_buffers), - absl::MakeSpan(result_se_buffers), scratch, - params.stream, opts)); - - // Note: Convolution has a tuple buffer as an output, but we don't need to - // populate it as no one should be reading from the tuple directly. - if (!params.stream->ok()) { - return InternalError("ConvolutionThunk::ExecuteOnStream failed."); - } - return OkStatus(); -} - -ConvolutionReorderThunk::ConvolutionReorderThunk( - ThunkInfo thunk_info, absl::Span filter_nchw, - std::vector operand_slices, - std::vector result_slices) - : Thunk(Kind::kConvolutionReorder, thunk_info), - filter_descriptor_(CreateFilterDescriptor(filter_nchw)), - operand_buffers_(std::move(operand_slices)), - result_buffers_(std::move(result_slices)) {} - -Status ConvolutionReorderThunk::ExecuteOnStream(const ExecuteParams& params) { - bool has_bias = operand_buffers_.size() > 1; - CHECK_EQ(operand_buffers_.size(), result_buffers_.size()); - - const auto& buffer_allocations = *params.buffer_allocations; - - auto filter_input = se::DeviceMemory( - buffer_allocations.GetDeviceAddress(operand_buffers_[0])); - auto filter_output = se::DeviceMemory( - buffer_allocations.GetDeviceAddress(result_buffers_[0])); - auto bias_input = - has_bias ? std::make_optional(se::DeviceMemory( - buffer_allocations.GetDeviceAddress(operand_buffers_[1]))) - : std::nullopt; - auto bias_output = - has_bias ? std::make_optional(se::DeviceMemory( - buffer_allocations.GetDeviceAddress(result_buffers_[1]))) - : std::nullopt; - - TF_RETURN_IF_ERROR(params.stream->CudnnReorderConvolutionFilterAndBias( - filter_descriptor_, filter_input, &filter_output, std::move(bias_input), - std::move(bias_output))); - - if (!params.stream->ok()) { - return InternalError("ConvolutionReorderThunk::ExecuteOnStream failed."); - } - return OkStatus(); -} - -se::dnn::FilterDescriptor ConvolutionReorderThunk::CreateFilterDescriptor( - absl::Span filter_nchw) { - CHECK_EQ(filter_nchw.size(), 4); - se::dnn::FilterDescriptor filter_desc(2); - filter_desc.set_layout(se::dnn::FilterLayout::kOutputInputYX32); - filter_desc.set_output_feature_map_count(filter_nchw[0]); - filter_desc.set_input_feature_map_count(filter_nchw[1]); - filter_desc.set_input_filter_height(filter_nchw[2]); - filter_desc.set_input_filter_width(filter_nchw[3]); - return filter_desc; -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/convolution_thunk.h b/xla/service/gpu/convolution_thunk.h deleted file mode 100644 index d7f731302c1a9..0000000000000 --- a/xla/service/gpu/convolution_thunk.h +++ /dev/null @@ -1,95 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_ -#define XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_ - -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/gpu_conv_runner.h" -#include "xla/service/gpu/gpu_executable.h" -#include "xla/service/gpu/thunk.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/status.h" - -namespace xla { -namespace gpu { - -// This class stores everything that StreamExecutor needs to launch a DNN -// convolution. It is generated by IrEmitter. -// -// This is thread-compatible. -class ConvolutionThunk : public Thunk { - public: - // Constructs a thunk for launching a DNN convolution. - // - // operand_slices should be in the same order as cudnn_call->operands(). - ConvolutionThunk(ThunkInfo thunk_info, GpuConvConfig config, - std::vector operand_slices, - std::vector result_slices, - BufferAllocation::Slice scratch_slice); - - ConvolutionThunk(const ConvolutionThunk&) = delete; - ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; - - Status ExecuteOnStream(const ExecuteParams& params) override; - - private: - std::vector operand_buffers_; - std::vector result_buffers_; - BufferAllocation::Slice scratch_buffer_; - GenericConvRunner& GetOrCreateRunner(const stream_executor::Stream* stream); - - // Convolution config - const GpuConvConfig config_; - absl::Mutex mu_; - absl::flat_hash_map> - runner_cache_ ABSL_GUARDED_BY(mu_); -}; - -// Launches the kernel that reorders input data for int8x32 convolutions. -class ConvolutionReorderThunk : public Thunk { - public: - ConvolutionReorderThunk(ThunkInfo thunk_info, absl::Span filter_nchw, - std::vector operand_slices, - std::vector result_slices); - - ConvolutionReorderThunk(const ConvolutionReorderThunk&) = delete; - ConvolutionReorderThunk& operator=(const ConvolutionReorderThunk&) = delete; - - Status ExecuteOnStream(const ExecuteParams& params) override; - - private: - static se::dnn::FilterDescriptor CreateFilterDescriptor( - absl::Span filter_nchw); - - const se::dnn::FilterDescriptor filter_descriptor_; - std::vector operand_buffers_; - std::vector result_buffers_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_ diff --git a/xla/service/gpu/copy_fusion.cc b/xla/service/gpu/copy_fusion.cc index 9db958199b9cc..4f06ef04ba5ae 100644 --- a/xla/service/gpu/copy_fusion.cc +++ b/xla/service/gpu/copy_fusion.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,15 +15,20 @@ limitations under the License. #include "xla/service/gpu/copy_fusion.h" +#include #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/reduction_utils.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" namespace xla { namespace gpu { @@ -53,7 +58,7 @@ bool OnlyElementwiseOpsReachableFromParams(HloComputation* fused_computation) { return true; } -StatusOr CopyFusion::DoCopyFusion(HloComputation* computation) { +absl::StatusOr CopyFusion::DoCopyFusion(HloComputation* computation) { bool changed = false; std::vector defs_before_uses = computation->MakeInstructionPostOrder(); @@ -174,7 +179,7 @@ StatusOr CopyFusion::DoCopyFusion(HloComputation* computation) { return changed; } -StatusOr CopyFusion::Run( +absl::StatusOr CopyFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { // Only for the entry computation we can be sure that the copies do not share diff --git a/xla/service/gpu/copy_fusion.h b/xla/service/gpu/copy_fusion.h index fa9306ce7e3fd..973b671f5978e 100644 --- a/xla/service/gpu/copy_fusion.h +++ b/xla/service/gpu/copy_fusion.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,9 +16,12 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_COPY_FUSION_H_ #define XLA_SERVICE_GPU_COPY_FUSION_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" namespace xla { namespace gpu { @@ -32,12 +35,12 @@ class CopyFusion : public HloModulePass { absl::string_view name() const override { return "copy_fusion"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; private: - StatusOr DoCopyFusion(HloComputation* computation); + absl::StatusOr DoCopyFusion(HloComputation* computation); }; } // namespace gpu diff --git a/xla/service/gpu/copy_fusion_test.cc b/xla/service/gpu/copy_fusion_test.cc index fdb57fad3e376..d2116eb68b0c2 100644 --- a/xla/service/gpu/copy_fusion_test.cc +++ b/xla/service/gpu/copy_fusion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,7 +15,10 @@ limitations under the License. #include "xla/service/gpu/copy_fusion.h" +#include +#include #include "absl/strings/str_cat.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" diff --git a/xla/service/gpu/copy_thunk.cc b/xla/service/gpu/copy_thunk.cc deleted file mode 100644 index 8b049452d9770..0000000000000 --- a/xla/service/gpu/copy_thunk.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/copy_thunk.h" - -#include "xla/stream_executor/stream_executor.h" - -namespace xla { -namespace gpu { - -DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( - ThunkInfo thunk_info, const BufferAllocation::Slice& source_buffer, - const BufferAllocation::Slice& destination_buffer, uint64_t mem_size, - mlir::Value source_value, mlir::Value destination_value) - : Thunk(Kind::kCopy, thunk_info), - source_buffer_(source_buffer), - destination_buffer_(destination_buffer), - mem_size_(mem_size), - source_value_(source_value), - destination_value_(destination_value) {} - -Status DeviceToDeviceCopyThunk::ExecuteOnStream(const ExecuteParams& params) { - se::DeviceMemoryBase destination_data = - params.buffer_allocations->GetDeviceAddress(destination_buffer_); - se::DeviceMemoryBase source_data = - params.buffer_allocations->GetDeviceAddress(source_buffer_); - params.stream->ThenMemcpy(&destination_data, source_data, mem_size_); - return OkStatus(); -} -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/copy_thunk.h b/xla/service/gpu/copy_thunk.h deleted file mode 100644 index 7576c81ca98ec..0000000000000 --- a/xla/service/gpu/copy_thunk.h +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_COPY_THUNK_H_ -#define XLA_SERVICE_GPU_COPY_THUNK_H_ - -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/thunk.h" -#include "xla/stream_executor/stream_executor.h" - -namespace xla { -namespace gpu { - -// A thunk that copies data from a device buffer to another device buffer. -class DeviceToDeviceCopyThunk : public Thunk { - public: - // Constructs a CopyThunk that copies host data from `source_buffer` to the - // device buffer `destination_buffer`. `mem_size` is the size of the data in - // bytes. - DeviceToDeviceCopyThunk(ThunkInfo thunk_info, - const BufferAllocation::Slice& source_buffer, - const BufferAllocation::Slice& destination_buffer, - uint64_t mem_size, mlir::Value source_value, - mlir::Value destination_value); - - DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete; - DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; - - Status ExecuteOnStream(const ExecuteParams& params) override; - - void ClearCompileTimeInfo() override { - Thunk::ClearCompileTimeInfo(); - source_value_ = nullptr; - destination_value_ = nullptr; - } - - const BufferAllocation::Slice& source() const { return source_buffer_; } - const BufferAllocation::Slice& destination() const { - return destination_buffer_; - } - uint64_t size_bytes() const { return mem_size_; } - mlir::Value source_value() const { return source_value_; } - mlir::Value destination_value() const { return destination_value_; } - - private: - const BufferAllocation::Slice source_buffer_; - const BufferAllocation::Slice destination_buffer_; - const uint64_t mem_size_; - mlir::Value source_value_; - mlir::Value destination_value_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_COPY_THUNK_H_ diff --git a/xla/service/gpu/cub_sort_kernel.cu.cc b/xla/service/gpu/cub_sort_kernel.cu.cc index 99ab2627212b2..2efd6d09bffa9 100644 --- a/xla/service/gpu/cub_sort_kernel.cu.cc +++ b/xla/service/gpu/cub_sort_kernel.cu.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,11 +18,7 @@ limitations under the License. #include #include -#if GOOGLE_CUDA -#include "xla/service/gpu/gpu_prim_cuda.h" -#elif TENSORFLOW_USE_ROCM -#include "xla/service/gpu/gpu_prim_rocm.h" -#endif // TENSORFLOW_USE_ROCM +#include "xla/service/gpu/gpu_prim.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/cub_sort_kernel.h b/xla/service/gpu/cub_sort_kernel.h index 621489b6e071c..8d1efaa2bc66d 100644 --- a/xla/service/gpu/cub_sort_kernel.h +++ b/xla/service/gpu/cub_sort_kernel.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/cub_sort_thunk.cc b/xla/service/gpu/cub_sort_thunk.cc deleted file mode 100644 index 762f822f2ee4a..0000000000000 --- a/xla/service/gpu/cub_sort_thunk.cc +++ /dev/null @@ -1,270 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/cub_sort_thunk.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "xla/primitive_util.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/cub_sort_kernel.h" -#include "xla/service/gpu/thunk.h" -#include "xla/status.h" -#include "xla/statusor.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" - -namespace xla { -namespace gpu { -namespace { - -// Template class for sorting a single tensor. -class CubSortKeysImpl : public CubSortRunnerInterface { - public: - using SortKeysFn = std::function; - - explicit CubSortKeysImpl(SortKeysFn sort_keys_fn, PrimitiveType type) - : sort_keys_fn_(sort_keys_fn), type_(type) {} - - Status Run(se::DeviceMemoryBase input_keys, se::DeviceMemoryBase input_values, - se::DeviceMemoryBase output_keys, - se::DeviceMemoryBase output_values, se::DeviceMemoryBase scratch, - bool descending) override; - Status Run(const Thunk::ExecuteParams& params, - const CubSortThunk* thunk) override; - StatusOr GetScratchSize(int64_t num_items) override; - - private: - SortKeysFn sort_keys_fn_; - PrimitiveType type_; -}; - -Status CubSortKeysImpl::Run(se::DeviceMemoryBase input_keys, - se::DeviceMemoryBase input_values, - se::DeviceMemoryBase output_keys, - se::DeviceMemoryBase output_values, - se::DeviceMemoryBase scratch, bool descending) { - size_t temp_bytes = scratch.size(); - size_t num_items = input_keys.size() * 8 / primitive_util::BitWidth(type_); - CHECK(input_values.is_null()); - CHECK(output_values.is_null()); - const char* error = - sort_keys_fn_(scratch.opaque(), temp_bytes, input_keys.opaque(), - output_keys.opaque(), num_items, descending); - if (error != nullptr) { - return absl::InvalidArgumentError( - absl::StrCat("CubSortKeys error: ", error)); - } - return absl::OkStatus(); -} - -Status CubSortKeysImpl::Run(const Thunk::ExecuteParams& params, - const CubSortThunk* thunk) { - const BufferAllocations& allocs = *params.buffer_allocations; - return Run(allocs.GetDeviceAddress(thunk->operand(0)), se::DeviceMemoryBase(), - allocs.GetDeviceAddress(thunk->result(0)), se::DeviceMemoryBase(), - allocs.GetDeviceAddress(thunk->scratch()), thunk->descending()); -} - -StatusOr CubSortKeysImpl::GetScratchSize(int64_t num_items) { - size_t temp_bytes = 0; - const char* error = - sort_keys_fn_(nullptr, temp_bytes, nullptr, nullptr, num_items, false); - if (error != nullptr) { - return absl::InvalidArgumentError( - absl::StrCat("CubSortKeys error: ", error)); - } - return temp_bytes; -} - -// Template class for sorting a pair of tensors. -class CubSortPairsImpl : public CubSortRunnerInterface { - public: - using SortPairsFn = std::function; - - explicit CubSortPairsImpl(SortPairsFn sort_pairs_fn, PrimitiveType type) - : sort_pairs_fn_(sort_pairs_fn), type_(type) {} - - Status Run(se::DeviceMemoryBase input_keys, se::DeviceMemoryBase input_values, - se::DeviceMemoryBase output_keys, - se::DeviceMemoryBase output_values, se::DeviceMemoryBase scratch, - bool descending) override; - Status Run(const Thunk::ExecuteParams& params, - const CubSortThunk* thunk) override; - StatusOr GetScratchSize(int64_t num_items) override; - - private: - SortPairsFn sort_pairs_fn_; - PrimitiveType type_; -}; - -Status CubSortPairsImpl::Run(se::DeviceMemoryBase input_keys, - se::DeviceMemoryBase input_values, - se::DeviceMemoryBase output_keys, - se::DeviceMemoryBase output_values, - se::DeviceMemoryBase scratch, bool descending) { - size_t temp_bytes = scratch.size(); - size_t num_items = input_keys.size() * 8 / primitive_util::BitWidth(type_); - const char* error = sort_pairs_fn_( - scratch.opaque(), temp_bytes, input_keys.opaque(), output_keys.opaque(), - input_values.opaque(), output_values.opaque(), num_items, descending); - if (error != nullptr) { - return absl::InvalidArgumentError( - absl::StrCat("CubSortPairs error: ", error)); - } - return absl::OkStatus(); -} - -Status CubSortPairsImpl::Run(const Thunk::ExecuteParams& params, - const CubSortThunk* thunk) { - const BufferAllocations& allocs = *params.buffer_allocations; - return Run(allocs.GetDeviceAddress(thunk->operand(0)), - allocs.GetDeviceAddress(thunk->operand(1)), - allocs.GetDeviceAddress(thunk->result(0)), - allocs.GetDeviceAddress(thunk->result(1)), - allocs.GetDeviceAddress(thunk->scratch()), thunk->descending()); -} - -StatusOr CubSortPairsImpl::GetScratchSize(int64_t num_items) { - size_t temp_bytes = 0; - const char* error = sort_pairs_fn_(nullptr, temp_bytes, nullptr, nullptr, - nullptr, nullptr, num_items, false); - if (error != nullptr) { - return absl::InvalidArgumentError( - absl::StrCat("CubSortPairs error: ", error)); - } - return temp_bytes; -} - -StatusOr> CreateCubSortRunner( - PrimitiveType type) { - switch (type) { - case F16: - return std::make_unique(CubSortKeys_f16, F16); - case F32: - return std::make_unique(CubSortKeys_f32, F32); - case F64: - return std::make_unique(CubSortKeys_f64, F64); - case S8: - return std::make_unique(CubSortKeys_s8, S8); - case S16: - return std::make_unique(CubSortKeys_s16, S16); - case S32: - return std::make_unique(CubSortKeys_s32, S32); - case S64: - return std::make_unique(CubSortKeys_s64, S64); - case U8: - return std::make_unique(CubSortKeys_u8, U8); - case U16: - return std::make_unique(CubSortKeys_u16, U16); - case U32: - return std::make_unique(CubSortKeys_u32, U32); - case U64: - return std::make_unique(CubSortKeys_u64, U64); - default: - return InvalidArgument("Unsupported type of the sort kernel: %s", - primitive_util::LowercasePrimitiveTypeName(type)); - } -} - -StatusOr> CreateCubSortRunner( - PrimitiveType key_type, PrimitiveType value_type) { - // Values can be of any type of 16/32/64 bit width. - int valueWidth = primitive_util::BitWidth(value_type); - if (valueWidth != 16 && valueWidth != 32 && valueWidth != 64) { - return InvalidArgument( - "Unsupported value type of the sort kernel: %s", - primitive_util::LowercasePrimitiveTypeName(value_type)); - } - - // Only unsigned integer types could be used for keys. - switch (key_type) { - case U16: - if (valueWidth == 16) { - return std::make_unique(CubSortPairs_u16_b16, U16); - } - if (valueWidth == 32) { - return std::make_unique(CubSortPairs_u16_b32, U16); - } - return std::make_unique(CubSortPairs_u16_b64, U16); - case U32: - if (valueWidth == 16) { - return std::make_unique(CubSortPairs_u32_b16, U32); - } - if (valueWidth == 32) { - return std::make_unique(CubSortPairs_u32_b32, U32); - } - return std::make_unique(CubSortPairs_u32_b64, U32); - case U64: - if (valueWidth == 16) { - return std::make_unique(CubSortPairs_u64_b16, U64); - } - if (valueWidth == 32) { - return std::make_unique(CubSortPairs_u64_b32, U64); - } - return std::make_unique(CubSortPairs_u64_b64, U64); - default: - return InvalidArgument( - "Unsupported key type of the sort kernel: %s", - primitive_util::LowercasePrimitiveTypeName(key_type)); - } -} - -} // namespace - -StatusOr> -CubSortRunnerInterface::Create(PrimitiveType type, - std::optional value_type) { - return value_type.has_value() ? CreateCubSortRunner(type, *value_type) - : CreateCubSortRunner(type); -} - -CubSortThunk::CubSortThunk(ThunkInfo thunk_info, PrimitiveType type, - std::optional value_type, - std::vector operands, - std::vector results, - BufferAllocation::Slice scratch, bool descending) - : Thunk(Thunk::kCubSort, thunk_info), - runner_(CubSortRunnerInterface::Create(type, value_type).value()), - operands_(std::move(operands)), - results_(std::move(results)), - scratch_(scratch), - descending_(descending) {} - -Status RunCubSort(PrimitiveType type, std::optional value_type, - se::DeviceMemoryBase input_keys, - se::DeviceMemoryBase input_values, - se::DeviceMemoryBase output_keys, - se::DeviceMemoryBase output_values, - se::DeviceMemoryBase scratch, bool descending) { - auto runner = CubSortRunnerInterface::Create(type, value_type).value(); - return runner->Run(input_keys, input_values, output_keys, output_values, - scratch, descending); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/cub_sort_thunk.h b/xla/service/gpu/cub_sort_thunk.h deleted file mode 100644 index d79fb0e3892d0..0000000000000 --- a/xla/service/gpu/cub_sort_thunk.h +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_CUB_SORT_THUNK_H_ -#define XLA_SERVICE_GPU_CUB_SORT_THUNK_H_ - -#include -#include -#include -#include - -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/thunk.h" -#include "xla/status.h" -#include "xla/statusor.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace gpu { - -class CubSortRunnerInterface { - public: - virtual ~CubSortRunnerInterface() = default; - virtual Status Run(se::DeviceMemoryBase input_keys, - se::DeviceMemoryBase input_values, - se::DeviceMemoryBase output_keys, - se::DeviceMemoryBase output_values, - se::DeviceMemoryBase scratch, bool descending) = 0; - virtual Status Run(const Thunk::ExecuteParams& params, - const class CubSortThunk* thunk) = 0; - virtual StatusOr GetScratchSize(int64_t num_items) = 0; - - static StatusOr> Create( - PrimitiveType type, std::optional value_type); -}; - -class CubSortThunk : public Thunk { - public: - CubSortThunk(ThunkInfo thunk_info, PrimitiveType type, - std::optional value_type, - std::vector operands, - std::vector results, - BufferAllocation::Slice scratch, bool descending); - - Status ExecuteOnStream(const ExecuteParams& params) override { - return runner_->Run(params, this); - } - - BufferAllocation::Slice operand(int i) const { return operands_[i]; } - BufferAllocation::Slice result(int i) const { return results_[i]; } - BufferAllocation::Slice scratch() const { return scratch_; } - bool descending() const { return descending_; } - - private: - std::unique_ptr runner_; - std::vector operands_; - std::vector results_; - BufferAllocation::Slice scratch_; - bool descending_; -}; - -Status RunCubSort(PrimitiveType type, std::optional value_type, - se::DeviceMemoryBase input_keys, - se::DeviceMemoryBase input_values, - se::DeviceMemoryBase output_keys, - se::DeviceMemoryBase output_values, - se::DeviceMemoryBase scratch, bool descending); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_CUB_SORT_THUNK_H_ diff --git a/xla/service/gpu/cublas_cudnn.cc b/xla/service/gpu/cublas_cudnn.cc index 18c705a8f5ed0..cab9e2c54a3f1 100644 --- a/xla/service/gpu/cublas_cudnn.cc +++ b/xla/service/gpu/cublas_cudnn.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,13 +17,21 @@ limitations under the License. #include +#include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/status.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { bool IsCublasGemm(const HloInstruction& hlo) { - return IsLegacyCublasMatmul(hlo) || IsCublasLtMatmul(hlo); + return IsLegacyCublasMatmul(hlo) || IsCublasLtMatmul(hlo) || + IsCublasLtMatmulF8(hlo); } bool IsLegacyCublasMatmul(const HloInstruction& hlo) { @@ -41,6 +49,11 @@ bool IsCublasLtMatmulF8(const HloInstruction& hlo) { hlo.custom_call_target() == kCublasLtMatmulF8CallTarget; } +bool IsTriangularSolve(const HloInstruction& hlo) { + return hlo.opcode() == HloOpcode::kCustomCall && + hlo.custom_call_target() == kTriangularSolveCallTarget; +} + const absl::string_view kGemmCallTarget = "__cublas$gemm"; const absl::string_view kCublasLtMatmulCallTarget = "__cublas$lt$matmul"; const absl::string_view kCublasLtMatmulF8CallTarget = "__cublas$lt$matmul$f8"; @@ -63,43 +76,43 @@ const absl::string_view kCudnnConvReorderFilterAndBiasCallTarget = const absl::string_view kCudnnNormCallTarget = "__cudnn$norm"; // fMHA forward call targets. -const absl::string_view kCudnnfMHABmmBmmCallTarget = "__cudnn$fhmaBmmBmm"; -const absl::string_view kCudnnfMHASoftmaxCallTarget = "__cudnn$fhmaSoftmax"; +const absl::string_view kCudnnfMHABmmBmmCallTarget = "__cudnn$fmhaBmmBmm"; +const absl::string_view kCudnnfMHASoftmaxCallTarget = "__cudnn$fmhaSoftmax"; const absl::string_view kCudnnfMHAScaleBiasMaskSoftmaxCallTarget = - "__cudnn$fhmaScaleBiasMaskSoftmax"; + "__cudnn$fmhaScaleBiasMaskSoftmax"; const absl::string_view kCudnnfMHAScaleBiasMaskSoftmaxDropoutCallTarget = - "__cudnn$fhmaScaleBiasMaskSoftmaxDropout"; + "__cudnn$fmhaScaleBiasMaskSoftmaxDropout"; const absl::string_view kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget = - "__cudnn$fhmaScaleBiasSoftmaxDropout"; + "__cudnn$fmhaScaleBiasSoftmaxDropout"; const absl::string_view kCudnnfMHAScaleBiasSoftmaxCallTarget = - "__cudnn$fhmaScaleBiasSoftmax"; + "__cudnn$fmhaScaleBiasSoftmax"; const absl::string_view kCudnnfMHAScaleMaskSoftmaxCallTarget = - "__cudnn$fhmaScaleMaskSoftmax"; + "__cudnn$fmhaScaleMaskSoftmax"; const absl::string_view kCudnnfMHAScaleMaskSoftmaxDropoutCallTarget = - "__cudnn$fhmaScaleMaskSoftmaxDropout"; + "__cudnn$fmhaScaleMaskSoftmaxDropout"; const absl::string_view kCudnnfMHASoftmaxDropoutCallTarget = - "__cudnn$fhmaSoftmaxDropout"; + "__cudnn$fmhaSoftmaxDropout"; // fMHA backward call targets. const absl::string_view kCudnnfMHABmmBmmBackwardCallTarget = - "__cudnn$fhmaBmmBmmBackward"; + "__cudnn$fmhaBmmBmmBackward"; const absl::string_view kCudnnfMHASoftmaxBackwardCallTarget = - "__cudnn$fhmaSoftmaxBackward"; + "__cudnn$fmhaSoftmaxBackward"; const absl::string_view kCudnnfMHAScaleBiasMaskSoftmaxBackwardCallTarget = - "__cudnn$fhmaScaleBiasMaskSoftmaxBackward"; + "__cudnn$fmhaScaleBiasMaskSoftmaxBackward"; const absl::string_view kCudnnfMHAScaleBiasMaskSoftmaxDropoutBackwardCallTarget = - "__cudnn$fhmaScaleBiasMaskSoftmaxDropoutBackward"; + "__cudnn$fmhaScaleBiasMaskSoftmaxDropoutBackward"; const absl::string_view kCudnnfMHAScaleBiasSoftmaxDropoutBackwardCallTarget = - "__cudnn$fhmaScaleBiasSoftmaxDropoutBackward"; + "__cudnn$fmhaScaleBiasSoftmaxDropoutBackward"; const absl::string_view kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget = - "__cudnn$fhmaScaleBiasSoftmaxBackward"; + "__cudnn$fmhaScaleBiasSoftmaxBackward"; const absl::string_view kCudnnfMHAScaleMaskSoftmaxBackwardCallTarget = - "__cudnn$fhmaScaleMaskSoftmaxBackward"; + "__cudnn$fmhaScaleMaskSoftmaxBackward"; const absl::string_view kCudnnfMHAScaleMaskSoftmaxDropoutBackwardCallTarget = - "__cudnn$fhmaScaleMaskSoftmaxDropoutBackward"; + "__cudnn$fmhaScaleMaskSoftmaxDropoutBackward"; const absl::string_view kCudnnfMHASoftmaxDropoutBackwardCallTarget = - "__cudnn$fhmaSoftmaxDropoutBackward"; + "__cudnn$fmhaSoftmaxDropoutBackward"; const absl::string_view kCubDeviceRadixSortTarget = "__cub$DeviceRadixSort"; @@ -184,7 +197,7 @@ bool IsCubDeviceRadixSort(const HloInstruction& hlo) { hlo.custom_call_target() == kCubDeviceRadixSortTarget; } -StatusOr GetCudnnConvKind( +absl::StatusOr GetCudnnConvKind( const HloCustomCallInstruction* instr) { absl::string_view target = instr->custom_call_target(); if (target == kCudnnConvForwardCallTarget) { @@ -202,7 +215,7 @@ StatusOr GetCudnnConvKind( if (target == kCudnnConvBiasActivationForwardCallTarget) { return CudnnConvKind::kForwardActivation; } - return InternalError("Unexpected call target: %s", target); + return Internal("Unexpected call target: %s", target); } std::string CudnnConvKindToString(CudnnConvKind kind) { @@ -220,7 +233,7 @@ std::string CudnnConvKindToString(CudnnConvKind kind) { } } -StatusOr GetCudnnfMHAKind( +absl::StatusOr GetCudnnfMHAKind( const HloCustomCallInstruction* instr) { absl::string_view target = instr->custom_call_target(); if (target == kCudnnfMHABmmBmmCallTarget) return CudnnfMHAKind::kBmmBmm; @@ -258,7 +271,7 @@ StatusOr GetCudnnfMHAKind( return CudnnfMHAKind::kBackwardScaleBiasSoftmax; if (target == kCudnnfMHAScaleBiasSoftmaxDropoutBackwardCallTarget) return CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout; - return InternalError("Unexpected call target: %s", target); + return Internal("Unexpected call target: %s", target); } std::string CudnnfMHAKindToString(CudnnfMHAKind kind) { @@ -303,7 +316,7 @@ std::string CudnnfMHAKindToString(CudnnfMHAKind kind) { } } -StatusOr GetFMHAInstructionPrefix( +absl::StatusOr GetFMHAInstructionPrefix( const std::string& custom_call_target) { if (custom_call_target == kCudnnfMHABmmBmmCallTarget) { return "fmha-bmm-bmm"; @@ -364,15 +377,15 @@ StatusOr GetFMHAInstructionPrefix( kCudnnfMHAScaleBiasSoftmaxDropoutBackwardCallTarget) { return "fmha-bmm-scale-bias-softmax-dropout-bmm-backward"; } - return InternalError("Unexpected call target: %s", custom_call_target); + return Internal("Unexpected call target: %s", custom_call_target); } // Give fmha instruction a more useful name than "custom-call.42". -Status SetFMHAInstructionName(HloModule* module, HloInstruction* fmha) { +absl::Status SetFMHAInstructionName(HloModule* module, HloInstruction* fmha) { TF_ASSIGN_OR_RETURN(std::string fmha_prefix, GetFMHAInstructionPrefix(fmha->custom_call_target())); module->SetAndUniquifyInstrName(fmha, fmha_prefix); - return OkStatus(); + return absl::OkStatus(); } } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/cublas_cudnn.h b/xla/service/gpu/cublas_cudnn.h index ed751bafe3bb2..c79e76b9a72b3 100644 --- a/xla/service/gpu/cublas_cudnn.h +++ b/xla/service/gpu/cublas_cudnn.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,14 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_CUBLAS_CUDNN_H_ #define XLA_SERVICE_GPU_CUBLAS_CUDNN_H_ +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" -#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -48,6 +52,12 @@ enum class CudnnConvKind { // => output }; +enum class CudnnNormKind { + kLayerForwardInfer, + kLayerForwardTrain, + kLayerBackward, +}; + enum class CudnnfMHAKind { kBmmBmm, kScaleBiasMaskSoftmax, @@ -69,7 +79,8 @@ enum class CudnnfMHAKind { kBackwardScaleBiasSoftmaxDropout, }; -StatusOr GetCudnnConvKind(const HloCustomCallInstruction* instr); +absl::StatusOr GetCudnnConvKind( + const HloCustomCallInstruction* instr); // Converts a CudnnConvKind value to a string. std::string CudnnConvKindToString(CudnnConvKind kind); @@ -88,6 +99,9 @@ bool IsCublasLtMatmul(const HloInstruction& hlo); // Scaled matrix multiplication in FP8. Calls into cublasLt. bool IsCublasLtMatmulF8(const HloInstruction& hlo); +// Triangular solve that calls into legacy cublas. +bool IsTriangularSolve(const HloInstruction& hlo); + // A call to cuBLAS general matrix multiplication API. extern const absl::string_view kGemmCallTarget; @@ -199,10 +213,11 @@ bool IsFwdCustomCallTofMHA(const HloInstruction& hlo); bool IsBwdCustomCallTofMHA(const HloInstruction& hlo); bool IsCustomCallTofMHA(const HloInstruction& hlo); -StatusOr GetCudnnfMHAKind(const HloCustomCallInstruction* instr); +absl::StatusOr GetCudnnfMHAKind( + const HloCustomCallInstruction* instr); std::string CudnnfMHAKindToString(CudnnfMHAKind kind); -Status SetFMHAInstructionName(HloModule* module, HloInstruction* fmha); +absl::Status SetFMHAInstructionName(HloModule* module, HloInstruction* fmha); bool MHACallHasDropout(absl::string_view fmha_call_name); diff --git a/xla/service/gpu/cublas_pad_for_gemms.cc b/xla/service/gpu/cublas_pad_for_gemms.cc index c54473cf86fd7..050f219d12b6c 100644 --- a/xla/service/gpu/cublas_pad_for_gemms.cc +++ b/xla/service/gpu/cublas_pad_for_gemms.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,20 +15,32 @@ limitations under the License. #include "xla/service/gpu/cublas_pad_for_gemms.h" +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/literal_util.h" -#include "xla/service/gpu/gemm_rewriter_triton.h" +#include "xla/service/gpu/gemm_fusion.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/triton_support.h" +#include "xla/shape.h" +#include "xla/stream_executor/device_description.h" #include "xla/util.h" -#include "xla/window_util.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { -static StatusOr PadForGemm(HloDotInstruction* dot, PrimitiveType datatype, - int pad_to_multiple_of) { +static absl::StatusOr PadForGemm(HloDotInstruction* dot, + PrimitiveType datatype, + int pad_to_multiple_of) { auto* lhs = dot->mutable_operand(0); auto* rhs = dot->mutable_operand(1); @@ -155,7 +167,7 @@ bool CheckCanonical(HloDotInstruction* dot) { } // namespace static std::vector GetRelevantDots( - const se::CudaComputeCapability cuda_compute_capability, + const se::GpuComputeCapability& gpu_compute_capability, HloComputation* comp, PrimitiveType datatype) { std::vector gemms; @@ -168,8 +180,8 @@ static std::vector GetRelevantDots( ->config() .debug_options() .xla_gpu_enable_triton_gemm() && - CanTritonHandleGEMM(*dot, cuda_compute_capability) && - ShouldTritonHandleGEMM(*dot, cuda_compute_capability))) { + IsTritonSupportedInstruction(*dot, gpu_compute_capability) && + ShouldTritonHandleGEMM(*dot, gpu_compute_capability))) { gemms.push_back(dot); } } @@ -177,14 +189,14 @@ static std::vector GetRelevantDots( return gemms; } -StatusOr CublasPadForGemms::Run( +absl::StatusOr CublasPadForGemms::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; for (HloComputation* comp : module->MakeNonfusionComputations(execution_threads)) { for (HloDotInstruction* dot : - GetRelevantDots(cuda_compute_capability_, comp, datatype_)) { + GetRelevantDots(gpu_compute_capability_, comp, datatype_)) { TF_ASSIGN_OR_RETURN(bool result, PadForGemm(dot, datatype_, pad_to_multiple_of_)); changed |= result; diff --git a/xla/service/gpu/cublas_pad_for_gemms.h b/xla/service/gpu/cublas_pad_for_gemms.h index 758949c91eb9f..2a1f9c6f161cd 100644 --- a/xla/service/gpu/cublas_pad_for_gemms.h +++ b/xla/service/gpu/cublas_pad_for_gemms.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,7 +16,14 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_CUBLAS_PAD_FOR_GEMMS_H_ #define XLA_SERVICE_GPU_CUBLAS_PAD_FOR_GEMMS_H_ +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { @@ -31,21 +38,21 @@ namespace gpu { // so it should go strictly later. class CublasPadForGemms : public HloModulePass { public: - CublasPadForGemms(const se::CudaComputeCapability cuda_compute_capability, + CublasPadForGemms(const se::GpuComputeCapability gpu_compute_capability, PrimitiveType datatype, int32_t pad_to_multiple_of) - : cuda_compute_capability_(cuda_compute_capability), + : gpu_compute_capability_(gpu_compute_capability), datatype_(datatype), pad_to_multiple_of_(pad_to_multiple_of) {} absl::string_view name() const override { return "cublas-pad-for-gemms"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; private: - const se::CudaComputeCapability cuda_compute_capability_; + const se::GpuComputeCapability gpu_compute_capability_; PrimitiveType datatype_; int32_t pad_to_multiple_of_; }; diff --git a/xla/service/gpu/cublas_pad_for_gemms_test.cc b/xla/service/gpu/cublas_pad_for_gemms_test.cc index 59aa5b572a1e8..d20dd94a06e7a 100644 --- a/xla/service/gpu/cublas_pad_for_gemms_test.cc +++ b/xla/service/gpu/cublas_pad_for_gemms_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,8 +15,13 @@ limitations under the License. #include "xla/service/gpu/cublas_pad_for_gemms.h" +#include +#include +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" namespace m = ::xla::match; diff --git a/xla/service/gpu/cublas_padding_requirements.cc b/xla/service/gpu/cublas_padding_requirements.cc index 4c99defa5ff19..d9634a9145660 100644 --- a/xla/service/gpu/cublas_padding_requirements.cc +++ b/xla/service/gpu/cublas_padding_requirements.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,10 +15,14 @@ limitations under the License. #include "xla/service/gpu/cublas_padding_requirements.h" -#include +#include +#include #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/variant_visitor.h" +#include "xla/shape.h" +#include "xla/stream_executor/device_description.h" #include "xla/util.h" namespace xla { @@ -27,34 +31,52 @@ namespace gpu { namespace { bool DimensionRequiresPadding(const int64_t size, const PrimitiveType data_type, - const se::CudaComputeCapability cc) { - for (const CublasPaddingRequirement& requirement : - CublasPaddingRequirements) { - if (cc.IsAtLeast(requirement.min_compute_capability) && - data_type == requirement.data_type && - size % requirement.multiple_of != 0) { + const se::GpuComputeCapability& gpu_cc) { + return std::visit( + VariantVisitor{ + [&](const se::CudaComputeCapability& cc) { + for (const auto& req : CublasPaddingRequirements) { + if (cc.IsAtLeast(req.min_compute_capability) && + data_type == req.data_type && size % req.multiple_of != 0) { + return true; + } + } + return false; + }, + [&](const se::RocmComputeCapability& cc) { + for (const auto& req : HipblasPaddingRequirements) { + if (data_type == req.data_type && size % req.multiple_of != 0) { + return true; + } + } + return false; + }}, + gpu_cc); +} + +bool ShapeRequiresPadding(const Shape& shape, int batch_dimensions_size, + const se::GpuComputeCapability& cc) { + // Non-batch dimensions requiring potential padding are placed at higher + // indices than batch dimensions. This is because dots are canonicalized prior + // to padding. + for (int i = batch_dimensions_size; i < shape.rank(); i++) { + if (DimensionRequiresPadding(shape.dimensions(i), shape.element_type(), + cc)) { return true; } } return false; } -bool ShapeRequiresPadding(const Shape& shape, - const se::CudaComputeCapability cc) { - // Since dots are canonicalized before padding only the last two dimensions - // of each operand represent non-batch dimensions and may need padding. - return DimensionRequiresPadding(shape.dimensions(shape.rank() - 1), - shape.element_type(), cc) || - DimensionRequiresPadding(shape.dimensions(shape.rank() - 2), - shape.element_type(), cc); -} - } // namespace bool CublasRequiresPadding(const HloDotInstruction& dot, - const se::CudaComputeCapability cc) { - return ShapeRequiresPadding(dot.operand(0)->shape(), cc) || - ShapeRequiresPadding(dot.operand(1)->shape(), cc); + const se::GpuComputeCapability& cc) { + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); + return ShapeRequiresPadding(dot.operand(0)->shape(), + dim_numbers.lhs_batch_dimensions_size(), cc) || + ShapeRequiresPadding(dot.operand(1)->shape(), + dim_numbers.rhs_batch_dimensions_size(), cc); } } // namespace gpu diff --git a/xla/service/gpu/cublas_padding_requirements.h b/xla/service/gpu/cublas_padding_requirements.h index b0b44e61b57bb..6bfd23c9d6840 100644 --- a/xla/service/gpu/cublas_padding_requirements.h +++ b/xla/service/gpu/cublas_padding_requirements.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ limitations under the License. #define XLA_SERVICE_GPU_CUBLAS_PADDING_REQUIREMENTS_H_ #include -#include #include "xla/hlo/ir/hlo_instructions.h" #include "xla/stream_executor/device_description.h" @@ -31,15 +30,23 @@ struct CublasPaddingRequirement { int multiple_of; }; +struct HipblasPaddingRequirement { + PrimitiveType data_type; + int multiple_of; +}; + // List of padding requirements per compute capability and data type. constexpr std::array CublasPaddingRequirements{ {{se::CudaComputeCapability::VOLTA, S8, 4}, {se::CudaComputeCapability::VOLTA, F16, 8}, {se::CudaComputeCapability::AMPERE, BF16, 8}}}; +constexpr std::array HipblasPaddingRequirements{ + {{/*rocm gpu arch,*/ F16, 8}, {/*rocm gpu arch,*/ BF16, 8}}}; + // Tell if either of the operands of the dot requires padding. bool CublasRequiresPadding(const HloDotInstruction& dot, - se::CudaComputeCapability cc); + const se::GpuComputeCapability& cc); } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/xla/service/gpu/cudnn_fused_conv_rewriter.cc index 2cd279b680c5a..a1f9bc04a0d29 100644 --- a/xla/service/gpu/cudnn_fused_conv_rewriter.cc +++ b/xla/service/gpu/cudnn_fused_conv_rewriter.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,14 +15,38 @@ limitations under the License. #include "xla/service/gpu/cudnn_fused_conv_rewriter.h" +#include #include +#include #include #include +#include #include #include #include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "xla/comparison_util.h" +#include "xla/debug_options_flags.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/util.h" +#include "tsl/platform/ml_dtypes.h" + #if GOOGLE_CUDA #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cudnn/cudnn.h" @@ -142,11 +166,11 @@ bool IsLosslesslyConvertibleTo(const HloInstruction* instr, // The only reason Convert() should fail is if we don't support converting // from x to y, which indeed means it's not losslessly-convertible. - StatusOr converted1 = instr->literal().Convert(dst_ty); + absl::StatusOr converted1 = instr->literal().Convert(dst_ty); if (!converted1.ok()) { return false; } - StatusOr converted2 = converted1->Convert(orig_ty); + absl::StatusOr converted2 = converted1->Convert(orig_ty); if (!converted2.ok()) { return false; } @@ -175,7 +199,8 @@ bool IsLosslesslyConvertibleToF16(const HloInstruction* instr) { // conv-bias-activation. If it's already a conv-bias-activation, does nothing. // // If `conv` is anything else, returns an error. -StatusOr EnsureIsConvBiasActivation(HloInstruction* conv) { +absl::StatusOr EnsureIsConvBiasActivation( + HloInstruction* conv) { CHECK_EQ(conv->opcode(), HloOpcode::kCustomCall); if (conv->custom_call_target() == kCudnnConvBiasActivationForwardCallTarget) { @@ -217,9 +242,9 @@ StatusOr EnsureIsConvBiasActivation(HloInstruction* conv) { // convert(gte(custom-call(int8_x, int8_w))) -> // gte(custom-call(int8_x, int8_w)) -StatusOr FuseConvertTypeIntoConv(HloComputation* comp, - PrimitiveType conv_type, - PrimitiveType cvt_type) { +absl::StatusOr FuseConvertTypeIntoConv(HloComputation* comp, + PrimitiveType conv_type, + PrimitiveType cvt_type) { bool changed = false; for (auto instr : comp->MakeInstructionPostOrder()) { HloInstruction* conv = nullptr; @@ -261,7 +286,7 @@ struct ConvConvertTypes { // (custom call) to be the same as the conversion result. // For example: convert(gte(custom-call(int8_x, int8_w))) -> // gte(custom-call(int8_x, int8_w)) -StatusOr FuseRemoveConvertInConv(HloComputation* comp) { +absl::StatusOr FuseRemoveConvertInConv(HloComputation* comp) { bool changed = false; // Note: We are eliminating F16->F32 because it fails on internal tests. std::array types{{ @@ -279,7 +304,7 @@ StatusOr FuseRemoveConvertInConv(HloComputation* comp) { // alpha * gte(custom-call(...)) -> // gte(custom-call(..., backend_config={alpha})). -StatusOr FuseConvAlpha(HloComputation* comp) { +absl::StatusOr FuseConvAlpha(HloComputation* comp) { bool changed = false; for (auto instr : comp->MakeInstructionPostOrder()) { HloInstruction* conv = nullptr; @@ -302,8 +327,11 @@ StatusOr FuseConvAlpha(HloComputation* comp) { continue; } - TF_ASSIGN_OR_RETURN(auto config, - conv->backend_config()); + TF_ASSIGN_OR_RETURN(auto gpu_config, + conv->backend_config()); + CudnnConvBackendConfig& config = + *gpu_config.mutable_cudnn_conv_backend_config(); + if (config.conv_result_scale() != 1) { continue; } @@ -319,8 +347,7 @@ StatusOr FuseConvAlpha(HloComputation* comp) { TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64)); config.set_conv_result_scale(alpha_f64.GetFirstElement()); - - TF_RETURN_IF_ERROR(conv->set_backend_config(config)); + TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config)); TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(instr, gte)); changed = true; @@ -575,8 +602,9 @@ void CaptureConvGraphRecursive(HloInstruction* instr, // Captures in a GraphString the subgraph of pointwise operations operating on // the convolution that will be fused into the cuDNN convolution Custom Call. -StatusOr, std::vector, - GraphString, HloInstruction*>> +absl::StatusOr< + std::tuple, std::vector, + GraphString, HloInstruction*>> CaptureConvGraph(HloInstruction* instr, HloInstruction* convolution, HloInstruction* wide_input, HloInstruction* wide_filter, HloInstruction* input_scale, HloInstruction* filter_scale, @@ -629,7 +657,8 @@ CaptureConvGraph(HloInstruction* instr, HloInstruction* convolution, // multiplying or dividing by a broadcast scalar. // 5. Optionally calculate the maximum of the absolute of the result. // 6. Optionally cast the output back to FP8. -StatusOr F8GraphConv(HloComputation* comp, se::CudaComputeCapability cc) { +absl::StatusOr F8GraphConv(HloComputation* comp, + se::CudaComputeCapability cc) { bool changed = false; #if CUDA_VERSION >= 12000 && CUDNN_VERSION >= 8900 @@ -692,8 +721,11 @@ StatusOr F8GraphConv(HloComputation* comp, se::CudaComputeCapability cc) { filter_scale_op ? filter_scale_op->opcode() == HloOpcode::kMultiply : false)); - TF_ASSIGN_OR_RETURN( - auto config, convolution->backend_config()); + TF_ASSIGN_OR_RETURN(auto gpu_config, + convolution->backend_config()); + CudnnConvBackendConfig& config = + *gpu_config.mutable_cudnn_conv_backend_config(); + config.set_serialized_graph(graph_string.Graph()); operands.insert(operands.begin(), input); operands.insert(operands.begin() + 1, filter); @@ -713,7 +745,7 @@ StatusOr F8GraphConv(HloComputation* comp, se::CudaComputeCapability cc) { ShapeUtil::MakeTupleShape(output_shapes), operands)); new_convolution->set_custom_call_target(kCudnnConvForwardGraphCallTarget); - TF_RETURN_IF_ERROR(new_convolution->set_backend_config(config)); + TF_RETURN_IF_ERROR(new_convolution->set_backend_config(gpu_config)); TF_ASSIGN_OR_RETURN(HloInstruction * new_gte, MakeGetTupleElementHlo(new_convolution, 0)); TF_RETURN_IF_ERROR(comp->ReplaceInstruction(final_instr, new_gte)); @@ -731,7 +763,7 @@ StatusOr F8GraphConv(HloComputation* comp, se::CudaComputeCapability cc) { return changed; } -StatusOr FuseBiasOrSideInput(HloComputation* comp) { +absl::StatusOr FuseBiasOrSideInput(HloComputation* comp) { bool changed = false; for (auto instr : comp->MakeInstructionPostOrder()) { HloInstruction* conv = nullptr; @@ -762,8 +794,10 @@ StatusOr FuseBiasOrSideInput(HloComputation* comp) { // Can't fuse bias or side-input if the conv already has a relu (or other // activation), because bias and side-input are added before the activation // is applied. - TF_ASSIGN_OR_RETURN(auto config, - conv->backend_config()); + TF_ASSIGN_OR_RETURN(auto gpu_config, + conv->backend_config()); + CudnnConvBackendConfig& config = + *gpu_config.mutable_cudnn_conv_backend_config(); if (config.activation_mode() != se::dnn::kNone) { continue; } @@ -823,7 +857,7 @@ StatusOr FuseBiasOrSideInput(HloComputation* comp) { HloInstruction* new_conv = comp->AddInstruction( conv->CloneWithNewOperands(conv->shape(), new_operands)); comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name()); - TF_RETURN_IF_ERROR(new_conv->set_backend_config(config)); + TF_RETURN_IF_ERROR(new_conv->set_backend_config(gpu_config)); TF_ASSIGN_OR_RETURN(HloInstruction * new_instr, MakeGetTupleElementHlo(new_conv, 0)); TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr)); @@ -842,7 +876,7 @@ StatusOr FuseBiasOrSideInput(HloComputation* comp) { // // where `reshape` can be an arbitrary chain of reshapes+transposes. This idiom // is created by the ReshapeMover pass. -StatusOr FuseSideInputAlpha(HloComputation* comp) { +absl::StatusOr FuseSideInputAlpha(HloComputation* comp) { bool changed = false; for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { HloInstruction* conv; @@ -853,8 +887,10 @@ StatusOr FuseSideInputAlpha(HloComputation* comp) { if (!Match(instr, pattern)) { continue; } - TF_ASSIGN_OR_RETURN(auto config, - conv->backend_config()); + TF_ASSIGN_OR_RETURN(auto gpu_config, + conv->backend_config()); + CudnnConvBackendConfig& config = + *gpu_config.mutable_cudnn_conv_backend_config(); if (config.side_input_scale() != 1) { continue; } @@ -939,7 +975,7 @@ StatusOr FuseSideInputAlpha(HloComputation* comp) { TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64)); config.set_side_input_scale(alpha_f64.GetFirstElement()); - TF_RETURN_IF_ERROR(new_conv->set_backend_config(config)); + TF_RETURN_IF_ERROR(new_conv->set_backend_config(gpu_config)); TF_RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv)); changed = true; @@ -947,7 +983,8 @@ StatusOr FuseSideInputAlpha(HloComputation* comp) { return changed; } -StatusOr FuseElu(HloComputation* comp, se::CudaComputeCapability cc) { +absl::StatusOr FuseElu(HloComputation* comp, + se::CudaComputeCapability cc) { if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(), cc)) { return false; @@ -988,8 +1025,10 @@ StatusOr FuseElu(HloComputation* comp, se::CudaComputeCapability cc) { continue; } - TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config, - conv->backend_config()); + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + conv->backend_config()); + CudnnConvBackendConfig& config = + *gpu_config.mutable_cudnn_conv_backend_config(); if (config.activation_mode() != se::dnn::kNone) { continue; } @@ -1001,14 +1040,14 @@ StatusOr FuseElu(HloComputation* comp, se::CudaComputeCapability cc) { } TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv)); config.set_activation_mode(se::dnn::kElu); - TF_RETURN_IF_ERROR(conv->set_backend_config(config)); + TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config)); TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte1)); changed = true; } return changed; } -StatusOr FuseRelu(HloComputation* comp) { +absl::StatusOr FuseRelu(HloComputation* comp) { bool changed = false; for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { HloInstruction* gte; @@ -1023,8 +1062,10 @@ StatusOr FuseRelu(HloComputation* comp) { .WithOneUse()))) { continue; } - TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config, - conv->backend_config()); + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + conv->backend_config()); + CudnnConvBackendConfig& config = + *gpu_config.mutable_cudnn_conv_backend_config(); if (config.activation_mode() != se::dnn::kNone) { continue; } @@ -1036,14 +1077,15 @@ StatusOr FuseRelu(HloComputation* comp) { } TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv)); config.set_activation_mode(se::dnn::kRelu); - TF_RETURN_IF_ERROR(conv->set_backend_config(config)); + TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config)); TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte)); changed = true; } return changed; } -StatusOr FuseRelu6(HloComputation* comp, se::CudaComputeCapability cc) { +absl::StatusOr FuseRelu6(HloComputation* comp, + se::CudaComputeCapability cc) { if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(), cc)) { return false; @@ -1065,8 +1107,10 @@ StatusOr FuseRelu6(HloComputation* comp, se::CudaComputeCapability cc) { m::Broadcast(m::ConstantEffectiveScalar(6))))) { continue; } - TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config, - conv->backend_config()); + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + conv->backend_config()); + CudnnConvBackendConfig& config = + *gpu_config.mutable_cudnn_conv_backend_config(); if (config.activation_mode() != se::dnn::kNone) { continue; } @@ -1082,15 +1126,15 @@ StatusOr FuseRelu6(HloComputation* comp, se::CudaComputeCapability cc) { } TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv)); config.set_activation_mode(se::dnn::kRelu6); - TF_RETURN_IF_ERROR(conv->set_backend_config(config)); + TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config)); TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte)); changed = true; } return changed; } -StatusOr FuseLeakyRelu(HloComputation* comp, - se::CudaComputeCapability cc) { +absl::StatusOr FuseLeakyRelu(HloComputation* comp, + se::CudaComputeCapability cc) { if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(), cc)) { return false; @@ -1122,8 +1166,10 @@ StatusOr FuseLeakyRelu(HloComputation* comp, continue; } - TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config, - conv->backend_config()); + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + conv->backend_config()); + CudnnConvBackendConfig& config = + *gpu_config.mutable_cudnn_conv_backend_config(); if (config.activation_mode() != se::dnn::kNone) { continue; } @@ -1141,14 +1187,14 @@ StatusOr FuseLeakyRelu(HloComputation* comp, config.set_activation_mode(se::dnn::kLeakyRelu); TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64)); config.set_leakyrelu_alpha(alpha_f64.GetFirstElement()); - TF_RETURN_IF_ERROR(conv->set_backend_config(config)); + TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config)); TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte1)); changed = true; } return changed; } -StatusOr FuseConvertToF16(HloComputation* comp) { +absl::StatusOr FuseConvertToF16(HloComputation* comp) { bool changed = false; for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { HloInstruction* gte = nullptr; @@ -1208,7 +1254,7 @@ StatusOr FuseConvertToF16(HloComputation* comp) { return changed; } -StatusOr FuseConvertToS8(HloComputation* comp) { +absl::StatusOr FuseConvertToS8(HloComputation* comp) { bool changed = false; for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { HloInstruction* gte = nullptr; @@ -1289,7 +1335,7 @@ StatusOr FuseConvertToS8(HloComputation* comp) { return changed; } -Status CheckNoIllegalIntegerConvs(HloComputation* comp) { +absl::Status CheckNoIllegalIntegerConvs(HloComputation* comp) { auto is_integral_not_s8 = [](const Shape& s) { return primitive_util::IsIntegralType(s.element_type()) && s.element_type() != S8; @@ -1310,7 +1356,7 @@ Status CheckNoIllegalIntegerConvs(HloComputation* comp) { } if (bad_convs.empty()) { - return OkStatus(); + return absl::OkStatus(); } return Unimplemented( @@ -1394,21 +1440,22 @@ void VlogStats(HloModule* module) { ++stats["22 convs with side-input"]; } - auto config = instr->backend_config(); - if (!config.ok()) { + auto gpu_config = instr->backend_config(); + if (!gpu_config.ok()) { LOG(ERROR) << "Couldn't parse backend config for " << instr->ToString(); continue; } - - if (config->conv_result_scale() != 1) { + const CudnnConvBackendConfig& config = + gpu_config->cudnn_conv_backend_config(); + if (config.conv_result_scale() != 1) { ++stats["30 convs with result scale"]; } - if (config->side_input_scale() != 0 && config->side_input_scale() != 1) { + if (config.side_input_scale() != 0 && config.side_input_scale() != 1) { ++stats["31 convs with side-input scale"]; } ++stats[absl::StrCat( "32 convs with activation mode ", - se::dnn::ActivationMode_Name(config->activation_mode()))]; + se::dnn::ActivationMode_Name(config.activation_mode()))]; } } @@ -1423,7 +1470,7 @@ void VlogStats(HloModule* module) { } // namespace -StatusOr CudnnFusedConvRewriter::Run( +absl::StatusOr CudnnFusedConvRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool any_changed = false; diff --git a/xla/service/gpu/cudnn_fused_conv_rewriter.h b/xla/service/gpu/cudnn_fused_conv_rewriter.h index 67c40792df08a..bc7291d262a61 100644 --- a/xla/service/gpu/cudnn_fused_conv_rewriter.h +++ b/xla/service/gpu/cudnn_fused_conv_rewriter.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,8 +16,12 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_ #define XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_ -#include "xla/hlo/ir/hlo_instructions.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { @@ -102,7 +106,7 @@ class CudnnFusedConvRewriter : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc index 9a88bbbda2e85..4a55b9dc9eb40 100644 --- a/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc +++ b/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,14 +16,29 @@ limitations under the License. #include "xla/service/gpu/cudnn_fused_conv_rewriter.h" #include +#include #include #include +#include +#include +#include +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "xla/comparison_util.h" +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_module_config.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/statusor.h" #if GOOGLE_CUDA #include "third_party/gpus/cuda/include/cuda.h" -#include "third_party/gpus/cudnn/cudnn.h" #endif #include "xla/service/algebraic_simplifier.h" @@ -135,7 +150,7 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { EXPECT_TRUE(RunAndCompare(pre_hlo_string, ErrorSpec{0.01})) << pre_hlo_string; - StatusOr filecheck_result = + absl::StatusOr filecheck_result = RunFileCheck(optimized_hlo_string, post_hlo_string); ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status(); EXPECT_TRUE(*filecheck_result); @@ -166,7 +181,7 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { EXPECT_TRUE(RunAndCompare(pre_hlo_string, ErrorSpec{0.15, 0.15})) << pre_hlo_string; - StatusOr filecheck_result = + absl::StatusOr filecheck_result = RunFileCheck(optimized_hlo_string, custom_call_string); ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status(); EXPECT_TRUE(*filecheck_result); @@ -1264,8 +1279,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestReluAfterConvert) { m::ConstantEffectiveScalar(0).WithElementType(F32))), 0) .WithShape(S8, {1, 32, 9, 9}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.activation_mode(), se::dnn::kRelu); } @@ -1362,8 +1378,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, Int8SideInputWithScaleAndReshape) { m::Reshape(m::Parameter(3)).WithShape(S8, {1, 32, 9, 9})), 0) .WithShape(S8, {1, 32, 9, 9}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.conv_result_scale(), 1); EXPECT_EQ(config.side_input_scale(), 0.25); } @@ -1402,8 +1419,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseAlpha) { m::CustomCall(&conv, {kCudnnConvBiasActivationForwardCallTarget}), 0) .WithShape(F32, {1, 32, 9, 9}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.conv_result_scale(), 42); } @@ -1441,8 +1459,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu) { m::Parameter(0), m::Parameter(1), m::Parameter(2)), 0) .WithShape(F32, {1, 32, 9, 9}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.activation_mode(), se::dnn::kRelu); } @@ -1484,8 +1503,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseReluIfMultipleUses) { 0) .WithShape(F32, {1, 32, 9, 9})), m::Minimum()))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.activation_mode(), se::dnn::kNone); } @@ -1529,8 +1549,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseElu) { m::Parameter(0), m::Parameter(1), m::Parameter(2)), 0) .WithShape(F16, {1, 32, 9, 9}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.activation_mode(), se::dnn::kElu); } @@ -1553,7 +1574,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseEluIfMultipleUses) { expm1 = exponential-minus-one(sum) elu = select(cmp, sum, expm1) not_elu = minimum(sum, zeros) - ROOT root = tuple(elu, not_elu) + ROOT root = tuple(elu, not_elu) })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); DebugOptions debug_opts = m->config().debug_options(); @@ -1584,8 +1605,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseEluIfMultipleUses) { .WithPredicate(HloPredicateIsOp) .WithOperand(0, gte_pattern)), m::Minimum()))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.activation_mode(), se::dnn::kNone); } @@ -1626,8 +1648,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu6) { m::Parameter(0), m::Parameter(1), m::Parameter(2)), 0) .WithShape(F16, {1, 32, 9, 9}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.activation_mode(), se::dnn::kRelu6); } @@ -1672,8 +1695,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseRelu6IfMultipleUses) { .WithShape(F16, {1, 32, 9, 9}), m::Broadcast(m::ConstantEffectiveScalar(6))), m::Minimum()))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.activation_mode(), se::dnn::kNone); } @@ -1715,8 +1739,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseLeakyRelu) { m::Parameter(0), m::Parameter(1), m::Parameter(2)), 0) .WithShape(F16, {1, 32, 9, 9}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.activation_mode(), se::dnn::kLeakyRelu); } @@ -1737,7 +1762,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseLeakyReluIfMultipleUses) { mul = multiply(sum, alphas) leaky_relu = select(cmp, sum, mul) not_leaky_relu = minimum(sum, zeros) - ROOT root = tuple(leaky_relu, not_leaky_relu) + ROOT root = tuple(leaky_relu, not_leaky_relu) })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); DebugOptions debug_opts = m->config().debug_options(); @@ -1768,8 +1793,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseLeakyReluIfMultipleUses) { m::Multiply(gte_pattern, m::Broadcast(m::ConstantEffectiveScalar()))), m::Minimum()))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.activation_mode(), se::dnn::kNone); } @@ -1807,8 +1833,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseAlphaIfMultipleUsers) { m::Broadcast(m::Parameter(3)), m::GetTupleElement(m::CustomCall(&conv2), 0)))))); EXPECT_EQ(conv1, conv2); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv1->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv1->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.conv_result_scale(), 1); EXPECT_EQ(config.activation_mode(), se::dnn::kNone); } @@ -1843,8 +1870,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasIfMultipleUsers) { m::AddAnyOrder(m::Broadcast(m::Parameter(2)), m::GetTupleElement(m::CustomCall(&conv2), 0))))); EXPECT_EQ(conv1, conv2); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv1->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv1->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.conv_result_scale(), 1); EXPECT_EQ(config.activation_mode(), se::dnn::kNone); } @@ -1880,8 +1908,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputThroughRelu) { m::CustomCall(&conv, m::Parameter(0), m::Parameter(1), m::Broadcast(m::ConstantEffectiveScalar(0))), 0)))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.conv_result_scale(), 1); EXPECT_EQ(config.activation_mode(), se::dnn::kRelu); } @@ -1915,8 +1944,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasThroughRelu) { m::GetTupleElement(m::CustomCall( &conv, m::Parameter(0), m::Parameter(1), m::Broadcast(m::ConstantEffectiveScalar(0))))))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.conv_result_scale(), 1); EXPECT_EQ(config.activation_mode(), se::dnn::kRelu); } @@ -1951,8 +1981,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputIfMultipleUsers) { m::AddAnyOrder(m::Parameter(2), m::GetTupleElement(m::CustomCall(&conv2), 0))))); EXPECT_EQ(conv1, conv2); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv1->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv1->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.conv_result_scale(), 1); EXPECT_EQ(config.activation_mode(), se::dnn::kNone); } @@ -2190,8 +2221,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseSideInput) { m::Parameter(2)), 0) .WithShape(F32, {1, 32, 9, 9}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.side_input_scale(), 1); } @@ -2231,8 +2263,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseScaledSideInput) { m::Parameter(2)), 0) .WithShape(F32, {1, 32, 9, 9}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.side_input_scale(), 42); } @@ -2270,8 +2303,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseBiasAndSideInput) { m::Parameter(3)), 0) .WithShape(F32, {1, 32, 9, 9}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.side_input_scale(), 1); } @@ -2352,8 +2386,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, StrengthReduceF32ToF16) { m::Parameter(3)), 0) .WithShape(F16, {1, 32, 9, 9}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.side_input_scale(), 1); } @@ -2400,8 +2435,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, BroadcastReshapeTransposeAfterConvert) { m::Parameter(2), m::Reshape(m::Parameter(3))), 0) .WithShape(F16, {1, 32, 9, 9}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.side_input_scale(), 1); } @@ -2453,8 +2489,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, NoStrengthReduceF32ToF16IfBiasIsF32) { m::Convert(m::Parameter(3)).WithElementType(F32)), 0)) .WithShape(F16, {1, 32, 9, 9}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.side_input_scale(), 1); } @@ -2505,8 +2542,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, F32Constants) { m::Parameter(1), m::Constant().WithElementType(F16)), 0) .WithShape(F16, {1, 2, 2, 2}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.side_input_scale(), 1); } @@ -2563,8 +2601,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, F32ConstantsNotLosslesslyConvertible) { 0) .WithShape(F32, {1, 2, 2, 2})) .WithElementType(F16))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.side_input_scale(), 1); } @@ -2617,8 +2656,9 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseReluBeforeConvert) { .WithShape(F32, {32})), 0) .WithShape(S8, {1, 32, 9, 9}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - conv->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + conv->backend_config()); + const CudnnConvBackendConfig& config = gpu_config.cudnn_conv_backend_config(); EXPECT_EQ(config.activation_mode(), se::dnn::kRelu); } diff --git a/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/xla/service/gpu/cudnn_fused_mha_rewriter.cc index c5c1f77995a53..632eb42f2b5bb 100644 --- a/xla/service/gpu/cudnn_fused_mha_rewriter.cc +++ b/xla/service/gpu/cudnn_fused_mha_rewriter.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,15 +17,19 @@ limitations under the License. #include #include +#include #include #include #include #include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -36,19 +40,26 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/stream_executor_util.h" #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/status.h" #include "xla/status_macros.h" #include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" -#include "xla/stream_executor/stream_executor.h" #include "xla/types.h" +#include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" +#endif + namespace xla { namespace gpu { namespace { @@ -63,10 +74,18 @@ struct MatchFwdResult { HloInstruction* matched_mask = nullptr; HloInstruction* matched_scale = nullptr; HloInstruction* matched_softmax_input = nullptr; + HloInstruction* matched_reduce_sum = nullptr; double matched_dropout_rate = 0.0; bool need_canonicalization = false; bool is_training = false; + // We use this to keep track of whether the bias or the mask that is being + // applied to the bmm1 is a causal mask, cuDNN can generate causal mask inside + // the attention kernel to save I/O. + bool is_causal_mask = false; + // We use this to keep track of whether the attention block should be lowered + // to flash attention or regular fused attention in cuDNN. + bool is_flash_attention = false; bool has_match = false; std::string matched_custom_call_name; }; @@ -79,7 +98,7 @@ struct MatchBwdResult { HloInstruction* matched_bmm_2_grad_1 = nullptr; HloInstruction* matched_bmm_2_grad_2 = nullptr; - HloInstruction* matched_d_intermediate = nullptr; + HloInstruction* matched_dbias = nullptr; // We use this to keep track of all gradient bmms that need // canonicalization. bool bmm_1_grad_1_need_canonicalization = false; @@ -212,22 +231,33 @@ auto GetUnfusedReduceMaxSumSoftmaxPattern( HloInstruction** softmax_reduce_sum = nullptr, HloInstruction** softmax_reduce_sum_bcast = nullptr) { // The reduce-max part of the softmax - auto unfused_softmax_max_subpattern = m::SharedSubpattern(m::Subtract( - m::Op(), m::Broadcast(OptionalConvert(OptionalConvert( - m::Op() - .WithPredicate(IsReduceMax) - .WithOperand(0, OptionalBitcast(OptionalConvert( - m::Op(softmax_input))))))))); + // reduce_max and subtract will always have exactly 1 user + // in both training and inference + // softmax_input should always have exactly 2 users + auto unfused_softmax_max_subpattern = m::SharedSubpattern( + m::Subtract( + m::Op(), + m::Broadcast(OptionalConvert( + m::Op() + .WithPredicate(IsReduceMax) + .WithOneUse() + .WithOperand(0, OptionalBitcast(OptionalConvert( + m::Op(softmax_input).WithNumUser(2))))))) + .WithOneUse()); // The reduce-add part of the softmax + // reduce_sum and reduce_sum_broadcast should have 2 users in training + // and 1 user in inference auto unfused_softmax_sum_subpattern = m::SharedSubpattern(m::Divide( OptionalBitcast(m::Exp(unfused_softmax_max_subpattern)), m::Broadcast( softmax_reduce_sum_bcast, - OptionalConvert(OptionalConvert( + OptionalConvert( m::Op(softmax_reduce_sum) .WithOperand(0, OptionalBitcast(OptionalConvert( m::Exp(unfused_softmax_max_subpattern)))) - .WithPredicate(IsReduceSum)))))); + .WithPredicate(IsReduceSum) + .WithAtMostNumUser(2))) + .WithAtMostNumUser(2))); return unfused_softmax_sum_subpattern; } @@ -264,29 +294,20 @@ double GetDropoutRateFromHlo(HloInstruction* dropout) { bool IsComputeCapabilityAndCudnnSupported( stream_executor::CudaComputeCapability cc, stream_executor::dnn::VersionInfo cudnn_version, - stream_executor::StreamExecutor* stream_exec, stream_executor::dnn::VersionInfo supported_cudnn_version) { - se::dnn::VersionInfo real_cudnn_version; - if (stream_exec) { - stream_executor::dnn::DnnSupport* dnn = stream_exec->AsDnn(); - StatusOr se_cudnn_version = dnn->GetVersion(); - if (se_cudnn_version.ok()) { - real_cudnn_version = (*se_cudnn_version); - } - } else { - real_cudnn_version = cudnn_version; - } - if (!((cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0) && - (real_cudnn_version >= supported_cudnn_version))) { - VLOG(2) << absl::StrFormat( - "CudnnFusedMHARewriter did not run. Unsupported compute " - "capability(==8.0) or cudnn version(>=%d.%d.%d)", - supported_cudnn_version.major_version(), - supported_cudnn_version.minor_version(), - supported_cudnn_version.patch()); - return false; + // Enforce capability minor == 0 because hardware with a non-zero minor + // number typically has insufficient shared memory for cuDNN FMHA. + if (cc.IsAtLeastAmpere() && cc.minor == 0 && + cudnn_version >= supported_cudnn_version) { + return true; } - return true; + VLOG(2) << absl::StrFormat( + "CudnnFusedMHARewriter did not run. Unsupported compute " + "capability(%s; major should be >= 8, minor should be 0) or cudnn version" + "(%s; should be >= %s)", + cc.ToString(), cudnn_version.ToString(), + supported_cudnn_version.ToString()); + return false; } bool IsSupportedPrimitiveType(const HloInstruction* bmm) { @@ -294,17 +315,6 @@ bool IsSupportedPrimitiveType(const HloInstruction* bmm) { return dtype == BF16 || dtype == F16; } -bool IsContractingDimSupported(absl::Span contracting_dims) { - return absl::c_all_of(contracting_dims, - [](int64_t dim) { return dim == 64; }); -} - -bool IsNonContractingDimSupported( - const std::vector& non_contracting_dims) { - return absl::c_all_of(non_contracting_dims, - [](int64_t dim) { return dim <= 512; }); -} - std::vector GetDimensionVector(absl::Span dimensions, absl::Span dim_nums) { std::vector vec(dim_nums.size()); @@ -314,92 +324,200 @@ std::vector GetDimensionVector(absl::Span dimensions, return vec; } -StatusOr IsSupportedBMM1(const HloInstruction* bmm_1) { - const DotDimensionNumbers& dot_dims_bmm1 = bmm_1->dot_dimension_numbers(); +struct QKVLayout { + int64_t batch; + int64_t num_heads; + int64_t seqlen_q; + int64_t seqlen_kv; + int64_t hidden_dim; +}; + +absl::StatusOr> GetQKVLayout( + HloInstruction* bmm_1, HloInstruction* bmm_2, bool need_canonicalization) { + // get layout from bmm1 + const DotDimensionNumbers& bmm1_dnums = bmm_1->dot_dimension_numbers(); TF_ASSIGN_OR_RETURN( - std::vector lhs_non_contracting_dim_nums_bmm1, + std::vector bmm1_s_q_dims, GetNonContractingDims(bmm_1->operand(0)->shape(), - dot_dims_bmm1.lhs_batch_dimensions(), - dot_dims_bmm1.lhs_contracting_dimensions())); + bmm1_dnums.lhs_batch_dimensions(), + bmm1_dnums.lhs_contracting_dimensions())); + TF_ASSIGN_OR_RETURN( - std::vector rhs_non_contracting_dim_nums_bmm1, + std::vector bmm1_s_kv_dims, GetNonContractingDims(bmm_1->operand(1)->shape(), - dot_dims_bmm1.rhs_batch_dimensions(), - dot_dims_bmm1.rhs_contracting_dimensions())); - std::vector lhs_non_contracting_dims_bmm1 = + bmm1_dnums.rhs_batch_dimensions(), + bmm1_dnums.rhs_contracting_dimensions())); + + std::vector bmm1_bh = GetDimensionVector(bmm_1->operand(0)->shape().dimensions(), - lhs_non_contracting_dim_nums_bmm1); - std::vector rhs_non_contracting_dims_bmm1 = - GetDimensionVector(bmm_1->operand(1)->shape().dimensions(), - rhs_non_contracting_dim_nums_bmm1); - // The non contracting dimensions for BMM1 need to be less than or equal to - // 512. - if (!IsNonContractingDimSupported(lhs_non_contracting_dims_bmm1) || - !IsNonContractingDimSupported(rhs_non_contracting_dims_bmm1)) { - if (VLOG_IS_ON(2)) { - VLOG(2) << "BMM1 lhs_non_contracting_dims: " - << absl::StrJoin(lhs_non_contracting_dims_bmm1, ",") - << " BMM1 rhs_non_contracting_dims: " - << absl::StrJoin(rhs_non_contracting_dims_bmm1, ",") - << " are not supported. The non-contracting dims should be less " - "than 512. This is a criteria for current cuDNN 8.8 support."; - } + bmm1_dnums.lhs_batch_dimensions()); + + std::vector bmm1_s_q = GetDimensionVector( + bmm_1->operand(0)->shape().dimensions(), bmm1_s_q_dims); + + std::vector bmm1_s_kv = GetDimensionVector( + bmm_1->operand(1)->shape().dimensions(), bmm1_s_kv_dims); + + std::vector bmm1_d = + GetDimensionVector(bmm_1->operand(0)->shape().dimensions(), + bmm1_dnums.lhs_contracting_dimensions()); + + TF_RET_CHECK(bmm1_bh.size() == 2); + TF_RET_CHECK(bmm1_s_q.size() == 1); + TF_RET_CHECK(bmm1_s_kv.size() == 1); + TF_RET_CHECK(bmm1_d.size() == 1); + + // get layout from bmm2 + const DotDimensionNumbers& bmm2_dnums = bmm_2->dot_dimension_numbers(); + TF_ASSIGN_OR_RETURN( + std::vector bmm2_lhs_non_contracting_dims, + GetNonContractingDims(bmm_2->operand(0)->shape(), + bmm2_dnums.lhs_batch_dimensions(), + bmm2_dnums.lhs_contracting_dimensions())); + + TF_ASSIGN_OR_RETURN( + std::vector bmm2_rhs_non_contracting_dims, + GetNonContractingDims(bmm_2->operand(1)->shape(), + bmm2_dnums.rhs_batch_dimensions(), + bmm2_dnums.rhs_contracting_dimensions())); + + std::vector bmm2_bh = + GetDimensionVector(bmm_2->operand(0)->shape().dimensions(), + bmm2_dnums.lhs_batch_dimensions()); + + std::vector bmm2_s_kv = + GetDimensionVector(bmm_2->operand(0)->shape().dimensions(), + bmm2_dnums.lhs_contracting_dimensions()); + + std::vector bmm2_s_q = + need_canonicalization + ? GetDimensionVector(bmm_2->operand(1)->shape().dimensions(), + bmm2_rhs_non_contracting_dims) + : GetDimensionVector(bmm_2->operand(0)->shape().dimensions(), + bmm2_lhs_non_contracting_dims); + + std::vector bmm2_d = + need_canonicalization + ? GetDimensionVector(bmm_2->operand(0)->shape().dimensions(), + bmm2_lhs_non_contracting_dims) + : GetDimensionVector(bmm_2->operand(1)->shape().dimensions(), + bmm2_rhs_non_contracting_dims); + + TF_RET_CHECK(bmm2_bh.size() == 2); + TF_RET_CHECK(bmm2_s_q.size() == 1); + TF_RET_CHECK(bmm2_s_kv.size() == 1); + TF_RET_CHECK(bmm2_d.size() == 1); + + // check if bhsd is correct between bmm1 and bmm2 + if (bmm1_bh[0] != bmm2_bh[0] || bmm1_bh[1] != bmm2_bh[1] || + bmm1_s_q[0] != bmm2_s_q[0] || bmm1_s_kv[0] != bmm2_s_kv[0] || + bmm1_d[0] != bmm2_d[0]) { + return std::nullopt; + } + + QKVLayout qkv_layout; + qkv_layout.batch = bmm1_bh[0]; + qkv_layout.num_heads = bmm1_bh[1]; + qkv_layout.seqlen_q = bmm1_s_q[0]; + qkv_layout.seqlen_kv = bmm1_s_kv[0]; + qkv_layout.hidden_dim = bmm1_d[0]; + return qkv_layout; +} + +absl::StatusOr IsFusedAttention( + QKVLayout qkv_layout, bool is_training, + stream_executor::CudaComputeCapability cc, + stream_executor::dnn::VersionInfo cudnn_version) { + // otherwise check if it is supported by regular attention + int64_t s_q = qkv_layout.seqlen_q; + int64_t s_kv = qkv_layout.seqlen_kv; + int64_t hidden_dim = qkv_layout.hidden_dim; + bool is_seqlen_supported = + (s_q <= 512 && s_kv <= 512) && + (!is_training || (s_q % 64 == 0 && s_kv % 64 == 0)); + bool is_hidden_dim_supported = hidden_dim == 64; + bool is_fused_attention = is_seqlen_supported && is_hidden_dim_supported; + return is_fused_attention; +} + +absl::StatusOr IsFlashAttention( + QKVLayout qkv_layout, bool is_training, + stream_executor::CudaComputeCapability cc, + stream_executor::dnn::VersionInfo cudnn_version) { + int64_t s_q = qkv_layout.seqlen_q; + int64_t s_kv = qkv_layout.seqlen_kv; + int64_t hidden_dim = qkv_layout.hidden_dim; + // start with most relaxed constraint + bool is_seqlen_supported = (s_q > 512 || s_kv > 512) && + (!is_training || (s_q % 2 == 0 && s_kv % 2 == 0)); + bool is_hidden_dim_supported = hidden_dim <= 128 && hidden_dim % 8 == 0; + bool is_flash_attention = is_seqlen_supported && is_hidden_dim_supported; + if (!is_flash_attention) return false; + // going backwards to check compatibility + if ((is_training && (s_q < 64 || s_kv < 64)) && + !IsComputeCapabilityAndCudnnSupported( + cc, cudnn_version, stream_executor::dnn::VersionInfo(9, 0, 0))) { + VLOG(2) << "Flash attention training with seq < 64 not supported cuDNN < " + "9.0.0."; return false; } - std::vector lhs_contracting_dims_bmm1 = - GetDimensionVector(bmm_1->operand(0)->shape().dimensions(), - dot_dims_bmm1.lhs_contracting_dimensions()); - std::vector rhs_contracting_dims_bmm1 = - GetDimensionVector(bmm_1->operand(1)->shape().dimensions(), - dot_dims_bmm1.rhs_contracting_dimensions()); - - // The contracting dimensions for BMM1 need to be 64. - if (!IsContractingDimSupported(lhs_contracting_dims_bmm1) || - !IsContractingDimSupported(rhs_contracting_dims_bmm1)) { - if (VLOG_IS_ON(2)) { - VLOG(2) << "BMM1 lhs_contracting_dims: " - << absl::StrJoin(lhs_contracting_dims_bmm1, ",") - << " BMM1 rhs_contracting_dims: " - << absl::StrJoin(rhs_contracting_dims_bmm1, ",") - << " are not supported."; - } + if ((hidden_dim != 64 && hidden_dim != 128) && + !IsComputeCapabilityAndCudnnSupported( + cc, cudnn_version, stream_executor::dnn::VersionInfo(8, 9, 6))) { + VLOG(2) << "Flash attention head dim != 64 or 128 not supported with cuDNN " + "< 8.9.6."; return false; } - return true; -} -StatusOr IsSupportedBMM2(const HloInstruction* bmm_2, - bool need_canonicalization) { - const DotDimensionNumbers& dot_dims_bmm2 = bmm_2->dot_dimension_numbers(); - // need swap lhs and rhs for bmm2 if canonicalization is needed - int operand_index = need_canonicalization ? 0 : 1; - auto batch_dim = need_canonicalization ? dot_dims_bmm2.lhs_batch_dimensions() - : dot_dims_bmm2.rhs_batch_dimensions(); - auto contracting_dim = need_canonicalization - ? dot_dims_bmm2.lhs_contracting_dimensions() - : dot_dims_bmm2.rhs_contracting_dimensions(); + if ((is_training && s_kv % 64 != 0) && + !IsComputeCapabilityAndCudnnSupported( + cc, cudnn_version, stream_executor::dnn::VersionInfo(8, 9, 5))) { + VLOG(2) << "Flash attention training with seq kv % 64 != 0 not supported " + "with cuDNN < 8.9.5."; + return false; + } - TF_ASSIGN_OR_RETURN( - std::vector non_contracting_dim_nums_bmm2, - GetNonContractingDims(bmm_2->operand(operand_index)->shape(), batch_dim, - contracting_dim)); - - std::vector non_contracting_dims_bmm2 = - GetDimensionVector(bmm_2->operand(operand_index)->shape().dimensions(), - non_contracting_dim_nums_bmm2); - // The non contracting dimension for BMM2 needs to be 64 for the input matrix. - // The input matrix is the second argument to BMM2 i.e, rhs. - if (!absl::c_all_of(non_contracting_dims_bmm2, - [](int64_t dim) { return dim == 64; })) { - if (VLOG_IS_ON(2)) { - VLOG(2) << " BMM2 rhs_non_contracting_dims: " - << absl::StrJoin(non_contracting_dims_bmm2, ",") - << " are not supported."; - } + if (!IsComputeCapabilityAndCudnnSupported( + cc, cudnn_version, stream_executor::dnn::VersionInfo(8, 9, 4))) { + VLOG(2) << "Require cuDNN 8.9.4 to run flash attention."; return false; } - return true; + return is_flash_attention; +} + +bool IsCausalMaskPattern(HloInstruction* mask) { + auto causal_mask = + m::Select(m::Compare(m::Iota(), m::Iota()), m::Broadcast(m::Constant()), + m::Broadcast(m::Constant())); + auto causal_mask_pattern_fwd_remat = + m::Broadcast(OptionalBitcast(causal_mask)); + auto causal_mask_pattern_bwd = m::Broadcast(m::Convert(OptionalBitcast( + m::Minimum(m::Op(), m::Broadcast(OptionalBitcast(causal_mask)))))); + HloInstruction* param = nullptr; + HloInstruction* gte = nullptr; + auto causal_mask_pattern_fwd = m::Broadcast( + OptionalBitcast(m::GetTupleElement(>e, m::Parameter(¶m)))); + auto causal_mask_pattern = m::AnyOf( + causal_mask_pattern_fwd_remat, causal_mask_pattern_fwd, + causal_mask_pattern_bwd); + if (Match(mask, causal_mask_pattern)) { + if (param != nullptr && param->parent()->IsWhileBodyComputation()) { + // need to track to outside of the while loop body to find the real mask. + auto while_instr = param->parent()->WhileCallInstruction(); + auto mask_index = gte->tuple_index(); + auto actual_mask = + while_instr->mutable_operand(0)->mutable_operand(mask_index); + auto causal_mask_pattern_fwd = + OptionalBitcast(m::Convert(m::MinimumAnyOrder( + m::Op(), + OptionalBitcast(m::MinimumAnyOrder( + m::Op(), m::Broadcast(OptionalBitcast(causal_mask))))))); + return Match(actual_mask, causal_mask_pattern_fwd); + } + return true; + } + return false; } MatchFwdResult MatchDefaultFwdBmmBmm(MatchFwdResult previous_result, @@ -409,12 +527,16 @@ MatchFwdResult MatchDefaultFwdBmmBmm(MatchFwdResult previous_result, // Try matching default bmm1-bmm2 pattern HloInstruction* bmm_1; HloInstruction* bmm_2; - + // bmm1 should have at most 2 users at this case + // 1. 1 user(bmm2) in case of inference + // 2. 2 users(bmm2 and backward bmm) in case of training auto default_bmm_bmm_pattern = m::Op(&bmm_2) .WithPredicate(IsBatchedMatmul) .WithOperand(bmm2_operand_position, - m::Op(&bmm_1).WithPredicate(IsBatchedMatmul)); + m::Op(&bmm_1) + .WithPredicate(IsBatchedMatmul) + .WithAtMostNumUser(2)); // If any of bmm1's operands is coming from a forward fMHA call, then return // false @@ -472,6 +594,14 @@ MatchFwdResult MatchSoftmaxDropoutBmm(MatchFwdResult previous_result, m::Broadcast(m::Constant(&dropout).WithPredicate(IsScalar)), m::Op()))))))))); + // Form3 -> softmax - mul(dropout) - mul(scale) - BMM2 + auto dropout_softmax_pattern_form_3 = m::MultiplyAnyOrder( + m::MultiplyAnyOrder( + OptionalConvert(GetUnfusedReduceMaxSumSoftmaxPattern( + &softmax_input, &softmax_reduce_sum, &softmax_reduce_sum_bcast)), + m::Op()), + m::Broadcast(m::Constant(&dropout).WithPredicate(IsScalar))); + // Try matching BMM1 - (Scale) - (Bias) - (Mask) - Softmax - (Dropout) - // BMM2 Dropout with non-zero drop rate has select(divide(softmax_output, // broadcast(1-dropout_rate))) @@ -485,7 +615,8 @@ MatchFwdResult MatchSoftmaxDropoutBmm(MatchFwdResult previous_result, &softmax_input, &softmax_reduce_sum, &softmax_reduce_sum_bcast))), dropout_softmax_pattern_form_1, - dropout_softmax_pattern_form_2)); + dropout_softmax_pattern_form_2, + dropout_softmax_pattern_form_3)); if (!Match(instr, softmax_dropout_bmm2_pattern) || !IsSupportedPrimitiveType(bmm_2)) { @@ -502,6 +633,7 @@ MatchFwdResult MatchSoftmaxDropoutBmm(MatchFwdResult previous_result, match_result.matched_dropout_rate = GetDropoutRateFromHlo(dropout); } match_result.matched_softmax_input = softmax_input; + match_result.matched_reduce_sum = softmax_reduce_sum; match_result.has_match = true; return match_result; } @@ -513,17 +645,20 @@ MatchFwdResult MatchBmm1UnfusedBiasSoftmaxBmm2(MatchFwdResult previous_result, HloInstruction* bmm_1; HloInstruction* bias = nullptr; HloInstruction* scale = nullptr; - + // bmm1/scale/bias add should have 2 users if being connected to softmax + // otherwise should have exactly 1 user auto first_bmm_pattern = m::SharedSubpattern(m::Op(&bmm_1).WithPredicate(IsBatchedMatmul)); auto unfused_scaled_bmm_subpattern = m::MultiplyAnyOrder( - OptionalConvert(first_bmm_pattern), + OptionalConvert(first_bmm_pattern.WithOneUse()), OptionalConvert( m::Broadcast(m::Constant(&scale).WithPredicate(IsScalar)))); - if (Match(softmax_input, - OptionalConvert(OptionalBitcast(first_bmm_pattern)))) { + OptionalConvert(OptionalBitcast(m::AnyOf( + first_bmm_pattern, unfused_scaled_bmm_subpattern))))) { + // bmm1 - (scale) - softmax match_result.matched_bmm_1 = bmm_1; + match_result.matched_scale = scale; match_result.matched_custom_call_name = has_dropout ? kCudnnfMHASoftmaxDropoutCallTarget : kCudnnfMHASoftmaxCallTarget; @@ -531,14 +666,17 @@ MatchFwdResult MatchBmm1UnfusedBiasSoftmaxBmm2(MatchFwdResult previous_result, } else if (Match(softmax_input, OptionalBitcast(m::AddAnyOrder( OptionalConvert(OptionalBitcast(m::AnyOf( - unfused_scaled_bmm_subpattern, first_bmm_pattern))), + unfused_scaled_bmm_subpattern.WithOneUse(), + first_bmm_pattern.WithOneUse()))), m::Op(&bias))))) { + // bmm1 - (scale) - bias - softmax match_result.matched_bmm_1 = bmm_1; match_result.matched_scale = scale; match_result.matched_bias = bias; match_result.matched_custom_call_name = has_dropout ? kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget : kCudnnfMHAScaleBiasSoftmaxCallTarget; + match_result.is_causal_mask |= IsCausalMaskPattern(bias); match_result.has_match = true; } else { match_result.has_match = false; @@ -561,36 +699,36 @@ MatchFwdResult MatchBmm1ScaleBiasMaskSoftmaxDropoutBmm2( OptionalConvert( m::Op(&bmm_1).WithPredicate(IsBatchedMatmul).WithOneUse()), m::Broadcast(m::Constant(&scale).WithPredicate(IsScalar)))); - - if (Match( - softmax_input, - OptionalConvert(m::Select( - m::Op(&mask).WithPredicate([](const HloInstruction* instr) { - return instr->shape().element_type() == PRED; - }), - // Match bmm1-scale-bias-mask - m::AnyOf( - // Scale and bias might or might not be fused - // with gemm - m::Op(&bmm_1).WithPredicate(IsBatchedMatmul).WithOneUse(), - OptionalConvert(m::AnyOf( - // Try to match unfused bias - m::AddAnyOrder(m::Op(&bias), - m::AnyOf( - OptionalConvert( - m::Op(&bmm_1) - .WithPredicate(IsBatchedMatmul) - .WithOneUse()), - unfused_scaled_bmm_subpattern)), - unfused_scaled_bmm_subpattern))), - m::Op())))) { + // bmm1/scale/bias add/mask should have 2 users if being connected to softmax + // otherwise should have exactly 1 user + if (Match(softmax_input, + OptionalConvert(m::Select( + m::Op(&mask).WithPredicate([](const HloInstruction* instr) { + return instr->shape().element_type() == PRED; + }), + // Match bmm1-scale-bias-mask + m::AnyOf( + // Scale and bias might or might not be fused + // with gemm + m::Op(&bmm_1).WithPredicate(IsBatchedMatmul).WithOneUse(), + OptionalConvert(m::AnyOf( + // Try to match unfused bias + m::AddAnyOrder( + m::Op(&bias), + m::AnyOf( + OptionalConvert( + m::Op(&bmm_1) + .WithPredicate(IsBatchedMatmul) + .WithOneUse()), + unfused_scaled_bmm_subpattern.WithOneUse())), + unfused_scaled_bmm_subpattern.WithOneUse()))), + m::Op())))) { if (!IsSupportedPrimitiveType(bmm_1)) { matched_result.has_match = false; return matched_result; } - if (has_dropout) { - // Found BMM1 - Scale - (bias) - Mask - Softmax - dropout - BMM2 + // Found BMM1 - (Scale) - (bias) - Mask - Softmax - dropout - BMM2 matched_result.matched_custom_call_name = bias == nullptr ? kCudnnfMHAScaleMaskSoftmaxDropoutCallTarget : kCudnnfMHAScaleBiasMaskSoftmaxDropoutCallTarget; @@ -694,40 +832,31 @@ MatchBwdResult MatchBmm1GradGemm2(MatchBwdResult previous_result, HloInstruction* bmm_1_grad_2 = nullptr; MatchBwdResult match_result = previous_result; match_result.has_match = false; - // bmm1 gradient gemm2 shares the same input as bmm1 gradient gemm1. + // bmm1 gradient gemm2 shares the same input d_s as bmm1 gradient gemm1. // Check to see if bmm1 grad gemm1 needs canonicalization or not, if not, // then the shared input is the first operand. - int64_t parent_nodex_index = - match_result.bmm_1_grad_1_need_canonicalization ? 1 : 0; + int64_t d_s_index = match_result.bmm_1_grad_1_need_canonicalization ? 1 : 0; HloInstruction* d_s_user_0 = match_result.matched_bmm_1_grad_1; - HloInstruction* parent_node = d_s_user_0->mutable_operand(parent_nodex_index); - if (parent_node->opcode() == HloOpcode::kBitcast && - parent_node->user_count() == 1) { - d_s_user_0 = parent_node; - parent_node = parent_node->mutable_operand(0); + HloInstruction* d_s = d_s_user_0->mutable_operand(d_s_index); + if (d_s->opcode() == HloOpcode::kBitcast && d_s->user_count() == 1) { + d_s = d_s->mutable_operand(0); } - auto bmm_1_grad_2_it = - std::find_if(parent_node->users().begin(), parent_node->users().end(), - [&](HloInstruction* instr) { - return instr != match_result.matched_bmm_1_grad_1 && - instr->opcode() != HloOpcode::kReduce; - }); - if (bmm_1_grad_2_it != parent_node->users().end()) { + auto bmm_1_grad_2_it = std::find_if( + d_s->users().begin(), d_s->users().end(), [&](HloInstruction* instr) { + return instr != match_result.matched_bmm_1_grad_1 && + instr->opcode() == HloOpcode::kDot; + }); + if (bmm_1_grad_2_it != d_s->users().end()) { bmm_1_grad_2 = *bmm_1_grad_2_it; } else { return match_result; } - if (bmm_1_grad_2->opcode() == HloOpcode::kBitcast && - bmm_1_grad_2->user_count() == 1) { - parent_node = bmm_1_grad_2; - bmm_1_grad_2 = bmm_1_grad_2->users()[0]; - } match_result.matched_bmm_1_grad_2 = bmm_1_grad_2; - if (match_result.matched_bmm_1_grad_2->operand_index(parent_node) != 0) { + if (match_result.matched_bmm_1_grad_2->operand_index(d_s) != 0) { match_result.bmm_1_grad_2_need_canonicalization = true; } match_result.has_match = true; @@ -786,6 +915,33 @@ MatchBwdResult MatchBmm2GradGemm2(MatchBwdResult previous_result, return match_result; } +MatchBwdResult MatchDbias(MatchBwdResult previous_result, + HloInstruction* d_intermediate, + const absl::flat_hash_set users) { + MatchBwdResult match_result = previous_result; + auto user_count = d_intermediate->user_count(); + HloInstruction* dbias_user = nullptr; + HloInstruction* dbias = nullptr; + for (auto user : d_intermediate->users()) { + if (users.contains(user)) { + user_count -= 1; + } else { + dbias_user = user; + } + } + auto ConsumeExtraConvert = [](HloInstruction* instr) { + Match(instr->users()[0], m::Convert(&instr, m::Op()).WithOneUse()); + return true; + }; + // user_count == 1 && (reduce-> {convert} ->bitcast) + match_result.has_match = + user_count == 1 && + Match(dbias_user, m::Reduce(&dbias, m::Op(), m::Op()).WithOneUse()) && + dbias->shape().rank() == 3 && ConsumeExtraConvert(dbias); + match_result.matched_dbias = dbias; + return match_result; +} + MatchBwdResult MatchBwdBmmSoftmaxDropoutBmm(MatchBwdResult previous_result, HloInstruction* fwd_fmha_call, HloInstruction* mask) { @@ -793,6 +949,7 @@ MatchBwdResult MatchBwdBmmSoftmaxDropoutBmm(MatchBwdResult previous_result, bool is_bmm1_grad1_canonicalized = match_result.bmm_1_grad_1_need_canonicalization; match_result.has_match = false; + bool has_scale = false; bool has_dropout = false; bool has_mask = false; // Backward dropout pattern @@ -817,54 +974,76 @@ MatchBwdResult MatchBwdBmmSoftmaxDropoutBmm(MatchBwdResult previous_result, m::Broadcast(OptionalConvert( m::Constant().WithPredicate(IsScalar))), m::Op())))))))); + auto bwd_dropout_pattern_form_3 = OptionalConvert(m::MultiplyAnyOrder( + m::MultiplyAnyOrder( + m::Op().WithPredicate([&](const HloInstruction* instr) { + return instr == match_result.matched_bmm_2_grad_2; + }), + m::Broadcast(m::Constant().WithPredicate(IsScalar))), + m::Op())); auto bwd_dropout_pattern = m::AnyOf( - bwd_dropout_pattern_form_1, bwd_dropout_pattern_form_2); + bwd_dropout_pattern_form_1, bwd_dropout_pattern_form_2, + bwd_dropout_pattern_form_3); // Backward softmax pattern HloInstruction* bwd_softmax_input = nullptr; HloInstruction* exp_1; HloInstruction* exp_2; HloInstruction* d_softmax; - auto bwd_softmax_pattern = - OptionalBitcast(OptionalConvert(m::MultiplyAnyOrder( + // d_softmax = exp * (dy / s_b - sum(dy * exp * 1 / s^2)) + // there could be at most 3 users of d_softmax: bmm1grad1 bmm1grad2 and dbias + auto bwd_softmax_pattern = OptionalBitcast(OptionalConvert( + m::MultiplyAnyOrder( &d_softmax, m::AddAnyOrder( - m::Divide(), - m::Broadcast(OptionalBitcast( - OptionalConvert(OptionalConvert(m::Negate(OptionalBitcast( - m::Op() - .WithPredicate(IsReduceSum) - .WithOperand(0, OptionalBitcast(m::MultiplyAnyOrder( - m::MultiplyAnyOrder( - m::Op(&bwd_softmax_input), - m::Broadcast()), - m::Exp(&exp_2, m::Op()))))))))))), - m::Exp(&exp_1, m::Op())))); + m::Divide().WithOneUse(), + m::Broadcast(OptionalBitcast(OptionalConvert( + m::Negate( + OptionalBitcast( + m::Op() + .WithPredicate(IsReduceSum) + .WithOneUse() + .WithOperand( + 0, OptionalBitcast( + m::MultiplyAnyOrder( + m::MultiplyAnyOrder( + m::Op(&bwd_softmax_input), + m::Broadcast()) + .WithOneUse(), + m::Exp(&exp_2, m::Op())) + .WithOneUse())))) + .WithOneUse())))), + m::Exp(&exp_1, m::Op())) + .WithAtMostNumUser(3))); // Backward mask input pattern // we already matched this in the fwd. Just make sure the same mask is used in // the bwd HloInstruction* bwd_mask_input = nullptr; HloInstruction* bwd_mask = nullptr; - auto bwd_mask_pattern = OptionalConvert( - m::Select(m::Op(&bwd_mask).WithPredicate([](const HloInstruction* instr) { + HloInstruction* d_mask = nullptr; + auto bwd_mask_pattern = OptionalConvert(m::Select( + &d_mask, m::Op(&bwd_mask).WithPredicate([](const HloInstruction* instr) { return instr->shape().element_type() == PRED; }), - m::Op(&bwd_mask_input), m::Op())); + m::Op(&bwd_mask_input), m::Op())); // Backward scale input pattern HloInstruction* bwd_scale_input = nullptr; auto bwd_scale_pattern = m::MultiplyAnyOrder(m::Op(&bwd_scale_input), - m::Broadcast(m::Constant().WithPredicate(IsScalar))); + m::Broadcast(m::Constant().WithPredicate(IsScalar))) + .WithNumUser(2); int intermediate_input_pos = is_bmm1_grad1_canonicalized ? 1 : 0; HloInstruction* intermediate_input = match_result.matched_bmm_1_grad_1->mutable_operand( intermediate_input_pos); - if (Match(intermediate_input, bwd_scale_pattern)) { + has_scale = Match(intermediate_input, bwd_scale_pattern); + + if (has_scale) { intermediate_input = bwd_scale_input; } @@ -932,22 +1111,36 @@ MatchBwdResult MatchBwdBmmSoftmaxDropoutBmm(MatchBwdResult previous_result, match_result.matched_custom_call_name = kCudnnfMHASoftmaxBackwardCallTarget; } - - // If d_softmax tensor has 3 consumers, then we need to output the - // intermediate tensor. - bool need_d_intermediate = d_softmax->user_count() == 3; - if ((match_result.matched_custom_call_name == - kCudnnfMHAScaleBiasSoftmaxDropoutBackwardCallTarget || - match_result.matched_custom_call_name == - kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget || - match_result.matched_custom_call_name == - kCudnnfMHAScaleBiasMaskSoftmaxDropoutBackwardCallTarget || - match_result.matched_custom_call_name == - kCudnnfMHAScaleBiasMaskSoftmaxBackwardCallTarget) && - need_d_intermediate) { - match_result.matched_d_intermediate = d_softmax; + // try to pattern match dbias + HloInstruction* dS = has_mask ? d_mask : d_softmax; + if (dS->users()[0]->opcode() == HloOpcode::kConvert) { + dS = dS->users()[0]; + } + if (has_scale) { + // bmm1-(scale)-(bias)-(mask)-softmax pattern + // users could be dbias besides mask bwd or scale bwd + if (dS->user_count() == 1) { + // no dbias + match_result.has_match = true; + } else if (dS->user_count() == 2) { + match_result = + MatchDbias(match_result, dS, {bwd_scale_input, bwd_mask_input}); + } else { + match_result.has_match = false; + } + } else { + // bmm1-(bias)-softmax pattern + // users could be dbias besides bmm1grad1 bmm1grad2 + if (dS->user_count() == 2) { + match_result.has_match = true; + } else if (dS->user_count() == 3) { + match_result = MatchDbias(match_result, dS, + {match_result.matched_bmm_1_grad_1, + match_result.matched_bmm_1_grad_2}); + } else { + match_result.has_match = false; + } } - match_result.has_match = true; return match_result; } // First, we look for the bmm2 gradient gemm 1 which takes the activation @@ -992,7 +1185,6 @@ MatchBwdResult MatchBwdMHAPatternsForCanonicalization( if (!match_result.has_match) { return match_result; } - // Found default bmm-bmm backward graph. if (match_result.matched_bmm_2_grad_2->users().size() == 2 && (match_result.matched_bmm_1_grad_1->IsUserOf( @@ -1008,21 +1200,20 @@ MatchBwdResult MatchBwdMHAPatternsForCanonicalization( return match_result; } -StatusOr IsMHABlockSupported(HloInstruction* bmm_1, HloInstruction* bmm_2, - bool need_canonicalization, bool is_training, - std::string& custom_call_name, - const DebugOptions& debug_options) { +absl::StatusOr IsMHABlockSupported( + HloInstruction* bmm_1, HloInstruction* bmm_2, bool need_canonicalization, + bool is_training, bool is_causal_mask, bool& is_flash_attention, + std::string& custom_call_name, const DebugOptions& debug_options, + stream_executor::CudaComputeCapability cc, + stream_executor::dnn::VersionInfo cudnn_version) { if (MHACallHasDropout(custom_call_name) && !debug_options.xla_gpu_fused_attention_use_cudnn_rng()) { VLOG(3) << "Using CUDNN RNG for fused attention dropout is not enabled.\n"; return false; } - if (is_training && - (custom_call_name != kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget && - custom_call_name != kCudnnfMHAScaleBiasSoftmaxCallTarget && - custom_call_name != kCudnnfMHAScaleBiasMaskSoftmaxDropoutCallTarget && - custom_call_name != kCudnnfMHAScaleBiasMaskSoftmaxCallTarget)) { + // cuDNN FMHA requires softmax for backward + if (is_training && custom_call_name == kCudnnfMHABmmBmmCallTarget) { VLOG(3) << "Unsupported fused MHA training pattern.\n"; return false; } @@ -1038,15 +1229,53 @@ StatusOr IsMHABlockSupported(HloInstruction* bmm_1, HloInstruction* bmm_2, return false; } - TF_ASSIGN_OR_RETURN(bool is_bmm1_supported, IsSupportedBMM1(bmm_1)); - if (!is_bmm1_supported) return false; - TF_ASSIGN_OR_RETURN(bool is_bmm2_supported, - IsSupportedBMM2(bmm_2, need_canonicalization)); - if (!is_bmm2_supported) return false; - return true; + if (bmm_1->shape().rank() != 4 || bmm_2->shape().rank() != 4) { + if (VLOG_IS_ON(2)) { + VLOG(2) << "Unsupported bmm rank for cuDNN MHA fusion:\n" + << bmm_1->ToString() << "\nOR\n" + << bmm_2->ToString() << "\n" + << "Only bmm with rank 4 is supported."; + } + return false; + } + + // get batch/num heads/sequence length/hidden dim from bmm1 and bmm2 + // also make sure they are the same between bmm1 and bmm2 + TF_ASSIGN_OR_RETURN(std::optional qkv_layout, + GetQKVLayout(bmm_1, bmm_2, need_canonicalization)); + if (!qkv_layout.has_value()) { + VLOG(2) << "bmm1 and bmm2 have different qkv layout."; + return false; + } + + // check if matched attention block is supported by cuDNN flash attention. + TF_ASSIGN_OR_RETURN( + is_flash_attention, + IsFlashAttention(qkv_layout.value(), is_training, cc, cudnn_version)); + if (is_flash_attention) { + if (is_causal_mask) { + // if bias is causal mask, needs to remove bias from name + if (custom_call_name == kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget) { + custom_call_name = kCudnnfMHASoftmaxDropoutCallTarget; + } else if (custom_call_name == kCudnnfMHAScaleBiasSoftmaxCallTarget) { + custom_call_name = kCudnnfMHASoftmaxCallTarget; + } else if (custom_call_name == + kCudnnfMHAScaleBiasMaskSoftmaxDropoutCallTarget) { + custom_call_name = kCudnnfMHAScaleMaskSoftmaxDropoutCallTarget; + } else if (custom_call_name == kCudnnfMHAScaleBiasMaskSoftmaxCallTarget) { + custom_call_name = kCudnnfMHAScaleMaskSoftmaxCallTarget; + } + } + return true; + } + // check if matched attention block is supported by cuDNN fused attention. + TF_ASSIGN_OR_RETURN( + bool is_fused_attention, + IsFusedAttention(qkv_layout.value(), is_training, cc, cudnn_version)); + return is_fused_attention; } -StatusOr CanonicalizeBatchedGemmForcuDNNFMHA( +absl::StatusOr CanonicalizeBatchedGemmForcuDNNFMHA( HloInstruction* bmm, HloComputation* comp) { if (VLOG_IS_ON(3)) { VLOG(3) << "Before FMHA Dot Cannonicalization: \n" @@ -1084,7 +1313,7 @@ StatusOr CanonicalizeBatchedGemmForcuDNNFMHA( return new_dot; } -StatusOr ChangeCheckedDimToFastest( +absl::StatusOr ChangeCheckedDimToFastest( HloComputation* comp, HloInstruction* bmm, bool is_lhs, bool should_contracting_be_fastest) { const DotDimensionNumbers& dot_dims_bmm = bmm->dot_dimension_numbers(); @@ -1105,28 +1334,27 @@ StatusOr ChangeCheckedDimToFastest( is_lhs ? lhs_minor_to_major_bmm : rhs_minor_to_major_bmm; CHECK_EQ(contracting_dims.size(), 1); - TF_ASSIGN_OR_RETURN(std::vector non_contracting_dim_nums_bmm, + TF_ASSIGN_OR_RETURN(std::vector non_contracting_dims, GetNonContractingDims(bmm->operand(bmm_operand)->shape(), batch_dims, contracting_dims)); - CHECK_EQ(non_contracting_dim_nums_bmm.size(), 1); + CHECK_EQ(non_contracting_dims.size(), 1); HloInstruction* operand_bmm = bmm->mutable_operand(bmm_operand); - std::vector contracting_dims_to_check{contracting_dims[0]}; - std::vector dims_to_set = should_contracting_be_fastest - ? contracting_dims_to_check - : non_contracting_dim_nums_bmm; - // If the dimension being checked(contracting or non-contracting) of the - // target operand is not the fastest moving dimension, make it so. - if (minor_to_major_to_check[0] != dims_to_set[0]) { + int64_t hidden_dim = should_contracting_be_fastest ? contracting_dims[0] + : non_contracting_dims[0]; + int64_t minor_dim = minor_to_major_to_check[0]; + // If the hidden dim of the target operand is not the fastest moving + // dimension, make it so. + if (minor_dim != hidden_dim) { std::vector perm(bmm->shape().dimensions_size()); std::iota(perm.begin(), perm.end(), 0); - std::swap(perm[dims_to_set[0]], perm[minor_to_major_to_check[0]]); + std::swap(perm[hidden_dim], perm[minor_dim]); if (is_lhs) { - new_dot_dims_bmm.set_lhs_contracting_dimensions( - 0, non_contracting_dim_nums_bmm[0]); + new_dot_dims_bmm.set_lhs_contracting_dimensions(0, + non_contracting_dims[0]); } else { - new_dot_dims_bmm.set_rhs_contracting_dimensions( - 0, non_contracting_dim_nums_bmm[0]); + new_dot_dims_bmm.set_rhs_contracting_dimensions(0, + non_contracting_dims[0]); } operand_bmm = comp->AddInstruction( @@ -1134,7 +1362,7 @@ StatusOr ChangeCheckedDimToFastest( ShapeUtil::MakeShapeWithDenseLayout( bmm->shape().element_type(), Permute(operand_bmm->shape().dimensions(), perm), - rhs_minor_to_major_bmm), + minor_to_major_to_check), operand_bmm, perm), &operand_bmm->metadata()); *((DynCast(bmm))->mutable_dot_dimension_numbers()) = @@ -1143,12 +1371,13 @@ StatusOr ChangeCheckedDimToFastest( return operand_bmm; } -StatusOr FuseFwdMultiHeadedAttentionBlock( +absl::StatusOr FuseFwdMultiHeadedAttentionBlock( HloComputation* comp, HloInstruction* bmm_1, HloInstruction* bmm_2, HloInstruction* bias, HloInstruction* mask, HloInstruction* scale, + HloInstruction* reduce_sum, HloInstruction* softmax_input, double dropout_rate, std::string& custom_call_name, stream_executor::CudaComputeCapability cc, bool is_training, bool& changed, - bool& v_transposed) { + bool& v_transposed, bool is_causal_mask, bool is_flash_attention) { double scale_value = 1.0; HloInstruction* lhs_bmm1; HloInstruction* rhs_bmm1; @@ -1171,7 +1400,10 @@ StatusOr FuseFwdMultiHeadedAttentionBlock( v_transposed = true; } - CudnnfMHABackendConfig fmha_config; + GpuBackendConfig gpu_config; + CudnnfMHABackendConfig& fmha_config = + *gpu_config.mutable_cudnn_fmha_backend_config(); + *fmha_config.mutable_bmm1_dot_dimension_numbers() = bmm_1->dot_dimension_numbers(); *fmha_config.mutable_bmm2_dot_dimension_numbers() = @@ -1212,38 +1444,66 @@ StatusOr FuseFwdMultiHeadedAttentionBlock( algorithm->set_is_cudnn_frontend(true); algorithm->mutable_workspace_size()->set_value(0); } + + // set is flash attention here + // choose to use flash attention or non-fa attention based on this flag. + fmha_config.set_is_flash_attention(is_flash_attention); + // set is_causal_mask here + // choose to generate causal mask inside cuDNN attention or not + fmha_config.set_is_causal_mask(is_causal_mask); + + // Output Order: {O, scratch, Fwd act*} const Shape& output_shape = bmm_2->shape(); Shape call_shape; // Activation output is used by backward gemm. HloInstruction* activation_output = nullptr; - std::vector output_shapes = {output_shape, - ShapeUtil::MakeShape(U8, {0})}; + std::vector output_shapes = { + output_shape, + ShapeUtil::MakeShape( + U8, {is_flash_attention + ? 16 + : 0})}; // reserved 2 int64 for dropout seed and offset if (is_training) { - // TODO Flush attention will have a different shape in training. activation_output = bmm_2->mutable_operand(0); // Sometimes activation output is bitcast, the actual activation is the - // second user of the producer of bmm_2's first operand. + // other user of the producer of bmm_2's first operand. if (activation_output->user_count() < 2 && activation_output->opcode() == HloOpcode::kBitcast) { HloInstruction* producer = activation_output->mutable_operand(0); TF_RET_CHECK(producer->user_count() == 2); - activation_output = producer->UserId(activation_output) == 0 - ? producer->users()[1] - : producer->users()[0]; + HloInstruction* bmm2_grad2_user = + producer->users()[0] == activation_output ? producer->users()[1] + : producer->users()[0]; + // might be (transpose) - bmm2_grad2 + if (IsBatchedMatmul(bmm2_grad2_user)) { + activation_output = producer; + } else if (bmm2_grad2_user->opcode() == HloOpcode::kTranspose) { + activation_output = bmm2_grad2_user; + } else { + return Internal("Unexpected activation patterns"); + } + } + // if it is flash attention, should output softmax stats to the bwd + if (is_flash_attention) { + TF_RET_CHECK(reduce_sum != nullptr); + output_shapes.push_back( + ShapeUtil::MakeShape(F32, reduce_sum->shape().dimensions())); + } else { + output_shapes.push_back(activation_output->shape()); } - output_shapes.push_back(activation_output->shape()); } call_shape = ShapeUtil::MakeTupleShape(output_shapes); + // Input Order: {Q, K, V, mask*, bias*} std::vector operands = {lhs_bmm1, rhs_bmm1, rhs_bmm2}; if (mask != nullptr) { HloInstruction* converted_mask = comp->AddInstruction( HloInstruction::CreateConvert(bmm_1->shape(), mask)); operands.push_back(converted_mask); } - if (bias != nullptr) { + if ((!is_flash_attention || !is_causal_mask) && bias != nullptr) { HloInstruction* original_bias; HloInstruction* original_broadcast; // There will be cases where the bias is up-casted to wider float type, @@ -1287,7 +1547,7 @@ StatusOr FuseFwdMultiHeadedAttentionBlock( HloInstruction* fmha_call = comp->AddInstruction(HloInstruction::CreateCustomCall( call_shape, operands, absl::string_view(custom_call_name))); - TF_RETURN_IF_ERROR(fmha_call->set_backend_config(fmha_config)); + TF_RETURN_IF_ERROR(fmha_call->set_backend_config(gpu_config)); TF_RETURN_IF_ERROR(SetFMHAInstructionName(bmm_1->GetModule(), fmha_call)); TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction( @@ -1313,43 +1573,18 @@ StatusOr FuseFwdMultiHeadedAttentionBlock( return fmha_call; } -bool IsDbiasOnlyUserBesidesGradGemm(HloInstruction* d_intermediate, - HloInstruction* bmm_1_grad_1, - HloInstruction* bmm_1_grad_2, - HloInstruction** dbias) { - auto user_count = d_intermediate->user_count(); - HloInstruction* dbias_user = nullptr; - for (auto user : d_intermediate->users()) { - if (user == bmm_1_grad_1) { - user_count -= 1; - } else if (user == bmm_1_grad_2) { - user_count -= 1; - } else { - dbias_user = user; - } - } - auto ConsumeExtraConvert = [](HloInstruction** instr) { - Match((*instr)->users()[0], m::Convert(instr, m::Op()).WithOneUse()); - return true; - }; - // user_count == 1 && (reduce-> {convert} ->bitcast) - return user_count == 1 && - Match(dbias_user, m::Reduce(dbias, m::Op(), m::Op()).WithOneUse()) && - (*dbias)->shape().rank() == 3 && ConsumeExtraConvert(dbias); -} - -StatusOr FuseBwdMultiHeadedAttentionBlock( +absl::StatusOr FuseBwdMultiHeadedAttentionBlock( HloComputation* comp, HloInstruction* bmm_1_grad_1, HloInstruction* bmm_1_grad_2, HloInstruction* bmm_2_grad_1, HloInstruction* bmm_2_grad_2, HloInstruction* fwd_fmha_call, - HloInstruction* d_intermediate, HloInstruction* mask, - std::string& bwd_custom_call_name, bool fwd_bmm_2_canonicalized, - bool is_bmm2_grad1_canonicalized) { + HloInstruction* dbias, HloInstruction* mask, HloInstruction* bias, + std::string& bwd_custom_call_name) { HloInstruction* rhs_bmm1_grad_gemm1; HloInstruction* lhs_bmm1_grad_gemm2; HloInstruction* lhs_bmm2_grad_gemm1; HloInstruction* rhs_bmm2_grad_gemm2; HloInstruction* d_output_grad; + DotDimensionNumbers orig_bmm1_grad1_config = bmm_1_grad_1->dot_dimension_numbers(); DotDimensionNumbers orig_bmm1_grad2_config = @@ -1359,6 +1594,12 @@ StatusOr FuseBwdMultiHeadedAttentionBlock( DotDimensionNumbers orig_bmm2_grad2_config = bmm_2_grad_2->dot_dimension_numbers(); + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + fwd_fmha_call->backend_config()); + CudnnfMHABackendConfig fwd_config = gpu_config.cudnn_fmha_backend_config(); + bool is_flash_attention = fwd_config.is_flash_attention(); + bool is_causal_mask = fwd_config.is_causal_mask(); + CudnnfMHABackendConfig bwd_fmha_config; // Q tensor TF_ASSIGN_OR_RETURN( rhs_bmm1_grad_gemm1, @@ -1369,67 +1610,74 @@ StatusOr FuseBwdMultiHeadedAttentionBlock( lhs_bmm1_grad_gemm2, ChangeCheckedDimToFastest(comp, bmm_1_grad_2, false /*is_lhs*/, false /*should_contracting_be_fastest*/)); - // Forward activation + // P tensor TF_ASSIGN_OR_RETURN( lhs_bmm2_grad_gemm1, ChangeCheckedDimToFastest(comp, bmm_2_grad_1, true /*is_lhs*/, false /*should_contracting_be_fastest*/)); + + // Forward activation + // if it is not flash attention, fwd activation is the P tensor + // else it is the softmax_stats + HloInstruction* fwd_act; + if (fwd_config.is_flash_attention()) { + auto fwd_act_index = 2; + fwd_act = comp->AddInstruction(HloInstruction::CreateGetTupleElement( + fwd_fmha_call->shape().tuple_shapes(fwd_act_index), fwd_fmha_call, + fwd_act_index)); + } else { + fwd_act = lhs_bmm2_grad_gemm1; + } + // V tensor TF_ASSIGN_OR_RETURN( rhs_bmm2_grad_gemm2, ChangeCheckedDimToFastest(comp, bmm_2_grad_2, false /*is_lhs*/, true /*should_contracting_be_fastest*/)); - // d output + // d output to bmm2_grad2 // Since d_o is the input of 2 bmms, we set the dim number using the // constraint // -> the contracting dimension of the lhs of bmm_2_grad_2 needs to be the // fastest moving dimension. - TF_ASSIGN_OR_RETURN(d_output_grad, ChangeCheckedDimToFastest( - comp, bmm_2_grad_2, true /*is_lhs*/, - true /*check_contracting_dim*/)); - // Operand order {Q, K, V, Fwd act, d_o, mask*} + TF_ASSIGN_OR_RETURN( + d_output_grad, + ChangeCheckedDimToFastest(comp, bmm_2_grad_2, true /*is_lhs*/, + true /*should_contracting_be_fastest*/)); + // d output to bmm2_grad1 + // we don't use this value but we call this to make sure dot number is being + // set correctly + TF_ASSIGN_OR_RETURN( + HloInstruction * bmm_2_grad_1_rhs, + ChangeCheckedDimToFastest(comp, bmm_2_grad_1, false /*is_lhs*/, + false /*should_contracting_be_fastest*/)); + (void)bmm_2_grad_1_rhs; + // Operand order: {Q, K, V, Fwd act, d_o, mask*, bias*, O*} std::vector operands = { - rhs_bmm1_grad_gemm1, lhs_bmm1_grad_gemm2, rhs_bmm2_grad_gemm2, - lhs_bmm2_grad_gemm1, d_output_grad}; + rhs_bmm1_grad_gemm1, lhs_bmm1_grad_gemm2, rhs_bmm2_grad_gemm2, fwd_act, + d_output_grad}; if (mask) { HloInstruction* converted_mask = comp->AddInstruction( HloInstruction::CreateConvert(bmm_2_grad_2->shape(), mask)); operands.push_back(converted_mask); } - TF_ASSIGN_OR_RETURN(CudnnfMHABackendConfig fwd_config, - fwd_fmha_call->backend_config()); - CudnnfMHABackendConfig bwd_fmha_config; - - // If forward bmm_2 is canonicalized, the contracting dimension of lhs - // of bmm_2_grad_1 needs to be changed to the non-contracting dimension. - - if (fwd_bmm_2_canonicalized) { - TF_ASSIGN_OR_RETURN( - std::vector bmm_2_grad_1_lhs_non_contracting_dims, - GetNonContractingDims( - bmm_2_grad_1->shape(), - bmm_2_grad_1->dot_dimension_numbers().lhs_batch_dimensions(), - bmm_2_grad_1->dot_dimension_numbers() - .lhs_contracting_dimensions())); - CHECK_EQ(bmm_2_grad_1_lhs_non_contracting_dims.size(), 1); - (DynCast(bmm_2_grad_1)) - ->mutable_dot_dimension_numbers() - ->set_lhs_contracting_dimensions( - 0, bmm_2_grad_1_lhs_non_contracting_dims[0]); - } - TF_ASSIGN_OR_RETURN( - std::vector bmm_2_grad_1_new_contracting_dims, - GetNonContractingDims( - bmm_2_grad_1->shape(), - bmm_2_grad_1->dot_dimension_numbers().rhs_batch_dimensions(), - bmm_2_grad_1->dot_dimension_numbers().rhs_contracting_dimensions())); - - if (is_bmm2_grad1_canonicalized) { - (DynCast(bmm_2_grad_1)) - ->mutable_dot_dimension_numbers() - ->set_rhs_contracting_dimensions(0, - bmm_2_grad_1_new_contracting_dims[0]); + // if is flash attention, add fwd output to input list + if (is_flash_attention) { + if (!is_causal_mask && bias) { + operands.push_back(bias); + } + HloInstruction* fwd_output; + for (auto user : fwd_fmha_call->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->tuple_index() == 0) { + fwd_output = user; + } + } + // should be able to find the instruction + TF_RET_CHECK(fwd_output != nullptr); + // check dO and O have the same layout as it is required by cuDNN + TF_RET_CHECK(fwd_output->shape() == d_output_grad->shape()); + operands.push_back(fwd_output); } *bwd_fmha_config.mutable_bmm1_grad_gemm1_dot_dimension_numbers() = @@ -1458,6 +1706,10 @@ StatusOr FuseBwdMultiHeadedAttentionBlock( // TODO Find a way to compute original seed from dropout keys. bwd_fmha_config.set_seed(fwd_config.seed()); + // Set is flash attention + bwd_fmha_config.set_is_flash_attention(is_flash_attention); + bwd_fmha_config.set_is_causal_mask(is_causal_mask); + *bwd_fmha_config.mutable_intermediate_tensor_shape() = fwd_config.intermediate_tensor_shape(); { @@ -1474,40 +1726,44 @@ StatusOr FuseBwdMultiHeadedAttentionBlock( } // Output order: - // dQ(bmm_1_grad_2), dK(bmm_1_grad_1), dV(bmm_2_grad_1), - // d_intermediate_tensor, d_bias_tensor + // {dQ(bmm_1_grad_2), dK(bmm_1_grad_1), dV(bmm_2_grad_1), + // d_intermediate_tensor*, softmax_sum*, d_Q_accum*, scratch, dbias*} std::vector output_shapes = { bmm_1_grad_2->shape(), bmm_1_grad_1->shape(), bmm_2_grad_1->shape()}; - // d_intermediate is required to be output - output_shapes.push_back(lhs_bmm2_grad_gemm1->shape()); - + if (!fwd_config.is_flash_attention()) { + output_shapes.push_back(lhs_bmm2_grad_gemm1->shape()); + } else { + // softmax_sum, d_Q_accum + // add softmax sum here and change the data type + // softmax sum and d_Q_accum should both be fp32 datatype + output_shapes.push_back( + ShapeUtil::MakeShape(F32, fwd_act->shape().dimensions())); + output_shapes.push_back( + ShapeUtil::MakeShape(F32, bmm_1_grad_2->shape().dimensions())); + } // Reserved placeholder for workspace - output_shapes.push_back(ShapeUtil::MakeShape(U8, {0})); + output_shapes.push_back(ShapeUtil::MakeShape( + U8, {is_flash_attention + ? 16 + : 0})); // reserved 2 int64 for dropout seed and offset - HloInstruction* dbias = nullptr; - if (d_intermediate) { - if (IsDbiasOnlyUserBesidesGradGemm(d_intermediate, bmm_1_grad_1, - bmm_1_grad_2, &dbias)) { - // Cudnn kernel only outputs dbias in this shape [1, num_heads, seq, seq], - // so we add a dimension of 1 to existing dbias' shape. - std::vector dbias_shape_vector = - SpanToVector(dbias->shape().dimensions()); - dbias_shape_vector.insert(dbias_shape_vector.begin(), 1); - Shape cudnn_dbias_shape = ShapeUtil::MakeShape( - dbias->shape().element_type(), dbias_shape_vector); - output_shapes.push_back(cudnn_dbias_shape); - } else { - VLOG(2) << "Intermediate gradient has other users outside of gradient " - "gemms and dbias" - << " which is not supported by CUDNN for now. Skipping."; - return false; - } + if (dbias) { + // Cudnn kernel only outputs dbias in this shape [1, num_heads, seq, seq], + // so we add a dimension of 1 to existing dbias' shape. + std::vector dbias_shape_vector = + SpanToVector(dbias->shape().dimensions()); + dbias_shape_vector.insert(dbias_shape_vector.begin(), 1); + Shape cudnn_dbias_shape = + ShapeUtil::MakeShape(dbias->shape().element_type(), dbias_shape_vector); + output_shapes.push_back(cudnn_dbias_shape); } Shape call_shape = ShapeUtil::MakeTupleShape(output_shapes); HloInstruction* fmha_bwd_call = comp->AddInstruction(HloInstruction::CreateCustomCall( call_shape, operands, absl::string_view(bwd_custom_call_name))); - TF_RETURN_IF_ERROR(fmha_bwd_call->set_backend_config(bwd_fmha_config)); + GpuBackendConfig bwd_gpu_config; + *bwd_gpu_config.mutable_cudnn_fmha_backend_config() = bwd_fmha_config; + TF_RETURN_IF_ERROR(fmha_bwd_call->set_backend_config(bwd_gpu_config)); TF_RETURN_IF_ERROR( SetFMHAInstructionName(bmm_1_grad_1->GetModule(), fmha_bwd_call)); @@ -1543,24 +1799,74 @@ StatusOr FuseBwdMultiHeadedAttentionBlock( } return true; } + +Status RestoreFwdGraph( + HloComputation* comp, HloInstruction* fwd_fmha_call, HloInstruction* bmm2, + HloInstruction* activation, HloInstruction* original_bmm2_producer0, + HloInstruction* original_bmm2_producer1, + std::vector& original_activation_producers, + bool bmm_2_need_canonicalization) { + // If backward pattern is not matched, we need to restore the + // original graph structure. + // Replacing new GTEs added by forward FMHA call with cloned old + // activations and bmm2. + HloInstruction* output_gte = fwd_fmha_call->users()[0]; + HloInstruction* activation_gte = fwd_fmha_call->users()[1]; + std::string suffix = "fmha_no_match_clone"; + HloInstruction* cloned_activation = + comp->AddInstruction(activation->CloneWithNewOperands( + activation->shape(), original_activation_producers, suffix)); + + // Since old activation is detached by forward FMHA rewrite, we need + // to use the newly cloned activation. + HloInstruction* lhs = activation == original_bmm2_producer0 + ? cloned_activation + : original_bmm2_producer0; + HloInstruction* rhs = activation == original_bmm2_producer0 + ? original_bmm2_producer1 + : cloned_activation; + HloInstruction* cloned_bmm2 = comp->AddInstruction( + bmm2->CloneWithNewOperands(bmm2->shape(), {lhs, rhs}, suffix)); + if (bmm_2_need_canonicalization) { + TF_RET_CHECK(output_gte->users()[0]->opcode() == HloOpcode::kTranspose); + TF_RETURN_IF_ERROR( + comp->ReplaceInstruction(output_gte->users()[0], cloned_bmm2)); + } else { + TF_RETURN_IF_ERROR(comp->ReplaceInstruction(output_gte, cloned_bmm2)); + } + TF_RETURN_IF_ERROR( + comp->ReplaceInstruction(activation_gte, cloned_activation)); + return OkStatus(); +} } // namespace -StatusOr CudnnFusedMHARewriter::Run( +absl::StatusOr CudnnFusedMHARewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool any_changed = false; + // we use this set to keep track of all already matched attention block + absl::flat_hash_set matched_bmm1; for (HloComputation* comp : module->MakeNonfusionComputations(execution_threads)) { const DebugOptions& debug_options = comp->parent()->config().debug_options(); + const se::dnn::VersionInfo cudnn_version = + GetDnnVersionInfo(stream_executor_, cudnn_version_); +#if !defined(GOOGLE_CUDA) || CUDA_VERSION < 12000 + // CUDA needs to be >= 12.0 for cuDNN to work with all supported hardware. + // Some cuDNN versions work with CUDA 11, but it is impractical for us to + // test those combinations so just disable them. + return false; +#endif if (!debug_options.xla_gpu_enable_cudnn_fmha() || !IsComputeCapabilityAndCudnnSupported( - compute_capability_, cudnn_version_, stream_executor_, + compute_capability_, cudnn_version, stream_executor::dnn::VersionInfo(8, 8, 0))) { return false; } for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { bool v_transposed = false; + bool changed = false; MatchFwdResult matched_result = MatchFwdMHAPatternsForCanonicalization(instr); if (!matched_result.has_match) { @@ -1568,14 +1874,17 @@ StatusOr CudnnFusedMHARewriter::Run( } // We check the validity of bmms here before canonicalization so we don't // modify the graph if mha fusion is not possible + // Relax 512 constraint if it is flash attention TF_ASSIGN_OR_RETURN( bool is_mha_module_supported, IsMHABlockSupported( matched_result.matched_bmm_1, matched_result.matched_bmm_2, matched_result.need_canonicalization, matched_result.is_training, - matched_result.matched_custom_call_name, debug_options)); - if (!is_mha_module_supported) continue; + matched_result.is_causal_mask, matched_result.is_flash_attention, + matched_result.matched_custom_call_name, debug_options, + compute_capability_, cudnn_version)); + if (!is_mha_module_supported) continue; // If we have an activation with more than 1 users in non-training mode, // we cannot rewrite the graph. So skip processing the rest. HloInstruction* activation = @@ -1593,10 +1902,15 @@ StatusOr CudnnFusedMHARewriter::Run( HloInstruction* original_bmm2_producer1 = matched_result.matched_bmm_2->mutable_operand(1); + HloInstruction* original_bmm2 = matched_result.matched_bmm_2; std::vector original_activation_producers; for (HloInstruction* operand : activation->mutable_operands()) { original_activation_producers.push_back(operand); } + // We make sure no attention block is matched and replaced twice here + if (!matched_bmm1.insert(matched_result.matched_bmm_1).second) { + continue; + } // If we need to canonicalize the bmm, we will assign the newly // canonicalized bmm to bmm_2. if (matched_result.need_canonicalization) { @@ -1604,7 +1918,7 @@ StatusOr CudnnFusedMHARewriter::Run( CanonicalizeBatchedGemmForcuDNNFMHA( matched_result.matched_bmm_2, comp)); } - bool changed = false; + // Fuse the bmms and intermediate nodes into fMHA call, the fused call // will replace bmm_2. TF_ASSIGN_OR_RETURN( @@ -1612,67 +1926,56 @@ StatusOr CudnnFusedMHARewriter::Run( FuseFwdMultiHeadedAttentionBlock( comp, matched_result.matched_bmm_1, matched_result.matched_bmm_2, matched_result.matched_bias, matched_result.matched_mask, - matched_result.matched_scale, matched_result.matched_dropout_rate, + matched_result.matched_scale, matched_result.matched_reduce_sum, + matched_result.matched_softmax_input, + matched_result.matched_dropout_rate, matched_result.matched_custom_call_name, compute_capability_, - matched_result.is_training, changed, v_transposed)); + matched_result.is_training, changed, v_transposed, + matched_result.is_causal_mask, + matched_result.is_flash_attention)); any_changed |= changed; - if (matched_result.is_training) { - // if fwd uses mask input, then bwd needs cudnn 8.9.1 to take in a mask - // input if cudnn version < 8.9.1 we won't lower the bwd pass - if (matched_result.matched_mask != nullptr && - !IsComputeCapabilityAndCudnnSupported( - compute_capability_, cudnn_version_, stream_executor_, - stream_executor::dnn::VersionInfo(8, 9, 1))) { - continue; - } MatchBwdResult matched_bwd_result = MatchBwdMHAPatternsForCanonicalization( fwd_fmha_call, matched_result.matched_bmm_1, matched_result.matched_mask, v_transposed); if (!matched_bwd_result.has_match) { VLOG(2) << "Backward pattern not matching, skipping."; - // If backward pattern is not matched, we need to restore the - // original graph structure. - // Replacing new GTEs added by forward FMHA call with cloned old - // activations and bmm2. - HloInstruction* output_gte = fwd_fmha_call->users()[0]; - HloInstruction* activation_gte = fwd_fmha_call->users()[1]; - std::string suffix = "fmha_no_match_clone"; - HloInstruction* cloned_activation = - comp->AddInstruction(activation->CloneWithNewOperands( - activation->shape(), original_activation_producers, suffix)); - - // Since old activation is detached by forward FMHA rewrite, we need - // to use the newly cloned activation. - HloInstruction* lhs = activation == original_bmm2_producer0 - ? cloned_activation - : original_bmm2_producer1; - HloInstruction* rhs = activation == original_bmm2_producer0 - ? original_bmm2_producer1 - : cloned_activation; - HloInstruction* cloned_bmm2 = comp->AddInstruction( - matched_result.matched_bmm_2->CloneWithNewOperands( - matched_result.matched_bmm_2->shape(), {lhs, rhs}, suffix)); - - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(output_gte, cloned_bmm2)); + // restore fwd graph if bwd pattern match failed + TF_RETURN_IF_ERROR( + RestoreFwdGraph(comp, fwd_fmha_call, original_bmm2, activation, + original_bmm2_producer0, original_bmm2_producer1, + original_activation_producers, + matched_result.need_canonicalization)); + continue; + } + // if fwd uses mask input, then bwd needs cudnn 8.9.1 to take in a mask + // input if cudnn version < 8.9.1 we won't lower the bwd pass + if (matched_result.matched_mask != nullptr && + !IsComputeCapabilityAndCudnnSupported( + compute_capability_, cudnn_version, + stream_executor::dnn::VersionInfo(8, 9, 1))) { + // restore fwd graph if bwd pattern match failed TF_RETURN_IF_ERROR( - comp->ReplaceInstruction(activation_gte, cloned_activation)); + RestoreFwdGraph(comp, fwd_fmha_call, original_bmm2, activation, + original_bmm2_producer0, original_bmm2_producer1, + original_activation_producers, + matched_result.need_canonicalization)); continue; } - // check if dbias is the only user of d_intermediate besides - // bmm_1_grad_1 and bmm_1_grad_2 and the cudnn version is > 8.9.1. We + // check if dbias exist and the cudnn version is > 8.9.1. We // won't lower bwd if this condition is not met as we won't deal with // unswizzling now - HloInstruction* dbias = nullptr; - if (matched_bwd_result.matched_d_intermediate && - !IsDbiasOnlyUserBesidesGradGemm( - matched_bwd_result.matched_d_intermediate, - matched_bwd_result.matched_bmm_1_grad_1, - matched_bwd_result.matched_bmm_1_grad_2, &dbias) && + if (matched_bwd_result.matched_dbias && !IsComputeCapabilityAndCudnnSupported( - compute_capability_, cudnn_version_, stream_executor_, + compute_capability_, cudnn_version, stream_executor::dnn::VersionInfo(8, 9, 1))) { + // restore fwd graph if bwd pattern match failed + TF_RETURN_IF_ERROR( + RestoreFwdGraph(comp, fwd_fmha_call, original_bmm2, activation, + original_bmm2_producer0, original_bmm2_producer1, + original_activation_producers, + matched_result.need_canonicalization)); continue; } // Canonicalize gemms @@ -1709,11 +2012,9 @@ StatusOr CudnnFusedMHARewriter::Run( matched_bwd_result.matched_bmm_1_grad_2, matched_bwd_result.matched_bmm_2_grad_1, matched_bwd_result.matched_bmm_2_grad_2, fwd_fmha_call, - matched_bwd_result.matched_d_intermediate, - matched_result.matched_mask, - matched_bwd_result.matched_custom_call_name, - matched_result.need_canonicalization, - matched_bwd_result.bmm_2_grad_1_need_canonicalization)); + matched_bwd_result.matched_dbias, matched_result.matched_mask, + matched_result.matched_bias, + matched_bwd_result.matched_custom_call_name)); any_changed |= changed; } } diff --git a/xla/service/gpu/cudnn_fused_mha_rewriter.h b/xla/service/gpu/cudnn_fused_mha_rewriter.h index 76704ff06ae1c..f0aa6871caf90 100644 --- a/xla/service/gpu/cudnn_fused_mha_rewriter.h +++ b/xla/service/gpu/cudnn_fused_mha_rewriter.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,8 +16,13 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_CUDNN_FUSED_MHA_REWRITER_H_ #define XLA_SERVICE_GPU_CUDNN_FUSED_MHA_REWRITER_H_ -#include "xla/hlo/ir/hlo_instructions.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" namespace xla { @@ -38,7 +43,7 @@ class CudnnFusedMHARewriter : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc b/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc index 944acedd15319..b52fa5be2c598 100644 --- a/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc +++ b/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,11 +16,15 @@ limitations under the License. #include "xla/service/gpu/cudnn_fused_mha_rewriter.h" #include +#include +#include +#include #include #include #include "absl/algorithm/container.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/algebraic_simplifier.h" #include "xla/service/computation_layout.h" #include "xla/service/gpu/backend_configs.pb.h" @@ -28,19 +32,27 @@ limitations under the License. #include "xla/service/gpu/cudnn_fused_mha_transpose_fusion.h" #include "xla/service/hlo_cse.h" #include "xla/service/hlo_dce.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_parser.h" +#include "xla/service/hlo_verifier.h" #include "xla/service/layout_normalization.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/service/reshape_decomposer.h" -#include "xla/status_macros.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" +#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: keep +#endif + namespace xla { namespace gpu { namespace { @@ -56,6 +68,13 @@ class CudnnFusedMhaRewriterTestHloTest : public HloTestBase { return se::CudaComputeCapability(8, 0); } + se::CudaComputeCapability GetRealCudaComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + } + se::dnn::VersionInfo GetCudnnVersion() { // Fake a supported compute capability to run tests, // we don't run any kernels in these tests so they should be safe @@ -70,10 +89,22 @@ class CudnnFusedMhaRewriterTestHloTest : public HloTestBase { return se::dnn::VersionInfo(8, 9, 1); } + se::dnn::VersionInfo GetCudnnVersionWithFlashAttentionSupport() { + // Fake a supported compute capability to run tests, + // we don't run any kernels in these tests so they should be safe + // to run anywhere. + return se::dnn::VersionInfo(8, 9, 4); + } + CudnnFusedMhaRewriterTestHloTest() : HloTestBase(/*verifier_layout_sensitive=*/false, /*allow_mixed_precision_in_hlo_verifier=*/false, - /*instruction_can_change_layout_func=*/{}) {} + /*instruction_can_change_layout_func=*/{}) { +#if !defined(GOOGLE_CUDA) || CUDA_VERSION < 12000 + skip_reason_ = "cuDNN fused MHA requires CUDA 12 or later."; + return; +#endif + } protected: size_t CountFusedAttentionCall(HloModule* module, bool is_backward = false) { @@ -101,10 +132,38 @@ class CudnnFusedMhaRewriterTestHloTest : public HloTestBase { config_with_fmha.set_debug_options(debug_options); return config_with_fmha; } + + // Centralize skip checks in the constructor. Unfortunately we cannot call + // GTEST_SKIP from the constructor. Instead, we set (if needed) `skip_reason`, + // and then check it from all test fixtures. + // An alternative would be to use the SetUp() override, but for this to be + // correct we'd have to ensure that all the parents' SetUp() methods are + // called, which is error prone. + std::optional skip_reason_; }; -TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1Bmm2Pattern) { - const char* module_str = R"( +class CudnnFusedMhaRewriterPipelineTest + : public CudnnFusedMhaRewriterTestHloTest { + public: + CudnnFusedMhaRewriterPipelineTest() { + if (skip_reason_) return; // the parent might have set it. +#if !defined(GOOGLE_CUDA) || CUDNN_VERSION < 8800 // NOLINT + skip_reason_ = "Pipeline test requires cuDNN 8.8.0 or later."; + return; +#endif + stream_executor::CudaComputeCapability cc = GetRealCudaComputeCapability(); + // Enforce capability minor == 0 because hardware with a non-zero minor + // number typically has insufficient shared memory for cuDNN FMHA. + if (!cc.IsAtLeastAmpere() || cc.minor != 0) { + skip_reason_ = + "Pipeline test requires Nvidia AMPERE+ GPUs with minor " + "compute capability == 0."; + return; + } + } +}; + +constexpr absl::string_view hlo_BF16Bmm1Bmm2Pattern = R"( HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}} ENTRY main.6 { Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2) @@ -112,13 +171,13 @@ ENTRY main.6 { Arg_1.2 = bf16[16,16,256,64]{3,2,1,0} parameter(1) dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={} ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(dot.0, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={} -} - - -)"; +})"; +TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1Bmm2Pattern) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; TF_ASSERT_OK_AND_ASSIGN( - auto m, ParseAndReturnVerifiedModule(module_str, GetModuleConfig())); + auto m, + ParseAndReturnVerifiedModule(hlo_BF16Bmm1Bmm2Pattern, GetModuleConfig())); CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(), GetCudnnVersion()}; TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); @@ -130,31 +189,36 @@ ENTRY main.6 { GmockMatch(m::GetTupleElement( m::CustomCall(&fmha, {kCudnnfMHABmmBmmCallTarget}), 0) .WithShape(BF16, {16, 16, 256, 64}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(config.fmha_scale(), 1.0); EXPECT_EQ(config.dropout_rate(), 0.0); -#if GOOGLE_CUDA && CUDNN_VERSION >= 8800 - // run whole pipeline +} + +TEST_F(CudnnFusedMhaRewriterPipelineTest, BF16Bmm1Bmm2Pattern) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; TF_ASSERT_OK_AND_ASSIGN( - m, ParseAndReturnVerifiedModule(module_str, GetModuleConfig())); + auto m, + ParseAndReturnVerifiedModule(hlo_BF16Bmm1Bmm2Pattern, GetModuleConfig())); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, GetOptimizedModule(std::move(m))); + const HloInstruction* fmha; + SCOPED_TRACE(optimized_module->ToString()); EXPECT_THAT( optimized_module->entry_computation()->root_instruction(), GmockMatch(m::GetTupleElement( m::CustomCall(&fmha, {kCudnnfMHABmmBmmCallTarget}), 0) .WithShape(BF16, {16, 16, 256, 64}))); - TF_ASSERT_OK_AND_ASSIGN(config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(config.fmha_scale(), 1.0); EXPECT_EQ(config.dropout_rate(), 0.0); -#endif // GOOGLE_CUDA && CUDNN_VERSION >= 8800 } -TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1Bmm2UncanonicalizedPattern) { - const char* module_str = R"( +constexpr absl::string_view hlo_BF16Bmm1Bmm2UncanonicalizedPattern = R"( HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,64,256]{3,2,1,0}} ENTRY main.6 { @@ -163,12 +227,12 @@ ENTRY main.6 { Arg_1.2 = bf16[16,16,256,64]{3,2,1,0} parameter(1) dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={} ROOT dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(Arg_2.3, dot.0), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={} -} - - -)"; +})"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); +TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1Bmm2UncanonicalizedPattern) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule( + hlo_BF16Bmm1Bmm2UncanonicalizedPattern)); CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(), GetCudnnVersion()}; TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); @@ -180,16 +244,21 @@ ENTRY main.6 { m::GetTupleElement( m::CustomCall(&fmha, {kCudnnfMHABmmBmmCallTarget}), 0) .WithShape(BF16, {16, 16, 256, 64})))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(config.fmha_scale(), 1.0); EXPECT_EQ(config.dropout_rate(), 0.0); -#if GOOGLE_CUDA && CUDNN_VERSION >= 8800 - // run whole pipeline +} + +TEST_F(CudnnFusedMhaRewriterPipelineTest, BF16Bmm1Bmm2UncanonicalizedPattern) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; TF_ASSERT_OK_AND_ASSIGN( - m, ParseAndReturnVerifiedModule(module_str, GetModuleConfig())); + auto m, ParseAndReturnVerifiedModule( + hlo_BF16Bmm1Bmm2UncanonicalizedPattern, GetModuleConfig())); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, GetOptimizedModule(std::move(m))); + const HloInstruction* fmha; SCOPED_TRACE(optimized_module->ToString()); EXPECT_THAT(optimized_module->entry_computation()->root_instruction(), @@ -197,16 +266,15 @@ ENTRY main.6 { m::GetTupleElement( m::CustomCall(&fmha, {kCudnnfMHABmmBmmCallTarget}), 0) .WithShape(BF16, {16, 16, 256, 64})))); - TF_ASSERT_OK_AND_ASSIGN(config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(config.fmha_scale(), 1.0); EXPECT_EQ(config.dropout_rate(), 0.0); -#endif // GOOGLE_CUDA && CUDNN_VERSION >= 8800 } -TEST_F(CudnnFusedMhaRewriterTestHloTest, - BF16Bmm1Bmm2Pattern_bmm1_rhs_contracting_dim_not_most_minor) { - const char* module_str = R"( +constexpr absl::string_view + hlo_BF16Bmm1Bmm2Pattern_bmm1_rhs_contracting_dim_not_most_minor = R"( HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}} ENTRY main.6 { @@ -215,10 +283,15 @@ ENTRY main.6 { Arg_1.2 = bf16[16,16,256,64]{2,3,1,0} parameter(1) dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={} ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(dot.0, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={} -} -)"; +})"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); +TEST_F(CudnnFusedMhaRewriterTestHloTest, + BF16Bmm1Bmm2Pattern_bmm1_rhs_contracting_dim_not_most_minor) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + TF_ASSERT_OK_AND_ASSIGN( + auto m, + ParseAndReturnVerifiedModule( + hlo_BF16Bmm1Bmm2Pattern_bmm1_rhs_contracting_dim_not_most_minor)); CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(), GetCudnnVersion()}; TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&fusedMhaRewriter, m.get())); @@ -231,16 +304,24 @@ ENTRY main.6 { GmockMatch(m::GetTupleElement( m::CustomCall(&fmha, {kCudnnfMHABmmBmmCallTarget}), 0) .WithShape(BF16, {16, 16, 256, 64}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(config.bmm1_dot_dimension_numbers().rhs_contracting_dimensions()[0], 2); -#if GOOGLE_CUDA && CUDNN_VERSION >= 8800 - // run whole pipeline +} + +TEST_F(CudnnFusedMhaRewriterPipelineTest, + BF16Bmm1Bmm2Pattern_bmm1_rhs_contracting_dim_not_most_minor) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; TF_ASSERT_OK_AND_ASSIGN( - m, ParseAndReturnVerifiedModule(module_str, GetModuleConfig())); + auto m, + ParseAndReturnVerifiedModule( + hlo_BF16Bmm1Bmm2Pattern_bmm1_rhs_contracting_dim_not_most_minor, + GetModuleConfig())); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, GetOptimizedModule(std::move(m))); + const HloInstruction* fmha; SCOPED_TRACE(optimized_module->ToString()); EXPECT_THAT( @@ -248,16 +329,15 @@ ENTRY main.6 { GmockMatch(m::GetTupleElement( m::CustomCall(&fmha, {kCudnnfMHABmmBmmCallTarget}), 0) .WithShape(BF16, {16, 16, 256, 64}))); - TF_ASSERT_OK_AND_ASSIGN(config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(config.bmm1_dot_dimension_numbers().rhs_contracting_dimensions()[0], 2); -#endif // GOOGLE_CUDA && CUDNN_VERSION >= 8800 } -TEST_F(CudnnFusedMhaRewriterTestHloTest, - BF16Bmm1Bmm2Pattern_bmm1_lhs_contracting_dim_not_most_minor) { - const char* module_str = R"( +constexpr absl::string_view + hlo_BF16Bmm1Bmm2Pattern_bmm1_lhs_contracting_dim_not_most_minor = R"( HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}} ENTRY main.6 { @@ -266,10 +346,15 @@ ENTRY main.6 { Arg_1.2 = bf16[16,16,256,64]{2,3,1,0} parameter(1) dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={} ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(dot.0, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={} -} -)"; +})"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); +TEST_F(CudnnFusedMhaRewriterTestHloTest, + BF16Bmm1Bmm2Pattern_bmm1_lhs_contracting_dim_not_most_minor) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + TF_ASSERT_OK_AND_ASSIGN( + auto m, + ParseAndReturnVerifiedModule( + hlo_BF16Bmm1Bmm2Pattern_bmm1_lhs_contracting_dim_not_most_minor)); CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(), GetCudnnVersion()}; TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&fusedMhaRewriter, m.get())); @@ -282,18 +367,26 @@ ENTRY main.6 { GmockMatch(m::GetTupleElement( m::CustomCall(&fmha, {kCudnnfMHABmmBmmCallTarget}), 0) .WithShape(BF16, {16, 16, 256, 64}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(config.bmm1_dot_dimension_numbers().lhs_contracting_dimensions()[0], 2); EXPECT_EQ(config.bmm1_dot_dimension_numbers().rhs_contracting_dimensions()[0], 2); -#if GOOGLE_CUDA && CUDNN_VERSION >= 8800 - // run whole pipeline +} + +TEST_F(CudnnFusedMhaRewriterPipelineTest, + BF16Bmm1Bmm2Pattern_bmm1_lhs_contracting_dim_not_most_minor) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; TF_ASSERT_OK_AND_ASSIGN( - m, ParseAndReturnVerifiedModule(module_str, GetModuleConfig())); + auto m, + ParseAndReturnVerifiedModule( + hlo_BF16Bmm1Bmm2Pattern_bmm1_lhs_contracting_dim_not_most_minor, + GetModuleConfig())); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, GetOptimizedModule(std::move(m))); + const HloInstruction* fmha; SCOPED_TRACE(optimized_module->ToString()); EXPECT_THAT( @@ -301,18 +394,17 @@ ENTRY main.6 { GmockMatch(m::GetTupleElement( m::CustomCall(&fmha, {kCudnnfMHABmmBmmCallTarget}), 0) .WithShape(BF16, {16, 16, 256, 64}))); - TF_ASSERT_OK_AND_ASSIGN(config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(config.bmm1_dot_dimension_numbers().lhs_contracting_dimensions()[0], 2); EXPECT_EQ(config.bmm1_dot_dimension_numbers().rhs_contracting_dimensions()[0], 2); -#endif // GOOGLE_CUDA && CUDNN_VERSION >= 8800 } -TEST_F(CudnnFusedMhaRewriterTestHloTest, - BF16Bmm1Bmm2Pattern_bmm2_non_contracting_dim_not_most_minor) { - const char* module_str = R"( +constexpr absl::string_view + hlo_BF16Bmm1Bmm2Pattern_bmm2_non_contracting_dim_not_most_minor = R"( HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}} ENTRY main.6 { @@ -321,10 +413,15 @@ ENTRY main.6 { Arg_1.2 = bf16[16,16,256,64]{2,3,1,0} parameter(1) dot.0 = bf16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={} ROOT dot.1 = bf16[16,16,256,64]{3,2,1,0} dot(dot.0, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={} -} -)"; +})"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); +TEST_F(CudnnFusedMhaRewriterTestHloTest, + BF16Bmm1Bmm2Pattern_bmm2_non_contracting_dim_not_most_minor) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + TF_ASSERT_OK_AND_ASSIGN( + auto m, + ParseAndReturnVerifiedModule( + hlo_BF16Bmm1Bmm2Pattern_bmm2_non_contracting_dim_not_most_minor)); CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(), GetCudnnVersion()}; TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&fusedMhaRewriter, m.get())); @@ -337,18 +434,26 @@ ENTRY main.6 { GmockMatch(m::GetTupleElement( m::CustomCall(&fmha, {kCudnnfMHABmmBmmCallTarget}), 0) .WithShape(BF16, {16, 16, 256, 64}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(config.bmm2_dot_dimension_numbers().lhs_contracting_dimensions()[0], 3); EXPECT_EQ(config.bmm2_dot_dimension_numbers().rhs_contracting_dimensions()[0], 3); -#if GOOGLE_CUDA && CUDNN_VERSION >= 8800 - // run whole pipeline +} + +TEST_F(CudnnFusedMhaRewriterPipelineTest, + BF16Bmm1Bmm2Pattern_bmm2_non_contracting_dim_not_most_minor) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; TF_ASSERT_OK_AND_ASSIGN( - m, ParseAndReturnVerifiedModule(module_str, GetModuleConfig())); + auto m, + ParseAndReturnVerifiedModule( + hlo_BF16Bmm1Bmm2Pattern_bmm2_non_contracting_dim_not_most_minor, + GetModuleConfig())); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, GetOptimizedModule(std::move(m))); + const HloInstruction* fmha; SCOPED_TRACE(optimized_module->ToString()); EXPECT_THAT( @@ -356,17 +461,16 @@ ENTRY main.6 { GmockMatch(m::GetTupleElement( m::CustomCall(&fmha, {kCudnnfMHABmmBmmCallTarget}), 0) .WithShape(BF16, {16, 16, 256, 64}))); - TF_ASSERT_OK_AND_ASSIGN(config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(config.bmm2_dot_dimension_numbers().lhs_contracting_dimensions()[0], 3); EXPECT_EQ(config.bmm2_dot_dimension_numbers().rhs_contracting_dimensions()[0], 3); -#endif // GOOGLE_CUDA && CUDNN_VERSION >= 8800 } -TEST_F(CudnnFusedMhaRewriterTestHloTest, F16Bmm1Bmm2Pattern) { - const char* module_str = R"( +absl::string_view F16Bmm1Bmm2Pattern_str = R"( HloModule fmha_test, entry_computation_layout={(f16[16,16,256,64]{3,2,1,0},f16[16,16,256,64]{3,2,1,0},f16[16,16,256,64]{3,2,1,0})->f16[16,16,256,64]{3,2,1,0}} ENTRY main.6 { Arg_2.3 = f16[16,16,256,64]{3,2,1,0} parameter(2) @@ -374,13 +478,13 @@ ENTRY main.6 { Arg_1.2 = f16[16,16,256,64]{3,2,1,0} parameter(1) dot.0 = f16[16,16,256,256]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={} ROOT dot.1 = f16[16,16,256,64]{3,2,1,0} dot(dot.0, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={} -} - - -)"; +})"; +TEST_F(CudnnFusedMhaRewriterTestHloTest, F16Bmm1Bmm2Pattern) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; TF_ASSERT_OK_AND_ASSIGN( - auto m, ParseAndReturnVerifiedModule(module_str, GetModuleConfig())); + auto m, + ParseAndReturnVerifiedModule(F16Bmm1Bmm2Pattern_str, GetModuleConfig())); CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(), GetCudnnVersion()}; TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); @@ -392,16 +496,21 @@ ENTRY main.6 { GmockMatch(m::GetTupleElement( m::CustomCall(&fmha, {kCudnnfMHABmmBmmCallTarget}), 0) .WithShape(F16, {16, 16, 256, 64}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(config.fmha_scale(), 1.0); EXPECT_EQ(config.dropout_rate(), 0.0); -#if GOOGLE_CUDA && CUDNN_VERSION >= 8800 - // run whole pipeline +} + +TEST_F(CudnnFusedMhaRewriterPipelineTest, F16Bmm1Bmm2Pattern) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; TF_ASSERT_OK_AND_ASSIGN( - m, ParseAndReturnVerifiedModule(module_str, GetModuleConfig())); + auto m, + ParseAndReturnVerifiedModule(F16Bmm1Bmm2Pattern_str, GetModuleConfig())); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, GetOptimizedModule(std::move(m))); + const HloInstruction* fmha; SCOPED_TRACE(optimized_module->ToString()); EXPECT_THAT( @@ -409,15 +518,14 @@ ENTRY main.6 { GmockMatch(m::GetTupleElement( m::CustomCall(&fmha, {kCudnnfMHABmmBmmCallTarget}), 0) .WithShape(F16, {16, 16, 256, 64}))); - TF_ASSERT_OK_AND_ASSIGN(config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_FLOAT_EQ(config.fmha_scale(), 1.0); EXPECT_FLOAT_EQ(config.dropout_rate(), 0.0); -#endif // GOOGLE_CUDA && CUDNN_VERSION >= 8800 } -TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1ScaleMaskSoftmaxBmm2Pattern) { - const char* module_str = R"( +constexpr absl::string_view hlo_BF16Bmm1ScaleMaskSoftmaxBmm2Pattern = R"( HloModule jit_bmm_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}} region_0.14.clone { @@ -458,11 +566,12 @@ ENTRY main.38 { convert.49 = bf16[16,16,256,256]{3,2,1,0} convert(divide.36) Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2) ROOT dot.37 = bf16[16,16,256,64]{3,2,1,0} dot(convert.49, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} -} - -)"; +})"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); +TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1ScaleMaskSoftmaxBmm2Pattern) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule( + hlo_BF16Bmm1ScaleMaskSoftmaxBmm2Pattern)); CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(), GetCudnnVersion()}; TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); @@ -475,18 +584,22 @@ ENTRY main.38 { m::GetTupleElement( m::CustomCall(&fmha, {kCudnnfMHAScaleMaskSoftmaxCallTarget}), 0) .WithShape(BF16, {16, 16, 256, 64}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_FLOAT_EQ(config.fmha_scale(), 2.1); EXPECT_FLOAT_EQ(config.dropout_rate(), 0.0); EXPECT_EQ(fmha->operands().size(), 4); +} -#if GOOGLE_CUDA && CUDNN_VERSION >= 8800 - // run whole pipeline +TEST_F(CudnnFusedMhaRewriterPipelineTest, BF16Bmm1ScaleMaskSoftmaxBmm2Pattern) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; TF_ASSERT_OK_AND_ASSIGN( - m, ParseAndReturnVerifiedModule(module_str, GetModuleConfig())); + auto m, ParseAndReturnVerifiedModule( + hlo_BF16Bmm1ScaleMaskSoftmaxBmm2Pattern, GetModuleConfig())); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, GetOptimizedModule(std::move(m))); + const HloInstruction* fmha; SCOPED_TRACE(optimized_module->ToString()); EXPECT_THAT( @@ -495,17 +608,15 @@ ENTRY main.38 { m::GetTupleElement( m::CustomCall(&fmha, {kCudnnfMHAScaleMaskSoftmaxCallTarget}), 0) .WithShape(BF16, {16, 16, 256, 64}))); - TF_ASSERT_OK_AND_ASSIGN(config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_FLOAT_EQ(config.fmha_scale(), 2.1); EXPECT_FLOAT_EQ(config.dropout_rate(), 0.0); EXPECT_EQ(fmha->operands().size(), 4); -#endif // GOOGLE_CUDA && CUDNN_VERSION >= 8800 } -TEST_F(CudnnFusedMhaRewriterTestHloTest, - BF16Bmm1ScaleBiasMaskSoftmaxBmm2Pattern) { - const char* module_str = R"( +constexpr absl::string_view hlo_BF16Bmm1ScaleBiasMaskSoftmaxBmm2Pattern = R"( HloModule jit_bmm_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}} region_0.17.clone { @@ -549,11 +660,14 @@ ENTRY main.41 { convert.49 = bf16[16,16,256,256]{3,2,1,0} convert(divide.36) Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2) ROOT dot.37 = bf16[16,16,256,64]{3,2,1,0} dot(convert.49, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} -} - -)"; +})"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); +TEST_F(CudnnFusedMhaRewriterTestHloTest, + BF16Bmm1ScaleBiasMaskSoftmaxBmm2Pattern) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + TF_ASSERT_OK_AND_ASSIGN(auto m, + ParseAndReturnVerifiedModule( + hlo_BF16Bmm1ScaleBiasMaskSoftmaxBmm2Pattern)); CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(), GetCudnnVersion()}; TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); @@ -567,18 +681,24 @@ ENTRY main.41 { m::CustomCall(&fmha, {kCudnnfMHAScaleBiasMaskSoftmaxCallTarget}), 0) .WithShape(BF16, {16, 16, 256, 64}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_FLOAT_EQ(config.fmha_scale(), 3.1); EXPECT_FLOAT_EQ(config.dropout_rate(), 0.0); EXPECT_EQ(fmha->operands().size(), 5); +} -#if GOOGLE_CUDA && CUDNN_VERSION >= 8800 - // run whole pipeline +TEST_F(CudnnFusedMhaRewriterPipelineTest, + BF16Bmm1ScaleBiasMaskSoftmaxBmm2Pattern) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; TF_ASSERT_OK_AND_ASSIGN( - m, ParseAndReturnVerifiedModule(module_str, GetModuleConfig())); + auto m, + ParseAndReturnVerifiedModule(hlo_BF16Bmm1ScaleBiasMaskSoftmaxBmm2Pattern, + GetModuleConfig())); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, GetOptimizedModule(std::move(m))); + const HloInstruction* fmha; SCOPED_TRACE(optimized_module->ToString()); EXPECT_THAT( @@ -588,17 +708,16 @@ ENTRY main.41 { m::CustomCall(&fmha, {kCudnnfMHAScaleBiasMaskSoftmaxCallTarget}), 0) .WithShape(BF16, {16, 16, 256, 64}))); - TF_ASSERT_OK_AND_ASSIGN(config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_FLOAT_EQ(config.fmha_scale(), 3.1); EXPECT_FLOAT_EQ(config.dropout_rate(), 0.0); EXPECT_EQ(fmha->operands().size(), 5); -#endif // GOOGLE_CUDA && CUDNN_VERSION >= 8800 } -TEST_F(CudnnFusedMhaRewriterTestHloTest, - BF16Bmm1ScaleBiasNonConstantMaskSoftmaxBmm2Pattern) { - const char* module_str = R"( +constexpr absl::string_view + hlo_BF16Bmm1ScaleBiasNonConstantMaskSoftmaxBmm2Pattern = R"( HloModule jit_bmm_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}} region_0.17.clone { @@ -628,7 +747,7 @@ ENTRY main.41 { convert.40 = bf16[16,16,256,256]{3,2,1,0} convert(add.15) constant.4 = bf16[] constant(0) broadcast.5 = bf16[16,16,256,256]{3,2,1,0} broadcast(constant.4), dimensions={} - compare = pred[16,16,256,256]{3,2,1,0} compare(convert.40, broadcast.5), direction=GT + compare = pred[16,16,256,256]{3,2,1,0} compare(convert.40, broadcast.5), direction=GT select.13 = bf16[16,16,256,256]{3,2,1,0} select(compare, convert.40, broadcast.5) convert.36 = f32[16,16,256,256]{3,2,1,0} convert(select.13) constant.9 = f32[] constant(-inf) @@ -643,11 +762,14 @@ ENTRY main.41 { convert.49 = bf16[16,16,256,256]{3,2,1,0} convert(divide.36) Arg_2.3 = bf16[16,16,256,64]{3,2,1,0} parameter(2) ROOT dot.37 = bf16[16,16,256,64]{3,2,1,0} dot(convert.49, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} -} - -)"; +})"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); +TEST_F(CudnnFusedMhaRewriterTestHloTest, + BF16Bmm1ScaleBiasNonConstantMaskSoftmaxBmm2Pattern) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + TF_ASSERT_OK_AND_ASSIGN( + auto m, ParseAndReturnVerifiedModule( + hlo_BF16Bmm1ScaleBiasNonConstantMaskSoftmaxBmm2Pattern)); CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(), GetCudnnVersion()}; TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); @@ -663,18 +785,24 @@ ENTRY main.41 { m::CustomCall(&fmha, {kCudnnfMHAScaleBiasMaskSoftmaxCallTarget}), 0) .WithShape(BF16, {16, 16, 256, 64}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_FLOAT_EQ(config.fmha_scale(), 3.1); EXPECT_FLOAT_EQ(config.dropout_rate(), 0.0); EXPECT_EQ(fmha->operands().size(), 5); +} -#if GOOGLE_CUDA && CUDNN_VERSION >= 8800 - // run whole pipeline +TEST_F(CudnnFusedMhaRewriterPipelineTest, + BF16Bmm1ScaleBiasNonConstantMaskSoftmaxBmm2Pattern) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; TF_ASSERT_OK_AND_ASSIGN( - m, ParseAndReturnVerifiedModule(module_str, GetModuleConfig())); + auto m, ParseAndReturnVerifiedModule( + hlo_BF16Bmm1ScaleBiasNonConstantMaskSoftmaxBmm2Pattern, + GetModuleConfig())); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, GetOptimizedModule(std::move(m))); + const HloInstruction* fmha; SCOPED_TRACE(optimized_module->ToString()); EXPECT_THAT( @@ -684,15 +812,16 @@ ENTRY main.41 { m::CustomCall(&fmha, {kCudnnfMHAScaleBiasMaskSoftmaxCallTarget}), 0) .WithShape(BF16, {16, 16, 256, 64}))); - TF_ASSERT_OK_AND_ASSIGN(config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_FLOAT_EQ(config.fmha_scale(), 3.1); EXPECT_FLOAT_EQ(config.dropout_rate(), 0.0); EXPECT_EQ(fmha->operands().size(), 5); -#endif // GOOGLE_CUDA && CUDNN_VERSION >= 8800 } TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1CombinedMaskBiasSoftmaxBmm2) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[1,16,256,256]{3,2,1,0},pred[16,1,256,256]{3,2,1,0})->bf16[16,256,16,64]{3,2,1,0}} @@ -767,13 +896,14 @@ ENTRY main.61 { m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}), 0))) .WithShape(BF16, {16, 256, 16, 64}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); EXPECT_EQ(fmha->operands().size(), 4); } TEST_F(CudnnFusedMhaRewriterTestHloTest, F16Bmm1ScaleBiasMaskSoftmaxDropoutBmm2) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,40,64]{3,2,1,0},f16[2,6,64,40]{3,2,1,0},f16[2,6,40,64]{3,2,1,0})->f16[2,6,40,64]{3,2,1,0}}, allow_spmd_sharding_propagation_to_output={true} @@ -873,14 +1003,16 @@ ENTRY main.83 { {kCudnnfMHAScaleBiasMaskSoftmaxDropoutCallTarget}), 0) .WithShape(F16, {2, 6, 40, 64}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_FLOAT_EQ(config.fmha_scale(), 2); EXPECT_NEAR(config.dropout_rate(), 0.2, 1e-2); EXPECT_EQ(fmha->operands().size(), 5); } TEST_F(CudnnFusedMhaRewriterTestHloTest, F16Bmm1UnfusedSoftmaxBmm2) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,40,64]{3,2,1,0},f16[2,6,64,40]{3,2,1,0},f16[2,6,40,64]{3,2,1,0})->f16[2,6,40,64]{3,2,1,0}} @@ -927,8 +1059,9 @@ ENTRY main.31 { GmockMatch(m::GetTupleElement( m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0) .WithShape(F16, {2, 6, 40, 64}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_FLOAT_EQ(config.fmha_scale(), 1.0); EXPECT_FLOAT_EQ(config.dropout_rate(), 0.0); EXPECT_EQ(fmha->operands().size(), 3); @@ -936,6 +1069,7 @@ ENTRY main.31 { TEST_F(CudnnFusedMhaRewriterTestHloTest, F16Bmm1UnfusedSoftmaxWithConvertF32ToReduceMaxBmm2) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[128,6,400,64]{3,2,1,0},f16[128,6,64,400]{3,2,1,0},f16[128,6,400,64]{3,2,1,0})->f16[128,6,400,64]{3,2,1,0}} @@ -995,8 +1129,9 @@ ENTRY main.41 { m::CustomCall(&fmha, {kCudnnfMHAScaleBiasMaskSoftmaxCallTarget}), 0) .WithShape(F16, {128, 6, 400, 64}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_FLOAT_EQ(config.fmha_scale(), 2.0); EXPECT_FLOAT_EQ(config.dropout_rate(), 0.0); EXPECT_EQ(fmha->operands().size(), 5); @@ -1004,6 +1139,7 @@ ENTRY main.41 { TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1UnfusedScaleMaskBiasSoftmaxBmm2) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[1,16,256,256]{3,2,1,0},pred[16,1,256,256]{3,2,1,0})->bf16[16,256,16,64]{3,2,1,0}} @@ -1080,14 +1216,16 @@ ENTRY main.61 { m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}), 0))) .WithShape(BF16, {16, 256, 16, 64}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(fmha->operands().size(), 4); EXPECT_FLOAT_EQ(config.fmha_scale(), 2.0); } TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1ConvertedMaskAddedAfterFirstGemmSoftmaxBmm2) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},pred[16,1,256,256]{3,2,1,0})->bf16[16,256,16,64]{3,2,1,0}} @@ -1156,14 +1294,15 @@ ENTRY main.56 { m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}), 0))) .WithShape(BF16, {16, 256, 16, 64}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); EXPECT_EQ(fmha->operands().size(), 4); } // negative test TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1Bmm2Pattern_bmm1_contracting_dim_not_equal_64) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,256,64]{3,2,1,0}} ENTRY main.6 { @@ -1191,6 +1330,7 @@ ENTRY main.6 { TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1Bmm2Pattern_bmm1_non_contracting_dim_larger_than_512) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule fmha_test, entry_computation_layout={(bf16[16,16,1024,64]{3,2,1,0},bf16[16,16,1024,64]{3,2,1,0},bf16[16,16,1024,64]{3,2,1,0})->bf16[16,16,1024,64]{3,2,1,0}} ENTRY main.6 { @@ -1217,6 +1357,7 @@ ENTRY main.6 { TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1Bmm2Pattern_bmm2_rhs_non_contracting_dim_not_equal_64) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0},bf16[16,16,256,32]{3,2,1,0})->bf16[16,16,256,32]{3,2,1,0}} ENTRY main.6 { @@ -1244,6 +1385,7 @@ ENTRY main.6 { // check if MHA is unsupported, canonicalization will not kick in TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1Bmm2PatternUncanonicalized_bmm1_contracting_dim_not_equal_64) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule fmha_test, entry_computation_layout={(bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,32]{3,2,1,0},bf16[16,16,256,64]{3,2,1,0})->bf16[16,16,64,256]{3,2,1,0}} @@ -1271,6 +1413,7 @@ ENTRY main.6 { } TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1BiasSoftmaxDropoutBmm2) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[1,16,256,256]{3,2,1,0})->bf16[16,256,16,64]{3,2,1,0}} @@ -1374,14 +1517,16 @@ ENTRY main.82 { &fmha, {kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget}), 0)))) .WithShape(BF16, {16, 256, 16, 64}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(fmha->operands().size(), 4); EXPECT_NEAR(config.dropout_rate(), 0.1, 1e-2); } TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1ScaleBiasSoftmaxDropoutForm2Bmm2) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[32,40,60,64]{3,2,1,0},bf16[32,40,60,64]{3,2,1,0},bf16[32,40,60,64]{3,2,1,0})->bf16[32,40,60,64]{3,2,1,0}}, allow_spmd_sharding_propagation_to_output={true} @@ -1483,13 +1628,15 @@ ENTRY main.79 { &fmha, {kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget}), 0)))) .WithShape(BF16, {32, 40, 60, 64}))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_NEAR(config.dropout_rate(), 0.1, 1e-2); EXPECT_EQ(fmha->operands().size(), 4); } TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16TrainingBmm1Bmm2) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0})->(bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0})} @@ -1543,6 +1690,7 @@ ENTRY main.17 { TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16TrainingBmm1ScaleBiasSoftmaxDropoutBmm2) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[1,16,256,256]{3,2,1,0},pred[16,1,256,256]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0})->(bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[1,16,256,256]{3,2,1,0})} @@ -1727,14 +1875,16 @@ ENTRY main.146 { .WithShape(BF16, {1, 16, 256, 256})), 0)), m::Op(), m::Op(), m::Op(), m::Op()))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(fmha->operands().size(), 5); EXPECT_NEAR(config.dropout_rate(), 0.1, 1e-2); } TEST_F(CudnnFusedMhaRewriterTestHloTest, F16TrainingBmm1ScaleBiasSoftmaxDropoutBmm2) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[16,256,16,64]{3,2,1,0},f16[16,256,16,64]{3,2,1,0},f16[16,256,16,64]{3,2,1,0},f16[1,16,256,256]{3,2,1,0},pred[16,1,256,256]{3,2,1,0},f16[16,256,16,64]{3,2,1,0})->(f16[16,256,16,64]{3,2,1,0}, f16[16,256,16,64]{3,2,1,0}, f16[16,256,16,64]{3,2,1,0}, f16[16,256,16,64]{3,2,1,0}, f16[1,16,256,256]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true,true} @@ -1917,14 +2067,16 @@ ENTRY main.146 { .WithShape(F16, {1, 16, 256, 256})), 0)), m::Op(), m::Op(), m::Op(), m::Op()))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(fmha->operands().size(), 5); EXPECT_NEAR(config.dropout_rate(), 0.1, 1e-2); } TEST_F(CudnnFusedMhaRewriterTestHloTest, F16TrainingBmm1ScaleBiasSoftmaxDropoutBmm2WithTransposeFusion) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[16,256,16,64]{3,2,1,0},f16[16,256,16,64]{3,2,1,0},f16[16,256,16,64]{3,2,1,0},f16[1,16,256,256]{3,2,1,0},pred[16,1,256,256]{3,2,1,0},f16[16,256,16,64]{3,2,1,0})->(f16[16,256,16,64]{3,2,1,0}, f16[16,256,16,64]{3,2,1,0}, f16[16,256,16,64]{3,2,1,0}, f16[16,256,16,64]{3,2,1,0}, f16[1,16,256,256]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true,true} @@ -2129,13 +2281,15 @@ ENTRY main.146 { m::CustomCall( {kCudnnfMHAScaleBiasSoftmaxDropoutBackwardCallTarget}), dbias_index)))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(fmha->operands().size(), 5); EXPECT_NEAR(config.dropout_rate(), 0.1, 1e-2); } TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16MiniT5xTest) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__lambda_, entry_computation_layout={(bf16[12,512,32,64]{3,2,1,0},bf16[12,512,2,32,64]{4,3,2,1,0},f32[12,512]{1,0},f32[12,512]{1,0})->(bf16[], bf16[12,512,32,64]{3,2,1,0}, bf16[12,512,2,32,64]{4,3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true} @@ -2285,6 +2439,7 @@ ENTRY main.129 { TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16TrainingBmm1ScaleBiasMaskSoftmaxDropoutBmm2) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,128,64]{3,2,1,0},bf16[2,6,64,128]{3,2,1,0},bf16[2,6,128,64]{3,2,1,0},bf16[2,6,128,64]{3,2,1,0})->(bf16[2,6,128,64]{3,2,1,0}, bf16[2,6,128,64]{3,2,1,0}, bf16[2,6,64,128]{3,2,1,0}, bf16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} @@ -2427,14 +2582,16 @@ ENTRY main.126 { .WithShape(BF16, {2, 6, 64, 128}), m::GetTupleElement(m::CustomCall({backward_target}), 2) .WithShape(BF16, {2, 6, 128, 64})))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(fmha->operands().size(), 6); EXPECT_NEAR(config.dropout_rate(), 0.1, 1e-2); } TEST_F(CudnnFusedMhaRewriterTestHloTest, F16TrainingBmm1ScaleBiasMaskSoftmaxDropoutBmm2) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,128,64]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} @@ -2578,14 +2735,16 @@ ENTRY main.126 { .WithShape(F16, {2, 6, 64, 128}), m::GetTupleElement(m::CustomCall({backward_target}), 2) .WithShape(F16, {2, 6, 128, 64})))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(fmha->operands().size(), 6); EXPECT_NEAR(config.dropout_rate(), 0.1, 1e-2); } TEST_F(CudnnFusedMhaRewriterTestHloTest, F16TrainingBmm1ScaleBiasMaskSoftmaxBmm2) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,128,64]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} @@ -2693,14 +2852,16 @@ ENTRY main.82 { m::CustomCall({kCudnnfMHAScaleBiasMaskSoftmaxBackwardCallTarget}), 2) .WithShape(F16, {2, 6, 128, 64})))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(fmha->operands().size(), 6); EXPECT_NEAR(config.dropout_rate(), 0, 1e-2); } TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16TrainingBmm1ScaleBiasSoftmaxDropoutBmm2DbiasShouldHaveUserShape) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[1,16,256,256]{3,2,1,0},pred[16,1,256,256]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0})->(bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[1,16,256,256]{3,2,1,0})} @@ -2887,14 +3048,16 @@ ENTRY main.146 { .WithShape(BF16, {16, 256, 256})))), 0)), m::Op(), m::Op(), m::Op(), m::Op()))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(fmha->operands().size(), 5); EXPECT_NEAR(config.dropout_rate(), 0.1, 1e-2); } TEST_F(CudnnFusedMhaRewriterTestHloTest, ActivationHasMoreThan1UserShouldNotLower) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule test @@ -2953,6 +3116,7 @@ ENTRY main { TEST_F(CudnnFusedMhaRewriterTestHloTest, F16InvalidTrainingBmm1ScaleBiasMaskSoftmaxBmm2ShouldNotBeLowered) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,128,64]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} @@ -3046,6 +3210,7 @@ ENTRY main.82 { TEST_F(CudnnFusedMhaRewriterTestHloTest, F16InvalidTrainingBmm1ScaleBiasMaskSoftmaxDropoutBmm2ShouldNotLower) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,128,64]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} @@ -3182,6 +3347,7 @@ ENTRY main.126 { TEST_F(CudnnFusedMhaRewriterTestHloTest, F16TrainingBmm1ScaleBiasSoftmaxBmm2QTranspose) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,128]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} @@ -3282,12 +3448,1790 @@ ENTRY main.82 { m::GetTupleElement( m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}), 2) .WithShape(F16, {2, 6, 128, 64})))); - TF_ASSERT_OK_AND_ASSIGN(auto config, - fmha->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); EXPECT_EQ(fmha->operands().size(), 5); EXPECT_NEAR(config.dropout_rate(), 0, 1e-2); } +TEST_F(CudnnFusedMhaRewriterTestHloTest, + F16Bmm1UnfusedSoftmaxBmm2IncorrectBmm1NumUsers) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + const char* module_str = R"( +HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,40,64]{3,2,1,0},f16[2,6,64,40]{3,2,1,0},f16[2,6,40,64]{3,2,1,0})->(f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0})} + +region_0.7 { + Arg_0.8 = f16[] parameter(0) + Arg_1.9 = f16[] parameter(1) + ROOT maximum = f16[] maximum(Arg_0.8, Arg_1.9) +} + +region_1.19 { + Arg_0.20 = f32[] parameter(0) + Arg_1.21 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0.20, Arg_1.21) +} + +ENTRY main.31 { + Arg_0.1 = f16[2,6,40,64]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = f16[2,6,64,40]{3,2,1,0} parameter(1), sharding={replicated} + dot = f16[2,6,40,40]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1} + // extra user of bmm1 + neg.1 = f16[2,6,40,40]{3,2,1,0} negate(dot) + constant = f16[] constant(-inf) + reduce.11 = f16[2,6,40]{2,1,0} reduce(dot, constant), dimensions={3}, to_apply=region_0.7 + broadcast.3 = f16[2,6,40,40]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2} + subtract.1 = f16[2,6,40,40]{3,2,1,0} subtract(dot, broadcast.3) + exponential.1 = f16[2,6,40,40]{3,2,1,0} exponential(subtract.1) + convert.1 = f32[2,6,40,40]{3,2,1,0} convert(exponential.1) + constant.1 = f32[] constant(0) + reduce.23 = f32[2,6,40]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19 + convert.2 = f16[2,6,40]{2,1,0} convert(reduce.23) + broadcast.4 = f16[2,6,40,40]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2} + divide = f16[2,6,40,40]{3,2,1,0} divide(exponential.1, broadcast.4) + Arg_2.3 = f16[2,6,40,64]{3,2,1,0} parameter(2), sharding={replicated} + dot.1 = f16[2,6,40,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1} + ROOT tuple.81 = (f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0}) tuple(dot.1, neg.1) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(), + GetCudnnVersion()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + + SCOPED_TRACE(m->ToString()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple(m::Dot(), m::Negate()))); +} + +TEST_F(CudnnFusedMhaRewriterTestHloTest, + F16Bmm1UnfusedSoftmaxBmm2IncorrectSoftmaxNumUsers) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + const char* module_str = R"( +HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,40,64]{3,2,1,0},f16[2,6,64,40]{3,2,1,0},f16[2,6,40,64]{3,2,1,0})->(f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0})} + +region_0.7 { + Arg_0.8 = f16[] parameter(0) + Arg_1.9 = f16[] parameter(1) + ROOT maximum = f16[] maximum(Arg_0.8, Arg_1.9) +} + +region_1.19 { + Arg_0.20 = f32[] parameter(0) + Arg_1.21 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0.20, Arg_1.21) +} + +ENTRY main.31 { + Arg_0.1 = f16[2,6,40,64]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = f16[2,6,64,40]{3,2,1,0} parameter(1), sharding={replicated} + dot = f16[2,6,40,40]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1} + constant = f16[] constant(-inf) + reduce.11 = f16[2,6,40]{2,1,0} reduce(dot, constant), dimensions={3}, to_apply=region_0.7 + broadcast.3 = f16[2,6,40,40]{3,2,1,0} broadcast(reduce.11), dimensions={0,1,2} + subtract.1 = f16[2,6,40,40]{3,2,1,0} subtract(dot, broadcast.3) + // extra user of softmax sub node + neg.1 = f16[2,6,40,40]{3,2,1,0} negate(subtract.1) + exponential.1 = f16[2,6,40,40]{3,2,1,0} exponential(subtract.1) + convert.1 = f32[2,6,40,40]{3,2,1,0} convert(exponential.1) + constant.1 = f32[] constant(0) + reduce.23 = f32[2,6,40]{2,1,0} reduce(convert.1, constant.1), dimensions={3}, to_apply=region_1.19 + convert.2 = f16[2,6,40]{2,1,0} convert(reduce.23) + broadcast.4 = f16[2,6,40,40]{3,2,1,0} broadcast(convert.2), dimensions={0,1,2} + divide = f16[2,6,40,40]{3,2,1,0} divide(exponential.1, broadcast.4) + Arg_2.3 = f16[2,6,40,64]{3,2,1,0} parameter(2), sharding={replicated} + dot.1 = f16[2,6,40,64]{3,2,1,0} dot(divide, Arg_2.3), lhs_contracting_dims={3}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1} + ROOT tuple.81 = (f16[2,6,40,64]{3,2,1,0}, f16[2,6,40,40]{3,2,1,0}) tuple(dot.1, neg.1) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(), + GetCudnnVersion()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + + SCOPED_TRACE(m->ToString()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple(m::Dot(), m::Negate()))); +} + +TEST_F(CudnnFusedMhaRewriterTestHloTest, + F16TrainingBmm1ScaleBiasSoftmaxBmm2IncorrectSoftmaxBwdNumUsers) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + const char* module_str = R"( +HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,128]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + +region_0.21 { + Arg_0.22 = f16[] parameter(0) + Arg_1.23 = f16[] parameter(1) + ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23) +} + +region_1.33 { + Arg_0.34 = f32[] parameter(0) + Arg_1.35 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0.34, Arg_1.35) +} + +region_2.55 { + Arg_0.56 = f16[] parameter(0) + Arg_1.57 = f16[] parameter(1) + ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57) +} + +ENTRY main.82 { + Arg_0.1 = f16[2,6,64,128]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated} + dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.22 = f16[] constant(2) + broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={} + multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24) + constant.19 = f16[] constant(1) + broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={} + add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13) + constant.21 = f16[] constant(0) + constant.15 = f16[] constant(-inf) + reduce.25 = f16[2,6,128]{2,1,0} reduce(add.3, constant.15), dimensions={3}, to_apply=region_0.21 + broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2} + subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(add.3, broadcast.17) + exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1) + convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1) + constant.17 = f32[] constant(0) + reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33 + convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37) + broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2} + divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26) + Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated} + dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated} + dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26) + // extra user of softmax bwd divide node + neg.1 = f16[2,6,128,128]{3,2,1,0} negate(divide.4) + broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={} + multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9) + divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3) + broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2} + multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21) + multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1) + reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55 + negate.2 = f16[2,6,128]{2,1,0} negate(reduce.59) + broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2} + add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25) + multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1) + multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.8, broadcast.24) + dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1, neg.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), + GetCudnnVersionWithDbiasAndMaskBwdInputSupport()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + HloDCE dce; + TF_ASSERT_OK(RunHloPass(&dce, m.get()).status()); + + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + + SCOPED_TRACE(m->ToString()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple(m::Dot(), m::Dot(), m::Dot(), m::Dot(), + m::Negate()))); +} + +TEST_F(CudnnFusedMhaRewriterTestHloTest, F16Bmm1SoftmaxBmm2IncorrectRank) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + const char* module_str = R"( +HloModule reproducer, entry_computation_layout={(f16[1,8,16,5,128]{4,3,2,1,0}, f16[1,8,16,5,128]{4,3,2,1,0}, f16[1,8,16,5,128]{4,3,2,1,0}, f32[128,2,64]{2,1,0}, f32[2,64]{1,0}, /*index=5*/f32[128,2,64]{2,1,0}, f32[2,64]{1,0}, f32[128,2,64]{2,1,0}, f32[2,64]{1,0})->f16[8,16,2,5,64]{4,3,2,1,0}} + +region_0.36 { + Arg_0.37 = f16[] parameter(0) + Arg_1.38 = f16[] parameter(1) + ROOT maximum = f16[] maximum(Arg_0.37, Arg_1.38) +} + +region_1.48 { + Arg_0.49 = f32[] parameter(0) + Arg_1.50 = f32[] parameter(1) + ROOT add.1 = f32[] add(Arg_0.49, Arg_1.50) +} + +ENTRY main { + arg2.3 = f16[1,8,16,5,128]{4,3,2,1,0} parameter(2), parameter_replication={false} + bitcast.31 = f16[640,128]{1,0} bitcast(arg2.3) + arg5.6 = f32[128,2,64]{2,1,0} parameter(5), parameter_replication={false} + convert.3 = f16[128,2,64]{2,1,0} convert(arg5.6) + bitcast.36 = f16[128,128]{1,0} bitcast(convert.3) + dot = f16[640,128]{1,0} dot(bitcast.31, bitcast.36), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"} + bitcast.39 = f16[1,8,16,5,2,64]{5,4,3,2,1,0} bitcast(dot) + transpose.27 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} transpose(bitcast.39), dimensions={0,1,2,4,3,5}, frontend_attributes={grad_x="false",grad_y="false"} + arg6.7 = f32[2,64]{1,0} parameter(6), parameter_replication={false} + convert.4 = f16[2,64]{1,0} convert(arg6.7) + broadcast.9 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} broadcast(convert.4), dimensions={3,5} + add.2 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} add(transpose.27, broadcast.9) + bitcast.49 = f16[8,16,2,5,64]{4,3,2,1,0} bitcast(add.2) + arg0.1 = f16[1,8,16,5,128]{4,3,2,1,0} parameter(0), parameter_replication={false} + bitcast.53 = f16[640,128]{1,0} bitcast(arg0.1) + arg3.4 = f32[128,2,64]{2,1,0} parameter(3), parameter_replication={false} + convert.5 = f16[128,2,64]{2,1,0} convert(arg3.4) + bitcast.58 = f16[128,128]{1,0} bitcast(convert.5) + dot.1 = f16[640,128]{1,0} dot(bitcast.53, bitcast.58), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"} + bitcast.61 = f16[1,8,16,5,2,64]{5,4,3,2,1,0} bitcast(dot.1) + transpose.28 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} transpose(bitcast.61), dimensions={0,1,2,4,5,3}, frontend_attributes={grad_x="false",grad_y="false"} + arg4.5 = f32[2,64]{1,0} parameter(4), parameter_replication={false} + convert.6 = f16[2,64]{1,0} convert(arg4.5) + broadcast.10 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} broadcast(convert.6), dimensions={3,4} + add.3 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} add(transpose.28, broadcast.10) + constant.29 = f16[] constant(0.125) + broadcast.11 = f16[1,8,16,2,64,5]{5,4,3,2,1,0} broadcast(constant.29), dimensions={} + multiply = f16[1,8,16,2,64,5]{5,4,3,2,1,0} multiply(add.3, broadcast.11) + bitcast.74 = f16[8,16,2,64,5]{4,3,2,1,0} bitcast(multiply) + dot.6 = f16[8,16,2,5,5]{4,3,2,1,0} dot(bitcast.49, bitcast.74), lhs_batch_dims={0,1,2}, lhs_contracting_dims={4}, rhs_batch_dims={0,1,2}, rhs_contracting_dims={3}, frontend_attributes={grad_x="false",grad_y="false"} + constant.35 = f16[] constant(-inf) + reduce.1 = f16[8,16,2,5]{3,2,1,0} reduce(dot.6, constant.35), dimensions={3}, to_apply=region_0.36 + broadcast.12 = f16[8,16,2,5,5]{4,3,2,1,0} broadcast(reduce.1), dimensions={0,1,2,4} + subtract.2 = f16[8,16,2,5,5]{4,3,2,1,0} subtract(dot.6, broadcast.12) + exponential.2 = f16[8,16,2,5,5]{4,3,2,1,0} exponential(subtract.2) + convert.7 = f32[8,16,2,5,5]{4,3,2,1,0} convert(exponential.2) + constant.34 = f32[] constant(0) + reduce.3 = f32[8,16,2,5]{3,2,1,0} reduce(convert.7, constant.34), dimensions={3}, to_apply=region_1.48 + convert.8 = f16[8,16,2,5]{3,2,1,0} convert(reduce.3) + broadcast.13 = f16[8,16,2,5,5]{4,3,2,1,0} broadcast(convert.8), dimensions={0,1,2,4} + divide.2 = f16[8,16,2,5,5]{4,3,2,1,0} divide(exponential.2, broadcast.13) + bitcast.98 = f16[8,16,2,5,5]{3,4,2,1,0} bitcast(divide.2) + arg1.2 = f16[1,8,16,5,128]{4,3,2,1,0} parameter(1), parameter_replication={false} + bitcast.102 = f16[640,128]{1,0} bitcast(arg1.2) + arg7.8 = f32[128,2,64]{2,1,0} parameter(7), parameter_replication={false} + convert.9 = f16[128,2,64]{2,1,0} convert(arg7.8) + bitcast.107 = f16[128,128]{1,0} bitcast(convert.9) + dot.3 = f16[640,128]{1,0} dot(bitcast.102, bitcast.107), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"} + bitcast.110 = f16[1,8,16,5,2,64]{5,4,3,2,1,0} bitcast(dot.3) + transpose.30 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} transpose(bitcast.110), dimensions={0,1,2,4,3,5}, frontend_attributes={grad_x="false",grad_y="false"} + arg8.9 = f32[2,64]{1,0} parameter(8), parameter_replication={false} + convert.10 = f16[2,64]{1,0} convert(arg8.9) + broadcast.14 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} broadcast(convert.10), dimensions={3,5} + add.4 = f16[1,8,16,2,5,64]{5,4,3,2,1,0} add(transpose.30, broadcast.14) + bitcast.120 = f16[8,16,2,5,64]{4,3,2,1,0} bitcast(add.4) + ROOT dot.7 = f16[8,16,2,5,64]{4,3,2,1,0} dot(bitcast.98, bitcast.120), lhs_batch_dims={0,1,2}, lhs_contracting_dims={4}, rhs_batch_dims={0,1,2}, rhs_contracting_dims={3}, frontend_attributes={grad_x="false",grad_y="false"} +} // main +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), + GetCudnnVersionWithDbiasAndMaskBwdInputSupport()}; + const auto status_or = RunHloPass(&fusedMhaRewriter, m.get()); + TF_ASSERT_OK(status_or.status()); + EXPECT_FALSE(status_or.value()); + + HloDCE dce; + TF_ASSERT_OK(RunHloPass(&dce, m.get()).status()); + + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + + SCOPED_TRACE(m->ToString()); + EXPECT_THAT(m->entry_computation()->root_instruction(), GmockMatch(m::Dot())); +} + +TEST_F(CudnnFusedMhaRewriterTestHloTest, + F16TrainingBmm1ScaleBiasSoftmaxBmm2NonContractingDimNotDivisibleBy64) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + const char* module_str = R"( +HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,100]{3,2,1,0},f16[2,6,64,100]{3,2,1,0},f16[2,6,100,64]{3,2,1,0},f16[2,6,100,64]{3,2,1,0})->(f16[2,6,100,64]{3,2,1,0}, f16[2,6,100,64]{3,2,1,0}, f16[2,6,64,100]{3,2,1,0}, f16[2,6,100,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + +region_0.21 { + Arg_0.22 = f16[] parameter(0) + Arg_1.23 = f16[] parameter(1) + ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23) +} + +region_1.33 { + Arg_0.34 = f32[] parameter(0) + Arg_1.35 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0.34, Arg_1.35) +} + +region_2.55 { + Arg_0.56 = f16[] parameter(0) + Arg_1.57 = f16[] parameter(1) + ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57) +} + +ENTRY main.82 { + Arg_0.1 = f16[2,6,64,100]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = f16[2,6,64,100]{3,2,1,0} parameter(1), sharding={replicated} + dot.17 = f16[2,6,100,100]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.22 = f16[] constant(2) + broadcast.24 = f16[2,6,100,100]{3,2,1,0} broadcast(constant.22), dimensions={} + multiply.2 = f16[2,6,100,100]{3,2,1,0} multiply(dot.17, broadcast.24) + constant.19 = f16[] constant(1) + broadcast.13 = f16[2,6,100,100]{3,2,1,0} broadcast(constant.19), dimensions={} + add.3 = f16[2,6,100,100]{3,2,1,0} add(multiply.2, broadcast.13) + constant.21 = f16[] constant(0) + constant.15 = f16[] constant(-inf) + reduce.25 = f16[2,6,100]{2,1,0} reduce(add.3, constant.15), dimensions={3}, to_apply=region_0.21 + broadcast.17 = f16[2,6,100,100]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2} + subtract.1 = f16[2,6,100,100]{3,2,1,0} subtract(add.3, broadcast.17) + exponential.1 = f16[2,6,100,100]{3,2,1,0} exponential(subtract.1) + convert.5 = f32[2,6,100,100]{3,2,1,0} convert(exponential.1) + constant.17 = f32[] constant(0) + reduce.37 = f32[2,6,100]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33 + convert.9 = f16[2,6,100]{2,1,0} convert(reduce.37) + broadcast.26 = f16[2,6,100,100]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2} + divide.5 = f16[2,6,100,100]{3,2,1,0} divide(exponential.1, broadcast.26) + Arg_2.3 = f16[2,6,100,64]{3,2,1,0} parameter(2), sharding={replicated} + dot.46 = f16[2,6,100,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = f16[2,6,100,64]{3,2,1,0} parameter(3), sharding={replicated} + dot.49 = f16[2,6,100,100]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + divide.4 = f16[2,6,100,100]{3,2,1,0} divide(dot.49, broadcast.26) + broadcast.20 = f16[2,6,100]{2,1,0} broadcast(constant.19), dimensions={} + multiply.3 = f16[2,6,100]{2,1,0} multiply(convert.9, convert.9) + divide.3 = f16[2,6,100]{2,1,0} divide(broadcast.20, multiply.3) + broadcast.21 = f16[2,6,100,100]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2} + multiply.4 = f16[2,6,100,100]{3,2,1,0} multiply(dot.49, broadcast.21) + multiply.5 = f16[2,6,100,100]{3,2,1,0} multiply(multiply.4, exponential.1) + reduce.59 = f16[2,6,100]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55 + negate.2 = f16[2,6,100]{2,1,0} negate(reduce.59) + broadcast.25 = f16[2,6,100,100]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2} + add.5 = f16[2,6,100,100]{3,2,1,0} add(divide.4, broadcast.25) + multiply.8 = f16[2,6,100,100]{3,2,1,0} multiply(add.5, exponential.1) + multiply.9 = f16[2,6,100,100]{3,2,1,0} multiply(multiply.8, broadcast.24) + dot.80 = f16[2,6,100,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot = f16[2,6,64,100]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + dot.1 = f16[2,6,100,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT tuple.81 = (f16[2,6,100,64]{3,2,1,0}, f16[2,6,100,64]{3,2,1,0}, f16[2,6,64,100]{3,2,1,0}, f16[2,6,100,64]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), + GetCudnnVersionWithDbiasAndMaskBwdInputSupport()}; + const auto status_or = RunHloPass(&fusedMhaRewriter, m.get()); + TF_ASSERT_OK(status_or.status()); + EXPECT_FALSE(status_or.value()); + + HloDCE dce; + TF_ASSERT_OK(RunHloPass(&dce, m.get()).status()); + + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + + SCOPED_TRACE(m->ToString()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple(m::Dot(), m::Dot(), m::Dot(), m::Dot()))); +} + +TEST_F(CudnnFusedMhaRewriterTestHloTest, F16TrainingBmm2Grad1IncorrectPattern) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + const char* module_str = R"( +HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,128]{3,2,1,0},f16[2,6,64,128]{3,2,1,0},f16[2,6,128,64]{3,2,1,0},f16[2,6,128,64]{3,2,1,0})->(f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + +region_0.21 { + Arg_0.22 = f16[] parameter(0) + Arg_1.23 = f16[] parameter(1) + ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23) +} + +region_1.33 { + Arg_0.34 = f32[] parameter(0) + Arg_1.35 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0.34, Arg_1.35) +} + +region_2.55 { + Arg_0.56 = f16[] parameter(0) + Arg_1.57 = f16[] parameter(1) + ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57) +} + +ENTRY main.82 { + Arg_0.1 = f16[2,6,64,128]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = f16[2,6,64,128]{3,2,1,0} parameter(1), sharding={replicated} + dot.17 = f16[2,6,128,128]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.22 = f16[] constant(2) + broadcast.24 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.22), dimensions={} + multiply.2 = f16[2,6,128,128]{3,2,1,0} multiply(dot.17, broadcast.24) + constant.19 = f16[] constant(1) + broadcast.13 = f16[2,6,128,128]{3,2,1,0} broadcast(constant.19), dimensions={} + add.3 = f16[2,6,128,128]{3,2,1,0} add(multiply.2, broadcast.13) + constant.21 = f16[] constant(0) + constant.15 = f16[] constant(-inf) + reduce.25 = f16[2,6,128]{2,1,0} reduce(add.3, constant.15), dimensions={3}, to_apply=region_0.21 + broadcast.17 = f16[2,6,128,128]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2} + subtract.1 = f16[2,6,128,128]{3,2,1,0} subtract(add.3, broadcast.17) + exponential.1 = f16[2,6,128,128]{3,2,1,0} exponential(subtract.1) + convert.5 = f32[2,6,128,128]{3,2,1,0} convert(exponential.1) + constant.17 = f32[] constant(0) + reduce.37 = f32[2,6,128]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33 + convert.9 = f16[2,6,128]{2,1,0} convert(reduce.37) + broadcast.26 = f16[2,6,128,128]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2} + divide.5 = f16[2,6,128,128]{3,2,1,0} divide(exponential.1, broadcast.26) + Arg_2.3 = f16[2,6,128,64]{3,2,1,0} parameter(2), sharding={replicated} + dot.46 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = f16[2,6,128,64]{3,2,1,0} parameter(3), sharding={replicated} + dot.49 = f16[2,6,128,128]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + divide.4 = f16[2,6,128,128]{3,2,1,0} divide(dot.49, broadcast.26) + broadcast.20 = f16[2,6,128]{2,1,0} broadcast(constant.19), dimensions={} + multiply.3 = f16[2,6,128]{2,1,0} multiply(convert.9, convert.9) + divide.3 = f16[2,6,128]{2,1,0} divide(broadcast.20, multiply.3) + broadcast.21 = f16[2,6,128,128]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2} + multiply.4 = f16[2,6,128,128]{3,2,1,0} multiply(dot.49, broadcast.21) + multiply.5 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.4, exponential.1) + reduce.59 = f16[2,6,128]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55 + negate.2 = f16[2,6,128]{2,1,0} negate(reduce.59) + broadcast.25 = f16[2,6,128,128]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2} + add.5 = f16[2,6,128,128]{3,2,1,0} add(divide.4, broadcast.25) + multiply.8 = f16[2,6,128,128]{3,2,1,0} multiply(add.5, exponential.1) + multiply.9 = f16[2,6,128,128]{3,2,1,0} multiply(multiply.8, broadcast.24) + dot.80 = f16[2,6,128,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot = f16[2,6,64,128]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + // add another user of ds multiply.9 here, neg.1 should not be pattern matched as bmm2grad1 + neg.1 = f16[2,6,128,128]{3,2,1,0} negate(multiply.9) + dot.1 = f16[2,6,128,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT tuple.81 = (f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,64,128]{3,2,1,0}, f16[2,6,128,64]{3,2,1,0}, f16[2,6,128,128]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1, neg.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), + GetCudnnVersionWithDbiasAndMaskBwdInputSupport()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + HloDCE dce; + TF_ASSERT_OK(RunHloPass(&dce, m.get()).status()); + + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + + SCOPED_TRACE(m->ToString()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple(m::Dot(), m::Dot(), m::Dot(), m::Dot(), + m::Negate()))); +} + +// flash attention +TEST_F(CudnnFusedMhaRewriterTestHloTest, + FlashAttentionBF16TrainingBmm1CausalMaskSoftmaxBmm2Pattern) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + const char* module_str = R"( +HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0})->(bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} +region_0.32 { + Arg_0.33 = bf16[] parameter(0) + Arg_1.34 = bf16[] parameter(1) + ROOT maximum = bf16[] maximum(Arg_0.33, Arg_1.34) +} +region_1.44 { + Arg_0.45 = f32[] parameter(0) + Arg_1.46 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0.45, Arg_1.46) +} +region_2.66 { + Arg_0.67 = bf16[] parameter(0) + Arg_1.68 = bf16[] parameter(1) + ROOT add.1 = bf16[] add(Arg_0.67, Arg_1.68) +} +ENTRY main.92 { + Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated} + dot.14 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.17 = bf16[] constant(2) + broadcast.29 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.17), dimensions={} + multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.14, broadcast.29) + iota.2 = s32[2048,2048]{1,0} iota(), iota_dimension=0 + iota.5 = s32[2048,2048]{1,0} iota(), iota_dimension=1 + compare.1 = pred[2048,2048]{1,0} compare(iota.2, iota.5), direction=LT + constant.6 = bf16[] constant(-2.366e+38) + broadcast.16 = bf16[2048,2048]{1,0} broadcast(constant.6), dimensions={} + constant.16 = bf16[] constant(0) + broadcast.17 = bf16[2048,2048]{1,0} broadcast(constant.16), dimensions={} + select.2 = bf16[2048,2048]{1,0} select(compare.1, broadcast.16, broadcast.17) + broadcast.19 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(select.2), dimensions={2,3} + add.3 = bf16[2,6,2048,2048]{3,2,1,0} add(multiply.2, broadcast.19) + constant.10 = bf16[] constant(-inf) + reduce.36 = bf16[2,6,2048]{2,1,0} reduce(add.3, constant.10), dimensions={3}, to_apply=region_0.32 + broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2} + subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(add.3, broadcast.21) + exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1) + convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1) + constant.14 = f32[] constant(0) + reduce.48 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.14), dimensions={3}, to_apply=region_1.44 + convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.48) + broadcast.32 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2} + divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.32) + Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated} + dot.57 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = bf16[2,6,2048,128]{3,2,1,0} parameter(3), sharding={replicated} + dot.60 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.60, broadcast.32) + constant.15 = bf16[] constant(1) + broadcast.25 = bf16[2,6,2048]{2,1,0} broadcast(constant.15), dimensions={} + multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9) + divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.25, multiply.3) + broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2} + multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.60, broadcast.26) + multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1) + reduce.70 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.16), dimensions={3}, to_apply=region_2.66 + negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.70) + broadcast.31 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2} + add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.31) + multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1) + multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.8, broadcast.29) + dot.90 = bf16[2,6,2048,128]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot = bf16[2,6,128,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + dot.1 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT tuple.91 = (bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}) tuple(dot.57, dot.90, dot, dot.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), GetCudnnVersionWithFlashAttentionSupport()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + HloDCE dce; + TF_ASSERT_OK(RunHloPass(&dce, m.get()).status()); + + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + + const HloInstruction* fwd_fmha; + const HloInstruction* bwd_fmha; + SCOPED_TRACE(m->ToString()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::GetTupleElement( + m::CustomCall(&fwd_fmha, {kCudnnfMHASoftmaxCallTarget}), 0) + .WithShape(BF16, {2, 6, 2048, 128}), + m::GetTupleElement( + m::CustomCall(&bwd_fmha, {kCudnnfMHASoftmaxBackwardCallTarget}), + 0) + .WithShape(BF16, {2, 6, 2048, 128}), + m::Transpose( + m::GetTupleElement( + m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 1)) + .WithShape(BF16, {2, 6, 128, 2048}), + m::GetTupleElement( + m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 2) + .WithShape(BF16, {2, 6, 2048, 128})))); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fwd_fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); + EXPECT_EQ(fwd_fmha->operands().size(), 3); + EXPECT_EQ(bwd_fmha->operands().size(), 6); + EXPECT_NEAR(config.dropout_rate(), 0, 1e-2); + EXPECT_EQ(config.is_flash_attention(), true); + EXPECT_EQ(config.is_causal_mask(), true); +} + +TEST_F(CudnnFusedMhaRewriterTestHloTest, + FlashAttentionBF16TrainingBmm1BiasSoftmaxBmm2Pattern) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + const char* module_str = R"( +HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,2048]{3,2,1,0})->(bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} +region_0.32 { + Arg_0.33 = bf16[] parameter(0) + Arg_1.34 = bf16[] parameter(1) + ROOT maximum = bf16[] maximum(Arg_0.33, Arg_1.34) +} +region_1.44 { + Arg_0.45 = f32[] parameter(0) + Arg_1.46 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0.45, Arg_1.46) +} +region_2.66 { + Arg_0.67 = bf16[] parameter(0) + Arg_1.68 = bf16[] parameter(1) + ROOT add.1 = bf16[] add(Arg_0.67, Arg_1.68) +} +ENTRY main.92 { + Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated} + dot.14 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.17 = bf16[] constant(2) + broadcast.29 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.17), dimensions={} + multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.14, broadcast.29) + // bias + Arg_4.5 = bf16[2,6,2048,2048]{3,2,1,0} parameter(4), sharding={replicated} + add.3 = bf16[2,6,2048,2048]{3,2,1,0} add(multiply.2, Arg_4.5) + constant.10 = bf16[] constant(-inf) + constant.16 = bf16[] constant(0) + reduce.36 = bf16[2,6,2048]{2,1,0} reduce(add.3, constant.10), dimensions={3}, to_apply=region_0.32 + broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2} + subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(add.3, broadcast.21) + exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1) + convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1) + constant.14 = f32[] constant(0) + reduce.48 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.14), dimensions={3}, to_apply=region_1.44 + convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.48) + broadcast.32 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2} + divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.32) + Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated} + dot.57 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = bf16[2,6,2048,128]{3,2,1,0} parameter(3), sharding={replicated} + dot.60 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.60, broadcast.32) + constant.15 = bf16[] constant(1) + broadcast.25 = bf16[2,6,2048]{2,1,0} broadcast(constant.15), dimensions={} + multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9) + divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.25, multiply.3) + broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2} + multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.60, broadcast.26) + multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1) + reduce.70 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.16), dimensions={3}, to_apply=region_2.66 + negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.70) + broadcast.31 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2} + add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.31) + multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1) + multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.8, broadcast.29) + dot.90 = bf16[2,6,2048,128]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot = bf16[2,6,128,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + dot.1 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT tuple.91 = (bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}) tuple(dot.57, dot.90, dot, dot.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), GetCudnnVersionWithFlashAttentionSupport()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + HloDCE dce; + TF_ASSERT_OK(RunHloPass(&dce, m.get()).status()); + + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + + const HloInstruction* fmha; + + SCOPED_TRACE(m->ToString()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::GetTupleElement( + m::CustomCall(&fmha, {kCudnnfMHAScaleBiasSoftmaxCallTarget}), 0) + .WithShape(BF16, {2, 6, 2048, 128}), + m::GetTupleElement( + m::CustomCall(&fmha, + {kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}), + 0) + .WithShape(BF16, {2, 6, 2048, 128}), + m::Transpose( + m::GetTupleElement( + m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}), + 1)) + .WithShape(BF16, {2, 6, 128, 2048}), + m::GetTupleElement( + m::CustomCall({kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget}), 2) + .WithShape(BF16, {2, 6, 2048, 128})))); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); + EXPECT_EQ(fmha->operands().size(), 7); + EXPECT_NEAR(config.dropout_rate(), 0, 1e-2); + EXPECT_EQ(config.is_flash_attention(), true); + EXPECT_EQ(config.is_causal_mask(), false); +} + +TEST_F(CudnnFusedMhaRewriterTestHloTest, + FlashAttentionBF16TrainingBmm1SoftmaxBmm2Pattern) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + const char* module_str = R"( +HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0})->(bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} +region_0.32 { + Arg_0.33 = bf16[] parameter(0) + Arg_1.34 = bf16[] parameter(1) + ROOT maximum = bf16[] maximum(Arg_0.33, Arg_1.34) +} +region_1.44 { + Arg_0.45 = f32[] parameter(0) + Arg_1.46 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0.45, Arg_1.46) +} +region_2.66 { + Arg_0.67 = bf16[] parameter(0) + Arg_1.68 = bf16[] parameter(1) + ROOT add.1 = bf16[] add(Arg_0.67, Arg_1.68) +} +ENTRY main.92 { + Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated} + dot.14 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.17 = bf16[] constant(2) + broadcast.29 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.17), dimensions={} + multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.14, broadcast.29) + constant.10 = bf16[] constant(-inf) + constant.16 = bf16[] constant(0) + reduce.36 = bf16[2,6,2048]{2,1,0} reduce(multiply.2, constant.10), dimensions={3}, to_apply=region_0.32 + broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.36), dimensions={0,1,2} + subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(multiply.2, broadcast.21) + exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1) + convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1) + constant.14 = f32[] constant(0) + reduce.48 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.14), dimensions={3}, to_apply=region_1.44 + convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.48) + broadcast.32 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2} + divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.32) + Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated} + dot.57 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = bf16[2,6,2048,128]{3,2,1,0} parameter(3), sharding={replicated} + dot.60 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.60, broadcast.32) + constant.15 = bf16[] constant(1) + broadcast.25 = bf16[2,6,2048]{2,1,0} broadcast(constant.15), dimensions={} + multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9) + divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.25, multiply.3) + broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2} + multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.60, broadcast.26) + multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1) + reduce.70 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.16), dimensions={3}, to_apply=region_2.66 + negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.70) + broadcast.31 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2} + add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.31) + multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1) + multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.8, broadcast.29) + dot.90 = bf16[2,6,2048,128]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot = bf16[2,6,128,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + dot.1 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT tuple.91 = (bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}, bf16[2,6,128,2048]{3,2,1,0}, bf16[2,6,2048,128]{3,2,1,0}) tuple(dot.57, dot.90, dot, dot.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), GetCudnnVersionWithFlashAttentionSupport()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + HloDCE dce; + TF_ASSERT_OK(RunHloPass(&dce, m.get()).status()); + + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + + const HloInstruction* fmha; + + SCOPED_TRACE(m->ToString()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::GetTupleElement( + m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0) + .WithShape(BF16, {2, 6, 2048, 128}), + m::GetTupleElement( + m::CustomCall(&fmha, {kCudnnfMHASoftmaxBackwardCallTarget}), 0) + .WithShape(BF16, {2, 6, 2048, 128}), + m::Transpose( + m::GetTupleElement( + m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 1)) + .WithShape(BF16, {2, 6, 128, 2048}), + m::GetTupleElement( + m::CustomCall({kCudnnfMHASoftmaxBackwardCallTarget}), 2) + .WithShape(BF16, {2, 6, 2048, 128})))); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); + EXPECT_EQ(fmha->operands().size(), 6); + EXPECT_NEAR(config.dropout_rate(), 0, 1e-2); + EXPECT_FLOAT_EQ(config.fmha_scale(), 2); + EXPECT_EQ(config.is_flash_attention(), true); + EXPECT_EQ(config.is_causal_mask(), false); +} + +TEST_F(CudnnFusedMhaRewriterTestHloTest, + FlashAttentionBF16TrainingBmm1ScaleMaskSoftmaxBmm2Pattern) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + const char* module_str = R"( +HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,64]{3,2,1,0},bf16[2,6,64,2048]{3,2,1,0},bf16[2,6,2048,64]{3,2,1,0},bf16[2,6,2048,64]{3,2,1,0})->(bf16[2,6,2048,64]{3,2,1,0}, bf16[2,6,2048,64]{3,2,1,0}, bf16[2,6,64,2048]{3,2,1,0}, bf16[2,6,2048,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + +region_0.21 { + Arg_0.22 = bf16[] parameter(0) + Arg_1.23 = bf16[] parameter(1) + ROOT maximum = bf16[] maximum(Arg_0.22, Arg_1.23) +} + +region_1.33 { + Arg_0.34 = f32[] parameter(0) + Arg_1.35 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0.34, Arg_1.35) +} + +region_2.55 { + Arg_0.56 = bf16[] parameter(0) + Arg_1.57 = bf16[] parameter(1) + ROOT add.1 = bf16[] add(Arg_0.56, Arg_1.57) +} + +ENTRY main.82 { + constant.18 = pred[2,6,2048,2048]{3,2,1,0} constant({...}) + Arg_0.1 = bf16[2,6,2048,64]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = bf16[2,6,64,2048]{3,2,1,0} parameter(1), sharding={replicated} + dot.17 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.22 = bf16[] constant(2) + broadcast.24 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.22), dimensions={} + multiply.2 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.17, broadcast.24) + constant.19 = bf16[] constant(1) + constant.21 = bf16[] constant(0) + broadcast.23 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.21), dimensions={} + select.1 = bf16[2,6,2048,2048]{3,2,1,0} select(constant.18, multiply.2, broadcast.23) + constant.15 = bf16[] constant(-inf) + reduce.25 = bf16[2,6,2048]{2,1,0} reduce(select.1, constant.15), dimensions={3}, to_apply=region_0.21 + broadcast.17 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2} + subtract.1 = bf16[2,6,2048,2048]{3,2,1,0} subtract(select.1, broadcast.17) + exponential.1 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.1) + convert.5 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.1) + constant.17 = f32[] constant(0) + reduce.37 = f32[2,6,2048]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33 + convert.9 = bf16[2,6,2048]{2,1,0} convert(reduce.37) + broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2} + divide.5 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.1, broadcast.26) + Arg_2.3 = bf16[2,6,2048,64]{3,2,1,0} parameter(2), sharding={replicated} + dot.46 = bf16[2,6,2048,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = bf16[2,6,2048,64]{3,2,1,0} parameter(3), sharding={replicated} + dot.49 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + divide.4 = bf16[2,6,2048,2048]{3,2,1,0} divide(dot.49, broadcast.26) + broadcast.20 = bf16[2,6,2048]{2,1,0} broadcast(constant.19), dimensions={} + multiply.3 = bf16[2,6,2048]{2,1,0} multiply(convert.9, convert.9) + divide.3 = bf16[2,6,2048]{2,1,0} divide(broadcast.20, multiply.3) + broadcast.21 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2} + multiply.4 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.49, broadcast.21) + multiply.5 = bf16[2,6,2048,2048]{3,2,1,0} multiply(multiply.4, exponential.1) + reduce.59 = bf16[2,6,2048]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55 + negate.2 = bf16[2,6,2048]{2,1,0} negate(reduce.59) + broadcast.25 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2} + add.5 = bf16[2,6,2048,2048]{3,2,1,0} add(divide.4, broadcast.25) + multiply.8 = bf16[2,6,2048,2048]{3,2,1,0} multiply(add.5, exponential.1) + select.3 = bf16[2,6,2048,2048]{3,2,1,0} select(constant.18, multiply.8, broadcast.23) + multiply.9 = bf16[2,6,2048,2048]{3,2,1,0} multiply(select.3, broadcast.24) + dot.80 = bf16[2,6,2048,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot = bf16[2,6,64,2048]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + dot.1 = bf16[2,6,2048,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT tuple.81 = (bf16[2,6,2048,64]{3,2,1,0}, bf16[2,6,2048,64]{3,2,1,0}, bf16[2,6,64,2048]{3,2,1,0}, bf16[2,6,2048,64]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), GetCudnnVersionWithFlashAttentionSupport()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + HloDCE dce; + TF_ASSERT_OK(RunHloPass(&dce, m.get()).status()); + + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + + const HloInstruction* fmha; + + SCOPED_TRACE(m->ToString()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::GetTupleElement( + m::CustomCall(&fmha, {kCudnnfMHAScaleMaskSoftmaxCallTarget}), 0) + .WithShape(BF16, {2, 6, 2048, 64}), + m::GetTupleElement( + m::CustomCall(&fmha, + {kCudnnfMHAScaleMaskSoftmaxBackwardCallTarget}), + 0) + .WithShape(BF16, {2, 6, 2048, 64}), + m::Transpose( + m::GetTupleElement( + m::CustomCall({kCudnnfMHAScaleMaskSoftmaxBackwardCallTarget}), + 1)) + .WithShape(BF16, {2, 6, 64, 2048}), + m::GetTupleElement( + m::CustomCall({kCudnnfMHAScaleMaskSoftmaxBackwardCallTarget}), 2) + .WithShape(BF16, {2, 6, 2048, 64})))); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); + EXPECT_EQ(fmha->operands().size(), 7); + EXPECT_NEAR(config.dropout_rate(), 0, 1e-2); + EXPECT_EQ(config.is_flash_attention(), true); + EXPECT_EQ(config.is_causal_mask(), false); +} + +// GPT3 pattern +TEST_F(CudnnFusedMhaRewriterTestHloTest, FlashAttentionBF16TrainingGPT3_5B) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + const char* module_str = R"( +HloModule jit__unnamed_wrapped_function_, entry_computation_layout={((s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0}))->(s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0})} +add { + x = bf16[] parameter(0) + y = bf16[] parameter(1) + ROOT add.580 = bf16[] add(x, y) +} + +region_20.962 { + Arg_0.963 = f32[] parameter(0) + Arg_1.964 = f32[] parameter(1) + ROOT add.579 = f32[] add(Arg_0.963, Arg_1.964) +} + +region_39.1120 { + Arg_0.1121 = f32[] parameter(0) + Arg_1.1122 = f32[] parameter(1) + ROOT maximum.21 = f32[] maximum(Arg_0.1121, Arg_1.1122) +} + +main { + param.3 = (s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0}) parameter(0) + get-tuple-element.31 = s32[] get-tuple-element(param.3), index=0 + constant.1961 = s32[] constant(1) + add.581 = s32[] add(get-tuple-element.31, constant.1961) + get-tuple-element.32 = bf16[24,32,2048,2048]{2,1,3,0} get-tuple-element(param.3), index=25 + bitcast.187 = bf16[24,2048,32,2048]{3,2,1,0} bitcast(get-tuple-element.32) + constant.1977 = s32[] constant(23) + subtract.221 = s32[] subtract(constant.1977, get-tuple-element.31) + constant.1980 = s32[] constant(0) + compare.210 = pred[] compare(subtract.221, constant.1980), direction=LT + constant.1979 = s32[] constant(47) + subtract.222 = s32[] subtract(constant.1979, get-tuple-element.31) + select.372 = s32[] select(compare.210, subtract.222, subtract.221) + dynamic-slice.324 = bf16[1,2048,32,2048]{3,2,1,0} dynamic-slice(bitcast.187, select.372, constant.1980, constant.1980, constant.1980), dynamic_slice_sizes={1,2048,32,2048} + bitcast.756 = bf16[2048,32,2048]{2,1,0} bitcast(dynamic-slice.324) + convert.282 = f32[2048,32,2048]{2,1,0} convert(bitcast.756) + constant.1991 = bf16[] constant(1) + broadcast.1270 = bf16[32,2048]{1,0} broadcast(constant.1991), dimensions={} + get-tuple-element.33 = bf16[32,2048]{1,0} get-tuple-element(param.3), index=27 + subtract.229 = bf16[32,2048]{1,0} subtract(broadcast.1270, get-tuple-element.33) + convert.285 = f32[32,2048]{1,0} convert(subtract.229) + broadcast.1228 = f32[2048,32,2048]{2,1,0} broadcast(convert.285), dimensions={1,2} + multiply.656 = f32[2048,32,2048]{2,1,0} multiply(convert.282, broadcast.1228) + bitcast.367 = f32[32,2048,2048]{1,0,2} bitcast(multiply.656) + constant.1968 = f32[] constant(0) + reduce.84 = f32[] reduce(bitcast.367, constant.1968), dimensions={0,1,2}, to_apply=region_20.962 + all-reduce.230 = f32[] all-reduce(reduce.84), channel_id=278, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962 + broadcast.1221 = f32[32,2048,4096]{2,1,0} broadcast(convert.285), dimensions={0,1} + reduce.85 = f32[] reduce(broadcast.1221, constant.1968), dimensions={0,1,2}, to_apply=region_20.962 + all-reduce.14 = f32[] all-reduce(reduce.85), channel_id=49, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=region_20.962 + constant.2005 = f32[] constant(1) + maximum.24 = f32[] maximum(all-reduce.14, constant.2005) + divide.96 = f32[] divide(all-reduce.230, maximum.24) + broadcast.1223 = f32[2048,32,2048]{2,1,0} broadcast(divide.96), dimensions={} + subtract.219 = f32[2048,32,2048]{2,1,0} subtract(convert.282, broadcast.1223) + multiply.644 = f32[2048,32,2048]{2,1,0} multiply(subtract.219, broadcast.1228) + multiply.645 = f32[2048,32,2048]{2,1,0} multiply(multiply.644, multiply.644) + bitcast.271 = f32[32,2048,2048]{1,0,2} bitcast(multiply.645) + reduce.86 = f32[] reduce(bitcast.271, constant.1968), dimensions={0,1,2}, to_apply=region_20.962 + all-reduce.231 = f32[] all-reduce(reduce.86), channel_id=279, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962 + divide.99 = f32[] divide(all-reduce.231, maximum.24) + rsqrt.16 = f32[] rsqrt(divide.99) + multiply.650 = f32[] multiply(rsqrt.16, constant.1968) + divide.100 = f32[] divide(multiply.650, maximum.24) + constant.1974 = f32[] constant(2) + multiply.652 = f32[] multiply(divide.100, constant.1974) + broadcast.1227 = f32[2048,32,2048]{2,1,0} broadcast(multiply.652), dimensions={} + multiply.653 = f32[2048,32,2048]{2,1,0} multiply(multiply.644, broadcast.1227) + multiply.654 = f32[2048,32,2048]{2,1,0} multiply(multiply.653, broadcast.1228) + negate.56 = f32[2048,32,2048]{2,1,0} negate(multiply.654) + bitcast.321 = f32[32,2048,2048]{1,0,2} bitcast(negate.56) + reduce.87 = f32[] reduce(bitcast.321, constant.1968), dimensions={0,1,2}, to_apply=region_20.962 + all-reduce.232 = f32[] all-reduce(reduce.87), channel_id=280, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962 + divide.101 = f32[] divide(all-reduce.232, maximum.24) + broadcast.1229 = f32[32,2048]{1,0} broadcast(divide.101), dimensions={} + multiply.655 = f32[32,2048]{1,0} multiply(broadcast.1229, convert.285) + broadcast.1230 = f32[2048,32,2048]{2,1,0} broadcast(multiply.655), dimensions={1,2} + add.582 = f32[2048,32,2048]{2,1,0} add(multiply.654, broadcast.1230) + broadcast.1236 = f32[2048,32,2048]{2,1,0} broadcast(constant.1968), dimensions={} + compare.208 = pred[2048,32,2048]{2,1,0} compare(multiply.656, broadcast.1236), direction=GE + abs.22 = f32[2048,32,2048]{2,1,0} abs(multiply.656) + bitcast.373 = f32[32,2048,2048]{1,0,2} bitcast(abs.22) + constant.1989 = f32[] constant(-inf) + reduce.88 = f32[] reduce(bitcast.373, constant.1989), dimensions={0,1,2}, to_apply=region_39.1120 + all-reduce.233 = f32[] all-reduce(reduce.88), channel_id=281, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_39.1120 + broadcast.1233 = f32[2048,32,2048]{2,1,0} broadcast(all-reduce.233), dimensions={} + compare.207 = pred[2048,32,2048]{2,1,0} compare(abs.22, broadcast.1233), direction=EQ + convert.286 = f32[2048,32,2048]{2,1,0} convert(compare.207) + bitcast.393 = f32[32,2048,2048]{1,0,2} bitcast(convert.286) + reduce.89 = f32[] reduce(bitcast.393, constant.1968), dimensions={0,1,2}, to_apply=region_20.962 + all-reduce.234 = f32[] all-reduce(reduce.89), channel_id=282, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=region_20.962 + divide.103 = f32[] divide(constant.1968, all-reduce.234) + broadcast.1238 = f32[2048,32,2048]{2,1,0} broadcast(divide.103), dimensions={} + select.370 = f32[2048,32,2048]{2,1,0} select(compare.207, broadcast.1238, broadcast.1236) + select.369 = f32[2048,32,2048]{2,1,0} select(compare.208, select.370, broadcast.1236) + constant.1976 = pred[] constant(false) + broadcast.1237 = pred[2048,32,2048]{2,1,0} broadcast(constant.1976), dimensions={} + compare.209 = pred[2048,32,2048]{2,1,0} compare(compare.208, broadcast.1237), direction=EQ + select.371 = f32[2048,32,2048]{2,1,0} select(compare.209, select.370, broadcast.1236) + negate.57 = f32[2048,32,2048]{2,1,0} negate(select.371) + add.583 = f32[2048,32,2048]{2,1,0} add(select.369, negate.57) + multiply.658 = f32[2048,32,2048]{2,1,0} multiply(add.583, broadcast.1228) + add.585 = f32[2048,32,2048]{2,1,0} add(add.582, multiply.658) + convert.287 = bf16[2048,32,2048]{2,1,0} convert(add.585) + get-tuple-element.34 = bf16[32,2048,2048]{1,0,2} get-tuple-element(param.3), index=1 + bitcast.1652 = bf16[2048,32,2048]{2,1,0} bitcast(get-tuple-element.34) + get-tuple-element.35 = bf16[24,3,1024,16,128]{4,3,1,2,0} get-tuple-element(param.3), index=22 + bitcast.461 = bf16[24,1024,3,16,128]{4,3,2,1,0} bitcast(get-tuple-element.35) + dynamic-slice.325 = bf16[1,1024,3,16,128]{4,3,2,1,0} dynamic-slice(bitcast.461, select.372, constant.1980, constant.1980, constant.1980, /*index=5*/constant.1980), dynamic_slice_sizes={1,1024,3,16,128} + bitcast.485 = bf16[3,1024,16,128]{3,2,0,1} bitcast(dynamic-slice.325) + all-gather.7 = bf16[3,4096,16,128]{3,2,0,1} all-gather(bitcast.485), channel_id=60, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={1}, use_global_device_ids=true + bitcast.1420 = bf16[6144,4096]{0,1} bitcast(all-gather.7) + bitcast.500 = f32[32,2048,2048]{1,0,2} bitcast(convert.282) + reduce.90 = f32[32,2048]{1,0} reduce(bitcast.500, constant.1968), dimensions={2}, to_apply=region_20.962 + all-reduce.23 = f32[32,2048]{1,0} all-reduce(reduce.90), channel_id=58, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962 + constant.1983 = f32[] constant(0.000244140625) + broadcast.1243 = f32[32,2048]{1,0} broadcast(constant.1983), dimensions={} + multiply.660 = f32[32,2048]{1,0} multiply(all-reduce.23, broadcast.1243) + broadcast.1242 = f32[2048,32,2048]{2,1,0} broadcast(multiply.660), dimensions={1,2} + subtract.224 = f32[2048,32,2048]{2,1,0} subtract(convert.282, broadcast.1242) + multiply.661 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, subtract.224) + bitcast.527 = f32[32,2048,2048]{1,0,2} bitcast(multiply.661) + reduce.91 = f32[32,2048]{1,0} reduce(bitcast.527, constant.1968), dimensions={2}, to_apply=region_20.962 + all-reduce.24 = f32[32,2048]{1,0} all-reduce(reduce.91), channel_id=59, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962 + multiply.662 = f32[32,2048]{1,0} multiply(all-reduce.24, broadcast.1243) + constant.1990 = f32[] constant(1e-05) + broadcast.1264 = f32[32,2048]{1,0} broadcast(constant.1990), dimensions={} + add.587 = f32[32,2048]{1,0} add(multiply.662, broadcast.1264) + bitcast.1447 = f32[1,32,2048]{2,1,0} bitcast(add.587) + rsqrt.20 = f32[1,32,2048]{2,1,0} rsqrt(bitcast.1447) + bitcast.1892 = f32[32,2048]{1,0} bitcast(rsqrt.20) + broadcast.1337 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1892), dimensions={1,2} + multiply.754 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, broadcast.1337) + convert.314 = bf16[2048,32,2048]{2,1,0} convert(multiply.754) + get-tuple-element.36 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=20 + dynamic-slice.326 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.36, select.372, constant.1980), dynamic_slice_sizes={1,2048} + broadcast.1266 = bf16[1,2048]{1,0} broadcast(constant.1991), dimensions={} + add.588 = bf16[1,2048]{1,0} add(dynamic-slice.326, broadcast.1266) + bitcast.1992 = bf16[2048]{0} bitcast(add.588) + broadcast.1338 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1992), dimensions={0} + multiply.755 = bf16[2048,32,2048]{2,1,0} multiply(convert.314, broadcast.1338) + get-tuple-element.37 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=19 + dynamic-slice.327 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.37, select.372, constant.1980), dynamic_slice_sizes={1,2048} + bitcast.1998 = bf16[2048]{0} bitcast(dynamic-slice.327) + broadcast.1339 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1998), dimensions={0} + add.640 = bf16[2048,32,2048]{2,1,0} add(multiply.755, broadcast.1339) + bitcast.2003 = bf16[32,2048,2048]{1,0,2} bitcast(add.640) + all-gather.8 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.2003), channel_id=61, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true + bitcast.597 = bf16[4096,65536]{1,0} bitcast(all-gather.8) + dot.42 = bf16[6144,65536]{1,0} dot(bitcast.1420, bitcast.597), lhs_contracting_dims={1}, rhs_contracting_dims={0} + bitcast.623 = bf16[3,16,128,32,2048]{4,3,2,1,0} bitcast(dot.42) + transpose.112 = bf16[3,32,16,128,2048]{4,3,2,1,0} transpose(bitcast.623), dimensions={0,3,1,2,4} + get-tuple-element.38 = bf16[24,3,16,128]{3,2,1,0} get-tuple-element(param.3), index=21 + dynamic-slice.328 = bf16[1,3,16,128]{3,2,1,0} dynamic-slice(get-tuple-element.38, select.372, constant.1980, constant.1980, constant.1980), dynamic_slice_sizes={1,3,16,128} + bitcast.626 = bf16[3,16,128]{2,1,0} bitcast(dynamic-slice.328) + broadcast.1250 = bf16[3,32,16,128,2048]{4,3,2,1,0} broadcast(bitcast.626), dimensions={0,2,3} + add.591 = bf16[3,32,16,128,2048]{4,3,2,1,0} add(transpose.112, broadcast.1250) + slice.87 = bf16[1,32,16,128,2048]{4,3,2,1,0} slice(add.591), slice={[2:3], [0:32], [0:16], [0:128], [0:2048]} + bitcast.1280 = bf16[32,16,128,2048]{3,2,1,0} bitcast(slice.87) + slice.88 = bf16[1,32,16,128,2048]{4,3,2,1,0} slice(add.591), slice={[0:1], [0:32], [0:16], [0:128], [0:2048]} + constant.2007 = bf16[] constant(0.08838) + broadcast.1251 = bf16[1,32,16,128,2048]{4,3,2,1,0} broadcast(constant.2007), dimensions={} + multiply.666 = bf16[1,32,16,128,2048]{4,3,2,1,0} multiply(slice.88, broadcast.1251) + bitcast.1330 = bf16[32,16,128,2048]{3,2,1,0} bitcast(multiply.666) + transpose.113 = bf16[32,16,2048,128]{3,2,1,0} transpose(bitcast.1330), dimensions={0,1,3,2} + slice.89 = bf16[1,32,16,128,2048]{4,3,2,1,0} slice(add.591), slice={[1:2], [0:32], [0:16], [0:128], [0:2048]} + bitcast.647 = bf16[32,16,128,2048]{3,2,1,0} bitcast(slice.89) + dot.43 = bf16[32,16,2048,2048]{3,2,1,0} dot(transpose.113, bitcast.647), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + convert.291 = f32[32,16,2048,2048]{3,2,1,0} convert(dot.43) + get-tuple-element.39 = bf16[32,1,2048,2048]{3,2,0,1} get-tuple-element(param.3), index=26 + bitcast.651 = bf16[1,32,2048,2048]{3,2,1,0} bitcast(get-tuple-element.39) + iota.38 = s32[2048,2048]{1,0} iota(), iota_dimension=0 + iota.39 = s32[2048,2048]{1,0} iota(), iota_dimension=1 + compare.211 = pred[2048,2048]{1,0} compare(iota.38, iota.39), direction=LT + constant.1987 = bf16[] constant(-2.366e+38) + broadcast.1252 = bf16[2048,2048]{1,0} broadcast(constant.1987), dimensions={} + constant.2006 = bf16[] constant(0) + broadcast.1253 = bf16[2048,2048]{1,0} broadcast(constant.2006), dimensions={} + select.373 = bf16[2048,2048]{1,0} select(compare.211, broadcast.1252, broadcast.1253) + broadcast.1254 = bf16[1,32,2048,2048]{3,2,1,0} broadcast(select.373), dimensions={2,3} + minimum.5 = bf16[1,32,2048,2048]{3,2,1,0} minimum(bitcast.651, broadcast.1254) + bitcast.673 = bf16[32,2048,2048]{2,1,0} bitcast(minimum.5) + convert.292 = f32[32,2048,2048]{2,1,0} convert(bitcast.673) + broadcast.1256 = f32[32,16,2048,2048]{3,2,1,0} broadcast(convert.292), dimensions={0,2,3} + add.593 = f32[32,16,2048,2048]{3,2,1,0} add(convert.291, broadcast.1256) + reduce.92 = f32[32,16,2048]{2,1,0} reduce(add.593, constant.1989), dimensions={3}, to_apply=region_39.1120 + broadcast.1258 = f32[32,16,2048,2048]{3,2,1,0} broadcast(reduce.92), dimensions={0,1,2} + subtract.226 = f32[32,16,2048,2048]{3,2,1,0} subtract(add.593, broadcast.1258) + exponential.8 = f32[32,16,2048,2048]{3,2,1,0} exponential(subtract.226) + reduce.93 = f32[32,16,2048]{2,1,0} reduce(exponential.8, constant.1968), dimensions={3}, to_apply=region_20.962 + broadcast.1309 = f32[32,16,2048,2048]{3,2,1,0} broadcast(reduce.93), dimensions={0,1,2} + divide.109 = f32[32,16,2048,2048]{3,2,1,0} divide(exponential.8, broadcast.1309) + convert.306 = bf16[32,16,2048,2048]{3,2,1,0} convert(divide.109) + dot.44 = bf16[32,16,128,2048]{3,2,1,0} dot(bitcast.1280, convert.306), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + transpose.116 = bf16[32,2048,16,128]{3,2,1,0} transpose(dot.44), dimensions={0,3,1,2} + bitcast.711 = bf16[65536,2048]{1,0} bitcast(transpose.116) + get-tuple-element.40 = bf16[24,1024,16,128]{3,2,1,0} get-tuple-element(param.3), index=24 + dynamic-slice.329 = bf16[1,1024,16,128]{3,2,1,0} dynamic-slice(get-tuple-element.40, select.372, constant.1980, constant.1980, constant.1980), dynamic_slice_sizes={1,1024,16,128} + bitcast.724 = bf16[1024,16,128]{2,1,0} bitcast(dynamic-slice.329) + all-gather.9 = bf16[4096,16,128]{2,1,0} all-gather(bitcast.724), channel_id=62, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={0}, use_global_device_ids=true + bitcast.729 = bf16[2048,4096]{0,1} bitcast(all-gather.9) + dot.57 = bf16[65536,4096]{0,1} dot(bitcast.711, bitcast.729), lhs_contracting_dims={1}, rhs_contracting_dims={0} + bitcast.733 = bf16[32,2048,4096]{1,0,2} bitcast(dot.57) + reduce-scatter = bf16[32,2048,2048]{1,0,2} reduce-scatter(bitcast.733), channel_id=322, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={2}, to_apply=add + bitcast.763 = bf16[2048,32,2048]{2,1,0} bitcast(reduce-scatter) + get-tuple-element.41 = bf16[24,1024]{1,0} get-tuple-element(param.3), index=23 + dynamic-slice.330 = bf16[1,1024]{1,0} dynamic-slice(get-tuple-element.41, select.372, constant.1980), dynamic_slice_sizes={1,1024} + bitcast.748 = bf16[1024]{0} bitcast(dynamic-slice.330) + collective-permute.1 = bf16[1024]{0} collective-permute(bitcast.748), channel_id=64, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}} + all-gather.10 = bf16[2048]{0} all-gather(collective-permute.1), channel_id=65, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={0}, use_global_device_ids=true + broadcast.1261 = bf16[2048,32,2048]{2,1,0} broadcast(all-gather.10), dimensions={0} + add.596 = bf16[2048,32,2048]{2,1,0} add(bitcast.763, broadcast.1261) + add.597 = bf16[2048,32,2048]{2,1,0} add(add.596, bitcast.756) + convert.295 = f32[2048,32,2048]{2,1,0} convert(add.597) + bitcast.774 = f32[32,2048,2048]{1,0,2} bitcast(convert.295) + reduce.94 = f32[32,2048]{1,0} reduce(bitcast.774, constant.1968), dimensions={2}, to_apply=region_20.962 + all-reduce.26 = f32[32,2048]{1,0} all-reduce(reduce.94), channel_id=66, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962 + multiply.668 = f32[32,2048]{1,0} multiply(all-reduce.26, broadcast.1243) + broadcast.1263 = f32[2048,32,2048]{2,1,0} broadcast(multiply.668), dimensions={1,2} + subtract.228 = f32[2048,32,2048]{2,1,0} subtract(convert.295, broadcast.1263) + multiply.669 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, subtract.228) + bitcast.809 = f32[32,2048,2048]{1,0,2} bitcast(multiply.669) + reduce.95 = f32[32,2048]{1,0} reduce(bitcast.809, constant.1968), dimensions={2}, to_apply=region_20.962 + all-reduce.27 = f32[32,2048]{1,0} all-reduce(reduce.95), channel_id=67, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962 + multiply.670 = f32[32,2048]{1,0} multiply(all-reduce.27, broadcast.1243) + add.598 = f32[32,2048]{1,0} add(multiply.670, broadcast.1264) + bitcast.1148 = f32[1,32,2048]{2,1,0} bitcast(add.598) + rsqrt.19 = f32[1,32,2048]{2,1,0} rsqrt(bitcast.1148) + bitcast.1602 = f32[32,2048]{1,0} bitcast(rsqrt.19) + broadcast.1329 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1602), dimensions={1,2} + multiply.750 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, broadcast.1329) + convert.312 = bf16[2048,32,2048]{2,1,0} convert(multiply.750) + get-tuple-element.42 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=18 + dynamic-slice.331 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.42, select.372, constant.1980), dynamic_slice_sizes={1,2048} + add.599 = bf16[1,2048]{1,0} add(dynamic-slice.331, broadcast.1266) + bitcast.1609 = bf16[2048]{0} bitcast(add.599) + broadcast.1330 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1609), dimensions={0} + multiply.745 = bf16[2048,32,2048]{2,1,0} multiply(convert.312, broadcast.1330) + get-tuple-element.43 = bf16[24,2048]{1,0} get-tuple-element(param.3), index=17 + dynamic-slice.332 = bf16[1,2048]{1,0} dynamic-slice(get-tuple-element.43, select.372, constant.1980), dynamic_slice_sizes={1,2048} + bitcast.1615 = bf16[2048]{0} bitcast(dynamic-slice.332) + broadcast.1331 = bf16[2048,32,2048]{2,1,0} broadcast(bitcast.1615), dimensions={0} + add.636 = bf16[2048,32,2048]{2,1,0} add(multiply.745, broadcast.1331) + bitcast.1620 = bf16[32,2048,2048]{1,0,2} bitcast(add.636) + all-gather.12 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1620), channel_id=69, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true + bitcast.877 = bf16[65536,4096]{0,1} bitcast(all-gather.12) + get-tuple-element.44 = bf16[24,1024,8192]{2,1,0} get-tuple-element(param.3), index=15 + dynamic-slice.333 = bf16[1,1024,8192]{2,1,0} dynamic-slice(get-tuple-element.44, select.372, constant.1980, constant.1980), dynamic_slice_sizes={1,1024,8192} + bitcast.890 = bf16[1024,8192]{1,0} bitcast(dynamic-slice.333) + all-gather.11 = bf16[4096,8192]{1,0} all-gather(bitcast.890), channel_id=68, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={0}, use_global_device_ids=true + dot.45 = bf16[65536,8192]{1,0} dot(bitcast.877, all-gather.11), lhs_contracting_dims={1}, rhs_contracting_dims={0} + get-tuple-element.45 = bf16[24,8192]{1,0} get-tuple-element(param.3), index=14 + dynamic-slice.334 = bf16[1,8192]{1,0} dynamic-slice(get-tuple-element.45, select.372, constant.1980), dynamic_slice_sizes={1,8192} + bitcast.906 = bf16[8192]{0} bitcast(dynamic-slice.334) + broadcast.1269 = bf16[65536,8192]{1,0} broadcast(bitcast.906), dimensions={1} + add.601 = bf16[65536,8192]{1,0} add(dot.45, broadcast.1269) + bitcast.997 = bf16[32,2048,8192]{2,1,0} bitcast(add.601) + broadcast.1333 = bf16[2048,32,2048]{2,1,0} broadcast(subtract.229), dimensions={1,2} + multiply.746 = bf16[2048,32,2048]{2,1,0} multiply(bitcast.1652, broadcast.1333) + bitcast.1739 = bf16[32,2048,2048]{1,0,2} bitcast(multiply.746) + all-gather.14 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1739), channel_id=71, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true + bitcast.934 = bf16[65536,4096]{0,1} bitcast(all-gather.14) + get-tuple-element.46 = bf16[24,8192,1024]{1,2,0} get-tuple-element(param.3), index=16 + bitcast.935 = bf16[24,1024,8192]{2,1,0} bitcast(get-tuple-element.46) + dynamic-slice.335 = bf16[1,1024,8192]{2,1,0} dynamic-slice(bitcast.935, select.372, constant.1980, constant.1980), dynamic_slice_sizes={1,1024,8192} + bitcast.947 = bf16[8192,1024]{0,1} bitcast(dynamic-slice.335) + all-gather.13 = bf16[8192,4096]{0,1} all-gather(bitcast.947), channel_id=70, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={1}, use_global_device_ids=true + dot.46 = bf16[65536,8192]{1,0} dot(bitcast.934, all-gather.13), lhs_contracting_dims={1}, rhs_contracting_dims={1} + bitcast.1092 = bf16[32,2048,8192]{2,1,0} bitcast(dot.46) + broadcast.1335 = bf16[32,2048,8192]{2,1,0} broadcast(subtract.229), dimensions={0,1} + multiply.703 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.1092, broadcast.1335) + multiply.685 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, multiply.703) + constant.2002 = bf16[] constant(0.5) + broadcast.1288 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2002), dimensions={} + multiply.686 = bf16[32,2048,8192]{2,1,0} multiply(multiply.685, broadcast.1288) + broadcast.1287 = bf16[32,2048,8192]{2,1,0} broadcast(constant.1991), dimensions={} + multiply.700 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, bitcast.997) + multiply.693 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, multiply.700) + constant.1998 = bf16[] constant(0.04468) + broadcast.1282 = bf16[32,2048,8192]{2,1,0} broadcast(constant.1998), dimensions={} + multiply.694 = bf16[32,2048,8192]{2,1,0} multiply(multiply.693, broadcast.1282) + add.605 = bf16[32,2048,8192]{2,1,0} add(bitcast.997, multiply.694) + constant.2010 = bf16[] constant(0.7969) + broadcast.1324 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2010), dimensions={} + multiply.695 = bf16[32,2048,8192]{2,1,0} multiply(add.605, broadcast.1324) + tanh.7 = bf16[32,2048,8192]{2,1,0} tanh(multiply.695) + subtract.231 = bf16[32,2048,8192]{2,1,0} subtract(broadcast.1287, tanh.7) + multiply.691 = bf16[32,2048,8192]{2,1,0} multiply(multiply.686, subtract.231) + multiply.737 = bf16[32,2048,8192]{2,1,0} multiply(multiply.691, tanh.7) + add.630 = bf16[32,2048,8192]{2,1,0} add(multiply.691, multiply.737) + multiply.738 = bf16[32,2048,8192]{2,1,0} multiply(add.630, broadcast.1324) + constant.2011 = bf16[] constant(0.03564) + broadcast.1326 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2011), dimensions={} + multiply.739 = bf16[32,2048,8192]{2,1,0} multiply(add.630, broadcast.1326) + constant.2012 = bf16[] constant(3) + broadcast.1327 = bf16[32,2048,8192]{2,1,0} broadcast(constant.2012), dimensions={} + multiply.740 = bf16[32,2048,8192]{2,1,0} multiply(multiply.700, broadcast.1327) + multiply.741 = bf16[32,2048,8192]{2,1,0} multiply(multiply.739, multiply.740) + add.632 = bf16[32,2048,8192]{2,1,0} add(multiply.738, multiply.741) + add.637 = bf16[32,2048,8192]{2,1,0} add(tanh.7, broadcast.1287) + multiply.747 = bf16[32,2048,8192]{2,1,0} multiply(add.637, broadcast.1288) + multiply.743 = bf16[32,2048,8192]{2,1,0} multiply(multiply.703, multiply.747) + add.635 = bf16[32,2048,8192]{2,1,0} add(add.632, multiply.743) + bitcast.1629 = bf16[65536,8192]{1,0} bitcast(add.635) + dot.47 = bf16[65536,4096]{0,1} dot(bitcast.1629, all-gather.11), lhs_contracting_dims={1}, rhs_contracting_dims={1} + bitcast.1130 = bf16[32,2048,4096]{1,0,2} bitcast(dot.47) + reduce-scatter.1 = bf16[32,2048,2048]{1,0,2} reduce-scatter(bitcast.1130), channel_id=323, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={2}, to_apply=add + bitcast.1766 = bf16[2048,32,2048]{2,1,0} bitcast(reduce-scatter.1) + multiply.712 = bf16[2048,32,2048]{2,1,0} multiply(bitcast.1766, broadcast.1330) + convert.299 = f32[2048,32,2048]{2,1,0} convert(multiply.712) + multiply.707 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, convert.299) + bitcast.1135 = f32[32,2048,2048]{1,0,2} bitcast(multiply.707) + reduce.96 = f32[32,2048]{1,0} reduce(bitcast.1135, constant.1968), dimensions={2}, to_apply=region_20.962 + all-reduce.29 = f32[32,2048]{1,0} all-reduce(reduce.96), channel_id=73, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962 + bitcast.1140 = f32[1,32,2048]{2,1,0} bitcast(all-reduce.29) + divide.105 = f32[1,32,2048]{2,1,0} divide(rsqrt.19, bitcast.1148) + constant.2008 = f32[] constant(-0.5) + broadcast.1313 = f32[1,32,2048]{2,1,0} broadcast(constant.2008), dimensions={} + multiply.708 = f32[1,32,2048]{2,1,0} multiply(divide.105, broadcast.1313) + multiply.709 = f32[1,32,2048]{2,1,0} multiply(bitcast.1140, multiply.708) + constant.2009 = f32[] constant(0.00048828125) + broadcast.1315 = f32[1,32,2048]{2,1,0} broadcast(constant.2009), dimensions={} + multiply.710 = f32[1,32,2048]{2,1,0} multiply(multiply.709, broadcast.1315) + bitcast.1235 = f32[32,2048]{1,0} bitcast(multiply.710) + broadcast.1296 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1235), dimensions={1,2} + multiply.717 = f32[2048,32,2048]{2,1,0} multiply(subtract.228, broadcast.1296) + multiply.718 = f32[2048,32,2048]{2,1,0} multiply(convert.299, broadcast.1329) + add.617 = f32[2048,32,2048]{2,1,0} add(multiply.717, multiply.718) + negate.58 = f32[2048,32,2048]{2,1,0} negate(multiply.717) + bitcast.1189 = f32[32,2048,2048]{1,0,2} bitcast(negate.58) + reduce.97 = f32[32,2048]{1,0} reduce(bitcast.1189, constant.1968), dimensions={2}, to_apply=region_20.962 + negate.59 = f32[2048,32,2048]{2,1,0} negate(multiply.718) + bitcast.1203 = f32[32,2048,2048]{1,0,2} bitcast(negate.59) + reduce.98 = f32[32,2048]{1,0} reduce(bitcast.1203, constant.1968), dimensions={2}, to_apply=region_20.962 + add.613 = f32[32,2048]{1,0} add(reduce.97, reduce.98) + all-reduce.274 = f32[32,2048]{1,0} all-reduce(add.613), channel_id=335, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962 + multiply.719 = f32[32,2048]{1,0} multiply(all-reduce.274, broadcast.1243) + broadcast.1297 = f32[2048,32,2048]{2,1,0} broadcast(multiply.719), dimensions={1,2} + add.618 = f32[2048,32,2048]{2,1,0} add(add.617, broadcast.1297) + convert.301 = bf16[2048,32,2048]{2,1,0} convert(add.618) + add.619 = bf16[2048,32,2048]{2,1,0} add(bitcast.1652, convert.301) + add.616 = bf16[2048,32,2048]{2,1,0} add(convert.287, add.619) + bitcast.2063 = bf16[32,2048,2048]{1,0,2} bitcast(add.619) + all-gather.15 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.2063), channel_id=76, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true + bitcast.1263 = bf16[65536,4096]{0,1} bitcast(all-gather.15) + bitcast.1269 = bf16[4096,2048]{1,0} bitcast(all-gather.9) + dot.48 = bf16[65536,2048]{1,0} dot(bitcast.1263, bitcast.1269), lhs_contracting_dims={1}, rhs_contracting_dims={0} + bitcast.1381 = bf16[32,2048,16,128]{3,2,1,0} bitcast(dot.48) + transpose.122 = bf16[32,16,2048,128]{3,2,1,0} transpose(bitcast.1381), dimensions={0,2,1,3} + dot.49 = bf16[32,16,2048,2048]{3,2,1,0} dot(transpose.122, bitcast.1280), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + convert.303 = f32[32,16,2048,2048]{3,2,1,0} convert(dot.49) + broadcast.1298 = f32[32,16,2048]{2,1,0} broadcast(constant.2005), dimensions={} + multiply.720 = f32[32,16,2048]{2,1,0} multiply(reduce.93, reduce.93) + divide.106 = f32[32,16,2048]{2,1,0} divide(broadcast.1298, multiply.720) + broadcast.1299 = f32[32,16,2048,2048]{3,2,1,0} broadcast(divide.106), dimensions={0,1,2} + multiply.721 = f32[32,16,2048,2048]{3,2,1,0} multiply(convert.303, broadcast.1299) + multiply.722 = f32[32,16,2048,2048]{3,2,1,0} multiply(multiply.721, exponential.8) + reduce.99 = f32[32,16,2048]{2,1,0} reduce(multiply.722, constant.1968), dimensions={3}, to_apply=region_20.962 + negate.61 = f32[32,16,2048]{2,1,0} negate(reduce.99) + broadcast.1305 = f32[32,16,2048,2048]{3,2,1,0} broadcast(negate.61), dimensions={0,1,2} + divide.108 = f32[32,16,2048,2048]{3,2,1,0} divide(convert.303, broadcast.1309) + add.622 = f32[32,16,2048,2048]{3,2,1,0} add(broadcast.1305, divide.108) + multiply.724 = f32[32,16,2048,2048]{3,2,1,0} multiply(add.622, exponential.8) + convert.305 = bf16[32,16,2048,2048]{3,2,1,0} convert(multiply.724) + dot.50 = bf16[32,16,2048,128]{3,2,1,0} dot(convert.305, transpose.113), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + bitcast.1934 = bf16[1,32,16,2048,128]{4,3,2,1,0} bitcast(dot.50) + pad.6 = bf16[3,32,16,2048,128]{4,3,2,1,0} pad(bitcast.1934, constant.2006), padding=1_1x0_0x0_0x0_0x0_0 + transpose.120 = bf16[32,16,2048,128]{3,2,1,0} transpose(bitcast.647), dimensions={0,1,3,2} + dot.51 = bf16[32,16,2048,128]{3,2,1,0} dot(convert.305, transpose.120), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + broadcast.1307 = bf16[32,16,2048,128]{3,2,1,0} broadcast(constant.2007), dimensions={} + multiply.725 = bf16[32,16,2048,128]{3,2,1,0} multiply(dot.51, broadcast.1307) + bitcast.1941 = bf16[1,32,16,2048,128]{4,3,2,1,0} bitcast(multiply.725) + pad.7 = bf16[3,32,16,2048,128]{4,3,2,1,0} pad(bitcast.1941, constant.2006), padding=0_2x0_0x0_0x0_0x0_0 + add.638 = bf16[3,32,16,2048,128]{4,3,2,1,0} add(pad.6, pad.7) + transpose.123 = bf16[32,16,128,2048]{3,2,1,0} transpose(bitcast.1381), dimensions={0,2,3,1} + dot.89 = bf16[32,16,2048,128]{3,2,1,0} dot(convert.306, transpose.123), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + bitcast.1949 = bf16[1,32,16,2048,128]{4,3,2,1,0} bitcast(dot.89) + pad.8 = bf16[3,32,16,2048,128]{4,3,2,1,0} pad(bitcast.1949, constant.2006), padding=2_0x0_0x0_0x0_0x0_0 + add.639 = bf16[3,32,16,2048,128]{4,3,2,1,0} add(add.638, pad.8) + transpose.127 = bf16[32,2048,3,16,128]{4,3,2,1,0} transpose(add.639), dimensions={1,3,0,2,4} + bitcast.1416 = bf16[65536,6144]{1,0} bitcast(transpose.127) + dot.52 = bf16[65536,4096]{0,1} dot(bitcast.1416, bitcast.1420), lhs_contracting_dims={1}, rhs_contracting_dims={0} + bitcast.1424 = bf16[32,2048,4096]{1,0,2} bitcast(dot.52) + reduce-scatter.2 = bf16[32,2048,2048]{1,0,2} reduce-scatter(bitcast.1424), channel_id=324, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, dimensions={2}, to_apply=add + bitcast.1851 = bf16[2048,32,2048]{2,1,0} bitcast(reduce-scatter.2) + multiply.732 = bf16[2048,32,2048]{2,1,0} multiply(bitcast.1851, broadcast.1338) + convert.308 = f32[2048,32,2048]{2,1,0} convert(multiply.732) + multiply.727 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, convert.308) + bitcast.1434 = f32[32,2048,2048]{1,0,2} bitcast(multiply.727) + reduce.100 = f32[32,2048]{1,0} reduce(bitcast.1434, constant.1968), dimensions={2}, to_apply=region_20.962 + all-reduce.33 = f32[32,2048]{1,0} all-reduce(reduce.100), channel_id=78, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962 + bitcast.1439 = f32[1,32,2048]{2,1,0} bitcast(all-reduce.33) + divide.110 = f32[1,32,2048]{2,1,0} divide(rsqrt.20, bitcast.1447) + multiply.728 = f32[1,32,2048]{2,1,0} multiply(divide.110, broadcast.1313) + multiply.729 = f32[1,32,2048]{2,1,0} multiply(bitcast.1439, multiply.728) + multiply.730 = f32[1,32,2048]{2,1,0} multiply(multiply.729, broadcast.1315) + bitcast.1485 = f32[32,2048]{1,0} bitcast(multiply.730) + broadcast.1321 = f32[2048,32,2048]{2,1,0} broadcast(bitcast.1485), dimensions={1,2} + multiply.734 = f32[2048,32,2048]{2,1,0} multiply(subtract.224, broadcast.1321) + multiply.735 = f32[2048,32,2048]{2,1,0} multiply(convert.308, broadcast.1337) + add.625 = f32[2048,32,2048]{2,1,0} add(multiply.734, multiply.735) + negate.62 = f32[2048,32,2048]{2,1,0} negate(multiply.734) + bitcast.1491 = f32[32,2048,2048]{1,0,2} bitcast(negate.62) + reduce.101 = f32[32,2048]{1,0} reduce(bitcast.1491, constant.1968), dimensions={2}, to_apply=region_20.962 + negate.63 = f32[2048,32,2048]{2,1,0} negate(multiply.735) + bitcast.1505 = f32[32,2048,2048]{1,0,2} bitcast(negate.63) + reduce.102 = f32[32,2048]{1,0} reduce(bitcast.1505, constant.1968), dimensions={2}, to_apply=region_20.962 + add.626 = f32[32,2048]{1,0} add(reduce.101, reduce.102) + all-reduce.275 = f32[32,2048]{1,0} all-reduce(add.626), channel_id=336, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=region_20.962 + multiply.736 = f32[32,2048]{1,0} multiply(all-reduce.275, broadcast.1243) + broadcast.1323 = f32[2048,32,2048]{2,1,0} broadcast(multiply.736), dimensions={1,2} + add.628 = f32[2048,32,2048]{2,1,0} add(add.625, broadcast.1323) + convert.309 = bf16[2048,32,2048]{2,1,0} convert(add.628) + add.629 = bf16[2048,32,2048]{2,1,0} add(add.616, convert.309) + bitcast.1525 = bf16[32,2048,2048]{1,0,2} bitcast(add.629) + get-tuple-element.47 = bf16[24,8192]{1,0} get-tuple-element(param.3), index=2 + reduce.103 = bf16[8192]{0} reduce(add.635, constant.2006), dimensions={0,1}, to_apply=add + all-reduce.36 = bf16[8192]{0} all-reduce(reduce.103), channel_id=81, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add + bitcast.1583 = bf16[1,8192]{1,0} bitcast(all-reduce.36) + dynamic-update-slice.28 = bf16[24,8192]{1,0} dynamic-update-slice(get-tuple-element.47, bitcast.1583, select.372, constant.1980) + get-tuple-element.48 = bf16[24,1024,8192]{2,1,0} get-tuple-element(param.3), index=3 + all-gather.16 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1620), channel_id=82, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true + bitcast.1625 = bf16[4096,65536]{1,0} bitcast(all-gather.16) + dot.53 = bf16[4096,8192]{1,0} dot(bitcast.1625, bitcast.1629), lhs_contracting_dims={1}, rhs_contracting_dims={0} + reduce-scatter.3 = bf16[1024,8192]{1,0} reduce-scatter(dot.53), channel_id=325, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, dimensions={0}, to_apply=add + bitcast.1634 = bf16[1,1024,8192]{2,1,0} bitcast(reduce-scatter.3) + dynamic-update-slice.29 = bf16[24,1024,8192]{2,1,0} dynamic-update-slice(get-tuple-element.48, bitcast.1634, select.372, constant.1980, constant.1980) + get-tuple-element.49 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=4 + collective-permute.2 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.49), channel_id=85, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}} + all-gather.17 = bf16[24,2048]{0,1} all-gather(collective-permute.2), channel_id=86, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true + bitcast.1649 = bf16[2048,24]{1,0} bitcast(all-gather.17) + reduce.104 = bf16[2048]{0} reduce(bitcast.1739, constant.2006), dimensions={0,1}, to_apply=add + all-reduce.38 = bf16[2048]{0} all-reduce(reduce.104), channel_id=84, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add + bitcast.1671 = bf16[2048,1]{1,0} bitcast(all-reduce.38) + dynamic-update-slice.30 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1649, bitcast.1671, constant.1980, select.372) + constant.2013 = s32[8]{0} constant({0, 2048, 0, 2048, 1024, 3072, 1024, 3072}) + partition-id.3 = u32[] partition-id() + dynamic-slice.336 = s32[1]{0} dynamic-slice(constant.2013, partition-id.3), dynamic_slice_sizes={1} + constant.2014 = s32[8]{0} constant({0, 2048, 0, 2048, 0, 2048, 0, 2048}) + dynamic-slice.337 = s32[1]{0} dynamic-slice(constant.2014, partition-id.3), dynamic_slice_sizes={1} + subtract.232 = s32[1]{0} subtract(dynamic-slice.336, dynamic-slice.337) + bitcast.2087 = s32[] bitcast(subtract.232) + dynamic-slice.338 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.30, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24} + bitcast.1695 = bf16[24,1024]{0,1} bitcast(dynamic-slice.338) + collective-permute.9 = bf16[24,1024]{0,1} collective-permute(bitcast.1695), channel_id=109, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}} + get-tuple-element.50 = bf16[24,8192,1024]{1,2,0} get-tuple-element(param.3), index=5 + bitcast.1698 = bf16[24,1024,8192]{2,1,0} bitcast(get-tuple-element.50) + multiply.748 = bf16[32,2048,8192]{2,1,0} multiply(bitcast.997, multiply.747) + multiply.749 = bf16[32,2048,8192]{2,1,0} multiply(multiply.748, broadcast.1335) + bitcast.1735 = bf16[8192,65536]{0,1} bitcast(multiply.749) + all-gather.18 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.1739), channel_id=87, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true + bitcast.1743 = bf16[65536,4096]{0,1} bitcast(all-gather.18) + dot.54 = bf16[8192,4096]{0,1} dot(bitcast.1735, bitcast.1743), lhs_contracting_dims={1}, rhs_contracting_dims={0} + reduce-scatter.4 = bf16[8192,1024]{0,1} reduce-scatter(dot.54), channel_id=326, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, dimensions={1}, to_apply=add + bitcast.1748 = bf16[1,1024,8192]{2,1,0} bitcast(reduce-scatter.4) + dynamic-update-slice.31 = bf16[24,1024,8192]{2,1,0} dynamic-update-slice(bitcast.1698, bitcast.1748, select.372, constant.1980, constant.1980) + bitcast.1758 = bf16[24,8192,1024]{1,2,0} bitcast(dynamic-update-slice.31) + get-tuple-element.51 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=6 + collective-permute.3 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.51), channel_id=90, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}} + all-gather.19 = bf16[24,2048]{0,1} all-gather(collective-permute.3), channel_id=91, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true + bitcast.1763 = bf16[2048,24]{1,0} bitcast(all-gather.19) + reduce.105 = bf16[2048]{0} reduce(reduce-scatter.1, constant.2006), dimensions={0,1}, to_apply=add + all-reduce.40 = bf16[2048]{0} all-reduce(reduce.105), channel_id=89, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add + bitcast.1779 = bf16[2048,1]{1,0} bitcast(all-reduce.40) + dynamic-update-slice.32 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1763, bitcast.1779, constant.1980, select.372) + dynamic-slice.339 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.32, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24} + bitcast.1794 = bf16[24,1024]{0,1} bitcast(dynamic-slice.339) + collective-permute.10 = bf16[24,1024]{0,1} collective-permute(bitcast.1794), channel_id=110, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}} + get-tuple-element.52 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=7 + collective-permute.4 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.52), channel_id=93, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}} + all-gather.20 = bf16[24,2048]{0,1} all-gather(collective-permute.4), channel_id=94, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true + bitcast.1801 = bf16[2048,24]{1,0} bitcast(all-gather.20) + multiply.751 = bf16[2048,32,2048]{2,1,0} multiply(convert.312, bitcast.1766) + bitcast.1817 = bf16[32,2048,2048]{1,0,2} bitcast(multiply.751) + reduce.106 = bf16[2048]{0} reduce(bitcast.1817, constant.2006), dimensions={0,1}, to_apply=add + all-reduce.41 = bf16[2048]{0} all-reduce(reduce.106), channel_id=92, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add + bitcast.1826 = bf16[2048,1]{1,0} bitcast(all-reduce.41) + dynamic-update-slice.33 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1801, bitcast.1826, constant.1980, select.372) + dynamic-slice.340 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.33, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24} + bitcast.1841 = bf16[24,1024]{0,1} bitcast(dynamic-slice.340) + collective-permute.11 = bf16[24,1024]{0,1} collective-permute(bitcast.1841), channel_id=111, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}} + get-tuple-element.53 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=8 + collective-permute.5 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.53), channel_id=96, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}} + all-gather.21 = bf16[24,2048]{0,1} all-gather(collective-permute.5), channel_id=97, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true + bitcast.1848 = bf16[2048,24]{1,0} bitcast(all-gather.21) + reduce.107 = bf16[2048]{0} reduce(reduce-scatter.2, constant.2006), dimensions={0,1}, to_apply=add + all-reduce.42 = bf16[2048]{0} all-reduce(reduce.107), channel_id=95, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add + bitcast.1864 = bf16[2048,1]{1,0} bitcast(all-reduce.42) + dynamic-update-slice.34 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1848, bitcast.1864, constant.1980, select.372) + dynamic-slice.341 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.34, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24} + bitcast.1879 = bf16[24,1024]{0,1} bitcast(dynamic-slice.341) + collective-permute.12 = bf16[24,1024]{0,1} collective-permute(bitcast.1879), channel_id=112, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}} + get-tuple-element.54 = bf16[24,1024]{0,1} get-tuple-element(param.3), index=9 + collective-permute.6 = bf16[24,1024]{0,1} collective-permute(get-tuple-element.54), channel_id=99, source_target_pairs={{0,0},{1,2},{2,4},{3,6},{4,1},{5,3},{6,5},{7,7}} + all-gather.22 = bf16[24,2048]{0,1} all-gather(collective-permute.6), channel_id=100, replica_groups={{0,4},{2,6},{1,5},{3,7}}, dimensions={1}, use_global_device_ids=true + bitcast.1886 = bf16[2048,24]{1,0} bitcast(all-gather.22) + multiply.753 = bf16[2048,32,2048]{2,1,0} multiply(convert.314, bitcast.1851) + bitcast.1905 = bf16[32,2048,2048]{1,0,2} bitcast(multiply.753) + reduce.108 = bf16[2048]{0} reduce(bitcast.1905, constant.2006), dimensions={0,1}, to_apply=add + all-reduce.43 = bf16[2048]{0} all-reduce(reduce.108), channel_id=98, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add + bitcast.1914 = bf16[2048,1]{1,0} bitcast(all-reduce.43) + dynamic-update-slice.35 = bf16[2048,24]{1,0} dynamic-update-slice(bitcast.1886, bitcast.1914, constant.1980, select.372) + dynamic-slice.342 = bf16[1024,24]{1,0} dynamic-slice(dynamic-update-slice.35, bitcast.2087, constant.1980), dynamic_slice_sizes={1024,24} + bitcast.1929 = bf16[24,1024]{0,1} bitcast(dynamic-slice.342) + collective-permute.13 = bf16[24,1024]{0,1} collective-permute(bitcast.1929), channel_id=113, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}} + get-tuple-element.55 = bf16[24,3,16,128]{3,2,1,0} get-tuple-element(param.3), index=10 + bitcast.1979 = bf16[3,32,2048,16,128]{4,2,3,1,0} bitcast(add.639) + reduce.109 = bf16[3,16,128]{2,1,0} reduce(bitcast.1979, constant.2006), dimensions={1,2}, to_apply=add + all-reduce.44 = bf16[3,16,128]{2,1,0} all-reduce(reduce.109), channel_id=101, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add + bitcast.1963 = bf16[1,3,16,128]{3,2,1,0} bitcast(all-reduce.44) + dynamic-update-slice.36 = bf16[24,3,16,128]{3,2,1,0} dynamic-update-slice(get-tuple-element.55, bitcast.1963, select.372, constant.1980, constant.1980, /*index=5*/constant.1980) + get-tuple-element.56 = bf16[24,3,1024,16,128]{4,3,1,2,0} get-tuple-element(param.3), index=11 + bitcast.1974 = bf16[24,1024,3,16,128]{4,3,2,1,0} bitcast(get-tuple-element.56) + transpose.130 = bf16[3,16,128,32,2048]{4,3,2,1,0} transpose(add.639), dimensions={0,2,4,1,3} + bitcast.1983 = bf16[6144,65536]{1,0} bitcast(transpose.130) + all-gather.23 = bf16[32,2048,4096]{1,0,2} all-gather(bitcast.2003), channel_id=102, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true + bitcast.2007 = bf16[65536,4096]{0,1} bitcast(all-gather.23) + dot.55 = bf16[6144,4096]{0,1} dot(bitcast.1983, bitcast.2007), lhs_contracting_dims={1}, rhs_contracting_dims={0} + bitcast.2011 = bf16[3,16,128,4096]{2,1,0,3} bitcast(dot.55) + reduce-scatter.5 = bf16[3,16,128,1024]{2,1,0,3} reduce-scatter(bitcast.2011), channel_id=327, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, dimensions={3}, to_apply=add + bitcast.2015 = bf16[1,1024,3,16,128]{4,3,2,1,0} bitcast(reduce-scatter.5) + dynamic-update-slice.37 = bf16[24,1024,3,16,128]{4,3,2,1,0} dynamic-update-slice(bitcast.1974, bitcast.2015, select.372, constant.1980, constant.1980, /*index=5*/constant.1980, constant.1980) + bitcast.2025 = bf16[24,3,1024,16,128]{4,3,1,2,0} bitcast(dynamic-update-slice.37) + get-tuple-element.57 = bf16[24,1024]{1,0} get-tuple-element(param.3), index=12 + reduce.110 = bf16[2048]{0} reduce(bitcast.2063, constant.2006), dimensions={0,1}, to_apply=add + all-reduce.46 = bf16[2048]{0} all-reduce(reduce.110), channel_id=104, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add + dynamic-slice.343 = bf16[1024]{0} dynamic-slice(all-reduce.46, bitcast.2087), dynamic_slice_sizes={1024} + bitcast.2046 = bf16[1,1024]{1,0} bitcast(dynamic-slice.343) + collective-permute.7 = bf16[1,1024]{1,0} collective-permute(bitcast.2046), channel_id=105, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}} + dynamic-update-slice.38 = bf16[24,1024]{1,0} dynamic-update-slice(get-tuple-element.57, collective-permute.7, select.372, constant.1980) + get-tuple-element.58 = bf16[24,1024,16,128]{3,2,1,0} get-tuple-element(param.3), index=13 + bitcast.2066 = bf16[2048,65536]{1,0} bitcast(add.619) + transpose.133 = bf16[16,32,2048,128]{3,2,1,0} transpose(dot.44), dimensions={1,0,3,2} + bitcast.2072 = bf16[32,2048,16,128]{3,1,0,2} bitcast(transpose.133) + all-gather.24 = bf16[32,2048,32,128]{3,1,0,2} all-gather(bitcast.2072), channel_id=106, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true + bitcast.2073 = bf16[32,32,2048,128]{3,2,1,0} bitcast(all-gather.24) + transpose.134 = bf16[32,2048,32,128]{3,2,1,0} transpose(bitcast.2073), dimensions={1,2,0,3} + bitcast.2077 = bf16[65536,4096]{1,0} bitcast(transpose.134) + dot.56 = bf16[2048,4096]{1,0} dot(bitcast.2066, bitcast.2077), lhs_contracting_dims={1}, rhs_contracting_dims={0} + bitcast.2081 = bf16[2048,32,128]{2,1,0} bitcast(dot.56) + all-reduce.47 = bf16[2048,32,128]{2,1,0} all-reduce(bitcast.2081), channel_id=107, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add + constant.2015 = s32[8]{0} constant({0, 0, 16, 16, 0, 0, 16, 16}) + dynamic-slice.344 = s32[1]{0} dynamic-slice(constant.2015, partition-id.3), dynamic_slice_sizes={1} + bitcast.2095 = s32[] bitcast(dynamic-slice.344) + dynamic-slice.345 = bf16[1024,16,128]{2,1,0} dynamic-slice(all-reduce.47, bitcast.2087, bitcast.2095, constant.1980), dynamic_slice_sizes={1024,16,128} + bitcast.2102 = bf16[1,1024,16,128]{3,2,1,0} bitcast(dynamic-slice.345) + collective-permute.8 = bf16[1,1024,16,128]{3,2,1,0} collective-permute(bitcast.2102), channel_id=108, source_target_pairs={{0,0},{2,1},{4,2},{6,3},{1,4},{3,5},{5,6},{7,7}} + dynamic-update-slice.39 = bf16[24,1024,16,128]{3,2,1,0} dynamic-update-slice(get-tuple-element.58, collective-permute.8, select.372, constant.1980, constant.1980, /*index=5*/constant.1980) + ROOT tuple.2 = (s32[], bf16[32,2048,2048]{1,0,2}, bf16[24,8192]{1,0}, bf16[24,1024,8192]{2,1,0}, bf16[24,1024]{0,1}, /*index=5*/bf16[24,8192,1024]{1,2,0}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, bf16[24,1024]{0,1}, /*index=10*/bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, bf16[24,8192]{1,0}, /*index=15*/bf16[24,1024,8192]{2,1,0}, bf16[24,8192,1024]{1,2,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, bf16[24,2048]{1,0}, /*index=20*/bf16[24,2048]{1,0}, bf16[24,3,16,128]{3,2,1,0}, bf16[24,3,1024,16,128]{4,3,1,2,0}, bf16[24,1024]{1,0}, bf16[24,1024,16,128]{3,2,1,0}, /*index=25*/bf16[24,32,2048,2048]{2,1,3,0}, bf16[32,1,2048,2048]{3,2,0,1}, bf16[32,2048]{1,0}) tuple(add.581, bitcast.1525, dynamic-update-slice.28, dynamic-update-slice.29, collective-permute.9, /*index=5*/bitcast.1758, collective-permute.10, collective-permute.11, collective-permute.12, collective-permute.13, /*index=10*/dynamic-update-slice.36, bitcast.2025, dynamic-update-slice.38, dynamic-update-slice.39, get-tuple-element.45, /*index=15*/get-tuple-element.44, get-tuple-element.46, get-tuple-element.43, get-tuple-element.42, get-tuple-element.37, /*index=20*/get-tuple-element.36, get-tuple-element.38, get-tuple-element.35, get-tuple-element.41, get-tuple-element.40, /*index=25*/get-tuple-element.32, get-tuple-element.39, get-tuple-element.33) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), GetCudnnVersionWithFlashAttentionSupport()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + + HloInstruction* fwd_instruction = nullptr; + HloInstruction* bwd_instruction = nullptr; + SCOPED_TRACE(m->ToString()); + for (HloInstruction* instr : + m->entry_computation()->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kCustomCall && + instr->custom_call_target() == kCudnnfMHASoftmaxCallTarget) { + fwd_instruction = instr; + } + if (instr->opcode() == HloOpcode::kCustomCall && + instr->custom_call_target() == kCudnnfMHASoftmaxBackwardCallTarget) { + bwd_instruction = instr; + } + } + EXPECT_NE(fwd_instruction, nullptr); + EXPECT_NE(bwd_instruction, nullptr); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fwd_instruction->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); + EXPECT_EQ(config.is_flash_attention(), true); + EXPECT_EQ(config.is_causal_mask(), true); +} + +TEST_F(CudnnFusedMhaRewriterTestHloTest, + BF16TrainingBmm2CanonicalizationRestoreFwdGraph) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + const char* module_str = R"( +HloModule pjit__unnamed_function_, entry_computation_layout={(bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,4,256,256]{3,2,1,0})->(bf16[4,256,8,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={false,false,false,false}, num_partitions=4 + +region_0.6 { + Arg_0.7 = bf16[] parameter(0) + Arg_1.8 = bf16[] parameter(1) + ROOT maximum.5 = bf16[] maximum(Arg_0.7, Arg_1.8) +} + +region_1.10 { + Arg_0.11 = f32[] parameter(0) + Arg_1.12 = f32[] parameter(1) + ROOT add.14 = f32[] add(Arg_0.11, Arg_1.12) +} + +add.clone { + x.1 = u32[] parameter(0) + y.1 = u32[] parameter(1) + ROOT add.15 = u32[] add(x.1, y.1) +} + +region_2.65 { + Arg_0.66 = bf16[] parameter(0) + Arg_1.67 = bf16[] parameter(1) + ROOT add.16 = bf16[] add(Arg_0.66, Arg_1.67) +} + +ENTRY main.164_spmd { + param = bf16[2,256,4,64]{3,2,1,0} parameter(2), sharding={devices=[2,1,2,1]<=[4]} + transpose.26 = bf16[2,4,64,256]{3,2,1,0} transpose(param), dimensions={0,2,3,1} + param.1 = bf16[2,256,4,64]{3,2,1,0} parameter(0), sharding={devices=[2,1,2,1]<=[4]} + transpose.27 = bf16[2,4,256,64]{3,2,1,0} transpose(param.1), dimensions={0,2,1,3} + constant.46 = bf16[] constant(0.5) + broadcast.126 = bf16[2,4,256,64]{3,2,1,0} broadcast(constant.46), dimensions={} + multiply.34 = bf16[2,4,256,64]{3,2,1,0} multiply(transpose.27, broadcast.126) + param.2 = bf16[2,256,4,64]{3,2,1,0} parameter(1), sharding={devices=[2,1,2,1]<=[4]} + transpose.29 = bf16[2,4,64,256]{3,2,1,0} transpose(param.2), dimensions={0,2,3,1} + dot.12 = bf16[2,4,256,256]{3,2,1,0} dot(multiply.34, transpose.29), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + param.3 = bf16[2,4,256,256]{3,2,1,0} parameter(4), sharding={devices=[2,2,1,1]<=[4]} + add.17 = bf16[2,4,256,256]{3,2,1,0} add(dot.12, param.3) + constant.47 = bf16[] constant(-inf) + reduce.4 = bf16[2,4,256]{2,1,0} reduce(add.17, constant.47), dimensions={3}, to_apply=region_0.6 + broadcast.127 = bf16[2,4,256,256]{3,2,1,0} broadcast(reduce.4), dimensions={0,1,2} + subtract.14 = bf16[2,4,256,256]{3,2,1,0} subtract(add.17, broadcast.127) + exponential.2 = bf16[2,4,256,256]{3,2,1,0} exponential(subtract.14) + convert.46 = f32[2,4,256,256]{3,2,1,0} convert(exponential.2) + constant.48 = f32[] constant(0) + reduce.5 = f32[2,4,256]{2,1,0} reduce(convert.46, constant.48), dimensions={3}, to_apply=region_1.10 + convert.47 = bf16[2,4,256]{2,1,0} convert(reduce.5) + broadcast.128 = bf16[2,4,256,256]{3,2,1,0} broadcast(convert.47), dimensions={0,1,2} + divide.7 = bf16[2,4,256,256]{3,2,1,0} divide(exponential.2, broadcast.128) + broadcast.129 = f32[4096]{0} broadcast(constant.48), dimensions={} + constant.50 = u32[] constant(0) + broadcast.131 = u32[8192]{0} broadcast(constant.50), dimensions={} + broadcast.133 = u32[4096]{0} broadcast(constant.50), dimensions={} + iota.3 = u32[8192]{0} iota(), iota_dimension=0 + slice.14 = u32[4096]{0} slice(iota.3), slice={[0:4096]} + slice.15 = u32[4096]{0} slice(iota.3), slice={[4096:8192]} + custom-call.3 = (u32[4096]{0}, u32[4096]{0}) custom-call(broadcast.133, broadcast.133, slice.14, slice.15), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[4096]{0}, u32[4096]{0}, u32[4096]{0}, u32[4096]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000\020\000\000\000\000\000\000" + get-tuple-element.6 = u32[4096]{0} get-tuple-element(custom-call.3), index=0 + constant.115 = u32[1]{0} constant({0}) + constant.52 = u32[4]{0} constant({0, 0, 1, 1}) + partition-id = u32[] partition-id() + dynamic-slice.21 = u32[1]{0} dynamic-slice(constant.52, partition-id), dynamic_slice_sizes={1} + constant.116 = u32[1]{0} constant({1}) + clamp.3 = u32[1]{0} clamp(constant.115, dynamic-slice.21, constant.116) + convert.48 = s32[1]{0} convert(clamp.3) + constant.117 = s32[1]{0} constant({2048}) + multiply.35 = s32[1]{0} multiply(convert.48, constant.117) + bitcast.105 = s32[] bitcast(multiply.35) + dynamic-slice.22 = u32[2048]{0} dynamic-slice(get-tuple-element.6, bitcast.105), dynamic_slice_sizes={2048} + constant.58 = s32[4]{0} constant({0, 0, 1, 1}) + dynamic-slice.23 = s32[1]{0} dynamic-slice(constant.58, partition-id), dynamic_slice_sizes={1} + multiply.36 = s32[1]{0} multiply(dynamic-slice.23, constant.117) + bitcast.108 = s32[] bitcast(multiply.36) + dynamic-update-slice.2 = u32[8192]{0} dynamic-update-slice(broadcast.131, dynamic-slice.22, bitcast.108) + get-tuple-element.7 = u32[4096]{0} get-tuple-element(custom-call.3), index=1 + dynamic-slice.24 = u32[2048]{0} dynamic-slice(get-tuple-element.7, bitcast.105), dynamic_slice_sizes={2048} + constant.65 = s32[] constant(4096) + add.18 = s32[] add(bitcast.108, constant.65) + dynamic-update-slice.3 = u32[8192]{0} dynamic-update-slice(dynamic-update-slice.2, dynamic-slice.24, add.18) + all-reduce = u32[8192]{0} all-reduce(dynamic-update-slice.3), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=add.clone + constant.118 = s32[1]{0} constant({4096}) + multiply.37 = s32[1]{0} multiply(dynamic-slice.23, constant.118) + bitcast.119 = s32[] bitcast(multiply.37) + dynamic-slice.25 = u32[4096]{0} dynamic-slice(all-reduce, bitcast.119), dynamic_slice_sizes={4096} + constant.69 = u32[] constant(9) + broadcast.134 = u32[4096]{0} broadcast(constant.69), dimensions={} + shift-right-logical.6 = u32[4096]{0} shift-right-logical(dynamic-slice.25, broadcast.134) + constant.70 = u32[] constant(1065353216) + broadcast.135 = u32[4096]{0} broadcast(constant.70), dimensions={} + or.5 = u32[4096]{0} or(shift-right-logical.6, broadcast.135) + bitcast-convert.5 = f32[4096]{0} bitcast-convert(or.5) + constant.71 = f32[] constant(-1) + broadcast.136 = f32[4096]{0} broadcast(constant.71), dimensions={} + add.19 = f32[4096]{0} add(bitcast-convert.5, broadcast.136) + maximum.6 = f32[4096]{0} maximum(broadcast.129, add.19) + constant.72 = f32[] constant(0.5) + broadcast.137 = f32[4096]{0} broadcast(constant.72), dimensions={} + compare.4 = pred[4096]{0} compare(maximum.6, broadcast.137), direction=LT + bitcast.135 = pred[2,8,256]{2,1,0} bitcast(compare.4) + convert.49 = bf16[2,8,256]{2,1,0} convert(bitcast.135) + constant.80 = s32[] constant(0) + constant.78 = s32[4]{0} constant({0, 4, 0, 4}) + dynamic-slice.26 = s32[1]{0} dynamic-slice(constant.78, partition-id), dynamic_slice_sizes={1} + bitcast.181 = s32[] bitcast(dynamic-slice.26) + dynamic-slice.27 = bf16[2,4,256]{2,1,0} dynamic-slice(convert.49, constant.80, bitcast.181, constant.80), dynamic_slice_sizes={2,4,256} + broadcast.139 = bf16[2,4,256,256]{3,2,1,0} broadcast(dynamic-slice.27), dimensions={0,1,3} + multiply.38 = bf16[2,4,256,256]{3,2,1,0} multiply(divide.7, broadcast.139) + constant.93 = bf16[] constant(2) + broadcast.141 = bf16[2,4,256,256]{3,2,1,0} broadcast(constant.93), dimensions={} + multiply.39 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.38, broadcast.141) + dot.13 = bf16[2,4,64,256]{3,2,1,0} dot(transpose.26, multiply.39), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + transpose.31 = bf16[4,2,64,256]{3,2,1,0} transpose(dot.13), dimensions={1,0,2,3} + bitcast.154 = bf16[2,256,4,64]{1,3,0,2} bitcast(transpose.31) + all-gather = bf16[2,256,8,64]{1,3,0,2} all-gather(bitcast.154), channel_id=2, replica_groups={{0,1},{2,3}}, dimensions={2}, use_global_device_ids=true + bitcast.155 = bf16[8,2,64,256]{3,2,1,0} bitcast(all-gather) + transpose.32 = bf16[2,8,64,256]{3,2,1,0} transpose(bitcast.155), dimensions={1,0,2,3} + bitcast.157 = bf16[2,256,8,64]{1,3,2,0} bitcast(transpose.32) + all-gather.1 = bf16[4,256,8,64]{1,3,2,0} all-gather(bitcast.157), channel_id=3, replica_groups={{0,2},{1,3}}, dimensions={0}, use_global_device_ids=true + bitcast.236 = bf16[4,8,64,256]{3,2,1,0} bitcast(all-gather.1) + transpose.38 = bf16[4,256,8,64]{3,2,1,0} transpose(bitcast.236), dimensions={0,3,1,2} + param.4 = bf16[2,256,4,64]{3,2,1,0} parameter(3), sharding={devices=[2,1,2,1]<=[4]} + transpose.33 = bf16[2,4,256,64]{3,2,1,0} transpose(param.4), dimensions={0,2,1,3} + dot.14 = bf16[2,4,256,256]{3,2,1,0} dot(transpose.33, transpose.26), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + broadcast.142 = bf16[4096]{0} broadcast(constant.93), dimensions={} + constant.95 = bf16[] constant(0) + broadcast.143 = bf16[4096]{0} broadcast(constant.95), dimensions={} + select.4 = bf16[4096]{0} select(compare.4, broadcast.142, broadcast.143) + bitcast.176 = bf16[2,8,256]{2,1,0} bitcast(select.4) + dynamic-slice.28 = bf16[2,4,256]{2,1,0} dynamic-slice(bitcast.176, constant.80, bitcast.181, constant.80), dynamic_slice_sizes={2,4,256} + broadcast.145 = bf16[2,4,256,256]{3,2,1,0} broadcast(dynamic-slice.28), dimensions={0,1,3} + multiply.40 = bf16[2,4,256,256]{3,2,1,0} multiply(dot.14, broadcast.145) + divide.8 = bf16[2,4,256,256]{3,2,1,0} divide(multiply.40, broadcast.128) + constant.106 = bf16[] constant(1) + broadcast.146 = bf16[2,4,256]{2,1,0} broadcast(constant.106), dimensions={} + multiply.41 = bf16[2,4,256]{2,1,0} multiply(convert.47, convert.47) + divide.9 = bf16[2,4,256]{2,1,0} divide(broadcast.146, multiply.41) + broadcast.147 = bf16[2,4,256,256]{3,2,1,0} broadcast(divide.9), dimensions={0,1,2} + multiply.42 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.40, broadcast.147) + multiply.43 = bf16[2,4,256,256]{3,2,1,0} multiply(multiply.42, exponential.2) + reduce.6 = bf16[2,4,256]{2,1,0} reduce(multiply.43, constant.95), dimensions={3}, to_apply=region_2.65 + negate.4 = bf16[2,4,256]{2,1,0} negate(reduce.6) + broadcast.148 = bf16[2,4,256,256]{3,2,1,0} broadcast(negate.4), dimensions={0,1,2} + add.20 = bf16[2,4,256,256]{3,2,1,0} add(divide.8, broadcast.148) + multiply.44 = bf16[2,4,256,256]{3,2,1,0} multiply(add.20, exponential.2) + transpose.34 = bf16[2,4,256,64]{3,2,1,0} transpose(param.2), dimensions={0,2,1,3} + dot.15 = bf16[2,4,256,64]{3,2,1,0} dot(multiply.44, transpose.34), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + multiply.45 = bf16[2,4,256,64]{3,2,1,0} multiply(dot.15, broadcast.126) + transpose.39 = bf16[2,256,4,64]{3,2,1,0} transpose(multiply.45), dimensions={0,2,1,3} + dot.16 = bf16[2,4,256,64]{3,2,1,0} dot(multiply.44, multiply.34), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.40 = bf16[2,256,4,64]{3,2,1,0} transpose(dot.16), dimensions={0,2,1,3} + transpose.36 = bf16[2,4,64,256]{3,2,1,0} transpose(param.4), dimensions={0,2,3,1} + dot.11 = bf16[2,4,64,256]{3,2,1,0} dot(transpose.36, multiply.39), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.41 = bf16[2,256,4,64]{3,2,1,0} transpose(dot.11), dimensions={0,3,1,2} + ROOT tuple.2 = (bf16[4,256,8,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}, bf16[2,256,4,64]{3,2,1,0}) tuple(transpose.38, transpose.39, transpose.40, transpose.41) +} // main.164_spmd +)"; + // Dropout bwd pattern not supported, should not lower fwd as well + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), + GetCudnnVersionWithDbiasAndMaskBwdInputSupport()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + SCOPED_TRACE(m->ToString()); + // check if fwd graph has been restored with cloned activation + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::Transpose(), m::Transpose(), m::Transpose(), + m::Transpose(m::Dot( + m::Op(), m::Op().WithPredicate([](const HloInstruction* instr) { + return instr->name() == "multiply.39.fmha_no_match_clone"; + })))))); +} + +constexpr absl::string_view hlo_should_lower_to_flash_attention = R"( +HloModule fmha_test, entry_computation_layout={(bf16[16,16,128,64]{3,2,1,0},bf16[16,16,1024,64]{3,2,1,0},bf16[16,16,1024,64]{3,2,1,0})->bf16[16,16,128,64]{3,2,1,0}} +ENTRY main.6 { + Arg_0.1 = bf16[16,16,128,64]{3,2,1,0} parameter(0) + Arg_1.2 = bf16[16,16,1024,64]{3,2,1,0} parameter(1) + Arg_2.3 = bf16[16,16,1024,64]{3,2,1,0} parameter(2) + dot.0 = bf16[16,16,128,1024]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, metadata={} + ROOT dot.1 = bf16[16,16,128,64]{3,2,1,0} dot(dot.0, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, metadata={} +})"; + +TEST_F(CudnnFusedMhaRewriterTestHloTest, ShouldLowerToFlashAttention) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + TF_ASSERT_OK_AND_ASSIGN( + auto m, ParseAndReturnVerifiedModule(hlo_should_lower_to_flash_attention, + GetModuleConfig())); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), GetCudnnVersionWithFlashAttentionSupport()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + const HloInstruction* fmha; + + SCOPED_TRACE(m->ToString()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::GetTupleElement( + m::CustomCall(&fmha, {kCudnnfMHABmmBmmCallTarget}), 0) + .WithShape(BF16, {16, 16, 128, 64}))); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); + EXPECT_EQ(config.fmha_scale(), 1.0); + EXPECT_EQ(config.dropout_rate(), 0.0); + EXPECT_EQ(config.is_flash_attention(), true); +} + +constexpr absl::string_view hlo_head_dim_not_multiple_of_64 = R"( +HloModule jit__reference, entry_computation_layout={(f16[4,48,1024,16]{3,2,1,0}, f16[4,48,1024,16]{3,2,1,0}, f16[4,48,1024,16]{3,2,1,0})->f16[4,48,1024,16]{3,2,1,0}} + +region_0.26 { + Arg_0.27 = f32[] parameter(0) + Arg_1.28 = f32[] parameter(1) + ROOT maximum = f32[] maximum(Arg_0.27, Arg_1.28) +} + +region_1.37 { + Arg_0.38 = f32[] parameter(0) + Arg_1.39 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0.38, Arg_1.39) +} + +ENTRY main.49 { + iota.2 = s32[1024,1024]{1,0} iota(), iota_dimension=0 + iota.3 = s32[1024,1024]{1,0} iota(), iota_dimension=1 + compare = pred[1024,1024]{1,0} compare(iota.2, iota.3), direction=GE + broadcast.4 = pred[4,48,1024,1024]{3,2,1,0} broadcast(compare), dimensions={2,3} + Arg_0.1 = f16[4,48,1024,16]{3,2,1,0} parameter(0) + Arg_1.2 = f16[4,48,1024,16]{3,2,1,0} parameter(1) + dot.9 = f16[4,48,1024,1024]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + constant.4 = f16[] constant(0.5) + broadcast.6 = f16[4,48,1024,1024]{3,2,1,0} broadcast(constant.4), dimensions={} + multiply = f16[4,48,1024,1024]{3,2,1,0} multiply(dot.9, broadcast.6) + constant = f16[] constant(-inf) + broadcast.7 = f16[4,48,1024,1024]{3,2,1,0} broadcast(constant), dimensions={} + select.1 = f16[4,48,1024,1024]{3,2,1,0} select(broadcast.4, multiply, broadcast.7) + convert.1 = f32[4,48,1024,1024]{3,2,1,0} convert(select.1) + constant.7 = f32[] constant(-inf) + reduce.30 = f32[4,48,1024]{2,1,0} reduce(convert.1, constant.7), dimensions={3}, to_apply=region_0.26 + broadcast.8 = f32[4,48,1024,1024]{3,2,1,0} broadcast(reduce.30), dimensions={0,1,2} + subtract = f32[4,48,1024,1024]{3,2,1,0} subtract(convert.1, broadcast.8) + exponential = f32[4,48,1024,1024]{3,2,1,0} exponential(subtract) + constant.6 = f32[] constant(0) + reduce.41 = f32[4,48,1024]{2,1,0} reduce(exponential, constant.6), dimensions={3}, to_apply=region_1.37 + broadcast.9 = f32[4,48,1024,1024]{3,2,1,0} broadcast(reduce.41), dimensions={0,1,2} + divide = f32[4,48,1024,1024]{3,2,1,0} divide(exponential, broadcast.9) + convert.2 = f16[4,48,1024,1024]{3,2,1,0} convert(divide) + Arg_2.3 = f16[4,48,1024,16]{3,2,1,0} parameter(2) + ROOT dot.48 = f16[4,48,1024,16]{3,2,1,0} dot(convert.2, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} +} // main.49 +)"; + +TEST_F(CudnnFusedMhaRewriterTestHloTest, HeadDimNotMultipleOf64) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + TF_ASSERT_OK_AND_ASSIGN( + auto m, ParseAndReturnVerifiedModule(hlo_head_dim_not_multiple_of_64, + GetModuleConfig())); + CudnnFusedMHARewriter fusedMhaRewriter{ + GetCudaComputeCapability(), GetCudnnVersionWithFlashAttentionSupport()}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); + + // head dim not a multiple of 64 should not be lowered with cuDNN < 8.9.6 + SCOPED_TRACE(m->ToString()); + EXPECT_THAT(m->entry_computation()->root_instruction(), GmockMatch(m::Dot())); + + // should be lowered with cuDNN >= 8.9.6 + CudnnFusedMHARewriter fusedMhaRewriterWithcuDNN8907{ + GetCudaComputeCapability(), se::dnn::VersionInfo(8, 9, 7)}; + TF_ASSERT_OK(RunHloPass(&fusedMhaRewriterWithcuDNN8907, m.get()).status()); + const HloInstruction* fmha; + + SCOPED_TRACE(m->ToString()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch( + m::GetTupleElement( + m::CustomCall(&fmha, {kCudnnfMHAScaleMaskSoftmaxCallTarget}), 0) + .WithShape(F16, {4, 48, 1024, 16}))); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); + EXPECT_EQ(config.fmha_scale(), 0.5); + EXPECT_EQ(config.dropout_rate(), 0.0); + EXPECT_EQ(config.is_flash_attention(), true); +} } // anonymous namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc b/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc index 6d30eac7bf3fd..4abd0940e0795 100644 --- a/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc +++ b/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/permutation_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" @@ -33,7 +34,6 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -55,14 +55,16 @@ bool IsBwdFMHACustomCall(const HloInstruction* instr) { return IsBwdCustomCallTofMHA(*instr); } -StatusOr FuseArgPrologueTransposeWithcuDNNFMHA( +absl::StatusOr FuseArgPrologueTransposeWithcuDNNFMHA( HloInstruction* fmha, int64_t operand_index, bool is_lhs, bool should_contracting_be_fastest) { HloInstruction* transpose_arg = fmha->mutable_operand(operand_index); HloInstruction* transpose_arg_operand = transpose_arg->mutable_operand(0); - CudnnfMHABackendConfig config; - TF_ASSIGN_OR_RETURN(config, fmha->backend_config()); - CudnnfMHABackendConfig new_fmha_config = config; + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + fmha->backend_config()); + CudnnfMHABackendConfig config = gpu_config.cudnn_fmha_backend_config(); + CudnnfMHABackendConfig& new_fmha_config = + *gpu_config.mutable_cudnn_fmha_backend_config(); std::vector inverse_perm = InversePermutation(transpose_arg->dimensions()); @@ -96,12 +98,15 @@ StatusOr FuseArgPrologueTransposeWithcuDNNFMHA( new_bmm_dot_dims = config.bmm2_grad_gemm2_dot_dimension_numbers(); break; default: - return InternalError("Invalid operand index."); + return Internal("Invalid operand index."); } } absl::Span checked_dims; std::vector checked_dims_vec; + // `should_contracting_be_fastest` means if contracting dim is the head + // dim. cuDNN requires head dim to be the fastest dim. fwd bmm1 and bwd + // bmm2grad1 should set this value to true. if (should_contracting_be_fastest) { checked_dims = is_lhs ? new_bmm_dot_dims.lhs_contracting_dimensions() : new_bmm_dot_dims.rhs_contracting_dimensions(); @@ -125,21 +130,19 @@ StatusOr FuseArgPrologueTransposeWithcuDNNFMHA( auto itr = std::find(inverse_perm.begin(), inverse_perm.end(), checked_dims[i]); if (itr == inverse_perm.end()) { - return InternalError("Invalid inverse perm"); + return Internal("Invalid inverse perm"); } new_bmm_checked_dims[i] = std::distance(inverse_perm.begin(), itr); } // We want to make sure that making the argument to transpose, an input to - // fmha, doesn't break cuDNN constraint that the checked dimensions of - // corresponding operand of BMM has the fastest moving dimension. + // fmha, doesn't break cuDNN constraint that the head dim of + // corresponding operand of BMM is the fastest moving dimension. // One exception is the forward activation which doesn't have the constraint - // that the fastest dim has to be 64. + // since it does not have head dim. absl::Span minor_to_major_bmm = transpose_arg_operand->shape().layout().minor_to_major(); if ((minor_to_major_bmm[0] != new_bmm_checked_dims[0]) && - ((transpose_arg_operand->shape().dimensions().at( - new_bmm_checked_dims[0]) == 64) || - (IsBwdCustomCallTofMHA(*fmha) && operand_index == 3))) { + !(IsBwdCustomCallTofMHA(*fmha) && operand_index == 3)) { return false; } if (should_contracting_be_fastest) { @@ -161,7 +164,7 @@ StatusOr FuseArgPrologueTransposeWithcuDNNFMHA( auto itr = std::find(inverse_perm.begin(), inverse_perm.end(), batch_dims[i]); if (itr == inverse_perm.end()) { - return InternalError("Invalid inverse perm"); + return Internal("Invalid inverse perm"); } new_bmm_batch_dims[i] = std::distance(inverse_perm.begin(), itr); } @@ -198,8 +201,10 @@ StatusOr FuseArgPrologueTransposeWithcuDNNFMHA( } if (IsFwdCustomCallTofMHA(*fmha)) { if (operand_index == 0 || operand_index == 1) { + // Q or K *new_fmha_config.mutable_bmm1_dot_dimension_numbers() = new_bmm_dot_dims; } else { + // V *new_fmha_config.mutable_bmm2_dot_dimension_numbers() = new_bmm_dot_dims; } } else { @@ -239,7 +244,7 @@ StatusOr FuseArgPrologueTransposeWithcuDNNFMHA( transpose_permutation.end(), bmm2_grad_gemm1_contracting_dims[0]); if (itr == transpose_permutation.end()) { - return InternalError( + return Internal( "bmm2 gradident gemm1 contracting dimension not found."); } int64_t index = std::distance(transpose_permutation.begin(), itr); @@ -267,12 +272,11 @@ StatusOr FuseArgPrologueTransposeWithcuDNNFMHA( break; } default: - return InternalError("Invalid operand index."); + return Internal("Invalid operand index."); } } - fmha->clear_backend_config(); - TF_RETURN_IF_ERROR(fmha->set_backend_config(new_fmha_config)); + TF_RETURN_IF_ERROR(fmha->set_backend_config(gpu_config)); TF_RETURN_IF_ERROR(fmha->ReplaceOperandWithDifferentShape( operand_index, transpose_arg_operand)); @@ -308,7 +312,7 @@ the new lhs_contracting dim ,if A were to be the new lhs, would be 2. Similarly, we need to find corresponding batch dimensions as well. */ -StatusOr FusePrologueTransposeWithcuDNNFMHA(HloComputation* comp) { +absl::StatusOr FusePrologueTransposeWithcuDNNFMHA(HloComputation* comp) { bool changed = false; for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { HloInstruction *transpose_arg0, *transpose_arg0_operand; @@ -455,9 +459,18 @@ StatusOr FusePrologueTransposeWithcuDNNFMHA(HloComputation* comp) { } // D_output tensor in backward graph is lhs with constraint on // contracting dim. - TF_ASSIGN_OR_RETURN(changed, FuseArgPrologueTransposeWithcuDNNFMHA( - fmha, 4, true /*is_lhs*/, - true /*should_contracting_be_fastest*/)); + // make sure we dont change layout of dO in flash attention case as dO + // should have the same layout of O + TF_ASSIGN_OR_RETURN(auto gpu_config, + fmha->backend_config()); + const CudnnfMHABackendConfig config = + gpu_config.cudnn_fmha_backend_config(); + if (!config.is_flash_attention()) { + TF_ASSIGN_OR_RETURN(changed, + FuseArgPrologueTransposeWithcuDNNFMHA( + fmha, 4, true /*is_lhs=*/, + true /*should_contracting_be_fastest=*/)); + } if (changed && VLOG_IS_ON(2)) { VLOG(2) << "After CudnnFusedMHATransposeFusion Arg 4: \n" @@ -489,16 +502,42 @@ Calling this function with 'result' shape as the input shape and the inverse perm as the permutation will generate an output shape whose dimensions match 'FMHA_out' dimensions but the physical layout is equivalent to 'result'. This is exactly what we want. + +FMHA output should have exactly one gte instruction for a tuple index +so we can safely fuse the transpose following that gte to FMHA + +FMHA_out = gte(FMHA, index=0) +FMHA_out_t = transpose(FMHA_out) +use(FMHA_out_t) + +after fusion: + +FMHA_out_t = gte(FMHA, index=0) +use(FMHA_out_t) */ -StatusOr FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) { +absl::StatusOr FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) { bool changed = false; + + auto only_one_gte_with_spec_index = [](const HloInstruction* instr, + int64_t index) { + int count = 0; + for (auto user : instr->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->tuple_index() == index) { + count += 1; + } + } + return count == 1; + }; + for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { HloInstruction* fmha; HloInstruction* transpose; HloInstruction* gte; auto fwd_tuple_elem = - m::GetTupleElement(m::Op(&fmha).WithPredicate(IsFwdFMHACustomCall), 0) + m::GetTupleElement(>e, + m::Op(&fmha).WithPredicate(IsFwdFMHACustomCall), 0) .WithOneUser(); // Note that we don't match any specific tuple index in matcher for // backward. @@ -510,6 +549,10 @@ StatusOr FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) { auto bwd_pattern = m::Transpose(&transpose, bwd_tuple_elem); if (Match(instr, fwd_pattern)) { + // check if only one gte with such index exist + int64_t tuple_index = gte->tuple_index(); + if (!only_one_gte_with_spec_index(fmha, tuple_index)) continue; + std::vector inverse_perm = InversePermutation(transpose->dimensions()); @@ -538,8 +581,8 @@ StatusOr FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) { call_shape, fmha->operands(), absl::string_view(fmha->custom_call_target()))); - TF_ASSIGN_OR_RETURN(CudnnfMHABackendConfig config, - fmha->backend_config()); + TF_ASSIGN_OR_RETURN(GpuBackendConfig config, + fmha->backend_config()); TF_RETURN_IF_ERROR(new_fmha_custom_call->set_backend_config(config)); TF_RETURN_IF_ERROR( SetFMHAInstructionName(fmha->GetModule(), new_fmha_custom_call)); @@ -558,9 +601,12 @@ StatusOr FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) { } changed |= true; } else if (Match(instr, bwd_pattern)) { + // check if only one gte with such index exist + int64_t operand_tuple_idx = gte->tuple_index(); + if (!only_one_gte_with_spec_index(fmha, operand_tuple_idx)) continue; + std::vector inverse_perm = InversePermutation(transpose->dimensions()); - int64_t operand_tuple_idx = gte->tuple_index(); auto expected_fmha_shape = ShapeUtil::PermuteDimensions(inverse_perm, transpose->shape()); @@ -587,8 +633,8 @@ StatusOr FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) { call_shape, fmha->operands(), absl::string_view(fmha->custom_call_target()))); - TF_ASSIGN_OR_RETURN(CudnnfMHABackendConfig config, - fmha->backend_config()); + TF_ASSIGN_OR_RETURN(GpuBackendConfig config, + fmha->backend_config()); TF_RETURN_IF_ERROR(new_fmha_custom_call->set_backend_config(config)); TF_RETURN_IF_ERROR( SetFMHAInstructionName(fmha->GetModule(), new_fmha_custom_call)); @@ -612,7 +658,7 @@ StatusOr FuseEpilogueTransposeWithcuDNNFMHA(HloComputation* comp) { } } // namespace -StatusOr CudnnFusedMHATransposeFusion::Run( +absl::StatusOr CudnnFusedMHATransposeFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool any_changed = false; diff --git a/xla/service/gpu/cudnn_fused_mha_transpose_fusion.h b/xla/service/gpu/cudnn_fused_mha_transpose_fusion.h index 2cbdc041e82a6..94ec229d7709d 100644 --- a/xla/service/gpu/cudnn_fused_mha_transpose_fusion.h +++ b/xla/service/gpu/cudnn_fused_mha_transpose_fusion.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,7 +16,10 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_ #define XLA_SERVICE_GPU_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_ -#include "xla/hlo/ir/hlo_instructions.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" namespace xla { @@ -31,7 +34,7 @@ class CudnnFusedMHATransposeFusion : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/cudnn_fusion_compiler.cc b/xla/service/gpu/cudnn_fusion_compiler.cc new file mode 100644 index 0000000000000..2c3f4ec7221b3 --- /dev/null +++ b/xla/service/gpu/cudnn_fusion_compiler.cc @@ -0,0 +1,577 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/cudnn_fusion_compiler.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "third_party/gpus/cudnn/cudnn_version.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_clone_context.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/primitive_util.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/kernel_reuse_cache.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/triton_fusion_analysis.h" +#include "xla/stream_executor/cuda/cuda_dnn.h" +#include "xla/stream_executor/cuda/cudnn_frontend_helpers.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +namespace { + +namespace fe = cudnn_frontend; +namespace graph = fe::graph; + +inline std::optional GetElementwiseMode( + const HloInstruction& instruction) { + const HloOpcode opcode = instruction.opcode(); + using m = fe::PointwiseMode_t; + switch (opcode) { + case HloOpcode::kAbs: + return m::ABS; + case HloOpcode::kAdd: + return m::ADD; + case HloOpcode::kCompare: + switch (instruction.comparison_direction()) { + case Comparison::Direction::kEq: + return m::CMP_EQ; + case Comparison::Direction::kNe: + return m::CMP_NEQ; + case Comparison::Direction::kGe: + return m::CMP_GE; + case Comparison::Direction::kGt: + return m::CMP_GT; + case Comparison::Direction::kLe: + return m::CMP_LE; + case Comparison::Direction::kLt: + return m::CMP_LT; + } + break; + case HloOpcode::kConvert: + return m::IDENTITY; + case HloOpcode::kCos: + return m::COS; + case HloOpcode::kDivide: + return m::DIV; + case HloOpcode::kExp: + return m::EXP; + case HloOpcode::kLog: + return m::LOG; + case HloOpcode::kMaximum: + return m::MAX; + case HloOpcode::kMinimum: + return m::MIN; + case HloOpcode::kMultiply: + return m::MUL; + case HloOpcode::kNegate: + return m::NEG; + case HloOpcode::kPower: + return m::POW; + case HloOpcode::kRsqrt: + return m::RSQRT; +#if CUDNN_VERSION >= 90100 + case HloOpcode::kSelect: + return m::BINARY_SELECT; +#endif // CUDNN_VERSION + case HloOpcode::kSin: + return m::SIN; + case HloOpcode::kSqrt: + return m::SQRT; + case HloOpcode::kSubtract: + return m::SUB; + case HloOpcode::kTan: + return m::TAN; + case HloOpcode::kTanh: + return m::TANH_FWD; + default: + return std::nullopt; + } +} + +inline std::optional ToCudnnDataType(const PrimitiveType type) { + using t = fe::DataType_t; + switch (type) { + case PrimitiveType::F32: + return t::FLOAT; + case PrimitiveType::F16: + return t::HALF; + case PrimitiveType::BF16: + return t::BFLOAT16; + case PrimitiveType::S32: + return t::INT32; + case PrimitiveType::S8: + return t::INT8; + case PrimitiveType::PRED: + return t::INT8; + default: + return std::nullopt; + } +} + +inline std::optional GetComputeDataType( + const PrimitiveType type) { + fe::DataType_t compute_dtype = fe::DataType_t::FLOAT; + if (primitive_util::IsIntegralType(type)) { +#if CUDNN_VERSION >= 90100 + compute_dtype = fe::DataType_t::INT32; +#else + VLOG(3) << "Integer math requires cuDNN 9.1+."; + return std::nullopt; +#endif // CUDNN_VERSION + } + return compute_dtype; +} + +int FusionLevel(const HloInstruction& hlo) { + return hlo.GetModule() + ->config() + .debug_options() + .xla_gpu_cudnn_gemm_fusion_level(); +}; + +// Extracts dimensions and strides from HLO tensors in the format expected by +// cuDNN. +class GemmDimensionAdapter { + explicit GemmDimensionAdapter(const HloDotInstruction& dot, + TritonFusionAnalysis analysis) + : analysis_(std::move(analysis)), dot_(dot) {}; + + public: + const TritonFusionAnalysis analysis_; + + static absl::StatusOr> Create( + const HloComputation& computation) { + const HloInstruction* maybe_dot = + hlo_query::GetFirstInstructionWithOpcode(computation, HloOpcode::kDot); + if (maybe_dot == nullptr) { + VLOG(3) << "Not a GEMM fusion."; + return std::nullopt; + } + const HloDotInstruction* dot = DynCast( + hlo_query::GetFirstInstructionWithOpcode(computation, HloOpcode::kDot)); + if (absl::c_any_of(dot->precision_config().operand_precision(), + [](int x) { return x != PrecisionConfig::DEFAULT; })) { + VLOG(3) << "Non-default precision is not supported."; + return std::nullopt; + } + TF_ASSIGN_OR_RETURN(auto analysis, + TritonFusionAnalysis::Execute(computation)); + return GemmDimensionAdapter{*dot, std::move(analysis)}; + } + + bool DimensionsAndStrides(const HloInstruction& hlo, + const TritonFusionAnalysis::Scope scope, + std::vector& dimensions, + std::vector& strides) { + const DotDimensionNumbers& dims = dot_.dot_dimension_numbers(); + // GEMM fusions require a specific canonical order of dimensions. + constexpr int kBatchDimensionIndex = 0; + constexpr int kOutputLHSNonContractingDimensionIndex = 1; + std::vector dim_indices; + int lhs_noncontracting_index = -1; + switch (scope) { + case TritonFusionAnalysis::Scope::LHS: + lhs_noncontracting_index = + GetNonContractingDims(dot_.operand(0)->shape(), + dims.lhs_batch_dimensions(), + dims.lhs_contracting_dimensions()) + .value()[0]; + dim_indices = { + dims.lhs_batch_dimensions().empty() ? -1 + : dims.lhs_batch_dimensions(0), + lhs_noncontracting_index, dims.lhs_contracting_dimensions(0)}; + break; + case TritonFusionAnalysis::Scope::RHS: + dim_indices = {dims.rhs_batch_dimensions().empty() + ? -1 + : dims.rhs_batch_dimensions(0), + dims.rhs_contracting_dimensions(0), + GetNonContractingDims(dot_.operand(1)->shape(), + dims.rhs_batch_dimensions(), + dims.rhs_contracting_dimensions()) + .value()[0]}; + break; + case TritonFusionAnalysis::Scope::OUTPUT: + lhs_noncontracting_index = dot_.shape().rank() - 2; + dim_indices = {dims.lhs_batch_dimensions().empty() ? -1 : 0, + lhs_noncontracting_index, dot_.shape().rank() - 1}; + break; + case TritonFusionAnalysis::Scope::META: + LOG(FATAL) << "Unsupported scope."; + } + dimensions.reserve(dim_indices.size()); + strides.reserve(dim_indices.size()); + for (const int index : dim_indices) { + const auto* spec = analysis_.IterSpec(scope, &hlo, index); + if (spec == nullptr) { + dimensions.push_back(1); + strides.push_back(strides.empty() ? 1 : strides.back()); + continue; + } else { + if (spec->size() == 1) { + // The dimension is not split, nothing to do. + } else if (spec->size() == 2) { + if (FusionLevel(hlo) < 3) { + return false; + } + if (!dims.lhs_batch_dimensions().empty()) { + VLOG(8) << "Noncontracting dimension split is not compatible with " + "batch dimensions."; + return false; + } + if (index != lhs_noncontracting_index) { + VLOG(8) << "Only LHS noncontracting dimension can be split."; + return false; + } + switch (scope) { + case TritonFusionAnalysis::Scope::LHS: + lhs_noncontracting_split = spec->back().count; + break; + case TritonFusionAnalysis::Scope::OUTPUT: + if (lhs_noncontracting_split != spec->back().count) { + VLOG(8) << "Output non-contracting dimension has to be split " + "the same way as the LHS input one if it is split."; + return false; + } + break; + default: + VLOG(8) << "Only LHS noncontracting dimension can be split."; + return false; + } + // Assign the major part of the noncontracting dimension to the + // unused batch one. + CHECK_EQ(dimensions[kBatchDimensionIndex], 1); + dimensions[kBatchDimensionIndex] = spec->back().count; + strides[kBatchDimensionIndex] = spec->back().stride; + } else { + VLOG(8) << "The dimension is split multiple times."; + return false; + } + dimensions.push_back(spec->front().count); + strides.push_back(spec->front().stride); + } + } + if (lhs_noncontracting_split > 1 && + scope == TritonFusionAnalysis::Scope::OUTPUT && + dimensions[kBatchDimensionIndex] == 1) { + // LHS input noncontracting dimension is split but the corresponding + // output one is not. Assign part of the output one to the unused batch + // dimension. + dimensions[kBatchDimensionIndex] = lhs_noncontracting_split; + dimensions[kOutputLHSNonContractingDimensionIndex] /= + lhs_noncontracting_split; + strides[kBatchDimensionIndex] = + strides[kOutputLHSNonContractingDimensionIndex] * + dimensions[kOutputLHSNonContractingDimensionIndex]; + } + return true; + } + + private: + int64_t lhs_noncontracting_split = 1; + const HloDotInstruction& dot_; +}; + +// Traverses fusion computations and creates cuDNN graphs out of them. +absl::StatusOr> HloFusionToCuDnnGraph( + const HloFusionInstruction& fusion) { + const HloComputation& computation = *fusion.fused_instructions_computation(); + VLOG(5) << fusion.ToString(); + VLOG(5) << computation.ToString(); + graph::Graph graph; + std::vector instructions = + computation.MakeInstructionPostOrder(); + absl::flat_hash_map> + hlo_to_cudnn; + TF_ASSIGN_OR_RETURN(std::optional adapter, + GemmDimensionAdapter::Create(computation)); + if (!adapter.has_value()) { + return std::nullopt; + } + auto add_parameter = [&](const HloInstruction& parameter, + std::vector& dimensions, + std::vector strides) { + const std::optional data_type = + ToCudnnDataType(parameter.shape().element_type()); + if (!data_type.has_value()) { + VLOG(3) << "Unsupported data type."; + return false; + } + hlo_to_cudnn[¶meter] = graph.tensor( + graph::Tensor_attributes() + .set_dim(dimensions) + .set_stride(strides) + .set_data_type(*data_type) + .set_uid(se::gpu::CuDnnTensorUID(parameter.parameter_number()))); + return true; + }; + for (const TritonFusionAnalysis::Scope scope : + {TritonFusionAnalysis::Scope::LHS, TritonFusionAnalysis::Scope::RHS, + TritonFusionAnalysis::Scope::OUTPUT}) { + for (const HloInstruction* parameter : + adapter->analysis_.ScopeParameters(scope)) { + std::vector dimensions; + std::vector strides; + if (!adapter->DimensionsAndStrides(*parameter, scope, dimensions, + strides)) { + VLOG(3) << "Unsupported dimensions."; + return std::nullopt; + } + if (!add_parameter(*parameter, dimensions, strides)) { + return std::nullopt; + } + } + } + + for (const HloInstruction* hlo : instructions) { + VLOG(5) << hlo->ToShortString(); + auto operand = [&hlo_to_cudnn, &hlo](int i) { + return hlo_to_cudnn[hlo->operand(i)]; + }; + if (hlo->opcode() == HloOpcode::kParameter) { + CHECK(hlo_to_cudnn.contains(hlo)); + continue; + } else if (hlo->opcode() == HloOpcode::kReshape || + hlo->opcode() == HloOpcode::kBitcast || + hlo->opcode() == HloOpcode::kTranspose || + hlo->opcode() == HloOpcode::kCopy || + (FusionLevel(fusion) >= 2 && + hlo->opcode() == HloOpcode::kBroadcast)) { + // All these are accounted for separately as transformations of strides. + hlo_to_cudnn[hlo] = operand(0); + } else if (hlo->IsElementwise()) { + const auto mode = GetElementwiseMode(*hlo); + if (!mode.has_value()) { + VLOG(3) << "Unsupported elementwise operation."; + return std::nullopt; + } + const auto compute_dtype = + GetComputeDataType(hlo->shape().element_type()); + if (!compute_dtype.has_value()) { + return std::nullopt; + } + const auto attrs = graph::Pointwise_attributes() + .set_mode(mode.value()) + .set_compute_data_type(compute_dtype.value()); + if (hlo->operand_count() == 1) { + hlo_to_cudnn[hlo] = graph.pointwise(operand(0), attrs); + } else if (hlo->operand_count() == 2) { + hlo_to_cudnn[hlo] = graph.pointwise(operand(0), operand(1), attrs); + } else if (hlo->operand_count() == 3) { + if (hlo->opcode() != HloOpcode::kSelect) { + VLOG(3) << "Unexpected ternary operation: " << hlo->ToString(); + return std::nullopt; + } + // Operand order for select differs between HLO and cuDNN. + hlo_to_cudnn[hlo] = + graph.pointwise(operand(1), operand(2), operand(0), attrs); + } else { + VLOG(3) << "Unimplemented elementwise operation."; + return std::nullopt; + } + } else if (hlo->opcode() == HloOpcode::kDot) { + const auto compute_dtype = + GetComputeDataType(hlo->shape().element_type()); + if (!compute_dtype.has_value()) { + return std::nullopt; + } + hlo_to_cudnn[hlo] = + graph.matmul(operand(0), operand(1), + graph::Matmul_attributes().set_compute_data_type( + compute_dtype.value())); + } else { + VLOG(3) << "Unimplemented operation."; + return std::nullopt; + } + if (hlo_to_cudnn[hlo] == nullptr) { + VLOG(3) << "Creation of the operation failed."; + return std::nullopt; + } + const auto data_type = ToCudnnDataType(hlo->shape().element_type()); + if (!data_type.has_value()) { + VLOG(3) << "Unimplemented data type: " << hlo->shape().element_type(); + return std::nullopt; + } + hlo_to_cudnn[hlo]->set_data_type(data_type.value()); + } + const HloInstruction* output = instructions.back(); + if (instructions.back()->shape().IsTuple()) { + output = instructions.back()->operand(0); + } + std::vector dimensions; + std::vector strides; + if (!adapter->DimensionsAndStrides( + *output, TritonFusionAnalysis::Scope::OUTPUT, dimensions, strides)) { + VLOG(3) << "Unsupported dimensions."; + return std::nullopt; + } + hlo_to_cudnn[output] + ->set_output(true) + .set_dim(dimensions) + .set_stride(strides) + .set_uid(se::gpu::CuDnnTensorUID(fusion.operand_count())); + if (cudnn_frontend::error_t result = graph.validate(); result.is_bad()) { + VLOG(3) << result.get_message(); + return std::nullopt; + } + + return se::gpu::CudnnGraph(std::move(graph)); +} + +// Creates a cuDNN graph, queries cuDNN whether it is supported. +absl::StatusOr PrepareGraph( + se::dnn::DnnSupport& dnn_support, const HloFusionInstruction& hlo) { + TF_ASSIGN_OR_RETURN(std::optional graph, + HloFusionToCuDnnGraph(hlo)); + if (!graph.has_value()) { + return absl::InternalError("Construction of cuDNN graph failed."); + } + TF_ASSIGN_OR_RETURN(bool supported, graph->Prepare(dnn_support)); + if (!supported) { + return absl::InternalError("cuDNN graph is not supported."); + } + return *graph; +} + +class CuDnnFusionVisitor : public DfsHloRewriteVisitor { + public: + explicit CuDnnFusionVisitor( + se::dnn::DnnSupport& dnn_support, + CuDnnFusionCompiler::BinaryMap& compilation_results) + : dnn_support_(dnn_support), compilation_results_(compilation_results) {} + + absl::Status HandleFusion(HloInstruction* hlo) override { + TF_ASSIGN_OR_RETURN(auto gpu_config, + hlo->backend_config()); + const auto& fusion_backend_config = gpu_config.fusion_backend_config(); + if (fusion_backend_config.kind() != kCuDnnFusionKind) { + return absl::OkStatus(); + } + int64_t plan_id = -1; + if (fusion_backend_config.has_cudnn_fusion_config()) { + plan_id = fusion_backend_config.cudnn_fusion_config().plan_id(); + } + + VLOG(4) << "Processing " << hlo->ToString(); + VLOG(4) << "Plan ID: " << plan_id; + + const std::string cache_key = + GetComputationFingerprint(hlo->fused_instructions_computation(), {}); + std::string& cache_entry = compilation_results_[cache_key]; + if (cache_entry.empty()) { + TF_ASSIGN_OR_RETURN( + se::gpu::CudnnGraph graph, + PrepareGraph(dnn_support_, *DynCast(hlo))); + + if (plan_id >= 0) { + // Build single plan with given ID. + if (plan_id >= graph.Graph().get_execution_plan_count()) { + return absl::InternalError("cuDNN graph plan does not exist."); + } + TF_RETURN_IF_ERROR(graph.Build(dnn_support_, plan_id)); + } else { + // Build plans one by one till first successful when no plan_id was + // provided. + for (plan_id = 0; plan_id < graph.Graph().get_execution_plan_count(); + ++plan_id) { + VLOG(7) << "Trying plan ID " << plan_id; + if (graph.Build(dnn_support_, plan_id).ok()) { + VLOG(7) << "Successfully built plan ID " << plan_id; + break; + } + } + if (plan_id == graph.Graph().get_execution_plan_count()) { + return absl::InternalError("No cuDNN plans can be built."); + } + } + + if (graph.Graph().get_workspace_size() != 0) { + return absl::UnimplementedError( + "Support of workspace allocation is not added yet."); + } + + std::vector serialized_graph; + RETURN_IF_CUDNN_FRONTEND_ERROR(graph.Graph().serialize(serialized_graph)); + cache_entry = + std::string(reinterpret_cast(serialized_graph.data()), + serialized_graph.size()); + } else { + VLOG(4) << "Cache hit."; + } + auto cudnn_config = gpu_config.mutable_fusion_backend_config() + ->mutable_cudnn_fusion_config(); + cudnn_config->set_plan_id(plan_id); + TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); + + MarkAsChanged(); + return absl::OkStatus(); + } + + private: + se::dnn::DnnSupport& dnn_support_; + // . + CuDnnFusionCompiler::BinaryMap& compilation_results_; +}; + +} // namespace + +absl::StatusOr CuDnnFusionCompiler::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + XLA_SCOPED_LOGGING_TIMER("cuDNN fusion compiler"); + return CuDnnFusionVisitor(dnn_support_, compilation_results_) + .RunOnModule(module, execution_threads); +} + +int CuDnnFusionCompiler::GetAvailablePlanCount( + se::StreamExecutor& stream_exec, const HloFusionInstruction& hlo) { + auto graph = PrepareGraph(*stream_exec.AsDnn(), hlo); + if (!graph.ok()) { + return 0; + } + constexpr int64_t kMaxPlans = 10; + return std::min(graph->Graph().get_execution_plan_count(), kMaxPlans); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/cudnn_fusion_compiler.h b/xla/service/gpu/cudnn_fusion_compiler.h new file mode 100644 index 0000000000000..e5ce4ddefa7b6 --- /dev/null +++ b/xla/service/gpu/cudnn_fusion_compiler.h @@ -0,0 +1,61 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_CUDNN_FUSION_COMPILER_H_ +#define XLA_SERVICE_GPU_CUDNN_FUSION_COMPILER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" +#include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/stream_executor.h" + +namespace xla { +namespace gpu { + +// Converts HLO fusions with cuDNN backend config to cuDNN graphs, +// compiles them using a cuDNN handle and serializes them. +class CuDnnFusionCompiler : public HloModulePass { + public: + // . + using BinaryMap = absl::flat_hash_map; + + explicit CuDnnFusionCompiler(se::StreamExecutor& stream_exec, + BinaryMap& compilation_results) + : dnn_support_(*stream_exec.AsDnn()), + compilation_results_(compilation_results) {} + + absl::string_view name() const override { return "cudnn-fusion-compiler"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + static int GetAvailablePlanCount(se::StreamExecutor& stream_exec, + const HloFusionInstruction& hlo); + + private: + se::dnn::DnnSupport& dnn_support_; + BinaryMap& compilation_results_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_CUDNN_FUSION_COMPILER_H_ diff --git a/xla/service/gpu/cudnn_norm_rewriter.cc b/xla/service/gpu/cudnn_norm_rewriter.cc index ba225fe265506..ec966e18085b1 100644 --- a/xla/service/gpu/cudnn_norm_rewriter.cc +++ b/xla/service/gpu/cudnn_norm_rewriter.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,19 +15,46 @@ limitations under the License. #include "xla/service/gpu/cudnn_norm_rewriter.h" -#include - +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/wrappers.pb.h" +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/pattern_matcher.h" -#include "xla/stream_executor/dnn.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/stream_executor/device_description.h" +#include "xla/types.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" #include "tsl/protobuf/dnn.pb.h" #if GOOGLE_CUDA -#include "third_party/gpus/cuda/include/cuda.h" -#include "third_party/gpus/cudnn/cudnn.h" +#include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: keep +#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: keep +#include "third_party/gpus/cudnn/cudnn_version.h" #endif namespace xla { @@ -37,17 +64,105 @@ namespace { namespace m = match; +// Traverses the graph upward starting at instr and returns the +// first instruction that is not a convert, bitcast or reshape. +const HloInstruction* SkipUnaryOps(const HloInstruction* instr) { + while (instr->opcode() == HloOpcode::kConvert || + instr->opcode() == HloOpcode::kBitcast || + instr->opcode() == HloOpcode::kReshape) { + instr = instr->operand(0); + } + return instr; +} + +// Recursively traverses the graph downward starting at instr and stores in +// instrs the users that are not a convert, bitcast or reshape. +void SkipUnaryOpsTopDownRecursive(HloInstruction* instr, + std::vector& instrs) { + if (instr->opcode() == HloOpcode::kConvert || + instr->opcode() == HloOpcode::kBitcast || + instr->opcode() == HloOpcode::kReshape) { + for (HloInstruction* user : instr->users()) { + SkipUnaryOpsTopDownRecursive(user, instrs); + } + } else { + instrs.emplace_back(instr); + } +} + +// Holds auxiliary information about individual layer norm patterns rewritten +// into a cuDNN Custom Call. +struct NormMetadata { + // Transposes applied to the input and output of the forward layer norm to + // order the normalization and non-normalization dimensions as required by + // cuDNN. Nullptr if no transposes were inserted. + HloInstruction *x_transpose, *y_transpose; + // The reduction and non-reduction dimensions of the input into the forward + // layer norm before the potential application of transposes and adjusted for + // the removal of any degenerate dimensions in the input to the norm. + std::vector norm_dims_adjusted, non_norm_dims_adjusted; +}; + +// Map from the instruction pointer of a layer norm Custom Call to its metadata. +using NormMetadataMap = absl::flat_hash_map; + +// Captures multiple HloInstruction pointers and verifies that their target +// is identical. +// +// Example: +// Pattern cos(x) / sin(x) with cos and sin intended to operate on the same +// HloInstruction: +// UniqueHloInstruction x; +// bool m = Match( +// instr, m::Divide(m::Cos(m::Op().WithPredicate(x.capture_and_verify)), +// m::Sin(m::Op().WithPredicate(x.capture_and_verify)))); +// m is true and x.Instr() returns an HloInstruction pointer to the operand of +// cosine and sine iff HloInstruction *instr points to a division of a cosine by +// a sine that operate on the same instruction. +class UniqueHloInstruction { + public: + UniqueHloInstruction() : is_set_(false), instr_(nullptr) {} + HloInstruction* Instr() const { return instr_; } + void SetInstr(HloInstruction* instr) { + is_set_ = true; + instr_ = instr; + } + + // Stores instr when invoked the first time. Otherwise, compares instr to the + // stored value and sets the stored value to nullptr if the comparison fails. + bool CaptureOrVerify(HloInstruction* instr) { + if (is_set_ && instr != instr_) { + instr_ = nullptr; + } + if (!is_set_) { + is_set_ = true; + instr_ = instr; + } + return instr_; + } + + // Lambda for capturing or verifying an instruction using WithPredicate. + const std::function capture_or_verify = + [this](const HloInstruction* instr) -> bool { + return CaptureOrVerify(const_cast(instr)); + }; + + private: + bool is_set_; + HloInstruction* instr_; +}; + // Returns an architecture-specific constant for the calculation of an upper // bound for the size of the scratch space for layer norm kernels. -StatusOr CConstant(se::CudaComputeCapability cuda_compute_capability) { +absl::StatusOr CConstant( + se::CudaComputeCapability cuda_compute_capability) { if (cuda_compute_capability.major == se::CudaComputeCapability::AMPERE) { return 32 * 128; } else if (cuda_compute_capability.major == se::CudaComputeCapability::HOPPER) { return 32 * 144; } - return xla::InternalError( - "Norm kernels require Ampere or Hopper architecture."); + return xla::Internal("Norm kernels require Ampere or Hopper architecture."); } // Returns whether the element type of instr is compatible with layer norm @@ -57,12 +172,59 @@ bool CompatibleElementType(const HloInstruction* instr) { return element_type == BF16 || element_type == F16 || element_type == F32; } +// Returns the dimensions associated with shape, adjusted for the removal of any +// degenerate dimensions in shape. Specifically, for each dimension d in +// dimensions, returns the new index of d if all dimensions of size 1 are +// removed from shape. If d has size 1, it is not included in the returned +// vector. +std::vector AdjustedDimensions(const Shape& shape, + absl::Span dimensions) { + absl::flat_hash_map dimension_map; + for (int64_t dimension = 0, non_degen_dimension = 0; dimension < shape.rank(); + ++dimension) { + if (shape.dimensions(dimension) > 1) { + dimension_map.insert({dimension, non_degen_dimension}); + non_degen_dimension++; + } + } + std::vector adjusted_dimensions; + for (int64_t dimension : dimensions) { + auto non_degenerate_dimension = dimension_map.find(dimension); + if (non_degenerate_dimension != dimension_map.end()) { + adjusted_dimensions.emplace_back(non_degenerate_dimension->second); + } + } + return adjusted_dimensions; +} + +// Returns the dimensions of broadcast or reduction instructions, adjusted for +// the removal of any degenerate dimensions in the output or input. +std::vector AdjustedDimensions(const HloInstruction* instr) { + Shape shape; + if (instr->opcode() == HloOpcode::kBroadcast) { + shape = instr->shape(); + } else if (instr->opcode() == HloOpcode::kReduce) { + shape = instr->operand(0)->shape(); + } else { + return {}; + } + return AdjustedDimensions(shape, instr->dimensions()); +} + // Returns whether the HLO Computation applied by instr calculates the sum of -// the elements. -bool AppliesAddReduce(const HloInstruction* instr) { +// the elements. When provided, compares reduce_dims to the dimensions of the +// reduction. +bool AppliesAddReduce(const HloInstruction* instr, + absl::Span reduce_dims = {}) { if (instr->opcode() != HloOpcode::kReduce) { return false; } + + // Verify the dimensions of the reduction. + if (!reduce_dims.empty() && AdjustedDimensions(instr) != reduce_dims) { + return false; + } + HloComputation* reduce_comp = instr->to_apply(); HloInstruction* reduce_comp_root = reduce_comp->root_instruction(); return instr->operand_count() == 2 && @@ -77,23 +239,13 @@ bool AppliesAddReduce(const HloInstruction* instr) { // Returns whether instr multiplies the result of a reduction by one over the // number of reduced elements. bool CalculatesExpectation(const HloInstruction* instr) { - auto skip_convert_and_reshape = - [](const HloInstruction* instr) -> const HloInstruction* { - while (instr->opcode() == HloOpcode::kConvert || - instr->opcode() == HloOpcode::kReshape) { - instr = instr->operand(0); - } - return instr; - }; - - instr = skip_convert_and_reshape(instr); + instr = SkipUnaryOps(instr); if (instr->opcode() != HloOpcode::kMultiply) { return false; } bool bcast_operand = instr->operand(0)->opcode() != HloOpcode::kBroadcast; const HloInstruction *broadcast = instr->operand(bcast_operand), - *reduce = instr->operand(!bcast_operand); - reduce = skip_convert_and_reshape(reduce); + *reduce = SkipUnaryOps(instr->operand(!bcast_operand)); if (reduce->opcode() != HloOpcode::kReduce || broadcast->opcode() != HloOpcode::kBroadcast || broadcast->operand(0)->opcode() != HloOpcode::kConstant) { @@ -114,6 +266,162 @@ bool CalculatesExpectation(const HloInstruction* instr) { ((actual_r_nelems + r_nelems) * numerical_epsilon); } +// Returns whether target can be reached from instr by recursively traversing +// the graph across converts, bitcasts and reshapes. +bool FindTargetRecursive( + const HloInstruction* instr, const HloInstruction* target, + absl::flat_hash_set& visited_instrs, + const HloInstruction* transpose) { + visited_instrs.emplace(instr); + const absl::flat_hash_set supported_ops = { + HloOpcode::kConvert, HloOpcode::kBitcast, HloOpcode::kReshape}; + if (instr == target) { + return true; + } + // Look for target among the users of instr. + for (HloInstruction* user : instr->users()) { + if ((supported_ops.contains(user->opcode()) || user == transpose) && + !visited_instrs.contains(user)) { + return FindTargetRecursive(user, target, visited_instrs, transpose); + } + } + // Ascend the graph if target is not found and instr is a convert, bitcast + // or reshape. + if (supported_ops.contains(instr->opcode())) { + return FindTargetRecursive(instr->operand(0), target, visited_instrs, + transpose); + } + return false; +} + +bool FindTarget(const HloInstruction* custom_call, const HloInstruction* instr, + const HloInstruction* target, + const NormMetadataMap& norm_metadata) { + absl::flat_hash_set visited_instrs; + auto custom_call_metadata = norm_metadata.find(custom_call); + if (custom_call_metadata == norm_metadata.end()) { + return false; + } + return FindTargetRecursive(instr, target, visited_instrs, + custom_call_metadata->second.x_transpose); +} + +// Maps the dimension numbers in dimensions from shape original_shape to shape +// reshaped_shape, assuming that the shapes are related through a strict +// reshape. Returns an empty vector if a dimension mapping is not found. +std::vector MapDimensions(const Shape& original_shape, + const Shape& reshaped_shape, + const absl::Span dimensions) { + auto dimension_product = + [](const Shape& shape, + absl::Span product_dimensions) -> int64_t { + int64_t product = 1; + for (int64_t product_dimension : product_dimensions) { + product *= shape.dimensions(product_dimension); + } + return product; + }; + // Construct the dimension mapping. + absl::flat_hash_map> dimensions_map; + std::vector original_dimensions, reshaped_dimensions; + for (int64_t original_dimension = 0, reshaped_dimension = 0; + original_dimension < original_shape.rank(); ++original_dimension) { + original_dimensions.emplace_back(original_dimension); + while ((reshaped_dimensions.empty() || + dimension_product(reshaped_shape, reshaped_dimensions) < + dimension_product(original_shape, original_dimensions)) && + reshaped_dimension < reshaped_shape.rank()) { + reshaped_dimensions.emplace_back(reshaped_dimension++); + } + + // Many-to-many dimension mappings are not supported. + if (original_dimensions.size() > 1 && reshaped_dimensions.size() > 1) { + return {}; + } + + if (dimension_product(original_shape, original_dimensions) == + dimension_product(reshaped_shape, reshaped_dimensions)) { + std::vector original_dimensions_in_dimensions; + std::set_intersection( + original_dimensions.begin(), original_dimensions.end(), + dimensions.begin(), dimensions.end(), + std::back_inserter(original_dimensions_in_dimensions)); + // The unique mapping of dimensions requires either all or none of the + // entries of original_dimensions to be an element of dimensions. + if (original_dimensions_in_dimensions.size() != 0 && + original_dimensions_in_dimensions.size() != + original_dimensions.size()) { + return {}; + } + for (int64_t dimension : original_dimensions) { + dimensions_map.insert({dimension, reshaped_dimensions}); + } + original_dimensions.clear(); + reshaped_dimensions.clear(); + } + } + + // Map the dimensions numbers to the reshaped shape. + std::vector mapped_dimensions; + for (int64_t dimension : dimensions) { + auto mapped_dimension = dimensions_map.find(dimension); + if (mapped_dimension == dimensions_map.end()) { + return {}; + } + mapped_dimensions.insert(mapped_dimensions.end(), + mapped_dimension->second.begin(), + mapped_dimension->second.end()); + } + + // Eliminate duplicates in the mapped dimension numbers. + mapped_dimensions.erase( + std::unique(mapped_dimensions.begin(), mapped_dimensions.end()), + mapped_dimensions.end()); + return mapped_dimensions; +} + +// Recursively traverses the graph across converts, bitcasts and reshapes, +// starting from instr, and returns the first addition-reduction identified. +// Returns nullptr if no addition-reduction is found. +HloInstruction* FindAddReduceRecursive( + HloInstruction* instr, const Shape& orig_instr_shape, + const absl::Span reduce_dims, + absl::flat_hash_set& visited_instrs) { + visited_instrs.emplace(instr); + const absl::flat_hash_set supported_ops = { + HloOpcode::kConvert, HloOpcode::kBitcast, HloOpcode::kReshape}; + // Look for a reduction among the users of instr. + for (HloInstruction* user : instr->users()) { + if (user->opcode() == HloOpcode::kReduce) { + std::vector mapped_reduce_dims = + MapDimensions(orig_instr_shape, instr->shape(), reduce_dims); + if (!mapped_reduce_dims.empty() && + AppliesAddReduce(user, mapped_reduce_dims)) { + return user; + } + } + if (supported_ops.contains(user->opcode()) && + !visited_instrs.contains(user)) { + return FindAddReduceRecursive(user, orig_instr_shape, reduce_dims, + visited_instrs); + } + } + // Ascend the graph if the addition-reduction is not found and instr is a + // convert, bitcast or reshape. + if (supported_ops.contains(instr->opcode())) { + return FindAddReduceRecursive(instr->mutable_operand(0), orig_instr_shape, + reduce_dims, visited_instrs); + } + return nullptr; +} + +HloInstruction* FindAddReduce(HloInstruction* instr, + const absl::Span reduce_dims) { + absl::flat_hash_set visited_instrs; + return FindAddReduceRecursive(instr, instr->shape(), reduce_dims, + visited_instrs); +} + // Type conversion from and to any of BF16, FP16 and FP32. template auto SupportedConvert(Pattern pattern) { @@ -124,69 +432,94 @@ auto SupportedConvert(Pattern pattern) { return m::Convert(pattern).WithPredicate(supported_convert); } -// Reshape adding or removing degenerate dimensions. +// Bitcast or reshape adding or removing degenerate dimensions. template -auto SupportedReshape(Pattern pattern) { - auto supported_reshape = [](const HloInstruction* instr) -> bool { +auto SupportedBitcastOrReshape(Pattern pattern) { + auto supported_bitcast_or_reshape = [](const HloInstruction* instr) -> bool { return ShapeUtil::Equal( ShapeUtil::DropDegenerateDimensions(instr->shape()), ShapeUtil::DropDegenerateDimensions(instr->operand(0)->shape())); }; - return m::Reshape(pattern).WithPredicate(supported_reshape); + return m::AnyOf( + m::Bitcast(pattern).WithPredicate(supported_bitcast_or_reshape), + m::Reshape(pattern).WithPredicate(supported_bitcast_or_reshape)); } -// Matches pattern, SupportedConvert(pattern), SupportedReshape(pattern), -// SupportedConvert(SupportedReshape(pattern)) and -// SupportedReshape(SupportedConvert(pattern)). +// Matches pattern, SupportedConvert(pattern), +// SupportedBitcastOrReshape(pattern), +// SupportedConvert(SupportedBitcastOrReshape(pattern)) and +// SupportedBitcastOrReshape(SupportedConvert(pattern)). template -auto OptionalConvertAndOrReshape(Pattern pattern) { +auto OptionalSupportedTransform(Pattern pattern) { auto shared_subpattern = m::SharedSubpattern(pattern); return m::AnyOf( - SupportedConvert(SupportedReshape(shared_subpattern)), - SupportedReshape(SupportedConvert(shared_subpattern)), - SupportedConvert(shared_subpattern), SupportedReshape(shared_subpattern), - shared_subpattern); + SupportedConvert(SupportedBitcastOrReshape(shared_subpattern)), + SupportedBitcastOrReshape(SupportedConvert(shared_subpattern)), + SupportedConvert(shared_subpattern), + SupportedBitcastOrReshape(shared_subpattern), shared_subpattern); +} + +// Bitcast or reshape with optional supported type conversion and/or addition or +// removal of degenerate dimensions. +template +auto BitcastOrReshape(Pattern pattern) { + return OptionalSupportedTransform( + m::AnyOf(m::Bitcast(pattern), m::Reshape(pattern))); +} + +// Transpose with optional supported type conversion and/or addition or removal +// of degenerate dimensions. +template +auto Transpose(Pattern pattern) { + return OptionalSupportedTransform(m::Transpose(pattern)); } -// Rsqrt with optional convert and/or reshape. +// Rsqrt with optional supported type conversion and/or addition or removal of +// degenerate dimensions. template auto Rsqrt(HloInstruction** rsqrt, Pattern pattern) { - return OptionalConvertAndOrReshape(m::Rsqrt(rsqrt, pattern)); + return OptionalSupportedTransform(m::Rsqrt(rsqrt, pattern)); } -// AddAnyOrder with optional convert and/or reshape. +// AddAnyOrder with optional supported type conversion and/or addition or +// removal of degenerate dimensions. template auto AddAnyOrder(Pattern0 pattern0, Pattern1 pattern1) { - return OptionalConvertAndOrReshape(m::AddAnyOrder(pattern0, pattern1)); + return OptionalSupportedTransform(m::AddAnyOrder(pattern0, pattern1)); } -// Subtract with optional convert and/or reshape. +// Subtract with optional supported type conversion and/or addition or removal +// of degenerate dimensions. template auto Subtract(Pattern0 pattern0, Pattern1 pattern1) { - return OptionalConvertAndOrReshape(m::Subtract(pattern0, pattern1)); + return OptionalSupportedTransform(m::Subtract(pattern0, pattern1)); } -// Capturing subtract with optional convert and/or reshape. +// Capturing subtract with optional supported type conversion and/or addition or +// removal of degenerate dimensions. template auto Subtract(HloInstruction** subtract, Pattern0 pattern0, Pattern1 pattern1) { - return OptionalConvertAndOrReshape(m::Subtract(subtract, pattern0, pattern1)); + return OptionalSupportedTransform(m::Subtract(subtract, pattern0, pattern1)); } -// Multiply with optional convert and/or reshape. +// Multiply with optional supported type conversion and/or addition or removal +// of degenerate dimensions. template auto MultiplyAnyOrder(Pattern0 pattern0, Pattern1 pattern1) { - return OptionalConvertAndOrReshape(m::MultiplyAnyOrder(pattern0, pattern1)); + return OptionalSupportedTransform(m::MultiplyAnyOrder(pattern0, pattern1)); } -// Capturing multiply with optional convert and/or reshape. +// Capturing multiply with optional supported type conversion and/or addition or +// removal of degenerate dimensions. template auto MultiplyAnyOrder(HloInstruction** multiply, Pattern0 pattern0, Pattern1 pattern1) { - return OptionalConvertAndOrReshape( + return OptionalSupportedTransform( m::MultiplyAnyOrder(multiply, pattern0, pattern1)); } -// Multiplication of pattern by itself with optional convert and/or reshape. +// Multiplication of pattern by itself with optional supported type conversion +// and/or addition or removal of degenerate dimensions. template auto Square(Pattern pattern) { return MultiplyAnyOrder(pattern, pattern) @@ -195,28 +528,49 @@ auto Square(Pattern pattern) { }); } -// Addition-reduction of pattern with optional convert and/or reshape and -// constant 0 scalar. +// Multiplication of the square of pattern by pattern with optional supported +// type conversion and/or addition or removal of degenerate dimensions. The root +// instruction of pattern cannot be a multiplication. +template +auto Cube(Pattern pattern) { + auto unique_cube = [](const HloInstruction* instr) -> bool { + bool square_operand = instr->operand(0)->opcode() != HloOpcode::kMultiply; + return instr->operand(!square_operand)->opcode() != HloOpcode::kMultiply && + instr->operand(square_operand)->operand(0) == + instr->operand(!square_operand); + }; + return MultiplyAnyOrder(Square(pattern), pattern).WithPredicate(unique_cube); +} + +// Addition-reduction of pattern with optional supported type conversion and/or +// addition or removal of degenerate dimensions. template auto AddReduce(Pattern pattern) { - return OptionalConvertAndOrReshape( + return OptionalSupportedTransform( m::Reduce(pattern, m::Op()) .WithPredicate([](const HloInstruction* instr) { return AppliesAddReduce(instr); })); } -// Capturing addition-reduction of pattern with optional convert and/or reshape -// and constant 0 scalar. +// Capturing addition-reduction of pattern with optional supported type +// conversion and/or addition or removal of degenerate dimensions. template auto AddReduce(HloInstruction** reduction, Pattern pattern) { - return OptionalConvertAndOrReshape( + return OptionalSupportedTransform( m::Reduce(reduction, pattern, m::Op()) .WithPredicate([](const HloInstruction* instr) { return AppliesAddReduce(instr); })); } +// Negated addition-reduction. +template +auto NegateAddReduce(HloInstruction** reduction, Pattern pattern) { + return m::AnyOf(AddReduce(reduction, m::Negate(pattern)), + m::Negate(AddReduce(reduction, pattern))); +} + // Expected value, or mean, with optional broadcast. template auto Expectation(Pattern pattern) { @@ -231,65 +585,62 @@ auto Expectation(Pattern pattern) { // Expected value, or mean, with optional broadcast. template -auto Expectation(HloInstruction** expectation, Pattern pattern) { - auto shared_subpattern = - MultiplyAnyOrder(expectation, m::Broadcast(m::ConstantScalar()), - AddReduce(pattern)) +auto Expectation(UniqueHloInstruction* expectation, Pattern pattern) { + auto shared_subpattern = OptionalSupportedTransform( + m::MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()), AddReduce(pattern)) .WithPredicate([](const HloInstruction* instr) { return CalculatesExpectation(instr); - }); + }) + .WithPredicate(expectation->capture_or_verify)); return m::AnyOf(m::Broadcast(shared_subpattern), shared_subpattern); } // Expected value, or mean, with optional broadcast. template -auto Expectation(HloInstruction** expectation, HloInstruction** reduce, +auto Expectation(UniqueHloInstruction* expectation, HloInstruction** reduce, Pattern pattern) { - auto shared_subpattern = - MultiplyAnyOrder(expectation, m::Broadcast(m::ConstantScalar()), - AddReduce(reduce, pattern)) + auto shared_subpattern = OptionalSupportedTransform( + m::MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()), + AddReduce(reduce, pattern)) .WithPredicate([](const HloInstruction* instr) { return CalculatesExpectation(instr); - }); + }) + .WithPredicate(expectation->capture_or_verify)); return m::AnyOf(m::Broadcast(shared_subpattern), shared_subpattern); } // Variance, expressed as expectation(X^2) - expectation(X)^2 or -// expectation((X - expectation(X))^2). The simultaneous capture of input0 and -// input1 allows the caller to verify that they are identical. -auto Variance(HloInstruction** expectation, HloInstruction** input0, - HloInstruction** input1) { +// expectation((X - expectation(X))^2). +auto Variance(UniqueHloInstruction* variance, UniqueHloInstruction* expectation, + UniqueHloInstruction* x) { return m::AnyOf( - Subtract(Expectation(Square(m::Op(input0))), - Square(Expectation(expectation, m::Op(input1)))), - Expectation(Square( - Subtract(m::Op(input0), Expectation(expectation, m::Op(input1)))))); -} - -// Variance, expressed as expectation(X^2) - expectation(X)^2 or -// expectation((X - expectation(X))^2). The simultaneous capture of input0 and -// input1 allows the caller to verify that they are identical. -auto Variance(HloInstruction** variance, HloInstruction** expectation, - HloInstruction** input0, HloInstruction** input1) { - return m::AnyOf( - Subtract(variance, Expectation(Square(m::Op(input0))), - Square(Expectation(expectation, m::Op(input1)))), - Expectation(variance, - Square(Subtract(m::Op(input0), - Expectation(expectation, m::Op(input1)))))); + Subtract(Expectation(Square(OptionalSupportedTransform( + m::Op().WithPredicate(x->capture_or_verify)))), + Square(Expectation(expectation, OptionalSupportedTransform( + m::Op().WithPredicate( + x->capture_or_verify))))) + .WithPredicate(variance->capture_or_verify), + Expectation( + Square(Subtract( + OptionalSupportedTransform( + m::Op().WithPredicate(x->capture_or_verify)), + Expectation(expectation, + OptionalSupportedTransform( + m::Op().WithPredicate(x->capture_or_verify)))))) + .WithPredicate(variance->capture_or_verify)); } // Reciprocal of the square root of variance + epsilon with optional broadcast. -// The simultaneous capture of input0 and input1 allows the caller to verify -// that they are identical. -auto NormFactor(HloInstruction** norm_factor, HloInstruction** input0, - HloInstruction** input1, HloInstruction** variance, - HloInstruction** expectation, HloInstruction** epsilon) { +auto NormFactor(HloInstruction** norm_factor, UniqueHloInstruction* x, + UniqueHloInstruction* variance, + UniqueHloInstruction* expectation, + UniqueHloInstruction* epsilon) { auto shared_subpattern = m::SharedSubpattern(Rsqrt( - norm_factor, AddAnyOrder(Variance(variance, expectation, input0, input1), - m::Broadcast(m::ConstantScalar(epsilon))))); + norm_factor, AddAnyOrder(Variance(variance, expectation, x), + m::Broadcast(m::ConstantScalar().WithPredicate( + epsilon->capture_or_verify))))); return m::AnyOf(m::Broadcast(shared_subpattern), shared_subpattern); } @@ -303,6 +654,22 @@ auto MultiplyMultiplyAnyOrder(P0 p0, P1 p1, P2 p2) { MultiplyAnyOrder(p2, MultiplyAnyOrder(p0, p1))); } +// Any order of p0 + p1 + p2. +template +auto AddAddAnyOrder(P0 p0, P1 p1, P2 p2) { + return m::AnyOf(AddAnyOrder(p0, AddAnyOrder(p1, p2)), + AddAnyOrder(p1, AddAnyOrder(p0, p2)), + AddAnyOrder(p2, AddAnyOrder(p0, p1))); +} + +// Any order of p0 * (p1 + p2). +template +auto MultiplyAddAnyOrder(P0 p0, P1 p1, P2 p2) { + return m::AnyOf( + MultiplyAnyOrder(p0, AddAnyOrder(p1, p2)), + AddAnyOrder(MultiplyAnyOrder(p0, p1), MultiplyAnyOrder(p0, p2))); +} + // Any order of p0 - p1 + p2. template auto SubtractAddAnyOrder(P0 p0, P1 p1, P2 p2) { @@ -320,38 +687,223 @@ auto SubtractMultiplyAddAnyOrder(P0 p0, P1 p1, P2 p2, P3 p3, P4 p4) { AddAnyOrder(MultiplyMultiplyAnyOrder(Subtract(p0, p1), p2, p3), p4)); } +// Expectation fused into a layer norm Custom Call. +auto FusedExpectation(UniqueHloInstruction* custom_call) { + auto shared_subpattern = m::SharedSubpattern( + m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) + .WithPredicate(custom_call->capture_or_verify), + 1)); + return m::AnyOf(shared_subpattern, + BitcastOrReshape(shared_subpattern)); +} + +// Expectation fused into a layer norm Custom Call. +auto FusedExpectation(UniqueHloInstruction* fused_expectation, + UniqueHloInstruction* custom_call) { + auto shared_subpattern = m::SharedSubpattern( + m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) + .WithPredicate(custom_call->capture_or_verify), + 1) + .WithPredicate(fused_expectation->capture_or_verify)); + return m::AnyOf(shared_subpattern, + BitcastOrReshape(shared_subpattern)); +} + +// Norm factor fused into a layer norm Custom Call. +auto FusedNormFactor(UniqueHloInstruction* custom_call) { + auto shared_subpattern = m::SharedSubpattern( + m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) + .WithPredicate(custom_call->capture_or_verify), + 2)); + return m::AnyOf(shared_subpattern, + BitcastOrReshape(shared_subpattern)); +} + +// Norm factor fused into a layer norm Custom Call. +auto FusedNormFactor(UniqueHloInstruction* fused_norm_factor, + UniqueHloInstruction* custom_call) { + auto shared_subpattern = m::SharedSubpattern( + m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) + .WithPredicate(custom_call->capture_or_verify), + 2) + .WithPredicate(fused_norm_factor->capture_or_verify)); + return m::AnyOf(shared_subpattern, + BitcastOrReshape(shared_subpattern)); +} + +// Derivative of the norm factor w.r.t. variance + epsilon, +// d(norm_factor)/d(variance + epsilon) +// = d((variance + epsilon)^-1/2)/d(variance + epsilon) +// = -1/2 * norm_factor^3. +// Forwards custom_call to FusedNormFactor for verification. +auto DNormFactor(UniqueHloInstruction* custom_call) { + return MultiplyAnyOrder(m::Broadcast(m::ConstantScalar(-0.5)), + Cube(FusedNormFactor(custom_call))); +} + +// Zero-centered input of the layer norm, X - expectation(X). Verifies that +// custom_call is a forward layer norm fusing X. Forwards custom_call to +// FusedExpectation for verification. +auto XCenter(UniqueHloInstruction* x, UniqueHloInstruction* custom_call, + const NormMetadataMap& norm_metadata) { + auto capture_or_verify_x = + [x, custom_call, &norm_metadata](const HloInstruction* instr) -> bool { + return x->CaptureOrVerify( + FindTarget(custom_call->Instr(), instr->operand(0), + custom_call->Instr()->operand(0), norm_metadata) + ? custom_call->Instr()->mutable_operand(0) + : nullptr); + }; + return Subtract(m::Op(), m::Broadcast(FusedExpectation(custom_call))) + .WithPredicate(capture_or_verify_x); +} + +// Zero-centered input of the layer norm, X - expectation(X). Captures X in x if +// custom_call is a forward layer norm fusing X. Forwards custom_call to +// FusedExpectation for comparison. +auto XCenter(UniqueHloInstruction* x_center, UniqueHloInstruction* x, + UniqueHloInstruction* fused_expectation, + UniqueHloInstruction* custom_call, + const NormMetadataMap& norm_metadata) { + auto capture_or_verify_x = + [x, custom_call, &norm_metadata](const HloInstruction* instr) -> bool { + return x->CaptureOrVerify( + FindTarget(custom_call->Instr(), instr->operand(0), + custom_call->Instr()->operand(0), norm_metadata) + ? custom_call->Instr()->mutable_operand(0) + : nullptr); + }; + return Subtract(m::Op(), m::Broadcast(FusedExpectation(fused_expectation, + custom_call))) + .WithPredicate(x_center->capture_or_verify) + .WithPredicate(capture_or_verify_x); +} + +// Addition-reduction of the product of XCenter, the broadcasted scale and DY, +// XCenter * scale * DY. Captures the scale in scale if custom_call is a forward +// layer norm fusing the scale. Forwards custom_call to XCenter for comparison. +auto F0(UniqueHloInstruction* custom_call, UniqueHloInstruction* scale, + UniqueHloInstruction* dy, UniqueHloInstruction* x, + HloInstruction** reduce, const NormMetadataMap& norm_metadata) { + auto capture_or_verify_scale = [scale, custom_call, &norm_metadata]( + const HloInstruction* instr) -> bool { + return scale->CaptureOrVerify(FindTarget(custom_call->Instr(), instr, + custom_call->Instr()->operand(1), + norm_metadata) + ? custom_call->Instr()->mutable_operand(1) + : nullptr); + }; + return AddReduce( + reduce, MultiplyMultiplyAnyOrder( + XCenter(x, custom_call, norm_metadata), + m::Broadcast(m::Op().WithPredicate(capture_or_verify_scale)), + m::Op().WithPredicate(dy->capture_or_verify))); +} + +// Product of XCenter and the scaled and broadcasted product of F0 and +// d(norm_factor)/d(variance + epsilon), XCenter * F0 * DNormFactor * 2 / +// nelems. Forwards custom_call to XCenter, F0 and DNormFactor for capture or +// verification. +auto F1(UniqueHloInstruction* x, UniqueHloInstruction* x_center, + UniqueHloInstruction* fused_expectation, + UniqueHloInstruction* custom_call, UniqueHloInstruction* scale, + UniqueHloInstruction* dy, HloInstruction** reduce, + const NormMetadataMap& norm_metadata) { + auto broadcasts_two_over_nelems = [](const HloInstruction* instr) -> bool { + const HloInstruction* multiply = SkipUnaryOps(instr->operand(0)); + bool bcast_operand = + multiply->operand(0)->opcode() != HloOpcode::kBroadcast; + + // The captured scalar must be two over the number of elements in the + // broadcasted dimensions. + float actual_two_over_nelems = multiply->operand(bcast_operand) + ->operand(0) + ->literal() + .GetAsDouble({}) + .value(); + int64_t nelems = 1; + for (int i = 0; i < instr->shape().dimensions_size(); ++i) { + if (!c_linear_search(instr->dimensions(), i)) { + nelems *= instr->shape().dimensions()[i]; + } + } + // The absolute of the difference between the actual scaling factor and the + // reference value must not exceed a prescribed threshold. + float two_over_nelems = 2. / static_cast(nelems); + float numerical_epsilon = std::numeric_limits::epsilon(); + return abs(actual_two_over_nelems - two_over_nelems) < + ((actual_two_over_nelems + two_over_nelems) * numerical_epsilon); + }; + + return MultiplyAnyOrder( + XCenter(x_center, x, fused_expectation, custom_call, norm_metadata), + m::Broadcast( + MultiplyAnyOrder(m::Broadcast(m::ConstantScalar()), + MultiplyAnyOrder(DNormFactor(custom_call), + F0(custom_call, scale, dy, x, + reduce, norm_metadata)))) + .WithPredicate(broadcasts_two_over_nelems)); +} + +// Product of the norm factor, scale and DY, NormFactor * scale * DY. Captures +// the scale in scale if custom_call is a forward layer norm fusing the scale. +// Forwards custom_call to FusedNormFactor for comparison. +auto F2(UniqueHloInstruction* fused_norm_factor, UniqueHloInstruction* scale, + UniqueHloInstruction* dy, UniqueHloInstruction* custom_call, + const NormMetadataMap& norm_metadata) { + auto capture_or_verify_scale = [scale, custom_call, &norm_metadata]( + const HloInstruction* instr) -> bool { + return scale->CaptureOrVerify( + FindTarget(custom_call->Instr(), instr->operand(0), + custom_call->Instr()->operand(1), norm_metadata) + ? custom_call->Instr()->mutable_operand(1) + : nullptr); + }; + return MultiplyAnyOrder( + m::Broadcast( + BitcastOrReshape(FusedNormFactor(fused_norm_factor, custom_call))), + MultiplyAnyOrder(m::Broadcast().WithPredicate(capture_or_verify_scale), + m::Op().WithPredicate(dy->capture_or_verify))); +} + class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { public: explicit CudnnNormRewriterVisitor( const se::CudaComputeCapability cuda_compute_capability) : cuda_compute_capability_(cuda_compute_capability) {} - Status HandleAdd(HloInstruction* instr) override { - return MatchLayerNorm(instr); + absl::Status HandleAdd(HloInstruction* instr) override { + TF_RETURN_IF_ERROR(MatchLayerNorm(instr)); + TF_RETURN_IF_ERROR(MatchLayerNormGradient(instr)); + return absl::OkStatus(); } - Status HandleSubtract(HloInstruction* instr) override { + absl::Status HandleSubtract(HloInstruction* instr) override { return MatchLayerNorm(instr); } // Matches and rewrites layer norm patterns, - // (X - expectation(X))/(variance(X) + epsilon)^1/2 * scale + bias, + // Y = (X - expectation(X))/sqrt(variance(X) + epsilon) * scale + bias, // into Custom Calls to cuDNN. - Status MatchLayerNorm(HloInstruction* instr) { - HloInstruction *input, *input0, *input1, *input2, *scale, *bias, *epsilon, - *expectation, *expectation0, *reduce, *norm_factor, *variance, - *broadcast_scale, *broadcast_bias; - if (Match(instr, SubtractMultiplyAddAnyOrder( - m::Op(&input), - Expectation(&expectation, &reduce, m::Op(&input0)), - NormFactor(&norm_factor, &input1, &input2, &variance, - &expectation0, &epsilon), - m::Broadcast(&broadcast_scale, m::Op(&scale)), - m::Broadcast(&broadcast_bias, m::Op(&bias))))) { + absl::Status MatchLayerNorm(HloInstruction* instr) { + UniqueHloInstruction x, expectation, variance, epsilon; + HloInstruction *scale, *bias, *reduce, *norm_factor, *broadcast_scale, + *broadcast_bias; + if (Match( + instr, + SubtractMultiplyAddAnyOrder( + OptionalSupportedTransform( + m::Op().WithPredicate(x.capture_or_verify)), + Expectation(&expectation, &reduce, + OptionalSupportedTransform( + m::Op().WithPredicate(x.capture_or_verify))), + NormFactor(&norm_factor, &x, &variance, &expectation, &epsilon), + m::Broadcast(&broadcast_scale, m::Op(&scale)), + m::Broadcast(&broadcast_bias, m::Op(&bias))))) { #if CUDNN_VERSION < 8905 // Layer norm kernels are available with cuDNN 8.9.5 and above. VLOG(1) << "Layer norm Custom Calls require cuDNN 8.9.5."; - return OkStatus(); + return absl::OkStatus(); #endif // CUDNN_VERSION < 8905 if (!instr->GetModule() @@ -359,7 +911,7 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { .debug_options() .xla_gpu_enable_cudnn_layer_norm()) { VLOG(1) << "Layer norm Custom Calls disabled."; - return OkStatus(); + return absl::OkStatus(); } // Layer norm kernels require Ampere or Hopper architectures. @@ -367,123 +919,124 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { cuda_compute_capability_.major != se::CudaComputeCapability::HOPPER) { VLOG(1) << "Layer norm Custom Calls require Ampere or Hopper " "architectures."; - return OkStatus(); + return absl::OkStatus(); } // Verify the uniqueness of the inputs. - auto is_input = [input](HloInstruction* inputx) -> bool { - return inputx->unique_id() == input->unique_id() || - (inputx->opcode() == HloOpcode::kConvert && - inputx->operand(0)->unique_id() == input->unique_id()); - }; - if (!is_input(input0) || !is_input(input1) || !is_input(input2) || - expectation->unique_id() != expectation0->unique_id()) { + if (!x.Instr() || !expectation.Instr() || !variance.Instr() || + !epsilon.Instr()) { VLOG(1) << "Layer norm operands not unique."; - return OkStatus(); - } - - // Skip initial convert, if present. - if (input->opcode() == HloOpcode::kConvert) { - input = input->mutable_operand(0); + return absl::OkStatus(); } // Verify the input and output layouts. // TODO(philipphack): Consider supporting more general cases. - if (!LayoutUtil::IsMonotonicWithDim0Major(input->shape().layout()) || + if (!LayoutUtil::IsMonotonicWithDim0Major(x.Instr()->shape().layout()) || !LayoutUtil::IsMonotonicWithDim0Major(scale->shape().layout()) || !LayoutUtil::IsMonotonicWithDim0Major(bias->shape().layout()) || !LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout())) { VLOG(1) << "Layer norm input and/or output layouts nor supported."; - return OkStatus(); + return absl::OkStatus(); } // Verify the element types. The types and shapes of the scale and bias // must match. - if (!CompatibleElementType(input) || !CompatibleElementType(instr) || + if (!CompatibleElementType(x.Instr()) || !CompatibleElementType(instr) || !CompatibleElementType(scale) || !CompatibleElementType(bias) || !ShapeUtil::Equal(scale->shape(), bias->shape())) { VLOG(1) << "Layer norm input types or shapes not supported."; - return OkStatus(); + return absl::OkStatus(); } // Verify that the shapes of scale and bias are compatible with the - // operation. + // operation. The adjusted norm dimensions are the dimensions of the + // reduction after removing any degenerate dimensions from the input of + // the reduction. std::vector norm_dims(reduce->dimensions().begin(), reduce->dimensions().end()); - if (norm_dims.size() != scale->shape().dimensions_size()) { + std::vector norm_dims_adjusted = AdjustedDimensions(reduce); + if (norm_dims_adjusted.size() != + ShapeUtil::DropDegenerateDimensions(scale->shape()) + .dimensions_size()) { VLOG(1) << "Layer norm input dimensions not supported."; - return OkStatus(); + return absl::OkStatus(); } for (int i = 0; i < norm_dims.size(); ++i) { - if (input->shape().dimensions(norm_dims[i]) != + if (x.Instr()->shape().dimensions(norm_dims[i]) != scale->shape().dimensions(i)) { VLOG(1) << "Layer norm input dimensions not supported."; - return OkStatus(); + return absl::OkStatus(); } } // Verify the broadcasts of scale and bias. - if (!ShapeUtil::EqualIgnoringElementType(reduce->operand(0)->shape(), - broadcast_scale->shape()) || - !ShapeUtil::EqualIgnoringElementType(reduce->operand(0)->shape(), - broadcast_bias->shape()) || - reduce->dimensions() != broadcast_scale->dimensions() || - reduce->dimensions() != broadcast_bias->dimensions()) { + if (!ShapeUtil::EqualIgnoringElementType( + ShapeUtil::DropDegenerateDimensions(reduce->operand(0)->shape()), + ShapeUtil::DropDegenerateDimensions(broadcast_scale->shape())) || + !ShapeUtil::EqualIgnoringElementType( + ShapeUtil::DropDegenerateDimensions(reduce->operand(0)->shape()), + ShapeUtil::DropDegenerateDimensions(broadcast_bias->shape())) || + norm_dims_adjusted != AdjustedDimensions(broadcast_scale) || + norm_dims_adjusted != AdjustedDimensions(broadcast_bias)) { VLOG(1) << "Layer norm operand broadcast not supported."; - return OkStatus(); + return absl::OkStatus(); } // If necessary, transpose the input so that the dimensions not being // normalized are the leading dimensions. std::vector non_norm_dims; - for (int64_t input_dim = 0; input_dim < input->shape().rank(); - ++input_dim) { - if (std::find(norm_dims.begin(), norm_dims.end(), input_dim) == + for (int64_t x_dim = 0; x_dim < x.Instr()->shape().rank(); ++x_dim) { + if (std::find(norm_dims.begin(), norm_dims.end(), x_dim) == norm_dims.end()) { - non_norm_dims.emplace_back(input_dim); + non_norm_dims.emplace_back(x_dim); } } - std::vector transpose_order = non_norm_dims; - transpose_order.insert(transpose_order.end(), norm_dims.begin(), - norm_dims.end()); + std::vector non_norm_dims_adjusted = + AdjustedDimensions(x.Instr()->shape(), non_norm_dims); + + std::vector x_transpose_order = non_norm_dims; + x_transpose_order.insert(x_transpose_order.end(), norm_dims.begin(), + norm_dims.end()); bool apply_transpose = false; - for (int i = 0; i < transpose_order.size(); ++i) { - if (transpose_order[i] != i) { + for (int i = 0; i < x_transpose_order.size(); ++i) { + if (x_transpose_order[i] != i) { apply_transpose = true; break; } } - std::optional transpose; - std::vector inverse_transpose_order(transpose_order.size()); + std::optional x_transpose; + // The transpose applied to the output is the inverse of the transpose + // applied to the input. + std::vector y_transpose_order(x_transpose_order.size()); if (apply_transpose) { - for (int k = 0; k < transpose_order.size(); ++k) { - inverse_transpose_order[transpose_order[k]] = k; + for (int k = 0; k < x_transpose_order.size(); ++k) { + y_transpose_order[x_transpose_order[k]] = k; } - TF_ASSIGN_OR_RETURN(transpose, - MakeTransposeHlo(input, transpose_order)); + TF_ASSIGN_OR_RETURN(x_transpose, + MakeTransposeHlo(x.Instr(), x_transpose_order)); } // Combine the dimensions not normalized into the first dimension of the // input as required by cuDNN. std::vector reshaped_dims = {1}; for (auto non_norm_dim : non_norm_dims) { - reshaped_dims[0] *= input->shape().dimensions(non_norm_dim); + reshaped_dims[0] *= x.Instr()->shape().dimensions(non_norm_dim); } for (auto norm_dim : norm_dims) { - reshaped_dims.emplace_back(input->shape().dimensions(norm_dim)); + reshaped_dims.emplace_back(x.Instr()->shape().dimensions(norm_dim)); } // cuDNN requires tensors to have at least four dimensions. while (reshaped_dims.size() < 4) { reshaped_dims.emplace_back(1); } - Shape reshaped_shape = - ShapeUtil::MakeShape(input->shape().element_type(), reshaped_dims); + Shape reshaped_shape = ShapeUtil::MakeShape( + x.Instr()->shape().element_type(), reshaped_dims); TF_ASSIGN_OR_RETURN( - HloInstruction * reshape, - MakeReshapeHlo(reshaped_shape, transpose.value_or(input))); + HloInstruction * x_reshape, + MakeReshapeHlo(reshaped_shape, x_transpose.value_or(x.Instr()))); // Reshape the scale and bias. std::vector reshaped_scale_dims(reshaped_dims.begin() + 1, @@ -494,13 +1047,16 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { } Shape scale_bias_shape = ShapeUtil::MakeShape( scale->shape().element_type(), reshaped_scale_dims); - TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_scale, + TF_ASSIGN_OR_RETURN(HloInstruction * scale_reshape, MakeReshapeHlo(scale_bias_shape, scale)); - TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_bias, + TF_ASSIGN_OR_RETURN(HloInstruction * bias_reshape, MakeReshapeHlo(scale_bias_shape, bias)); - - CudnnNormBackendConfig backend_config; - backend_config.set_epsilon(epsilon->literal().GetAsDouble({}).value()); + GpuBackendConfig gpu_backend_config; + CudnnNormBackendConfig& backend_config = + *gpu_backend_config.mutable_cudnn_norm_backend_config(); + backend_config.set_epsilon( + epsilon.Instr()->literal().GetAsDouble({}).value()); + backend_config.set_kind(CudnnNormBackendConfig::LAYER_FWD_INFER); auto* algorithm = backend_config.mutable_algorithm(); algorithm->set_algo_id(0); algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH); @@ -517,28 +1073,34 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { // The output of the Custom Call is a tuple, the second element of which // describes the scratch space. Shape custom_call_shape = ShapeUtil::MakeTupleShape( - {reshape->shape(), ShapeUtil::MakeShape(U8, {workspace_size})}); + {x_reshape->shape(), ShapeUtil::MakeShape(U8, {workspace_size})}); HloInstruction* custom_call = instr->AddInstruction(HloInstruction::CreateCustomCall( - custom_call_shape, {reshape, reshaped_scale, reshaped_bias}, + custom_call_shape, {x_reshape, scale_reshape, bias_reshape}, kCudnnNormCallTarget)); - TF_RETURN_IF_ERROR(custom_call->set_backend_config(backend_config)); + TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config)); TF_ASSIGN_OR_RETURN(HloInstruction * gte, MakeGetTupleElementHlo(custom_call, 0)); TF_ASSIGN_OR_RETURN( - HloInstruction * inverse_reshape, - MakeReshapeHlo(transpose.value_or(instr)->shape(), gte)); + HloInstruction * y_reshape, + MakeReshapeHlo(x_transpose.value_or(instr)->shape(), gte)); - if (!apply_transpose) { - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, inverse_reshape)); - } else { - TF_ASSIGN_OR_RETURN( - HloInstruction * inverse_transpose, - MakeTransposeHlo(inverse_reshape, inverse_transpose_order)); - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, inverse_transpose)); + std::optional y_transpose; + if (apply_transpose) { + TF_ASSIGN_OR_RETURN(y_transpose, + MakeTransposeHlo(y_reshape, y_transpose_order)); } + TF_RETURN_IF_ERROR( + ReplaceInstruction(instr, y_transpose.value_or(y_reshape))); + + // Store metadata for potential use in the backward graph. + norm_metadata_.insert( + {custom_call, + NormMetadata({x_transpose.value_or(nullptr), + y_transpose.value_or(nullptr), norm_dims_adjusted, + non_norm_dims_adjusted})}); VLOG(1) << "Layer norm rewritten into Custom Call."; @@ -553,33 +1115,43 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { } } - return OkStatus(); + return absl::OkStatus(); } // The layer norm training graph separately contains the expectation as well // as the norm factor and its cube, (variance + epsilon)^-1/2 and (variance + // epsilon)^-3/2. When identified in the graph, these quantities are fused // into the layer norm Custom Call. - Status MatchNormFactor(HloInstruction* instr, HloInstruction* custom_call, - HloInstruction* variance, HloInstruction* expectation, - HloInstruction* epsilon) { - HloInstruction *variance0, *epsilon0, *gte = custom_call->users()[0]; + absl::Status MatchNormFactor(HloInstruction* instr, + HloInstruction* custom_call, + UniqueHloInstruction& variance, + UniqueHloInstruction& expectation, + UniqueHloInstruction& epsilon) { + HloInstruction* gte = custom_call->users()[0]; if (Match(instr, - m::Divide(m::Op(), AddAnyOrder(m::Op(&variance0), - m::Broadcast(m::ConstantScalar( - &epsilon0)))))) { + m::Divide( + m::Op(), + AddAnyOrder(m::Op().WithPredicate(variance.capture_or_verify), + m::Broadcast(m::ConstantScalar().WithPredicate( + epsilon.capture_or_verify)))))) { // Verify the uniqueness of the operands. - if (variance->unique_id() != variance0->unique_id() || - epsilon->unique_id() != epsilon0->unique_id()) { + if (!variance.Instr() || !epsilon.Instr()) { VLOG(1) << "Layer norm operands not unique."; - return OkStatus(); + return absl::OkStatus(); } // Verify the element types. if (!CompatibleElementType(instr) || - !CompatibleElementType(expectation)) { + !CompatibleElementType(expectation.Instr())) { VLOG(1) << "Layer norm input types not compatible."; - return OkStatus(); + return absl::OkStatus(); + } + + // Retrieve metadata of the forward layer norm. + auto norm_metadata = norm_metadata_.extract(custom_call); + if (!norm_metadata) { + VLOG(1) << "Unable to retrieve norm metadata of forward Custom Call."; + return absl::OkStatus(); } // The shape of the expectation and norm factor return values of the @@ -590,7 +1162,8 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { {ShapeUtil::ElementsIn(shape), 1, 1, 1}); }; - Shape expectation_shape = make_compatible_shape(expectation->shape()); + Shape expectation_shape = + make_compatible_shape(expectation.Instr()->shape()); Shape norm_factor_shape = make_compatible_shape(instr->shape()); // The augmented Custom Call additionally returns the expectation and the @@ -604,20 +1177,25 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { HloInstruction* new_custom_call = instr->AddInstruction( custom_call->CloneWithNewShape(custom_call_shape)); + TF_ASSIGN_OR_RETURN( + GpuBackendConfig gpu_backend_config, + custom_call->backend_config()); + CudnnNormBackendConfig& backend_config = + *gpu_backend_config.mutable_cudnn_norm_backend_config(); + backend_config.set_kind(CudnnNormBackendConfig::LAYER_FWD_TRAIN); + // Update the workspace size. TF_ASSIGN_OR_RETURN(const int64_t c_constant, CConstant(cuda_compute_capability_)); const int64_t workspace_size = (2 * c_constant * (4 + 256)) + 32; - TF_ASSIGN_OR_RETURN( - CudnnNormBackendConfig backend_config, - custom_call->backend_config()); backend_config.mutable_algorithm()->mutable_workspace_size()->set_value( workspace_size); - TF_RETURN_IF_ERROR(custom_call->set_backend_config(backend_config)); + TF_RETURN_IF_ERROR( + new_custom_call->set_backend_config(gpu_backend_config)); auto replace_with_new_cc = [new_custom_call, this]( HloInstruction* old_instr, - int tuple_index) -> Status { + int tuple_index) -> absl::Status { TF_ASSIGN_OR_RETURN( HloInstruction * new_gte, MakeGetTupleElementHlo(new_custom_call, tuple_index)); @@ -644,26 +1222,301 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { MakeBinaryHlo(HloOpcode::kMultiply, new_multiply0, new_instr)); TF_RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_multiply1)); } - return OkStatus(); + return absl::OkStatus(); }; // Replace the result of the original Custom Call as well as the // expectation and the norm factor with the augmented Custom Call. TF_RETURN_IF_ERROR(replace_with_new_cc(gte, 0)); - TF_RETURN_IF_ERROR(replace_with_new_cc(expectation, 1)); + TF_RETURN_IF_ERROR(replace_with_new_cc(expectation.Instr(), 1)); TF_RETURN_IF_ERROR(replace_with_new_cc(instr, 2)); + // Update the Custom Call associated with the metadata of the forward + // norm. + norm_metadata.key() = new_custom_call; + norm_metadata_.insert(std::move(norm_metadata)); + VLOG(1) << "Expectation and norm factor fused into layer norm Custom Call."; } - return OkStatus(); + + return absl::OkStatus(); + } + + // Matches and rewrites the backward graph of layer norm patterns into Custom + // Calls to cuDNN when the associated forward graph has been rewritten into a + // cuDNN Custom Call. The gradients are + // DX = F1 + F2 - AddReduce(F1 + F2) / nelems, + // Dscale = AddReduce(DY * XCenter * NormFactor), + // Dbias = AddReduce(DY), + // with + // F0 = XCenter * scale * DY, + // F1 = XCenter * F0 * DNormFactor * 2 / nelems, + // F2 = NormFactor * scale * DY, + // XCenter = X - expectation(X), + // NormFactor = (variance(X) + epsilon)^-1/2 and + // DNormFactor = -1/2 * NormFactor^3. + absl::Status MatchLayerNormGradient(HloInstruction* instr) { + UniqueHloInstruction fwd_custom_call, x, x_center, scale, dy, + fused_expectation, fused_norm_factor; + HloInstruction *broadcast, *scalar, *dscale, *dbias, *reduce0, *reduce1, + *reduce2, *reduce3; + if (Match(instr, + AddAddAnyOrder( + m::Broadcast( + &broadcast, + MultiplyAddAnyOrder( + m::Broadcast(m::ConstantScalar(&scalar)), + NegateAddReduce(&reduce0, + F1(&x, &x_center, &fused_expectation, + &fwd_custom_call, &scale, &dy, + &reduce2, norm_metadata_)), + NegateAddReduce( + &reduce1, F2(&fused_norm_factor, &scale, &dy, + &fwd_custom_call, norm_metadata_)))), + F2(&fused_norm_factor, &scale, &dy, &fwd_custom_call, + norm_metadata_), + F1(&x, &x_center, &fused_expectation, &fwd_custom_call, + &scale, &dy, &reduce3, norm_metadata_)))) { + // Skip initial convert, if present. + if (instr->user_count() == 1 && + instr->users()[0]->opcode() == HloOpcode::kConvert && + CompatibleElementType(instr->users()[0])) { + instr = instr->users()[0]; + } + + // Verify the uniqueness of the captured Custom Call and inputs. + if (!fwd_custom_call.Instr() || !x.Instr() || !dy.Instr() || + !x_center.Instr() || !scale.Instr() || !fused_expectation.Instr() || + !fused_norm_factor.Instr()) { + VLOG(1) << "Layer norm gradient inputs not unique."; + return absl::OkStatus(); + } + + // Retrieve metadata of the forward layer norm. + auto norm_metadata = norm_metadata_.find(fwd_custom_call.Instr()); + if (norm_metadata == norm_metadata_.end()) { + VLOG(1) << "Unable to retrieve norm metadata of forward Custom Call."; + return absl::OkStatus(); + } + + // Verify the dimensions of reductions in the backward graph. + if (AdjustedDimensions(reduce0) != + norm_metadata->second.norm_dims_adjusted || + AdjustedDimensions(reduce1) != + norm_metadata->second.norm_dims_adjusted || + AdjustedDimensions(reduce2) != + norm_metadata->second.norm_dims_adjusted || + AdjustedDimensions(reduce3) != + norm_metadata->second.norm_dims_adjusted) { + VLOG(1) << "Unexpected reductions dimensions in layer norm gradient."; + return absl::OkStatus(); + } + + // The captured scalar must be one over the number of elements in the + // broadcasted dimensions. + float actual_r_nelems = scalar->literal().GetAsDouble({}).value(); + int64_t nelems = 1; + for (int i = 0; i < broadcast->shape().dimensions_size(); ++i) { + if (!c_linear_search(broadcast->dimensions(), i)) { + nelems *= broadcast->shape().dimensions()[i]; + } + } + // The absolute of the difference between the actual scaling factor and + // the reference value must not exceed a prescribed threshold. + float r_nelems = 1. / static_cast(nelems); + float numerical_epsilon = std::numeric_limits::epsilon(); + if (!(abs(actual_r_nelems - r_nelems) < + ((actual_r_nelems + r_nelems) * numerical_epsilon))) { + VLOG(1) + << "Layer norm backward broadcast operand outside expected range."; + return absl::OkStatus(); + } + + // Identify Dscale = AddReduce(DY * XCenter * norm factor) with factor0 + // and factor1 intended to be XCenter and DY or DY and XCenter. + auto find_dscale = + [&fused_norm_factor, &norm_metadata]( + const UniqueHloInstruction& factor0, + const UniqueHloInstruction& factor1) -> HloInstruction* { + for (HloInstruction* factor0_user : factor0.Instr()->users()) { + std::vector users; + SkipUnaryOpsTopDownRecursive(factor0_user, users); + // One of the users of factor0 must be a chained multiplication by the + // fused norm factor and factor1. + for (HloInstruction* user : users) { + if (Match(user, + MultiplyAnyOrder( + m::Op(), MultiplyAnyOrder( + m::Broadcast(BitcastOrReshape(m::Op().Is( + fused_norm_factor.Instr()))), + m::Op().Is(factor1.Instr()))))) { + // Dscale is an addition-reduction of the product. + for (HloInstruction* multiply_user : user->users()) { + if (AppliesAddReduce( + multiply_user, + norm_metadata->second.non_norm_dims_adjusted)) { + return multiply_user; + } + } + } + } + } + return nullptr; + }; + if (!(dscale = find_dscale(x_center, dy)) && + !(dscale = find_dscale(dy, x_center))) { + VLOG(1) << "Unable to identify Dscale in graph."; + return absl::OkStatus(); + } + + // Find Dbias, i.e. an addition-reduction of DY, starting from DY. + // Rewriting proceeds without fusing Dbias if unsuccessful. + dbias = FindAddReduce(dy.Instr(), + norm_metadata->second.non_norm_dims_adjusted); + + // Verify the input and output layouts. + // TODO(philipphack): Consider supporting more general cases. + if (!LayoutUtil::IsMonotonicWithDim0Major(dy.Instr()->shape().layout()) || + !LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout()) || + !LayoutUtil::IsMonotonicWithDim0Major(dscale->shape().layout()) || + (dbias && + !LayoutUtil::IsMonotonicWithDim0Major(dbias->shape().layout()))) { + VLOG(1) << "Layer norm input and/or output layouts nor supported."; + return absl::OkStatus(); + } + + // The types of X and DX must match. + if (x.Instr()->shape().element_type() != instr->shape().element_type()) { + VLOG(1) << "The types of X and DX must match."; + return absl::OkStatus(); + } + + // The types and shapes of scale, Dscale and Dbias (if present) must + // match. + if (!ShapeUtil::Equal( + ShapeUtil::DropDegenerateDimensions(scale.Instr()->shape()), + ShapeUtil::DropDegenerateDimensions(dscale->shape())) || + (dbias && + !ShapeUtil::Equal( + ShapeUtil::DropDegenerateDimensions(scale.Instr()->shape()), + ShapeUtil::DropDegenerateDimensions(dbias->shape())))) { + VLOG(1) << "Backward layer norm types not supported."; + return absl::OkStatus(); + } + + // Verify the element types. + if (!CompatibleElementType(dy.Instr())) { + VLOG(1) << "Backward layer norm types not supported."; + return absl::OkStatus(); + } + + // cuDNN requires the byte size of the element type of X to be at least + // that of DY and scale. + if (ShapeUtil::ByteSizeOfPrimitiveType( + x.Instr()->shape().element_type()) < + ShapeUtil::ByteSizeOfPrimitiveType( + dy.Instr()->shape().element_type()) || + ShapeUtil::ByteSizeOfPrimitiveType( + x.Instr()->shape().element_type()) < + ShapeUtil::ByteSizeOfPrimitiveType( + scale.Instr()->shape().element_type())) { + VLOG(1) << "Backward layer norm types not supported."; + return absl::OkStatus(); + } + + // Transpose DY applying the stored transpose order of X from the forward + // graph. + HloInstruction* transposed_dy = dy.Instr(); + if (norm_metadata->second.x_transpose) { + TF_ASSIGN_OR_RETURN( + transposed_dy, + MakeTransposeHlo(dy.Instr(), + norm_metadata->second.x_transpose->dimensions())); + } + TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_dy, + MakeReshapeHlo(x.Instr()->shape(), transposed_dy)); + + Shape dx_shape = ShapeUtil::MakeShape(instr->shape().element_type(), + x.Instr()->shape().dimensions()); + + Shape dscale_dbias_shape = ShapeUtil::MakeShape( + dscale->shape().element_type(), scale.Instr()->shape().dimensions()); + + GpuBackendConfig gpu_backend_config; + CudnnNormBackendConfig& backend_config = + *gpu_backend_config.mutable_cudnn_norm_backend_config(); + backend_config.set_kind(CudnnNormBackendConfig::LAYER_BWD); + auto* algorithm = backend_config.mutable_algorithm(); + algorithm->set_algo_id(0); + algorithm->set_math_type(se::dnn::AlgorithmProto::TENSOR_OP_MATH); + algorithm->set_is_cudnn_frontend(true); + + // Set the workspace size to its upper bound. + // TODO(philipphack): Consider autotuning the norm kernels. + TF_ASSIGN_OR_RETURN(const int64_t c_constant, + CConstant(cuda_compute_capability_)); + const int64_t workspace_size = + (2 * c_constant * (4 + 256)) + + (2 * x.Instr()->shape().dimensions(0) * 4) + 64; + algorithm->mutable_workspace_size()->set_value(workspace_size); + + // The output of the Custom Call is a tuple. The output shape of Dscale + // and Dbias is that of scale. + Shape custom_call_shape = ShapeUtil::MakeTupleShape( + {dx_shape, dscale_dbias_shape, dscale_dbias_shape, + ShapeUtil::MakeShape(U8, {workspace_size})}); + + HloInstruction* custom_call = + instr->AddInstruction(HloInstruction::CreateCustomCall( + custom_call_shape, + {x.Instr(), scale.Instr(), reshaped_dy, fused_expectation.Instr(), + fused_norm_factor.Instr()}, + kCudnnNormCallTarget)); + TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config)); + + auto replace_with_cc = [custom_call, norm_metadata, transposed_dy, this]( + HloInstruction* old_instr, + int tuple_index) -> absl::Status { + TF_ASSIGN_OR_RETURN(HloInstruction * gte, + MakeGetTupleElementHlo(custom_call, tuple_index)); + HloInstruction* new_instr; + // Transpose DX applying the stored transpose order of Y from the + // forward graph. + if (tuple_index == 0 && norm_metadata->second.y_transpose) { + TF_ASSIGN_OR_RETURN(new_instr, + MakeReshapeHlo(transposed_dy->shape(), gte)); + TF_ASSIGN_OR_RETURN( + new_instr, + MakeTransposeHlo( + new_instr, norm_metadata->second.y_transpose->dimensions())); + } else { + TF_ASSIGN_OR_RETURN(new_instr, + MakeReshapeHlo(old_instr->shape(), gte)); + } + TF_RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_instr)); + return absl::OkStatus(); + }; + + TF_RETURN_IF_ERROR(replace_with_cc(instr, 0)); + TF_RETURN_IF_ERROR(replace_with_cc(dscale, 1)); + if (dbias) { + TF_RETURN_IF_ERROR(replace_with_cc(dbias, 2)); + } + VLOG(1) << "Gradients w.r.t. x" + << (dbias ? ", scale and bias" : " and scale") + << " rewritten into layer norm backward Custom Call."; + } + + return absl::OkStatus(); } private: se::CudaComputeCapability cuda_compute_capability_; + NormMetadataMap norm_metadata_; }; -StatusOr RunOnComputation( +absl::StatusOr RunOnComputation( HloComputation* computation, se::CudaComputeCapability cuda_compute_capability) { CudnnNormRewriterVisitor visitor(cuda_compute_capability); @@ -677,7 +1530,7 @@ CudnnNormRewriter::CudnnNormRewriter( se::CudaComputeCapability cuda_compute_capability) : cuda_compute_capability_(cuda_compute_capability) {} -StatusOr CudnnNormRewriter::Run( +absl::StatusOr CudnnNormRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/gpu/cudnn_norm_rewriter.h b/xla/service/gpu/cudnn_norm_rewriter.h index 6353178aee2c9..7b3ef8d66e15f 100644 --- a/xla/service/gpu/cudnn_norm_rewriter.h +++ b/xla/service/gpu/cudnn_norm_rewriter.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,23 +16,25 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_CUDNN_NORM_REWRITER_H_ #define XLA_SERVICE_GPU_CUDNN_NORM_REWRITER_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { // Rewrites norm patterns into Custom Calls to the cuDNN library. Currently, the -// forward pass of layer norm patterns is implemented. +// forward and backward passes of layer norm patterns are implemented. class CudnnNormRewriter : public HloModulePass { public: - explicit CudnnNormRewriter( - const se::CudaComputeCapability cuda_compute_capability); + explicit CudnnNormRewriter(se::CudaComputeCapability cuda_compute_capability); absl::string_view name() const override { return "norm-rewriter"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/gpu/cudnn_norm_rewriter_test.cc b/xla/service/gpu/cudnn_norm_rewriter_test.cc index e55b37fd33764..f598e46da9acc 100644 --- a/xla/service/gpu/cudnn_norm_rewriter_test.cc +++ b/xla/service/gpu/cudnn_norm_rewriter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,16 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/cudnn_norm_rewriter.h" +#include + +#include +#include "xla/error_spec.h" +#include "xla/stream_executor/device_description.h" #if GOOGLE_CUDA #include "third_party/gpus/cuda/include/cuda.h" -#include "third_party/gpus/cudnn/cudnn.h" +#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: keep +#include "third_party/gpus/cudnn/cudnn_version.h" #endif #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/tests/filecheck.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace gpu { @@ -102,7 +105,7 @@ TEST_F(CudnnNormRewriterTest, LayerNorm2D1) { const char* optimized_hlo = R"( -; CHECK-LABEL: ENTRY %test (input: f32[2,4], scale: f32[4], bias: f32[4]) -> f32[2,4] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4], {{.*}}: f32[4], {{.*}}: f32[4]) -> f32[2,4] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4]{1,0} parameter(0) ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P0]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1) @@ -171,7 +174,7 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D3) { const char* optimized_hlo = R"( -; CHECK-LABEL: ENTRY %test (input: f32[2,4,6,8], scale: f32[8], bias: f32[8]) -> f32[2,4,6,8] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> f32[2,4,6,8] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P0]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[8]{0} parameter(1) @@ -190,6 +193,75 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D3) { TestNorm(hlo_text, optimized_hlo); } +TEST_F(CudnnNormRewriterTest, LayerNorm4D3Degenerate0) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) + GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; +#endif + if (!(GetCudaComputeCapability().major == + se::CudaComputeCapability::AMPERE) && + !(GetCudaComputeCapability().major == + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() + << "Layer norm kernels require Ampere or Hopper architectures."; + } + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f32[1,4,6,8] parameter(0) + input_square = f32[1,4,6,8] multiply(input, input) + c0 = f32[] constant(0) + input_square_sum = f32[1,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply + r_nelems = f32[] constant(0.125) + r_nelems_bcast = f32[1,4,6] broadcast(r_nelems), dimensions={} + input_square_mean = f32[1,4,6] multiply(input_square_sum, r_nelems_bcast) + input_sum = f32[1,4,6] reduce(input, c0), dimensions={3}, to_apply=apply + input_mean = f32[1,4,6] multiply(input_sum, r_nelems_bcast) + input_mean_square = f32[1,4,6] multiply(input_mean, input_mean) + variance = f32[1,4,6] subtract(input_square_mean, input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[1,4,6] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[1,4,6] add(variance, epsilon_bcast) + norm_factor = f32[1,4,6] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[1,4,6,8] broadcast(norm_factor), dimensions={0,1,2} + input_mean_bcast = f32[1,4,6,8] broadcast(input_mean), dimensions={0,1,2} + input_center = f32[1,4,6,8] subtract(input, input_mean_bcast) + norm = f32[1,4,6,8] multiply(norm_factor_bcast, input_center) + scale = f32[8] parameter(1) + scale_bcast = f32[1,4,6,8] broadcast(scale), dimensions={3} + norm_scale = f32[1,4,6,8] multiply(norm, scale_bcast) + bias = f32[8] parameter(2) + bias_bcast = f32[1,4,6,8] broadcast(bias), dimensions={3} + ROOT out = f32[1,4,6,8] add(norm_scale, bias_bcast) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[1,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> f32[1,4,6,8] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[1,4,6,8]{3,2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[24,8,1,1]{3,2,1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[8]{0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[8,1,1,1]{3,2,1,0} bitcast([[P1]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[8]{0} parameter(2) +; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[8,1,1,1]{3,2,1,0} bitcast([[P2]]) +; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[24,8,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0.001 +; CHECK: } +; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[24,8,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0 +; CHECK-NEXT: ROOT [[GTE_BITCAST:%[^ ]+]] = f32[1,4,6,8]{3,2,1,0} bitcast([[GTE]]) + )"; + + TestNorm(hlo_text, optimized_hlo); +} + TEST_F(CudnnNormRewriterTest, LayerNorm4D2) { #if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; @@ -240,7 +312,7 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D2) { const char* optimized_hlo = R"( -; CHECK-LABEL: ENTRY %test (input: f32[2,4,6,8], scale: f32[6], bias: f32[6]) -> f32[2,4,6,8] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[6], {{.*}}: f32[6]) -> f32[2,4,6,8] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) ; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P0]]), dimensions={0,1,3,2} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE]]) @@ -260,6 +332,76 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D2) { TestNorm(hlo_text, optimized_hlo); } +TEST_F(CudnnNormRewriterTest, LayerNorm4D2Degenerate1) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) + GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; +#endif + if (!(GetCudaComputeCapability().major == + se::CudaComputeCapability::AMPERE) && + !(GetCudaComputeCapability().major == + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() + << "Layer norm kernels require Ampere or Hopper architectures."; + } + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f32[2,1,6,8] parameter(0) + input_square = f32[2,1,6,8] multiply(input, input) + c0 = f32[] constant(0) + input_square_sum = f32[2,1,8] reduce(input_square, c0), dimensions={2}, to_apply=apply + r_nelems = f32[] constant(0.166667) + r_nelems_bcast = f32[2,1,8] broadcast(r_nelems), dimensions={} + input_square_mean = f32[2,1,8] multiply(input_square_sum, r_nelems_bcast) + reduce = f32[2,1,8] reduce(input, c0), dimensions={2}, to_apply=apply + input_mean = f32[2,1,8] multiply(reduce, r_nelems_bcast) + input_mean_square = f32[2,1,8] multiply(input_mean, input_mean) + variance = f32[2,1,8] subtract(input_square_mean, input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[2,1,8] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[2,1,8] add(variance, epsilon_bcast) + norm_factor = f32[2,1,8] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[2,1,6,8] broadcast(norm_factor), dimensions={0,1,3} + input_mean_bcast = f32[2,1,6,8] broadcast(input_mean), dimensions={0,1,3} + input_center = f32[2,1,6,8] subtract(input, input_mean_bcast) + norm = f32[2,1,6,8] multiply(norm_factor_bcast, input_center) + scale = f32[6] parameter(1) + scale_bcast = f32[2,1,6,8] broadcast(scale), dimensions={2} + norm_scale = f32[2,1,6,8] multiply(norm, scale_bcast) + bias = f32[6] parameter(2) + bias_broadcast = f32[2,1,6,8] broadcast(bias), dimensions={2} + ROOT out = f32[2,1,6,8] add(norm_scale, bias_broadcast) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,1,6,8], {{.*}}: f32[6], {{.*}}: f32[6]) -> f32[2,1,6,8] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,1,6,8]{3,2,1,0} parameter(0) +; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,6]{3,2,1,0} transpose([[P0]]), dimensions={1,0,3,2} +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[6]{0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[6,1,1,1]{3,2,1,0} bitcast([[P1]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[6]{0} parameter(2) +; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[6,1,1,1]{3,2,1,0} bitcast([[P2]]) +; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[16,6,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0.001 +; CHECK: } +; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[16,6,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0 +; CHECK-NEXT: ROOT [[FUSION:%[^ ]+]] = f32[2,1,6,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] + )"; + + TestNorm(hlo_text, optimized_hlo); +} + TEST_F(CudnnNormRewriterTest, LayerNorm4D12) { #if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; @@ -282,9 +424,9 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D12) { ENTRY test { input = f32[2,4,6,8] parameter(0) - multiply3 = f32[2,4,6,8] multiply(input, input) + input_square = f32[2,4,6,8] multiply(input, input) c0 = f32[] constant(0) - input_square_sum = f32[2,8] reduce(multiply3, c0), dimensions={1,2}, to_apply=apply + input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply r_nelems = f32[] constant(0.041667) r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={} input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast) @@ -310,7 +452,7 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D12) { const char* optimized_hlo = R"( -; CHECK-LABEL: ENTRY %test (input: f32[2,4,6,8], scale: f32[4,6], bias: f32[4,6]) -> f32[2,4,6,8] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6]) -> f32[2,4,6,8] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) ; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE]]) @@ -330,6 +472,76 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D12) { TestNorm(hlo_text, optimized_hlo); } +TEST_F(CudnnNormRewriterTest, LayerNorm4D12Degenerate2) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) + GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; +#endif + if (!(GetCudaComputeCapability().major == + se::CudaComputeCapability::AMPERE) && + !(GetCudaComputeCapability().major == + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() + << "Layer norm kernels require Ampere or Hopper architectures."; + } + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f32[2,4,1,8] parameter(0) + input_square = f32[2,4,1,8] multiply(input, input) + c0 = f32[] constant(0) + input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply + r_nelems = f32[] constant(0.25) + r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={} + input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast) + reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply + input_mean = f32[2,8] multiply(reduce, r_nelems_bcast) + input_mean_square = f32[2,8] multiply(input_mean, input_mean) + variance = f32[2,8] subtract(input_square_mean, input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast) + norm_factor = f32[2,8] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[2,4,1,8] broadcast(norm_factor), dimensions={0,3} + input_mean_bcast = f32[2,4,1,8] broadcast(input_mean), dimensions={0,3} + input_center = f32[2,4,1,8] subtract(input, input_mean_bcast) + norm = f32[2,4,1,8] multiply(norm_factor_bcast, input_center) + scale = f32[4,1] parameter(1) + scale_bcast = f32[2,4,1,8] broadcast(scale), dimensions={1,2} + norm_scale = f32[2,4,1,8] multiply(norm, scale_bcast) + bias = f32[4,1] parameter(2) + bias_broadcast = f32[2,4,1,8] broadcast(bias), dimensions={1,2} + ROOT out = f32[2,4,1,8] add(norm_scale, bias_broadcast) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1]) -> f32[2,4,1,8] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0) +; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1} +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} bitcast([[P1]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,1]{1,0} parameter(2) +; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} bitcast([[P2]]) +; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0.001 +; CHECK: } +; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0 +; CHECK-NEXT: ROOT [[FUSION:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] + )"; + + TestNorm(hlo_text, optimized_hlo); +} + TEST_F(CudnnNormRewriterTest, LayerNorm4D3IncorrectScaleBroadcast) { #if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; @@ -380,7 +592,7 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D3IncorrectScaleBroadcast) { const char* optimized_hlo = R"( -; CHECK-LABEL: ENTRY %test (input: f32[2,2,2,2], scale: f32[2], bias: f32[2]) -> f32[2,2,2,2] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,2,2,2], {{.*}}: f32[2], {{.*}}: f32[2]) -> f32[2,2,2,2] { ; CHECK-NOT: custom_call_target="__cudnn$norm" )"; @@ -409,9 +621,9 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain2D1) { ENTRY test { input = f32[2,4] parameter(0) - multiply3 = f32[2,4] multiply(input, input) + input_square = f32[2,4] multiply(input, input) c0 = f32[] constant(0) - input_square_sum = f32[2] reduce(multiply3, c0), dimensions={1}, to_apply=apply + input_square_sum = f32[2] reduce(input_square, c0), dimensions={1}, to_apply=apply r_nelems = f32[] constant(0.25) r_nelems_bcast = f32[2] broadcast(r_nelems), dimensions={} input_square_mean = f32[2] multiply(input_square_sum,r_nelems_bcast) @@ -439,7 +651,7 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain2D1) { const char* optimized_hlo = R"( -; CHECK-LABEL: ENTRY %test (input: f32[2,4], scale: f32[4], bias: f32[4]) -> (f32[2,4], f32[2], f32[2], f32[2]) { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4], {{.*}}: f32[4], {{.*}}: f32[4]) -> (f32[2,4], f32[2], f32[2], f32[2]) { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4]{1,0} parameter(0) ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P0]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1) @@ -486,9 +698,9 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D3) { ENTRY test { input = f32[2,4,6,8] parameter(0) - multiply3 = f32[2,4,6,8] multiply(input, input) + input_square = f32[2,4,6,8] multiply(input, input) c0 = f32[] constant(0) - input_square_sum = f32[2,4,6] reduce(multiply3, c0), dimensions={3}, to_apply=apply + input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply r_nelems = f32[] constant(0.125) r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={} input_square_mean = f32[2,4,6] multiply(input_square_sum, r_nelems_bcast) @@ -516,7 +728,7 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D3) { const char* optimized_hlo = R"( -; CHECK-LABEL: ENTRY %test (input: f32[2,4,6,8], scale: f32[8], bias: f32[8]) -> (f32[2,4,6,8], f32[2,4,6], f32[2,4,6], f32[2,4,6]) { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> (f32[2,4,6,8], f32[2,4,6], f32[2,4,6], f32[2,4,6]) { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P0]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[8]{0} parameter(1) @@ -563,9 +775,9 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12) { ENTRY test { input = f32[2,4,6,8] parameter(0) - multiply3 = f32[2,4,6,8] multiply(input, input) + input_square = f32[2,4,6,8] multiply(input, input) c0 = f32[] constant(0) - input_square_sum = f32[2,8] reduce(multiply3, c0), dimensions={1,2}, to_apply=apply + input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply r_nelems = f32[] constant(0.041667) r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={} input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast) @@ -593,7 +805,7 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12) { const char* optimized_hlo = R"( -; CHECK-LABEL: ENTRY %test (input: f32[2,4,6,8], scale: f32[4,6], bias: f32[4,6]) -> (f32[2,4,6,8], f32[2,8], f32[2,8], f32[2,8]) { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6]) -> (f32[2,4,6,8], f32[2,8], f32[2,8], f32[2,8]) { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) ; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE]]) @@ -619,6 +831,906 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12) { TestNorm(hlo_text, optimized_hlo); } +TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12Degenerate2) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) + GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; +#endif + if (!(GetCudaComputeCapability().major == + se::CudaComputeCapability::AMPERE) && + !(GetCudaComputeCapability().major == + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() + << "Layer norm kernels require Ampere or Hopper architectures."; + } + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f32[2,4,1,8] parameter(0) + input_square = f32[2,4,1,8] multiply(input, input) + c0 = f32[] constant(0) + input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply + r_nelems = f32[] constant(0.25) + r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={} + input_square_mean = f32[2,8] multiply(input_square_sum, r_nelems_bcast) + reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply + input_mean = f32[2,8] multiply(reduce, r_nelems_bcast) + input_mean_square = f32[2,8] multiply(input_mean, input_mean) + variance = f32[2,8] subtract(input_square_mean, input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast) + norm_factor = f32[2,8] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[2,4,1,8] broadcast(norm_factor), dimensions={0,3} + input_mean_bcast = f32[2,4,1,8] broadcast(input_mean), dimensions={0,3} + input_center = f32[2,4,1,8] subtract(input, input_mean_bcast) + norm = f32[2,4,1,8] multiply(norm_factor_bcast, input_center) + scale = f32[4,1] parameter(1) + scale_bcast = f32[2,4,1,8] broadcast(scale), dimensions={1,2} + norm_scale = f32[2,4,1,8] multiply(norm, scale_bcast) + bias = f32[4,1] parameter(2) + bias_broadcast = f32[2,4,1,8] broadcast(bias), dimensions={1,2} + norm_scale_bias = f32[2,4,1,8] add(norm_scale, bias_broadcast) + norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon) + ROOT out = (f32[2,4,1,8], f32[2,8], f32[2,8], f32[2,8]) tuple(norm_scale_bias, input_mean, norm_factor, norm_factor_cube) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1]) -> (f32[2,4,1,8], f32[2,8], f32[2,8], f32[2,8]) { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0) +; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1} +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} bitcast([[P1]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,1]{1,0} parameter(2) +; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} bitcast([[P2]]) +; CHECK-NEXT: [[CC:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0.001 +; CHECK: } +; CHECK-NEXT: [[GTE0:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0 +; CHECK-NEXT: [[FUSION0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} fusion([[GTE0]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]] +; CHECK-NEXT: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1 +; CHECK-NEXT: [[GTE1_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE1]]) +; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2 +; CHECK-NEXT: [[GTE2_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE2]]) +; CHECK-NEXT: [[FUSION1:%[^ ]+]] = f32[2,8]{1,0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]] +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}) tuple([[FUSION0]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION1]]) + )"; + + TestNorm(hlo_text, optimized_hlo); +} + +TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward2D1) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) + GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; +#endif + if (!(GetCudaComputeCapability().major == + se::CudaComputeCapability::AMPERE) && + !(GetCudaComputeCapability().major == + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() + << "Layer norm kernels require Ampere or Hopper architectures."; + } + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f32[2,4] parameter(0) + input_square = f32[2,4] multiply(input, input) + c0 = f32[] constant(0) + input_square_sum = f32[2] reduce(input_square, c0), dimensions={1}, to_apply=apply + reduce = f32[2] reduce(input, c0), dimensions={1}, to_apply=apply + r_nelems = f32[] constant(0.25) + r_nelems_bcast = f32[2] broadcast(r_nelems), dimensions={} + input_square_mean = f32[2] multiply(input_square_sum,r_nelems_bcast) + input_mean = f32[2] multiply(reduce, r_nelems_bcast) + input_mean_square = f32[2] multiply(input_mean,input_mean) + variance = f32[2] subtract(input_square_mean,input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[2] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[2] add(variance, epsilon_bcast) + norm_factor = f32[2] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[2,4] broadcast(norm_factor), dimensions={0} + input_mean_bcast = f32[2,4] broadcast(input_mean), dimensions={0} + input_center = f32[2,4] subtract(input, input_mean_bcast) + norm = f32[2,4] multiply(input_center, norm_factor_bcast) + scale = f32[4] parameter(1) + scale_bcast = f32[2,4] broadcast(scale), dimensions={1} + norm_scale = f32[2,4] multiply(norm, scale_bcast) + bias = f32[4] parameter(2) + bias_bcast = f32[2,4] broadcast(bias), dimensions={1} + norm_scale_bias = f32[2,4] add(norm_scale, bias_bcast) + doutput = f32[2,4] parameter(3) + dbias = f32[4] reduce(doutput, c0), dimensions={0}, to_apply=apply + norm_doutput = f32[2,4] multiply(norm, doutput) + dscale = f32[4] reduce(norm_doutput, c0), dimensions={0}, to_apply=apply + scale_doutput = f32[2,4] multiply(scale_bcast, doutput) + input_center_scale_doutput = f32[2,4] multiply(input_center, scale_doutput) + f0 = f32[2] reduce(input_center_scale_doutput, c0), dimensions={1}, to_apply=apply + norm_factor_cube = f32[2] divide(norm_factor, variance_plus_epsilon) + c1 = f32[] constant(-0.5) + c1_bcast = f32[2] broadcast(c1), dimensions={} + dnorm_factor = f32[2] multiply(norm_factor_cube, c1_bcast) + f0_dnorm_factor = f32[2] multiply(f0, dnorm_factor) + c2 = f32[] constant(0.5) + c2_bcast = f32[2] broadcast(c2), dimensions={} + f0_dnorm_factor_scaled = f32[2] multiply(f0_dnorm_factor, c2_bcast) + f0_dnorm_factor_scaled_bcast = f32[2,4] broadcast(f0_dnorm_factor_scaled), dimensions={0} + f1 = f32[2,4] multiply(input_center, f0_dnorm_factor_scaled_bcast) + minus_f1 = f32[2,4] negate(f1) + minus_f1_sum = f32[2] reduce(minus_f1, c0), dimensions={1}, to_apply=apply + f2 = f32[2,4] multiply(norm_factor_bcast, scale_doutput) + minus_f2 = f32[2,4] negate(f2) + minus_f2_sum = f32[2] reduce(minus_f2, c0), dimensions={1}, to_apply=apply + minus_f1_f2_sum = f32[2] add(minus_f1_sum, minus_f2_sum) + minus_f1_f2_sum_scaled = f32[2] multiply(minus_f1_f2_sum, r_nelems_bcast) + minus_f1_f2_sum_scaled_bcast = f32[2,4] broadcast(minus_f1_f2_sum_scaled), dimensions={0} + f1_f2 = f32[2,4] add(f1, f2) + dinput = f32[2,4] add(f1_f2, minus_f1_f2_sum_scaled_bcast) + ROOT out = (f32[2,4], f32[2,4], f32[4], f32[4]) tuple(norm_scale_bias, dinput, dscale, dbias) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4], {{.*}}: f32[4], {{.*}}: f32[4], {{.*}}: f32[2,4]) -> (f32[2,4], f32[2,4], f32[4], f32[4]) { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4]{1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} bitcast([[P1]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) +; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} bitcast([[P2]]) +; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, f32[2,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0.001 +; CHECK-DAG: "kind":"LAYER_FWD_TRAIN" +; CHECK: } +; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0 +; CHECK-DAG: [[GTE0_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE0]]) +; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4]{1,0} parameter(3) +; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} bitcast([[P3]]) +; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1 +; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[2,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2 +; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[2,4,1,1]{3,2,1,0}, f32[4,1,1,1]{3,2,1,0}, f32[4,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0 +; CHECK-DAG: "kind":"LAYER_BWD" +; CHECK: } +; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[2,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0 +; CHECK-DAG: [[GTE3_BITCAST:%[^ ]+]] = f32[2,4]{1,0} bitcast([[GTE3]]) +; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1 +; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE4]]) +; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2 +; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE5]]) +; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2,4]{1,0}, f32[4]{0}, f32[4]{0}) tuple([[GTE0_BITCAST]], [[GTE3_BITCAST]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) + )"; + + TestNorm(hlo_text, optimized_hlo); +} + +TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D3) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) + GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; +#endif + if (!(GetCudaComputeCapability().major == + se::CudaComputeCapability::AMPERE) && + !(GetCudaComputeCapability().major == + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() + << "Layer norm kernels require Ampere or Hopper architectures."; + } + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f32[2,4,6,8] parameter(0) + input_square = f32[2,4,6,8] multiply(input, input) + c0 = f32[] constant(0) + input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply + reduce = f32[2,4,6] reduce(input, c0), dimensions={3}, to_apply=apply + r_nelems = f32[] constant(0.125) + r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={} + input_square_mean = f32[2,4,6] multiply(input_square_sum,r_nelems_bcast) + input_mean = f32[2,4,6] multiply(reduce, r_nelems_bcast) + input_mean_square = f32[2,4,6] multiply(input_mean,input_mean) + variance = f32[2,4,6] subtract(input_square_mean,input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast) + norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2} + input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2} + input_center = f32[2,4,6,8] subtract(input, input_mean_bcast) + norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast) + scale = f32[8] parameter(1) + scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={3} + norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast) + bias = f32[8] parameter(2) + bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={3} + norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast) + doutput = f32[2,4,6,8] parameter(3) + dbias = f32[8] reduce(doutput, c0), dimensions={0,1,2}, to_apply=apply + norm_doutput = f32[2,4,6,8] multiply(norm, doutput) + dscale = f32[8] reduce(norm_doutput, c0), dimensions={0,1,2}, to_apply=apply + scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput) + input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput) + f0 = f32[2,4,6] reduce(input_center_scale_doutput, c0), dimensions={3}, to_apply=apply + norm_factor_cube = f32[2,4,6] divide(norm_factor, variance_plus_epsilon) + c1 = f32[] constant(-0.5) + c1_bcast = f32[2,4,6] broadcast(c1), dimensions={} + dnorm_factor = f32[2,4,6] multiply(norm_factor_cube, c1_bcast) + f0_dnorm_factor = f32[2,4,6] multiply(f0, dnorm_factor) + c2 = f32[] constant(0.25) + c2_bcast = f32[2,4,6] broadcast(c2), dimensions={} + f0_dnorm_factor_scaled = f32[2,4,6] multiply(f0_dnorm_factor, c2_bcast) + f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,1,2} + f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast) + minus_f1 = f32[2,4,6,8] negate(f1) + minus_f1_sum = f32[2,4,6] reduce(minus_f1, c0), dimensions={3}, to_apply=apply + f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput) + minus_f2 = f32[2,4,6,8] negate(f2) + minus_f2_sum = f32[2,4,6] reduce(minus_f2, c0), dimensions={3}, to_apply=apply + minus_f1_f2_sum = f32[2,4,6] add(minus_f1_sum, minus_f2_sum) + minus_f1_f2_sum_scaled = f32[2,4,6] multiply(minus_f1_f2_sum, r_nelems_bcast) + minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,1,2} + f1_f2 = f32[2,4,6,8] add(f1, f2) + dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast) + ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[8], f32[8]) tuple(norm_scale_bias, dinput, dscale, dbias) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8], {{.*}}: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[8], f32[8]) { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[8]{0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[8,1,1,1]{3,2,1,0} bitcast([[P1]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[8]{0} parameter(2) +; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[8,1,1,1]{3,2,1,0} bitcast([[P2]]) +; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, f32[48,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0.001 +; CHECK-DAG: "kind":"LAYER_FWD_TRAIN" +; CHECK: } +; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0 +; CHECK-DAG: [[GTE0_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE0]]) +; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3) +; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} bitcast([[P3]]) +; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1 +; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[48,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2 +; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[48,8,1,1]{3,2,1,0}, f32[8,1,1,1]{3,2,1,0}, f32[8,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0 +; CHECK-DAG: "kind":"LAYER_BWD" +; CHECK: } +; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[48,8,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0 +; CHECK-DAG: [[GTE3_BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[GTE3]]) +; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[8,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1 +; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[8]{0} bitcast([[GTE4]]) +; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[8,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2 +; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[8]{0} bitcast([[GTE5]]) +; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[8]{0}, f32[8]{0}) tuple([[GTE0_BITCAST]], [[GTE3_BITCAST]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) + )"; + + TestNorm(hlo_text, optimized_hlo); +} + +TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D2) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) + GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; +#endif + if (!(GetCudaComputeCapability().major == + se::CudaComputeCapability::AMPERE) && + !(GetCudaComputeCapability().major == + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() + << "Layer norm kernels require Ampere or Hopper architectures."; + } + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f32[2,4,6,8] parameter(0) + input_square = f32[2,4,6,8] multiply(input, input) + c0 = f32[] constant(0) + input_square_sum = f32[2,4,8] reduce(input_square, c0), dimensions={2}, to_apply=apply + reduce = f32[2,4,8] reduce(input, c0), dimensions={2}, to_apply=apply + r_nelems = f32[] constant(0.166667) + r_nelems_bcast = f32[2,4,8] broadcast(r_nelems), dimensions={} + input_square_mean = f32[2,4,8] multiply(input_square_sum,r_nelems_bcast) + input_mean = f32[2,4,8] multiply(reduce, r_nelems_bcast) + input_mean_square = f32[2,4,8] multiply(input_mean,input_mean) + variance = f32[2,4,8] subtract(input_square_mean,input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[2,4,8] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[2,4,8] add(variance, epsilon_bcast) + norm_factor = f32[2,4,8] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,3} + input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,3} + input_center = f32[2,4,6,8] subtract(input, input_mean_bcast) + norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast) + scale = f32[6] parameter(1) + scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={2} + norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast) + bias = f32[6] parameter(2) + bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={2} + norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast) + doutput = f32[2,4,6,8] parameter(3) + dbias = f32[6] reduce(doutput, c0), dimensions={0,1,3}, to_apply=apply + norm_doutput = f32[2,4,6,8] multiply(norm, doutput) + dscale = f32[6] reduce(norm_doutput, c0), dimensions={0,1,3}, to_apply=apply + scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput) + input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput) + f0 = f32[2,4,8] reduce(input_center_scale_doutput, c0), dimensions={2}, to_apply=apply + norm_factor_cube = f32[2,4,8] divide(norm_factor, variance_plus_epsilon) + c1 = f32[] constant(-0.5) + c1_bcast = f32[2,4,8] broadcast(c1), dimensions={} + dnorm_factor = f32[2,4,8] multiply(norm_factor_cube, c1_bcast) + f0_dnorm_factor = f32[2,4,8] multiply(f0, dnorm_factor) + c2 = f32[] constant(0.333333) + c2_bcast = f32[2,4,8] broadcast(c2), dimensions={} + f0_dnorm_factor_scaled = f32[2,4,8] multiply(f0_dnorm_factor, c2_bcast) + f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,1,3} + f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast) + minus_f1 = f32[2,4,6,8] negate(f1) + minus_f1_sum = f32[2,4,8] reduce(minus_f1, c0), dimensions={2}, to_apply=apply + f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput) + minus_f2 = f32[2,4,6,8] negate(f2) + minus_f2_sum = f32[2,4,8] reduce(minus_f2, c0), dimensions={2}, to_apply=apply + minus_f1_f2_sum = f32[2,4,8] add(minus_f1_sum, minus_f2_sum) + minus_f1_f2_sum_scaled = f32[2,4,8] multiply(minus_f1_f2_sum, r_nelems_bcast) + minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,1,3} + f1_f2 = f32[2,4,6,8] add(f1, f2) + dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast) + ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[6], f32[6]) tuple(norm_scale_bias, dinput, dscale, dbias) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[6], {{.*}}: f32[6], {{.*}}: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[6], f32[6]) { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) +; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P0]]), dimensions={0,1,3,2} +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[6]{0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[6,1,1,1]{3,2,1,0} bitcast([[P1]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[6]{0} parameter(2) +; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[6,1,1,1]{3,2,1,0} bitcast([[P2]]) +; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[64,6,1,1]{3,2,1,0}, f32[64,1,1,1]{3,2,1,0}, f32[64,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0.001 +; CHECK-DAG: "kind":"LAYER_FWD_TRAIN" +; CHECK: } +; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0 +; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3) +; CHECK-NEXT: [[TRANSPOSE1:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P3]]), dimensions={0,1,3,2} +; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE1]]) +; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[64,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1 +; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[64,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2 +; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[64,6,1,1]{3,2,1,0}, f32[6,1,1,1]{3,2,1,0}, f32[6,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0 +; CHECK-DAG: "kind":"LAYER_BWD" +; CHECK: } +; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0 +; CHECK-DAG: [[FUSION:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-DAG: [[GTEF0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=0 +; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=1 +; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[6,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1 +; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[6]{0} bitcast([[GTE4]]) +; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[6,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2 +; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[6]{0} bitcast([[GTE5]]) +; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[6]{0}, f32[6]{0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) + )"; + + TestNorm(hlo_text, optimized_hlo); +} + +TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) + GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; +#endif + if (!(GetCudaComputeCapability().major == + se::CudaComputeCapability::AMPERE) && + !(GetCudaComputeCapability().major == + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() + << "Layer norm kernels require Ampere or Hopper architectures."; + } + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f32[2,4,6,8] parameter(0) + input_square = f32[2,4,6,8] multiply(input, input) + c0 = f32[] constant(0) + input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply + reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply + r_nelems = f32[] constant(0.041667) + r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={} + input_square_mean = f32[2,8] multiply(input_square_sum,r_nelems_bcast) + input_mean = f32[2,8] multiply(reduce, r_nelems_bcast) + input_mean_square = f32[2,8] multiply(input_mean,input_mean) + variance = f32[2,8] subtract(input_square_mean,input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast) + norm_factor = f32[2,8] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,3} + input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,3} + input_center = f32[2,4,6,8] subtract(input, input_mean_bcast) + norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast) + scale = f32[4,6] parameter(1) + scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1,2} + norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast) + bias = f32[4,6] parameter(2) + bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={1,2} + norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast) + doutput = f32[2,4,6,8] parameter(3) + dbias = f32[4,6] reduce(doutput, c0), dimensions={0,3}, to_apply=apply + norm_doutput = f32[2,4,6,8] multiply(norm, doutput) + dscale = f32[4,6] reduce(norm_doutput, c0), dimensions={0,3}, to_apply=apply + scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput) + input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput) + f0 = f32[2,8] reduce(input_center_scale_doutput, c0), dimensions={1,2}, to_apply=apply + norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon) + c1 = f32[] constant(-0.5) + c1_bcast = f32[2,8] broadcast(c1), dimensions={} + dnorm_factor = f32[2,8] multiply(norm_factor_cube, c1_bcast) + f0_dnorm_factor = f32[2,8] multiply(f0, dnorm_factor) + c2 = f32[] constant(0.083333) + c2_bcast = f32[2,8] broadcast(c2), dimensions={} + f0_dnorm_factor_scaled = f32[2,8] multiply(f0_dnorm_factor, c2_bcast) + f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,3} + f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast) + minus_f1 = f32[2,4,6,8] negate(f1) + minus_f1_sum = f32[2,8] reduce(minus_f1, c0), dimensions={1,2}, to_apply=apply + f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput) + minus_f2 = f32[2,4,6,8] negate(f2) + minus_f2_sum = f32[2,8] reduce(minus_f2, c0), dimensions={1,2}, to_apply=apply + minus_f1_f2_sum = f32[2,8] add(minus_f1_sum, minus_f2_sum) + minus_f1_f2_sum_scaled = f32[2,8] multiply(minus_f1_f2_sum, r_nelems_bcast) + minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,3} + f1_f2 = f32[2,4,6,8] add(f1, f2) + dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast) + ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[4,6], f32[4,6]) tuple(norm_scale_bias, dinput, dscale, dbias) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6], {{.*}}: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4,6], f32[4,6]) { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) +; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2} +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,6,1,1]{3,2,1,0} bitcast([[P1]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,6]{1,0} parameter(2) +; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[4,6,1,1]{3,2,1,0} bitcast([[P2]]) +; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0.001 +; CHECK-DAG: "kind":"LAYER_FWD_TRAIN" +; CHECK: } +; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0 +; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3) +; CHECK-NEXT: [[TRANSPOSE1:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P3]]), dimensions={0,3,1,2} +; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE1]]) +; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1 +; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2 +; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, f32[4,6,1,1]{3,2,1,0}, f32[4,6,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0 +; CHECK-DAG: "kind":"LAYER_BWD" +; CHECK: } +; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0 +; CHECK-DAG: [[FUSION:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-DAG: [[GTEF0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=0 +; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=1 +; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[4,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1 +; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4,6]{1,0} bitcast([[GTE4]]) +; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[4,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2 +; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4,6]{1,0} bitcast([[GTE5]]) +; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4,6]{1,0}, f32[4,6]{1,0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) + )"; + + TestNorm(hlo_text, optimized_hlo); +} + +TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12Degenerate2) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) + GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; +#endif + if (!(GetCudaComputeCapability().major == + se::CudaComputeCapability::AMPERE) && + !(GetCudaComputeCapability().major == + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() + << "Layer norm kernels require Ampere or Hopper architectures."; + } + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f32[2,4,1,8] parameter(0) + input_square = f32[2,4,1,8] multiply(input, input) + c0 = f32[] constant(0) + input_square_sum = f32[2,8] reduce(input_square, c0), dimensions={1,2}, to_apply=apply + reduce = f32[2,8] reduce(input, c0), dimensions={1,2}, to_apply=apply + r_nelems = f32[] constant(0.25) + r_nelems_bcast = f32[2,8] broadcast(r_nelems), dimensions={} + input_square_mean = f32[2,8] multiply(input_square_sum,r_nelems_bcast) + input_mean = f32[2,8] multiply(reduce, r_nelems_bcast) + input_mean_square = f32[2,8] multiply(input_mean,input_mean) + variance = f32[2,8] subtract(input_square_mean,input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[2,8] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[2,8] add(variance, epsilon_bcast) + norm_factor = f32[2,8] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[2,4,1,8] broadcast(norm_factor), dimensions={0,3} + input_mean_bcast = f32[2,4,1,8] broadcast(input_mean), dimensions={0,3} + input_center = f32[2,4,1,8] subtract(input, input_mean_bcast) + norm = f32[2,4,1,8] multiply(input_center, norm_factor_bcast) + scale = f32[4,1] parameter(1) + scale_bcast = f32[2,4,1,8] broadcast(scale), dimensions={1,2} + norm_scale = f32[2,4,1,8] multiply(norm, scale_bcast) + bias = f32[4,1] parameter(2) + bias_bcast = f32[2,4,1,8] broadcast(bias), dimensions={1,2} + norm_scale_bias = f32[2,4,1,8] add(norm_scale, bias_bcast) + doutput = f32[2,4,1,8] parameter(3) + dbias = f32[4,1] reduce(doutput, c0), dimensions={0,3}, to_apply=apply + norm_doutput = f32[2,4,1,8] multiply(norm, doutput) + dscale = f32[4,1] reduce(norm_doutput, c0), dimensions={0,3}, to_apply=apply + scale_doutput = f32[2,4,1,8] multiply(scale_bcast, doutput) + input_center_scale_doutput = f32[2,4,1,8] multiply(input_center, scale_doutput) + f0 = f32[2,8] reduce(input_center_scale_doutput, c0), dimensions={1,2}, to_apply=apply + norm_factor_cube = f32[2,8] divide(norm_factor, variance_plus_epsilon) + c1 = f32[] constant(-0.5) + c1_bcast = f32[2,8] broadcast(c1), dimensions={} + dnorm_factor = f32[2,8] multiply(norm_factor_cube, c1_bcast) + f0_dnorm_factor = f32[2,8] multiply(f0, dnorm_factor) + c2 = f32[] constant(0.5) + c2_bcast = f32[2,8] broadcast(c2), dimensions={} + f0_dnorm_factor_scaled = f32[2,8] multiply(f0_dnorm_factor, c2_bcast) + f0_dnorm_factor_scaled_bcast = f32[2,4,1,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,3} + f1 = f32[2,4,1,8] multiply(input_center, f0_dnorm_factor_scaled_bcast) + minus_f1 = f32[2,4,1,8] negate(f1) + minus_f1_sum = f32[2,8] reduce(minus_f1, c0), dimensions={1,2}, to_apply=apply + f2 = f32[2,4,1,8] multiply(norm_factor_bcast, scale_doutput) + minus_f2 = f32[2,4,1,8] negate(f2) + minus_f2_sum = f32[2,8] reduce(minus_f2, c0), dimensions={1,2}, to_apply=apply + minus_f1_f2_sum = f32[2,8] add(minus_f1_sum, minus_f2_sum) + minus_f1_f2_sum_scaled = f32[2,8] multiply(minus_f1_f2_sum, r_nelems_bcast) + minus_f1_f2_sum_scaled_bcast = f32[2,4,1,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,3} + f1_f2 = f32[2,4,1,8] add(f1, f2) + dinput = f32[2,4,1,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast) + ROOT out = (f32[2,4,1,8], f32[2,4,1,8], f32[4,1], f32[4,1]) tuple(norm_scale_bias, dinput, dscale, dbias) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1], {{.*}}: f32[2,4,1,8]) -> (f32[2,4,1,8], f32[2,4,1,8], f32[4,1], f32[4,1]) { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0) +; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1} +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} bitcast([[P1]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,1]{1,0} parameter(2) +; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} bitcast([[P2]]) +; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, f32[16,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0.001 +; CHECK-DAG: "kind":"LAYER_FWD_TRAIN" +; CHECK: } +; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0 +; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(3) +; CHECK-NEXT: [[TRANSPOSE1:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P3]]), dimensions={2,0,3,1} +; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE1]]) +; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1 +; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2 +; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, f32[4,1,1,1]{3,2,1,0}, f32[4,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0 +; CHECK-DAG: "kind":"LAYER_BWD" +; CHECK: } +; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0 +; CHECK-DAG: [[FUSION0:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,4,1,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]] +; CHECK-DAG: [[GTEF0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} get-tuple-element([[FUSION0]]), index=0 +; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} get-tuple-element([[FUSION0]]), index=1 +; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1 +; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4,1]{1,0} bitcast([[GTE4]]) +; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2 +; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4,1]{1,0} bitcast([[GTE5]]) +; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,4,1,8]{3,2,1,0}, f32[4,1]{1,0}, f32[4,1]{1,0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) + )"; + + TestNorm(hlo_text, optimized_hlo); +} + +TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D1DoutputReshapeSplit) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) + GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; +#endif + if (!(GetCudaComputeCapability().major == + se::CudaComputeCapability::AMPERE) && + !(GetCudaComputeCapability().major == + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() + << "Layer norm kernels require Ampere or Hopper architectures."; + } + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f32[2,4,6,8] parameter(0) + input_square = f32[2,4,6,8] multiply(input, input) + c0 = f32[] constant(0) + input_square_sum = f32[2,6,8] reduce(input_square, c0), dimensions={1}, to_apply=apply + reduce = f32[2,6,8] reduce(input, c0), dimensions={1}, to_apply=apply + r_nelems = f32[] constant(0.25) + r_nelems_bcast = f32[2,6,8] broadcast(r_nelems), dimensions={} + input_square_mean = f32[2,6,8] multiply(input_square_sum,r_nelems_bcast) + input_mean = f32[2,6,8] multiply(reduce, r_nelems_bcast) + input_mean_square = f32[2,6,8] multiply(input_mean,input_mean) + variance = f32[2,6,8] subtract(input_square_mean,input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[2,6,8] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[2,6,8] add(variance, epsilon_bcast) + norm_factor = f32[2,6,8] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,2,3} + input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,2,3} + input_center = f32[2,4,6,8] subtract(input, input_mean_bcast) + norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast) + scale = f32[4] parameter(1) + scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1} + norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast) + bias = f32[4] parameter(2) + bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={1} + norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast) + doutput = f32[2,4,48] parameter(3) + dbias = f32[4] reduce(doutput, c0), dimensions={0,2}, to_apply=apply + doutput_bitcast = f32[2,4,6,8] reshape(doutput) + norm_doutput = f32[2,4,6,8] multiply(norm, doutput_bitcast) + dscale = f32[4] reduce(norm_doutput, c0), dimensions={0,2,3}, to_apply=apply + scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput_bitcast) + input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput) + f0 = f32[2,6,8] reduce(input_center_scale_doutput, c0), dimensions={1}, to_apply=apply + norm_factor_cube = f32[2,6,8] divide(norm_factor, variance_plus_epsilon) + c1 = f32[] constant(-0.5) + c1_bcast = f32[2,6,8] broadcast(c1), dimensions={} + dnorm_factor = f32[2,6,8] multiply(norm_factor_cube, c1_bcast) + f0_dnorm_factor = f32[2,6,8] multiply(f0, dnorm_factor) + c2 = f32[] constant(0.5) + c2_bcast = f32[2,6,8] broadcast(c2), dimensions={} + f0_dnorm_factor_scaled = f32[2,6,8] multiply(f0_dnorm_factor, c2_bcast) + f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,2,3} + f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast) + minus_f1 = f32[2,4,6,8] negate(f1) + minus_f1_sum = f32[2,6,8] reduce(minus_f1, c0), dimensions={1}, to_apply=apply + f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput) + minus_f2 = f32[2,4,6,8] negate(f2) + minus_f2_sum = f32[2,6,8] reduce(minus_f2, c0), dimensions={1}, to_apply=apply + minus_f1_f2_sum = f32[2,6,8] add(minus_f1_sum, minus_f2_sum) + minus_f1_f2_sum_scaled = f32[2,6,8] multiply(minus_f1_f2_sum, r_nelems_bcast) + minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,2,3} + f1_f2 = f32[2,4,6,8] add(f1, f2) + dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast) + ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) tuple(norm_scale_bias, dinput, dscale, dbias) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4], {{.*}}: f32[4], {{.*}}: f32[2,4,48]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) +; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} transpose([[P0]]), dimensions={0,2,3,1} +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} bitcast([[P1]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) +; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} bitcast([[P2]]) +; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0.001 +; CHECK-DAG: "kind":"LAYER_FWD_TRAIN" +; CHECK: } +; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0 +; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,48]{2,1,0} parameter(3) +; CHECK-DAG: [[FUSION0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} fusion([[P3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]] +; CHECK-DAG: [[FUSION0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[FUSION0]]) +; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1 +; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2 +; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[4,1,1,1]{3,2,1,0}, f32[4,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[FUSION0_BITCAST]], [[GTE1]], [[GTE2]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0 +; CHECK-DAG: "kind":"LAYER_BWD" +; CHECK: } +; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0 +; CHECK-DAG: [[FUSION1:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]] +; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=0 +; CHECK-DAG: [[GTEF2:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=1 +; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1 +; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE4]]) +; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2 +; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE5]]) +; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4]{0}, f32[4]{0}) tuple([[GTEF1]], [[GTEF2]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) + )"; + + TestNorm(hlo_text, optimized_hlo); +} + +TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D1DoutputReshapeCombine) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) + GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; +#endif + if (!(GetCudaComputeCapability().major == + se::CudaComputeCapability::AMPERE) && + !(GetCudaComputeCapability().major == + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() + << "Layer norm kernels require Ampere or Hopper architectures."; + } + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f32[2,4,6,8] parameter(0) + input_square = f32[2,4,6,8] multiply(input, input) + c0 = f32[] constant(0) + input_square_sum = f32[2,6,8] reduce(input_square, c0), dimensions={1}, to_apply=apply + reduce = f32[2,6,8] reduce(input, c0), dimensions={1}, to_apply=apply + r_nelems = f32[] constant(0.25) + r_nelems_bcast = f32[2,6,8] broadcast(r_nelems), dimensions={} + input_square_mean = f32[2,6,8] multiply(input_square_sum,r_nelems_bcast) + input_mean = f32[2,6,8] multiply(reduce, r_nelems_bcast) + input_mean_square = f32[2,6,8] multiply(input_mean,input_mean) + variance = f32[2,6,8] subtract(input_square_mean,input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[2,6,8] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[2,6,8] add(variance, epsilon_bcast) + norm_factor = f32[2,6,8] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,2,3} + input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,2,3} + input_center = f32[2,4,6,8] subtract(input, input_mean_bcast) + norm = f32[2,4,6,8] multiply(input_center, norm_factor_bcast) + scale = f32[4] parameter(1) + scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={1} + norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast) + bias = f32[4] parameter(2) + bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={1} + norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast) + doutput = f32[2,4,6,2,2,2] parameter(3) + dbias = f32[4] reduce(doutput, c0), dimensions={0,2,3,4,5}, to_apply=apply + doutput_bitcast = f32[2,4,6,8] reshape(doutput) + norm_doutput = f32[2,4,6,8] multiply(norm, doutput_bitcast) + dscale = f32[4] reduce(norm_doutput, c0), dimensions={0,2,3}, to_apply=apply + scale_doutput = f32[2,4,6,8] multiply(scale_bcast, doutput_bitcast) + input_center_scale_doutput = f32[2,4,6,8] multiply(input_center, scale_doutput) + f0 = f32[2,6,8] reduce(input_center_scale_doutput, c0), dimensions={1}, to_apply=apply + norm_factor_cube = f32[2,6,8] divide(norm_factor, variance_plus_epsilon) + c1 = f32[] constant(-0.5) + c1_bcast = f32[2,6,8] broadcast(c1), dimensions={} + dnorm_factor = f32[2,6,8] multiply(norm_factor_cube, c1_bcast) + f0_dnorm_factor = f32[2,6,8] multiply(f0, dnorm_factor) + c2 = f32[] constant(0.5) + c2_bcast = f32[2,6,8] broadcast(c2), dimensions={} + f0_dnorm_factor_scaled = f32[2,6,8] multiply(f0_dnorm_factor, c2_bcast) + f0_dnorm_factor_scaled_bcast = f32[2,4,6,8] broadcast(f0_dnorm_factor_scaled), dimensions={0,2,3} + f1 = f32[2,4,6,8] multiply(input_center, f0_dnorm_factor_scaled_bcast) + minus_f1 = f32[2,4,6,8] negate(f1) + minus_f1_sum = f32[2,6,8] reduce(minus_f1, c0), dimensions={1}, to_apply=apply + f2 = f32[2,4,6,8] multiply(norm_factor_bcast, scale_doutput) + minus_f2 = f32[2,4,6,8] negate(f2) + minus_f2_sum = f32[2,6,8] reduce(minus_f2, c0), dimensions={1}, to_apply=apply + minus_f1_f2_sum = f32[2,6,8] add(minus_f1_sum, minus_f2_sum) + minus_f1_f2_sum_scaled = f32[2,6,8] multiply(minus_f1_f2_sum, r_nelems_bcast) + minus_f1_f2_sum_scaled_bcast = f32[2,4,6,8] broadcast(minus_f1_f2_sum_scaled), dimensions={0,2,3} + f1_f2 = f32[2,4,6,8] add(f1, f2) + dinput = f32[2,4,6,8] add(f1_f2, minus_f1_f2_sum_scaled_bcast) + ROOT out = (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) tuple(norm_scale_bias, dinput, dscale, dbias) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4], {{.*}}: f32[4], {{.*}}: f32[2,4,6,2,2,2]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4], f32[4]) { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) +; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} transpose([[P0]]), dimensions={0,2,3,1} +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} bitcast([[P1]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) +; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} bitcast([[P2]]) +; CHECK-NEXT: [[CC0:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, f32[96,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0.001 +; CHECK-DAG: "kind":"LAYER_FWD_TRAIN" +; CHECK: } +; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0 +; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,2,2,2]{5,4,3,2,1,0} parameter(3) +; CHECK-DAG: [[FUSION0:%[^ ]+]] = f32[2,6,8,4]{3,2,1,0} fusion([[P3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]] +; CHECK-DAG: [[FUSION0_BITCAST:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} bitcast([[FUSION0]]) +; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1 +; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[96,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2 +; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[96,4,1,1]{3,2,1,0}, f32[4,1,1,1]{3,2,1,0}, f32[4,1,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[FUSION0_BITCAST]], [[GTE1]], [[GTE2]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0 +; CHECK-DAG: "kind":"LAYER_BWD" +; CHECK: } +; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[96,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0 +; CHECK-DAG: [[FUSION1:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]] +; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=0 +; CHECK-DAG: [[GTEF2:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION1]]), index=1 +; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1 +; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE4]]) +; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[4,1,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2 +; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4]{0} bitcast([[GTE5]]) +; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4]{0}, f32[4]{0}) tuple([[GTEF1]], [[GTEF2]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) + )"; + + TestNorm(hlo_text, optimized_hlo); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/cudnn_pad_for_convolutions.cc b/xla/service/gpu/cudnn_pad_for_convolutions.cc index 71a375e920c6b..866324cad8586 100644 --- a/xla/service/gpu/cudnn_pad_for_convolutions.cc +++ b/xla/service/gpu/cudnn_pad_for_convolutions.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,19 +15,34 @@ limitations under the License. #include "xla/service/gpu/cudnn_pad_for_convolutions.h" +#include +#include +#include +#include +#include #include +#include +#include "absl/container/flat_hash_set.h" #include "absl/functional/bind_front.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/literal_util.h" +#include "xla/primitive_util.h" +#include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/cudnn_support_utils.h" -#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/stream_executor_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/dnn.h" #include "xla/util.h" -#include "xla/window_util.h" -#include "tsl/platform/status.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -68,9 +83,9 @@ static HloInstruction* PadInstruction(HloInstruction* instr, } // Modifies the given convolution to have the given input and result shapes. -static Status PadConv(HloCustomCallInstruction* conv, - absl::Span new_input_shapes, - const Shape& new_result_shape) { +static absl::Status PadConv(HloCustomCallInstruction* conv, + absl::Span new_input_shapes, + const Shape& new_result_shape) { CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0)) << "conv must use 0 scratch bytes, i.e. this pass must be run " "before CudnnConvAlgorithmPicker."; @@ -147,11 +162,11 @@ static std::vector GetRelevantConvs( // new_input_shapes. Notice that new_input_shapes is a vector for multiple // input tensors. This function shall return true if padding is necessary or // false otherwise in addition to status. -static StatusOr ResolveAndPad( +static absl::StatusOr ResolveAndPad( HloCustomCallInstruction* conv, - std::function(HloCustomCallInstruction* conv, - std::vector* new_input_shapes, - Shape* new_result_shape)> + std::function(HloCustomCallInstruction* conv, + std::vector* new_input_shapes, + Shape* new_result_shape)> resolve_pad_shapes) { std::vector new_input_shapes; Shape new_result_shape; @@ -177,7 +192,7 @@ static StatusOr ResolveAndPad( // Don't run this pass on GPUs without tensor cores -- it will make them slower! // // TODO(jlebar): Also pad dots. -static StatusOr TryResolvePaddedShapesForTensorCore( +static absl::StatusOr TryResolvePaddedShapesForTensorCore( HloCustomCallInstruction* conv, std::vector* new_input_shapes_ptr, Shape* new_result_shape_ptr) { TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv)); @@ -299,7 +314,7 @@ static StatusOr TryResolvePaddedShapesForTensorCore( // Adds padding to cudnn integer convolutions to make input and output feature // maps multiples of pad_to (usually 4 or 32). -StatusOr TryResolvePaddedShapesForIntegerConvolution( +absl::StatusOr TryResolvePaddedShapesForIntegerConvolution( int pad_to, const se::CudaComputeCapability& compute_capability, HloCustomCallInstruction* conv, std::vector* new_input_shapes_ptr, Shape* new_result_shape_ptr) { @@ -373,7 +388,7 @@ StatusOr TryResolvePaddedShapesForIntegerConvolution( case CudnnConvKind::kForward: CHECK_EQ(new_input_shapes.size(), 2); // Input feature maps - pad_dim(&new_input_shapes[0], dnums.input_feature_dimension(), + pad_dim(new_input_shapes.data(), dnums.input_feature_dimension(), input_vect_size); // Kernel for the input feature maps pad_dim(&new_input_shapes[1], dnums.kernel_input_feature_dimension(), @@ -389,7 +404,7 @@ StatusOr TryResolvePaddedShapesForIntegerConvolution( case CudnnConvKind::kForwardActivation: CHECK(new_input_shapes.size() == 3 || new_input_shapes.size() == 4); // Input feature maps - pad_dim(&new_input_shapes[0], dnums.input_feature_dimension(), + pad_dim(new_input_shapes.data(), dnums.input_feature_dimension(), input_vect_size); // Kernel for the input feature maps pad_dim(&new_input_shapes[1], dnums.kernel_input_feature_dimension(), @@ -471,7 +486,7 @@ StatusOr TryResolvePaddedShapesForIntegerConvolution( return changed; } -StatusOr CudnnPadForConvolutions::Run( +absl::StatusOr CudnnPadForConvolutions::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/gpu/cudnn_pad_for_convolutions.h b/xla/service/gpu/cudnn_pad_for_convolutions.h index dc403112632ee..be7fae26d6cd0 100644 --- a/xla/service/gpu/cudnn_pad_for_convolutions.h +++ b/xla/service/gpu/cudnn_pad_for_convolutions.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,12 +16,12 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_CUDNN_PAD_FOR_CONVOLUTIONS_H_ #define XLA_SERVICE_GPU_CUDNN_PAD_FOR_CONVOLUTIONS_H_ -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/literal_util.h" -#include "xla/service/gpu/ir_emission_utils.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/service/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" #include "xla/util.h" -#include "xla/window_util.h" namespace xla { namespace gpu { @@ -39,7 +39,7 @@ class CudnnPadForConvolutions : public HloModulePass { } // Run PadForConvolutions on the given module and return if any change is made using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/gpu/cudnn_pad_for_convolutions_test.cc b/xla/service/gpu/cudnn_pad_for_convolutions_test.cc index d81d98a598780..2bae239358182 100644 --- a/xla/service/gpu/cudnn_pad_for_convolutions_test.cc +++ b/xla/service/gpu/cudnn_pad_for_convolutions_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,13 +15,13 @@ limitations under the License. #include "xla/service/gpu/cudnn_pad_for_convolutions.h" +#include +#include #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/status_macros.h" #include "xla/tests/hlo_test_base.h" -#include "xla/util.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/cudnn_simplify_padding.cc b/xla/service/gpu/cudnn_simplify_padding.cc index 5c8405b958a1e..c8f87f7103e3a 100644 --- a/xla/service/gpu/cudnn_simplify_padding.cc +++ b/xla/service/gpu/cudnn_simplify_padding.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,17 +16,29 @@ limitations under the License. #include "xla/service/gpu/cudnn_simplify_padding.h" #include -#include +#include #include #include -#include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/literal.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/pattern_matcher.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla::gpu { @@ -96,8 +108,9 @@ std::optional NumTrailingZeroOutputFeatures(HloInstruction* conv) { // If the filter is reordered for an int8x32 NCHW_VECT_C convolution, find the // original, un-reordered filter and check *it* for trailing zero output // features. - auto backend_config = conv->backend_config(); - if (backend_config.ok() && backend_config->reordered_int8_nchw_vect()) { + auto backend_config = conv->backend_config(); + if (backend_config.ok() && + backend_config->cudnn_conv_backend_config().reordered_int8_nchw_vect()) { VLOG(2) << "Matched int8x32 convolution with filter reordering"; // Try to set weights to the original, un-reordered value. @@ -269,7 +282,7 @@ std::optional NumTrailingZeroOutputFeatures(HloInstruction* conv) { return std::nullopt; } -StatusOr TrySimplifyPadding(HloInstruction* instr) { +absl::StatusOr TrySimplifyPadding(HloInstruction* instr) { // Match one of the following patterns. // conv -> slice -> pad // conv -> reshape -> slice-> pad @@ -452,7 +465,7 @@ StatusOr TrySimplifyPadding(HloInstruction* instr) { } // anonymous namespace -StatusOr CudnnSimplifyPadding::Run( +absl::StatusOr CudnnSimplifyPadding::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/gpu/cudnn_simplify_padding.h b/xla/service/gpu/cudnn_simplify_padding.h index 67b5d37316ec7..5811d26144c4f 100644 --- a/xla/service/gpu/cudnn_simplify_padding.h +++ b/xla/service/gpu/cudnn_simplify_padding.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,8 +16,11 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_CUDNN_SIMPLIFY_PADDING_H_ #define XLA_SERVICE_GPU_CUDNN_SIMPLIFY_PADDING_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" namespace xla::gpu { @@ -54,7 +57,7 @@ class CudnnSimplifyPadding : public HloModulePass { absl::string_view name() const override { return "cudnn_simplify_padding"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/cudnn_simplify_padding_test.cc b/xla/service/gpu/cudnn_simplify_padding_test.cc index a722b4d93f995..4cd9b72ef8ea6 100644 --- a/xla/service/gpu/cudnn_simplify_padding_test.cc +++ b/xla/service/gpu/cudnn_simplify_padding_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,12 +15,18 @@ limitations under the License. #include "xla/service/gpu/cudnn_simplify_padding.h" +#include #include #include +#include +#include +#include "absl/functional/function_ref.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xla/literal.h" #include "xla/service/algebraic_simplifier.h" #include "xla/service/call_inliner.h" -#include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/cudnn_pad_for_convolutions.h" #include "xla/service/gpu/cudnn_vectorize_convolutions.h" #include "xla/service/hlo_pass_fix.h" @@ -29,10 +35,13 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/service/reshape_mover.h" #include "xla/service/tuple_simplifier.h" -#include "xla/status_macros.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/dnn.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" namespace xla::gpu { @@ -45,8 +54,8 @@ class CudnnSimplifyPaddingTest : public HloTestBase { // Runs the whole relevant pass pipeline starting at CudnnPadForConvolutions. // This lets us test that we're matching the patterns that actually get // generated by padding+vectorization. - StatusOr RunEndToEnd(std::pair compute_capability, - HloModule* module) { + absl::StatusOr RunEndToEnd(std::pair compute_capability, + HloModule* module) { se::CudaComputeCapability cc{compute_capability.first, compute_capability.second}; @@ -82,7 +91,7 @@ class CudnnSimplifyPaddingTest : public HloTestBase { return changed; } - StatusOr RunJustThisPass(HloModule* module) { + absl::StatusOr RunJustThisPass(HloModule* module) { TF_ASSIGN_OR_RETURN(bool changed, RunHloPass(CudnnSimplifyPadding(), module)); VLOG(1) << "after simplify_padding:\n" << module->ToString(); diff --git a/xla/service/gpu/cudnn_support_utils.cc b/xla/service/gpu/cudnn_support_utils.cc index 5188265af7076..864943884a56f 100644 --- a/xla/service/gpu/cudnn_support_utils.cc +++ b/xla/service/gpu/cudnn_support_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,19 +15,24 @@ limitations under the License. #include "xla/service/gpu/cudnn_support_utils.h" -#include +#include #include +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/primitive_util.h" #include "xla/service/gpu/cublas_cudnn.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/util.h" #include "xla/window_util.h" -#include "tsl/platform/status.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { -StatusOr CudnnSupportsOptimizedIntegerConvolution( +absl::StatusOr CudnnSupportsOptimizedIntegerConvolution( const se::CudaComputeCapability& compute_capability, HloCustomCallInstruction& conv, int vector_size) { TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(&conv)); @@ -119,12 +124,13 @@ StatusOr CudnnSupportsOptimizedIntegerConvolution( return true; } -StatusOr CudnnInferTransposeForFilterReordering( +absl::StatusOr +CudnnInferTransposeForFilterReordering( const Shape& shape, const ConvolutionDimensionNumbers& dimension_numbers) { // A normal filter should have four dimensions: [O, I, H, W] // An already vectorized filter will have five: [O, I/k, H, W, k]; k=4|32 if (shape.rank() != 4 && shape.rank() != 5) { - return InternalError("Filter shape has unexpected rank."); + return Internal("Filter shape has unexpected rank."); } // Get convolution dimension numbers. @@ -142,7 +148,7 @@ StatusOr CudnnInferTransposeForFilterReordering( if (shape.dimensions(dO) % 32 != 0 || shape.dimensions(dI) % (32 / vsize) != 0 || (revectorize && vsize != 4 && vsize != 32)) { - return InternalError("Filter shape is not vectorizable."); + return Internal("Filter shape is not vectorizable."); } // Build the resulting shape: [O, I/32, H, W, 32] @@ -187,14 +193,14 @@ StatusOr CudnnInferTransposeForFilterReordering( return CudnnReorderTransposeConfig{split_shape, output_shape, permutation}; } -StatusOr CudnnInferTransposeForBiasReordering( - const Shape& shape) { +absl::StatusOr +CudnnInferTransposeForBiasReordering(const Shape& shape) { // Expected bias has one dimension: [O] if (shape.rank() != 1) { - return InternalError("Bias shape has unexpected rank."); + return Internal("Bias shape has unexpected rank."); } if (shape.dimensions(0) % 32 != 0) { - return InternalError("Bias shape is not vectorizable."); + return Internal("Bias shape is not vectorizable."); } // Build the transposable shape: [O/32, 4, 2, 4] diff --git a/xla/service/gpu/cudnn_support_utils.h b/xla/service/gpu/cudnn_support_utils.h index 0ed4a88bdd21f..780e0593f9b36 100644 --- a/xla/service/gpu/cudnn_support_utils.h +++ b/xla/service/gpu/cudnn_support_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,12 +16,13 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_CUDNN_SUPPORT_UTILS_H_ #define XLA_SERVICE_GPU_CUDNN_SUPPORT_UTILS_H_ +#include #include +#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/shape.h" #include "xla/stream_executor/device_description.h" -#include "tsl/platform/status.h" namespace xla { namespace gpu { @@ -31,7 +32,7 @@ namespace gpu { // // This function does not guarantee that a convolution will be padded and/or // vectorized. It only checks that it is a valid candiate for such optimization. -StatusOr CudnnSupportsOptimizedIntegerConvolution( +absl::StatusOr CudnnSupportsOptimizedIntegerConvolution( const se::CudaComputeCapability& compute_capability, HloCustomCallInstruction& conv, int vector_size); @@ -59,14 +60,15 @@ struct CudnnReorderTransposeConfig { // Create a transposition for an int8x32 convolution filter that effectively // does the same thing as cudnnReorderFilterAndBias, but could also be constant // folded or fused. -StatusOr CudnnInferTransposeForFilterReordering( +absl::StatusOr +CudnnInferTransposeForFilterReordering( const Shape& shape, const ConvolutionDimensionNumbers& dimension_numbers); // Create a transposition for an int8x32 convolution bias that effectively // does the same thing as cudnnReorderFilterAndBias, but could also be constant // folded or fused. -StatusOr CudnnInferTransposeForBiasReordering( - const Shape& shape); +absl::StatusOr +CudnnInferTransposeForBiasReordering(const Shape& shape); } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/cudnn_support_utils_test.cc b/xla/service/gpu/cudnn_support_utils_test.cc index bcfe30a4aec38..0cc170fb1a32f 100644 --- a/xla/service/gpu/cudnn_support_utils_test.cc +++ b/xla/service/gpu/cudnn_support_utils_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,29 +16,31 @@ limitations under the License. #include "xla/service/gpu/cudnn_support_utils.h" #include +#include +#include #include #include #include #include -#include "absl/status/status.h" +#include #include "absl/strings/string_view.h" -#include "xla/hlo/ir/dynamic_parameter_binding.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/hlo_parser.h" -#include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" -#include "xla/status_macros.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/dnn.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" #include "xla/util.h" #include "tsl/platform/errors.h" -#include "tsl/platform/status.h" +#include "tsl/platform/logging.h" #include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -50,7 +52,7 @@ class CudnnSupportUtilsTest : public HloTestBase { public: // Gets the custom call with `target` from the `module`. Expects that there is // one and only one matching call. - StatusOr GetCustomCall( + absl::StatusOr GetCustomCall( xla::VerifiedHloModule* module, absl::string_view target) { HloCustomCallInstruction* call = nullptr; for (HloComputation* comp : module->MakeNonfusionComputations()) { diff --git a/xla/service/gpu/cudnn_vectorize_convolutions.cc b/xla/service/gpu/cudnn_vectorize_convolutions.cc index 74b01b900e9bd..9ab4630f2c2e7 100644 --- a/xla/service/gpu/cudnn_vectorize_convolutions.cc +++ b/xla/service/gpu/cudnn_vectorize_convolutions.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,19 +15,41 @@ limitations under the License. #include "xla/service/gpu/cudnn_vectorize_convolutions.h" +#include #include #include #include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "xla/client/xla_builder.h" +#include "xla/client/xla_computation.h" #include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_clone_context.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/primitive_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/cudnn_support_utils.h" #include "xla/service/gpu/stream_executor_util.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/dnn.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -79,7 +101,7 @@ static std::vector GetRelevantConvs( // `sibling_computation`. // // Yes, we serialize/deserialize as a proto. :) -static StatusOr BuilderToHloComputation( +static absl::StatusOr BuilderToHloComputation( XlaBuilder& b, XlaOp root, HloComputation* sibling_computation) { TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root)); TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comp.GetProgramShape()); @@ -260,7 +282,8 @@ static ConvolutionDimensionNumbers VectorizeDnums( // Reorders the convolution's filter and bias (if present) according to // cudnnReorderFilterAndBias. Also marks that the filter + bias are reordered // in the conv's backend-config. -Status ReorderInt8NchwVect(HloCustomCallInstruction* conv, XlaOp* operands) { +absl::Status ReorderInt8NchwVect(HloCustomCallInstruction* conv, + XlaOp* operands) { bool has_bias = conv->operand_count() > 2; VLOG(1) << "Reordering filter" << (has_bias ? " and bias" : "") << " (replacement for cudnnReorderFilterAndBias)"; @@ -269,10 +292,12 @@ Status ReorderInt8NchwVect(HloCustomCallInstruction* conv, XlaOp* operands) { ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers(); // Update convolution backend config. - TF_ASSIGN_OR_RETURN(auto config, - conv->backend_config()); + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + conv->backend_config()); + CudnnConvBackendConfig& config = + *gpu_config.mutable_cudnn_conv_backend_config(); config.set_reordered_int8_nchw_vect(true); - TF_RETURN_IF_ERROR(conv->set_backend_config(config)); + TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config)); // Reorder the filter. TF_ASSIGN_OR_RETURN(Shape filter_shape, builder->GetShape(operands[1])); @@ -299,7 +324,7 @@ Status ReorderInt8NchwVect(HloCustomCallInstruction* conv, XlaOp* operands) { transpose = Transpose(reshape, reorder.permutation); operands[2] = Reshape(reorder.result_shape, transpose); } - return OkStatus(); + return absl::OkStatus(); } // Tries to vectorize an already-vectorized convolution. @@ -310,7 +335,7 @@ Status ReorderInt8NchwVect(HloCustomCallInstruction* conv, XlaOp* operands) { // // (The dimensions can appear in any order; which is N/C/etc is determined by // the convolutions' dnums.) -static StatusOr TryRevectorizeConv( +static absl::StatusOr TryRevectorizeConv( const se::CudaComputeCapability& compute_capability, const se::dnn::VersionInfo& cudnn_version, HloCustomCallInstruction* conv, int vect_size) { @@ -471,7 +496,7 @@ static StatusOr TryRevectorizeConv( // // This requires that C be a multiple of vect_size. CudnnPadForConvolutions can // add padding to make this true. -static StatusOr TryVectorizeConv( +static absl::StatusOr TryVectorizeConv( const se::CudaComputeCapability& compute_capability, const se::dnn::VersionInfo& cudnn_version, HloCustomCallInstruction* conv, int64_t vect_size) { @@ -591,7 +616,7 @@ static StatusOr TryVectorizeConv( } // namespace -StatusOr CudnnVectorizeConvolutions::Run( +absl::StatusOr CudnnVectorizeConvolutions::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/gpu/cudnn_vectorize_convolutions.h b/xla/service/gpu/cudnn_vectorize_convolutions.h index 9721125cf2375..43165f24c2562 100644 --- a/xla/service/gpu/cudnn_vectorize_convolutions.h +++ b/xla/service/gpu/cudnn_vectorize_convolutions.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,12 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_CUDNN_VECTORIZE_CONVOLUTIONS_H_ #define XLA_SERVICE_GPU_CUDNN_VECTORIZE_CONVOLUTIONS_H_ -#include - +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" namespace xla { @@ -56,7 +58,7 @@ class CudnnVectorizeConvolutions : public HloModulePass { return "cudnn_vectorize_convolutions"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/gpu/cudnn_vectorize_convolutions_test.cc b/xla/service/gpu/cudnn_vectorize_convolutions_test.cc index f806ec0309bd2..d448621fbc6bc 100644 --- a/xla/service/gpu/cudnn_vectorize_convolutions_test.cc +++ b/xla/service/gpu/cudnn_vectorize_convolutions_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,16 +15,26 @@ limitations under the License. #include "xla/service/gpu/cudnn_vectorize_convolutions.h" +#include +#include #include +#include +#include +#include "absl/algorithm/container.h" +#include "absl/status/statusor.h" #include "xla/service/call_inliner.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/dnn.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace xla { @@ -36,8 +46,8 @@ namespace m = ::xla::match; class CudnnVectorizeConvolutionsTest : public HloTestBase { protected: // Runs this pass and some cleanup to make pattern-matching easier. - StatusOr Run(std::pair compute_capability, - HloModule* module) { + absl::StatusOr Run(std::pair compute_capability, + HloModule* module) { CudnnVectorizeConvolutions pass( se::CudaComputeCapability{compute_capability.first, compute_capability.second}, @@ -308,8 +318,9 @@ TEST_F(CudnnVectorizeConvolutionsTest, VectorizeTo32) { .WithShape(S8, {10, 20, 30, 4, 32})), m::Op()))); - EXPECT_TRUE(conv->backend_config() - ->reordered_int8_nchw_vect()); + EXPECT_TRUE(conv->backend_config() + ->cudnn_conv_backend_config() + .reordered_int8_nchw_vect()); } TEST_F(CudnnVectorizeConvolutionsTest, BiasAndSideInput) { @@ -360,8 +371,9 @@ TEST_F(CudnnVectorizeConvolutionsTest, BiasAndSideInput) { .WithShape(S8, {10, 20, 30, 4, 32})), m::Op()))); - EXPECT_TRUE(conv->backend_config() - ->reordered_int8_nchw_vect()); + EXPECT_TRUE(conv->backend_config() + ->cudnn_conv_backend_config() + .reordered_int8_nchw_vect()); } TEST_F(CudnnVectorizeConvolutionsTest, InputNHWC_OutputNCHW) { @@ -412,8 +424,9 @@ TEST_F(CudnnVectorizeConvolutionsTest, InputNHWC_OutputNCHW) { .WithShape(S8, {10, 4, 32, 20, 30})), m::Op()))); - EXPECT_TRUE(conv->backend_config() - ->reordered_int8_nchw_vect()); + EXPECT_TRUE(conv->backend_config() + ->cudnn_conv_backend_config() + .reordered_int8_nchw_vect()); } TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo32) { @@ -447,8 +460,9 @@ TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo32) { .WithShape(S8, {10, 20, 30, 32, 4})), m::Op()))); - EXPECT_FALSE(conv->backend_config() - ->reordered_int8_nchw_vect()); + EXPECT_FALSE(conv->backend_config() + ->cudnn_conv_backend_config() + .reordered_int8_nchw_vect()); } TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32) { @@ -504,8 +518,9 @@ TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32) { .WithShape(S8, {10, 20, 30, 48, 4}), m::Op()))); - EXPECT_TRUE(conv->backend_config() - ->reordered_int8_nchw_vect()); + EXPECT_TRUE(conv->backend_config() + ->cudnn_conv_backend_config() + .reordered_int8_nchw_vect()); } TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32NCHW) { @@ -561,8 +576,9 @@ TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32NCHW) { .WithShape(S8, {10, 32, 20, 30, 4}), m::Op()))); - EXPECT_TRUE(conv->backend_config() - ->reordered_int8_nchw_vect()); + EXPECT_TRUE(conv->backend_config() + ->cudnn_conv_backend_config() + .reordered_int8_nchw_vect()); } TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32VectorDimFirst) { @@ -618,8 +634,9 @@ TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32VectorDimFirst) { .WithShape(S8, {4, 10, 20, 30, 48}), m::Op()))); - EXPECT_TRUE(conv->backend_config() - ->reordered_int8_nchw_vect()); + EXPECT_TRUE(conv->backend_config() + ->cudnn_conv_backend_config() + .reordered_int8_nchw_vect()); } TEST_F(CudnnVectorizeConvolutionsTest, NoVectorize4To32) { @@ -687,8 +704,9 @@ TEST_F(CudnnVectorizeConvolutionsTest, Vectorize16To32) { .WithShape(S8, {10, 20, 30, 6, 2, 16})) .WithShape(S8, {10, 20, 30, 12, 16}), m::Op()))); - EXPECT_TRUE(conv->backend_config() - ->reordered_int8_nchw_vect()); + EXPECT_TRUE(conv->backend_config() + ->cudnn_conv_backend_config() + .reordered_int8_nchw_vect()); } TEST_F(CudnnVectorizeConvolutionsTest, VectorizeMixedTo32) { @@ -731,8 +749,9 @@ TEST_F(CudnnVectorizeConvolutionsTest, VectorizeMixedTo32) { .WithShape(S8, {10, 20, 30, 6, 16, 2})) .WithShape(S8, {10, 20, 30, 96, 2}), m::Op()))); - EXPECT_TRUE(conv->backend_config() - ->reordered_int8_nchw_vect()); + EXPECT_TRUE(conv->backend_config() + ->cudnn_conv_backend_config() + .reordered_int8_nchw_vect()); } } // namespace diff --git a/xla/service/gpu/cusolver_context.cc b/xla/service/gpu/cusolver_context.cc index 282bfa2953af4..8f9642ba4ea05 100644 --- a/xla/service/gpu/cusolver_context.cc +++ b/xla/service/gpu/cusolver_context.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,11 +17,25 @@ limitations under the License. #include #include -#include - +#include +#include + +#include "absl/status/status.h" +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuComplex.h" +#include "third_party/gpus/cuda/include/cusolverDn.h" +#include "third_party/gpus/cuda/include/cusolver_common.h" +#include "third_party/gpus/cuda/include/library_types.h" +#endif #include "xla/primitive_util.h" +#include "xla/status.h" +#include "xla/stream_executor/blas.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_stream.h" +#include "xla/stream_executor/stream.h" #include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" namespace xla { namespace gpu { @@ -37,12 +51,12 @@ struct GpuComplexT { // For ROCm, use hipsolver if the ROCm version >= 4.5 and // rocblas/rocsolver if the ROCm version < 4.5. -#if !TENSORFLOW_USE_ROCM +#if GOOGLE_CUDA #define GPU_SOLVER_CONTEXT_PREFIX cusolverDn #define GPU_SOLVER_PREFIX cusolverDn -using gpuStream_t = cudaStream_t; +using gpuDataType_t = cudaDataType_t; template <> struct GpuComplexT> { @@ -62,9 +76,9 @@ struct GpuComplexT*> { typedef cuDoubleComplex* type; }; -#else +#elif TENSORFLOW_USE_ROCM -using gpuStream_t = hipStream_t; +using gpuDataType_t = hipDataType; #if TF_ROCM_VERSION >= 40500 #define GPU_SOLVER_CONTEXT_PREFIX se::wrap::hipsolver @@ -110,14 +124,14 @@ struct GpuComplexT*> { }; #endif // TF_ROCM_VERSION >= 40500 -#endif // !TENSORFLOW_USE_ROCM +#endif // TENSORFLOW_USE_ROCM template inline typename GpuComplexT::type* ToDevicePointer(se::DeviceMemory p) { return static_cast::type*>(p.opaque()); } -#if !TENSORFLOW_USE_ROCM +#if GOOGLE_CUDA cublasFillMode_t GpuBlasUpperLower(se::blas::UpperLower uplo) { switch (uplo) { case se::blas::UpperLower::kUpper: @@ -129,11 +143,11 @@ cublasFillMode_t GpuBlasUpperLower(se::blas::UpperLower uplo) { } } -// Converts a cuSolver status to a Status. -Status ConvertStatus(cusolverStatus_t status) { +// Converts a cuSolver absl::Status to a Status. +absl::Status ConvertStatus(cusolverStatus_t status) { switch (status) { case CUSOLVER_STATUS_SUCCESS: - return OkStatus(); + return absl::OkStatus(); case CUSOLVER_STATUS_NOT_INITIALIZED: return FailedPrecondition("cuSolver has not been initialized"); case CUSOLVER_STATUS_ALLOC_FAILED: @@ -160,7 +174,8 @@ Status ConvertStatus(cusolverStatus_t status) { return Unknown("Unknown cuSolver error"); } } -#else +#elif TENSORFLOW_USE_ROCM + #if TF_ROCM_VERSION >= 40500 hipsolverFillMode_t GpuBlasUpperLower(se::blas::UpperLower uplo) { switch (uplo) { @@ -169,14 +184,14 @@ hipsolverFillMode_t GpuBlasUpperLower(se::blas::UpperLower uplo) { case se::blas::UpperLower::kLower: return HIPSOLVER_FILL_MODE_LOWER; default: - LOG(FATAL) << "Invalid value of blas::UpperLower."; + LOG(FATAL) << "Invalid value of blas::UpperLower"; } } -Status ConvertStatus(hipsolverStatus_t status) { +absl::Status ConvertStatus(hipsolverStatus_t status) { switch (status) { case HIPSOLVER_STATUS_SUCCESS: - return OkStatus(); + return absl::OkStatus(); case HIPSOLVER_STATUS_NOT_INITIALIZED: return FailedPrecondition("hipsolver has not been initialized"); case HIPSOLVER_STATUS_ALLOC_FAILED: @@ -203,7 +218,7 @@ Status ConvertStatus(hipsolverStatus_t status) { return Unknown("Unknown hipsolver error"); } } -#else +#else // TF_ROCM_VERSION < 40500 rocblas_fill GpuBlasUpperLower(se::blas::UpperLower uplo) { switch (uplo) { case se::blas::UpperLower::kUpper: @@ -215,10 +230,10 @@ rocblas_fill GpuBlasUpperLower(se::blas::UpperLower uplo) { } } -Status ConvertStatus(rocblas_status status) { +absl::Status ConvertStatus(rocblas_status status) { switch (status) { case rocblas_status_success: - return OkStatus(); + return absl::OkStatus(); case rocblas_status_invalid_handle: return FailedPrecondition("handle not initialized, invalid or null"); case rocblas_status_not_implemented: @@ -271,6 +286,8 @@ Status ConvertStatus(rocblas_status status) { GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Cpotrf_bufferSize) #define GpuSolverZpotrf_bufferSize \ GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Zpotrf_bufferSize) +#define GpuSolverDnXpotrf_bufferSize \ + GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Xpotrf_bufferSize) #if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER #define GpuSolverSpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Spotrf) #define GpuSolverDpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Dpotrf) @@ -280,6 +297,7 @@ Status ConvertStatus(rocblas_status status) { #define GpuSolverDpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, DpotrfBatched) #define GpuSolverCpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, CpotrfBatched) #define GpuSolverZpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, ZpotrfBatched) +#define GpuSolverXpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Xpotrf) #else // TENSORFLOW_USE_ROCSOLVER #define GpuSolverSpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, spotrf) #define GpuSolverDpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, dpotrf) @@ -293,13 +311,13 @@ Status ConvertStatus(rocblas_status status) { } // namespace -StatusOr GpuSolverContext::Create() { +absl::StatusOr GpuSolverContext::Create() { gpusolverHandle_t handle; TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverCreate(&handle))); return GpuSolverContext(handle); } -Status GpuSolverContext::SetStream(se::Stream* stream) { +absl::Status GpuSolverContext::SetStream(se::Stream* stream) { return ConvertStatus( GpuSolverSetStream(handle_.get(), se::gpu::AsGpuStreamValue(stream))); } @@ -309,7 +327,7 @@ GpuSolverContext::GpuSolverContext(gpusolverHandle_t handle) void GpuSolverContext::Deleter::operator()(gpusolverHandle_t handle) { if (handle) { - Status status = ConvertStatus(GpuSolverDestroy(handle)); + absl::Status status = ConvertStatus(GpuSolverDestroy(handle)); if (!status.ok()) { LOG(ERROR) << "GpuSolverDestroy failed: " << status; } @@ -319,42 +337,76 @@ void GpuSolverContext::Deleter::operator()(gpusolverHandle_t handle) { // Note: NVidia have promised that it is safe to pass 'nullptr' as the argument // buffers to cuSolver buffer size methods and this will be a documented // behavior in a future cuSolver release. -StatusOr GpuSolverContext::PotrfBufferSize(PrimitiveType type, - se::blas::UpperLower uplo, - int n, int lda, - int batch_size) { -#if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER +absl::StatusOr GpuSolverContext::PotrfBufferSize( + PrimitiveType type, se::blas::UpperLower uplo, int n, int lda, + int batch_size) { int size = -1; + auto gpu_uplo = GpuBlasUpperLower(uplo); +#if GOOGLE_CUDA + size_t d_lwork = 0; /* size of workspace */ + size_t h_lwork = 0; /* size of workspace */ + + gpuDataType_t cuda_data_type; switch (type) { case F32: { - TF_RETURN_IF_ERROR(ConvertStatus( - GpuSolverSpotrf_bufferSize(handle_.get(), GpuBlasUpperLower(uplo), n, - /*A=*/nullptr, lda, &size))); + cuda_data_type = CUDA_R_32F; break; } case F64: { - TF_RETURN_IF_ERROR(ConvertStatus( - GpuSolverDpotrf_bufferSize(handle_.get(), GpuBlasUpperLower(uplo), n, - /*A=*/nullptr, lda, &size))); + cuda_data_type = CUDA_R_64F; break; } case C64: { - TF_RETURN_IF_ERROR(ConvertStatus( - GpuSolverCpotrf_bufferSize(handle_.get(), GpuBlasUpperLower(uplo), n, - /*A=*/nullptr, lda, &size))); + cuda_data_type = CUDA_C_32F; break; } case C128: { - TF_RETURN_IF_ERROR(ConvertStatus( - GpuSolverZpotrf_bufferSize(handle_.get(), GpuBlasUpperLower(uplo), n, - /*A=*/nullptr, lda, &size))); + cuda_data_type = CUDA_C_64F; break; } default: return InvalidArgument("Invalid type for cholesky decomposition: %s", PrimitiveType_Name(type)); } - // CUDA's potrfBatched needs space for the `as` array, which contains + TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverDnXpotrf_bufferSize( + handle_.get(), nullptr, gpu_uplo, n, cuda_data_type, nullptr, lda, + cuda_data_type, &d_lwork, &h_lwork))); + size = static_cast(d_lwork); + +#elif TENSORFLOW_USE_HIPSOLVER + switch (type) { + case F32: { + TF_RETURN_IF_ERROR( + ConvertStatus(GpuSolverSpotrf_bufferSize(handle_.get(), gpu_uplo, n, + /*A=*/nullptr, lda, &size))); + break; + } + case F64: { + TF_RETURN_IF_ERROR( + ConvertStatus(GpuSolverDpotrf_bufferSize(handle_.get(), gpu_uplo, n, + /*A=*/nullptr, lda, &size))); + break; + } + case C64: { + TF_RETURN_IF_ERROR( + ConvertStatus(GpuSolverCpotrf_bufferSize(handle_.get(), gpu_uplo, n, + /*A=*/nullptr, lda, &size))); + break; + } + case C128: { + TF_RETURN_IF_ERROR( + ConvertStatus(GpuSolverZpotrf_bufferSize(handle_.get(), gpu_uplo, n, + /*A=*/nullptr, lda, &size))); + break; + } + default: + return InvalidArgument("Invalid type for cholesky decomposition: %s", + PrimitiveType_Name(type)); + } +#endif // TENSORFLOW_USE_HIPSOLVER + +#if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER + // CUDA/HIP's potrfBatched needs space for the `as` array, which contains // batch_size pointers. Divide by sizeof(type) because this function returns // not bytes but a number of elements of `type`. int64_t potrf_batched_scratch = CeilOfRatio( @@ -366,10 +418,11 @@ StatusOr GpuSolverContext::PotrfBufferSize(PrimitiveType type, #endif } -Status GpuSolverContext::PotrfBatched(se::blas::UpperLower uplo, int n, - se::DeviceMemory as, int lda, - se::DeviceMemory lapack_info, - int batch_size) { +absl::Status GpuSolverContext::PotrfBatched(se::blas::UpperLower uplo, int n, + se::DeviceMemory as, + int lda, + se::DeviceMemory lapack_info, + int batch_size) { return ConvertStatus(GpuSolverSpotrfBatched( handle_.get(), GpuBlasUpperLower(uplo), n, ToDevicePointer(as), lda, #if TENSORFLOW_USE_HIPSOLVER @@ -378,10 +431,11 @@ Status GpuSolverContext::PotrfBatched(se::blas::UpperLower uplo, int n, ToDevicePointer(lapack_info), batch_size)); } -Status GpuSolverContext::PotrfBatched(se::blas::UpperLower uplo, int n, - se::DeviceMemory as, int lda, - se::DeviceMemory lapack_info, - int batch_size) { +absl::Status GpuSolverContext::PotrfBatched(se::blas::UpperLower uplo, int n, + se::DeviceMemory as, + int lda, + se::DeviceMemory lapack_info, + int batch_size) { return ConvertStatus(GpuSolverDpotrfBatched( handle_.get(), GpuBlasUpperLower(uplo), n, ToDevicePointer(as), lda, #if TENSORFLOW_USE_HIPSOLVER @@ -390,11 +444,9 @@ Status GpuSolverContext::PotrfBatched(se::blas::UpperLower uplo, int n, ToDevicePointer(lapack_info), batch_size)); } -Status GpuSolverContext::PotrfBatched(se::blas::UpperLower uplo, int n, - se::DeviceMemory*> as, - int lda, - se::DeviceMemory lapack_info, - int batch_size) { +absl::Status GpuSolverContext::PotrfBatched( + se::blas::UpperLower uplo, int n, se::DeviceMemory*> as, + int lda, se::DeviceMemory lapack_info, int batch_size) { return ConvertStatus(GpuSolverCpotrfBatched( handle_.get(), GpuBlasUpperLower(uplo), n, ToDevicePointer(as), lda, #if TENSORFLOW_USE_HIPSOLVER @@ -403,7 +455,7 @@ Status GpuSolverContext::PotrfBatched(se::blas::UpperLower uplo, int n, ToDevicePointer(lapack_info), batch_size)); } -Status GpuSolverContext::PotrfBatched( +absl::Status GpuSolverContext::PotrfBatched( se::blas::UpperLower uplo, int n, se::DeviceMemory*> as, int lda, se::DeviceMemory lapack_info, int batch_size) { @@ -415,5 +467,87 @@ Status GpuSolverContext::PotrfBatched( ToDevicePointer(lapack_info), batch_size)); } +#if GOOGLE_CUDA +absl::Status GpuSolverContext::Potrf(se::blas::UpperLower uplo, int n, + se::DeviceMemory a, int lda, + se::DeviceMemory lapack_info, + se::DeviceMemory workspace) { + absl::Status status = ConvertStatus(GpuSolverXpotrf( + handle_.get(), nullptr, GpuBlasUpperLower(uplo), n, CUDA_R_64F, + ToDevicePointer(a), lda, CUDA_R_64F, ToDevicePointer(workspace), + workspace.ElementCount(), nullptr, 0, ToDevicePointer(lapack_info))); + return status; +} + +absl::Status GpuSolverContext::Potrf(se::blas::UpperLower uplo, int n, + se::DeviceMemory a, int lda, + se::DeviceMemory lapack_info, + se::DeviceMemory workspace) { + absl::Status status = ConvertStatus(GpuSolverXpotrf( + handle_.get(), nullptr, GpuBlasUpperLower(uplo), n, CUDA_R_32F, + ToDevicePointer(a), lda, CUDA_R_32F, ToDevicePointer(workspace), + workspace.ElementCount(), nullptr, 0, ToDevicePointer(lapack_info))); + return status; +} + +absl::Status GpuSolverContext::Potrf( + se::blas::UpperLower uplo, int n, se::DeviceMemory> a, + int lda, se::DeviceMemory lapack_info, + se::DeviceMemory> workspace) { + absl::Status status = ConvertStatus(GpuSolverXpotrf( + handle_.get(), nullptr, GpuBlasUpperLower(uplo), n, CUDA_C_32F, + ToDevicePointer(a), lda, CUDA_C_32F, ToDevicePointer(workspace), + workspace.ElementCount(), nullptr, 0, ToDevicePointer(lapack_info))); + return status; +} + +absl::Status GpuSolverContext::Potrf( + se::blas::UpperLower uplo, int n, se::DeviceMemory> a, + int lda, se::DeviceMemory lapack_info, + se::DeviceMemory> workspace) { + absl::Status status = ConvertStatus(GpuSolverXpotrf( + handle_.get(), nullptr, GpuBlasUpperLower(uplo), n, CUDA_C_64F, + ToDevicePointer(a), lda, CUDA_C_64F, ToDevicePointer(workspace), + workspace.ElementCount(), nullptr, 0, ToDevicePointer(lapack_info))); + return status; +} +#elif TENSORFLOW_USE_HIPSOLVER +absl::Status GpuSolverContext::Potrf(se::blas::UpperLower uplo, int n, + se::DeviceMemory a, int lda, + se::DeviceMemory lapack_info, + se::DeviceMemory workspace) { + return ConvertStatus(GpuSolverDpotrf(handle_.get(), GpuBlasUpperLower(uplo), + n, ToDevicePointer(a), lda, nullptr, 0, + ToDevicePointer(lapack_info))); +} + +absl::Status GpuSolverContext::Potrf(se::blas::UpperLower uplo, int n, + se::DeviceMemory a, int lda, + se::DeviceMemory lapack_info, + se::DeviceMemory workspace) { + return ConvertStatus(GpuSolverSpotrf(handle_.get(), GpuBlasUpperLower(uplo), + n, ToDevicePointer(a), lda, nullptr, 0, + ToDevicePointer(lapack_info))); +} + +absl::Status GpuSolverContext::Potrf( + se::blas::UpperLower uplo, int n, se::DeviceMemory> a, + int lda, se::DeviceMemory lapack_info, + se::DeviceMemory> workspace) { + return ConvertStatus(GpuSolverCpotrf(handle_.get(), GpuBlasUpperLower(uplo), + n, ToDevicePointer(a), lda, nullptr, 0, + ToDevicePointer(lapack_info))); +} + +absl::Status GpuSolverContext::Potrf( + se::blas::UpperLower uplo, int n, se::DeviceMemory> a, + int lda, se::DeviceMemory lapack_info, + se::DeviceMemory> workspace) { + return ConvertStatus(GpuSolverZpotrf(handle_.get(), GpuBlasUpperLower(uplo), + n, ToDevicePointer(a), lda, nullptr, 0, + ToDevicePointer(lapack_info))); +} +#endif // TENSORFLOW_USE_HIPSOLVER + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/cusolver_context.h b/xla/service/gpu/cusolver_context.h index 63138f5fd6015..74f287254fa4a 100644 --- a/xla/service/gpu/cusolver_context.h +++ b/xla/service/gpu/cusolver_context.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,7 +17,12 @@ limitations under the License. #define XLA_SERVICE_GPU_CUSOLVER_CONTEXT_H_ #include +#include #include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" #define TENSORFLOW_USE_HIPSOLVER \ (TENSORFLOW_USE_ROCM && (TF_ROCM_VERSION >= 40500)) @@ -42,7 +47,6 @@ using gpusolverHandle_t = rocblas_handle; #endif // TF_ROCM_VERSION >= 40500 #endif // TENSORFLOW_USE_ROCM -#include "xla/statusor.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/stream_executor.h" #include "xla/xla_data.pb.h" @@ -54,27 +58,44 @@ namespace se = ::stream_executor; class GpuSolverContext { public: - static StatusOr Create(); + static absl::StatusOr Create(); - Status SetStream(se::Stream* stream); + absl::Status SetStream(se::Stream* stream); // Computes the Cholesky factorization of multiple matrices. See // https://docs.nvidia.com/cuda/cusolver/index.html#cuSolverDN-lt-t-gt-batchpotrf // // `as` is a list of pointers to the batch_size individual n x n matricies // that make up the input array. - Status PotrfBatched(se::blas::UpperLower uplo, int n, - se::DeviceMemory as, int lda, - se::DeviceMemory lapack_info, int batch_size); - Status PotrfBatched(se::blas::UpperLower uplo, int n, - se::DeviceMemory as, int lda, - se::DeviceMemory lapack_info, int batch_size); - Status PotrfBatched(se::blas::UpperLower uplo, int n, - se::DeviceMemory*> as, int lda, - se::DeviceMemory lapack_info, int batch_size); - Status PotrfBatched(se::blas::UpperLower uplo, int n, - se::DeviceMemory*> as, int lda, - se::DeviceMemory lapack_info, int batch_size); + absl::Status PotrfBatched(se::blas::UpperLower uplo, int n, + se::DeviceMemory as, int lda, + se::DeviceMemory lapack_info, int batch_size); + absl::Status PotrfBatched(se::blas::UpperLower uplo, int n, + se::DeviceMemory as, int lda, + se::DeviceMemory lapack_info, int batch_size); + absl::Status PotrfBatched(se::blas::UpperLower uplo, int n, + se::DeviceMemory*> as, int lda, + se::DeviceMemory lapack_info, int batch_size); + absl::Status PotrfBatched(se::blas::UpperLower uplo, int n, + se::DeviceMemory*> as, int lda, + se::DeviceMemory lapack_info, int batch_size); + + absl::Status Potrf(se::blas::UpperLower uplo, int n, + se::DeviceMemory a, int lda, + se::DeviceMemory lapack_info, + se::DeviceMemory workspace); + absl::Status Potrf(se::blas::UpperLower uplo, int n, + se::DeviceMemory a, int lda, + se::DeviceMemory lapack_info, + se::DeviceMemory workspace); + absl::Status Potrf(se::blas::UpperLower uplo, int n, + se::DeviceMemory> a, int lda, + se::DeviceMemory lapack_info, + se::DeviceMemory> workspace); + absl::Status Potrf(se::blas::UpperLower uplo, int n, + se::DeviceMemory> a, int lda, + se::DeviceMemory lapack_info, + se::DeviceMemory> workspace); // Returns the max size of the `workspace` required by Potrf and PotrfBatched, // in number of elements of `type`. @@ -90,9 +111,9 @@ class GpuSolverContext { // // In practice, this does not result in a notable increase in scratch space // needed, because both cases require a relatively small amount of scratch. - StatusOr PotrfBufferSize(PrimitiveType type, - se::blas::UpperLower uplo, int n, int lda, - int batch_size); + absl::StatusOr PotrfBufferSize(PrimitiveType type, + se::blas::UpperLower uplo, int n, + int lda, int batch_size); private: explicit GpuSolverContext(gpusolverHandle_t handle); diff --git a/xla/service/gpu/cusolver_rewriter.cc b/xla/service/gpu/cusolver_rewriter.cc index 904673038edfd..ddfda663821eb 100644 --- a/xla/service/gpu/cusolver_rewriter.cc +++ b/xla/service/gpu/cusolver_rewriter.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,25 +15,31 @@ limitations under the License. #include "xla/service/gpu/cusolver_rewriter.h" -#include +#include #include -#include -#include +#include #include #include "absl/algorithm/container.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/service/gpu/cusolver_context.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/blas.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" -#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -48,10 +54,10 @@ void SetFortranLayout(Shape* shape) { shape->mutable_layout()->mutable_minor_to_major()->at(1)); } -StatusOr CreateCholesky(GpuSolverContext* context, - HloInstruction* operand, - const CholeskyOptions& options, - const OpMetadata& metadata) { +absl::StatusOr CreateCholesky(GpuSolverContext* context, + HloInstruction* operand, + const CholeskyOptions& options, + const OpMetadata& metadata) { HloComputation* computation = operand->parent(); Shape a_shape = operand->shape(); @@ -131,8 +137,8 @@ StatusOr CreateCholesky(GpuSolverContext* context, } // Tries to rewrite a single convolution into a call to cudnn. -StatusOr RunOnInstruction(GpuSolverContext* context, - HloInstruction* instruction) { +absl::StatusOr RunOnInstruction(GpuSolverContext* context, + HloInstruction* instruction) { if (instruction->opcode() != HloOpcode::kCholesky) { return false; } @@ -154,7 +160,7 @@ StatusOr RunOnInstruction(GpuSolverContext* context, // Rewrites the convolutions in the given computation into calls to cudnn. // Returns true if it made any changes. -StatusOr GpusolverRewriter::RunOnComputation( +absl::StatusOr GpusolverRewriter::RunOnComputation( HloComputation* computation) { std::vector cusolver_calls; for (auto* hlo : computation->instructions()) { @@ -179,7 +185,7 @@ StatusOr GpusolverRewriter::RunOnComputation( GpusolverRewriter::GpusolverRewriter() = default; -StatusOr GpusolverRewriter::Run( +absl::StatusOr GpusolverRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/gpu/cusolver_rewriter.h b/xla/service/gpu/cusolver_rewriter.h index 9d68cfa15bae1..fd1d84dfa9936 100644 --- a/xla/service/gpu/cusolver_rewriter.h +++ b/xla/service/gpu/cusolver_rewriter.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,12 +16,12 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_ #define XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/cusolver_context.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/stream_executor/device_memory_allocator.h" -#include "xla/stream_executor/stream_executor.h" namespace xla { namespace gpu { @@ -33,12 +33,12 @@ class GpusolverRewriter : public HloModulePass { absl::string_view name() const override { return "gpusolver-rewriter"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; private: - StatusOr RunOnComputation(HloComputation* computation); + absl::StatusOr RunOnComputation(HloComputation* computation); }; } // namespace gpu diff --git a/xla/service/gpu/custom_call_test.cc b/xla/service/gpu/custom_call_test.cc index e6a29c01b9b59..feb3d53c79c26 100644 --- a/xla/service/gpu/custom_call_test.cc +++ b/xla/service/gpu/custom_call_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,14 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include +#include #include #include - -#include "absl/strings/str_cat.h" +#include #if GOOGLE_CUDA -#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: keep #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/gpus/cuda/include/driver_types.h" #define PLATFORM "CUDA" @@ -29,27 +31,31 @@ limitations under the License. #define PLATFORM "ROCM" #endif +#include +#include +#include "absl/algorithm/container.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" #include "xla/client/lib/constants.h" #include "xla/client/xla_builder.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/memref_view.h" -#include "xla/runtime/module.h" -#include "xla/runtime/module_registry.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" -#include "xla/service/gpu/runtime/custom_call_registry.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/status.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/scratch_allocator.h" +#include "xla/stream_executor/stream.h" #include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" #if GOOGLE_CUDA #define gpuSuccess cudaSuccess @@ -344,119 +350,368 @@ TEST_F(CustomCallTest, WithStatusFailed) { // XLA runtime custom calls provides type-safe custom call API //===----------------------------------------------------------------------===// -// WARNING: We currently rely on a magic custom call prefix `__gpu$` to detect -// "internal" custom calls that linked statically into the binary. Without this -// prefix custom calls expected to be registered as XLA:FFI custom calls, and -// this is not yet fully supported. -// -// TODO(ezhulenev): Unify runtime custom calls and XLA:FFI. +static absl::Status AlwaysFail(ffi::Result, int32_t value) { + return absl::InternalError(absl::StrCat("Uh oh, wrong value: ", value)); +} -// (1) Declare custom call implementations as static functions. +XLA_FFI_DEFINE_HANDLER(kAlwaysFail, AlwaysFail, + ffi::Ffi::Bind() + .Ret() // + .Attr("value") // value +); +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$always_fail", + PLATFORM, kAlwaysFail); -static absl::Status AlwaysFailImpl(runtime::MemrefView arg, int32_t value) { - return absl::InternalError(absl::StrCat("Uh oh, wrong value: ", value)); +TEST_F(CustomCallTest, RuntimeCustomCallAlwaysFail) { + XlaBuilder b(TestName()); + CustomCall(&b, "__xla_test$$always_fail", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {}), /*opaque=*/"{value = 42 : i32}", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + auto status = Execute(&b, {}).status(); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_THAT(status.message(), ::testing::HasSubstr("Uh oh, wrong value: 42")); } -static absl::Status MemcpyImpl(const ServiceExecutableRunOptions* run_options, - runtime::MemrefView src, - runtime::MemrefView dst) { - auto src_mem = gpu::GetDeviceAddress(src); - auto dst_mem = gpu::GetDeviceAddress(dst); - run_options->stream()->ThenMemcpyD2D(&dst_mem, src_mem, src_mem.size()); - return absl::OkStatus(); +static absl::Status Memcpy(se::Stream* stream, ffi::BufferBase src, + ffi::Result dst) { + return stream->MemcpyD2D( + &dst->data, src.data, + absl::c_accumulate(src.dimensions, 1.0, std::multiplies()) * + sizeof(float)); } -// (2) Declare custom call binding signature. At compile time we check that -// declared signature matches function handlers, and at run time we check that -// passed arguments match the signature (number of arguments and their types). +XLA_FFI_DEFINE_HANDLER(kMemcpy, Memcpy, + ffi::Ffi::Bind() + .Ctx() + .Arg() // src + .Ret() // dst +); +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$memcpy", PLATFORM, + kMemcpy); -// TODO(ezhulenev): Remove these custom calls once we switch to thunks runtime. +TEST_F(CustomCallTest, ExportedFfiMemcpy) { + XlaBuilder b(TestName()); + CustomCall(&b, "__xla_test$$memcpy", + /*operands=*/{Broadcast(ConstantR0WithType(&b, F32, 42.0), {128})}, + ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + TF_ASSERT_OK_AND_ASSIGN(auto result, ExecuteAndTransfer(&b, {})); + EXPECT_THAT(result.data(), ::testing::Each(42)); +} -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - AlwaysFail, AlwaysFailImpl, runtime::CustomCall::RuntimeChecks::kDefault, - runtime::CustomCall::Bind("__gpu$xla.gpu.ext.always_fail") - .Arg() // arg - .Attr("value") // value -); +static absl::Status HandleUserPointer(ffi::Result, + const std::string* str) { + return absl::InternalError(*str); +} -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Memcpy, MemcpyImpl, runtime::CustomCall::RuntimeChecks::kDefault, - runtime::CustomCall::Bind("__gpu$xla.gpu.ext.memcpy") - .UserData() - .Arg() // src - .Arg() // dst -); +XLA_FFI_DEFINE_HANDLER(kHandleUserPointer, HandleUserPointer, + ffi::Ffi::Bind() + .Ret() // buffer for result + .Attr>("message")); -// (3) Declare FFI handlers as adaptors for legacy XLA runtime custom calls. -// -// TODO(ezhulenev): This is a long term replacement for "legacy" custom calls -// (custom calls with void** arguments) and a type safe xla runtime custom -// calls (see above). XLA FFI unifies internal custom calls (static linking) -// with external custom calls (dynamically loaded libraries). Make this the only -// example, once it's fully supported. +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$user_data", PLATFORM, + kHandleUserPointer); -namespace impl { -static Status AlwaysFail(ffi::Buffer arg, int32_t value) { - return AlwaysFailImpl(arg, value); +TEST_F(CustomCallTest, PassUserPointerWithAttrs) { + std::string message = "User-defined message"; + auto ptr = reinterpret_cast(&message); + + XlaBuilder b(TestName()); + CustomCall(&b, "__xla_test$$user_data", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {}), + /*opaque=*/absl::StrFormat("{message = %d : i64}", ptr), + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + auto status = Execute(&b, {}).status(); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_THAT(status.message(), ::testing::HasSubstr("User-defined message")); } -static Status Memcpy(const ServiceExecutableRunOptions* run_options, - ffi::Buffer src, ffi::Buffer dst) { - return MemcpyImpl(run_options, src, dst); +bool is_ffi_invoked = false; +static absl::Status IsInvoked(ffi::Result) { + is_ffi_invoked = true; + return absl::OkStatus(); } -} // namespace impl -XLA_FFI_DEFINE_HANDLER(kAlwaysFail, impl::AlwaysFail, - ffi::Ffi::Bind() - .Arg() // arg - .Attr("value") // value -); +XLA_FFI_DEFINE_HANDLER( + kIsInvoked, IsInvoked, + ffi::Ffi::Bind().Ret()); // Buffer for result (unused). -XLA_FFI_DEFINE_HANDLER(kMemcpy, impl::Memcpy, +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$isinvoked", PLATFORM, + kIsInvoked); + +TEST_F(CustomCallTest, ExportedFfiIsInvoked) { + XlaBuilder b(TestName()); + CustomCall(&b, "__xla_test$$isinvoked", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {}), /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + TF_ASSERT_OK_AND_ASSIGN(auto result, ExecuteAndTransfer(&b, {})); + EXPECT_TRUE(is_ffi_invoked); +} + +TEST_F(CustomCallTest, ExportedFfiUnknownTarget) { + XlaBuilder b(TestName()); + CustomCall(&b, "__xla_test$$unknown_target", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {}), /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + auto status = Execute(&b, {}).status(); + EXPECT_EQ(status.code(), absl::StatusCode::kUnimplemented); + EXPECT_THAT(status.message(), + ::testing::HasSubstr("No registered implementation")); +} + +// Memcpy and SubBuffers tests are already ported in +// fusions/address_computation_fusion_test.cc + +// Reusing kExpectedOpaque from the original test. +static absl::Status Opaque(ffi::Result, + const std::string* str) { + std::string opaque(*str); + if (opaque != kExpectedOpaque) + return absl::InternalError(absl::StrFormat( + "Opaque string does not match. Expected `%s` but got `%s`", + kExpectedOpaque, opaque)); + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kOpaque, Opaque, ffi::Ffi::Bind() - .Ctx() - .Arg() // src - .Arg() // dst -); + .Ret() // Dummy result buffer. + .Attr>("opaque")); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$opaque", PLATFORM, + kOpaque); + +TEST_F(CustomCallTest, ExportedFfiOpaque) { + XlaBuilder b(TestName()); + const std::string opaque = absl::StrFormat( + "{opaque = %d : i64}", reinterpret_cast(&kExpectedOpaque)); + CustomCall(&b, "__xla_test$$opaque", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {}), + /*opaque=*/opaque, + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + TF_ASSERT_OK(Execute(&b, {}).status()); +} -// (4) Register custom calls handlers with XLA runtime. +static absl::Status TokensChecker(std::vector inputs, + const std::string* opaque) { + // TODO(penporn): Actually check the inputs when FFI handlers support tokens. + return absl::OkStatus(); +} -static void RegisterCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("__gpu$xla.gpu.ext.always_fail", AlwaysFail); - registry.Register("__gpu$xla.gpu.ext.memcpy", Memcpy); +static absl::Status Tokens1Input(ffi::BufferBase input1, + ffi::Result, + const std::string* opaque) { + return TokensChecker({input1}, opaque); } -XLA_GPU_REGISTER_RUNTIME_CUSTOM_CALL(RegisterCustomCalls); +static absl::Status Tokens2Inputs(ffi::BufferBase input1, + ffi::BufferBase input2, + ffi::Result, + const std::string* opaque) { + return TokensChecker({input1, input2}, opaque); +} -// (5) Register XLA FFI handlers with XLA runtime. +static absl::Status Tokens3Inputs(ffi::BufferBase input1, + ffi::BufferBase input2, + ffi::BufferBase input3, + ffi::Result, + const std::string* opaque) { + return TokensChecker({input1, input2, input3}, opaque); +} -XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__gpu$xla.gpu.ext.always_fail", - kAlwaysFail); -XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__gpu$xla.gpu.ext.memcpy", - kMemcpy); +XLA_FFI_DEFINE_HANDLER(kTokens1Input, Tokens1Input, + ffi::Ffi::Bind() + .Arg() // 1 input buffer. + .Ret() // Output buffer. + .Attr>("opaque")); -TEST_F(CustomCallTest, RuntimeCustomCallAlwaysFail) { +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$tokens_1input", + PLATFORM, kTokens1Input); + +XLA_FFI_DEFINE_HANDLER(kTokens2Inputs, Tokens2Inputs, + ffi::Ffi::Bind() + .Arg() // 1st input buffer. + .Arg() // 2nd input buffer. + .Ret() // Output buffer. + .Attr>("opaque")); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$tokens_2inputs", + PLATFORM, kTokens2Inputs); + +XLA_FFI_DEFINE_HANDLER(kTokens3Inputs, Tokens3Inputs, + ffi::Ffi::Bind() + .Arg() // 1st input buffer. + .Arg() // 2nd input buffer. + .Arg() // 3rd input buffer. + .Ret() // Output buffer. + .Attr>("opaque")); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$tokens_3inputs", + PLATFORM, kTokens3Inputs); + +TEST_P(CustomCallTokensTest, ExportedFfiTokensTest) { + const TokenTestCase& tc = GetParam(); XlaBuilder b(TestName()); - CustomCall(&b, "__gpu$xla.gpu.ext.always_fail", /*operands=*/{}, - ShapeUtil::MakeShape(F32, {}), /*opaque=*/"{value = 42 : i32}", + std::istringstream input(tc.input); + std::istringstream output(tc.output); + std::vector call_inputs = BuildInputs(b, input); + std::vector call_output = BuildOutputType(output); + ASSERT_GE(call_inputs.size(), 1); + ASSERT_LE(call_inputs.size(), 3); + ASSERT_EQ(call_output.size(), 1); + + const std::string custom_call_name = + absl::StrFormat("__xla_test$$tokens_%dinput%s", call_inputs.size(), + call_inputs.size() == 1 ? "" : "s"); + const std::string opaque = absl::StrFormat( + "{opaque = %d : i64}", reinterpret_cast(&tc.opaque)); + CustomCall(&b, custom_call_name, /*operands=*/call_inputs, + call_output.front(), + /*opaque=*/opaque, /*has_side_effect=*/false, /*output_operand_aliasing=*/{}, /*literal=*/nullptr, /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + + // TODO(penporn): Expect an OK status when FFI handlers support tokens. auto status = Execute(&b, {}).status(); EXPECT_EQ(status.code(), absl::StatusCode::kInternal); - EXPECT_THAT(status.message(), ::testing::HasSubstr("Uh oh, wrong value: 42")); + EXPECT_THAT(status.message(), + ::testing::HasSubstr("FFI handlers do not support tokens")); } -TEST_F(CustomCallTest, ExportedFfiMemcpy) { +INSTANTIATE_TEST_SUITE_P(CustomCallTokensTest, CustomCallTokensTest, + ::testing::ValuesIn(GetTokenTestCases())); + +static absl::Status AlwaysSucceed(ffi::Result) { + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kAlwaysSucceed, AlwaysSucceed, + ffi::Ffi::Bind().Ret()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$always_succeed", + PLATFORM, kAlwaysSucceed); + +TEST_F(CustomCallTest, ExportedFfiWithStatusSucceeded) { XlaBuilder b(TestName()); - CustomCall(&b, "__gpu$xla.gpu.ext.memcpy", - /*operands=*/{Broadcast(ConstantR0WithType(&b, F32, 42.0), {128})}, - ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"", + CustomCall(&b, "__xla_test$$always_succeed", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {}), /*opaque=*/"", /*has_side_effect=*/false, /*output_operand_aliasing=*/{}, /*literal=*/nullptr, /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + TF_ASSERT_OK(Execute(&b, {}).status()); +} + +//===----------------------------------------------------------------------===// +// XLA:FFI handler for testing attributes decoding +//===----------------------------------------------------------------------===// + +static absl::Status FfiAttributes(ffi::Result, + absl::Span i32_arr) { + if (i32_arr.size() != 4) + return absl::InternalError("i32_arr size does not match"); + + if (i32_arr[0] != 1 || i32_arr[1] != 2 || i32_arr[2] != 3 || i32_arr[3] != 4) + return absl::InternalError("i32_arr values do not match"); + + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kFfiAttributes, FfiAttributes, + ffi::Ffi::Bind() + .Ret() + .Attr>("i32_arr")); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla.gpu.ffi_attributes", + PLATFORM, kFfiAttributes); + +TEST_F(CustomCallTest, FfiAttributes) { + XlaBuilder b(TestName()); + CustomCall(&b, "xla.gpu.ffi_attributes", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {}), + /*opaque=*/"{ i32_arr = array }", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + TF_ASSERT_OK(Execute(&b, {}).status()); +} + +//===----------------------------------------------------------------------===// +// XLA:FFI handler with attached HloComputation +//===----------------------------------------------------------------------===// + +static absl::Status MemcpyWithCalledComputation( + se::Stream* stream, se::OwningScratchAllocator<> scratch_allocator, + ffi::BufferBase src, ffi::Result dst, + const HloComputation* called_computation) { + if (called_computation == nullptr) + return absl::InternalError("Called computation is not defined"); + + if (called_computation->instruction_count() != 1) + return absl::InternalError("Unexpected number of instructions"); + + if (!DynCast(called_computation->root_instruction())) + return absl::InternalError("ROOT must be a paremeter"); + + // Check that scratch allocator is working. + auto scratch = scratch_allocator.AllocateBytes(1024); + if (!scratch.ok()) return scratch.status(); + + return Memcpy(stream, src, dst); +} + +XLA_FFI_DEFINE_HANDLER(kMemcpyWithCalledComputation, + MemcpyWithCalledComputation, + ffi::Ffi::Bind() + .Ctx() + .Ctx() // scratch + .Arg() // src + .Ret() // dst + .Ctx()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), + "xla.gpu.ext.memcpy_with_called_computation", PLATFORM, + kMemcpyWithCalledComputation); + +TEST_F(CustomCallTest, WithCalledComputation) { + auto shape = ShapeUtil::MakeShape(F32, {128}); + + // Build a called computation which is just a copy instruction. + XlaBuilder copy("copy"); + auto p0 = Parameter(©, 0, shape, "l_val"); + Copy(p0); + auto copy_computation = copy.Build().value(); + + XlaBuilder b(TestName()); + CustomCallWithComputation( + &b, "xla.gpu.ext.memcpy_with_called_computation", + /*operands=*/{Broadcast(ConstantR0WithType(&b, F32, 42.0), {128})}, + copy_computation, shape, /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); TF_ASSERT_OK_AND_ASSIGN(auto result, ExecuteAndTransfer(&b, {})); EXPECT_THAT(result.data(), ::testing::Each(42)); } diff --git a/xla/service/gpu/custom_fusion_rewriter.cc b/xla/service/gpu/custom_fusion_rewriter.cc deleted file mode 100644 index 666b816650c58..0000000000000 --- a/xla/service/gpu/custom_fusion_rewriter.cc +++ /dev/null @@ -1,175 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/custom_fusion_rewriter.h" - -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/kernels/custom_fusion_pattern.h" -#include "xla/statusor.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" - -namespace xla::gpu { - -CustomFusionRewriter::CustomFusionRewriter( - const CustomFusionPatternRegistry* patterns) - : patterns_(patterns) {} - -// Returns instructions that have to become custom fusion parameters. Returns an -// error if matched pattern can't be outlined as a fusion. -static StatusOr> GetPatternCaptures( - const CustomFusionPattern::Match& match) { - HloInstruction* root = match.instructions.back(); - absl::InlinedVector captures; - - // Instruction that will go into the fusion body. - absl::flat_hash_set instructions_set( - match.instructions.begin(), match.instructions.end()); - - // Check that intermediate instructions do not have users outside of the - // matched pattern. Only root instruction can have external users. - for (HloInstruction* instr : match.instructions) { - for (HloInstruction* user : instr->users()) { - if (instr != root && !instructions_set.contains(user)) { - return absl::InvalidArgumentError(absl::StrCat( - "Custom fusion intermediate result ", instr->name(), - " has users outside of a matched pattern: ", user->name())); - } - } - } - - // Collect instructions captured by a matched pattern. - for (HloInstruction* instr : match.instructions) { - for (HloInstruction* operand : instr->operands()) { - if (!instructions_set.contains(operand)) captures.push_back(operand); - } - } - - return captures; -} - -// Creates custom fusion computation and moves all matched instructions into it. -static StatusOr CreateFusionBody( - HloModule* module, const CustomFusionPattern::Match& match, - absl::Span captures) { - HloComputation::Builder builder(match.config.name()); - - // A mapping from original instructions to instructions in the fusion body. - absl::flat_hash_map instr_mapping; - - auto mapped_operands = [&](HloInstruction* instr) { - absl::InlinedVector operands; - for (HloInstruction* operand : instr->operands()) { - operands.push_back(instr_mapping.at(operand)); - } - return operands; - }; - - // For every parameter create a parameter instruction in the computation body - // and set up instruction mapping. - for (const HloInstruction* capture : captures) { - int64_t index = instr_mapping.size(); - instr_mapping[capture] = - builder.AddInstruction(HloInstruction::CreateParameter( - index, capture->shape(), absl::StrCat("p", index))); - } - - // TODO(ezhulenev): Instructions in the pattern must be topologically sorted, - // otherwise we'll get a crash! Figure out how to do it! - for (HloInstruction* instr : match.instructions) { - instr_mapping[instr] = builder.AddInstruction( - instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr))); - } - - return module->AddComputationAndUnifyNamesAndIds(builder.Build(), false); -} - -static StatusOr CreateFusionInstruction( - HloModule* module, const CustomFusionPattern::Match& match, - absl::Span captures, HloComputation* body) { - // We'll be replacing the root operation of a custom fusion with a fusion - // instruction calling fusion computation. - HloInstruction* root = match.instructions.back(); - HloComputation* parent = root->parent(); - - // Add a fusion operation calling outlined fusion computation. - HloInstruction* fusion = parent->AddInstruction(HloInstruction::CreateFusion( - root->shape(), HloInstruction::FusionKind::kCustom, captures, body)); - module->SetAndUniquifyInstrName(fusion, match.config.name()); - - // Set backends config to a matched custom fusion config. - FusionBackendConfig backend_config; - backend_config.set_kind("__custom_fusion"); - *backend_config.mutable_custom_fusion_config() = match.config; - TF_RETURN_IF_ERROR(fusion->set_backend_config(std::move(backend_config))); - - TF_RETURN_IF_ERROR(parent->ReplaceInstruction(root, fusion)); - return fusion; -} - -StatusOr CustomFusionRewriter::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - std::vector matches; - - // Collect all potential custom fusion matches in the module. - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instr : computation->instructions()) { - auto matched = patterns_->Match(instr); - matches.insert(matches.end(), matched.begin(), matched.end()); - } - } - - if (matches.empty()) return false; - - for (const CustomFusionPattern::Match& match : matches) { - // Check if pattern can be outlined as a fusion and collect captured - // parameters (instructions defined outside of a fusion). - auto captures = GetPatternCaptures(match); - if (!captures.ok()) { - VLOG(2) << "Skip custom fusion " << match.config.name() << ": " - << captures.status(); - continue; - } - - TF_ASSIGN_OR_RETURN(HloComputation * fusion_body, - CreateFusionBody(module, match, *captures)); - - TF_ASSIGN_OR_RETURN( - HloInstruction * fusion, - CreateFusionInstruction(module, match, *captures, fusion_body)); - - VLOG(2) << "Added a fusion instruction: " << fusion->name() - << " for custom fusion " << match.config.name() - << " (instruction count = " << match.instructions.size() << ")"; - } - - return true; -} - -} // namespace xla::gpu diff --git a/xla/service/gpu/custom_fusion_rewriter.h b/xla/service/gpu/custom_fusion_rewriter.h deleted file mode 100644 index 1e5e643a4a4c1..0000000000000 --- a/xla/service/gpu/custom_fusion_rewriter.h +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_CUSTOM_FUSION_REWRITER_H_ -#define XLA_SERVICE_GPU_CUSTOM_FUSION_REWRITER_H_ - -#include "absl/container/flat_hash_set.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_schedule.h" -#include "xla/service/gpu/kernels/custom_fusion_pattern.h" -#include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" - -namespace xla::gpu { - -// Pattern matches HLO instruction to custom fusions (hand written CUDA C++ -// kernels, e.g. custom GEMMs implemented with CUTLASS) and rewrites them into -// fusion instructions and fusion computations. -// -// Example: pattern matching dot operation into CUTLASS gemm -// -// ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] { -// %p0 = f16[15,19]{1,0} parameter(0) -// %p1 = f16[19,17]{1,0} parameter(1) -// ROOT %r = f16[15,17]{1,0} dot(%p0, %p1), -// lhs_contracting_dims={1}, rhs_contracting_dims={0} -// } -// -// After the pass: -// -// %cutlass_gemm (p0: f16[19,17], p1: f16[15,19]) -> f16[15,17] { -// %p0 = f16[15,19]{1,0} parameter(0) -// %p1 = f16[19,17]{1,0} parameter(1) -// ROOT %r = f16[15,17]{1,0} dot(%p0, %p1), -// lhs_contracting_dims={1}, rhs_contracting_dims={0} -// } -// -// ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] { -// %p0 = f16[15,19]{1,0} parameter(0) -// %p1 = f16[19,17]{1,0} parameter(1) -// ROOT %r = f16[15,17]{1,0} fusion(%p0, %p1), kind=kCustom, -// calls==cutlass_gemm, -// backend_config={kind: "__custom_fusion", -// custom_fusion_config: {"name":"cutlass_gemm"}} -// } -// -class CustomFusionRewriter : public HloModulePass { - public: - explicit CustomFusionRewriter(const CustomFusionPatternRegistry* patterns = - CustomFusionPatternRegistry::Default()); - - absl::string_view name() const override { return "custom-fusion-rewriter"; } - - StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - const CustomFusionPatternRegistry* patterns_; -}; - -} // namespace xla::gpu - -#endif // XLA_SERVICE_GPU_CUSTOM_FUSION_REWRITER_H_ diff --git a/xla/service/gpu/custom_fusion_rewriter_test.cc b/xla/service/gpu/custom_fusion_rewriter_test.cc deleted file mode 100644 index 55268442fc48a..0000000000000 --- a/xla/service/gpu/custom_fusion_rewriter_test.cc +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/custom_fusion_rewriter.h" - -#include -#include - -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/kernels/custom_fusion_pattern.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/test.h" - -namespace xla::gpu { - -//===----------------------------------------------------------------------===// -// Simple pattern matchers for testing custom fusion rewriter. -//===----------------------------------------------------------------------===// - -class SimpleGemmPattern : public CustomFusionPattern { - public: - std::optional TryMatch(HloInstruction* instr) const override { - if (auto* dot = DynCast(instr)) { - CustomFusionConfig config; - config.set_name("simple_gemm"); - return Match{config, {instr}}; - } - return std::nullopt; - } -}; - -//===----------------------------------------------------------------------===// - -class CustomFusionRewriterTest : public HloTestBase {}; - -TEST_F(CustomFusionRewriterTest, SimpleGemm) { - const char* hlo = R"( - HloModule test - - ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] { - %p0 = f16[15,19]{1,0} parameter(0) - %p1 = f16[19,17]{1,0} parameter(1) - ROOT %r = f16[15,17]{1,0} dot(%p0, %p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - } - )"; - - const char* expected = R"( - ; CHECK: %simple_gemm {{.*}} { - ; CHECK: [[P0:%[^ ]+]] = f16[15,19]{1,0} parameter(0) - ; CHECK: [[P1:%[^ ]+]] = f16[19,17]{1,0} parameter(1) - ; CHECK: ROOT [[DOT:%[^ ]+]] = f16[15,17]{1,0} dot([[P0]], [[P1]]), - ; CHECK: lhs_contracting_dims={1}, rhs_contracting_dims={0} - ; CHECK: } - - ; CHECK: ENTRY %main {{.*}} { - ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[15,17]{1,0} fusion - ; CHECK: kind=kCustom, calls=%simple_gemm, - ; CHECK: backend_config={ - ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"simple_gemm"} - ; CHECK: } - ; CHECK: } - )"; - - CustomFusionPatternRegistry patterns; - patterns.Emplace(); - - CustomFusionRewriter pass(&patterns); - RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); -} - -} // namespace xla::gpu diff --git a/xla/service/gpu/custom_kernel_fusion_rewriter.cc b/xla/service/gpu/custom_kernel_fusion_rewriter.cc new file mode 100644 index 0000000000000..cf80ab5fcb3e4 --- /dev/null +++ b/xla/service/gpu/custom_kernel_fusion_rewriter.cc @@ -0,0 +1,239 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/custom_kernel_fusion_rewriter.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { + +CustomKernelFusionRewriter::CustomKernelFusionRewriter( + const se::DeviceDescription* device, + const CustomKernelFusionPatternRegistry* patterns) + : device_(device), patterns_(patterns) {} + +// Returns a set of instruction that have users outside of a matched pattern +// and have a replacement that must be applied after building a new custom +// fusion instruction. Only root instruction can have external users and does +// not require a replacement, as the fusion itself is a replacement. If +// instruction has external users and does not have a replacement returns empty +// optional. +static std::optional> +GetPatternReplacements(const CustomKernelFusionPattern::Match& match) { + absl::flat_hash_set requires_replacement; + absl::flat_hash_set instructions_set( + match.instructions().begin(), match.instructions().end()); + + for (HloInstruction* instr : match.instructions()) { + for (HloInstruction* user : instr->users()) { + if (instr == match.root() || instructions_set.contains(user)) continue; + + if (match.HasReplacement(instr)) { + requires_replacement.insert(instr); + continue; + } + + VLOG(3) << "Custom kernel fusion intermediate result " << instr->name() + << " has users outside of a matched pattern: " << user->name(); + return std::nullopt; + } + } + + return requires_replacement; +} + +// Returns instructions that have to become custom kernel fusion parameters. +// Returns an error if matched pattern can't be outlined as a fusion. +static absl::InlinedVector GetPatternCaptures( + const CustomKernelFusionPattern::Match& match) { + absl::InlinedVector captures; + + absl::flat_hash_set instructions_set( + match.instructions().begin(), match.instructions().end()); + + for (HloInstruction* instr : match.instructions()) { + for (HloInstruction* operand : instr->operands()) { + if (!instructions_set.contains(operand) && + absl::c_find(captures, operand) == captures.end()) { + captures.emplace_back(operand); + } + } + } + + return captures; +} + +// Creates custom kernel fusion computation and moves all matched instructions +// into it. +static absl::StatusOr CreateFusionBody( + HloModule* module, const CustomKernelFusionPattern::Match& match, + absl::Span captures) { + HloComputation::Builder builder(match.config().name()); + + // A mapping from original instructions to instructions in the fusion body. + absl::flat_hash_map instr_mapping; + + auto mapped_operands = [&](HloInstruction* instr) { + absl::InlinedVector operands; + for (HloInstruction* operand : instr->operands()) { + operands.push_back(instr_mapping.at(operand)); + } + return operands; + }; + + // For every captured value create a parameter instruction in the computation + // body and set up instruction mapping. + for (const HloInstruction* capture : captures) { + int64_t index = instr_mapping.size(); + instr_mapping[capture] = + builder.AddInstruction(HloInstruction::CreateParameter( + index, capture->shape(), absl::StrCat("p", index))); + } + + // TODO(ezhulenev): Instructions in the pattern must be topologically sorted, + // otherwise we'll get a crash! Figure out how to do it! + for (HloInstruction* instr : match.instructions()) { + instr_mapping[instr] = builder.AddInstruction( + instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr))); + } + + HloInstruction* root = builder.last_added_instruction(); + + // If custom kernel fusion requires a workspace we add a custom call that + // allocates workspace and return a tuple of "real" result and a workspace. + if (match.workspace_size_bytes() > 0) { + auto workspace_shape = + ShapeUtil::MakeShape(PrimitiveType::U8, {match.workspace_size_bytes()}); + HloInstruction* workspace = + builder.AddInstruction(HloInstruction::CreateCustomCall( + workspace_shape, {}, CustomKernelFusionPattern::kWorkspace, "", + CustomCallApiVersion::API_VERSION_TYPED_FFI)); + builder.AddInstruction(HloInstruction::CreateTuple({root, workspace})); + } + + return module->AddComputationAndUnifyNamesAndIds(builder.Build(), false); +} + +static absl::StatusOr CreateFusionInstruction( + HloModule* module, const CustomKernelFusionPattern::Match& match, + absl::Span captures, HloComputation* body) { + // We'll be replacing the root operation of a custom kernel fusion with a + // fusion instruction calling fusion computation. + HloInstruction* root = match.root(); + HloComputation* parent = root->parent(); + + // Add a fusion operation calling outlined fusion computation. + HloInstruction* fusion = parent->AddInstruction(HloInstruction::CreateFusion( + body->root_instruction()->shape(), HloInstruction::FusionKind::kCustom, + captures, body)); + module->SetAndUniquifyInstrName(fusion, match.config().name()); + + // Set backends config to a matched custom fusion config. + GpuBackendConfig gpu_config; + FusionBackendConfig& backend_config = + *gpu_config.mutable_fusion_backend_config(); + backend_config.set_kind("__custom_fusion"); + *backend_config.mutable_custom_fusion_config() = match.config(); + TF_RETURN_IF_ERROR(fusion->set_backend_config(std::move(gpu_config))); + + // If we don't have workspace we can return constructed fusion instruction. + if (match.workspace_size_bytes() == 0) return fusion; + + // Otherwise have to get result corresponding to the original value; + return parent->AddInstruction( + HloInstruction::CreateGetTupleElement(fusion, 0)); +} + +absl::StatusOr CustomKernelFusionRewriter::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + std::vector matches; + + // Collect all potential custom fusion matches in the module. + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instr : computation->instructions()) { + auto matched = patterns_->Match(*device_, instr); + matches.insert(matches.end(), matched.begin(), matched.end()); + } + } + + if (matches.empty()) return false; + + for (const CustomKernelFusionPattern::Match& match : matches) { + VLOG(2) << "Matched custom kernel fusion " << match.config().name() + << "; root instruction: " << match.instructions().back()->name(); + + auto replacememts = GetPatternReplacements(match); + if (!replacememts.has_value()) continue; + + auto captures = GetPatternCaptures(match); + + TF_ASSIGN_OR_RETURN(HloComputation * fusion_body, + CreateFusionBody(module, match, captures)); + TF_ASSIGN_OR_RETURN( + HloInstruction * fusion, + CreateFusionInstruction(module, match, captures, fusion_body)); + + VLOG(2) << "Added a fusion instruction: " << fusion->name() + << " for custom kernel fusion " << match.config().name() + << " (instruction count = " << match.instructions().size() << ")"; + + for (HloInstruction* instr : *replacememts) { + VLOG(2) << "Replace matched instruction: " << instr->name() + << " with a pattern replacement"; + + TF_ASSIGN_OR_RETURN( + HloInstruction * replacement, + match.BuildReplacement(instr, Cast(fusion))); + + TF_RETURN_IF_ERROR( + instr->ReplaceAllUsesWith(replacement, match.config().name())); + + VLOG(2) << "Replaced instruction: " << instr->name() + << " with: " << replacement->name(); + } + + VLOG(2) << "Replace custom kernel fusion root instruction " + << match.root()->name() << "with " << fusion->name(); + HloComputation* parent = match.root()->parent(); + TF_RETURN_IF_ERROR(parent->ReplaceInstruction(match.root(), fusion)); + } + + return true; +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/custom_kernel_fusion_rewriter.h b/xla/service/gpu/custom_kernel_fusion_rewriter.h new file mode 100644 index 0000000000000..cb19d91a3dd57 --- /dev/null +++ b/xla/service/gpu/custom_kernel_fusion_rewriter.h @@ -0,0 +1,86 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_REWRITER_H_ +#define XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_REWRITER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h" +#include "xla/service/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" + +namespace xla::gpu { + +// Pattern matches HLO instruction to custom kernel fusions (hand written CUDA +// C++ kernels, e.g. custom GEMMs implemented with CUTLASS) and rewrites them +// into fusion instructions and fusion computations. +// +// Example: pattern matching dot operation into CUTLASS gemm +// +// ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] { +// %p0 = f16[15,19]{1,0} parameter(0) +// %p1 = f16[19,17]{1,0} parameter(1) +// ROOT %r = f16[15,17]{1,0} dot(%p0, %p1), +// lhs_contracting_dims={1}, rhs_contracting_dims={0} +// } +// +// After the pass: +// +// %cutlass_gemm (p0: f16[19,17], p1: f16[15,19]) -> f16[15,17] { +// %p0 = f16[15,19]{1,0} parameter(0) +// %p1 = f16[19,17]{1,0} parameter(1) +// ROOT %r = f16[15,17]{1,0} dot(%p0, %p1), +// lhs_contracting_dims={1}, rhs_contracting_dims={0} +// } +// +// ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] { +// %p0 = f16[15,19]{1,0} parameter(0) +// %p1 = f16[19,17]{1,0} parameter(1) +// ROOT %r = f16[15,17]{1,0} fusion(%p0, %p1), kind=kCustom, +// calls==cutlass_gemm, +// backend_config={kind: "__custom_fusion", +// custom_fusion_config: {"name":"cutlass_gemm"}} +// } +// +class CustomKernelFusionRewriter : public HloModulePass { + public: + explicit CustomKernelFusionRewriter( + const se::DeviceDescription* device, + const CustomKernelFusionPatternRegistry* patterns = + CustomKernelFusionPatternRegistry::Default()); + + absl::string_view name() const override { + return "custom-kernel-fusion-rewriter"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + const se::DeviceDescription* device_; + const CustomKernelFusionPatternRegistry* patterns_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_REWRITER_H_ diff --git a/xla/service/gpu/custom_kernel_fusion_rewriter_test.cc b/xla/service/gpu/custom_kernel_fusion_rewriter_test.cc new file mode 100644 index 0000000000000..ac0d1464f3612 --- /dev/null +++ b/xla/service/gpu/custom_kernel_fusion_rewriter_test.cc @@ -0,0 +1,138 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/custom_kernel_fusion_rewriter.h" + +#include +#include +#include + +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/test.h" + +namespace xla::gpu { + +//===----------------------------------------------------------------------===// +// Simple pattern matchers for testing custom kernel_fusion rewriter. +//===----------------------------------------------------------------------===// + +struct SimpleGemmPattern : public CustomKernelFusionPattern { + explicit SimpleGemmPattern(int64_t workspace = 0) : workspace(workspace) {} + + std::optional TryMatch(const se::DeviceDescription& device, + HloInstruction* instr) const override { + if (auto* dot = DynCast(instr)) { + CustomFusionConfig config; + config.set_name("simple_gemm"); + return Match{config, {instr}, workspace}; + } + return std::nullopt; + } + + int64_t workspace; +}; + +//===----------------------------------------------------------------------===// + +class CustomKernelFusionRewriterTest : public HloTestBase {}; + +TEST_F(CustomKernelFusionRewriterTest, SimpleGemm) { + const char* hlo = R"( + HloModule test + + ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] { + %p0 = f16[15,19]{1,0} parameter(0) + %p1 = f16[19,17]{1,0} parameter(1) + ROOT %r = f16[15,17]{1,0} dot(%p0, %p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + + const char* expected = R"( + ; CHECK: %simple_gemm {{.*}} { + ; CHECK: [[P0:%[^ ]+]] = f16[15,19]{1,0} parameter(0) + ; CHECK: [[P1:%[^ ]+]] = f16[19,17]{1,0} parameter(1) + ; CHECK: ROOT [[DOT:%[^ ]+]] = f16[15,17]{1,0} dot([[P0]], [[P1]]), + ; CHECK: lhs_contracting_dims={1}, rhs_contracting_dims={0} + ; CHECK: } + + ; CHECK: ENTRY %main {{.*}} { + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[15,17]{1,0} fusion + ; CHECK: kind=kCustom, calls=%simple_gemm, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"simple_gemm"} + ; CHECK: } + ; CHECK: } + )"; + + CustomKernelFusionPatternRegistry patterns; + patterns.Emplace(); + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + CustomKernelFusionRewriter pass(&device, &patterns); + RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); +} + +TEST_F(CustomKernelFusionRewriterTest, SimpleGemmWithWorkspace) { + const char* hlo = R"( + HloModule test + + ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] { + %p0 = f16[15,19]{1,0} parameter(0) + %p1 = f16[19,17]{1,0} parameter(1) + ROOT %r = f16[15,17]{1,0} dot(%p0, %p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + + const char* expected = R"( + ; CHECK: %simple_gemm {{.*}} { + ; CHECK: [[P0:%[^ ]+]] = f16[15,19]{1,0} parameter(0) + ; CHECK: [[P1:%[^ ]+]] = f16[19,17]{1,0} parameter(1) + ; CHECK: [[DOT:%[^ ]+]] = f16[15,17]{1,0} dot([[P0]], [[P1]]), + ; CHECK: lhs_contracting_dims={1}, rhs_contracting_dims={0} + ; CHECK: [[WORKSPACE:%[^ ]+]] = u8[1024]{0} custom-call(), + ; CHECK: custom_call_target="__custom_kernel_fusion$workspace" + ; CHECK: ROOT [[TUPLE:%[^ ]+]] = (f16[15,17]{1,0}, u8[1024]{0}) + ; CHECK: tuple([[DOT]], [[WORKSPACE]]) + ; CHECK: } + + ; CHECK: ENTRY %main {{.*}} { + ; CHECK: [[FUSION:%[^ ]+]] = (f16[15,17]{1,0}, u8[1024]{0}) fusion + ; CHECK: kind=kCustom, calls=%simple_gemm, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"simple_gemm"} + ; CHECK: } + ; CHECK: ROOT {{.*}} get-tuple-element([[FUSION]]), index=0 + ; CHECK: } + )"; + + CustomKernelFusionPatternRegistry patterns; + patterns.Emplace(1024); + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + CustomKernelFusionRewriter pass(&device, &patterns); + RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/determinism_test.cc b/xla/service/gpu/determinism_test.cc index ab362dba4ad42..43799901f2a0e 100644 --- a/xla/service/gpu/determinism_test.cc +++ b/xla/service/gpu/determinism_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/dot_algorithm_support_test.cc b/xla/service/gpu/dot_algorithm_support_test.cc new file mode 100644 index 0000000000000..78b39e88f5503 --- /dev/null +++ b/xla/service/gpu/dot_algorithm_support_test.cc @@ -0,0 +1,232 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include +#include +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/primitive_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { +namespace { + +using ::testing::Combine; +using ::testing::ConvertGenerator; +using ::testing::HasSubstr; +using ::testing::TestParamInfo; +using ::testing::Values; +using ::testing::WithParamInterface; + +enum class BackendRestriction { + kNoRestriction = 0, + kTritonOnly = 1, +}; + +std::string BackendRestrictionToString(BackendRestriction backend_restriction) { + switch (backend_restriction) { + case BackendRestriction::kNoRestriction: + return "no_restriction"; + case BackendRestriction::kTritonOnly: + return "triton_only"; + } +} + +struct Sizes { + int contracting_size; + int non_contracting_size; +}; + +struct TestParams { + using TupleType = + std::tuple; + + PrecisionConfig::Algorithm algorithm; + PrimitiveType input_storage_type; + PrimitiveType output_storage_type; + se::CudaComputeCapability min_cuda_capability; + BackendRestriction backend_restriction; + Sizes sizes; + + explicit TestParams(TupleType t) + : algorithm(std::get<0>(t)), + input_storage_type(std::get<1>(t)), + output_storage_type(std::get<2>(t)), + min_cuda_capability(std::get<3>(t)), + backend_restriction(std::get<4>(t)), + sizes(std::get<5>(t)) {} +}; + +std::string TestParamsToString(const TestParamInfo& info) { + const TestParams& params = info.param; + return absl::StrFormat( + "%s_with_input_%s_output_%s_from_cc_%d_%d_%s_c_%d_nc_%d", + AlgorithmToString(params.algorithm), + primitive_util::LowercasePrimitiveTypeName(params.input_storage_type), + primitive_util::LowercasePrimitiveTypeName(params.output_storage_type), + params.min_cuda_capability.major, params.min_cuda_capability.minor, + BackendRestrictionToString(params.backend_restriction), + params.sizes.contracting_size, params.sizes.non_contracting_size); +} + +// These are integration tests. +// TODO(tdanyluk): Consider checking somehow directly if the correct algorithms +// are called / emitted. Currently the emitters should decline unsupported +// algorithms, but maybe we could check this directly. + +class DotAlgorithmSupportTest : public HloTestBase, + public WithParamInterface { + public: + se::CudaComputeCapability GetCudaComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + } + + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + // Setting this explicitly to make sure that we also test the case when the + // dot's dimensions are under the rewrite size threshold: + // (2 * non_contracting_size * contracting_size < threshold). + debug_options.set_xla_gpu_gemm_rewrite_size_threshold(100); + return debug_options; + } +}; + +// A parametrized test that checks if an algorithm is supported, with the given +// input and output storage types, from a given cuda capability. +TEST_P(DotAlgorithmSupportTest, AlgorithmIsSupportedFromCudaCapability) { + const TestParams& params = GetParam(); + const std::string hlo_text = absl::Substitute( + R"( + HloModule test + + ENTRY test { + x = $1[$4,$3] parameter(0) + y = $1[$3,$4] parameter(1) + + ROOT out = $2[$4,$4] dot(x, y), + lhs_contracting_dims={1}, + rhs_contracting_dims={0}, + algorithm=$0 + } + )", + AlgorithmToString(params.algorithm), + primitive_util::LowercasePrimitiveTypeName(params.input_storage_type), + primitive_util::LowercasePrimitiveTypeName(params.output_storage_type), + params.sizes.contracting_size, params.sizes.non_contracting_size); + + if (GetCudaComputeCapability().IsAtLeast(params.min_cuda_capability.major, + params.min_cuda_capability.minor)) { + EXPECT_TRUE(Run(hlo_text)); + + if (params.backend_restriction == BackendRestriction::kTritonOnly) { + MatchOptimizedHlo(hlo_text, R"( + ;CHECK: ENTRY + ;CHECK: ROOT + ;CHECK-SAME: kCustom + ;CHECK-SAME: "triton_gemm_config" + )"); + } + } else { + EXPECT_THAT(Run(hlo_text).message(), HasSubstr("Unsupported algorithm")); + } +} + +using PC = PrecisionConfig; +using CC = se::CudaComputeCapability; + +INSTANTIATE_TEST_SUITE_P(F8E5M2Tests, DotAlgorithmSupportTest, + ConvertGenerator(Combine( + Values(PC::ALG_DOT_ANY_F8_ANY_F8_F32, + PC::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM), + Values(F8E5M2), Values(F8E5M2, F16, BF16, F32), + Values(CC(8, 9)), + Values(BackendRestriction::kNoRestriction), + Values(Sizes{32, 32}, Sizes{16, 2}))), + TestParamsToString); + +INSTANTIATE_TEST_SUITE_P(F8E4M3FNTests, DotAlgorithmSupportTest, + ConvertGenerator(Combine( + Values(PC::ALG_DOT_ANY_F8_ANY_F8_F32, + PC::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM), + Values(F8E4M3FN), Values(F8E4M3FN, F16, BF16, F32), + Values(CC(8, 9)), + Values(BackendRestriction::kNoRestriction), + Values(Sizes{32, 32}, Sizes{16, 2}))), + TestParamsToString); + +INSTANTIATE_TEST_SUITE_P(DotF16F16F32Tests, DotAlgorithmSupportTest, + ConvertGenerator(Combine( + Values(PC::ALG_DOT_F16_F16_F32), Values(F16), + Values(F16, F32), Values(CC(0, 0)), + Values(BackendRestriction::kNoRestriction), + Values(Sizes{32, 32}, Sizes{16, 2}))), + TestParamsToString); + +INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32Tests, DotAlgorithmSupportTest, + ConvertGenerator(Combine( + Values(PC::ALG_DOT_BF16_BF16_F32), Values(BF16), + Values(BF16, F32), Values(CC(8, 0)), + Values(BackendRestriction::kNoRestriction), + Values(Sizes{32, 32}, Sizes{16, 2}))), + TestParamsToString); + +INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32XnTests, DotAlgorithmSupportTest, + ConvertGenerator( + Combine(Values(PC::ALG_DOT_BF16_BF16_F32_X3, + PC::ALG_DOT_BF16_BF16_F32_X6), + Values(F32), Values(F32), Values(CC(8, 0)), + Values(BackendRestriction::kTritonOnly), + Values(Sizes{32, 32}, Sizes{16, 2}))), + TestParamsToString); + +INSTANTIATE_TEST_SUITE_P(DotTf32Tf32F32Tests, DotAlgorithmSupportTest, + ConvertGenerator( + Combine(Values(PC::ALG_DOT_TF32_TF32_F32), + Values(F32), Values(F32), Values(CC(8, 0)), + Values(BackendRestriction::kNoRestriction), + Values(Sizes{32, 32}, Sizes{16, 2}))), + TestParamsToString); + +INSTANTIATE_TEST_SUITE_P(DotF32F32F32Tests, DotAlgorithmSupportTest, + ConvertGenerator( + Combine(Values(PC::ALG_DOT_F32_F32_F32), + Values(F32), Values(F32), Values(CC(0, 0)), + Values(BackendRestriction::kNoRestriction), + Values(Sizes{32, 32}, Sizes{16, 2}))), + TestParamsToString); + +INSTANTIATE_TEST_SUITE_P(DotF64F64F64Tests, DotAlgorithmSupportTest, + ConvertGenerator( + Combine(Values(PC::ALG_DOT_F64_F64_F64), + Values(F64), Values(F64), Values(CC(0, 0)), + Values(BackendRestriction::kNoRestriction), + Values(Sizes{32, 32}, Sizes{16, 2}))), + TestParamsToString); + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/dot_dimension_sorter.cc b/xla/service/gpu/dot_dimension_sorter.cc index b44eb7f811f19..0609581981c01 100644 --- a/xla/service/gpu/dot_dimension_sorter.cc +++ b/xla/service/gpu/dot_dimension_sorter.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,21 +15,26 @@ limitations under the License. #include "xla/service/gpu/dot_dimension_sorter.h" -#include #include #include #include #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout_util.h" #include "xla/permutation_util.h" +#include "xla/status.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" namespace xla { @@ -38,7 +43,7 @@ namespace gpu { namespace { // Sort contracting dimensions of a dot() instruction preserving lhs-rhs pairs. -Status SortDotDimensions(HloInstruction* dot) { +absl::Status SortDotDimensions(HloDotInstruction* dot) { const DotDimensionNumbers& dims = dot->dot_dimension_numbers(); DotDimensionNumbers new_dims(dims); new_dims.clear_lhs_contracting_dimensions(); @@ -64,7 +69,8 @@ Status SortDotDimensions(HloInstruction* dot) { sorted_rhs.end()}; std::unique_ptr new_dot = HloInstruction::CreateDot( dot->shape(), dot->mutable_operand(0), dot->mutable_operand(1), new_dims, - dot->precision_config()); + dot->precision_config(), {dot->sparsity().begin(), dot->sparsity().end()}, + absl::MakeSpan(dot->operands()).subspan(HloDotInstruction::kOperands)); dot->SetupDerivedInstruction(new_dot.get()); VLOG(3) << "Sorted dot() dimensions:\n" @@ -75,7 +81,7 @@ Status SortDotDimensions(HloInstruction* dot) { } // namespace -StatusOr DotDimensionSorter::Run( +absl::StatusOr DotDimensionSorter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { std::vector dots_to_process; @@ -121,7 +127,7 @@ StatusOr DotDimensionSorter::Run( return false; } for (HloInstruction* dot : dots_to_process) { - TF_RETURN_IF_ERROR(SortDotDimensions(dot)); + TF_RETURN_IF_ERROR(SortDotDimensions(Cast(dot))); } return true; } diff --git a/xla/service/gpu/dot_dimension_sorter.h b/xla/service/gpu/dot_dimension_sorter.h index 0a2e19542f0b7..5eadeb14ceb50 100644 --- a/xla/service/gpu/dot_dimension_sorter.h +++ b/xla/service/gpu/dot_dimension_sorter.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_DOT_DIMENSION_SORTER_H_ #define XLA_SERVICE_GPU_DOT_DIMENSION_SORTER_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" @@ -38,7 +41,7 @@ class DotDimensionSorter : public HloModulePass { // Run the pass on computations in 'module'. // Returns whether the 'module' was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/dot_dimension_sorter_test.cc b/xla/service/gpu/dot_dimension_sorter_test.cc index f0a72b4cffe76..fedd1eae6b65c 100644 --- a/xla/service/gpu/dot_dimension_sorter_test.cc +++ b/xla/service/gpu/dot_dimension_sorter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,7 +15,15 @@ limitations under the License. #include "xla/service/gpu/dot_dimension_sorter.h" +#include + +#include +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -155,6 +163,29 @@ ENTRY e { EXPECT_FALSE(modified); } +TEST_F(DotDimensionSorterTest, SparseDotSortContractingDims) { + const char* module_string = R"( +HloModule m + +ENTRY e { + p0 = f16[1,144,96,16] parameter(0) + p1 = f16[122,96,32] parameter(1) + meta = u16[1,144,96,2] parameter(2) + ROOT _ = f16[1,144,122] dot(p0, p1, meta), sparsity=L.3@2:4, + lhs_contracting_dims={3,2}, rhs_contracting_dims={2,1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool modified, + DotDimensionSorter().Run(module.get())); + EXPECT_TRUE(modified); + HloDotInstruction* dot = DynCast( + module->entry_computation()->root_instruction()); + EXPECT_TRUE(dot != nullptr && dot->sparse_operands() == 1); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/dot_operand_converter.cc b/xla/service/gpu/dot_operand_converter.cc new file mode 100644 index 0000000000000..2a298e67eaf70 --- /dev/null +++ b/xla/service/gpu/dot_operand_converter.cc @@ -0,0 +1,74 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/dot_operand_converter.h" + +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape_util.h" +#include "tsl/platform/errors.h" + +namespace xla::gpu { + +bool DotOperandConverter::InstructionMatchesPattern( + HloInstruction* instruction) { + if (instruction->opcode() != HloOpcode::kDot) { + return false; + } + HloInstruction* lhs = instruction->mutable_operand(0); + HloInstruction* rhs = instruction->mutable_operand(1); + + PrimitiveType lhs_type = lhs->shape().element_type(); + PrimitiveType rhs_type = rhs->shape().element_type(); + + if (lhs_type == rhs_type) { + return false; + } + + // Exclude conversions between FP8 types. + absl::flat_hash_set non_converting = {F8E4M3FN, F8E5M2}; + if (non_converting.contains(lhs_type) && non_converting.contains(rhs_type)) { + return false; + } + + PrimitiveType desired_type = + ShapeUtil::HigherPrecisionElementType(lhs->shape(), rhs->shape()); + + return desired_type == lhs_type || desired_type == rhs_type; +} + +absl::StatusOr DotOperandConverter::ExpandInstruction( + HloInstruction* instruction) { + HloInstruction* lhs = instruction->mutable_operand(0); + HloInstruction* rhs = instruction->mutable_operand(1); + + // Find the higher precision type among the two operands, and add a convert + // instruction to convert the lesser-precise operand to that type. + PrimitiveType desired_type = + ShapeUtil::HigherPrecisionElementType(lhs->shape(), rhs->shape()); + int operand_index = desired_type == lhs->shape().element_type() ? 1 : 0; + HloInstruction* inst_to_replace = + desired_type == lhs->shape().element_type() ? rhs : lhs; + auto upcast_shape = inst_to_replace->shape(); + upcast_shape.set_element_type(desired_type); + auto* convert_inst = instruction->AddInstruction( + HloInstruction::CreateConvert(upcast_shape, inst_to_replace)); + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWithDifferentShape( + operand_index, convert_inst)); + return nullptr; +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/dot_operand_converter.h b/xla/service/gpu/dot_operand_converter.h new file mode 100644 index 0000000000000..d277a24100c0b --- /dev/null +++ b/xla/service/gpu/dot_operand_converter.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_DOT_OPERAND_CONVERTER_H_ +#define XLA_SERVICE_GPU_DOT_OPERAND_CONVERTER_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/op_expander_pass.h" +#include "xla/util.h" + +namespace xla::gpu { + +// Converts both operands to the highest precision operand type. +class DotOperandConverter : public OpExpanderPass { + public: + explicit DotOperandConverter(HloPredicate extra_filter = nullptr) + : OpExpanderPass(std::move(extra_filter)) {} + + absl::string_view name() const override { return "operand_converter"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + absl::StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_DOT_OPERAND_CONVERTER_H_ diff --git a/xla/service/gpu/dot_operand_converter_test.cc b/xla/service/gpu/dot_operand_converter_test.cc new file mode 100644 index 0000000000000..9a36a288e01b6 --- /dev/null +++ b/xla/service/gpu/dot_operand_converter_test.cc @@ -0,0 +1,141 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/dot_operand_converter.h" + +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/utils/hlo_matchers.h" +#include "xla/primitive_util.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +namespace op = ::xla::testing::opcode_matchers; + +class DotOperandConverterTest : public HloTestBase { + public: + void TestConvert(bool left_less_precise, PrimitiveType lhs_type, + PrimitiveType rhs_type, PrimitiveType result_type) { + absl::string_view module_tmpl = R"( + HloModule module + + ENTRY main { + p0 = $0[2,3]{1,0} parameter(0) + p1 = $1[3,2]{1,0} parameter(1) + ROOT dot = $2[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + })"; + auto module_string = absl::Substitute( + module_tmpl, primitive_util::LowercasePrimitiveTypeName(lhs_type), + primitive_util::LowercasePrimitiveTypeName(rhs_type), + primitive_util::LowercasePrimitiveTypeName(result_type)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool upcasted, + DotOperandConverter().Run(module.get())); + EXPECT_TRUE(upcasted); + if (left_less_precise) { + auto original_lhs = op::Parameter(0); + auto upcasted_lhs = + AllOf(op::Convert(original_lhs), + op::Shape(absl::Substitute( + "$0[2,3]{1,0}", + primitive_util::LowercasePrimitiveTypeName(rhs_type)))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + AllOf(op::Dot(upcasted_lhs, op::Parameter(1)), + op::Shape(absl::Substitute( + "$0[2,2]{1,0}", + primitive_util::LowercasePrimitiveTypeName(result_type))))); + } else { + auto original_rhs = op::Parameter(1); + auto upcasted_rhs = + AllOf(op::Convert(original_rhs), + op::Shape(absl::Substitute( + "$0[3,2]{1,0}", + primitive_util::LowercasePrimitiveTypeName(lhs_type)))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + AllOf(op::Dot(op::Parameter(0), upcasted_rhs), + op::Shape(absl::Substitute( + "$0[2,2]{1,0}", + primitive_util::LowercasePrimitiveTypeName(result_type))))); + } + } +}; + +TEST_F(DotOperandConverterTest, ConvertsLeftAndRight) { + TestConvert(/*left_less_precise=*/true, S8, BF16, F32); + TestConvert(/*left_less_precise=*/false, BF16, S8, F32); +} + +TEST_F(DotOperandConverterTest, NoConvertHappensWithSameTypes) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = s8[2,3]{1,0} parameter(0) + p1 = s8[3,2]{1,0} parameter(1) + ROOT dot = bf16[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool upcasted, + DotOperandConverter().Run(module.get())); + EXPECT_FALSE(upcasted); +} + +TEST_F(DotOperandConverterTest, NoConvertFromF8toF8) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = f8e4m3fn[2,3]{1,0} parameter(0) + p1 = f8e5m2[3,2]{1,0} parameter(1) + ROOT dot = bf16[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool upcasted, + DotOperandConverter().Run(module.get())); + EXPECT_FALSE(upcasted); +} + +TEST_F(DotOperandConverterTest, CompilerOptimizesUsingDotOperandConverter) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = s8[2,3]{1,0} parameter(0) + p1 = bf16[3,2]{1,0} parameter(1) + ROOT dot = bf16[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(module_string)); +} + +} // namespace +} // namespace xla::gpu diff --git a/xla/service/gpu/elemental_ir_emitter.cc b/xla/service/gpu/elemental_ir_emitter.cc index 260287cc25952..6d6675f21e46e 100644 --- a/xla/service/gpu/elemental_ir_emitter.cc +++ b/xla/service/gpu/elemental_ir_emitter.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,23 +15,32 @@ limitations under the License. #include "xla/service/gpu/elemental_ir_emitter.h" +#include #include #include // IWYU pragma: no_include "llvm/IR/Attributes.gen.inc" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/log/check.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/Support/ModRef.h" +#include "llvm/TargetParser/Triple.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout.h" -#include "xla/literal.h" +#include "xla/service/elemental_ir_emitter.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/ir_emitter_nested.h" @@ -39,7 +48,7 @@ limitations under the License. #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/service/llvm_ir/math_ops.h" -#include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -48,24 +57,12 @@ namespace gpu { using absl::StrAppend; -namespace { -// Returns whether operand is a floating-point literal with the given value. -bool IsFPLiteralWithValue(const HloInstruction* operand, float value) { - if (operand->opcode() == HloOpcode::kConstant && - operand->literal().IsAllFloat(value)) { - return true; - } - return operand->opcode() == HloOpcode::kBroadcast && - IsFPLiteralWithValue(operand->operand(0), value); -} -} // namespace - GpuElementalIrEmitter::GpuElementalIrEmitter( IrEmitterContext& ir_emitter_context, llvm::IRBuilder<>* b) : ElementalIrEmitter(ir_emitter_context.llvm_module(), b), ir_emitter_context_(ir_emitter_context) {} -StatusOr GpuElementalIrEmitter::EmitDeviceMathCall( +absl::StatusOr GpuElementalIrEmitter::EmitDeviceMathCall( TargetDeviceFunctionID funcid, absl::Span operands, absl::Span input_types, PrimitiveType output_type, absl::string_view name) { @@ -108,30 +105,7 @@ StatusOr GpuElementalIrEmitter::EmitDeviceMathCall( return result; } -StatusOr GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall( - const std::string& callee_name, absl::Span operands, - absl::Span input_types, PrimitiveType output_type) { - // llvm intrinsics differentiate between half/float/double functions via - // the suffixes ".f16", ".f32" and ".f64". - std::string munged_callee = callee_name; - switch (output_type) { - case F16: - StrAppend(&munged_callee, ".f16"); - break; - case F32: - StrAppend(&munged_callee, ".f32"); - break; - case F64: - StrAppend(&munged_callee, ".f64"); - break; - default: - return Unimplemented("Bad type for llvm intrinsic math call: %s", - PrimitiveType_Name(output_type)); - } - return EmitMathCall(munged_callee, operands, input_types, output_type); -} - -StatusOr GpuElementalIrEmitter::EmitMathCall( +absl::StatusOr GpuElementalIrEmitter::EmitMathCall( const std::string& callee_name, absl::Span operands, absl::Span input_types, PrimitiveType output_type, absl::string_view name) { @@ -158,8 +132,11 @@ llvm_ir::IrArray::Index GpuElementalIrEmitter::GetSourceIndexOfBitcast( // Decode the layout of the shape from the Protobufs attached to // backend_config_. - BitcastBackendConfig bitcast_config; - CHECK(bitcast_config.ParseFromString(hlo->raw_backend_config_string())); + auto gpu_config = hlo->backend_config(); + CHECK(gpu_config.ok()); + + const BitcastBackendConfig& bitcast_config = + gpu_config.value().bitcast_backend_config(); // If there is no layout in the protobuf, do not override it. if (!bitcast_config.result_layout().minor_to_major().empty()) { @@ -173,7 +150,7 @@ llvm_ir::IrArray::Index GpuElementalIrEmitter::GetSourceIndexOfBitcast( return index.SourceIndexOfBitcast(shape, operand_shape, b()); } -StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( +absl::StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); @@ -188,6 +165,17 @@ StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( {lhs_value, rhs_value}, {lhs_value->getType()}, b()); } + // sm_80 and up has min.NaN and max.NaN instructions. + if (output_type == F32 && + ir_emitter_context_.cuda_compute_capability().IsAtLeast( + se::CudaComputeCapability::AMPERE) && + (opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum)) { + return llvm_ir::EmitCallToIntrinsic( + opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maximum + : llvm::Intrinsic::minimum, + {lhs_value, rhs_value}, {lhs_value->getType()}, b()); + } + switch (op->opcode()) { case HloOpcode::kRemainder: { return EmitDeviceMathCall(TargetDeviceFunctionID::kFmod, @@ -202,7 +190,7 @@ StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( } } -StatusOr GpuElementalIrEmitter::EmitPowerOp( +absl::StatusOr GpuElementalIrEmitter::EmitPowerOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { CHECK_EQ(op->opcode(), HloOpcode::kPower); PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); @@ -213,77 +201,76 @@ StatusOr GpuElementalIrEmitter::EmitPowerOp( {lhs_input_type, rhs_input_type}, output_type); } -StatusOr GpuElementalIrEmitter::EmitLog(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr GpuElementalIrEmitter::EmitLog( + PrimitiveType prim_type, llvm::Value* value) { return EmitDeviceMathCall(TargetDeviceFunctionID::kLog, {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr GpuElementalIrEmitter::EmitLog1p( + PrimitiveType prim_type, llvm::Value* value) { return EmitDeviceMathCall(TargetDeviceFunctionID::kLog1p, {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitSin(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr GpuElementalIrEmitter::EmitSin( + PrimitiveType prim_type, llvm::Value* value) { return EmitDeviceMathCall(TargetDeviceFunctionID::kSin, {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr GpuElementalIrEmitter::EmitCos( + PrimitiveType prim_type, llvm::Value* value) { return EmitDeviceMathCall(TargetDeviceFunctionID::kCos, {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitTan(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr GpuElementalIrEmitter::EmitTan( + PrimitiveType prim_type, llvm::Value* value) { return EmitDeviceMathCall(TargetDeviceFunctionID::kTan, {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitExp( +absl::StatusOr GpuElementalIrEmitter::EmitExp( PrimitiveType prim_type, llvm::Value* value, absl::string_view /*name*/) { return EmitDeviceMathCall(TargetDeviceFunctionID::kExp, {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr GpuElementalIrEmitter::EmitExpm1( + PrimitiveType prim_type, llvm::Value* value) { return EmitDeviceMathCall(TargetDeviceFunctionID::kExpm1, {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type, - llvm::Value* lhs, - llvm::Value* rhs, - absl::string_view name) { +absl::StatusOr GpuElementalIrEmitter::EmitPow( + PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs, + absl::string_view name) { return EmitDeviceMathCall(TargetDeviceFunctionID::kPow, {lhs, rhs}, {prim_type, prim_type}, prim_type, name); } -StatusOr GpuElementalIrEmitter::EmitSqrt(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr GpuElementalIrEmitter::EmitSqrt( + PrimitiveType prim_type, llvm::Value* value) { return EmitDeviceMathCall(TargetDeviceFunctionID::kSqrt, {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr GpuElementalIrEmitter::EmitRsqrt( + PrimitiveType prim_type, llvm::Value* value) { return EmitDeviceMathCall(TargetDeviceFunctionID::kRsqrt, {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitAtan2( +absl::StatusOr GpuElementalIrEmitter::EmitAtan2( PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs, absl::string_view name) { return EmitDeviceMathCall(TargetDeviceFunctionID::kAtan2, {lhs, rhs}, {prim_type, prim_type}, prim_type, name); } -StatusOr GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr GpuElementalIrEmitter::EmitTanh( + PrimitiveType prim_type, llvm::Value* value) { // When F64 is being requested, assume performance is less important and use // the more numerically precise tanh function. if (prim_type == F64) { @@ -316,33 +303,37 @@ StatusOr GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, value->getType(), "tanh"); } -StatusOr GpuElementalIrEmitter::EmitComplexAbs( +absl::StatusOr GpuElementalIrEmitter::EmitErf( + PrimitiveType prim_type, llvm::Value* value) { + if (prim_type == F64) { + return EmitDeviceMathCall(TargetDeviceFunctionID::kErf, {value}, + {prim_type}, prim_type); + } + // Upcast F16 to F32 if necessary. + llvm::Type* type = prim_type == F16 ? b()->getFloatTy() : value->getType(); + if (type == b()->getFloatTy()) { + llvm::Value* x = FPCast(value, type); + auto* result = llvm_ir::EmitErfF32(b(), x); + return FPCast(result, value->getType()); + } + return Unimplemented("erf"); +} + +absl::StatusOr GpuElementalIrEmitter::EmitComplexAbs( PrimitiveType prim_type, llvm::Value* value) { return EmitDeviceMathCall(TargetDeviceFunctionID::kHypot, {EmitExtractReal(value), EmitExtractImag(value)}, {prim_type, prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitCbrt(PrimitiveType prim_type, - llvm::Value* value) { +absl::StatusOr GpuElementalIrEmitter::EmitCbrt( + PrimitiveType prim_type, llvm::Value* value) { return EmitDeviceMathCall(TargetDeviceFunctionID::kCbrt, {value}, {prim_type}, prim_type); } -llvm::Value* GpuElementalIrEmitter::EmitThreadId() { - llvm::Value* block_id = IntCast( - EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b()), - b()->getIntNTy(128), /*isSigned=*/true, "block.id"); - llvm::Value* thread_id_in_block = IntCast( - EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b()), - b()->getIntNTy(128), /*isSigned=*/true, "thread.id"); - llvm::Value* threads_per_block = IntCast( - EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b()), - b()->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); - return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block); -} - -StatusOr> GpuElementalIrEmitter::EmitThreadLocalCall( +absl::StatusOr> +GpuElementalIrEmitter::EmitThreadLocalCall( const HloComputation& callee, absl::Span parameters, absl::string_view, bool /*is_reducer*/) { return CallNestedComputationWithScalars(b(), ir_emitter_context_, callee, diff --git a/xla/service/gpu/elemental_ir_emitter.h b/xla/service/gpu/elemental_ir_emitter.h index f97861ba2e7af..83f82818f2d64 100644 --- a/xla/service/gpu/elemental_ir_emitter.h +++ b/xla/service/gpu/elemental_ir_emitter.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,7 +17,10 @@ limitations under the License. #define XLA_SERVICE_GPU_ELEMENTAL_IR_EMITTER_H_ #include +#include +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" @@ -26,8 +29,7 @@ limitations under the License. #include "xla/service/elemental_ir_emitter.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/target_util.h" -#include "xla/service/hlo_module_config.h" -#include "xla/statusor.h" +#include "xla/service/llvm_ir/ir_array.h" #include "xla/xla_data.pb.h" namespace xla { @@ -42,88 +44,83 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { llvm_ir::IrArray::Index GetSourceIndexOfBitcast( const llvm_ir::IrArray::Index& index, const HloInstruction* hlo) override; - StatusOr EmitFloatBinaryOp(const HloInstruction* op, - llvm::Value* lhs_value, - llvm::Value* rhs_value) override; - - StatusOr EmitLog(PrimitiveType prim_type, - llvm::Value* value) override; + absl::StatusOr EmitFloatBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, + llvm::Value* rhs_value) override; - StatusOr EmitLog1p(PrimitiveType prim_type, - llvm::Value* value) override; + absl::StatusOr EmitLog(PrimitiveType prim_type, + llvm::Value* value) override; - StatusOr EmitSin(PrimitiveType prim_type, - llvm::Value* value) override; + absl::StatusOr EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) override; - StatusOr EmitCos(PrimitiveType prim_type, - llvm::Value* value) override; + absl::StatusOr EmitSin(PrimitiveType prim_type, + llvm::Value* value) override; - StatusOr EmitTan(PrimitiveType prim_type, - llvm::Value* value) override; + absl::StatusOr EmitCos(PrimitiveType prim_type, + llvm::Value* value) override; - StatusOr EmitExp(PrimitiveType prim_type, llvm::Value* value, - absl::string_view name) override; + absl::StatusOr EmitTan(PrimitiveType prim_type, + llvm::Value* value) override; - StatusOr EmitExpm1(PrimitiveType prim_type, - llvm::Value* value) override; + absl::StatusOr EmitExp(PrimitiveType prim_type, + llvm::Value* value, + absl::string_view name) override; - StatusOr EmitSqrt(PrimitiveType prim_type, - llvm::Value* value) override; + absl::StatusOr EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) override; - StatusOr EmitRsqrt(PrimitiveType prim_type, - llvm::Value* value) override; + absl::StatusOr EmitSqrt(PrimitiveType prim_type, + llvm::Value* value) override; - StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs, - absl::string_view name) override; + absl::StatusOr EmitRsqrt(PrimitiveType prim_type, + llvm::Value* value) override; - StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs, - absl::string_view name) override; + absl::StatusOr EmitPow(PrimitiveType prim_type, + llvm::Value* lhs, llvm::Value* rhs, + absl::string_view name) override; - StatusOr EmitTanh(PrimitiveType prim_type, - llvm::Value* value) override; + absl::StatusOr EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, llvm::Value* rhs, + absl::string_view name) override; - StatusOr EmitComplexAbs(PrimitiveType prim_type, + absl::StatusOr EmitTanh(PrimitiveType prim_type, llvm::Value* value) override; - StatusOr EmitCbrt(PrimitiveType prim_type, - llvm::Value* value) override; + absl::StatusOr EmitErf(PrimitiveType prim_type, + llvm::Value* value) override; + + absl::StatusOr EmitComplexAbs(PrimitiveType prim_type, + llvm::Value* value) override; + + absl::StatusOr EmitCbrt(PrimitiveType prim_type, + llvm::Value* value) override; - StatusOr> EmitThreadLocalCall( + absl::StatusOr> EmitThreadLocalCall( const HloComputation& callee, absl::Span parameters, absl::string_view, bool /*is_reducer*/) override; - llvm::Value* EmitThreadId() override; - bool fast_min_max() override { return ir_emitter_context_.debug_options().xla_gpu_enable_fast_min_max(); } private: // Emits IR for op, which must have opcode kPower. - StatusOr EmitPowerOp(const HloInstruction* op, - llvm::Value* lhs_value, - llvm::Value* rhs_value); - - // Emits IR to call an LLVM intrinsic of type [T] -> T. Adjusts - // callee_name according to T. Returns the IR value that represents the - // return value of the function. - StatusOr EmitLlvmIntrinsicMathCall( - const std::string& callee_name, absl::Span operands, - absl::Span input_types, PrimitiveType output_type); + absl::StatusOr EmitPowerOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value); // Emits IR to call a device function of type [T] -> T. Adjusts // callee_name according to T. Returns the IR value that represents the // return value of the function. - StatusOr EmitDeviceMathCall( + absl::StatusOr EmitDeviceMathCall( TargetDeviceFunctionID funcid, absl::Span operands, absl::Span input_types, PrimitiveType output_type, absl::string_view name = ""); // Emits IR to call a function of type [T] -> T. Does not munge callee_name. // Returns the IR value that represents the return value of the function. - StatusOr EmitMathCall( + absl::StatusOr EmitMathCall( const std::string& callee_name, absl::Span operands, absl::Span input_types, PrimitiveType output_type, absl::string_view name = ""); diff --git a/xla/service/gpu/executable.proto b/xla/service/gpu/executable.proto index 3a9e08a866511..e66c48c4762b2 100644 --- a/xla/service/gpu/executable.proto +++ b/xla/service/gpu/executable.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ syntax = "proto3"; package xla.gpu; import "xla/service/hlo.proto"; +import "xla/xla.proto"; message XlaRuntimeGpuExecutableProto { message ConstantInfoProto { @@ -39,8 +40,9 @@ message XlaRuntimeGpuExecutableProto { } message CompilationResultProto { - HloModuleProto hlo_module = 1; + HloModuleProtoWithConfig hlo_module_with_config = 1; BufferAssignmentProto buffer_assignment = 2; string asm_text = 3; bytes binary = 4; + map dnn_compiled_graphs = 5; } diff --git a/xla/service/gpu/float_support_test.cc b/xla/service/gpu/float_support_test.cc index 06b5176a21866..5822d10a8deed 100644 --- a/xla/service/gpu/float_support_test.cc +++ b/xla/service/gpu/float_support_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,9 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include #include "absl/strings/string_view.h" #include "xla/error_spec.h" +#include "xla/service/gpu/variant_visitor.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla.pb.h" @@ -26,11 +29,11 @@ namespace { class FloatSupportTest : public HloTestBase { public: - se::CudaComputeCapability GetCudaComputeCapability() { + const se::GpuComputeCapability& GetGpuComputeCapability() { return backend() .default_stream_executor() ->GetDeviceDescription() - .cuda_compute_capability(); + .gpu_compute_capability(); } }; @@ -72,9 +75,15 @@ ENTRY e { } TEST_F(FloatSupportTestWithTriton, MixedTypeDotWithBF16IsNotUpcasted) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; + bool skip_test = std::visit( + VariantVisitor{[](const se::CudaComputeCapability& cc) { + return !cc.IsAtLeast(se::CudaComputeCapability::AMPERE); + }, + [](const se::RocmComputeCapability&) { return true; }}, + GetGpuComputeCapability()); + + if (skip_test) { + GTEST_SKIP() << "Not supported on this GPU architecture"; } constexpr absl::string_view kHloText = R"( diff --git a/xla/service/gpu/for_thunk.cc b/xla/service/gpu/for_thunk.cc deleted file mode 100644 index 2d4a8b4176a52..0000000000000 --- a/xla/service/gpu/for_thunk.cc +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/for_thunk.h" - -#include -#include - -#include "xla/util.h" -#include "tsl/platform/errors.h" - -namespace xla { -namespace gpu { - -ForThunk::ForThunk(ThunkInfo thunk_info, const int64_t loop_limit, - std::unique_ptr body_thunk_sequence) - : Thunk(Kind::kFor, thunk_info), - loop_limit_(loop_limit), - body_thunk_sequence_(std::make_unique( - // Pass nullptr as the HloInstruction* to the body_thunk_sequence_ - // constructor because this SequentialThunk is logically "part of" - // this ForThunk, and shouldn't be profiled separately from it. - ThunkInfo(thunk_info.op), std::move(*body_thunk_sequence))) {} - -Status ForThunk::Initialize(se::StreamExecutor* executor, - ExecutableSource src) { - TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executor, src)); - return OkStatus(); -} - -Status ForThunk::ExecuteOnStream(const ExecuteParams& params) { - VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters"; - for (int64_t i = 0; i < loop_limit_; ++i) { - VLOG(3) << "Executing iteration # " << i; - // Invoke loop body thunk sequence. - TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params)); - } - return OkStatus(); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/for_thunk.h b/xla/service/gpu/for_thunk.h deleted file mode 100644 index caa4db61b306f..0000000000000 --- a/xla/service/gpu/for_thunk.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_FOR_THUNK_H_ -#define XLA_SERVICE_GPU_FOR_THUNK_H_ - -#include - -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/sequential_thunk.h" -#include "xla/service/gpu/thunk.h" -#include "xla/stream_executor/stream_executor.h" - -namespace xla { -namespace gpu { - -// ForThunk executes 'loop_limit' invocations of 'body_thunk_sequence'. -class ForThunk : public Thunk { - public: - ForThunk(ThunkInfo thunk_info, int64_t loop_limit, - std::unique_ptr body_thunk_sequence); - ForThunk(const ForThunk&) = delete; - ForThunk& operator=(const ForThunk&) = delete; - - Status Initialize(se::StreamExecutor* executor, - ExecutableSource src) override; - Status ExecuteOnStream(const ExecuteParams& params) override; - - SequentialThunk* body_thunk_sequence() { return body_thunk_sequence_.get(); } - - private: - const int64_t loop_limit_; - std::unique_ptr body_thunk_sequence_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_FOR_THUNK_H_ diff --git a/xla/service/gpu/fusion_merger.cc b/xla/service/gpu/fusion_merger.cc index 3e44bc4c84d7d..0c0c9ee15d80f 100644 --- a/xla/service/gpu/fusion_merger.cc +++ b/xla/service/gpu/fusion_merger.cc @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,22 +15,32 @@ limitations under the License. #include "xla/service/gpu/fusion_merger.h" -#include -#include -#include #include #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" +#include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_graph_dumper.h" +#include "xla/service/instruction_fusion.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" #include "xla/util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/status.h" namespace xla { namespace gpu { @@ -50,13 +60,13 @@ class FusionInstructionMerger { .debug_options() .xla_dump_fusion_visualization()) {} - Status Run(); + absl::Status Run(); bool changed() const { return changed_; } private: FusionDecision ShouldFuse(HloInstruction* producer); - Status FuseIntoAllUsers(HloInstruction* producer); + absl::Status FuseIntoAllUsers(HloInstruction* producer); HloComputation* computation_; HloCostAnalysis::ShapeSizeFunction shape_size_function_; @@ -83,7 +93,8 @@ class FusionInstructionMerger { FusionInstructionMerger& operator=(const FusionInstructionMerger&) = delete; }; -Status FusionInstructionMerger::FuseIntoAllUsers(HloInstruction* producer) { +absl::Status FusionInstructionMerger::FuseIntoAllUsers( + HloInstruction* producer) { // Merge fused instructions from 'fusion' into each user. std::vector users = producer->users(); for (HloInstruction* user : users) { @@ -131,10 +142,10 @@ Status FusionInstructionMerger::FuseIntoAllUsers(HloInstruction* producer) { absl::StrAppend(out, user->name()); }) << " }"; - return OkStatus(); + return absl::OkStatus(); } -Status FusionInstructionMerger::Run() { +absl::Status FusionInstructionMerger::Run() { for (HloInstruction* producer : computation_->MakeInstructionPostOrder()) { if (producer->opcode() != HloOpcode::kFusion) { continue; @@ -171,7 +182,7 @@ Status FusionInstructionMerger::Run() { << num_fail_inefficient_fusion_emitter_ << " slower_if_fused: " << num_fail_slower_if_fused_ << " fusion_too_large: " << num_fail_fusion_too_large_ << " }"; - return OkStatus(); + return absl::OkStatus(); } bool TransposesMostData(const HloInstruction& fusion) { @@ -220,6 +231,10 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { ++num_fail_merge_all_users_; return "not fusing bitcast ops"; } + if (user->IsCustomFusion()) { + ++num_fail_merge_all_users_; + return "not fusing custom fusions"; + } auto consumer_hero = GetRealHeroForMultiOutputFusion(*user); if (auto compatible = FusionHeroesAreCompatible(producer_hero, consumer_hero); @@ -275,8 +290,7 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { } GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( - producer, &*cost_analysis_, - GpuPerformanceModelOptions::ForModule(producer->GetModule()), + producer, &*cost_analysis_, GpuPerformanceModelOptions::Default(), producer->users()); if (t.time_fused > t.time_unfused) { ++num_fail_slower_if_fused_; @@ -286,7 +300,7 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { return {}; } -StatusOr FusionMerger::Run( +absl::StatusOr FusionMerger::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/gpu/fusion_merger.h b/xla/service/gpu/fusion_merger.h index f332ee141372f..acbc93e7781fb 100644 --- a/xla/service/gpu/fusion_merger.h +++ b/xla/service/gpu/fusion_merger.h @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,8 +16,11 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_FUSION_MERGER_H_ #define XLA_SERVICE_GPU_FUSION_MERGER_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_pass_interface.h" #include "xla/stream_executor/device_description.h" @@ -67,7 +70,7 @@ class FusionMerger : public HloModulePass { absl::string_view name() const override { return "fusion_merger"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/gpu/fusion_merger_test.cc b/xla/service/gpu/fusion_merger_test.cc index 243c2f3944c56..447ffce64998c 100644 --- a/xla/service/gpu/fusion_merger_test.cc +++ b/xla/service/gpu/fusion_merger_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,14 +15,22 @@ limitations under the License. #include "xla/service/gpu/fusion_merger.h" +#include #include +#include +#include #include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/gpu_fusible.h" +#include "xla/service/hlo_cost_analysis.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/stream_executor/device_description.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" namespace xla { diff --git a/xla/service/gpu/fusion_merger_triton.cc b/xla/service/gpu/fusion_merger_triton.cc deleted file mode 100644 index ad9d76d5cd6c8..0000000000000 --- a/xla/service/gpu/fusion_merger_triton.cc +++ /dev/null @@ -1,131 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/fusion_merger_triton.h" - -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/triton_fusion_analysis.h" -#include "xla/status.h" -#include "xla/statusor.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" - -namespace xla::gpu { -namespace { - -// Taking in a producer HloFusionInstruction, tries to merge into consumer -// triton softmax fusion. -// The following is assumed: -// * The producer is an HloFusionInstruction -// * The consumer is a triton softmax fusion -// -// Returns true if the producer is merged into the consumer and replaced -// in the original computation. Returns false otherwise. -StatusOr TryMergeFusionProducerIntoTritonSoftmaxConsumer( - HloFusionInstruction* producer) { - HloComputation* computation = producer->parent(); - HloInstruction* original_softmax_instruction = producer->users().front(); - CHECK_EQ(original_softmax_instruction->opcode(), HloOpcode::kFusion); - - std::unique_ptr candidate = - original_softmax_instruction->Clone(); - HloInstruction* candidate_fusion = - static_cast(candidate.get()); - - // Try to merge the producer into candidate fusion - candidate_fusion->MergeFusionInstruction(producer); - - HloComputation* fused_computation = - candidate_fusion->called_computations().front(); - - TF_ASSIGN_OR_RETURN(const auto analysis, - TritonFusionAnalysis::Execute(*fused_computation)); - - computation->AddInstruction(std::move(candidate)); - - if (original_softmax_instruction->IsRoot()) { - computation->set_root_instruction(candidate_fusion); - } - - TF_CHECK_OK( - original_softmax_instruction->ReplaceAllUsesWith(candidate_fusion)); - TF_RETURN_IF_ERROR( - computation->RemoveInstruction(original_softmax_instruction)); - - CHECK_EQ(0, producer->user_count()) << producer->ToString(); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(producer)); - return true; -} - -} // anonymous namespace - -StatusOr FusionMergerTriton::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - int fused_comps = 0; - for (HloComputation* comp : - module->MakeNonfusionComputations(execution_threads)) { - if (comp->IsCustomCallComputation()) { - continue; - } - - for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { - if (instr->opcode() == HloOpcode::kFusion && - instr->fusion_kind() == HloInstruction::FusionKind::kCustom && - instr->backend_config().ok() && - instr->backend_config()->kind() == - kTritonSoftmaxFusionKind) { - // TODO(b/313026024): Add support for multiple users - if (instr->operand(0)->opcode() != HloOpcode::kFusion || - instr->operand(0)->user_count() != 1) { - continue; - } - - HloFusionInstruction* producer = - Cast(instr->mutable_operand(0)); - - VLOG(6) << "Matched triton_softmax kernel, Fusing producer " - << producer->ToShortString() << " into " - << instr->ToShortString(); - - absl::StatusOr result = - TryMergeFusionProducerIntoTritonSoftmaxConsumer(producer); - - if (!result.ok()) { - VLOG(6) << "Did not fuse producer into " << instr->ToShortString(); - } - - if (result.ok() && *result) ++fused_comps; - } - } - } - return fused_comps > 0; -} -} // namespace xla::gpu diff --git a/xla/service/gpu/fusion_merger_triton.h b/xla/service/gpu/fusion_merger_triton.h deleted file mode 100644 index 56fb5e4667bbb..0000000000000 --- a/xla/service/gpu/fusion_merger_triton.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_FUSION_MERGER_TRITON_H_ -#define XLA_SERVICE_GPU_FUSION_MERGER_TRITON_H_ - -#include "absl/container/flat_hash_set.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" - -namespace xla { -namespace gpu { - -// An HLO pass that attempts to merge producer fusions into triton softmax -// fusions. -// -// Producer kernels are only merged if the resulting fusion can be correctly -// tiled. If the result can be tiled, all operations from the auxiliary -// producer fusion will be merged into the triton softmax computation, and this -// computation will replace both the auxiliary and original triton softmax -// fusion. -// -// Auxiliary fusions are not merged into consumer triton fusions if: -// * The auxiliary fusion has multiple users -// * The resulting merged fusion is not tilable -class FusionMergerTriton : public HloModulePass { - public: - explicit FusionMergerTriton() = default; - absl::string_view name() const override { return "fusion-merger-triton"; } - - using HloPassInterface::Run; - StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_FUSION_MERGER_TRITON_H_ diff --git a/xla/service/gpu/fusion_merger_triton_test.cc b/xla/service/gpu/fusion_merger_triton_test.cc deleted file mode 100644 index 86de08d942864..0000000000000 --- a/xla/service/gpu/fusion_merger_triton_test.cc +++ /dev/null @@ -1,323 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/fusion_merger_triton.h" - -#include -#include - -#include -#include -#include "absl/log/log.h" -#include "xla/autotune_results.pb.h" -#include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/status_matchers.h" - -using ::tsl::testing::IsOk; -using ::tsl::testing::IsOkAndHolds; - -namespace xla { -namespace gpu { -namespace { - -namespace m = ::xla::match; -using FusionMergerTritonTest = HloTestBase; - -TEST_F(FusionMergerTritonTest, - CanMergeTritonFusionWithSingleParameterProducer) { - const std::string kHloText = R"( -HloModule t -add { - Arg_0 = f32[] parameter(0) - Arg_1 = f32[] parameter(1) - ROOT add = f32[] add(Arg_0, Arg_1) -} - -auxiliary_computation { - parameter_0 = f32[125]{0} parameter(0) - ROOT broadcast = f32[125,127]{1,0} broadcast(parameter_0), dimensions={0} -} - -triton_softmax_computation { - parameter_0 = f32[125,127]{1,0} parameter(0) - multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) - constant_0 = f32[] constant(0) - reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add - broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} - ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) -} - -ENTRY main { - param_0 = f32[125]{0} parameter(0) - auxiliary_fusion = f32[125,127]{1,0} fusion(param_0), kind=kLoop, calls=auxiliary_computation - ROOT triton_softmax = f32[125,127]{1,0} fusion(auxiliary_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} -})"; - auto module = ParseAndReturnVerifiedModule(kHloText).value(); - FusionMergerTriton fusion_merger; - EXPECT_THAT(fusion_merger.Run(module.get()), IsOkAndHolds(true)); - EXPECT_THAT(verifier().Run(module.get()), IsOk()); - VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter()))); -} - -TEST_F( - FusionMergerTritonTest, - CanMergeProducerFusionIntoTritonSoftmaxConsumerWhenTheConsumerIsNotRoot) { // NOLINT(whitespace/line_length) - const std::string kHloText = R"( -HloModule t -add { - Arg_0 = f32[] parameter(0) - Arg_1 = f32[] parameter(1) - ROOT add = f32[] add(Arg_0, Arg_1) -} - -auxiliary_computation { - parameter_0 = f32[125]{0} parameter(0) - ROOT broadcast = f32[125,127]{1,0} broadcast(parameter_0), dimensions={0} -} - -triton_softmax_computation { - parameter_0 = f32[125,127]{1,0} parameter(0) - multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) - constant_0 = f32[] constant(0) - reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add - broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} - ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) -} - -ENTRY main { - param_0 = f32[125]{0} parameter(0) - auxiliary_fusion = f32[125,127]{1,0} fusion(param_0), kind=kLoop, calls=auxiliary_computation - triton_softmax = f32[125,127]{1,0} fusion(auxiliary_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} - ROOT broadcast = f32[10,125,127]{2,1,0} broadcast(triton_softmax), dimensions={1,2} -})"; - auto module = ParseAndReturnVerifiedModule(kHloText).value(); - FusionMergerTriton fusion_merger; - EXPECT_THAT(fusion_merger.Run(module.get()), IsOkAndHolds(true)); - EXPECT_THAT(verifier().Run(module.get()), IsOk()); - VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Broadcast(m::Fusion(m::Parameter())))); -} - -TEST_F(FusionMergerTritonTest, - CanMergeTritonFusionWithMultipleParameterProducer) { - const std::string kHloText = R"( -HloModule t -add { - Arg_0 = f32[] parameter(0) - Arg_1 = f32[] parameter(1) - ROOT add = f32[] add(Arg_0, Arg_1) -} - -auxiliary_computation { - parameter_0 = f32[125]{0} parameter(0) - parameter_1 = f32[125,127]{1,0} parameter(1) - broadcast = f32[125,127]{1,0} broadcast(parameter_0), dimensions={0} - ROOT multiply = f32[125,127]{1,0} multiply(parameter_1, broadcast) -} - -triton_softmax_computation { - parameter_0 = f32[125,127]{1,0} parameter(0) - multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) - constant_0 = f32[] constant(0) - reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add - broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} - ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) -} - -ENTRY main { - param_0 = f32[125]{0} parameter(0) - param_1 = f32[125,127]{1,0} parameter(1) - auxiliary_fusion = f32[125,127]{1,0} fusion(param_0, param_1), kind=kLoop, calls=auxiliary_computation - ROOT triton_softmax = f32[125,127]{1,0} fusion(auxiliary_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} -})"; - auto module = ParseAndReturnVerifiedModule(kHloText).value(); - FusionMergerTriton fusion_merger; - EXPECT_TRUE(fusion_merger.Run(module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); -} - -TEST_F(FusionMergerTritonTest, CanMergeTritonFusionWithTransposeProducer) { - const std::string kHloText = R"( -HloModule t -add { - Arg_0 = f32[] parameter(0) - Arg_1 = f32[] parameter(1) - ROOT add = f32[] add(Arg_0, Arg_1) -} - -auxiliary_computation { - parameter_0 = f32[125]{0} parameter(0) - parameter_1 = f32[127,125]{1,0} parameter(1) - transpose = f32[125,127]{1,0} transpose(parameter_1), dimensions={1,0} - broadcast = f32[125,127]{1,0} broadcast(parameter_0), dimensions={0} - ROOT multiply = f32[125,127]{1,0} multiply(transpose, broadcast) -} - -triton_softmax_computation { - parameter_0 = f32[125,127]{1,0} parameter(0) - multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) - constant_0 = f32[] constant(0) - reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add - broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} - ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) -} - -ENTRY main { - param_0 = f32[125]{0} parameter(0) - param_1 = f32[127,125]{1,0} parameter(1) - auxiliary_fusion = f32[125,127]{1,0} fusion(param_0, param_1), kind=kLoop, calls=auxiliary_computation - ROOT triton_softmax = f32[125,127]{1,0} fusion(auxiliary_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} -})"; - auto module = ParseAndReturnVerifiedModule(kHloText).value(); - FusionMergerTriton fusion_merger; - EXPECT_TRUE(fusion_merger.Run(module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); -} - -TEST_F(FusionMergerTritonTest, - DoesNotMergeTritonFusionWithProducerContainingUntileableOp) { - // Right now, concatenate is not tileable. - const std::string kHloText = R"( -HloModule t -add { - Arg_0 = f32[] parameter(0) - Arg_1 = f32[] parameter(1) - ROOT add = f32[] add(Arg_0, Arg_1) -} - -auxiliary_computation { - parameter_0 = f32[125,63]{1,0} parameter(0) - parameter_1 = f32[125,64]{1,0} parameter(1) - ROOT concatenate = f32[125,127]{1,0} concatenate(parameter_0, parameter_1), dimensions={1} -} - -triton_softmax_computation { - parameter_0 = f32[125,127]{1,0} parameter(0) - multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) - constant_0 = f32[] constant(0) - reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add - broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} - ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) -} - -ENTRY main { - param_0 = f32[125,63]{1,0} parameter(0) - param_1 = f32[125,64]{1,0} parameter(1) - auxiliary_fusion = f32[125,127]{1,0} fusion(param_0, param_1), kind=kLoop, calls=auxiliary_computation - ROOT triton_softmax = f32[125,127]{1,0} fusion(auxiliary_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} -})"; - auto module = ParseAndReturnVerifiedModule(kHloText).value(); - FusionMergerTriton fusion_merger; - EXPECT_FALSE(fusion_merger.Run(module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Fusion(m::Parameter(), m::Parameter())))); -} - -TEST_F(FusionMergerTritonTest, CanMergeTritonFusionWithElementwiseProducer) { - const std::string kHloText = R"( -HloModule layernorm - -add_f32 { - Arg_0 = f32[] parameter(0) - Arg_1 = f32[] parameter(1) - ROOT add_6 = f32[] add(Arg_0, Arg_1) -} - -auxiliary_fusion { - parameter_0 = f32[125,127]{1,0} parameter(0) - parameter_1 = f32[125,127]{1,0} parameter(1) - ROOT multiply_1 = f32[125,127]{1,0} multiply(parameter_0, parameter_1) -} - -triton_softmax_computation { - parameter_0 = f32[125,127]{1,0} parameter(0) - constant_0 = f32[] constant(0) - reduce = f32[125]{0} reduce(parameter_0, constant_0), dimensions={1}, to_apply=add_f32 - broadcast = f32[125,127]{1,0} broadcast(reduce), dimensions={0} - ROOT multiply_result = f32[125,127]{1,0} multiply(parameter_0, broadcast) -} - -ENTRY main { - param_0 = f32[125,127]{1,0} parameter(0) - param_1 = f32[125,127]{1,0} parameter(1) - auxiliary_fusion = f32[125,127]{1,0} fusion(param_0, param_1), kind=kCustom, calls=auxiliary_fusion - ROOT triton_softmax = f32[125,127]{1,0} fusion(auxiliary_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} -} - -)"; - auto module = ParseAndReturnVerifiedModule(kHloText).value(); - FusionMergerTriton fusion_merger; - EXPECT_TRUE(fusion_merger.Run(module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); -} - -TEST_F(FusionMergerTritonTest, - DoesNotMergeSoftmaxWithParamBroadcastedAlongBatchAndReduceDimensions) { - const std::string kHloText = R"( -HloModule t - -add { - Arg_0 = f32[] parameter(0) - Arg_1 = f32[] parameter(1) - ROOT add = f32[] add(Arg_0, Arg_1) -} - -auxiliary_computation { - param_0 = f32[10,125,127]{2,1,0} parameter(0) - param_1 = f32[10]{0} parameter(1) - broadcast_0 = f32[10,125,127]{2,1,0} broadcast(param_1), dimensions={0} - ROOT multiply_0 = f32[10,125,127]{2,1,0} multiply(param_0, broadcast_0) -} - -triton_softmax_computation { - param_0 = f32[10,125,127]{2,1,0} parameter(0) - multiply = f32[10,125,127]{2,1,0} multiply(param_0, param_0) - constant = f32[] constant(0) - reduce = f32[10,125]{1,0} reduce(multiply, constant), dimensions={2}, to_apply=add - broadcast = f32[10,125,127]{2,1,0} broadcast(reduce), dimensions={0,1} - ROOT multiply_out = f32[10,125,127]{2,1,0} multiply(param_0, broadcast) -} - -ENTRY main { - param_0 = f32[10,125,127]{2,1,0} parameter(0) - param_1 = f32[10]{0} parameter(1) - auxiliary_fusion = f32[10,125,127]{2,1,0} fusion(param_0, param_1), kind=kCustom, calls=auxiliary_computation - ROOT triton_softmax = f32[10,125,127]{2,1,0} fusion(auxiliary_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} -} -)"; - auto module = ParseAndReturnVerifiedModule(kHloText).value(); - FusionMergerTriton fusion_merger; - EXPECT_FALSE(fusion_merger.Run(module.get()).value()); - VLOG(2) << module->ToString(); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Fusion()))); -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/fusion_pipeline.cc b/xla/service/gpu/fusion_pipeline.cc index 6554ca858bd91..8d309bce85c56 100644 --- a/xla/service/gpu/fusion_pipeline.cc +++ b/xla/service/gpu/fusion_pipeline.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,6 +36,7 @@ limitations under the License. #include "xla/service/layout_assignment.h" #include "xla/stream_executor/device_description.h" #include "xla/xla.pb.h" +#include "tsl/platform/threadpool.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/fusion_pipeline.h b/xla/service/gpu/fusion_pipeline.h index bd2d7bb1f7ed2..994dd56802518 100644 --- a/xla/service/gpu/fusion_pipeline.h +++ b/xla/service/gpu/fusion_pipeline.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/fusion_process_dump.cc b/xla/service/gpu/fusion_process_dump.cc new file mode 100644 index 0000000000000..9863a3a7b63ef --- /dev/null +++ b/xla/service/gpu/fusion_process_dump.cc @@ -0,0 +1,219 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusion_process_dump.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/fusion_process_dump.pb.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tools/hlo_module_loader.h" +#include "xla/util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/path.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +namespace { + +HloInstruction* AddFusionInstruction(HloInstruction* producer, + HloInstruction* consumer, + HloComputation* computation, + std::string_view fusion_name) { + if (consumer->opcode() == HloOpcode::kFusion) { + return consumer; + } + + // This is not true for all fusions, but the fusion kind isn't used in the + // cost model and fusion pipeline, so it doesn't matter here. Set kLoop for + // everything. + auto kind = HloInstruction::FusionKind::kLoop; + + auto fusion_instruction = computation->AddInstruction( + HloInstruction::CreateFusion(consumer->shape(), kind, consumer), + /*new_name=*/fusion_name); + TF_CHECK_OK(computation->ReplaceInstruction(consumer, fusion_instruction)); + + return fusion_instruction; +} + +HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer, + HloComputation* computation, + std::string_view fusion_name) { + HloInstruction* fusion_instruction = + AddFusionInstruction(producer, consumer, computation, fusion_name); + if (producer->opcode() == HloOpcode::kFusion) { + fusion_instruction->MergeFusionInstruction(producer); + } else { + fusion_instruction->FuseInstruction(producer); + } + + if (producer->user_count() == 0) { + TF_CHECK_OK(computation->RemoveInstruction(producer)); + } + + return fusion_instruction; +} + +absl::string_view GetProducerName(const FusionStep& step) { + if (step.has_fusion()) { + return step.fusion().producer_name(); + } + + if (step.has_update_priority()) { + return step.update_priority().producer_name(); + } + + if (step.has_producer_ineligible()) { + return step.producer_ineligible().producer_name(); + } + + LOG(FATAL) << "Producer name not found in the current step."; +} + +} // namespace + +absl::StatusOr FusionProcessDump::LoadFromFile( + const std::string& path) { + std::string format = std::string(tsl::io::Extension(path)); + std::string data; + TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), path, &data)); + return FusionProcessDump::LoadFromData(data, format); +} + +absl::StatusOr FusionProcessDump::LoadFromData( + const std::string& data, absl::string_view format) { + FusionProcessDumpProto fusion_process_dump_proto; + if (format == "txt" || format == "pbtxt") { + if (!tsl::protobuf::TextFormat::ParseFromString( + data, &fusion_process_dump_proto)) { + return InvalidArgument("Failed to parse input as HLO protobuf text"); + } + } else if (format == "pb") { + if (!fusion_process_dump_proto.ParseFromString(data)) { + return InvalidArgument("Failed to parse input as HLO protobuf binary"); + } + } else { + return InvalidArgument( + "Invalid format from file extension: '%s'. Expected: txt, pb, or pbtxt", + format); + } + + return FusionProcessDump::LoadFromProto(fusion_process_dump_proto); +} + +absl::StatusOr FusionProcessDump::LoadFromProto( + const FusionProcessDumpProto& fusion_process_dump_proto) { + TF_ASSIGN_OR_RETURN( + auto module, + LoadModuleFromData(fusion_process_dump_proto.hlo_module_before_fusion(), + /*format=*/"txt")); + + se::DeviceDescription gpu_device_info( + fusion_process_dump_proto.gpu_device_info()); + + absl::flat_hash_map + instruction_name_to_computation_map; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + for (HloInstruction* instr : computation->instructions()) { + instruction_name_to_computation_map[instr->name()] = computation; + } + } + + return FusionProcessDump(std::move(fusion_process_dump_proto), + std::move(module), std::move(gpu_device_info), + std::move(instruction_name_to_computation_map)); +} + +HloComputation* FusionProcessDump::GetCurrentComputation() { + return instruction_name_to_computation_map_.at( + GetProducerName(CurrentStep())); +} + +HloInstruction* FusionProcessDump::GetInstructionWithName( + absl::string_view name) { + return instruction_name_to_computation_map_[name]->GetInstructionWithName( + name); +} + +HloInstruction* FusionProcessDump::GetProducer() { + return GetInstructionWithName(GetProducerName(CurrentStep())); +} + +absl::InlinedVector FusionProcessDump::GetConsumers() { + auto& step = CurrentStep(); + + if (step.has_fusion()) { + return {GetInstructionWithName(step.fusion().consumer_name())}; + } + + if (step.has_update_priority()) { + absl::InlinedVector consumers; + for (const auto& consumer_name : step.update_priority().consumer_names()) { + consumers.push_back(GetInstructionWithName(consumer_name)); + } + return consumers; + } + + return {}; +} + +const FusionStep& FusionProcessDump::CurrentStep() { + CHECK(HasNext()); + return fusion_process_dump_proto_.fusion_steps(current_step_idx_); +} + +bool FusionProcessDump::HasNext() { + return current_step_idx_ < fusion_process_dump_proto_.fusion_steps_size(); +} + +void FusionProcessDump::Advance() { + auto step = CurrentStep(); + if (step.has_fusion()) { + const auto& fusion_step = step.fusion(); + + auto* computation = GetCurrentComputation(); + + HloInstruction* producer = + computation->GetInstructionWithName(fusion_step.producer_name()); + HloInstruction* consumer = + computation->GetInstructionWithName(fusion_step.consumer_name()); + + HloInstruction* fusion = + Fuse(producer, consumer, computation, fusion_step.fusion_name()); + + instruction_name_to_computation_map_[fusion->name()] = computation; + last_fusion_ = fusion; + } + ++current_step_idx_; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusion_process_dump.h b/xla/service/gpu/fusion_process_dump.h new file mode 100644 index 0000000000000..4f06eadadf48f --- /dev/null +++ b/xla/service/gpu/fusion_process_dump.h @@ -0,0 +1,118 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSION_PROCESS_DUMP_H_ +#define XLA_SERVICE_GPU_FUSION_PROCESS_DUMP_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/fusion_process_dump.pb.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { +namespace gpu { + +// Helper class to work with fusion process dump. +class FusionProcessDump { + public: + static absl::StatusOr LoadFromFile( + const std::string& path); + static absl::StatusOr LoadFromData( + const std::string& data, absl::string_view format); + static absl::StatusOr LoadFromProto( + const FusionProcessDumpProto& fusion_process_dump_proto); + + const FusionProcessDumpProto& proto() { return fusion_process_dump_proto_; } + + HloModule* module() { return hlo_module_.get(); } + + const se::DeviceDescription& device_info() { return device_info_; } + + int64_t current_step_idx() { return current_step_idx_; } + + // Returns computation that contains producer (and other instructions) of the + // current step. + HloComputation* GetCurrentComputation(); + + // Returns the instruction with `name`. + HloInstruction* GetInstructionWithName(absl::string_view name); + + // Returns producer of the current step. Should not be null, since all step + // types have a producer. + HloInstruction* GetProducer(); + + // Returns a list of consumers of the current step. The list contains one + // instruction is the current step is fusion. The list is empty if the current + // step is `producer_ineligible`. + absl::InlinedVector GetConsumers(); + + // Returns result instruction of the last fusion step. Returns nullptr before + // the first fusion. + HloInstruction* GetLastFusion() { return last_fusion_; } + + // Returns current step. If current step is `fusion`, the `module` is in the + // state *before* the fusion. Next call to `FusionProcessDump::Advance` will + // actualy perform the fusion. + const FusionStep& CurrentStep(); + + // Returns true if there are fusion steps. + bool HasNext(); + + // Advances to the next fusion step. If current step is `fusion`, modifies the + // `module` accordingly. + void Advance(); + + private: + FusionProcessDump(FusionProcessDumpProto fusion_process_dump_proto, + std::unique_ptr hlo_module, + se::DeviceDescription device_info, + absl::flat_hash_map + instruction_name_to_computation_map) + : fusion_process_dump_proto_(std::move(fusion_process_dump_proto)), + hlo_module_(std::move(hlo_module)), + device_info_(std::move(device_info)), + instruction_name_to_computation_map_( + std::move(instruction_name_to_computation_map)) {} + + FusionProcessDumpProto fusion_process_dump_proto_; + std::unique_ptr hlo_module_; + se::DeviceDescription device_info_; + + // A map from instructions to computations. HLO module doesn't have a + // convenient way to get an instruction by name. This map saves the need to + // iterator over all computations in the module. + absl::flat_hash_map + instruction_name_to_computation_map_; + + // Index of the current step. + int64_t current_step_idx_ = 0; + + // Tracks result of the last fusion step. + HloInstruction* last_fusion_ = nullptr; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSION_PROCESS_DUMP_H_ diff --git a/xla/service/gpu/fusion_process_dump.proto b/xla/service/gpu/fusion_process_dump.proto index 0fc379441358b..0c52edb46c09e 100644 --- a/xla/service/gpu/fusion_process_dump.proto +++ b/xla/service/gpu/fusion_process_dump.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package xla.gpu; +import "xla/stream_executor/device_description.proto"; + message FusionStep { message Fusion { // Name of the resulting fusion. Can be the same as producer or consumer. @@ -46,4 +48,13 @@ message FusionStep { message FusionProcessDumpProto { repeated FusionStep fusion_steps = 1; + + stream_executor.GpuDeviceInfoProto gpu_device_info = 2; + + // HLO module before fusion in short parsable string format. The string + // represantation is compacter than HloModuleProto in this case, especially + // when the fusion process dump is stored as text proto. + // + // TODO: Consider using base64 or gzip to decrease the size of the string. + string hlo_module_before_fusion = 3; } diff --git a/xla/service/gpu/fusion_process_dump_test.cc b/xla/service/gpu/fusion_process_dump_test.cc new file mode 100644 index 0000000000000..37eb3bee29c83 --- /dev/null +++ b/xla/service/gpu/fusion_process_dump_test.cc @@ -0,0 +1,94 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusion_process_dump.h" + +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/fusion_process_dump.pb.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/hlo_parser.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" +#include "xla/test.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace m = ::xla::match; + +namespace xla { +namespace gpu { +namespace { + +using FusionProcessDumpTest = HloTestBase; + +void AddFusion(FusionProcessDumpProto& dump_proto, + const std::string& fusion_name, const std::string& producer_name, + const std::string& consumer_name) { + auto step = dump_proto.add_fusion_steps(); + auto fusion_step = step->mutable_fusion(); + fusion_step->set_fusion_name(fusion_name); + fusion_step->set_producer_name(producer_name); + fusion_step->set_consumer_name(consumer_name); +} + +TEST_F(FusionProcessDumpTest, MultipleFusionSteps) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule test_module + + ENTRY main { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + add = f32[] add(p0, p1) + subtract = f32[] subtract(p0, p1) + abs = f32[] abs(subtract) + ROOT multiply = f32[] multiply(add, abs) + })")); + + FusionProcessDumpProto dump_proto; + *dump_proto.mutable_gpu_device_info() = + TestGpuDeviceInfo::RTXA6000DeviceInfo().ToGpuProto(); + dump_proto.set_hlo_module_before_fusion( + module->ToString(HloPrintOptions::ShortParsable())); + + AddFusion(dump_proto, "fusion.1", "subtract", "abs"); + AddFusion(dump_proto, "fusion.2", "fusion.1", "multiply"); + AddFusion(dump_proto, "fusion.2", "add", "fusion.2"); + + TF_ASSERT_OK_AND_ASSIGN(auto fusion_process_dump, + FusionProcessDump::LoadFromProto(dump_proto)); + + fusion_process_dump.Advance(); + fusion_process_dump.Advance(); + fusion_process_dump.Advance(); + + EXPECT_FALSE(fusion_process_dump.HasNext()); + + auto root = + fusion_process_dump.module()->entry_computation()->root_instruction(); + EXPECT_EQ(root->name(), "fusion.2"); + ASSERT_THAT(root, GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); + EXPECT_THAT(root->fused_expression_root(), + GmockMatch(m::Multiply( + m::Add(m::Parameter(), m::Parameter()), + m::Abs(m::Subtract(m::Parameter(), m::Parameter()))))); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusion_wrapper.cc b/xla/service/gpu/fusion_wrapper.cc index ba80f5ae33120..5b5120397a266 100644 --- a/xla/service/gpu/fusion_wrapper.cc +++ b/xla/service/gpu/fusion_wrapper.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,25 +17,27 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_fusible.h" -#include "xla/status_macros.h" -#include "xla/statusor.h" +#include "xla/status.h" #include "tsl/platform/errors.h" namespace xla { namespace gpu { -StatusOr FusionWrapper::Run( +absl::StatusOr FusionWrapper::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { auto instructions = module->entry_computation()->MakeInstructionPostOrder(); bool changed = false; std::function handle_instruction; - handle_instruction = [&](HloInstruction* instruction) -> Status { + handle_instruction = [&](HloInstruction* instruction) -> absl::Status { switch (instruction->opcode()) { case HloOpcode::kConditional: case HloOpcode::kWhile: @@ -66,6 +68,7 @@ StatusOr FusionWrapper::Run( case HloOpcode::kDot: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -114,11 +117,14 @@ StatusOr FusionWrapper::Run( auto* fusion_instruction = computation->AddInstruction(HloInstruction::CreateFusion( instruction->shape(), - ChooseFusionKind(*instruction /*unused but required*/, - *instruction), - instruction)); - instruction->GetModule()->SetAndUniquifyInstrName( - fusion_instruction, absl::StrCat("wrapped_", instruction->name())); + ChooseFusionKind(*instruction, *instruction), instruction)); + const absl::string_view wrapped_opcode = + HloOpcodeString(instruction->opcode()); + module->SetAndUniquifyInstrName( + fusion_instruction, absl::StrCat("wrapped_", wrapped_opcode)); + module->SetAndUniquifyComputationName( + fusion_instruction->fused_instructions_computation(), + absl::StrCat("wrapped_", wrapped_opcode, "_computation")); if (module->has_schedule()) { module->schedule().replace_instruction(computation, instruction, fusion_instruction); @@ -134,7 +140,7 @@ StatusOr FusionWrapper::Run( default: break; } - return OkStatus(); + return absl::OkStatus(); }; for (auto* instruction : instructions) { diff --git a/xla/service/gpu/fusion_wrapper.h b/xla/service/gpu/fusion_wrapper.h index 3dd05047ce342..fc466925ce086 100644 --- a/xla/service/gpu/fusion_wrapper.h +++ b/xla/service/gpu/fusion_wrapper.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,10 @@ limitations under the License. #define XLA_SERVICE_GPU_FUSION_WRAPPER_H_ #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" namespace xla { namespace gpu { @@ -31,7 +31,7 @@ class FusionWrapper : public HloModulePass { absl::string_view name() const override { return "fusion-wrapper"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/fusion_wrapper_test.cc b/xla/service/gpu/fusion_wrapper_test.cc index 6a5deaeb0075f..397fe754843b6 100644 --- a/xla/service/gpu/fusion_wrapper_test.cc +++ b/xla/service/gpu/fusion_wrapper_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/fusion_wrapper.h" +#include + #include #include "xla/tests/hlo_test_base.h" @@ -33,7 +35,7 @@ TEST_F(FusionWrapperTest, SimpleOp) { ROOT result = f16[60, 41] concatenate(p0, p1), dimensions={0} })", FusionWrapper(), R"( -// CHECK: %fused_computation (param_0: f16[30,41], param_1: f16[30,41]) -> f16[60,41] { +// CHECK: %wrapped_concatenate_computation (param_0: f16[30,41], param_1: f16[30,41]) -> f16[60,41] { // CHECK: %param_0 = f16[30,41]{1,0} parameter(0) // CHECK: %param_1 = f16[30,41]{1,0} parameter(1) // CHECK: ROOT %result.1 = f16[60,41]{1,0} concatenate(%param_0, %param_1), dimensions={0} @@ -42,7 +44,7 @@ TEST_F(FusionWrapperTest, SimpleOp) { // CHECK: ENTRY %TestComputation (p0: f16[30,41], p1: f16[30,41]) -> f16[60,41] { // CHECK: %p0 = f16[30,41]{1,0} parameter(0) // CHECK: %p1 = f16[30,41]{1,0} parameter(1) -// CHECK: ROOT %wrapped_result = f16[60,41]{1,0} fusion(%p0, %p1), kind=kLoop, calls=%fused_computation +// CHECK: ROOT %wrapped_concatenate = f16[60,41]{1,0} fusion(%p0, %p1), kind=kLoop, calls=%wrapped_concatenate_computation // CHECK: })"); } @@ -67,7 +69,7 @@ TEST_F(FusionWrapperTest, Scatter) { to_apply=update_s32 })", FusionWrapper(), R"( -// CHECK: fused_computation +// CHECK: wrapped_scatter_computation // CHECK: %[[param_0:.*]] = s32[] parameter(0) // CHECK: %[[param_1:.*]] = s32[0]{0} parameter(1) // CHECK: %[[param_2:.*]] = s32[] parameter(2) @@ -77,7 +79,7 @@ TEST_F(FusionWrapperTest, Scatter) { // CHECK: %[[p0:.*]] = s32[] parameter(0) // CHECK: %[[p1:.*]] = s32[0]{0} parameter(1) // CHECK: %[[p2:.*]] = s32[] parameter(2) -// CHECK: ROOT %{{.*}} = s32[] fusion(%[[p0]], %[[p1]], %[[p2]]), kind=kInput, calls=%fused_computation +// CHECK: ROOT %{{.*}} = s32[] fusion(%[[p0]], %[[p1]], %[[p2]]), kind=kInput, calls=%wrapped_scatter_computation // CHECK: })"); } @@ -123,28 +125,28 @@ TEST_F(FusionWrapperTest, While) { ROOT %while.19 = (f32[5]{0}) while((f32[5]{0}) %tuple), condition=%cond, body=%body })", FusionWrapper(), R"( -// CHECK: %fused_computation.1 {{.*}} { +// CHECK: %wrapped_broadcast_computation {{.*}} { // CHECK: %param_0.1 = f32[] parameter(0) // CHECK: ROOT %broadcast.0 = f32[5]{0} broadcast(%param_0.1), dimensions={} // CHECK: } // CHECK: %body {{.*}} { // CHECK: %parameter.5 = (f32[5]{0}) parameter(0) // CHECK: %constant_8 = f32[] constant(0) -// CHECK: %wrapped_broadcast.9 = f32[5]{0} fusion(%constant_8), kind=kLoop, calls=%fused_computation.1 -// CHECK: ROOT %tuple.2 = (f32[5]{0}) tuple(%wrapped_broadcast.9) +// CHECK: %wrapped_broadcast = f32[5]{0} fusion(%constant_8), kind=kLoop, calls=%wrapped_broadcast_computation +// CHECK: ROOT %tuple.2 = (f32[5]{0}) tuple(%wrapped_broadcast) // CHECK: } // CHECK: %cond {{.*}} { // CHECK: %parameter.12 = (f32[5]{0}) parameter(0) // CHECK: ROOT %constant_1 = pred[] constant(false) // CHECK: } -// CHECK: %fused_computation {{.*}} { +// CHECK: %wrapped_copy_computation {{.*}} { // CHECK: %param_0 = f32[5]{0} parameter(0) // CHECK: ROOT %copy.0 = f32[5]{0} copy(%param_0) // CHECK: } // CHECK: ENTRY %main {{.*}} { // CHECK: %parameter.1 = f32[5]{0} parameter(0) -// CHECK: %wrapped_copy.3 = f32[5]{0} fusion(%parameter.1), kind=kLoop, calls=%fused_computation -// CHECK: %tuple = (f32[5]{0}) tuple(%wrapped_copy.3) +// CHECK: %wrapped_copy = f32[5]{0} fusion(%parameter.1), kind=kLoop, calls=%wrapped_copy_computation +// CHECK: %tuple = (f32[5]{0}) tuple(%wrapped_copy) // CHECK: ROOT %while.19 = (f32[5]{0}) while(%tuple), condition=%cond, body=%body // CHECK: })"); } diff --git a/xla/service/gpu/fusions/BUILD b/xla/service/gpu/fusions/BUILD index 4c64b7309ae77..c542823f935fd 100644 --- a/xla/service/gpu/fusions/BUILD +++ b/xla/service/gpu/fusions/BUILD @@ -1,53 +1,225 @@ +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") +load("@tsl//tsl/platform:build_config_root.bzl", "tf_cuda_tests_tags") +load("@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") +load("//xla:xla.bzl", "xla_cc_test") +load("//xla/tests:build_defs.bzl", "xla_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + cc_library( name = "in_place_dynamic_update_slice", srcs = ["in_place_dynamic_update_slice.cc"], hdrs = ["in_place_dynamic_update_slice.h"], deps = [ ":fusion_emitter", + "//xla:status", + "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:ir_emitter", + "//xla/service/gpu:ir_emitter_context", "//xla/service/gpu:launch_dimensions", "//xla/service/llvm_ir:dynamic_update_slice_util", "//xla/service/llvm_ir:fused_ir_emitter", "//xla/service/llvm_ir:ir_array", + "@com_google_absl//absl/status", "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", ], ) +xla_cc_test( + name = "in_place_dynamic_update_slice_test", + srcs = ["in_place_dynamic_update_slice_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":fusions", + ":in_place_dynamic_update_slice", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "in_place_dynamic_update_slice_mlir", + srcs = ["in_place_dynamic_update_slice_mlir.cc"], + hdrs = ["in_place_dynamic_update_slice_mlir.h"], + deps = [ + "//xla:shape_util", + "//xla:status", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions/mlir:computation_partitioner", + "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", + "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TensorDialect", + ], +) + +xla_cc_test( + name = "in_place_dynamic_update_slice_mlir_test", + srcs = ["in_place_dynamic_update_slice_mlir_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":in_place_dynamic_update_slice_mlir", + ":mlir_emitter_test_base", + "//xla:error_spec", + "//xla/service:gpu_plugin", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@tsl//tsl/lib/core:status_test_util", + ], +) + cc_library( name = "copy", srcs = ["copy.cc"], hdrs = ["copy.h"], deps = [ ":fusion_emitter", - "//xla/service/gpu:gpu_executable", + "//xla:statusor", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", "//xla/service/gpu:ir_emitter_context", + "//xla/service/gpu/runtime:copy_thunk", + "//xla/service/gpu/runtime:thunk", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "custom", + srcs = ["custom.cc"], + hdrs = ["custom.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + deps = [ + ":fusion_emitter", + "//xla:shape_util", + "//xla:status", + "//xla:statusor", + "//xla:util", + "//xla/ffi:ffi_api", + "//xla/ffi/api:c_api", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service:custom_call_status", + "//xla/service:custom_call_target_registry", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:ir_emitter_context", + "//xla/service/gpu:kernel_arguments", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu/kernels:custom_kernel", + "//xla/service/gpu/kernels:custom_kernel_fusion", + "//xla/service/gpu/runtime:address_computation_thunk", + "//xla/service/gpu/runtime:custom_call_thunk", + "//xla/service/gpu/runtime:gemm_thunk", + "//xla/service/gpu/runtime:kernel_thunk", + "//xla/service/gpu/runtime:thunk", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) +xla_test( + name = "address_computation_fusion_test", + srcs = if_cuda_is_configured(["address_computation_fusion_test.cc"]), + backends = ["gpu"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + deps = [ + "//xla:error_spec", + "//xla:shape_util", + "//xla/client:xla_builder", + "//xla/client:xla_computation", + "//xla/client/lib:constants", + "//xla/ffi", + "//xla/ffi:ffi_api", + "//xla/hlo/ir:hlo", + "//xla/service:custom_call_target_registry", + "//xla/service:executable", + "//xla/service:hlo_module_config", + "//xla/service/gpu:address_computation_fusion_rewriter", + "//xla/stream_executor", + "//xla/stream_executor:device_description", + "//xla/stream_executor/gpu:gpu_types_header", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + ]), +) + cc_library( name = "fusion_emitter", srcs = ["fusion_emitter.cc"], hdrs = ["fusion_emitter.h"], visibility = ["//xla/service/gpu:__subpackages__"], deps = [ + "//xla:shape_util", + "//xla:status", + "//xla:status_macros", + "//xla:statusor", + "//xla:util", "//xla/hlo/ir:hlo", - "//xla/mlir_hlo:lhlo", - "//xla/service:elemental_ir_emitter", - "//xla/service/gpu:gpu_executable", - "//xla/service/gpu:hlo_to_ir_bindings", "//xla/service/gpu:ir_emitter_context", "//xla/service/gpu:kernel_arguments", "//xla/service/gpu:kernel_reuse_cache", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:target_util", - "//xla/service/gpu:thunk", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_map", + "//xla/service/gpu/runtime:kernel_thunk", + "//xla/service/gpu/runtime:thunk", "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:llvm_util", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", "@tsl//tsl/platform:errors", @@ -61,23 +233,78 @@ cc_library( hdrs = ["fusions.h"], visibility = ["//xla/service/gpu:__subpackages__"], deps = [ + ":concatenate", + ":concatenate_mlir", ":copy", + ":cudnn", + ":custom", ":fusion_emitter", ":in_place_dynamic_update_slice", + ":in_place_dynamic_update_slice_mlir", ":input_slices", + ":input_slices_mlir", ":loop", + ":loop_mlir", ":reduction", + ":reduction_mlir", + ":scatter", + ":scatter_mlir", ":transpose", + ":transpose_mlir", + ":triton", "//xla:shape_util", + "//xla:status", + "//xla:statusor", "//xla/hlo/ir:hlo", - "//xla/mlir_hlo:lhlo", "//xla/service:buffer_assignment", - "//xla/service:elemental_ir_emitter", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:ir_emitter_context", + "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//mlir:IR", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "mlir_emitter_test_base", + testonly = True, + srcs = ["mlir_emitter_test_base.cc"], + hdrs = ["mlir_emitter_test_base.h"], + deps = [ + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "//xla/service:gpu_plugin", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", + "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/model:affine_map_printer", + "//xla/stream_executor:device_description", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MemRefTransforms", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:TensorDialect", + "@tsl//tsl/platform:statusor", ], ) @@ -85,35 +312,353 @@ cc_library( name = "loop", srcs = ["loop.cc"], hdrs = ["loop.h"], - visibility = ["//xla/service/gpu:__subpackages__"], deps = [ ":fusion_emitter", + "//xla:shape_util", + "//xla:status", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:gpu_fusible", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:ir_emitter", + "//xla/service/gpu:ir_emitter_context", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:parallel_loop_emitter", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_map", + "//xla/service/llvm_ir:fused_ir_emitter", + "//xla/service/llvm_ir:ir_array", + "@com_google_absl//absl/numeric:bits", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "loop_mlir", + srcs = ["loop_mlir.cc"], + hdrs = ["loop_mlir.h"], + deps = [ + ":loop", + "//xla:shape_util", + "//xla:status", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions/mlir:computation_partitioner", + "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", + "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", + "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/model:indexing_analysis", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TensorDialect", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "loop_mlir_test", + srcs = ["loop_mlir_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":loop_mlir", + ":mlir_emitter_test_base", + "//xla:error_spec", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "scatter_mlir", + srcs = ["scatter_mlir.cc"], + hdrs = ["scatter_mlir.h"], + deps = [ + ":loop", + "//xla:shape_util", + "//xla:status", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:scatter_simplifier", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions/mlir:computation_partitioner", + "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", + "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", + "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:TensorDialect", + ], +) + +xla_cc_test( + name = "scatter_mlir_test", + srcs = ["scatter_mlir_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":mlir_emitter_test_base", + ":scatter_mlir", + "//xla:error_spec", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "transpose_mlir", + srcs = ["transpose_mlir.cc"], + hdrs = ["transpose_mlir.h"], + deps = [ + ":tiling_util", + "//xla:permutation_util", + "//xla:shape_util", + "//xla:status", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/mlir/utils:type_util", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions/mlir:computation_partitioner", + "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", + "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", + "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_map", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "transpose_mlir_test", + srcs = ["transpose_mlir_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":mlir_emitter_test_base", + ":transpose_mlir", + "//xla:error_spec", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "loop_test", + srcs = ["loop_test.cc"], + deps = [ + ":fusions", + ":loop", + "//xla:status_macros", + "//xla:statusor", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "scatter", + srcs = ["scatter.cc"], + hdrs = ["scatter.h"], + deps = [ + ":fusion_emitter", + ":loop", + "//xla:shape_util", + "//xla:status", "//xla/hlo/ir:hlo", - "//xla/mlir_hlo:lhlo", - "//xla/service:elemental_ir_emitter", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:ir_emitter", "//xla/service/gpu:ir_emitter_context", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:parallel_loop_emitter", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/llvm_ir:fused_ir_emitter", "//xla/service/llvm_ir:ir_array", + "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", ], ) +xla_cc_test( + name = "scatter_test", + srcs = ["scatter_test.cc"], + deps = [ + ":fusions", + ":scatter", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@tsl//tsl/platform:statusor", + ], +) + cc_library( name = "tiling_util", srcs = ["tiling_util.cc"], hdrs = ["tiling_util.h"], visibility = ["//xla/service/gpu:__subpackages__"], deps = [ - "//xla/service/gpu:hlo_fusion_analysis", + "//xla:shape_util", + "//xla:util", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:target_util", "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:kernel_support_library", + "//xla/service/llvm_ir:llvm_loop", + "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "triton", + srcs = ["triton.cc"], + hdrs = ["triton.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + deps = [ + ":fusion_emitter", + "//xla:statusor", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:ir_emitter_context", + "//xla/service/gpu:ir_emitter_triton", + "//xla/service/gpu:kernel_arguments", + "//xla/service/gpu:kernel_reuse_cache", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:triton_fusion_analysis", + "//xla/service/gpu/runtime:kernel_thunk", + "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "triton_test", + srcs = ["triton_test.cc"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + deps = [ + ":fusions", + ":triton", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "cudnn", + srcs = ["cudnn.cc"], + hdrs = ["cudnn.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + deps = [ + ":fusion_emitter", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/runtime:cudnn_thunk", + ], +) + +xla_test( + name = "cudnn_test", + srcs = if_cuda_is_configured(["cudnn_test.cc"]), + backend_tags = {"gpu": [ + "requires-gpu-sm80", + ]}, + backends = [ + "gpu", + ], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:executable", + "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu/tests:gpu_codegen_test", + "//xla/tests:filecheck", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test_main", ], ) @@ -125,17 +670,14 @@ cc_library( deps = [ "//xla:literal", "//xla:shape_util", - "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", - "//xla/service/gpu:gpu_executable", - "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:ir_emitter_context", - "//xla/service/gpu:thunk", - "//xla/translate/hlo_to_mhlo:hlo_utils", + "//xla/service/gpu/runtime:memset_thunk", + "//xla/service/gpu/runtime:thunk", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", ], ) @@ -143,39 +685,49 @@ cc_library( name = "reduction", srcs = ["reduction.cc"], hdrs = ["reduction.h"], - visibility = ["//xla/service/gpu:__subpackages__"], deps = [ ":fusion_emitter", + ":reduction_base", ":thunk_util", ":tiling_util", "//xla:shape_util", "//xla:status", "//xla:status_macros", - "//xla:statusor", + "//xla:union_find", + "//xla:util", "//xla/hlo/ir:hlo", - "//xla/mlir_hlo", - "//xla/mlir_hlo:lhlo", + "//xla/hlo/utils:hlo_query", "//xla/service:buffer_assignment", - "//xla/service:elemental_ir_emitter", - "//xla/service/gpu:gpu_executable", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:ir_emitter", "//xla/service/gpu:ir_emitter_context", "//xla/service/gpu:kernel_arguments", "//xla/service/gpu:kernel_reuse_cache", + "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:parallel_loop_emitter", "//xla/service/gpu:reduction_utils", "//xla/service/gpu:target_util", - "//xla/service/gpu:thunk", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_map", + "//xla/service/gpu/runtime:kernel_thunk", + "//xla/service/gpu/runtime:thunk", "//xla/service/llvm_ir:fused_ir_emitter", "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:kernel_support_library", + "//xla/service/llvm_ir:llvm_loop", "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:loop_emitter", - "//xla/translate/mhlo_to_hlo:location_exporter", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", @@ -189,27 +741,261 @@ cc_library( ], ) +cc_library( + name = "reduction_base", + srcs = ["reduction_base.cc"], + hdrs = ["reduction_base.h"], + deps = [ + ":fusion_emitter", + ":tiling_util", + "//xla:shape_util", + "//xla:union_find", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service/gpu:gpu_fusible", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:reduction_utils", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_map", + "//xla/stream_executor:device_description", + "//xla/stream_executor:launch_dim", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +xla_cc_test( + name = "reduction_base_test", + srcs = ["reduction_base_test.cc"], + deps = [ + ":fusion_emitter", + ":reduction_base", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:ir_emitter_context", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "reduction_mlir", + srcs = ["reduction_mlir.cc"], + hdrs = ["reduction_mlir.h"], + deps = [ + ":reduction_base", + "//xla:shape_util", + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:reduction_utils", + "//xla/service/gpu/fusions/mlir:computation_partitioner", + "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", + "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", + "//xla/service/gpu/fusions/mlir:type_util", + "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_map", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:BufferizationInterfaces", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:SCFDialect", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "reduction_mlir_test", + srcs = ["reduction_mlir_test.cc"], + shard_count = 11, + tags = tf_cuda_tests_tags(), + deps = [ + ":mlir_emitter_test_base", + ":reduction_mlir", + "//xla:error_spec", + "//xla/service:gpu_plugin", + "//xla/tests:filecheck", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "concatenate", + srcs = ["concatenate.cc"], + hdrs = ["concatenate.h"], + deps = [ + ":fusion_emitter", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:ir_emitter", + "//xla/service/gpu:ir_emitter_context", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:parallel_loop_emitter", + "//xla/service/gpu/model:indexing_map", + "//xla/service/llvm_ir:fused_ir_emitter", + "//xla/service/llvm_ir:ir_array", + "//xla/service/llvm_ir:loop_emitter", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "concatenate_test", + srcs = ["concatenate_test.cc"], + deps = [ + ":concatenate", + ":fusions", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "concatenate_mlir", + srcs = ["concatenate_mlir.cc"], + hdrs = ["concatenate_mlir.h"], + deps = [ + ":concatenate", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions/mlir:computation_partitioner", + "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", + "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", + "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_map", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TensorDialect", + "@tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "concatenate_mlir_test", + srcs = ["concatenate_mlir_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":concatenate", + ":concatenate_mlir", + ":mlir_emitter_test_base", + "//xla:error_spec", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@tsl//tsl/lib/core:status_test_util", + ], +) + cc_library( name = "transpose", srcs = ["transpose.cc"], hdrs = ["transpose.h"], - visibility = ["//xla/service/gpu:__subpackages__"], deps = [ ":fusion_emitter", ":tiling_util", "//xla:permutation_util", + "//xla:status", + "//xla:util", "//xla/hlo/ir:hlo", - "//xla/mlir_hlo:lhlo", - "//xla/service:elemental_ir_emitter", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:ir_emitter", "//xla/service/gpu:ir_emitter_context", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:target_util", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_map", "//xla/service/llvm_ir:fused_ir_emitter", "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + ], +) + +xla_cc_test( + name = "transpose_test", + srcs = ["transpose_test.cc"], + deps = [ + ":fusions", + ":transpose", + "//xla:status_macros", + "//xla:statusor", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:statusor", ], ) @@ -217,17 +1003,92 @@ cc_library( name = "input_slices", srcs = ["input_slices.cc"], hdrs = ["input_slices.h"], - visibility = ["//xla/service/gpu:__subpackages__"], deps = [ ":fusion_emitter", + "//xla:shape_util", + "//xla:status", + "//xla:statusor", + "//xla:util", "//xla/hlo/ir:hlo", - "//xla/service:elemental_ir_emitter", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:ir_emitter", + "//xla/service/gpu:ir_emitter_context", + "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:parallel_loop_emitter", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/llvm_ir:fused_ir_emitter", + "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:kernel_support_library", - "//xla/service/llvm_ir:llvm_util", + "//xla/service/llvm_ir:llvm_loop", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "input_slices_mlir", + srcs = ["input_slices_mlir.cc"], + hdrs = ["input_slices_mlir.h"], + deps = [ + "//xla:status", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions/mlir:computation_partitioner", + "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", + "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", + "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/model:indexing_map", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TensorDialect", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "input_slices_mlir_test", + srcs = ["input_slices_mlir_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":input_slices_mlir", + ":mlir_emitter_test_base", + "//xla:error_spec", + "//xla/service:gpu_plugin", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@tsl//tsl/lib/core:status_test_util", + ], +) + +xla_cc_test( + name = "input_slices_test", + srcs = ["input_slices_test.cc"], + deps = [ + ":fusions", + ":input_slices", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/gpu/fusions/address_computation_fusion_test.cc b/xla/service/gpu/fusions/address_computation_fusion_test.cc new file mode 100644 index 0000000000000..323b95fbe8d63 --- /dev/null +++ b/xla/service/gpu/fusions/address_computation_fusion_test.cc @@ -0,0 +1,2741 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "xla/client/lib/constants.h" +#include "xla/client/xla_builder.h" +#include "xla/error_spec.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/gpu/address_computation_fusion_rewriter.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/service_executable_run_options.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/stream.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +#define PLATFORM "CUDA" +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: keep +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/gpus/cuda/include/driver_types.h" +#elif TENSORFLOW_USE_ROCM +#include "rocm/include/hip/hip_runtime.h" +#define PLATFORM "ROCM" +#endif + +#if GOOGLE_CUDA +#define gpuSuccess cudaSuccess +#define gpuMemcpyAsync cudaMemcpyAsync +#define gpuMemcpyDeviceToDevice cudaMemcpyDeviceToDevice +#define gpuMemcpy cudaMemcpy +#define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost +#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice +#elif TENSORFLOW_USE_ROCM +#define gpuSuccess hipSuccess +#define gpuMemcpyAsync hipMemcpyAsync +#define gpuMemcpyDeviceToDevice hipMemcpyDeviceToDevice +#define gpuMemcpy hipMemcpy +#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost +#define gpuMemcpyHostToDevice hipMemcpyHostToDevice +#endif + +namespace xla { +namespace gpu { +namespace { + +class AddressComputationFusionTest : public HloTestBase { + public: + HloModuleConfig GetModuleConfigWithoutCommandBuffer() { + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.clear_xla_gpu_enable_command_buffer(); + HloModuleConfig config; + config.set_debug_options(debug_options); + return config; + } +}; + +TEST_F(AddressComputationFusionTest, CublasGemmSimple) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main.9 { + %p0 = bf16[2,8,8]{2,1,0} parameter(0), sharding={replicated} + %p1 = bf16[2,8,8]{2,1,0} parameter(1), sharding={replicated} + %slice.13 = bf16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} + %bitcast.41 = bf16[8,8]{1,0} bitcast(%slice.13) + %slice.14 = bf16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} + %bitcast.42 = bf16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = bf16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %fused_computation { + %param_0_0 = bf16[2,8,8]{2,1,0} parameter(0) + %slice.13 = bf16[1,8,8]{2,1,0} slice(%param_0_0), slice={[1:2], [0:8], [0:8]} + %bitcast.41 = bf16[8,8]{1,0} bitcast(%slice.13) + %param_1_0 = bf16[2,8,8]{2,1,0} parameter(1) + %slice.14 = bf16[1,8,8]{2,1,0} slice(%param_1_0), slice={[1:2], [0:8], [0:8]} + %bitcast.42 = bf16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = bf16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + } + + ENTRY %main.9 { + %p0 = bf16[2,8,8]{2,1,0} parameter(0), sharding={replicated} + %p1 = bf16[2,8,8]{2,1,0} parameter(1), sharding={replicated} + ROOT %fusion.2 = bf16[8,8]{1,0} fusion(%p0, %p1), kind=kCustom, calls=%fused_computation, + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, CublasGemmWithWorkspace) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0), sharding={replicated} + %p1 = f16[2,8,8]{2,1,0} parameter(1), sharding={replicated} + %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %fused_computation { + %param_0_0 = f16[2,8,8]{2,1,0} parameter(0) + %slice.13 = f16[1,8,8]{2,1,0} slice(%param_0_0), slice={[1:2], [0:8], [0:8]} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %param_1_0 = f16[2,8,8]{2,1,0} parameter(1) + %slice.14 = f16[1,8,8]{2,1,0} slice(%param_1_0), slice={[1:2], [0:8], [0:8]} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0 + %get-tuple-element.1 = s8[256]{0} get-tuple-element(%custom-call.1), index=1 + ROOT %tuple = (f16[8,8]{1,0}, s8[256]{0}) tuple(%get-tuple-element.0, %get-tuple-element.1) + } + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0), sharding={replicated} + %p1 = f16[2,8,8]{2,1,0} parameter(1), sharding={replicated} + ROOT %fusion.2 = (f16[8,8]{1,0}, s8[256]{0}) fusion(%p0, %p1), kind=kCustom, calls=%fused_computation, + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, ContiguousSlice) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main.9 { + %p0 = bf16[2,8,8]{2,1,0} parameter(0), sharding={replicated} + %p1 = bf16[8,8,10,8]{3,2,1,0} parameter(1), sharding={replicated} + %slice.13 = bf16[1,4,8]{2,1,0} slice(%p0), slice={[1:2], [0:4], [0:8]} + %bitcast.41 = bf16[4,8]{1,0} bitcast(%slice.13) + %slice.14 = bf16[1,1,8,8]{3,2,1,0} slice(%p1), slice={[0:1], [5:6], [2:10], [0:8]} + %bitcast.42 = bf16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = bf16[4,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %fused_computation { + %param_0_0 = bf16[2,8,8]{2,1,0} parameter(0) + %slice.13 = bf16[1,4,8]{2,1,0} slice(%param_0_0), slice={[1:2], [0:4], [0:8]} + %bitcast.41 = bf16[4,8]{1,0} bitcast(%slice.13) + %param_1_0 = bf16[8,8,10,8]{3,2,1,0} parameter(1) + %slice.14 = bf16[1,1,8,8]{3,2,1,0} slice(%param_1_0), slice={[0:1], [5:6], [2:10], [0:8]} + %bitcast.42 = bf16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = bf16[4,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + } + + ENTRY %main.9 { + %p0 = bf16[2,8,8]{2,1,0} parameter(0), sharding={replicated} + %p1 = bf16[8,8,10,8]{3,2,1,0} parameter(1), sharding={replicated} + ROOT %fusion.2 = bf16[4,8]{1,0} fusion(%p0, %p1), kind=kCustom, calls=%fused_computation, + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, ContiguousSliceNonDefaultLayout) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main.9 { + %p0 = bf16[2,8,8]{1,2,0} parameter(0), sharding={replicated} + %p1 = bf16[8,8,10,8]{1,2,3,0} parameter(1), sharding={replicated} + %slice.13 = bf16[1,8,4]{1,2,0} slice(%p0), slice={[1:2], [0:8], [0:4]} + %bitcast.41 = bf16[4,8]{1,0} bitcast(%slice.13) + %slice.14 = bf16[1,8,8,1]{1,2,3,0} slice(%p1), slice={[0:1], [0:8], [2:10], [5:6]} + %bitcast.42 = bf16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = bf16[4,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %fused_computation { + %param_0_0 = bf16[2,8,8]{1,2,0} parameter(0) + %slice.13 = bf16[1,8,4]{1,2,0} slice(%param_0_0), slice={[1:2], [0:8], [0:4]} + %bitcast.41 = bf16[4,8]{1,0} bitcast(%slice.13) + %param_1_0 = bf16[8,8,10,8]{1,2,3,0} parameter(1) + %slice.14 = bf16[1,8,8,1]{1,2,3,0} slice(%param_1_0), slice={[0:1], [0:8], [2:10], [5:6]} + %bitcast.42 = bf16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = bf16[4,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + } + + ENTRY %main.9 { + %p0 = bf16[2,8,8]{1,2,0} parameter(0), sharding={replicated} + %p1 = bf16[8,8,10,8]{1,2,3,0} parameter(1), sharding={replicated} + ROOT %fusion.2 = bf16[4,8]{1,0} fusion(%p0, %p1), kind=kCustom, calls=%fused_computation, + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, OperandIsSlicedGetTupleElement) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %get-tuple-element.240 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.241 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.10 = f32[200,100]{1,0} concatenate(%get-tuple-element.240, %get-tuple-element.241), dimensions={0} + %custom-call.16 = (f32[200,100]{1,0}, s8[120000]{0}) custom-call(%concatenate.10, %get-tuple-element.240), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"20000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.97 = f32[200,100]{1,0} get-tuple-element(%custom-call.16), index=0 + %slice.26 = f32[100,100]{1,0} slice(%get-tuple-element.97), slice={[0:100], [0:100]} + ROOT %custom-call.17 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%slice.26, %get-tuple-element.240), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %address-computation { + %p0.3 = f32[200,100]{1,0} parameter(0) + %p1.3 = f32[100,100]{1,0} parameter(1) + %slice.56 = f32[100,100]{1,0} slice(%p0.3), slice={[0:100], [0:100]} + %cublas-gemm.23 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%slice.56, %p1.3), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.221 = f32[100,100]{1,0} get-tuple-element(%cublas-gemm.23), index=0 + %get-tuple-element.222 = s8[80000]{0} get-tuple-element(%cublas-gemm.23), index=1 + ROOT %tuple.58 = (f32[100,100]{1,0}, s8[80000]{0}) tuple(%get-tuple-element.221, %get-tuple-element.222) + } + + ENTRY %main { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %get-tuple-element.240 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.241 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.10 = f32[200,100]{1,0} concatenate(%get-tuple-element.240, %get-tuple-element.241), dimensions={0} + %custom-call.16 = (f32[200,100]{1,0}, s8[120000]{0}) custom-call(%concatenate.10, %get-tuple-element.240), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"20000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.97 = f32[200,100]{1,0} get-tuple-element(%custom-call.16), index=0 + ROOT %address_computation.6 = (f32[100,100]{1,0}, s8[80000]{0}) fusion(%get-tuple-element.97, %get-tuple-element.240), + kind=kCustom, + calls=%address-computation, + backend_config={ + "fusion_backend_config":{ + "kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"} + } + } + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, ReversedOperandOrder) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.42, %bitcast.41), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %address-computation { + %p0.1 = f16[2,8,8]{2,1,0} parameter(0) + %slice.1 = f16[1,8,8]{2,1,0} slice(%p0.1), slice={[1:2], [0:8], [0:8]} + %bitcast.1 = f16[8,8]{1,0} bitcast(%slice.1) + %p1.1 = f16[2,8,8]{2,1,0} parameter(1) + %slice.0 = f16[1,8,8]{2,1,0} slice(%p1.1), slice={[0:1], [0:8], [0:8]} + %bitcast.0 = f16[8,8]{1,0} bitcast(%slice.0) + ROOT %custom-call.0 = f16[8,8]{1,0} custom-call(%bitcast.1, %bitcast.0), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + } + } + } + + ENTRY %main { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + ROOT %address_computation.6 = f16[8,8]{1,0} fusion(%p1, %p0), + kind=kCustom, + calls=%address-computation, + backend_config={ + "fusion_backend_config":{ + "kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"} + } + } + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, SingleOperandComputation) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %get-tuple-element.240 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.241 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.10 = f32[200,100]{1,0} concatenate(%get-tuple-element.240, %get-tuple-element.241), dimensions={0} + %custom-call.16 = (f32[200,100]{1,0}, s8[120000]{0}) custom-call(%concatenate.10, %get-tuple-element.240), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"20000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.97 = f32[200,100]{1,0} get-tuple-element(%custom-call.16), index=0 + %slice.26 = f32[100,100]{1,0} slice(%get-tuple-element.97), slice={[0:100], [0:100]} + ROOT %custom-call.17 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%slice.26, %slice.26), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %address-computation { + %p0.3 = f32[200,100]{1,0} parameter(0) + %slice.56 = f32[100,100]{1,0} slice(%p0.3), slice={[0:100], [0:100]} + %cublas-gemm.23 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%slice.56, %slice.56), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.221 = f32[100,100]{1,0} get-tuple-element(%cublas-gemm.23), index=0 + %get-tuple-element.222 = s8[80000]{0} get-tuple-element(%cublas-gemm.23), index=1 + ROOT %tuple.58 = (f32[100,100]{1,0}, s8[80000]{0}) tuple(%get-tuple-element.221, %get-tuple-element.222) + } + + ENTRY %main { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %get-tuple-element.240 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.241 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.10 = f32[200,100]{1,0} concatenate(%get-tuple-element.240, %get-tuple-element.241), dimensions={0} + %custom-call.16 = (f32[200,100]{1,0}, s8[120000]{0}) custom-call(%concatenate.10, %get-tuple-element.240), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"20000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.97 = f32[200,100]{1,0} get-tuple-element(%custom-call.16), index=0 + ROOT %address_computation.6 = (f32[100,100]{1,0}, s8[80000]{0}) fusion(%get-tuple-element.97), + kind=kCustom, + calls=%address-computation, + backend_config={ + "fusion_backend_config":{ + "kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"} + } + } + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, SlicedOperandAliasingOutput) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main.9 { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %get-tuple-element.287 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.288 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.12 = f32[200,100]{1,0} concatenate(%get-tuple-element.287, %get-tuple-element.288), dimensions={0} + %slice.30 = f32[100,100]{1,0} slice(%concatenate.12), slice={[20:120], [0:100]} + %slice.34 = f32[100,100]{1,0} slice(%concatenate.12), slice={[99:199], [0:100]} + ROOT %cublas-gemm.15 = (f32[100,100]{1,0}, s8[120000]{0}) custom-call(%get-tuple-element.287, %slice.30, %slice.34), + custom_call_target="__cublas$gemm", + output_to_operand_aliasing={{0}: (2, {})}, + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":1, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + }} + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %address-computation { + %p0.1 = f32[100,100]{1,0} parameter(0) + %p2 = f32[200,100]{1,0} parameter(2) + %slice.0 = f32[100,100]{1,0} slice(f32[200,100]{1,0} %p2), slice={[20:120], [0:100]} + %p1 = f32[100,100]{1,0} parameter(1) + %cublas-gemm.0 = (f32[100,100]{1,0}, s8[120000]{0}) custom-call(%p0.1, %slice.0, %p1), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":1, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element = f32[100,100]{1,0} get-tuple-element(%cublas-gemm.0), index=0 + %get-tuple-element.1 = s8[120000]{0} get-tuple-element(%cublas-gemm.0), index=1 + ROOT %tuple = (f32[100,100]{1,0}, s8[120000]{0}) tuple(%get-tuple-element, %get-tuple-element.1) + } + + ENTRY %main { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %get-tuple-element.287 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.288 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.12 = f32[200,100]{1,0} concatenate(%get-tuple-element.287, %get-tuple-element.288), dimensions={0} + %slice.34 = f32[100,100]{1,0} slice(%concatenate.12), slice={[99:199], [0:100]} + ROOT %address_computation.6 = (f32[100,100]{1,0}, s8[120000]{0}) fusion(%get-tuple-element.287, %slice.34, %concatenate.12), + kind=kCustom, + calls=%address-computation, + output_to_operand_aliasing={{0}: (1, {})}, + backend_config={ + "fusion_backend_config":{ + "kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"} + } + } + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +static absl::Status Memcpy(se::Stream* stream, ffi::BufferBase src, + ffi::Result dst) { + return stream->MemcpyD2D( + &dst->data, src.data, + absl::c_accumulate(src.dimensions, 1.0, std::multiplies()) * + sizeof(float)); +} + +XLA_FFI_DEFINE_HANDLER(kMemcpy, Memcpy, + ffi::Ffi::Bind() + .Ctx() + .Arg() // src + .Ret() // dst +); +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$memcpy", PLATFORM, + kMemcpy); + +TEST_F(AddressComputationFusionTest, CustomCallSimple) { + XlaBuilder b(TestName()); + CustomCall(&b, "__xla_test$$memcpy", + /*operands=*/ + {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0}, + {128}, {1})}, + ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + AddressComputationFusionRewriter pass(PLATFORM); + TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); + EXPECT_TRUE(changed); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), + error_spec, /*run_hlo_passes=*/false)); +} + +static absl::Status SubBuffers( + se::Stream* stream, ffi::BufferBase src0, ffi::BufferBase src1, + ffi::BufferBase src2, ffi::BufferBase src3, ffi::BufferBase src4, + ffi::BufferBase src5, ffi::BufferBase src6, ffi::BufferBase src7, + ffi::Result dst0, ffi::Result dst1, + ffi::Result dst2, ffi::Result dst3, + ffi::Result dst4, ffi::Result dst5, + ffi::Result dst6) { + // src0: param 0 at tuple index {0}, shape f32[128] + // src1: param 0 at tuple index {1}, shape f32[256] + // src2: param 1 at tuple index {0}, shape f32[1024] + // src3: param 1 at tuple index {1}, shape f32[8] + // src4: param 2, shape f32[4,8] + // src5: param 3 at tuple index {0, 0}, shape f32[32] + // src6: param 3 at tuple index {0, 1}, shape f32[64] + // src7: param 3 at tuple index {1}, shape f32[3,128] + // + // dst0: result at tuple index {0}, shape f32[8] + // dst1: result at tuple index {1, 0}, shape f32[128] + // dst2: result at tuple index {1, 1}, shape f32[256] + // dst3: result at tuple index {2}, shape f32[1024] + // dst4: result at tuple index {3}, shape f32[4,8] + // dst5: result at tuple index {4}, shape f32[3,128] + // dst6: result at tuple index {5}, shape f32[96] + + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst0->data, src3.data, 8 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst1->data, src0.data, 128 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst2->data, src1.data, 256 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst3->data, src2.data, 1024 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst4->data, src4.data, 4 * 8 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst5->data, src7.data, 3 * 128 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst6->data, src6.data, 64 * sizeof(float))); + stream_executor::DeviceMemoryBase slice = + dst6->data.GetByteSlice(64 * sizeof(float), 32 * sizeof(float)); + TF_RETURN_IF_ERROR(stream->MemcpyD2D(&slice, src6.data, 32 * sizeof(float))); + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kSubBuffers, SubBuffers, + ffi::Ffi::Bind() + .Ctx() + .Arg() // src0 + .Arg() // src1 + .Arg() // src2 + .Arg() // src3 + .Arg() // src4 + .Arg() // src5 + .Arg() // src6 + .Arg() // src7 + .Ret() // dst0 + .Ret() // dst1 + .Ret() // dst2 + .Ret() // dst3 + .Ret() // dst4 + .Ret() // dst5 + .Ret() // dst6 +); +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$subbuffers", + PLATFORM, kSubBuffers); + +TEST_F(AddressComputationFusionTest, CustomCallWithTuple) { + XlaBuilder b(TestName()); + CustomCall( + &b, "__xla_test$$subbuffers", /*operands=*/ + { + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 1), {128}), + Broadcast(ConstantR0WithType(&b, F32, 2), {256}), + }), + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 3), {1024}), + Broadcast(ConstantR0WithType(&b, F32, 4), {8}), + }), + Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}), {0, 0}, + {4, 8}, {1, 1}), + Tuple(&b, + { + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 6), {32}), + Broadcast(ConstantR0WithType(&b, F32, 7), {64}), + }), + Slice(Parameter(&b, 0, ShapeUtil::MakeShape(S32, {4, 128}), + "p0"), + {1, 0}, {4, 128}, {1, 1}), + }), + }, + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {8}), + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {128}), + ShapeUtil::MakeShape(F32, {256}), + }), + ShapeUtil::MakeShape(F32, {1024}), + ShapeUtil::MakeShape(F32, {4, 8}), + ShapeUtil::MakeShape(F32, {3, 128}), + ShapeUtil::MakeShape(F32, {32 + 64}), + }), + /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/true); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + debug_options.set_xla_gpu_enable_address_computation_fusion(true); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + AddressComputationFusionRewriter pass(PLATFORM); + TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); + EXPECT_TRUE(changed); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), + error_spec, /*run_hlo_passes=*/false)); +} + +static absl::Status NoOp(se::Stream* stream, ffi::BufferBase operand) { + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kNoOp, NoOp, + ffi::Ffi::Bind() + .Ctx() // stream + .Arg() // operand +); +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$noop", PLATFORM, + kNoOp); + +TEST_F(AddressComputationFusionTest, NilTuple) { + XlaBuilder b(TestName()); + CustomCall(&b, "__xla_test$$noop", + /*operands=*/ + {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0}, + {128}, {1})}, + ShapeUtil::MakeNil(), + /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + debug_options.set_xla_gpu_enable_address_computation_fusion(true); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + AddressComputationFusionRewriter pass(PLATFORM); + TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); + EXPECT_TRUE(changed); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), + error_spec, + /*run_hlo_passes=*/false)); +} + +void Callback_Memcpy(se::gpu::GpuStreamHandle stream, void** buffers, + const char* /*opaque*/, size_t /*opaque_len*/) { + void* src = buffers[0]; + void* dst = buffers[1]; + auto err = gpuMemcpyAsync(dst, src, /*count=*/sizeof(float) * 3 * 128, + gpuMemcpyDeviceToDevice, stream); + ASSERT_EQ(err, gpuSuccess); +} + +XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Memcpy, PLATFORM); + +TEST_F(AddressComputationFusionTest, CustomCallLegacyAPI) { + XlaBuilder b(TestName()); + CustomCall(&b, "Callback_Memcpy", + /*operands=*/ + {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {512}), {128}, + {4 * 128}, {1})}, + ShapeUtil::MakeShape(F32, {3 * 128}), /*opaque=*/""); + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + debug_options.set_xla_gpu_enable_address_computation_fusion(true); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + AddressComputationFusionRewriter pass(PLATFORM); + TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); + EXPECT_TRUE(changed); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), + error_spec, + /*run_hlo_passes=*/false)); +} + +void Callback_Void(se::gpu::GpuStreamHandle /*stream*/, void** /*buffers*/, + const char* /*opaque*/, size_t /*opaque_len*/) {} + +XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Void, PLATFORM); + +TEST_F(AddressComputationFusionTest, NilTupleLegacyAPI) { + XlaBuilder b(TestName()); + CustomCall(&b, "Callback_Void", /*operands=*/ + {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0}, + {128}, {1})}, + ShapeUtil::MakeNil(), + /*opaque=*/""); + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + debug_options.set_xla_gpu_enable_address_computation_fusion(true); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + AddressComputationFusionRewriter pass(PLATFORM); + TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); + EXPECT_TRUE(changed); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), + error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, CublasGemmDynamic) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY main.9 { + p0 = bf16[2,8,8]{2,1,0} parameter(0) + p1 = bf16[2,8,8]{2,1,0} parameter(1) + c1_s32 = s32[] constant(1) + c0_s32 = s32[] constant(0) + slice.13 = bf16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.41 = bf16[8,8]{1,0} bitcast(slice.13) + slice.14 = bf16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.42 = bf16[8,8]{1,0} bitcast(slice.14) + + ROOT custom-call.1 = bf16[8,8]{1,0} custom-call(bitcast.41, bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + fused_computation { + p0 = bf16[2,8,8]{2,1,0} parameter(0) + p1 = bf16[2,8,8]{2,1,0} parameter(1) + c1_s32 = s32[] parameter(2) + c0_s32 = s32[] parameter(3) + slice.13 = bf16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.41 = bf16[8,8]{1,0} bitcast(slice.13) + slice.14 = bf16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.42 = bf16[8,8]{1,0} bitcast(slice.14) + + ROOT custom-call.1 = bf16[8,8]{1,0} custom-call(bitcast.41, bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + } + + ENTRY main.9 { + p0 = bf16[2,8,8]{2,1,0} parameter(0) + p1 = bf16[2,8,8]{2,1,0} parameter(1) + c1_s32 = s32[] constant(1) + c0_s32 = s32[] constant(0) + ROOT fusion.2 = bf16[8,8]{1,0} fusion(p0, p1, c1_s32, c0_s32), kind=kCustom, calls=fused_computation, + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, CublasGemmDynamicWithWorkspace) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0), sharding={replicated} + %p1 = f16[2,8,8]{2,1,0} parameter(1), sharding={replicated} + %c1_s32 = s32[] constant(1) + %c0_s32 = s32[] constant(0) + %slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(%p0, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(%p1, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %fused_computation { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %c1_s32 = s32[] parameter(2) + %c0_s32 = s32[] parameter(3) + %slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(%p0, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(%p1, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0 + %get-tuple-element.1 = s8[256]{0} get-tuple-element(%custom-call.1), index=1 + ROOT %tuple = (f16[8,8]{1,0}, s8[256]{0}) tuple(%get-tuple-element.0, %get-tuple-element.1) + } + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0), sharding={replicated} + %p1 = f16[2,8,8]{2,1,0} parameter(1), sharding={replicated} + %c1_s32 = s32[] constant(1) + %c0_s32 = s32[] constant(0) + ROOT %fusion.2 = (f16[8,8]{1,0}, s8[256]{0}) fusion(%p0, %p1, %c1_s32, %c0_s32), kind=kCustom, calls=%fused_computation, + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, DynamicContiguousSlice) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main.9 { + %p0 = bf16[2,8,8]{2,1,0} parameter(0), sharding={replicated} + %p1 = bf16[8,8,10,8]{3,2,1,0} parameter(1), sharding={replicated} + %c1_s32 = s32[] constant(1) + %c0_s32 = s32[] constant(0) + %c2_s32 = s32[] constant(2) + %c5_s32 = s32[] constant(5) + %slice.13 = bf16[1,4,8]{2,1,0} dynamic-slice(%p0, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,4,8} + %bitcast.41 = bf16[4,8]{1,0} bitcast(%slice.13) + %slice.14 = bf16[1,1,8,8]{3,2,1,0} dynamic-slice(%p1, %c1_s32, %c5_s32, %c2_s32, %c0_s32), dynamic_slice_sizes={1,1,8,8} + %bitcast.42 = bf16[8,8]{1,0} bitcast(%slice.14) + ROOT %custom-call.1 = bf16[4,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %fused_computation { + %p0 = bf16[2,8,8]{2,1,0} parameter(0) + %p1 = bf16[8,8,10,8]{3,2,1,0} parameter(1) + %c1_s32 = s32[] parameter(2) + %c0_s32 = s32[] parameter(3) + %c2_s32 = s32[] parameter(4) + %c5_s32 = s32[] parameter(5) + %slice.13 = bf16[1,4,8]{2,1,0} dynamic-slice(%p0, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,4,8} + %bitcast.41 = bf16[4,8]{1,0} bitcast(%slice.13) + %slice.14 = bf16[1,1,8,8]{3,2,1,0} dynamic-slice(%p1, %c1_s32, %c5_s32, %c2_s32, %c0_s32), dynamic_slice_sizes={1,1,8,8} + %bitcast.42 = bf16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = bf16[4,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + } + + ENTRY %main.9 { + %p0 = bf16[2,8,8]{2,1,0} parameter(0), sharding={replicated} + %p1 = bf16[8,8,10,8]{3,2,1,0} parameter(1), sharding={replicated} + %c1_s32 = s32[] constant(1) + %c0_s32 = s32[] constant(0) + %c2_s32 = s32[] constant(2) + %c5_s32 = s32[] constant(5) + ROOT %fusion.2 = bf16[4,8]{1,0} fusion(%p0, %p1, %c1_s32, %c0_s32, %c2_s32, %c5_s32), kind=kCustom, + calls=%fused_computation, + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, DynamicContiguousSliceNonDefaultLayout) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main.9 { + %p0 = bf16[2,8,8]{1,2,0} parameter(0), sharding={replicated} + %p1 = bf16[8,8,10,8]{1,2,3,0} parameter(1), sharding={replicated} + %c1_s32 = s32[] constant(1) + %c0_s32 = s32[] constant(0) + %c2_s32 = s32[] constant(2) + %c5_s32 = s32[] constant(5) + %slice.13 = bf16[1,8,4]{1,2,0} dynamic-slice(%p0, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,4} + %bitcast.41 = bf16[4,8]{1,0} bitcast(%slice.13) + %slice.14 = bf16[1,8,8,1]{1,2,3,0} dynamic-slice(%p1, %c0_s32, %c0_s32, %c2_s32, %c5_s32), dynamic_slice_sizes={1,8,8,1} + %bitcast.42 = bf16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = bf16[4,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %fused_computation { + %p0 = bf16[2,8,8]{1,2,0} parameter(0) + %p1 = bf16[8,8,10,8]{1,2,3,0} parameter(1) + %c1_s32 = s32[] parameter(2) + %c0_s32 = s32[] parameter(3) + %c2_s32 = s32[] parameter(4) + %c5_s32 = s32[] parameter(5) + %slice.13 = bf16[1,8,4]{1,2,0} dynamic-slice(%p0, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,4} + %bitcast.41 = bf16[4,8]{1,0} bitcast(%slice.13) + %slice.14 = bf16[1,8,8,1]{1,2,3,0} dynamic-slice(%p1, %c0_s32, %c0_s32, %c2_s32, %c5_s32), dynamic_slice_sizes={1,8,8,1} + %bitcast.42 = bf16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = bf16[4,8]{1,0} custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + } + + ENTRY %main.9 { + %p0 = bf16[2,8,8]{1,2,0} parameter(0), sharding={replicated} + %p1 = bf16[8,8,10,8]{1,2,3,0} parameter(1), sharding={replicated} + %c1_s32 = s32[] constant(1) + %c0_s32 = s32[] constant(0) + %c2_s32 = s32[] constant(2) + %c5_s32 = s32[] constant(5) + ROOT %fusion.2 = bf16[4,8]{1,0} fusion(%p0, %p1, %c1_s32, %c0_s32, %c2_s32, %c5_s32), kind=kCustom, + calls=%fused_computation, + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, DynamicOperandIsSlicedGetTupleElement) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %c0_s32 = s32[] constant(0) + %get-tuple-element.240 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.241 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.10 = f32[200,100]{1,0} concatenate(%get-tuple-element.240, %get-tuple-element.241), dimensions={0} + %custom-call.16 = (f32[200,100]{1,0}, s8[120000]{0}) custom-call(%concatenate.10, %get-tuple-element.240), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"20000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.97 = f32[200,100]{1,0} get-tuple-element(%custom-call.16), index=0 + %slice.26 = f32[100,100]{1,0} dynamic-slice(%get-tuple-element.97, %c0_s32, %c0_s32), dynamic_slice_sizes={100,100} + ROOT %custom-call.17 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%slice.26, %get-tuple-element.240), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %address-computation { + %p0.3 = f32[200,100]{1,0} parameter(0) + %p1.3 = f32[100,100]{1,0} parameter(1) + %c0_s32 = s32[] parameter(2) + %slice.56 = f32[100,100]{1,0} dynamic-slice(%p0.3, %c0_s32, %c0_s32), dynamic_slice_sizes={100,100} + %cublas-gemm.23 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%slice.56, %p1.3), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.221 = f32[100,100]{1,0} get-tuple-element(%cublas-gemm.23), index=0 + %get-tuple-element.222 = s8[80000]{0} get-tuple-element(%cublas-gemm.23), index=1 + ROOT %tuple.58 = (f32[100,100]{1,0}, s8[80000]{0}) tuple(%get-tuple-element.221, %get-tuple-element.222) + } + + ENTRY %main { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %c0_s32 = s32[] constant(0) + %get-tuple-element.240 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.241 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.10 = f32[200,100]{1,0} concatenate(%get-tuple-element.240, %get-tuple-element.241), dimensions={0} + %custom-call.16 = (f32[200,100]{1,0}, s8[120000]{0}) custom-call(%concatenate.10, %get-tuple-element.240), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"20000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.97 = f32[200,100]{1,0} get-tuple-element(%custom-call.16), index=0 + ROOT %address_computation.6 = (f32[100,100]{1,0}, s8[80000]{0}) fusion(%get-tuple-element.97, %get-tuple-element.240, %c0_s32), + kind=kCustom, + calls=%address-computation, + backend_config={ + "fusion_backend_config":{ + "kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"} + } + } + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, DynamicReversedOperandOrder) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %c0_s32 = s32[] constant(0) + %c1_s32 = s32[] constant(1) + %slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(%p0, %c0_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(%p1, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + ROOT %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.42, %bitcast.41), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %address-computation { + %p0.1 = f16[2,8,8]{2,1,0} parameter(0) + %p1.1 = f16[2,8,8]{2,1,0} parameter(1) + %c0_s32 = s32[] parameter(2) + %c1_s32 = s32[] parameter(3) + %slice.1 = f16[1,8,8]{2,1,0} dynamic-slice(%p0.1, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.1 = f16[8,8]{1,0} bitcast(%slice.1) + %slice.0 = f16[1,8,8]{2,1,0} dynamic-slice(%p1.1, %c0_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.0 = f16[8,8]{1,0} bitcast(%slice.0) + ROOT %custom-call.0 = f16[8,8]{1,0} custom-call(%bitcast.1, %bitcast.0), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + } + } + } + + ENTRY %main { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %c0_s32 = s32[] constant(0) + %c1_s32 = s32[] constant(1) + ROOT %address_computation.6 = f16[8,8]{1,0} fusion(%p1, %p0, %c0_s32, %c1_s32), + kind=kCustom, + calls=%address-computation, + backend_config={ + "fusion_backend_config":{ + "kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"} + } + } + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, DynamicSingleOperandComputation) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %c0_s32 = s32[] constant(0) + %get-tuple-element.240 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.241 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.10 = f32[200,100]{1,0} concatenate(%get-tuple-element.240, %get-tuple-element.241), dimensions={0} + %custom-call.16 = (f32[200,100]{1,0}, s8[120000]{0}) custom-call(%concatenate.10, %get-tuple-element.240), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"20000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.97 = f32[200,100]{1,0} get-tuple-element(%custom-call.16), index=0 + %slice.26 = f32[100,100]{1,0} dynamic-slice(%get-tuple-element.97, %c0_s32, %c0_s32), dynamic_slice_sizes={100,100} + ROOT %custom-call.17 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%slice.26, %slice.26), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %address-computation { + %p0.3 = f32[200,100]{1,0} parameter(0) + %c0_s32 = s32[] parameter(1) + %slice.56 = f32[100,100]{1,0} dynamic-slice(%p0.3, %c0_s32, %c0_s32), dynamic_slice_sizes={100,100} + %cublas-gemm.23 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%slice.56, %slice.56), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.221 = f32[100,100]{1,0} get-tuple-element(%cublas-gemm.23), index=0 + %get-tuple-element.222 = s8[80000]{0} get-tuple-element(%cublas-gemm.23), index=1 + ROOT %tuple.58 = (f32[100,100]{1,0}, s8[80000]{0}) tuple(%get-tuple-element.221, %get-tuple-element.222) + } + + ENTRY %main { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %c0_s32 = s32[] constant(0) + %get-tuple-element.240 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.241 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.10 = f32[200,100]{1,0} concatenate(%get-tuple-element.240, %get-tuple-element.241), dimensions={0} + %custom-call.16 = (f32[200,100]{1,0}, s8[120000]{0}) custom-call(%concatenate.10, %get-tuple-element.240), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"20000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element.97 = f32[200,100]{1,0} get-tuple-element(%custom-call.16), index=0 + ROOT %address_computation.6 = (f32[100,100]{1,0}, s8[80000]{0}) fusion(%get-tuple-element.97, %c0_s32), + kind=kCustom, + calls=%address-computation, + backend_config={ + "fusion_backend_config":{ + "kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"} + } + } + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, DynamicSlicedOperandAliasingOutput) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main.9 { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %c20_s32 = s32[] constant(20) + %c99_s32 = s32[] constant(99) + %c0_s32 = s32[] constant(0) + %get-tuple-element.287 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.288 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.12 = f32[200,100]{1,0} concatenate(%get-tuple-element.287, %get-tuple-element.288), dimensions={0} + %slice.30 = f32[100,100]{1,0} dynamic-slice(%concatenate.12, %c20_s32, %c0_s32), dynamic_slice_sizes={100,100} + %slice.34 = f32[100,100]{1,0} dynamic-slice(%concatenate.12, %c99_s32, %c0_s32), dynamic_slice_sizes={100,100} + ROOT %cublas-gemm.15 = (f32[100,100]{1,0}, s8[120000]{0}) custom-call(%get-tuple-element.287, %slice.30, %slice.34), + custom_call_target="__cublas$gemm", + output_to_operand_aliasing={{0}: (2, {})}, + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":1, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + }} + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %address-computation { + %p0.1 = f32[100,100]{1,0} parameter(0) + %p1 = f32[100,100]{1,0} parameter(1) + %p2 = f32[200,100]{1,0} parameter(2) + %c0_s32 = s32[] parameter(3) + %c20_s32 = s32[] parameter(4) + %slice.0 = f32[100,100]{1,0} dynamic-slice(f32[200,100]{1,0} %p2, %c20_s32, %c0_s32), dynamic_slice_sizes={100,100} + %cublas-gemm.0 = (f32[100,100]{1,0}, s8[120000]{0}) custom-call(%p0.1, %slice.0, %p1), + custom_call_target="__cublas$gemm", + backend_config={ + "gemm_backend_config":{ + "alpha_real":1, + "beta":1, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["HIGHEST","HIGHEST"]}, + "epilogue":"DEFAULT", + "lhs_stride":"10000", + "rhs_stride":"10000", + "grad_x":false, + "grad_y":false + } + } + %get-tuple-element = f32[100,100]{1,0} get-tuple-element(%cublas-gemm.0), index=0 + %get-tuple-element.1 = s8[120000]{0} get-tuple-element(%cublas-gemm.0), index=1 + ROOT %tuple = (f32[100,100]{1,0}, s8[120000]{0}) tuple(%get-tuple-element, %get-tuple-element.1) + } + + ENTRY %main { + %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) + %c20_s32 = s32[] constant(20) + %c99_s32 = s32[] constant(99) + %c0_s32 = s32[] constant(0) + %get-tuple-element.287 = f32[100,100]{1,0} get-tuple-element(%p0), index=0 + %get-tuple-element.288 = f32[100,100]{1,0} get-tuple-element(%p0), index=1 + %concatenate.12 = f32[200,100]{1,0} concatenate(%get-tuple-element.287, %get-tuple-element.288), dimensions={0} + %slice.34 = f32[100,100]{1,0} dynamic-slice(%concatenate.12, %c99_s32, %c0_s32), dynamic_slice_sizes={100,100} + ROOT %address_computation.6 = (f32[100,100]{1,0}, s8[120000]{0}) fusion(%get-tuple-element.287, %slice.34, %concatenate.12, %c0_s32, %c20_s32), + kind=kCustom, + calls=%address-computation, + output_to_operand_aliasing={{0}: (1, {})}, + backend_config={ + "fusion_backend_config":{ + "kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"} + } + } + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, CublasGemmDUS) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY main.9 { + p0 = bf16[2,8,8]{2,1,0} parameter(0) + p1 = bf16[2,8,8]{2,1,0} parameter(1) + p2 = bf16[4,8,8]{2,1,0} parameter(2) + c1_s32 = s32[] constant(1) + c0_s32 = s32[] constant(0) + slice.13 = bf16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.41 = bf16[8,8]{1,0} bitcast(slice.13) + slice.14 = bf16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.42 = bf16[8,8]{1,0} bitcast(slice.14) + + custom-call.1 = bf16[8,8]{1,0} custom-call(bitcast.41, bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + bitcast.43 = bf16[1,8,8]{2,1,0} bitcast(custom-call.1) + ROOT dus = bf16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, c1_s32, c0_s32, c0_s32) + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + fused_computation { + p0 = bf16[2,8,8]{2,1,0} parameter(0) + p1 = bf16[2,8,8]{2,1,0} parameter(1) + p2 = bf16[4,8,8]{2,1,0} parameter(2) + c1_s32 = s32[] parameter(3) + c0_s32 = s32[] parameter(4) + slice.13 = bf16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.41 = bf16[8,8]{1,0} bitcast(slice.13) + slice.14 = bf16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8} + bitcast.42 = bf16[8,8]{1,0} bitcast(slice.14) + + custom-call.1 = bf16[8,8]{1,0} custom-call(bitcast.41, bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + bitcast.43 = bf16[1,8,8]{2,1,0} bitcast(custom-call.1) + ROOT dus = bf16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, c1_s32, c0_s32, c0_s32) + } + + ENTRY main.9 { + p0 = bf16[2,8,8]{2,1,0} parameter(0) + p1 = bf16[2,8,8]{2,1,0} parameter(1) + p2 = bf16[4,8,8]{2,1,0} parameter(2) + c1_s32 = s32[] constant(1) + c0_s32 = s32[] constant(0) + ROOT fusion.2 = bf16[4,8,8]{2,1,0} fusion(p0, p1, p2, c1_s32, c0_s32), kind=kCustom, calls=fused_computation, + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} + })"; + + // The GEMM custom call does not have a workspace, shouldn't be run in command + // buffer. + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_opt, GetModuleConfigWithoutCommandBuffer(), + GetModuleConfigWithoutCommandBuffer(), error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, CublasGemmDUSWithWorkspace) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %p2 = f16[4,8,8]{2,1,0} parameter(2) + %c1_s32 = s32[] constant(1) + %c0_s32 = s32[] constant(0) + %slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(%p0, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(%p1, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0 + %bitcast.43 = f16[1,8,8]{2,1,0} bitcast(%get-tuple-element.0) + %dus = f16[4,8,8]{2,1,0} dynamic-update-slice(%p2, %bitcast.43, %c1_s32, %c0_s32, %c0_s32) + %get-tuple-element.1 = s8[256]{0} get-tuple-element(%custom-call.1), index=1 + ROOT %tuple = (f16[4,8,8]{2,1,0}, s8[256]{0}) tuple(%dus, %get-tuple-element.1) + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %fused_computation { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %p2 = f16[4,8,8]{2,1,0} parameter(2) + %c1_s32 = s32[] parameter(3) + %c0_s32 = s32[] parameter(4) + %slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(%p0, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(%p1, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0 + %bitcast.43 = f16[1,8,8]{2,1,0} bitcast(%get-tuple-element.0) + %dus = f16[4,8,8]{2,1,0} dynamic-update-slice(%p2, %bitcast.43, %c1_s32, %c0_s32, %c0_s32) + %get-tuple-element.1 = s8[256]{0} get-tuple-element(%custom-call.1), index=1 + ROOT %tuple = (f16[4,8,8]{2,1,0}, s8[256]{0}) tuple(%dus, %get-tuple-element.1) + } + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %p2 = f16[4,8,8]{2,1,0} parameter(2) + %c1_s32 = s32[] constant(1) + %c0_s32 = s32[] constant(0) + ROOT %fusion.2 = (f16[4,8,8]{2,1,0}, s8[256]{0}) fusion(%p0, %p1, %p2, %c1_s32, %c0_s32), kind=kCustom, calls=%fused_computation, + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, CublasGemmDUSWorkspaceIgnored) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main.9 { + %p0 = f16[8,8]{1,0} parameter(0) + %p1 = f16[8,8]{1,0} parameter(1) + %p2 = f16[4,8,8]{2,1,0} parameter(2) + %c1_s32 = s32[] constant(1) + %c0_s32 = s32[] constant(0) + + %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%p0, %p1), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0 + %bitcast.43 = f16[1,8,8]{2,1,0} bitcast(%get-tuple-element.0) + ROOT %dus = f16[4,8,8]{2,1,0} dynamic-update-slice(%p2, %bitcast.43, %c1_s32, %c0_s32, %c0_s32) + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %fused_computation { + %p0 = f16[8,8]{1,0} parameter(0) + %p1 = f16[8,8]{1,0} parameter(1) + %p2 = f16[4,8,8]{2,1,0} parameter(2) + %c1_s32 = s32[] parameter(3) + %c0_s32 = s32[] parameter(4) + + %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%p0, %p1), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0 + %bitcast.43 = f16[1,8,8]{2,1,0} bitcast(%get-tuple-element.0) + ROOT %dus = f16[4,8,8]{2,1,0} dynamic-update-slice(%p2, %bitcast.43, %c1_s32, %c0_s32, %c0_s32) + } + + ENTRY %main.9 { + %p0 = f16[8,8]{1,0} parameter(0) + %p1 = f16[8,8]{1,0} parameter(1) + %p2 = f16[4,8,8]{2,1,0} parameter(2) + %c1_s32 = s32[] constant(1) + %c0_s32 = s32[] constant(0) + ROOT %fusion.2 = f16[4,8,8]{2,1,0} fusion(%p0, %p1, %p2, %c1_s32, %c0_s32), kind=kCustom, calls=%fused_computation, + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, CublasGemmDUSOffsetS32NotConstant) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %p2 = f16[4,8,8]{2,1,0} parameter(2) + %c1_s32 = s32[] parameter(3) + %c0_s32 = s32[] parameter(4) + %slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(%p0, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(%p1, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0 + %bitcast.43 = f16[1,8,8]{2,1,0} bitcast(%get-tuple-element.0) + %dus = f16[4,8,8]{2,1,0} dynamic-update-slice(%p2, %bitcast.43, %c1_s32, %c0_s32, %c0_s32) + %get-tuple-element.1 = s8[256]{0} get-tuple-element(%custom-call.1), index=1 + ROOT %tuple = (f16[4,8,8]{2,1,0}, s8[256]{0}) tuple(%dus, %get-tuple-element.1) + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %fused_computation { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %p2 = f16[4,8,8]{2,1,0} parameter(2) + %c1_s32 = s32[] parameter(3) + %c0_s32 = s32[] parameter(4) + %slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(%p0, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(%p1, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0 + %bitcast.43 = f16[1,8,8]{2,1,0} bitcast(%get-tuple-element.0) + %dus = f16[4,8,8]{2,1,0} dynamic-update-slice(%p2, %bitcast.43, %c1_s32, %c0_s32, %c0_s32) + %get-tuple-element.1 = s8[256]{0} get-tuple-element(%custom-call.1), index=1 + ROOT %tuple = (f16[4,8,8]{2,1,0}, s8[256]{0}) tuple(%dus, %get-tuple-element.1) + } + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %p2 = f16[4,8,8]{2,1,0} parameter(2) + %c1_s32 = s32[] parameter(3) + %c0_s32 = s32[] parameter(4) + ROOT %fusion.2 = (f16[4,8,8]{2,1,0}, s8[256]{0}) fusion(%p0, %p1, %p2, %c1_s32, %c0_s32), kind=kCustom, calls=%fused_computation, + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, CublasGemmDUSOffsetOOB) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule jit_slice + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %p2 = f16[4,8,8]{2,1,0} parameter(2) + %c1_s32 = s64[] constant(10) + %c0_s32 = s64[] constant(-1) + %slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(%p0, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(%p1, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0 + %bitcast.43 = f16[1,8,8]{2,1,0} bitcast(%get-tuple-element.0) + %dus = f16[4,8,8]{2,1,0} dynamic-update-slice(%p2, %bitcast.43, %c1_s32, %c0_s32, %c0_s32) + %get-tuple-element.1 = s8[256]{0} get-tuple-element(%custom-call.1), index=1 + ROOT %tuple = (f16[4,8,8]{2,1,0}, s8[256]{0}) tuple(%dus, %get-tuple-element.1) + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + %fused_computation { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %p2 = f16[4,8,8]{2,1,0} parameter(2) + %c1_s32 = s64[] parameter(3) + %c0_s32 = s64[] parameter(4) + %slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(%p0, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) + %slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(%p1, %c1_s32, %c0_s32, %c0_s32), dynamic_slice_sizes={1,8,8} + %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) + + %custom-call.1 = (f16[8,8]{1,0}, s8[256]{0}) custom-call(%bitcast.41, %bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + %get-tuple-element.0 = f16[8,8]{1,0} get-tuple-element(%custom-call.1), index=0 + %bitcast.43 = f16[1,8,8]{2,1,0} bitcast(%get-tuple-element.0) + %dus = f16[4,8,8]{2,1,0} dynamic-update-slice(%p2, %bitcast.43, %c1_s32, %c0_s32, %c0_s32) + %get-tuple-element.1 = s8[256]{0} get-tuple-element(%custom-call.1), index=1 + ROOT %tuple = (f16[4,8,8]{2,1,0}, s8[256]{0}) tuple(%dus, %get-tuple-element.1) + } + + ENTRY %main.9 { + %p0 = f16[2,8,8]{2,1,0} parameter(0) + %p1 = f16[2,8,8]{2,1,0} parameter(1) + %p2 = f16[4,8,8]{2,1,0} parameter(2) + %c1_s32 = s64[] constant(10) + %c0_s32 = s64[] constant(-1) + ROOT %fusion.2 = (f16[4,8,8]{2,1,0}, s8[256]{0}) fusion(%p0, %p1, %p2, %c1_s32, %c0_s32), kind=kCustom, calls=%fused_computation, + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, DynamicCustomCallSimple) { + XlaBuilder b(TestName()); + CustomCall( + &b, "__xla_test$$memcpy", + /*operands=*/ + {DynamicSlice(Parameter(&b, 0, ShapeUtil::MakeShape(S32, {4, 128}), "p0"), + {Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "start0"), + Parameter(&b, 2, ShapeUtil::MakeShape(S32, {}), "start1")}, + {2, 128})}, + ShapeUtil::MakeShape(F32, {2, 128}), /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + AddressComputationFusionRewriter pass(PLATFORM); + TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); + EXPECT_TRUE(changed); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), + error_spec, /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, DynamicCustomCallWithTuple) { + XlaBuilder b(TestName()); + CustomCall( + &b, "__xla_test$$subbuffers", /*operands=*/ + { + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 1), {128}), + Broadcast(ConstantR0WithType(&b, F32, 2), {256}), + }), + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 3), {1024}), + Broadcast(ConstantR0WithType(&b, F32, 4), {8}), + }), + Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}), {0, 0}, + {4, 8}, {1, 1}), + Tuple(&b, + { + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 6), {32}), + Broadcast(ConstantR0WithType(&b, F32, 7), {64}), + }), + DynamicSlice( + Parameter(&b, 0, ShapeUtil::MakeShape(S32, {4, 128}), + "p0"), + {Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), + "start0"), + Parameter(&b, 2, ShapeUtil::MakeShape(S32, {}), + "start1")}, + {3, 128}), + }), + }, + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {8}), + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {128}), + ShapeUtil::MakeShape(F32, {256}), + }), + ShapeUtil::MakeShape(F32, {1024}), + ShapeUtil::MakeShape(F32, {4, 8}), + ShapeUtil::MakeShape(F32, {3, 128}), + ShapeUtil::MakeShape(F32, {32 + 64}), + }), + /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/true); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + debug_options.set_xla_gpu_enable_address_computation_fusion(true); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + AddressComputationFusionRewriter pass(PLATFORM); + TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); + EXPECT_TRUE(changed); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), + error_spec, /*run_hlo_passes=*/false)); +} + +static absl::Status SubBuffers2( + se::Stream* stream, ffi::BufferBase src0, ffi::BufferBase src1, + ffi::BufferBase src2, ffi::BufferBase src3, ffi::BufferBase src4, + ffi::BufferBase src5, ffi::BufferBase src6, + ffi::Result dst0, ffi::Result dst1, + ffi::Result dst2, ffi::Result dst3, + ffi::Result dst4, ffi::Result dst5, + ffi::Result dst6) { + // src0: param 0 at tuple index {0}, shape f32[128] + // src1: param 0 at tuple index {1}, shape f32[256] + // src2: param 1 at tuple index {0}, shape f32[1024] + // src3: param 1 at tuple index {1}, shape f32[8] + // src4: param 2, shape f32[4,8] + // src5: param 3 at tuple index {0, 0}, shape f32[3,128] + // src6: param 3 at tuple index {0, 1}, shape f32[5,128] + // + // dst0: result at tuple index {0}, shape f32[8] + // dst1: result at tuple index {1, 0}, shape f32[128] + // dst2: result at tuple index {1, 1}, shape f32[256] + // dst3: result at tuple index {2}, shape f32[1024] + // dst4: result at tuple index {3}, shape f32[4,8] + // dst5: result at tuple index {4, 0}, shape f32[5,128] + // dst6: result at tuple index {4, 1}, shape f32[3,128] + + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst0->data, src3.data, 8 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst1->data, src0.data, 128 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst2->data, src1.data, 256 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst3->data, src2.data, 1024 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst4->data, src4.data, 4 * 8 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst5->data, src6.data, 5 * 128 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst6->data, src5.data, 3 * 128 * sizeof(float))); + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kSubBuffers2, SubBuffers2, + ffi::Ffi::Bind() + .Ctx() + .Arg() // src0 + .Arg() // src1 + .Arg() // src2 + .Arg() // src3 + .Arg() // src4 + .Arg() // src5 + .Arg() // src6 + .Ret() // dst0 + .Ret() // dst1 + .Ret() // dst2 + .Ret() // dst3 + .Ret() // dst4 + .Ret() // dst5 + .Ret() // dst6 +); +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$subbuffers2", + PLATFORM, kSubBuffers2); + +TEST_F(AddressComputationFusionTest, CustomCallDUS) { + XlaBuilder b(TestName()); + auto custom_call = + CustomCall(&b, "Callback_Memcpy", + /*operands=*/ + {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {10, 128}), + {2, 0}, {5, 128}, {1, 1})}, + ShapeUtil::MakeShape(F32, {3, 128}), /*opaque=*/""); + + DynamicUpdateSlice( + Broadcast(ConstantR0WithType(&b, F32, 92.0), {10, 128}), custom_call, + {ConstantR0WithType(&b, S32, 4), ConstantR0WithType(&b, S32, 0)}); + + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + debug_options.set_xla_gpu_enable_address_computation_fusion(true); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + AddressComputationFusionRewriter pass(PLATFORM); + TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); + EXPECT_TRUE(changed); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), + error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, CustomCallDUSTuple) { + XlaBuilder b(TestName()); + auto big_buffer1 = + Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10, 128}), "p0"); + auto big_buffer2 = + Parameter(&b, 1, ShapeUtil::MakeShape(F32, {10, 256}), "p1"); + auto custom_call = CustomCall( + &b, "__xla_test$$subbuffers2", /*operands=*/ + { + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 1), {128}), + Broadcast(ConstantR0WithType(&b, F32, 2), {256}), + }), + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 3), {1024}), + Broadcast(ConstantR0WithType(&b, F32, 4), {8}), + }), + Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}), {0, 0}, + {4, 8}, {1, 1}), + Tuple( + &b, + { + Tuple( + &b, + { + Broadcast(ConstantR0WithType(&b, F32, 6), {3, 128}), + DynamicSlice(Broadcast(ConstantR0WithType(&b, F32, 7), + {8, 128}), + {ConstantR0WithType(&b, S32, 2), + ConstantR0WithType(&b, S32, 0)}, + {5, 128}), + }), + }), + }, + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {8}), + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {128}), + ShapeUtil::MakeShape(F32, {256}), + }), + ShapeUtil::MakeShape(F32, {1024}), + ShapeUtil::MakeShape(F32, {4, 8}), + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {5, 128}), + ShapeUtil::MakeShape(F32, {3, 128}), + }), + }), + /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + auto tuple_gte = GetTupleElement(custom_call, 4); + auto dus1 = DynamicUpdateSlice( + big_buffer1, GetTupleElement(tuple_gte, 0), + {ConstantR0WithType(&b, S32, 2), ConstantR0WithType(&b, S32, 0)}); + auto dus2 = DynamicUpdateSlice( + big_buffer1, GetTupleElement(tuple_gte, 1), + {ConstantR0WithType(&b, S32, 7), ConstantR0WithType(&b, S32, 0)}); + auto dus3 = DynamicUpdateSlice( + big_buffer2, + xla::internal::XlaBuilderFriend::BuildBitcast( + &b, GetTupleElement(custom_call, 2), + ShapeUtil::MakeShape(F32, {4, 256})), + {Parameter(&b, 2, ShapeUtil::MakeShape(S32, {}), "start0"), + Parameter(&b, 3, ShapeUtil::MakeShape(S32, {}), "start1")}); + Tuple(&b, {dus1, dus2, dus3}); + + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + debug_options.set_xla_gpu_enable_address_computation_fusion(true); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + AddressComputationFusionRewriter pass(PLATFORM); + TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); + EXPECT_TRUE(changed); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), + error_spec, + /*run_hlo_passes=*/false)); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/concatenate.cc b/xla/service/gpu/fusions/concatenate.cc new file mode 100644 index 0000000000000..f0f2eee6e16f1 --- /dev/null +++ b/xla/service/gpu/fusions/concatenate.cc @@ -0,0 +1,136 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/concatenate.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/elemental_ir_emitter.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/parallel_loop_emitter.h" +#include "xla/service/llvm_ir/fused_ir_emitter.h" +#include "xla/service/llvm_ir/ir_array.h" +#include "xla/service/llvm_ir/loop_emitter.h" +#include "xla/shape.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +const Shape& GetLargestConcatOperandShape(const HloFusionAnalysis& analysis) { + const HloInstruction* concat = analysis.fusion_heroes().front(); + int64_t dim = concat->concatenate_dimension(); + auto less = [&](const HloInstruction* lhs, const HloInstruction* rhs) { + return lhs->shape().dimensions(dim) < rhs->shape().dimensions(dim); + }; + HloInstruction* operand = *absl::c_max_element(concat->operands(), less); + return operand->shape(); +} + +ConcatenateFusion::ConcatenateFusion(const HloFusionAnalysis& analysis) + : analysis_(analysis) {} + +std::optional ConcatenateFusion::ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const { + return std::nullopt; +} + +std::optional ConcatenateFusion::ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const { + return GetDefaultThreadIdIndexingMap(launch_dimensions(), /*unroll_factor=*/1, + GetLargestConcatOperandShape(analysis_), + ctx); +} + +absl::Status ConcatenateFusion::EmitKernel( + IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion, + const LaunchDimensions& launch_dims, std::vector inputs, + std::vector outputs, llvm::IRBuilder<>* builder) const { + GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder); + FusedIrEmitter fused_emitter(elemental_emitter); + for (int i = 0; i < fusion.fused_parameters().size(); i++) { + fused_emitter.BindGenerator( + *fusion.fused_parameter(i), [&, i](llvm_ir::IrArray::Index index) { + return inputs[i].EmitReadArrayElement(index, builder); + }); + } + + llvm::Type* index_type = + GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder); + + const HloInstruction* concat = analysis_.fusion_heroes().front(); + int64_t concat_dim = concat->concatenate_dimension(); + int64_t operand_offset = 0; + + // Emit the slices that correspond to the operands of the concat hero. + for (const HloInstruction* operand : concat->operands()) { + llvm_ir::BodyEmitter body_emitter = + [&](const llvm_ir::IrArray::Index& operand_index) -> absl::Status { + // Bind concat to generate the current operand. + TF_ASSIGN_OR_RETURN(auto operand_generator, + fused_emitter.GetGenerator(*operand)); + fused_emitter.BindGenerator(*concat, [&](llvm_ir::IrArray::Index) { + return operand_generator(operand_index); + }); + + // Create the index of the slice corresponding to the current operand. + llvm_ir::IrArray::Index result_index = operand_index.AddOffsetToDim( + llvm::ConstantInt::get(index_type, operand_offset), concat_dim, + builder); + operand_offset += operand->shape().dimensions(concat_dim); + + // Generate and write out the slice for each root. + for (const auto& [output, root] : + llvm::zip_equal(outputs, analysis_.fusion_roots())) { + llvm_ir::IrArray::Index root_index = result_index.SourceIndexOfBitcast( + concat->shape(), root->shape(), builder); + TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(*root)); + TF_ASSIGN_OR_RETURN(llvm::Value * value, generator(root_index)); + output.EmitWriteArrayElement(root_index, value, builder); + } + return absl::OkStatus(); + }; + + ParallelLoopEmitter emitter(body_emitter, operand->shape(), launch_dims, + builder); + TF_RETURN_IF_ERROR(emitter.EmitLoop(fusion.name(), index_type)); + } + + return absl::OkStatus(); +} + +LaunchDimensions ConcatenateFusion::launch_dimensions() const { + return CalculateLaunchDimensions(GetLargestConcatOperandShape(analysis_), + analysis_.device_info()); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/concatenate.h b/xla/service/gpu/fusions/concatenate.h new file mode 100644 index 0000000000000..ec9349c958941 --- /dev/null +++ b/xla/service/gpu/fusions/concatenate.h @@ -0,0 +1,62 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_CONCATENATE_H_ +#define XLA_SERVICE_GPU_FUSIONS_CONCATENATE_H_ + +#include +#include + +#include "absl/status/status.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/launch_dimensions.h" + +namespace xla { +namespace gpu { + +const Shape& GetLargestConcatOperandShape(const HloFusionAnalysis& analysis); + +// Emits a kernel for the given hlo instruction where each thread produces +// one element of each concat operand. +class ConcatenateFusion : public KernelFusionEmitterBase { + public: + explicit ConcatenateFusion(const HloFusionAnalysis& analysis); + LaunchDimensions launch_dimensions() const override; + + std::optional ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const override; + + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const override; + + protected: + absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion, + const LaunchDimensions& launch_dims, + std::vector inputs, + std::vector outputs, + llvm::IRBuilder<>* builder) const override; + + private: + const HloFusionAnalysis& analysis_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_CONCATENATE_H_ diff --git a/xla/service/gpu/fusions/concatenate_mlir.cc b/xla/service/gpu/fusions/concatenate_mlir.cc new file mode 100644 index 0000000000000..e6091085c2e2c --- /dev/null +++ b/xla/service/gpu/fusions/concatenate_mlir.cc @@ -0,0 +1,157 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusions/concatenate_mlir.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/concatenate.h" +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "tsl/platform/errors.h" + +namespace xla { +namespace gpu { + +using llvm::SmallVector; +using mlir::Value; +using mlir::ValueRange; + +LaunchDimensions MlirConcatenateFusion::launch_dimensions() const { + return CalculateLaunchDimensions(GetLargestConcatOperandShape(analysis_), + analysis_.device_info()); +} + +std::optional +MlirConcatenateFusion::ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const { + return std::nullopt; +} + +std::optional +MlirConcatenateFusion::ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const { + // TODO(b/331356433): Add constraints depending on the `hero_operand_index`. + return GetDefaultThreadIdIndexingMap(launch_dimensions(), /*unroll_factor=*/1, + GetLargestConcatOperandShape(analysis_), + ctx); +} + +std::vector +MlirConcatenateFusion::GetInstructionsWithCustomCodegen( + const HloFusionInstruction& fusion) const { + return analysis_.fusion_heroes(); +} + +absl::Status MlirConcatenateFusion::EmitEntryFunction( + const mlir_converter::PartitionedComputations& computations, + const mlir_converter::CallTargetProvider& call_targets, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const { + const auto& root_computation = computations.FindPartitionedComputation( + fusion.fused_instructions_computation()); + const auto* concat = analysis_.fusion_heroes()[0]; + mlir::ImplicitLocOpBuilder builder(entry_function.getLoc(), entry_function); + builder.setInsertionPointToStart(entry_function.addEntryBlock()); + auto* ctx = entry_function.getContext(); + + int num_inputs = fusion.fused_instructions_computation()->num_parameters(); + SmallVector input_tensors( + entry_function.getArguments().take_front(num_inputs)); + auto output_tensor_args = + entry_function.getArguments().drop_front(num_inputs); + + SmallVector result_tensors{output_tensor_args.begin(), + output_tensor_args.end()}; + + auto thread_id_to_input_map = + ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/0, ctx) + .value(); + auto epilogue_indexing = ComputeEpilogueInputToOutputIndexing(concat, ctx); + + for (auto [operand_index, operand] : llvm::enumerate(concat->operands())) { + auto input_to_output_map = + *ComputeInputToOutputIndexing(concat, /*input_id=*/operand_index, ctx) + .indexing_maps.front() + .begin(); + auto thread_id_to_output_map = ComposeIndexingMaps( + ComposeIndexingMaps(thread_id_to_input_map, input_to_output_map), + epilogue_indexing); + + auto loop_nest_body_builder = + [&, operand_index = operand_index]( + ValueRange output_tensors, ValueRange dim_values, + ValueRange symbol_values) -> SmallVector { + auto input_indices = + mlir_converter::ApplyAffineMap(thread_id_to_input_map.GetAffineMap(), + dim_values, symbol_values, builder); + + auto result_scalars = mlir_converter::ProvideParameter( + root_computation.FindSubgraph(concat), concat, operand_index, + input_indices, call_targets, entry_function, builder); + auto output_indices = + mlir_converter::ApplyAffineMap(thread_id_to_output_map.GetAffineMap(), + dim_values, symbol_values, builder); + result_scalars = EmitEpilogue(computations, entry_function, + result_scalars, output_indices, builder); + + SmallVector result_tensors; + result_tensors.reserve(output_tensor_args.size()); + for (auto [tensor, value] : llvm::zip(output_tensors, result_scalars)) { + result_tensors.push_back( + builder + .create(value, tensor, output_indices) + .getResult()); + } + + return result_tensors; + }; + + result_tensors = + EmitThreadLoopNest(builder, result_tensors, thread_id_to_output_map, + loop_nest_body_builder); + } + + builder.create(result_tensors); + + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/concatenate_mlir.h b/xla/service/gpu/fusions/concatenate_mlir.h new file mode 100644 index 0000000000000..f14606a1073f4 --- /dev/null +++ b/xla/service/gpu/fusions/concatenate_mlir.h @@ -0,0 +1,66 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSIONS_CONCATENATE_MLIR_H_ +#define XLA_SERVICE_GPU_FUSIONS_CONCATENATE_MLIR_H_ + +#include +#include + +#include "absl/status/status.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { + +class MlirConcatenateFusion : public MlirFusionEmitterBase { + public: + explicit MlirConcatenateFusion(const HloFusionAnalysis& analysis) + : analysis_(analysis) {} + + LaunchDimensions launch_dimensions() const override; + + std::optional ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const override; + + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const override; + + protected: + absl::Status EmitEntryFunction( + const mlir_converter::PartitionedComputations& computations, + const mlir_converter::CallTargetProvider& call_targets, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const override; + + std::vector GetInstructionsWithCustomCodegen( + const HloFusionInstruction& fusion) const override; + + private: + const HloFusionAnalysis& analysis_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_CONCATENATE_MLIR_H_ diff --git a/xla/service/gpu/fusions/concatenate_mlir_test.cc b/xla/service/gpu/fusions/concatenate_mlir_test.cc new file mode 100644 index 0000000000000..515d385408018 --- /dev/null +++ b/xla/service/gpu/fusions/concatenate_mlir_test.cc @@ -0,0 +1,246 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusions/concatenate_mlir.h" + +#include +#include "xla/error_spec.h" +#include "xla/service/gpu/fusions/concatenate.h" +#include "xla/service/gpu/fusions/mlir_emitter_test_base.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "tsl/lib/core/status_test_util.h" + +namespace xla { +namespace gpu { +namespace { + +using MlirConcatenateFusionTest = MlirEmitterTestBase; + +TEST_F(MlirConcatenateFusionTest, ThreadIdIndexing) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule module + + fused_computation { + param0 = f32[200] parameter(0) + param1 = f32[400] parameter(1) + param2 = f32[300] parameter(2) + ROOT concat = f32[900] concatenate(param0, param1, param2), dimensions={0} + } + ENTRY main { + param0 = f32[200] parameter(0) + param1 = f32[400] parameter(1) + param2 = f32[300] parameter(2) + ROOT fusion = f32[900] fusion(param0, param1, param2), + calls=fused_computation, kind=kLoop + } + )")); + thread_id_printer_.SetSymbolName(0, "chunk_id"); + thread_id_printer_.SetSymbolName(1, "unroll_id"); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + MlirConcatenateFusion fusion(analysis); + + constexpr auto kIndexing = R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( + (th_x + bl_x * 128) mod 400) + domain: + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 3] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + th_x + bl_x * 128 in [0, 399] + )"; + auto thread_id_to_output_indexing_0 = fusion.ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); + EXPECT_THAT(thread_id_to_output_indexing_0->ToString(thread_id_printer_), + MatchIndexingString(kIndexing)); + auto thread_id_to_output_indexing_1 = fusion.ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_); + EXPECT_THAT(thread_id_to_output_indexing_1->ToString(thread_id_printer_), + MatchIndexingString(kIndexing)); + auto thread_id_to_output_indexing_2 = fusion.ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_); + EXPECT_THAT(thread_id_to_output_indexing_2->ToString(thread_id_printer_), + MatchIndexingString(kIndexing)); +} + +TEST_F(MlirConcatenateFusionTest, StandAloneConcatenate) { + auto kHloString = R"( + HloModule module + + fused_computation { + param0 = f32[200] parameter(0) + param1 = f32[400] parameter(1) + param2 = f32[300] parameter(2) + ROOT concat = f32[900] concatenate(param0, param1, param2), dimensions={0} + } + ENTRY main { + param0 = f32[200] parameter(0) + param1 = f32[400] parameter(1) + param2 = f32[300] parameter(2) + ROOT fusion = f32[900] fusion(param0, param1, param2), + calls=fused_computation, kind=kLoop + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK-DAG: #[[MAP_1:.*]] = affine_map<()[s0, s1] -> (s0 + s1 * 128)> + // CHECK-DAG: #[[MAP_2:.*]] = affine_map<()[s0, s1] -> ((s0 + s1 * 128) mod 400)> + // CHECK-DAG: #[[MAP_3:.*]] = affine_map<()[s0, s1] -> ((s0 + s1 * 128) mod 400 + 200)> + // CHECK-DAG: #[[MAP_4:.*]] = affine_map<()[s0, s1] -> ((s0 + s1 * 128) mod 400 + 600)> + + // CHECK-LABEL: fused_computation + // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9]*]]: {{[^,]*}}, + // CHECK-SAME: %[[ARG_1:[a-zA-Z0-9]*]]: {{[^,]*}}, + // CHECK-SAME: %[[ARG_2:[a-zA-Z0-9]*]]: {{[^,]*}}, + // CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]*]]: {{[^,]*}} + + // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x + // CHECK: %[[BLOCK_ID:.*]] = gpu.block_id x + + // CHECK: %[[INPUT_INDEX_1:.*]] = affine.apply #[[MAP_2]]()[%[[THREAD_ID]], %[[BLOCK_ID]]] + // CHECK: %[[IF_1:.*]] = scf.if + // CHECK: %[[VAL_1:.*]] = xla_gpu.pure_call @fused_computation_param0 + // CHECK: %[[INSERTED_1:.*]] = tensor.insert %[[VAL_1:.*]] into %[[OUTPUT]][%[[INPUT_INDEX_1]]] + + // CHECK: %[[IF_2:.*]] = scf.if + // CHECK: %[[VAL_2:.*]] = xla_gpu.pure_call @fused_computation_param1 + // CHECK: %[[OUTPUT_INDEX_2:.*]] = affine.apply #[[MAP_3]]()[%[[THREAD_ID]], %[[BLOCK_ID]]] + // CHECK: %[[INSERTED_2:.*]] = tensor.insert %[[VAL_2:.*]] into {{.*}}[%[[OUTPUT_INDEX_2]]] + + // CHECK: %[[IF_3:.*]] = scf.if + // CHECK: %[[VAL_3:.*]] = xla_gpu.pure_call @fused_computation_param2 + // CHECK: %[[OUTPUT_INDEX_3:.*]] = affine.apply #[[MAP_4]]()[%[[THREAD_ID]], %[[BLOCK_ID]]] + // CHECK: %[[INSERTED_3:.*]] = tensor.insert %[[VAL_3:.*]] into {{.*}}[%[[OUTPUT_INDEX_3]]] + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirConcatenateFusionTest, PrologueEpilogue) { + auto kHloString = R"( + HloModule module + + fused_computation { + param0 = f32[64] parameter(0) + param1 = f32[128] parameter(1) + log = f32[64] log(param0) + exp = f32[128] exponential(param1) + concat = f32[192] concatenate(log, exp), dimensions={0} + ROOT neg = f32[192] negate(concat) + } + ENTRY main { + param0 = f32[64] parameter(0) + param1 = f32[128] parameter(1) + ROOT fusion = f32[192] fusion(param0, param1), calls=fused_computation, kind=kLoop + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 + 64)> + + // CHECK-LABEL: fused_computation + // CHECK-DAG: %[[C_63:.*]] = arith.constant 63 + // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x + + // CHECK: %[[IN_BOUND_1:.*]] = arith.cmpi sle, %[[THREAD_ID:.*]], %[[C_63]] + // CHECK: %[[IF_1:.*]] = scf.if %[[IN_BOUND_1]] + // CHECK: %[[VAL_1_1:.*]] = xla_gpu.pure_call @fused_computation_log({{.*}}, %[[THREAD_ID]]) + // CHECK: %[[VAL_1_2:.*]] = xla_gpu.pure_call @fused_computation__epilogue__({{.*}}, %[[THREAD_ID]], %[[VAL_1_1]]) + // CHECK: %[[INSERTED_1:.*]] = tensor.insert %[[VAL_1_2:.*]] into {{.*}}[%[[THREAD_ID]]] + // CHECK: scf.yield %[[INSERTED_1]] + + // CHECK: %[[VAL_2_1:.*]] = xla_gpu.pure_call @fused_computation_exp({{.*}}, %[[THREAD_ID]]) + // CHECK: %[[INDEX_2:.*]] = affine.apply #[[MAP]]()[%[[THREAD_ID]]] + // CHECK: %[[VAL_2_2:.*]] = xla_gpu.pure_call @fused_computation__epilogue__({{.*}}, %[[INDEX_2]], %[[VAL_2_1]]) + // CHECK: %[[INSERTED_2:.*]] = tensor.insert %[[VAL_2_2:.*]] into {{.*}}[%[[INDEX_2]]] + + // CHECK: return %[[INSERTED_2]] + + // CHECK: func.func private @fused_computation_log + // CHECK: func.func private @fused_computation_exp + // CHECK: func.func private @fused_computation_neg + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirConcatenateFusionTest, EpilogueSideParameter) { + auto kHloString = R"( + HloModule module + + fused_computation { + param0 = f32[64] parameter(0) + param1 = f32[192] parameter(1) + neg = f32[64] negate(param0) + slice = f32[128] slice(param1), slice={[32:160]} + exp = f32[128] exponential(slice) + concat = f32[192] concatenate(neg, exp), dimensions={0} + ROOT add = f32[192] add(concat, param1) + } + ENTRY main { + param0 = f32[64] parameter(0) + param1 = f32[192] parameter(1) + ROOT fusion = f32[192] fusion(param0, param1), calls=fused_computation, kind=kLoop + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirConcatenateFusionTest, MajorDimension) { + auto kHloString = R"( + HloModule module + + fused_computation { + param0 = f32[16,16] parameter(0) + param1 = f32[16,16] parameter(1) + ROOT concat = f32[32,16] concatenate(param0, param1), dimensions={0} + } + ENTRY main { + param0 = f32[16,16] parameter(0) + param1 = f32[16,16] parameter(1) + ROOT %fusion = f32[32,16] fusion(param0, param1), kind=kInput, calls=fused_computation + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirConcatenateFusionTest, EpilogueBitcast) { + auto kHloString = R"( + HloModule Test + + fused_computation { + p0 = pred[1] parameter(0) + p1 = pred[1] parameter(1) + p2 = pred[1] parameter(2) + %concatenate.3.3 = pred[3] concatenate(p0, p1, p2), dimensions={0} + %bitcast.57.1 = pred[1,1,3]{2,1,0} bitcast(pred[3]{0} %concatenate.3.3) + ROOT %convert.36.1 = u32[1,1,3] convert(pred[1,1,3]{2,1,0} %bitcast.57.1) + } + + ENTRY main { + p0 = pred[1] parameter(0) + p1 = pred[1] parameter(1) + p2 = pred[1] parameter(2) + ROOT fusion = u32[1,1,3] fusion(p0, p1, p2), kind=kInput, calls=fused_computation + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/concatenate_test.cc b/xla/service/gpu/fusions/concatenate_test.cc new file mode 100644 index 0000000000000..b617fd6513d10 --- /dev/null +++ b/xla/service/gpu/fusions/concatenate_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/concatenate.h" + +#include + +#include +#include +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/service/gpu/fusions/fusions.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class ConcatenateTest : public HloTestBase { + public: + void SetUp() override { + HloTestBase::SetUp(); + printer_ = + AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, + {"chunk_id", "unroll_id"}); + } + + protected: + AffineMapPrinter printer_; + mlir::MLIRContext mlir_context_; +}; + +TEST_F(ConcatenateTest, ThreadIndexing) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + fused_computation { + param0 = f32[200] parameter(0) + param1 = f32[400] parameter(1) + param2 = f32[300] parameter(2) + ROOT concat = f32[900] concatenate(param0, param1, param2), dimensions={0} + } + ENTRY main { + param0 = f32[200] parameter(0) + param1 = f32[400] parameter(1) + param2 = f32[300] parameter(2) + ROOT fusion = f32[900] fusion(param0, param1, param2), + calls=fused_computation, kind=kLoop + } + )") + .value(); + + stream_executor::DeviceDescription device_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis_fused = AnalyzeFusion(*root, device_info); + + TF_ASSERT_OK_AND_ASSIGN( + auto emitter, + GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused})); + auto fusion = dynamic_cast(emitter.get()); + ASSERT_NE(fusion, nullptr); + + constexpr auto kIndexing = R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( + (th_x + bl_x * 128) mod 400) + domain: + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 3] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + th_x + bl_x * 128 in [0, 399] + )"; + EXPECT_THAT( + fusion + ->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_) + ->ToString(printer_), + MatchIndexingString(kIndexing)); + EXPECT_THAT( + fusion + ->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_) + ->ToString(printer_), + MatchIndexingString(kIndexing)); + EXPECT_THAT( + fusion + ->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_) + ->ToString(printer_), + MatchIndexingString(kIndexing)); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/copy.cc b/xla/service/gpu/fusions/copy.cc index a04885f9f0a25..37cb5ab0dba0e 100644 --- a/xla/service/gpu/fusions/copy.cc +++ b/xla/service/gpu/fusions/copy.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,26 +16,28 @@ limitations under the License. #include -#include "xla/service/gpu/copy_thunk.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/runtime/copy_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/statusor.h" namespace xla { namespace gpu { -StatusOr MemcpyFusion::Emit( - IrEmitterContext& ir_emitter_context, ElementalIrEmitter& elemental_emitter, - mlir::lmhlo::FusionOp fusion_op, const HloFusionInstruction& fusion, - KernelReuseCache& kernel_cache, llvm::IRBuilder<>*) const { - auto src_buffer = *GetAllocationSlice(src_, ir_emitter_context.allocations()); - auto dst_buffer = *GetAllocationSlice(dst_, ir_emitter_context.allocations()); +absl::StatusOr MemcpyFusion::Emit( + IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion) const { FusionEmissionResult result; - if (src_buffer != dst_buffer) { - result.thunks.emplace_back(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(fusion_op), - /*source_buffer=*/src_buffer, - /*destination_buffer=*/dst_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(GetShape(src_)), - /*source_value=*/src_, - /*destination_value=*/dst_)); + for (int i = 0; i < src_buffers_.size(); ++i) { + if (src_buffers_[i] != dst_buffers_[i]) { + result.thunks.emplace_back(std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(&fusion), + /*source_buffer=*/src_buffers_[i], + /*destination_buffer=*/dst_buffers_[i], + /*mem_size=*/src_buffers_[i].size())); + } } return result; } diff --git a/xla/service/gpu/fusions/copy.h b/xla/service/gpu/fusions/copy.h index 173517d73384c..574f1eb454271 100644 --- a/xla/service/gpu/fusions/copy.h +++ b/xla/service/gpu/fusions/copy.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,28 +15,40 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_FUSIONS_COPY_H_ #define XLA_SERVICE_GPU_FUSIONS_COPY_H_ +#include + +#include "mlir/IR/Value.h" // from @llvm-project +#include "xla/service/buffer_assignment.h" #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/ir_emitter_context.h" namespace xla { namespace gpu { -// Special case of a fusion consisting only of a kCopy instruction that can be -// implemented using a memcpy. +// Special case of a fusion consisting only of `kCopy` instructions that can be +// implemented using `memcpy`s. class MemcpyFusion : public FusionInterface { public: - MemcpyFusion(mlir::Value src, mlir::Value dst) : src_(src), dst_(dst) {} - - StatusOr Emit(IrEmitterContext& ir_emitter_context, - ElementalIrEmitter& elemental_emitter, - mlir::lmhlo::FusionOp fusion_op, - const HloFusionInstruction& fusion, - KernelReuseCache& kernel_cache, - llvm::IRBuilder<>*) const final; + MemcpyFusion(std::vector src_buffers, + std::vector dst_buffers, + std::vector srcs, std::vector dsts) + : src_buffers_(std::move(src_buffers)), + dst_buffers_(std::move(dst_buffers)), + srcs_(std::move(srcs)), + dsts_(std::move(dsts)) {} + + absl::StatusOr Emit( + IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion) const final; private: - mlir::Value src_; - mlir::Value dst_; + std::vector src_buffers_; + std::vector dst_buffers_; + + // These are only used by the LMHLO code path and are empty if emitting from + // HLO. + std::vector srcs_; + std::vector dsts_; }; } // namespace gpu diff --git a/xla/service/gpu/fusions/cudnn.cc b/xla/service/gpu/fusions/cudnn.cc new file mode 100644 index 0000000000000..60f8af52222eb --- /dev/null +++ b/xla/service/gpu/fusions/cudnn.cc @@ -0,0 +1,47 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusions/cudnn.h" + +#include "xla/hlo/ir/hlo_instructions.h" +#if GOOGLE_CUDA +#include "xla/service/gpu/runtime/cudnn_thunk.h" +#endif + +namespace xla { +namespace gpu { + +absl::StatusOr CuDnnFusion::Emit( + IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion) const { +#if GOOGLE_CUDA + VLOG(3) << fusion.ToString(); + + TF_ASSIGN_OR_RETURN( + auto kernel_arguments, + KernelArguments::Create(ir_emitter_context.buffer_assignment(), &fusion)); + FusionEmissionResult result; + result.thunks.emplace_back(std::make_unique( + GetComputationFingerprint(fusion.fused_instructions_computation(), {}), + Thunk::ThunkInfo::WithProfileAnnotation(&fusion), + kernel_arguments.args())); + return result; +#else + return absl::UnimplementedError("cuDNN support requires CUDA"); +#endif +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/cudnn.h b/xla/service/gpu/fusions/cudnn.h new file mode 100644 index 0000000000000..ba5eecc88a481 --- /dev/null +++ b/xla/service/gpu/fusions/cudnn.h @@ -0,0 +1,39 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSIONS_CUDNN_H_ +#define XLA_SERVICE_GPU_FUSIONS_CUDNN_H_ + +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" + +namespace xla { +namespace gpu { + +// Creates thunks from compiled cuDNN graphs serialized in backend +// configs of corresponding fusions. +class CuDnnFusion : public FusionInterface { + public: + explicit CuDnnFusion(const HloFusionAnalysis&) {} + + absl::StatusOr Emit( + IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion) const final; +}; +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_CUDNN_H_ diff --git a/xla/service/gpu/fusions/cudnn_test.cc b/xla/service/gpu/fusions/cudnn_test.cc new file mode 100644 index 0000000000000..c69caa882fde0 --- /dev/null +++ b/xla/service/gpu/fusions/cudnn_test.cc @@ -0,0 +1,787 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "absl/status/statusor.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/substitute.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/executable.h" +#include "xla/service/gpu/stream_executor_util.h" +#include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/tests/filecheck.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class CuDnnFusionTest : public GpuCodegenTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + // Let this group of tests just use first available plan skipping + // autotuning. + debug_options.set_xla_gpu_autotune_level(0); + debug_options.set_xla_gpu_cudnn_gemm_fusion_level(1); + return debug_options; + } + bool IsAtLeastHopperWithCuDnn9() { + se::StreamExecutor* executor = backend().default_stream_executor(); + return executor->GetDeviceDescription() + .cuda_compute_capability() + .IsAtLeastHopper() && + GetDnnVersionInfo(executor).major_version() >= 9; + } + bool IsAtLeastCuDnn91() { + se::StreamExecutor* executor = backend().default_stream_executor(); + const se::dnn::VersionInfo version = GetDnnVersionInfo(executor); + return (version.major_version() == 9 && version.minor_version() >= 1) || + version.major_version() > 9; + } + + protected: + void SetUp() override { + if (!IsAtLeastHopperWithCuDnn9()) { + GTEST_SKIP() + << "cuDNN GEMM fusion is not enabled before Hopper / cuDNN 9."; + } + } +}; + +using CuDnnFusionExecutionTest = CuDnnFusionTest; + +TEST_F(CuDnnFusionExecutionTest, + NoTritonConfigIsAssignedAtZeroAutotuningLevel) { + EXPECT_EQ(GetDebugOptionsForTest().xla_gpu_autotune_level(), 0); + MatchOptimizedHlo(R"( +fusion1 { + p0 = f32[32,96] parameter(0) + p1 = f32[96,64] parameter(1) + ROOT r = f32[32,64] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[32,96] parameter(0) + p1 = f32[96,64] parameter(1) + ROOT _ = f32[32,64] fusion(p0, p1), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + R"( +CHECK-NOT: triton_gemm_config + )"); +} + +TEST_F(CuDnnFusionExecutionTest, DotF32ExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = f32[32,96] parameter(0) + p1 = f32[96,64] parameter(1) + ROOT r = f32[32,64] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[32,96] parameter(0) + p1 = f32[96,64] parameter(1) + ROOT _ = f32[32,64] fusion(p0, p1), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(CuDnnFusionExecutionTest, DotBF16WithCopyExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = bf16[96,512,64]{1,2,0} parameter(0) + cp = bf16[96,512,64]{2,1,0} copy(p0) + p1 = bf16[96,64,512]{2,1,0} parameter(1) + ROOT d = bf16[96,512,512]{2,1,0} dot(cp, p1), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = bf16[96,512,64]{1,2,0} parameter(0) + p1 = bf16[96,64,512]{2,1,0} parameter(1) + ROOT r = bf16[96,512,512]{2,1,0} fusion(p0, p1), kind=kCustom, + calls=fusion1, + backend_config={"fusion_backend_config": {kind :"__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-3})); +} + +TEST_F(CuDnnFusionExecutionTest, DotBF16BF16F32ExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = bf16[16,32,128] parameter(0) + p1 = bf16[16,128,64] parameter(1) + ROOT r = f32[16,32,64] dot(p0, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = bf16[16,32,128] parameter(0) + p1 = bf16[16,128,64] parameter(1) + ROOT _ = f32[16,32,64] fusion(p0, p1), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); +} + +TEST_F(CuDnnFusionExecutionTest, DotF32WithOutputSubtractionExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = f32[9,32,96] parameter(0) + p1 = f32[9,96,64] parameter(1) + d = f32[9,32,64] dot(p0, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} + p2 = f32[9,32,64] parameter(2) + ROOT s = f32[9,32,64] subtract(p2, d) +} + +ENTRY e { + p0 = f32[9,32,96] parameter(0) + p1 = f32[9,96,64] parameter(1) + p2 = f32[9,32,64] parameter(2) + ROOT _ = f32[9,32,64] fusion(p0, p1, p2), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(CuDnnFusionExecutionTest, DotWithNonDefaultLayoutsExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = bf16[32,32]{0,1} parameter(0) + p1 = bf16[32,32]{1,0} parameter(1) + ROOT r = bf16[32,32]{0,1} dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = bf16[32,32]{0,1} parameter(0) + p1 = bf16[32,32]{1,0} parameter(1) + ROOT _ = bf16[32,32]{0,1} fusion(p0, p1), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4})); +} + +TEST_F(CuDnnFusionExecutionTest, RHSFusionExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = bf16[5,32,96] parameter(0) + p1 = s8[5,96,16] parameter(1) + p1c = bf16[5,96,16] convert(p1) + ROOT r = bf16[5,32,16] dot(p0, p1c), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = bf16[5,32,96] parameter(0) + p1 = s8[5,96,16] parameter(1) + ROOT _ = bf16[5,32,16] fusion(p0, p1), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(CuDnnFusionExecutionTest, SkipNonDefaultPrecision) { + EXPECT_FALSE(Run(R"( +t { + p0 = f32[27,23] parameter(0) + p0c = s8[27,23] convert(p0) + p0cc = f32[27,23] convert(p0c) + p1 = f32[23,21] parameter(1) + ROOT r = f32[27,21] dot(p0cc, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + operand_precision={HIGH, HIGH} +} + +ENTRY e { + p0 = f32[27,23] parameter(0) + p1 = f32[23,21] parameter(1) + ROOT r = f32[27,21] fusion(p0, p1), kind=kCustom, calls=t, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})")); +} + +TEST_F(CuDnnFusionExecutionTest, + DotF16NegateNonDefaultDimensionsExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = f16[16,32,96] parameter(0) + p0n = f16[16,32,96] negate(p0) + p1 = f16[16,64,96] parameter(1) + ROOT r = f16[16,32,64] dot(p0n, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2} +} + +ENTRY e { + p0 = f16[16,32,96] parameter(0) + p1 = f16[16,64,96] parameter(1) + ROOT _ = f16[16,32,64] fusion(p0, p1), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(CuDnnFusionExecutionTest, DotS8BF16ExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = s8[5,32,96] parameter(0) + p0c = bf16[5,32,96] convert(p0) + p1 = bf16[5,96,16] parameter(1) + ROOT r = bf16[5,32,16] dot(p0c, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = s8[5,32,96] parameter(0) + p1 = bf16[5,96,16] parameter(1) + ROOT _ = bf16[5,32,16] fusion(p0, p1), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-5, /*arel=*/1e-5})); +} + +TEST_F(CuDnnFusionExecutionTest, IntegerMathExecutesCorrectly) { + if (!IsAtLeastCuDnn91()) { + GTEST_SKIP() << "Integer math requires cuDNN 9.1+."; + } + const std::string kHloText = + R"( +fusion1 { + p0 = s8[16,16] parameter(0) + p1 = s8[16,16] parameter(1) + d = s32[16,16] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + p2 = s32[16,16] parameter(2) + ROOT a = s32[16,16] add(d, p2) +} + +ENTRY e { + p0 = s8[16,16] parameter(0) + p1 = s8[16,16] parameter(1) + p2 = s32[16,16] parameter(2) + ROOT r = s32[16,16] fusion(p0, p1, p2), kind=kCustom, + calls=fusion1, + backend_config={"fusion_backend_config": {"kind":"__cudnn$fusion"}} +})"; + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +class CuDnnFusionCommandBufferTest : public CuDnnFusionTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = CuDnnFusionTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_graph_min_graph_size(1); + return debug_options; + } +}; + +TEST_F(CuDnnFusionCommandBufferTest, CommandBuffersAreSupported) { + const std::string kHloText = R"( +fd0 { + p0 = f32[64,64]{1,0} parameter(0) + p1 = f32[64,64]{1,0} parameter(1) + ROOT d = f32[64,64]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +fd1 { + p0 = f32[64,64]{1,0} parameter(0) + p1 = f32[64,64]{1,0} parameter(1) + ROOT d = f32[64,64]{1,0} dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = f32[64,64]{1,0} parameter(0) + p1 = f32[64,64]{1,0} parameter(1) + d0 = f32[64,64]{1,0} fusion(p0, p1), kind=kCustom, calls=fd0, + backend_config={"fusion_backend_config":{"kind":"__cudnn$fusion","cudnn_fusion_config":{"plan_id":"0"}}} + a = f32[64,64]{1,0} add(d0, d0) + ROOT d1 = f32[64,64]{1,0} fusion(a, d0), kind=kCustom, calls=fd1, + backend_config={"fusion_backend_config":{"kind":"__cudnn$fusion","cudnn_fusion_config":{"plan_id":"0"}}} +})"; + + // Verify that a command buffer is applied. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + backend().compiler()->RunBackend( + GetOptimizedModule(kHloText).value(), + backend().default_stream_executor(), + backend().default_stream_executor()->GetAllocator())); + absl::StatusOr filecheck_result = + RunFileCheck(executable->module().ToString(), R"( +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: ROOT +; CHECK-SAME: command_buffer +)"); + TF_ASSERT_OK(filecheck_result.status()); + EXPECT_TRUE(filecheck_result.value()); + + // Verify that the command buffer executes correctly. + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +class CuDnnFusionLevel2Test : public CuDnnFusionExecutionTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = + CuDnnFusionExecutionTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_cudnn_gemm_fusion_level(2); + return debug_options; + } +}; + +TEST_F(CuDnnFusionLevel2Test, BroadcastToDim2ExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = f16[16,32,128] parameter(0) + p1 = f16[16,128,64] parameter(1) + p2 = f16[16,32] parameter(2) + p2b = f16[16,32,128] broadcast(p2), dimensions={0,1} + a = f16[16,32,128] add(p0, p2b) + ROOT r = f16[16,32,64] dot(a, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = f16[16,32,128] parameter(0) + p1 = f16[16,128,64] parameter(1) + p2 = f16[16,32] parameter(2) + ROOT _ = f16[16,32,64] fusion(p0, p1, p2), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(CuDnnFusionLevel2Test, BroadcastToDim1ExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = f16[16,32,128] parameter(0) + p1 = f16[16,128,64] parameter(1) + p2 = f16[16,128] parameter(2) + p2b = f16[16,32,128] broadcast(p2), dimensions={0,2} + a = f16[16,32,128] add(p0, p2b) + ROOT r = f16[16,32,64] dot(a, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = f16[16,32,128] parameter(0) + p1 = f16[16,128,64] parameter(1) + p2 = f16[16,128] parameter(2) + ROOT _ = f16[16,32,64] fusion(p0, p1, p2), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(CuDnnFusionLevel2Test, BroadcastToDim0ExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = bf16[32,128] parameter(0) + p0b = bf16[5,32,128] broadcast(p0), dimensions={1,2} + p1 = bf16[5,128,64] parameter(1) + ROOT r = f32[5,32,64] dot(p0b, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = bf16[32,128] parameter(0) + p1 = bf16[5,128,64] parameter(1) + ROOT _ = f32[5,32,64] fusion(p0, p1), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(CuDnnFusionLevel2Test, BroadcastTo2DimsExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = f16[16,32,128] parameter(0) + p1 = f16[16,128,64] parameter(1) + p2 = f16[128] parameter(2) + p2b = f16[16,32,128] broadcast(p2), dimensions={2} + a = f16[16,32,128] add(p0, p2b) + ROOT r = f16[16,32,64] dot(a, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = f16[16,32,128] parameter(0) + p1 = f16[16,128,64] parameter(1) + p2 = f16[128] parameter(2) + ROOT _ = f16[16,32,64] fusion(p0, p1, p2), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(CuDnnFusionLevel2Test, BroadcastTo3DimsExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = f16[16,32,128] parameter(0) + p1 = f16[16,128,64] parameter(1) + p2 = f16[] parameter(2) + p2b = f16[16,32,128] broadcast(p2), dimensions={} + a = f16[16,32,128] add(p0, p2b) + ROOT r = f16[16,32,64] dot(a, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = f16[16,32,128] parameter(0) + p1 = f16[16,128,64] parameter(1) + p2 = f16[] parameter(2) + ROOT _ = f16[16,32,64] fusion(p0, p1, p2), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +class CuDnnFusionLevel3Test : public CuDnnFusionExecutionTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = + CuDnnFusionExecutionTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_cudnn_gemm_fusion_level(3); + return debug_options; + } +}; + +TEST_F(CuDnnFusionLevel3Test, + DotWithSplitNonContractingInputExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = s8[4,3,16,400]{2,1,3,0} parameter(0) + cp0 = s8[4,3,16,400]{3,2,1,0} copy(p0) + bc0 = s8[192,400]{1,0} bitcast(cp0) + cvt0 = bf16[192,400]{1,0} convert(bc0) + p1 = bf16[1,128,400]{2,1,0} parameter(1) + bc1 = bf16[128,400]{1,0} reshape(p1) + ROOT d = bf16[192,128]{1,0} dot(cvt0, bc1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY r { + p0 = s8[4,3,16,400]{2,1,3,0} parameter(0) + p1 = bf16[1,128,400]{2,1,0} parameter(1) + ROOT r = bf16[192,128]{1,0} fusion(p0, p1), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(CuDnnFusionLevel3Test, + DotWithSplitNonContractingInOutExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = s8[4,3,16,400]{2,1,3,0} parameter(0) + cp0 = s8[4,3,16,400]{3,2,1,0} copy(p0) + bc0 = s8[192,400]{1,0} bitcast(cp0) + cvt0 = bf16[192,400]{1,0} convert(bc0) + p1 = bf16[1,128,400]{2,1,0} parameter(1) + bc1 = bf16[128,400]{1,0} reshape(p1) + d = bf16[192,128]{1,0} dot(cvt0, bc1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + bc = bf16[4,3,16,128]{3,2,1,0} bitcast(d) + ROOT cp = bf16[4,3,16,128]{2,1,3,0} copy(bc) +} + +ENTRY r { + p0 = s8[4,3,16,400]{2,1,3,0} parameter(0) + p1 = bf16[1,128,400]{2,1,0} parameter(1) + ROOT r = bf16[4,3,16,128]{2,1,3,0} fusion(p0, p1), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +class ElementwiseTest : public CuDnnFusionExecutionTest, + public ::testing::WithParamInterface< + std::tuple> {}; + +std::string ElementwiseTestParamsToString( + const ::testing::TestParamInfo>& + data) { + PrimitiveType data_type; + HloOpcode opcode; + float tolerance; + std::tie(data_type, opcode, tolerance) = data.param; + return absl::StrCat( + primitive_util::LowercasePrimitiveTypeName(data_type), "_", + absl::StrReplaceAll(HloOpcodeString(opcode), {{"-", "_"}})); +} + +using UnaryElementwiseTest = ElementwiseTest; + +TEST_P(UnaryElementwiseTest, ElementwiseFusionExecutesCorrectly) { + PrimitiveType data_type; + HloOpcode opcode; + float tolerance; + std::tie(data_type, opcode, tolerance) = GetParam(); + + const std::string kHloTemplate = R"( +fusion_computation { + p0 = f32[32,32] parameter(0) + p1 = $0[32,32] parameter(1) + f1.1 = $0[32,32] $1(p1) + c.1 = f32[32,32] convert(f1.1) + ROOT _ = f32[32,32] dot(p0, c.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p1 = $0[32,32] parameter(1) + p0 = f32[32,32] parameter(0) + ROOT r = f32[32,32] fusion(p0, p1), kind=kCustom, + calls=fusion_computation, + backend_config={"fusion_backend_config":{"kind":"__cudnn$$fusion"}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + + EXPECT_TRUE(RunAndCompare(hlo_test, + ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance})); +} + +INSTANTIATE_TEST_SUITE_P( + ElementwiseTestSuiteF32, UnaryElementwiseTest, + ::testing::Combine(::testing::Values(F32), + ::testing::ValuesIn({HloOpcode::kAbs, HloOpcode::kCos, + HloOpcode::kExp, HloOpcode::kLog, + HloOpcode::kNegate, + HloOpcode::kRsqrt, HloOpcode::kSin, + HloOpcode::kSqrt, HloOpcode::kTan, + HloOpcode::kTanh}), + ::testing::Values(5e-4)), + ElementwiseTestParamsToString); + +using BinaryElementwiseTest = ElementwiseTest; + +TEST_P(BinaryElementwiseTest, ElementwiseFusionExecutesCorrectly) { + PrimitiveType data_type; + HloOpcode opcode; + float tolerance; + std::tie(data_type, opcode, tolerance) = GetParam(); + + const std::string kHloTemplate = R"( +fusion_computation { + p0 = f32[32,32] parameter(0) + p1 = $0[32,32] parameter(1) + p2 = $0[32,32] parameter(2) + f1.1 = $0[32,32] $1(p1, p2) + c.1 = f32[32,32] convert(f1.1) + ROOT _ = f32[32,32] dot(p0, c.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + +ENTRY e { + p0 = f32[32,32] parameter(0) + p1 = $0[32,32] parameter(1) + p2 = $0[32,32] parameter(2) + ROOT r = f32[32,32] fusion(p0, p1, p2), kind=kCustom, + calls=fusion_computation, + backend_config={"fusion_backend_config":{"kind":"__cudnn$$fusion"}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + + EXPECT_TRUE(RunAndCompare(hlo_test, + ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance})); +} + +INSTANTIATE_TEST_SUITE_P( + ElementwiseTestSuiteF32, BinaryElementwiseTest, + ::testing::Combine( + ::testing::Values(F32), + ::testing::ValuesIn({HloOpcode::kAdd, HloOpcode::kDivide, + HloOpcode::kMaximum, HloOpcode::kMinimum, + HloOpcode::kMultiply, HloOpcode::kPower, + HloOpcode::kSubtract}), + ::testing::Values(3e-3)), + ElementwiseTestParamsToString); + +class CompareTest : public CuDnnFusionExecutionTest, + public ::testing::WithParamInterface< + std::tuple> {}; + +std::string CompareTestParamsToString( + const ::testing::TestParamInfo< + std::tuple>& data) { + PrimitiveType data_type; + Comparison::Direction direction; + std::tie(data_type, direction) = data.param; + return absl::StrCat(primitive_util::LowercasePrimitiveTypeName(data_type), + "_", ComparisonDirectionToString(direction)); +} + +TEST_P(CompareTest, FusedComparisonExecutesCorrectly) { + PrimitiveType data_type; + Comparison::Direction direction; + std::tie(data_type, direction) = GetParam(); + + const std::string kHloTemplate = R"( +fusion_computation { + p0 = f32[32,32] parameter(0) + p1 = $0[32,32] parameter(1) + p2 = $0[32,32] parameter(2) + f1.1 = pred[32,32] compare(p1, p2), direction=$1 + c.1 = f32[32,32] convert(f1.1) + ROOT _ = f32[32,32] dot(p0, c.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + +ENTRY e { + p0 = f32[32,32] parameter(0) + p1 = $0[32,32] parameter(1) + p2 = $0[32,32] parameter(2) + ROOT r = f32[32,32] fusion(p0, p1, p2), kind=kCustom, + calls=fusion_computation, + backend_config={"fusion_backend_config":{"kind":"__cudnn$$fusion"}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + ComparisonDirectionToString(direction)); + + EXPECT_TRUE(RunAndCompare(hlo_test, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +using cd = Comparison::Direction; + +INSTANTIATE_TEST_SUITE_P( + CompareTestSuite, CompareTest, + ::testing::Combine(::testing::Values(PRED, S8, S32, F16, F32), + ::testing::Values(cd::kEq, cd::kNe, cd::kGe, cd::kGt, + cd::kLe, cd::kLt)), + CompareTestParamsToString); + +class SelectTest : public CuDnnFusionExecutionTest, + public ::testing::WithParamInterface {}; + +TEST_P(SelectTest, SelectFusionExecutesCorrectly) { + if (!IsAtLeastCuDnn91()) { + GTEST_SKIP() << "Select operation requires cuDNN 9.1+."; + } + const std::string kHloTemplate = R"( +fusion_computation { + p0 = f32[32,32] parameter(0) + p1 = $0[32,32] parameter(1) + p2 = $0[32,32] parameter(2) + p3 = pred[32,32] parameter(3) + s = $0[32,32] select(p3, p1, p2) + c = f32[32,32] convert(s) + ROOT r = f32[32,32] dot(p0, c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[32,32] parameter(0) + p1 = $0[32,32] parameter(1) + p2 = $0[32,32] parameter(2) + p3 = pred[32,32] parameter(3) + ROOT r = f32[32,32] fusion(p0, p1, p2, p3), kind=kCustom, + calls=fusion_computation, + backend_config={"fusion_backend_config":{"kind":"__cudnn$$fusion"}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTemplate, primitive_util::LowercasePrimitiveTypeName(GetParam())); + + EXPECT_TRUE(RunAndCompare(hlo_test, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4})); +} + +constexpr std::array kSupportedDataTypes{F16, F32, BF16}; + +INSTANTIATE_TEST_SUITE_P(SelectTestSuite, SelectTest, + ::testing::ValuesIn(kSupportedDataTypes)); + +class CuDnnFusionRewriteTest : public CuDnnFusionTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = CuDnnFusionTest::GetDebugOptionsForTest(); + // Reset autotuning level to default. + debug_options.set_xla_gpu_autotune_level( + GetDebugOptionsFromFlags().xla_gpu_autotune_level()); + debug_options.set_xla_gpu_cudnn_gemm_fusion_level(1); + return debug_options; + } +}; + +TEST_F(CuDnnFusionRewriteTest, + DoNotExecuteGemmFusionWithCuDnnWhenNotSupported) { + // Dimension size 61 does not satisfy the requirement on alignment + // (multiple of 2). + MatchOptimizedHlo(R"( +ENTRY e { + p0 = f16[20,40,61] parameter(0) + p2 = f16[20,40,61] parameter(2) + p0n = f16[20,40,61] negate(p2) + p1 = f16[20,80,61] parameter(1) + ROOT r = f16[20,40,80] dot(p0n, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2} +})", + R"( +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: ROOT +; CHECK-SAME: fusion +; CHECK-NOT: cudnn +)"); +} + +TEST_F(CuDnnFusionRewriteTest, AutotuningPicksCuDnnForS8BF16OnHopper) { + // The test case relies on measurements by the autotuner and current + // performance comparison of the backends. May need to be updated if + // the situation changes. + MatchOptimizedHlo(R"( +e { + p0 = bf16[720,720,720] parameter(0) + p1 = s8[720,720,720] parameter(1) + c = bf16[720,720,720] convert(p1) + ROOT d = bf16[720,720,720] dot(p0, c), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={1} +})", + R"( +; CHECK: __cudnn$fusion +)"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/custom.cc b/xla/service/gpu/fusions/custom.cc new file mode 100644 index 0000000000000..2c69785d9f618 --- /dev/null +++ b/xla/service/gpu/fusions/custom.cc @@ -0,0 +1,751 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/custom.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/AsmParser/AsmParser.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/ffi_api.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/custom_call_status.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/kernel_arguments.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/kernels/custom_kernel_fusion.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/runtime/address_computation_thunk.h" +#include "xla/service/gpu/runtime/custom_call_thunk.h" +#include "xla/service/gpu/runtime/gemm_thunk.h" +#include "xla/service/gpu/runtime/kernel_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +constexpr unsigned kGEMMOutputBufferIndex = 0; +constexpr unsigned kGEMMWorkspaceBufferIndex = 1; + +absl::StatusOr> BuildCustomKernelThunkForFusion( + IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion, + CustomKernel custom_kernel) { + TF_ASSIGN_OR_RETURN( + auto kernel_arguments, + KernelArguments::Create(ir_emitter_context.buffer_assignment(), &fusion)); + + return std::make_unique( + &fusion, std::move(custom_kernel), std::move(kernel_arguments.args())); +} + +absl::StatusOr GetOperandSlice( + const BufferAssignment& buffer_assignment, const HloFusionAdaptor& adaptor, + const HloInstruction& fusion_instr, const HloInstruction& start_instr, + std::vector& slice_instrs, const ShapeIndex& shape_idx, + unsigned arg_idx) { + if (const auto* param = DynCast(&start_instr)) { + return GetAllocationSlice(buffer_assignment, + fusion_instr.operand(param->parameter_number()), + shape_idx); + } + + // Walk through ShapeIndex to find the real starting point. + auto* start = const_cast(&start_instr); + for (auto idx : shape_idx) { + CHECK(start->shape().IsTuple()); + start = const_cast(start->operand(idx)); + } + + if (const auto* param = DynCast(start)) { + // At this point we've walked through all `shape_idx`, `index` should be + // empty. + return GetAllocationSlice(buffer_assignment, + fusion_instr.operand(param->parameter_number()), + /*index*/ {}); + } + + auto slice_adaptor = + HloFindIf({HloInstructionAdaptor(*start)}, adaptor, [](auto node) { + return IsOpcodeAnyOf(node); + }); + if (slice_adaptor.has_value()) { + auto* slice_instr = + const_cast(&slice_adaptor->instruction()); + + if (!IsContiguousSlice(slice_instr->operand(0)->shape(), + slice_instr->shape())) { + return absl::InternalError( + "AddressComputationFusion only handles contiguous slices " + "currently"); + } + + slice_instrs[arg_idx] = slice_instr; + + const auto* param = Cast(slice_instr->operand(0)); + // At this point we've walked through all `shape_idx`, `index` should be + // empty. + TF_ASSIGN_OR_RETURN( + BufferAllocation::Slice orig_slice, + GetAllocationSlice(buffer_assignment, + fusion_instr.operand(param->parameter_number()), + /*index*/ {})); + + if (auto* static_slice = DynCast(slice_instr)) { + // Update static slices. + const Shape& src_shape = static_slice->operand(0)->shape(); + const Shape& dst_shape = static_slice->shape(); + int64_t size = ShapeUtil::ByteSizeOf(dst_shape); + + // Given this slice + // f16[1,4,8]{2,1,0} slice(f16[2,8,8]{2,1,0}), + // slice={[1:2], [4:8], [0:8]} + // + // The offset of the slice should be: + // slice_starts(0) * 8 * 8 * sizeof(f16) + + // slice_starts(1) * 8 * sizeof(f16) + int64_t offset = orig_slice.offset(); + for (auto [start, stride] : + llvm::zip(static_slice->slice_starts(), + *ShapeUtil::ByteStrides(src_shape))) { + offset += start * stride; + } + + return BufferAllocation::Slice(orig_slice.allocation(), offset, size); + } + + return orig_slice; + } + + return absl::InternalError("WTF"); +} + +absl::Status CollectSliceInfo( + const BufferAssignment& buffer_assignment, + const HloInstruction& fusion_instr, + absl::Span slice_instrs, + std::vector>>& + offset_buffer_indices, + std::vector>& orig_shapes, + std::vector>& sliced_shapes, + std::vector>& offset_byte_sizes, unsigned arg_idx) { + auto* slice_instr = + DynCastOrNull(slice_instrs[arg_idx]); + if (slice_instr == nullptr) { + return absl::OkStatus(); + } + + std::vector offset_slices; + for (auto idx_op : slice_instr->index_operands()) { + const auto* param = Cast(idx_op); + TF_ASSIGN_OR_RETURN( + auto offset_slice, + GetAllocationSlice(buffer_assignment, + fusion_instr.operand(param->parameter_number()), + /*index=*/{})); + offset_slices.push_back(offset_slice); + } + offset_buffer_indices[arg_idx] = std::move(offset_slices); + orig_shapes[arg_idx] = slice_instr->operand(0)->shape(); + sliced_shapes[arg_idx] = DynCast(slice_instr) + ? slice_instr->shape() + : slice_instr->operand(1)->shape(); + offset_byte_sizes[arg_idx] = ShapeUtil::ByteSizeOfPrimitiveType( + slice_instr->index_operands().front()->shape().element_type()); + + return absl::OkStatus(); +} + +absl::StatusOr GetResultSlice( + const BufferAssignment& buffer_assignment, const HloFusionAdaptor& adaptor, + const HloInstruction& fusion_instr, const HloInstruction& start_instr, + std::vector& slice_instrs, const ShapeIndex& shape_idx, + unsigned arg_idx) { + auto* start = const_cast(&start_instr); + // Walk through ShapeIndex to find the real "user" (i.e. not get-tuple-element + // user). Otherwise one sliced element will mark all buffers of all other + // elements "sliced" too. + if (start->shape().IsTuple()) { + for (auto idx : shape_idx) { + std::vector gte_users( + start->shape().tuple_shapes_size(), nullptr); + for (auto* user : start->users()) + if (auto* gte = DynCast(user)) + gte_users[gte->tuple_index()] = gte; + + start = static_cast(gte_users[idx]); + if (start == nullptr) + return GetAllocationSlice(buffer_assignment, &fusion_instr, shape_idx); + } + } + + auto slice_adaptor = HloFindIf( + {HloInstructionAdaptor(*start)}, adaptor, + [](auto node) { return node.opcode() == HloOpcode::kDynamicUpdateSlice; }, + /*visit_operands=*/false); + if (slice_adaptor.has_value()) { + auto* slice_instr = + const_cast(&slice_adaptor->instruction()); + slice_instrs[arg_idx] = slice_instr; + + if (!IsContiguousSlice(slice_instr->shape(), + Cast(slice_instr) + ->update() + ->shape())) { + return absl::InternalError( + "AddressComputationFusion only handles contiguous slices " + "currently"); + } + } + + return GetAllocationSlice(buffer_assignment, &fusion_instr, shape_idx); +} + +absl::StatusOr EmitGemm( + IrEmitterContext& ir_emitter_context, const HloFusionAdaptor& adaptor, + const HloFusionInstruction& fusion, + const HloCustomCallInstruction& custom_call) { + const BufferAssignment& buffer_assignment = + ir_emitter_context.buffer_assignment(); + + std::vector>> + offset_buffer_indices(4, std::nullopt); + std::vector> orig_shapes(4, std::nullopt); + std::vector> sliced_shapes(4, std::nullopt); + std::vector> offset_byte_sizes(4, std::nullopt); + + std::vector slice_instrs(4, nullptr); + + unsigned arg_idx = 0; + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_slice, + GetOperandSlice(buffer_assignment, adaptor, fusion, + *custom_call.operand(arg_idx), + slice_instrs, /*shape_idx=*/{}, arg_idx)); + TF_RETURN_IF_ERROR(CollectSliceInfo( + buffer_assignment, fusion, absl::Span(slice_instrs), + offset_buffer_indices, orig_shapes, sliced_shapes, offset_byte_sizes, + arg_idx++)); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_slice, + GetOperandSlice(buffer_assignment, adaptor, fusion, + *custom_call.operand(arg_idx), + slice_instrs, /*shape_idx=*/{}, arg_idx)); + TF_RETURN_IF_ERROR(CollectSliceInfo( + buffer_assignment, fusion, absl::Span(slice_instrs), + offset_buffer_indices, orig_shapes, sliced_shapes, offset_byte_sizes, + arg_idx++)); + + BufferAllocation::Slice output; + std::optional workspace = std::nullopt; + std::optional slice_workspace_fake = std::nullopt; + + // Handling cases where multiple operands share the same buffer, with + // different offset by creating new fake allocations so each operand will have + // a different buffer index. The slices can thus always start at offset 0. + // AddressComputationThunk will take care of the offset adjustment. + std::vector> fake_allocations(4); + if (fusion.shape().IsArray()) { + TF_ASSIGN_OR_RETURN( + output, GetResultSlice(buffer_assignment, adaptor, fusion, custom_call, + slice_instrs, /*shape_idx=*/{}, arg_idx)); + TF_RETURN_IF_ERROR(CollectSliceInfo( + buffer_assignment, fusion, absl::Span(slice_instrs), + offset_buffer_indices, orig_shapes, sliced_shapes, offset_byte_sizes, + arg_idx)); + } else { + TF_ASSIGN_OR_RETURN( + output, + GetResultSlice(buffer_assignment, adaptor, fusion, custom_call, + slice_instrs, /*shape_idx=*/{kGEMMOutputBufferIndex}, + arg_idx)); + TF_RETURN_IF_ERROR(CollectSliceInfo( + buffer_assignment, fusion, absl::Span(slice_instrs), + offset_buffer_indices, orig_shapes, sliced_shapes, offset_byte_sizes, + arg_idx++)); + + // TODO(vuson): If we want to support slices of workspace, we'd need to + // start `HloFindIf` with `get-tuple-element` with the right index. + TF_ASSIGN_OR_RETURN( + workspace, GetAllocationSlice(buffer_assignment, &fusion, + /*index=*/{kGEMMWorkspaceBufferIndex})); + TF_RETURN_IF_ERROR(CollectSliceInfo( + buffer_assignment, fusion, absl::Span(slice_instrs), + offset_buffer_indices, orig_shapes, sliced_shapes, offset_byte_sizes, + arg_idx)); + fake_allocations[arg_idx] = std::make_unique( + /*index=*/arg_idx, workspace->size(), /*color=*/0); + slice_workspace_fake = BufferAllocation::Slice( + fake_allocations[arg_idx].get(), 0, workspace->size()); + } + + if (absl::c_all_of(slice_instrs, [&](auto slice_instr) { + return slice_instr == nullptr; + })) { + return absl::InternalError( + "AddressComputationFusion expects at least one sliced " + "operand/result"); + } + + bool deterministic_ops = + ir_emitter_context.debug_options().xla_gpu_deterministic_ops(); + + TF_ASSIGN_OR_RETURN( + GemmConfig config, + GemmConfig::For(static_cast(&custom_call))); + + std::unique_ptr thunk; + auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(&custom_call); + + if (absl::c_any_of(slice_instrs, [&](auto slice_instr) { + return DynCastOrNull(slice_instr) != + nullptr; + })) { + // Creating embedded GEMM thunk. + unsigned fake_arg_idx = 0; + int64_t lhs_byte_size = + ShapeUtil::ByteSizeOf(custom_call.operand(fake_arg_idx)->shape()); + fake_allocations[fake_arg_idx] = std::make_unique( + /*index=*/fake_arg_idx, lhs_byte_size, /*color=*/0); + BufferAllocation::Slice slice_lhs_fake(fake_allocations[fake_arg_idx].get(), + 0, lhs_byte_size); + + fake_arg_idx++; + int64_t rhs_byte_size = + ShapeUtil::ByteSizeOf(custom_call.operand(fake_arg_idx)->shape()); + fake_allocations[fake_arg_idx] = std::make_unique( + /*index=*/fake_arg_idx, rhs_byte_size, /*color=*/0); + BufferAllocation::Slice slice_rhs_fake(fake_allocations[fake_arg_idx].get(), + 0, rhs_byte_size); + + fake_arg_idx++; + int64_t out_fake_byte_size = ShapeUtil::ByteSizeOf( + custom_call.shape().IsArray() ? custom_call.shape() + : custom_call.shape().tuple_shapes(0)); + fake_allocations[fake_arg_idx] = std::make_unique( + /*index=*/fake_arg_idx, out_fake_byte_size, /*color=*/0); + BufferAllocation::Slice slice_out_fake(fake_allocations[fake_arg_idx].get(), + 0, out_fake_byte_size); + ThunkSequence seq; + seq.emplace_back(std::make_unique( + thunk_info, std::move(config), slice_lhs_fake, slice_rhs_fake, + slice_out_fake, slice_workspace_fake, deterministic_ops)); + + std::vector> arguments{ + lhs_slice, rhs_slice, output, workspace}; + + thunk = std::make_unique( + thunk_info, std::make_unique(std::move(seq)), + std::move(arguments), std::move(fake_allocations), + std::move(offset_buffer_indices), std::move(orig_shapes), + std::move(sliced_shapes), std::move(offset_byte_sizes)); + } else { + thunk = std::make_unique(thunk_info, std::move(config), + lhs_slice, rhs_slice, output, workspace, + deterministic_ops); + } + + FusionEmissionResult result; + result.thunks.push_back(std::move(thunk)); + return result; +} + +absl::StatusOr EmitCustomCall( + IrEmitterContext& ir_emitter_context, const HloFusionAdaptor& adaptor, + const HloFusionInstruction& fusion, + const HloCustomCallInstruction& custom_call) { + const BufferAssignment& buffer_assignment = + ir_emitter_context.buffer_assignment(); + + const std::string& call_target_name = custom_call.custom_call_target(); + + // Typed FFI custom calls is a replacement for legacy custom calls with + // a rich type safe API. It's under construction and not fully supported. + bool is_ffi_custom_call = + custom_call.api_version() == CustomCallApiVersion::API_VERSION_TYPED_FFI; + + void* call_target = CustomCallTargetRegistry::Global()->Lookup( + call_target_name, std::string(ir_emitter_context.platform_name())); + + absl::StatusOr registration = + ffi::FindHandler(call_target_name, ir_emitter_context.platform_name()); + + // At least one implementation should be available at run time. + bool found_custom_call = !is_ffi_custom_call && call_target != nullptr; + bool found_ffi_handler = is_ffi_custom_call && registration.ok(); + + if (!found_custom_call && !found_ffi_handler) { + return absl::InternalError( + "AddressComputationFusion expects custom calls that are emittable as " + "thunks"); + } + + using Slices = std::vector>; + + int64_t num_args = ShapeUtil::GetLeafCount(custom_call.shape()); + absl::c_for_each(custom_call.operands(), [&](auto* operand) { + num_args += ShapeUtil::GetLeafCount(operand->shape()); + }); + + std::vector>> + offset_buffer_indices(num_args, std::nullopt); + std::vector> orig_shapes(num_args, std::nullopt); + std::vector> sliced_shapes(num_args, std::nullopt); + std::vector> offset_byte_sizes(num_args, + std::nullopt); + + std::vector slice_instrs(num_args, nullptr); + std::vector> arguments; + + unsigned arg_idx = 0; + // TODO(vuson): add test for custom call with token-typed operands + Slices operands; + for (auto* operand : custom_call.operands()) { + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + operand->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsToken()) { + arg_idx++; + operands.push_back(std::nullopt); + return absl::OkStatus(); + } + if (!subshape.IsArray()) { + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN( + auto slice, + GetOperandSlice(buffer_assignment, adaptor, fusion, *operand, + slice_instrs, /*shape_idx=*/index, arg_idx)); + TF_RETURN_IF_ERROR(CollectSliceInfo( + buffer_assignment, fusion, + absl::Span(slice_instrs), offset_buffer_indices, + orig_shapes, sliced_shapes, offset_byte_sizes, arg_idx++)); + + operands.push_back(CustomCallThunk::Slice{slice, subshape}); + arguments.push_back(slice); + return absl::OkStatus(); + })); + } + + Slices results; + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + custom_call.shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsToken()) { + arg_idx++; + results.push_back(std::nullopt); + return absl::OkStatus(); + } + if (!subshape.IsArray()) { + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN( + auto slice, + GetResultSlice(buffer_assignment, adaptor, fusion, custom_call, + slice_instrs, /*shape_idx=*/index, arg_idx)); + TF_RETURN_IF_ERROR(CollectSliceInfo( + buffer_assignment, fusion, + absl::Span(slice_instrs), offset_buffer_indices, + orig_shapes, sliced_shapes, offset_byte_sizes, arg_idx++)); + + results.push_back(CustomCallThunk::Slice{slice, subshape}); + arguments.push_back(slice); + return absl::OkStatus(); + })); + + if (absl::c_all_of(slice_instrs, [&](auto slice_instr) { + return slice_instr == nullptr; + })) { + return absl::InternalError( + "AddressComputationFusion expects at least one sliced " + "operand/result"); + } + + // For legacy custom calls we convert all API versions into the latest + // status-returning one and pass backend config as an opaque string. + CustomCallThunk::CustomCallTarget custom_call_target; + std::string opaque; + + // For XLA FFI handlers we decode opaque backend config into attributes map + // at IR emission time, so that we do not need to parse MLIR at run time. For + // FFI handlers backend config must be a compatible MLIR dictionary. + CustomCallThunk::AttributesMap attributes; + + // For information about this calling convention, see + // xla/g3doc/custom_call.md. + switch (custom_call.api_version()) { + case CustomCallApiVersion::API_VERSION_ORIGINAL: + using original_call_type = + void (*)(CustomCallThunk::Stream /*stream*/, void** /*buffers*/, + const char* /*opaque*/, size_t /*opaque_len*/); + custom_call_target = [call_target](CustomCallThunk::Stream stream, + void** buffers, const char* opaque, + size_t opaque_len, + XlaCustomCallStatus*) { + auto typed_call_target = + reinterpret_cast(call_target); + typed_call_target(stream, buffers, opaque, opaque_len); + }; + break; + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING: + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: + using status_returning_call_type = + void (*)(CustomCallThunk::Stream /*stream*/, void** /*buffers*/, + const char* /*opaque*/, size_t /*opaque_len*/, + XlaCustomCallStatus* /*status*/); + custom_call_target = + reinterpret_cast(call_target); + break; + case CustomCallApiVersion::API_VERSION_TYPED_FFI: + // We already checked `handler` above. + break; + default: + return Internal("Unknown custom-call API version enum value: %d", + custom_call.api_version()); + } + + auto& backend_config_str = custom_call.raw_backend_config_string(); + switch (custom_call.api_version()) { + case CustomCallApiVersion::API_VERSION_ORIGINAL: + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING: + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: + if (!backend_config_str.empty()) { + opaque = backend_config_str; + } + break; + + case CustomCallApiVersion::API_VERSION_TYPED_FFI: + if (!backend_config_str.empty()) { + mlir::Attribute attr = mlir::parseAttribute( + backend_config_str, ir_emitter_context.mlir_context()); + if (auto dict = attr.dyn_cast_or_null()) { + TF_ASSIGN_OR_RETURN(attributes, BuildAttributesMap(dict)); + break; + } + return absl::InternalError( + "Unsupported backend config. Expected a string parsable into " + "dictionary attribute"); + } + break; + + default: + return Internal("Unknown custom-call API version enum value: %d", + custom_call.api_version()); + } + + std::unique_ptr thunk; + auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(&custom_call); + + auto ffi_thunk = [&](Slices ops, Slices res) { + auto& called_computations = custom_call.called_computations(); + return std::make_unique( + thunk_info, registration->handler, std::move(ops), std::move(res), + std::move(attributes), + called_computations.empty() ? nullptr : called_computations[0]); + }; + + auto legacy_thunk = [&](Slices ops, Slices res) { + return std::make_unique( + thunk_info, std::move(custom_call_target), std::move(ops), + std::move(res), std::move(opaque)); + }; + + std::vector> fake_allocations(num_args); + if (absl::c_any_of(slice_instrs, [&](auto slice_instr) { + return DynCastOrNull(slice_instr) != + nullptr; + })) { + // Creating embedded custom call thunk. + unsigned fake_arg_idx = 0; + + Slices fake_operands; + for (auto* operand : custom_call.operands()) { + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + operand->shape(), + [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsToken()) { + fake_arg_idx++; + fake_operands.push_back(std::nullopt); + return absl::OkStatus(); + } + if (!subshape.IsArray()) { + return absl::OkStatus(); + } + + int64_t operand_byte_size = ShapeUtil::ByteSizeOf(subshape); + fake_allocations[fake_arg_idx] = std::make_unique( + /*index=*/fake_arg_idx, operand_byte_size, /*color=*/0); + BufferAllocation::Slice fake_slice( + fake_allocations[fake_arg_idx].get(), 0, operand_byte_size); + + fake_arg_idx++; + fake_operands.push_back( + CustomCallThunk::Slice{fake_slice, subshape}); + return absl::OkStatus(); + })); + } + + Slices fake_results; + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + custom_call.shape(), + [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsToken()) { + fake_arg_idx++; + fake_results.push_back(std::nullopt); + return absl::OkStatus(); + } + if (!subshape.IsArray()) { + return absl::OkStatus(); + } + + int64_t result_byte_size = ShapeUtil::ByteSizeOf(subshape); + fake_allocations[fake_arg_idx] = std::make_unique( + /*index=*/fake_arg_idx, result_byte_size, /*color=*/0); + BufferAllocation::Slice fake_slice( + fake_allocations[fake_arg_idx].get(), 0, result_byte_size); + + fake_arg_idx++; + fake_results.push_back(CustomCallThunk::Slice{fake_slice, subshape}); + return absl::OkStatus(); + })); + + ThunkSequence seq; + seq.emplace_back( + found_ffi_handler + ? ffi_thunk(std::move(fake_operands), std::move(fake_results)) + : legacy_thunk(std::move(fake_operands), std::move(fake_results))); + + thunk = std::make_unique( + thunk_info, std::make_unique(std::move(seq)), + std::move(arguments), std::move(fake_allocations), + std::move(offset_buffer_indices), std::move(orig_shapes), + std::move(sliced_shapes), std::move(offset_byte_sizes)); + } else { + thunk = found_ffi_handler + ? ffi_thunk(std::move(operands), std::move(results)) + : legacy_thunk(std::move(operands), std::move(results)); + } + + FusionEmissionResult result; + result.thunks.push_back(std::move(thunk)); + return result; +} + +} // namespace + +absl::StatusOr CustomFusion::Emit( + IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion) const { + TF_ASSIGN_OR_RETURN(auto gpu_config, + fusion.backend_config()); + const FusionBackendConfig& backend_config = + gpu_config.fusion_backend_config(); + const auto& config = backend_config.custom_fusion_config(); + + VLOG(3) << "Lower HLO fusion to a custom fusion " << config.name(); + + auto* registry = CustomKernelFusionRegistry::Default(); + auto* custom_kernel_fusion = registry->Lookup(config.name()); + + // If custom fusion is not found it means that some of the build targets might + // not be statically linked into the binary. + if (custom_kernel_fusion == nullptr) { + return absl::InternalError( + absl::StrCat("Custom kernel fusion ", config.name(), + " not found in a default registry.")); + } + + // Load custom kernels that can implement a fusion computation. + TF_ASSIGN_OR_RETURN(std::vector kernels, + custom_kernel_fusion->LoadKernels( + ir_emitter_context.gpu_device_info(), + fusion.fused_instructions_computation())); + + // This should never happen, it means that compilation pipeline created a + // fusion operation that is not supported by a given custom fusion. + if (kernels.empty()) { + return absl::InternalError( + absl::StrCat("Custom kernel fusion ", config.name(), + " returned empty custom kernels for a fused computation")); + } + + // TODO(ezhulenev): Add support for auto tuning to select the best kernel. + if (kernels.size() != 1) { + return absl::InternalError("Expected exactly one custom kernel"); + } + + TF_ASSIGN_OR_RETURN( + auto thunk, BuildCustomKernelThunkForFusion(ir_emitter_context, fusion, + std::move(kernels[0]))); + + FusionEmissionResult result; + result.thunks.push_back(std::move(thunk)); + return result; +} + +absl::StatusOr AddressComputationFusion::Emit( + IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion) const { + const HloFusionAdaptor& adaptor = analysis_.fusion(); + auto maybe_custom_call_adaptor = HloFindIf( + adaptor.GetRoots(), adaptor, + [](auto node) { return node.opcode() == HloOpcode::kCustomCall; }); + if (maybe_custom_call_adaptor == std::nullopt) { + return absl::InternalError( + "AddressComputationFusion requires a CustomCall hero"); + } + + const auto& custom_call = *static_cast( + &maybe_custom_call_adaptor->instruction()); + if (IsLegacyCublasMatmul(custom_call)) { + return EmitGemm(ir_emitter_context, adaptor, fusion, custom_call); + } + + return EmitCustomCall(ir_emitter_context, adaptor, fusion, custom_call); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/custom.h b/xla/service/gpu/fusions/custom.h new file mode 100644 index 0000000000000..edd8d6d72b9af --- /dev/null +++ b/xla/service/gpu/fusions/custom.h @@ -0,0 +1,69 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_CUSTOM_H_ +#define XLA_SERVICE_GPU_FUSIONS_CUSTOM_H_ + +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/statusor.h" + +namespace xla { +namespace gpu { + +// A wrapper for fusions implemented using the mechanism in +// xla/service/gpu/kernels. See custom_kernel_fusion.h in that folder for +// details. +class CustomFusion : public FusionInterface { + public: + absl::StatusOr Emit( + IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion) const final; +}; + +// Emitter for custom fusions implementing address computation. An address +// computation contains a custom call hero, with at least one of its operands +// coming from a static contiguous slice. E.g. operand `%cast` of `%gemm` coming +// from `%slice`: +// %address_computation { +// %p0 = f32[2, 1024, 1024] +// %p1 = f32[1024, 1024] +// %slice = f32[1, 1024, 1024] slice(%p0) +// %cast = f32[1024, 1024] bitcast(%slice) +// ROOT %gemm = custom_call(%cast, %p1) __cublas$Gemm +// } +// +// The goal is to compute the buffer addresses for such operands (`%cast`) at +// compile-time instead of allocating a new buffer for it at runtime by +// translating the static slice into offset + size of the original buffer passed +// into the custom call `%gemm`. +class AddressComputationFusion : public FusionInterface { + public: + explicit AddressComputationFusion(const HloFusionAnalysis& analysis) + : analysis_(analysis) {} + + absl::StatusOr Emit( + IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion) const final; + + private: + const HloFusionAnalysis& analysis_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_CUSTOM_H_ diff --git a/xla/service/gpu/fusions/fusion_emitter.cc b/xla/service/gpu/fusions/fusion_emitter.cc index 4d505db798186..e652532fd0464 100644 --- a/xla/service/gpu/fusions/fusion_emitter.cc +++ b/xla/service/gpu/fusions/fusion_emitter.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,25 +14,54 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/fusions/fusion_emitter.h" +#include #include +#include #include +#include #include #include #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" -#include "mlir/IR/Operation.h" // from @llvm-project -#include "xla/service/gpu/hlo_to_ir_bindings.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Metadata.h" +#include "llvm/TargetParser/Triple.h" +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/layout_util.h" +#include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/kernel_arguments.h" #include "xla/service/gpu/kernel_reuse_cache.h" -#include "xla/service/gpu/kernel_thunk.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/runtime/kernel_thunk.h" #include "xla/service/gpu/target_util.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/status_macros.h" +#include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" +#include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -56,11 +85,19 @@ void AnnotateWithInt32Value(std::string name, int64_t value, llvm::IntegerType::get(llvm_context, /*NumBits=*/32), value))})); } +} // namespace + // Annotates the launch dimensions of the corresponding IR kernel in // `llvm_module`. -void AnnotateKernelLaunchDimensions(const LaunchDimensions& launch_dims, - const std::string& kernel_name, - llvm::Module* llvm_module) { +absl::Status AnnotateKernelLaunchDimensions( + const se::DeviceDescription& device_info, + const LaunchDimensions& launch_dims, const std::string& kernel_name, + llvm::Module* llvm_module) { + TF_RET_CHECK(device_info.block_dim_limit().x == 0 || + launch_dims.block_counts().x < device_info.block_dim_limit().x) + << "Kernel '" << kernel_name << "' launch needs more blocks (" + << launch_dims.block_counts().x << ") than allowed by hardware (" + << device_info.block_dim_limit().x << ")."; // Add __launch_bounds__ to metadata. This limits registers per thread to // avoid out-of-resources launching errors. @@ -76,12 +113,96 @@ void AnnotateKernelLaunchDimensions(const LaunchDimensions& launch_dims, AnnotateWithInt32Value("reqntidz", launch_dims.thread_counts_per_block().z, kernel_name, llvm_module); } + // Maybe we want to set "reqnctapercluster" here, but not sure if needed or if + // LLVM supports that yet. Let's do that later when needed. + return absl::OkStatus(); } -} // namespace +IndexingMap KernelFusionInterface::GetDefaultThreadIdIndexingMap( + const LaunchDimensions& launch_dims, int unroll_factor, const Shape& shape, + mlir::MLIRContext* ctx) { + std::vector output_dims(shape.rank()); + + std::array thread_counts{ + launch_dims.thread_counts_per_block().x, + launch_dims.thread_counts_per_block().y, + launch_dims.thread_counts_per_block().z, + }; + + std::array total_sizes{ + launch_dims.thread_counts_per_block().x * launch_dims.block_counts().x, + launch_dims.thread_counts_per_block().y * launch_dims.block_counts().y, + launch_dims.thread_counts_per_block().z * launch_dims.block_counts().z, + }; + + // ParallelLoopEmitter makes some assumptions about launch dimensions and + // computes the linear index using only the x and y components. + // + // We implement the general formula instead and rely on the simplifier to + // fix it. + // + // This means that this code supports some launch grids that the parallel + // loop emitter doesn't support. This is safe, since the latter CHECK fails + // if its assumptions are not fulfilled. + mlir::AffineExpr c0 = mlir::getAffineConstantExpr(0, ctx); + mlir::AffineExpr linear_index = c0; + uint64_t stride = 1; + for (int i = 0; i < 3; ++i) { + auto coord = mlir::getAffineDimExpr(kIndexingMapThreadIdxDims[i], ctx) + + mlir::getAffineDimExpr(kIndexingMapBlockIdxDims[i], ctx) * + thread_counts[i]; + auto linear_component = coord * stride; + linear_index = linear_index + linear_component; + stride *= total_sizes[i]; + } + mlir::AffineExpr chunk_id = mlir::getAffineSymbolExpr(0, ctx); + mlir::AffineExpr unroll_elem_id = mlir::getAffineSymbolExpr(1, ctx); + + linear_index = linear_index * unroll_factor + + chunk_id * unroll_factor * launch_dims.launch_bound() + + unroll_elem_id; + + // See IndexUtil::LinearIndexToMultidimensionalIndex. + uint64_t divisor = 1; + for (auto dimension : LayoutUtil::MinorToMajor(shape)) { + output_dims[dimension] = (linear_index.floorDiv(divisor)) % + static_cast(shape.dimensions(dimension)); + divisor *= shape.dimensions(dimension); + } -std::tuple, - std::vector> + std::vector dim_vars = { + {{0, static_cast(launch_dims.thread_counts_per_block().x) - 1}}, + {{0, static_cast(launch_dims.thread_counts_per_block().y) - 1}}, + {{0, static_cast(launch_dims.thread_counts_per_block().z) - 1}}, + {{0, static_cast(launch_dims.block_counts().x) - 1}}, + {{0, static_cast(launch_dims.block_counts().y) - 1}}, + {{0, static_cast(launch_dims.block_counts().z) - 1}}, + }; + std::vector range_vars; + int64_t num_elements = ShapeUtil::ElementsIn(shape); + range_vars.push_back( + {{0, CeilOfRatio(num_elements, + static_cast(launch_dims.launch_bound()) * + unroll_factor) - + 1}}); + range_vars.push_back({0, unroll_factor - 1}); + IndexingMap indexing_map( + mlir::AffineMap::get(/*dimCount=*/6, + /*symbolCount=*/2, output_dims, ctx), + dim_vars, range_vars, /*rt_vars=*/{}); + // Remove the unroll_elem_id symbol if unrolling divides num_elements. + if (num_elements % unroll_factor == 0) { + indexing_map.AddConstraint(linear_index.replace({{unroll_elem_id, c0}}), + Interval{0, num_elements - unroll_factor}); + } else { + indexing_map.AddConstraint(linear_index, Interval{0, num_elements - 1}); + } + indexing_map.Simplify(GetIndexingMapForInstruction); + return indexing_map; +} + +absl::StatusOr, + std::vector>> BuildKernelPrototype(IrEmitterContext& ir_emitter_context, const std::string& suggested_name, absl::Span arguments, @@ -112,16 +233,20 @@ BuildKernelPrototype(IrEmitterContext& ir_emitter_context, // Create the kernel and add it to the module. auto* llvm_module = ir_emitter_context.llvm_module(); llvm::LLVMContext& context = llvm_module->getContext(); + // Explicitly set global addrspace for SPIR backend. + int addrspace = llvm::Triple(llvm_module->getTargetTriple()).isSPIR() ? 1 : 0; llvm::FunctionType* kernel_type = llvm::FunctionType::get( /*Result=*/llvm::Type::getVoidTy(context), - std::vector(kNumLlvmArgs, builder->getPtrTy()), + std::vector(kNumLlvmArgs, builder->getPtrTy(addrspace)), /*isVarArg=*/false); llvm::Function* kernel = llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage, kernel_name, llvm_module); AnnotateFunctionAsGpuKernel(llvm_module, kernel, builder); - AnnotateKernelLaunchDimensions(launch_dimensions, kernel_name, llvm_module); + TF_RETURN_IF_ERROR(AnnotateKernelLaunchDimensions( + ir_emitter_context.gpu_device_info(), launch_dimensions, kernel_name, + llvm_module)); // TODO(b/65380986): Investigate if adding fast math flags for generated // kernels makes sense. @@ -172,63 +297,61 @@ BuildKernelPrototype(IrEmitterContext& ir_emitter_context, (arg_no < num_inputs ? inputs : outputs).push_back(ir_array); } - return {kernel, std::move(inputs), std::move(outputs)}; + return {{kernel, std::move(inputs), std::move(outputs)}}; } -StatusOr KernelFusionEmitterBase::Emit( - IrEmitterContext& ir_emitter_context, ElementalIrEmitter& elemental_emitter, - mlir::lmhlo::FusionOp fusion_op, const HloFusionInstruction& fusion, - KernelReuseCache& kernel_cache, llvm::IRBuilder<>* builder) const { +absl::StatusOr KernelFusionEmitterBase::Emit( + IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion) const { + llvm::IRBuilder<> builder(ir_emitter_context.llvm_module()->getContext()); std::string suggested_kernel_name = std::string(fusion.name()); - TF_ASSIGN_OR_RETURN(KernelArguments kernel_arguments, - ir_emitter_context.emit_ir_from_hlo() - ? KernelArguments::Create( - ir_emitter_context.buffer_assignment(), &fusion) - : KernelArguments::Create( - ir_emitter_context.allocations(), fusion_op)); + TF_ASSIGN_OR_RETURN( + KernelArguments kernel_arguments, + KernelArguments::Create(ir_emitter_context.buffer_assignment(), &fusion)); auto* fused_computation = fusion.fused_instructions_computation(); - FusionEmissionResult result; - for (int i = 0, n = num_kernels(); i < n; ++i) { - TF_ASSIGN_OR_RETURN(auto launch_dims, - launch_dimensions(ir_emitter_context, i)); - std::vector inputs, outputs; - auto [entry, cached] = kernel_cache.GetWithStatus( - fused_computation, kernel_arguments.args(), absl::StrCat(i), - [&]() -> StatusOr { - llvm::Function* kernel; - std::tie(kernel, inputs, outputs) = BuildKernelPrototype( - ir_emitter_context, suggested_kernel_name, - kernel_arguments.args(), fusion.operand_count(), launch_dims, - builder); - TF_RETURN_IF_ERROR(EmitKernel(ir_emitter_context, elemental_emitter, - fusion, launch_dims, std::move(inputs), - std::move(outputs), builder, i)); - // TODO(jreiffers): Return shmem_bytes from EmitKernel when - // converting the Triton emitters to this infrastructure. - return KernelReuseCache::Entry{kernel->getName().str(), launch_dims, - /*shmem_bytes=*/0}; - }); - TF_RETURN_IF_ERROR(entry.status()); - - if (cached) { - VLOG(3) << "Reuse: " << suggested_kernel_name << " -> " - << entry->kernel_name; - } - - if (ir_emitter_context.emit_ir_from_hlo()) { - result.thunks.emplace_back(std::make_unique( - &fusion, entry->kernel_name, kernel_arguments.args(), launch_dims, - entry->shmem_bytes)); - } else { - result.thunks.emplace_back(std::make_unique( - fusion_op, entry->kernel_name, kernel_arguments.args(), launch_dims, - entry->shmem_bytes)); - } + TF_ASSIGN_OR_RETURN(auto result, + EmitInitializers(ir_emitter_context, fusion)); + auto launch_dims = launch_dimensions(); + std::vector inputs, outputs; + auto [status_or_entry, cached] = + ir_emitter_context.kernel_cache().GetWithStatus( + fused_computation, kernel_arguments.args(), /*discriminator=*/"", + [&]() -> absl::StatusOr { + llvm::Function* kernel; + TF_ASSIGN_OR_RETURN( + std::tie(kernel, inputs, outputs), + BuildKernelPrototype(ir_emitter_context, suggested_kernel_name, + kernel_arguments.args(), + fusion.operand_count(), launch_dims, + &builder)); + if (ir_emitter_context.emit_kernels()) { + TF_RETURN_IF_ERROR(EmitKernel(ir_emitter_context, fusion, + launch_dims, std::move(inputs), + std::move(outputs), &builder)); + } else { + VLOG(3) << "Skipped kernel compilation: " + << suggested_kernel_name; + } + // TODO(jreiffers): Return shmem_bytes from EmitKernel when + // converting the Triton emitters to this infrastructure. + return KernelReuseCache::Entry{kernel->getName().str(), launch_dims, + /*cluster_dim=*/std::nullopt, + /*shmem_bytes=*/0}; + }); + TF_ASSIGN_OR_RETURN(const KernelReuseCache::Entry* entry, status_or_entry); + + if (cached) { + VLOG(3) << "Reuse: " << suggested_kernel_name << " -> " + << entry->kernel_name; } + result.thunks.emplace_back(std::make_unique( + &fusion, entry->kernel_name, kernel_arguments.args(), launch_dims, + entry->cluster_dim, entry->shmem_bytes)); + return result; } diff --git a/xla/service/gpu/fusions/fusion_emitter.h b/xla/service/gpu/fusions/fusion_emitter.h index f62b2da0276e1..dc8c399cbf42f 100644 --- a/xla/service/gpu/fusions/fusion_emitter.h +++ b/xla/service/gpu/fusions/fusion_emitter.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,18 +15,31 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_FUSIONS_FUSION_EMITTER_H_ #define XLA_SERVICE_GPU_FUSIONS_FUSION_EMITTER_H_ +#include +#include +#include #include #include +#include +#include #include +#include "absl/types/span.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/service/elemental_ir_emitter.h" #include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/kernel_reuse_cache.h" +#include "xla/service/gpu/kernel_arguments.h" #include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/thunk.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/runtime/thunk.h" #include "xla/service/llvm_ir/ir_array.h" +#include "xla/shape.h" +#include "xla/status.h" +#include "xla/statusor.h" namespace xla { namespace gpu { @@ -39,41 +52,81 @@ class FusionInterface { public: virtual ~FusionInterface() = default; - virtual StatusOr Emit( + virtual absl::StatusOr Emit( IrEmitterContext& ir_emitter_context, - ElementalIrEmitter& elemental_emitter, mlir::lmhlo::FusionOp fusion_op, - const HloFusionInstruction& fusion, KernelReuseCache& kernel_cache, - llvm::IRBuilder<>* builder) const = 0; + const HloFusionInstruction& fusion) const = 0; }; -class KernelFusionEmitterBase : public FusionInterface { +// Interface for fusions that are implemented using cuda kernels. +class KernelFusionInterface : public FusionInterface { public: - // The downstream code that is used by this emitter operates on a mix of MLIR - // and HLO classes. Ideally this would not be the case, but it's hard to - // change. - StatusOr Emit(IrEmitterContext& ir_emitter_context, - ElementalIrEmitter& elemental_emitter, - mlir::lmhlo::FusionOp fusion_op, - const HloFusionInstruction& fusion, - KernelReuseCache& kernel_cache, - llvm::IRBuilder<>* builder) const final; - virtual StatusOr launch_dimensions( - IrEmitterContext& ir_emitter_context, int kernel_index) const = 0; + virtual ~KernelFusionInterface() = default; + + // Returns the fusion's launch dimensions. + virtual LaunchDimensions launch_dimensions() const = 0; + + // Computes an indexing map from thread to output element(s) of the **hero**. + // + // The dimensions in the resulting map are + // d0, d1, d2: threadIdx.{x,y,z} + // d3, d4, d5: blockIdx.{x,y,z} + // If one thread computes multiple elements, this will be represented using a + // symbol. + // + // Cases where the exact element cannot be statically determined are currently + // unsupported (scatter, in-place DUS). Implementations will return nullopt. + // Note: Work in progress, not implemented for all emitters. + virtual std::optional ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const = 0; + + // Computes an indexing map from thread to input element(s) of the root's + // **hero**. Note that in many cases this is not computable from the output + // indexing. The indexing may only be known for some operands of the hero. + virtual std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const = 0; + + static constexpr std::array kIndexingMapThreadIdxDims = {0, 1, 2}; + static constexpr std::array kIndexingMapBlockIdxDims = {3, 4, 5}; protected: - virtual Status EmitKernel(IrEmitterContext& ir_emitter_context, - ElementalIrEmitter& elemental_emitter, - const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder, - int kernel_index) const = 0; - virtual int num_kernels() const { return 1; } + // Returns the default mapping for the given launch dimensions: linearizes + // the thread index and then reshapes it into the given layout. + // Populates the ranges for d0, d1, d2, d3, d4, d5 from the thread counts and + // block sizes in the given launch dimensions. + static IndexingMap GetDefaultThreadIdIndexingMap( + const LaunchDimensions& launch_dims, int unroll_factor, + const Shape& shape, mlir::MLIRContext* ctx); }; -std::tuple, - std::vector /*outputs*/> +// Base class for fusions that are implemented using a single kernel, which is +// generated using LLVM. +class KernelFusionEmitterBase : public KernelFusionInterface { + public: + absl::StatusOr Emit( + IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion) const final; + + protected: + // Creates initializer thunks that need to run before the main kernel. + virtual absl::StatusOr EmitInitializers( + IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion) const { + // No initializers by default. + return FusionEmissionResult{}; + } + + virtual absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion, + const LaunchDimensions& launch_dims, + std::vector inputs, + std::vector outputs, + llvm::IRBuilder<>* builder) const = 0; +}; + +absl::StatusOr< + std::tuple, + std::vector /*outputs*/>> BuildKernelPrototype(IrEmitterContext& ir_emitter_context, const std::string& suggested_name, absl::Span arguments, @@ -81,6 +134,11 @@ BuildKernelPrototype(IrEmitterContext& ir_emitter_context, const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* builder); +absl::Status AnnotateKernelLaunchDimensions( + const se::DeviceDescription& device_info, + const LaunchDimensions& launch_dims, const std::string& kernel_name, + llvm::Module* llvm_module); + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/fusions/fusions.cc b/xla/service/gpu/fusions/fusions.cc index a24a5d55c2825..f8120956810b2 100644 --- a/xla/service/gpu/fusions/fusions.cc +++ b/xla/service/gpu/fusions/fusions.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,79 +14,226 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/fusions/fusions.h" +#include #include #include +#include #include -#include "absl/types/span.h" +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/strings/match.h" #include "mlir/IR/Value.h" // from @llvm-project -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout_util.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/fusions/concatenate.h" +#include "xla/service/gpu/fusions/concatenate_mlir.h" #include "xla/service/gpu/fusions/copy.h" +#include "xla/service/gpu/fusions/cudnn.h" +#include "xla/service/gpu/fusions/custom.h" #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/fusions/in_place_dynamic_update_slice.h" +#include "xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h" #include "xla/service/gpu/fusions/input_slices.h" +#include "xla/service/gpu/fusions/input_slices_mlir.h" #include "xla/service/gpu/fusions/loop.h" +#include "xla/service/gpu/fusions/loop_mlir.h" +#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" #include "xla/service/gpu/fusions/reduction.h" +#include "xla/service/gpu/fusions/reduction_mlir.h" +#include "xla/service/gpu/fusions/scatter.h" +#include "xla/service/gpu/fusions/scatter_mlir.h" #include "xla/service/gpu/fusions/transpose.h" +#include "xla/service/gpu/fusions/transpose_mlir.h" +#include "xla/service/gpu/fusions/triton.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { +namespace { -bool IsSingleInstructionFusion(mlir::lmhlo::FusionOp fusion) { - bool seen_instruction = false; - for (mlir::Operation& instr : fusion.getRegion().front()) { - if (mlir::isa(&instr)) { - continue; +bool IsParameterOrGteOfParameter(const HloInstruction* instr) { + if (instr->opcode() == HloOpcode::kParameter) { + return true; + } + if (instr->opcode() == HloOpcode::kGetTupleElement) { + return IsParameterOrGteOfParameter(instr->operand(0)); + } + return false; +} + +bool IsDynamicUpdateSliceFusion(const HloFusionAnalysis& analysis) { + return absl::c_all_of( + analysis.fusion_roots(), [](const HloInstruction* root) { + return root->opcode() == HloOpcode::kDynamicUpdateSlice || + (root->opcode() == HloOpcode::kBitcast && + root->operand(0)->opcode() == HloOpcode::kDynamicUpdateSlice); + }); +} + +} // namespace + +std::optional>> +HloFusionInfo::GetCopyFusion() const { + std::vector src_buffers; + for (auto* root : analysis().fusion_roots()) { + if (root->opcode() != HloOpcode::kCopy || + root->operand(0)->opcode() != HloOpcode::kParameter || + !LayoutUtil::Equal(root->operand(0)->shape().layout(), + root->shape().layout())) { + return std::nullopt; } - if (seen_instruction) return false; - seen_instruction = true; + + const HloInstruction* src_instr = + instr_->operands()[root->operand(0)->parameter_number()]; + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, + buffer_assignment_->GetUniqueSlice(src_instr, {})); + src_buffers.push_back(slice); } - return seen_instruction; + + std::vector dst_buffers; + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + instr_->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (!subshape.IsArray()) { + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, + buffer_assignment_->GetUniqueSlice(instr_, index)); + dst_buffers.push_back(slice); + return absl::OkStatus(); + })); + + DCHECK(src_buffers.size() == dst_buffers.size()); + std::vector srcs; + std::vector dsts; + return std::make_unique(std::move(src_buffers), + std::move(dst_buffers), + /*srcs=*/std::vector(), + /*dsts=*/std::vector()); +} + +bool HloFusionInfo::CanEmitDynamicUpdateSliceInPlace() const { + auto ret = CanEmitFusedDynamicUpdateSliceInPlaceForGpu( + instr_, buffer_assignment_, analysis().fusion_roots()); + return ret.ok() && *ret; } -std::optional> GetFusionEmitter( - HloFusionAnalysis& analysis, - absl::Span allocations, - mlir::lmhlo::FusionOp fusion_op) { +absl::StatusOr> GetFusionEmitter( + const FusionInfo& fusion_info) { + const auto& analysis = fusion_info.analysis(); + const FusionBackendConfig& backend_config = analysis.fusion_backend_config(); + + const auto& opts = + analysis.fusion_roots().front()->GetModule()->config().debug_options(); + auto check_mlir_emitters = [&](std::function + support_check) { + if (!opts.xla_gpu_enable_mlir_emitters()) { + return false; + } + if (!mlir_converter::IsHloConversionSupported( + analysis.fusion(), + fusion_info.analysis().device_info().gpu_compute_capability())) { + VLOG(5) << "Skipping MLIR emission because the fusion contains " + "unsupported instructions."; + return false; + } + if (support_check && !support_check(analysis)) { + VLOG(5) << "Skipping MLIR emission because the fusion emitter does not " + "support " + "the fusion."; + return false; + } + + static int num_mlir_emitters = 0; + // This kernel can be emitted with MLIR, but we need to check if there are + // limits to how many kernels can be emitted. + ++num_mlir_emitters; + if (num_mlir_emitters <= opts.xla_gpu_skip_mlir_kernels()) { + VLOG(5) << "Skipping MLIR emission because initial skips were requested."; + return false; + } + + int n_emitted = num_mlir_emitters - opts.xla_gpu_skip_mlir_kernels(); + if (opts.xla_gpu_max_mlir_kernels() > 0 && + n_emitted > opts.xla_gpu_max_mlir_kernels()) { + VLOG(5) << "Skipping MLIR emission because max_mlir_emitters was set."; + return false; + } + VLOG(5) << "Emitting with MLIR."; + return true; + }; + switch (analysis.GetEmitterFusionKind()) { + case HloFusionAnalysis::EmitterFusionKind::kCustomFusion: { + const auto& config = backend_config.custom_fusion_config(); + if (absl::StrContains(config.name(), "address_computation")) { + return std::make_unique(analysis); + } + return std::make_unique(); + } case HloFusionAnalysis::EmitterFusionKind::kInputSlices: + if (check_mlir_emitters(nullptr)) { + return std::make_unique(analysis); + } return std::make_unique(analysis); case HloFusionAnalysis::EmitterFusionKind::kLoop: { - if (!allocations.empty() && fusion_op != nullptr) { - bool is_single = IsSingleInstructionFusion(fusion_op); - if (!is_single && CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - fusion_op, allocations)) { - return std::make_unique(analysis); - } - if (is_single && analysis.fusion_roots().size() == 1 && - analysis.fusion_roots().front()->opcode() == HloOpcode::kCopy) { - mlir::Value operand = GetHloOperands(fusion_op).front(); - mlir::Value output = GetHloOutputs(fusion_op).front(); - Shape operand_shape = GetShape(operand); - Shape output_shape = GetShape(output); - if (LayoutUtil::Equal(operand_shape.layout(), - output_shape.layout()) && - GetAllocationSlice(operand, allocations).ok()) { - return std::make_unique(operand, output); - } + if (IsDynamicUpdateSliceFusion(analysis) && + fusion_info.CanEmitDynamicUpdateSliceInPlace()) { + if (check_mlir_emitters( + MlirInPlaceDynamicUpdateSliceFusion::IsSupported)) { + return std::make_unique( + analysis); } + return std::make_unique(analysis); + } + + if (auto copy_fusion = fusion_info.GetCopyFusion()) { + return *std::move(copy_fusion); + } + + if (check_mlir_emitters(nullptr)) { + return std::make_unique(analysis); } return std::make_unique(analysis); } case HloFusionAnalysis::EmitterFusionKind::kReduction: + if (check_mlir_emitters(MlirReductionFusion::IsSupported)) { + return std::make_unique(analysis); + } return std::make_unique(analysis); - case HloFusionAnalysis::EmitterFusionKind::kTranspose: + case HloFusionAnalysis::EmitterFusionKind::kScatter: { + if (check_mlir_emitters(MlirScatterFusion::IsSupported)) { + return std::make_unique(analysis); + } + return std::make_unique(analysis); + } + case HloFusionAnalysis::EmitterFusionKind::kTranspose: { + if (check_mlir_emitters(nullptr)) { + return std::make_unique(analysis); + } return std::make_unique(analysis); - default: - break; + } + case HloFusionAnalysis::EmitterFusionKind::kConcatenate: { + if (check_mlir_emitters(nullptr)) { + return std::make_unique(analysis); + } + return std::make_unique(analysis); + } + case HloFusionAnalysis::EmitterFusionKind::kTriton: + return std::make_unique(analysis); + case HloFusionAnalysis::EmitterFusionKind::kCuDnn: + return std::make_unique(analysis); } - return std::nullopt; } } // namespace gpu diff --git a/xla/service/gpu/fusions/fusions.h b/xla/service/gpu/fusions/fusions.h index 82fc63400b411..f91ac049acb4a 100644 --- a/xla/service/gpu/fusions/fusions.h +++ b/xla/service/gpu/fusions/fusions.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,8 +18,7 @@ limitations under the License. #include #include -#include "absl/types/span.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" @@ -27,14 +26,71 @@ limitations under the License. namespace xla { namespace gpu { +class FusionInfo { + public: + explicit FusionInfo(const HloFusionAnalysis& analysis) + : analysis_(analysis) {} + virtual ~FusionInfo() = default; + + const HloFusionAnalysis& analysis() const { return analysis_; } + + // If the fusion is a DUS fusion, returns whether it can be emitted in place. + // Undefined if the fusion is not a DUS fusion. + virtual bool CanEmitDynamicUpdateSliceInPlace() const = 0; + + // Attempts to create a memcpy fusion, if possible. Returns nullopt if the + // fusion failed to pattern match. Returns an error if the fusion successfully + // pattern matched, but buffer assignment failed. + // TODO(b/204548848): Find a proper abstraction for this once LMHLO is gone. + virtual std::optional>> + GetCopyFusion() const = 0; + + private: + const HloFusionAnalysis& analysis_; +}; + +class HloFusionInfo : public FusionInfo { + public: + HloFusionInfo(const HloFusionAnalysis& analysis, + const HloFusionInstruction* instr, + const BufferAssignment* buffer_assignment) + : FusionInfo(analysis), + instr_(instr), + buffer_assignment_(buffer_assignment) {} + + bool CanEmitDynamicUpdateSliceInPlace() const override; + std::optional>> + GetCopyFusion() const override; + + private: + const HloFusionInstruction* instr_; + const BufferAssignment* buffer_assignment_; +}; + +class PreBufferAssignmentFusionInfo : public FusionInfo { + public: + explicit PreBufferAssignmentFusionInfo(const HloFusionAnalysis& analysis) + : FusionInfo(analysis) {} + + bool CanEmitDynamicUpdateSliceInPlace() const override { + // Optimistically assume all DUS fusions are in-place. + return true; + } + + std::optional>> + GetCopyFusion() const override { + // Copy fusions can't be created without buffer assignment. Note: + // technically, this is only needed to generate the chunk, the validation + // itself could be done without a buffer assignment. However, we currently + // have no use for this, so it's OK to always fall back to the loop fusion. + return std::nullopt; + } +}; + // Returns the emitter for the given fusion. Returns nullopt if the fusion // type is not yet supported. -// `allocations` may be empty and `fusion_op` may be nullptr if buffer -// assignment didn't run yet. -std::optional> GetFusionEmitter( - HloFusionAnalysis& analysis, - absl::Span allocations, - mlir::lmhlo::FusionOp fusion_op); +absl::StatusOr> GetFusionEmitter( + const FusionInfo& fusion_info); } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/fusions/in_place_dynamic_update_slice.cc b/xla/service/gpu/fusions/in_place_dynamic_update_slice.cc index 5056affd6e5be..ea4c56fa576b1 100644 --- a/xla/service/gpu/fusions/in_place_dynamic_update_slice.cc +++ b/xla/service/gpu/fusions/in_place_dynamic_update_slice.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,32 +14,56 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/fusions/in_place_dynamic_update_slice.h" +#include #include #include +#include "absl/status/status.h" #include "llvm/ADT/STLExtras.h" #include "llvm/IR/IRBuilder.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/elemental_ir_emitter.h" +#include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/llvm_ir/dynamic_update_slice_util.h" #include "xla/service/llvm_ir/fused_ir_emitter.h" #include "xla/service/llvm_ir/ir_array.h" +#include "xla/status.h" +#include "xla/statusor.h" namespace xla { namespace gpu { +namespace { -StatusOr InPlaceDynamicUpdateSliceEmitter::launch_dimensions( - IrEmitterContext& ir_emitter_context, int kernel_index) const { +constexpr int kDUSUpdateIndex = 1; + +} // namespace + +LaunchDimensions InPlaceDynamicUpdateSliceFusion::launch_dimensions() const { const auto& update_shape = dus_ops_.front()->operand(1)->shape(); - return CalculateLaunchDimensions(update_shape, - ir_emitter_context.gpu_device_info()); + return CalculateLaunchDimensions(update_shape, analysis_.device_info()); +} + +std::optional +InPlaceDynamicUpdateSliceFusion::ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* mlir_context) const { + if (hero_operand_index != kDUSUpdateIndex) { + return std::nullopt; + } + auto launch_dims = launch_dimensions(); + // It is guaranteed that all DUS ops have the same output shape at this point. + const auto& update_shape = + dus_ops_.front()->operand(kDUSUpdateIndex)->shape(); + return GetDefaultThreadIdIndexingMap(launch_dims, /*unroll_factor=*/1, + update_shape, mlir_context); } -Status InPlaceDynamicUpdateSliceEmitter::EmitKernel( - IrEmitterContext& ir_emitter_context, ElementalIrEmitter& elemental_emitter, - const HloFusionInstruction& fusion, const LaunchDimensions& launch_dims, - std::vector inputs, std::vector outputs, - llvm::IRBuilder<>* builder, int kernel_index) const { +absl::Status InPlaceDynamicUpdateSliceFusion::EmitKernel( + IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion, + const LaunchDimensions& launch_dims, std::vector inputs, + std::vector outputs, llvm::IRBuilder<>* builder) const { // In case a dynamic slice update's output is bitcasted, we need to ensure we // write to the output array using the shape and layout of the dynamic slice // update. This cast is known to be safe to do iff, in the case the output of @@ -52,6 +76,7 @@ Status InPlaceDynamicUpdateSliceEmitter::EmitKernel( } auto* fused_computation = fusion.fused_instructions_computation(); + GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder); FusedIrEmitter fused_emitter(elemental_emitter); for (auto [index, input] : llvm::enumerate(inputs)) { auto fused_operand = fused_computation->parameter_instruction(index); diff --git a/xla/service/gpu/fusions/in_place_dynamic_update_slice.h b/xla/service/gpu/fusions/in_place_dynamic_update_slice.h index 222d182bf51d0..213f7e7ecbdea 100644 --- a/xla/service/gpu/fusions/in_place_dynamic_update_slice.h +++ b/xla/service/gpu/fusions/in_place_dynamic_update_slice.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,11 +15,20 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_FUSIONS_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_ #define XLA_SERVICE_GPU_FUSIONS_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_ +#include #include +#include "llvm/IR/IRBuilder.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/llvm_ir/ir_array.h" +#include "xla/status.h" +#include "xla/statusor.h" namespace xla { namespace gpu { @@ -49,24 +58,34 @@ namespace gpu { // modifies the output in place without touching the un-updated elements. The // update slice is assumed to be the exact same for all the // dynamic-update-slice ops. -class InPlaceDynamicUpdateSliceEmitter : public KernelFusionEmitterBase { +class InPlaceDynamicUpdateSliceFusion : public KernelFusionEmitterBase { public: - explicit InPlaceDynamicUpdateSliceEmitter(const HloFusionAnalysis& analysis) - : dus_ops_( + explicit InPlaceDynamicUpdateSliceFusion(const HloFusionAnalysis& analysis) + : analysis_(analysis), + dus_ops_( GetOutputDefiningDynamicUpdateSlices(analysis.fusion_roots())) {} - StatusOr launch_dimensions( - IrEmitterContext& ir_emitter_context, int kernel_index) const override; + LaunchDimensions launch_dimensions() const override; + + std::optional ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const override { + // The mapping cannot be statically computed in general, since the offsets + // are unknown. + return std::nullopt; + } + + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const override; protected: - Status EmitKernel(IrEmitterContext& ir_emitter_context, - ElementalIrEmitter& elemental_emitter, - const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder, - int kernel_index) const override; + absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion, + const LaunchDimensions& launch_dims, + std::vector inputs, + std::vector outputs, + llvm::IRBuilder<>* builder) const override; + const HloFusionAnalysis& analysis_; std::vector dus_ops_; }; diff --git a/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc b/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc new file mode 100644 index 0000000000000..eccdcfceee8a8 --- /dev/null +++ b/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc @@ -0,0 +1,169 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h" + +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/primitive_util.h" +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { +namespace { + +using llvm::SmallVector; +using mlir::ImplicitLocOpBuilder; +using mlir::MLIRContext; +using mlir::Value; +using mlir::ValueRange; +using mlir::arith::AddIOp; +using mlir::func::ReturnOp; +using mlir::tensor::InsertOp; +using mlir_converter::ApplyAffineMap; +using mlir_converter::CallTargetProvider; +using mlir_converter::ClampIndex; +using mlir_converter::PartitionedComputations; +using mlir_converter::ProvideParameter; + +constexpr int kDUSUpdateIndex = 1; + +} // namespace + +/*static*/ bool MlirInPlaceDynamicUpdateSliceFusion::IsSupported( + const HloFusionAnalysis& analysis) { + return analysis.fusion_roots().size() == 1; +} + +LaunchDimensions MlirInPlaceDynamicUpdateSliceFusion::launch_dimensions() + const { + const auto& update_shape = + dus_ops_.front()->operand(kDUSUpdateIndex)->shape(); + return CalculateLaunchDimensions(update_shape, analysis_.device_info()); +} + +std::optional +MlirInPlaceDynamicUpdateSliceFusion::ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* mlir_context) const { + // TODO(b/331355203): Implement thread ID -> operand indexing. + if (hero_operand_index != kDUSUpdateIndex) { + return std::nullopt; + } + auto launch_dims = launch_dimensions(); + // It is guaranteed that all DUS ops have the same output shape at this point. + const auto& update_shape = + dus_ops_.front()->operand(kDUSUpdateIndex)->shape(); + return GetDefaultThreadIdIndexingMap(launch_dims, /*unroll_factor=*/1, + update_shape, mlir_context); +} + +std::vector +MlirInPlaceDynamicUpdateSliceFusion::GetInstructionsWithCustomCodegen( + const HloFusionInstruction& fusion) const { + return dus_ops_; +} + +absl::Status MlirInPlaceDynamicUpdateSliceFusion::EmitEntryFunction( + const PartitionedComputations& computations, + const CallTargetProvider& call_targets, mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const { + ImplicitLocOpBuilder b(entry_function.getLoc(), entry_function); + b.setInsertionPointToStart(entry_function.addEntryBlock()); + + mlir::MLIRContext* mlir_context = entry_function.getContext(); + + auto indexing = *ComputeThreadIdToInputIndexing( + /*root_index=*/0, + /*hero_operand_index=*/kDUSUpdateIndex, mlir_context); + indexing.Simplify(GetIndexingMapForInstruction); + indexing.RemoveUnusedSymbols(); + + int num_inputs = fusion.fused_instructions_computation()->num_parameters(); + auto output_tensor_args = + entry_function.getArguments().drop_front(num_inputs); + + const auto& root_computation = computations.FindPartitionedComputation( + fusion.fused_instructions_computation()); + const auto& dus_subgraph = root_computation.FindSubgraph(dus_ops_.front()); + + const auto* dus_instr = + Cast(dus_ops_.front()); + const auto& update_shape = dus_instr->update()->shape(); + auto result_tensors = EmitThreadLoopNest( + b, output_tensor_args, indexing, + [&](ValueRange output_tensors, ValueRange dim_values, + ValueRange symbol_values) -> llvm::SmallVector { + auto input_indices = ApplyAffineMap(indexing.GetAffineMap(), dim_values, + symbol_values, b); + SmallVector update_indices; + for (int i = 0; i < update_shape.rank(); ++i) { + int64_t update_size = update_shape.dimensions(i); + auto start_index = + ProvideParameter(dus_subgraph, dus_instr, + i + dus_instr->first_index_operand_number(), {}, + call_targets, entry_function, b)[0]; + start_index = ClampIndex( + start_index, + primitive_util::IsUnsignedIntegralType( + dus_instr + ->operand(i + dus_instr->first_index_operand_number()) + ->shape() + .element_type()), + dus_instr->shape().dimensions(i) - update_size, b); + + update_indices.push_back( + b.create(input_indices[i], start_index)); + } + + auto updated_value = + ProvideParameter(dus_subgraph, dus_instr, kDUSUpdateIndex, + input_indices, call_targets, entry_function, b)[0]; + auto insert = b.create(updated_value, output_tensors[0], + update_indices); + + return {insert.getResult()}; + }); + + b.create(result_tensors); + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h b/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h new file mode 100644 index 0000000000000..bac44f13144cd --- /dev/null +++ b/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h @@ -0,0 +1,81 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_IN_PLACE_DYNAMIC_UPDATE_SLICE_MLIR_H_ +#define XLA_SERVICE_GPU_FUSIONS_IN_PLACE_DYNAMIC_UPDATE_SLICE_MLIR_H_ + +#include +#include + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/status.h" + +namespace xla { +namespace gpu { + +// Fusion node where the root is either: +// 1. a dynamic-update-slice op +// 2. a bitcast of a dynamic-update-slice op +// 3. a tuple op returning the result of several dynamic-update-slice ops +// 4. a tuple op returning the result of several bitcast +// dynamic-update-slice ops +// +// Lowers to LLVM via MLIR. +class MlirInPlaceDynamicUpdateSliceFusion : public MlirFusionEmitterBase { + public: + explicit MlirInPlaceDynamicUpdateSliceFusion( + const HloFusionAnalysis& analysis) + : analysis_(analysis), + dus_ops_( + GetOutputDefiningDynamicUpdateSlices(analysis.fusion_roots())) {} + + static bool IsSupported(const HloFusionAnalysis& analysis); + + LaunchDimensions launch_dimensions() const override; + + std::optional ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* indexing_context) const override { + // The mapping cannot be statically computed in general, since the offsets + // are unknown. + return std::nullopt; + } + + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* indexing_context) const override; + + protected: + absl::Status EmitEntryFunction( + const mlir_converter::PartitionedComputations& computations, + const mlir_converter::CallTargetProvider& call_targets, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const override; + + std::vector GetInstructionsWithCustomCodegen( + const HloFusionInstruction& fusion) const override; + + private: + const HloFusionAnalysis& analysis_; + std::vector dus_ops_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_IN_PLACE_DYNAMIC_UPDATE_SLICE_MLIR_H_ diff --git a/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc b/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc new file mode 100644 index 0000000000000..3aabb901b498c --- /dev/null +++ b/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc @@ -0,0 +1,181 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h" + +#include +#include "xla/error_spec.h" +#include "xla/service/gpu/fusions/mlir_emitter_test_base.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "tsl/lib/core/status_test_util.h" + +namespace xla { +namespace gpu { +namespace { + +using MlirInPlaceDynamicUpdateSliceFusionTest = + MlirEmitterTestBase; + +TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule module + + fused_computation { + in = f32[20,30] parameter(0) + updates = f32[5,6] parameter(1) + i0 = s32[] parameter(2) + i1 = s32[] parameter(3) + ROOT updated = f32[20,30] dynamic-update-slice(in, updates, i0, i1) + } + ENTRY entry { + in = f32[20,30] parameter(0) + updates = f32[5,6] parameter(1) + i0 = s32[] constant(2) + i1 = s32[] constant(3) + ROOT fusion = f32[20,30] fusion(in, updates, i0, i1), kind=kLoop, calls=fused_computation + } + )")); + thread_id_printer_.SetSymbolName(0, "chunk_id"); + thread_id_printer_.SetSymbolName(1, "unroll_id"); + + auto* root = module->entry_computation()->root_instruction(); + + auto analysis = AnalyzeFusion(*root, device_info_); + MlirInPlaceDynamicUpdateSliceFusion fusion(analysis); + + auto thread_id_update_indexing = fusion.ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_); + EXPECT_THAT(thread_id_update_indexing->ToString(thread_id_printer_), + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( + th_x floordiv 6, th_x mod 6) + domain: + th_x in [0, 29] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 0] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + )")); + auto thread_id_dst_indexing = fusion.ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); + EXPECT_THAT(thread_id_dst_indexing, ::testing::Eq(std::nullopt)); +} + +TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, SimpleDUS) { + auto kHloString = R"( + HloModule module + + fused_computation { + in = f32[20,30] parameter(0) + updates = f32[5,6] parameter(1) + i0 = s32[] parameter(2) + i1 = s32[] parameter(3) + ROOT updated = f32[20,30] dynamic-update-slice(in, updates, i0, i1) + } + ENTRY entry { + in = f32[20,30] parameter(0) + updates = f32[5,6] parameter(1) + i0 = s32[] constant(2) + i1 = s32[] constant(3) + ROOT fusion = f32[20,30] fusion(in, updates, i0, i1), kind=kLoop, calls=fused_computation + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK-DAG: #[[MAP_1:.*]] = affine_map<()[s0] -> (s0 floordiv 6)> + // CHECK-DAG: #[[MAP_2:.*]] = affine_map<()[s0] -> (s0 mod 6)> + // CHECK: func.func @fused_computation + // CHECK-SAME: %arg0: tensor<20x30xf32> + // CHECK-SAME: %arg1: tensor<5x6xf32> + // CHECK-SAME: %arg2: tensor + // CHECK-SAME: %arg3: tensor + // CHECK-SAME: %arg4: tensor<20x30xf32> + // CHECK-DAG: %[[C_24:.*]] = arith.constant 24 + // CHECK-DAG: %[[C_15:.*]] = arith.constant 15 + // CHECK-DAG: %[[C_0:.*]] = arith.constant 0 + // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x + // CHECK: %[[INPUT_INDEX_0:.*]] = affine.apply #[[MAP_1]]()[%[[THREAD_ID]]] + // CHECK: %[[INPUT_INDEX_1:.*]] = affine.apply #[[MAP_2]]()[%[[THREAD_ID]]] + // CHECK: %[[I0:.*]] = xla_gpu.pure_call @fused_computation_i0 + // CHECK: %[[IDX0:.*]] = arith.index_cast %[[I0]] + // CHECK: %[[MIN0:.*]] = arith.minsi %[[IDX0]], %[[C_15]] + // CHECK: %[[MAX0:.*]] = arith.maxsi %[[MIN0]], %[[C_0]] + // CHECK: %[[ADD0:.*]] = arith.addi %[[INPUT_INDEX_0]], %[[MAX0]] + // CHECK: %[[I1:.*]] = xla_gpu.pure_call @fused_computation_i1 + // CHECK: %[[IDX1:.*]] = arith.index_cast %[[I1]] + // CHECK: %[[MIN1:.*]] = arith.minsi %[[IDX1]], %[[C_24]] + // CHECK: %[[MAX1:.*]] = arith.maxsi %[[MIN1]], %[[C_0]] + // CHECK: %[[ADD1:.*]] = arith.addi %[[INPUT_INDEX_1]], %[[MAX1]] + // CHECK: %[[UPDATE:.*]] = xla_gpu.pure_call @fused_computation_updates + // CHECK: %[[INSERT:.*]] = tensor.insert %[[UPDATE:.*]] into %arg4[%[[ADD0]], %[[ADD1]]] + // CHECK: return %[[INSERT]] + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, OutOfBoundDUS) { + auto kHloString = R"( + HloModule module + + fused_computation { + in = f32[7,8] parameter(0) + updates = f32[2,3] parameter(1) + i0 = s32[] parameter(2) + i1 = s32[] parameter(3) + ROOT updated = f32[7,8] dynamic-update-slice(in, updates, i0, i1) + } + ENTRY entry { + in = f32[7,8] parameter(0) + updates = f32[2,3] parameter(1) + i0 = s32[] constant(-20) + i1 = s32[] constant(30) + ROOT fusion = f32[7,8] fusion(in, updates, i0, i1), kind=kLoop, calls=fused_computation + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK-DAG: #[[MAP_1:.*]] = affine_map<()[s0] -> (s0 floordiv 3)> + // CHECK-DAG: #[[MAP_2:.*]] = affine_map<()[s0] -> (s0 mod 3)> + // CHECK: func.func @fused_computation + // CHECK-SAME: %arg0: tensor<7x8xf32> + // CHECK-SAME: %arg1: tensor<2x3xf32> + // CHECK-SAME: %arg2: tensor + // CHECK-SAME: %arg3: tensor + // CHECK-SAME: %arg4: tensor<7x8xf32> + // CHECK-DAG: %[[C_5:.*]] = arith.constant 5 + // CHECK-DAG: %[[C_0:.*]] = arith.constant 0 + // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x + // CHECK: %[[INPUT_INDEX_0:.*]] = affine.apply #[[MAP_1]]()[%[[THREAD_ID]]] + // CHECK: %[[INPUT_INDEX_1:.*]] = affine.apply #[[MAP_2]]()[%[[THREAD_ID]]] + // CHECK: %[[I0:.*]] = xla_gpu.pure_call @fused_computation_i0 + // CHECK: %[[IDX0:.*]] = arith.index_cast %[[I0]] + // CHECK: %[[MIN0:.*]] = arith.minsi %[[IDX0]], %[[C_5]] + // CHECK: %[[MAX0:.*]] = arith.maxsi %[[MIN0]], %[[C_0]] + // CHECK: %[[ADD0:.*]] = arith.addi %[[INPUT_INDEX_0]], %[[MAX0]] + // CHECK: %[[I1:.*]] = xla_gpu.pure_call @fused_computation_i1 + // CHECK: %[[IDX1:.*]] = arith.index_cast %[[I1]] + // CHECK: %[[MIN1:.*]] = arith.minsi %[[IDX1]], %[[C_5]] + // CHECK: %[[MAX1:.*]] = arith.maxsi %[[MIN1]], %[[C_0]] + // CHECK: %[[ADD1:.*]] = arith.addi %[[INPUT_INDEX_1]], %[[MAX1]] + // CHECK: %[[UPDATE:.*]] = xla_gpu.pure_call @fused_computation_updates + // CHECK: %[[INSERT:.*]] = tensor.insert %[[UPDATE:.*]] into %arg4[%[[ADD0]], %[[ADD1]]] + // CHECK: return %[[INSERT]] + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc b/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc new file mode 100644 index 0000000000000..c0382560399c1 --- /dev/null +++ b/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc @@ -0,0 +1,105 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/in_place_dynamic_update_slice.h" + +#include + +#include +#include +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/service/gpu/fusions/fusions.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class InPlaceDynamicUpdateSliceFusionTest : public HloTestBase { + public: + void SetUp() override { + HloTestBase::SetUp(); + printer_ = + AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, + {"chunk_id", "unroll_id"}); + } + + protected: + AffineMapPrinter printer_; + mlir::MLIRContext mlir_context_; +}; + +TEST_F(InPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + fused_computation { + in = f32[20,30] parameter(0) + updates = f32[5,6] parameter(1) + i0 = s32[] parameter(2) + i1 = s32[] parameter(3) + ROOT updated = f32[20,30] dynamic-update-slice(in, updates, i0, i1) + } + ENTRY entry { + in = f32[20,30] parameter(0) + updates = f32[5,6] parameter(1) + i0 = s32[] constant(2) + i1 = s32[] constant(3) + ROOT fusion = f32[20,30] fusion(in, updates, i0, i1), kind=kLoop, calls=fused_computation + } + )") + .value(); + + stream_executor::DeviceDescription device_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis_fused = AnalyzeFusion(*root, device_info); + + TF_ASSERT_OK_AND_ASSIGN( + auto emitter, + GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused})); + auto fusion = dynamic_cast(emitter.get()); + ASSERT_NE(fusion, nullptr); + + auto thread_id_update_indexing = fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_); + EXPECT_THAT(thread_id_update_indexing->ToString(printer_), + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( + th_x floordiv 6, th_x mod 6) + domain: + th_x in [0, 29] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 0] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + )")); + auto thread_id_dst_indexing = fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); + EXPECT_THAT(thread_id_dst_indexing, ::testing::Eq(std::nullopt)); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/input_slices.cc b/xla/service/gpu/fusions/input_slices.cc index 8b2463140f7f2..225de1da8be49 100644 --- a/xla/service/gpu/fusions/input_slices.cc +++ b/xla/service/gpu/fusions/input_slices.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,14 +14,37 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/fusions/input_slices.h" +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/elemental_ir_emitter.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/elemental_ir_emitter.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/parallel_loop_emitter.h" #include "xla/service/llvm_ir/fused_ir_emitter.h" +#include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/kernel_support_library.h" -#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/service/llvm_ir/llvm_loop.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -44,7 +67,7 @@ namespace { // Write to output of slice1 // } // -Status EmitElementForInputFusibleSlices( +absl::Status EmitElementForInputFusibleSlices( ElementalIrEmitter& elemental_emitter, const HloComputation* fused_computation, const std::vector& inputs, @@ -115,7 +138,7 @@ Status EmitElementForInputFusibleSlices( ksl.If(absl::StrCat("slice", i), guarding_cond, emit_slice_elem_func); } - return OkStatus(); + return absl::OkStatus(); } // Gets the input shape of the ROOT slices, which will be used as the kernel @@ -125,7 +148,7 @@ Status EmitElementForInputFusibleSlices( // Returns the input shape of the ROOT slices if all the input shapes of ROOT // slices are the same and the slices are non-strided. Otherwise, returns // FailedPrecondition. -StatusOr GetConsistentInputShapeForRootSlices( +absl::StatusOr GetConsistentInputShapeForRootSlices( const HloComputation* fused_computation) { const HloInstruction& root = *fused_computation->root_instruction(); if (root.opcode() == HloOpcode::kSlice) { @@ -152,26 +175,41 @@ StatusOr GetConsistentInputShapeForRootSlices( } // namespace -StatusOr InputSlicesFusion::launch_dimensions( - IrEmitterContext& ir_emitter_context, int kernel_index) const { - return analysis_.GetLaunchDimensions(); +LaunchDimensions InputSlicesFusion::launch_dimensions() const { + auto* root = analysis_.fusion_roots().front(); + const auto& shape = root->operands()[0]->shape(); + return CalculateLaunchDimensions(shape, analysis_.device_info(), + {unroll_factor_}); +} + +std::optional InputSlicesFusion::ComputeThreadIdToOutputIndexing( + int64_t output_id, mlir::MLIRContext* ctx) const { + // The mapping here is trivial and the same for all outputs - slice offsets + // are applied in the indexing from slice outputs to slice inputs. + auto launch_dims = launch_dimensions(); + // The implementation requires the shapes and layouts to be the same, but we + // still use the requested output's shape for clarity. + const auto& shape = analysis_.fusion_roots()[output_id]->shape(); + return GetDefaultThreadIdIndexingMap(launch_dims, unroll_factor_, shape, ctx); } -Status InputSlicesFusion::EmitKernel( - IrEmitterContext& ir_emitter_context, ElementalIrEmitter& elemental_emitter, - const HloFusionInstruction& fusion, const LaunchDimensions& launch_dims, - std::vector inputs, std::vector outputs, - llvm::IRBuilder<>* builder, int kernel_index) const { +absl::Status InputSlicesFusion::EmitKernel( + IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion, + const LaunchDimensions& launch_dims, std::vector inputs, + std::vector outputs, llvm::IRBuilder<>* builder) const { TF_ASSIGN_OR_RETURN(Shape element_shape, GetConsistentInputShapeForRootSlices( fusion.fused_instructions_computation())); + LaunchDimensionsConfig launch_config; + launch_config.unroll_factor = unroll_factor_; + GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder); return ParallelLoopEmitter( - [&](const llvm_ir::IrArray::Index index) -> Status { + [&](const llvm_ir::IrArray::Index index) -> absl::Status { return EmitElementForInputFusibleSlices( elemental_emitter, fusion.fused_instructions_computation(), inputs, outputs, index, builder); }, - element_shape, launch_dims, builder) + element_shape, launch_dims, builder, launch_config) .EmitLoop( fusion.name(), GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder)); diff --git a/xla/service/gpu/fusions/input_slices.h b/xla/service/gpu/fusions/input_slices.h index c037a4df5e19c..90f4f4e4a24d0 100644 --- a/xla/service/gpu/fusions/input_slices.h +++ b/xla/service/gpu/fusions/input_slices.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,10 +15,20 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_H_ #define XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_H_ +#include +#include #include +#include "llvm/IR/IRBuilder.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/llvm_ir/ir_array.h" +#include "xla/status.h" namespace xla { namespace gpu { @@ -32,23 +42,32 @@ namespace gpu { // in the future. class InputSlicesFusion : public KernelFusionEmitterBase { public: - explicit InputSlicesFusion(HloFusionAnalysis& analysis) - : analysis_(analysis) {} - StatusOr launch_dimensions( - IrEmitterContext& ir_emitter_context, int kernel_index) const override; + explicit InputSlicesFusion(const HloFusionAnalysis& analysis) + : analysis_(analysis), + unroll_factor_(analysis.input_output_info().has_4_bit_output ? 2 : 1) {} + LaunchDimensions launch_dimensions() const override; + + std::optional ComputeThreadIdToOutputIndexing( + int64_t output_id, mlir::MLIRContext* ctx) const override; + + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const override { + // TODO(b/319081342): Implement this. + return std::nullopt; + } protected: - Status EmitKernel(IrEmitterContext& ir_emitter_context, - ElementalIrEmitter& elemental_emitter, - const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder, - int kernel_index) const override; + absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion, + const LaunchDimensions& launch_dims, + std::vector inputs, + std::vector outputs, + llvm::IRBuilder<>* builder) const override; private: - HloFusionAnalysis& analysis_; + const HloFusionAnalysis& analysis_; + const int unroll_factor_; }; } // namespace gpu diff --git a/xla/service/gpu/fusions/input_slices_mlir.cc b/xla/service/gpu/fusions/input_slices_mlir.cc new file mode 100644 index 0000000000000..16c4fbf506234 --- /dev/null +++ b/xla/service/gpu/fusions/input_slices_mlir.cc @@ -0,0 +1,122 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/input_slices_mlir.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/status_macros.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +using llvm::SmallVector; +using mlir::Value; +using mlir::ValueRange; + +std::optional +MlirInputSlicesFusion::ComputeThreadIdToOutputIndexing( + int64_t output_id, mlir::MLIRContext* ctx) const { + // The mapping here is trivial and the same for all outputs - slice offsets + // are applied in the indexing from slice outputs to slice inputs. + auto launch_dims = launch_dimensions(); + // The implementation requires the shapes and layouts to be the same, but we + // still use the requested output's shape for clarity. + const auto& shape = analysis_.fusion_roots()[output_id]->shape(); + return GetDefaultThreadIdIndexingMap(launch_dims, unroll_factor_, shape, ctx); +} + +LaunchDimensions MlirInputSlicesFusion::launch_dimensions() const { + auto* root = analysis_.fusion_roots().front(); + const auto& shape = root->operands()[0]->shape(); + return CalculateLaunchDimensions(shape, analysis_.device_info(), + {unroll_factor_}); +} + +absl::Status MlirInputSlicesFusion::EmitEntryFunction( + const mlir_converter::PartitionedComputations& computations, + const mlir_converter::CallTargetProvider& call_targets, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const { + mlir::ImplicitLocOpBuilder builder(entry_function.getLoc(), entry_function); + builder.setInsertionPointToStart(entry_function.addEntryBlock()); + + // We enforce that all the root shapes have identical dimensions in + // IsHloOpSupported. + auto indexing = + ComputeThreadIdToOutputIndexing(0, entry_function.getContext()); + TF_RET_CHECK(indexing) << "Indexing is never nullopt"; + + int num_inputs = fusion.fused_instructions_computation()->num_parameters(); + auto output_tensor_args = + entry_function.getArguments().drop_front(num_inputs); + + auto result_tensors = EmitThreadLoopNest( + builder, output_tensor_args, *indexing, + [&](ValueRange output_tensors, ValueRange dim_values, + ValueRange symbol_values) -> SmallVector { + auto output_indices = mlir_converter::ApplyAffineMap( + indexing->GetAffineMap(), dim_values, symbol_values, builder); + auto root_fn = call_targets( + fusion.fused_instructions_computation()->root_instruction()); + + SmallVector operands( + entry_function.getArguments().take_front(num_inputs)); + absl::c_copy(output_indices, std::back_inserter(operands)); + + auto result_scalars = + builder.create(root_fn, operands).getResults(); + + SmallVector result_tensors; + result_tensors.reserve(output_tensor_args.size()); + for (auto [tensor, value] : llvm::zip(output_tensors, result_scalars)) { + result_tensors.push_back( + builder + .create(value, tensor, output_indices) + .getResult()); + } + return result_tensors; + }); + builder.create(result_tensors); + + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/input_slices_mlir.h b/xla/service/gpu/fusions/input_slices_mlir.h new file mode 100644 index 0000000000000..1de06b963d9e5 --- /dev/null +++ b/xla/service/gpu/fusions/input_slices_mlir.h @@ -0,0 +1,64 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_MLIR_H_ +#define XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_MLIR_H_ + +#include +#include + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/status.h" + +namespace xla { +namespace gpu { + +// Generates code for input-fusible slices. Lowers to LLVM via MLIR. +class MlirInputSlicesFusion : public MlirFusionEmitterBase { + public: + explicit MlirInputSlicesFusion(const HloFusionAnalysis& analysis) + : analysis_(analysis), + unroll_factor_(analysis.input_output_info().has_4_bit_output ? 2 : 1) {} + LaunchDimensions launch_dimensions() const override; + + std::optional ComputeThreadIdToOutputIndexing( + int64_t output_id, mlir::MLIRContext* ctx) const override; + + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const override { + // TODO(b/319081342): Implement this. + return std::nullopt; + } + + protected: + absl::Status EmitEntryFunction( + const mlir_converter::PartitionedComputations& computations, + const mlir_converter::CallTargetProvider& call_targets, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const override; + + private: + const HloFusionAnalysis& analysis_; + const int unroll_factor_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_MLIR_H_ diff --git a/xla/service/gpu/fusions/input_slices_mlir_test.cc b/xla/service/gpu/fusions/input_slices_mlir_test.cc new file mode 100644 index 0000000000000..b37dcfa7e0c23 --- /dev/null +++ b/xla/service/gpu/fusions/input_slices_mlir_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/input_slices_mlir.h" + +#include +#include "xla/error_spec.h" +#include "xla/service/gpu/fusions/mlir_emitter_test_base.h" +#include "tsl/lib/core/status_test_util.h" + +namespace xla { +namespace gpu { +namespace { + +using MlirInputSlicesFusionTest = MlirEmitterTestBase; + +TEST_F(MlirInputSlicesFusionTest, SimpleInputSlices) { + auto kHloString = R"( + HloModule module + + fused_computation { + %input = f32[2,3,5,7]{2,1,0,3} parameter(0) + slice0 = f32[1,2,3,5]{2,1,0,3} slice(input), slice={[0:1],[1:3],[0:3],[2:7]} + slice1 = f32[1,2,3,5]{2,1,0,3} slice(input), slice={[0:1],[0:2],[0:3],[2:7]} + ROOT tuple = (f32[1,2,3,5]{2,1,0,3}, f32[1,2,3,5]{2,1,0,3}) tuple(slice0, slice1) + } + ENTRY entry { + %input = f32[2,3,5,7]{2,1,0,3} parameter(0) + ROOT %fusion = (f32[1,2,3,5]{2,1,0,3}, f32[1,2,3,5]{2,1,0,3}) fusion(%input), kind=kLoop, calls=fused_computation + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK: arith.cmpi sge + // CHECK: arith.cmpi sle + // CHECK: arith.andi + // CHECK: scf.if + // CHECK: func.func private @fused_computation_input + // CHECK: tensor.extract + // CHECK: func.func private @fused_computation_tuple + // CHECK-COUNT-2: xla_gpu.pure_call + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/input_slices_test.cc b/xla/service/gpu/fusions/input_slices_test.cc new file mode 100644 index 0000000000000..094bbfac7a27a --- /dev/null +++ b/xla/service/gpu/fusions/input_slices_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/input_slices.h" + +#include + +#include +#include +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/service/gpu/fusions/fusions.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class InputSlicesTest : public HloTestBase { + public: + void SetUp() override { + HloTestBase::SetUp(); + printer_ = + AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, + {"chunk_id", "unroll_id"}); + } + + protected: + AffineMapPrinter printer_; + mlir::MLIRContext mlir_context_; +}; + +TEST_F(InputSlicesTest, ThreadIndexing) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + fused_computation { + %input = f32[2,3,5,7]{2,1,0,3} parameter(0) + slice0 = f32[1,2,3,5]{2,1,0,3} slice(input), slice={[0:1],[1:3],[0:3],[2:7]} + slice1 = f32[1,2,3,5]{2,1,0,3} slice(input), slice={[0:1],[0:2],[0:3],[2:7]} + ROOT tuple = (f32[1,2,3,5]{2,1,0,3}, f32[1,2,3,5]{2,1,0,3}) tuple(slice0, slice1) + } + + ENTRY entry { + %input = f32[2,3,5,7]{2,1,0,3} parameter(0) + ROOT %fusion = (f32[1,2,3,5]{2,1,0,3}, f32[1,2,3,5]{2,1,0,3}) fusion(%input), kind=kLoop, calls=fused_computation + })") + .value(); + + stream_executor::DeviceDescription device_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis_fused = AnalyzeFusion(*root, device_info); + + TF_ASSERT_OK_AND_ASSIGN( + auto emitter, + GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused})); + auto fusion = dynamic_cast(emitter.get()); + ASSERT_NE(fusion, nullptr); + + auto thread_id_to_output_indexing = + fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_); + EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (0, + ((th_x + bl_x * 128) floordiv 3) mod 2, + (th_x + bl_x * 128) mod 3, + ((bl_x * 64 + th_x floordiv 2) floordiv 3) mod 5) + domain: + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 1] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + th_x + bl_x * 128 in [0, 29] + )")); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/loop.cc b/xla/service/gpu/fusions/loop.cc index 9e5b6aa713b36..e417f96923f4e 100644 --- a/xla/service/gpu/fusions/loop.cc +++ b/xla/service/gpu/fusions/loop.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,24 +14,242 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/fusions/loop.h" +#include +#include +#include #include +#include "absl/numeric/bits.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Type.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout_util.h" +#include "xla/service/gpu/elemental_ir_emitter.h" +#include "xla/service/gpu/gpu_fusible.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/parallel_loop_emitter.h" #include "xla/service/llvm_ir/fused_ir_emitter.h" #include "xla/service/llvm_ir/ir_array.h" +#include "xla/shape.h" +#include "xla/status.h" namespace xla { namespace gpu { +namespace { -Status LoopFusion::EmitKernel( - IrEmitterContext& ir_emitter_context, ElementalIrEmitter& elemental_emitter, - const HloFusionInstruction& fusion, const LaunchDimensions& launch_dims, - std::vector inputs, std::vector outputs, - llvm::IRBuilder<>* builder, int kernel_index) const { +const Shape& GetElementShape(const HloFusionAnalysis& analysis) { + const Shape* shape = &analysis.fusion_roots().front()->shape(); + while (shape->IsTuple()) { + shape = &shape->tuple_shapes(0); + } + return *shape; +} + +// Computes the maximum valid unroll factor for a given instruction. +int ComputeMaxUnrollFactor(int64_t num_elements) { + constexpr int kMaxUnrollFactor = 4; + for (int i = kMaxUnrollFactor; i > 1; i /= 2) { + if (num_elements % i == 0) { + return i; + } + } + return 1; +} + +// Determines if we enable the row optimized codegen. When we have a fusion with +// only pointwise operations, scalar broadcasting and row broadcasting, we can +// trigger a kernel that vectorizes the row loads. This speeds up the kernel, in +// particular on A100. The int is the number of inputs with rank `out_rank`. Its +// value is only defined if row vectorization is enabled. +std::pair RowVectorizationEnabled( + const HloFusionAdaptor& fusion, int64_t out_rank) { + auto roots = fusion.GetRoots(); + const auto is_row_major = [](auto instr) { + // Only tested when the inputs are row-major. So only enable that case. + // Maybe it would work if only the inner dimensions is contiguous. + return LayoutUtil::IsMonotonicWithDim0Major(instr.shape().layout()); + }; + bool row_vectorized = roots.size() == 1 && !roots[0].shape().IsTuple() && + is_row_major(roots[0]); + if (!row_vectorized) { + return {false, 0}; + } + + // Check that the operations in the fusion are supported. Each + // supported operation (or category) must be manually vetted as XLA + // only unrolls and relies on LLVM to vectorize. But this is brittle. + // Currently tested and supported operations: + // Elementwise, scalar and row broadcasting. + // + // We also detect at the same time if there is a row broadcasting + // operation. + int num_big_inputs = 0; + bool some_row_broadcasting = false; + HloBfsConsumersFirstTraversal( + roots, fusion, + [&](auto node) -> TraversalResult { + if (!row_vectorized) { + return TraversalResult::kInterrupt; + } + + if (node.instruction().IsElementwise()) { + return TraversalResult::kAdvance; + } + + switch (node.opcode()) { + case HloOpcode::kConstant: + return TraversalResult::kSkip; + case HloOpcode::kParameter: + return TraversalResult::kAdvance; + case HloOpcode::kBroadcast: { + auto dims = node.instruction().dimensions(); + if (dims.empty()) { + return TraversalResult::kAdvance; + } + + if (dims.size() == 1 && dims.front() == node.shape().rank() - 1) { + some_row_broadcasting = true; + return TraversalResult::kAdvance; + } + TF_FALLTHROUGH_INTENDED; + } + default: + VLOG(2) << "Row vectorization not enabled due to: " + << node.ToString(); + row_vectorized = false; + return TraversalResult::kInterrupt; + } + }, + [&](auto argument) { + if (argument.shape().rank() == out_rank) { + ++num_big_inputs; + } + if (!is_row_major(argument)) { + row_vectorized = false; + } + }); + // Trigger only when there is a row broadcasting. + return std::make_pair(row_vectorized && some_row_broadcasting, + num_big_inputs); +} + +} // namespace + +LaunchDimensionsConfig ComputeLoopFusionConfig( + const HloFusionAnalysis& analysis) { + int unroll_factor = 1; + // Unrolling is good to read large inputs with small elements + // due to vector loads, but increases the register pressure when one + // thread has to produce multiple output elements. + // Therefore for fusions with small outputs prefer to use one thread + // per output element = no unroll. + // Call 'small' fusions that use less threads than the GPU has. + const auto& element_shape = GetElementShape(analysis); + int64_t num_elements = ShapeUtil::ElementsIn(element_shape); + int64_t n_threads_max = analysis.device_info().threads_per_core_limit() * + analysis.device_info().core_count(); + if (num_elements >= n_threads_max && + !MayPreventVectorization(analysis.fusion())) { + unroll_factor = ComputeMaxUnrollFactor(num_elements); + } + // CHECK that unroll_factor is a power-of-2, as needed by the logic below. + CHECK(absl::has_single_bit(static_cast(unroll_factor))); + if (analysis.input_output_info().has_4_bit_output && unroll_factor == 1) { + // Ensure a single thread writes to a byte containing two int4 values by + // setting unroll_factor to 2. unroll_factor is always a power of 2, so + // setting it to 2 here ensures unroll_factor is even when there are 4-bit + // outputs. Setting unroll_factor is safe even if there are an odd number of + // elements, as the parallel loop emitter will insert a bounds check in this + // case to ensure the out-of-bounds element is not computed and written. + // Setting unroll_factor is safe even if MayPreventVectorization returns + // false, as the MayPreventVectorization check is an optimization, not a + // correctness requirement. + unroll_factor = 2; + } + VLOG(2) << "Unroll factor: " << unroll_factor; + + bool row_vectorized; + int num_big_inputs; + std::tie(row_vectorized, num_big_inputs) = + RowVectorizationEnabled(analysis.fusion(), element_shape.rank()); + bool few_waves = !HloAnyOf( + analysis.fusion().GetRoots(), analysis.fusion(), [&](auto instr) { + if (instr.opcode() == HloOpcode::kParameter || + instr.opcode() == HloOpcode::kConstant || + HloInstruction::IsOpElementwise(instr.opcode())) { + return false; + } + if (auto broadcast = + DynCast(&instr.instruction())) { + if (broadcast->dimensions().empty() || + // More than 3 big inputs cause a speed regression. + (row_vectorized && num_big_inputs <= 3)) { + return false; + } + } + VLOG(2) << "few_waves not enabled due to: " + << instr.instruction().ToString(); + return true; + }); + + LaunchDimensionsConfig launch_config{unroll_factor, few_waves, + row_vectorized}; + // Check that the shapes is supported. + if (launch_config.row_vectorized && + ThreadsPerBlockRowVectorized(element_shape, analysis.device_info(), + launch_config) <= 0) { + VLOG(2) << "Cancelling row_vectorization as the shape isn't supported."; + launch_config.row_vectorized = false; + launch_config.few_waves = false; + } + return launch_config; +} + +LoopFusion::LoopFusion(const HloFusionAnalysis& analysis) + : analysis_(analysis), config_(ComputeLoopFusionConfig(analysis)) {} + +std::optional LoopFusion::ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const { + auto launch_dims = launch_dimensions(); + return GetDefaultThreadIdIndexingMap(launch_dims, config_.unroll_factor, + GetElementShape(analysis_), ctx); +} + +std::optional LoopFusion::ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const { + std::optional thread_id_to_output_indexing = + ComputeThreadIdToOutputIndexing(root_index, ctx); + if (!thread_id_to_output_indexing.has_value()) { + return std::nullopt; + } + const HloInstruction* fusion_root = analysis_.fusion_roots()[root_index]; + auto output_to_input_indexing = + ComputeOutputToInputIndexing(fusion_root, /*output_id=*/0, ctx); + IndexingMapSet output_to_input_indexing_set = + output_to_input_indexing.indexing_maps[hero_operand_index]; + // Since we are computing the indexing for a non-fusion op, there is only one + // indexing map per operand. + CHECK_EQ(output_to_input_indexing_set.size(), 1); + IndexingMap thread_id_to_input_indexing_map = ComposeIndexingMaps( + *thread_id_to_output_indexing, *output_to_input_indexing_set.begin()); + thread_id_to_input_indexing_map.Simplify(GetIndexingMapForInstruction); + return thread_id_to_input_indexing_map; +} + +absl::Status LoopFusion::EmitKernel(IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion, + const LaunchDimensions& launch_dims, + std::vector inputs, + std::vector outputs, + llvm::IRBuilder<>* builder) const { + GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder); FusedIrEmitter fused_emitter(elemental_emitter); for (int i = 0; i < fusion.fused_parameters().size(); i++) { fused_emitter.BindGenerator( @@ -47,13 +265,13 @@ Status LoopFusion::EmitKernel( GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder); return ParallelLoopEmitter(element_generator, outputs, launch_dims, builder, - *analysis_.GetLoopFusionConfig()) + config_) .EmitLoop(fusion.name(), index_type); } -StatusOr LoopFusion::launch_dimensions( - IrEmitterContext& ir_emitter_context, int kernel_index) const { - return analysis_.GetLaunchDimensions(); +LaunchDimensions LoopFusion::launch_dimensions() const { + return CalculateLaunchDimensions(GetElementShape(analysis_), + analysis_.device_info(), config_); } } // namespace gpu diff --git a/xla/service/gpu/fusions/loop.h b/xla/service/gpu/fusions/loop.h index 8a87641e30fee..e466abe66a843 100644 --- a/xla/service/gpu/fusions/loop.h +++ b/xla/service/gpu/fusions/loop.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,14 +15,20 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_FUSIONS_LOOP_H_ #define XLA_SERVICE_GPU_FUSIONS_LOOP_H_ +#include +#include #include +#include "llvm/IR/IRBuilder.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/service/elemental_ir_emitter.h" #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/llvm_ir/ir_array.h" +#include "xla/status.h" namespace xla { namespace gpu { @@ -30,24 +36,32 @@ namespace gpu { // Generic loop fusion. class LoopFusion : public KernelFusionEmitterBase { public: - explicit LoopFusion(HloFusionAnalysis& analysis) : analysis_(analysis) {} - StatusOr launch_dimensions( - IrEmitterContext& ir_emitter_context, int kernel_index) const override; + explicit LoopFusion(const HloFusionAnalysis& analysis); + LaunchDimensions launch_dimensions() const override; + + std::optional ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const override; + + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const override; protected: - Status EmitKernel(IrEmitterContext& ir_emitter_context, - ElementalIrEmitter& elemental_emitter, - const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder, - int kernel_index) const override; + absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion, + const LaunchDimensions& launch_dims, + std::vector inputs, + std::vector outputs, + llvm::IRBuilder<>* builder) const override; private: - HloFusionAnalysis& analysis_; + const HloFusionAnalysis& analysis_; + LaunchDimensionsConfig config_; }; +LaunchDimensionsConfig ComputeLoopFusionConfig( + const HloFusionAnalysis& analysis); + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/fusions/loop_mlir.cc b/xla/service/gpu/fusions/loop_mlir.cc new file mode 100644 index 0000000000000..bf41d50930ea9 --- /dev/null +++ b/xla/service/gpu/fusions/loop_mlir.cc @@ -0,0 +1,149 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/loop_mlir.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/shape.h" +#include "xla/status_macros.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +using llvm::SmallVector; +using mlir::Value; +using mlir::ValueRange; + +const Shape& GetFusionResultShape(const HloFusionAnalysis& analysis) { + const Shape* shape = &analysis.fusion_roots().front()->shape(); + while (shape->IsTuple()) { + shape = &shape->tuple_shapes(0); + } + return *shape; +} + +} // namespace + +std::optional MlirLoopFusion::ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const { + auto launch_dims = launch_dimensions(); + return GetDefaultThreadIdIndexingMap(launch_dims, config_.unroll_factor, + GetFusionResultShape(analysis_), ctx); +} + +std::optional MlirLoopFusion::ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const { + std::optional thread_id_to_output_indexing = + ComputeThreadIdToOutputIndexing(root_index, ctx); + if (!thread_id_to_output_indexing.has_value()) { + return std::nullopt; + } + const HloInstruction* fusion_root = analysis_.fusion_roots()[root_index]; + auto output_to_input_indexing = + ComputeOutputToInputIndexing(fusion_root, /*output_id=*/0, ctx); + IndexingMapSet output_to_input_indexing_set = + output_to_input_indexing.indexing_maps[hero_operand_index]; + // Since we are computing the indexing for a non-fusion op, there is only one + // indexing map per operand. + CHECK_EQ(output_to_input_indexing_set.size(), 1); + IndexingMap thread_id_to_input_indexing_map = ComposeIndexingMaps( + *thread_id_to_output_indexing, *output_to_input_indexing_set.begin()); + thread_id_to_input_indexing_map.Simplify(GetIndexingMapForInstruction); + return thread_id_to_input_indexing_map; +} + +LaunchDimensions MlirLoopFusion::launch_dimensions() const { + return CalculateLaunchDimensions(GetFusionResultShape(analysis_), + analysis_.device_info(), config_); +} + +absl::Status MlirLoopFusion::EmitEntryFunction( + const mlir_converter::PartitionedComputations& computations, + const mlir_converter::CallTargetProvider& call_targets, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const { + mlir::ImplicitLocOpBuilder builder(entry_function.getLoc(), entry_function); + builder.setInsertionPointToStart(entry_function.addEntryBlock()); + + // We enforce that all the root shapes have identical dimensions in + // IsHloOpSupported. + auto indexing = + ComputeThreadIdToOutputIndexing(0, entry_function.getContext()); + TF_RET_CHECK(indexing) << "Indexing is never nullopt"; + + int num_inputs = fusion.fused_instructions_computation()->num_parameters(); + auto output_tensor_args = + entry_function.getArguments().drop_front(num_inputs); + + auto body_builder = [&](ValueRange output_tensors, ValueRange dim_values, + ValueRange symbol_values) -> SmallVector { + auto output_indices = mlir_converter::ApplyAffineMap( + indexing->GetAffineMap(), dim_values, symbol_values, builder); + auto root_fn = call_targets( + fusion.fused_instructions_computation()->root_instruction()); + + // Generate the operands for the root function: input tensors + + // output indices. + SmallVector operands( + entry_function.getArguments().take_front(num_inputs)); + absl::c_copy(output_indices, std::back_inserter(operands)); + auto result_scalars = + builder.create(root_fn, operands).getResults(); + + SmallVector result_tensors; + result_tensors.reserve(output_tensor_args.size()); + for (auto [tensor, value] : llvm::zip(output_tensors, result_scalars)) { + result_tensors.push_back(builder.create( + value, tensor, output_indices)); + } + return result_tensors; + }; + + builder.create( + EmitThreadLoopNest(builder, output_tensor_args, *indexing, body_builder)); + + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/loop_mlir.h b/xla/service/gpu/fusions/loop_mlir.h new file mode 100644 index 0000000000000..228c8c87b5ff2 --- /dev/null +++ b/xla/service/gpu/fusions/loop_mlir.h @@ -0,0 +1,61 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_LOOP_MLIR_H_ +#define XLA_SERVICE_GPU_FUSIONS_LOOP_MLIR_H_ + +#include +#include + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/loop.h" +#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/status.h" + +namespace xla { +namespace gpu { + +// Generic loop fusion. Lowers to LLVM via MLIR. +class MlirLoopFusion : public MlirFusionEmitterBase { + public: + explicit MlirLoopFusion(const HloFusionAnalysis& analysis) + : analysis_(analysis), config_(ComputeLoopFusionConfig(analysis)) {} + LaunchDimensions launch_dimensions() const override; + + std::optional ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const override; + + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const override; + + protected: + absl::Status EmitEntryFunction( + const mlir_converter::PartitionedComputations& computations, + const mlir_converter::CallTargetProvider& call_targets, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const override; + + private: + const HloFusionAnalysis& analysis_; + LaunchDimensionsConfig config_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_LOOP_MLIR_H_ diff --git a/xla/service/gpu/fusions/loop_mlir_test.cc b/xla/service/gpu/fusions/loop_mlir_test.cc new file mode 100644 index 0000000000000..1f3d41bddc46a --- /dev/null +++ b/xla/service/gpu/fusions/loop_mlir_test.cc @@ -0,0 +1,392 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/loop_mlir.h" + +#include +#include "xla/error_spec.h" +#include "xla/service/gpu/fusions/mlir_emitter_test_base.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +using MlirLoopFusionTest = MlirEmitterTestBase; + +TEST_F(MlirLoopFusionTest, ThreadId_IndexingUnrolled) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule module + + neg { + %input = f32[100,200,300] parameter(0) + ROOT neg = f32[100,200,300] negate(%input) + } + ENTRY entry { + %input = f32[100,200,300] parameter(0) + ROOT %fusion = f32[100,200,300] fusion(%input), kind=kLoop, calls=neg + } + )")); + thread_id_printer_.SetSymbolName(0, "chunk_id"); + thread_id_printer_.SetSymbolName(1, "unroll_id"); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + MlirLoopFusion fusion(analysis); + auto thread_id_to_output_indexing = + fusion.ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); + + EXPECT_THAT(thread_id_to_output_indexing->ToString(thread_id_printer_), + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( + (((bl_x * 16 + th_x floordiv 8) floordiv 3 + chunk_id * 5376) floordiv 625) mod 100, + (((th_x + bl_x * 128) floordiv 3 + chunk_id * 43008) floordiv 25) mod 200, + (th_x * 4 + bl_x * 512 + chunk_id * 516096) mod 300 + unroll_id + ) + domain: + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 1007] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 11] + unroll_id in [0, 3] + (th_x + bl_x * 128) * 4 + chunk_id * 516096 in [0, 5999996] +)")); +} + +TEST_F(MlirLoopFusionTest, ThreadId_IndexingNotUnrolled) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule module + + neg { + %input = f32[20] parameter(0) + ROOT neg = f32[20] negate(%input) + } + ENTRY entry { + %input = f32[20] parameter(0) + ROOT %fusion = f32[20] fusion(%input), kind=kLoop, calls=neg + } + )")); + thread_id_printer_.SetSymbolName(0, "chunk_id"); + thread_id_printer_.SetSymbolName(1, "unroll_id"); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + MlirLoopFusion fusion(analysis); + auto thread_id_to_output_indexing = + fusion.ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); + EXPECT_THAT(thread_id_to_output_indexing->ToString(thread_id_printer_), + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) + domain: + th_x in [0, 19] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 0] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + )")); + auto thread_id_to_input_indexing = fusion.ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); + EXPECT_THAT(thread_id_to_input_indexing->ToString(thread_id_printer_), + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) + domain: + th_x in [0, 19] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 0] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + )")); +} + +TEST_F(MlirLoopFusionTest, ThreadId_Broadcast) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule module + + bcast { + %input = f32[20] parameter(0) + ROOT bcast = f32[10, 20, 30] broadcast(%input), dimensions={1} + } + ENTRY entry { + %input = f32[20] parameter(0) + ROOT %fusion = f32[10, 20, 30] fusion(%input), kind=kLoop, calls=bcast + } + )")); + thread_id_printer_.SetSymbolName(0, "chunk_id"); + thread_id_printer_.SetSymbolName(1, "unroll_id"); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + MlirLoopFusion fusion(analysis); + auto thread_id_to_output_indexing = + fusion.ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); + EXPECT_THAT(thread_id_to_output_indexing->ToString(thread_id_printer_), + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( + ((bl_x * 16 + th_x floordiv 8) floordiv 75) mod 10, + ((bl_x * 64 + th_x floordiv 2) floordiv 15) mod 20, + (th_x + bl_x * 128) mod 30) + domain: + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 46] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + th_x + bl_x * 128 in [0, 5999] + )")); + auto thread_id_to_input_indexing = fusion.ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); + EXPECT_THAT(thread_id_to_input_indexing->ToString(thread_id_printer_), + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( + ((bl_x * 64 + th_x floordiv 2) floordiv 15) mod 20) + domain: + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 46] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + th_x + bl_x * 128 in [0, 5999] + )")); +} + +TEST_F(MlirLoopFusionTest, NoCodeDuplication) { + // This test HLO is copied from + // xla/service/fusion_node_indexing_evaluation_test.cc. + auto kHloString = R"( + HloModule test_module + + %fused_computation (param: f32[6]) -> f32[2] { + %param = f32[6]{0} parameter(0) + %slice0.1 = f32[5]{0} slice(f32[6]{0} %param), slice={[0:5]} + %slice0.2 = f32[5]{0} slice(f32[6]{0} %param), slice={[1:6]} + %add0 = f32[5]{0} add(f32[5]{0} %slice0.1, f32[5]{0} %slice0.2) + %slice1.1 = f32[4]{0} slice(f32[5]{0} %add0), slice={[0:4]} + %slice1.2 = f32[4]{0} slice(f32[5]{0} %add0), slice={[1:5]} + %add1 = f32[4]{0} add(f32[4]{0} %slice1.1, f32[4]{0} %slice1.2) + %slice2.1 = f32[3]{0} slice(f32[4]{0} %add1), slice={[0:3]} + %slice2.2 = f32[3]{0} slice(f32[4]{0} %add1), slice={[1:4]} + %add2 = f32[3]{0} add(f32[3]{0} %slice2.1, f32[3]{0} %slice2.2) + %slice3.1 = f32[2]{0} slice(f32[3]{0} %add2), slice={[0:2]} + %slice3.2 = f32[2]{0} slice(f32[3]{0} %add2), slice={[1:3]} + ROOT %add3 = f32[2]{0} add(f32[2]{0} %slice3.1, f32[2]{0} %slice3.2) + } + ENTRY entry_computation { + p0 = f32[] parameter(0) + add = f32[] add(p0, p0) + broadcast = f32[6]{0} broadcast(add), dimensions={} + ROOT %fusion = f32[2]{0} fusion(broadcast), kind=kLoop, calls=%fused_computation + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK-COUNT-4: arith.add + // CHECK-NOT: arith.add + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirLoopFusionTest, TwoUsersConsistentIndexing) { + auto kHloString = R"( + HloModule test_module + + %fused_computation (param: f32[6]) -> f32[2] { + %p0 = f32[2]{0} parameter(0) + %p1 = f32[2]{0} parameter(1) + %add = f32[2] add(%p0, %p1) + %sub = f32[2] subtract(%p0, %p1) + %mul = f32[2] multiply(%add, %sub) + %div = f32[2] divide(%add, %sub) + ROOT %atan2 = f32[2] atan2(%mul, %div) + } + ENTRY entry_computation { + p0 = f32[2] parameter(0) + p1 = f32[2] parameter(1) + ROOT %fusion = f32[2] fusion(p0, p1), kind=kLoop, calls=%fused_computation + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK: func.func @fused_computation + // CHECK-NEXT: gpu.thread_id + // CHECK-NEXT: pure_call @fused_computation_atan2 + // CHECK-NEXT: tensor.insert + // CHECK-NEXT: return + + // CHECK: func.func private @fused_computation_atan2 + // CHECK-NEXT: tensor.extract + // CHECK-NEXT: tensor.extract + // CHECK-NEXT: addf + // CHECK-NEXT: subf + // CHECK-NEXT: mulf + // CHECK-NEXT: divf + // CHECK-NEXT: atan2 + // CHECK-NEXT: return + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirLoopFusionTest, ComplexOps) { + auto kHloString = R"( + HloModule test_module + + %fused_computation { + %p0 = f32[2]{0} parameter(0) + %p1 = f32[2]{0} parameter(1) + %p2 = c64[2]{0} parameter(2) + %complex = c64[2] complex(%p0, %p1) + ROOT %add = c64[2] add(%complex, %p2) + } + ENTRY entry_computation { + p0 = f32[2] parameter(0) + p1 = f32[2] parameter(1) + p2 = c64[2] parameter(2) + ROOT %fusion = c64[2] fusion(p0, p1, p2), kind=kLoop, calls=%fused_computation + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK: func.func @fused_computation + // CHECK-NEXT: gpu.thread_id + // CHECK-NEXT: pure_call @fused_computation_add + // CHECK-NEXT: tensor.insert + // CHECK-NEXT: return + + // CHECK: func.func private @fused_computation_add + // CHECK-NEXT: tensor.extract + // CHECK-NEXT: tensor.extract + // CHECK-NEXT: complex.create + // CHECK-NEXT: tensor.extract + // CHECK-NEXT: complex.add + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirLoopFusionTest, IotaCopyBitcastBroadcastReshapeReverseTranspose) { + auto kHloString = R"( + HloModule test_module + + %fused_computation { + %iota = f32[10,20,30] iota(), iota_dimension=2 + %copy = f32[10,20,30] copy(%iota) + %bitcast = s32[10,20,30] bitcast-convert(%copy) + %broadcast = s32[2,10,3,20,5,30,7] broadcast(%bitcast), + dimensions={1,3,5} + %reshape = s32[20,60,150,7] reshape(%broadcast) + %reverse = s32[20,60,150,7] reverse(%reshape), dimensions={2,3} + ROOT %transpose = s32[60,20,7,150] transpose(%reverse), + dimensions={1,0,3,2} + } + ENTRY entry_computation { + ROOT %fusion = s32[60,20,7,150] fusion(), + kind=kLoop, calls=%fused_computation + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK-COUNT-2: func.func + // CHECK-NOT: func.func + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirLoopFusionTest, VariadicReduce) { + auto kHloString = R"( + HloModule Test, is_scheduled=true + + Add { + scalar_lhs.0 = f32[] parameter(0) + scalar_lhs.1 = f32[] parameter(1) + scalar_rhs.0 = f32[] parameter(2) + scalar_rhs.1 = f32[] parameter(3) + add = f32[] add(scalar_lhs.0, scalar_rhs.0) + mul = f32[] multiply(scalar_lhs.1, scalar_rhs.1) + ROOT t = (f32[], f32[]) tuple(add, mul) + } + fused_computation { + param_0 = f32[3,4,5]{2,1,0} parameter(0) + param_1 = f32[3,4,5]{2,1,0} parameter(1) + param_2 = f32[] parameter(2) + ROOT d.1 = (f32[4], f32[4]) reduce(f32[3,4,5]{2,1,0} param_0, + f32[3,4,5]{2,1,0} %param_1, f32[] param_2, f32[] param_2), + dimensions={0,2}, to_apply=Add + } + ENTRY main { + a = f32[3,4,5]{2,1,0} parameter(0) + b = f32[3,4,5]{2,1,0} parameter(1) + c = f32[] constant(0) + ROOT fusion = (f32[4]{0}, f32[4]{0}) fusion(a, b, c), + kind=kLoop, calls=fused_computation + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK: func @fused_computation( + // CHECK: %[[TID_X:.*]] = gpu.thread_id x + // CHECK: %[[SCALARS_0:.*]], %[[SCALARS_1:.*]] = xla_gpu.pure_call @fused_computation_d_1 + // CHECK: %[[INSERTED_1:.*]] = tensor.insert %[[SCALARS_0]] into %{{.*}}[%[[TID_X]]] + // CHECK: %[[INSERTED_2:.*]] = tensor.insert %[[SCALARS_1]] into %{{.*}}[%[[TID_X]]] + // CHECK: return %[[INSERTED_1]], %[[INSERTED_2]] + + // CHECK: func private @fused_computation_d_1 + // CHECK: %[[RET:.*]]:2 = func.call @Add_t + // CHECK: yield %[[RET]]#0, %[[RET]]#1 + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirLoopFusionTest, MinimumMaximum) { + auto kHloString = R"( + HloModule Test + + fused_computation { + param0 = f64[] parameter(0) + param1 = f64[] parameter(1) + + minimum = f64[] minimum(f64[] param0, f64[] param1) + maximum = f64[] maximum(f64[] param0, f64[] param1) + ROOT tuple = (f64[], f64[]) tuple(minimum, maximum) + } + + ENTRY main { + param0 = f64[] parameter(0) + param1 = f64[] parameter(1) + ROOT fusion = (f64[], f64[]) fusion(f64[] param0, f64[] param1), kind=kLoop, calls=fused_computation + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK-LABEL: fused_computation_tuple + // CHECK: arith.minimumf + // CHECK: arith.maximumf + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/loop_test.cc b/xla/service/gpu/fusions/loop_test.cc new file mode 100644 index 0000000000000..1bb5fdb8705d3 --- /dev/null +++ b/xla/service/gpu/fusions/loop_test.cc @@ -0,0 +1,223 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/loop.h" + +#include +#include + +#include +#include +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/service/gpu/fusions/fusions.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/status_macros.h" +#include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class LoopTest : public HloTestBase { + public: + void SetUp() override { + HloTestBase::SetUp(); + + printer_ = + AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, + {"chunk_id", "unroll_id"}); + } + + protected: + stream_executor::DeviceDescription device_info_ = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); + AffineMapPrinter printer_; + mlir::MLIRContext mlir_context_; +}; + +absl::StatusOr> GetFusion( + const HloFusionAnalysis& analysis) { + TF_ASSIGN_OR_RETURN( + auto emitter, GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis})); + auto fusion = dynamic_cast(emitter.get()); + TF_RET_CHECK(fusion != nullptr); + + emitter.release(); + return std::unique_ptr{fusion}; +} + +TEST_F(LoopTest, ThreadIndexingUnrolled) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + neg { + %input = f32[100,200,300] parameter(0) + ROOT neg = f32[100,200,300] negate(%input) + } + + ENTRY entry { + %input = f32[100,200,300] parameter(0) + ROOT %fusion = f32[100,200,300] fusion(%input), kind=kLoop, calls=neg + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis)); + auto thread_id_to_output_indexing = + loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, + &mlir_context_); + + EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( + (((bl_x * 16 + th_x floordiv 8) floordiv 3 + chunk_id * 5376) floordiv 625) mod 100, + (((th_x + bl_x * 128) floordiv 3 + chunk_id * 43008) floordiv 25) mod 200, + (th_x * 4 + bl_x * 512 + chunk_id * 516096) mod 300 + unroll_id + ) + domain: + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 1007] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 11] + unroll_id in [0, 3] + (th_x + bl_x * 128) * 4 + chunk_id * 516096 in [0, 5999996] +)")); +} + +TEST_F(LoopTest, ThreadIndexingNotUnrolled) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + neg { + %input = f32[20] parameter(0) + ROOT neg = f32[20] negate(%input) + } + + ENTRY entry { + %input = f32[20] parameter(0) + ROOT %fusion = f32[20] fusion(%input), kind=kLoop, calls=neg + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis)); + auto thread_id_to_output_indexing = + loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, + &mlir_context_); + EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) + domain: + th_x in [0, 19] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 0] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + )")); + auto thread_id_to_input_indexing = + loop_fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); + EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_), + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) + domain: + th_x in [0, 19] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 0] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + )")); +} + +TEST_F(LoopTest, Broadcast) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + bcast { + %input = f32[20] parameter(0) + ROOT bcast = f32[10, 20, 30] broadcast(%input), dimensions={1} + } + + ENTRY entry { + %input = f32[20] parameter(0) + ROOT %fusion = f32[10, 20, 30] fusion(%input), kind=kLoop, calls=bcast + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis)); + auto thread_id_to_output_indexing = + loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, + &mlir_context_); + EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( + ((bl_x * 16 + th_x floordiv 8) floordiv 75) mod 10, + ((bl_x * 64 + th_x floordiv 2) floordiv 15) mod 20, + (th_x + bl_x * 128) mod 30) + domain: + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 46] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + th_x + bl_x * 128 in [0, 5999] + )")); + auto thread_id_to_input_indexing = + loop_fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); + EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_), + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( + ((bl_x * 64 + th_x floordiv 2) floordiv 15) mod 20) + domain: + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 46] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + th_x + bl_x * 128 in [0, 5999] + )")); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/BUILD b/xla/service/gpu/fusions/mlir/BUILD new file mode 100644 index 0000000000000..f88c0d464b244 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/BUILD @@ -0,0 +1,353 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("//xla:xla.bzl", "xla_cc_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +cc_library( + name = "computation_partitioner", + srcs = ["computation_partitioner.cc"], + hdrs = ["computation_partitioner.h"], + deps = [ + ":type_util", + "//xla:shape_util", + "//xla:union_find", + "//xla/hlo/ir:hlo", + "//xla/service/llvm_ir:llvm_util", + "//xla/translate/hlo_to_mhlo:hlo_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:TensorDialect", + ], +) + +xla_cc_test( + name = "computation_partitioner_test", + srcs = ["computation_partitioner_test.cc"], + deps = [ + ":computation_partitioner", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "elemental_hlo_to_mlir", + srcs = ["elemental_hlo_to_mlir.cc"], + hdrs = ["elemental_hlo_to_mlir.h"], + deps = [ + ":computation_partitioner", + "//xla:comparison_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/mlir/utils:type_util", + "//xla/mlir_hlo", + "//xla/mlir_hlo:map_mhlo_to_scalar_op", + "//xla/mlir_hlo:type_conversion", + "//xla/service:algorithm_util", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_map", + "//xla/service/llvm_ir:llvm_util", + "//xla/stream_executor:device_description", + "//xla/translate/hlo_to_mhlo:hlo_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineUtils", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "elemental_hlo_to_mlir_test", + srcs = ["elemental_hlo_to_mlir_test.cc"], + deps = [ + ":computation_partitioner", + ":elemental_hlo_to_mlir", + "//xla:status_macros", + "//xla/mlir_hlo", + "//xla/service:hlo_parser", + "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/llvm_ir:llvm_util", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:DLTIDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "mlir_fusion_emitter", + srcs = ["mlir_fusion_emitter.cc"], + hdrs = ["mlir_fusion_emitter.h"], + deps = [ + ":computation_partitioner", + ":elemental_hlo_to_mlir", + ":passes", + ":type_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "//xla/mlir_hlo:mhlo_passes", + "//xla/service:buffer_assignment", + "//xla/service/gpu:ir_emitter_context", + "//xla/service/gpu:kernel_arguments", + "//xla/service/gpu:kernel_reuse_cache", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:target_util", + "//xla/service/gpu/fusions:fusion_emitter", + "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/model:indexing_map", + "//xla/service/gpu/runtime:kernel_thunk", + "//xla/service/llvm_ir:llvm_util", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineToStandard", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:BufferizationInterfaces", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:ComplexToStandard", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MemRefTransforms", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ReconcileUnrealizedCasts", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "mlir_fusion_emitter_test", + srcs = ["mlir_fusion_emitter_test.cc"], + deps = [ + ":computation_partitioner", + ":mlir_fusion_emitter", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/model:indexing_map", + "//xla/stream_executor:device_description", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BufferizationInterfaces", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:TensorDialect", + "@tsl//tsl/platform:statusor", + ], +) + +gentbl_cc_library( + name = "passes_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=GpuFusionTransforms", + ], + "passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes.td", + visibility = ["//visibility:private"], + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +) + +cc_library( + name = "passes", + srcs = [ + "expand_float_ops.cc", + "lower_func.cc", + "lower_tensors.cc", + "lower_to_llvm.cc", + "lower_xla_gpu_to_scf.cc", + "merge_pointers_to_same_slice.cc", + "propagate_slice_indices.cc", + "simplify_affine.cc", + "simplify_arith.cc", + ], + hdrs = ["passes.h"], + deps = [ + ":passes_inc_gen", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/mlir_hlo", + "//xla/mlir_hlo:map_mhlo_to_scalar_op", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_map", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineToStandard", + "@llvm-project//mlir:AffineUtils", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ArithTransforms", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:ComplexToLLVM", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncToLLVM", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:MathTransforms", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:VectorTransforms", + ], +) + +cc_library( + name = "type_util", + srcs = ["type_util.cc"], + hdrs = ["type_util.h"], + deps = [ + "//xla:shape_util", + "//xla/mlir/utils:type_util", + "//xla/translate/hlo_to_mhlo:hlo_utils", + "@com_google_absl//absl/log:check", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +xla_cc_test( + name = "type_util_test", + srcs = ["type_util_test.cc"], + deps = [ + ":type_util", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) diff --git a/xla/service/gpu/fusions/mlir/README.md b/xla/service/gpu/fusions/mlir/README.md new file mode 100644 index 0000000000000..d692bd279bce9 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/README.md @@ -0,0 +1,157 @@ +# XLA MLIR fusion emitters + +This is a prototype of a new loop emitter. The main goals are: + +- Fixing exponential code size issues with the current emitter. We should be + able to generate reasonable code for any fusion (note that execution time may + still be bad, but that's a problem for priority fusion). +- Fixing compile time (as a result of the above). +- Make the code easier to understand thanks to gradual lowering. +- Eventually extend the concepts here to the other emitters (transpose, reduce + in particular) + +## High-level overview + +The code consists of the following big building blocks: + +- Computation partitioning - splitting an HLO computation into functions +- Elemental emission of XLA instructions +- Based on the above two: emission of functions +- The actual emitter +- Lowerings to LLVM + +## Partitioning + +See `computation_partitioner.h`. + +Non-elementwise HLO instructions cannot always be emitted together. Consider the +following HLO graph: + +``` + param + | + log + | \ + | transpose + | / + add +``` + +If we emit this in a single function, the `log` will be accessed at two +different indices for each element of the `add`. The old emitters solve this +problem by generating the `log` twice. For this particular graph, this is not +a problem, but when there are multiple splits, the code size grows +exponentially. + +Here, we solve this problem by partitioning the graph into pieces that can be +safely emitted as one function. The criteria are: + +- Instructions that have only one user are safe to emit together with their + user. +- Instructions that have multiple users are safe to emit together with their + users if they are accessed through the same indices by all users. + +In the example above, the `add` and `tranpose` access different indices of the +`log`, so it is not safe to emit it together with them. + +The graph is therefore partitioned into three functions (each containing just +one instruction). + +## Elemental emission + +See `elemental_hlo_to_mlir.h`. + +Elemental emission is based on `mlir_hlo` and reuses it for all element-wise +instructions. For the most part, this is straightforward, but there are some +interesting things going on here. + +### Indexing transformations + +Some instructions (`transpose`, `broadcast`, `reshape`, `slice`, `reverse` and +a few more) are purely transformations on indices: to produce an element of the +result, we need to produce some other element of the input. For this, we can +reuse XLA's `indexing_analysis`, which has functions to produce the output to +input mapping for an instruction. + +For example, for a `transpose` from `[20,40]` to `[40,20]`, it will produce the +following indexing map (one affine expression per input dimension; d0 and d1 are +the output dimensions): + +``` + (d0, d1) -> d1 + (d0, d1) -> d0 +``` + +So for these pure index transformation instructions, we can simply get the map, +apply it to the output indices, and produce the input at the resulting index. + +Similarly, the `pad` op uses indexing maps and constraints for most of the +implementation. `pad` is also an indexing transformation with some added checks +to see if we return an element of the input or the padding value. + +### Tuples + +We do not support internal `tuple`s. We also do not support nested tuple +outputs. All XLA graphs that use these features can be converted to graphs that +do not. + +### Gather + +We only support canonical gathers as produced by [`gather_simplifier`]( +https://github.com/openxla/xla/blob/main/xla/service/gather_simplifier.h). + +## Emission of functions + +For a subgraph of a computation with parameters `%p0` to `%p_n`, and subgraph +roots with rank `r` and element types (`e0` to `e_m`), we use the following MLIR +function signature: + +`````` +(%p0: tensor<...>, %p1: tensor<...>, ..., %pn: tensor<...>, + %i0: index, %i1: index, ..., %i_r-1: index) -> (e0, ..., e_m) +`````` + +That is, we have one tensor input per computation parameter, one index input per +dimension of the output, and one result per output. + +To emit a function, we simply use the elemental emitter above, and recursively +emit its operands until we reach the edge of the subgraph. Then, we: + +- emit a `tensor.extract` for parameters +- or emit a `func.call` for other subgraphs + +## Putting it together: the loop emitter + +The loop emitter first partitions its fusion computation and emits code for each +subgraph. Then, it has to generate an entry function. The entry function is +different from the functions above, since it has no indices as inputs (just the +thread and block IDs) and actually needs to write the output somewhere. For the +loop emitter, this is fairly straightforward, but the transpose and reduction +emitters have non-trivial write logic. + +The signature of the entry computation is: + +``` +(%p0: tensor<...>, ..., %pn: tensor<...>, + %r0: tensor<...>, ..., %rn: tensor<...>) -> (tensor<...>, ..., tensor<...>) +``` + +Where like before, the `%pn`s are the parameters of the computation, and the +`%rn`s are the results of the computation. The entry computation takes the +results as tensors, `tensor.insert`s updates into them, and then returns them. +No other uses of the output tensors are allowed. + +## Lowerings to LLVM + +We mostly use the standard LLVM lowerings, but there are a few special passes. +We cannot use the `memref` lowerings for tensors, since we don't bufferize the +IR and our ABI is not compatible with the `memref` ABI. Instead, we have a +custom lowering directly from tensors to `LLVM`. + +- The lowering of tensors is done in `lower_tensors.cc`. `tensor.extract` is + lowered to `llvm.load`, `tensor.insert` to `llvm.store`, in the obvious way. +- `propagate_slice_indices` and `merge_pointers_to_same_slice` together + implement a detail of buffer assignment and XLA's ABI: if two tensors share + the same buffer slice, they are only passed once. These passes deduplicate the + function arguments. + diff --git a/xla/service/gpu/fusions/mlir/computation_partitioner.cc b/xla/service/gpu/fusions/mlir/computation_partitioner.cc new file mode 100644 index 0000000000000..ad46b914a9e48 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/computation_partitioner.cc @@ -0,0 +1,403 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/node_hash_map.h" +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/fusions/mlir/type_util.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/shape.h" +#include "xla/translate/hlo_to_mhlo/hlo_utils.h" +#include "xla/union_find.h" + +namespace xla { +namespace gpu { +namespace mlir_converter { +namespace { + +absl::flat_hash_map PartitionGraphByIndexing( + const HloComputation& computation) { + constexpr int kRootIndexing = 0; + int next_indexing = 1; + absl::flat_hash_map indexing; + + std::function indexing_for_instr; + indexing_for_instr = [&](const HloInstruction* instr) -> int { + auto it = indexing.find(instr); + if (it != indexing.end()) return it->second; + + if (instr->opcode() != HloOpcode::kTuple && + !HloInstruction::IsOpElementwise(instr->opcode())) { + return indexing[instr] = next_indexing++; + } + if (instr->user_count() == 0) { + return indexing[instr] = kRootIndexing; + } + // If all users have the same indexing, we can reuse it. + std::optional instr_indexing = std::nullopt; + for (auto* user : instr->users()) { + auto user_indexing = indexing_for_instr(user); + if (user->opcode() == HloOpcode::kConcatenate || + (instr_indexing && user_indexing != *instr_indexing)) { + instr_indexing = std::nullopt; + break; + } + instr_indexing = user_indexing; + } + return indexing[instr] = instr_indexing ? *instr_indexing : next_indexing++; + }; + for (auto* instr : computation.instructions()) { + indexing_for_instr(instr); + } + return indexing; +} + +} // namespace + +std::string PartitionedComputation::Subgraph::ToString() const { + std::ostringstream ss; + ss << "SUBGRAPH " << name << " {\n"; + for (auto instr : instructions_post_order) { + ss << " "; + if (absl::c_linear_search(roots, instr)) { + ss << "ROOT "; + } + ss << instr->ToString() << "\n"; + } + ss << "}"; + return ss.str(); +} + +std::string PartitionedComputation::ToString() const { + std::ostringstream ss; + ss << "PartitionedComputation " << computation_->name() << ":"; + for (const Subgraph& subgraph : subgraphs_) { + ss << "\n" << subgraph.ToString(); + } + return ss.str(); +} + +std::string PartitionedComputations::ToString() const { + std::ostringstream ss; + ss << "PartitionedComputations:"; + for (const auto& partitioned_computation : partitioned_computations_) { + ss << "\n" << partitioned_computation.ToString(); + } + return ss.str(); +} + +PartitionedComputation::PartitionedComputation( + const HloComputation* computation, + std::function is_subgraph_root) + : computation_(computation) { + CHECK_NE(computation, nullptr); + // For each instruction, figure out what function it goes in. Parameters don't + // count. + absl::node_hash_map> + disjoint_sets; + auto indexing = PartitionGraphByIndexing(*computation); + for (auto* instruction : computation->instructions()) { + disjoint_sets[instruction].Get() = instruction; + } + for (auto* instruction : computation->instructions()) { + // If the instruction has to become a subgraph root, then we do not merge. + bool can_merge = !is_subgraph_root(instruction); + if (instruction->user_count() > 0) { + // If all users have the same indexing, we can merge. + int64_t one_user_indexing = indexing.at(instruction->users().front()); + can_merge &= + absl::c_all_of(instruction->users(), [&](const HloInstruction* user) { + return indexing.at(user) == one_user_indexing; + }); + } + auto is_bad_gather = [&](const HloInstruction* user) { + // Don't merge into a gather that would evaluate the index more than once. + return user->opcode() == HloOpcode::kGather && + user->operand_index(instruction) == 1 && + instruction->shape().dimensions(1) > 1; + }; + auto is_concat = [&](const HloInstruction* user) { + // Concat codegen doesn't work if any of a concat's transitive inputs is + // reused. Instead of checking, we just cut the function at the concat, + // which has the benefit of leading to slightly easier to read IR. + return user->opcode() == HloOpcode::kConcatenate; + }; + can_merge &= absl::c_none_of(instruction->users(), is_bad_gather); + can_merge &= absl::c_none_of(instruction->users(), is_concat); + if (can_merge) { + auto& set = disjoint_sets[instruction]; + for (auto* user : instruction->users()) { + set.Merge(&disjoint_sets[user]); + } + } + } + + ConstHloInstructionMap> functions; + for (auto* instruction : computation->MakeInstructionPostOrder()) { + functions[disjoint_sets[instruction].Get()].push_back(instruction); + } + + subgraphs_.reserve(functions.size()); + for (auto& [cluster_id, instructions] : functions) { + auto is_different_cluster = [cluster_id = cluster_id, + &disjoint_sets](auto* user) { + auto it = disjoint_sets.find(user); + if (it == disjoint_sets.end()) { + return true; + } + return it->second.Get() != cluster_id; + }; + + std::vector roots; + for (auto* instruction : instructions) { + if (instruction->user_count() == 0 || + absl::c_any_of(instruction->users(), is_different_cluster)) { + roots.push_back(instruction); + } + } + CHECK(!roots.empty()) << "No roots found"; + std::string name = llvm_ir::SanitizeFunctionName(absl::StrCat( + roots.front()->parent()->name(), "_", + absl::StrJoin(roots, "_", [](std::string* out, const auto* root) { + absl::StrAppend(out, root->name()); + }))); + subgraphs_.push_back( + Subgraph{.name = std::move(name), + .instructions = {instructions.begin(), instructions.end()}, + .instructions_post_order = std::move(instructions), + .roots = std::move(roots)}); + } + + for (const auto& subgraph : subgraphs_) { + for (const auto* instruction : subgraph.instructions_post_order) { + instructions_to_subgraphs_[instruction] = &subgraph; + } + } +} + +std::optional +PartitionedComputation::Subgraph::ForEpilogue( + const HloComputation* computation, + absl::Span heroes) { + if (heroes.empty() || + (heroes.size() == 1 && heroes[0] == computation->root_instruction())) { + return std::nullopt; + } + + PartitionedComputation::Subgraph subgraph{ + .name = llvm_ir::SanitizeFunctionName( + absl::StrCat(computation->name(), "__epilogue__")), + .roots = {computation->root_instruction()}, + }; + for (auto [index, hero] : llvm::enumerate(heroes)) { + subgraph.injected_values[hero] = index; + } + + std::vector instructions_pre_order; + absl::flat_hash_set seen; + std::function visit; + visit = [&](const HloInstruction* instruction) { + if (!seen.insert(instruction).second) return; + instructions_pre_order.push_back(instruction); + for (auto [index, operand] : llvm::enumerate(instruction->operands())) { + if (!subgraph.injected_values.contains(operand)) { + visit(operand); + } + } + }; + + visit(computation->root_instruction()); + subgraph.instructions = std::move(seen); + subgraph.instructions_post_order = {instructions_pre_order.rbegin(), + instructions_pre_order.rend()}; + return subgraph; +} + +PartitionedComputations::PartitionedComputations( + const HloComputation* fusion, + absl::Span heroes) + : fusion_(fusion), + epilogue_(PartitionedComputation::Subgraph::ForEpilogue(fusion, heroes)) { + // Collect all transitively called computations (including the fusion itself). + absl::flat_hash_set seen; + std::vector computations; + std::function visit; + visit = [&](const HloComputation* computation) { + if (!seen.insert(computation).second) return; + computations.push_back(computation); + for (auto* instr : computation->instructions()) { + absl::c_for_each(instr->called_computations(), visit); + } + }; + visit(fusion); + + absl::flat_hash_set roots{heroes.begin(), + heroes.end()}; + for (auto* instruction : heroes) { + roots.insert(instruction->operands().begin(), + instruction->operands().end()); + } + auto is_root = [&](const HloInstruction* instruction) { + return roots.contains(instruction); + }; + + partitioned_computations_.reserve(computations.size()); + for (auto* computation : computations) { + computation_to_partitioning_[computation] = + &partitioned_computations_.emplace_back( + PartitionedComputation{computation, is_root}); + } +} + +absl::flat_hash_map +PartitionedComputations::DeclareFunctions(mlir::ModuleOp module) const { + absl::flat_hash_map + mapping; + mlir::ImplicitLocOpBuilder builder(module.getLoc(), module->getContext()); + builder.setInsertionPointToEnd(module.getBody()); + for (const auto& computation : partitioned_computations_) { + for (const auto& subgraph : computation.subgraphs()) { + auto func_op = CreateSubgraphMlirFunction(subgraph, builder); + func_op->setAttr("llvm.linkage", mlir::LLVM::LinkageAttr::get( + module->getContext(), + mlir::LLVM::Linkage::Internal)); + mapping[&subgraph] = func_op; + } + } + if (epilogue_) { + auto func_op = CreateSubgraphMlirFunction(*epilogue_, builder); + func_op->setAttr("llvm.linkage", + mlir::LLVM::LinkageAttr::get( + module->getContext(), mlir::LLVM::Linkage::Internal)); + mapping[&*epilogue_] = func_op; + } + return mapping; +} + +const PartitionedComputation::Subgraph& PartitionedComputations::FindSubgraph( + const HloInstruction* instr) const { + return FindPartitionedComputation(instr->parent()).FindSubgraph(instr); +} + +CallTargetProvider PartitionedComputations::CreateCallTargetProvider( + const absl::flat_hash_map& subgraph_to_func) const { + return [&, this](const HloInstruction* instr) { + const auto& subgraph = FindSubgraph(instr); + CHECK(subgraph_to_func.contains(&subgraph)); + return subgraph_to_func.at(&subgraph); + }; +} + +mlir::func::FuncOp CreateSubgraphMlirFunction( + const PartitionedComputation::Subgraph& subgraph, + mlir::ImplicitLocOpBuilder& b) { + auto* computation = subgraph.roots.front()->parent(); + llvm::SmallVector parameter_types; + llvm::SmallVector result_types; + + auto element_type = [&](const auto& shape) { + return *ConvertPrimitiveTypeToMlirType(shape.element_type(), b); + }; + + const xla::Shape* first_root_shape = nullptr; + for (auto* root : subgraph.roots) { + if (root->shape().IsTuple()) { + for (auto& shape : root->shape().tuple_shapes()) { + if (!first_root_shape) { + first_root_shape = &shape; + } + result_types.push_back(element_type(shape)); + } + } else { + if (!first_root_shape) { + first_root_shape = &root->shape(); + } + result_types.push_back(element_type(root->shape())); + } + } + + llvm::SmallVector arg_attrs; + // We support the entry computation here for convenience of testing. The entry + // computation is never code generated here. + if (computation->IsFusionComputation() || computation->IsEntryComputation()) { + for (auto* param : computation->parameter_instructions()) { + parameter_types.push_back(TensorShapeToMlirType(param->shape(), b)); + arg_attrs.emplace_back(); + } + for (int dim = 0; dim < first_root_shape->rank(); ++dim) { + parameter_types.push_back(b.getIndexType()); + arg_attrs.emplace_back(mlir::DictionaryAttr::get( + b.getContext(), + {b.getNamedAttr("xla.range", + b.getIndexArrayAttr( + {0, first_root_shape->dimensions(dim) - 1}))})); + } + + // Populate arguments for injected parameters (values that are computed + // outside the function and are passed into it). + int operand_offset = parameter_types.size(); + parameter_types.resize(operand_offset + subgraph.injected_values.size()); + arg_attrs.resize(parameter_types.size()); + + for (auto [value, index] : subgraph.injected_values) { + parameter_types[operand_offset + index] = element_type(value->shape()); + } + } else { + for (auto* param : computation->parameter_instructions()) { + parameter_types.push_back(element_type(param->shape())); + } + } + auto ty = b.getFunctionType(parameter_types, result_types); + auto func_op = b.create( + subgraph.name, ty, + /*attrs=*/llvm::ArrayRef{}, arg_attrs); + // Needed so that the function can potentially be inlined in-place. + func_op.setPrivate(); + return func_op; +} + +} // namespace mlir_converter +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/computation_partitioner.h b/xla/service/gpu/fusions/mlir/computation_partitioner.h new file mode 100644 index 0000000000000..5c6f61a2dff48 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/computation_partitioner.h @@ -0,0 +1,190 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_COMPUTATION_PARTITIONER_H_ +#define XLA_SERVICE_GPU_FUSIONS_MLIR_COMPUTATION_PARTITIONER_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" + +namespace xla { +namespace gpu { +namespace mlir_converter { + +// Partitions an HLO computation into subgraphs so that all users of a node have +// consistent indexing, i. e. when we compute a node `a` with users `b` and `c`, +// all three nodes will have the same indexing - neither of `b` or `c` will be a +// transpose, reshape, reduce, etc. +// +// Consider the following example, where we assume all nodes affect indexing: +// +// a b Here we create four subgraphs: `a,d,c,e`, `b`, `f` and `g`. If +// \ /| `f` and `g` didn't change the indexing, they would be included +// d c | in the `a,d,c,e` subgraph, so we'd have `b` and the rest. +// \ | | +// e | Note that if some users have the same indexing as a node (e.g. +// / \| `e` and `g` in the graph to the left), we still have to create +// f g separate subgraphs for `f` and `g`. +// +// The purpose of this partitioning is to allow us to generate code without ever +// having to duplicate instructions: users with incompatible indexing will be in +// different subgraphs, each of which will emit a call to the producer graph. +// +// Note that this partitioning will sometimes create silly subgraphs that should +// (and will) be inlined, e. g. containing only a constant or only a broadcast. +// +// There is a hooks to customize this partitioning: +// is_subgraph_root: forces the clusterer to start a new subgraph at a given +// instruction. The instruction is guaranteed to be a in a different subgraph +// than its users. +class PartitionedComputation { + public: + explicit PartitionedComputation( + const HloComputation* computation, + std::function is_subgraph_root = + [](const HloInstruction*) { return false; }); + + struct Subgraph { + // A unique name of the subgraph. Used for function names. + std::string name; + + // The instructions that make up this subgraph. + absl::flat_hash_set instructions; + std::vector instructions_post_order; + + // The roots. These are guaranteed not to have users inside the subgraph. + std::vector roots; + + // For values that are function arguments (not function calls), stores the + // mapping from value to the argument index. The arguments always come + // after the tensor parameters and output indices; the indices are relative + // to the argument after the last index argument. + absl::flat_hash_map injected_values; + + std::string ToString() const; + + // Creates a subgraph for the given heroes' epilogue. The heroes values will + // be injected into the subgraph. + // If there is no epilogue (the root is the hero), returns nullopt. + static std::optional ForEpilogue( + const HloComputation* computation, + absl::Span heroes); + }; + + absl::Span subgraphs() const { return subgraphs_; } + + const HloComputation& computation() const { return *computation_; } + + const Subgraph& GetRootSubgraph() const { + return FindSubgraph(computation_->root_instruction()); + } + + // Returns the subgraph containing the given instruction. + const Subgraph& FindSubgraph(const HloInstruction* instr) const { + return *instructions_to_subgraphs_.at(instr); + } + + std::string ToString() const; + + private: + const HloComputation* computation_; + std::vector subgraphs_; + absl::flat_hash_map + instructions_to_subgraphs_; +}; + +// Given a root of a subgraph, returns the corresponding function. +using CallTargetProvider = + std::function; + +// A collection of PartitionedComputations, starting at a fusion computation and +// including all transitively called computations. +class PartitionedComputations { + public: + explicit PartitionedComputations( + const HloComputation* fusion, + absl::Span heroes = {}); + + const PartitionedComputation& FindPartitionedComputation( + const HloComputation* computation) const { + return *computation_to_partitioning_.at(computation); + } + + const PartitionedComputation::Subgraph& FindSubgraph( + const HloInstruction* instr) const; + + absl::Span partitioned_computations() const { + return partitioned_computations_; + } + + // If the fusion has an epilogue (i.e., the heroes are inside the fusion), + // returns it. + const std::optional& epilogue() const { + return epilogue_; + } + + const HloComputation* fusion() const { return fusion_; } + + // Creates a call target lookup function for use with SubgraphToMlir. + CallTargetProvider CreateCallTargetProvider( + const absl::flat_hash_map& subgraph_to_func) const; + + // Declares func.func ops for each subgraph in each computation and returns a + // mapping from subgraph to declared function. + absl::flat_hash_map + DeclareFunctions(mlir::ModuleOp module) const; + + std::string ToString() const; + + private: + std::vector partitioned_computations_; + absl::flat_hash_map + computation_to_partitioning_; + const HloComputation* fusion_; + std::optional epilogue_; +}; + +// Returns an MLIR function declaration for the given subgraph. For subgraphs of +// fusions, the signature is: +// (ptr, ptr, ..., index, index, ...) -> element type(s) +// For subgraphs of called computations, the signature is: +// (elemen type, ...) -> element type(s) +// +// Subgraphs of fusions will also have range (xla.range = [lower_bound, +// upper_bound], both bounds are inclusive) annotations on their index +// arguments. +mlir::func::FuncOp CreateSubgraphMlirFunction( + const PartitionedComputation::Subgraph& subgraph, + mlir::ImplicitLocOpBuilder& b); + +} // namespace mlir_converter +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_COMPUTATION_PARTITIONER_H_ diff --git a/xla/service/gpu/fusions/mlir/computation_partitioner_test.cc b/xla/service/gpu/fusions/mlir/computation_partitioner_test.cc new file mode 100644 index 0000000000000..606e4f4182a92 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/computation_partitioner_test.cc @@ -0,0 +1,305 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" + +#include +#include + +#include +#include +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/tests/hlo_test_base.h" + +namespace xla { +namespace gpu { +namespace mlir_converter { +namespace { + +using ::testing::ElementsAre; +using ::testing::SizeIs; + +using ComputationPartitionerTest = HloTestBase; + +std::string PrintAndErase(mlir::func::FuncOp func) { + std::string out; + llvm::raw_string_ostream os(out); + os << func; + // Erase the function so we don't leak memory. + func.erase(); + return out; +} + +TEST_F(ComputationPartitionerTest, PartitionDiamonds) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test_module + fused_computation { + %param = f32[6] parameter(0) + %slice0.1 = f32[5] slice(f32[6]{0} %param), slice={[0:5]} + %slice0.2 = f32[5] slice(f32[6]{0} %param), slice={[1:6]} + %add0 = f32[5] add(f32[5]{0} %slice0.1, f32[5]{0} %slice0.2) + %slice1.1 = f32[4] slice(f32[5]{0} %add0), slice={[0:4]} + %slice1.2 = f32[4] slice(f32[5]{0} %add0), slice={[1:5]} + %add1 = f32[4] add(f32[4]{0} %slice1.1, f32[4]{0} %slice1.2) + %slice2.1 = f32[3] slice(f32[4]{0} %add1), slice={[0:3]} + %slice2.2 = f32[3] slice(f32[4]{0} %add1), slice={[1:4]} + %add2 = f32[3] add(f32[3]{0} %slice2.1, f32[3]{0} %slice2.2) + %slice3.1 = f32[2] slice(f32[3]{0} %add2), slice={[0:2]} + %slice3.2 = f32[2] slice(f32[3]{0} %add2), slice={[1:3]} + ROOT %add3 = f32[2] add(f32[2]{0} %slice3.1, f32[2]{0} %slice3.2) + })") + .value(); + + auto* fusion = module->GetComputationWithName("fused_computation"); + ASSERT_NE(fusion, nullptr); + PartitionedComputation computation(fusion); + auto param = fusion->GetInstructionWithName("param"); + auto slice01 = fusion->GetInstructionWithName("slice0.1"); + auto slice02 = fusion->GetInstructionWithName("slice0.2"); + auto add0 = fusion->GetInstructionWithName("add0"); + auto slice11 = fusion->GetInstructionWithName("slice1.1"); + auto slice12 = fusion->GetInstructionWithName("slice1.2"); + auto add1 = fusion->GetInstructionWithName("add1"); + auto slice21 = fusion->GetInstructionWithName("slice2.1"); + auto slice22 = fusion->GetInstructionWithName("slice2.2"); + auto add2 = fusion->GetInstructionWithName("add2"); + auto slice31 = fusion->GetInstructionWithName("slice3.1"); + auto slice32 = fusion->GetInstructionWithName("slice3.2"); + auto add3 = fusion->GetInstructionWithName("add3"); + + const auto& graphs = computation.subgraphs(); + ASSERT_THAT(graphs, SizeIs(5)); + EXPECT_THAT(graphs[0].instructions_post_order, ElementsAre(param)); + EXPECT_THAT(graphs[1].instructions_post_order, + ElementsAre(slice01, slice02, add0)); + EXPECT_THAT(graphs[2].instructions_post_order, + ElementsAre(slice11, slice12, add1)); + EXPECT_THAT(graphs[3].instructions_post_order, + ElementsAre(slice21, slice22, add2)); + EXPECT_THAT(graphs[4].instructions_post_order, + ElementsAre(slice31, slice32, add3)); + + EXPECT_THAT(graphs[1].roots, ElementsAre(add0)); + EXPECT_THAT(graphs[2].roots, ElementsAre(add1)); + EXPECT_THAT(graphs[3].roots, ElementsAre(add2)); + EXPECT_THAT(graphs[4].roots, ElementsAre(add3)); + + EXPECT_EQ(&computation.GetRootSubgraph(), &graphs[4]); + EXPECT_EQ(&computation.FindSubgraph(slice21), &graphs[3]); +} + +TEST_F(ComputationPartitionerTest, TupleRoot) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test_module + fused_computation { + %p0 = f32[6] parameter(0) + %p1 = f32[6] parameter(1) + %add = f32[6] add(p0, p1) + %sub = f32[6] subtract(p0, p1) + ROOT %root = (f32[6], f32[6]) tuple(%add, %sub) + })") + .value(); + + auto* fusion = module->GetComputationWithName("fused_computation"); + ASSERT_NE(fusion, nullptr); + PartitionedComputation computation(fusion); + + ASSERT_THAT(computation.subgraphs(), SizeIs(1)) << computation.ToString(); +} + +TEST_F(ComputationPartitionerTest, TupleRootWithInjectedValues) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test_module + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + fused_computation { + p0 = f32[4] parameter(0) + c0 = f32[] constant(0) + reduce = f32[] reduce(p0, c0), dimensions={0}, to_apply=add + bitcast = f32[1] bitcast(reduce) + abs = f32[1] abs(bitcast) + log = f32[1] log(abs) + sign = f32[1] sign(bitcast) + ROOT tuple = (f32[1], f32[1]) tuple(log, sign) + })") + .value(); + + auto* fused_computation = module->GetComputationWithName("fused_computation"); + PartitionedComputations fusion( + fused_computation, + /*heroes=*/ + {fused_computation->GetInstructionWithName("reduce")}); + + // The epilogue should be one subgraph. + EXPECT_EQ( + &fusion.FindSubgraph( + fused_computation->GetInstructionWithName("bitcast")), + &fusion.FindSubgraph(fused_computation->GetInstructionWithName("tuple"))); +} + +TEST_F(ComputationPartitionerTest, EnforcePartitioning) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test_module + fused_computation { + %p0 = f32[64, 32] parameter(0) + %p1 = f32[64, 32] parameter(1) + %add = f32[64, 32] add(p0, p1) + %transpose = f32[32, 64] transpose(%add), dimensions={1, 0} + %exp = f32[32, 64] exponential(%transpose) + ROOT %root = f32[32, 64] tanh(%exp) + })") + .value(); + + auto* fusion = module->GetComputationWithName("fused_computation"); + ASSERT_NE(fusion, nullptr); + PartitionedComputation computation(fusion, [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kTranspose; + }); + ASSERT_THAT(computation.subgraphs(), SizeIs(2)); + EXPECT_THAT(computation.GetRootSubgraph().roots, SizeIs(1)); + EXPECT_THAT(computation.GetRootSubgraph().instructions_post_order, SizeIs(2)); +} + +TEST_F(ComputationPartitionerTest, PartiallyMergable) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test_module + fused_computation { + %p0 = f32[10,10] parameter(0) + %p1 = f32[10,10] parameter(1) + %add = f32[10,10] add(%p0, %p1) + %transpose = f32[10,10] transpose(%add), dimensions={1,0} + ROOT %sub = f32[10,10] subtract(%add, %transpose) + })") + .value(); + + auto* fusion = module->GetComputationWithName("fused_computation"); + ASSERT_NE(fusion, nullptr); + PartitionedComputation computation(fusion); + + auto transpose = fusion->GetInstructionWithName("transpose"); + auto sub = fusion->GetInstructionWithName("sub"); + + ASSERT_THAT(computation.subgraphs(), SizeIs(2)); + EXPECT_THAT(computation.GetRootSubgraph().instructions_post_order, + ElementsAre(transpose, sub)); +} + +TEST_F(ComputationPartitionerTest, SubgraphSignatures) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test_module + + add { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + ROOT %add = f32[] add(%p0, %p1) + } + + fusion { + %p0 = f32[10,10]{0,1} parameter(0) + %p1 = f32[10,10]{1,0} parameter(1) + %c0 = f32[] constant(2) + %bc = f32[10,10]{0,1} bitcast(%p1) + %add = f32[10,10] add(%p0, %bc) + ROOT %reduce = f32[10] reduce(%add, %c0), dimensions={1}, to_apply=add + } + + ENTRY main { + %p0 = f32[10,10] parameter(0) + %p1 = f32[10,10] parameter(1) + ROOT %fusion = f32[10] fusion(%p0, %p1), kind=kLoop, calls=fusion + })") + .value(); + + mlir::MLIRContext context; + context.loadDialect(); + mlir::ImplicitLocOpBuilder builder(mlir::UnknownLoc::get(&context), &context); + + PartitionedComputation fusion(module->GetComputationWithName("fusion")); + EXPECT_EQ( + PrintAndErase( + CreateSubgraphMlirFunction(fusion.GetRootSubgraph(), builder)), + "func.func private @fusion_reduce(tensor<10x10xf32, dense<[0, 1]> : " + "tensor<2xi64>>, tensor<10x10xf32>, index {xla.range = [0 : index, 9 : " + "index]}) -> f32"); + + PartitionedComputation add(module->GetComputationWithName("add")); + EXPECT_EQ( + PrintAndErase(CreateSubgraphMlirFunction(add.GetRootSubgraph(), builder)), + "func.func private @add_add(f32, f32) -> f32"); +} + +TEST_F(ComputationPartitionerTest, SubgraphSignaturesWithInjectedValues) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test_module + + %fused_computation { + %p0 = f32[2,16,17] parameter(0) + %log = f32[2,16,17] log(%p0) + %transpose = f32[2,17,16] transpose(%log), dimensions={0,2,1} + %p1 = f32[] parameter(1) + %bc = f32[2,17,16] broadcast(%p1), dimensions={} + ROOT %add = f32[2,17,16] add(%transpose, %bc) + } + + ENTRY main { + %p0 = f32[2,16,17] parameter(0) + %p1 = f32[] parameter(1) + ROOT %fusion = f32[2,17,16] fusion(%p0, %p1), kind=kInput, + calls=%fused_computation + } + )") + .value(); + + mlir::MLIRContext context; + context.loadDialect(); + mlir::ImplicitLocOpBuilder builder(mlir::UnknownLoc::get(&context), &context); + + // We force a split at the transpose (like the transpose emitter would do) and + // enforce that the transpose is injected as a parameter into the epilogue. + auto* fused_computation = module->GetComputationWithName("fused_computation"); + auto* transpose = fused_computation->GetInstructionWithName("transpose"); + PartitionedComputations fusion(fused_computation, + /*heroes=*/ + {transpose}); + auto& epilogue_graph = fusion.epilogue(); + auto& injected_values = epilogue_graph->injected_values; + EXPECT_EQ(injected_values.size(), 1); + std::pair injected_operand_key( + fused_computation->root_instruction(), 0); + ASSERT_TRUE(injected_values.contains(transpose)); + EXPECT_EQ(injected_values.at(transpose), 0); + EXPECT_EQ( + PrintAndErase(CreateSubgraphMlirFunction(*epilogue_graph, builder)), + "func.func private @fused_computation__epilogue__(tensor<2x16x17xf32>, " + "tensor, index {xla.range = [0 : index, 1 : index]}, index " + "{xla.range = [0 : index, 16 : index]}, index {xla.range = [0 : " + "index, 15 : index]}, f32) -> f32"); +} + +} // namespace +} // namespace mlir_converter +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc new file mode 100644 index 0000000000000..4c62612341502 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc @@ -0,0 +1,1405 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/node_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/MathExtras.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/Affine/LoopUtils.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/mlir/utils/type_util.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" +#include "xla/mlir_hlo/mhlo/utils/type_conversion.h" +#include "xla/primitive_util.h" +#include "xla/service/algorithm_util.h" +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_description.h" +#include "xla/translate/hlo_to_mhlo/hlo_utils.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace mlir_converter { +namespace { + +using llvm::SmallVector; +using llvm::SmallVectorImpl; +using mlir::Block; +using mlir::FloatType; +using mlir::ImplicitLocOpBuilder; +using mlir::IRMapping; +using mlir::Location; +using mlir::MLIRContext; +using mlir::OpBuilder; +using mlir::Value; +using mlir::ValueRange; +using mlir::arith::AndIOp; +using mlir::arith::CmpFOp; +using mlir::arith::CmpFPredicate; +using mlir::arith::CmpIOp; +using mlir::arith::CmpIPredicate; +using mlir::arith::ConstantIndexOp; +using mlir::arith::ConstantOp; +using mlir::arith::SelectOp; +using mlir::scf::ForOp; +using mlir::scf::IfOp; +using mlir::scf::YieldOp; + +namespace arith = ::mlir::arith; +namespace mhlo = ::mlir::mhlo; +namespace scf = ::mlir::scf; + +// HLO opcodes that we never support. +static auto& kUnsupportedOps = + *new absl::flat_hash_set{HloOpcode::kAddDependency, + HloOpcode::kAfterAll, + HloOpcode::kAllGather, + HloOpcode::kAllGatherDone, + HloOpcode::kAllGatherStart, + HloOpcode::kAllReduce, + HloOpcode::kAllReduceDone, + HloOpcode::kAllReduceStart, + HloOpcode::kAllToAll, + HloOpcode::kAsyncDone, + HloOpcode::kAsyncStart, + HloOpcode::kAsyncUpdate, + HloOpcode::kBatchNormGrad, + HloOpcode::kBatchNormInference, + HloOpcode::kBatchNormTraining, + HloOpcode::kCholesky, + HloOpcode::kCollectivePermute, + HloOpcode::kCollectivePermuteDone, + HloOpcode::kCollectivePermuteStart, + HloOpcode::kCopyDone, + HloOpcode::kCopyStart, + HloOpcode::kCustomCall, + HloOpcode::kDomain, + HloOpcode::kDynamicReshape, + HloOpcode::kFft, + HloOpcode::kFusion, + HloOpcode::kGetDimensionSize, + HloOpcode::kOptimizationBarrier, + HloOpcode::kInfeed, + HloOpcode::kOutfeed, + HloOpcode::kPartitionId, + HloOpcode::kRecv, + HloOpcode::kRecvDone, + HloOpcode::kReduceScatter, + HloOpcode::kReplicaId, + HloOpcode::kRng, + HloOpcode::kRngBitGenerator, + HloOpcode::kRngGetAndUpdateState, + HloOpcode::kScatter, + HloOpcode::kSelectAndScatter, + HloOpcode::kSend, + HloOpcode::kSendDone, + HloOpcode::kSetDimensionSize, + HloOpcode::kSort, + HloOpcode::kTopK, + HloOpcode::kTriangularSolve, + HloOpcode::kWhile, + HloOpcode::kConditional, + HloOpcode::kStochasticConvert, + HloOpcode::kCall}; + +static auto& kUnimplementedOps = + *new absl::flat_hash_set{HloOpcode::kMap}; + +bool IsUnsupportedConstant(const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kConstant && + (!ShapeUtil::IsEffectiveScalar(instr->shape()) || + primitive_util::IsUnsignedIntegralType( + instr->shape().element_type()) || + primitive_util::IsComplexType(instr->shape().element_type())); +} + +bool IsUnsupportedTuple(const HloInstruction* instr) { + if (instr->opcode() != HloOpcode::kTuple) { + return false; + } + + if (instr->user_count() > 0) { + // Internal tuples are unsupported. + return true; + } + + // Nested tuples and tokens are unsupported. + if (absl::c_any_of(instr->operands(), + [&](auto* op) { return !op->shape().IsArray(); })) { + return true; + } + + // All tuple elements must have bitcast-compatible dimensions (element types + // may differ). + auto first_shape = instr->shape().tuple_shapes(0); + for (int i = 1; i < instr->operand_count(); ++i) { + const auto& tuple_shape = instr->shape().tuple_shapes(i); + if (!ShapeUtil::EqualIgnoringElementType(tuple_shape, first_shape) && + !ShapeUtil::IsReshapeOrTransposeBitcast(tuple_shape, first_shape, + /*ignore_element_type=*/true)) { + return true; + } + } + return false; +} + +bool IsUnsupportedGather(const HloInstruction* instr) { + // We assume gather simplifier ran, so we don't need to support all gather + // forms. + if (instr->opcode() != HloOpcode::kGather) return false; + + auto* gather = Cast(instr); + const auto& dims = gather->gather_dimension_numbers(); + if (dims.index_vector_dim() != 1 || !dims.collapsed_slice_dims().empty() || + gather->operand(1)->shape().rank() != 2) { + return true; + } + + for (auto [index, val] : llvm::enumerate(dims.start_index_map())) { + if (index != val) return true; + } + for (auto [index, val] : llvm::enumerate(dims.offset_dims())) { + if (index + 1 != val) return true; + } + return false; +} + +absl::StatusOr GetSingleOperandValue( + const OperandProvider& operand_provider, const HloInstruction* instr, + int operand_index, ValueRange indices) { + TF_ASSIGN_OR_RETURN(auto operand, + operand_provider(instr, operand_index, indices)); + TF_RET_CHECK(operand.size() == 1) << "Expected operand to be a single value."; + return operand.front(); +} + +SmallVector ConvertToSignless(const SmallVector& values, + ImplicitLocOpBuilder& b) { + mlir::mhlo::RemoveSignTypeConverter sign_converter; + SmallVector results; + results.reserve(values.size()); + for (auto& value : values) { + auto signless_type = sign_converter.convertType(value.getType()); + results.push_back( + b.create(signless_type, value) + .getResult(0)); + } + return results; +} + +absl::StatusOr> EmitReduce( + const HloInstruction* instr, ValueRange indices, + const OperandProvider& operand_provider, + const CallTargetProvider& call_target_provider, ImplicitLocOpBuilder& b) { + auto* mlir_context = b.getContext(); + HloInstructionIndexing indexing = + ComputeOutputToInputIndexing(instr, 0, mlir_context); + const auto& indexing_map = *indexing.indexing_maps[0].begin(); + + SmallVector init_values; + for (int i = instr->operand_count() / 2; i < instr->operand_count(); ++i) { + TF_ASSIGN_OR_RETURN(init_values.emplace_back(), + GetSingleOperandValue(operand_provider, instr, i, {})); + // Convert back to signed type. + TF_ASSIGN_OR_RETURN(auto element_mlir_type, + ConvertPrimitiveTypeToMlirType( + instr->operand(i)->shape().element_type(), b)); + init_values.back() = b.create( + element_mlir_type, init_values.back()) + .getResult(0); + } + + auto body = + [&](ValueRange iter_args, ValueRange dim_values, + ValueRange symbol_values) -> absl::StatusOr> { + auto indices = ApplyAffineMap(indexing_map.GetAffineMap(), dim_values, + symbol_values, b); + + SmallVector args{iter_args}; + for (int i = 0; i < instr->operand_count() / 2; ++i) { + TF_ASSIGN_OR_RETURN( + args.emplace_back(), + GetSingleOperandValue(operand_provider, instr, i, indices)); + // Convert back to signed type. + TF_ASSIGN_OR_RETURN(auto element_mlir_type, + ConvertPrimitiveTypeToMlirType( + instr->operand(i)->shape().element_type(), b)); + args.back() = b.create( + element_mlir_type, args.back()) + .getResult(0); + } + auto reducer = call_target_provider( + instr->called_computations().front()->root_instruction()); + return b.create(reducer, args).getResults(); + }; + + TF_ASSIGN_OR_RETURN( + auto result, + EmitLoopNestWithStatus(b, indices, init_values, indexing_map, body)); + + return ConvertToSignless(result, b); +} + +absl::StatusOr> EmitReduceWindow( + const HloInstruction* instr, mlir::Type result_element_type, + ValueRange indices, const OperandProvider& operand_provider, + const CallTargetProvider& call_target_provider, ImplicitLocOpBuilder& b) { + MLIRContext* mlir_context = b.getContext(); + HloInstructionIndexing indexing = + ComputeOutputToInputIndexing(instr, 0, mlir_context); + auto indexing_map = *indexing.indexing_maps[0].begin(); + indexing_map.RescaleSymbols(); + + auto reduce_window = DynCast(instr); + CHECK(reduce_window != nullptr); + + SmallVector init_values; + for (auto [index, init_value] : + llvm::enumerate(reduce_window->init_values())) { + TF_ASSIGN_OR_RETURN( + init_values.emplace_back(), + GetSingleOperandValue(operand_provider, instr, + reduce_window->input_count() + index, {})); + // Convert back to signed type. + TF_ASSIGN_OR_RETURN( + auto element_mlir_type, + ConvertPrimitiveTypeToMlirType(init_value->shape().element_type(), b)); + init_values.back() = b.create( + element_mlir_type, init_values.back()) + .getResult(0); + } + + auto body = + [&](ValueRange iter_args, ValueRange dim_values, + ValueRange symbol_values) -> absl::StatusOr> { + auto indices = ApplyAffineMap(indexing_map.GetAffineMap(), dim_values, + symbol_values, b); + + SmallVector args{iter_args}; + for (auto [index, input] : llvm::enumerate(reduce_window->inputs())) { + TF_ASSIGN_OR_RETURN( + args.emplace_back(), + GetSingleOperandValue(operand_provider, instr, index, indices)); + + // Convert back to signed type. + TF_ASSIGN_OR_RETURN( + auto element_mlir_type, + ConvertPrimitiveTypeToMlirType(input->shape().element_type(), b)); + args.back() = b.create( + element_mlir_type, args.back()) + .getResult(0); + } + + auto reducer = call_target_provider( + instr->called_computations().front()->root_instruction()); + return b.create(reducer, args).getResults(); + }; + + TF_ASSIGN_OR_RETURN( + auto result, + EmitLoopNestWithStatus(b, indices, init_values, indexing_map, body)); + + return ConvertToSignless(result, b); +} + +absl::StatusOr> EmitConcat( + const HloInstruction* instr, mlir::Type result_element_type, + ValueRange indices, const OperandProvider& operand_provider, + ImplicitLocOpBuilder& b) { + int concat_dim = + Cast(instr)->concatenate_dimension(); + int64_t offset = 0; + IfOp outermost_if = nullptr; + SmallVector operand_indices = indices; + for (auto [index, operand] : llvm::enumerate(instr->operands())) { + int64_t limit = offset + operand->shape().dimensions(concat_dim); + auto ins = b.create(CmpIPredicate::ult, indices[concat_dim], + b.create(limit)); + + auto generate_operand = [&, index = index]() { + operand_indices[concat_dim] = b.create( + indices[concat_dim], b.create(offset)); + TF_ASSIGN_OR_RETURN(auto operand, + operand_provider(instr, index, operand_indices)); + b.create(operand); + return absl::OkStatus(); + }; + + if (index < instr->operand_count() - 1) { + auto if_op = + b.create(mlir::TypeRange{result_element_type}, ins, true, true); + if (outermost_if == nullptr) { + outermost_if = if_op; + } else { + b.create(if_op.getResults()); + } + + b.setInsertionPointToStart(if_op.getBody(0)); + TF_RETURN_IF_ERROR(generate_operand()); + b.setInsertionPointToStart(if_op.getBody(1)); + } else { + TF_RETURN_IF_ERROR(generate_operand()); + } + offset = limit; + } + + b.setInsertionPointAfter(outermost_if); + return outermost_if.getResults(); +} + +absl::StatusOr> EmitDynamicSlice( + const HloInstruction* instr, ValueRange indices, + const OperandProvider& operand_provider, ImplicitLocOpBuilder& b) { + llvm::SmallVector input_indices(indices); + + const auto& input_shape = instr->operand(0)->shape(); + for (int i = 0; i < input_shape.rank(); ++i) { + TF_ASSIGN_OR_RETURN( + auto offset, GetSingleOperandValue(operand_provider, instr, i + 1, {})); + offset = + ClampIndex(offset, + primitive_util::IsUnsignedIntegralType( + instr->operand(i + 1)->shape().element_type()), + input_shape.dimensions(i) - instr->shape().dimensions(i), b); + input_indices[i] = b.create(input_indices[i], offset); + } + + return operand_provider(instr, 0, input_indices); +} + +absl::StatusOr> EmitDynamicUpdateSlice( + const HloInstruction* instr, mlir::Type result_element_type, + ValueRange indices, const OperandProvider& operand_provider, + ImplicitLocOpBuilder& b) { + mlir::Value is_in_bounds = + b.create(b.getIntegerAttr(b.getI1Type(), 1)); + mlir::SmallVector update_indices; + const auto& updates_shape = instr->operand(1)->shape(); + for (int i = 0; i < instr->shape().rank(); ++i) { + int64_t update_size = updates_shape.dimensions(i); + TF_ASSIGN_OR_RETURN( + auto start_index, + GetSingleOperandValue(operand_provider, instr, i + 2, {})); + start_index = ClampIndex(start_index, + primitive_util::IsUnsignedIntegralType( + instr->operand(i + 2)->shape().element_type()), + instr->shape().dimensions(i) - update_size, b); + + auto end_index = b.create( + start_index, b.create(b.getIndexAttr(update_size))); + + is_in_bounds = b.create( + is_in_bounds, + b.create(CmpIPredicate::sge, indices[i], start_index)); + is_in_bounds = b.create( + is_in_bounds, + b.create(CmpIPredicate::slt, indices[i], end_index)); + + update_indices.push_back(b.create(indices[i], start_index)); + } + + auto if_op = b.create(mlir::TypeRange{result_element_type}, + is_in_bounds, true, true); + b.setInsertionPointToStart(if_op.getBody(0)); + TF_ASSIGN_OR_RETURN( + auto updated_value, + GetSingleOperandValue(operand_provider, instr, 1, update_indices)); + b.create(updated_value); + + b.setInsertionPointToStart(if_op.getBody(1)); + TF_ASSIGN_OR_RETURN( + auto original_value, + GetSingleOperandValue(operand_provider, instr, 0, indices)); + b.create(original_value); + + b.setInsertionPointAfter(if_op); + return if_op.getResults(); +} + +absl::StatusOr> EmitGather( + const HloInstruction* instr, ValueRange indices, + const OperandProvider& operand_provider, ImplicitLocOpBuilder& b) { + auto row = indices[0]; + auto zero = b.create(0); + // Gather allows the index vector to contain fewer elements than the rank + // of the input. In that case, the remaining indices are 0. + SmallVector operand_indices(instr->operand(0)->shape().rank(), zero); + + // Produce start indices. + int num_indices = instr->operand(1)->shape().dimensions(1); + for (int i = 0; i < num_indices; ++i) { + auto i_val = i == 0 ? zero : b.create(i); + int64_t slice_size = instr->gather_slice_sizes()[i]; + int64_t input_size = instr->operand(0)->shape().dimensions()[i]; + // Read and clamp index. + TF_ASSIGN_OR_RETURN(auto input_index, + operand_provider(instr, 1, {row, i_val})); + TF_RET_CHECK(input_index.size() == 1) + << "Expected operand to be a single value."; + operand_indices[i] = + ClampIndex(input_index.front(), + primitive_util::IsUnsignedIntegralType( + instr->operand(1)->shape().element_type()), + input_size - slice_size, b); + } + + // Add offsets. + for (int i = 0; i < operand_indices.size(); ++i) { + operand_indices[i] = + b.createOrFold(operand_indices[i], indices[i + 1]); + } + + return operand_provider(instr, 0, operand_indices); +} + +// For a given instruction, deduces the indices of each parameter that are +// needed for a given output index. +SmallVector> GetInputIndices( + const HloInstructionIndexing& indexing, ValueRange output_indices, + ImplicitLocOpBuilder& b) { + SmallVector> indices; + for (auto& maps : indexing.indexing_maps) { + CHECK_EQ(maps.size(), 1); + auto map = maps.begin()->GetAffineMap(); + CHECK(!maps.begin()->IsUndefined()); + indices.emplace_back() = ApplyAffineMap(map, output_indices, {}, b); + } + return indices; +} + +absl::StatusOr> EmitPad( + const HloInstruction* instr, mlir::Type result_element_type, + ValueRange indices, const OperandProvider& operand_provider, + ImplicitLocOpBuilder& b) { + auto indexing = ComputeOutputToInputIndexing(instr, 0, b.getContext()); + const auto& indexing_map = *indexing.indexing_maps[0].begin(); + mlir::Value is_in_bounds = CheckConstraints(indexing_map, indices, {}, b); + + auto if_op = b.create(mlir::TypeRange{result_element_type}, + is_in_bounds, true, true); + b.setInsertionPointToStart(if_op.getBody(0)); + TF_ASSIGN_OR_RETURN(auto input_value, + GetSingleOperandValue( + operand_provider, instr, 0, + GetInputIndices(indexing, indices, + b)[0 /* indexing for operand 0 */])); + b.create(input_value); + + b.setInsertionPointToStart(if_op.getBody(1)); + TF_ASSIGN_OR_RETURN(auto padding_value, + GetSingleOperandValue(operand_provider, instr, 1, {})); + b.create(padding_value); + + b.setInsertionPointAfter(if_op); + return if_op.getResults(); +} + +absl::StatusOr EmitFloatCast(Value value, mlir::Type target_type, + ImplicitLocOpBuilder& b) { + if (value.getType().getIntOrFloatBitWidth() < + target_type.getIntOrFloatBitWidth()) { + return b.create(target_type, value); + } + if (value.getType().getIntOrFloatBitWidth() > + target_type.getIntOrFloatBitWidth()) { + return b.create(target_type, value); + } + return value; +} + +absl::StatusOr EmitMulAdd(Value lhs, Value rhs, Value accumulator, + mlir::Type result_element_type, + mlir::Type accumulator_type, + ImplicitLocOpBuilder& b) { + if (result_element_type.isa()) { + if (result_element_type.isBF16()) { + lhs = b.create(b.getF32Type(), lhs); + rhs = b.create(b.getF32Type(), rhs); + } + TF_ASSIGN_OR_RETURN( + Value casted, + EmitFloatCast(b.create(lhs, rhs), accumulator_type, b)); + return b.create(accumulator, casted); + } + if (result_element_type.isInteger(1)) { + return b.create(accumulator, + b.create(lhs, rhs)); + } + return b.create(accumulator, + b.create(lhs, rhs)); +} + +absl::StatusOr> EmitDotLoop( + const HloInstruction* instr, mlir::Type result_element_type, + ValueRange indices, const OperandProvider& operand_provider, + ImplicitLocOpBuilder& b) { + HloInstructionIndexing indexing = + ComputeOutputToInputIndexing(instr, /*output_id=*/0, b.getContext()); + const IndexingMap& lhs_indexing_map = *indexing.indexing_maps.at(0).begin(); + const IndexingMap& rhs_indexing_map = *indexing.indexing_maps.at(1).begin(); + + const mlir::Type accumulator_type = + result_element_type.isBF16() ? b.getF32Type() : result_element_type; + Value accum_init_value = + b.create(b.getZeroAttr(accumulator_type)).getResult(); + + // For convolutions with `batch_group_count` > 1, there is an additional + // symbol for LHS (group id) - ignore it for RHS. + size_t rhs_symbol_count = rhs_indexing_map.GetSymbolCount(); + + auto body = + [&](ValueRange iter_args, ValueRange dim_values, + ValueRange symbol_values) -> absl::StatusOr> { + llvm::SmallVector lhs_indices = ApplyAffineMap( + lhs_indexing_map.GetAffineMap(), dim_values, symbol_values, b); + llvm::SmallVector rhs_indices = + ApplyAffineMap(rhs_indexing_map.GetAffineMap(), dim_values, + symbol_values.take_front(rhs_symbol_count), b); + + TF_ASSIGN_OR_RETURN(Value lhs_value, GetSingleOperandValue( + operand_provider, instr, + /*operand_index=*/0, lhs_indices)); + TF_ASSIGN_OR_RETURN(Value rhs_value, GetSingleOperandValue( + operand_provider, instr, + /*operand_index=*/1, rhs_indices)); + Value accum = iter_args[0]; + + TF_ASSIGN_OR_RETURN( + accum, EmitMulAdd(lhs_value, rhs_value, accum, result_element_type, + accumulator_type, b)); + return {{accum}}; + }; + + TF_ASSIGN_OR_RETURN(SmallVector results, + EmitLoopNestWithStatus(b, indices, {accum_init_value}, + lhs_indexing_map, body)); + if (result_element_type.isBF16()) { + results[0] = b.create(b.getBF16Type(), results[0]); + } + return results; +} + +absl::StatusOr> EmitDot( + const HloInstruction* instr, mlir::Type result_element_type, + ValueRange indices, const OperandProvider& operand_provider, + ImplicitLocOpBuilder& b) { + VLOG(1) << "EmitDot: " << instr->ToString() << " " + << llvm_ir::DumpToString(result_element_type); + + if (!algorithm_util::IsSupportedByElementalIrEmitter( + instr->precision_config().algorithm())) { + return absl::InvalidArgumentError( + absl::StrFormat("Algorithm not supported by the ElementalIrEmitter: %s", + PrecisionConfig::Algorithm_Name( + instr->precision_config().algorithm()))); + } + auto* dot = DynCast(instr); + TF_RET_CHECK(dot != nullptr); + if (dot->sparse_operands()) { + return absl::UnimplementedError( + "Sparse dot is supported by Triton emitter only."); + } + + return EmitDotLoop(instr, result_element_type, indices, operand_provider, b); +} + +absl::StatusOr> EmitConvolution( + const HloInstruction* instr, mlir::Type result_element_type, + ValueRange indices, const OperandProvider& operand_provider, + ImplicitLocOpBuilder& b) { + VLOG(1) << "EmitConvolution: " << instr->ToString() << " " + << llvm_ir::DumpToString(result_element_type); + + return EmitDotLoop(instr, result_element_type, indices, operand_provider, b); +} + +absl::StatusOr> EmitParameter(const HloInstruction* instr, + mlir::func::FuncOp this_fn, + ValueRange indices, + ImplicitLocOpBuilder& b) { + Value value = this_fn.getArgument(instr->parameter_number()); + if (value.getType().isa()) { + value = b.create(value, indices); + } else { + TF_RET_CHECK(indices.empty()); + } + return {{value}}; +} + +template +SmallVector MapHloOp(mlir::Type result_type, + llvm::ArrayRef arg_types, + llvm::ArrayRef args, ImplicitLocOpBuilder& b, + ExtraArgs&&... extra_args) { + return {mhlo::MhloOpToStdScalarOp::mapOpOfType( + b.getLoc(), result_type, arg_types, + typename MhloOp::Adaptor(args, std::forward(extra_args)...), + &b)}; +} + +template +SmallVector MapElementwiseOp(llvm::ArrayRef arg_types, + llvm::ArrayRef args, + ImplicitLocOpBuilder& b) { + // We use the last argument's type because of select. + return MapHloOp(args.back().getType(), arg_types, args, b); +} + +} // namespace + +Value ApplyAffineExpr(mlir::AffineExpr expr, ValueRange dims, + ValueRange symbols, ImplicitLocOpBuilder& b) { + // For unknown (but undoubtedly good) reasons, affine.apply removes unused + // trailing dimensions, but only in the expression. + while (!dims.empty() && !expr.isFunctionOfDim(dims.size() - 1)) { + dims = dims.drop_back(); + } + while (!symbols.empty() && !expr.isFunctionOfSymbol(symbols.size() - 1)) { + symbols = symbols.drop_back(); + } + SmallVector args(dims); + absl::c_copy(symbols, std::back_inserter(args)); + return b.createOrFold(expr, args); +} + +SmallVector ApplyAffineMap(mlir::AffineMap map, ValueRange dims, + ValueRange symbols, ImplicitLocOpBuilder& b) { + CHECK_EQ(map.getNumDims(), dims.size()); + CHECK_EQ(map.getNumSymbols(), symbols.size()); + SmallVector result; + result.reserve(map.getNumResults()); + for (auto expr : map.getResults()) { + result.push_back(ApplyAffineExpr(expr, dims, symbols, b)); + } + return result; +} + +Value CheckConstraint(mlir::Value constrained_value, Interval range, + ImplicitLocOpBuilder& b) { + auto lb = b.create(b.getIndexAttr(range.lower)); + if (range.IsPoint()) { + return b.create(CmpIPredicate::eq, constrained_value, lb); + } + auto ub = b.create(b.getIndexAttr(range.upper)); + return b.create( + b.create(CmpIPredicate::sge, constrained_value, lb), + b.create(CmpIPredicate::sle, constrained_value, ub)); +} + +Value CheckConstraints(const IndexingMap& map, ValueRange dims, + ValueRange symbols, ImplicitLocOpBuilder& b) { + Value ret = b.create(b.getIntegerAttr(b.getI1Type(), 1)); + for (auto&& [expression, range] : map.GetConstraints()) { + ret = b.create( + ret, CheckConstraint(ApplyAffineExpr(expression, dims, symbols, b), + range, b)); + } + for (auto&& [index, bound] : llvm::enumerate(map.GetDimensionBounds())) { + ret = b.create(ret, CheckConstraint(dims[index], bound, b)); + } + return ret; +} + +namespace { + +absl::StatusOr> HloToMlir( + const HloInstruction* instr, mlir::func::FuncOp this_fn, ValueRange indices, + const OperandProvider& operand_provider, + const CallTargetProvider& call_target_provider, + ImplicitLocOpBuilder& builder) { + CHECK(!kUnsupportedOps.contains(instr->opcode())) << instr->ToShortString(); + CHECK(!kUnimplementedOps.contains(instr->opcode())) << instr->ToShortString(); + + auto element_type = instr->shape().element_type(); + mlir::Type element_mlir_type; + mlir::Type result_element_type; + if (!instr->shape().IsTuple()) { + TF_ASSIGN_OR_RETURN(element_mlir_type, + ConvertPrimitiveTypeToMlirType(element_type, builder)); + + // During mapping to the arith dialect, we need to convert from signed + // integer types to signless integer types. Most mappings can infer the + // signless integer type from the already converted operand, but e.g. for + // Convert this is not possible, so we need to have the signless result + // element type as well. But we also still need to pass the signed integer + // element type, as that is needed to select the correct arith ops for + // unsigned element types. + mlir::mhlo::RemoveSignTypeConverter sign_converter; + result_element_type = sign_converter.convertType(element_mlir_type); + } + + auto* mlir_context = builder.getContext(); + // Handle ops that aren't elementwise and aren't just indexing + // transformations. + switch (instr->opcode()) { + case HloOpcode::kConcatenate: + return EmitConcat(instr, result_element_type, indices, operand_provider, + builder); + case HloOpcode::kConstant: + if (ShapeUtil::IsEffectiveScalar(instr->shape())) { + auto val = mlir::cast( + CreateDenseElementsAttrFromLiteral(instr->literal(), builder) + ->getValues()[0]); + return {{builder.create(val).getResult()}}; + } + return absl::UnimplementedError( + absl::StrCat("Unimplemented: ", instr->ToShortString())); + case HloOpcode::kConvolution: + return EmitConvolution(instr, result_element_type, indices, + operand_provider, builder); + case HloOpcode::kDynamicSlice: + return EmitDynamicSlice(instr, indices, operand_provider, builder); + case HloOpcode::kDynamicUpdateSlice: + return EmitDynamicUpdateSlice(instr, result_element_type, indices, + operand_provider, builder); + case HloOpcode::kGather: + return EmitGather(instr, indices, operand_provider, builder); + case HloOpcode::kIota: { + auto index = indices[Cast(instr)->iota_dimension()]; + auto index_type = builder.getIntegerType( + mlir::DataLayout::closest(builder.getInsertionBlock()->getParentOp()) + .getTypeSizeInBits(index.getType())); + index = builder.create(index_type, index); + return {{mhlo::MhloOpToStdScalarOp::mapConvertOpToStdScalarOp( + builder.getLoc(), element_mlir_type, result_element_type, + {index_type}, {index}, &builder)}}; + } + case HloOpcode::kPad: + return EmitPad(instr, result_element_type, indices, operand_provider, + builder); + case HloOpcode::kDot: + return EmitDot(instr, result_element_type, indices, operand_provider, + builder); + case HloOpcode::kParameter: + return EmitParameter(instr, this_fn, indices, builder); + case HloOpcode::kReduce: + return EmitReduce(instr, indices, operand_provider, call_target_provider, + builder); + case HloOpcode::kReduceWindow: + return EmitReduceWindow(instr, result_element_type, indices, + operand_provider, call_target_provider, builder); + case HloOpcode::kTuple: { + CHECK(!IsUnsupportedTuple(instr)); + const auto& first_shape = instr->shape().tuple_shapes(0); + CHECK_EQ(first_shape.rank(), indices.size()) + << "Indices for tuple must be for the first tuple element"; + SmallVector operands; + for (int i = 0; i < instr->operand_count(); ++i) { + llvm::SmallVector operand_indices; + // The tuple shapes only need to be bitcast compatible, so insert + // bitcasts where necessary. + if (i > 0 && !ShapeUtil::EqualIgnoringElementType( + first_shape, instr->operand(i)->shape())) { + auto operand_map = GetBitcastMap( + first_shape, instr->operand(i)->shape(), mlir_context); + operand_indices = + ApplyAffineMap(operand_map.GetAffineMap(), indices, {}, builder); + } else { + operand_indices = indices; + } + TF_ASSIGN_OR_RETURN( + operands.emplace_back(), + GetSingleOperandValue(operand_provider, instr, i, operand_indices)); + } + return operands; + } + case HloOpcode::kGetTupleElement: { + // We have to generate the entire tuple, but since we don't support + // internal tuple operations (only root tuples), this will always be + // cached and computed together anyway (e.g. it'll be a variadic reduce). + TF_ASSIGN_OR_RETURN(auto tuple, operand_provider(instr, 0, indices)); + return {{tuple[instr->tuple_index()]}}; + } + default: + break; + } + + llvm::SmallVector arg_types; + arg_types.reserve(instr->operands().size()); + for (auto operand : instr->operands()) { + TF_ASSIGN_OR_RETURN(auto operand_element_type, + ConvertPrimitiveTypeToMlirType( + operand->shape().element_type(), builder)); + arg_types.push_back(operand_element_type); + } + auto input_indices = GetInputIndices( + ComputeOutputToInputIndexing(instr, 0, mlir_context), indices, builder); + SmallVector operands; + for (auto&& [operand_number, operand_indices] : + llvm::enumerate(input_indices)) { + TF_ASSIGN_OR_RETURN(operands.emplace_back(), + GetSingleOperandValue(operand_provider, instr, + operand_number, operand_indices)); + // Nulls can be pretty hard to debug, so guard against them here. The MHLO + // conversion functions like to return nullptr for errors. + TF_RET_CHECK(operands.back() != nullptr) + << "null operand at index " << operand_number << " for " + << instr->ToShortString(); + } + CHECK_NE(operands.size(), 0); + + switch (instr->opcode()) { + case HloOpcode::kAbs: + return {MapHloOp(element_mlir_type, arg_types, operands, + builder)}; + case HloOpcode::kAdd: + if (element_type == PRED) { + return MapElementwiseOp(arg_types, operands, builder); + } else { + return MapElementwiseOp(arg_types, operands, builder); + } + case HloOpcode::kAnd: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kAtan2: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kCbrt: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kCeil: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kClamp: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kClz: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kCompare: { + auto* context = builder.getContext(); + auto direction = mhlo::symbolizeComparisonDirection( + ComparisonDirectionToString(instr->comparison_direction())); + mhlo::CompareOp::Properties properties; + properties.comparison_direction = + mhlo::ComparisonDirectionAttr::get(context, direction.value()); + auto result_types = llvm::to_vector(mlir::TypeRange{builder.getI1Type()}); + return {{mhlo::MhloOpToStdScalarOp::mapOpOfType( + builder.getLoc(), result_types, arg_types, + mhlo::CompareOp::Adaptor(operands, nullptr, properties), &builder)}}; + } + case HloOpcode::kComplex: + return MapHloOp(element_mlir_type, arg_types, operands, + builder); + case HloOpcode::kCos: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kDivide: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kErf: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kExp: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kExpm1: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kFloor: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kIsFinite: + return MapHloOp(builder.getI1Type(), arg_types, + operands, builder); + case HloOpcode::kImag: + return MapHloOp(element_mlir_type, arg_types, operands, + builder); + case HloOpcode::kLog: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kLog1p: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kLogistic: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kMaximum: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kMinimum: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kMultiply: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kNegate: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kNot: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kOr: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kPopulationCount: + return MapHloOp(result_element_type, arg_types, + operands, builder); + case HloOpcode::kPower: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kReal: + return MapHloOp(element_mlir_type, arg_types, operands, + builder); + case HloOpcode::kReducePrecision: { + mlir::NamedAttribute exponent_bits( + builder.getStringAttr("exponent_bits"), + builder.getI32IntegerAttr(instr->exponent_bits())); + mlir::NamedAttribute mantissa_bits( + builder.getStringAttr("mantissa_bits"), + builder.getI32IntegerAttr(instr->mantissa_bits())); + return MapHloOp( + operands.front().getType(), arg_types, operands, builder, + mlir::DictionaryAttr::get(builder.getContext(), + {exponent_bits, mantissa_bits})); + } + case HloOpcode::kRemainder: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kRoundNearestAfz: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kRoundNearestEven: + return MapElementwiseOp(arg_types, operands, + builder); + case HloOpcode::kRsqrt: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kSelect: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kShiftLeft: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kShiftRightArithmetic: + return MapElementwiseOp(arg_types, operands, + builder); + case HloOpcode::kShiftRightLogical: + return MapElementwiseOp(arg_types, operands, + builder); + case HloOpcode::kSign: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kSin: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kSqrt: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kSubtract: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kTan: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kTanh: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kXor: + return MapElementwiseOp(arg_types, operands, builder); + case HloOpcode::kBitcastConvert: + return MapHloOp(result_element_type, arg_types, + operands, builder); + case HloOpcode::kConvert: { + return {{mhlo::MhloOpToStdScalarOp::mapConvertOpToStdScalarOp( + builder.getLoc(), element_mlir_type, result_element_type, arg_types, + operands, &builder)}}; + } + case HloOpcode::kBitcast: + case HloOpcode::kCopy: + case HloOpcode::kSlice: + case HloOpcode::kBroadcast: + case HloOpcode::kReshape: + case HloOpcode::kReverse: + case HloOpcode::kTranspose: + return operands; + default: + break; + } + + return absl::UnimplementedError(absl::StrCat("Unsupported: ", instr->name())); +} + +} // namespace + +bool IsHloOpSupported(const HloInstruction* instr, + se::CudaComputeCapability compute_capability) { + auto is_unsupported_type = [](const HloInstruction* instr) { + auto e = instr->shape().element_type(); + // TODO(akuegel): Fix remaining issues with complex. + // TODO(jreiffers): Support fp8. + // TODO(jreiffers): Support int4. + return (primitive_util::IsIntegralType(e) && + primitive_util::BitWidth(e) > 1 && + primitive_util::BitWidth(e) < 8) || + primitive_util::IsComplexType(e) || + (primitive_util::IsFloatingPointType(e) && + primitive_util::BitWidth(e) < 16); + }; + if (is_unsupported_type(instr) || + absl::c_any_of(instr->operands(), is_unsupported_type)) { + return false; + } + + return !(kUnsupportedOps.contains(instr->opcode()) || + kUnimplementedOps.contains(instr->opcode()) || + IsUnsupportedConstant(instr) || IsUnsupportedTuple(instr) || + IsUnsupportedGather(instr)); +} + +bool IsHloConversionSupported(const HloComputation* computation, + se::GpuComputeCapability compute_capability) { + if (!std::holds_alternative(compute_capability)) { + // ROCM is not tested. + return false; + } + auto cuda_compute_capability = + std::get(compute_capability); + + return absl::c_all_of( + computation->instructions(), + [=](const HloInstruction* instr) { + return absl::c_all_of(instr->called_computations(), + [&](const HloComputation* called) { + return IsHloConversionSupported( + called, compute_capability); + }) && + IsHloOpSupported(instr, cuda_compute_capability); + }) && + (computation->IsFusionComputation() || + (absl::c_all_of( + computation->parameter_instructions(), [](auto* param) { + return param->shape().IsArray() && param->shape().rank() == 0; + }))); +} + +bool IsHloConversionSupported(const HloFusionAdaptor& fusion, + se::GpuComputeCapability compute_capability) { + if (!std::holds_alternative(compute_capability)) { + // ROCM is not tested. + return false; + } + auto cuda_compute_capability = + std::get(compute_capability); + + if (fusion.GetRoots().size() > 1) { + auto first_shape = fusion.GetRoots()[0].instruction().shape(); + for (int i = 1; i < fusion.GetRoots().size(); ++i) { + if (fusion.GetRoots()[i].instruction().shape().dimensions() != + first_shape.dimensions()) { + return false; + } + } + } + + return !HloFindIf( + fusion.GetRoots(), fusion, [=](HloInstructionAdaptor instr) { + return !absl::c_all_of(instr.instruction().called_computations(), + [&](const HloComputation* called) { + return IsHloConversionSupported( + called, compute_capability); + }) || + !IsHloOpSupported(&instr.instruction(), cuda_compute_capability); + }); +} + +SmallVector ProvideParameter( + const PartitionedComputation::Subgraph& caller, const HloInstruction* instr, + int operand_index, ValueRange indices, + const CallTargetProvider& call_target_provider, mlir::func::FuncOp this_fn, + ImplicitLocOpBuilder& builder) { + auto* operand = instr->operand(operand_index); + + const auto& injected_values = caller.injected_values; + if (auto it = injected_values.find(operand); it != injected_values.end()) { + auto injected_param_values = + this_fn.getArguments().take_back(caller.injected_values.size()); + return {{injected_param_values[it->second]}}; + } + + auto callee = call_target_provider(operand); + SmallVector operands( + this_fn.getArguments().take_front(instr->parent()->num_parameters())); + absl::c_copy(indices, std::back_inserter(operands)); + return builder.create(callee, operands).getResults(); +} + +SmallVector ProvideParameterRange( + const PartitionedComputation::Subgraph& caller, const HloInstruction* instr, + int start, int num, ValueRange indices, + const CallTargetProvider& call_target_provider, mlir::func::FuncOp this_fn, + ImplicitLocOpBuilder& builder) { + SmallVector scalars; + for (int i = 0; i < num; ++i) { + auto scalar = ProvideParameter(caller, instr, i + start, indices, + call_target_provider, this_fn, builder); + CHECK_EQ(scalar.size(), 1); + scalars.push_back(scalar.front()); + } + return scalars; +} + +namespace { + +absl::StatusOr> SubgraphToMlir( + const PartitionedComputation::Subgraph& subgraph, + mlir::func::FuncOp this_fn, const CallTargetProvider& call_target_provider, + ValueRange parameters, ValueRange indices, ImplicitLocOpBuilder& builder) { + SmallVector results; + absl::node_hash_map>, + SmallVector> + cached_instructions; + + std::function>(const HloInstruction* instr, + ValueRange indices)> + emit_instr; + + auto provide_operand = + [&](const HloInstruction* instr, int index, + ValueRange indices) -> absl::StatusOr> { + auto* operand = instr->operand(index); + if (subgraph.instructions.contains(operand)) { + return emit_instr(operand, indices); + } + return ConvertToSignless( + ProvideParameter(subgraph, instr, index, indices, call_target_provider, + this_fn, builder), + builder); + return results; + }; + + emit_instr = [&](const HloInstruction* instr, + ValueRange indices) -> absl::StatusOr> { + // TODO(jreiffers): Check dominance, e.g.: + // + // padding_value = log(param) + // pad = pad(bar, padding_value) + // broadcast = broadcast(padding_value) + // pad + broadcast + // + // If padding_value was first emitted in the context of pad, it'll be + // inside an scf.if. For now this doesn't matter, because the indexing + // is considered to be different, but once the partitioner is smarter, + // it will matter. + // + // Also, this caching should be combined with parameter caching. + std::vector indices_ptrs; + indices_ptrs.reserve(indices.size()); + for (auto index : indices) { + indices_ptrs.push_back(index.getAsOpaquePointer()); + } + auto& entry = cached_instructions[std::make_pair(instr, indices_ptrs)]; + if (!entry.empty()) { + return entry; + } + + TF_ASSIGN_OR_RETURN(auto lowered_instr, + HloToMlir(instr, this_fn, indices, provide_operand, + call_target_provider, builder)); + + entry = ConvertToSignless(lowered_instr, builder); + TF_RET_CHECK(!absl::c_any_of( + entry, [](const auto& entry) { return entry == nullptr; })) + << "null result for " << instr->ToShortString(); + return entry; + }; + + for (const auto* root : subgraph.roots) { + TF_ASSIGN_OR_RETURN(auto root_results, emit_instr(root, indices)); + results.append(root_results.begin(), root_results.end()); + } + return results; +} + +void GetLoopBoundsFromIndexingMap(ImplicitLocOpBuilder& b, + const IndexingMap& indexing_map, + SmallVectorImpl* lbs, + SmallVectorImpl* ubs, + SmallVectorImpl* steps) { + Value c1 = b.create(1); + + for (const Interval& bound : indexing_map.GetSymbolBounds()) { + lbs->push_back(b.create(bound.lower)); + ubs->push_back(b.create(bound.upper + 1)); + steps->push_back(c1); + } +} + +} // namespace + +absl::Status SubgraphToMlirFunction( + const PartitionedComputation& computation, + const PartitionedComputation::Subgraph& subgraph, mlir::func::FuncOp& func, + const CallTargetProvider& call_target_provider) { + TF_RET_CHECK(func != nullptr); + ImplicitLocOpBuilder builder(func.getLoc(), func->getContext()); + builder.setInsertionPointToStart(func.addEntryBlock()); + auto parameters = func.getArguments().take_front( + computation.computation().num_parameters()); + auto indices_and_injected_values = func.getArguments().drop_front( + computation.computation().num_parameters()); + int num_injected_values = subgraph.injected_values.size(); + auto indices = indices_and_injected_values.drop_back(num_injected_values); + TF_ASSIGN_OR_RETURN(auto results, + SubgraphToMlir(subgraph, func, call_target_provider, + parameters, indices, builder)); + + // We have been converting signed types to signless types. To match the + // function signature, we have to convert back to signed types. + auto function = mlir::cast( + results.front().getDefiningOp()->getParentOp()); + const auto& function_results = function.getFunctionType().getResults(); + for (auto [index, function_result] : llvm::enumerate(function_results)) { + results[index] = + builder + .create( + results[index].getLoc(), function_result, results[index]) + .getResult(0); + } + + builder.create(results); + return absl::OkStatus(); +} + +SmallVector EmitLoopNest( + ImplicitLocOpBuilder& b, ValueRange dim_values, ValueRange iter_args_inits, + const IndexingMap& indexing_map, + mlir::function_ref(ValueRange /*iter_args*/, + ValueRange /*dim_values*/, + ValueRange /*symbol_values*/)> + create_body) { + SmallVector lbs, ubs, steps; + GetLoopBoundsFromIndexingMap(b, indexing_map, &lbs, &ubs, &steps); + + scf::LoopNest loop_nest = scf::buildLoopNest( + b, b.getLoc(), lbs, ubs, steps, iter_args_inits, + [&](OpBuilder& nested_builder, Location loc, ValueRange symbol_values, + ValueRange iter_args) -> scf::ValueVector { + ImplicitLocOpBuilder nested_b(loc, nested_builder); + auto is_in_bounds = mlir_converter::CheckConstraints( + indexing_map, dim_values, symbol_values, nested_b); + auto if_op = nested_b.create( + is_in_bounds, + [&](OpBuilder& then_builder, Location then_loc) -> void { + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToStart(then_builder.getInsertionBlock()); + auto results = create_body(iter_args, dim_values, symbol_values); + b.create(results); + }, + [&](OpBuilder& else_b, Location else_loc) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToStart(else_b.getInsertionBlock()); + b.create(iter_args); + }); + + return if_op.getResults(); + }); + return loop_nest.results; +} + +absl::StatusOr> EmitLoopNestWithStatus( + ImplicitLocOpBuilder& b, ValueRange dim_values, ValueRange iter_args_inits, + const IndexingMap& indexing_map, + mlir::function_ref>( + ValueRange /*iter_args*/, ValueRange /*dim_values*/, + ValueRange /*symbol_values*/)> + create_body) { + absl::Status status = absl::OkStatus(); + + auto result = EmitLoopNest( + b, dim_values, iter_args_inits, indexing_map, + [&](ValueRange iter_args, ValueRange dim_values, + ValueRange symbol_values) -> SmallVector { + auto body_result = create_body(iter_args, dim_values, symbol_values); + if (!body_result.ok()) { + status = std::move(body_result.status()); + return SmallVector{}; + } + + return std::move(body_result.value()); + }); + + if (!status.ok()) { + return status; + } + return result; +} + +mlir::Value ClampIndex(mlir::Value index, bool is_unsigned, int64_t high, + ImplicitLocOpBuilder& b) { + auto zero = b.create(b.getIndexAttr(0)); + if (high <= 0) { + return zero; + } + + if (is_unsigned) { + if (index.getType() != b.getIndexType()) { + index = b.create(b.getIndexType(), index); + } + index = b.create( + index, b.create(b.getIndexAttr(high))); + } else { + if (index.getType() != b.getIndexType()) { + index = b.create(b.getIndexType(), index); + } + index = b.create( + index, b.create(b.getIndexAttr(high))); + index = b.create(index, zero); + } + return index; +} + +SmallVector InlineBlock(OpBuilder& builder, Block& src_block, + ValueRange mapped_args) { + IRMapping mapping; + for (auto [from, to] : llvm::zip(src_block.getArguments(), mapped_args)) { + mapping.map(from, to); + } + for (auto& op : src_block.without_terminator()) { + builder.clone(op, mapping); + } + auto* terminator = src_block.getTerminator(); + SmallVector mapped_results; + + mapped_results.reserve(terminator->getResults().size()); + for (mlir::Value result : src_block.getTerminator()->getOperands()) { + mapped_results.push_back(mapping.lookup(result)); + } + return mapped_results; +} + +} // namespace mlir_converter +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h new file mode 100644 index 0000000000000..85d48087a0e98 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h @@ -0,0 +1,131 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_ELEMENTAL_HLO_TO_MLIR_H_ +#define XLA_SERVICE_GPU_FUSIONS_MLIR_ELEMENTAL_HLO_TO_MLIR_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { +namespace gpu { +namespace mlir_converter { + +using OperandProvider = + std::function>( + const HloInstruction* instr, int index, mlir::ValueRange indices)>; + +// Emits MLIR to produce the value(s) of a parameter. The parameter must be +// located outside the subgraph. +llvm::SmallVector ProvideParameter( + const PartitionedComputation::Subgraph& caller, const HloInstruction* instr, + int operand_index, mlir::ValueRange indices, + const CallTargetProvider& call_target_provider, mlir::func::FuncOp this_fn, + mlir::ImplicitLocOpBuilder& builder); + +// Emits MLIR to produce the values of a range of parameters. The parameters +// must all be scalars. The parameters are all evaluated at the same indices. +llvm::SmallVector ProvideParameterRange( + const PartitionedComputation::Subgraph& caller, const HloInstruction* instr, + int start, int num, mlir::ValueRange indices, + const CallTargetProvider& call_target_provider, mlir::func::FuncOp this_fn, + mlir::ImplicitLocOpBuilder& builder); + +// Checks whether the given HLO instruction can be converted to MLIR. +bool IsHloOpSupported(const HloInstruction* instr, + se::CudaComputeCapability compute_capability); + +// Checks whether the given HLO computation is supported by the MLIR converter: +// - all instructions in it are supported +// - the signature is supported: if the computation is not a fusion computation, +// all arguments have rank 0. +bool IsHloConversionSupported(const HloComputation* computation, + se::GpuComputeCapability compute_capability); +bool IsHloConversionSupported(const HloFusionAdaptor& fusion, + se::GpuComputeCapability compute_capability); + +// Converts a function (subgraph) to an MLIR function producing one element of +// the result. The function must have the correct interface. +absl::Status SubgraphToMlirFunction( + const PartitionedComputation& computation, + const PartitionedComputation::Subgraph& subgraph, mlir::func::FuncOp& func, + const CallTargetProvider& call_target_provider); + +// Creates an affine.apply op for the given expression and values. +mlir::Value ApplyAffineExpr(mlir::AffineExpr expr, mlir::ValueRange dims, + mlir::ValueRange symbols, + mlir::ImplicitLocOpBuilder& b); + +// Creates affine.apply ops for each result of the given map. +llvm::SmallVector ApplyAffineMap(mlir::AffineMap map, + mlir::ValueRange dims, + mlir::ValueRange symbols, + mlir::ImplicitLocOpBuilder& b); + +// Checks all the constraints and dimension ranges in the map. +mlir::Value CheckConstraints(const IndexingMap& map, mlir::ValueRange dims, + mlir::ValueRange symbols, + mlir::ImplicitLocOpBuilder& b); + +// Emits a loop nest over the entire domain of the indexing_map at a point +// `dim_values`. +llvm::SmallVector EmitLoopNest( + mlir::ImplicitLocOpBuilder& b, mlir::ValueRange dim_values, + mlir::ValueRange iter_args_inits, const IndexingMap& indexing_map, + mlir::function_ref( + mlir::ValueRange iter_args, mlir::ValueRange dim_values, + mlir::ValueRange symbol_values)> + create_body); + +// Same as EmitLoopNest, but the body building function can return an error +// which gets returned from EmitLoopNestWithStatus. +absl::StatusOr> EmitLoopNestWithStatus( + mlir::ImplicitLocOpBuilder& b, mlir::ValueRange dim_values, + mlir::ValueRange iter_args_inits, const IndexingMap& indexing_map, + mlir::function_ref>( + mlir::ValueRange iter_args, mlir::ValueRange dim_values, + mlir::ValueRange symbol_values)> + create_body); + +// Clamps `index` to [0, high] boundaries. +mlir::Value ClampIndex(mlir::Value index, bool is_unsigned, int64_t high, + mlir::ImplicitLocOpBuilder& b); + +// Inlines `src_block` using `mapped_args` to initialize IRMapping from the +// block arguments of `src_block` to `mapped_args`. Return remapped values of +// the terminator. +mlir::SmallVector InlineBlock(mlir::OpBuilder& builder, + mlir::Block& src_block, + mlir::ValueRange mapped_args); + +} // namespace mlir_converter +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_ELEMENTAL_HLO_TO_MLIR_H_ diff --git a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc new file mode 100644 index 0000000000000..cae011f1df4e2 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc @@ -0,0 +1,1439 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" + +#include +#include + +#include +#include "absl/status/status.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/AsmParser/AsmParser.h" // from @llvm-project +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/DLTI/DLTI.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/hlo_parser.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/status_macros.h" +#include "xla/tests/filecheck.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace mlir_converter { +namespace { + +class ElementalHloToMlirTest : public HloTestBase { + public: + ElementalHloToMlirTest() { + context_.loadDialect(); + } + + // Converts the root subgraph of the entry function of the given hlo module to + // MLIR. + absl::Status Run( + const std::string& hlo, const std::string& filecheck_str, + std::function is_subgraph_root = nullptr) { + auto hlo_module = ParseAndReturnVerifiedModule(hlo).value(); + + mlir::ImplicitLocOpBuilder builder(mlir::UnknownLoc::get(&context_), + &context_); + auto module = llvm_ir::CreateMlirModuleOp(builder.getLoc()); + (*module)->setAttr( + mlir::DLTIDialect::kDataLayoutAttrName, + mlir::parseAttribute("#dlti.dl_spec<#dlti.dl_entry>", + builder.getContext())); + builder.setInsertionPointToStart(module->getBody()); + auto* entry_computation = hlo_module->entry_computation(); + std::vector roots; + if (is_subgraph_root) { + for (auto* instr : entry_computation->instructions()) { + if (is_subgraph_root(instr)) { + roots.push_back(instr); + } + } + } + PartitionedComputations partitioned_computations(entry_computation, roots); + auto fns = partitioned_computations.DeclareFunctions(module.get()); + auto entry_func = fns[&partitioned_computations + .FindPartitionedComputation(entry_computation) + .GetRootSubgraph()]; + auto& entry_pc = + partitioned_computations.FindPartitionedComputation(entry_computation); + auto call_targets = partitioned_computations.CreateCallTargetProvider(fns); + TF_RETURN_IF_ERROR(SubgraphToMlirFunction( + entry_pc, entry_pc.GetRootSubgraph(), entry_func, call_targets)); + + if (const auto& epilogue = partitioned_computations.epilogue()) { + TF_RETURN_IF_ERROR(SubgraphToMlirFunction(entry_pc, *epilogue, + fns[&*epilogue], call_targets)); + } + + // Canonicalize and CSE for better readability of check tests. + mlir::PassManager pm(&context_); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + TF_RET_CHECK(pm.run(module.get()).succeeded()); + + std::string out; + llvm::raw_string_ostream stream(out); + stream << module.get(); + + TF_ASSIGN_OR_RETURN(auto filecheck_result, + RunFileCheck(out, filecheck_str)); + TF_RET_CHECK(filecheck_result); + return absl::OkStatus(); + } + + mlir::MLIRContext context_; +}; + +TEST_F(ElementalHloToMlirTest, Reduce) { + TF_EXPECT_OK(Run(R"( + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT sum = f32[] add(p0, p1) + } + + ENTRY main { + p0 = f32[10,20,30,40] parameter(0) + p1 = f32[] parameter(1) + ROOT r = f32[10,30] reduce(p0, p1), dimensions={1,3}, + to_apply=add + })", + R"( + // CHECK: @main_r( + // CHECK-SAME: %[[ARG0:.*]]: tensor<10x20x30x40xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor + // CHECK-SAME: %[[X:.*]]: index {{.*}}, %[[Y:.*]]: index {{.*}} -> f32 + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-DAG: %[[C20:.*]] = arith.constant 20 + // CHECK-DAG: %[[C40:.*]] = arith.constant 40 + // CHECK: %[[INIT:.*]] = tensor.extract %[[ARG1]][] + // CHECK: %[[RET:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C20]] + // CHECK-SAME: step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) + // CHECK: %[[RET_INNER:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[C40]] + // CHECK-SAME: iter_args(%[[ACC_INNER:.*]] = %[[ACC]]) + // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]] + // CHECK-SAME: [%[[X]], %[[I]], %[[Y]], %[[J]]] + // CHECK: %[[UPD:.*]] = func.call @add_sum(%[[ACC_INNER]], + // CHECK-SAME: %[[VAL]]) + // CHECK: scf.yield %[[UPD]] + // CHECK: } + // CHECK: scf.yield %[[RET_INNER]] + // CHECK: } + // CHECK: return %[[RET]] + )")); +} + +TEST_F(ElementalHloToMlirTest, ReduceUnsigned) { + TF_EXPECT_OK(Run(R"( + add { + p0 = u32[] parameter(0) + p1 = u32[] parameter(1) + ROOT sum = u32[] add(p0, p1) + } + + ENTRY main { + p0 = u32[10,20,30,40] parameter(0) + p1 = u32[] parameter(1) + ROOT r = u32[10,30] reduce(p0, p1), dimensions={1,3}, + to_apply=add + })", + R"( + // CHECK: @main_r( + // CHECK-SAME: %[[ARG0:.*]]: tensor<10x20x30x40xui32> + // CHECK-SAME: %[[ARG1:.*]]: tensor + // CHECK-SAME: %[[X:.*]]: index {{.*}}, %[[Y:.*]]: index {{.*}} -> ui32 + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-DAG: %[[C20:.*]] = arith.constant 20 + // CHECK-DAG: %[[C40:.*]] = arith.constant 40 + // CHECK: %[[INIT:.*]] = tensor.extract %[[ARG1]][] + // CHECK: %[[RET:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C20]] + // CHECK-SAME: step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) + // CHECK: %[[RET_INNER:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[C40]] + // CHECK-SAME: iter_args(%[[ACC_INNER:.*]] = %[[ACC]]) + // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]] + // CHECK-SAME: [%[[X]], %[[I]], %[[Y]], %[[J]]] + // CHECK: %[[UPD:.*]] = func.call @add_sum(%[[ACC_INNER]], + // CHECK-SAME: %[[VAL]]) + // CHECK: scf.yield %[[UPD]] + // CHECK: } + // CHECK: scf.yield %[[RET_INNER]] + // CHECK: } + // CHECK: return %[[RET]] + )")); +} + +TEST_F(ElementalHloToMlirTest, ReduceWindow) { + TF_EXPECT_OK(Run(R"( + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT sum = f32[] add(p0, p1) + } + + ENTRY main { + p0 = f32[42,12,8] parameter(0) + p1 = f32[] parameter(1) + ROOT r = f32[42,3,8] reduce-window(p0, p1), window={ + size=1x1x7 + stride=1x4x1 + pad=0_0x0_0x3_3 + }, + to_apply=add + })", + R"( + // CHECK: @main_r( + // CHECK-SAME: %[[ARG0:.*]]: tensor<42x12x8xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor + // CHECK-SAME: %[[X:arg[0-9]*]]: index {{[^}]*}}}, + // CHECK-SAME: %[[Y:arg[0-9]*]]: index {{[^}]*}}}, + // CHECK-SAME: %[[Z:arg[0-9]*]]: index {{[^}]*}}}) -> f32 + // CHECK-DAG: %[[C10:.*]] = arith.constant 10 + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-DAG: %[[C7:.*]] = arith.constant 7 + // CHECK: %[[INIT:.*]] = tensor.extract %[[ARG1]][] + // CHECK: %[[RET:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C7]] + // CHECK-SAME: step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) + // CHECK: %[[J:.*]] = affine.apply affine_map<()[s0] -> + // CHECK-SAME: (s0 * 4)>()[%[[Y]]] + // CHECK: %[[K:.*]] = affine.apply affine_map<()[s0, s1] -> + // CHECK-SAME: (s0 + s1 - 3)>()[%[[I]], %[[Z]]] + // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]] + // CHECK-SAME: [%[[X]], %[[J]], %[[K]]] + // CHECK: %[[UPD:.*]] = func.call @add_sum(%[[ACC]], + // CHECK-SAME: %[[VAL]]) + // CHECK: scf.yield %[[UPD]] + // CHECK: } + // CHECK: } + // CHECK: return %[[RET]] + )")); +} + +TEST_F(ElementalHloToMlirTest, ReduceWindowWithRescaling) { + TF_EXPECT_OK(Run(R"( + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT sum = f32[] add(p0, p1) + } + + ENTRY main { + p0 = f32[42,12,8] parameter(0) + p1 = f32[] parameter(1) + ROOT r = f32[19,12,8] reduce-window(p0, p1), window={ + size=8x1x1 + stride=4x1x1 + pad=0_0x0_0x0_0 + lhs_dilate=2x1x1 + }, + to_apply=add + })", + R"( + // CHECK: @main_r( + // CHECK-SAME: %[[ARG0:.*]]: tensor<42x12x8xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor + // CHECK-SAME: %[[X:arg[0-9]*]]: index {{[^}]*}}}, + // CHECK-SAME: %[[Y:arg[0-9]*]]: index {{[^}]*}}}, + // CHECK-SAME: %[[Z:arg[0-9]*]]: index {{[^}]*}}}) -> f32 + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + + // We have a window size of 8, but expect a loop from 0 to 4 + // due to the base dilation of 2 and the applied symbol rescaling: + // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] + // CHECK: %[[K:.*]] = affine.apply affine_map<()[s0, s1] -> + // If symbol rescaling wasn't working we would have a + // `s0 floordiv ` in the map: + // CHECK-SAME: (s0 + s1 * 2)>()[%[[I]], %[[X]]] + // CHECK: tensor.extract %[[ARG0]][%[[K]], %[[Y]], %[[Z]]] + )")); +} + +TEST_F(ElementalHloToMlirTest, Concatenate) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[10,20,30] parameter(0) + p1 = f32[10,15,30] parameter(1) + p2 = f32[10,3,30] parameter(2) + ROOT r = f32[10,38,30] concatenate(p0, p1, p2), dimensions={1} + })", + R"( + // CHECK: @main_r( + // CHECK-SAME: %[[ARG0:.*]]: tensor<10x20x30xf32>, + // CHECK-SAME: %[[ARG1:.*]]: tensor<10x15x30xf32>, + // CHECK-SAME: %[[ARG2:.*]]: tensor<10x3x30xf32>, + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}}, + // CHECK-SAME: %[[Z:.*]]: index {{{.*}}} + // CHECK-DAG: %[[C35:.*]] = arith.constant 35 + // CHECK-DAG: %[[C20:.*]] = arith.constant 20 + // CHECK: %[[IN_BOUNDS:.*]] = arith.cmpi ult, %[[Y]], %[[C20]] + // CHECK: %[[CONCAT:.*]] = scf.if %[[IN_BOUNDS]] + // CHECK: %[[P0_VAL:.*]] = xla_gpu.pure_call @main_p0 + // CHECK-SAME: %[[X]], %[[Y]], %[[Z]] + // CHECK: scf.yield %[[P0_VAL]] + // CHECK: } else { + // CHECK: %[[IN_BOUNDS:.*]] = arith.cmpi ult, %[[Y]], %[[C35]] + // CHECK: %[[CONCAT2:.*]] = scf.if %[[IN_BOUNDS]] + // CHECK: %[[OFFSET:.*]] = arith.subi %[[Y]], %[[C20]] + // CHECK: %[[P1_VAL:.*]] = xla_gpu.pure_call @main_p1 + // CHECK-SAME: %[[X]], %[[OFFSET]], %[[Z]] + // CHECK: scf.yield %[[P1_VAL]] + // CHECK: } else { + // CHECK: %[[OFFSET:.*]] = arith.subi %[[Y]], %[[C35]] + // CHECK: %[[P2_VAL:.*]] = xla_gpu.pure_call @main_p2 + // CHECK-SAME: %[[X]], %[[OFFSET]], %[[Z]] + // CHECK: scf.yield %[[P2_VAL]] + // CHECK: } + // CHECK: scf.yield %[[CONCAT2]] + // CHECK: } + // CHECK: return %[[CONCAT]] + )")); +} + +TEST_F(ElementalHloToMlirTest, ConcatenateUnsigned) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = u32[10,20,30] parameter(0) + p1 = u32[10,15,30] parameter(1) + ROOT r = u32[10,35,30] concatenate(p0, p1), dimensions={1} + })", + R"( + // CHECK: @main_r( + // CHECK-SAME: %[[ARG0:.*]]: tensor<10x20x30xui32>, + // CHECK-SAME: %[[ARG1:.*]]: tensor<10x15x30xui32> + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}}, + // CHECK-SAME: %[[Z:.*]]: index {{{.*}}} + // CHECK-DAG: %[[C20:.*]] = arith.constant 20 + // CHECK: %[[IN_BOUNDS:.*]] = arith.cmpi ult, %[[Y]], %[[C20]] + // CHECK: %[[CONCAT:.*]] = scf.if %[[IN_BOUNDS]] + // CHECK: %[[P0_VAL:.*]] = xla_gpu.pure_call @main_p0 + // CHECK-SAME: %[[X]], %[[Y]], %[[Z]] + // CHECK: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[P0_VAL]] + // CHECK: scf.yield %[[CAST0]] + // CHECK: } else { + // CHECK: %[[OFFSET:.*]] = arith.subi %[[Y]], %[[C20]] + // CHECK: %[[P1_VAL:.*]] = xla_gpu.pure_call @main_p1 + // CHECK-SAME: %[[X]], %[[OFFSET]], %[[Z]] + // CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[P1_VAL]] + // CHECK: scf.yield %[[CAST1]] + // CHECK: } + // CHECK: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[CONCAT]] + // CHECK: return %[[CAST2]] + )")); +} + +TEST_F(ElementalHloToMlirTest, Gather) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + operand = f32[33,34] parameter(0) + indices = s32[1806,1] parameter(1) + ROOT r = f32[1806,7,8] gather(operand, indices), offset_dims={1,2}, + collapsed_slice_dims={}, start_index_map={0}, + index_vector_dim=1, slice_sizes={7,8} + })", + R"( + // CHECK: @main_r( + // CHECK-SAME: %[[ARG0:.*]]: tensor<33x34xf32>, + // CHECK-SAME: %[[ARG1:.*]]: tensor<1806x1xi32>, + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}}, + // CHECK-SAME: %[[Z:.*]]: index {{{.*}}} + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 + // CHECK-DAG: %[[C26:.*]] = arith.constant 26 + // CHECK: %[[IDX_I32:.*]] = tensor.extract %[[ARG1]][%[[X]], %[[C0]]] + // CHECK: %[[IDX:.*]] = arith.index_cast %[[IDX_I32]] : i32 to index + // CHECK: %[[CLAMP_HIGH:.*]] = arith.minsi %[[IDX]], %[[C26]] + // CHECK: %[[CLAMPED:.*]] = arith.maxsi %[[CLAMP_HIGH]], %[[C0]] + // CHECK: %[[X_IN:.*]] = arith.addi %[[CLAMPED]], %[[Y]] + // CHECK: %[[RET:.*]] = tensor.extract %[[ARG0]][%[[X_IN]], %[[Z]]] + // CHECK: return %[[RET]] + )")); +} + +TEST_F(ElementalHloToMlirTest, Pad) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[4, 4] parameter(0) + p1 = f32[] parameter(1) + ROOT pad = f32[12, 16] pad(p0, p1), padding=1_4_1x4_8_0 + })", + R"( + // CHECK: @main_pad( + // CHECK-SAME: %[[ARG0:.*]]: tensor<4x4xf32>, + // CHECK-SAME: %[[ARG1:.*]]: tensor, + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}} + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 + // CHECK-DAG: %[[C7:.*]] = arith.constant 7 + // CHECK: %[[CONSTRAINT_VAL:.*]] = affine.apply + // CHECK-SAME: <()[s0] -> (s0 - ((s0 - 1) floordiv 2) * 2 - 1)> + // CHECK-SAME: ()[%[[X]]] + // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] + // CHECK: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] + // CHECK: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]] + // CHECK: %[[X_BOUNDS:.*]] = arith.andi %[[X_L]], %[[X_H]] + // CHECK: %[[X_AND_CONSTRAINT:.*]] = arith.andi %[[CONSTRAINT]], %[[X_BOUNDS]] + // CHECK: %[[Y_L:.*]] = arith.cmpi sge, %[[Y]], %[[C4]] + // CHECK: %[[Y_H:.*]] = arith.cmpi sle, %[[Y]], %[[C7]] + // CHECK: %[[Y_BOUNDS:.*]] = arith.andi %[[Y_L]], %[[Y_H]] + // CHECK: %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]] + // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] + // CHECK: %[[X_IN:.*]] = affine.apply + // CHECK-SAME: <()[s0] -> ((s0 - 1) floordiv 2)>()[%[[X]]] + // CHECK: %[[Y_IN:.*]] = affine.apply + // CHECK-SAME: <()[s0] -> (s0 - 4)>()[%[[Y]]] + // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[X_IN]], %[[Y_IN]]] + // CHECK: scf.yield %[[VAL]] + // CHECK: } else { + // CHECK: %[[PAD_VAL:.*]] = tensor.extract %[[ARG1]][] + // CHECK: scf.yield %[[PAD_VAL]] + // CHECK: } + // CHECK: return %[[RET]] + )")); +} + +TEST_F(ElementalHloToMlirTest, PadUnsigned) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = u32[4, 4] parameter(0) + p1 = u32[] parameter(1) + ROOT pad = u32[12, 16] pad(p0, p1), padding=1_4_1x4_8_0 + })", + R"( + // CHECK: @main_pad( + // CHECK-SAME: %[[ARG0:.*]]: tensor<4x4xui32>, + // CHECK-SAME: %[[ARG1:.*]]: tensor, + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}} + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 + // CHECK-DAG: %[[C7:.*]] = arith.constant 7 + // CHECK: %[[CONSTRAINT_VAL:.*]] = affine.apply + // CHECK-SAME: <()[s0] -> (s0 - ((s0 - 1) floordiv 2) * 2 - 1)> + // CHECK-SAME: ()[%[[X]]] + // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] + // CHECK: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] + // CHECK: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]] + // CHECK: %[[X_BOUNDS:.*]] = arith.andi %[[X_L]], %[[X_H]] + // CHECK: %[[X_AND_CONSTRAINT:.*]] = arith.andi %[[CONSTRAINT]], %[[X_BOUNDS]] + // CHECK: %[[Y_L:.*]] = arith.cmpi sge, %[[Y]], %[[C4]] + // CHECK: %[[Y_H:.*]] = arith.cmpi sle, %[[Y]], %[[C7]] + // CHECK: %[[Y_BOUNDS:.*]] = arith.andi %[[Y_L]], %[[Y_H]] + // CHECK: %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]] + // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] + // CHECK: %[[X_IN:.*]] = affine.apply + // CHECK-SAME: <()[s0] -> ((s0 - 1) floordiv 2)>()[%[[X]]] + // CHECK: %[[Y_IN:.*]] = affine.apply + // CHECK-SAME: <()[s0] -> (s0 - 4)>()[%[[Y]]] + // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[X_IN]], %[[Y_IN]]] + // CHECK: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[VAL]] + // CHECK: scf.yield %[[CAST0]] + // CHECK: } else { + // CHECK: %[[PAD_VAL:.*]] = tensor.extract %[[ARG1]][] + // CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[PAD_VAL]] + // CHECK: scf.yield %[[CAST1]] + // CHECK: } + // CHECK: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[RET]] + // CHECK: return %[[CAST2]] + )")); +} + +TEST_F(ElementalHloToMlirTest, DotWithF32Type) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[3, 4] parameter(0) + p1 = f32[4, 5] parameter(1) + ROOT dot = f32[3, 5] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + })", + R"( + // CHECK: @main_dot( + // CHECK-SAME: %[[A:.*]]: tensor<3x4xf32>, %[[B:.*]]: tensor<4x5xf32>, + // CHECK-SAME: %[[I:.*]]: index {xla.range = [0 : index, 2 : index]}, + // CHECK-SAME: %[[J:.*]]: index {xla.range = [0 : index, 4 : index]}) + // CHECK-SAME: -> f32 + // CHECK-SAME: { + // CHECK-DAG: %[[ACCUM_INIT:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + // CHECK: %[[FOR0:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[C4]] step %[[C1]] + // CHECK-SAME: iter_args(%[[ACCUM:.*]] = %[[ACCUM_INIT]]) -> (f32) { + // CHECK-DAG: %[[CMPI0:.*]] = arith.cmpi sge, %[[I]], %[[C0]] : index + // CHECK-DAG: %[[CMPI1:.*]] = arith.cmpi sle, %[[I]], %[[C2]] : index + // CHECK-DAG: %[[I_IN_RANGE:.*]] = arith.andi %[[CMPI0]], %[[CMPI1]] : i1 + // CHECK-DAG: %[[CMPI2:.*]] = arith.cmpi sge, %[[J]], %[[C0]] : index + // CHECK-DAG: %[[CMPI3:.*]] = arith.cmpi sle, %[[J]], %[[C4]] : index + // CHECK-DAG: %[[J_IN_RANGE:.*]] = arith.andi %[[CMPI2]], %[[CMPI3]] : i1 + // CHECK-DAG: %[[I_J_IN_RANGE:.*]] = arith.andi %[[I_IN_RANGE]], %[[J_IN_RANGE]] : i1 + // CHECK: %[[IF0:.*]] = scf.if %[[I_J_IN_RANGE]] -> (f32) { + // CHECK-DAG: %[[A_I_K:.*]] = tensor.extract %[[A]][%[[I]], %[[K]]] : tensor<3x4xf32> + // CHECK-DAG: %[[B_K_J:.*]] = tensor.extract %[[B]][%[[K]], %[[J]]] : tensor<4x5xf32> + // CHECK-DAG: %[[MULF0:.*]] = arith.mulf %[[A_I_K]], %[[B_K_J]] : f32 + // CHECK-DAG: %[[ADDF0:.*]] = arith.addf %[[ACCUM]], %[[MULF0]] : f32 + // CHECK-DAG: scf.yield %[[ADDF0]] : f32 + // CHECK: } else { + // CHECK: scf.yield %[[ACCUM]] : f32 + // CHECK: } + // CHECK: scf.yield %[[IF0]] : f32 + // CHECK: } + // CHECK: return %[[FOR0]] : f32 + // CHECK: } + )")); +} + +TEST_F(ElementalHloToMlirTest, DotWithBF16Type) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = bf16[3, 4] parameter(0) + p1 = bf16[4, 5] parameter(1) + ROOT dot = bf16[3, 5] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + })", + R"( + // CHECK: @main_dot( + // CHECK-SAME: %[[A:.*]]: tensor<3x4xbf16>, %[[B:.*]]: tensor<4x5xbf16>, + // CHECK-SAME: %[[I:.*]]: index {xla.range = [0 : index, 2 : index]}, + // CHECK-SAME: %[[J:.*]]: index {xla.range = [0 : index, 4 : index]}) + // CHECK-SAME: -> bf16 + // CHECK-SAME: { + // CHECK-DAG: %[[ACCUM_INIT:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + // CHECK: %[[FOR0:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[C4]] step %[[C1]] + // CHECK-SAME: iter_args(%[[ACCUM:.*]] = %[[ACCUM_INIT]]) -> (f32) { + // CHECK-DAG: %[[CMPI0:.*]] = arith.cmpi sge, %[[I]], %[[C0]] : index + // CHECK-DAG: %[[CMPI1:.*]] = arith.cmpi sle, %[[I]], %[[C2]] : index + // CHECK-DAG: %[[I_IN_RANGE:.*]] = arith.andi %[[CMPI0]], %[[CMPI1]] : i1 + // CHECK-DAG: %[[CMPI2:.*]] = arith.cmpi sge, %[[J]], %[[C0]] : index + // CHECK-DAG: %[[CMPI3:.*]] = arith.cmpi sle, %[[J]], %[[C4]] : index + // CHECK-DAG: %[[J_IN_RANGE:.*]] = arith.andi %[[CMPI2]], %[[CMPI3]] : i1 + // CHECK-DAG: %[[I_J_IN_RANGE:.*]] = arith.andi %[[I_IN_RANGE]], %[[J_IN_RANGE]] : i1 + // CHECK: %[[IF0:.*]] = scf.if %[[I_J_IN_RANGE]] -> (f32) { + // CHECK-DAG: %[[A_I_K:.*]] = tensor.extract %[[A]][%[[I]], %[[K]]] : tensor<3x4xbf16> + // CHECK-DAG: %[[B_K_J:.*]] = tensor.extract %[[B]][%[[K]], %[[J]]] : tensor<4x5xbf16> + // CHECK-DAG: %[[A_I_K_F32:.*]] = arith.extf %[[A_I_K]] : bf16 to f32 + // CHECK-DAG: %[[B_K_J_F32:.*]] = arith.extf %[[B_K_J]] : bf16 to f32 + // CHECK-DAG: %[[MULF0:.*]] = arith.mulf %[[A_I_K_F32]], %[[B_K_J_F32]] : f32 + // CHECK-DAG: %[[ADDF0:.*]] = arith.addf %[[ACCUM]], %[[MULF0]] : f32 + // CHECK-DAG: scf.yield %[[ADDF0]] : f32 + // CHECK: } else { + // CHECK: scf.yield %[[ACCUM]] : f32 + // CHECK: } + // CHECK: scf.yield %[[IF0]] : f32 + // CHECK: } + // CHECK: %[[FOR0_BF16:.*]] = arith.truncf %[[FOR0]] : f32 to bf16 + // CHECK: return %[[FOR0_BF16]] : bf16 + // CHECK: } + )")); +} + +TEST_F(ElementalHloToMlirTest, DotWithS32Type) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = s32[3, 4] parameter(0) + p1 = s32[4, 5] parameter(1) + ROOT dot = s32[3, 5] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + })", + R"( + // CHECK: @main_dot( + // CHECK-SAME: %[[A:.*]]: tensor<3x4xi32>, %[[B:.*]]: tensor<4x5xi32>, + // CHECK-SAME: %[[I:.*]]: index {xla.range = [0 : index, 2 : index]}, + // CHECK-SAME: %[[J:.*]]: index {xla.range = [0 : index, 4 : index]}) + // CHECK-SAME: -> i32 + // CHECK-SAME: { + // CHECK-DAG: %[[ACCUM_INIT:.*]] = arith.constant 0 : i32 + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + // CHECK: %[[FOR0:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[C4]] step %[[C1]] + // CHECK-SAME: iter_args(%[[ACCUM:.*]] = %[[ACCUM_INIT]]) -> (i32) { + // CHECK-DAG: %[[CMPI0:.*]] = arith.cmpi sge, %[[I]], %[[C0]] : index + // CHECK-DAG: %[[CMPI1:.*]] = arith.cmpi sle, %[[I]], %[[C2]] : index + // CHECK-DAG: %[[I_IN_RANGE:.*]] = arith.andi %[[CMPI0]], %[[CMPI1]] : i1 + // CHECK-DAG: %[[CMPI2:.*]] = arith.cmpi sge, %[[J]], %[[C0]] : index + // CHECK-DAG: %[[CMPI3:.*]] = arith.cmpi sle, %[[J]], %[[C4]] : index + // CHECK-DAG: %[[J_IN_RANGE:.*]] = arith.andi %[[CMPI2]], %[[CMPI3]] : i1 + // CHECK-DAG: %[[I_J_IN_RANGE:.*]] = arith.andi %[[I_IN_RANGE]], %[[J_IN_RANGE]] : i1 + // CHECK: %[[IF0:.*]] = scf.if %[[I_J_IN_RANGE]] -> (i32) { + // CHECK-DAG: %[[A_I_K:.*]] = tensor.extract %[[A]][%[[I]], %[[K]]] : tensor<3x4xi32> + // CHECK-DAG: %[[B_K_J:.*]] = tensor.extract %[[B]][%[[K]], %[[J]]] : tensor<4x5xi32> + // CHECK-DAG: %[[MUL0:.*]] = arith.muli %[[A_I_K]], %[[B_K_J]] : i32 + // CHECK-DAG: %[[ADD0:.*]] = arith.addi %[[ACCUM]], %[[MUL0]] : i32 + // CHECK-DAG: scf.yield %[[ADD0]] : i32 + // CHECK: } else { + // CHECK: scf.yield %[[ACCUM]] : i32 + // CHECK: } + // CHECK: scf.yield %[[IF0]] : i32 + // CHECK: } + // CHECK: return %[[FOR0]] : i32 + // CHECK: } + )")); +} + +TEST_F(ElementalHloToMlirTest, DotWithU32Type) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = u32[3, 4] parameter(0) + p1 = u32[4, 5] parameter(1) + ROOT dot = u32[3, 5] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + })", + R"( + // CHECK: @main_dot( + // CHECK-SAME: %[[A:.*]]: tensor<3x4xui32>, %[[B:.*]]: tensor<4x5xui32>, + // CHECK-SAME: %[[I:.*]]: index {xla.range = [0 : index, 2 : index]}, + // CHECK-SAME: %[[J:.*]]: index {xla.range = [0 : index, 4 : index]}) + // CHECK-SAME: -> ui32 + // CHECK-SAME: { + // CHECK-DAG: %[[ACCUM_INIT:.*]] = arith.constant 0 : i32 + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + // CHECK: %[[FOR0:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[C4]] step %[[C1]] + // CHECK-SAME: iter_args(%[[ACCUM:.*]] = %[[ACCUM_INIT]]) -> (i32) { + // CHECK-DAG: %[[CMPI0:.*]] = arith.cmpi sge, %[[I]], %[[C0]] : index + // CHECK-DAG: %[[CMPI1:.*]] = arith.cmpi sle, %[[I]], %[[C2]] : index + // CHECK-DAG: %[[I_IN_RANGE:.*]] = arith.andi %[[CMPI0]], %[[CMPI1]] : i1 + // CHECK-DAG: %[[CMPI2:.*]] = arith.cmpi sge, %[[J]], %[[C0]] : index + // CHECK-DAG: %[[CMPI3:.*]] = arith.cmpi sle, %[[J]], %[[C4]] : index + // CHECK-DAG: %[[J_IN_RANGE:.*]] = arith.andi %[[CMPI2]], %[[CMPI3]] : i1 + // CHECK-DAG: %[[I_J_IN_RANGE:.*]] = arith.andi %[[I_IN_RANGE]], %[[J_IN_RANGE]] : i1 + // CHECK: %[[IF0:.*]] = scf.if %[[I_J_IN_RANGE]] -> (i32) { + // CHECK-DAG: %[[A_I_K:.*]] = tensor.extract %[[A]][%[[I]], %[[K]]] : tensor<3x4xui32> + // CHECK-DAG: %[[A_I_K_I32:.*]] = builtin.unrealized_conversion_cast %[[A_I_K]] : ui32 to i32 + // CHECK-DAG: %[[B_K_J:.*]] = tensor.extract %[[B]][%[[K]], %[[J]]] : tensor<4x5xui32> + // CHECK-DAG: %[[B_K_J_I32:.*]] = builtin.unrealized_conversion_cast %[[B_K_J]] : ui32 to i32 + // CHECK-DAG: %[[MUL0:.*]] = arith.muli %[[A_I_K_I32]], %[[B_K_J_I32]] : i32 + // CHECK-DAG: %[[ADD0:.*]] = arith.addi %[[ACCUM]], %[[MUL0]] : i32 + // CHECK-DAG: scf.yield %[[ADD0]] : i32 + // CHECK: } else { + // CHECK: scf.yield %[[ACCUM]] : i32 + // CHECK: } + // CHECK: scf.yield %[[IF0]] : i32 + // CHECK: } + // CHECK: %[[FOR0_UI32:.*]] = builtin.unrealized_conversion_cast %[[FOR0]] : i32 to ui32 + // CHECK: return %[[FOR0_UI32]] : ui32 + // CHECK: } + )")); +} + +TEST_F(ElementalHloToMlirTest, DotWithPredType) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = pred[3, 4] parameter(0) + p1 = pred[4, 5] parameter(1) + ROOT dot = pred[3, 5] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + })", + R"( + // CHECK: @main_dot( + // CHECK-SAME: %[[A:.*]]: tensor<3x4xi1>, %[[B:.*]]: tensor<4x5xi1>, + // CHECK-SAME: %[[I:.*]]: index {xla.range = [0 : index, 2 : index]}, + // CHECK-SAME: %[[J:.*]]: index {xla.range = [0 : index, 4 : index]}) + // CHECK-SAME: -> i1 + // CHECK-SAME: { + // CHECK-DAG: %[[ACCUM_INIT:.*]] = arith.constant false + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + // CHECK: %[[FOR0:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[C4]] step %[[C1]] + // CHECK-SAME: iter_args(%[[ACCUM:.*]] = %[[ACCUM_INIT]]) -> (i1) { + // CHECK-DAG: %[[CMPI0:.*]] = arith.cmpi sge, %[[I]], %[[C0]] : index + // CHECK-DAG: %[[CMPI1:.*]] = arith.cmpi sle, %[[I]], %[[C2]] : index + // CHECK-DAG: %[[I_IN_RANGE:.*]] = arith.andi %[[CMPI0]], %[[CMPI1]] : i1 + // CHECK-DAG: %[[CMPI2:.*]] = arith.cmpi sge, %[[J]], %[[C0]] : index + // CHECK-DAG: %[[CMPI3:.*]] = arith.cmpi sle, %[[J]], %[[C4]] : index + // CHECK-DAG: %[[J_IN_RANGE:.*]] = arith.andi %[[CMPI2]], %[[CMPI3]] : i1 + // CHECK-DAG: %[[I_J_IN_RANGE:.*]] = arith.andi %[[I_IN_RANGE]], %[[J_IN_RANGE]] : i1 + // CHECK: %[[IF0:.*]] = scf.if %[[I_J_IN_RANGE]] -> (i1) { + // CHECK-DAG: %[[A_I_K:.*]] = tensor.extract %[[A]][%[[I]], %[[K]]] : tensor<3x4xi1> + // CHECK-DAG: %[[B_K_J:.*]] = tensor.extract %[[B]][%[[K]], %[[J]]] : tensor<4x5xi1> + // CHECK-DAG: %[[AND0:.*]] = arith.andi %[[A_I_K]], %[[B_K_J]] : i1 + // CHECK-DAG: %[[OR0:.*]] = arith.ori %[[ACCUM]], %[[AND0]] : i1 + // CHECK-DAG: scf.yield %[[OR0]] : i1 + // CHECK: } else { + // CHECK: scf.yield %[[ACCUM]] : i1 + // CHECK: } + // CHECK: scf.yield %[[IF0]] : i1 + // CHECK: } + // CHECK: return %[[FOR0]] : i1 + // CHECK: } + )")); +} + +TEST_F(ElementalHloToMlirTest, DotWithBatchAnd2ContractingDims) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[7, 3, 4, 5] parameter(0) + p1 = f32[5, 6, 4, 7] parameter(1) + ROOT dot = f32[7, 3, 6] dot(p0, p1), + lhs_contracting_dims={2, 3}, rhs_contracting_dims={2, 0}, + lhs_batch_dims={0}, rhs_batch_dims={3} + })", + R"( + // CHECK: @main_dot( + // CHECK-SAME: %[[A:.*]]: tensor<7x3x4x5xf32>, %[[B:.*]]: tensor<5x6x4x7xf32>, + // CHECK-SAME: %[[N:.*]]: index {xla.range = [0 : index, 6 : index]}, + // CHECK-SAME: %[[I:.*]]: index {xla.range = [0 : index, 2 : index]}, + // CHECK-SAME: %[[J:.*]]: index {xla.range = [0 : index, 5 : index]}) + // CHECK-SAME: -> f32 + // CHECK-SAME: { + // CHECK-DAG: %[[C0F:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + // CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index + // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index + // CHECK: %[[FOR0:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[C4]] step %[[C1]] + // CHECK-SAME: iter_args(%[[ACCUM0:.*]] = %[[C0F]]) -> (f32) { + // CHECK: %[[FOR1:.*]] = scf.for %[[L:.*]] = %[[C0]] to %[[C5]] step %[[C1]] + // CHECK-SAME: iter_args(%[[ACCUM1:.*]] = %[[ACCUM0]]) -> (f32) { + // CHECK-DAG: %[[CMPI0:.*]] = arith.cmpi sge, %[[N]], %[[C0]] : index + // CHECK-DAG: %[[CMPI1:.*]] = arith.cmpi sle, %[[N]], %[[C6]] : index + // CHECK-DAG: %[[N_IN_RANGE:.*]] = arith.andi %[[CMPI0]], %[[CMPI1]] : i1 + // CHECK-DAG: %[[CMPI2:.*]] = arith.cmpi sge, %[[I]], %[[C0]] : index + // CHECK-DAG: %[[CMPI3:.*]] = arith.cmpi sle, %[[I]], %[[C2]] : index + // CHECK-DAG: %[[I_IN_RANGE:.*]] = arith.andi %[[CMPI2]], %[[CMPI3]] : i1 + // CHECK-DAG: %[[N_I_IN_RANGE:.*]] = arith.andi %[[N_IN_RANGE]], %[[I_IN_RANGE]] : i1 + // CHECK-DAG: %[[CMPI4:.*]] = arith.cmpi sge, %[[J]], %[[C0]] : index + // CHECK-DAG: %[[CMPI5:.*]] = arith.cmpi sle, %[[J]], %[[C5]] : index + // CHECK-DAG: %[[J_IN_RANGE:.*]] = arith.andi %[[CMPI4]], %[[CMPI5]] : i1 + // CHECK-DAG: %[[N_I_J_IN_RANGE:.*]] = arith.andi %[[N_I_IN_RANGE]], %[[J_IN_RANGE]] : i1 + // CHECK: %[[IF0:.*]] = scf.if %[[N_I_J_IN_RANGE]] -> (f32) { + // CHECK-DAG: %[[A_N_I_K_L:.*]] = tensor.extract %[[A]][%[[N]], %[[I]], %[[K]], %[[L]]] : tensor<7x3x4x5xf32> + // CHECK-DAG: %[[B_L_J_K_N:.*]] = tensor.extract %[[B]][%[[L]], %[[J]], %[[K]], %[[N]]] : tensor<5x6x4x7xf32> + // CHECK-DAG: %[[MULF0:.*]] = arith.mulf %[[A_N_I_K_L]], %[[B_L_J_K_N]] : f32 + // CHECK-DAG: %[[ADDF0:.*]] = arith.addf %[[ACCUM1]], %[[MULF0]] : f32 + // CHECK-DAG: scf.yield %[[ADDF0]] : f32 + // CHECK: } else { + // CHECK: scf.yield %[[ACCUM1]] : f32 + // CHECK: } + // CHECK: scf.yield %[[IF0]] : f32 + // CHECK: } + // CHECK: scf.yield %[[FOR1]] : f32 + // CHECK: } + // CHECK: return %[[FOR0]] : f32 + // CHECK: } + )")); +} + +TEST_F(ElementalHloToMlirTest, ConvolutionSimple) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[2,8,12,4] parameter(0) + p1 = f32[4,3,5,16] parameter(1) + ROOT conv = f32[2,6,8,16] convolution(p0, p1), window={size=3x5 pad=0_0x0_0}, dim_labels=b01f_i01o->b01f + })", + R"( + // CHECK: @main_conv( + // CHECK-SAME: %[[LHS:.+]]: tensor<2x8x12x4xf32>, %[[RHS:.*]]: tensor<4x3x5x16xf32>, + // CHECK-SAME: %[[B:.+]]: index {xla.range = [0 : index, 1 : index]}, + // CHECK-SAME: %[[W:.+]]: index {xla.range = [0 : index, 5 : index]}, + // CHECK-SAME: %[[H:.+]]: index {xla.range = [0 : index, 7 : index]}, + // CHECK-SAME: %[[O:.+]]: index {xla.range = [0 : index, 15 : index]}) + // CHECK-SAME: -> f32 + // CHECK-DAG: %[[INIT:.+]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index + // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index + // CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index + // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { + // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { + // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { + // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { + // CHECK-DAG: %[[XX:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[X]], %[[W]]] + // CHECK-DAG: %[[YY:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[Y]], %[[H]]] + // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX]], %[[YY]], %[[I]]] : tensor<2x8x12x4xf32> + // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> + // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 + // CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[ACC]], %[[MUL]] : f32 + // CHECK-NEXT: scf.yield %[[ADD]] : f32 + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %[[ACC]] : f32 + // CHECK-NEXT: } + // CHECK-NEXT: scf.yield %[[R3]] : f32 + // CHECK: scf.yield %[[R2]] : f32 + // CHECK: scf.yield %[[R1]] : f32 + // CHECK: return %[[R0]] : f32 + )")); +} + +TEST_F(ElementalHloToMlirTest, ConvolutionWithWindowStrides) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[2,8,12,4] parameter(0) + p1 = f32[4,3,5,16] parameter(1) + ROOT conv = f32[2,3,4,16] convolution(p0, p1), window={size=3x5 stride=2x2 pad=0_0x0_0}, dim_labels=b01f_i01o->b01f + })", + R"( + // CHECK: @main_conv( + // CHECK-SAME: %[[LHS:.+]]: tensor<2x8x12x4xf32>, %[[RHS:.*]]: tensor<4x3x5x16xf32>, + // CHECK-SAME: %[[B:.+]]: index {xla.range = [0 : index, 1 : index]}, + // CHECK-SAME: %[[W:.+]]: index {xla.range = [0 : index, 2 : index]}, + // CHECK-SAME: %[[H:.+]]: index {xla.range = [0 : index, 3 : index]}, + // CHECK-SAME: %[[O:.+]]: index {xla.range = [0 : index, 15 : index]}) + // CHECK-SAME: -> f32 + // CHECK-DAG: %[[INIT:.+]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index + // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index + // CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index + // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { + // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { + // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { + // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { + // CHECK-DAG: %[[XX:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 2)>()[%[[X]], %[[W]]] + // CHECK-DAG: %[[YY:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 2)>()[%[[Y]], %[[H]]] + // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX]], %[[YY]], %[[I]]] : tensor<2x8x12x4xf32> + // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> + // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 + // CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[ACC]], %[[MUL]] : f32 + // CHECK-NEXT: scf.yield %[[ADD]] : f32 + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %[[ACC]] : f32 + // CHECK-NEXT: } + // CHECK-NEXT: scf.yield %[[R3]] : f32 + // CHECK: scf.yield %[[R2]] : f32 + // CHECK: scf.yield %[[R1]] : f32 + // CHECK: return %[[R0]] : f32 + )")); +} + +TEST_F(ElementalHloToMlirTest, ConvolutionWithPadding) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[2,8,12,4] parameter(0) + p1 = f32[4,3,5,16] parameter(1) + ROOT conv = f32[2,8,12,16] convolution(p0, p1), window={size=3x5 pad=1_1x2_2}, dim_labels=b01f_i01o->b01f + })", + R"( + // CHECK: @main_conv( + // CHECK-SAME: %[[LHS:.+]]: tensor<2x8x12x4xf32>, %[[RHS:.*]]: tensor<4x3x5x16xf32>, + // CHECK-SAME: %[[B:.+]]: index {xla.range = [0 : index, 1 : index]}, + // CHECK-SAME: %[[W:.+]]: index {xla.range = [0 : index, 7 : index]}, + // CHECK-SAME: %[[H:.+]]: index {xla.range = [0 : index, 11 : index]}, + // CHECK-SAME: %[[O:.+]]: index {xla.range = [0 : index, 15 : index]}) + // CHECK-SAME: -> f32 + // CHECK-DAG: %[[INIT:.+]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index + // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index + // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index + // CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index + // CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index + // CHECK-DAG: %[[C13:.+]] = arith.constant 13 : index + // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { + // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { + // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { + // CHECK-DAG: %[[TESTX:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[X]], %[[W]]] + // CHECK-DAG: %[[TXGE:.+]] = arith.cmpi sge, %[[TESTX]], %[[C1]] : index + // CHECK-DAG: %[[TXLE:.+]] = arith.cmpi sle, %[[TESTX]], %[[C8]] : index + // CHECK-DAG: %[[TX:.+]] = arith.andi %[[TXGE]], %[[TXLE]] : i1 + // CHECK-DAG: %[[TESTY:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[Y]], %[[H]]] + // CHECK-DAG: %[[TYGE:.+]] = arith.cmpi sge, %[[TESTY]], %[[C2]] : index + // CHECK-DAG: %[[TYLE:.+]] = arith.cmpi sle, %[[TESTY]], %[[C13]] : index + // CHECK-DAG: %[[TY:.+]] = arith.andi %[[TYGE]], %[[TYLE]] : i1 + // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { + // CHECK-DAG: %[[XX:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1 - 1)>()[%[[X]], %[[W]]] + // CHECK-DAG: %[[YY:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1 - 2)>()[%[[Y]], %[[H]]] + // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX]], %[[YY]], %[[I]]] : tensor<2x8x12x4xf32> + // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> + // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 + // CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[ACC]], %[[MUL]] : f32 + // CHECK-NEXT: scf.yield %[[ADD]] : f32 + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %[[ACC]] : f32 + // CHECK-NEXT: } + // CHECK-NEXT: scf.yield %[[R3]] : f32 + // CHECK: scf.yield %[[R2]] : f32 + // CHECK: scf.yield %[[R1]] : f32 + // CHECK: return %[[R0]] : f32 + )")); +} + +TEST_F(ElementalHloToMlirTest, ConvolutionWithLhsDilation) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[2,8,12,4] parameter(0) + p1 = f32[4,3,5,16] parameter(1) + ROOT conv = f32[2,13,19,16] convolution(p0, p1), window={size=3x5 pad=0_0x0_0 lhs_dilate=2x2}, dim_labels=b01f_i01o->b01f + })", + R"( + // CHECK: @main_conv( + // CHECK-SAME: %[[LHS:.+]]: tensor<2x8x12x4xf32>, %[[RHS:.*]]: tensor<4x3x5x16xf32>, + // CHECK-SAME: %[[B:.+]]: index {xla.range = [0 : index, 1 : index]}, + // CHECK-SAME: %[[W:.+]]: index {xla.range = [0 : index, 12 : index]}, + // CHECK-SAME: %[[H:.+]]: index {xla.range = [0 : index, 18 : index]}, + // CHECK-SAME: %[[O:.+]]: index {xla.range = [0 : index, 15 : index]}) + // CHECK-SAME: -> f32 + // CHECK-DAG: %[[INIT:.+]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index + // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index + // CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index + // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { + // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { + // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { + // CHECK-DAG: %[[TESTX:.+]] = affine.apply affine_map<()[s0, s1] -> ((s0 + s1) mod 2)>()[%[[X]], %[[W]]] + // CHECK-DAG: %[[TX:.+]] = arith.cmpi eq, %[[TESTX]], %[[C0]] : index + // CHECK-DAG: %[[TESTY:.+]] = affine.apply affine_map<()[s0, s1] -> ((s0 + s1) mod 2)>()[%[[Y]], %[[H]]] + // CHECK-DAG: %[[TY:.+]] = arith.cmpi eq, %[[TESTY]], %[[C0]] : index + // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { + // CHECK-DAG: %[[XX:.+]] = affine.apply affine_map<()[s0, s1] -> ((s0 + s1) floordiv 2)>()[%[[X]], %[[W]]] + // CHECK-DAG: %[[YY:.+]] = affine.apply affine_map<()[s0, s1] -> ((s0 + s1) floordiv 2)>()[%[[Y]], %[[H]]] + // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX]], %[[YY]], %[[I]]] : tensor<2x8x12x4xf32> + // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> + // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 + // CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[ACC]], %[[MUL]] : f32 + // CHECK-NEXT: scf.yield %[[ADD]] : f32 + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %[[ACC]] : f32 + // CHECK-NEXT: } + // CHECK-NEXT: scf.yield %[[R3]] : f32 + // CHECK: scf.yield %[[R2]] : f32 + // CHECK: scf.yield %[[R1]] : f32 + // CHECK: return %[[R0]] : f32 + )")); +} + +TEST_F(ElementalHloToMlirTest, ConvolutionWithRhsDilation) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[2,8,12,4] parameter(0) + p1 = f32[4,3,5,16] parameter(1) + ROOT conv = f32[2,4,4,16] convolution(p0, p1), window={size=3x5 pad=0_0x0_0 rhs_dilate=2x2}, dim_labels=b01f_i01o->b01f + })", + R"( + // CHECK: @main_conv( + // CHECK-SAME: %[[LHS:.+]]: tensor<2x8x12x4xf32>, %[[RHS:.*]]: tensor<4x3x5x16xf32>, + // CHECK-SAME: %[[B:.+]]: index {xla.range = [0 : index, 1 : index]}, + // CHECK-SAME: %[[W:[^ ]+]]: index {xla.range = [0 : index, 3 : index]}, + // CHECK-SAME: %[[H:.+]]: index {xla.range = [0 : index, 3 : index]}, + // CHECK-SAME: %[[O:.+]]: index {xla.range = [0 : index, 15 : index]}) + // CHECK-SAME: -> f32 + // CHECK-DAG: %[[INIT:.+]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index + // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index + // CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index + // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { + // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { + // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { + // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { + // CHECK-DAG: %[[XX:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 * 2 + s1)>()[%[[X]], %[[W]]] + // CHECK-DAG: %[[YY:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 * 2 + s1)>()[%[[Y]], %[[H]]] + // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX]], %[[YY]], %[[I]]] : tensor<2x8x12x4xf32> + // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> + // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 + // CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[ACC]], %[[MUL]] : f32 + // CHECK-NEXT: scf.yield %[[ADD]] : f32 + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %[[ACC]] : f32 + // CHECK-NEXT: } + // CHECK-NEXT: scf.yield %[[R3]] : f32 + // CHECK: scf.yield %[[R2]] : f32 + // CHECK: scf.yield %[[R1]] : f32 + // CHECK: return %[[R0]] : f32 + )")); +} + +TEST_F(ElementalHloToMlirTest, ConvolutionWithFeatureGroupCount) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[2,8,12,4] parameter(0) + p1 = f32[2,3,5,16] parameter(1) + ROOT conv = f32[2,6,8,16] convolution(p0, p1), window={size=3x5 pad=0_0x0_0}, dim_labels=b01f_i01o->b01f, feature_group_count=2 + })", + R"( + // CHECK: @main_conv( + // CHECK-SAME: %[[LHS:.+]]: tensor<2x8x12x4xf32>, %[[RHS:.*]]: tensor<2x3x5x16xf32>, + // CHECK-SAME: %[[B:.+]]: index {xla.range = [0 : index, 1 : index]}, + // CHECK-SAME: %[[W:.+]]: index {xla.range = [0 : index, 5 : index]}, + // CHECK-SAME: %[[H:.+]]: index {xla.range = [0 : index, 7 : index]}, + // CHECK-SAME: %[[O:.+]]: index {xla.range = [0 : index, 15 : index]}) + // CHECK-SAME: -> f32 + // CHECK-DAG: %[[INIT:.+]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index + // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index + // CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index + // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { + // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { + // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { + // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { + // CHECK-DAG: %[[XX:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[X]], %[[W]]] + // CHECK-DAG: %[[YY:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[Y]], %[[H]]] + // CHECK-DAG: %[[II:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + (s1 floordiv 8) * 2)>()[%[[I]], %[[O]]] + // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX]], %[[YY]], %[[II]]] : tensor<2x8x12x4xf32> + // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<2x3x5x16xf32> + // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 + // CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[ACC]], %[[MUL]] : f32 + // CHECK-NEXT: scf.yield %[[ADD]] : f32 + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %[[ACC]] : f32 + // CHECK-NEXT: } + // CHECK-NEXT: scf.yield %[[R3]] : f32 + // CHECK: scf.yield %[[R2]] : f32 + // CHECK: scf.yield %[[R1]] : f32 + // CHECK: return %[[R0]] : f32 + )")); +} + +TEST_F(ElementalHloToMlirTest, ConvolutionWithBatchGroupCount) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[2,8,12,4] parameter(0) + p1 = f32[4,3,5,16] parameter(1) + ROOT conv = f32[1,6,8,16] convolution(p0, p1), window={size=3x5 pad=0_0x0_0}, dim_labels=b01f_i01o->b01f, batch_group_count=2 + })", + R"( + // CHECK: @main_conv( + // CHECK-SAME: %[[LHS:.+]]: tensor<2x8x12x4xf32>, %[[RHS:.*]]: tensor<4x3x5x16xf32>, + // CHECK-SAME: %[[B:.+]]: index {xla.range = [0 : index, 0 : index]}, + // CHECK-SAME: %[[W:.+]]: index {xla.range = [0 : index, 5 : index]}, + // CHECK-SAME: %[[H:.+]]: index {xla.range = [0 : index, 7 : index]}, + // CHECK-SAME: %[[O:.+]]: index {xla.range = [0 : index, 15 : index]}) + // CHECK-SAME: -> f32 + // CHECK-DAG: %[[INIT:.+]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index + // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index + // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index + // CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index + // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { + // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { + // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[A2:.+]] = %[[A1]]) -> (f32) { + // CHECK-NEXT: %[[R3:.+]] = scf.for %[[G:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A2]]) -> (f32) { + // CHECK: %[[R4:.+]] = scf.if {{.+}} -> (f32) { + // CHECK-DAG: %[[BB:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[G]], %[[B]]] + // CHECK-DAG: %[[XX:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[X]], %[[W]]] + // CHECK-DAG: %[[YY:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[Y]], %[[H]]] + // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[BB]], %[[XX]], %[[YY]], %[[I]]] : tensor<2x8x12x4xf32> + // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> + // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 + // CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[ACC]], %[[MUL]] : f32 + // CHECK-NEXT: scf.yield %[[ADD]] : f32 + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.yield %[[ACC]] : f32 + // CHECK-NEXT: } + // CHECK-NEXT: scf.yield %[[R4]] : f32 + // CHECK: scf.yield %[[R3]] : f32 + // CHECK: scf.yield %[[R2]] : f32 + // CHECK: scf.yield %[[R1]] : f32 + // CHECK: return %[[R0]] : f32 + )")); +} + +TEST_F(ElementalHloToMlirTest, Transpose) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[4,5,6] parameter(0) + ROOT transpose = f32[6,5,4] transpose(p0), dimensions={2,1,0} + })", + R"( + // CHECK: @main_transpose( + // CHECK-SAME: %[[ARG0:.*]]: tensor<4x5x6xf32>, + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}}, + // CHECK-SAME: %[[Z:.*]]: index {{{.*}}} + // CHECK: %[[RET:.*]] = tensor.extract %[[ARG0]] + // CHECK-SAME: [%[[Z]], %[[Y]], %[[X]]] + // CHECK: return %[[RET]] + )")); +} + +TEST_F(ElementalHloToMlirTest, Broadcast) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[4,5] parameter(0) + ROOT broadcast = f32[6,4,5] broadcast(p0), dimensions={1,2} + })", + R"( + // CHECK: @main_broadcast( + // CHECK-SAME: %[[ARG0:.*]]: tensor<4x5xf32>, + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}}, + // CHECK-SAME: %[[Z:.*]]: index {{{.*}}} + // CHECK: %[[RET:.*]] = tensor.extract %[[ARG0]] + // CHECK-SAME: [%[[Y]], %[[Z]]] + // CHECK: return %[[RET]] + )")); +} + +TEST_F(ElementalHloToMlirTest, Add) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[4] parameter(0) + p1 = f32[4] parameter(1) + ROOT add = f32[4] add(p0, p1) + })", + R"( + // CHECK: @main_add( + // CHECK-SAME: %[[ARG0:.*]]: tensor<4xf32>, %[[ARG1:.*]]: tensor<4xf32>, + // CHECK-SAME: %[[X:.*]]: index {{.*}} + // CHECK: %[[A:.*]] = tensor.extract %[[ARG0]][%[[X]]] + // CHECK: %[[B:.*]] = tensor.extract %[[ARG1]][%[[X]]] + // CHECK: %[[RET:.*]] = arith.addf %[[A]], %[[B]] + // CHECK: return %[[RET]] + )")); +} + +TEST_F(ElementalHloToMlirTest, Complex) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[4] parameter(0) + p1 = f32[4] parameter(1) + ROOT add = c64[4] complex(p0, p1) + })", + R"( + // CHECK: @main_add( + // CHECK-SAME: %[[ARG0:.*]]: tensor<4xf32>, %[[ARG1:.*]]: tensor<4xf32>, + // CHECK-SAME: %[[X:.*]]: index {{.*}} + // CHECK: %[[A:.*]] = tensor.extract %[[ARG0]][%[[X]]] + // CHECK: %[[B:.*]] = tensor.extract %[[ARG1]][%[[X]]] + // CHECK: %[[RET:.*]] = complex.create %[[A]], %[[B]] + // CHECK: return %[[RET]] + )")); +} + +TEST_F(ElementalHloToMlirTest, ComplexAbs) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = c64[4] parameter(0) + ROOT abs = f32[4] abs(p0) + })", + R"( + // CHECK: @main_abs( + // CHECK-SAME: %[[ARG0:.*]]: tensor<4xcomplex> + // CHECK-SAME: %[[X:.*]]: index {{.*}} + // CHECK: %[[A:.*]] = tensor.extract %[[ARG0]][%[[X]]] + // CHECK: %[[RET:.*]] = complex.abs %[[A]] : complex + // CHECK: return %[[RET]] + )")); +} + +TEST_F(ElementalHloToMlirTest, UnsignedDiv) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = u32[4] parameter(0) + p1 = u32[4] parameter(1) + ROOT div = u32[4] divide(p0, p1) + })", + R"( + // CHECK: @main_div( + // CHECK-SAME: %[[ARG0:.*]]: tensor<4xui32>, %[[ARG1:.*]]: tensor<4xui32>, + // CHECK-SAME: %[[X:.*]]: index {{.*}} + // CHECK: %[[DIV:.*]] = arith.divui %{{.*}}, %{{.*}} : i32 + )")); +} + +TEST_F(ElementalHloToMlirTest, ConvertToUnsigned) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[4] parameter(0) + ROOT convert = u32[4] convert(p0) + })", + R"( + // CHECK: @main_convert( + // CHECK: arith.fptoui %{{.*}} : f32 to i32 + )")); +} + +TEST_F(ElementalHloToMlirTest, PopulationCountUnsigned) { + TF_EXPECT_OK(Run(R"( + ENTRY main{ + p0 = u32[10,1,4]{2,1,0} parameter(0) + ROOT popcnt = u32[10,1,4]{2,1,0} popcnt(p0) + })", + R"( + // CHECK: @main_popcnt( + // CHECK: builtin.unrealized_conversion_cast %{{.*}} : ui32 to i32 + // CHECK: math.ctpop %{{.*}} : i32 + // CHECK: builtin.unrealized_conversion_cast %{{.*}} : i32 to ui32 + )")); +} + +TEST_F(ElementalHloToMlirTest, Epilogue) { + TF_EXPECT_OK(Run( + R"( + ENTRY main { + %p0 = f32[2,16,17] parameter(0) + %log = f32[2,16,17] log(%p0) + %transpose = f32[2,17,16] transpose(%log), dimensions={0,2,1} + %p1 = f32[] parameter(1) + %bc = f32[2,17,16] broadcast(%p1), dimensions={} + ROOT %add = f32[2,17,16] add(%transpose, %bc) + })", + R"( + // CHECK: @main__epilogue__( + // CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x17xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor + // CHECK-SAME: %[[X:.*]]: index {xla.range = [0 : index, 1 : + // CHECK-SAME: %[[Y:.*]]: index {xla.range = [0 : index, 16 : + // CHECK-SAME: %[[Z:.*]]: index {xla.range = [0 : index, 15 : + // CHECK-SAME: %[[TRANSPOSE:.*]]: f32) -> f32 + // CHECK: %[[B:.*]] = tensor.extract %[[ARG1]][] + // CHECK: %[[RET:.*]] = arith.addf %[[TRANSPOSE]], %[[B]] + // CHECK: return %[[RET]])", + [](const HloInstruction* instr) { + // Make the transpose a new root. + return instr->opcode() == HloOpcode::kTranspose; + })); +} + +TEST_F(ElementalHloToMlirTest, ScalarConstant) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[1,1] parameter(0) + c1 = f32[1,1] constant({{1.0}}) + ROOT add = f32[1,1] add(p0, c1) + })", + R"( + // CHECK: @main_add( + // CHECK-SAME: %[[ARG0:.*]]: tensor<1x1xf32> + // CHECK-SAME: %[[X:.*]]: index {{.*}}, %[[Y:.*]]: index {{.*}} + // CHECK: %[[C_1:.*]] = arith.constant 1 + // CHECK: %[[A:.*]] = tensor.extract %[[ARG0]][%[[X]], %[[Y]]] + // CHECK: %[[RET:.*]] = arith.addf %[[A]], %[[C_1]] + // CHECK: return %[[RET]] + })")); +} + +TEST_F(ElementalHloToMlirTest, DynamicSlice) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + in = f32[20,30] parameter(0) + i0 = s32[] parameter(1) + i1 = s32[] parameter(2) + ROOT slice = f32[4,5] dynamic-slice(in, i0, i1), dynamic_slice_sizes={4,5} + })", + R"( + // CHECK: @main_slice( + // CHECK-SAME: %[[ARG0:.*]]: tensor<20x30xf32>, + // CHECK-SAME: %[[I0_T:.*]]: tensor, %[[I1_T:.*]]: tensor, + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 + // CHECK-DAG: %[[C16:.*]] = arith.constant 16 + // CHECK-DAG: %[[C25:.*]] = arith.constant 25 + // CHECK: %[[I0:.*]] = tensor.extract %[[I0_T]] + // CHECK: %[[I0_1:.*]] = arith.index_cast %[[I0]] + // CHECK: %[[I0_2:.*]] = arith.minsi %[[I0_1]], %[[C16]] + // CHECK: %[[I0_3:.*]] = arith.maxsi %[[I0_2]], %[[C0]] + // CHECK: %[[X_IN:.*]] = arith.addi %[[X]], %[[I0_3]] + // CHECK: %[[I1:.*]] = tensor.extract %[[I1_T]] + // CHECK: %[[I1_1:.*]] = arith.index_cast %[[I1]] + // CHECK: %[[I1_2:.*]] = arith.minsi %[[I1_1]], %[[C25]] + // CHECK: %[[I1_3:.*]] = arith.maxsi %[[I1_2]], %[[C0]] + // CHECK: %[[Y_IN:.*]] = arith.addi %[[Y]], %[[I1_3]] + // CHECK: %[[RET:.*]] = tensor.extract %[[ARG0]][%[[X_IN]], %[[Y_IN]]] + // CHECK: return %[[RET]] + )")); +} + +TEST_F(ElementalHloToMlirTest, DynamicSliceUnsignedIndices) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + in = f32[20,30] parameter(0) + i0 = u32[] parameter(1) + i1 = u32[] parameter(2) + ROOT slice = f32[4,5] dynamic-slice(in, i0, i1), dynamic_slice_sizes={4,5} + })", + R"( + // CHECK: @main_slice( + // CHECK-SAME: %[[ARG0:.*]]: tensor<20x30xf32>, + // CHECK-SAME: %[[I0_T:.*]]: tensor, %[[I1_T:.*]]: tensor, + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index { + // CHECK-DAG: %[[C16:.*]] = arith.constant 16 + // CHECK-DAG: %[[C25:.*]] = arith.constant 25 + // CHECK: %[[I0:.*]] = tensor.extract %[[I0_T]] + // CHECK: %[[I0_SIGNLESS:.*]] = builtin.unrealized_conversion_cast %[[I0]] : ui32 to i32 + // CHECK: %[[I0_1:.*]] = arith.index_castui %[[I0_SIGNLESS]] + // CHECK: %[[I0_2:.*]] = arith.minui %[[I0_1]], %[[C16]] + // CHECK: %[[X_IN:.*]] = arith.addi %[[X]], %[[I0_2]] + // CHECK: %[[I1:.*]] = tensor.extract %[[I1_T]] + // CHECK: %[[I1_SIGNLESS:.*]] = builtin.unrealized_conversion_cast %[[I1]] : ui32 to i32 + // CHECK: %[[I1_1:.*]] = arith.index_castui %[[I1_SIGNLESS]] + // CHECK: %[[I1_2:.*]] = arith.minui %[[I1_1]], %[[C25]] + // CHECK: %[[Y_IN:.*]] = arith.addi %[[Y]], %[[I1_2]] + // CHECK: %[[RET:.*]] = tensor.extract %[[ARG0]][%[[X_IN]], %[[Y_IN]]] + // CHECK: return %[[RET]] + )")); +} + +TEST_F(ElementalHloToMlirTest, DynamicUpdateSlice) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + in = f32[20,30] parameter(0) + updates = f32[5,6] parameter(1) + i0 = s32[] parameter(2) + i1 = s32[] parameter(3) + ROOT updated = f32[20,30] dynamic-update-slice(in, updates, i0, i1) + })", + R"( + // CHECK: @main_updated( + // CHECK-SAME: %[[ARG0:.*]]: tensor<20x30xf32>, %[[ARG1:.*]]: tensor<5x6xf32> + // CHECK-SAME: %[[I0_T:.*]]: tensor, %[[I1_T:.*]]: tensor, + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 + // CHECK-DAG: %[[C5:.*]] = arith.constant 5 + // CHECK-DAG: %[[C6:.*]] = arith.constant 6 + // CHECK-DAG: %[[C15:.*]] = arith.constant 15 + // CHECK-DAG: %[[C24:.*]] = arith.constant 24 + // CHECK: %[[I0:.*]] = tensor.extract %[[I0_T]] + // CHECK: %[[I0_1:.*]] = arith.index_cast %[[I0]] + // CHECK: %[[I0_2:.*]] = arith.minsi %[[I0_1]], %[[C15]] + // CHECK: %[[START_X:.*]] = arith.maxsi %[[I0_2]], %[[C0]] + // CHECK: %[[END_X:.*]] = arith.addi %[[START_X]], %[[C5]] + // CHECK: %[[LOW_X:.*]] = arith.cmpi sge, %[[X]], %[[START_X]] + // CHECK: %[[HIGH_X:.*]] = arith.cmpi slt, %[[X]], %[[END_X]] + // CHECK: %[[BOUNDS_X:.*]] = arith.andi %[[LOW_X]], %[[HIGH_X]] + // CHECK: %[[UPDATES_X:.*]] = arith.subi %[[X]], %[[START_X]] + // CHECK: arith.andi + // CHECK: %[[BOUNDS:.*]] = arith.andi + // CHECK: scf.if %[[BOUNDS]] + // CHECK: tensor.extract %[[ARG1]][%[[UPDATES_X]] + // CHECK: } else { + // CHECK: tensor.extract %[[ARG0]][%[[X]] + )")); +} + +TEST_F(ElementalHloToMlirTest, DynamicUpdateSliceUnsigned) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + in = u32[20,30] parameter(0) + updates = u32[5,6] parameter(1) + i0 = s32[] parameter(2) + i1 = s32[] parameter(3) + ROOT updated = u32[20,30] dynamic-update-slice(in, updates, i0, i1) + })", + R"( + // CHECK: @main_updated( + // CHECK-SAME: %[[ARG0:.*]]: tensor<20x30xui32>, %[[ARG1:.*]]: tensor<5x6xui32> + // CHECK-SAME: %[[I0_T:.*]]: tensor, %[[I1_T:.*]]: tensor, + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 + // CHECK-DAG: %[[C5:.*]] = arith.constant 5 + // CHECK-DAG: %[[C6:.*]] = arith.constant 6 + // CHECK-DAG: %[[C15:.*]] = arith.constant 15 + // CHECK-DAG: %[[C24:.*]] = arith.constant 24 + // CHECK: %[[I0:.*]] = tensor.extract %[[I0_T]] + // CHECK: %[[I0_1:.*]] = arith.index_cast %[[I0]] + // CHECK: %[[I0_2:.*]] = arith.minsi %[[I0_1]], %[[C15]] + // CHECK: %[[START_X:.*]] = arith.maxsi %[[I0_2]], %[[C0]] + // CHECK: %[[END_X:.*]] = arith.addi %[[START_X]], %[[C5]] + // CHECK: %[[LOW_X:.*]] = arith.cmpi sge, %[[X]], %[[START_X]] + // CHECK: %[[HIGH_X:.*]] = arith.cmpi slt, %[[X]], %[[END_X]] + // CHECK: %[[BOUNDS_X:.*]] = arith.andi %[[LOW_X]], %[[HIGH_X]] + // CHECK: %[[UPDATES_X:.*]] = arith.subi %[[X]], %[[START_X]] + // CHECK: arith.andi + // CHECK: %[[BOUNDS:.*]] = arith.andi + // CHECK: scf.if %[[BOUNDS]] + // CHECK: %[[VAL0:.*]] = tensor.extract %[[ARG1]][%[[UPDATES_X]] + // CHECK: builtin.unrealized_conversion_cast %[[VAL0]] + // CHECK: } else { + // CHECK: %[[VAL1:.*]] = tensor.extract %[[ARG0]][%[[X]] + // CHECK: builtin.unrealized_conversion_cast %[[VAL1]] + )")); +} + +TEST_F(ElementalHloToMlirTest, IotaUnsigned) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + ROOT iota = u32[10,20] iota(), iota_dimension=0 + })", + R"( + // CHECK: @main_iota( + // CHECK-SAME: %[[I0:.*]]: index {{.*}}, %[[I1:.*]]: index {{.*}} { + // CHECK: %[[VAL:.*]] = arith.index_castui %[[I0]] : index to i32 + // CHECK: builtin.unrealized_conversion_cast %[[VAL]] : i32 to ui32 + )")); +} + +TEST_F(ElementalHloToMlirTest, IotaComplex) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + ROOT iota = c64[6,4,5] iota(), iota_dimension=1 + })", + R"( + // CHECK: @main_iota( + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}}, + // CHECK-SAME: %[[Z:.*]]: index {{{.*}}} + // CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[I:.*]] = arith.index_castui %[[Y]] : index to i32 + // CHECK: %[[F:.*]] = arith.sitofp %[[I]] : i32 to f32 + // CHECK: %[[RET:.*]] = complex.create %[[F]], %[[ZERO]] : complex + // CHECK: return %[[RET]] + )")); +} + +TEST_F(ElementalHloToMlirTest, MixedIndexingTuple) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + %p0 = f32[10,10] parameter(0) + %p1 = f32[100] parameter(1) + ROOT tuple = (f32[10,10], f32[100]) tuple(%p0, %p1) + })", + R"( + // CHECK: @main_tuple( + // CHECK-SAME: %[[P0:.*]]: tensor<10x10xf32>, + // CHECK-SAME: %[[P1:.*]]: tensor<100xf32>, + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}} + // CHECK: %[[A:.*]] = tensor.extract %[[P0]][%[[X]], %[[Y]]] + // CHECK: %[[IDX:.*]] = affine.apply + // CHECK-SAME: affine_map<()[s0, s1] -> (s0 * 10 + s1)>() + // CHECK-SAME: [%[[X]], %[[Y]]] + // CHECK: %[[B:.*]] = tensor.extract %[[P1]][%[[IDX]]] + // CHECK: return %[[A]], %[[B]] + )")); +} + +} // namespace +} // namespace mlir_converter +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/expand_float_ops.cc b/xla/service/gpu/fusions/mlir/expand_float_ops.cc new file mode 100644 index 0000000000000..6ad33a9152e9d --- /dev/null +++ b/xla/service/gpu/fusions/mlir/expand_float_ops.cc @@ -0,0 +1,160 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Math/Transforms/Passes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" +#include "xla/primitive_util.h" +#include "xla/service/gpu/fusions/mlir/passes.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +#define GEN_PASS_DEF_EXPANDFLOATOPSPASS +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +namespace { + +template +struct RewriteToCmpSelect : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + RewriteToCmpSelect(mlir::MLIRContext* context, bool include_f32) + : mlir::OpRewritePattern(context), include_f32(include_f32) {} + + mlir::LogicalResult matchAndRewrite( + OpTy op, mlir::PatternRewriter& rewriter) const override { + if (op.getType().isF32() && !include_f32) { + return rewriter.notifyMatchFailure(op, "not rewriting f32 min/max"); + } + + auto lhs_is_nan = rewriter.create( + op.getLoc(), mlir::arith::CmpFPredicate::UNE, op.getLhs(), op.getLhs()); + auto rhs_is_not_nan = rewriter.create( + op.getLoc(), mlir::arith::CmpFPredicate::OEQ, op.getRhs(), op.getRhs()); + + auto return_lhs = rewriter + .create(op.getLoc(), pred, + op.getLhs(), op.getRhs()) + .getResult(); + + // logic: isNaN(lhs) || (!isNan(rhs) && return_lhs) ? lhs : rhs + return_lhs = rewriter.create( + op.getLoc(), lhs_is_nan, + rewriter.create(op.getLoc(), rhs_is_not_nan, + return_lhs)); + + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), return_lhs, op.getLhs(), op.getRhs()); + return mlir::success(); + } + + bool include_f32; +}; + +struct RewriteErf32Pattern : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::math::ErfOp op, mlir::PatternRewriter& rewriter) const override { + namespace ma = mlir::arith; + if (!op.getType().isF32()) { + return rewriter.notifyMatchFailure(op, "not an f32 erf"); + } + + static const std::array kAlpha{ + 0.00022905065861350646f, 0.0034082910107109506f, 0.050955695062380861f, + 0.18520832239976145f, 1.128379143519084f}; + + static const std::array kBeta{-1.1791602954361697e-7, + 0.000023547966471313185f, + 0.0010179625278914885f, + 0.014070470171167667f, + 0.11098505178285362f, + 0.49746925110067538f, + 1.0f}; + + // We clamp x to be within [-c;c] where c = erfinv(1-2^-23), outside of + // which x should be +/-1. + constexpr float kErfInvOneMinusHalfULP = 3.7439211627767994f; + + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto c = [&](float v) -> mlir::Value { + return b.create(llvm::APFloat(v), + rewriter.getF32Type()); + }; + + auto poly = [&](auto x, auto coefficients) -> mlir::Value { + auto r = c(coefficients[0]); + for (int i = 1; i < coefficients.size(); ++i) { + r = b.create(r, x, c(coefficients[i])); + } + return r; + }; + + mlir::Value x = op.getOperand(); + x = b.create(x, c(-kErfInvOneMinusHalfULP)); + x = b.create(x, c(kErfInvOneMinusHalfULP)); + mlir::Value x2 = b.create(x, x); + + rewriter.replaceOpWithNewOp( + op, b.create(x, poly(x2, kAlpha)), poly(x2, kBeta)); + + return mlir::success(); + } +}; + +class ExpandFloatOpsPass + : public impl::ExpandFloatOpsPassBase { + public: + using ExpandFloatOpsPassBase::ExpandFloatOpsPassBase; + void runOnOperation() override { + mlir::RewritePatternSet patterns(&getContext()); + patterns.add>( + &getContext(), /*include_f32=*/pre_ampere_); + patterns.add>( + &getContext(), /*include_f32=*/pre_ampere_); + mlir::populatePolynomialApproximateTanhPattern(patterns); + patterns.add(&getContext()); + if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr CreateExpandFloatOpsPass(bool pre_ampere) { + return createExpandFloatOpsPass(ExpandFloatOpsPassOptions{pre_ampere}); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/ir/BUILD b/xla/service/gpu/fusions/mlir/ir/BUILD new file mode 100644 index 0000000000000..e82df33e96521 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/ir/BUILD @@ -0,0 +1,70 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +td_library( + name = "xla_gpu_ops_td_files", + srcs = glob(["*.td"]), + includes = ["."], + deps = [ + "@llvm-project//mlir:CallInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "xla_gpu_ops_inc_gen", + strip_include_prefix = ".", + tbl_outs = [ + ( + ["-gen-op-decls"], + "xla_gpu_ops.h.inc", + ), + ( + ["-gen-op-defs"], + "xla_gpu_ops.cc.inc", + ), + ( + ["-gen-dialect-decls"], + "xla_gpu_dialect.h.inc", + ), + ( + ["-gen-dialect-defs"], + "xla_gpu_dialect.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "xla_gpu_ops.td", + deps = [":xla_gpu_ops_td_files"], +) + +cc_library( + name = "xla_gpu", + srcs = ["xla_gpu_ops.cc"], + hdrs = ["xla_gpu_ops.h"], + deps = [ + ":xla_gpu_ops_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeOpInterface", + "@llvm-project//mlir:CallOpInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + ], +) diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc new file mode 100644 index 0000000000000..45fe249097fdc --- /dev/null +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc @@ -0,0 +1,176 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" + +#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/DialectImplementation.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/MLIRContext.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/PatternMatch.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc.inc" + +namespace xla { +namespace gpu { +namespace { +struct XlaGpuInlinerInterface : public mlir::DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + // Returns true if the given operation 'callable', that implements the + // 'CallableOpInterface', can be inlined into the position given call + // operation 'call', that is registered to the current dialect and implements + // the `CallOpInterface`. 'wouldBeCloned' is set to true if the region of the + // given 'callable' is set to be cloned during the inlining process, or false + // if the region is set to be moved in-place (i.e. no duplicates would be + // created). + bool isLegalToInline(mlir::Operation *call, mlir::Operation *callable, + bool wouldBeCloned) const final { + if (!wouldBeCloned) { + // If no duplicate would be created, 'call' is likely the only caller of + // 'callable'. + return true; + } + // Otherwise, inline only if the called function is small. We could + // theoretically also inline if there is no other caller in the function + // that contains the callee that has a call path to the callable, but that + // is more expensive to check. + auto func_op = mlir::dyn_cast(callable); + if (!func_op) { + return false; + } + auto region = func_op.getCallableRegion(); + if (!region) { + return false; + } + const int kMaxOperationsToInline = 8; + return region->front().getOperations().size() <= kMaxOperationsToInline; + } + // Returns true if the given operation 'op', that is registered to this + // dialect, can be inlined into the given region, false otherwise. + // 'wouldBeCloned' is set to true if the given 'op' is set to be cloned + // during the inlining process, or false if the operation is set to be moved + // in-place(i.e. no duplicates would be created). 'valueMapping' contains any + // remapped values from within the 'src' region. This can be used to examine + // what values may potentially replace the operands to 'op'. + bool isLegalToInline(mlir::Operation *op, mlir::Region *dest, + bool wouldBeCloned, + mlir::IRMapping &valueMapping) const final { + // We allow any op from the xla_gpu dialect to be inlined. + return true; + } +}; +} // namespace + +void XlaGpuDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc.inc" +#undef GET_OP_LIST + >(); + addInterfaces(); +} + +mlir::LogicalResult PureCallOp::verifySymbolUses( + mlir::SymbolTableCollection &symbolTable) { + auto callee = getCalleeAttr(); + auto function = + symbolTable.lookupNearestSymbolFrom(*this, callee); + if (!function) { + return emitError("'f' attribute refers to an undefined function: ") + << callee; + } + + int func_arg_count = function.getFunctionType().getNumInputs(); + int arg_count = getOperands().size(); + + if (arg_count != func_arg_count) { + return emitError() << "argument count mismatch: 'operands' has " + << arg_count << " arguments, but '" << callee + << "' expects " << func_arg_count; + } + + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// AllocateSharedOp +//===----------------------------------------------------------------------===// + +void AllocateSharedOp::getAsmResultNames( + llvm::function_ref setNameFn) { + setNameFn(getResult(), "shmem"); +} + +//===----------------------------------------------------------------------===// +// AtomicRMWOp +//===----------------------------------------------------------------------===// + +void AtomicRMWOp::getAsmResultNames( + llvm::function_ref setNameFn) { + setNameFn(getResult(), "atomic_rmw"); +} + +using mlir::OpBuilder; +using mlir::OperationState; +using mlir::RankedTensorType; +using mlir::Region; +using mlir::Type; +using mlir::Value; +using mlir::ValueRange; + +void AtomicRMWOp::build(OpBuilder &builder, OperationState &result, + Value tensor, ValueRange ivs) { + OpBuilder::InsertionGuard g(builder); + result.addOperands(tensor); + result.addOperands(ivs); + result.addTypes(tensor.getType()); + + auto tensor_type = llvm::cast(tensor.getType()); + Region *body = result.addRegion(); + builder.createBlock(body); + body->addArgument(tensor_type.getElementType(), tensor.getLoc()); +} + +//===----------------------------------------------------------------------===// +// PureCallOp +//===----------------------------------------------------------------------===// + +void PureCallOp::getAsmResultNames( + llvm::function_ref setNameFn) { + for (auto result : getResults()) { + setNameFn(result, "pure_call"); + } +} + +//===----------------------------------------------------------------------===// +// SyncThreadsOp +//===----------------------------------------------------------------------===// + +void SyncThreadsOp::getAsmResultNames( + llvm::function_ref setNameFn) { + for (auto result : getResults()) { + setNameFn(result, "synced_tensor"); + } +} + +} // namespace gpu +} // namespace xla + +#define GET_OP_CLASSES +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc.inc" diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h new file mode 100644 index 0000000000000..c4b11a03e66ca --- /dev/null +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h @@ -0,0 +1,35 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_OPS_H_ +#define XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_OPS_H_ + +#include "mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/Attributes.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/Dialect.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/MLIRContext.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/OpDefinition.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/OpImplementation.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project // IWYU pragma : keep +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project // IWYU pragma : keep + +#define GET_OP_CLASSES +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.h.inc" +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h.inc" +#undef GET_OP_CLASSES + +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_OPS_H_ diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td new file mode 100644 index 0000000000000..1cbc4626eac85 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td @@ -0,0 +1,236 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_OPS +#define XLA_SERVICE_GPU_FUSIONS_MLIR_OPS + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/DialectBase.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/IR/OpAsmInterface.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def XlaGpuDialect : Dialect { + let name = "xla_gpu"; + + let description = [{ + This dialect contains ops required for lowering HLO to LLVM. + }]; + + let cppNamespace = "::xla::gpu"; +} + +class XLAGPU_Op traits = []> : + Op { +} + +def XLAGPU_AllocateSharedOp : XLAGPU_Op<"allocate_shared", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Allocates a shared memory tile."; + + let description = [{ + Allocates a shared memory tensor. The tensor is shared among all threads in + a block. + + ```mlir + %shared = xla_gpu.allocate_shared : tensor<32x32xf32> + ``` + }]; + + let results = (outs AnyStaticShapeTensor:$result); + + let assemblyFormat = "attr-dict `:` type($result)"; +} + +def XLAGPU_SyncThreadsOp : XLAGPU_Op<"sync_threads", [ + TypesMatchWith<"result type matches type of dest", + "operands", "results", "$_self">, + DeclareOpInterfaceMethods + ]> { + let summary = "Synchronizes threads."; + + let description = [{ + Synchronizes threads, taking any number of distributed tensors and returning + the synchronized state. + }]; + + let arguments = (ins Variadic:$operands); + let results = (outs Variadic:$results); + + let assemblyFormat = "operands attr-dict `:` type($operands)"; +} + +def XLAGPU_AtomicRMWOp : XLAGPU_Op<"atomic_rmw", + [Pure, + TypesMatchWith<"result type matches type of dest", + "input", "result", "$_self">, + DeclareOpInterfaceMethods + ]> { + let summary = "Atomically updates an element of a tensor."; + + let description = [{ + Reads an element from a tensor, computes the updated value for it, and + writes back the result. + }]; + + let arguments = (ins AnyRankedTensor:$input, Variadic:$indices); + let results = (outs AnyRankedTensor:$result); + // The region takes the current value in the tensor as an argument and yields + // the updated value. + let regions = (region SizedRegion<1>:$computation); + + let skipDefaultBuilders = 1; + let builders = [OpBuilder<(ins "mlir::Value":$memref, "mlir::ValueRange":$ivs)>]; + + let extraClassDeclaration = [{ + mlir::Block* getBody() { return &getComputation().front(); } + mlir::OpBuilder getBodyBuilder() { + return mlir::OpBuilder(getBody(), std::prev(getBody()->end())); + } + // The value stored in tensor[ivs]. + mlir::Value getCurrentValue() { + return getRegion().getArgument(0); + } + }]; + + let assemblyFormat = [{ + $input `[` $indices `]` `:` type($input) $computation attr-dict + }]; +} + +def XLAGPU_YieldOp : XLAGPU_Op<"yield", + [HasParent<"::xla::gpu::AtomicRMWOp">, Terminator]> { + let summary = "Terminator for atomic_rmw ops."; + let arguments = (ins AnyType:$result); + + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + +def XLAGPU_PureCallOp : XLAGPU_Op<"pure_call", + [Pure, CallOpInterface, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ]> { + let summary = "Function call without side effects."; + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let results = (outs Variadic); + let builders = [ + OpBuilder<(ins "mlir::func::FuncOp":$callee, CArg<"mlir::ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", mlir::SymbolRefAttr::get(callee)); + $_state.addTypes(callee.getFunctionType().getResults()); + }]>]; + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; + + let extraClassDeclaration = [{ + operand_range getArgOperands() { + return getOperands(); + } + + mlir::MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + + mlir::CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + + void setCalleeFromCallable(mlir::CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); + } + }]; +} + +def XLAGPU_ShuffleReduceOp : XLAGPU_Op<"shuffle_reduce", + [Pure, + TypesMatchWith<"result type matches type of operands", + "operands", "results", "$_self">]> { + let summary = "Performs a full warp shuffle and reduces the values"; + let description = [{ + This op performs a full warp shuffle and reduces the results using the given + function. The function is invoked with the operands from the low lanes, + followed by the operands from the high lanes. For example: + + ``` + shuffle_reduce @argmax(%value, %idx) : (f32, index) + ``` + + Will perform shuffles with distance 16, 8, 4, 2 and 1, and will invoke + @argmax five times. The first invocations will be + + ``` + @argmax(%value[i], %idx[i], %value[16+i], %idx[16+i]) + ``` + }]; + let builders = [ + OpBuilder<(ins "mlir::func::FuncOp":$reducer, "mlir::ValueRange":$operands, "int64_t":$max_distance), [{ + $_state.addOperands(operands); + $_state.addAttribute("reducer", mlir::SymbolRefAttr::get(reducer)); + $_state.addAttribute("max_distance", + mlir::IntegerAttr::get( + mlir::IntegerType::get(reducer.getContext(), 64), + max_distance)); + $_state.addTypes(reducer.getFunctionType().getResults()); + }]>]; + let arguments = (ins FlatSymbolRefAttr:$reducer, + Variadic:$operands, + I64Attr:$max_distance); + let results = (outs Variadic:$results); + + let assemblyFormat = [{ + $reducer `(` $operands `)` `to` $max_distance attr-dict `:` type($operands) + }]; +} + +def XLAGPU_PredicatedInsertOp : XLAGPU_Op<"predicated_insert", + [Pure, + TypesMatchWith<"result type matches type of operands", + "dest", "result", "$_self">, + TypesMatchWith<"value type matches element type of dest", + "dest", "value", + "::llvm::cast($_self).getElementType()">]> { + let summary = "Inserts a value into a tensor if a condition holds"; + let arguments = (ins I1:$condition, AnyType:$value, + AnyStaticShapeTensor:$dest, Variadic:$indices); + let results = (outs AnyStaticShapeTensor:$result); + + let assemblyFormat = [{ + $value `into` $dest `[` $indices `]` `if` $condition attr-dict `:` type($dest) + }]; +} + +def XLAGPU_PredicatedExtractOp : XLAGPU_Op<"predicated_extract", + [Pure, + TypesMatchWith<"fallback type matches element type of src", + "src", "fallback", + "::llvm::cast($_self).getElementType()">, + TypesMatchWith<"result type matches element type of src", + "src", "result", + "::llvm::cast($_self).getElementType()">]> { + let summary = "Inserts a value into a tensor if a condition holds"; + let arguments = (ins I1:$condition, AnyType:$fallback, + AnyStaticShapeTensor:$src, Variadic:$indices); + let results = (outs AnyType:$result); + + let assemblyFormat = [{ + $src `[` $indices `]` `if` $condition `else` $fallback attr-dict `:` type($src) + }]; +} + +#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS diff --git a/xla/service/gpu/fusions/mlir/lower_func.cc b/xla/service/gpu/fusions/mlir/lower_func.cc new file mode 100644 index 0000000000000..42ffb6dc1cbc0 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/lower_func.cc @@ -0,0 +1,64 @@ +/* Copyright 2024 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" + +namespace xla { +namespace gpu { + +#define GEN_PASS_DEF_LOWERFUNCPASS +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +namespace { + +using mlir::failure; +using mlir::success; + +struct RewriteCall : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + PureCallOp op, mlir::PatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op.getResultTypes(), op.getOperands(), op->getAttrs()); + return success(); + } +}; + +class LowerFuncPass : public impl::LowerFuncPassBase { + public: + void runOnOperation() override; +}; + +void LowerFuncPass::runOnOperation() { + mlir::RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr<::mlir::Pass> CreateLowerFuncPass() { + return std::make_unique(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/lower_tensors.cc b/xla/service/gpu/fusions/mlir/lower_tensors.cc new file mode 100644 index 0000000000000..11467a37effa5 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/lower_tensors.cc @@ -0,0 +1,485 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" // from @llvm-project +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "xla/layout_util.h" +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +#define GEN_PASS_DEF_LOWERTENSORSPASS +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +namespace { + +using mlir::failure; +using mlir::success; +using mlir::Value; +using mlir::ValueRange; + +struct RewriteFunctionSignatures : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::func::FuncOp op, mlir::PatternRewriter& rewriter) const override { + auto is_tensor = [](mlir::Type ty) { + return ty.isa(); + }; + if (!llvm::any_of(op.getFunctionType().getInputs(), is_tensor)) { + return rewriter.notifyMatchFailure(op, + "the function has no input tensors"); + } + + bool some_tensor_result = + llvm::any_of(op.getFunctionType().getResults(), is_tensor); + bool all_tensor_results = + llvm::all_of(op.getFunctionType().getResults(), is_tensor); + if (some_tensor_result && !all_tensor_results) { + op->emitOpError("function has a mix of tensor and non-tensor results"); + return failure(); + } + + mlir::TypeRange new_results = op.getFunctionType().getResults(); + if (some_tensor_result) { + new_results = {}; + auto terminator = op.getFunctionBody().front().getTerminator(); + rewriter.setInsertionPoint(terminator); + rewriter.replaceOpWithNewOp(terminator); + } + + llvm::SmallVector new_operands( + op.getFunctionType().getInputs()); + for (auto&& [index, operand] : llvm::enumerate(new_operands)) { + if (is_tensor(operand)) { + rewriter.setInsertionPointToStart(&op.getBody().front()); + auto cast = rewriter.create( + op.getLoc(), operand, op.getArgument(index)); + op.getArgument(index).replaceAllUsesExcept(cast.getResult(0), cast); + operand = mlir::LLVM::LLVMPointerType::get(op.getContext()); + } + } + + op.setFunctionType(rewriter.getFunctionType(new_operands, new_results)); + auto& entry = op->getRegion(0).front(); + for (auto [arg, arg_type] : llvm::zip(entry.getArguments(), new_operands)) { + arg.setType(arg_type); + } + + return success(); + } +}; + +mlir::LLVM::GEPOp CreateGep(mlir::Operation* op, + mlir::TypedValue tensor, + ValueRange indices, + mlir::PatternRewriter& rewriter) { + auto ptr = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); + auto byte_shape = ShapeUtil::MakeShape(U8, tensor.getType().getShape()); + if (auto encoding = tensor.getType().getEncoding()) { + *byte_shape.mutable_layout() = LayoutUtil::MakeLayout(llvm::to_vector( + encoding.cast().getValues())); + } + auto linearize_map = mlir::getAffineConstantExpr(0, rewriter.getContext()); + for (auto [dim, stride] : + llvm::enumerate(*ShapeUtil::ByteStrides(byte_shape))) { + linearize_map = linearize_map + + mlir::getAffineDimExpr(dim, rewriter.getContext()) * stride; + } + + rewriter.setInsertionPoint(op); + Value index = rewriter.create( + tensor.getLoc(), linearize_map, indices); + auto index_ty = + ShapeUtil::ElementsIn(byte_shape) < std::numeric_limits::max() + ? rewriter.getI32Type() + : rewriter.getI64Type(); + index = rewriter.create(tensor.getLoc(), index_ty, + index); + + auto tensor_ptr = rewriter + .create( + tensor.getLoc(), ptr, tensor) + .getResult(0); + mlir::LLVMTypeConverter converter(rewriter.getContext()); + auto llvm_element_type = + converter.convertType(tensor.getType().getElementType()); + auto gep = rewriter.create( + tensor.getLoc(), ptr, llvm_element_type, tensor_ptr, index); + gep.setInbounds(true); + return gep; +} + +struct RewriteTensorExtract : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::tensor::ExtractOp op, + mlir::PatternRewriter& rewriter) const override { + auto gep = CreateGep(op, op.getTensor(), op.getIndices(), rewriter); + auto load = + rewriter + .create(gep.getLoc(), gep.getElemType(), gep) + .getResult(); + rewriter.replaceOpWithNewOp( + op, op.getType(), load); + return success(); + } +}; + +struct RewriteTensorInsert : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::tensor::InsertOp op, + mlir::PatternRewriter& rewriter) const override { + Value dest = op.getDest(); + while (dest.getDefiningOp()) { + int result_number = dest.cast().getResultNumber(); + if (auto insert = dest.getDefiningOp()) { + dest = insert.getDest(); + } else if (auto scf_if = dest.getDefiningOp()) { + // Pick one of the branches, they're required to yield the same buffers. + dest = scf_if.getThenRegion().front().getTerminator()->getOperand( + result_number); + } else if (auto scf_for = dest.getDefiningOp()) { + dest = scf_for.getInitArgs()[result_number]; + } else if (dest.getDefiningOp() || + dest.getDefiningOp()) { + break; + } else { + return op.emitOpError("unsupported dest type"); + } + } + + auto gep = + CreateGep(op, dest.cast>(), + op.getIndices(), rewriter); + auto scalar_value = op.getScalar(); + mlir::LLVMTypeConverter converter(getContext()); + auto llvm_type = converter.convertType(scalar_value.getType()); + scalar_value = rewriter + .create( + gep.getLoc(), llvm_type, scalar_value) + .getResult(0); + rewriter.create(gep.getLoc(), scalar_value, gep); + + op.replaceAllUsesWith(op.getDest()); + op.erase(); + return success(); + } +}; + +struct RewriteCall : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::func::CallOp op, mlir::PatternRewriter& rewriter) const override { + if (!llvm::any_of(op->getOperandTypes(), [](mlir::Type ty) { + return ty.isa(); + })) { + return rewriter.notifyMatchFailure(op, "the call has no input tensors"); + } + + for (const auto&& [index, arg] : llvm::enumerate(op.getOperands())) { + if (arg.getType().isa()) { + op.setOperand( + index, + rewriter + .create( + op.getLoc(), + mlir::LLVM::LLVMPointerType::get(op.getContext()), arg) + .getResult(0)); + } + } + return success(); + } +}; + +struct RewriteAllocateShared : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + AllocateSharedOp op, mlir::PatternRewriter& rewriter) const override { + auto module = op->getParentOfType(); + auto shaped_ty = op.getResult().getType().cast(); + constexpr int kGPUSharedMemoryAddrSpace = 3; + auto array_ty = mlir::LLVM::LLVMArrayType::get(shaped_ty.getElementType(), + shaped_ty.getNumElements()); + + std::string name; + int index = 0; + do { + name = absl::StrCat("shared_", index); + ++index; + } while (module.lookupSymbol(name)); + + rewriter.setInsertionPointToStart(module.getBody()); + auto global = rewriter.create( + op.getLoc(), array_ty, /*isConstant=*/false, + /*linkage=*/mlir::LLVM::Linkage::Private, name, + /*value=*/mlir::Attribute{}, + /*alignment=*/0, kGPUSharedMemoryAddrSpace); + + rewriter.setInsertionPoint(op); + auto addr = rewriter.create(op.getLoc(), global); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), + rewriter + .create( + op.getLoc(), mlir::LLVM::LLVMPointerType::get(op.getContext()), + addr) + .getResult()); + return success(); + } +}; + +struct RewriteSyncThreads : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + SyncThreadsOp op, mlir::PatternRewriter& rewriter) const override { + rewriter.create(op.getLoc()); + rewriter.replaceOp(op, op.getOperands()); + return success(); + } +}; + +// Implements atomic binary operations using atomic compare-and-swap +// (atomicCAS) as follows: +// 1. Reads the value from the memory pointed to by output_address and +// records it as old_output. +// 2. Uses old_output as one of the source operand to perform the binary +// operation and stores the result in new_output. +// 3. Calls atomicCAS which implements compare-and-swap as an atomic +// operation. In particular, atomicCAS reads the value from the memory +// pointed to by output_address, and compares the value with old_output. If +// the two values equal, new_output is written to the same memory location +// and true is returned to indicate that the atomic operation succeeds. +// Otherwise, the new value read from the memory is returned. In this case, +// the new value is copied to old_output, and steps 2. and 3. are repeated +// until atomicCAS succeeds. +// +// On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers. If +// the element type of the binary operation is 32 bits or 64 bits, the integer +// type of the same size is used for the atomicCAS operation. On the other hand, +// if the element type is smaller than 32 bits, int32_t is used for the +// atomicCAS operation. In this case, atomicCAS reads and writes 32 bit values +// from the memory, which is larger than the memory size required by the +// original atomic binary operation. We mask off the last two bits of the +// output_address and use the result as an address to read the 32 bit values +// from the memory. This can avoid out of bound memory accesses if tensor +// buffers are 4 byte aligned and have a size of 4N, an assumption that the +// runtime can guarantee. +struct RewriteAtomicRMW : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + AtomicRMWOp op, mlir::PatternRewriter& rewriter) const override { + namespace ml = mlir::LLVM; + mlir::Location loc = op.getLoc(); + auto input = op.getInput(); + + // Use 32-bit atomic type for small input types. + mlir::Type result_ty = op.getResult().getType().getElementType(); + unsigned int result_size = result_ty.getIntOrFloatBitWidth(); + bool small_type = result_size < 32; + mlir::Type atomic_ty = + mlir::IntegerType::get(op.getContext(), small_type ? 32 : result_size); + + // Calculate load address for the input. + Value addr = CreateGep(op, input, op.getIndices(), rewriter); + Value shift, mask; + if (small_type) { + // Update input pointer by discarding the last two bits - i.e. align to + // 32-bit boundary for small input types (will not result in OOB, as the + // input alignment is at least 32 bits). + mlir::Type addr_int_ty = rewriter.getI64Type(); + Value addr_int = rewriter.create(loc, addr_int_ty, addr); + Value addr_offset = rewriter.create( + loc, addr_int, rewriter.create(loc, addr_int_ty, 3)); + Value index = rewriter.create( + loc, addr_offset, + rewriter.create(loc, addr_int_ty, -1)); + addr = + rewriter.create(loc, addr.getType(), rewriter.getI8Type(), + addr, index, /*inbounds=*/true); + + // Calculate the bit shift (assume little-endianness). + Value offset = rewriter.create(loc, atomic_ty, addr_offset); + shift = rewriter.create( + loc, offset, + rewriter.create(loc, offset.getType(), 8)); + + // Compose the update mask. + Value bits_long = rewriter.create(loc, atomic_ty, -1); + Value bits_short = rewriter.create( + loc, atomic_ty, + rewriter.create( + loc, rewriter.getIntegerType(result_size), -1)); + mask = rewriter.create( + loc, bits_long, rewriter.create(loc, bits_short, shift)); + } + + // Load initial atomic value and create the loop. + Value initial = rewriter.create(loc, atomic_ty, addr); + rewriter.create( + loc, mlir::TypeRange{atomic_ty}, ValueRange{initial}, + [&](mlir::OpBuilder& b, mlir::Location loc, ValueRange values) { + Value old_value = values[0]; + + // Convert atomic value to input value. + Value input_value; + if (small_type) { + Value short_value = b.create( + loc, b.getIntegerType(result_size), + b.create(loc, old_value, shift)); + input_value = b.create(loc, result_ty, short_value); + } else { + input_value = b.create(loc, result_ty, old_value); + } + + // Perform computation on the loaded input value. + rewriter.mergeBlocks(&op.getComputation().front(), b.getBlock(), + {input_value}); + auto yield_op = b.getBlock()->getTerminator(); + Value result = yield_op->getOperand(0); + rewriter.eraseOp(yield_op); + + // Convert resulting value to atomic value. + Value new_value; + if (small_type) { + Value cast_value = rewriter.create( + loc, atomic_ty, + rewriter.create( + loc, rewriter.getIntegerType(result_size), result)); + new_value = rewriter.create( + loc, rewriter.create(loc, old_value, mask), + rewriter.create(loc, cast_value, shift)); + } else { + new_value = b.create(loc, atomic_ty, result); + } + + // Try saving the result atomically, retry if failed. + Value cmpxchg = b.create( + loc, addr, old_value, new_value, + /*success_ordering=*/ml::AtomicOrdering::seq_cst, + /*failure_ordering=*/ml::AtomicOrdering::seq_cst); + Value next = b.create(loc, cmpxchg, 0); + Value ok = b.create(loc, cmpxchg, 1); + Value low_bit = + b.create(loc, b.getOneAttr(b.getI1Type())); + Value not_ok = b.create(loc, ok, low_bit); + b.create(loc, not_ok, ValueRange{next}); + }, + [&](mlir::OpBuilder& b, mlir::Location loc, ValueRange values) { + b.create(loc, values); + }); + rewriter.replaceOp(op, input); + return success(); + } +}; + +class LowerTensorsPass : public impl::LowerTensorsPassBase { + public: + void runOnOperation() override { + mlir::RewritePatternSet tensor_patterns(&getContext()); + tensor_patterns + .add(&getContext()); + if (mlir::failed(mlir::applyPatternsAndFoldGreedily( + getOperation(), std::move(tensor_patterns)))) { + signalPassFailure(); + } + + mlir::RewritePatternSet function_patterns(&getContext()); + function_patterns.add( + &getContext()); + mlir::scf::ForOp::getCanonicalizationPatterns(function_patterns, + &getContext()); + mlir::scf::IfOp::getCanonicalizationPatterns(function_patterns, + &getContext()); + if (mlir::failed(mlir::applyPatternsAndFoldGreedily( + getOperation(), std::move(function_patterns)))) { + signalPassFailure(); + } + + getOperation()->walk([this](mlir::LLVM::LoadOp load) { + Value addr = load.getAddr(); + while (auto gep = addr.getDefiningOp()) { + addr = gep.getBase(); + } + if (addr.getDefiningOp()) { + // Shared memory - no need to annotate anything. + return; + } + if (auto base = mlir::dyn_cast(addr)) { + if (auto func = mlir::dyn_cast( + base.getOwner()->getParentOp())) { + if (func.getArgAttr(base.getArgNumber(), "xla.invariant")) { + load.setInvariant(true); + } + return; + } + } + load.emitOpError("load op address is not (a GEP of) a function argument"); + signalPassFailure(); + }); + } +}; + +} // namespace + +std::unique_ptr<::mlir::Pass> CreateLowerTensorsPass() { + return std::make_unique(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/lower_to_llvm.cc b/xla/service/gpu/fusions/mlir/lower_to_llvm.cc new file mode 100644 index 0000000000000..625c9f8794b48 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/lower_to_llvm.cc @@ -0,0 +1,93 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" // from @llvm-project +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" // from @llvm-project +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" // from @llvm-project +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" // from @llvm-project +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" // from @llvm-project +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Arith/Transforms/Passes.h" // from @llvm-project +#include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace xla { +namespace gpu { + +#define GEN_PASS_DEF_LOWERTOLLVMPASS +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +namespace { + +class LowerToLLVMPass : public impl::LowerToLLVMPassBase { + public: + using LowerToLLVMPassBase::LowerToLLVMPassBase; + + void runOnOperation() override { + // Populate type conversions. + mlir::LLVMTypeConverter type_converter(getOperation().getContext()); + mlir::LLVMConversionTarget target(*getOperation().getContext()); + + // Populate patterns. + mlir::RewritePatternSet patterns(&getContext()); + mlir::populateAffineToStdConversionPatterns(patterns); + mlir::populateSCFToControlFlowConversionPatterns(patterns); + mlir::arith::populateArithExpandOpsPatterns(patterns); + mlir::arith::populateArithToLLVMConversionPatterns(type_converter, + patterns); + mlir::populateGpuToNVVMConversionPatterns(type_converter, patterns); + mlir::populateFuncToLLVMConversionPatterns(type_converter, patterns); + mlir::cf::populateControlFlowToLLVMConversionPatterns(type_converter, + patterns); + mlir::populateComplexToLLVMConversionPatterns(type_converter, patterns); + mlir::populateMathToLLVMConversionPatterns(type_converter, patterns); + + // Setup target. + mlir::configureGpuToNVVMConversionLegality(target); + target.addIllegalDialect(); + target.addLegalOp(); + + if (failed( + applyFullConversion(getOperation(), target, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr CreateLowerToLLVMPass() { + return std::make_unique(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc b/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc new file mode 100644 index 0000000000000..aab493ec184e0 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc @@ -0,0 +1,172 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/util.h" + +namespace xla { +namespace gpu { + +#define GEN_PASS_DEF_LOWERXLAGPUTOSCFPASS +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +namespace { + +using mlir::success; + +struct RewritePredicatedInsert : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + PredicatedInsertOp op, mlir::PatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op.getCondition(), + [&](mlir::OpBuilder& b, mlir::Location loc) { + b.create( + loc, b.create( + loc, op.getValue(), op.getDest(), op.getIndices()) + .getResult()); + }, + [&](mlir::OpBuilder& b, mlir::Location loc) { + b.create(loc, op.getDest()); + }); + return success(); + } +}; + +struct RewritePredicatedExtract : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + PredicatedExtractOp op, mlir::PatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op.getCondition(), + [&](mlir::OpBuilder& b, mlir::Location loc) { + b.create( + loc, b.create(loc, op.getSrc(), + op.getIndices()) + .getResult()); + }, + [&](mlir::OpBuilder& b, mlir::Location loc) { + b.create(loc, op.getFallback()); + }); + return success(); + } +}; + +struct RewriteShuffleReduce : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + ShuffleReduceOp op, mlir::PatternRewriter& rewriter) const override { + int max_distance = + op->getAttr("max_distance").cast().getInt(); + // TODO(jreiffers): Do this in a verifier. + if (max_distance & (max_distance - 1) || max_distance >= WarpSize()) { + return op->emitOpError("max_distance must be a power of 2 < WarpSize()"); + } + + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + mlir::ValueRange values = op.getOperands(); + for (int distance = max_distance; distance > 0; distance /= 2) { + namespace ml = mlir::LLVM; + auto shuffle = [&](mlir::Value v) { + return b + .create(v, distance, WarpSize(), + mlir::gpu::ShuffleMode::DOWN) + .getShuffleResult(); + }; + + llvm::SmallVector args = values; + for (auto value : values) { + // Shuffle within the warps. + auto ty = value.getType(); + int bit_width = ty.getIntOrFloatBitWidth(); + + if (bit_width == 32) { + value = shuffle(value); + } else { + int n_shuffles = CeilOfRatio(bit_width, 32); + auto int_ty = b.getIntegerType(bit_width); + auto padded_int_ty = b.getIntegerType(n_shuffles * 32); + value = b.create(int_ty, value); + value = b.create(padded_int_ty, value); + auto vector_type = ml::getVectorType(b.getI32Type(), n_shuffles); + value = b.create(vector_type, value); + mlir::Value result_vec = b.create(vector_type); + for (int i = 0; i < n_shuffles; ++i) { + auto idx = b.create(i, 32); + result_vec = b.create( + result_vec, shuffle(b.create(value, idx)), + idx); + } + value = b.create(padded_int_ty, result_vec); + value = b.create(int_ty, value); + value = b.create(ty, value); + } + args.push_back(value); + } + values = b.create(op.getReducerAttr().getAttr(), + op.getResultTypes(), args) + .getResults(); + } + rewriter.replaceOp(op, values); + return success(); + } +}; + +class LowerXlaGpuToScfPass + : public impl::LowerXlaGpuToScfPassBase { + public: + void runOnOperation() override { + mlir::RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr<::mlir::Pass> CreateLowerXlaGpuToScfPass() { + return std::make_unique(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/merge_pointers_to_same_slice.cc b/xla/service/gpu/fusions/mlir/merge_pointers_to_same_slice.cc new file mode 100644 index 0000000000000..ce9b73648ef55 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/merge_pointers_to_same_slice.cc @@ -0,0 +1,117 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace xla { +namespace gpu { + +#define GEN_PASS_DEF_MERGEPOINTERSTOSAMESLICEPASS +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +namespace { + +class MergePointersToSameSlicePass + : public impl::MergePointersToSameSlicePassBase< + MergePointersToSameSlicePass> { + public: + void runOnOperation() override; +}; + +struct PackedArgs { + llvm::BitVector args_to_erase; + // replacement_args[i] == i iff !args_to_erase[i]. + llvm::SmallVector replacement_args; + + PackedArgs() = default; + explicit PackedArgs(mlir::func::FuncOp func) { + absl::flat_hash_map> slice_to_operand; + args_to_erase.resize(func.getNumArguments()); + replacement_args.reserve(func.getNumArguments()); + for (int i = 0; i < func.getNumArguments(); ++i) { + replacement_args.push_back(i); + } + + for (auto [idx, operand] : llvm::enumerate(func.getArguments())) { + auto slice_index = func.getArgAttr(idx, "xla.slice_index"); + if (!slice_index) { + continue; + } + + auto& target_index = slice_to_operand[static_cast( + slice_index.cast().getInt())]; + if (target_index) { + replacement_args[idx] = *target_index; + args_to_erase[idx] = true; + } else { + target_index = idx; + } + } + } + + void Pack(mlir::func::FuncOp op) { + for (auto [idx, arg] : llvm::enumerate(op.getArguments())) { + if (replacement_args[idx] != idx) { + arg.replaceAllUsesWith(op.getArgument(replacement_args[idx])); + } + } + op.eraseArguments(args_to_erase); + for (int i = 0; i < op.getNumArguments(); ++i) { + if (op.getArgAttr(i, "xla.slice_index")) { + op.removeArgAttr(i, "xla.slice_index"); + op.setArgAttr(i, mlir::LLVM::LLVMDialect::getNoAliasAttrName(), + mlir::UnitAttr::get(op->getContext())); + } + } + } + + void Pack(mlir::func::CallOp op) { op->eraseOperands(args_to_erase); } +}; + +void MergePointersToSameSlicePass::runOnOperation() { + mlir::func::FuncOp entry; + + absl::flat_hash_map args_to_pack; + getOperation()->walk([&](mlir::func::FuncOp func) { + args_to_pack[func.getName()] = PackedArgs(func); + }); + getOperation()->walk([&](mlir::func::CallOp call) { + args_to_pack[call.getCallee()].Pack(call); + }); + getOperation()->walk([&](mlir::func::FuncOp func) { + args_to_pack[func.getName()].Pack(func); + }); +} + +} // namespace + +std::unique_ptr> +CreateMergePointersToSameSlicePass() { + return std::make_unique(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc b/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc new file mode 100644 index 0000000000000..f12af822e51ba --- /dev/null +++ b/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc @@ -0,0 +1,472 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicsNVPTX.h" +#include "llvm/Linker/Linker.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project +#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" // from @llvm-project +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" // from @llvm-project +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" // from @llvm-project +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project +#include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project +#include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" // from @llvm-project +#include "mlir/Target/LLVMIR/Export.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/mlir/passes.h" +#include "xla/service/gpu/fusions/mlir/type_util.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/kernel_arguments.h" +#include "xla/service/gpu/kernel_reuse_cache.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/runtime/kernel_thunk.h" +#include "xla/service/gpu/target_util.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_description.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +using llvm::SmallVector; +using mlir::Value; +using mlir::ValueRange; + +void AddRanges(llvm::Function* func, const LaunchDimensions& launch_dims, + llvm::Module* module) { + for (auto& block : *func) { + for (auto& instr : block) { + if (auto* call = llvm::dyn_cast(&instr)) { + if (auto* callee = call->getCalledFunction()) { + switch (callee->getIntrinsicID()) { + case llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x: + llvm_ir::AddRangeMetadata( + 0, launch_dims.thread_counts_per_block().x, call, module); + break; + case llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y: + llvm_ir::AddRangeMetadata( + 0, launch_dims.thread_counts_per_block().y, call, module); + break; + case llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z: + llvm_ir::AddRangeMetadata( + 0, launch_dims.thread_counts_per_block().z, call, module); + break; + case llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x: + llvm_ir::AddRangeMetadata(0, launch_dims.block_counts().x, call, + module); + break; + case llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y: + llvm_ir::AddRangeMetadata(0, launch_dims.block_counts().y, call, + module); + break; + case llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z: + llvm_ir::AddRangeMetadata(0, launch_dims.block_counts().z, call, + module); + break; + } + } + } + } + } +} + +} // namespace + +Value MlirFusionEmitterBase::EmitBlockId(mlir::ImplicitLocOpBuilder& builder, + int dim) const { + const auto& counts = launch_dimensions().block_counts(); + int64_t count = dim == 0 ? counts.x : dim == 1 ? counts.y : counts.z; + auto block_id = builder.create( + static_cast(dim)); + block_id->setAttr("xla.range", builder.getIndexArrayAttr({0, count - 1})); + return block_id; +} + +Value MlirFusionEmitterBase::EmitThreadId(mlir::ImplicitLocOpBuilder& builder, + int dim) const { + const auto& counts = launch_dimensions().thread_counts_per_block(); + int64_t count = dim == 0 ? counts.x : dim == 1 ? counts.y : counts.z; + auto thread_id = builder.create( + static_cast(dim)); + thread_id->setAttr("xla.range", builder.getIndexArrayAttr({0, count - 1})); + return thread_id; +} + +absl::StatusOr MlirFusionEmitterBase::Emit( + IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion) const { + TF_ASSIGN_OR_RETURN( + auto args, + KernelArguments::Create(ir_emitter_context.buffer_assignment(), &fusion)); + auto launch_dims = launch_dimensions(); + auto [status_or_entry, cached] = + ir_emitter_context.kernel_cache().GetWithStatus( + fusion.fused_instructions_computation(), args.args(), + /*discriminator=*/"", + [&]() -> absl::StatusOr { + std::string kernel_name = + ir_emitter_context.name_uniquer()->GetUniqueName( + llvm_ir::SanitizeFunctionName(std::string(fusion.name()))); + if (ir_emitter_context.emit_kernels()) { + TF_ASSIGN_OR_RETURN( + auto module, + CreateLLVMModule( + *ir_emitter_context.mlir_context(), + ir_emitter_context.llvm_module()->getContext(), + ir_emitter_context.gpu_device_info(), fusion, kernel_name, + &ir_emitter_context.buffer_assignment())); + auto* kernel_func = module->getFunction(kernel_name); + AddRanges(kernel_func, launch_dims, module.get()); + + auto* target = ir_emitter_context.llvm_module(); + module->setDataLayout(target->getDataLayout()); + module->setTargetTriple(target->getTargetTriple()); + + llvm::IRBuilder<> builder(module->getContext()); + AnnotateFunctionAsGpuKernel(module.get(), kernel_func, &builder); + TF_RETURN_IF_ERROR(AnnotateKernelLaunchDimensions( + ir_emitter_context.gpu_device_info(), launch_dims, + kernel_name, module.get())); + + // Use override flag because libdevice functions can be present in + // both. + CHECK(!llvm::Linker::linkModules( + *target, std::move(module), + llvm::Linker::Flags::OverrideFromSrc)); + } else { + VLOG(3) << "Skipped kernel compilation."; + } + + return KernelReuseCache::Entry{kernel_name, launch_dims, + std::nullopt, + /*shmem_bytes=*/0}; + }); + TF_ASSIGN_OR_RETURN(const KernelReuseCache::Entry* entry, status_or_entry); + + if (cached) { + VLOG(3) << "Reuse: " << fusion.name() << " -> " << entry->kernel_name; + } + + FusionEmissionResult result; + result.thunks.emplace_back(std::make_unique( + &fusion, entry->kernel_name, args.args(), launch_dims, entry->cluster_dim, + entry->shmem_bytes)); + return result; +} + +absl::StatusOr> +MlirFusionEmitterBase::CreateLLVMModule( + mlir::MLIRContext& mlir_context, llvm::LLVMContext& llvm_context, + const se::DeviceDescription& device, const HloFusionInstruction& fusion, + const std::string& entry_function_name, + const BufferAssignment* buffer_assignment) const { + TF_RET_CHECK(device.cuda_compute_capability().major >= 1) + << "Unsupported device type: " << device.name(); + TF_ASSIGN_OR_RETURN( + auto module, CreateMLIRModule(mlir_context, fusion, entry_function_name, + buffer_assignment)); + + mlir::PassManager pm(&mlir_context); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::mhlo::createConvertToSignlessPass()); + pm.addPass(CreatePropagateSliceIndicesPass()); + pm.addPass(CreateLowerFuncPass()); + pm.addPass(CreateLowerXlaGpuToScfPass()); + pm.addPass(CreateLowerTensorsPass()); + pm.addPass(mlir::createConvertComplexToStandardPass()); + pm.addPass(CreateMergePointersToSameSlicePass()); + + // LowerTensors creates new affine.apply ops. Fold and CSE them so + // simplify-affine has maximally folded expressions to work with. + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(CreateSimplifyAffinePass()); + // Replace comparisons that result in constant values (e.g. due to ranges not + // overlapping). This pass must run after SimplifyAffinePass, since that + // generates the range information. + pm.addPass(CreateSimplifyArithPass()); + + // simplify-affine lowers most affine.apply ops, but if it can't prove a + // division or modulo is unsigned, affine.apply ops will remain. + pm.addPass(mlir::createLowerAffinePass()); + + pm.addPass(mlir::createLoopInvariantCodeMotionPass()); + pm.addPass(mlir::createSymbolDCEPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(CreateLowerTensorsPass()); + pm.addPass(CreateExpandFloatOpsPass( + !device.cuda_compute_capability().IsAtLeastAmpere())); + pm.addPass(CreateLowerToLLVMPass()); + pm.addPass(mlir::createReconcileUnrealizedCastsPass()); + + if (pm.run(module.get()).failed()) { + std::string module_dump; + llvm::raw_string_ostream os(module_dump); + module->print(os); + return absl::InternalError(absl::StrFormat( + "Failed create LLVM module.\nHloFusionInstruction " + "computation:\n%s\nMLIR module:\n%s", + fusion.fused_instructions_computation()->ToString(), module_dump)); + } + + auto llvm_module = mlir::translateModuleToLLVMIR(module.get(), llvm_context); + TF_RET_CHECK(llvm_module != nullptr) + << "Failed to translate module to LLVM IR."; + + return llvm_module; +} + +absl::StatusOr> +MlirFusionEmitterBase::CreateMLIRModule( + mlir::MLIRContext& context, const HloFusionInstruction& fusion, + const std::string& entry_function_name, + const BufferAssignment* buffer_assignment) const { + context.loadDialect(); + mlir::DialectRegistry registry; + mlir::func::registerInlinerExtension(registry); + mlir::registerBuiltinDialectTranslation(registry); + mlir::registerLLVMDialectTranslation(registry); + mlir::registerNVVMDialectTranslation(registry); + context.appendDialectRegistry(registry); + + mlir::OpBuilder builder(&context); + auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion.name())); + mlir::OwningOpRef module = llvm_ir::CreateMlirModuleOp(loc); + + // Create the entry function. + SmallVector param_types; + std::optional args; + if (buffer_assignment != nullptr) { + TF_ASSIGN_OR_RETURN(args, + KernelArguments::Create(*buffer_assignment, &fusion)); + } + // Annotate tensors with the buffer indices. This way, the buffer propagation + // pass can clean them up later. + int next_slice_index = 0; + absl::flat_hash_map> + slice_indices; + auto get_arg_attrs = [&](int index) -> absl::StatusOr { + if (!args) { + return builder.getDictionaryAttr({builder.getNamedAttr( + "xla.slice_index", builder.getIndexAttr(next_slice_index++))}); + } + + const auto& arg = args->args()[index]; + SmallVector attrs; + attrs.push_back(builder.getNamedAttr( + "xla.slice_index", builder.getIndexAttr(arg.llvm_arg_index()))); + attrs.push_back( + builder.getNamedAttr(mlir::LLVM::LLVMDialect::getAlignAttrName(), + builder.getIndexAttr(arg.alignment()))); + attrs.push_back(builder.getNamedAttr( + mlir::LLVM::LLVMDialect::getDereferenceableAttrName(), + builder.getIndexAttr(arg.slice().size()))); + if (!arg.written()) { + attrs.push_back( + builder.getNamedAttr("xla.invariant", builder.getUnitAttr())); + } + return builder.getDictionaryAttr(attrs); + }; + + SmallVector arg_attrs; + int arg_index = 0; + for (auto* param : fusion.operands()) { + param_types.push_back( + mlir_converter::TensorShapeToMlirType(param->shape(), builder)); + TF_ASSIGN_OR_RETURN(arg_attrs.emplace_back(), get_arg_attrs(arg_index++)); + } + + auto result_types = mlir_converter::ShapeToMlirTypes(fusion.shape(), builder); + param_types.append(result_types.begin(), result_types.end()); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + fusion.shape(), [&](const auto& shape, const ShapeIndex& index) { + if (shape.IsArray()) { + TF_ASSIGN_OR_RETURN(arg_attrs.emplace_back(), + get_arg_attrs(arg_index++)); + } + return absl::OkStatus(); + })); + + builder.setInsertionPointToStart(module->getBody()); + auto entry_func = builder.create( + loc, entry_function_name, + mlir::FunctionType::get(&context, param_types, result_types), + /*sym_visibility=*/mlir::StringAttr{}, + mlir::ArrayAttr::get(&context, arg_attrs), + /*res_attrs=*/mlir::ArrayAttr{}); + entry_func->setAttr("xla.entry", mlir::UnitAttr::get(&context)); + + TF_RETURN_IF_ERROR(EmitMlir(module.get(), entry_func, fusion)); + + // Run a minimal simplification pipeline. + mlir::PassManager pm(&context); + pm.addPass(CreateSimplifyArithPass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + if (pm.run(module.get()).failed()) { + std::string module_dump; + llvm::raw_string_ostream os(module_dump); + module->print(os); + return absl::InternalError(absl::StrFormat( + "Failed to simplify module.\nHloFusionInstruction " + "computation:\n%s\nMLIR module:\n%s", + fusion.fused_instructions_computation()->ToString(), module_dump)); + } + + return module; +} + +SmallVector MlirFusionEmitterBase::EmitThreadLoopNest( + mlir::ImplicitLocOpBuilder& b, ValueRange outputs, + const IndexingMap& indexing_map, + const std::function< + SmallVector(ValueRange outputs_tensors, ValueRange dim_values, + ValueRange symbol_values)>& create_body) const { + SmallVector dim_values{EmitThreadId(b, 0), EmitThreadId(b, 1), + EmitThreadId(b, 2), EmitBlockId(b, 0), + EmitBlockId(b, 1), EmitBlockId(b, 2)}; + return mlir_converter::EmitLoopNest(b, dim_values, outputs, indexing_map, + create_body); +} + +absl::Status MlirFusionEmitterBase::EmitMlir( + mlir::ModuleOp module, mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const { + auto customized = GetInstructionsWithCustomCodegen(fusion); + mlir_converter::PartitionedComputations computations( + fusion.fused_instructions_computation(), customized); + auto subgraph_to_mlir_fn = computations.DeclareFunctions(module); + + // Erase subgraphs for all customized instructions that aren't used anywhere + // else. This is necessary because the instructions may not have elemental + // implementations (scatter). + for (auto* custom : customized) { + if (custom->user_count() == 0) { + subgraph_to_mlir_fn.extract(&computations.FindSubgraph(custom)) + .mapped() + .erase(); + } + } + + auto call_targets = + computations.CreateCallTargetProvider(subgraph_to_mlir_fn); + for (const auto& comp : computations.partitioned_computations()) { + for (const auto& subgraph : comp.subgraphs()) { + if (subgraph_to_mlir_fn.contains(&subgraph)) { + TF_RETURN_IF_ERROR(mlir_converter::SubgraphToMlirFunction( + comp, subgraph, subgraph_to_mlir_fn[&subgraph], call_targets)); + } + } + } + if (const auto& epilogue = computations.epilogue()) { + TF_RETURN_IF_ERROR(mlir_converter::SubgraphToMlirFunction( + computations.FindPartitionedComputation( + fusion.fused_instructions_computation()), + *epilogue, subgraph_to_mlir_fn[&*epilogue], call_targets)); + } + + return EmitEntryFunction(computations, call_targets, entry_function, fusion); +} + +mlir::ValueRange MlirFusionEmitterBase::EmitEpilogue( + const mlir_converter::PartitionedComputations& computations, + mlir::func::FuncOp entry_fn, mlir::ValueRange hero_values, + mlir::ValueRange output_indices, + mlir::ImplicitLocOpBuilder& builder) const { + const auto& epilogue = computations.epilogue(); + if (!epilogue) { + return hero_values; + } + + auto epilogue_fn = mlir::cast( + entry_fn->getParentOfType().lookupSymbol(epilogue->name)); + SmallVector operands = + mlir::ValueRange(entry_fn.getArguments().take_front( + computations.fusion()->num_parameters())); + absl::c_copy(output_indices, std::back_inserter(operands)); + absl::c_copy(hero_values, std::back_inserter(operands)); + + return builder.create(epilogue_fn, operands).getResults(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h b/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h new file mode 100644 index 0000000000000..6baf86372613e --- /dev/null +++ b/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h @@ -0,0 +1,119 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_MLIR_FUSION_EMITTER_H_ +#define XLA_SERVICE_GPU_FUSIONS_MLIR_MLIR_FUSION_EMITTER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { +namespace gpu { + +class MlirFusionEmitterBase : public KernelFusionInterface { + public: + absl::StatusOr Emit( + IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion) const final; + + // Visible for testing. `buffer_assignment` is optional for testing (assigns + // a different buffer to each tensor). + absl::StatusOr> CreateLLVMModule( + mlir::MLIRContext& mlir_context, llvm::LLVMContext& llvm_context, + const se::DeviceDescription& device, const HloFusionInstruction& fusion, + const std::string& entry_function_name, + const BufferAssignment* buffer_assignment) const; + + // Visible for testing. `buffer_assignment` is optional for testing (assigns + // a different buffer to each tensor). + absl::StatusOr> CreateMLIRModule( + mlir::MLIRContext& context, const HloFusionInstruction& fusion, + const std::string& entry_function_name, + const BufferAssignment* buffer_assignment) const; + + protected: + // Returns the set of instructions that will be isolated in the partitioned, + // i.e., they will get their own subgraph. We won't automatically emit + // functions for these instructions. + virtual std::vector GetInstructionsWithCustomCodegen( + const HloFusionInstruction& fusion) const { + return {}; + } + + virtual absl::Status EmitEntryFunction( + const mlir_converter::PartitionedComputations& computations, + const mlir_converter::CallTargetProvider& call_targets, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const = 0; + + // Evaluates the epilogue of the fusion. Returns `hero_values` if there is no + // epilogue. + mlir::ValueRange EmitEpilogue( + const mlir_converter::PartitionedComputations& computations, + mlir::func::FuncOp entry_fn, mlir::ValueRange hero_values, + mlir::ValueRange output_indices, + mlir::ImplicitLocOpBuilder& builder) const; + + // Emit a loop nest for the symbols in the output map. The map should have + // the dimensions specified in KernelFusionInterface. Loops are nested with + // the symbol 0 as the outermost loop. The indices of the map's dimensions and + // symbols are passed to the lambda separately. The return values of the + // function are the updated outputs. + llvm::SmallVector EmitThreadLoopNest( + mlir::ImplicitLocOpBuilder& b, mlir::ValueRange outputs, + const IndexingMap& indexing_map, + const std::function( + mlir::ValueRange outputs, mlir::ValueRange dim_values, + mlir::ValueRange symbol_values)>& create_body) const; + + mlir::Value EmitBlockId(mlir::ImplicitLocOpBuilder& builder, int dim) const; + mlir::Value EmitThreadId(mlir::ImplicitLocOpBuilder& builder, int dim) const; + + private: + // Emits MLIR for the given fusion. The entry function has one tensor argument + // per fusion parameter and output and one tensor result per fusion output. + // The fuson outputs may only be used with `tensor.insert` ops.a + absl::Status EmitMlir(mlir::ModuleOp module, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_MLIR_FUSION_EMITTER_H_ diff --git a/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc b/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc new file mode 100644 index 0000000000000..f3f699a8b854e --- /dev/null +++ b/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc @@ -0,0 +1,186 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" + +#include +#include +#include + +#include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" // from @llvm-project +#include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project +#include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" // from @llvm-project +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/filecheck.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class DummyCopyFusionEmitter : public MlirFusionEmitterBase { + public: + LaunchDimensions launch_dimensions() const final { return {1, 100}; } + + std::optional ComputeThreadIdToOutputIndexing( + int64_t, mlir::MLIRContext*) const final { + return std::nullopt; + } + + std::optional ComputeThreadIdToInputIndexing( + int64_t, int64_t, mlir::MLIRContext*) const final { + return std::nullopt; + } + + protected: + absl::Status EmitEntryFunction( + const mlir_converter::PartitionedComputations& computations, + const mlir_converter::CallTargetProvider& call_targets, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const { + mlir::ImplicitLocOpBuilder b(entry_function.getLoc(), entry_function); + b.setInsertionPointToStart(entry_function.addEntryBlock()); + auto thread_id = EmitThreadId(b, 0); + auto value = b.create( + entry_function.getArgument(0), mlir::ValueRange{thread_id}); + auto result = b.create( + value, entry_function.getArgument(1), mlir::ValueRange{thread_id}); + b.create(result->getResults()); + return absl::OkStatus(); + } +}; + +class MlirFusionEmitterTest : public HloTestBase { + protected: + MlirFusionEmitterTest() { + context_.loadDialect(); + mlir::DialectRegistry registry; + mlir::func::registerInlinerExtension(registry); + mlir::registerBuiltinDialectTranslation(registry); + mlir::registerLLVMDialectTranslation(registry); + mlir::registerNVVMDialectTranslation(registry); + context_.appendDialectRegistry(registry); + } + + mlir::MLIRContext context_; + stream_executor::DeviceDescription device_info_ = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); +}; + +constexpr absl::string_view kModule = R"( + fused_computation { + ROOT %p0 = f32[100] parameter(0) + } + + ENTRY main { + %p0 = f32[100] parameter(0) + ROOT fusion = f32[100] fusion(%p0), kind=kLoop, calls=fused_computation + })"; + +TEST_F(MlirFusionEmitterTest, CreateMlirModule) { + auto module = ParseAndReturnVerifiedModule(kModule).value(); + DummyCopyFusionEmitter emitter; + TF_ASSERT_OK_AND_ASSIGN( + auto mlir_module, + emitter.CreateMLIRModule( + context_, + *Cast( + module->entry_computation()->root_instruction()), + "fusion", + /*buffer_assignment=*/nullptr)); + + std::string out; + llvm::raw_string_ostream stream(out); + stream << *mlir_module; + + TF_ASSERT_OK_AND_ASSIGN(auto filecheck_result, RunFileCheck(out, R"( + // CHECK: func.func @fusion( + // CHECK-SAME: %[[IN:.*]]: tensor<100xf32> {xla.slice_index = 0 + // CHECK-SAME: %[[OUT:.*]]: tensor<100xf32> {xla.slice_index = 1 + // CHECK: %[[TID:.*]] = gpu.thread_id x + // CHECK: %[[VAL:.*]] = tensor.extract %[[IN]][%[[TID]]] + // CHECK: %[[RET:.*]] = tensor.insert %[[VAL]] + // CHECK-SAME: into %[[OUT]][%[[TID]]] + // CHECK: return %[[RET]] + )")); + EXPECT_TRUE(filecheck_result); +} + +TEST_F(MlirFusionEmitterTest, CreateLLVMModule) { + llvm::LLVMContext llvm_context; + + auto module = ParseAndReturnVerifiedModule(kModule).value(); + DummyCopyFusionEmitter emitter; + TF_ASSERT_OK_AND_ASSIGN( + auto llvm_module, + emitter.CreateLLVMModule( + context_, llvm_context, device_info_, + *Cast( + module->entry_computation()->root_instruction()), + "fusion", + /*buffer_assignment=*/nullptr)); + + std::string out; + llvm::raw_string_ostream stream(out); + stream << *llvm_module; + + TF_ASSERT_OK_AND_ASSIGN(auto filecheck_result, RunFileCheck(out, R"( + // CHECK: define void @fusion(ptr noalias %[[IN:.*]], ptr noalias %[[OUT:.*]]) + // CHECK: %[[TID:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + // CHECK: %[[EXT:.*]] = sext i32 %[[TID]] to i64 + // CHECK: %[[TRUNC:.*]] = trunc i64 %[[EXT]] to i32 + // CHECK: %[[IN_PTR:.*]] = getelementptr inbounds float, ptr %[[IN]], i32 %[[TRUNC]] + // CHECK: %[[VAL:.*]] = load float, ptr %[[IN_PTR]], align 4 + // CHECK: %[[OUT_PTR:.*]] = getelementptr inbounds float, ptr %[[OUT]], i32 %[[TRUNC]] + // CHECK: store float %[[VAL]], ptr %[[OUT_PTR]], align 4 + // CHECK: ret void + )")); + EXPECT_TRUE(filecheck_result); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/passes.h b/xla/service/gpu/fusions/mlir/passes.h new file mode 100644 index 0000000000000..b91fc53a54763 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/passes.h @@ -0,0 +1,53 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_H_ +#define XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_H_ + +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { + +#define GEN_PASS_DECL +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +// Returns the range of a given value, if it can be statically determined. +std::optional GetRange(mlir::Value value); + +std::unique_ptr CreateExpandFloatOpsPass(bool pre_ampere); +std::unique_ptr CreateLowerFuncPass(); +std::unique_ptr CreateLowerTensorsPass(); +std::unique_ptr CreateLowerToLLVMPass(); +std::unique_ptr CreateLowerXlaGpuToScfPass(); +std::unique_ptr CreateMergePointersToSameSlicePass(); +std::unique_ptr CreatePropagateSliceIndicesPass(); +std::unique_ptr CreateSimplifyAffinePass(); +std::unique_ptr CreateSimplifyArithPass(); + +#define GEN_PASS_REGISTRATION +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_H_ diff --git a/xla/service/gpu/fusions/mlir/passes.td b/xla/service/gpu/fusions/mlir/passes.td new file mode 100644 index 0000000000000..72dbf9ab500f5 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/passes.td @@ -0,0 +1,167 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_TD_ +#define XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_TD_ + +include "mlir/Pass/PassBase.td" + +def PropagateSliceIndicesPass : + Pass<"xla-gpu-propagate-slice-indices", "mlir::ModuleOp"> { + let summary = "Propagates slice indices from the entry function to all callees."; + + let description = [{ + Propagates xla.slice_index attributes from the function with the xla.entry + attribute to all other functions. + }]; + + let dependentDialects = [ + "mlir::func::FuncDialect" + ]; + + let constructor = "CreatePropagateSliceIndicesPass()"; +} + +def LowerFuncPass : Pass<"xla-gpu-lower-func", "mlir::ModuleOp"> { + let summary = "Lowers function calls to func.calls."; + + let description = [{ + We use xla_gpu.pure_call ops for calls to enable CSE and other + transformations (e.g. LICM). This pass rewrites our custom ops to standard + ops. + }]; + + let dependentDialects = [ + "mlir::func::FuncDialect", "xla::gpu::XlaGpuDialect" + ]; + + let constructor = "CreateLowerFuncPass()"; +} + +def LowerTensorsPass : + Pass<"xla-gpu-lower-tensors", "mlir::ModuleOp"> { + let summary = "Lowers tensors to llvm pointers and loads/stores."; + + let description = [{ + Lowers tensors to LLVM. We cannot use the memref lowerings because they + are not compatible with XLA's ABI. + }]; + + let dependentDialects = [ + "mlir::func::FuncDialect", "mlir::gpu::GPUDialect", "mlir::LLVM::LLVMDialect", + "mlir::tensor::TensorDialect", "mlir::scf::SCFDialect", + ]; + + let constructor = "CreateLowerTensorsPass()"; +} + +def MergePointersToSameSlicePass : + Pass<"xla-gpu-merge-pointers", "mlir::ModuleOp"> { + let summary = "Merges pointers that share slices."; + + let description = [{ + When a function has multiple pointer arguments with the same slice index, + merges them. + }]; + + let dependentDialects = [ + "mlir::func::FuncDialect" + ]; + + let constructor = "CreateMergePointersToSameSlicePass()"; +} + +def SimplifyArithPass : Pass<"xla-gpu-simplify-arith", "mlir::ModuleOp"> { + let summary = "Simplifies arith using XLA's range-aware simplifier."; + + let description = [{ + We often emit bounds checks that are statically known to be satisfied. + This pass removes them. + }]; + + let dependentDialects = [ + "mlir::arith::ArithDialect" + ]; + + let constructor = "CreateSimplifyArithPass()"; +} + +def SimplifyAffinePass : Pass<"xla-gpu-simplify-affine", "mlir::ModuleOp"> { + let summary = "Simplifies affine.apply using XLA's range-aware simplifier."; + + let description = [{ + The standard affine canonicalizer cannot simplify all expressions, since + it is unaware of range information. This pass uses `xla.range` attributes + on arguments and ops for simplification. It also lowers floordiv and mod + to simpler expressions than lower-affine. This pass only works for + expressions for which we can prove the LHS of mod and div is nonnegative. + }]; + + let dependentDialects = [ + "mlir::affine::AffineDialect", "mlir::func::FuncDialect", + "mlir::scf::SCFDialect", + ]; + + let constructor = "CreateSimplifyAffinePass()"; +} + +def ExpandFloatOpsPass : Pass<"xla-gpu-expand-float-ops", "mlir::ModuleOp"> { + let summary = "Expands float ops that are not natively supported."; + + let description = [{ + Not all float ops are natively supported, either because they don't exist + in hardware or they are too inaccurate. + + This pass replaces these ops with alternative implementations. + }]; + + let dependentDialects = [ + "mlir::arith::ArithDialect", "mlir::mhlo::MhloDialect" + ]; + + let options = [ + Option<"pre_ampere_", "pre-ampere", "bool", /*default=*/"false", + "Rewrite ops that are not supported on architectures before Ampere">, + ]; +} + +def LowerXlaGpuToScfPass : + Pass<"xla-gpu-lower-xla-gpu-to-scf", "mlir::ModuleOp"> { + let summary = "Lowers xla_gpu to SCF."; + + let dependentDialects = [ + "mlir::gpu::GPUDialect", "mlir::LLVM::LLVMDialect", "mlir::scf::SCFDialect", + "mlir::tensor::TensorDialect", "xla::gpu::XlaGpuDialect", + ]; + + let constructor = "CreateLowerXlaGpuToScfPass()"; +} + +def LowerToLLVMPass : + Pass<"xla-gpu-lower-to-llvm", "mlir::ModuleOp"> { + let summary = "Lowers to LLVM."; + + let description = [{ + Lowers the rest to LLVM + }]; + + let dependentDialects = [ + "mlir::func::FuncDialect", "mlir::LLVM::LLVMDialect" + ]; + + let constructor = "CreateLowerToLLVMPass()"; +} + +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_TD_ diff --git a/xla/service/gpu/fusions/mlir/propagate_slice_indices.cc b/xla/service/gpu/fusions/mlir/propagate_slice_indices.cc new file mode 100644 index 0000000000000..bc4e74d914a99 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/propagate_slice_indices.cc @@ -0,0 +1,80 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/service/gpu/fusions/mlir/passes.h" + +namespace xla { +namespace gpu { + +#define GEN_PASS_DEF_PROPAGATESLICEINDICESPASS +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +namespace { + +class PropagateSliceIndicesPass + : public impl::PropagateSliceIndicesPassBase { + public: + void runOnOperation() override; +}; + +void PropagateSliceIndicesPass::runOnOperation() { + mlir::func::FuncOp entry; + for (auto func : getOperation().getOps()) { + if (func->getAttr("xla.entry")) { + entry = func; + break; + } + } + + if (!entry) { + getOperation()->emitOpError("No entry function found."); + signalPassFailure(); + return; + } + + for (auto func : getOperation().getOps()) { + if (func.getNumArguments() == 0 || func == entry) { + continue; + } + + for (int i = 0; i < func.getNumArguments(); ++i) { + if (mlir::isa(func.getArgument(i).getType())) { + if (auto index = entry.getArgAttr(i, "xla.slice_index")) { + func.setArgAttr(i, "xla.slice_index", index); + } + if (auto invariant = entry.getArgAttr(i, "xla.invariant")) { + func.setArgAttr(i, "xla.invariant", invariant); + } + } else { + break; + } + } + } +} + +} // namespace + +std::unique_ptr CreatePropagateSliceIndicesPass() { + return std::make_unique(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/simplify_affine.cc b/xla/service/gpu/fusions/mlir/simplify_affine.cc new file mode 100644 index 0000000000000..e4020c97a57e7 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/simplify_affine.cc @@ -0,0 +1,232 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/optimization.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/Affine/LoopUtils.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "xla/service/gpu/fusions/mlir/passes.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { + +#define GEN_PASS_DEF_SIMPLIFYAFFINEPASS +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +std::optional GetRange(mlir::Value value) { + auto attr_to_range = [](mlir::Attribute attr) -> std::optional { + if (!attr) { + return std::nullopt; + } + auto values = llvm::to_vector( + attr.cast().getAsValueRange()); + return {{values[0].getSExtValue(), values[1].getSExtValue()}}; + }; + + if (value.getDefiningOp()) { + return attr_to_range(value.getDefiningOp()->getAttr("xla.range")); + } + + auto bbarg = value.dyn_cast(); + if (!bbarg) { + return std::nullopt; + } + + auto parent = bbarg.getParentBlock()->getParentOp(); + if (auto func_op = mlir::dyn_cast(parent)) { + return attr_to_range(func_op.getArgAttr(bbarg.getArgNumber(), "xla.range")); + } + + if (auto for_op = mlir::dyn_cast(parent)) { + llvm::APInt lb, ub; + if (mlir::matchPattern(for_op.getLowerBound(), mlir::m_ConstantInt(&lb)) && + mlir::matchPattern(for_op.getUpperBound(), mlir::m_ConstantInt(&ub))) { + return {{lb.getSExtValue(), ub.getSExtValue() - 1}}; + } + } + + return std::nullopt; +} + +namespace { + +class SimplifyAffinePass + : public impl::SimplifyAffinePassBase { + public: + void runOnOperation() override; +}; + +struct RewriteAffineApply + : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::affine::AffineApplyOp op, + mlir::PatternRewriter& rewriter) const override { + auto affine_map = op.getAffineMap(); + std::vector dim_ranges(affine_map.getNumDims()); + std::vector symbol_ranges(affine_map.getNumSymbols()); + + for (int i = 0; i < affine_map.getNumInputs(); ++i) { + if (auto range = GetRange(op->getOperand(i))) { + if (i >= dim_ranges.size()) { + symbol_ranges[i - dim_ranges.size()] = RangeVar{*range}; + } else { + dim_ranges[i] = DimVar{*range}; + } + } else { + return rewriter.notifyMatchFailure(op, "failed to deduce range"); + } + } + + IndexingMap map(op.getAffineMap(), dim_ranges, symbol_ranges, + /*rt_vars=*/{}); + map.Simplify(GetIndexingMapForInstruction); + auto expr = map.GetAffineMap().getResult(0); + + RangeEvaluator range_evaluator(map.GetDimensionBounds(), + map.GetSymbolBounds(), op->getContext()); + std::function can_be_lowered; + bool fits_32_bits = true; + can_be_lowered = [&](mlir::AffineExpr expr) { + auto range = range_evaluator.ComputeExpressionRange(expr); + fits_32_bits &= range.upper < std::numeric_limits::max(); + + auto bin_op = llvm::dyn_cast(expr); + if (!bin_op) { + return true; + } + + // Mod and div can be lowered if their LHS is >= 0 and their RHS is a + // constant. + if (expr.getKind() == mlir::AffineExprKind::Mod || + expr.getKind() == mlir::AffineExprKind::FloorDiv) { + if (!range_evaluator.IsAlwaysPositiveOrZero(bin_op.getLHS()) || + !range_evaluator.ComputeExpressionRange(bin_op.getRHS()) + .IsPoint()) { + return false; + } + } + if (expr.getKind() == mlir::AffineExprKind::CeilDiv) { + return false; + } + + return can_be_lowered(bin_op.getLHS()) && can_be_lowered(bin_op.getRHS()); + }; + + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + if (!can_be_lowered(expr)) { + auto range = range_evaluator.ComputeExpressionRange(expr); + op->setAttr("xla.range", b.getIndexArrayAttr({range.lower, range.upper})); + return rewriter.notifyMatchFailure(op, + "unable to lower the affine apply"); + } + + std::function lower; + + auto int_ty = fits_32_bits ? b.getI32Type() : b.getI64Type(); + b.setInsertionPoint(op); + lower = [&](mlir::AffineExpr expr) -> mlir::Value { + if (auto bin_op = mlir::dyn_cast(expr)) { + auto lhs = lower(bin_op.getLHS()); + auto rhs = lower(bin_op.getRHS()); + switch (expr.getKind()) { + case mlir::AffineExprKind::Add: + return b.create(lhs, rhs); + case mlir::AffineExprKind::Mul: + return b.create(lhs, rhs); + case mlir::AffineExprKind::Mod: + return b.create(lhs, rhs); + case mlir::AffineExprKind::FloorDiv: + return b.create(lhs, rhs); + default: + ABSL_UNREACHABLE(); + } + } + + switch (expr.getKind()) { + case mlir::AffineExprKind::Constant: + return b.create( + mlir::cast(expr).getValue(), int_ty); + case mlir::AffineExprKind::DimId: + return b.create( + int_ty, op.getDimOperands()[mlir::cast(expr) + .getPosition()]); + case mlir::AffineExprKind::SymbolId: + return b.create( + int_ty, + op.getSymbolOperands()[mlir::cast(expr) + .getPosition()]); + default: + ABSL_UNREACHABLE(); + } + }; + + auto result = lower(map.GetAffineMap().getResult(0)); + auto result_range = + range_evaluator.ComputeExpressionRange(map.GetAffineMap().getResult(0)); + rewriter + .replaceOpWithNewOp(op, b.getIndexType(), + result) + ->setAttr("xla.range", b.getIndexArrayAttr( + {result_range.lower, result_range.upper})); + return mlir::success(); + } +}; + +void SimplifyAffinePass::runOnOperation() { + mlir::RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + mlir::GreedyRewriteConfig config; + // There's no point simplifying more than once. + config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps; + if (mlir::failed(mlir::applyPatternsAndFoldGreedily( + getOperation(), std::move(patterns), config))) { + signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr CreateSimplifyAffinePass() { + return std::make_unique(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/simplify_arith.cc b/xla/service/gpu/fusions/mlir/simplify_arith.cc new file mode 100644 index 0000000000000..22f2d0f06e750 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/simplify_arith.cc @@ -0,0 +1,106 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Utils/StaticValueUtils.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "xla/service/gpu/fusions/mlir/passes.h" +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { + +#define GEN_PASS_DEF_SIMPLIFYARITHPASS +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +namespace { + +Interval::ComparisonResult EvaluateCmpI(mlir::arith::CmpIPredicate pred, + Interval lhs, int64_t rhs) { + switch (pred) { + case mlir::arith::CmpIPredicate::eq: + return lhs == rhs; + case mlir::arith::CmpIPredicate::ne: + return lhs != rhs; + case mlir::arith::CmpIPredicate::slt: + case mlir::arith::CmpIPredicate::ult: + return lhs < rhs; + case mlir::arith::CmpIPredicate::sle: + case mlir::arith::CmpIPredicate::ule: + return lhs <= rhs; + case mlir::arith::CmpIPredicate::sgt: + case mlir::arith::CmpIPredicate::ugt: + return lhs > rhs; + case mlir::arith::CmpIPredicate::sge: + case mlir::arith::CmpIPredicate::uge: + return lhs >= rhs; + } +} + +struct RewriteCmpI : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::arith::CmpIOp op, mlir::PatternRewriter& rewriter) const override { + // We don't need to support constants on the LHS, since comparisons are + // canonicalized to have them on the RHS. + auto rhs = mlir::getConstantIntValue(op.getRhs()); + auto lhs = GetRange(op.getLhs()); + if (lhs && rhs) { + Interval::ComparisonResult result = + EvaluateCmpI(op.getPredicate(), *lhs, *rhs); + if (result != std::nullopt) { + rewriter.replaceOpWithNewOp( + op, *result, rewriter.getI1Type()); + return mlir::success(); + } + } + // TODO(jreiffers): Consider supporting ranges on the RHS as well. + return rewriter.notifyMatchFailure(op, "not a constant result"); + } +}; + +class SimplifyArithPass + : public impl::SimplifyArithPassBase { + public: + void runOnOperation() override { + mlir::RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr CreateSimplifyArithPass() { + return std::make_unique(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/tests/BUILD b/xla/service/gpu/fusions/mlir/tests/BUILD new file mode 100644 index 0000000000000..c4a8230a2966b --- /dev/null +++ b/xla/service/gpu/fusions/mlir/tests/BUILD @@ -0,0 +1,40 @@ +load("//xla:lit.bzl", "lit_test_suite") +load("//xla:xla.bzl", "xla_cc_binary") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +xla_cc_binary( + name = "mlir_fusions_opt", + srcs = ["mlir_fusions_opt.cc"], + deps = [ + "//xla/mlir_hlo", + "//xla/service/gpu/fusions/mlir:passes", + "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) + +lit_test_suite( + name = "tests", + srcs = glob(["*.mlir"]), + cfg = "//xla:lit.cfg.py", + tools = [ + ":mlir_fusions_opt", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/xla/service/gpu/fusions/mlir/tests/expand_float_ops.mlir b/xla/service/gpu/fusions/mlir/tests/expand_float_ops.mlir new file mode 100644 index 0000000000000..ec891b0a4393d --- /dev/null +++ b/xla/service/gpu/fusions/mlir/tests/expand_float_ops.mlir @@ -0,0 +1,66 @@ +// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-expand-float-ops="pre-ampere=true" -canonicalize | FileCheck %s -check-prefixes=CHECK,CHECK-PRE-AMPERE +// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-expand-float-ops="pre-ampere=false" -canonicalize | FileCheck %s -check-prefixes=CHECK,CHECK-AMPERE + +module { + func.func @tanh(%arg0: f32) -> f32 { + %ret = math.tanh %arg0 : f32 + return %ret : f32 + } +} + +// CHECK-LABEL: @tanh +// CHECK-NOT: tanh + +// ----- + +module { + func.func @erf(%arg0: f32) -> f32 { + %ret = math.erf %arg0 : f32 + return %ret : f32 + } +} + +// CHECK-LABEL: @erf +// CHECK-NOT: erf + +// ----- + +module { + func.func @maximumf(%arg0: f32, %arg1: f32) -> f32 { + %ret = arith.maximumf %arg0, %arg1 : f32 + return %ret : f32 + } +} + +// CHECK-LABEL: @maximumf +// CHECK-AMPERE: arith.maximumf +// CHECK-PRE-AMPERE: arith.cmpf +// CHECK-PRE-AMPERE: arith.select + +// ----- + +module { + func.func @minimumf(%arg0: f32, %arg1: f32) -> f32 { + %ret = arith.minimumf %arg0, %arg1 : f32 + return %ret : f32 + } +} + +// CHECK-LABEL: @minimumf +// CHECK-AMPERE: arith.minimumf +// CHECK-PRE-AMPERE: arith.cmpf +// CHECK-PRE-AMPERE: arith.select + +// ----- + +module { + func.func @minimumf64(%arg0: f64, %arg1: f64) -> f64 { + %ret = arith.minimumf %arg0, %arg1 : f64 + return %ret : f64 + } +} + +// CHECK-LABEL: @minimumf64 +// CHECK-NOT: minimumf +// CHECK: arith.cmpf +// CHECK: arith.select \ No newline at end of file diff --git a/xla/service/gpu/fusions/mlir/tests/inlining.mlir b/xla/service/gpu/fusions/mlir/tests/inlining.mlir new file mode 100644 index 0000000000000..b2a0b2db94279 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/tests/inlining.mlir @@ -0,0 +1,195 @@ +// RUN: mlir_fusions_opt %s -split-input-file -inline | FileCheck %s + +module { + func.func private @mul(%a: f32, %b: f32) -> f32 { + %ret = arith.mulf %a, %b : f32 + return %ret : f32 + } + + func.func private @add(%a: f32, %b: f32) -> f32 { + %add = arith.addf %a, %b : f32 + %ret = xla_gpu.pure_call @mul(%add, %add) : (f32, f32) -> (f32) + return %ret : f32 + } + + func.func @caller(%a: f32, %b: f32) -> f32 { + %ret = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) + return %ret : f32 + } +} +// CHECK: @caller +// CHECK-NOT: xla_gpu.pure_call @add +// CHECK: arith.addf +// CHECK-NOT: xla_gpu.pure_call @mul +// CHECK: arith.mulf + +// ----- + +module { + func.func @fused_computation(%arg0: tensor<2xf32> {xla.slice_index = 0 : index}, %arg1: tensor<2xf32> {xla.slice_index = 1 : index}, %arg2: tensor<2xf32> {xla.slice_index = 2 : index}) -> tensor<2xf32> attributes {xla.entry} { + %0 = gpu.thread_id x {xla.range = [0 : index, 1 : index]} + %1 = xla_gpu.pure_call @fused_computation_atan2(%arg0, %arg1, %0) : (tensor<2xf32>, tensor<2xf32>, index) -> f32 + %inserted = tensor.insert %1 into %arg2[%0] : tensor<2xf32> + return %inserted : tensor<2xf32> + } + func.func private @fused_computation_atan2(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>, %arg2: index {xla.range = [0 : index, 1 : index]}) -> f32 attributes {llvm.linkage = #llvm.linkage} { + %extracted = tensor.extract %arg0[%arg2] : tensor<2xf32> + %extracted_0 = tensor.extract %arg1[%arg2] : tensor<2xf32> + %0 = arith.addf %extracted, %extracted_0 : f32 + %1 = arith.subf %extracted, %extracted_0 : f32 + %2 = arith.mulf %0, %1 : f32 + %3 = arith.divf %0, %1 : f32 + %4 = math.atan2 %2, %3 : f32 + return %4 : f32 + } +} +// CHECK: @fused_computation +// CHECK-NOT: xla_gpu.pure_call @add +// CHECK: gpu.thread_id +// CHECK-NEXT: tensor.extract +// CHECK-NEXT: tensor.extract +// CHECK-NEXT: arith.addf +// CHECK-NEXT: arith.subf +// CHECK-NEXT: arith.mulf +// CHECK-NEXT: arith.divf +// CHECK-NEXT: math.atan2 +// CHECK-NEXT: tensor.insert + +// ----- + +module { + // Do not inline this function as it has two callers. Even if the callers are + // in different functions at the start, after inlining the two callers are in + // the same function. + func.func private @large(%a: f32, %b: f32) -> f32 { + %mul = arith.mulf %a, %b : f32 + %add = arith.addf %a, %mul : f32 + %div = arith.divf %add, %b : f32 + %sub = arith.subf %div, %a : f32 + %atan2 = math.atan2 %b, %sub : f32 + %neg = arith.negf %atan2 : f32 + %zero = arith.constant 0.0 : f32 + %comp = arith.cmpf olt, %neg, %zero : f32 + %ret = arith.select %comp, %zero, %neg : f32 + return %ret : f32 + } + + func.func private @add(%a: f32, %b: f32) -> f32 { + %add = arith.addf %a, %b : f32 + %ret = xla_gpu.pure_call @large(%add, %add) : (f32, f32) -> (f32) + return %ret : f32 + } + + func.func @caller(%a: f32, %b: f32) -> f32 { + %add = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) + %ret = xla_gpu.pure_call @large(%add, %add) : (f32, f32) -> (f32) + return %ret : f32 + } +} +// CHECK: @caller +// CHECK: arith.addf +// CHECK: xla_gpu.pure_call @large +// CHECK: xla_gpu.pure_call @large + +// ----- + +module { + func.func private @add(%a: f32, %b: f32) -> f32 { + %ret = arith.addf %a, %b : f32 + return %ret : f32 + } + + func.func @caller(%a: f32, %b: f32) -> f32 { + %add = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) + %ret = xla_gpu.pure_call @add(%add, %add) : (f32, f32) -> (f32) + return %ret : f32 + } +} +// CHECK: @caller +// CHECK-NOT: xla_gpu.pure_call +// CHECK: arith.addf +// CHECK: arith.addf + +// ----- + +module { + func.func private @fib0(%start : f32) -> f32 { + %zero = arith.constant 0.0 : f32 + return %zero : f32 + } + func.func private @fib1(%start : f32) -> f32 { + return %start : f32 + } + func.func private @fib2(%start : f32) -> f32 { + %a = xla_gpu.pure_call @fib0(%start) : (f32) -> (f32) + %b = xla_gpu.pure_call @fib1(%start) : (f32) -> (f32) + %ret = arith.addf %a, %b : f32 + return %ret : f32 + } + func.func private @fib3(%start : f32) -> f32 { + %a = xla_gpu.pure_call @fib1(%start) : (f32) -> (f32) + %b = xla_gpu.pure_call @fib2(%start) : (f32) -> (f32) + %ret = arith.addf %a, %b : f32 + return %ret : f32 + } + func.func private @fib4(%start : f32) -> f32 { + %a = xla_gpu.pure_call @fib2(%start) : (f32) -> (f32) + %b = xla_gpu.pure_call @fib3(%start) : (f32) -> (f32) + %ret = arith.addf %a, %b : f32 + return %ret : f32 + } + // When inlining the other functions into @fib5, this function exceeds the + // threshold for inlining. + func.func private @fib5(%start : f32) -> f32 { + %a = xla_gpu.pure_call @fib3(%start) : (f32) -> (f32) + %b = xla_gpu.pure_call @fib4(%start) : (f32) -> (f32) + %ret = arith.addf %a, %b : f32 + return %ret : f32 + } + // As we do not inline @fib5 into @fib6, this function stays below the + // threshold for inlining. + func.func private @fib6(%start : f32) -> f32 { + %a = xla_gpu.pure_call @fib4(%start) : (f32) -> (f32) + %b = xla_gpu.pure_call @fib5(%start) : (f32) -> (f32) + %ret = arith.addf %a, %b : f32 + return %ret : f32 + } + func.func private @fib7(%start : f32) -> f32 { + %a = xla_gpu.pure_call @fib5(%start) : (f32) -> (f32) + %b = xla_gpu.pure_call @fib6(%start) : (f32) -> (f32) + %ret = arith.addf %a, %b : f32 + return %ret : f32 + } + + func.func @caller(%a: f32) -> f32 { + %ret = xla_gpu.pure_call @fib7(%a) : (f32) -> (f32) + return %ret : f32 + } +} +// CHECK: @caller +// CHECK: arith.constant 0.000000e+00 +// CHECK: xla_gpu.pure_call @fib5 +// CHECK: arith.addf +// CHECK: arith.addf +// CHECK: arith.addf +// CHECK: arith.addf +// CHECK: xla_gpu.pure_call @fib5 +// CHECK: arith.addf +// CHECK: arith.addf + +// ----- + +module { + func.func private @complex(%a: f32, %b: f32) -> complex { + %ret = complex.create %a, %b : complex + return %ret : complex + } + + func.func @caller(%a: f32, %b: f32) -> complex { + %ret = xla_gpu.pure_call @complex(%a, %b) : (f32, f32) -> (complex) + return %ret : complex + } +} + +// CHECK: @caller +// CHECK-NEXT: complex.create diff --git a/xla/service/gpu/fusions/mlir/tests/lower_func.mlir b/xla/service/gpu/fusions/mlir/tests/lower_func.mlir new file mode 100644 index 0000000000000..29acbd5e284fa --- /dev/null +++ b/xla/service/gpu/fusions/mlir/tests/lower_func.mlir @@ -0,0 +1,53 @@ +// RUN: mlir_fusions_opt %s -xla-gpu-lower-func | FileCheck %s +// RUN: mlir_fusions_opt %s -cse -xla-gpu-lower-func | FileCheck %s -check-prefixes=CHECK-CSE + +module { + func.func private @callee() -> f32 { + %ret = arith.constant 0.0 : f32 + return %ret : f32 + } + + func.func @caller() -> f32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %call0 = xla_gpu.pure_call @callee() : () -> (f32) + %v = scf.for %i = %c0 to %c10 step %c1 iter_args(%r = %call0) -> f32 { + %call1 = xla_gpu.pure_call @callee() : () -> (f32) + %new_v = arith.addf %call1, %r : f32 + scf.yield %new_v : f32 + } + return %v : f32 + } +} + +// CHECK: @caller +// CHECK: call @callee +// CHECK: call @callee + +// CHECK-CSE: @caller +// CHECK-CSE: %[[CALL:.*]] = call @callee +// CHECK-CSE: scf.for {{.*}} iter_args(%[[ITER_ARG:.*]] = %[[CALL]]) +// CHECK-CSE: arith.addf %[[CALL]], %[[ITER_ARG]] + +// ----- + +module { + func.func private @arg_callee(%arg0: f32, %arg1: f32) -> f32 { + %ret = arith.addf %arg0, %arg1 : f32 + return %ret : f32 + } + + func.func @arg_caller() -> f32 { + %cst0 = arith.constant 0.0 : f32 + %cst1 = arith.constant 1.0 : f32 + %call = xla_gpu.pure_call @arg_callee(%cst0, %cst1) : (f32, f32) -> (f32) + return %call : f32 + } +} + +// CHECK: @arg_caller +// CHECK: %[[CST0:.*]] = arith.constant 0 +// CHECK: %[[CST1:.*]] = arith.constant 1 +// CHECK: %[[RET:.*]] = call @arg_callee(%[[CST0]], %[[CST1]]) +// CHECK: return %[[RET]] diff --git a/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir b/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir new file mode 100644 index 0000000000000..b508d2cb84dd3 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir @@ -0,0 +1,311 @@ +// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-lower-tensors | FileCheck %s + +module { + func.func private @add(%arg0: f32, %arg1: f32) -> f32 { + %sum = arith.addf %arg0, %arg1 : f32 + func.return %sum : f32 + } + + func.func private @tensorarg(%arg0: tensor<43xf32> {xla.invariant, xla.slice_index = 0}, %arg1: index) -> f32 { + %v1 = arith.constant 2.0 : f32 + %v2 = tensor.extract %arg0[%arg1] : tensor<43xf32> + %sum = func.call @add(%v1, %v2) : (f32, f32) -> f32 + func.return %sum : f32 + } + + func.func @tensorcall(%arg0: tensor<43xf32> {xla.slice_index = 0}, %arg1: index) -> f32 { + %call = func.call @tensorarg(%arg0, %arg1) : (tensor<43xf32>, index) -> f32 + func.return %call : f32 + } + + func.func @stores(%arg0: tensor<17xf32> {xla.slice_index = 0}, %arg1: tensor<43xf32> {xla.slice_index = 1}) -> tensor<43xf32> { + %c17 = arith.constant 17 : index + %c23 = arith.constant 23 : index + %cst = arith.constant 3.0 : f32 + %out = tensor.insert %cst into %arg1[%c17] : tensor<43xf32> + %out2 = tensor.insert %cst into %out[%c23] : tensor<43xf32> + func.return %out2 : tensor<43xf32> + } +} + +// CHECK: func.func private @add(%{{.*}}: f32, %{{.*}}: f32) -> f32 { +// CHECK-NEXT: arith.addf +// CHECK-NEXT: return + +// CHECK: func.func private @tensorarg(%[[ARG0:.*]]: !llvm.ptr +// CHECK-SAME: {xla.invariant, xla.slice_index = 0 : i64}, %[[ARG1:.*]]: index) -> f32 { +// CHECK-DAG: %[[C2:.*]] = arith.constant 2.000000e+00 +// CHECK-DAG: %[[IDX:.*]] = arith.index_castui %[[ARG1]] : index to i32 +// CHECK-DAG: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX]]] +// CHECK-DAG: %[[V2:.*]] = llvm.load %[[PTR]] invariant +// CHECK: %[[RET:.*]] = call @add(%[[C2]], %[[V2]]) +// CHECK: return %[[RET]] + +// CHECK: func.func @tensorcall(%[[ARG0:.*]]: !llvm.ptr +// CHECK-SAME: {xla.slice_index = 0 : i64}, %[[ARG1:.*]]: index) +// CHECK: %[[RET:.*]] = call @tensorarg(%[[ARG0]], %[[ARG1]]) +// CHECK: return %[[RET]] + +// CHECK: func.func @stores( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr {xla.slice_index = 0 : i64}, +// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr {xla.slice_index = 1 : i64}) +// CHECK-NEXT: %[[CST:.*]] = arith.constant 3.000000e+00 : f32 +// CHECK-NEXT: %[[PTR1:.*]] = llvm.getelementptr inbounds %[[ARG1]][17] +// CHECK-NEXT: llvm.store %[[CST]], %[[PTR1]] +// CHECK-NEXT: %[[PTR2:.*]] = llvm.getelementptr inbounds %[[ARG1]][23] +// CHECK-NEXT: llvm.store %[[CST]], %[[PTR2]] +// CHECK-NEXT: return + +// ----- + +module { + func.func @layout( + %arg0: tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>, + %arg1: index, %arg2: index) -> f32 { + %v = tensor.extract %arg0[%arg1, %arg2] + : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>> + func.return %v : f32 + } +} + +// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 2)> +// CHECK: @layout(%[[ARG0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]](%[[X]], %[[Y]]) +// CHECK: %[[IDX_CAST:.*]] = arith.index_castui %[[IDX]] : index to i32 +// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX_CAST]]] +// CHECK: llvm.load %[[PTR]] + +// ----- + +module { + func.func @store_control_flow( + %arg0: tensor<2xf32>, + %arg1: index + ) -> tensor<2xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %cst = arith.constant 0.0 : f32 + %cst2 = arith.constant 1.0 : f32 + + %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg2 = %arg0) -> tensor<2xf32> { + %new_out = tensor.insert %cst into %arg2[%i] : tensor<2xf32> + scf.yield %new_out : tensor<2xf32> + } + + %inbounds = arith.cmpi sle, %arg1, %c1 : index + %result = scf.if %inbounds -> tensor<2xf32> { + %if = tensor.insert %cst2 into %for[%arg1] : tensor<2xf32> + scf.yield %if : tensor<2xf32> + } else { + scf.yield %for : tensor<2xf32> + } + func.return %result : tensor<2xf32> + } +} + +// CHECK: @store_control_flow(%[[ARG0:.*]]: !llvm.ptr, %[[X:.*]]: index) { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] { +// CHECK: %[[CAST:.*]] = arith.index_castui %[[I]] : index to i32 +// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[CAST]]] +// CHECK: llvm.store {{.*}}, %[[PTR]] +// CHECK: %[[INBOUNDS:.*]] = arith.cmpi +// CHECK: scf.if %[[INBOUNDS]] { +// CHECK: llvm.store +// CHECK-NEXT: } +// CHECK-NEXT: return + +// ----- + +module { + func.func @large_tensor( + %arg0: tensor<1024x1024x1024x6xf32>, + %arg1: index) -> f32 { + %v = tensor.extract %arg0[%arg1, %arg1, %arg1, %arg1] : tensor<1024x1024x1024x6xf32> + func.return %v : f32 + } +} + +// CHECK: @large_tensor +// CHECK: arith.index_castui {{.*}} : index to i64 + +// ----- + +module { + func.func @complex_tensor_insert( + %arg0: tensor<10xcomplex>) -> tensor<10xcomplex> { + %c1 = arith.constant 1 : index + %real = arith.constant 3.0 : f32 + %imag = arith.constant 2.0 : f32 + %complex = complex.create %real, %imag : complex + %out = tensor.insert %complex into %arg0[%c1] : tensor<10xcomplex> + func.return %out : tensor<10xcomplex> + } +} + +// CHECK: @complex_tensor_insert(%[[ARG0:.*]]: !llvm.ptr +// CHECK: %[[C:.*]] = complex.create +// CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ARG0]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, f32)> +// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[C]] : complex to !llvm.struct<(f32, f32)> +// CHECK: llvm.store %[[CAST]], %[[GEP]] : !llvm.struct<(f32, f32)>, !llvm.ptr + +// ----- + +module { + func.func @complex_tensor_extract( + %arg0: tensor<10xcomplex>) -> complex { + %c1 = arith.constant 1 : index + %v2 = tensor.extract %arg0[%c1] : tensor<10xcomplex> + func.return %v2 : complex + } +} + +// CHECK: @complex_tensor_extract(%[[ARG0:.*]]: !llvm.ptr +// CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ARG0]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, f32)> +// CHECK: %[[LOAD:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> !llvm.struct<(f32, f32)> +// CHECK: builtin.unrealized_conversion_cast %[[LOAD]] : !llvm.struct<(f32, f32)> to complex + +// ----- + +module { + // This example is a bit silly, in real life there wouldn't be a loop (the + // loop body would be executed by different threads). We're just doing it this + // way so control flow with shared memory is tested as well. + func.func @transpose_shared(%in: tensor<32x32xf32>, + %out: tensor<32x32xf32>) -> tensor<32x32xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + + %shared = xla_gpu.allocate_shared : tensor<32x32xf32> + %loaded_tile = scf.for %i = %c0 to %c32 step %c1 + iter_args(%tile = %shared) -> tensor<32x32xf32> { + %inner_loaded_tile = scf.for %j = %c0 to %c32 step %c1 + iter_args(%inner_tile = %tile) -> tensor<32x32xf32> { + %v = tensor.extract %in[%i, %j] : tensor<32x32xf32> + %inserted = tensor.insert %v into %inner_tile[%i, %j] + : tensor<32x32xf32> + scf.yield %inserted : tensor<32x32xf32> + } + scf.yield %inner_loaded_tile : tensor<32x32xf32> + } + + %synced = xla_gpu.sync_threads %shared : tensor<32x32xf32> + %written_tile = scf.for %i = %c0 to %c32 step %c1 + iter_args(%written = %out) -> tensor<32x32xf32> { + %inner_written_tile = scf.for %j = %c0 to %c32 step %c1 + iter_args(%inner_written = %written) -> tensor<32x32xf32> { + %v = tensor.extract %shared[%j, %i] : tensor<32x32xf32> + %inserted = tensor.insert %v into %inner_written[%i, %j] + : tensor<32x32xf32> + scf.yield %inserted : tensor<32x32xf32> + } + scf.yield %inner_written_tile : tensor<32x32xf32> + } + + return %written_tile : tensor<32x32xf32> + } +} + +// CHECK: llvm.mlir.global private @[[SHARED:shared_.*]]() +// CHECK-SAME: {addr_space = 3 : i32} : !llvm.array<1024 x f32> +// CHECK: @transpose_shared +// CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @[[SHARED]] : !llvm.ptr<3> +// CHECK: %[[CAST:.*]] = llvm.addrspacecast %[[ADDR]] +// CHECK-SAME: : !llvm.ptr<3> to !llvm.ptr +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[ELEM_ADDR:.*]] = llvm.getelementptr inbounds %[[CAST]] +// CHECK: llvm.store {{.*}} %[[ELEM_ADDR]] +// CHECK: gpu.barrier +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[ELEM_ADDR:.*]] = llvm.getelementptr inbounds %[[CAST]] +// CHECK: llvm.load %[[ELEM_ADDR]] + +// ----- + +module { + func.func @atomic_rmw_f32(%in: tensor<2x4xf32>, %i: index, %j: index) + -> (tensor<2x4xf32>) { + %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> { + ^bb0(%current : f32): + %c42 = arith.constant 1.0 : f32 + %add = arith.addf %current, %c42 : f32 + xla_gpu.yield %add : f32 + } + return %ret : tensor<2x4xf32> + } +} + +// CHECK: @atomic_rmw_f32 +// CHECK: %[[ADDR:.*]] = llvm.getelementptr +// CHECK-NEXT: %[[INIT:.*]] = llvm.load %[[ADDR]] +// CHECK-NEXT: scf.while (%[[VAR:.*]] = %[[INIT]]) +// CHECK: %[[RES:.*]] = llvm.bitcast %{{.*}} : f32 to i32 +// CHECK-NEXT: llvm.cmpxchg %[[ADDR]], %[[VAR]], %[[RES]] + +// ----- + +module { + func.func @atomic_rmw_f16(%in: tensor<2x4xf16>, %i: index, %j: index) + -> (tensor<2x4xf16>) { + %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> { + ^bb0(%current : f16): + %c1 = arith.constant 1.0 : f16 + %add = arith.addf %current, %c1 : f16 + xla_gpu.yield %add : f16 + } + return %ret : tensor<2x4xf16> + } +} + +// CHECK: @atomic_rmw_f16 +// CHECK: %[[ADDR:.*]] = llvm.getelementptr +// CHECK-NEXT: %[[ADDR_INT:.*]] = llvm.ptrtoint %[[ADDR]] +// CHECK-NEXT: %[[OFFSET:.*]] = llvm.and %[[ADDR_INT]], %{{.*}} +// CHECK-NEXT: %[[INDEX:.*]] = llvm.mul %[[OFFSET]], %{{.*}} +// CHECK-NEXT: %[[BASE:.*]] = llvm.getelementptr inbounds %[[ADDR]][%[[INDEX]]] +// CHECK: %[[INIT:.*]] = llvm.load %[[BASE]] +// CHECK-NEXT: scf.while (%[[VAR:.*]] = %[[INIT]]) +// CHECK-NEXT: %[[VAR_SHIFT:.*]] = llvm.lshr %[[VAR]], %{{.*}} +// CHECK-NEXT: %[[VAR_TRUNC:.*]] = llvm.trunc %[[VAR_SHIFT]] +// CHECK-NEXT: llvm.bitcast %[[VAR_TRUNC]] : i16 to f16 +// CHECK: %[[RES:.*]] = llvm.bitcast %{{.*}} : f16 to i16 +// CHECK-NEXT: %[[RES_WIDE:.*]] = llvm.zext %[[RES]] +// CHECK-NEXT: %[[NEW_MASKED:.*]] = llvm.and %[[VAR]], %{{.*}} +// CHECK-NEXT: %[[RES_SHIFT:.*]] = llvm.shl %[[RES_WIDE]], %{{.*}} +// CHECK-NEXT: %[[NEW:.*]] = llvm.or %[[NEW_MASKED]], %[[RES_SHIFT]] +// CHECK-NEXT: llvm.cmpxchg %[[BASE]], %[[VAR]], %[[NEW]] + +// ----- + +module { + func.func @atomic_rmw_overwrite(%in: tensor<2x4xf16>, %i: index, %j: index) + -> (tensor<2x4xf16>) { + %c1 = arith.constant 1.0 : f16 + %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> { + ^bb0(%current : f16): + xla_gpu.yield %c1 : f16 + } + return %ret : tensor<2x4xf16> + } +} +// CHECK: @atomic_rmw_overwrite +// CHECK: %[[ADDR:.*]] = llvm.getelementptr +// CHECK-NEXT: %[[ADDR_INT:.*]] = llvm.ptrtoint %[[ADDR]] +// CHECK-NEXT: %[[OFFSET:.*]] = llvm.and %[[ADDR_INT]], %{{.*}} +// CHECK-NEXT: %[[INDEX:.*]] = llvm.mul %[[OFFSET]], %{{.*}} +// CHECK-NEXT: %[[BASE:.*]] = llvm.getelementptr inbounds %[[ADDR]][%[[INDEX]]] +// CHECK: %[[INIT:.*]] = llvm.load %[[BASE]] +// CHECK-NEXT: scf.while (%[[VAR:.*]] = %[[INIT]]) +// CHECK: %[[RES:.*]] = llvm.bitcast %{{.*}} : f16 to i16 +// CHECK-NEXT: %[[RES_WIDE:.*]] = llvm.zext %[[RES]] +// CHECK-NEXT: %[[NEW_MASKED:.*]] = llvm.and %[[VAR]], %{{.*}} +// CHECK-NEXT: %[[RES_SHIFT:.*]] = llvm.shl %[[RES_WIDE]], %{{.*}} +// CHECK-NEXT: %[[NEW:.*]] = llvm.or %[[NEW_MASKED]], %[[RES_SHIFT]] +// CHECK-NEXT: llvm.cmpxchg %[[BASE]], %[[VAR]], %[[NEW]] diff --git a/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir b/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir new file mode 100644 index 0000000000000..9b1d0b20fe894 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir @@ -0,0 +1,91 @@ +// RUN: mlir_fusions_opt %s -xla-gpu-lower-xla-gpu-to-scf | FileCheck %s + +module { + func.func @reducer(%a: f32, %b: i32, %c: f32, %d: i32) -> (f32, i32) { + return %a, %b : f32, i32 + } + + func.func @shuffler(%a: f32, %b: i32) -> (f32, i32) { + %ret:2 = xla_gpu.shuffle_reduce @reducer(%a, %b) to 4 : f32, i32 + return %ret#0, %ret#1 : f32, i32 + } +} + +// CHECK: @shuffler(%[[A:.*]]: f32, %[[B:.*]]: i32) +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 +// CHECK-DAG: %[[C32:.*]] = arith.constant 32 +// CHECK: %[[A4H:.*]], {{.*}} = gpu.shuffle down %[[A]], %[[C4]], %[[C32]] +// CHECK: %[[B4H:.*]], {{.*}} = gpu.shuffle down %[[B]], %[[C4]], %[[C32]] +// CHECK: %[[AB4:.*]]:2 = call @reducer(%[[A]], %[[B]], %[[A4H]], %[[B4H]]) +// CHECK: %[[A2H:.*]], {{.*}} = gpu.shuffle down %[[AB4]]#0, %[[C2]], %[[C32]] +// CHECK: %[[B2H:.*]], {{.*}} = gpu.shuffle down %[[AB4]]#1, %[[C2]], %[[C32]] +// CHECK: %[[AB2:.*]]:2 = call @reducer(%[[AB4]]#0, %[[AB4]]#1, %[[A2H]], %[[B2H]]) +// CHECK: %[[A1H:.*]], {{.*}} = gpu.shuffle down %[[AB2]]#0, %[[C1]], %[[C32]] +// CHECK: %[[B1H:.*]], {{.*}} = gpu.shuffle down %[[AB2]]#1, %[[C1]], %[[C32]] +// CHECK: %[[AB1:.*]]:2 = call @reducer(%[[AB2]]#0, %[[AB2]]#1, %[[A1H]], %[[B1H]]) +// CHECK: return %[[AB1]]#0, %[[AB1]]#1 + +// ----- + +module { + func.func @reducer(%a: f64, %b: f64) -> f64 { + return %a : f64 + } + + func.func @shuffler(%a: f64) -> f64 { + %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : f64 + return %ret : f64 + } +} + +// CHECK: @shuffler(%[[A:.*]]: f64 +// CHECK: gpu.shuffle down {{.*}}, %[[C1]] +// CHECK: gpu.shuffle down {{.*}}, %[[C1]] + +// ----- + +module { + func.func @predicated_insert( + %v: i32, %tensor: tensor<2xi32>, %index: index, + %cond: i1) -> tensor<2xi32> { + %ret = xla_gpu.predicated_insert %v into %tensor[%index] if %cond + : tensor<2xi32> + return %ret : tensor<2xi32> + } +} + +// CHECK: @predicated_insert +// CHECK-SAME: %[[V:.*]]: i32, %[[TENSOR:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[INDEX:.*]]: index, %[[COND:.*]]: i1 +// CHECK-NEXT: %[[RET:.*]] = scf.if %[[COND]] +// CHECK-NEXT: %[[UPD:.*]] = tensor.insert %[[V]] into %[[TENSOR]][%[[INDEX]]] +// CHECK-NEXT: yield %[[UPD]] +// CHECK-NEXT: else +// CHECK-NEXT: yield %[[TENSOR]] +// CHECK-NEXT: } +// CHECK-NEXT: return %[[RET]] + +// ----- + +module { + func.func @predicated_extract( + %v: i32, %tensor: tensor<2xi32>, %index: index, + %cond: i1) -> i32 { + %ret = xla_gpu.predicated_extract %tensor[%index] if %cond else %v + : tensor<2xi32> + return %ret : i32 + } +} + +// CHECK: @predicated_extract +// CHECK-SAME: %[[V:.*]]: i32, %[[TENSOR:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[INDEX:.*]]: index, %[[COND:.*]]: i1 +// CHECK-NEXT: %[[RET:.*]] = scf.if %[[COND]] +// CHECK-NEXT: %[[VAL:.*]] = tensor.extract %[[TENSOR]][%[[INDEX]]] +// CHECK-NEXT: yield %[[VAL]] +// CHECK-NEXT: else +// CHECK-NEXT: yield %[[V]] +// CHECK-NEXT: } +// CHECK-NEXT: return %[[RET]] diff --git a/xla/service/gpu/fusions/mlir/tests/merge_pointers_to_same_slice.mlir b/xla/service/gpu/fusions/mlir/tests/merge_pointers_to_same_slice.mlir new file mode 100644 index 0000000000000..f3670a42d3a3b --- /dev/null +++ b/xla/service/gpu/fusions/mlir/tests/merge_pointers_to_same_slice.mlir @@ -0,0 +1,40 @@ +// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-lower-tensors -xla-gpu-merge-pointers | FileCheck %s + +module { + func.func private @tensorargs(%arg0: tensor<43xf32> {xla.slice_index = 0}, + %arg1: tensor<43xf32> {xla.slice_index = 1, xla.invariant}, + %arg2: tensor<43xf32> {xla.slice_index = 0}, + %arg3: index) -> f32 { + %v0 = tensor.extract %arg0[%arg3] : tensor<43xf32> + %v1 = tensor.extract %arg1[%arg3] : tensor<43xf32> + %v2 = tensor.extract %arg2[%arg3] : tensor<43xf32> + %sum = arith.addf %v0, %v1 : f32 + %sum2 = arith.addf %sum, %v2 : f32 + func.return %sum2 : f32 + } + + func.func @tensorcall(%arg0: tensor<43xf32> {xla.slice_index = 0}, + %arg1: tensor<43xf32> {xla.slice_index = 1, xla.invariant}, + %arg2: tensor<43xf32> {xla.slice_index = 0}, + %arg3: index) -> f32 { + %call = func.call @tensorargs(%arg0, %arg1, %arg2, %arg3) : + (tensor<43xf32>, tensor<43xf32>, tensor<43xf32>, index) -> f32 + func.return %call : f32 + } +} + +// CHECK: func.func private @tensorargs( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr {llvm.noalias}, +// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr {llvm.noalias, xla.invariant}, +// CHECK-SAME: %[[ARG2:.*]]: index) -> f32 { +// CHECK: %[[GEP0:.*]] = llvm.getelementptr inbounds %[[ARG0]] +// CHECK: llvm.load %[[GEP0]] : !llvm.ptr +// CHECK: %[[GEP1:.*]] = llvm.getelementptr inbounds %[[ARG1]] +// CHECK: llvm.load %[[GEP1]] invariant : !llvm.ptr +// CHECK: %[[GEP2:.*]] = llvm.getelementptr inbounds %[[ARG0]] + +// CHECK: func.func @tensorcall +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr {llvm.noalias}, +// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr {llvm.noalias, xla.invariant}, +// CHECK-SAME: %[[ARG2:.*]]: index) -> f32 { +// CHECK: call @tensorargs(%[[ARG0]], %[[ARG1]], %[[ARG2]]) diff --git a/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc b/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc new file mode 100644 index 0000000000000..3f77b3336b2b7 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc @@ -0,0 +1,49 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/mlir/passes.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registry.insert(); + mlir::func::registerAllExtensions(registry); + mlir::registerCanonicalizerPass(); + mlir::registerCSEPass(); + mlir::registerInliner(); + xla::gpu::registerGpuFusionTransformsPasses(); + + return mlir::failed( + MlirOptMain(argc, argv, "XLA MLIR Fusion Pass Driver\n", registry)); +} diff --git a/xla/service/gpu/fusions/mlir/tests/ops.mlir b/xla/service/gpu/fusions/mlir/tests/ops.mlir new file mode 100644 index 0000000000000..96391210473c2 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/tests/ops.mlir @@ -0,0 +1,61 @@ +// RUN: mlir_fusions_opt %s -canonicalize | FileCheck %s +// RUN: mlir_fusions_opt %s -cse | FileCheck %s --check-prefixes=CHECK-CSE + +module { + func.func @shared_and_sync() -> (tensor<2xf32>, tensor<2xf32>) { + %shared1 = xla_gpu.allocate_shared : tensor<2xf32> + %shared2 = xla_gpu.allocate_shared : tensor<2xf32> + %sync:2 = xla_gpu.sync_threads %shared1, %shared2 + : tensor<2xf32>, tensor<2xf32> + return %sync#0, %sync#1 : tensor<2xf32>, tensor<2xf32> + } +} + +// CHECK: @shared_and_sync +// CHECK-NEXT: allocate_shared +// CHECK-NEXT: allocate_shared +// CHECK-NEXT: sync_threads +// CHECK-NEXT: return + +// ----- + +module { + func.func @atomic_rmw(%in: tensor<2x3xf32>, %i: index, %j: index) + -> (tensor<2x3xf32>) { + %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> { + ^bb0(%current : f32): + %c42 = arith.constant 42.0 : f32 + %add = arith.addf %current, %c42 : f32 + xla_gpu.yield %add : f32 + } + return %ret : tensor<2x3xf32> + } +} + +// CHECK: @atomic_rmw +// CHECK: xla_gpu.atomic_rmw + +// ----- + +module { + func.func private @add(%a: f32, %b: f32) -> f32 { + %ret = arith.addf %a, %b : f32 + return %ret : f32 + } + + func.func @caller(%a: f32, %b: f32) -> f32 { + %c = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) + %d = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) + %ret = arith.addf %c, %d : f32 + return %ret : f32 + } +} + +// CHECK: @caller +// CHECK: %[[C:.*]] = xla_gpu.pure_call @add +// CHECK: %[[D:.*]] = xla_gpu.pure_call @add +// CHECK: arith.addf %[[C]], %[[D]] + +// CHECK-CSE: @caller +// CHECK-CSE: %[[C:.*]] = xla_gpu.pure_call @add +// CHECK-CSE: arith.addf %[[C]], %[[C]] diff --git a/xla/service/gpu/fusions/mlir/tests/propagate_slice_indices.mlir b/xla/service/gpu/fusions/mlir/tests/propagate_slice_indices.mlir new file mode 100644 index 0000000000000..b776337bda065 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/tests/propagate_slice_indices.mlir @@ -0,0 +1,36 @@ +// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-propagate-slice-indices | FileCheck %s + +module { + func.func private @add(%arg0: f32, %arg1: f32) -> f32 { + %sum = arith.addf %arg0, %arg1 : f32 + func.return %sum : f32 + } + + func.func private @tensorarg(%arg0: tensor<43xf32>, %arg1: index) -> f32 { + %v1 = arith.constant 2.0 : f32 + %v2 = tensor.extract %arg0[%arg1] : tensor<43xf32> + %sum = func.call @add(%v1, %v2) : (f32, f32) -> f32 + func.return %sum : f32 + } + + func.func @tensorcall(%arg0: tensor<43xf32>, %arg1: index) -> f32 { + %call = func.call @tensorarg(%arg0, %arg1) : (tensor<43xf32>, index) -> f32 + func.return %call : f32 + } + + func.func @stores(%arg0: tensor<17xf32> {xla.invariant, xla.slice_index = 0}, + %arg1: tensor<43xf32> {xla.slice_index = 1}) -> tensor<43xf32> + attributes { xla.entry } { + %c17 = arith.constant 17 : index + %c23 = arith.constant 23 : index + %cst = arith.constant 3.0 : f32 + %out = tensor.insert %cst into %arg1[%c17] : tensor<43xf32> + %out2 = tensor.insert %cst into %out[%c23] : tensor<43xf32> + func.return %out2 : tensor<43xf32> + } +} + +// CHECK-DAG: @add(%{{.*}}: f32, %{{.*}}: f32) +// CHECK-DAG: @tensorarg(%{{.*}}: tensor<43xf32> {xla.invariant, xla.slice_index = 0 : i64}, %{{.*}}: index) +// CHECK-DAG: @tensorcall(%{{.*}}: tensor<43xf32> {xla.invariant, xla.slice_index = 0 : i64}, %{{.*}}: index) +// CHECK-DAG: @stores(%{{.*}}: tensor<17xf32> {xla.invariant, xla.slice_index = 0 : i64}, %{{.*}}: tensor<43xf32> {xla.slice_index = 1 : i64}) diff --git a/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir b/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir new file mode 100644 index 0000000000000..6411232f8980b --- /dev/null +++ b/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir @@ -0,0 +1,84 @@ +// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-simplify-affine | FileCheck %s + +module { + func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %0 = gpu.thread_id x {xla.range = [0 : index, 127 : index]} + %1 = gpu.block_id x {xla.range = [0 : index, 3071 : index]} + scf.for %arg3 = %c0 to %c4 step %c1 { + %2 = affine.apply affine_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 - ((s1 * 4 + s2) floordiv 256) * 256 + (s1 floordiv 64) * 256 - ((s0 * 2 + s1 floordiv 64) floordiv 3) * 768 + ((s0 * 128 + s1) floordiv 192) * 768 - (((s0 * 128 + s1) floordiv 192) floordiv 1024) * 786432 + (s0 floordiv 1536) * 786432)>()[%1, %0, %arg3] + %3 = arith.index_castui %2 : index to i64 + %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + %5 = llvm.load %4 invariant : !llvm.ptr -> f32 + %8 = llvm.getelementptr %arg1[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + %9 = llvm.load %8 invariant : !llvm.ptr -> f32 + %10 = arith.cmpf oge, %5, %9 : f32 + %11 = llvm.getelementptr %arg2[%3] : (!llvm.ptr, i64) -> !llvm.ptr, i1 + llvm.store %10, %11 : i1, !llvm.ptr + } + return + } +} + +// CHECK: @op_and_for_ranges +// CHECK-DAG: %[[C512:.*]] = arith.constant 512 +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 +// CHECK-DAG: %[[TID_X:.*]] = gpu.thread_id x +// CHECK-DAG: %[[BID_X:.*]] = gpu.block_id x +// CHECK: scf.for %[[I:.*]] = +// CHECK: %[[BID_32:.*]] = arith.index_castui %[[BID_X]] : index to i32 +// CHECK: %[[BLOCK_OFFSET:.*]] = arith.muli %[[BID_32]], %[[C512]] +// CHECK: %[[TID_32:.*]] = arith.index_castui %[[TID_X]] : index to i32 +// CHECK: %[[THREAD_OFFSET:.*]] = arith.muli %[[TID_32]], %[[C4]] +// CHECK: %[[OFFSET:.*]] = arith.addi %[[BLOCK_OFFSET]], %[[THREAD_OFFSET]] +// CHECK: %[[I_32:.*]] = arith.index_castui %[[I]] : index to i32 +// CHECK: %[[UNROLL_OFFSET:.*]] = arith.addi %[[OFFSET]], %[[I_32]] +// CHECK: %[[UNROLL_INDEX:.*]] = arith.index_castui %[[UNROLL_OFFSET]] +// CHECK-SAME: {xla.range = [0 : index, 1572863 : index]} : i32 to index +// CHECK: arith.index_castui %[[UNROLL_INDEX]] : index to i64 + +// ----- + +module { + func.func @arg_ranges(%arg0: index {xla.range = [0 : index, 42 : index]}, %arg1: index {xla.range = [0 : index, 1000 : index]}) -> index { + %0 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)>()[%arg0, %arg1] + return %0 : index + } +} + +// CHECK: @arg_ranges +// CHECK-NEXT: %[[C100:.*]] = arith.constant 100 +// CHECK-NEXT: %[[ARG0_32:.*]] = arith.index_castui {{.*}} : index to i32 +// CHECK-NEXT: %[[RET_32:.*]] = arith.divui %[[ARG0_32]], %[[C100]] +// CHECK-NEXT: %[[RET:.*]] = arith.index_castui %[[RET_32]] +// CHECK-SAME: {xla.range = [0 : index, 10 : index]} : i32 to index +// CHECK-NEXT: return %[[RET]] + +// ----- + +module { + func.func @needs_i64(%arg0: index {xla.range = [0 : index, 1000000000000 : index]}, %arg1: index {xla.range = [0 : index, 10 : index]}) -> index { + %0 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg0, %arg1] + return %0 : index + } +} + +// CHECK: @needs_i64 +// CHECK: arith.index_castui {{.*}} : index to i64 +// CHECK: arith.index_castui {{.*}} : index to i64 +// CHECK: arith.index_castui {{.*}} : i64 to index + +// ----- + +module { + func.func @cant_lower(%arg0: index {xla.range = [-10 : index, 42 : index]}, %arg1: index {xla.range = [0 : index, 1000 : index]}) -> index { + %0 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)>()[%arg0, %arg1] + return %0 : index + } +} + +// CHECK: @cant_lower +// CHECK: affine.apply +// CHECK-SAME: {xla.range = [-1 : index, 10 : index]} diff --git a/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir b/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir new file mode 100644 index 0000000000000..f9518002cbff9 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir @@ -0,0 +1,70 @@ +// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-simplify-arith -canonicalize | FileCheck %s + +module { + func.func @unknown(%arg0: index {xla.range = [0 : index, 42 : index]}) -> i1 { + %c12 = arith.constant 12 : index + %eq = arith.cmpi eq, %arg0, %c12 : index + return %eq : i1 + } +} + +// CHECK: @unknown +// CHECK: cmpi + +// ----- + + +module { + func.func @true(%arg0: index {xla.range = [12 : index, 42 : index]}) -> i1 { + %c5 = arith.constant 5 : index + %eq = arith.cmpi sge, %arg0, %c5 : index + return %eq : i1 + } +} + +// CHECK: @true +// CHECK-NEXT: constant true +// CHECK-NEXT: return + +// ----- + +module { + func.func @false(%arg0: index {xla.range = [12 : index, 42 : index]}) -> i1 { + %c5 = arith.constant 5 : index + %eq = arith.cmpi slt, %arg0, %c5 : index + return %eq : i1 + } +} + +// CHECK: @false +// CHECK-NEXT: constant false +// CHECK-NEXT: return + +// ----- + +module { + func.func @rhs_range(%arg0: index {xla.range = [12 : index, 42 : index]}) -> i1 { + %c42 = arith.constant 64 : index + %eq = arith.cmpi slt, %c42, %arg0 : index + return %eq : i1 + } +} + +// CHECK: @rhs_range +// CHECK-NEXT: constant false +// CHECK-NEXT: return + +// ----- + +module { + func.func @both_range(%arg0: index {xla.range = [12 : index, 42 : index]}, + %arg1: index {xla.range = [63 : index, 100 : index]}) -> i1 { + // This is true, but we don't support it yet. + %eq = arith.cmpi slt, %arg0, %arg1 : index + return %eq : i1 + } +} + +// CHECK: @both_range +// CHECK-NEXT: cmpi +// CHECK-NEXT: return \ No newline at end of file diff --git a/xla/service/gpu/fusions/mlir/type_util.cc b/xla/service/gpu/fusions/mlir/type_util.cc new file mode 100644 index 0000000000000..b311af50e72bb --- /dev/null +++ b/xla/service/gpu/fusions/mlir/type_util.cc @@ -0,0 +1,64 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/mlir/type_util.h" + +#include "absl/log/check.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "xla/layout_util.h" +#include "xla/mlir/utils/type_util.h" +#include "xla/shape.h" +#include "xla/translate/hlo_to_mhlo/hlo_utils.h" + +namespace xla { +namespace gpu { +namespace mlir_converter { + +mlir::Type TensorShapeToMlirType(const Shape& shape, mlir::OpBuilder& b) { + CHECK(shape.IsArray()); + + // Default layouts create a lot of clutter in the IR, so only add an + // encoding when needed. + mlir::Attribute layout = {}; + if (!LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) { + layout = CreateDenseIntElementsAttrFromVector( + llvm::to_vector(shape.layout().minor_to_major()), b); + } + return mlir::RankedTensorType::get( + llvm::to_vector(shape.dimensions()), + *ConvertPrimitiveTypeToMlirType(shape.element_type(), b), layout); +} + +llvm::SmallVector ShapeToMlirTypes(const Shape& shape, + mlir::OpBuilder& b) { + llvm::SmallVector types; + types.reserve(shape.IsTuple() ? shape.tuple_shapes_size() : 1); + if (shape.IsTuple()) { + types.reserve(shape.tuple_shapes_size()); + for (auto& tuple_shape : shape.tuple_shapes()) { + types.push_back(TensorShapeToMlirType(tuple_shape, b)); + } + } else { + types.push_back(TensorShapeToMlirType(shape, b)); + } + return types; +} + +} // namespace mlir_converter +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/type_util.h b/xla/service/gpu/fusions/mlir/type_util.h new file mode 100644 index 0000000000000..d9d394ffa6c9b --- /dev/null +++ b/xla/service/gpu/fusions/mlir/type_util.h @@ -0,0 +1,41 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_TYPE_UTIL_H_ +#define XLA_SERVICE_GPU_FUSIONS_MLIR_TYPE_UTIL_H_ + +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "xla/shape.h" + +namespace xla { +namespace gpu { +namespace mlir_converter { + +// Converts an XLA tensor to an MLIR ranked tensor. The layout is stored in the +// encoding attribute, if it is not the default layout. `shape` must be an +// array. +mlir::Type TensorShapeToMlirType(const Shape& shape, mlir::OpBuilder& b); + +// If `shape` is a tuple, returns the converted tuple shapes. Otherwise returns +// just the converted shape. Nested tuples are not supported. +llvm::SmallVector ShapeToMlirTypes(const Shape& shape, + mlir::OpBuilder& b); + +} // namespace mlir_converter +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_TYPE_UTIL_H_ diff --git a/xla/service/gpu/fusions/mlir/type_util_test.cc b/xla/service/gpu/fusions/mlir/type_util_test.cc new file mode 100644 index 0000000000000..0d35bbc452f53 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/type_util_test.cc @@ -0,0 +1,89 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/mlir/type_util.h" + +#include + +#include +#include +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { +namespace mlir_converter { +namespace { + +using ::testing::ElementsAre; + +std::string TypeToString(mlir::Type type) { + std::string out; + llvm::raw_string_ostream stream(out); + stream << type; + return out; +} + +llvm::SmallVector TypesToString( + const llvm::SmallVector& types) { + return llvm::map_to_vector(types, TypeToString); +} + +TEST(TensorShapeTest, ConvertsShape) { + mlir::MLIRContext ctx; + mlir::OpBuilder b(&ctx); + EXPECT_EQ(TypeToString( + TensorShapeToMlirType(ShapeUtil::MakeShape(S32, {4, 5, 6}), b)), + "tensor<4x5x6xi32>"); +} + +TEST(TensorShapeTest, ConvertsLayout) { + mlir::MLIRContext ctx; + mlir::OpBuilder b(&ctx); + EXPECT_EQ( + TypeToString(TensorShapeToMlirType( + ShapeUtil::MakeShapeWithDenseLayout(S32, {4, 5, 6}, {0, 2, 1}), b)), + "tensor<4x5x6xi32, dense<[0, 2, 1]> : tensor<3xi64>>"); +} + +TEST(ShapeTest, ConvertsArray) { + mlir::MLIRContext ctx; + mlir::OpBuilder b(&ctx); + EXPECT_THAT( + TypesToString(ShapeToMlirTypes(ShapeUtil::MakeShape(S32, {4, 5, 6}), b)), + ElementsAre("tensor<4x5x6xi32>")); +} + +TEST(ShapeTest, ConvertsTuple) { + mlir::MLIRContext ctx; + mlir::OpBuilder b(&ctx); + + EXPECT_THAT( + TypesToString(ShapeToMlirTypes( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {4, 5, 6}), + ShapeUtil::MakeShape(F32, {})}), + b)), + ElementsAre("tensor<4x5x6xi32>", "tensor")); +} + +} // namespace +} // namespace mlir_converter +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir_emitter_test_base.cc b/xla/service/gpu/fusions/mlir_emitter_test_base.cc new file mode 100644 index 0000000000000..2dfc06b9e747a --- /dev/null +++ b/xla/service/gpu/fusions/mlir_emitter_test_base.cc @@ -0,0 +1,110 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusions/mlir_emitter_test_base.h" + +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project +#include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/tests/filecheck.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/xla.pb.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +MlirEmitterTestBaseImpl::MlirEmitterTestBaseImpl() { + // clang-format off + mlir_context_.loadDialect< + mlir::affine::AffineDialect, + mlir::arith::ArithDialect, + mlir::complex::ComplexDialect, + mlir::func::FuncDialect, + mlir::gpu::GPUDialect, + mlir::math::MathDialect, + mlir::mhlo::MhloDialect, + mlir::scf::SCFDialect, + mlir::tensor::TensorDialect, + xla::gpu::XlaGpuDialect + >(); + // clang-format on + mlir::DialectRegistry registry; + mlir::func::registerInlinerExtension(registry); + mlir_context_.appendDialectRegistry(registry); + thread_id_printer_ = + AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, {}); +} + +DebugOptions MlirEmitterTestBaseImpl::GetDebugOptionsForTest() { + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_mlir_emitters(true); + return debug_options; +} + +absl::StatusOr MlirEmitterTestBaseImpl::EmitIR( + std::string_view hlo_string) { + TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_string)); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + auto fusion_emitter = GetEmitter(analysis); + + TF_ASSIGN_OR_RETURN(auto mlir_module, + fusion_emitter->CreateMLIRModule( + mlir_context_, *Cast(root), + "fused_computation", nullptr)); + + std::string out; + llvm::raw_string_ostream os(out); + mlir_module->print(os); + + return out; +} + +absl::Status MlirEmitterTestBaseImpl::EmitAndCheckIR( + std::string_view hlo_string, std::string_view pattern) { + TF_ASSIGN_OR_RETURN(auto ir, EmitIR(hlo_string)); + TF_ASSIGN_OR_RETURN(auto check_result, RunFileCheck(ir, pattern)); + return check_result ? absl::Status() + : absl::FailedPreconditionError("match failure"); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir_emitter_test_base.h b/xla/service/gpu/fusions/mlir_emitter_test_base.h new file mode 100644 index 0000000000000..a299c2ea4007b --- /dev/null +++ b/xla/service/gpu/fusions/mlir_emitter_test_base.h @@ -0,0 +1,68 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_EMITTER_TEST_BASE_H_ +#define XLA_SERVICE_GPU_FUSIONS_MLIR_EMITTER_TEST_BASE_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" + +namespace xla { +namespace gpu { + +class MlirEmitterTestBaseImpl : public HloTestBase { + public: + MlirEmitterTestBaseImpl(); + + virtual std::unique_ptr GetEmitter( + const HloFusionAnalysis& analysis) = 0; + + DebugOptions GetDebugOptionsForTest() override; + + absl::StatusOr EmitIR(std::string_view hlo_string); + absl::Status EmitAndCheckIR(std::string_view hlo_string, + std::string_view pattern); + + stream_executor::DeviceDescription device_info_ = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); + mlir::MLIRContext mlir_context_; + AffineMapPrinter thread_id_printer_; +}; + +template +class MlirEmitterTestBase : public MlirEmitterTestBaseImpl { + public: + std::unique_ptr GetEmitter( + const HloFusionAnalysis& analysis) override { + return std::make_unique(analysis); + } +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_EMITTER_TEST_BASE_H_ diff --git a/xla/service/gpu/fusions/reduction.cc b/xla/service/gpu/fusions/reduction.cc index 5ccc9f88d509e..193fc36d20266 100644 --- a/xla/service/gpu/fusions/reduction.cc +++ b/xla/service/gpu/fusions/reduction.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,13 +19,18 @@ limitations under the License. #include #include #include +#include #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/container/node_hash_map.h" +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" @@ -36,40 +41,41 @@ limitations under the License. #include "llvm/IR/Value.h" #include "llvm/Support/AtomicOrdering.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" +#include "xla/layout_util.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/elemental_ir_emitter.h" +#include "xla/service/gpu/elemental_ir_emitter.h" #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/fusions/thunk_util.h" #include "xla/service/gpu/fusions/tiling_util.h" -#include "xla/service/gpu/gpu_fusible.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/ir_emitter_nested.h" #include "xla/service/gpu/kernel_arguments.h" -#include "xla/service/gpu/kernel_mapping_scheme.h" #include "xla/service/gpu/kernel_reuse_cache.h" -#include "xla/service/gpu/kernel_thunk.h" +#include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/parallel_loop_emitter.h" #include "xla/service/gpu/reduction_utils.h" +#include "xla/service/gpu/runtime/kernel_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" #include "xla/service/gpu/target_util.h" -#include "xla/service/gpu/thunk.h" #include "xla/service/llvm_ir/fused_ir_emitter.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/kernel_support_library.h" +#include "xla/service/llvm_ir/llvm_loop.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/service/llvm_ir/loop_emitter.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" -#include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" +#include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" @@ -94,51 +100,83 @@ int GetNumOutputs(const Shape& shape) { return 1; } +const Shape& OutputShape(const Shape& output_shape, int output_index) { + CHECK(output_index == 0 || output_shape.IsTuple()); + return output_shape.IsTuple() ? output_shape.tuple_shapes(output_index) + : output_shape; +} + llvm::Type* GetIndexType(const HloFusionInstruction& fusion, - const TilingScheme& tiling_scheme, - llvm::IRBuilder<>* builder) { - return GetIndexTypeForKernel(&fusion, - tiling_scheme.GetNumThreadsPerBlockPhysical() * - tiling_scheme.GetNumberOfBlocksPhysical(), - builder); + const Tiling& tiling, llvm::IRBuilder<>* builder) { + return GetIndexTypeForKernel( + &fusion, tiling.GetNumThreadsPerBlock() * tiling.GetNumBlocks(), builder); +} + +// For a row reduction, returns the number of rows we can process in parallel +// per warp. +int RowReductionGetRowsPerWarp(int reduced_dimension_size) { + if (WarpSize() % reduced_dimension_size != 0 || + reduced_dimension_size >= WarpSize()) { + return 1; + } + return WarpSize() / reduced_dimension_size; +} + +llvm::Value* CastSharedToGlobal(llvm::IRBuilder<>* builder, llvm::Value* input, + llvm::Type* element_type, llvm::Twine name) { + return builder->CreateAddrSpaceCast( + input, + llvm::PointerType::get(element_type, + /*AddressSpace=*/0), + name); } class ReductionEmitter { public: - ReductionEmitter(HloFusionAnalysis& analysis, + ReductionEmitter(const HloFusionAnalysis& analysis, + const ReductionInfo& reduction_codegen_info, IrEmitterContext& ir_emitter_context, - ElementalIrEmitter& elemental_emitter, - mlir::lmhlo::FusionOp fusion_op, const HloFusionInstruction& fusion, - KernelReuseCache& kernel_cache, llvm::IRBuilder<>* builder) - : analysis_(analysis), + llvm::IRBuilder<>* builder) + : builder_(builder), + elemental_emitter_(ir_emitter_context, builder_), + analysis_(analysis), + reduction_codegen_info_(reduction_codegen_info), ir_emitter_context_(ir_emitter_context), - elemental_emitter_(elemental_emitter), - fusion_op_(fusion_op), fusion_(fusion), - kernel_cache_(kernel_cache), - builder_(builder), - index_ty_(GetIndexType( - fusion, analysis.GetReductionCodegenInfo()->GetTilingScheme(), - builder)) {} + index_ty_(GetIndexType(fusion, reduction_codegen_info.GetTiling(), + elemental_emitter_.builder())) { + for (auto hero : analysis.fusion_heroes()) { + if (hero->opcode() == HloOpcode::kReduce) { + for (int i = 0; i < hero->operand_count() / 2; ++i) { + CHECK(LayoutUtil::IsMonotonicWithDim0Major( + hero->operand(i)->shape().layout())) + << "reduction-layout-normalizer must run before code generation"; + } + } + } + } - StatusOr Emit(); + absl::StatusOr EmitInitializers(); + absl::Status EmitKernel(const LaunchDimensions& launch_dims, + std::vector inputs, + std::vector outputs); private: friend class ReductionGroupEmitter; - StatusOr> BuildKernelThunkForFusion( + absl::StatusOr> BuildKernelThunkForFusion( const LaunchDimensions& launch_dimensions, absl::string_view discriminator, std::function, std::vector)> kernel_builder_fn); - StatusOr> BuildFusedInitializerThunk( - const HloInstruction* fusion_root, mlir::Value dest, - BufferAllocation::Slice dest_slice, int output_index); + absl::StatusOr> BuildFusedInitializerThunk( + const HloInstruction* fusion_root, BufferAllocation::Slice dest_slice, + int output_index); - Status EmitIRForReduction( + absl::Status EmitIRForReduction( absl::Span instr_index_group, FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays, const Shape& input_shape); @@ -146,37 +184,25 @@ class ReductionEmitter { void MaybeEmitFenceForAMDGPU(); void EmitSyncThreads(); - // For a row reduction, returns the number of rows we can process in parallel - // per warp. - int RowReductionGetRowsPerWarp() const { - int reduced_dimension_size = ReducedDimensionSize(); - if (WarpSize() % reduced_dimension_size != 0 || - reduced_dimension_size >= WarpSize()) { - return 1; - } - return WarpSize() / reduced_dimension_size; - } - int ReducedDimensionSize() const { - return analysis_.GetReductionCodegenInfo() - ->GetTilingScheme() - .GetDimsInElems()[2]; + return reduction_codegen_info_.GetTiling().GetShape()[2]; } - HloFusionAnalysis& analysis_; + llvm::IRBuilder<>* builder_; + GpuElementalIrEmitter elemental_emitter_; + const HloFusionAnalysis& analysis_; + const ReductionInfo& reduction_codegen_info_; IrEmitterContext& ir_emitter_context_; - ElementalIrEmitter& elemental_emitter_; - mlir::lmhlo::FusionOp fusion_op_; const HloFusionInstruction& fusion_; - KernelReuseCache& kernel_cache_; - llvm::IRBuilder<>* builder_; llvm::Type* index_ty_; }; +class ReductionEmitter; + class ReductionGroupEmitter { public: struct ReductionCalculationState { - llvm::GlobalVariable* shared_cache; + std::optional shared_cache; llvm::Value* initial_value; llvm::AllocaInst* partial_result_address; llvm::AllocaInst* input_address; @@ -206,13 +232,13 @@ class ReductionGroupEmitter { void EmitReductionOutputForRowReduction( const TilingKernelInfo& tiling_kernel_info, - const HloReduceInstruction* reduction, const HloInstruction* root, - int partial_result_idx) const; + const HloReduceInstruction* reduction, + const std::vector& roots) const; void EmitReductionOutputForColumnReduction( const TilingKernelInfo& tiling_kernel_info, - const HloReduceInstruction* reduction, const HloInstruction* root, - int partial_result_idx) const; + const HloReduceInstruction* reduction, + const std::vector& roots) const; void EmitFullWarpShuffleDownLoopForReduce( const HloComputation* reducer, @@ -221,24 +247,21 @@ class ReductionGroupEmitter { void WriteReductionOutput(const TilingKernelInfo& tiling_kernel_info, const HloReduceInstruction* reduction, - const HloInstruction* root, int partial_result_idx, + const std::vector& roots, absl::Span values) const; llvm_ir::IrArray::Index GetOutputIndexForReduction( - int partial_result_idx, const TilingKernelInfo& tiling_kernel_info, + const TilingKernelInfo& tiling_kernel_info, const HloReduceInstruction* reduction, const HloInstruction* root, int output_idx) const; - void GenerateElementForReducer( - const HloReduceInstruction* reduction, llvm::Value* partial_result_index, - const llvm_ir::IrArray::Index& index_without_linear, - const llvm_ir::IrArray::Index& input_index, - int num_partial_results) const; + void GenerateElementForReducer(const HloReduceInstruction* reduction, + const llvm_ir::IrArray::Index& index) const; - Status EmitExtraOutputsForReduce( + absl::Status EmitExtraOutputsForReduce( const Shape& reduction_operand_shape, const llvm_ir::IrArray::Index& index, - const ExtraOutputGensMap& extra_output_gens) const; + const ExtraOutputGensMap& extra_output_gens); private: ReductionEmitter& reduction_emitter_; @@ -251,23 +274,6 @@ class ReductionGroupEmitter { absl::flat_hash_map state_; }; -// Allocates a shared tile of given dimensions, applying scaling specified in -// tilng_scheme as a major-most dimension to avoid collisions. -llvm::GlobalVariable* AllocateShared( - llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, - llvm::Type* element_type, - absl::Span dimensions_major_to_minor, - absl::string_view buffer_name) { - CHECK(!dimensions_major_to_minor.empty()); - llvm::Type* ty = element_type; - for (auto dim : llvm::reverse(dimensions_major_to_minor)) { - ty = llvm::ArrayType::get(ty, dim); - } - ty = llvm::ArrayType::get(ty, tiling_scheme.GetThreadIdScalingFactor()); - return llvm_ir::AllocateSharedMemoryTile( - builder->GetInsertBlock()->getModule(), ty, buffer_name); -} - // Creates accumulator alloca's, populates them with initial values, generates // __shared__ caches and returns the populated object. ReductionGroupEmitter::ReductionGroupEmitter( @@ -276,19 +282,16 @@ ReductionGroupEmitter::ReductionGroupEmitter( const ReductionOutputMap& result_ir_arrays, FusedIrEmitter& fused_emitter) : reduction_emitter_(reduction_emitter), result_ir_arrays_(result_ir_arrays) { - const ReductionCodegenInfo& reduction_info = - *reduction_emitter_.analysis_.GetReductionCodegenInfo(); + const ReductionInfo& reduction_info = + reduction_emitter_.reduction_codegen_info_; VLOG(10) << "Emit prologue for reduction: " << reduction_emitter_.fusion_.ToString(); auto* builder = reduction_emitter_.builder_; for (const HloReduceInstruction* reduce_hlo : reduce_instr_index_group) { - int num_partial_results = reduction_info.GetNumPartialResults(); for (int op_result_idx = 0; op_result_idx < GetNumOutputs(reduce_hlo->shape()); op_result_idx++) { - Shape result_shape = reduce_hlo->shape().IsTuple() - ? reduce_hlo->shape().tuple_shapes(op_result_idx) - : reduce_hlo->shape(); + Shape result_shape = OutputShape(reduce_hlo->shape(), op_result_idx); llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType( result_shape.element_type(), builder->GetInsertBlock()->getModule()); @@ -296,11 +299,8 @@ ReductionGroupEmitter::ReductionGroupEmitter( llvm_ir::EmitAllocaAtFunctionEntry( element_type, "reduction_input_address", builder); - llvm::AllocaInst* partial_result_address = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - element_type, - /*element_count=*/builder->getInt32(num_partial_results), - "partial_reduction_result", builder); + llvm::AllocaInst* result_address = llvm_ir::EmitAllocaAtFunctionEntry( + element_type, "partial_reduction_result", builder); const HloInstruction* init_value = reduce_hlo->init_values()[op_result_idx]; @@ -310,51 +310,41 @@ ReductionGroupEmitter::ReductionGroupEmitter( *init_value))(llvm_ir::IrArray::Index(builder->getInt32Ty())) .value(); - for (int i = 0; i < num_partial_results; ++i) { - builder->CreateStore( - init_ir_value, builder->CreateInBoundsGEP( - partial_result_address->getAllocatedType(), - partial_result_address, {builder->getInt32(i)})); - } - - const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); - llvm::GlobalVariable* shared_cache = [&]() -> llvm::GlobalVariable* { + builder->CreateStore(init_ir_value, result_address); + const Tiling& tiling = reduction_info.GetTiling(); + auto shared_cache = [&]() -> std::optional { + auto* module = reduction_emitter.ir_emitter_context_.llvm_module(); if (reduction_info.IsRowReduction()) { // Multi-row reductions do not use shared memory. - if (reduction_emitter_.RowReductionGetRowsPerWarp() > 1) { - return nullptr; + if (RowReductionGetRowsPerWarp( + reduction_emitter_.ReducedDimensionSize()) > 1) { + return std::nullopt; } - // Allocate __shared__ - // cache[num_partial_results][num_warps][scaling_factor]. - CHECK_EQ(tiling_scheme.GetNumThreadsPerBlock() % WarpSize(), 0); - int num_warps = tiling_scheme.GetNumThreadsPerBlock() / WarpSize(); - return AllocateShared(builder, tiling_scheme, element_type, - {num_partial_results, num_warps}, - "shared_cache"); - } else { - int64_t num_threads_x = - tiling_scheme.GetNumThreadsFor(TilingScheme::DimX); - // Allocate __shared__ - // cache[num_threads][num_threads + 1], where - // num_threads == num_threads_x == num_threads_y. The "+1" is used to - // avoid bank conflicts. - // - // (Although each thread produces num_partial_results results, we - // don't need that much cache: Only one result is live at a time.) - CHECK_EQ(num_threads_x, - tiling_scheme.GetNumThreadsFor(TilingScheme::DimY)); - return AllocateShared(builder, tiling_scheme, element_type, - {num_threads_x, num_threads_x + 1}, - "shared_cache"); + // Allocate one shared memory element per warp. + auto block_size = tiling.GetThreadsPerBlock(); + CHECK_EQ(block_size[ReductionDimensions::kRowMinorReducedDimension] % + WarpSize(), + 0); + return llvm_ir::AllocateSharedMemoryTile( + module, element_type, + {block_size[ReductionDimensions::kRowKeptDimension], + block_size[ReductionDimensions::kRowMinorReducedDimension] / + WarpSize()}, + "shared_cache"); } + const auto& num_threads = tiling.GetThreadsPerBlock(); + int n = num_threads[ReductionDimensions::kColReducedDimension]; + CHECK_EQ(n, num_threads[ReductionDimensions::kColMinorKeptDimension]); + // The "+1" is used to avoid bank conflicts. + return llvm_ir::AllocateSharedMemoryTile(module, element_type, + {n, n + 1}, "shared_cache"); }(); llvm_ir::ElementGenerator input_gen = *fused_emitter.GetGenerator(*reduce_hlo->inputs()[op_result_idx]); - SetCalculationStateFor( - {shared_cache, init_ir_value, partial_result_address, - reduction_input_address, input_gen}, - reduce_hlo, op_result_idx); + SetCalculationStateFor({shared_cache, init_ir_value, result_address, + reduction_input_address, input_gen}, + reduce_hlo, op_result_idx); } } } @@ -362,8 +352,7 @@ ReductionGroupEmitter::ReductionGroupEmitter( void ReductionEmitter::MaybeEmitFenceForAMDGPU() { auto* module = builder_->GetInsertBlock()->getModule(); if (IsAMDGPU(module) && - ir_emitter_context_.rocm_compute_capability().gcn_arch_name().substr( - 0, 6) == "gfx90a") { + ir_emitter_context_.rocm_compute_capability().fence_before_barrier()) { builder_->CreateFence( llvm::AtomicOrdering::SequentiallyConsistent, builder_->getContext().getOrInsertSyncScopeID("workgroup")); @@ -392,11 +381,11 @@ void ReductionEmitter::EmitSyncThreads() { // std::vector outputs) { ... }; // TF_ASSIGN_OR_RETURN( // auto thunk, -// BuildKernelThunkForFusion(..., fusion_op, launch_dimensions, builder_fn, -// ...)); +// BuildKernelThunkForFusion(..., launch_dimensions, builder_fn)); // AddThunkToThunkSequence(std::move(thunk)) // ``` -StatusOr> ReductionEmitter::BuildKernelThunkForFusion( +absl::StatusOr> +ReductionEmitter::BuildKernelThunkForFusion( const LaunchDimensions& launch_dimensions, absl::string_view discriminator, std::function, std::vector)> @@ -405,48 +394,45 @@ StatusOr> ReductionEmitter::BuildKernelThunkForFusion( fusion_.fused_instructions_computation(); std::string suggested_kernel_name = std::string(fusion_.name()); - TF_ASSIGN_OR_RETURN( - auto kernel_arguments, - ir_emitter_context_.emit_ir_from_hlo() - ? KernelArguments::Create(ir_emitter_context_.buffer_assignment(), - &fusion_) - : KernelArguments::Create(ir_emitter_context_.allocations(), - fusion_op_)); - - auto kernel_builder_status = OkStatus(); - auto [entry, cached] = kernel_cache_.Get( - fused_computation, kernel_arguments.args(), discriminator, - [&]() -> KernelReuseCache::Entry { - auto [kernel, input_arrays, output_arrays] = BuildKernelPrototype( - ir_emitter_context_, suggested_kernel_name, kernel_arguments.args(), - fusion_.operand_count(), launch_dimensions, builder_); - kernel_builder_status = kernel_builder_fn(input_arrays, output_arrays); - return {kernel->getName().str(), launch_dimensions}; - }); - TF_RETURN_IF_ERROR(kernel_builder_status); + TF_ASSIGN_OR_RETURN(auto kernel_arguments, + KernelArguments::Create( + ir_emitter_context_.buffer_assignment(), &fusion_)); + + auto [status_or_entry, cached] = + ir_emitter_context_.kernel_cache().GetWithStatus( + fused_computation, kernel_arguments.args(), discriminator, + [&]() -> absl::StatusOr { + llvm::Function* kernel; + std::vector input_arrays; + std::vector output_arrays; + TF_ASSIGN_OR_RETURN( + std::tie(kernel, input_arrays, output_arrays), + BuildKernelPrototype(ir_emitter_context_, suggested_kernel_name, + kernel_arguments.args(), + fusion_.operand_count(), launch_dimensions, + builder_)); + TF_RETURN_IF_ERROR(kernel_builder_fn(input_arrays, output_arrays)); + // Shared memory is allocated statically. + return {{kernel->getName().str(), launch_dimensions, + /*cluster_dim=*/std::nullopt, + /*shmem_bytes=*/0}}; + }); + TF_ASSIGN_OR_RETURN(const KernelReuseCache::Entry* entry, status_or_entry); if (cached) { VLOG(3) << "Reuse: " << suggested_kernel_name << " -> " - << entry.kernel_name; - } - - if (ir_emitter_context_.emit_ir_from_hlo()) { - return std::make_unique( - &fusion_, entry.kernel_name, kernel_arguments.args(), launch_dimensions, - // Shared memory is allocated statically. - /*shmem_bytes=*/0); + << entry->kernel_name; } return std::make_unique( - fusion_op_, entry.kernel_name, kernel_arguments.args(), launch_dimensions, - // Shared memory is allocated statically. - /*shmem_bytes=*/0); + &fusion_, entry->kernel_name, kernel_arguments.args(), launch_dimensions, + entry->cluster_dim, entry->shmem_bytes); } -Status ReductionGroupEmitter::EmitExtraOutputsForReduce( +absl::Status ReductionGroupEmitter::EmitExtraOutputsForReduce( const Shape& reduction_operand_shape, const llvm_ir::IrArray::Index& index, - const ExtraOutputGensMap& extra_output_gens) const { + const ExtraOutputGensMap& extra_output_gens) { if (extra_output_gens.empty()) { - return OkStatus(); + return absl::OkStatus(); } auto* builder = reduction_emitter_.builder_; @@ -473,17 +459,15 @@ Status ReductionGroupEmitter::EmitExtraOutputsForReduce( for (const auto& [instr, generator] : extra_output_ir_values) { absl::Span result_ir = result_ir_arrays_.at(instr); CHECK_EQ(result_ir.size(), 1); - result_ir[0].EmitWriteArrayElement( - get_index(instr), generator, builder, /*use_linear_index=*/ - reduction_emitter_.analysis_.GetReductionCodegenInfo() - ->GetNumPartialResults() == 1); + result_ir[0].EmitWriteArrayElement(get_index(instr), generator, builder); } - return OkStatus(); + return absl::OkStatus(); } -StatusOr> ReductionEmitter::BuildFusedInitializerThunk( - const HloInstruction* fusion_root, mlir::Value dest, - BufferAllocation::Slice dest_slice, int output_index) { +absl::StatusOr> +ReductionEmitter::BuildFusedInitializerThunk(const HloInstruction* fusion_root, + BufferAllocation::Slice dest_slice, + int output_index) { const HloReduceInstruction* reduce = DynCast(fusion_root); TF_RET_CHECK(reduce); @@ -491,22 +475,21 @@ StatusOr> ReductionEmitter::BuildFusedInitializerThunk( const HloInstruction* init_value = reduce->init_values()[0]; TF_ASSIGN_OR_RETURN( std::optional> constant_init_thunk, - BuildConstantInitializerThunk(ir_emitter_context_, fusion_op_, - fusion_root, init_value, dest, dest_slice)); + BuildConstantInitializerThunk(ir_emitter_context_, fusion_root, + init_value, dest_slice)); if (constant_init_thunk) { return *std::move(constant_init_thunk); } const Shape dest_shape = fusion_root->shape(); - TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, - CalculateLaunchDimensions( - dest_shape, ir_emitter_context_.gpu_device_info())); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + dest_shape, ir_emitter_context_.gpu_device_info()); const HloComputation* fused_computation = fusion_.fused_instructions_computation(); auto builder_fn = [&](std::vector inputs, - std::vector outputs) -> Status { + std::vector outputs) -> absl::Status { FusedIrEmitter fused_emitter(elemental_emitter_); for (int i = 0; i < fused_computation->num_parameters(); i++) { fused_emitter.BindGenerator( @@ -528,7 +511,7 @@ StatusOr> ReductionEmitter::BuildFusedInitializerThunk( TF_RETURN_IF_ERROR(ParallelLoopEmitter(generator, {outputs[output_index]}, launch_dimensions, builder_) .EmitLoop(fusion_.name())); - return OkStatus(); + return absl::OkStatus(); }; return BuildKernelThunkForFusion(launch_dimensions, @@ -537,20 +520,6 @@ StatusOr> ReductionEmitter::BuildFusedInitializerThunk( builder_fn); } -// Gets the output offset as calculated from thread_id.x (to be applied to the -// offset calculated from block_id and thread_id.y). -static llvm::Value* GetStartOffsetX(const TilingScheme& tiling_scheme, - llvm::Value* thread_id_x, - llvm::Type* index_ty, - llvm::IRBuilder<>* b) { - int64_t multiplier = - tiling_scheme.GetIndexingOrder() == TilingScheme::StridedIndexingX - ? tiling_scheme.GetVectorSize() - : tiling_scheme.GetTileSizeFor(TilingScheme::DimX); - return b->CreateMul(thread_id_x, - llvm::ConstantInt::get(index_ty, multiplier)); -} - // Emits shuffle-down reduction for the `partial_result_address` using the // reduction computation `reducer`, writes output into // `partial_result_address`. @@ -588,19 +557,17 @@ void ReductionGroupEmitter::EmitFullWarpShuffleDownLoopForReduce( llvm::Type* shuffled_value_type = element_type->isStructTy() ? builder->getIntNTy(bit_width) : element_type; - auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) { return ptr; }; - llvm::Value* partial_result = builder->CreateLoad( - shuffled_value_type, - convert_pointer_for_shuffle(partial_result_address), - "partial_reduction_result"); + llvm::Value* partial_result = + builder->CreateLoad(shuffled_value_type, partial_result_address, + "partial_reduction_result"); builder->CreateStore( EmitFullWarpShuffleDown(partial_result, builder->getInt32(distance), builder), - convert_pointer_for_shuffle(result_from_other_lane)); + result_from_other_lane); } - StatusOr> returned_scalars = + absl::StatusOr> returned_scalars = CallNestedComputationWithScalarAddrs( builder, reduction_emitter_.ir_emitter_context_, *reducer, reduction_params); @@ -614,134 +581,92 @@ void ReductionGroupEmitter::EmitFullWarpShuffleDownLoopForReduce( } llvm_ir::IrArray::Index ReductionGroupEmitter::GetOutputIndexForReduction( - int partial_result_idx, const TilingKernelInfo& tiling_kernel_info, + const TilingKernelInfo& tiling_kernel_info, const HloReduceInstruction* reduction, const HloInstruction* root, int output_idx) const { - auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(reduction_emitter_.index_ty_, c); - }; - auto* builder = reduction_emitter_.builder_; - const auto& reduction_info = - *reduction_emitter_.analysis_.GetReductionCodegenInfo(); - const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); - const TilingThreadIdInfo& thread_id_info = tiling_kernel_info.thread_id_info; - - llvm_ir::IrArray::Index start_offset = [&] { - llvm::Value* x_loc = thread_id_info.thread_id_x; - llvm::Value* y_loc = thread_id_info.thread_id_y; - if (!reduction_info.IsRowReduction()) { - std::swap(x_loc, y_loc); - } - llvm::Value* start_offset_x = GetStartOffsetX( - tiling_scheme, x_loc, reduction_emitter_.index_ty_, builder); - return tiling_kernel_info.tile_origin - .AddOffsetToDim(y_loc, TilingScheme::DimY, builder) - .AddOffsetToDim(start_offset_x, TilingScheme::DimX, builder); - }(); - - const Shape& operand_shape = reduction->inputs()[output_idx]->shape(); - Shape reduction_kept_element_shape = - ShapeUtil::DeleteDimensions(reduction->dimensions(), operand_shape); - - // Given the llvm_ir::IrArray index of a reduction input, returns the linear - // address of the reduction output as if the reduction were going to keep - // the input shape with the dimensions being reduced moved. - llvm::Value* untransposed_output_linear_address = [&] { - const llvm_ir::IrArray::Index index = start_offset.AddOffsetToDim( - constant(partial_result_idx), TilingScheme::DimX, builder); + auto* index_ty = reduction_emitter_.index_ty_; + + // 1d or 2d output index (for row/column reduction). + auto projected_index = [&]() -> llvm_ir::IrArray::Index { + const auto& reduction_info = reduction_emitter_.reduction_codegen_info_; + const auto& offset = tiling_kernel_info.tile_origin; + const auto& shape = reduction_info.GetTiling().GetXlaShape(); + const auto& thread_ids = tiling_kernel_info.thread_id_info.thread_ids; if (reduction_info.IsRowReduction()) { - // For row-reduction, y-coordinate determines which row we write into. - return index[TilingScheme::DimY]; + constexpr int kDim = ReductionDimensions::kRowKeptDimension; + return {{builder->CreateAdd(offset[kDim], thread_ids[kDim])}, + {shape.dimensions(kDim)}, + index_ty}; } - // For column reduction, we get the transposed address. - absl::Span dims_in_elem = tiling_scheme.GetDimsInElems(); - llvm::Value* x_dim_size = - index.GetConstantWithIndexType(dims_in_elem[TilingScheme::DimX]); - llvm::Value* x_block_offset = - builder->CreateMul(index[TilingScheme::DimZ], x_dim_size); - return builder->CreateAdd(x_block_offset, index[TilingScheme::DimX]); + auto* major_idx = offset[ReductionDimensions::kColMajorKeptDimension]; + auto* minor_idx = builder->CreateAdd( + offset[ReductionDimensions::kColMinorKeptDimension], + thread_ids[ReductionDimensions::kColReducedDimension]); + return {{major_idx, minor_idx}, + ShapeUtil::DeleteDimension( + ReductionDimensions::kColReducedDimension, shape), + index_ty}; }(); - // A reduction is allowed to transpose its output. For example, suppose - // we are reducing the second dimension of f32[10,20,30]{3,2,1}. We are - // allowed to produce as output either f32[10,30]{1,0} (no transpose) or - // f32[10,30]{0,1} (transposing the two output dims). - // - // At this point in the function we have a "partial sum" of input elements - // (stored in partial_result_addresses), and we need to accumulate it into - // the correct output element. - llvm_ir::IrArray::Index element_index( - /*linear=*/untransposed_output_linear_address, - reduction_kept_element_shape, builder); - const Shape& output_shape = !reduction->shape().IsTuple() - ? reduction->shape() - : reduction->shape().tuple_shapes(output_idx); - llvm_ir::IrArray::Index output_index(element_index.multidim(), output_shape, - element_index.GetType()); - // We need to check for root == reduction separately, because for variadic - // reduce the root shape would be a tuple, while 'output_shape' is the - // subshape. - return (root == reduction || - ShapeUtil::EqualIgnoringElementType(output_shape, root->shape())) - ? output_index - : output_index.SourceIndexOfBitcast(output_shape, root->shape(), - builder); -} - -llvm::Value* CastSharedToGlobal(llvm::IRBuilder<>* builder, llvm::Value* input, - llvm::Type* element_type, llvm::Twine name) { - return builder->CreateAddrSpaceCast( - input, - llvm::PointerType::get(element_type, - /*AddressSpace=*/0), - name); + auto physical_shape = ShapeUtil::DeleteDimensions( + reduction->dimensions(), reduction->operand(output_idx)->shape()); + auto physical_index = + projected_index.SourceIndexOfBitcast(physical_shape, builder); + return llvm_ir::IrArray::Index(physical_index.multidim(), + OutputShape(reduction->shape(), output_idx), + index_ty) + .SourceIndexOfBitcast(OutputShape(root->shape(), output_idx), builder); } void ReductionGroupEmitter::WriteReductionOutput( const TilingKernelInfo& tiling_kernel_info, - const HloReduceInstruction* reduction, const HloInstruction* root, - int partial_result_idx, const absl::Span values) const { + const HloReduceInstruction* reduction, + const std::vector& roots, + const absl::Span values) const { auto* builder = reduction_emitter_.builder_; - const auto& reduction_info = - *reduction_emitter_.analysis_.GetReductionCodegenInfo(); + const auto& reduction_info = reduction_emitter_.reduction_codegen_info_; const HloComputation* reducer = reduction->to_apply(); for (const auto& [oidx, typed_ptr] : llvm::enumerate(values)) { auto [output_ptr, type] = typed_ptr; - llvm_ir::IrArray::Index output_index = GetOutputIndexForReduction( - partial_result_idx, tiling_kernel_info, reduction, root, oidx); - - llvm::Value* output_address = - result_ir_arrays_.at(root)[oidx].EmitArrayElementAddress( - output_index, builder, "output_element_address"); - if (reduction_info.IsRaceFree()) { - FusedIrEmitter fused_emitter(reduction_emitter_.elemental_emitter_); - llvm::Value* loaded = builder->CreateLoad(type, output_ptr, "output"); - fused_emitter.BindGenerator( - *reduction, - [&](const llvm_ir::IrArray::Index& index) { return loaded; }); - llvm_ir::ElementGenerator gen = *fused_emitter.GetGenerator(*root); - llvm::Value* generated = *gen(output_index); - builder->CreateStore(generated, output_address); - } else { - CHECK_EQ(values.size(), 1); - CHECK_EQ(reduction, root) - << "output fusion is not allowed for racing reductions"; - TF_CHECK_OK(EmitAtomicOperationForNestedComputation( - builder, reduction_emitter_.ir_emitter_context_, *reducer, - output_address, output_ptr, type)); + for (auto root : roots) { + llvm_ir::IrArray::Index output_index = + GetOutputIndexForReduction(tiling_kernel_info, reduction, root, oidx); + + llvm::Value* output_address = + result_ir_arrays_.at(root)[oidx].EmitArrayElementAddress( + output_index, builder, "output_element_address"); + if (reduction_info.IsRaceFree()) { + FusedIrEmitter fused_emitter(reduction_emitter_.elemental_emitter_); + llvm::Value* loaded = builder->CreateLoad(type, output_ptr, "output"); + fused_emitter.BindGenerator( + *reduction, + [&](const llvm_ir::IrArray::Index& index) { return loaded; }); + llvm_ir::ElementGenerator gen = *fused_emitter.GetGenerator(*root); + llvm::Value* generated = *gen(output_index); + builder->CreateStore(generated, output_address); + } else { + CHECK_EQ(values.size(), 1); + CHECK_EQ(roots.size(), 1); + CHECK_EQ(reduction, root) + << "output fusion is not allowed for racing reductions"; + TF_CHECK_OK(EmitAtomicOperationForNestedComputation( + builder, reduction_emitter_.ir_emitter_context_, *reducer, + output_address, output_ptr, type)); + } } } } -// `current_output`: the value the tile has calculated. -// `output_address`: address where the output value has to be written. void ReductionGroupEmitter::EmitReductionOutputForRowReduction( const TilingKernelInfo& tiling_kernel_info, - const HloReduceInstruction* reduction, const HloInstruction* root, - int partial_result_idx) const { + const HloReduceInstruction* reduction, + const std::vector& roots) const { const HloComputation* reducer = reduction->to_apply(); const auto& thread_id_info = tiling_kernel_info.thread_id_info; + const auto& thread_ids = thread_id_info.thread_ids; + auto* thread_id_x = + thread_ids[ReductionDimensions::kRowMinorReducedDimension]; auto constant = [&](uint64_t c) -> llvm::Constant* { return llvm::ConstantInt::get(reduction_emitter_.index_ty_, c); }; @@ -754,112 +679,121 @@ void ReductionGroupEmitter::EmitReductionOutputForRowReduction( int num_outputs = reducer->num_parameters() / 2; absl::InlinedVector current_outputs; for (int output_idx = 0; output_idx < num_outputs; output_idx++) { - const ReductionGroupEmitter::ReductionCalculationState& state = - GetCalculationStateFor(reduction, output_idx); + const auto& state = GetCalculationStateFor(reduction, output_idx); current_outputs.push_back( - {builder->CreateInBoundsGEP( - state.partial_result_address->getAllocatedType(), - state.partial_result_address, {constant(partial_result_idx)}, - "current_output"), + {state.partial_result_address, state.partial_result_address->getAllocatedType()}); } - const auto& reduction_info = - *reduction_emitter_.analysis_.GetReductionCodegenInfo(); - const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); - int num_rows_per_warp = reduction_emitter_.RowReductionGetRowsPerWarp(); - EmitFullWarpShuffleDownLoopForReduce( - reducer, absl::MakeSpan(current_outputs), - tiling_scheme.GetNumThreadsPerBlockPhysical(), num_rows_per_warp); + const auto& reduction_info = reduction_emitter_.reduction_codegen_info_; + const Tiling& tiling = reduction_info.GetTiling(); + int num_rows_per_warp = + RowReductionGetRowsPerWarp(reduction_emitter_.ReducedDimensionSize()); + EmitFullWarpShuffleDownLoopForReduce(reducer, absl::MakeSpan(current_outputs), + tiling.GetNumThreadsPerBlock(), + num_rows_per_warp); KernelSupportLibrary ksl(builder); - llvm::Value* warp_id = - builder->CreateUDiv(thread_id_info.thread_id_x, constant(WarpSize())); + llvm::Value* warp_id = builder->CreateUDiv(thread_id_x, constant(WarpSize())); auto emit_write_output = [&](llvm::Value* write_condition, const absl::Span values) { ksl.If("reduction_write_output", write_condition, [&] { - WriteReductionOutput(tiling_kernel_info, reduction, root, - partial_result_idx, values); + WriteReductionOutput(tiling_kernel_info, reduction, roots, values); }); }; - if (num_rows_per_warp > 1) { - llvm::Value* is_writing_thread = is_zero(builder->CreateAnd( - thread_id_info.thread_id_x, - constant(reduction_emitter_.ReducedDimensionSize() - 1))); - emit_write_output(is_writing_thread, current_outputs); - return; - } - - ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] { - for (int oidx = 0; oidx < num_outputs; oidx++) { - const auto& state = GetCalculationStateFor(reduction, oidx); - llvm::Value* shmem_output_addr = thread_id_info.GEPIntoSharedMemory( - builder, state.shared_cache, {constant(partial_result_idx), warp_id}); - builder->CreateStore(builder->CreateLoad(current_outputs[oidx].second, - current_outputs[oidx].first), - shmem_output_addr); + // The major kept dimension and vector dimension are not tiled, so they're + // always in bounds. + llvm::Value* is_in_bounds_y = builder->CreateICmpULT( + thread_ids[ReductionDimensions::kRowKeptDimension], + tiling_kernel_info + .output_tile_bounds[ReductionDimensions::kRowKeptDimension]); + + ksl.If("thread_in_bounds", is_in_bounds_y, [&] { + if (num_rows_per_warp > 1) { + llvm::Value* is_writing_thread = is_zero(builder->CreateAnd( + thread_id_x, + constant(reduction_emitter_.ReducedDimensionSize() - 1))); + emit_write_output(is_writing_thread, current_outputs); + return; } - }); - // TODO(cheshire): Don't we want to sync it once for everything in the - // output? Not once per each? - reduction_emitter_.EmitSyncThreads(); - ksl.If("inter_warp_reduce", is_zero(warp_id), [&] { - absl::InlinedVector selected_values; - for (int oidx = 0; oidx < num_outputs; oidx++) { - const auto& state = GetCalculationStateFor(reduction, oidx); - llvm::Value* block_accum_addr = thread_id_info.GEPIntoSharedMemory( - builder, state.shared_cache, - {constant(partial_result_idx), thread_id_info.lane_id}); - - llvm::Type* element_type = - state.partial_result_address->getAllocatedType(); - - // Ensure initial value address is in generic, not scratch. - llvm::Value* initial_value_addr = - CastSharedToGlobal(builder, - llvm_ir::EmitAllocaAtFunctionEntry( - element_type, "initial_value_addr", builder), - element_type, /*name=*/""); - builder->CreateStore(state.initial_value, initial_value_addr); - - llvm::Value* warp_exists = builder->CreateICmpULT( - thread_id_info.thread_id_x, - constant(tiling_scheme.GetNumThreadsFor(TilingScheme::DimX) / - WarpSize())); - - llvm::Value* selected_value = builder->CreateSelect( - warp_exists, block_accum_addr, initial_value_addr); - - selected_values.push_back({selected_value, element_type}); - } + ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] { + for (int oidx = 0; oidx < num_outputs; oidx++) { + auto& state = GetCalculationStateFor(reduction, oidx); + state.shared_cache->Store( + builder->CreateLoad(current_outputs[oidx].second, + current_outputs[oidx].first), + {thread_id_info.thread_ids[ReductionDimensions::kRowKeptDimension], + warp_id}, + builder); + } + }); - // If only one warp is present in the block, then we don't need inter-warp - // reduction. - // TODO(b/241414088) If only warp is present, then inter-warp - // communication using shared memory and synchronization using barrier is - // also unnecessary and should be removed. - if (tiling_scheme.GetNumThreadsPerBlock() > WarpSize()) { - EmitFullWarpShuffleDownLoopForReduce( - reducer, absl::MakeSpan(selected_values), - tiling_scheme.GetNumThreadsPerBlock(), /*num_results_per_warp=*/1); - } + // TODO(cheshire): Don't we want to sync it once for everything in the + // output? Not once per each? + reduction_emitter_.EmitSyncThreads(); + ksl.If("inter_warp_reduce", is_zero(warp_id), [&] { + absl::InlinedVector selected_values; + for (int oidx = 0; oidx < num_outputs; oidx++) { + auto& state = GetCalculationStateFor(reduction, oidx); + llvm::Value* block_accum_addr = state.shared_cache->Address( + {thread_id_info.thread_ids[ReductionDimensions::kRowKeptDimension], + thread_id_info.lane_id}, + builder); + + llvm::Type* element_type = + state.partial_result_address->getAllocatedType(); + + // Ensure initial value address is in generic, not scratch. + llvm::Value* initial_value_addr = + CastSharedToGlobal(builder, + llvm_ir::EmitAllocaAtFunctionEntry( + element_type, "initial_value_addr", builder), + element_type, /*name=*/""); + builder->CreateStore(state.initial_value, initial_value_addr); + + llvm::Value* warp_exists = builder->CreateICmpULT( + thread_id_x, + constant(tiling.GetThreadsPerBlock() + [ReductionDimensions::kRowMinorReducedDimension] / + WarpSize())); + + llvm::Value* selected_value = builder->CreateSelect( + warp_exists, block_accum_addr, initial_value_addr); + + selected_values.push_back({selected_value, element_type}); + } + + // If only one warp produces the output element, we don't need to emit + // an inter warp reduce. In our tiling, DimX is the minor reduced + // dimension. The major reduced dimension is always emitted as a loop. + // TODO(b/241414088) If only warp is present, then inter-warp + // communication using shared memory and synchronization using barrier is + // also unnecessary and should be removed. + if (tiling.GetThreadsPerBlock() + [ReductionDimensions::kRowMinorReducedDimension] > WarpSize()) { + EmitFullWarpShuffleDownLoopForReduce( + reducer, absl::MakeSpan(selected_values), + tiling.GetNumThreadsPerBlock(), /*num_results_per_warp=*/1); + } - emit_write_output(is_zero(thread_id_info.thread_id_x), selected_values); + emit_write_output(is_zero(thread_id_x), selected_values); + }); }); } // Same arguments as EmitReductionOutputForRowReduction. void ReductionGroupEmitter::EmitReductionOutputForColumnReduction( const TilingKernelInfo& tiling_kernel_info, - const HloReduceInstruction* reduction, const HloInstruction* root, - int partial_result_idx) const { + const HloReduceInstruction* reduction, + const std::vector& roots) const { auto* builder = reduction_emitter_.builder_; KernelSupportLibrary ksl(builder); const HloComputation* reducer = reduction->to_apply(); const auto& thread_id_info = tiling_kernel_info.thread_id_info; + const auto& thread_ids = thread_id_info.thread_ids; auto constant = [&](uint64_t c) -> llvm::Constant* { return llvm::ConstantInt::get(reduction_emitter_.index_ty_, c); @@ -867,35 +801,21 @@ void ReductionGroupEmitter::EmitReductionOutputForColumnReduction( auto is_zero = [&](llvm::Value* value) { return builder->CreateICmpEQ(value, constant(0)); }; - const auto& reduction_info = - *reduction_emitter_.analysis_.GetReductionCodegenInfo(); - const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); + const auto& reduction_info = reduction_emitter_.reduction_codegen_info_; + const Tiling& tiling = reduction_info.GetTiling(); int num_outputs = reducer->num_parameters() / 2; - // Wait for reads from shmem in the last iteration to complete. (If this is - // slow, we could "double-buffer" by having two shmem buffers and switching - // between them.) - if (partial_result_idx > 0) { - reduction_emitter_.EmitSyncThreads(); - } + auto* kept_index = thread_ids[ReductionDimensions::kColMinorKeptDimension]; + auto* reduced_index = thread_ids[ReductionDimensions::kColReducedDimension]; // Store the transpose in shared memory. for (int output_idx = 0; output_idx < num_outputs; output_idx++) { const auto& state = GetCalculationStateFor(reduction, output_idx); - llvm::GlobalVariable* shared_cache = state.shared_cache; - llvm::AddrSpaceCastInst* shmem_output_addr = - llvm::cast(thread_id_info.GEPIntoSharedMemory( - builder, shared_cache, - {thread_id_info.thread_id_x, thread_id_info.thread_id_y}, - "shmem_output_address")); - llvm::Value* current_output = builder->CreateInBoundsGEP( - state.partial_result_address->getAllocatedType(), - state.partial_result_address, {constant(partial_result_idx)}, - "current_output"); - - llvm::Value* current_output_value = builder->CreateLoad( - state.partial_result_address->getAllocatedType(), current_output); - builder->CreateStore(current_output_value, shmem_output_addr); + auto* current_output_value = + builder->CreateLoad(state.partial_result_address->getAllocatedType(), + state.partial_result_address); + state.shared_cache->Store(current_output_value, {kept_index, reduced_index}, + builder); } reduction_emitter_.EmitSyncThreads(); @@ -904,45 +824,41 @@ void ReductionGroupEmitter::EmitReductionOutputForColumnReduction( absl::InlinedVector shmem_transposed_addrs; for (int output_idx = 0; output_idx < num_outputs; output_idx++) { const auto& state = GetCalculationStateFor(reduction, output_idx); - llvm::AddrSpaceCastInst* shmem_transposed_addr = - llvm::cast(thread_id_info.GEPIntoSharedMemory( - builder, state.shared_cache, - {thread_id_info.thread_id_y, thread_id_info.thread_id_x}, - "shmem_transposed_addr")); + auto* shmem_transposed_addr = + state.shared_cache->Address({reduced_index, kept_index}, builder); shmem_transposed_addrs.push_back( - {shmem_transposed_addr, llvm::cast( - shmem_transposed_addr->getPointerOperand()) - ->getResultElementType()}); + {shmem_transposed_addr, state.shared_cache->GetElementType()}); } EmitFullWarpShuffleDownLoopForReduce(reducer, absl::MakeSpan(shmem_transposed_addrs), - tiling_scheme.GetNumThreadsPerBlock(), + tiling.GetNumThreadsPerBlock(), /*num_results_per_warp=*/1); // Some warps in the block are completely outside of the bound of the // tensor, so they should not write any output at all. llvm::Value* has_output = builder->CreateAnd( builder->CreateICmpULT( - GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_y, - reduction_emitter_.index_ty_, builder), - tiling_kernel_info.output_tile_bounds[1]), - builder->CreateICmpULT(thread_id_info.thread_id_x, - tiling_kernel_info.output_tile_bounds[0])); + reduced_index, + tiling_kernel_info + .output_tile_bounds[ReductionDimensions::kColMinorKeptDimension]), + builder->CreateICmpULT( + kept_index, + tiling_kernel_info + .output_tile_bounds[ReductionDimensions::kColReducedDimension])); ksl.If("reduction_write_output", builder->CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] { - WriteReductionOutput(tiling_kernel_info, reduction, root, - partial_result_idx, shmem_transposed_addrs); + WriteReductionOutput(tiling_kernel_info, reduction, roots, + shmem_transposed_addrs); }); } // Generate a single element of the tile (update the accumulator state) for a -// given reducer of index `i`. +// given reducer. void ReductionGroupEmitter::GenerateElementForReducer( - const HloReduceInstruction* reduction, llvm::Value* partial_result_index, - const llvm_ir::IrArray::Index& index_without_linear, - const llvm_ir::IrArray::Index& input_index, int num_partial_results) const { + const HloReduceInstruction* reduction, + const llvm_ir::IrArray::Index& index) const { HloComputation* reducer = reduction->to_apply(); auto* builder = reduction_emitter_.builder_; CHECK_EQ(reducer->num_parameters() % 2, 0); @@ -953,15 +869,11 @@ void ReductionGroupEmitter::GenerateElementForReducer( const auto& state = GetCalculationStateFor(reduction, red_idx); llvm::AllocaInst* input_address = state.input_address; - llvm::AllocaInst* partial_reduction_result_address = - state.partial_result_address; - llvm::Value* const input_ir_value = *state.input_gen( - num_partial_results > 1 ? index_without_linear : input_index); + auto input_index = + index.SourceIndexOfBitcast(reduction->operand(0)->shape(), builder); + llvm::Value* const input_ir_value = *state.input_gen(input_index); builder->CreateStore(input_ir_value, input_address); - llvm::Value* partial_result_address = builder->CreateInBoundsGEP( - partial_reduction_result_address->getAllocatedType(), - partial_reduction_result_address, {partial_result_index}); - reduction_accumulators.push_back(partial_result_address); + reduction_accumulators.push_back(state.partial_result_address); reduction_input_value.push_back(input_address); } @@ -980,7 +892,7 @@ void ReductionGroupEmitter::GenerateElementForReducer( // pointers as last parameters, the called computation writes into // those pointers, and we have returned values on the stack (as well // as pointers to them). - StatusOr> returned_scalars = + absl::StatusOr> returned_scalars = CallNestedComputationWithScalarAddrs( builder, reduction_emitter_.ir_emitter_context_, *reducer, reduction_params); @@ -992,234 +904,204 @@ void ReductionGroupEmitter::GenerateElementForReducer( } // Emits code for reductions in the output_instructions. -Status ReductionEmitter::EmitIRForReduction( +absl::Status ReductionEmitter::EmitIRForReduction( absl::Span instr_index_group, FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays, const Shape& input_shape) { - const auto& reduction_info = *analysis_.GetReductionCodegenInfo(); - std::vector roots; - std::vector heroes; ExtraOutputGensMap extra_output_gens; + absl::flat_hash_map> + heroes_to_roots; + // Keep a list of deduplicated heroes separate from heroes_to_roots to make + // the CodeGen deterministic. + std::vector heroes; for (const HloInstruction* hlo : instr_index_group) { auto& hero = FindNonTrivialHero(*hlo); if (IsRealReductionHero(*hlo, hero)) { auto reduction = Cast(&hero); - roots.push_back(hlo); - heroes.push_back(reduction); + if (heroes_to_roots.find(reduction) == heroes_to_roots.end()) { + heroes.push_back(reduction); + } + heroes_to_roots[reduction].push_back(hlo); } else { extra_output_gens[hlo] = *fused_emitter.GetGenerator(*hlo); } } CHECK(!heroes.empty()) << " expect at least one reduce instructions."; - const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); - CHECK_EQ(tiling_scheme.GetNumThreadsPerBlockPhysical() % WarpSize(), 0); + const Tiling& tiling = reduction_codegen_info_.GetTiling(); + CHECK_EQ(tiling.GetNumThreadsPerBlock() % WarpSize(), 0); ReductionGroupEmitter group_emitter(*this, heroes, result_ir_arrays, fused_emitter); - EmitTileElementFunction emit_reduction_element = - [&](const TilingThreadIdInfo& thread_id_info, - const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { - llvm_ir::IrArray::Index input_index = GetUnnormalizedIndex( - index, input_shape, builder_, - reduction_info.GetTilingScheme().GetDimsInElems()); - llvm::Value* partial_result_index = - reduction_info.IsRowReduction() - ? builder_->getInt32(0) - : builder_->CreateSub( - x_loc, - GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_x, - index_ty_, builder_)); - - // Clear the linear index field of the llvm_ir::IrArray::Index to enable - // the use of GetElementPointer with array types. This enables the - // vectorization of the computation for different partial results. Use - // this index if 'num_partial_results > 1'. - int num_partial_results = reduction_info.GetNumPartialResults(); - llvm_ir::IrArray::Index index_without_linear{ - input_index.multidim(), input_shape, input_index.GetType()}; - - // Emit code to generate the input and perform the reduction computation - // for each reduction instruction. - for (const HloReduceInstruction* reduce : heroes) { - group_emitter.GenerateElementForReducer( - reduce, partial_result_index, index_without_linear, input_index, - num_partial_results); - } - - // Emit code to generate the output for the non-reduction instructions - // in the fusion, if any. - TF_CHECK_OK(group_emitter.EmitExtraOutputsForReduce( - input_shape, input_index, extra_output_gens)); - }; - TF_ASSIGN_OR_RETURN( TilingKernelInfo tiling_kernel_info, - EmitTilingKernel(builder_, tiling_scheme, index_ty_, - [&](const TilingThreadIdInfo& thread_id_info, - const llvm_ir::IrArray::Index& index, - std::array tile_dimensions) { - EmitTile(builder_, reduction_info.GetTilingScheme(), - index, thread_id_info, tile_dimensions, - emit_reduction_element); - })); + EmitTilingKernel( + builder_, tiling, index_ty_, + [&](const TilingThreadIdInfo& thread_id_info, + const llvm_ir::IrArray::Index& tile_index, + absl::Span tile_dimensions) { + auto emit_element = + [&](absl::Span index_in_tile) { + auto index = tile_index.AddOffset(index_in_tile, builder_); + + // Emit code to generate the input and perform the reduction + // computation for each reduction instruction. + for (const HloReduceInstruction* reduce : heroes) { + group_emitter.GenerateElementForReducer(reduce, index); + } + + // Emit code to generate the output for the non-reduction + // instructions in the fusion, if any. + TF_CHECK_OK(group_emitter.EmitExtraOutputsForReduce( + ShapeUtil::MakeShape( + F32, reduction_codegen_info_.GetTiling().GetShape()), + index, extra_output_gens)); + }; + EmitTile(builder_, reduction_codegen_info_.GetTiling(), + thread_id_info, tile_dimensions, emit_element); + })); KernelSupportLibrary ksl(builder_); - for (auto [reduce, root] : llvm::zip(heroes, roots)) { - for (int partial_result_idx = 0; - partial_result_idx < reduction_info.GetNumPartialResults(); - ++partial_result_idx) { - if (reduction_info.IsRowReduction()) { - group_emitter.EmitReductionOutputForRowReduction( - tiling_kernel_info, reduce, root, partial_result_idx); - } else { - group_emitter.EmitReductionOutputForColumnReduction( - tiling_kernel_info, reduce, root, partial_result_idx); - } + for (auto reduce : heroes) { + if (reduction_codegen_info_.IsRowReduction()) { + group_emitter.EmitReductionOutputForRowReduction( + tiling_kernel_info, reduce, heroes_to_roots[reduce]); + } else { + group_emitter.EmitReductionOutputForColumnReduction( + tiling_kernel_info, reduce, heroes_to_roots[reduce]); } } - return OkStatus(); + return absl::OkStatus(); } -StatusOr ReductionEmitter::Emit() { - auto* reduction_codegen_info = analysis_.GetReductionCodegenInfo(); - TF_ASSIGN_OR_RETURN(auto launch_dimensions, analysis_.GetLaunchDimensions()); - +absl::StatusOr ReductionEmitter::EmitInitializers() { FusionEmissionResult result; - VLOG(3) << "Launch dimensions of " << fusion_.name() << ": " - << launch_dimensions.ToString(); - const HloComputation* fused_computation = - fusion_.fused_instructions_computation(); - if (!reduction_codegen_info->IsRaceFree()) { - // We need to get the dest slice by traversing the slice assigned to - // fusion, because instructions inside fusion don't have buffer assignment. - // - // The order of fusion roots is determined by its position in the result - // tuple. For example, in the following fused computation - // - // %fused_computation { - // %a = ... - // &b = ... - // ROOT %root = tuple(%a, %b) - // } - // - // The fusion root with index = 0 is %a, and the fusion root %b has index 1. - // Therefore we can get the ordered slices by calling ForEachSubshape on the - // result shape. - std::vector slices; - if (ir_emitter_context_.emit_ir_from_hlo()) { - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - fusion_.shape(), [&](const Shape& subshape, ShapeIndex index) { - if (!ShapeUtil::IsLeafIndex(fusion_.shape(), index)) { - return OkStatus(); - } - - TF_ASSIGN_OR_RETURN( - BufferAllocation::Slice slice, - ir_emitter_context_.buffer_assignment().GetUniqueSlice(&fusion_, - index)); - slices.push_back(slice); - return OkStatus(); - })); - } - - absl::Span fusion_roots = - analysis_.fusion_roots(); - for (int i = 0; i < fusion_roots.size(); ++i) { - const HloInstruction* fusion_root = fusion_roots[i]; - - mlir::Value dest = ir_emitter_context_.emit_ir_from_hlo() - ? nullptr - : fusion_op_.getOutputBuffers()[i]; - - BufferAllocation::Slice dest_slice; - if (ir_emitter_context_.emit_ir_from_hlo()) { - dest_slice = slices[i]; - } else { - TF_ASSIGN_OR_RETURN( - dest_slice, - GetAllocationSlice(dest, ir_emitter_context_.allocations())); - } + if (reduction_codegen_info_.IsRaceFree()) { + return result; + } + // We need to get the dest slice by traversing the slice assigned to + // fusion, because instructions inside fusion don't have buffer assignment. + // + // The order of fusion roots is determined by its position in the result + // tuple. For example, in the following fused computation + // + // %fused_computation { + // %a = ... + // &b = ... + // ROOT %root = tuple(%a, %b) + // } + // + // The fusion root with index = 0 is %a, and the fusion root %b has index 1. + // Therefore we can get the ordered slices by calling ForEachSubshape on the + // result shape. + std::vector slices; + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + fusion_.shape(), [&](const Shape& subshape, ShapeIndex index) { + if (!ShapeUtil::IsLeafIndex(fusion_.shape(), index)) { + return absl::OkStatus(); + } - if (IsReductionFromOrToContiguousDimensions(*fusion_root)) { TF_ASSIGN_OR_RETURN( - result.thunks.emplace_back(), - BuildFusedInitializerThunk(fusion_root, dest, dest_slice, i)); - } + BufferAllocation::Slice slice, + ir_emitter_context_.buffer_assignment().GetUniqueSlice(&fusion_, + index)); + slices.push_back(slice); + return absl::OkStatus(); + })); + + absl::Span fusion_roots = + analysis_.fusion_roots(); + for (int i = 0; i < fusion_roots.size(); ++i) { + const HloInstruction* fusion_root = fusion_roots[i]; + + if (IsReductionFromOrToContiguousDimensions(*fusion_root)) { + TF_ASSIGN_OR_RETURN( + result.thunks.emplace_back(), + BuildFusedInitializerThunk(fusion_root, slices[i], i)); } } + return result; +} - auto builder_fn = [&, this](std::vector inputs, - std::vector outputs) -> Status { - FusedIrEmitter fused_emitter(elemental_emitter_); - for (int i = 0; i < fused_computation->num_parameters(); i++) { - HloInstruction* fused_operand = - fused_computation->parameter_instruction(i); - fused_emitter.BindGenerator( - *fused_operand, - [builder = builder_, input = inputs[i], - fused_operand](const llvm_ir::IrArray::Index& index) { - return input.EmitReadArrayElement(index, builder, - fused_operand->name()); - }); - } - - // Get outputs. - ReductionOutputMap result_ir_arrays; +absl::Status ReductionEmitter::EmitKernel( + const LaunchDimensions& launch_dims, std::vector inputs, + std::vector outputs) { + const HloComputation* fused_computation = + fusion_.fused_instructions_computation(); + FusedIrEmitter fused_emitter(elemental_emitter_); + for (int i = 0; i < fused_computation->num_parameters(); i++) { + HloInstruction* fused_operand = fused_computation->parameter_instruction(i); + fused_emitter.BindGenerator( + *fused_operand, [builder = builder_, input = inputs[i], + fused_operand](const llvm_ir::IrArray::Index& index) { + return input.EmitReadArrayElement(index, builder, + fused_operand->name()); + }); + } - int ir_arrays_idx = 0; - for (const HloInstruction* root : analysis_.fusion_roots()) { - int get_num_results = GetNumOutputs(root->shape()); - result_ir_arrays[root] = - absl::MakeSpan(outputs).subspan(ir_arrays_idx, get_num_results); - ir_arrays_idx += get_num_results; - } + // Get outputs. + ReductionOutputMap result_ir_arrays; - KernelSupportLibrary ksl(builder_, llvm_ir::UnrollMode::kDefaultUnroll); - - // Use raw block_id_y to select the i-th parallel reduction to run. Using - // block_id_y instead of block_id_x simplifies the index calculation - // for reduction code generation as the block_id_y is orthogonal to - // the indices used within the reductions. - const std::vector>& instr_index_groups = - reduction_codegen_info->GetIndexGroups(); - Shape reduce_operand_shape = - reduction_codegen_info->GetReduceOperandShape(); - - llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic( - gpu::TargetIntrinsicID::kBlockIdy, {}, {}, builder_); - llvm_ir::AddRangeMetadata(0, instr_index_groups.size(), - llvm::cast(raw_block_id_y)); - for (int i = 0; i < instr_index_groups.size(); ++i) { - TF_RETURN_IF_ERROR(ksl.IfWithStatus( - absl::StrCat("reduce-group-", i), - builder_->CreateICmpEQ(raw_block_id_y, builder_->getInt32(i)), [&] { - return EmitIRForReduction(instr_index_groups[i], fused_emitter, - result_ir_arrays, reduce_operand_shape); - })); - } + int ir_arrays_idx = 0; + for (const HloInstruction* root : analysis_.fusion_roots()) { + int get_num_results = GetNumOutputs(root->shape()); + result_ir_arrays[root] = + absl::MakeSpan(outputs).subspan(ir_arrays_idx, get_num_results); + ir_arrays_idx += get_num_results; + } - return OkStatus(); - }; + KernelSupportLibrary ksl(builder_, llvm_ir::UnrollMode::kDefaultUnroll); + + // Use raw block_id_y to select the i-th parallel reduction to run. Using + // block_id_y instead of block_id_x simplifies the index calculation + // for reduction code generation as the block_id_y is orthogonal to + // the indices used within the reductions. + const auto& instr_index_groups = + reduction_codegen_info_.GetGroups().grouped_roots; + Shape reduce_operand_shape = reduction_codegen_info_.GetReduceOperandShape(); + + llvm::Value* block_id_y = gpu::EmitCallToTargetIntrinsic( + gpu::TargetIntrinsicID::kBlockIdy, {}, {}, builder_); + llvm_ir::AddRangeMetadata(0, instr_index_groups.size(), + llvm::cast(block_id_y), + builder_->GetInsertBlock()->getModule()); + block_id_y = builder_->CreateZExtOrTrunc(block_id_y, builder_->getInt32Ty()); + block_id_y->setName("block.id.y"); + for (int i = 0; i < instr_index_groups.size(); ++i) { + TF_RETURN_IF_ERROR(ksl.IfWithStatus( + absl::StrCat("reduce-group-", i), + builder_->CreateICmpEQ(block_id_y, builder_->getInt32(i)), [&] { + return EmitIRForReduction(instr_index_groups[i], fused_emitter, + result_ir_arrays, reduce_operand_shape); + })); + } - TF_ASSIGN_OR_RETURN( - result.thunks.emplace_back(), - BuildKernelThunkForFusion(launch_dimensions, "", builder_fn)); - return result; + return absl::OkStatus(); } } // namespace -StatusOr ReductionFusion::Emit( - IrEmitterContext& ir_emitter_context, ElementalIrEmitter& elemental_emitter, - mlir::lmhlo::FusionOp fusion_op, const HloFusionInstruction& fusion, - KernelReuseCache& kernel_cache, llvm::IRBuilder<>* builder) const { - return ReductionEmitter(analysis_, ir_emitter_context, elemental_emitter, - fusion_op, fusion, kernel_cache, builder) - .Emit(); +absl::StatusOr ReductionFusion::EmitInitializers( + IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion) const { + llvm::IRBuilder<> builder(ir_emitter_context.llvm_module()->getContext()); + return ReductionEmitter(analysis(), reduction_info(), ir_emitter_context, + fusion, &builder) + .EmitInitializers(); +} + +absl::Status ReductionFusion::EmitKernel(IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion, + const LaunchDimensions& launch_dims, + std::vector inputs, + std::vector outputs, + llvm::IRBuilder<>* builder) const { + return ReductionEmitter(analysis(), reduction_info(), ir_emitter_context, + fusion, builder) + .EmitKernel(launch_dims, inputs, outputs); } } // namespace gpu diff --git a/xla/service/gpu/fusions/reduction.h b/xla/service/gpu/fusions/reduction.h index 3b3e7897b2cc6..1304e36c27d37 100644 --- a/xla/service/gpu/fusions/reduction.h +++ b/xla/service/gpu/fusions/reduction.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,8 +15,23 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_ #define XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_ +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "llvm/IR/IRBuilder.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/fusions/reduction_base.h" +#include "xla/service/gpu/fusions/tiling_util.h" #include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/llvm_ir/ir_array.h" +#include "xla/shape.h" namespace xla { namespace gpu { @@ -87,18 +102,21 @@ namespace gpu { // complicating the index calculation in the code generation of the reduce // instructions. In other words, a block_id_y is assigned to a group and so // different groups can be run in parallel. -class ReductionFusion : public FusionInterface { +class ReductionFusion : public ReductionFusionBase { public: - explicit ReductionFusion(HloFusionAnalysis& analysis) : analysis_(analysis) {} + using ReductionFusionBase::ReductionFusionBase; - StatusOr Emit( + protected: + absl::StatusOr EmitInitializers( IrEmitterContext& ir_emitter_context, - ElementalIrEmitter& elemental_emitter, mlir::lmhlo::FusionOp fusion_op, - const HloFusionInstruction& fusion, KernelReuseCache& kernel_cache, - llvm::IRBuilder<>* builder) const override; + const HloFusionInstruction& fusion) const override; - private: - HloFusionAnalysis& analysis_; + absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion, + const LaunchDimensions& launch_dims, + std::vector inputs, + std::vector outputs, + llvm::IRBuilder<>* builder) const override; }; } // namespace gpu diff --git a/xla/service/gpu/fusions/reduction_base.cc b/xla/service/gpu/fusions/reduction_base.cc new file mode 100644 index 0000000000000..86bb721129d00 --- /dev/null +++ b/xla/service/gpu/fusions/reduction_base.cc @@ -0,0 +1,438 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/reduction_base.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/container/node_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/fusions/tiling_util.h" +#include "xla/service/gpu/gpu_fusible.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/reduction_utils.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/union_find.h" +#include "xla/util.h" + +namespace xla { +namespace gpu { +namespace { + +const Shape& FirstShape(const Shape& in) { + return in.IsTuple() ? in.tuple_shapes(0) : in; +} + +int RowReductionGetRowsPerWarp(int reduced_dimension_size) { + if (WarpSize() % reduced_dimension_size != 0 || + reduced_dimension_size >= WarpSize()) { + return 1; + } + return WarpSize() / reduced_dimension_size; +} + +int GetVectorSize(const HloFusionAnalysis& analysis, + const ReductionDimensions& reduction_dimensions, + int num_threads, Vector3 reduction_tiling) { + if (!reduction_dimensions.is_row_reduction) { + return 1; + } + + constexpr int kRowMinorReduced = + ReductionDimensions::kRowMinorReducedDimension; + if (reduction_dimensions.dimensions[kRowMinorReduced] % 2 != 0 || + MayPreventVectorization(analysis.fusion())) { + return 1; + } + + // Enabling vectorization if number of threads is <= warpsize leads to half or + // more of the threads not doing any work. + if (num_threads <= WarpSize()) { + return 1; + } + + const auto* cuda_cc = std::get_if( + &analysis.device_info().gpu_compute_capability()); + if (cuda_cc == nullptr) return 1; + if (cuda_cc->IsAtLeast(se::CudaComputeCapability::VOLTA)) return 2; + if (cuda_cc->IsAtLeast(se::CudaComputeCapability::PASCAL_)) { + return analysis.input_output_info().smallest_input_dtype_bits <= 32 && + reduction_dimensions.dimensions[kRowMinorReduced] % + (reduction_tiling[kRowMinorReduced] * num_threads) == + 0 + ? 2 + : 1; + } + return 1; +} + +ReductionGroups GroupDisjointReductions(const HloFusionAnalysis& analysis) { + const int num_fusion_outputs = analysis.fusion_roots().size(); + + CHECK_NE(0, num_fusion_outputs); + if (num_fusion_outputs == 1) { + return {{{analysis.fusion_roots()[0]}}, {0}, {true}}; + } + + absl::node_hash_map> + disjoint_sets; + + // TODO(b/249976438): we currently do not treat properly + // aliasing between inputs and outputs of the fusion, so for now put all + // non-reduction roots into one group to avoid read-after-write conflicts. + std::optional first_non_reduction_root = std::nullopt; + + absl::node_hash_map> + reachable_outputs; + absl::flat_hash_set roots_with_reduction; + absl::flat_hash_map root_indices; + const auto& roots = analysis.fusion().GetRoots(); + ReductionGroups result; + result.group_id_per_root.resize(roots.size()); + result.is_reduction_root.reserve(roots.size()); + for (auto [root, hero] : llvm::zip(roots, analysis.fusion_heroes())) { + int index = root_indices.size(); + root_indices[&root.instruction()] = index; + disjoint_sets[root].Get() = root; + reachable_outputs[root].insert(root); + result.is_reduction_root.push_back( + IsRealReductionHero(root.instruction(), *hero)); + if (result.is_reduction_root.back()) { + roots_with_reduction.insert(root); + } else if (first_non_reduction_root) { + disjoint_sets[*first_non_reduction_root].Merge(&disjoint_sets[root]); + } else { + first_non_reduction_root = root; + } + } + + std::vector instructions; + HloBfsConsumersFirstTraversal( + roots, analysis.fusion(), + [&](HloInstructionAdaptor consumer) { + auto& consumer_reachable = reachable_outputs[consumer]; + for (auto producer : consumer.GetOperands()) { + reachable_outputs[producer].insert(consumer_reachable.begin(), + consumer_reachable.end()); + } + instructions.push_back(consumer); + return TraversalResult::kAdvance; + }, + [&](HloInstructionAdaptor argument) { + instructions.push_back(argument); + }); + + for (auto instr : instructions) { + const auto& reachable = reachable_outputs[instr]; + std::vector reached_output_ids; + bool added_to_reduce = false; + for (auto output : roots) { + bool has_real_hero = roots_with_reduction.contains(output); + if (has_real_hero && + (hlo_query::IsBroadcastedConstantOrScalar(instr.instruction()))) { + if (added_to_reduce) { + // Do not group more than one output reduce instructions through + // broadcasted constants or scalars, as the recomputation should be + // acceptable. + VLOG(3) << "Skip broadcasted constant or scalar " << instr.ToString(); + continue; + } + } + // Now group output instructions if they have common predecessors. + if (reachable.contains(output)) { + VLOG(3) << "Reaching " << output.ToString() << " from " + << instr.ToString(); + reached_output_ids.push_back(output); + if (has_real_hero) { + added_to_reduce = true; + } + } + } + for (size_t j = 1; j < reached_output_ids.size(); ++j) { + disjoint_sets[reached_output_ids[0]].Merge( + &disjoint_sets[reached_output_ids[j]]); + } + } + + // Place output instructions in the same set into the same group. + ConstHloInstructionMap> group_map; + for (auto root : roots) { + group_map[&disjoint_sets[root].Get().instruction()].push_back( + &root.instruction()); + } + + result.grouped_roots.reserve(group_map.size()); + absl::c_for_each(group_map, [&](auto& it) { + for (auto* root : it.second) { + result.group_id_per_root[root_indices[root]] = + result.grouped_roots.size(); + } + result.grouped_roots.emplace_back(std::move(it.second)); + }); + return result; +} + +} // namespace + +int ReductionInfo::GetRowsPerWarp() const { + if (!is_row_reduction_) return 1; + return RowReductionGetRowsPerWarp( + tiling_.GetShape()[ReductionDimensions::kRowMinorReducedDimension]); +} + +LaunchDimensions ReductionInfo::launch_dimensions() const { + size_t blocks_y = groups_.grouped_roots.size(); + return {se::BlockDim(/*x=*/tiling_.GetNumBlocks(), + /*y=*/static_cast(blocks_y), /*z=*/1), + se::ThreadDim(/*x=*/tiling_.GetNumThreadsPerBlock(), + /*y=*/1, /*z=*/1)}; +} + +ReductionInfo ReductionInfo::Create(const HloFusionAnalysis& analysis) { + auto* hero_reduction = analysis.FindHeroReduction(); + CHECK_NE(hero_reduction, nullptr); + Shape input_shape = hero_reduction->operand(0)->shape(); + ReductionDimensions reduction_dimensions = + GetReductionKindAndContiguousComponents(*hero_reduction); + auto shape = reduction_dimensions.dimensions; + VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction + << " " << shape[0] << " " << shape[1] << " " << shape[2]; + Vector3 reduction_tiling = GetReductionTiling(reduction_dimensions); + + int64_t num_threads_y = + reduction_dimensions.is_row_reduction ? 1 : WarpSize(); + int64_t rows_per_warp = + reduction_dimensions.is_row_reduction + ? RowReductionGetRowsPerWarp( + shape[ReductionDimensions::kRowMinorReducedDimension]) + : 1; + int64_t num_threads_x = [&] { + if (reduction_dimensions.is_row_reduction) { + if (rows_per_warp > 1) { + return shape[ReductionDimensions::kRowMinorReducedDimension]; + } + int64_t max_block_size = + MinThreadsXRowReduction(hero_reduction->GetModule()->config()); + return std::min( + max_block_size, + RoundUpTo( + CeilOfRatio(shape[ReductionDimensions::kRowMinorReducedDimension], + reduction_tiling + [ReductionDimensions::kRowMinorReducedDimension]), + WarpSize())); + } + return WarpSize(); + }(); + + // If we're limited by the size of the x dimension, add additional parallelism + // in the y dimension. The code generator doesn't currently support + // parallelizing the z dimension (major reduced dimensions). The general + // recommendation is to use between 128 and 512 threads, so we just go for + // 256. See https://forums.developer.nvidia.com/t/55529 + constexpr int64_t kThreadsPerBlockTarget = 256; + if (reduction_dimensions.is_row_reduction && + num_threads_x * 2 <= kThreadsPerBlockTarget) { + int64_t kept_size = + reduction_dimensions.dimensions[ReductionDimensions::kRowKeptDimension]; + // Increase the size of the y dimension as long as there's remaining + // parallelism. + if (kept_size * num_threads_x <= kThreadsPerBlockTarget) { + num_threads_y = kept_size; + // num_threads_x is a power of two, but it may be less than 32. If dim_y + // is also small, we may have to increase the bound so the total number of + // threads is a multiple of 32. + while ((num_threads_x * num_threads_y) % 32) ++num_threads_y; + } else { + num_threads_y = kThreadsPerBlockTarget / num_threads_x; + } + } + + int vector_size = GetVectorSize(analysis, reduction_dimensions, num_threads_x, + reduction_tiling); + + absl::InlinedVector num_threads{1, num_threads_y, num_threads_x}; + absl::InlinedVector tiled_shape{shape[0], shape[1], + shape[2] / vector_size}; + absl::InlinedVector tile_per_thread{ + reduction_tiling[0], reduction_tiling[1], + reduction_tiling[2] / vector_size}; + if (rows_per_warp > 1) { + // If we produce more than one element per thread, that means the reduced + // dimension is small and it can't be tiled - we already have more threads + // in a warp than the size of the reduced dimension. The code generator + // doesn't currently support tiling the kept dimension, because it just + // uses the thread ID as the coordinate. + tile_per_thread[2] = 1; + } + if (vector_size != 1) { + num_threads.push_back(1); // The vector dimension is a loop. + tiled_shape.push_back(vector_size); + tile_per_thread.push_back(vector_size); + } + + Tiling tiling(tiled_shape, tile_per_thread, num_threads, + /*loops_to_unroll=*/{false, false, true, false}); + bool reduction_is_race_free = ReductionIsRaceFree( + hero_reduction->GetModule()->config(), reduction_dimensions); + return ReductionInfo(analysis, tiling, reduction_dimensions.is_row_reduction, + reduction_is_race_free, + GroupDisjointReductions(analysis), hero_reduction); +} + +std::optional ReductionInfo::ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const { + if (!groups_.is_reduction_root[root_index]) { + // Non-transpose roots are elementwise by definition. + return ComputeThreadIdToInputIndexing(root_index, 0, ctx); + } + auto* root = analysis_.fusion_roots()[root_index]; + auto* hero = analysis_.fusion_heroes()[root_index]; + + auto block_offsets = GetBlockOffsetsForTiling(tiling_, ctx); + auto thread_ids = DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), + tiling_.GetThreadsPerBlock(), + tiling_.GetThreadStrides()); + + auto physical_shape = ShapeUtil::DeleteDimensions(hero->dimensions(), + hero->operand(0)->shape()); + std::vector dimension_ranges{ + {{0, tiling_.GetNumThreadsPerBlock() - 1}}, + {}, + {}, + {{0, tiling_.GetNumBlocks() - 1}}, + {{0, static_cast(groups_.grouped_roots.size() - 1)}}, + {}, + }; + + constexpr int kRowKept = ReductionDimensions::kRowKeptDimension; + constexpr int kRowMinorReduced = + ReductionDimensions::kRowMinorReducedDimension; + + constexpr int kColMajorKept = ReductionDimensions::kColMajorKeptDimension; + constexpr int kColMinorKept = ReductionDimensions::kColMinorKeptDimension; + constexpr int kColReduced = ReductionDimensions::kColReducedDimension; + + auto physical_index = [&]() { + if (is_row_reduction_) { + IndexingMap linear_index( + mlir::AffineMap::get( + 6, 0, block_offsets.getResult(kRowKept) + thread_ids[kRowKept], + ctx), + dimension_ranges, /*range_vars=*/{}, /*rt_vars=*/{}); + int rows_per_warp = GetRowsPerWarp(); + if (rows_per_warp > 1) { + linear_index.AddConstraint( + thread_ids[kRowMinorReduced] % (WarpSize() / rows_per_warp), + {0, 0}); + } else { + linear_index.AddConstraint(thread_ids[kRowMinorReduced], {0, 0}); + } + return ComposeIndexingMaps( + linear_index, GetBitcastMap(ShapeUtil::MakeShape( + PRED, {tiling_.GetShape()[kRowKept]}), + physical_shape, ctx)); + } + + IndexingMap projected_index( + mlir::AffineMap::get( + 6, 0, + {block_offsets.getResult(kColMajorKept), + block_offsets.getResult(kColMinorKept) + thread_ids[kColReduced]}, + ctx), + dimension_ranges, /*range_vars=*/{}, /*rt_vars=*/{}); + + projected_index.AddConstraint( + mlir::getAffineDimExpr( + KernelFusionInterface::kIndexingMapThreadIdxDims[0], ctx) % + WarpSize(), + {0, 0}); + if (!is_row_reduction_) { + projected_index.AddConstraint( + projected_index.GetAffineMap().getResult(1), + {0, tiling_.GetShape()[ReductionDimensions::kColMinorKeptDimension] - + 1}); + } + + return ComposeIndexingMaps( + projected_index, + GetBitcastMap(ShapeUtil::DeleteDimension( + ReductionDimensions::kColReducedDimension, + tiling_.GetXlaShape()), + physical_shape, ctx)); + }(); + + auto map = ComposeIndexingMaps( + physical_index, + GetBitcastMap(FirstShape(hero->shape()), FirstShape(root->shape()), ctx)); + + int group_index = groups_.group_id_per_root[root_index]; + map.AddConstraint( + mlir::getAffineDimExpr(KernelFusionInterface::kIndexingMapBlockIdxDims[1], + ctx), + {group_index, group_index}); + return map; +} + +std::optional ReductionInfo::ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const { + auto* hero = analysis_.fusion_heroes()[root_index]; + if (groups_.is_reduction_root[root_index] && + hero_operand_index >= hero->operand_count() / 2) { + // We don't have indexing for the init values. + return std::nullopt; + } + + auto map = ComposeIndexingMaps( + GetIndexingMapForTiling(tiling_, ctx), + GetBitcastMap(tiling_.GetXlaShape(), + hero->operand(hero_operand_index)->shape(), ctx)); + // Only threads with the right y block index actually do anything for this + // root. + int group_index = groups_.group_id_per_root[root_index]; + map.AddConstraint( + mlir::getAffineDimExpr(KernelFusionInterface::kIndexingMapBlockIdxDims[1], + ctx), + {group_index, group_index}); + return map; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/reduction_base.h b/xla/service/gpu/fusions/reduction_base.h new file mode 100644 index 0000000000000..93c2ecc2681f8 --- /dev/null +++ b/xla/service/gpu/fusions/reduction_base.h @@ -0,0 +1,123 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_REDUCTION_BASE_H_ +#define XLA_SERVICE_GPU_FUSIONS_REDUCTION_BASE_H_ + +#include +#include +#include + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/fusions/tiling_util.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/shape.h" + +namespace xla { +namespace gpu { + +struct ReductionGroups { + std::vector> grouped_roots; + + // For each root of the fusion, returns the index of the group it was placed + // in. + std::vector group_id_per_root; + + // For each root of the fusion, returns whether it is a reduction root, or + // an additional output. + std::vector is_reduction_root; +}; + +class ReductionInfo { + public: + static ReductionInfo Create(const HloFusionAnalysis& analysis); + + const Tiling& GetTiling() const { return tiling_; } + const ReductionGroups& GetGroups() const { return groups_; } + Shape GetReduceOperandShape() const { + return first_reduce_->operand(0)->shape(); + } + + bool IsRowReduction() const { return is_row_reduction_; } + bool IsRaceFree() const { return is_race_free_; } + int GetRowsPerWarp() const; + + std::optional ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const; + + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const; + + LaunchDimensions launch_dimensions() const; + + private: + ReductionInfo(const HloFusionAnalysis& analysis, Tiling tiling, + bool is_row_reduction, bool is_race_free, + ReductionGroups groups, const HloInstruction* first_reduce) + : analysis_(analysis), + tiling_(tiling), + is_row_reduction_(is_row_reduction), + is_race_free_(is_race_free), + groups_(std::move(groups)), + first_reduce_(first_reduce) {} + + const HloFusionAnalysis& analysis_; + Tiling tiling_; + bool is_row_reduction_; + bool is_race_free_; + ReductionGroups groups_; + const HloInstruction* first_reduce_; +}; + +// Base class for reduction fusions. Computes shared information (reduction +// grouping) and provides implementations of thread->input/output indexing. +template +class ReductionFusionBase : public Base { + public: + explicit ReductionFusionBase(const HloFusionAnalysis& analysis) + : analysis_(analysis), reduction_info_(ReductionInfo::Create(analysis)) {} + + std::optional ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const override { + return reduction_info().ComputeThreadIdToOutputIndexing(root_index, ctx); + } + + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const override { + return reduction_info().ComputeThreadIdToInputIndexing( + root_index, hero_operand_index, ctx); + } + + LaunchDimensions launch_dimensions() const override { + return reduction_info().launch_dimensions(); + } + + const ReductionInfo& reduction_info() const { return reduction_info_; } + + const HloFusionAnalysis& analysis() const { return analysis_; } + + private: + const HloFusionAnalysis& analysis_; + ReductionInfo reduction_info_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_REDUCTION_BASE_H_ diff --git a/xla/service/gpu/fusions/reduction_base_test.cc b/xla/service/gpu/fusions/reduction_base_test.cc new file mode 100644 index 0000000000000..2c4ffa0e9ce07 --- /dev/null +++ b/xla/service/gpu/fusions/reduction_base_test.cc @@ -0,0 +1,405 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/reduction_base.h" + +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" + +namespace xla { +namespace gpu { +namespace { + +class ReductionTest : public HloTestBase { + protected: + stream_executor::DeviceDescription device_info_ = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); +}; + +class FakeReductionFusion : public ReductionFusionBase { + using ReductionFusionBase::ReductionFusionBase; + absl::StatusOr Emit( + IrEmitterContext&, const HloFusionInstruction&) const override { + return absl::UnimplementedError("Unimplemented"); + } +}; + +std::unique_ptr GetReductionFusion( + const HloFusionAnalysis& analysis) { + return std::make_unique(analysis); +} + +TEST_F(ReductionTest, ThreadIndexingRowReduction) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + fusion { + %input = f32[100,64,512] parameter(0) + %c0 = f32[] constant(0) + ROOT reduce = f32[100,64] reduce(%input, %c0), dimensions={2}, to_apply=add + } + + ENTRY entry { + %input = f32[100,64,512] parameter(0) + ROOT %fusion = f32[100,64] fusion(%input), kind=kInput, calls=fusion + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + FakeReductionFusion fusion(analysis); + mlir::MLIRContext mlir_context; + + EXPECT_THAT( + fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + (d3 * 8 + d0 floordiv 32) floordiv 64, + (d3 * 8 + d0 floordiv 32) mod 64, + d0 mod 32 + s2 * 32 + ) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 799] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 0] + s1 in [0, 0] + s2 in [0, 15] + 0 in [0, 0] + d0 mod 32 + s2 * 32 in [0, 511] + d3 * 8 + d0 floordiv 32 in [0, 6399] + )")); + EXPECT_THAT( + fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5) -> ( + (d3 * 8 + d0 floordiv 32) floordiv 64, + (d3 * 8 + d0 floordiv 32) mod 64 + ) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 799] + d4 in [0, 0] + d5 in [0, 0] + (d3 * 8 + d0 floordiv 32) mod 64 in [0, 63] + d0 mod 32 in [0, 0] + d3 * 8 + d0 floordiv 32 in [0, 6399] + )")); +} + +TEST_F(ReductionTest, ThreadIndexingMultiRowReduction) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + fusion { + %input = f32[100,64,4] parameter(0) + %c0 = f32[] constant(0) + ROOT reduce = f32[100,64] reduce(%input, %c0), dimensions={2}, to_apply=add + } + + ENTRY entry { + %input = f32[100,64,4] parameter(0) + ROOT %fusion = f32[100,64] fusion(%input), kind=kInput, calls=fusion + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + FakeReductionFusion fusion(analysis); + mlir::MLIRContext mlir_context; + + EXPECT_THAT( + fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + d3 + (d0 floordiv 4) floordiv 64, + (d0 floordiv 4) mod 64, + d0 mod 4 + ) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 99] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 0] + s1 in [0, 0] + s2 in [0, 0] + 0 in [0, 0] + d0 mod 4 in [0, 3] + d3 * 64 + d0 floordiv 4 in [0, 6399] + )")); + EXPECT_THAT( + fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5) -> ( + d3 + (d0 floordiv 4) floordiv 64, + (d0 floordiv 4) mod 64 + ) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 99] + d4 in [0, 0] + d5 in [0, 0] + (d0 floordiv 4) mod 64 in [0, 63] + d0 mod 4 in [0, 0] + d3 * 64 + d0 floordiv 4 in [0, 6399] + d3 + (d0 floordiv 4) floordiv 64 in [0, 99] + )")); +} + +TEST_F(ReductionTest, ThreadIndexingColumnReduction) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + fusion { + %input = f32[100,64,32] parameter(0) + %c0 = f32[] constant(0) + ROOT reduce = f32[100,32] reduce(%input, %c0), dimensions={1}, to_apply=add + } + + ENTRY entry { + %input = f32[100,64,32] parameter(0) + ROOT %fusion = f32[100,32] fusion(%input), kind=kInput, calls=fusion + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + FakeReductionFusion fusion(analysis); + mlir::MLIRContext mlir_context; + + EXPECT_THAT( + fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + d3, + d0 floordiv 32 + s1 * 32, + d0 mod 32 + ) + domain: + d0 in [0, 1023] d1 in [0, 0] d2 in [0, 0] + d3 in [0, 99] d4 in [0, 0] d5 in [0, 0] + s0 in [0, 0] s1 in [0, 127] s2 in [0, 0] + d0 floordiv 32 + s1 * 32 in [0, 63] + d0 mod 32 in [0, 31] + )")); + EXPECT_THAT( + fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5) -> ( + d3, + d0 floordiv 32 + ) + domain: + d0 in [0, 1023] d1 in [0, 0] d2 in [0, 0] + d3 in [0, 99] d4 in [0, 0] d5 in [0, 0] + d0 mod 32 in [0, 0] + )")); +} + +TEST_F(ReductionTest, ThreadIndexingOutputLayout) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + fusion { + %input = f32[100,64,512] parameter(0) + %c0 = f32[] constant(0) + ROOT reduce = f32[100,64]{0,1} reduce(%input, %c0), dimensions={2}, to_apply=add + } + + ENTRY entry { + %input = f32[100,64,512] parameter(0) + ROOT %fusion = f32[100,64]{0,1} fusion(%input), kind=kInput, calls=fusion + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + FakeReductionFusion fusion(analysis); + mlir::MLIRContext mlir_context; + + EXPECT_THAT( + fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5) -> ( + (d3 * 8 + d0 floordiv 32) floordiv 64, + (d3 * 8 + d0 floordiv 32) mod 64 + ) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 799] + d4 in [0, 0] + d5 in [0, 0] + (d3 * 8 + d0 floordiv 32) mod 64 in [0, 63] + d0 mod 32 in [0, 0] + d3 * 8 + d0 floordiv 32 in [0, 6399] + )")); +} + +TEST_F(ReductionTest, ThreadIndexingSideOutput) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + fusion { + %input = f32[100,64,512] parameter(0) + %c0 = f32[] constant(0) + %log = f32[100,64,512] log(%input) + %reduce = f32[100,64] reduce(%input, %c0), dimensions={2}, to_apply=add + ROOT tuple = (f32[100,64], f32[100,64,512]) tuple(%reduce, %log) + } + + ENTRY entry { + %input = f32[100,64,512] parameter(0) + ROOT %fusion = (f32[100,64], f32[100,64,512]) fusion(%input), kind=kInput, calls=fusion + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + FakeReductionFusion fusion(analysis); + mlir::MLIRContext mlir_context; + + constexpr char kExpectedIndexing[] = R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + (d3 * 8 + d0 floordiv 32) floordiv 64, + (d3 * 8 + d0 floordiv 32) mod 64, + d0 mod 32 + s2 * 32 + ) + domain: + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 799] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 0] + s1 in [0, 0] + s2 in [0, 15] + 0 in [0, 0] + d0 mod 32 + s2 * 32 in [0, 511] + d3 * 8 + d0 floordiv 32 in [0, 6399] + )"; + EXPECT_THAT( + fusion.ComputeThreadIdToInputIndexing(1, 0, &mlir_context)->ToString(), + MatchIndexingString(kExpectedIndexing)); + EXPECT_THAT( + fusion.ComputeThreadIdToOutputIndexing(1, &mlir_context)->ToString(), + MatchIndexingString(kExpectedIndexing)); +} + +TEST_F(ReductionTest, bla) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + fusion { + %input = f32[1024, 8192] parameter(0) + %c0 = f32[] constant(0) + ROOT reduce = f32[1024]{0} reduce(f32[1024, 8192] %input, f32[] %c0), + dimensions={1}, to_apply=add + } + ENTRY entry { + %input = f32[1024, 8192] parameter(0) + ROOT %fusion = f32[1024] fusion(%input), kind=kInput, calls=fusion + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + FakeReductionFusion fusion(analysis); + mlir::MLIRContext mlir_context; + + EXPECT_THAT( + fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( + d3, + (d0 + s2 * 512) * 2 + s3 + ) + domain: + d0 in [0, 511] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 1023] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 0] + s1 in [0, 0] + s2 in [0, 7] + s3 in [0, 1] + 0 in [0, 0] + d0 + s2 * 512 in [0, 4095] + )")); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/reduction_mlir.cc b/xla/service/gpu/fusions/reduction_mlir.cc new file mode 100644 index 0000000000000..1ff23dddcf51b --- /dev/null +++ b/xla/service/gpu/fusions/reduction_mlir.cc @@ -0,0 +1,364 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/reduction_mlir.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/mlir/type_util.h" +#include "xla/service/gpu/fusions/reduction_base.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/reduction_utils.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +using llvm::SmallVector; +using mlir::Value; +using mlir::ValueRange; +using mlir_converter::PartitionedComputation; +using mlir_converter::PartitionedComputations; + +struct MlirReductionFusion::EmitterState { + // Uses the given indexing map to reduce a subset of the inputs in a single + // thread. The subset may be a single element. + absl::StatusOr> EmitPerThreadReducedElements( + const IndexingMap& input_indexing, const HloInstruction* hero, + ValueRange inits); + + mlir::func::FuncOp GetReducer(const HloInstruction* hero) const { + return call_target(hero->called_computations()[0]->root_instruction()); + } + + SmallVector AllocateSharedTiles(const HloInstruction* hero, + absl::Span shape); + + SmallVector FusionParams() { + return ValueRange(entry_function.getArguments().take_front( + fusion.fused_parameters().size())); + } + + const MlirReductionFusion& owner; + mlir::func::FuncOp entry_function; + const HloFusionInstruction& fusion; + const PartitionedComputations& computations; + const mlir_converter::CallTargetProvider& call_target; + mlir::ImplicitLocOpBuilder builder; +}; + +MlirReductionFusion::MlirReductionFusion(const HloFusionAnalysis& analysis) + : ReductionFusionBase(analysis) { + for (auto [index, hero] : llvm::enumerate(analysis.fusion_heroes())) { + if (reduction_info().GetGroups().is_reduction_root[index]) { + reduction_roots_[hero].push_back(index); + } + } + + for (const auto& [hero, _] : reduction_roots_) { + reduction_heroes_.push_back(hero); + } +} + +bool MlirReductionFusion::IsSupported(const HloFusionAnalysis& analysis) { + auto info = ReductionInfo::Create(analysis); + return info.GetGroups().grouped_roots.size() == 1 && + !absl::c_linear_search(info.GetGroups().is_reduction_root, false) && + info.IsRaceFree(); +} + +std::vector +MlirReductionFusion::GetInstructionsWithCustomCodegen( + const HloFusionInstruction& fusion) const { + return reduction_heroes_; +} + +absl::Status MlirReductionFusion::EmitEntryFunction( + const mlir_converter::PartitionedComputations& computations, + const mlir_converter::CallTargetProvider& call_targets, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const { + // Reduction groups will probably be implemented in a separate pass, since + // they share nothing by definition. + TF_RET_CHECK(reduction_info().GetGroups().grouped_roots.size() == 1) + << "Only one reduction group is supported."; + EmitterState state{*this, entry_function, + fusion, computations, + call_targets, {entry_function.getLoc(), entry_function}}; + state.builder.setInsertionPointToStart(entry_function.addEntryBlock()); + return EmitReduction(state); +} + +absl::Status MlirReductionFusion::EmitReduction(EmitterState& state) const { + CHECK(IsSupported(analysis())) + << "Attempting to output code for an unsupported reduction"; + auto& builder = state.builder; + const auto& tiling = reduction_info().GetTiling(); + + // The number of warps working on one element in a row reduction. + int num_warps_row = tiling.GetThreadsPerBlock() + [ReductionDimensions::kRowMinorReducedDimension] / + WarpSize(); + auto ctx = state.entry_function.getContext(); + + auto zero = builder.create(0); + auto lane_id = builder.create(); + auto is_first_lane = builder.create( + mlir::arith::CmpIPredicate::eq, lane_id, zero); + auto thread_id = EmitThreadId(builder, 0); + auto block_id = EmitBlockId(builder, 0); + Value cstTrue = builder.create( + builder.getIntegerAttr(builder.getI1Type(), 1)); + + auto thread_ids = mlir_converter::ApplyAffineMap( + mlir::AffineMap::get( + /*dimCount=*/1, /*symbolCount=*/0, + DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), + tiling.GetThreadsPerBlock(), + tiling.GetThreadStrides()), + ctx), + {thread_id}, {}, builder); + SmallVector thread_and_block_indices{thread_id, zero, zero, + block_id, zero, zero}; + + auto warp_id = builder.create( + reduction_info().IsRowReduction() + ? thread_ids[ReductionDimensions::kRowMinorReducedDimension] + : thread_id, + builder.create(WarpSize())); + + auto output_args = state.entry_function.getArguments().drop_front( + state.fusion.fused_parameters().size()); + + std::vector shared_tile_size; + SmallVector shared_write_indices; + SmallVector shared_read_indices; + Value shared_write_condition = cstTrue; + Value shared_read_condition = cstTrue; + if (!reduction_info().IsRowReduction()) { + shared_tile_size = {WarpSize(), WarpSize() + 1}; + shared_write_indices = {lane_id, warp_id}; + shared_read_indices = {warp_id, lane_id}; + } else if (reduction_info().GetRowsPerWarp() == 1 && num_warps_row > 1) { + auto kKept = ReductionDimensions::kRowKeptDimension; + shared_tile_size = {tiling.GetThreadsPerBlock()[kKept], num_warps_row}; + shared_write_condition = is_first_lane; + shared_read_condition = builder.create( + mlir::arith::CmpIPredicate::ult, + thread_ids[ReductionDimensions::kRowMinorReducedDimension], + builder.create(num_warps_row)); + shared_write_indices = {thread_ids[kKept], warp_id}; + shared_read_indices = {thread_ids[kKept], lane_id}; + } + bool use_shared = !shared_tile_size.empty(); + + auto output_indexing = ComputeThreadIdToOutputIndexing(0, ctx); + auto output_indices = mlir_converter::ApplyAffineMap( + output_indexing->GetAffineMap(), thread_and_block_indices, {}, builder); + auto thread_has_output = mlir_converter::CheckConstraints( + *output_indexing, thread_and_block_indices, {}, builder); + + llvm::DenseMap> inits; + for (auto [index, hero] : llvm::enumerate(reduction_heroes_)) { + int num_inputs = hero->operand_count() / 2; + const auto& computation = + state.computations.FindPartitionedComputation(hero->parent()); + inits[hero] = ProvideParameterRange( + computation.FindSubgraph(hero), hero, num_inputs, num_inputs, {}, + state.call_target, state.entry_function, builder); + } + + auto evaluate_epilogue = + [&](SmallVector> results) -> mlir::ValueRange { + if (!state.computations.epilogue()) { + return results.front(); + } + + llvm::SmallVector hero_values; + for (const auto& result : results) { + CHECK(result.size() == 1) + << "Epilogue fusions are not supported with variadic reduce."; + hero_values.push_back(result.front()); + } + return EmitEpilogue(state.computations, state.entry_function, hero_values, + output_indices, builder); + }; + + SmallVector updated_outputs; + SmallVector> results; + for (auto* hero : reduction_heroes_) { + auto input_indexing = ComputeThreadIdToInputIndexing( + reduction_roots_.at(hero).front(), 0, ctx); + TF_ASSIGN_OR_RETURN( + auto accumulated, + state.EmitPerThreadReducedElements(*input_indexing, hero, inits[hero])); + + // In row reductions, we can do a warp shuffle before writing to shared + // memory. In column reductions, the members of the warp process different + // output elements, so we need to transpose first. + if (reduction_info().IsRowReduction()) { + auto reducer = state.GetReducer(hero); + int max_dist = WarpSize() / 2 / reduction_info().GetRowsPerWarp(); + accumulated = + builder.create(reducer, accumulated, max_dist) + .getResults(); + } + + results.push_back(accumulated); + } + + if (use_shared) { + // Write results to shared memory. + for (auto [hero, result] : llvm::zip(reduction_heroes_, results)) { + auto dest = state.AllocateSharedTiles(hero, shared_tile_size); + for (auto [value, output] : llvm::zip(result, dest)) { + updated_outputs.push_back(builder.create( + shared_write_condition, value, output, shared_write_indices)); + } + } + } else { + // Evaluate the epilogue, if there is one. + auto result_scalars = evaluate_epilogue(results); + for (auto [value, output] : llvm::zip(result_scalars, output_args)) { + updated_outputs.push_back(builder.create( + thread_has_output, value, output, output_indices)); + } + builder.create(updated_outputs); + return absl::OkStatus(); + } + + // Wait for the entire tile to be written. + auto shared_tiles = builder + .create( + mlir::TypeRange(updated_outputs), updated_outputs) + .getResults(); + auto write_outputs = [&](mlir::OpBuilder then_builder, mlir::Location loc) { + results.clear(); + mlir::ImplicitLocOpBuilder b(loc, then_builder); + int tile_index = 0; + llvm::SmallVector updated_outputs; + for (auto* hero : reduction_heroes_) { + // Load from shared memory. + SmallVector reduced; + for (auto init : inits[hero]) { + // If a warp didn't write anything, use the init values instead. + reduced.push_back(b.create( + shared_read_condition, init, + shared_tiles[tile_index++], shared_read_indices) + .getResult()); + } + + reduced = builder + .create(state.GetReducer(hero), reduced, + WarpSize() / 2) + .getResults(); + results.push_back(reduced); + } + + auto result_scalars = evaluate_epilogue(results); + + for (auto [output_value, dest] : llvm::zip(result_scalars, output_args)) { + updated_outputs.push_back(b.create( + thread_has_output, output_value, dest, output_indices)); + } + b.create(loc, updated_outputs); + }; + + auto warp_writes = reduction_info().IsRowReduction() + ? builder.create( + mlir::arith::CmpIPredicate::eq, warp_id, zero) + : cstTrue; + auto written = builder.create( + warp_writes, write_outputs, [&](mlir::OpBuilder b, mlir::Location loc) { + b.create(loc, output_args); + }); + builder.create(written.getResults()); + + return absl::OkStatus(); +} + +absl::StatusOr> +MlirReductionFusion::EmitterState::EmitPerThreadReducedElements( + const IndexingMap& input_indexing, const HloInstruction* hero, + ValueRange inits) { + auto body_builder = [&](ValueRange outputs, ValueRange dim_values, + ValueRange symbol_values) -> SmallVector { + auto indices = mlir_converter::ApplyAffineMap( + input_indexing.GetAffineMap(), dim_values, symbol_values, builder); + auto operands = FusionParams(); + absl::c_copy(indices, std::back_inserter(operands)); + auto values = ProvideParameterRange(computations.FindSubgraph(hero), hero, + 0, hero->operand_count() / 2, indices, + call_target, entry_function, builder); + + SmallVector reduce_args = outputs; + reduce_args.append(values.begin(), values.end()); + return builder.create(GetReducer(hero), reduce_args) + .getResults(); + }; + return owner.EmitThreadLoopNest(builder, inits, input_indexing, body_builder); +} + +SmallVector MlirReductionFusion::EmitterState::AllocateSharedTiles( + const HloInstruction* hero, absl::Span shape) { + SmallVector tiles; + for (int i = 0; i < hero->operand_count() / 2; ++i) { + tiles.push_back( + builder.create(mlir_converter::TensorShapeToMlirType( + ShapeUtil::MakeShapeWithDescendingLayout( + hero->operand(i)->shape().element_type(), shape), + builder))); + } + return tiles; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/reduction_mlir.h b/xla/service/gpu/fusions/reduction_mlir.h new file mode 100644 index 0000000000000..e321d4da97bbc --- /dev/null +++ b/xla/service/gpu/fusions/reduction_mlir.h @@ -0,0 +1,62 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_REDUCTION_MLIR_H_ +#define XLA_SERVICE_GPU_FUSIONS_REDUCTION_MLIR_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" +#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" +#include "xla/service/gpu/fusions/reduction_base.h" + +namespace xla { +namespace gpu { + +// Reduction fusion. Lowers to LLVM via MLIR. Currently not fully +// implemented: only single reduction groups, no side outputs, only row +// reductions. +class MlirReductionFusion : public ReductionFusionBase { + public: + explicit MlirReductionFusion(const HloFusionAnalysis& analysis); + + static bool IsSupported(const HloFusionAnalysis& analysis); + + protected: + absl::Status EmitEntryFunction( + const mlir_converter::PartitionedComputations& computations, + const mlir_converter::CallTargetProvider& call_targets, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const override; + + std::vector GetInstructionsWithCustomCodegen( + const HloFusionInstruction& fusion) const override; + + private: + struct EmitterState; + friend struct EmitterState; + + absl::Status EmitReduction(EmitterState& state) const; + + std::vector reduction_heroes_; + // The root indices for each reduction hero. + absl::flat_hash_map> reduction_roots_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_REDUCTION_MLIR_H_ diff --git a/xla/service/gpu/fusions/reduction_mlir_test.cc b/xla/service/gpu/fusions/reduction_mlir_test.cc new file mode 100644 index 0000000000000..23edf7ed08f64 --- /dev/null +++ b/xla/service/gpu/fusions/reduction_mlir_test.cc @@ -0,0 +1,345 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/reduction_mlir.h" + +#include + +#include +#include "xla/error_spec.h" +#include "xla/service/gpu/fusions/mlir_emitter_test_base.h" +#include "tsl/lib/core/status_test_util.h" + +namespace xla { +namespace gpu { +namespace { + +using ReductionTest = MlirEmitterTestBase; + +TEST_F(ReductionTest, VariadicRowReduce) { + constexpr auto kHloString = R"( + HloModule Test, is_scheduled=true + + Add { + scalar_lhs.0 = f32[] parameter(0) + scalar_rhs.0 = f32[] parameter(1) + scalar_lhs.1 = f32[] parameter(2) + scalar_rhs.1 = f32[] parameter(3) + add.0 = f32[] add(scalar_lhs.0, scalar_lhs.1) + add.1 = f32[] add(scalar_rhs.0, scalar_rhs.1) + ROOT t = (f32[], f32[]) tuple(add.0, add.1) + } + fused_computation { + param_0 = f32[5,200,2048] parameter(0) + param_1 = f32[5,200,2048] parameter(1) + param_2 = f32[] parameter(2) + ROOT d.1 = (f32[5,200], f32[5,200]) + reduce(param_0, param_1, param_2, param_2), dimensions={2}, to_apply=Add + } + ENTRY main { + a = f32[5, 200, 2048] parameter(0) + b = f32[5, 200, 2048] parameter(1) + c = f32[] constant(0) + ROOT fusion = (f32[5,200], f32[5,200]) fusion(a, b, c), + kind=kInput, calls=fused_computation + })"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( +// CHECK: @fused_computation +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x200x2048xf32> {xla.slice_index = 0 +// CHECK-SAME: %[[ARG1:.*]]: tensor<5x200x2048xf32> {xla.slice_index = 1 +// CHECK-SAME: %[[INIT_TENSOR:.*]]: tensor {xla.slice_index = 2 +// CHECK-SAME: %[[OUT0:.*]]: tensor<5x200xf32> {xla.slice_index = 3 +// CHECK-SAME: %[[OUT1:.*]]: tensor<5x200xf32> {xla.slice_index = 4 +// CHECK: %[[INIT:.*]] = xla_gpu.pure_call @fused_computation_param_2 +// CHECK: %[[PER_THREAD:.*]]:2 = scf.for +// CHECK-SAME: iter_args(%[[A:.*]] = %[[INIT]], %[[B:.*]] = %[[INIT]]) +// CHECK: %[[A2:.*]] = xla_gpu.pure_call @fused_computation_param_0 +// CHECK: %[[B2:.*]] = xla_gpu.pure_call @fused_computation_param_1 +// CHECK: xla_gpu.pure_call @Add_t(%[[A]], %[[B]], %[[A2]], %[[B2]]) +// CHECK: %[[SHUFFLED:.*]]:2 = xla_gpu.shuffle_reduce +// CHECK-SAME: @Add_t(%[[PER_THREAD]]#0, %[[PER_THREAD]]#1) to 16 +// CHECK: %[[A_SHARED:.*]] = xla_gpu.allocate_shared : tensor<2x4xf32> +// CHECK: %[[B_SHARED:.*]] = xla_gpu.allocate_shared : tensor<2x4xf32> +// CHECK: predicated_insert %[[SHUFFLED]]#0 into %[[A_SHARED]] +// CHECK: predicated_insert %[[SHUFFLED]]#1 into %[[B_SHARED]] +// CHECK: sync_threads +// CHECK-NOT: shuffle_reduce) + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(ReductionTest, RowReduceEpilogue) { + constexpr auto kHloString = R"( + HloModule Test, is_scheduled=true + + Add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + fused_computation { + param_0 = f32[8,2048] parameter(0) + param_1 = f32[] parameter(1) + reduce = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Add + ROOT log = f32[8] log(reduce) + } + ENTRY main { + a = f32[8,2048] parameter(0) + c = f32[] constant(0) + ROOT fusion = f32[8] fusion(a, c), kind=kInput, calls=fused_computation + })"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK: pure_call @Add_add + // CHECK: shuffle_reduce + // CHECK: allocate_shared + // CHECK: sync_threads + // CHECK: shuffle_reduce + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(ReductionTest, RowReduceMOFEpilogue) { + constexpr auto kHloString = R"( + HloModule Test, is_scheduled=true + + Add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + Mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) + } + fused_computation { + param_0 = f32[8,2048] parameter(0) + param_1 = f32[] parameter(1) + reduce1 = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Add + reduce2 = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Mul + log = f32[8] log(reduce1) + abs = f32[8] abs(reduce1) + neg = f32[8] negate(reduce2) + ROOT tuple = (f32[8], f32[8], f32[8]) tuple(log, neg, abs) + } + ENTRY main { + a = f32[8,2048] parameter(0) + c = f32[] constant(0) + ROOT fusion = (f32[8], f32[8], f32[8]) fusion(a, c), kind=kInput, + calls=fused_computation + })"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK-DAG: pure_call @Add_add + // CHECK-DAG: shuffle_reduce @Add_add + // CHECK-DAG: pure_call @Mul_mul + // CHECK-DAG: shuffle_reduce @Mul_mul + // CHECK: allocate_shared + // CHECK: allocate_shared + // CHECK: sync_threads + // CHECK-DAG: shuffle_reduce @Add_add + // CHECK-DAG: shuffle_reduce @Mul_mul + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(ReductionTest, ColumnReduction) { + constexpr auto kHloString = R"( + HloModule Test, is_scheduled=true + + Add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + fused_computation { + param_0 = f32[123,2051,321] parameter(0) + param_1 = f32[] parameter(1) + ROOT reduce = f32[123,321] reduce(param_0, param_1), dimensions={1}, to_apply=Add + } + ENTRY main { + a = f32[123,2051,321] parameter(0) + c = f32[] constant(0) + ROOT fusion = f32[123,321] fusion(a, c), kind=kInput, calls=fused_computation + })"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK: xla_gpu.pure_call @Add_add + // CHECK: allocate_shared + // CHECK: predicated_insert + // CHECK: sync_threads + // CHECK: predicated_extract + // CHECK: shuffle_reduce + // CHECK: predicated_insert + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(ReductionTest, SmallColumnReduction) { + constexpr auto kHloString = R"( + HloModule Test, is_scheduled=true + + Add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + fused_computation { + param_0 = f32[3,128,4] parameter(0) + param_1 = f32[] parameter(1) + ROOT reduce = f32[3,4] reduce(param_0, param_1), dimensions={1}, to_apply=Add + } + ENTRY main { + a = f32[3,128,4] parameter(0) + c = f32[] constant(0) + ROOT fusion = f32[3,4] fusion(a, c), kind=kInput, calls=fused_computation + })"; + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(ReductionTest, F64RowReduction) { + constexpr auto kHloString = R"( + HloModule Test, is_scheduled=true + + Add { + lhs = f64[] parameter(0) + rhs = f64[] parameter(1) + ROOT add = f64[] add(lhs, rhs) + } + fused_computation { + param_0 = f64[100,128] parameter(0) + param_1 = f64[] parameter(1) + ROOT reduce = f64[100] reduce(param_0, param_1), dimensions={1}, to_apply=Add + } + ENTRY main { + a = f64[100,128] parameter(0) + c = f64[] constant(0) + ROOT fusion = f64[100] fusion(a, c), kind=kInput, calls=fused_computation + })"; + // This reduction is small enough not to require shared memory. + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK-NOT: allocate_shared + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(ReductionTest, MultiRowReduction) { + constexpr auto kHloString = R"( + HloModule Test, is_scheduled=true + + Add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + fused_computation { + param_0 = f32[1024,4] parameter(0) + param_1 = f32[] parameter(1) + ROOT reduce = f32[1024] reduce(param_0, param_1), dimensions={1}, to_apply=Add + } + ENTRY main { + a = f32[1024,4] parameter(0) + c = f32[] constant(0) + ROOT fusion = f32[1024] fusion(a, c), kind=kInput, calls=fused_computation + })"; + // Multi-row reductions don't use shared memory. + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK: shuffle_reduce {{.*}} to 2 + // CHECK-NOT: allocate_shared + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(ReductionTest, NonPowerOfTwoRowReduction) { + constexpr auto kHloString = R"( + HloModule Test, is_scheduled=true + + Add { + lhs = f64[] parameter(0) + rhs = f64[] parameter(1) + ROOT add = f64[] add(lhs, rhs) + } + fused_computation { + param_0 = f64[100,568] parameter(0) + param_1 = f64[] parameter(1) + ROOT reduce = f64[100] reduce(param_0, param_1), dimensions={1}, to_apply=Add + } + ENTRY main { + a = f64[100,568] parameter(0) + c = f64[] constant(0) + ROOT fusion = f64[100] fusion(a, c), kind=kInput, calls=fused_computation + })"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK: allocate_shared + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(ReductionTest, MixedIndexing) { + constexpr auto kHloString = R"( + HloModule module + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + fusion { + %param_0 = f32[64,128] parameter(0) + %constant_0 = f32[] constant(0) + %reduce.1 = f32[128] reduce(f32[64,128] %param_0, f32[] %constant_0), dimensions={0}, to_apply=%add + %neg = f32[64,128] negate(f32[64,128] %param_0) + %bitcast = f32[8,8,128]{2,1,0} bitcast(f32[64,128] %neg) + %reduce.2 = f32[128] reduce(f32[8,8,128]{2,1,0} %bitcast, f32[] %constant_0), dimensions={0,1}, to_apply=%add + ROOT %tuple.12 = (f32[128], f32[128]) tuple(f32[128] %reduce.1, f32[128] %reduce.2) + } + ENTRY entry { + %param_0 = f32[64,128] parameter(0) + ROOT %fusion = (f32[128], f32[128]) fusion(%param_0), kind=kInput, calls=fusion + })"; + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(ReductionTest, NonTrivialEpilogue) { + constexpr auto kHloString = R"( + HloModule module + add { + p0 = f64[] parameter(0) + p1 = f64[] parameter(1) + ROOT add = f64[] add(p0, p1) + } + fusion { + %p0 = f64[4] parameter(0) + %p1 = f64[4] parameter(1) + %c0 = f64[] constant(-inf) + %reduce0 = f64[] reduce(p1, c0), dimensions={0}, to_apply=add + %bc0 = f64[4] broadcast(reduce0), dimensions={} + %compare0 = pred[4] compare(p1, bc0), direction=EQ + %c1 = f64[] constant(0) + %bc1 = f64[4] broadcast(c1), dimensions={} + %select.3.1 = f64[4] select(compare0, p0, bc1) + %reduce1 = f64[] reduce(select.3.1, c1), dimensions={0}, to_apply=add + %convert0 = f64[4] convert(compare0) + %reduce2 = f64[] reduce(convert0, c1), dimensions={0}, to_apply=add + ROOT %tuple.1 = (f64[], f64[], f64[]) tuple(%reduce1, reduce0, reduce2) + } + ENTRY main { + %p0 = f64[4] parameter(0) + %p1 = f64[4] parameter(1) + ROOT %fusion = (f64[], f64[], f64[]) fusion(%p0, %p1), kind=kInput, + calls=fusion + })"; + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/scatter.cc b/xla/service/gpu/fusions/scatter.cc new file mode 100644 index 0000000000000..0625f9efd4653 --- /dev/null +++ b/xla/service/gpu/fusions/scatter.cc @@ -0,0 +1,286 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/scatter.h" + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/elemental_ir_emitter.h" +#include "xla/service/gpu/fusions/loop.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/ir_emitter_nested.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/parallel_loop_emitter.h" +#include "xla/service/llvm_ir/fused_ir_emitter.h" +#include "xla/service/llvm_ir/ir_array.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" + +namespace xla { +namespace gpu { + +ScatterFusion::ScatterFusion(const HloFusionAnalysis& analysis) + : analysis_(analysis), config_(ComputeLoopFusionConfig(analysis)) { + CHECK_EQ(analysis.fusion_roots().size(), 1); + CHECK_EQ(analysis.fusion_roots()[0]->opcode(), HloOpcode::kScatter); +} + +LaunchDimensions ScatterFusion::launch_dimensions() const { + const auto& updates_shape = + analysis_.fusion_roots().front()->operands().back()->shape(); + return CalculateLaunchDimensions(updates_shape, analysis_.device_info()); +} + +absl::Status ScatterFusion::EmitKernel(IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion, + const LaunchDimensions& launch_dims, + std::vector inputs, + std::vector outputs, + llvm::IRBuilder<>* builder) const { + GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder); + // Spin up a new fused emitter for the scatter kernel and emit it. + FusedIrEmitter scatter_fused_emitter(elemental_emitter); + auto* fused_computation = fusion.fused_instructions_computation(); + for (int i = 0; i < fused_computation->num_parameters(); i++) { + auto fused_operand = fused_computation->parameter_instruction(i); + scatter_fused_emitter.BindGenerator( + *fused_operand, [builder, &input = inputs[i], + fused_operand](llvm_ir::IrArray::Index index) { + return input.EmitReadArrayElement(index, builder, + fused_operand->name()); + }); + } + + auto* root = fused_computation->root_instruction(); + const xla::ScatterDimensionNumbers& scatter_dims = + Cast(root)->scatter_dimension_numbers(); + + std::string name = llvm_ir::IrName(root); + const Shape& operand_shape = root->operand(0)->shape(); + const Shape& scatter_indices_shape = root->operand(1)->shape(); + const Shape& updates_shape = root->operand(2)->shape(); + const HloComputation& update_computation = *root->called_computations()[0]; + + TF_ASSIGN_OR_RETURN(auto scatter_indices_gen, + scatter_fused_emitter.GetGenerator(*root->operand(1))); + TF_ASSIGN_OR_RETURN(auto updates_gen, + scatter_fused_emitter.GetGenerator(*root->operand(2))); + + auto loop_body_emitter = + [&](const llvm_ir::IrArray::Index& index) -> absl::Status { + std::vector raw_window_multidim; + std::vector input_scatter_multidim; + std::vector raw_window_bounds; + + auto get_i64_array = [](absl::Span container) { + return llvm::ArrayRef{container.data(), + static_cast(container.size())}; + }; + + llvm::ArrayRef update_window_dims = + get_i64_array(scatter_dims.update_window_dims()); + // Partition the index into window indices and scatter indices. + for (int64_t i = 0, e = index.size(); i != e; ++i) { + // For window indices also remember the window size, this comes in handy + // later. + if (llvm::is_contained(update_window_dims, i)) { + raw_window_multidim.push_back(index[i]); + raw_window_bounds.push_back(updates_shape.dimensions(i)); + } else { + input_scatter_multidim.push_back(index[i]); + } + } + DCHECK_EQ(raw_window_multidim.size(), + scatter_dims.update_window_dims_size()); + + // Apply inserted_window_dims to the window dimensions. + int64_t raw_window_multidim_idx = 0; + llvm::SmallVector input_window_multidim; + llvm::SmallVector input_window_bounds; + const int64_t rank = operand_shape.rank(); + input_window_bounds.reserve(rank); + input_window_multidim.reserve(rank); + + llvm::ArrayRef inserted_window_dims = + get_i64_array(scatter_dims.inserted_window_dims()); + for (int64_t i = 0; i != rank; ++i) { + if (llvm::is_contained(inserted_window_dims, i)) { + input_window_bounds.push_back(1); // Trivial dimension. + input_window_multidim.push_back(index.GetConstantWithIndexType(0)); + } else { + input_window_bounds.push_back( + raw_window_bounds[raw_window_multidim_idx]); + input_window_multidim.push_back( + raw_window_multidim[raw_window_multidim_idx]); + ++raw_window_multidim_idx; + } + } + DCHECK_EQ(input_window_multidim.size(), operand_shape.rank()); + + // Insert a 1 dimension at the end if index_vector_dim requests one. + Shape scatter_indices_shape_fixed = scatter_indices_shape; + if (scatter_dims.index_vector_dim() == scatter_indices_shape.rank()) { + scatter_indices_shape_fixed.add_dimensions(1); + scatter_indices_shape_fixed.mutable_layout()->add_minor_to_major( + scatter_dims.index_vector_dim()); + } + + // Now load the indices corresponding to the current window from + // scatter_indices. + std::vector raw_scatter_index_multidim = + input_scatter_multidim; + raw_scatter_index_multidim.insert( + raw_scatter_index_multidim.begin() + scatter_dims.index_vector_dim(), + nullptr); + + llvm::ArrayRef scatter_dims_to_operand_dims = + get_i64_array(scatter_dims.scatter_dims_to_operand_dims()); + llvm::Value* is_in_bounds = builder->getTrue(); + for (int64_t i = 0, e = scatter_dims_to_operand_dims.size(); i != e; ++i) { + // Our index is stored along index_vector_dim, insert that into the lookup + // index into scatter_indices. + raw_scatter_index_multidim[scatter_dims.index_vector_dim()] = + index.GetConstantWithIndexType(i); + llvm_ir::IrArray::Index raw_scatter_index_index( + raw_scatter_index_multidim, scatter_indices_shape_fixed, + index.GetType()); + + int64_t operand_dim = scatter_dims_to_operand_dims[i]; + if (operand_dim > rank) { + return absl::OutOfRangeError( + "The provided scatter_dims_to_operand_dims was out of range."); + } + TF_ASSIGN_OR_RETURN( + llvm::Value* const loaded_scatter_index, + scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape( + scatter_indices_shape_fixed, scatter_indices_shape, builder))); + // And add the index to our window index. This yields the output index. + llvm::Value* casted_scatter_index = builder->CreateIntCast( + loaded_scatter_index, index.GetType(), + /*isSigned=*/ShapeUtil::ElementIsSigned(scatter_indices_shape)); + llvm::Value* dim_offset = builder->CreateAdd( + input_window_multidim[operand_dim], casted_scatter_index); + input_window_multidim[operand_dim] = dim_offset; + + // Also do the bounds check now. + int64_t max_index = operand_shape.dimensions(operand_dim) - + input_window_bounds[operand_dim] + 1; + // is_in_bounds = index >= 0 && index < dim_size-window_size+1 + // --> index u< dim_size-window_size+1 + is_in_bounds = builder->CreateAnd( + is_in_bounds, + builder->CreateICmpULT(casted_scatter_index, + index.GetConstantWithIndexType(max_index))); + } + + llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse( + is_in_bounds, "scatter.in_bounds", builder, /*emit_else=*/false); + llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, + builder); + // All done, now just read from the calculated input from the window, and do + // an atomic store to the calculated location in the output. + llvm_ir::IrArray::Index input_window_index( + input_window_multidim, outputs.back().GetShape(), index.GetType()); + llvm::Value* output_address = + outputs.back().EmitArrayElementAddress(input_window_index, builder); + llvm::Value* input_address = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(updates_shape.element_type(), + ir_emitter_context.llvm_module()), + "input_address", builder); + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index)); + builder->CreateStore(input_ir_value, input_address); + + if (root->unique_indices()) { + return CallNestedComputation( + builder, ir_emitter_context, update_computation, + {output_address, input_address}, output_address); + } + return EmitAtomicOperationForNestedComputation( + builder, ir_emitter_context, update_computation, output_address, + input_address, outputs.back().GetElementLlvmType()); + }; + + // Launch a kernel that reads every element in the updates tensor. We could + // also do one kernel per window instead if bounds checks turn out to be a + // bottleneck. + auto index_type = + GetIndexTypeForKernel(root, launch_dims.launch_bound(), builder); + return ParallelLoopEmitter(loop_body_emitter, updates_shape, launch_dims, + builder) + .EmitLoop(name, index_type); +} + +std::optional ScatterFusion::ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const { + auto* scatter = + DynCast(analysis_.fusion_heroes().front()); + int64_t scatter_operand_count = scatter->scatter_operand_count(); + // Scatter operands a packed in the following way: + // Operand IDs [0, scatter_operand_count - 1] for `scatter operands`. + // Operand ID scatter_operand_count for `scatter indices`. + // Operand IDs [scatter_operand_count + 1, 2 * scatter_operand_count] for + // `scatter updates`. + + // For scatter operands we do not know the thread ID indexing. + if (hero_operand_index < scatter_operand_count) { + return std::nullopt; + } + // Compute thread id mapping based on the first update operand. + Shape scatter_update_shape = scatter->scatter_updates().front()->shape(); + IndexingMap scatter_update_map = GetDefaultThreadIdIndexingMap( + launch_dimensions(), config_.unroll_factor, scatter_update_shape, ctx); + + // For scatter indices we project indexing for scatter updates and take the + // first result of the affine map only, because they coincide. + if (hero_operand_index == scatter_operand_count) { + Shape scatter_indices_shape = scatter->scatter_indices()->shape(); + CHECK_EQ(scatter_indices_shape.rank(), 2) << scatter->ToString(); + // Create a map from scatter update to scatter indices. + IndexingMap updates_to_indices_map{ + mlir::AffineMap::get( + /*dimCount=*/scatter_update_shape.rank(), /*symbolCount=*/1, + {mlir::getAffineDimExpr(0, ctx), mlir::getAffineSymbolExpr(0, ctx)}, + ctx), + DimVarsFromTensorSizes(scatter_update_shape.dimensions()), + RangeVarsFromTensorSizes({scatter_indices_shape.dimensions(1)}), + /*rt_vars=*/{}}; + auto scatter_indices_map = scatter_update_map * updates_to_indices_map; + scatter_indices_map.Simplify(GetIndexingMapForInstruction); + return scatter_indices_map; + } + return scatter_update_map; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/scatter.h b/xla/service/gpu/fusions/scatter.h new file mode 100644 index 0000000000000..289328fd7a7c1 --- /dev/null +++ b/xla/service/gpu/fusions/scatter.h @@ -0,0 +1,69 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_SCATTER_H_ +#define XLA_SERVICE_GPU_FUSIONS_SCATTER_H_ + +#include +#include + +#include "absl/log/check.h" +#include "llvm/IR/IRBuilder.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/llvm_ir/ir_array.h" +#include "xla/status.h" + +namespace xla { +namespace gpu { + +// A scatter, implemented as a loop over the updates. All scatters are in-place. +class ScatterFusion : public KernelFusionEmitterBase { + public: + explicit ScatterFusion(const HloFusionAnalysis& analysis); + + LaunchDimensions launch_dimensions() const override; + + std::optional ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const override { + // The kernel iterates over updates, whose correspondence to output + // elements cannot be computed statically. + return std::nullopt; + } + + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const override; + + protected: + absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion, + const LaunchDimensions& launch_dims, + std::vector inputs, + std::vector outputs, + llvm::IRBuilder<>* builder) const override; + + private: + const HloFusionAnalysis& analysis_; + LaunchDimensionsConfig config_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_SCATTER_H_ diff --git a/xla/service/gpu/fusions/scatter_mlir.cc b/xla/service/gpu/fusions/scatter_mlir.cc new file mode 100644 index 0000000000000..3c9276342fba0 --- /dev/null +++ b/xla/service/gpu/fusions/scatter_mlir.cc @@ -0,0 +1,275 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/scatter_mlir.h" + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/scatter_simplifier.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { +namespace { + +using llvm::SmallVector; +using mlir::Location; +using mlir::OpBuilder; +using mlir::Value; +using mlir::ValueRange; +using mlir::arith::AddIOp; +using mlir::arith::AndIOp; +using mlir::arith::CmpIOp; +using mlir::arith::CmpIPredicate; +using mlir::arith::ConstantIndexOp; +using mlir::func::ReturnOp; +using mlir::tensor::InsertOp; +using mlir_converter::ApplyAffineMap; +using mlir_converter::CallTargetProvider; +using mlir_converter::PartitionedComputations; +using mlir_converter::ProvideParameter; + +namespace scf = ::mlir::scf; + +} // namespace + +bool MlirScatterFusion::IsSupported(const HloFusionAnalysis& analysis) { + auto* scatter = Cast(analysis.fusion_heroes().front()); + if (scatter->scatter_operand_count() != 1) { + LOG(ERROR) << "Variadic scatter is not supported like in the legacy " + "emitter, although it is possible to make it work when the " + "indices are unique."; + return false; + } + return true; +} + +std::optional MlirScatterFusion::ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const { + return std::nullopt; +} + +std::optional MlirScatterFusion::ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const { + auto* scatter = + DynCast(analysis_.fusion_heroes().front()); + CHECK(ScatterSimplifier::IsSimplifiedScatter(scatter)) + << "Non-simplified HLO Scatter is not supported."; + int64_t scatter_operand_count = scatter->scatter_operand_count(); + // Scatter operands a packed in the following way: + // Operand IDs [0, scatter_operand_count - 1] for `scatter operands`. + // Operand ID scatter_operand_count for `scatter indices`. + // Operand IDs [scatter_operand_count + 1, 2 * scatter_operand_count] for + // `scatter updates`. + + // For scatter operands we do not know the thread ID indexing. + if (hero_operand_index < scatter_operand_count) { + return std::nullopt; + } + // Compute thread id mapping based on the first update operand. + Shape scatter_update_shape = scatter->scatter_updates().front()->shape(); + IndexingMap scatter_update_map = GetDefaultThreadIdIndexingMap( + launch_dimensions(), config_.unroll_factor, scatter_update_shape, ctx); + + // For scatter indices we project indexing for scatter updates and take the + // first result of the affine map only, because they coincide. + if (hero_operand_index == scatter_operand_count) { + Shape scatter_indices_shape = scatter->scatter_indices()->shape(); + CHECK_EQ(scatter_indices_shape.rank(), 2) << scatter->ToString(); + // Create a map from scatter update to scatter indices. + IndexingMap updates_to_indices_map{ + mlir::AffineMap::get( + /*dimCount=*/scatter_update_shape.rank(), /*symbolCount=*/1, + {mlir::getAffineDimExpr(0, ctx), mlir::getAffineSymbolExpr(0, ctx)}, + ctx), + DimVarsFromTensorSizes(scatter_update_shape.dimensions()), + RangeVarsFromTensorSizes({scatter_indices_shape.dimensions(1)}), + /*rt_vars=*/{}}; + auto scatter_indices_map = scatter_update_map * updates_to_indices_map; + scatter_indices_map.Simplify(GetIndexingMapForInstruction); + return scatter_indices_map; + } + return scatter_update_map; +} + +LaunchDimensions MlirScatterFusion::launch_dimensions() const { + auto* scatter = analysis_.fusion_heroes().front(); + // Compute thread id mapping based on the shape of update operand. + auto& shape = scatter->operands().back()->shape(); + return CalculateLaunchDimensions(shape, analysis_.device_info()); +} + +std::vector +MlirScatterFusion::GetInstructionsWithCustomCodegen( + const HloFusionInstruction& fusion) const { + return analysis_.fusion_heroes(); +} + +mlir::Value EmitScatterComputation( + const HloInstruction* scatter, ValueRange indices, Value update_elem, + Value output_tensor, + const mlir_converter::PartitionedComputation& root_computation, + const mlir_converter::CallTargetProvider& call_targets, + mlir::func::FuncOp entry_function, mlir::ImplicitLocOpBuilder& b) { + constexpr int kScatterOperandIndex = 0; + auto reducer = + call_targets(scatter->called_computations()[0]->root_instruction()); + if (scatter->unique_indices()) { + auto operand_elem = ProvideParameter(root_computation.FindSubgraph(scatter), + scatter, kScatterOperandIndex, indices, + call_targets, entry_function, b)[0]; + auto reduced_val = mlir_converter::InlineBlock( + b, reducer.getBody().front(), {operand_elem, update_elem})[0]; + + return b.create(reduced_val, output_tensor, indices); + } + auto atomic_rmw = b.create(output_tensor, indices); + mlir::OpBuilder body_builder = atomic_rmw.getBodyBuilder(); + auto reduced_val = mlir_converter::InlineBlock( + body_builder, reducer.getBody().front(), + {atomic_rmw.getCurrentValue(), update_elem})[0]; + body_builder.create(reducer->getLoc(), reduced_val); + return atomic_rmw->getResult(0); +} + +// The scatter has to be canonicalized with `scatter_simplifier` pass. +absl::Status MlirScatterFusion::EmitEntryFunction( + const PartitionedComputations& computations, + const CallTargetProvider& call_targets, mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const { + constexpr int kScatterOperandIndex = 0; + constexpr int kScatterIndicesIndex = 1; + constexpr int kScatterUpdateIndex = 2; + const auto* scatter = analysis_.fusion_heroes()[0]; + const HloInstruction* scatter_operand = + scatter->operand(kScatterOperandIndex); + const HloInstruction* scatter_indices = + scatter->operand(kScatterIndicesIndex); + const HloInstruction* scatter_update = scatter->operand(kScatterUpdateIndex); + + mlir::MLIRContext* mlir_context = entry_function.getContext(); + auto thread_id_to_update_map = + ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/kScatterUpdateIndex, + mlir_context) + .value(); + thread_id_to_update_map.Simplify(GetIndexingMapForInstruction); + thread_id_to_update_map.RemoveUnusedSymbols(); + + const auto& root_computation = computations.FindPartitionedComputation( + fusion.fused_instructions_computation()); + const auto& scatter_subgraph = root_computation.FindSubgraph(scatter); + mlir::ImplicitLocOpBuilder b(entry_function.getLoc(), entry_function); + b.setInsertionPointToStart(entry_function.addEntryBlock()); + + SmallVector result_tensors{entry_function.getArguments().back()}; + auto c0 = b.create(0); + + auto scatter_result = EmitThreadLoopNest( + b, result_tensors, thread_id_to_update_map, + [&](ValueRange output_tensors, ValueRange dim_values, + ValueRange symbol_values) -> SmallVector { + // Extract input element. + auto update_tensor_indices = + ApplyAffineMap(thread_id_to_update_map.GetAffineMap(), dim_values, + symbol_values, b); + auto update_elem = + ProvideParameter(scatter_subgraph, scatter, kScatterUpdateIndex, + update_tensor_indices, call_targets, + entry_function, b) + .front(); + + // Extract slice offsets from scatter_indices operand, compute if the + // whole slice of scatter_update operand will fit into the output. + mlir::Value is_in_bounds = + b.create(1, b.getI1Type()); + SmallVector indices{ + llvm::ArrayRef(update_tensor_indices).drop_front()}; + for (int i = 0; i < scatter_operand->shape().rank(); ++i) { + Value extracted_index = c0; + if (i < scatter_indices->shape().dimensions(1)) { + SmallVector indices_tensor_indices = { + update_tensor_indices.front(), b.create(i)}; + extracted_index = ProvideParameter( + scatter_subgraph, scatter, kScatterIndicesIndex, + indices_tensor_indices, call_targets, entry_function, b)[0]; + if (extracted_index.getType() != b.getIndexType()) { + extracted_index = b.create( + b.getIndexType(), extracted_index); + } + } + is_in_bounds = b.create( + is_in_bounds, + b.create(CmpIPredicate::sge, extracted_index, c0)); + Value ub = b.create( + scatter_operand->shape().dimensions(i) - + scatter_update->shape().dimensions(i + 1)); + is_in_bounds = b.create( + is_in_bounds, + b.create(CmpIPredicate::sle, extracted_index, ub)); + indices[i] = b.create(extracted_index, indices[i]); + } + // Call scatter's computation if is_in_bounds. + Value output_tensor = output_tensors.front(); + Value predicated_update = + b.create( + is_in_bounds, + [&](OpBuilder& then_builder, Location then_loc) -> void { + Value updated_output = EmitScatterComputation( + scatter, indices, update_elem, output_tensor, + root_computation, call_targets, entry_function, b); + b.create(updated_output); + }, + [&](OpBuilder& else_b, Location else_loc) { + b.create(output_tensor); + }) + .getResult(0); + return {predicated_update}; + }); + b.create(scatter_result); + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/scatter_mlir.h b/xla/service/gpu/fusions/scatter_mlir.h new file mode 100644 index 0000000000000..e66e2c6a4f5a7 --- /dev/null +++ b/xla/service/gpu/fusions/scatter_mlir.h @@ -0,0 +1,69 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_SCATTER_MLIR_H_ +#define XLA_SERVICE_GPU_FUSIONS_SCATTER_MLIR_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/loop.h" +#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/status.h" + +namespace xla { +namespace gpu { + +// Generic loop fusion. Lowers to LLVM via MLIR. +class MlirScatterFusion : public MlirFusionEmitterBase { + public: + explicit MlirScatterFusion(const HloFusionAnalysis& analysis) + : analysis_(analysis), config_(ComputeLoopFusionConfig(analysis)) {} + LaunchDimensions launch_dimensions() const override; + + static bool IsSupported(const HloFusionAnalysis& analysis); + + std::optional ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const override; + + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const override; + + protected: + absl::Status EmitEntryFunction( + const mlir_converter::PartitionedComputations& computations, + const mlir_converter::CallTargetProvider& call_targets, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const override; + + std::vector GetInstructionsWithCustomCodegen( + const HloFusionInstruction& fusion) const override; + + private: + const HloFusionAnalysis& analysis_; + LaunchDimensionsConfig config_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_SCATTER_MLIR_H_ diff --git a/xla/service/gpu/fusions/scatter_mlir_test.cc b/xla/service/gpu/fusions/scatter_mlir_test.cc new file mode 100644 index 0000000000000..12fca854ae5fb --- /dev/null +++ b/xla/service/gpu/fusions/scatter_mlir_test.cc @@ -0,0 +1,341 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/scatter_mlir.h" + +#include +#include "xla/error_spec.h" +#include "xla/service/gpu/fusions/mlir_emitter_test_base.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +using MlirScatterFusionTest = MlirEmitterTestBase; + +TEST_F(MlirScatterFusionTest, ThreadIdIndexing) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule module + + computation { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + %p2 = f32[] parameter(2) + %p3 = f32[] parameter(3) + ROOT %tuple = (f32[], f32[]) tuple(f32[] %p2, f32[] %p3) + } + scatter { + %operand0 = f32[300,200] parameter(0) + %operand1 = f32[300,200] parameter(1) + %indices = s32[42,1] parameter(2) + %update.1 = f32[42,10,20] parameter(3) + %update.2 = f32[42,10,20]parameter(4) + + ROOT %scatter = (f32[300,200], f32[300,200]) scatter( + f32[300,200] %operand0, + f32[300,200] %operand1, + s32[42,1] %indices, + f32[42,10,20] %update.1, + f32[42,10,20] %update.2 + ), + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + to_apply=computation + } + ENTRY entry { + %operand0 = f32[300,200] parameter(0) + %operand1 = f32[300,200] parameter(1) + %indices = s32[42,1] parameter(2) + %update.1 = f32[42,10,20] parameter(3) + %update.2 = f32[42,10,20]parameter(4) + ROOT %fusion = (f32[300,200], f32[300,200]) fusion( + %operand0, %operand1, %indices, %update.1, %update.2), + kind=kLoop, calls=scatter + } + )")); + thread_id_printer_.SetSymbolName(0, "chunk_id"); + thread_id_printer_.SetSymbolName(1, "unroll_id"); + thread_id_printer_.SetSymbolName(2, "index_id"); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + MlirScatterFusion fusion(analysis); + + constexpr auto kUpdatesIndexing = R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( + ((bl_x * 16 + th_x floordiv 8) floordiv 25) mod 42, + ((bl_x * 32 + th_x floordiv 4) floordiv 5) mod 10, + (th_x + bl_x * 128) mod 20) + domain: + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 65] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + th_x + bl_x * 128 in [0, 8399] + )"; + EXPECT_THAT( + fusion + .ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/3, &mlir_context_) + ->ToString(thread_id_printer_), + MatchIndexingString(kUpdatesIndexing)); + EXPECT_THAT( + fusion + .ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/4, &mlir_context_) + ->ToString(thread_id_printer_), + MatchIndexingString(kUpdatesIndexing)); + EXPECT_THAT( + fusion + .ComputeThreadIdToInputIndexing( + /*root_index=*/1, /*hero_operand_index=*/3, &mlir_context_) + ->ToString(thread_id_printer_), + MatchIndexingString(kUpdatesIndexing)); + EXPECT_THAT( + fusion + .ComputeThreadIdToInputIndexing( + /*root_index=*/1, /*hero_operand_index=*/4, &mlir_context_) + ->ToString(thread_id_printer_), + MatchIndexingString(kUpdatesIndexing)); + + constexpr auto kIndicesIndexing = R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] -> ( + ((bl_x * 16 + th_x floordiv 8) floordiv 25) mod 42, 0) + domain: + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 65] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + index_id in [0, 0] + th_x + bl_x * 128 in [0, 8399] + )"; + EXPECT_THAT( + fusion + .ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_) + ->ToString(thread_id_printer_), + MatchIndexingString(kIndicesIndexing)); + EXPECT_THAT( + fusion + .ComputeThreadIdToInputIndexing( + /*root_index=*/1, /*hero_operand_index=*/2, &mlir_context_) + ->ToString(thread_id_printer_), + MatchIndexingString(kIndicesIndexing)); +} + +TEST_F(MlirScatterFusionTest, Scatter_UniqueIndices) { + auto kHloString = R"( + HloModule module + + add { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + ROOT %sum = f32[] add(%p0, %p1) + } + scatter { + %operand = f32[10,5] parameter(0) + %indices = s32[8,1] parameter(1) + %update = f32[8,1,2] parameter(2) + + ROOT %scatter = f32[10,5] scatter( + f32[10,5] %operand, + s32[8,1] %indices, + f32[8,1,2] %update + ), + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + unique_indices=true, + to_apply=add + } + ENTRY entry { + %c1 = f32[] constant(1) + %c1_tensor = f32[10,5] broadcast(%c1), dimensions={} + %indices = s32[8,1] constant({{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}}) + %update = f32[8, 1, 2] parameter(0) + ROOT %fusion = f32[10, 5] fusion( + %c1_tensor, %indices, %update), kind=kLoop, calls=scatter + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 2)> + // CHECK: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 2)> + + // CHECK-LABEL: func.func @fused_computation( + // CHECK-SAME: %[[OPERAND:[a-zA-Z0-9]*]]: tensor<10x5xf32> + // CHECK-SAME: %[[INDICES:[a-zA-Z0-9]*]]: tensor<8x1xi32> + // CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]*]]: tensor<8x1x2xf32> + // CHECK-SAME: %[[OUT:[a-zA-Z0-9]*]]: tensor<10x5xf32> + + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index + + // CHECK: %[[TH_X:.*]] = gpu.thread_id x + // CHECK: %[[SLICE_ID:.*]] = affine.apply #[[$MAP0]]()[%[[TH_X]]] + // CHECK: %[[SLICE_X:.*]] = affine.apply #[[$MAP1]]()[%[[TH_X]]] + + // CHECK: %[[UPD_ELEM:.*]] = xla_gpu.pure_call @scatter_update( + // CHECK-SAME: %[[OPERAND]], %[[INDICES]], %[[UPDATES]], + // CHECK-SAME: %[[SLICE_ID]], %[[C0]], %[[SLICE_X]]) + + // CHECK: xla_gpu.pure_call @scatter_indices(%[[OPERAND]], %[[INDICES]] + // CHECK-SAME: %[[UPDATES]], %[[SLICE_ID]], %[[C0]]) + + // CHECK: %[[IN_BOUNDS:.*]] = arith.andi + // CHECK: scf.if %[[IN_BOUNDS]] -> (tensor<10x5xf32>) { + // CHECK: %[[CURRENT:.*]] = xla_gpu.pure_call @scatter_operand( + // CHECK-SAME: %[[OPERAND]], %[[INDICES]], %[[UPDATES]], + // CHECK-SAME: %[[SLICE_X]]) + // CHECK: %[[COMBINED:.*]] = arith.addf %[[CURRENT]], %[[UPD_ELEM]] + // CHECK: %[[UPDATED:.*]] = tensor.insert %[[COMBINED]] + // CHECK-SAME: into %[[OUT]][%{{.*}}, %[[SLICE_X]]] : tensor<10x5xf32> + // CHECK: scf.yield %[[UPDATED]] : tensor<10x5xf32> + // CHECK: } else { + // CHECK: scf.yield %[[OUT]] : tensor<10x5xf32> + // CHECK: } + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirScatterFusionTest, Scatter_Add) { + auto kHloString = R"( + HloModule module + + add { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + ROOT %sum = f32[] add(%p0, %p1) + } + scatter { + %operand = f32[10,5] parameter(0) + %indices = s32[24,1] parameter(1) + %update = f32[24,2,3] parameter(2) + + ROOT %scatter = f32[10,5] scatter( + f32[10,5] %operand, + s32[24,1] %indices, + f32[24,2,3] %update + ), + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + unique_indices=false, + to_apply=add + } + ENTRY entry { + %c1 = f32[] constant(1) + %c1_tensor = f32[10,5] broadcast(%c1), dimensions={} + %indices = s32[24,1] parameter(0) + %update = f32[24, 2, 3] parameter(1) + ROOT %fusion = f32[10, 5] fusion( + %c1_tensor, %indices, %update), kind=kLoop, calls=scatter + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK-LABEL: func.func @fused_computation( + // CHECK-SAME: %[[OPERAND:[a-zA-Z0-9]*]]: tensor<10x5xf32> + // CHECK-SAME: %[[INDICES:[a-zA-Z0-9]*]]: tensor<24x1xi32> + // CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]*]]: tensor<24x2x3xf32> + // CHECK-SAME: %[[OUT:[a-zA-Z0-9]*]]: tensor<10x5xf32> + + // CHECK: %[[UPD_ELEM:.*]] = xla_gpu.pure_call @scatter_update + // CHECK: %[[IN_BOUNDS:.*]] = arith.andi + // CHECK: scf.if %[[IN_BOUNDS]] -> (tensor<10x5xf32>) { + // CHECK: %[[RMW:.*]] = xla_gpu.atomic_rmw %[[OUT]] + // CHECK: ^bb0(%[[CUR_VALUE:.*]]: f32): + // CHECK: %[[SUM:.*]] = arith.addf %[[CUR_VALUE]], %[[UPD_ELEM]] + // CHECK: xla_gpu.yield %[[SUM]] : f32 + // CHECK: } + // CHECK: scf.yield %[[RMW]] : tensor<10x5xf32> + // CHECK: } else { + // CHECK: scf.yield %[[OUT]] : tensor<10x5xf32> + // CHECK: } + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirScatterFusionTest, Scatter_Overwrite) { + auto kHloString = R"( + HloModule module + + overwrite { + %p0 = f32[] parameter(0) + ROOT %p1 = f32[] parameter(1) + } + scatter { + %operand = f32[10,5] parameter(0) + %indices = s32[3,1] parameter(1) + %update = f32[3,2,3] parameter(2) + + ROOT %scatter = f32[10,5] scatter( + f32[10,5] %operand, + s32[3,1] %indices, + f32[3,2,3] %update + ), + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + unique_indices=false, + to_apply=overwrite + } + ENTRY entry { + %c1 = f32[] constant(1) + %c1_tensor = f32[10,5] broadcast(%c1), dimensions={} + %indices = s32[3,1] constant({ {0}, {3}, {6}}) + %update = f32[3, 2, 3] parameter(0) + ROOT %fusion = f32[10, 5] fusion( + %c1_tensor, %indices, %update), kind=kLoop, calls=scatter + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK-LABEL: func.func @fused_computation( + // CHECK-SAME: %[[OPERAND:[a-zA-Z0-9]*]]: tensor<10x5xf32> + // CHECK-SAME: %[[INDICES:[a-zA-Z0-9]*]]: tensor<3x1xi32> + // CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]*]]: tensor<3x2x3xf32> + // CHECK-SAME: %[[OUT:[a-zA-Z0-9]*]]: tensor<10x5xf32> + + // CHECK: %[[UPD_ELEM:.*]] = xla_gpu.pure_call @scatter_update + // CHECK: %[[IN_BOUNDS:.*]] = arith.andi + // CHECK: scf.if %[[IN_BOUNDS]] -> (tensor<10x5xf32>) { + // CHECK: %[[RMW:.*]] = xla_gpu.atomic_rmw %[[OUT]] + // CHECK: ^bb0(%[[CUR_VALUE:.*]]: f32): + // CHECK: xla_gpu.yield %[[UPD_ELEM]] : f32 + // CHECK: } + // CHECK: scf.yield %[[RMW]] : tensor<10x5xf32> + // CHECK: } else { + // CHECK: scf.yield %[[OUT]] : tensor<10x5xf32> + // CHECK: } + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/scatter_test.cc b/xla/service/gpu/fusions/scatter_test.cc new file mode 100644 index 0000000000000..2be8dc86d7554 --- /dev/null +++ b/xla/service/gpu/fusions/scatter_test.cc @@ -0,0 +1,219 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/scatter.h" + +#include + +#include +#include +#include "xla/service/gpu/fusions/fusions.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class ScatterFusionTest : public HloTestBase { + public: + void SetUp() override { + HloTestBase::SetUp(); + printer_ = + AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, + {"chunk_id", "unroll_id", "index_id"}); + } + + protected: + AffineMapPrinter printer_; + mlir::MLIRContext mlir_context_; +}; + +TEST_F(ScatterFusionTest, ScatterFusion) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) + } + + fused_computation { + %input = f32[2,9] parameter(0) + %indices = s32[3] parameter(1) + %updates = f32[3,9] parameter(2) + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + } + + ENTRY entry { + %input = f32[2,9] parameter(0) + %indices = s32[3] parameter(1) + %updates = f32[3,9] parameter(2) + ROOT %fusion = f32[2,9] fusion(%input, %indices, %updates), kind=kLoop, calls=fused_computation + })") + .value(); + + stream_executor::DeviceDescription device_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis_fused = AnalyzeFusion(*root, device_info); + + TF_ASSERT_OK_AND_ASSIGN( + auto emitter, + GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused})); + auto scatter_fusion = dynamic_cast(emitter.get()); + ASSERT_NE(scatter_fusion, nullptr); + EXPECT_EQ(scatter_fusion->launch_dimensions().launch_bound(), + 3 * 9 /* updates size */); +} + +TEST_F(ScatterFusionTest, ThreadIdIndexing) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule module + + computation { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + %p2 = f32[] parameter(2) + %p3 = f32[] parameter(3) + ROOT %tuple = (f32[], f32[]) tuple(f32[] %p2, f32[] %p3) + } + scatter { + %operand0 = f32[300,200] parameter(0) + %operand1 = f32[300,200] parameter(1) + %indices = s32[42,1] parameter(2) + %update.1 = f32[42,10,20] parameter(3) + %update.2 = f32[42,10,20]parameter(4) + + ROOT %scatter = (f32[300,200], f32[300,200]) scatter( + f32[300,200] %operand0, + f32[300,200] %operand1, + s32[42,1] %indices, + f32[42,10,20] %update.1, + f32[42,10,20] %update.2 + ), + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + to_apply=computation + } + ENTRY entry { + %operand0 = f32[300,200] parameter(0) + %operand1 = f32[300,200] parameter(1) + %indices = s32[42,1] parameter(2) + %update.1 = f32[42,10,20] parameter(3) + %update.2 = f32[42,10,20]parameter(4) + ROOT %fusion = (f32[300,200], f32[300,200]) fusion( + %operand0, %operand1, %indices, %update.1, %update.2), + kind=kLoop, calls=scatter + } + )")); + stream_executor::DeviceDescription device_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis_fused = AnalyzeFusion(*root, device_info); + + TF_ASSERT_OK_AND_ASSIGN( + auto emitter, + GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused})); + auto fusion = dynamic_cast(emitter.get()); + ASSERT_NE(fusion, nullptr); + + constexpr auto kUpdatesIndexing = R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( + ((bl_x * 16 + th_x floordiv 8) floordiv 25) mod 42, + ((bl_x * 32 + th_x floordiv 4) floordiv 5) mod 10, + (th_x + bl_x * 128) mod 20) + domain: + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 65] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + th_x + bl_x * 128 in [0, 8399] + )"; + EXPECT_THAT( + fusion + ->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/3, &mlir_context_) + ->ToString(printer_), + MatchIndexingString(kUpdatesIndexing)); + EXPECT_THAT( + fusion + ->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/4, &mlir_context_) + ->ToString(printer_), + MatchIndexingString(kUpdatesIndexing)); + EXPECT_THAT( + fusion + ->ComputeThreadIdToInputIndexing( + /*root_index=*/1, /*hero_operand_index=*/3, &mlir_context_) + ->ToString(printer_), + MatchIndexingString(kUpdatesIndexing)); + EXPECT_THAT( + fusion + ->ComputeThreadIdToInputIndexing( + /*root_index=*/1, /*hero_operand_index=*/4, &mlir_context_) + ->ToString(printer_), + MatchIndexingString(kUpdatesIndexing)); + + constexpr auto kIndicesIndexing = R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] -> ( + ((bl_x * 16 + th_x floordiv 8) floordiv 25) mod 42, 0) + domain: + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 65] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + index_id in [0, 0] + th_x + bl_x * 128 in [0, 8399] + )"; + EXPECT_THAT( + fusion + ->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_) + ->ToString(printer_), + MatchIndexingString(kIndicesIndexing)); + EXPECT_THAT( + fusion + ->ComputeThreadIdToInputIndexing( + /*root_index=*/1, /*hero_operand_index=*/2, &mlir_context_) + ->ToString(printer_), + MatchIndexingString(kIndicesIndexing)); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/thunk_util.cc b/xla/service/gpu/fusions/thunk_util.cc index d9d729ab06649..b356f76d083f8 100644 --- a/xla/service/gpu/fusions/thunk_util.cc +++ b/xla/service/gpu/fusions/thunk_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,89 +15,75 @@ limitations under the License. #include "xla/service/gpu/fusions/thunk_util.h" #include +#include #include #include +#include "absl/algorithm/container.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/literal.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/memset_thunk.h" -#include "xla/service/gpu/thunk.h" +#include "xla/service/gpu/runtime/memset_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" #include "xla/shape.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" +#include "xla/shape_util.h" namespace xla { namespace gpu { -namespace { -// TODO(b/291536641): Clean this up. What's the difference between this and the -// caller? -std::optional> BuildConstantInitializerThunk( - mlir::Operation* op, absl::Span init_value, mlir::Value dest, - const BufferAllocation::Slice& dest_slice, const Shape& output_shape) { - int64_t num_bytes = init_value.size(); - if (absl::c_all_of(init_value, [](uint8_t byte) { return byte == 0; })) { - return {{std::make_unique(Thunk::ThunkInfo(op), dest_slice, - dest)}}; - } - - // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by - // repeating the literal 4 or 2 times, so long as the destination buffer is - // an even multiple of 32 bits long. - if ((num_bytes == 1 || num_bytes == 2) && - ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) { - uint16_t pattern16; - if (num_bytes == 1) { - uint8_t b = init_value.front(); - pattern16 = uint16_t{b} | (uint16_t{b} << 8); - } else { - memcpy(&pattern16, init_value.data(), sizeof(pattern16)); - } - uint32_t pattern32 = uint32_t{pattern16} | (uint32_t{pattern16} << 16); - return {{std::make_unique( - Thunk::ThunkInfo(op), pattern32, dest_slice, dest)}}; - } - - // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit - // memset so long as all 32-bit words of the scalar are equal to each other. - if (num_bytes >= 4 && num_bytes % 4 == 0 && - memcmp(init_value.data(), init_value.data() + 4, init_value.size() - 4) == - 0) { - uint32_t word; - memcpy(&word, init_value.data(), sizeof(word)); - return {{std::make_unique(Thunk::ThunkInfo(op), word, - dest_slice, dest)}}; - } - - return std::nullopt; -} - -} // namespace - -StatusOr>> BuildConstantInitializerThunk( - IrEmitterContext& ir_emitter_context, mlir::Operation* op, - const HloInstruction* instr, const HloInstruction* init_value, - mlir::Value dest, BufferAllocation::Slice dest_slice) { +absl::StatusOr>> +BuildConstantInitializerThunk(IrEmitterContext& ir_emitter_context, + const HloInstruction* instr, + const HloInstruction* init_value, + BufferAllocation::Slice dest_slice) { if (const HloConstantInstruction* constant = DynCast(init_value)) { const Literal& literal = constant->literal(); absl::Span literal_bytes( static_cast(literal.untyped_data()), literal.size_bytes()); + int64_t num_bytes = literal_bytes.size(); const Shape dest_shape = instr->shape(); - return BuildConstantInitializerThunk(op, literal_bytes, dest, dest_slice, - dest_shape); + + Thunk::ThunkInfo thunk_info = + Thunk::ThunkInfo::WithProfileAnnotation(instr); + if (absl::c_all_of(literal_bytes, [](uint8_t byte) { return byte == 0; })) { + return {{std::make_unique(thunk_info, dest_slice)}}; + } + + // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by + // repeating the literal 4 or 2 times, so long as the destination buffer is + // an even multiple of 32 bits long. + if ((num_bytes == 1 || num_bytes == 2) && + ShapeUtil::ByteSizeOf(dest_shape) % 4 == 0) { + uint16_t pattern16; + if (num_bytes == 1) { + uint8_t b = literal_bytes.front(); + pattern16 = uint16_t{b} | (uint16_t{b} << 8); + } else { + memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16)); + } + uint32_t pattern32 = uint32_t{pattern16} | (uint32_t{pattern16} << 16); + return {{std::make_unique(thunk_info, pattern32, + dest_slice)}}; + } + + // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit + // memset so long as all 32-bit words of the scalar are equal to each other. + if (num_bytes >= 4 && num_bytes % 4 == 0 && + memcmp(literal_bytes.data(), literal_bytes.data() + 4, + literal_bytes.size() - 4) == 0) { + uint32_t word; + memcpy(&word, literal_bytes.data(), sizeof(word)); + return {{std::make_unique(thunk_info, word, + dest_slice)}}; + } } return std::nullopt; } diff --git a/xla/service/gpu/fusions/thunk_util.h b/xla/service/gpu/fusions/thunk_util.h index 29a9715209dd9..a78bb76f3cdd2 100644 --- a/xla/service/gpu/fusions/thunk_util.h +++ b/xla/service/gpu/fusions/thunk_util.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,22 +18,22 @@ limitations under the License. #include #include -#include "mlir/IR/Value.h" // from @llvm-project +#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/thunk.h" -#include "xla/statusor.h" +#include "xla/service/gpu/runtime/thunk.h" namespace xla { namespace gpu { // Attempts to build an initializer constant for the given value. Returns an // empty optional if the value is not a constant. -StatusOr>> BuildConstantInitializerThunk( - IrEmitterContext& ir_emitter_context, mlir::Operation* op, - const HloInstruction* instr, const HloInstruction* init_value, - mlir::Value dest, BufferAllocation::Slice dest_slice); +absl::StatusOr>> +BuildConstantInitializerThunk(IrEmitterContext& ir_emitter_context, + const HloInstruction* instr, + const HloInstruction* init_value, + BufferAllocation::Slice dest_slice); } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/fusions/tiling_util.cc b/xla/service/gpu/fusions/tiling_util.cc index a1bb666fec7b3..24456209e521f 100644 --- a/xla/service/gpu/fusions/tiling_util.cc +++ b/xla/service/gpu/fusions/tiling_util.cc @@ -1,4 +1,4 @@ -/*Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/*Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,136 +15,111 @@ limitations under the License. #include "xla/service/gpu/fusions/tiling_util.h" +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/target_util.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/kernel_support_library.h" +#include "xla/service/llvm_ir/llvm_loop.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { namespace { -// Gets the output offset as calculated from thread_id.x (to be applied to the -// offset calculated from block_id and thread_id.y). -llvm::Value* GetStartOffsetX(const TilingScheme& tiling_scheme, - llvm::Value* thread_id_x, llvm::Type* index_ty, - llvm::IRBuilder<>* b) { - int64_t multiplier = - tiling_scheme.GetIndexingOrder() == TilingScheme::StridedIndexingX - ? tiling_scheme.GetVectorSize() - : tiling_scheme.GetTileSizeFor(TilingScheme::DimX); - return b->CreateMul(thread_id_x, - llvm::ConstantInt::get(index_ty, multiplier)); -} - -// Emits loop through the minor (X) dimension of a tile, starting at a given -// offset. -// -// Rough pseudocode: -// -// Given: offset, callback -// -// for (int x = 0; x < x_tile_size / vector_size; ++x) { -// for (int i = 0; i < vector_size; ++i) { -// callback(offset + x * stride * vector_size + i); -// } -// } -void EmitXTileLoop(const TilingThreadIdInfo& thread_id_info, - const llvm_ir::IrArray::Index& tile_origin_index, - const TilingScheme& tiling_scheme, bool check_x_tile_bounds, - llvm::Value* y_loc, - std::array tile_dimensions, - llvm::IRBuilder<>* b, - const EmitTileElementFunction* emit_elem_function) { - llvm::Type* index_ty = tile_dimensions[1]->getType(); - KernelSupportLibrary ksl(b, llvm_ir::UnrollMode::kDefaultUnroll); +void EmitTileRec(const TilingThreadIdInfo& thread_id_info, const Tiling& tiling, + int dim, absl::InlinedVector tile_idx, + absl::Span tile_dimensions, + llvm::IRBuilder<>* b, const TileElementGenerator& emit_elem) { + llvm::Type* index_ty = thread_id_info.thread_id->getType(); auto constant = [&](int64_t val) { return llvm::ConstantInt::get(index_ty, val); }; - llvm::Value* start_offset_x = - GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_x, index_ty, b); - int64_t vector_size = tiling_scheme.GetVectorSize(); - int64_t stride_x = - tiling_scheme.GetIndexingOrder() == TilingScheme::LinearIndexingX - ? 1 - : tiling_scheme.GetNumThreadsFor(TilingScheme::DimX); - KernelSupportLibrary unrolled_ksl(b, llvm_ir::UnrollMode::kFullyUnroll); - unrolled_ksl.For( - "tile_loop", - /*start=*/constant(0), - /*end=*/ - constant(tiling_scheme.GetTileSizeFor(TilingScheme::DimX) / vector_size), - /*step=*/1, [&](llvm::Value* x) { - for (int64_t i = 0; i < vector_size; i++) { - llvm::Value* x_offset = b->CreateAdd( - b->CreateMul(x, constant(stride_x * vector_size)), constant(i)); - llvm::Value* x_loc = b->CreateAdd(x_offset, start_offset_x, "x_loc"); - llvm_ir::IrArray::Index source_idx_x = - tile_origin_index - .AddOffsetToDim(y_loc, tiling_scheme.GetTilingDimension(0), b) - .AddOffsetToDim(x_loc, tiling_scheme.GetTilingDimension(1), - b); - auto emit_element = [&] { - return (*emit_elem_function)(thread_id_info, source_idx_x, y_loc, - x_loc); - }; - if (check_x_tile_bounds) { - ksl.If("x_in_tile", b->CreateICmpULT(x_loc, tile_dimensions[1]), - emit_element); - } else { - emit_element(); - } + auto recurse = [&] { + if (dim == tile_idx.size() - 1) { + emit_elem(tile_idx); + } else { + EmitTileRec(thread_id_info, tiling, dim + 1, tile_idx, tile_dimensions, b, + emit_elem); + } + }; + + bool unroll = tiling.GetLoopsToUnroll()[dim]; + KernelSupportLibrary ksl(b, unroll ? llvm_ir::UnrollMode::kFullyUnroll + : llvm_ir::UnrollMode::kDefaultUnroll); + + if (tiling.GetBlockTileSize()[dim] == 1) { + tile_idx[dim] = constant(0); + recurse(); + } else if (unroll) { + // TODO(jreiffers): Check if this unrolling does anything useful. + int64_t stride = tiling.GetThreadsPerBlock()[dim]; + int64_t dim_size = tiling.GetThreadTileSize()[dim]; + + auto make_loop = [&](bool emit_bounds_checks) { + auto body = [&, emit_bounds_checks](llvm::Value* i) { + tile_idx[dim] = b->CreateAdd(i, thread_id_info.thread_ids[dim]); + if (emit_bounds_checks) { + auto* in_bounds = + b->CreateICmpULT(tile_idx[dim], tile_dimensions[dim]); + ksl.If("x_in_tile", in_bounds, recurse); + } else { + recurse(); } - }); + }; + return [&, body] { + ksl.For(absl::StrCat("loop", dim), constant(0), + constant(dim_size * stride), constant(stride), body); + }; + }; + if (stride > 1 && dim_size > 1) { + // Most tiles will be full, so we emit a single bounds check for those. + auto* is_full_tile = b->CreateICmpEQ( + constant(tiling.GetBlockTileSize()[dim]), tile_dimensions[dim]); + ksl.If("is_full_tile", is_full_tile, make_loop(false), make_loop(true)); + } else { + make_loop(true)(); + } + } else { + // All dimensions are strided (thread 0 processes elements 0, num_threads, + // num_threads+2, ...; thread 1 processes elements 1, num_threads + 1 and so + // on). + ksl.For(absl::StrCat("loop", dim), /*start=*/thread_id_info.thread_ids[dim], + /*end=*/tile_dimensions[dim], + /*step=*/tiling.GetThreadsPerBlock()[dim], [&](llvm::Value* i) { + tile_idx[dim] = i; + recurse(); + }); + } } } // namespace -void EmitTile(llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, - const llvm_ir::IrArray::Index& tile_origin_index, +void EmitTile(llvm::IRBuilder<>* builder, const Tiling& tiling, const TilingThreadIdInfo& thread_id_info, - std::array tile_dimensions, - const EmitTileElementFunction& emit_elem_function) { - llvm::Type* index_ty = tile_dimensions[0]->getType(); - auto constant = [&](int64_t val) { - return llvm::ConstantInt::get(index_ty, val); - }; - llvm::Value* num_threads_y = constant( - tiling_scheme.GetNumThreadsFor(tiling_scheme.GetTilingDimension(0))); - - KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll); - - ksl.For( - "y_in_tile", - /*start=*/thread_id_info.thread_id_y, - /*end=*/ - tile_dimensions[0], - /*step=*/num_threads_y, [&](llvm::Value* y_loc) { - auto unroll_inner_tile_loop = [&](bool check_x_tile_bounds) { - return EmitXTileLoop(thread_id_info, tile_origin_index, tiling_scheme, - check_x_tile_bounds, y_loc, tile_dimensions, - builder, &emit_elem_function); - }; - - // Only take this path when we unroll in a way vectorizable by - // LLVM. Special case when the tile doesn't fit completely for even - // row size. For odd row size every other row isn't aligned to the - // vectorized size, so it can't be vectorized by LLVM. - if (tiling_scheme.GetIndexingOrder() == - TilingScheme::StridedIndexingX) { - ksl.If( - "is_full_tile", - builder->CreateICmpEQ(constant(tiling_scheme.GetBlockTileSizeFor( - TilingScheme::DimX)), - tile_dimensions[1]), - [&] { unroll_inner_tile_loop(/*check_x_tile_bounds=*/false); }, - [&] { unroll_inner_tile_loop(/*check_x_tile_bounds=*/true); }); - } else { - unroll_inner_tile_loop(/*check_x_tile_bounds=*/true); - } - }); + absl::Span tile_dimensions, + const TileElementGenerator& emit_elem_function) { + absl::InlinedVector tile_idx(tiling.GetShape().size()); + EmitTileRec(thread_id_info, tiling, 0, tile_idx, tile_dimensions, builder, + emit_elem_function); } namespace { @@ -156,10 +131,12 @@ llvm::Value* EmitBlockId(llvm::IRBuilder<>* builder, int32_t num_blocks, EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, builder); if (num_blocks != 0) { llvm_ir::AddRangeMetadata(0, num_blocks, - llvm::cast(block_id)); + llvm::cast(block_id), + builder->GetInsertBlock()->getModule()); } - return builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true, - "block.id.x"); + auto ret = builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true); + ret->setName("block.id.x"); + return ret; } // Emits current thread id with the given type. @@ -169,228 +146,111 @@ llvm::Value* EmitThreadId(llvm::IRBuilder<>* builder, int64_t threads_per_block, llvm::Type* index_ty) { // Calculate (y, x) coordinates respectively in the 2D view of thread block, // defined by (num_thread_y, num_thread_x) from thread_id. - llvm::CallInst* thread_id_raw = + llvm::CallInst* thread_id = EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, threads_per_block, thread_id_raw); - return builder->CreateIntCast(thread_id_raw, index_ty, - /*isSigned=*/true, "thread.id.x"); + llvm_ir::AddRangeMetadata(0, threads_per_block, thread_id, + builder->GetInsertBlock()->getModule()); + auto ret = builder->CreateIntCast(thread_id, index_ty, /*isSigned=*/true); + ret->setName("thread.id.x"); + return ret; } -// Emits the LLVM values for thread_id, thread_id.x, thread_id.y and lane -// id. -// -// Returns a struct containting these values. -// -// In the presence of thread scaling in tiling scheme may return early if the -// combination of thread_id/block_id does not correspond to a real block. -// Assumes the current function returns void. -StatusOr EmitThreadIdInfo(llvm::IRBuilder<>* builder, - const TilingScheme& tiling_scheme, - llvm::Type* index_ty) { +// Emits the LLVM values for thread_id, block_id, coordinates of the current +// tile and strides of the loops to iterate over the current tile. +absl::StatusOr EmitThreadIdInfo(llvm::IRBuilder<>* builder, + const Tiling& tiling, + llvm::Type* index_ty) { auto constant = [&](uint64_t c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; - llvm::Value* thread_id_physical = EmitThreadId( - builder, tiling_scheme.GetNumThreadsPerBlockPhysical(), index_ty); - int64_t num_blocks = tiling_scheme.GetNumberOfBlocksPhysical(); + int64_t num_blocks = tiling.GetNumBlocks(); if (num_blocks > (int64_t)std::numeric_limits::max()) { return FailedPrecondition( "Number of physical blocks (%d) does not fit in an i32 in tiling " "scheme: %s", - num_blocks, tiling_scheme.ToString()); + num_blocks, tiling.ToString()); } - llvm::Value* block_id_physical = EmitBlockId(builder, num_blocks, index_ty); - - // Wait this will break coalescing. - llvm::Value* thread_id_logical = builder->CreateURem( - thread_id_physical, constant(tiling_scheme.GetNumThreadsPerBlock())); - llvm::Value* scaling = builder->CreateUDiv( - thread_id_physical, constant(tiling_scheme.GetNumThreadsPerBlock())); - llvm::Value* block_id_logical = builder->CreateAdd( - builder->CreateMul(block_id_physical, - constant(tiling_scheme.GetThreadIdScalingFactor())), - scaling); - llvm::Value* num_threads_x_v = - constant(tiling_scheme.GetNumThreadsFor(TilingScheme::DimX)); + TilingThreadIdInfo info; + info.thread_id = + EmitThreadId(builder, tiling.GetNumThreadsPerBlock(), index_ty); + info.block_id = EmitBlockId(builder, num_blocks, index_ty); + + for (auto [dim, stride] : llvm::enumerate(tiling.GetThreadStrides())) { + int64_t size = tiling.GetThreadsPerBlock()[dim]; + if (size == 1) { + info.thread_ids.emplace_back(constant(0)); + } else { + auto& dim_id = info.thread_ids.emplace_back(info.thread_id); + if (stride > 1) { + dim_id = builder->CreateUDiv(dim_id, constant(stride)); + } + if (dim) { + dim_id = builder->CreateURem(dim_id, constant(size)); + } + dim_id->setName(absl::StrCat("thread.id.", dim)); + } + } - llvm::Value* block_exists = builder->CreateICmpULT( - block_id_logical, constant(tiling_scheme.GetNumberOfBlocks())); - llvm_ir::EmitEarlyReturn(block_exists, builder); - return { - {thread_id_logical, - /*thread_id_x=*/ - builder->CreateURem(thread_id_logical, num_threads_x_v, "thread_id.x"), - /*thread_id_y=*/ - builder->CreateUDiv(thread_id_logical, num_threads_x_v, "thread_id.y"), - /*lane_id=*/ - builder->CreateURem(thread_id_logical, constant(WarpSize()), "lane_id"), - /*block_id=*/block_id_logical, - /*scaling=*/scaling}}; + info.lane_id = + builder->CreateURem(info.thread_id, constant(WarpSize()), "lane_id"); + return info; } } // namespace -StatusOr EmitTilingKernel( - llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, - llvm::Type* index_ty, const TileElementGenerator& tile_element_generator) { - absl::Span dims_in_elems = tiling_scheme.GetDimsInElems(); - Vector3 dims_in_blocks = tiling_scheme.GetDimsInBlocks(); +absl::StatusOr EmitTilingKernel( + llvm::IRBuilder<>* builder, const Tiling& tiling, llvm::Type* index_ty, + const TileGenerator& tile_generator) { + absl::Span dims_in_elems = tiling.GetShape(); + const auto& block_counts = tiling.GetBlockCounts(); auto constant = [&](uint64_t c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; TF_ASSIGN_OR_RETURN(TilingThreadIdInfo thread_id_info, - EmitThreadIdInfo(builder, tiling_scheme, index_ty)); + EmitThreadIdInfo(builder, tiling, index_ty)); KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll); - int64_t non_tiling_dimension = tiling_scheme.GetTilingDimension(0) == 1 - ? TilingScheme::DimZ - : TilingScheme::DimY; const llvm_ir::IrArray::Index block_coords( thread_id_info.block_id, - ShapeUtil::MakeShapeWithDenseLayout( - PRED /*arbitrary*/, dims_in_blocks, - // This layout determines the iteration order. We want the - // non-tiling dimension to be the slowest varying dimension. - {2, 1 - non_tiling_dimension, non_tiling_dimension}), - builder); - - std::array tile_dimensions; - // Coordinate access is shifted: 0 corresponds to the first non-tiling - // dimension and 1 corresponds to DimX. - std::array tiling_coords{1 - non_tiling_dimension, - TilingScheme::DimX}; - for (int i = 0; i < 2; ++i) { - int64_t tile_size_for_dim = - tiling_scheme.GetBlockTileSizeFor(tiling_coords[i]); - // Only last row or column may not have full size. - llvm::Value* is_last = - builder->CreateICmpEQ(block_coords[tiling_coords[i]], - constant(dims_in_blocks[tiling_coords[i]] - 1)); - int64_t partial_row = - dims_in_elems[tiling_coords[i]] - - (dims_in_blocks[tiling_coords[i]] - 1) * tile_size_for_dim; - tile_dimensions[i] = - builder->CreateSelect(is_last, constant(partial_row), - constant(tile_size_for_dim), "tile_bound"); + ShapeUtil::MakeShape(PRED /*arbitrary*/, block_counts), builder); + + absl::InlinedVector tile_dimensions; + for (int i = 0; i < block_counts.size(); ++i) { + int64_t block_tile_size = tiling.GetBlockTileSize()[i]; + if (dims_in_elems[i] % block_tile_size == 0) { + // The block tile size evenly divides the tiled shape -> no need to emit + // the bounds check. + tile_dimensions.push_back(constant(block_tile_size)); + } else { + // Only the last tile in each dimension may not have full size. + llvm::Value* is_last = + builder->CreateICmpEQ(block_coords[i], constant(block_counts[i] - 1)); + int64_t partial_row = + dims_in_elems[i] - (block_counts[i] - 1) * block_tile_size; + tile_dimensions.push_back(builder->CreateSelect( + is_last, constant(partial_row), constant(block_tile_size), + absl::StrCat("tile_bound.", i))); + } } - llvm_ir::IrArray::Index tile_origin = [&] { + llvm_ir::IrArray::Index tile_offset = [&] { std::vector elem_multi_index = block_coords.multidim(); llvm::Type* index_ty = block_coords.GetType(); - for (int i = 0; i < TilingScheme::DimTot; ++i) { + for (int i = 0; i < block_counts.size(); ++i) { elem_multi_index[i] = builder->CreateMul( block_coords[i], - llvm::ConstantInt::get(index_ty, - tiling_scheme.GetBlockTileSizeFor(i)), - "tile_origin." + std::to_string(i)); + llvm::ConstantInt::get(index_ty, tiling.GetBlockTileSize()[i]), + absl::StrCat("tile_origin.", i)); } - return llvm_ir::IrArray::Index(elem_multi_index, - tiling_scheme.GetDimsInElems(), index_ty); + return llvm_ir::IrArray::Index(elem_multi_index, tiling.GetShape(), + index_ty); }(); - auto emit_tile = [&](const llvm_ir::IrArray::Index& tile) { - tile_element_generator(thread_id_info, tile, tile_dimensions); - }; - - if (tiling_scheme.GetBlockTileSizeFor(non_tiling_dimension) == 1) { - emit_tile(tile_origin); - } else { - llvm::Value* starting_tile_index_for_dim = - tile_origin[non_tiling_dimension]; - llvm::Value* block_size_for_dim = - constant(tiling_scheme.GetBlockTileSizeFor(non_tiling_dimension)); - llvm::Value* block_id_for_dim = - builder->CreateUDiv(starting_tile_index_for_dim, block_size_for_dim); - llvm::Value* last_block_for_dim = - constant(dims_in_blocks[non_tiling_dimension] - 1); - llvm::Value* last_block_size_for_dim = - constant(dims_in_elems[non_tiling_dimension] - - (dims_in_blocks[non_tiling_dimension] - 1) * - tiling_scheme.GetBlockTileSizeFor(non_tiling_dimension)); - - llvm::Value* num_tiles_in_block = builder->CreateSelect( - builder->CreateICmpEQ(last_block_for_dim, block_id_for_dim), - last_block_size_for_dim, block_size_for_dim); - ksl.For("loop_z", - /*start=*/constant(0), - /*end=*/num_tiles_in_block, - /*step=*/1, [&](llvm::Value* block_dim_induction_var) { - llvm_ir::IrArray::Index tile_index = tile_origin.AddOffsetToDim( - block_dim_induction_var, non_tiling_dimension, builder); - emit_tile(tile_index); - }); - } - - return {{tile_dimensions, tile_origin, thread_id_info}}; -} - -llvm::Type* TilingThreadIdInfo::GEPIntoSharedMemoryType( - llvm::GlobalVariable* shared, - absl::Span idx_major_to_minor) const { - std::vector idxs_scaled; - idxs_scaled.push_back(llvm::ConstantInt::get(scaling->getType(), 0)); - idxs_scaled.push_back(scaling); - idxs_scaled.insert(idxs_scaled.end(), idx_major_to_minor.begin(), - idx_major_to_minor.end()); - return llvm::GetElementPtrInst::getIndexedType(shared->getValueType(), - idxs_scaled); -} - -llvm::Value* TilingThreadIdInfo::GEPIntoSharedMemory( - llvm::IRBuilder<>* b, llvm::GlobalVariable* shared, - absl::Span idx_major_to_minor, - const llvm::Twine& name) const { - std::vector idxs_scaled; - idxs_scaled.push_back(llvm::ConstantInt::get(scaling->getType(), 0)); - idxs_scaled.push_back(scaling); - idxs_scaled.insert(idxs_scaled.end(), idx_major_to_minor.begin(), - idx_major_to_minor.end()); - llvm::Value* gep = - b->CreateInBoundsGEP(shared->getValueType(), shared, idxs_scaled, name); - - llvm::PointerType* pointer_in_addressspace = - llvm::PointerType::getWithSamePointeeType( - llvm::cast(gep->getType()), /*AddressSpace=*/0); - - // __shared__ memory uses a different address space, so we cast it to - // global address space before writing or reading. - return b->CreateAddrSpaceCast(gep, pointer_in_addressspace); -} - -llvm_ir::IrArray::Index GetUnnormalizedIndex( - const llvm_ir::IrArray::Index& normalized_shape_index, - const Shape& unnormalized_shape, llvm::IRBuilder<>* builder, - absl::Span dims_in_elems) { - CHECK_EQ(normalized_shape_index.size(), 3); - // If the normalization only add a new dimensions of size 1, - // generate simpler indexing. LLVM doesn't always simplify the more - // complicated indexing and this prevents it from vectorizing some - // cases. We do this only for major_to_minor memory layout. - if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && - unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] && - unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] && - unnormalized_shape.layout().minor_to_major(1) == 0) { - CHECK_EQ(normalized_shape_index.dims()[0], 1); - auto multidim = normalized_shape_index.multidim(); - return llvm_ir::IrArray::Index({multidim[1], multidim[2]}, - unnormalized_shape, - normalized_shape_index.GetType()); - } - if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && - unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[2] && - unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[1] && - unnormalized_shape.layout().minor_to_major(1) == 1) { - CHECK_EQ(normalized_shape_index.dims()[0], 1); - auto multidim = normalized_shape_index.multidim(); - return llvm_ir::IrArray::Index({multidim[2], multidim[1]}, - unnormalized_shape, - normalized_shape_index.GetType()); - } - return normalized_shape_index.SourceIndexOfBitcast( - ShapeUtil::MakeShape(F32, dims_in_elems), unnormalized_shape, builder); + tile_generator(thread_id_info, tile_offset, tile_dimensions); + return {{tile_dimensions, tile_offset, thread_id_info}}; } } // namespace gpu diff --git a/xla/service/gpu/fusions/tiling_util.h b/xla/service/gpu/fusions/tiling_util.h index 7765d02de9a6f..f06ae8ccab428 100644 --- a/xla/service/gpu/fusions/tiling_util.h +++ b/xla/service/gpu/fusions/tiling_util.h @@ -1,4 +1,4 @@ -/*Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/*Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,66 +15,128 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_FUSIONS_TILING_UTIL_H_ #define XLA_SERVICE_GPU_FUSIONS_TILING_UTIL_H_ +#include #include +#include -#include "xla/service/gpu/kernel_mapping_scheme.h" +#include "absl/log/check.h" +#include "absl/types/span.h" #include "xla/service/llvm_ir/ir_array.h" +#include "xla/shape_util.h" +#include "xla/util.h" namespace xla { namespace gpu { -// Contains threading information. Note that for performance we might apply -// thread id "scaling" where the physical thread id (to achieve good SM -// occupancy) will differ from logical thread id. This struct contains -// logical thread ids, along with meta-information about the scaling applied. -struct TilingThreadIdInfo { - TilingThreadIdInfo(llvm::Value* thread_id, llvm::Value* thread_id_x, - llvm::Value* thread_id_y, llvm::Value* lane_id, - llvm::Value* block_id, llvm::Value* scaling) - : thread_id(thread_id), - thread_id_x(thread_id_x), - thread_id_y(thread_id_y), - lane_id(lane_id), - block_id(block_id), - scaling(scaling) {} +// Describes tiling used by the kernel. +// +// Used by reduction and transpose emitters. +class Tiling { + public: + Tiling(absl::InlinedVector shape, + absl::InlinedVector tile_sizes, + absl::InlinedVector num_threads, + // By default, don't unroll anything. + absl::InlinedVector loops_to_unroll = {}) + : shape_(shape), + tile_sizes_per_thread_(tile_sizes), + tile_sizes_per_block_(shape.size()), + num_threads_(num_threads), + num_blocks_(shape.size()), + loops_to_unroll_(loops_to_unroll) { + for (int64_t i = 0; i < shape.size(); ++i) { + tile_sizes_per_block_[i] = tile_sizes[i] * num_threads[i]; + CHECK_NE(tile_sizes_per_block_[i], 0); + num_blocks_[i] = CeilOfRatio(shape[i], tile_sizes_per_block_[i]); + CHECK_NE(num_blocks_[i], 0); + } + if (loops_to_unroll_.empty()) loops_to_unroll_.resize(shape.size()); + } + + std::string ToString() const { + return absl::StrJoin( + {absl::StrFormat("shape = {%s}", absl::StrJoin(shape_, ", ")), + absl::StrFormat("tile_sizes = {%s}", + absl::StrJoin(tile_sizes_per_thread_, ", ")), + absl::StrFormat("num_threads = {%s}", + absl::StrJoin(num_threads_, ", "))}, + ", "); + } + + // Number of elements in each dimension. + const absl::InlinedVector& GetShape() const { return shape_; } + xla::Shape GetXlaShape(PrimitiveType element_type = F32) const { + return ShapeUtil::MakeShape(element_type, shape_); + } + + const absl::InlinedVector& GetBlockCounts() const { + return num_blocks_; + } + + // Tile size for each thread. + // + // Equals to the number of iterations in the loop each tile will make. + const absl::InlinedVector& GetThreadTileSize() const { + return tile_sizes_per_thread_; + } - llvm::Value* thread_id; + // Tile size for an entire thread block. + const absl::InlinedVector& GetBlockTileSize() const { + return tile_sizes_per_block_; + } + + const absl::InlinedVector& GetThreadsPerBlock() const { + return num_threads_; + } + + // Returns the strides of the thread index dimensions wrt. the linear thread + // id. + absl::InlinedVector GetThreadStrides() const { + return *ShapeUtil::ByteStrides(ShapeUtil::MakeShape(U8, num_threads_)); + } + + // Returns the strides of the block index dimensions wrt. the linear block id. + absl::InlinedVector GetBlockStrides() const { + return *ShapeUtil::ByteStrides(ShapeUtil::MakeShape(U8, num_blocks_)); + } + + int64_t GetNumThreadsPerBlock() const { return Product(num_threads_); } + + int64_t GetNumBlocks() const { return Product(num_blocks_); } + + const absl::InlinedVector& GetLoopsToUnroll() const { + return loops_to_unroll_; + } + + private: + // The number of elements in each dimension. + absl::InlinedVector shape_; + + // The number of elements for each dimension of a tile. + absl::InlinedVector tile_sizes_per_thread_; + absl::InlinedVector tile_sizes_per_block_; + + absl::InlinedVector num_threads_; + absl::InlinedVector num_blocks_; + + absl::InlinedVector loops_to_unroll_; +}; - // X-coordinate calculated from thread id: `thread_id % num_threads_x` - llvm::Value* thread_id_x; +struct TilingThreadIdInfo { + llvm::Value* thread_id; - // Y-coordinate calculated from thread id: `thread_id / num_threads_x` - llvm::Value* thread_id_y; + absl::InlinedVector thread_ids; // Lane id: `thread_id % WarpSize` llvm::Value* lane_id; // Block id. llvm::Value* block_id; - - // Emits GEP into a shared memory, taking virtual thread scaling into - // account. Automatically inserts the first zero required by LLVM GEP. - // Defined on ThreadIdInfo to keep `scaling` private. - // - // Same semantics as CreateInBoundsGEP. - llvm::Value* GEPIntoSharedMemory( - llvm::IRBuilder<>* b, llvm::GlobalVariable* shared, - absl::Span idx_major_to_minor, - const llvm::Twine& name = "") const; - - // Calculate the pointee type of the llvm::Value returned by - // GEPIntoSharedMemory - llvm::Type* GEPIntoSharedMemoryType( - llvm::GlobalVariable* shared, - absl::Span idx_major_to_minor) const; - - private: - llvm::Value* scaling; }; struct TilingKernelInfo { // Tiling bounds. - std::array output_tile_bounds; + absl::InlinedVector output_tile_bounds; // Starting tile, as calculated from block id only. llvm_ir::IrArray::Index tile_origin; @@ -87,61 +149,30 @@ struct TilingKernelInfo { // // index: Absolute coordinate of the start of the tile in input. // tile_dimensions: Size of the tile -using TileElementGenerator = +using TileGenerator = std::function tile_dimensions)>; + const llvm_ir::IrArray::Index& tile_start_index, + absl::Span tile_dimensions)>; // A function object to generate code to process one element in a tile. // -// index: the index for the first output element of the current thread. -// y_loc: The y coordinate within a tile. -// x_loc: The x coordinate within a tile. -using EmitTileElementFunction = - std::function; +// index_in_tile: the current coordinates within the tile. To get the global +// coordinates, use `tile_start_index.AddOffset(index_in_tile, ...)`. +using TileElementGenerator = + std::function index_in_tile)>; -// Emits code to iterate through a 2-dimensional tile with a given tile -// dimensions and given strides, and call the callback at each iteration., -// -// thread_id_y` and `thread_id_x` are the intra-tile coordinates for -// the first element to process, and `index` is the index for the origin of -// the tile. Emits bounds check to ensure that each processed element -// is within the boundary defined by `tile_dimensions`. -// -// Rough pseudocode: -// -// Given: tile_dimensions, x_offset, y_offset -// -// for (y = 0; y < tile_dimensions[0]; y += num_threads_y) { -// for (x = 0; x < tile_dimensions[1]; x++) { -// -// y_pos = y_offset + y -// x_pos = x_offset + x * stride -// -// if (x_loc < tile_width) { -// emit_elem_function(y_offset + y, x_loc); -// } -// } -// } -// -void EmitTile(llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, - const llvm_ir::IrArray::Index& tile_origin_index, +// Emits code to iterate through a tile with given tile dimensions and generate +// elements using the callback. +void EmitTile(llvm::IRBuilder<>* builder, const Tiling& tiling, const TilingThreadIdInfo& thread_id_info, - std::array tile_dimensions, - const EmitTileElementFunction& emit_elem_function); + absl::Span tile_dimensions, + const TileElementGenerator& emit_elem_function); // Emits a kernel for the hlo instruction using the given kernel mapping // scheme. -StatusOr EmitTilingKernel( - llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, - llvm::Type* index_ty, const TileElementGenerator& tile_element_generator); - -llvm_ir::IrArray::Index GetUnnormalizedIndex( - const llvm_ir::IrArray::Index& normalized_shape_index, - const Shape& unnormalized_shape, llvm::IRBuilder<>* builder, - absl::Span dims_in_elems); +absl::StatusOr EmitTilingKernel( + llvm::IRBuilder<>* builder, const Tiling& tiling, llvm::Type* index_ty, + const TileGenerator& tile_element_generator); } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/fusions/transpose.cc b/xla/service/gpu/fusions/transpose.cc index 758743734d751..ca7b3f7ff7922 100644 --- a/xla/service/gpu/fusions/transpose.cc +++ b/xla/service/gpu/fusions/transpose.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,44 +14,77 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/fusions/transpose.h" +#include +#include +#include +#include +#include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/AtomicOrdering.h" +#include "mlir/IR/AffineMap.h" // from @llvm-project #include "xla/hlo/ir/hlo_instructions.h" #include "xla/permutation_util.h" +#include "xla/service/gpu/elemental_ir_emitter.h" #include "xla/service/gpu/fusions/tiling_util.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/target_util.h" #include "xla/service/llvm_ir/fused_ir_emitter.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "xla/status.h" +#include "xla/util.h" namespace xla { namespace gpu { namespace { -llvm::GlobalVariable* AllocateShared( - llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, - llvm::Type* element_type, - absl::Span dimensions_major_to_minor, - absl::string_view buffer_name) { - CHECK(!dimensions_major_to_minor.empty()); - llvm::Type* ty = element_type; - for (auto dim : llvm::reverse(dimensions_major_to_minor)) { - ty = llvm::ArrayType::get(ty, dim); - } - ty = llvm::ArrayType::get(ty, tiling_scheme.GetThreadIdScalingFactor()); - return llvm_ir::AllocateSharedMemoryTile( - builder->GetInsertBlock()->getModule(), ty, buffer_name); +Tiling ComputeTransposeTiling(const TransposeDescription& tiled_transpose) { + constexpr int kNumRows = 4; + static_assert(WarpSize() % kNumRows == 0); + + // 3D view over the output shape. + Vector3 transposed_dims = tiled_transpose.dimensions; + Vector3 permutation = tiled_transpose.permutation; + + // Note: the supported permutations are their own inverses. Therefore we + // always use the permutation, even when we want the inverse. + CHECK((permutation == Vector3{0, 2, 1}) || (permutation == Vector3{2, 1, 0})); + + absl::InlinedVector input_dims{transposed_dims[permutation[0]], + transposed_dims[permutation[1]], + transposed_dims[permutation[2]]}; + + // We tile along the minor dimensions pre- and post-transpose. + absl::InlinedVector tile_sizes{1, 1, 1}; + tile_sizes[permutation[2]] = WarpSize() / kNumRows; + absl::InlinedVector num_threads{1, 1, WarpSize()}; + num_threads[permutation[2]] = kNumRows; + + return Tiling(input_dims, tile_sizes, num_threads); } void MaybeEmitFenceForAMDGPU(llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context) { auto* module = builder->GetInsertBlock()->getModule(); if (IsAMDGPU(module) && - ir_emitter_context.rocm_compute_capability().gcn_arch_name().substr( - 0, 6) == "gfx90a") { + ir_emitter_context.rocm_compute_capability().fence_before_barrier()) { builder->CreateFence( llvm::AtomicOrdering::SequentiallyConsistent, builder->getContext().getOrInsertSyncScopeID("workgroup")); @@ -73,13 +106,26 @@ llvm_ir::IrArray::Index PermuteIndex(const llvm_ir::IrArray::Index& index, } // namespace -Status TransposeFusion::EmitKernel( - IrEmitterContext& ir_emitter_context, ElementalIrEmitter& elemental_emitter, - const HloFusionInstruction& fusion, const LaunchDimensions& launch_dims, - std::vector inputs, std::vector outputs, - llvm::IRBuilder<>* builder, int kernel_index) const { - const auto& tiling_scheme = *analysis_.GetTransposeTilingScheme(); +TransposeFusion::TransposeFusion(const HloFusionAnalysis& analysis) + : analysis_(analysis), + tiling_(ComputeTransposeTiling(analysis.tiled_transpose())) { + for (auto [root, hero] : + llvm::zip(analysis_.fusion_roots(), analysis_.fusion_heroes())) { + if (auto transpose = GetDescriptionForTiledTransposeEmitter(*root, *hero)) { + permutation_ = transpose->permutation; + break; + } + } +} + +absl::Status TransposeFusion::EmitKernel(IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion, + const LaunchDimensions& launch_dims, + std::vector inputs, + std::vector outputs, + llvm::IRBuilder<>* builder) const { const auto& hlo_roots = analysis_.fusion_roots(); + GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder); FusedIrEmitter fused_emitter(elemental_emitter); for (auto [i, input] : llvm::enumerate(inputs)) { HloInstruction* fused_operand = fusion.fused_parameter(i); @@ -91,153 +137,190 @@ Status TransposeFusion::EmitKernel( }); } - std::vector heroes; - std::vector> transposes; - heroes.reserve(hlo_roots.size()); - for (const auto& root : hlo_roots) { - heroes.push_back(&FindNonTrivialHero(*root)); - transposes.push_back( - GetDescriptionForTiledTransposeEmitter(*root, *heroes.back())); + absl::flat_hash_map>> + transposes_to_roots; + // Keep a list of deduplicated transpose heroes separate from + // transposes_to_roots to make the CodeGen deterministic. + std::vector transposes; + transposes.reserve(hlo_roots.size()); + std::vector> extra_outputs; + + for (const auto& [output_idx, root] : llvm::enumerate(hlo_roots)) { + const auto& hero = *analysis_.fusion_heroes()[output_idx]; + auto transpose_descr = GetDescriptionForTiledTransposeEmitter(*root, hero); + if (transpose_descr.has_value()) { + auto iterator_inserted = transposes_to_roots.insert(std::make_pair( + &hero, std::vector>{ + {output_idx, root}})); + if (iterator_inserted.second) { + transposes.push_back(*transpose_descr); + } else { + iterator_inserted.first->second.push_back({output_idx, root}); + } + } else { + extra_outputs.push_back({output_idx, root}); + } } - absl::flat_hash_map tiles; + absl::flat_hash_map tiles; Vector3 permutation; - for (const auto& [tile_idx, root] : llvm::enumerate(hlo_roots)) { - if (const auto& tr = transposes[tile_idx]) { - const auto& hero = *heroes[tile_idx]; - permutation = tr->permutation; - tiles[&hero] = AllocateShared( - builder, tiling_scheme, - llvm_ir::PrimitiveTypeToIrType( - hero.operand(0)->shape().element_type(), - ir_emitter_context.llvm_module()), - {tiling_scheme.GetBlockTileSizeFor(permutation[TilingScheme::DimX]), - tiling_scheme.GetBlockTileSizeFor(TilingScheme::DimX) + 1}, - absl::StrCat("tr_tile_", tile_idx)); - } + for (const auto& [tile_idx, tr] : llvm::enumerate(transposes)) { + permutation = tr.permutation; + auto tile_size = tiling_.GetBlockTileSize(); + ++tile_size.back(); // Prevent bank conflicts. + auto* module = ir_emitter_context.llvm_module(); + tiles[tr.instr] = llvm_ir::AllocateSharedMemoryTile( + module, + llvm_ir::PrimitiveTypeToIrType(tr.instr->shape().element_type(), + module), + tile_size, absl::StrCat("tr_tile_", tile_idx)); } - TileElementGenerator tile_generator = - [&](const TilingThreadIdInfo& thread_id_info, - const llvm_ir::IrArray::Index& index, - std::array tile_dimensions) { - // Copy input parameter values to shared memory buffers: - // tile[thread_id_y, thread_id_x] = input[index] - // Note that tile_width and tile_height are flipped here because we - // are reading a transposed tile. - EmitTile( - builder, tiling_scheme, index, thread_id_info, tile_dimensions, - [&](const TilingThreadIdInfo& thread_id_info, - const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { - // Compute all extra output values before writing them. This - // avoids overwriting aliased input/output values before all reads - // occurred. - std::vector> - scheduled_writes; - - for (const auto& [output_idx, root] : - llvm::enumerate(hlo_roots)) { - if (transposes[output_idx].has_value()) { - const HloInstruction& hero = *heroes[output_idx]; - llvm_ir::ElementGenerator input_gen = - *fused_emitter.GetGenerator(*hero.operand(0)); - llvm_ir::IrArray::Index untiled_index = GetUnnormalizedIndex( - index, hero.operand(0)->shape(), builder, - tiling_scheme.GetDimsInElems()); - llvm::Value* value = *input_gen(untiled_index); - llvm::Value* addr = thread_id_info.GEPIntoSharedMemory( - builder, tiles[&hero], {y_loc, x_loc}); - - builder->CreateStore(value, addr); - } else { - llvm_ir::IrArray::Index untiled_index = - GetUnnormalizedIndex(index, root->shape(), builder, - tiling_scheme.GetDimsInElems()); - llvm_ir::ElementGenerator output_gen = - *fused_emitter.GetGenerator(*root); - llvm::Value* output_value = *output_gen(untiled_index); - scheduled_writes.emplace_back(outputs[output_idx], - untiled_index, output_value); - } - } - - for (const auto& [output, idx, value] : scheduled_writes) { - output.EmitWriteArrayElement(idx, value, builder); - } - }); - - EmitSyncThreads(builder, ir_emitter_context); - - llvm_ir::IrArray::Index output_tile_index = - PermuteIndex(index, permutation); - std::array transposed_tile_dimensions = { - tile_dimensions[1], tile_dimensions[0]}; - - EmitTile( - builder, tiling_scheme, output_tile_index, thread_id_info, - transposed_tile_dimensions, - /*emit_elem_function=*/ - [&](const TilingThreadIdInfo& thread_id_info, - const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { - for (const auto& [output_idx, root] : - llvm::enumerate(hlo_roots)) { - if (transposes[output_idx].has_value()) { - const HloInstruction& hero = *heroes[output_idx]; - - std::vector idx = {x_loc, y_loc}; - llvm::Value* gep = thread_id_info.GEPIntoSharedMemory( - builder, tiles[&hero], idx); - llvm::Type* type = - thread_id_info.GEPIntoSharedMemoryType(tiles[&hero], idx); - llvm::Value* loaded = - builder->CreateLoad(type, gep, "tiled_buffer"); - - FusedIrEmitter fused_emitter(elemental_emitter); - fused_emitter.BindGenerator( - hero, [&](const llvm_ir::IrArray::Index& index) { - return loaded; - }); - for (int64_t i = 0; - i < fusion.fused_instructions_computation() - ->num_parameters(); - ++i) { - llvm_ir::IrArray ir_array = inputs[i]; - HloInstruction* fused_operand = fusion.fused_parameter(i); - fused_emitter.BindGenerator( - *fused_operand, - [=](const llvm_ir::IrArray::Index& index) { - return ir_array.EmitReadArrayElement( - index, builder, fused_operand->name()); - }); - } - - // Apply codegeneration for the code after the real hero. - TF_ASSIGN_OR_RETURN(llvm_ir::ElementGenerator gen, - fused_emitter.GetGenerator(*root)); - - // Both for emission and writing it should be - // index-as-transformed by the computation. - llvm_ir::IrArray::Index untiled_index = GetUnnormalizedIndex( - index, root->shape(), builder, - Permute(tiling_scheme.GetDimsInElems(), permutation)); - TF_ASSIGN_OR_RETURN(llvm::Value * generated, - gen(untiled_index)); - outputs[output_idx].EmitWriteArrayElement(untiled_index, - generated, builder); - } - } - return OkStatus(); - }); - }; + auto tile_generator = [&](const TilingThreadIdInfo& thread_id_info, + const llvm_ir::IrArray::Index& tile_start_index, + absl::Span tile_dimensions) { + // Copy input parameter values to shared memory buffers: + // tile[thread_id_y, thread_id_x] = input[index] + EmitTile(builder, tiling_, thread_id_info, tile_dimensions, + [&](absl::Span index_in_tile) { + auto index = tile_start_index.AddOffset(index_in_tile, builder); + for (const auto& tr : transposes) { + auto input_gen = + *fused_emitter.GetGenerator(*tr.instr->operand(0)); + auto input_index = index.SourceIndexOfBitcast( + tr.instr->operand(0)->shape(), builder); + llvm::Value* value = *input_gen(input_index); + tiles[tr.instr].Store(value, index_in_tile, builder); + } + + // Compute all extra output values before writing them. This + // avoids overwriting aliased input/output values before all + // reads occurred. + std::vector> + scheduled_writes; + for (const auto& [output_idx, root] : extra_outputs) { + auto extra_output_index = + index.SourceIndexOfBitcast(root->shape(), builder); + auto output_gen = *fused_emitter.GetGenerator(*root); + llvm::Value* output_value = *output_gen(extra_output_index); + scheduled_writes.emplace_back( + outputs[output_idx], extra_output_index, output_value); + } + + for (const auto& [output, idx, value] : scheduled_writes) { + output.EmitWriteArrayElement(idx, value, builder); + } + }); + + EmitSyncThreads(builder, ir_emitter_context); + + auto output_tile_index = PermuteIndex(tile_start_index, permutation); + auto transposed_tile_dimensions = Permute(tile_dimensions, permutation); + + EmitTile( + builder, tiling_, thread_id_info, transposed_tile_dimensions, + /*emit_elem_function=*/ + [&](absl::Span index_in_tile) { + auto index = output_tile_index.AddOffset(index_in_tile, builder); + for (const auto& tr : transposes) { + llvm::Value* loaded = tiles[tr.instr].Load( + Permute(index_in_tile, permutation), builder); + + FusedIrEmitter fused_emitter(elemental_emitter); + fused_emitter.BindGenerator( + *tr.instr, + [&](const llvm_ir::IrArray::Index&) { return loaded; }); + for (int64_t i = 0; + i < fusion.fused_instructions_computation()->num_parameters(); + ++i) { + llvm_ir::IrArray ir_array = inputs[i]; + HloInstruction* fused_operand = fusion.fused_parameter(i); + fused_emitter.BindGenerator( + *fused_operand, [=](const llvm_ir::IrArray::Index& index) { + return ir_array.EmitReadArrayElement(index, builder, + fused_operand->name()); + }); + } + + // Apply code generation for the code after the real hero. + // Compute all output values before writing them. This avoids + // overwriting aliased input/output values before all reads + // occurred. + std::vector> + scheduled_writes; + for (const auto& [output_idx, root] : + transposes_to_roots[tr.instr]) { + TF_ASSIGN_OR_RETURN(llvm_ir::ElementGenerator gen, + fused_emitter.GetGenerator(*root)); + + // Both for emission and writing it should be + // index-as-transformed by the computation. + auto untiled_index = + index.SourceIndexOfBitcast(root->shape(), builder); + TF_ASSIGN_OR_RETURN(llvm::Value * generated, gen(untiled_index)); + scheduled_writes.emplace_back(outputs[output_idx], untiled_index, + generated); + } + for (const auto& [output, idx, value] : scheduled_writes) { + output.EmitWriteArrayElement(idx, value, builder); + } + } + return absl::OkStatus(); + }); + }; llvm::Type* index_type = GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder); - return EmitTilingKernel(builder, tiling_scheme, index_type, tile_generator) + return EmitTilingKernel(builder, tiling_, index_type, tile_generator) .status(); } +LaunchDimensions TransposeFusion::launch_dimensions() const { + return LaunchDimensions(tiling_.GetNumBlocks(), + tiling_.GetNumThreadsPerBlock()); +} + +std::optional TransposeFusion::ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const { + const auto& hero = *analysis_.fusion_heroes()[root_index]; + const auto& root = *analysis_.fusion_roots()[root_index]; + if (!GetDescriptionForTiledTransposeEmitter(root, hero)) { + // Non-transpose roots are elementwise by definition. + return ComputeThreadIdToInputIndexing(root_index, 0, ctx); + } + + // The block offsets are permuted, but the thread offsets remain the same. + auto block_offset = GetBlockOffsetsForTiling(tiling_, ctx) + .getSubMap(std::vector{permutation_.begin(), + permutation_.end()}); + auto thread_offset = GetThreadOffsetsForTiling(tiling_, ctx); + auto permuted_tiled_shape = + ShapeUtil::MakeShape(U8, Permute(tiling_.GetShape(), permutation_)); + + auto map = ComposeIndexingMaps( + GetIndexingMapForTiling( + block_offset, thread_offset, tiling_.GetNumThreadsPerBlock(), + tiling_.GetNumBlocks(), tiling_.GetThreadTileSize(), + permuted_tiled_shape.dimensions()), + GetBitcastMap(permuted_tiled_shape, hero.shape(), ctx)); + map.Simplify(GetIndexingMapForInstruction); + return map; +} + +std::optional TransposeFusion::ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const { + const auto& hero = *analysis_.fusion_heroes()[root_index]; + + auto map = ComposeIndexingMaps( + GetIndexingMapForTiling(tiling_, ctx), + GetBitcastMap(tiling_.GetXlaShape(), hero.operand(0)->shape(), ctx)); + map.Simplify(GetIndexingMapForInstruction); + return map; +} + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/fusions/transpose.h b/xla/service/gpu/fusions/transpose.h index 93965cfaef684..899b1cb94390a 100644 --- a/xla/service/gpu/fusions/transpose.h +++ b/xla/service/gpu/fusions/transpose.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,14 +15,22 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_H_ #define XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_H_ +#include +#include #include +#include "absl/status/status.h" +#include "llvm/IR/IRBuilder.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/service/elemental_ir_emitter.h" #include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/fusions/tiling_util.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/llvm_ir/ir_array.h" +#include "xla/util.h" namespace xla { namespace gpu { @@ -52,24 +60,28 @@ namespace gpu { // efficient to launch fewer blocks so each transposes many tiles. class TransposeFusion : public KernelFusionEmitterBase { public: - explicit TransposeFusion(HloFusionAnalysis& analysis) : analysis_(analysis) {} - StatusOr launch_dimensions( - IrEmitterContext& ir_emitter_context, int kernel_index) const override { - return analysis_.GetLaunchDimensions(); - } + explicit TransposeFusion(const HloFusionAnalysis& analysis); + LaunchDimensions launch_dimensions() const override; + + std::optional ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const override; + + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const override; protected: - Status EmitKernel(IrEmitterContext& ir_emitter_context, - ElementalIrEmitter& elemental_emitter, - const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder, - int kernel_index) const override; + absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion, + const LaunchDimensions& launch_dims, + std::vector inputs, + std::vector outputs, + llvm::IRBuilder<>* builder) const override; private: - HloFusionAnalysis& analysis_; + const HloFusionAnalysis& analysis_; + Tiling tiling_; + Vector3 permutation_; }; } // namespace gpu diff --git a/xla/service/gpu/fusions/transpose_mlir.cc b/xla/service/gpu/fusions/transpose_mlir.cc new file mode 100644 index 0000000000000..9e2e5be2ea564 --- /dev/null +++ b/xla/service/gpu/fusions/transpose_mlir.cc @@ -0,0 +1,368 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/transpose_mlir.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/mlir/utils/type_util.h" +#include "xla/permutation_util.h" +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/tiling_util.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +using absl::StatusOr; +using llvm::SmallPtrSet; +using llvm::SmallVector; +using mlir::AffineExpr; +using mlir::AffineMap; +using mlir::MLIRContext; +using mlir::ModuleOp; +using mlir::RankedTensorType; +using mlir::Value; +using mlir::ValueRange; +using mlir::func::FuncOp; +using mlir::func::ReturnOp; +using mlir::tensor::ExtractOp; +using mlir::tensor::InsertOp; +using mlir_converter::ApplyAffineMap; +using mlir_converter::CallTargetProvider; +using mlir_converter::PartitionedComputation; + +Tiling ComputeTransposeTiling(const TransposeDescription& tiled_transpose) { + constexpr int kNumRows = 4; + static_assert(WarpSize() % kNumRows == 0); + + // 3D view over the output shape. + Vector3 transposed_dims = tiled_transpose.dimensions; + Vector3 permutation = tiled_transpose.permutation; + + // Note: the supported permutations are their own inverses. Therefore we + // always use the permutation, even when we want the inverse. + CHECK((permutation == Vector3{0, 2, 1}) || (permutation == Vector3{2, 1, 0})); + + absl::InlinedVector input_dims{transposed_dims[permutation[0]], + transposed_dims[permutation[1]], + transposed_dims[permutation[2]]}; + + // We tile along the minor dimensions pre- and post-transpose. + absl::InlinedVector tile_sizes{1, 1, 1}; + tile_sizes[permutation[2]] = WarpSize() / kNumRows; + absl::InlinedVector num_threads{1, 1, WarpSize()}; + num_threads[permutation[2]] = kNumRows; + + return Tiling(input_dims, tile_sizes, num_threads); +} + +// Returns transpose heroes that should be codegened via shmem. +std::vector GetShMemTransposes( + const HloFusionAnalysis& analysis) { + ConstHloInstructionSet transposes_to_tile; + for (const auto [hero, root] : + llvm::zip(analysis.fusion_heroes(), analysis.fusion_roots())) { + if (GetDescriptionForTiledTransposeEmitter(*root, *hero)) { + transposes_to_tile.insert(hero); + } + } + return {transposes_to_tile.begin(), transposes_to_tile.end()}; +} + +} // namespace + +MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) + : analysis_(analysis), + tiling_(ComputeTransposeTiling(analysis.tiled_transpose())), + shmem_transposes_(GetShMemTransposes(analysis)) { + for (auto [root, hero] : + llvm::zip(analysis_.fusion_roots(), analysis_.fusion_heroes())) { + if (auto transpose = GetDescriptionForTiledTransposeEmitter(*root, *hero)) { + permutation_ = transpose->permutation; + break; + } + } +} + +std::optional MlirTransposeFusion::ComputeThreadIdToOutputIndexing( + int64_t root_index, MLIRContext* mlir_context) const { + const auto& hero = *analysis_.fusion_heroes()[root_index]; + // The block offsets are permuted, but the thread offsets remain the same. + auto block_offset = GetBlockOffsetsForTiling(tiling_, mlir_context) + .getSubMap(std::vector{permutation_.begin(), + permutation_.end()}); + auto thread_offset = GetThreadOffsetsForTiling(tiling_, mlir_context); + auto permuted_tiled_shape = + ShapeUtil::MakeShape(U8, Permute(tiling_.GetShape(), permutation_)); + + auto map = ComposeIndexingMaps( + GetIndexingMapForTiling( + block_offset, thread_offset, tiling_.GetNumThreadsPerBlock(), + tiling_.GetNumBlocks(), tiling_.GetThreadTileSize(), + permuted_tiled_shape.dimensions()), + GetBitcastMap(permuted_tiled_shape, hero.shape(), mlir_context)); + map.Simplify(GetIndexingMapForInstruction); + return map; +} + +IndexingMap MlirTransposeFusion::ComputeThreadIdToInputIndexing( + const HloInstruction& hero, MLIRContext* mlir_context) const { + auto map = ComposeIndexingMaps( + GetIndexingMapForTiling(tiling_, mlir_context), + GetBitcastMap(tiling_.GetXlaShape(), hero.operand(0)->shape(), + mlir_context)); + map.Simplify(GetIndexingMapForInstruction); + return map; +} + +std::optional MlirTransposeFusion::ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + MLIRContext* mlir_context) const { + const auto& hero = *analysis_.fusion_heroes()[root_index]; + const auto& root = *analysis_.fusion_roots()[root_index]; + if (!GetDescriptionForTiledTransposeEmitter(root, hero)) { + // Non-transpose roots are elementwise by definition. + return ComputeThreadIdToOutputIndexing(root_index, mlir_context); + } + return ComputeThreadIdToInputIndexing(*analysis_.fusion_heroes()[root_index], + mlir_context); +} + +LaunchDimensions MlirTransposeFusion::launch_dimensions() const { + return LaunchDimensions(tiling_.GetNumBlocks(), + tiling_.GetNumThreadsPerBlock()); +} + +// Returns an indexing map with block_x, block_y, block_z set to 0. +IndexingMap GetSharedMemoryWriteIndexingMap( + const IndexingMap& thread_id_indexing, int loop_dim) { + auto* mlir_context = thread_id_indexing.GetMLIRContext(); + + AffineExpr c0 = mlir::getAffineConstantExpr(0, mlir_context); + AffineExpr th_x = mlir::getAffineDimExpr(0, mlir_context); + SmallVector tile_sizes(3); + mlir::bindSymbolsList(mlir_context, llvm::MutableArrayRef(tile_sizes)); + + IndexingMap shmem_write_indexing{ + AffineMap::get( + thread_id_indexing.GetDimensionCount(), + thread_id_indexing.GetSymbolCount(), + {c0, th_x.floorDiv(32) + 4 * tile_sizes[loop_dim], th_x % 32}, + mlir_context), + thread_id_indexing.GetDimVars(), + thread_id_indexing.GetRangeVars(), + thread_id_indexing.GetRTVars(), + thread_id_indexing.GetConstraints()}; + shmem_write_indexing.Simplify(GetIndexingMapForInstruction); + return shmem_write_indexing; +} + +// Returns an indexing map with block_x, block_y, block_z set to 0 and swapped +// 2nd and 3rd results. +IndexingMap GetSharedMemoryReadIndexingMap( + const IndexingMap& thread_id_indexing, int loop_dim) { + IndexingMap write_indexing = + GetSharedMemoryWriteIndexingMap(thread_id_indexing, loop_dim); + return IndexingMap{write_indexing.GetAffineMap().getSubMap({0, 2, 1}), + write_indexing.GetDimVars(), write_indexing.GetRangeVars(), + write_indexing.GetRTVars(), + write_indexing.GetConstraints()}; +} + +absl::StatusOr> MlirTransposeFusion::EmitWriteToShMemMlir( + mlir::ImplicitLocOpBuilder& builder, FuncOp entry_function, + const HloFusionInstruction& fusion, + const PartitionedComputation& root_computation, + const CallTargetProvider& call_target_provider) const { + std::vector shmem_tensor_size(tiling_.GetBlockTileSize().begin(), + tiling_.GetBlockTileSize().end()); + + int num_inputs = fusion.fused_instructions_computation()->num_parameters(); + int num_outputs = entry_function.getArguments().size() - num_inputs; + + MLIRContext* mlir_context = builder.getContext(); + SmallVector shmem_intermediate_result; + for (auto* transpose : shmem_transposes_) { + auto input_indexing = + ComputeThreadIdToInputIndexing(*transpose, mlir_context); + IndexingMap shmem_input_indexing = + GetSharedMemoryWriteIndexingMap(input_indexing, permutation_[2]); + + // Allocate shared memory. + const HloInstruction* transpose_operand = transpose->operand(0); + auto elem_type = *ConvertPrimitiveTypeToMlirType( + transpose_operand->shape().element_type(), builder); + auto shmem = builder.create( + RankedTensorType::get(shmem_tensor_size, elem_type)); + + // Emit loop that writes subgraphs of transpose operands to shmem. + auto shmem_result = EmitThreadLoopNest( + builder, {shmem}, input_indexing, + [&](ValueRange output_tensors, ValueRange dim_values, + ValueRange symbol_values) -> SmallVector { + auto input_indices = + ApplyAffineMap(input_indexing.GetAffineMap(), dim_values, + symbol_values, builder); + auto shmem_indices = + ApplyAffineMap(shmem_input_indexing.GetAffineMap(), dim_values, + symbol_values, builder); + + auto result_scalars = mlir_converter::ProvideParameter( + root_computation.FindSubgraph(transpose), transpose, + /*operand_index=*/0, input_indices, call_target_provider, + entry_function, builder); + + SmallVector result_tensors; + result_tensors.reserve(num_outputs); + for (auto [tensor, value] : + llvm::zip(output_tensors, result_scalars)) { + result_tensors.push_back( + builder.create(value, tensor, shmem_indices)); + } + return result_tensors; + }); + shmem_intermediate_result.append(shmem_result.begin(), shmem_result.end()); + } + + return shmem_intermediate_result; +} + +absl::Status MlirTransposeFusion::EmitReadFromShMemMlir( + mlir::ImplicitLocOpBuilder& builder, FuncOp entry_function, + const HloFusionInstruction& fusion, + const mlir_converter::PartitionedComputations& computations, + const CallTargetProvider& call_targets, ValueRange shmem_tensors) const { + int num_inputs = fusion.fused_instructions_computation()->num_parameters(); + auto* mlir_context = builder.getContext(); + ValueRange output_tensor_args = + entry_function.getArguments().drop_front(num_inputs); + auto output_indexing = *ComputeThreadIdToOutputIndexing(0, mlir_context); + auto shmem_output_indexing = + GetSharedMemoryReadIndexingMap(output_indexing, permutation_[2]); + auto epilogue_indexing = ComputeEpilogueInputToOutputIndexing( + analysis_.fusion_heroes()[0], mlir_context); + auto root_indexing = ComposeIndexingMaps(output_indexing, epilogue_indexing); + auto result_tensors = EmitThreadLoopNest( + builder, output_tensor_args, output_indexing, + [&](ValueRange output_tensors, ValueRange dim_values, + ValueRange symbol_values) -> SmallVector { + auto shmem_indices = + ApplyAffineMap(shmem_output_indexing.GetAffineMap(), dim_values, + symbol_values, builder); + llvm::SmallVector transpose_values; + for (auto shmem : shmem_tensors) { + transpose_values.push_back( + builder.create(shmem, shmem_indices)); + } + auto root_indices = ApplyAffineMap(root_indexing.GetAffineMap(), + dim_values, symbol_values, builder); + auto result_scalars = + EmitEpilogue(computations, entry_function, transpose_values, + root_indices, builder); + SmallVector results; + results.reserve(output_tensor_args.size()); + const auto& first_shape = analysis_.fusion_roots().front()->shape(); + for (auto [tensor, value, root] : llvm::zip( + output_tensors, result_scalars, analysis_.fusion_roots())) { + llvm::SmallVector indices; + if (ShapeUtil::EqualIgnoringElementType(first_shape, root->shape())) { + indices = root_indices; + } else { + auto bitcast_map = + GetBitcastMap(first_shape, root->shape(), mlir_context); + indices = ApplyAffineMap(bitcast_map.GetAffineMap(), root_indices, + {}, builder); + } + results.push_back(builder.create(value, tensor, indices)); + } + return results; + }); + + builder.create(result_tensors); + return absl::OkStatus(); +} + +std::vector +MlirTransposeFusion::GetInstructionsWithCustomCodegen( + const HloFusionInstruction& fusion) const { + return GetShMemTransposes(analysis_); +} + +absl::Status MlirTransposeFusion::EmitEntryFunction( + const mlir_converter::PartitionedComputations& computations, + const mlir_converter::CallTargetProvider& call_targets, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const { + const auto& root_computation = computations.FindPartitionedComputation( + fusion.fused_instructions_computation()); + // Write intermediate results to shmem. + mlir::ImplicitLocOpBuilder builder(entry_function.getLoc(), entry_function); + builder.setInsertionPointToStart(entry_function.addEntryBlock()); + TF_ASSIGN_OR_RETURN(auto shmem_tensors, + EmitWriteToShMemMlir(builder, entry_function, fusion, + root_computation, call_targets)); + // Sync GPU threads before reading from shmem. + auto sync_threads = builder.create( + mlir::TypeRange(shmem_tensors), shmem_tensors); + + // Read intermediate results from shmem and compute epilogues. + return EmitReadFromShMemMlir(builder, entry_function, fusion, computations, + call_targets, sync_threads.getResults()); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/transpose_mlir.h b/xla/service/gpu/fusions/transpose_mlir.h new file mode 100644 index 0000000000000..8329cdd852ae6 --- /dev/null +++ b/xla/service/gpu/fusions/transpose_mlir.h @@ -0,0 +1,96 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_MLIR_H_ +#define XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_MLIR_H_ + +#include +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" +#include "xla/service/gpu/fusions/tiling_util.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/status.h" +#include "xla/util.h" + +namespace xla { +namespace gpu { + +// Lowers kTranspose fusion to LLVM via MLIR using GPU's shared memory. + +// Each thread block of `kWarpSize` x `kNumRows` threads +// transposes one tile: each thread copies kWarpSize/kNumRows elements from +// the input to a shared memory tile. + +// This is similar to the following CUDA algorithm in TensorFlow: +// https://goo.gl/MStRV6. +class MlirTransposeFusion : public MlirFusionEmitterBase { + public: + explicit MlirTransposeFusion(const HloFusionAnalysis& analysis); + LaunchDimensions launch_dimensions() const override; + + std::optional ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* mlir_context) const override; + + std::optional ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* mlir_context) const override; + + protected: + IndexingMap ComputeThreadIdToInputIndexing( + const HloInstruction& hero, mlir::MLIRContext* mlir_context) const; + + absl::Status EmitEntryFunction( + const mlir_converter::PartitionedComputations& computations, + const mlir_converter::CallTargetProvider& call_targets, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const override; + + std::vector GetInstructionsWithCustomCodegen( + const HloFusionInstruction& fusion) const override; + + absl::StatusOr> EmitWriteToShMemMlir( + mlir::ImplicitLocOpBuilder& builder, mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion, + const mlir_converter::PartitionedComputation& root_computation, + const mlir_converter::CallTargetProvider& call_target_provider) const; + absl::Status EmitReadFromShMemMlir( + mlir::ImplicitLocOpBuilder& builder, mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion, + const mlir_converter::PartitionedComputations& computations, + const mlir_converter::CallTargetProvider& call_target_provider, + mlir::ValueRange shmem_tensors) const; + + private: + const HloFusionAnalysis& analysis_; + Tiling tiling_; + Vector3 permutation_; + std::vector shmem_transposes_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_MLIR_H_ diff --git a/xla/service/gpu/fusions/transpose_mlir_test.cc b/xla/service/gpu/fusions/transpose_mlir_test.cc new file mode 100644 index 0000000000000..2b3ddc04e5a3e --- /dev/null +++ b/xla/service/gpu/fusions/transpose_mlir_test.cc @@ -0,0 +1,434 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/transpose_mlir.h" + +#include +#include "xla/error_spec.h" +#include "xla/service/gpu/fusions/mlir_emitter_test_base.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +using MlirTransposeFusionTest = MlirEmitterTestBase; + +TEST_F(MlirTransposeFusionTest, ThreadIndexing021) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule module + + fusion { + %input = f32[100,32,64] parameter(0) + ROOT transpose = f32[100,64,32] transpose(%input), dimensions={0,2,1} + } + ENTRY entry { + %input = f32[100,32,64] parameter(0) + ROOT %fusion = f32[100,64,32] fusion(%input), kind=kInput, calls=fusion + } + )")); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + MlirTransposeFusion fusion(analysis); + EXPECT_THAT( + fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + d3 floordiv 2, + d0 floordiv 32 + s1 * 4, + (d3 mod 2) * 32 + d0 mod 32 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 0] + s1 in [0, 7] + s2 in [0, 0] + )")); + EXPECT_THAT( + fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + d3 floordiv 2, + d0 floordiv 32 + (d3 mod 2) * 32 + s1 * 4, + d0 mod 32 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 0] + s1 in [0, 7] + s2 in [0, 0] + )")); +} + +TEST_F(MlirTransposeFusionTest, ThreadIndexing201) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule module + + fusion { + %input = f32[100,64,32] parameter(0) + ROOT transpose = f32[32,100,64] transpose(%input), dimensions={2,0,1} + } + ENTRY entry { + %input = f32[100,64,32] parameter(0) + ROOT %fusion = f32[32,100,64] fusion(%input), kind=kInput, calls=fusion + })")); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + MlirTransposeFusion fusion(analysis); + + EXPECT_THAT( + fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + d3 floordiv 2, + d0 floordiv 32 + (d3 * 32 + s1 * 4) mod 64, + d0 mod 32 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 0] + s1 in [0, 7] + s2 in [0, 0] + )")); + EXPECT_THAT( + fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + d0 floordiv 32 + s1 * 4, + d3 floordiv 2, + (d3 mod 2) * 32 + d0 mod 32 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 0] + s1 in [0, 7] + s2 in [0, 0] + )")); +} + +TEST_F(MlirTransposeFusionTest, FusedTranspose021) { + auto kHloString = R"( + HloModule Transpose + + %fused_computation { + %p0 = f32[20,160,170] parameter(0) + %exp = f32[20,160,170] exponential(%p0) + %transpose = f32[20,170,160] transpose(%exp), dimensions={0,2,1} + ROOT %abs = f32[20,170,160] abs(%transpose) + } + ENTRY main { + %param = f32[20,160,170] parameter(0) + ROOT %fusion = f32[20,170,160] fusion(%param), kind=kInput, + calls=%fused_computation + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK-LABEL: func.func @fused_computation( + // CHECK-SAME: }, %[[OUT:.*]]: tensor<20x170x160xf32> + // + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index + + // CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<1x32x32xf32> + // CHECK: %[[SHMEM_WITH_VALS:.*]] = scf.for + // CHECK-SAME: %[[C0]] to %[[C8]] step %[[C1]] + // CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]]) + // CHECK: %[[EXP:.*]] = xla_gpu.pure_call @fused_computation_exp + // CHECK: tensor.insert %[[EXP]] into %[[SHMEM_]] + + // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] + + // CHECK: scf.for + // CHECK-SAME: %[[C0]] to %[[C8]] step %[[C1]] + // CHECK-SAME: iter_args(%[[OUT_:.*]] = %[[OUT]]) + // CHECK: %[[ABS:.*]] = xla_gpu.pure_call @fused_computation__epilogue__ + // CHECK: tensor.insert %[[ABS]] into %[[OUT_]] + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirTransposeFusionTest, Transpose021_Parameter) { + auto kHloString = R"( + HloModule Transpose + + %fused_computation { + %p0 = f32[20,160,170] parameter(0) + %transpose = f32[20,170,160] transpose(%p0), dimensions={0,2,1} + ROOT %abs = f32[20,170,160] abs(%transpose) + } + ENTRY main { + %param = f32[20,160,170] parameter(0) + ROOT %fusion = f32[20,170,160] fusion(%param), kind=kInput, + calls=%fused_computation + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK-LABEL: func.func @fused_computation( + // CHECK-SAME: }, %[[OUT:.*]]: tensor<20x170x160xf32> + // + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index + + // CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<1x32x32xf32> + // CHECK: %[[SHMEM_WITH_VALS:.*]] = scf.for + // CHECK-SAME: %[[C0]] to %[[C8]] step %[[C1]] + // CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]]) + // CHECK: %[[EXP:.*]] = xla_gpu.pure_call @fused_computation_p0 + // CHECK: tensor.insert %[[EXP]] into %[[SHMEM_]] + + // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] + + // CHECK: scf.for + // CHECK-SAME: %[[C0]] to %[[C8]] step %[[C1]] + // CHECK-SAME: iter_args(%[[OUT_:.*]] = %[[OUT]]) + // CHECK: %[[ABS:.*]] = xla_gpu.pure_call @fused_computation__epilogue__ + // CHECK: tensor.insert %[[ABS]] into %[[OUT_]] + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirTransposeFusionTest, Transpose021_NoEpilogue) { + auto kHloString = R"( + HloModule Transpose + + %fused_computation { + %p0 = f32[20,160,170] parameter(0) + ROOT %transpose = f32[20,170,160] transpose(%p0), dimensions={0,2,1} + } + ENTRY main { + %param = f32[20,160,170] parameter(0) + ROOT %fusion = f32[20,170,160] fusion(%param), kind=kInput, + calls=%fused_computation + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK-LABEL: func.func @fused_computation( + // CHECK-SAME: }, %[[OUT:.*]]: tensor<20x170x160xf32> + // + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index + // CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<1x32x32xf32> + // CHECK: %[[SHMEM_WITH_VALS:.*]] = scf.for + // CHECK-SAME: %[[C0]] to %[[C8]] step %[[C1]] + // CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]]) + // CHECK: %[[EXP:.*]] = xla_gpu.pure_call @fused_computation_p0 + // CHECK: tensor.insert %[[EXP]] into %[[SHMEM_]] + + // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] + + // CHECK: scf.for + // CHECK-SAME: %[[C0]] to %[[C8]] step %[[C1]] + // CHECK-SAME: iter_args(%[[OUT_:.*]] = %[[OUT]]) + // CHECK: %[[SHMEM_ELEM:.*]] = tensor.extract %[[SYNC]] + // CHECK: tensor.insert %[[SHMEM_ELEM]] into %[[OUT_]] + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirTransposeFusionTest, Transpose_4D) { + auto kHloString = R"( + HloModule Transpose + + %fused_computation { + %param_0 = f64[2,24,6,4] parameter(0) + ROOT %transpose= f64[6,4,2,24] transpose(f64[2,24,6,4] %param_0), + dimensions={2,3,0,1} + } + ENTRY main { + %param = f64[2,24,6,4] parameter(0) + ROOT %fusion = f64[6,4,2,24] fusion(%param), kind=kInput, + calls=%fused_computation + } + )"; + TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirTransposeFusionTest, Transpose_2D) { + auto kHloString = R"( + HloModule Transpose + + %fused_computation { + %param_0 = f64[100, 200] parameter(0) + ROOT %transpose= f64[200,100] transpose(f64[100, 200] %param_0), + dimensions={1,0} + } + ENTRY main { + %param = f64[100, 200] parameter(0) + ROOT %fusion = f64[200,100] fusion(%param), kind=kInput, + calls=%fused_computation + } + )"; + TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirTransposeFusionTest, Transpose_2D_2) { + auto kHloString = R"( + HloModule m + + %fused_computation { + %p0 = f32[17,2820]{0,1} parameter(0) + %p1 = f32[30,17,94] parameter(1) + + %bitcast0 = f32[2,3,5,17,94] bitcast(f32[30,17,94] %p1) + %transpose = f32[2,3,5,94,17] transpose(f32[2,3,5,17,94] %bitcast0), dimensions={0,1,2,4,3} + %bitcast1 = f32[2820,17]{1,0} bitcast(f32[2,3,5,94,17] %transpose) + %bitcast2 = f32[2820,17]{1,0} bitcast(f32[17,2820]{0,1} %p0) + %neg = f32[2820,17]{1,0} negate(f32[2820,17] %bitcast2) + ROOT %add = f32[2820,17]{1,0} add(f32[2820,17] %bitcast1, f32[2820,17]{1,0} %neg) + } + + ENTRY main { + %p1 = f32[30,17,94]{2,1,0} parameter(1) + %p0 = f32[17,2820]{0,1} parameter(0) + ROOT %fusion = f32[2820,17]{1,0} fusion(%p0, %p1), kind=kInput, calls=%fused_computation + } + )"; + TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirTransposeFusionTest, MultipleRootsForTranspose) { + auto kHloString = R"( + HloModule m + + %fused_computation { + %iota.0 = s32[200,200] iota(), iota_dimension=1 + %iota.1 = s32[200,200] iota(), iota_dimension=0 + %compare = pred[200,200] compare(%iota.0, %iota.1), direction=GE + %transpose = pred[200,200] transpose(%compare), dimensions={1,0} + %copy = pred[200,200] copy(%transpose) + %copy.1 = pred[200,200] copy(%transpose) + ROOT %tuple = (pred[200,200], pred[200,200], pred[200,200]{1,0}) + tuple(%transpose, %copy, %copy.1) + } + + ENTRY main { + ROOT %fusion = + (pred[200,200]{1,0}, pred[200,200]{1,0}, pred[200,200]{1,0}) + fusion(), kind=kInput, calls=%fused_computation + } + )"; + TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirTransposeFusionTest, PartialTile) { + auto kHloString = R"( + HloModule m + + fused_computation { + %p0 = f64[24,2,6,4] parameter(0) + ROOT %t = f64[6,4,2,24] transpose(%p0), dimensions={2,3,1,0} + } + + ENTRY main { + %p0 = f64[24,2,6,4] parameter(0) + ROOT %fusion = f64[6,4,2,24] fusion(%p0), kind=kInput, calls=%fused_computation + } + )"; + TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirTransposeFusionTest, MixedIndexing) { + auto kHloString = R"( + HloModule m + + fused_computation { + %p0 = f64[24,2,6,4] parameter(0) + %bc = f64[24,2,24] bitcast(%p0) + %t1 = f64[6,4,2,24] transpose(%p0), dimensions={2,3,1,0} + %t2 = f64[24,2,24] transpose(%bc), dimensions={2,1,0} + %p1 = f64[] parameter(1) + %bc1 = f64[6,4,2,24] broadcast(%p1), dimensions={} + %bc2 = f64[24,2,24] broadcast(%p1), dimensions={} + %a1 = f64[6,4,2,24] add(%t1, %bc1) + %a2 = f64[24,2,24] add(%t2, %bc2) + ROOT %t = (f64[6,4,2,24], f64[24,2,24]) tuple(%a1, %a2) + } + + ENTRY main { + %p0 = f64[24,2,6,4] parameter(0) + %p1 = f64[] parameter(1) + ROOT %fusion = (f64[6,4,2,24], f64[24,2,24]) fusion(%p0, %p1), + kind=kInput, calls=%fused_computation + } + )"; + TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirTransposeFusionTest, SideOutputs) { + auto kHloString = R"( + HloModule m + + fused_computation { + %p0 = f64[24,2,36] parameter(0) + %p1 = f64[36,2,24] parameter(1) + %tr = f64[36,2,24] transpose(%p0), dimensions={2,1,0} + %neg = f64[36,2,24] negate(%p1) + %log = f64[24,2,36] log(%p0) + ROOT %t = (f64[36,2,24], f64[36,2,24], f64[24,2,36]) + tuple(%neg, %tr, %log) + } + + ENTRY main { + %p0 = f64[24,2,36] parameter(0) + %p1 = f64[36,2,24] parameter(1) + ROOT %fusion = (f64[36,2,24], f64[36,2,24], f64[24,2,36]) + fusion(%p0, %p1), kind=kInput, calls=%fused_computation + } + )"; + TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/transpose_test.cc b/xla/service/gpu/fusions/transpose_test.cc new file mode 100644 index 0000000000000..d7363bbd39f38 --- /dev/null +++ b/xla/service/gpu/fusions/transpose_test.cc @@ -0,0 +1,251 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/transpose.h" + +#include +#include + +#include +#include +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/service/gpu/fusions/fusions.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/status_macros.h" +#include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +using ::testing::HasSubstr; + +class TransposeTest : public HloTestBase { + protected: + stream_executor::DeviceDescription device_info_ = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); +}; + +absl::StatusOr> GetTransposeFusion( + const HloFusionAnalysis& analysis) { + TF_ASSIGN_OR_RETURN( + auto emitter, GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis})); + auto fusion = dynamic_cast(emitter.get()); + TF_RET_CHECK(fusion != nullptr); + + emitter.release(); + return std::unique_ptr{fusion}; +} + +TEST_F(TransposeTest, ThreadIndexing021) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + fusion { + %input = f32[100,32,64] parameter(0) + ROOT transpose = f32[100,64,32] transpose(%input), dimensions={0,2,1} + } + + ENTRY entry { + %input = f32[100,32,64] parameter(0) + ROOT %fusion = f32[100,64,32] fusion(%input), kind=kInput, calls=fusion + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); + mlir::MLIRContext mlir_context; + + EXPECT_THAT( + fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + d3 floordiv 2, + d0 floordiv 32 + s1 * 4, + (d3 mod 2) * 32 + d0 mod 32 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 0] + s1 in [0, 7] + s2 in [0, 0] + )")); + EXPECT_THAT( + fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + d3 floordiv 2, + d0 floordiv 32 + (d3 mod 2) * 32 + s1 * 4, + d0 mod 32 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 0] + s1 in [0, 7] + s2 in [0, 0] + )")); +} + +TEST_F(TransposeTest, ThreadIndexing201) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule module + + fusion { + %input = f32[100,64,32] parameter(0) + ROOT transpose = f32[32,100,64] transpose(%input), dimensions={2,0,1} + } + + ENTRY entry { + %input = f32[100,64,32] parameter(0) + ROOT %fusion = f32[32,100,64] fusion(%input), kind=kInput, calls=fusion + })") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); + mlir::MLIRContext mlir_context; + EXPECT_THAT( + fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + d3 floordiv 2, + d0 floordiv 32 + (d3 * 32 + s1 * 4) mod 64, + d0 mod 32 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 0] + s1 in [0, 7] + s2 in [0, 0] + )")); + EXPECT_THAT( + fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + d0 floordiv 32 + s1 * 4, + d3 floordiv 2, + (d3 mod 2) * 32 + d0 mod 32 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 0] + s1 in [0, 7] + s2 in [0, 0] + )")); +} + +TEST_F(TransposeTest, ThreadIndexingPartialBlock) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule m + + fused_computation { + %p0 = f64[24,2,6,4] parameter(0) + ROOT %t = f64[6,4,2,24] transpose(%p0), dimensions={2,3,1,0} + } + + ENTRY main { + %p0 = f64[24,2,6,4] parameter(0) + ROOT %fusion = f64[6,4,2,24] fusion(%p0), kind=kInput, + calls=%fused_computation + } + )") + .value(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + + TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); + mlir::MLIRContext mlir_context; + EXPECT_THAT( + fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + d0 floordiv 32 + s0 * 4, + d3, + (d0 floordiv 4) mod 8, + d0 mod 4 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 1] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 7] + s1 in [0, 0] + s2 in [0, 0] + d0 floordiv 32 + s0 * 4 in [0, 23] + d0 mod 32 in [0, 23] + )")); + EXPECT_THAT( + fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + s0, + d0 floordiv 32, + d3, + d0 mod 32 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 1] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 7] + s1 in [0, 0] + s2 in [0, 0] + d0 floordiv 32 + s0 * 4 in [0, 23] + d0 mod 32 in [0, 23] + )")); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/triton.cc b/xla/service/gpu/fusions/triton.cc new file mode 100644 index 0000000000000..ebbaccdb0bd74 --- /dev/null +++ b/xla/service/gpu/fusions/triton.cc @@ -0,0 +1,229 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/triton.h" + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/kernel_arguments.h" +#include "xla/service/gpu/kernel_reuse_cache.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/runtime/kernel_thunk.h" +#include "xla/service/gpu/triton_fusion_analysis.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/statusor.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "xla/service/gpu/ir_emitter_triton.h" +#else +#include "absl/status/status.h" +#endif + +namespace xla { +namespace gpu { +namespace { + +// Derives the number of blocks and threads to use for processing a Triton +// Softmax fusion. +LaunchDimensions CalculateSoftMaxLaunchDimensions( + const HloFusionAdaptor& fusion) { + auto reduce = HloFindIf(fusion.GetRoots(), fusion, [](auto node) { + return node.opcode() == HloOpcode::kReduce; + }); + + CHECK(reduce.has_value()); + const Shape& reduce_input_shape = reduce->GetOperand(0).instruction().shape(); + + CHECK_EQ(reduce->instruction().dimensions().size(), 1); + CHECK_EQ(reduce->instruction().dimensions()[0], + reduce_input_shape.rank() - 1); + + int reduction_dim = reduce_input_shape.dimensions_minor(0); + + unsigned num_rows = 1; + for (unsigned minor_axis = 1; minor_axis < reduce_input_shape.rank(); + ++minor_axis) { + num_rows *= reduce_input_shape.dimensions_minor(minor_axis); + } + + unsigned num_warps = 32; + + if (reduction_dim <= 512) { + num_warps = 1; + } else if (reduction_dim <= 1024) { + num_warps = 2; + } else if (reduction_dim <= 16384) { + num_warps = 4; + } else if (reduction_dim <= 32768) { + num_warps = 8; + } else if (reduction_dim <= 65536) { + num_warps = 16; + } + + return {num_rows, static_cast(num_warps * WarpSize())}; +} + +} // namespace + +absl::StatusOr TritonFusion::Emit( + IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion) const { + llvm::IRBuilder builder(ir_emitter_context.llvm_module()->getContext()); +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + VLOG(3) << fusion.ToString(); + std::string suggested_kernel_name = std::string(fusion.name()); + TF_ASSIGN_OR_RETURN( + auto kernel_arguments, + KernelArguments::Create(ir_emitter_context.buffer_assignment(), &fusion)); + + const HloComputation* hlo_computation = + fusion.fused_instructions_computation(); + + auto generate = [&]() -> absl::StatusOr { + VLOG(3) << "Generating: " << suggested_kernel_name; + + const std::string impl_fn_name = + ir_emitter_context.name_uniquer()->GetUniqueName( + llvm_ir::SanitizeFunctionName( + absl::StrCat(suggested_kernel_name, "_impl"))); + + auto backend_config = analysis_.fusion_backend_config(); + absl::string_view fusion_kind = backend_config.kind(); + + TritonWrapperResult triton_wrapper_result; + LaunchDimensions launch_dimensions; + if (fusion_kind == kTritonSoftmaxFusionKind) { + launch_dimensions = *this->launch_dimensions(); + + // This is a hack, we use TritonGemmConfig for Softmax too, but we ignore + // most parameters. + TritonGemmConfig config; + config.num_stages = 1; + // Thread count per block is always a multiple of WarpSize. + config.num_warps = launch_dimensions.num_threads_per_block() / WarpSize(); + config.num_ctas = 1; + + TF_ASSIGN_OR_RETURN(auto analysis, + TritonFusionAnalysis::Execute(*hlo_computation)); + TF_ASSIGN_OR_RETURN( + triton_wrapper_result, + TritonWrapper(analysis, impl_fn_name, hlo_computation, + ir_emitter_context.gpu_compute_capability(), + ir_emitter_context.gpu_device_info(), config, + ir_emitter_context.llvm_module(), &EmitSoftMax, + *ir_emitter_context.mlir_context())); + } else { // Must be a MatMul + CHECK_EQ(fusion_kind, kTritonGemmFusionKind); + if (!backend_config.has_triton_gemm_config()) { + LOG(WARNING) << "Using fallback triton GEMM config for op " + << fusion.name(); + auto& triton_config = *backend_config.mutable_triton_gemm_config(); + triton_config.set_block_m(64); + triton_config.set_block_k(64); + triton_config.set_block_n(64); + triton_config.set_split_k(1); + triton_config.set_num_stages(1); + triton_config.set_num_warps(2); + triton_config.set_num_ctas(1); + } + TF_ASSIGN_OR_RETURN( + TritonGemmConfig config, + TritonGemmConfig::FromProto(backend_config.triton_gemm_config())); + + TF_ASSIGN_OR_RETURN(auto analysis, TritonFusionAnalysis::Execute( + *hlo_computation, config.split_k)); + TF_ASSIGN_OR_RETURN( + triton_wrapper_result, + TritonWrapper(analysis, impl_fn_name, hlo_computation, + ir_emitter_context.gpu_compute_capability(), + ir_emitter_context.gpu_device_info(), config, + ir_emitter_context.llvm_module(), &EmitMatMul, + *ir_emitter_context.mlir_context())); + TF_ASSIGN_OR_RETURN( + launch_dimensions, + GetMatMulLaunchDimensions(analysis, analysis_.fusion(), config)); + } + + llvm::Function* impl_fn = + ir_emitter_context.llvm_module()->getFunction(impl_fn_name); + TF_RET_CHECK(impl_fn); + + llvm::Function* kernel; + std::vector inputs; + std::vector outputs; + TF_ASSIGN_OR_RETURN( + std::tie(kernel, inputs, outputs), + BuildKernelPrototype(ir_emitter_context, suggested_kernel_name, + kernel_arguments.args(), impl_fn->arg_size(), + launch_dimensions, &builder)); + + // Move function body into kernel prototype. + llvm::Function* prototype_func = builder.GetInsertBlock()->getParent(); + prototype_func->splice(prototype_func->begin(), impl_fn); + for (const auto& [arg, ir_array] : llvm::zip(impl_fn->args(), inputs)) { + arg.replaceAllUsesWith(ir_array.GetBasePointer()); + } + impl_fn->eraseFromParent(); + + return {{kernel->getName().str(), launch_dimensions, + triton_wrapper_result.cluster_dim, + triton_wrapper_result.shmem_bytes}}; + }; + + auto [status_or_entry, was_cached] = + ir_emitter_context.kernel_cache().GetWithStatus( + hlo_computation, kernel_arguments.args(), + /*discriminator=*/"", generate); + TF_ASSIGN_OR_RETURN(const KernelReuseCache::Entry* entry, status_or_entry); + + FusionEmissionResult result; + result.thunks.emplace_back(std::make_unique( + &fusion, entry->kernel_name, kernel_arguments.args(), + entry->launch_dimensions, entry->cluster_dim, entry->shmem_bytes)); + + return result; +#else + return absl::UnimplementedError("Triton support requires CUDA or ROCm"); +#endif +} + +std::optional TritonFusion::launch_dimensions() const { + if (analysis_.fusion_backend_config().kind() == kTritonSoftmaxFusionKind) { + return CalculateSoftMaxLaunchDimensions(analysis_.fusion()); + } + + // MatMul is not yet supported. + return std::nullopt; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/triton.h b/xla/service/gpu/fusions/triton.h new file mode 100644 index 0000000000000..c584ca8c56fb7 --- /dev/null +++ b/xla/service/gpu/fusions/triton.h @@ -0,0 +1,49 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_H_ +#define XLA_SERVICE_GPU_FUSIONS_TRITON_H_ + +#include + +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/statusor.h" + +namespace xla { +namespace gpu { + +class TritonFusion : public FusionInterface { + public: + explicit TritonFusion(const HloFusionAnalysis& analysis) + : analysis_(analysis) {} + + absl::StatusOr Emit( + IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion) const final; + + // Returns the launch dimensions for softmax fusions. Not supported for + // MatMul fusions. + std::optional launch_dimensions() const; + + private: + const HloFusionAnalysis& analysis_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_H_ diff --git a/xla/service/gpu/fusions/triton_test.cc b/xla/service/gpu/fusions/triton_test.cc new file mode 100644 index 0000000000000..49b8f6d966444 --- /dev/null +++ b/xla/service/gpu/fusions/triton_test.cc @@ -0,0 +1,95 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/triton.h" + +#include + +#include +#include +#include "xla/service/gpu/fusions/fusions.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class TritonFusionTest : public HloTestBase {}; + +TEST_F(TritonFusionTest, TritonSoftmaxFusion) { +#ifndef GOOGLE_CUDA + GTEST_SKIP() << "Triton fusion only enable for CUDA devices."; +#endif + + auto module = ParseAndReturnVerifiedModule(R"( + HloModule t + + add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) + } + + auxiliary_computation { + parameter_0 = f32[125]{0} parameter(0) + ROOT broadcast = f32[125,127]{1,0} broadcast(parameter_0), dimensions={0} + } + + triton_softmax_computation { + parameter_0 = f32[125,127]{1,0} parameter(0) + multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) + constant_0 = f32[] constant(0) + reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) + } + + ENTRY main { + param_0 = f32[125]{0} parameter(0) + auxiliary_fusion = f32[125,127]{1,0} fusion(param_0), kind=kLoop, calls=auxiliary_computation + ROOT triton_softmax = f32[125,127]{1,0} fusion(auxiliary_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config":{"kind":"__triton_softmax"}} + })") + .value(); + + stream_executor::GpuDeviceInfoProto device_info_proto; + stream_executor::DeviceDescription device_info(device_info_proto); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis_fused = + AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info); + + TF_ASSERT_OK_AND_ASSIGN( + auto emitter_fused, + GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused})); + auto triton_fusion = dynamic_cast(emitter_fused.get()); + ASSERT_NE(triton_fusion, nullptr); + auto launch_dims = triton_fusion->launch_dimensions(); + ASSERT_NE(launch_dims, std::nullopt); + EXPECT_EQ(launch_dims->num_blocks(), 125); + EXPECT_EQ(launch_dims->num_threads_per_block(), 32); + + auto analysis_consumer = AnalyzeFusion(*root, device_info); + + TF_ASSERT_OK_AND_ASSIGN( + auto emitter_consumer, + GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_consumer})); + ASSERT_NE(dynamic_cast(emitter_consumer.get()), nullptr); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/gemm_algorithm_picker.cc b/xla/service/gpu/gemm_algorithm_picker.cc index 2823e16fcd037..bafe75abddace 100644 --- a/xla/service/gpu/gemm_algorithm_picker.cc +++ b/xla/service/gpu/gemm_algorithm_picker.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,37 +15,44 @@ limitations under the License. #include "xla/service/gpu/gemm_algorithm_picker.h" -#include +#include #include -#include -#include +#include #include #include -#include -#include #include #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/gpu_asm_opts_util.h" +#include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/stream_executor_util.h" +#include "xla/service/gpu/variant_visitor.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/statusor.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/gpu/redzone_allocator.h" +#include "xla/stream_executor/scratch_allocator.h" +#include "xla/tsl/util/proto/proto_utils.h" #include "xla/util.h" #include "tsl/platform/errors.h" -#include "tsl/platform/logger.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" -#include "tsl/util/proto/proto_utils.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "xla/service/gpu/buffer_comparator.h" @@ -53,150 +60,11 @@ limitations under the License. namespace xla { namespace gpu { - -// Returns the index (into `algorithms`) of the fastest algorithm. -template -StatusOr GetBestAlgorithm( - se::Stream* stream, se::RedzoneAllocator& allocator, - std::optional gemm_str, - const AutotuneConfig& autotune_config, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, - absl::Span algorithms, const Shape& output_shape, - const HloModuleConfig& hlo_module_config, double beta, - const std::function(const AlgoT&)>& - run_benchmark) { - if (!stream->parent()->SynchronizeAllActivity()) { - return InternalError("Failed to synchronize GPU for autotuning."); - } - - se::DeviceMemoryBase reference_buffer; - if (autotune_config.should_check_correctness()) { - TF_ASSIGN_OR_RETURN( - reference_buffer, - allocator.AllocateBytes(ShapeUtil::ByteSizeOf(output_shape))); - } - - BufferComparator comparator(output_shape, hlo_module_config); - - std::vector results; - std::optional reference_algorithm; - - for (const AlgoT& algorithm : algorithms) { - // Make sure the output buffer always has the same value if we use - // the bias parameter. - if (autotune_config.should_reinit_output_buffer() && beta != 0) { - int64_t rng_state = 0; - InitializeBuffer(stream, output_shape.element_type(), &rng_state, - output_buffer); - } - - TF_ASSIGN_OR_RETURN(se::blas::ProfileResult profile_result, - run_benchmark(algorithm)); - - results.emplace_back(); - AutotuneResult& result = results.back(); - result.mutable_gemm()->set_algorithm(profile_result.algorithm()); - - if (!profile_result.is_valid()) { // Unsupported algorithm. - result.mutable_failure()->set_kind(AutotuneResult::DISQUALIFIED); - continue; - } - - VLOG(2) << "gemm algorithm " << profile_result.algorithm() << " took " - << profile_result.elapsed_time_in_ms() << "ms"; - - *result.mutable_run_time() = tsl::proto_utils::ToDurationProto( - absl::Milliseconds(profile_result.elapsed_time_in_ms())); - - if (!autotune_config.should_check_correctness()) { - continue; - } -#if GOOGLE_CUDA // redzone check is not yet available on ROCm - TF_ASSIGN_OR_RETURN( - se::RedzoneAllocator::RedzoneCheckStatus rz_check_status, - allocator.CheckRedzones()); - - if (!rz_check_status.ok()) { - result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED); - *result.mutable_failure()->mutable_msg() = - rz_check_status.RedzoneFailureMsg(); - LOG(ERROR) << "Detected out-of-bounds write in gemm buffer"; - CHECK(!autotune_config.should_crash_on_check_failure()); - continue; - } -#endif // GOOGLE_CUDA - - if (!reference_algorithm) { - stream->ThenMemcpy(&reference_buffer, output_buffer, - output_buffer.size()); - reference_algorithm = profile_result.algorithm(); - } else { - // Perform the comparison. - TF_ASSIGN_OR_RETURN( - bool outputs_match, - comparator.CompareEqual(stream, /*current=*/output_buffer, - /*expected=*/reference_buffer)); - if (!outputs_match) { - LOG(ERROR) << "Results mismatch between different GEMM algorithms. " - << "This is likely a bug/unexpected loss of precision."; - CHECK(!autotune_config.should_crash_on_check_failure()); - - result.mutable_failure()->set_kind(AutotuneResult::WRONG_RESULT); - result.mutable_failure()->mutable_reference_gemm()->set_algorithm( - *reference_algorithm); - } - } - } - - if (!autotune_config.should_crash_on_check_failure()) { - AutotuningLog log; - for (const AutotuneResult& result : results) { - *log.add_results() = result; - } - tsl::Logger::GetSingleton()->LogProto(log); - } - - StatusOr best = - PickBestResult(results, gemm_str, hlo_module_config); - if (best.ok()) { - for (size_t i = 0; i < results.size(); ++i) { - if (best->gemm().algorithm() == results[i].gemm().algorithm()) { - best->mutable_gemm()->set_algorithm(i); - return best; - } - } - return InternalError("unknown best algorithm"); - } - - LOG(WARNING) << "Failed to find best cuBLAS algorithm, GEMM performance " - "might be suboptimal: " - << best.status(); - return AutotuneResult{}; -} - -// Select the best algorithm using information from a Blas instruction. -// Returns the index (into `algorithms`) of the fastest algorithm. -StatusOr GetBestBlasAlgorithm( - se::Stream* stream, se::RedzoneAllocator& allocator, - std::optional gemm_str, - const AutotuneConfig& autotune_config, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, - absl::Span algorithms, - const Shape& output_shape, const HloModuleConfig& hlo_module_config, - double beta, - const std::function( - const se::blas::AlgorithmType&)>& run_benchmark) { - return GetBestAlgorithm( - stream, allocator, gemm_str, autotune_config, lhs_buffer, rhs_buffer, - output_buffer, algorithms, output_shape, hlo_module_config, beta, - run_benchmark); -} - namespace { using se::gpu::BlasLt; -StatusOr AsBlasLtEpilogue( +absl::StatusOr AsBlasLtEpilogue( GemmBackendConfig_Epilogue epilogue) { switch (epilogue) { case GemmBackendConfig::DEFAULT: @@ -216,211 +84,376 @@ StatusOr AsBlasLtEpilogue( case GemmBackendConfig::BIAS_GELU_AUX: return BlasLt::Epilogue::kBiasThenGELUWithAux; default: - return InternalError("Unsupported Epilogue."); + return Internal("Unsupported Epilogue."); } } #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -StatusOr DoGemmAutotuneNoCache( - const HloInstruction* gemm, const AutotuneCacheKey& key, - const AutotuneConfig& autotune_config) { - if (autotune_config.IsDeviceless()) { - // Return empty result, will tune at runtime. - return AutotuneResult{}; +class GemmAutotuner { + const AutotuneConfig& autotune_config_; + se::DeviceMemoryBase lhs_buffer_, rhs_buffer_, output_buffer_; + std::unique_ptr redzone_allocator_; + se::Stream* stream_ = nullptr; + bool deterministic_ops_ = false; + int64_t rng_state_ = 0; + + public: + explicit GemmAutotuner(const AutotuneConfig& autotune_config) + : autotune_config_(autotune_config) {} + + absl::StatusOr operator()(const HloInstruction* gemm, + const AutotuneCacheKey& key) { + if (autotune_config_.IsDeviceless()) { + // Return empty result, will tune at runtime. + return AutotuneResult{}; + } + VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString(); + + TF_ASSIGN_OR_RETURN(stream_, autotune_config_.GetStream()); + const DebugOptions& debug_options = + gemm->GetModule()->config().debug_options(); + deterministic_ops_ = debug_options.xla_gpu_deterministic_ops(); + + TF_ASSIGN_OR_RETURN(auto gemm_config, GemmConfig::For(gemm)); + + // Don't run autotuning concurrently on the same GPU. + absl::MutexLock gpu_lock(&GetGpuMutex(stream_->parent())); + + TF_ASSIGN_OR_RETURN(auto buf_alloc, AutotunerUtil::CreateRedzoneAllocator( + autotune_config_, debug_options)); + redzone_allocator_ = + std::make_unique(std::move(buf_alloc)); + + TF_ASSIGN_OR_RETURN(lhs_buffer_, CreateBuffer(gemm->operand(0)->shape())); + TF_ASSIGN_OR_RETURN(rhs_buffer_, CreateBuffer(gemm->operand(1)->shape())); + TF_ASSIGN_OR_RETURN(output_buffer_, CreateBuffer(GetOutputShape(gemm))); + + return IsCublasLtMatmul(*gemm) || IsCublasLtMatmulF8(*gemm) + ? TuneGpuBlasLt(gemm, gemm_config) + : TuneGpuBlas(gemm, gemm_config); } - VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString(); - se::DeviceMemoryAllocator* allocator = autotune_config.GetAllocator(); - TF_ASSIGN_OR_RETURN(se::Stream* const stream, autotune_config.GetStream()); - GemmBackendConfig gemm_config = - gemm->backend_config().value(); - const DebugOptions& debug_options = - gemm->GetModule()->config().debug_options(); - const bool deterministic_ops = debug_options.xla_gpu_deterministic_ops(); - - TF_ASSIGN_OR_RETURN(GemmConfig config, GemmConfig::For(gemm)); - // Don't run autotuning concurrently on the same GPU. - absl::MutexLock gpu_lock(&GetGpuMutex(stream->parent())); - - TF_ASSIGN_OR_RETURN( - se::RedzoneAllocator buffer_allocator, - AutotunerUtil::CreateRedzoneAllocator(autotune_config, debug_options)); - - int64_t rng_state = 0; - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase lhs_buffer, - AutotunerUtil::CreateBuffer(buffer_allocator, gemm->operand(0)->shape(), - autotune_config, rng_state)); - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase rhs_buffer, - AutotunerUtil::CreateBuffer(buffer_allocator, gemm->operand(1)->shape(), - autotune_config, rng_state)); - - const Shape& output_shape = - gemm->shape().IsTuple() ? gemm->shape().tuple_shapes(0) : gemm->shape(); - - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase output_buffer, - AutotunerUtil::CreateBuffer(buffer_allocator, output_shape, - autotune_config, rng_state)); - - int64_t workspace_size = - autotune_config.GetCudaComputeCapability().IsAtLeastHopper() - ? GemmConfig::kHopperWorkspace - : GemmConfig::kDefaultWorkspace; - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase workspace_buffer, - AutotunerUtil::CreateBuffer(buffer_allocator, - ShapeUtil::MakeShape(S8, {workspace_size}), - autotune_config, rng_state)); - - HloModuleConfig& hlo_module_config = gemm->GetModule()->mutable_config(); - AutotuneResult best_algorithm; - if (IsCublasLtMatmul(*gemm)) { - bool has_matrix_bias = config.beta != 0.; + private: + const Shape& GetOutputShape(const HloInstruction* gemm) { + return gemm->shape().IsTuple() ? gemm->shape().tuple_shapes(0) + : gemm->shape(); + } + + absl::StatusOr CreateBuffer(const Shape& shape) { + return AutotunerUtil::CreateBuffer(*redzone_allocator_, shape, + autotune_config_, rng_state_); + } + + absl::StatusOr TuneGpuBlasLt(const HloInstruction* gemm, + const GemmConfig& gemm_config) { + GpuBackendConfig gpu_config = + gemm->backend_config().value(); + const GemmBackendConfig& backend_config = gpu_config.gemm_backend_config(); + + bool has_matrix_bias = gemm_config.beta != 0.; TF_ASSIGN_OR_RETURN( bool has_vector_bias, - xla::gpu::gpublas_lt::EpilogueAddsVectorBias(gemm_config.epilogue())); + gpublas_lt::EpilogueAddsVectorBias(backend_config.epilogue())); - TF_ASSIGN_OR_RETURN(bool has_aux_output, - xla::gpu::gpublas_lt::EpilogueHasAuxiliaryOutput( - gemm_config.epilogue())); + TF_ASSIGN_OR_RETURN( + bool has_aux_output, + gpublas_lt::EpilogueHasAuxiliaryOutput(backend_config.epilogue())); TF_ASSIGN_OR_RETURN(auto epilogue, - AsBlasLtEpilogue(gemm_config.epilogue())); + AsBlasLtEpilogue(backend_config.epilogue())); + + se::DeviceMemoryBase a_scale_buffer, b_scale_buffer, c_scale_buffer, + d_scale_buffer, d_amax_buffer, bias_buffer, aux_buffer; - se::DeviceMemoryBase bias_buffer; if (has_vector_bias) { TF_ASSIGN_OR_RETURN( bias_buffer, - AutotunerUtil::CreateBuffer( - buffer_allocator, gemm->operand(has_matrix_bias ? 3 : 2)->shape(), - autotune_config, rng_state)); + CreateBuffer(gemm->operand(has_matrix_bias ? 3 : 2)->shape())); } - se::DeviceMemoryBase a_scale_buffer, b_scale_buffer, c_scale_buffer, - d_scale_buffer, d_amax_buffer; - - se::DeviceMemoryBase aux_buffer; if (has_aux_output) { - TF_ASSIGN_OR_RETURN( - aux_buffer, AutotunerUtil::CreateBuffer(buffer_allocator, - gemm->shape().tuple_shapes(1), - autotune_config, rng_state)); + TF_ASSIGN_OR_RETURN(aux_buffer, + CreateBuffer(gemm->shape().tuple_shapes(1))); } TF_ASSIGN_OR_RETURN(auto plan, - BlasLt::GetMatmulPlan(stream, config, epilogue)); + BlasLt::GetMatmulPlan(stream_, gemm_config, epilogue)); TF_ASSIGN_OR_RETURN(auto algorithms, plan->GetAlgorithms()); + auto tuned_func = [&](const BlasLt::MatmulAlgorithm& algorithm) + -> absl::StatusOr { + se::OwningScratchAllocator<> scratch_allocator( + stream_->parent()->device_ordinal(), autotune_config_.GetAllocator()); + // Run a warmup iteration without the profiler active. + TF_RETURN_IF_ERROR(plan->ExecuteOnStream( + stream_, lhs_buffer_, rhs_buffer_, output_buffer_, output_buffer_, + bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer, + c_scale_buffer, d_scale_buffer, d_amax_buffer, algorithm, + scratch_allocator)); + se::blas::ProfileResult profile_result; + profile_result.set_warmup_run_executed(true); + TF_RETURN_IF_ERROR(plan->ExecuteOnStream( + stream_, lhs_buffer_, rhs_buffer_, output_buffer_, output_buffer_, + bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer, + c_scale_buffer, d_scale_buffer, d_amax_buffer, algorithm, + scratch_allocator, &profile_result)); + return std::move(profile_result); + }; + + return GetBestAlgorithm( + gemm, algorithms, gemm_config.beta, tuned_func); + } + + absl::StatusOr TuneGpuBlas(const HloInstruction* gemm, + const GemmConfig& gemm_config) { + int64_t workspace_size = + std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { + return cc.IsAtLeastHopper() + ? GemmConfig::kHopperWorkspace + : GemmConfig::kDefaultWorkspace; + }, + [](const se::RocmComputeCapability&) { + return GemmConfig::kDefaultWorkspace; + }}, + autotune_config_.GetGpuComputeCapability()); + TF_ASSIGN_OR_RETURN( - best_algorithm, - GetBestAlgorithm( - stream, buffer_allocator, gemm->ToString(), autotune_config, - lhs_buffer, rhs_buffer, output_buffer, algorithms, output_shape, - hlo_module_config, gemm_config.beta(), - [&](const BlasLt::MatmulAlgorithm& algorithm) - -> StatusOr { - se::OwningScratchAllocator<> scratch_allocator( - stream->parent()->device_ordinal(), allocator); - se::blas::ProfileResult profile_result; - TF_RETURN_IF_ERROR(plan->ExecuteOnStream( - stream, lhs_buffer, rhs_buffer, output_buffer, output_buffer, - bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer, - c_scale_buffer, d_scale_buffer, d_amax_buffer, algorithm, - scratch_allocator, &profile_result)); - return std::move(profile_result); - })); - } else { + auto workspace_buffer, + CreateBuffer(ShapeUtil::MakeShape(S8, {workspace_size}))); + std::vector algorithms; - TF_RET_CHECK(stream->parent()->GetBlasGemmAlgorithms(stream, &algorithms)); + TF_ASSIGN_OR_RETURN(GemmConfig::DescriptorsTuple desc, + gemm_config.GetMatrixDescriptors( + lhs_buffer_, rhs_buffer_, output_buffer_)); + + auto blas = stream_->parent()->AsBlas(); + if (blas == nullptr) { + return absl::InternalError("No BLAS support for stream"); + } + blas->GetBlasGemmAlgorithms(stream_, desc.lhs, desc.rhs, &desc.output, + &gemm_config.alpha, &gemm_config.beta, + &algorithms); -#if TENSORFLOW_USE_ROCM // Blas gemm algorithms are not yet supported + AutotuneResult best_algorithm; +#if TENSORFLOW_USE_ROCM // Blas gemm algorithms can be empty for ROCM if (algorithms.empty()) { // nothing to autotune - VLOG(1) << "Skipping autotuning for ROCm.."; + LOG(WARNING) << "No solutions found: skipping autotuning for ROCM.."; best_algorithm.mutable_gemm()->set_algorithm(se::blas::kDefaultAlgorithm); return best_algorithm; } #endif - TF_ASSIGN_OR_RETURN( - best_algorithm, - GetBestBlasAlgorithm( - stream, buffer_allocator, gemm->ToString(), autotune_config, - lhs_buffer, rhs_buffer, output_buffer, algorithms, output_shape, - hlo_module_config, gemm_config.beta(), - [&](const se::blas::AlgorithmType& algorithm) - -> StatusOr { - se::blas::ProfileResult profile_result; - // We expect GemmWithAlgorithm to fail sometimes - // -- in fact, it will fail for all algorithms if - // we're targeting < sm_50. But because we pass a - // non-null ProfileResult, DoGemmWithAlgorithm - // should always return true, and the actual - // success-ness is returned in - // ProfileResult::is_valid. - TF_RETURN_IF_ERROR(RunGemm(config, lhs_buffer, rhs_buffer, - output_buffer, workspace_buffer, - deterministic_ops, stream, algorithm, - &profile_result)); - return std::move(profile_result); - })); + auto tuned_func = [&](const se::blas::AlgorithmType& algorithm) + -> absl::StatusOr { + // Do a warm-up run first, without a profile result. RunGemm swallows + // error codes when profile_result is passed, as it is in the measurement + // below, but not otherwise. It is, therefore, consistent to ignore the + // error code here. + static_cast(RunGemm(gemm_config, lhs_buffer_, rhs_buffer_, + output_buffer_, workspace_buffer, + deterministic_ops_, stream_, algorithm)); + se::blas::ProfileResult profile_result; + // Allow GpuTimer to use its delay kernel implementation to improve + // accuracy. + profile_result.set_warmup_run_executed(true); + // We expect GemmWithAlgorithm to fail sometimes -- in fact, it will fail + // for all algorithms if we're targeting < sm_50. But because we pass a + // non-null ProfileResult, DoGemmWithAlgorithm should always return true, + // and the actual success-ness is returned in ProfileResult::is_valid. + TF_RETURN_IF_ERROR(RunGemm(gemm_config, lhs_buffer_, rhs_buffer_, + output_buffer_, workspace_buffer, + deterministic_ops_, stream_, algorithm, + &profile_result)); + return std::move(profile_result); + }; + + TF_ASSIGN_OR_RETURN(best_algorithm, + GetBestAlgorithm( + gemm, algorithms, gemm_config.beta, tuned_func)); if (best_algorithm.has_gemm()) { int alg_idx = best_algorithm.gemm().algorithm(); best_algorithm.mutable_gemm()->set_algorithm(algorithms[alg_idx]); } + return best_algorithm; } - return best_algorithm; -} + + // Returns the index (into `algorithms`) of the fastest algorithm. + template + absl::StatusOr GetBestAlgorithm( + const HloInstruction* gemm, absl::Span algorithms, + double beta, TunedFunc&& run_benchmark) { + static_assert(std::is_invocable_r_v, + TunedFunc, const AlgoT&>, + "Tuned function has incorrect prototype!"); + + if (!stream_->parent()->SynchronizeAllActivity()) { + return Internal("Failed to synchronize GPU for autotuning."); + } + + auto& hlo_module_config = gemm->GetModule()->mutable_config(); + const auto& output_shape = GetOutputShape(gemm); + + se::DeviceMemoryBase reference_buffer; + if (autotune_config_.should_check_correctness()) { + TF_ASSIGN_OR_RETURN(reference_buffer, + redzone_allocator_->AllocateBytes( + ShapeUtil::ByteSizeOf(output_shape))); + } + + BufferComparator comparator(output_shape, hlo_module_config); + std::vector results; + results.reserve(algorithms.size()); + std::optional reference_algorithm; + + for (const AlgoT& algorithm : algorithms) { + // Make sure the output buffer always has the same value if we use + // the bias parameter. + if (autotune_config_.should_reinit_output_buffer() && beta != 0) { + int64_t rng_state = 0; + InitializeBuffer(stream_, output_shape.element_type(), &rng_state, + output_buffer_); + } + TF_ASSIGN_OR_RETURN(auto profile_result, run_benchmark(algorithm)); + + AutotuneResult& result = results.emplace_back(); + result.mutable_gemm()->set_algorithm(profile_result.algorithm()); + + if (!profile_result.is_valid()) { // Unsupported algorithm. + result.mutable_failure()->set_kind(AutotuneResult::DISQUALIFIED); + continue; + } + + VLOG(2) << "gemm algorithm " << profile_result.algorithm() << " took " + << profile_result.elapsed_time_in_ms() << "ms"; + + *result.mutable_run_time() = tsl::proto_utils::ToDurationProto( + absl::Milliseconds(profile_result.elapsed_time_in_ms())); + + if (!autotune_config_.should_check_correctness()) { + continue; + } + TF_ASSIGN_OR_RETURN( + se::RedzoneAllocator::RedzoneCheckStatus rz_check_status, + redzone_allocator_->CheckRedzones()); + + if (!rz_check_status.ok()) { + result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED); + *result.mutable_failure()->mutable_msg() = + rz_check_status.RedzoneFailureMsg(); + LOG(ERROR) << "Detected out-of-bounds write in gemm buffer"; + CHECK(!autotune_config_.should_crash_on_check_failure()); + continue; + } + + if (!reference_algorithm) { + TF_RETURN_IF_ERROR(stream_->Memcpy(&reference_buffer, output_buffer_, + output_buffer_.size())); + reference_algorithm = profile_result.algorithm(); + } else { + // Perform the comparison. + TF_ASSIGN_OR_RETURN( + bool outputs_match, + comparator.CompareEqual(stream_, /*current=*/output_buffer_, + /*expected=*/reference_buffer)); + if (!outputs_match) { + LOG(ERROR) << "Results mismatch between different GEMM algorithms. " + << "This is likely a bug/unexpected loss of precision."; + CHECK(!autotune_config_.should_crash_on_check_failure()); + + result.mutable_failure()->set_kind(AutotuneResult::WRONG_RESULT); + result.mutable_failure()->mutable_reference_gemm()->set_algorithm( + *reference_algorithm); + } + } + } // for algorithms + + absl::StatusOr best = + PickBestResult(results, gemm->ToString(), hlo_module_config); + if (best.ok()) { + for (size_t i = 0; i < results.size(); ++i) { + if (best->gemm().algorithm() == results[i].gemm().algorithm()) { + best->mutable_gemm()->set_algorithm(i); + return best; + } + } + return Internal("unknown best algorithm"); + } + LOG(WARNING) << "Failed to find best cuBLAS algorithm, GEMM performance " + "might be suboptimal: " + << best.status(); + return AutotuneResult{}; + } // GetBestAlgorithm +}; // GemmAutotuner #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Do Gemm Autotune without stream executor. Use results from autotune cache // only. -StatusOr RunOnInstruction(HloInstruction* gemm, - const AutotuneConfig& config) { +absl::StatusOr RunOnInstruction(HloInstruction* gemm, + const AutotuneConfig& config) { VLOG(3) << "Loading the autotune result of GemmThunk " << gemm->ToString(); - GemmBackendConfig gemm_config = - gemm->backend_config().value(); + GpuBackendConfig gpu_config = + gemm->backend_config().value(); + GemmBackendConfig& backend_config = *gpu_config.mutable_gemm_backend_config(); + // Degenerate gemms replaced with memzero operation, no need to auto tune it. - if (gemm_config.alpha_real() == 0.0 && gemm_config.alpha_imag() == 0.0 && - gemm_config.beta() == 0.0) { + if (backend_config.alpha_real() == 0.0 && + backend_config.alpha_imag() == 0.0 && backend_config.beta() == 0.0) { VLOG(3) << "Skip degenerate gemm instruction auto tuning"; return false; } AutotuneCacheKey key(config.GetModelStr(), *gemm); - + GemmAutotuner autotuner(config); TF_ASSIGN_OR_RETURN(AutotuneResult algorithm, - AutotunerUtil::Autotune(gemm, config, [&] { - return DoGemmAutotuneNoCache(gemm, key, config); - })); - - GemmBackendConfig updated_config = gemm_config; - - // We only set the 'algorithm' field on non-Ampere architectures, as for - // Ampere it's ignored in any case. - bool update_algorithm = true; -#if GOOGLE_CUDA - auto capability = config.GetCudaComputeCapability(); - update_algorithm = !capability.IsAtLeast(se::CudaComputeCapability::AMPERE); -#endif + AutotunerUtil::Autotune( + gemm, config, [&] { return autotuner(gemm, key); })); + + auto old_algorithm = backend_config.selected_algorithm(); + bool update_algorithm = + IsCublasLtMatmulF8(*gemm) || + std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { + // We only set the 'algorithm' field on + // non-Ampere architectures, as for Ampere + // it's ignored in any case. + return !cc.IsAtLeast( + se::CudaComputeCapability::AMPERE); + }, + [](const se::RocmComputeCapability&) { + return true; // TODO: not decided yet + }}, + config.GetGpuComputeCapability()); + if (update_algorithm) { + int64_t new_algorithm{}; if (algorithm.has_gemm()) { - updated_config.set_selected_algorithm(algorithm.gemm().algorithm()); + new_algorithm = algorithm.gemm().algorithm(); } else { - updated_config.set_selected_algorithm(se::blas::kRuntimeAutotuning); + // NOTE: runtime autotuning is no longer available => set to default + new_algorithm = se::blas::kDefaultAlgorithm; + } + + if (new_algorithm == old_algorithm && + backend_config.has_selected_algorithm()) { + // We don't need to update the backend config if + // the algorithm hasn't changed unless previously + // the algorithm wasn't set explicitly. + return false; } + + backend_config.set_selected_algorithm(new_algorithm); + TF_RETURN_IF_ERROR(gemm->set_backend_config(gpu_config)); + return true; // We changed `gemm` } - TF_RETURN_IF_ERROR(gemm->set_backend_config(updated_config)); - return updated_config.SerializeAsString() != gemm_config.SerializeAsString(); + + return false; // No change to `gemm` } -StatusOr RunOnComputation(HloComputation* computation, - AutotuneConfig config) { +absl::StatusOr RunOnComputation(HloComputation* computation, + AutotuneConfig config) { bool changed = false; for (HloInstruction* instr : computation->instructions()) { if (IsCublasGemm(*instr)) { @@ -433,7 +466,7 @@ StatusOr RunOnComputation(HloComputation* computation, } // namespace -StatusOr GemmAlgorithmPicker::Run( +absl::StatusOr GemmAlgorithmPicker::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_SCOPED_LOGGING_TIMER( diff --git a/xla/service/gpu/gemm_algorithm_picker.h b/xla/service/gpu/gemm_algorithm_picker.h index 99cae9f796f49..2a58205e81b3d 100644 --- a/xla/service/gpu/gemm_algorithm_picker.h +++ b/xla/service/gpu/gemm_algorithm_picker.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,30 +17,32 @@ limitations under the License. #include #include -#include #include -#include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/autotune_results.pb.h" #include "xla/autotuning.pb.h" -#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/gpu/autotuner_util.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/stream_executor/device_description.h" +#include "xla/shape.h" +#include "xla/stream_executor/blas.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/service/gpu/gpu_conv_runner.h" #include "xla/stream_executor/gpu/redzone_allocator.h" #endif namespace xla { namespace gpu { -StatusOr GetBestBlasAlgorithm( +absl::StatusOr GetBestBlasAlgorithm( se::Stream* stream, se::RedzoneAllocator& allocator, std::optional gemm_str, const AutotuneConfig& autotune_config, se::DeviceMemoryBase lhs_buffer, @@ -48,7 +50,7 @@ StatusOr GetBestBlasAlgorithm( absl::Span algorithms, const Shape& output_shape, const HloModuleConfig& hlo_module_config, double beta, - const std::function( + const std::function( const se::blas::AlgorithmType&)>& run_benchmark); // GemmAlgorithmPicker supports two modes: device and deviceless. @@ -63,7 +65,7 @@ class GemmAlgorithmPicker : public HloModulePass { absl::string_view name() const override { return "gemm-algorithm-picker"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/gpu/gemm_algorithm_picker_test.cc b/xla/service/gpu/gemm_algorithm_picker_test.cc index 77bcb13661d7b..d531d4df1ef2b 100644 --- a/xla/service/gpu/gemm_algorithm_picker_test.cc +++ b/xla/service/gpu/gemm_algorithm_picker_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,13 +15,21 @@ limitations under the License. #include "xla/service/gpu/gemm_algorithm_picker.h" -#include +#include +#include +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gemm_rewriter.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/service/platform_util.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream_executor_pimpl.h" #include "xla/tests/hlo_test_base.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" @@ -44,9 +52,33 @@ class GemmAlgorithmPickerTest : public HloTestBase, debug_options.set_xla_gpu_enable_triton_gemm(false); return debug_options; } + + void SetUp() override { + const auto& gpu_cc = backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + + if (auto* procm = std::get_if(&gpu_cc)) { + if (GetDebugOptionsForTest().xla_gpu_enable_cublaslt() && + !procm->has_hipblaslt()) { + GTEST_SKIP() << "No gpublas-lt support on this architecture!"; + } + } + } }; TEST_P(GemmAlgorithmPickerTest, SetAlgorithm) { + auto comp = backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + if (comp.IsAtLeast(se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() << "Skipping this test for Ampere+ as it is supported and " + "recommended with " + "the Nvidia Volta+ GPUs."; + } + constexpr absl::string_view kHlo = R"( HloModule module @@ -111,12 +143,23 @@ ENTRY main { GmockMatch(m::GetTupleElement(m::CustomCall(&dot), 0))); } - TF_ASSERT_OK_AND_ASSIGN(GemmBackendConfig config, - dot->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + dot->backend_config()); + const GemmBackendConfig& config = gpu_config.gemm_backend_config(); EXPECT_EQ(config.selected_algorithm(), new_algo_id); } TEST_P(GemmAlgorithmPickerTest, GetAlgorithmWithoutDevice) { + auto comp = backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + if (comp.IsAtLeast(se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() << "Skipping this test for Ampere+ as it is supported and " + "recommended with " + "the Nvidia Volta+ GPUs."; + } + constexpr absl::string_view kHlo = R"( HloModule module @@ -189,8 +232,10 @@ ENTRY main { GmockMatch(m::GetTupleElement(m::CustomCall(&dot), 0))); } - TF_ASSERT_OK_AND_ASSIGN(GemmBackendConfig config, - dot->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + dot->backend_config()); + const GemmBackendConfig& config = gpu_config.gemm_backend_config(); + EXPECT_EQ(config.selected_algorithm(), new_algo_id); } diff --git a/xla/service/gpu/gemm_broadcast_folding_rewriter.cc b/xla/service/gpu/gemm_broadcast_folding_rewriter.cc index 0e4fb8485650c..a6cbbf11c94fd 100644 --- a/xla/service/gpu/gemm_broadcast_folding_rewriter.cc +++ b/xla/service/gpu/gemm_broadcast_folding_rewriter.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,19 +15,22 @@ limitations under the License. #include "xla/service/gpu/gemm_broadcast_folding_rewriter.h" +#include + #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" -#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/pattern_matcher.h" -#include "xla/status_macros.h" -#include "xla/statusor.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -36,7 +39,7 @@ namespace m = match; class GemmBroadcastFoldingVisitor : public DfsHloRewriteVisitor { public: - Status HandleCustomCall(HloInstruction *instr) override { + absl::Status HandleCustomCall(HloInstruction *instr) override { HloInstruction *existing_gemm; HloInstruction *bcast; if (Match(instr, m::CustomCall(&existing_gemm, @@ -45,8 +48,9 @@ class GemmBroadcastFoldingVisitor : public DfsHloRewriteVisitor { (Match(instr, m::CustomCall(&existing_gemm, {kGemmCallTarget, kCublasLtMatmulCallTarget}) .WithOperand(1, m::Broadcast(&bcast, m::Op()))))) { - TF_ASSIGN_OR_RETURN(auto config, - existing_gemm->backend_config()); + TF_ASSIGN_OR_RETURN(auto gpu_config, + existing_gemm->backend_config()); + GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config(); DotDimensionNumbers *dim_nums = config.mutable_dot_dimension_numbers(); int bcast_operand_index = instr->operand_index(bcast); int num_bcast_dims = (bcast->shape().dimensions_size() - @@ -63,11 +67,11 @@ class GemmBroadcastFoldingVisitor : public DfsHloRewriteVisitor { // dimensions are >= num_bcast_dims. for (int64_t bcast_dim : bcast->dimensions()) { if (bcast_dim < num_bcast_dims) { - return OkStatus(); + return absl::OkStatus(); } // bcast_dim should not be in batch_dimensions. if (absl::c_linear_search(batch_dimensions, bcast_dim)) { - return OkStatus(); + return absl::OkStatus(); } } @@ -75,7 +79,7 @@ class GemmBroadcastFoldingVisitor : public DfsHloRewriteVisitor { // there is at least one batch dimension. CHECK_GT(num_bcast_dims, 0); if (num_bcast_dims != num_batch_dims) { - return OkStatus(); + return absl::OkStatus(); } if (bcast_operand_index == 1) { @@ -91,20 +95,20 @@ class GemmBroadcastFoldingVisitor : public DfsHloRewriteVisitor { } TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWithDifferentShape( bcast_operand_index, bcast->mutable_operand(0))); - TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(config)); + TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(gpu_config)); MarkAsChanged(); } - return OkStatus(); + return absl::OkStatus(); } }; -static StatusOr RunOnComputation(HloComputation *computation) { +static absl::StatusOr RunOnComputation(HloComputation *computation) { GemmBroadcastFoldingVisitor visitor; TF_RETURN_IF_ERROR(computation->Accept(&visitor)); return visitor.changed(); } -StatusOr GemmBroadcastFoldingRewriter::Run( +absl::StatusOr GemmBroadcastFoldingRewriter::Run( HloModule *module, const absl::flat_hash_set &execution_threads) { bool changed = false; diff --git a/xla/service/gpu/gemm_broadcast_folding_rewriter.h b/xla/service/gpu/gemm_broadcast_folding_rewriter.h index 46fdc83973876..bac14bc971138 100644 --- a/xla/service/gpu/gemm_broadcast_folding_rewriter.h +++ b/xla/service/gpu/gemm_broadcast_folding_rewriter.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,9 +15,9 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GEMM_BROADCAST_FOLDING_REWRITER_H_ #define XLA_SERVICE_GPU_GEMM_BROADCAST_FOLDING_REWRITER_H_ -#include - -#include "xla/hlo/ir/hlo_instructions.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" @@ -40,7 +40,7 @@ class GemmBroadcastFoldingRewriter : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/gemm_fusion.cc b/xla/service/gpu/gemm_fusion.cc new file mode 100644 index 0000000000000..05e758a73f3d4 --- /dev/null +++ b/xla/service/gpu/gemm_fusion.cc @@ -0,0 +1,822 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/gemm_fusion.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/cublas_padding_requirements.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/triton_fusion_analysis.h" +#include "xla/service/gpu/triton_support.h" +#include "xla/service/gpu/triton_tiling_propagation.h" +#include "xla/service/instruction_fusion.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/stream_executor/device_description.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/tensor_float_32_utils.h" + +namespace xla { +namespace gpu { + +namespace { + +using triton_fusion::CombineRequirements; +using triton_fusion::DimensionOrder; +using triton_fusion::DimOrderMap; +using triton_fusion::DimOrdersAndReqs; +using triton_fusion::DimOrdersAndReqsOrError; +using triton_fusion::DotRequirements; +using triton_fusion::FusionContext; +using triton_fusion::GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible; +using triton_fusion::HeroProperties; +using triton_fusion::Requirements; +using triton_fusion::RequirementsOrError; +using triton_fusion::TransformDirection; + +// This represents a directed graph. +class AdjacencyList { + public: + using NodeId = int64_t; + + NodeId AddNode() { + adj_.emplace_back(); + return adj_.size() - 1; + } + + const std::vector& GetOutNeighbors(NodeId node_id) const { + return adj_.at(node_id); + } + + void ReserveSpaceForOutNeighbors(NodeId node_id, size_t count) { + adj_.at(node_id).reserve(count); + } + + void AddArc(NodeId from, NodeId to) { adj_.at(from).push_back(to); } + + // Currently the Root node is the node which was added first. + NodeId GetRoot() const { + CHECK(!adj_.empty()); + return 0; + } + + private: + // Adjacency list: A vector of out-neighbors for each node. + std::vector> adj_; +}; + +struct HloAndDimOrder { + const HloInstruction* original_hlo = nullptr; + DimensionOrder dim_order; +}; + +struct HloAndIterSpec { + const HloInstruction* original_hlo; + TensorIterationSpec iter_spec; + + auto ToTuple() const { return std::make_tuple(original_hlo, iter_spec); } + bool operator==(const HloAndIterSpec& other) const { + return ToTuple() == other.ToTuple(); + } + template + friend H AbslHashValue(H h, const HloAndIterSpec& key) { + return H::combine(std::move(h), key.ToTuple()); + } +}; + +struct NodeFusionPlan { + const HloInstruction* original_hlo = nullptr; + bool should_fuse = false; +}; + +struct FusionPlan { + // The graph describing the structure of the fusion that we build - nodes + // corresponding to the instructions and arcs pointing from users to operands. + AdjacencyList graph; + // The fusion plan for each node. + absl::flat_hash_map map; +}; + +struct FusionPlanAndRequirements { + FusionPlan fusion_plan; + Requirements requirements; +}; + +struct HlosAndRequirements { + // The original HLO (which is outside the fusion computation). + const HloInstruction* original_hlo = nullptr; + // The fused HLO inside the new fusion computation, built by the builder. + // + // This can have the same opcode as `original_hlo` or it can be a parameter if + // the original HLO can't be fused. + const HloInstruction* fused_hlo = nullptr; + // The requirements imposed by the fused operations. + // + // If we fuse further operations they may have to conform to these + // requirements. + Requirements requirements; +}; + +// Clones the hero kDot operation into the fusion. +HloInstruction& FuseDot(const HloDotInstruction& dot, + const HloInstruction& fused_lhs, + const HloInstruction& fused_rhs, + std::optional fused_meta, + HloComputation::Builder& builder // append +) { + VLOG(3) << "Fusing " << dot.ToString(); + + std::vector hlo_new_operands = { + const_cast(&fused_lhs), + const_cast(&fused_rhs)}; + if (fused_meta.has_value()) { + hlo_new_operands.push_back(const_cast(fused_meta.value())); + } + return *builder.AddInstruction( + dot.CloneWithNewOperands(dot.shape(), hlo_new_operands)); +} + +// Tells how many new parameters does a fusion gain by fusing the operation as +// an input. +int64_t NumAddedParameters(const HloInstruction& hlo) { + // Non-scalar constant is equivalent to a parameter: one input, one output. + if (hlo.opcode() == HloOpcode::kParameter || + (hlo.opcode() == HloOpcode::kConstant && + !ShapeUtil::IsScalar(hlo.shape()))) { + return 0; + } + // All other instructions add all own inputs and remove own single output. + return hlo.operand_count() - 1; +} + +// Just a helper to reduce "unwrapping" code where we use this. +std::optional GetOperandDimOrdersAndCombinedReqs( + const HloInstruction& hlo, const DimensionOrder& dim_order, + const HeroProperties& properties, + const se::GpuComputeCapability& gpu_version, + const Requirements& requirements) { + DimOrdersAndReqsOrError dim_orders_and_new_reqs = + GetPropagatedDimOrdersAndRequirements( + hlo, dim_order, TransformDirection::kOutputToInput, properties); + if (!std::holds_alternative(dim_orders_and_new_reqs)) { + return std::nullopt; + } + RequirementsOrError combined_reqs = CombineRequirements( + requirements, + std::get(dim_orders_and_new_reqs).requirements); + if (!std::holds_alternative(combined_reqs)) { + return std::nullopt; + } + return DimOrdersAndReqs{ + std::get(dim_orders_and_new_reqs).dim_orders, + std::get(combined_reqs)}; +} + +// Just a helper to reduce "unwrapping" code where we use this. +std::optional GetOperandDimOrdersAndCombinedReqsIfProfitable( + const HloInstruction& hlo, const DimensionOrder& dim_order, + const HeroProperties& properties, + const se::GpuComputeCapability& gpu_version, + const Requirements& requirements) { + DimOrdersAndReqsOrError dim_orders_and_new_reqs = + GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( + hlo, TransformDirection::kOutputToInput, + /*src_operand_index=*/std::nullopt, dim_order, gpu_version, + properties); + if (!std::holds_alternative(dim_orders_and_new_reqs)) { + return std::nullopt; + } + RequirementsOrError combined_reqs = CombineRequirements( + requirements, + std::get(dim_orders_and_new_reqs).requirements); + if (!std::holds_alternative(combined_reqs)) { + return std::nullopt; + } + return DimOrdersAndReqs{ + std::get(dim_orders_and_new_reqs).dim_orders, + std::get(combined_reqs)}; +} + +// Just a helper to reduce "unwrapping" code where we use this. +std::optional GetUserDimOrdersAndCombinedReqsIfProfitable( + const HloInstruction& hlo, const DimensionOrder& hlo_dim_order, + const HloInstruction& user, const HeroProperties& properties, + const se::GpuComputeCapability& gpu_version, + const Requirements& requirements) { + DimOrdersAndReqsOrError dim_orders_and_new_reqs = + GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( + user, TransformDirection::kInputToOutput, user.operand_index(&hlo), + hlo_dim_order, gpu_version, properties); + if (!std::holds_alternative(dim_orders_and_new_reqs)) { + return std::nullopt; + } + RequirementsOrError combined_reqs = CombineRequirements( + requirements, + std::get(dim_orders_and_new_reqs).requirements); + if (!std::holds_alternative(combined_reqs)) { + return std::nullopt; + } + return DimOrdersAndReqs{ + std::get(dim_orders_and_new_reqs).dim_orders, + std::get(combined_reqs)}; +} + +// Builds the fusion map and the requirements which can later be used to +// actually fuse that subgraph. +FusionPlanAndRequirements BuildFusionPlanTowardOperands( + const HloInstruction& root_hlo, const DimensionOrder& root_dim_order, + const std::optional& max_params, + const se::GpuComputeCapability& gpu_version, + const HeroProperties& properties, const Requirements& requirements_so_far) { + CHECK(!max_params.has_value() || max_params.value() >= 1); + + // The graph describing the structure of the fusion that we build - nodes + // corresponding to the instructions and arcs pointing from users to operands. + // We can build and modify this graph easily without the need to create + // HloInstructions at this point. + AdjacencyList graph; + // Stores the original HLO and the dimension order for each node. This is a + // temporary map which is used when processing the nodes in this function. + absl::flat_hash_map + hlo_and_dim_order_map; + // Stores the information needed to build the fused HLO for each node (what + // was the original HLO and whether we should fuse it or create a parameter). + // This is one of the outputs of this function. + absl::flat_hash_map fusion_plan_map; + // Allows reusing nodes when multiple instructions iterate over the same HLO + // using the same iteration spec. In that case we don't duplicate the + // instruction in the fusion. + absl::flat_hash_map node_reuse_map; + // The requirements imposed by the fusion choices made in this function, + // combined with the existing requirements. This is one of the outputs of this + // function. + Requirements combined_reqs = requirements_so_far; + + auto get_or_create_fusion_node = + [&](const HloInstruction& hlo, const DimensionOrder& dim_order, + bool* is_new_node = nullptr) -> AdjacencyList::NodeId { + HloAndIterSpec reuse_key = {&hlo, dim_order.ToTensorIterationSpec()}; + if (auto it = node_reuse_map.find(reuse_key); it != node_reuse_map.end()) { + if (is_new_node != nullptr) { + *is_new_node = false; + } + return it->second; + } + AdjacencyList::NodeId node_id = graph.AddNode(); + CHECK(hlo_and_dim_order_map.insert({node_id, {&hlo, dim_order}}).second); + CHECK(node_reuse_map.insert({reuse_key, node_id}).second); + if (is_new_node != nullptr) { + *is_new_node = true; + } + return node_id; + }; + AdjacencyList::NodeId root = + get_or_create_fusion_node(root_hlo, root_dim_order); + + // Nodes at the fusion edge that can either get fused too or become parameters + // of the fusion. Used to track the number of parameters. + absl::flat_hash_set inputs({root}); + std::queue queue({root}); + int64_t num_requeued = 0; + // BFS + while (queue.size() > num_requeued) { + AdjacencyList::NodeId node_id = queue.front(); + queue.pop(); + const HloAndDimOrder& hlo_and_dim_order = hlo_and_dim_order_map.at(node_id); + const HloInstruction& original_hlo = *hlo_and_dim_order.original_hlo; + const DimensionOrder& dim_order = hlo_and_dim_order.dim_order; + + // Watch the total number of fusion parameters. + if (max_params.has_value() && + inputs.size() + NumAddedParameters(original_hlo) > max_params.value()) { + // Re-queue: the number of parameters may go down when other instructions + // are processed. + queue.push(node_id); + // Prevent infinite loops. + ++num_requeued; + continue; + } + num_requeued = 0; + if (original_hlo.opcode() == HloOpcode::kParameter) { + CHECK(fusion_plan_map + .insert({node_id, {&original_hlo, /*should_fuse=*/false}}) + .second); + continue; + } + auto opt_result = GetOperandDimOrdersAndCombinedReqsIfProfitable( + original_hlo, dim_order, properties, gpu_version, combined_reqs); + if (!opt_result.has_value()) { + CHECK(fusion_plan_map + .insert({node_id, {&original_hlo, /*should_fuse=*/false}}) + .second); + continue; + } + const DimOrderMap operand_dim_orders = std::move(opt_result->dim_orders); + combined_reqs = std::move(opt_result->requirements); + inputs.erase(node_id); + graph.ReserveSpaceForOutNeighbors(node_id, original_hlo.operand_count()); + for (int64_t i = 0; i < original_hlo.operand_count(); ++i) { + const HloInstruction& operand = *original_hlo.operand(i); + const DimensionOrder& operand_dim_order = operand_dim_orders.at(&operand); + bool is_new_node = false; + AdjacencyList::NodeId operand_node_id = + get_or_create_fusion_node(operand, operand_dim_order, &is_new_node); + graph.AddArc(node_id, operand_node_id); + if (is_new_node) { + VLOG(6) << "Enqueueing " << operand.ToString() << ":" + << operand_dim_order.ToString(); + inputs.insert(operand_node_id); + queue.push(operand_node_id); + } + } + CHECK( + fusion_plan_map.insert({node_id, {&original_hlo, /*should_fuse=*/true}}) + .second); + } + // Handle the remaining requeued items. + while (!queue.empty()) { + AdjacencyList::NodeId node_id = queue.front(); + queue.pop(); + + const HloAndDimOrder& hlo_and_dim_order = hlo_and_dim_order_map.at(node_id); + CHECK(fusion_plan_map + .insert({node_id, + {hlo_and_dim_order.original_hlo, /*should_fuse=*/false}}) + .second); + } + return {{std::move(graph), std::move(fusion_plan_map)}, + std::move(combined_reqs)}; +} + +// Builds the HLO instructions for the fusion represented by `fusion_plan`, +// starting from `node_id`. +HloInstruction& BuildFusionTowardOperandsImpl( + AdjacencyList::NodeId node_id, const FusionPlan& fusion_plan, + absl::flat_hash_map& + fused_hlo_map, // read/append + HloComputation::Builder& builder, // append + std::vector& fusion_params // append +) { + if (auto it = fused_hlo_map.find(node_id); it != fused_hlo_map.end()) { + return *it->second; + } + + const NodeFusionPlan& node_fusion_plan = fusion_plan.map.at(node_id); + const bool should_fuse = node_fusion_plan.should_fuse; + const HloInstruction& original_hlo = *node_fusion_plan.original_hlo; + + HloInstruction* fused_hlo = nullptr; + if (should_fuse) { + HloInstruction::InstructionVector new_operands; + for (AdjacencyList::NodeId operand_id : + fusion_plan.graph.GetOutNeighbors(node_id)) { + new_operands.push_back(&BuildFusionTowardOperandsImpl( + operand_id, fusion_plan, fused_hlo_map, builder, fusion_params)); + } + fused_hlo = builder.AddInstruction( + original_hlo.CloneWithNewOperands(original_hlo.shape(), new_operands)); + } else { + fusion_params.push_back(const_cast(&original_hlo)); + fused_hlo = builder.AddInstruction(HloInstruction::CreateParameter( + fusion_params.size() - 1, original_hlo.shape(), + absl::StrCat("parameter_", fusion_params.size() - 1))); + } + + CHECK(fused_hlo_map.insert({node_id, fused_hlo}).second); + return *fused_hlo; +} + +// Builds the HLO instructions for the fusion represented by `fusion_plan`. +HloInstruction& BuildFusionTowardOperands( + const FusionPlan& fusion_plan, + HloComputation::Builder& builder, // append + std::vector& fusion_params // append +) { + absl::flat_hash_map fused_hlo_map; + return BuildFusionTowardOperandsImpl(fusion_plan.graph.GetRoot(), fusion_plan, + fused_hlo_map, builder, fusion_params); +} + +// Grows the fusion toward the operands. +// +// This always succeeds. +// +// If it's not possible to fuse something, it fuses a parameter instead. +// +// The fusion can grow until it has `max_params` params and it can only grow +// with operations for which the DimOrder propagation works and they don't +// impose requirements contradicting the existing requirements. +// +// The return value contains the HLOs corresponding to `root_hlo` and the +// requirements corresponding to the whole fusion so far. +HlosAndRequirements FuseTowardOperands( + const HloInstruction& root_hlo, const DimensionOrder& root_dim_order, + const std::optional& max_params, + const se::GpuComputeCapability& gpu_version, + const HeroProperties& properties, const Requirements& requirements_so_far, + HloComputation::Builder& builder, // append + std::vector& fusion_params // append +) { + FusionPlanAndRequirements fusion_plan_and_reqs = + BuildFusionPlanTowardOperands(root_hlo, root_dim_order, max_params, + gpu_version, properties, + requirements_so_far); + HloInstruction& fused_hlo_or_param = BuildFusionTowardOperands( + fusion_plan_and_reqs.fusion_plan, builder, fusion_params); + return HlosAndRequirements{&root_hlo, &fused_hlo_or_param, + fusion_plan_and_reqs.requirements}; +} + +// Grows the fusion toward the given dot operand. +// +// This always succeeds. +// +// If it's not possible to fuse something, it fuses a parameter instead. +// +// The fusion can grow until it has `max_params` params and it can only grow +// with operations for which the DimOrder propagation works and they don't +// impose requirements contradicting the existing requirements. +// +// The return value contains the HLOs corresponding to the given dot operand and +// the requirements corresponding to the whole fusion so far. +absl::StatusOr FuseDotOperand( + const HloInstruction& dot, int operand_index, + const se::GpuComputeCapability& gpu_version, + HloComputation::Builder& builder, // append + std::vector& fusion_params // append +) { + // Direct dot inputs have well defined dimension orders. + TF_ASSIGN_OR_RETURN(const FusionContext context, + FusionContext::FromDotOperand(dot, operand_index)); + const HloInstruction& operand = *dot.operand(operand_index); + return FuseTowardOperands(operand, context.dim_orders().at(&operand), + TritonFusionAnalysis::kMaxParameterPerDotOperand, + gpu_version, context.hero_properties(), + context.requirements(), builder, fusion_params); +} + +// Grows the fusion toward the users. +// +// This always succeeds. +// +// The fusion can grow as long as the DimOrder propagation works and the users +// don't impose requirements contradicting the existing requirements. +// +// The return value contains the HLOs corresponding to the "lowest" fused user +// or `hlo` if no users can be fused. +// +// It also grows the fusion upward, toward the "other" operands of the users, +// but currently only in special cases, such as binary elementwise operation +// with broadcast of scalar constant. +HlosAndRequirements FuseTowardUsers( + const HloInstruction& hlo, const HloInstruction& fused_hlo, + const DimensionOrder& hlo_dim_order, + const se::GpuComputeCapability& gpu_version, + const HeroProperties& properties, const Requirements& requirements, + HloComputation::Builder& builder, // append + std::vector& fusion_params // append +) { + const HlosAndRequirements existing_hlos_and_requirements = {&hlo, &fused_hlo, + requirements}; + if (hlo.user_count() != 1) { + return existing_hlos_and_requirements; + } + const HloInstruction& user = *hlo.users()[0]; + if (!IsDistributiveOverAddition(user)) { + return existing_hlos_and_requirements; + } + + // Get the dim orders for the user. + auto opt_user_result = GetUserDimOrdersAndCombinedReqsIfProfitable( + hlo, hlo_dim_order, user, properties, gpu_version, requirements); + if (!opt_user_result.has_value()) { + return existing_hlos_and_requirements; + } + DimensionOrder user_dim_order = opt_user_result->dim_orders.at(&user); + Requirements combined_requirements = opt_user_result->requirements; + + HloInstruction::InstructionVector new_operands; + if (user.operand_count() == 1) { + new_operands.push_back(const_cast(&fused_hlo)); + } else { + // Get the dim orders for the operands of the user. + // We shouldn't do a profitability check here, we made that decision in + // GetUserDimOrdersAndCombinedReqsIfProfitable. + auto opt_operand_result = GetOperandDimOrdersAndCombinedReqs( + user, user_dim_order, properties, gpu_version, combined_requirements); + // This shouldn't fail, because currently we only encounter this when we + // have just propagated down the DimOrders on a binary elementwise + // operation (user). In that case propagating up the DimOrders should always + // work. + if (!opt_operand_result.has_value()) { + return existing_hlos_and_requirements; + } + DimOrderMap operand_dim_orders = opt_operand_result->dim_orders; + combined_requirements = opt_operand_result->requirements; + + // Fuse the other operands of the user. + for (int i = 0; i < user.operand_count(); ++i) { + const HloInstruction& operand = *user.operand(i); + if (&operand == &hlo) { + new_operands.push_back(const_cast(&fused_hlo)); + } else { + HlosAndRequirements hlos_and_requirements = FuseTowardOperands( + operand, operand_dim_orders.at(&operand), + /*max_params=*/std::nullopt, gpu_version, properties, + combined_requirements, builder, fusion_params); + new_operands.push_back( + const_cast(hlos_and_requirements.fused_hlo)); + combined_requirements = hlos_and_requirements.requirements; + } + } + } + + const HloInstruction& fused_user = *builder.AddInstruction( + user.CloneWithNewOperands(user.shape(), new_operands)); + return FuseTowardUsers(user, fused_user, user_dim_order, gpu_version, + properties, combined_requirements, builder, + fusion_params); +} + +// Grows the fusion toward the users of the dot. +// +// This always succeeds. +// +// The fusion can grow as long as the DimOrder propagation works and the users +// don't impose requirements contradicting the existing requirements. +// +// The return value contains the HLOs corresponding to the "lowest" fused user +// or `dot` if no users can be fused. +// +// It also grows the fusion towards the "other" operands of the users, but +// currently only in special cases, such as binary elementwise operation with +// broadcast of scalar constant. +HlosAndRequirements FuseDotOutput( + const HloInstruction& dot, const HloInstruction& fused_dot, + const se::GpuComputeCapability& gpu_version, + const DotRequirements& requirements, + HloComputation::Builder& builder, // append + std::vector& fusion_params // append +) { + const auto context = + FusionContext::FromDotOutput(dot, /*split_k=*/1, requirements); + return FuseTowardUsers(dot, fused_dot, context.dim_orders().at(&dot), + gpu_version, context.hero_properties(), + context.requirements(), builder, fusion_params); +} + +// Fuses dot and the compatible and profitable to fuse operations around it +// into a new fusion computation constructed using the builder. fusion_inputs +// get populated with the non-fused instructions that become operands of the +// call to this fusion. fusion_output_ptr (if not nullptr) gets assigned the +// original instruction that has to be replaced by the call to the fusion. +absl::StatusOr CreateDotFusion( + const HloDotInstruction& dot, const se::GpuComputeCapability gpu_version, + HloComputation::Builder& builder, + std::vector& fusion_inputs, + HloInstruction** fusion_output_ptr) { + VLOG(5) << dot.ToString(); + if (CodegenDecision is_supported = + IsTritonSupportedInstruction(dot, gpu_version); + !is_supported) { + VLOG(3) << is_supported.Explain(); + return is_supported; + } + + // Verify sparse dot constraints. + if (dot.sparse_operands()) { + const SparsityDescriptor& descriptor = dot.sparsity().front(); + if (dot.sparse_operands() != 1 || descriptor.index() != 0) { + return InvalidArgument("Sparsity is only supported on left operand"); + } + if (descriptor.type() != SparsityType::SPARSITY_STRUCTURED_N_M || + descriptor.n() != 2 || descriptor.m() != 4) { + return InvalidArgument("Only 2:4 structured sparsity is supported"); + } + // DotDimensionSorter pass makes sure the sparse dimension is minor. + CHECK_EQ(descriptor.dimension(), dot.operand(0)->shape().rank() - 1); + } + + TF_ASSIGN_OR_RETURN(HlosAndRequirements lhs_hlos_and_reqs, + FuseDotOperand(dot, /*operand_index=*/0, gpu_version, + builder, fusion_inputs)); + TF_ASSIGN_OR_RETURN(HlosAndRequirements rhs_hlos_and_reqs, + FuseDotOperand(dot, /*operand_index=*/1, gpu_version, + builder, fusion_inputs)); + std::optional meta_hlo; + if (dot.sparse_operands()) { + TF_ASSIGN_OR_RETURN(HlosAndRequirements meta_hlos_and_reqs, + FuseDotOperand(dot, /*operand_index=*/2, gpu_version, + builder, fusion_inputs)); + meta_hlo.emplace(meta_hlos_and_reqs.fused_hlo); + } + HloInstruction& fused_dot = + FuseDot(dot, *lhs_hlos_and_reqs.fused_hlo, *rhs_hlos_and_reqs.fused_hlo, + meta_hlo, builder); + // For now the RHS doesn't support splits, so it also doesn't impose any + // requirements. + HlosAndRequirements fused_output_and_reqs = + FuseDotOutput(dot, fused_dot, gpu_version, + std::get(lhs_hlos_and_reqs.requirements), + builder, fusion_inputs); + + if (fusion_output_ptr != nullptr) { + *fusion_output_ptr = + const_cast(fused_output_and_reqs.original_hlo); + } + + const PrecisionConfig::Algorithm algorithm = + dot.precision_config().algorithm(); + if (algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 || + algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 || + dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any() || + dot.sparse_operands()) { + return FusionDecision{}; + } + + bool is_pure_matmul = true; + (void)builder.ForEachInstruction([&](const HloInstruction* fused_hlo) { + static constexpr std::array kPureOpcodes = { + HloOpcode::kBitcast, HloOpcode::kDot, HloOpcode::kParameter, + HloOpcode::kReshape}; + if (absl::c_find(kPureOpcodes, fused_hlo->opcode()) == kPureOpcodes.end()) { + is_pure_matmul = false; + // Stop iterating. + return absl::CancelledError(); + } + return absl::OkStatus(); + }); + if (!is_pure_matmul) { + return FusionDecision{}; + } + + return "No profitable operations to fuse."; +} + +// Extracts into fused computations parts of HLO graph including dot() +// operations that can target the triton GEMM emitter. +class GemmFusionVisitor : public DfsHloRewriteVisitor { + public: + explicit GemmFusionVisitor(const se::GpuComputeCapability& gpu_version) + : gpu_version_(gpu_version) {} + // Checks that a dot() should be targeting the triton GEMM emitter; + // if so - fuses all its compatible inputs and outputs as a new computation + // and replaces the original dot() with a call to the computation. + absl::Status HandleDot(HloInstruction* dot) override { + CHECK_EQ(dot->opcode(), HloOpcode::kDot); + + int64_t gemm_rewrite_size_threshold = + dot->GetModule() + ->config() + .debug_options() + .xla_gpu_gemm_rewrite_size_threshold(); + TF_ASSIGN_OR_RETURN(bool is_matmul_tiny, + IsMatrixMultiplicationTooSmallForRewriting( + *dot, gemm_rewrite_size_threshold)); + if (is_matmul_tiny && IsDotSupportedByClassicalEmitters(*dot)) { + return absl::OkStatus(); + } + + std::string fusion_name = absl::StrCat("gemm_fusion_", dot->name()); + HloComputation::Builder builder(absl::StrCat(fusion_name, "_computation")); + std::vector fusion_inputs; + HloInstruction* fusion_output = nullptr; + TF_ASSIGN_OR_RETURN( + const FusionDecision should_fuse, + CreateDotFusion(*Cast(dot), gpu_version_, builder, + fusion_inputs, &fusion_output)); + if (builder.last_added_instruction() == nullptr) { + return absl::OkStatus(); + } + // If a GEMM requiring padding for cuBLAS is encountered here this + // happened because earlier ShouldTritonHandleGEMM() accepted it and padding + // was skipped. Accept it ignoring profitability checks. + // TODO(rocm): check ROCM padding requirements. + if (std::holds_alternative(gpu_version_)) { + if (!CublasRequiresPadding( + *Cast(dot), + std::get(gpu_version_)) && + !should_fuse) { + return OkStatus(); + } + } + + HloComputation* computation = + dot->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(), + /*is_entry=*/false); + HloInstruction* dot_fusion = + dot->parent()->AddInstruction(HloInstruction::CreateFusion( + computation->root_instruction()->shape(), + HloInstruction::FusionKind::kCustom, fusion_inputs, computation)); + // Copy the metadata of the `dot` to the newly created `fusion` op. This + // is convenient for handling metadata in split-k rewriting subsequently. + dot_fusion->set_metadata(dot->metadata()); + dot_fusion->GetModule()->SetAndUniquifyInstrName(dot_fusion, fusion_name); + + TF_ASSIGN_OR_RETURN(auto gpu_config, + dot_fusion->backend_config()); + FusionBackendConfig& backend_config = + *gpu_config.mutable_fusion_backend_config(); + backend_config.set_kind(std::string(kTritonGemmFusionKind)); + TF_RETURN_IF_ERROR(dot_fusion->set_backend_config(gpu_config)); + + if (fusion_output->IsRoot()) { + fusion_output->parent()->set_root_instruction(dot_fusion); + TF_RETURN_IF_ERROR( + fusion_output->parent()->RemoveInstructionAndUnusedOperands( + fusion_output)); + MarkAsChanged(); + } else { + TF_RETURN_IF_ERROR(ReplaceInstruction(fusion_output, dot_fusion)); + } + XLA_VLOG_LINES(5, computation->ToString(HloPrintOptions::ShortParsable())); + return absl::OkStatus(); + } + + private: + se::GpuComputeCapability gpu_version_; +}; + +absl::StatusOr RunOnComputation( + HloComputation* computation, const se::GpuComputeCapability& gpu_version) { + GemmFusionVisitor visitor(gpu_version); + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + return visitor.changed(); +} + + +} // namespace + +bool ShouldTritonHandleGEMM(HloDotInstruction& dot, + const se::GpuComputeCapability& gpu_version) { + std::vector fusion_inputs; + HloComputation::Builder builder("disposable"); + return CreateDotFusion(dot, gpu_version, builder, fusion_inputs, + /*fusion_output_ptr=*/nullptr) + ->CanFuse(); +} + +absl::StatusOr GemmFusion::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + auto cuda_compute_capability = + std::get_if(&gpu_version_); + if (!cuda_compute_capability || !cuda_compute_capability->IsAtLeastAmpere()) { + return absl::FailedPreconditionError( + "Triton support is only enabled for Ampere GPUs and up."); + } + + bool changed = false; + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + TF_ASSIGN_OR_RETURN(bool result, + RunOnComputation(computation, gpu_version_)); + changed |= result; + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/gemm_fusion.h b/xla/service/gpu/gemm_fusion.h new file mode 100644 index 0000000000000..1138ad28a36a5 --- /dev/null +++ b/xla/service/gpu/gemm_fusion.h @@ -0,0 +1,57 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_GEMM_FUSION_H_ +#define XLA_SERVICE_GPU_GEMM_FUSION_H_ + +// This file contains the code for fusing dots and other operations into Triton +// GEMM fusions. + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/hlo_pass_interface.h" +#include "xla/service/instruction_fusion.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { +namespace gpu { + +// Filters GEMMs which are better to handle using Triton. +bool ShouldTritonHandleGEMM(HloDotInstruction&, + const se::GpuComputeCapability&); + +// Rewrite compatible dot() calls into custom calls with fused computations +// that target Triton-based matmul emitter. +class GemmFusion : public HloModulePass { + public: + explicit GemmFusion(const se::GpuComputeCapability& gpu_version) + : gpu_version_(gpu_version) {} + absl::string_view name() const override { return "triton-gemm-rewriter"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + se::GpuComputeCapability gpu_version_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_GEMM_FUSION_H_ diff --git a/xla/service/gpu/gemm_fusion_autotuner.cc b/xla/service/gpu/gemm_fusion_autotuner.cc new file mode 100644 index 0000000000000..58615453d7ea6 --- /dev/null +++ b/xla/service/gpu/gemm_fusion_autotuner.cc @@ -0,0 +1,1230 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/gemm_fusion_autotuner.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "third_party/gpus/cuda/include/cublas_v2.h" +#include "xla/autotuning.pb.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_clone_context.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/algorithm_util.h" +#include "xla/service/dump.h" +#include "xla/service/executable.h" +#include "xla/service/float_normalization.h" +#include "xla/service/gpu/autotuner_compile_util.h" +#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/buffer_comparator.h" +#include "xla/service/gpu/cudnn_fusion_compiler.h" +#include "xla/service/gpu/gemm_rewriter.h" +#include "xla/service/gpu/gpu_float_support.h" +#include "xla/service/gpu/gpu_fusible.h" +#include "xla/service/gpu/instruction_fusion.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/split_k_gemm_rewriter.h" +#include "xla/service/gpu/stream_executor_util.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/shaped_buffer.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/status_macros.h" +#include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/redzone_allocator.h" +#include "xla/stream_executor/stream.h" +#include "xla/tools/hlo_decomposer.h" +#include "xla/tsl/util/proto/proto_utils.h" +#include "xla/util.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" +#include "tsl/lib/core/bits.h" +#include "tsl/platform/blocking_counter.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/threadpool.h" + +// Log levels used in this file: +// VLOG(1): Overview +// VLOG(2): Autotuning progress +// VLOG(3): Autotuning progress - more frequent +// VLOG(4): Print all fusions +// VLOG(5): Profiling information for every tiling + +// TODO(b/317016172): Update usages of TritonGemmConfig to use newly exposed +// parameters. + +namespace xla { +namespace gpu { + +using ProfilingOutput = AutotunerCompileUtil::ProfilingOutput; + +namespace { + +// Currently supported minimum tile size. +constexpr int kMinTileSize = 16; +// Not a hard limit, just an assumption that should stay valid. +constexpr int kMaxTileSize = 512; + +// Default tiling when autotuning is disabled. +constexpr TritonGemmConfig kDefaultGemmTiling = {32, 32, 32, 1, 1, 4}; + +class GemmFusionAutotunerVisitor : public DfsHloRewriteVisitor { + public: + explicit GemmFusionAutotunerVisitor(const AutotuneConfig& config) + : config_(config) {} + + absl::Status HandleFusion(HloInstruction* hlo) override { + TF_ASSIGN_OR_RETURN(auto gpu_config, + hlo->backend_config()); + FusionBackendConfig& backend_config = + *gpu_config.mutable_fusion_backend_config(); + if (backend_config.kind() != kTritonGemmFusionKind && + backend_config.kind() != kCuDnnFusionKind) { + return absl::OkStatus(); + } + + VLOG(4) << "Processing " << hlo->ToString(); + if (!backend_config.has_triton_gemm_config() && + !backend_config.has_cudnn_fusion_config()) { + TF_ASSIGN_OR_RETURN( + AutotuneResult autotune_result, + AutotunerUtil::Autotune( + hlo, config_, [&]() -> absl::StatusOr { + if (config_.IsDeviceless()) { + return absl::InternalError(absl::StrCat( + "Expect autotune result cache hit for deviceless " + "compilation (HLO: ", + hlo->ToString(), ")")); + } + return absl::InternalError("Expect autotune result cache hit."); + })); + VLOG(4) << "Result: " << autotune_result.ShortDebugString(); + + if (autotune_result.has_triton()) { + *backend_config.mutable_triton_gemm_config() = autotune_result.triton(); + TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); + } else if (autotune_result.has_gemm()) { + // Falling back to cuBLAS: Converting the fusion to a Call, so that it + // can be inlined back again. + HloComputation* const computation = hlo->parent(); + HloInstruction* const call = computation->AddInstruction( + HloInstruction::CreateCall(hlo->shape(), hlo->operands(), + hlo->fused_instructions_computation())); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, call)); + hlo = call; + } else { + CHECK(autotune_result.has_algorithm()); + backend_config.set_kind(std::string(kCuDnnFusionKind)); + backend_config.mutable_cudnn_fusion_config()->set_plan_id( + autotune_result.algorithm().algo_id()); + TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); + } + } + + if (backend_config.has_triton_gemm_config()) { + TF_ASSIGN_OR_RETURN( + const TritonGemmConfig config, + TritonGemmConfig::FromProto(backend_config.triton_gemm_config())); + if (config.split_k > 1) { + TF_RETURN_IF_ERROR(MakeDotSplitKBatch(hlo, config)); + } + } + + MarkAsChanged(); + return absl::OkStatus(); + } + + private: + AutotuneConfig config_; +}; + +// This contains all alternative Triton GEMM configs related to one fusion. +struct GemmConfigSet { + std::vector configs; + // Setting this to true disallows verification and fallback to cuBLAS, and + // the usage of cuDNN. + bool has_sparsity = false; +}; + +using CuDnnPlanId = int64_t; + +struct ExecutableCandidate { + std::variant config; + // Not nullptr. + std::unique_ptr executable; +}; + +// This contains all alternative executables related to one fusion. +struct ExecutableSet { + std::vector candidates; + // Not nullptr. + std::unique_ptr reference; +}; + +class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault { + public: + explicit GemmConfigSetCollector(const AutotuneConfig& config) + : config_(config) {} + + absl::StatusOr< + absl::flat_hash_map> + CollectGemmConfigSets( + const HloModule* module, + const absl::flat_hash_set& execution_threads = {}) { + gemm_config_sets_.clear(); + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + TF_RETURN_IF_ERROR(computation->Accept(this)); + } + return std::move(gemm_config_sets_); + } + + absl::Status HandleFusion(const HloInstruction* hlo) override { + const HloFusionInstruction* fusion = Cast(hlo); + + TF_ASSIGN_OR_RETURN(auto gpu_config, + hlo->backend_config()); + const FusionBackendConfig& backend_config = + gpu_config.fusion_backend_config(); + + AutotuneCacheKey key = AutotunerUtil::GetKey(hlo, config_); + if (AutotunerUtil::IsInCache(key) || handled_fusions_.contains(key)) { + return absl::OkStatus(); + } + + if (backend_config.kind() == kTritonGemmFusionKind && + !backend_config.has_triton_gemm_config()) { + TF_ASSIGN_OR_RETURN(GemmConfigSet gemm_config_set, + GetGemmConfigSet(fusion)); + TF_RET_CHECK( + gemm_config_sets_.insert({fusion, std::move(gemm_config_set)}) + .second); + } else if (backend_config.kind() == kCuDnnFusionKind && + !backend_config.has_cudnn_fusion_config()) { + TF_RET_CHECK(gemm_config_sets_.insert({fusion, {}}).second); + } + + handled_fusions_.insert(key); + return absl::OkStatus(); + } + + absl::Status DefaultAction(const HloInstruction* hlo) override { + return absl::OkStatus(); + } + + private: + absl::StatusOr GetGemmConfigSet( + const HloFusionInstruction* fusion) { + const DebugOptions& debug_options = + fusion->GetModule()->config().debug_options(); + auto cuda_comp = + std::get(config_.GetGpuComputeCapability()); + const HloDotInstruction* dot_instr = + Cast(hlo_query::GetFirstInstructionWithOpcode( + *fusion->called_computations().at(0), HloOpcode::kDot)); + TF_ASSIGN_OR_RETURN(auto configs, GetPossibleMatmulAutotuneConfigs( + *dot_instr, cuda_comp, debug_options, + config_.ExhaustiveTilingSearch())); + return GemmConfigSet{std::move(configs), + /*has_sparsity=*/dot_instr->sparse_operands() > 0}; + } + + AutotuneConfig config_; + absl::flat_hash_map + gemm_config_sets_; + absl::flat_hash_set handled_fusions_; +}; + +struct TileSizeLimit { + int64_t block_m = 0; + int64_t block_n = 0; + int64_t block_k = 0; +}; + +absl::StatusOr GetUpperLimit(const HloDotInstruction& dot) { + TF_ASSIGN_OR_RETURN(int64_t non_contracting_index0, + NonContractingDimensionIndex(dot, /*operand_number=*/0)); + TF_ASSIGN_OR_RETURN(int64_t non_contracting_index1, + NonContractingDimensionIndex(dot, /*operand_number=*/1)); + TF_ASSIGN_OR_RETURN(int64_t contracting_index0, + ContractingDimensionIndex(dot, /*operand_number=*/0)); + // This is not a sharp upper limit, the actual m value can be much smaller + // based on how much of the m dimension is physically contiguous. + // TODO(tdanyluk): Get the exact m value by running a TritonFusionAnalysis. + const int64_t m = dot.operand(0)->shape().dimensions(non_contracting_index0); + // Theoretically the same is true as for m, but that is not possible in + // practice with the current implementation. + const int64_t n = dot.operand(1)->shape().dimensions(non_contracting_index1); + // This is before doing the split-k transform. + const int64_t k = dot.operand(0)->shape().dimensions(contracting_index0); + const int64_t block_m_limit = + std::max(tsl::NextPowerOfTwoS64(m), kMinTileSize); + const int64_t block_n_limit = + std::max(tsl::NextPowerOfTwoS64(n), kMinTileSize); + // Increase minimum tile size for the contracting dimension proportionally + // to the sparsity multiplier (assume 2:4 structured sparsity). + const int64_t block_k_limit = + std::max(tsl::NextPowerOfTwoS64(k), + kMinTileSize * (dot.sparse_operands() ? 2 : 1)); + return TileSizeLimit{block_m_limit, block_n_limit, block_k_limit}; +} + +int64_t GetSplitKLimit(int64_t block_k, int64_t block_k_limit) { + return std::max(block_k_limit / block_k, 1); +} + +// Search space for exhaustive matmul autotuning. +constexpr std::array BLOCK_SIZES = {16, 32, 64, 128, 256, 512}; +constexpr std::array NUM_STAGES = {1, 2, 3, 4}; +constexpr std::array NUM_WARPS = {2, 4, 8, 16}; +constexpr std::array SPLIT_K = {1, 2, 4, 8, 16}; +// This is the number of blocks per cluster. +// +// Clusters have 3 dimensions (x,y,z) and only 1 <= x*y*z <= 16 are supported. +// Triton doesn't support (3,3,1) and possibly other non-"power of 2" values. +// It's possible that some other values may be(come) supported. +constexpr std::array NUM_CTAS = {1, 2, 4, 8, 16}; + +absl::StatusOr> +GetExhaustiveMatmulAutotuneConfigs( + const HloDotInstruction& dot, + const se::CudaComputeCapability compute_capability, const int max_split_k, + const DebugOptions& debug_options) { + TF_ASSIGN_OR_RETURN(const TileSizeLimit limit, GetUpperLimit(dot)); + std::vector configs; + bool mma_layout_v2 = + compute_capability.IsAtLeast(se::CudaComputeCapability::AMPERE); + bool enable_hopper_optimizations = + debug_options.xla_gpu_enable_triton_hopper() && + compute_capability.IsAtLeast(se::CudaComputeCapability::HOPPER); + + for (int num_warps : NUM_WARPS) { + for (int num_stages : NUM_STAGES) { + // Volta doesn't support num_stages > 2. + if (!mma_layout_v2 && num_stages > 2) { + continue; + } + for (int block_m : BLOCK_SIZES) { + if (block_m > limit.block_m) { + continue; + } + for (int block_n : BLOCK_SIZES) { + if (block_n > limit.block_n) { + continue; + } + for (int block_k : BLOCK_SIZES) { + if (block_k > limit.block_k) { + continue; + } + // Sparse meta should have at least one element per thread. + // Note: only 2:4 structured sparsity is currently supported. + if (dot.sparse_operands() && + block_m * block_k / 16 < num_warps * WarpSize()) { + continue; + } + for (int split_k : SPLIT_K) { + if (split_k > + std::min(max_split_k, + GetSplitKLimit(block_k, limit.block_k))) { + continue; + } + if (!enable_hopper_optimizations) { + configs.push_back(TritonGemmConfig( + block_m, block_n, block_k, split_k, num_stages, num_warps)); + continue; + } + // Arch >= Hopper autotuning. + // We only want to autotune this if it provides any speedup. So + // please think about that before adding it to the default + // autotuning parameters. + for (int num_ctas : NUM_CTAS) { + configs.push_back(TritonGemmConfig(block_m, block_n, block_k, + split_k, num_stages, + num_warps, num_ctas)); + } + } + } + } + } + } + } + return configs; +} + +std::vector GetFixedMatmulAutotuneConfigs( + const se::CudaComputeCapability compute_capability, const int max_split_k) { + // Shorter name for better formatting. + using Config = TritonGemmConfig; + std::vector configs = { + Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4), + Config(32, 64, 64, 4, 1, 4), Config(128, 128, 64, 4, 1, 4), + Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4), + Config(16, 64, 128, 1, 1, 4), Config(16, 128, 32, 8, 1, 4), + Config(16, 16, 512, 1, 1, 4), Config(32, 16, 512, 1, 1, 4), + Config(64, 32, 64, 1, 2, 8)}; + if (compute_capability.IsAtLeast(se::CudaComputeCapability::AMPERE)) { + absl::c_copy( + std::vector{ + Config(128, 256, 32, 1, 3, 8), Config(256, 128, 32, 1, 3, 8), + Config(256, 64, 32, 1, 4, 4), Config(64, 256, 32, 1, 4, 4), + Config(128, 64, 32, 1, 4, 4), Config(64, 128, 32, 1, 4, 4), + Config(256, 128, 128, 1, 3, 8), Config(256, 64, 128, 1, 4, 4), + Config(64, 256, 128, 1, 4, 4), Config(128, 128, 128, 1, 4, 4), + Config(128, 64, 64, 1, 4, 4), Config(64, 128, 64, 1, 4, 4), + Config(128, 32, 64, 1, 4, 4), Config(64, 32, 64, 1, 4, 4), + Config(32, 128, 32, 1, 4, 4), Config(128, 128, 32, 1, 4, 4), + Config(16, 16, 256, 1, 3, 4), Config(128, 128, 64, 2, 1, 8), + Config(64, 64, 64, 1, 2, 4), Config(16, 64, 256, 8, 1, 4), + Config(256, 256, 128, 1, 3, 8)}, + std::back_inserter(configs)); + } + if (compute_capability.IsAtLeast(se::CudaComputeCapability::HOPPER)) { + absl::c_copy( + std::vector{ + Config(16, 32, 32, 8, 1, 2), + Config(16, 64, 128, 8, 1, 4), + Config(16, 64, 128, 16, 3, 4), + }, + std::back_inserter(configs)); + } + configs.erase(std::remove_if(configs.begin(), configs.end(), + [&](const Config& config) { + return config.split_k > max_split_k; + }), + configs.end()); + return configs; +} + +// This prefers to take the parameter by moving it. +absl::StatusOr> ReduceTileSizes( + const HloDotInstruction& dot, std::vector configs) { + TF_ASSIGN_OR_RETURN(const TileSizeLimit limit, GetUpperLimit(dot)); + // Decrease the block sizes and split_k if they are unnecessarily big. + for (TritonGemmConfig& config : configs) { + config.block_m = std::min(config.block_m, limit.block_m); + config.block_n = std::min(config.block_n, limit.block_n); + config.block_k = std::min(config.block_k, limit.block_k); + config.split_k = std::min( + config.split_k, GetSplitKLimit(config.block_k, limit.block_k)); + // Sparse meta should have at least one element per thread. + // Note: only 2:4 structured sparsity is currently supported. + if (dot.sparse_operands()) { + int meta_elements = config.block_m * config.block_k / 16; + config.num_warps = + std::min(config.num_warps, meta_elements / WarpSize()); + } + } + + // Remove duplicates. + absl::flat_hash_set configs_so_far; + configs.erase(std::remove_if(configs.begin(), configs.end(), + [&](const TritonGemmConfig& config) { + return !configs_so_far.insert(config).second; + }), + configs.end()); + TF_RET_CHECK(!configs.empty()); + return configs; +} + +int GetLogEveryN() { return VLOG_IS_ON(3) ? 100 : 1000; } + +absl::StatusOr> TritonGemmAutotuneExtractor( + const TritonGemmConfig& config, + const se::DeviceDescription& gpu_device_info, + const HloFusionInstruction* fusion, DebugOptions debug_opts, + bool allow_filtering_kernels_spilling_registers) { + std::unique_ptr new_module = + ExtractInstructionIntoNewModule(*fusion); + // Reduce memory usage during compilation by disabling GPU runtime. + debug_opts.set_xla_gpu_enable_xla_runtime_executable(false); + // TODO(anlunx): Disable command buffers for now because it breaks triton + // autotuner test. Enable this when the function of command buffers is stable. + debug_opts.clear_xla_gpu_enable_command_buffer(); + if (!allow_filtering_kernels_spilling_registers) { + debug_opts.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning( + false); + } + new_module->mutable_config().set_debug_options(debug_opts); + + HloComputation* entry_computation = new_module->entry_computation(); + HloInstruction* cloned_dot_fusion = entry_computation->root_instruction(); + + TF_ASSIGN_OR_RETURN(auto gpu_config, + cloned_dot_fusion->backend_config()); + FusionBackendConfig& backend_config = + *gpu_config.mutable_fusion_backend_config(); + + *backend_config.mutable_triton_gemm_config() = config.ToProto(); + TF_RETURN_IF_ERROR(cloned_dot_fusion->set_backend_config(gpu_config)); + + if (config.split_k > 1) { + TF_RETURN_IF_ERROR(MakeDotSplitKBatch(cloned_dot_fusion, config)); + GpuFloatSupport bf16_support(gpu_device_info.cuda_compute_capability(), + BF16); + FloatNormalization float_normalization(&bf16_support); + TF_RETURN_IF_ERROR(float_normalization.Run(new_module.get()).status()); + GpuInstructionFusion instruction_fusion(/*may_duplicate=*/false, + gpu_device_info); + TF_RETURN_IF_ERROR(instruction_fusion.Run(new_module.get()).status()); + HloInstruction* root = entry_computation->root_instruction(); + // If the instruction fusion pass above skipped the reduction, turn it + // into a fusion for a universal set of arguments for execution. + if (root->opcode() == HloOpcode::kReduce) { + HloInstruction* fusion_instruction = + entry_computation->AddInstruction(HloInstruction::CreateFusion( + root->shape(), ChooseFusionKind(*root, *root), root)); + HloInstruction* init_value = root->mutable_operand(1); + TF_CHECK_OK( + entry_computation->ReplaceInstruction(root, fusion_instruction)); + fusion_instruction->FuseInstruction(init_value); + TF_CHECK_OK(entry_computation->RemoveInstruction(init_value)); + } + } + return new_module; +} + +absl::StatusOr> CublasGemmAutotuneExtractor( + const AutotuneConfig& config, const HloFusionInstruction* fusion, + const DebugOptions& debug_opts) { + const HloComputation* fusion_computation = + fusion->called_computations().at(0); + std::unique_ptr new_module = + ExtractComputationIntoNewModule(*fusion_computation); + new_module->mutable_config().set_debug_options(debug_opts); + + auto* dot = hlo_query::GetFirstInstructionWithOpcode( + *new_module->entry_computation(), HloOpcode::kDot); + // Substitute algorithms, which are not supported by cuBLAS for the check, but + // don't use cuBlas in the end. This assumes that the substituting algorithm + // has result which are close enough for the check in this file. + if (dot->precision_config().algorithm() == + PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 || + dot->precision_config().algorithm() == + PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6) { + dot->mutable_precision_config()->set_algorithm( + PrecisionConfig::ALG_DOT_F32_F32_F32); + } + + GemmRewriter rewriter(config.GetGpuComputeCapability()); + GpuInstructionFusion fusion_pass( + /*may_duplicate=*/false, config.GetExecutor()->GetDeviceDescription()); + TF_RETURN_IF_ERROR(rewriter.Run(new_module.get()).status()); + TF_RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status()); + // TODO(tdanyluk): Consider running GemmAlgorithmPicker here for better cuBLAS + // performance. It is probably not needed on Ampere and later because cuBLAS + // ignores the algorithm parameter for those targets. If we run + // GemmAlgorithmPicker, we probably should not run this in parallel with other + // compilations. + return new_module; +} + +absl::StatusOr> CudnnGemmAutotuneExtractor( + const AutotuneConfig& autotune_config, const HloFusionInstruction* fusion, + const DebugOptions& debug_opts, const int plan_id) { + std::unique_ptr new_module = + ExtractInstructionIntoNewModule(*fusion); + new_module->mutable_config().set_debug_options(debug_opts); + + GpuBackendConfig gpu_config; + FusionBackendConfig& backend_config = + *gpu_config.mutable_fusion_backend_config(); + backend_config.set_kind(std::string(kCuDnnFusionKind)); + // Provided a plan ID the autotuner just compiles one plan. + backend_config.mutable_cudnn_fusion_config()->set_plan_id(plan_id); + TF_RETURN_IF_ERROR( + new_module->entry_computation()->root_instruction()->set_backend_config( + gpu_config)); + + return new_module; +} + +bool ShouldAllowFilteringKernelsSpillingRegisters( + const GemmConfigSet& gemm_config_set) { + return gemm_config_set.configs.size() > 1; +} + +bool IsFusionKind(const HloInstruction& hlo, absl::string_view kind) { + auto gpu_config = hlo.backend_config(); + if (!gpu_config.ok()) { + return false; + } + return gpu_config->fusion_backend_config().kind() == kind; +} + +int GetCuDnnPlanCount(const HloInstruction& hlo, + const AutotuneConfig& autotune_config) { + if (auto gpu_config = hlo.backend_config(); + !gpu_config.ok() || + gpu_config->fusion_backend_config().has_cudnn_fusion_config()) { + return {}; + } + return CuDnnFusionCompiler::GetAvailablePlanCount( + *autotune_config.GetExecutor(), *DynCast(&hlo)); +} + +bool IsCuDnnEnabled(const AutotuneConfig& config, + const DebugOptions& debug_opts) { + return !config.IsDeviceless() && + std::get(config.GetGpuComputeCapability()) + .IsAtLeastHopper() && + debug_opts.xla_gpu_cudnn_gemm_fusion_level() > 0 && + GetDnnVersionInfo(config.GetExecutor()).major_version() >= 9; +} + +bool HasAlgorithmSupportedByCublasOrCublasLt( + const HloFusionInstruction& fusion) { + const PrecisionConfig::Algorithm algorithm = + hlo_query::GetFirstInstructionWithOpcode(*fusion.called_computation(), + HloOpcode::kDot) + ->precision_config() + .algorithm(); + return algorithm_util::IsSupportedByCublasOrCublasLt(algorithm); +} + +bool HasAlgorithmSupportedByCudnn(const HloFusionInstruction& fusion) { + const PrecisionConfig::Algorithm algorithm = + hlo_query::GetFirstInstructionWithOpcode(*fusion.called_computation(), + HloOpcode::kDot) + ->precision_config() + .algorithm(); + return algorithm_util::IsSupportedByCudnn(algorithm); +} + +absl::StatusOr> +CompileMany(const AutotuneConfig& config, AutotunerCompileUtil& util, + tsl::thread::ThreadPool* thread_pool, + const DebugOptions& debug_opts, + const absl::flat_hash_map& gemm_config_sets) { + absl::Mutex executable_sets_mu; + absl::flat_hash_map + executable_sets; + + if (gemm_config_sets.empty()) { + return executable_sets; + } + + const se::DeviceDescription& gpu_device_info = + config.GetExecutor()->GetDeviceDescription(); + + const int log_every_n = GetLogEveryN(); + int64_t config_count = 0; + for (const auto& key_value : gemm_config_sets) { + const HloFusionInstruction& hlo = *key_value.first; + const GemmConfigSet& gemm_config_set = key_value.second; + + if (IsFusionKind(hlo, kTritonGemmFusionKind)) { + config_count += gemm_config_set.configs.size(); + if (!gemm_config_set.has_sparsity && IsCuDnnEnabled(config, debug_opts) && + HasAlgorithmSupportedByCudnn(hlo)) { + config_count += GetCuDnnPlanCount(hlo, config); + } + } else if (IsFusionKind(hlo, kCuDnnFusionKind)) { + config_count += GetCuDnnPlanCount(hlo, config); + } + // Reference config for verification (uses cuBLAS). + config_count += !gemm_config_set.has_sparsity; + } + + std::atomic done_count = 0; + std::atomic good_count = 0; + auto log = [&](bool success) { + const int done_so_far = done_count.fetch_add(1) + 1; + const int good_so_far = + success ? good_count.fetch_add(1) + 1 : good_count.load(); + if (done_so_far % log_every_n == 0) { + VLOG(2) << "Compiled " << done_so_far << " of " << config_count + << " configs (successful: " << good_so_far << ")"; + } + }; + + // Returns true on success. + auto compile = [&](const HloFusionInstruction* fusion, + const TritonGemmConfig& conf, + bool allow_filtering_kernels_spilling_registers) + -> absl::StatusOr { + CHECK_LE(conf.block_m, kMaxTileSize); + CHECK_LE(conf.block_n, kMaxTileSize); + CHECK_LE(conf.block_k, kMaxTileSize); + // TODO(b/296884861): Reenable GPU runtime, when it will have much smaller + // memory overhead (regarding the size of the executables). + // We can also remove the force_disable_gpu_runtime argument at that + // point. + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + util.Compile([&](const DebugOptions& opts) { + return TritonGemmAutotuneExtractor( + conf, gpu_device_info, fusion, opts, + allow_filtering_kernels_spilling_registers); + })); + + if (executable != nullptr) { + absl::MutexLock lock(&executable_sets_mu); + ExecutableSet& executable_set = executable_sets[fusion]; + executable_set.candidates.push_back( + ExecutableCandidate{conf, std::move(executable)}); + return true; + } + + return false; + }; + + // Returns true on success. + auto compile_reference_executable = + [&](const HloFusionInstruction* fusion) -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + util.Compile([&](const DebugOptions& opts) { + return CublasGemmAutotuneExtractor(config, fusion, + opts); + })); + + if (executable != nullptr) { + absl::MutexLock lock(&executable_sets_mu); + ExecutableSet& executable_set = executable_sets[fusion]; + TF_RET_CHECK(executable_set.reference == nullptr); + executable_set.reference = std::move(executable); + return true; + } + + return false; + }; + + auto compile_cudnn_executable = [&](const HloFusionInstruction* fusion, + const int plan_id) { + std::unique_ptr executable = + util.Compile([&](const DebugOptions& opts) { + return CudnnGemmAutotuneExtractor(config, fusion, opts, plan_id); + }) + .value_or(nullptr); + if (executable != nullptr) { + absl::MutexLock lock(&executable_sets_mu); + ExecutableSet& executable_set = executable_sets[fusion]; + executable_set.candidates.push_back( + ExecutableCandidate{plan_id, std::move(executable)}); + return true; + } + return false; + }; + + // If the thread pool has only one thread, then it is actually slower to + // offload the tasks there. + if (thread_pool && thread_pool->NumThreads() > 1 && + debug_opts.xla_gpu_force_compilation_parallelism() != 1) { + if (gemm_config_sets.size() == 1) { + absl::string_view fusion_name = gemm_config_sets.begin()->first->name(); + VLOG(1) << "Compiling " << config_count << " configs for " << fusion_name + << " on " << thread_pool->NumThreads() << " threads."; + } else { + VLOG(1) << "Compiling " << config_count << " configs for " + << gemm_config_sets.size() << " fusions on " + << thread_pool->NumThreads() << " threads."; + } + + tsl::BlockingCounter counter(config_count); + for (const auto& key_value : gemm_config_sets) { + const HloFusionInstruction* fusion = key_value.first; + const GemmConfigSet& gemm_config_set = key_value.second; + + for (const TritonGemmConfig& conf : gemm_config_set.configs) { + thread_pool->Schedule([&, fusion] { + absl::StatusOr has_executable = compile( + fusion, conf, + ShouldAllowFilteringKernelsSpillingRegisters(gemm_config_set)); + TF_CHECK_OK(has_executable.status()) + << "Failure occured when compiling fusion " << fusion->name() + << " with config '" << conf.ToString() + << "'\nFused HLO computation:\n" + << fusion->fused_instructions_computation()->ToString(); + log(has_executable.value()); + counter.DecrementCount(); + }); + } + + if (!gemm_config_set.has_sparsity) { + thread_pool->Schedule([&, fusion] { + absl::StatusOr has_executable = + compile_reference_executable(fusion); + TF_CHECK_OK(has_executable.status()); + log(has_executable.value()); + counter.DecrementCount(); + }); + } + + if (IsFusionKind(*fusion, kCuDnnFusionKind) || + (IsFusionKind(*fusion, kTritonGemmFusionKind) && + !gemm_config_set.has_sparsity && + IsCuDnnEnabled(config, debug_opts) && + HasAlgorithmSupportedByCudnn(*fusion))) { + const int plan_count = GetCuDnnPlanCount(*fusion, config); + for (int plan_id = 0; plan_id < plan_count; ++plan_id) { + thread_pool->Schedule([&, fusion, plan_id] { + log(compile_cudnn_executable(fusion, plan_id)); + counter.DecrementCount(); + }); + } + } + } + counter.Wait(); + } else { + if (gemm_config_sets.size() == 1) { + absl::string_view fusion_name = gemm_config_sets.begin()->first->name(); + LOG(WARNING) << "Compiling " << config_count << " configs for " + << fusion_name << " on a single thread."; + + } else { + LOG(WARNING) << "Compiling " << config_count << " configs for " + << gemm_config_sets.size() << " fusions on a single thread."; + } + + for (const auto& key_value : gemm_config_sets) { + const HloFusionInstruction* fusion = key_value.first; + const GemmConfigSet& gemm_config_set = key_value.second; + + for (const TritonGemmConfig& gemm_config : gemm_config_set.configs) { + VLOG(5) << "Compiling " << gemm_config.ToString(); + TF_ASSIGN_OR_RETURN( + bool has_executable, + compile( + fusion, gemm_config, + ShouldAllowFilteringKernelsSpillingRegisters(gemm_config_set))); + log(has_executable); + } + + if (!gemm_config_set.has_sparsity) { + TF_ASSIGN_OR_RETURN(bool has_executable, + compile_reference_executable(fusion)); + log(has_executable); + } + + if (IsFusionKind(*fusion, kCuDnnFusionKind) || + (IsFusionKind(*fusion, kTritonGemmFusionKind) && + !gemm_config_set.has_sparsity && + IsCuDnnEnabled(config, debug_opts) && + HasAlgorithmSupportedByCudnn(*fusion))) { + const int plan_count = GetCuDnnPlanCount(*fusion, config); + for (int plan_id = 0; plan_id < plan_count; ++plan_id) { + log(compile_cudnn_executable(fusion, plan_id)); + } + } + } + } + + VLOG(1) << "Done compiling (successful: " << good_count.load() << ")."; + + return executable_sets; +} + +absl::StatusOr Execute(const AutotuneConfig& config, + AutotunerCompileUtil& util, + const DebugOptions& debug_opts, + const HloFusionInstruction* fusion, + const ExecutableSet& executable_set) { + const HloComputation* fusion_computation = + fusion->called_computations().at(0); + + se::StreamExecutor* stream_exec = config.GetExecutor(); + if (!stream_exec->SynchronizeAllActivity()) { + return Internal("Failed to synchronize GPU for autotuning."); + } + se::DeviceMemoryAllocator* allocator = config.GetAllocator(); + if (allocator == nullptr) { + allocator = stream_exec->GetAllocator(); + } + TF_ASSIGN_OR_RETURN(se::Stream* const stream, + allocator->GetStream(stream_exec->device_ordinal())); + TF_ASSIGN_OR_RETURN( + se::RedzoneAllocator rz_allocator, + AutotunerUtil::CreateRedzoneAllocator(config, debug_opts)); + + const HloInstruction& root = *fusion_computation->root_instruction(); + BufferComparator comparator(root.shape(), + fusion_computation->parent()->config()); + + std::vector inputs; + inputs.reserve(fusion_computation->parameter_instructions().size()); + std::vector input_shapes; + input_shapes.reserve(fusion_computation->parameter_instructions().size()); + int64_t rng_state = 0; + for (const HloInstruction* param : + fusion_computation->parameter_instructions()) { + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase param_buffer, + AutotunerUtil::CreateBuffer( + rz_allocator, param->shape(), config, rng_state)); + inputs.push_back(param_buffer); + input_shapes.push_back(param->shape()); + } + + // Run with cuBLAS (optional). + std::optional reference_buffer; + absl::Duration cublas_duration = absl::InfiniteDuration(); + if (executable_set.reference != nullptr) { + TF_ASSIGN_OR_RETURN(std::optional output, + util.ProfileExecutable(&*executable_set.reference, + stream, inputs, input_shapes)); + TF_RET_CHECK(output.has_value()); + if (config.should_check_correctness()) { + reference_buffer = std::move(output->output); + } + cublas_duration = output->duration; + } + + const int log_every_n = GetLogEveryN(); + const int64_t executable_count = executable_set.candidates.size(); + int ran_so_far = 0; + std::vector triton_results, cudnn_results; + VLOG(2) << "Running " << executable_count << " configs for " << fusion->name() + << "."; + for (const ExecutableCandidate& candidate : executable_set.candidates) { + AutotuneResult res; + + std::string candidate_description; + if (std::holds_alternative(candidate.config)) { + candidate_description = absl::StrFormat( + "triton tiling %s", + std::get(candidate.config).ToString()); + *res.mutable_triton() = + std::get(candidate.config).ToProto(); + } else { + const int64_t plan_id = std::get(candidate.config); + candidate_description = absl::StrFormat("cuDNN plan %d", plan_id); + res.mutable_algorithm()->set_algo_id(plan_id); + } + VLOG(5) << "Trying : " << candidate_description; + + TF_ASSIGN_OR_RETURN(std::optional profiling_output, + util.ProfileExecutable(candidate.executable.get(), + stream, inputs, input_shapes)); + ran_so_far += 1; + if (ran_so_far % log_every_n == 0) { + VLOG(2) << "Ran " << ran_so_far << " configs of " << executable_count + << "."; + } + + if (!profiling_output) { + VLOG(5) << "Skipping this tiling."; + continue; + } + + VLOG(5) << "Running the kernel took: " << profiling_output->duration; + if (profiling_output->duration >= absl::Seconds(1)) { + LOG(WARNING) << "Slow kernel for " << fusion->name() + << " took: " << profiling_output->duration << ". " + << candidate_description; + } + *res.mutable_run_time() = + tsl::proto_utils::ToDurationProto(profiling_output->duration); + + // Reference buffer is available when `config.should_check_correctness()` + // is set and reference executable was compiled. + if (reference_buffer.has_value()) { + TF_ASSIGN_OR_RETURN( + se::RedzoneAllocator::RedzoneCheckStatus rz_check_status, + rz_allocator.CheckRedzones()); + if (!rz_check_status.ok()) { + LOG(ERROR) << "Red zone modified"; + res.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED); + res.mutable_failure()->set_msg(rz_check_status.RedzoneFailureMsg()); + CHECK(!config.should_crash_on_check_failure()); + continue; + } + + TF_ASSIGN_OR_RETURN( + bool outputs_match, + comparator.CompareEqual( + stream, /*current=*/profiling_output->output.root_buffer(), + /*expected=*/reference_buffer->root_buffer())); + if (!outputs_match) { + const char kMessage[] = + "Results do not match the reference. This is likely a " + "bug/unexpected loss of precision."; + LOG(ERROR) << kMessage; + CHECK(!config.should_crash_on_check_failure()); + // WRONG_RESULT is not taken seriously by PickBestResult(), so + // use DISQUALIFIED. + res.mutable_failure()->set_kind(AutotuneResult::DISQUALIFIED); + res.mutable_failure()->set_msg(kMessage); + } + } + if (std::holds_alternative(candidate.config)) { + triton_results.push_back(res); + } else { + cudnn_results.push_back(res); + } + } + + VLOG(2) << "Done running."; + + VLOG(2) << fusion->name() << ": time with cuBLAS: " << cublas_duration; + AutotuneResult best; + best.mutable_failure()->set_kind(AutotuneResult::DISQUALIFIED); + if (!triton_results.empty()) { + TF_ASSIGN_OR_RETURN(const AutotuneResult triton_best, + PickBestResult(triton_results, root.ToString(), + root.GetModule()->config())); + VLOG(2) << "Best time with Triton: " + << tsl::proto_utils::FromDurationProto(triton_best.run_time()); + best = triton_best; + } + if (!cudnn_results.empty()) { + TF_ASSIGN_OR_RETURN(const AutotuneResult cudnn_best, + PickBestResult(cudnn_results, root.ToString(), + root.GetModule()->config())); + VLOG(2) << "Best time with cuDNN: " + << tsl::proto_utils::FromDurationProto(cudnn_best.run_time()); + TF_ASSIGN_OR_RETURN(best, + PickBestResult({best, cudnn_best}, root.ToString(), + root.GetModule()->config())); + } + + if (debug_opts.xla_gpu_cublas_fallback() && + !debug_opts.xla_gpu_deterministic_ops() && + HasAlgorithmSupportedByCublasOrCublasLt(*fusion)) { + if (cublas_duration < + tsl::proto_utils::FromDurationProto(best.run_time())) { + VLOG(2) << "Falling back to cuBLAS for " << fusion->name(); + + AutotuneResult cublas; + *cublas.mutable_run_time() = + tsl::proto_utils::ToDurationProto(cublas_duration); + // We will ignore this value anyway. + cublas.mutable_gemm()->set_algorithm(CUBLAS_GEMM_DEFAULT); + + return cublas; + } + } + if (!best.has_triton()) { + VLOG(2) << "Using cuDNN plan " << best.algorithm().algo_id() << " for " + << fusion->name(); + } + return best; +} + +absl::Status DumpAutotunedFusion(const AutotuneConfig& autotune_config, + AutotunerCompileUtil& util, + const AutotuneResult result, + const HloFusionInstruction* fusion, + int fusion_id) { + TritonGemmConfig triton_gemm_config; + if (!result.has_triton()) { + LOG(WARNING) << "Using empty triton GEMM config for op " << fusion->name(); + // Empty TritonGemmConfig has all zero values which is good enough to keep + // fused computation in the dump but illustrate that Triton is not used for + // it after autotuning. + } else { + TF_ASSIGN_OR_RETURN(triton_gemm_config, + TritonGemmConfig::FromProto(result.triton())); + } + const se::DeviceDescription& device_desc = + autotune_config.GetExecutor()->GetDeviceDescription(); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + util.ExtractModule([&](const DebugOptions& debug_opts) { + if (result.has_algorithm()) { + return CudnnGemmAutotuneExtractor(autotune_config, fusion, debug_opts, + result.algorithm().algo_id()); + } else { + return TritonGemmAutotuneExtractor( + triton_gemm_config, device_desc, fusion, debug_opts, + /*allow_filtering_kernels_spilling_registers=*/true); + } + })); + module->set_name(std::string(fusion->name())); + // Using the original module for its debug info and name in the first + // parameter. It's better to include the name of both the original module + // and the extracted module, to avoid name clashes. + DumpToFileInDirOrStdout( + /*module=*/*fusion->GetModule(), + /*file_prefix=*/"", + /*file_suffix=*/ + absl::StrCat("triton_fusion_", fusion_id, ".", module->name(), + ".optimized.txt"), + /*contents=*/module->ToString()); + return absl::OkStatus(); +} + +absl::Status Autotune( + const AutotuneConfig& config, AutotunerCompileUtil& util, + tsl::thread::ThreadPool* thread_pool, const DebugOptions& debug_opts, + const absl::flat_hash_map& + gemm_config_sets) { + absl::flat_hash_map + executable_sets; + TF_ASSIGN_OR_RETURN( + executable_sets, + CompileMany(config, util, thread_pool, debug_opts, gemm_config_sets)); + + // Sort the candidates to make their execution order well-defined for each + // fusion. + for (auto& key_value : executable_sets) { + ExecutableSet& executable_set = key_value.second; + std::vector& candidates = executable_set.candidates; + absl::c_sort(candidates, [](const ExecutableCandidate& a, + const ExecutableCandidate& b) { + return a.config < b.config; + }); + } + + int fusion_id = 0; + for (const auto& key_value : executable_sets) { + const HloFusionInstruction* fusion = key_value.first; + const ExecutableSet& executable_set = key_value.second; + + TF_ASSIGN_OR_RETURN(AutotuneResult result, Execute(config, util, debug_opts, + fusion, executable_set)); + + if (debug_opts.xla_gpu_dump_autotuned_triton_fusions()) { + TF_RETURN_IF_ERROR( + DumpAutotunedFusion(config, util, result, fusion, fusion_id++)); + } + + const AutotuneCacheKey key = AutotunerUtil::GetKey(fusion, config); + if (!AutotunerUtil::AddResult(key, std::move(result))) { + // In the context of model server, concurrent autotuning is expected and + // insertion of identical autotuning keys is accepted. + LOG(WARNING) << "AutotunerUtil::AddResult already existed: " + << key.ToString(); + } + } + + return absl::OkStatus(); +} + +} // anonymous namespace + +absl::StatusOr> GetPossibleMatmulAutotuneConfigs( + const HloDotInstruction& dot, + const se::CudaComputeCapability compute_capability, + const DebugOptions& debug_options, bool exhaustive_tiling_search) { + // Avoid autotuning tiny fusions. + constexpr int kMinGemmElements = 32 * 32; + if (ShapeUtil::ElementsIn(dot.operand(0)->shape()) <= kMinGemmElements && + ShapeUtil::ElementsIn(dot.operand(1)->shape()) <= kMinGemmElements) { + return ReduceTileSizes(dot, {kDefaultGemmTiling}); + } + // Split-K optimization enables more even utilization of a GPU in cases + // where tiling just the non-contracting dimensions of a GEMM does not create + // a sufficient number of thread block programs to occupy all available cores. + // Given the typical ~100 cores per GPU 500 tiles make around 5 full + // waves that completely avoid the need for split-K. The formula below is + // n_tiles = split_k * (M * N) / (block_m * block_n) + // with pessimistically assumed maximum block_m and block_n. + // Most likely there is no need for split-K already at much smaller output + // tensor sizes. + constexpr int kSufficientNumberOfTiles = 500; + const int max_split_k = + debug_options.xla_gpu_enable_split_k_autotuning() + ? std::max(1L, kSufficientNumberOfTiles * kMaxTileSize * + kMaxTileSize / + ShapeUtil::ElementsIn(dot.shape())) + : 1; + return exhaustive_tiling_search + ? GetExhaustiveMatmulAutotuneConfigs(dot, compute_capability, + max_split_k, debug_options) + : ReduceTileSizes(dot, GetFixedMatmulAutotuneConfigs( + compute_capability, max_split_k)); +} + +absl::StatusOr GemmFusionAutotuner::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + XLA_SCOPED_LOGGING_TIMER("GEMM fusion autotuner"); + const DebugOptions& debug_options = module->config().debug_options(); + TF_ASSIGN_OR_RETURN(std::optional opt_compile_util, + AutotunerCompileUtil::Create(config_, debug_options)); + + GemmConfigSetCollector gemm_config_set_collector(config_); + absl::flat_hash_map + gemm_config_sets; + TF_ASSIGN_OR_RETURN(gemm_config_sets, + gemm_config_set_collector.CollectGemmConfigSets( + module, execution_threads)); + + if (debug_options.xla_gpu_autotune_level() == 0 || + debug_options.xla_gpu_deterministic_ops()) { + // Pick the first option for each gemm instead of autotuning. + for (const auto& [fusion, tilings] : gemm_config_sets) { + const AutotuneCacheKey key = AutotunerUtil::GetKey(fusion, config_); + AutotuneResult res; + if (IsFusionKind(*fusion, kCuDnnFusionKind)) { + res.mutable_algorithm()->set_algo_id(-1); + } else { + const HloDotInstruction* dot_instr = + Cast(hlo_query::GetFirstInstructionWithOpcode( + *fusion->called_computations().at(0), HloOpcode::kDot)); + TF_ASSIGN_OR_RETURN(auto configs, + ReduceTileSizes(*dot_instr, {kDefaultGemmTiling})); + auto config = configs.front(); + *res.mutable_triton() = config.ToProto(); + } + *res.mutable_run_time() = + tsl::proto_utils::ToDurationProto(absl::ZeroDuration()); + AutotunerUtil::AddResult(key, res); + } + } else if (!config_.IsDeviceless()) { + TF_RET_CHECK(opt_compile_util.has_value()); + if (!gemm_config_sets.empty()) { + std::string correctness_check_str = config_.should_check_correctness() + ? "(with correctness check)" + : "(without correctness check)"; + + VLOG(1) << "Autotuning " << gemm_config_sets.size() << " fusions " + << correctness_check_str << "."; + TF_RETURN_IF_ERROR(Autotune(config_, *opt_compile_util, thread_pool_, + debug_options, gemm_config_sets)); + VLOG(1) << "Done autotuning."; + } + } + + return GemmFusionAutotunerVisitor(config_).RunOnModule(module, + execution_threads); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/gemm_fusion_autotuner.h b/xla/service/gpu/gemm_fusion_autotuner.h new file mode 100644 index 0000000000000..18a6e1b76caa4 --- /dev/null +++ b/xla/service/gpu/gemm_fusion_autotuner.h @@ -0,0 +1,65 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_GEMM_FUSION_AUTOTUNER_H_ +#define XLA_SERVICE_GPU_GEMM_FUSION_AUTOTUNER_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/autotuning.pb.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla.pb.h" +#include "tsl/platform/threadpool.h" + +namespace xla { +namespace gpu { + +// Find best tiling configuration for each triton fusion outlined. +class GemmFusionAutotuner : public HloModulePass { + public: + explicit GemmFusionAutotuner(const AutotuneConfig& config, + tsl::thread::ThreadPool* thread_pool) + : config_(config), thread_pool_(thread_pool) {} + + absl::string_view name() const override { return "triton-autotuner"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + AutotuneConfig config_; + tsl::thread::ThreadPool* thread_pool_; +}; + +// TODO(b/266210099): have a way to generate/load these dynamically. +// Returns a list of possible tilings for a GEMM performed in Triton. +absl::StatusOr> GetPossibleMatmulAutotuneConfigs( + const HloDotInstruction& dot, se::CudaComputeCapability compute_capability, + const DebugOptions& debug_options, bool exhaustive_tiling_search = false); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_GEMM_FUSION_AUTOTUNER_H_ diff --git a/xla/service/gpu/gemm_fusion_autotuner_test.cc b/xla/service/gpu/gemm_fusion_autotuner_test.cc new file mode 100644 index 0000000000000..f8def9455b2c1 --- /dev/null +++ b/xla/service/gpu/gemm_fusion_autotuner_test.cc @@ -0,0 +1,772 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/gemm_fusion_autotuner.h" + +#include +#include +#include +#include +#include + +#include +#include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "xla/autotuning.pb.h" +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/executable.h" +#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/gemm_fusion.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/hlo_pass_pipeline.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/filecheck.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/test_utils.h" +#include "xla/tests/verified_hlo_module.h" +#include "xla/tools/hlo_decomposer.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/cpu_info.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/threadpool.h" + +namespace xla { +namespace gpu { +namespace { + +namespace m = ::xla::match; + +using HloExtractionTest = HloTestBase; + +TEST_F(HloExtractionTest, InstructionExtractionIsCorrect) { + std::unique_ptr module = ParseAndReturnVerifiedModule(R"( +HloModule module + +triton_gemm_dot { + p0 = s8[10,10] parameter(0) + p1 = f32[10,10] parameter(1) + c0 = f32[10,10] convert(p0) + ROOT dot.0 = f32[10,10] dot(c0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY entry { + p0 = s8[10,10] parameter(0) + p1 = f32[10,10] parameter(1) + s = f32[10,10] sqrt(p1) + d = f32[10,10] fusion(p0, p1), + kind=kCustom, calls=triton_gemm_dot + ROOT r = f32[10,10] add(d, s) +})") + .value(); + + std::unique_ptr extracted_module = ExtractInstructionIntoNewModule( + *module->entry_computation()->root_instruction()->operand(0)); + + // Destroy the original module to be sure that the extracted one has no + // dependency on it. + module.release(); + + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); + EXPECT_EQ(extracted_module->entry_computation()->instruction_count(), 3); + TF_EXPECT_OK(VerifyHloModule(extracted_module.get(), + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false)); +} + +TEST_F(HloExtractionTest, ComputationExtractionIsCorrect) { + std::unique_ptr module = ParseAndReturnVerifiedModule(R"( +HloModule module + +triton_gemm_dot { + p0 = s8[10,10] parameter(0) + p1 = f32[10,10] parameter(1) + c0 = f32[10,10] convert(p0) + ROOT dot.0 = f32[10,10] dot(c0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY entry { + p0 = s8[10,10] parameter(0) + p1 = f32[10,10] parameter(1) + s = f32[10,10] sqrt(p1) + d = f32[10,10] fusion(p0, p1), + kind=kCustom, calls=triton_gemm_dot + ROOT r = f32[10,10] add(d, s) +})") + .value(); + + std::unique_ptr extracted_module = + ExtractComputationIntoNewModule(*module->entry_computation() + ->root_instruction() + ->operand(0) + ->fused_instructions_computation()); + + // Destroy the original module to be sure that the extracted one has no + // dependency on it. + module.release(); + + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + GmockMatch(m::Dot(m::Convert(m::Parameter()), m::Parameter()))); + EXPECT_EQ(extracted_module->entry_computation()->instruction_count(), 4); + TF_EXPECT_OK(VerifyHloModule(extracted_module.get(), + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false)); +} + +class StatelessAutotunerTest : public HloTestBase { + public: + StatelessAutotunerTest() + : HloTestBase(/*verifier_layout_sensitive=*/true, + /*allow_mixed_precision_in_hlo_verifier=*/false) {} + + void SetUp() override { + AutotunerUtil::ClearAutotuneResults(); + HloTestBase::SetUp(); + } + + void TearDown() override { + AutotunerUtil::ClearAutotuneResults(); + HloTestBase::TearDown(); + } +}; + +class GemmFusionAutotunerTest : public StatelessAutotunerTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = + StatelessAutotunerTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_triton_gemm(true); + debug_options.set_xla_gpu_cublas_fallback(false); + debug_options.set_xla_gpu_cudnn_gemm_fusion_level(0); + return debug_options; + } + + se::CudaComputeCapability GetCudaComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + } + + void CheckTritonAutotuning(absl::string_view hlo, + absl::string_view expected) { + HloPassPipeline pipeline("gemm_rewrite"); + pipeline.AddPass(backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability()); + tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "", + tsl::port::MaxParallelism()); + DebugOptions opts; + pipeline.AddPass( + AutotuneConfig{DeviceConfig{backend().default_stream_executor(), + backend().memory_allocator()}, + opts}, + &thread_pool); + + RunAndFilecheckHloRewrite( + hlo, std::move(pipeline), expected, [](const HloModule* m) { + VLOG(5) << m->ToString(); + const HloInstruction* dot_fusion = + m->entry_computation()->root_instruction(); + if (dot_fusion->opcode() == HloOpcode::kReduce) { + dot_fusion = dot_fusion->operand(0); + } + CHECK_EQ(dot_fusion->opcode(), HloOpcode::kFusion); + if (!dot_fusion->backend_config() + ->fusion_backend_config() + .has_cudnn_fusion_config()) { + CHECK_GT(dot_fusion->backend_config() + .value() + .fusion_backend_config() + .triton_gemm_config() + .block_m(), + 0); + } + }); + } +}; + +class GemmFusionAutotunerTestWithMorePreciseReduction + : public GemmFusionAutotunerTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = + GemmFusionAutotunerTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_triton_gemm_disable_reduced_precision_reduction( + true); + return debug_options; + } +}; + +TEST_F(GemmFusionAutotunerTest, AmpereUsesMoreThanTwoStages) { + std::unique_ptr module = ParseAndReturnVerifiedModule(R"( +ENTRY e { + p0 = f32[1024,1024] parameter(0) + p1 = f32[1024,1024] parameter(1) + ROOT r = f32[1024,1024] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})") + .value(); + const se::CudaComputeCapability compute_capability{ + se::CudaComputeCapability::AMPERE, /*minor=*/0}; + TF_ASSERT_OK_AND_ASSIGN( + const std::vector configs, + GetPossibleMatmulAutotuneConfigs( + *Cast( + module->entry_computation()->root_instruction()), + compute_capability, GetDebugOptionsForTest())); + EXPECT_TRUE(std::any_of( + configs.begin(), configs.end(), + [](const TritonGemmConfig& config) { return config.num_stages > 2; })); +} + +TEST_F(GemmFusionAutotunerTest, SmallOutputCanUseLargeSplitK) { + std::unique_ptr module = ParseAndReturnVerifiedModule(R"( +ENTRY e { + p0 = f32[1024,1024] parameter(0) + p1 = f32[1024,1024] parameter(1) + ROOT r = f32[1024,1024] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})") + .value(); + const se::CudaComputeCapability compute_capability{ + se::CudaComputeCapability::AMPERE, /*minor=*/0}; + TF_ASSERT_OK_AND_ASSIGN( + const std::vector configs, + GetPossibleMatmulAutotuneConfigs( + *Cast( + module->entry_computation()->root_instruction()), + compute_capability, GetDebugOptionsForTest())); + EXPECT_TRUE(std::any_of( + configs.begin(), configs.end(), + [](const TritonGemmConfig& config) { return config.split_k >= 16; })); +} + +TEST_F(GemmFusionAutotunerTest, LargeOutputDoesNotUseLargeSplitK) { + std::unique_ptr module = ParseAndReturnVerifiedModule(R"( +ENTRY e { + p0 = f32[20480,20480] parameter(0) + p1 = f32[20480,20480] parameter(1) + ROOT r = f32[20480,20480] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})") + .value(); + const se::CudaComputeCapability compute_capability{ + se::CudaComputeCapability::AMPERE, /*minor=*/0}; + TF_ASSERT_OK_AND_ASSIGN( + const std::vector configs, + GetPossibleMatmulAutotuneConfigs( + *Cast( + module->entry_computation()->root_instruction()), + compute_capability, GetDebugOptionsForTest())); + EXPECT_FALSE(std::any_of( + configs.begin(), configs.end(), + [](const TritonGemmConfig& config) { return config.split_k > 1; })); +} + +TEST_F(GemmFusionAutotunerTest, Int8FusedGemm) { + const std::string hlo = R"( +HloModule module + +ENTRY e { + x = s8[128,64] parameter(0) + c = f16[128,64] convert(x) + + y = f16[64,6144] parameter(1) + + ROOT out = f16[128,6144] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + CheckTritonAutotuning(hlo, R"( +// CHECK: ENTRY +// CHECK: ROOT +// CHECK-SAME: kCustom +// CHECK-SAME: block_m +)"); + + EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{/*aabs=*/5e-3, /*arel=*/5e-3})); +} + +TEST_F(GemmFusionAutotunerTest, Int8FusedGemm256) { + const std::string hlo = R"( +HloModule module + +ENTRY e { + x = s8[128,256] parameter(0) + c = f16[128,256] convert(x) + + y = f16[256,6144] parameter(1) + + ROOT out = f16[128,6144] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + CheckTritonAutotuning(hlo, R"( +// CHECK: ENTRY +// CHECK: ROOT +// CHECK-SAME: kCustom +// CHECK-SAME: block_m +)"); + + EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(GemmFusionAutotunerTest, SelectsSplitK) { + // Shapes with K >> M, N have to force split-K configurations. + const std::string kHloText = R"( +HloModule t + +ENTRY e { + p0 = s8[7,8192] parameter(0) + p0c = bf16[7,8192] convert(p0) + p1 = bf16[8192,18] parameter(1) + ROOT dot.0 = bf16[7,18] dot(p0c, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: reduce +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: kCustom +; CHECK-NEXT: kLoop +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/4, /*arel=*/1e-1})); +} + +TEST_F(GemmFusionAutotunerTestWithMorePreciseReduction, SelectsSplitK) { + // Shapes with K >> M, N have to force split-K configurations. + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + p0 = s8[7,8192] parameter(0) + p0c = bf16[7,8192] convert(p0) + p1 = bf16[8192,18] parameter(1) + ROOT dot.0 = bf16[7,18] dot(p0c, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: reduce +; CHECK: ENTRY +; CHECK-NEXT: parameter +; CHECK-NEXT: parameter +; CHECK-NEXT: kCustom +; CHECK-NEXT: kLoop +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(GemmFusionAutotunerTest, ApplySplitKWithoutAlteringTiling) { + const std::string kHloText = R"( +triton_dot { + p0 = f16[55,120] parameter(0) + p1 = f16[120,20] parameter(1) + ROOT dot = f16[55,20] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f16[55,120]{1,0} parameter(0) + p1 = f16[120,20]{1,0} parameter(1) + ROOT _ = f16[55,20] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config":{kind: "__triton_gemm", triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32,"split_k":3,"num_stages":1,"num_warps":2,"num_ctas":1}}} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: f16[3,55,20] +; CHECK: {"block_m":16,"block_n":64,"block_k":32,"split_k":3,"num_stages":1,"num_warps":2,"num_ctas":1} +; CHECK: f16[55,20]{1,0} {{(reduce|fusion)}} +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(GemmFusionAutotunerTest, DoNotRunAutotuningKernelSpillingRegisters) { + const std::string kHloText = R"( +HloModule m + +%triton_gemm_dot { + %p1 = s8[4,12288]{1,0} parameter(1) + %p0 = s8[12288,1536]{1,0} parameter(0) + %convert.p0 = f16[12288,1536]{1,0} convert(s8[12288,1536]{1,0} %p0) + %convert.p1 = f16[4,12288]{1,0} convert(s8[4,12288]{1,0} %p1) + %dot = f16[4,1536]{1,0} dot(f16[4,12288]{1,0} %convert.p1, f16[12288,1536]{1,0} %convert.p0), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT %convert = s8[4,1536]{1,0} convert(f16[4,1536]{1,0} %dot) +} + +ENTRY %e { + %get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0) + %convert = s8[4,12288]{1,0} parameter(1) + ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"16","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}} +})"; + + auto module = ParseAndReturnVerifiedModule(kHloText).value(); + EXPECT_THAT( + backend().compiler()->RunBackend(std::move(module), + backend().default_stream_executor(), + {/*device_allocator=*/nullptr, + /*thread_pool=*/nullptr, + /*layout_canonicalization_callback=*/{}, + /*is_autotuning_compilation=*/true}), + tsl::testing::StatusIs( + tsl::error::CANCELLED, + absl::StrFormat( + "Compilation result discarded due to register spilling"))); +} + +TEST_F(GemmFusionAutotunerTest, + DoNotFilterOutAutotuningKernelSpillingRegisters) { + const std::string kHloText = R"( +HloModule m + +%triton_gemm_dot { + %p1 = s8[4,12288]{1,0} parameter(1) + %p0 = s8[12288,1536]{1,0} parameter(0) + %convert.p0 = f16[12288,1536]{1,0} convert(s8[12288,1536]{1,0} %p0) + %convert.p1 = f16[4,12288]{1,0} convert(s8[4,12288]{1,0} %p1) + %dot = f16[4,1536]{1,0} dot(f16[4,12288]{1,0} %convert.p1, f16[12288,1536]{1,0} %convert.p0), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT %convert = s8[4,1536]{1,0} convert(f16[4,1536]{1,0} %dot) +} + +ENTRY %e { + %get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0) + %convert = s8[4,12288]{1,0} parameter(1) + ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"16","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}} +})"; + + auto module = ParseAndReturnVerifiedModule(kHloText).value(); + HloModuleConfig config = module->config(); + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning( + false); + config.set_debug_options(debug_options); + module->set_config(config); + + std::unique_ptr executable = + backend() + .compiler() + ->RunBackend(std::move(module), backend().default_stream_executor(), + {/*device_allocator=*/nullptr, + /*thread_pool=*/nullptr, + /*layout_canonicalization_callback=*/{}, + /*is_autotuning_compilation=*/true}) + .value(); + EXPECT_NE(executable, nullptr); +} + +TEST_F(GemmFusionAutotunerTest, RunAutotuningKernelNotSpillingRegisters) { + const std::string kHloText = R"( +HloModule m + +%triton_gemm_dot { + %p1 = f16[4,12288]{1,0} parameter(1) + %p0 = s8[12288,1536]{1,0} parameter(0) + %convert.10406 = f16[12288,1536]{1,0} convert(s8[12288,1536]{1,0} %p0) + ROOT %dot = f16[4,1536]{1,0} dot(f16[4,12288]{1,0} %p1, f16[12288,1536]{1,0} %convert.10406), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY %e { + %p0 = s8[12288,1536]{1,0} parameter(0) + %p1 = f16[4,12288]{1,0} parameter(1) + ROOT %triton_dot = f16[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %p0, f16[4,12288]{1,0} %p1), kind=kCustom, calls=%triton_gemm_dot, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"32","block_k":"16","split_k":"1","num_stages":"1","num_warps":"2","num_ctas":"1"}}} +})"; + + auto module = ParseAndReturnVerifiedModule(kHloText).value(); + std::unique_ptr executable = + backend() + .compiler() + ->RunBackend(std::move(module), backend().default_stream_executor(), + {/*device_allocator=*/nullptr, + /*thread_pool=*/nullptr, + /*layout_canonicalization_callback=*/{}, + /*is_autotuning_compilation=*/true}) + .value(); + EXPECT_NE(executable, nullptr); +} + +class GemmFusionAutotunerDumpTest : public GemmFusionAutotunerTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = + GemmFusionAutotunerTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_cublas_fallback(true); + debug_options.set_xla_gpu_dump_autotuned_triton_fusions(true); + return debug_options; + } +}; + +TEST_F(GemmFusionAutotunerDumpTest, DumpingFusionsWorksWithFallback) { + // Computation is chosen such that relatively heavy math operations before the + // GEMM are not worth fusing because they would get duplicated many times and + // slow down execution. Therefore autotuning picks cuBLAS here. + const std::string kHloText = R"( +ENTRY e { + p0 = f32[3333,3333] parameter(0) + s = f32[3333,3333] sine(p0) + p1 = f32[3333,3333] parameter(1) + c = f32[3333,3333] cosine(p1) + ROOT dot = f32[3333,3333] dot(s, c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: cublas +; CHECK-NOT: triton +)"); +} + +TEST_F(GemmFusionAutotunerTest, AutotuneCuDnnFusion) { + const std::string kHlo = R"( +fusion1 { + p0 = f32[3,28,32] parameter(0) + p1 = f32[3,28,32] parameter(1) + ROOT d = f32[3,32,32] dot(p0, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = f32[3,28,32] parameter(0) + p1 = f32[3,28,32] parameter(1) + ROOT _ = f32[3,32,32] fusion(p0, p1), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})"; + + CheckTritonAutotuning(kHlo, R"( +// CHECK: "plan_id": +)"); +} + +// TODO(b/281489442): Write a testcase called +// `SkipConfigsProducingDeviantResults` or similar. + +class GemmFusionAutotunerLevelTest : public StatelessAutotunerTest, + public ::testing::WithParamInterface { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = + StatelessAutotunerTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_autotune_level(GetParam()); + debug_options.set_xla_gpu_cublas_fallback(false); + return debug_options; + } +}; + +TEST_P(GemmFusionAutotunerLevelTest, AllAutotuningLevelsWorkCorrectly) { + const std::string kHloText = R"( +HloModule m + +ENTRY e { + p0 = pred[64,10] parameter(0) + p0c = f32[64,10] convert(p0) + p1 = f32[10,128] parameter(1) + ROOT r = f32[64,128] dot(p0c, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: kind=kCustom +; CHECK-SAME: block_m + )"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_P(GemmFusionAutotunerLevelTest, Deviceless) { + const std::string hlo = R"( +HloModule module + +ENTRY e { + x = s8[16,16] parameter(0) + c = f16[16,16] convert(x) + y = f16[16,16] parameter(1) + ROOT out = f16[16,16] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + HloPassPipeline pipeline("gemm_rewrite_deviceless"); + pipeline.AddPass(backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability()); + tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "", + tsl::port::MaxParallelism()); + DebugOptions opts; + pipeline.AddPass( + AutotuneConfig{DevicelessConfig{backend() + .default_stream_executor() + ->GetDeviceDescription() + .model_str(), + backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability()}, + opts}, + &thread_pool); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + if (GetDebugOptionsForTest().xla_gpu_autotune_level() == 0) { + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloTestBase::RunHloPass(&pipeline, module.get())); + EXPECT_TRUE(changed); + + // Check default configuration. + TF_ASSERT_OK_AND_ASSIGN( + bool filecheck_matches, + RunFileCheck( + module->ToString(HloPrintOptions{}.set_print_operand_shape(false)), + R"( +// CHECK: backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"16","block_k":"16","split_k":"1","num_stages":"1","num_warps":"4","num_ctas":"1"}},"force_earliest_schedule":false} + )")); + EXPECT_TRUE(filecheck_matches); + } else { + EXPECT_THAT(HloTestBase::RunHloPass(&pipeline, module.get()), + tsl::testing::StatusIs( + tsl::error::INTERNAL, + ::testing::HasSubstr( + "Expect autotune result cache hit for deviceless"))); + } +} + +INSTANTIATE_TEST_SUITE_P(GemmFusionAutotunerLevelSweep, + GemmFusionAutotunerLevelTest, ::testing::Range(0, 5)); + +class GemmFusionAutotunerExhaustiveTest : public GemmFusionAutotunerTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = + GemmFusionAutotunerTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_exhaustive_tiling_search(true); + return debug_options; + } +}; + +TEST_F(GemmFusionAutotunerExhaustiveTest, DISABLED_CompileOnly) { + const std::string hlo = R"( +HloModule module + +ENTRY e { + x = s8[16,16] parameter(0) + c = f16[16,16] convert(x) + y = f16[16,16] parameter(1) + ROOT out = f16[16,16] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + CheckTritonAutotuning(hlo, R"( +// CHECK: %triton_gemm_out_computation ( +// CHECK: ROOT %out.1 = f16[16,16]{1,0} dot(%c.1, %parameter_1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +// CHECK: ROOT %triton_gemm_out = f16[16,16]{1,0} fusion(%x, %y), kind=kCustom, calls=%triton_gemm_out_computation +// CHECK-SAME: "block_m": +)"); +} + +class GemmFusionAutotunerDisableSplitK : public GemmFusionAutotunerTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = + GemmFusionAutotunerTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + return debug_options; + } +}; + +TEST_F(GemmFusionAutotunerDisableSplitK, SplitKIsDisabled) { + std::unique_ptr module = ParseAndReturnVerifiedModule(R"( +ENTRY e { + p0 = f32[1024,1024] parameter(0) + p1 = f32[1024,1024] parameter(1) + ROOT r = f32[1024,1024] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})") + .value(); + const se::CudaComputeCapability compute_capability{ + se::CudaComputeCapability::AMPERE, /*minor=*/0}; + TF_ASSERT_OK_AND_ASSIGN( + const std::vector configs, + GetPossibleMatmulAutotuneConfigs( + *Cast( + module->entry_computation()->root_instruction()), + compute_capability, GetDebugOptionsForTest())); + EXPECT_TRUE(std::all_of( + configs.begin(), configs.end(), + [](const TritonGemmConfig& config) { return config.split_k == 1; })); +} + +class GemmFusionAutotunerConfigTest + : public StatelessAutotunerTest, + public ::testing::WithParamInterface {}; + +TEST_P(GemmFusionAutotunerConfigTest, SparseDotDiscardsUnsupportedTiles) { + const std::string kHloText = R"( +HloModule test +ENTRY wais { + lhs = f16[5,1600] parameter(0) + rhs = f16[3200,10] parameter(1) + meta = u16[5,200] parameter(2) + ROOT dot = f32[5,10] dot(lhs, rhs, meta), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4 +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + auto dot = + Cast(module->entry_computation()->root_instruction()); + + TF_ASSERT_OK_AND_ASSIGN( + auto configs, + GetPossibleMatmulAutotuneConfigs( + *dot, se::CudaComputeCapability{8, 0}, GetDebugOptionsForTest(), + /*exhaustive_tiling_search=*/GetParam())); + for (const auto& config : configs) { + int metadata_size = config.block_m * config.block_k / 16; + EXPECT_LE(config.num_warps * WarpSize(), metadata_size); + EXPECT_GT(config.block_k, 16); // kMinTileSize + } +} + +INSTANTIATE_TEST_SUITE_P(GemmFusionAutotunerConfigSweep, + GemmFusionAutotunerConfigTest, ::testing::Bool()); + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/gemm_fusion_test.cc b/xla/service/gpu/gemm_fusion_test.cc new file mode 100644 index 0000000000000..e5986c9968b5e --- /dev/null +++ b/xla/service/gpu/gemm_fusion_test.cc @@ -0,0 +1,1206 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/gemm_fusion.h" + +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/autotuning.pb.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/cublas_padding_requirements.h" +#include "xla/service/gpu/triton_fusion_analysis.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" +#include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/filecheck.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/verified_hlo_module.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +using ::testing::ElementsAre; +using ::testing::FieldsAre; + +namespace m = ::xla::match; + +class GemmFusionTest : public HloTestBase { + public: + GemmFusionTest() + : HloTestBase(/*verifier_layout_sensitive=*/true, + /*allow_mixed_precision_in_hlo_verifier=*/false) {} + + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_triton_gemm_any(false); + debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); + return debug_options; + } + + se::GpuComputeCapability gpu_version_{ + se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0}}; + + void MatchHloModule(HloModule& module, absl::string_view pattern) { + TF_ASSERT_OK_AND_ASSIGN(bool filecheck_result, + RunFileCheck(module.ToString(), pattern)); + EXPECT_TRUE(filecheck_result); + } +}; + +TEST_F(GemmFusionTest, TransposeSubdimensionGroup) { + // This HLO is artificial because unnecessary reshapes get optimized + // out during compilation. It tests the ability of GemmFusion + // to handle transposes of groups of subdimensions. + auto module = ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY e { + p0 = f32[32,3] parameter(0) + t1 = f32[3,32] transpose(p0), dimensions={1,0} + r1 = f32[3,8,4] reshape(t1) + r0 = f32[3,32] reshape(r1) + p1 = f16[32,7] parameter(1) + c1 = f32[32,7] convert(p1) + ROOT d = f32[3,7] dot(r0, c1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})") + .value(); + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); +} + +TEST_F(GemmFusionTest, UnsupportedTransposeIsNotFused) { + auto module = ParseAndReturnVerifiedModule(R"( +ENTRY e { + p0 = f16[1,512,8,1024]{3,1,0,2} parameter(0) + c = f16[1,512,8,1024]{3,2,1,0} copy(p0) + b = f16[4096,1024]{1,0} bitcast(c) + p1 = f16[128,1024]{1,0} parameter(1) + ROOT d = f16[4096,128]{1,0} dot(b, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +})") + .value(); + EXPECT_FALSE(GemmFusion(gpu_version_).Run(module.get()).value()); +} + +TEST_F(GemmFusionTest, BitcastChain) { + // This HLO is artificial because unnecessary reshapes get optimized + // out during compilation. It tests the ability of GemmFusion + // to handle various kinds of bitcasts. + auto module = ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY e { + p0 = s8[60,5] parameter(0) + r0 = s8[3,20,5] reshape(p0) + c0 = f16[3,20,5] convert(r0) + p1 = f16[3,200] parameter(1) + r12 = f16[600] reshape(p1) + r11 = f16[30,20] reshape(r12) + r1 = f16[3,10,20] reshape(r11) + ROOT d = f16[3,5,10] dot(c0, r1), + lhs_contracting_dims={1}, rhs_contracting_dims={2}, + lhs_batch_dims={0}, rhs_batch_dims={0} +})") + .value(); + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); +} + +TEST_F(GemmFusionTest, SplitDimensionTwice) { + auto module = ParseAndReturnVerifiedModule(R"( +ENTRY e { + p0 = s8[4,2,32,4,2] parameter(0) + r1 = s8[8,32,8] reshape(p0) + t1 = s8[32,8,8] transpose(r1), dimensions={1,0,2} + r0 = s8[32,64] reshape(t1) + p1 = s8[32,32] parameter(1) + c0 = f16[32,32] convert(p1) + ROOT d = f16[64,32] dot(r0, c0), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +})") + .value(); + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); +} + +TEST_F(GemmFusionTest, DoNotTriggerOnUnsupportedOutputConversions) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + p0 = f16[128,256] parameter(0) + p1 = f16[256,512] parameter(1) + r = f16[128,512] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT c = u8[128,512] convert(r) +})")); + EXPECT_FALSE(GemmFusion(gpu_version_).Run(module.get()).value()); +} + +TEST_F(GemmFusionTest, FuseDotWithTrivialNoncontractingDim) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY e { + p0 = s8[60,5] parameter(0) + r0 = s8[3,20,5] reshape(p0) + c0 = f16[3,20,5] convert(r0) + p1 = f16[3,1,20] parameter(1) + ROOT d = f16[3,5,1] dot(c0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={2}, + lhs_batch_dims={0}, rhs_batch_dims={0} +})") + .value(); + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); +} + +TEST_F(GemmFusionTest, HandleDotIfCublasRequiresPadding) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY e { + p0 = f16[5,3] parameter(0) + p1 = f16[5,7] parameter(1) + ROOT d = f16[3,7] dot(p0, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +})")); + + const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0}; + EXPECT_TRUE(CublasRequiresPadding( + *xla::Cast( + module->entry_computation()->root_instruction()), + cc)); + EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value()); +} + +TEST_F(GemmFusionTest, FuseSliceOfParameterWithOtherUsers) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + p0 = f32[97,121] parameter(0) + s0 = f32[7,101] slice(p0), slice={[3:10], [10:111]} + p1 = f32[101,16] parameter(1) + d = f32[16,7] dot(p1, s0), + lhs_contracting_dims={0}, rhs_contracting_dims={1} + s1 = f32[3,33] slice(p0), slice={[10:13], [20:53]} + ROOT t = tuple(d, s1) +})")); + + const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0}; + EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value()); +} + +TEST_F(GemmFusionTest, DoNotFuseSliceOfMixedDimensions) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + p0 = bf16[768,64] parameter(0) + s0 = bf16[768,32] slice(p0), slice={[0:768], [0:32]} + b0 = bf16[256,3,32] reshape(s0) + b1 = bf16[256,96] reshape(b0) + p1 = bf16[256,96] parameter(1) + ROOT d = bf16[96,96] dot(b1, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +})")); + + const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0}; + EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value()); +} + +TEST_F(GemmFusionTest, DoNotFuseSlicesOfNonMajorFragments) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + p0 = f32[2,2,256,256] parameter(0) + s0 = f32[1,1,256,256] slice(p0), + slice={[0:1], [0:1], [0:256], [0:256]} + r0 = f32[256,256] reshape(s0) + p1 = f16[2,2,256,256] parameter(1) + s1 = f16[1,1,256,256] slice(p1), + slice={[0:1], [0:1], [0:256], [0:256]} + r1 = f16[256,256] reshape(s1) + ROOT d = f32[256,256] dot(r0, r1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})")); + + const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0}; + EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value()); +} + +TEST_F(GemmFusionTest, SliceToDegenerateIsSkipped) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + p = f32[3] parameter(0) + s = f32[1] slice(p), slice={[2:3]} + r = f32[] reshape(s) + b = f32[3,3] broadcast(r), dimensions={} + ROOT d = f32[3,3] dot(b, b), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)")); + const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0}; + + ASSERT_TRUE(GemmFusion(cc).Run(module.get()).value()); + + // Slice is not fused. + MatchHloModule(*module, R"( +; CHECK-NOT: slice +; CHECK: ENTRY +; CHECK: slice +)"); +} + +TEST_F(GemmFusionTest, MultipleUsesAreHandled) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + c = f32[] constant(1) + b = f32[6,8] broadcast(c), dimensions={} + p0 = f32[6,8] parameter(0) + a1 = f32[6,8] add(p0, b) + e = f32[6,8] exponential(a1) + a2 = f32[6,8] add(e, b) + d = f32[6,8] divide(b, a2) + p2 = f16[8,6] parameter(1) + cv = f32[8,6] convert(p2) + ROOT r = f32[6,6] dot(d, cv), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})")); + const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0}; + EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); +} + +TEST_F(GemmFusionTest, BinaryElementwiseOfBroadcastIsFused) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + p2 = f32[3072] parameter(2) + b = f32[8192,3072] broadcast(p2), dimensions={1} + p0 = f16[8192,3072] parameter(0) + p0c = f32[8192,3072] convert(p0) + a = f32[8192,3072] add(p0c, b) + p1 = f32[3072,768] parameter(1) + ROOT r = f32[8192,768] dot(a, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})")); + const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0}; + EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value()); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()))); +} + +TEST_F(GemmFusionTest, BinaryElementwiseOfUnsupportedBroadcastIsNotFused) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + p2 = f32[768] parameter(2) + b = f32[8192,768,4] broadcast(p2), dimensions={1} + s = f32[8192,3072] bitcast(b) + p0 = f16[8192,3072] parameter(0) + p0c = f32[8192,3072] convert(p0) + a = f32[8192,3072] add(p0c, s) + p1 = f32[3072,768] parameter(1) + ROOT r = f32[8192,768] dot(a, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})")); + const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0}; + EXPECT_FALSE(GemmFusion(cc).Run(module.get()).value()); +} + +class GemmFusionLevel2Test : public GemmFusionTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GemmFusionTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_triton_fusion_level(2); + return debug_options; + } +}; + +TEST_F(GemmFusionLevel2Test, ReshapeToScalarIsHandled) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + p0 = s8[5,3] parameter(0) + c = f16[5,3] convert(p0) + p1 = f16[1] parameter(1) + r = f16[] reshape(p1) + b = f16[5,7] broadcast(r) + ROOT d = f16[3,7] dot(c, b), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +})")); + + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); +} + +TEST_F(GemmFusionLevel2Test, DoNotFuseIncompatibleDimensionSplits) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + p1 = s8[5,7,2,3]{3,2,1,0} parameter(1) + t1 = s8[7,5,2,3]{3,2,1,0} transpose(p1), dimensions={1,0,2,3} + r1 = s8[7,30]{1,0} reshape(t1) + cvt = f16[7,30]{1,0} convert(r1) + p2 = f16[2,7,5,3]{3,2,1,0} parameter(2) + t2 = f16[7,2,5,3]{3,2,1,0} transpose(p2), dimensions={1,0,2,3} + r2 = f16[7,30]{1,0} reshape(t2) + a = f16[7,30]{1,0} add(cvt, r2) + p0 = f16[7,79]{1,0} parameter(0) + ROOT dot = f16[30,79]{1,0} dot(a, p0), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +})")); + + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Transpose(), m::Parameter(), m::Parameter()))); +} + +TEST_F(GemmFusionLevel2Test, DoNotFuseTooManyParameters) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + tmp_0 = f32[] constant(1) + tmp_1 = f32[3,49]{1,0} broadcast(tmp_0), dimensions={} + tmp_2 = f32[3,49]{1,0} parameter(6) + tmp_3 = f32[] constant(0) + tmp_4 = f32[3,49]{1,0} broadcast(tmp_3), dimensions={} + tmp_5 = pred[3,49]{1,0} compare(tmp_2, tmp_4), direction=GT + tmp_6 = f32[3,49]{1,0} convert(tmp_5) + tmp_7 = f32[3,49]{1,0} subtract(tmp_1, tmp_6) + tmp_8 = s32[] parameter(13) + tmp_9 = f32[] convert(tmp_8) + tmp_10 = f32[] maximum(tmp_9, tmp_0) + tmp_11 = f32[] divide(tmp_3, tmp_10) + tmp_12 = f32[3,49]{1,0} broadcast(tmp_11), dimensions={} + tmp_13 = pred[3,49]{1,0} parameter(7) + tmp_14 = pred[3,49]{1,0} parameter(10) + tmp_15 = pred[3,49]{1,0} and(tmp_13, tmp_14) + tmp_16 = f32[3,49]{1,0} convert(tmp_15) + tmp_17 = f32[3,49]{1,0} multiply(tmp_12, tmp_16) + tmp_18 = f32[3,49]{1,0} negate(tmp_17) + tmp_19 = f32[3,49]{1,0} multiply(tmp_7, tmp_18) + tmp_20 = f32[3,49]{1,0} parameter(19) + tmp_21 = f32[3,49]{1,0} subtract(tmp_1, tmp_20) + tmp_22 = f32[3,49]{1,0} divide(tmp_19, tmp_21) + tmp_23 = f32[3,49]{1,0} negate(tmp_22) + tmp_24 = f32[3,49]{1,0} negate(tmp_6) + tmp_25 = f32[3,49]{1,0} multiply(tmp_24, tmp_17) + tmp_26 = f32[3,49]{1,0} divide(tmp_25, tmp_20) + tmp_27 = f32[3,49]{1,0} add(tmp_23, tmp_26) + tmp_28 = f32[3,49]{1,0} parameter(18) + tmp_29 = f32[3,49]{1,0} multiply(tmp_27, tmp_28) + tmp_30 = f32[3,49]{1,0} parameter(17) + tmp_31 = f32[3,49]{1,0} multiply(tmp_29, tmp_30) + tmp_32 = f32[3,49]{1,0} parameter(16) + tmp_33 = f32[3,49]{1,0} multiply(tmp_31, tmp_32) + tmp_34 = f32[3,49]{1,0} parameter(15) + tmp_35 = f32[3,49]{1,0} add(tmp_33, tmp_34) + tmp_36 = f32[3,49]{1,0} parameter(14) + tmp_37 = f32[3,49]{1,0} add(tmp_35, tmp_36) + tmp_38 = f32[1,1]{1,0} constant({ {0} }) + tmp_39 = f32[1,1]{1,0} broadcast(tmp_38), dimensions={0,1} + tmp_40 = f32[] reshape(tmp_39) + tmp_41 = f32[3,32]{1,0} broadcast(tmp_40), dimensions={} + tmp_42 = u32[48]{0} parameter(11) + tmp_43 = u32[48]{0} parameter(5) + tmp_44 = u32[96]{0} concatenate(tmp_42, tmp_43), dimensions={0} + tmp_45 = u32[3,32]{1,0} reshape(tmp_44) + tmp_46 = u32[96]{0} reshape(tmp_45) + tmp_47 = u32[] constant(1) + tmp_48 = u32[3,32]{1,0} broadcast(tmp_47), dimensions={} + tmp_49 = u32[96]{0} reshape(tmp_48) + tmp_50 = u32[96]{0} shift-right-logical(tmp_46, tmp_49) + tmp_51 = u32[3,32]{1,0} reshape(tmp_50) + tmp_52 = u32[3,32]{1,0} or(tmp_51, tmp_48) + tmp_53 = f32[3,32]{1,0} bitcast-convert(tmp_52) + tmp_54 = f32[3,32]{1,0} broadcast(tmp_0), dimensions={} + tmp_55 = f32[3,32]{1,0} subtract(tmp_53, tmp_54) + tmp_56 = f32[1,1]{1,0} constant({ {1} }) + tmp_57 = f32[1,1]{1,0} broadcast(tmp_56), dimensions={0,1} + tmp_58 = f32[] reshape(tmp_57) + tmp_59 = f32[3,32]{1,0} broadcast(tmp_58), dimensions={} + tmp_60 = f32[3,32]{1,0} multiply(tmp_55, tmp_59) + tmp_61 = f32[3,32]{1,0} add(tmp_60, tmp_41) + tmp_62 = f32[3,32]{1,0} maximum(tmp_41, tmp_61) + tmp_63 = f32[3,32]{1,0} broadcast(tmp_3), dimensions={} + tmp_64 = pred[3,32]{1,0} compare(tmp_62, tmp_63), direction=LT + tmp_65 = f32[3,32]{1,0} convert(tmp_64) + tmp_66 = f32[3,49]{1,0} parameter(9) + tmp_67 = f32[49]{0} parameter(4) + tmp_68 = f32[3,49]{1,0} broadcast(tmp_67), dimensions={1} + tmp_69 = f32[3,49]{1,0} add(tmp_66, tmp_68) + tmp_70 = f32[1,49]{1,0} parameter(12) + tmp_71 = f32[1,49]{1,0} broadcast(tmp_0), dimensions={} + tmp_72 = f32[1,49]{1,0} divide(tmp_70, tmp_71) + tmp_73 = f32[1,49]{1,0} broadcast(tmp_72), dimensions={0,1} + tmp_74 = f32[49]{0} reshape(tmp_73) + tmp_75 = f32[3,49]{1,0} broadcast(tmp_74), dimensions={1} + tmp_76 = f32[3,49]{1,0} subtract(tmp_69, tmp_75) + tmp_77 = f32[1,49]{1,0} parameter(3) + tmp_78 = f32[1,49]{1,0} parameter(8) + tmp_79 = f32[1,49]{1,0} divide(tmp_78, tmp_71) + tmp_80 = f32[1,49]{1,0} multiply(tmp_72, tmp_72) + tmp_81 = f32[1,49]{1,0} subtract(tmp_79, tmp_80) + tmp_82 = f32[1,49]{1,0} add(tmp_81, tmp_71) + tmp_83 = f32[1,49]{1,0} rsqrt(tmp_82) + tmp_84 = f32[1,49]{1,0} multiply(tmp_77, tmp_83) + tmp_85 = f32[1,49]{1,0} broadcast(tmp_84), dimensions={0,1} + tmp_86 = f32[49]{0} reshape(tmp_85) + tmp_87 = f32[3,49]{1,0} broadcast(tmp_86), dimensions={1} + tmp_88 = f32[3,49]{1,0} multiply(tmp_76, tmp_87) + tmp_89 = f32[1,49]{1,0} parameter(2) + tmp_90 = f32[1,49]{1,0} broadcast(tmp_89), dimensions={0,1} + tmp_91 = f32[49]{0} reshape(tmp_90) + tmp_92 = f32[3,49]{1,0} broadcast(tmp_91), dimensions={1} + tmp_93 = f32[3,49]{1,0} add(tmp_88, tmp_92) + tmp_94 = f32[49,32]{1,0} parameter(1) + tmp_95 = f32[3,32]{1,0} dot(tmp_93, tmp_94), lhs_contracting_dims={1}, rhs_contracting_dims={0} + tmp_96 = f32[32]{0} parameter(0) + tmp_97 = f32[3,32]{1,0} broadcast(tmp_96), dimensions={1} + tmp_98 = f32[3,32]{1,0} add(tmp_95, tmp_97) + tmp_99 = f32[3,32]{1,0} multiply(tmp_65, tmp_98) + tmp_100 = f32[3,32]{1,0} divide(tmp_99, tmp_63) + tmp_101 = f32[3,32]{1,0} maximum(tmp_100, tmp_63) + ROOT tmp_102 = f32[49,32]{1,0} dot(tmp_37, tmp_101), lhs_contracting_dims={0}, rhs_contracting_dims={0} +})")); + + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), + HloOpcode::kFusion); + EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(), + HloInstruction::FusionKind::kCustom); + EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(), + TritonFusionAnalysis::kMaxParameterPerDotOperand * 2); +} + +TEST_F(GemmFusionLevel2Test, + DoNotFuseTooManyParametersWhenAnInstructionWouldAddMultipleParameters) { + static_assert(TritonFusionAnalysis::kMaxParameterPerDotOperand == 4, + "We have to update this test."); + // If we fuse the select, it adds 2 additional parameters at once (not 3, + // because the select instruction itself is removed from the parameters). + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + a = f32[3,49]{1,0} parameter(0) + b = f32[3,49]{1,0} parameter(1) + c = pred[3,49]{1,0} parameter(2) + d = f32[3,49]{1,0} parameter(3) + e = f32[3,49]{1,0} parameter(4) + add0 = f32[3,49]{1,0} add(a, b) + select = f32[3,49]{1,0} select(c, d, e) + add1 = f32[3,49]{1,0} add(add0, select) + f = f32[3,32]{1,0} parameter(5) + ROOT tmp_102 = f32[49,32]{1,0} dot(add1, f), lhs_contracting_dims={0}, rhs_contracting_dims={0} +})")); + + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), + HloOpcode::kFusion); + EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(), + HloInstruction::FusionKind::kCustom); + EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(), + TritonFusionAnalysis::kMaxParameterPerDotOperand + 1); +} + +TEST_F(GemmFusionLevel2Test, DoNotFuseTooManyParametersForConcat) { + static_assert(TritonFusionAnalysis::kMaxParameterPerDotOperand == 4, + "We have to update this test."); + // The concat shouldn't overgo the allowed parameter limit. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + a = f32[3,3]{1,0} parameter(0) + b = f32[3,3]{1,0} parameter(1) + c = f32[3,3]{1,0} parameter(2) + d = f32[3,3]{1,0} parameter(3) + e = f32[3,3]{1,0} parameter(4) + f = f16[3,3]{1,0} parameter(5) + concat = f32[15,3]{1,0} concatenate(a, b, c, d, e), dimensions={0} + convert = f32[3,3]{1,0} convert(f) + ROOT dot = f32[15,3]{1,0} dot(concat, convert), lhs_contracting_dims={1}, rhs_contracting_dims={1} +})")); + + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), + HloOpcode::kFusion); + EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(), + HloInstruction::FusionKind::kCustom); + EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(), + TritonFusionAnalysis::kMaxParameterPerDotOperand + 1); +} + +TEST_F(GemmFusionLevel2Test, + InstructionsReachableFromMultipleOperandsAreHandledCorrectly) { + static_assert(TritonFusionAnalysis::kMaxParameterPerDotOperand == 4, + "We have to update this test."); + // There was a bug that some dead code was generated into some fusions in a + // specific edge case. When some instructions were reachable both through the + // LHS and the RHS operands, the BFS (Breadth-first search) through the LHS1 + // operand "marked" one operation as non-fusible because it would exceed the + // limit on fusion parameters per operand. But the BFS through the RHS operand + // went through that node and fused some more operands. So the resulting + // fusion was not connected and caused errors. This test case checks that such + // configurations generate a correct HLO now. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + a = f32[2,4]{1,0} parameter(0) + b = f32[2,4]{1,0} parameter(1) + c = f32[2,4]{1,0} parameter(2) + d = f32[2,4]{1,0} parameter(3) + e = f32[2,4]{1,0} parameter(4) + add0 = f32[2,4]{1,0} add(a, b) + add1 = f32[2,4]{1,0} add(add0, c) + add2 = f32[2,4]{1,0} add(add1, d) + add3 = f32[2,4]{1,0} add(add2, e) + ROOT r = f32[2,2]{1,0} dot(add3, add0), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +})")); + + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + // ~VerifiedHloModule() will verify the module. +} + +TEST_F(GemmFusionLevel2Test, EachScopeIsFusedToASeparateSubgraph) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + a = f32[2,4]{1,0} parameter(0) + b = f32[2,4]{1,0} parameter(1) + add = f32[2,4]{1,0} add(a, b) + ROOT r = f32[2,2]{1,0} dot(add, add), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +})")); + + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + + MatchHloModule(*module, R"( +CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0) +CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1) +CHECK-DAG: %[[ADD0:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]]) +CHECK-DAG: %[[P2:.*]] = f32[2,4]{1,0} parameter(2) +CHECK-DAG: %[[P3:.*]] = f32[2,4]{1,0} parameter(3) +CHECK-DAG: %[[ADD1:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P2]], f32[2,4]{1,0} %[[P3]]) +CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} dot(f32[2,4]{1,0} %[[ADD0]], f32[2,4]{1,0} %[[ADD1]]) +CHECK: ENTRY +CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0) +CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1) +CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} +CHECK-SAME: fusion(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]], f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]]), +CHECK-SAME: kind=kCustom +CHECK-SAME: __triton_gemm +})"); +} + +// The 2 inputs of the add operation are the same and they are iterated the same +// way, so the same parameter node is reused for them. +// The reuse happens per "operand fusion", so the add of the LHS and RHS still +// use different nodes. +TEST_F(GemmFusionLevel2Test, ParamNodesAreReusedIfTheyHaveTheSameIterSpec) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + a = f32[2,4]{1,0} parameter(0) + add = f32[2,4]{1,0} add(a, a) + ROOT r = f32[2,2]{1,0} dot(add, add), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +})")); + + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + + MatchHloModule(*module, R"( +CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0) +CHECK-DAG: %[[ADD0:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P0]]) +CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1) +CHECK-DAG: %[[ADD1:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P1]], f32[2,4]{1,0} %[[P1]]) +CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} dot(f32[2,4]{1,0} %[[ADD0]], f32[2,4]{1,0} %[[ADD1]]) +CHECK: ENTRY +CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0) +CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} +CHECK-SAME: fusion(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P0]]) +CHECK-SAME: kind=kCustom +CHECK-SAME: __triton_gemm +})"); +} + +// NEGATE has the same iteration spec at both usages, so the node is reused +// (implying that P0 is also reused). +TEST_F(GemmFusionLevel2Test, NonParamNodesAreReusedIfTheyHaveTheSameIterSpec) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + a = f32[4,4]{1,0} parameter(0) + b = f32[4,4]{1,0} parameter(1) + negate = f32[4,4]{1,0} negate(a) + sine = f32[4,4]{1,0} sine(negate) + add = f32[4,4]{1,0} add(negate, sine) + ROOT r = f32[4,4]{1,0} dot(add, b), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +})")); + + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + + MatchHloModule(*module, R"( +CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0) +CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1) +CHECK-DAG: %[[NEGATE:.*]] = f32[4,4]{1,0} negate(f32[4,4]{1,0} %[[P0]]) +CHECK-DAG: %[[SINE:.*]] = f32[4,4]{1,0} sine(f32[4,4]{1,0} %[[NEGATE]]) +CHECK-DAG: %[[ADD:.*]] = f32[4,4]{1,0} add(f32[4,4]{1,0} %[[NEGATE]], f32[4,4]{1,0} %[[SINE]]) +CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0} dot(f32[4,4]{1,0} %[[ADD]], f32[4,4]{1,0} %[[P1]]) +CHECK: ENTRY +CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0) +CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1) +CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0} +CHECK-SAME: fusion(f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[P1]]) +CHECK-SAME: kind=kCustom +CHECK-SAME: __triton_gemm +})"); +} + +// The direct read of the input and the transposed read of the input have +// different iteration specs, so we don't reuse the node. +TEST_F(GemmFusionLevel2Test, NodesAreNotReusedIfTheyHaveDifferentIterSpecs) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + a = f32[4,4]{1,0} parameter(0) + b = f32[4,4]{1,0} parameter(1) + tr_a = f32[4,4]{1,0} transpose(a), dimensions={1,0} + add = f32[4,4]{1,0} add(a, tr_a) + ROOT r = f32[4,4]{1,0} dot(add, b), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +})")); + + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + + MatchHloModule(*module, R"( +CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0) +CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1) +CHECK-DAG: %[[P2:.*]] = f32[4,4]{1,0} parameter(2) +CHECK-DAG: %[[TRANSPOSE:.*]] = f32[4,4]{1,0} transpose(f32[4,4]{1,0} %[[P1]]) +CHECK-DAG: %[[ADD:.*]] = f32[4,4]{1,0} add(f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[TRANSPOSE]]) +CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0} dot(f32[4,4]{1,0} %[[ADD]], f32[4,4]{1,0} %[[P2]]) +CHECK: ENTRY +CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0) +CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1) +CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0} +CHECK-SAME: fusion(f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[P1]]) +CHECK-SAME: kind=kCustom +CHECK-SAME: __triton_gemm +})"); +} + +TEST_F(GemmFusionLevel2Test, OperationsAddingMoreParametersGetMultipleTries) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +e { + p0 = f32[2,2] parameter(0) + c0 = f32[] constant(12345) + b0 = f32[2,2] broadcast(c0), dimensions={} + m0 = f32[2,2] multiply(p0, b0) + c1 = f32[] constant(34567) + b1 = f32[2,2] broadcast(c1), dimensions={} + a0 = f32[2,2] add(m0, b1) + b3 = f32[2,2,2] broadcast(a0), dimensions={0,1} + p2 = f32[2,2,2] parameter(2) + m2 = f32[2,2,2] multiply(p2, b3) + p1 = f32[2]{0} parameter(1) + c2 = f32[] constant(5678) + b2 = f32[2] broadcast(c2), dimensions={} + a1 = f32[2]{0} add(p1, b2) + b4 = f32[2,2,2] broadcast(a1), dimensions={2} + m1 = f32[2,2,2] multiply(m2, b4) + b = f32[4,2] bitcast(m1) + p3 = f16[2,2] parameter(3) + p3c = f32[2,2] convert(p3) + ROOT r = f32[4,2] dot(b, p3c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})")); + + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), + m::Parameter(), m::Parameter())))); +} + +TEST_F(GemmFusionLevel2Test, GemmFusionBailsOutPreAmpere) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + p0 = f32[2,53] parameter(0) + p0e = f32[2,53] exponential(p0) + p1 = s16[53,2] parameter(1) + p1c = f32[53,2] convert(p1) + ROOT dot = f32[2,2] dot(p0e, p1c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})")); + EXPECT_THAT( + GemmFusion(se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0}) + .Run(module.get()), + tsl::testing::StatusIs( + absl::StatusCode::kFailedPrecondition, + ::testing::StrEq( + "Triton support is only enabled for Ampere GPUs and up."))); +} + +TEST_F(GemmFusionLevel2Test, ParameterUsedElementwiseTwiceIsFused) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +ENTRY e { + p0 = f32[2,35] parameter(0) + p0n = f32[2,35] negate(p0) + p0e = f32[2,35] exponential(p0) + a = f32[2,35] add(p0e, p0n) + p1 = f16[35,2] parameter(1) + p1c = f32[35,2] convert(p1) + ROOT dot = f32[2,2] dot(a, p1c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})")); + EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}) + .Run(module.get()) + .value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch((m::Fusion(m::Parameter(), m::Parameter())))); + TF_ASSERT_OK_AND_ASSIGN( + const auto analysis, + TritonFusionAnalysis::Execute(*module->entry_computation() + ->root_instruction() + ->called_computations()[0])); + EXPECT_EQ(analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).size(), + 1); + EXPECT_EQ(analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).size(), + 1); +} + +TEST_F(GemmFusionLevel2Test, + ParameterUsedNonElementwiseTwiceIsFusedOnBothPaths) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +ENTRY e { + p0 = f32[4,4] parameter(0) + p0t = f32[4,4] transpose(p0), dimensions={1,0} + a = f32[4,4] add(p0, p0t) + p1 = f16[4,5] parameter(1) + p1c = f32[4,5] convert(p1) + ROOT dot = f32[4,5] dot(a, p1c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})")); + EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}) + .Run(module.get()) + .value()); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), m::Parameter())))); +} + +TEST_F(GemmFusionLevel2Test, + ComputationParameterWithMultipleUsersIsNotTrivialToFuse) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + p0 = f32[400,400] parameter(0) + + c0 = f16[400,400] convert(p0) + p1 = f16[400,400] parameter(1) + dot0 = f16[400,400] dot(c0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + + c1 = f16[400,400] convert(p0) + p2 = f16[400,400] parameter(2) + dot1 = f16[400,400] dot(c1, p2), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + + ROOT a = f16[400,400] add(dot0, dot1) +})")); + EXPECT_FALSE(GemmFusion(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}) + .Run(module.get()) + .value()); +} + +TEST_F(GemmFusionLevel2Test, NarrowingConversionIsAlwaysBetterToFuse) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + p0 = s8[512,512] parameter(0) + c0 = f16[512,512] convert(p0) + p1 = f16[512,512] parameter(1) + dot0 = f16[512,512] dot(c0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + + n = f16[512,512] negate(c0) + ROOT a = f16[512,512] add(dot0, n) +})")); + EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}) + .Run(module.get()) + .value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch((m::Add(m::Fusion(m::Parameter(), m::Parameter()), + m::Negate())))); +} + +TEST_F(GemmFusionLevel2Test, NestedSlicingIsAnalyzedCorrectly) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +triton_gemm_d_computation { + p0 = f32[6,24]{1,0} parameter(0) + s1 = f32[5,20]{1,0} slice(p0), slice={[1:6], [3:23]} + n1 = f32[5,20]{1,0} negate(s1) + s2 = f32[3,7]{1,0} slice(n1), slice={[1:4], [13:20]} + p1 = f32[7,37]{1,0} parameter(1) + ROOT d = f32[3,37]{1,0} dot(s2, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[7,37]{1,0} parameter(0) + p1 = f32[6,24]{1,0} parameter(1) + ROOT triton_gemm_d = f32[3,37]{1,0} fusion(p1, p0), kind=kCustom, + calls=triton_gemm_d_computation +})")); + const HloComputation* computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*computation)); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, + computation->parameter_instruction(0), 0), + ElementsAre(FieldsAre(/*stride=*/24, /*count=*/6, + /*slice_start=*/2, /*sliced_count=*/3, + /*subfragments=*/ElementsAre(3)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, + computation->parameter_instruction(0), 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/24, + /*slice_start=*/16, /*sliced_count=*/7, + /*subfragments=*/ElementsAre(7)))); +} + +TEST_F(GemmFusionLevel2Test, FusedConcatenationIsAnalyzedCorrectly) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +e { + p0 = s8[153,1536] parameter(0) + p1 = s8[153,128] parameter(1) + p2 = s8[153,256] parameter(2) + cat = s8[153,1920] concatenate(p0, p1, p2), dimensions={1} + cvt = bf16[153,1920] convert(cat) + p3 = bf16[16,153] parameter(3) + ROOT d = bf16[16,1920] dot(p3, cvt), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})")); + EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}) + .Run(module.get()) + .value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), + m::Parameter(), m::Parameter())))); + const HloComputation* computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*computation)); + + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, + computation->parameter_instruction(1), 0), + ElementsAre(FieldsAre(/*stride=*/1536, /*count=*/153, + /*slice_start=*/0, /*sliced_count=*/153, + /*subfragments=*/ElementsAre(153)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, + computation->parameter_instruction(1), 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/1536, + /*slice_start=*/0, /*sliced_count=*/1536, + /*subfragments=*/ElementsAre(1536)))); + + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, + computation->parameter_instruction(2), 0), + ElementsAre(FieldsAre(/*stride=*/128, /*count=*/153, + /*slice_start=*/0, /*sliced_count=*/153, + /*subfragments=*/ElementsAre(153)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, + computation->parameter_instruction(2), 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/128, + /*slice_start=*/-1536, /*sliced_count=*/128, + /*subfragments=*/ElementsAre(128)))); + + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, + computation->parameter_instruction(3), 0), + ElementsAre(FieldsAre(/*stride=*/256, /*count=*/153, + /*slice_start=*/0, /*sliced_count=*/153, + /*subfragments=*/ElementsAre(153)))); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, + computation->parameter_instruction(3), 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/256, + /*slice_start=*/-1536 - 128, + /*sliced_count=*/256, + /*subfragments=*/ElementsAre(256)))); +} + +TEST_F(GemmFusionLevel2Test, IndivisibleConcatenationIsNotFused) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +e { + p0 = s8[124,1024] parameter(0) + p1 = s8[124,1001] parameter(1) + cat = s8[124,2025] concatenate(p0, p1), dimensions={1} + cvt = f16[124,2025] convert(cat) + p2 = f16[123,124] parameter(2) + ROOT d = f16[2025,123] dot(cvt, p2), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +})")); + EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}) + .Run(module.get()) + .value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch((m::Fusion(m::Concatenate(), m::Parameter())))); +} + +TEST_F(GemmFusionLevel2Test, ConcatenationOfContractingIsNotFused) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +e { + p0 = s8[124,1024] parameter(0) + p1 = s8[124,1024] parameter(1) + cat = s8[124,2048] concatenate(p0, p1), dimensions={1} + cvt = f16[124,2048] convert(cat) + p2 = f16[123,2048] parameter(2) + ROOT d = f16[124,123] dot(cvt, p2), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +})")); + EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}) + .Run(module.get()) + .value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch((m::Fusion(m::Concatenate(), m::Parameter())))); +} + +TEST_F(GemmFusionLevel2Test, ConcatenationOfBatchIsNotFused) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +e { + p0 = s8[124,1024,50] parameter(0) + p1 = s8[124,1024,50] parameter(1) + cat = s8[124,2048,50] concatenate(p0, p1), dimensions={1} + cvt = f16[124,2048,50] convert(cat) + p2 = f16[123,2048,50] parameter(2) + ROOT d = f16[2048,124,123] dot(cvt, p2), + lhs_batch_dims={1}, rhs_batch_dims={1}, + lhs_contracting_dims={2}, rhs_contracting_dims={2} +})")); + EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}) + .Run(module.get()) + .value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch((m::Fusion(m::Concatenate(), m::Parameter())))); +} + +TEST_F(GemmFusionLevel2Test, + DifferentConcatenationOfSameParametersIsFusedViaNodeDuplication) { + // It means that the same input is passed to the fusion multiple times and + // it's read differently for each. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +e { + p0 = s8[128,2] parameter(0) + p1 = s8[128,2] parameter(1) + cat0 = s8[256,2] concatenate(p0, p1), dimensions={0} + cvt0 = f16[256,2] convert(cat0) + cat1 = s8[256,2] concatenate(p1, p0), dimensions={0} + n1 = s8[256,2] negate(cat1) + cvt1 = f16[256,2] convert(n1) + a = f16[256,2] add(cvt1, cvt0) + p2 = f16[2,18] parameter(2) + ROOT d = f16[18,256] dot(p2, a), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +})")); + + EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}) + .Run(module.get()) + .value()); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), m::Parameter(), + m::Parameter(), m::Parameter())))); +} + +TEST_F(GemmFusionTest, CopiesDotMetadataToFusionOp) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY e { + p0 = f16[2,18] parameter(0) + p1 = f16[256,2] parameter(1) + ROOT d = f16[18,256] dot(p0, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={1}, metadata={op_name="foo"} +})") + .value(); + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + EXPECT_EQ( + module->entry_computation()->root_instruction()->metadata().op_name(), + "foo"); +} + +// A test fixture class for testing the threshold for small matrices. +class SmallDotGemmFusionTest : public GemmFusionTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GemmFusionTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_gemm_rewrite_size_threshold(100); + return debug_options; + } +}; + +TEST_F(SmallDotGemmFusionTest, SkipSmallMatrixMultiplicationRewrite) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY e { + p0 = f16[2,10] parameter(0) + p1 = f16[10,2] parameter(1) + ROOT d = f16[10,10] dot(p0, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +})") + .value(); + + EXPECT_FALSE(GemmFusion(gpu_version_).Run(module.get()).value()); + + MatchHloModule(*module, R"( +; CHECK-LABEL: ENTRY %e ({{.*}}: f16[2,10], {{.*}}: f16[10,2]) -> f16[10,10] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,10]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f16[10,2]{1,0} parameter(1) +; CHECK: ROOT {{.*}} = f16[10,10]{1,0} dot(f16[2,10]{1,0} [[P0]], f16[10,2]{1,0} [[P1]]) +})"); +} + +TEST_F(SmallDotGemmFusionTest, LargeMatrixMultiplicationIsRewritten) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY e { + p0 = f16[2,18] parameter(0) + p1 = f16[50,2] parameter(1) + ROOT d = f16[18,50] dot(p0, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={1} +})") + .value(); + + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + + MatchHloModule(*module, R"( +; CHECK-LABEL: ENTRY %e ({{.*}}: f16[2,18], {{.*}}: f16[50,2]) -> f16[18,50] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,18]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f16[50,2]{1,0} parameter(1) +; CHECK: ROOT {{.*}} = f16[18,50]{1,0} +; CHECK: fusion(f16[2,18]{1,0} [[P0]], f16[50,2]{1,0} [[P1]]), +; CHECK: kind=kCustom +; CHECK: __triton_gemm +})"); +} + +class SparseDotTest : public GemmFusionTest {}; + +TEST_F(SparseDotTest, DotWithSparseLhsOperandIsRewritten) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule test +ENTRY main { + lhs = f16[2,16] parameter(0) + rhs = f16[32,2] parameter(1) + meta = u16[2,2] parameter(2) + ROOT dot = f32[2,2] dot(lhs, rhs, meta), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4 +})") + .value(); + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + + MatchHloModule(*module, R"( +; CHECK-LABEL: ENTRY %main ({{.*}}: f16[2,16], {{.*}}: f16[32,2], {{.*}}: u16[2,2]) -> f32[2,2] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,16]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f16[32,2]{1,0} parameter(1) +; CHECK-NEXT: [[META:%[^ ]+]] = u16[2,2]{1,0} parameter(2) +; CHECK: ROOT {{.*}} = f32[2,2]{1,0} +; CHECK-SAME: fusion(f16[2,16]{1,0} [[P0]], f16[32,2]{1,0} [[P1]], u16[2,2]{1,0} [[META]]), +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton_gemm +})"); +} + +TEST_F(SparseDotTest, DotWithSparseRhsOperandIsNotSupported) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule test +ENTRY main { + lhs = f16[2,32] parameter(0) + rhs = f16[16,2] parameter(1) + meta = u16[2,2] parameter(2) + ROOT dot = f32[2,2] dot(lhs, rhs, meta), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=R.0@2:4 +})") + .value(); + auto result = GemmFusion(gpu_version_).Run(module.get()); + EXPECT_FALSE(result.ok()); +} + +TEST_F(SparseDotTest, UnsupportedSparsityType) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule test +ENTRY main { + lhs = f16[2,8] parameter(0) + rhs = f16[32,2] parameter(1) + meta = u16[2,1] parameter(2) + ROOT dot = f32[2,2] dot(lhs, rhs, meta), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@1:4 +})") + .value(); + auto result = GemmFusion(gpu_version_).Run(module.get()); + EXPECT_FALSE(result.ok()); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/gemm_rewriter.cc b/xla/service/gpu/gemm_rewriter.cc index 919573f08cd02..0aa610fc92f33 100644 --- a/xla/service/gpu/gemm_rewriter.cc +++ b/xla/service/gpu/gemm_rewriter.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,25 +19,33 @@ limitations under the License. #include #include #include +#include #include #include #include #include +#include #include #include #include #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/primitive_util.h" +#include "xla/service/algorithm_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/ir_emission_utils.h" @@ -52,8 +60,11 @@ limitations under the License. #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" +#include "xla/types.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/statusor.h" #include "tsl/protobuf/dnn.pb.h" @@ -68,21 +79,22 @@ namespace { namespace m = match; // Give this instruction a more useful name than "custom-call.42". -Status SetName(HloModule *module, HloInstruction *gemm) { +absl::Status SetName(HloModule *module, HloInstruction *gemm) { if (IsCublasLtMatmul(*gemm)) { module->SetAndUniquifyInstrName(gemm, "cublas-lt-matmul"); - return OkStatus(); + return absl::OkStatus(); } - GemmBackendConfig config; - TF_ASSIGN_OR_RETURN(config, gemm->backend_config()); + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + gemm->backend_config()); + const GemmBackendConfig &config = gpu_config.gemm_backend_config(); const DotDimensionNumbers &dot_dims = config.dot_dimension_numbers(); bool is_batch_dot = !dot_dims.lhs_batch_dimensions().empty() || !dot_dims.rhs_batch_dimensions().empty(); module->SetAndUniquifyInstrName( gemm, is_batch_dot ? "cublas-batch-gemm" : "cublas-gemm"); - return OkStatus(); + return absl::OkStatus(); } // Returns whether a given PrimitiveType is supported by cuBLASLt Epilogue @@ -216,7 +228,6 @@ bool IsSupportedF8Pattern( } std::reverse(subgraph.begin(), subgraph.end()); - // When not operating directly on an FP8 operand, the second and // third instructions in the subgraph must describe a dequantization, i.e. a // convert instruction followed by a multiply/divide instruction. @@ -243,19 +254,21 @@ bool IsSupportedF8Pattern( auto use_spmd_partitioning = [](const HloInstruction *instr) -> bool { return instr->GetModule()->config().use_spmd_partitioning(); }; + for (int i = 3; i < subgraph.size(); ++i) { // The remaining instructions must be commutative with dequantization. // Bitcast, broadcast, copy, dynamic-slice, pad, reshape, select, slice, - // all-gather, all-to-all and collective-permute instructions are supported. - // Specifically, the all-gather, all-to-all and collective-permute - // operations are permitted only in SPMD cases since the optimization cannot - // be guaranteed to be applied to all replicas in the MPMD scenario. + // transpose, all-gather, all-to-all and collective-permute instructions are + // supported. Specifically, the all-gather, all-to-all and + // collective-permute operations are permitted only in SPMD cases since the + // optimization cannot be guaranteed to be applied to all replicas in the + // MPMD scenario. if (!Match( subgraph[i].first, m::AnyOf( m::Bitcast().WithPredicate(preserves_element_type), m::Broadcast(), m::Copy(), m::DynamicSlice(), m::Pad(), - m::Reshape(), m::Select(), m::Slice(), + m::Reshape(), m::Select(), m::Slice(), m::Transpose(), m::AllGather().WithPredicate(use_spmd_partitioning), m::AllToAll().WithPredicate(use_spmd_partitioning), m::CollectivePermute().WithPredicate(use_spmd_partitioning)))) { @@ -471,20 +484,40 @@ auto OptionalBitcast(HloInstruction **optional_bitcast, Pattern pattern) { // when the output of the GEMM is requested in FP8 format. class GemmRewriterVisitor : public DfsHloRewriteVisitor { public: - explicit GemmRewriterVisitor(se::GpuComputeCapability gpu_version) - : gpu_version_(gpu_version) {} + explicit GemmRewriterVisitor(const se::GpuComputeCapability &gpu_version, + const bool f8_rewrite) + : gpu_version_(gpu_version), f8_rewrite_(f8_rewrite) {} - Status HandleDot(HloInstruction *instr) override { - if (!IsMatrixMultiplication(*instr)) { - return OkStatus(); + absl::Status HandleDot(HloInstruction *instr) override { + if (!IsMatrixMultiplication(*instr) && + !IsMatrixVectorMultiplication(*instr)) { + return absl::OkStatus(); + } + // Sparse dot is not supported. + if (Cast(instr)->sparse_operands()) { + return absl::OkStatus(); + } + + int64_t gemm_rewrite_size_threshold = + instr->GetModule() + ->config() + .debug_options() + .xla_gpu_gemm_rewrite_size_threshold(); + TF_ASSIGN_OR_RETURN(bool is_matmul_tiny, + IsMatrixMultiplicationTooSmallForRewriting( + *instr, gemm_rewrite_size_threshold)); + if (is_matmul_tiny && IsDotSupportedByClassicalEmitters(*instr)) { + return absl::OkStatus(); } CHECK(!instr->IsRank2Transpose()); CHECK(!instr->mutable_operand(0)->IsRank2Transpose()); CHECK(!instr->mutable_operand(1)->IsRank2Transpose()); - // Create a GemmBackendConfig based on the instruction. - GemmBackendConfig gemm_backend_config; + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config, + instr->backend_config()); + GemmBackendConfig &gemm_backend_config = + *gpu_backend_config.mutable_gemm_backend_config(); gemm_backend_config.set_alpha_real(1.0); gemm_backend_config.set_alpha_imag(0.0); gemm_backend_config.set_beta(0.0); @@ -500,78 +533,97 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { int64_t lhs_batch_dims_size = instr->dot_dimension_numbers().lhs_batch_dimensions_size(); - int64_t lhs_stride = lhs->shape().dimensions(lhs_batch_dims_size) * - lhs->shape().dimensions(lhs_batch_dims_size + 1); - int64_t rhs_stride = rhs->shape().dimensions(lhs_batch_dims_size) * - rhs->shape().dimensions(lhs_batch_dims_size + 1); + bool is_lhs_vector = + lhs->shape().dimensions_size() == lhs_batch_dims_size + 1; + bool is_rhs_vector = + rhs->shape().dimensions_size() == lhs_batch_dims_size + 1; + int64_t lhs_stride = + is_lhs_vector ? lhs->shape().dimensions(lhs_batch_dims_size) + : lhs->shape().dimensions(lhs_batch_dims_size) * + lhs->shape().dimensions(lhs_batch_dims_size + 1); + int64_t rhs_stride = + is_rhs_vector ? rhs->shape().dimensions(lhs_batch_dims_size) + : rhs->shape().dimensions(lhs_batch_dims_size) * + rhs->shape().dimensions(lhs_batch_dims_size + 1); gemm_backend_config.set_lhs_stride(lhs_stride); gemm_backend_config.set_rhs_stride(rhs_stride); - // First try to match the fp8 gemm pattern. - TF_ASSIGN_OR_RETURN(bool supported_by_cublaslt, - GemmIsSupportedByCublasLt(*instr, gemm_backend_config)); - HloInstruction *a, *b, *a_scale = nullptr, *b_scale = nullptr; - // Sequence of ops between dequantization and GEMM which are mathematically - // commutative with dequantization. The second element of the pair gives the - // index of the operand identifying the next op in the sequence. - std::vector> a_ops, b_ops; - bool a_mult_scale, b_mult_scale; - if (supported_by_cublaslt && - Match(instr, - m::Dot(m::Op().WithPredicate([&](const HloInstruction *instr) { - return IsSupportedF8Pattern(const_cast(instr), - a, a_scale, a_mult_scale, a_ops); - }), - m::Op().WithPredicate([&](const HloInstruction *instr) { - return IsSupportedF8Pattern( - const_cast(instr), b, b_scale, - b_mult_scale, b_ops); - })))) { + if (f8_rewrite_) { + // Rewrite FP8 GEMMs into a type-specific cublasLT Custom Call. TF_ASSIGN_OR_RETURN( - bool created_call, - CreateF8CustomCall(instr, gemm_backend_config, a, b, a_scale, b_scale, - a_mult_scale, b_mult_scale, a_ops, b_ops)); - if (created_call) { - return OkStatus(); + bool supported_by_cublaslt, + GemmIsSupportedByCublasLt(*instr, gemm_backend_config)); + HloInstruction *a, *b, *a_scale = nullptr, *b_scale = nullptr; + // Sequence of ops between dequantization and GEMM which are + // mathematically commutative with dequantization. The second element of + // the pair gives the index of the operand identifying the next op in the + // sequence. + std::vector> a_ops, b_ops; + bool a_mult_scale{}, b_mult_scale{}; + if (supported_by_cublaslt && + Match(instr, + m::Dot(m::Op().WithPredicate([&](const HloInstruction *instr) { + return IsSupportedF8Pattern( + const_cast(instr), a, a_scale, + a_mult_scale, a_ops); + }), + m::Op().WithPredicate([&](const HloInstruction *instr) { + return IsSupportedF8Pattern( + const_cast(instr), b, b_scale, + b_mult_scale, b_ops); + })))) { +#if TENSORFLOW_USE_ROCM + if (instr->shape().element_type() != F16 && + instr->shape().element_type() != F32) { + TF_ASSIGN_OR_RETURN(instr, + TurnF8DotWithUnsupportedOutputTypeIntoF32(instr)); + } +#endif // TENSORFLOW_USE_ROCM + TF_ASSIGN_OR_RETURN(bool created_call, + CreateF8CustomCall(instr, gpu_backend_config, a, b, + a_scale, b_scale, a_mult_scale, + b_mult_scale, a_ops, b_ops)); + if (created_call) { + return absl::OkStatus(); + } } - } - - if (IsF8Type(instr->operand(0))) { - // Couldn't rewrite as an FP8 cublasLt custom call, so turn into an FP16 - // dot and below it will be rewritten as an FP16 cublas or cublasLt call. - TF_ASSIGN_OR_RETURN(instr, TurnF8DotIntoF16Dot(instr)); - } - - // Couldn't rewrite as an FP8 cublasLt custom call, rewrite as a cublas or - // cublasLt call. - TF_ASSIGN_OR_RETURN( - absl::string_view gemm_custom_call_target, - GetNonFp8GemmCustomCallTarget(*instr, gemm_backend_config)); - const Shape &output_shape = instr->shape(); - HloInstruction *gemm_call = - instr->AddInstruction(HloInstruction::CreateCustomCall( - output_shape, - {instr->mutable_operand(0), instr->mutable_operand(1)}, - gemm_custom_call_target)); - TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gemm_backend_config)); - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, gemm_call)); - return OkStatus(); + if (IsF8Type(instr->operand(0))) { + // FP8 rewriter couldn't rewrite dot with FP8 inputs into cublasLt + // custom call, so turn into an FP16 dot which may be rewritten as an + // FP16 Triton, cublas or cublasLt call. + TF_ASSIGN_OR_RETURN(instr, TurnF8DotIntoF16Dot(instr)); + } + } else { + // Rewrite non-FP8 GEMMs into a cublas or cublasLT Custom Call. + TF_ASSIGN_OR_RETURN( + absl::string_view gemm_custom_call_target, + GetNonFp8GemmCustomCallTarget(*instr, gemm_backend_config)); + const Shape &output_shape = instr->shape(); + HloInstruction *gemm_call = + instr->AddInstruction(HloInstruction::CreateCustomCall( + output_shape, + {instr->mutable_operand(0), instr->mutable_operand(1)}, + gemm_custom_call_target)); + TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gpu_backend_config)); + TF_RETURN_IF_ERROR(ReplaceInstruction(instr, gemm_call)); + } + return absl::OkStatus(); } - Status HandleMultiply(HloInstruction *instr) override { + absl::Status HandleMultiply(HloInstruction *instr) override { HloInstruction *alpha, *existing_gemm; if (Match(instr, m::MultiplyAnyOrder( GemmOrCublasLtMatmulMaybeF8(&existing_gemm).WithOneUser(), m::Broadcast(m::ConstantScalar(&alpha)).WithOneUser()))) { - TF_ASSIGN_OR_RETURN(auto config, - existing_gemm->backend_config()); - + TF_ASSIGN_OR_RETURN(auto gpu_config, + existing_gemm->backend_config()); + GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config(); // Do not fuse alpha into S32 GEMM, as they only support fixed values for // alpha/beta. if (existing_gemm->shape().element_type() == S32) { - return OkStatus(); + return absl::OkStatus(); } if (config.beta() == 0.0 && existing_gemm->user_count() == 1) { @@ -580,7 +632,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { *alpha->literal().GetAsComplex128({}) * prev_alpha; config.set_alpha_real(new_alpha.real()); config.set_alpha_imag(new_alpha.imag()); - TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(config)); + TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(gpu_config)); return ReplaceInstruction(instr, existing_gemm); } } @@ -593,10 +645,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { if (Match(instr, m::MultiplyAnyOrder( m::AnyOf( m::Slice(&slice_or_bitcast, - CublasLtMatmul(&existing_gemm)), + CublasLtMatmulMaybeF8(&existing_gemm)), m::Bitcast(&slice_or_bitcast, - CublasLtMatmul(&existing_gemm)), - CublasLtMatmul(&existing_gemm)), + CublasLtMatmulMaybeF8(&existing_gemm)), + CublasLtMatmulMaybeF8(&existing_gemm)), m::Op(&cdf).WithOneUser())) && Match(cdf, m::MultiplyAnyOrder( @@ -630,10 +682,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { .WithOneUser())))) { return FuseGeluActivation(instr, existing_gemm, slice_or_bitcast); } - return OkStatus(); + return absl::OkStatus(); } - Status HandleAdd(HloInstruction *instr) override { + absl::Status HandleAdd(HloInstruction *instr) override { HloInstruction *bias, *existing_gemm = nullptr; HloInstruction *optional_slice = nullptr; HloInstruction *optional_convert = nullptr; @@ -657,7 +709,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { optional_convert, optional_bitcast)); if (was_fused) { - return OkStatus(); + return absl::OkStatus(); } } // Attempt to elide broadcast and fuse addition of a vector bias into @@ -729,9 +781,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { GemmOrCublasLtMatmul(&existing_gemm).WithOneUser()) .WithOneUser()), m::Op(&bias).WithPredicate(is_not_broadcast)))) { - TF_ASSIGN_OR_RETURN(GemmBackendConfig gemm_backend_config, - existing_gemm->backend_config()); - + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config, + existing_gemm->backend_config()); + const GemmBackendConfig &gemm_backend_config = + gpu_backend_config.gemm_backend_config(); // check if type combination is supported here TF_ASSIGN_OR_RETURN( bool types_are_supported, @@ -768,14 +821,19 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { .WithOneUser())) .WithOneUser(), m::Op(&bias).WithPredicate(is_not_broadcast)))) { - return FuseMatrixBiasAdd(instr, bias, existing_gemm, - optional_bitcast_matrix, optional_slice_matrix); + // The matrix bias must not be FP8, see + // https://docs.nvidia.com/cuda/cublas/index.html. + if (!IsF8Type(bias)) { + return FuseMatrixBiasAdd(instr, bias, existing_gemm, + optional_bitcast_matrix, + optional_slice_matrix); + } } - return OkStatus(); + return absl::OkStatus(); } - Status HandleMaximum(HloInstruction *instr) override { + absl::Status HandleMaximum(HloInstruction *instr) override { HloInstruction *existing_gemm, *zeros; HloInstruction *optional_slice_or_bitcast = nullptr; // Attempt to elide maximum and fuse ReLU activation into GEMM, including @@ -795,10 +853,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { TF_RETURN_IF_ERROR(FuseReluActivation(instr, zeros, existing_gemm, optional_slice_or_bitcast)); } - return OkStatus(); + return absl::OkStatus(); } - Status HandleConvert(HloInstruction *instr) override { + absl::Status HandleConvert(HloInstruction *instr) override { HloInstruction *clamp_lower, *clamp_upper, *d_scale, *existing_gemm, *binary; // Attempt to elide the scaling and conversion of the result of an FP8 @@ -824,15 +882,17 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { instr, existing_gemm, d_scale, clamp_lower, clamp_upper, /*mult_scale=*/binary->opcode() == HloOpcode::kMultiply); } - return OkStatus(); + return absl::OkStatus(); } - StatusOr CreateF8CustomCall( - HloInstruction *instr, GemmBackendConfig &gemm_backend_config, + absl::StatusOr CreateF8CustomCall( + HloInstruction *instr, GpuBackendConfig &gpu_backend_config, HloInstruction *a, HloInstruction *b, HloInstruction *a_scale, HloInstruction *b_scale, bool a_mult_scale, bool b_mult_scale, std::vector> a_ops, std::vector> b_ops) { + GemmBackendConfig &gemm_backend_config = + *gpu_backend_config.mutable_gemm_backend_config(); #if GOOGLE_CUDA auto cuda_compute_capability_ = std::get(gpu_version_); @@ -843,17 +903,38 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { << "FP8 Custom Calls require Ada, Hopper, or later architectures."; return false; } + #if CUDA_VERSION < 12000 // FP8 GEMM kernels are only available with CUDA 12.0 and above VLOG(1) << "FP8 Custom Calls require CUDA 12.0 or newer."; return false; #endif // CUDA_VERSION < 12000 +#endif // GOOGLE_CUDA + +#if TENSORFLOW_USE_ROCM + auto isrocm = std::get_if(&gpu_version_); + if (!isrocm->has_fp8_support()) { + VLOG(1) << "FP8 Custom Calls require MI300, or later architectures."; + return false; + } + +#if TF_ROCM_VERSION < 60000 + // FP8 GEMM kernels are only available with ROCm 6.0 and above + VLOG(1) << "FP8 Custom Calls require ROCm 6.0 or newer."; + return false; +#endif // TF_ROCM_VERSION < 60000 + +#endif // TENSORFLOW_USE_ROCM + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM PrimitiveType a_type = a->shape().element_type(); PrimitiveType b_type = b->shape().element_type(); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM // cuBLASLt FP8 GEMM kernels require one of the two operands to be in // F8E4M3FN format. +#if GOOGLE_CUDA if (a_type == F8E5M2 && b_type == F8E5M2) { VLOG(1) << "Failed to rewrite " << instr->ToShortString() @@ -870,6 +951,26 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { << PrimitiveType_Name(b_type); return false; } +#endif // GOOGLE_CUDA + +#if TENSORFLOW_USE_ROCM + if (a_type == F8E5M2FNUZ && b_type == F8E5M2FNUZ) { + VLOG(1) + << "Failed to rewrite " << instr->ToShortString() + << " into FP8 Custom Call. The element type of one of the operands " + "must be F8E4M3FNUZ."; + return false; + } + if ((a_type != F8E5M2FNUZ && a_type != F8E4M3FNUZ) || + (b_type != F8E5M2FNUZ && b_type != F8E4M3FNUZ)) { + VLOG(1) << "Failed to rewrite " << instr->ToShortString() + << " into FP8 Custom Call. The input types must be F8E5M2FNUZ or " + "F8E4M3FNUZ, but got " + << PrimitiveType_Name(a_type) << " and " + << PrimitiveType_Name(b_type); + return false; + } +#endif // TENSORFLOW_USE_ROCM absl::Span batch_dims = gemm_backend_config.dot_dimension_numbers().rhs_batch_dimensions(); @@ -912,6 +1013,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { case F32: break; default: + VLOG(1) << "Failed to rewrite " << instr->ToShortString() << " into FP8 Custom Call. Output element type must be " "F8E4M3FN, F8E5M2, BF16, F16 or F32. Actual element type is " @@ -1043,9 +1145,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { instr->shape().element_type(), new_output_shape.dimensions(), instr->shape().layout().minor_to_major()), operands_list, kCublasLtMatmulF8CallTarget)); - - TF_RETURN_IF_ERROR( - new_custom_call->set_backend_config(gemm_backend_config)); + TF_RETURN_IF_ERROR(new_custom_call->set_backend_config(gpu_backend_config)); TF_RETURN_IF_ERROR(SetName(instr->GetModule(), new_custom_call)); // Slice the result of the GEMM if the operands were padded. @@ -1062,35 +1162,33 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { ReplaceInstruction(instr, slice ? slice : new_custom_call)); VLOG(1) << instr->ToString() << " rewritten into FP8 Custom Call."; return true; -#else // TENSORFLOW_USE_ROCM - return false; -#endif } - Status F8ConvertD(HloInstruction *instr, HloInstruction *existing_gemm, - HloInstruction *d_scale, HloInstruction *clamp_lower, - HloInstruction *clamp_upper, bool mult_scale = false) { + absl::Status F8ConvertD(HloInstruction *instr, HloInstruction *existing_gemm, + HloInstruction *d_scale, HloInstruction *clamp_lower, + HloInstruction *clamp_upper, + bool mult_scale = false) { // Verify the data types and the operands of clamp. if (instr->shape().element_type() == F8E4M3FN) { if (!clamp_lower->literal().IsAllFloat(static_cast( std::numeric_limits::lowest())) || !clamp_upper->literal().IsAllFloat(static_cast( std::numeric_limits::max()))) { - return OkStatus(); + return absl::OkStatus(); } } else if (instr->shape().element_type() == F8E5M2) { if (!clamp_lower->literal().IsAllFloat(static_cast( std::numeric_limits::lowest())) || !clamp_upper->literal().IsAllFloat(static_cast( std::numeric_limits::max()))) { - return OkStatus(); + return absl::OkStatus(); } } else { - return OkStatus(); + return absl::OkStatus(); } if (!ShapeUtil::IsScalar(d_scale->shape())) { - return OkStatus(); + return absl::OkStatus(); } // The possible second user of the GEMM must be the calculation of the @@ -1102,8 +1200,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { if (gemm_users.size() == 2) { // In the presence of a ReLU activation, the abs instruction is elided // since abs(ReLU(x)) = ReLU(x). - TF_ASSIGN_OR_RETURN(auto config, - existing_gemm->backend_config()); + TF_ASSIGN_OR_RETURN(auto gpu_config, + existing_gemm->backend_config()); + const GemmBackendConfig &config = gpu_config.gemm_backend_config(); for (int i = 0; i < gemm_users.size(); ++i) { HloInstruction *maybe_reduce = nullptr; if (gemm_users[i]->opcode() == HloOpcode::kAbs) { @@ -1134,14 +1233,17 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } } if (!reduce_damax) { - return OkStatus(); + return absl::OkStatus(); } } else if (gemm_users.size() > 2) { - return OkStatus(); + return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(auto gemm_backend_config, - existing_gemm->backend_config()); + TF_ASSIGN_OR_RETURN(auto gpu_backend_config, + existing_gemm->backend_config()); + const GemmBackendConfig &gemm_backend_config = + gpu_backend_config.gemm_backend_config(); + if (gemm_backend_config.beta() != 0.0 && existing_gemm->operand(2)->shape().element_type() != BF16 && existing_gemm->operand(2)->shape().element_type() != F16) { @@ -1150,7 +1252,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { << " is not fused into the FP8 Custom Call because it " "conflicts with the existing fusion of the addition of a " "matrix bias with element type other than BF16 or F16."; - return OkStatus(); + return absl::OkStatus(); } // If necessary, invert the scaling factor of D and convert to F32. @@ -1178,12 +1280,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { existing_gemm->CloneWithNewShape(instr->shape()); TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(new_gemm))); - return OkStatus(); + return absl::OkStatus(); } // Adds a scalar DAmax return value to an FP8 GEMM. - Status F8AddDAmax(HloInstruction *instr, HloInstruction *existing_gemm, - HloInstruction *reduce_damax) { + absl::Status F8AddDAmax(HloInstruction *instr, HloInstruction *existing_gemm, + HloInstruction *reduce_damax) { // Change the output shape of the Custom Call to tuple(D, DAmax). Shape damax_shape = ShapeUtil::MakeScalarShape(F32); Shape tuple_shape = @@ -1204,7 +1306,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { TF_RETURN_IF_ERROR(ReplaceInstruction(reduce_damax, damax_converted)); TF_RETURN_IF_ERROR(ReplaceInstruction(instr, d)); - return OkStatus(); + return absl::OkStatus(); } // Fuses a matrix bias into a cuBLAS call. 'instr' should be an Add @@ -1213,10 +1315,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // where 'gemm' is expected to be a cuBLAS custom_call. Slice is introduced // when the inputs of the gemm are possibly padded. Bitcast is introduced to // handle high rank input. - Status FuseMatrixBiasAdd(HloInstruction *instr, HloInstruction *bias, - const HloInstruction *gemm, - HloInstruction *bitcast = nullptr, - HloInstruction *slice = nullptr) { + absl::Status FuseMatrixBiasAdd(HloInstruction *instr, HloInstruction *bias, + const HloInstruction *gemm, + HloInstruction *bitcast = nullptr, + HloInstruction *slice = nullptr) { TF_RET_CHECK(Shape::Equal().IgnoreElementType()(bias->shape(), bitcast ? bitcast->shape() : slice ? slice->shape() @@ -1225,7 +1327,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // Do not fuse bias into S32 GEMM, as for this datatype cuBLAS only // supports fixed values for alpha/beta. if (gemm->shape().element_type() == S32) { - return OkStatus(); + return absl::OkStatus(); } // To ensure correctness, only slices that chop off the ends of dimensions @@ -1234,7 +1336,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { int slice_op_dim = slice->operand(0)->shape().rank(); if (slice->slice_starts() != std::vector(slice_op_dim, 0) || slice->slice_strides() != std::vector(slice_op_dim, 1)) { - return OkStatus(); + return absl::OkStatus(); } } // Cublas gemm overwrites the bias matrix, so fusion is only possible if the @@ -1269,8 +1371,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { bool want_to_fuse_bias = IsCublasLtMatmulF8(*gemm) || IsCublasLtMatmul(*gemm) || can_overwrite_bias; - auto config = gemm->backend_config().value(); - + auto gpu_config = gemm->backend_config().value(); + GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config(); // It is possible to fuse into a cublasLt matmul that already has a vector // bias, but no other epilogue will commute with the matrix bias add. bool supported_epilogue = @@ -1279,7 +1381,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { if ((config.beta() != 0) || !want_to_fuse_bias || (gemm->user_count() != 1) || !supported_epilogue) { - return OkStatus(); + return absl::OkStatus(); } config.set_beta(1.0); @@ -1302,8 +1404,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { gemm->CloneWithNewOperands(gemm->shape(), operands); // set output shape to bias shape if mix type fused_op->mutable_shape()->set_element_type(bias->shape().element_type()); - - TF_RETURN_IF_ERROR(fused_op->set_backend_config(config)); + TF_RETURN_IF_ERROR(fused_op->set_backend_config(gpu_config)); // Choose whether the bias must alias the output. Legacy cublas GEMMs must // operate in place and alias the bias with the output, whereas with @@ -1348,12 +1449,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // convert is only used for F8 matmuls as cublasLt has specific constraints // on the vector bias type for such matmuls. The optional bitcast is // necessary to handle high rank input cases. - StatusOr FuseVectorBiasAdd(HloInstruction *instr, - HloInstruction *broadcast, - HloInstruction *gemm, - HloInstruction *slice = nullptr, - HloInstruction *convert = nullptr, - HloInstruction *bitcast = nullptr) { + absl::StatusOr FuseVectorBiasAdd(HloInstruction *instr, + HloInstruction *broadcast, + HloInstruction *gemm, + HloInstruction *slice = nullptr, + HloInstruction *convert = nullptr, + HloInstruction *bitcast = nullptr) { if (!bitcast) { TF_RET_CHECK(ShapeUtil::Compatible( broadcast->shape(), (slice ? slice->shape() : gemm->shape()))); @@ -1365,8 +1466,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { HloInstruction *bias = broadcast->mutable_operand(0); - TF_ASSIGN_OR_RETURN(auto config, gemm->backend_config()); - + TF_ASSIGN_OR_RETURN(auto gpu_config, + gemm->backend_config()); + GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config(); // # output column dims == # non-contracting rhs operand dims. const DotDimensionNumbers &dot_dims = config.dot_dimension_numbers(); size_t num_col_dims = gemm->operand(1)->shape().rank() - @@ -1454,7 +1556,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { config.set_epilogue(GemmBackendConfig::BIAS); std::unique_ptr result = gemm->CloneWithNewOperands(gemm->shape(), operands); - TF_RETURN_IF_ERROR(result->set_backend_config(config)); + TF_RETURN_IF_ERROR(result->set_backend_config(gpu_config)); TF_RETURN_IF_ERROR(SetName(result->GetModule(), result.get())); if (slice) { result = slice->CloneWithNewOperands( @@ -1470,32 +1572,35 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return true; } - Status FuseReluActivation(HloInstruction *instr, HloInstruction *broadcast, - HloInstruction *gemm, - HloInstruction *slice_or_bitcast = nullptr) { + absl::Status FuseReluActivation(HloInstruction *instr, + HloInstruction *broadcast, + HloInstruction *gemm, + HloInstruction *slice_or_bitcast = nullptr) { TF_RET_CHECK(ShapeUtil::Compatible( broadcast->shape(), (slice_or_bitcast ? slice_or_bitcast->shape() : gemm->shape()))); if (!SupportsEpilogueFusion(gemm->shape().element_type())) { - return OkStatus(); + return absl::OkStatus(); } if (gemm->user_count() != 1) { - return OkStatus(); + return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(auto config, gemm->backend_config()); + TF_ASSIGN_OR_RETURN(auto gpu_config, + gemm->backend_config()); + GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config(); if (config.epilogue() == GemmBackendConfig::DEFAULT) { config.set_epilogue(GemmBackendConfig::RELU); } else if (config.epilogue() == GemmBackendConfig::BIAS) { config.set_epilogue(GemmBackendConfig::BIAS_RELU); } else { - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr result = gemm->Clone(); - TF_RETURN_IF_ERROR(result->set_backend_config(config)); + TF_RETURN_IF_ERROR(result->set_backend_config(gpu_config)); TF_RETURN_IF_ERROR(SetName(result->GetModule(), result.get())); if (slice_or_bitcast) { @@ -1507,16 +1612,30 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return ReplaceWithNewInstruction(instr, std::move(result)); } - Status FuseGeluActivation(HloInstruction *multiply, HloInstruction *gemm, - HloInstruction *slice_or_bitcast = nullptr) { + absl::Status FuseGeluActivation(HloInstruction *multiply, + HloInstruction *gemm, + HloInstruction *slice_or_bitcast = nullptr) { if (!SupportsEpilogueFusion(gemm->shape().element_type())) { - return OkStatus(); + return absl::OkStatus(); + } + +#if CUDA_VERSION < 12040 + // For CUDA versions less than 12.3.2, cuBLAS LT returns + // CUBLAS_STATUS_NOT_SUPPORTED in some cases when fusing gelu into an FP8 + // matmul. We cannot check the patch version, so disable this fusion with + // CUDA versions less than 12.4. + if (IsCublasLtMatmulF8(*gemm)) { + return absl::OkStatus(); } +#endif // There are four users of the gemm output within the GELU calculation. bool has_aux = gemm->user_count() > 4; - TF_ASSIGN_OR_RETURN(auto config, gemm->backend_config()); + TF_ASSIGN_OR_RETURN(auto gpu_config, + gemm->backend_config()); + GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config(); + if (config.epilogue() == GemmBackendConfig::DEFAULT) { config.set_epilogue(has_aux ? GemmBackendConfig::GELU_AUX : GemmBackendConfig::GELU); @@ -1524,13 +1643,13 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { config.set_epilogue(has_aux ? GemmBackendConfig::BIAS_GELU_AUX : GemmBackendConfig::BIAS_GELU); } else { - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr output = gemm->CloneWithNewShape( has_aux ? ShapeUtil::MakeTupleShape({gemm->shape(), gemm->shape()}) : gemm->shape()); - TF_RETURN_IF_ERROR(output->set_backend_config(config)); + TF_RETURN_IF_ERROR(output->set_backend_config(gpu_config)); TF_RETURN_IF_ERROR(SetName(multiply->GetModule(), output.get())); if (slice_or_bitcast) { @@ -1552,10 +1671,11 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { private: se::GpuComputeCapability gpu_version_; + bool f8_rewrite_; // Choose cublas or cublasLt for the target of the custom call that instr will // be rewritten into. - StatusOr GetNonFp8GemmCustomCallTarget( + absl::StatusOr GetNonFp8GemmCustomCallTarget( const HloInstruction &instr, const GemmBackendConfig &gemm_backend_config) const { if (!instr.GetModule() @@ -1588,7 +1708,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return absl::string_view(kGemmCallTarget); } - StatusOr TypesAreSupportedByLegacyCublas( + absl::StatusOr TypesAreSupportedByLegacyCublas( const HloInstruction &instr, const GemmBackendConfig &gemm_backend_config, const HloInstruction *bias = nullptr) const { // Figure out the Atype/Btype. @@ -1605,9 +1725,11 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { if (!absl::c_linear_search(supported_type, output_type)) return false; TF_ASSIGN_OR_RETURN(const se::blas::DataType output_dtype, se::gpu::AsBlasDataType(output_type)); + // TODO(tdanyluk): Investigate why don't we use the actual precision (and + // algorithm) here? Why do we use the default? TF_ASSIGN_OR_RETURN(const se::blas::ComputationType compute_type, se::gpu::GetBlasComputationType( - a_dtype, output_type, + PrecisionConfig::ALG_UNSET, a_dtype, output_type, stream_executor::blas::kDefaultComputePrecision)); se::blas::DataType scale_type = se::gpu::GetScaleType(output_dtype, compute_type); @@ -1675,7 +1797,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { output_dtype)); } - StatusOr TypesAreSupportedByCublasLt( + absl::StatusOr TypesAreSupportedByCublasLt( const HloInstruction &instr, const GemmBackendConfig &backend_config, const HloInstruction *bias = nullptr) const { // Figure out the Atype/Btype. @@ -1684,21 +1806,27 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { const PrimitiveType output_type = bias ? bias->shape().element_type() : instr.shape().element_type(); const std::array supported_type = { - PrimitiveType::F8E5M2, PrimitiveType::F8E4M3FN, PrimitiveType::S8, - PrimitiveType::F16, PrimitiveType::BF16, PrimitiveType::F32, - PrimitiveType::S32, PrimitiveType::F64, PrimitiveType::C64, - PrimitiveType::C128}; + PrimitiveType::F8E5M2FNUZ, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E5M2, PrimitiveType::F8E4M3FN, + PrimitiveType::S8, PrimitiveType::F16, + PrimitiveType::BF16, PrimitiveType::F32, + PrimitiveType::S32, PrimitiveType::F64, + PrimitiveType::C64, PrimitiveType::C128}; if (!absl::c_linear_search(supported_type, output_type)) return false; // cublasLt has a defined set of combinations of types that it supports. // Figure out the computeType and scaleType. TF_ASSIGN_OR_RETURN(const se::blas::DataType output_dtype, se::gpu::AsBlasDataType(output_type)); - int max_precision = *absl::c_max_element( + const int max_precision = *absl::c_max_element( backend_config.precision_config().operand_precision()); + const PrecisionConfig::Algorithm algorithm = + backend_config.precision_config().algorithm(); + if (!algorithm_util::IsSupportedByCublasOrCublasLt(algorithm)) return false; + TF_ASSIGN_OR_RETURN( const se::blas::ComputationType compute_type, - se::gpu::GetBlasComputationType(a_dtype, instr.shape().element_type(), - max_precision)); + se::gpu::GetBlasComputationType( + algorithm, a_dtype, instr.shape().element_type(), max_precision)); se::blas::DataType scale_type = se::gpu::GetScaleType(output_dtype, compute_type); @@ -1746,6 +1874,39 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2, PrimitiveType::F8E4M3FN, DataType::kFloat}, #endif // GOOGLE_CUDA +#if TENSORFLOW_USE_ROCM + // FP8 types: + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E4M3FNUZ, DataType::kBF16}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E4M3FNUZ, DataType::kHalf}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E4M3FNUZ, DataType::kFloat}, + + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E5M2FNUZ, DataType::kBF16}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E5M2FNUZ, DataType::kF8E4M3FNUZ}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E5M2FNUZ, DataType::kF8E5M2FNUZ}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E5M2FNUZ, DataType::kHalf}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ, + PrimitiveType::F8E5M2FNUZ, DataType::kFloat}, + + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ, + PrimitiveType::F8E4M3FNUZ, DataType::kBF16}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ, + PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ, + PrimitiveType::F8E4M3FNUZ, DataType::kF8E5M2FNUZ}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ, + PrimitiveType::F8E4M3FNUZ, DataType::kHalf}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ, + PrimitiveType::F8E4M3FNUZ, DataType::kFloat}, +#endif // TENSORFLOW_USE_ROCM // Other data types: {ComputationType::kF16, DataType::kHalf, PrimitiveType::F16, PrimitiveType::F16, DataType::kHalf}, @@ -1803,7 +1964,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { output_dtype)); } - StatusOr MatrixIsColumnMajor( + absl::StatusOr MatrixIsColumnMajor( const HloInstruction &instr, const GemmBackendConfig &gemm_backend_config, const std::string matrix_name = "output") const { const HloInstruction *lhs = instr.operand(0); @@ -1811,6 +1972,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { const DotDimensionNumbers &dot_dims = gemm_backend_config.dot_dimension_numbers(); + // We use ALG_UNSET and kDefaultComputePrecision because we don't care about + // the precision, just the layout, since we're just checking if the matrix + // is column-major. TF_ASSIGN_OR_RETURN( GemmConfig gemm_config, GemmConfig::For( @@ -1820,6 +1984,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { dot_dims.rhs_contracting_dimensions(), /*output_shape=*/instr.shape(), gemm_backend_config.alpha_real(), gemm_backend_config.alpha_imag(), gemm_backend_config.beta(), + /*precision_algorithm=*/PrecisionConfig::ALG_UNSET, /*algorithm*/ std::nullopt, se::blas::kDefaultComputePrecision, gemm_backend_config.grad_x(), gemm_backend_config.grad_y())); @@ -1831,11 +1996,11 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return gemm_config.output_layout.order == MatrixLayout::Order::kColumnMajor; } else { - return InternalError("Invalid matrix name."); + return Internal("Invalid matrix name."); } } - StatusOr GemmIsSupportedByCublasLt( + absl::StatusOr GemmIsSupportedByCublasLt( const HloInstruction &instr, const GemmBackendConfig &gemm_backend_config) const { const HloInstruction *lhs = instr.operand(0); @@ -1869,12 +2034,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { TF_ASSIGN_OR_RETURN(bool output_is_column_major, MatrixIsColumnMajor(instr, gemm_backend_config)); - if (std::holds_alternative(gpu_version_)) { - auto rocm_compute_capability_ = - std::get(gpu_version_); - - // as of ROCm 5.5, hipblaslt only supports MI200. - if (rocm_compute_capability_.gcn_arch_name().substr(0, 6) != "gfx90a") { + if (auto isrocm = std::get_if(&gpu_version_); + isrocm) { + if (!isrocm->has_hipblaslt()) { return false; } } @@ -1931,10 +2093,25 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return lhs_non_contracting_dimension_size <= kMaxDimensionSize; } +#if TENSORFLOW_USE_ROCM + // Turns an F8 dot with unsupported output type into an F8 dot with F32 + // output, and converting the F32 output to unsupported output types. + absl::StatusOr TurnF8DotWithUnsupportedOutputTypeIntoF32( + HloInstruction *instr) { + Shape output_f32_shape = instr->shape(); + output_f32_shape.set_element_type(F32); + HloInstruction *f32_dot = + instr->AddInstruction(instr->CloneWithNewShape(output_f32_shape)); + HloInstruction *convert = instr->AddInstruction( + HloInstruction::CreateConvert(instr->shape(), f32_dot)); + TF_RETURN_IF_ERROR(ReplaceInstruction(instr, convert)); + return f32_dot; + } +#endif // TENSORFLOW_USE_ROCM + // Turns an F8 dot into an F16 dot, converting operands to F16 and // converting the output back to F8. - StatusOr TurnF8DotIntoF16Dot(HloInstruction *instr) { - DCHECK(IsF8Type(instr)); + absl::StatusOr TurnF8DotIntoF16Dot(HloInstruction *instr) { DCHECK(IsF8Type(instr->operand(0))); DCHECK(IsF8Type(instr->operand(1))); @@ -1948,15 +2125,19 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { TF_RETURN_IF_ERROR(instr->ReplaceOperandWith(i, convert)); } - // Clone instruction and convert output to F8 - Shape output_f16_shape = instr->shape(); - output_f16_shape.set_element_type(F16); - HloInstruction *f16_dot = - instr->AddInstruction(instr->CloneWithNewShape(output_f16_shape)); - HloInstruction *convert_to_f8 = instr->AddInstruction( - HloInstruction::CreateConvert(instr->shape(), f16_dot)); - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, convert_to_f8)); - return f16_dot; + // If output is F8, change output to F16 and then convert it back to F8 + if (IsF8Type(instr)) { + Shape output_f16_shape = instr->shape(); + output_f16_shape.set_element_type(F16); + HloInstruction *f16_dot = + instr->AddInstruction(instr->CloneWithNewShape(output_f16_shape)); + HloInstruction *convert_to_f8 = instr->AddInstruction( + HloInstruction::CreateConvert(instr->shape(), f16_dot)); + TF_RETURN_IF_ERROR(ReplaceInstruction(instr, convert_to_f8)); + return f16_dot; + } else { + return instr; + } } }; @@ -1965,13 +2146,14 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // having to match output tuples. class GemmWorkspaceRewriteVisitor : public DfsHloRewriteVisitor { public: - explicit GemmWorkspaceRewriteVisitor(se::GpuComputeCapability gpu_version) + explicit GemmWorkspaceRewriteVisitor( + const se::GpuComputeCapability &gpu_version) : gpu_version_(gpu_version) {} - Status HandleCustomCall(HloInstruction *instr) override { + absl::Status HandleCustomCall(HloInstruction *instr) override { if (instr->custom_call_target() != kGemmCallTarget || !instr->shape().IsArray()) { - return OkStatus(); + return absl::OkStatus(); } auto *cuda_cc = std::get_if(&gpu_version_); @@ -1997,18 +2179,6 @@ class GemmWorkspaceRewriteVisitor : public DfsHloRewriteVisitor { } workspace = std::min(workspace, operands_byte_size); - // If CUDA graphs are disabled (command buffer implementation detail), - // then we reset the workspace size to 0 and rely on cuBlas to allocate - // workspace from its own pool. - // - // TODO(ezhulenev): Remove this work around, allocating workspace - // explicitly should always be better than relying on cuBlas. - bool cuda_graphs_disabled = instr->GetModule() - ->config() - .debug_options() - .xla_gpu_enable_command_buffer_size() == 0; - if (cuda_graphs_disabled) workspace = 0; - // Append workspace buffer to instruction outputs. std::vector output_shapes = {instr->shape()}; output_shapes.emplace_back(ShapeUtil::MakeShape(S8, {workspace})); @@ -2033,9 +2203,10 @@ class GemmWorkspaceRewriteVisitor : public DfsHloRewriteVisitor { se::GpuComputeCapability gpu_version_; }; -StatusOr RunOnComputation(HloComputation *computation, - se::GpuComputeCapability gpu_version) { - GemmRewriterVisitor visitor(gpu_version); +absl::StatusOr RunOnComputation(HloComputation *computation, + se::GpuComputeCapability gpu_version, + bool f8_rewrite) { + GemmRewriterVisitor visitor(gpu_version, f8_rewrite); TF_RETURN_IF_ERROR(computation->Accept(&visitor)); GemmWorkspaceRewriteVisitor workspace_visitor(gpu_version); TF_RETURN_IF_ERROR(computation->Accept(&workspace_visitor)); @@ -2044,17 +2215,18 @@ StatusOr RunOnComputation(HloComputation *computation, } // anonymous namespace -GemmRewriter::GemmRewriter(se::GpuComputeCapability gpu_version) - : gpu_version_(gpu_version) {} +GemmRewriter::GemmRewriter(se::GpuComputeCapability gpu_version, + bool f8_rewrite) + : gpu_version_(gpu_version), f8_rewrite_(f8_rewrite) {} -StatusOr GemmRewriter::Run( +absl::StatusOr GemmRewriter::Run( HloModule *module, const absl::flat_hash_set &execution_threads) { bool changed = false; for (HloComputation *computation : module->MakeNonfusionComputations(execution_threads)) { - TF_ASSIGN_OR_RETURN(bool result, - RunOnComputation(computation, gpu_version_)); + TF_ASSIGN_OR_RETURN( + bool result, RunOnComputation(computation, gpu_version_, f8_rewrite_)); changed |= result; } return changed; diff --git a/xla/service/gpu/gemm_rewriter.h b/xla/service/gpu/gemm_rewriter.h index 65efd940450e4..576f75efaeb34 100644 --- a/xla/service/gpu/gemm_rewriter.h +++ b/xla/service/gpu/gemm_rewriter.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,9 +15,9 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GEMM_REWRITER_H_ #define XLA_SERVICE_GPU_GEMM_REWRITER_H_ -#include - -#include "xla/hlo/ir/hlo_instructions.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" #include "xla/stream_executor/device_description.h" @@ -31,9 +31,10 @@ namespace gpu { // (kMultiply (kDot A B) alpha) // (kMultiply C beta)) // -// where A, B, C are matrixes and `alpha` and `beta` are host constants. -// The additional requirement is that C has no other users (otherwise, -// it does not make sense to fuse it inside the custom call). +// where A, B, C are matrices or vectors and `alpha` and `beta` are host +// constants. In matrix-vector multiplication, one operand must be a matrix and +// the other must be a vector. The additional requirement is that C has no other +// users (otherwise, it does not make sense to fuse it inside the custom call). // // Both multiplication and addition can be avoided (equivalent to setting // `alpha` to one and `beta` to zero). @@ -44,16 +45,20 @@ namespace gpu { // stored in the backend config. class GemmRewriter : public HloModulePass { public: - explicit GemmRewriter(se::GpuComputeCapability gpu_version); + // When f8_rewrite is true, only FP8 GEMMs are rewritten. Otherwise, non-FP8 + // GEMMs are rewritten. + explicit GemmRewriter(se::GpuComputeCapability gpu_version, + bool f8_rewrite = false); absl::string_view name() const override { return "cublas-gemm-rewriter"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; private: se::GpuComputeCapability gpu_version_; + bool f8_rewrite_; }; } // namespace gpu diff --git a/xla/service/gpu/gemm_rewriter_triton.cc b/xla/service/gpu/gemm_rewriter_triton.cc deleted file mode 100644 index f618411e7f49d..0000000000000 --- a/xla/service/gpu/gemm_rewriter_triton.cc +++ /dev/null @@ -1,498 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/gemm_rewriter_triton.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_padding_requirements.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/triton_fusion_analysis.h" -#include "xla/service/gpu/triton_support.h" -#include "xla/service/gpu/triton_tiling_propagation.h" -#include "xla/service/instruction_fusion.h" -#include "xla/shape_util.h" -#include "xla/status.h" -#include "xla/status_macros.h" -#include "xla/statusor.h" -#include "xla/stream_executor/device_description.h" -#include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/tensor_float_32_utils.h" - -namespace xla { -namespace gpu { - -namespace { - -using triton_fusion::DimOrdersAndReqs; -using triton_fusion::DimOrdersAndReqsOrError; -using triton_fusion::FusionContext; -using triton_fusion::GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible; -using triton_fusion::TransformDirection; - -using OldToNewHloMap = - absl::flat_hash_map; - -// Gets the fused HLO corresponding to `hlo` or adds a new parameter if not -// found. -HloInstruction* GetFusedHloOrAddParameter( - HloInstruction& hlo, OldToNewHloMap& old_to_new_map, - std::vector& fusion_inputs, - HloComputation::Builder& builder) { - if (auto it = old_to_new_map.find(&hlo); it != old_to_new_map.end()) { - return it->second; - } - fusion_inputs.push_back(&hlo); - return old_to_new_map - .insert( - {&hlo, builder.AddInstruction(HloInstruction::CreateParameter( - fusion_inputs.size() - 1, hlo.shape(), - absl::StrCat("parameter_", fusion_inputs.size() - 1)))}) - .first->second; -} - -// Clone an instruction into the fusion. -// -// For the hero dot operation in the dot fusion, please use FuseDotOnly. -void Fuse(HloInstruction& hlo, OldToNewHloMap& old_to_new_map, - std::vector& fusion_inputs, - HloComputation::Builder& builder) { - if (old_to_new_map.contains(&hlo)) { - return; - } - VLOG(3) << "Fusing " << hlo.ToString(); - if (hlo.opcode() == HloOpcode::kParameter || - hlo.opcode() == HloOpcode::kGetTupleElement) { - GetFusedHloOrAddParameter(hlo, old_to_new_map, fusion_inputs, builder); - } else { - std::vector hlo_new_operands; - for (HloInstruction* operand : hlo.operands()) { - hlo_new_operands.push_back(GetFusedHloOrAddParameter( - *operand, old_to_new_map, fusion_inputs, builder)); - } - old_to_new_map[&hlo] = builder.AddInstruction( - hlo.CloneWithNewOperands(hlo.shape(), hlo_new_operands)); - } -} - -// Clones the hero kDot operation into the fusion. -void FuseDotOnly(HloInstruction& hlo, OldToNewHloMap& output_old_to_new_map, - OldToNewHloMap& lhs_old_to_new_map, - OldToNewHloMap& rhs_old_to_new_map, - std::vector& fusion_inputs, - HloComputation::Builder& builder) { - CHECK_EQ(hlo.opcode(), HloOpcode::kDot); - CHECK_EQ(hlo.operand_count(), 2); - VLOG(3) << "Fusing " << hlo.ToString(); - - std::array hlo_new_operands = { - GetFusedHloOrAddParameter(*hlo.mutable_operand(0), lhs_old_to_new_map, - fusion_inputs, builder), - GetFusedHloOrAddParameter(*hlo.mutable_operand(1), rhs_old_to_new_map, - fusion_inputs, builder)}; - output_old_to_new_map[&hlo] = builder.AddInstruction( - hlo.CloneWithNewOperands(hlo.shape(), hlo_new_operands)); -} - -// Tells how many new parameters does a fusion gain by fusing the operation as -// an input. -int64_t NumAddedParameters(const HloInstruction& hlo) { - // Non-scalar constant is equivalent to a parameter: one input, one output. - if (hlo.opcode() == HloOpcode::kConstant && - !ShapeUtil::IsScalar(hlo.shape())) { - return 0; - } - // All other instructions add all own inputs and remove own single output. - return hlo.operand_count() - 1; -} - -// Fuse an instruction with all its fusible inputs. -// If an input is not fusible stop there and make a parameter of the new -// fusion, otherwise put it onto stack and check its own inputs first. -void TryToFuseWithInputsRecursively(HloInstruction& root, - se::GpuComputeCapability gpu_version, - triton_fusion::FusionContext& context, - OldToNewHloMap& old_to_new_map, - std::vector& fusion_inputs, - HloComputation::Builder& builder) { - // Instructions at the fusion edge that can either get fused too or - // become parameters of the fusion. Used to track the number of parameters. - absl::flat_hash_set inputs; - // Traverse all connected instructions that could be fused, analyze them and - // collect ones that will be fused. - absl::flat_hash_set to_fuse_set; - std::list to_fuse_list; - absl::flat_hash_set enqueued; - std::queue to_visit; - to_visit.push(&root); - int num_requeued = 0; - while (to_visit.size() > num_requeued) { - HloInstruction* hlo = to_visit.front(); - to_visit.pop(); - // Watch the total number of fusion parameters. - if (inputs.size() + NumAddedParameters(*hlo) > - TritonFusionAnalysis::kMaxParameterPerDotScope) { - // Re-queue: the number of parameters may go down when other instructions - // are processed. - to_visit.push(hlo); - // Prevent infinite loops. - ++num_requeued; - continue; - } - num_requeued = 0; - const DimOrdersAndReqsOrError result = - GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( - *hlo, TransformDirection::kOutputToInput, - /*src_operand_index=*/std::nullopt, context.dim_orders().at(hlo), - gpu_version, context.hero_properties()); - if (!std::holds_alternative(result) || - !context.CombineDimOrdersAndReqs(std::get(result))) { - continue; - } - if (hlo->opcode() != HloOpcode::kParameter) { - inputs.erase(hlo); - } - inputs.insert(hlo->operands().cbegin(), hlo->operands().cend()); - to_fuse_set.insert(hlo); - to_fuse_list.push_back(hlo); - for (HloInstruction* operand : hlo->operands()) { - if (enqueued.insert(operand).second) { - VLOG(6) << "Enqueueing " << operand->ToString(); - to_visit.push(operand); - } - } - } - // Find one by one instructions that have no operands queued to be fused and - // fuse them. - while (!to_fuse_list.empty()) { - for (auto it = to_fuse_list.begin(); it != to_fuse_list.end();) { - bool ready_to_fuse = true; - for (const HloInstruction* operand : (*it)->operands()) { - if (to_fuse_set.contains(operand)) { - ready_to_fuse = false; - break; - } - } - if (ready_to_fuse) { - Fuse(**it, old_to_new_map, fusion_inputs, builder); - to_fuse_set.erase(*it); - it = to_fuse_list.erase(it); - } else { - ++it; - } - } - } -} - -// Fuses dot and the compatible and profitable to fuse operations around it -// into a new fusion computation constructed using the builder. fusion_inputs -// get populated with the non-fused instructions that become operands of the -// call to this fusion. fusion_output_ptr (if not nullptr) gets assigned the -// original instruction that has to be replaced by the call to the fusion. -StatusOr FuseDot(HloInstruction& dot, - const se::GpuComputeCapability gpu_version, - HloComputation::Builder& builder, - std::vector& fusion_inputs, - HloInstruction** fusion_output_ptr) { - VLOG(5) << dot.ToString(); - if (FusionDecision can_handle = CanTritonHandleGEMM(dot, gpu_version); - !can_handle) { - VLOG(3) << can_handle.Explain(); - return can_handle; - } - - // Separate traversal from LHS and RHS inputs of the dot: they use - // differently shaped tiles but may go through same HLO graph nodes. - // Direct dot inputs have well defined dimension orders. - - auto fuse_inputs = - [&](int operand_number, - OldToNewHloMap& old_to_new_map) -> StatusOr { - const int operand_count_before = fusion_inputs.size(); - // Direct dot inputs have well defined dimension orders. - auto context = FusionContext::FromDotOperand(dot, operand_number); - TryToFuseWithInputsRecursively(*dot.mutable_operand(operand_number), - gpu_version, context, old_to_new_map, - fusion_inputs, builder); - const int new_parameters = fusion_inputs.size() - operand_count_before; - TF_RET_CHECK(new_parameters <= - TritonFusionAnalysis::kMaxParameterPerDotScope) - << "Too many new parameters: " << new_parameters << " > " - << TritonFusionAnalysis::kMaxParameterPerDotScope; - return context; - }; - - // Original instruction -> fused one. Separate for each scope. - OldToNewHloMap lhs_old_to_new_map; - TF_ASSIGN_OR_RETURN(const FusionContext lhs_context, - fuse_inputs(0, lhs_old_to_new_map)); - - OldToNewHloMap rhs_old_to_new_map; - if (auto result = fuse_inputs(1, rhs_old_to_new_map); !result.ok()) { - return result.status(); - } - - OldToNewHloMap output_old_to_new_map; - // Fuse the dot into output_old_to_new_map and use lhs_old_to_new_map and - // rhs_old_to_new_map to generate / determine its operands. - FuseDotOnly(dot, output_old_to_new_map, lhs_old_to_new_map, - rhs_old_to_new_map, fusion_inputs, builder); - - // Fusion at dot's output. - - // These describe _outputs_ of corresponding HLOs. - auto context = FusionContext::FromDotOutput( - dot, /*split_k=*/1, lhs_context.splittable_dimension_major_part_size()); - HloInstruction* fusion_output = ˙ - bool output_changed = true; - while (output_changed) { - output_changed = false; - if (fusion_output->user_count() != 1) { - break; - } - HloInstruction* user = fusion_output->users()[0]; - if (!IsDistributiveOverAddition(*user)) { - break; - } - DimOrdersAndReqsOrError result = - GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( - *user, TransformDirection::kInputToOutput, - user->operand_index(fusion_output), - context.dim_orders().at(fusion_output), gpu_version, - context.hero_properties()); - if (!std::holds_alternative(result) || - !context.CombineDimOrdersAndReqs(std::get(result))) { - break; - } - for (HloInstruction* operand : user->operands()) { - if (!output_old_to_new_map.contains(operand)) { - TryToFuseWithInputsRecursively(*operand, gpu_version, context, - output_old_to_new_map, fusion_inputs, - builder); - } - } - Fuse(*user, output_old_to_new_map, fusion_inputs, builder); - fusion_output = user; - output_changed = true; - } - if (fusion_output_ptr != nullptr) { - *fusion_output_ptr = fusion_output; - } - if (dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any()) { - return FusionDecision{}; - } - - for (auto* old_to_new_map : std::array{ - &lhs_old_to_new_map, &rhs_old_to_new_map, &output_old_to_new_map}) { - for (auto [_, new_hlo] : *old_to_new_map) { - static constexpr std::array kPureOpcodes = { - HloOpcode::kBitcast, HloOpcode::kDot, HloOpcode::kParameter, - HloOpcode::kReshape}; - // Fuse if this is not a "pure" matmul. - if (absl::c_find(kPureOpcodes, new_hlo->opcode()) == kPureOpcodes.end()) { - return FusionDecision{}; - } - } - } - return "No profitable operations to fuse."; -} - -// Extracts into fused computations parts of HLO graph including dot() -// operations that can target the triton GEMM emitter. -class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor { - public: - explicit GemmRewriterTritonVisitor(const se::GpuComputeCapability gpu_version) - : gpu_version_(gpu_version) {} - // Checks that a dot() should be targeting the triton GEMM emitter; - // if so - fuses all its compatible inputs and outputs as a new computation - // and replaces the original dot() with a call to the computation. - Status HandleDot(HloInstruction* dot) override { - std::string fusion_name = absl::StrCat("triton_gemm_", dot->name()); - HloComputation::Builder builder(absl::StrCat(fusion_name, "_computation")); - std::vector fusion_inputs; - HloInstruction* fusion_output = nullptr; - TF_ASSIGN_OR_RETURN( - const FusionDecision should_fuse, - FuseDot(*dot, gpu_version_, builder, fusion_inputs, &fusion_output)); - if (builder.last_added_instruction() == nullptr) { - return OkStatus(); - } - // If a GEMM requiring padding for cuBLAS is encountered here this - // happened because earlier ShouldTritonHandleGEMM() accepted it and padding - // was skipped. Accept it ignoring profitability checks. - if (!CublasRequiresPadding( - *Cast(dot), - std::get(gpu_version_)) && - !should_fuse) { - return OkStatus(); - } - - HloComputation* computation = - dot->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(), - /*is_entry=*/false); - HloInstruction* dot_fusion = - dot->parent()->AddInstruction(HloInstruction::CreateFusion( - computation->root_instruction()->shape(), - HloInstruction::FusionKind::kCustom, fusion_inputs, computation)); - dot_fusion->GetModule()->SetAndUniquifyInstrName(dot_fusion, fusion_name); - - TF_ASSIGN_OR_RETURN(auto backend_config, - dot_fusion->backend_config()); - backend_config.set_kind(std::string(kTritonGemmFusionKind)); - TF_RETURN_IF_ERROR(dot_fusion->set_backend_config(backend_config)); - - if (fusion_output->IsRoot()) { - fusion_output->parent()->set_root_instruction(dot_fusion); - TF_RETURN_IF_ERROR( - fusion_output->parent()->RemoveInstructionAndUnusedOperands( - fusion_output)); - MarkAsChanged(); - } else { - TF_RETURN_IF_ERROR(ReplaceInstruction(fusion_output, dot_fusion)); - } - XLA_VLOG_LINES(5, computation->ToString(HloPrintOptions::ShortParsable())); - return OkStatus(); - } - - private: - se::GpuComputeCapability gpu_version_; -}; - -StatusOr RunOnComputation(HloComputation* computation, - se::GpuComputeCapability gpu_version) { - GemmRewriterTritonVisitor visitor(gpu_version); - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); - return visitor.changed(); -} - -} // namespace - -FusionDecision CanTritonHandleGEMM(const HloInstruction& dot, - const se::GpuComputeCapability gpu_version) { - if (dot.opcode() != HloOpcode::kDot || - !tsl::tensor_float_32_execution_enabled() || - absl::c_any_of(dot.precision_config().operand_precision(), - [](int x) { return x != PrecisionConfig::DEFAULT; })) { - return "Non-default precision."; - } - - auto supported_output_type = [&](const PrimitiveType t) { - const auto cuda_compute_capability = - std::get(gpu_version); - switch (t) { - case F16: - case F32: - return true; - case BF16: - return cuda_compute_capability.IsAtLeast( - stream_executor::CudaComputeCapability::AMPERE); - default: - return false; - } - }; - - // TODO(b/266862493): Support more output types. - if (!supported_output_type(dot.shape().element_type())) { - return "Unsupported output data type."; - } - - if (!IsTritonSupportedDataType(dot.operand(0)->shape().element_type(), - gpu_version) || - !IsTritonSupportedDataType(dot.operand(1)->shape().element_type(), - gpu_version)) { - return "Unsupported input data type."; - } - - const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); - - // TODO(b/269580541): support multiple batch dimensions. - if (dim_numbers.lhs_batch_dimensions().size() > 1) { - return "Multiple batch dimensions."; - } - - // Cases where lhs or rhs have no non-contracting dims are not handled. - if (dim_numbers.lhs_batch_dimensions().size() + - dim_numbers.lhs_contracting_dimensions().size() == - dot.operand(0)->shape().rank() || - dim_numbers.rhs_batch_dimensions().size() + - dim_numbers.rhs_contracting_dimensions().size() == - dot.operand(1)->shape().rank()) { - return "No non-contracting dimensions."; - } - - for (int operand_number = 0; operand_number <= 1; ++operand_number) { - // This pass relies on dot decomposer which ensures that all non-contracting - // dimensions are merged into one. Using NonContractingDimensionIndex is - // sufficient. - const int64_t nc_size = - dot.operand(operand_number) - ->shape() - .dimensions(NonContractingDimensionIndex(dot, operand_number)); - if (nc_size <= 1) { - return "Trivial non-contracting dimensions."; - } - } - - return FusionDecision{}; -} - -bool ShouldTritonHandleGEMM(HloInstruction& dot, - const se::GpuComputeCapability gpu_version) { - std::vector fusion_inputs; - HloComputation::Builder builder("disposable"); - return FuseDot(dot, gpu_version, builder, fusion_inputs, - /*fusion_output_ptr=*/nullptr) - ->CanFuse(); -} - -StatusOr GemmRewriterTriton::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - bool changed = false; - for (HloComputation* computation : - module->MakeNonfusionComputations(execution_threads)) { - TF_ASSIGN_OR_RETURN(bool result, - RunOnComputation(computation, gpu_version_)); - changed |= result; - } - return changed; -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/gemm_rewriter_triton.h b/xla/service/gpu/gemm_rewriter_triton.h deleted file mode 100644 index 34a77bf905bbb..0000000000000 --- a/xla/service/gpu/gemm_rewriter_triton.h +++ /dev/null @@ -1,60 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GEMM_REWRITER_TRITON_H_ -#define XLA_SERVICE_GPU_GEMM_REWRITER_TRITON_H_ - -// This file contains the code for fusing dots and other operations into Triton -// GEMM fusions. - -#include "absl/container/flat_hash_set.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/hlo_pass_interface.h" -#include "xla/service/instruction_fusion.h" -#include "xla/statusor.h" -#include "xla/stream_executor/device_description.h" - -namespace xla { -namespace gpu { - -// Filters GEMMs which can be handled using Triton. -FusionDecision CanTritonHandleGEMM(const HloInstruction&, - se::GpuComputeCapability gpu_version); - -// Filters GEMMs which are better to handle using Triton. -bool ShouldTritonHandleGEMM(HloInstruction&, - se::GpuComputeCapability gpu_version); - -// Rewrite compatible dot() calls into custom calls with fused computations -// that target Triton-based matmul emitter. -class GemmRewriterTriton : public HloModulePass { - public: - explicit GemmRewriterTriton(se::GpuComputeCapability gpu_version) - : gpu_version_(gpu_version) {} - absl::string_view name() const override { return "triton-gemm-rewriter"; } - - using HloPassInterface::Run; - StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - se::GpuComputeCapability gpu_version_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_GEMM_REWRITER_TRITON_H_ diff --git a/xla/service/gpu/gemm_rewriter_triton_test.cc b/xla/service/gpu/gemm_rewriter_triton_test.cc deleted file mode 100644 index 8913782398cca..0000000000000 --- a/xla/service/gpu/gemm_rewriter_triton_test.cc +++ /dev/null @@ -1,954 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/gemm_rewriter_triton.h" - -#include -#include - -#include -#include -#include "absl/strings/string_view.h" -#include "xla/autotuning.pb.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/gpu/cublas_padding_requirements.h" -#include "xla/service/gpu/triton_fusion_analysis.h" -#include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" -#include "xla/statusor.h" -#include "xla/stream_executor/device_description.h" -#include "xla/tests/filecheck.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" -#include "xla/xla.pb.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { -namespace { - -using ::testing::ElementsAre; -using ::testing::FieldsAre; - -namespace m = ::xla::match; - -class GemmRewriterTritonTest : public HloTestBase { - public: - GemmRewriterTritonTest() - : HloTestBase(/*verifier_layout_sensitive=*/true, - /*allow_mixed_precision_in_hlo_verifier=*/false) {} - - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_triton_gemm_any(false); - return debug_options; - } - - se::GpuComputeCapability gpu_version_{ - se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0}}; - - void MatchHloModule(HloModule& module, absl::string_view pattern) { - TF_ASSERT_OK_AND_ASSIGN(bool filecheck_result, - RunFileCheck(module.ToString(), pattern)); - EXPECT_TRUE(filecheck_result); - } -}; - -TEST_F(GemmRewriterTritonTest, TransposeSubdimensionGroup) { - // This HLO is artificial because unnecessary reshapes get optimized - // out during compilation. It tests the ability of GemmRewriterTriton - // to handle transposes of groups of subdimensions. - auto module = ParseAndReturnVerifiedModule(R"( -HloModule m - -ENTRY e { - p0 = f32[32,3] parameter(0) - t1 = f32[3,32] transpose(p0), dimensions={1,0} - r1 = f32[3,8,4] reshape(t1) - r0 = f32[3,32] reshape(r1) - p1 = f16[32,7] parameter(1) - c1 = f32[32,7] convert(p1) - ROOT d = f32[3,7] dot(r0, c1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})") - .value(); - EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); -} - -TEST_F(GemmRewriterTritonTest, UnsupportedTransposeIsNotFused) { - auto module = ParseAndReturnVerifiedModule(R"( -ENTRY e { - p0 = f16[1,512,8,1024]{3,1,0,2} parameter(0) - c = f16[1,512,8,1024]{3,2,1,0} copy(p0) - b = f16[4096,1024]{1,0} bitcast(c) - p1 = f16[128,1024]{1,0} parameter(1) - ROOT d = f16[4096,128]{1,0} dot(b, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={1} -})") - .value(); - EXPECT_FALSE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); -} - -TEST_F(GemmRewriterTritonTest, BitcastChain) { - // This HLO is artificial because unnecessary reshapes get optimized - // out during compilation. It tests the ability of GemmRewriterTriton - // to handle various kinds of bitcasts. - auto module = ParseAndReturnVerifiedModule(R"( -HloModule m - -ENTRY e { - p0 = s8[60,5] parameter(0) - r0 = s8[3,20,5] reshape(p0) - c0 = f16[3,20,5] convert(r0) - p1 = f16[3,200] parameter(1) - r12 = f16[600] reshape(p1) - r11 = f16[30,20] reshape(r12) - r1 = f16[3,10,20] reshape(r11) - ROOT d = f16[3,5,10] dot(c0, r1), - lhs_contracting_dims={1}, rhs_contracting_dims={2}, - lhs_batch_dims={0}, rhs_batch_dims={0} -})") - .value(); - EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); -} - -TEST_F(GemmRewriterTritonTest, SplitDimensionTwice) { - auto module = ParseAndReturnVerifiedModule(R"( -ENTRY e { - p0 = s8[4,2,32,4,2] parameter(0) - r1 = s8[8,32,8] reshape(p0) - t1 = s8[32,8,8] transpose(r1), dimensions={1,0,2} - r0 = s8[32,64] reshape(t1) - p1 = s8[32,32] parameter(1) - c0 = f16[32,32] convert(p1) - ROOT d = f16[64,32] dot(r0, c0), - lhs_contracting_dims={0}, rhs_contracting_dims={1} -})") - .value(); - EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); -} - -TEST_F(GemmRewriterTritonTest, DoNotTriggerOnUnsupportedOutputConversions) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - p0 = f16[128,256] parameter(0) - p1 = f16[256,512] parameter(1) - r = f16[128,512] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT c = u8[128,512] convert(r) -})")); - EXPECT_FALSE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); -} - -TEST_F(GemmRewriterTritonTest, DoNotTriggerWhenTheLhsNoncontractingDimIs1) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - p0 = s8[1,256] parameter(0) - p0c = f16[1,256] convert(p0) - p1 = f16[256,512] parameter(1) - ROOT r = f16[1,512] dot(p0c, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})")); - EXPECT_FALSE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); -} - -TEST_F(GemmRewriterTritonTest, DoNotTriggerWhenTheRhsNoncontractingDimIs1) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - p0 = s8[128,256] parameter(0) - p0c = f16[128,256] convert(p0) - p1 = f16[256,1] parameter(1) - ROOT r = f16[128,1] dot(p0c, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})")); - EXPECT_FALSE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); -} - -TEST_F(GemmRewriterTritonTest, HandleDotIfCublasRequiresPadding) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -HloModule m - -ENTRY e { - p0 = f16[5,3] parameter(0) - p1 = f16[5,7] parameter(1) - ROOT d = f16[3,7] dot(p0, p1), - lhs_contracting_dims={0}, rhs_contracting_dims={0} -})")); - - const se::CudaComputeCapability cc{se::CudaComputeCapability::VOLTA, 0}; - EXPECT_TRUE(CublasRequiresPadding( - *xla::Cast( - module->entry_computation()->root_instruction()), - cc)); - EXPECT_TRUE(GemmRewriterTriton(cc).Run(module.get()).value()); -} - -TEST_F(GemmRewriterTritonTest, FuseSliceOfParameterWithOtherUsers) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - p0 = f32[97,121] parameter(0) - s0 = f32[7,101] slice(p0), slice={[3:10], [10:111]} - p1 = f32[101,16] parameter(1) - d = f32[16,7] dot(p1, s0), - lhs_contracting_dims={0}, rhs_contracting_dims={1} - s1 = f32[3,33] slice(p0), slice={[10:13], [20:53]} - ROOT t = tuple(d, s1) -})")); - - const se::CudaComputeCapability cc{se::CudaComputeCapability::VOLTA, 0}; - EXPECT_TRUE(GemmRewriterTriton(cc).Run(module.get()).value()); -} - -TEST_F(GemmRewriterTritonTest, DoNotFuseSliceOfMixedDimensions) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - p0 = bf16[768,64] parameter(0) - s0 = bf16[768,32] slice(p0), slice={[0:768], [0:32]} - b0 = bf16[256,3,32] reshape(s0) - b1 = bf16[256,96] reshape(b0) - p1 = bf16[256,96] parameter(1) - ROOT d = bf16[96,96] dot(b1, p1), - lhs_contracting_dims={0}, rhs_contracting_dims={0} -})")); - - const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0}; - EXPECT_FALSE(GemmRewriterTriton(cc).Run(module.get()).value()); -} - -TEST_F(GemmRewriterTritonTest, DoNotFuseSlicesOfNonMajorFragments) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - p0 = f32[2,2,256,256] parameter(0) - s0 = f32[1,1,256,256] slice(p0), - slice={[0:1], [0:1], [0:256], [0:256]} - r0 = f32[256,256] reshape(s0) - p1 = f16[2,2,256,256] parameter(1) - s1 = f16[1,1,256,256] slice(p1), - slice={[0:1], [0:1], [0:256], [0:256]} - r1 = f16[256,256] reshape(s1) - ROOT d = f32[256,256] dot(r0, r1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})")); - - const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0}; - EXPECT_FALSE(GemmRewriterTriton(cc).Run(module.get()).value()); -} - -TEST_F(GemmRewriterTritonTest, SliceToDegenerateIsSkipped) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - p = f32[3] parameter(0) - s = f32[1] slice(p), slice={[2:3]} - r = f32[] reshape(s) - b = f32[3,3] broadcast(r), dimensions={} - ROOT d = f32[3,3] dot(b, b), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} -)")); - const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0}; - - ASSERT_TRUE(GemmRewriterTriton(cc).Run(module.get()).value()); - - // Slice is not fused. - MatchHloModule(*module, R"( -; CHECK-NOT: slice -; CHECK: ENTRY -; CHECK: slice -)"); -} - -TEST_F(GemmRewriterTritonTest, MultipleUsesAreHandled) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - c = f32[] constant(1) - b = f32[6,8] broadcast(c), dimensions={} - p0 = f32[6,8] parameter(0) - a1 = f32[6,8] add(p0, b) - e = f32[6,8] exponential(a1) - a2 = f32[6,8] add(e, b) - d = f32[6,8] divide(b, a2) - p2 = f16[8,6] parameter(1) - cv = f32[8,6] convert(p2) - ROOT r = f32[6,6] dot(d, cv), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})")); - const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0}; - EXPECT_TRUE(GemmRewriterTriton(cc).Run(module.get()).value()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); -} - -TEST_F(GemmRewriterTritonTest, BinaryElementwiseOfBroadcastIsFused) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - p2 = f32[3072] parameter(2) - b = f32[8192,3072] broadcast(p2), dimensions={1} - p0 = f16[8192,3072] parameter(0) - p0c = f32[8192,3072] convert(p0) - a = f32[8192,3072] add(p0c, b) - p1 = f32[3072,768] parameter(1) - ROOT r = f32[8192,768] dot(a, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})")); - const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0}; - EXPECT_TRUE(GemmRewriterTriton(cc).Run(module.get()).value()); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()))); -} - -TEST_F(GemmRewriterTritonTest, - BinaryElementwiseOfUnsupportedBroadcastIsNotFused) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - p2 = f32[768] parameter(2) - b = f32[8192,768,4] broadcast(p2), dimensions={1} - s = f32[8192,3072] bitcast(b) - p0 = f16[8192,3072] parameter(0) - p0c = f32[8192,3072] convert(p0) - a = f32[8192,3072] add(p0c, s) - p1 = f32[3072,768] parameter(1) - ROOT r = f32[8192,768] dot(a, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})")); - const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0}; - EXPECT_FALSE(GemmRewriterTriton(cc).Run(module.get()).value()); -} - -class GemmRewriterTritonLevel2Test : public GemmRewriterTritonTest { - public: - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = - GemmRewriterTritonTest::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_triton_fusion_level(2); - return debug_options; - } -}; - -TEST_F(GemmRewriterTritonLevel2Test, ReshapeToScalarIsHandled) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - p0 = s8[5,3] parameter(0) - c = f16[5,3] convert(p0) - p1 = f16[1] parameter(1) - r = f16[] reshape(p1) - b = f16[5,7] broadcast(r) - ROOT d = f16[3,7] dot(c, b), - lhs_contracting_dims={0}, rhs_contracting_dims={0} -})")); - - EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); -} - -TEST_F(GemmRewriterTritonLevel2Test, DoNotFuseIncompatibleDimensionSplits) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - p1 = s8[5,7,2,3]{3,2,1,0} parameter(1) - t1 = s8[7,5,2,3]{3,2,1,0} transpose(p1), dimensions={1,0,2,3} - r1 = s8[7,30]{1,0} reshape(t1) - cvt = f16[7,30]{1,0} convert(r1) - p2 = f16[2,7,5,3]{3,2,1,0} parameter(2) - t2 = f16[7,2,5,3]{3,2,1,0} transpose(p2), dimensions={1,0,2,3} - r2 = f16[7,30]{1,0} reshape(t2) - a = f16[7,30]{1,0} add(cvt, r2) - p0 = f16[7,79]{1,0} parameter(0) - ROOT dot = f16[30,79]{1,0} dot(a, p0), - lhs_contracting_dims={0}, rhs_contracting_dims={0} -})")); - - EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Transpose(), m::Parameter(), m::Parameter()))); -} - -TEST_F(GemmRewriterTritonLevel2Test, DoNotFuseTooManyParameters) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - tmp_0 = f32[] constant(1) - tmp_1 = f32[3,49]{1,0} broadcast(tmp_0), dimensions={} - tmp_2 = f32[3,49]{1,0} parameter(6) - tmp_3 = f32[] constant(0) - tmp_4 = f32[3,49]{1,0} broadcast(tmp_3), dimensions={} - tmp_5 = pred[3,49]{1,0} compare(tmp_2, tmp_4), direction=GT - tmp_6 = f32[3,49]{1,0} convert(tmp_5) - tmp_7 = f32[3,49]{1,0} subtract(tmp_1, tmp_6) - tmp_8 = s32[] parameter(13) - tmp_9 = f32[] convert(tmp_8) - tmp_10 = f32[] maximum(tmp_9, tmp_0) - tmp_11 = f32[] divide(tmp_3, tmp_10) - tmp_12 = f32[3,49]{1,0} broadcast(tmp_11), dimensions={} - tmp_13 = pred[3,49]{1,0} parameter(7) - tmp_14 = pred[3,49]{1,0} parameter(10) - tmp_15 = pred[3,49]{1,0} and(tmp_13, tmp_14) - tmp_16 = f32[3,49]{1,0} convert(tmp_15) - tmp_17 = f32[3,49]{1,0} multiply(tmp_12, tmp_16) - tmp_18 = f32[3,49]{1,0} negate(tmp_17) - tmp_19 = f32[3,49]{1,0} multiply(tmp_7, tmp_18) - tmp_20 = f32[3,49]{1,0} parameter(19) - tmp_21 = f32[3,49]{1,0} subtract(tmp_1, tmp_20) - tmp_22 = f32[3,49]{1,0} divide(tmp_19, tmp_21) - tmp_23 = f32[3,49]{1,0} negate(tmp_22) - tmp_24 = f32[3,49]{1,0} negate(tmp_6) - tmp_25 = f32[3,49]{1,0} multiply(tmp_24, tmp_17) - tmp_26 = f32[3,49]{1,0} divide(tmp_25, tmp_20) - tmp_27 = f32[3,49]{1,0} add(tmp_23, tmp_26) - tmp_28 = f32[3,49]{1,0} parameter(18) - tmp_29 = f32[3,49]{1,0} multiply(tmp_27, tmp_28) - tmp_30 = f32[3,49]{1,0} parameter(17) - tmp_31 = f32[3,49]{1,0} multiply(tmp_29, tmp_30) - tmp_32 = f32[3,49]{1,0} parameter(16) - tmp_33 = f32[3,49]{1,0} multiply(tmp_31, tmp_32) - tmp_34 = f32[3,49]{1,0} parameter(15) - tmp_35 = f32[3,49]{1,0} add(tmp_33, tmp_34) - tmp_36 = f32[3,49]{1,0} parameter(14) - tmp_37 = f32[3,49]{1,0} add(tmp_35, tmp_36) - tmp_38 = f32[1,1]{1,0} constant({ {0} }) - tmp_39 = f32[1,1]{1,0} broadcast(tmp_38), dimensions={0,1} - tmp_40 = f32[] reshape(tmp_39) - tmp_41 = f32[3,32]{1,0} broadcast(tmp_40), dimensions={} - tmp_42 = u32[48]{0} parameter(11) - tmp_43 = u32[48]{0} parameter(5) - tmp_44 = u32[96]{0} concatenate(tmp_42, tmp_43), dimensions={0} - tmp_45 = u32[3,32]{1,0} reshape(tmp_44) - tmp_46 = u32[96]{0} reshape(tmp_45) - tmp_47 = u32[] constant(1) - tmp_48 = u32[3,32]{1,0} broadcast(tmp_47), dimensions={} - tmp_49 = u32[96]{0} reshape(tmp_48) - tmp_50 = u32[96]{0} shift-right-logical(tmp_46, tmp_49) - tmp_51 = u32[3,32]{1,0} reshape(tmp_50) - tmp_52 = u32[3,32]{1,0} or(tmp_51, tmp_48) - tmp_53 = f32[3,32]{1,0} bitcast-convert(tmp_52) - tmp_54 = f32[3,32]{1,0} broadcast(tmp_0), dimensions={} - tmp_55 = f32[3,32]{1,0} subtract(tmp_53, tmp_54) - tmp_56 = f32[1,1]{1,0} constant({ {1} }) - tmp_57 = f32[1,1]{1,0} broadcast(tmp_56), dimensions={0,1} - tmp_58 = f32[] reshape(tmp_57) - tmp_59 = f32[3,32]{1,0} broadcast(tmp_58), dimensions={} - tmp_60 = f32[3,32]{1,0} multiply(tmp_55, tmp_59) - tmp_61 = f32[3,32]{1,0} add(tmp_60, tmp_41) - tmp_62 = f32[3,32]{1,0} maximum(tmp_41, tmp_61) - tmp_63 = f32[3,32]{1,0} broadcast(tmp_3), dimensions={} - tmp_64 = pred[3,32]{1,0} compare(tmp_62, tmp_63), direction=LT - tmp_65 = f32[3,32]{1,0} convert(tmp_64) - tmp_66 = f32[3,49]{1,0} parameter(9) - tmp_67 = f32[49]{0} parameter(4) - tmp_68 = f32[3,49]{1,0} broadcast(tmp_67), dimensions={1} - tmp_69 = f32[3,49]{1,0} add(tmp_66, tmp_68) - tmp_70 = f32[1,49]{1,0} parameter(12) - tmp_71 = f32[1,49]{1,0} broadcast(tmp_0), dimensions={} - tmp_72 = f32[1,49]{1,0} divide(tmp_70, tmp_71) - tmp_73 = f32[1,49]{1,0} broadcast(tmp_72), dimensions={0,1} - tmp_74 = f32[49]{0} reshape(tmp_73) - tmp_75 = f32[3,49]{1,0} broadcast(tmp_74), dimensions={1} - tmp_76 = f32[3,49]{1,0} subtract(tmp_69, tmp_75) - tmp_77 = f32[1,49]{1,0} parameter(3) - tmp_78 = f32[1,49]{1,0} parameter(8) - tmp_79 = f32[1,49]{1,0} divide(tmp_78, tmp_71) - tmp_80 = f32[1,49]{1,0} multiply(tmp_72, tmp_72) - tmp_81 = f32[1,49]{1,0} subtract(tmp_79, tmp_80) - tmp_82 = f32[1,49]{1,0} add(tmp_81, tmp_71) - tmp_83 = f32[1,49]{1,0} rsqrt(tmp_82) - tmp_84 = f32[1,49]{1,0} multiply(tmp_77, tmp_83) - tmp_85 = f32[1,49]{1,0} broadcast(tmp_84), dimensions={0,1} - tmp_86 = f32[49]{0} reshape(tmp_85) - tmp_87 = f32[3,49]{1,0} broadcast(tmp_86), dimensions={1} - tmp_88 = f32[3,49]{1,0} multiply(tmp_76, tmp_87) - tmp_89 = f32[1,49]{1,0} parameter(2) - tmp_90 = f32[1,49]{1,0} broadcast(tmp_89), dimensions={0,1} - tmp_91 = f32[49]{0} reshape(tmp_90) - tmp_92 = f32[3,49]{1,0} broadcast(tmp_91), dimensions={1} - tmp_93 = f32[3,49]{1,0} add(tmp_88, tmp_92) - tmp_94 = f32[49,32]{1,0} parameter(1) - tmp_95 = f32[3,32]{1,0} dot(tmp_93, tmp_94), lhs_contracting_dims={1}, rhs_contracting_dims={0} - tmp_96 = f32[32]{0} parameter(0) - tmp_97 = f32[3,32]{1,0} broadcast(tmp_96), dimensions={1} - tmp_98 = f32[3,32]{1,0} add(tmp_95, tmp_97) - tmp_99 = f32[3,32]{1,0} multiply(tmp_65, tmp_98) - tmp_100 = f32[3,32]{1,0} divide(tmp_99, tmp_63) - tmp_101 = f32[3,32]{1,0} maximum(tmp_100, tmp_63) - ROOT tmp_102 = f32[49,32]{1,0} dot(tmp_37, tmp_101), lhs_contracting_dims={0}, rhs_contracting_dims={0} -})")); - - EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); - EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), - HloOpcode::kFusion); - EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(), - HloInstruction::FusionKind::kCustom); - EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(), - TritonFusionAnalysis::kMaxParameterPerDotScope * 2); -} - -TEST_F(GemmRewriterTritonLevel2Test, - DoNotFuseTooManyParametersWhenAnInstructionWouldAddMultipleParameters) { - static_assert(TritonFusionAnalysis::kMaxParameterPerDotScope == 4, - "We have to update this test."); - // If we fuse the select, it adds 2 additional parameters at once (not 3, - // because the select instruction itself is removed from the parameters). - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - a = f32[3,49]{1,0} parameter(0) - b = f32[3,49]{1,0} parameter(1) - c = pred[3,49]{1,0} parameter(2) - d = f32[3,49]{1,0} parameter(3) - e = f32[3,49]{1,0} parameter(4) - add0 = f32[3,49]{1,0} add(a, b) - select = f32[3,49]{1,0} select(c, d, e) - add1 = f32[3,49]{1,0} add(add0, select) - f = f32[3,32]{1,0} parameter(5) - ROOT tmp_102 = f32[49,32]{1,0} dot(add1, f), lhs_contracting_dims={0}, rhs_contracting_dims={0} -})")); - - EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); - EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), - HloOpcode::kFusion); - EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(), - HloInstruction::FusionKind::kCustom); - EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(), - TritonFusionAnalysis::kMaxParameterPerDotScope + 1); -} - -TEST_F(GemmRewriterTritonLevel2Test, - InstructionsReachableFromMultipleOperandsAreHandledCorrectly) { - static_assert(TritonFusionAnalysis::kMaxParameterPerDotScope == 4, - "We have to update this test."); - // There was a bug that some dead code was generated into some fusions in a - // specific edge case. When some instructions were reachable both through the - // LHS and the RHS operands, the BFS (Breadth-first search) through the LHS1 - // operand "marked" one operation as non-fusible because it would exceed the - // limit on fusion parameters per operand. But the BFS through the RHS operand - // went through that node and fused some more operands. So the resulting - // fusion was not connected and caused errors. This test case checks that such - // configurations generate a correct HLO now. - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - a = f32[2,4]{1,0} parameter(0) - b = f32[2,4]{1,0} parameter(1) - c = f32[2,4]{1,0} parameter(2) - d = f32[2,4]{1,0} parameter(3) - e = f32[2,4]{1,0} parameter(4) - add0 = f32[2,4]{1,0} add(a, b) - add1 = f32[2,4]{1,0} add(add0, c) - add2 = f32[2,4]{1,0} add(add1, d) - add3 = f32[2,4]{1,0} add(add2, e) - ROOT r = f32[2,2]{1,0} dot(add3, add0), - lhs_contracting_dims={1}, rhs_contracting_dims={1} -})")); - - EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); - // ~VerifiedHloModule() will verify the module. -} - -TEST_F(GemmRewriterTritonLevel2Test, EachScopeIsFusedToASeparateSubgraph) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - a = f32[2,4]{1,0} parameter(0) - b = f32[2,4]{1,0} parameter(1) - add = f32[2,4]{1,0} add(a, b) - ROOT r = f32[2,2]{1,0} dot(add, add), - lhs_contracting_dims={1}, rhs_contracting_dims={1} -})")); - - EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); - - MatchHloModule(*module, R"( -CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0) -CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1) -CHECK-DAG: %[[ADD0:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]]) -CHECK-DAG: %[[P2:.*]] = f32[2,4]{1,0} parameter(2) -CHECK-DAG: %[[P3:.*]] = f32[2,4]{1,0} parameter(3) -CHECK-DAG: %[[ADD1:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P2]], f32[2,4]{1,0} %[[P3]]) -CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} dot(f32[2,4]{1,0} %[[ADD0]], f32[2,4]{1,0} %[[ADD1]]) -CHECK: ENTRY -CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0) -CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1) -CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} -CHECK-SAME: fusion(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]], f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]]), -CHECK-SAME: kind=kCustom -CHECK-SAME: __triton_gemm -})"); -} - -TEST_F(GemmRewriterTritonLevel2Test, - OperationsAddingMoreParametersGetMultipleTries) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -e { - p0 = f32[2,2] parameter(0) - c0 = f32[] constant(12345) - b0 = f32[2,2] broadcast(c0), dimensions={} - m0 = f32[2,2] multiply(p0, b0) - c1 = f32[] constant(34567) - b1 = f32[2,2] broadcast(c1), dimensions={} - a0 = f32[2,2] add(m0, b1) - b3 = f32[2,2,2] broadcast(a0), dimensions={0,1} - p2 = f32[2,2,2] parameter(2) - m2 = f32[2,2,2] multiply(p2, b3) - p1 = f32[2]{0} parameter(1) - c2 = f32[] constant(5678) - b2 = f32[2] broadcast(c2), dimensions={} - a1 = f32[2]{0} add(p1, b2) - b4 = f32[2,2,2] broadcast(a1), dimensions={2} - m1 = f32[2,2,2] multiply(m2, b4) - b = f32[4,2] bitcast(m1) - p3 = f16[2,2] parameter(3) - p3c = f32[2,2] convert(p3) - ROOT r = f32[4,2] dot(b, p3c), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})")); - - EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), - m::Parameter(), m::Parameter())))); -} - -TEST_F(GemmRewriterTritonLevel2Test, FusionLevelIsLimitedOnVolta) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - p0 = f32[2,53] parameter(0) - p0e = f32[2,53] exponential(p0) - p1 = s16[53,2] parameter(1) - p1c = f32[53,2] convert(p1) - ROOT dot = f32[2,2] dot(p0e, p1c), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})")); - EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ - se::CudaComputeCapability::VOLTA, 0}) - .Run(module.get()) - .value()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch((m::Fusion(m::Parameter(), m::Exp())))); -} - -TEST_F(GemmRewriterTritonLevel2Test, ParameterUsedElementwiseTwiceIsFused) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -HloModule t - -ENTRY e { - p0 = f32[2,35] parameter(0) - p0n = f32[2,35] negate(p0) - p0e = f32[2,35] exponential(p0) - a = f32[2,35] add(p0e, p0n) - p1 = f16[35,2] parameter(1) - p1c = f32[35,2] convert(p1) - ROOT dot = f32[2,2] dot(a, p1c), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})")); - EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ - se::CudaComputeCapability::AMPERE, 0}) - .Run(module.get()) - .value()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch((m::Fusion(m::Parameter(), m::Parameter())))); - TF_ASSERT_OK_AND_ASSIGN( - const auto analysis, - TritonFusionAnalysis::Execute(*module->entry_computation() - ->root_instruction() - ->called_computations()[0])); - EXPECT_EQ(analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).size(), - 1); - EXPECT_EQ(analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).size(), - 1); -} - -TEST_F(GemmRewriterTritonLevel2Test, - ParameterUsedNonElementwiseTwiceIsFusedOnlyOnOnePath) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -HloModule t - -ENTRY e { - p0 = f32[4,4] parameter(0) - p0t = f32[4,4] transpose(p0), dimensions={1,0} - a = f32[4,4] add(p0, p0t) - p1 = f16[4,5] parameter(1) - p1c = f32[4,5] convert(p1) - ROOT dot = f32[4,5] dot(a, p1c), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})")); - EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ - se::CudaComputeCapability::AMPERE, 0}) - .Run(module.get()) - .value()); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch((m::Fusion(m::Parameter(), m::Transpose(), m::Parameter())))); -} - -TEST_F(GemmRewriterTritonLevel2Test, - ComputationParameterWithMultipleUsersIsNotTrivialToFuse) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - p0 = f32[400,400] parameter(0) - - c0 = f16[400,400] convert(p0) - p1 = f16[400,400] parameter(1) - dot0 = f16[400,400] dot(c0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - - c1 = f16[400,400] convert(p0) - p2 = f16[400,400] parameter(2) - dot1 = f16[400,400] dot(c1, p2), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - - ROOT a = f16[400,400] add(dot0, dot1) -})")); - EXPECT_FALSE(GemmRewriterTriton(se::CudaComputeCapability{ - se::CudaComputeCapability::AMPERE, 0}) - .Run(module.get()) - .value()); -} - -TEST_F(GemmRewriterTritonLevel2Test, NarrowingConversionIsAlwaysBetterToFuse) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -ENTRY e { - p0 = s8[512,512] parameter(0) - c0 = f16[512,512] convert(p0) - p1 = f16[512,512] parameter(1) - dot0 = f16[512,512] dot(c0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - - n = f16[512,512] negate(c0) - ROOT a = f16[512,512] add(dot0, n) -})")); - EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ - se::CudaComputeCapability::AMPERE, 0}) - .Run(module.get()) - .value()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch((m::Add(m::Fusion(m::Parameter(), m::Parameter()), - m::Negate())))); -} - -TEST_F(GemmRewriterTritonLevel2Test, NestedSlicingIsAnalyzedCorrectly) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -triton_gemm_d_computation { - p0 = f32[6,24]{1,0} parameter(0) - s1 = f32[5,20]{1,0} slice(p0), slice={[1:6], [3:23]} - n1 = f32[5,20]{1,0} negate(s1) - s2 = f32[3,7]{1,0} slice(n1), slice={[1:4], [13:20]} - p1 = f32[7,37]{1,0} parameter(1) - ROOT d = f32[3,37]{1,0} dot(s2, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = f32[7,37]{1,0} parameter(0) - p1 = f32[6,24]{1,0} parameter(1) - ROOT triton_gemm_d = f32[3,37]{1,0} fusion(p1, p0), kind=kCustom, - calls=triton_gemm_d_computation -})")); - const HloComputation* computation = - module->entry_computation()->root_instruction()->called_computations()[0]; - TF_ASSERT_OK_AND_ASSIGN(const auto analysis, - TritonFusionAnalysis::Execute(*computation)); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, - computation->parameter_instruction(0), 0), - ElementsAre(FieldsAre(/*stride=*/24, /*count=*/6, - /*slice_start=*/2, /*sliced_count=*/3, - /*subfragments=*/ElementsAre(3)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, - computation->parameter_instruction(0), 1), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/24, - /*slice_start=*/16, /*sliced_count=*/7, - /*subfragments=*/ElementsAre(7)))); -} - -TEST_F(GemmRewriterTritonLevel2Test, FusedConcatenationIsAnalyzedCorrectly) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -e { - p0 = s8[153,1536] parameter(0) - p1 = s8[153,128] parameter(1) - p2 = s8[153,256] parameter(2) - cat = s8[153,1920] concatenate(p0, p1, p2), dimensions={1} - cvt = bf16[153,1920] convert(cat) - p3 = bf16[16,153] parameter(3) - ROOT d = bf16[16,1920] dot(p3, cvt), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})")); - EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ - se::CudaComputeCapability::AMPERE, 0}) - .Run(module.get()) - .value()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), - m::Parameter(), m::Parameter())))); - const HloComputation* computation = - module->entry_computation()->root_instruction()->called_computations()[0]; - TF_ASSERT_OK_AND_ASSIGN(const auto analysis, - TritonFusionAnalysis::Execute(*computation)); - - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, - computation->parameter_instruction(0), 0), - ElementsAre(FieldsAre(/*stride=*/1536, /*count=*/153, - /*slice_start=*/0, /*sliced_count=*/153, - /*subfragments=*/ElementsAre(153)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, - computation->parameter_instruction(0), 1), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/1536, - /*slice_start=*/0, /*sliced_count=*/1536, - /*subfragments=*/ElementsAre(1536)))); - - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, - computation->parameter_instruction(1), 0), - ElementsAre(FieldsAre(/*stride=*/128, /*count=*/153, - /*slice_start=*/0, /*sliced_count=*/153, - /*subfragments=*/ElementsAre(153)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, - computation->parameter_instruction(1), 1), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/128, - /*slice_start=*/0, /*sliced_count=*/128, - /*subfragments=*/ElementsAre(128)))); - - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, - computation->parameter_instruction(2), 0), - ElementsAre(FieldsAre(/*stride=*/256, /*count=*/153, - /*slice_start=*/0, /*sliced_count=*/153, - /*subfragments=*/ElementsAre(153)))); - EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, - computation->parameter_instruction(2), 1), - ElementsAre(FieldsAre(/*stride=*/1, /*count=*/256, - /*slice_start=*/0, - /*sliced_count=*/256, - /*subfragments=*/ElementsAre(256)))); -} - -TEST_F(GemmRewriterTritonLevel2Test, IndivisibleConcatenationIsNotFused) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -e { - p0 = s8[124,1024] parameter(0) - p1 = s8[124,1001] parameter(1) - cat = s8[124,2025] concatenate(p0, p1), dimensions={1} - cvt = f16[124,2025] convert(cat) - p2 = f16[123,124] parameter(2) - ROOT d = f16[2025,123] dot(cvt, p2), - lhs_contracting_dims={0}, rhs_contracting_dims={1} -})")); - EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ - se::CudaComputeCapability::AMPERE, 0}) - .Run(module.get()) - .value()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch((m::Fusion(m::Concatenate(), m::Parameter())))); -} - -TEST_F(GemmRewriterTritonLevel2Test, ConcatenationOfContractingIsNotFused) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -e { - p0 = s8[124,1024] parameter(0) - p1 = s8[124,1024] parameter(1) - cat = s8[124,2048] concatenate(p0, p1), dimensions={1} - cvt = f16[124,2048] convert(cat) - p2 = f16[123,2048] parameter(2) - ROOT d = f16[124,123] dot(cvt, p2), - lhs_contracting_dims={1}, rhs_contracting_dims={1} -})")); - EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ - se::CudaComputeCapability::AMPERE, 0}) - .Run(module.get()) - .value()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch((m::Fusion(m::Concatenate(), m::Parameter())))); -} - -TEST_F(GemmRewriterTritonLevel2Test, ConcatenationOfBatchIsNotFused) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -e { - p0 = s8[124,1024,50] parameter(0) - p1 = s8[124,1024,50] parameter(1) - cat = s8[124,2048,50] concatenate(p0, p1), dimensions={1} - cvt = f16[124,2048,50] convert(cat) - p2 = f16[123,2048,50] parameter(2) - ROOT d = f16[2048,124,123] dot(cvt, p2), - lhs_batch_dims={1}, rhs_batch_dims={1}, - lhs_contracting_dims={2}, rhs_contracting_dims={2} -})")); - EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ - se::CudaComputeCapability::AMPERE, 0}) - .Run(module.get()) - .value()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch((m::Fusion(m::Concatenate(), m::Parameter())))); -} - -TEST_F(GemmRewriterTritonLevel2Test, - TwoConcatenationsOfSameParametersAreNotFused) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( -e { - p0 = s8[128,2] parameter(0) - p1 = s8[128,2] parameter(1) - cat0 = s8[256,2] concatenate(p0, p1), dimensions={0} - cvt0 = f16[256,2] convert(cat0) - cat1 = s8[256,2] concatenate(p1, p0), dimensions={0} - n1 = s8[256,2] negate(cat1) - cvt1 = f16[256,2] convert(n1) - a = f16[256,2] add(cvt1, cvt0) - p2 = f16[2,18] parameter(2) - ROOT d = f16[18,256] dot(p2, a), - lhs_contracting_dims={0}, rhs_contracting_dims={1} -})")); - - EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ - se::CudaComputeCapability::AMPERE, 0}) - .Run(module.get()) - .value()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch((m::Fusion(m::Concatenate(), m::Concatenate(), - m::Parameter())))); -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/gemm_thunk.cc b/xla/service/gpu/gemm_thunk.cc deleted file mode 100644 index b774280a2b3ec..0000000000000 --- a/xla/service/gpu/gemm_thunk.cc +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/gemm_thunk.h" - -#include - -#include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/thunk.h" -#include "xla/status.h" -#include "xla/stream_executor/device_memory.h" -#include "tsl/platform/logging.h" - -namespace xla { -namespace gpu { - -GemmThunk::GemmThunk(ThunkInfo thunk_info, GemmConfig config, - const BufferAllocation::Slice& lhs_buffer, - const BufferAllocation::Slice& rhs_buffer, - const BufferAllocation::Slice& output_buffer, - bool deterministic) - : Thunk(Kind::kGemm, thunk_info), - config_(std::move(config)), - lhs_buffer_(lhs_buffer), - rhs_buffer_(rhs_buffer), - output_buffer_(output_buffer), - deterministic_(deterministic) {} - -Status GemmThunk::ExecuteOnStream(const ExecuteParams& params) { - VLOG(3) << "Running GEMM thunk"; - const BufferAllocations& allocs = *params.buffer_allocations; - // TODO(ezhulenev): Pass a correct workspace. For now we ignore it as Thunks - // are disabled by default, and they do not interact with CUDA graphs. - se::DeviceMemoryBase workspace(nullptr, 0); - return RunGemm(config_, allocs.GetDeviceAddress(lhs_buffer_), - allocs.GetDeviceAddress(rhs_buffer_), - allocs.GetDeviceAddress(output_buffer_), workspace, - deterministic_, params.stream); -} - -Status GemmThunk::Initialize(se::StreamExecutor* executor, - ExecutableSource src) { - if (!executor->AsBlas()) { - return absl::InternalError("Failed to initialize BLAS support"); - } - return OkStatus(); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/gemm_thunk.h b/xla/service/gpu/gemm_thunk.h deleted file mode 100644 index 96272317697ef..0000000000000 --- a/xla/service/gpu/gemm_thunk.h +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_GEMM_THUNK_H_ -#define XLA_SERVICE_GPU_GEMM_THUNK_H_ - -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/thunk.h" -#include "xla/status.h" -#include "xla/stream_executor/stream_executor.h" - -namespace xla { -namespace gpu { - -// This is thread-compatible. -class GemmThunk : public Thunk { - public: - // Constructs a thunk that computes "output = (lhs rhs) * alpha" using - // BLAS gemm (alpha is stored in the instruction GemmBackendConfig). - GemmThunk(ThunkInfo thunk_info, GemmConfig config, - const BufferAllocation::Slice& lhs_buffer, - const BufferAllocation::Slice& rhs_buffer, - const BufferAllocation::Slice& output_buffer, bool deterministic); - - GemmThunk(const GemmThunk&) = delete; - GemmThunk& operator=(const GemmThunk&) = delete; - - Status ExecuteOnStream(const ExecuteParams& params) override; - Status Initialize(se::StreamExecutor* executor, - ExecutableSource src) override; - - private: - const GemmConfig config_; - const BufferAllocation::Slice lhs_buffer_; - const BufferAllocation::Slice rhs_buffer_; - const BufferAllocation::Slice output_buffer_; - // Whether to run deterministically. - const bool deterministic_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_GEMM_THUNK_H_ diff --git a/xla/service/gpu/gemv_rewriter.cc b/xla/service/gpu/gemv_rewriter.cc new file mode 100644 index 0000000000000..21e5f477e4b05 --- /dev/null +++ b/xla/service/gpu/gemv_rewriter.cc @@ -0,0 +1,183 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/gemv_rewriter.h" + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/shape.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +namespace { + +// Construct a new layout by adding a new minor-most dimension to the input +// layout. For example, {3, 2, 1, 0} is extended to {4, 3, 2, 1, 0}. +// We expect that the input layout is normalized by LayoutNormalizer, so that +// the input layout has a descending ordering. +absl::StatusOr GetLayoutWithNewMinorMostDimension( + const Layout& layout) { + // Check that the layout is normalized. + if (!LayoutUtil::IsMonotonicWithDim0Major(layout)) { + return absl::InvalidArgumentError("Layout is not normalized."); + } + return LayoutUtil::MakeDescendingLayout(layout.minor_to_major_size() + 1); +} + +class GemvRewriterVisitor : public DfsHloRewriteVisitor { + public: + absl::Status HandleDot(HloInstruction* instr) override { + HloDotInstruction* dot = Cast(instr); + const DotDimensionNumbers& dim_numbers = dot->dot_dimension_numbers(); + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + + // This pass relies on dot decomposer which ensures that all non-batch + // dimensions are merged into one. + bool lhs_has_non_contracting_dim = + lhs->shape().rank() == + dim_numbers.lhs_batch_dimensions_size() + + dim_numbers.lhs_contracting_dimensions_size() + 1; + bool rhs_has_non_contracting_dim = + rhs->shape().rank() == + dim_numbers.rhs_batch_dimensions_size() + + dim_numbers.rhs_contracting_dimensions_size() + 1; + + // Skip matrix-matrix multiplication. + if (lhs_has_non_contracting_dim && rhs_has_non_contracting_dim) { + return absl::OkStatus(); + } + + // Skip vector-vector multiplication. + if (!lhs_has_non_contracting_dim && !rhs_has_non_contracting_dim) { + return absl::OkStatus(); + } + + if (dot->shape().is_dynamic()) { + return absl::OkStatus(); + } + + changed_ = true; + + HloComputation* computation = dot->parent(); + HloInstruction* new_lhs = lhs; + if (!lhs_has_non_contracting_dim) { + const Shape& lhs_shape = lhs->shape(); + absl::Span lhs_dimensions = lhs_shape.dimensions(); + std::vector new_lhs_dimensions(lhs_dimensions.begin(), + lhs_dimensions.end()); + new_lhs_dimensions.push_back(1); + Shape new_lhs_shape( + lhs_shape.element_type(), new_lhs_dimensions, + absl::InlinedVector(new_lhs_dimensions.size(), false), + /*tuple_shapes=*/{}); + TF_ASSIGN_OR_RETURN( + *new_lhs_shape.mutable_layout(), + GetLayoutWithNewMinorMostDimension(lhs_shape.layout())); + new_lhs = computation->AddInstruction( + HloInstruction::CreateBitcast(new_lhs_shape, lhs)); + } + + HloInstruction* new_rhs = rhs; + if (!rhs_has_non_contracting_dim) { + const Shape& rhs_shape = rhs->shape(); + absl::Span rhs_dimensions = rhs_shape.dimensions(); + std::vector new_rhs_dimensions(rhs_dimensions.begin(), + rhs_dimensions.end()); + new_rhs_dimensions.push_back(1); + Shape new_rhs_shape( + rhs_shape.element_type(), new_rhs_dimensions, + absl::InlinedVector(new_rhs_dimensions.size(), false), + /*tuple_shapes=*/{}); + TF_ASSIGN_OR_RETURN( + *new_rhs_shape.mutable_layout(), + GetLayoutWithNewMinorMostDimension(rhs_shape.layout())); + new_rhs = computation->AddInstruction( + HloInstruction::CreateBitcast(new_rhs_shape, rhs)); + } + + std::vector new_out_dimensions; + new_out_dimensions.reserve(dot->shape().dimensions().size() + 1); + for (int64_t dim_size : dot->shape().dimensions()) { + new_out_dimensions.push_back(dim_size); + } + if (!lhs_has_non_contracting_dim) { + // Insert the trivial dimension before the non-contracting dimension from + // rhs. + int non_contracting_dim_size = new_out_dimensions.back(); + new_out_dimensions[new_out_dimensions.size() - 1] = 1; + new_out_dimensions.push_back(non_contracting_dim_size); + } else { + new_out_dimensions.push_back(1); + } + + Shape new_out_shape( + dot->shape().element_type(), new_out_dimensions, + absl::InlinedVector(new_out_dimensions.size(), false), + /*tuple_shapes=*/{}); + TF_ASSIGN_OR_RETURN( + *new_out_shape.mutable_layout(), + GetLayoutWithNewMinorMostDimension(dot->shape().layout())); + + HloInstruction* new_dot = + computation->AddInstruction(HloInstruction::CreateDot( + new_out_shape, new_lhs, new_rhs, dot->dot_dimension_numbers(), + dot->precision_config())); + HloInstruction* bitcast = computation->AddInstruction( + HloInstruction::CreateBitcast(dot->shape(), new_dot)); + return computation->ReplaceInstruction(dot, bitcast); + } + + bool changed() const { return changed_; } + + private: + bool changed_ = false; +}; + +} // namespace + +absl::StatusOr GemvRewriter::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + GemvRewriterVisitor gemv_rewriter; + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + TF_RETURN_IF_ERROR(computation->Accept(&gemv_rewriter)); + } + return gemv_rewriter.changed(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/gemv_rewriter.h b/xla/service/gpu/gemv_rewriter.h new file mode 100644 index 0000000000000..a041138b8af5c --- /dev/null +++ b/xla/service/gpu/gemv_rewriter.h @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_GEMV_REWRITER_H_ +#define XLA_SERVICE_GPU_GEMV_REWRITER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Rewrite a matrix-vector or a vector-matrix multiplication into a +// matrix-matrix multiplication with a trivial dimension. For example, +// [m x n] @ [n] is rewritten to [m x n] @ [n x 1], and [n] @ [m x n] is +// rewritten to [n x 1] @ [m x n]. +class GemvRewriter : public HloModulePass { + public: + absl::string_view name() const override { return "gemv-rewriter"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_GEMV_REWRITER_H_ diff --git a/xla/service/gpu/gemv_rewriter_test.cc b/xla/service/gpu/gemv_rewriter_test.cc new file mode 100644 index 0000000000000..2a8b8103e0a94 --- /dev/null +++ b/xla/service/gpu/gemv_rewriter_test.cc @@ -0,0 +1,149 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/gemv_rewriter.h" + +#include +#include + +#include +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +class GemvRewriterTest : public HloTestBase {}; + +TEST_F(GemvRewriterTest, RewriteMatrixVectorMultiplicationToGemm) { + const char* hlo = R"( + HloModule m + + ENTRY e { + p0 = f32[32,7] parameter(0) + p1 = f32[7] parameter(1) + ROOT d = f32[32] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + })"; + + const char* expected = R"() +// CHECK: %[[P0:.*]] = f32[32,7]{1,0} parameter(0) +// CHECK: %[[P1:.*]] = f32[7]{0} parameter(1) +// CHECK: %[[BITCAST:.*]] = f32[7,1]{1,0} bitcast(%[[P1]]) +// CHECK: %[[DOT:.*]] = f32[32,1]{1,0} dot(%[[P0]], %[[BITCAST]]), lhs_contracting_dims={1}, rhs_contracting_dims={0} +// CHECK: ROOT %[[ROOT:.*]] = f32[32]{0} bitcast(%[[DOT]]) +})"; + + RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected); +} + +TEST_F(GemvRewriterTest, RewriteVectorMatrixMultiplicationToGemm) { + const char* hlo = R"( + HloModule m + + ENTRY e { + p0 = f32[7] parameter(0) + p1 = f32[7,32] parameter(1) + ROOT d = f32[32] dot(p0, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={0} + })"; + + const char* expected = R"() +// CHECK: %[[P0:.*]] = f32[7]{0} parameter(0) +// CHECK: %[[BITCAST:.*]] = f32[7,1]{1,0} bitcast(%[[P0]]) +// CHECK: %[[P1:.*]] = f32[7,32]{1,0} parameter(1) +// CHECK: %[[DOT:.*]] = f32[1,32]{1,0} dot(%[[BITCAST]], %[[P1]]), lhs_contracting_dims={0}, rhs_contracting_dims={0} +// CHECK: ROOT %[[ROOT:.*]].1 = f32[32]{0} bitcast(%[[DOT]]) +})"; + + RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected); +} + +TEST_F(GemvRewriterTest, RewriteMatrixVectorMultiplicationWithBatch) { + const char* hlo = R"( + HloModule m + + ENTRY e { + p0 = f32[2,5,32,7] parameter(0) + p1 = f32[2,5,7] parameter(1) + ROOT d = f32[2,5,32] dot(p0, p1), + lhs_batch_dims={0,1}, rhs_batch_dims={0,1}, + lhs_contracting_dims={3}, rhs_contracting_dims={2} + })"; + + const char* expected = R"() +// CHECK: %[[P0:.*]] = f32[2,5,32,7]{3,2,1,0} parameter(0) +// CHECK: %[[P1:.*]] = f32[2,5,7]{2,1,0} parameter(1) +// CHECK: %[[BITCAST:.*]] = f32[2,5,7,1]{3,2,1,0} bitcast(%[[P1]]) +// CHECK: %[[DOT:.*]] = f32[2,5,32,1]{3,2,1,0} dot(%[[P0]], %[[BITCAST]]), +// CHECK-SAME: lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} +// CHECK: ROOT %[[ROOT:.*]] = f32[2,5,32]{2,1,0} bitcast(%[[DOT]]) +})"; + + RunAndFilecheckHloRewrite(hlo, GemvRewriter(), expected); +} + +TEST_F(GemvRewriterTest, DotNotRewriteVectorVectorMultiplication) { + const char* hlo = R"( + HloModule m + + ENTRY e { + p0 = f32[7] parameter(0) + p1 = f32[7] parameter(1) + ROOT d = f32[] dot(p0, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={0} + })"; + + RunAndFilecheckHloRewrite(hlo, GemvRewriter(), /*expected=*/std::nullopt); +} + +TEST_F(GemvRewriterTest, DotNotRewriteMatrixMatrixMultiplication) { + const char* hlo = R"( + HloModule m + + ENTRY e { + p0 = f32[5,7] parameter(0) + p1 = f32[7,32] parameter(1) + ROOT d = f32[5,32] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + })"; + + RunAndFilecheckHloRewrite(hlo, GemvRewriter(), /*expected=*/std::nullopt); +} + +TEST_F(GemvRewriterTest, DoNotRewriteDotsWithNonNormalizedLayout) { + const char* hlo = R"( + HloModule m + + ENTRY e { + p0 = f32[5,32,7]{2,1,0} parameter(0) + p1 = f32[5,7]{0,1} parameter(1) + ROOT d = f32[5,32]{0,1} dot(p0, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + GemvRewriter rewriter; + absl::StatusOr result = this->RunHloPass(&rewriter, module.get()); + EXPECT_FALSE(result.ok()); + EXPECT_EQ(result.status().message(), "Layout is not normalized."); +} + +} // namespace +} // namespace xla::gpu diff --git a/xla/service/gpu/gpu_all_gather_optimizer.cc b/xla/service/gpu/gpu_all_gather_optimizer.cc index ce25a019ab85f..66afc65bfa701 100644 --- a/xla/service/gpu/gpu_all_gather_optimizer.cc +++ b/xla/service/gpu/gpu_all_gather_optimizer.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,14 +27,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/collective_ops_utils.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" namespace xla { namespace gpu { -StatusOr AllGatherOptimizer::Run( +absl::StatusOr AllGatherOptimizer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/gpu/gpu_all_gather_optimizer.h b/xla/service/gpu/gpu_all_gather_optimizer.h index 717c6088f06e4..e28e42246910f 100644 --- a/xla/service/gpu/gpu_all_gather_optimizer.h +++ b/xla/service/gpu/gpu_all_gather_optimizer.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,10 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GPU_ALL_GATHER_OPTIMIZER_H_ #define XLA_SERVICE_GPU_GPU_ALL_GATHER_OPTIMIZER_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" namespace xla { @@ -31,7 +35,7 @@ class AllGatherOptimizer : public HloModulePass { absl::string_view name() const override { return "all-gather-optimizer"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/gpu_aot_compilation_test.cc b/xla/service/gpu/gpu_aot_compilation_test.cc index db5e3d9436d65..aad47e75728e2 100644 --- a/xla/service/gpu/gpu_aot_compilation_test.cc +++ b/xla/service/gpu/gpu_aot_compilation_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,24 +26,18 @@ limitations under the License. #include "xla/service/compiler.h" #include "xla/service/executable.h" #include "xla/service/platform_util.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/stream_executor_pimpl.h" -#include "tsl/platform/statusor.h" - -#if GOOGLE_CUDA -#include "xla/service/gpu/nvptx_compiler.h" -#elif TF_USE_ROCM -#include "xla/service/gpu/amdgpu_compiler.h" -#endif +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { using GpuAotCompilationTest = HloTestBase; -TEST_F(GpuAotCompilationTest, LoadExecutableFromAotCompilation) { +TEST_F(GpuAotCompilationTest, ExportAndLoadExecutable) { const absl::string_view hlo_string = R"( HloModule Test @@ -59,62 +53,15 @@ ENTRY main { auto name = absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, - se::MultiPlatformManager::PlatformWithName(name)); + se::PlatformManager::PlatformWithName(name)); TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * stream_exec, platform->ExecutorForDevice(0)); // Compile AOT. auto module_group = std::make_unique(std::move(module)); AotCompilationOptions aot_options(compiler->PlatformId()); - // ToDo: Remove after unification of AOT compiler - if (!aot_options.debug_options().xla_gpu_enable_xla_runtime_executable()) { - return; - } - aot_options.set_executor(stream_exec); - TF_ASSERT_OK_AND_ASSIGN( - std::vector> aot_results, - compiler->CompileAheadOfTime(std::move(module_group), aot_options)); - - // Serialize-deserialize AOT compilation result. - TF_ASSERT_OK_AND_ASSIGN(std::string serialized_aot_result, - aot_results[0]->SerializeAsString()); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr aot_result, - compiler->LoadAotCompilationResult(serialized_aot_result)); - - // Load Executable from AOT compilation result. - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, - aot_result->LoadExecutable(compiler, stream_exec)); -} - -TEST_F(GpuAotCompilationTest, LoadExecutableForThunkRuntime) { - const absl::string_view hlo_string = R"( -HloModule Test -ENTRY main { - a = f32[100, 200]{1,0} parameter(0) - ROOT b = f32[100, 200]{0,1} copy(a) -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_string)); - DebugOptions debug_options; - debug_options.set_xla_gpu_enable_xla_runtime_executable(false); - module->mutable_config().set_debug_options(debug_options); - - auto compiler = backend().compiler(); - auto name = - absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); - TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, - se::MultiPlatformManager::PlatformWithName(name)); - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * stream_exec, - platform->ExecutorForDevice(0)); - - // Compile AOT. - auto module_group = std::make_unique(std::move(module)); - AotCompilationOptions aot_options(compiler->PlatformId()); - aot_options.set_executor(stream_exec); TF_ASSERT_OK_AND_ASSIGN( std::vector> aot_results, compiler->CompileAheadOfTime(std::move(module_group), aot_options)); @@ -147,7 +94,7 @@ ENTRY main { auto name = absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, - se::MultiPlatformManager::PlatformWithName(name)); + se::PlatformManager::PlatformWithName(name)); TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * stream_exec, platform->ExecutorForDevice(0)); @@ -156,11 +103,6 @@ ENTRY main { // Stream executor is not passed as an option. Compiler::TargetConfig gpu_target_config(stream_exec); AotCompilationOptions aot_options(compiler->PlatformId()); - // ToDo: Remove after unification of AOT compiler - if (!aot_options.debug_options().xla_gpu_enable_xla_runtime_executable()) { - return; - } - aot_options.set_target_config(gpu_target_config); TF_ASSERT_OK_AND_ASSIGN( diff --git a/xla/service/gpu/gpu_asm_opts_util.cc b/xla/service/gpu/gpu_asm_opts_util.cc index 7bf5262491179..8028f36e261d8 100644 --- a/xla/service/gpu/gpu_asm_opts_util.cc +++ b/xla/service/gpu/gpu_asm_opts_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/gpu_asm_opts_util.h b/xla/service/gpu/gpu_asm_opts_util.h index 22a08f8657a99..a37ee7094e404 100644 --- a/xla/service/gpu/gpu_asm_opts_util.h +++ b/xla/service/gpu/gpu_asm_opts_util.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/gpu_async_collective_annotator.cc b/xla/service/gpu/gpu_async_collective_annotator.cc index 5df96baf19ae0..c2f6c04e5c274 100644 --- a/xla/service/gpu/gpu_async_collective_annotator.cc +++ b/xla/service/gpu/gpu_async_collective_annotator.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,25 +15,36 @@ limitations under the License. #include "xla/service/gpu/gpu_async_collective_annotator.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { -StatusOr GpuAsyncCollectiveAnnotator::Run( +absl::StatusOr GpuAsyncCollectiveAnnotator::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { for (HloInstruction* instruction : computation->instructions()) { - if (!hlo_query::IsAsyncCollectiveStartOp(instruction->opcode())) { + if (!hlo_query::IsAsyncCollectiveStartOp(instruction)) { continue; } CollectiveBackendConfig config; config.set_is_sync(!is_collective_async_(instruction)); - TF_RETURN_IF_ERROR(instruction->set_backend_config(config)); + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + instruction->backend_config()); + *gpu_config.mutable_collective_backend_config() = config; + TF_RETURN_IF_ERROR(instruction->set_backend_config(gpu_config)); changed = true; } } diff --git a/xla/service/gpu/gpu_async_collective_annotator.h b/xla/service/gpu/gpu_async_collective_annotator.h index 94955ff6395b6..4000fbcbdd499 100644 --- a/xla/service/gpu/gpu_async_collective_annotator.h +++ b/xla/service/gpu/gpu_async_collective_annotator.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,7 +18,12 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" +#include "xla/util.h" namespace xla { namespace gpu { @@ -33,7 +38,7 @@ class GpuAsyncCollectiveAnnotator : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/gpu/gpu_async_collective_annotator_test.cc b/xla/service/gpu/gpu_async_collective_annotator_test.cc index faf053ab5f5a6..f874a7e565ea7 100644 --- a/xla/service/gpu/gpu_async_collective_annotator_test.cc +++ b/xla/service/gpu/gpu_async_collective_annotator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,11 +19,18 @@ limitations under the License. #include #include +#include +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -107,18 +114,20 @@ XLA_TEST_P(GpuAsyncCollectiveAnnotatorTest, Test) { // Assert that all async collectives are annotated with the backend config. for (const HloInstruction* hlo : module->entry_computation()->instructions()) { - if (!hlo_query::IsAsyncCollectiveStartOp(hlo->opcode())) { + if (!hlo_query::IsAsyncCollectiveStartOp(hlo)) { continue; } - StatusOr backend_config = - hlo->backend_config(); - ASSERT_TRUE(backend_config.ok()); + auto gpu_config = hlo->backend_config(); + ASSERT_TRUE(gpu_config.ok()); + + const CollectiveBackendConfig& backend_config = + gpu_config.value().collective_backend_config(); if (test_case.expected_async.contains(hlo->name())) { - EXPECT_FALSE(backend_config->is_sync()); + EXPECT_FALSE(backend_config.is_sync()); } if (test_case.expected_sync.contains(hlo->name())) { - EXPECT_TRUE(backend_config->is_sync()); + EXPECT_TRUE(backend_config.is_sync()); } } } diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 8c5f4b03674b1..d0c20aa1c8e5e 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,6 +33,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "absl/types/variant.h" @@ -70,12 +71,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/transforms/hlo_constant_splitter.h" -#include "xla/mlir/backends/gpu/transforms/passes.h" -#include "xla/mlir/runtime/transforms/compilation_pipeline_gpu.h" -#include "xla/mlir/runtime/transforms/compilation_pipeline_options.h" -#include "xla/runtime/compiler.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/jit_executable.h" #include "xla/service/algebraic_simplifier.h" #include "xla/service/all_gather_broadcast_reorder.h" #include "xla/service/all_gather_combiner.h" @@ -98,6 +93,7 @@ limitations under the License. #include "xla/service/compiler.h" #include "xla/service/conditional_canonicalizer.h" #include "xla/service/conditional_simplifier.h" +#include "xla/service/convert_memory_placement_to_internal_annotations.h" #include "xla/service/convert_mover.h" #include "xla/service/convolution_4d_expander.h" #include "xla/service/convolution_pred_expander.h" @@ -118,21 +114,24 @@ limitations under the License. #include "xla/service/float_support.h" #include "xla/service/gather_expander.h" #include "xla/service/gather_simplifier.h" +#include "xla/service/gpu/address_computation_fusion_rewriter.h" +#include "xla/service/gpu/algorithm_checker.h" #include "xla/service/gpu/alias_passthrough_params.h" #include "xla/service/gpu/all_reduce_blueconnect.h" #include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/collective_permute_cycle_decomposer.h" #include "xla/service/gpu/command_buffer_scheduling.h" #include "xla/service/gpu/compile_module_to_llvm_ir.h" #include "xla/service/gpu/conv_layout_normalization.h" #include "xla/service/gpu/copy_fusion.h" -#include "xla/service/gpu/custom_fusion_rewriter.h" +#include "xla/service/gpu/custom_kernel_fusion_rewriter.h" #include "xla/service/gpu/dot_dimension_sorter.h" -#include "xla/service/gpu/fusion_merger_triton.h" +#include "xla/service/gpu/dot_operand_converter.h" #include "xla/service/gpu/fusion_pipeline.h" #include "xla/service/gpu/fusion_wrapper.h" #include "xla/service/gpu/gemm_broadcast_folding_rewriter.h" +#include "xla/service/gpu/gemm_fusion.h" #include "xla/service/gpu/gemm_rewriter.h" -#include "xla/service/gpu/gemm_rewriter_triton.h" #include "xla/service/gpu/gpu_all_gather_optimizer.h" #include "xla/service/gpu/gpu_async_collective_annotator.h" #include "xla/service/gpu/gpu_constants.h" @@ -142,9 +141,11 @@ limitations under the License. #include "xla/service/gpu/gpu_float_support.h" #include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/gpu/gpu_layout_assignment.h" +#include "xla/service/gpu/gpu_p2p_pipeliner.h" #include "xla/service/gpu/gpu_reduce_scatter_creator.h" #include "xla/service/gpu/gpu_sanitize_constant_names.h" #include "xla/service/gpu/gpu_scatter_expander.h" +#include "xla/service/gpu/gpu_windowed_einsum_handler.h" #include "xla/service/gpu/hlo_fusion_stats.h" #include "xla/service/gpu/horizontal_loop_fusion.h" #include "xla/service/gpu/ir_emission_utils.h" @@ -162,11 +163,13 @@ limitations under the License. #include "xla/service/gpu/reduction_layout_normalizer.h" #include "xla/service/gpu/reduction_splitter.h" #include "xla/service/gpu/reduction_utils.h" -#include "xla/service/gpu/runtime/executable.h" +#include "xla/service/gpu/rename_fusions.h" +#include "xla/service/gpu/runtime/thunk.h" #include "xla/service/gpu/runtime_intrinsics.h" #include "xla/service/gpu/scatter_slice_simplifier.h" #include "xla/service/gpu/softmax_rewriter_triton.h" -#include "xla/service/gpu/thunk.h" +#include "xla/service/gpu/stream_attribute_annotator.h" +#include "xla/service/gpu/stream_attribute_async_wrapper.h" #include "xla/service/gpu/topk_specializer.h" #include "xla/service/gpu/topk_splitter.h" #include "xla/service/gpu/tree_reduction_rewriter.h" @@ -183,6 +186,9 @@ limitations under the License. #include "xla/service/hlo_pass_pipeline.h" #include "xla/service/hlo_rematerialization.h" #include "xla/service/hlo_verifier.h" +#include "xla/service/host_memory_transfer_asyncifier.h" +#include "xla/service/host_offload_legalize.h" +#include "xla/service/host_offloader.h" #include "xla/service/layout_assignment.h" #include "xla/service/layout_normalization.h" #include "xla/service/llvm_ir/llvm_util.h" @@ -191,6 +197,7 @@ limitations under the License. #include "xla/service/loop_schedule_linearizer.h" #include "xla/service/operand_upcaster.h" #include "xla/service/optimization_barrier_expander.h" +#include "xla/service/optimize_input_output_buffer_alias.h" #include "xla/service/qr_expander.h" #include "xla/service/real_imag_expander.h" #include "xla/service/reduce_decomposer.h" @@ -227,16 +234,13 @@ limitations under the License. #include "xla/status.h" #include "xla/status_macros.h" #include "xla/statusor.h" -#if GOOGLE_CUDA -#include "xla/stream_executor/cuda/cuda_platform_id.h" -#elif TENSORFLOW_USE_ROCM -#include "xla/stream_executor/rocm/rocm_platform_id.h" -#endif #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" #include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/integrations/device_mem_allocator.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" @@ -245,11 +249,21 @@ limitations under the License. #include "tsl/platform/cpu_info.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/numbers.h" +#include "tsl/platform/path.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" #include "tsl/profiler/lib/traceme.h" +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/cuda/cuda_platform_id.h" +#elif TENSORFLOW_USE_ROCM +#include "xla/stream_executor/rocm/rocm_platform_id.h" +#endif + #ifdef PLATFORM_GOOGLE #include "xla/hlo/experimental/auto_sharding/auto_sharding.h" #endif // PLATFORM_GOOGLE @@ -257,11 +271,103 @@ limitations under the License. namespace xla { namespace gpu { namespace { -bool ConvIsLowerable(HloInstruction* conv) { - return GpuConvRewriter::ConvIsLowerable(conv); +// A class for storing either an owned thread pool or a non-owning pointer to an +// external thread pool. +class MaybeOwningThreadPool { + public: + // Gets or creates a thread pool. + // + // See the code for the logic. + static MaybeOwningThreadPool GetOrCreate( + int parallelism, tsl::thread::ThreadPool* default_thread_pool, + int default_parallelism); + + // Not owning (nullptr). + MaybeOwningThreadPool(); + // Not owning. + explicit MaybeOwningThreadPool(tsl::thread::ThreadPool* thread_pool); + // Owning. + explicit MaybeOwningThreadPool( + std::unique_ptr thread_pool); + tsl::thread::ThreadPool* get(); + const tsl::thread::ThreadPool* get() const; + tsl::thread::ThreadPool* operator->(); + const tsl::thread::ThreadPool* operator->() const; + explicit operator bool() const; + bool operator!() const; + + private: + std::variant> + thread_pool_; +}; + +/*static*/ MaybeOwningThreadPool MaybeOwningThreadPool::GetOrCreate( + int parallelism, tsl::thread::ThreadPool* default_thread_pool, + int default_parallelism) { + CHECK_GE(parallelism, 0); + CHECK_GE(default_parallelism, 1); + // CurrentThreadId() returns -1 if the current thread does not belong to the + // thread pool. If the current thread belongs to the thread pool, we should + // not be using it, because it can potentially cause deadlocks. + CHECK(default_thread_pool == nullptr || + default_thread_pool->CurrentThreadId() == -1); + + auto create_thread_pool = [&](int num_threads) { + CHECK_GE(num_threads, 1); + return std::make_unique(tsl::Env::Default(), "", + num_threads); + }; + + switch (parallelism) { + case 0: + if (default_thread_pool == nullptr && default_parallelism > 1) { + return MaybeOwningThreadPool(create_thread_pool(default_parallelism)); + } + return MaybeOwningThreadPool(default_thread_pool); + case 1: + return MaybeOwningThreadPool(nullptr); + default: + return MaybeOwningThreadPool(create_thread_pool(parallelism)); + } } -StatusOr GetAutotuneConfig( +MaybeOwningThreadPool::MaybeOwningThreadPool() : thread_pool_(nullptr) {} + +MaybeOwningThreadPool::MaybeOwningThreadPool( + tsl::thread::ThreadPool* thread_pool) + : thread_pool_(thread_pool) {} + +MaybeOwningThreadPool::MaybeOwningThreadPool( + std::unique_ptr thread_pool) + : thread_pool_(std::move(thread_pool)) {} + +tsl::thread::ThreadPool* MaybeOwningThreadPool::get() { + if (std::holds_alternative(thread_pool_)) { + return std::get(thread_pool_); + } + return std::get>(thread_pool_).get(); +} + +const tsl::thread::ThreadPool* MaybeOwningThreadPool::get() const { + return const_cast(this)->get(); +} + +tsl::thread::ThreadPool* MaybeOwningThreadPool::operator->() { + tsl::thread::ThreadPool* thread_pool = get(); + CHECK_NE(thread_pool, nullptr); + return thread_pool; +} + +const tsl::thread::ThreadPool* MaybeOwningThreadPool::operator->() const { + return const_cast(this)->operator->(); +} + +MaybeOwningThreadPool::operator bool() const { return get() != nullptr; } + +bool MaybeOwningThreadPool::operator!() const { return get() == nullptr; } + +absl::StatusOr GetAutotuneConfig( se::StreamExecutor* stream_exec, const DebugOptions& debug_options, const GpuCompiler::CompileOptions& options, const Compiler::TargetConfig& gpu_target_config) { @@ -275,140 +381,75 @@ StatusOr GetAutotuneConfig( return deviceless_config; } -se::GpuComputeCapability GetGpuVersion(se::StreamExecutor* stream_exec) { +se::GpuComputeCapability GetGpuVersion(const se::StreamExecutor* stream_exec) { return stream_exec->GetDeviceDescription().gpu_compute_capability(); } -// TODO(b/232263665): It should be shared between GPU and CPU. -class GpuAotCompilationResult : public AotCompilationResult { +class GpuThunkAotCompilationResult : public AotCompilationResult { public: - GpuAotCompilationResult( - HloModuleProto hlo, std::string_view obj_file, - std::string_view mlir_module, std::string_view gpu_asm_text, - absl::Span gpu_binary, - absl::Span constants = {}) { - XlaRuntimeExecutableProto xla_runtime_executable; - *xla_runtime_executable.mutable_hlo_module_proto() = hlo; - xla_runtime_executable.set_obj_file(std::string(obj_file)); - xla_runtime_executable.set_mlir_module(std::string(mlir_module)); - *xla_runtime_gpu_executable_.mutable_xla_runtime_executable() = - xla_runtime_executable; - - xla_runtime_gpu_executable_.set_gpu_asm_text(std::string(gpu_asm_text)); - xla_runtime_gpu_executable_.set_gpu_binary(gpu_binary.data(), - gpu_binary.size()); - - for (const GpuExecutable::ConstantInfo& cst : constants) { - auto* cst_proto = xla_runtime_gpu_executable_.add_constants(); - cst_proto->set_symbol_name(cst.symbol_name); - cst_proto->set_allocation_index(cst.allocation_index); - cst_proto->set_content(cst.content.span().data(), - cst.content.span().size()); - } - } - - explicit GpuAotCompilationResult(XlaRuntimeGpuExecutableProto executable) - : xla_runtime_gpu_executable_(executable) {} - - StatusOr SerializeAsString() const override { - return xla_runtime_gpu_executable_.SerializeAsString(); + static absl::StatusOr> + FromModule(const HloModule* hlo_module, + const BufferAssignment* buffer_assignment, + std::string_view asm_text, absl::Span binary, + const Thunk::BinaryMap& dnn_compiled_graphs) { + CompilationResultProto proto; + TF_ASSIGN_OR_RETURN(*proto.mutable_hlo_module_with_config(), + hlo_module->ToProtoWithConfig()); + *proto.mutable_buffer_assignment() = buffer_assignment->ToProto(); + proto.set_asm_text(std::string(asm_text)); + proto.set_binary(binary.data(), binary.size()); + proto.mutable_dnn_compiled_graphs()->insert(dnn_compiled_graphs.cbegin(), + dnn_compiled_graphs.cend()); + return std::unique_ptr( + new GpuThunkAotCompilationResult(hlo_module->Clone(), + std::move(proto))); } - static StatusOr> FromString( - const std::string& serialized) { - XlaRuntimeGpuExecutableProto xla_runtime_gpu_executable; - if (!xla_runtime_gpu_executable.ParseFromString(serialized)) { - return InternalError("Failed to parse serialized JitRtExecutableProto."); + static absl::StatusOr> + FromString(const std::string& serialized) { + CompilationResultProto proto; + if (!proto.ParseFromString(serialized)) { + return Internal( + "Failed to parse serialized GpuThunkAotCompilationResult."); } - return std::make_unique( - xla_runtime_gpu_executable); - } - StatusOr> LoadExecutable( - Compiler* compiler, se::StreamExecutor* executor) override; - - private: - XlaRuntimeGpuExecutableProto xla_runtime_gpu_executable_; -}; - -class GpuThunkAotCompilationResult : public AotCompilationResult { - public: - GpuThunkAotCompilationResult(HloModule* hlo_module, - BufferAssignment* buffer_assignment, - std::string_view asm_text, - absl::Span binary) { - *proto_.mutable_hlo_module() = hlo_module->ToProto(); - *proto_.mutable_buffer_assignment() = buffer_assignment->ToProto(); - proto_.set_asm_text(std::string(asm_text)); - proto_.set_binary(binary.data(), binary.size()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + HloModule::CreateFromProtoWithConfig(proto.hlo_module_with_config())); + return std::unique_ptr( + new GpuThunkAotCompilationResult(std::move(module), std::move(proto))); } - explicit GpuThunkAotCompilationResult(CompilationResultProto proto) - : proto_(proto) {} - - StatusOr SerializeAsString() const override { + absl::StatusOr SerializeAsString() const override { return proto_.SerializeAsString(); } - static StatusOr> FromString( - const std::string& serialized) { - CompilationResultProto proto; - if (!proto.ParseFromString(serialized)) { - return InternalError( - "Failed to parse serialized GpuThunkAotCompilationResult."); - } - return std::make_unique(proto); - } + absl::StatusOr> LoadExecutable( + Compiler* compiler, const se::StreamExecutor* stream_exec) const override; - StatusOr> LoadExecutable( - Compiler* compiler, se::StreamExecutor* stream_exec) override; + const HloModule* optimized_module() const override { return module_.get(); } + std::unique_ptr consume_optimized_module() override { + return std::move(module_); + } private: + GpuThunkAotCompilationResult(std::unique_ptr module, + CompilationResultProto proto) + : module_(std::move(module)), proto_(std::move(proto)) {} + + std::unique_ptr module_; CompilationResultProto proto_; }; } // end anonymous namespace -StatusOr> GpuAotCompilationResult::LoadExecutable( - Compiler* compiler, se::StreamExecutor* executor) { - XlaRuntimeExecutableProto xla_runtime_executable = - xla_runtime_gpu_executable_.xla_runtime_executable(); - TF_ASSIGN_OR_RETURN(HloModuleConfig hlo_module_config, - HloModule::CreateModuleConfigFromProto( - xla_runtime_executable.hlo_module_proto(), - GetDebugOptionsFromFlags())); +absl::StatusOr> +GpuThunkAotCompilationResult::LoadExecutable( + Compiler* compiler, const se::StreamExecutor* stream_exec) const { + // Recreate HloModule+HloModuleConfig from proto. TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_module, - HloModule::CreateFromProto(xla_runtime_executable.hlo_module_proto(), - hlo_module_config)); - std::vector constants; - for (auto& cst : xla_runtime_gpu_executable_.constants()) { - GpuExecutable::ConstantInfo constant = { - cst.symbol_name(), - DenseDataIntermediate::Own( - std::vector{cst.content().begin(), cst.content().end()}), - cst.allocation_index()}; - constants.push_back(std::move(constant)); - } - - return GpuExecutable::LoadFromObjFile( - std::move(hlo_module), xla_runtime_executable.obj_file(), - xla_runtime_executable.mlir_module(), GetDebugOptionsFromFlags(), - xla_runtime_gpu_executable_.gpu_asm_text(), - xla_runtime_gpu_executable_.gpu_binary(), std::move(constants), - GetGpuVersion(executor), executor); -} - -StatusOr> -GpuThunkAotCompilationResult::LoadExecutable(Compiler* compiler, - se::StreamExecutor* stream_exec) { - // Recreate HloModule from proto. - TF_ASSIGN_OR_RETURN(HloModuleConfig hlo_module_config, - HloModule::CreateModuleConfigFromProto( - proto_.hlo_module(), GetDebugOptionsFromFlags())); - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_module, - HloModule::CreateFromProto(proto_.hlo_module(), hlo_module_config)); + HloModule::CreateFromProtoWithConfig(proto_.hlo_module_with_config())); // Recreate BufferAssignment from proto. TF_ASSIGN_OR_RETURN( @@ -422,41 +463,26 @@ GpuThunkAotCompilationResult::LoadExecutable(Compiler* compiler, // Build the executable, which should be a thunk sequence. TF_ASSIGN_OR_RETURN( se::Platform * platform, - se::MultiPlatformManager::PlatformWithId(compiler->PlatformId())); + se::PlatformManager::PlatformWithId(compiler->PlatformId())); std::string platform_name = platform->Name(); se::DeviceDescription gpu_device_info = stream_exec->GetDeviceDescription(); mlir::DialectRegistry registry; - IrEmitterUnnested::GetDependentDialects(registry); auto mlir_context = std::make_unique(registry); llvm::LLVMContext llvm_context; auto llvm_module = std::make_unique("", llvm_context); auto* gpu_compiler = dynamic_cast(compiler); if (gpu_compiler == nullptr) { - return InternalError("Compiler is not a GpuCompiler."); + return Internal("Compiler is not a GpuCompiler."); } llvm_module->setTargetTriple(gpu_compiler->target_triple()); llvm_module->setDataLayout(gpu_compiler->data_layout()); IrEmitterContext ir_emitter_context(hlo_module.get(), buffer_assignment.get(), platform_name, gpu_device_info, mlir_context.get(), llvm_module.get(), - /*emit_ir_from_hlo=*/true); - mlir::OwningOpRef mlir_module = llvm_ir::CreateMlirModuleOp( - mlir::Builder(mlir_context.get()).getUnknownLoc(), hlo_module->name()); - std::vector ordered_allocations; - absl::flat_hash_map - operation_map; - TF_RETURN_IF_ERROR(HloToLhloModule(*buffer_assignment, *hlo_module, - *mlir_module, &ordered_allocations, - &operation_map)); - ir_emitter_context.set_allocations(ordered_allocations); + /*emit_kernels=*/false); auto ir_emitter = IrEmitterUnnested::Create(&ir_emitter_context); - auto entry_function = mlir::cast( - mlir_module->lookupSymbol(hlo_module->entry_computation()->name())); - // TODO(anlunx): EmitLmhloRegion emits fusion kernels. We need to make sure - // ptx and cubin already contain emission results and disable kernel emission - // here. TF_RETURN_IF_ERROR( - ir_emitter->EmitLmhloRegion(&entry_function.getBody(), operation_map)); + ir_emitter->EmitHloComputation(hlo_module->entry_computation())); std::unique_ptr thunk_sequence = ir_emitter->ConsumeThunkSequence(); ForAllThunks([](Thunk* thunk) { thunk->ClearCompileTimeInfo(); }, @@ -471,10 +497,7 @@ GpuThunkAotCompilationResult::LoadExecutable(Compiler* compiler, std::function buffer_assignment_dumper = [] { return std::string(); }; - bool enable_persistent_temp_buffers = - hlo_module->config() - .debug_options() - .xla_gpu_enable_persistent_temp_buffers(); + int64_t debug_buffer_assignment_show_max = hlo_module->config() .debug_options() @@ -485,6 +508,9 @@ GpuThunkAotCompilationResult::LoadExecutable(Compiler* compiler, GpuExecutable::Create(GpuExecutable::Params{ /*asm_text=*/proto_.asm_text(), /*binary=*/binary, + /*dnn_compiled_graphs=*/ + Thunk::BinaryMap(proto_.dnn_compiled_graphs().cbegin(), + proto_.dnn_compiled_graphs().cend()), /*gpu_version=*/gpu_device_info.gpu_compute_capability(), /*executable=*/std::move(thunk_sequence), /*constants=*/std::move(constants), @@ -493,9 +519,8 @@ GpuThunkAotCompilationResult::LoadExecutable(Compiler* compiler, /*output_shape=*/std::move(output_shape), /*mlir_allocations=*/std::nullopt, /*buffer_assignment=*/std::move(buffer_assignment), - /*enable_persistent_temp_buffers=*/enable_persistent_temp_buffers, /*debug_buffer_assignment_show_max=*/debug_buffer_assignment_show_max, - /*debug_module=*/std::unique_ptr(), + /*debug_module=*/std::move(hlo_module), /*enable_debug_info_manager=*/true})); return executable; } @@ -523,39 +548,43 @@ void AddHloVerifier(HloPassPipeline* pipeline, HloVerifierOpts&& opts = {}, } } -void SetInstructionMetadata(HloModule* module) { - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - instruction->set_creation_pass_id(-1); - instruction->set_logical_creation_pass_id(-1); - } - } +void CheckNotScheduled(HloModule* hlo_module) { + CHECK(!hlo_module->has_schedule()) + << "\nThe current HLO module " << hlo_module->name() + << " is scheduled and optimized. \n" + << "It is not expected to run optimization passes again.\nPlease use " + << "RunAndCompareNoHloPasses() or RunAndCompareTwoModules() instead of " + << "RunAndCompare()\nif running unit tests as they set" + << " run_hlo_passes=false."; } -} // namespace - -// Runs optimization passes on the given HLO module. -Status GpuCompiler::OptimizeHloModule(HloModule* hlo_module, - se::StreamExecutor* stream_exec, - const CompileOptions& options, - const TargetConfig& gpu_target_config) { - const DebugOptions& debug_options = hlo_module->config().debug_options(); - MaybeOwningThreadPool thread_pool = MaybeOwningThreadPool::GetOrCreate( - /*parallelism=*/hlo_module->config() - .debug_options() - .xla_gpu_force_compilation_parallelism(), - /*default_thread_pool=*/options.thread_pool, - /*default_parallelism=*/tsl::port::MaxParallelism()); +void LogDebugOptions(HloModule* hlo_module) { + // LOG_LINES is used instead of LOG since the message can exceed the + // maximum line length, which results in the message being truncated. + XLA_VLOG_LINES( + 1, absl::StrFormat("GpuCompilationEnvironment of hlo_module %s:\n%s", + hlo_module->name(), + hlo_module->config().debug_options().DebugString())); +} - AlgebraicSimplifierOptions layout_insensitive_algsimp_opts({}, - ConvIsLowerable); +AlgebraicSimplifierOptions LayoutInsensitiveAlgebraicSimplifierOptions( + const HloModuleConfig& hlo_module_config, + const Compiler::TargetConfig& gpu_target_config, + AlgebraicSimplifierOptions opts_from_compiler) { + AlgebraicSimplifierOptions layout_insensitive_algsimp_opts = + opts_from_compiler; + layout_insensitive_algsimp_opts.set_conv_is_lowerable_callback( + GpuConvRewriter::ConvIsLowerable); + layout_insensitive_algsimp_opts.set_enable_dot_strength_reduction( + hlo_module_config.debug_options() + .xla_gpu_enable_dot_strength_reduction()); // GPU only supports canonical convolutions. layout_insensitive_algsimp_opts.set_supports_non_canonical_dots(false); // "slow" minmax means we propagate nan. layout_insensitive_algsimp_opts.set_minmax_propagate_nan( - !debug_options.xla_gpu_enable_fast_min_max()); + !hlo_module_config.debug_options().xla_gpu_enable_fast_min_max()); // Always simplify reduce(transpose(x)) and reduce(reshape(x)), even when // the transpose/reshape has multiple users. This helps int8 models, which @@ -570,16 +599,20 @@ Status GpuCompiler::OptimizeHloModule(HloModule* hlo_module, } layout_insensitive_algsimp_opts .set_enable_unconditional_reduce_of_concat_replacement(false); + return layout_insensitive_algsimp_opts; +} - SetInstructionMetadata(hlo_module); - +absl::Status RunPreSPMDPartitionerPasses(HloModule* hlo_module) { HloPassPipeline pre_spmd_pipeline("pre-spmd-partitioner"); // Run some IR cleanup passes before running the SPMD partitioning // passes. + pre_spmd_pipeline.AddPass(); pre_spmd_pipeline.AddPass(); pre_spmd_pipeline.AddPass(); pre_spmd_pipeline.AddPass(); + // The TopkDecomposer generates a compare op with type=TOTALORDER and must + // run before the ComparisonExpander which rewrites such comparisons. pre_spmd_pipeline.AddPass([&](const HloInstruction* instr) { return instr->opcode() == HloOpcode::kTopK; }); @@ -589,8 +622,12 @@ Status GpuCompiler::OptimizeHloModule(HloModule* hlo_module, pre_spmd_pipeline.AddPass( [](const HloSortInstruction*, int64_t) { return true; }); - TF_RETURN_IF_ERROR(pre_spmd_pipeline.Run(hlo_module).status()); + return pre_spmd_pipeline.Run(hlo_module).status(); +} +absl::Status RunSPMDPasses( + HloModule* hlo_module, const Compiler::TargetConfig& gpu_target_config, + const AlgebraicSimplifierOptions& layout_insensitive_algsimp_opts) { const int64_t num_partitions = hlo_module->config().num_partitions(); bool auto_sharding = hlo_module->config().use_auto_spmd_partitioning(); @@ -666,333 +703,347 @@ Status GpuCompiler::OptimizeHloModule(HloModule* hlo_module, hlo_module->config().allow_spmd_sharding_propagation_to_output()); spmd_pipeline.AddPass( num_partitions, hlo_module->config().replica_count(), - debug_options.xla_gpu_threshold_for_windowed_einsum_mib()); + hlo_module->config() + .debug_options() + .xla_gpu_threshold_for_windowed_einsum_mib(), + hlo_module->config() + .debug_options() + .xla_gpu_multi_streamed_windowed_einsum()); spmd_pipeline.AddPass(); - TF_RETURN_IF_ERROR(spmd_pipeline.Run(hlo_module).status()); + return spmd_pipeline.Run(hlo_module).status(); } else { HloPassPipeline sharding_removal_pipeline("sharding-removal"); // Remove redundant sharding ops when partition_count == 1. sharding_removal_pipeline.AddPass(); sharding_removal_pipeline.AddPass(); - TF_RETURN_IF_ERROR(sharding_removal_pipeline.Run(hlo_module).status()); + return sharding_removal_pipeline.Run(hlo_module).status(); } +} - { - HloPassPipeline pipeline("optimization"); - AddHloVerifier(&pipeline); - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); - - HloPredicate upcaster_filter = [&](const HloInstruction* instr) { - const auto* cuda_cc = std::get_if( - &gpu_target_config.device_description.gpu_compute_capability()); - if (cuda_cc != nullptr && - !cuda_cc->IsAtLeast(se::CudaComputeCapability::VOLTA)) { - return true; - } - return !gpu::IsMatrixMultiplication(*instr); - }; - pipeline.AddPass(); - pipeline.AddPass(); +absl::Status RunOptimizationPasses( + HloModule* hlo_module, const Compiler::TargetConfig& gpu_target_config, + const AlgebraicSimplifierOptions& layout_insensitive_algsimp_opts) { + HloPassPipeline pipeline("optimization"); + AddHloVerifier(&pipeline); + if (hlo_module->config() + .debug_options() + .xla_gpu_multi_streamed_windowed_einsum()) { + pipeline.AddPass(); + } + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); + + HloPredicate upcaster_filter = [&](const HloInstruction* instr) { + const auto* cuda_cc = std::get_if( + &gpu_target_config.device_description.gpu_compute_capability()); + if (cuda_cc != nullptr && + !cuda_cc->IsAtLeast(se::CudaComputeCapability::VOLTA)) { + return true; + } + return !gpu::IsMatrixMultiplication(*instr); + }; + pipeline.AddPass(); + pipeline.AddPass(); - pipeline.AddPass(upcaster_filter); - pipeline.AddPass(upcaster_filter); + pipeline.AddPass(upcaster_filter); + pipeline.AddPass(upcaster_filter); - pipeline.AddPass( - SubByteNormalization::SET_ELEMENT_SIZE); + // Add the DotOperandConverter after any potential upcasts done as part of + // the OperandUpcaster, so that the DotOperandConverter becomes a no-op. + pipeline.AddPass(); - // Expand random number generation. - pipeline.AddPass(); - pipeline.AddPass(RandomAlgorithm::RNG_PHILOX); + pipeline.AddPass( + SubByteNormalization::SET_ELEMENT_SIZE); - // Comparison total order expander - pipeline.AddPass(); + // Expand random number generation. + pipeline.AddPass(); + pipeline.AddPass(RandomAlgorithm::RNG_PHILOX); - // Remove zero-sized HLO from the input so that other passes don't have to - // handle it. - pipeline.AddPass(); + // Comparison total order expander + pipeline.AddPass(std::array{std::make_pair(BF16, F32)}); - if (debug_options.xla_gpu_deterministic_ops()) { - // Scatter can be indeterministic if indices are not unique or a non - // associative combiner function is used. Eliminate these Scatter ops. - pipeline.AddPass( - ScatterExpander::kEliminateIndeterminisitcScatters); - } - // Scatters unsupported on XLA:GPU are eliminated. - pipeline.AddPass(); - - // TODO(phawkins): replace QR and Eigh decompositions with calls to - // cuSOLVER. - pipeline.AddPass(); - pipeline.AddPass(); - - pipeline.AddPass(); - - // TODO(b/64094172): make Call work on GPU instead of inlining. - pipeline.AddPass(); - - pipeline.AddPass(); - - pipeline.AddPass(); - - // Replace PRED convolutions with F16. - pipeline.AddPass(); - - // Expand the sort op to support stable sorting if required. - pipeline.AddPass(); - - pipeline.AddPass( - /*rewrite_training_op=*/true, - /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true); - - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); - - DynamicPadderOptions dynamic_padder_options; - - switch (hlo_module->config().debug_options().xla_gpu_shape_checks()) { - case DebugOptions::IGNORE: - dynamic_padder_options.shape_check_mode = - DynamicDimensionInference::ShapeCheckMode::kIgnore; - break; - case DebugOptions::RUNTIME: { - dynamic_padder_options.shape_check_mode = - DynamicDimensionInference::ShapeCheckMode::kRuntime; - dynamic_padder_options.assertion_generator = [&](HloInstruction* inst) { - auto created = Cast( - inst->parent()->AddInstruction(HloInstruction::CreateCustomCall( - ShapeUtil::MakeTokenShape(), {inst}, - kXlaGpuAssertCustomCallTag, - "Buffers have different size at runtime", - API_VERSION_STATUS_RETURNING))); - created->set_custom_call_has_side_effect(true); - }; - break; - } - case DebugOptions::COMPILE_TIME: - dynamic_padder_options.shape_check_mode = - DynamicDimensionInference::ShapeCheckMode::kCompileTime; - break; - default: - LOG(FATAL) << "Unreachable"; - } + // Remove zero-sized HLO from the input so that other passes don't have to + // handle it. + pipeline.AddPass(); - pipeline.AddPass(dynamic_padder_options); - - // Build simplification pipeline. The passes in here are run to a fixed - // point. - [&, &pipeline = - pipeline.AddPass>("simplification")] { - AddHloVerifier(&pipeline, HloVerifierOpts{}, /*debug_only=*/true); - - // BatchNormExpander can create zero-sized ops, so zero-sized HLO - // elimination has to come after that pass. - pipeline.AddPass(); - - pipeline.AddPass(); - pipeline.AddPass(GatherExpander::kEliminateSimpleGathers); - pipeline.AddPass(); - pipeline.AddPass( - ScatterExpander::kEliminateSimpleScatters); - pipeline.AddPass(); - pipeline.AddPass(layout_insensitive_algsimp_opts); - pipeline.AddPass(); - // AlgebraicSimplifier may add contracting dimensions to a dot. - pipeline.AddPass(); - pipeline.AddPass(); - // Only merge "smallish" dots. This threshold was not set carefully, but - // so far we know that 1mb is too small. - pipeline.AddPass(/*max_size_to_merge=*/int64_t{16} << 20); - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); - - ReshapeMoverOptions reshape_mover_options; - reshape_mover_options.reshape_of_1d_broadcast_is_cheap = true; - pipeline.AddPass(reshape_mover_options); - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(CanFoldTransposeOperandIntoDot); - pipeline.AddPass(/*is_layout_sensitive=*/false); - pipeline.AddPass(); - }(); - - // ConvertMover and ReshapeMover fight with each other: ConvertMover wants - // to move some converts down the graph, but ReshapeMover wants to move them - // up the graph. As a compromise, let ReshapeMover run to a fixed point, - // and then run ConvertMover + algsimp to a fixed point. - [&, &pipeline = - pipeline.AddPass>("simplification-2")] { - pipeline.AddPass(); - pipeline.AddPass(layout_insensitive_algsimp_opts); - }(); - - pipeline.AddPass( - /*mark_fusion_duplications=*/false); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + if (hlo_module->config().debug_options().xla_gpu_deterministic_ops()) { + // Scatter can be indeterministic if indices are not unique or a non + // associative combiner function is used. Eliminate these Scatter ops. + pipeline.AddPass( + ScatterExpander::kEliminateIndeterminisitcScatters); } + // Scatters unsupported on XLA:GPU are eliminated. + pipeline.AddPass(); - const bool enable_all_pipelined = - debug_options.xla_gpu_enable_pipelined_collectives(); + // TODO(phawkins): replace QR and Eigh decompositions with calls to + // cuSOLVER. + pipeline.AddPass(); + pipeline.AddPass(); - // Optimize collectives generated by SPMD partitioning. Enable these passes - // otherwise as well so that all collectives can get these optimizations. - { - HloPassPipeline collectives_pipeline("collective-optimizations"); - collectives_pipeline.AddPass(); - collectives_pipeline.AddPass(); - collectives_pipeline.AddPass(); - collectives_pipeline.AddPass( - debug_options.xla_gpu_enable_reassociation_for_converted_ar()); - collectives_pipeline.AddPass(); - const DebugOptions& debug_options = hlo_module->config().debug_options(); - collectives_pipeline.AddPass( - /*enable_reduce_scatter=*/debug_options - .xla_gpu_enable_while_loop_reduce_scatter_code_motion()); - - if (enable_all_pipelined || - debug_options.xla_gpu_enable_pipelined_all_reduce()) { - CollectivePipeliner::Config config{ - /*level_to_operate_on=*/0, - /*max_pipelining_per_loop=*/INT64_MAX, - /*last_run=*/true, - /*pipeline_use_tree=*/false, - /*process_different_sized_ops=*/true, - /*pipelining_direction=*/ - CollectivePipeliner::PipeliningDirection::kForward, - /*should_process=*/HloPredicateIsOp, - /*acceptable_formatting=*/[](const HloInstruction*) { return true; }, - /*reuse_pipelined_op_buffer=*/ - [](const HloInstruction*) { return false; }}; - collectives_pipeline.AddPass(config); - } - if (enable_all_pipelined || - debug_options.xla_gpu_enable_pipelined_all_gather()) { - CollectivePipeliner::Config config{ - /*level_to_operate_on=*/0, - /*max_pipelining_per_loop=*/INT64_MAX, - /*last_run=*/true, - /*pipeline_use_tree=*/false, - /*process_different_sized_ops=*/true, - /*pipelining_direction=*/ - CollectivePipeliner::PipeliningDirection::kBackward, - /*should_process=*/HloPredicateIsOp, - /*acceptable_formatting=*/[](const HloInstruction*) { return true; }, - /*reuse_pipelined_op_buffer=*/ - [](const HloInstruction*) { return false; }}; - collectives_pipeline.AddPass(config); - } - if (enable_all_pipelined || - debug_options.xla_gpu_enable_pipelined_reduce_scatter()) { - CollectivePipeliner::Config config{ - /*level_to_operate_on=*/0, - /*max_pipelining_per_loop=*/INT64_MAX, - /*last_run=*/true, - /*pipeline_use_tree=*/false, - /*process_different_sized_ops=*/true, - /*pipelining_direction=*/ - CollectivePipeliner::PipeliningDirection::kForward, - /*should_process=*/HloPredicateIsOp, - /*acceptable_formatting=*/[](const HloInstruction*) { return true; }, - /*reuse_pipelined_op_buffer=*/ - [](const HloInstruction*) { return false; }}; - collectives_pipeline.AddPass(config); + pipeline.AddPass(); + + // TODO(b/64094172): make Call work on GPU instead of inlining. + pipeline.AddPass(); + + pipeline.AddPass(); + + pipeline.AddPass(); + + // Replace PRED convolutions with F16. + pipeline.AddPass(); + + // Expand the sort op to support stable sorting if required. + pipeline.AddPass(); + + pipeline.AddPass( + /*rewrite_training_op=*/true, + /*rewrite_inference_op=*/true, + /*rewrite_grad_op=*/true); + + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); + + DynamicPadderOptions dynamic_padder_options; + + switch (hlo_module->config().debug_options().xla_gpu_shape_checks()) { + case DebugOptions::IGNORE: + dynamic_padder_options.shape_check_mode = + DynamicDimensionInference::ShapeCheckMode::kIgnore; + break; + case DebugOptions::RUNTIME: { + dynamic_padder_options.shape_check_mode = + DynamicDimensionInference::ShapeCheckMode::kRuntime; + dynamic_padder_options.assertion_generator = [&](HloInstruction* inst) { + auto created = Cast( + inst->parent()->AddInstruction(HloInstruction::CreateCustomCall( + ShapeUtil::MakeTokenShape(), {inst}, kXlaGpuAssertCustomCallTag, + "Buffers have different size at runtime", + API_VERSION_STATUS_RETURNING))); + created->set_custom_call_has_side_effect(true); + }; + break; } + case DebugOptions::COMPILE_TIME: + dynamic_padder_options.shape_check_mode = + DynamicDimensionInference::ShapeCheckMode::kCompileTime; + break; + default: + LOG(FATAL) << "Unreachable"; + } - // Run algebraic simplifier to reshape(broadcast) into a broadcast when - // the reshape is just adding a unit dimension. This will help with the - // AllGatherBroadcastReorder pass. - collectives_pipeline.AddPass( - layout_insensitive_algsimp_opts); + pipeline.AddPass(dynamic_padder_options); - collectives_pipeline.AddPass(); + // Build simplification pipeline. The passes in here are run to a fixed + // point. + [&, &pipeline = + pipeline.AddPass>("simplification")] { + AddHloVerifier(&pipeline, HloVerifierOpts{}, /*debug_only=*/true); - // promote 16 bit integer all-reduce and reduce-scatter to 32-bit. - const std::pair ar_promoted_types[] = { - {U16, U32}, {S16, S32}}; - collectives_pipeline.AddPass(ar_promoted_types); - // Remove dead computations left over after ar/rs promotion. - collectives_pipeline.AddPass(); + // BatchNormExpander can create zero-sized ops, so zero-sized HLO + // elimination has to come after that pass. + pipeline.AddPass(); - // Run WhileLoopTripCountAnnotator after collective pipelining and before - // layout assignment and fusion.This pass does some pattern-matching on - // while bodies/conditions, and this is where the HLO is "nicest". - // - // It's important that we don't make semantic changes (e.g. unrolling) to - // any `while` loops after this point, because otherwise the trip-count - // annotations added by this pass may not be correct after the - // modifications. - collectives_pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(GatherExpander::kEliminateSimpleGathers); + pipeline.AddPass(); + pipeline.AddPass( + ScatterExpander::kEliminateSimpleScatters); + pipeline.AddPass(); + pipeline.AddPass(layout_insensitive_algsimp_opts); + pipeline.AddPass(); + // AlgebraicSimplifier may add contracting dimensions to a dot. + pipeline.AddPass(); + pipeline.AddPass(); + // Only merge "smallish" dots. This threshold was not set carefully, but + // so far we know that 1mb is too small. + pipeline.AddPass(/*max_size_to_merge=*/int64_t{16} << 20); + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); - TF_RETURN_IF_ERROR(collectives_pipeline.Run(hlo_module).status()); - } + ReshapeMoverOptions reshape_mover_options; + reshape_mover_options.reshape_of_1d_broadcast_is_cheap = true; + pipeline.AddPass(reshape_mover_options); + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(CanFoldTransposeOperandIntoDot); + pipeline.AddPass(/*is_layout_sensitive=*/false); + pipeline.AddPass(); + }(); - // Run target-specific HLO optimization passes for convolution - // canonicalization. - se::GpuComputeCapability gpu_version = - gpu_target_config.device_description.gpu_compute_capability(); - se::dnn::VersionInfo dnn_version = gpu_target_config.dnn_version_info; - if (stream_exec != nullptr) { - gpu_version = GetGpuVersion(stream_exec); - se::dnn::DnnSupport* dnn = stream_exec->AsDnn(); - if (dnn == nullptr) { - return tsl::errors::FailedPrecondition( - "DNN library initialization failed." - " Look at the errors above for more details."); - } - TF_ASSIGN_OR_RETURN(dnn_version, dnn->GetVersion()); + // ConvertMover and ReshapeMover fight with each other: ConvertMover wants + // to move some converts down the graph, but ReshapeMover wants to move them + // up the graph. As a compromise, let ReshapeMover run to a fixed point, + // and then run ConvertMover + algsimp to a fixed point. + [&, &pipeline = + pipeline.AddPass>("simplification-2")] { + pipeline.AddPass(); + pipeline.AddPass(layout_insensitive_algsimp_opts); + }(); + + pipeline.AddPass( + /*mark_fusion_duplications=*/false); + return pipeline.Run(hlo_module).status(); +} + +absl::Status RunCollectiveOptimizationPasses( + HloModule* hlo_module, + const AlgebraicSimplifierOptions& layout_insensitive_algsimp_opts) { + // Optimize collectives generated by SPMD partitioning. Enable these passes + // otherwise as well so that all collectives can get these optimizations. + const DebugOptions& debug_options = hlo_module->config().debug_options(); + + HloPassPipeline collectives_pipeline("collective-optimizations"); + collectives_pipeline.AddPass(); + collectives_pipeline.AddPass(); + collectives_pipeline.AddPass(); + collectives_pipeline.AddPass( + debug_options.xla_gpu_enable_reassociation_for_converted_ar()); + collectives_pipeline.AddPass(); + + collectives_pipeline.AddPass( + /*enable_reduce_scatter=*/debug_options + .xla_gpu_enable_while_loop_reduce_scatter_code_motion()); + + if (debug_options.xla_gpu_enable_pipelined_collectives() || + debug_options.xla_gpu_enable_pipelined_all_reduce()) { + CollectivePipeliner::Config config{ + /*level_to_operate_on=*/0, + /*max_pipelining_per_loop=*/INT64_MAX, + /*last_run=*/true, + /*pipeline_use_tree=*/false, + /*process_different_sized_ops=*/true, + /*pipelining_direction=*/ + CollectivePipeliner::PipeliningDirection::kForward, + /*should_process=*/HloPredicateIsOp, + /*acceptable_formatting=*/[](const HloInstruction*) { return true; }, + /*reuse_pipelined_op_buffer=*/ + [](const HloInstruction*) { return false; }}; + collectives_pipeline.AddPass(config); + } + if (debug_options.xla_gpu_enable_pipelined_collectives() || + debug_options.xla_gpu_enable_pipelined_all_gather()) { + CollectivePipeliner::Config config{ + /*level_to_operate_on=*/0, + /*max_pipelining_per_loop=*/INT64_MAX, + /*last_run=*/true, + /*pipeline_use_tree=*/false, + /*process_different_sized_ops=*/true, + /*pipelining_direction=*/ + CollectivePipeliner::PipeliningDirection::kBackward, + /*should_process=*/HloPredicateIsOp, + /*acceptable_formatting=*/[](const HloInstruction*) { return true; }, + /*reuse_pipelined_op_buffer=*/ + [](const HloInstruction*) { return false; }}; + collectives_pipeline.AddPass(config); + } + if (debug_options.xla_gpu_enable_pipelined_collectives() || + debug_options.xla_gpu_enable_pipelined_reduce_scatter()) { + CollectivePipeliner::Config config{ + /*level_to_operate_on=*/0, + /*max_pipelining_per_loop=*/INT64_MAX, + /*last_run=*/true, + /*pipeline_use_tree=*/false, + /*process_different_sized_ops=*/true, + /*pipelining_direction=*/ + CollectivePipeliner::PipeliningDirection::kForward, + /*should_process=*/HloPredicateIsOp, + /*acceptable_formatting=*/[](const HloInstruction*) { return true; }, + /*reuse_pipelined_op_buffer=*/ + [](const HloInstruction*) { return false; }}; + collectives_pipeline.AddPass(config); } - TF_RETURN_IF_ERROR(OptimizeHloConvolutionCanonicalization( - hlo_module, gpu_version, dnn_version, options.device_allocator)); + collectives_pipeline.AddPass( + hlo_module->config() + .debug_options() + .xla_gpu_collective_permute_decomposer_threshold()); - { - // Run layout assignment in a separate pipeline from - // "post-layout-assignment" because we want everything after layout - // assignment to have a layout-sensitive invariant-checker, but - // HloPassPipeline also runs its invariant checker before any passes are - // run, meaning, the pipeline that contains layout assignment cannot contain - // a layout-sensitive verifier! - HloPassPipeline pipeline("layout assignment"); - // Layout assignment uses alias analysis, which requires the call graph to - // be flattened. - pipeline.AddPass(); - ChannelLayoutConstraints layout_constraints; - pipeline.AddPass( - hlo_module->mutable_entry_computation_layout(), stream_exec, - &layout_constraints); - // Run SubByteNormalization because GpuLayoutAssignment may modify a - // Layout's element_size_in_bits field. - pipeline.AddPass( - SubByteNormalization::SET_ELEMENT_SIZE); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + collectives_pipeline.AddPass( + hlo_module->config() + .debug_options() + .xla_gpu_collective_permute_decomposer_threshold()); + + if (hlo_module->config() + .debug_options() + .xla_gpu_enable_pipelined_collectives() || + hlo_module->config().debug_options().xla_gpu_enable_pipelined_p2p()) { + AddP2PPipeliner(collectives_pipeline); } - // Run target-specific HLO optimization passes after layout assignment. - TF_RETURN_IF_ERROR(OptimizeHloPostLayoutAssignment( - hlo_module, stream_exec, options, gpu_target_config, thread_pool.get())); + // Run algebraic simplifier to reshape(broadcast) into a broadcast when + // the reshape is just adding a unit dimension. This will help with the + // AllGatherBroadcastReorder pass. + collectives_pipeline.AddPass( + layout_insensitive_algsimp_opts); + + collectives_pipeline.AddPass(); + + // promote 16 bit integer all-reduce and reduce-scatter to 32-bit. + const std::pair ar_promoted_types[] = { + {U16, U32}, {S16, S32}}; + collectives_pipeline.AddPass(ar_promoted_types); + // Remove dead computations left over after ar/rs promotion. + collectives_pipeline.AddPass(); + + // Run WhileLoopTripCountAnnotator after collective pipelining and before + // layout assignment and fusion.This pass does some pattern-matching on + // while bodies/conditions, and this is where the HLO is "nicest". + // + // It's important that we don't make semantic changes (e.g. unrolling) to + // any `while` loops after this point, because otherwise the trip-count + // annotations added by this pass may not be correct after the + // modifications. + collectives_pipeline.AddPass(); + + return collectives_pipeline.Run(hlo_module).status(); +} +absl::Status RunLayoutAssignmentPasses(HloModule* hlo_module, + se::GpuComputeCapability gpu_version, + se::dnn::VersionInfo dnn_version) { + // Run layout assignment in a separate pipeline from + // "post-layout-assignment" because we want everything after layout + // assignment to have a layout-sensitive invariant-checker, but + // HloPassPipeline also runs its invariant checker before any passes are + // run, meaning, the pipeline that contains layout assignment cannot contain + // a layout-sensitive verifier! + HloPassPipeline pipeline("layout assignment"); + // Layout assignment uses alias analysis, which requires the call graph to + // be flattened. + pipeline.AddPass(); + ChannelLayoutConstraints layout_constraints; + pipeline.AddPass( + hlo_module->mutable_entry_computation_layout(), gpu_version, dnn_version, + &layout_constraints); + // Run SubByteNormalization because GpuLayoutAssignment may modify a + // Layout's element_size_in_bits field. + pipeline.AddPass( + SubByteNormalization::SET_ELEMENT_SIZE); + pipeline.AddPass(true); + return pipeline.Run(hlo_module).status(); +} + +absl::Status RunFusionPasses(HloModule* hlo_module, + const Compiler::TargetConfig& gpu_target_config, + tsl::thread::ThreadPool* thread_pool, + HloCostAnalysis::ShapeSizeFunction shape_size_fn) { const se::DeviceDescription& gpu_device_info = gpu_target_config.device_description; - TF_RETURN_IF_ERROR(FusionPipeline(debug_options, ShapeSizeBytesFunction(), - thread_pool.get(), gpu_device_info) + TF_RETURN_IF_ERROR(FusionPipeline(hlo_module->config().debug_options(), + shape_size_fn, thread_pool, gpu_device_info) .Run(hlo_module) .status()); - if (debug_options.xla_gpu_enable_triton_softmax_fusion()) { - TF_RETURN_IF_ERROR(FusionMergerTriton().Run(hlo_module).status()); - } - - if (debug_options.xla_gpu_collect_cost_model_stats()) { + if (hlo_module->config().debug_options().xla_gpu_collect_cost_model_stats()) { GpuHloCostAnalysis::Options cost_analysis_options{ - ShapeSizeBytesFunction(), + shape_size_fn, /*per_second_rates=*/{}, /*count_multiple_input_accesses=*/true}; @@ -1011,137 +1062,273 @@ Status GpuCompiler::OptimizeHloModule(HloModule* hlo_module, VLOG(2) << stats.ToString(); } - { - HloPassPipeline pipeline("post-fusion optimization"); - pipeline.AddPass( - debug_options.xla_gpu_all_gather_combine_threshold_bytes(), - /*combine_threshold_count=*/256, - debug_options.xla_gpu_enable_all_gather_combine_by_dim()); - pipeline.AddPass( - debug_options.xla_gpu_all_reduce_combine_threshold_bytes(), - /*combine_threshold_count=*/256); - pipeline.AddPass( - debug_options.xla_gpu_reduce_scatter_combine_threshold_bytes(), - /*combine_threshold_count=*/256, - debug_options.xla_gpu_enable_reduce_scatter_combine_by_dim()); - - if (debug_options.xla_gpu_all_reduce_contiguous()) { - pipeline.AddPass(); - } + return absl::OkStatus(); +} - TF_RETURN_IF_ERROR( - AddCustomKernelReplacementPasses(&pipeline, debug_options)); +absl::Status RunPostFusionPasses( + HloModule* hlo_module, + std::function + add_custom_kernel_replacement_passes) { + HloPassPipeline pipeline("post-fusion optimization"); + pipeline.AddPass(); + pipeline.AddPass( + hlo_module->config() + .debug_options() + .xla_gpu_all_gather_combine_threshold_bytes(), + /*combine_threshold_count=*/256, + hlo_module->config() + .debug_options() + .xla_gpu_enable_all_gather_combine_by_dim()); + pipeline.AddPass( + hlo_module->config() + .debug_options() + .xla_gpu_all_reduce_combine_threshold_bytes(), + /*combine_threshold_count=*/256); + pipeline.AddPass( + hlo_module->config() + .debug_options() + .xla_gpu_reduce_scatter_combine_threshold_bytes(), + /*combine_threshold_count=*/256, + hlo_module->config() + .debug_options() + .xla_gpu_enable_reduce_scatter_combine_by_dim()); - int32_t blueconnect_num_devices_per_host = - debug_options.xla_gpu_all_reduce_blueconnect_num_devices_per_host(); - if (blueconnect_num_devices_per_host > 0) { - pipeline.AddPass(blueconnect_num_devices_per_host); - } + if (hlo_module->config().debug_options().xla_gpu_all_reduce_contiguous()) { + pipeline.AddPass(); + } - if (debug_options.xla_gpu_enable_while_loop_double_buffering()) { - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); - } + TF_RETURN_IF_ERROR(add_custom_kernel_replacement_passes( + &pipeline, hlo_module->config().debug_options())); - { - // Convert all collectives to their async form, and then annotate the ones - // that actually need to run asynchronously with a GPU specific backend - // config. - AsyncCollectiveCreator::CollectiveCreatorConfig config; - config.convert_all_reduce = HloPredicateTrue; - config.convert_collective_permute = HloPredicateTrue; - config.convert_all_gather = HloPredicateTrue; - config.convert_reduce_scatter = HloPredicateTrue; - config.convert_all_to_all = HloPredicateTrue; - pipeline.AddPass(std::move(config)); - - auto convert_to_async = [&debug_options](const HloInstruction* inst) { - const bool enable_all_async = - debug_options.xla_gpu_enable_async_collectives(); - switch (inst->opcode()) { - case HloOpcode::kAllReduceStart: - return enable_all_async || - debug_options.xla_gpu_enable_async_all_reduce(); - case HloOpcode::kAllGatherStart: + int32_t blueconnect_num_devices_per_host = + hlo_module->config() + .debug_options() + .xla_gpu_all_reduce_blueconnect_num_devices_per_host(); + if (blueconnect_num_devices_per_host > 0) { + pipeline.AddPass(blueconnect_num_devices_per_host); + } + + if (hlo_module->config() + .debug_options() + .xla_gpu_enable_while_loop_double_buffering()) { + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); + } + + return pipeline.Run(hlo_module).status(); +} + +absl::Status RunPostFusionCollectiveOptimizationPasses(HloModule* hlo_module) { + HloPassPipeline pipeline("post-fusion-collectives optimization"); + + // Convert all collectives to their async form, and then annotate the ones + // that actually need to run asynchronously with a GPU specific backend + // config. + AsyncCollectiveCreator::CollectiveCreatorConfig config; + config.convert_all_reduce = HloPredicateTrue; + config.convert_collective_broadcast = HloPredicateTrue; + config.convert_collective_permute = HloPredicateTrue; + config.convert_all_gather = HloPredicateTrue; + config.convert_reduce_scatter = HloPredicateTrue; + config.convert_all_to_all = HloPredicateTrue; + pipeline.AddPass(std::move(config)); + + auto convert_to_async = [&hlo_module](const HloInstruction* inst) { + const bool enable_all_async = + hlo_module->config().debug_options().xla_gpu_enable_async_collectives(); + switch (inst->opcode()) { + case HloOpcode::kAllReduceStart: + return enable_all_async || hlo_module->config() + .debug_options() + .xla_gpu_enable_async_all_reduce(); + case HloOpcode::kAllGatherStart: + return enable_all_async || hlo_module->config() + .debug_options() + .xla_gpu_enable_async_all_gather(); + case HloOpcode::kCollectivePermuteStart: + return enable_all_async || + hlo_module->config() + .debug_options() + .xla_gpu_enable_async_collective_permute(); + case HloOpcode::kAsyncStart: { + auto async_inst = Cast(inst); + switch (async_inst->async_wrapped_opcode()) { + case HloOpcode::kCollectiveBroadcast: return enable_all_async || - debug_options.xla_gpu_enable_async_all_gather(); - case HloOpcode::kCollectivePermuteStart: + hlo_module->config() + .debug_options() + .xla_gpu_enable_async_collective_broadcast(); + case HloOpcode::kReduceScatter: return enable_all_async || - debug_options.xla_gpu_enable_async_collective_permute(); - case HloOpcode::kAsyncStart: { - auto async_inst = Cast(inst); - switch (async_inst->async_wrapped_opcode()) { - case HloOpcode::kReduceScatter: - return enable_all_async || - debug_options.xla_gpu_enable_async_reduce_scatter(); - case HloOpcode::kAllToAll: - return enable_all_async || - debug_options.xla_gpu_enable_async_all_to_all(); - default: - return false; - } - } + hlo_module->config() + .debug_options() + .xla_gpu_enable_async_reduce_scatter(); + case HloOpcode::kAllToAll: + return enable_all_async || hlo_module->config() + .debug_options() + .xla_gpu_enable_async_all_to_all(); default: return false; } - }; - pipeline.AddPass(convert_to_async); + } + default: + return false; } - pipeline.AddPass( - debug_options.xla_gpu_collective_permute_decomposer_threshold()); - - if (enable_all_pipelined || debug_options.xla_gpu_enable_pipelined_p2p()) { - auto may_pipeline_p2p = [](const HloInstruction* instruction) { - const HloRecvDoneInstruction* recv_done = - DynCast(instruction); - if (!recv_done || recv_done->is_host_transfer()) return false; - // Check that the recv-done is used for non-trivial computation, which - // can also help avoid repeatedly pipelining a loop. - return recv_done->user_count() == 1 && recv_done->parent() != nullptr && - recv_done->users()[0] != recv_done->parent()->root_instruction(); - }; - // We curretly use one asynchronous stream to execute P2P operations, - // as such, can only support pipelining at most one P2P chain in each - // loop. - CollectivePipeliner::Config config{ - /*level_to_operate_on=*/0, - /*max_pipelining_per_loop=*/1, - /*last_run=*/true, - /*pipeline_use_tree=*/false, - /*process_different_sized_ops=*/true, - /*pipelining_direction=*/ - CollectivePipeliner::PipeliningDirection::kBackward, - /*should_process=*/may_pipeline_p2p, - /*acceptable_formatting=*/[](const HloInstruction*) { return true; }}; - pipeline.AddPass(config); + }; + pipeline.AddPass(convert_to_async); + + return pipeline.Run(hlo_module).status(); +} + +absl::Status RunPostFusionSimplificationPasses( + HloModule* hlo_module, + const AlgebraicSimplifierOptions& layout_insensitive_algsimp_opts) { + HloPassPipeline pipeline("post-fusion-simplification-pipeline optimization"); + AlgebraicSimplifierOptions options = layout_insensitive_algsimp_opts; + options.set_is_layout_sensitive(true); + pipeline.AddPass(options); + + // This invocation is used to populate deduplicated_name for fusions that + // are considered duplicates according to the comparator in this pass. + // Currently, the pass doesn't actually deduplicate the fusions. + pipeline.AddPass( + /*mark_fusion_duplications=*/true); + + if (hlo_module->config() + .debug_options() + .xla_gpu_multi_streamed_windowed_einsum()) { + pipeline.AddPass(); + pipeline.AddPass(); + } + + return pipeline.Run(hlo_module).status(); +} + +} // namespace + +// Runs optimization passes on the given HLO module. +absl::Status GpuCompiler::OptimizeHloModule( + HloModule* hlo_module, se::StreamExecutor* stream_exec, + const CompileOptions& options, const TargetConfig& gpu_target_config) { + CheckNotScheduled(hlo_module); + LogDebugOptions(hlo_module); + + MaybeOwningThreadPool thread_pool = MaybeOwningThreadPool::GetOrCreate( + /*parallelism=*/hlo_module->config() + .debug_options() + .xla_gpu_force_compilation_parallelism(), + /*default_thread_pool=*/options.thread_pool, + /*default_parallelism=*/tsl::port::MaxParallelism()); + + AlgebraicSimplifierOptions layout_insensitive_algsimp_opts = + LayoutInsensitiveAlgebraicSimplifierOptions( + hlo_module->config(), gpu_target_config, + GetAlgebraicSimplifierOptions(hlo_module->config())); + + TF_RETURN_IF_ERROR(RunPreSPMDPartitionerPasses(hlo_module)); + TF_RETURN_IF_ERROR(RunSPMDPasses(hlo_module, gpu_target_config, + layout_insensitive_algsimp_opts)); + TF_RETURN_IF_ERROR(RunOptimizationPasses(hlo_module, gpu_target_config, + layout_insensitive_algsimp_opts)); + TF_RETURN_IF_ERROR(RunCollectiveOptimizationPasses( + hlo_module, layout_insensitive_algsimp_opts)); + + // Run target-specific HLO optimization passes for convolution + // canonicalization. + se::GpuComputeCapability gpu_version = + gpu_target_config.device_description.gpu_compute_capability(); + se::dnn::VersionInfo dnn_version = gpu_target_config.dnn_version_info; + if (stream_exec != nullptr) { + gpu_version = GetGpuVersion(stream_exec); + se::dnn::DnnSupport* dnn = stream_exec->AsDnn(); + if (dnn == nullptr) { + return tsl::errors::FailedPrecondition( + "DNN library initialization failed." + " Look at the errors above for more details."); } + TF_ASSIGN_OR_RETURN(dnn_version, dnn->GetVersion()); + } - AlgebraicSimplifierOptions options = layout_insensitive_algsimp_opts; - options.set_is_layout_sensitive(true); - pipeline.AddPass(options); + TF_RETURN_IF_ERROR(OptimizeHloConvolutionCanonicalization( + hlo_module, gpu_version, dnn_version, options.device_allocator)); - // This invocation is used to populate deduplicated_name for fusions that - // are considered duplicates according to the comparator in this pass. - // Currently, the pass doesn't actually deduplicate the fusions. - pipeline.AddPass( - /*mark_fusion_duplications=*/true); + TF_RETURN_IF_ERROR( + RunLayoutAssignmentPasses(hlo_module, gpu_version, dnn_version)); + // TODO(b/328264715): Add tests to ensure that layout normalization pass is + // run before any fusion pass. + HloPassPipeline layout_normalization_pipeline("layout normalization"); + const DebugOptions& debug_options = hlo_module->config().debug_options(); + const AlgebraicSimplifierOptions simplifier_options = [&] { + AlgebraicSimplifierOptions opts = + GetAlgebraicSimplifierOptions(hlo_module->config()); + opts.set_supports_non_canonical_dots(false); + opts.set_is_layout_sensitive(true); + opts.set_enable_conv_operand_swap(false); + // "slow" minmax means we propagate nan. + opts.set_minmax_propagate_nan(!debug_options.xla_gpu_enable_fast_min_max()); + opts.set_enable_unconditional_reduce_of_concat_replacement(false); + return opts; + }(); + if (debug_options.xla_gpu_normalize_layouts()) { + layout_normalization_pipeline.AddPass(); + layout_normalization_pipeline.AddPass>(); + layout_normalization_pipeline.AddPass( + &NormalizeLayoutForGpuCustomCalls); + // The LayoutAssignment pass may leave behind kCopy instructions which are + // duplicate or NOPs, so remove them with algebraic simplification and CSE. + layout_normalization_pipeline.AddPass>( + simplifier_options); + } + TF_RETURN_IF_ERROR(layout_normalization_pipeline.Run(hlo_module).status()); + // Run target-specific HLO optimization passes after layout assignment. + TF_RETURN_IF_ERROR(OptimizeHloPostLayoutAssignment( + hlo_module, stream_exec, options, gpu_target_config, thread_pool.get())); + + // This is a "low effort, high impact" fusion that should be run first. + if (hlo_module->config() + .debug_options() + .xla_gpu_enable_address_computation_fusion()) { + HloPassPipeline pipeline("address-computation"); + TF_ASSIGN_OR_RETURN(se::Platform * platform, + se::PlatformManager::PlatformWithId(PlatformId())); + pipeline.AddPass(platform->Name()); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } - return OkStatus(); + TF_RETURN_IF_ERROR(RunFusionPasses(hlo_module, gpu_target_config, + thread_pool.get(), + ShapeSizeBytesFunction())); + TF_RETURN_IF_ERROR(RunPostFusionPasses( + hlo_module, + [this](HloPassPipeline* pipeline, const DebugOptions& debug_options) { + return AddCustomKernelReplacementPasses(pipeline, debug_options); + })); + TF_RETURN_IF_ERROR(RunPostFusionCollectiveOptimizationPasses(hlo_module)); + TF_RETURN_IF_ERROR(RunPostFusionSimplificationPasses( + hlo_module, layout_insensitive_algsimp_opts)); + + return absl::OkStatus(); +} // NOLINT(readability/fn_size) + +AlgebraicSimplifierOptions GpuCompiler::GetAlgebraicSimplifierOptions( + const HloModuleConfig& config) { + AlgebraicSimplifierOptions opts; + opts.set_enable_dot_strength_reduction( + config.debug_options().xla_gpu_enable_dot_strength_reduction()); + return opts; } // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. -Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) { +absl::Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) { return PrepareHloModuleForIrEmittingPipeline(*hlo_module, GetCanShareBuffer()) .Run(hlo_module) .status(); } -Status GpuCompiler::OptimizeHloPostLayoutAssignment( +absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( HloModule* hlo_module, se::StreamExecutor* stream_exec, const CompileOptions& options, const TargetConfig& gpu_target_config, tsl::thread::ThreadPool* thread_pool) { @@ -1150,7 +1337,8 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( const se::GpuComputeCapability gpu_version = gpu_target_config.device_description.gpu_compute_capability(); const AlgebraicSimplifierOptions simplifier_options = [&] { - AlgebraicSimplifierOptions opts; + AlgebraicSimplifierOptions opts = + GetAlgebraicSimplifierOptions(hlo_module->config()); opts.set_supports_non_canonical_dots(false); opts.set_is_layout_sensitive(true); opts.set_enable_conv_operand_swap(false); @@ -1163,12 +1351,12 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( GetAutotuneConfig(stream_exec, debug_options, options, gpu_target_config)); // Lambdas and related constants: - const GpuFloatSupport bf16_support(BF16); - const GpuFloatSupport f8e5m2_support(F8E5M2, F16); - const GpuFloatSupport f8e4m3fn_support(F8E4M3FN, F16); + const GpuFloatSupport bf16_support(gpu_version, BF16); + const GpuFloatSupport f8e5m2_support(gpu_version, F8E5M2, F16); + const GpuFloatSupport f8e4m3fn_support(gpu_version, F8E4M3FN, F16); const FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ, F16); - const FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ, F16); - const FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ, F16); + const GpuFloatSupport f8e5m2fnuz_support(gpu_version, F8E5M2FNUZ, F16); + const GpuFloatSupport f8e4m3fnuz_support(gpu_version, F8E4M3FNUZ, F16); auto add_float_normalization = [&](HloPassPipeline& pipeline) { auto& sub_pipeline = pipeline.AddPass("float_normalization"); @@ -1179,9 +1367,9 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( sub_pipeline.AddPass(&f8e5m2fnuz_support); sub_pipeline.AddPass(&f8e4m3fnuz_support); // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. - if (debug_options.xla_gpu_simplify_all_fp_conversions()) { - sub_pipeline.AddPass( - SimplifyFPConversions::Scope::kSimplifyAllConversions); + if (debug_options.xla_allow_excess_precision() && + debug_options.xla_gpu_simplify_all_fp_conversions()) { + sub_pipeline.AddPass(); } }; @@ -1202,35 +1390,40 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass([&](const HloInstruction* r) { return IsReductionFromOrToContiguousDimensions(*r); }); - pipeline.AddPass>(); - // Greedy pattern matching for custom fusions. We run it before Triton - // rewriter or a regular Gemm rewriter to be able to match compatible GEMMs - // before they matched into Triton gemm or a cuBLAS custom call. + // Greedy pattern matching for custom kernel fusions. We run it before + // Triton rewriter or a regular Gemm rewriter to be able to match compatible + // GEMMs before they matched into Triton gemm or a cuBLAS custom call. // // TODO(ezhulenev): This should be plugged into the cost model and fusion // heuristic, so we can mix and match various Gemm implementations based // on projected (measured) performance. if (debug_options.xla_gpu_enable_custom_fusions()) { - pipeline.AddPass(); + pipeline.AddPass( + &gpu_target_config.device_description); } // Rewrite GEMMs into custom calls. se::GpuComputeCapability gpu_version = gpu_target_config.device_description.gpu_compute_capability(); + pipeline.AddPass(gpu_version); const auto* cuda_cc = std::get_if(&gpu_version); + + // Rewrite FP8 GEMMs ahead of Triton which currently lacks support for FP8 + // and may rewrite quantized FP8 GEMMs as higher-precision GEMMs. + pipeline.AddPass(gpu_version, /*f8_rewrite=*/true); if (debug_options.xla_gpu_enable_triton_gemm() && cuda_cc != nullptr && - cuda_cc->IsAtLeast(se::CudaComputeCapability::VOLTA)) { - pipeline.AddPass(gpu_version); + cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) { + pipeline.AddPass(gpu_version); } - pipeline.AddPass(gpu_version); + // Rewrite non-FP8 GEMMs. + pipeline.AddPass(gpu_version, /*f8_rewrite=*/false); // Rewrite GEMMs with broadcasted inputs as strided GEMMs. pipeline.AddPass(); if (debug_options.xla_gpu_normalize_layouts()) { pipeline.AddPass(&NormalizeLayoutForGpuCustomCalls); - pipeline.AddPass>(simplifier_options); } pipeline.AddPass(); @@ -1242,13 +1435,17 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( // harder. if (debug_options.xla_gpu_enable_triton_softmax_fusion() && cuda_cc != nullptr && - cuda_cc->IsAtLeast(se::CudaComputeCapability::VOLTA)) { + cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) { pipeline.AddPass>(simplifier_options); pipeline.AddPass(gpu_version); } pipeline.AddPass(); - pipeline.AddPass>(); + // Do not split small reduction dimensions unless priority fusion is + // enabled, which handles such cases well. + bool ignore_small_reduce_dims = + !debug_options.xla_gpu_enable_priority_fusion(); + pipeline.AddPass>(ignore_small_reduce_dims); pipeline.AddPass>(gpu_version); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } @@ -1274,7 +1471,7 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( // f32). add_float_normalization(pipeline); - TF_RETURN_IF_ERROR(AddTritonGemmAutotuningPasses( + TF_RETURN_IF_ERROR(AddGemmFusionAutotuningPasses( &pipeline, hlo_module, autotune_config, thread_pool)); // Inline back the calls which have better performance with cuBLAS. pipeline.AddPass(); @@ -1284,10 +1481,16 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( // Rewrite GEMMs with broadcasted inputs as strided GEMMs. pipeline.AddPass(); + pipeline.AddPass( + static_cast(stream_executor::MemoryType::kHost), + /* after_layout= */ true); + pipeline.AddPass( + static_cast(stream_executor::MemoryType::kHost)); + TF_RETURN_IF_ERROR(AddConvAndGemmAutotuningPasses( &pipeline, hlo_module, autotune_config, thread_pool)); - // The Triton autotuner can insert new bf16 reductions that need to be + // The GEMM fusion autotuner can insert new bf16 reductions that need to be // normalized again. add_float_normalization(pipeline); @@ -1298,22 +1501,58 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( // duplicate or NOPs, so remove them with algebraic simplification and CSE. pipeline.AddPass>(simplifier_options); + if (debug_options.xla_allow_excess_precision() && + debug_options.xla_gpu_simplify_all_fp_conversions()) { + // This pass cleans up chains of compiler-generated converts + // (i.e. f32 -> bf16 -> f32) that have been produced by the algebraic + // simplifier by rearranging ops (i.e. by pushing broadcasts towards the + // root). + pipeline.AddPass(); + } + // Since this CSE runs after collective schedule linearizer which inserts // control dependencies, ignore these control deps when replacing instructions // with equivalent ones here. pipeline.AddPass(/*is_layout_sensitive=*/true, - /*only_fusion_computations*/ false, + /*only_fusion_computations=*/false, /*ignore_control_dependencies=*/true); + + pipeline.AddPass( + static_cast(stream_executor::MemoryType::kHost)); + +#ifdef NDEBUG + // Verify the module in non-debug builds. For debug builds, the verifier + // already runs after every pass. + pipeline.AddPass( + std::make_unique( + HloVerifierOpts{} + .MakeLayoutSensitive() + .WithInstructionCanChangeLayout( + LayoutAssignment::InstructionCanChangeLayout) + .VerifyBroadcastDimensionsOrder() + .VerifyReshapeIsBitcast()), + "end-of-post-layout_assignment"); +#endif // NDEBUG + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - return OkStatus(); + if (DumpingEnabledForHloModule(*hlo_module)) { + TF_ASSIGN_OR_RETURN( + std::string autotune_results, + AutotunerUtil::SerializeAutotuneResultsForModule( + *hlo_module, autotune_config, /*as_textproto=*/true)); + DumpToFileInDirOrStdout(*hlo_module, "", "autotune_results.pbtxt", + autotune_results); + } + + return absl::OkStatus(); } -// Get the target config for compilation. Returns std::nullopt if no deviceless -// target config is specified: in this case, device is used. -static StatusOr> -GetDevicelessTargetConfig(const Compiler::CompileOptions& options, - const DebugOptions& debug_opts) { +// Returns the TargetConfig, either from the module debug options, or from the +// CompilationOptions, or if both of those are absent, from the attached GPU. +/*static*/ absl::StatusOr GpuCompiler::GetTargetConfig( + const Compiler::CompileOptions& options, const DebugOptions& debug_opts, + se::StreamExecutor* executor) { if (options.target_config.has_value()) { return *options.target_config; } @@ -1325,27 +1564,30 @@ GetDevicelessTargetConfig(const Compiler::CompileOptions& options, stream_executor::GpuTargetConfigProto gpu_target_config_proto; if (!tsl::protobuf::TextFormat::ParseFromString(gpu_target_config_string, &gpu_target_config_proto)) { - return FailedPrecondition("Failed to parse GpuTargetConfigProto"); + return absl::FailedPreconditionError( + "Failed to parse GpuTargetConfigProto"); } return Compiler::TargetConfig{gpu_target_config_proto}; } - return std::nullopt; + if (executor) { + return Compiler::TargetConfig{executor}; + } + return absl::InternalError( + "Either GPU has to be attached, or --xla_gpu_target_config_filename " + "has to be specified to specify the target to compile for."); } -StatusOr> GpuCompiler::RunHloPasses( +absl::StatusOr> GpuCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, const CompileOptions& options) { - TF_RETURN_IF_ERROR( - LoadAutotuneResultsFromFile(module->config().debug_options())); - - TF_ASSIGN_OR_RETURN( - std::optional forced_target_config, - GetDevicelessTargetConfig(options, module->config().debug_options())); + const DebugOptions& debug_opts = module->config().debug_options(); + TF_RETURN_IF_ERROR(LoadAutotuneResultsFromFile(debug_opts)); + bool is_deviceless = options.target_config.has_value() || + !debug_opts.xla_gpu_target_config_filename().empty(); - bool is_deviceless = forced_target_config.has_value(); - TargetConfig gpu_target_config = - is_deviceless ? *forced_target_config : TargetConfig{stream_exec}; + TF_ASSIGN_OR_RETURN(TargetConfig gpu_target_config, + GetTargetConfig(options, debug_opts, stream_exec)); const std::optional unoptimized_fingerprint = MaybeUploadUnoptimizedGpuSymbols(module.get(), gpu_target_config.ToProto()); @@ -1387,7 +1629,7 @@ StatusOr> GpuCompiler::RunHloPasses( } namespace { -Status RunPostSchedulingCopyInsertion( +absl::Status RunPostSchedulingCopyInsertion( HloModule* module, const HloDataflowAnalysis::CanShareBuffer& can_share_buffer) { // We run a separate pass of copy elision here because the sequential ordering @@ -1427,50 +1669,19 @@ Status RunPostSchedulingCopyInsertion( TF_RETURN_IF_ERROR(saved_schedule.Update()); TF_RETURN_IF_ERROR(module->set_schedule(std::move(saved_schedule))); - return OkStatus(); + return absl::OkStatus(); } } // namespace -StatusOr> GpuCompiler::AssignBuffers( - HloModule* hlo_module, se::StreamExecutor* stream_exec) { - const se::DeviceDescription& gpu_device_info = - stream_exec->GetDeviceDescription(); - const int64_t scheduler_mem_limit = - GetSchedulerMemoryLimit(hlo_module, gpu_device_info, pointer_size_); - TF_RETURN_IF_ERROR(ScheduleGpuModule(hlo_module, pointer_size_, - scheduler_mem_limit, gpu_device_info)); - TF_RETURN_IF_ERROR( - RunPostSchedulingCopyInsertion(hlo_module, GetCanShareBuffer())); - - auto buffer_size_bytes_function = - [this](const BufferValue& buffer_value) -> int64_t { - return GetSizeOfShape(buffer_value.shape(), pointer_size_); - }; - - TF_ASSIGN_OR_RETURN( - std::unique_ptr assignment, - BufferAssigner::Run( - hlo_module, - std::make_unique(hlo_module->schedule()), - buffer_size_bytes_function, - /*color_alignment=*/ - [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; }, - /*allocate_buffers_for_constants=*/true, - /*colorer=*/BufferAssigner::DefaultColorer(), - /*must_not_live_out=*/{}, GetCanShareBuffer())); - - return std::move(assignment); -} - using OutputInfoMap = absl::flat_hash_map; -static void NullDiagnosticHandler(const llvm::DiagnosticInfo& diag_info, +static void NullDiagnosticHandler(const llvm::DiagnosticInfo* diag_info, void* context) { std::string error_string; llvm::raw_string_ostream string_printer(error_string); llvm::DiagnosticPrinterRawOStream diagnostic_printer(string_printer); - diag_info.print(diagnostic_printer); + diag_info->print(diagnostic_printer); VLOG(5) << error_string; } @@ -1500,10 +1711,13 @@ std::unique_ptr CopyToContext(const llvm::Module& module, } // namespace -StatusOr GpuCompiler::CompileSingleModule( - const HloModuleConfig& module_config, se::GpuComputeCapability gpu_version, - const HloModule* debug_module, llvm::Module* llvm_module, bool relocatable, - const CompileOptions& options, std::optional shard_number) { +absl::StatusOr +GpuCompiler::CompileSingleModule(const HloModuleConfig& module_config, + se::GpuComputeCapability gpu_version, + const HloModule* debug_module, + llvm::Module* llvm_module, bool relocatable, + const CompileOptions& options, + std::optional shard_number) { // This may print multiple lines per HLO compilation because of the // parallelized compilation of LLVM modules. XLA_SCOPED_LOGGING_TIMER_IF( @@ -1573,15 +1787,22 @@ StatusOr GpuCompiler::CompileSingleModule( return result; } -StatusOr GpuCompiler::CompileToTargetBinary( - const HloModuleConfig& module_config, llvm::Module* llvm_module, - se::GpuComputeCapability gpu_version, se::StreamExecutor* stream_exec, - const CompileOptions& options, const HloModule* debug_module) { - MaybeOwningThreadPool thread_pool = MaybeOwningThreadPool::GetOrCreate( - /*parallelism=*/module_config.debug_options() - .xla_gpu_force_compilation_parallelism(), - /*default_thread_pool=*/options.thread_pool, - /*default_parallelism=*/1); +absl::StatusOr +GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config, + llvm::Module* llvm_module, + se::GpuComputeCapability gpu_version, + se::StreamExecutor* stream_exec, + const CompileOptions& options, + const HloModule* debug_module) { + MaybeOwningThreadPool thread_pool = + module_config.debug_options() + .xla_gpu_enable_llvm_module_compilation_parallelism() + ? MaybeOwningThreadPool::GetOrCreate( + /*parallelism=*/module_config.debug_options() + .xla_gpu_force_compilation_parallelism(), + /*default_thread_pool=*/options.thread_pool, + /*default_parallelism=*/1) + : MaybeOwningThreadPool(nullptr); // Test whether LinkModules is supported. TF_ASSIGN_OR_RETURN(bool can_use_link_modules, @@ -1645,7 +1866,7 @@ StatusOr GpuCompiler::CompileToTargetBinary( }, /*PreserveLocals=*/true); - std::vector> compile_results( + std::vector> compile_results( llvm_modules.size()); tsl::BlockingCounter counter(llvm_modules.size()); for (int i = 0; i < llvm_modules.size(); i++) { @@ -1681,36 +1902,29 @@ StatusOr GpuCompiler::CompileToTargetBinary( this->LinkModules(stream_exec, std::move(submodule_compile_results), module_config.debug_options()); if (!maybe_backend_result.ok()) { - LOG(ERROR) << "The CUDA linking API did not work. Please use " - "XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 to " - "bypass it, but expect to get longer compilation time due to " - "the lack of multi-threading. Original error: " + LOG(ERROR) << "The CUDA linking API did not work. Please use XLA_FLAGS=" + "--xla_gpu_enable_llvm_module_compilation_parallelism=false " + "to bypass it, but expect to get longer compilation time due " + "to the lack of multi-threading. Original error: " << maybe_backend_result.status(); return maybe_backend_result.status(); } return BackendCompileResult{ptx_snippets, std::move(*maybe_backend_result)}; } -StatusOr +absl::StatusOr GpuCompiler::CompileToBackendResult( HloModule* module, llvm::LLVMContext* llvm_context, se::StreamExecutor* executor, const CompileOptions& options, const se::DeviceDescription& gpu_device_info) { - const int64_t scheduler_mem_limit = - GetSchedulerMemoryLimit(module, gpu_device_info, pointer_size_); - TF_RETURN_IF_ERROR(ScheduleGpuModule(module, pointer_size_, - scheduler_mem_limit, gpu_device_info)); - - if (!IsXlaRuntimeExecutableEnabled(module->config())) { - HloPassPipeline pipeline("command-buffer-scheduling"); - pipeline.AddPass(); - TF_RETURN_IF_ERROR(pipeline.Run(module).status()); - } - - TF_RETURN_IF_ERROR(RunPostSchedulingPipelines(module, scheduler_mem_limit)); + TF_ASSIGN_OR_RETURN( + ScheduleMetadata schedule_metadata, + ScheduleGpuModule(module, pointer_size_, gpu_device_info)); + TF_RETURN_IF_ERROR(RunPostSchedulingPipelines( + module, schedule_metadata.scheduler_mem_limit, gpu_device_info)); TF_ASSIGN_OR_RETURN(se::Platform * platform, - se::MultiPlatformManager::PlatformWithId(PlatformId())); + se::PlatformManager::PlatformWithId(PlatformId())); // Compile the module TF_ASSIGN_OR_RETURN( @@ -1732,29 +1946,34 @@ GpuCompiler::CompileToBackendResult( module->config(), compile_module_results.llvm_module.get(), gpu_device_info.gpu_compute_capability(), executor, options, module)); RecordXlaDeviceBinarySize(backend_result.binary.size()); - if (DumpingEnabledForHloModule(*module) && - std::holds_alternative( - compile_module_results.executable)) { - const ThunkSequence& thunk_sequence = - *std::get( - compile_module_results.executable); + if (DumpingEnabledForHloModule(*module)) { DumpToFileInDirOrStdout(*module, "", "thunk_sequence.txt", - thunk_sequence.ToString()); + compile_module_results.executable->ToString()); } return CompileResultWithMetadata{std::move(backend_result), std::move(compile_module_results)}; } -StatusOr> GpuCompiler::RunBackend( +absl::StatusOr> GpuCompiler::RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, const CompileOptions& options) { - TF_ASSIGN_OR_RETURN( - std::optional forced_target_config, - GetDevicelessTargetConfig(options, module->config().debug_options())); - bool is_deviceless = forced_target_config.has_value(); - TargetConfig gpu_target_config = - is_deviceless ? *forced_target_config : TargetConfig{stream_exec}; + Thunk::BinaryMap dnn_compiled_graphs; + if (stream_exec) { + TF_RETURN_IF_ERROR(RunCudnnFusionCompilerPass(module.get(), stream_exec, + &dnn_compiled_graphs)); + } + + const DebugOptions& debug_opts = module->config().debug_options(); + TF_ASSIGN_OR_RETURN(TargetConfig gpu_target_config, + GetTargetConfig(options, debug_opts, stream_exec)); + + if (DumpingEnabledForHloModule(*module)) { + std::string textproto; + tsl::protobuf::TextFormat::PrintToString(gpu_target_config.ToProto(), + &textproto); + DumpToFileInDirOrStdout(*module, "", "gpu_target_config.pbtxt", textproto); + } if (!options.is_autotuning_compilation) { VLOG(1) << "Starting to compile HLO module " << module->name(); @@ -1768,12 +1987,6 @@ StatusOr> GpuCompiler::RunBackend( auto slow_compile_alarm = SlowCompilationAlarm(slow_compilation_msg); if (options.is_autotuning_compilation) { - if (module->config() - .debug_options() - .xla_gpu_enable_persistent_temp_buffers()) { - LOG(WARNING) << "Doing autotuning compilations with " - "xla_gpu_enable_persistent_temp_buffers wastes memory!"; - } if (module->config().debug_options().xla_embed_ir_in_executable()) { LOG(WARNING) << "Doing autotuning compilations with " "xla_embed_ir_in_executable wastes memory!"; @@ -1805,11 +2018,9 @@ StatusOr> GpuCompiler::RunBackend( CompileToBackendResult(module.get(), &llvm_context, stream_exec, options, gpu_device_info)); - if (auto thunk_sequence = std::get_if( - &res.compile_module_results.executable); - DumpingEnabledForHloModule(*module) && thunk_sequence) { + if (DumpingEnabledForHloModule(*module)) { DumpToFileInDirOrStdout(*module, "", "thunk_sequence.txt", - (*thunk_sequence)->ToString()); + res.compile_module_results.executable->ToString()); } // The module is being moved into the GpuExecutable below and we need to @@ -1818,8 +2029,6 @@ StatusOr> GpuCompiler::RunBackend( module->config().debug_options().xla_embed_ir_in_executable(); int64_t debug_buffer_assignment_show_max = module->config().debug_options().xla_debug_buffer_assignment_show_max(); - bool enable_persistent_temp_buffers = - module->config().debug_options().xla_gpu_enable_persistent_temp_buffers(); TF_ASSIGN_OR_RETURN( auto gpu_executable, @@ -1829,6 +2038,8 @@ StatusOr> GpuCompiler::RunBackend( ? std::string() : std::move(res.backend_result.asm_text), /*binary=*/std::move(res.backend_result.binary), + /*dnn_compiled_graphs=*/ + std::move(dnn_compiled_graphs), /*gpu_version=*/gpu_device_info.gpu_compute_capability(), /*executable=*/std::move(res.compile_module_results.executable), /*constants=*/std::move(res.compile_module_results.constants), @@ -1841,7 +2052,6 @@ StatusOr> GpuCompiler::RunBackend( : std::move(res.compile_module_results.allocations)), /*buffer_assignment=*/ std::move(res.compile_module_results.buffer_assignment), - /*enable_persistent_temp_buffers=*/enable_persistent_temp_buffers, /*debug_buffer_assignment_show_max=*/debug_buffer_assignment_show_max, /*debug_module=*/options.is_autotuning_compilation ? std::unique_ptr() @@ -1871,7 +2081,7 @@ StatusOr> GpuCompiler::RunBackend( return static_cast>(std::move(gpu_executable)); } -StatusOr>> +absl::StatusOr>> GpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& options) { #if GOOGLE_CUDA @@ -1882,6 +2092,26 @@ GpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, std::vector> modules = module_group->ConsumeModules(); + + std::vector> optimized_modules; + optimized_modules.reserve(modules.size()); + + for (std::unique_ptr& module : modules) { + if (!module->has_schedule()) { + CompileOptions compile_options; + compile_options.device_allocator = options.device_allocator(); + compile_options.target_config = options.target_config(); + TF_ASSIGN_OR_RETURN( + std::unique_ptr optimized_module, + RunHloPasses(std::move(module), options.executor(), compile_options)); + optimized_modules.push_back(std::move(optimized_module)); + } else { + optimized_modules.push_back(std::move(module)); + } + } + + modules = std::move(optimized_modules); + std::vector> results; const std::optional& target_config = @@ -1897,70 +2127,15 @@ GpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, CompileToBackendResult(module.get(), &llvm_context, options.executor(), {options.device_allocator()}, gpu_device_info)); - if (!IsXlaRuntimeExecutableEnabled(module->config())) { - // Create GpuThunkAotCompilationResult if thunk runtime is enabled. - results.emplace_back(std::make_unique( - module.get(), res.compile_module_results.buffer_assignment.get(), - res.backend_result.asm_text, res.backend_result.binary)); - continue; - } - - const auto* program = std::get_if( - &res.compile_module_results.executable); - if (!program) { - return InternalError("Gpu runtime program was not provided"); - } - - // TODO(ezhulenev): Unify AOT compilation with GpuRuntimeExecutable::Create - // (see `gpu/runtime/executable.h`). - - // Options for the default XLA runtime compilation pipeline. - runtime::CompilationPipelineOptions copts; - - // Populate mapping from XLA (SE) enums/structs type id to symbol names. - copts.populate_type_id_names = RegisterXlaGpuTypeIdNames; - - // For passing LMHLO attributes as XLA (SE) enums/structs to custom calls. - copts.populate_attr_encodings = RegisterXlaGpuAttrEncoding; - - // Options for constructing XLA runtime JitExecutable. - runtime::JitExecutable::Options opts; - opts.specialization = runtime::JitExecutable::Specialization::kDisabled; - opts.compiler.register_dialects = - runtime::RegisterDefaultXlaGpuRuntimeDialects; - - // Register XLA Gpu runtime custom calls with the linker. - opts.compiler.symbols_binding = runtime::ToSymbolsBinding( - RegisterXlaGpuRuntimeCustomCalls, RegisterXlaGpuTypeIdNames); - - opts.compiler.create_compilation_pipeline = - [copts](xla::runtime::PassManager& passes) { - runtime::CreateDefaultXlaGpuRuntimeCompilationPipeline(passes, copts); - }; - - // Instantiate new JitExecutable from the MLIR source. - auto jit_executable = runtime::JitExecutable::Instantiate( - (*program)->module, (*program)->entry_point, opts); - if (!jit_executable.ok()) - return InternalError("Failed to compile XLA program: %s", - jit_executable.status().message()); - - // For static shapes we can always serialize only the default executable. - runtime::Executable& executable = jit_executable->DefaultExecutable().get(); - - // Check if XLA runtime executable saved the compilation result. - std::unique_ptr obj_file = executable.obj_file(); - if (!obj_file) - return InternalError("XLA runtime executable didn't save the obj file"); - - std::string data(obj_file->getBuffer().data(), - obj_file->getBuffer().size()); - - results.emplace_back(std::make_unique( - module->ToProto(), data, (*program)->module, - res.backend_result.asm_text, res.backend_result.binary, - res.compile_module_results.constants)); + // Create GpuThunkAotCompilationResult if thunk runtime is enabled. + TF_ASSIGN_OR_RETURN( + results.emplace_back(), + GpuThunkAotCompilationResult::FromModule( + module.get(), res.compile_module_results.buffer_assignment.get(), + res.backend_result.asm_text, res.backend_result.binary, + res.backend_result.dnn_compiled_graphs)); } + return std::move(results); } @@ -1971,23 +2146,20 @@ HloCostAnalysis::ShapeSizeFunction GpuCompiler::ShapeSizeBytesFunction() const { }; } -StatusOr> GpuCompiler::Export( +absl::StatusOr> GpuCompiler::Export( Executable* executable) const { auto* gpu_executable = tensorflow::down_cast(executable); if (!gpu_executable) return Internal("GpuExecutable is null"); - HloModuleProto module_proto = gpu_executable->module().ToProto(); - auto obj_file = gpu_executable->GetObjFile().value_or(""); - auto mlir_module = gpu_executable->GetMlirModule().value_or(""); - auto text = gpu_executable->text(); - auto binary = gpu_executable->binary(); - - return std::make_unique( - module_proto, obj_file, mlir_module, text, binary, - gpu_executable->constants()); + + return GpuThunkAotCompilationResult::FromModule( + &gpu_executable->module(), gpu_executable->buffer_assignment(), + gpu_executable->text(), gpu_executable->binary(), + gpu_executable->dnn_compiled_graphs()); } -Status GpuCompiler::RunPostSchedulingPipelines( - HloModule* module, int64_t scheduler_mem_limit) const { +absl::Status GpuCompiler::RunPostSchedulingPipelines( + HloModule* module, int64_t scheduler_mem_limit, + const se::DeviceDescription& gpu_device_info) const { TF_RETURN_IF_ERROR( RunPostSchedulingCopyInsertion(module, GetCanShareBuffer())); { @@ -2017,6 +2189,7 @@ Status GpuCompiler::RunPostSchedulingPipelines( /*host_memory_offload_config=*/std::nullopt); HloRematerialization::RematerializationSizes sizes; pipeline.AddPass(options, sizes); + pipeline.AddPass(); pipeline.AddPass(); TF_ASSIGN_OR_RETURN(bool changed, pipeline.Run(module)); @@ -2034,26 +2207,44 @@ Status GpuCompiler::RunPostSchedulingPipelines( // insert additional copies. TF_RETURN_IF_ERROR(pipeline.Run(module).status()); } - return OkStatus(); + + // After we have a scheduled module and all operations wrapped into fusions we + // can decide how to wrap them into command buffers. + if (!IsXlaRuntimeExecutableEnabled(module->config())) { + HloPassPipeline pipeline("command-buffer-scheduling"); + auto driver_version = se::gpu::GpuDriver::GetDriverVersion(); +#if GOOGLE_CUDA + constexpr int toolkit_version = CUDA_VERSION; +#else + constexpr int toolkit_version = TF_ROCM_VERSION; +#endif + pipeline.AddPass( + gpu_device_info, toolkit_version, + driver_version.value_or(toolkit_version)); + pipeline.AddPass(); + TF_RETURN_IF_ERROR(pipeline.Run(module).status()); + } + + return absl::OkStatus(); } -Status GpuCompiler::LoadAutotuneResultsFromFile( +absl::Status GpuCompiler::LoadAutotuneResultsFromFile( const DebugOptions& debug_options) { // We are doing this before the timer is started. if (absl::string_view file_path = debug_options.xla_gpu_load_autotune_results_from(); !file_path.empty()) { static absl::once_flag once; - Status status = OkStatus(); + absl::Status status = absl::OkStatus(); absl::call_once(once, [&file_path, &status] { status = AutotunerUtil::LoadAutotuneResultsFromFile(file_path); }); TF_RETURN_IF_ERROR(status); } - return OkStatus(); + return absl::OkStatus(); } -Status GpuCompiler::SerializeAutotuneResultsToFile( +absl::Status GpuCompiler::SerializeAutotuneResultsToFile( const DebugOptions& debug_options) { // We are doing this after the timer is finished. if (absl::string_view file_path = @@ -2064,22 +2255,18 @@ Status GpuCompiler::SerializeAutotuneResultsToFile( TF_RETURN_IF_ERROR( AutotunerUtil::SerializeAutotuneResultsToFile(file_path)); } - return OkStatus(); + return absl::OkStatus(); } -StatusOr> +absl::StatusOr> GpuCompiler::LoadAotCompilationResult( const std::string& serialized_aot_result) { return LoadAotCompilationResultStatic(serialized_aot_result); } -StatusOr> +absl::StatusOr> GpuCompiler::LoadAotCompilationResultStatic( const std::string& serialized_aot_result) { - // TODO(anlunx): Remove the code that loads a GpuAotCompilationResult when we - // convert to thunk runtime. - auto result = GpuAotCompilationResult::FromString(serialized_aot_result); - if (result.ok()) return result; return GpuThunkAotCompilationResult::FromString(serialized_aot_result); } diff --git a/xla/service/gpu/gpu_compiler.h b/xla/service/gpu/gpu_compiler.h index 0232b6b27ef8e..9d30a471deda8 100644 --- a/xla/service/gpu/gpu_compiler.h +++ b/xla/service/gpu/gpu_compiler.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,16 +18,16 @@ limitations under the License. #include #include +#include #include -#include -#include #include -#include "absl/types/span.h" +#include "absl/status/status.h" #include "llvm/IR/Module.h" #include "xla/autotune_results.pb.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" +#include "xla/service/algebraic_simplifier.h" #include "xla/service/buffer_assignment.h" #include "xla/service/compiler.h" #include "xla/service/executable.h" @@ -35,7 +35,6 @@ limitations under the License. #include "xla/service/gpu/buffer_sharing.h" #include "xla/service/gpu/compile_module_to_llvm_ir.h" #include "xla/service/gpu/executable.pb.h" -#include "xla/service/gpu/gpu_executable.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_dataflow_analysis.h" @@ -43,10 +42,11 @@ limitations under the License. #include "xla/service/hlo_pass_pipeline.h" #include "xla/service/llvm_compiler.h" #include "xla/status.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/util.h" #include "xla/xla.pb.h" @@ -55,7 +55,6 @@ limitations under the License. namespace xla { namespace gpu { - // The GPU compiler generates efficient GPU executables. class GpuCompiler : public LLVMCompiler { public: @@ -68,18 +67,15 @@ class GpuCompiler : public LLVMCompiler { // from the attached device OR from the `options` struct (in which case the // attached device is ignored during the compilation). // If you call this directly, follow it with RunBackend rather than Compile. - StatusOr> RunHloPasses( + absl::StatusOr> RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, const CompileOptions& options) override; - StatusOr> AssignBuffers( - HloModule* hlo_module, se::StreamExecutor* stream_exec) override; - - StatusOr> RunBackend( + absl::StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, const CompileOptions& options) override; - StatusOr>> + absl::StatusOr>> CompileAheadOfTime(std::unique_ptr module_group, AotCompilationOptions const& options) override; @@ -89,33 +85,49 @@ class GpuCompiler : public LLVMCompiler { // Returns a (deserialized) AotCompilationResult from a serialized // AotCompilationResult. - StatusOr> LoadAotCompilationResult( - const std::string& serialized_aot_result) override; + absl::StatusOr> + LoadAotCompilationResult(const std::string& serialized_aot_result) override; // Stateless version of the same function. - static StatusOr> + static absl::StatusOr> LoadAotCompilationResultStatic(const std::string& serialized_aot_result); - StatusOr> Export( + absl::StatusOr> Export( Executable* executable) const override; - Status RunPostSchedulingPipelines(HloModule* module, - int64_t scheduler_mem_limit) const; + absl::Status RunPostSchedulingPipelines( + HloModule* module, int64_t scheduler_mem_limit, + const se::DeviceDescription& gpu_device_info) const; std::string target_triple() const { return target_triple_; } std::string data_layout() const { return data_layout_; } + const char* GetDataLayout() const { return data_layout_; } + + const char* GetTargetTriple() const { return target_triple_; } + + int64_t GetPointerSize() const { return pointer_size_; } + + static absl::StatusOr GetTargetConfig( + const Compiler::CompileOptions& options, const DebugOptions& debug_opts, + se::StreamExecutor* executor); + + virtual HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() const { + return &FusionCanShareBufferHint; + } + protected: struct BackendCompileResult { std::string asm_text; std::vector binary; + Thunk::BinaryMap dnn_compiled_graphs; }; // During compilation with device, stream_exec != null and autotune_results // == null. During deviceless AOT compilation, stream_exec == null and // autotune_results != null. // thread_pool is used to speed up compilation during autotuning. - virtual Status OptimizeHloPostLayoutAssignment( + virtual absl::Status OptimizeHloPostLayoutAssignment( HloModule* hlo_module, se::StreamExecutor* stream_exec, const CompileOptions& options, const TargetConfig& gpu_target_config, tsl::thread::ThreadPool* thread_pool); @@ -132,25 +144,35 @@ class GpuCompiler : public LLVMCompiler { } // Add autotuning passes for convolution and gemm (except triton). - virtual Status AddConvAndGemmAutotuningPasses( + virtual absl::Status AddConvAndGemmAutotuningPasses( HloPassPipeline* pipeline, HloModule* hlo_module, AutotuneConfig& autotune_config, tsl::thread::ThreadPool* thread_pool) { - return OkStatus(); + return absl::OkStatus(); } - // Add autotuning passes for triton gemm. - virtual Status AddTritonGemmAutotuningPasses( + // Add autotuning passes for GEMM fusions. + virtual absl::Status AddGemmFusionAutotuningPasses( HloPassPipeline* pipeline, HloModule* hlo_module, AutotuneConfig& autotune_config, tsl::thread::ThreadPool* thread_pool) { - return OkStatus(); + return absl::OkStatus(); } // Add passes that convert HLO operations to custom kernels. - virtual Status AddCustomKernelReplacementPasses( + virtual absl::Status AddCustomKernelReplacementPasses( HloPassPipeline* pipeline, const DebugOptions& debug_options) { - return OkStatus(); + return absl::OkStatus(); + } + + // Runs CUDNN fusion compiler pass. + virtual absl::Status RunCudnnFusionCompilerPass( + HloModule* module, se::StreamExecutor* stream_exec, + Thunk::BinaryMap* dnn_compiled_graphs) { + return absl::OkStatus(); } + AlgebraicSimplifierOptions GetAlgebraicSimplifierOptions( + const HloModuleConfig& config); + private: struct CompileResultWithMetadata { BackendCompileResult backend_result; @@ -158,56 +180,54 @@ class GpuCompiler : public LLVMCompiler { }; // Schedule and compile the module. - StatusOr CompileToBackendResult( + absl::StatusOr CompileToBackendResult( HloModule* module, llvm::LLVMContext* llvm_context, se::StreamExecutor* executor, const CompileOptions& options, const se::DeviceDescription& gpu_device_info); - StatusOr CompileToTargetBinary( + absl::StatusOr CompileToTargetBinary( const HloModuleConfig& module_config, llvm::Module* llvm_module, se::GpuComputeCapability gpu_version, se::StreamExecutor* stream_exec, const CompileOptions& options, const HloModule* debug_module); - StatusOr CompileSingleModule( + absl::StatusOr CompileSingleModule( const HloModuleConfig& module_config, se::GpuComputeCapability gpu_version, const HloModule* debug_module, llvm::Module* llvm_module, bool relocatable, const CompileOptions& options, std::optional shard_number); - Status LoadAutotuneResultsFromFile(const DebugOptions& debug_options); - Status SerializeAutotuneResultsToFile(const DebugOptions& debug_options); + absl::Status LoadAutotuneResultsFromFile(const DebugOptions& debug_options); + absl::Status SerializeAutotuneResultsToFile( + const DebugOptions& debug_options); // During compilation with device, stream_exec != null and autotune_results // == null. During deviceless AOT compilation, stream_exec == null and // autotune_results != null. - Status OptimizeHloModule(HloModule* hlo_module, - se::StreamExecutor* stream_exec, - const CompileOptions& options, - const TargetConfig& gpu_target_config); + absl::Status OptimizeHloModule(HloModule* hlo_module, + se::StreamExecutor* stream_exec, + const CompileOptions& options, + const TargetConfig& gpu_target_config); - virtual Status OptimizeHloConvolutionCanonicalization( + virtual absl::Status OptimizeHloConvolutionCanonicalization( HloModule* hlo_module, se::GpuComputeCapability gpu_version, se::dnn::VersionInfo dnn_version, se::DeviceMemoryAllocator* device_allocator) = 0; - virtual HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() const { - return &FusionCanShareBufferHint; - } - // TODO(timshen): Replace `debug_module` with some portable debug information // that accommodates both HLO and MLIR. - virtual StatusOr CompileTargetBinary( + virtual absl::StatusOr CompileTargetBinary( const HloModuleConfig& module_config, llvm::Module* llvm_module, se::GpuComputeCapability gpu_version, bool relocatable, const HloModule* debug_module, const CompileOptions& options) = 0; - Status PrepareHloModuleForIrEmitting(HloModule* hlo_module); + absl::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module); - virtual StatusOr CanUseLinkModules(const HloModuleConfig& config) { + virtual absl::StatusOr CanUseLinkModules( + const HloModuleConfig& config) { return false; } - virtual StatusOr> LinkModules( + virtual absl::StatusOr> LinkModules( se::StreamExecutor* stream_exec, std::vector> modules, const DebugOptions& debug_options) { diff --git a/xla/service/gpu/gpu_compiler_test.cc b/xla/service/gpu/gpu_compiler_test.cc index 5424f31a29dd2..4e886b46e5408 100644 --- a/xla/service/gpu/gpu_compiler_test.cc +++ b/xla/service/gpu/gpu_compiler_test.cc @@ -1,4 +1,5 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +#include "xla/service/gpu/gpu_compiler.h" +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,23 +14,39 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include #include #include #include -#include "absl/base/log_severity.h" -#include "absl/log/scoped_mock_log.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/autotune_results.pb.h" -#include "xla/service/gpu/horizontal_loop_fusion.h" +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/executable.h" +#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/gpu/metrics.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/service/xla_debug_info_manager.h" #include "xla/tests/hlo_test_base.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/path.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" namespace xla { namespace gpu { @@ -43,9 +60,13 @@ using ::testing::TempDir; class GpuCompilerTest : public HloTestBase { public: - StatusOr> AssignBuffers(HloModule* module) { + absl::Status Schedule(HloModule* module) { auto compiler = backend().compiler(); - return compiler->AssignBuffers(module, backend().default_stream_executor()); + const se::DeviceDescription& gpu_device_info = + backend().default_stream_executor()->GetDeviceDescription(); + TF_RETURN_IF_ERROR(ScheduleGpuModule(module, 4, gpu_device_info).status()); + return tensorflow::down_cast(compiler) + ->RunPostSchedulingPipelines(module, 4 * 1024 * 1024, gpu_device_info); } }; @@ -298,7 +319,7 @@ ENTRY main { EXPECT_EQ(while_op->while_body()->root_instruction()->operand(1)->opcode(), HloOpcode::kCopy); - TF_ASSERT_OK_AND_ASSIGN(auto buffer_assignment, AssignBuffers(module.get())); + TF_ASSERT_OK(Schedule(module.get())); EXPECT_EQ(CountCopies(*module), 4); module->entry_computation()->root_instruction(); while_op = root->operand(0)->operand(0); @@ -307,6 +328,73 @@ ENTRY main { HloOpcode::kAllGatherDone); } +TEST_F(GpuCompilerTest, + GemmFusionIsNoOpWhenGemmFusionAutotunerFallsBackToCublas) { + const absl::string_view hlo_string = R"( +HloModule test + +ENTRY main { + param_0 = bf16[3,32,1024,4,1024]{4,3,2,1,0} parameter(0) + param_1 = bf16[4,3,32,1024]{3,2,1,0} parameter(1) + param_2 = s32[] parameter(2) + constant_0 = s32[] constant(0) + dynamic-slice_0 = bf16[1,3,32,1024]{3,2,1,0} dynamic-slice(param_1, param_2, constant_0, constant_0, constant_0), dynamic_slice_sizes={1,3,32,1024} + reshape_0 = bf16[3,32,1024]{2,1,0} reshape(dynamic-slice_0) + broadcast_0 = bf16[3,32,1024,4,1024]{2,1,4,3,0} broadcast(reshape_0), dimensions={0,1,2} + add_0 = bf16[3,32,1024,4,1024]{4,3,2,1,0} add(param_0, broadcast_0) + transpose_0 = bf16[3,4,1024,32,1024]{2,1,4,3,0} transpose(add_0), dimensions={0,3,4,1,2} + slice_0 = bf16[1,4,1024,32,1024]{4,3,2,1,0} slice(transpose_0), slice={[0:1], [0:4], [0:1024], [0:32], [0:1024]} + reshape_1 = bf16[4,1024,32,1024]{3,2,1,0} reshape(slice_0) + copy_0 = bf16[4,1024,32,1024]{3,2,1,0} copy(reshape_1) + constant_1 = bf16[] constant(0.08838) + broadcast_1 = bf16[4,1024,32,1024]{3,2,1,0} broadcast(constant_1), dimensions={} + multiply_0 = bf16[4,1024,32,1024]{3,2,1,0} multiply(copy_0, broadcast_1) + slice_1 = bf16[1,4,1024,32,1024]{4,3,2,1,0} slice(transpose_0), slice={[1:2], [0:4], [0:1024], [0:32], [0:1024]} + reshape_2 = bf16[4,1024,32,1024]{3,2,1,0} reshape(slice_1) + copy_1 = bf16[4,1024,32,1024]{3,2,1,0} copy(reshape_2) + ROOT dot_0 = bf16[4,32,1024,1024]{3,2,1,0} dot(multiply_0, copy_1), lhs_batch_dims={0,2}, lhs_contracting_dims={3}, rhs_batch_dims={0,2}, rhs_contracting_dims={3} +} +)"; + + HloModuleConfig config; + DebugOptions triton_enabled_debug_options = GetDebugOptionsForTest(); + triton_enabled_debug_options.set_xla_gpu_enable_address_computation_fusion( + false); + config.set_debug_options(triton_enabled_debug_options); + config.set_replica_count(1); + config.set_num_partitions(1); + + // Load autotuning DB. We shouldn't depend on actual execution times in a unit + // test. + std::string path = + tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu", + "gpu_compiler_test_autotune_db.textproto"); + TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(path)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string, config)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr triton_enabled_module, + GetOptimizedModule(std::move(module))); + AutotunerUtil::ClearAutotuneResults(); + DebugOptions triton_disabled_debug_options = GetDebugOptionsForTest(); + triton_disabled_debug_options.set_xla_gpu_enable_address_computation_fusion( + false); + triton_disabled_debug_options.set_xla_gpu_enable_triton_gemm(false); + config.set_debug_options(triton_disabled_debug_options); + TF_ASSERT_OK_AND_ASSIGN(module, + ParseAndReturnVerifiedModule(hlo_string, config)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr triton_disabled_module, + GetOptimizedModule(std::move(module))); + // Make sure autotuner falls back to cuBLAS when enabling triton gemm + const HloInstruction* root = + triton_enabled_module->entry_computation()->root_instruction(); + const HloInstruction* custom_op = root->operand(0)->operand(0); + EXPECT_TRUE(custom_op->IsCustomCall("__cublas$gemm")); + // Make sure that the module has the same number of computations with/without + // enabling triton gemm + EXPECT_EQ(triton_enabled_module->computation_count(), + triton_disabled_module->computation_count()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/gpu_compiler_test_autotune_db.textproto b/xla/service/gpu/gpu_compiler_test_autotune_db.textproto new file mode 100644 index 0000000000000..0f81bf9ae8690 --- /dev/null +++ b/xla/service/gpu/gpu_compiler_test_autotune_db.textproto @@ -0,0 +1,25 @@ +version: 3 +results { + device: "sm_9.0 with 84942979072B RAM, 132 cores, 1980000KHz clock, 2619000KHz mem clock, 52428800B L2$" + hlo: "(bf16[128,1024,1024]{2,1,0}, s8[33554432]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"],\"algorithm\":\"ALG_UNSET\"},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"1048576\",\"rhs_stride\":\"1048576\",\"grad_x\":false,\"grad_y\":false},\"force_earliest_schedule\":false}" + result { + gemm { + algorithm: -1 + } + run_time { + nanos: 657376 + } + } +} +results { + device: "sm_9.0 with 84942979072B RAM, 132 cores, 1980000KHz clock, 2619000KHz mem clock, 52428800B L2$" + hlo: "{\n tmp_0 = bf16[1,4,32,1024,1024]{4,3,2,1,0} parameter(0)\n tmp_1 = bf16[] constant({...})\n tmp_2 = bf16[1,4,32,1024,1024]{4,3,2,1,0} broadcast(bf16[] tmp_1), dimensions={}\n tmp_3 = bf16[1,4,32,1024,1024]{4,3,2,1,0} multiply(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_0, bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_2)\n tmp_4 = bf16[4,32,1024,1024]{3,2,1,0} bitcast(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_3)\n tmp_5 = bf16[4,32,1024,1024]{3,2,1,0} transpose(bf16[4,32,1024,1024]{3,2,1,0} tmp_4), dimensions={0,1,3,2}\n tmp_6 = bf16[128,1024,1024]{2,1,0} bitcast(bf16[4,32,1024,1024]{3,2,1,0} tmp_5)\n tmp_7 = bf16[1,4,32,1024,1024]{4,3,2,1,0} parameter(1)\n tmp_8 = bf16[128,1024,1024]{2,1,0} bitcast(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_7)\n tmp_9 = bf16[128,1024,1024]{2,1,0} dot(bf16[128,1024,1024]{2,1,0} tmp_6, bf16[128,1024,1024]{2,1,0} tmp_8), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}\n ROOT tmp_10 = bf16[4,32,1024,1024]{3,2,1,0} bitcast(bf16[128,1024,1024]{2,1,0} tmp_9)\n}" + result { + gemm { + algorithm: -1 + } + run_time { + nanos: 854688 + } + } +} \ No newline at end of file diff --git a/xla/service/gpu/gpu_constants.h b/xla/service/gpu/gpu_constants.h index 68eda89b715fa..e9bb204eec513 100644 --- a/xla/service/gpu/gpu_constants.h +++ b/xla/service/gpu/gpu_constants.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/gpu_conv_padding_legalization.cc b/xla/service/gpu/gpu_conv_padding_legalization.cc index b64051fc845c9..bbe037280494b 100644 --- a/xla/service/gpu/gpu_conv_padding_legalization.cc +++ b/xla/service/gpu/gpu_conv_padding_legalization.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,17 +15,31 @@ limitations under the License. #include "xla/service/gpu/gpu_conv_padding_legalization.h" -#include - +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/literal.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/literal_util.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/shape_inference.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -402,7 +416,7 @@ bool GpuConvPaddingLegalization::CanonicalizeBackwardInputConvolution( return true; } -StatusOr GpuConvPaddingLegalization::RunOnComputation( +absl::StatusOr GpuConvPaddingLegalization::RunOnComputation( HloComputation* computation) { bool changed = false; std::vector convs; @@ -429,7 +443,7 @@ StatusOr GpuConvPaddingLegalization::RunOnComputation( return changed; } -StatusOr GpuConvPaddingLegalization::Run( +absl::StatusOr GpuConvPaddingLegalization::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/gpu/gpu_conv_padding_legalization.h b/xla/service/gpu/gpu_conv_padding_legalization.h index 15e5de46e5a93..32e0238bed1b3 100644 --- a/xla/service/gpu/gpu_conv_padding_legalization.h +++ b/xla/service/gpu/gpu_conv_padding_legalization.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,12 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GPU_CONV_PADDING_LEGALIZATION_H_ #define XLA_SERVICE_GPU_GPU_CONV_PADDING_LEGALIZATION_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" namespace xla { @@ -31,12 +37,12 @@ class GpuConvPaddingLegalization : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; private: - StatusOr RunOnComputation(HloComputation* computation); + absl::StatusOr RunOnComputation(HloComputation* computation); // Returns if any changes are made to the parent computation. bool CanonicalizeForwardConvolution(HloInstruction* conv); bool CanonicalizeBackwardFilterConvolution(HloInstruction* backward_conv); diff --git a/xla/service/gpu/gpu_conv_padding_legalization_test.cc b/xla/service/gpu/gpu_conv_padding_legalization_test.cc index 147b3187c25e0..edaf9d053d77c 100644 --- a/xla/service/gpu/gpu_conv_padding_legalization_test.cc +++ b/xla/service/gpu/gpu_conv_padding_legalization_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,11 +16,10 @@ limitations under the License. #include "xla/service/gpu/gpu_conv_padding_legalization.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" diff --git a/xla/service/gpu/gpu_conv_rewriter.cc b/xla/service/gpu/gpu_conv_rewriter.cc index 838de80674423..8ba1af1633887 100644 --- a/xla/service/gpu/gpu_conv_rewriter.cc +++ b/xla/service/gpu/gpu_conv_rewriter.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/gpu_conv_rewriter.h" +#include #include #include #include @@ -24,17 +25,24 @@ limitations under the License. #include #include -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/permutation_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -43,6 +51,30 @@ namespace { using ConvolutionMatch = std::optional< std::tuple>; +// Determine whether conv2d is equal to conv1d. +bool MaybeConv1dToConv2d(HloInstruction* conv) { + if (conv->window().dimensions().size() != 2) { + return false; + } + if (conv->operand(1)->opcode() != HloOpcode::kReshape) { + return false; + } + auto filter = conv->operand(1); + std::optional reshape_degenerate = + filter->ReshapeMerelyInsertsOrDeletes1SizedDimensions(); + if (reshape_degenerate.has_value() && + reshape_degenerate->deleted_dimensions.empty() && + reshape_degenerate->inserted_dimensions.size() == 1) { + auto dnums = conv->convolution_dimension_numbers(); + for (auto dim : dnums.kernel_spatial_dimensions()) { + if (dim == reshape_degenerate->inserted_dimensions[0]) { + return true; + } + } + } + return false; +} + bool CanImplementAsGpuForwardConv(HloInstruction* conv) { const ConvolutionDimensionNumbers& dnums = conv->convolution_dimension_numbers(); @@ -145,10 +177,18 @@ ConvolutionMatch MatchBackwardFilter(HloInstruction* conv) { // convolutions have very small kernel dimensions, while in the backward pass // "kernel dimensions" are large. If kernel dimensions are smaller than the // output dimensions, return foward conv; otherwise proceed with backward - // filter conv. - if ((kernel_spatial_dims.empty() || - conv->operand(1)->shape().dimensions(kernel_spatial_dims[0]) <= - conv->shape().dimensions(output_spatial_dims[0])) && + // filter conv. But for conv1d, it is not same. Due to conv1d always reshape + // 1D-filter to 2D-filter, even backward or forward will exist one small + // kernel dimension. We should handle this special case. + int small_kernel_dimension_num = 0; + for (int i = 0; i < kernel_spatial_dims.size(); ++i) { + if (conv->operand(1)->shape().dimensions(kernel_spatial_dims[i]) <= + conv->shape().dimensions(output_spatial_dims[i])) { + small_kernel_dimension_num += 1; + } + } + if ((kernel_spatial_dims.empty() || small_kernel_dimension_num > 1 || + (!MaybeConv1dToConv2d(conv) && small_kernel_dimension_num == 1)) && !window_util::HasWindowDilation(conv->window())) { VLOG(1) << conv->ToString() << " is a regular forward convolution. No need " @@ -309,10 +349,18 @@ ConvolutionMatch MatchBackwardInput(HloInstruction* conv) { reverse_filter->opcode() == HloOpcode::kReverse && absl::c_is_permutation(dnums.kernel_spatial_dimensions(), reverse_filter->dimensions()); + // For conv1d which reshape to conv2d, filter reverse pattern is + // reshape(reverse(filter)). It seems we can reuse conv2d backward input + // pattern matcher, but after algsimp pass, this pattern will change to + // reverse(reshape(filter)) and fail to match. So matching conv1d backward + // input need different processing logic. + bool is_reversed_conv1d_filter = + MaybeConv1dToConv2d(conv) && + reverse_filter->operand(0)->opcode() == HloOpcode::kReverse; bool is_1x1_filter = absl::c_all_of(conv->window().dimensions(), [](const WindowDimension& d) { return d.size() == 1; }); - if (!is_reversed_filter && + if (!is_reversed_filter && !is_reversed_conv1d_filter && !(window_util::HasBaseDilation(conv->window()) && (reverse_filter->IsConstant() || is_1x1_filter))) { VLOG(1) << "Can't match to backwards convolution. Either filter is not " @@ -484,6 +532,10 @@ ConvolutionMatch MatchBackwardInput(HloInstruction* conv) { // One reverse is subsumed by the cuDNN call. if (rhs->opcode() == HloOpcode::kReverse) { rhs = rhs->mutable_operand(0); + } else if (is_reversed_conv1d_filter) { + auto src = rhs->mutable_operand(0)->mutable_operand(0); + rhs = conv->parent()->AddInstruction( + HloInstruction::CreateReshape(rhs->shape(), src)); } if (conv->feature_group_count() == 1) { return std::make_tuple(new_window, dnums, rhs); @@ -662,7 +714,8 @@ CudnnConvBackendConfig GetDefaultBackendConfig() { // Helper function to create a custom_call instruction to replace the given // conv instruction -static StatusOr CreateCustomCallHelper(HloInstruction* conv) { +static absl::StatusOr CreateCustomCallHelper( + HloInstruction* conv) { if (ConvolutionMatch m = MatchBackwardInput(conv)) { auto& [window, dnums, rhs] = *m; return CreateGpuConv(kCudnnConvBackwardInputCallTarget, conv->shape(), @@ -696,7 +749,7 @@ static StatusOr CreateCustomCallHelper(HloInstruction* conv) { } // Tries to rewrite a single convolution into a call to cudnn/miopen. -StatusOr RunOnInstruction(HloInstruction* conv) { +absl::StatusOr RunOnInstruction(HloInstruction* conv) { CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); TF_ASSIGN_OR_RETURN(HloInstruction * custom_call, @@ -705,8 +758,10 @@ StatusOr RunOnInstruction(HloInstruction* conv) { return false; } - TF_RETURN_IF_ERROR( - custom_call->set_backend_config(GetDefaultBackendConfig())); + GpuBackendConfig gpu_backend_config; + *gpu_backend_config.mutable_cudnn_conv_backend_config() = + GetDefaultBackendConfig(); + TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config)); VLOG(1) << "Replacing convolution " << conv->ToString() << " with " << custom_call->ToString(); @@ -722,7 +777,7 @@ StatusOr RunOnInstruction(HloInstruction* conv) { // Rewrites the convolutions in the given computation into calls to // cudnn/miopen. // Returns true if it made any changes. -StatusOr RunOnComputation(HloComputation* computation) { +absl::StatusOr RunOnComputation(HloComputation* computation) { std::vector convs; for (auto* hlo : computation->instructions()) { if (hlo->opcode() == HloOpcode::kConvolution) { @@ -739,7 +794,7 @@ StatusOr RunOnComputation(HloComputation* computation) { } } // namespace -StatusOr GpuConvRewriter::Run( +absl::StatusOr GpuConvRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), before:\n" + module->ToString()); diff --git a/xla/service/gpu/gpu_conv_rewriter.h b/xla/service/gpu/gpu_conv_rewriter.h index 4e36a69ceaa11..526271e349cf4 100644 --- a/xla/service/gpu/gpu_conv_rewriter.h +++ b/xla/service/gpu/gpu_conv_rewriter.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,9 +16,10 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GPU_CONV_REWRITER_H_ #define XLA_SERVICE_GPU_GPU_CONV_REWRITER_H_ -#include -#include - +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" @@ -43,7 +44,7 @@ class GpuConvRewriter : public HloModulePass { static bool ConvIsLowerable(HloInstruction* conv); using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/gpu_conv_rewriter_test.cc b/xla/service/gpu/gpu_conv_rewriter_test.cc index 4f7a59b8c5114..add6e0f53ea50 100644 --- a/xla/service/gpu/gpu_conv_rewriter_test.cc +++ b/xla/service/gpu/gpu_conv_rewriter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,18 +15,26 @@ limitations under the License. #include "xla/service/gpu/gpu_conv_rewriter.h" +#include +#include + +#include "absl/log/check.h" +#include "absl/strings/str_format.h" +#include "xla/array4d.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" #include "xla/protobuf_util.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/service/shape_inference.h" +#include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -642,7 +650,9 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveConstantFilter) { 0))); } -TEST_F(GpuConvRewriterTest, TestBackwardFilterPattern) { +TEST_F(GpuConvRewriterTest, TestBackwardFilterPatternMatch) { + // All filter dimensions are larger than the corresponding output dimensions. + // This must be a backward filter convolution. const std::string module_str = absl::StrFormat(R"( HloModule Test @@ -662,6 +672,74 @@ TEST_F(GpuConvRewriterTest, TestBackwardFilterPattern) { 0))); } +TEST_F(GpuConvRewriterTest, TestBackwardFilterPatternNoMatch) { + // At least one filter dimension is smaller than the corresponding output + // dimension. This must be a forward convolution. + const std::string module_str = absl::StrFormat(R"( + HloModule Test + + ENTRY Test { + input = f32[8,128,2,32] parameter(0) + filter = f32[3,3,128,128] parameter(1) + + ROOT conv = f32[8,128,2,32] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01 + })"); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + + EXPECT_TRUE(RunPass(m.get())); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::GetTupleElement( + m::CustomCall({kCudnnConvForwardCallTarget}, m::Parameter(0), + m::Parameter(1)), + 0))); +} + +TEST_F(GpuConvRewriterTest, TestConv1dBackwardFilterPatternMatch) { + // There exist one kernel dimension equal to output dimension, regard + // it as backward filter if conv is 1d. + const std::string module_str = absl::StrFormat(R"( + HloModule Test + + ENTRY Test { + input = f32[8,256,128] parameter(0) + filter = f32[8,254,128] parameter(1) + reshape.1 = f32[8,1,256,128] reshape(input) + reshape.2 = f32[8,1,254,128] reshape(filter) + ROOT conv = f32[1,3,128,128] convolution(reshape.1, reshape.2), window={size=1x254}, dim_labels=f01b_i01o->01bf + })"); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + + EXPECT_TRUE(RunPass(m.get())); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::GetTupleElement( + m::CustomCall({kCudnnConvBackwardFilterCallTarget}, + m::Reshape(), m::Reshape()), + 0))); +} + +TEST_F(GpuConvRewriterTest, TestConv1dBackwardInputPatternMatch) { + // For conv1d backward input, filter may reverse first and then reshape. + const std::string module_str = absl::StrFormat(R"( + HloModule Test + + ENTRY Test { + input = f32[8,254,128] parameter(0) + filter = f32[3,128,128] parameter(1) + reverse = f32[3,128,128] reverse(filter), dimensions={0} + reshape.1 = f32[8,1,254,128] reshape(input) + reshape.2 = f32[1,3,128,128] reshape(reverse) + ROOT conv = f32[8,1,256,128] convolution(reshape.1, reshape.2), window={size=1x3 pad=0_0x2_2}, dim_labels=b01f_01oi->b01f + })"); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + + EXPECT_TRUE(RunPass(m.get())); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::GetTupleElement( + m::CustomCall({kCudnnConvBackwardInputCallTarget}, + m::Reshape(), m::Reshape()), + 0))); +} + } // anonymous namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/gpu_conv_runner.cc b/xla/service/gpu/gpu_conv_runner.cc index 5ca19f64c8713..63d71a7ee0c3c 100644 --- a/xla/service/gpu/gpu_conv_runner.cc +++ b/xla/service/gpu/gpu_conv_runner.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,17 +15,32 @@ limitations under the License. #include "xla/service/gpu/gpu_conv_runner.h" +#include +#include +#include +#include +#include #include -#include "absl/strings/str_cat.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "xla/layout_util.h" +#include "absl/types/span.h" +#include "Eigen/Core" // from @eigen_archive +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/stream_executor_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/lazy_op_runner.h" #include "xla/util.h" +#include "tsl/platform/ml_dtypes.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -41,16 +56,15 @@ using se::dnn::FilterDescriptor; using se::dnn::FilterLayout; template -Status RunGpuConvUnfused(const GpuConvParams& params, se::Stream* stream, - RunConvOptions options, - DeviceMemory input_buf, - DeviceMemory filter_buf, - DeviceMemory output_buf, - DeviceMemoryBase scratch_memory) { +absl::Status RunGpuConvUnfused(const GpuConvParams& params, se::Stream* stream, + RunConvOptions options, + DeviceMemory input_buf, + DeviceMemory filter_buf, + DeviceMemory output_buf, + DeviceMemoryBase scratch_memory) { if (params.config->conv_result_scale != 1) { - return InternalError( - "StreamExecutor doesn't support scaled convolution: %lf.", - params.config->conv_result_scale); + return Internal("StreamExecutor doesn't support scaled convolution: %lf.", + params.config->conv_result_scale); } TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind, @@ -87,16 +101,15 @@ Status RunGpuConvUnfused(const GpuConvParams& params, se::Stream* stream, } template -Status RunGpuConvGraph(const GpuConvParams& params, se::Stream* stream, - RunConvOptions options, - DeviceMemory input_buf, - DeviceMemory filter_buf, - DeviceMemory output_buf, - DeviceMemoryBase scratch_memory) { +absl::Status RunGpuConvGraph(const GpuConvParams& params, se::Stream* stream, + RunConvOptions options, + DeviceMemory input_buf, + DeviceMemory filter_buf, + DeviceMemory output_buf, + DeviceMemoryBase scratch_memory) { if (params.config->conv_result_scale != 1) { - return InternalError( - "StreamExecutor doesn't support scaled convolution: %lf.", - params.config->conv_result_scale); + return Internal("StreamExecutor doesn't support scaled convolution: %lf.", + params.config->conv_result_scale); } TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind, @@ -142,17 +155,15 @@ Status RunGpuConvGraph(const GpuConvParams& params, se::Stream* stream, } template -Status RunGpuConvForwardActivation(const GpuConvParams& params, - se::Stream* stream, RunConvOptions options, - DeviceMemory input_buf, - DeviceMemory filter_buf, - DeviceMemory output_buf, - DeviceMemoryBase scratch_memory) { +absl::Status RunGpuConvForwardActivation( + const GpuConvParams& params, se::Stream* stream, RunConvOptions options, + DeviceMemory input_buf, DeviceMemory filter_buf, + DeviceMemory output_buf, DeviceMemoryBase scratch_memory) { se::DeviceMemory side_input(params.fusion->side_input_buf); // If there is no side input, use output as the side input. if (side_input.is_null()) { if (params.config->fusion->side_input_scale != 0) { - return InternalError( + return Internal( "Side input scale is not 0, yet no side input buffer is " "provided"); } @@ -212,12 +223,12 @@ Status RunGpuConvForwardActivation(const GpuConvParams& params, template ::value>::type* = nullptr> -Status RunGpuConvInternalImpl(const GpuConvParams& params, se::Stream* stream, - RunConvOptions options, - DeviceMemory input_buf, - DeviceMemory filter_buf, - DeviceMemory output_buf, - DeviceMemoryBase scratch_memory) { +absl::Status RunGpuConvInternalImpl(const GpuConvParams& params, + se::Stream* stream, RunConvOptions options, + DeviceMemory input_buf, + DeviceMemory filter_buf, + DeviceMemory output_buf, + DeviceMemoryBase scratch_memory) { switch (params.config->kind) { case CudnnConvKind::kForward: case CudnnConvKind::kBackwardInput: @@ -233,19 +244,19 @@ Status RunGpuConvInternalImpl(const GpuConvParams& params, se::Stream* stream, output_buf, scratch_memory); } } - return OkStatus(); + return absl::OkStatus(); } // Specialization for integer types. Only two forward convolutions are allowed. template ::value>::type* = nullptr> -Status RunGpuConvInternalImpl(const GpuConvParams& params, se::Stream* stream, - RunConvOptions options, - DeviceMemory input_buf, - DeviceMemory filter_buf, - DeviceMemory output_buf, - DeviceMemoryBase scratch_memory) { +absl::Status RunGpuConvInternalImpl(const GpuConvParams& params, + se::Stream* stream, RunConvOptions options, + DeviceMemory input_buf, + DeviceMemory filter_buf, + DeviceMemory output_buf, + DeviceMemoryBase scratch_memory) { switch (params.config->kind) { case CudnnConvKind::kForward: return RunGpuConvUnfused(params, stream, options, input_buf, filter_buf, @@ -255,26 +266,27 @@ Status RunGpuConvInternalImpl(const GpuConvParams& params, se::Stream* stream, params, stream, options, input_buf, filter_buf, output_buf, scratch_memory); default: - return InternalError( + return Internal( "Only convolution kinds kForward and kForwardActivation are " "supported for integer types"); } - return OkStatus(); + return absl::OkStatus(); } template -Status RunGpuConvImpl(const GpuConvParams& params, se::Stream* stream, - se::DeviceMemoryBase scratch_memory, - RunConvOptions options) { +absl::Status RunGpuConvImpl(const GpuConvParams& params, se::Stream* stream, + se::DeviceMemoryBase scratch_memory, + RunConvOptions options) { auto input_buf = se::DeviceMemory(params.input_buf); auto filter_buf = se::DeviceMemory(params.filter_buf); auto output_buf = se::DeviceMemory(params.output_buf); - Status run_status = RunGpuConvInternalImpl( - params, stream, options, input_buf, filter_buf, output_buf, - scratch_memory); + absl::Status run_status = + RunGpuConvInternalImpl( + params, stream, options, input_buf, filter_buf, output_buf, + scratch_memory); - if (run_status != OkStatus()) { + if (!run_status.ok()) { return run_status; } @@ -283,11 +295,11 @@ Status RunGpuConvImpl(const GpuConvParams& params, se::Stream* stream, if (options.runner_cache) { algorithm = options.runner_cache->ToAlgorithmDesc(); } - return InternalError( + return Internal( "Unable to launch convolution with type %s and algorithm %s", CudnnConvKindToString(params.config->kind), algorithm.ToString()); } - return OkStatus(); + return absl::OkStatus(); } int64_t GetVectCSize(DataLayout layout) { @@ -315,7 +327,7 @@ int64_t GetVectCSize(FilterLayout layout) { } // anonymous namespace -StatusOr GetGpuConvConfig( +absl::StatusOr GetGpuConvConfig( const GpuConvDescriptor& desc, const absl::string_view inst_as_string) { GpuConvConfig config; @@ -350,13 +362,13 @@ StatusOr GetGpuConvConfig( config.output_shape = operand1_shape; break; default: - return InternalError("Unknown convolution kind"); + return Internal("Unknown convolution kind"); } if (config.kind == CudnnConvKind::kForwardActivation) { if (!se::dnn::ActivationMode_IsValid(backend_config.activation_mode())) { - return InternalError("Bad activation mode: %s", - backend_config.ShortDebugString()); + return Internal("Bad activation mode: %s", + backend_config.ShortDebugString()); } GpuConvConfig::FusionConfig fusion; @@ -526,13 +538,14 @@ StatusOr GetGpuConvConfig( return config; } -StatusOr GetGpuConvConfig( +absl::StatusOr GetGpuConvConfig( const HloCustomCallInstruction* cudnn_call) { GpuConvDescriptor descriptor; TF_ASSIGN_OR_RETURN(descriptor.kind, GetCudnnConvKind(cudnn_call)); - TF_ASSIGN_OR_RETURN(descriptor.backend_config, - cudnn_call->backend_config()); + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config, + cudnn_call->backend_config()); + descriptor.backend_config = gpu_backend_config.cudnn_conv_backend_config(); descriptor.operand0_shape = cudnn_call->operand(0)->shape(); descriptor.operand1_shape = cudnn_call->operand(1)->shape(); descriptor.result_shape = cudnn_call->shape().tuple_shapes(0); @@ -545,7 +558,7 @@ StatusOr GetGpuConvConfig( return GetGpuConvConfig(descriptor, cudnn_call->ToString()); } -StatusOr GetGpuConvParams( +absl::StatusOr GetGpuConvParams( const GpuConvConfig& config, absl::Span operand_buffers, absl::Span result_buffers) { @@ -589,11 +602,11 @@ StatusOr GetGpuConvParams( return params; } -Status RunGpuConv(const gpu::GpuConvConfig& config, - absl::Span operand_buffers, - absl::Span result_buffers, - se::DeviceMemoryBase scratch_memory, se::Stream* stream, - RunConvOptions options) { +absl::Status RunGpuConv(const gpu::GpuConvConfig& config, + absl::Span operand_buffers, + absl::Span result_buffers, + se::DeviceMemoryBase scratch_memory, se::Stream* stream, + RunConvOptions options) { TF_ASSIGN_OR_RETURN( GpuConvParams params, GetGpuConvParams(config, operand_buffers, result_buffers)); @@ -602,14 +615,14 @@ Status RunGpuConv(const gpu::GpuConvConfig& config, switch (input_primitive_type) { case F8E4M3FN: if (config.kind != CudnnConvKind::kForwardGraph) { - return InternalError("FP8 convolution requires graph mode."); + return Internal("FP8 convolution requires graph mode."); } return RunGpuConvImpl(params, stream, scratch_memory, options); case F8E5M2: if (config.kind != CudnnConvKind::kForwardGraph) { - return InternalError("FP8 convolution requires graph mode."); + return Internal("FP8 convolution requires graph mode."); } return RunGpuConvImpl(params, stream, scratch_memory, diff --git a/xla/service/gpu/gpu_conv_runner.h b/xla/service/gpu/gpu_conv_runner.h index 748db93fabab7..31efbd2fecd7b 100644 --- a/xla/service/gpu/gpu_conv_runner.h +++ b/xla/service/gpu/gpu_conv_runner.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,20 +16,26 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GPU_CONV_RUNNER_H_ #define XLA_SERVICE_GPU_GPU_CONV_RUNNER_H_ +#include +#include +#include #include #include +#include +#include #include -#include "xla/hlo/ir/hlo_instruction.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" +#include "xla/shape.h" #include "xla/status.h" -#include "xla/statusor.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/lazy_op_runner.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" namespace xla { @@ -212,11 +218,11 @@ struct RunConvOptions { // allocator and take note of how much memory is used. The next time you call // the same conv, you can provide an explicitly preallocated scratch buffer of // that size, if you like. -Status RunGpuConv(const GpuConvConfig& conv_config, - absl::Span operand_buffers, - absl::Span result_buffers, - se::DeviceMemoryBase scratch_memory, se::Stream* stream, - RunConvOptions = {}); +absl::Status RunGpuConv(const GpuConvConfig& conv_config, + absl::Span operand_buffers, + absl::Span result_buffers, + se::DeviceMemoryBase scratch_memory, se::Stream* stream, + RunConvOptions = {}); // Struct to describe properties of a convolution without being tied to specific // IR. Will be used to help build Convolution thunks from either XLA HLO or @@ -234,17 +240,17 @@ struct GpuConvDescriptor { }; // Returns the convolution configuration given a XLA HLO instruction. -StatusOr GetGpuConvConfig( +absl::StatusOr GetGpuConvConfig( const HloCustomCallInstruction* cudnn_call); // Returns the convolution configuration given a convolution descriptor `desc` // and a string representation of the convolution instruction `inst_as_string` // (for error reporting). -StatusOr GetGpuConvConfig(const GpuConvDescriptor& desc, - absl::string_view inst_as_string); +absl::StatusOr GetGpuConvConfig( + const GpuConvDescriptor& desc, absl::string_view inst_as_string); // Implementation details exposed for debugging and log analysis. -StatusOr GetGpuConvParams( +absl::StatusOr GetGpuConvParams( const GpuConvConfig& conv_config, absl::Span operand_buffers, absl::Span result_buffers); diff --git a/xla/service/gpu/gpu_convert_async_collectives_to_sync.cc b/xla/service/gpu/gpu_convert_async_collectives_to_sync.cc index 0e3169ce27901..b8c87e2e1978a 100644 --- a/xla/service/gpu/gpu_convert_async_collectives_to_sync.cc +++ b/xla/service/gpu/gpu_convert_async_collectives_to_sync.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,12 +19,20 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { -Status GpuConvertAsyncCollectivesToSync::ConvertAsyncInstructionsToSync( +absl::Status GpuConvertAsyncCollectivesToSync::ConvertAsyncInstructionsToSync( HloComputation* computation, absl::Span> async_pairs) const { @@ -33,7 +41,10 @@ Status GpuConvertAsyncCollectivesToSync::ConvertAsyncInstructionsToSync( sync_config.set_is_sync(true); for (auto& [async_start, async_done] : async_pairs) { // Tag the async start with is_sync = true. - TF_RETURN_IF_ERROR(async_start->set_backend_config(sync_config)); + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + async_start->backend_config()); + *gpu_config.mutable_collective_backend_config() = sync_config; + TF_RETURN_IF_ERROR(async_start->set_backend_config(gpu_config)); replaced_ops[async_start] = nullptr; replaced_ops[async_done] = async_start; } @@ -62,7 +73,7 @@ Status GpuConvertAsyncCollectivesToSync::ConvertAsyncInstructionsToSync( new_sequence.push_back(instr); } module->schedule().set_sequence(computation, new_sequence); - return OkStatus(); + return absl::OkStatus(); } } // namespace gpu diff --git a/xla/service/gpu/gpu_convert_async_collectives_to_sync.h b/xla/service/gpu/gpu_convert_async_collectives_to_sync.h index 01713f8dd3cef..ea56f7a91914c 100644 --- a/xla/service/gpu/gpu_convert_async_collectives_to_sync.h +++ b/xla/service/gpu/gpu_convert_async_collectives_to_sync.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,11 @@ limitations under the License. #include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/convert_async_collectives_to_sync.h" namespace xla { @@ -30,7 +35,7 @@ class GpuConvertAsyncCollectivesToSync : public ConvertAsyncCollectivesToSync { return "gpu-convert-async-collectives-to-sync"; } - Status ConvertAsyncInstructionsToSync( + absl::Status ConvertAsyncInstructionsToSync( HloComputation* computation, absl::Span> async_pairs) const override; diff --git a/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc b/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc index 6f504e8d877b2..4daeb62905f8a 100644 --- a/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc +++ b/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,10 +17,17 @@ limitations under the License. #include +#include +#include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/tests/hlo_test_base.h" +#include "xla/util.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -30,16 +37,16 @@ using ::testing::IsFalse; using ::testing::IsTrue; // Note: The pass only processes modules that are already scheduled. If the test -// does not work as epxected, make sure to check if "is_scheduled=true" is added +// does not work as expected, make sure to check if "is_scheduled=true" is added // to the HLO module string. class GpuConvertAsyncCollectivesToSyncTest : public HloTestBase { public: - Status RunPass(HloModule *module, bool expect_change, - HloPredicate is_nop = {}) { + absl::Status RunPass(HloModule *module, bool expect_change, + HloPredicate is_nop = {}) { TF_ASSIGN_OR_RETURN(bool changed, GpuConvertAsyncCollectivesToSync{is_nop}.Run(module)); EXPECT_EQ(changed, expect_change); - return OkStatus(); + return absl::OkStatus(); } // Returns true if the instruction with the given name is synchronous. @@ -48,8 +55,9 @@ class GpuConvertAsyncCollectivesToSyncTest : public HloTestBase { if (inst == nullptr) { return false; } - auto backend_config = - inst->backend_config().value(); + auto backend_config = inst->backend_config() + .value() + .collective_backend_config(); return backend_config.is_sync(); } @@ -102,6 +110,26 @@ TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllReduceWithNop) { TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true, is_nop_simple_)); EXPECT_THAT(IsSync(module.get(), "start"), IsTrue()); } +TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleCollectiveBroadcast) { + const absl::string_view hlo_string = R"( + HloModule test, is_scheduled=true + + collective_broadcast { + p0 = u32[8] parameter(0) + ROOT result = u32[8] collective-broadcast(p0), replica_groups={{0,1}, {2,3}} + } + + ENTRY main { + data = u32[8] parameter(0) + cb-start = ((u32[8]{0}), u32[8]{0}) async-start(u32[8]{0} %data), calls=collective_broadcast + ROOT %ars = u32[8]{0} async-done(((u32[8]{0}), u32[8]{0}) %cb-start), calls=collective_broadcast + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK(RunPass(module.get(), /*expect_change=*/true)); + EXPECT_THAT(IsSync(module.get(), "cb-start"), IsTrue()); +} TEST_F(GpuConvertAsyncCollectivesToSyncTest, SimpleAllReduceWithNonNop) { const absl::string_view hlo_string = R"( diff --git a/xla/service/gpu/gpu_copy_insertion_test.cc b/xla/service/gpu/gpu_copy_insertion_test.cc index ae8123389a7f9..83df4285a889d 100644 --- a/xla/service/gpu/gpu_copy_insertion_test.cc +++ b/xla/service/gpu/gpu_copy_insertion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -22,7 +26,9 @@ limitations under the License. #include "xla/service/copy_insertion.h" #include "xla/service/gpu/buffer_sharing.h" #include "xla/test.h" +#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/gpu_device_info_for_tests.cc b/xla/service/gpu/gpu_device_info_for_tests.cc index 8eb44d45aa56e..a34ce4b1685bc 100644 --- a/xla/service/gpu/gpu_device_info_for_tests.cc +++ b/xla/service/gpu/gpu_device_info_for_tests.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/gpu_device_info_for_tests.h b/xla/service/gpu/gpu_device_info_for_tests.h index 2dab6d5a90c07..633148a5c156e 100644 --- a/xla/service/gpu/gpu_device_info_for_tests.h +++ b/xla/service/gpu/gpu_device_info_for_tests.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/gpu_device_info_test.cc b/xla/service/gpu/gpu_device_info_test.cc deleted file mode 100644 index c9e2ae245a753..0000000000000 --- a/xla/service/gpu/gpu_device_info_test.cc +++ /dev/null @@ -1,155 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "absl/strings/string_view.h" -#include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/gpu/gpu_executor.h" -#include "tsl/platform/test.h" - -#if TENSORFLOW_USE_ROCM -#include "rocm/rocm_config.h" -#endif - -namespace xla { -namespace gpu { -namespace { - -namespace se = stream_executor; - -TEST(DeviceInfoTest, DeviceInfoIsCorrect) { - std::string test_platform = "cuda"; -#if TENSORFLOW_USE_ROCM - test_platform = "rocm"; -#endif - - se::Platform* platform = - se::MultiPlatformManager::PlatformWithName(test_platform).value(); - se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - const GpuDeviceInfo dev_info = GetGpuDeviceInfo(executor); - absl::string_view name(executor->GetDeviceDescription().name()); - if (name == "NVIDIA RTX A6000") { - GpuDeviceInfo test_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - EXPECT_THAT( - dev_info, - ::testing::FieldsAre( - test_info.compute_capability, test_info.threads_per_block_limit, - test_info.threads_per_warp, test_info.shared_memory_per_block, - test_info.shared_memory_per_block_optin, - test_info.shared_memory_per_core, test_info.threads_per_core_limit, - test_info.core_count, test_info.fpus_per_core, - test_info.block_dim_limit_x, test_info.block_dim_limit_y, - test_info.block_dim_limit_z, test_info.memory_bandwidth, - test_info.l2_cache_size, - // Clock rate can vary between base and boost values. - ::testing::Ge(test_info.clock_rate_ghz), - test_info.device_memory_size)); - } else if (name == "Quadro P1000") { - EXPECT_THAT( - dev_info, - ::testing::FieldsAre( - se::GpuComputeCapability(se::CudaComputeCapability(6, 1)), - /*threads_per_block_limit=*/1024, - /*threads_per_warp=*/32, /*shared_memory_per_block=*/48 * 1024, - /*shared_memory_per_block_optin=*/48 * 1024, - /*shared_memory_per_core=*/96 * 1024, - /*threads_per_core_limit=*/2048, /*core_count=*/5, - /*fpus_per_core=*/128, - /*block_dim_limit_x=*/2'147'483'647, - /*block_dim_limit_y=*/65535, - /*block_dim_limit_z=*/65535, - /*memory_bandwidth=*/80'160'000'000, /*l2_cache_size=*/1024 * 1024, - /*clock_rate_ghz=*/::testing::Ge(1.4), - /*device_memory_size=*/4'234'346'496)); - } else if (name == "Tesla P100-SXM2-16GB") { - EXPECT_THAT(dev_info, - ::testing::FieldsAre( - se::GpuComputeCapability(se::CudaComputeCapability(6, 0)), - /*threads_per_block_limit=*/1024, - /*threads_per_warp=*/32, - /*shared_memory_per_block=*/48 * 1024, - /*shared_memory_per_block_optin=*/48 * 1024, - /*shared_memory_per_core=*/64 * 1024, - /*threads_per_core_limit=*/2048, /*core_count=*/56, - /*fpus_per_core=*/64, - /*block_dim_limit_x=*/2'147'483'647, - /*block_dim_limit_y=*/65535, - /*block_dim_limit_z=*/65535, - /*memory_bandwidth=*/732'160'000'000, - /*l2_cache_size=*/4 * 1024 * 1024, - /*clock_rate_ghz=*/::testing::Ge(1.4), - /*device_memory_size=*/17'066'622'976)); - } -#if TF_ROCM_VERSION >= 50500 - else if (name == "AMD Instinct MI210") { // NOLINT - GpuDeviceInfo test_info = TestGpuDeviceInfo::AMDMI210DeviceInfo(); - EXPECT_THAT( - dev_info, - ::testing::FieldsAre( - test_info.compute_capability, test_info.threads_per_block_limit, - test_info.threads_per_warp, test_info.shared_memory_per_block, - test_info.shared_memory_per_block_optin, - test_info.shared_memory_per_core, test_info.threads_per_core_limit, - test_info.core_count, test_info.fpus_per_core, - test_info.block_dim_limit_x, test_info.block_dim_limit_y, - test_info.block_dim_limit_z, test_info.memory_bandwidth, - test_info.l2_cache_size, ::testing::Ge(test_info.clock_rate_ghz), - dev_info.device_memory_size)); - } else if (name == "AMD Instinct MI100") { - EXPECT_THAT( - dev_info, - ::testing::FieldsAre( - se::GpuComputeCapability(se::RocmComputeCapability("gfx908")), - /*threads_per_block_limit=*/1024, - /*threads_per_warp=*/64, /*shared_memory_per_block=*/64 * 1024, - /*shared_memory_per_block_optin=*/0, - /*shared_memory_per_core=*/64 * 1024, - /*threads_per_core_limit=*/2560, /*core_count=*/120, - /*fpus_per_core=*/128, /*block_dim_limit_x=*/2'147'483'647, - /*block_dim_limit_y=*/2'147'483'647, - /*block_dim_limit_z=*/2'147'483'647, - /*memory_bandwidth=*/1228800000000, - /*l2_cache_size=*/8 * 1024 * 1024, - /*clock_rate_ghz=*/::testing::Ge(1.5), - /*device_memory_size=*/33'806'090'240)); - } else if (name == "AMD Instinct MI50/MI60") { - EXPECT_THAT( - dev_info, - ::testing::FieldsAre( - se::GpuComputeCapability(se::RocmComputeCapability("gfx906")), - /*threads_per_block_limit=*/1024, - /*threads_per_warp=*/64, /*shared_memory_per_block=*/64 * 1024, - /*shared_memory_per_block_optin=*/0, - /*shared_memory_per_core=*/64 * 1024, - /*threads_per_core_limit=*/2560, /*core_count=*/60, - /*fpus_per_core=*/64, /*block_dim_limit_x=*/2'147'483'647, - /*block_dim_limit_y=*/2'147'483'647, - /*block_dim_limit_z=*/2'147'483'647, - /*memory_bandwidth=*/256000000000, - /*l2_cache_size=*/8 * 1024 * 1024, - /*clock_rate_ghz=*/::testing::Ge(1.7), - /*device_memory_size=*/17'163'091'968)); - } -#endif // TF_ROCM_VERSION >= 50500 - else { // NOLINT - VLOG(1) << "Not tested for " << name; - } -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc index 05283bbdc2b8a..0a9cbe65d98e1 100644 --- a/xla/service/gpu/gpu_executable.cc +++ b/xla/service/gpu/gpu_executable.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,31 +21,48 @@ limitations under the License. #include #include #include -#include #include #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" -#include "mlir/Parser/Parser.h" // from @llvm-project +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "xla/executable_run_options.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/map_util.h" -#include "xla/mlir/runtime/ir/rt_ops.h" -#include "xla/mlir/runtime/transforms/compilation_pipeline_gpu.h" -#include "xla/mlir/runtime/transforms/type_converter.h" -#include "xla/runtime/executable.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/executable.h" +#include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/gpu_constants.h" -#include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" -#include "xla/service/gpu/runtime/executable.h" +#include "xla/service/gpu/gpu_executable_run_options.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/service/gpu/runtime/annotation.h" +#include "xla/service/gpu/runtime/nccl_clique.h" +#include "xla/service/gpu/runtime/thunk.h" #include "xla/service/gpu/stream_executor_util.h" -#include "xla/service/gpu/thunk.h" +#include "xla/service/hlo_execution_profile.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_parser.h" +#include "xla/service/hlo_value.h" +#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/rendezvous.h" +#include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_buffer.h" #include "xla/service/stream_pool.h" #include "xla/service/xla_debug_info_manager.h" @@ -53,16 +70,21 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/util.h" +#include "tsl/platform/env.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/scoped_annotation.h" #include "tsl/profiler/lib/traceme.h" @@ -73,21 +95,33 @@ limitations under the License. #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "xla/stream_executor/gpu/gpu_activation.h" #include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/gpu/gpu_stream.h" +#include "xla/stream_executor/gpu/gpu_timer.h" +#else +namespace stream_executor::gpu { +class GpuTimer {}; +} // namespace stream_executor::gpu #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace xla { namespace gpu { +using ::tsl::profiler::ScopedAnnotation; + bool IsXlaRuntimeExecutableEnabled(const HloModuleConfig& config) { - return config.debug_options().xla_gpu_enable_xla_runtime_executable(); + bool enabled = config.debug_options().xla_gpu_enable_xla_runtime_executable(); + if (enabled) { + LOG(ERROR) + << "XLA:GPU tried to use deprecated xla runtime by setting " + "--xla_gpu_enable_xla_runtime_executable flag to `true` but the " + "flag value was ignored as XLA:GPU uses default runtime. This flag " + "together with the deprecated code will be removed soon. Please " + "report bugs to XLA team if this breaks your workloads."; + } + return false; } -namespace { - -using ::tsl::profiler::ScopedAnnotation; -using ::tsl::profiler::ScopedAnnotationAlways; - -bool NeedsAsyncCommsStream(Thunk& thunk) { +static bool NeedsAsyncCommsStream(Thunk& thunk) { switch (thunk.kind()) { case Thunk::Kind::kNcclAllReduceStart: case Thunk::Kind::kNcclAllReduceDone: @@ -97,26 +131,34 @@ bool NeedsAsyncCommsStream(Thunk& thunk) { } } -} // namespace - -StatusOr> GpuExecutable::Create(Params params) { - auto executable = std::move(params.executable); - std::unique_ptr result(new GpuExecutable(std::move(params))); - - if (std::holds_alternative(executable)) { - result->thunks_ = std::move(std::get(executable)); - return result; - } - - if (std::holds_alternative(executable)) { - auto& program = std::get(executable); - TF_ASSIGN_OR_RETURN( - result->gpu_runtime_executable_, - GpuRuntimeExecutable::Create(result->module_name_, std::move(program))); - return result; +// Traverses operations in HLO module and collects execution stream ids +// requested by HLO operations. At run time thunks may use additional streams to +// launch compute operations in addition to a main one. +// +// TODO(ezhulenev): Execution stream requirements should be queried from thunks +// directly and not from HLO module that might be missing. +static absl::flat_hash_set GetExecutionStreamIds( + const HloModule& module) { + absl::flat_hash_set stream_ids; + for (const HloComputation* comp : module.computations()) { + for (const HloInstruction* hlo : comp->instructions()) { + if (hlo->has_backend_config() && + hlo->backend_config().ok()) { + int64_t op_queue_id = hlo->backend_config() + .value() + .operation_queue_id(); + if (op_queue_id > 0) { + stream_ids.insert(ExecutionStreamId(op_queue_id)); + } + } + } } + return stream_ids; +} - return InternalError("No XLA gpu executable was provided"); +absl::StatusOr> GpuExecutable::Create( + Params params) { + return std::unique_ptr(new GpuExecutable(std::move(params))); } // Implementation note: HLO profiling is always enabled for GPU executables, @@ -125,12 +167,16 @@ GpuExecutable::GpuExecutable(GpuExecutable::Params params) : Executable(std::move(params.debug_module)), text_(std::move(params.asm_text)), binary_(std::move(params.binary)), + dnn_compiled_graphs_(std::move(params.dnn_compiled_graphs)), gpu_version_(params.gpu_version), + thunks_(std::move(params.executable)), + execution_stream_ids_(has_module() + ? GetExecutionStreamIds(module()) + : absl::flat_hash_set()), module_name_(params.module_name), output_shape_(params.output_shape), allocations_(std::move(params.mlir_allocations)), buffer_assignment_(std::move(params.buffer_assignment)), - enable_persistent_temp_buffers_(params.enable_persistent_temp_buffers), debug_buffer_assignment_show_max_( params.debug_buffer_assignment_show_max), constants_(std::move(params.constants)), @@ -153,17 +199,9 @@ GpuExecutable::~GpuExecutable() { if (has_module() && enable_debug_info_manager_) { XlaDebugInfoManager::Get()->UnregisterModule(module().unique_id()); } - - // Deallocate all persistent buffers. - for (auto& [executor, map] : persistent_temp_buffers_) { - for (const auto& alloc_buffer : map) { - se::DeviceMemoryBase buffer = alloc_buffer.second; - executor->UnifiedMemoryDeallocate(buffer.opaque()); - } - } } -Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions( +absl::Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions( const ServiceExecutableRunOptions* run_options) { se::Stream* main_stream = run_options->stream(); @@ -186,23 +224,168 @@ Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions( << "}, but was {" << std::get(cc).ToString() << "}"; } else { - return InternalError("Unknown platform"); + return Internal("Unknown platform"); } - return OkStatus(); + return absl::OkStatus(); } namespace { -Status MaybeSyncAndProfile(const ServiceExecutableRunOptions* run_options, - uint64_t start_nanos, se::Stream* stream_to_sync); +// Shared resources required for thunk initialization and execution. +class ResourceRequests : public Thunk::ResourceRequests { + public: + absl::Status AddClique(const NcclCliqueKey& clique_key, + int32_t num_local_participants) final { + VLOG(5) << "Add collective clique request: " << clique_key.ToString() + << "; num_local_participants: " << num_local_participants; + + // Check if there is already a clique request for this clique key. + if (auto it = cliques_.find(clique_key); it != cliques_.end()) { + // We can't have multiple requests for a same clique key with different + // number of local participants as we can acquire a clique only once and + // we have to know how many executables will join the rendezvous. + if (it->second.num_local_participants != num_local_participants) { + return absl::InternalError(absl::StrFormat( + "Clique request for a clique key %s has number of local " + "participants %d different from previously requested value of %d. " + "This will lead to deadlock at run time and is an XLA compiler " + "bug. Please report it to XLA team.", + clique_key.ToString(), num_local_participants, + it->second.num_local_participants)); + } + return absl::OkStatus(); + } -Status ExecuteThunks(const std::string& module_name, ModuleIdentifier module_id, - const ThunkSequence& thunk_sequence, - const ServiceExecutableRunOptions* run_options, - const BufferAllocations& buffer_allocations, - bool block_host_until_done, - bool use_highest_priority_for_async_stream) { + // XLA compiler guarantees that all collective operations have the same + // order on all replicas. We rely on this property to assign unique id to + // clique requests simply based on the number of already recored requests. + int64_t id = cliques_.size(); + cliques_.try_emplace(clique_key, + CliqueRequest{clique_key, num_local_participants, id}); + return absl::OkStatus(); + } + + absl::StatusOr AcquireCollectiveCliques( + const Thunk::CollectiveExecuteParams& params) { + if (cliques_.empty()) return Thunk::CollectiveCliques(); + + VLOG(2) << "Acquire " << cliques_.size() + << " collective cliques for global device id " + << params.global_device_id.value() + << "; run_id=" << params.run_id.ToInt() + << "; max number of channels for collectives " + << params.collective_max_nchannels + << "; max number of channels for p2p " << params.p2p_max_nchannels; + + tsl::profiler::TraceMe trace([&] { + return tsl::profiler::TraceMeEncode("AcquireCollectiveCliques", + {{"num_cliques", cliques_.size()}}); + }); + + auto start_micros = tsl::Env::Default()->NowMicros(); + + NcclClique::AcquiredCliquesMap cliques_map; + + for (const CliqueRequest& r : GetOrderedCliqueRequests()) { + std::optional rank = r.key.rank(params.global_device_id); + + if (!rank.has_value()) { + return absl::InternalError(absl::StrCat( + "Can't find global device id ", params.global_device_id.value(), + " in clique key ", r.key.ToString())); + } + + bool is_local = r.key.devices().size() == r.num_local_participants; + TF_ASSIGN_OR_RETURN( + const NcclCliqueIdCallback* clique_id_callback, + GetNcclCliqueIdCallback(params.nccl_clique_id_callback, is_local)); + + int64_t max_channels = r.key.stream_kind() == AsyncStreamKind::kCollective + ? params.collective_max_nchannels + : params.p2p_max_nchannels; + TF_ASSIGN_OR_RETURN(std::shared_ptr clique, + AcquireNcclClique(params.executor, params.run_id, + r.key, *clique_id_callback, *rank, + r.num_local_participants, + cliques_map, max_channels)); + + cliques_map[r.key] = std::move(clique); + } + + auto end_micros = tsl::Env::Default()->NowMicros(); + VLOG(2) << "Acquired " << cliques_map.size() + << " collective cliques for global device id " + << params.global_device_id.value() << " in " + << (end_micros - start_micros) << " μs" + << "; run_id=" << params.run_id.ToInt(); + + return Thunk::CollectiveCliques(std::move(cliques_map)); + } + + private: + struct CliqueRequest { + NcclCliqueKey key; + int64_t num_local_participants; + int64_t id; + }; + + // Return clique requests deterministically ordered using a comparison + // function that produces identical ordering for all participating ranks. + // + // Example: 8 ranks splitted in different groups of communicators + // + // Group #0: [0,1], [2,3], [4,5], [6,7] + // Group #1: [0,4], [1,5], [2,6], [3,7] + // + // Both groups #0 and #1 can be acqured by splitting [0...7] clique. To avoid + // deadlocks all participants should acquire all cliques in a group #0 before + // acquiring any cliques in a group #1. + // + // We rely on clique request id to guarantee that the order is identical + // on all participating ranks (including ranks running on different hosts). + std::vector GetOrderedCliqueRequests() { + std::vector cliques; + cliques.reserve(cliques_.size()); + for (const auto& [_, request] : cliques_) cliques.push_back(request); + + absl::c_sort(cliques, [](const CliqueRequest& a, const CliqueRequest& b) { + // Acquire larger cliques first to be able to split them later. + if (a.key.devices().size() > b.key.devices().size()) return true; + if (b.key.devices().size() > a.key.devices().size()) return false; + + // If cliques have the same size prefer cliques with smaller stream id. + if (a.key.stream_id() < b.key.stream_id()) return true; + if (b.key.stream_id() < a.key.stream_id()) return false; + + // Prefer cliques with smaller id (comes earlier in execution order). + return a.id < b.id; + }); + + return cliques; + } + + absl::flat_hash_map cliques_; +}; + +absl::Status MaybeSyncAndProfile( + const ServiceExecutableRunOptions* run_options, + std::optional execution_timer, + se::Stream* stream_to_sync); + +absl::Status RendezvousAfterInitialization( + const ServiceExecutableRunOptions* run_options); + +absl::Status ExecuteThunks( + const std::string& module_name, ModuleIdentifier module_id, + const ThunkSequence& thunk_sequence, + Thunk::ExecutableSource executable_source, + const ServiceExecutableRunOptions* run_options, + const BufferAllocations& buffer_allocations, bool block_host_until_done, + bool use_highest_priority_for_async_stream, + const absl::flat_hash_set& execution_stream_ids, + int64_t collective_max_nchannels, int64_t p2p_max_nchannels, + const ModuleAnnotations& module_annotations) { se::Stream* main_stream = run_options->stream(); se::StreamExecutor* executor = main_stream->parent(); stream_executor::StreamPriority stream_priority = @@ -211,37 +394,111 @@ Status ExecuteThunks(const std::string& module_name, ModuleIdentifier module_id, stream_priority = stream_executor::StreamPriority::Highest; } - // Create the needed streams to support NcclCollectiveThunk. + // Borrow streams required for NcclCollectiveThunk. absl::InlinedVector async_comms_streams( kAsyncStreamTotal, nullptr); - StatusOr> streams = run_options->BorrowStreams( - executor->device_ordinal(), kAsyncStreamTotal, stream_priority); + absl::StatusOr> streams = + run_options->BorrowStreams(executor->device_ordinal(), kAsyncStreamTotal, + stream_priority); if (streams.ok()) { for (int64_t i = 0; i < kAsyncStreamTotal; ++i) { async_comms_streams[i] = streams->at(i).get(); } } - uint64_t start_nanos = tsl::Env::Default()->NowNanos(); + + // Borrow stream for tracing command buffers. + se::Stream* command_buffer_trace_stream = nullptr; + absl::StatusOr borrowed_command_buffer_trace_stream = + run_options->BorrowStream(executor->device_ordinal()); + if (borrowed_command_buffer_trace_stream.ok()) { + command_buffer_trace_stream = borrowed_command_buffer_trace_stream->get(); + } + + // Borrow stream for additional compute streams + Thunk::ExecutionStreamIdMap additional_execution_streams; + std::vector additional_streams; + if (!execution_stream_ids.empty()) { + TF_ASSIGN_OR_RETURN(additional_streams, run_options->BorrowStreams( + executor->device_ordinal(), + execution_stream_ids.size())); + int64_t i = 0; + for (ExecutionStreamId stream_id : execution_stream_ids) { + additional_execution_streams[stream_id] = additional_streams.at(i).get(); + i++; + } + VLOG(2) << "Using " << additional_execution_streams.size() + << " additional compute streams."; + } tsl::profiler::TraceMe hlo_module_activity( [&] { return absl::StrCat(module_name, ":XLA GPU module"); }, tsl::profiler::TraceMeLevel::kInfo); - ScopedAnnotationAlways annotation([&] { - std::string module_id_str; - if (module_id >= 0) { - module_id_str = absl::StrFormat(",program_id=%d", module_id); + std::optional execution_timer; +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + if (ExecutionProfile* profile = + run_options->run_options().execution_profile(); + profile) { + TF_ASSIGN_OR_RETURN( + execution_timer, + se::gpu::GpuTimer::Create(main_stream, profile->warmup_run_executed())); + } +#endif + + // Parameters for executing collective operations. + TF_ASSIGN_OR_RETURN(Thunk::CollectiveExecuteParams collective_params, + Thunk::CollectiveExecuteParams::Create( + *run_options, main_stream->parent()->device_ordinal(), + collective_max_nchannels, p2p_max_nchannels)); + + ResourceRequests resource_requests; + + { // Collect resource requirements from thunks. + Thunk::PrepareParams prepare_params{&collective_params}; + + tsl::profiler::TraceMe trace([&] { return "Thunks::Prepare"; }); + for (const std::unique_ptr& thunk : thunk_sequence) { + TF_RETURN_IF_ERROR(thunk->Prepare(prepare_params, resource_requests)); } - return absl::StrFormat("XlaModule:#hlo_module=%s%s#", module_name, - module_id_str); - }); + } + + // Acquire collective cliques requested by thunks. + TF_ASSIGN_OR_RETURN( + Thunk::CollectiveCliques collective_cliques, + resource_requests.AcquireCollectiveCliques(collective_params)); + + { // Initialize thunks using prepared resources before execution. + Thunk::InitializeParams initialize_params{ + executor, executable_source, &buffer_allocations, + main_stream, command_buffer_trace_stream, &collective_params, + &collective_cliques}; + + tsl::profiler::TraceMe trace([&] { return "Thunks::Initialize"; }); + for (const std::unique_ptr& thunk : thunk_sequence) { + TF_RETURN_IF_ERROR(thunk->Initialize(initialize_params)); + } + } + + // Maybe join a round of rendezvous after thunk initialization. We do this + // only in presence of collective cliques which means that we have collective + // operations in the XLA operations that tend to cause deadlocks. + if (!collective_cliques.empty()) { + TF_RETURN_IF_ERROR(RendezvousAfterInitialization(run_options)); + } + + // Prepare parameters for thunks execution. + Thunk::ExecuteParams execute_params = Thunk::ExecuteParams::Create( + *run_options, buffer_allocations, main_stream, + command_buffer_trace_stream, async_comms_streams, &collective_params, + &collective_cliques, additional_execution_streams); for (const std::unique_ptr& thunk : thunk_sequence) { // Annotate execution of this op if tracing was enabled when we started // running this module. If tracing is enabled *while* we're running the // module, we won't get any data, but that's probably an OK trade-off. - ScopedAnnotation annotation([&] { return thunk->profile_annotation(); }); - VLOG(2) << "Executing the thunk for " << thunk->profile_annotation(); + auto scoped_annotation = + GetKernelAnnotation(&module_annotations, thunk->profile_annotation()); + VLOG(3) << "Executing the thunk for " << thunk->profile_annotation(); if (NeedsAsyncCommsStream(*thunk)) { for (se::Stream* async_stream : async_comms_streams) { TF_RET_CHECK(async_stream != nullptr) @@ -249,47 +506,125 @@ Status ExecuteThunks(const std::string& module_name, ModuleIdentifier module_id, } } - Thunk::ExecuteParams thunk_params{*run_options, buffer_allocations, - main_stream, async_comms_streams}; - TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(thunk_params)); + TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(execute_params)); } - return MaybeSyncAndProfile(run_options, start_nanos, + return MaybeSyncAndProfile(run_options, std::move(execution_timer), block_host_until_done ? main_stream : nullptr); } -Status MaybeSyncAndProfile(const ServiceExecutableRunOptions* run_options, - uint64_t start_nanos, - se::Stream* stream_to_sync = nullptr) { +namespace { +// Wrap RunId into a unique struct to guarantee we do not accidentally try to +// run multiple unrelated rendezvous for a same key. +struct InitializationKey { + RunId run_id; + + template + friend H AbslHashValue(H h, const InitializationKey& key) { + return H::combine(std::move(h), key.run_id); + } +}; + +bool operator==(const InitializationKey& a, const InitializationKey& b) { + return a.run_id == b.run_id; +} +} // namespace + +absl::Status RendezvousAfterInitialization( + const ServiceExecutableRunOptions* run_options) { + // Thunk initialization can allocate new control data structures on device + // that can lead to deadlocks if other replicas are executing concurrently + // (i.e. this happens if we try to instantiate CUDA graph when other replica + // is executing NCCL kernels). If we detect that we are running in multi-gpu + // setup we synchronize after first initialization to make sure that all + // replicas completed initialization process before we start execution. + auto* gpu_opts = run_options->run_options().gpu_executable_run_options(); + auto* device_assn = run_options->run_options().device_assignment(); + + // If we don't have Gpu executable options or device assignment it means we + // are running in a single Gpu config and don't need a rendezvous. + if (!gpu_opts || !device_assn) return absl::OkStatus(); + + // Assume that all participants execute locally first, if we have a local + // device id to global device id map we will use it to get the real number of + // participating local devices. + int64_t num_local_participants = + device_assn->replica_count() * device_assn->computation_count(); + + // Find what local devices are part of the device assignment. + if (gpu_opts->gpu_global_device_ids()) { + auto d2l_map = device_assn->GetDeviceToLogicalIdMap(); + + num_local_participants = 0; + for (auto& [local_id, global_id] : *gpu_opts->gpu_global_device_ids()) { + num_local_participants += d2l_map.contains(global_id); + } + + if (num_local_participants == 0) { + return absl::InternalError( + "Cound't find the number of local participants"); + } + } + + VLOG(1) << "Join thunks initialization rendezvous with " + << num_local_participants << " local participants" + << "; device_ordinal=" << run_options->device_ordinal(); + + tsl::profiler::TraceMe trace([&] { + return tsl::profiler::TraceMeEncode( + "RendezvousAfterInitialization", + {{"run_id", run_options->run_options().run_id().ToInt()}, + {"num_local_participants", num_local_participants}}); + }); + + auto rendezvous_key = InitializationKey{run_options->run_options().run_id()}; + auto rendezvous_name = absl::StrFormat( + "thunk initialization completion for device ordinal %d; run_id=%d", + run_options->device_ordinal(), + run_options->run_options().run_id().ToInt()); + + RendezvousSingle(rendezvous_name, rendezvous_key, num_local_participants, + absl::Seconds(10), absl::Seconds(30)); + + return absl::OkStatus(); +} + +absl::Status MaybeSyncAndProfile( + const ServiceExecutableRunOptions* run_options, + std::optional execution_timer, + se::Stream* stream_to_sync = nullptr) { +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + // If we're measuring the execution time then it's important to queue the + // stop event before triggering any synchronization. + if (ExecutionProfile* profile = + run_options->run_options().execution_profile(); + profile) { + CHECK(execution_timer.has_value()); + TF_ASSIGN_OR_RETURN(absl::Duration elapsed, + execution_timer->GetElapsedDuration()); + profile->set_compute_time_ns( + std::max(absl::ToDoubleNanoseconds(elapsed), 1.0)); + } +#endif + // Make sure kernels are completed before deallocating temporary buffers or // the profiler state. // TODO(b/30100571): we could potentially postpone deallocating the temp // buffers until a different computation is executed. if (stream_to_sync) { - Status block_status = stream_to_sync->BlockHostUntilDone(); + absl::Status block_status = stream_to_sync->BlockHostUntilDone(); if (!block_status.ok()) { - return InternalError( + return Internal( "Failed to complete all kernels launched on stream %p: %s", stream_to_sync, block_status.message()); } } - // FinishExecution() blocks until main_stream has completed if profiling is - // enabled; we therefore do not need to defer profile collection onto a - // stream. - uint64_t end_nanos = tsl::Env::Default()->NowNanos(); - - if (run_options->run_options().execution_profile()) { - ExecutionProfile* profile = run_options->run_options().execution_profile(); - const double nanoseconds = end_nanos - start_nanos; - profile->set_compute_time_ns(std::max(nanoseconds, 1.0)); - } - - return OkStatus(); + return absl::OkStatus(); } } // namespace -StatusOr +absl::StatusOr GpuExecutable::ResolveConstantGlobals(se::Stream* stream) { se::StreamExecutor* executor = stream->parent(); @@ -320,7 +655,7 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) { int submitted_mem_copies = 0; for (const ConstantInfo& info : constants_) { - StatusOr global_status; + absl::StatusOr global_status; if (static_cast(module_handle)) { global_status = executor->GetUntypedSymbol(info.symbol_name, module_handle); @@ -337,8 +672,8 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) { if (!info.content.span().empty()) { // This means the constant did not have an initializer in the PTX and // therefore must be initialized by XLA here. - stream->ThenMemcpy(&global, info.content.span().data(), - info.content.span().size()); + TF_RETURN_IF_ERROR(stream->Memcpy(&global, info.content.span().data(), + info.content.span().size())); submitted_mem_copies = true; } } else { @@ -375,7 +710,7 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) { .first->second.get(); } -StatusOr GpuExecutable::BufferForAllocation( +absl::StatusOr GpuExecutable::BufferForAllocation( VariantArguments arguments, const GpuExecutable::BufferAllocToDeviceMemoryMap* globals, const BufferAllocation& allocation, @@ -417,8 +752,10 @@ StatusOr GpuExecutable::BufferForAllocation( const int64_t buffer_size = allocation.size(); se::DeviceMemoryBase buffer_address; if (buffer_size > 0) { - StatusOr buffer = - memory_allocator->Allocate(device_ordinal, buffer_size); + absl::StatusOr buffer = + memory_allocator->Allocate(device_ordinal, buffer_size, + /*retry_on_failure=*/true, + /*memory_space=*/allocation.color()); if (!buffer.ok()) { return ResourceExhausted("%s\n%s\n", buffer.status().message(), buffer_assignment_->ToVerboseString( @@ -430,8 +767,8 @@ StatusOr GpuExecutable::BufferForAllocation( } } -static Status CheckAlignment(const BufferAllocation& allocation, - se::DeviceMemoryBase buffer, int arg_idx) { +static absl::Status CheckAlignment(const BufferAllocation& allocation, + se::DeviceMemoryBase buffer, int arg_idx) { const int64_t expected_alignment = [&] { if (allocation.is_entry_computation_parameter()) { return kEntryParameterAlignBytes; @@ -443,19 +780,18 @@ static Status CheckAlignment(const BufferAllocation& allocation, }(); if (!buffer.is_null() && reinterpret_cast(buffer.opaque()) % expected_alignment != 0) { - return InternalError( + return Internal( "Address of buffer %d must be a multiple of %x, but " "was %p", arg_idx, expected_alignment, buffer.opaque()); } - return OkStatus(); + return absl::OkStatus(); } -StatusOr GpuExecutable::GenerateBufferAllocations( +absl::StatusOr GpuExecutable::GenerateBufferAllocations( VariantArguments arguments, const GpuExecutable::BufferAllocToDeviceMemoryMap* globals, - se::DeviceMemoryAllocator* const memory_allocator, int device_ordinal, - const BufferAllocToDeviceMemoryMap& buffer_alloc_to_persistent_memory_map) { + se::DeviceMemoryAllocator* const memory_allocator, int device_ordinal) { tsl::profiler::TraceMe hlo_module_activity( [&] { return std::string("Build buffer allocations"); }, tsl::profiler::TraceMeLevel::kInfo); @@ -466,30 +802,23 @@ StatusOr GpuExecutable::GenerateBufferAllocations( buffers.reserve(num_buffers); for (int64_t i = 0; i < num_buffers; ++i) { const BufferAllocation& allocation = allocations[i]; - // Check if the buffer is already stored as a persistent buffer. - se::DeviceMemoryBase buffer; - if (buffer_alloc_to_persistent_memory_map.contains(allocation.index())) { - buffer = buffer_alloc_to_persistent_memory_map.at(allocation.index()); - } else { - TF_ASSIGN_OR_RETURN( - buffer, BufferForAllocation(arguments, globals, allocation, - memory_allocator, device_ordinal, i)); - } - - buffers.push_back(buffer); - TF_RETURN_IF_ERROR(CheckAlignment(allocation, buffer, i)); + TF_ASSIGN_OR_RETURN( + buffers.emplace_back(), + BufferForAllocation(arguments, globals, allocations[i], + memory_allocator, device_ordinal, i)); + TF_RETURN_IF_ERROR(CheckAlignment(allocation, buffers.back(), i)); } return {{buffers, device_ordinal, memory_allocator}}; } -StatusOr GpuExecutable::ExecuteAsyncOnStream( +absl::StatusOr GpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, std::vector arguments, HloExecutionProfile* hlo_execution_profile) { return ExecuteAsyncOnStreamImpl(run_options, absl::MakeSpan(arguments)); } -StatusOr GpuExecutable::ExecuteAsyncOnStream( +absl::StatusOr GpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { @@ -498,67 +827,7 @@ StatusOr GpuExecutable::ExecuteAsyncOnStream( return out.ConsumeResult(); } -static Status ExecuteXlaRuntime(const std::string& module_name, - ModuleIdentifier module_id, - GpuRuntimeExecutable& gpu_runtime_executable, - const ServiceExecutableRunOptions* run_options, - const std::string& asm_text, - const std::vector& binary, - const BufferAllocations& buffer_allocations, - const BufferAllocation* temp_buffer, - bool block_host_until_done, - NonAtomicallyUpgradeableRWLock& gpu_lock) { - uint64_t start_nanos = tsl::Env::Default()->NowNanos(); - - tsl::profiler::TraceMe hlo_module_activity( - [&] { return absl::StrCat(module_name, ":XLA GPU module"); }, - tsl::profiler::TraceMeLevel::kInfo); - - ScopedAnnotationAlways annotation([&] { - std::string module_id_str; - if (module_id >= 0) { - module_id_str = absl::StrFormat(",program_id=%d", module_id); - } - return absl::StrFormat("XlaModule:#hlo_module=%s%s#", module_name, - module_id_str); - }); - - auto executed = gpu_runtime_executable.Execute( - run_options, asm_text, binary, buffer_allocations, gpu_lock, temp_buffer); - if (!executed.ok()) return executed; - - return MaybeSyncAndProfile( - run_options, start_nanos, - block_host_until_done ? run_options->stream() : nullptr); -} - -Status GpuExecutable::PopulatePersistentTempBuffers( - se::StreamExecutor* executor) { - auto search = persistent_temp_buffers_.find(executor); - if (search != persistent_temp_buffers_.end()) { - return OkStatus(); - } - - // Allocate persistent temp buffers. - BufferAllocToDeviceMemoryMap buffer_alloc_to_device_memory_map; - for (const BufferAllocation& allocation : GetAllocations()) { - if (!allocation.IsPreallocatedTempBuffer()) { - continue; - } - - const int64_t buffer_size = allocation.size(); - void* ptr = executor->UnifiedMemoryAllocate(buffer_size); - if (ptr) { - se::DeviceMemoryBase buffer(ptr, buffer_size); - buffer_alloc_to_device_memory_map[allocation.index()] = buffer; - } - } - - persistent_temp_buffers_[executor] = buffer_alloc_to_device_memory_map; - return OkStatus(); -} - -StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( +absl::StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( const ServiceExecutableRunOptions* run_options, VariantArguments arguments) { XLA_SCOPED_LOGGING_TIMER(absl::StrCat( @@ -573,21 +842,6 @@ StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( se::gpu::ScopedActivateExecutorContext activation(gpu_executor); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - // If persistent buffers are enabled, the executable cannot execute - // concurrently, therefore performance can suffer under contention. - absl::MutexLockMaybe lock( - enable_persistent_temp_buffers_ ? &persistent_temp_buffers_mu_ : nullptr); - - // Map from buffer allocation to persistent temp buffers. It is empty if - // persistent temp buffer is not enabled. - BufferAllocToDeviceMemoryMap persistent_buffers_map = {}; - - if (enable_persistent_temp_buffers_) { - persistent_temp_buffers_mu_.AssertHeld(); - TF_RETURN_IF_ERROR(PopulatePersistentTempBuffers(executor)); - persistent_buffers_map = persistent_temp_buffers_[executor]; - } - // Force synchronous execution if the allocator requires it. const bool block_host_until_done = !memory_allocator->AllowsAsynchronousDeallocation(); @@ -596,12 +850,13 @@ StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( // that may be running during JIT compilation while allowing multiple XLA // computations to use the same GPU simultaneously. We do not add locking for // "recursive" invocations, which are done when holding a lock already. - NonAtomicallyUpgradeableRWLock gpu_lock(&GetGpuMutex(executor)); - std::optional exclusive_gpu_lock; - const gpu::GpuExecutableRunOptions* gpu_opts = - run_options->run_options().gpu_executable_run_options(); - if (gpu_opts && gpu_opts->requires_exclusive_lock_on_gpu()) { - exclusive_gpu_lock.emplace(&gpu_lock); + std::variant gpu_lock( + std::in_place_index_t<0>{}, &GetGpuMutex(executor)); + + // Maybe update to a writer lock to get exlcusive acess to underlying GPU. + if (auto* gpu_opts = run_options->run_options().gpu_executable_run_options(); + gpu_opts && gpu_opts->requires_exclusive_lock_on_gpu()) { + gpu_lock.emplace<1>(&GetGpuMutex(executor)); } const GpuExecutable::BufferAllocToDeviceMemoryMap* globals; @@ -620,8 +875,8 @@ StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( TF_ASSIGN_OR_RETURN( BufferAllocations buffer_allocations, GenerateBufferAllocations(arguments, globals, memory_allocator, - device_ordinal, persistent_buffers_map)); - VLOG(2) << buffer_allocations.ToString(); + device_ordinal)); + VLOG(3) << buffer_allocations.ToString(); std::set buffers_in_result; const bool is_entire_tuple_contents_aliased = [&] { @@ -701,8 +956,10 @@ StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( "buffer is not donated; allocating a fresh buffer"; int64_t allocation_size = ShapeUtil::ByteSizeOf(ShapeUtil::GetSubshape(output_shape_, index)); - StatusOr allocated_buffer = - memory_allocator->Allocate(device_ordinal, allocation_size); + absl::StatusOr allocated_buffer = + memory_allocator->Allocate(device_ordinal, allocation_size, + /*retry_on_failure=*/true, + /*memory_space=*/allocation->color()); if (!allocated_buffer.ok()) { return ResourceExhausted("%s\n%s\n", allocated_buffer.status().message(), @@ -714,8 +971,8 @@ StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( buffer_allocations.GetMutableDeviceAddress( output_info.allocation_index); CHECK_EQ(aliased_buffer.size(), result_buffer.size()); - run_options->stream()->ThenMemcpyD2D(&result_buffer, aliased_buffer, - aliased_buffer.size()); + TF_RETURN_IF_ERROR(run_options->stream()->MemcpyD2D( + &result_buffer, aliased_buffer, aliased_buffer.size())); aliased_buffer = result_buffer; } } @@ -735,18 +992,11 @@ StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( buffers_in_result.insert(result_buffer); } - TF_RETURN_IF_ERROR(ExecuteThunksOrXlaRuntime( - run_options, buffer_allocations, block_host_until_done, gpu_lock)); + TF_RETURN_IF_ERROR(ExecuteThunksOrXlaRuntime(run_options, buffer_allocations, + block_host_until_done)); - // Free all temporary allocations. - std::vector non_persistent_allocations; - for (const BufferAllocation& allocation : GetAllocations()) { - if (!persistent_buffers_map.contains(allocation.index())) { - non_persistent_allocations.push_back(allocation); - } - } - TF_RETURN_IF_ERROR(buffer_allocations.TearDown(buffers_in_result, - non_persistent_allocations)); + TF_RETURN_IF_ERROR( + buffer_allocations.TearDown(buffers_in_result, GetAllocations())); // Free allocations for arguments. if (auto args = std::get_if>(&arguments)) { @@ -755,50 +1005,40 @@ StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( return std::move(result); } -Status GpuExecutable::ExecuteThunksOrXlaRuntime( +absl::Status GpuExecutable::ExecuteThunksOrXlaRuntime( const ServiceExecutableRunOptions* run_options, - const BufferAllocations& buffer_allocations, bool block_host_until_done, - NonAtomicallyUpgradeableRWLock& gpu_lock) { + const BufferAllocations& buffer_allocations, bool block_host_until_done) { TF_RETURN_IF_ERROR( CheckCompatibilityWithServiceExecutableRunOptions(run_options)); - // There isn't always an HLO module. - ModuleIdentifier unique_id = -1; - if (has_module()) { - unique_id = module().unique_id(); - } + ScopedAnnotation annotation([&] { return module_annotations_.top_level; }); + ScopedModuleAnnotations module_annotations(&module_annotations_); + + ModuleIdentifier unique_id = has_module() ? module().unique_id() : -1; if (thunks_) { - se::StreamExecutor* executor = run_options->stream()->parent(); - Thunk::ExecutableSource executable_source = {text_, binary_}; - for (const std::unique_ptr& thunk : *thunks_) { - TF_RETURN_IF_ERROR(thunk->Initialize(executor, executable_source)); - } + Thunk::ExecutableSource executable_source = {text_, binary_, + dnn_compiled_graphs_}; + int64_t collective_max_nchannels = + has_module() ? module_config() + .debug_options() + .xla_gpu_nccl_collective_max_nchannels() + : 0; + int64_t p2p_max_nchannels = + has_module() + ? module_config().debug_options().xla_gpu_nccl_p2p_max_nchannels() + : 0; return ExecuteThunks( - module_name_, unique_id, *thunks_, run_options, buffer_allocations, - block_host_until_done, + module_name_, unique_id, *thunks_, executable_source, run_options, + buffer_allocations, block_host_until_done, /*use_highest_priority_for_async_stream*/ has_module() ? module_config() .debug_options() .xla_gpu_enable_highest_priority_async_stream() - : false); - } - - // Match IrEmitter's temp buffer allocation for kernel launches. See - // IrEmitterUnnested::BuildKernelThunkImpl(). - const BufferAllocation* temp_buffer = nullptr; - for (const BufferAllocation& alloc : GetAllocations()) { - if (alloc.IsPreallocatedTempBuffer()) { - // Retrieve the first seen temp buffer. - if (temp_buffer == nullptr) temp_buffer = &alloc; - } - } - - if (gpu_runtime_executable_) { - return ExecuteXlaRuntime(module_name_, unique_id, *gpu_runtime_executable_, - run_options, text_, binary_, buffer_allocations, - temp_buffer, block_host_until_done, gpu_lock); + : false, + execution_stream_ids_, collective_max_nchannels, p2p_max_nchannels, + module_annotations_); } return FailedPrecondition("Expected XLA gpu executable is not supplied."); @@ -819,13 +1059,16 @@ int64_t GpuExecutable::SizeOfGeneratedCodeInBytes() const { return size; } -Status GpuExecutable::SetUpMlirAllocation( +absl::Status GpuExecutable::SetUpMlirAllocation( mlir::func::FuncOp func, llvm::ArrayRef buffer_sizes, std::vector* allocations, absl::flat_hash_map* output_info, Shape* output_shape) { for (int i = 0; i < buffer_sizes.size(); i++) { - allocations->emplace_back(i, buffer_sizes[i], 0); + // This code path is taken when using the non-thunk based runtime. Memory + // space is being set to 0 for all allocations. We need to copy over the + // value from BufferAssignment instead. + allocations->emplace_back(i, buffer_sizes[i], /*memory_space=*/0); } for (int i = 0; i < func.getNumArguments(); i++) { @@ -889,10 +1132,10 @@ Status GpuExecutable::SetUpMlirAllocation( .getValue() .str())); - return OkStatus(); + return absl::OkStatus(); } -StatusOr> +absl::StatusOr> GetOutputInfo(const HloModule& hlo_module, const BufferAssignment& assignment) { const HloInstruction* root = hlo_module.entry_computation()->root_instruction(); @@ -909,7 +1152,7 @@ GetOutputInfo(const HloModule& hlo_module, const BufferAssignment& assignment) { OutputInfoMap output; TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( root->shape(), - [&](const Shape& /*sub_shape*/, const ShapeIndex& index) -> Status { + [&](const Shape& /*sub_shape*/, const ShapeIndex& index) -> absl::Status { const auto& sources = root_value_set.element(index); // The points-to set is unambiguous so the set should be a // singleton. That is, we know exactly which instruction @@ -928,215 +1171,10 @@ GetOutputInfo(const HloModule& hlo_module, const BufferAssignment& assignment) { output[index].alias_config = hlo_module.input_output_alias_config().GetAliasedParameter(index); - return OkStatus(); + return absl::OkStatus(); })); return output; } -GpuExecutable::GpuExecutable( - std::shared_ptr hlo_module, std::string asm_text, - std::vector binary, std::vector constants, - se::GpuComputeCapability gpu_version, absl::string_view module_name, - Shape xla_output_shape, std::vector allocations, - absl::flat_hash_map output_info, - std::unique_ptr gpu_runtime_executable) - : Executable(std::move(hlo_module)), - text_(std::move(asm_text)), - binary_(std::move(binary)), - gpu_version_(gpu_version), - gpu_runtime_executable_(std::move(gpu_runtime_executable)), - module_name_(module_name), - output_shape_(xla_output_shape), - allocations_(std::move(allocations)), - constants_(std::move(constants)), - output_info_(std::move(output_info)), - enable_debug_info_manager_(true) { - if (has_module()) { - XlaDebugInfoManager::Get()->RegisterModule(shared_module(), - BufferAssignmentProto()); - } -} - -// Returns a list of functions exported from the `module` that should be loaded -// from the object file. Entrypoint functions always loaded with ordinal 0. -static StatusOr> -GetFunctionsToLoad(mlir::ModuleOp module, std::string_view entry) { - std::vector functions; - - // Use canonical type converter because we currently do not support any - // user-defined types in XLA:GPU executables. - runtime::TypeConverter type_converter; - - // Converts function type and adds load function metadata. In XLA:GPU exported - // function runtime signature is the same as regular signature with an extra - // execution context argument at index 0. - auto convert = [&](mlir::func::FuncOp func) -> Status { - auto signature = type_converter.Convert(func.getFunctionType()); - if (!signature.ok()) - return InternalError("Failed to convert entry function type: %s", - signature.status().message()); - - // TODO(ezhulenev): Copy `signature` once FunctionType is copyable. - auto rt_signature = type_converter.Convert(func.getFunctionType()); - rt_signature->insert_operand( - 0, std::make_unique()); - - functions.push_back({func.getName().str(), std::move(*signature), - std::move(*rt_signature)}); - - return OkStatus(); - }; - - mlir::SymbolTable sym_table(module); - - // Load entrypoint function first at ordinal 0. - TF_CHECK_OK(convert(module.lookupSymbol(entry))); - - // Load all functions explicitly exported from the module (in XLA:GPU it's - // always CUDA graph capture functions). We explicitly sort them by ordinal, - // to make sure they are loaded in correct order. - auto export_ops = llvm::to_vector(module.getOps()); - llvm::sort(export_ops, [](runtime::ExportOp a, runtime::ExportOp b) { - return a.getOrdinal()->getSExtValue() < b.getOrdinal()->getSExtValue(); - }); - for (runtime::ExportOp exported : export_ops) { - TF_CHECK_OK(convert( - sym_table.lookup(exported.getFunctionRef()))); - } - - return functions; -} - -// Get arguments buffer sizes from the entry function signature. -static StatusOr> GetBufferSizes(runtime::FunctionType& f) { - std::vector buffer_sizes; - for (unsigned i = 0; i < f.num_operands(); ++i) { - auto* memref = llvm::dyn_cast(f.operand(i)); - - // Entry function argument must be a statically shaped 1d I8 memref. - if (memref == nullptr || memref->element_type() != PrimitiveType::S8 || - memref->rank() != 1 || runtime::MemrefType::IsDynamic(memref->size(0))) - return InternalError("Illegal buffer argument type: %s", - f.operand(0)->ToString()); - - buffer_sizes.push_back(memref->size(0)); - } - return buffer_sizes; -} - -// TODO(ezhulenev): This is a copy of `GetAllocationIndices` from -// `mlir/backends/gpu/transforms/passes.h`. We can't depend on that file because -// of a dependency cycle, and this is a short term work around the cuda graph -// capture bug. This code should not survive beyond Q1 2024. -static std::vector> GetAllocationIndices( - mlir::ModuleOp module) { - std::vector> res; - - mlir::SymbolTable sym_table(module); - for (auto op : module.getOps()) { - unsigned ordinal = *op.ordinal(); - if (ordinal >= res.size()) res.resize(ordinal + 1); - - auto func = sym_table.lookup(op.getFunctionRef()); - res[ordinal].resize(func.getNumArguments(), -1); - - for (unsigned i = 0; i < func.getNumArguments(); ++i) { - auto idx = - func.getArgAttrOfType(i, "rt.allocation_index"); - if (idx) res[ordinal][i] = idx.getInt(); - } - } - - return res; -} - -StatusOr> GpuExecutable::LoadFromObjFile( - std::shared_ptr hlo_module, absl::string_view obj_file, - absl::string_view mlir_module, DebugOptions debug_options, - absl::string_view asm_text, absl::string_view binary, - std::vector constants, se::GpuComputeCapability gpu_version, - se::StreamExecutor* executor) { - VLOG(1) << "Load serialized Gpu executable from object file: module=" - << hlo_module->name(); - - std::string_view entry = hlo_module->entry_computation()->name(); - - // Load MLIR module behind the compiled object file to recover XLA allocations - // and output info details. Also recover buffer sizes from the entrypoint - // function signature. - mlir::MLIRContext context; - runtime::AppendXlaGpuDialectRegistry(context); - - auto module = mlir::parseSourceString(mlir_module, &context); - if (!module) return InternalError("Failed to parse AOT compiled module"); - - // Get the list of functions to be loaded from the object file. - TF_ASSIGN_OR_RETURN(std::vector functions, - GetFunctionsToLoad(*module, entry)); - VLOG(2) << "Found " << functions.size() << " functions to load"; - - // Get the buffer sizes from the entry function signature. - TF_ASSIGN_OR_RETURN(std::vector buffer_sizes, - GetBufferSizes(functions[0].signature)); - - // Get allocation indices from graph capture functions. - auto allocation_indices = GetAllocationIndices(*module); - - // Get the XLA module entrypoint function. - auto func = mlir::cast(module->lookupSymbol(entry)); - - // Infer XLA allocations and output info from the MLIR module. - std::vector allocations; - absl::flat_hash_map output_info; - Shape result_xla_shape; - TF_RETURN_IF_ERROR(SetUpMlirAllocation(func, buffer_sizes, &allocations, - &output_info, &result_xla_shape)); - - // Create a named buffer from compiled object file. - llvm::StringRef data(obj_file.data(), obj_file.size()); - auto buffer = llvm::MemoryBuffer::getMemBuffer(data, hlo_module->name()); - - auto symbol_map = runtime::ToSymbolsBinding(RegisterXlaGpuRuntimeCustomCalls, - RegisterXlaGpuTypeIdNames); - - // Load XLA Runtime executable from an object file, and link it with Gpu - // runtime intrinsics implementing Gpu custom calls. - auto executable = runtime::Executable::LoadFromObjFile( - hlo_module->name(), std::move(buffer), std::move(functions), symbol_map); - - if (!executable.ok()) - return InternalError("Failed to load XLA Runtime executable: %s", - executable.status().message()); - - // Move runtime::Executable ownership to the GpuRuntimeExecutable. - TF_ASSIGN_OR_RETURN(auto gpu_runtime_executable, - GpuRuntimeExecutable::Create( - hlo_module->name(), std::move(buffer_sizes), - std::move(allocation_indices), std::move(*executable), - std::move(debug_options))); - - // Construct GpuExecutable for the loaded XLA Runtime executable. - std::string name = hlo_module->name(); - std::string asm_text_string = std::string(asm_text); - std::vector binary_vector(binary.begin(), binary.end()); - return std::unique_ptr(new GpuExecutable( - std::move(hlo_module), std::move(asm_text_string), - std::move(binary_vector), std::move(constants), gpu_version, name, - result_xla_shape, std::move(allocations), std::move(output_info), - std::move(gpu_runtime_executable))); -} - -StatusOr GpuExecutable::GetObjFile() const { - if (!gpu_runtime_executable_) - return Internal("gpu_runtime_executable is null"); - return gpu_runtime_executable_->GetObjFile(); -} - -StatusOr GpuExecutable::GetMlirModule() const { - if (!gpu_runtime_executable_) - return Internal("gpu_runtime_executable is null"); - return gpu_runtime_executable_->GetMlirModule(); -} - } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/gpu_executable.h b/xla/service/gpu/gpu_executable.h index 060583867ed96..e2e0daafe7554 100644 --- a/xla/service/gpu/gpu_executable.h +++ b/xla/service/gpu/gpu_executable.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,33 +17,38 @@ limitations under the License. #define XLA_SERVICE_GPU_GPU_EXECUTABLE_H_ #include -#include -#include #include #include #include -#include -#include #include #include +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "absl/types/variant.h" +#include "llvm/ADT/ArrayRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/buffer_assignment.h" #include "xla/service/executable.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" -#include "xla/service/gpu/runtime/executable.h" -#include "xla/service/gpu/thunk.h" +#include "xla/service/gpu/runtime/annotation.h" +#include "xla/service/gpu/runtime/thunk.h" #include "xla/service/hlo_execution_profile.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_buffer.h" -#include "xla/statusor.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" @@ -61,7 +66,6 @@ bool IsXlaRuntimeExecutableEnabled(const HloModuleConfig& config); class GpuExecutable : public Executable { public: using OwnedThunkSequence = std::unique_ptr; - using OwnedGpuRuntimeProgram = std::unique_ptr; struct ConstantInfo { std::string symbol_name; @@ -84,18 +88,15 @@ class GpuExecutable : public Executable { struct Params { std::string asm_text; std::vector binary; + Thunk::BinaryMap dnn_compiled_graphs; se::GpuComputeCapability gpu_version; - // The GpuExecutable will either execute Thunks, XLA runtime executable - // (native function) or experimental XLA runtime executable (IREE VM - // function) depending on which is supplied. - std::variant executable; + OwnedThunkSequence executable; std::vector constants; absl::flat_hash_map output_info; std::string module_name; xla::Shape output_shape; std::optional> mlir_allocations; std::unique_ptr buffer_assignment; - bool enable_persistent_temp_buffers; int64_t debug_buffer_assignment_show_max; std::unique_ptr debug_module = nullptr; bool enable_debug_info_manager = true; @@ -103,37 +104,13 @@ class GpuExecutable : public Executable { // Analyze the entry function to construct buffer allocation and other output // information. - // - // TODO(ezhulenev): Once Xla runtime enabled by default, hide this method as - // an implementation detail of GpuExecutable. - static Status SetUpMlirAllocation( + static absl::Status SetUpMlirAllocation( mlir::func::FuncOp func, llvm::ArrayRef buffer_sizes, std::vector* allocations, absl::flat_hash_map* output_info, Shape* output_shape); - // Returns an Executable that is loaded from an object file (XLA program - // compiled to a native function using the XLA Runtime stack). - static StatusOr> LoadFromObjFile( - std::shared_ptr hlo_module, absl::string_view obj_file, - absl::string_view mlir_module, DebugOptions debug_options, - absl::string_view asm_text, absl::string_view binary, - std::vector constants, se::GpuComputeCapability gpu_version, - stream_executor::StreamExecutor* executor); - - // Constructor to use when loading a GpuExecutable from an object file (native - // function compiled for XLA Runtime). Omits setting class members that aren't - // used in XLA Runtime execution mode. - GpuExecutable(std::shared_ptr hlo_module, std::string asm_text, - std::vector binary, - std::vector constants, - se::GpuComputeCapability gpu_version, - absl::string_view module_name, Shape xla_output_shape, - std::vector allocations, - absl::flat_hash_map output_info, - std::unique_ptr runtime_executable); - - static StatusOr> Create(Params params); + static absl::StatusOr> Create(Params params); ~GpuExecutable() override; int64_t SizeOfGeneratedCodeInBytes() const override; @@ -162,21 +139,25 @@ class GpuExecutable : public Executable { // compiled. const std::vector& binary() const { return binary_; } + const Thunk::BinaryMap& dnn_compiled_graphs() const { + return dnn_compiled_graphs_; + } + // ExecuteAsyncOnStream will fail if the compute capability of the stream // doesn't match the compute capability passed to this object's constructor. - StatusOr ExecuteAsyncOnStream( + absl::StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, std::vector arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( + absl::StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, absl::Span arguments, HloExecutionProfile* hlo_execution_profile) override; using VariantArguments = std::variant, absl::Span>; - StatusOr ExecuteAsyncOnStreamImpl( + absl::StatusOr ExecuteAsyncOnStreamImpl( const ServiceExecutableRunOptions* run_options, VariantArguments arguments); @@ -198,9 +179,6 @@ class GpuExecutable : public Executable { const std::vector& constants() const { return constants_; } - StatusOr GetObjFile() const; - StatusOr GetMlirModule() const; - const BufferAssignment* buffer_assignment() const { return buffer_assignment_.get(); } @@ -214,10 +192,9 @@ class GpuExecutable : public Executable { // clients, such as Tensorflow, that use a single stream of execution for // computations, and allow host-side deallocation from the allocator before // GPU execution completes. - Status ExecuteThunksOrXlaRuntime( + absl::Status ExecuteThunksOrXlaRuntime( const ServiceExecutableRunOptions* run_options, - const BufferAllocations& buffer_allocations, bool block_host_until_done, - NonAtomicallyUpgradeableRWLock& gpu_lock); + const BufferAllocations& buffer_allocations, bool block_host_until_done); using BufferAllocToDeviceMemoryMap = absl::flat_hash_map; @@ -232,27 +209,20 @@ class GpuExecutable : public Executable { // The returned map is cached. If the above process has already been run for // the given stream, it is skipped and the cached map is immediately returned // instead. - StatusOr ResolveConstantGlobals( + absl::StatusOr ResolveConstantGlobals( stream_executor::Stream* stream); - // Allocate the temp buffers and store them with the GpuExecutable. This - // function only allocates buffers on the first run for each executor. - Status PopulatePersistentTempBuffers(se::StreamExecutor* executor) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(persistent_temp_buffers_mu_); - // GpuExecutable check with either AMD's ISA version, or Nvidia's major minor // version for compute capability, depending on the hardware. - Status CheckCompatibilityWithServiceExecutableRunOptions( + absl::Status CheckCompatibilityWithServiceExecutableRunOptions( const ServiceExecutableRunOptions* run_options); - StatusOr GenerateBufferAllocations( + absl::StatusOr GenerateBufferAllocations( VariantArguments arguments, const GpuExecutable::BufferAllocToDeviceMemoryMap* globals, - se::DeviceMemoryAllocator* memory_allocator, int device_ordinal, - const BufferAllocToDeviceMemoryMap& - buffer_alloc_to_persistent_memory_map); + se::DeviceMemoryAllocator* memory_allocator, int device_ordinal); - StatusOr BufferForAllocation( + absl::StatusOr BufferForAllocation( VariantArguments arguments, const GpuExecutable::BufferAllocToDeviceMemoryMap* globals, const BufferAllocation& allocation, @@ -274,11 +244,10 @@ class GpuExecutable : public Executable { // compute_capability_. // // May be empty, in which case we leave compilation up to the GPU driver. -#ifdef TENSORFLOW_USE_ROCM std::vector binary_; -#else - const std::vector binary_; -#endif + + Thunk::BinaryMap dnn_compiled_graphs_; + // The GPU version for compute compatibility check. se::GpuComputeCapability gpu_version_; @@ -286,10 +255,8 @@ class GpuExecutable : public Executable { // IrEmitter (null if XLA:GPU runtime is enabled). OwnedThunkSequence thunks_; - // Gpu runtime executable that encapsulates all the state for running Gpu - // runtime custom calls implementing gpu abstraction layer (available only if - // Xla runtime is enabled). - std::unique_ptr gpu_runtime_executable_; + // Additional execution streams requested by `thunks_`. + absl::flat_hash_set execution_stream_ids_; std::string module_name_; @@ -307,15 +274,12 @@ class GpuExecutable : public Executable { // This object is also used for dumping debug info. std::unique_ptr buffer_assignment_; - bool enable_persistent_temp_buffers_ = false; - - absl::Mutex persistent_temp_buffers_mu_; - // Temp buffers can be allocated once and be reused whenever the GpuExecutable - // is executed. The persistent temp buffer is stored in a map that maps from - // a BufferAllocation to the temp buffer. - absl::flat_hash_map - persistent_temp_buffers_ ABSL_GUARDED_BY(persistent_temp_buffers_mu_); + ModuleAnnotations module_annotations_ = [this] { + if (has_module()) { + return ModuleAnnotations(module()); + } + return ModuleAnnotations(module_name_); + }(); int64_t debug_buffer_assignment_show_max_; @@ -340,7 +304,7 @@ class GpuExecutable : public Executable { GpuExecutable& operator=(const GpuExecutable&) = delete; }; -StatusOr> +absl::StatusOr> GetOutputInfo(const HloModule& hlo_module, const BufferAssignment& assignment); } // namespace gpu diff --git a/xla/service/gpu/gpu_executable_run_options.cc b/xla/service/gpu/gpu_executable_run_options.cc index 62c1ddea06350..da269d10f8fa2 100644 --- a/xla/service/gpu/gpu_executable_run_options.cc +++ b/xla/service/gpu/gpu_executable_run_options.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,9 +17,11 @@ limitations under the License. #include #include +#include -#include "absl/algorithm/container.h" -#include "xla/status_macros.h" +#include "xla/executable_run_options.h" +#include "xla/service/global_device_id.h" +#include "xla/service/gpu/nccl_clique_key.h" namespace xla { namespace gpu { @@ -35,44 +37,15 @@ GpuExecutableRunOptions::gpu_global_device_ids() const { return gpu_global_device_ids_; } -GpuExecutableRunOptions& GpuExecutableRunOptions::set_nccl_unique_id_callback( - NcclUniqueIdCallback nccl_unique_id_callback) { - nccl_unique_id_callback_ = std::move(nccl_unique_id_callback); +GpuExecutableRunOptions& GpuExecutableRunOptions::set_nccl_clique_id_callback( + NcclCliqueIdCallback nccl_clique_id_callback) { + nccl_clique_id_callback_ = std::move(nccl_clique_id_callback); return *this; } -const NcclUniqueIdCallback& GpuExecutableRunOptions::nccl_unique_id_callback() +const NcclCliqueIdCallback& GpuExecutableRunOptions::nccl_clique_id_callback() const { - return nccl_unique_id_callback_; -} - -NcclExecuteParams::NcclExecuteParams( - const ServiceExecutableRunOptions& run_options, - se::StreamExecutor* stream_executor) - : stream_executor(stream_executor), - run_id(run_options.run_options().run_id()), - device_assn(run_options.run_options().device_assignment()) { - const GpuExecutableRunOptions* gpu_options = - run_options.run_options().gpu_executable_run_options(); - gpu_global_device_ids = gpu_options && gpu_options->gpu_global_device_ids() - ? &*gpu_options->gpu_global_device_ids() - : nullptr; - nccl_unique_id_callback = - gpu_options && gpu_options->nccl_unique_id_callback() - ? &gpu_options->nccl_unique_id_callback() - : nullptr; -} - -StatusOr NcclExecuteParams::GetGlobalDeviceId() const { - int64_t local_device_ordinal = stream_executor->device_ordinal(); - if (gpu_global_device_ids) { - auto it = gpu_global_device_ids->find(local_device_ordinal); - TF_RET_CHECK(it != gpu_global_device_ids->end()) << local_device_ordinal; - return it->second; - } else { - // No local -> global mapping was provided; assume the identity mapping. - return GlobalDeviceId(local_device_ordinal); - } + return nccl_clique_id_callback_; } } // namespace gpu diff --git a/xla/service/gpu/gpu_executable_run_options.h b/xla/service/gpu/gpu_executable_run_options.h index ffb8785bf7cce..0348152549929 100644 --- a/xla/service/gpu/gpu_executable_run_options.h +++ b/xla/service/gpu/gpu_executable_run_options.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,55 +16,16 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_ #define XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_ -#include #include #include -#include -#include -#include +#include "xla/executable_run_options.h" #include "xla/service/global_device_id.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/statusor.h" -#include "xla/stream_executor/stream_executor.h" +#include "xla/service/gpu/nccl_clique_key.h" namespace xla { namespace gpu { -// Key for naming up a particular NCCL clique. This is just a set of unique -// device IDs (i.e. GPU IDs) and a stream_id. The device IDs must be global -// within a cluster. The stream_id is used to create different NCCL clique and -// communicators for collectives executed on different streams within an -// executable. -class NcclCliqueKey { - public: - explicit NcclCliqueKey(std::vector devices, - int64_t stream_id = 0) - : devices_(std::move(devices)), stream_id_(stream_id) {} - - template - friend H AbslHashValue(H h, const NcclCliqueKey& k) { - return H::combine(std::move(h), k.devices_, k.stream_id_); - } - friend bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b) { - return a.devices_ == b.devices_ && a.stream_id_ == b.stream_id_; - } - - const std::vector& devices() const { return devices_; } - - std::string ToString() const { - return absl::StrCat("stream[", stream_id_, "]", - GlobalDeviceIdsToString(devices_)); - } - - private: - const std::vector devices_; - const int64_t stream_id_; -}; - -using NcclUniqueIdCallback = - std::function(const NcclCliqueKey&)>; - // GPU-specific executable options. // We keep these separate from ExecutableRunOptions to avoid adding // dependencies to ExecutableRunOptions. @@ -81,9 +42,9 @@ class GpuExecutableRunOptions { // Callback that returns a ncclUniqueId encoded as a string for a group of // communicating GPU devices. Used only on NVidia GPUs. - GpuExecutableRunOptions& set_nccl_unique_id_callback( - NcclUniqueIdCallback nccl_unique_id_callback); - const NcclUniqueIdCallback& nccl_unique_id_callback() const; + GpuExecutableRunOptions& set_nccl_clique_id_callback( + NcclCliqueIdCallback nccl_clique_id_callback); + const NcclCliqueIdCallback& nccl_clique_id_callback() const; // Whether the run requires an exclusive lock on the GPU. bool requires_exclusive_lock_on_gpu() const { @@ -100,31 +61,29 @@ class GpuExecutableRunOptions { return enable_mock_nccl_collectives_; } - // Enable mocking nccl collective operations on the GPU + // Enables mocking nccl collective operations on the GPU. GpuExecutableRunOptions& set_enable_mock_nccl_collectives() { enable_mock_nccl_collectives_ = true; return *this; } + enum class MockNcclTopoModel { kGCPA3, kNvidia }; + // Gets the nccl network topology used in mocking calls. + MockNcclTopoModel mock_nccl_topo_model() const { + return mock_nccl_topo_model_; + } + GpuExecutableRunOptions& set_mock_nccl_topo_model( + MockNcclTopoModel mock_nccl_topo_model) { + mock_nccl_topo_model_ = mock_nccl_topo_model; + return *this; + } + private: bool requires_exclusive_lock_on_gpu_ = false; bool enable_mock_nccl_collectives_ = false; + MockNcclTopoModel mock_nccl_topo_model_ = MockNcclTopoModel::kGCPA3; std::optional> gpu_global_device_ids_; - NcclUniqueIdCallback nccl_unique_id_callback_; -}; - -// NCCL-related execution parameters. -struct NcclExecuteParams { - NcclExecuteParams(const ServiceExecutableRunOptions& run_options, - se::StreamExecutor* stream_executor); - - se::StreamExecutor* stream_executor; - RunId run_id; - const DeviceAssignment* device_assn; // never null - const std::map* gpu_global_device_ids; // may be null - const NcclUniqueIdCallback* nccl_unique_id_callback; // may be null - - StatusOr GetGlobalDeviceId() const; + NcclCliqueIdCallback nccl_clique_id_callback_; }; } // namespace gpu diff --git a/xla/service/gpu/gpu_flash_attn.cc b/xla/service/gpu/gpu_flash_attn.cc new file mode 100644 index 0000000000000..91e6152282712 --- /dev/null +++ b/xla/service/gpu/gpu_flash_attn.cc @@ -0,0 +1,901 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/gpu_flash_attn.h" + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "cutlass/numeric_types.h" +#include "flash_attn/flash.h" +#include "flash_attn/static_switch.h" +#include "flash_attn/utils.h" +#include "xla/service/gpu/stream_executor_util.h" +#include "xla/stream_executor/gpu/gpu_stream.h" + +namespace xla { +namespace gpu { + +const absl::string_view kGpuFlashAttnFwdCallTarget = "__gpu$flash_attn_fwd"; +const absl::string_view kGpuFlashAttnBwdCallTarget = "__gpu$flash_attn_bwd"; +const absl::string_view kGpuFlashAttnVarLenFwdCallTarget = + "__gpu$flash_attn_varlen_fwd"; +const absl::string_view kGpuFlashAttnVarLenBwdCallTarget = + "__gpu$flash_attn_varlen_bwd"; + +bool IsFwdCustomCallToFlashAttn(const HloInstruction &hlo) { + if (hlo.opcode() != HloOpcode::kCustomCall) { + return false; + } + const std::string &target = hlo.custom_call_target(); + return target == kGpuFlashAttnFwdCallTarget || + target == kGpuFlashAttnVarLenFwdCallTarget; +} + +bool IsBwdCustomCallToFlashAttn(const HloInstruction &hlo) { + if (hlo.opcode() != HloOpcode::kCustomCall) { + return false; + } + const std::string &target = hlo.custom_call_target(); + return target == kGpuFlashAttnBwdCallTarget || + target == kGpuFlashAttnVarLenBwdCallTarget; +} + +bool IsCustomCallToFlashAttn(const HloInstruction &hlo) { + if (hlo.opcode() != HloOpcode::kCustomCall) { + return false; + } + const std::string &target = hlo.custom_call_target(); + return target == kGpuFlashAttnFwdCallTarget || + target == kGpuFlashAttnVarLenFwdCallTarget || + target == kGpuFlashAttnBwdCallTarget || + target == kGpuFlashAttnVarLenBwdCallTarget; +} + +absl::StatusOr GetFlashAttnKind( + const HloCustomCallInstruction *instr) { + absl::string_view target = instr->custom_call_target(); + if (target == kGpuFlashAttnFwdCallTarget) { + return FlashAttnKind::kForward; + } + if (target == kGpuFlashAttnVarLenFwdCallTarget) { + return FlashAttnKind::kVarLenForward; + } + if (target == kGpuFlashAttnBwdCallTarget) { + return FlashAttnKind::kBackward; + } + if (target == kGpuFlashAttnVarLenBwdCallTarget) { + return FlashAttnKind::kVarLenBackward; + } + return Internal("Unexpected call target: %s", target); +} + +absl::StatusOr FlashAttnConfig::For( + const Shape &query_shape, const Shape &key_shape, const Shape &value_shape, + const std::optional &cu_seqlens_query_shape, + const std::optional &cu_seqlens_key_shape, + const std::optional &alibi_slopes_shape, const Shape &output_shape, + const Shape &softmax_lse_shape, float dropout_rate, float scale, + bool is_causal, const std::optional &max_seqlen_q, + const std::optional &max_seqlen_k) { + PrimitiveType type = query_shape.element_type(); + + CHECK(type == PrimitiveType::F16 || type == PrimitiveType::BF16); + CHECK(type == key_shape.element_type() && + type == value_shape.element_type() && + type == output_shape.element_type()); + + FlashAttnConfig config; + config.type = type; + + TF_ASSIGN_OR_RETURN(se::dnn::DataType elem_type, + GetDNNDataTypeFromPrimitiveType(type)); + TF_ASSIGN_OR_RETURN(se::dnn::DataType f32_type, + GetDNNDataTypeFromPrimitiveType(PrimitiveType::F32)); + + config.query_desc = + se::dnn::TensorDescriptor::For(elem_type, query_shape.dimensions(), + query_shape.layout().minor_to_major()); + config.key_desc = se::dnn::TensorDescriptor::For( + elem_type, key_shape.dimensions(), key_shape.layout().minor_to_major()); + config.value_desc = + se::dnn::TensorDescriptor::For(elem_type, value_shape.dimensions(), + value_shape.layout().minor_to_major()); + bool is_varlen = cu_seqlens_query_shape.has_value(); + CHECK(is_varlen == cu_seqlens_key_shape.has_value() && + is_varlen == max_seqlen_q.has_value() && + is_varlen == max_seqlen_k.has_value()); + if (is_varlen) { + CHECK(cu_seqlens_query_shape->element_type() == PrimitiveType::S32); + CHECK(cu_seqlens_key_shape->element_type() == PrimitiveType::S32); + TF_ASSIGN_OR_RETURN(se::dnn::DataType cu_type, + GetDNNDataTypeFromPrimitiveType(PrimitiveType::S32)); + config.cu_seqlens_query_desc = se::dnn::TensorDescriptor::For( + cu_type, cu_seqlens_query_shape->dimensions(), + cu_seqlens_query_shape->layout().minor_to_major()); + config.cu_seqlens_key_desc = se::dnn::TensorDescriptor::For( + cu_type, cu_seqlens_key_shape->dimensions(), + cu_seqlens_key_shape->layout().minor_to_major()); + config.max_seqlen_q = max_seqlen_q; + config.max_seqlen_k = max_seqlen_k; + } + + if (alibi_slopes_shape.has_value()) { + config.alibi_slopes_desc = se::dnn::TensorDescriptor::For( + f32_type, alibi_slopes_shape.value().dimensions(), + alibi_slopes_shape.value().layout().minor_to_major()); + } + + config.output_desc = + se::dnn::TensorDescriptor::For(elem_type, output_shape.dimensions(), + output_shape.layout().minor_to_major()); + config.softmax_lse_desc = se::dnn::TensorDescriptor::For( + f32_type, softmax_lse_shape.dimensions(), + softmax_lse_shape.layout().minor_to_major()); + + config.dropout_rate = dropout_rate; + config.scale = scale; + config.is_causal = is_causal; + return config; +} + +absl::StatusOr FlashAttnFwdConfig::For( + const Shape &query_shape, const Shape &key_shape, const Shape &value_shape, + const std::optional &cu_seqlens_query_shape, + const std::optional &cu_seqlens_key_shape, + const std::optional &alibi_slopes_shape, const Shape &output_shape, + const Shape &softmax_lse_shape, const std::optional &s_dmask_shape, + float dropout_rate, float scale, bool is_causal, + const std::optional &max_seqlen_q, + const std::optional &max_seqlen_k) { + TF_ASSIGN_OR_RETURN( + FlashAttnFwdConfig config, + FlashAttnConfig::For(query_shape, key_shape, value_shape, + cu_seqlens_query_shape, cu_seqlens_key_shape, + alibi_slopes_shape, output_shape, softmax_lse_shape, + dropout_rate, scale, is_causal, max_seqlen_q, + max_seqlen_k)); + + if (s_dmask_shape.has_value()) { + TF_ASSIGN_OR_RETURN(se::dnn::DataType elem_type, + GetDNNDataTypeFromPrimitiveType(config.type)); + + config.s_dmask_desc = se::dnn::TensorDescriptor::For( + elem_type, s_dmask_shape->dimensions(), + s_dmask_shape->layout().minor_to_major()); + } + + return config; +} + +absl::StatusOr FlashAttnBwdConfig::For( + const Shape &grad_output_shape, const Shape &query_shape, + const Shape &key_shape, const Shape &value_shape, const Shape &output_shape, + const Shape &softmax_lse_shape, + const std::optional &cu_seqlens_query_shape, + const std::optional &cu_seqlens_key_shape, + const std::optional &alibi_slopes_shape, + const Shape &grad_query_shape, const Shape &grad_key_shape, + const Shape &grad_value_shape, const Shape &grad_softmax_shape, + float dropout_rate, float scale, bool is_causal, bool deterministic, + const std::optional &max_seqlen_q, + const std::optional &max_seqlen_k) { + TF_ASSIGN_OR_RETURN( + FlashAttnBwdConfig config, + FlashAttnConfig::For(query_shape, key_shape, value_shape, + cu_seqlens_query_shape, cu_seqlens_key_shape, + alibi_slopes_shape, output_shape, softmax_lse_shape, + dropout_rate, scale, is_causal, max_seqlen_q, + max_seqlen_k)); + + TF_ASSIGN_OR_RETURN(se::dnn::DataType elem_type, + GetDNNDataTypeFromPrimitiveType(config.type)); + + config.grad_output_desc = se::dnn::TensorDescriptor::For( + elem_type, grad_output_shape.dimensions(), + grad_output_shape.layout().minor_to_major()); + + config.grad_query_desc = se::dnn::TensorDescriptor::For( + elem_type, grad_query_shape.dimensions(), + grad_query_shape.layout().minor_to_major()); + config.grad_key_desc = + se::dnn::TensorDescriptor::For(elem_type, grad_key_shape.dimensions(), + grad_key_shape.layout().minor_to_major()); + config.grad_value_desc = se::dnn::TensorDescriptor::For( + elem_type, grad_value_shape.dimensions(), + grad_value_shape.layout().minor_to_major()); + config.grad_softmax_desc = se::dnn::TensorDescriptor::For( + se::dnn::ToDataType::value, grad_softmax_shape.dimensions(), + grad_softmax_shape.layout().minor_to_major()); + + config.deterministic = deterministic; + + return config; +} + +static void set_params_fprop( + Flash_fwd_params ¶ms, const FlashAttnConfig &config, + se::DeviceMemoryBase query_buffer, se::DeviceMemoryBase key_buffer, + se::DeviceMemoryBase value_buffer, se::DeviceMemoryBase output_buffer, + se::DeviceMemoryBase softmax_lse_buffer, + std::optional s_dmask_buffer, + // sizes + const size_t batch_size, const size_t seqlen_q, const size_t seqlen_k, + const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, + const size_t num_heads, const size_t num_heads_k, const size_t head_size, + const size_t head_size_rounded, void *cu_seqlens_q_d, void *cu_seqlens_k_d, + void *seqused_k, float p_dropout, float softmax_scale, int window_size_left, + int window_size_right, bool seqlenq_ngroups_swapped = false) { + // Reset the parameters + params = {}; + + params.is_bf16 = config.type == PrimitiveType::BF16; + + // Set the pointers and strides. + params.q_ptr = query_buffer.opaque(); + params.k_ptr = key_buffer.opaque(); + params.v_ptr = value_buffer.opaque(); + params.o_ptr = output_buffer.opaque(); + + // All stride are in elements, not bytes. + const auto &q_strides = config.query_desc.GetLogicalStrides(); + const auto &k_strides = config.key_desc.GetLogicalStrides(); + const auto &v_strides = config.value_desc.GetLogicalStrides(); + const auto &o_strides = config.output_desc.GetLogicalStrides(); + + // sequence length + params.q_row_stride = q_strides[q_strides.size() - 3]; + params.k_row_stride = k_strides[k_strides.size() - 3]; + params.v_row_stride = v_strides[v_strides.size() - 3]; + params.o_row_stride = o_strides[o_strides.size() - 3]; + + // head number + params.q_head_stride = q_strides[q_strides.size() - 2]; + params.k_head_stride = k_strides[k_strides.size() - 2]; + params.v_head_stride = v_strides[v_strides.size() - 2]; + params.o_head_stride = o_strides[o_strides.size() - 2]; + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = q_strides[0]; + params.k_batch_stride = k_strides[0]; + params.v_batch_stride = v_strides[0]; + params.o_batch_stride = o_strides[0]; + if (seqlenq_ngroups_swapped) { + params.q_batch_stride *= seqlen_q; + params.o_batch_stride *= seqlen_q; + } + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_k = static_cast(seqused_k); + + // P = softmax(QK^T) + params.p_ptr = + s_dmask_buffer.has_value() ? s_dmask_buffer->opaque() : nullptr; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_buffer.opaque(); + + // Set the dimensions. + params.b = batch_size; + params.h = num_heads; + params.h_k = num_heads_k; + params.h_h_k_ratio = num_heads / num_heads_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = head_size; + params.d_rounded = head_size_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to + // float to compare. [Minor] We want to round down since when we do the + // comparison we use <= instead of < params.p_dropout_in_uint = + // uint32_t(std::floor(params.p_dropout * 4294967295.0)); + // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * + // 65535.0)); + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; + CHECK(p_dropout < 1.f); +#ifdef FLASHATTENTION_DISABLE_DROPOUT + CHECK(p_dropout == 0.0f) + << "This flash attention build does not support dropout."; +#endif + + // Causal is the special case where window_size_right == 0 and + // window_size_left < 0. Local is the more general case where + // window_size_right >= 0 or window_size_left >= 0. + params.is_causal = window_size_left < 0 && window_size_right == 0; + + if (window_size_left < 0 && window_size_right >= 0) { + window_size_left = seqlen_k; + } + if (window_size_left >= 0 && window_size_right < 0) { + window_size_right = seqlen_k; + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + +#ifdef FLASHATTENTION_DISABLE_LOCAL + CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0)) + << "This flash attention build does not support local attention."; +#endif + + params.is_seqlens_k_cumulative = true; + +#ifdef FLASHATTENTION_DISABLE_UNEVEN_K + CHECK(head_size == head_size_rounded) + << "This flash attention build does not support " + "headdim not being a multiple of 32."; +#endif +} + +static void set_params_dgrad( + Flash_bwd_params ¶ms, const FlashAttnBwdConfig &config, + se::DeviceMemoryBase grad_output_buffer, se::DeviceMemoryBase query_buffer, + se::DeviceMemoryBase key_buffer, se::DeviceMemoryBase value_buffer, + se::DeviceMemoryBase output_buffer, se::DeviceMemoryBase softmax_lse_buffer, + se::DeviceMemoryBase grad_query_buffer, + se::DeviceMemoryBase grad_key_buffer, + se::DeviceMemoryBase grad_value_buffer, + se::DeviceMemoryBase grad_softmax_buffer, + // sizes + const size_t batch_size, const size_t seqlen_q, const size_t seqlen_k, + const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, + const size_t num_heads, const size_t num_heads_k, const size_t head_size, + const size_t head_size_rounded, void *cu_seqlens_q_d, void *cu_seqlens_k_d, + void *dq_accum_d, void *dk_accum_d, void *dv_accum_d, float p_dropout, + float softmax_scale, int window_size_left, int window_size_right, + bool deterministic) { + set_params_fprop(params, config, query_buffer, key_buffer, value_buffer, + output_buffer, softmax_lse_buffer, std::nullopt, batch_size, + seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, head_size, head_size_rounded, + cu_seqlens_q_d, cu_seqlens_k_d, + /*seqused_k=*/nullptr, p_dropout, softmax_scale, + window_size_left, window_size_right); + + // Set the pointers and strides. + params.do_ptr = grad_output_buffer.opaque(); + params.dq_ptr = grad_query_buffer.opaque(); + params.dk_ptr = grad_key_buffer.opaque(); + params.dv_ptr = grad_value_buffer.opaque(); + + // All stride are in elements, not bytes. + const auto &grad_output_strides = config.grad_output_desc.GetLogicalStrides(); + const auto &grad_query_strides = config.grad_query_desc.GetLogicalStrides(); + const auto &grad_key_strides = config.grad_key_desc.GetLogicalStrides(); + const auto &grad_value_strides = config.grad_value_desc.GetLogicalStrides(); + + // sequence length + params.do_row_stride = grad_output_strides[grad_output_strides.size() - 3]; + params.dq_row_stride = grad_query_strides[grad_query_strides.size() - 3]; + params.dk_row_stride = grad_key_strides[grad_key_strides.size() - 3]; + params.dv_row_stride = grad_value_strides[grad_value_strides.size() - 3]; + + // head number + params.do_head_stride = grad_output_strides[grad_output_strides.size() - 2]; + params.dq_head_stride = grad_query_strides[grad_query_strides.size() - 2]; + params.dk_head_stride = grad_key_strides[grad_key_strides.size() - 2]; + params.dv_head_stride = grad_value_strides[grad_value_strides.size() - 2]; + + if (cu_seqlens_q_d == nullptr) { + params.do_batch_stride = grad_output_strides[0]; + params.dq_batch_stride = grad_query_strides[0]; + params.dk_batch_stride = grad_key_strides[0]; + params.dv_batch_stride = grad_value_strides[0]; + } + + params.dq_accum_ptr = dq_accum_d; + params.dk_accum_ptr = dk_accum_d; + params.dv_accum_ptr = dv_accum_d; + + // Softmax sum + params.dsoftmax_sum = grad_softmax_buffer.opaque(); + + params.deterministic = deterministic; +} + +// Find the number of splits that maximizes the occupancy. For example, if we +// have batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = +// 0.89) is better than having 3 splits (efficiency = 0.67). However, we also +// don't want too many splits as that would incur more HBM reads/writes. So we +// find the best efficiency, then find the smallest number of splits that gets +// 85% of the best efficiency. +int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, + int num_n_blocks, int max_splits) { + // If we have enough to almost fill the SMs, then just use 1 split + if (batch_nheads_mblocks >= 0.8f * num_SMs) { + return 1; + } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // Some splits are not eligible. For example, if we have 64 blocks and choose + // 11 splits, we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have + // 6 * 11 + (-2) blocks (i.e. it's 11 splits anyway). So we check if the + // number of blocks per split is the same as the previous num_splits. + auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != + ceildiv(num_n_blocks, num_splits - 1); + }; + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + efficiency.push_back(0.f); + } else { + float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if (eff > max_efficiency) { + max_efficiency = eff; + } + efficiency.push_back(eff); + } + } + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + continue; + } + if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + +static void set_params_splitkv( + Flash_fwd_params ¶ms, const int batch_size, const int num_heads, + const int head_size, const int max_seqlen_k, const int max_seqlen_q, + const int head_size_rounded, const float p_dropout, const int num_splits, + std::optional output_accum_buffer, + std::optional softmax_lse_accum_buffer, + const cudaDeviceProp *dprops) { + params.num_splits = num_splits; + if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout + if (num_splits < 1) { + // This needs to match with run_mha_fwd_splitkv_dispatch + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; + // Technically kBlockM = 64 only for the splitKV kernels, not the standard + // kernel. In any case we don't expect seqlen_q to be larger than 64 for + // inference. + const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64; + params.num_splits = num_splits_heuristic( + batch_size * num_heads * num_m_blocks, + dprops->multiProcessorCount * 2, num_n_blocks, 128); + } + bool splitkv = params.num_splits > 1; + CHECK(splitkv == output_accum_buffer.has_value() && + splitkv == softmax_lse_accum_buffer.has_value()); + if (splitkv) { + params.oaccum_ptr = output_accum_buffer->opaque(); + params.softmax_lseaccum_ptr = softmax_lse_accum_buffer->opaque(); + } + CHECK(params.num_splits <= 128) << "num_splits > 128 not supported"; + } +} + +static void set_params_alibi( + Flash_fwd_params ¶ms, + std::optional &alibi_slopes_buffer, + std::optional alibi_slopes_desc, int batch_size, + int num_heads) { +#ifdef FLASHATTENTION_DISABLE_ALIBI + TORCH_CHECK(!alibi_slopes_buffer.has_value(), + "This flash attention build does not support alibi."); + params.alibi_slopes_ptr = nullptr; +#else + if (alibi_slopes_buffer.has_value()) { + CHECK(alibi_slopes_desc->type() == se::dnn::ToDataType::value) + << "ALiBi slopes must have dtype fp32"; + const auto &alibi_slopes_strides = alibi_slopes_desc->GetLogicalStrides(); + CHECK(alibi_slopes_strides.back() == 1) + << "ALiBi slopes tensor must have contiguous last dimension"; + CHECK((alibi_slopes_desc->dimensions() == std::vector{num_heads} || + alibi_slopes_desc->dimensions() == + std::vector{batch_size, num_heads})); + params.alibi_slopes_ptr = alibi_slopes_buffer->opaque(); + params.alibi_slopes_batch_stride = + alibi_slopes_desc->ndims() == 2 ? alibi_slopes_strides.front() : 0; + } else { + params.alibi_slopes_ptr = nullptr; + } +#endif +} + +static void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, + bool force_split_kernel = false) { + FP16_SWITCH(!params.is_bf16, [&] { + HEADDIM_SWITCH(params.d, [&] { + if (params.num_splits <= 1 && + !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch(params, stream); + } + }); + }); +} + +static void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + FP16_SWITCH(!params.is_bf16, [&] { + HEADDIM_SWITCH(params.d, + [&] { run_mha_bwd_(params, stream); }); + }); +} + +static int64_t GetNumElements(const std::vector &dims) { + return std::accumulate(dims.begin(), dims.end(), 1, + std::multiplies()); +} + +absl::Status RunFlashAttnFwd( + se::Stream *stream, const FlashAttnFwdConfig &config, + se::DeviceMemoryBase query_buffer, se::DeviceMemoryBase key_buffer, + se::DeviceMemoryBase value_buffer, + std::optional cu_seqlens_query_buffer, + std::optional cu_seqlens_key_buffer, + std::optional alibi_slopes_buffer, + std::optional output_accum_buffer, + std::optional softmax_lse_accum_buffer, + se::DeviceMemoryBase output_buffer, se::DeviceMemoryBase softmax_lse_buffer, + se::DeviceMemoryBase rng_state_buffer, + std::optional s_dmask_buffer, int window_size_left, + int window_size_right) { + const float p_dropout = config.dropout_rate; + const float softmax_scale = config.scale; + bool is_causal = config.is_causal; + + const cudaDeviceProp *dprops = flash::cuda::getCurrentDeviceProperties(); + const bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + const bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + CHECK(is_sm8x || is_sm90) + << "FlashAttention only supports Ampere GPUs or newer."; + + bool is_varlen = cu_seqlens_query_buffer.has_value(); + CHECK(is_varlen == cu_seqlens_key_buffer.has_value() && + is_varlen == config.cu_seqlens_query_desc.has_value() && + is_varlen == config.cu_seqlens_key_desc.has_value() && + is_varlen == config.max_seqlen_q.has_value() && + is_varlen == config.max_seqlen_k.has_value()) + << "cu_seqlens_query_buffer, cu_seqlens_key_buffer, max_seqlen_q, and " + "max_seqlen_k must be all set or all unset."; + + const auto &q_strides = config.query_desc.GetLogicalStrides(); + const auto &k_strides = config.key_desc.GetLogicalStrides(); + const auto &v_strides = config.value_desc.GetLogicalStrides(); + + CHECK(q_strides.back() == 1 && k_strides.back() == 1 && v_strides.back() == 1) + << "Input tensor must have contiguous last dimension in FlashAttention."; + + if (is_varlen) { + CHECK(config.cu_seqlens_query_desc->ndims() == 1 && + config.cu_seqlens_key_desc->ndims() == 1); + } + + const auto &o_strides = config.output_desc.GetLogicalStrides(); + CHECK(o_strides.back() == 1) + << "Output tensor must have contiguous last dimension in FlashAttention."; + + const auto &q_sizes = config.query_desc.dimensions(); + const auto &k_sizes = config.key_desc.dimensions(); + + int batch_size, num_heads, head_size_og, num_heads_k; + int seqlen_q, seqlen_k; + + if (is_varlen) { + batch_size = GetNumElements(config.cu_seqlens_query_desc->dimensions()) - 1; + num_heads = q_sizes[1]; + head_size_og = q_sizes[2]; + num_heads_k = k_sizes[1]; + + seqlen_q = config.max_seqlen_q.value(); + seqlen_k = config.max_seqlen_k.value(); + } else { + batch_size = q_sizes[0]; + num_heads = q_sizes[2]; + head_size_og = q_sizes[3]; + num_heads_k = k_sizes[2]; + + seqlen_q = q_sizes[1]; + seqlen_k = k_sizes[1]; + } + + CHECK(batch_size > 0) << "Batch size must be positive in FlashAttention."; + // TODO: more loose check for head_size? + CHECK(head_size_og % 8 == 0) + << "Head size must be a multiple of 8 in FlashAttention."; + CHECK(head_size_og <= 256) + << "FlashAttention forward only supports head dimension at most 256"; + // TODO: num_heads % num_heads_k == 0 + CHECK(num_heads == num_heads_k) << "Number of heads in key/value must be " + "equal to number of heads in query"; + + if (window_size_left >= seqlen_k) { + window_size_left = -1; + } + if (window_size_right >= seqlen_k) { + window_size_right = -1; + } + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && !alibi_slopes_buffer.has_value()) { + is_causal = false; + } + if (is_causal) { + window_size_right = 0; + } + + // TODO: seqlenq_ngroups_swapped + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + Flash_fwd_params params; + set_params_fprop(params, config, query_buffer, key_buffer, value_buffer, + output_buffer, softmax_lse_buffer, s_dmask_buffer, + batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, + seqlen_k_rounded, num_heads, num_heads_k, head_size, + head_size_rounded, + is_varlen ? cu_seqlens_query_buffer->opaque() : nullptr, + is_varlen ? cu_seqlens_key_buffer->opaque() : nullptr, + /*seqused_k=*/nullptr, p_dropout, softmax_scale, + window_size_left, window_size_right); + + // TODO: seqlenq_ngroups_swapped + if (!is_varlen) { + set_params_splitkv(params, batch_size, num_heads, head_size, seqlen_k, + seqlen_q, head_size_rounded, p_dropout, + /*num_splits*/ 0, output_accum_buffer, + softmax_lse_accum_buffer, dprops); + } + + // number of times random will be generated per thread, to offset philox + // counter in thc random state We use a custom RNG that increases the offset + // by batch_size * nheads * 32. + int64_t counter_offset = params.b * params.h * 32; + // Forward kernel will populate memory with the seed and offset. + params.rng_state = reinterpret_cast(rng_state_buffer.opaque()); + + if (p_dropout > 0.0) { + int cur_stream_device = stream->parent()->device_ordinal(); + auto &gen = flash::cuda::getDefaultCUDAGenerator(cur_stream_device); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen.mutex_); + params.philox_args = gen.philox_cuda_state(counter_offset); + } + + set_params_alibi(params, alibi_slopes_buffer, config.alibi_slopes_desc, + batch_size, num_heads); + + if (seqlen_k > 0) { + run_mha_fwd(params, se::gpu::AsGpuStreamValue(stream)); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output + // to 0. + TF_RETURN_IF_ERROR(stream->MemZero(&output_buffer, output_buffer.size())); + static uint32_t inf_pattern = []() { + float value = std::numeric_limits::infinity(); + uint32_t pattern; + std::memcpy(&pattern, &value, sizeof(pattern)); + return pattern; + }(); + TF_RETURN_IF_ERROR(stream->Memset32(&softmax_lse_buffer, inf_pattern, + softmax_lse_buffer.size())); + } + + return absl::OkStatus(); +} + +absl::Status RunFlashAttnBwd( + se::Stream *stream, const FlashAttnBwdConfig &config, + se::DeviceMemoryBase grad_output_buffer, se::DeviceMemoryBase query_buffer, + se::DeviceMemoryBase key_buffer, se::DeviceMemoryBase value_buffer, + se::DeviceMemoryBase output_buffer, se::DeviceMemoryBase softmax_lse_buffer, + se::DeviceMemoryBase rng_state_buffer, + std::optional cu_seqlens_query_buffer, + std::optional cu_seqlens_key_buffer, + std::optional alibi_slopes_buffer, + se::DeviceMemoryBase grad_query_accum_buffer, + se::DeviceMemoryBase grad_query_buffer, + se::DeviceMemoryBase grad_key_buffer, + se::DeviceMemoryBase grad_value_buffer, + se::DeviceMemoryBase grad_softmax_buffer, int window_size_left, + int window_size_right) { +#ifdef FLASHATTENTION_DISABLE_BACKWARD + CHECK(false) << "This flash attention build does not support backward."; +#endif + + const float p_dropout = config.dropout_rate; + bool is_dropout = p_dropout > 0.0; + bool is_causal = config.is_causal; + if (is_causal) { + window_size_right = 0; + } + + const cudaDeviceProp *dprops = flash::cuda::getCurrentDeviceProperties(); + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm80 = dprops->major == 8 && dprops->minor == 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + CHECK(is_sm8x || is_sm90) + << "FlashAttention only supports Ampere GPUs or newer."; + + bool is_varlen = cu_seqlens_query_buffer.has_value(); + CHECK(is_varlen == cu_seqlens_key_buffer.has_value() && + is_varlen == config.cu_seqlens_query_desc.has_value() && + is_varlen == config.cu_seqlens_key_desc.has_value() && + is_varlen == config.max_seqlen_q.has_value() && + is_varlen == config.max_seqlen_k.has_value()) + << "cu_seqlens_query_buffer, cu_seqlens_key_buffer, max_seqlen_q, and " + "max_seqlen_k must be all set or all unset."; + + const auto &query_strides = config.query_desc.GetLogicalStrides(); + const auto &key_strides = config.key_desc.GetLogicalStrides(); + const auto &value_strides = config.value_desc.GetLogicalStrides(); + const auto &output_strides = config.output_desc.GetLogicalStrides(); + const auto &grad_output_strides = config.grad_output_desc.GetLogicalStrides(); + CHECK(query_strides.back() == 1 && key_strides.back() == 1 && + value_strides.back() == 1) + << "Input tensor must have contiguous last dimension in FlashAttention."; + CHECK(output_strides.back() == 1) + << "Output tensor must have contiguous last dimension in FlashAttention."; + CHECK(grad_output_strides.back() == 1) + << "Gradient output tensor must have contiguous last dimension in " + "FlashAttention."; + + if (is_varlen) { + CHECK(config.cu_seqlens_query_desc->ndims() == 1 && + config.cu_seqlens_key_desc->ndims() == 1); + } + + const auto &q_sizes = config.query_desc.dimensions(); + const auto &k_sizes = config.key_desc.dimensions(); + const auto &dout_sizes = config.grad_output_desc.dimensions(); + + int batch_size, num_heads, head_size_og, head_size, num_heads_k; + int seqlen_q, seqlen_k; + int total_q, total_k; + + if (is_varlen) { + batch_size = GetNumElements(config.cu_seqlens_query_desc->dimensions()) - 1; + num_heads = q_sizes[1]; + head_size_og = dout_sizes[2]; + head_size = q_sizes[2]; + num_heads_k = k_sizes[1]; + + seqlen_q = config.max_seqlen_q.value(); + seqlen_k = config.max_seqlen_k.value(); + + total_q = q_sizes[0]; + total_k = k_sizes[0]; + } else { + batch_size = q_sizes[0]; + num_heads = q_sizes[2]; + head_size_og = dout_sizes[3]; + head_size = q_sizes[3]; + num_heads_k = k_sizes[2]; + + seqlen_q = q_sizes[1]; + seqlen_k = k_sizes[1]; + } + + CHECK(batch_size > 0) << "Batch size must be positive in FlashAttention."; + CHECK(head_size % 8 == 0) + << "Head size must be a multiple of 8 in FlashAttention."; + CHECK(head_size <= 256) + << "FlashAttention forward only supports head dimension at most 256"; + if (head_size > 192 && (head_size <= 224 || is_dropout)) { + CHECK(is_sm80 || is_sm90) + << "FlashAttention backward for head dim 256 with dropout, or head dim " + "224 with/without dropout requires A100/A800 or H100/H800"; + } + CHECK(num_heads % num_heads_k == 0) + << "Number of heads in key/value must divide number of heads in query"; + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + CHECK(head_size == round_multiple(head_size_og, 8)) + << "head_size must be head_size_og rounded to a multiple of 8"; + + if (window_size_left >= seqlen_k) { + window_size_left = -1; + } + if (window_size_right >= seqlen_k) { + window_size_right = -1; + } + + bool deterministic = config.deterministic; + + bool loop = true; + + // at::Tensor dk_expanded, dv_expanded; + // if (num_heads_k != num_heads) { // MQA / GQA + // dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, + // head_size}, opts); dv_expanded = torch::empty({batch_size, seqlen_k, + // num_heads, head_size}, opts); + // } else { + // dk_expanded = dk; + // dv_expanded = dv; + // } + + Flash_bwd_params params; + set_params_dgrad(params, config, grad_output_buffer, query_buffer, key_buffer, + value_buffer, output_buffer, softmax_lse_buffer, + grad_query_buffer, grad_key_buffer, grad_value_buffer, + grad_softmax_buffer, batch_size, seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, + head_size, head_size_rounded, + is_varlen ? cu_seqlens_query_buffer->opaque() : nullptr, + is_varlen ? cu_seqlens_key_buffer->opaque() : nullptr, + loop ? grad_query_accum_buffer.opaque() : nullptr, + /*dk_accum_d=*/nullptr, + /*dv_accum_d=*/nullptr, p_dropout, config.scale, + window_size_left, window_size_right, deterministic); + + if (deterministic) { + int64_t hidden_size = num_heads * head_size_rounded; + if (is_varlen) { + params.dq_accum_split_stride = (total_q + 128 * batch_size) * hidden_size; + } else { + params.dq_accum_split_stride = + batch_size * seqlen_q_rounded * hidden_size; + } + } else { + params.dq_accum_split_stride = 0; + } + + int64_t counter_offset = params.b * params.h * 32; + + params.rng_state = reinterpret_cast(rng_state_buffer.opaque()); + + set_params_alibi(params, alibi_slopes_buffer, config.alibi_slopes_desc, + batch_size, num_heads); + + if (seqlen_q > 0) { + run_mha_bwd(params, se::gpu::AsGpuStreamValue(stream)); + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output + // to 0. + TF_RETURN_IF_ERROR( + stream->MemZero(&grad_key_buffer, grad_key_buffer.size())); + TF_RETURN_IF_ERROR( + stream->MemZero(&grad_value_buffer, grad_value_buffer.size())); + TF_RETURN_IF_ERROR( + stream->MemZero(&grad_softmax_buffer, grad_softmax_buffer.size())); + } + + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/gpu_flash_attn.h b/xla/service/gpu/gpu_flash_attn.h new file mode 100644 index 0000000000000..f430b468cb75c --- /dev/null +++ b/xla/service/gpu/gpu_flash_attn.h @@ -0,0 +1,157 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_GPU_FLASH_ATTN_H_ +#define XLA_SERVICE_GPU_GPU_FLASH_ATTN_H_ + +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/shape.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/dnn.h" +#include "xla/types.h" +#include "xla/util.h" + +namespace xla { +namespace gpu { + +extern const absl::string_view kGpuFlashAttnFwdCallTarget; +extern const absl::string_view kGpuFlashAttnBwdCallTarget; +extern const absl::string_view kGpuFlashAttnVarLenFwdCallTarget; +extern const absl::string_view kGpuFlashAttnVarLenBwdCallTarget; + +bool IsFwdCustomCallToFlashAttn(const HloInstruction &hlo); +bool IsBwdCustomCallToFlashAttn(const HloInstruction &hlo); +bool IsCustomCallToFlashAttn(const HloInstruction &hlo); + +enum class FlashAttnKind { + kForward, + kVarLenForward, + kBackward, + kVarLenBackward, +}; + +absl::StatusOr GetFlashAttnKind( + const HloCustomCallInstruction *instr); + +struct FlashAttnConfig { + static absl::StatusOr For( + const Shape &query_shape, const Shape &key_shape, + const Shape &value_shape, + const std::optional &cu_seqlens_query_shape, + const std::optional &cu_seqlens_key_shape, + const std::optional &alibi_slopes_shape, const Shape &output_shape, + const Shape &softmax_lse_shape, float dropout_rate, float scale, + bool is_causal, const std::optional &max_seqlen_q, + const std::optional &max_seqlen_k); + PrimitiveType type; + + se::dnn::TensorDescriptor query_desc; // input + se::dnn::TensorDescriptor key_desc; // input + se::dnn::TensorDescriptor value_desc; // input + std::optional cu_seqlens_query_desc; // input + std::optional cu_seqlens_key_desc; // input + std::optional alibi_slopes_desc; // input + + se::dnn::TensorDescriptor output_desc; // output(fwd), input(bwd) + se::dnn::TensorDescriptor softmax_lse_desc; // output(fwd), input(bwd) + + std::optional max_seqlen_q; + std::optional max_seqlen_k; + + float dropout_rate; + float scale; + bool is_causal; +}; + +struct FlashAttnFwdConfig : public FlashAttnConfig { + static absl::StatusOr For( + const Shape &query_shape, const Shape &key_shape, + const Shape &value_shape, + const std::optional &cu_seqlens_query_shape, + const std::optional &cu_seqlens_key_shape, + const std::optional &alibi_slopes_shape, const Shape &output_shape, + const Shape &softmax_lse_shape, const std::optional &s_dmask_shape, + float dropout_rate, float scale, bool is_causal, + const std::optional &max_seqlen_q, + const std::optional &max_seqlen_k); + + FlashAttnFwdConfig(const FlashAttnConfig &config) : FlashAttnConfig(config) {} + + std::optional s_dmask_desc; // output +}; + +struct FlashAttnBwdConfig : public FlashAttnConfig { + static absl::StatusOr For( + const Shape &grad_output_shape, const Shape &query_shape, + const Shape &key_shape, const Shape &value_shape, + const Shape &output_shape, const Shape &softmax_lse_shape, + const std::optional &cu_seqlens_query_shape, + const std::optional &cu_seqlens_key_shape, + const std::optional &alibi_slopes_shape, + const Shape &grad_query_shape, const Shape &grad_key_shape, + const Shape &grad_value_shape, const Shape &grad_softmax_shape, + float dropout_rate, float scale, bool is_causal, bool deterministic, + const std::optional &max_seqlen_q, + const std::optional &max_seqlen_k); + + FlashAttnBwdConfig(const FlashAttnConfig &config) : FlashAttnConfig(config) {} + + se::dnn::TensorDescriptor grad_output_desc; // input + se::dnn::TensorDescriptor grad_query_desc; // output + se::dnn::TensorDescriptor grad_key_desc; // output + se::dnn::TensorDescriptor grad_value_desc; // output + se::dnn::TensorDescriptor grad_softmax_desc; // output + + bool deterministic; +}; + +int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, + int num_n_blocks, int max_splits); + +absl::Status RunFlashAttnFwd( + se::Stream *stream, const FlashAttnFwdConfig &config, + se::DeviceMemoryBase query_buffer, se::DeviceMemoryBase key_buffer, + se::DeviceMemoryBase value_buffer, + std::optional cu_seqlens_query_buffer, + std::optional cu_seqlens_key_buffer, + std::optional alibi_slopes_buffer, + std::optional output_accum_buffer, + std::optional softmax_lse_accum_buffer, + se::DeviceMemoryBase output_buffer, se::DeviceMemoryBase softmax_lse_buffer, + se::DeviceMemoryBase rng_state_buffer, + std::optional s_dmask_buffer, + int window_size_left = -1, int window_size_right = -1); + +absl::Status RunFlashAttnBwd( + se::Stream *stream, const FlashAttnBwdConfig &config, + se::DeviceMemoryBase grad_output_buffer, se::DeviceMemoryBase query_buffer, + se::DeviceMemoryBase key_buffer, se::DeviceMemoryBase value_buffer, + se::DeviceMemoryBase output_buffer, se::DeviceMemoryBase softmax_lse_buffer, + se::DeviceMemoryBase rng_state_buffer, + std::optional cu_seqlens_query_buffer, + std::optional cu_seqlens_key_buffer, + std::optional alibi_slopes_buffer, + se::DeviceMemoryBase grad_query_accum_buffer, + se::DeviceMemoryBase grad_query_buffer, + se::DeviceMemoryBase grad_key_buffer, + se::DeviceMemoryBase grad_value_buffer, + se::DeviceMemoryBase grad_softmax_buffer, int window_size_left = -1, + int window_size_right = -1); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_GPU_FLASH_ATTN_H_ diff --git a/xla/service/gpu/gpu_flash_attn_normalization.cc b/xla/service/gpu/gpu_flash_attn_normalization.cc new file mode 100644 index 0000000000000..3844027bc65a8 --- /dev/null +++ b/xla/service/gpu/gpu_flash_attn_normalization.cc @@ -0,0 +1,245 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/gpu_flash_attn_normalization.h" + +#include +#include + +#include "flash_attn/utils.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/layout_util.h" +#include "xla/literal_util.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/gpu_flash_attn.h" +#include "xla/shape_util.h" + +namespace xla { +namespace gpu { + +StatusOr GpuFlashAttnNormalization::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + + VLOG(2) << "Before flash attention normalization:"; + XLA_VLOG_LINES(2, module->ToString()); + + for (HloComputation* computation : module->computations(execution_threads)) { + for (HloInstruction* instr : computation->instructions()) { + if (instr->opcode() != HloOpcode::kCustomCall) { + continue; + } + auto kind_status = + GetFlashAttnKind(Cast(instr)); + if (!kind_status.ok()) { + continue; + } + FlashAttnKind kind = kind_status.value(); + bool attn_changed = false; + switch (kind) { + case FlashAttnKind::kForward: { + TF_ASSIGN_OR_RETURN( + attn_changed, + RunOnFwdFlashAttn(computation, instr, /*is_varlen=*/false)); + break; + } + case FlashAttnKind::kVarLenForward: { + TF_ASSIGN_OR_RETURN( + attn_changed, + RunOnFwdFlashAttn(computation, instr, /*is_varlen=*/true)); + break; + } + case FlashAttnKind::kBackward: { + TF_ASSIGN_OR_RETURN( + attn_changed, + RunOnBwdFlashAttn(computation, instr, /*is_varlen=*/false)); + break; + } + case FlashAttnKind::kVarLenBackward: { + TF_ASSIGN_OR_RETURN( + attn_changed, + RunOnBwdFlashAttn(computation, instr, /*is_varlen=*/true)); + break; + } + } + + changed |= attn_changed; + } + } + + VLOG(2) << "After flash attention normalization:"; + XLA_VLOG_LINES(2, module->ToString()); + + return changed; +} + +static int RoundMultiple(int x, int m) { return (x + m - 1) / m * m; } + +static HloInstruction* CreateZeroTensor(HloComputation* computation, + HloInstruction* const_zero, + PrimitiveType element_type, + absl::Span dimensions) { + const Shape& shape = ShapeUtil::MakeShape(element_type, dimensions); + return computation->AddInstruction( + HloInstruction::CreateBroadcast(shape, const_zero, {})); +} + +static HloInstruction* CreateZeroTensor(HloComputation* computation, + PrimitiveType element_type, + absl::Span dimensions) { + HloInstruction* const_zero = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); + return CreateZeroTensor(computation, const_zero, element_type, dimensions); +} + +absl::StatusOr GpuFlashAttnNormalization::RunOnFwdFlashAttn( + HloComputation* computation, HloInstruction* instr, bool is_varlen) { + // flash_attn_varlen does not support splitkv + if (is_varlen) { + return false; + } + + TF_ASSIGN_OR_RETURN(const auto gpu_config, + instr->backend_config()); + const auto& config = gpu_config.flash_attn_backend_config(); + + float p_dropout = config.dropout_rate(); + if (p_dropout == 0.0f) { + const HloInstruction* query = instr->operand(0); + const HloInstruction* key = instr->operand(1); + + const Shape& q_shape = query->shape(); + const Shape& k_shape = key->shape(); + + int batch_size = q_shape.dimensions(0); + int num_heads = q_shape.dimensions(2); + int head_size = RoundMultiple(q_shape.dimensions(3), 8); + int head_size_rounded = RoundMultiple(head_size, 32); + + int max_seqlen_q = q_shape.dimensions(1); + int max_seqlen_k = k_shape.dimensions(1); + + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; + const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64; + const cudaDeviceProp* dprops = flash::cuda::getCurrentDeviceProperties(); + int num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, + dprops->multiProcessorCount * 2, + num_n_blocks, 128); + if (num_splits > 1) { + PrimitiveType accum_type = PrimitiveType::F32; + HloInstruction* const_zero = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(accum_type))); + + HloInstruction* output_accum = CreateZeroTensor( + computation, const_zero, accum_type, + {num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}); + instr->AppendOperand(output_accum); + + HloInstruction* softmax_lse_accum = + CreateZeroTensor(computation, const_zero, accum_type, + {num_splits, batch_size, num_heads, max_seqlen_q}); + instr->AppendOperand(softmax_lse_accum); + return true; + } + } + + return false; +} + +absl::StatusOr GpuFlashAttnNormalization::RunOnBwdFlashAttn( + HloComputation* computation, HloInstruction* instr, bool is_varlen) { + TF_ASSIGN_OR_RETURN(const auto gpu_config, + instr->backend_config()); + const auto& config = gpu_config.flash_attn_backend_config(); + + bool deterministic = config.deterministic(); + + const HloInstruction* query = instr->operand(1); + const Shape& q_shape = query->shape(); + + int batch_size, num_heads, head_size; + int max_seqlen_q; + int total_q; + + if (is_varlen) { + const HloInstruction* cu_seqlens_query = instr->operand(7); + const auto cu_seqlens_query_dims = cu_seqlens_query->shape().dimensions(); + batch_size = std::accumulate(cu_seqlens_query_dims.begin(), + cu_seqlens_query_dims.end(), 1, + std::multiplies()) - + 1; + total_q = q_shape.dimensions(0); + num_heads = q_shape.dimensions(1); + head_size = q_shape.dimensions(2); + max_seqlen_q = config.max_seqlen_q(); + } else { + batch_size = q_shape.dimensions(0); + max_seqlen_q = q_shape.dimensions(1); + num_heads = q_shape.dimensions(2); + head_size = q_shape.dimensions(3); + } + + int seqlen_q_rounded = RoundMultiple(max_seqlen_q, 128); + int head_size_rounded = RoundMultiple(head_size, 32); + + std::vector dq_accum_dims; + + if (!deterministic) { + if (is_varlen) { + dq_accum_dims = { + total_q + 128 * batch_size, + num_heads, + head_size_rounded, + }; + } else { + dq_accum_dims = { + batch_size, + seqlen_q_rounded, + num_heads, + head_size_rounded, + }; + } + } else { + const cudaDeviceProp* dprops = flash::cuda::getCurrentDeviceProperties(); + const int nsplits = + (dprops->multiProcessorCount + batch_size * num_heads - 1) / + (batch_size * num_heads); + if (is_varlen) { + dq_accum_dims = { + nsplits, + total_q + 128 * batch_size, + num_heads, + head_size_rounded, + }; + } else { + dq_accum_dims = { + nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded, + }; + } + } + + HloInstruction* grad_query_accum = + CreateZeroTensor(computation, PrimitiveType::F32, dq_accum_dims); + + instr->AppendOperand(grad_query_accum); + + return true; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/gpu_flash_attn_normalization.h b/xla/service/gpu/gpu_flash_attn_normalization.h new file mode 100644 index 0000000000000..1276b64c60410 --- /dev/null +++ b/xla/service/gpu/gpu_flash_attn_normalization.h @@ -0,0 +1,56 @@ + +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_GPU_FLASH_ATTN_NORMALIZATION_H_ +#define XLA_SERVICE_GPU_GPU_FLASH_ATTN_NORMALIZATION_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +class GpuFlashAttnNormalization : public HloModulePass { + public: + GpuFlashAttnNormalization() = default; + + absl::string_view name() const override { + return "gpu_flash_attn_normalization"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + static absl::StatusOr RunOnFwdFlashAttn(HloComputation* computation, + HloInstruction* instr, + bool is_varlen); + static absl::StatusOr RunOnBwdFlashAttn(HloComputation* computation, + HloInstruction* instr, + bool is_varlen); +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_GPU_FLASH_ATTN_NORMALIZATION_H_ diff --git a/xla/service/gpu/gpu_float_support.cc b/xla/service/gpu/gpu_float_support.cc index 7652bed27c61f..3bae5e6b8e7e7 100644 --- a/xla/service/gpu/gpu_float_support.cc +++ b/xla/service/gpu/gpu_float_support.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,10 +15,14 @@ limitations under the License. #include "xla/service/gpu/gpu_float_support.h" +#include + #include "absl/log/check.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/float_support.h" +#include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" namespace xla { @@ -30,7 +34,7 @@ bool GpuFloatSupport::SupportsMixedPrecisions(const HloInstruction& hlo) const { switch (hlo.opcode()) { // Handled by Triton GEMM or cuBLAS. case HloOpcode::kDot: { - CHECK_EQ(hlo.operand_count(), 2); + CHECK_GE(hlo.operand_count(), HloDotInstruction::kOperands); const PrimitiveType lhs_type = hlo.operand(0)->shape().element_type(); const PrimitiveType rhs_type = hlo.operand(1)->shape().element_type(); const PrimitiveType result_type = hlo.shape().element_type(); @@ -73,6 +77,18 @@ bool GpuFloatSupport::IsSupported(const HloInstruction& hlo) const { // Other special ops. case HloOpcode::kBitcast: return true; + // Elementwise ops. + case HloOpcode::kAdd: + case HloOpcode::kSubtract: + case HloOpcode::kMultiply: { + if (LowPrecisionType() == BF16) { + auto* cuda_compute_capability = + std::get_if(&compute_capability_); + return cuda_compute_capability != nullptr && + cuda_compute_capability->IsAtLeastHopper(); + } + return false; + } default: return false; } diff --git a/xla/service/gpu/gpu_float_support.h b/xla/service/gpu/gpu_float_support.h index 2bdc64a739f85..ee63dd3d2daab 100644 --- a/xla/service/gpu/gpu_float_support.h +++ b/xla/service/gpu/gpu_float_support.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/float_support.h" +#include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" namespace xla { @@ -27,9 +28,11 @@ namespace gpu { class GpuFloatSupport : public FloatSupport { public: - explicit GpuFloatSupport(PrimitiveType low_precision_type, + explicit GpuFloatSupport(se::GpuComputeCapability cc, + PrimitiveType low_precision_type, PrimitiveType high_precision_type = F32) - : FloatSupport(low_precision_type, high_precision_type) {} + : FloatSupport(low_precision_type, high_precision_type), + compute_capability_(cc) {} bool SupportsLowPrecisionOperand(const HloInstruction& hlo, int64_t operand_index) const override { @@ -45,6 +48,8 @@ class GpuFloatSupport : public FloatSupport { private: bool IsSupported(const HloInstruction& hlo) const; + + const se::GpuComputeCapability compute_capability_; }; } // namespace gpu diff --git a/xla/service/gpu/gpu_fused_mha_runner.cc b/xla/service/gpu/gpu_fused_mha_runner.cc index a796f0b5652ca..3ea68861a23a2 100644 --- a/xla/service/gpu/gpu_fused_mha_runner.cc +++ b/xla/service/gpu/gpu_fused_mha_runner.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,17 +15,25 @@ limitations under the License. #include "xla/service/gpu/gpu_fused_mha_runner.h" +#include #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/layout_util.h" +#include "absl/strings/str_format.h" +#include "Eigen/Core" // from @eigen_archive #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/stream_executor_util.h" -#include "xla/shape_util.h" -#include "xla/status_macros.h" +#include "xla/shape.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/lazy_op_runner.h" +#include "xla/stream_executor/stream.h" #include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -38,15 +46,17 @@ using se::dnn::MatmulTensorDescriptor; using se::dnn::TensorDescriptor; template -Status RunFusedMHA(GpufMHAParams params, se::Stream *stream, - RunFusedMHAOptions options, - DeviceMemory lhs_bmm1_buffer, - DeviceMemory rhs_bmm1_buffer, - DeviceMemory rhs_bmm2_buffer, - DeviceMemory output_buffer, - DeviceMemoryBase mask_buffer, DeviceMemoryBase bias_buffer, - DeviceMemoryBase scratch_memory, - DeviceMemoryBase activation_output) { +absl::Status RunFusedMHA(GpufMHAParams params, se::Stream *stream, + RunFusedMHAOptions options, + DeviceMemory lhs_bmm1_buffer, + DeviceMemory rhs_bmm1_buffer, + DeviceMemory rhs_bmm2_buffer, + DeviceMemory output_buffer, + DeviceMemoryBase mask_buffer, + DeviceMemoryBase bias_buffer, + DeviceMemoryBase scratch_memory, + DeviceMemoryBase activation_output, + DeviceMemoryBase seqlen_q, DeviceMemoryBase seqlen_k) { se::dnn::LazyOpRunner *lazy_runner = options.runner_cache->AsFusedMHARunner(); std::optional> local_runner; @@ -89,13 +99,14 @@ Status RunFusedMHA(GpufMHAParams params, se::Stream *stream, lazy_runner->GetOrCreateRunner(config, stream)); return (*runner)(stream, options.profile_result, scratch_memory, lhs_bmm1_buffer, rhs_bmm1_buffer, rhs_bmm2_buffer, - output_buffer, mask_buffer, bias_buffer, activation_output); + output_buffer, mask_buffer, bias_buffer, activation_output, + seqlen_q, seqlen_k); } template -Status RunGpuFMHAImpl(const GpufMHAParams ¶ms, se::Stream *stream, - se::DeviceMemoryBase scratch_memory, - RunFusedMHAOptions options) { +absl::Status RunGpuFMHAImpl(const GpufMHAParams ¶ms, se::Stream *stream, + se::DeviceMemoryBase scratch_memory, + RunFusedMHAOptions options) { auto lhs_bmm1_buffer = se::DeviceMemory(params.lhs_bmm1_buffer); auto rhs_bmm1_buffer = se::DeviceMemory(params.rhs_bmm1_buffer); auto rhs_bmm2_buffer = se::DeviceMemory(params.rhs_bmm2_buffer); @@ -110,12 +121,20 @@ Status RunGpuFMHAImpl(const GpufMHAParams ¶ms, se::Stream *stream, auto bias_buffer = params.bias_buffer.has_value() ? se::DeviceMemory(*params.bias_buffer) : se::DeviceMemoryBase(); + auto seqlen_q_buffer = + params.seqlen_q_buffer.has_value() + ? se::DeviceMemory(*params.seqlen_q_buffer) + : se::DeviceMemoryBase(); + auto seqlen_k_buffer = + params.seqlen_k_buffer.has_value() + ? se::DeviceMemory(*params.seqlen_k_buffer) + : se::DeviceMemoryBase(); se::dnn::AlgorithmDesc algorithm = params.config->algorithm; if (options.runner_cache) { algorithm = options.runner_cache->ToAlgorithmDesc(); } - Status run_status = OkStatus(); + absl::Status run_status = absl::OkStatus(); switch (params.config->kind) { case CudnnfMHAKind::kBmmBmm: case CudnnfMHAKind::kSoftmaxDropout: @@ -129,79 +148,27 @@ Status RunGpuFMHAImpl(const GpufMHAParams ¶ms, se::Stream *stream, run_status = RunFusedMHA( params, stream, options, lhs_bmm1_buffer, rhs_bmm1_buffer, rhs_bmm2_buffer, output_buffer, mask_buffer, bias_buffer, - scratch_memory, activation_buffer); + scratch_memory, activation_buffer, seqlen_q_buffer, seqlen_k_buffer); break; default: - return InternalError("Invalid cuDNN fMHA kind"); + return Internal("Invalid cuDNN fMHA kind"); } - if (run_status != OkStatus()) { + if (!run_status.ok()) { return run_status; } if (!stream->ok()) { - return InternalError("Unable to launch FMHA with type %s and algorithm %s", - CudnnfMHAKindToString(params.config->kind), - algorithm.ToString()); + return Internal("Unable to launch FMHA with type %s and algorithm %s", + CudnnfMHAKindToString(params.config->kind), + algorithm.ToString()); } - return OkStatus(); -} - -void AssignScale(GpufMHAConfig &config, - const CudnnfMHABackendConfig &backend_config) { - double fmha_scale = 0.0; - - switch (config.kind) { - case CudnnfMHAKind::kScaleBiasMaskSoftmax: - case CudnnfMHAKind::kScaleBiasMaskSoftmaxDropout: - case CudnnfMHAKind::kScaleMaskSoftmax: - case CudnnfMHAKind::kScaleMaskSoftmaxDropout: - case CudnnfMHAKind::kScaleBiasSoftmaxDropout: - case CudnnfMHAKind::kScaleBiasSoftmax: - fmha_scale = backend_config.fmha_scale(); - config.fmha_scale.emplace(fmha_scale); - break; - default: - break; - } -} - -void AssignDropoutRate(GpufMHAConfig &config, - const CudnnfMHABackendConfig &backend_config) { - double dropout_rate = 0.0; - switch (config.kind) { - case CudnnfMHAKind::kScaleBiasMaskSoftmaxDropout: - case CudnnfMHAKind::kScaleMaskSoftmaxDropout: - case CudnnfMHAKind::kSoftmaxDropout: - case CudnnfMHAKind::kScaleBiasSoftmaxDropout: - dropout_rate = backend_config.dropout_rate(); - config.dropout_rate.emplace(dropout_rate); - break; - default: - break; - } -} - -void AssignSeed(GpufMHAConfig &config, - const CudnnfMHABackendConfig &backend_config) { - int64_t seed_value = 0; - - switch (config.kind) { - case CudnnfMHAKind::kScaleBiasMaskSoftmaxDropout: - case CudnnfMHAKind::kScaleMaskSoftmaxDropout: - case CudnnfMHAKind::kSoftmaxDropout: - case CudnnfMHAKind::kScaleBiasSoftmaxDropout: - seed_value = backend_config.seed(); - config.seed.emplace(seed_value); - break; - default: - break; - } + return absl::OkStatus(); } template -Status RunFusedMHABackward( +absl::Status RunFusedMHABackward( GpufMHABackwardParams params, se::Stream *stream, RunFusedMHABackwardOptions options, DeviceMemory bmm1_grad_gemm1_rhs_buffer, @@ -215,7 +182,8 @@ Status RunFusedMHABackward( DeviceMemoryBase softmax_buffer, DeviceMemoryBase d_Q_accum_buffer, DeviceMemoryBase mask_buffer, DeviceMemoryBase d_bias_buffer, DeviceMemoryBase fwd_output_buffer, DeviceMemoryBase bias_buffer, - DeviceMemoryBase scratch_memory) { + DeviceMemoryBase scratch_memory, DeviceMemoryBase seqlen_q, + DeviceMemoryBase seqlen_k) { se::dnn::LazyOpRunner *lazy_runner = options.runner_cache->AsFusedMHABackwardRunner(); std::optional> @@ -269,15 +237,15 @@ Status RunFusedMHABackward( d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, softmax_buffer, d_Q_accum_buffer, mask_buffer, d_bias_buffer, - fwd_output_buffer, bias_buffer); - return OkStatus(); + fwd_output_buffer, bias_buffer, seqlen_q, seqlen_k); + return absl::OkStatus(); } template -Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, - se::Stream *stream, - se::DeviceMemoryBase scratch_memory, - RunFusedMHABackwardOptions options) { +absl::Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, + se::Stream *stream, + se::DeviceMemoryBase scratch_memory, + RunFusedMHABackwardOptions options) { auto bmm1_grad_gemm1_rhs_buffer = se::DeviceMemory(params.bmm1_grad_gemm1_rhs_buffer); auto bmm1_grad_gemm2_rhs_buffer = @@ -325,12 +293,22 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, ? se::DeviceMemory(*params.bias_buffer) : se::DeviceMemoryBase(); + auto seqlen_q_buffer = + params.seqlen_q_buffer.has_value() + ? se::DeviceMemory(*params.seqlen_q_buffer) + : se::DeviceMemoryBase(); + + auto seqlen_k_buffer = + params.seqlen_k_buffer.has_value() + ? se::DeviceMemory(*params.seqlen_k_buffer) + : se::DeviceMemoryBase(); + se::dnn::AlgorithmDesc algorithm = params.config->algorithm; if (options.runner_cache) { algorithm = options.runner_cache->ToAlgorithmDesc(); } - Status run_status = OkStatus(); + absl::Status run_status = absl::OkStatus(); switch (params.config->kind) { case CudnnfMHAKind::kBackwardBmmBmm: case CudnnfMHAKind::kBackwardSoftmaxDropout: @@ -347,27 +325,27 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, bmm2_grad_gemm2_rhs_buffer, d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, softmax_sum_buffer, d_Q_accum_buffer, mask_buffer, d_bias_buffer, fwd_output_buffer, - bias_buffer, scratch_memory); + bias_buffer, scratch_memory, seqlen_q_buffer, seqlen_k_buffer); break; default: - return InternalError("Invalid cuDNN fMHA kind"); + return Internal("Invalid cuDNN fMHA kind"); } - if (run_status != OkStatus()) { + if (!run_status.ok()) { return run_status; } if (!stream->ok()) { - return InternalError("Unable to launch FMHA with type %s and algorithm %s", - CudnnfMHAKindToString(params.config->kind), - algorithm.ToString()); + return Internal("Unable to launch FMHA with type %s and algorithm %s", + CudnnfMHAKindToString(params.config->kind), + algorithm.ToString()); } return run_status; } } // namespace -/*static*/ StatusOr GpufMHAConfig::For( +/*static*/ absl::StatusOr GpufMHAConfig::For( const GpufMHADescriptor &desc) { // Get shapes from desc. const Shape &lhs_bmm1_shape = desc.lhs_bmm1_shape; @@ -456,14 +434,13 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, config.is_causal_mask = desc.is_causal_mask; const CudnnfMHABackendConfig &backend_config = desc.backend_config; config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); - - AssignScale(config, backend_config); - AssignDropoutRate(config, backend_config); - AssignSeed(config, backend_config); + config.fmha_scale.emplace(backend_config.fmha_scale()); + config.dropout_rate.emplace(backend_config.dropout_rate()); + config.seed.emplace(backend_config.seed()); return config; } -/*static*/ StatusOr GpufMHABackwardConfig::For( +/*static*/ absl::StatusOr GpufMHABackwardConfig::For( const GpufMHABackwardDescriptor &desc) { // Get shapes from desc. @@ -600,29 +577,21 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, config.is_causal_mask = desc.is_causal_mask; const CudnnfMHABackendConfig &backend_config = desc.backend_config; config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); - - auto assign_scale = [&]() { - config.fmha_scale.emplace(backend_config.fmha_scale()); - }; - - auto assign_dropout_rate = [&]() { - config.dropout_rate.emplace(backend_config.dropout_rate()); - }; - - auto assign_seed = [&]() { config.seed.emplace(backend_config.seed()); }; - assign_scale(); - assign_dropout_rate(); - assign_seed(); + config.fmha_scale.emplace(backend_config.fmha_scale()); + config.dropout_rate.emplace(backend_config.dropout_rate()); + config.seed.emplace(backend_config.seed()); return config; } -/*static*/ StatusOr GpufMHAParams::For( +/*static*/ absl::StatusOr GpufMHAParams::For( const GpufMHAConfig &config, se::DeviceMemoryBase lhs_bmm1_buffer, se::DeviceMemoryBase rhs_bmm1_buffer, se::DeviceMemoryBase rhs_bmm2_buffer, se::DeviceMemoryBase output_buffer, std::optional mask_buffer, std::optional bias_buffer, - std::optional activation_buffer) { + std::optional activation_buffer, + std::optional seqlen_q_buffer, + std::optional seqlen_k_buffer) { GpufMHAParams params; params.config = &config; params.lhs_bmm1_buffer = lhs_bmm1_buffer; @@ -632,11 +601,12 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, params.activation_buffer = activation_buffer; params.mask_buffer = mask_buffer; params.bias_buffer = bias_buffer; - + params.seqlen_q_buffer = seqlen_q_buffer; + params.seqlen_k_buffer = seqlen_k_buffer; return params; } -/*static*/ StatusOr GpufMHABackwardParams::For( +/*static*/ absl::StatusOr GpufMHABackwardParams::For( const GpufMHABackwardConfig &config, se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, @@ -652,7 +622,9 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, std::optional mask_buffer, std::optional d_bias_buffer, std::optional fwd_output_buffer, - std::optional bias_buffer) { + std::optional bias_buffer, + std::optional seqlen_q_buffer, + std::optional seqlen_k_buffer) { GpufMHABackwardParams params; params.config = &config; params.bmm1_grad_gemm1_rhs_buffer = bmm1_grad_gemm1_rhs_buffer; @@ -670,24 +642,29 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, params.d_bias_buffer = d_bias_buffer; params.fwd_output_buffer = fwd_output_buffer; params.bias_buffer = bias_buffer; + params.seqlen_q_buffer = seqlen_q_buffer; + params.seqlen_k_buffer = seqlen_k_buffer; return params; } -Status RunGpuFMHA(const GpufMHAConfig &fmha_config, - se::DeviceMemoryBase lhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm2_buffer, - se::DeviceMemoryBase output_buffer, - se::DeviceMemoryBase scratch_buffer, - std::optional mask_buffer, - std::optional bias_buffer, - std::optional activation_buffer, - se::Stream *stream, RunFusedMHAOptions options) { +absl::Status RunGpuFMHA(const GpufMHAConfig &fmha_config, + se::DeviceMemoryBase lhs_bmm1_buffer, + se::DeviceMemoryBase rhs_bmm1_buffer, + se::DeviceMemoryBase rhs_bmm2_buffer, + se::DeviceMemoryBase output_buffer, + se::DeviceMemoryBase scratch_buffer, + std::optional mask_buffer, + std::optional bias_buffer, + std::optional activation_buffer, + std::optional seqlen_q_buffer, + std::optional seqlen_k_buffer, + se::Stream *stream, RunFusedMHAOptions options) { TF_ASSIGN_OR_RETURN( GpufMHAParams params, GpufMHAParams::For(fmha_config, lhs_bmm1_buffer, rhs_bmm1_buffer, rhs_bmm2_buffer, output_buffer, mask_buffer, - bias_buffer, activation_buffer)); + bias_buffer, activation_buffer, seqlen_q_buffer, + seqlen_k_buffer)); PrimitiveType input_primitive_type = fmha_config.input_type; switch (input_primitive_type) { case F16: @@ -700,10 +677,10 @@ Status RunGpuFMHA(const GpufMHAConfig &fmha_config, return absl::UnimplementedError(absl::StrFormat( "Unimplemented fused MHA with %s", ToString(fmha_config))); } - return OkStatus(); + return absl::OkStatus(); } -Status RunGpuFMHABackward( +absl::Status RunGpuFMHABackward( const GpufMHABackwardConfig &fmha_config, se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, @@ -719,7 +696,9 @@ Status RunGpuFMHABackward( std::optional mask_buffer, std::optional d_bias_buffer, std::optional fwd_output_buffer, - std::optional bias_buffer, se::Stream *stream, + std::optional bias_buffer, + std::optional seqlen_q_buffer, + std::optional seqlen_k_buffer, se::Stream *stream, RunFusedMHABackwardOptions options) { TF_ASSIGN_OR_RETURN( GpufMHABackwardParams params, @@ -728,7 +707,8 @@ Status RunGpuFMHABackward( bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, softmax_sum_buffer, d_Q_accum_buffer, - mask_buffer, d_bias_buffer, fwd_output_buffer, bias_buffer)); + mask_buffer, d_bias_buffer, fwd_output_buffer, bias_buffer, + seqlen_q_buffer, seqlen_k_buffer)); PrimitiveType input_primitive_type = fmha_config.input_type; switch (input_primitive_type) { case F16: @@ -741,7 +721,7 @@ Status RunGpuFMHABackward( default: return Unimplemented("Unimplemented fused MHA backward"); } - return OkStatus(); + return absl::OkStatus(); } std::string ToString(const GpufMHAConfig &config) { diff --git a/xla/service/gpu/gpu_fused_mha_runner.h b/xla/service/gpu/gpu_fused_mha_runner.h index 041993431030c..108cf6a1d4b1b 100644 --- a/xla/service/gpu/gpu_fused_mha_runner.h +++ b/xla/service/gpu/gpu_fused_mha_runner.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,23 +16,23 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_ #define XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_ +#include #include #include #include #include #include -#include -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/log.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" +#include "xla/shape.h" #include "xla/status.h" -#include "xla/statusor.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/lazy_op_runner.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" namespace xla { @@ -53,7 +53,7 @@ struct GpufMHADescriptor { Shape rhs_bmm2_shape; Shape intermediate_lhs_bmm2_shape; // This will contain both output shape and activation shape - std::vector output_shapes; + absl::InlinedVector output_shapes; DotDimensionNumbers bmm1_dnums; DotDimensionNumbers bmm2_dnums; @@ -88,7 +88,7 @@ struct GpufMHABackwardDescriptor { // Structure to describe static properties of a GPU fused Multi-Headed // Attention. struct GpufMHAConfig { - static StatusOr For(const GpufMHADescriptor& fmha_desc); + static absl::StatusOr For(const GpufMHADescriptor& fmha_desc); PrimitiveType input_type; // Capture the primitive type of one of the inputs of BMM1 PrimitiveType output_type; @@ -116,7 +116,7 @@ struct GpufMHAConfig { // Structure to describe static properties of a GPU fused Multi-Headed // Attention backward. struct GpufMHABackwardConfig { - static StatusOr For( + static absl::StatusOr For( const GpufMHABackwardDescriptor& fmha_desc); PrimitiveType input_type; // Capture the primitive type of one of the inputs of BMM1 @@ -148,13 +148,15 @@ struct GpufMHABackwardConfig { // Implementation struct exposed for debugging and log analysis. struct GpufMHAParams { - static StatusOr For( + static absl::StatusOr For( const GpufMHAConfig& config, se::DeviceMemoryBase lhs_bmm1_buffer, se::DeviceMemoryBase rhs_bmm1_buffer, se::DeviceMemoryBase rhs_bmm2_buffer, se::DeviceMemoryBase output_buffer, std::optional mask_buffer, std::optional bias_buffer, - std::optional activation_buffer); + std::optional activation_buffer, + std::optional seqlen_q_buffer, + std::optional seqlen_k_buffer); const GpufMHAConfig* config; // Not owned se::DeviceMemoryBase lhs_bmm1_buffer; @@ -164,10 +166,12 @@ struct GpufMHAParams { std::optional activation_buffer; std::optional mask_buffer; std::optional bias_buffer; + std::optional seqlen_q_buffer; + std::optional seqlen_k_buffer; }; struct GpufMHABackwardParams { - static StatusOr For( + static absl::StatusOr For( const GpufMHABackwardConfig& config, se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, @@ -183,7 +187,9 @@ struct GpufMHABackwardParams { std::optional mask_buffer, std::optional d_bias_buffer, std::optional fwd_output_buffer, - std::optional bias_buffer); + std::optional bias_buffer, + std::optional seqlen_q_buffer, + std::optional seqlen_k_buffer); const GpufMHABackwardConfig* config; // Not owned se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer; @@ -201,6 +207,8 @@ struct GpufMHABackwardParams { std::optional d_bias_buffer; std::optional fwd_output_buffer; std::optional bias_buffer; + std::optional seqlen_q_buffer; + std::optional seqlen_k_buffer; }; class FusedMultiHeadedAttentionRunner { @@ -380,18 +388,20 @@ struct RunFusedMHABackwardOptions { FusedMultiHeadedAttentionBackwardRunner* runner_cache; }; -Status RunGpuFMHA(const GpufMHAConfig& fmha_config, - se::DeviceMemoryBase lhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm2_buffer, - se::DeviceMemoryBase output_buffer, - se::DeviceMemoryBase scratch_buffer, - std::optional mask_buffer, - std::optional bias_buffer, - std::optional activation_buffer, - se::Stream* stream, RunFusedMHAOptions = {}); - -Status RunGpuFMHABackward( +absl::Status RunGpuFMHA(const GpufMHAConfig& fmha_config, + se::DeviceMemoryBase lhs_bmm1_buffer, + se::DeviceMemoryBase rhs_bmm1_buffer, + se::DeviceMemoryBase rhs_bmm2_buffer, + se::DeviceMemoryBase output_buffer, + se::DeviceMemoryBase scratch_buffer, + std::optional mask_buffer, + std::optional bias_buffer, + std::optional activation_buffer, + std::optional seqlen_q_buffer, + std::optional seqlen_k_buffer, + se::Stream* stream, RunFusedMHAOptions = {}); + +absl::Status RunGpuFMHABackward( const GpufMHABackwardConfig& fmha_config, se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, @@ -407,7 +417,9 @@ Status RunGpuFMHABackward( std::optional mask_buffer, std::optional d_bias_buffer, std::optional fwd_output_buffer, - std::optional bias_buffer, se::Stream* stream, + std::optional bias_buffer, + std::optional seqlen_q_buffer, + std::optional seqlen_k_buffer, se::Stream* stream, RunFusedMHABackwardOptions = {}); std::string ToString(const GpufMHAConfig& config); diff --git a/xla/service/gpu/gpu_fusible.cc b/xla/service/gpu/gpu_fusible.cc index c01182280ed11..02344a7fb1c27 100644 --- a/xla/service/gpu/gpu_fusible.cc +++ b/xla/service/gpu/gpu_fusible.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,21 +15,29 @@ limitations under the License. #include "xla/service/gpu/gpu_fusible.h" -#include -#include +#include +#include #include #include #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/reduction_utils.h" +#include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/instruction_fusion.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" #include "xla/util.h" namespace xla { @@ -152,15 +160,13 @@ bool IsNestableVariadicReduction(const HloInstruction& instr) { } bool IsInputFusibleTranspose(const HloInstruction& instr) { - if (instr.opcode() == HloOpcode::kBitcast) { + if (instr.opcode() == HloOpcode::kBitcast || instr.IsCustomFusion()) { return false; } - auto& hero = FindNonTrivialHero(instr); - if (GetDescriptionForTiledTransposeEmitter(instr, hero).has_value()) { - return true; + if (instr.opcode() == HloOpcode::kFusion) { + return HasAnyTiledTransposeRoot(*instr.fused_instructions_computation()); } - return !instr.IsCustomFusion() && instr.opcode() == HloOpcode::kFusion && - HasAnyTiledTransposeRoot(*instr.called_computations()[0]); + return GetDescriptionForTiledTransposeEmitter(instr, instr).has_value(); } const HloInstruction* GetRealHeroForMultiOutputFusion( @@ -192,28 +198,6 @@ const HloInstruction* GetRealHeroForMultiOutputFusion( return fused_expression_root->operands()[0]; } -// Returns whether the output of a fusion with reduction are consistent with -// `first_reduce`. -static bool IsFusedReductionOutputConsistent( - const HloInstruction* inst, const HloInstruction* first_reduce) { - const auto& hero = FindNonTrivialHero(*inst); - if (IsRealReductionHero(*inst, hero)) { - // Shapes, layouts and dimensions must be the same for all reduces - // inside of this fusion. - return ShapeUtil::EqualIgnoringElementType(first_reduce->shape(), - inst->shape()) && - ShapeUtil::EqualIgnoringElementType( - first_reduce->operand(0)->shape(), inst->operand(0)->shape()) && - ShapeUtil::EqualIgnoringElementType( - first_reduce->operand(1)->shape(), inst->operand(1)->shape()) && - first_reduce->dimensions() == inst->dimensions(); - } - return ShapeUtil::CompatibleIgnoringElementType( - first_reduce->operand(0)->shape(), inst->shape()) && - LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), - inst->shape().layout()); -} - FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1, const HloInstruction* hero2) { auto hero1_is_unnested_reduce = @@ -228,7 +212,7 @@ FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1, bool hero2_is_unnested_transpose = tiled_transpose_hero2.has_value(); if (hero1_is_unnested_reduce && hero2_is_unnested_reduce && - !IsFusedReductionOutputConsistent(hero2, hero1)) { + !AreReductionsMultiOutputFusionCompatible(hero2, hero1)) { return "tiled reductions with different shapes"; } else if (hero1_is_unnested_transpose && hero2_is_unnested_transpose && // After normalization to rank 3, the transposes should have the @@ -342,8 +326,7 @@ bool IsInputFusible(const HloInstruction& instr) { // Returns true if `instr` can be fused as a producer or as a consumer into a // kLoop fusion. -bool IsUniversallyLoopFusible(const HloInstruction& instr, - const HloInstruction& hero) { +bool IsUniversallyLoopFusible(const HloInstruction& instr) { // NOTE: this check is done before the switch below, because a fusion instr // can also be elementwise, even if it's not a kLoop. if (instr.IsElementwise() && instr.operand_count() > 0 && @@ -353,7 +336,7 @@ bool IsUniversallyLoopFusible(const HloInstruction& instr, switch (instr.opcode()) { case HloOpcode::kCopy: - return !GetDescriptionForTiledTransposeEmitter(instr, hero).has_value(); + return !GetDescriptionForTiledTransposeEmitter(instr, instr).has_value(); case HloOpcode::kFusion: return instr.fusion_kind() == HloInstruction::FusionKind::kLoop; @@ -377,8 +360,7 @@ bool IsUniversallyLoopFusible(const HloInstruction& instr, } // Returns true if `instr` can be fused as a consumer into a kLoop fusion. -bool IsLoopFusibleAsConsumer(const HloInstruction& instr, - const HloInstruction& hero) { +bool IsLoopFusibleAsConsumer(const HloInstruction& instr) { // Instr should be fusible. if (!instr.IsFusible()) return false; @@ -390,12 +372,19 @@ bool IsLoopFusibleAsConsumer(const HloInstruction& instr, // Any reduction can be fused as a consumer. if (instr.opcode() == HloOpcode::kReduce) return true; - return IsUniversallyLoopFusible(instr, hero); + // We may have input fusions which effectively have turned into loop + // fusions. Those should still be considered as loop fusible consumers, + // but they are not universally loop fusible. + if (!IsInputFusible(instr) && instr.opcode() == HloOpcode::kFusion && + instr.fusion_kind() == HloInstruction::FusionKind::kInput) { + return true; + } + + return IsUniversallyLoopFusible(instr); } // Returns true if `instr` can be fused as a producer into a kLoop fusion. -bool IsLoopFusibleAsProducer(const HloInstruction& instr, - const HloInstruction& hero) { +bool IsLoopFusibleAsProducer(const HloInstruction& instr) { // Instr should be fusible. if (!instr.IsFusible()) return false; @@ -407,7 +396,7 @@ bool IsLoopFusibleAsProducer(const HloInstruction& instr, // Non-variadic reductions can be fused as producers. return !instr.shape().IsTuple(); default: - return IsUniversallyLoopFusible(instr, hero); + return IsUniversallyLoopFusible(instr); } } @@ -452,12 +441,8 @@ FusionDecision CanEmitInputFusedScatter(const HloInstruction& producer, FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, const HloInstruction& consumer) { - const auto& producer_hero = FindNonTrivialHero(producer); - const auto& consumer_hero = FindNonTrivialHero(consumer); - if (!IsLoopFusibleAsProducer(producer, producer_hero) && - !(GetDescriptionForTiledTransposeEmitter(producer, producer_hero) - .has_value() && - &consumer_hero == &producer)) { + if (!IsLoopFusibleAsProducer(producer) && + !IsInputFusibleTranspose(producer)) { return "the producer is not loop-fusible"; } @@ -468,12 +453,10 @@ FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, .xla_gpu_enable_reduction_epilogue_fusion()) { return "Reduction epilogue fusion is not enabled."; } - // TODO(akuegel): Remove workaround when producer_hero is computed - // correctly. const HloInstruction& reduce_hero = - producer_hero.opcode() == HloOpcode::kFusion - ? FindNonTrivialHero(*producer_hero.fused_expression_root()) - : producer_hero; + producer.opcode() == HloOpcode::kFusion + ? FindNonTrivialHero(*producer.fused_expression_root()) + : producer; if (!ReductionIsRaceFree( reduce_hero.GetModule()->config(), GetReductionKindAndContiguousComponents(reduce_hero))) { @@ -494,8 +477,7 @@ FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, return can_fuse; } - if (!IsInputFusible(consumer) && - !IsLoopFusibleAsConsumer(consumer, consumer_hero)) { + if (!IsInputFusible(consumer) && !IsLoopFusibleAsConsumer(consumer)) { return "the consumer is not input-fusible and not loop-fusible"; } @@ -556,7 +538,7 @@ FusionDecision IsProducerMultiOutputFusible(const HloInstruction& producer) { return "In-place operations are present"; } - if (!IsLoopFusibleAsProducer(producer, FindNonTrivialHero(producer))) { + if (!IsLoopFusibleAsProducer(producer)) { return "producer is not loop-fusible"; } @@ -567,11 +549,18 @@ FusionDecision IsProducerMultiOutputFusible(const HloInstruction& producer) { return {}; } -// Returns shared memory usage for a given instruction in bytes. +// Returns an estimate of the shared memory usage for a given instruction in +// bytes. static int64_t SharedMemoryUsageNoCache(const HloInstruction& instr) { - // For now we are only fusing reductions. - if (instr.opcode() == HloOpcode::kReduce && - IsReductionFromOrToContiguousDimensions(instr)) { + if (instr.opcode() == HloOpcode::kFusion) { + int64_t sum = 0; + for (const HloInstruction* hlo : + instr.fused_instructions_computation()->instructions()) { + sum += SharedMemoryUsageNoCache(*hlo); + } + return sum; + } else if (instr.opcode() == HloOpcode::kReduce && + IsReductionFromOrToContiguousDimensions(instr)) { ReductionDimensions reduction_info = GetReductionKindAndContiguousComponents(instr); int64_t primitive_size = ShapeUtil::ByteSizeOfPrimitiveType( @@ -586,20 +575,11 @@ static int64_t SharedMemoryUsageNoCache(const HloInstruction& instr) { // from potential x-tiling). return 2 * 32 * 33 * primitive_size * num_variadic; } - } else if (GetDescriptionForTiledTransposeEmitter(instr, - FindNonTrivialHero(instr)) - .has_value()) { + } else if (GetDescriptionForTiledTransposeEmitter(instr, instr).has_value()) { // Tile size for transposition. int64_t primitive_size = ShapeUtil::ByteSizeOfPrimitiveType(instr.shape().element_type()); return 32 * 33 * primitive_size; - } else if (instr.opcode() == HloOpcode::kFusion) { - int64_t sum = 0; - for (const HloInstruction* hlo : - instr.fused_instructions_computation()->instructions()) { - sum += SharedMemoryUsageNoCache(*hlo); - } - return sum; } // Other fused expressions for now don't need the shared memory budget. return 0; @@ -624,25 +604,6 @@ int64_t SharedMemoryUsage(const HloInstruction& instr, FusionInfoCache* cache) { return it->second; } -int64_t ReductionProjectedShmemUsageBytes( - const ReductionDimensions& reduction_dimensions, - const std::vector>& instr_index_groups) { - int64_t out = 0; - // Different groups are computed in parallel on different blocks, so they are - // not sharing the shmem budget. The overall usage is given by the largest - // one. - for (const auto& group : instr_index_groups) { - int64_t sum = 0; - for (const HloInstruction* root : group) { - if (IsReductionFromOrToContiguousDimensions(*root)) { - sum += SharedMemoryUsage(*root); - } - } - out = std::max(out, sum); - } - return out; -} - // Codegen'ing unnested reductions requires a lot of registers, so a MOF // combining many of those runs a high risk of spilling. constexpr int64_t kMaxUnnestedReductionOutputsPerFusion = 8; @@ -867,10 +828,11 @@ bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr) { instr.IsElementwise()); } -HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& /*producer*/, +HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& producer, const HloInstruction& consumer) { - return IsInputFusible(consumer) ? HloInstruction::FusionKind::kInput - : HloInstruction::FusionKind::kLoop; + return (IsInputFusible(consumer) || IsInputFusible(producer)) + ? HloInstruction::FusionKind::kInput + : HloInstruction::FusionKind::kLoop; } bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr, @@ -940,14 +902,39 @@ std::vector GetFusionRoots( return out; } -bool IsRealReductionHero(const HloInstruction& root, - const HloInstruction& hero) { - if (!IsReductionFromOrToContiguousDimensions(hero)) { - return false; - } - return &root == &hero || - ReductionIsRaceFree(hero.GetModule()->config(), - GetReductionKindAndContiguousComponents(hero)); +bool IsTritonSoftmaxFusion(const HloInstruction& instr) { + return instr.opcode() == HloOpcode::kFusion && + instr.fusion_kind() == HloInstruction::FusionKind::kCustom && + instr.backend_config().ok() && + instr.backend_config() + ->fusion_backend_config() + .kind() == kTritonSoftmaxFusionKind; +} + +bool MayPreventVectorization(const HloFusionAdaptor& fusion) { + // An empirically chosen constant: unrolling concat with a large amount of + // arguments causes excessive register spilling. + static constexpr int kMaxConcatArgumentsForUnrolling = 10; + return HloAnyOf(fusion.GetRoots(), fusion, [&](auto node) { + switch (node.opcode()) { + case HloOpcode::kReduceWindow: + case HloOpcode::kSort: + case HloOpcode::kDot: + case HloOpcode::kSin: + case HloOpcode::kCos: + case HloOpcode::kTan: + case HloOpcode::kPower: + case HloOpcode::kAtan2: + return true; + case HloOpcode::kConcatenate: + return node.instruction().operand_count() > + kMaxConcatArgumentsForUnrolling; + case HloOpcode::kReduce: + return node.instruction().shape().tuple_shapes_size() > 1; + default: + return false; + } + }); } } // namespace gpu diff --git a/xla/service/gpu/gpu_fusible.h b/xla/service/gpu/gpu_fusible.h index 897b44bb301fc..79fef3b70ba31 100644 --- a/xla/service/gpu/gpu_fusible.h +++ b/xla/service/gpu/gpu_fusible.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,11 +16,15 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GPU_FUSIBLE_H_ #define XLA_SERVICE_GPU_GPU_FUSIBLE_H_ +#include +#include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/reduction_utils.h" +#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/instruction_fusion.h" #include "xla/stream_executor/device_description.h" @@ -63,12 +67,7 @@ struct FusionInfoCache { int64_t SharedMemoryUsage(const HloInstruction& instr, FusionInfoCache* cache = nullptr); -// Returns projected shared memory usage of a reduction fusion. -int64_t ReductionProjectedShmemUsageBytes( - const ReductionDimensions& reduction_dimensions, - const std::vector>& instr_index_groups); - -inline constexpr int64_t MaxOperandsAndOutputsPerFusion() { return 64; } +inline constexpr int64_t MaxOperandsAndOutputsPerFusion() { return 96; } // Whether the op transposes the physical data layout. Fusing such ops may lead // to uncoalesced data access and may thus not be beneficial. @@ -168,7 +167,7 @@ FusionDecision IsProducerMultiOutputFusible(const HloInstruction& producer); // a producer-consumer multi-output fusion. bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr); -// Determines the fusion kind to be used when fusing `producer` and `consumer`. +// Determines the fusion kind to be used when fusing into `consumer`. HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& producer, const HloInstruction& consumer); @@ -207,9 +206,12 @@ size_t GetOutputSizeOfFusible(const HloInstruction& instr); std::vector GetFusionRoots( const HloComputation& computation); -// Whether the instruction is a reduction hero for the given root. -bool IsRealReductionHero(const HloInstruction& root, - const HloInstruction& hero); +// Whether the instruction is a Triton Softmax fusion. +bool IsTritonSoftmaxFusion(const HloInstruction& instr); + +// Whether the fusion will likely behave poorly with vectorization due to the +// instructions it contains. +bool MayPreventVectorization(const HloFusionAdaptor& fusion); } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/gpu_fusible_test.cc b/xla/service/gpu/gpu_fusible_test.cc index 172d18122c33d..acddac32f94a2 100644 --- a/xla/service/gpu/gpu_fusible_test.cc +++ b/xla/service/gpu/gpu_fusible_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,10 +15,15 @@ limitations under the License. #include "xla/service/gpu/gpu_fusible.h" +#include + +#include #include "absl/strings/str_cat.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -1001,6 +1006,51 @@ TEST_F(GpuFusibleTest, ProducerConsumerFusionElementwiseAndReduce) { EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*producer, *consumer)); } +TEST_F(GpuFusibleTest, ProducerConsumerFusionTransposeAndLoopFusion) { + auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( + fused_add { + p0.1 = f32[32,31,30]{2,1,0} parameter(0) + p1.1 = f32[32,31,30]{2,1,0} parameter(1) + neg = f32[32,31,30]{2,1,0} negate(p0.1) + ROOT add = f32[32,31,30]{2,1,0} add(neg, p1.1) + } + + ENTRY reduce { + p0 = f32[32,31,30]{2,1,0} parameter(0) + p1 = f32[32,30,31]{2,1,0} parameter(1) + transpose = f32[32,31,30]{2,1,0} transpose(p1), dimensions={0,2,1} + ROOT add = f32[32,31,30]{2,1,0} fusion(p0, transpose), kind=kLoop, calls=fused_add + })")) + .value(); + const HloInstruction* root = module->entry_computation()->root_instruction(); + const HloInstruction* consumer = root; + const HloInstruction* producer = root->operand(1); + EXPECT_TRUE(IsProducerConsumerFusible(*producer, *consumer)); +} + +TEST_F(GpuFusibleTest, ProducerConsumerFusionReduceAndLoopFusion) { + auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( + fused_add { + p0.1 = f32[32,31,30]{2,1,0} parameter(0) + p1.1 = f32[32,31,30]{2,1,0} parameter(1) + neg = f32[32,31,30]{2,1,0} negate(p0.1) + ROOT add = f32[32,31,30]{2,1,0} add(neg, p1.1) + } + + ENTRY reduce { + p0 = f32[32,31,30]{2,1,0} parameter(0) + p1 = f32[32,31,30,29]{3,2,1,0} parameter(1) + c0 = f32[] constant(0.0) + reduce = f32[32,31,30]{2,1,0} reduce(p1, c0), dimensions={3}, to_apply=scalar_add + ROOT add = f32[32,31,30]{2,1,0} fusion(p0, reduce), kind=kLoop, calls=fused_add + })")) + .value(); + const HloInstruction* root = module->entry_computation()->root_instruction(); + const HloInstruction* consumer = root; + const HloInstruction* producer = root->operand(1); + EXPECT_TRUE(IsProducerConsumerFusible(*producer, *consumer)); +} + TEST_F(GpuFusibleTest, ProducerConsumerFusionLoopFusionAndReduce) { auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( fused_add { @@ -1403,5 +1453,22 @@ TEST_F(GpuFusibleTest, DoesNotCreateHeavyComputation_FusionInstr) { EXPECT_FALSE(CreatesHeavyComputation(*producer, *consumer)); } +TEST_F(GpuFusibleTest, ChooseFusionKind) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule module + +ENTRY computation { + p = f32[5000,6000]{1,0} parameter(0) + c = f32[6000,5000] transpose(p), dimensions={1,0} + ROOT r = f32[300,20,5000] reshape(c) +} +)") + .value(); + const HloInstruction* root = module->entry_computation()->root_instruction(); + const HloInstruction* producer = root->operand(0); + EXPECT_EQ(ChooseFusionKind(*producer, *root), + HloInstruction::FusionKind::kInput); +} + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/gpu_hlo_schedule.cc b/xla/service/gpu/gpu_hlo_schedule.cc index 3023e7d6db5a7..0f5d9809ebae7 100644 --- a/xla/service/gpu/gpu_hlo_schedule.cc +++ b/xla/service/gpu/gpu_hlo_schedule.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,63 +15,11 @@ limitations under the License. #include "xla/service/gpu/gpu_hlo_schedule.h" -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/strings/match.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_input_output_alias_config.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_schedule.h" -#include "xla/hlo/utils/hlo_query.h" -#include "xla/service/buffer_value.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/gpu_schedule_postprocessing.h" -#include "xla/service/gpu/model/analytical_latency_estimator.h" -#include "xla/service/hlo_memory_scheduler.h" -#include "xla/service/hlo_pass_pipeline.h" -#include "xla/service/latency_hiding_scheduler.h" -#include "xla/service/p2p_schedule_preparation.h" -#include "xla/service/profile_guided_latency_estimator.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/status.h" -#include "xla/statusor.h" -#include "xla/stream_executor/device_description.h" -#include "xla/util.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/protobuf.h" - namespace xla { namespace gpu { namespace { -bool IsSyncCollective(const HloInstruction& instr) { - auto backend_config = instr.backend_config().value(); - return backend_config.is_sync(); -} - bool IsNopInstruction(const HloInstruction& hlo) { HloOpcode op = hlo.opcode(); return op == HloOpcode::kGetTupleElement || op == HloOpcode::kBitcast || @@ -83,7 +31,7 @@ bool ShouldScheduleAsEarlyAsPossible(const HloInstruction& instr) { switch (instr.opcode()) { case HloOpcode::kAllReduceStart: case HloOpcode::kCollectivePermuteStart: - return !IsSyncCollective(instr); + return !IsSyncCollective(&instr); case HloOpcode::kCustomCall: return static_cast(instr) .custom_call_schedule() == @@ -139,21 +87,24 @@ HloInstructionSequence PostprocessorToScheduleAsEarlyOrLateAsPossible( earliest_scheduled.push_back(instr); scheduled.insert(instr); }; + for (HloInstruction* instr : input.instructions()) { - if (is_scheduled(instr)) { - continue; - } + if (is_scheduled(instr)) continue; add_to_schedule(instr); // Schedule any successor that should be scheduled as early as possible if // all of its producers and control_predecessors have been scheduled. for (HloInstruction* user : instr->users()) { + if (is_scheduled(user)) continue; + if (ShouldScheduleSuccessor(*user, is_scheduled)) { add_to_schedule(user); } } for (HloInstruction* successor : instr->control_successors()) { + if (is_scheduled(successor)) continue; + if (ShouldScheduleSuccessor(*successor, is_scheduled)) { add_to_schedule(successor); } @@ -173,20 +124,22 @@ HloInstructionSequence PostprocessorToScheduleAsEarlyOrLateAsPossible( }; for (auto it = earliest_scheduled.rbegin(); it != earliest_scheduled.rend(); it++) { - if (is_scheduled(*it)) { - continue; - } + if (is_scheduled(*it)) continue; add_to_schedule(*it); // Schedule any predecessor that should be scheduled as late as possible // if all of its users and control_successors have been scheduled. for (HloInstruction* operand : (*it)->operands()) { + if (is_scheduled(operand)) continue; + if (ShouldSchedulePredecessor(*operand, is_scheduled)) { add_to_schedule(operand); } } for (HloInstruction* predecessor : (*it)->control_predecessors()) { + if (is_scheduled(predecessor)) continue; + if (ShouldSchedulePredecessor(*predecessor, is_scheduled)) { add_to_schedule(predecessor); } @@ -197,6 +150,12 @@ HloInstructionSequence PostprocessorToScheduleAsEarlyOrLateAsPossible( HloInstructionSequence result; absl::c_for_each(latest_scheduled, [&](HloInstruction* i) { result.push_back(i); }); + + // Schedule post-processing can't introduce new instructions. + CHECK(input.instructions().size() == result.size()) + << "schedule as early or late post-processing changed schedule size from " + << input.instructions().size() << " to " << result.size(); + return result; } @@ -205,29 +164,39 @@ HloInstructionSequence PostprocessorToScheduleAsEarlyOrLateAsPossible( HloInstructionSequence PostprocessorToScheduleSyncCollectives( const HloInstructionSequence& input) { HloInstructionSequence result; - auto is_synchronous_op = [](const HloInstruction* instr) { - return hlo_query::IsAsyncCollectiveStartOp(instr->opcode(), + + // Returns true if `inst` is a synchronous version of async collective start + // operation (marked with `is_sync` attribute). + auto is_sync_start = [](const HloInstruction* instr) { + return hlo_query::IsAsyncCollectiveStartOp(instr, /*include_send_recv=*/true) && - IsSyncCollective(*instr); + IsSyncCollective(instr); }; + for (HloInstruction* instr : input.instructions()) { - if (is_synchronous_op(instr)) { - continue; - } - if (hlo_query::IsAsyncCollectiveDoneOp(instr->opcode(), - /*include_send_recv=*/true)) { - // Place the start op just before the done op if its synchronous. + // Skip synchronous start instruction as it will be scheduled later when + // we'll process corresponding done instruction. + if (is_sync_start(instr)) continue; + + // Find a start instruction corresponding to done and schedule it right + // before a done if it's a synchronous version. + if (hlo_query::IsAsyncCollectiveDoneOp(instr, true)) { HloInstruction* start = instr->mutable_operand(0); - if (is_synchronous_op(start)) { - result.push_back(start); - } + if (is_sync_start(start)) result.push_back(start); } + result.push_back(instr); } + + // Schedule post-processing can't introduce new instructions. + CHECK(input.instructions().size() == result.size()) + << "sync collectives post-processing changed schedule size from " + << input.instructions().size() << " to " << result.size(); + return result; } -StatusOr ScheduleGpuModuleWithMemoryScheduler( +absl::StatusOr ScheduleGpuModuleWithMemoryScheduler( const HloModule* module, int64_t pointer_size) { return ScheduleModule( module, @@ -238,26 +207,10 @@ StatusOr ScheduleGpuModuleWithMemoryScheduler( PostProcessSchedule)); } -// Latency hiding scheduler support. - -CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo) { - switch (hlo.opcode()) { - case HloOpcode::kSend: - return {HloOpcode::kAsyncStart, HloOpcode::kSend}; - case HloOpcode::kSendDone: - return {HloOpcode::kAsyncDone, HloOpcode::kSend}; - case HloOpcode::kRecv: - return {HloOpcode::kAsyncStart, HloOpcode::kRecv}; - case HloOpcode::kRecvDone: - return {HloOpcode::kAsyncDone, HloOpcode::kRecv}; - default: - return DefaultGetCanonicalAsyncOp(hlo); - } -} - SchedulerConfig GetSchedulerConfig(int64_t memory_limit) { SchedulerConfig config; config.all_reduce_overlap_limit = 1; + config.collective_broadcast_overlap_limit = 1; config.collective_permute_overlap_limit = 1; config.use_real_cost_model = false; config.aggressive_scheduling_policies = true; @@ -271,14 +224,48 @@ SchedulerConfig GetSchedulerConfig(int64_t memory_limit) { // We use two different set of resources to model the scheduling of asynchronous // collective operations and P2P Send and Recv operations. This corresponds to // the fact that the runtime use a stream to run asynchronous collective -// operations and another stream to run P2P Send and Recv operations. +// operations and two other streams to run P2P Send and Recv operations. enum class GpuResourceType { - kGpuAsyncStreamSend = 0, // The resource for P2P Send operation. - kGpuAsyncStreamRecv = 1, // The resource for P2P Recv operation. - kGpuAsyncStreamCollectives = 2, // The resource for collective operations. - kNumTargetResources = 3, + kGpuAsyncStreamSend0 = 0, // A resource for P2P Send operation. + kGpuAsyncStreamSend1 = 1, // Another resource for P2P Send operation. + kGpuAsyncStreamRecv0 = 2, // A resource for P2P Recv operation. + kGpuAsyncStreamRecv1 = 3, // Another resource for P2P Recv operation. + kGpuAsyncStreamCollectives = 4, // The resource for collective operations. + kGpuAsyncStreamComputes = 5, // The resource for async compute operations. + kNumTargetResources = 6, }; +// Returns the pipeline stream for a P2P instruction recorded in a frontend +// attribute. +int64_t GetPipelineStream(const HloInstruction& start) { + auto it = start.frontend_attributes().map().find(kSendRecvPipelineAttr); + if (it != start.frontend_attributes().map().end() && it->second == "1") { + return 1; + } + return 0; +} + +// Returns the resource type and resource usage for a P2P instruction. +std::pair GetP2PResourceAndUsage( + const HloInstruction& instr, const CanonicalAsyncOp& op) { + ResourceUsageType usage = op.outer == HloOpcode::kAsyncStart + ? ResourceUsageType::kResourceRelease + : ResourceUsageType::kResourceOccupy; + int64_t pipeline = GetPipelineStream(instr); + HloOpcode opcode = op.inner; + GpuResourceType resource; + if (pipeline == 0) { + resource = opcode == HloOpcode::kSend + ? GpuResourceType::kGpuAsyncStreamSend0 + : GpuResourceType::kGpuAsyncStreamRecv0; + } else { + resource = opcode == HloOpcode::kSend + ? GpuResourceType::kGpuAsyncStreamSend1 + : GpuResourceType::kGpuAsyncStreamRecv1; + } + return {resource, usage}; +} + // Base GPU async tracker that enables async tracking only for async collectives // that are marked for async execution. class GpuAsyncTrackerBase : public AsyncTracker { @@ -290,17 +277,54 @@ class GpuAsyncTrackerBase : public AsyncTracker { GetCanonicalAsyncOpFunc func = GpuGetCanonicalAsyncOp) : AsyncTracker(config, func) {} + bool IsAsyncComputeOp(const HloInstruction& hlo) const { + return (hlo.opcode() == HloOpcode::kAsyncStart || + hlo.opcode() == HloOpcode::kAsyncDone) && + !hlo_query::IsCollectiveCommunicationOp( + hlo.async_wrapped_opcode()) && + hlo.async_execution_thread() != hlo.parent()->execution_thread(); + } + bool IsSupportedAsyncDone(const HloInstruction& hlo) const override { - return hlo_query::IsAsyncCollectiveDoneOp(hlo.opcode(), - /*include_send_recv=*/true) && - !IsSyncCollective(*hlo.operand(0)); + return (hlo_query::IsAsyncCollectiveDoneOp(&hlo, + /*include_send_recv=*/true) && + !IsSyncCollective(hlo.operand(0))) || + IsAsyncComputeOp(hlo); } // Returns if this is an Async op start that the scheduler supports. bool IsSupportedAsyncStart(const HloInstruction& hlo) const override { - return hlo_query::IsAsyncCollectiveStartOp(hlo.opcode(), - /*include_send_recv=*/true) && - !IsSyncCollective(hlo); + return (hlo_query::IsAsyncCollectiveStartOp(&hlo, + /*include_send_recv=*/true) && + !IsSyncCollective(&hlo)) || + IsAsyncComputeOp(hlo); + } + + void PostProcessScheduleGraph( + HloScheduleGraph* schedule_graph, + const LatencyEstimator* latency_estimator) const override { + for (auto inst : schedule_graph->GetOriginalInstrList()) { + // Force pipelined Recv to be closed to Recvdone so that copies inserted + // for RecvDone can be eliminated. + if (inst->opcode() == HloOpcode::kRecv) { + if (inst->frontend_attributes().map().count(kSendRecvPipelineAttr) > + 0) { + HloGraphNode& node = schedule_graph->GetNode(inst); + node.SetForceEarly(true); + VLOG(5) << "Setting force early for instruction: " + << inst->ToString(); + } + } + if (inst->has_backend_config()) { + auto gpu_config = inst->backend_config(); + if (gpu_config.ok()) { + HloGraphNode& node = schedule_graph->GetNode(inst); + node.SetForceDelay(gpu_config->force_earliest_schedule()); + VLOG(5) << "Setting force delay for instruction: " + << inst->ToString(); + } + } + } } }; @@ -315,24 +339,21 @@ class GpuAsyncTracker : public GpuAsyncTrackerBase { CanonicalAsyncOp op = GetCanonicalAsyncOp(instr); if (op.outer == HloOpcode::kAsyncStart || op.outer == HloOpcode::kAsyncDone) { - ResourceUsageType usage = op.outer == HloOpcode::kAsyncStart - ? ResourceUsageType::kResourceRelease - : ResourceUsageType::kResourceOccupy; - ResourcesVector resources; - auto add_resource = [&](GpuResourceType resource_type) { - const int64_t gpu_stream_resource = GetFirstTargetDefinedResource() + - static_cast(resource_type); - resources.push_back(std::make_pair(gpu_stream_resource, usage)); - }; - - if (op.inner == HloOpcode::kSend) { - add_resource(GpuResourceType::kGpuAsyncStreamSend); - } else if (op.inner == HloOpcode::kRecv) { - add_resource(GpuResourceType::kGpuAsyncStreamRecv); + ResourceUsageType usage; + GpuResourceType resource; + if (op.inner == HloOpcode::kSend || op.inner == HloOpcode::kRecv) { + std::tie(resource, usage) = GetP2PResourceAndUsage(instr, op); } else { - add_resource(GpuResourceType::kGpuAsyncStreamCollectives); + usage = op.outer == HloOpcode::kAsyncStart + ? ResourceUsageType::kResourceRelease + : ResourceUsageType::kResourceOccupy; + resource = hlo_query::IsCollectiveCommunicationOp(op.inner) + ? GpuResourceType::kGpuAsyncStreamCollectives + : GpuResourceType::kGpuAsyncStreamComputes; } - return resources; + return {std::make_pair( + GetFirstTargetDefinedResource() + static_cast(resource), + usage)}; } return GpuAsyncTrackerBase::GetResourcesFromInstruction(instr); } @@ -361,6 +382,15 @@ class GpuAsyncTracker : public GpuAsyncTrackerBase { // async stream, we can increase this number and then do a post-pass on the // scheduled code to assign async stream-id to collectives (and actually // support > 1 async stream in the runtime). + // The only case we'd allow 2 for now is when the current resource is + // for an async computation operation which will be allocated with + // a dedicated compute stream. It can run concurrently with + // another collective. + if ((resource_type - first_target_resource) == + static_cast(GpuResourceType::kGpuAsyncStreamComputes)) { + return 2; + } + return 1; } @@ -373,12 +403,18 @@ class GpuAsyncTracker : public GpuAsyncTrackerBase { first_target_resource + GetNumTargetDefinedResources()); switch ( static_cast(resource_type - first_target_resource)) { - case GpuResourceType::kGpuAsyncStreamSend: - return "kGpuAsyncStreamSend"; - case GpuResourceType::kGpuAsyncStreamRecv: - return "kGpuAsyncStreamRecv"; + case GpuResourceType::kGpuAsyncStreamSend0: + return "kGpuAsyncStreamSend0"; + case GpuResourceType::kGpuAsyncStreamSend1: + return "kGpuAsyncStreamSend1"; + case GpuResourceType::kGpuAsyncStreamRecv0: + return "kGpuAsyncStreamRecv0"; + case GpuResourceType::kGpuAsyncStreamRecv1: + return "kGpuAsyncStreamRecv1"; case GpuResourceType::kGpuAsyncStreamCollectives: return "kGpuAsyncStreamCollectives"; + case GpuResourceType::kGpuAsyncStreamComputes: + return "kGpuAsyncStreamComputes"; default: return "kUnsupportedResource"; } @@ -394,6 +430,53 @@ class GpuAsyncTracker : public GpuAsyncTrackerBase { first_target_resource + GetNumTargetDefinedResources()); return ResourceHazardType::kUnshareable; } + + int64_t GetNumResourcesPerInstruction( + int64_t resource_type, const HloInstruction& instr) const override { + int64_t num_resources = GpuAsyncTrackerBase::GetNumResourcesPerInstruction( + resource_type, instr); + if (num_resources <= 0 || instr.opcode() != HloOpcode::kWhile) { + return num_resources; + } + // For while-loop with pipelined Send/Recv, the while-body first releases + // the Send/Recv resource and then uses the resource. Therefore, subtract 1 + // from num_resources for the relevant resource type. + int64_t first_p2p_resource = + GetFirstTargetDefinedResource() + + static_cast(GpuResourceType::kGpuAsyncStreamSend0); + if (resource_type < first_p2p_resource || + resource_type > first_p2p_resource + 4) { + return num_resources; + } + auto find_instruction_for_pipeline = [&](HloOpcode opcode, + int64_t pipeline) { + for (auto user1 : instr.users()) { + if (user1->opcode() == HloOpcode::kGetTupleElement) { + for (auto user2 : user1->users()) { + if (user2->opcode() == opcode) { + if (GetPipelineStream(*user2) == pipeline) { + return true; + } + } + } + } + } + return false; + }; + bool found; + // Look into the users of the while-result to find pipelined Send-done or + // Recv-done. + if (resource_type == first_p2p_resource) { + found = find_instruction_for_pipeline(HloOpcode::kSendDone, 0); + } else if (resource_type == first_p2p_resource + 1) { + found = find_instruction_for_pipeline(HloOpcode::kSendDone, 1); + } else if (resource_type == first_p2p_resource + 2) { + found = find_instruction_for_pipeline(HloOpcode::kRecvDone, 0); + } else { + found = find_instruction_for_pipeline(HloOpcode::kRecvDone, 1); + } + return num_resources - (found ? 1 : 0); + } }; class GpuLatencyEstimator : public ApproximateLatencyEstimator { @@ -543,16 +626,27 @@ std::optional ReadPGLEProfile( const std::string& text_path, const std::string& binary_path) -> std::optional { - Status s = tsl::ReadTextProto(env, text_path, &profile); - if (s.ok()) { - LOG(INFO) << "Using PGLE profile from " << text_path; - return GetProfileForFingerprint(profile, fingerprint); + if (env->FileExists(text_path).ok()) { + absl::Status s = tsl::ReadTextProto(env, text_path, &profile); + if (s.ok()) { + LOG(INFO) << "Using PGLE profile from " << text_path; + return GetProfileForFingerprint(profile, fingerprint); + } else { + LOG(ERROR) << "Unable to read PGLE text proto from " << text_path + << ": " << s.message(); + } + profile.Clear(); } - profile.Clear(); - s = tsl::ReadBinaryProto(env, binary_path, &profile); - if (s.ok()) { - LOG(INFO) << "Using PGLE profile from " << binary_path; - return GetProfileForFingerprint(profile, fingerprint); + if (env->FileExists(binary_path).ok()) { + absl::Status s = tsl::ReadBinaryProto(env, binary_path, &profile); + if (s.ok()) { + LOG(INFO) << "Using PGLE profile from " << binary_path; + return GetProfileForFingerprint(profile, fingerprint); + } else { + LOG(ERROR) << "Unable to read PGLE binary proto from " << binary_path + << ": " << s.message(); + } + profile.Clear(); } return std::nullopt; }; @@ -567,14 +661,22 @@ std::optional ReadPGLEProfile( } // The pgle_profile_file_or_dir is a file. Attempt to read the profile as text - // proto or binary proto. - return read_text_or_binary_profile(pgle_profile_file_or_dir_path, - pgle_profile_file_or_dir_path); + // proto or binary proto. Attempt to infer the file type based on the + // extension. + auto extension = tsl::io::Extension(pgle_profile_file_or_dir_path); + if (extension == "pbtxt") { + return read_text_or_binary_profile(pgle_profile_file_or_dir_path, ""); + } else if (extension == "pb") { + return read_text_or_binary_profile("", pgle_profile_file_or_dir_path); + } else { + return read_text_or_binary_profile(pgle_profile_file_or_dir_path, + pgle_profile_file_or_dir_path); + } } // Return true if the profile is applicable to the module. That is true if every // instruction in the profile is present in the module. -Status IsProfileApplicable( +absl::Status IsProfileApplicable( const HloModule* module, const tensorflow::profiler::ProfiledInstructionsProto& profile) { absl::flat_hash_set instruction_names; @@ -583,11 +685,15 @@ Status IsProfileApplicable( instruction_names.insert(instr->name()); } } - + int64_t total_instruction_count = instruction_names.size(); + int64_t cost_miss_count; + int64_t cost_hit_count; for (const auto& cost : profile.costs()) { if (!instruction_names.contains(cost.name())) { + cost_miss_count++; // profile inst name not in this module return absl::InvalidArgumentError(absl::StrFormat( "cost name %s not in module %s", cost.name(), module->name())); + } else { } } for (const auto& latency : profile.latencies()) { @@ -601,14 +707,14 @@ Status IsProfileApplicable( "cost name %s not in module %s", latency.target(), module->name())); } } - return OkStatus(); + return absl::OkStatus(); } } // end namespace int64_t GetSizeOfShape(const Shape& shape, int pointer_size) { int64_t size = ShapeUtil::ByteSizeOf(shape, pointer_size); - if (shape.is_static() || shape.IsTuple()) { + if (shape.IsTuple() || shape.is_static()) { return size; } // Each dynamic dimension size is represented as a S32. @@ -616,12 +722,15 @@ int64_t GetSizeOfShape(const Shape& shape, int pointer_size) { return size + metadata_size; } -Status ScheduleGpuModule(HloModule* module, int64_t pointer_size, - int64_t memory_limit, - const se::DeviceDescription& gpu_device_info) { +absl::StatusOr ScheduleGpuModule( + HloModule* module, int64_t pointer_size, + const se::DeviceDescription& gpu_device_info) { + int64_t memory_limit = + GetSchedulerMemoryLimit(module, gpu_device_info, pointer_size); if (module->has_schedule()) { - return OkStatus(); + return ScheduleMetadata{memory_limit}; } + HloPassPipeline prepare_pipeline("p2p-schedule-preparation"); prepare_pipeline.AddPass(); TF_RETURN_IF_ERROR(prepare_pipeline.Run(module).status()); @@ -635,10 +744,9 @@ Status ScheduleGpuModule(HloModule* module, int64_t pointer_size, // instruction name with ids. std::string fingerprint = module->GetFingerprint128( HloPrintOptions::Canonical().set_print_backend_config(true)); - HloInstruction* root = module->entry_computation()->root_instruction(); FrontendAttributes attributes; (*attributes.mutable_map())[std::string(kFingerprintBeforeLHS)] = fingerprint; - root->add_frontend_attributes(attributes); + module->add_frontend_attributes(attributes); VLOG(1) << "Fingerprint before LHS for module " << module->name() << "(" << module->unique_id() << ") = " << fingerprint; @@ -648,7 +756,7 @@ Status ScheduleGpuModule(HloModule* module, int64_t pointer_size, .xla_gpu_enable_latency_hiding_scheduler(); if (!enable_latency_hiding_scheduler) { - return OkStatus(); + return ScheduleMetadata{memory_limit}; } SchedulerConfig config = GetSchedulerConfig(memory_limit); @@ -663,10 +771,25 @@ Status ScheduleGpuModule(HloModule* module, int64_t pointer_size, .debug_options() .xla_gpu_enable_analytical_latency_estimator(); if (profile.has_value()) { - latency_estimator = std::make_unique( - config, std::move(gpu_latency_estimator), profile.value()); + if (enable_analytical_latency_estimator) { + auto backup_latency_estimator = + std::make_unique( + config, std::move(gpu_latency_estimator), gpu_device_info, + [input_pointer_size = pointer_size](const Shape& shape) { + return GetSizeOfShape(shape, input_pointer_size); + }, + module->entry_computation()); + LOG(INFO) << "Using analytical latency estimator as PGLE backup"; + latency_estimator = std::make_unique( + config, std::move(backup_latency_estimator), profile.value()); + } else { + LOG(INFO) << "Using gpu_latency_estimator as PGLE backup"; + latency_estimator = std::make_unique( + config, std::move(gpu_latency_estimator), profile.value()); + } + LOG(INFO) << "Found profile, using profile guided latency estimator"; - Status s = IsProfileApplicable(module, profile.value()); + absl::Status s = IsProfileApplicable(module, profile.value()); if (!s.ok()) { LOG(INFO) << "PGLE profile may not applicable to the module, but will " "still be used : " @@ -695,21 +818,33 @@ Status ScheduleGpuModule(HloModule* module, int64_t pointer_size, auto shape_size_in_bytes = [pointer_size](const Shape& shape) { return GetSizeOfShape(shape, pointer_size); }; + const bool enable_linear_program_scheduler = + module->config() + .debug_options() + .xla_gpu_enable_linear_program_scheduler(); + HloPassPipeline pipeline("latency-hiding-scheduler"); + auto scheduler_core = std::make_unique( shape_size_in_bytes, async_tracker.get(), latency_estimator.get(), config); - pipeline.AddPass( - std::move(latency_estimator), std::move(async_tracker), - std::move(scheduler_core), shape_size_in_bytes); + if (enable_linear_program_scheduler) { + pipeline.AddPass( + std::move(latency_estimator), std::move(async_tracker), + std::move(scheduler_core), shape_size_in_bytes); + } else { + pipeline.AddPass( + std::move(latency_estimator), std::move(async_tracker), + std::move(scheduler_core), shape_size_in_bytes); + } TF_RETURN_IF_ERROR(pipeline.Run(module).status()); HloPassPipeline postprocessing_pipeline("gpu-schedule-postprocessing"); postprocessing_pipeline.AddPass(); TF_RETURN_IF_ERROR(postprocessing_pipeline.Run(module).status()); - return OkStatus(); + return ScheduleMetadata{memory_limit}; } HloInstructionSequence PostProcessSchedule( diff --git a/xla/service/gpu/gpu_hlo_schedule.h b/xla/service/gpu/gpu_hlo_schedule.h index 792dfbd2633c1..74262cbf8c6a5 100644 --- a/xla/service/gpu/gpu_hlo_schedule.h +++ b/xla/service/gpu/gpu_hlo_schedule.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,27 +16,87 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GPU_HLO_SCHEDULE_H_ #define XLA_SERVICE_GPU_GPU_HLO_SCHEDULE_H_ -#include "xla/hlo/ir/hlo_module.h" -#include "xla/stream_executor/device_description.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_casting_utils.h" + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/buffer_value.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/hlo/experimental/auto_reorder/auto_reorder.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/gpu_schedule_postprocessing.h" +#include "xla/service/gpu/model/analytical_latency_estimator.h" +#include "xla/service/hlo_memory_scheduler.h" +#include "xla/service/hlo_pass_pipeline.h" +#include "xla/service/latency_hiding_scheduler.h" +#include "xla/service/p2p_schedule_preparation.h" +#include "xla/service/profile_guided_latency_estimator.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/shape.h" +#include "xla/stream_executor/device_description.h" +#include "xla/util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/path.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/protobuf.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { int64_t GetSizeOfShape(const Shape& shape, int pointer_size); +struct ScheduleMetadata { + int64_t scheduler_mem_limit; +}; + // Determines the schedule of HLO instructions for a module run on the GPU. -Status ScheduleGpuModule(HloModule* module, int64_t pointer_size, - int64_t memory_limit, - const se::DeviceDescription& gpu_device_info); +absl::StatusOr ScheduleGpuModule( + HloModule* module, int64_t pointer_size, + const se::DeviceDescription& gpu_device_info); +int64_t GetSchedulerMemoryLimit( + const HloModule* module, const se::DeviceDescription& gpu_device_info, + int pointer_size); HloInstructionSequence PostProcessSchedule(const HloInstructionSequence& input); -int64_t GetSchedulerMemoryLimit(const HloModule* module, - const se::DeviceDescription& gpu_device_info, - int pointer_size); - constexpr absl::string_view kFingerprintBeforeLHS = "fingerprint_before_lhs"; } // namespace gpu + + } // namespace xla #endif // XLA_SERVICE_GPU_GPU_HLO_SCHEDULE_H_ diff --git a/xla/service/gpu/gpu_hlo_schedule_test.cc b/xla/service/gpu/gpu_hlo_schedule_test.cc index e952c0c9d8b07..4edab874aaf51 100644 --- a/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" #include "absl/log/log.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -60,10 +61,8 @@ class GpuHloScheduleTest : public HloTestBase { Backend& test_backend = backend(); const se::DeviceDescription& gpu_device_info = test_backend.default_stream_executor()->GetDeviceDescription(); - TF_CHECK_OK(ScheduleGpuModule( - module, /*pointer_size=*/8, - /*memory_limit=*/gpu_device_info.device_memory_size() * 8 / 10, - gpu_device_info)); + TF_CHECK_OK(ScheduleGpuModule(module, /*pointer_size=*/8, gpu_device_info) + .status()); return SequentialHloOrdering{module->schedule()}; } @@ -89,9 +88,7 @@ class GpuHloScheduleTest : public HloTestBase { static bool HasValidFingerprint(HloModule* module) { // Verify that the fingerprint of HLO prior to LHS is present. - const HloInstruction* root = - module->entry_computation()->root_instruction(); - const FrontendAttributes& attrs = root->frontend_attributes(); + const FrontendAttributes& attrs = module->frontend_attributes(); auto it = attrs.map().find(kFingerprintBeforeLHS); // The fingerprint is 128 bits stored as a hex string (128/4 hex digits). @@ -579,7 +576,7 @@ TEST_F(GpuHloScheduleTest, LHSSendRecv) { EXPECT_LT(get_index("recv"), get_index("send")); EXPECT_LT(get_index("send"), get_index("recv-done")); - EXPECT_GE(get_index("send-done") - get_index("recv-done"), 9); + EXPECT_GE(get_index("send-done") - get_index("recv-done"), 8); EXPECT_LT(abs(get_index("send-done") - get_index("result")), 2); EXPECT_TRUE(HasValidFingerprint(module.get())); } @@ -680,8 +677,8 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPairs2) { EXPECT_LT(abs(get_index("send-done-0") - get_index("result")), 2); } -// Checks that asynchronous AllReduce is scheduled to interleave with the Send -// and Recv sequence. +// Checks that asynchronous AllReduce is scheduled to not interleave with the +// Send and Recv sequence. TEST_F(GpuHloScheduleTest, LHSSendRecvAllReduce) { const char* hlo_text = R"( HloModule test @@ -731,7 +728,7 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvAllReduce) { lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} all-reduce-start = f32[1, 1024, 1024] all-reduce-start(f32[1, 1024, 1024] p), - replica_groups={{0,1}}, to_apply=add, backend_config={"is_sync":false} + replica_groups={{0,1}}, to_apply=add, backend_config={"collective_backend_config":{"is_sync":false}} all-reduce-done = f32[1, 1024, 1024] all-reduce-done(f32[1, 1024, 1024] all-reduce-start) new-data = f32[1, 1024, 1024] add(s, all-reduce-done) ROOT result = (u32[], f32[1, 1024, 1024]) tuple(new_count, new-data) @@ -768,43 +765,178 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvAllReduce) { EXPECT_LT(get_index("recv"), get_index("send")); EXPECT_LT(get_index("send"), get_index("recv-done")); EXPECT_GE(get_index("send-done") - get_index("recv-done"), 3); - EXPECT_LT(get_index("send-done"), get_index("all-reduce-start")); + EXPECT_TRUE(get_index("send-done") < get_index("all-reduce-start") || + get_index("recv") > get_index("all-reduce-start")); EXPECT_TRUE(HasValidFingerprint(module.get())); } // Checks that with the dependence added by the gpu-hlo-scheduler, the -// pipelined Send and Recv instructions are scheduled correctly. -TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined) { +// pipelined one Send-Recv group is scheduled correctly. +TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined1) { const char* hlo_text = R"( HloModule test while_cond { - param = (u32[], f32[1, 1024, 1024], f32[1, 1024, 1024]) parameter(0) + param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0) count = get-tuple-element(param), index=0 ub = u32[] constant(25) ROOT cond-result = pred[] compare(count, ub), direction=LT } -while_body { - param = (u32[], f32[1, 1024, 1024], f32[1, 1024, 1024]) parameter(0) + while_body { + param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0) count = get-tuple-element(param), index=0 - send-data = get-tuple-element(param), index=1 - recv-data = get-tuple-element(param), index=2 + + recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1 + recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0 + + c1 = u32[] constant(1) + new-count = u32[] add(count, c1) + replica = u32[] replica-id() + c10 = u32[] constant(10) + sum = u32[] add(replica, c10) + sum2 = u32[] add(sum, count) + conv = f32[] convert(sum2) + p = f32[1, 1024, 1024] broadcast(conv), dimensions={} + b = f32[1, 1024, 1024] add(p, recv-data) + c = f32[1, 1024, 1024] multiply(b, b) + d = f32[1, 1024, 1024] tan(c) + s = f32[1, 1024, 1024] dot(c, d), lhs_batch_dims={0}, + lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} + send-data = f32[1, 1024, 1024] add(c, s) after-all.1 = token[] after-all() send.1 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.1), channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" - } - send-done.1 = token[] send-done(send.1), channel_id=1 + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" + } recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" - } + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" + } + recv-done.1 = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send-done.1 = token[] send-done(send.1), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[]) + tuple(new-count, recv-done.1, send-done.1) + } + + ENTRY main { + c0 = u32[] constant(0) + f0 = f32[] constant(0.0) + init = f32[1, 1024, 1024] broadcast(f0), dimensions={} + + after-all.2 = token[] after-all() + recv.2 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" + } + send.2 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" + } + recv-done.2 = (f32[1,1024,1024], token[]) recv-done(recv.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send-done.2 = token[] send-done(send.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + while-init = (u32[], (f32[1,1024,1024], token[]), token[]) + tuple(c0, recv-done.2, send-done.2) + while-result = (u32[], (f32[1,1024,1024], token[]), token[]) + while(while-init), + body=while_body, condition=while_cond, + backend_config={"known_trip_count":{"n":"25"}} + + recv-done.2.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result), index=1 + + ROOT entry-result = f32[1, 1024, 1024] get-tuple-element(recv-done.2.q), index=0 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule( + hlo_text, GetModuleConfig(/*enable_latency_hiding_scheduler=*/true, + /*enable_gpu_async_tracker=*/true))); + SequentialHloOrdering order = BuildHloOrdering(module.get()); + const std::vector& while_body = + order.SequentialOrder(*module->GetComputationWithName("while_body")) + ->instructions(); + const std::vector& main = + order.SequentialOrder(*module->GetComputationWithName("main")) + ->instructions(); + auto get_index = + [](absl::string_view hlo_name, + const std::vector& instruction_sequence) { + return absl::c_find_if(instruction_sequence, + [hlo_name](HloInstruction* instruction) { + return instruction->name() == hlo_name; + }) - + instruction_sequence.begin(); + }; + EXPECT_TRUE(HasValidFingerprint(module.get())); + + // The pipelined Send-Recv in the main. A pipelined Recv is scheduled right + // after its corresponding Send due to kForceEarly. + EXPECT_EQ(get_index("recv.2", main) + 1, get_index("send.2", main)); + EXPECT_LT(get_index("send.2", main), get_index("recv-done.2", main)); + EXPECT_LT(get_index("recv-done.2", main), get_index("send-done.2", main)); + EXPECT_LT(get_index("send-done.2", main), get_index("while-result", main)); + + // The pipelined Send-Recv in the while-body. A pipelined Recv is scheduled + // right after its corresponding Send due to kForceEarly. + EXPECT_EQ(get_index("recv.1", while_body) + 1, + get_index("send.1", while_body)); + EXPECT_LT(get_index("send.1", while_body), + get_index("recv-done.1", while_body)); + EXPECT_LT(get_index("recv-done.1", while_body), + get_index("send-done.1", while_body)); +} + +// Checks that with the dependence added by the gpu-hlo-scheduler, the +// pipelined two Send-Recv groups are scheduled correctly. +TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined2) { + const char* hlo_text = R"( + HloModule test + + while_cond { + param = (u32[], (f32[1,1024,1024], token[]), token[], + (f32[1,1024,1024], token[]), token[]) parameter(0) + count = get-tuple-element(param), index=0 + ub = u32[] constant(25) + ROOT cond-result = pred[] compare(count, ub), direction=LT + } + + while_body { + param = (u32[], (f32[1,1024,1024], token[]), token[], + (f32[1,1024,1024], token[]), token[]) parameter(0) + count = get-tuple-element(param), index=0 + + recv-done.0.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1 + recv-data.0 = f32[1, 1024, 1024] get-tuple-element(recv-done.0.q), index=0 + recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=3 + recv-data.1 = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0 + + replica = u32[] replica-id() + constant0 = u32[] constant(0) + compare0 = pred[] compare(replica, constant0), direction=EQ + compare = pred[1, 1024, 1024] broadcast(compare0), dimensions={} + recv-data = f32[1, 1024, 1024] select(compare, recv-data.0, recv-data.1) c1 = u32[] constant(1) new-count = u32[] add(count, c1) - replica = u32[] replica-id() c10 = u32[] constant(10) sum = u32[] add(replica, c10) sum2 = u32[] add(sum, count) @@ -815,26 +947,51 @@ while_body { d = f32[1, 1024, 1024] tan(c) s = f32[1, 1024, 1024] dot(c, d), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} - new-data-0 = f32[1, 1024, 1024] add(c, s) + send-data = f32[1, 1024, 1024] add(c, s) - recv-done.1 = (f32[1, 1024, 1024], token[]) recv-done(recv.1), channel_id=1 - new-recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.1), index=0 + after-all.0 = token[] after-all() + send.0 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.0), + channel_id=1, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } + recv.0 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } + recv-done.0 = (f32[1,1024,1024], token[]) recv-done(recv.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send-done.0 = token[] send-done(send.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } - after-all.4 = token[] after-all() - send.4 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.4), - channel_id=4, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" - } - send-done.4 = token[] send-done(send.4), channel_id=4 - recv.4 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.4), channel_id=4, + after-all.1 = token[] after-all() + send.1 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.1), + channel_id=2, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}", + _xla_send_recv_pipeline="1" + } + recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=2, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" - } - recv-done.4 = (f32[1, 1024, 1024], token[]) recv-done(recv.4), channel_id=4 - recv-data-4 = f32[1, 1024, 1024] get-tuple-element(recv-done.4), index=0 - new-data = f32[1, 1024, 1024] add(new-data-0, recv-data-4) + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}", + _xla_send_recv_pipeline="1" + } + recv-done.1 = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } + send-done.1 = token[] send-done(send.1), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } - ROOT body-result = (u32[], f32[1, 1024, 1024], f32[1, 1024, 1024]) tuple(new-count, new-data, new-recv-data) + ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[], + (f32[1,1024,1024], token[]), token[]) + tuple(new-count, recv-done.0, send-done.0, recv-done.1, send-done.1) } ENTRY main { @@ -845,24 +1002,60 @@ while_body { after-all.2 = token[] after-all() recv.2 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.2), channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } + send.2 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } + recv-done.2 = (f32[1,1024,1024], token[]) recv-done(recv.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send-done.2 = token[] send-done(send.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + + after-all.3 = token[] after-all() + recv.3 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.3), channel_id=2, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}", + _xla_send_recv_pipeline="1" + } + send.3 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.3), channel_id=2, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}", + _xla_send_recv_pipeline="1" } - recv-done.2 = (f32[1, 1024, 1024], token[]) recv-done(recv.2), channel_id=1 - recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.2), index=0 + recv-done.3 = (f32[1,1024,1024], token[]) recv-done(recv.3), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } + send-done.3 = token[] send-done(send.3), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } - while-init = (u32[], f32[1, 1024, 1024], f32[1, 1024, 1024]) tuple(c0, init, recv-data) - while-result = (u32[], f32[1, 1024, 1024], f32[1, 1024, 1024]) while(while-init), + while-init = (u32[], (f32[1,1024,1024], token[]), token[], + (f32[1,1024,1024], token[]), token[]) tuple(c0, recv-done.2, send-done.2, recv-done.3, send-done.3) + while-result = (u32[], (f32[1,1024,1024], token[]), token[], + (f32[1,1024,1024], token[]), token[]) while(while-init), body=while_body, condition=while_cond, backend_config={"known_trip_count":{"n":"25"}} - send-data = f32[1, 1024, 1024] get-tuple-element(while-result), index=2 - send.2 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.2), - channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" - } - send-done.2 = token[] send-done(send.2), channel_id=1 + recv-done.2.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result), index=1 + recv-data.2 = f32[1, 1024, 1024] get-tuple-element(recv-done.2.q), index=0 + recv-done.3.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result), index=3 + recv-data.3 = f32[1, 1024, 1024] get-tuple-element(recv-done.3.q), index=0 - ROOT entry-result = f32[1, 1024, 1024] get-tuple-element(while-result), index=1 + replica = u32[] replica-id() + constant0 = u32[] constant(0) + compare0 = pred[] compare(replica, constant0), direction=EQ + compare = pred[1, 1024, 1024] broadcast(compare0), dimensions={} + ROOT entry-result = f32[1, 1024, 1024] select(compare, recv-data.2, recv-data.3) } )"; @@ -889,20 +1082,32 @@ while_body { }; EXPECT_TRUE(HasValidFingerprint(module.get())); - - // The pipelined Send-Recv in the main. - EXPECT_LT(get_index("recv-done.2", main), get_index("while-result", main)); - EXPECT_LT(get_index("while-result", main), get_index("send.2", main)); - - // The pipelined Send-Recv in the while-body. - EXPECT_LT(get_index("send.1", while_body), get_index("recv.1", while_body)); - - // The unpipelined Send-Recv in the while-body is scheduled after the - // pipelined Send-Done and before the pipelined Recv. - EXPECT_LT(get_index("send-done.1", while_body), - get_index("recv.4", while_body)); - EXPECT_LT(get_index("recv-done.4", while_body), - get_index("recv.1", while_body)); + // The pipelined Send-Recv in the main. A pipelined Recv is scheduled right + // after its corresponding Send due to kForceEarly. + EXPECT_EQ(get_index("recv.2", main) + 1, get_index("send.2", main)); + EXPECT_LT(get_index("send.2", main), get_index("recv.3", main)); + EXPECT_EQ(get_index("recv.3", main) + 1, get_index("send.3", main)); + EXPECT_LT(get_index("send.3", main), get_index("recv-done.2", main)); + EXPECT_LT(get_index("recv-done.2", main), get_index("recv-done.3", main)); + EXPECT_LT(get_index("recv-done.3", main), get_index("send-done.2", main)); + EXPECT_LT(get_index("send-done.2", main), get_index("send-done.3", main)); + EXPECT_LT(get_index("send-done.3", main), get_index("while-result", main)); + + // The pipelined Send-Recv in the while-body. A pipelined Recv is scheduled + // right after its corresponding Send due to kForceEarly. + EXPECT_EQ(get_index("recv.0", while_body) + 1, + get_index("send.0", while_body)); + EXPECT_LT(get_index("send.0", while_body), get_index("recv.1", while_body)); + EXPECT_EQ(get_index("recv.1", while_body) + 1, + get_index("send.1", while_body)); + EXPECT_LT(get_index("send.1", while_body), + get_index("recv-done.0", while_body)); + EXPECT_LT(get_index("recv-done.0", while_body), + get_index("recv-done.1", while_body)); + EXPECT_LT(get_index("recv-done.1", while_body), + get_index("send-done.0", while_body)); + EXPECT_LT(get_index("send-done.0", while_body), + get_index("send-done.1", while_body)); } TEST_F(GpuHloScheduleTest, SkipAlreadyScheduled) { @@ -927,9 +1132,9 @@ ENTRY e { })") .value(); TF_CHECK_OK(ScheduleGpuModule( - module.get(), /*pointer_size=*/8, - /*memory_limit=*/1024 * 1024 * 1024, - backend().default_stream_executor()->GetDeviceDescription())); + module.get(), /*pointer_size=*/8, + backend().default_stream_executor()->GetDeviceDescription()) + .status()); EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: ENTRY // CHECK: wrapped_negate = f32[1024,1024]{1,0} @@ -937,6 +1142,77 @@ ENTRY e { )")); } +TEST_F(GpuHloScheduleTest, ProfileGuidedCostModelWithForceEarliestSchedule) { + const char* hlo_text = R"( + HloModule AsyncAR + apply_op { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT apply_op = f32[] add(x, y) + } + + ENTRY main { + p0 = f32[32] parameter(0) + p1 = f32[32, 32] parameter(1) + p2 = f32[32, 32] parameter(2) + p3 = f32[32] parameter(3) + + // Independent compute + dot0 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm", backend_config={"force_earliest_schedule":true} + dot1 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" + add0 = f32[32,32] add(dot0, dot1) + + // 2 Independent collectives. + ar-start = f32[32] all-reduce-start(p0), to_apply=apply_op + ar-done = f32[32] all-reduce-done(ar-start) + + ROOT t = (f32[32], f32[32,32]) tuple(ar-done, add0) + })"; + + const std::string ar_long_latency_proto_text = R"pb( + costs { name: "dot0" cost_us: 100.0 } + costs { name: "dot1" cost_us: 100.0 } + costs { name: "add0" cost_us: 10.0 } + costs { name: "ar-start" cost_us: 1000.0 } + )pb"; + + tensorflow::profiler::ProfiledInstructionsProto profile; + ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString( + ar_long_latency_proto_text, &profile)); + std::string ar_long_latency_proto_binary = profile.SerializeAsString(); + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule( + hlo_text, + GetModuleConfig(/*enable_latency_hiding_scheduler=*/true, + // Post processing should work even with + // GpuAsyncTrackerBase. + /*enable_gpu_async_tracker=*/false, + /*fdo_profile=*/ar_long_latency_proto_binary))); + SequentialHloOrdering order = BuildHloOrdering(module.get()); + + const std::vector& main = + order.SequentialOrder(*module->GetComputationWithName("main")) + ->instructions(); + auto get_index = + [](absl::string_view hlo_name, + const std::vector& instruction_sequence) { + return absl::c_find_if(instruction_sequence, + [hlo_name](HloInstruction* instruction) { + return instruction->name() == hlo_name; + }) - + instruction_sequence.begin(); + }; + // Using the profile, LHS should schedule all computes between ar pair, + // but since dot0 is marked as force delay, it should be scheduled + // before ar-start now. + EXPECT_LT(get_index("dot0", main), get_index("ar-start", main)); + // Also verify that dot1 is scheduled between ar start and ar done. + EXPECT_GT(get_index("dot1", main), get_index("ar-start", main)); + EXPECT_LT(get_index("dot1", main), get_index("ar-done", main)); +} + class GpuHloScheduleParameterizedTest : public GpuHloScheduleTest, public ::testing::WithParamInterface {}; @@ -1071,11 +1347,10 @@ TEST_P(GpuHloScheduleParameterizedTest, LHSResourceModel) { uint32_t max_in_flight = 0; for (const HloInstruction* inst : order.SequentialOrder(*module->entry_computation())->instructions()) { - HloOpcode op = inst->opcode(); - if (hlo_query::IsAsyncCollectiveStartOp(op)) { + if (hlo_query::IsAsyncCollectiveStartOp(inst)) { in_flight++; max_in_flight = std::max(max_in_flight, in_flight); - } else if (hlo_query::IsAsyncCollectiveDoneOp(op)) { + } else if (hlo_query::IsAsyncCollectiveDoneOp(inst)) { in_flight--; } } @@ -1110,7 +1385,7 @@ TEST_F(GpuHloSchedulePostProcessTest, PostProcessAsyncCollectives) { // This will be sync, so we expect the start/done to be moved next to each // other. - ag-start = (f32[32], f32[64]) all-gather-start(p1), dimensions={0}, backend_config="{\"is_sync\":true}" + ag-start = (f32[32], f32[64]) all-gather-start(p1), dimensions={0}, backend_config="{\"collective_backend_config\":{\"is_sync\":true}}" add1 = f32[32] add(p1, p1) ag-done = f32[64] all-gather-done(ag-start) diff --git a/xla/service/gpu/gpu_layout_assignment.cc b/xla/service/gpu/gpu_layout_assignment.cc index ad2d7ffa62f00..0926ca82cf5e5 100644 --- a/xla/service/gpu/gpu_layout_assignment.cc +++ b/xla/service/gpu/gpu_layout_assignment.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,28 +16,45 @@ limitations under the License. #include "xla/service/gpu/gpu_layout_assignment.h" #include +#include #include #include #include #include +#include #include -#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" #include "xla/layout_util.h" +#include "xla/primitive_util.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/gpu_flash_attn.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/reduction_utils.h" #include "xla/service/gpu/stream_executor_util.h" -#include "xla/status_macros.h" +#include "xla/service/logical_buffer.h" +#include "xla/shape.h" +#include "xla/shape_layout.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/dnn.h" +#include "xla/util.h" #include "xla/window_util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -48,7 +65,8 @@ using se::dnn::FilterLayout; // Returns (input, filter, output) layouts. static std::tuple HeuristicLayoutAssignment(const HloInstruction* instr, - se::StreamExecutor* stream_executor) { + const se::GpuComputeCapability& gpu_version, + const se::dnn::VersionInfo& dnn_version) { // DataLayout and FilterLayout uses weird enum names. Translations: // N <=> Batch or Output // C <=> Depth or Input @@ -111,10 +129,12 @@ HeuristicLayoutAssignment(const HloInstruction* instr, // If we're not Volta or not fp16/bfloat16, or not conv2D, the decision is // easy: Use NCHW. const bool isFloat16 = (input_ty == F16) || (input_ty == BF16); - if (!isFloat16 || - !stream_executor->GetDeviceDescription() - .cuda_compute_capability() - .IsAtLeast(se::CudaComputeCapability::VOLTA) || + const auto* cuda_compute_capability = + std::get_if(&gpu_version); + bool is_volta = + cuda_compute_capability && + cuda_compute_capability->IsAtLeast(se::CudaComputeCapability::VOLTA); + if (!isFloat16 || !is_volta || instr->shape().tuple_shapes(0).dimensions_size() != 4) { return kAllNCHW; } @@ -131,17 +151,11 @@ HeuristicLayoutAssignment(const HloInstruction* instr, // * we've also observed that for mixed layouts, cuDNN transposes data back // and forth from a different layout combination. If we end up with // transposes anyway, we prefer to have them in XLA, as they can be fused. - if (auto* dnn = stream_executor->AsDnn()) { - auto version_status = dnn->GetVersion(); - if (version_status.ok()) { - auto version = std::move(version_status).value(); - if (std::make_tuple(version.major_version(), version.minor_version()) <= - std::make_tuple(7, 3) && - instr->custom_call_target() == kCudnnConvBackwardInputCallTarget && - window_util::HasStride(instr->window())) { - return kAllNCHW; - } - } + if (std::make_tuple(dnn_version.major_version(), + dnn_version.minor_version()) <= std::make_tuple(7, 3) && + instr->custom_call_target() == kCudnnConvBackwardInputCallTarget && + window_util::HasStride(instr->window())) { + return kAllNCHW; } // For other Volta f16 convolutions, use NHWC. @@ -152,7 +166,7 @@ HeuristicLayoutAssignment(const HloInstruction* instr, // constraints are represented in terms of minor_to_major fields of both // operands and the output shape. Depending on the underlying algorithm, one of // { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen. -Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( +absl::Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( HloCustomCallInstruction* instr, LayoutConstraints* constraints) { Shape lhs_shape = instr->operand(0)->shape(); Shape rhs_shape = instr->operand(1)->shape(); @@ -188,7 +202,7 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( FilterLayout filter; DataLayout output; std::tie(input, filter, output) = - HeuristicLayoutAssignment(instr, stream_executor_); + HeuristicLayoutAssignment(instr, gpu_version_, dnn_version_); TF_ASSIGN_OR_RETURN( std::tie(*input_shape->mutable_layout(), @@ -230,13 +244,25 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( if (instr->operand_count() > 2 && kind != CudnnConvKind::kForwardActivation && kind != CudnnConvKind::kForwardGraph) { - return InternalError( + return Internal( "Invalid convolution. Conv has a side input, but kind is not fused " "conv forward or graph conv foward: %s", instr->ToString()); } - return OkStatus(); + return absl::OkStatus(); +} + +absl::Status GpuLayoutAssignment::AddBackendConstraintsToFlashAttnCustomCall( + HloCustomCallInstruction* instr, LayoutConstraints* constraints) { + // Make sure flash attn's operands and output are all contiguous + for (int64_t i = 0; i < instr->operand_count(); ++i) { + Shape op_shape = instr->operand(i)->shape(); + LayoutUtil::SetToDefaultLayout(&op_shape); + TF_RETURN_IF_ERROR(SetOperandLayout(op_shape, instr, i)); + } + TF_RETURN_IF_ERROR(SetInstructionLayout(instr->shape(), instr)); + return absl::OkStatus(); } namespace { @@ -269,7 +295,7 @@ bool DotCanSupportShapeWithLayout(const HloInstruction* dot, } // namespace -Status GpuLayoutAssignment::AddBackendConstraints( +absl::Status GpuLayoutAssignment::AddBackendConstraints( LayoutConstraints* constraints) { // Add convolution constraints in reverse postorder that the earliest // convolution layout propagates first. This reduces the likelihood of fusion @@ -282,6 +308,10 @@ Status GpuLayoutAssignment::AddBackendConstraints( TF_RETURN_IF_ERROR(AddBackendConstraintsToDnnConvCustomCall( Cast(instruction), constraints)); } + if (IsCustomCallToFlashAttn(*instruction)) { + TF_RETURN_IF_ERROR(AddBackendConstraintsToFlashAttnCustomCall( + Cast(instruction), constraints)); + } CHECK(!IsCublasGemm(*instruction)) << "Gemm rewriting should run after layout assignment"; @@ -446,12 +476,25 @@ Status GpuLayoutAssignment::AddBackendConstraints( ShapeUtil::MoveDimToMajor(all_to_all->shape(), *all_to_all->split_dimension()), all_to_all)); + } else if (instruction->opcode() == HloOpcode::kSend) { + Shape s = instruction->operand(0)->shape(); + LayoutUtil::SetToDefaultLayout(&s); + TF_RETURN_IF_ERROR(SetInstructionLayout(s, instruction->operand(0))); + TF_RETURN_IF_ERROR( + SetArrayOperandLayout(s.layout(), instruction->operand(0), 0)); + } else if (instruction->opcode() == HloOpcode::kRecv) { + Shape s = instruction->shape(); + ShapeUtil::ForEachMutableSubshape( + &s, [&](Shape* subshape, const ShapeIndex& index) { + LayoutUtil::SetToDefaultLayout(subshape); + }); + TF_RETURN_IF_ERROR(SetInstructionLayout(s, instruction)); } } - return OkStatus(); + return absl::OkStatus(); } -Status GpuLayoutAssignment::SetDotOperandLayout( +absl::Status GpuLayoutAssignment::SetDotOperandLayout( const HloInstruction* instruction, int64_t operand, absl::Span batch_dims, absl::Span row_dims, absl::Span col_dims) { @@ -474,7 +517,7 @@ Status GpuLayoutAssignment::SetDotOperandLayout( /*dim_groups=*/{batch_dims, row_dims, col_dims}); } -Status GpuLayoutAssignment::SetOperandMajorToMinorLayout( +absl::Status GpuLayoutAssignment::SetOperandMajorToMinorLayout( const HloInstruction* instruction, int64_t operand, std::initializer_list> dim_groups) { size_t size = 0; @@ -491,8 +534,8 @@ Status GpuLayoutAssignment::SetOperandMajorToMinorLayout( return SetOperandLayout(shape, instruction, operand); } -Status GpuLayoutAssignment::SetDotLayout(const HloInstruction* instruction, - LayoutConstraints* constraints) { +absl::Status GpuLayoutAssignment::SetDotLayout( + const HloInstruction* instruction, LayoutConstraints* constraints) { // If a user has requested a layout that we can support, use that. for (const HloInstruction* user : instruction->users()) { for (int64_t i = 0; i < user->operand_count(); ++i) { @@ -515,13 +558,16 @@ Status GpuLayoutAssignment::SetDotLayout(const HloInstruction* instruction, bool GpuLayoutAssignment::PropagateReductionLayoutToOperand( const HloInstruction* user) { - // Propagating the layout is only beneficial if the total size of reduction - // dims is large enough. + // We try to propagate a layout to make the reduction a row reduction. But + // propagating the layout is only beneficial if the reduction emitter would be + // used for the row reduction. int64_t reduction_size = 1; for (int64_t reduction_dim : user->dimensions()) { reduction_size *= user->operand(0)->shape().dimensions(reduction_dim); } - return reduction_size >= 32; + int64_t kept_dimension_size = ShapeUtil::ElementsIn(user->shape()); + return IsUnnestedReductionFasterThanElemental( + {/*is_row_reduction=*/true, {1, kept_dimension_size, reduction_size}}); } } // namespace gpu diff --git a/xla/service/gpu/gpu_layout_assignment.h b/xla/service/gpu/gpu_layout_assignment.h index 27ca9bcfb72be..5ac233e47afb6 100644 --- a/xla/service/gpu/gpu_layout_assignment.h +++ b/xla/service/gpu/gpu_layout_assignment.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,11 +16,17 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ #define XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ +#include +#include + +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/computation_layout.h" #include "xla/service/layout_assignment.h" -#include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/status.h" +#include "xla/status.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/dnn.h" namespace xla { namespace gpu { @@ -31,35 +37,41 @@ class GpuLayoutAssignment : public LayoutAssignment { public: explicit GpuLayoutAssignment( ComputationLayout* entry_computation_layout, - se::StreamExecutor* stream_executor, + const se::GpuComputeCapability& gpu_version, + const se::dnn::VersionInfo& dnn_version, ChannelLayoutConstraints* channel_constraints = nullptr) : LayoutAssignment(entry_computation_layout, channel_constraints), - stream_executor_(stream_executor) {} + gpu_version_(gpu_version), + dnn_version_(dnn_version) {} ~GpuLayoutAssignment() override = default; protected: - Status AddBackendConstraints(LayoutConstraints* constraints) override; + absl::Status AddBackendConstraints(LayoutConstraints* constraints) override; private: - Status AddBackendConstraintsToDnnConvCustomCall( + absl::Status AddBackendConstraintsToDnnConvCustomCall( + HloCustomCallInstruction* instr, LayoutConstraints* constraints); + absl::Status AddBackendConstraintsToFlashAttnCustomCall( HloCustomCallInstruction* instr, LayoutConstraints* constraints); // dim_groups are ordered from major to minor dimensions. - Status SetOperandMajorToMinorLayout( + absl::Status SetOperandMajorToMinorLayout( const HloInstruction* instruction, int64_t operand, std::initializer_list> dim_groups); - Status SetDotOperandLayout(const HloInstruction* instruction, int64_t operand, - absl::Span batch_dims, - absl::Span row_dims, - absl::Span col_dims); + absl::Status SetDotOperandLayout(const HloInstruction* instruction, + int64_t operand, + absl::Span batch_dims, + absl::Span row_dims, + absl::Span col_dims); - Status SetDotLayout(const HloInstruction* instruction, - LayoutConstraints* constraints); + absl::Status SetDotLayout(const HloInstruction* instruction, + LayoutConstraints* constraints); bool PropagateReductionLayoutToOperand(const HloInstruction* user) override; - se::StreamExecutor* stream_executor_; + const se::GpuComputeCapability gpu_version_; + const se::dnn::VersionInfo dnn_version_; }; } // namespace gpu diff --git a/xla/service/gpu/gpu_layout_assignment_test.cc b/xla/service/gpu/gpu_layout_assignment_test.cc index c07c5257da344..0a92f45cacc8e 100644 --- a/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/xla/service/gpu/gpu_layout_assignment_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,25 +15,32 @@ limitations under the License. #include "xla/service/gpu/gpu_layout_assignment.h" +#include #include -#include "absl/strings/str_cat.h" +#include +#include +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" #include "xla/layout_util.h" #include "xla/service/computation_layout.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/gemm_rewriter.h" +#include "xla/service/gpu/stream_executor_util.h" #include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" #include "xla/shape_layout.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/dnn.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" #include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -50,6 +57,20 @@ class LayoutAssignmentTest : public HloTestBase { ->GetDeviceDescription() .cuda_compute_capability(); } + + se::GpuComputeCapability GetGpuComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } + + se::dnn::VersionInfo GetDnnVersion() { + // GpuLayoutAssignment has a special case heuristic for cudnn <= 7.3, but + // none of the tests trigger this heuristic. + return GetDnnVersionInfo(backend().default_stream_executor(), + se::dnn::VersionInfo{8, 3, 0}); + } }; TEST_F(LayoutAssignmentTest, Elementwise) { @@ -89,7 +110,7 @@ TEST_F(LayoutAssignmentTest, Elementwise) { ShapeLayout(result_shape_with_layout); GpuLayoutAssignment layout_assignment( - &computation_layout, backend().default_stream_executor()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); for (const HloInstruction* operand : add->operands()) { @@ -118,8 +139,8 @@ TEST_F(LayoutAssignmentTest, DotLayoutUnchangedIfValid) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); - GpuLayoutAssignment layout_assignment(&computation_layout, - backend().default_stream_executor()); + GpuLayoutAssignment layout_assignment( + &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT(module->entry_computation()->root_instruction(), GmockMatch(m::Dot(m::Op().WithShape(F32, {5, 2, 3}, {1, 2, 0}), @@ -144,8 +165,8 @@ TEST_F(LayoutAssignmentTest, DotLayoutSetToDefaultIfDefaultValid) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); - GpuLayoutAssignment layout_assignment(&computation_layout, - backend().default_stream_executor()); + GpuLayoutAssignment layout_assignment( + &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -171,8 +192,8 @@ TEST_F(LayoutAssignmentTest, DotOperandLayoutSetToBatchRowsColsOtherwise) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); - GpuLayoutAssignment layout_assignment(&computation_layout, - backend().default_stream_executor()); + GpuLayoutAssignment layout_assignment( + &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -197,8 +218,8 @@ TEST_F(LayoutAssignmentTest, DotOperandInconsistentDimLayouts) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); - GpuLayoutAssignment layout_assignment(&computation_layout, - backend().default_stream_executor()); + GpuLayoutAssignment layout_assignment( + &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT( @@ -225,8 +246,8 @@ TEST_F(LayoutAssignmentTest, TransposedDotLayout) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); - GpuLayoutAssignment layout_assignment(&computation_layout, - backend().default_stream_executor()); + GpuLayoutAssignment layout_assignment( + &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT( @@ -258,8 +279,8 @@ TEST_F(LayoutAssignmentTest, TransposedDotOfDotLayout) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); - GpuLayoutAssignment layout_assignment(&computation_layout, - backend().default_stream_executor()); + GpuLayoutAssignment layout_assignment( + &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); // The transpose layout is not supported by dot.2. Also, we need a copy @@ -294,8 +315,8 @@ TEST_F(LayoutAssignmentTest, DotLayoutS8) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); - GpuLayoutAssignment layout_assignment(&computation_layout, - backend().default_stream_executor()); + GpuLayoutAssignment layout_assignment( + &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -329,8 +350,8 @@ TEST_F(LayoutAssignmentTest, SortLayout) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); - GpuLayoutAssignment layout_assignment(&computation_layout, - backend().default_stream_executor()); + GpuLayoutAssignment layout_assignment( + &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); @@ -355,8 +376,8 @@ TEST_F(LayoutAssignmentTest, FftLayout) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); - GpuLayoutAssignment layout_assignment(&computation_layout, - backend().default_stream_executor()); + GpuLayoutAssignment layout_assignment( + &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -383,8 +404,8 @@ ENTRY entry { ComputationLayout computation_layout( m->entry_computation()->ComputeProgramShape()); - GpuLayoutAssignment layout_assignment(&computation_layout, - backend().default_stream_executor()); + GpuLayoutAssignment layout_assignment( + &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true)); @@ -495,8 +516,8 @@ ENTRY main { ParseAndReturnVerifiedModule(module_str)); ComputationLayout computation_layout( m->entry_computation()->ComputeProgramShape()); - GpuLayoutAssignment layout_assignment(&computation_layout, - backend().default_stream_executor()); + GpuLayoutAssignment layout_assignment( + &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true)); auto reduce = m->entry_computation()->root_instruction(); @@ -504,6 +525,97 @@ ENTRY main { LayoutUtil::MakeLayout({3, 1, 4, 2, 0}).minor_to_major()); } +TEST_F(LayoutAssignmentTest, ReduceOperandLayoutDivisorOfWarpSize) { + // Same as ReduceOperandLayout, but with a small reduction dimension that + // is a divisor of the warp size. + const char* module_str = R"( +scalar_add_computation { + scalar_lhs = c64[] parameter(0) + scalar_rhs = c64[] parameter(1) + ROOT add.1 = c64[] add(scalar_lhs, scalar_rhs) +} + +ENTRY main { + param_0 = c64[512,16,1024,128]{3,2,1,0} parameter(0) + negate = c64[512,16,1024,128]{3,2,1,0} negate(param_0) + constant_7 = c64[] constant((0, 0)) + ROOT reduce.2 = c64[512,1024,128]{2,1,0} reduce(negate, constant_7), dimensions={1}, to_apply=scalar_add_computation +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + GpuLayoutAssignment layout_assignment( + &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + + EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true)); + auto reduce = m->entry_computation()->root_instruction(); + EXPECT_EQ(reduce->operand(0)->shape().layout().minor_to_major(), + LayoutUtil::MakeLayout({1, 3, 2, 0}).minor_to_major()); +} + +TEST_F(LayoutAssignmentTest, SendRcvLayout) { + const char* hlo = R"( +HloModule Module + +condition { + p = (f32[100,100], (f32[100,100], u32[], token[])) parameter(0) + ROOT lt = pred[] constant(1) +} + +body { + p = (f32[100,100], (f32[100,100], u32[], token[])) parameter(0) + + t1 = f32[100,100] get-tuple-element(p), index=0 + t = (f32[100,100], u32[], token[]) get-tuple-element(p), index=1 + sdone = token[] send-done(t), channel_id=3, frontend_attributes={ + _xla_send_recv_pipeline="0" + } + tk = token[] after-all() + + + rcvd = (f32[100,100]{0,1}, u32[], token[]) recv(tk), channel_id=2 + zz = (f32[100,100]{0,1}, token[]) recv-done(rcvd), channel_id=2 + + rcvd_d = get-tuple-element(zz), index=0 + + snd = (f32[100,100]{0,1}, u32[], token[]) send(t1, tk), channel_id=3, frontend_attributes={ + _xla_send_recv_pipeline="0" + } + a = add(t1, t1) + + b = add(rcvd_d, a) + + ROOT tup = tuple(b, snd) +} + +ENTRY %main { + p0 = f32[100,100] parameter(0) + tk = token[] after-all() + snd = (f32[100,100]{0,1}, u32[], token[]) send(p0, tk), channel_id=1, frontend_attributes={ + _xla_send_recv_pipeline="0" + } + t = tuple(p0, snd) + ROOT loop = while(t), condition=condition, body=body +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo)); + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + + RunAndFilecheckHloRewrite( + hlo, + GpuLayoutAssignment{&computation_layout, GetGpuComputeCapability(), + GetDnnVersion()}, + R"( +// CHECK: (f32[100,100]{1,0}, u32[], token[]) recv +// CHECK: (f32[100,100]{1,0}, token[]) recv-done +// CHECK: (f32[100,100]{1,0}, u32[], token[]) send + )"); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/gpu_memory_space_assignment.h b/xla/service/gpu/gpu_memory_space_assignment.h new file mode 100644 index 0000000000000..faa9195bc37fc --- /dev/null +++ b/xla/service/gpu/gpu_memory_space_assignment.h @@ -0,0 +1,74 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_GPU_MEMORY_SPACE_ASSIGNMENT_H_ +#define XLA_SERVICE_GPU_GPU_MEMORY_SPACE_ASSIGNMENT_H_ + +#include + +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_ordering.h" +#include "xla/service/hlo_value.h" +#include "xla/status.h" + +namespace xla { +namespace gpu { + +inline constexpr int64_t kCollectiveMemorySpaceColor = 1; + +// Set memory space to kCollectiveMemorySpaceColor for all allocations used by +// all-reduce, all-gather, and reduce-scatter. This memory space maps to +// collective memory using ncclMemAlloc in the runtime. +inline BufferAssigner::Colorer CollectiveColorer() { + return [](HloAliasAnalysis* alias_analysis, const HloOrdering&) { + static const auto* kSupportedOpcodes = new absl::flat_hash_set{ + HloOpcode::kAllReduce, + HloOpcode::kAllReduceStart, + HloOpcode::kAllReduceDone, + HloOpcode::kAllGather, + HloOpcode::kAllGatherStart, + HloOpcode::kAllGatherDone, + HloOpcode::kReduceScatter, + HloOpcode::kCollectivePermute, + HloOpcode::kCollectivePermuteStart, + HloOpcode::kCollectivePermuteDone, + HloOpcode::kAllToAll, + }; + for (HloValue* value : alias_analysis->dataflow_analysis().values()) { + auto& buffer = alias_analysis->GetBufferContainingValue(*value); + for (const auto& alias : buffer.values()) { + // opcode or async wrapped opcode is in kSupportedOpcodes. + if (kSupportedOpcodes->contains(alias->instruction()->opcode()) || + ((alias->instruction()->opcode() == HloOpcode::kAsyncStart || + alias->instruction()->opcode() == HloOpcode::kAsyncDone) && + kSupportedOpcodes->contains( + alias->instruction()->async_wrapped_opcode()))) { + value->set_color(kCollectiveMemorySpaceColor); + } + } + if (!value->has_color()) { + value->set_color(0); + } + } + return OkStatus(); + }; +} + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_GPU_MEMORY_SPACE_ASSIGNMENT_H_ diff --git a/xla/service/gpu/gpu_norm_runner.cc b/xla/service/gpu/gpu_norm_runner.cc index fdc6ae4e2cc2d..5abb58af4e019 100644 --- a/xla/service/gpu/gpu_norm_runner.cc +++ b/xla/service/gpu/gpu_norm_runner.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,49 +15,76 @@ limitations under the License. #include "xla/service/gpu/gpu_norm_runner.h" +#include #include -#include "xla/layout_util.h" +#include "absl/status/status.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/status_macros.h" +#include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/stream_executor_util.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/lazy_op_runner.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { -Status RunGpuNorm(const gpu::GpuNormConfig& config, - const se::DeviceMemoryBase& input_buffer, - const se::DeviceMemoryBase& scale_buffer, - const se::DeviceMemoryBase& bias_buffer, - const se::DeviceMemoryBase& output_buffer, - std::optional expectation_buffer, - std::optional norm_factor_buffer, - const se::DeviceMemoryBase& scratch_memory, - se::Stream* stream, RunNormOptions options) { +absl::Status RunGpuNorm(const gpu::GpuNormConfig& config, + const se::DeviceMemoryBase& x_buffer, + const se::DeviceMemoryBase& scale_buffer, + const se::DeviceMemoryBase& y_or_dx_buffer, + std::optional bias_buffer, + std::optional dy_buffer, + std::optional expectation_buffer, + std::optional norm_factor_buffer, + std::optional dscale_buffer, + std::optional dbias_buffer, + const se::DeviceMemoryBase& scratch_memory, + se::Stream* stream, RunNormOptions options) { se::dnn::LazyOpRunner* lazy_runner = options.norm_runner->AsNormRunner(); std::optional> local_runner; - se::dnn::NormOp::Config ln_config{config.epsilon, - config.input_descriptor, + TF_ASSIGN_OR_RETURN(se::dnn::NormKind kind, + GetDNNNormKindFromCudnnNormKind(config.kind)); + + se::dnn::NormOp::Config ln_config{kind, + config.epsilon, + config.x_descriptor, config.scale_descriptor, + config.y_or_dx_descriptor, config.bias_descriptor, - config.output_descriptor, + config.dy_descriptor, config.expectation_descriptor, - config.norm_factor_descriptor}; + config.norm_factor_descriptor, + config.dscale_descriptor, + config.dbias_descriptor}; TF_ASSIGN_OR_RETURN(auto* runner, lazy_runner->GetOrCreateRunner(ln_config, stream)); std::vector operands; - operands.emplace_back(input_buffer); + operands.emplace_back(x_buffer); operands.emplace_back(scale_buffer); - operands.emplace_back(bias_buffer); - operands.emplace_back(output_buffer); - if (expectation_buffer) { + operands.emplace_back(y_or_dx_buffer); + + // The remaining operands are composed of inputs followed by outputs of the + // library call. The expectation and norm factor are outputs of the forward + // training layer norm, and inputs of the backward layer norm. + if (config.kind == CudnnNormKind::kLayerForwardInfer || + config.kind == CudnnNormKind::kLayerForwardTrain) { + operands.emplace_back(bias_buffer.value()); + } + if (config.kind == CudnnNormKind::kLayerForwardTrain) { operands.emplace_back(expectation_buffer.value()); + operands.emplace_back(norm_factor_buffer.value()); } - if (norm_factor_buffer) { + if (config.kind == CudnnNormKind::kLayerBackward) { + operands.emplace_back(dy_buffer.value()); + operands.emplace_back(expectation_buffer.value()); operands.emplace_back(norm_factor_buffer.value()); + operands.emplace_back(dscale_buffer.value()); + operands.emplace_back(dbias_buffer.value()); } return (*runner)(stream, options.profile_result, scratch_memory, operands); diff --git a/xla/service/gpu/gpu_norm_runner.h b/xla/service/gpu/gpu_norm_runner.h index 8371ecbe03d8b..854e3c0892050 100644 --- a/xla/service/gpu/gpu_norm_runner.h +++ b/xla/service/gpu/gpu_norm_runner.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,49 +16,70 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GPU_NORM_RUNNER_H_ #define XLA_SERVICE_GPU_GPU_NORM_RUNNER_H_ +#include +#include #include -#include +#include #include -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/stream_executor_util.h" +#include "xla/shape.h" #include "xla/status.h" -#include "xla/statusor.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/lazy_op_runner.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { +inline absl::StatusOr AsCudnnNormKind( + xla::gpu::CudnnNormBackendConfig_Kind kind) { + switch (kind) { + case xla::gpu::CudnnNormBackendConfig::LAYER_FWD_INFER: + return xla::gpu::CudnnNormKind::kLayerForwardInfer; + case xla::gpu::CudnnNormBackendConfig::LAYER_FWD_TRAIN: + return xla::gpu::CudnnNormKind::kLayerForwardTrain; + case xla::gpu::CudnnNormBackendConfig::LAYER_BWD: + return xla::gpu::CudnnNormKind::kLayerBackward; + default: + return xla::Internal("Unknown norm kind."); + } +} + // Intermediate structure used as input to construct GpuNormConfig. struct GpuNormDescriptor { CudnnNormBackendConfig backend_config; - Shape input_shape; + Shape x_shape; Shape scale_shape; - Shape bias_shape; - Shape output_shape; + std::optional bias_shape; + Shape y_or_dx_shape; std::optional expectation_shape; std::optional norm_factor_shape; + std::optional dy_shape; + std::optional dscale_shape; + std::optional dbias_shape; size_t scratch_size; }; // Structure to describe static properties of a fused norm op. struct GpuNormConfig { - static StatusOr For(const GpuNormDescriptor& desc) { - std::vector output_types; + static absl::StatusOr For(const GpuNormDescriptor& desc) { + std::vector y_or_dx_types; GpuNormConfig config; config.epsilon = desc.backend_config.epsilon(); config.algorithm = se::dnn::AlgorithmDesc(desc.backend_config.algorithm()); + TF_ASSIGN_OR_RETURN(config.kind, + AsCudnnNormKind(desc.backend_config.kind())); auto tensor_descriptor_from_shape = - [](Shape shape) -> StatusOr { + [](Shape shape) -> absl::StatusOr { TF_ASSIGN_OR_RETURN( se::dnn::DataType data_type, GetDNNDataTypeFromPrimitiveType(shape.element_type())); @@ -66,35 +87,49 @@ struct GpuNormConfig { shape.layout().minor_to_major()); }; - TF_ASSIGN_OR_RETURN(config.input_descriptor, - tensor_descriptor_from_shape(desc.input_shape)); + TF_ASSIGN_OR_RETURN(config.x_descriptor, + tensor_descriptor_from_shape(desc.x_shape)); TF_ASSIGN_OR_RETURN(config.scale_descriptor, tensor_descriptor_from_shape(desc.scale_shape)); - TF_ASSIGN_OR_RETURN(config.bias_descriptor, - tensor_descriptor_from_shape(desc.bias_shape)); - TF_ASSIGN_OR_RETURN(config.output_descriptor, - tensor_descriptor_from_shape(desc.output_shape)); + TF_ASSIGN_OR_RETURN(config.y_or_dx_descriptor, + tensor_descriptor_from_shape(desc.y_or_dx_shape)); + if (desc.bias_shape) { + TF_ASSIGN_OR_RETURN(config.bias_descriptor, tensor_descriptor_from_shape( + desc.bias_shape.value())); + } if (desc.expectation_shape) { TF_ASSIGN_OR_RETURN( config.expectation_descriptor, tensor_descriptor_from_shape(desc.expectation_shape.value())); - } - if (desc.norm_factor_shape) { TF_ASSIGN_OR_RETURN( config.norm_factor_descriptor, tensor_descriptor_from_shape(desc.norm_factor_shape.value())); } + if (desc.dscale_shape) { + TF_ASSIGN_OR_RETURN(config.dy_descriptor, + tensor_descriptor_from_shape(desc.dy_shape.value())); + TF_ASSIGN_OR_RETURN( + config.dscale_descriptor, + tensor_descriptor_from_shape(desc.dscale_shape.value())); + TF_ASSIGN_OR_RETURN( + config.dbias_descriptor, + tensor_descriptor_from_shape(desc.dbias_shape.value())); + } return config; } double epsilon; + CudnnNormKind kind; se::dnn::AlgorithmDesc algorithm; - se::dnn::TensorDescriptor input_descriptor; + se::dnn::TensorDescriptor x_descriptor; se::dnn::TensorDescriptor scale_descriptor; - se::dnn::TensorDescriptor bias_descriptor; - se::dnn::TensorDescriptor output_descriptor; + std::optional bias_descriptor; + se::dnn::TensorDescriptor y_or_dx_descriptor; std::optional expectation_descriptor; std::optional norm_factor_descriptor; + std::optional dy_descriptor; + std::optional dscale_descriptor; + std::optional dbias_descriptor; }; class NormRunner { @@ -127,15 +162,18 @@ struct RunNormOptions { NormRunner* norm_runner; }; -Status RunGpuNorm(const GpuNormConfig& conv_config, - const se::DeviceMemoryBase& input_buffer, - const se::DeviceMemoryBase& scale_buffer, - const se::DeviceMemoryBase& bias_buffer, - const se::DeviceMemoryBase& output_buffer, - std::optional exepctation_buffer, - std::optional norm_factor_buffer, - const se::DeviceMemoryBase& scratch_memory, - se::Stream* stream, RunNormOptions options = {}); +absl::Status RunGpuNorm(const GpuNormConfig& conv_config, + const se::DeviceMemoryBase& x_buffer, + const se::DeviceMemoryBase& scale_buffer, + const se::DeviceMemoryBase& y_or_dx_buffer, + std::optional bias_buffer, + std::optional dy_buffer, + std::optional exepctation_buffer, + std::optional norm_factor_buffer, + std::optional dscale_buffer, + std::optional dbias_buffer, + const se::DeviceMemoryBase& scratch_memory, + se::Stream* stream, RunNormOptions options = {}); } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/gpu_offloading_test.cc b/xla/service/gpu/gpu_offloading_test.cc new file mode 100644 index 0000000000000..56abbe6f91140 --- /dev/null +++ b/xla/service/gpu/gpu_offloading_test.cc @@ -0,0 +1,206 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include +#include +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/autotune_results.pb.h" +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/utils/hlo_matchers.h" +#include "xla/layout.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/buffer_value.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/hlo_memory_scheduler.h" +#include "xla/service/hlo_rematerialization.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/util.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +namespace m = ::xla::match; +namespace op = xla::testing::opcode_matchers; + +using ::testing::IsEmpty; +using ::testing::Not; +using ::testing::TempDir; + +class GpuCompilerTest : public HloTestBase { + protected: + absl::StatusOr RunHloRematerialization(int64_t memory_limit_bytes, + HloModule* module, + int64_t min_remat_size = 0) { + TF_EXPECT_OK(verifier().Run(module).status()); + if (!module->has_schedule()) { + HloMemoryScheduler scheduler( + [](const BufferValue& buffer) { + return ::xla::ShapeUtil::ByteSizeOf(buffer.shape()); + }, + ComputationSchedulerToModuleScheduler(DefaultMemoryScheduler)); + TF_EXPECT_OK(scheduler.Run(module).status()); + } + // Create a configuration where any compute is much much slower than any + // number of number of copies. + HloCostAnalysis::Options hlo_cost_analysis_options; + hlo_cost_analysis_options.shape_size = [](const Shape& shape) { + return ::xla::ShapeUtil::ByteSizeOf(shape); + }; + hlo_cost_analysis_options.set_flops_per_second(flops_per_second_); + hlo_cost_analysis_options.set_transcendentals_per_second( + transcendentals_per_second_); + HloCostAnalysis cost_analysis(hlo_cost_analysis_options); + HloRematerialization::RematerializationModeConfig config( + /*recompute=*/false, /*compress=*/false, /*host_offload=*/true); + HloRematerialization::HostMemoryOffloadConfig host_memory_offload_config( + kHostMemorySpaceColor, copy_to_host_speed_, copy_from_host_speed_); + HloRematerialization::Options options( + cost_analysis, config, memory_limit_bytes, + /*block_size_limit=*/1, /*block_rematerialization_factor=*/1, + min_remat_size, /*compact_shape_function=*/nullptr, + host_memory_offload_config); + HloRematerialization::RematerializationSizes sizes; + HloRematerialization remat(options, sizes); + return remat.Run(module); + } + void SetCopyToHostSpeed(float val) { copy_to_host_speed_ = val; } + void SetCopyFromHostSpeed(float val) { copy_from_host_speed_ = val; } + void SetFlopsPerSecond(float val) { flops_per_second_ = val; } + void SetTranscendentalsPerSecond(float val) { + transcendentals_per_second_ = val; + } + + static constexpr const int64_t kHostMemorySpaceColor{5}; + + private: + float copy_to_host_speed_{1.0f}; + float copy_from_host_speed_{1.0f}; + float flops_per_second_{1.0f}; + float transcendentals_per_second_{1.0f}; +}; + +TEST_F(GpuCompilerTest, OriginalTest) { + const char* hlo_text = R"( + HloModule test + +ENTRY %main (param_0: f32[1024], param_1: f32[1024]) -> f32[1024] { + %param_1 = f32[1024]{0} parameter(1) + %param_0 = f32[1024]{0} parameter(0) + %res_3 = f32[1024]{0} add(f32[1024]{0} %param_0, f32[1024]{0} %param_1) + %copy-start = (f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) copy-start(f32[1024]{0} %res_3) + %res_4 = f32[1024]{0} tanh(f32[1024]{0} %res_3) + %copy-start.2 = (f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) copy-start(f32[1024]{0} %res_4) + %res_5 = f32[1024]{0} tanh(f32[1024]{0} %res_4) + %copy-done = f32[1024]{0:S(5)} copy-done((f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) %copy-start) + %res_6 = f32[1024]{0} tanh(f32[1024]{0} %res_5) + %copy-done.2 = f32[1024]{0:S(5)} copy-done((f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) %copy-start.2) + %copy-start.3 = (f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) copy-start(f32[1024]{0:S(5)} %copy-done.2) + %res_7 = f32[1024]{0} add(f32[1024]{0} %res_6, f32[1024]{0} %res_6) + %copy-start.1 = (f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) copy-start(f32[1024]{0:S(5)} %copy-done) + %res_8 = f32[1024]{0} add(f32[1024]{0} %res_7, f32[1024]{0} %res_5) + %copy-done.3 = f32[1024]{0} copy-done((f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) %copy-start.3) + %res_9 = f32[1024]{0} add(f32[1024]{0} %res_8, f32[1024]{0} %copy-done.3) + %copy-done.1 = f32[1024]{0} copy-done((f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) %copy-start.1) + %res_10 = f32[1024]{0} add(f32[1024]{0} %res_9, f32[1024]{0} %copy-done.1) + ROOT %res_11 = f32[1024]{0} tanh(f32[1024]{0} %res_10) +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); +} + +TEST_F(GpuCompilerTest, CompiledProgramsCount) { + const char* hlo_text = R"( + HloModule test + + ENTRY main { + param_0 = f32[1024]{0} parameter(0) + param_1 = f32[1024]{0} parameter(1) + res_3 = f32[1024]{0} add(param_0, param_1) + res_4 = f32[1024]{0} tanh(res_3) + res_5 = f32[1024]{0} tanh(res_4) + res_6 = f32[1024]{0} tanh(res_5) + res_7 = f32[1024]{0} add(res_6, res_6) + res_8 = f32[1024]{0} add(res_7, res_5) + res_9 = f32[1024]{0} add(res_8, res_4) + res_10 = f32[1024]{0} add(res_9, res_3) + ROOT res_11 = f32[1024]{0} tanh(res_10) + } +)"; + + auto module = ParseAndReturnVerifiedModule(hlo_text).value(); + auto module_ref = ParseAndReturnVerifiedModule(hlo_text).value(); + + // Set some "hardware" constants so that we can test that instructions are + // placed in the places we expect. + SetCopyToHostSpeed(4.0 * 1024); + SetCopyFromHostSpeed(4.0 * 1024); + SetFlopsPerSecond(2 * 1024); + SetTranscendentalsPerSecond(2 * 1024); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/10 * 1024, module.get())); + ASSERT_TRUE(changed); + + // The module should still have a schedule. + ASSERT_TRUE(module->has_schedule()); + + // Verify that exactly two instructions are rematerialized. + auto res_3_matcher = op::Add(op::Parameter(), op::Parameter()); + auto res_3_rematted_matcher = op::AsyncCopy( + xla::Layout::kDefaultMemorySpace, kHostMemorySpaceColor, + op::AsyncCopy(kHostMemorySpaceColor, xla::Layout::kDefaultMemorySpace, + res_3_matcher)); + auto res_4_matcher = op::Tanh(res_3_matcher); + auto res_4_rematted_matcher = op::AsyncCopy( + xla::Layout::kDefaultMemorySpace, kHostMemorySpaceColor, + op::AsyncCopy(kHostMemorySpaceColor, xla::Layout::kDefaultMemorySpace, + res_4_matcher)); + auto res_5_matcher = op::Tanh(res_4_matcher); + auto res_6_matcher = op::Tanh(res_5_matcher); + auto res_7_matcher = op::Add(res_6_matcher, res_6_matcher); + auto res_8_matcher = op::Add(res_7_matcher, res_5_matcher); + auto res_9_matcher = op::Add(res_8_matcher, res_4_rematted_matcher); + auto res_10_matcher = op::Add(res_9_matcher, res_3_rematted_matcher); + + const auto instruction_sequence = + module->schedule().sequence(module->entry_computation()); + ASSERT_THAT(instruction_sequence.instructions().back(), + op::Tanh(res_10_matcher)); + // module has the graph optimized by rematerialization and schedule + // module_ref has the original graph without rematerialization + EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module_ref), + ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6}, + /*run_hlo_passes=*/false)); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/gpu_p2p_pipeliner.cc b/xla/service/gpu/gpu_p2p_pipeliner.cc new file mode 100644 index 0000000000000..f8cba55030c9f --- /dev/null +++ b/xla/service/gpu/gpu_p2p_pipeliner.cc @@ -0,0 +1,211 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/gpu_p2p_pipeliner.h" + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/collective_pipeliner.h" +#include "xla/service/hlo_parser.h" +#include "xla/service/hlo_pass_pipeline.h" +#include "xla/status.h" +#include "xla/util.h" + +namespace xla { +namespace gpu { +namespace { + +bool ShouldPipeline(const HloInstruction* instr) { + if (!HloPredicateIsOp(instr)) { + return false; + } + + // Not annotated for pipelining. + auto it = instr->frontend_attributes().map().find(kSendRecvPipelineAttr); + if (it == instr->frontend_attributes().map().end()) { + return false; + } + + // Checks that the SendDone or RecvDone is used for non-trivial computation. + // This avoids repeatedly pipelining a loop. + bool is_pipelined = + (instr->user_count() == 1 && instr->parent() != nullptr && + instr->users()[0] == instr->parent()->root_instruction()); + return !is_pipelined; +} + +bool ShouldAllowLoopVariantParameterInChain(const HloInstruction* instr) { + // Allow any loop parameter needed for pipelining the Send/Recv instructions + // that have been decided to pipeline. + CHECK(instr->opcode() == HloOpcode::kGetTupleElement && + instr->operand(0)->opcode() == HloOpcode::kParameter); + return true; +} + +Status PostprocessP2PImpl( + HloInstruction* instr, + std::function&)> transformer) { + // The input instruction is a Done instruction. + if (!HloPredicateIsOp(instr)) { + return Internal("Expected SendDone/RecvDone as the pipelined collective"); + } + instr = instr->mutable_operand(0); + if (!HloPredicateIsOp(instr)) { + return Internal("Expected Send/Recv as the SendDone/RecvDone operand"); + } + auto validation_it = + instr->frontend_attributes().map().find(kSendRecvValidationAttr); + if (validation_it == instr->frontend_attributes().map().end() || + validation_it->second == "invalid") { + return OkStatus(); + } + auto statusor_bounds = ParseReplicaGroupsOnly(validation_it->second); + if (!statusor_bounds.ok()) { + return statusor_bounds.status(); + } + std::string validation_attr = transformer(statusor_bounds.value()); + xla::FrontendAttributes attributes = instr->frontend_attributes(); + (*attributes.mutable_map())[kSendRecvValidationAttr] = validation_attr; + instr->set_frontend_attributes(attributes); + return OkStatus(); +} + +// Modifies the loop iteration frontend attribute for the peeled off Send and +// Recv for the first iteration of a loop. +Status PostprocessPeeledP2P(HloInstruction* instr) { + auto transform_bounds = [&](std::vector& replica_groups) { + std::vector> bounds; + bounds.reserve(replica_groups.size()); + bool all_invalid = true; + for (const auto& replica_group : replica_groups) { + // The peeled off instruction is for executing the first iteration of + // the loop. + int64_t lower_bound = replica_group.replica_ids(0); + int64_t upper_bound = replica_group.replica_ids(1); + if (lower_bound <= 0 && upper_bound >= 0) { + all_invalid = false; + bounds.push_back({0, 0}); + } else { + bounds.push_back({1, 0}); + } + } + std::string validation_attr; + if (all_invalid) { + // An optimized way to represent that all source-target pairs are + // communicating invalid data, to avoid the overhead related to the use + // of execution counters. + validation_attr = "invalid"; + } else { + validation_attr = "{" + + absl::StrJoin(bounds, ",", + absl::PairFormatter( + [](std::string* out, int64_t value) { + absl::StrAppend(out, "{", value); + }, + ",", + [](std::string* out, int64_t value) { + absl::StrAppend(out, value, "}"); + })) + + "}"; + } + return validation_attr; + }; + return PostprocessP2PImpl(instr, transform_bounds); +}; + +// Modifies the loop iteration frontend attribute for the rotated Send and Recv +// for the remaining iterations in a loop. +Status PostprocessRotatedP2P(HloInstruction* instr) { + auto transform_bounds = [&](std::vector& replica_groups) { + std::vector> bounds; + bounds.reserve(replica_groups.size()); + bool all_invalid = true; + for (const auto& replica_group : replica_groups) { + int64_t lower_bound = replica_group.replica_ids(0); + int64_t upper_bound = replica_group.replica_ids(1); + if (lower_bound <= upper_bound) { + if (lower_bound >= 1) { + --lower_bound; + } + if (upper_bound >= 1) { + --upper_bound; + } + if (lower_bound <= upper_bound) { + all_invalid = false; + bounds.push_back({lower_bound, upper_bound}); + } else { + bounds.push_back({1, 0}); + } + } else { + bounds.push_back({lower_bound, upper_bound}); + } + } + + std::string validation_attr; + if (all_invalid) { + // An optimized way to represent that all source-target pairs are + // communicating invalid data, to avoid the overhead related to the use + // of execution counters. + validation_attr = "invalid"; + } else { + validation_attr = "{" + + absl::StrJoin(bounds, ",", + absl::PairFormatter( + [](std::string* out, int64_t value) { + absl::StrAppend(out, "{", value); + }, + ",", + [](std::string* out, int64_t value) { + absl::StrAppend(out, value, "}"); + })) + + "}"; + } + return validation_attr; + }; + + return PostprocessP2PImpl(instr, transform_bounds); +} + +} // anonymous namespace + +void AddP2PPipeliner(HloPassPipeline& pipeline) { + CollectivePipeliner::Config config{ + /*level_to_operate_on=*/0, + // Pipeline everything annotated for pipelining. + /*max_pipelining_per_loop=*/INT64_MAX, + /*last_run=*/true, + /*pipeline_use_tree=*/false, + /*process_different_sized_ops=*/true, + /*pipelining_direction=*/ + CollectivePipeliner::PipeliningDirection::kBackward, ShouldPipeline, + /*acceptable_formatting=*/HloPredicateTrue, + /*reuse_pipelined_op_buffer=*/HloPredicateTrue, + ShouldAllowLoopVariantParameterInChain, PostprocessPeeledP2P, + PostprocessRotatedP2P}; + pipeline.AddPass(config); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/gpu_p2p_pipeliner.h b/xla/service/gpu/gpu_p2p_pipeliner.h new file mode 100644 index 0000000000000..6bd3b588eed8c --- /dev/null +++ b/xla/service/gpu/gpu_p2p_pipeliner.h @@ -0,0 +1,29 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_GPU_P2P_PIPELINER_H_ +#define XLA_SERVICE_GPU_GPU_P2P_PIPELINER_H_ + +#include "xla/service/hlo_pass_pipeline.h" + +namespace xla { +namespace gpu { +// Adds a collective-pipeliner pass for pipelining P2P Send-Recv chains. +void AddP2PPipeliner(HloPassPipeline& pipeline); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_GPU_P2P_PIPELINER_H_ diff --git a/xla/service/gpu/gpu_p2p_pipeliner_test.cc b/xla/service/gpu/gpu_p2p_pipeliner_test.cc new file mode 100644 index 0000000000000..7e71820ed87a1 --- /dev/null +++ b/xla/service/gpu/gpu_p2p_pipeliner_test.cc @@ -0,0 +1,158 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/gpu_p2p_pipeliner.h" + +#include +#include +#include + +#include +#include +#include "absl/log/check.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/hlo_parser.h" +#include "xla/service/hlo_pass_pipeline.h" +#include "xla/service/hlo_verifier.h" +#include "xla/statusor.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/util.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuP2PPipelinerTest : public HloTestBase { + public: + GpuP2PPipelinerTest() { + const int64_t kNumReplicas = 1; + const int64_t kNumComputations = 4; + config_ = GetModuleConfigForTest(/*replica_count=*/kNumReplicas, + /*num_partitions=*/kNumComputations); + } + + absl::StatusOr RunOptimizer(HloModule* module) { + HloPassPipeline pipeline("optimizer"); + pipeline.AddPass(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); + AddP2PPipeliner(pipeline); + pipeline.AddPass(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); + return pipeline.Run(module); + } + + protected: + HloModuleConfig config_; +}; + +TEST_F(GpuP2PPipelinerTest, + TransformRecvSendBackwardsWithMetaDataPostProcessing) { + const char* kHloStr = R"( + HloModule module + cond { + param = (u32[], u32[2]) parameter(0) + count = get-tuple-element(param), index=0 + ub = u32[] constant(10) + ROOT result = pred[] compare(count, ub), direction=LT + } + + body { + param = (u32[], u32[2]) parameter(0) + count = get-tuple-element(param), index=0 + send-data = get-tuple-element(param), index=1 + + after-all.0 = token[] after-all() + recv.0 = (u32[2], u32[], token[]) recv(after-all.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}", + _xla_send_recv_pipeline="0", + _xla_send_recv_validation="{{1,7}}" + } + after-all.0.s = token[] after-all() + send.0 = (u32[2], u32[], token[]) send(send-data, after-all.0.s), + channel_id=1, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}", + _xla_send_recv_pipeline="0", + _xla_send_recv_validation="{{1,7}}" + } + recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + recv-data = u32[2] get-tuple-element(recv-done.0), index=0 + + c1 = u32[] constant(1) + new_count = u32[] add(count, c1) + + r = u32[2] broadcast(c1), dimensions={} + s = u32[2] add(r, recv-data) + + send-done.0 = token[] send-done(send.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + ROOT result = (u32[], u32[2]) tuple(new_count, s) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + c1 = u32[] constant(1) + r = u32[] replica-id() + a = u32[] add(c1, r) + init = u32[2] broadcast(a), dimensions={} + while_init = (u32[], u32[2]) tuple(c0, init) + while_result = (u32[], u32[2]) while(while_init), body=body, condition=cond + ROOT result = u32[2] get-tuple-element(while_result), index=1 + })"; + + auto module = ParseAndReturnUnverifiedModule(kHloStr, config_).value(); + EXPECT_TRUE(RunOptimizer(module.get()).value()); + XLA_VLOG_LINES(10, module->ToString()); + auto while_op = FindInstruction(module.get(), "while"); + EXPECT_EQ(while_op->opcode(), HloOpcode::kWhile); + EXPECT_EQ(while_op->shape().tuple_shapes().size(), 5); + auto recv1 = + DynCast(FindInstruction(module.get(), "recv.1")); + EXPECT_NE(recv1, nullptr); + auto recv2 = + DynCast(FindInstruction(module.get(), "recv.2")); + EXPECT_NE(recv2, nullptr); + EXPECT_EQ(recv1->channel_id(), recv2->channel_id()); + + auto send1 = + DynCast(FindInstruction(module.get(), "send.1")); + EXPECT_NE(send1, nullptr); + auto send2 = + DynCast(FindInstruction(module.get(), "send.2")); + EXPECT_NE(send2, nullptr); + EXPECT_EQ(send1->channel_id(), send2->channel_id()); + + const char* kPeeledAttr = "_xla_send_recv_validation=\"invalid\""; + const char* kRotatedAttr = "_xla_send_recv_validation=\"{{0,6}}\""; + EXPECT_THAT(send1->ToString(), ::testing::HasSubstr(kPeeledAttr)); + EXPECT_THAT(recv1->ToString(), ::testing::HasSubstr(kPeeledAttr)); + EXPECT_THAT(send2->ToString(), ::testing::HasSubstr(kRotatedAttr)); + EXPECT_THAT(recv2->ToString(), ::testing::HasSubstr(kRotatedAttr)); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/gpu_prim.h b/xla/service/gpu/gpu_prim.h new file mode 100644 index 0000000000000..5e7daa3d86e02 --- /dev/null +++ b/xla/service/gpu/gpu_prim.h @@ -0,0 +1,118 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +To in writing unless required by applicable law or agreed, +distributed on an, software distributed under the license is "AS IS" +BASIS, WITHOUT OF ANY KIND WARRANTIES OR CONDITIONS, either express +or implied. For the specific language governing permissions and +limitations under the license, the license you must see. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_GPU_PRIM_H_ +#define XLA_SERVICE_GPU_GPU_PRIM_H_ + +#include "tsl/platform/bfloat16.h" + +#if GOOGLE_CUDA +#include "cub/block/block_load.cuh" +#include "cub/block/block_scan.cuh" +#include "cub/block/block_store.cuh" +#include "cub/device/device_histogram.cuh" +#include "cub/device/device_radix_sort.cuh" +#include "cub/device/device_reduce.cuh" +#include "cub/device/device_scan.cuh" +#include "cub/device/device_segmented_radix_sort.cuh" +#include "cub/device/device_segmented_reduce.cuh" +#include "cub/device/device_select.cuh" +#include "cub/iterator/counting_input_iterator.cuh" +#include "cub/iterator/transform_input_iterator.cuh" +#include "cub/thread/thread_operators.cuh" +#include "cub/warp/warp_reduce.cuh" +#include "third_party/gpus/cuda/include/cusparse.h" + +namespace gpuprim = ::cub; + +// Required for sorting Eigen::half and bfloat16. +namespace cub { +template <> +__device__ __forceinline__ void ThreadStoreVolatilePtr( + Eigen::half *ptr, Eigen::half val, Int2Type /*is_primitive*/) { + *reinterpret_cast(ptr) = + Eigen::numext::bit_cast(val); +} + +template <> +__device__ __forceinline__ Eigen::half ThreadLoadVolatilePointer( + Eigen::half *ptr, Int2Type /*is_primitive*/) { + uint16_t result = *reinterpret_cast(ptr); + return Eigen::numext::bit_cast(result); +} + +template <> +__device__ __forceinline__ void ThreadStoreVolatilePtr( + tsl::bfloat16 *ptr, tsl::bfloat16 val, Int2Type /*is_primitive*/) { + *reinterpret_cast(ptr) = + Eigen::numext::bit_cast(val); +} + +template <> +__device__ __forceinline__ tsl::bfloat16 +ThreadLoadVolatilePointer(tsl::bfloat16 *ptr, + Int2Type /*is_primitive*/) { + uint16_t result = *reinterpret_cast(ptr); + return Eigen::numext::bit_cast(result); +} + +template <> +struct NumericTraits + : BaseTraits {}; +template <> +struct NumericTraits + : BaseTraits {}; +} // namespace cub +#elif TENSORFLOW_USE_ROCM + +#include "rocm/include/hipcub/hipcub.hpp" +#include "rocm/rocm_config.h" +namespace gpuprim = ::hipcub; + +// Required for sorting Eigen::half and bfloat16. +namespace rocprim { +namespace detail { + +#if (TF_ROCM_VERSION >= 50200) +template <> +struct float_bit_mask { + static constexpr uint16_t sign_bit = 0x8000; + static constexpr uint16_t exponent = 0x7C00; + static constexpr uint16_t mantissa = 0x03FF; + using bit_type = uint16_t; +}; + +template <> +struct float_bit_mask { + static constexpr uint16_t sign_bit = 0x8000; + static constexpr uint16_t exponent = 0x7F80; + static constexpr uint16_t mantissa = 0x007F; + using bit_type = uint16_t; +}; +#endif // TF_ROCM_VERSION >= 50200 +template <> +struct radix_key_codec_base + : radix_key_codec_floating {}; +template <> +struct radix_key_codec_base + : radix_key_codec_floating {}; +}; // namespace detail +}; // namespace rocprim + +#endif // TENSORFLOW_USE_ROCM + +#endif // XLA_SERVICE_GPU_GPU_PRIM_H_ diff --git a/xla/service/gpu/gpu_prim_cuda.h b/xla/service/gpu/gpu_prim_cuda.h deleted file mode 100644 index e4ee313cacc95..0000000000000 --- a/xla/service/gpu/gpu_prim_cuda.h +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -To in writing unless required by applicable law or agreed, -distributed on an, software distributed under the license is "AS IS" -BASIS, WITHOUT OF ANY KIND WARRANTIES OR CONDITIONS, either express -or implied. For the specific language governing permissions and -limitations under the license, the license you must see. -==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_PRIM_CUDA_H_ -#define XLA_SERVICE_GPU_GPU_PRIM_CUDA_H_ - -#include "tsl/platform/bfloat16.h" - -#if GOOGLE_CUDA -#include "cub/block/block_load.cuh" -#include "cub/block/block_scan.cuh" -#include "cub/block/block_store.cuh" -#include "cub/device/device_histogram.cuh" -#include "cub/device/device_radix_sort.cuh" -#include "cub/device/device_reduce.cuh" -#include "cub/device/device_scan.cuh" -#include "cub/device/device_segmented_radix_sort.cuh" -#include "cub/device/device_segmented_reduce.cuh" -#include "cub/device/device_select.cuh" -#include "cub/iterator/counting_input_iterator.cuh" -#include "cub/iterator/transform_input_iterator.cuh" -#include "cub/thread/thread_operators.cuh" -#include "cub/warp/warp_reduce.cuh" -#include "third_party/gpus/cuda/include/cusparse.h" - -namespace gpuprim = ::cub; - -// Required for sorting Eigen::half and bfloat16. -namespace cub { -template <> -__device__ __forceinline__ void ThreadStoreVolatilePtr( - Eigen::half *ptr, Eigen::half val, Int2Type /*is_primitive*/) { - *reinterpret_cast(ptr) = - Eigen::numext::bit_cast(val); -} - -template <> -__device__ __forceinline__ Eigen::half ThreadLoadVolatilePointer( - Eigen::half *ptr, Int2Type /*is_primitive*/) { - uint16_t result = *reinterpret_cast(ptr); - return Eigen::numext::bit_cast(result); -} - -template <> -__device__ __forceinline__ void ThreadStoreVolatilePtr( - tsl::bfloat16 *ptr, tsl::bfloat16 val, Int2Type /*is_primitive*/) { - *reinterpret_cast(ptr) = - Eigen::numext::bit_cast(val); -} - -template <> -__device__ __forceinline__ tsl::bfloat16 -ThreadLoadVolatilePointer(tsl::bfloat16 *ptr, - Int2Type /*is_primitive*/) { - uint16_t result = *reinterpret_cast(ptr); - return Eigen::numext::bit_cast(result); -} - -template <> -struct NumericTraits - : BaseTraits {}; -template <> -struct NumericTraits - : BaseTraits {}; -} // namespace cub -#endif // GOOGLE_CUDA - -#endif // XLA_SERVICE_GPU_GPU_PRIM_CUDA_H_ diff --git a/xla/service/gpu/gpu_prim_rocm.h b/xla/service/gpu/gpu_prim_rocm.h deleted file mode 100644 index 773e53468878d..0000000000000 --- a/xla/service/gpu/gpu_prim_rocm.h +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -To in writing unless required by applicable law or agreed, -distributed on an, software distributed under the license is "AS IS" -BASIS, WITHOUT OF ANY KIND WARRANTIES OR CONDITIONS, either express -or implied. For the specific language governing permissions and -limitations under the license, the license you must see. -==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_PRIM_ROCM_H_ -#define XLA_SERVICE_GPU_GPU_PRIM_ROCM_H_ - -#include "tsl/platform/bfloat16.h" - -#if TENSORFLOW_USE_ROCM - -#include "rocm/include/hipcub/hipcub.hpp" -#include "rocm/rocm_config.h" -namespace gpuprim = ::hipcub; - -// Required for sorting Eigen::half and bfloat16. -namespace rocprim { -namespace detail { - -#if (TF_ROCM_VERSION >= 50200) -template <> -struct float_bit_mask { - static constexpr uint16_t sign_bit = 0x8000; - static constexpr uint16_t exponent = 0x7C00; - static constexpr uint16_t mantissa = 0x03FF; - using bit_type = uint16_t; -}; - -template <> -struct float_bit_mask { - static constexpr uint16_t sign_bit = 0x8000; - static constexpr uint16_t exponent = 0x7F80; - static constexpr uint16_t mantissa = 0x007F; - using bit_type = uint16_t; -}; -#endif // TF_ROCM_VERSION >= 50200 -template <> -struct radix_key_codec_base - : radix_key_codec_floating {}; -template <> -struct radix_key_codec_base - : radix_key_codec_floating {}; -}; // namespace detail -}; // namespace rocprim - -#endif // TENSORFLOW_USE_ROCM -#endif // XLA_SERVICE_GPU_GPU_PRIM_ROCM_H_ diff --git a/xla/service/gpu/gpu_reduce_scatter_creator.cc b/xla/service/gpu/gpu_reduce_scatter_creator.cc index 9950581a10c18..2f3f0176da2a3 100644 --- a/xla/service/gpu/gpu_reduce_scatter_creator.cc +++ b/xla/service/gpu/gpu_reduce_scatter_creator.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,14 @@ limitations under the License. #include "xla/service/gpu/gpu_reduce_scatter_creator.h" +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -22,11 +30,15 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/collective_opt_utils.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" +#include "xla/status_macros.h" +#include "tsl/platform/errors.h" namespace xla { namespace gpu { -StatusOr ReduceScatterCreator::Run( +absl::StatusOr ReduceScatterCreator::Run( HloModule *module, const absl::flat_hash_set &execution_threads) { const HloModuleConfig &config = module->config(); diff --git a/xla/service/gpu/gpu_reduce_scatter_creator.h b/xla/service/gpu/gpu_reduce_scatter_creator.h index 4bc1da7d003fc..fcecb460747cc 100644 --- a/xla/service/gpu/gpu_reduce_scatter_creator.h +++ b/xla/service/gpu/gpu_reduce_scatter_creator.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,10 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GPU_REDUCE_SCATTER_CREATOR_H_ #define XLA_SERVICE_GPU_GPU_REDUCE_SCATTER_CREATOR_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" namespace xla { @@ -28,7 +32,7 @@ class ReduceScatterCreator : public HloModulePass { absl::string_view name() const override { return "reduce-scatter-creator"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/gpu_sanitize_constant_names.cc b/xla/service/gpu/gpu_sanitize_constant_names.cc index 395e755e34ae8..d948882dd0781 100644 --- a/xla/service/gpu/gpu_sanitize_constant_names.cc +++ b/xla/service/gpu/gpu_sanitize_constant_names.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,17 +17,19 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/llvm_ir/buffer_assignment_util.h" +#include "xla/service/name_uniquer.h" #include "tsl/platform/logging.h" -#include "tsl/platform/status.h" namespace xla { namespace gpu { -StatusOr GpuSanitizeConstantNames::Run( +absl::StatusOr GpuSanitizeConstantNames::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; @@ -57,6 +59,9 @@ StatusOr GpuSanitizeConstantNames::Run( std::string sanitized_name = llvm_ir::SanitizeConstantName(*instr); instr->SetAndSanitizeName(sanitized_name); instr->UniquifyName(&instr_name_uniquer); + // Register this new name with the module's instruction_name_uniquer to + // avoid name collision that might happen in future. + module->instruction_name_uniquer().GetUniqueName(instr->name()); changed = true; } } diff --git a/xla/service/gpu/gpu_sanitize_constant_names.h b/xla/service/gpu/gpu_sanitize_constant_names.h index 70bbebd426ca5..08701a4fe3432 100644 --- a/xla/service/gpu/gpu_sanitize_constant_names.h +++ b/xla/service/gpu/gpu_sanitize_constant_names.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_ #define XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" @@ -30,7 +33,7 @@ class GpuSanitizeConstantNames : public HloModulePass { absl::string_view name() const override { return "sanitize-constant-names"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/gpu_sanitize_constant_names_test.cc b/xla/service/gpu/gpu_sanitize_constant_names_test.cc index 80afbb648bd64..17f45dc100f68 100644 --- a/xla/service/gpu/gpu_sanitize_constant_names_test.cc +++ b/xla/service/gpu/gpu_sanitize_constant_names_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,13 +15,16 @@ limitations under the License. #include "xla/service/gpu/gpu_sanitize_constant_names.h" +#include +#include #include -#include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/literal_util.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -61,6 +64,29 @@ TEST_F(SanitizeConstantNamesTest, InstructionNameWithDotSanitized) { EXPECT_EQ(root->name(), "equal_to"); } +TEST_F(SanitizeConstantNamesTest, NewInstructionNameRegisteredWithModule) { + const char *const kHloString = R"( + HloModule HyphenInInstructionName + ENTRY kernelEntry { + ROOT equal.to = s32[2]{0} constant({42, 73}) + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).value()); + HloInstruction *root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->name(), "equal_to"); + + auto constant_instr = + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1)); + constant_instr->SetAndSanitizeName("equal_to"); + module->entry_computation()->AddInstruction(std::move(constant_instr)); + + EXPECT_THAT(FindInstruction(module.get(), "equal_to.1"), + GmockMatch(m::Constant())); +} + TEST_F(SanitizeConstantNamesTest, BufferSanitizedNameCollisionResolved) { const char *const kHloString = R"( HloModule BufferSanitizedName diff --git a/xla/service/gpu/gpu_scatter_expander.cc b/xla/service/gpu/gpu_scatter_expander.cc index 8dfe3a78d7c9b..b03b340cb8bbd 100644 --- a/xla/service/gpu/gpu_scatter_expander.cc +++ b/xla/service/gpu/gpu_scatter_expander.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,11 +15,9 @@ limitations under the License. #include "xla/service/gpu/gpu_scatter_expander.h" -#include "absl/algorithm/container.h" -#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/statusor.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/primitive_util.h" namespace xla { diff --git a/xla/service/gpu/gpu_scatter_expander.h b/xla/service/gpu/gpu_scatter_expander.h index 7a91c309b1728..100350cb67ac0 100644 --- a/xla/service/gpu/gpu_scatter_expander.h +++ b/xla/service/gpu/gpu_scatter_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GPU_SCATTER_EXPANDER_H_ #define XLA_SERVICE_GPU_GPU_SCATTER_EXPANDER_H_ +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/scatter_expander.h" namespace xla { diff --git a/xla/service/gpu/gpu_schedule_postprocessing.cc b/xla/service/gpu/gpu_schedule_postprocessing.cc index 8c66874b378ee..e71232c0a2dda 100644 --- a/xla/service/gpu/gpu_schedule_postprocessing.cc +++ b/xla/service/gpu/gpu_schedule_postprocessing.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,12 +15,12 @@ limitations under the License. #include "xla/service/gpu/gpu_schedule_postprocessing.h" -#include #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -28,7 +28,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/statusor.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -60,22 +59,22 @@ bool MayInvokeCustomCall( // Returns true if this is an asynchronous collective start operation, excluding // P2P operations. -StatusOr IsRelevantAsynchronousStart(const HloInstruction* hlo) { - HloOpcode opcode = hlo->opcode(); - if (!hlo_query::IsAsyncCollectiveStartOp(opcode, +absl::StatusOr IsRelevantAsynchronousStart(const HloInstruction* hlo) { + if (!hlo_query::IsAsyncCollectiveStartOp(hlo, /*include_send_recv=*/false)) { return false; } - TF_ASSIGN_OR_RETURN(CollectiveBackendConfig collective_backend_config, - hlo->backend_config()); + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + hlo->backend_config()); + CollectiveBackendConfig collective_backend_config = + gpu_config.collective_backend_config(); return !collective_backend_config.is_sync(); } // Returns true if this is a collective done operation, excluding P2P // operations. -StatusOr IsRelevantAsynchronousDone(const HloInstruction* hlo) { - HloOpcode opcode = hlo->opcode(); - return hlo_query::IsAsyncCollectiveDoneOp(opcode, +absl::StatusOr IsRelevantAsynchronousDone(const HloInstruction* hlo) { + return hlo_query::IsAsyncCollectiveDoneOp(hlo, /*include_send_recv=*/false); } @@ -83,7 +82,7 @@ StatusOr IsRelevantAsynchronousDone(const HloInstruction* hlo) { // that aren't parallel with custom-calls and sets its no_parallel_custom_call // attribute to true. Also records whether the given computation may invoke // custom-calls. -StatusOr ProcessComputation( +absl::StatusOr ProcessComputation( const HloSchedule& schedule, HloComputation* computation, CustomCallInComputation& custom_call_in_computation) { bool changed = false; @@ -115,12 +114,12 @@ StatusOr ProcessComputation( HloInstruction* async_start = hlo->mutable_operand(0); if (async_starts.contains(async_start)) { changed = true; - TF_ASSIGN_OR_RETURN( - CollectiveBackendConfig collective_backend_config, - async_start->backend_config()); + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + async_start->backend_config()); + CollectiveBackendConfig& collective_backend_config = + *gpu_config.mutable_collective_backend_config(); collective_backend_config.set_no_parallel_custom_call(true); - TF_RETURN_IF_ERROR( - async_start->set_backend_config(collective_backend_config)); + TF_RETURN_IF_ERROR(async_start->set_backend_config(gpu_config)); async_starts.erase(async_start); } } @@ -132,7 +131,7 @@ StatusOr ProcessComputation( } // anonymous namespace -StatusOr GpuSchedulePostprocessing::Run( +absl::StatusOr GpuSchedulePostprocessing::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { if (!module->has_schedule()) return false; diff --git a/xla/service/gpu/gpu_schedule_postprocessing.h b/xla/service/gpu/gpu_schedule_postprocessing.h index 521d74e617d10..d8eda81f25780 100644 --- a/xla/service/gpu/gpu_schedule_postprocessing.h +++ b/xla/service/gpu/gpu_schedule_postprocessing.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,10 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GPU_SCHEDULE_POSTPROCESSING_H_ #define XLA_SERVICE_GPU_GPU_SCHEDULE_POSTPROCESSING_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" namespace xla { @@ -37,7 +41,7 @@ class GpuSchedulePostprocessing : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/gpu_schedule_postprocessing_test.cc b/xla/service/gpu/gpu_schedule_postprocessing_test.cc index b9ef17de14c82..9d4956bdd5b4d 100644 --- a/xla/service/gpu/gpu_schedule_postprocessing_test.cc +++ b/xla/service/gpu/gpu_schedule_postprocessing_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ TEST_F(GpuSchedulePostprocessingTest, SynchronousOpsNotChanged) { ENTRY entry { pf32 = f32[1] parameter(0) - all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config="{\"is_sync\":true}" + all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":true,"no_parallel_custom_call":false}} ROOT all-gather-done = f32[2] all-gather-done(all-gather-start) } )"; @@ -83,7 +83,7 @@ TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsChanged) { ENTRY entry { pf32 = f32[1] parameter(0) pf32.2 = f32[1] custom-call(pf32), custom_call_target="my_custom_call" - all-gather-start = (f32[1], f32[2]) all-gather-start(pf32.2), dimensions={0}, backend_config="{\"is_sync\":false}" + all-gather-start = (f32[1], f32[2]) all-gather-start(pf32.2), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}} ROOT all-gather-done = f32[2] all-gather-done(all-gather-start) } )"; @@ -94,8 +94,10 @@ TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsChanged) { EXPECT_TRUE(changed); HloInstruction* start = FindInstruction(module.get(), "all-gather-start"); - TF_ASSERT_OK_AND_ASSIGN(CollectiveBackendConfig collective_backend_config, - start->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + start->backend_config()); + const CollectiveBackendConfig& collective_backend_config = + gpu_config.collective_backend_config(); EXPECT_TRUE(collective_backend_config.no_parallel_custom_call()); } @@ -105,7 +107,7 @@ TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) { ENTRY entry { pf32 = f32[1] parameter(0) - all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config="{\"is_sync\":false}" + all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}} pf32.2 = f32[1] custom-call(pf32), custom_call_target="my_custom_call" all-gather-done = f32[2] all-gather-done(all-gather-start) ROOT out = (f32[1], f32[2]) tuple(f32[1] pf32.2, f32[2] all-gather-done) @@ -118,8 +120,10 @@ TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) { EXPECT_FALSE(changed); HloInstruction* start = FindInstruction(module.get(), "all-gather-start"); - TF_ASSERT_OK_AND_ASSIGN(CollectiveBackendConfig collective_backend_config, - start->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + start->backend_config()); + const CollectiveBackendConfig& collective_backend_config = + gpu_config.collective_backend_config(); EXPECT_FALSE(collective_backend_config.no_parallel_custom_call()); } @@ -134,7 +138,7 @@ TEST_F(GpuSchedulePostprocessingTest, ENTRY entry { pf32 = f32[1] parameter(0) - all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config="{\"is_sync\":false}" + all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config={"collective_backend_config":{"is_sync":false}} pf32.2 = f32[1] call(f32[1] pf32), to_apply=foo all-gather-done = f32[2] all-gather-done(all-gather-start) ROOT out = (f32[1], f32[2]) tuple(f32[1] pf32.2, f32[2] all-gather-done) @@ -147,8 +151,10 @@ TEST_F(GpuSchedulePostprocessingTest, EXPECT_FALSE(changed); HloInstruction* start = FindInstruction(module.get(), "all-gather-start"); - TF_ASSERT_OK_AND_ASSIGN(CollectiveBackendConfig collective_backend_config, - start->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + start->backend_config()); + const CollectiveBackendConfig& collective_backend_config = + gpu_config.collective_backend_config(); EXPECT_FALSE(collective_backend_config.no_parallel_custom_call()); } diff --git a/xla/service/gpu/gpu_sort_rewriter.cc b/xla/service/gpu/gpu_sort_rewriter.cc index 77d182ccc49a6..0f209b143e74d 100644 --- a/xla/service/gpu/gpu_sort_rewriter.cc +++ b/xla/service/gpu/gpu_sort_rewriter.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,11 +31,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/cub_sort_thunk.h" #include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/runtime/cub_sort_thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -87,7 +86,7 @@ std::optional AnalyzeSortComputation( } // Create runner for CUB sort operation. -StatusOr> CreateRunner( +absl::StatusOr> CreateRunner( HloSortInstruction* sort_op, const SortComputationAnalysis& sort_config) { int value_index = 1 - sort_config.key_operand; return CubSortRunnerInterface::Create( @@ -145,7 +144,8 @@ HloInstruction* UnpackResultPair(HloSortInstruction* sort_op, } // namespace // Rewrites a single sort instruction with a custom call. -StatusOr GpuSortRewriter::RunOnInstruction(HloSortInstruction* sort_op) { +absl::StatusOr GpuSortRewriter::RunOnInstruction( + HloSortInstruction* sort_op) { // Get the sort tensor index and direction. SortComputationAnalysis sort_config = AnalyzeSortComputation(sort_op->called_computations().front()).value(); @@ -203,7 +203,8 @@ StatusOr GpuSortRewriter::RunOnInstruction(HloSortInstruction* sort_op) { } // Rewrites the sorts in the given computation into calls to CUB. -StatusOr GpuSortRewriter::RunOnComputation(HloComputation* computation) { +absl::StatusOr GpuSortRewriter::RunOnComputation( + HloComputation* computation) { std::vector sort_ops; for (auto* inst : computation->instructions()) { HloSortInstruction* sort = DynCast(inst); @@ -220,7 +221,7 @@ StatusOr GpuSortRewriter::RunOnComputation(HloComputation* computation) { } // Replace compatible sort operations with custom calls. -StatusOr GpuSortRewriter::Run( +absl::StatusOr GpuSortRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES(2, "GpuSortRewriter::Run(), before:\n" + module->ToString()); diff --git a/xla/service/gpu/gpu_sort_rewriter.h b/xla/service/gpu/gpu_sort_rewriter.h index 094cbc48bd089..bfea9c87ba63a 100644 --- a/xla/service/gpu/gpu_sort_rewriter.h +++ b/xla/service/gpu/gpu_sort_rewriter.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,12 +17,12 @@ limitations under the License. #define XLA_SERVICE_GPU_GPU_SORT_REWRITER_H_ #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" namespace xla { namespace gpu { @@ -40,13 +40,13 @@ class GpuSortRewriter : public HloModulePass { static constexpr int kSortSizeThreshold = 100000; using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; private: - StatusOr RunOnInstruction(HloSortInstruction* sort_op); - StatusOr RunOnComputation(HloComputation* computation); + absl::StatusOr RunOnInstruction(HloSortInstruction* sort_op); + absl::StatusOr RunOnComputation(HloComputation* computation); }; } // namespace gpu diff --git a/xla/service/gpu/gpu_sort_rewriter_test.cc b/xla/service/gpu/gpu_sort_rewriter_test.cc index 80540770cadf7..3df91cc98e89e 100644 --- a/xla/service/gpu/gpu_sort_rewriter_test.cc +++ b/xla/service/gpu/gpu_sort_rewriter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/gpu_symbol_repository.h b/xla/service/gpu/gpu_symbol_repository.h index 231abdf1ad9bb..61044214e7913 100644 --- a/xla/service/gpu/gpu_symbol_repository.h +++ b/xla/service/gpu/gpu_symbol_repository.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/gpu_transfer_manager.cc b/xla/service/gpu/gpu_transfer_manager.cc index 5bd1c557e952a..8bd97486d3435 100644 --- a/xla/service/gpu/gpu_transfer_manager.cc +++ b/xla/service/gpu/gpu_transfer_manager.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,30 +15,43 @@ limitations under the License. #include "xla/service/gpu/gpu_transfer_manager.h" +#include +#include +#include +#include #include -#include #include #include #include "absl/cleanup/cleanup.h" +#include "absl/functional/function_ref.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" #include "llvm/IR/DataLayout.h" #include "xla/literal.h" -#include "xla/literal_util.h" #include "xla/service/compiler.h" +#include "xla/service/generic_transfer_manager.h" +#include "xla/service/gpu/infeed_manager.h" #include "xla/service/gpu/outfeed_manager.h" #include "xla/service/gpu/target_constants.h" +#include "xla/service/shaped_buffer.h" +#include "xla/service/transfer_manager.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" -#include "xla/stream_executor/host/host_platform_id.h" -#include "xla/stream_executor/multi_platform_manager.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/memory_allocation.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/numbers.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -50,44 +63,42 @@ GpuTransferManager::GpuTransferManager(se::Platform::Id id, unsigned pointer_size) : GenericTransferManager(id, pointer_size) {} -GpuTransferManager::~GpuTransferManager() { - if (pinned_chunk_se_) { - pinned_chunk_se_->HostMemoryDeallocate(pinned_chunk_); - } -} - -Status GpuTransferManager::TransferLiteralToInfeed( +absl::Status GpuTransferManager::TransferLiteralToInfeed( se::StreamExecutor* executor, const LiteralSlice& literal) { return gpu::GetOrCreateInfeedManager(executor)->TransferLiteralToInfeed( executor, literal); } -Status GpuTransferManager::TransferLiteralFromOutfeed( +absl::Status GpuTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* executor, MutableBorrowingLiteral literal) { return gpu::GetOrCreateOutfeedManager(executor)->TransferLiteralFromOutfeed( executor, literal); } -void GpuTransferManager::EnsurePinnedBuffersAllocated( +absl::Status GpuTransferManager::EnsurePinnedBuffersAllocated( se::StreamExecutor* executor) { if (pinned_chunk_ != nullptr) { - return; + return absl::OkStatus(); } + TF_ASSIGN_OR_RETURN(pinned_chunk_, + executor->HostMemoryAllocate(kPinnedChunkBytes)); pinned_chunk_se_ = executor; - pinned_chunk_ = - reinterpret_cast(executor->HostMemoryAllocate(kPinnedChunkBytes)); + static_assert(kPinnedChunkBytes % kPinnedBufferBytes == 0, "assumption of loop below"); - for (char* buf = pinned_chunk_; buf < pinned_chunk_ + kPinnedChunkBytes; + char* base = reinterpret_cast(pinned_chunk_->opaque()); + for (char* buf = base; buf < base + kPinnedChunkBytes; buf += kPinnedBufferBytes) { pinned_buffers_.push_back(buf); } + + return absl::OkStatus(); } -Status GpuTransferManager::ReadDynamicShapes(se::Stream* stream, - const ShapedBuffer* device_buffer, - Shape* device_shape) { +absl::Status GpuTransferManager::ReadDynamicShapes( + se::Stream* stream, const ShapedBuffer* device_buffer, + Shape* device_shape) { DCHECK(device_shape->is_dynamic()); Shape original_device_shape = *device_shape; @@ -101,16 +112,17 @@ Status GpuTransferManager::ReadDynamicShapes(se::Stream* stream, std::vector> copies; TF_RETURN_IF_ERROR(device_buffer->buffers().ForEachElementWithStatus( - [&](const ShapeIndex& index, const se::DeviceMemoryBase& buffer) { + [&](const ShapeIndex& index, + const se::DeviceMemoryBase& buffer) -> absl::Status { const Shape& buffer_shape = ShapeUtil::GetSubshape(*device_shape, index); if (buffer_shape.IsTuple()) { - return OkStatus(); + return absl::OkStatus(); } Shape& device_sub_shape = *ShapeUtil::GetMutableSubshape(device_shape, index); if (device_sub_shape.is_static()) { - return OkStatus(); + return absl::OkStatus(); } // Read the dynamic shape metadata from the device stream. The dynamic @@ -123,11 +135,10 @@ Status GpuTransferManager::ReadDynamicShapes(se::Stream* stream, } auto buffer_8 = se::DeviceMemory(buffer); - auto metadata_buffer = - stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size); + auto metadata_buffer = buffer_8.GetSlice(offset, metadata_size); copies.push_back(std::make_pair(metadata_buffer, &device_sub_shape)); - return OkStatus(); + return absl::OkStatus(); })); // Check out pinned memory for each buffer we want to copy. If there aren't @@ -146,7 +157,7 @@ Status GpuTransferManager::ReadDynamicShapes(se::Stream* stream, { absl::MutexLock lock(&mu_); - EnsurePinnedBuffersAllocated(stream->parent()); + TF_RETURN_IF_ERROR(EnsurePinnedBuffersAllocated(stream->parent())); for (const auto& src_dst : copies) { se::DeviceMemoryBase src = src_dst.first; @@ -171,7 +182,7 @@ Status GpuTransferManager::ReadDynamicShapes(se::Stream* stream, for (int i = 0; i < copies.size(); i++) { se::DeviceMemoryBase src = copies[i].first; void* dst = h2d_memcpy_dsts[i]; - stream->ThenMemcpy(dst, src, src.size()); + TF_RETURN_IF_ERROR(stream->Memcpy(dst, src, src.size())); } // Wait for all the async copies to complete, then write into device_shape. @@ -187,7 +198,134 @@ Status GpuTransferManager::ReadDynamicShapes(se::Stream* stream, device_shape->clear_dynamic_dimensions(); TF_RET_CHECK(ShapeUtil::DynamicShapeIsCompatible(*device_shape, original_device_shape)); - return OkStatus(); + return absl::OkStatus(); +} + +// Chunks `size` into chunks of `chunk_size` and calls `callback` for each. +static absl::Status ForEachChunk( + size_t size, size_t chunk_size, + absl::FunctionRef + callback) { + int64_t num_chunks = CeilOfRatio(size, chunk_size); + + for (int64_t chunk_index = 0; chunk_index < num_chunks; ++chunk_index) { + TF_RETURN_IF_ERROR(callback( + /*chunk_offset=*/chunk_index * chunk_size, + /*chunk_size=*/std::min(chunk_size, size - chunk_index * chunk_size))); + } + return absl::OkStatus(); +} + +absl::Status GpuTransferManager::TransferBufferFromDevice( + se::Stream* stream, const se::DeviceMemoryBase& source, int64_t size, + void* destination) { + if (source.size() < size) { + return absl::FailedPreconditionError(absl::StrFormat( + "Source allocation on device not large enough for data transfer: " + "%d < %d", + source.size(), size)); + } + + VLOG(5) << "Transfer buffer from device: size=" + << tsl::strings::HumanReadableNumBytes(size); + + TF_ASSIGN_OR_RETURN(auto staging_buffer, + GetOrCreateStagingBuffer(stream->parent())); + + absl::MutexLock lock(&staging_buffer->mutex); + void* staging = staging_buffer->allocation->opaque(); + + // Transfer chunk of data from device to destination via staging buffer. + auto transfer_chunk = [&](size_t chunk_offset, + size_t chunk_size) -> absl::Status { + VLOG(5) << "Transfer buffer chunk from device: offset=" << chunk_offset + << " size=" << tsl::strings::HumanReadableNumBytes(chunk_size); + + se::DeviceMemoryBase chunk = source.GetByteSlice(chunk_offset, chunk_size); + TF_RETURN_IF_ERROR(stream->Memcpy(staging, chunk, chunk_size)); + + void* dst = reinterpret_cast(destination) + chunk_offset; + return stream->DoHostCallback( + [=] { std::memcpy(dst, staging, chunk_size); }); + }; + + TF_RETURN_IF_ERROR(stream->WaitFor(staging_buffer->transfer_completed.get())); + TF_RETURN_IF_ERROR(ForEachChunk(size, kStagingBufferSize, transfer_chunk)); + TF_RETURN_IF_ERROR( + stream->RecordEvent(staging_buffer->transfer_completed.get())); + + return absl::OkStatus(); +} + +absl::Status GpuTransferManager::TransferBufferToDevice( + se::Stream* stream, int64_t size, const void* source, + se::DeviceMemoryBase* destination) { + if (destination->size() < size) { + return absl::FailedPreconditionError(absl::StrFormat( + "Destination allocation on device not large enough for data transfer: " + "%d < %d", + destination->size(), size)); + } + + VLOG(5) << "Transfer buffer to device: size=" + << tsl::strings::HumanReadableNumBytes(size); + + TF_ASSIGN_OR_RETURN(auto staging_buffer, + GetOrCreateStagingBuffer(stream->parent())); + + absl::MutexLock lock(&staging_buffer->mutex); + void* staging = staging_buffer->allocation->opaque(); + + // Transfer chunk of data from device to destination. + auto transfer_chunk = [&](size_t chunk_offset, size_t chunk_size) { + VLOG(5) << "Transfer buffer chunk to device: offset=" << chunk_offset + << " size=" << tsl::strings::HumanReadableNumBytes(chunk_size); + + const void* src = reinterpret_cast(source) + chunk_offset; + TF_RETURN_IF_ERROR( + stream->DoHostCallback([=] { std::memcpy(staging, src, chunk_size); })); + + auto chunk = destination->GetByteSlice(chunk_offset, chunk_size); + return stream->Memcpy(&chunk, staging, chunk_size); + }; + + TF_RETURN_IF_ERROR(stream->WaitFor(staging_buffer->transfer_completed.get())); + TF_RETURN_IF_ERROR(ForEachChunk(size, kStagingBufferSize, transfer_chunk)); + TF_RETURN_IF_ERROR( + stream->RecordEvent(staging_buffer->transfer_completed.get())); + + return absl::OkStatus(); +} + +GpuTransferManager::StagingBuffer::StagingBuffer( + std::unique_ptr allocation, + std::unique_ptr transfer_completed) + : allocation(std::move(allocation)), + transfer_completed(std::move(transfer_completed)) {} + +absl::StatusOr +GpuTransferManager::GetOrCreateStagingBuffer(se::StreamExecutor* executor) { + absl::MutexLock lock(&mutex_); + if (auto it = staging_buffers_.find(executor); it != staging_buffers_.end()) { + return &it->second; + } + + VLOG(3) << absl::StreamFormat( + "Allocate staging buffer of %s for executor %p (device_ordinal=%d)", + tsl::strings::HumanReadableNumBytes(kStagingBufferSize), executor, + executor->device_ordinal()); + + TF_ASSIGN_OR_RETURN(auto staging_buffer, + executor->HostMemoryAllocate(kStagingBufferSize)); + + auto transfer_completed = std::make_unique(executor); + if (!transfer_completed->Init()) { + return absl::InternalError("Failed to initialize transfer completed event"); + } + + auto emplaced = staging_buffers_.try_emplace( + executor, std::move(staging_buffer), std::move(transfer_completed)); + return &emplaced.first->second; } } // namespace gpu @@ -214,4 +352,5 @@ static bool InitModule() { stream_executor::rocm::kROCmPlatformId, &CreateAMDGPUTransferManager); return true; } + static bool module_initialized = InitModule(); diff --git a/xla/service/gpu/gpu_transfer_manager.h b/xla/service/gpu/gpu_transfer_manager.h index 500695527e482..3ec0e6da5f816 100644 --- a/xla/service/gpu/gpu_transfer_manager.h +++ b/xla/service/gpu/gpu_transfer_manager.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,13 +16,23 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GPU_TRANSFER_MANAGER_H_ #define XLA_SERVICE_GPU_GPU_TRANSFER_MANAGER_H_ +#include +#include #include +#include "absl/base/thread_annotations.h" +#include "absl/container/node_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "xla/literal.h" #include "xla/service/generic_transfer_manager.h" -#include "xla/service/gpu/infeed_manager.h" -#include "xla/service/transfer_manager.h" -#include "xla/shape_tree.h" -#include "xla/statusor.h" +#include "xla/service/shaped_buffer.h" +#include "xla/shape.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/memory_allocation.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/xla_data.pb.h" @@ -34,22 +44,53 @@ namespace gpu { class GpuTransferManager : public GenericTransferManager { public: GpuTransferManager(se::Platform::Id id, unsigned pointer_size); - ~GpuTransferManager() override; - Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const LiteralSlice& literal) override; - Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, - MutableBorrowingLiteral literal) override; - Status ReadDynamicShapes(se::Stream* stream, - const ShapedBuffer* device_buffer, - Shape* device_shape) override; + absl::Status TransferLiteralToInfeed(se::StreamExecutor* executor, + const LiteralSlice& literal) override; + absl::Status TransferLiteralFromOutfeed( + se::StreamExecutor* executor, MutableBorrowingLiteral literal) override; + absl::Status ReadDynamicShapes(se::Stream* stream, + const ShapedBuffer* device_buffer, + Shape* device_shape) override; private: + // We use a fixed-size staging buffers and split transfer into multiple + // operations if literal does not fit into it. + static constexpr int64_t kStagingBufferSize = 128 * 1024 * 1024; + + // We use host memory allocation (pinned host memory) as a staging buffer for + // transfering literals to and from device. We keep a separate staging + // allocation per device so we don't need to do cross-device synchronization. + // All transfers to and from a device are ordered via stream dependencies. + struct StagingBuffer { + StagingBuffer(std::unique_ptr allocation, + std::unique_ptr transfer_completed); + + absl::Mutex mutex; + std::unique_ptr allocation ABSL_GUARDED_BY(mutex); + std::unique_ptr transfer_completed ABSL_GUARDED_BY(mutex); + }; + GpuTransferManager(const GpuTransferManager&) = delete; GpuTransferManager& operator=(const GpuTransferManager&) = delete; bool PackSubbyteTypes() const override { return true; } + // Returns or creates the staging buffer for the given executor. + absl::StatusOr GetOrCreateStagingBuffer( + se::StreamExecutor* executor); + + absl::Status TransferBufferFromDevice(se::Stream* stream, + const se::DeviceMemoryBase& source, + int64_t size, + void* destination) override; + + absl::Status TransferBufferToDevice( + se::Stream* stream, int64_t size, const void* source, + se::DeviceMemoryBase* destination) override; + + // TODO(ezhulenev): Unify this with staged buffers for transfering literals. + // This class keeps a pool of pinned memory // (StreamExecutor::HostMemoryAllocate()) that serves ReadDynamicShapes(). // This is a bit of a hack: Callers like TensorFlow already have a full pinned @@ -82,7 +123,7 @@ class GpuTransferManager : public GenericTransferManager { // // Lazy initialization works around this, because at that point we have a // stream, and therefore we have an already-initialized StreamExecutor. - void EnsurePinnedBuffersAllocated(se::StreamExecutor* executor) + absl::Status EnsurePinnedBuffersAllocated(se::StreamExecutor* executor) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); static constexpr int64_t kPinnedChunkBytes = 128 * 1024; @@ -96,11 +137,16 @@ class GpuTransferManager : public GenericTransferManager { // Chunk of pinned memory of size kPinnedChunkBytes. The pointers in // pinned_buffers_ point into this chunk. Lazily initialized. - char* pinned_chunk_ ABSL_GUARDED_BY(mu_) = nullptr; + std::unique_ptr pinned_chunk_ ABSL_GUARDED_BY(mu_); // Host buffers for reading dynamic shapes. Each buffer has size // kPinnedBufferBytes. Lazily initialized. std::vector pinned_buffers_ ABSL_GUARDED_BY(mu_); + + // Staging buffers allocated for transfers to and from device. + absl::Mutex mutex_; + absl::node_hash_map staging_buffers_ + ABSL_GUARDED_BY(mutex_); }; } // namespace gpu diff --git a/xla/service/gpu/gpu_windowed_einsum_handler.cc b/xla/service/gpu/gpu_windowed_einsum_handler.cc new file mode 100644 index 0000000000000..fa5339fc007f4 --- /dev/null +++ b/xla/service/gpu/gpu_windowed_einsum_handler.cc @@ -0,0 +1,174 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/gpu_windowed_einsum_handler.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/pattern_matcher.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +namespace m = match; + +int64_t NumberOfInstructionsInComp(const HloComputation* comp, HloOpcode op) { + int64_t total_count = 0; + for (const HloInstruction* inst : comp->instructions()) { + if (inst->opcode() == op) { + ++total_count; + } + } + return total_count; +} + +absl::Status UpdateDotAndConsumerConfig(HloInstruction* dot, + int64_t stream_id) { + auto dot_gpu_config = dot->backend_config(); + + HloInstruction* updater = dot->users()[0]; + auto updater_gpu_config = updater->backend_config(); + dot_gpu_config->set_operation_queue_id(stream_id); + updater_gpu_config->mutable_wait_on_operation_queues()->Add(stream_id); + + TF_RETURN_IF_ERROR(dot->set_backend_config(dot_gpu_config.value())); + TF_RETURN_IF_ERROR(updater->set_backend_config(updater_gpu_config.value())); + return absl::OkStatus(); +} + +absl::Status SetForceDelayForInstruction(HloInstruction* instr, + bool force_delay) { + auto gpu_config = instr->backend_config(); + + gpu_config->set_force_earliest_schedule(force_delay); + + TF_RETURN_IF_ERROR(instr->set_backend_config(gpu_config.value())); + return absl::OkStatus(); +} + +absl::StatusOr HandleRsWindowedEinsumLoop(HloComputation* comp, + int64_t stream_id) { + bool changed = false; + // If we have a einsum loop with only 1 dot, this means either + // the loop is not unrolled or only 1 partition is available. + // It's a no-op in either case. + if (NumberOfInstructionsInComp(comp, HloOpcode::kDot) <= 1) { + return changed; + } + for (auto inst : comp->MakeInstructionPostOrder()) { + HloInstruction* matched_dot; + // The dot we'd like to parallelize is consuming the second loop input + // as RHS. + if (Match(inst, m::Dot(&matched_dot, m::DynamicSlice(), + m::GetTupleElement(m::Parameter(), 1)))) { + // Dispatch the dot to additional compute stream. + TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(matched_dot, stream_id)); + ++stream_id; + changed = true; + } + + // We need to enforce the first collective-permute to be always scheduled + // at the beginning of the loop. + HloInstruction* matched_cp; + if (Match(inst, m::CollectivePermute( + &matched_cp, m::GetTupleElement(m::Parameter(), 2)))) { + TF_RETURN_IF_ERROR( + SetForceDelayForInstruction(matched_cp, /*force_delay=*/true)); + changed = true; + } + } + return changed; +} + +absl::StatusOr HandleAgWindowedEinsumLoop(HloComputation* comp, + int64_t stream_id) { + bool changed = false; + // If we have a einsum loop with only 1 dot, this means either + // the loop is not unrolled or only 1 partition is available. + // It's a no-op in either case. + if (NumberOfInstructionsInComp(comp, HloOpcode::kDot) <= 1) { + return changed; + } + for (auto inst : comp->MakeInstructionPostOrder()) { + HloInstruction* matched_dot; + // The dot we'd like to parallelize is consuming the second loop input + // as RHS and first loop input as LHS. + if (Match(inst, m::Dot(&matched_dot, m::GetTupleElement(m::Parameter(), 0), + m::GetTupleElement(m::Parameter(), 1)))) { + // Dispatch the dot to additional compute stream. + TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(matched_dot, stream_id)); + ++stream_id; + TF_RETURN_IF_ERROR( + SetForceDelayForInstruction(matched_dot, /*force_delay=*/true)); + changed = true; + } + + // We need to enforce the first collective-permute to be always scheduled + // at the beginning of the loop. + HloInstruction* matched_cp; + if (Match(inst, m::CollectivePermute( + &matched_cp, m::GetTupleElement(m::Parameter(), 0)))) { + TF_RETURN_IF_ERROR( + SetForceDelayForInstruction(matched_cp, /*force_delay=*/true)); + changed = true; + } + } + return changed; +} + +} // namespace + +absl::StatusOr GpuWindowedEinsumHandler::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + XLA_VLOG_LINES( + 5, "GpuWindowedEinsumHandler::Run(), before:\n" + module->ToString()); + bool changed = false; + int64_t stream_id = hlo_query::NextChannelId(*module); + + for (HloComputation* comp : + module->MakeNonfusionComputations(execution_threads)) { + if (comp->name().find(kWindowedEinsumRsLoopName) == 0) { + VLOG(5) << "Processing computation: " << comp->name(); + TF_ASSIGN_OR_RETURN(bool comp_result, + HandleRsWindowedEinsumLoop(comp, stream_id)); + changed = comp_result; + } else if (comp->name().find(kWindowedEinsumAgLoopName) == 0) { + VLOG(5) << "Processing computation: " << comp->name(); + TF_ASSIGN_OR_RETURN(bool comp_result, + HandleAgWindowedEinsumLoop(comp, stream_id)); + changed = comp_result; + } + } + XLA_VLOG_LINES( + 5, "GpuWindowedEinsumHandler::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/gpu_windowed_einsum_handler.h b/xla/service/gpu/gpu_windowed_einsum_handler.h new file mode 100644 index 0000000000000..87ec1474d576f --- /dev/null +++ b/xla/service/gpu/gpu_windowed_einsum_handler.h @@ -0,0 +1,55 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_GPU_WINDOWED_EINSUM_HANDLER_H_ +#define XLA_SERVICE_GPU_GPU_WINDOWED_EINSUM_HANDLER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla::gpu { + +// This pass is targeting the windowed einsum optimization +// in the SPMD pipeline. It rewrites all-gather+gemm or +// gemm+reduce-scatter into sharded loops to achieve overlap +// between sharded gemms and communication. This pass will +// optimize it on GPU by annotating independent gemms with +// stream ids in the backend config. By running them in different +// streams, we can practically achieve overlap between gemms too. +class GpuWindowedEinsumHandler : public HloModulePass { + public: + absl::string_view name() const override { + return "gpu-windowed-einsum-handler"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + constexpr static const char* kWindowedEinsumRsLoopName = + "windowed_dot_general_body_rs"; + constexpr static const char* kWindowedEinsumAgLoopName = + "windowed_dot_general_body_ag"; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_GPU_WINDOWED_EINSUM_HANDLER_H_ diff --git a/xla/service/gpu/gpu_windowed_einsum_handler_test.cc b/xla/service/gpu/gpu_windowed_einsum_handler_test.cc new file mode 100644 index 0000000000000..c70fbf2b08d12 --- /dev/null +++ b/xla/service/gpu/gpu_windowed_einsum_handler_test.cc @@ -0,0 +1,197 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/gpu_windowed_einsum_handler.h" + +#include +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +using GpuWindowedEinsumHanlderTest = HloTestBase; + +HloInstruction* FindInstructionByName(HloComputation* comp, std::string name) { + for (auto inst : comp->instructions()) { + if (inst->name() == name) { + return inst; + } + } + return nullptr; +} + +TEST_F(GpuWindowedEinsumHanlderTest, AgLoopsHaveStreamIds) { + constexpr absl::string_view kHloString = R"( +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,512,24576]{2,1,0}, bf16[24576,24576]{1,0})->bf16[2048,24576]{1,0}}, num_partitions=4 + +windowed_dot_general_body_ag.1 { + param = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) parameter(0) + get-tuple-element = bf16[512,24576]{1,0} get-tuple-element(param), index=0 + collective-permute = bf16[512,24576]{1,0} collective-permute(get-tuple-element), channel_id=2, source_target_pairs={{0,3},{1,0},{2,1},{3,2}}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]} + get-tuple-element.1 = bf16[24576,24576]{1,0} get-tuple-element(param), index=1 + get-tuple-element.2 = bf16[2048,24576]{1,0} get-tuple-element(param), index=2 + dot.2 = bf16[512,24576]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]} + constant.1 = s32[4]{0} constant({0, 512, 1024, 1536}) + get-tuple-element.4 = u32[] get-tuple-element(param), index=4 + partition-id = u32[] partition-id() + add = u32[] add(get-tuple-element.4, partition-id) + constant = u32[] constant(4) + remainder = u32[] remainder(add, constant) + dynamic-slice = s32[1]{0} dynamic-slice(constant.1, remainder), dynamic_slice_sizes={1} + reshape.4 = s32[] reshape(dynamic-slice) + constant.2 = s32[] constant(0) + dynamic-update-slice = bf16[2048,24576]{1,0} dynamic-update-slice(get-tuple-element.2, dot.2, reshape.4, constant.2), backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]} + dot.3 = bf16[512,24576]{1,0} dot(collective-permute, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + constant.3 = u32[] constant(1) + add.1 = u32[] add(get-tuple-element.4, constant.3) + add.2 = u32[] add(add.1, partition-id) + remainder.1 = u32[] remainder(add.2, constant) + dynamic-slice.1 = s32[1]{0} dynamic-slice(constant.1, remainder.1), dynamic_slice_sizes={1} + reshape.5 = s32[] reshape(dynamic-slice.1) + dynamic-update-slice.1 = bf16[2048,24576]{1,0} dynamic-update-slice(dynamic-update-slice, dot.3, reshape.5, constant.2) + get-tuple-element.3 = bf16[2048,24576]{1,0} get-tuple-element(param), index=3 + add.3 = u32[] add(add.1, constant.3) + ROOT tuple = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) tuple(collective-permute, get-tuple-element.1, dynamic-update-slice.1, get-tuple-element.3, add.3) +} // windowed_dot_general_body_ag.1 + +windowed_dot_general_cond_ag { + param.1 = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) parameter(0) + get-tuple-element.5 = u32[] get-tuple-element(param.1), index=4 + constant.8 = u32[] constant(4) + ROOT compare = pred[] compare(get-tuple-element.5, constant.8), direction=LT +} + +ENTRY test_main { + param.4 = bf16[1,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + reshape.8 = bf16[512,24576]{1,0} reshape(param.4) + param.5 = bf16[24576,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + constant.18 = bf16[] constant(0) + broadcast = bf16[2048,24576]{1,0} broadcast(constant.18), dimensions={} + constant.20 = u32[] constant(0) + tuple.2 = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) tuple(reshape.8, param.5, broadcast, broadcast, constant.20) + while = (bf16[512,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[2048,24576]{1,0}, bf16[2048,24576]{1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag.1 + ROOT get-tuple-element.13 = bf16[2048,24576]{1,0} get-tuple-element(while), index=2 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + GpuWindowedEinsumHandler gpu_handler; + bool changed; + TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); + EXPECT_TRUE(changed); + + HloInstruction* ag_loop = + module->entry_computation()->root_instruction()->mutable_operand(0); + HloComputation* ag_loop_body = ag_loop->while_body(); + HloInstruction* inst = FindInstructionByName(ag_loop_body, "dot.2"); + EXPECT_GT(inst->backend_config()->operation_queue_id(), 0); + EXPECT_TRUE( + inst->backend_config()->force_earliest_schedule()); + + HloInstruction* cp1 = + FindInstructionByName(ag_loop_body, "collective-permute"); + EXPECT_TRUE( + cp1->backend_config()->force_earliest_schedule()); +} + +TEST_F(GpuWindowedEinsumHanlderTest, RsLoopsHaveStreamIds) { + constexpr absl::string_view kHloString = R"( +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[2048,24576]{1,0})->bf16[512,24576]{1,0}}, num_partitions=4 + +windowed_dot_general_body_rs_clone.1 { + param.2 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) parameter(0) + get-tuple-element.6 = bf16[2048,24576]{1,0} get-tuple-element(param.2), index=0 + get-tuple-element.7 = bf16[24576,24576]{1,0} get-tuple-element(param.2), index=1 + get-tuple-element.9 = bf16[512,24576]{1,0} get-tuple-element(param.2), index=2 + collective-permute.1 = bf16[512,24576]{1,0} collective-permute(get-tuple-element.9), channel_id=4, source_target_pairs={{0,2},{1,3},{2,0},{3,1}}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]} + constant.10 = s32[4]{0} constant({0, 512, 1024, 1536}) + get-tuple-element.11 = u32[] get-tuple-element(param.2), index=4 + constant.12 = u32[] constant(2) + add.8 = u32[] add(get-tuple-element.11, constant.12) + constant.13 = u32[] constant(1) + add.9 = u32[] add(add.8, constant.13) + partition-id.3 = u32[] partition-id() + add.10 = u32[] add(add.9, partition-id.3) + constant.9 = u32[] constant(4) + remainder.3 = u32[] remainder(add.10, constant.9) + dynamic-slice.4 = s32[1]{0} dynamic-slice(constant.10, remainder.3), dynamic_slice_sizes={1} + reshape.7 = s32[] reshape(dynamic-slice.4) + constant.11 = s32[] constant(0) + dynamic-slice.5 = bf16[512,24576]{1,0} dynamic-slice(get-tuple-element.6, reshape.7, constant.11), dynamic_slice_sizes={512,24576} + dot.7 = bf16[512,24576]{1,0} dot(dynamic-slice.5, get-tuple-element.7), lhs_contracting_dims={1}, rhs_contracting_dims={0}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]} + add.11 = bf16[512,24576]{1,0} add(collective-permute.1, dot.7), backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]} + get-tuple-element.10 = bf16[512,24576]{1,0} get-tuple-element(param.2), index=3 + add.6 = u32[] add(get-tuple-element.11, partition-id.3) + remainder.2 = u32[] remainder(add.6, constant.9) + dynamic-slice.2 = s32[1]{0} dynamic-slice(constant.10, remainder.2), dynamic_slice_sizes={1} + reshape.6 = s32[] reshape(dynamic-slice.2) + dynamic-slice.3 = bf16[512,24576]{1,0} dynamic-slice(get-tuple-element.6, reshape.6, constant.11), dynamic_slice_sizes={512,24576} + dot.5 = bf16[512,24576]{1,0} dot(dynamic-slice.3, get-tuple-element.7), lhs_contracting_dims={1}, rhs_contracting_dims={0}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]} + add.7 = bf16[512,24576]{1,0} add(get-tuple-element.10, dot.5), backend_config={"operation_queue_id":"0","wait_on_operation_queues":[]} + collective-permute.2 = bf16[512,24576]{1,0} collective-permute(add.7), channel_id=5, source_target_pairs={{0,2},{1,3},{2,0},{3,1}} + ROOT tuple.1 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) tuple(get-tuple-element.6, get-tuple-element.7, add.11, collective-permute.2, add.8) +} + +windowed_dot_general_cond_rs { + param.3 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) parameter(0) + get-tuple-element.12 = u32[] get-tuple-element(param.3), index=4 + constant.17 = u32[] constant(4) + ROOT compare.1 = pred[] compare(get-tuple-element.12, constant.17), direction=LT +} + +ENTRY main.9_spmd { + param.6 = bf16[24576,24576]{1,0} parameter(0), sharding={devices=[4,1]<=[4]} + param.7 = bf16[512,24576]{1,0} parameter(1) + param.8 = bf16[2048,24576]{1,0} parameter(2) + constant.20 = u32[] constant(0) + tuple.3 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) tuple(param.8, param.6, param.7, param.7, constant.20) + while.1 = (bf16[2048,24576]{1,0}, bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[512,24576]{1,0}, u32[]) while(tuple.3), condition=windowed_dot_general_cond_rs, body=windowed_dot_general_body_rs_clone.1 + ROOT get-tuple-element.14 = bf16[512,24576]{1,0} get-tuple-element(while.1), index=2 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + GpuWindowedEinsumHandler gpu_handler; + bool changed; + TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); + EXPECT_TRUE(changed); + + HloInstruction* rs_loop = + module->entry_computation()->root_instruction()->mutable_operand(0); + HloComputation* rs_loop_body = rs_loop->while_body(); + HloInstruction* inst = FindInstructionByName(rs_loop_body, "dot.7"); + EXPECT_TRUE(inst->backend_config()->operation_queue_id() > + 0); + + HloInstruction* cp1 = + FindInstructionByName(rs_loop_body, "collective-permute.1"); + EXPECT_TRUE( + cp1->backend_config()->force_earliest_schedule()); +} + +} // namespace +} // namespace xla::gpu diff --git a/xla/service/gpu/hlo_algorithm_denylist.cc b/xla/service/gpu/hlo_algorithm_denylist.cc index c1b8b3998f127..1679f3d95e91f 100644 --- a/xla/service/gpu/hlo_algorithm_denylist.cc +++ b/xla/service/gpu/hlo_algorithm_denylist.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,11 +15,19 @@ limitations under the License. #include "xla/service/gpu/hlo_algorithm_denylist.h" +#include #include +#include +#include #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/types/span.h" #include "xla/debug_options_flags.h" #include "xla/service/gpu/gpu_autotuning.pb.h" +#include "xla/stream_executor/dnn.h" +#include "tsl/platform/env.h" +#include "tsl/platform/status.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/hlo_algorithm_denylist.h b/xla/service/gpu/hlo_algorithm_denylist.h index 5dfcb189a954e..828903ce99458 100644 --- a/xla/service/gpu/hlo_algorithm_denylist.h +++ b/xla/service/gpu/hlo_algorithm_denylist.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,11 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_HLO_ALGORITHM_DENYLIST_H_ #define XLA_SERVICE_GPU_HLO_ALGORITHM_DENYLIST_H_ -#include +#include +#include "absl/types/span.h" #include "xla/autotuning.pb.h" -#include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/dnn.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/hlo_algorithm_denylist_test.cc b/xla/service/gpu/hlo_algorithm_denylist_test.cc index 18447ca4974fb..c98ac430b3e47 100644 --- a/xla/service/gpu/hlo_algorithm_denylist_test.cc +++ b/xla/service/gpu/hlo_algorithm_denylist_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,10 +18,10 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "xla/stream_executor/dnn.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" -#include "tsl/platform/resource_loader.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/service/gpu/hlo_fusion_analysis.cc b/xla/service/gpu/hlo_fusion_analysis.cc index 05182f8e660f1..529ebf9008d5d 100644 --- a/xla/service/gpu/hlo_fusion_analysis.cc +++ b/xla/service/gpu/hlo_fusion_analysis.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,85 +16,34 @@ limitations under the License. #include "xla/service/gpu/hlo_fusion_analysis.h" #include -#include #include #include #include -#include #include -#include #include #include "absl/algorithm/container.h" -#include "absl/container/node_hash_map.h" #include "absl/log/check.h" -#include "absl/numeric/bits.h" -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_computation.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/STLExtras.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/utils/hlo_query.h" #include "xla/primitive_util.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/kernel_mapping_scheme.h" -#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/reduction_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/statusor.h" #include "xla/stream_executor/device_description.h" -#include "xla/union_find.h" -#include "tsl/platform/macros.h" namespace xla { namespace gpu { namespace { -const auto kDimX = TilingScheme::DimX; -const auto kLinearIndexingX = TilingScheme::LinearIndexingX; -const auto kStridedIndexingX = TilingScheme::StridedIndexingX; - -std::optional ComputeTransposeTilingScheme( - const std::optional& tiled_transpose) { - if (!tiled_transpose) { - return std::nullopt; - } - - constexpr int kNumRows = 4; - static_assert(WarpSize() % kNumRows == 0); - - // 3D view over the input shape. - Vector3 dims = tiled_transpose->dimensions; - Vector3 order = tiled_transpose->permutation; - - Vector3 permuted_dims = {dims[order[0]], dims[order[1]], dims[order[2]]}; - Vector3 tile_sizes{1, 1, 1}; - tile_sizes[order[2]] = WarpSize() / kNumRows; - Vector3 num_threads{1, 1, WarpSize()}; - num_threads[order[2]] = kNumRows; - - return TilingScheme( - /*permuted_dims*/ permuted_dims, - /*tile_sizes=*/tile_sizes, - /*num_threads=*/num_threads, - /*indexing_order=*/kLinearIndexingX, - /*vector_size=*/1, - /*scaling_factor=*/1, - /*tiling_dimensions=*/{order[2], 2}); -} - -// Returns true if `instr` is a non-strided slice. -bool IsSliceWithUnitStrides(const HloInstruction* instr) { - auto slice = DynCast(instr); - return slice && absl::c_all_of(slice->slice_strides(), - [](int64_t stride) { return stride == 1; }); -} - // Returns true if the fusion output contains non-strided slices only. bool IsInputFusibleNonStridedSlices( const std::vector& fusion_roots) { @@ -111,139 +60,6 @@ bool AllSliceInputsAreCompatible( }); } -bool MayPreventVectorization(const HloFusionAdaptor& fusion) { - // An empirically chosen constant: unrolling concat with a large amount of - // arguments causes excessive register spilling. - static constexpr int kMaxConcatArgumentsForUnrolling = 10; - return HloAnyOf(fusion.GetRoots(), fusion, [&](auto node) { - switch (node.opcode()) { - case HloOpcode::kReduceWindow: - case HloOpcode::kSort: - case HloOpcode::kDot: - case HloOpcode::kSin: - case HloOpcode::kCos: - case HloOpcode::kTan: - case HloOpcode::kPower: - case HloOpcode::kAtan2: - return true; - case HloOpcode::kConcatenate: - return node.instruction().operand_count() > - kMaxConcatArgumentsForUnrolling; - case HloOpcode::kReduce: - return node.instruction().shape().tuple_shapes_size() > 1; - default: - return false; - } - }); -} - -// Determines if we enable the row optimized codegen. When we have a fusion with -// only point-wise operations, scalar broadcasting and row broadcasting, we can -// trigger a kernel that vectorizes the row loads. This speeds up the kernel, in -// particular on A100. The int is the number of inputs with rank `out_rank`. Its -// value is only defined if row vectorization is enabled. -std::pair RowVectorizationEnabled( - const HloFusionAdaptor& fusion, int64_t out_rank) { - auto roots = fusion.GetRoots(); - const auto is_row_major = [](auto instr) { - // Only tested when the inputs are row-major. So only enable that case. - // Maybe it would work if only the inner dimensions is contiguous. - return LayoutUtil::IsMonotonicWithDim0Major(instr.shape().layout()); - }; - bool row_vectorized = roots.size() == 1 && !roots[0].shape().IsTuple() && - is_row_major(roots[0]); - if (!row_vectorized) { - return {false, 0}; - } - - // Check that the operations in the fusion are supported. Each - // supported operation (or category) must be manually vetted as XLA - // only unrolls and relies on LLVM to vectorize. But this is brittle. - // Currently tested and supported operations: - // Elementwise, scalar and row broadcasting. - // - // We also detect at the same time if there is a row broadcasting - // operation. - int num_big_inputs = 0; - bool some_row_broadcasting = false; - HloBfsConsumersFirstTraversal( - roots, fusion, - [&](auto node) -> TraversalResult { - if (!row_vectorized) { - return TraversalResult::kAbortTraversal; - } - - if (node.instruction().IsElementwise()) { - return TraversalResult::kVisitOperands; - } - - switch (node.opcode()) { - case HloOpcode::kConstant: - return TraversalResult::kDoNotVisitOperands; - case HloOpcode::kParameter: - return TraversalResult::kVisitOperands; - case HloOpcode::kBroadcast: { - auto dims = node.instruction().dimensions(); - if (dims.empty()) { - return TraversalResult::kVisitOperands; - } - - if (dims.size() == 1 && dims.front() == node.shape().rank() - 1) { - some_row_broadcasting = true; - return TraversalResult::kVisitOperands; - } - TF_FALLTHROUGH_INTENDED; - } - default: - VLOG(2) << "Row vectorization not enabled due to: " - << node.ToString(); - row_vectorized = false; - return TraversalResult::kAbortTraversal; - } - }, - [&](auto argument) { - if (argument.shape().rank() == out_rank) { - ++num_big_inputs; - } - if (!is_row_major(argument)) { - row_vectorized = false; - } - }); - // Trigger only when there is a row broadcasting. - return std::make_pair(row_vectorized && some_row_broadcasting, - num_big_inputs); -} - -// Computes the maximum valid unroll factor for a given instruction. -int ComputeMaxUnrollFactor(int64_t num_elements) { - constexpr int kMaxUnrollFactor = 4; - for (int i = kMaxUnrollFactor; i > 1; i /= 2) { - if (num_elements % i == 0) { - return i; - } - } - return 1; -} - -// For a row reduction, returns the number of rows we can process in parallel -// per warp. -int RowReductionGetRowsPerWarp(int reduced_dimension_size) { - if (WarpSize() % reduced_dimension_size != 0 || - reduced_dimension_size >= WarpSize()) { - return 1; - } - return WarpSize() / reduced_dimension_size; -} - -int64_t NearestPowerOfTwo(int64_t v) { - if (v < 0) { - return 0; - } - int64_t upper = absl::bit_ceil(v); - int64_t lower = upper >> 1; - return upper - v < v - lower ? upper : lower; -} - // Returns a description of a transpose hero, that is compatible with all roots. // // A root is compatible with the transpose hero if: @@ -310,13 +126,10 @@ HloFusionAnalysis::HloFusionAnalysis( fusion_heroes_(std::move(fusion_heroes)), device_info_(device_info), tiled_transpose_(tiled_transpose), - input_output_info_(std::move(input_output_info)), - reduction_codegen_info_(ComputeReductionCodegenInfo(FindHeroReduction())), - transpose_tiling_scheme_(ComputeTransposeTilingScheme(tiled_transpose_)), - loop_fusion_config_(ComputeLoopFusionConfig()) {} + input_output_info_(std::move(input_output_info)) {} // static -StatusOr HloFusionAnalysis::Create( +HloFusionAnalysis HloFusionAnalysis::Create( FusionBackendConfig backend_config, std::unique_ptr fusion, const se::DeviceDescription* device_info) { @@ -351,12 +164,14 @@ StatusOr HloFusionAnalysis::Create( } // static -StatusOr HloFusionAnalysis::Create( +HloFusionAnalysis HloFusionAnalysis::Create( const HloFusionInstruction* fusion, const se::DeviceDescription* device_info) { CHECK(device_info != nullptr); - TF_ASSIGN_OR_RETURN(auto backend_config, - fusion->backend_config()); + FusionBackendConfig backend_config = + fusion->has_backend_config() + ? fusion->backend_config()->fusion_backend_config() + : FusionBackendConfig::default_instance(); return Create(std::move(backend_config), HloFusionAdaptor::ForInstruction(fusion), device_info); } @@ -366,29 +181,80 @@ bool HloFusionAnalysis::HasConsistentTransposeHeros() const { return tiled_transpose_.has_value(); } +static bool UseConcatenateFusion( + const std::vector& roots, + const std::vector& heroes) { + if (heroes.size() != 1) return false; + if (heroes.front()->opcode() != HloOpcode::kConcatenate) return false; + // The concat emitter does not support multiple outputs yet. TODO(csigg): fix. + if (roots.front()->shape().IsTuple()) return false; + // Limit the number of operands because the concat emitter produces code for + // each operand, hurting occupancy. + if (heroes.front()->operand_count() > 4) return false; + // The loop emitter is faster when warp divergence and occupancy are both low. + // TODO(csigg): exclude this case. + return true; +} + HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() const { if (fusion_backend_config_.kind() == kCustomFusionKind) { return EmitterFusionKind::kCustomFusion; } -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (fusion_backend_config_.kind() == kTritonGemmFusionKind || fusion_backend_config_.kind() == kTritonSoftmaxFusionKind) { return EmitterFusionKind::kTriton; } + + if (fusion_backend_config_.kind() == kCuDnnFusionKind) { + return EmitterFusionKind::kCuDnn; + } #endif if (input_output_info_.has_4_bit_input || input_output_info_.has_4_bit_output) { - // Only loop fusions currently can handle int4 inputs/outputs, due to the - // special handling with IrArray needed to deal with two values occupying a - // single byte. + // Only loop and input slice fusions currently can handle int4 + // inputs/outputs, due to the special handling with IrArray needed to deal + // with two values occupying a single byte. + if (fusion_roots_.size() > 1 && + IsInputFusibleNonStridedSlices(fusion_roots_) && + AllSliceInputsAreCompatible(fusion_roots_)) { + return EmitterFusionKind::kInputSlices; + } return EmitterFusionKind::kLoop; } + const HloInstruction* first_reduce_hero = nullptr; for (auto [root, hero] : llvm::zip(fusion_roots_, fusion_heroes_)) { if (IsRealReductionHero(*root, *hero)) { + first_reduce_hero = hero; + break; + } + } + if (first_reduce_hero != nullptr) { + bool valid_shapes = true; + Shape hero_operand_shape = first_reduce_hero->operand(0)->shape(); + for (auto [root, hero] : llvm::zip(fusion_roots_, fusion_heroes_)) { + if (root == first_reduce_hero) { + continue; + } + if (!IsRealReductionHero(*root, *hero)) { + // Needs to have a compatible shape to the reduce operand (compatible + // meaning same number of elements). + if (ShapeUtil::ElementsIn(root->shape()) != + ShapeUtil::ElementsIn(hero_operand_shape)) { + valid_shapes = false; + break; + } + } else if (!AreReductionsMultiOutputFusionCompatible(hero, + first_reduce_hero)) { + valid_shapes = false; + break; + } + } + if (valid_shapes) { return EmitterFusionKind::kReduction; } } @@ -410,56 +276,11 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() return EmitterFusionKind::kScatter; } - return EmitterFusionKind::kLoop; -} - -StatusOr HloFusionAnalysis::GetLaunchDimensions() const { - auto emitter_fusion_kind = GetEmitterFusionKind(); - switch (emitter_fusion_kind) { - case EmitterFusionKind::kLoop: { - // Disable experimental block size if few_waves or row_vectorized enabled. - auto loop_fusion_config = GetLoopFusionConfig(); - return CalculateLaunchDimensions(GetElementShape(), *device_info_, - *loop_fusion_config); - } - case EmitterFusionKind::kReduction: { - auto* reduction_codegen_info = GetReductionCodegenInfo(); - const TilingScheme& tiling_scheme = - reduction_codegen_info->GetTilingScheme(); - size_t blocks_y = reduction_codegen_info->GetIndexGroups().size(); - return LaunchDimensions( - {/*x=*/tiling_scheme.GetNumberOfBlocksPhysical(), - /*y=*/static_cast(blocks_y), /*z=*/1}, - {/*x=*/tiling_scheme.GetNumThreadsPerBlockPhysical(), - /*y=*/1, /*z=*/1}); - } - case EmitterFusionKind::kTranspose: { - auto* tiling_scheme = GetTransposeTilingScheme(); - return LaunchDimensions(tiling_scheme->GetNumberOfBlocksPhysical(), - tiling_scheme->GetNumThreadsPerBlockPhysical()); - } - case EmitterFusionKind::kInputSlices: { - auto* root = fusion_roots().front(); - const auto& shape = root->operands()[0]->shape(); - constexpr int kUnrollFactor = 1; - return CalculateLaunchDimensions(shape, *device_info_, {kUnrollFactor}); - } - case EmitterFusionKind::kScatter: { - const auto& root_shape = fusion_roots().front()->shape(); - int64_t num_elements = ShapeUtil::ElementsIn(root_shape); - int unroll_factor = num_elements % 4 == 0 ? 4 - : num_elements % 2 == 0 ? 2 - : 1; - return CalculateLaunchDimensions(root_shape, *device_info_, - {unroll_factor, /*few_waves=*/false}); - } - case EmitterFusionKind::kCustomFusion: - return absl::UnimplementedError( - "GetLaunchDimensions is not implemented for custom fusions"); - case EmitterFusionKind::kTriton: - return absl::UnimplementedError( - "GetLaunchDimensions is not implemented for Triton fusions"); + if (UseConcatenateFusion(fusion_roots_, fusion_heroes_)) { + return EmitterFusionKind::kConcatenate; } + + return EmitterFusionKind::kLoop; } const HloInstruction* HloFusionAnalysis::FindHeroReduction() const { @@ -480,464 +301,25 @@ const HloInstruction* HloFusionAnalysis::FindHeroReduction() const { LOG(FATAL) << "Did not find a hero reduction"; } -std::optional -HloFusionAnalysis::ComputeLoopFusionConfig() const { - int unroll_factor = 1; - // Unrolling is good to read large inputs with small elements - // due to vector loads, but increases the register pressure when one - // thread has to produce multiple output elements. - // Therefore for fusions with small outputs prefer to use one thread - // per output element = no unroll. - // Call 'small' fusions that use less threads than the GPU has. - int64_t num_elements = ShapeUtil::ElementsIn(GetElementShape()); - int64_t n_threads_max = - device_info_->threads_per_core_limit() * device_info_->core_count(); - if (num_elements >= n_threads_max && !MayPreventVectorization(*fusion_)) { - unroll_factor = ComputeMaxUnrollFactor(num_elements); - } - // CHECK that unroll_factor is a power-of-2, as needed by the logic below. - CHECK(absl::has_single_bit(static_cast(unroll_factor))); - if (input_output_info_.has_4_bit_output && unroll_factor == 1) { - // Ensure a single thread writes to a byte containing two int4 values by - // setting unroll_factor to 2. unroll_factor is always a power of 2, so - // setting it to 2 here ensures unroll_factor is even when there are 4-bit - // outputs. Setting unroll_factor is safe even if there are an odd number of - // elements, as the parallel loop emitter will insert a bounds check in this - // case to ensure the out-of-bounds element is not computed and written. - // Setting unroll_factor is safe even if MayPreventVectorization returns - // false, as the MayPreventVectorization check is an optimization, not a - // correctness requirement. - unroll_factor = 2; - } - VLOG(2) << "Unroll factor: " << unroll_factor; - - if (GetEmitterFusionKind() == EmitterFusionKind::kScatter) { - // Only the unroll factor is used for scatter. - return LaunchDimensionsConfig{unroll_factor}; - } - - bool row_vectorized; - int num_big_inputs; - std::tie(row_vectorized, num_big_inputs) = - RowVectorizationEnabled(*fusion_, GetElementShape().rank()); - bool few_waves = !HloAnyOf(fusion_->GetRoots(), *fusion_, [&](auto instr) { - if (instr.opcode() == HloOpcode::kParameter || - instr.opcode() == HloOpcode::kConstant || - HloInstruction::IsOpElementwise(instr.opcode())) { - return false; - } - if (auto broadcast = - DynCast(&instr.instruction())) { - if (broadcast->dimensions().empty() || - // More than 3 big inputs cause a speed regression. - (row_vectorized && num_big_inputs <= 3)) { - return false; - } - } - VLOG(2) << "few_waves not enabled due to: " - << instr.instruction().ToString(); - return true; - }); - - LaunchDimensionsConfig launch_config{unroll_factor, few_waves, - row_vectorized}; - // Check that the shapes is supported. - if (launch_config.row_vectorized && - ThreadsPerBlockRowVectorized(GetElementShape(), *device_info_, - launch_config) <= 0) { - VLOG(2) << "Cancelling row_vectorization as the shape isn't supported."; - launch_config.row_vectorized = false; - launch_config.few_waves = false; - } - return launch_config; -} - -const Shape& HloFusionAnalysis::GetElementShape() const { - const Shape* shape = &fusion_roots_.front()->shape(); - while (shape->IsTuple()) { - shape = &shape->tuple_shapes(0); - } - return *shape; -} - -int64_t HloFusionAnalysis::MaxBeneficialColumnReductionUnrollBasedOnBlockSize() - const { - // Some callers use this analysis with an invalid device info. - // TODO(jreiffers): Fix that. - if (device_info_->core_count() == 0) return 1; - - int64_t num_reduce_output_elems = 0; - for (const HloInstruction* root : fusion_roots()) { - if (!IsReductionFromOrToContiguousDimensions(*root)) { - continue; - } - const Shape* output_shape = &root->shape(); - // Unwrap multi-output reduction. All outputs should be the same shape. - if (output_shape->IsTuple()) { - output_shape = &output_shape->tuple_shapes()[0]; - } - num_reduce_output_elems = - std::max(num_reduce_output_elems, ShapeUtil::ElementsIn(*output_shape)); - } - - // A column reduction that's unrolled N times uses one warp to generate N - // output elements. The block size is always 32 warps = 1024 threads. - int64_t num_blocks = CeilOfRatio(num_reduce_output_elems, int64_t{32}); - int64_t num_threads = num_blocks * 1024; - // Number of SMs we can saturate with this work. - int num_cores = - CeilOfRatio(num_threads, device_info_->threads_per_core_limit()); - return static_cast(CeilOfRatio(num_cores, device_info_->core_count())); -} - -// Divides `num_reduces` reduces into groups. Different groups will be executed -// in parallel. Generally speaking, we'd like to run the reduce instructions -// in parallel without incurring too much recomputation overhead. The current -// heuristic is to place reduce instructions who share nothing or only -// (broadcasted) scalars/constants into different groups; otherwise, they are -// placed in the same group. Non-reduce instructions always go with the reduce -// instructions into the same group so long as they share any predecessors. -std::vector> -HloFusionAnalysis::GroupDisjointReductions() const { - const int num_fusion_outputs = fusion_roots().size(); - - CHECK_NE(0, num_fusion_outputs); - if (num_fusion_outputs == 1) { - return {{fusion_roots()[0]}}; - } - - absl::node_hash_map> - disjoint_sets; - - // TODO(b/249976438): we currently do not treat properly - // aliasing between inputs and outputs of the fusion, so for now put all - // non-reduction roots into one group to avoid read-after-write conflicts. - std::optional first_non_reduction_root = std::nullopt; - - absl::node_hash_map> - reachable_outputs; - absl::flat_hash_set roots_with_reduction; - auto roots = fusion_->GetRoots(); - for (auto [root, hero] : llvm::zip(roots, fusion_heroes_)) { - disjoint_sets[root].Get() = root; - reachable_outputs[root].insert(root); - if (IsRealReductionHero(root.instruction(), *hero)) { - roots_with_reduction.insert(root); - } else if (first_non_reduction_root) { - disjoint_sets[*first_non_reduction_root].Merge(&disjoint_sets[root]); - } else { - first_non_reduction_root = root; - } - } - - std::vector instructions; - HloBfsConsumersFirstTraversal( - roots, *fusion_, - [&](HloInstructionAdaptor consumer) { - auto& consumer_reachable = reachable_outputs[consumer]; - for (auto producer : consumer.GetOperands()) { - reachable_outputs[producer].insert(consumer_reachable.begin(), - consumer_reachable.end()); - } - instructions.push_back(consumer); - return TraversalResult::kVisitOperands; - }, - [&](HloInstructionAdaptor argument) { - instructions.push_back(argument); - }); - - for (auto instr : instructions) { - const auto& reachable = reachable_outputs[instr]; - std::vector reached_output_ids; - bool added_to_reduce = false; - for (auto output : roots) { - bool has_real_hero = roots_with_reduction.contains(output); - if (has_real_hero && - (hlo_query::IsBroadcastedConstantOrScalar(instr.instruction()))) { - if (added_to_reduce) { - // Do not group more than one output reduce instructions through - // broadcasted constants or scalars, as the recomputation should be - // acceptable. - VLOG(3) << "Skip broadcasted constant or scalar " << instr.ToString(); - continue; - } - } - // Now group output instructions if they have common predecessors. - if (reachable.contains(output)) { - VLOG(3) << "Reaching " << output.ToString() << " from " - << instr.ToString(); - reached_output_ids.push_back(output); - if (has_real_hero) { - added_to_reduce = true; - } - } - } - for (size_t j = 1; j < reached_output_ids.size(); ++j) { - disjoint_sets[reached_output_ids[0]].Merge( - &disjoint_sets[reached_output_ids[j]]); - } - } - - // Place output instructions in the same set into the same group. - ConstHloInstructionMap> groups; - for (auto root : roots) { - groups[&disjoint_sets[root].Get().instruction()].push_back( - &root.instruction()); - } - - std::vector> ret; - ret.reserve(groups.size()); - absl::c_for_each( - groups, [&](auto& iter) { ret.emplace_back(std::move(iter.second)); }); - return ret; -} - -bool HloFusionAnalysis::IsUnrollingColumnReductionBeneficial( - const Shape& input_shape, int64_t num_kept_minor, - bool reduction_is_race_free) const { - if (num_kept_minor % (WarpSize() * 2) != 0) { - return false; - } - if (input_shape.dimensions(input_shape.rank() - 1) < 64) { - return false; - } - - int64_t can_be_vectorized = 0; - int64_t cannot_be_vectorized = 0; - absl::flat_hash_set use_chain_endings; - - for (const HloInstruction* fusion_root : fusion_roots()) { - if (!reduction_is_race_free && - IsReductionFromOrToContiguousDimensions(*fusion_root)) { - // Atomics cannot be vectorized. - cannot_be_vectorized++; - } else { - can_be_vectorized++; - } - use_chain_endings.insert(fusion_root); - } - - // Fusion inputs that have the same dimension as the reduce input and - // only involve in element-wise operations can be vectorized. - absl::flat_hash_set reachable_through_non_elementwise; - HloBfsConsumersFirstTraversal( - fusion_->GetRoots(), *fusion_, [&](auto consumer) { - // We check if the consumer is elementwise, unless this edge is a - // virtual edge that only exists in partially fused HLO. There are two - // types of such edges: - // 1. Edges from producers outside a fusion to a parameter instruction - // within a fusion. Here, the producer is a parameter of the fusion - // instruction. - // 2. Edges from fusion roots to fusion nodes. - if (reachable_through_non_elementwise.contains(consumer) || - (!consumer.instruction().IsElementwise() && - !use_chain_endings.contains(&consumer.instruction()))) { - for (auto producer : consumer.GetOperands()) { - reachable_through_non_elementwise.insert(producer); - } - } - return TraversalResult::kVisitOperands; - }); - - int64_t num_elements = ShapeUtil::ElementsIn(input_shape); - FindFusionArguments(*fusion_, [&](auto arg) { - if (!reachable_through_non_elementwise.contains(arg) && - ShapeUtil::SameDimensions(input_shape, arg.shape())) { - ++can_be_vectorized; - } - - // Fusion inputs with more elements than the reduce op input must - // participate in non-elementwise operations and we assume that they are - // not vectorizable for the purpose of estimating the benefit of - // unrolling. If the kernel is unrolled even with such an assumption, - // and the accesses to those inputs turn out to be vectorizable, the - // compiler will still vectorize them. - if (ShapeUtil::ElementsIn(arg.shape()) > num_elements) { - ++cannot_be_vectorized; - } - }); - - if (can_be_vectorized < cannot_be_vectorized) { - return false; - } - - return MaxBeneficialColumnReductionUnrollBasedOnBlockSize() > 1; -} - -bool HloFusionAnalysis::CanVectorizeReduction( - const ReductionDimensions& reduction_dimensions, int num_threads_x, - Vector3 reduction_tiling, const Shape& input_shape, - bool reduction_is_race_free) const { - if (!reduction_dimensions.is_row_reduction) { - return IsUnrollingColumnReductionBeneficial( - input_shape, reduction_dimensions.dimensions[kDimX], - reduction_is_race_free); - } - - if (reduction_dimensions.dimensions[kDimX] % 2 != 0 || - MayPreventVectorization(*fusion_)) { - return false; - } - - // Enabling vectorization if number of threads is <= warpsize leads to half or - // more of the threads not doing any work. - if (reduction_dimensions.is_row_reduction && num_threads_x <= WarpSize()) { - return false; - } - - const auto* cuda_cc = std::get_if( - &device_info_->gpu_compute_capability()); - if (cuda_cc == nullptr) return false; - if (cuda_cc->IsAtLeast(se::CudaComputeCapability::VOLTA)) return true; - if (cuda_cc->IsAtLeast(se::CudaComputeCapability::PASCAL_)) { - return input_output_info_.smallest_input_dtype_bits <= 32 && - reduction_dimensions.dimensions[kDimX] % - (reduction_tiling[2] * num_threads_x) == - 0; - } - return false; -} - -int HloFusionAnalysis::CalculateVirtualThreadScalingFactorForReduction( - const ReductionDimensions& reduction_dimensions) const { - int64_t dimx = reduction_dimensions.dimensions[kDimX]; - if (reduction_dimensions.is_row_reduction && dimx <= 128) { - int rows_per_warp = RowReductionGetRowsPerWarp(dimx); - const auto* cuda_cc = std::get_if( - &device_info_->gpu_compute_capability()); - if (cuda_cc != nullptr && - cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) { - return rows_per_warp * 3; - } - return rows_per_warp * 5; - } - return 1; -} - -std::optional -HloFusionAnalysis::ComputeReductionCodegenInfo( - const HloInstruction* hero_reduction) const { - if (!hero_reduction) { - return std::nullopt; - } - - Shape input_shape = hero_reduction->operand(0)->shape(); - ReductionDimensions reduction_dimensions = - GetReductionKindAndContiguousComponents(*hero_reduction); - VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction - << " " << reduction_dimensions.dimensions[0] << " " - << reduction_dimensions.dimensions[1] << " " - << reduction_dimensions.dimensions[2]; - Vector3 reduction_tiling = GetReductionTiling(reduction_dimensions); - - int64_t fan_out = fusion_roots().size(); - int64_t num_threads_y = - reduction_dimensions.is_row_reduction ? 1 : WarpSize(); - int64_t num_threads_x = [&] { - if (reduction_dimensions.is_row_reduction) { - if (RowReductionGetRowsPerWarp(reduction_dimensions.dimensions[2]) > 1) { - return reduction_dimensions.dimensions[2]; - } - // Use 512 as default block size (threads per block) for row reductions. - // For multi-output fusions, reduce the block size further to decrease - // register pressure when multiple outputs are computed by each thread. - int64_t max_block_size = std::max( - MinThreadsXRowReduction(hero_reduction->GetModule()->config()), - static_cast(512LL / NearestPowerOfTwo(fan_out))); - return std::min(max_block_size, - RoundUpTo(CeilOfRatio(reduction_dimensions.dimensions[2], - reduction_tiling[2]), - WarpSize())); - } - return WarpSize(); - }(); - - TilingScheme::IndexingOrder indexing_order = - reduction_dimensions.is_row_reduction ? kStridedIndexingX - : kLinearIndexingX; - auto instr_index_groups = GroupDisjointReductions(); - int64_t shmem_usage = ReductionProjectedShmemUsageBytes(reduction_dimensions, - instr_index_groups); - const int64_t shmem_budget = device_info_->shared_memory_per_block(); - bool reduction_is_race_free = ReductionIsRaceFree( - hero_reduction->GetModule()->config(), reduction_dimensions); - bool vectorize = - // Vectorization might cause us to run out of budget. - (shmem_usage * 2 <= shmem_budget) && - CanVectorizeReduction(reduction_dimensions, num_threads_x, - reduction_tiling, input_shape, - reduction_is_race_free); - int vector_size = vectorize ? 2 : 1; - - // TODO(b/283542954): Autotune num_partial_results? This can make a big - // difference, e.g. by affecting register spilling. - int num_partial_results = 1; - if (!reduction_dimensions.is_row_reduction && vectorize) { - int smallest_input_dtype_bits = - input_output_info_.smallest_input_dtype_bits; - if (smallest_input_dtype_bits <= 32) { - // Make sure to use all the data read at once. - // Instead of hardcoding the granularity, we can query the granularity we - // need like this: - // size_t granularity = 0; - // CUresult res = cuCtxGetLimit(&granularity, - // CU_LIMIT_MAX_L2_FETCH_GRANULARITY); // 0x05 - // But we need a context to be active. Which isn't the case here. - num_partial_results = std::min(64 / smallest_input_dtype_bits, 8); - - // Limit register pressure for MOF, but still use a minimum of 2. - num_partial_results /= fan_out; - // We can't go below 2 for the unroll factor -- if we wanted to use 1 as - // the unroll factor, we should have set this reduction as unvectorized. - num_partial_results = std::max(num_partial_results, 2); - } else { - num_partial_results = 2; - } - - while (num_partial_results != 1 && - shmem_usage * num_partial_results > shmem_budget) { - num_partial_results /= 2; - } - reduction_tiling[kDimX] *= num_partial_results; - } - - VLOG(3) << "Each thread will produce " << num_partial_results << " output(s)"; - - Vector3 num_threads = {1, num_threads_y, num_threads_x}; - int virtual_thread_scaling_factor = - CalculateVirtualThreadScalingFactorForReduction(reduction_dimensions); - VLOG(2) << "Using virtual thread scaling: " << virtual_thread_scaling_factor; - - TilingScheme tiling_scheme(reduction_dimensions.dimensions, reduction_tiling, - num_threads, indexing_order, vector_size, - virtual_thread_scaling_factor); - return ReductionCodegenInfo( - tiling_scheme, num_partial_results, reduction_dimensions.is_row_reduction, - reduction_is_race_free, std::move(instr_index_groups), hero_reduction); -} - -std::optional AnalyzeProducerConsumerFusion( +HloFusionAnalysis AnalyzeProducerConsumerFusion( const HloInstruction& producer, const HloInstruction& consumer, const se::DeviceDescription& device_info) { - auto ret = HloFusionAnalysis::Create( - FusionBackendConfig::default_instance(), + return HloFusionAnalysis::Create( + consumer.has_backend_config() + ? consumer.backend_config()->fusion_backend_config() + : producer.backend_config() + ->fusion_backend_config(), std::make_unique( HloFusionAdaptor::ForInstruction(&producer), HloFusionAdaptor::ForInstruction(&consumer)), &device_info); - if (!ret.ok()) return std::nullopt; - return {std::move(*ret)}; } -std::optional AnalyzeFusion( - const HloInstruction& consumer, const se::DeviceDescription& device_info) { - auto ret = HloFusionAnalysis::Create( - FusionBackendConfig::default_instance(), +HloFusionAnalysis AnalyzeFusion(const HloInstruction& consumer, + const se::DeviceDescription& device_info) { + return HloFusionAnalysis::Create( + consumer.backend_config()->fusion_backend_config(), HloFusionAdaptor::ForInstruction(&consumer), &device_info); - if (!ret.ok()) return std::nullopt; - return {std::move(*ret)}; } } // namespace gpu diff --git a/xla/service/gpu/hlo_fusion_analysis.h b/xla/service/gpu/hlo_fusion_analysis.h index 786fe3d8a86d9..26d54247bff47 100644 --- a/xla/service/gpu/hlo_fusion_analysis.h +++ b/xla/service/gpu/hlo_fusion_analysis.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,20 +16,16 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_HLO_FUSION_ANALYSIS_H_ #define XLA_SERVICE_GPU_HLO_FUSION_ANALYSIS_H_ +#include #include -#include #include -#include "xla/hlo/ir/hlo_computation.h" +#include "absl/log/check.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/kernel_mapping_scheme.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/reduction_utils.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" namespace xla { @@ -44,62 +40,58 @@ class HloFusionAnalysis { kTriton, kReduction, kTranspose, + kConcatenate, kInputSlices, kScatter, + kCuDnn, }; - static StatusOr Create( - FusionBackendConfig backend_config, - std::unique_ptr fusion, - const se::DeviceDescription* device_info); - static StatusOr Create( - const HloFusionInstruction* fusion, - const se::DeviceDescription* device_info); + // Precomputed information about inputs (arguments) and outputs (roots) of the + // fusion. + struct InputOutputInfo { + bool has_4_bit_input; + bool has_4_bit_output; + int smallest_input_dtype_bits; + }; + + static HloFusionAnalysis Create(FusionBackendConfig backend_config, + std::unique_ptr fusion, + const se::DeviceDescription* device_info); + static HloFusionAnalysis Create(const HloFusionInstruction* fusion, + const se::DeviceDescription* device_info); const std::vector& fusion_roots() const { return fusion_roots_; } + const std::vector& fusion_heroes() const { + return fusion_heroes_; + } const HloFusionAdaptor& fusion() const { return *fusion_; } // Determines the fusion type for the emitter. EmitterFusionKind GetEmitterFusionKind() const; - // Determines the launch dimensions for the fusion. The fusion kind must not - // be `kTriton`. - StatusOr GetLaunchDimensions() const; + // Returns the hero reduction of the computation. + const HloInstruction* FindHeroReduction() const; - // Calculates the reduction information. Returns `nullptr` if the fusion is - // not a reduction. - const ReductionCodegenInfo* GetReductionCodegenInfo() const { - return reduction_codegen_info_.has_value() ? &*reduction_codegen_info_ - : nullptr; - } + const se::DeviceDescription& device_info() const { return *device_info_; } - // Calculates the transpose tiling information. Returns `nullptr` if the - // fusion is not a transpose. - const TilingScheme* GetTransposeTilingScheme() const { - return transpose_tiling_scheme_.has_value() ? &*transpose_tiling_scheme_ - : nullptr; + const FusionBackendConfig& fusion_backend_config() const { + return fusion_backend_config_; } - // Calculates the loop fusion config. Returns `nullptr` if the fusion is not a - // loop. - const LaunchDimensionsConfig* GetLoopFusionConfig() const { - return loop_fusion_config_.has_value() ? &*loop_fusion_config_ : nullptr; + // Returns the tiled transpose description. Requires that GetEmitterFusionKind + // returns kTranspose. + const TransposeDescription& tiled_transpose() const { + CHECK(tiled_transpose_.has_value()); + return *tiled_transpose_; } - // Returns the hero reduction of the computation. - const HloInstruction* FindHeroReduction() const; + const InputOutputInfo& input_output_info() const { + return input_output_info_; + } private: - // Precomputed information about inputs (arguments) and outputs (roots) of the - // fusion. - struct InputOutputInfo { - bool has_4_bit_input; - bool has_4_bit_output; - int smallest_input_dtype_bits; - }; - HloFusionAnalysis(FusionBackendConfig fusion_backend_config, std::vector fusion_roots, std::unique_ptr fusion, @@ -108,22 +100,6 @@ class HloFusionAnalysis { std::optional tiled_transpose, InputOutputInfo input_output_info); - const Shape& GetElementShape() const; - int64_t MaxBeneficialColumnReductionUnrollBasedOnBlockSize() const; - std::vector> GroupDisjointReductions() - const; - bool IsUnrollingColumnReductionBeneficial(const Shape& input_shape, - int64_t num_kept_minor, - bool reduction_is_race_free) const; - bool CanVectorizeReduction(const ReductionDimensions& reduction_dimensions, - int num_threads_x, Vector3 reduction_tiling, - const Shape& input_shape, - bool reduction_is_race_free) const; - int CalculateVirtualThreadScalingFactorForReduction( - const ReductionDimensions& reduction_dimensions) const; - std::optional ComputeReductionCodegenInfo( - const HloInstruction* hero_reduction) const; - std::optional ComputeLoopFusionConfig() const; bool HasConsistentTransposeHeros() const; FusionBackendConfig fusion_backend_config_; @@ -133,21 +109,18 @@ class HloFusionAnalysis { const se::DeviceDescription* device_info_; std::optional tiled_transpose_; InputOutputInfo input_output_info_; - - std::optional reduction_codegen_info_; - std::optional transpose_tiling_scheme_; - std::optional loop_fusion_config_; }; // Creates a HloFusionAnalysis that analyzes a hypothetical fusion of producer // into consumer. -std::optional AnalyzeProducerConsumerFusion( +HloFusionAnalysis AnalyzeProducerConsumerFusion( const HloInstruction& producer, const HloInstruction& consumer, const se::DeviceDescription& device_info); + // Creates a HloFusionAnalysis that analyzes just consumer as a standalone // fusion. -std::optional AnalyzeFusion( - const HloInstruction& consumer, const se::DeviceDescription& device_info); +HloFusionAnalysis AnalyzeFusion(const HloInstruction& consumer, + const se::DeviceDescription& device_info); } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/hlo_fusion_analysis_test.cc b/xla/service/gpu/hlo_fusion_analysis_test.cc index 0a02760077547..eb5add834a553 100644 --- a/xla/service/gpu/hlo_fusion_analysis_test.cc +++ b/xla/service/gpu/hlo_fusion_analysis_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,9 +14,12 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/hlo_fusion_analysis.h" + +#include #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_traversal.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" @@ -27,7 +30,7 @@ namespace { class HloFusionAnalysisTest : public HloTestBase {}; TEST_F(HloFusionAnalysisTest, DoesNotPeekOutsideBoundary) { - auto module = ParseAndReturnVerifiedModule(R"( + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule test_module add { @@ -41,26 +44,23 @@ TEST_F(HloFusionAnalysisTest, DoesNotPeekOutsideBoundary) { %p1 = f32[] parameter(1) %reduce = f32[] reduce(%p0, %p1), dimensions={0}, to_apply=add ROOT %bitcast = s32[] bitcast(%reduce) - })") - .value(); + })")); auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); auto* root = module->entry_computation()->root_instruction(); auto analysis = AnalyzeFusion(*root, device_info); - ASSERT_NE(analysis, std::nullopt); - EXPECT_EQ(analysis->GetEmitterFusionKind(), + EXPECT_EQ(analysis.GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kLoop); auto analysis_fused = AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info); - ASSERT_NE(analysis_fused, std::nullopt); - EXPECT_EQ(analysis_fused->GetEmitterFusionKind(), + EXPECT_EQ(analysis_fused.GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction); } TEST_F(HloFusionAnalysisTest, ReductionWithMultipleUsers) { - auto module = ParseAndReturnVerifiedModule(R"( + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule test_module add { @@ -69,32 +69,34 @@ TEST_F(HloFusionAnalysisTest, ReductionWithMultipleUsers) { ROOT add = f32[] add(p0, p1) } - ENTRY main { + fused_computation { %p0 = f32[1024] parameter(0) %p1 = f32[] parameter(1) %reduce = f32[] reduce(%p0, %p1), dimensions={0}, to_apply=add %negate = f32[] negate(%reduce) %log = f32[] log(%reduce) ROOT %tuple = (f32[], f32[]) tuple(%negate, %log) - })") - .value(); + } + + ENTRY main { + %p0 = f32[1024] parameter(0) + %p1 = f32[] parameter(1) + ROOT %fusion = (f32[], f32[]) fusion(%p0, %p1), kind=kLoop, calls=fused_computation + })")); auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - TF_ASSERT_OK_AND_ASSIGN( - auto analysis, - HloFusionAnalysis::Create( - FusionBackendConfig::default_instance(), - HloFusionAdaptor::ForComputation(module->entry_computation()), - &device_info)); - // This fusion cannot use the reduction emitter because the reduce has two - // users. + auto analysis = HloFusionAnalysis::Create( + FusionBackendConfig::default_instance(), + HloFusionAdaptor::ForInstruction( + module->entry_computation()->root_instruction()), + &device_info); EXPECT_EQ(analysis.GetEmitterFusionKind(), - HloFusionAnalysis::EmitterFusionKind::kLoop); + HloFusionAnalysis::EmitterFusionKind::kReduction); } TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusion) { - auto module = ParseAndReturnVerifiedModule(R"( + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule test_module add { @@ -114,22 +116,20 @@ TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusion) { %p0 = f32[1024] parameter(0) %p1 = f32[] parameter(1) ROOT %fusion = f32[] fusion(%p0, %p1), kind=kInput, calls=fused_computation - })") - .value(); + })")); auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); auto* root = module->entry_computation()->root_instruction(); - TF_ASSERT_OK_AND_ASSIGN( - auto analysis, HloFusionAnalysis::Create( - FusionBackendConfig::default_instance(), - HloFusionAdaptor::ForInstruction(root), &device_info)); + auto analysis = HloFusionAnalysis::Create( + FusionBackendConfig::default_instance(), + HloFusionAdaptor::ForInstruction(root), &device_info); EXPECT_EQ(analysis.GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction); } TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusionPartiallyFused) { - auto module = ParseAndReturnVerifiedModule(R"( + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule test_module add { @@ -149,8 +149,7 @@ TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusionPartiallyFused) { %p1 = f32[] parameter(1) %fusion = f32[] fusion(%p0, %p1), kind=kInput, calls=fusion ROOT %negate = f32[] negate(%fusion) - })") - .value(); + })")); auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); @@ -158,13 +157,12 @@ TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusionPartiallyFused) { auto analysis = AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info); - ASSERT_NE(analysis, std::nullopt); - EXPECT_EQ(analysis->GetEmitterFusionKind(), + EXPECT_EQ(analysis.GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction); } TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusionPartiallyFusedInConsumer) { - auto module = ParseAndReturnVerifiedModule(R"( + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule test_module add { @@ -183,21 +181,19 @@ TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusionPartiallyFusedInConsumer) { %p1 = f32[] parameter(1) %reduce = f32[] reduce(%p0, %p1), dimensions={0}, to_apply=add ROOT %fusion = f32[] fusion(%reduce), kind=kInput, calls=fusion - })") - .value(); + })")); auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); auto* root = module->entry_computation()->root_instruction(); auto analysis = AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info); - ASSERT_NE(analysis, std::nullopt); - EXPECT_EQ(analysis->GetEmitterFusionKind(), + EXPECT_EQ(analysis.GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction); } TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusionPartiallyFusedInBoth) { - auto module = ParseAndReturnVerifiedModule(R"( + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule test_module add { @@ -222,24 +218,90 @@ TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusionPartiallyFusedInBoth) { %p1 = f32[] parameter(1) %fusion.1 = f32[] fusion(%p0, %p1), kind=kInput, calls=fusion.1 ROOT %fusion.2 = f32[] fusion(%fusion.1), kind=kInput, calls=fusion.2 - })") - .value(); + })")); + + auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = + AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info); + EXPECT_EQ(analysis.GetEmitterFusionKind(), + HloFusionAnalysis::EmitterFusionKind::kReduction); +} + +TEST_F(HloFusionAnalysisTest, ReduceMultiOutputFusionWithTransposeBitcast) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule test_module + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + fusion { + %p0 = f32[1024, 512]{1,0} parameter(0) + %p1 = f32[] parameter(1) + %reduce = f32[1024]{0} reduce(%p0, %p1), dimensions={1}, to_apply=add + %bitcast = f32[512, 1024]{0,1} bitcast(%p0) + ROOT res = (f32[1024]{0}, f32[512, 1024]{0,1}) tuple(%reduce, %bitcast) + } + + ENTRY main { + %p0 = f32[1024, 512]{1,0} parameter(0) + %p1 = f32[] parameter(1) + ROOT %fusion = (f32[1024]{0}, f32[512, 1024]{0,1}) fusion(%p0, %p1), kind=kInput, calls=fusion + })")); auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); auto* root = module->entry_computation()->root_instruction(); auto analysis = AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info); - ASSERT_NE(analysis, std::nullopt); - EXPECT_EQ(analysis->GetEmitterFusionKind(), + EXPECT_EQ(analysis.GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction); } +TEST_F(HloFusionAnalysisTest, InvalidReduceMultiOutputFusion) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule test_module + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + fusion { + %p0 = f32[1024, 1024]{1,0} parameter(0) + %p1 = f32[] parameter(1) + %reduce = f32[1024]{0} reduce(%p0, %p1), dimensions={0}, to_apply=add + %reduce2 = f32[1024]{0} reduce(%p0, %p1), dimensions={1}, to_apply=add + ROOT res = (f32[1024]{0}, f32[1024]{0}) tuple(reduce, reduce2) + } + + ENTRY main { + %p0 = f32[1024, 1024]{1,0} parameter(0) + %p1 = f32[] parameter(1) + ROOT %fusion = (f32[1024]{0}, f32[1024]{0}) fusion(%p0, %p1), kind=kInput, calls=fusion + })")); + + auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = + AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info); + // We expect to fallback to the loop emitter, because the two reductions are + // not compatible as they reduce over different dimensions. + EXPECT_EQ(analysis.GetEmitterFusionKind(), + HloFusionAnalysis::EmitterFusionKind::kLoop); +} + TEST_F(HloFusionAnalysisTest, InvalidDevice) { // Verifies that an analysis can be created even with an invalid/empty device // info, and that the emitter type is determined correctly. // Don't rely on this behavior. - auto module = ParseAndReturnVerifiedModule(R"( + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule test_module add { @@ -253,8 +315,7 @@ TEST_F(HloFusionAnalysisTest, InvalidDevice) { %p1 = f32[] parameter(1) %reduce = f32[128] reduce(%p0, %p1), dimensions={0}, to_apply=add ROOT %bitcast = s32[128] bitcast(%reduce) - })") - .value(); + })")); stream_executor::GpuDeviceInfoProto device_info_proto; stream_executor::DeviceDescription device_info(device_info_proto); @@ -262,10 +323,37 @@ TEST_F(HloFusionAnalysisTest, InvalidDevice) { auto* root = module->entry_computation()->root_instruction(); auto analysis_fused = AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info); - ASSERT_NE(analysis_fused, std::nullopt); - EXPECT_EQ(analysis_fused->GetEmitterFusionKind(), + EXPECT_EQ(analysis_fused.GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction); } +TEST_F(HloFusionAnalysisTest, ConcatFusion) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule test_module + + fused_computation { + %p0 = f32[128] parameter(0) + %p1 = f32[128] parameter(1) + %add = f32[128] add(p0, p0) + %concat = f32[256] concatenate(%add, %p1), dimensions={0} + ROOT %negate = f32[256] negate(%concat) + } + + ENTRY main { + %p0 = f32[128] parameter(0) + %p1 = f32[128] parameter(1) + ROOT %fusion = f32[256] fusion(%p0, %p1), kind=kInput, calls=fused_computation + })")); + + auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + + auto* root = module->entry_computation()->root_instruction(); + auto analysis = HloFusionAnalysis::Create( + FusionBackendConfig::default_instance(), + HloFusionAdaptor::ForInstruction(root), &device_info); + EXPECT_EQ(analysis.GetEmitterFusionKind(), + HloFusionAnalysis::EmitterFusionKind::kConcatenate); +} + } // namespace } // namespace xla::gpu diff --git a/xla/service/gpu/hlo_fusion_stats.cc b/xla/service/gpu/hlo_fusion_stats.cc index 572428c2e8ed5..5fd9a0da24cfd 100644 --- a/xla/service/gpu/hlo_fusion_stats.cc +++ b/xla/service/gpu/hlo_fusion_stats.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,16 +15,18 @@ limitations under the License. #include "xla/service/gpu/hlo_fusion_stats.h" +#include #include -#include "absl/strings/match.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/status.h" #include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -36,7 +38,7 @@ class OpcodeCollector : public ConstDfsHloVisitorWithDefault { std::set GetUniqueOpcodes() { return opcodes_; } protected: - Status DefaultAction(const xla::HloInstruction* instr) final { + absl::Status DefaultAction(const xla::HloInstruction* instr) final { switch (instr->opcode()) { case HloOpcode::kConstant: break; @@ -47,6 +49,7 @@ class OpcodeCollector : public ConstDfsHloVisitorWithDefault { case HloOpcode::kCbrt: case HloOpcode::kCeil: case HloOpcode::kCos: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -73,7 +76,7 @@ class OpcodeCollector : public ConstDfsHloVisitorWithDefault { default: opcodes_.insert(std::string(HloOpcodeString(instr->opcode()))); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -82,7 +85,7 @@ class OpcodeCollector : public ConstDfsHloVisitorWithDefault { std::set GetUniqueOpcodes(HloComputation* computation) { OpcodeCollector collector; - if (computation->Accept(&collector) != OkStatus()) { + if (!computation->Accept(&collector).ok()) { return {}; } return collector.GetUniqueOpcodes(); @@ -99,9 +102,9 @@ std::string HloOpcodeHistogram::ToString() { return result; } -Status HloFusionStatsVisitor::RunOnModule(HloModule* module) { +absl::Status HloFusionStatsVisitor::RunOnModule(HloModule* module) { TF_RETURN_IF_ERROR(module->entry_computation()->Accept(this)); - return OkStatus(); + return absl::OkStatus(); } std::string HloFusionStatsVisitor::ToString() { @@ -113,11 +116,12 @@ std::string HloFusionStatsVisitor::ToString() { input_fusion_opcode_histogram_.ToString()); } -Status HloFusionStatsVisitor::DefaultAction(const xla::HloInstruction* instr) { - return OkStatus(); +absl::Status HloFusionStatsVisitor::DefaultAction( + const xla::HloInstruction* instr) { + return absl::OkStatus(); } -Status HloFusionStatsVisitor::HandleFusion(const HloInstruction* fusion) { +absl::Status HloFusionStatsVisitor::HandleFusion(const HloInstruction* fusion) { num_fusions_++; std::set opcodes = GetUniqueOpcodes(fusion->fused_instructions_computation()); @@ -128,7 +132,7 @@ Status HloFusionStatsVisitor::HandleFusion(const HloInstruction* fusion) { num_input_fusions_++; input_fusion_opcode_histogram_[opcodes]++; } - return OkStatus(); + return absl::OkStatus(); } } // namespace gpu diff --git a/xla/service/gpu/hlo_fusion_stats.h b/xla/service/gpu/hlo_fusion_stats.h index 22e9348200d5b..0138ae515e167 100644 --- a/xla/service/gpu/hlo_fusion_stats.h +++ b/xla/service/gpu/hlo_fusion_stats.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,13 +17,14 @@ limitations under the License. #define XLA_SERVICE_GPU_HLO_FUSION_STATS_H_ #include +#include +#include #include -#include "absl/strings/string_view.h" +#include "absl/status/status.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" // Read-only pass logging statistics about HLO fusion ops in the module. Enabled // at VLOG level 1 only. @@ -37,14 +38,14 @@ class HloOpcodeHistogram : public std::map, int64_t> { class HloFusionStatsVisitor : public ConstDfsHloVisitorWithDefault { public: - Status RunOnModule(HloModule* module); + absl::Status RunOnModule(HloModule* module); std::string ToString(); protected: - Status DefaultAction(const xla::HloInstruction* instr) final; + absl::Status DefaultAction(const xla::HloInstruction* instr) final; - Status HandleFusion(const HloInstruction* fusion) override; + absl::Status HandleFusion(const HloInstruction* fusion) override; private: int64_t num_fusions_ = 0; diff --git a/xla/service/gpu/hlo_fusion_stats_test.cc b/xla/service/gpu/hlo_fusion_stats_test.cc index cc944d8beac55..0a19b213922b4 100644 --- a/xla/service/gpu/hlo_fusion_stats_test.cc +++ b/xla/service/gpu/hlo_fusion_stats_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ limitations under the License. #include +#include +#include "absl/strings/match.h" #include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" #include "tsl/lib/core/status_test_util.h" diff --git a/xla/service/gpu/hlo_to_ir_bindings.cc b/xla/service/gpu/hlo_to_ir_bindings.cc index 475830d7a4868..34c1a84fb81a3 100644 --- a/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/xla/service/gpu/hlo_to_ir_bindings.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,17 +15,27 @@ limitations under the License. #include "xla/service/gpu/hlo_to_ir_bindings.h" +#include + +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/Instructions.h" +#include "llvm/Support/Casting.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/map_util.h" #include "xla/service/llvm_ir/buffer_assignment_util.h" +#include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_util.h" -#include "xla/service/llvm_ir/tuple_ops.h" +#include "xla/shape.h" +#include "xla/shape_tree.h" +#include "xla/shape_util.h" #include "tsl/platform/logging.h" namespace xla { @@ -97,20 +107,6 @@ void HloToIrBindings::EmitBasePointersForHlos( } } -llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, - llvm::Value* base_ptr) { - // TODO(b/26344050): tighten the alignment based on the real element type. - if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) { - return llvm_ir::EmitGetTupleElement( - gte->shape(), gte->tuple_index(), /*alignment=*/1, base_ptr, - llvm_ir::ShapeToIrType(gte->operand(0)->shape(), module_), b_); - } - return llvm_ir::EmitGetTupleElement( - gte->shape(), gte->tuple_index(), /*alignment=*/1, - EmitGetTupleElement(gte->operand(0), base_ptr), - llvm_ir::ShapeToIrType(gte->operand(0)->shape(), module_), b_); -} - // Returns true if `value` has a name that should not be changed. static bool HasMeaningfulName(llvm::Value* value) { if (auto* global = llvm::dyn_cast(value)) { @@ -149,26 +145,9 @@ llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, return ir_array; } -void HloToIrBindings::UnbindAllLocalIrValues() { - std::vector hlos_to_unbind; - for (auto& key_value : base_ptrs_) { - if (!llvm::isa( - (key_value.second.element({}))->stripPointerCasts())) { - hlos_to_unbind.push_back(key_value.first); - } - } - for (const HloInstruction* hlo_to_unbind : hlos_to_unbind) { - VLOG(2) << "Unbinding " << hlo_to_unbind->ToString(); - base_ptrs_.erase(hlo_to_unbind); - } -} - std::string HloToIrBindings::ToString() const { std::string s = StrCat("** HloToIrBindings **\n"); StrAppend(&s, " is_nested_=", is_nested_, "\n"); - StrAppend(&s, - " temp_buffer_base_=", llvm_ir::DumpToString(temp_buffer_base_), - "\n"); if (base_ptrs_.empty()) { return s; diff --git a/xla/service/gpu/hlo_to_ir_bindings.h b/xla/service/gpu/hlo_to_ir_bindings.h index 6ccc40246b7c4..1aa84c1a698fb 100644 --- a/xla/service/gpu/hlo_to_ir_bindings.h +++ b/xla/service/gpu/hlo_to_ir_bindings.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,13 +16,17 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_HLO_TO_IR_BINDINGS_H_ #define XLA_SERVICE_GPU_HLO_TO_IR_BINDINGS_H_ +#include + #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" #include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/map_util.h" #include "xla/service/llvm_ir/ir_array.h" +#include "xla/shape_tree.h" +#include "xla/shape_util.h" namespace xla { namespace gpu { @@ -43,21 +47,11 @@ class HloToIrBindings { void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value, ShapeIndexView shape_index = {}); - // Unbinds all IR values that's defined in an LLVM function, e.g., function - // arguments and stack variables. Global variables will be kept in bindings_. - // - // This method is called after emitting code for each top-level HLO. The local - // IR values are out of scope at that point and should not be used. - void UnbindAllLocalIrValues(); - // Returns whether `hlo` is bound to an LLVM IR value. bool BoundToIrValue(const HloInstruction& hlo) const { return base_ptrs_.contains(&hlo); } - llvm::Value* GetTempBufferBase() const { return temp_buffer_base_; } - void SetTempBufferBase(llvm::Value* v) { temp_buffer_base_ = v; } - // A helper method that returns the base pointer of the IrArray containing the // output of "inst".at the given ShapeIndex. llvm::Value* GetBasePointer(const HloInstruction& hlo, @@ -81,10 +75,6 @@ class HloToIrBindings { std::string ToString() const; private: - // Emits IR to resolve (possibly) recursive GetTupleElement instructions. - llvm::Value* EmitGetTupleElement(const HloInstruction* gte, - llvm::Value* base_ptr); - const bool is_nested_; llvm::IRBuilder<>* b_; @@ -96,9 +86,6 @@ class HloToIrBindings { // in the ShapeTree. absl::flat_hash_map> base_ptrs_; - - // The address of the memory block that contains all temporary buffers. - llvm::Value* temp_buffer_base_ = nullptr; }; } // namespace gpu diff --git a/xla/service/gpu/hlo_traversal.cc b/xla/service/gpu/hlo_traversal.cc index dc0b1dcb57a10..5ba3a1e6c1293 100644 --- a/xla/service/gpu/hlo_traversal.cc +++ b/xla/service/gpu/hlo_traversal.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,13 @@ limitations under the License. #include #include +#include #include +#include #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -37,6 +40,20 @@ void ResolveUsers(const HloInstruction* value, const HloInstruction* user, for (const auto* param_user : param->users()) { fn(param_user); } + } else if (user->opcode() == HloOpcode::kTuple && user->IsRoot()) { + if (auto* fusion = user->parent()->FusionInstruction()) { + // Skip through the tuple -> get-tuple-element ops and directly go to the + // "real" users. + for (const auto* gte : fusion->users()) { + if (gte->opcode() != HloOpcode::kGetTupleElement) { + fn(gte); + continue; + } + for (const auto* gte_user : gte->users()) { + ResolveUsers(gte, gte_user, fn); + } + } + } } else { fn(user); } @@ -46,6 +63,13 @@ const HloInstruction* ResolveOperand(const HloInstruction* operand) { if (operand->opcode() == HloOpcode::kFusion) { return operand->fused_expression_root(); } + // Deal with multi-output fusion operands, which are reached via a + // get-tuple-element op. + if (operand->opcode() == HloOpcode::kGetTupleElement && + operand->operand(0)->opcode() == HloOpcode::kFusion) { + return operand->operand(0)->fused_expression_root()->operand( + operand->tuple_index()); + } if (operand->opcode() == HloOpcode::kParameter) { if (auto* fusion = operand->parent()->FusionInstruction()) { return ResolveOperand(fusion->operand(operand->parameter_number())); @@ -70,6 +94,13 @@ class SingleInstructionFusion : public HloFusionAdaptor { return {instruction_}; } + absl::InlinedVector MakeInstructionPostOrder() + const override { + return {instruction_}; + } + + std::string ToString() const override { return instruction_.ToString(); } + private: HloInstructionAdaptor instruction_; }; @@ -78,6 +109,27 @@ class HloComputationFusion : public HloFusionAdaptor { public: explicit HloComputationFusion(const HloComputation* computation) : computation_(computation) { + // HloFusionAdaptor should only be created for fusion computations, that + // usually have only a few roots, but there is a case when we can it for + // non-fusion computations with thousands of roots. It happens inside + // `FindNonTrivialHero` and it gets very expensive. Calling + // `FindNonTrivialHero` also doesn't make sense on non-fusion computation, + // but `InstructionFusion` and `FusionMerger` depend on this behavoiur in + // `IsProducerConsumerFusible`. + // + // `FindNonTrivialHero` only call `ContainsInstruction` and doesn't use + // information about roots, so we can skip looking for roots as performance + // optimization. + // TODO(shyshkov): Clean this up once priority fusion is fully launched. + if (computation->IsFusionComputation()) { + roots_ = FindRoots(computation); + } + } + + static absl::InlinedVector FindRoots( + const HloComputation* computation) { + absl::InlinedVector roots; + std::function get_roots; absl::flat_hash_set roots_set; get_roots = [&](const HloInstruction* instr) { @@ -88,11 +140,13 @@ class HloComputationFusion : public HloFusionAdaptor { } else { HloInstructionAdaptor wrapped{*instr}; if (roots_set.insert(wrapped).second) { - roots_.push_back(wrapped); + roots.push_back(wrapped); } } }; get_roots(computation->root_instruction()); + + return roots; } bool ContainsInstruction(HloInstructionAdaptor instruction) const override { @@ -100,9 +154,38 @@ class HloComputationFusion : public HloFusionAdaptor { } absl::InlinedVector GetRoots() const override { + CHECK(!roots_.empty()) + << "No roots found in the computation. HloFusionAdaptor was likely " + "created for a non-fusion computation: " + << computation_->ToString(); + return roots_; } + absl::InlinedVector MakeInstructionPostOrder() + const override { + auto post_order = computation_->MakeInstructionPostOrder(); + + absl::InlinedVector result; + result.reserve(post_order.size() - computation_->num_parameters()); + + for (auto* instr : post_order) { + // Skip parameter and root tuple as FusionAdaptor hides their existence. + // HloInstructionAdaptor will look through them and return operands + // outside of the computation if necessary. We don't expect to see any + // internal tuples, but the other logic only handles root tuples + // explicitly. + if (instr->opcode() == HloOpcode::kParameter || + (instr->opcode() == HloOpcode::kTuple && instr->IsRoot())) { + continue; + } + result.emplace_back(*instr); + } + return result; + } + + std::string ToString() const override { return computation_->ToString(); } + private: const HloComputation* computation_; absl::InlinedVector roots_; @@ -175,69 +258,101 @@ bool operator==(const HloInstructionAdaptor& lhs, lhs.instruction_->unique_id() == rhs.instruction_->unique_id(); } -void HloBfsConsumersFirstTraversal( +namespace { +void HloBfsTraversal( absl::Span roots, const HloFusionAdaptor& fusion, - const std::function& visit, - const std::function& visit_arg) { + const std::function& + visit_node, + const std::function& visit_arg, + bool visit_operands) { absl::flat_hash_set visited; std::queue q; - auto enqueue_operands = [&](const HloInstructionAdaptor& node) { - for (auto operand : node.GetOperands()) { - if (visited.insert(operand).second) { - if (fusion.ContainsInstruction(operand)) { - q.push(operand); + auto enqueue = [&](const HloInstructionAdaptor& node) { + const auto& adjacent_nodes = + visit_operands ? node.GetOperands() : node.GetUsers(); + for (const auto& node : adjacent_nodes) { + if (visited.insert(node).second) { + if (fusion.ContainsInstruction(node)) { + q.push(node); } else { - visit_arg(operand); + visit_arg(node); } } } }; for (auto root : roots) { - q.push(root); + if (visited.insert(root).second) { + q.push(root); + } } while (!q.empty()) { HloInstructionAdaptor node = q.front(); q.pop(); - switch (visit(node)) { - case TraversalResult::kVisitOperands: - enqueue_operands(node); + switch (visit_node(node)) { + case TraversalResult::kAdvance: + enqueue(node); break; - case TraversalResult::kAbortTraversal: + case TraversalResult::kInterrupt: return; - case TraversalResult::kDoNotVisitOperands: + case TraversalResult::kSkip: break; } } } +} // namespace + +void HloBfsConsumersFirstTraversal( + absl::Span roots, + const HloFusionAdaptor& fusion, + const std::function& + visit_node, + const std::function& visit_arg) { + HloBfsTraversal(roots, fusion, visit_node, visit_arg, + /*visit_operands=*/true); +} + +void HloBfsProducersFirstTraversal( + absl::Span producers, + const HloFusionAdaptor& fusion, + const std::function& + visit_node) { + HloBfsTraversal( + producers, fusion, visit_node, [](HloInstructionAdaptor) {}, + /*visit_operands=*/false); +} void FindFusionArguments( const HloFusionAdaptor& fusion, const std::function& visit) { HloBfsConsumersFirstTraversal( fusion.GetRoots(), fusion, - [&](HloInstructionAdaptor) { return TraversalResult::kVisitOperands; }, - visit); + [&](HloInstructionAdaptor) { return TraversalResult::kAdvance; }, visit); } bool HloAnyOf(absl::Span roots, const HloFusionAdaptor& fusion, - const std::function& visit) { - return HloFindIf(roots, fusion, visit).has_value(); + const std::function& visit, + bool visit_operands) { + return HloFindIf(roots, fusion, visit, visit_operands).has_value(); } std::optional HloFindIf( absl::Span roots, const HloFusionAdaptor& fusion, - const std::function& visit) { + const std::function& visit, + bool visit_operands) { std::optional result = std::nullopt; - HloBfsConsumersFirstTraversal(roots, fusion, [&](HloInstructionAdaptor node) { - if (visit(node)) { - result = node; - return TraversalResult::kAbortTraversal; - } - return TraversalResult::kVisitOperands; - }); + HloBfsTraversal( + roots, fusion, + [&](HloInstructionAdaptor node) { + if (visit(node)) { + result = node; + return TraversalResult::kInterrupt; + } + return TraversalResult::kAdvance; + }, + [](HloInstructionAdaptor) {}, visit_operands); return result; } diff --git a/xla/service/gpu/hlo_traversal.h b/xla/service/gpu/hlo_traversal.h index 9a4f29d621f92..d77e669b48f16 100644 --- a/xla/service/gpu/hlo_traversal.h +++ b/xla/service/gpu/hlo_traversal.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,11 +16,20 @@ limitations under the License. #define XLA_SERVICE_GPU_HLO_TRAVERSAL_H_ #include +#include +#include +#include +#include +#include +#include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape.h" namespace xla { namespace gpu { @@ -59,11 +68,19 @@ H AbslHashValue(H h, const HloInstructionAdaptor& m) { m.instruction_->unique_id()); } +template +bool IsOpcodeAnyOf(const HloInstructionAdaptor& adaptor) { + return (adaptor.opcode() == op) || ((adaptor.opcode() == rest) || ...); +} + class HloFusionAdaptor { public: virtual ~HloFusionAdaptor() = default; virtual bool ContainsInstruction(HloInstructionAdaptor instruction) const = 0; virtual absl::InlinedVector GetRoots() const = 0; + virtual absl::InlinedVector + MakeInstructionPostOrder() const = 0; + virtual std::string ToString() const = 0; static std::unique_ptr ForInstruction( const HloInstruction* instruction); @@ -77,6 +94,11 @@ class ProducerConsumerFusion : public HloFusionAdaptor { std::unique_ptr consumer) : producer_(std::move(producer)), consumer_(std::move(consumer)) {} + ProducerConsumerFusion(const HloInstruction* producer, + const HloInstruction* consumer) + : ProducerConsumerFusion(HloFusionAdaptor::ForInstruction(producer), + HloFusionAdaptor::ForInstruction(consumer)) {} + bool ContainsInstruction(HloInstructionAdaptor instruction) const override { return producer_->ContainsInstruction(instruction) || consumer_->ContainsInstruction(instruction); @@ -86,20 +108,41 @@ class ProducerConsumerFusion : public HloFusionAdaptor { return consumer_->GetRoots(); } + absl::InlinedVector MakeInstructionPostOrder() + const override { + auto producer_post_order = producer_->MakeInstructionPostOrder(); + auto consumer_post_order = consumer_->MakeInstructionPostOrder(); + + producer_post_order.reserve(consumer_post_order.size() + + producer_post_order.size()); + + absl::c_move(consumer_post_order, std::back_inserter(producer_post_order)); + + return producer_post_order; + } + + std::string ToString() const override { + // TODO: Add a parameter to indent output on nested adaptor for better + // visual representation. Nested producer-consumers fusion are not used in + // practice yet. + return absl::StrJoin({std::string("producer-consumer fusion:"), + producer_->ToString(), consumer_->ToString()}, + "\n"); + } + private: std::unique_ptr producer_; std::unique_ptr consumer_; }; enum class TraversalResult { - // Visit the operands of this node. - kVisitOperands, + // Visit the operands/users of this node. + kAdvance, // Do not visit any more nodes. - kAbortTraversal, - // Do not visit the operands of this node (but continue the traversal - // otherwise). If the node visitation function returns this, the `boundary` - // condition will not be evaluated. - kDoNotVisitOperands, + kInterrupt, + // Do not visit the operands/users of this node (but continue the traversal + // otherwise). + kSkip, }; // Visit the HLO nodes starting from `roots` in BFS order (consumers before @@ -112,20 +155,32 @@ void HloBfsConsumersFirstTraversal( const std::function& visit_arg = [](HloInstructionAdaptor) {}); +// Visit the HLO nodes starting from `producers` in BFS order following the +// `user` edges. Each node will be visited exactly once. +void HloBfsProducersFirstTraversal( + absl::Span producers, + const HloFusionAdaptor& fusion, + const std::function& + visit_node); + // Visit the HLO nodes starting from `roots`, returning true if the return value // of `visit` for any of nodes is true. Uses the same order as -// `HloBfsConsumersFirstTraversal`. +// `HloBfsConsumersFirstTraversal` if `visit_operands` is true. Otherwise the +// same order as `HloBfsProducersFirstTraversal` is used. bool HloAnyOf(absl::Span roots, const HloFusionAdaptor& fusion, - const std::function& visit); + const std::function& visit, + bool visit_operands = true); // Visit the HLO nodes stating from `roots`, returning the first // node for which `visit` returns true, or `nullptr` if no node matches. Uses -// the same order as `HloBfsConsumersFirstTraversal`. +// the same order as `HloBfsConsumersFirstTraversal` if `visit_operands` is +// true. Otherwise the same order as `HloBfsProducersFirstTraversal` is used. std::optional HloFindIf( absl::Span roots, const HloFusionAdaptor& fusion, - const std::function& visit); + const std::function& visit, + bool visit_operands = true); // Visit the producers of all parameters that are needed by the fusion. void FindFusionArguments( diff --git a/xla/service/gpu/hlo_traversal_test.cc b/xla/service/gpu/hlo_traversal_test.cc index d6f18c17a30f5..c7e3f0db3b7b4 100644 --- a/xla/service/gpu/hlo_traversal_test.cc +++ b/xla/service/gpu/hlo_traversal_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/hlo_traversal.h" +#include #include +#include #include #include @@ -29,6 +31,8 @@ namespace { using ::testing::ElementsAre; +MATCHER_P(InstructionAdaptorName, name, "") { return arg.name() == name; } + class HloTraversalTest : public HloTestBase {}; const char kTestModule[] = R"( @@ -47,6 +51,14 @@ const char kTestModule[] = R"( ROOT reduce.1 = f32[] reduce(mul, p0.1), dimensions={0}, to_apply=scalar_add_computation } + fused_computation_1 { + p0.2 = f32[] parameter(0) + zero = f32[] constant(0.0) + is_positive = pred[] compare(p0.2, zero), direction=GE + not = pred[] not(is_positive) + ROOT tuple = (pred[], pred[]) tuple(is_positive, not) + } + ENTRY entry { p0 = f32[] parameter(0) p1 = f32[128] parameter(1) @@ -54,19 +66,21 @@ const char kTestModule[] = R"( log = f32[128] log(sum) negate = f32[128] negate(log) fusion = f32[] fusion(p0, negate), kind=kLoop, calls=fused_computation - ROOT difference = f32[] subtract(fusion, p0) + fusion2 = (pred[], pred[]) fusion(fusion), kind=kLoop, calls=fused_computation_1 + gte = pred[] get-tuple-element(fusion2), index=0 + ROOT select = f32[] select(gte, fusion, p0) })"; TEST_F(HloTraversalTest, AdaptorOperands) { auto module = ParseAndReturnVerifiedModule(kTestModule).value(); HloInstructionAdaptor instr{ - *module->entry_computation()->GetInstructionWithName("difference")}; + *module->entry_computation()->GetInstructionWithName("select")}; - auto operands = instr.GetOperands(); - ASSERT_EQ(operands.size(), 2); - EXPECT_EQ(operands[0].name(), "reduce.1"); - EXPECT_EQ(operands[1].name(), "p0"); + EXPECT_THAT(instr.GetOperands(), + ElementsAre(InstructionAdaptorName("is_positive"), + InstructionAdaptorName("reduce.1"), + InstructionAdaptorName("p0"))); } TEST_F(HloTraversalTest, AdaptorUsers) { @@ -80,27 +94,35 @@ TEST_F(HloTraversalTest, AdaptorUsers) { ROOT t = (f32[], f32[]) tuple(neg, add) } + fused_computation_1 { + p0.0 = f32[] parameter(0) + mul = f32[] multiply(p0.0, p0.0) + ROOT neg.1 = f32[] negate(mul) + } + ENTRY entry { p0 = f32[] parameter(0) fusion = (f32[], f32[]) fusion(p0), kind=kLoop, calls=fused_computation - ROOT gte = f32[] get-tuple-element(fusion), index=0 + gte = f32[] get-tuple-element(fusion), index=0 + add.1 = f32[] add(p0, gte) + fusion2 = f32[] fusion(gte), kind=kLoop, calls=fused_computation_1 + ROOT res = (f32[], (f32[], f32[]), f32[]) tuple(add.1, fusion, fusion2) } )") .value(); - auto get_single_user = [](auto instr) { - auto users = instr.GetUsers(); - EXPECT_EQ(users.size(), 1); - return users[0]; - }; - HloInstructionAdaptor add{*module->GetComputationWithName("fused_computation") ->GetInstructionWithName("add")}; - EXPECT_EQ(get_single_user(add).name(), "t"); - EXPECT_EQ(get_single_user(get_single_user(add)).name(), "gte"); + EXPECT_THAT(add.GetUsers(), ElementsAre(InstructionAdaptorName("add.1"), + InstructionAdaptorName("mul"), + InstructionAdaptorName("res"))); + HloInstructionAdaptor mul{ + *module->GetComputationWithName("fused_computation_1") + ->GetInstructionWithName("mul")}; + EXPECT_THAT(mul.GetUsers(), ElementsAre(InstructionAdaptorName("neg.1"))); } -TEST_F(HloTraversalTest, TraverseFusion) { +TEST_F(HloTraversalTest, TraverseFusionConsumerFirst) { auto module = ParseAndReturnVerifiedModule(kTestModule).value(); std::vector visited_nodes; std::vector visited_args; @@ -110,7 +132,29 @@ TEST_F(HloTraversalTest, TraverseFusion) { fusion->GetRoots(), *fusion, [&](HloInstructionAdaptor node) { visited_nodes.emplace_back(node.name()); - return TraversalResult::kVisitOperands; + return TraversalResult::kAdvance; + }, + [&](HloInstructionAdaptor arg) { + visited_args.emplace_back(arg.name()); + }); + + EXPECT_THAT(visited_nodes, ElementsAre("reduce.1", "mul")); + EXPECT_THAT(visited_args, ElementsAre("p0", "negate")); +} + +TEST_F(HloTraversalTest, + TraverseFusionConsumerFirstFromFusionRootAndInnerNode) { + auto module = ParseAndReturnVerifiedModule(kTestModule).value(); + std::vector visited_nodes; + std::vector visited_args; + auto fusion = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("fusion")); + auto root = fusion->GetRoots()[0]; + HloBfsConsumersFirstTraversal( + {root, root.GetOperand(0)}, *fusion, + [&](HloInstructionAdaptor node) { + visited_nodes.emplace_back(node.name()); + return TraversalResult::kAdvance; }, [&](HloInstructionAdaptor arg) { visited_args.emplace_back(arg.name()); @@ -120,6 +164,21 @@ TEST_F(HloTraversalTest, TraverseFusion) { EXPECT_THAT(visited_args, ElementsAre("p0", "negate")); } +TEST_F(HloTraversalTest, TraverseFusionProducerFirst) { + auto module = ParseAndReturnVerifiedModule(kTestModule).value(); + std::vector visited_nodes; + auto fusion = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("fusion")); + auto root = fusion->GetRoots()[0]; + HloBfsProducersFirstTraversal({root.GetOperand(0)}, *fusion, + [&](HloInstructionAdaptor node) { + visited_nodes.emplace_back(node.name()); + return TraversalResult::kAdvance; + }); + + EXPECT_THAT(visited_nodes, ElementsAre("mul", "reduce.1")); +} + TEST_F(HloTraversalTest, AbortTraversal) { auto module = ParseAndReturnVerifiedModule(kTestModule).value(); auto fusion = HloFusionAdaptor::ForInstruction( @@ -129,8 +188,8 @@ TEST_F(HloTraversalTest, AbortTraversal) { [&](HloInstructionAdaptor node) { visited_nodes.emplace_back(node.name()); return node.opcode() == HloOpcode::kReduce - ? TraversalResult::kVisitOperands - : TraversalResult::kAbortTraversal; + ? TraversalResult::kAdvance + : TraversalResult::kInterrupt; }); EXPECT_THAT(visited_nodes, ElementsAre("reduce.1", "mul")); @@ -167,7 +226,6 @@ TEST_F(HloTraversalTest, FindIf) { auto module = ParseAndReturnVerifiedModule(kTestModule).value(); auto fusion = HloFusionAdaptor::ForInstruction( module->entry_computation()->GetInstructionWithName("fusion")); - std::vector visited_nodes; auto result = HloFindIf(fusion->GetRoots(), *fusion, [&](HloInstructionAdaptor node) { return node.opcode() == HloOpcode::kMultiply; @@ -180,7 +238,6 @@ TEST_F(HloTraversalTest, NotFound) { auto module = ParseAndReturnVerifiedModule(kTestModule).value(); auto fusion = HloFusionAdaptor::ForInstruction( module->entry_computation()->GetInstructionWithName("fusion")); - std::vector visited_nodes; auto result = HloFindIf(fusion->GetRoots(), *fusion, [&](HloInstructionAdaptor node) { return false; }); ASSERT_EQ(result, std::nullopt); @@ -233,7 +290,7 @@ TEST_F(HloTraversalTest, FuseFusionConsumer) { fusion.GetRoots(), fusion, [&](HloInstructionAdaptor node) { nodes.emplace_back(node.name()); - return TraversalResult::kVisitOperands; + return TraversalResult::kAdvance; }, [&](HloInstructionAdaptor param) { params.emplace_back(param.name()); }); @@ -256,7 +313,7 @@ TEST_F(HloTraversalTest, FuseFusionProducer) { fusion.GetRoots(), fusion, [&](HloInstructionAdaptor node) { nodes.emplace_back(node.name()); - return TraversalResult::kVisitOperands; + return TraversalResult::kAdvance; }, [&](HloInstructionAdaptor arg) { params.emplace_back(arg.name()); }); @@ -276,7 +333,7 @@ TEST_F(HloTraversalTest, FuseFusionConsumerAndProducer) { HloBfsConsumersFirstTraversal(fusion.GetRoots(), fusion, [&](HloInstructionAdaptor node) { nodes.emplace_back(node.name()); - return TraversalResult::kVisitOperands; + return TraversalResult::kAdvance; }); std::vector params; FindFusionArguments(fusion, [&](const HloInstructionAdaptor& param) { @@ -300,7 +357,7 @@ TEST_F(HloTraversalTest, FuseNonFusionConsumerAndProducer) { HloBfsConsumersFirstTraversal(fusion.GetRoots(), fusion, [&](HloInstructionAdaptor node) { nodes.emplace_back(node.name()); - return TraversalResult::kVisitOperands; + return TraversalResult::kAdvance; }); EXPECT_THAT(nodes, ElementsAre("negate", "log")); @@ -315,7 +372,7 @@ TEST_F(HloTraversalTest, SingleInstructionFusionOfFusion) { HloBfsConsumersFirstTraversal(fusion->GetRoots(), *fusion, [&](HloInstructionAdaptor node) { nodes.emplace_back(node.name()); - return TraversalResult::kVisitOperands; + return TraversalResult::kAdvance; }); EXPECT_THAT(nodes, ElementsAre("reduce.1", "mul")); @@ -330,12 +387,85 @@ TEST_F(HloTraversalTest, SingleInstructionFusionOfInstruction) { HloBfsConsumersFirstTraversal(fusion->GetRoots(), *fusion, [&](HloInstructionAdaptor node) { nodes.emplace_back(node.name()); - return TraversalResult::kVisitOperands; + return TraversalResult::kAdvance; }); EXPECT_THAT(nodes, ElementsAre("negate")); } +TEST_F(HloTraversalTest, MakeInstructionsPostOrder_SingleInstruction) { + auto module = ParseAndReturnVerifiedModule(kTwoFusions).value(); + auto fusion = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("negate")); + + auto nodes = fusion->MakeInstructionPostOrder(); + EXPECT_THAT(nodes, ElementsAre(InstructionAdaptorName("negate"))); +} + +TEST_F(HloTraversalTest, MakeInstructionsPostOrder_TwoFusions) { + auto module = ParseAndReturnVerifiedModule(kTwoFusions).value(); + auto fusion = ProducerConsumerFusion( + module->entry_computation()->GetInstructionWithName("fusion.1"), + module->entry_computation()->GetInstructionWithName("fusion.2")); + + auto nodes = fusion.MakeInstructionPostOrder(); + EXPECT_THAT(nodes, ElementsAre(InstructionAdaptorName("mul"), + InstructionAdaptorName("reduce.1"), + InstructionAdaptorName("reduce.2"))); +} + +TEST_F(HloTraversalTest, MakeInstructionsPostOrder_TwoMultiOutputFusions) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test + + scalar_add_computation { + scalar_lhs.0 = f32[] parameter(0) + scalar_rhs.0 = f32[] parameter(1) + ROOT add.0 = f32[] add(scalar_lhs.0, scalar_rhs.0) + } + + fused_computation_1 { + p0.1 = f32[] parameter(0) + p1.1 = f32[128] parameter(1) + mul = f32[128] multiply(p1.1, p1.1) + reduce.1 = f32[] reduce(mul, p0.1), dimensions={0}, to_apply=scalar_add_computation + ROOT t = (f32[128], f32[]) tuple(mul, reduce.1) + } + + fused_computation_2 { + p0.2 = f32[] parameter(0) + p1.2 = f32[128] parameter(1) + neg = f32[128] negate(p1.2) + reduce.2 = f32[] reduce(neg, p0.2), dimensions={0}, to_apply=scalar_add_computation + ROOT t2 = (f32[], f32[128]) tuple(reduce.2, neg) + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[128] parameter(1) + sum = f32[128] add(p1, p1) + negate = f32[128] negate(sum) + fusion.1 = (f32[128], f32[]) fusion(p0, negate), kind=kLoop, calls=fused_computation_1 + gte1 = f32[128] get-tuple-element(fusion.1), index=0 + gte2 = f32[] get-tuple-element(fusion.1), index=1 + fusion.2 = (f32[], f32[128]) fusion(p0, gte1), kind=kLoop, calls=fused_computation_2 + gte3 = f32[] get-tuple-element(fusion.2), index=0 + gte4 = f32[128] get-tuple-element(fusion.2), index=1 + difference = f32[] subtract(gte3, p0) + ROOT res = (f32[], f32[128]) tuple(difference, gte4) + })") + .value(); + auto fusion = ProducerConsumerFusion( + module->entry_computation()->GetInstructionWithName("fusion.1"), + module->entry_computation()->GetInstructionWithName("fusion.2")); + + auto nodes = fusion.MakeInstructionPostOrder(); + EXPECT_THAT(nodes, ElementsAre(InstructionAdaptorName("mul"), + InstructionAdaptorName("reduce.1"), + InstructionAdaptorName("neg"), + InstructionAdaptorName("reduce.2"))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/horizontal_input_fusion.cc b/xla/service/gpu/horizontal_input_fusion.cc index 0e01d4964149d..c693856968661 100644 --- a/xla/service/gpu/horizontal_input_fusion.cc +++ b/xla/service/gpu/horizontal_input_fusion.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,11 +16,25 @@ limitations under the License. #include "xla/service/gpu/horizontal_input_fusion.h" #include +#include +#include #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/hlo_creation_utils.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -48,7 +62,7 @@ class HorizontalInputFusionImpl { ~HorizontalInputFusionImpl() = default; - StatusOr Run(); + absl::StatusOr Run(); private: HloComputation* computation_; @@ -106,7 +120,7 @@ std::vector FindAndSortFusionCandidates( return fusion_instrs; } -StatusOr HorizontalInputFusionImpl::Run() { +absl::StatusOr HorizontalInputFusionImpl::Run() { bool changed = false; XLA_VLOG_LINES(3, computation_->ToString()); @@ -155,13 +169,13 @@ StatusOr HorizontalInputFusionImpl::Run() { } // namespace -StatusOr GpuHorizontalInputFusion::RunOnComputation( +absl::StatusOr GpuHorizontalInputFusion::RunOnComputation( HloComputation* computation) { HorizontalInputFusionImpl horizontal_fusion_impl(computation, device_info_); return horizontal_fusion_impl.Run(); } -StatusOr GpuHorizontalInputFusion::Run( +absl::StatusOr GpuHorizontalInputFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/gpu/horizontal_input_fusion.h b/xla/service/gpu/horizontal_input_fusion.h index 0dc45c90adb5f..370ce7bd0509a 100644 --- a/xla/service/gpu/horizontal_input_fusion.h +++ b/xla/service/gpu/horizontal_input_fusion.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ #define XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" @@ -46,12 +49,12 @@ class GpuHorizontalInputFusion : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; private: - StatusOr RunOnComputation(HloComputation*); + absl::StatusOr RunOnComputation(HloComputation*); const se::DeviceDescription& device_info_; }; diff --git a/xla/service/gpu/horizontal_input_fusion_test.cc b/xla/service/gpu/horizontal_input_fusion_test.cc index 2337f192c4bfc..2839e2ff40c0e 100644 --- a/xla/service/gpu/horizontal_input_fusion_test.cc +++ b/xla/service/gpu/horizontal_input_fusion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,11 +15,21 @@ limitations under the License. #include "xla/service/gpu/horizontal_input_fusion.h" +#include +#include +#include + +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" #include "xla/test.h" namespace xla { @@ -30,8 +40,9 @@ namespace m = ::xla::match; class HorizontalInputFusionTest : public GpuCodegenTest { public: - GpuHorizontalInputFusion horizontal_input_fusion_{ + se::DeviceDescription device_description_{ TestGpuDeviceInfo::RTXA6000DeviceInfo()}; + GpuHorizontalInputFusion horizontal_input_fusion_{device_description_}; }; TEST_F(HorizontalInputFusionTest, BasicTest) { diff --git a/xla/service/gpu/horizontal_loop_fusion.cc b/xla/service/gpu/horizontal_loop_fusion.cc index 8f8e8eee454a0..e6074d43254d3 100644 --- a/xla/service/gpu/horizontal_loop_fusion.cc +++ b/xla/service/gpu/horizontal_loop_fusion.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,20 +16,37 @@ limitations under the License. #include "xla/service/gpu/horizontal_loop_fusion.h" #include +#include +#include #include +#include #include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout_util.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/hlo_creation_utils.h" +#include "xla/service/sub_byte_normalization.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -59,12 +76,12 @@ class HorizontalLoopFusionImpl { ~HorizontalLoopFusionImpl() = default; - StatusOr Run(); + absl::StatusOr Run(); private: - Status Fuse(absl::Span fused_fusion_instrs, - bool sliced_input_fusion, - std::vector& to_fuse_candidates); + absl::Status Fuse(absl::Span fused_fusion_instrs, + bool sliced_input_fusion, + std::vector& to_fuse_candidates); // If `sliced_input_fusion` is true, Horizontally fuses `fused_fusion_instrs` // into kInput computation, else fuses `fused_fusion_instrs` into kLoop @@ -78,7 +95,7 @@ class HorizontalLoopFusionImpl { // // Returns the fused computation in `uniq_computation` and the operands that // are used by `uniq_computation`. - Status CreateFusedComputation( + absl::Status CreateFusedComputation( absl::Span fused_fusion_instrs, std::unique_ptr* uniq_computation, std::vector* bound_operands, bool sliced_input_fusion); @@ -89,7 +106,7 @@ class HorizontalLoopFusionImpl { // stack that we want to try horizontally fuse its operands, when we create a // new fusion instruction, we push it to the stack in hope to further fuse its // operands. - StatusOr FuseConsumerOperands( + absl::StatusOr FuseConsumerOperands( HloInstruction* consumer, bool sliced_input_fusion, std::vector& to_fuse_candidates); @@ -341,76 +358,59 @@ HorizontalLoopFusionImpl::FusionCandidates::GetNextSpanOfFusions() { return 32; } else { if (fusible_instrs_[pos_]->opcode() == HloOpcode::kFusion) { - auto fused_instruction_count = - fusible_instrs_[pos_]->fused_instruction_count(); - if (fused_instruction_count < 8) { - return 32; - } else if (fused_instruction_count < 16) { - return 16; - } else if (fused_instruction_count < 32) { - return 8; - } else if (fused_instruction_count < 64) { - return 4; - } else { - return 2; - } + return 32; } else { return 64; } } }(); - // CUDA has a parameter size limit of ~4k bytes. - constexpr int64_t kMaxCudaParamSize = 4000; - size_t accum_io_size = 0; - auto reach_max_fusion_batch_size = [&](size_t left, size_t right) -> bool { - if (right - left >= kMaxFusionBatchSize) { - return true; - } - - accum_io_size += fusible_instrs_.at(right)->operand_count() + - GetOutputSizeOfFusible(*fusible_instrs_.at(right)); - - if (accum_io_size * 8 >= kMaxCudaParamSize) { - return true; - } - - return false; - }; - size_t left = pos_; size_t right = pos_ + 1; size_t first_output_size = GetOutputSizeOfFusible(*fusible_instrs_[left]); PrimitiveType first_output_type = GetUniqueOutputTypeOfFusible(*fusible_instrs_[left]); + // CUDA has a parameter size limit of ~4k bytes. + constexpr int64_t kMaxCudaParamSize = 4000; + size_t accum_io_size = 0; + size_t accum_num_outputs = 0; for (; right < fusible_instrs_.size(); ++right) { PrimitiveType cur_output_type = GetUniqueOutputTypeOfFusible(*fusible_instrs_[right]); if (first_output_type != cur_output_type) { // Cannot fuse computations who have multiple output types. break; - } else if (first_output_size != - GetOutputSizeOfFusible(*fusible_instrs_[right])) { + } + if (first_output_size != GetOutputSizeOfFusible(*fusible_instrs_[right])) { // Cannot fuse computations who have different numbers of outputs. break; - } else if (GetInstrCountOfFusible(*fusible_instrs_[left]) != - GetInstrCountOfFusible(*fusible_instrs_[right])) { + } + if (GetInstrCountOfFusible(*fusible_instrs_[left]) != + GetInstrCountOfFusible(*fusible_instrs_[right])) { // Do not fuse computations of different instruction counts as it may // introduce control divergence. This is a very simple heuristic to avoid // fusing computations with too much discrepancy and we may improve it // when the needs arise. break; - } else if (!sliced_input_fusion_ && - !ShapeUtil::EqualIgnoringElementType( - GetOutputsOfFusible(*fusible_instrs_[left])[0]->shape(), - GetOutputsOfFusible(*fusible_instrs_[right])[0]->shape())) { + } + if (!sliced_input_fusion_ && + !ShapeUtil::EqualIgnoringElementType( + GetOutputsOfFusible(*fusible_instrs_[left])[0]->shape(), + GetOutputsOfFusible(*fusible_instrs_[right])[0]->shape())) { // This is for fusing into kLoop type kernel, so we requires that each // fusion operand have the same shape break; - } else if (reach_max_fusion_batch_size(left, right)) { + } + size_t num_outputs = GetOutputSizeOfFusible(*fusible_instrs_[right]); + accum_num_outputs += num_outputs; + if (accum_num_outputs >= kMaxFusionBatchSize) { // Hit max fusion batch size. break; } + accum_io_size += fusible_instrs_.at(right)->operand_count() + num_outputs; + if (accum_io_size * 8 >= kMaxCudaParamSize) { + break; + } } VLOG(2) << "horizontal fuse get instruction span with " << (right - left) << " instructions for sliced_input_fusion=" << sliced_input_fusion_ @@ -419,7 +419,7 @@ HorizontalLoopFusionImpl::FusionCandidates::GetNextSpanOfFusions() { return absl::MakeSpan(fusible_instrs_).subspan(left, right - left); } -StatusOr HorizontalLoopFusionImpl::FuseConsumerOperands( +absl::StatusOr HorizontalLoopFusionImpl::FuseConsumerOperands( HloInstruction* consumer, bool sliced_input_fusion, std::vector& to_fuse_candidates) { bool changed = false; @@ -454,7 +454,7 @@ StatusOr HorizontalLoopFusionImpl::FuseConsumerOperands( return changed; } -Status HorizontalLoopFusionImpl::CreateFusedComputation( +absl::Status HorizontalLoopFusionImpl::CreateFusedComputation( absl::Span fused_fusion_instrs, std::unique_ptr* uniq_computation, std::vector* bound_operands, bool sliced_input_fusion) { @@ -600,10 +600,10 @@ Status HorizontalLoopFusionImpl::CreateFusedComputation( TF_RETURN_IF_ERROR(comp->RemoveInstruction(dummy_root)); } - return OkStatus(); + return absl::OkStatus(); } -Status HorizontalLoopFusionImpl::Fuse( +absl::Status HorizontalLoopFusionImpl::Fuse( absl::Span fused_fusion_instrs, bool sliced_input_fusion, std::vector& to_fuse_candidates) { // Fuse fused_fusion_instrs and replace them with the new fused computation. @@ -670,10 +670,10 @@ Status HorizontalLoopFusionImpl::Fuse( VLOG(1) << "Fused " << fused_fusion_instrs.size() << " instructions into: " << hori_fusion_instr->ToString(); - return OkStatus(); + return absl::OkStatus(); } -StatusOr HorizontalLoopFusionImpl::Run() { +absl::StatusOr HorizontalLoopFusionImpl::Run() { bool changed = false; XLA_VLOG_LINES(3, computation_->ToString()); @@ -713,13 +713,13 @@ StatusOr HorizontalLoopFusionImpl::Run() { } // namespace -StatusOr GpuHorizontalLoopFusion::RunOnComputation( +absl::StatusOr GpuHorizontalLoopFusion::RunOnComputation( HloComputation* computation) { HorizontalLoopFusionImpl horizontal_fusion_impl(computation, prefix_); return horizontal_fusion_impl.Run(); } -StatusOr GpuHorizontalLoopFusion::Run( +absl::StatusOr GpuHorizontalLoopFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(2) << "Run horizontal fusion."; @@ -728,6 +728,15 @@ StatusOr GpuHorizontalLoopFusion::Run( TF_ASSIGN_OR_RETURN(bool changed, RunOnComputation(module->entry_computation())); + if (changed) { + // Correctly set element_size_in_bits for any int4 added slice and + // concatenate instructions + TF_ASSIGN_OR_RETURN( + [[maybe_unused]] bool unused, + SubByteNormalization{SubByteNormalization::SET_ELEMENT_SIZE}.Run( + module)); + } + return changed; } diff --git a/xla/service/gpu/horizontal_loop_fusion.h b/xla/service/gpu/horizontal_loop_fusion.h index 4d81ba65412f4..5daed0378aa90 100644 --- a/xla/service/gpu/horizontal_loop_fusion.h +++ b/xla/service/gpu/horizontal_loop_fusion.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,11 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_ #define XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_ +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -63,15 +68,16 @@ namespace gpu { // fused operations have the same shape or not. // // case 1: if Mul and Add's output shape and type are the same, then we fuse -// them into the below pattern: i0 i1 i2 i3 +// them into the below pattern: +// i0 i1 i2 i3 // | | | | // v v v v // Mul Add // | | // v v // (ROOT) tuple -// the fused kernel will be kLoop type, i.e, GPU code is emitted through -// IrEmitterUnnested::EmitLoopFusion +// the fused kernel will be kLoop type, and GPU code is emitted through +// the LoopFusion class. // // case 2: if Mul and Add's output shape are diffent, then we fuse them into // the below pattern that adds extra indexing: @@ -96,7 +102,7 @@ namespace gpu { // (ROOT) tuple // // the fused kernel will be kInput type, and, the GPU code is emitted through -// IrEmitterUnnested::EmitInputFusibleNonStridedSlices +// the InputSlicesFusion class. // // In theory, the pattern in case 1 could also be fused into the case2 target // graph, but we prefer to fuse into kLoop type, because the codegen for it does @@ -127,12 +133,12 @@ class GpuHorizontalLoopFusion : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; private: - StatusOr RunOnComputation(HloComputation*); + absl::StatusOr RunOnComputation(HloComputation*); std::string prefix_; }; diff --git a/xla/service/gpu/horizontal_loop_fusion_test.cc b/xla/service/gpu/horizontal_loop_fusion_test.cc index 3087958396949..935c21c6e23fe 100644 --- a/xla/service/gpu/horizontal_loop_fusion_test.cc +++ b/xla/service/gpu/horizontal_loop_fusion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,9 +15,17 @@ limitations under the License. #include "xla/service/gpu/horizontal_loop_fusion.h" +#include +#include +#include #include -#include "xla/literal.h" +#include "absl/algorithm/container.h" +#include "absl/log/log.h" +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/instruction_fusion.h" #include "xla/service/hlo_dce.h" @@ -26,10 +34,10 @@ limitations under the License. #include "xla/service/hlo_pass_pipeline.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/service/tuple_simplifier.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" #include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "tsl/lib/core/status_test_util.h" diff --git a/xla/service/gpu/infeed_manager.cc b/xla/service/gpu/infeed_manager.cc index 29e8dbc8c1937..945e686aaf572 100644 --- a/xla/service/gpu/infeed_manager.cc +++ b/xla/service/gpu/infeed_manager.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,9 +15,24 @@ limitations under the License. #include "xla/service/gpu/infeed_manager.h" +#include +#include #include - +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/literal.h" +#include "xla/service/gpu/xfeed_queue.h" +#include "xla/shape.h" +#include "xla/shape_tree.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "xla/service/gpu/xla_executor_state.h" @@ -31,11 +46,9 @@ constexpr int kMaxInfeedsInFlight = 8; InfeedManager::InfeedManager(se::StreamExecutor* executor) : BlockingXfeedQueue(/*max_pending_xfeeds=*/kMaxInfeedsInFlight), - stream_(std::make_unique(executor)) { - stream_->Init(); -} + stream_(executor->CreateStream().value()) {} -static StatusOr> CopyBufferToDevice( +static absl::StatusOr> CopyBufferToDevice( se::Stream* stream, int64_t size, const void* source) { if (size > std::numeric_limits::max()) { return InvalidArgument("GPU infeed of %d bytes exceeds maximum of %d bytes", @@ -49,13 +62,13 @@ static StatusOr> CopyBufferToDevice( se::StreamExecutor* executor = stream->parent(); se::ScopedDeviceMemory buffer( executor, executor->AllocateArray(size)); - stream->ThenMemcpy(buffer.ptr(), source, size); + TF_RETURN_IF_ERROR(stream->Memcpy(buffer.ptr(), source, size)); return std::move(buffer); } -Status InfeedManager::TransferLiteralToInfeed(se::StreamExecutor* executor, - const LiteralSlice& literal) { +absl::Status InfeedManager::TransferLiteralToInfeed( + se::StreamExecutor* executor, const LiteralSlice& literal) { const Shape& literal_shape = literal.shape(); VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(literal_shape); @@ -77,14 +90,14 @@ Status InfeedManager::TransferLiteralToInfeed(se::StreamExecutor* executor, // TODO(b/30467474): Since this stream is shared across different infeed // requests, blocking on the stream might be heavy-handed. Figure out if // finer-grained acknowledgement is possible. - Status block_status = stream()->BlockHostUntilDone(); + absl::Status block_status = stream()->BlockHostUntilDone(); if (!block_status.ok()) { - return InternalError("Failed to complete data transfer on stream %p: %s", - stream(), block_status.message()); + return Internal("Failed to complete data transfer on stream %p: %s", + stream(), block_status.message()); } EnqueueDestination(std::move(buffer_tree)); - return OkStatus(); + return absl::OkStatus(); } InfeedManager* GetOrCreateInfeedManager(se::StreamExecutor* executor) { diff --git a/xla/service/gpu/infeed_manager.h b/xla/service/gpu/infeed_manager.h index 055e3139affa7..6eff97721d807 100644 --- a/xla/service/gpu/infeed_manager.h +++ b/xla/service/gpu/infeed_manager.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,12 +20,15 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_INFEED_MANAGER_H_ #define XLA_SERVICE_GPU_INFEED_MANAGER_H_ -#include "absl/base/thread_annotations.h" +#include +#include + +#include "absl/status/status.h" #include "xla/literal.h" #include "xla/service/gpu/xfeed_queue.h" #include "xla/shape_tree.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" namespace xla { namespace gpu { @@ -48,8 +51,8 @@ class InfeedManager public: explicit InfeedManager(se::StreamExecutor* executor); - Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const LiteralSlice& literal); + absl::Status TransferLiteralToInfeed(se::StreamExecutor* executor, + const LiteralSlice& literal); private: se::Stream* stream() const { return stream_.get(); } diff --git a/xla/service/gpu/infeed_thunk.h b/xla/service/gpu/infeed_thunk.h deleted file mode 100644 index c2d1cd325e2f1..0000000000000 --- a/xla/service/gpu/infeed_thunk.h +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_INFEED_THUNK_H_ -#define XLA_SERVICE_GPU_INFEED_THUNK_H_ - -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/thunk.h" -#include "xla/stream_executor/stream_executor.h" - -namespace xla { -namespace gpu { - -// A thunk that infeeds data. Data must be already resident on the -// device. This thunk performs an intra-device copy from that location -// to the buffer allocated for the infeed op. -class InfeedThunk : public Thunk { - public: - // Constructs a InfeedThunk that copies data from the on-device - // infeed queue into the buffers in the given shape tree. - InfeedThunk(ThunkInfo thunk_info, std::vector dest_slices); - - InfeedThunk(const InfeedThunk&) = delete; - InfeedThunk& operator=(const InfeedThunk&) = delete; - - Status ExecuteOnStream(const ExecuteParams& params) override; - - private: - const std::vector dest_slices_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_INFEED_THUNK_H_ diff --git a/xla/service/gpu/instruction_fusion.cc b/xla/service/gpu/instruction_fusion.cc index 60a8b5b7a1f0e..a507e086e5025 100644 --- a/xla/service/gpu/instruction_fusion.cc +++ b/xla/service/gpu/instruction_fusion.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -74,7 +74,7 @@ FusionDecision GpuInstructionFusion::ShouldFuseInexpensiveChecks( // Cost condition: not fuse (simple, expensive producers) and (consumers who // reuse operand elements). - if (producer->opcode() != HloOpcode::kFusion && is_expensive(*producer) && + if (is_expensive(*producer) && ReusesOperandElements(consumer, operand_index)) { return "the producer is expensive, and the consumer reuses inputs"; } @@ -86,10 +86,7 @@ FusionDecision GpuInstructionFusion::ShouldFuseInexpensiveChecks( return "fusing the producer would break read coalescing"; } - if (auto fusible = IsProducerConsumerFusible(*producer, *consumer); - !fusible) { - return fusible; - } + RETURN_IF_NOT_FUSIBLE(IsProducerConsumerFusible(*producer, *consumer)); if (CreatesHeavyComputation(*producer, *consumer)) { return "the fusion would create a heavy computation"; @@ -100,20 +97,14 @@ FusionDecision GpuInstructionFusion::ShouldFuseInexpensiveChecks( FusionDecision GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64_t operand_index) { - if (auto fusible = ShouldFuseInexpensiveChecks(consumer, operand_index); - !fusible) { - return fusible; - } + RETURN_IF_NOT_FUSIBLE(ShouldFuseInexpensiveChecks(consumer, operand_index)); auto producer = consumer->operand(operand_index); // The following checks are potentially expensive. - if (auto fits_budget = - FusionFitsInBudget(*consumer, *producer, device_info_, - /*is_consumer_producer_fusion=*/true); - !fits_budget) { - return fits_budget; - } + RETURN_IF_NOT_FUSIBLE( + FusionFitsInBudget(*consumer, *producer, device_info_, + /*is_consumer_producer_fusion=*/true)); if (consumer->opcode() != HloOpcode::kFusion) { return {}; diff --git a/xla/service/gpu/instruction_fusion.h b/xla/service/gpu/instruction_fusion.h index bf036cbdd4501..db57690ce9571 100644 --- a/xla/service/gpu/instruction_fusion.h +++ b/xla/service/gpu/instruction_fusion.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,10 +19,10 @@ limitations under the License. #include #include -#include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -31,7 +31,6 @@ limitations under the License. #include "xla/service/fusion_queue.h" #include "xla/service/hlo_pass_interface.h" #include "xla/service/instruction_fusion.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" namespace xla { @@ -47,9 +46,9 @@ class GpuInstructionFusion : public InstructionFusion { static bool IsExpensive(const HloInstruction& instruction); using HloPassInterface::Run; - StatusOr Run(HloModule* module, - const absl::flat_hash_set& - execution_threads) override { + absl::StatusOr Run(HloModule* module, + const absl::flat_hash_set& + execution_threads) override { fusion_node_evaluations_.clear(); return InstructionFusion::Run(module, execution_threads); } diff --git a/xla/service/gpu/instruction_fusion_test.cc b/xla/service/gpu/instruction_fusion_test.cc index 9bad7d92a5859..d0d3d387bb1f9 100644 --- a/xla/service/gpu/instruction_fusion_test.cc +++ b/xla/service/gpu/instruction_fusion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,16 +15,25 @@ limitations under the License. #include "xla/service/gpu/instruction_fusion.h" +#include #include +#include +#include #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" +#include "xla/tests/verified_hlo_module.h" #include "xla/util.h" +#include "tsl/platform/statusor.h" namespace m = ::xla::match; diff --git a/xla/service/gpu/ir_emission_utils.cc b/xla/service/gpu/ir_emission_utils.cc index be7a9438d2f3c..6265c5845d036 100644 --- a/xla/service/gpu/ir_emission_utils.cc +++ b/xla/service/gpu/ir_emission_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,11 +15,8 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" -#include #include -#include #include -#include #include #include #include @@ -29,12 +26,14 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/IR/Attributes.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/FPEnv.h" #include "llvm/IR/IRBuilder.h" @@ -45,28 +44,22 @@ limitations under the License. #include "llvm/IR/Verifier.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TargetParser/Triple.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/hlo_traversal.h" @@ -78,14 +71,13 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" -#include "xla/status_macros.h" #include "xla/statusor.h" #include "xla/translate/mhlo_to_hlo/location_exporter.h" #include "xla/translate/mhlo_to_hlo/type_to_shape.h" -#include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/ml_dtypes.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -97,6 +89,11 @@ bool IsRank2(const Shape& shape, int64_t batch_dimensions_size) { return shape.rank() == batch_dimensions_size + 2; } +// Return whether the given shape is rank 1 excluding the batch dimensions. +bool IsRank1(const Shape& shape, int64_t batch_dimensions_size) { + return shape.rank() == batch_dimensions_size + 1; +} + Shape GetShapeFromTensorType(mlir::Value value) { constexpr char kDefaultLayoutAttrName[] = "xla_shape"; @@ -126,9 +123,11 @@ bool IsMatrixMultiplication(const HloInstruction& dot) { PrimitiveType output_primitive_type = dot.shape().element_type(); bool type_is_allowed = (output_primitive_type == F8E4M3FN || output_primitive_type == F8E5M2 || - output_primitive_type == F16 || output_primitive_type == BF16 || - output_primitive_type == F32 || output_primitive_type == F64 || - output_primitive_type == C64 || output_primitive_type == C128) || + output_primitive_type == F8E4M3FNUZ || + output_primitive_type == F8E5M2FNUZ || output_primitive_type == F16 || + output_primitive_type == BF16 || output_primitive_type == F32 || + output_primitive_type == F64 || output_primitive_type == C64 || + output_primitive_type == C128) || (output_primitive_type == S32 && lhs_shape.element_type() == S8 && rhs_shape.element_type() == S8); bool shapes_are_valid = @@ -139,17 +138,37 @@ bool IsMatrixMultiplication(const HloInstruction& dot) { !ShapeUtil::IsZeroElementArray(lhs_shape) && !ShapeUtil::IsZeroElementArray(rhs_shape); - if (!shapes_are_valid) { + return shapes_are_valid; +} + +bool IsMatrixVectorMultiplication(const HloInstruction& dot) { + if (dot.opcode() != HloOpcode::kDot) { return false; } + const Shape& lhs_shape = dot.operand(0)->shape(); + const Shape& rhs_shape = dot.operand(1)->shape(); + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); - // The size of the reduction dimension should match. The shape inference - // guarantees this invariant, so the check here is for programming - // errors. - CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), - rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); + PrimitiveType output_primitive_type = dot.shape().element_type(); + bool type_is_allowed = + (output_primitive_type == F8E4M3FN || output_primitive_type == F8E5M2 || + output_primitive_type == F16 || output_primitive_type == BF16 || + output_primitive_type == F32 || output_primitive_type == F64 || + output_primitive_type == C64 || output_primitive_type == C128) || + (output_primitive_type == S32 && lhs_shape.element_type() == S8 && + rhs_shape.element_type() == S8); - return true; + bool shapes_are_valid = + type_is_allowed && + ((IsRank2(lhs_shape, dim_numbers.lhs_batch_dimensions_size()) && + IsRank1(rhs_shape, dim_numbers.lhs_batch_dimensions_size())) || + (IsRank1(lhs_shape, dim_numbers.lhs_batch_dimensions_size()) && + IsRank2(rhs_shape, dim_numbers.lhs_batch_dimensions_size()))) && + IsRank1(dot.shape(), dim_numbers.lhs_batch_dimensions_size()) && + !ShapeUtil::IsZeroElementArray(lhs_shape) && + !ShapeUtil::IsZeroElementArray(rhs_shape); + + return shapes_are_valid; } const char* const kCusolverCholeskyCallTarget = "__cusolver$cholesky"; @@ -161,85 +180,37 @@ bool IsCustomCallToCusolver(const HloInstruction& hlo) { return hlo.custom_call_target() == kCusolverCholeskyCallTarget; } -bool IsInputFusibleSlices(mlir::Operation* unnested_hlo, - bool verify_no_strides) { - auto fusion = mlir::dyn_cast(unnested_hlo); - if (!fusion) { - return false; - } - - auto is_non_strided = [](mlir::DenseIntElementsAttr strides) -> bool { - return absl::c_all_of( - strides, [](const llvm::APInt& stride) { return stride == 1; }); - }; +bool IsCustomCallToTopK(const HloInstruction& hlo) { + return hlo.opcode() == HloOpcode::kCustomCall && + hlo.custom_call_target() == kTopKCustomCallTarget; +} - for (mlir::Value value : fusion.getFusionResults()) { - auto slice = - mlir::dyn_cast_or_null(value.getDefiningOp()); - if (!slice) { - return false; - } - if (verify_no_strides && !is_non_strided(slice.getStrides())) { - return false; - } - } - return true; +bool IsSliceWithUnitStrides(const HloInstruction* instr) { + auto slice = DynCast(instr); + return slice && absl::c_all_of(slice->slice_strides(), + [](int64_t stride) { return stride == 1; }); } -// This emits a device-side call to -// "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see -// http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls -llvm::Value* EmitPrintf(absl::string_view fmt, - absl::Span arguments, - llvm::IRBuilder<>* builder) { - std::vector argument_types; - - // Variadic arguments implicit promotion [1] converts float to double, - // and bool/char/short are converted to int. - // [1] https://en.cppreference.com/w/cpp/language/variadic_arguments - auto requires_int32_promotion = [](llvm::Type* type) { - return type->isIntegerTy(/*BitWidth=*/1) || - type->isIntegerTy(/*BitWidth=*/8) || - type->isIntegerTy(/*BitWidth=*/16); - }; - auto requires_double_promotion = [](llvm::Type* type) { - return type->isFloatingPointTy(); - }; +bool IsContiguousSlice(const HloInstruction& instr) { + auto slice = DynCast(&instr); + if (!slice) return false; + // No need to check for strides because if stride != 1 there's no way + // src and dst dimensions match. + const Shape& src_shape = slice->operand(0)->shape(); + const Shape& dst_shape = slice->shape(); + return IsContiguousSlice(src_shape, dst_shape); +} - for (auto argument : arguments) { - llvm::Type* type = argument->getType(); - if (requires_double_promotion(type)) { - argument_types.push_back(builder->getDoubleTy()); - } else if (requires_int32_promotion(type)) { - argument_types.push_back(builder->getInt32Ty()); - } else { - argument_types.push_back(type); - } - } - auto* arguments_type = llvm::StructType::create(argument_types); - llvm::Value* arguments_ptr = builder->CreateAlloca(arguments_type); - for (size_t i = 0; i < arguments.size(); ++i) { - llvm::Value* value = arguments[i]; - llvm::Type* type = value->getType(); - if (requires_double_promotion(type)) { - value = builder->CreateFPCast(value, builder->getDoubleTy()); - } else if (requires_int32_promotion(type)) { - value = builder->CreateIntCast(value, builder->getInt32Ty(), - /*isSigned=*/true); +bool IsContiguousSlice(const Shape& orig, const Shape& sliced) { + bool sliced_dim_found = false; + for (auto dim : orig.layout().minor_to_major()) { + if (!sliced_dim_found) { + sliced_dim_found = sliced.dimensions(dim) < orig.dimensions(dim); + continue; } - builder->CreateStore( - value, - builder->CreateGEP(arguments_type, arguments_ptr, - {builder->getInt64(0), builder->getInt32(i)})); + if (sliced.dimensions(dim) != 1) return false; } - llvm::Type* ptr_ty = builder->getPtrTy(); - return builder->CreateCall( - builder->GetInsertBlock()->getParent()->getParent()->getOrInsertFunction( - "vprintf", - llvm::FunctionType::get(builder->getInt32Ty(), {ptr_ty, ptr_ty}, - /*isVarArg=*/false)), - {builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)), - builder->CreatePointerCast(arguments_ptr, ptr_ty)}); + return true; } // Helper function to emit call to AMDGPU shfl_down function. @@ -276,6 +247,29 @@ llvm::Value* EmitNVPTXShflDown(llvm::Value* value, llvm::Value* offset, intrinsic, {b->getInt32(-1), value, offset, b->getInt32(WarpSize() - 1)}); } +// Helper function to emit call to SPIR shfl_down intrinsic. +llvm::Value* EmitSPIRShflDown(llvm::Value* value, llvm::Value* offset, + llvm::IRBuilder<>* b) { + CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32); + if (value->getType()->isFloatTy()) { + return EmitDeviceFunctionCall( + "_Z34__spirv_GroupNonUniformShuffleDownffj", + {b->getInt32(3), value, offset}, {U32, F32, U32}, F32, + llvm::AttrBuilder(b->getContext()) + .addAttribute(llvm::Attribute::NoUnwind) + .addAttribute(llvm::Attribute::Convergent), + b); + } else { + return EmitDeviceFunctionCall( + "_Z34__spirv_GroupNonUniformShuffleDownjjj", + {b->getInt32(3), value, offset}, {U32, U32, U32}, U32, + llvm::AttrBuilder(b->getContext()) + .addAttribute(llvm::Attribute::NoUnwind) + .addAttribute(llvm::Attribute::Convergent), + b); + } +} + llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* builder) { int bit_width = value->getType()->getPrimitiveSizeInBits(); @@ -288,6 +282,8 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, return EmitNVPTXShflDown(value, offset, builder); } else if (target_triple.getArch() == llvm::Triple::amdgcn) { return EmitAMDGPUShflDown(value, offset, builder); + } else if (target_triple.isSPIR()) { + return EmitSPIRShflDown(value, offset, builder); } else { LOG(FATAL) << "Invalid triple " << target_triple.str(); } @@ -309,6 +305,9 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, } else if (target_triple.getArch() == llvm::Triple::amdgcn) { insert_val = EmitAMDGPUShflDown(builder->CreateExtractElement(x, i), offset, builder); + } else if (target_triple.isSPIR()) { + insert_val = EmitSPIRShflDown(builder->CreateExtractElement(x, i), offset, + builder); } else { LOG(FATAL) << "Invalid triple " << target_triple.str(); } @@ -332,65 +331,6 @@ llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) { return b->CreateAnd(is_thread0, is_block0); } -// Given an LMHLO op, returns the operand index of the first output operand. -// -// Notice that an operand alised to an output isn't an output, even though in -// that case WritesMlirBuffer() returns true on that operand. -// -// An operand is !WritesMlirBuffer() || equals (aliases) to a later operand. An -// output is the opposite, being both WritesMlirBuffer() and does not equal to -// any later operand. -int PartitionLmhloOperandsAndOutputs(mlir::Operation* op) { - CHECK(op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")); - - int i; - for (i = op->getOperands().size() - 1; i >= 0; i--) { - const bool aliased = - std::find(op->getOperands().begin() + i + 1, op->getOperands().end(), - op->getOperand(i)) != op->getOperands().end(); - if (!WritesMlirBuffer(op, op->getOperand(i)) || aliased) { - break; - } - } - return i + 1; -} - -llvm::SmallVector GetHloOperands(mlir::Operation* op) { - if (auto fusion = mlir::dyn_cast(op)) { - return fusion.getInputBuffers(); - } - if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) { - int output_start = PartitionLmhloOperandsAndOutputs(op); - llvm::SmallVector operands; - for (int i = 0; i < output_start; i++) { - operands.push_back(op->getOperand(i)); - } - return operands; - } - if (op->getDialect() == op->getContext()->getLoadedDialect("mhlo")) { - return op->getOperands(); - } - LOG(FATAL) << "Unexpected op: " << llvm_ir::DumpToString(op); -} - -llvm::SmallVector GetHloOutputs(mlir::Operation* op) { - if (auto fusion = mlir::dyn_cast(op)) { - return fusion.getOutputBuffers(); - } - if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) { - int output_start = PartitionLmhloOperandsAndOutputs(op); - llvm::SmallVector outputs; - for (int i = output_start; i < op->getNumOperands(); i++) { - outputs.push_back(op->getOperand(i)); - } - return outputs; - } - if (op->getDialect() == op->getContext()->getLoadedDialect("mhlo")) { - return op->getResults(); - } - LOG(FATAL) << "Unexpected op: " << llvm_ir::DumpToString(op); -} - bool WritesMlirBuffer(mlir::Operation* op, mlir::Value operand) { llvm::SmallVector effects; mlir::cast(op).getEffectsOnValue(operand, @@ -415,85 +355,14 @@ static int64_t GetMemRefSizeInBytes(mlir::MemRefType type) { } } -static int64_t GetAllocationIndex(mlir::BlockArgument func_arg, - std::string* constant_name) { - auto func_op = - mlir::cast(func_arg.getParentRegion()->getParentOp()); - if (constant_name) { - if (auto constant_name_attr = func_op.getArgAttrOfType( - func_arg.getArgNumber(), "lmhlo.constant_name")) { - *constant_name = constant_name_attr.getValue().str(); - } - } - return func_arg.getArgNumber(); -} - -StatusOr GetAllocationSlice( - mlir::Value v, absl::Span allocations, - std::string* constant_name) { - if (constant_name) { - constant_name->clear(); - } - - int64_t size = GetMemRefSizeInBytes(v.getType().cast()); - - // We match the following patterns here: - // base := ViewOp(arg) | get_global_memref (global_memref) | arg - // root := base | MemRefReinterpretCastOp(base) | CollapseShapeOp(base) - - if (auto cast = mlir::dyn_cast_or_null( - v.getDefiningOp())) { - v = cast.getViewSource(); - } - if (auto collapse_shape = - mlir::dyn_cast_or_null( - v.getDefiningOp())) { - v = collapse_shape.getSrc(); - } - - if (auto view = - mlir::dyn_cast_or_null(v.getDefiningOp())) { - TF_RET_CHECK(view.getSource().isa()); - - const BufferAllocation* allocation = allocations[GetAllocationIndex( - view.getSource().cast(), constant_name)]; - return BufferAllocation::Slice( - allocation, - mlir::cast(view.getByteShift().getDefiningOp()) - .getValue() - .cast() - .getValue() - .getSExtValue(), - size); - } - if (auto get_global = mlir::dyn_cast_or_null( - v.getDefiningOp())) { - auto module = get_global->getParentOfType(); - if (constant_name) { - *constant_name = get_global.getName().str(); - } - auto global = mlir::cast( - module.lookupSymbol(get_global.getName())); - int64_t index = - global->getAttrOfType("lmhlo.alloc").getInt(); - - return BufferAllocation::Slice(allocations[index], 0, - allocations[index]->size()); - } - if (auto arg = v.dyn_cast()) { - return BufferAllocation::Slice( - allocations[GetAllocationIndex(arg, constant_name)], 0, size); - } - - return Unimplemented( - "Operand has to be in the form of ViewOp(arg) or " - "StaticMemRefCastOp(ViewOp(arg)) or arg"); +absl::StatusOr GetAllocationSlice( + const BufferAssignment& buffer_assignment, const HloInstruction* instr, + const ShapeIndex& index) { + return buffer_assignment.GetUniqueSlice(instr, index); } std::vector GetOutputDefiningDynamicUpdateSlices( const std::vector& roots) { - // Same as GetOutputDefiningDynamicUpdateSliceOps but on a HLO fusion - // computation instead of a LMHLO FusionOp. std::vector dus_ops; for (const HloInstruction* root : roots) { while (root->opcode() == HloOpcode::kBitcast) { @@ -504,136 +373,139 @@ std::vector GetOutputDefiningDynamicUpdateSlices( dus_ops.push_back(root); } } - return dus_ops; } -std::vector -GetOutputDefiningDynamicUpdateSliceOps(mlir::lmhlo::FusionOp fusion) { - std::vector dus_ops; - - auto fusion_results = fusion.getFusionResults(); - for (const auto& fusion_result : fusion_results) { - // A dynamic slice update is said to be "defining" of a result if that - // result is the output of a dynamic slice update, or if that result is - // the output of a bitcast of a dynamic slice update---since a bitcast may - // be handled here as a no-op. - if (auto dus = mlir::dyn_cast( - fusion_result.getDefiningOp())) { - dus_ops.push_back(dus); - } - - if (auto bitcast = mlir::dyn_cast( - fusion_result.getDefiningOp())) { - if (auto dus = mlir::dyn_cast( - bitcast.getOperand().getDefiningOp())) { - dus_ops.push_back(dus); - } - } +template +absl::InlinedVector GetStartIndices(T instr) { + absl::InlinedVector result; + for (int i = instr->first_index_operand_number(); i < instr->operand_count(); + i++) { + const HloInstruction* index = instr->operand(i); + result.push_back(index); } - return dus_ops; + return result; } -bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - mlir::lmhlo::FusionOp fusion, - absl::Span allocations) { - std::vector dus_ops = - GetOutputDefiningDynamicUpdateSliceOps(fusion); +absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( + const HloFusionInstruction* fusion, + const BufferAssignment* buffer_assignment, + const std::vector& roots) { + std::vector dus_instrs = + GetOutputDefiningDynamicUpdateSlices(roots); + + // Get output buffers for fusion. + std::vector output_buffers; + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + fusion->shape(), [&](const Shape& shape, const ShapeIndex index) { + if (shape.IsArray()) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer, + buffer_assignment->GetUniqueSlice(fusion, index)); + output_buffers.push_back(buffer); + } + return absl::OkStatus(); + })); // This check could probably be relaxed: if code generation is made to use a // separate parallel loop for each dynamic slice update, then it shouldn't be // necessary for every output to be a dynamic slice update, nor to have the // same shape. - if (dus_ops.size() != fusion.getFusionResults().size()) { + if (dus_instrs.size() != output_buffers.size()) { return false; } - auto output_buffers = fusion.getOutputBuffers(); - CHECK_GE(output_buffers.size(), 1); - CHECK_EQ(dus_ops.size(), output_buffers.size()); + if (output_buffers.empty()) { + return Internal("Output buffers should not be empty"); + } + + Shape update_shape = dus_instrs[0]->operand(1)->shape(); - auto update_shape = - dus_ops[0].getUpdate().getType().cast().getShape(); + for (int i = 0; i < dus_instrs.size(); ++i) { + auto* dus = Cast(dus_instrs[i]); - // We can safely assume here that the slices being updated do not overlap, as - // constructing a fusion with them would not be safe otherwise. - for (auto [dus, output_buffer] : llvm::zip(dus_ops, output_buffers)) { - // Dynamic slice updates should have a single path to the root---this to - // avoid allowing a dynamic slice update to depend on another, as this would - // not be guaranteed to work with the current codegen. - if (!dus->hasOneUse()) { - return false; - } + // Dynamic slice updates should have a single path to the root to avoid + // allowing a dynamic slice update to depend on another, as this would not + // be guaranteed to work with the current codegen. + if (!dus->IsRoot() && dus->user_count() != 1) return false; + + // We follow DUS users until we find a root instruction. We support only + // few patterns: + // + // (1) ROOT dynamic-update-slice + // (2) ROOT tuple(dynamic-update-slice) + // (3) ROOT bitcast(dynamic-update-slice) + // (4) ROOT tuple(bitcast(dynamic-update-slice)) + HloInstruction* dus_user = dus->IsRoot() ? nullptr : dus->users().front(); // Since the direct consumer of an output dynamic slice update may be a // bitcast, we also check that this bitcast is used a single time. // This property is also important because reads and writes on the parameter // to be updated are done using the shape and layout of the dynamic slice // update. This is a valid approach only if a subsequent bitcast is not read - // by any other op within the fusion---as this may result in codegen + // by any other op within the fusion as this may result in codegen // accessing elements using the wrong physical layout. - auto dus_user = *dus->user_begin(); - if (auto bitcast = mlir::dyn_cast(dus_user)) { - if (!bitcast->hasOneUse()) { - return false; - } - dus_user = *bitcast->user_begin(); - } - if (!mlir::isa(dus_user)) { - return false; + if (dus_user && dus_user->opcode() == HloOpcode::kBitcast) { + if (!dus_user->IsRoot() && dus_user->user_count() != 1) return false; + + // Stop following DUS users if we found a root. + dus_user = dus_user->IsRoot() ? nullptr : dus_user->users().front(); } - auto operand = dus.getOperand(); - // A bitcast separating a fusion input from a dynamic slice update can be - // treated as a no-op. - if (auto bitcast = - mlir::dyn_cast(operand.getDefiningOp())) { - operand = bitcast.getOperand(); + + // Check that last DUS user is a tuple operation at ROOT position. + if (dus_user && dus_user->opcode() == HloOpcode::kTuple) { + if (!dus_user->IsRoot()) return false; + + // Stop following DUS users if we found a root. + dus_user = nullptr; } - auto parameter = mlir::dyn_cast( - operand.getDefiningOp()); + // We can't emit DUS fusion if we have unsupported DUS users. + if (dus_user != nullptr) return false; - if (!parameter) { - return false; + // Find "real" DUS operand by skipping bitcasted operands. + const HloInstruction* operand = dus->operand(0); + if (operand->opcode() == HloOpcode::kBitcast) { + operand = operand->operand(0); } + // Operand to a DUS (or Bitcast) must be a fusion parameter. + auto* parameter = DynCast(operand); + if (!parameter) return false; + // We require that the parameter being updated is only read at the same // index positions by all users, since we otherwise risk a race condition // when updating the parameter inplace. - std::queue q; - absl::flat_hash_set visited; + std::queue q; + absl::flat_hash_set visited; q.push(parameter); visited.insert(parameter); - // We have already checked above that the DUS only has one user: a - // (possibly bitcasted) MaterializeInDestinationOp. So we don't need to - // visit it during the breadth-first search. + // We have already checked above that the DUS only has one user. So we don't + // need to visit it during the breadth-first search. visited.insert(dus); while (!q.empty()) { - auto op = q.front(); + const HloInstruction* instr = q.front(); q.pop(); - for (auto user : op->getUsers()) { - if (mlir::isa(user) && - dus->getOperand(0) == user->getOperand(0) && - update_shape == user->getResult(0) - .getType() - .cast() - .getShape()) { + for (const HloInstruction* user : instr->users()) { + if (user->opcode() == HloOpcode::kDynamicSlice && + dus->operand(0) == user->operand(0) && + update_shape == user->shape()) { // We can still emit in-place in this case if the same slice is // accessed by the DUS and the DS. If they don't access the same // slice, the two slices might partially overlap and read/write the // same index at different times, and then we cannot guarantee that we // read before it is overwritten. However if both access only a single // element, there also can be no race condition. - if (mlir::ShapedType::getNumElements(update_shape) != 1 && - dus.getStartIndices() != - mlir::dyn_cast(user) - .getStartIndices()) { + absl::InlinedVector user_start_indices = + GetStartIndices(Cast(user)); + absl::InlinedVector dus_start_indices = + GetStartIndices(dus); + if (ShapeUtil::ElementsIn(update_shape) != 1 && + user_start_indices != dus_start_indices) { return false; } - } else if (user != dus && - !user->hasTrait() && - !mlir::isa( - user)) { + } else if (user != dus && !user->IsElementwise() && + user->opcode() != HloOpcode::kBitcast && + user->opcode() != HloOpcode::kTuple) { return false; } if (visited.insert(user).second) { @@ -647,15 +519,15 @@ bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu( // be necessary for the shape to be the same for all the dynamic slice // updates. Note that this equality check purposefully ignores the element // type. - if (dus.getUpdate().getType().cast().getShape() != - update_shape) { + if (dus->update()->shape() != update_shape) { return false; } - auto maybe_lhs = GetAllocationSlice(parameter.getMemref(), allocations); - auto maybe_rhs = GetAllocationSlice(output_buffer, allocations); - - if (!(maybe_lhs.ok() && maybe_rhs.ok() && *maybe_lhs == *maybe_rhs)) { + const HloInstruction* lhs = fusion->operand(parameter->parameter_number()); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_buffer, + buffer_assignment->GetUniqueSlice(lhs, {})); + BufferAllocation::Slice rhs_buffer = output_buffers[i]; + if (lhs_buffer != rhs_buffer) { return false; } } @@ -681,7 +553,7 @@ Shape GetShape(mlir::Value value) { return shape; } -std::optional FindTiledTranspose( +static std::optional FindTiledTranspose( const HloInstruction& instr) { if (instr.opcode() != HloOpcode::kCopy) { return std::nullopt; @@ -713,7 +585,7 @@ std::optional FindTiledTranspose( } // Find 021 or 210 transpose in logical + physical transposition. -std::optional FindTiledLogicalTranspose( +static std::optional FindTiledLogicalTranspose( const HloInstruction& instr) { if (instr.opcode() != HloOpcode::kTranspose) { return std::nullopt; @@ -765,24 +637,28 @@ std::optional GetDescriptionForTiledTransposeEmitter( } bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count, - const HloFusionAdaptor* fusion) { + const HloFusionAdaptor* fusion, + bool add_single_user_check) { // Number of operands should be in range [1, allowed_operand_count]. if (instr->operand_count() == 0 || instr->operand_count() > allowed_operand_count) { return false; } - // Intermediate `instr` can't have multiple users. - // If we have a boundary function, only consider users within the - // boundary. - // TODO(jreiffers): Figure out the point of this check. - int64_t num_users = - fusion ? absl::c_count_if( - HloInstructionAdaptor{*instr}.GetUsers(), - [&](auto user) { return fusion->ContainsInstruction(user); }) - : instr->user_count(); - if (num_users > 1) { - return false; + if (add_single_user_check) { + // Check that intermediate `instr` doesn't have multiple users. If we have a + // fusion, only consider users within the fusion. + // TODO(akuegel): Figure out why we still need this check for transpose + // fusions. + int64_t num_users = + fusion ? absl::c_count_if(HloInstructionAdaptor{*instr}.GetUsers(), + [&](auto user) { + return fusion->ContainsInstruction(user); + }) + : instr->user_count(); + if (num_users > 1) { + return false; + } } if (instr->IsElementwise()) { @@ -810,71 +686,89 @@ bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count, } } -static bool IsParameter(const HloInstruction& instr) { - return instr.opcode() == HloOpcode::kParameter; +static std::optional FindNonTrivialHero( + HloInstructionAdaptor root, const HloFusionAdaptor& fusion, + const std::function& predicate) { + std::optional hero = std::nullopt; + auto visitor = [&](HloInstructionAdaptor node) { + if (predicate(node.instruction())) { + if (hero) { // Bail out if we found multiple potential heros. + hero = std::nullopt; + return TraversalResult::kInterrupt; + } + hero = node; + return TraversalResult::kSkip; + } + + // We set add_single_user_check to true because it could be that it causes + // problems if we have more than one user in a transpose fusion. + // TODO(akuegel): Verify and possibly fix this. + if (!IsIntermediate(&node.instruction(), /*allowed_operand_count=*/3, + /*fusion=*/nullptr, /*add_single_user_check=*/true)) { + return TraversalResult::kSkip; + } + return TraversalResult::kAdvance; + }; + HloBfsConsumersFirstTraversal({root}, fusion, visitor); + if (!hero) { + return std::nullopt; + } + + // Make sure that no non-elementwise op is reachable from the transpose. + auto is_nontrivial = [](HloInstructionAdaptor node) { + // We set add_single_user_check to true because it could be that it causes + // problems if we have more than one user in a transpose fusion. + // TODO(akuegel): Verify and possibly fix this. + return node.instruction().opcode() != HloOpcode::kTuple && + node.instruction().opcode() != HloOpcode::kParameter && + !IsIntermediate(&node.instruction(), + /*allowed_operand_count=*/3, /*fusion=*/nullptr, + /*add_single_user_check=*/true); + }; + bool visit_operands = false; + if (HloAnyOf(hero->GetUsers(), fusion, is_nontrivial, visit_operands)) { + return std::nullopt; + } + + return hero; } const HloInstruction& FindNonTrivialHero(const HloInstruction& instr, const HloFusionAdaptor& fusion) { - HloInstructionAdaptor idx{instr}; + HloInstructionAdaptor hero{instr}; - // Go up the chain of trivial element-wise(+bitcast, -copy) operations. Such - // chains are bound to be quite small, as we restrict the number of users as - // well. Note that no memoization is needed due to user number constraints: we + // Go up the chain of trivial element-wise(+bitcast, -copy) operations. Note + // that no memoization is needed due to number of operands constraints: we // never have to revisit same nodes. - auto get_intermediate_arg = - [&](HloInstructionAdaptor node) -> std::optional { - if (IsIntermediate(&node.instruction(), 1, &fusion) && - fusion.ContainsInstruction(node.GetOperand(0))) { - return node.GetOperand(0); - } - return std::nullopt; - }; - while (auto arg = get_intermediate_arg(idx)) { - idx = *arg; + while (IsIntermediate(&hero.instruction(), /*allowed_operand_count=*/1, + &fusion) && + fusion.ContainsInstruction(hero.GetOperand(0))) { + hero = hero.GetOperand(0); } - // The reduction emitter can't handle multiple users. - if (idx.opcode() == HloOpcode::kReduce && - absl::c_count_if(idx.GetUsers(), [&](auto user) { - return fusion.ContainsInstruction(user); - }) > 1) { - return instr; + // Try a bit harder to find a transpose or concat hero. The shared memory + // transpose and concat emitters also work if there are elementwise ops with + // more than 1 operand on the path between root and the root op. + auto is_transpose = [](const HloInstruction& node) { + return FindTiledLogicalTranspose(node).has_value(); + }; + if (auto transpose = FindNonTrivialHero(hero, fusion, is_transpose)) { + return transpose->instruction(); } - - std::optional transpose = std::nullopt; - // Try a bit harder to find a transpose hero. The shared memory transpose - // emitter also works if there are ops with more than 1 operand on the path - // between root and the transpose op, we still want the restriction though - // that each op on the path is elementwise and has only 1 user. - auto visit = [&transpose](HloInstructionAdaptor node) { - if (FindTiledLogicalTranspose(node.instruction())) { - // If we do not find a unique transpose op, use the original non-trivial - // hero. - if (transpose) { - transpose = std::nullopt; - return TraversalResult::kAbortTraversal; - } - transpose = node; - return TraversalResult::kDoNotVisitOperands; - } - - if (!IsIntermediate(&node.instruction(), /*allowed_operand_count=*/3)) { - return TraversalResult::kDoNotVisitOperands; - } - return TraversalResult::kVisitOperands; + auto is_concatenate = [](const HloInstruction& node) { + return node.opcode() == HloOpcode::kConcatenate; }; - HloBfsConsumersFirstTraversal({idx}, fusion, visit); - - return transpose ? transpose->instruction() : idx.instruction(); + if (auto concatenate = FindNonTrivialHero(hero, fusion, is_concatenate)) { + return concatenate->instruction(); + } + if (hero.opcode() != HloOpcode::kReduce) { + return instr; + } + return hero.instruction(); } const HloInstruction& FindNonTrivialHero(const HloInstruction& instr) { - // It doesn't really make sense to call this function with a fusion, but it - // happens. Return the fusion itself for historical reasons. - // TODO(jreiffers): Clean this up. - if (instr.opcode() == HloOpcode::kFusion) return instr; - + CHECK_NE(instr.opcode(), HloOpcode::kFusion); return FindNonTrivialHero(instr, *HloFusionAdaptor::ForComputation(instr.parent())); } @@ -943,60 +837,6 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, return b->getInt32Ty(); } -llvm::Type* GetIndexTypeForKernel(mlir::Operation* op, int64_t launch_size, - llvm::IRBuilder<>* b) { - auto shape_in_range = [&](const Shape& s) { - bool in_range = true; - ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape, - const ShapeIndex& /*index*/) { - if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) { - in_range = false; - } - }); - - return in_range; - }; - - llvm::Type* i64_ty = b->getInt64Ty(); - // Check launch dimension - if (!IsInt32(launch_size)) { - return i64_ty; - } - - // Check the size of result tensors - for (auto result : GetHloOutputs(op)) { - if (!shape_in_range(GetShape(result))) { - return i64_ty; - } - } - - auto hlo_shape_in_range = [&](mlir::Value operand) -> bool { - return shape_in_range(GetShape(operand)); - }; - - // Check the size of input tensors - if (!absl::c_all_of(op->getOperands(), hlo_shape_in_range)) { - return i64_ty; - } - - // Check the size of the internal result tensors - if (auto fusion = mlir::dyn_cast(op)) { - auto result = fusion.getRegion().walk([&](mlir::Operation* op) { - for (mlir::Value result : op->getResults()) { - if (!hlo_shape_in_range(result)) { - return mlir::WalkResult::interrupt(); - } - } - return mlir::WalkResult::advance(); - }); - if (result.wasInterrupted()) { - return i64_ty; - } - } - - return b->getInt32Ty(); -} - std::string GetIrNameFromLoc(mlir::Location loc) { return llvm_ir::SanitizeConstantName( mlir::mhlo::GetDebugNameFromLocation(loc)); @@ -1006,7 +846,12 @@ bool IsAMDGPU(const llvm::Module* module) { return llvm::Triple(module->getTargetTriple()).isAMDGPU(); } -StatusOr LiteralToXlaFormat(const Literal& literal) { +bool IsSPIR(const llvm::Module* module) { + return llvm::Triple(module->getTargetTriple()).isSPIR(); +} + +absl::StatusOr LiteralToXlaFormat( + const Literal& literal) { PrimitiveType element_type = literal.shape().element_type(); if (!primitive_util::IsArrayType(element_type)) { return Internal("Unsupported type in LiteralToXlaFormat"); diff --git a/xla/service/gpu/ir_emission_utils.h b/xla/service/gpu/ir_emission_utils.h index f3349762868c9..896237855503e 100644 --- a/xla/service/gpu/ir_emission_utils.h +++ b/xla/service/gpu/ir_emission_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,17 +23,22 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/literal.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/hlo_traversal.h" -#include "xla/statusor.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" namespace xla { namespace gpu { @@ -49,10 +54,8 @@ inline constexpr int64_t kMinDimensionToTransposeTiled2 = 8; inline constexpr int64_t kMinTotalDimensionsToTransposeTiled = 64 * 128; // Matrix multiplication before the rewrite. -// -// This function should never return "true" on instructions after -// GemmRewriter pass has finished. bool IsMatrixMultiplication(const HloInstruction& dot); +bool IsMatrixVectorMultiplication(const HloInstruction& dot); inline constexpr int64_t WarpSize() { return 32; } @@ -68,9 +71,13 @@ inline constexpr absl::string_view kTritonGemmFusionKind = "__triton_gemm"; inline constexpr absl::string_view kTritonSoftmaxFusionKind = "__triton_softmax"; +inline constexpr absl::string_view kCuDnnFusionKind = "__cudnn$fusion"; + inline constexpr absl::string_view kUncompilableFusion = "__uncompilable_fusion"; +inline constexpr absl::string_view kTopKCustomCallTarget = "__gpu$TopK"; + // Returns true if `hlo` will be implemented as a call to a cuSolver routine. // // This returns true if `hlo` is a CustomCall HLO with a call target equal to @@ -78,22 +85,24 @@ inline constexpr absl::string_view kUncompilableFusion = // say, a kCholesky opcode. bool IsCustomCallToCusolver(const HloInstruction& hlo); +// Returns true if `hlo` will be implemented as a call to a TopK routine. +bool IsCustomCallToTopK(const HloInstruction& hlo); + // Cholesky decomposition. Takes a (batched) matrix as input, and returns a // tuple of (result, workspace, info), where result is the result of the // Cholesky decomposition, workspace is scratch space for cuSolver, and info // is a success/failure code per batch element. extern const char* const kCusolverCholeskyCallTarget; -// Returns whether unnested_hlo is an input fusion whose root is either a slice -// or a tuple of slices. If verify_no_strides is true, returns false unless all -// ROOT slices have no strides. -bool IsInputFusibleSlices(mlir::Operation* unnested_hlo, - bool verify_no_strides); +// Returns true if `instr` is a non-strided slice. +bool IsSliceWithUnitStrides(const HloInstruction* instr); -// Emits call to "vprintf" with given format and arguments. -llvm::Value* EmitPrintf(absl::string_view fmt, - absl::Span arguments, - llvm::IRBuilder<>* builder); +// Returns true if `instr` is a slice instruction and produces a contiguous +// slice. +bool IsContiguousSlice(const HloInstruction& instr); + +// Returns true if `sliced` is a contiguous slice of `orig`. +bool IsContiguousSlice(const Shape& orig, const Shape& sliced); // Emits code to shuffle data between threads of a warp. This has the same // semantics as the PTX "shfl.sync.down" instruction but works for values that @@ -112,26 +121,23 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, // block 0 of the kernel. llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b); -int PartitionLmhloOperandsAndOutputs(mlir::Operation* op); llvm::SmallVector GetHloOperands(mlir::Operation* op); llvm::SmallVector GetHloOutputs(mlir::Operation* op); bool WritesMlirBuffer(mlir::Operation* op, mlir::Value operand); -template -std::vector ToStdVector(const llvm::SmallVectorImpl& v) { - return std::vector(v.begin(), v.end()); -} - -StatusOr GetAllocationSlice( +absl::StatusOr GetAllocationSlice( mlir::Value v, absl::Span allocations, std::string* constant_name = nullptr); -bool IsSingleInstructionFusion(mlir::lmhlo::FusionOp fusion); +absl::StatusOr GetAllocationSlice( + const BufferAssignment& buffer_assignment, const HloInstruction* instr, + const ShapeIndex& index); -bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - mlir::lmhlo::FusionOp fusion, - absl::Span allocations); +absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( + const HloFusionInstruction* fusion, + const BufferAssignment* buffer_assignment, + const std::vector& roots); // Returns the dynamic-update-slice instructions defining the results of a // fusion node. A dynamic slice update is said to be "defining" of a result if @@ -141,14 +147,6 @@ bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu( std::vector GetOutputDefiningDynamicUpdateSlices( const std::vector& roots); -// Returns the DynamicUpdateSliceOp(s) defining the results of a fusion node. -// A dynamic slice update is said to be "defining" of a result if that result is -// the output of a dynamic slice update, or if that result is the output of a -// bitcast of a dynamic slice update---since such bitcast may be handled as a -// no-op. -std::vector -GetOutputDefiningDynamicUpdateSliceOps(mlir::lmhlo::FusionOp fusion); - Shape GetShape(mlir::Value value); // `is_boundary` returns `true` for edges that are on the boundary of the @@ -181,11 +179,6 @@ struct TransposeDescription { Vector3 permutation) : instr(instr), dimensions(dimensions), permutation(permutation) {} - std::string ToString() const { - return absl::StrCat("dimensions=", VectorString(dimensions), - ", permutation=", VectorString(permutation)); - } - // Transpose instruction input shape. const Shape& input_shape() const { return instr->operand(0)->shape(); } @@ -196,19 +189,16 @@ struct TransposeDescription { } }; -std::optional FindTiledTranspose( - const HloInstruction& instr); - -std::optional FindTiledLogicalTranspose( - const HloInstruction& instr); - std::optional GetDescriptionForTiledTransposeEmitter( const HloInstruction& root, const HloInstruction& hero); // Checks if the instruction is elementwise and only has a single user. If -// a fusion adaptor is provided, only checks for users within the fusion. +// a fusion adaptor is provided, only checks for users within the fusion. If +// `add_single_user_check` is true, then it is also checked whether `instr` has +// at most 1 user. bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count = 1, - const HloFusionAdaptor* fusion = nullptr); + const HloFusionAdaptor* fusion = nullptr, + bool add_single_user_check = false); // Log the given module if the VLOG level is >= level. void VLogModule(int level, const llvm::Module& module); @@ -237,6 +227,9 @@ std::string GetIrNameFromLoc(mlir::Location loc); // Whether the module's target is an AMD GPU. bool IsAMDGPU(const llvm::Module* module); +// Whether the module's target is a SPIR. +bool IsSPIR(const llvm::Module* module); + // This class stores either a non-owning reference or owns data that represents // a dense array in XLA format. It is used for intermediate storage during IR // constant emission. @@ -266,7 +259,8 @@ class DenseDataIntermediate { std::variant, absl::Span> data_; }; -StatusOr LiteralToXlaFormat(const Literal& literal); +absl::StatusOr LiteralToXlaFormat( + const Literal& literal); } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/ir_emission_utils_test.cc b/xla/service/gpu/ir_emission_utils_test.cc index 1b3e323a58db7..3c776d60a3e24 100644 --- a/xla/service/gpu/ir_emission_utils_test.cc +++ b/xla/service/gpu/ir_emission_utils_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,29 +16,15 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" #include -#include #include #include -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/DialectRegistry.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/Parser/Parser.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/gpu/hlo_traversal.h" #include "xla/tests/hlo_test_base.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/types.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -47,63 +33,6 @@ namespace gpu { class IrEmissionUtilsTest : public HloTestBase {}; -TEST_F(IrEmissionUtilsTest, TestOperandPartitionNoAlias) { - mlir::DialectRegistry registry; - registry.insert(); - registry.insert(); - mlir::MLIRContext context(registry); - - auto module = mlir::parseSourceString(R"( - func.func @foo(%arg0 : memref, %arg1 : memref, %arg2 : memref) { - "lmhlo.add" (%arg0, %arg1, %arg2) : (memref, memref, memref) -> () - "lmhlo.terminator" () : () -> () - } - )", - &context); - mlir::func::FuncOp func = - mlir::cast(module->lookupSymbol("foo")); - mlir::Operation* op = &func.getBody().front().front(); - EXPECT_EQ(2, PartitionLmhloOperandsAndOutputs(op)); -} - -TEST_F(IrEmissionUtilsTest, TestOperandPartitionWithAlias0) { - mlir::DialectRegistry registry; - registry.insert(); - registry.insert(); - mlir::MLIRContext context(registry); - - auto module = mlir::parseSourceString(R"( - func.func @foo(%arg0 : memref, %arg1 : memref, %arg2 : memref) { - "lmhlo.add" (%arg0, %arg1, %arg0) : (memref, memref, memref) -> () - "lmhlo.terminator" () : () -> () - } - )", - &context); - mlir::func::FuncOp func = - mlir::cast(module->lookupSymbol("foo")); - mlir::Operation* op = &func.getBody().front().front(); - EXPECT_EQ(2, PartitionLmhloOperandsAndOutputs(op)); -} - -TEST_F(IrEmissionUtilsTest, TestOperandPartitionWithAlias1) { - mlir::DialectRegistry registry; - registry.insert(); - registry.insert(); - mlir::MLIRContext context(registry); - - auto module = mlir::parseSourceString(R"( - func.func @foo(%arg0 : memref, %arg1 : memref, %arg2 : memref) { - "lmhlo.add" (%arg0, %arg1, %arg1) : (memref, memref, memref) -> () - "lmhlo.terminator" () : () -> () - } - )", - &context); - mlir::func::FuncOp func = - mlir::cast(module->lookupSymbol("foo")); - mlir::Operation* op = &func.getBody().front().front(); - EXPECT_EQ(2, PartitionLmhloOperandsAndOutputs(op)); -} - TEST_F(IrEmissionUtilsTest, FindTiledLogicalTranspose) { const char* hlo = R"( HloModule module @@ -222,6 +151,115 @@ TEST_F(IrEmissionUtilsTest, FindReduceHeroEpilogueFusion) { EXPECT_EQ(result.name(), "reduce.0"); } +TEST_F(IrEmissionUtilsTest, FindReduceHeroEpilogueFusionTwoRootUsers) { + const char* hlo = R"( + HloModule module + + Add { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) + } + fused_computation { + param_0 = f32[4,2]{1,0} parameter(0) + neg = f32[4,2]{1,0} negate(param_0) + constant_0 = f32[] constant(0) + reduce.1 = f32[4]{0} reduce(param_0, constant_0), dimensions={1}, to_apply=Add + bitcast.1 = f32[1,1,4]{2,1,0} bitcast(reduce.1) + sign.1 = f32[1,1,4]{2,1,0} sign(bitcast.1) + ROOT tuple.12 = (f32[4,2]{1,0}, f32[1,1,4]{2,1,0}, f32[1,1,4]{2,1,0}) tuple(neg, bitcast.1, sign.1) + } + + ENTRY main.7749 { + Arg_2.1 = f32[4,2]{1,0} parameter(0) + ROOT fusion = (f32[4,2]{1,0}, f32[1,1,4]{2,1,0}, f32[1,1,4]{2,1,0}) fusion(Arg_2.1), kind=kInput, calls=fused_computation + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + + HloInstruction* r = module->entry_computation()->root_instruction(); + auto fusion = HloFusionAdaptor::ForInstruction(r); + const auto& result = + FindNonTrivialHero(fusion->GetRoots()[1].instruction(), *fusion); + EXPECT_EQ(result.name(), "reduce.1"); + const auto& result2 = + FindNonTrivialHero(fusion->GetRoots()[2].instruction(), *fusion); + EXPECT_EQ(result2.name(), "reduce.1"); +} + +TEST_F(IrEmissionUtilsTest, FindReduceHeroEpilogueFusionHeroAlsoUsedAsNonHero) { + const char* hlo = R"( + HloModule module + + Add { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + fused_computation { + p0 = f32[4]{0} parameter(0) + zero = f32[] constant(0.0) + reduce.0 = f32[] reduce(f32[4]{0} p0, f32[] zero), dimensions={0}, to_apply=Add + broadcast = f32[4]{0} broadcast(f32[] reduce.0), dimensions={} + reduce.1 = f32[] reduce(f32[4]{0} broadcast, f32[] zero), dimensions={0}, to_apply=Add + bitcast = f32[1]{0} bitcast(f32[] reduce.0) + ROOT tuple.1 = (f32[], f32[4]{0}, f32[1]{0}) tuple(f32[] reduce.1, f32[4]{0} broadcast, f32[1]{0} bitcast) + } + + ENTRY main { + Arg0 = f32[4]{0} parameter(0) + ROOT fusion = (f32[], f32[4]{0}, f32[1]{0}) fusion(Arg0), kind=kInput, calls=fused_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + + HloInstruction* r = module->entry_computation()->root_instruction(); + auto fusion = HloFusionAdaptor::ForInstruction(r); + const auto& result = + FindNonTrivialHero(fusion->GetRoots()[1].instruction(), *fusion); + // reduce.0 is also an operand of broadcast, but it is not a hero for that + // root. + EXPECT_EQ(result.name(), "broadcast"); + const auto& result2 = + FindNonTrivialHero(fusion->GetRoots()[2].instruction(), *fusion); + EXPECT_EQ(result2.name(), "reduce.0"); +} + +TEST_F(IrEmissionUtilsTest, DoNotFindTransposeHeroEpilogueFusionTwoRootUsers) { + const char* hlo = R"( + HloModule module + + fused_computation { + param_0 = f32[64,32]{1,0} parameter(0) + transpose = f32[32,64]{1,0} transpose(param_0), dimensions={1,0} + bitcast.1 = f32[1,32,64]{2,1,0} bitcast(transpose) + sign.1 = f32[1,32,64]{2,1,0} sign(bitcast.1) + ROOT tuple.12 = (f32[1,32,64]{2,1,0}, f32[1,32,64]{2,1,0}) tuple(bitcast.1, sign.1) + } + + ENTRY main.7749 { + Arg_2.1 = f32[64,32]{1,0} parameter(0) + ROOT fusion = (f32[1,32,64]{2,1,0}, f32[1,32,64]{2,1,0}) fusion(Arg_2.1), kind=kInput, calls=fused_computation + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + + HloInstruction* r = module->entry_computation()->root_instruction(); + auto fusion = HloFusionAdaptor::ForInstruction(r); + const auto& result = + FindNonTrivialHero(fusion->GetRoots()[0].instruction(), *fusion); + EXPECT_EQ(result.name(), "bitcast.1"); + const auto& result2 = + FindNonTrivialHero(fusion->GetRoots()[1].instruction(), *fusion); + EXPECT_EQ(result2.name(), "sign.1"); +} + TEST_F(IrEmissionUtilsTest, FindAnyTiledTransposeWithIntermediateBinaryOp) { const char* hlo = R"( HloModule module @@ -360,6 +398,29 @@ ENTRY entry { transpose); } +TEST_F(IrEmissionUtilsTest, TransposeReachableViaTrivialAndNontrivialOps) { + const char* hlo = R"( +HloModule module + +ENTRY entry { + p = f64[16,16]{1,0} parameter(0) + trans = f64[16,16]{1,0} transpose(p), dimensions={1,0} + rev = f64[16,16]{1,0} reverse(trans), dimensions={0,1} + sub = f64[16,16]{1,0} subtract(trans, trans) + ROOT add = f64[16,16]{1,0} add(rev, sub) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + + HloInstruction* r = module->entry_computation()->root_instruction(); + EXPECT_FALSE( + GetDescriptionForTiledTransposeEmitter(*r, FindNonTrivialHero(*r)) + .has_value()); + EXPECT_EQ(&FindNonTrivialHero(*r), r); +} + TEST_F(IrEmissionUtilsTest, FindTiledTransposeOneSwapDimIsSmall) { const char* hlo = R"( HloModule module @@ -444,6 +505,71 @@ ENTRY entry { EXPECT_EQ(result->permutation, Vector3({2, 1, 0})); } +TEST_F(IrEmissionUtilsTest, IsContiguousSlice) { + const char* hlo = R"( +HloModule module + +ENTRY entry { + p = f32[8,12,100,11]{3,2,1,0} parameter(0) + slice.1 = f32[2,12,100,11]{3,2,1,0} slice(p), slice={[1:3], [0:12], [0:100], [0:11]} + slice.2 = f32[1,1,1,11]{3,2,1,0} slice(p), slice={[1:2], [0:1], [0:1], [0:11]} + slice.3 = f32[1,1,10,11]{3,2,1,0} slice(p), slice={[1:2], [0:1], [0:10], [0:11]} + slice.4 = f32[1,2,10,11]{3,2,1,0} slice(p), slice={[1:2], [0:2], [0:10], [0:11]} + slice.5 = f32[8,2,100,11]{3,2,1,0} slice(p), slice={[0:8], [10:12], [0:100], [0:11]} + c = f32[8,12,100,11]{0,1,3,2} copy(p) + slice.6 = f32[8,12,40,11]{0,1,3,2} slice(c), slice={[0:8], [0:12], [10:50], [0:11]} + slice.7 = f32[8,12,1,2]{0,1,3,2} slice(c), slice={[0:8], [0:12], [0:1], [0:2]} + slice.8 = f32[8,2,100,11]{0,1,3,2} slice(c), slice={[0:8], [0:2], [0:100], [0:11]} + slice.9 = f32[8,2,40,11]{0,1,3,2} slice(c), slice={[0:8], [10:12], [10:50], [0:11]} + slice.10 = f32[8,2,50,11]{3,2,1,0} slice(p), slice={[0:8:1], [10:12:1], [0:100:2], [0:11:1]} + ROOT t = (f32[2,12,100,11]{3,2,1,0}, + f32[1,1,1,11]{3,2,1,0}, + f32[1,1,10,11]{3,2,1,0}, + f32[1,2,10,11]{3,2,1,0}, + f32[8,2,100,11]{3,2,1,0}, + f32[8,12,40,11]{0,1,3,2}, + f32[8,12,1,2]{0,1,3,2}, + f32[8,2,100,11]{0,1,3,2}, + f32[8,2,40,11]{0,1,3,2}, + f32[8,2,50,11]{3,2,1,0}) tuple(slice.1, slice.2, slice.3, slice.4, slice.5, slice.6, slice.7, slice.8, slice.9, slice.10) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + + HloInstruction* slice1 = + module->entry_computation()->GetInstructionWithName("slice.1"); + HloInstruction* slice2 = + module->entry_computation()->GetInstructionWithName("slice.2"); + HloInstruction* slice3 = + module->entry_computation()->GetInstructionWithName("slice.3"); + HloInstruction* slice4 = + module->entry_computation()->GetInstructionWithName("slice.4"); + HloInstruction* slice5 = + module->entry_computation()->GetInstructionWithName("slice.5"); + HloInstruction* slice6 = + module->entry_computation()->GetInstructionWithName("slice.6"); + HloInstruction* slice7 = + module->entry_computation()->GetInstructionWithName("slice.7"); + HloInstruction* slice8 = + module->entry_computation()->GetInstructionWithName("slice.8"); + HloInstruction* slice9 = + module->entry_computation()->GetInstructionWithName("slice.9"); + HloInstruction* slice10 = + module->entry_computation()->GetInstructionWithName("slice.10"); + EXPECT_TRUE(IsContiguousSlice(*slice1)); + EXPECT_TRUE(IsContiguousSlice(*slice2)); + EXPECT_TRUE(IsContiguousSlice(*slice3)); + EXPECT_TRUE(!IsContiguousSlice(*slice4)); + EXPECT_TRUE(!IsContiguousSlice(*slice5)); + EXPECT_TRUE(IsContiguousSlice(*slice6)); + EXPECT_TRUE(IsContiguousSlice(*slice7)); + EXPECT_TRUE(!IsContiguousSlice(*slice8)); + EXPECT_TRUE(!IsContiguousSlice(*slice9)); + EXPECT_TRUE(!IsContiguousSlice(*slice10)); +} + TEST_F(IrEmissionUtilsTest, LiteralToAttrToXlaFormat) { // int16, should be aliased. { diff --git a/xla/service/gpu/ir_emitter.cc b/xla/service/gpu/ir_emitter.cc index 8290787c69db0..4a6e05a10da06 100644 --- a/xla/service/gpu/ir_emitter.cc +++ b/xla/service/gpu/ir_emitter.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,18 +15,24 @@ limitations under the License. #include "xla/service/gpu/ir_emitter.h" +#include #include +#include // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/primitive_util.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/TargetParser/Triple.h" #include "xla/service/elemental_ir_emitter.h" #include "xla/service/gpu/elemental_ir_emitter.h" +#include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/ir_emitter_nested.h" #include "xla/service/llvm_ir/fused_ir_emitter.h" #include "xla/service/llvm_ir/ir_array.h" @@ -35,7 +41,7 @@ limitations under the License. #include "xla/service/llvm_ir/tuple_ops.h" #include "xla/shape_util.h" #include "xla/util.h" -#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -47,7 +53,7 @@ IrEmitter::IrEmitter(IrEmitterContext* ir_emitter_context, bool is_nested) b_(module_->getContext()), bindings_(&b_, module_, is_nested) {} -Status IrEmitter::DefaultAction(HloInstruction* hlo) { +absl::Status IrEmitter::DefaultAction(HloInstruction* hlo) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : hlo->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { @@ -60,11 +66,11 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { .MakeElementGenerator(hlo, operand_to_generator)); } -Status IrEmitter::HandleConstant(HloInstruction* constant) { - return OkStatus(); +absl::Status IrEmitter::HandleConstant(HloInstruction* constant) { + return absl::OkStatus(); } -Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) { +absl::Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) { VLOG(2) << "HandleAddDependency: " << add_dependency->ToString(); const HloInstruction* operand = add_dependency->operand(0); // Add_Dependency is a no-op, but we still want to bind it to an llvm::Value @@ -73,10 +79,11 @@ Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) { if (bindings_.BoundToIrValue(*operand)) { bindings_.BindHloToIrValue(*add_dependency, GetBasePointer(*operand)); } - return OkStatus(); + return absl::OkStatus(); } -Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { +absl::Status IrEmitter::HandleGetTupleElement( + HloInstruction* get_tuple_element) { auto operand = get_tuple_element->operand(0); CHECK(bindings_.BoundToIrValue(*operand)); bindings_.BindHloToIrValue( @@ -87,36 +94,36 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { // based on the real element type. /*alignment=*/1, GetBasePointer(*operand), llvm_ir::ShapeToIrType(operand->shape(), module_), &b_)); - return OkStatus(); + return absl::OkStatus(); } -Status IrEmitter::HandleSend(HloInstruction*) { +absl::Status IrEmitter::HandleSend(HloInstruction*) { return Unimplemented("Send is not implemented on GPU"); } -Status IrEmitter::HandleSendDone(HloInstruction*) { +absl::Status IrEmitter::HandleSendDone(HloInstruction*) { return Unimplemented("Send-Done is not implemented on GPU"); } -Status IrEmitter::HandleRecv(HloInstruction*) { +absl::Status IrEmitter::HandleRecv(HloInstruction*) { return Unimplemented("Recv is not implemented on GPU"); } -Status IrEmitter::HandleRecvDone(HloInstruction*) { +absl::Status IrEmitter::HandleRecvDone(HloInstruction*) { return Unimplemented("Recv-done is not implemented on GPU"); } -Status IrEmitter::HandleScatter(HloInstruction*) { +absl::Status IrEmitter::HandleScatter(HloInstruction*) { return Unimplemented("Scatter is not implemented on GPUs."); } -Status IrEmitter::HandleTuple(HloInstruction* tuple) { +absl::Status IrEmitter::HandleTuple(HloInstruction* tuple) { std::vector base_ptrs; for (const HloInstruction* operand : tuple->operands()) { base_ptrs.push_back(GetBasePointer(*operand)); } llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_); - return OkStatus(); + return absl::OkStatus(); } bool IrEmitter::IsEmittingForAMDGPU() const { @@ -150,34 +157,34 @@ std::pair MultiplyComplex(llvm::Value* lhs_value, } } // namespace -Status IrEmitter::HandleConvolution(HloInstruction* convolution) { +absl::Status IrEmitter::HandleConvolution(HloInstruction* convolution) { if (ShapeUtil::IsZeroElementArray(convolution->shape())) { // Emit no code for an empty output. - return OkStatus(); + return absl::OkStatus(); } // TODO(b/31409998): Support convolution with dilation. return Unimplemented( "Hit a case for convolution that is not implemented on GPU."); } -Status IrEmitter::HandleFft(HloInstruction* fft) { +absl::Status IrEmitter::HandleFft(HloInstruction* fft) { if (ShapeUtil::IsZeroElementArray(fft->shape())) { // Emit no code for an empty output. - return OkStatus(); + return absl::OkStatus(); } return Unimplemented("Hit a case for fft that is not implemented on GPU."); } -Status IrEmitter::HandleAllReduce(HloInstruction* crs) { +absl::Status IrEmitter::HandleAllReduce(HloInstruction* crs) { return Unimplemented( "AllReduce cannot be nested inside of fusion, map, etc."); } -Status IrEmitter::HandleParameter(HloInstruction* parameter) { - return OkStatus(); +absl::Status IrEmitter::HandleParameter(HloInstruction* parameter) { + return absl::OkStatus(); } -Status IrEmitter::HandleFusion(HloInstruction* fusion) { +absl::Status IrEmitter::HandleFusion(HloInstruction* fusion) { // kFusion for library calls should be handled by // IrEmitterUnnested::HandleFusion. CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind()); @@ -189,7 +196,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { return EmitTargetElementLoop(*fusion, generator); } -Status IrEmitter::HandleCall(HloInstruction* call) { +absl::Status IrEmitter::HandleCall(HloInstruction* call) { std::vector operand_addresses; for (HloInstruction* operand : call->operands()) { operand_addresses.push_back(GetBasePointer(*operand)); @@ -198,35 +205,35 @@ Status IrEmitter::HandleCall(HloInstruction* call) { operand_addresses, GetBasePointer(*call)); } -Status IrEmitter::HandleCustomCall(HloInstruction*) { +absl::Status IrEmitter::HandleCustomCall(HloInstruction*) { return Unimplemented("custom-call"); } -Status IrEmitter::HandleInfeed(HloInstruction*) { +absl::Status IrEmitter::HandleInfeed(HloInstruction*) { // TODO(b/30467474): Implement infeed on GPU. return Unimplemented("Infeed is not supported on GPU."); } -Status IrEmitter::HandleOutfeed(HloInstruction*) { +absl::Status IrEmitter::HandleOutfeed(HloInstruction*) { // TODO(b/34359662): Implement outfeed on GPU. return Unimplemented("Outfeed is not supported on GPU."); } -Status IrEmitter::HandleBatchNormInference(HloInstruction*) { +absl::Status IrEmitter::HandleBatchNormInference(HloInstruction*) { return Unimplemented( "The GPU backend does not implement BatchNormInference directly. It " "should be lowered before IR emission to HLO-soup using " "BatchNormRewriter."); } -Status IrEmitter::HandleBatchNormTraining(HloInstruction*) { +absl::Status IrEmitter::HandleBatchNormTraining(HloInstruction*) { return Unimplemented( "The GPU backend does not implement BatchNormTraining directly. It " "should be lowered before IR emission to HLO-soup using " "BatchNormRewriter."); } -Status IrEmitter::HandleBatchNormGrad(HloInstruction*) { +absl::Status IrEmitter::HandleBatchNormGrad(HloInstruction*) { return Unimplemented( "The GPU backend does not implement BatchNormGrad directly. It should " "be lowered before IR emission to HLO-soup using BatchNormRewriter."); @@ -263,8 +270,7 @@ void IrEmitter::BindFusionArguments(const HloInstruction* fusion, void IrEmitter::MaybeEmitFenceForAMDGPU(llvm::AtomicOrdering atomic_ordering, const char* sync_scope_id) { if (IsEmittingForAMDGPU() && - ir_emitter_context_->rocm_compute_capability().gcn_arch_name().substr( - 0, 6) == "gfx90a") { + ir_emitter_context_->rocm_compute_capability().fence_before_barrier()) { b_.CreateFence(atomic_ordering, b_.getContext().getOrInsertSyncScopeID(sync_scope_id)); } diff --git a/xla/service/gpu/ir_emitter.h b/xla/service/gpu/ir_emitter.h index 8ac0b2075d3cd..106d05018479e 100644 --- a/xla/service/gpu/ir_emitter.h +++ b/xla/service/gpu/ir_emitter.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,9 +18,11 @@ limitations under the License. #include +#include "absl/status/status.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" +#include "llvm/Support/AtomicOrdering.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/hlo_to_ir_bindings.h" @@ -29,6 +31,7 @@ limitations under the License. #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/ir_builder_mixin.h" #include "xla/service/llvm_ir/loop_emitter.h" +#include "xla/shape_util.h" namespace xla { namespace gpu { @@ -56,30 +59,33 @@ class IrEmitter : public DfsHloVisitorWithDefault, IrEmitter(const IrEmitter&) = delete; IrEmitter& operator=(const IrEmitter&) = delete; - Status DefaultAction(HloInstruction* hlo) override; - Status HandleConstant(HloInstruction* constant) override; - Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; - Status HandleConvolution(HloInstruction* convolution) override; - Status HandleFft(HloInstruction* fft) override; - Status HandleAllReduce(HloInstruction* crs) override; - Status HandleInfeed(HloInstruction* infeed) override; - Status HandleOutfeed(HloInstruction* outfeed) override; - Status HandleSend(HloInstruction* send) override; - Status HandleSendDone(HloInstruction* send_done) override; - Status HandleRecv(HloInstruction* recv) override; - Status HandleRecvDone(HloInstruction* recv_done) override; - Status HandleParameter(HloInstruction* parameter) override; - Status HandleTuple(HloInstruction* tuple) override; - Status HandleScatter(HloInstruction* scatter) override; - Status HandleFusion(HloInstruction* fusion) override; - Status HandleCall(HloInstruction* call) override; - Status HandleCustomCall(HloInstruction* custom_call) override; - Status HandleBatchNormInference(HloInstruction* batch_norm) override; - Status HandleBatchNormTraining(HloInstruction* batch_norm) override; - Status HandleBatchNormGrad(HloInstruction* batch_norm) override; - Status HandleAddDependency(HloInstruction* add_dependency) override; - - Status FinishVisit(HloInstruction* root) override { return OkStatus(); } + absl::Status DefaultAction(HloInstruction* hlo) override; + absl::Status HandleConstant(HloInstruction* constant) override; + absl::Status HandleGetTupleElement( + HloInstruction* get_tuple_element) override; + absl::Status HandleConvolution(HloInstruction* convolution) override; + absl::Status HandleFft(HloInstruction* fft) override; + absl::Status HandleAllReduce(HloInstruction* crs) override; + absl::Status HandleInfeed(HloInstruction* infeed) override; + absl::Status HandleOutfeed(HloInstruction* outfeed) override; + absl::Status HandleSend(HloInstruction* send) override; + absl::Status HandleSendDone(HloInstruction* send_done) override; + absl::Status HandleRecv(HloInstruction* recv) override; + absl::Status HandleRecvDone(HloInstruction* recv_done) override; + absl::Status HandleParameter(HloInstruction* parameter) override; + absl::Status HandleTuple(HloInstruction* tuple) override; + absl::Status HandleScatter(HloInstruction* scatter) override; + absl::Status HandleFusion(HloInstruction* fusion) override; + absl::Status HandleCall(HloInstruction* call) override; + absl::Status HandleCustomCall(HloInstruction* custom_call) override; + absl::Status HandleBatchNormInference(HloInstruction* batch_norm) override; + absl::Status HandleBatchNormTraining(HloInstruction* batch_norm) override; + absl::Status HandleBatchNormGrad(HloInstruction* batch_norm) override; + absl::Status HandleAddDependency(HloInstruction* add_dependency) override; + + absl::Status FinishVisit(HloInstruction* root) override { + return absl::OkStatus(); + } llvm::IRBuilder<>* builder() { return &b_; } @@ -115,7 +121,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // in the result of the given HLO instruction. This produces a series of // nested loops (e.g. one for each dimension of the `hlo`'s shape). The body // of the inner-most loop is provided by the body_emitter function. - virtual Status EmitTargetElementLoop( + virtual absl::Status EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter) = 0; @@ -138,13 +144,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, const char* sync_scope_id); private: - // A helper method for HandleSort(). It adds the inner comparison loop where - // we compare elements pointed to by 'keys_index' and 'compare_keys_index'. - void EmitCompareLoop(int64_t dimension_to_sort, - const llvm_ir::IrArray::Index& keys_index, - const llvm_ir::IrArray::Index& compare_keys_index, - const llvm_ir::IrArray& keys_array); - // A convenience method to determine whether or not IR is emitted for AMDGPU. bool IsEmittingForAMDGPU() const; }; diff --git a/xla/service/gpu/ir_emitter_context.cc b/xla/service/gpu/ir_emitter_context.cc index 9af81cb90fd45..38a5e71306990 100644 --- a/xla/service/gpu/ir_emitter_context.cc +++ b/xla/service/gpu/ir_emitter_context.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,13 +16,23 @@ limitations under the License. #include "xla/service/gpu/ir_emitter_context.h" #include -#include +#include #include #include #include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/Support/Alignment.h" +#include "llvm/TargetParser/Triple.h" #include "xla/service/gpu/gpu_constants.h" +#include "xla/service/gpu/gpu_executable.h" #include "xla/service/gpu/ir_emission_utils.h" namespace xla { @@ -66,6 +76,9 @@ void IrEmitterContext::emit_constant(int64_t num_elements, content.span().size())); }(); + // Explicitly set global addrspace for SPIR backend. + int addrspace = + llvm::Triple(llvm_module_->getTargetTriple()).isSPIR() ? 1 : 0; // These globals will be looked up by name by GpuExecutable so we need to // give them an external linkage. Not all of their uses are visible in // the LLVM IR so we can't give then a linkage that merely preserves their @@ -79,7 +92,7 @@ void IrEmitterContext::emit_constant(int64_t num_elements, llvm::GlobalValue::ExternalLinkage, /*Initializer=*/initializer, symbol_name, /*TLMode=*/llvm::GlobalValue::NotThreadLocal, - /*AddressSpace=*/0, + /*AddressSpace=*/addrspace, /*isExternallyInitialized=*/false); global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes)); llvm_module_->insertGlobalVariable(global_for_const); diff --git a/xla/service/gpu/ir_emitter_context.h b/xla/service/gpu/ir_emitter_context.h index 90614c715ee82..afbf212bad036 100644 --- a/xla/service/gpu/ir_emitter_context.h +++ b/xla/service/gpu/ir_emitter_context.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,22 +16,42 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_IR_EMITTER_CONTEXT_H_ #define XLA_SERVICE_GPU_IR_EMITTER_CONTEXT_H_ +#include +#include #include #include #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/gpu_executable.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/kernel_reuse_cache.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/service/name_uniquer.h" #include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { +// Maps async start ops to their async events so we can emit done thunk +// sharing events with corresponding start thunk. Async events may be null if +// the start op is degenerate (so not emitted). For Send and Recv, this maps +// to the asyn events, as multiple Recv and Recv-done or +// multiple Send and Send-done may map to the same async events and a Recv-done +// or Send-done operand may not be its corresponding Recv or Send, when a +// Send-Recv chain inside a loop is pipelined. +using CollectivesAsyncEvents = + absl::flat_hash_map>, + std::shared_ptr>; // IrEmitterContext encapsulates common (mutable and immutable) data structures // used by both IrEmitterNested and IrEmitterUnnested, such as the buffer @@ -43,14 +63,14 @@ class IrEmitterContext { std::string platform_name, const se::DeviceDescription& gpu_device_info, mlir::MLIRContext* mlir_context, llvm::Module* llvm_module, - bool emit_ir_from_hlo) + bool emit_kernels) : hlo_module_(hlo_module), buffer_assignment_(buffer_assignment), platform_name_(std::move(platform_name)), gpu_device_info_(gpu_device_info), mlir_context_(mlir_context), llvm_module_(llvm_module), - emit_ir_from_hlo_(emit_ir_from_hlo) {} + emit_kernels_(emit_kernels) {} // Disallow copy and assign. IrEmitterContext(const IrEmitterContext&) = delete; IrEmitterContext& operator=(const IrEmitterContext&) = delete; @@ -83,14 +103,6 @@ class IrEmitterContext { std::vector& constants() { return constants_; } - absl::Span allocations() const { - return allocations_; - } - - void set_allocations(absl::Span allocations) { - allocations_ = allocations; - } - // Emit a constant with a given number of element, given byte size of the // element, given symbol name and content. void emit_constant(int64_t num_elements, int64_t bytes_per_element, @@ -101,24 +113,28 @@ class IrEmitterContext { return hlo_module_->config().debug_options(); } - bool emit_ir_from_hlo() const { return emit_ir_from_hlo_; } + KernelReuseCache& kernel_cache() { return kernel_cache_; } + CollectivesAsyncEvents& collectives_async_events() { + return collectives_async_events_; + } + + bool emit_kernels() const { return emit_kernels_; } private: const HloModule* hlo_module_; const BufferAssignment* buffer_assignment_; - - // Stores pointer to buffer allocations in the order of the LMHLO entry args. - // LMHLO-based emitters need the ordering to locate the buffer allocation. - // This should be removed once LMHLO-based emitters are removed. - absl::Span allocations_; - std::string platform_name_; const se::DeviceDescription& gpu_device_info_; mlir::MLIRContext* mlir_context_; llvm::Module* llvm_module_; NameUniquer name_uniquer_; std::vector constants_; - const bool emit_ir_from_hlo_; + KernelReuseCache kernel_cache_; + + CollectivesAsyncEvents collectives_async_events_; + + // We should not emit kernels when loading thunks from a compilation result. + const bool emit_kernels_; }; } // namespace gpu diff --git a/xla/service/gpu/ir_emitter_nested.cc b/xla/service/gpu/ir_emitter_nested.cc index b531a5655fb06..7403531a8285f 100644 --- a/xla/service/gpu/ir_emitter_nested.cc +++ b/xla/service/gpu/ir_emitter_nested.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,26 +14,55 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/ir_emitter_nested.h" +#include +#include +#include #include #include +#include "absl/algorithm/container.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/Alignment.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" +#include "llvm/TargetParser/Triple.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal.h" +#include "xla/primitive_util.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/ir_emitter.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/kernel_reuse_cache.h" #include "xla/service/llvm_ir/buffer_assignment_util.h" +#include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/kernel_support_library.h" +#include "xla/service/llvm_ir/llvm_loop.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "xla/service/llvm_ir/loop_emitter.h" #include "xla/service/llvm_ir/tuple_ops.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_description.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -52,7 +81,7 @@ class IrEmitterNested : public IrEmitter { // Overrides the default empty implementation. Binds the given instruction // "parameter" with the parameter of the IR function. - Status HandleParameter(HloInstruction* parameter) override; + absl::Status HandleParameter(HloInstruction* parameter) override; // Generate the code for the computation passed in the constructor, if it // wasn't already generated previously. @@ -65,17 +94,17 @@ class IrEmitterNested : public IrEmitter { // // The allocation index for these constants will always be -1 (i.e. doesn't // correspond to any allocation) - StatusOr CodegenNestedComputation(); + absl::StatusOr CodegenNestedComputation(); protected: - Status EmitTargetElementLoop( + absl::Status EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator) override; private: // Emits constants to generated LLVM IR, and also populates related // information to 'ir_emitter_context_' for large-constant initializations. - Status EmitConstants(const HloComputation& computation); + absl::Status EmitConstants(const HloComputation& computation); const HloComputation& nested_computation_; }; @@ -88,7 +117,7 @@ IrEmitterNested::IrEmitterNested(const HloComputation& nested_computation, // Nested function serves the same purpose on GPU as a thread-local function on // a CPU. -StatusOr IrEmitterNested::CodegenNestedComputation() { +absl::StatusOr IrEmitterNested::CodegenNestedComputation() { // Include a fingerprint of the HLO in the function name. Currently, codegen // is invoked on temporary HLO objects, which means the address of the // computation is not necessarily unique. @@ -209,11 +238,11 @@ StatusOr IrEmitterNested::CodegenNestedComputation() { return function; } -Status IrEmitterNested::HandleParameter(HloInstruction* parameter) { - return OkStatus(); +absl::Status IrEmitterNested::HandleParameter(HloInstruction* parameter) { + return absl::OkStatus(); } -Status IrEmitterNested::EmitTargetElementLoop( +absl::Status IrEmitterNested::EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator) { // For MOF we give the loop emitter an array for every output it should @@ -224,13 +253,13 @@ Status IrEmitterNested::EmitTargetElementLoop( TF_RETURN_IF_ERROR( llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop()); llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_); - return OkStatus(); + return absl::OkStatus(); } return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_) .EmitLoop(); } -Status IrEmitterNested::EmitConstants(const HloComputation& computation) { +absl::Status IrEmitterNested::EmitConstants(const HloComputation& computation) { for (HloInstruction* instr : computation.instructions()) { if (instr->opcode() != HloOpcode::kConstant) { continue; @@ -258,7 +287,7 @@ Status IrEmitterNested::EmitConstants(const HloComputation& computation) { absl::MakeSpan(base, base + literal.size_bytes())), &b_); } - return OkStatus(); + return absl::OkStatus(); } // Casts the provided llvm::Value* to the default address space. This is useful @@ -268,8 +297,8 @@ llvm::Value* AddrCastToDefault(llvm::Value* arg, llvm::IRBuilder<>& b) { llvm::Type* arg_type = arg->getType(); CHECK(arg_type->isPointerTy()); if (arg_type->getPointerAddressSpace() != 0) { - llvm::Type* generic_arg_type = llvm::PointerType::getWithSamePointeeType( - llvm::cast(arg_type), 0); + llvm::Type* generic_arg_type = llvm::PointerType::get( + llvm::cast(arg_type)->getContext(), 0); llvm::Value* addrspacecast_arg = b.CreateAddrSpaceCast(arg, generic_arg_type); return addrspacecast_arg; @@ -292,8 +321,8 @@ void EmitAMDGPUAtomicAdd(llvm::IRBuilder<>* builder, // is in global addrspace (1) : builder->CreateAddrSpaceCast( output_address, - llvm::PointerType::getWithSamePointeeType(output_address_type, - /*AddressSpace=*/1)); + llvm::PointerType::get(output_address_type->getContext(), + /*AddressSpace=*/1)); builder->CreateAtomicRMW( llvm::AtomicRMWInst::FAdd, output_ptr, source, llvm::MaybeAlign(), @@ -359,8 +388,12 @@ bool MaybeEmitDirectAtomicOperation(llvm::IRBuilder<>* builder, bool f64_atomic_add_supported = ir_emitter_context.cuda_compute_capability().IsAtLeast( se::CudaComputeCapability::PASCAL_); + bool f16_atomic_add_supported = + ir_emitter_context.cuda_compute_capability().IsAtLeast( + se::CudaComputeCapability::VOLTA); bool atomic_add_supported = element_type == F32 || + (f16_atomic_add_supported && element_type == F16) || (f64_atomic_add_supported && element_type == F64); if (atomic_add_supported) { builder->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, output_address, @@ -538,16 +571,15 @@ bool MaybeEmitDirectAtomicOperation(llvm::IRBuilder<>* builder, // *cas_new_output_address); // } while (!success); // -Status EmitAtomicOperationUsingCAS(llvm::IRBuilder<>* builder, - IrEmitterContext& ir_emitter_context, - const HloComputation& computation, - llvm::Value* output_address, - llvm::Value* source_address, - llvm::Type* element_type) { +absl::Status EmitAtomicOperationUsingCAS(llvm::IRBuilder<>* builder, + IrEmitterContext& ir_emitter_context, + const HloComputation& computation, + llvm::Value* output_address, + llvm::Value* source_address, + llvm::Type* element_type) { llvm::PointerType* output_address_type = llvm::dyn_cast(output_address->getType()); CHECK_NE(output_address_type, nullptr); - CHECK(output_address_type->isOpaqueOrPointeeTypeMatches(element_type)); int element_size = llvm_ir::GetSizeInBits(element_type); @@ -663,16 +695,16 @@ Status EmitAtomicOperationUsingCAS(llvm::IRBuilder<>* builder, // Set the insertion point to the exit basic block so that the caller of // this method can continue emitting code to the right place. llvm_ir::SetToFirstInsertPoint(loop_exit_bb, builder); - return OkStatus(); + return absl::OkStatus(); } } // namespace -Status CallNestedComputation(llvm::IRBuilder<>* builder, - IrEmitterContext& ir_emitter_context, - const HloComputation& computation, - absl::Span operands, - llvm::Value* output) { +absl::Status CallNestedComputation(llvm::IRBuilder<>* builder, + IrEmitterContext& ir_emitter_context, + const HloComputation& computation, + absl::Span operands, + llvm::Value* output) { TF_RET_CHECK(computation.num_parameters() > 0); TF_ASSIGN_OR_RETURN(llvm::Function * emitted_function, @@ -692,10 +724,10 @@ Status CallNestedComputation(llvm::IRBuilder<>* builder, builder->CreateCall(emitted_function, arguments); - return OkStatus(); + return absl::OkStatus(); } -StatusOr> CallNestedComputationWithScalars( +absl::StatusOr> CallNestedComputationWithScalars( llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, const HloComputation& computation, absl::Span parameter_elements) { @@ -710,7 +742,7 @@ StatusOr> CallNestedComputationWithScalars( computation, parameter_buffers); } -StatusOr> CallNestedComputationWithScalarAddrs( +absl::StatusOr> CallNestedComputationWithScalarAddrs( llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, const HloComputation& computation, absl::Span parameter_elements_addrs) { @@ -746,7 +778,7 @@ StatusOr> CallNestedComputationWithScalarAddrs( return returned_scalars; } -Status EmitAtomicOperationForNestedComputation( +absl::Status EmitAtomicOperationForNestedComputation( llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, const HloComputation& computation, llvm::Value* output_address, llvm::Value* source_address, llvm::Type* element_type) { @@ -760,7 +792,7 @@ Status EmitAtomicOperationForNestedComputation( if (MaybeEmitDirectAtomicOperation(builder, ir_emitter_context, computation, output_address, source_address)) { - return OkStatus(); + return absl::OkStatus(); } return EmitAtomicOperationUsingCAS(builder, ir_emitter_context, computation, diff --git a/xla/service/gpu/ir_emitter_nested.h b/xla/service/gpu/ir_emitter_nested.h index 9f1b6e1baf983..ee5204b6a94cd 100644 --- a/xla/service/gpu/ir_emitter_nested.h +++ b/xla/service/gpu/ir_emitter_nested.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,12 +16,15 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_ #define XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_ +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/hlo_module_config.h" namespace xla { namespace gpu { @@ -40,20 +43,20 @@ namespace gpu { // - N pointers to the buffers of each of the N parameters to the computation, // - a pointer to the output buffer of the computation, and // - a pointer to the top-level temp buffer. -Status CallNestedComputation(llvm::IRBuilder<>* builder, - IrEmitterContext& ir_emitter_context, - const HloComputation& computation, - absl::Span operands, - llvm::Value* output); +absl::Status CallNestedComputation(llvm::IRBuilder<>* builder, + IrEmitterContext& ir_emitter_context, + const HloComputation& computation, + absl::Span operands, + llvm::Value* output); // Like CallNestedComputation, but parameters and results are scalars. -StatusOr> CallNestedComputationWithScalars( +absl::StatusOr> CallNestedComputationWithScalars( llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, const HloComputation& computation, absl::Span parameter_elements); // Like CallNestedComputationWithScalars, but parameters are scalar addresses. -StatusOr> CallNestedComputationWithScalarAddrs( +absl::StatusOr> CallNestedComputationWithScalarAddrs( llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, const HloComputation& computation, absl::Span parameter_elements_addrs); @@ -67,7 +70,7 @@ StatusOr> CallNestedComputationWithScalarAddrs( // will, otherwise it will be emitted as a compare-and-swap and a loop. // // The computation must have exactly two parameters. -Status EmitAtomicOperationForNestedComputation( +absl::Status EmitAtomicOperationForNestedComputation( llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, const HloComputation& computation, llvm::Value* output_address, llvm::Value* source_address, llvm::Type* element_type); diff --git a/xla/service/gpu/ir_emitter_triton.cc b/xla/service/gpu/ir_emitter_triton.cc index 93f511506eea3..f1bbceb65c676 100644 --- a/xla/service/gpu/ir_emitter_triton.cc +++ b/xla/service/gpu/ir_emitter_triton.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,14 +17,17 @@ limitations under the License. #include #include +#include #include #include +#include #include #include #include #include #include // NOLINT(build/c++11): required to interface with LLVM #include +#include #include #include "absl/algorithm/container.h" @@ -44,11 +47,15 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/Linker/Linker.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TargetParser/Triple.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project @@ -56,6 +63,7 @@ limitations under the License. #include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/ExecutionEngine/OptUtils.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -65,6 +73,7 @@ limitations under the License. #include "mlir/IR/DialectRegistry.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project @@ -80,49 +89,56 @@ limitations under the License. #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" // from @llvm-project +#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Export.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "xla/autotuning.pb.h" #include "xla/comparison_util.h" +#include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" #include "xla/primitive_util.h" +#include "xla/service/algorithm_util.h" #include "xla/service/dump.h" +#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/symbolic_tile_analysis.h" +#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" #include "xla/service/gpu/target_util.h" #include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/service/gpu/triton_tiling_propagation.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/tensor_float_32_utils.h" -#include "triton/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.h" -#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" -#include "triton/Dialect/Triton/Transforms/Passes.h" -#include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" namespace xla { @@ -137,6 +153,7 @@ namespace mt = ::mlir::triton; using ::llvm::SmallVector; using mlir::ArrayRef; using mlir::ImplicitLocOpBuilder; +using ::mlir::ShapedType; using ::mlir::Type; using ::mlir::Value; using mlir::ValueRange; @@ -181,7 +198,7 @@ Type StorageType(mlir::OpBuilder b, Type t) { template T ScalarConstantValue(const HloInstruction& instr, PrimitiveType dst_type) { CHECK(hlo_query::IsScalarConstant(&instr)); - StatusOr converted = instr.literal().Convert(dst_type); + absl::StatusOr converted = instr.literal().Convert(dst_type); TF_CHECK_OK(converted.status()); return converted.value().GetFirstElement(); } @@ -216,7 +233,7 @@ ma::ConstantOp CreateConst(ImplicitLocOpBuilder& b, Type type, T value, } Value ZerosLike(ImplicitLocOpBuilder& b, Value x) { - if (auto src_shaped_ty = x.getType().dyn_cast()) { + if (auto src_shaped_ty = x.getType().dyn_cast()) { Type src_ty = src_shaped_ty.getElementType(); return CreateConst(b, src_ty, 0, src_shaped_ty.getShape()); } @@ -224,7 +241,7 @@ Value ZerosLike(ImplicitLocOpBuilder& b, Value x) { } Value OnesLike(ImplicitLocOpBuilder& b, Value x) { - if (auto src_shaped_ty = x.getType().dyn_cast()) { + if (auto src_shaped_ty = x.getType().dyn_cast()) { Type src_ty = src_shaped_ty.getElementType(); return CreateConst(b, src_ty, 1, src_shaped_ty.getShape()); } @@ -237,7 +254,7 @@ Value Cast(ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) { Type src_element_ty = src_ty; Type fp32_ty = b.getF32Type(); Type dst_ty = dst_element_ty; - if (auto src_shaped_ty = src_ty.dyn_cast()) { + if (auto src_shaped_ty = src_ty.dyn_cast()) { src_element_ty = src_shaped_ty.getElementType(); dst_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), dst_element_ty); fp32_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), b.getF32Type()); @@ -251,7 +268,10 @@ Value Cast(ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) { return Cast(b, b.create(fp32_ty, value), dst_element_ty); } if (dst_element_ty.isBF16()) { - return b.create(dst_ty, Cast(b, value, b.getF32Type())); + // S8 -> BF16 is directly supported and doesn't need to go through f32. + if (!src_element_ty.isInteger(8)) { + return b.create(dst_ty, Cast(b, value, b.getF32Type())); + } } // float => float @@ -326,9 +346,11 @@ Value Compare(ImplicitLocOpBuilder& b, ValueRange values, values[0], values[1]); } -Value Maximum(ImplicitLocOpBuilder& b, ValueRange values) { - // ma::MaximumFOp seems to think that max(NaN, x) = x, so we don't use that. - // +Value Maximum(ImplicitLocOpBuilder& b, const se::DeviceDescription& device_info, + ValueRange values) { + if (mlir::getElementTypeOrSelf(values[0]).isa()) { + return b.create(values); + } // logic: isNaN(lhs) || (!isNan(rhs) && lhs >= rhs) ? lhs : rhs // See also: IEEE Std 754-2008 5.11. // @@ -345,9 +367,11 @@ Value Maximum(ImplicitLocOpBuilder& b, ValueRange values) { values[0], values[1]); } -Value Minimum(ImplicitLocOpBuilder& b, ValueRange values) { - // ma::MinimumFOp seems to think that min(NaN, x) = x, so we don't use that. - // +Value Minimum(ImplicitLocOpBuilder& b, const se::DeviceDescription& device_info, + ValueRange values) { + if (mlir::getElementTypeOrSelf(values[0]).isa()) { + return b.create(values); + } // logic: isNaN(lhs) || (!isNan(rhs) && lhs <= rhs) ? lhs : rhs // See also: IEEE Std 754-2008 5.11. // @@ -388,17 +412,24 @@ Value AddPtr(ImplicitLocOpBuilder& b, Value ptr, Value offset) { return b.create(ptr.getType(), ptr, offset); } -Value EmitElementwise(ImplicitLocOpBuilder& b, absl::string_view libdevice_path, - const HloInstruction& hlo, ValueRange inputs) { +absl::StatusOr EmitElementwise(ImplicitLocOpBuilder& b, + absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const HloInstruction& hlo, + ValueRange inputs) { if (mlir::getElementTypeOrSelf(inputs[0]).isF32() || mlir::getElementTypeOrSelf(inputs[0]).isF64()) { auto dev_fn_id = GetTargetDeviceFunctionID(hlo.opcode()); if (dev_fn_id.ok()) { + llvm::Triple triple("nvptx64-unknown-unknown"); + if (std::holds_alternative( + device_info.gpu_compute_capability())) { + triple.setTriple("amdgcn-unknown-unknown"); + } return b.create( inputs[0].getType(), inputs, "libdevice", libdevice_path, ObtainDeviceFunctionName(dev_fn_id.value(), - hlo.shape().element_type(), - llvm::Triple("nvptx64-unknown-unknown")), + hlo.shape().element_type(), triple), /*pure=*/true); } } @@ -434,9 +465,9 @@ Value EmitElementwise(ImplicitLocOpBuilder& b, absl::string_view libdevice_path, } return b.create(inputs[0], inputs[1]); case HloOpcode::kMaximum: - return Maximum(b, inputs); + return Maximum(b, device_info, inputs); case HloOpcode::kMinimum: - return Minimum(b, inputs); + return Minimum(b, device_info, inputs); case HloOpcode::kAnd: return b.create(inputs[0], inputs[1]); case HloOpcode::kOr: @@ -461,7 +492,8 @@ Value EmitElementwise(ImplicitLocOpBuilder& b, absl::string_view libdevice_path, mlir::mhlo::ComparisonDirection::NE), inputs[1], inputs[2]); default: - LOG(FATAL) << "Unsupported operation " << hlo.ToString(); + return absl::InvalidArgumentError( + absl::StrCat("Unsupported elementwise operation ", hlo.ToString())); } } @@ -514,12 +546,12 @@ struct DimProperties { int split_value; }; -Value EmitBroadcast(ImplicitLocOpBuilder& b, - const TritonFusionAnalysis* analysis, - TritonFusionAnalysis::Scope scope, - absl::Span tiled_dimensions, - const HloInstruction& broadcast, Value input) { - CHECK(analysis != nullptr); +absl::StatusOr EmitBroadcast( + ImplicitLocOpBuilder& b, const TritonFusionAnalysis* analysis, + TritonFusionAnalysis::Scope scope, + absl::Span tiled_dimensions, + const HloInstruction& broadcast, Value input) { + TF_RET_CHECK(analysis != nullptr); std::vector out_shape; for (const DimProperties& dim : tiled_dimensions) { const TensorIterationSpec::DimIterationSpec* spec = @@ -554,27 +586,31 @@ Value EmitBroadcast(ImplicitLocOpBuilder& b, return Broadcast(b, expanded_input.cast(), out_shape); } -StatusOr EmitScope( +absl::StatusOr EmitScope( ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + const se::DeviceDescription& device_info, const TritonFusionAnalysis* analysis, TritonFusionAnalysis::Scope scope, absl::Span tiled_dimensions, absl::Span instructions, absl::flat_hash_map& values); -StatusOr EmitReduce(ImplicitLocOpBuilder& b, - absl::string_view libdevice_path, - const HloInstruction& hlo_reduce, Value input) { +absl::StatusOr EmitReduce(ImplicitLocOpBuilder& b, + absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const HloInstruction& hlo_reduce, + Value input) { llvm::ArrayRef input_shape = input.cast().getType().getShape(); // At the moment, we should only emit a full reduction over the last axis of // a single input. - CHECK_EQ(hlo_reduce.operand_count(), 2); - CHECK_EQ(hlo_reduce.dimensions().size(), 1); - CHECK_EQ(hlo_reduce.dimensions(0), hlo_reduce.operand(0)->shape().rank() - 1); + TF_RET_CHECK(hlo_reduce.operand_count() == 2); + TF_RET_CHECK(hlo_reduce.dimensions().size() == 1); + TF_RET_CHECK(hlo_reduce.dimensions(0) == + hlo_reduce.operand(0)->shape().rank() - 1); const int block_row = input_shape.back(); const int row_len = hlo_reduce.operand(0)->shape().dimensions_minor(0); - CHECK_GE(block_row, row_len); + TF_RET_CHECK(block_row >= row_len); const HloInstruction* operand = hlo_reduce.operand(1); Value neutral; @@ -582,14 +618,14 @@ StatusOr EmitReduce(ImplicitLocOpBuilder& b, // We assume that the reduction value was input as a constant, or in the case // of a data type affected by float normalization, a convert of a constant. if (operand->opcode() == HloOpcode::kConvert) { - CHECK_EQ(operand->operand(0)->opcode(), HloOpcode::kConstant); - CHECK_EQ(operand->operand(0)->shape().element_type(), BF16); + TF_RET_CHECK(operand->operand(0)->opcode() == HloOpcode::kConstant); + TF_RET_CHECK(operand->operand(0)->shape().element_type() == BF16); PrimitiveType dest_ty = operand->shape().element_type(); - CHECK_EQ(dest_ty, F32); + TF_RET_CHECK(dest_ty == F32); neutral = EmitConstant(b, *operand->operand(0)); neutral = Cast(b, neutral, TritonType(b, dest_ty)); } else { - CHECK_EQ(operand->opcode(), HloOpcode::kConstant); + TF_RET_CHECK(operand->opcode() == HloOpcode::kConstant); neutral = EmitConstant(b, *operand); } @@ -628,22 +664,24 @@ StatusOr EmitReduce(ImplicitLocOpBuilder& b, reduction_computation->MakeInstructionPostOrder()) { if (instr->opcode() == HloOpcode::kParameter) { int parameter_number = instr->parameter_number(); - CHECK_LT(parameter_number, 2); - CHECK(region_values - .insert({instr, reducer->getArgument(parameter_number)}) - .second); + TF_RET_CHECK(parameter_number < 2); + TF_RET_CHECK( + region_values + .insert({instr, reducer->getArgument(parameter_number)}) + .second); } else { to_emit.push_back(instr); } } - CHECK(!to_emit.empty()); + TF_RET_CHECK(!to_emit.empty()); b.setInsertionPointToStart(reducer); - TF_ASSIGN_OR_RETURN(Value result, - EmitScope(b, libdevice_path, /*analysis=*/nullptr, - TritonFusionAnalysis::Scope::OUTPUT, {}, - to_emit, region_values)); + TF_ASSIGN_OR_RETURN( + Value result, + EmitScope(b, libdevice_path, device_info, /*analysis=*/nullptr, + TritonFusionAnalysis::Scope::OUTPUT, {}, to_emit, + region_values)); b.create(SmallVector({result})); b.setInsertionPointAfter(reduction); } @@ -660,10 +698,182 @@ StatusOr EmitReduce(ImplicitLocOpBuilder& b, return Cast(b, result, TritonType(b, hlo_reduce.shape().element_type())); } +// Emit code corresponding to a fusion instruction somehow nested within the +// initial Triton fusion. This can happen when we carry around auxiliary +// computations, e.g. with reduces. Since we are emitting a single Triton +// fusion, we simply flatten the fusion inside the computation. +// +// TODO(b/331413981): get rid of this special handling once this is solved. +absl::StatusOr EmitNestedFusion( + ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const HloFusionInstruction& fusion_instruction, + absl::flat_hash_map& values) { + // TODO(b/331402498): revisit the order of scope once we completely deprecate + // Triton fusion analysis. + const HloComputation* fusion_computation = + fusion_instruction.fused_instructions_computation(); + + absl::flat_hash_map region_values; + + std::vector to_emit; + for (const HloInstruction* instr : + fusion_computation->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kParameter) { + int64_t parameter_number = instr->parameter_number(); + auto it = values.find(fusion_instruction.operand(parameter_number)); + TF_RET_CHECK(it != values.end()); + TF_RET_CHECK(region_values.insert({instr, it->second}).second); + } else { + to_emit.push_back(instr); + } + } + + TF_RET_CHECK(to_emit.back() == fusion_computation->root_instruction()); + + return EmitScope(b, libdevice_path, device_info, /*analysis=*/nullptr, + TritonFusionAnalysis::Scope::OUTPUT, {}, to_emit, + region_values); +} + +// TODO(b/331332678): Add unit tests to target this function specifically. +Value EmitTiledBroadcast( + ImplicitLocOpBuilder& b, const SymbolicTileAnalysis& analysis, + const SymbolicTiledHloInstruction& tiled_broadcast, + absl::flat_hash_map& values) { + auto input_tile_shape = analysis.TileSizes(*tiled_broadcast.operand(0)); + auto output_tile_shape = analysis.TileSizes(tiled_broadcast); + + Value expanded_input = values[tiled_broadcast.operand(0)]; + + // Returns true if `dim_id` is broadcasted. + auto is_broadcasted_dim = [&](int64_t dim_id) { + return !llvm::is_contained(tiled_broadcast.hlo()->dimensions(), dim_id); + }; + + // The loop below iterates over output dimensions and tracks matching dims in + // input_tile_shape and expended_input value. + // `input_dim_id != expanded_input_dim_id`, because size-1 dims are present in + // the input tile shape, but not in the MLIR value. Triton doesn't like size-1 + // dims, so they are inserted only for dimensions that will be broadcasted. + int64_t input_dim_id = 0; + int64_t expanded_input_dim_id = 0; + for (size_t output_dim_id = 0; output_dim_id < output_tile_shape.size(); + ++output_dim_id) { + if (is_broadcasted_dim(output_dim_id)) { + // The dim is broadcasted in the original instruction, but tiled to 1 in + // this case. Nothing to broadcast. + if (output_tile_shape[output_dim_id] == 1) continue; + + // Expand dim for broadcast. + expanded_input = + b.create(expanded_input, expanded_input_dim_id); + ++expanded_input_dim_id; + } else { + // The dim is not broadcasted. Validate that it's equal in the input and + // output tile. + CHECK_EQ(input_tile_shape[input_dim_id], + output_tile_shape[output_dim_id]); + ++input_dim_id; + + // Size-1 dims are not present in the tensor type. + if (output_tile_shape[output_dim_id] != 1) { + ++expanded_input_dim_id; + } + } + } + + SmallVector padded_output_tile_shape; + padded_output_tile_shape.reserve(output_tile_shape.size()); + + for (int64_t tile_dim : output_tile_shape) { + if (tile_dim != 1) { + padded_output_tile_shape.push_back(llvm::PowerOf2Ceil(tile_dim)); + } + } + + return Broadcast(b, expanded_input.cast(), + padded_output_tile_shape); +} + +absl::StatusOr EmitTiledHloInstruction( + ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const SymbolicTileAnalysis& analysis, + const SymbolicTiledHloInstruction& tiled_hlo, + std::function(const SymbolicTiledHloInstruction&)> + emit_param_load_fn, + absl::flat_hash_map& values) { + const HloInstruction* hlo = tiled_hlo.hlo(); + + if (hlo->opcode() == HloOpcode::kParameter) { + return emit_param_load_fn(tiled_hlo); + } + + if (hlo->opcode() == HloOpcode::kConstant && + ShapeUtil::IsEffectiveScalar(hlo->shape())) { + // Splat makes it a tensor to avoid type mismatches. + return Splat(b, EmitConstant(b, *hlo), {}); + } + + if (hlo->opcode() == HloOpcode::kBroadcast) { + return EmitTiledBroadcast(b, analysis, tiled_hlo, values); + } + + if (hlo->opcode() == HloOpcode::kReduce) { + return EmitReduce(b, libdevice_path, device_info, *hlo, + values[tiled_hlo.operand(0)]); + } + + if (hlo->IsElementwise()) { + std::vector operands; + operands.reserve(hlo->operands().size()); + + for (const SymbolicTiledHloInstruction* operand : tiled_hlo.operands()) { + operands.push_back(values[operand]); + } + return EmitElementwise(b, libdevice_path, device_info, *hlo, operands); + } + + if (hlo->opcode() == HloOpcode::kTranspose || + hlo->opcode() == HloOpcode::kSlice || hlo->opcode() == HloOpcode::kPad) { + // All these are currently supported only as operations on indices + // which are pushed to loads and stores. No operations on tiles are + // performed here. + return values[tiled_hlo.operand(0)]; + } + + return absl::UnimplementedError( + absl::StrCat("Unsupported opcode: ", hlo->opcode())); +} + // Emit sequence of instructions using compatible tiling ordered producers // before consumers. -StatusOr EmitScope( +absl::StatusOr EmitTiledScope( ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const SymbolicTileAnalysis& analysis, + std::function(const SymbolicTiledHloInstruction&)> + emit_param_load_fn, + absl::flat_hash_map& values) { + for (const auto& tiled_hlo : analysis.GetTiledHloInstructions()) { + TF_ASSIGN_OR_RETURN( + Value result, + EmitTiledHloInstruction(b, libdevice_path, device_info, analysis, + *tiled_hlo, emit_param_load_fn, values)); + TF_RET_CHECK(values.insert({tiled_hlo.get(), result}).second) + << tiled_hlo->hlo()->ToString(); + VLOG(8) << "Emitted " + << tiled_hlo->hlo()->ToString(HloPrintOptions::ShortParsable()); + } + return values[analysis.GetRoot()]; +} + +// Emit sequence of instructions using compatible tiling ordered producers +// before consumers. +absl::StatusOr EmitScope( + ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + const se::DeviceDescription& device_info, const TritonFusionAnalysis* analysis, TritonFusionAnalysis::Scope scope, absl::Span tiled_dimensions, absl::Span instructions, @@ -684,18 +894,20 @@ StatusOr EmitScope( // Splat makes it a tensor to avoid type mismatches. result = Splat(b, EmitConstant(b, *hlo), {}); } else if (hlo->opcode() == HloOpcode::kBroadcast) { - result = EmitBroadcast(b, analysis, scope, tiled_dimensions, *hlo, - values[hlo->operand(0)]); - } else if (hlo->opcode() == HloOpcode::kReduce) { TF_ASSIGN_OR_RETURN( - result, EmitReduce(b, libdevice_path, *hlo, values[hlo->operand(0)])); - } else if (hlo->IsElementwise()) { + result, EmitBroadcast(b, analysis, scope, tiled_dimensions, *hlo, + values[hlo->operand(0)])); + } else if (hlo->opcode() == HloOpcode::kReduce) { + TF_ASSIGN_OR_RETURN(result, EmitReduce(b, libdevice_path, device_info, + *hlo, values[hlo->operand(0)])); + } else if (HloInstruction::IsOpElementwise(hlo->opcode())) { std::vector operands; operands.reserve(hlo->operands().size()); for (const HloInstruction* operand : hlo->operands()) { operands.push_back(values[operand]); } - result = EmitElementwise(b, libdevice_path, *hlo, operands); + TF_ASSIGN_OR_RETURN(result, EmitElementwise(b, libdevice_path, + device_info, *hlo, operands)); } else if (hlo->opcode() == HloOpcode::kTuple) { TF_RET_CHECK(hlo->IsRoot()) << hlo->ToString(); } else if (hlo->opcode() == HloOpcode::kBitcast || @@ -707,8 +919,14 @@ StatusOr EmitScope( // which are pushed to loads and stores. No operations on tiles are // performed here. result = values[hlo->operand(0)]; + } else if (hlo->opcode() == HloOpcode::kFusion) { + const auto* fusion_instruction = ::xla::Cast(hlo); + TF_ASSIGN_OR_RETURN(result, + EmitNestedFusion(b, libdevice_path, device_info, + *fusion_instruction, values)); } else { - LOG(FATAL) << hlo->ToString(); + return absl::InvalidArgumentError( + absl::StrCat("Unsupported operation ", hlo->ToString())); } TF_RET_CHECK(values.insert({hlo, result}).second) << hlo->ToString(); VLOG(8) << "Emitted " << hlo->ToString(HloPrintOptions::ShortParsable()); @@ -716,72 +934,6 @@ StatusOr EmitScope( return values[instructions.back()]; } -void CreateTritonPipeline(mlir::OpPassManager& pm, - const se::CudaComputeCapability& cc, int num_warps, - int num_stages) { - const int ccAsInt = cc.major * 10 + cc.minor; - const int threadsPerWarp = 32; - const int numCTAs = 1; - // Based on optimize_ttir() in - // @triton//:python/triton/compiler/compiler.py - pm.addPass(mt::createRewriteTensorPointerPass(ccAsInt)); - pm.addPass(mlir::createInlinerPass()); - pm.addPass(mt::createCombineOpsPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mt::createReorderBroadcastPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createLoopInvariantCodeMotionPass()); - pm.addPass(mlir::createSymbolDCEPass()); - // Based on ttir_to_ttgir() in - // @triton//:python/triton/compiler/compiler.py - pm.addPass(mt::createConvertTritonToTritonGPUPass(num_warps, threadsPerWarp, - numCTAs, ccAsInt)); - // Based on optimize_ttgir() in - // @triton//:python/triton/compiler/compiler.py - pm.addPass(mlir::createTritonGPUCoalescePass()); - pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(/*clusterInfo=*/)); - pm.addPass(mlir::createTritonGPURewriteTensorPointerPass(ccAsInt)); - pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(/*clusterInfo=*/)); - pm.addPass(mlir::createTritonGPURemoveLayoutConversionsPass()); - pm.addPass(mlir::createTritonGPUAccelerateMatmulPass(ccAsInt)); - pm.addPass(mlir::createTritonGPURemoveLayoutConversionsPass()); - pm.addPass(mlir::createTritonGPUOptimizeDotOperandsPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createTritonGPUPipelinePass(num_stages, num_warps, numCTAs, - ccAsInt)); - pm.addPass( - mlir::createTritonNvidiaGPUMaterializeLoadStorePass(num_warps, ccAsInt)); - if (ccAsInt <= 80) { - pm.addPass(mlir::createTritonGPUPrefetchPass()); - } - pm.addPass(mlir::createTritonGPUOptimizeDotOperandsPass()); - pm.addPass(mlir::createTritonGPURemoveLayoutConversionsPass()); - pm.addPass(mlir::createTritonGPUDecomposeConversionsPass()); - pm.addPass(mlir::createTritonNvidiaGPUWSFixupMissingAttrs()); - pm.addPass(mlir::createTritonGPUReorderInstructionsPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createSymbolDCEPass()); - if (ccAsInt >= 90) { - pm.addPass(mlir::createTritonNvidiaGPUFenceInsertionPass(ccAsInt)); - } - pm.addPass(mlir::createTritonNvidiaGPUWSFixupMissingAttrs()); - pm.addPass(mlir::createTritonGPUOptimizeThreadLocalityPass()); - pm.addPass(mlir::createCanonicalizerPass()); - // Based on translateTritonGPUToLLVMIR() in - // @triton//:lib/Target/LLVMIR/LLVMIRTranslation.cpp - pm.addPass(mlir::createConvertSCFToCFPass()); - pm.addPass(mlir::createConvertIndexToLLVMPass()); - pm.addPass(mt::createConvertTritonGPUToLLVMPass(ccAsInt, - /*target=*/mt::Default, - /*tmaMetadata=*/nullptr)); - pm.addPass(mt::createConvertNVGPUToLLVMPass()); - pm.addPass(mlir::createArithToLLVMConversionPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createSymbolDCEPass()); - // Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass. -} - // Extract additional attributes from an LLVM function that are not passed // to the builder directly. SmallVector GetExtraAttrs(ml::LLVMFuncOp func) { @@ -825,7 +977,7 @@ void StripParameterAddressSpaces(mlir::RewriterBase& rewriter, // Convert generic address spaces back to original ones within the function // body. - mlir::Block* entry = generic_func.addEntryBlock(); + mlir::Block* entry = generic_func.addEntryBlock(rewriter); rewriter.setInsertionPointToEnd(entry); SmallVector converted_args; for (auto [arg, type] : @@ -889,8 +1041,9 @@ const TensorIterationSpec::DimIterationSpec* GetLhsNoncontractingSplitSpec( // split-K, batch, non-contracting LHS, non-contracting RHS, // where split-K and batch are optional. struct MatMulDims { - MatMulDims(const TritonGemmConfig& config, const HloDotInstruction& dot, - const TritonFusionAnalysis& analysis); + static absl::StatusOr Create( + const TritonGemmConfig& config, const HloDotInstruction& dot, + const TritonFusionAnalysis& analysis); std::optional out_split_k_dim_idx = std::nullopt; @@ -913,6 +1066,9 @@ struct MatMulDims { int64_t m; int64_t n; int64_t k; + + private: + MatMulDims() = default; }; // Structure for parameters relating to the MatMul launch grid. @@ -928,85 +1084,90 @@ struct MatMulLaunchConfig { mt::ProgramIDDim noncontracting_program_id_dim; }; -MatMulDims::MatMulDims(const TritonGemmConfig& config, - const HloDotInstruction& dot, - const TritonFusionAnalysis& analysis) { +/*static*/ absl::StatusOr MatMulDims::Create( + const TritonGemmConfig& config, const HloDotInstruction& dot, + const TritonFusionAnalysis& analysis) { + MatMulDims matmul_dims; if (config.split_k > 1) { // split-k is always the first logical dimension. - out_split_k_dim_idx = 0; + matmul_dims.out_split_k_dim_idx = 0; } int64_t num_split_k_dims = config.split_k > 1 ? 1 : 0; const auto& dims = dot.dot_dimension_numbers(); - lhs_contracting_dim_idx = dims.lhs_contracting_dimensions(0); - lhs_noncontracting_dim_idx = + matmul_dims.lhs_contracting_dim_idx = dims.lhs_contracting_dimensions(0); + matmul_dims.lhs_noncontracting_dim_idx = GetNonContractingDims(dot.operand(0)->shape(), dims.lhs_batch_dimensions(), dims.lhs_contracting_dimensions()) .value()[0]; - rhs_contracting_dim_idx = dims.rhs_contracting_dimensions(0); - rhs_noncontracting_dim_idx = + matmul_dims.rhs_contracting_dim_idx = dims.rhs_contracting_dimensions(0); + matmul_dims.rhs_noncontracting_dim_idx = GetNonContractingDims(dot.operand(1)->shape(), dims.rhs_batch_dimensions(), dims.rhs_contracting_dimensions()) .value()[0]; if (dims.lhs_batch_dimensions_size() > num_split_k_dims) { - lhs_batch_dim_idx = *dims.lhs_batch_dimensions().rbegin(); - rhs_batch_dim_idx = *dims.rhs_batch_dimensions().rbegin(); + matmul_dims.lhs_batch_dim_idx = *dims.lhs_batch_dimensions().rbegin(); + matmul_dims.rhs_batch_dim_idx = *dims.rhs_batch_dimensions().rbegin(); // The batch dimension (if present) comes after the split-k dimension (if // present, otherwise it's the first dimension). - out_batch_dim_idx = num_split_k_dims; + matmul_dims.out_batch_dim_idx = num_split_k_dims; } // Logical output dimensions are always ordered as: // split-K, batch, non-contracting LHS, non-contracting RHS, // where split-K and batch are optional. - out_rhs_noncontracting_dim_idx = dot.shape().rank() - 1; - out_lhs_noncontracting_dim_idx = dot.shape().rank() - 2; + matmul_dims.out_rhs_noncontracting_dim_idx = dot.shape().rank() - 1; + matmul_dims.out_lhs_noncontracting_dim_idx = dot.shape().rank() - 2; auto* root = dot.parent()->root_instruction(); - n = analysis - .IterSpec(TritonFusionAnalysis::Scope::OUTPUT, root, - out_rhs_noncontracting_dim_idx) - ->at(0) - .count; + auto iter_spec = + analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, root, + matmul_dims.out_rhs_noncontracting_dim_idx); + TF_RET_CHECK(iter_spec != nullptr); + matmul_dims.n = iter_spec->at(0).count; // Contracting dimension length. if (config.split_k > 1 && dot.operand(0)->operand(0)->opcode() == HloOpcode::kPad) { // Unpadded LHS shape: [..., k, ...] // Padded LHS shape: [..., padded_k, ...] // Bitcasted LHS shape: [..., split_k, padded_k / split_k, ...] - CHECK_EQ(dot.operand(0)->opcode(), HloOpcode::kBitcast); + TF_RET_CHECK(dot.operand(0)->opcode() == HloOpcode::kBitcast); const Shape& unpadded_lhs_shape = dot.operand(0)->operand(0)->operand(0)->shape(); - k = unpadded_lhs_shape.dimensions(dims.lhs_contracting_dimensions(0) - 1); + matmul_dims.k = + unpadded_lhs_shape.dimensions(dims.lhs_contracting_dimensions(0) - 1); } else { - k = dot.operand(0)->shape().dimensions(dims.lhs_contracting_dimensions(0)) * + matmul_dims.k = + dot.operand(0)->shape().dimensions(dims.lhs_contracting_dimensions(0)) * config.split_k; } - auto* lhs_noncontracting_split_spec = - GetLhsNoncontractingSplitSpec(analysis, lhs_noncontracting_dim_idx); + auto* lhs_noncontracting_split_spec = GetLhsNoncontractingSplitSpec( + analysis, matmul_dims.lhs_noncontracting_dim_idx); if (lhs_noncontracting_split_spec != nullptr) { // Just the fastest-varying part of it if the dimension is split. - m = lhs_noncontracting_split_spec->at(0).count; - lhs_noncontracting_split = lhs_noncontracting_split_spec->at(1).count; + matmul_dims.m = lhs_noncontracting_split_spec->at(0).count; + matmul_dims.lhs_noncontracting_split = + lhs_noncontracting_split_spec->at(1).count; } else { - m = analysis - .IterSpec(TritonFusionAnalysis::Scope::OUTPUT, root, - out_lhs_noncontracting_dim_idx) - ->at(0) - .count; + matmul_dims.m = analysis + .IterSpec(TritonFusionAnalysis::Scope::OUTPUT, root, + matmul_dims.out_lhs_noncontracting_dim_idx) + ->at(0) + .count; } // For now split non-contracting and batch are not supported // simultaneously because they are implemented via same mechanism. - CHECK( - !(out_batch_dim_idx.has_value() && lhs_noncontracting_split.has_value())); + TF_RET_CHECK(!(matmul_dims.out_batch_dim_idx.has_value() && + matmul_dims.lhs_noncontracting_split.has_value())); - CHECK_GE(m, 1); - CHECK_GE(n, 1); + TF_RET_CHECK(matmul_dims.m >= 1); + TF_RET_CHECK(matmul_dims.n >= 1); + return std::move(matmul_dims); } MatMulLaunchConfig::MatMulLaunchConfig(const TritonGemmConfig& config, @@ -1031,49 +1192,51 @@ MatMulLaunchConfig::MatMulLaunchConfig(const TritonGemmConfig& config, if (large_batch) { batch_program_id_dim = mt::ProgramIDDim::X; noncontracting_program_id_dim = mt::ProgramIDDim::Y; - launch_dims = {{batch_size, grid_m * grid_n, config.split_k}, - {config.num_warps * WarpSize(), 1, 1}}; + launch_dims = LaunchDimensions( + se::BlockDim(batch_size, grid_m * grid_n, config.split_k), + se::ThreadDim(config.num_warps * WarpSize(), 1, 1)); } else { batch_program_id_dim = mt::ProgramIDDim::Y; noncontracting_program_id_dim = mt::ProgramIDDim::X; - launch_dims = - LaunchDimensions{{grid_m * grid_n, batch_size, config.split_k}, - {config.num_warps * WarpSize(), 1, 1}}; + launch_dims = LaunchDimensions( + se::BlockDim(grid_m * grid_n, batch_size, config.split_k), + se::ThreadDim(config.num_warps * WarpSize(), 1, 1)); } } -void ValidateMatMulConfig(const TritonGemmConfig& config, - const HloDotInstruction& dot) { - CHECK_GE(config.split_k, 1); - CHECK_GE(config.block_m, 16); - CHECK_GE(config.block_k, 16); - CHECK_GE(config.block_n, 16); +absl::Status ValidateMatMulConfig(const TritonGemmConfig& config, + const HloDotInstruction& dot) { + TF_RET_CHECK(config.split_k >= 1); + TF_RET_CHECK(config.block_m >= 16); + TF_RET_CHECK(config.block_k >= 16); + TF_RET_CHECK(config.block_n >= 16); const auto& dims = dot.dot_dimension_numbers(); int num_batch_dims = dims.lhs_batch_dimensions_size() - (config.split_k > 1 ? 1 : 0); - CHECK_LE(num_batch_dims, 1); + TF_RET_CHECK(num_batch_dims <= 1); if (config.split_k > 1) { // Split-K dimension has to be the first batch one and have an index // just before the contracting one. const int lhs_split_k_dim_idx = dims.lhs_contracting_dimensions(0) - 1; const int rhs_split_k_dim_idx = dims.rhs_contracting_dimensions(0) - 1; // Size of this dimension has to match the split_k value. - CHECK_EQ(dims.lhs_batch_dimensions(0), lhs_split_k_dim_idx); - CHECK_EQ(dims.rhs_batch_dimensions(0), rhs_split_k_dim_idx); - CHECK_EQ(config.split_k, - dot.operand(0)->shape().dimensions(lhs_split_k_dim_idx)); - CHECK_EQ(config.split_k, - dot.operand(1)->shape().dimensions(rhs_split_k_dim_idx)); + TF_RET_CHECK(dims.lhs_batch_dimensions(0) == lhs_split_k_dim_idx); + TF_RET_CHECK(dims.rhs_batch_dimensions(0) == rhs_split_k_dim_idx); + TF_RET_CHECK(config.split_k == + dot.operand(0)->shape().dimensions(lhs_split_k_dim_idx)); + TF_RET_CHECK(config.split_k == + dot.operand(1)->shape().dimensions(rhs_split_k_dim_idx)); } // Rely on dot decomposer: there is just one contracting and one // non-contracting dimension on each side + batch ones optionally. - CHECK_EQ(dims.lhs_contracting_dimensions_size(), 1); - CHECK_EQ(dims.rhs_contracting_dimensions_size(), 1); + TF_RET_CHECK(dims.lhs_contracting_dimensions_size() == 1); + TF_RET_CHECK(dims.rhs_contracting_dimensions_size() == 1); - CHECK_EQ(dot.operand(0)->shape().rank(), - 2 + (config.split_k > 1 ? 1 : 0) + num_batch_dims); + TF_RET_CHECK(dot.operand(0)->shape().rank() == + 2 + (config.split_k > 1 ? 1 : 0) + num_batch_dims); + return absl::OkStatus(); } struct Side { @@ -1091,9 +1254,9 @@ struct Side { // } else { // return choices.back(); // } -Value EmitMultiSelect(ImplicitLocOpBuilder b, Value index, ValueRange limits, - ValueRange choices) { - CHECK_EQ(choices.size() - 1, limits.size()); +absl::StatusOr EmitMultiSelect(ImplicitLocOpBuilder b, Value index, + ValueRange limits, ValueRange choices) { + TF_RET_CHECK(choices.size() - 1 == limits.size()); Value result = choices[0]; for (int i = 0; i < choices.size() - 1; ++i) { result = b.create( @@ -1103,8 +1266,8 @@ Value EmitMultiSelect(ImplicitLocOpBuilder b, Value index, ValueRange limits, return result; } -Status UncompilableMatmul(absl::string_view explanation) { - Status s = absl::CancelledError(explanation); +absl::Status UncompilableMatmul(absl::string_view explanation) { + absl::Status s = absl::CancelledError(explanation); s.SetPayload(kUncompilableFusion, absl::Cord(explanation)); return s; } @@ -1112,12 +1275,14 @@ Status UncompilableMatmul(absl::string_view explanation) { class MatMulEmitterHelper { public: MatMulEmitterHelper(absl::string_view libdevice_path, + const se::DeviceDescription& device_info, const HloDotInstruction* dot_instr, ImplicitLocOpBuilder& b, Type index_ty, MatMulDims dims, const MatMulLaunchConfig& launch_config, const TritonFusionAnalysis& analysis) : b_(b), libdevice_path_(libdevice_path), + device_info_(device_info), dot_instr_(dot_instr), index_ty_(index_ty), analysis_(analysis), @@ -1127,20 +1292,37 @@ class MatMulEmitterHelper { // TODO(b/266862493): Accumulator can be integer too. // Otherwise only f64 x f64 -> f64 uses f64 accumulator. mlir::FloatType GetDotAccumulatorType() { - Type dot_output_ty = TritonType(b_, dot_instr_->shape().element_type()); - // Data type of dot() immediate inputs. - Type dot_input_ty = [&] { - const Type lhs_ty = - TritonType(b_, dot_instr_->operand(0)->shape().element_type()); - const Type rhs_ty = - TritonType(b_, dot_instr_->operand(1)->shape().element_type()); - CHECK(lhs_ty == rhs_ty); - return lhs_ty; - }(); - // TODO(b/266862493): Accumulator can be integer too. - // Otherwise only f64 x f64 -> f64 uses f64 accumulator. - return (dot_output_ty.isF64() && dot_input_ty.isF64()) ? b_.getF64Type() - : b_.getF32Type(); + const PrecisionConfig::Algorithm algorithm = + dot_instr_->precision_config().algorithm(); + + if (algorithm == PrecisionConfig::ALG_UNSET) { + Type dot_output_ty = TritonType(b_, dot_instr_->shape().element_type()); + // Data type of dot() immediate inputs. + Type dot_input_ty = [&] { + const Type lhs_ty = + TritonType(b_, dot_instr_->operand(0)->shape().element_type()); + const Type rhs_ty = + TritonType(b_, dot_instr_->operand(1)->shape().element_type()); + CHECK(lhs_ty == rhs_ty); + return lhs_ty; + }(); + // TODO(b/266862493): Accumulator can be integer too. + // Otherwise only f64 x f64 -> f64 uses f64 accumulator. + return (dot_output_ty.isF64() && dot_input_ty.isF64()) ? b_.getF64Type() + : b_.getF32Type(); + } + + absl::StatusOr accum_type = + algorithm_util::GetDotAccumulatorType(algorithm); + CHECK(accum_type.ok()) << "Unexpected algorithm: " + << PrecisionConfig::Algorithm_Name(algorithm); + Type mlir_accum_type = TritonType(b_, accum_type.value()); + if (auto float_accum_type = mlir_accum_type.dyn_cast()) { + return float_accum_type; + } + LOG(FATAL) << "Only floating point accumulator types are supported for " + "now, but we got: " + << llvm_ir::DumpToString(mlir_accum_type); } std::vector EpiloguePostOrderTransitiveOperands( @@ -1179,15 +1361,16 @@ class MatMulEmitterHelper { Value MakeInput(Side& side, int64_t operand_index, absl::flat_hash_map& values) { return *EmitScope( - b_, libdevice_path_, &analysis_, side.scope, side.tiled_dims, + b_, libdevice_path_, device_info_, &analysis_, side.scope, + side.tiled_dims, dot_instr_->parent()->MakeInstructionPostOrderFrom( const_cast(*dot_instr_->operand(operand_index))), values); } - StatusOr EmitTensorPointer(const HloInstruction* hlo, const Side& side, - ValueRange bases, Value pid_k, - std::vector& boundary_checks) { + absl::StatusOr EmitTensorPointer( + const HloInstruction* hlo, const Side& side, ValueRange bases, + Value pid_k, std::vector& boundary_checks) { // Parameters of MakeTensorPtrOp to be generated by this function. Value base; std::vector bounds; @@ -1231,35 +1414,33 @@ class MatMulEmitterHelper { } LOG(FATAL) << "Missing dimension."; }(); - CHECK_EQ(bases.size(), hlo->operand_count()); + TF_RET_CHECK(bases.size() == hlo->operand_count()); concat_boundaries.reserve(hlo->operand_count() - 1); - int64_t accumulated_size = 0; for (int i = 0; i < hlo->operand_count() - 1; ++i) { - const int64_t operand_size = + const TensorIterationSpec::IterationSpecFragment& fragment = analysis_.IterSpec(side.scope, hlo->operand(i), concat_dim_idx) - ->at(0) - .count; - if (operand_size % properties.block_size != 0) { + ->at(0); + if (fragment.sliced_count % properties.block_size != 0) { return UncompilableMatmul( "Operand is not divisible by the block size."); } - accumulated_size += operand_size; - concat_boundaries.push_back(Cst32(accumulated_size)); + concat_boundaries.push_back( + Cst32(-fragment.slice_start + fragment.sliced_count)); } concat_dim_pid_offset = b_.create(properties.pid, Cst32(properties.block_size)); - base = - EmitMultiSelect(b_, concat_dim_pid_offset, concat_boundaries, bases); + TF_ASSIGN_OR_RETURN(base, EmitMultiSelect(b_, concat_dim_pid_offset, + concat_boundaries, bases)); } else { concat_dim_idx = -1; base = bases[0]; } - auto add_dim = [&](const DimProperties& properties) { + auto add_dim = [&](const DimProperties& properties) -> absl::Status { if (analysis_.IterSpec(side.scope, hlo, properties.index) == nullptr) { - return; + return absl::OkStatus(); } Value pid_offset = (properties.pid == nullptr) @@ -1285,25 +1466,29 @@ class MatMulEmitterHelper { specs.push_back( analysis_.IterSpec(side.scope, input, properties.index)); input_strides.push_back(Cst64(specs.back()->at(0).stride)); - input_offsets.push_back(b_.create( - pid_offset, input_offsets.empty() - ? Cst32(0) - : concat_boundaries[input_offsets.size() - 1])); + input_offsets.push_back(b_.create( + pid_offset, Cst32(specs.back()->at(0).slice_start))); input_bounds.push_back(Cst64(specs.back()->at(0).count)); } - strides.push_back(EmitMultiSelect(b_, concat_dim_pid_offset, - concat_boundaries, input_strides)); + TF_ASSIGN_OR_RETURN(Value select_value, + EmitMultiSelect(b_, concat_dim_pid_offset, + concat_boundaries, input_strides)); + strides.push_back(select_value); if (properties.index == concat_dim_idx) { - block_offsets.push_back( + TF_ASSIGN_OR_RETURN( + select_value, EmitMultiSelect(b_, pid_offset, concat_boundaries, input_offsets)); - bounds.push_back( + block_offsets.push_back(select_value); + TF_ASSIGN_OR_RETURN( + select_value, EmitMultiSelect(b_, pid_offset, concat_boundaries, input_bounds)); + bounds.push_back(select_value); } else { block_offsets.push_back(pid_offset); - int64_t count = specs.back()->at(0).count; + int64_t count = specs.front()->at(0).count; if (side.scope == TritonFusionAnalysis::Scope::OUTPUT && properties.index == dims_.out_lhs_noncontracting_dim_idx && - specs.back()->size() == 1 && + specs.front()->size() == 1 && dims_.lhs_noncontracting_split.has_value()) { // Dimension of the output produced by the non-contracting LHS one // is logically split, major part is addressed using pid_batch. @@ -1314,50 +1499,81 @@ class MatMulEmitterHelper { boundary_checks.push_back(bounds.size() - 1); } } - tensor_offsets.push_back(Cst32(specs.back()->at(0).slice_start)); + tensor_offsets.push_back(Cst32(specs.front()->at(0).slice_start)); block_dims.push_back(properties.block_size); dim_order.emplace(dim_order.begin(), dim_order.size()); + return absl::OkStatus(); }; for (const DimProperties& dim : side.tiled_dims) { - add_dim(dim); + TF_RETURN_IF_ERROR(add_dim(dim)); } - int64_t stride_batch = 0; int64_t offset_batch = 0; - if (side.scope != TritonFusionAnalysis::Scope::RHS && - dims_.lhs_noncontracting_split) { - const TensorIterationSpec::DimIterationSpec* spec = - analysis_.IterSpec(side.scope, hlo, side.tiled_dims[0].index); - if (spec != nullptr) { - if (spec->size() > 1) { - // Support one specific kind of output transpose that splits the - // dimension originating from the split LHS non-contracting one. - stride_batch = spec->at(1).stride; - } else { - // Because the major part of the split is implemented using the - // batch logic stride_batch is populated here as the stride of - // the minor part times its size. - stride_batch = spec->at(0).stride * - (spec->at(0).count / *dims_.lhs_noncontracting_split); + bool has_batch_offset = false; + Value batch_stride; + + // Return the batch stride of the HLO passed as a parameter. If the + // parameter HLO has no batch dimension, a zero stride is returned. + // Also sets offset_batch and updates has_batch_offset as a side effect. + auto get_batch_stride = + [this, &side, &offset_batch, &has_batch_offset]( + const HloInstruction* hlo_param) -> absl::StatusOr { + int64_t stride_batch = 0; + if (side.scope != TritonFusionAnalysis::Scope::RHS && + dims_.lhs_noncontracting_split) { + const TensorIterationSpec::DimIterationSpec* spec = + analysis_.IterSpec(side.scope, hlo_param, side.tiled_dims[0].index); + if (spec != nullptr) { + if (spec->size() > 1) { + // Support one specific kind of output transpose that splits the + // dimension originating from the split LHS non-contracting one. + stride_batch = spec->at(1).stride; + } else { + // Because the major part of the split is implemented using the + // batch logic stride_batch is populated here as the stride of + // the minor part times its size. + stride_batch = + spec->at(0).stride * + (spec->at(0).count / *dims_.lhs_noncontracting_split); + } + TF_RET_CHECK(stride_batch != 0); + } + } else if (side.batch_dim_idx.has_value()) { + const TensorIterationSpec::DimIterationSpec* spec = + analysis_.IterSpec(side.scope, hlo_param, *side.batch_dim_idx); + if (spec != nullptr) { + stride_batch = spec->at(0).stride; + offset_batch = spec->at(0).slice_start; + TF_RET_CHECK(stride_batch != 0); } - CHECK_NE(stride_batch, 0); } - } else if (side.batch_dim_idx.has_value()) { - const TensorIterationSpec::DimIterationSpec* spec = - analysis_.IterSpec(side.scope, hlo, *side.batch_dim_idx); - if (spec != nullptr) { - stride_batch = spec->at(0).stride; - offset_batch = spec->at(0).slice_start; - CHECK_NE(stride_batch, 0); + + has_batch_offset |= stride_batch != 0; + return Cst(stride_batch); + }; + + if (hlo->opcode() == HloOpcode::kConcatenate) { + std::vector batch_strides; + batch_strides.reserve(hlo->operands().size()); + for (const HloInstruction* operand : hlo->operands()) { + TF_ASSIGN_OR_RETURN(Value op_stride, get_batch_stride(operand)); + batch_strides.push_back(op_stride); } + TF_ASSIGN_OR_RETURN(batch_stride, + EmitMultiSelect(b_, concat_dim_pid_offset, + concat_boundaries, batch_strides)); + } else { + TF_ASSIGN_OR_RETURN(batch_stride, get_batch_stride(hlo)); } - if (stride_batch != 0) { + + // Avoid generating logic to compute batch offset if unnecessary. + if (has_batch_offset) { Value pid_batch = b_.create(launch_config_.batch_program_id_dim); Value pid_offset_batch = b_.create( b_.create(Cst(offset_batch), ConvertScalar(pid_batch)), - Cst(stride_batch)); + batch_stride); base = AddPtr(b_, base, pid_offset_batch); } @@ -1365,7 +1581,7 @@ class MatMulEmitterHelper { const TensorIterationSpec::DimIterationSpec* spec = analysis_.IterSpec( TritonFusionAnalysis::Scope::OUTPUT, hlo, *dims_.out_split_k_dim_idx); if (spec != nullptr) { - CHECK(pid_k != nullptr); + TF_RET_CHECK(pid_k != nullptr); base = AddPtr(b_, base, b_.create(ConvertScalar(pid_k), Cst(spec->at(0).stride))); @@ -1396,11 +1612,12 @@ class MatMulEmitterHelper { } Value Cst(int64_t v) { return CreateConst(b_, index_ty_, v); } - Value Cst32(int64_t v) { return CreateConst(b_, i32_ty_, v); } + Value Cst32(int32_t v) { return CreateConst(b_, i32_ty_, v); } Value Cst64(int64_t v) { return CreateConst(b_, i64_ty_, v); } ImplicitLocOpBuilder& b_; absl::string_view libdevice_path_; + const se::DeviceDescription& device_info_; const HloDotInstruction* dot_instr_; Type index_ty_; TritonFusionAnalysis analysis_; @@ -1412,16 +1629,17 @@ class MatMulEmitterHelper { } // namespace -LaunchDimensions GetMatMulLaunchDimensions(const TritonFusionAnalysis& analysis, - const HloFusionAdaptor& fusion, - const TritonGemmConfig& config) { +absl::StatusOr GetMatMulLaunchDimensions( + const TritonFusionAnalysis& analysis, const HloFusionAdaptor& fusion, + const TritonGemmConfig& config) { auto dot = HloFindIf(fusion.GetRoots(), fusion, [](auto node) { return node.opcode() == HloOpcode::kDot; }); - CHECK(dot != std::nullopt); + TF_RET_CHECK(dot != std::nullopt); const auto& dot_instr = *static_cast(&dot->instruction()); - MatMulDims dims(config, dot_instr, analysis); + TF_ASSIGN_OR_RETURN(MatMulDims dims, + MatMulDims::Create(config, dot_instr, analysis)); MatMulLaunchConfig launch_config(config, dot_instr, dims); return launch_config.launch_dims; } @@ -1466,13 +1684,241 @@ ConstHloInstructionSet ScopeInputs(const TritonFusionAnalysis& analysis, return result; } +// Truncates |input| of F32 type to the number representable in Bf16 toward +// zero. +// It is used for Emit6xBfloat16MatMul. +Value TruncateToBF16TowardsZero(ImplicitLocOpBuilder& b, Value input) { + ShapedType input_type = input.getType().dyn_cast(); + Type input_type_as_i32 = input_type.clone(b.getI32Type()); + Value input_as_i32 = b.create(input_type_as_i32, input); + Value mask = CreateConst(b, b.getI32Type(), 0xFFFF0000u, + input_type.getShape()); + Value high_bits = b.create(input_type_as_i32, input_as_i32, mask); + + return b.create(input_type, high_bits); +} + +// Finds the middle 8 bits of |input|'s mantissa. +// It is used for Emit6xBfloat16MatMul. +Value SoftMiddleEight(ImplicitLocOpBuilder& b, Value input) { + Value high = TruncateToBF16TowardsZero(b, input); + return b.create(input, high); +} + +// Finds the low 8 bits of |input|'s mantissa. +// It is used for Emit6xBfloat16MatMul. +Value SoftLowEight(ImplicitLocOpBuilder& b, Value input) { + // Find the middle bits of the middle bits, and these are the low eight + // bits. + return SoftMiddleEight(b, SoftMiddleEight(b, input)); +} + +// Rounds |input| to BF16 type. +// It is used for Emit6xBfloat16MatMul. +Value RoundToBF16(ImplicitLocOpBuilder& b, Value input) { + return Cast(b, input, b.getBF16Type()); +} + +// Checks |input| is finite f32 (not Nan and not infinite). +// It is used for Emit6xBfloat16MatMul and Emit3xBfloat16MatMul. +Value CheckFiniteF32(ImplicitLocOpBuilder& b, Value input) { + Value positive_inf = CreateConst( + b, b.getF32Type(), std::numeric_limits::infinity(), + input.getType().cast().getShape()); + Value abs_input = b.create(input); + return b.create(ma::CmpFPredicate::OGT, positive_inf, abs_input); +} + +// Leverages BF16 datatype for F32 matmul computation. It follows the guidance +// from https://arxiv.org/pdf/1904.06376.pdf. +absl::StatusOr Emit6xBfloat16MatMul(ImplicitLocOpBuilder& b, Value lhs, + Value rhs, Value acc) { + Type f32 = b.getF32Type(); + TF_RET_CHECK(lhs.getType().cast().getElementType() == f32); + TF_RET_CHECK(rhs.getType().cast().getElementType() == f32); + TF_RET_CHECK(acc.getType().cast().getElementType() == f32); + + Value lhs_high = RoundToBF16(b, TruncateToBF16TowardsZero(b, lhs)); + Value lhs_middle = + RoundToBF16(b, TruncateToBF16TowardsZero(b, SoftMiddleEight(b, lhs))); + Value lhs_low = + RoundToBF16(b, TruncateToBF16TowardsZero(b, SoftLowEight(b, lhs))); + + Value rhs_high = RoundToBF16(b, TruncateToBF16TowardsZero(b, rhs)); + Value rhs_middle = + RoundToBF16(b, TruncateToBF16TowardsZero(b, SoftMiddleEight(b, rhs))); + Value rhs_low = + RoundToBF16(b, TruncateToBF16TowardsZero(b, SoftLowEight(b, rhs))); + + auto bf16_dot = [&](Value lhs_bf16, Value rhs_bf16, + Value accumulator) -> Value { + return b.create(lhs_bf16, rhs_bf16, accumulator, + /*allowTF32=*/false, + /*maxNumImpreciseAcc=*/0); + }; + + Value local_acc = ZerosLike(b, acc); + Value result = bf16_dot(lhs_middle, rhs_middle, local_acc); + result = bf16_dot(lhs_low, rhs_high, result); + result = bf16_dot(lhs_high, rhs_low, result); + result = bf16_dot(lhs_middle, rhs_high, result); + result = bf16_dot(lhs_high, rhs_middle, result); + // If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0. + // If rhs is +infinity, we will have: + // +infinity * 1.0 = +infinity + // +infinity * 0.0 = NaN + // We would get the wrong result if we sum these partial products. Instead, we + // must override any accumulated result if the last partial product is + // non-finite. See b/115844437. + Value is_finite = CheckFiniteF32(b, result); + result = b.create(is_finite, result, ZerosLike(b, result)); + result = bf16_dot(lhs_high, rhs_high, result); + result = b.create(acc, result); + return result; +} + +// Compute F32 matmul with 3 BF16 dots. It is less accurate than +// Emit6xBfloat16MatMul. +absl::StatusOr Emit3xBfloat16MatMul(ImplicitLocOpBuilder& b, Value lhs, + Value rhs, Value acc) { + Type f32 = b.getF32Type(); + TF_RET_CHECK(lhs.getType().cast().getElementType() == f32); + TF_RET_CHECK(rhs.getType().cast().getElementType() == f32); + TF_RET_CHECK(acc.getType().cast().getElementType() == f32); + + Value lhs_high = RoundToBF16(b, TruncateToBF16TowardsZero(b, lhs)); + Value lhs_low = RoundToBF16(b, SoftMiddleEight(b, lhs)); + + Value rhs_high = RoundToBF16(b, TruncateToBF16TowardsZero(b, rhs)); + Value rhs_low = RoundToBF16(b, SoftMiddleEight(b, rhs)); + + auto bf16_dot = [&](Value lhs_bf16, Value rhs_bf16, + Value accumulator) -> Value { + return b.create(lhs_bf16, rhs_bf16, accumulator, + /*allowTF32=*/false, + /*maxNumImpreciseAcc=*/0); + }; + + Value local_acc = ZerosLike(b, acc); + Value result = bf16_dot(lhs_low, rhs_high, local_acc); + result = bf16_dot(lhs_high, rhs_low, result); + Value is_finite = CheckFiniteF32(b, result); + result = b.create(is_finite, result, ZerosLike(b, result)); + result = bf16_dot(lhs_high, rhs_high, result); + result = b.create(acc, result); + return result; +} + +namespace { + +bool IsTf32Allowed(const HloDotInstruction* dot_instr) { + const PrecisionConfig::Algorithm algorithm = + dot_instr->precision_config().algorithm(); + + if (algorithm == PrecisionConfig::ALG_UNSET) { + return tsl::tensor_float_32_execution_enabled() && + absl::c_none_of(dot_instr->precision_config().operand_precision(), + [](const int precision) { + return precision != PrecisionConfig::DEFAULT; + }); + } + + return algorithm_util::HasTf32InputType(algorithm); +} + +bool Is6xBfloat16MatMul(const HloDotInstruction* dot_instr, + mlir::OpBuilder& builder, Value dot_input_lhs, + Value dot_input_rhs, + const se::DeviceDescription& device_info) { + const PrecisionConfig::Algorithm algorithm = + dot_instr->precision_config().algorithm(); + + if (algorithm == PrecisionConfig::ALG_UNSET) { + const HloModule* hlo_module = dot_instr->GetModule(); + Type f32 = builder.getF32Type(); + return hlo_module->config() + .debug_options() + .xla_gpu_enable_bf16_6way_gemm() && + dot_input_lhs.getType().cast().getElementType() == f32 && + dot_input_rhs.getType().cast().getElementType() == f32; + } + + return algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6; +} + +bool Is3xBfloat16MatMul(const HloDotInstruction* dot_instr, + mlir::OpBuilder& builder, Value dot_input_lhs, + Value dot_input_rhs, + const se::DeviceDescription& device_info) { + const PrecisionConfig::Algorithm algorithm = + dot_instr->precision_config().algorithm(); + + if (algorithm == PrecisionConfig::ALG_UNSET) { + const HloModule* hlo_module = dot_instr->GetModule(); + Type f32 = builder.getF32Type(); + return hlo_module->config() + .debug_options() + .xla_gpu_enable_bf16_3way_gemm() && + dot_input_lhs.getType().cast().getElementType() == f32 && + dot_input_rhs.getType().cast().getElementType() == f32; + } + + return algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3; +} + +// This is a heuristic that serves as a proxy for register usage and code size. +// +// We have noticed that tilings with very long LLVM IR code are both slow to +// compile and slow to run. This can be for example due to register spills. So +// we should skip these tilings to save time. But it's better to skip them +// before the LLVM IR is generated. To do that, we came up with a formula that +// strongly correlates with the LLVM IR size. The formula is the size of the two +// input and the output thread block tiles divided by the number of warps. We +// read https://developer.nvidia.com/blog/cutlass-linear-algebra-cuda/ as a +// reference, and found the formula by trial and error. +// +// To regenerate the limit, we have to run an exhaustive search on all tilings +// for a few different HLOs, printing the runtimes and the heuristic values. +// +// From that, we can find a limit, such that all tilings within alpha * +// optimal_runtime have a heuristic value less than or equal to the limit. +// +// In our measurements, all tilings which were within 1.13 * optimal_runtime had +// a complexity_heuristic_value <= kComplexityHeuristicLimit. +// +// See go/tiling-heuristic for more details. +absl::Status CheckGemmTilingComplexityHeuristic( + const TritonGemmConfig& config) { + constexpr int64_t kComplexityHeuristicLimit = 9000; + int64_t complexity_heuristic_value = + (config.block_m * config.block_n + + (config.block_m + config.block_n) * config.block_k) / + config.num_warps; + VLOG(2) << "Complexity heuristic: " << complexity_heuristic_value; + if (complexity_heuristic_value > kComplexityHeuristicLimit) { + return ResourceExhausted("Tiling complexity heuristic exceeded: %d > %d", + complexity_heuristic_value, + kComplexityHeuristicLimit); + } + return absl::OkStatus(); +} + +} // namespace + // Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n]. -Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, - const TritonFusionAnalysis& analysis, - const HloComputation* computation, mlir::triton::FuncOp fn, - const TritonGemmConfig& config, int shmem_budget) { - const HloDotInstruction* dot_instr = DynCast( - hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot)); +absl::Status EmitMatMul(mlir::OpBuilder builder, + absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const TritonFusionAnalysis& analysis, + const HloComputation* computation, + mlir::triton::FuncOp fn, + const TritonGemmConfig& config) { + TF_RETURN_IF_ERROR(CheckGemmTilingComplexityHeuristic(config)); + + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); + const HloDotInstruction* dot_instr = DynCast(instr); + TF_RET_CHECK(!dot_instr->sparse_operands()); // Use 32-bit indexing if addressing any of the inputs or the output (which // could grow if split_k is set) does not cross the INT_MAX boundary. // Otherwise, fall back to 64-bit indexing, which is slower. @@ -1483,7 +1929,22 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, Type index_ty = builder.getIntegerType(use_64bit_indexing ? 64 : 32); const HloInstruction* root = dot_instr->parent()->root_instruction(); - CHECK(!root->shape().IsTuple()); + TF_RET_CHECK(!root->shape().IsTuple()); + + HloInstructionAdaptor instr_adaptor{*instr}; + auto fusion_adaptor = HloFusionAdaptor::ForComputation(computation); + // TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. + bool is_8_bit_or_less_dot_with_F32 = HloAnyOf( + instr_adaptor.GetOperands(), *fusion_adaptor, + [&](HloInstructionAdaptor node) { + if (node.opcode() != HloOpcode::kConvert) { + return false; + } + Type in_type = + TritonType(builder, node.GetOperand(0).shape().element_type()); + Type out_type = TritonType(builder, node.shape().element_type()); + return in_type.getIntOrFloatBitWidth() <= 8 && out_type.isF32(); + }); // We'll be creating a lot of instructions from a single dot, use an // implicit loc builder so we don't have to pass around the location all the @@ -1492,18 +1953,19 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, ImplicitLocOpBuilder b(loc, builder); Type i32_ty = b.getI32Type(); - ValidateMatMulConfig(config, *dot_instr); + TF_RETURN_IF_ERROR(ValidateMatMulConfig(config, *dot_instr)); const int split_k = config.split_k; const int block_m = config.block_m; const int block_k = config.block_k; const int block_n = config.block_n; - const MatMulDims dims(config, *dot_instr, analysis); + TF_ASSIGN_OR_RETURN(const MatMulDims dims, + MatMulDims::Create(config, *dot_instr, analysis)); const MatMulLaunchConfig launch_config(config, *dot_instr, dims); VLOG(6) << analysis.ToString(); - MatMulEmitterHelper emitter(libdevice_path, dot_instr, b, index_ty, dims, - launch_config, analysis); + MatMulEmitterHelper emitter(libdevice_path, device_info, dot_instr, b, + index_ty, dims, launch_config, analysis); constexpr int group_m = 8; const int64_t width = group_m * launch_config.grid_n; @@ -1567,6 +2029,7 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, iter_args_next.reserve(iter_args.size()); absl::flat_hash_map values_lhs; absl::flat_hash_map values_rhs; + // Load tiles of all parameters of LHS and RHS scopes and advance pointers. for (int i = 0; i < iter_args.size() - 1; ++i) { const bool is_lhs = @@ -1645,21 +2108,38 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, dot_input_rhs = apply_mask(1, dot_input_rhs); } - const bool allow_tf32 = - tsl::tensor_float_32_execution_enabled() && - absl::c_none_of(dot_instr->precision_config().operand_precision(), - [](const int precision) { - return precision != PrecisionConfig::DEFAULT; - }); - - // Execute matrix multiplication of input tiles and pass the accumulator. - // TODO(manany): Should be looked into once we enable Hopper workloads. - // maxNumImpreciseAcc flag was introduced for Hopper to accumulate in a - // lower precision than the output type. The change was introduced here: - // https://github.com/openai/triton/commit/31b0c521427109a8eda609b58d756c380b21599a - Value accumulator_next = b.create(dot_input_lhs, dot_input_rhs, - iter_args.back(), allow_tf32, - /*maxNumImpreciseAcc=*/0); + const HloModule* hlo_module = dot_instr->GetModule(); + if (hlo_module->config().debug_options().xla_gpu_enable_bf16_3way_gemm() && + hlo_module->config().debug_options().xla_gpu_enable_bf16_6way_gemm()) { + LOG(WARNING) << "Both BF16 6way gemm and 3way gemm are enabled." + << " Fallback to BF16 6way gemm."; + } + + Value accumulator_next; + if (Is6xBfloat16MatMul(dot_instr, b, dot_input_lhs, dot_input_rhs, + device_info)) { + absl::StatusOr accumulator_next_or = Emit6xBfloat16MatMul( + b, dot_input_lhs, dot_input_rhs, iter_args.back()); + TF_CHECK_OK(accumulator_next_or.status()); + accumulator_next = accumulator_next_or.value(); + } else if (Is3xBfloat16MatMul(dot_instr, b, dot_input_lhs, dot_input_rhs, + device_info)) { + absl::StatusOr accumulator_next_or = Emit3xBfloat16MatMul( + b, dot_input_lhs, dot_input_rhs, iter_args.back()); + TF_CHECK_OK(accumulator_next_or.status()); + accumulator_next = accumulator_next_or.value(); + } else { + // Execute matrix multiplication of input tiles and pass the accumulator. + // TODO(manany): Should be looked into once we enable Hopper workloads. + // maxNumImpreciseAcc flag was introduced for Hopper to accumulate in a + // lower precision than the output type. The change was introduced here: + // https://github.com/openai/triton/commit/31b0c521427109a8eda609b58d756c380b21599a + accumulator_next = + b.create(dot_input_lhs, dot_input_rhs, iter_args.back(), + /*allowTF32=*/IsTf32Allowed(dot_instr) && + !is_8_bit_or_less_dot_with_F32, + /*maxNumImpreciseAcc=*/0); + } iter_args_next.push_back(accumulator_next); b.create(iter_args_next); @@ -1674,7 +2154,8 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, for (const Side& side : {lhs, rhs}) { for (const HloInstruction* input : ScopeInputs(analysis, side.scope)) { - CHECK(iter_args_to_inputs.insert({iter_args.size(), input}).second); + TF_RET_CHECK( + iter_args_to_inputs.insert({iter_args.size(), input}).second); TF_ASSIGN_OR_RETURN(Value tensor_ptr, emitter.EmitTensorPointer( input, side, GetArguments(fn, *input), pid_k, @@ -1705,12 +2186,12 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, Value tensor_pointer, emitter.EmitTensorPointer(input, out, GetArguments(fn, *input), pid_k, boundary_checks)); - CHECK(values_out - .insert({input, - EmitParameterLoad(b, tensor_pointer, boundary_checks)}) - .second); + TF_RET_CHECK(values_out + .insert({input, EmitParameterLoad(b, tensor_pointer, + boundary_checks)}) + .second); } - TF_RETURN_IF_ERROR(EmitScope(b, libdevice_path, &analysis, + TF_RETURN_IF_ERROR(EmitScope(b, libdevice_path, device_info, &analysis, TritonFusionAnalysis::Scope::OUTPUT, out.tiled_dims, to_emit, values_out) .status()); @@ -1731,29 +2212,180 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, b.create(tensor_pointer, values_out[producer], boundary_checks, mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); } - return OkStatus(); + return absl::OkStatus(); } -LaunchDimensions GetSoftMaxLaunchDimensions(const HloFusionAdaptor& fusion, - const TritonGemmConfig& config) { - auto reduce = HloFindIf(fusion.GetRoots(), fusion, [](auto node) { - return node.opcode() == HloOpcode::kReduce; - }); - CHECK(reduce != std::nullopt); - const Shape& reduce_input_shape = reduce->instruction().operand(0)->shape(); - int num_rows = 1; - for (int minor_axis = 1; minor_axis < reduce_input_shape.rank(); - ++minor_axis) { - num_rows *= reduce_input_shape.dimensions_minor(minor_axis); +// Computes the base pointer offset for the given pid and shape. +// `tile_offset_indexing` is a mapping from +// (program_id) -> [tile_offset0, ..., tile_offsetN] +Value ComputeBasePtrOffset(ImplicitLocOpBuilder b, Value pid, + const Shape& shape, + const IndexingMap& tile_offset_indexing) { + ArrayRef dimension_exprs = + tile_offset_indexing.GetAffineMap().getResults(); + + mlir::AffineExpr linear_index = + mlir::getAffineConstantExpr(0, b.getContext()); + int64_t stride = 1; + for (int i : shape.layout().minor_to_major()) { + linear_index = linear_index + dimension_exprs[i] * stride; + stride *= shape.dimensions(i); + } + + return b.create( + b.getI64Type(), + mlir_converter::ApplyAffineExpr(linear_index, /*dims=*/pid, + /*symbols=*/{}, b)); +} + +absl::Status EmitTiledSoftMax(mlir::OpBuilder builder, + absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + SymbolicTileAnalysis* analysis, + const HloComputation* computation, + mlir::triton::FuncOp fn) { + const HloInstruction* root = computation->root_instruction(); + auto loc = mlir::NameLoc::get(builder.getStringAttr(root->name())); + ImplicitLocOpBuilder b(loc, builder); + + // Assumptions we make about the matcher: + // * matches Softmax "diamonds" on the last axis, along with any number of + // elementwise operations/bitcasts on any edge + // * within a given fusion, every argument to a Softmax diamond has the same + // shape + // * every reduction is on the last axis + // * the last axis of every reduction parameter has the same length + // * reductions only reduce a single operand + // * all the shapes have canonical layout (logical layout = physical layout) + // * the computation has a single output + // * we tile along a single dimension + + const HloInstruction* reduce = hlo_query::GetFirstInstructionWithOpcode( + *computation, HloOpcode::kReduce); + + if (reduce == nullptr) { + return absl::InvalidArgumentError("No reduce instruction found."); + } + + const Shape& reduce_input_shape = reduce->operand(0)->shape(); + + if (reduce->dimensions().size() != 1 || + reduce->dimensions(0) != reduce_input_shape.rank() - 1) { + return absl::InvalidArgumentError( + absl::StrCat("Reduce instruction must reduce inner-most dimension. ", + reduce->ToString())); + } + + const Shape& root_shape = computation->root_instruction()->shape(); + if (!root_shape.IsArray() || + LayoutUtil::IsMonotonicWithDim0Minor(root_shape.layout())) { + return absl::InvalidArgumentError( + absl::StrCat("Root shape is not supported. ", root_shape.ToString())); + } + + int row_len = reduce_input_shape.dimensions_minor(0); + + Value pid = b.create( + b.getIndexType(), b.create(mt::ProgramIDDim::X)); + + std::vector output_tile_sizes( + computation->root_instruction()->shape().rank(), 1); + output_tile_sizes.back() = row_len; + + analysis->SetTileSizes(output_tile_sizes); + + // block_size must be a power of two. + int result_block_size = llvm::PowerOf2Ceil(row_len); + + std::vector boundary_checks; + if (result_block_size != row_len) { + boundary_checks.push_back(0); } - return {{num_rows, 1, 1}, {config.num_warps * WarpSize(), 1, 1}}; + // Emits load instructions + auto emit_param_load = [&](const SymbolicTiledHloInstruction& tiled_hlo) + -> absl::StatusOr { + std::vector tile_sizes, tile_strides, tile_offsets; + for (auto [size, stride, offset] : llvm::zip( + analysis->TileSizes(tiled_hlo), analysis->TileStrides(tiled_hlo), + analysis->TileOffsets(tiled_hlo))) { + if (size == 1) continue; + + tile_sizes.push_back(CreateConst(b, b.getI64Type(), size)); + tile_strides.push_back(CreateConst(b, b.getI64Type(), stride)); + tile_offsets.push_back(CreateConst(b, b.getI32Type(), offset)); + } + + TF_ASSIGN_OR_RETURN( + IndexingMap program_id_to_input_tile_indexing, + analysis->ComputeBlockIdToTileOffsetIndexing(tiled_hlo)); + + // Manually compute pointer offset to avoid materialized fully parallel + // dimensions in the tile. Current codegen tried to avoid size-1 dims. + Value ptr_offset = ComputeBasePtrOffset(b, pid, tiled_hlo.hlo()->shape(), + program_id_to_input_tile_indexing); + + auto fn_arg = fn.getArgument(tiled_hlo.hlo()->parameter_number()); + auto tile_ptr = AddPtr(b, fn_arg, ptr_offset); + + if (tile_sizes.empty()) { + return EmitParameterLoad(b, tile_ptr, boundary_checks); + } + + Value emitted_tensor = b.create( + /*base=*/tile_ptr, + /*shape=*/tile_sizes, + /*strides=*/tile_strides, + /*offsets=*/tile_offsets, + /*tensorShape=*/std::vector{result_block_size}, + /*order=*/std::vector{0}); + + return EmitParameterLoad(b, emitted_tensor, boundary_checks); + }; + + absl::flat_hash_map values_out; + TF_ASSIGN_OR_RETURN(Value result, + EmitTiledScope(b, libdevice_path, device_info, *analysis, + emit_param_load, values_out)); + + TF_ASSIGN_OR_RETURN( + IndexingMap program_id_to_output_tile_indexing, + analysis->ComputeBlockIdToTileOffsetIndexing(*analysis->GetRoot())); + + Value ptr_offset = ComputeBasePtrOffset(b, pid, root_shape, + program_id_to_output_tile_indexing); + + Value store_tensor = b.create( + /*base=*/AddPtr(b, fn.getArgument(computation->num_parameters()), + ptr_offset), + /*shape=*/ValueRange{CreateConst(b, b.getI64Type(), row_len)}, + /*strides=*/ValueRange{CreateConst(b, b.getI64Type(), 1)}, + /*offsets=*/ValueRange{CreateConst(b, b.getI32Type(), 0)}, + /*tensorShape=*/std::vector{result_block_size}, + /*order=*/std::vector{0}); + + b.create(store_tensor, result, std::vector{0}, + mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); + + return absl::OkStatus(); } -Status EmitSoftMax(mlir::OpBuilder builder, absl::string_view libdevice_path, - const TritonFusionAnalysis& analysis, - const HloComputation* computation, mlir::triton::FuncOp fn, - const TritonGemmConfig& config, int) { +absl::Status EmitSoftMax(mlir::OpBuilder builder, + absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const TritonFusionAnalysis& analysis, + const HloComputation* computation, + mlir::triton::FuncOp fn, + const TritonGemmConfig& config) { + SymbolicTileAnalysisOrError symbolic_tile_analysis_or = + SymbolicTileAnalysis::AnalyzeComputation(*computation, + builder.getContext()); + if (auto* symbolic_tile_analysis = + std::get_if(&symbolic_tile_analysis_or)) { + return EmitTiledSoftMax(builder, libdevice_path, device_info, + symbolic_tile_analysis, computation, fn); + } + const HloInstruction* root = computation->root_instruction(); auto loc = mlir::NameLoc::get(builder.getStringAttr(root->name())); ImplicitLocOpBuilder b(loc, builder); @@ -1776,13 +2408,13 @@ Status EmitSoftMax(mlir::OpBuilder builder, absl::string_view libdevice_path, const HloInstruction* reduce = hlo_query::GetFirstInstructionWithOpcode( *computation, HloOpcode::kReduce); - CHECK_NE(reduce, nullptr); + TF_RET_CHECK(reduce != nullptr); Shape reduce_input_shape = reduce->operand(0)->shape(); - CHECK_EQ(reduce->opcode(), HloOpcode::kReduce); - CHECK_EQ(reduce->dimensions().size(), 1); - CHECK_EQ(reduce->dimensions()[0], reduce_input_shape.rank() - 1); + TF_RET_CHECK(reduce->opcode() == HloOpcode::kReduce); + TF_RET_CHECK(reduce->dimensions().size() == 1); + TF_RET_CHECK(reduce->dimensions()[0] == reduce_input_shape.rank() - 1); int row_len = reduce_input_shape.dimensions_minor(0); @@ -1816,9 +2448,8 @@ Status EmitSoftMax(mlir::OpBuilder builder, absl::string_view libdevice_path, param, /*dimension=*/1); // Make sure only batch and reduce dims are present in tiling - CHECK_EQ(analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, param, - /*dimension=*/2), - nullptr); + TF_RET_CHECK(analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, param, + /*dimension=*/2) == nullptr); if (!reduce_iterspec) { // This parameter's broadcast is along the reduce dimension, and so @@ -1834,8 +2465,8 @@ Status EmitSoftMax(mlir::OpBuilder builder, absl::string_view libdevice_path, continue; } - CHECK_NE(reduce_iterspec, nullptr); - CHECK_EQ(reduce_iterspec->size(), 1); + TF_RET_CHECK(reduce_iterspec != nullptr); + TF_RET_CHECK(reduce_iterspec->size() == 1); // TODO(b/310721908): The below assumes that we tile along a single dim. int reduce_dim_len = reduce_iterspec->front().count; @@ -1848,13 +2479,14 @@ Status EmitSoftMax(mlir::OpBuilder builder, absl::string_view libdevice_path, Value base_offset = batch_iterspec ? row_offset : zero_offset; // We assume that the reduced axis of this parameter has length row_len. - CHECK_EQ(reduce_dim_len, row_len); + // TODO(b/316637896): Relax assumption that param reduce_dim_len == row_len. + TF_RET_CHECK(reduce_dim_len == row_len); // block_size must be a power of two. int block_size = pow(2, ceil(log(reduce_dim_len) / log(2))); // Verify that this param contains a single contiguous fragment. - CHECK_EQ(reduce_iterspec->front().subfragments.size(), 1); + TF_RET_CHECK(reduce_iterspec->front().subfragments.size() == 1); Value emitted_tensor = b.create( /*base=*/AddPtr(b, fn.getArgument(param_idx), base_offset), @@ -1874,7 +2506,7 @@ Status EmitSoftMax(mlir::OpBuilder builder, absl::string_view libdevice_path, /*index=*/0, pid, result_block_size, /*split_value=*/1)}; TF_ASSIGN_OR_RETURN( Value result, - EmitScope(b, libdevice_path, &analysis, + EmitScope(b, libdevice_path, device_info, &analysis, TritonFusionAnalysis::Scope::OUTPUT, tiled_dims, computation->MakeInstructionPostOrder(), values_out)); @@ -1889,24 +2521,25 @@ Status EmitSoftMax(mlir::OpBuilder builder, absl::string_view libdevice_path, b.create(store_tensor, result, std::vector{0}, mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); - return OkStatus(); + return absl::OkStatus(); } // Simplified copy of translateLLVMToLLVMIR which in addition takes // path to libdevice directly as an argument. -StatusOr> TranslateLLVMToLLVMIR( +absl::StatusOr> TranslateLLVMToLLVMIR( llvm::LLVMContext* llvmContext, mlir::ModuleOp module, absl::string_view libdevice_path) { mlir::DialectRegistry registry; mlir::registerBuiltinDialectTranslation(registry); mlir::registerLLVMDialectTranslation(registry); mlir::registerNVVMDialectTranslation(registry); + mlir::registerROCDLDialectTranslation(registry); module->getContext()->appendDialectRegistry(registry); std::unique_ptr llvmModule = mlir::translateModuleToLLVMIR(module, *llvmContext); if (!llvmModule) { - return InternalError("Failed to emit LLVM IR."); + return Internal("Failed to emit LLVM IR."); } // Link external libraries before performing optimizations. @@ -1919,29 +2552,20 @@ StatusOr> TranslateLLVMToLLVMIR( if (auto err = optPipeline(llvmModule.get())) { llvm::errs() << err; - return InternalError("Failed to optimize LLVM IR."); + return Internal("Failed to optimize LLVM IR."); } return llvmModule; } -namespace { - -std::string GetLibdevicePath(const HloComputation* hlo_computation) { - return nvptx::LibDevicePath(hlo_computation->parent() - ->config() - .debug_options() - .xla_gpu_cuda_data_dir()); -} - -} // namespace - -StatusOr> CreateTritonModule( +absl::StatusOr> CreateTritonModule( const TritonFusionAnalysis& analysis, absl::string_view fn_name, const HloComputation* hlo_computation, const se::DeviceDescription& device_info, const TritonGemmConfig& config, TritonIrEmitter ir_emitter, mlir::MLIRContext& mlir_context) { - mlir_context.loadDialect(); + mlir_context.loadDialect(); + mlir::OpBuilder b(&mlir_context); auto loc = mlir::NameLoc::get(b.getStringAttr(hlo_computation->name())); mlir::OwningOpRef triton_module = @@ -1971,66 +2595,48 @@ StatusOr> CreateTritonModule( fn.addEntryBlock(); b.setInsertionPointToStart(&fn.front()); - TF_RETURN_IF_ERROR(ir_emitter(b, GetLibdevicePath(hlo_computation), analysis, - hlo_computation, fn, config, - device_info.shared_memory_per_block_optin())); + TF_RETURN_IF_ERROR(ir_emitter( + b, GetLibdevicePath(hlo_computation->parent()->config(), device_info), + device_info, analysis, hlo_computation, fn, config)); b.create(loc); + mlir::PassManager pm(&mlir_context); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + TF_RET_CHECK(pm.run(triton_module.get()).succeeded()); + VLOG(6) << llvm_ir::DumpToString(*triton_module); if (DumpingEnabledForHloModule(*hlo_computation->parent())) { DumpToFileInDirOrStdout(*hlo_computation->parent(), "triton_ir", "ttir", llvm_ir::DumpToString(*triton_module)); } - CHECK(mlir::succeeded(mlir::verify(*triton_module))); + TF_RET_CHECK(mlir::succeeded(mlir::verify(*triton_module))); return std::move(triton_module); } -StatusOr TritonWrapper( +absl::StatusOr TritonWrapper( const TritonFusionAnalysis& analysis, absl::string_view fn_name, - const HloComputation* hlo_computation, absl::string_view fusion_kind, - const se::CudaComputeCapability& cc, + const HloComputation* hlo_computation, const se::GpuComputeCapability& cc, const se::DeviceDescription& device_info, const TritonGemmConfig& config, llvm::Module* llvm_module, TritonIrEmitter ir_emitter, mlir::MLIRContext& mlir_context) { - if (fusion_kind == kTritonGemmFusionKind) { - // This is a heuristic that serves as a proxy for register usage and code - // size. - // - // We have noticed that tilings with very long LLVM IR code are both slow to - // compile and slow to run. This can be for example due to register spills. - // So we should skip these tilings to save time. But it's better to skip - // them before the LLVM IR is generated. To do that, we came up with a - // formula that strongly correlates with the LLVM IR size. The formula is - // the size of the two input and the output thread block tiles divided by - // the number of warps. We read - // https://developer.nvidia.com/blog/cutlass-linear-algebra-cuda/ as a - // reference, and found the formula by trial and error. - // - // To regenerate the limit, we have to run an exhaustive search on all - // tilings for a few different HLOs, printing the runtimes and the heuristic - // values. - // From that, we can find a limit, such that all tilings within alpha * - // optimal_runtime have a heuristic value less than or equal to the limit. - // - // In our measurements, all tilings which were within 1.13 * optimal_runtime - // had a complexity_heuristic_value <= kComplexityHeuristicLimit. - // - // See go/tiling-heuristic for more details. - constexpr int64_t kComplexityHeuristicLimit = 9000; - int64_t complexity_heuristic_value = - (config.block_m * config.block_n + - (config.block_m + config.block_n) * config.block_k) / - config.num_warps; - VLOG(2) << "Complexity heuristic: " << complexity_heuristic_value; - if (complexity_heuristic_value > kComplexityHeuristicLimit) { - return ResourceExhausted("Tiling complexity heuristic exceeded: %d > %d", - complexity_heuristic_value, - kComplexityHeuristicLimit); + if (std::holds_alternative(cc)) { + auto ccCuda = std::get(cc); + if (!ccCuda.IsAtLeastAmpere()) { + return absl::FailedPreconditionError( + "Triton support is only enabled for Ampere GPUs and up."); } } + auto debug_options = GetDebugOptionsFromFlags(); + if (debug_options.xla_gpu_enable_triton_hopper()) { + // Set environment variables for consumption by Triton. + tsl::setenv("ENABLE_MMA_V3", "true", true /*overwrite*/); + tsl::setenv("ENABLE_PIPELINING", "true", true /*overwrite*/); + } + TF_ASSIGN_OR_RETURN( auto triton_module, CreateTritonModule(analysis, fn_name, hlo_computation, device_info, @@ -2040,12 +2646,29 @@ StatusOr TritonWrapper( VLOG(2) << config.ToString(); // Compile Triton kernel to LLVM. - std::optional log_stream; const HloModule* hlo_module = hlo_computation->parent(); + return CompileTritonToLLVM(hlo_module->config(), hlo_module->name(), cc, + device_info, config, triton_module.get(), + llvm_module, mlir_context); +} + +// TODO(b/325220878): Replace TritonGemmConfig with a more generic abstraction. +absl::StatusOr CompileTritonToLLVM( + const HloModuleConfig& hlo_config, absl::string_view hlo_module_name, + const se::GpuComputeCapability& cc, + const se::DeviceDescription& device_info, const TritonGemmConfig& config, + mlir::ModuleOp triton_module, llvm::Module* llvm_module, + mlir::MLIRContext& mlir_context) { + if (std::holds_alternative(cc)) { + auto ccCuda = std::get(cc); + if (!ccCuda.IsAtLeastAmpere()) { + return absl::FailedPreconditionError( + "Triton support is only enabled for Ampere GPUs and up."); + } + } bool should_verify = - (hlo_module->config().debug_options().xla_gpu_llvm_verification_level() >= - 1); + (hlo_config.debug_options().xla_gpu_llvm_verification_level() >= 1); #ifndef NDEBUG should_verify = true; #endif @@ -2053,13 +2676,14 @@ StatusOr TritonWrapper( mlir::PassManager pm(&mlir_context); pm.enableVerifier(should_verify); - if (hlo_module->config().debug_options().xla_gpu_dump_llvmir()) { + std::optional log_stream; + if (hlo_config.debug_options().xla_gpu_dump_llvmir()) { const std::string basename = - absl::StrCat(absl::string_view(tsl::io::Basename(hlo_module->name())), + absl::StrCat(absl::string_view(tsl::io::Basename(hlo_module_name)), ".triton-passes.log"); std::string outputs_dir; if (!tsl::io::GetTestUndeclaredOutputsDir(&outputs_dir)) { - outputs_dir = hlo_module->config().debug_options().xla_dump_to(); + outputs_dir = hlo_config.debug_options().xla_dump_to(); } if (!outputs_dir.empty()) { std::string path = tsl::io::JoinPath(outputs_dir, basename); @@ -2083,7 +2707,13 @@ StatusOr TritonWrapper( } } - CreateTritonPipeline(pm, cc, config.num_warps, config.num_stages); + // Lower affine expressions into arithmetic ops. + pm.addPass(mlir::createLowerAffinePass()); + + mlir::triton::nvidia_gpu::ClusterInfo cluster_info; + if (!CreateTritonPipeline(pm, cc, config, /*out*/ cluster_info).ok()) { + return Internal("Failed to create Triton pipeline."); + } if (log_stream.has_value()) { pm.printAsTextualPipeline(log_stream.value()); log_stream->write("\n\n", 2); @@ -2094,22 +2724,22 @@ StatusOr TritonWrapper( // llvm::Linker::linkModules() segfaults if we don't strip locations. pm.addPass(mlir::createStripDebugInfoPass()); - bool succeeded = mlir::succeeded(pm.run(*triton_module)); + bool succeeded = mlir::succeeded(pm.run(triton_module)); if (log_stream.has_value()) { log_stream->flush(); } if (!succeeded) { - return InternalError("Failed to compile Triton kernel."); + return Internal("Failed to compile Triton kernel."); } const int shared_mem_bytes = - (*triton_module) - ->getAttrOfType("triton_gpu.shared") + triton_module->getAttrOfType("triton_gpu.shared") .getInt(); VLOG(2) << "Shared memory usage: " << shared_mem_bytes << " B"; - if (shared_mem_bytes > device_info.shared_memory_per_block_optin()) { + if (std::holds_alternative(cc) && + shared_mem_bytes > device_info.shared_memory_per_block_optin()) { return absl::ResourceExhaustedError(absl::StrFormat( "Shared memory size limit exceeded: requested %d, available: %d", shared_mem_bytes, device_info.shared_memory_per_block_optin())); @@ -2117,8 +2747,8 @@ StatusOr TritonWrapper( TF_ASSIGN_OR_RETURN( std::unique_ptr ll_triton_module, - TranslateLLVMToLLVMIR(&llvm_module->getContext(), *triton_module, - GetLibdevicePath(hlo_computation))); + TranslateLLVMToLLVMIR(&llvm_module->getContext(), triton_module, + GetLibdevicePath(hlo_config, device_info))); VLogModule(5, *ll_triton_module); if (should_verify) { VerifyModule(*ll_triton_module); @@ -2130,14 +2760,32 @@ StatusOr TritonWrapper( ll_triton_module->setDataLayout(llvm_module->getDataLayout()); ll_triton_module->setTargetTriple(llvm_module->getTargetTriple()); // Use override flag because libdevice functions can be present in both. - CHECK(!llvm::Linker::linkModules(*llvm_module, std::move(ll_triton_module), - llvm::Linker::Flags::OverrideFromSrc)); + TF_RET_CHECK( + !llvm::Linker::linkModules(*llvm_module, std::move(ll_triton_module), + llvm::Linker::Flags::OverrideFromSrc)); VLogModule(5, *llvm_module); if (should_verify) { VerifyModule(*llvm_module); } - return {{shared_mem_bytes}}; + // `cluster_info` must be read after pm.run(). + std::optional cluster_dim; + if (config.num_ctas > 1) { + VLOG(3) << "num_ctas: " << config.num_ctas + << ", cluster_info: " << cluster_info.clusterDimX << "," + << cluster_info.clusterDimY << "," << cluster_info.clusterDimZ; + if (cluster_info.clusterDimX > 1 || cluster_info.clusterDimY > 1 || + cluster_info.clusterDimZ > 1) { + cluster_dim = + se::ClusterDim(cluster_info.clusterDimX, cluster_info.clusterDimY, + cluster_info.clusterDimZ); + } + } else { + TF_RET_CHECK(cluster_info.clusterDimX == 1 && + cluster_info.clusterDimY == 1 && + cluster_info.clusterDimZ == 1); + } + return {{shared_mem_bytes, cluster_dim}}; } } // namespace gpu diff --git a/xla/service/gpu/ir_emitter_triton.h b/xla/service/gpu/ir_emitter_triton.h index 8a6e233c17f02..91cd73e70da7c 100644 --- a/xla/service/gpu/ir_emitter_triton.h +++ b/xla/service/gpu/ir_emitter_triton.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,74 +16,114 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_IR_EMITTER_TRITON_H_ #define XLA_SERVICE_GPU_IR_EMITTER_TRITON_H_ +#include #include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "llvm/IR/Module.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/triton_fusion_analysis.h" -#include "xla/statusor.h" +#include "xla/service/hlo_module_config.h" +#include "xla/status.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/launch_dim.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" namespace xla { namespace gpu { +namespace mt = ::mlir::triton; + struct TritonWrapperResult { - int64_t shmem_bytes; + int64_t shmem_bytes = 0; + std::optional cluster_dim; }; // Compute the launch dimensions for the given Triton MatMul. -LaunchDimensions GetMatMulLaunchDimensions(const TritonFusionAnalysis& analysis, - const HloFusionAdaptor& fusion, - const TritonGemmConfig& config); +absl::StatusOr GetMatMulLaunchDimensions( + const TritonFusionAnalysis& analysis, const HloFusionAdaptor& fusion, + const TritonGemmConfig& config); // Use tiling and execution parameters from 'config'. -Status EmitMatMul(mlir::OpBuilder b, absl::string_view libdevice_path, - const TritonFusionAnalysis& analysis, - const HloComputation* computation, mlir::triton::FuncOp fn, - const TritonGemmConfig& config, int shmem_budget); +absl::Status EmitMatMul(mlir::OpBuilder b, absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const TritonFusionAnalysis& analysis, + const HloComputation* computation, + mlir::triton::FuncOp fn, + const TritonGemmConfig& config); // Compute the launch dimensions for the given Triton SoftMax. LaunchDimensions GetSoftMaxLaunchDimensions(const HloFusionAdaptor& fusion, const TritonGemmConfig& config); // Generate Softmax in Triton IR inside 'fn'. // Use execution parameters from 'config'. -Status EmitSoftMax(mlir::OpBuilder b, absl::string_view libdevice_path, - const TritonFusionAnalysis& analysis, - const HloComputation* computation, mlir::triton::FuncOp fn, - const TritonGemmConfig& config, int shmem_budget); +absl::Status EmitSoftMax(mlir::OpBuilder b, absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const TritonFusionAnalysis& analysis, + const HloComputation* computation, + mlir::triton::FuncOp fn, + const TritonGemmConfig& config); using TritonIrEmitter = std::function; + mlir::OpBuilder, absl::string_view, const se::DeviceDescription&, + const TritonFusionAnalysis& analysis, const HloComputation*, + mlir::triton::FuncOp, const TritonGemmConfig&)>; // Generate Triton IR by running the provided generator and compile it into LLVM // IR. // MatMul and SoftMax above are some such IR generators. -StatusOr TritonWrapper( +absl::StatusOr TritonWrapper( const TritonFusionAnalysis& analysis, absl::string_view fn_name, - const HloComputation* hlo_computation, absl::string_view fusion_kind, - const se::CudaComputeCapability& cc, + const HloComputation* hlo_computation, const se::GpuComputeCapability& cc, const se::DeviceDescription& device_info, const TritonGemmConfig& config, llvm::Module* llvm_module, TritonIrEmitter ir_emitter, mlir::MLIRContext& mlir_context); // Creates the initial Triton module for the given fusion. Visible for testing, // use TritonWrapper instead. -StatusOr> CreateTritonModule( +absl::StatusOr> CreateTritonModule( const TritonFusionAnalysis& analysis, absl::string_view fn_name, const HloComputation* hlo_computation, const se::DeviceDescription& device_info, const TritonGemmConfig& config, TritonIrEmitter ir_emitter, mlir::MLIRContext& mlir_context); +// Compiles a given Triton module to LLVM IR. +absl::StatusOr CompileTritonToLLVM( + const HloModuleConfig& hlo_config, absl::string_view hlo_module_name, + const se::GpuComputeCapability& cc, + const se::DeviceDescription& device_info, const TritonGemmConfig& config, + mlir::ModuleOp triton_module, llvm::Module* llvm_module, + mlir::MLIRContext& mlir_context); + +// Create Triton pipeline. +// +// `out_cluster_info` must be kept alive at least until pm.run() is called. +// It should be read after that. We have to pass the cluster dims to +// LaunchDimensions. Triton currently uses this as an out-parameter to return +// the cluster dims determined based on `config.num_ctas` and a heuristic. There +// are some signs that show that this was intended to be used as an in-out +// parameter which would give a hint to Triton which cluster dims we prefer to +// use, but that's not the case currently. +absl::Status CreateTritonPipeline( + mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, + const TritonGemmConfig& config, + mt::nvidia_gpu::ClusterInfo& out_cluster_info); + +std::string GetLibdevicePath(const HloModuleConfig& hlo_config, + const se::DeviceDescription& device_info); + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/ir_emitter_triton_cuda.cc b/xla/service/gpu/ir_emitter_triton_cuda.cc new file mode 100644 index 0000000000000..30f5aeb5cfc0c --- /dev/null +++ b/xla/service/gpu/ir_emitter_triton_cuda.cc @@ -0,0 +1,115 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "nvidia/include/NVGPUToLLVM/NVGPUToLLVMPass.h" +#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" +#include "absl/status/status.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/hlo_module_config.h" +#include "xla/stream_executor/device_description.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +namespace xla { +namespace gpu { + +namespace mt = ::mlir::triton; + +absl::Status CreateTritonPipeline( + mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, + const TritonGemmConfig& config, + mt::nvidia_gpu::ClusterInfo& out_cluster_info) { + auto ccCuda = std::get(cc); + const int ccAsInt = ccCuda.major * 10 + ccCuda.minor; + const int threadsPerWarp = 32; + + // Based on make_ttir() in + // @triton//:third_party/nvidia/backend/compiler.py + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mt::createRewriteTensorPointerPass()); + pm.addPass(mt::createCombineOpsPass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mt::createReorderBroadcastPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createLoopInvariantCodeMotionPass()); + pm.addPass(mlir::createSymbolDCEPass()); + + // Based on make_ttgir() in + // @triton//:third_party/nvidia/backend/compiler.py + pm.addPass(mt::createConvertTritonToTritonGPUPass( + config.num_warps, threadsPerWarp, config.num_ctas, ccAsInt)); + pm.addPass(mt::gpu::createCoalescePass()); + pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(&out_cluster_info)); + pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); + pm.addPass(mt::gpu::createOptimizeThreadLocalityPass()); + pm.addPass(mt::gpu::createAccelerateMatmulPass(ccAsInt)); + pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); + pm.addPass(mt::gpu::createOptimizeDotOperandsPass()); + pm.addPass(mlir::createCSEPass()); + + pm.addPass(mt::gpu::createPipelinePass(config.num_stages, config.num_warps, + config.num_ctas, ccAsInt)); + + if (!ccCuda.IsAtLeastHopper()) { + pm.addPass(mt::gpu::createPrefetchPass()); + } + + pm.addPass(mt::gpu::createOptimizeDotOperandsPass()); + pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); + pm.addPass(mt::gpu::createReduceDataDuplicationPass()); + pm.addPass(mt::gpu::createReorderInstructionsPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createSymbolDCEPass()); + if (ccCuda.IsAtLeastHopper()) { + pm.addPass(mlir::createTritonNvidiaGPUFenceInsertionPass(ccAsInt)); + } + pm.addPass(mlir::createCanonicalizerPass()); + + // Based on make_llir() in + // @triton//:third_party/nvidia/backend/compiler.py + pm.addPass(mt::gpu::createDecomposeUnsupportedConversionsPass()); + pm.addPass(mlir::createConvertSCFToCFPass()); + pm.addPass(mlir::createConvertIndexToLLVMPass()); + pm.addPass(mt::gpu::createAllocateSharedMemoryPass()); + pm.addPass(mt::createConvertTritonGPUToLLVMPass(ccAsInt)); + pm.addPass(mt::createConvertNVGPUToLLVMPass()); + pm.addPass(mlir::createArithToLLVMConversionPass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createSymbolDCEPass()); + // Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass. + + return absl::OkStatus(); +} + +std::string GetLibdevicePath(const HloModuleConfig& hlo_config, + const se::DeviceDescription& device_info) { + return nvptx::LibDevicePath( + hlo_config.debug_options().xla_gpu_cuda_data_dir()); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/ir_emitter_triton_large_test.cc b/xla/service/gpu/ir_emitter_triton_large_test.cc index c5d98bb87f1d5..05a0b78a6d876 100644 --- a/xla/service/gpu/ir_emitter_triton_large_test.cc +++ b/xla/service/gpu/ir_emitter_triton_large_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "xla/error_spec.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/tests/hlo_test_base.h" @@ -43,7 +44,7 @@ ENTRY e { arg1 = f16[32800,32] parameter(1) gemm = (f16[65536,32], s8[0]) custom-call(arg0, arg1), custom_call_target="__cublas$gemm", - backend_config="{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}" + backend_config="{\"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" ROOT get-tuple-element = f16[65536,32] get-tuple-element((f16[65536,32], s8[0]) gemm), index=0 } )"; @@ -62,7 +63,7 @@ ENTRY e { p0 = f16[65536,32800] parameter(0) p1 = f16[32800,32] parameter(1) ROOT _ = f16[65536,32] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config="{kind: \"__triton_gemm\", triton_gemm_config: {\"block_m\":\"32\",\"block_n\":\"32\",\"block_k\":\"32\",\"split_k\":\"1\",\"num_stages\":\"1\",\"num_warps\":\"1\"}}" + backend_config="{\"fusion_backend_config\": {kind: \"__triton_gemm\", triton_gemm_config: {\"block_m\":\"32\",\"block_n\":\"32\",\"block_k\":\"32\",\"split_k\":\"1\",\"num_stages\":\"1\",\"num_warps\":\"1\",\"num_ctas\":\"1\"}}}" } )"; diff --git a/xla/service/gpu/ir_emitter_triton_parametrized_test.cc b/xla/service/gpu/ir_emitter_triton_parametrized_test.cc index 12b3f5205f9e5..3be360ae55437 100644 --- a/xla/service/gpu/ir_emitter_triton_parametrized_test.cc +++ b/xla/service/gpu/ir_emitter_triton_parametrized_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -58,15 +58,19 @@ class MixedTypeTest : public GpuCodegenTest, ->GetDeviceDescription() .cuda_compute_capability(); } + + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + // We are testing Triton, remove cuBLAS fallback for these tests. + debug_options.set_xla_gpu_cublas_fallback(false); + // Always rewrite Gemms with Triton regardless of size. + debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); + return debug_options; + } }; TEST_P(MixedTypeTest, MixedTypeDotProducesCorrectResult) { MixTypeParams params = GetParam(); - if ((params.lhs_ty == BF16 || params.rhs_ty == BF16) && - !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const std::string hlo_string_template = R"( HloModule m @@ -127,7 +131,7 @@ INSTANTIATE_TEST_SUITE_P(RewriteTestSuite, MixedTypeTest, // TritonRewriteTest2Params{F32, F16}, // TritonRewriteTest2Params{F32, BF16}, MixTypeParams{S8, BF16, 24, 40, 8}, - MixTypeParams{S8, F16, 80, 16, 32}, + MixTypeParams{S8, F16, 80, 16, 32, 1e-3, 1e-6}, MixTypeParams{F16, F32, 127, 3, 300, 1e-2, 1e-2}, MixTypeParams{F16, BF16, 544, 96, 16, 1e-3, 1e-3}, MixTypeParams{BF16, F32, 77, 500, 333, 3e-3, 3e-3}, @@ -140,6 +144,8 @@ class TritonTest : public GpuCodegenTest { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_triton_gemm_any(true); debug_options.set_xla_gpu_cublas_fallback(false); + // Always rewrite Gemms with Triton regardless of size. + debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); return debug_options; } @@ -191,10 +197,15 @@ ENTRY e { p0 = f32[15,33]{1,0} parameter(0) ROOT triton_gemm__ = f32[15,68]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm___computation, - backend_config={"kind":"__triton_gemm", - "triton_gemm_config":{"block_m":"32","block_n":"32", - "block_k":"32","split_k":"1", - "num_stages":"1","num_warps":"4"}} + backend_config={"fusion_backend_config":{"kind":"__triton_gemm", + "triton_gemm_config": + {"block_m":"32", + "block_n":"32", + "block_k":"32", + "split_k":"1", + "num_stages":"1", + "num_warps":"4", + "num_ctas":"1"}}} })"; const std::string hlo_test = absl::Substitute( kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), @@ -213,11 +224,11 @@ ENTRY e { fusion = f32[33,68]{1,0} fusion(p1), kind=kLoop, calls=fused_computation gemm = (f32[15,68]{1,0}, s8[0]{0}) custom-call(p0, fusion), custom_call_target="__cublas$$gemm", - backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers": + backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers": {"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"], "lhs_batch_dimensions":[],"rhs_batch_dimensions":[]}, "alpha_imag":0,"precision_config": - {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"} + {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"}} ROOT get-tuple-element = f32[15,68]{1,0} get-tuple-element((f32[15,68]{1,0}, s8[0]{0}) gemm), index=0 })"; const std::string hlo_ref = absl::Substitute( @@ -302,10 +313,15 @@ ENTRY e { p2 = $0[11,63]{1,0} parameter(2) ROOT triton_gemm__ = f32[92,63]{1,0} fusion(p0, p1, p2), kind=kCustom, calls=triton_gemm___computation, - backend_config={"kind":"__triton_gemm", - "triton_gemm_config":{"block_m":"64","block_n":"32", - "block_k":"64","split_k":"1", - "num_stages":"2","num_warps":"2"}} + backend_config={"fusion_backend_config":{"kind":"__triton_gemm", + "triton_gemm_config": + {"block_m":"64", + "block_n":"32", + "block_k":"64", + "split_k":"1", + "num_stages":"2", + "num_warps":"2", + "num_ctas":"1"}}} })"; const std::string hlo_test = absl::Substitute( kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), @@ -326,11 +342,11 @@ ENTRY e { fusion = f32[11,63]{1,0} fusion(p1, p2), kind=kLoop, calls=fused_computation gemm = (f32[92,63]{1,0}, s8[0]{0}) custom-call(p0, fusion), custom_call_target="__cublas$$gemm", - backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers": + backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers": {"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"], "lhs_batch_dimensions":[],"rhs_batch_dimensions":[]}, "alpha_imag":0,"precision_config": - {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"} + {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"}} ROOT get-tuple-element = f32[92,63]{1,0} get-tuple-element((f32[92,63]{1,0}, s8[0]{0}) gemm), index=0 })"; const std::string hlo_ref = absl::Substitute( @@ -339,7 +355,7 @@ ENTRY e { EXPECT_TRUE(RunAndCompareTwoModules( hlo_ref, hlo_test, ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance}, - /*run_hlo_passes=*/false)); + /*run_hlo_passes=*/false, /*args_max_bits_of_precision=*/6)); } std::vector TestedBinaryElementwise(PrimitiveType element_type) { @@ -354,28 +370,28 @@ INSTANTIATE_TEST_SUITE_P( ElementwiseTestSuitePRED, BinaryElementwiseTest, ::testing::Combine(::testing::Values(PRED), ::testing::ValuesIn(TestedBinaryElementwise(PRED)), - ::testing::Values(3e-2)), + ::testing::Values(0)), ElementwiseTestParamsToString); INSTANTIATE_TEST_SUITE_P( ElementwiseTestSuiteS8, BinaryElementwiseTest, ::testing::Combine(::testing::Values(S8), ::testing::ValuesIn(TestedBinaryElementwise(S8)), - ::testing::Values(3e-2)), + ::testing::Values(0)), ElementwiseTestParamsToString); INSTANTIATE_TEST_SUITE_P( ElementwiseTestSuiteS16, BinaryElementwiseTest, ::testing::Combine(::testing::Values(S16), ::testing::ValuesIn(TestedBinaryElementwise(S16)), - ::testing::Values(1e-3)), + ::testing::Values(0)), ElementwiseTestParamsToString); INSTANTIATE_TEST_SUITE_P( ElementwiseTestSuiteS32, BinaryElementwiseTest, ::testing::Combine(::testing::Values(S32), ::testing::ValuesIn(TestedBinaryElementwise(S32)), - ::testing::Values(1e-5)), + ::testing::Values(0)), ElementwiseTestParamsToString); INSTANTIATE_TEST_SUITE_P( @@ -429,10 +445,15 @@ ENTRY e { p2 = $0[11,63]{1,0} parameter(2) ROOT triton_gemm__ = f32[92,63]{1,0} fusion(p0, p1, p2), kind=kCustom, calls=triton_gemm___computation, - backend_config={"kind":"__triton_gemm", - "triton_gemm_config":{"block_m":"16","block_n":"64", - "block_k":"16","split_k":"1", - "num_stages":"3","num_warps":"2"}} + backend_config={"fusion_backend_config":{"kind":"__triton_gemm", + "triton_gemm_config": + {"block_m":"16", + "block_n":"64", + "block_k":"16", + "split_k":"1", + "num_stages":"3", + "num_warps":"2", + "num_ctas":"1"}}} })"; const std::string hlo_test = absl::Substitute( kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), @@ -453,11 +474,11 @@ ENTRY e { fusion = f32[11,63]{1,0} fusion(p1, p2), kind=kLoop, calls=fused_computation gemm = (f32[92,63]{1,0}, s8[0]{0}) custom-call(p0, fusion), custom_call_target="__cublas$$gemm", - backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers": + backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers": {"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"], "lhs_batch_dimensions":[],"rhs_batch_dimensions":[]}, "alpha_imag":0,"precision_config": - {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"} + {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"}} ROOT get-tuple-element = f32[92,63]{1,0} get-tuple-element((f32[92,63]{1,0}, s8[0]{0}) gemm), index=0 })"; const std::string hlo_ref = absl::Substitute( @@ -516,28 +537,33 @@ TEST_P(SelectTest, SelectFusionExecutesCorrectly) { const std::string kHloTestTemplate = R"( triton_gemm___computation { - parameter_0 = $1[92,11]{1,0} parameter(0) - parameter_1 = $0[11,63]{1,0} parameter(1) - parameter_2 = $0[11,63]{1,0} parameter(2) - parameter_3 = pred[11,63]{1,0} parameter(3) - f1.1 = $0[11,63]{1,0} select(parameter_3, parameter_1, parameter_2) - c.1 = $1[11,63]{1,0} convert(f1.1) + parameter_0 = $1[92,13]{1,0} parameter(0) + parameter_1 = $0[13,63]{1,0} parameter(1) + parameter_2 = $0[13,63]{1,0} parameter(2) + parameter_3 = pred[13,63]{1,0} parameter(3) + f1.1 = $0[13,63]{1,0} select(parameter_3, parameter_1, parameter_2) + c.1 = $1[13,63]{1,0} convert(f1.1) ROOT _.1 = $1[92,63]{1,0} dot(parameter_0, c.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={HIGH, HIGH} } ENTRY e { - p0 = $1[92,11]{1,0} parameter(0) - p1 = $0[11,63]{1,0} parameter(1) - p2 = $0[11,63]{1,0} parameter(2) - p3 = pred[11,63]{1,0} parameter(3) + p0 = $1[92,13]{1,0} parameter(0) + p1 = $0[13,63]{1,0} parameter(1) + p2 = $0[13,63]{1,0} parameter(2) + p3 = pred[13,63]{1,0} parameter(3) ROOT triton_gemm__ = $1[92,63]{1,0} fusion(p0, p1, p2, p3), kind=kCustom, calls=triton_gemm___computation, - backend_config={"kind":"__triton_gemm", - "triton_gemm_config":{"block_m":"16","block_n":"64", - "block_k":"16","split_k":"1", - "num_stages":"3","num_warps":"2"}} + backend_config={"fusion_backend_config":{"kind":"__triton_gemm", + "triton_gemm_config": + {"block_m":"16", + "block_n":"64", + "block_k":"16", + "split_k":"1", + "num_stages":"3", + "num_warps":"2", + "num_ctas":"1"}}} })"; const std::string hlo_test = absl::Substitute( kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type1), @@ -545,60 +571,36 @@ ENTRY e { const std::string kHloRefTemplate = R"( fused_computation { - p0 = $0[11,63]{1,0} parameter(0) - p1 = $0[11,63]{1,0} parameter(1) - p2 = pred[11,63]{1,0} parameter(2) - f.1 = $0[11,63]{1,0} select(p2, p0, p1) - ROOT convert.1 = $1[11,63]{1,0} convert(f.1) + p0 = $0[13,63]{1,0} parameter(0) + p1 = $0[13,63]{1,0} parameter(1) + p2 = pred[13,63]{1,0} parameter(2) + f.1 = $0[13,63]{1,0} select(p2, p0, p1) + ROOT convert.1 = $1[13,63]{1,0} convert(f.1) } ENTRY e { - p3 = pred[11,63]{1,0} parameter(3) - p2 = $0[11,63]{1,0} parameter(2) - p1 = $0[11,63]{1,0} parameter(1) - p0 = $1[92,11]{1,0} parameter(0) - fusion = $1[11,63]{1,0} fusion(p1, p2, p3), kind=kLoop, + p3 = pred[13,63]{1,0} parameter(3) + p2 = $0[13,63]{1,0} parameter(2) + p1 = $0[13,63]{1,0} parameter(1) + p0 = $1[92,13]{1,0} parameter(0) + fusion = $1[13,63]{1,0} fusion(p1, p2, p3), kind=kLoop, calls=fused_computation gemm = ($1[92,63]{1,0}, s8[0]{0}) custom-call(p0, fusion), custom_call_target="__cublas$$gemm", - backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers": + backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers": {"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"], "lhs_batch_dimensions":[],"rhs_batch_dimensions":[]}, "alpha_imag":0,"precision_config": - {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"} + {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"}} ROOT get-tuple-element = $1[92,63]{1,0} get-tuple-element(($1[92,63]{1,0}, s8[0]{0}) gemm), index=0 })"; const std::string hlo_ref = absl::Substitute( kHloRefTemplate, primitive_util::LowercasePrimitiveTypeName(data_type1), primitive_util::LowercasePrimitiveTypeName(data_type2)); - float tolerance; - switch (data_type1) { - case F32: - tolerance = 1e-6; - break; - case BF16: - tolerance = 1e-6; - break; - case F16: - tolerance = 2e-4; - break; - case PRED: - case S8: - tolerance = 3e-2; - break; - case S16: - tolerance = 1e-3; - break; - case S32: - tolerance = 1e-5; - break; - default: - ABSL_UNREACHABLE(); - } EXPECT_TRUE(RunAndCompareTwoModules( - hlo_ref, hlo_test, ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance}, - /*run_hlo_passes=*/false)); + hlo_ref, hlo_test, ErrorSpec{/*aabs=*/0, /*arel=*/0}, + /*run_hlo_passes=*/false, /*args_max_bits_of_precision=*/9)); } std::string TwoPrimitiveTypesToString( @@ -651,10 +653,15 @@ ENTRY e { p1 = f32[11,63]{1,0} parameter(1) ROOT triton_gemm__ = f32[92,63]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm___computation, - backend_config={"kind":"__triton_gemm", - "triton_gemm_config":{"block_m":"16","block_n":"64", - "block_k":"16","split_k":"1", - "num_stages":"3","num_warps":"2"}} + backend_config={"fusion_backend_config":{"kind":"__triton_gemm", + "triton_gemm_config": + {"block_m":"16", + "block_n":"64", + "block_k":"16", + "split_k":"1", + "num_stages":"3", + "num_warps":"2", + "num_ctas":"1"}}} })"; const std::string hlo_test = absl::Substitute( kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type)); @@ -675,11 +682,11 @@ ENTRY e { calls=fused_computation gemm = (f32[92,63]{1,0}, s8[0]{0}) custom-call(p0, fusion), custom_call_target="__cublas$$gemm", - backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers": + backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers": {"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"], "lhs_batch_dimensions":[],"rhs_batch_dimensions":[]}, "alpha_imag":0,"precision_config": - {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"} + {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"}} ROOT get-tuple-element = f32[92,63]{1, 0} get-tuple-element((f32[92,63]{1, 0}, s8[0]{0}) gemm), index=0 })"; const std::string hlo_ref = absl::Substitute( @@ -749,7 +756,7 @@ ENTRY e { p0 = $0[2,2] parameter(0) p1 = f32[2,2] parameter(1) ROOT r = f32[2,2] fusion(p0, p1), kind=kCustom, calls=t, - backend_config={"kind":"__triton_gemm"} + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} })", primitive_util::LowercasePrimitiveTypeName(data_type1), primitive_util::LowercasePrimitiveTypeName(data_type2)); @@ -787,10 +794,6 @@ TEST_P(TritonSoftmaxTest, CanFuseAndEmitExactSoftmax) { if (data_type == F16) { GTEST_SKIP() << "Exponential op does not support F16."; - } else if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; } const std::string hlo_text_template = R"( @@ -875,18 +878,7 @@ ENTRY main { const std::string hlo_text = absl::Substitute( hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); - std::string hlo_ref_template; - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - hlo_ref_template = R"( -; CHECK: ENTRY -; CHECK: %[[P0:.*]] = bf16[127,125]{1,0} parameter(0) -; CHECK: %[[FUSED_REDUCE:.*]] = f32[127]{0} fusion(%[[P0]]) -; CHECK: ROOT -; CHECK-SAME: fusion(%[[P0]], %[[FUSED_REDUCE]]) -)"; - } else { - hlo_ref_template = R"( + std::string hlo_ref_template = R"( ; CHECK: ENTRY ; CHECK: %[[P0:.*]] = $0[127,125]{1,0} parameter(0) ; CHECK: ROOT @@ -894,7 +886,6 @@ ENTRY main { ; CHECK-SAME: kind=kCustom ; CHECK-SAME: __triton_softmax )"; - } const std::string hlo_ref = absl::Substitute( hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); @@ -919,12 +910,6 @@ ENTRY main { TEST_P(TritonSoftmaxTest, CanFuseAndEmitSoftmaxDiamondWithSmallRows) { PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - constexpr absl::string_view kHloTextTemplate = R"( HloModule softmax min_computation { @@ -961,12 +946,6 @@ ENTRY main { } TEST_F(TritonSoftmaxTest, CanFuseAndEmitDiamondWithBF16Converts) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text = R"( HloModule softmax max_computation { @@ -1008,10 +987,6 @@ TEST_P( if (data_type == F16) { GTEST_SKIP() << "Exponential op does not support F16."; - } else if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; } const std::string hlo_text_template = R"( @@ -1081,12 +1056,6 @@ TEST_P(TritonSoftmaxTest, CanFuseAndEmitDiamondWithMultipleBroadcastDimensions) { PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule softmax max_computation { @@ -1143,12 +1112,7 @@ TEST_P(TritonSoftmaxTest, if (data_type == F16) { GTEST_SKIP() << "Exponential op does not support F16."; - } else if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; } - const std::string hlo_text_template = R"( HloModule softmax max_computation { @@ -1212,12 +1176,6 @@ TEST_P( CanFuseAndEmitTwoDiamondsWithSecondDiamondProducerEqualToFirstDiamondRoot) { PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule softmax max_computation { @@ -1281,12 +1239,6 @@ TEST_P(TritonSoftmaxTest, CanFuseAndEmitDiamondWithTrailingUnaryElementwiseAtTheRoot) { PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule softmax max_computation { @@ -1341,12 +1293,6 @@ ENTRY main { TEST_P(TritonSoftmaxTest, CanFuseAndEmitDiamondWithUnaryElementwisePrefix) { PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule softmax max_computation { @@ -1402,12 +1348,6 @@ TEST_P(TritonSoftmaxTest, CanFuseAndEmitSoftmaxDiamondWithLastDimensionBitcastAfterReduce) { PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule softmax max_computation { @@ -1461,17 +1401,10 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -TEST_P( - TritonSoftmaxTest, - CanFuseAndEmitConvertInvolvingBF16InputIntoSoftmaxDiamondCorrectlyForAmpereAndVoltaComputeCapability) { // NOLINT(whitespace/line_length) +TEST_P(TritonSoftmaxTest, + CanFuseAndEmitConvertInvolvingBF16InputIntoSoftmaxDiamondCorrectly) { PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule softmax max_computation { @@ -1491,8 +1424,7 @@ ENTRY main { const std::string hlo_text = absl::Substitute( hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) { - const std::string hlo_ref = R"( + const std::string hlo_ref = R"( ; CHECK: ENTRY ; CHECK: %[[P0:.*]] = bf16[127,125]{1,0} parameter(0) ; CHECK: ROOT @@ -1501,23 +1433,7 @@ ENTRY main { ; CHECK-SAME: __triton_softmax )"; - MatchOptimizedHlo(hlo_text, hlo_ref); - } else { - const std::string hlo_ref_template = R"( -; CHECK: ENTRY -; CHECK: %[[P0:.*]] = bf16[127,125]{1,0} parameter(0) -; CHECK: %[[CONVERT:.*]] = $0[127,125]{1,0} convert(%[[P0]]) -; CHECK: ROOT -; CHECK-SAME: fusion(%[[CONVERT]]) -; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax -)"; - - const std::string hlo_ref = - absl::Substitute(hlo_ref_template, - primitive_util::LowercasePrimitiveTypeName(data_type)); - MatchOptimizedHlo(hlo_text, hlo_ref); - } + MatchOptimizedHlo(hlo_text, hlo_ref); float tolerance; switch (data_type) { @@ -1542,12 +1458,6 @@ TEST_P( CanFuseAndEmitBinaryElementwiseProducerIntoDiamondWhenBothOperandsAreTheSame) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamond max_computation { @@ -1604,12 +1514,6 @@ TEST_P( CanFuseAndEmitIntermediateBinaryElementwiseWithinDiamondWhenBothOperandsAreTheSame) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamond max_computation { @@ -1666,12 +1570,6 @@ TEST_P( CanFuseAndEmitBinaryElementwiseWhenBothOperandsAreTheSameBetweenDiamonds) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamonds max_computation { @@ -1737,12 +1635,6 @@ TEST_P( CanFuseAndEmitBinaryElementwiseConsumerWhereBothOperandsAreTheSameIntoDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamond max_computation { @@ -1805,12 +1697,6 @@ TEST_P( CanFuseAndEmitTwoBinaryElementwiseWhereBothOperandsAreTheSameBetweenDiamonds) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamonds max_computation { @@ -1911,10 +1797,6 @@ TEST_P(TritonSoftmaxTest, CanFuseAndEmitRMSNormDiamond) { if (data_type == F16) { GTEST_SKIP() << "rsqrt op does not support F16."; - } else if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; } const std::string hlo_text_template = R"( @@ -1981,12 +1863,6 @@ TEST_P( CanFuseAndEmitBinaryElementwiseWhereTheFirstOperandIsASplatConstantBetweenDiamonds) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamonds add_computation { @@ -2050,12 +1926,6 @@ TEST_P( CanFuseAndEmitBinaryElementwiseWhereTheSecondOperandIsASplatConstantBetweenDiamonds) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamonds add_computation { @@ -2119,12 +1989,6 @@ TEST_P( CanFuseAndEmitBinaryElementwiseWhereTheFirstOperandIsASplatConstantWithinDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamond max_computation { @@ -2184,12 +2048,6 @@ TEST_P( CanFuseAndEmitBinaryElementwiseConsumerWhereTheFirstOperandIsASplatConstantIntoDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamond add_computation { @@ -2248,12 +2106,6 @@ TEST_P( CanFuseAndEmitBinaryElementwiseProducerWhereTheFirstOperandIsASplatConstantIntoDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamond add_computation { @@ -2313,12 +2165,6 @@ TEST_P( CanFuseAndEmitBinaryElementwiseOperationWhereOneOperandIsASharedSplatProducerIntoDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule nonfusible_diamond max_computation { @@ -2385,17 +2231,16 @@ ENTRY main { } )"; - // Param order is arbitrary. We test that only param_1 is in the fused root - // instruction below. const std::string hlo_ref = R"( ; CHECK: ENTRY ; CHECK-DAG: %[[param_0:.*]] = f32[125,127]{1,0} parameter(0) ; CHECK-DAG: %[[param_1:.*]] = f32[127]{0} parameter(1) ; CHECK: ROOT ; CHECK-SAME: f32[125,127]{1,0} fusion -; CHECK-SAME: %[[param_1]] -; CHECK-SAME: kind=kCustom -; CHECK-SAME: triton_softmax +; CHECK-SAME: %[[param_0]] +; CHECK-SAME: %[[param_1]] +; CHECK-SAME: kind=kCustom +; CHECK-SAME: triton_softmax )"; MatchOptimizedHlo(hlo_text, hlo_ref); diff --git a/xla/service/gpu/ir_emitter_triton_rocm.cc b/xla/service/gpu/ir_emitter_triton_rocm.cc new file mode 100644 index 0000000000000..c8147aa6c0bfd --- /dev/null +++ b/xla/service/gpu/ir_emitter_triton_rocm.cc @@ -0,0 +1,122 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/hlo_module_config.h" +#include "tsl/platform/rocm_rocdl_path.h" +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +namespace xla { +namespace gpu { + +namespace ma = ::mlir::arith; +namespace mm = ::mlir::math; +namespace ml = ::mlir::LLVM; +namespace mt = ::mlir::triton; + +using ::llvm::SmallVector; +using mlir::ArrayRef; +using ::mlir::ShapedType; +using ::mlir::Type; +using ::mlir::Value; +using mlir::ValueRange; + +absl::Status CreateTritonPipeline( + mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, + const TritonGemmConfig& config, + mt::nvidia_gpu::ClusterInfo& out_cluster_info) { + // TODO(ROCm): Check whether value different than 0 can be used. + const int ccAsInt = 0; + // TODO(ROCm): Check why some test fail when threadsPerWarp is set to 64. + const int threadsPerWarp = 32; + + // Based on make_ttir() in + // @triton//:third_party/nvidia/backend/compiler.py + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mt::createRewriteTensorPointerPass()); + pm.addPass(mt::createCombineOpsPass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mt::createReorderBroadcastPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createLoopInvariantCodeMotionPass()); + pm.addPass(mlir::createSymbolDCEPass()); + + // Based on make_ttgir() in + // @triton//:third_party/nvidia/backend/compiler.py + pm.addPass(mt::createConvertTritonToTritonGPUPass( + config.num_warps, threadsPerWarp, config.num_ctas, ccAsInt)); + pm.addPass(mt::gpu::createCoalescePass()); + pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); + pm.addPass(mt::gpu::createOptimizeThreadLocalityPass()); + pm.addPass(mt::gpu::createAccelerateMatmulPass(ccAsInt)); + pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); + pm.addPass(mt::gpu::createOptimizeDotOperandsPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mt::gpu::createPipelinePass(config.num_stages, config.num_warps, + config.num_ctas, ccAsInt)); + pm.addPass(mt::gpu::createPrefetchPass()); + + pm.addPass(mt::gpu::createOptimizeDotOperandsPass()); + pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); + pm.addPass(mt::gpu::createReduceDataDuplicationPass()); + pm.addPass(mt::gpu::createReorderInstructionsPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createSymbolDCEPass()); + pm.addPass(mlir::createCanonicalizerPass()); + + // Based on make_llir() in + // @triton//:third_party/nvidia/backend/compiler.py + pm.addPass(mt::gpu::createDecomposeUnsupportedConversionsPass()); + pm.addPass(mlir::createConvertSCFToCFPass()); + pm.addPass(mlir::createConvertIndexToLLVMPass()); + pm.addPass(mt::gpu::createAllocateSharedMemoryPass()); + pm.addPass(mt::createConvertTritonAMDGPUToLLVMPass()); + pm.addPass(mlir::createArithToLLVMConversionPass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createSymbolDCEPass()); + // Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass. + pm.addPass(mlir::createConvertSCFToCFPass()); + pm.addPass(mlir::createConvertControlFlowToLLVMPass()); + + // There is no clusters in ROCm for now. + out_cluster_info.clusterDimX = 1; + out_cluster_info.clusterDimY = 1; + out_cluster_info.clusterDimZ = 1; + + return absl::OkStatus(); +} + +std::string GetLibdevicePath(const HloModuleConfig& hlo_config, + const se::DeviceDescription& device_info) { + std::string libdevice_dir = tsl::RocdlRoot(); + auto compute_capability = device_info.rocm_compute_capability(); + const std::string libdevice_path = + amdgpu::LibDevicePath(compute_capability.gcn_arch_name(), libdevice_dir); + return libdevice_path; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/ir_emitter_triton_test.cc b/xla/service/gpu/ir_emitter_triton_test.cc index e0990e5ba2b14..a8fe6df7a8b09 100644 --- a/xla/service/gpu/ir_emitter_triton_test.cc +++ b/xla/service/gpu/ir_emitter_triton_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ limitations under the License. #include "xla/service/gpu/ir_emitter_triton.h" +#include +#include #include #include #include @@ -22,25 +24,28 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/LLVMContext.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project #include "xla/autotuning.pb.h" #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/filecheck.h" #include "xla/tests/verified_hlo_module.h" @@ -54,6 +59,9 @@ limitations under the License. #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" +// TODO(b/317016172): Inspect usages of TritonGemmConfig and potentially update +// them to to use newly exposed parameters. + namespace xla { namespace gpu { namespace { @@ -74,7 +82,22 @@ class TritonGemmTest : public TritonTest { public: DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); + // Do not fall back to cuBLAS, we are testing Triton. debug_options.set_xla_gpu_cublas_fallback(false); + // Do not autotune split-k by default, since this prevents deterministically + // matching the optimized HLO. + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + // Always rewrite Gemms with Triton regardless of size. + debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); + return debug_options; + } +}; + +class TritonGemmTestWithSplitK : public TritonGemmTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = TritonGemmTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_split_k_autotuning(true); return debug_options; } }; @@ -90,13 +113,14 @@ class TritonGemmTestWithoutTritonGemmAny : public TritonGemmTest { class TritonFilecheckTest : public TritonTest { public: - StatusOr CreateTritonIrAndFileCheck( - absl::string_view hlo_text, const TritonGemmConfig& config, - TritonIrEmitter emitter, absl::string_view triton_fusion_name, - absl::string_view filecheck_pattern); + absl::Status CreateTritonIrAndFileCheck(absl::string_view hlo_text, + const TritonGemmConfig& config, + TritonIrEmitter emitter, + absl::string_view triton_fusion_name, + absl::string_view filecheck_pattern); }; -StatusOr TritonFilecheckTest::CreateTritonIrAndFileCheck( +absl::Status TritonFilecheckTest::CreateTritonIrAndFileCheck( absl::string_view hlo_text, const TritonGemmConfig& config, TritonIrEmitter emitter, absl::string_view triton_fusion_name, absl::string_view filecheck_pattern) { @@ -105,6 +129,7 @@ StatusOr TritonFilecheckTest::CreateTritonIrAndFileCheck( auto* computation = verified_module->GetComputationWithName(triton_fusion_name); + TF_RET_CHECK(computation != nullptr); TF_ASSIGN_OR_RETURN(auto analysis, TritonFusionAnalysis::Execute(*computation)); @@ -113,14 +138,15 @@ StatusOr TritonFilecheckTest::CreateTritonIrAndFileCheck( auto module, CreateTritonModule(analysis, "triton_fn", computation, TestGpuDeviceInfo::RTXA6000DeviceInfo(), config, emitter, context)); - mlir::PassManager pm(&context); - pm.addPass(mlir::createCanonicalizerPass()); - TF_RET_CHECK(pm.run(module.get()).succeeded()); std::string out; llvm::raw_string_ostream os(out); module->print(os); - return RunFileCheck(out, filecheck_pattern); + TF_ASSIGN_OR_RETURN(bool succeeded, RunFileCheck(out, filecheck_pattern)); + if (!succeeded) { + return absl::InternalError("FileCheck failed."); + } + return absl::OkStatus(); } TEST_F(TritonFilecheckTest, TestGemm) { @@ -140,11 +166,14 @@ ENTRY e { p0 = s8[80,115]{1,0} parameter(0) ROOT triton_gemm_r = f32[80,137]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_r, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32,"split_k":1,"num_stages":1,"num_warps":2}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":1}}} })"; TritonGemmConfig config(16, 64, 32, 1, 1, 1); - ASSERT_THAT(CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, - "triton_gemm_r", R"( + TF_EXPECT_OK(CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, + "triton_gemm_r", R"( CHECK: tt.func @triton_fn(%[[LHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[RHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[OUT:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[ZERO_KN:.*]] = arith.constant dense<0.000000e+00> : tensor<32x64xf32> CHECK-DAG: %[[ZERO_MK:.*]] = arith.constant dense<0.000000e+00> : tensor<16x32xf32> @@ -185,28 +214,109 @@ CHECK: %[[RHS_ITER_PTR_NEXT:.*]] = tt.advance %[[RHS_ITER_PTR]], [%[[TILE CHECK: %[[CONVERTED:.*]] = arith.sitofp %[[LHS_TILE]] : tensor<16x32xi8> to tensor<16x32xf32> CHECK: %[[TILE_K_LIMIT:.*]] = arith.subi %[[SIZE_K]], %[[BLOCK_K]] : i32 CHECK: %[[K_TILE_IOTA:.*]] = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> -CHECK: %[[K_OFFSETS_1K:.*]] = tt.expand_dims %[[K_TILE_IOTA]] {axis = 0 : i32} : (tensor<32xi32>) -> tensor<1x32xi32> -CHECK: %[[TILE_K_LIMIT_1K:.*]] = tt.splat %[[TILE_K_LIMIT]] : (i32) -> tensor<1x32xi32> +CHECK: %[[K_OFFSETS_1K:.*]] = tt.expand_dims %[[K_TILE_IOTA]] {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> +CHECK: %[[TILE_K_LIMIT_1K:.*]] = tt.splat %[[TILE_K_LIMIT]] : i32 -> tensor<1x32xi32> CHECK: %[[LHS_INBOUNDS_1K:.*]] = arith.cmpi slt, %[[K_OFFSETS_1K]], %[[TILE_K_LIMIT_1K]] : tensor<1x32xi32> -CHECK: %[[LHS_INBOUNDS_MK:.*]] = tt.broadcast %[[LHS_INBOUNDS_1K]] : (tensor<1x32xi1>) -> tensor<16x32xi1> +CHECK: %[[LHS_INBOUNDS_MK:.*]] = tt.broadcast %[[LHS_INBOUNDS_1K]] : tensor<1x32xi1> -> tensor<16x32xi1> CHECK: %[[LHS_MASKED:.*]] = arith.select %[[LHS_INBOUNDS_MK]], %[[CONVERTED]], %[[ZERO_MK]] -CHECK: %[[K_OFFSETS_K1:.*]] = tt.expand_dims %[[K_TILE_IOTA]] {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32> -CHECK: %[[TILE_K_LIMIT_K1:.*]] = tt.splat %[[TILE_K_LIMIT]] : (i32) -> tensor<32x1xi32> +CHECK: %[[K_OFFSETS_K1:.*]] = tt.expand_dims %[[K_TILE_IOTA]] {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> +CHECK: %[[TILE_K_LIMIT_K1:.*]] = tt.splat %[[TILE_K_LIMIT]] : i32 -> tensor<32x1xi32> CHECK: %[[RHS_INBOUNDS_K1:.*]] = arith.cmpi slt, %[[K_OFFSETS_K1]], %[[TILE_K_LIMIT_K1]] : tensor<32x1xi32> -CHECK: %[[RHS_INBOUNDS_KN:.*]] = tt.broadcast %[[RHS_INBOUNDS_K1]] : (tensor<32x1xi1>) -> tensor<32x64xi1> +CHECK: %[[RHS_INBOUNDS_KN:.*]] = tt.broadcast %[[RHS_INBOUNDS_K1]] : tensor<32x1xi1> -> tensor<32x64xi1> CHECK: %[[RHS_MASKED:.*]] = arith.select %[[RHS_INBOUNDS_KN]], %[[RHS_TILE]], %[[ZERO_KN]] : tensor<32x64xi1>, tensor<32x64xf32> CHECK: %[[ACC_NEXT:.*]] = tt.dot %[[LHS_MASKED]], %[[RHS_MASKED]], %[[ACC]] CHECK: scf.yield %[[LHS_ITER_PTR_NEXT]], %[[RHS_ITER_PTR_NEXT]], %[[ACC_NEXT]] : !tt.ptr, 1>, !tt.ptr, 1>, tensor<16x64xf32> CHECK: } -CHECK: %[[TILE_OFFSET_M_OUT:.*]] = arith.muli %[[TILE_INDEX_M]], %[[TILE_SIZE_M]] -CHECK: %[[TILE_OFFSET_N_OUT:.*]] = arith.muli %[[TILE_INDEX_N]], %[[TILE_SIZE_N]] CHECK: %[[OUT_PTR:.*]] = tt.make_tensor_ptr %[[OUT]], [%[[C80]], %[[SIZE_M]]], [%[[SIZE_M]], %[[C1]]], [%[[C0]], %[[C0]]] {order = array} : , 1> -CHECK: %[[OUT_OFFSET:.*]] = tt.advance %[[OUT_PTR]], [%[[TILE_OFFSET_M_OUT]], %[[TILE_OFFSET_N_OUT]]] : , 1> +CHECK: %[[OUT_OFFSET:.*]] = tt.advance %[[OUT_PTR]], [%[[TILE_OFFSET_M_LHS]], %[[TILE_OFFSET_N_RHS]]] : , 1> CHECK: tt.store %[[OUT_OFFSET]], %[[FOR]]#2 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<16x64xf32> CHECK: tt.return CHECK: } -)"), - tsl::testing::IsOkAndHolds(true)); +)")); +} + +TEST_F(TritonFilecheckTest, TestGemmWithTrivialNonContractingDimension) { + const std::string kHloText = R"( +HloModule t, is_scheduled=true + +triton_dot { + param_0.1 = f32[137,115]{1,0} parameter(0) + param_1.1 = f32[1,115]{1,0} parameter(1) + ROOT dot = f32[137,1]{1,0} dot(param_0.1, param_1.1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = f32[137,115]{1,0} parameter(0) + p1 = f32[1,115]{1,0} parameter(1) + ROOT custom-call = f32[137,1]{1,0} fusion(p0, p1), kind=kCustom, + calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":1}}} +})"; + + TritonGemmConfig config(16, 16, 32, 1, 1, 1); + EXPECT_OK( + CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( +CHECK: tt.func @triton_fn(%[[LHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[RHS:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[OUT:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +CHECK-DAG: %[[ZERO_KN:.*]] = arith.constant dense<0.000000e+00> : tensor<32x16xf32> +CHECK-DAG: %[[ZERO_MK:.*]] = arith.constant dense<0.000000e+00> : tensor<16x32xf32> +CHECK-DAG: %[[ZERO_MN:.*]] = arith.constant dense<0.000000e+00> : tensor<16x16xf32> +CHECK-DAG: %[[SIZE_K:.*]] = arith.constant 115 : i32 +CHECK-DAG: %[[SIZE_M:.*]] = arith.constant 137 : i64 +CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i64 +CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32 +CHECK-DAG: %[[C115:.*]] = arith.constant 115 : i64 +CHECK-DAG: %[[TILE_SIZE_K:.*]] = arith.constant 32 : i32 +CHECK-DAG: %[[TILE_SIZE_M:.*]] = arith.constant 16 : i32 +CHECK-DAG: %[[C8:.*]] = arith.constant 8 : i32 +CHECK-DAG: %[[NUM_TILES_M:.*]] = arith.constant 9 : i32 +CHECK: %[[PID_NC:.*]] = tt.get_program_id x : i32 +CHECK: %[[GROUP_ID:.*]] = arith.divsi %[[PID_NC]], %[[C8]] +CHECK: %[[FIRST_PID_M:.*]] = arith.muli %[[GROUP_ID]], %[[C8]] +CHECK: %[[MAX_M:.*]] = arith.subi %[[NUM_TILES_M]], %[[FIRST_PID_M]] +CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[MAX_M]], %[[C8]] +CHECK: %[[GROUP_SIZE:.*]] = arith.select %[[CMP]], %[[MAX_M]], %[[C8]] +CHECK: %[[PID_M:.*]] = arith.remsi %[[PID_NC]], %[[GROUP_SIZE]] +CHECK: %[[TILE_INDEX_M:.*]] = arith.addi %[[FIRST_PID_M]], %[[PID_M]] +CHECK: %[[TMP:.*]] = arith.remsi %[[PID_NC]], %[[C8]] +CHECK: %[[TILE_INDEX_N:.*]] = arith.divsi %[[TMP]], %[[GROUP_SIZE]] +CHECK: %[[TILE_OFFSET_M_LHS:.*]] = arith.muli %[[TILE_INDEX_M]], %[[TILE_SIZE_M]] +CHECK: %[[LHS_PTR:.*]] = tt.make_tensor_ptr %[[LHS]] +CHECK: %[[LHS_TILE_PTR:.*]] = tt.advance %[[LHS_PTR]], [%[[TILE_OFFSET_M_LHS]], %[[C0]]] +CHECK: %[[TILE_OFFSET_N_RHS:.*]] = arith.muli %[[TILE_INDEX_N]], %[[TILE_SIZE_M]] +CHECK: %[[RHS_PTR:.*]] = tt.make_tensor_ptr %[[RHS]] +CHECK: %[[RHS_TILE_PTR:.*]] = tt.advance %[[RHS_PTR]], [%[[C0]], %[[TILE_OFFSET_N_RHS]]] +CHECK: %[[FOR:.*]]:3 = scf.for %[[BLOCK_K:.*]] = %[[C0]] to %[[SIZE_K]] step %[[TILE_SIZE_K]] +CHECK-SAME: iter_args(%[[LHS_ITER_PTR:.*]] = %[[LHS_TILE_PTR]], %[[RHS_ITER_PTR:.*]] = %[[RHS_TILE_PTR]], %[[ACC:.*]] = %[[ZERO_MN]]) +CHECK: %[[LHS_TILE:.*]] = tt.load %[[LHS_ITER_PTR]] {boundaryCheck = array +CHECK: %[[LHS_ITER_PTR_NEXT:.*]] = tt.advance %[[LHS_ITER_PTR]], [%[[C0]], %[[TILE_SIZE_K]]] +CHECK: %[[RHS_TILE:.*]] = tt.load %[[RHS_ITER_PTR]] {boundaryCheck = array +CHECK: %[[RHS_ITER_PTR_NEXT:.*]] = tt.advance %[[RHS_ITER_PTR]], [%[[TILE_SIZE_K]], %[[C0]]] +CHECK: %[[TILE_K_LIMIT:.*]] = arith.subi %[[SIZE_K]], %[[BLOCK_K]] : i32 +CHECK: %[[K_TILE_IOTA:.*]] = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> +CHECK: %[[K_OFFSETS_1K:.*]] = tt.expand_dims %[[K_TILE_IOTA]] {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> +CHECK: %[[TILE_K_LIMIT_1K:.*]] = tt.splat %[[TILE_K_LIMIT]] : i32 -> tensor<1x32xi32> +CHECK: %[[LHS_INBOUNDS_1K:.*]] = arith.cmpi slt, %[[K_OFFSETS_1K]], %[[TILE_K_LIMIT_1K]] : tensor<1x32xi32> +CHECK: %[[LHS_INBOUNDS_MK:.*]] = tt.broadcast %[[LHS_INBOUNDS_1K]] : tensor<1x32xi1> -> tensor<16x32xi1> +CHECK: %[[LHS_MASKED:.*]] = arith.select %[[LHS_INBOUNDS_MK]], %[[LHS_TILE]], %[[ZERO_MK]] +CHECK: %[[K_OFFSETS_K1:.*]] = tt.expand_dims %[[K_TILE_IOTA]] {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> +CHECK: %[[TILE_K_LIMIT_K1:.*]] = tt.splat %[[TILE_K_LIMIT]] : i32 -> tensor<32x1xi32> +CHECK: %[[RHS_INBOUNDS_K1:.*]] = arith.cmpi slt, %[[K_OFFSETS_K1]], %[[TILE_K_LIMIT_K1]] : tensor<32x1xi32> +CHECK: %[[RHS_INBOUNDS_KN:.*]] = tt.broadcast %[[RHS_INBOUNDS_K1]] : tensor<32x1xi1> -> tensor<32x16xi1> +CHECK: %[[RHS_MASKED:.*]] = arith.select %[[RHS_INBOUNDS_KN]], %[[RHS_TILE]], %[[ZERO_KN]] : tensor<32x16xi1>, tensor<32x16xf32> +CHECK: %[[ACC_NEXT:.*]] = tt.dot %[[LHS_MASKED]], %[[RHS_MASKED]], %[[ACC]] +CHECK: scf.yield %[[LHS_ITER_PTR_NEXT]], %[[RHS_ITER_PTR_NEXT]], %[[ACC_NEXT]] : !tt.ptr, 1>, !tt.ptr, 1>, tensor<16x16xf32> +CHECK: } + +CHECK: %[[OUT_PTR:.*]] = tt.make_tensor_ptr %[[OUT]], [%[[SIZE_M]], %[[C1]]], [%[[C1]], %[[C1]]], [%[[C0]], %[[C0]]] {order = array} : , 1> +CHECK: %[[OUT_OFFSET:.*]] = tt.advance %[[OUT_PTR]], [%[[TILE_OFFSET_M_LHS]], %[[TILE_OFFSET_N_RHS]]] : , 1> +CHECK: tt.store %[[OUT_OFFSET]], %[[FOR]]#2 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<16x16xf32> +CHECK: tt.return +CHECK: } +)")); } TEST_F(TritonFilecheckTest, TestSoftmaxEmitterWithSingleParameter) { @@ -229,14 +339,15 @@ triton_softmax_computation { ENTRY main { param_0 = f32[125,127]{1,0} parameter(0) - ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} + ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton_softmax"}} })"; TritonGemmConfig config(16, 64, 32, 1, 1, 1); - ASSERT_THAT(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, - "triton_softmax_computation", R"( + TF_EXPECT_OK(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, + "triton_softmax_computation", R"( +CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 * 127)> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK: %[[PID:.*]] = tt.get_program_id x : i32 -CHECK: arith.extsi %[[PID]] : i32 to i64 +CHECK: arith.index_castui %[[PID]] : i32 to index CHECK: tt.addptr %[[P0]] CHECK-NEXT: tt.make_tensor_ptr CHECK-SAME: , 1> @@ -257,8 +368,7 @@ CHECK-NEXT: tt.store CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<128xf32> CHECK: tt.return CHECK: } -)"), - tsl::testing::IsOkAndHolds(true)); +)")); } TEST_F(TritonFilecheckTest, TestSoftmaxEmitterWithSingleScalarParameter) { @@ -282,14 +392,15 @@ triton_softmax_computation { ENTRY main { param_0 = f32[] constant(42) - ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} + ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton_softmax"}} })"; TritonGemmConfig config(16, 64, 32, 1, 1, 1); - ASSERT_THAT(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, - "triton_softmax_computation", R"( + TF_EXPECT_OK(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, + "triton_softmax_computation", R"( +CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 * 127)> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 -CHECK-DAG: arith.extsi %[[PID]] : i32 to i64 +CHECK-DAG: arith.index_castui %[[PID]] : i32 to index CHECK-DAG: %[[ZERO_OFFSET:.*]] = arith.constant 0 : i64 CHECK-DAG: %[[ARG_0:.*]] = tt.addptr %[[P0]], %[[ZERO_OFFSET]] : !tt.ptr, i64 CHECK: tt.load %[[ARG_0]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32 @@ -309,8 +420,7 @@ CHECK-NEXT: tt.store CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<128xf32> CHECK: tt.return CHECK: } -)"), - tsl::testing::IsOkAndHolds(true)); +)")); } TEST_F(TritonFilecheckTest, TestSoftmaxEmitterWithMultipleParameters) { @@ -337,18 +447,20 @@ triton_softmax_computation { ENTRY main { param_0 = f32[125,127]{1,0} parameter(0) param_1 = f32[127]{0} parameter(1) - ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0, param_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} + ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0, param_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton_softmax"}} } )"; TritonGemmConfig config(16, 64, 32, 1, 1, 1); - ASSERT_THAT(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, - "triton_softmax_computation", R"( + TF_EXPECT_OK(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, + "triton_softmax_computation", R"( +CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 * 127)> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 -CHECK-DAG: %[[PID_i64:.*]] = arith.extsi %[[PID]] : i32 to i64 +CHECK-DAG: %[[PID_INDEX:.*]] = arith.index_castui %[[PID]] : i32 to index CHECK-DAG: %[[C127_i64:.*]] = arith.constant 127 : i64 CHECK-DAG: %[[ZERO_OFFSET:.*]] = arith.constant 0 : i64 -CHECK: %[[ROW_OFFSET:.*]] = arith.muli %[[PID_i64]], %[[C127_i64]] : i64 +CHECK: %[[ROW_OFFSET_INDEX:.*]] = affine.apply #[[MAP]]()[%[[PID_INDEX]]] +CHECK: %[[ROW_OFFSET:.*]] = arith.index_castui %[[ROW_OFFSET_INDEX]] : index to i64 CHECK: %[[ARG0:.*]] = tt.addptr %[[P0]], %[[ROW_OFFSET]] : !tt.ptr, i64 CHECK-NEXT: tt.make_tensor_ptr CHECK-SAME: , 1> @@ -371,8 +483,7 @@ CHECK-NEXT: tt.store CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<128xf32> CHECK: tt.return CHECK: } -)"), - tsl::testing::IsOkAndHolds(true)); +)")); } TEST_F(TritonFilecheckTest, @@ -402,24 +513,26 @@ triton_softmax_computation { ENTRY main { param_0 = f32[125,127]{1,0} parameter(1) param_1 = f32[127]{0} parameter(0) - ROOT triton_softmax = f32[125,127]{1,0} fusion(param_1, param_0), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} + ROOT triton_softmax = f32[125,127]{1,0} fusion(param_1, param_0), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton_softmax"}} } )"; TritonGemmConfig config(16, 64, 32, 1, 1, 1); - ASSERT_THAT(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, - "triton_softmax_computation", R"( + TF_EXPECT_OK(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, + "triton_softmax_computation", R"( +CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 * 127)> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 -CHECK-DAG: %[[PID_i64:.*]] = arith.extsi %[[PID]] : i32 to i64 +CHECK-DAG: %[[PID_INDEX:.*]] = arith.index_castui %[[PID]] : i32 to index CHECK-DAG: %[[C127_i64:.*]] = arith.constant 127 : i64 CHECK-DAG: %[[ZERO_OFFSET:.*]] = arith.constant 0 : i64 -CHECK: %[[ROW_OFFSET:.*]] = arith.muli %[[PID_i64]], %[[C127_i64]] : i64 -CHECK: %[[ARG0:.*]] = tt.addptr %[[P0]], %[[ZERO_OFFSET]] : !tt.ptr, i64 +CHECK: %[[ROW_OFFSET_INDEX:.*]] = affine.apply #[[MAP]]()[%[[PID_INDEX]]] +CHECK: %[[ROW_OFFSET:.*]] = arith.index_castui %[[ROW_OFFSET_INDEX]] : index to i64 +CHECK: %[[ARG1:.*]] = tt.addptr %[[P1]], %[[ROW_OFFSET]] : !tt.ptr, i64 CHECK-NEXT: tt.make_tensor_ptr CHECK-SAME: , 1> CHECK-NEXT: tt.load CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr, 1> -> tensor<128xf32> -CHECK: %[[ARG1:.*]] = tt.addptr %[[P1]], %[[ROW_OFFSET]] : !tt.ptr, i64 +CHECK: %[[ARG0:.*]] = tt.addptr %[[P0]], %[[ZERO_OFFSET]] : !tt.ptr, i64 CHECK-NEXT: tt.make_tensor_ptr CHECK-SAME: , 1> CHECK-NEXT: tt.load @@ -437,8 +550,7 @@ CHECK-NEXT: tt.store CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<128xf32> CHECK: tt.return CHECK: } -)"), - tsl::testing::IsOkAndHolds(true)); +)")); } TEST_F(TritonFilecheckTest, @@ -465,33 +577,35 @@ triton_softmax_computation { ENTRY main { param_0 = f32[125,127]{1,0} parameter(0) param_1 = f32[127]{0} parameter(1) - ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0, param_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} + ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0, param_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton_softmax"}} } )"; TritonGemmConfig config(16, 64, 32, 1, 1, 1); - ASSERT_THAT(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, - "triton_softmax_computation", R"( + TF_EXPECT_OK(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, + "triton_softmax_computation", R"( +CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 * 127)> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 -CHECK-DAG: %[[PID_i64:.*]] = arith.extsi %[[PID]] : i32 to i64 +CHECK-DAG: %[[PID_INDEX:.*]] = arith.index_castui %[[PID]] : i32 to index CHECK-DAG: %[[C127_i64:.*]] = arith.constant 127 : i64 CHECK-DAG: %[[ZERO_OFFSET:.*]] = arith.constant 0 : i64 -CHECK: %[[ROW_OFFSET:.*]] = arith.muli %[[PID_i64]], %[[C127_i64]] : i64 +CHECK: %[[ROW_OFFSET_INDEX:.*]] = affine.apply #[[MAP]]()[%[[PID_INDEX]]] +CHECK: %[[ROW_OFFSET:.*]] = arith.index_castui %[[ROW_OFFSET_INDEX]] : index to i64 CHECK: %[[ARG0:.*]] = tt.addptr %[[P0]], %[[ROW_OFFSET]] : !tt.ptr, i64 CHECK-NEXT: tt.make_tensor_ptr CHECK-SAME: , 1> CHECK-NEXT: tt.load CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr, 1> -> tensor<128xf32> -CHECK: %[[ARG1:.*]] = tt.addptr %[[P1]], %[[ZERO_OFFSET]] : !tt.ptr, i64 -CHECK-NEXT: tt.make_tensor_ptr -CHECK-SAME: , 1> -CHECK-NEXT: tt.load -CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr, 1> -> tensor<128xf32> CHECK: tt.reduce CHECK-NEXT: ^bb0(%[[ARG3:[^:]*]]: f32, %[[ARG4:[^:]*]]: f32): CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[ARG3]], %[[ARG4]] : f32 CHECK-NEXT: tt.reduce.return %[[ADD]] : f32 CHECK-NEXT: }) : (tensor<128xf32>) -> f32 +CHECK: %[[ARG1:.*]] = tt.addptr %[[P1]], %[[ZERO_OFFSET]] : !tt.ptr, i64 +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK-NEXT: tt.load +CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr, 1> -> tensor<128xf32> CHECK: tt.addptr %[[P2]] CHECK-NEXT: tt.make_tensor_ptr CHECK-SAME: , 1> @@ -499,8 +613,7 @@ CHECK-NEXT: tt.store CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<128xf32> CHECK: tt.return CHECK: } -)"), - tsl::testing::IsOkAndHolds(true)); +)")); } TEST_F(TritonFilecheckTest, @@ -532,18 +645,20 @@ ENTRY main { param_0 = f32[125,127]{1,0} parameter(1) param_1 = f32[127]{0} parameter(0) param_2 = f32[125]{0} parameter(2) - ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0, param_1, param_2), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} + ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0, param_1, param_2), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton_softmax"}} } )"; TritonGemmConfig config(16, 64, 32, 1, 1, 1); - ASSERT_THAT(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, - "triton_softmax_computation", R"( + TF_EXPECT_OK(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, + "triton_softmax_computation", R"( +CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 * 127)> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P3:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { -CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 -CHECK-DAG: %[[PID_i64:.*]] = arith.extsi %[[PID]] : i32 to i64 CHECK-DAG: %[[C127_i64:.*]] = arith.constant 127 : i64 CHECK-DAG: %[[ZERO_OFFSET:.*]] = arith.constant 0 : i64 -CHECK: %[[ROW_OFFSET:.*]] = arith.muli %[[PID_i64]], %[[C127_i64]] : i64 +CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 +CHECK-DAG: %[[PID_INDEX:.*]] = arith.index_castui %[[PID]] : i32 to index +CHECK: %[[ROW_OFFSET_INDEX:.*]] = affine.apply #[[MAP]]()[%[[PID_INDEX]]] +CHECK: %[[ROW_OFFSET:.*]] = arith.index_castui %[[ROW_OFFSET_INDEX]] : index to i64 CHECK: %[[ARG0:.*]] = tt.addptr %[[P0]], %[[ROW_OFFSET]] : !tt.ptr, i64 CHECK-NEXT: tt.make_tensor_ptr CHECK-SAME: , 1> @@ -554,6 +669,7 @@ CHECK-NEXT: tt.make_tensor_ptr CHECK-SAME: , 1> CHECK-NEXT: tt.load CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr, 1> -> tensor<128xf32> +CHECK: %[[PID_i64:.*]] = arith.index_castui %[[PID_INDEX]] : index to i64 CHECK: %[[ARG2:.*]] = tt.addptr %[[P2]], %[[PID_i64]] : !tt.ptr, i64 CHECK-NEXT: tt.load %[[ARG2]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32 CHECK: tt.reduce @@ -571,8 +687,7 @@ CHECK-NEXT: tt.store CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<128xf32> CHECK: tt.return CHECK: } -)"), - tsl::testing::IsOkAndHolds(true)); +)")); } TEST_F(TritonFilecheckTest, TestSoftmaxEmitterWithMultipleTiledDimensions) { @@ -603,18 +718,20 @@ ENTRY main { param_0 = f32[10,125,127]{2,1,0} parameter(0) param_1 = f32[127]{0} parameter(1) param_2 = f32[10,125]{1,0} parameter(2) - ROOT triton_softmax = f32[10,125,127]{2,1,0} fusion(param_0, param_1, param_2), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} + ROOT triton_softmax = f32[10,125,127]{2,1,0} fusion(param_0, param_1, param_2), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton_softmax"}} } )"; TritonGemmConfig config(16, 64, 32, 1, 1, 1); - ASSERT_THAT(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, - "triton_softmax_computation", R"( + TF_EXPECT_OK(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, + "triton_softmax_computation", R"( +CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 * 127)> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P3:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 -CHECK-DAG: %[[PID_i64:.*]] = arith.extsi %[[PID]] : i32 to i64 +CHECK-DAG: %[[PID_INDEX:.*]] = arith.index_castui %[[PID]] : i32 to index CHECK-DAG: %[[C127_i64:.*]] = arith.constant 127 : i64 CHECK-DAG: %[[ZERO_OFFSET:.*]] = arith.constant 0 : i64 -CHECK: %[[ROW_OFFSET:.*]] = arith.muli %[[PID_i64]], %[[C127_i64]] : i64 +CHECK-DAG: %[[ROW_OFFSET_INDEX:.*]] = affine.apply #[[MAP]]()[%[[PID_INDEX]]] +CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.index_castui %[[ROW_OFFSET_INDEX]] : index to i64 CHECK: %[[ARG0:.*]] = tt.addptr %[[P0]], %[[ROW_OFFSET]] : !tt.ptr, i64 CHECK-NEXT: tt.make_tensor_ptr CHECK-SAME: , 1> @@ -625,6 +742,7 @@ CHECK-NEXT: tt.make_tensor_ptr CHECK-SAME: , 1> CHECK-NEXT: tt.load CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr, 1> -> tensor<128xf32> +CHECK: %[[PID_i64:.*]] = arith.index_castui %[[PID_INDEX]] : index to i64 CHECK: %[[ARG2:.*]] = tt.addptr %[[P2]], %[[PID_i64]] : !tt.ptr, i64 CHECK-NEXT: tt.load %[[ARG2]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32 CHECK: tt.reduce @@ -642,8 +760,338 @@ CHECK-NEXT: tt.store CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<128xf32> CHECK: tt.return CHECK: } -)"), - tsl::testing::IsOkAndHolds(true)); +)")); +} + +TEST_F( + TritonFilecheckTest, + DiamondWithAdditionalDiamondParameterBroadcastedAlongReductionDimProducesAccurateResults) { // NOLINT(whitespace/line_length) + const std::string kHloText = R"( +HloModule h1 + +max_computation { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT _ = f32[] maximum(x, y) +} + +triton_softmax_computation { + parameter_1 = f32[32]{0} parameter(1) + broadcast_1 = f32[32,16]{1,0} broadcast(parameter_1), dimensions={0} + parameter_0 = f32[32,16]{1,0} parameter(0) + add_0 = f32[32,16]{1,0} add(broadcast_1, parameter_0) + c = f32[] constant(0) + reduce_0 = f32[32]{0} reduce(parameter_0, c), dimensions={1}, to_apply=max_computation + broadcast_0 = f32[32,16]{1,0} broadcast(reduce_0), dimensions={0} + ROOT _ = f32[32,16]{1,0} add(add_0, broadcast_0) +} + +ENTRY main { + parameter_1 = f32[32]{0} parameter(1) + parameter_0 = f32[32,16]{1,0} parameter(0) + ROOT _ = f32[32,16]{1,0} fusion(parameter_0, parameter_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config":{"kind":"__triton_softmax"}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); + + TritonGemmConfig config(16, 64, 32, 1, 1, 1); + TF_ASSERT_OK(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, + "triton_softmax_computation", R"( +CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 * 16)> +CHECK-LABEL: tt.func @triton_fn( +CHECK-SAME: %[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, +CHECK-SAME: %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, +CHECK-SAME: %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +CHECK-DAG: %[[ZERO_OFFSET:.*]] = arith.constant 0 : i32 +CHECK-DAG: %[[C1_i64:.*]] = arith.constant 1 : i64 +CHECK-DAG: %[[C16_i64:.*]] = arith.constant 16 : i64 +CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 +CHECK: %[[PID_INDEX:.*]] = arith.index_castui %[[PID]] : i32 to index +CHECK: %[[PID_i64:.*]] = arith.index_castui %[[PID_INDEX]] : index to i64 +CHECK: tt.addptr %[[P1]], %[[PID_i64]] : !tt.ptr, i64 +CHECK: tt.splat +CHECK: %[[ROW_OFFSET_INDEX:.*]] = affine.apply #[[MAP]]()[%[[PID_INDEX]]] +CHECK: %[[ROW_OFFSET:.*]] = arith.index_castui %[[ROW_OFFSET_INDEX]] : index to i64 +CHECK: tt.addptr %[[P0]], %[[ROW_OFFSET]] : !tt.ptr, i64 +CHECK: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK: tt.load +CHECK: tt.reduce +CHECK-NEXT: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): +CHECK: %[[MAX:.*]] = arith.maximumf %[[ARG3]], %[[ARG4]] : f32 +CHECK: tt.reduce.return %[[MAX]] : f32 +CHECK: }) : (tensor<16xf32>) -> f32 +CHECK: tt.addptr %[[P2]] +CHECK: tt.make_tensor_ptr +CHECK-SAME: tensor<16xf32> +CHECK: tt.store +CHECK-SAME: !tt.ptr, 1>, tensor<16xf32> +)")); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/0, + /*arel=*/0})); +} + +TEST_F(TritonFilecheckTest, NestedReducerFusionGetsCodegenedCorrectly) { + // TODO(b/327336797): remove filter once V100 codegen in Triton is removed. + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() << "Doesn't pass on pre-Ampere GPUs."; + } + + const std::string kHloText = R"( +HloModule softmax + +fused_convert { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + convert0 = bf16[] convert(p0) + convert1 = bf16[] convert(p1) + add = bf16[] add(convert0, convert1) + ROOT output = f32[] convert(add) +} + +add_computation { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT fusion = f32[] fusion(p0, p1), kind=kLoop, calls=fused_convert +} + +triton_softmax_computation { + p0 = pred[10,128]{1,0} parameter(0) + p0_f32 = f32[10,128]{1,0} convert(p0) + zero = f32[] constant(0) + reduce = f32[10]{0} reduce(p0_f32, zero), dimensions={1}, to_apply=add_computation + broadcast = f32[10,128]{1,0} broadcast(reduce), dimensions={0} + ROOT add = f32[10,128]{1,0} add(p0_f32, broadcast) +} + +ENTRY main { + p0 = pred[10,128]{1,0} parameter(0) + ROOT softmax = f32[10,128] fusion(p0), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config":{"kind":"__triton_softmax"}} +})"; + + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, + /*arel=*/0})); +} + +TEST_F( + TritonFilecheckTest, + DiamondWithAdditionalDiamondParameterBroadcastedAlongBatchDimProducesAccurateResults) { // NOLINT(whitespace/line_length) + const std::string kHloText = R"( +HloModule h1 + +max_computation { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT _ = f32[] maximum(x, y) +} + +triton_softmax_computation { + parameter_1 = f32[32]{0} parameter(1) + broadcast_1 = f32[16,32]{1,0} broadcast(parameter_1), dimensions={1} + parameter_0 = f32[16,32]{1,0} parameter(0) + add_0 = f32[16,32]{1,0} add(broadcast_1, parameter_0) + c = f32[] constant(0) + reduce_0 = f32[16]{0} reduce(parameter_0, c), dimensions={1}, to_apply=max_computation + broadcast_0 = f32[16,32]{1,0} broadcast(reduce_0), dimensions={0} + ROOT _ = f32[16,32]{1,0} add(add_0, broadcast_0) +} + +ENTRY main { + parameter_0 = f32[16,32]{1,0} parameter(0) + parameter_1 = f32[32]{0} parameter(1) + ROOT _ = f32[16,32]{1,0} fusion(parameter_0,parameter_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config":{"kind":"__triton_softmax"}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); + + TritonGemmConfig config(16, 64, 32, 1, 1, 1); + TF_ASSERT_OK(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, + "triton_softmax_computation", R"( +CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 * 32)> +CHECK-LABEL: tt.func @triton_fn( +CHECK-SAME: %[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, +CHECK-SAME: %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, +CHECK-SAME: %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +CHECK-DAG: %[[ZERO_OFFSET:.*]] = arith.constant 0 : i32 +CHECK-DAG: %[[C0_i64:.*]] = arith.constant 0 : i64 +CHECK-DAG: %[[C1_i64:.*]] = arith.constant 1 : i64 +CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 +CHECK: %[[PID_INDEX:.*]] = arith.index_castui %[[PID]] : i32 to index +CHECK: tt.addptr %[[P1]], %[[C0_i64]] : !tt.ptr, i64 +CHECK: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK: tt.load +CHECK-SAME: !tt.ptr, 1> -> tensor<32xf32> +CHECK: %[[ROW_OFFSET_INDEX:.*]] = affine.apply #[[MAP]]()[%[[PID_INDEX]]] +CHECK: %[[ROW_OFFSET:.*]] = arith.index_castui %[[ROW_OFFSET_INDEX]] : index to i64 +CHECK: tt.addptr %[[P0]], %[[ROW_OFFSET]] : !tt.ptr, i64 +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK: tt.load +CHECK-SAME: !tt.ptr, 1> -> tensor<32xf32> +CHECK: tt.reduce +CHECK-NEXT: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): +CHECK: %[[MAX:.*]] = arith.maximumf %[[ARG3]], %[[ARG4]] : f32 +CHECK: tt.reduce.return %[[MAX]] : f32 +CHECK: }) : (tensor<32xf32>) -> f32 +CHECK: tt.addptr %[[P2]] +CHECK: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK: tt.store +CHECK-SAME: !tt.ptr, 1>, tensor<32xf32> +)")); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +TEST_F( + TritonFilecheckTest, + DiamondWithAdditionalSplatDiamondScalarParameterProducesAccurateResults) { // NOLINT(whitespace/line_length) + const std::string kHloText = R"( +HloModule h1 + +max_computation { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT _ = f32[] maximum(x,y) +} + +triton_softmax_computation { + parameter_1 = f32[] parameter(1) + broadcast_1 = f32[64,32,16]{2,1,0} broadcast(parameter_1), dimensions={} + parameter_0 = f32[64,32,16]{2,1,0} parameter(0) + add_0 = f32[64,32,16]{2,1,0} add(broadcast_1, parameter_0) + c = f32[] constant(0) + reduce_0 = f32[64,32]{1,0} reduce(parameter_0, c), dimensions={2}, to_apply=max_computation + broadcast_0 = f32[64,32,16]{2,1,0} broadcast(reduce_0), dimensions={0,1} + ROOT _ = f32[64,32,16]{2,1,0} add(add_0, broadcast_0) +} + +ENTRY main { + parameter_1 = f32[64,32,16]{2,1,0} parameter(1) + parameter_0 = f32[] parameter(0) + ROOT _ = f32[64,32,16]{2,1,0} fusion(parameter_1, parameter_0), kind=kCustom, calls=triton_softmax_computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_softmax"}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); + + TritonGemmConfig config(16, 64, 32, 1, 1, 1); + TF_ASSERT_OK(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, + "triton_softmax_computation", R"( +// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 * 16)> +// CHECK-LABEL: tt.func @triton_fn( +// CHECK-SAME: %[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, +// CHECK-SAME: %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, +// CHECK-SAME: %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +// CHECK-DAG: %[[ZERO_OFFSET_i32:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C1_i64:.*]] = arith.constant 1 : i64 +// CHECK-DAG: %[[C16_i64:.*]] = arith.constant 16 : i64 +// CHECK-DAG: %[[ZERO_OFFSET_i64:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 +// CHECK: %[[PID_INDEX:.*]] = arith.index_castui %[[PID]] : i32 to index +// CHECK: tt.addptr %[[P1]], %[[ZERO_OFFSET_i64]] : !tt.ptr, i64 +// CHECK-NEXT: tt.load +// CHECK-SAME: {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32 +// CHECK: %[[ROW_OFFSET_INDEX:.*]] = affine.apply #[[MAP]]()[%[[PID_INDEX]]] +// CHECK: %[[ROW_OFFSET:.*]] = arith.index_castui %[[ROW_OFFSET_INDEX]] : index to i64 +// CHECK: tt.addptr %[[P0]], %[[ROW_OFFSET]] : !tt.ptr, i64 +// CHECK: tt.make_tensor_ptr +// CHECK-SAME: , 1> +// CHECK: tt.load +// CHECK-SAME: !tt.ptr, 1> -> tensor<16xf32> +// CHECK: tt.reduce +// CHECK: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): +// CHECK: %[[MAX:.*]] = arith.maximumf %[[ARG3]], %[[ARG4]] : f32 +// CHECK: tt.reduce.return %[[MAX]] : f32 +// CHECK: }) : (tensor<16xf32>) -> f32 +// CHECK: tt.addptr %[[P2]], %[[ROW_OFFSET]] : !tt.ptr, i64 +// CHECK-NEXT: tt.make_tensor_ptr +// CHECK-SAME: , 1> +// CHECK: tt.store +// CHECK-SAME: !tt.ptr, 1>, tensor<16xf32> +)")); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +TEST_F( + TritonFilecheckTest, + DiamondWithAdditionalBroadcastOf1DParameterAlongNonReductionDimensionsProducesAccurateResults) { // NOLINT(whitespace/line_length) + const std::string kHloText = R"( +HloModule h1 + +max_computation { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT _ = f32[] maximum(x,y) +} + +triton_softmax_computation { + parameter_1 = f32[16]{0} parameter(1) + broadcast_1 = f32[64,32,16]{2,1,0} broadcast(f32[16]{0} parameter_1), dimensions={2} + parameter_0 = f32[64,32,16]{2,1,0} parameter(0) + add_0 = f32[64,32,16]{2,1,0} add(f32[64,32,16]{2,1,0} broadcast_1, f32[64,32,16]{2,1,0} parameter_0) + c = f32[] constant(0) + reduce_0 = f32[64,32]{1,0} reduce(f32[64,32,16]{2,1,0} parameter_0, f32[] c), dimensions={2}, to_apply=max_computation + broadcast_0 = f32[64,32,16]{2,1,0} broadcast(f32[64,32]{1,0} reduce_0), dimensions={0,1} + ROOT _ = f32[64,32,16]{2,1,0} add(f32[64,32,16]{2,1,0} add_0, f32[64,32,16]{2,1,0} broadcast_0) +} + +ENTRY main { + parameter_1 = f32[64,32,16]{2,1,0} parameter(1) + parameter_0 = f32[16]{0} parameter(0) + ROOT _ = f32[64,32,16]{2,1,0} fusion(f32[64,32,16]{2,1,0} parameter_1, f32[16]{0} parameter_0), kind=kCustom, calls=%triton_softmax_computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_softmax"}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); + + TritonGemmConfig config(16, 64, 32, 1, 1, 1); + ASSERT_OK(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, + "triton_softmax_computation", R"( +// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 * 16)> +// CHECK-LABEL: tt.func @triton_fn( +// CHECK-SAME: %[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, +// CHECK-SAME: %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, +// CHECK-SAME: %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +// CHECK-DAG: %[[ZERO_OFFSET_i32:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C0_i64:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[C1_i64:.*]] = arith.constant 1 : i64 +// CHECK-DAG: %[[C16_i64:.*]] = arith.constant 16 : i64 +// CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 +// CHECK: %[[PID_INDEX:.*]] = arith.index_castui %[[PID]] : i32 to index +// CHECK: tt.addptr %[[P1]], %[[C0_i64]] : !tt.ptr, i64 +// CHECK-NEXT: tt.make_tensor_ptr +// CHECK-SAME: , 1> +// CHECK: tt.load +// CHECK-SAME: !tt.ptr, 1> -> tensor<16xf32> +// CHECK: %[[ROW_OFFSET_INDEX:.*]] = affine.apply #[[MAP]]()[%[[PID_INDEX]]] +// CHECK: %[[ROW_OFFSET:.*]] = arith.index_castui %[[ROW_OFFSET_INDEX]] : index to i64 +// CHECK: tt.addptr %[[P0]], %[[ROW_OFFSET]] : !tt.ptr, i64 +// CHECK-NEXT: tt.make_tensor_ptr +// CHECK-SAME: , 1> +// CHECK-NEXT: tt.load +// CHECK-SAME: !tt.ptr, 1> -> tensor<16xf32> +// CHECK: tt.reduce +// CHECK: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): +// CHECK: %[[MAX:.*]] = arith.maximumf %[[ARG3]], %[[ARG4]] : f32 +// CHECK: tt.reduce.return %[[MAX]] : f32 +// CHECK: }) : (tensor<16xf32>) -> f32 +// CHECK: tt.addptr %[[P2]], %[[ROW_OFFSET]] : !tt.ptr, i64 +// CHECK-NEXT: tt.make_tensor_ptr +// CHECK-SAME: , 1> +// CHECK: tt.store +// CHECK-SAME: !tt.ptr, 1>, tensor<16xf32> +)")); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); } TEST_F(TritonFilecheckTest, PredParametersAreTruncatedToI1) { @@ -669,22 +1117,66 @@ ENTRY e { c = f32[2,2]{1,0} parameter(3) ROOT triton_gemm = f32[2,2]{1,0} fusion(p, a, b, c), kind=kCustom, calls=triton_gemm_computation, - backend_config={kind: "__triton_gemm", - triton_gemm_config: { - "block_m":16,"block_n":16,"block_k":16, - "split_k":1,"num_stages":1,"num_warps":1 + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: { + "block_m":16,"block_n":16,"block_k":16, + "split_k":1,"num_stages":1,"num_warps":1, + "num_ctas":1 + } } } } )"; TritonGemmConfig config(16, 16, 16, 1, 1, 1); - ASSERT_THAT(CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, - "triton_gemm_computation", R"( + TF_EXPECT_OK(CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, + "triton_gemm_computation", R"( CHECK: %[[LOAD:.*]] = tt.load %{{.*}} {{.*}} : !tt.ptr, 1> -> tensor<16x16xi8> CHECK: %[[TRUNCI:.*]] = arith.trunci %[[LOAD]] : tensor<16x16xi8> to tensor<16x16xi1> CHECK: %{{.*}} = arith.andi %[[TRUNCI]], %{{.*}} : tensor<16x16xi1> -)"), - tsl::testing::IsOkAndHolds(true)); +)")); +} + +TEST_F(TritonFilecheckTest, + CodegenBatchedDotWithConcatenationWithCorrectBatchStride) { + constexpr absl::string_view kHloText = R"( +HloModule t, is_scheduled=true + +triton_gemm { + parameter_0 = f32[2,3,10]{2,1,0} parameter(0) + parameter_1 = f32[2,10,128]{2,1,0} parameter(1) + parameter_2 = f32[2,10,256]{2,1,0} parameter(2) + concatenate = f32[2,10,384]{2,1,0} concatenate(parameter_1, parameter_2), dimensions={2} + ROOT dot = f32[2,3,384]{2,1,0} dot(parameter_0, concatenate), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={1} +} + +ENTRY e { + parameter_0 = f32[2,3,10]{2,1,0} parameter(0) + parameter_1 = f32[2,10,128]{2,1,0} parameter(1) + parameter_2 = f32[2,10,256]{2,1,0} parameter(2) + ROOT dot = f32[2,3,384]{2,1,0} fusion(parameter_0, parameter_1, parameter_2), + kind=kCustom, calls=triton_gemm, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":1}}} +})"; + + TritonGemmConfig config(16, 64, 32, 1, 1, 2); + TF_EXPECT_OK(CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, + "triton_gemm", R"( +CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr +CHECK-SAME: %[[P1:[^:]*]]: !tt.ptr +CHECK-SAME: %[[P2:[^:]*]]: !tt.ptr +CHECK-DAG: %[[ARG_PTR:.*]] = arith.select %[[CONCAT_COND:.*]], %[[P1]], %[[P2]] +CHECK-DAG: %[[BATCH_STRIDE_P1:.*]] = arith.constant 1280 +CHECK-DAG: %[[BATCH_STRIDE_P2:.*]] = arith.constant 2560 +CHECK-DAG: %[[BATCH_STRIDE:.*]] = arith.select %[[CONCAT_COND_2:.*]], %[[BATCH_STRIDE_P1]], %[[BATCH_STRIDE_P2]] +CHECK-DAG: %[[PID_BATCH:.*]] = tt.get_program_id y +CHECK-DAG: %[[OFFSET:.*]] = arith.muli %[[PID_BATCH]], %[[BATCH_STRIDE]] +CHECK: %[[BLOCK_BASE_PTR:.*]] = tt.addptr %[[ARG_PTR]], %[[OFFSET]] +)")); } TEST_F(TritonGemmTest, DoNotUseTensorCoresWithNonDefaultPrecision) { @@ -703,8 +1195,10 @@ ENTRY e { p0 = s8[80,15]{1,0} parameter(0) ROOT triton_gemm_r = f32[80,16]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_r, - backend_config={kind: "__triton_gemm", triton_gemm_config: - {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":2}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":1}}} })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, ParseAndReturnVerifiedModule(kHloText)); @@ -747,7 +1241,7 @@ ENTRY e { TEST_F(TritonGemmTest, UseTensorCoresForF32OnAmpere) { const std::string kHloText = R"( triton_gemm_r { - parameter_0 = s8[80,15]{1,0} parameter(0) + parameter_0 = f16[80,15]{1,0} parameter(0) convert.3 = f32[80,15]{1,0} convert(parameter_0) parameter_1 = f32[16,15]{1,0} parameter(1) ROOT r.1 = f32[80,16]{1,0} dot(convert.3, parameter_1), @@ -756,26 +1250,21 @@ triton_gemm_r { ENTRY e { p1 = f32[16,15]{1,0} parameter(1) - p0 = s8[80,15]{1,0} parameter(0) + p0 = f16[80,15]{1,0} parameter(0) ROOT triton_gemm_r = f32[80,16]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_r, - backend_config={kind: "__triton_gemm", triton_gemm_config: - {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":2}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":1}}} })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, ParseAndReturnVerifiedModule(kHloText)); - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) { - CompileAndOptionallyVerifyPtx(std::move(verified_module), - R"( + CompileAndOptionallyVerifyPtx(std::move(verified_module), + R"( CHECK: mma )"); - } else { - CompileAndOptionallyVerifyPtx(std::move(verified_module), - R"( -CHECK: fma -)"); - } } TEST_F(TritonGemmTest, FailIfTooMuchShmem) { @@ -811,7 +1300,7 @@ ENTRY entry { TritonGemmConfig config(16, 32, 512, 1, 4, 8); EXPECT_THAT( TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation), - "test_fn", triton_dot_computation, kTritonGemmFusionKind, + "test_fn", triton_dot_computation, se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, /*minor=*/0}, dev_info, config, &llvm_module, &EmitMatMul, mlir_context), @@ -826,7 +1315,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN( const auto result, TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation), - "test_fn", triton_dot_computation, kTritonGemmFusionKind, + "test_fn", triton_dot_computation, se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, /*minor=*/0}, dev_info, config, &llvm_module, &EmitMatMul, mlir_context)); @@ -834,7 +1323,8 @@ ENTRY entry { EXPECT_GT(result.shmem_bytes, dev_info.shared_memory_per_block()); } -TEST_F(TritonGemmTest, WorksWhenKIsDivisibleByBlockKButNotByBlockKTimesSplitK) { +TEST_F(TritonGemmTestWithSplitK, + WorksWhenKIsDivisibleByBlockKButNotByBlockKTimesSplitK) { // The condition mentioned in the test name is fulfilled by // GemmKey(16, 64, 256, 8, 1, 4), which was part of the default configs for // Ampere at the time of the addition of this test case. @@ -849,7 +1339,7 @@ ENTRY e { } )"; - // This check tests if Triton is used at all plus it runs TritonAutotuner, + // This check tests if Triton is used at all plus it runs GemmFusionAutotuner, // which verifies if the generated kernels can run without errors such as // CUDA_ERROR_ILLEGAL_ADDRESS. MatchOptimizedHlo(kHloText, R"( @@ -922,10 +1412,11 @@ TEST_F(TritonGemmTest, SplitLhsNoncontractingTransposeRhs) { HloModule t ENTRY e { - p0 = s8[3,122,96,12]{3,2,1,0} parameter(0) + p0 = pred[3,122,96,12]{3,2,1,0} parameter(0) cp0 = f16[3,122,96,12]{3,2,1,0} convert(p0) - p1 = f16[1,5,122]{2,1,0} parameter(1) - ROOT _ = f16[3,96,12,1,5]{4,3,2,1,0} dot(cp0, p1), + p1 = pred[1,5,122]{2,1,0} parameter(1) + cp1 = f16[1,5,122]{2,1,0} convert(p1) + ROOT _ = f16[3,96,12,1,5]{4,3,2,1,0} dot(cp0, cp1), lhs_contracting_dims={1}, rhs_contracting_dims={2} })"; @@ -938,7 +1429,7 @@ ENTRY e { ; CHECK-SAME: "block_m": )"); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/0, /*arel=*/0})); } TEST_F(TritonGemmTest, SplitLhsNoncontracting) { @@ -1162,7 +1653,7 @@ ENTRY e { MatchOptimizedHlo(hlo_text, R"( ; CHECK: ENTRY -; CHECK: f32[5,3,4]{2,1,0} bitcast(%p1) +; CHECK: f32[5,3,4]{2,1,0} bitcast ; CHECK: fusion ; CHECK-SAME: kind=kCustom ; CHECK-SAME: "block_m": @@ -1185,7 +1676,7 @@ ENTRY e { })"; MatchOptimizedHlo(kHloText, R"( -; CHECK: ENTRY +; CHECK: ROOT ; CHECK: transpose( ; CHECK: bitcast( ; CHECK: kCustom @@ -1194,6 +1685,54 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4})); } +TEST_F(TritonGemmTest, CanCodegenNonBatchedDotWithConcatenationCorrectly) { + constexpr absl::string_view kHloText = R"( +ENTRY e { + parameter_0 = f32[3,10]{1,0} parameter(0) + parameter_1 = f32[10,128]{1,0} parameter(1) + parameter_2 = f32[10,256]{1,0} parameter(2) + concatenate = f32[10,384]{1,0} concatenate(parameter_1, parameter_2), dimensions={1} + ROOT dot = f32[3,384]{1,0} dot(parameter_0, concatenate), + lhs_batch_dims={}, lhs_contracting_dims={1}, + rhs_batch_dims={}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: ENTRY +; CHECK-NOT: concatenate +; CHECK: fusion +; CHECK-SAME: kind=kCustom +)"); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, CanCodegenBatchedDotWithConcatenationCorrectly) { + constexpr absl::string_view kHloText = R"( +ENTRY e { + parameter_0 = f32[2,3,10]{2,1,0} parameter(0) + parameter_1 = f32[2,10,128]{2,1,0} parameter(1) + parameter_2 = f32[2,10,256]{2,1,0} parameter(2) + concatenate = f32[2,10,384]{2,1,0} concatenate(parameter_1, parameter_2), dimensions={2} + ROOT dot = f32[2,3,384]{2,1,0} dot(parameter_0, concatenate), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={1} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: ENTRY +; CHECK-NOT: concatenate +; CHECK: fusion +; CHECK-SAME: kind=kCustom +)"); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + TEST_F(TritonGemmTestWithoutTritonGemmAny, SkipU8) { const std::string hlo_text = R"( HloModule t @@ -1264,7 +1803,7 @@ ENTRY entry { TritonGemmConfig config(512, 512, 32, 1, 1, 2); EXPECT_THAT( TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation), - "test_fn", triton_dot_computation, kTritonGemmFusionKind, + "test_fn", triton_dot_computation, se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, /*minor=*/0}, dev_info, config, &llvm_module, &EmitMatMul, mlir_context), @@ -1278,7 +1817,7 @@ ENTRY entry { config.block_k = 32; TF_CHECK_OK( TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation), - "test_fn", triton_dot_computation, kTritonGemmFusionKind, + "test_fn", triton_dot_computation, se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, /*minor=*/0}, dev_info, config, &llvm_module, &EmitMatMul, mlir_context) @@ -1289,7 +1828,7 @@ ENTRY entry { // https://github.com/openai/triton/issues/1864 TEST_F(TritonGemmTest, TritonCompilerDoesNotFailOnConstants) { TF_CHECK_OK(GetOptimizedModule(R"( -HloModule m, is_scheduled=true +HloModule m triton_gemm___computation { parameter_0 = f32[92,11]{1,0} parameter(0) @@ -1303,10 +1842,11 @@ ENTRY e { p0 = f32[92,11]{1,0} parameter(0) ROOT triton_gemm__ = f32[92,63]{1,0} fusion(p0), kind=kCustom, calls=triton_gemm___computation, - backend_config={"kind":"__triton_gemm", + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", "triton_gemm_config":{"block_m":"16","block_n":"64", "block_k":"16","split_k":"1", - "num_stages":"3","num_warps":"2"}} + "num_stages":"3","num_warps":"2", + "num_ctas":"1"}}} })") .status()); } @@ -1330,7 +1870,7 @@ ENTRY e { p0 = f16[55,77,111]{2,1,0} parameter(0) p1 = f16[111,77,99]{2,1,0} parameter(1) ROOT r = f16[77,55,99]{2,1,0} fusion(p0, p1), kind=kCustom, - calls=t, backend_config={"kind":"__triton_gemm"} + calls=t, backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} })", // This partially optimized HLO will go through the // autotuner which will run the fusion through the emitter @@ -1362,7 +1902,7 @@ ENTRY e { p0 = f32[2,7,3]{2,1,0} parameter(0) p1 = s32[2,1]{1,0} parameter(1) ROOT r = f32[2,7,3]{2,1,0} fusion(p0, p1), kind=kCustom, - calls=t, backend_config={"kind":"__triton_gemm"} + calls=t, backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} })", // This partially optimized HLO will go through the // autotuner which will run the fusion through the emitter @@ -1393,7 +1933,7 @@ ENTRY e { })"; MatchOptimizedHlo(hlo_text, R"( -; CHECK: fusion +; CHECK: fusion( ; CHECK-SAME: kind=kCustom ; CHECK-SAME: block_m )"); @@ -1416,9 +1956,7 @@ ENTRY e { })"; MatchOptimizedHlo(hlo_text, R"( -; CHECK: fusion -; CHECK-SAME: kind=kCustom -; CHECK-SAME: block_m +; CHECK: fusion({{.*}} kind=kCustom, {{.*}}block_m )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); @@ -1437,21 +1975,42 @@ ENTRY e { // The fusion has separate parameters for each scope. MatchOptimizedHlo(hlo_text, R"( -; CHECK: fusion(%p0, %p0), kind=kCustom +; CHECK: ENTRY +; CHECK: %[[p0:.*]] = pred[5,5]{1,0} parameter(0) +; CHECK: fusion(%[[p0]], %[[p0]]), kind=kCustom ; CHECK-SAME: "block_m": )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); } +TEST_F(TritonGemmTestAny, + DoNotFuseConcatenationOfSplitNonContractingDimension) { + const std::string hlo_text = R"( +HloModule m + +ENTRY e { + x = bf16[2,128,10] parameter(0) + y = bf16[2,256,10] parameter(1) + concat = bf16[2,384,10] concatenate(x, y), dimensions={1} + z = bf16[10,20] parameter(2) + ROOT d = bf16[2,384,20] dot(concat, z), lhs_contracting_dims={2}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(hlo_text, R"( +; CHECK: ENTRY +; CHECK: concatenate +; CHECK: ROOT +; CHECK-SAME: fusion +; CHECK-SAME: kind=kCustom +; CHECK-SAME: "block_m" +)"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + class TritonGemmLevel2Test : public TritonGemmTest { public: - void SetUp() override { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "Triton fusion on pre-Ampere GPUs is limited."; - } - } DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = TritonGemmTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_triton_fusion_level(2); @@ -1511,7 +2070,7 @@ ENTRY e { })"; MatchOptimizedHlo(kHloText, R"( -; CHECK: fused_computation +; CHECK: fused_subtract ; CHECK: negate ; CHECK: negate ; CHECK: ROOT @@ -1709,7 +2268,7 @@ ENTRY e { })"; MatchOptimizedHlo(kHloText, R"( -; CHECK: triton_gemm_dot +; CHECK: gemm_fusion_dot ; CHECK: dot( ; CHECK: bf16[] constant(0.123) ; CHECK: ROOT @@ -1734,8 +2293,10 @@ ENTRY e { p0 = f16[75] parameter(0) p1 = f16[92,75] parameter(1) ROOT _ = f16[92,67] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={kind: "__triton_gemm", triton_gemm_config: - {"block_m":32,"block_n":64,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", triton_gemm_config: + {"block_m":32,"block_n":64,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":1, + "num_ctas":1}}} })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -1785,7 +2346,7 @@ ENTRY e { })"; MatchOptimizedHlo(kHloText, R"( -; CHECK: fusion +; CHECK: fusion( ; CHECK-SAME: kind=kCustom ; CHECK-SAME: block_m )"); @@ -1808,7 +2369,7 @@ ENTRY e { })"; MatchOptimizedHlo(kHloText, R"( -; CHECK: fusion +; CHECK: fusion( ; CHECK-SAME: kind=kCustom ; CHECK-SAME: block_m )"); @@ -1831,7 +2392,7 @@ ENTRY e { })"; MatchOptimizedHlo(kHloText, R"( -; CHECK: fusion +; CHECK: fusion( ; CHECK-SAME: kind=kCustom ; CHECK-SAME: block_m )"); @@ -1854,7 +2415,7 @@ ENTRY e { })"; MatchOptimizedHlo(kHloText, R"( -; CHECK: fusion +; CHECK: fusion( ; CHECK-SAME: kind=kCustom ; CHECK-SAME: block_m )"); @@ -1878,7 +2439,7 @@ ENTRY e { })"; MatchOptimizedHlo(kHloText, R"( -; CHECK: fusion +; CHECK: fusion( ; CHECK-SAME: kind=kCustom ; CHECK-SAME: block_m )"); @@ -1903,7 +2464,7 @@ ENTRY e { })"; MatchOptimizedHlo(kHloText, R"( -; CHECK: fusion +; CHECK: fusion( ; CHECK-SAME: kind=kCustom ; CHECK-SAME: block_m )"); @@ -1928,7 +2489,7 @@ ENTRY e { })"; MatchOptimizedHlo(kHloText, R"( -; CHECK: fusion +; CHECK: fusion( ; CHECK-SAME: kind=kCustom ; CHECK-SAME: block_m )"); @@ -1953,7 +2514,7 @@ ENTRY e { })"; MatchOptimizedHlo(kHloText, R"( -; CHECK: fusion +; CHECK: fusion( ; CHECK-SAME: kind=kCustom ; CHECK-SAME: block_m )"); @@ -2074,7 +2635,8 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(TritonGemmTest, SplitKDoesNotBreakSlicedFragmentedContractingDimension) { +TEST_F(TritonGemmTestWithSplitK, + SplitKDoesNotBreakSlicedFragmentedContractingDimension) { const std::string kHloText = R"( ENTRY e { p0 = f16[16,8,128]{2,1,0} parameter(0) @@ -2255,7 +2817,7 @@ ENTRY e { GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter(), m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-5, /*arel=*/1e-5})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-5, /*arel=*/1e-3})); } TEST_F(TritonGemmTest, Naming) { @@ -2271,9 +2833,9 @@ ENTRY e { })"; MatchOptimizedHlo(hlo_text, R"( -; CHECK: %triton_gemm_r_computation ( -; CHECK: %triton_gemm_r = -; CHECK-SAME: fusion +; CHECK: %gemm_fusion_r_computation ( +; CHECK: ROOT %gemm_fusion_r +; CHECK-SAME: kCustom )"); } @@ -2329,7 +2891,11 @@ ENTRY e { p0 = s8[101,202]{1,0} parameter(0) p1 = f32[202,303]{1,0} parameter(1) ROOT _ = f32[101,303] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32,"split_k":1,"num_stages":3,"num_warps":8}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":16,"block_n":64,"block_k":32, + "split_k":1,"num_stages":3,"num_warps":8, + "num_ctas":1}}} })"; const char* hlo_text_triton = R"( @@ -2347,7 +2913,10 @@ ENTRY e { p0 = s8[101,202]{1,0} parameter(0) p1 = f32[202,303]{1,0} parameter(1) ROOT _ = f32[101,303] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":32,"block_n":128,"block_k":32,"split_k":1,"num_stages":2,"num_warps":4}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":128,"block_k":32, + "split_k":1,"num_stages":2,"num_warps":4, + "num_ctas":1}}} })"; EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_triton, @@ -2364,7 +2933,7 @@ ENTRY e { arg1 = f16[7,33] parameter(1) gemm = (f16[5,33], s8[0]{0}) custom-call(arg0, arg1), custom_call_target="__cublas$gemm", - backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"} + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} ROOT get-tuple-element = f16[5,33]{1,0} get-tuple-element((f16[5,33]{1,0}, s8[0]{0}) gemm), index=0 } )"; @@ -2383,7 +2952,10 @@ ENTRY e { p0 = f16[5,7]{1,0} parameter(0) p1 = f16[7,33]{1,0} parameter(1) ROOT _ = f16[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":1, + "num_ctas":1}}} } )"; @@ -2401,7 +2973,7 @@ ENTRY e { arg1 = f32[7,33] parameter(1) gemm = (f32[5,33], s8[0]{0}) custom-call(arg0, arg1), custom_call_target="__cublas$gemm", - backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"} + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} ROOT get-tuple-element = f32[5,33]{1,0} get-tuple-element((f32[5,33]{1,0}, s8[0]{0}) gemm), index=0 } )"; @@ -2420,7 +2992,10 @@ ENTRY e { p0 = f32[5,7]{1,0} parameter(0) p1 = f32[7,33]{1,0} parameter(1) ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":1, + "num_ctas":1}}} } )"; @@ -2429,12 +3004,47 @@ ENTRY e { /*run_hlo_passes=*/false)); } -TEST_F(CompareTest, BF16TransposedLHS) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } +TEST_F(CompareTest, F32WithTrivialNonContractingDimension) { + const char* hlo_text_ref = R"( +HloModule r + +ENTRY e { + arg0 = f32[5,7] parameter(0) + arg1 = f32[1,7] parameter(1) + gemm = (f32[5,1], s8[0]{0}) custom-call(arg0, arg1), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[1],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = f32[5,1]{1,0} get-tuple-element((f32[5,1]{1,0}, s8[0]{0}) gemm), index=0 +} +)"; + + const char* hlo_text_triton = R"( +HloModule t + +triton_dot { + p0 = f32[5,7] parameter(0) + p1 = f32[1,7] parameter(1) + ROOT dot = f32[5,1] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = f32[5,7]{1,0} parameter(0) + p1 = f32[1,7]{1,0} parameter(1) + ROOT _ = f32[5,1] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":1, + "num_ctas":1}}} +} +)"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_ref, hlo_text_triton, + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}, + /*run_hlo_passes=*/false)); +} +TEST_F(CompareTest, BF16TransposedLHS) { const char* hlo_text_ref = R"( HloModule r @@ -2443,7 +3053,7 @@ ENTRY e { arg1 = bf16[512,256]{1,0} parameter(1) gemm = (bf16[16,256]{1,0}, s8[0]{0}) custom-call(arg0, arg1), custom_call_target="__cublas$gemm", - backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[0],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"} + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[0],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} ROOT get-tuple-element = bf16[16,256]{1,0} get-tuple-element((bf16[16,256]{1,0}, s8[0]{0}) gemm), index=0 } )"; @@ -2462,7 +3072,10 @@ ENTRY e { arg0 = bf16[512,16]{1,0} parameter(0) arg1 = bf16[512,256]{1,0} parameter(1) ROOT _ = bf16[16,256]{1,0} fusion(arg0, arg1), kind=kCustom, calls=triton_dot, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":128,"block_n":32,"block_k":16,"split_k":1,"num_stages":2,"num_warps":4}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":128,"block_n":32,"block_k":16, + "split_k":1,"num_stages":2,"num_warps":4, + "num_ctas":1}}} } )"; @@ -2472,11 +3085,6 @@ ENTRY e { } TEST_F(CompareTest, UsingOptinSharedMemoryOnAmpereProducesSameResult) { - // On pre-Ampere GPUs the test would use a different amount of shared memory. - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "This test is for Ampere+ GPUs."; - } const se::DeviceDescription dev_info = backend().default_stream_executor()->GetDeviceDescription(); constexpr int kBytesOfSharedMemoryTested = 64 * 1024; @@ -2498,7 +3106,10 @@ ENTRY e { p0 = s8[332,441]{1,0} parameter(0) p1 = f16[441,39]{1,0} parameter(1) ROOT _ = f16[332,39]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":128,"block_n":128,"block_k":128,"split_k":1,"num_stages":2,"num_warps":32}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":128,"block_n":128,"block_k":128, + "split_k":1,"num_stages":2,"num_warps":32, + "num_ctas":1}}} })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, @@ -2511,16 +3122,19 @@ ENTRY e { llvm::Module llvm_module("module", llvm_ctx); mlir::MLIRContext mlir_context; - TF_ASSERT_OK_AND_ASSIGN(auto config, + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, hlo_module->entry_computation() ->root_instruction() - ->backend_config()); + ->backend_config()); + const FusionBackendConfig& config = gpu_config.fusion_backend_config(); + TF_ASSERT_OK_AND_ASSIGN( + TritonGemmConfig triton_gemm_config, + TritonGemmConfig::FromProto(config.triton_gemm_config())); TF_ASSERT_OK_AND_ASSIGN( const auto result, TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation), - "test_fn", triton_dot_computation, kTritonGemmFusionKind, - GetCudaComputeCapability(), dev_info, - TritonGemmConfig::FromProto(config.triton_gemm_config()), + "test_fn", triton_dot_computation, + GetCudaComputeCapability(), dev_info, triton_gemm_config, &llvm_module, &EmitMatMul, mlir_context)); // The config is chosen so that the used memory size is slightly above the // 48 kB boundary of standard / optin shared memory so that any GPU that @@ -2544,7 +3158,10 @@ ENTRY e { p0 = s8[332,441]{1,0} parameter(0) p1 = f16[441,39]{1,0} parameter(1) ROOT _ = f16[332,39]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":4}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":4, + "num_ctas":1}}} })"; EXPECT_TRUE(RunAndCompareTwoModules(kHloTextLowShmem, kHloTextOptinShmem, @@ -2561,7 +3178,7 @@ ENTRY e { arg1 = f16[64,32]{1,0} parameter(1) gemm = (f16[128,64]{1,0}, s8[0]{0}) custom-call(arg0, arg1), custom_call_target="__cublas$gemm", - backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[1],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"} + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[1],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} ROOT get-tuple-element = f16[128,64]{1,0} get-tuple-element((f16[128,64]{1,0}, s8[0]{0}) gemm), index=0 } )"; @@ -2580,7 +3197,10 @@ ENTRY e { arg0 = f16[128,32]{1,0} parameter(0) arg1 = f16[64,32]{1,0} parameter(1) ROOT _ = f16[128,64]{1,0} fusion(arg0, arg1), kind=kCustom, calls=triton_dot, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":128,"block_n":32,"block_k":64,"split_k":1,"num_stages":2,"num_warps":4}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":128,"block_n":32,"block_k":64, + "split_k":1,"num_stages":2,"num_warps":4, + "num_ctas":1}}} } )"; @@ -2598,7 +3218,7 @@ ENTRY e { arg1 = f32[1024,64]{1,0} parameter(1) gemm = (f32[128,1024]{1,0}, s8[0]{0}) custom-call(arg0, arg1), custom_call_target="__cublas$gemm", - backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[0],"rhs_contracting_dimensions":[1],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"} + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[0],"rhs_contracting_dimensions":[1],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} ROOT get-tuple-element = f32[128,1024]{1,0} get-tuple-element((f32[128,1024]{1,0}, s8[0]{0}) gemm), index=0 } )"; @@ -2617,7 +3237,10 @@ ENTRY e { arg0 = f32[64,128]{1,0} parameter(0) arg1 = f32[1024,64]{1,0} parameter(1) ROOT _ = f32[128,1024]{1,0} fusion(arg0, arg1), kind=kCustom, calls=triton_dot, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":32,"block_n":32,"block_k":64,"split_k":1,"num_stages":2,"num_warps":4}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":64, + "split_k":1,"num_stages":2,"num_warps":4, + "num_ctas":1}}} } )"; @@ -2627,10 +3250,6 @@ ENTRY e { } TEST_F(CompareTest, S8BF16) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const char* hlo_text_ref = R"( HloModule r @@ -2645,7 +3264,7 @@ ENTRY e { p1 = bf16[256,122]{1,0} parameter(1) gemm = (bf16[144,122]{1,0}, s8[0]{0}) custom-call(fusion, p1), custom_call_target="__cublas$gemm", - backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"} + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} ROOT get-tuple-element = bf16[144,122]{1,0} get-tuple-element((bf16[144,122]{1,0}, s8[0]{0}) gemm), index=0 } )"; @@ -2665,7 +3284,10 @@ ENTRY e { p0 = s8[144,256]{1,0} parameter(0) p1 = bf16[256,122]{1,0} parameter(1) ROOT _ = bf16[144,122]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":64,"block_n":64,"block_k":64,"split_k":1,"num_stages":1,"num_warps":2}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":64,"block_n":64,"block_k":64, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":1}}} } )"; @@ -2675,10 +3297,6 @@ ENTRY e { } TEST_F(CompareTest, SplitK) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const std::string hlo_text_ref = R"( HloModule t, is_scheduled=true @@ -2696,7 +3314,10 @@ ENTRY e { bitcast.4 = s8[480,120]{1,0} bitcast(p0) ROOT triton_gemm_r = bf16[480,16]{1,0} fusion(bitcast.4, p1), kind=kCustom, calls=triton_gemm_r, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":64,"block_n":32,"block_k":64,"split_k":1,"num_stages":4,"num_warps":4}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":64,"block_n":32,"block_k":64, + "split_k":1,"num_stages":4,"num_warps":4, + "num_ctas":1}}} })"; const std::string hlo_text_splitk = R"( @@ -2735,7 +3356,10 @@ ENTRY e { bitcast.4 = s8[480,120]{1,0} bitcast(p0) triton_gemm_r = bf16[4,480,16]{2,1,0} fusion(bitcast.4, p1), kind=kCustom, calls=triton_gemm_r, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":32,"block_n":32,"block_k":128,"split_k":4,"num_stages":1,"num_warps":4}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":128, + "split_k":4,"num_stages":1,"num_warps":4, + "num_ctas":1}}} ROOT fusion.1 = bf16[480,16]{1,0} fusion(triton_gemm_r), kind=kLoop, calls=fused_computation })"; @@ -2746,10 +3370,6 @@ ENTRY e { } TEST_F(CompareTest, SplitKBatch) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const std::string kHloTextRef = R"( HloModule m, is_scheduled=true @@ -2767,7 +3387,10 @@ ENTRY e { tmp_0 = bf16[1,1,800,5,128]{4,3,2,1,0} parameter(1) ROOT triton_gemm_dot.24 = f32[5,128,700]{2,1,0} fusion(tmp_3, tmp_0), kind=kCustom, calls=triton_gemm_dot.24, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":64,"block_n":32,"block_k":64,"split_k":1,"num_stages":2,"num_warps":8}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":64,"block_n":32,"block_k":64, + "split_k":1,"num_stages":2,"num_warps":8, + "num_ctas":1}}} })"; const std::string kHloTextSplitK = R"( @@ -2795,7 +3418,10 @@ ENTRY e { tmp_0 = bf16[1,1,800,5,128]{4,3,2,1,0} parameter(1) triton_gemm_dot.24 = f32[8,5,128,700]{3,2,1,0} fusion(tmp_3, tmp_0), kind=kCustom, calls=triton_gemm_dot, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":64,"block_n":32,"block_k":64,"split_k":8,"num_stages":1,"num_warps":4}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":64,"block_n":32,"block_k":64, + "split_k":8,"num_stages":1,"num_warps":4, + "num_ctas":1}}} constant = f32[] constant(0) ROOT reduce = f32[5,128,700]{2,1,0} reduce(triton_gemm_dot.24, constant), dimensions={0}, to_apply=add })"; @@ -2806,10 +3432,6 @@ ENTRY e { } TEST_F(CompareTest, SplitKNontrivialBitcast) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const std::string kHloTextRef = R"( HloModule module, is_scheduled=true @@ -2828,7 +3450,10 @@ ENTRY entry { parameter_1.1 = bf16[16,4,128]{2,1,0} parameter(1) ROOT triton_gemm_dot.5316 = bf16[16,96]{1,0} fusion(bitcast.6, parameter_1.1), kind=kCustom, calls=triton_gemm_dot.5316, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":32,"block_n":32,"block_k":256,"split_k":1,"num_stages":1,"num_warps":4}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":256, + "split_k":1,"num_stages":1,"num_warps":4, + "num_ctas":1}}} })"; const std::string kHloTextSplitK = R"( @@ -2868,7 +3493,10 @@ ENTRY entry { parameter_1.1 = bf16[16,4,128]{2,1,0} parameter(1) triton_gemm_dot.5316 = bf16[16,16,96]{2,1,0} fusion(bitcast.6, parameter_1.1), kind=kCustom, calls=triton_gemm_dot.5316, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":64,"block_n":32,"block_k":32,"split_k":16,"num_stages":1,"num_warps":4}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":64,"block_n":32,"block_k":32, + "split_k":16,"num_stages":1,"num_warps":4, + "num_ctas":1}}} ROOT fusion.1 = bf16[16,96]{1,0} fusion(triton_gemm_dot.5316), kind=kLoop, calls=fused_computation })"; @@ -2878,7 +3506,7 @@ ENTRY entry { /*run_hlo_passes=*/false)); } -// This is based on gemm_rewriter_triton_test.cc/SplitKTest.SupportsIndivisible. +// This is based on gemm_fusion_test.cc/SplitKTest.SupportsIndivisible. // // There were relatively large numeric errors with an f16 temporary buffer, so I // ended up using --xla_gpu_triton_gemm_disable_reduced_precision_reduction=true @@ -2900,7 +3528,11 @@ triton_gemm_dot.clone { ENTRY entry_computation { p0 = s8[3,129,5,32]{3,2,1,0} parameter(0) p1 = f16[16,129]{1,0} parameter(1) - ROOT fusion = f16[480,16]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.clone, backend_config={"kind":"__triton_gemm","triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"256","split_k":"1","num_stages":"1","num_warps":"4"}} + ROOT fusion = f16[480,16]{1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.clone, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"256", + "split_k":"1","num_stages":"1","num_warps":"4", + "num_ctas":"1"}}} } )"; @@ -2941,7 +3573,11 @@ fused_computation { ENTRY entry_computation { p0 = s8[3,129,5,32]{3,2,1,0} parameter(0) p1 = f16[16,129]{1,0} parameter(1) - fusion = f32[2,480,16]{2,1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.clone, backend_config={"kind":"__triton_gemm","triton_gemm_config":{"block_m":"128","block_n":"128","block_k":"64","split_k":"2","num_stages":"1","num_warps":"8"}} + fusion = f32[2,480,16]{2,1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.clone, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"128","block_n":"128","block_k":"64", + "split_k":"2","num_stages":"1","num_warps":"8", + "num_ctas":"1"}}} ROOT fusion.1 = f16[480,16]{1,0} fusion(fusion), kind=kLoop, calls=fused_computation } )"; @@ -2969,7 +3605,10 @@ ENTRY entry_computation { p1 = f16[1,1023,128]{2,1,0} parameter(1) ROOT triton_gemm_dot.7103 = f16[1,8,4,128]{3,2,1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.7103_computation.clone, - backend_config={"kind":"__triton_gemm","triton_gemm_config":{"block_m":"128","block_n":"128","block_k":"32","split_k":"1","num_stages":"4","num_warps":"4"}} + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"128","block_n":"128","block_k":"32", + "split_k":"1","num_stages":"4","num_warps":"4", + "num_ctas":"1"}}} } )"; @@ -3017,7 +3656,12 @@ fused_computation.1 { ENTRY entry_computation { p0 = f16[1,8,4,1023]{3,2,1,0} parameter(0) p1 = f16[1,1023,128]{2,1,0} parameter(1) - triton_gemm_dot.7103 = f16[8,1,8,4,128]{4,3,2,1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.7103_computation.clone, backend_config={"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"128","block_k":"32","split_k":"8","num_stages":"1","num_warps":"4"}} + triton_gemm_dot.7103 = f16[8,1,8,4,128]{4,3,2,1,0} fusion(p0, p1), kind=kCustom, + calls=triton_gemm_dot.7103_computation.clone, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"16","block_n":"128","block_k":"32", + "split_k":"8","num_stages":"1","num_warps":"4", + "num_ctas":"1"}}} ROOT fusion.1 = f16[1,8,4,128]{3,2,1,0} fusion(triton_gemm_dot.7103), kind=kLoop, calls=fused_computation.1 } )"; @@ -3043,7 +3687,12 @@ triton_gemm_dot.7103_computation.clone { ENTRY entry_computation { p0 = f16[1,8,4,1019]{3,2,1,0} parameter(0) p1 = f16[1,1019,128]{2,1,0} parameter(1) - ROOT triton_gemm_dot.7103 = f16[1,8,4,128]{3,2,1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.7103_computation.clone, backend_config={"kind":"__triton_gemm","triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"256","split_k":"1","num_stages":"1","num_warps":"4"}} + ROOT triton_gemm_dot.7103 = f16[1,8,4,128]{3,2,1,0} fusion(p0, p1), kind=kCustom, + calls=triton_gemm_dot.7103_computation.clone, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"256", + "split_k":"1","num_stages":"1","num_warps":"4", + "num_ctas":"1"}}} } )"; @@ -3091,7 +3740,12 @@ fused_computation.1 { ENTRY entry_computation { p0 = f16[1,8,4,1019]{3,2,1,0} parameter(0) p1 = f16[1,1019,128]{2,1,0} parameter(1) - triton_gemm_dot.7103 = f16[16,1,8,4,128]{4,3,2,1,0} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.7103_computation.clone, backend_config={"kind":"__triton_gemm","triton_gemm_config":{"block_m":"64","block_n":"32","block_k":"32","split_k":"16","num_stages":"1","num_warps":"4"}} + triton_gemm_dot.7103 = f16[16,1,8,4,128]{4,3,2,1,0} fusion(p0, p1), kind=kCustom, + calls=triton_gemm_dot.7103_computation.clone, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", + "triton_gemm_config":{"block_m":"64","block_n":"32","block_k":"32", + "split_k":"16","num_stages":"1","num_warps":"4", + "num_ctas":"1"}}} ROOT fusion.1 = f16[1,8,4,128]{3,2,1,0} fusion(triton_gemm_dot.7103), kind=kLoop, calls=fused_computation.1 } )"; @@ -3119,7 +3773,10 @@ ENTRY e { p1 = f32[32,50,104]{2,1,0} parameter(1) ROOT triton_gemm_dot.6 = f32[32,50,26]{2,0,1} fusion(p0, p1), kind=kCustom, calls=triton_gemm_dot.6, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":64,"block_n":16,"block_k":32,"split_k":1,"num_stages":1,"num_warps":4}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":64,"block_n":16,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":4, + "num_ctas":1}}} })"; const std::string kHloTextRef = R"( @@ -3145,7 +3802,10 @@ ENTRY e { %parameter_1 = f32[32,50,104]{2,1,0} parameter(1) %triton_gemm_dot.127 = f32[32,50,26]{2,1,0} fusion(%parameter_0, %parameter_1), kind=kCustom, calls=%triton_gemm_dot.127, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":32,"block_n":128,"block_k":64,"split_k":1,"num_stages":2,"num_warps":4}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":128,"block_k":64, + "split_k":1,"num_stages":2,"num_warps":4, + "num_ctas":1}}} ROOT %fusion.1 = f32[32,50,26]{2,0,1} fusion(%triton_gemm_dot.127), kind=kLoop, calls=%fused_computation })"; @@ -3171,10 +3831,11 @@ ENTRY e { p0 = f32[92,11]{1,0} parameter(0) ROOT triton_gemm__ = f32[63,92]{1,0} fusion(p0), kind=kCustom, calls=triton_gemm___computation, - backend_config={"kind":"__triton_gemm", + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", "triton_gemm_config":{"block_m":"16","block_n":"64", "block_k":"16","split_k":"1", - "num_stages":"3","num_warps":"2"}} + "num_stages":"3","num_warps":"2", + "num_ctas":"1"}}} })"; const std::string kHloTextRef = R"( @@ -3186,7 +3847,7 @@ ENTRY e { broadcast.2 = f32[11,63]{1,0} broadcast(constant_2), dimensions={} gemm = (f32[63,92]{1,0}, s8[0]{0}) custom-call(broadcast.2, parameter_0), custom_call_target="__cublas$gemm", - backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["0"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"} + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["0"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} ROOT get-tuple-element = f32[63,92]{1,0} get-tuple-element((f32[63,92]{1,0}, s8[0]{0}) gemm), index=0 })"; @@ -3211,10 +3872,11 @@ triton_gemm___computation { ENTRY e { ROOT triton_gemm__ = f32[11,45]{1,0} fusion(), kind=kCustom, calls=triton_gemm___computation, - backend_config={"kind":"__triton_gemm", + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", "triton_gemm_config":{"block_m":"16","block_n":"64", "block_k":"16","split_k":"1", - "num_stages":"3","num_warps":"2"}} + "num_stages":"3","num_warps":"2", + "num_ctas":"1"}}} })"; const std::string kHloTextRef = R"( @@ -3227,7 +3889,7 @@ ENTRY triton_gemm___computation { broadcast.1 = f32[63,45]{1,0} broadcast(constant_1), dimensions={} gemm = (f32[11,45]{1,0}, s8[0]{0}) custom-call(broadcast, broadcast.1), custom_call_target="__cublas$gemm", - backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"} + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} ROOT get-tuple-element = f32[11,45]{1,0} get-tuple-element((f32[11,45]{1,0}, s8[0]{0}) gemm), index=0 })"; @@ -3280,10 +3942,11 @@ ENTRY e { tmp_12 = f32[3,57]{1,0} parameter(11) ROOT r = f32[32,57]{0,1} fusion(tmp_1, tmp_2, tmp_3, tmp_5, tmp_7, tmp_14, tmp_15, tmp_16, tmp_18, tmp_9, tmp_10, tmp_12), kind=kCustom, calls=triton_gemm_dot_computation, - backend_config={"kind":"__triton_gemm", + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", "triton_gemm_config":{"block_m":"64","block_n":"64", "block_k":"64","split_k":"1", - "num_stages":"1","num_warps":"4"}} + "num_stages":"1","num_warps":"4", + "num_ctas":"1"}}} })"; const std::string kHloTextRef = R"( @@ -3333,7 +3996,7 @@ ENTRY e { fusion = f32[3,57]{1,0} fusion(tmp_18, tmp_14, tmp_15, tmp_16, tmp_12, /*index=5*/tmp_9, tmp_10), kind=kLoop, calls=fused_computation gemm = (f32[32,57]{0,1}, s8[0]{0}) custom-call(fusion.1, fusion), custom_call_target="__cublas$gemm", - backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["0"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"} + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["0"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} ROOT get-tuple-element = f32[32,57]{0,1} get-tuple-element((f32[32,57]{0,1}, s8[0]{0}) gemm), index=0 })"; @@ -3343,10 +4006,6 @@ ENTRY e { } TEST_F(CompareTest, PredToBF16ConversionWorks) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const std::string kHloTextTest = R"( HloModule m, is_scheduled=true @@ -3366,10 +4025,11 @@ ENTRY e { p2 = s32[11,63]{1,0} parameter(2) ROOT triton_gemm__ = bf16[92,63]{1,0} fusion(p0, p1, p2), kind=kCustom, calls=triton_gemm_computation, - backend_config={"kind":"__triton_gemm", + backend_config={"fusion_backend_config": {"kind":"__triton_gemm", "triton_gemm_config":{"block_m":"32","block_n":"16", "block_k":"32","split_k":"1", - "num_stages":"1","num_warps":"4"}} + "num_stages":"1","num_warps":"4", + "num_ctas":"1"}}} })"; const std::string kHloTextRef = R"( @@ -3389,11 +4049,11 @@ ENTRY e { fusion = bf16[11,63]{1,0} fusion(p1, p2), kind=kLoop, calls=fused_computation gemm = (bf16[92,63]{1,0}, s8[0]{0}) custom-call(p0, fusion), custom_call_target="__cublas$gemm", - backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers": + backend_config={"gemm_backend_config": {"alpha_real":1,"beta":0,"dot_dimension_numbers": {"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"], "lhs_batch_dimensions":[],"rhs_batch_dimensions":[]}, "alpha_imag":0,"precision_config": - {"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"} + {"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} ROOT get-tuple-element = bf16[92,63]{1,0} get-tuple-element((bf16[92,63]{1,0}, s8[0]{0}) gemm), index=0 })"; @@ -3426,8 +4086,10 @@ ENTRY e { p2 = f16[32,32]{1,0} parameter(2) ROOT r = f16[9,32]{1,0} fusion(p0, p1, p2), kind=kCustom, calls=triton_dot, - backend_config={kind: "__triton_gemm", - triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":2}} + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":"1"}}} })"; const std::string kHloTextRef = R"( @@ -3464,10 +4126,6 @@ class TritonGemmContractionDims : public TritonGemmTest { }; TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_0) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const std::string kHloText = R"( HloModule m @@ -3490,10 +4148,6 @@ ENTRY e { } TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_1_2) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const std::string kHloText = R"( HloModule m @@ -3516,10 +4170,6 @@ ENTRY e { } TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_0_1) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const std::string kHloText = R"( HloModule m @@ -3543,10 +4193,6 @@ ENTRY e { } TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_1) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const std::string kHloText = R"( HloModule m @@ -3563,10 +4209,760 @@ ENTRY e { ->fused_instructions_computation() ->root_instruction(), GmockMatch(m::Dot(m::Op().WithShape(BF16, {16, 32}, {1, 0}), - m::Op().WithShape(BF16, {40, 32}, {1, 0})) + m::Op().WithShape(BF16, {32, 40}, {1, 0})) .WithShape(BF16, {16, 40}, {1, 0}))); } +// In these tests, we depend on "algorithm" annotations for selecting the 6XBF16 +// algorithm. +class Triton6xBF16GemmTest : public TritonFilecheckTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = TritonFilecheckTest::GetDebugOptionsForTest(); + // These 2 flags are not strictly necessary now, but we're adding them to be + // on the safe side against future flakiness. + // + // Enable triton fusion for all supported gemms. + debug_options.set_xla_gpu_triton_gemm_any(true); + // Do not fall back to cuBLAS, we are testing Triton. + debug_options.set_xla_gpu_cublas_fallback(false); + + // Do not autotune split-k by default, since this prevents deterministically + // matching the optimized HLO. + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + return debug_options; + } +}; + +// In these tests, we depend on debug option flags for selecting the 6XBF16 +// algorithm. +// TODO(b/316147294): Remove this class and the --xla_gpu_enable_bf16_6way_gemm +// flag after we will support the algorithm values through the entire stack. +class Triton6xBF16GemmTestWithFlag : public TritonFilecheckTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = TritonFilecheckTest::GetDebugOptionsForTest(); + // Enable triton fusion for all supported gemms. + debug_options.set_xla_gpu_triton_gemm_any(true); + // Do not fall back to cuBLAS, we are testing Triton. + debug_options.set_xla_gpu_cublas_fallback(false); + // Do not autotune split-k by default, since this prevents deterministically + // matching the optimized HLO. + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + // Enable bf16_6way gemm to compute F32 matmul. + debug_options.set_xla_gpu_enable_bf16_6way_gemm(true); + return debug_options; + } +}; + +TEST_F(Triton6xBF16GemmTest, Emit6xBF16GemmWhenBothInputsAreF32) { + const char* kHloText = R"( +HloModule t + +triton_dot { + p0 = f32[5,7] parameter(0) + p1 = f32[7,33] parameter(1) + ROOT dot = f32[5,33] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x6 +} + +ENTRY e { + p0 = f32[5,7]{1,0} parameter(0) + p1 = f32[7,33]{1,0} parameter(1) + ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} +} +)"; + TritonGemmConfig config(32, 32, 32, 1, 1, 1); + ASSERT_OK( + CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( +CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> +CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> +CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32> +CHECK: %[[CAST_I32:.*]] = tt.bitcast %{{.*}} : tensor<32x32xf32> -> tensor<32x32xi32> +CHECK: %[[EXTRACT_HI:.*]] = arith.andi %[[CAST_I32]], %[[C_MASK]] : tensor<32x32xi32> +CHECK: %[[CAST_HI:.*]] = tt.bitcast %[[EXTRACT_HI]] : tensor<32x32xi32> -> tensor<32x32xf32> +CHECK: %[[TRUNC_TO_BF16:.*]] = arith.truncf %[[CAST_HI]] : tensor<32x32xf32> to tensor<32x32xbf16> +CHECK-COUNT-5: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> +CHECK: %[[ABS:.*]] = math.absf +CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[INFINITY]], %[[ABS]] : tensor<32x32xf32> +CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %{{.*}}, %[[C0]] : tensor<32x32xi1>, tensor<32x32xf32> +CHECK: %[[DOT_LAST:.*]] = tt.dot %{{.*}}, %{{.*}}, %[[SELECT]] {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> +CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf32> + )")); + + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-6, + /*arel=*/1e-6})); +} + +TEST_F(Triton6xBF16GemmTestWithFlag, Emit6xBF16GemmWhenBothInputsAreF32) { + const char* kHloText = R"( +HloModule t + +triton_dot { + p0 = f32[5,7] parameter(0) + p1 = f32[7,33] parameter(1) + ROOT dot = f32[5,33] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[5,7]{1,0} parameter(0) + p1 = f32[7,33]{1,0} parameter(1) + ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} +} +)"; + TritonGemmConfig config(32, 32, 32, 1, 1, 1); + TF_ASSERT_OK( + CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( +CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> +CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> +CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32> +CHECK: %[[CAST_I32:.*]] = tt.bitcast %{{.*}} : tensor<32x32xf32> -> tensor<32x32xi32> +CHECK: %[[EXTRACT_HI:.*]] = arith.andi %[[CAST_I32]], %[[C_MASK]] : tensor<32x32xi32> +CHECK: %[[CAST_HI:.*]] = tt.bitcast %[[EXTRACT_HI]] : tensor<32x32xi32> -> tensor<32x32xf32> +CHECK: %[[TRUNC_TO_BF16:.*]] = arith.truncf %[[CAST_HI]] : tensor<32x32xf32> to tensor<32x32xbf16> +CHECK-COUNT-5: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> +CHECK: %[[ABS:.*]] = math.absf +CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[INFINITY]], %[[ABS]] : tensor<32x32xf32> +CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %{{.*}}, %[[C0]] : tensor<32x32xi1>, tensor<32x32xf32> +CHECK: %[[DOT_LAST:.*]] = tt.dot %{{.*}}, %{{.*}}, %[[SELECT]] {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> +CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf32> + )")); + + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-6, + /*arel=*/1e-6})); +} + +TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmWorksForLongContractingDimension) { + const char* kHloText = R"( +HloModule t + +triton_dot { + p0 = f32[5,2048] parameter(0) + p1 = f32[2048,33] parameter(1) + ROOT dot = f32[5,33] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x6 +} + +ENTRY e { + p0 = f32[5,2048]{1,0} parameter(0) + p1 = f32[2048,33]{1,0} parameter(1) + ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":64,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":4, "num_ctas":1}}} +} +)"; + TritonGemmConfig config(64, 32, 32, 1, 1, 4); + TF_ASSERT_OK( + CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( +CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<64x32xbf16> * tensor<32x32xbf16> -> tensor<64x32xf32> + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-5, + /*arel=*/1e-5})); +} + +TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmCanHandleInfinity) { + const char* kHloText = R"( +HloModule t + +triton_dot { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT dot = f32[2,2] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x6 +} + +ENTRY e { + p0 = f32[2,2]{1, 0} parameter(0) + p1 = f32[2,2]{1, 0} parameter(1) + ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} +} +)"; + TritonGemmConfig config(32, 32, 32, 1, 1, 1); + TF_ASSERT_OK( + CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( +CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + std::vector arguments(2); + arguments[0] = + LiteralUtil::CreateR2({{+std::numeric_limits::infinity(), + +std::numeric_limits::infinity()}, + {+std::numeric_limits::infinity(), + +std::numeric_limits::infinity()}}); + arguments[1] = LiteralUtil::CreateR2({{1.0f, 1.0f}, {1.0f, 1.0f}}); + std::vector argument_ptrs; + absl::c_transform( + arguments, std::back_inserter(argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), argument_ptrs, + ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmCanHandleNaN) { + const char* kHloText = R"( +HloModule t + +triton_dot { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT dot = f32[2,2] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x6 +} + +ENTRY e { + p0 = f32[2,2]{1, 0} parameter(0) + p1 = f32[2,2]{1, 0} parameter(1) + ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} +} +)"; + TritonGemmConfig config(32, 32, 32, 1, 1, 1); + TF_ASSERT_OK( + CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( +CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + std::vector arguments(2); + arguments[0] = + LiteralUtil::CreateR2({{std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}}); + arguments[1] = LiteralUtil::CreateR2( + {{1.0f, +std::numeric_limits::infinity()}, + {1.0f, +std::numeric_limits::infinity()}}); + std::vector argument_ptrs; + absl::c_transform( + arguments, std::back_inserter(argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), argument_ptrs, + ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +// Test case shows that why we truncate the middle term instead of rounding. +// If we round the middle term, the splitted terms may disagree in sign. This +// could result in wrong results for extreme values. +// For example, consider: +// x = -3.40282347e+38 +// If we round the middle term, its decomposition would be: +// x_hi: -3.38953139e+38 +// x_mid: -1.3240357e+36 +// x_lo: 5.17201445e+33 +// The result of x*x would be NaN instead of positive infinity. +TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmWorksForInputsWithLargeExponent) { + const char* kHloText = R"( +HloModule t + +triton_dot { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT dot = f32[2,2] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x6 +} + +ENTRY e { + p0 = f32[2,2]{1, 0} parameter(0) + p1 = f32[2,2]{1, 0} parameter(1) + ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} +} +)"; + TritonGemmConfig config(32, 32, 32, 1, 1, 1); + TF_ASSERT_OK( + CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( +CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + std::vector arguments(2); + constexpr float kLargeExponentFloat = 0x1.0103p72f; + arguments[0] = LiteralUtil::CreateR2( + {{kLargeExponentFloat, 1.0f}, {-kLargeExponentFloat, 1.0f}}); + arguments[1] = LiteralUtil::CreateR2( + {{kLargeExponentFloat, 1.0f}, {-kLargeExponentFloat, 1.0f}}); + std::vector argument_ptrs; + absl::c_transform( + arguments, std::back_inserter(argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(std::move(module), argument_ptrs, + ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); +} + +TEST_F(Triton6xBF16GemmTest, Emit6xBF16GemmEndToEnd) { + const char* kHloText = R"( +HloModule t + +ENTRY e { + p0 = f32[5,32] parameter(0) + p1 = f32[32,7] parameter(1) + ROOT dot = f32[5,7] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x6 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, + ParseAndReturnVerifiedModule(kHloText)); + CompileAndOptionallyVerifyPtx(std::move(verified_module), + R"( +CHECK: mma.sync.aligned.{{.*}}.row.col.f32.bf16.bf16.f32 +CHECK-NOT: mma.sync.aligned.{{.*}}.row.col.f32.tf32.tf32.f32 +)"); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-6, + /*arel=*/1e-6})); +} + +// In these tests, we depend on "algorithm" annotations for selecting the 3XBF16 +// algorithm. +class Triton3xBF16GemmTest : public TritonFilecheckTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = TritonFilecheckTest::GetDebugOptionsForTest(); + // These 2 flags are not strictly necessary now, but we're adding them the + // to be on the safe side against future flakiness. + // + // Enable triton fusion for all supported gemms. + debug_options.set_xla_gpu_triton_gemm_any(true); + // Do not fall back to cuBLAS, we are testing Triton. + debug_options.set_xla_gpu_cublas_fallback(false); + + // Do not autotune split-k by default, since this prevents deterministically + // matching the optimized HLO. + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + return debug_options; + } +}; + +// In these tests, we depend on debug option flags for selecting the 3XBF16 +// algorithm. +// TODO(b/316147294): Remove this class and the --xla_gpu_enable_bf16_3way_gemm +// flag after we will support the algorithm values through the entire stack. +class Triton3xBF16GemmTestWithFlag : public TritonFilecheckTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = TritonFilecheckTest::GetDebugOptionsForTest(); + // Enable triton fusion for all supported gemms. + debug_options.set_xla_gpu_triton_gemm_any(true); + // Do not fall back to cuBLAS, we are testing Triton. + debug_options.set_xla_gpu_cublas_fallback(false); + // Do not autotune split-k by default, since this prevents deterministically + // matching the optimized HLO. + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + // Enable bf16_3way gemm to compute F32 matmul. + debug_options.set_xla_gpu_enable_bf16_3way_gemm(true); + return debug_options; + } +}; + +TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmWhenBothInputsAreF32) { + const char* kHloText = R"( +HloModule t + +triton_dot { + p0 = f32[5,7] parameter(0) + p1 = f32[7,33] parameter(1) + ROOT dot = f32[5,33] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x3 +} + +ENTRY e { + p0 = f32[5,7]{1,0} parameter(0) + p1 = f32[7,33]{1,0} parameter(1) + ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} +} +)"; + TritonGemmConfig config(32, 32, 32, 1, 1, 1); + ASSERT_OK( + CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( +CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> +CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> +CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32> +CHECK: %[[CAST_I32:.*]] = tt.bitcast %{{.*}} : tensor<32x32xf32> -> tensor<32x32xi32> +CHECK: %[[EXTRACT_HI:.*]] = arith.andi %[[CAST_I32]], %[[C_MASK]] : tensor<32x32xi32> +CHECK: %[[CAST_HI:.*]] = tt.bitcast %[[EXTRACT_HI]] : tensor<32x32xi32> -> tensor<32x32xf32> +CHECK: %[[TRUNC_TO_BF16:.*]] = arith.truncf %[[CAST_HI]] : tensor<32x32xf32> to tensor<32x32xbf16> +CHECK-COUNT-2: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> +CHECK: %[[ABS:.*]] = math.absf +CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[INFINITY]], %[[ABS]] : tensor<32x32xf32> +CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %{{.*}}, %[[C0]] : tensor<32x32xi1>, tensor<32x32xf32> +CHECK: %[[DOT_LAST:.*]] = tt.dot %{{.*}}, %{{.*}}, %[[SELECT]] {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> +CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf32> + )")); + + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-5, + /*arel=*/1e-5})); +} + +TEST_F(Triton3xBF16GemmTestWithFlag, Emit3xBF16GemmWhenBothInputsAreF32) { + const char* kHloText = R"( +HloModule t + +triton_dot { + p0 = f32[5,7] parameter(0) + p1 = f32[7,33] parameter(1) + ROOT dot = f32[5,33] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[5,7]{1,0} parameter(0) + p1 = f32[7,33]{1,0} parameter(1) + ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} +} +)"; + TritonGemmConfig config(32, 32, 32, 1, 1, 1); + TF_ASSERT_OK( + CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( +CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> +CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> +CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32> +CHECK: %[[CAST_I32:.*]] = tt.bitcast %{{.*}} : tensor<32x32xf32> -> tensor<32x32xi32> +CHECK: %[[EXTRACT_HI:.*]] = arith.andi %[[CAST_I32]], %[[C_MASK]] : tensor<32x32xi32> +CHECK: %[[CAST_HI:.*]] = tt.bitcast %[[EXTRACT_HI]] : tensor<32x32xi32> -> tensor<32x32xf32> +CHECK: %[[TRUNC_TO_BF16:.*]] = arith.truncf %[[CAST_HI]] : tensor<32x32xf32> to tensor<32x32xbf16> +CHECK-COUNT-2: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> +CHECK: %[[ABS:.*]] = math.absf +CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[INFINITY]], %[[ABS]] : tensor<32x32xf32> +CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %{{.*}}, %[[C0]] : tensor<32x32xi1>, tensor<32x32xf32> +CHECK: %[[DOT_LAST:.*]] = tt.dot %{{.*}}, %{{.*}}, %[[SELECT]] {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> +CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf32> + )")); + + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-5, + /*arel=*/1e-5})); +} + +TEST_F(Triton3xBF16GemmTestWithFlag, NoEmit3xBF16GemmWhenBothInputsAreNotF32) { + const char* kHloText = R"( +HloModule t + +triton_dot { + p0 = f16[5,7] parameter(0) + p1 = f16[7,33] parameter(1) + ROOT dot = f16[5,33] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f16[5,7]{1,0} parameter(0) + p1 = f16[7,33]{1,0} parameter(1) + ROOT _ = f16[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} +} +)"; + TritonGemmConfig config(32, 32, 32, 1, 1, 1); + TF_ASSERT_OK( + CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( +CHECK: tt.dot +CHECK-SAME: tensor<32x32xf16> * tensor<32x32xf16> -> tensor<32x32xf32> +CHECK-NOT: tt.dot + )")); +} + +TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmWorksForLongContractingDimension) { + const char* kHloText = R"( +HloModule t + +triton_dot { + p0 = f32[5,2048] parameter(0) + p1 = f32[2048,33] parameter(1) + ROOT dot = f32[5,33] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x3 +} + +ENTRY e { + p0 = f32[5,2048]{1,0} parameter(0) + p1 = f32[2048,33]{1,0} parameter(1) + ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":64,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":4, "num_ctas":1}}} +} +)"; + TritonGemmConfig config(64, 32, 32, 1, 1, 4); + TF_ASSERT_OK( + CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( +CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<64x32xbf16> * tensor<32x32xbf16> -> tensor<64x32xf32> + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-4, + /*arel=*/1e-4})); +} + +TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmCanHandleInfinity) { + const char* kHloText = R"( +HloModule t + +triton_dot { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT dot = f32[2,2] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x3 +} + +ENTRY e { + p0 = f32[2,2]{1, 0} parameter(0) + p1 = f32[2,2]{1, 0} parameter(1) + ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} +} +)"; + TritonGemmConfig config(32, 32, 32, 1, 1, 1); + TF_ASSERT_OK( + CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( +CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + std::vector arguments(2); + arguments[0] = + LiteralUtil::CreateR2({{+std::numeric_limits::infinity(), + +std::numeric_limits::infinity()}, + {+std::numeric_limits::infinity(), + +std::numeric_limits::infinity()}}); + arguments[1] = LiteralUtil::CreateR2({{1.0f, 1.0f}, {1.0f, 1.0f}}); + std::vector argument_ptrs; + absl::c_transform( + arguments, std::back_inserter(argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), argument_ptrs, + ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmCanHandleNaN) { + const char* kHloText = R"( +HloModule t + +triton_dot { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT dot = f32[2,2] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x3 +} + +ENTRY e { + p0 = f32[2,2]{1, 0} parameter(0) + p1 = f32[2,2]{1, 0} parameter(1) + ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} +} +)"; + TritonGemmConfig config(32, 32, 32, 1, 1, 1); + TF_ASSERT_OK( + CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( +CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + std::vector arguments(2); + arguments[0] = + LiteralUtil::CreateR2({{std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}}); + arguments[1] = LiteralUtil::CreateR2( + {{1.0f, +std::numeric_limits::infinity()}, + {1.0f, +std::numeric_limits::infinity()}}); + std::vector argument_ptrs; + absl::c_transform( + arguments, std::back_inserter(argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), argument_ptrs, + ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmWorksForInputsWithLargeExponent) { + const char* kHloText = R"( +HloModule t + +triton_dot { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT dot = f32[2,2] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x3 +} + +ENTRY e { + p0 = f32[2,2]{1, 0} parameter(0) + p1 = f32[2,2]{1, 0} parameter(1) + ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} +} +)"; + TritonGemmConfig config(32, 32, 32, 1, 1, 1); + TF_ASSERT_OK( + CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, "triton_dot", R"( +CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + std::vector arguments(2); + constexpr float kLargeExponentFloat = 0x1.0103p72f; + arguments[0] = LiteralUtil::CreateR2( + {{kLargeExponentFloat, 1.0f}, {-kLargeExponentFloat, 1.0f}}); + arguments[1] = LiteralUtil::CreateR2( + {{kLargeExponentFloat, 1.0f}, {-kLargeExponentFloat, 1.0f}}); + std::vector argument_ptrs; + absl::c_transform( + arguments, std::back_inserter(argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(std::move(module), argument_ptrs, + ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4})); +} + +// This test could be modified to allow TF32 once this bug is fixed. +// TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. +TEST_F(TritonFilecheckTest, NoTF32For8BitOrLessWithF32) { + const std::string hlo_text = R"( +HloModule t + +triton_dot { + parameter_0 = s32[11,24]{1,0} parameter(0) + broadcast.1747 = s32[11,24,128]{2,1,0} broadcast(parameter_0), + dimensions={0,1} parameter_1 = s32[11,24,128]{2,1,0} parameter(1) + compare.49 = pred[11,24,128]{2,1,0} compare(broadcast.1747, parameter_1), + direction=EQ bitcast.4717 = pred[264,128]{1,0} bitcast(compare.49) + convert.142 = f32[264,128]{1,0} convert(bitcast.4717) + parameter_2 = f32[128,8]{1,0} parameter(2) + ROOT dot.381 = f32[264,8]{1,0} dot(convert.142, parameter_2), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = s32[11,24]{1,0} parameter(0) + p1 = s32[11,24,128]{2,1,0} parameter(1) + p2 = f32[128,8]{1,0} parameter(2) + ROOT _ = f32[264,8] fusion(p0, p1, p2), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":16,"block_k":128, + "split_k":1,"num_stages":1,"num_warps":4, + "num_ctas":1}}} +})"; + + TritonGemmConfig config(32, 16, 128, 1, 1, 4); + ASSERT_OK( + CreateTritonIrAndFileCheck(hlo_text, config, EmitMatMul, "triton_dot", R"( +CHECK: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false + )")); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmEndToEnd) { + const char* kHloText = R"( +HloModule t + +ENTRY e { + p0 = f32[5,32] parameter(0) + p1 = f32[32,7] parameter(1) + ROOT dot = f32[5,7] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x3 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, + ParseAndReturnVerifiedModule(kHloText)); + CompileAndOptionallyVerifyPtx(std::move(verified_module), + R"( +CHECK: mma.sync.aligned.{{.*}}.row.col.f32.bf16.bf16.f32 +CHECK-NOT: mma.sync.aligned.{{.*}}.row.col.f32.tf32.tf32.f32 +)"); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-5, + /*arel=*/1e-5})); +} + +using TritonEmitterTest = TritonGemmTest; + +TEST_F(TritonEmitterTest, EmitterFailsIfComputeCapabilityIsBelowAmpere) { + const std::string kHloText = R"( +HloModule module, is_scheduled=true + +triton_gemm_dot { + p0 = f32[10,20] parameter(0) + p1 = f32[20,30] parameter(1) + ROOT dot = f32[10,30] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY entry { + p0 = f32[10,20] parameter(0) + p1 = f32[20,30] parameter(1) + ROOT r = f32[10,30] fusion(p0, p1), + kind=kCustom, calls=triton_gemm_dot +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloText)); + const HloComputation* triton_dot_computation = + hlo_module->entry_computation() + ->root_instruction() + ->fused_instructions_computation(); + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); + llvm::LLVMContext llvm_ctx; + llvm::Module llvm_module("module", llvm_ctx); + mlir::MLIRContext mlir_context; + + EXPECT_THAT( + TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation), + "test_fn", triton_dot_computation, + se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, + /*minor=*/0}, + dev_info, TritonGemmConfig{}, &llvm_module, &EmitMatMul, + mlir_context), + tsl::testing::StatusIs( + absl::StatusCode::kFailedPrecondition, + ::testing::StrEq( + "Triton support is only enabled for Ampere GPUs and up."))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc index 9e25ad8bddd53..d18037b3be94e 100644 --- a/xla/service/gpu/ir_emitter_unnested.cc +++ b/xla/service/gpu/ir_emitter_unnested.cc @@ -1,4 +1,4 @@ -/*Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/*Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -30,9 +29,9 @@ limitations under the License. #include #include -#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -47,6 +46,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" @@ -55,20 +55,17 @@ limitations under the License. #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/Linker/Linker.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project +#include "mlir/AsmParser/AsmParser.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" // from @llvm-project @@ -81,302 +78,121 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/utils/hlo_query.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/transforms/gpu_passes.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/collective_ops_utils.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" +#include "xla/service/global_device_id.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/conditional_thunk.h" -#include "xla/service/gpu/convolution_thunk.h" -#include "xla/service/gpu/copy_thunk.h" -#include "xla/service/gpu/for_thunk.h" -#include "xla/service/gpu/fused_mha_thunk.h" +#include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/fusions/fusions.h" -#include "xla/service/gpu/fusions/input_slices.h" -#include "xla/service/gpu/fusions/loop.h" -#include "xla/service/gpu/fusions/reduction.h" #include "xla/service/gpu/fusions/thunk_util.h" -#include "xla/service/gpu/fusions/transpose.h" -#include "xla/service/gpu/gemm_thunk.h" #include "xla/service/gpu/gpu_asm_opts_util.h" #include "xla/service/gpu/gpu_conv_runner.h" -#include "xla/service/gpu/gpu_executable.h" +#include "xla/service/gpu/gpu_flash_attn.h" #include "xla/service/gpu/gpu_fused_mha_runner.h" #include "xla/service/gpu/gpu_norm_runner.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/infeed_thunk.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/ir_emitter.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/ir_emitter_nested.h" #include "xla/service/gpu/kernel_arguments.h" -#include "xla/service/gpu/kernel_thunk.h" -#include "xla/service/gpu/kernels/custom_fusion.h" +#include "xla/service/gpu/kernel_reuse_cache.h" #include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/kernels/topk_custom_kernel.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/nccl_all_gather_thunk.h" -#include "xla/service/gpu/nccl_all_reduce_thunk.h" -#include "xla/service/gpu/nccl_all_to_all_thunk.h" -#include "xla/service/gpu/nccl_collective_permute_thunk.h" -#include "xla/service/gpu/nccl_collective_thunk.h" -#include "xla/service/gpu/norm_thunk.h" -#include "xla/service/gpu/outfeed_thunk.h" #include "xla/service/gpu/parallel_loop_emitter.h" -#include "xla/service/gpu/replica_id_thunk.h" -#include "xla/service/gpu/runtime3/command_buffer_cmd.h" -#include "xla/service/gpu/runtime3/command_buffer_cmd_emitter.h" -#include "xla/service/gpu/runtime3/command_buffer_thunk.h" -#include "xla/service/gpu/runtime3/custom_call_thunk.h" -#include "xla/service/gpu/runtime3/fft_thunk.h" -#include "xla/service/gpu/sequential_thunk.h" -#include "xla/service/gpu/thunk.h" -#include "xla/service/gpu/while_thunk.h" +#include "xla/service/gpu/runtime/command_buffer_cmd.h" +#include "xla/service/gpu/runtime/command_buffer_cmd_emitter.h" +#include "xla/service/gpu/runtime/command_buffer_thunk.h" +#include "xla/service/gpu/runtime/conditional_thunk.h" +#include "xla/service/gpu/runtime/convolution_thunk.h" +#include "xla/service/gpu/runtime/copy_thunk.h" +#include "xla/service/gpu/runtime/custom_call_thunk.h" +#include "xla/service/gpu/runtime/fft_thunk.h" +#include "xla/service/gpu/runtime/flash_attn_thunk.h" +#include "xla/service/gpu/runtime/fused_mha_thunk.h" +#include "xla/service/gpu/runtime/gemm_thunk.h" +#include "xla/service/gpu/runtime/infeed_thunk.h" +#include "xla/service/gpu/runtime/kernel_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_gather_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_to_all_thunk.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h" +#include "xla/service/gpu/runtime/nccl_collective_permute_thunk.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/service/gpu/runtime/nccl_recv_thunk.h" +#include "xla/service/gpu/runtime/nccl_send_thunk.h" +#include "xla/service/gpu/runtime/norm_thunk.h" +#include "xla/service/gpu/runtime/outfeed_thunk.h" +#include "xla/service/gpu/runtime/replica_id_thunk.h" +#include "xla/service/gpu/runtime/send_recv_thunk.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/service/gpu/runtime/wait_for_streams_thunk.h" +#include "xla/service/gpu/runtime/while_thunk.h" +#include "xla/service/gpu/triton_call.h" #include "xla/service/llvm_ir/buffer_assignment_util.h" -#include "xla/service/llvm_ir/fused_ir_emitter.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/kernel_support_library.h" +#include "xla/service/llvm_ir/llvm_loop.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "xla/service/llvm_ir/loop_emitter.h" #include "xla/service/llvm_ir/sort_util.h" #include "xla/service/name_uniquer.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" -#include "xla/translate/mhlo_to_hlo/attribute_exporter.h" -#include "xla/translate/mhlo_to_hlo/location_exporter.h" -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" -#include "xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h" +#include "xla/stream_executor/gpu/gpu_blas_lt.h" +#include "xla/stream_executor/integrations/device_mem_allocator.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/human_readable_json.h" -#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/protobuf/dnn.pb.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #if GOOGLE_CUDA || TF_HIPBLASLT -#include "xla/service/gpu/gpublas_lt_matmul_thunk.h" +#include "xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h" #endif // GOOGLE_CUDA || TF_HIPBLASLT #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/service/gpu/cub_sort_thunk.h" #include "xla/service/gpu/ir_emitter_triton.h" -#include "xla/service/gpu/runtime3/cholesky_thunk.h" -#include "xla/service/gpu/runtime3/triangular_solve_thunk.h" +#include "xla/service/gpu/runtime/cholesky_thunk.h" +#include "xla/service/gpu/runtime/cub_sort_thunk.h" +#include "xla/service/gpu/runtime/triangular_solve_thunk.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace xla { namespace gpu { namespace { -// Some HLO operations are not implemented as Thunks, and only available when -// XLA:GPU compiled for XLA runtime. However we still depend on emitting thunk -// sequence during compilation, and for unsupported operations we emit -// unreachable thunk, which is not supposed to be executed, and exists only -// during compilation as we transition from thunks to XLA runtime. -// -// Examples: Point-to-point communication operations (Send and Recv) are only -// available as XLA runtime custom calls. API_VERSION_TYPED_FFI custom calls -// are only implemented when executing with XLA runtime. -class UnreachableThunk : public Thunk { - public: - UnreachableThunk(mlir::Operation* op, std::string error_message) - : Thunk(Kind::kKernel, ThunkInfo(op)), - error_message_(std::move(error_message)) {} - - UnreachableThunk(const UnreachableThunk&) = delete; - UnreachableThunk& operator=(const UnreachableThunk&) = delete; - - Status Initialize(se::StreamExecutor*, ExecutableSource) final { - return tsl::errors::Internal(error_message_); - } - - Status ExecuteOnStream(const ExecuteParams& params) final { - return tsl::errors::Internal(error_message_); - } - - private: - std::string error_message_; -}; - -StatusOr AsCudnnfMHAKind( - mlir::lmhlo_gpu::FusedMhaDagSignature signature) { - switch (signature) { - case mlir::lmhlo_gpu::FusedMhaDagSignature::Default: - return xla::gpu::CudnnfMHAKind::kBmmBmm; - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasMaskSoftmax: - return xla::gpu::CudnnfMHAKind::kScaleBiasMaskSoftmax; - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasMaskSoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kScaleBiasMaskSoftmaxDropout; - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleMaskSoftmax: - return xla::gpu::CudnnfMHAKind::kScaleMaskSoftmax; - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleMaskSoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kScaleMaskSoftmaxDropout; - case mlir::lmhlo_gpu::FusedMhaDagSignature::SoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kSoftmaxDropout; - case mlir::lmhlo_gpu::FusedMhaDagSignature::Softmax: - return xla::gpu::CudnnfMHAKind::kSoftmax; - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasSoftmax: - return xla::gpu::CudnnfMHAKind::kScaleBiasSoftmax; - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasSoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kScaleBiasSoftmaxDropout; - default: - return xla::InternalError("Unsupported fused_mha_dag_signature"); - } -} - -StatusOr AsCudnnBackwardfMHAKind( - mlir::lmhlo_gpu::FusedMhaBackwardDagSignature signature) { - switch (signature) { - // backward - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasSoftmax: - return xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmax; - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasSoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout; - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasMaskSoftmax: - return xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmax; - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasMaskSoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmaxDropout; - break; - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmax: - return xla::gpu::CudnnfMHAKind::kBackwardSoftmax; - break; - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kBackwardSoftmaxDropout; - break; - default: - return xla::InternalError("Unsupported fused_mha_backward_dag_signature"); - } -} - -// Builds a thunk that calls a new or reused kernel for a fusion operation. -// -// The caller must specify the same launch dimensions for fusions which have -// the same computation. -// -// If a given fusion is implemented using multiple kernels, then for each -// kernel we should provide a discriminator, such as "init" and "impl". -// -// The builder_fn is only invoked if the kernel couldn't be reused. -// -// This is the typical usage pattern of this method: -// -// ``` -// auto builder_fn = [](std::vector inputs, -// std::vector outputs) { ... }; -// TF_ASSIGN_OR_RETURN( -// auto thunk, -// BuildKernelThunkForFusion(..., fusion_op, launch_dimensions, builder_fn, -// ...)); -// AddThunkToThunkSequence(std::move(thunk)) -// ``` -StatusOr> BuildKernelThunkForFusion( - IrEmitterContext& ir_emitter_context, KernelReuseCache& kernel_cache, - const HloFusionInstruction* fusion, mlir::lmhlo::FusionOp fusion_op, - const HloComputation* fused_computation, - const LaunchDimensions& launch_dimensions, absl::string_view discriminator, - std::function, - std::vector)> - kernel_builder_fn, - llvm::IRBuilder<>* builder) { - std::string suggested_kernel_name = std::string(fusion->name()); - - TF_ASSIGN_OR_RETURN(auto kernel_arguments, - ir_emitter_context.emit_ir_from_hlo() - ? KernelArguments::Create( - ir_emitter_context.buffer_assignment(), fusion) - : KernelArguments::Create( - ir_emitter_context.allocations(), fusion_op)); - - auto kernel_builder_status = OkStatus(); - auto [entry, cached] = kernel_cache.Get( - fused_computation, kernel_arguments.args(), discriminator, - [&]() -> KernelReuseCache::Entry { - auto [kernel, input_arrays, output_arrays] = BuildKernelPrototype( - ir_emitter_context, suggested_kernel_name, kernel_arguments.args(), - fusion->operand_count(), launch_dimensions, builder); - kernel_builder_status = kernel_builder_fn(input_arrays, output_arrays); - return {kernel->getName().str(), launch_dimensions}; - }); - TF_RETURN_IF_ERROR(kernel_builder_status); - if (cached) { - VLOG(3) << "Reuse: " << suggested_kernel_name << " -> " - << entry.kernel_name; - } - - std::variant op; - if (ir_emitter_context.emit_ir_from_hlo()) { - op = fusion; - } else { - op = fusion_op; - } - - return std::make_unique( - op, entry.kernel_name, kernel_arguments.args(), launch_dimensions, - /*shmem_bytes=*/0); -} - -StatusOr> BuildCustomKernelThunkForFusion( - IrEmitterContext& ir_emitter_context, const HloFusionInstruction* fusion, - CustomKernel custom_kernel) { - TF_ASSIGN_OR_RETURN( - auto kernel_arguments, - KernelArguments::Create(ir_emitter_context.buffer_assignment(), fusion)); - - return std::make_unique( - fusion, std::move(custom_kernel), std::move(kernel_arguments.args())); -} - -// Derives the number of warps to use for processing a Triton Softmax fusion. -int DeriveNumWarpsFromTritonSoftmaxComputation( - const HloComputation* computation) { - const HloInstruction* reduce = hlo_query::GetFirstInstructionWithOpcode( - *computation, HloOpcode::kReduce); - - CHECK_NE(reduce, nullptr); - Shape reduce_input_shape = reduce->operand(0)->shape(); - - CHECK_EQ(reduce->dimensions().size(), 1); - CHECK_EQ(reduce->dimensions()[0], reduce_input_shape.rank() - 1); - - int reduction_dim = reduce_input_shape.dimensions_minor(0); - - int num_warps = 32; - - if (reduction_dim <= 512) { - num_warps = 1; - } else if (reduction_dim <= 1024) { - num_warps = 2; - } else if (reduction_dim <= 16384) { - num_warps = 4; - } else if (reduction_dim <= 32768) { - num_warps = 8; - } else if (reduction_dim <= 65536) { - num_warps = 16; - } - - return num_warps; +// Construct the key for looking up the AsyncEvents for Send and Recv. Input +// kind is the thunk kind for the corresponding done thunk. +inline std::pair GetSendRecvAsyncEventsKey(Thunk::Kind kind, + int64_t channel_id) { + return std::make_pair(kind == Thunk::Kind::kNcclRecvDone, channel_id); } } // namespace IrEmitterUnnested::IrEmitterUnnested(IrEmitterContext* ir_emitter_context) : IrEmitter(ir_emitter_context, /*is_nested=*/false), + send_recv_events_(std::make_shared()), elemental_emitter_(*ir_emitter_context, &b_) {} std::unique_ptr IrEmitterUnnested::Create( @@ -385,100 +201,60 @@ std::unique_ptr IrEmitterUnnested::Create( new IrEmitterUnnested(ir_emitter_context)); } -StatusOr IrEmitterUnnested::GetAllocationSlice( - mlir::Value v) { - return xla::gpu::GetAllocationSlice(v, ir_emitter_context_->allocations(), - nullptr); -} - -StatusOr> -IrEmitterUnnested::GetAllocationSlices(mlir::OperandRange operands) { - std::vector slices; - slices.reserve(operands.size()); - for (mlir::Value operand : operands) { - TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(operand)); - slices.push_back(slice); - } - return slices; -} - -Status IrEmitterUnnested::EmitUnreachable(mlir::Operation* op, - std::string error_message) { - AddThunkToThunkSequence(std::unique_ptr( - new UnreachableThunk(op, std::move(error_message)))); - return OkStatus(); -} - -Status IrEmitterUnnested::EmitConstant(mlir::Operation* op, - const Literal& literal) { - auto get_global = mlir::cast(op); - auto module = get_global->getParentOfType(); - auto global = mlir::cast( - module.lookupSymbol(get_global.getName())); +absl::Status IrEmitterUnnested::EmitConstant( + const HloConstantInstruction* instr) { TF_ASSIGN_OR_RETURN(DenseDataIntermediate content, - LiteralToXlaFormat(literal)); + LiteralToXlaFormat(instr->literal())); - int element_bytes = primitive_util::ByteWidth(literal.shape().element_type()); + int element_bytes = + primitive_util::ByteWidth(instr->literal().shape().element_type()); TF_RET_CHECK(content.span().size() % element_bytes == 0); // Treat int4 constant as int8 constant with half the number of elements. int num_elements = content.span().size() / element_bytes; - int64_t arg_index = - global->getAttrOfType("lmhlo.alloc").getInt(); - int allocation_index = ir_emitter_context_->allocations()[arg_index]->index(); + std::string global_name = llvm_ir::ConstantHloToGlobalName(*instr); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, + GetAllocationSliceForHlo(instr, {})); - ir_emitter_context_->emit_constant(num_elements, element_bytes, - global.getSymName(), allocation_index, - std::move(content), &b_); - return OkStatus(); + ir_emitter_context_->emit_constant(num_elements, element_bytes, global_name, + slice.index(), std::move(content), &b_); + return absl::OkStatus(); } static ConditionalThunkConfig GetConditionalThunkConfig( - mlir::lmhlo::CaseOp op, std::vector branch_thunk_sequences) { + const HloInstruction* instr, + std::vector branch_thunk_sequences) { ConditionalThunkConfig config; - config.branch_index_is_bool = op.getIndex() - .getType() - .cast() - .getElementType() - .isInteger( - /*width=*/1); - config.branch_count = op.getBranches().size(); - // Pass nullptr as the HloInstruction* to the branch_thunks - // constructors because these SequentialThunks are logically "part of" - // this ConditionalThunk, and shouldn't be profiled separately from it. - config.branch_thunks.reserve(branch_thunk_sequences.size()); + config.branch_index_is_bool = + instr->operand(0)->shape().element_type() == PRED; + config.branch_count = instr->branch_count(); + config.branch_thunks.reserve(config.branch_count); for (auto& branch_thunk_sequence : branch_thunk_sequences) { - config.branch_thunks.emplace_back(new SequentialThunk( - Thunk::ThunkInfo(op), std::move(branch_thunk_sequence))); + config.branch_thunks.emplace_back( + new SequentialThunk(Thunk::ThunkInfo::WithProfileAnnotation(instr), + std::move(branch_thunk_sequence))); } return config; } -Status IrEmitterUnnested::EmitConditional( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo) { - auto conditional = mlir::cast(op); - +Status IrEmitterUnnested::EmitConditional(const HloInstruction* instr) { std::vector branch_thunks; + branch_thunks.reserve(instr->branch_count()); - int branch_count = conditional.getBranches().size(); - branch_thunks.reserve(branch_count); - - for (int j = 0; j < branch_count; ++j) { - mlir::Region* branch_computation = &conditional.getBranches()[j]; + for (auto comp : instr->branch_computations()) { auto ir_emitter = IrEmitterUnnested::Create(ir_emitter_context_); - TF_RETURN_IF_ERROR( - ir_emitter->EmitLmhloRegion(branch_computation, hlo_for_lmhlo)); + TF_RETURN_IF_ERROR(ir_emitter->EmitHloComputation(comp)); branch_thunks.push_back(std::move(*ir_emitter->ConsumeThunkSequence())); } ConditionalThunkConfig config = - GetConditionalThunkConfig(conditional, std::move(branch_thunks)); + GetConditionalThunkConfig(instr, std::move(branch_thunks)); - TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(conditional.getIndex())); - AddThunkToThunkSequence(std::unique_ptr(new ConditionalThunk( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(config), slice))); + TF_ASSIGN_OR_RETURN(auto slice, + GetAllocationSliceForHlo(instr->operand(0), {})); + AddThunkToThunkSequence(std::unique_ptr( + new ConditionalThunk(Thunk::ThunkInfo::WithProfileAnnotation(instr), + std::move(config), slice))); return OkStatus(); } @@ -538,23 +314,20 @@ void IrEmitterUnnested::CreateStore(llvm::Value* data, llvm::Value* address, // Input = {dynamic array(with dynamic dimension meta data at the end)} // Output = {static array, dynamic_dim0, dynamic_dim1} -Status IrEmitterUnnested::EmitPadToStatic(mlir::Operation* op) { - // TODO(jurahul): Create an op to represent PadToStatic. - auto pad_to_static = mlir::cast(op); +absl::Status IrEmitterUnnested::EmitPadToStatic( + const HloCustomCallInstruction* instr) { int unroll_factor = 1; - std::string ir_name = GetIrNameFromLoc(pad_to_static.getLoc()); + std::string ir_name = std::string(instr->name()); - const Shape& input_shape = GetShape(pad_to_static.getArgs().front()); + const Shape& input_shape = instr->operand(0)->shape(); - TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, - CalculateLaunchDimensions( - input_shape, ir_emitter_context_->gpu_device_info(), - {unroll_factor})); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + input_shape, ir_emitter_context_->gpu_device_info(), {unroll_factor}); std::vector input_arrays; std::vector output_arrays; - TF_ASSIGN_OR_RETURN( - std::tie(input_arrays, output_arrays), - BuildKernelThunkForNonFusionOp(pad_to_static, launch_dimensions)); + TF_ASSIGN_OR_RETURN(std::tie(input_arrays, output_arrays), + BuildKernelThunkForNonFusionOp(instr, instr->operands(), + launch_dimensions)); CHECK_EQ(output_arrays.size(), 0); const llvm_ir::IrArray source_array = input_arrays[0]; @@ -562,8 +335,8 @@ Status IrEmitterUnnested::EmitPadToStatic(mlir::Operation* op) { auto output_dim_arrays = absl::Span(input_arrays).subspan(2); - llvm::Type* index_ty = GetIndexTypeForKernel( - pad_to_static, launch_dimensions.launch_bound(), &b_); + llvm::Type* index_ty = + GetIndexTypeForKernel(instr, launch_dimensions.launch_bound(), &b_); // pseudo code for PadToStatic on a 2d array // int* source_array = input[0]; @@ -581,10 +354,13 @@ Status IrEmitterUnnested::EmitPadToStatic(mlir::Operation* op) { // int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int); std::vector dynamic_dims; int alignment = raw_data_size % sizeof(int32_t); - for (int64_t i = 1; i < pad_to_static.getOutput().size(); ++i) { + std::vector output_shapes = + ShapeUtil::GetLeafShapes(instr->shape()); + + for (int64_t i = 1; i < output_shapes.size(); ++i) { // Dynamic size of each dimension is attached at the end of the source // array(operand(0)). We need to extract these value. - const Shape& dim_shape = GetShape(pad_to_static.getOutput()[i]); + const Shape& dim_shape = output_shapes[i].shape; TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32))); const int64_t dim_index = i - 1; @@ -604,7 +380,7 @@ Status IrEmitterUnnested::EmitPadToStatic(mlir::Operation* op) { // *output[2] = *dyn_dim1_size; // } KernelSupportLibrary{&b_}.If("is_thread_0", IsBlock0Thread0(&b_), [&] { - for (int64_t i = 1; i < pad_to_static.getOutput().size(); ++i) { + for (int64_t i = 1; i < output_shapes.size(); ++i) { const int64_t dim_index = i - 1; llvm::Value* dest_dim_size_address = output_dim_arrays[dim_index].GetBasePointer(); @@ -637,7 +413,7 @@ Status IrEmitterUnnested::EmitPadToStatic(mlir::Operation* op) { // } // } llvm_ir::BodyEmitter body_generator = - [&](const llvm_ir::IrArray::Index& array_index) -> Status { + [&](const llvm_ir::IrArray::Index& array_index) -> absl::Status { llvm::Value* linearIndex = array_index.Linearize(input_shape.dimensions(), &b_); auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse( @@ -651,40 +427,38 @@ Status IrEmitterUnnested::EmitPadToStatic(mlir::Operation* op) { dyn_index, source_array.EmitReadArrayElement(array_index, &b_, /*name=*/""), &b_, /*use_linear_index=*/false); - return OkStatus(); + return absl::OkStatus(); }; - const Shape& data_shape = GetShape(pad_to_static.getOutput().front()); + const Shape& data_shape = instr->shape().tuple_shapes(0); TF_RETURN_IF_ERROR(ParallelLoopEmitter(body_generator, data_shape, launch_dimensions, &b_, {unroll_factor}) .EmitLoop(ir_name, index_ty)); - return OkStatus(); + return absl::OkStatus(); } // Input = {dynamic array(with dynamic dimension meta data at the end)} // Output = {static array, dynamic_dim0, dynamic_dim1} -Status IrEmitterUnnested::EmitSliceToDynamic(mlir::Operation* op) { +absl::Status IrEmitterUnnested::EmitSliceToDynamic( + const HloCustomCallInstruction* instr) { // TODO(jurahul): Create an op to represent SliceToDynamic. - auto slice_to_dynamic = mlir::cast(op); int unroll_factor = 1; - std::string ir_name = GetIrNameFromLoc(slice_to_dynamic.getLoc()); + std::string ir_name = std::string(instr->name()); - const Shape& input_shape = GetShape(slice_to_dynamic.getArgs().front()); + const Shape& input_shape = instr->operand(0)->shape(); - TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, - CalculateLaunchDimensions( - input_shape, ir_emitter_context_->gpu_device_info(), - {unroll_factor})); - llvm::Type* index_ty = GetIndexTypeForKernel( - slice_to_dynamic, launch_dimensions.launch_bound(), &b_); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + input_shape, ir_emitter_context_->gpu_device_info(), {unroll_factor}); + llvm::Type* index_ty = + GetIndexTypeForKernel(instr, launch_dimensions.launch_bound(), &b_); std::vector input_arrays, output_arrays; - TF_ASSIGN_OR_RETURN( - std::tie(input_arrays, output_arrays), - BuildKernelThunkForNonFusionOp(slice_to_dynamic, launch_dimensions)); + TF_ASSIGN_OR_RETURN(std::tie(input_arrays, output_arrays), + BuildKernelThunkForNonFusionOp(instr, instr->operands(), + launch_dimensions)); - TF_RET_CHECK(slice_to_dynamic.getOutput().size() == 1); - const Shape& data_shape = GetShape(slice_to_dynamic.getOutput().front()); + const Shape& data_shape = ShapeUtil::MakeStaticShape(instr->shape()); + TF_RET_CHECK(data_shape.IsArray()); // TODO(jurahul): data_shape here is the static shape of the output (which has // a dynamic shape in XLA). Currently, we are mapping that to a static shaped @@ -706,7 +480,7 @@ Status IrEmitterUnnested::EmitSliceToDynamic(mlir::Operation* op) { // Load dynamic dimensions from memory. std::vector dynamic_dims; int alignment = raw_data_size % sizeof(int32_t); - for (int64_t i = 1; i < slice_to_dynamic.getArgs().size(); ++i) { + for (int64_t i = 1; i < instr->operand_count(); ++i) { llvm::Value* source_buffer = input_arrays[i].GetBasePointer(); llvm::Type* source_buffer_pointee_type = input_arrays[i].GetBasePointeeType(); @@ -723,7 +497,7 @@ Status IrEmitterUnnested::EmitSliceToDynamic(mlir::Operation* op) { // *dyn_dim1_size = *output[2]; // } KernelSupportLibrary{&b_}.If("is_thread_0", IsBlock0Thread0(&b_), [&] { - for (int64_t i = 1; i < slice_to_dynamic.getArgs().size(); ++i) { + for (int64_t i = 1; i < instr->operand_count(); ++i) { const int64_t dim_index = i - 1; llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32( b_.getInt8Ty(), dest_buffer, @@ -757,7 +531,7 @@ Status IrEmitterUnnested::EmitSliceToDynamic(mlir::Operation* op) { // } // } llvm_ir::BodyEmitter body_generator = - [&](const llvm_ir::IrArray::Index& array_index) -> Status { + [&](const llvm_ir::IrArray::Index& array_index) -> absl::Status { llvm::Value* linearIndex = array_index.Linearize(input_shape.dimensions(), &b_); auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse( @@ -773,17 +547,18 @@ Status IrEmitterUnnested::EmitSliceToDynamic(mlir::Operation* op) { input_arrays[0].EmitReadArrayElement(dyn_index, &b_, /*name=*/"", /*use_linear_index=*/false), &b_); - return OkStatus(); + return absl::OkStatus(); }; TF_RETURN_IF_ERROR(ParallelLoopEmitter(body_generator, data_shape, launch_dimensions, &b_, {unroll_factor}) .EmitLoop(ir_name, index_ty)); - return OkStatus(); + return absl::OkStatus(); } -Status IrEmitterUnnested::EmitCommandBufferThunk(const HloInstruction* instr) { +absl::Status IrEmitterUnnested::EmitCommandBufferThunk( + const HloInstruction* instr) { // Spawn a new IrEmitterUnnested to emit thunks for the command buffer // computation. Then convert emitted thunks to a sequence of CommandBufferCmd. // The resulting thunk added to the thunk sequence is a CommandBufferThunk. @@ -794,786 +569,901 @@ Status IrEmitterUnnested::EmitCommandBufferThunk(const HloInstruction* instr) { TF_RETURN_IF_ERROR(ir_emitter->EmitHloComputation(command_buffer)); std::unique_ptr thunk_sequence = ir_emitter->ConsumeThunkSequence(); + + // Maybe serialize all commands in a sequence by forcing barriers between all + // recorded commands. This guarantees that we execute all device operations + // in the exact same order as a thunk sequence. + CommandBufferCmdSequence::SynchronizationMode synchronization_mode = + ir_emitter_context_->debug_options() + .xla_gpu_graph_enable_concurrent_region() + ? CommandBufferCmdSequence::SynchronizationMode::kAutomatic + : CommandBufferCmdSequence::SynchronizationMode::kSerialize; + TF_ASSIGN_OR_RETURN(CommandBufferCmdSequence cmd_sequence, - ConvertToCommands(*thunk_sequence)); + ConvertToCommands(*thunk_sequence, synchronization_mode)); + AddThunkToThunkSequence(std::make_unique( - std::move(cmd_sequence), Thunk::ThunkInfo::WithProfileAnnotation(instr))); - return OkStatus(); + std::move(cmd_sequence), Thunk::ThunkInfo::WithProfileAnnotation(instr), + std::move(*thunk_sequence))); + + return absl::OkStatus(); } -Status IrEmitterUnnested::EmitConvolutionThunk(mlir::Operation* op) { - using mlir::dyn_cast; - using mlir::lmhlo_gpu::Activation; - using mlir::lmhlo_gpu::ConvBackwardFilterOp; - using mlir::lmhlo_gpu::ConvBackwardInputOp; - using mlir::lmhlo_gpu::ConvForwardFusedOp; - using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp; - using mlir::lmhlo_gpu::ConvForwardGraphOp; - using mlir::lmhlo_gpu::ConvForwardOp; - - std::vector operand_slices, result_slices; - int32_t n_aux_outputs = 0; - if (auto conv = dyn_cast(op)) { - n_aux_outputs = conv.getNAuxOutputs(); - } - int64_t num_operands = op->getNumOperands(); - operand_slices.reserve(num_operands - n_aux_outputs - 2); - - // The operands describe inputs, the main result of the convolution, the - // scratch workspace and n_aux_outputs return values of ops fused into the - // convolution. - for (mlir::Value operand : op->getOperands().drop_back(2 + n_aux_outputs)) { - TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(operand)); +absl::Status IrEmitterUnnested::EmitConvolutionThunk( + const HloCustomCallInstruction* instr) { + std::vector operand_slices; + operand_slices.reserve(instr->operand_count()); + for (const HloInstruction* operand : instr->operands()) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, + GetAllocationSliceForHlo(operand, {})); operand_slices.push_back(slice); } - result_slices.reserve(1 + n_aux_outputs); - for (mlir::Value result : op->getOperands() - .drop_front(num_operands - n_aux_outputs - 2) - .drop_back(1)) { - TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(result)); - result_slices.push_back(slice); + // The first and the last element in the result tuple for a convolution are + // always the result and the scratch buffer. It may have auxiliary results in + // addition to the main result. + std::vector result_slices; + for (int i = 0; i < instr->shape().tuple_shapes_size() - 1; i++) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice, + GetAllocationSliceForHlo(instr, {i})); + result_slices.push_back(result_slice); } - mlir::Value scratch_result = op->getOperand(num_operands - 1); - TF_ASSIGN_OR_RETURN(auto scratch_slice, GetAllocationSlice(scratch_result)); - - auto apply_layout = [](const Shape& shape, - mlir::ArrayRef minor_to_major) { - return ShapeUtil::MakeShapeWithDenseLayout( - shape.element_type(), shape.dimensions(), minor_to_major); - }; - - GpuConvDescriptor descriptor; - - auto fill_conv_descriptor = [&](auto op) { - descriptor.operand0_shape = - apply_layout(GetShape(op->getOperand(0)), - op.getBackendConfig().getOperand_0Layout()); - descriptor.operand1_shape = - apply_layout(GetShape(op->getOperand(1)), - op.getBackendConfig().getOperand_1Layout()); - descriptor.result_shape = - apply_layout(GetShape(op->getOperand(num_operands - n_aux_outputs - 2)), - op.getBackendConfig().getResultLayout()); - descriptor.dnums = ConvertConvDimensionNumbers(op.getDimensionNumbers()); - descriptor.scratch_size = scratch_slice.size(); - mlir::DenseIntElementsAttr window_strides = op.getWindowStrides().value(); - mlir::DenseIntElementsAttr padding = op.getPadding().value(); - mlir::DenseIntElementsAttr lhs_dilation = op.getLhsDilation().value(); - mlir::DenseIntElementsAttr rhs_dilation = op.getRhsDilation().value(); - mlir::DenseElementsAttr window_reversal = op.getWindowReversal().value(); - for (auto index : llvm::seq(0, window_strides.getNumElements())) { - WindowDimension* dim = descriptor.window.add_dimensions(); - // Window size for a convolution is the same as the kernel size. - // Kernel size of the convolution is operand1_shape. We need to look at - // the convolution dimension numbers kernel spatial dimensions to get - // the window size. - int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index); - dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim)); - dim->set_stride(window_strides.getValues()[index]); - dim->set_padding_low(padding.getValues()[index]); - dim->set_padding_high(padding.getValues()[index]); - dim->set_base_dilation(lhs_dilation.getValues()[index]); - dim->set_window_dilation(rhs_dilation.getValues()[index]); - dim->set_window_reversal(window_reversal.getValues()[index]); - } - descriptor.feature_group_count = op.getFeatureGroupCount(); - { - auto* algorithm = descriptor.backend_config.mutable_algorithm(); - algorithm->set_algo_id(op.getBackendConfig().getAlgorithm()); - algorithm->set_math_type(op.getBackendConfig().getTensorOpsEnabled() - ? se::dnn::AlgorithmProto::TENSOR_OP_MATH - : se::dnn::AlgorithmProto::DEFAULT_MATH); - for (int i = 0; i < op.getBackendConfig().getKnobIds().size(); ++i) { - // N.B. tuning_knobs is a map rather than a repeated field, so this - // doesn't require reserving space up front. - (*algorithm - ->mutable_tuning_knobs())[op.getBackendConfig().getKnobIds()[i]] = - op.getBackendConfig().getKnobValues()[i]; - } - algorithm->set_is_cudnn_frontend( - op.getBackendConfig().getIsCudnnFrontend()); - auto workspace_size = op.getBackendConfig().getWorkspaceSize(); - if (workspace_size >= 0) { - algorithm->mutable_workspace_size()->set_value(workspace_size); - } - } - descriptor.backend_config.set_conv_result_scale( - op.getResultScale().convertToDouble()); - descriptor.backend_config.set_reordered_int8_nchw_vect( - op.getBackendConfig().getIsCudnnReorderedInt8()); - }; - auto set_activation_mode = [&](auto op) -> Status { - TF_ASSIGN_OR_RETURN(stream_executor::dnn::ActivationMode activation_mode, - ConvertConvActivationMode(op.getActivationMode())); - descriptor.backend_config.set_activation_mode(activation_mode); - return OkStatus(); - }; + TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr)); + TF_ASSIGN_OR_RETURN(auto gpu_config, + instr->backend_config()); + const CudnnConvBackendConfig& backend_config = + gpu_config.cudnn_conv_backend_config(); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice, + GetAllocationSliceForHlo( + instr, {instr->shape().tuple_shapes_size() - 1})); + GpuConvDescriptor descriptor = {kind, + backend_config, + instr->operand(0)->shape(), + instr->operand(1)->shape(), + instr->shape().tuple_shapes(0), + static_cast(scratch_slice.size()), + instr->window(), + instr->convolution_dimension_numbers(), + instr->feature_group_count()}; - if (auto conv = dyn_cast(op)) { - descriptor.kind = CudnnConvKind::kForward; - fill_conv_descriptor(conv); - } else if (auto conv = dyn_cast(op)) { - descriptor.kind = CudnnConvKind::kBackwardInput; - fill_conv_descriptor(conv); - } else if (auto conv = dyn_cast(op)) { - descriptor.kind = CudnnConvKind::kBackwardFilter; - fill_conv_descriptor(conv); - } else if (auto conv = dyn_cast(op)) { - descriptor.kind = CudnnConvKind::kForwardGraph; - fill_conv_descriptor(conv); - descriptor.backend_config.set_serialized_graph( - conv.getSerializedGraph().data()); - } else if (auto conv = dyn_cast(op)) { - descriptor.kind = CudnnConvKind::kForwardActivation; - fill_conv_descriptor(conv); - TF_RETURN_IF_ERROR(set_activation_mode(conv)); - descriptor.backend_config.set_leakyrelu_alpha( - conv.getLeakyreluAlpha().convertToDouble()); - } else if (auto conv = dyn_cast(op)) { - descriptor.kind = CudnnConvKind::kForwardActivation; - fill_conv_descriptor(conv); - TF_RETURN_IF_ERROR(set_activation_mode(conv)); - descriptor.backend_config.set_side_input_scale( - conv.getSideInputScale().convertToDouble()); - } else { - return InternalError("EmitConvolutionThunk: Unexpected operation"); - } TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(descriptor, "")); AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(config), + Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(config), std::move(operand_slices), std::move(result_slices), scratch_slice)); return OkStatus(); } -Status IrEmitterUnnested::EmitGemmThunk(mlir::Operation* op) { - auto gemm = mlir::dyn_cast(op); - TF_RET_CHECK(gemm != nullptr); +absl::Status IrEmitterUnnested::EmitGemmThunk( + const HloCustomCallInstruction* instr) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a, + GetAllocationSliceForHlo(instr->operand(0), {})); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice b, + GetAllocationSliceForHlo(instr->operand(1), {})); + + // Result of a legacy cuBLAS custom call can be a tuple if we explicitly + // allocate workspace buffer in HLO. If result is an array, it means that + // workspace is not available, and cuBLAS will allocate its own workspace. + BufferAllocation::Slice c; + std::optional workspace; + + if (instr->shape().IsArray()) { + TF_ASSIGN_OR_RETURN(c, GetAllocationSliceForHlo(instr, {})); + } else { + TF_ASSIGN_OR_RETURN(c, GetAllocationSliceForHlo(instr, {0})); + TF_ASSIGN_OR_RETURN(workspace, GetAllocationSliceForHlo(instr, {1})); + } - TF_ASSIGN_OR_RETURN(auto a, GetAllocationSlice(gemm.getA())); - TF_ASSIGN_OR_RETURN(auto b, GetAllocationSlice(gemm.getB())); - TF_ASSIGN_OR_RETURN(auto c, GetAllocationSlice(gemm.getC())); bool deterministic_ops = ir_emitter_context_->debug_options().xla_gpu_deterministic_ops(); - TF_ASSIGN_OR_RETURN(GemmConfig config, GemmConfig::For(gemm)); + TF_ASSIGN_OR_RETURN( + GemmConfig config, + GemmConfig::For(static_cast(instr))); auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(config), a, b, c, - deterministic_ops); - + Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(config), a, b, + c, workspace, deterministic_ops); AddThunkToThunkSequence(std::move(thunk)); - return OkStatus(); + return absl::OkStatus(); } #if GOOGLE_CUDA || TF_HIPBLASLT -Status IrEmitterUnnested::EmitCublasLtMatmulThunk(mlir::Operation* op) { - auto matmul = mlir::dyn_cast(op); - TF_RET_CHECK(matmul != nullptr); +absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunk( + const HloCustomCallInstruction* instr) { + TF_ASSIGN_OR_RETURN(const auto gpu_config, + instr->backend_config()); + xla::gpu::GemmBackendConfig config = gpu_config.gemm_backend_config(); + xla::gpu::GemmBackendConfig_Epilogue epilogue = config.epilogue(); - TF_ASSIGN_OR_RETURN(auto a, GetAllocationSlice(matmul.getA())); - TF_ASSIGN_OR_RETURN(auto b, GetAllocationSlice(matmul.getB())); - TF_ASSIGN_OR_RETURN(auto c, GetAllocationSlice(matmul.getC())); - TF_ASSIGN_OR_RETURN(auto d, GetAllocationSlice(matmul.getD())); + TF_ASSIGN_OR_RETURN(bool has_vector_bias, + xla::gpu::gpublas_lt::EpilogueAddsVectorBias(epilogue)); + bool has_matrix_bias = config.beta() != 0; + + TF_RET_CHECK(instr->operand_count() == + 2 + int{has_matrix_bias} + int{has_vector_bias}); + + TF_ASSIGN_OR_RETURN( + bool has_aux_output, + xla::gpu::gpublas_lt::EpilogueHasAuxiliaryOutput(epilogue)); + xla::ShapeIndex output_index = + has_aux_output ? xla::ShapeIndex{0} : xla::ShapeIndex{}; + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a, + GetAllocationSliceForHlo(instr->operand(0))); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice b, + GetAllocationSliceForHlo(instr->operand(1))); + BufferAllocation::Slice c; + if (has_matrix_bias) { + TF_ASSIGN_OR_RETURN(c, GetAllocationSliceForHlo(instr->operand(2))); + } else { + TF_ASSIGN_OR_RETURN(c, GetAllocationSliceForHlo(instr, output_index)); + } + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d, + GetAllocationSliceForHlo(instr, output_index)); - BufferAllocation::Slice bias, a_scale, b_scale, c_scale, d_scale, d_amax; - if (matmul.getBias() != nullptr) { - TF_ASSIGN_OR_RETURN(bias, GetAllocationSlice(matmul.getBias())); + BufferAllocation::Slice bias; + if (has_vector_bias) { + TF_ASSIGN_OR_RETURN(bias, GetAllocationSliceForHlo( + instr->operand(has_matrix_bias ? 3 : 2))); } BufferAllocation::Slice aux; - if (matmul.getAux() != nullptr) { - TF_ASSIGN_OR_RETURN(aux, GetAllocationSlice(matmul.getAux())); + if (has_aux_output) { + TF_ASSIGN_OR_RETURN(aux, GetAllocationSliceForHlo(instr, {1})); } - TF_ASSIGN_OR_RETURN(GemmConfig gemm_config, GemmConfig::For(matmul)); - TF_ASSIGN_OR_RETURN(auto epilogue, - gpublas_lt::AsBlasLtEpilogue(matmul.getEpilogue())); + TF_ASSIGN_OR_RETURN( + auto gemm_config, + GemmConfig::For(static_cast(instr))); + + // Use the first algorithm by default (i.e. fastest according to heuristics). + int64_t algorithm = + config.algorithm_case() == GemmBackendConfig::kSelectedAlgorithm + ? config.selected_algorithm() + : 0; + + BufferAllocation::Slice a_scale, b_scale, c_scale, d_scale, d_amax; + TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::Epilogue blas_lt_epilogue, + gpublas_lt::AsBlasLtEpilogue(epilogue)); auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(gemm_config), - epilogue, matmul.getAlgorithm(), a, b, c, d, bias, aux, a_scale, b_scale, + Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(gemm_config), + blas_lt_epilogue, algorithm, a, b, c, d, bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax); - AddThunkToThunkSequence(std::move(thunk)); - return OkStatus(); + return absl::OkStatus(); } -#endif // GOOGLE_CUDA || TF_HIPBLASLT -#if GOOGLE_CUDA -Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8(mlir::Operation* op) { - auto matmul = mlir::dyn_cast(op); - TF_RET_CHECK(matmul != nullptr); +absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8( + const HloCustomCallInstruction* instr) { + TF_RET_CHECK(instr->operand_count() == 6 || instr->operand_count() == 7 || + instr->operand_count() == 8); + TF_ASSIGN_OR_RETURN(const auto gpu_config, + instr->backend_config()); + xla::gpu::GemmBackendConfig config = gpu_config.gemm_backend_config(); + xla::gpu::GemmBackendConfig_Epilogue epilogue = config.epilogue(); + + TF_ASSIGN_OR_RETURN(bool has_vector_bias, + xla::gpu::gpublas_lt::EpilogueAddsVectorBias(epilogue)); + bool has_damax = instr->shape().IsTuple(); + xla::ShapeIndex output_index = + has_damax ? xla::ShapeIndex{0} : xla::ShapeIndex{}; TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a, - GetAllocationSlice(matmul.getA())); + GetAllocationSliceForHlo(instr->operand(0))); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice b, - GetAllocationSlice(matmul.getB())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice c, - GetAllocationSlice(matmul.getC())); + GetAllocationSliceForHlo(instr->operand(1))); + BufferAllocation::Slice c; + bool has_matrix_bias = config.beta() != 0; + if (has_matrix_bias) { + TF_ASSIGN_OR_RETURN(c, GetAllocationSliceForHlo(instr->operand(2))); + } else { + TF_ASSIGN_OR_RETURN(c, GetAllocationSliceForHlo(instr, output_index)); + } TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d, - GetAllocationSlice(matmul.getD())); + GetAllocationSliceForHlo(instr, output_index)); + + int a_scale_index = has_matrix_bias ? 3 : 2; TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a_scale, - GetAllocationSlice(matmul.getAScale())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice b_scale, - GetAllocationSlice(matmul.getBScale())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice c_scale, - GetAllocationSlice(matmul.getCScale())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_scale, - GetAllocationSlice(matmul.getDScale())); - BufferAllocation::Slice d_amax, bias; - if (matmul.getDAmax() != nullptr) { - TF_ASSIGN_OR_RETURN(d_amax, GetAllocationSlice(matmul.getDAmax())); + GetAllocationSliceForHlo(instr->operand(a_scale_index))); + TF_ASSIGN_OR_RETURN( + BufferAllocation::Slice b_scale, + GetAllocationSliceForHlo(instr->operand(a_scale_index + 1))); +#if GOOGLE_CUDA + TF_ASSIGN_OR_RETURN( + BufferAllocation::Slice c_scale, + GetAllocationSliceForHlo(instr->operand(a_scale_index + 2))); + TF_ASSIGN_OR_RETURN( + BufferAllocation::Slice d_scale, + GetAllocationSliceForHlo(instr->operand(a_scale_index + 3))); +#else // TENSORFLOW_USE_ROCM + BufferAllocation::Slice c_scale; + BufferAllocation::Slice d_scale; +#endif + + BufferAllocation::Slice bias; + if (has_vector_bias) { + TF_ASSIGN_OR_RETURN( + bias, GetAllocationSliceForHlo(instr->operand(a_scale_index + 4))); } - if (matmul.getBias() != nullptr) { - TF_ASSIGN_OR_RETURN(bias, GetAllocationSlice(matmul.getBias())); + + BufferAllocation::Slice d_amax; + if (has_damax) { + TF_ASSIGN_OR_RETURN(d_amax, GetAllocationSliceForHlo(instr, {1})); } + TF_ASSIGN_OR_RETURN( + auto gemm_config, + GemmConfig::For(static_cast(instr))); + + // Use the first algorithm by default (i.e. fastest according to heuristics). + int64_t algorithm = + config.algorithm_case() == GemmBackendConfig::kSelectedAlgorithm + ? config.selected_algorithm() + : 0; + BufferAllocation::Slice aux; // Not used. - TF_ASSIGN_OR_RETURN(GemmConfig gemm_config, GemmConfig::For(matmul)); - TF_ASSIGN_OR_RETURN(auto epilogue, - gpublas_lt::AsBlasLtEpilogue(matmul.getEpilogue())); + TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::Epilogue blas_lt_epilogue, + gpublas_lt::AsBlasLtEpilogue(epilogue)); auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(gemm_config), - epilogue, matmul.getAlgorithm(), a, b, c, d, bias, aux, a_scale, b_scale, + Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(gemm_config), + blas_lt_epilogue, algorithm, a, b, c, d, bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax); - AddThunkToThunkSequence(std::move(thunk)); - return OkStatus(); + return absl::OkStatus(); } +#endif // GOOGLE_CUDA || TF_HIPBLASLT -Status IrEmitterUnnested::EmitConvolutionReorderThunk(mlir::Operation* op) { - using mlir::dyn_cast; - using mlir::lmhlo_gpu::CudnnConvReorderFilterAndBiasOp; - using mlir::lmhlo_gpu::CudnnConvReorderFilterOp; - - std::vector operand_slices; - std::vector result_slices; - std::vector filter_dims; - - auto set_filter_data = [&](auto op) -> Status { - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice filter_input, - GetAllocationSlice(op.getFilterInput())); - operand_slices.push_back(filter_input); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice filter_output, - GetAllocationSlice(op.getFilterOutput())); - result_slices.push_back(filter_output); - - auto filter_dims_values = op.getFilterDims().template getValues(); - filter_dims.assign(filter_dims_values.begin(), filter_dims_values.end()); - return OkStatus(); - }; - - if (auto reorder = dyn_cast(op)) { - TF_RETURN_IF_ERROR(set_filter_data(reorder)); - +#if GOOGLE_CUDA +absl::Status IrEmitterUnnested::EmitConvolutionReorderThunk( + const HloCustomCallInstruction* instr) { + bool has_bias = instr->operand_count() > 1; + Shape shape = has_bias ? instr->shape().tuple_shapes(0) : instr->shape(); + if (shape.rank() != 5 || shape.dimensions(4) != 32) { + return Internal("Unexpected shape for convolution reorder: %s", + instr->ToString()); + } + absl::InlinedVector filter_dims = { + shape.dimensions(0), shape.dimensions(1) * 32, shape.dimensions(2), + shape.dimensions(3)}; + + absl::InlinedVector operand_slices; + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice filter_input, + GetAllocationSliceForHlo(instr->operand(0))); + operand_slices.push_back(filter_input); + if (has_bias) { TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bias_input, - GetAllocationSlice(reorder.getBiasInput())); + GetAllocationSliceForHlo(instr->operand(1))); operand_slices.push_back(bias_input); + } + absl::InlinedVector result_slices; + if (has_bias) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice filter_output, + GetAllocationSliceForHlo(instr, {0})); + result_slices.push_back(filter_output); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bias_output, - GetAllocationSlice(reorder.getBiasOutput())); + GetAllocationSliceForHlo(instr, {1})); result_slices.push_back(bias_output); - } else if (auto reorder = dyn_cast(op)) { - TF_RETURN_IF_ERROR(set_filter_data(reorder)); } else { - return InternalError("Unexpected operation"); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice filter_output, + GetAllocationSliceForHlo(instr)); + result_slices.push_back(filter_output); } auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), absl::MakeSpan(filter_dims), - std::move(operand_slices), std::move(result_slices)); - + Thunk::ThunkInfo::WithProfileAnnotation(instr), + absl::MakeSpan(filter_dims), operand_slices, result_slices); AddThunkToThunkSequence(std::move(thunk)); - return OkStatus(); + return absl::OkStatus(); } -Status IrEmitterUnnested::EmitNormThunk(mlir::Operation* op) { - auto norm = mlir::dyn_cast(op); - TF_RET_CHECK(norm != nullptr); +absl::Status IrEmitterUnnested::EmitNormThunk( + const HloCustomCallInstruction* instr) { + TF_ASSIGN_OR_RETURN(auto const gpu_backend_config, + instr->backend_config()); + const xla::gpu::CudnnNormBackendConfig& backend_config = + gpu_backend_config.cudnn_norm_backend_config(); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice, - GetAllocationSlice(norm.getInput())); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice x_slice, + GetAllocationSliceForHlo(instr->operand(0))); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scale_slice, - GetAllocationSlice(norm.getScale())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bias_slice, - GetAllocationSlice(norm.getBias())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, - GetAllocationSlice(norm.getOutput())); + GetAllocationSliceForHlo(instr->operand(1))); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice y_or_dx_slice, + GetAllocationSliceForHlo(instr, {0})); + + std::optional bias_slice, expectation_slice, + norm_factor_slice, dy_slice, dscale_slice, dbias_slice; - int64_t num_operands = op->getNumOperands(); - std::optional expectation_slice, norm_factor_slice; - if (num_operands == 7) { + if (backend_config.kind() == + xla::gpu::CudnnNormBackendConfig::LAYER_FWD_INFER || + backend_config.kind() == + xla::gpu::CudnnNormBackendConfig::LAYER_FWD_TRAIN) { + TF_ASSIGN_OR_RETURN(bias_slice, + GetAllocationSliceForHlo(instr->operand(2))); + } + if (backend_config.kind() == + xla::gpu::CudnnNormBackendConfig::LAYER_FWD_TRAIN) { TF_ASSIGN_OR_RETURN(expectation_slice, - GetAllocationSlice(norm.getExpectation())); + GetAllocationSliceForHlo(instr, {1})); TF_ASSIGN_OR_RETURN(norm_factor_slice, - GetAllocationSlice(norm.getNormFactor())); + GetAllocationSliceForHlo(instr, {2})); + } + if (backend_config.kind() == xla::gpu::CudnnNormBackendConfig::LAYER_BWD) { + TF_ASSIGN_OR_RETURN(dy_slice, GetAllocationSliceForHlo(instr->operand(2))); + TF_ASSIGN_OR_RETURN(expectation_slice, + GetAllocationSliceForHlo(instr->operand(3))); + TF_ASSIGN_OR_RETURN(norm_factor_slice, + GetAllocationSliceForHlo(instr->operand(4))); + TF_ASSIGN_OR_RETURN(dscale_slice, GetAllocationSliceForHlo(instr, {1})); + TF_ASSIGN_OR_RETURN(dbias_slice, GetAllocationSliceForHlo(instr, {2})); } - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice, - GetAllocationSlice(norm.getScratch())); + GetAllocationSliceForHlo( + instr, {instr->shape().tuple_shapes_size() - 1})); GpuNormDescriptor descriptor; - auto* algorithm = descriptor.backend_config.mutable_algorithm(); - algorithm->set_algo_id(norm.getAlgorithmConfig().getAlgorithm()); - algorithm->set_is_cudnn_frontend(true); - auto workspace_size = norm.getAlgorithmConfig().getWorkspaceSize(); - algorithm->mutable_workspace_size()->set_value(workspace_size); - - descriptor.input_shape = GetShape(norm->getOperand(0)); - descriptor.scale_shape = GetShape(norm->getOperand(1)); - descriptor.bias_shape = GetShape(norm->getOperand(2)); - descriptor.output_shape = GetShape(norm->getOperand(3)); - if (num_operands == 7) { - descriptor.expectation_shape = GetShape(norm->getOperand(4)); - descriptor.norm_factor_shape = GetShape(norm->getOperand(5)); - } - descriptor.backend_config.set_epsilon(norm.getEpsilon().convertToDouble()); + descriptor.backend_config = backend_config; + + descriptor.x_shape = instr->operand(0)->shape(); + descriptor.scale_shape = instr->operand(1)->shape(); + descriptor.y_or_dx_shape = ShapeUtil::GetSubshape(instr->shape(), {0}); + if (backend_config.kind() == + xla::gpu::CudnnNormBackendConfig::LAYER_FWD_INFER || + backend_config.kind() == + xla::gpu::CudnnNormBackendConfig::LAYER_FWD_TRAIN) { + descriptor.bias_shape = instr->operand(2)->shape(); + } + if (backend_config.kind() == + xla::gpu::CudnnNormBackendConfig::LAYER_FWD_TRAIN) { + descriptor.expectation_shape = ShapeUtil::GetSubshape(instr->shape(), {1}); + descriptor.norm_factor_shape = ShapeUtil::GetSubshape(instr->shape(), {2}); + } + if (backend_config.kind() == xla::gpu::CudnnNormBackendConfig::LAYER_BWD) { + descriptor.dy_shape = instr->operand(2)->shape(); + descriptor.expectation_shape = instr->operand(3)->shape(); + descriptor.norm_factor_shape = instr->operand(4)->shape(); + descriptor.dscale_shape = ShapeUtil::GetSubshape(instr->shape(), {1}); + descriptor.dbias_shape = ShapeUtil::GetSubshape(instr->shape(), {2}); + } TF_ASSIGN_OR_RETURN(GpuNormConfig config, GpuNormConfig::For(descriptor)); auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(config), - input_slice, scale_slice, bias_slice, output_slice, expectation_slice, - norm_factor_slice, scratch_slice); - + Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(config), + x_slice, scale_slice, y_or_dx_slice, bias_slice, expectation_slice, + norm_factor_slice, dy_slice, dscale_slice, dbias_slice, scratch_slice); AddThunkToThunkSequence(std::move(thunk)); - - return OkStatus(); + return absl::OkStatus(); } -Status IrEmitterUnnested::EmitFusedMHAThunk(mlir::Operation* op) { - using mlir::dyn_cast; - using mlir::lmhlo_gpu::fusedMHAOp; - GpufMHADescriptor descriptor; - BufferAllocation::Slice lhs_bmm1_slice, rhs_bmm1_slice, rhs_bmm2_slice, - output_slice, scratch_slice, activation_slice, mask_slice, bias_slice; - - auto populate_common = [&](auto fmha) -> Status { - descriptor.backend_config.set_fmha_scale( - fmha.getFmhaScale().convertToDouble()); - - if (fmha.getDropoutRate()) { - descriptor.backend_config.set_dropout_rate( - (*fmha.getDropoutRate()).convertToDouble()); - } - - if (fmha.getSeed()) { - descriptor.backend_config.set_seed((*fmha.getSeed())); - } - - auto* algorithm = descriptor.backend_config.mutable_algorithm(); - algorithm->set_algo_id(fmha.getAlgorithmConfig().getAlgorithm()); - for (int i = 0; i < fmha.getAlgorithmConfig().getKnobIds().size(); ++i) { - // N.B. tuning_knobs is a map rather than a repeated field, so this - // doesn't require reserving space up front. - (*algorithm->mutable_tuning_knobs())[fmha.getAlgorithmConfig() - .getKnobIds()[i]] = - fmha.getAlgorithmConfig().getKnobValues()[i]; - } - algorithm->set_is_cudnn_frontend(true); - auto workspace_size = fmha.getAlgorithmConfig().getWorkspaceSize(); - if (workspace_size >= 0) { - algorithm->mutable_workspace_size()->set_value(workspace_size); - } - - descriptor.bmm1_dnums = - ConvertDotDimensionNumbers(fmha.getBmm1DotDimensionNumbers()); - descriptor.bmm2_dnums = - ConvertDotDimensionNumbers(fmha.getBmm2DotDimensionNumbers()); - - descriptor.lhs_bmm1_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getLhsBmm1()).element_type(), - GetShape(fmha.getLhsBmm1()).dimensions(), - GetShape(fmha.getLhsBmm1()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(lhs_bmm1_slice, GetAllocationSlice(fmha.getLhsBmm1())); - - descriptor.rhs_bmm1_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getRhsBmm1()).element_type(), - GetShape(fmha.getRhsBmm1()).dimensions(), - GetShape(fmha.getRhsBmm1()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(rhs_bmm1_slice, GetAllocationSlice(fmha.getRhsBmm1())); - - descriptor.rhs_bmm2_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getRhsBmm2()).element_type(), - GetShape(fmha.getRhsBmm2()).dimensions(), - GetShape(fmha.getRhsBmm2()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(rhs_bmm2_slice, GetAllocationSlice(fmha.getRhsBmm2())); - - descriptor.output_shapes.push_back(ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getOutput()).element_type(), - GetShape(fmha.getOutput()).dimensions(), - GetShape(fmha.getOutput()).layout().minor_to_major())); - TF_ASSIGN_OR_RETURN(output_slice, GetAllocationSlice(fmha.getOutput())); - - TF_ASSIGN_OR_RETURN(scratch_slice, GetAllocationSlice(fmha.getScratch())); - - TF_ASSIGN_OR_RETURN(auto intermediate_tensor_dims_array, - ConvertMlirArrayAttrToInt64Array( - fmha.getIntermediateTensorDimensions())); - if (fmha.getActivation() != nullptr) { - descriptor.output_shapes.push_back(ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getActivation()).element_type(), - GetShape(fmha.getActivation()).dimensions(), - GetShape(fmha.getActivation()).layout().minor_to_major())); - TF_ASSIGN_OR_RETURN(activation_slice, - GetAllocationSlice(fmha.getActivation())); - } - - if (fmha.getBias() != nullptr) { - descriptor.bias_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getBias()).element_type(), - GetShape(fmha.getBias()).dimensions(), - GetShape(fmha.getBias()).layout().minor_to_major()); - - TF_ASSIGN_OR_RETURN(bias_slice, GetAllocationSlice(fmha.getBias())); - } - - if (fmha.getMask() != nullptr) { - descriptor.mask_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getMask()).element_type(), - GetShape(fmha.getMask()).dimensions(), - GetShape(fmha.getMask()).layout().minor_to_major()); - - TF_ASSIGN_OR_RETURN(mask_slice, GetAllocationSlice(fmha.getMask())); - } - TF_ASSIGN_OR_RETURN( - auto intermediate_tensor_layout_array, - ConvertMlirArrayAttrToInt64Array(fmha.getIntermediateTensorLayout())); - - descriptor.intermediate_lhs_bmm2_shape = - ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getOutput()).element_type(), - intermediate_tensor_dims_array, intermediate_tensor_layout_array); - - // set if flash attention here - descriptor.is_flash_attention = fmha.getIsFlashAttention(); - // set if causal mask here - descriptor.is_causal_mask = fmha.getIsCausalMask(); - return OkStatus(); - }; +absl::Status IrEmitterUnnested::EmitFusedMHAThunk( + const HloCustomCallInstruction* instr) { + const HloInstruction* lhs_bmm1 = instr->operand(0); + const HloInstruction* rhs_bmm1 = instr->operand(1); + const HloInstruction* rhs_bmm2 = instr->operand(2); - if (auto fmha_op = dyn_cast(op)) { - TF_RET_CHECK(fmha_op != nullptr); - TF_ASSIGN_OR_RETURN(CudnnfMHAKind kind, - AsCudnnfMHAKind(fmha_op.getFusedMhaDag())); - descriptor.kind = kind; - TF_RETURN_IF_ERROR(populate_common(fmha_op)); - } else { - return InternalError("Unexpected operation"); - } - TF_ASSIGN_OR_RETURN(GpufMHAConfig config, GpufMHAConfig::For(descriptor)); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_bmm1_slice, + GetAllocationSliceForHlo(lhs_bmm1)); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_bmm1_slice, + GetAllocationSliceForHlo(rhs_bmm1)); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_bmm2_slice, + GetAllocationSliceForHlo(rhs_bmm2)); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, + GetAllocationSliceForHlo(instr, {0})); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice, + GetAllocationSliceForHlo(instr, {1})); + BufferAllocation::Slice activation_slice; + bool has_activation = xla::ShapeUtil::TupleElementCount(instr->shape()) == 3; + if (has_activation) { + TF_ASSIGN_OR_RETURN(activation_slice, GetAllocationSliceForHlo(instr, {2})); + } + + TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, + xla::gpu::GetCudnnfMHAKind(instr)); + BufferAllocation::Slice mask_slice, bias_slice; + BufferAllocation::Slice seqlen_q_slice, seqlen_k_slice; + std::optional mask_shape, bias_shape; + { + bool has_mask = kind == CudnnfMHAKind::kScaleMaskSoftmax || + kind == CudnnfMHAKind::kScaleMaskSoftmaxDropout || + kind == CudnnfMHAKind::kScaleBiasMaskSoftmax || + kind == CudnnfMHAKind::kScaleBiasMaskSoftmaxDropout; + bool has_bias = kind == CudnnfMHAKind::kScaleBiasMaskSoftmax || + kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout || + kind == CudnnfMHAKind::kScaleBiasSoftmax || + kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout; + + if (has_mask) { + const HloInstruction* mask = instr->operand(3); + TF_ASSIGN_OR_RETURN(mask_slice, GetAllocationSliceForHlo(mask)); + mask_shape = mask->shape(); + if (has_bias) { + const HloInstruction* bias = instr->operand(4); + TF_ASSIGN_OR_RETURN(bias_slice, GetAllocationSliceForHlo(bias)); + bias_shape = bias->shape(); + } + } else if (has_bias) { + const HloInstruction* bias = instr->operand(3); + TF_ASSIGN_OR_RETURN(bias_slice, GetAllocationSliceForHlo(bias)); + bias_shape = bias->shape(); + } + int64_t seqlen_qk_operand_index = 3 + has_mask + has_bias; + bool has_seqlen_qk = seqlen_qk_operand_index == instr->operand_count() - 2; + if (has_seqlen_qk) { + const HloInstruction* seqlen_q = instr->operand(seqlen_qk_operand_index); + TF_ASSIGN_OR_RETURN(seqlen_q_slice, GetAllocationSliceForHlo(seqlen_q)); + const HloInstruction* seqlen_k = + instr->operand(seqlen_qk_operand_index + 1); + TF_ASSIGN_OR_RETURN(seqlen_k_slice, GetAllocationSliceForHlo(seqlen_k)); + } + } + + TF_ASSIGN_OR_RETURN(const auto gpu_config, + instr->backend_config()); + const xla::gpu::CudnnfMHABackendConfig& config = + gpu_config.cudnn_fmha_backend_config(); + Shape intermediate_tensor_shape(config.intermediate_tensor_shape()); + absl::InlinedVector output_shapes = { + ShapeUtil::GetSubshape(instr->shape(), {0})}; + if (has_activation) { + output_shapes.push_back(ShapeUtil::GetSubshape(instr->shape(), {2})); + } + + GpufMHADescriptor descriptor = {kind, + config, + config.is_flash_attention(), + config.is_causal_mask(), + lhs_bmm1->shape(), + rhs_bmm1->shape(), + rhs_bmm2->shape(), + intermediate_tensor_shape, + output_shapes, + config.bmm1_dot_dimension_numbers(), + config.bmm2_dot_dimension_numbers(), + mask_shape, + bias_shape}; + + TF_ASSIGN_OR_RETURN(GpufMHAConfig fmha_config, + GpufMHAConfig::For(descriptor)); AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(config), + Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(fmha_config), lhs_bmm1_slice, rhs_bmm1_slice, rhs_bmm2_slice, output_slice, - scratch_slice, mask_slice, bias_slice, activation_slice)); - return OkStatus(); -} - -Status IrEmitterUnnested::EmitFusedMHABackwardThunk(mlir::Operation* op) { - using mlir::dyn_cast; - using mlir::lmhlo_gpu::fusedMHABackwardOp; - - GpufMHABackwardDescriptor descriptor; - BufferAllocation::Slice bmm1_grad_gemm1_rhs_slice, bmm1_grad_gemm2_rhs_slice, - bmm2_grad_gemm1_lhs_slice, bmm2_grad_gemm2_rhs_slice, d_output_slice, - scratch_slice, mask_slice, fwd_output_slice, bias_slice; - BufferAllocation::Slice d_bmm1_lhs_slice, d_bmm1_rhs_slice, d_bmm2_rhs_slice, - d_s_slice, softmax_sum_slice, d_Q_accum_slice, d_bias_slice; - - auto populate_common = [&](auto fmha) -> Status { - descriptor.backend_config.set_fmha_scale( - fmha.getFmhaScale().convertToDouble()); - - if (fmha.getDropoutRate()) { - descriptor.backend_config.set_dropout_rate( - (*fmha.getDropoutRate()).convertToDouble()); - } - - if (fmha.getSeed()) { - descriptor.backend_config.set_seed((*fmha.getSeed())); - } - - auto* algorithm = descriptor.backend_config.mutable_algorithm(); - algorithm->set_algo_id(fmha.getAlgorithmConfig().getAlgorithm()); - for (int i = 0; i < fmha.getAlgorithmConfig().getKnobIds().size(); ++i) { - // N.B. tuning_knobs is a map rather than a repeated field, so this - // doesn't require reserving space up front. - (*algorithm->mutable_tuning_knobs())[fmha.getAlgorithmConfig() - .getKnobIds()[i]] = - fmha.getAlgorithmConfig().getKnobValues()[i]; - } - algorithm->set_is_cudnn_frontend(true); - auto workspace_size = fmha.getAlgorithmConfig().getWorkspaceSize(); - if (workspace_size >= 0) { - algorithm->mutable_workspace_size()->set_value(workspace_size); - } - - // set if flash attention here - descriptor.is_flash_attention = fmha.getIsFlashAttention(); - // set if causal mask here - descriptor.is_causal_mask = fmha.getIsCausalMask(); - descriptor.bmm1_grad_gemm1_dnums = - ConvertDotDimensionNumbers(fmha.getBmm1GradGemm1DotDimensionNumbers()); - descriptor.bmm1_grad_gemm2_dnums = - ConvertDotDimensionNumbers(fmha.getBmm1GradGemm2DotDimensionNumbers()); - descriptor.bmm2_grad_gemm1_dnums = - ConvertDotDimensionNumbers(fmha.getBmm2GradGemm1DotDimensionNumbers()); - descriptor.bmm2_grad_gemm2_dnums = - ConvertDotDimensionNumbers(fmha.getBmm2GradGemm2DotDimensionNumbers()); - - descriptor.bmm1_grad_gemm1_rhs_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getBmm1GradGemm1Rhs()).element_type(), - GetShape(fmha.getBmm1GradGemm1Rhs()).dimensions(), - GetShape(fmha.getBmm1GradGemm1Rhs()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(bmm1_grad_gemm1_rhs_slice, - GetAllocationSlice(fmha.getBmm1GradGemm1Rhs())); - - descriptor.bmm1_grad_gemm2_rhs_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getBmm1GradGemm2Rhs()).element_type(), - GetShape(fmha.getBmm1GradGemm2Rhs()).dimensions(), - GetShape(fmha.getBmm1GradGemm2Rhs()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(bmm1_grad_gemm2_rhs_slice, - GetAllocationSlice(fmha.getBmm1GradGemm2Rhs())); - - // fwd activation - // fmha.getBmm2GradGemm1Lhs() could be bmm2_grad_gemm1_lhs for regular - // attention or softmax stats for flash attention here we set the shape to - // be bmm2_grad_gemm1_lhs even it is flash attention - if (descriptor.is_flash_attention) { - // flash attention TODO: make sure the layout is correct for - // bmm2_grad_gemm1_lhs - TF_ASSIGN_OR_RETURN(auto intermediate_tensor_dims_array, - ConvertMlirArrayAttrToInt64Array( - fmha.getIntermediateTensorDimensions())); - TF_ASSIGN_OR_RETURN( - auto intermediate_tensor_layout_array, - ConvertMlirArrayAttrToInt64Array(fmha.getIntermediateTensorLayout())); - - descriptor.bmm2_grad_gemm1_lhs_shape = - ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getDOutput()).element_type(), - intermediate_tensor_dims_array, intermediate_tensor_layout_array); - } else { - descriptor.bmm2_grad_gemm1_lhs_shape = - ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getBmm2GradGemm1Lhs()).element_type(), - GetShape(fmha.getBmm2GradGemm1Lhs()).dimensions(), - GetShape(fmha.getBmm2GradGemm1Lhs()).layout().minor_to_major()); - } - TF_ASSIGN_OR_RETURN(bmm2_grad_gemm1_lhs_slice, - GetAllocationSlice(fmha.getBmm2GradGemm1Lhs())); - - descriptor.bmm2_grad_gemm2_rhs_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getBmm2GradGemm2Rhs()).element_type(), - GetShape(fmha.getBmm2GradGemm2Rhs()).dimensions(), - GetShape(fmha.getBmm2GradGemm2Rhs()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(bmm2_grad_gemm2_rhs_slice, - GetAllocationSlice(fmha.getBmm2GradGemm2Rhs())); - - descriptor.d_output_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getDOutput()).element_type(), - GetShape(fmha.getDOutput()).dimensions(), - GetShape(fmha.getDOutput()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(d_output_slice, GetAllocationSlice(fmha.getDOutput())); - descriptor.d_bmm1_lhs_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getDBmm1Lhs()).element_type(), - GetShape(fmha.getDBmm1Lhs()).dimensions(), - GetShape(fmha.getDBmm1Lhs()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(d_bmm1_lhs_slice, - GetAllocationSlice(fmha.getDBmm1Lhs())); - - descriptor.d_bmm1_rhs_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getDBmm1Rhs()).element_type(), - GetShape(fmha.getDBmm1Rhs()).dimensions(), - GetShape(fmha.getDBmm1Rhs()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(d_bmm1_rhs_slice, - GetAllocationSlice(fmha.getDBmm1Rhs())); - - descriptor.d_bmm2_rhs_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getDBmm2Rhs()).element_type(), - GetShape(fmha.getDBmm2Rhs()).dimensions(), - GetShape(fmha.getDBmm2Rhs()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(d_bmm2_rhs_slice, - GetAllocationSlice(fmha.getDBmm2Rhs())); - - TF_ASSIGN_OR_RETURN(scratch_slice, GetAllocationSlice(fmha.getScratch())); - - if (fmha.getD_S() != nullptr) { - descriptor.d_s_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getD_S()).element_type(), - GetShape(fmha.getD_S()).dimensions(), - GetShape(fmha.getD_S()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(d_s_slice, GetAllocationSlice(fmha.getD_S())); - } - - if (fmha.getDBias() != nullptr) { - descriptor.d_bias_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getDBias()).element_type(), - GetShape(fmha.getDBias()).dimensions(), - GetShape(fmha.getDBias()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(d_bias_slice, GetAllocationSlice(fmha.getDBias())); - } - - if (fmha.getMask() != nullptr) { - // has mask input - TF_RET_CHECK( - descriptor.kind != xla::gpu::CudnnfMHAKind::kBackwardBmmBmm && - descriptor.kind != xla::gpu::CudnnfMHAKind::kBackwardSoftmaxDropout && - descriptor.kind != xla::gpu::CudnnfMHAKind::kBackwardSoftmax); - - descriptor.mask_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getMask()).element_type(), - GetShape(fmha.getMask()).dimensions(), - GetShape(fmha.getMask()).layout().minor_to_major()); - - TF_ASSIGN_OR_RETURN(mask_slice, GetAllocationSlice(fmha.getMask())); - } - // add flash attention backward related slice here - if (fmha.getBias() != nullptr) { - descriptor.bias_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getBias()).element_type(), - GetShape(fmha.getBias()).dimensions(), - GetShape(fmha.getBias()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(bias_slice, GetAllocationSlice(fmha.getBias())); - } - - if (fmha.getSoftmaxSum() != nullptr) { - TF_ASSIGN_OR_RETURN(softmax_sum_slice, - GetAllocationSlice(fmha.getSoftmaxSum())); - } - - if (fmha.getD_QAccum() != nullptr) { - TF_ASSIGN_OR_RETURN(d_Q_accum_slice, - GetAllocationSlice(fmha.getD_QAccum())); - } - - if (fmha.getFwdOutput() != nullptr) { - descriptor.fwd_output_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getFwdOutput()).element_type(), - GetShape(fmha.getFwdOutput()).dimensions(), - GetShape(fmha.getFwdOutput()).layout().minor_to_major()); - TF_ASSIGN_OR_RETURN(fwd_output_slice, - GetAllocationSlice(fmha.getFwdOutput())); - } - return OkStatus(); - }; - - if (auto fmha_backward_op = dyn_cast(op)) { - TF_RET_CHECK(fmha_backward_op != nullptr); - TF_ASSIGN_OR_RETURN( - CudnnfMHAKind kind, - AsCudnnBackwardfMHAKind(fmha_backward_op.getFusedMhaDag())); - descriptor.kind = kind; - TF_RETURN_IF_ERROR(populate_common(fmha_backward_op)); + scratch_slice, mask_slice, bias_slice, activation_slice, seqlen_q_slice, + seqlen_k_slice)); + return absl::OkStatus(); +} + +absl::Status IrEmitterUnnested::EmitFusedMHABackwardThunk( + const HloCustomCallInstruction* instr) { + TF_ASSIGN_OR_RETURN(const auto gpu_config, + instr->backend_config()); + const xla::gpu::CudnnfMHABackendConfig& config = + gpu_config.cudnn_fmha_backend_config(); + bool is_flash_attention = config.is_flash_attention(); + + int input_index = 0; + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm1_grad_gemm1_rhs_slice, + GetAllocationSliceForHlo(instr->operand(input_index))); + Shape bmm1_grad_gemm1_rhs_shape = instr->operand(input_index++)->shape(); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm1_grad_gemm2_rhs_slice, + GetAllocationSliceForHlo(instr->operand(input_index))); + Shape bmm1_grad_gemm2_rhs_shape = instr->operand(input_index++)->shape(); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm2_grad_gemm2_rhs_slice, + GetAllocationSliceForHlo(instr->operand(input_index))); + Shape bmm2_grad_gemm2_rhs_shape = instr->operand(input_index++)->shape(); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm2_grad_gemm1_lhs_slice, + GetAllocationSliceForHlo(instr->operand(input_index))); + Shape bmm2_grad_gemm1_lhs_shape; + + // fmha.getBmm2GradGemm1Lhs() could be bmm2_grad_gemm1_lhs for regular + // attention or softmax stats for flash attention here we set the shape to + // be bmm2_grad_gemm1_lhs even it is flash attention + if (is_flash_attention) { + // flash attention TODO: make sure the layout is correct for + // bmm2_grad_gemm1_lhs + Shape intermediate_tensor_shape(config.intermediate_tensor_shape()); + bmm2_grad_gemm1_lhs_shape = intermediate_tensor_shape; + input_index++; } else { - return InternalError("Unexpected operation"); + bmm2_grad_gemm1_lhs_shape = instr->operand(input_index++)->shape(); + } + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_output_slice, + GetAllocationSliceForHlo(instr->operand(input_index))); + Shape d_output_shape = instr->operand(input_index++)->shape(); + + TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind, GetCudnnfMHAKind(instr)); + bool has_mask = kind == CudnnfMHAKind::kBackwardScaleMaskSoftmax || + kind == CudnnfMHAKind::kBackwardScaleBiasMaskSoftmax || + kind == CudnnfMHAKind::kBackwardScaleMaskSoftmaxDropout || + kind == CudnnfMHAKind::kBackwardScaleBiasMaskSoftmaxDropout; + BufferAllocation::Slice mask_slice; + std::optional mask_shape; + if (has_mask) { + TF_ASSIGN_OR_RETURN(mask_slice, + GetAllocationSliceForHlo(instr->operand(input_index))); + mask_shape = instr->operand(input_index++)->shape(); + } + + bool has_bias = is_flash_attention && + (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax || + kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout || + kind == CudnnfMHAKind::kBackwardScaleBiasMaskSoftmax || + kind == CudnnfMHAKind::kBackwardScaleBiasMaskSoftmaxDropout); + BufferAllocation::Slice bias_slice; + std::optional bias_shape; + if (has_bias) { + TF_ASSIGN_OR_RETURN(bias_slice, + GetAllocationSliceForHlo(instr->operand(input_index))); + bias_shape = instr->operand(input_index++)->shape(); + } + + BufferAllocation::Slice fwd_output_slice; + std::optional fwd_output_shape; + if (is_flash_attention) { + TF_ASSIGN_OR_RETURN(fwd_output_slice, + GetAllocationSliceForHlo(instr->operand(input_index))); + fwd_output_shape = instr->operand(input_index++)->shape(); + } + + BufferAllocation::Slice seqlen_q_slice, seqlen_k_slice; + bool has_seqlen_qk = input_index == instr->operand_count() - 2; + if (has_seqlen_qk) { + const HloInstruction* seqlen_q = instr->operand(input_index); + TF_ASSIGN_OR_RETURN(seqlen_q_slice, GetAllocationSliceForHlo(seqlen_q)); + const HloInstruction* seqlen_k = instr->operand(input_index + 1); + TF_ASSIGN_OR_RETURN(seqlen_k_slice, GetAllocationSliceForHlo(seqlen_k)); + input_index += 2; + } + TF_RET_CHECK(input_index == instr->operand_count()); + + int output_index = 0; + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm1_lhs_slice, + GetAllocationSliceForHlo(instr, {output_index})); + Shape d_bmm1_lhs_shape = + ShapeUtil::GetSubshape(instr->shape(), {output_index++}); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm1_rhs_slice, + GetAllocationSliceForHlo(instr, {output_index})); + Shape d_bmm1_rhs_shape = + ShapeUtil::GetSubshape(instr->shape(), {output_index++}); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm2_rhs_slice, + GetAllocationSliceForHlo(instr, {output_index})); + Shape d_bmm2_rhs_shape = + ShapeUtil::GetSubshape(instr->shape(), {output_index++}); + + BufferAllocation::Slice d_s_slice, softmax_sum_slice, d_Q_accum_slice; + std::optional d_s_shape; + if (!is_flash_attention) { + TF_ASSIGN_OR_RETURN(d_s_slice, + GetAllocationSliceForHlo(instr, {output_index})); + d_s_shape = ShapeUtil::GetSubshape(instr->shape(), {output_index++}); + } else { + TF_ASSIGN_OR_RETURN(softmax_sum_slice, + GetAllocationSliceForHlo(instr, {output_index++})); + TF_ASSIGN_OR_RETURN(d_Q_accum_slice, + GetAllocationSliceForHlo(instr, {output_index++})); } - TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig config, + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice, + GetAllocationSliceForHlo(instr, {output_index++})); + + bool has_dbias = + instr->shape().tuple_shapes().size() == 6 && !is_flash_attention; + BufferAllocation::Slice d_bias_slice; + std::optional d_bias_shape; + if (has_dbias) { + TF_ASSIGN_OR_RETURN(d_bias_slice, + GetAllocationSliceForHlo(instr, {output_index})); + d_bias_shape = ShapeUtil::GetSubshape(instr->shape(), {output_index++}); + } + + TF_RET_CHECK(output_index == instr->shape().tuple_shapes().size()); + + GpufMHABackwardDescriptor descriptor = { + kind, + config, + is_flash_attention, + config.is_causal_mask(), + bmm1_grad_gemm1_rhs_shape, + bmm1_grad_gemm2_rhs_shape, + bmm2_grad_gemm1_lhs_shape, + bmm2_grad_gemm2_rhs_shape, + d_output_shape, + d_bmm1_lhs_shape, + d_bmm1_rhs_shape, + d_bmm2_rhs_shape, + config.bmm1_grad_gemm1_dot_dimension_numbers(), + config.bmm1_grad_gemm2_dot_dimension_numbers(), + config.bmm2_grad_gemm1_dot_dimension_numbers(), + config.bmm2_grad_gemm2_dot_dimension_numbers(), + d_s_shape, + fwd_output_shape, + mask_shape, + d_bias_shape, + bias_shape}; + + TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_backward_config, GpufMHABackwardConfig::For(descriptor)); AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(config), - bmm1_grad_gemm1_rhs_slice, bmm1_grad_gemm2_rhs_slice, - bmm2_grad_gemm1_lhs_slice, bmm2_grad_gemm2_rhs_slice, d_output_slice, - scratch_slice, d_bmm1_lhs_slice, d_bmm1_rhs_slice, d_bmm2_rhs_slice, - d_s_slice, softmax_sum_slice, d_Q_accum_slice, mask_slice, d_bias_slice, - fwd_output_slice, bias_slice)); - - return OkStatus(); -} -#endif // GOOGLE_CUDA + Thunk::ThunkInfo::WithProfileAnnotation(instr), + std::move(fmha_backward_config), bmm1_grad_gemm1_rhs_slice, + bmm1_grad_gemm2_rhs_slice, bmm2_grad_gemm1_lhs_slice, + bmm2_grad_gemm2_rhs_slice, d_output_slice, scratch_slice, + d_bmm1_lhs_slice, d_bmm1_rhs_slice, d_bmm2_rhs_slice, d_s_slice, + softmax_sum_slice, d_Q_accum_slice, mask_slice, d_bias_slice, + fwd_output_slice, bias_slice, seqlen_q_slice, seqlen_k_slice)); + + return absl::OkStatus(); +} + +absl::Status IrEmitterUnnested::EmitFlashAttnFwdThunk( + const HloCustomCallInstruction* instr) { + int64_t num_operands = instr->operand_count(); + CHECK(num_operands >= 3); + + TF_ASSIGN_OR_RETURN(const auto gpu_config, + instr->backend_config()); + const auto& config = gpu_config.flash_attn_backend_config(); + + const HloInstruction* query = instr->operand(0); + const HloInstruction* key = instr->operand(1); + const HloInstruction* value = instr->operand(2); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice query_slice, + GetAllocationSliceForHlo(query)); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice key_slice, + GetAllocationSliceForHlo(key)); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice value_slice, + GetAllocationSliceForHlo(value)); + + int64_t cur_arg_idx = 3; + + BufferAllocation::Slice cu_seqlens_query_slice, cu_seqlens_key_slice; + std::optional cu_seqlens_query_shape, cu_seqlens_key_shape; + std::optional max_seqlen_q, max_seqlen_k; + if (instr->custom_call_target() == kGpuFlashAttnVarLenFwdCallTarget) { + CHECK(num_operands >= cur_arg_idx + 2); + const HloInstruction* cu_seqlens_query = instr->operand(cur_arg_idx++); + const HloInstruction* cu_seqlens_key = instr->operand(cur_arg_idx++); + TF_ASSIGN_OR_RETURN(cu_seqlens_query_slice, + GetAllocationSliceForHlo(cu_seqlens_query)); + TF_ASSIGN_OR_RETURN(cu_seqlens_key_slice, + GetAllocationSliceForHlo(cu_seqlens_key)); + cu_seqlens_query_shape = cu_seqlens_query->shape(); + cu_seqlens_key_shape = cu_seqlens_key->shape(); + CHECK(config.has_max_seqlen_q() && config.has_max_seqlen_k()); + max_seqlen_q = config.max_seqlen_q(); + max_seqlen_k = config.max_seqlen_k(); + } + + BufferAllocation::Slice alibi_slopes_slice; + std::optional alibi_slopes_shape; + if (config.has_alibi_slopes()) { + CHECK(num_operands >= cur_arg_idx + 1); + const HloInstruction* alibi_slopes = instr->operand(cur_arg_idx++); + TF_ASSIGN_OR_RETURN(alibi_slopes_slice, + GetAllocationSliceForHlo(alibi_slopes)); + alibi_slopes_shape = alibi_slopes->shape(); + } + + // These two parameters are inserted by FlashAttnNormalization pass. + BufferAllocation::Slice output_accum_slice; + BufferAllocation::Slice softmax_lse_accum_slice; + if (num_operands == cur_arg_idx + 2) { + const HloInstruction* output_accum = instr->operand(cur_arg_idx++); + const HloInstruction* softmax_lse_accum = instr->operand(cur_arg_idx++); + TF_ASSIGN_OR_RETURN(output_accum_slice, + GetAllocationSliceForHlo(output_accum)); + TF_ASSIGN_OR_RETURN(softmax_lse_accum_slice, + GetAllocationSliceForHlo(softmax_lse_accum)); + } else { + CHECK(num_operands == cur_arg_idx); + } -StatusOr IrEmitterUnnested::GetAllocationSliceForHlo( - const HloInstruction* instr, const ShapeIndex& index) const { - const BufferAssignment& buffer_assignment = - ir_emitter_context_->buffer_assignment(); - return buffer_assignment.GetUniqueSlice(instr, index); -} + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, + GetAllocationSliceForHlo(instr, {0})); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice softmax_lse_slice, + GetAllocationSliceForHlo(instr, {1})); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rng_state_slice, + GetAllocationSliceForHlo(instr, {2})); -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + const Shape& output_shape = + ShapeUtil::GetTupleElementShape(instr->shape(), 0); + const Shape& softmax_lse_shape = + ShapeUtil::GetTupleElementShape(instr->shape(), 1); -Status IrEmitterUnnested::EmitCubDeviceRadixSort(mlir::Operation* op) { - auto radix_sort_op = mlir::cast(op); - if (radix_sort_op.getInputs().size() != 1 && - radix_sort_op.getInputs().size() != 2) { - return InternalError("Invalid number of operands for radix sort"); + BufferAllocation::Slice s_dmask_slice; + std::optional s_dmask_shape; + CHECK(config.has_return_softmax()); + bool return_softmax = config.return_softmax(); + if (return_softmax) { + CHECK(xla::ShapeUtil::TupleElementCount(instr->shape()) == 4); + TF_ASSIGN_OR_RETURN(s_dmask_slice, GetAllocationSliceForHlo(instr, {3})); + s_dmask_shape = ShapeUtil::GetTupleElementShape(instr->shape(), 3); } - TF_ASSIGN_OR_RETURN(std::vector operands, - GetAllocationSlices(radix_sort_op.getInputs())); - TF_ASSIGN_OR_RETURN(std::vector results, - GetAllocationSlices(radix_sort_op.getOutput())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch, - GetAllocationSlice(radix_sort_op.getScratch())); - - auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), - GetShape(op->getOperand(0)).element_type(), - radix_sort_op.getInputs().size() == 2 - ? std::optional(GetShape(op->getOperand(1)).element_type()) - : std::nullopt, - operands, results, scratch, radix_sort_op.getDescending()); + TF_ASSIGN_OR_RETURN( + FlashAttnFwdConfig flash_attn_fwd_config, + FlashAttnFwdConfig::For( + query->shape(), key->shape(), value->shape(), cu_seqlens_query_shape, + cu_seqlens_key_shape, alibi_slopes_shape, output_shape, + softmax_lse_shape, s_dmask_shape, config.dropout_rate(), + config.scale(), config.is_causal(), max_seqlen_q, max_seqlen_k)); + AddThunkToThunkSequence(std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(instr), + std::move(flash_attn_fwd_config), query_slice, key_slice, value_slice, + cu_seqlens_query_slice, cu_seqlens_key_slice, alibi_slopes_slice, + output_accum_slice, softmax_lse_accum_slice, output_slice, + softmax_lse_slice, rng_state_slice, s_dmask_slice)); + return absl::OkStatus(); +} + +absl::Status IrEmitterUnnested::EmitFlashAttnBwdThunk( + const HloCustomCallInstruction* instr) { + int64_t num_operands = instr->operand_count(); + CHECK(num_operands >= 7); + + TF_ASSIGN_OR_RETURN(const auto gpu_config, + instr->backend_config()); + const auto& config = gpu_config.flash_attn_backend_config(); + + const HloInstruction* grad_output = instr->operand(0); + const HloInstruction* query = instr->operand(1); + const HloInstruction* key = instr->operand(2); + const HloInstruction* value = instr->operand(3); + const HloInstruction* output = instr->operand(4); + const HloInstruction* softmax_lse = instr->operand(5); + const HloInstruction* rng_state = instr->operand(6); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice grad_output_slice, + GetAllocationSliceForHlo(grad_output)); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice query_slice, + GetAllocationSliceForHlo(query)); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice key_slice, + GetAllocationSliceForHlo(key)); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice value_slice, + GetAllocationSliceForHlo(value)); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, + GetAllocationSliceForHlo(output)); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice softmax_lse_slice, + GetAllocationSliceForHlo(softmax_lse)); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rng_state_slice, + GetAllocationSliceForHlo(rng_state)); + + int64_t cur_arg_idx = 7; + + BufferAllocation::Slice cu_seqlens_query_slice, cu_seqlens_key_slice; + std::optional cu_seqlens_query_shape, cu_seqlens_key_shape; + std::optional max_seqlen_q, max_seqlen_k; + if (instr->custom_call_target() == kGpuFlashAttnVarLenBwdCallTarget) { + CHECK(num_operands >= cur_arg_idx + 2); + const HloInstruction* cu_seqlens_query = instr->operand(cur_arg_idx++); + const HloInstruction* cu_seqlens_key = instr->operand(cur_arg_idx++); + TF_ASSIGN_OR_RETURN(cu_seqlens_query_slice, + GetAllocationSliceForHlo(cu_seqlens_query)); + TF_ASSIGN_OR_RETURN(cu_seqlens_key_slice, + GetAllocationSliceForHlo(cu_seqlens_key)); + cu_seqlens_query_shape = cu_seqlens_query->shape(); + cu_seqlens_key_shape = cu_seqlens_key->shape(); + CHECK(config.has_max_seqlen_q() && config.has_max_seqlen_k()); + max_seqlen_q = config.max_seqlen_q(); + max_seqlen_k = config.max_seqlen_k(); + } + + BufferAllocation::Slice alibi_slopes_slice; + std::optional alibi_slopes_shape; + if (config.has_alibi_slopes()) { + CHECK(num_operands >= cur_arg_idx + 1); + const HloInstruction* alibi_slopes = instr->operand(cur_arg_idx++); + TF_ASSIGN_OR_RETURN(alibi_slopes_slice, + GetAllocationSliceForHlo(alibi_slopes)); + alibi_slopes_shape = alibi_slopes->shape(); + } + + CHECK(num_operands == cur_arg_idx + 1); + // This parameter is inserted by FlashAttnNormalization pass. + const HloInstruction* grad_query_accum = instr->operand(cur_arg_idx++); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice grad_query_accum_slice, + GetAllocationSliceForHlo(grad_query_accum)); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice grad_query_slice, + GetAllocationSliceForHlo(instr, {0})); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice grad_key_slice, + GetAllocationSliceForHlo(instr, {1})); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice grad_value_slice, + GetAllocationSliceForHlo(instr, {2})); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice grad_softmax_slice, + GetAllocationSliceForHlo(instr, {3})); + + const Shape& grad_query_shape = + ShapeUtil::GetTupleElementShape(instr->shape(), 0); + const Shape& grad_key_shape = + ShapeUtil::GetTupleElementShape(instr->shape(), 1); + const Shape& grad_value_shape = + ShapeUtil::GetTupleElementShape(instr->shape(), 2); + const Shape& grad_softmax_shape = + ShapeUtil::GetTupleElementShape(instr->shape(), 3); - AddThunkToThunkSequence(std::move(thunk)); - return OkStatus(); + TF_ASSIGN_OR_RETURN( + FlashAttnBwdConfig flash_attn_bwd_config, + FlashAttnBwdConfig::For( + grad_output->shape(), query->shape(), key->shape(), value->shape(), + output->shape(), softmax_lse->shape(), cu_seqlens_query_shape, + cu_seqlens_key_shape, alibi_slopes_shape, grad_query_shape, + grad_key_shape, grad_value_shape, grad_softmax_shape, + config.dropout_rate(), config.scale(), config.is_causal(), + config.deterministic(), max_seqlen_q, max_seqlen_k)); + AddThunkToThunkSequence(std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(instr), + std::move(flash_attn_bwd_config), grad_output_slice, query_slice, + key_slice, value_slice, output_slice, softmax_lse_slice, rng_state_slice, + cu_seqlens_query_slice, cu_seqlens_key_slice, alibi_slopes_slice, + grad_query_accum_slice, grad_query_slice, grad_key_slice, + grad_value_slice, grad_softmax_slice)); + return absl::OkStatus(); } -Status IrEmitterUnnested::EmitCholeskyThunk(mlir::Operation* op) { - auto cholesky_op = mlir::cast(op); - - const Shape shape = GetShape(cholesky_op.getInput()); - int ndim = shape.dimensions_size(); - CHECK_GE(ndim, 2); - int64_t n = shape.dimensions(ndim - 1); +#endif // GOOGLE_CUDA - const auto& dims = shape.dimensions(); - int64_t batch_size = - std::accumulate(dims.begin(), dims.end() - 2, int64_t{1}, - [](int64_t a, int64_t b) { return a * b; }); +absl::StatusOr +IrEmitterUnnested::GetAllocationSliceForHlo(const HloInstruction* instr, + const ShapeIndex& index) const { + return xla::gpu::GetAllocationSlice(ir_emitter_context_->buffer_assignment(), + instr, index); +} - TF_ASSIGN_OR_RETURN(auto operand_buffer, - GetAllocationSlice(cholesky_op.getInput())); - TF_ASSIGN_OR_RETURN(auto a_buffer, - GetAllocationSlice(cholesky_op.getOutput())); - TF_ASSIGN_OR_RETURN(auto workspace_buffer, - GetAllocationSlice(cholesky_op.getScratch())); - TF_ASSIGN_OR_RETURN(auto info_buffer, - GetAllocationSlice(cholesky_op.getInfo())); +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - ThunkSequence thunks; +absl::Status IrEmitterUnnested::EmitCubDeviceRadixSort( + const HloCustomCallInstruction* instr) { + if (instr->operand_count() != 1 && instr->operand_count() != 2) { + return Internal("Invalid number of operands for radix sort"); + } - if (operand_buffer != a_buffer) { - thunks.push_back(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), - /*source_buffer=*/operand_buffer, - /*destination_buffer=*/a_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(shape), - /*source_value=*/cholesky_op.getInput(), - /*destination_value=*/cholesky_op.getOutput())); + absl::InlinedVector operands; + for (int i = 0; i < instr->operand_count(); ++i) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice operand, + GetAllocationSliceForHlo(instr->operand(i), {})); + operands.push_back(operand); } - CholeskyOptions options; - options.set_lower(cholesky_op.getIsLower()); - thunks.push_back(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), options, - PtxOptsFromDebugOptions(ir_emitter_context_->debug_options()), a_buffer, - workspace_buffer, info_buffer, shape.element_type(), batch_size, n)); + absl::InlinedVector results; + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result, + GetAllocationSliceForHlo(instr, {0})); + results.push_back(result); - // Elide the sequential thunk if there's no copy. - if (thunks.size() == 1) { - AddThunkToThunkSequence(std::move(thunks[0])); + BufferAllocation::Slice scratch; + if (instr->operand_count() == 1) { + TF_ASSIGN_OR_RETURN(scratch, GetAllocationSliceForHlo(instr, {1})); } else { - AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(thunks))); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result, + GetAllocationSliceForHlo(instr, {1})); + results.push_back(result); + TF_ASSIGN_OR_RETURN(scratch, GetAllocationSliceForHlo(instr, {2})); } - return OkStatus(); + TF_ASSIGN_OR_RETURN(xla::SortOptions options, + instr->backend_config()); + auto thunk = std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(instr), + instr->operand(0)->shape().element_type(), + instr->operand_count() == 2 + ? std::optional(instr->operand(1)->shape().element_type()) + : std::nullopt, + operands, results, scratch, options.descending()); + AddThunkToThunkSequence(std::move(thunk)); + return absl::OkStatus(); } -Status IrEmitterUnnested::EmitCholeskyThunk(const HloInstruction* instr) { +absl::Status IrEmitterUnnested::EmitCholeskyThunk(const HloInstruction* instr) { TF_ASSIGN_OR_RETURN(CholeskyOptions options, instr->backend_config()); const Shape& shape = instr->operand(0)->shape(); @@ -1602,9 +1492,7 @@ Status IrEmitterUnnested::EmitCholeskyThunk(const HloInstruction* instr) { Thunk::ThunkInfo::WithProfileAnnotation(instr), /*source_buffer=*/operand_buffer, /*destination_buffer=*/a_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(shape), - /*source_value=*/nullptr, - /*destination_value=*/nullptr)); + /*mem_size=*/ShapeUtil::ByteSizeOf(shape))); } thunks.push_back(std::make_unique( @@ -1620,79 +1508,28 @@ Status IrEmitterUnnested::EmitCholeskyThunk(const HloInstruction* instr) { Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(thunks))); } - return OkStatus(); + return absl::OkStatus(); } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -// Converts MLIR dictionary attribute attached to a custom call operation to a -// custom call thunk attributes that are forwarded to the FFI handler. -static StatusOr BuildAttributesMap( - mlir::DictionaryAttr dict) { - CustomCallThunk::AttributesMap attributes; - for (auto& kv : dict) { - std::string_view name = kv.getName().strref(); - - auto integer = [&](mlir::IntegerAttr integer) { - switch (integer.getType().getIntOrFloatBitWidth()) { - case 32: - attributes[name] = static_cast(integer.getInt()); - return OkStatus(); - case 64: - attributes[name] = static_cast(integer.getInt()); - return OkStatus(); - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported integer attribute bit width for attribute: ", name)); - } - }; - - auto fp = [&](mlir::FloatAttr fp) { - switch (fp.getType().getIntOrFloatBitWidth()) { - case 32: - attributes[name] = static_cast(fp.getValue().convertToFloat()); - return OkStatus(); - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported float attribute bit width for attribute: ", name)); - } - }; - - auto str = [&](mlir::StringAttr str) { - attributes[name] = str.getValue().str(); - return OkStatus(); - }; - - TF_RETURN_IF_ERROR( - llvm::TypeSwitch(kv.getValue()) - .Case(integer) - .Case(fp) - .Case(str) - .Default([&](mlir::Attribute) { - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported attribute type for attribute: ", name)); - })); - } - return attributes; -} - -Status IrEmitterUnnested::EmitCustomCallThunk(mlir::Operation* op) { - auto custom_call = mlir::cast(op); - const std::string call_target_name = custom_call.getCallTargetName().str(); +absl::Status IrEmitterUnnested::EmitCustomCallThunk( + const HloCustomCallInstruction* instr) { + const std::string call_target_name = instr->custom_call_target(); // Typed FFI custom calls is a replacement for legacy custom calls with // a rich type safe API. It's under construction and not fully supported. bool is_ffi_custom_call = - custom_call.getApiVersion() == - mlir::mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI; + instr->api_version() == CustomCallApiVersion::API_VERSION_TYPED_FFI; void* call_target = CustomCallTargetRegistry::Global()->Lookup( call_target_name, std::string(platform_name())); - StatusOr handler = ffi::FindHandler(call_target_name); + absl::StatusOr registration = + ffi::FindHandler(call_target_name, platform_name()); // At least one implementation should be available at run time. bool found_custom_call = !is_ffi_custom_call && call_target != nullptr; - bool found_ffi_handler = is_ffi_custom_call && handler.ok(); + bool found_ffi_handler = is_ffi_custom_call && registration.ok(); if (!found_custom_call && !found_ffi_handler) { auto& debug_options = ir_emitter_context_->debug_options(); @@ -1700,13 +1537,13 @@ Status IrEmitterUnnested::EmitCustomCallThunk(mlir::Operation* op) { // If true, then all custom calls that are not found in custom call or FFI // registries will become no-op (we don't emit any thunks for them). if (debug_options.xla_gpu_mock_custom_calls()) { - return OkStatus(); + return absl::OkStatus(); } // TODO(ezhulenev): Custom calls registered with an XLA runtime are not part // of a legacy registry, or an FFI registry. For now we simply ignore them. if (debug_options.xla_gpu_enable_xla_runtime_executable()) { - return OkStatus(); + return absl::OkStatus(); } return absl::UnimplementedError( @@ -1716,49 +1553,40 @@ Status IrEmitterUnnested::EmitCustomCallThunk(mlir::Operation* op) { using Slices = std::vector>; - // Initialize slices and shapes from the value range. - auto init_from_values = [&](mlir::ValueRange values, Slices* slices) { - for (mlir::Value value : values) { - TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(value)); - slices->push_back(CustomCallThunk::Slice{slice, GetShape(value)}); - } - return OkStatus(); - }; - - // Initialize slices and shapes from the value range with token holes. - auto init_from_mapped_values = [&](mlir::ValueRange values, - absl::Span target_mapping, - int64_t target_size, Slices* slices) { - slices->resize(target_size); - for (auto [index, value] : llvm::zip(target_mapping, values)) { - TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(value)); - (*slices)[index] = CustomCallThunk::Slice{slice, GetShape(value)}; - } - return OkStatus(); - }; - - Slices operands, results; - - // If we have a target mapping, than the number of operands and results of a - // custom call handler can be more than a number of operands and results in - // the IR. These holes are coming from the HLO token operands and results. - if (auto target_mapping = custom_call.getTargetArgMapping()) { - auto arg_mapping = target_mapping->getArgsToTargetArgs(); - auto res_mapping = target_mapping->getResultsToTargetResults(); - - TF_RETURN_IF_ERROR( - init_from_mapped_values(custom_call.getArgs(), arg_mapping, - target_mapping->getNumArgs(), &operands)); - TF_RETURN_IF_ERROR( - init_from_mapped_values(custom_call.getOutput(), res_mapping, - target_mapping->getNumResults(), &results)); - - } else { - TF_RETURN_IF_ERROR(init_from_values(custom_call.getArgs(), &operands)); - TF_RETURN_IF_ERROR(init_from_values(custom_call.getOutput(), &results)); + Slices operands; + for (auto* operand : instr->operands()) { + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + operand->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsToken()) { + operands.push_back(std::nullopt); + return absl::OkStatus(); + } + if (!subshape.IsArray()) { + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN(auto slice, + GetAllocationSliceForHlo(operand, index)); + operands.push_back(CustomCallThunk::Slice{slice, subshape}); + return absl::OkStatus(); + })); } - // For legacy custom calls we convert all API versions into the the latest + Slices results; + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + instr->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsToken()) { + results.push_back(std::nullopt); + return absl::OkStatus(); + } + if (!subshape.IsArray()) { + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForHlo(instr, index)); + results.push_back(CustomCallThunk::Slice{slice, subshape}); + return absl::OkStatus(); + })); + + // For legacy custom calls we convert all API versions into the latest // status-returning one and pass backend config as an opaque string. CustomCallThunk::CustomCallTarget custom_call_target; std::string opaque; @@ -1770,8 +1598,8 @@ Status IrEmitterUnnested::EmitCustomCallThunk(mlir::Operation* op) { // For information about this calling convention, see // xla/g3doc/custom_call.md. - switch (custom_call.getApiVersion()) { - case mlir::mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL: + switch (instr->api_version()) { + case CustomCallApiVersion::API_VERSION_ORIGINAL: using original_call_type = void (*)(CustomCallThunk::Stream /*stream*/, void** /*buffers*/, const char* /*opaque*/, size_t /*opaque_len*/); @@ -1784,8 +1612,8 @@ Status IrEmitterUnnested::EmitCustomCallThunk(mlir::Operation* op) { typed_call_target(stream, buffers, opaque, opaque_len); }; break; - case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING: - case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING: + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: using status_returning_call_type = void (*)(CustomCallThunk::Stream /*stream*/, void** /*buffers*/, const char* /*opaque*/, size_t /*opaque_len*/, @@ -1793,138 +1621,118 @@ Status IrEmitterUnnested::EmitCustomCallThunk(mlir::Operation* op) { custom_call_target = reinterpret_cast(call_target); break; - case mlir::mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI: + case CustomCallApiVersion::API_VERSION_TYPED_FFI: // We already checked `handler` above. break; default: - return InternalError("Unknown custom-call API version enum value: %d", - custom_call.getApiVersion()); + return Internal("Unknown custom-call API version enum value: %d", + instr->api_version()); } - auto backend_config = - custom_call.getBackendConfig().value_or(mlir::Attribute()); - - switch (custom_call.getApiVersion()) { - case mlir::mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL: - case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING: - case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: - if (auto str = backend_config.dyn_cast_or_null()) { - opaque = str.str(); - break; + auto& backend_config_str = instr->raw_backend_config_string(); + switch (instr->api_version()) { + case CustomCallApiVersion::API_VERSION_ORIGINAL: + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING: + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: + if (!backend_config_str.empty()) { + opaque = backend_config_str; } - return absl::InternalError( - "Unsupported backend config. Expected a string attribute"); + break; - case mlir::mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI: - if (auto dict = backend_config.dyn_cast_or_null()) { - TF_ASSIGN_OR_RETURN(attributes, BuildAttributesMap(dict)); - break; + case CustomCallApiVersion::API_VERSION_TYPED_FFI: + if (!backend_config_str.empty()) { + mlir::Attribute attr = mlir::parseAttribute( + backend_config_str, ir_emitter_context_->mlir_context()); + if (auto dict = attr.dyn_cast_or_null()) { + TF_ASSIGN_OR_RETURN(attributes, BuildAttributesMap(dict)); + break; + } + return absl::InternalError( + "Unsupported backend config. Expected a string parsable into " + "dictionary attribute"); } - return absl::InternalError( - "Unsupported backend config. Expected a dictionary attribute"); + break; default: - return InternalError("Unknown custom-call API version enum value: %d", - custom_call.getApiVersion()); + return Internal("Unknown custom-call API version enum value: %d", + instr->api_version()); } auto ffi_thunk = [&] { + auto& called_computations = instr->called_computations(); return std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), *handler, - std::move(operands), std::move(results), std::move(attributes)); + Thunk::ThunkInfo::WithProfileAnnotation(instr), registration->handler, + std::move(operands), std::move(results), std::move(attributes), + called_computations.empty() ? nullptr : called_computations[0]); }; auto legacy_thunk = [&] { return std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), + Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(custom_call_target), std::move(operands), std::move(results), std::move(opaque)); }; AddThunkToThunkSequence(found_ffi_handler ? ffi_thunk() : legacy_thunk()); - return OkStatus(); + return absl::OkStatus(); } -Status IrEmitterUnnested::EmitFftThunk(mlir::Operation* op) { - auto fft_op = mlir::cast(op); - const Shape operand_shape = GetShape(fft_op.getOperand()); - const Shape output_shape = GetShape(fft_op.getOutput()); - TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand_shape.layout())); - TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout())); - +absl::Status IrEmitterUnnested::EmitFftThunk(const HloFftInstruction* instr) { TF_ASSIGN_OR_RETURN(BufferAllocation::Slice arg_slice, - GetAllocationSlice(fft_op.getOperand())); + GetAllocationSliceForHlo(instr->operand(0))); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dest_slice, - GetAllocationSlice(fft_op.getOutput())); - TF_ASSIGN_OR_RETURN( - xla::FftType fft_type, - ConvertFftType(mlir::mhlo::stringifyFftType(fft_op.getFftType()))); - auto fft_length_values = fft_op.getFftLength().getValues(); - std::vector fft_length(fft_length_values.begin(), - fft_length_values.end()); - - AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), fft_type, fft_length, - /*input_buffer=*/arg_slice, - /*output_buffer=*/dest_slice, - /*input_shape=*/operand_shape, - /*output_shape=*/output_shape)); - return OkStatus(); + GetAllocationSliceForHlo(instr)); + AddThunkToThunkSequence( + std::make_unique(Thunk::ThunkInfo::WithProfileAnnotation(instr), + instr->fft_type(), instr->fft_length(), + /*input_buffer=*/arg_slice, + /*output_buffer=*/dest_slice, + /*input_shape=*/instr->operand(0)->shape(), + /*output_shape=*/instr->shape())); + return absl::OkStatus(); } #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -Status IrEmitterUnnested::EmitTriangularSolveCustomCall(mlir::Operation* op) { - auto custom_call = mlir::cast(op); - auto operands = op->getOperands(); - TF_RET_CHECK(operands.size() == 4); +absl::Status IrEmitterUnnested::EmitTriangularSolveCustomCall( + const HloInstruction* instr) { + TF_RET_CHECK(instr->operand_count() == 2); + auto operands = instr->operands(); + TF_RET_CHECK(instr->shape().IsTuple() && + instr->shape().tuple_shapes_size() == 2); // We expect Fortran layout for everything other than the temp buffer (the // last operand). Fortran layout is not XLA default layout with elements 0 // and 1 swapped. For example instead of default layout {3,2,1,0} we'd have // Fortran layout {2,3,1,0}. - TF_RET_CHECK(absl::c_all_of(operands.drop_back(1), [&](mlir::Value v) { - const Shape& shape = GetShape(v); - const Layout& layout = shape.layout(); + auto has_fortran_layout = [](const Layout& layout) { int n = layout.minor_to_major_size(); - if (n < 2) { - return false; - } - // Unfortunately the HLO -> LMHLO -> HLO conversion loses layout information - // if the shape has any dimensions of size 1: In that case, the new HLO - // (which we see here) will have an arbitrary value for the location of the - // size-1 dimension. Just skip this assertion if the shape has any - // degenerate dimensions. - if (absl::c_any_of(shape.dimensions(), - [](int64_t dim) { return dim == 1; })) { - return true; - } return layout.minor_to_major(0) == n - 2 && - layout.minor_to_major(1) == n - 1 && - std::is_sorted(layout.minor_to_major().begin() + 2, - layout.minor_to_major().end(), - std::greater()); - })); + layout.minor_to_major(1) == n - 1; + }; + TF_RET_CHECK(has_fortran_layout(operands[0]->shape().layout())); + TF_RET_CHECK(has_fortran_layout(operands[1]->shape().layout())); + TF_RET_CHECK(has_fortran_layout(instr->shape().tuple_shapes(0).layout())); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a_slice, - GetAllocationSlice(operands[0])); + GetAllocationSliceForHlo(operands[0])); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice b_slice, - GetAllocationSlice(operands[1])); + GetAllocationSliceForHlo(operands[1])); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice, - GetAllocationSlice(operands[2])); + GetAllocationSliceForHlo(instr, {0})); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice temp_slice, - GetAllocationSlice(operands[3])); + GetAllocationSliceForHlo(instr, {1})); - const Shape b_shape = GetShape(operands[1]); + const Shape b_shape = operands[1]->shape(); const PrimitiveType elem_ty = b_shape.element_type(); TriangularSolveOptions backend_config; - if (auto str = custom_call.getBackendConfig() - .value_or(mlir::Attribute()) - .dyn_cast_or_null()) + auto& backend_config_str = instr->raw_backend_config_string(); + if (!backend_config_str.empty()) { TF_RETURN_IF_ERROR( - tsl::HumanReadableJsonToProto(str.str(), &backend_config)); + tsl::HumanReadableJsonToProto(backend_config_str, &backend_config)); + } ThunkSequence thunks; @@ -1932,12 +1740,10 @@ Status IrEmitterUnnested::EmitTriangularSolveCustomCall(mlir::Operation* op) { // aren't the same buffer. if (b_slice != result_slice) { thunks.push_back(std::make_unique( - Thunk::ThunkInfo(op), + Thunk::ThunkInfo::WithProfileAnnotation(instr), /*source_buffer=*/b_slice, /*destination_buffer=*/result_slice, - /*mem_size=*/ShapeUtil::ByteSizeOf(b_shape), - /*source_value=*/operands[1], - /*destination_value=*/operands[2])); + /*mem_size=*/ShapeUtil::ByteSizeOf(b_shape))); } int64_t m = b_shape.dimensions(b_shape.rank() - 2); @@ -1950,7 +1756,7 @@ Status IrEmitterUnnested::EmitTriangularSolveCustomCall(mlir::Operation* op) { backend_config.left_side() ? m * m * elem_size : n * n * elem_size; int64_t b_batch_stride = m * n * elem_size; thunks.push_back(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), backend_config, + Thunk::ThunkInfo::WithProfileAnnotation(instr), backend_config, PtxOptsFromDebugOptions(ir_emitter_context_->debug_options()), /*a_buffer=*/a_slice, /*b_buffer=*/result_slice, temp_slice, elem_ty, batch_size, m, n, a_batch_stride, b_batch_stride)); @@ -1959,385 +1765,167 @@ Status IrEmitterUnnested::EmitTriangularSolveCustomCall(mlir::Operation* op) { if (thunks.size() == 1) { AddThunkToThunkSequence(std::move(thunks[0])); } else { - AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(thunks))); + auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(instr); + // Don't repeat the annotation from inside thunks + thunk_info.profile_annotation = {}; + AddThunkToThunkSequence( + std::make_unique(thunk_info, std::move(thunks))); } - return OkStatus(); + return absl::OkStatus(); } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -// Convert the following form of fusion region: -// fusion() { -// %0 = tensor_load %external_memref0 -// %1 = tensor_load %external_memref1 -// ... -// materialize_in_destination %ret, %external_memref2 -// } -// to -// fusion(%external_memref0, %external_memref1) (^bb(%0, %1) { -// ... -// mhlo.return %ret -// }) -// -// So that it's suitable for MHLO -> XLA HLO conversion. -// This function won't be needed once ElementalIrEmitter migrates to take MHLO -// instead. -static Status ProcessFusionForConversion(mlir::Region* region, - std::vector* operand_shapes, - std::vector* output_shapes) { - std::vector loads; - std::vector stores; - - region->walk([&](mlir::bufferization::ToTensorOp load) { - if (load.getMemref().getParentRegion() != region) { - loads.push_back(load); - } - }); - - region->walk([&](mlir::bufferization::MaterializeInDestinationOp store) { - if (!llvm::isa(store.getDest().getType())) return; - if (store.getDest().getParentRegion() != region) { - stores.push_back(store); - } - }); - - for (auto& load : loads) { - auto arg = region->addArgument(load.getType(), region->getLoc()); - load.replaceAllUsesWith(arg); - Shape shape = GetShape(load.getResult()); - operand_shapes->push_back(std::move(shape)); - load.erase(); - } - - std::vector returned_values; - for (auto store : stores) { - Shape shape = GetShape(store.getDest()); - output_shapes->push_back(shape); - - returned_values.push_back(store.getSource()); - store.erase(); - } - - region->back().back().erase(); - auto b = mlir::OpBuilder::atBlockEnd(®ion->back()); - auto loc = returned_values[0].getLoc(); - b.create(loc, returned_values); - return OkStatus(); -} - -#if GOOGLE_CUDA -StatusOr IrEmitterUnnested::EmitTritonFusion( - const HloFusionAnalysis& hlo_fusion_analysis, - const HloFusionInstruction* fusion, mlir::Operation* op) { - // Note: In this method we can't use `BuildKernelThunk` as usual, - // because we only get the launch dimensions after code generation. So we - // implement kernel reuse using lower level APIs, such as - // `BuildKernelThunkImpl`. - CHECK_NE(fusion, nullptr); - if (!ir_emitter_context_->emit_ir_from_hlo()) { - CHECK_NE(op, nullptr); - } - if (ir_emitter_context_->emit_ir_from_hlo()) { - VLOG(3) << fusion->ToString(); - } else { - VLOG(3) << llvm_ir::DumpToString(op); - } - std::string suggested_kernel_name = std::string(fusion->name()); +absl::Status IrEmitterUnnested::EmitTopKCustomCall( + const HloCustomCallInstruction* instr) { + auto operands = instr->operands(); + auto shape = instr->shape(); + TF_RET_CHECK(operands.size() == 1) + << "Expect only 1 operand for TopK custom call."; + TF_RET_CHECK(shape.IsTuple()) + << "Expect TopK custom call to have tuple shape."; + TF_RET_CHECK(shape.tuple_shapes_size() == 2) + << "Expect TopK custom call shape to have exactly 2 sub-shapes."; + + auto data_shape = operands[0]->shape(); + auto top_elements_shape = shape.tuple_shapes()[0]; + auto indices_shape = shape.tuple_shapes()[1]; + + TF_RET_CHECK(data_shape.rank() <= 2) << "Invalid input shape."; + TF_RET_CHECK(indices_shape.element_type() == PrimitiveType::S32) + << "Indices should be S32."; + + bool has_batch = data_shape.rank() == 2; + auto [batch_size, n, k] = + has_batch + ? std::tuple{data_shape.dimensions(0), + data_shape.dimensions(1), + top_elements_shape.dimensions(1)} + : std::tuple{ + 1, data_shape.dimensions(0), top_elements_shape.dimensions(0)}; + + // Load TopK custom kernel. + TF_ASSIGN_OR_RETURN(CustomKernel kernel, + kernel::topk::GetTopKKernel( + "topk", data_shape.element_type(), n, k, batch_size)); + + // Prepare kernel arguments. TF_ASSIGN_OR_RETURN( auto kernel_arguments, - ir_emitter_context_->emit_ir_from_hlo() - ? KernelArguments::Create(ir_emitter_context_->buffer_assignment(), - fusion) - : KernelArguments::Create(ir_emitter_context_->allocations(), - mlir::cast(op))); - - const HloComputation* hlo_computation = - fusion->fused_instructions_computation(); - - auto generate = [&]() -> StatusOr { - VLOG(3) << "Generating: " << suggested_kernel_name; - - const std::string impl_fn_name = - ir_emitter_context_->name_uniquer()->GetUniqueName( - llvm_ir::SanitizeFunctionName( - absl::StrCat(suggested_kernel_name, "_impl"))); - - TF_ASSIGN_OR_RETURN(auto backend_config, - fusion->backend_config()); - absl::string_view fusion_kind = backend_config.kind(); - - TritonWrapperResult triton_wrapper_result; - LaunchDimensions launch_dimensions; - if (fusion_kind == kTritonSoftmaxFusionKind) { - auto& triton_config = *backend_config.mutable_triton_gemm_config(); - triton_config.set_num_stages(1); - triton_config.set_num_warps(DeriveNumWarpsFromTritonSoftmaxComputation( - fusion->fused_instructions_computation())); - TritonGemmConfig config = TritonGemmConfig::FromProto(triton_config); - - TF_ASSIGN_OR_RETURN(auto analysis, - TritonFusionAnalysis::Execute(*hlo_computation)); - TF_ASSIGN_OR_RETURN( - triton_wrapper_result, - TritonWrapper(analysis, impl_fn_name, hlo_computation, - kTritonSoftmaxFusionKind, - ir_emitter_context_->cuda_compute_capability(), - ir_emitter_context_->gpu_device_info(), config, module_, - &EmitSoftMax, *ir_emitter_context_->mlir_context())); - launch_dimensions = - GetSoftMaxLaunchDimensions(hlo_fusion_analysis.fusion(), config); - } else { // Must be a MatMul - CHECK_EQ(fusion_kind, kTritonGemmFusionKind); - if (!backend_config.has_triton_gemm_config()) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - LOG(WARNING) << "Using fallback triton GEMM config for op " - << fusion->name(); - } else { - LOG(WARNING) << "Using fallback triton GEMM config for op " - << GetIrNameFromLoc(op->getLoc()); - } - auto& triton_config = *backend_config.mutable_triton_gemm_config(); - triton_config.set_block_m(64); - triton_config.set_block_k(64); - triton_config.set_block_n(64); - triton_config.set_split_k(1); - triton_config.set_num_stages(1); - triton_config.set_num_warps(2); - } - TritonGemmConfig config = - TritonGemmConfig::FromProto(backend_config.triton_gemm_config()); - - TF_ASSIGN_OR_RETURN(auto analysis, TritonFusionAnalysis::Execute( - *hlo_computation, config.split_k)); - TF_ASSIGN_OR_RETURN( - triton_wrapper_result, - TritonWrapper(analysis, impl_fn_name, hlo_computation, - kTritonGemmFusionKind, - ir_emitter_context_->cuda_compute_capability(), - ir_emitter_context_->gpu_device_info(), config, module_, - &EmitMatMul, *ir_emitter_context_->mlir_context())); - launch_dimensions = GetMatMulLaunchDimensions( - analysis, hlo_fusion_analysis.fusion(), config); - } + KernelArguments::Create(ir_emitter_context_->buffer_assignment(), instr, + operands)); - llvm::Function* impl_fn = module_->getFunction(impl_fn_name); + auto thunk = std::make_unique( + instr, std::move(kernel), std::move(kernel_arguments.args())); + AddThunkToThunkSequence(std::move(thunk)); + + return absl::OkStatus(); +} + +absl::Status IrEmitterUnnested::EmitTritonCustomCall( + const HloCustomCallInstruction* instr) { +#if !GOOGLE_CUDA + return absl::UnimplementedError("Triton support requires CUDA"); +#else + auto generate = [this, &instr]() -> absl::StatusOr { + mlir::MLIRContext& mlir_context = *ir_emitter_context_->mlir_context(); + mlir_context.loadDialect(); + auto call = + TritonCall::Parse(instr->raw_backend_config_string(), &mlir_context); + auto kernel_name = + ir_emitter_context_->name_uniquer()->GetUniqueName(call.name); + VLOG(3) << "Generating: " << kernel_name; + + auto triton_module = + mlir::parseSourceString(call.ir, &mlir_context); + auto triton_fn = + triton_module->lookupSymbol(call.name); + triton_fn.setName(kernel_name); + + HloModule* hlo_module = instr->GetModule(); + auto gemm_config = TritonGemmConfig( + /*block_m=*/-1, /*block_n=*/-1, /*block_k=*/-1, /*split_k=*/-1, + call.num_stages, call.num_warps); + TF_ASSIGN_OR_RETURN( + auto result, + CompileTritonToLLVM(hlo_module->config(), hlo_module->name(), + ir_emitter_context_->cuda_compute_capability(), + ir_emitter_context_->gpu_device_info(), gemm_config, + triton_module.get(), + ir_emitter_context_->llvm_module(), mlir_context)); + + llvm::Function* impl_fn = + ir_emitter_context_->llvm_module()->getFunction(kernel_name); TF_RET_CHECK(impl_fn); + impl_fn->setName(ir_emitter_context_->name_uniquer()->GetUniqueName( + kernel_name + "_impl")); - auto [kernel, inputs, outputs] = BuildKernelPrototype( - *ir_emitter_context_, suggested_kernel_name, kernel_arguments.args(), - impl_fn->arg_size(), launch_dimensions, &b_); + TF_ASSIGN_OR_RETURN( + auto kernel_arguments, + KernelArguments::Create(ir_emitter_context_->buffer_assignment(), instr, + instr->operands(), + /*dedup=*/false)); + auto launch_dimensions = + LaunchDimensions(se::BlockDim(call.grid_x, call.grid_y, call.grid_z), + se::ThreadDim(call.num_warps * 32)); + + llvm::IRBuilder builder(ir_emitter_context_->llvm_module()->getContext()); + + llvm::Function* kernel; + std::vector inputs; + std::vector outputs; + TF_ASSIGN_OR_RETURN( + std::tie(kernel, inputs, outputs), + BuildKernelPrototype(*ir_emitter_context_, kernel_name, + kernel_arguments.args(), impl_fn->arg_size(), + launch_dimensions, &builder)); // Move function body into kernel prototype. - llvm::Function* prototype_func = b_.GetInsertBlock()->getParent(); + llvm::Function* prototype_func = builder.GetInsertBlock()->getParent(); prototype_func->splice(prototype_func->begin(), impl_fn); - for (const auto& [arg, ir_array] : llvm::zip(impl_fn->args(), inputs)) { - arg.replaceAllUsesWith(ir_array.GetBasePointer()); + for (const auto& [arg, input] : llvm::zip(impl_fn->args(), inputs)) { + arg.replaceAllUsesWith(input.GetBasePointer()); } impl_fn->eraseFromParent(); - return {{kernel->getName().str(), launch_dimensions, - triton_wrapper_result.shmem_bytes}}; - }; - - auto [kernel, was_cached] = kernel_reuse_cache_.GetWithStatus( - hlo_computation, kernel_arguments.args(), - /*discriminator=*/"", generate); - TF_RETURN_IF_ERROR(kernel.status()); + for (auto& arg : prototype_func->args()) { + // Remove the alignment and aliasing attributes to avoid recompiling the + // kernel for each alignment/aliasing combination. + arg.removeAttr(llvm::Attribute::Alignment); + arg.removeAttr(llvm::Attribute::NoAlias); + } - std::variant fusion_op; - if (ir_emitter_context_->emit_ir_from_hlo()) { - fusion_op = fusion; - } else { - fusion_op = op; - } + return {{kernel->getName().str(), launch_dimensions, result.cluster_dim, + result.shmem_bytes}}; + }; - FusionEmissionResult result; - result.thunks.emplace_back(std::make_unique( - fusion_op, kernel->kernel_name, kernel_arguments.args(), - kernel->launch_dimensions, kernel->shmem_bytes)); + auto [status_or_entry, was_cached] = + ir_emitter_context_->kernel_cache().GetWithStatus( + instr->raw_backend_config_string(), generate); + TF_ASSIGN_OR_RETURN(const KernelReuseCache::Entry* entry, status_or_entry); - return result; -} + TF_ASSIGN_OR_RETURN( + auto kernel_arguments, + KernelArguments::Create(ir_emitter_context_->buffer_assignment(), instr, + instr->operands(), + /*dedup=*/false)); + AddThunkToThunkSequence(std::make_unique( + instr, entry->kernel_name, kernel_arguments.args(), + entry->launch_dimensions, entry->cluster_dim, entry->shmem_bytes)); + return absl::OkStatus(); #endif // GOOGLE_CUDA - -// Check if the fusion instruction should be emitted as an in place dynamic -// update slice or a memcpy fusion. The logic is copied from GetFusionEmitter. -bool IsSpecializedLoopFusion( - mlir::Operation* op, absl::Span allocations, - HloFusionAnalysis& analysis) { - auto fusion_op = mlir::cast(op); - if (!allocations.empty() && fusion_op != nullptr) { - bool is_single = IsSingleInstructionFusion(fusion_op); - if (!is_single && - CanEmitFusedDynamicUpdateSliceInPlaceForGpu(fusion_op, allocations)) { - return true; - } - if (is_single && analysis.fusion_roots().size() == 1 && - analysis.fusion_roots().front()->opcode() == HloOpcode::kCopy) { - mlir::Value operand = GetHloOperands(fusion_op).front(); - mlir::Value output = GetHloOutputs(fusion_op).front(); - Shape operand_shape = GetShape(operand); - Shape output_shape = GetShape(output); - if (LayoutUtil::Equal(operand_shape.layout(), output_shape.layout()) && - GetAllocationSlice(operand, allocations).ok()) { - return true; - } - } - } - return false; -} - -StatusOr IrEmitterUnnested::GetFusionEmissionResult( - const HloFusionInstruction* instr, HloFusionAnalysis& fusion_analysis) { - FusionEmissionResult emission_result; - switch (fusion_analysis.GetEmitterFusionKind()) { - case HloFusionAnalysis::EmitterFusionKind::kInputSlices: { - auto emitter = std::make_unique(fusion_analysis); - TF_ASSIGN_OR_RETURN( - emission_result, - emitter->Emit(*ir_emitter_context_, elemental_emitter_, nullptr, - *instr, kernel_reuse_cache_, &b_)); - break; - } - case HloFusionAnalysis::EmitterFusionKind::kLoop: { - // TODO(anlunx): Support MemcpyFusion and InPlaceDymaicUpdateSlice. - auto emitter = std::make_unique(fusion_analysis); - TF_ASSIGN_OR_RETURN( - emission_result, - emitter->Emit(*ir_emitter_context_, elemental_emitter_, nullptr, - *instr, kernel_reuse_cache_, &b_)); - break; - } - case HloFusionAnalysis::EmitterFusionKind::kTranspose: { - auto emitter = std::make_unique(fusion_analysis); - TF_ASSIGN_OR_RETURN( - emission_result, - emitter->Emit(*ir_emitter_context_, elemental_emitter_, nullptr, - *instr, kernel_reuse_cache_, &b_)); - break; - } - case HloFusionAnalysis::EmitterFusionKind::kReduction: { - auto emitter = std::make_unique(fusion_analysis); - TF_ASSIGN_OR_RETURN( - emission_result, - emitter->Emit(*ir_emitter_context_, elemental_emitter_, nullptr, - *instr, kernel_reuse_cache_, &b_)); - break; - } - case HloFusionAnalysis::EmitterFusionKind::kTriton: { - TF_ASSIGN_OR_RETURN(auto backend_config, - instr->backend_config()); -#if GOOGLE_CUDA - TF_ASSIGN_OR_RETURN(emission_result, - EmitTritonFusion(fusion_analysis, instr, nullptr)); - break; -#endif - LOG(FATAL) << "Unsupported fusion kind: " << backend_config.kind(); - } - case HloFusionAnalysis::EmitterFusionKind::kScatter: { - TF_ASSIGN_OR_RETURN(emission_result, - EmitScatter(instr, nullptr, fusion_analysis)); - break; - } - case HloFusionAnalysis::EmitterFusionKind::kCustomFusion: { - TF_ASSIGN_OR_RETURN(auto backend_config, - instr->backend_config()); - TF_ASSIGN_OR_RETURN( - emission_result, - EmitCustomFusion(instr, backend_config.custom_fusion_config())); - break; - } - default: - return FailedPrecondition( - "Fusion type not supported by the HLO emitter."); - break; - } - - return emission_result; -} - -Status IrEmitterUnnested::EmitFusion(const HloFusionInstruction* instr, - HloFusionAnalysis& fusion_analysis) { - TF_ASSIGN_OR_RETURN(FusionEmissionResult emission_result, - GetFusionEmissionResult(instr, fusion_analysis)); - for (auto& thunk : emission_result.thunks) { - AddThunkToThunkSequence(std::move(thunk)); - } - return OkStatus(); } - -Status IrEmitterUnnested::EmitFusion( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo) { - auto fusion_op = mlir::cast(op); - auto* fusion = Cast(hlo_for_lmhlo.at(fusion_op)); - - // Parse backend config. - FusionBackendConfig backend_config; - if (auto backend_config_str = fusion_op.getBackendConfig() - .value_or(mlir::Attribute()) - .dyn_cast_or_null()) { - auto status = tsl::HumanReadableJsonToProto(backend_config_str.str(), - &backend_config); - if (!status.ok()) { - LOG(ERROR) << "Ignoring invalid backend config on " - << GetIrNameFromLoc(op->getLoc()) << ": " - << backend_config_str.str(); - } - } - - // Create HloFusionAnalysis instance. - const se::DeviceDescription& device_info = - ir_emitter_context_->gpu_device_info(); - TF_ASSIGN_OR_RETURN(auto fusion_analysis, - HloFusionAnalysis::Create(fusion, &device_info)); - - FusionEmissionResult emission_result; - auto emitter_fusion_kind = fusion_analysis.GetEmitterFusionKind(); - switch (emitter_fusion_kind) { - case HloFusionAnalysis::EmitterFusionKind::kInputSlices: - case HloFusionAnalysis::EmitterFusionKind::kLoop: - case HloFusionAnalysis::EmitterFusionKind::kReduction: - case HloFusionAnalysis::EmitterFusionKind::kTranspose: { - std::optional> emitter = - GetFusionEmitter(fusion_analysis, ir_emitter_context_->allocations(), - fusion_op); - if (emitter == std::nullopt) { - return FailedPrecondition( - "Fusion should have been handled by GetFusionEmitter."); - } - TF_ASSIGN_OR_RETURN( - emission_result, - (*emitter)->Emit(*ir_emitter_context_, elemental_emitter_, fusion_op, - *fusion, kernel_reuse_cache_, &b_)); - break; - } - case HloFusionAnalysis::EmitterFusionKind::kTriton: { -#if GOOGLE_CUDA - TF_ASSIGN_OR_RETURN(emission_result, - EmitTritonFusion(fusion_analysis, fusion, fusion_op)); - break; -#endif - LOG(FATAL) << "Unsupported fusion kind: " << backend_config.kind(); - } - case HloFusionAnalysis::EmitterFusionKind::kScatter: { - TF_ASSIGN_OR_RETURN(emission_result, - EmitScatter(fusion, fusion_op, fusion_analysis)); - break; - } - case HloFusionAnalysis::EmitterFusionKind::kCustomFusion: - LOG(FATAL) << "kCustomFusion is not supported by JitRt runtime"; - } - - for (auto& thunk : emission_result.thunks) { - AddThunkToThunkSequence(std::move(thunk)); - } - return OkStatus(); + +absl::Status IrEmitterUnnested::EmitFusion(const HloFusionInstruction* instr, + HloFusionAnalysis& fusion_analysis) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr emitter, + GetFusionEmitter(HloFusionInfo( + fusion_analysis, instr, &ir_emitter_context_->buffer_assignment()))); + return AddThunksToThunkSequence(emitter->Emit(*ir_emitter_context_, *instr)); } -Status IrEmitterUnnested::AssertNonDeterminismIsOkay( +absl::Status IrEmitterUnnested::AssertNonDeterminismIsOkay( const std::string& op_name) { if (ir_emitter_context_->debug_options().xla_gpu_deterministic_ops()) { return Unimplemented( @@ -2346,52 +1934,40 @@ Status IrEmitterUnnested::AssertNonDeterminismIsOkay( "--xla_gpu_deterministic_ops.", op_name); } - return OkStatus(); + return absl::OkStatus(); } -Status IrEmitterUnnested::EmitSelectAndScatter( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo) { - auto select_and_scatter_op = mlir::cast(op); - auto* select_and_scatter = - Cast(hlo_for_lmhlo.at(op)); - - const Shape source_shape = GetShape(select_and_scatter_op.getSource()); - const Shape operand_shape = GetShape(select_and_scatter_op.getOperand()); +absl::Status IrEmitterUnnested::EmitSelectAndScatter( + const HloSelectAndScatterInstruction* instr) { + const HloInstruction* operand = instr->operand(0); + const HloInstruction* source = instr->operand(1); + const Shape source_shape = source->shape(); + const Shape operand_shape = operand->shape(); const int64_t rank = operand_shape.rank(); + Window window = instr->window(); + CHECK_EQ(rank, source_shape.rank()); - if (select_and_scatter_op.getWindowDimensions()) { - CHECK_EQ(rank, select_and_scatter_op.getWindowDimensions()->size()); - } + CHECK_EQ(rank, window.dimensions_size()); - TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay( - mlir::mhlo::GetDebugNameFromLocation(select_and_scatter_op.getLoc()))); + std::string name = llvm_ir::IrName(instr); - std::string name = GetIrNameFromLoc(select_and_scatter_op.getLoc()); + TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(name)); - const HloInstruction* init_value = select_and_scatter->operand(2); + const HloInstruction* init_value = instr->operand(2); // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk // consisting of two thunks, an initializer KernelThunk that initializes // the output and another KernelThunk that accumulates the scattered // elements. - TF_RETURN_IF_ERROR(BuildInitializerThunk(op, select_and_scatter, init_value, - select_and_scatter_op.getInitValue(), - select_and_scatter_op.getOut())); + TF_RETURN_IF_ERROR(BuildInitializerThunk(instr, init_value)); - TF_ASSIGN_OR_RETURN( - LaunchDimensions launch_dimensions, - CalculateLaunchDimensions(source_shape, - ir_emitter_context_->gpu_device_info())); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + source_shape, ir_emitter_context_->gpu_device_info()); // Init value is not needed in IR emission. - TF_ASSIGN_OR_RETURN(auto ir_arrays, BuildKernelThunkForNonFusionOp( - select_and_scatter_op, - {select_and_scatter_op.getOperand(), - select_and_scatter_op.getSource(), - select_and_scatter_op.getOut()}, - launch_dimensions)); + TF_ASSIGN_OR_RETURN(auto ir_arrays, + BuildKernelThunkForNonFusionOp(instr, {operand, source}, + launch_dimensions)); auto& [inputs, outputs] = ir_arrays; CHECK_EQ(inputs.size(), 3); @@ -2400,8 +1976,8 @@ Status IrEmitterUnnested::EmitSelectAndScatter( const llvm_ir::IrArray& source_array = inputs[1]; const llvm_ir::IrArray& out_array = inputs[2]; - llvm::Type* index_type = GetIndexTypeForKernel( - select_and_scatter_op, launch_dimensions.launch_bound(), &b_); + llvm::Type* index_type = + GetIndexTypeForKernel(instr, launch_dimensions.launch_bound(), &b_); auto index_typed_constant = [&](uint64_t c) -> llvm::Constant* { return llvm::ConstantInt::get(index_type, c); }; @@ -2427,7 +2003,7 @@ Status IrEmitterUnnested::EmitSelectAndScatter( // if initialized_flag: // output(selected_index) = scatter(output(selected_index), source(S)) auto loop_body_emitter = - [&](const llvm_ir::IrArray::Index& source_index) -> Status { + [&](const llvm_ir::IrArray::Index& source_index) -> absl::Status { // Allocate space to keep the currently selected value, its index, and a // boolean flag if the value is initialized. The initialized_flag is set // false. @@ -2450,11 +2026,10 @@ Status IrEmitterUnnested::EmitSelectAndScatter( index_type); DimensionVector window_size; - mlir::DenseIntElementsAttr window_dimensions = - select_and_scatter_op.getWindowDimensions().value(); - for (const auto& dim : window_dimensions) { - window_size.push_back(dim.getSExtValue()); - CHECK_GT(dim.getSExtValue(), 0); + for (const WindowDimension& dim : window.dimensions()) { + auto size = static_cast(dim.size()); + window_size.push_back(size); + CHECK_GT(size, 0); } const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape( @@ -2469,14 +2044,9 @@ Status IrEmitterUnnested::EmitSelectAndScatter( std::vector operand_multi_index(source_index.size()); llvm::Value* in_bounds_condition = b_.getInt1(true); - auto strides = *select_and_scatter_op.getWindowStrides(); - auto paddings = *select_and_scatter_op.getPadding(); - - for (const auto& stride_and_padding : - llvm::enumerate(llvm::zip(strides, paddings))) { - const int i = stride_and_padding.index(); - int64_t stride = std::get<0>(stride_and_padding.value()).getSExtValue(); - int64_t padding = std::get<1>(stride_and_padding.value()).getSExtValue(); + for (const auto [i, value] : llvm::enumerate(window.dimensions())) { + auto stride = static_cast(value.stride()); + auto padding = static_cast(value.padding_low()); llvm::Value* strided_index = NSWMul(source_index[i], index_typed_constant(stride)); @@ -2528,7 +2098,7 @@ Status IrEmitterUnnested::EmitSelectAndScatter( llvm_ir::PrimitiveTypeToIrType(PRED, module_), "select_return_buffer", &b_); - const HloComputation* select_computation = select_and_scatter->select(); + const HloComputation* select_computation = instr->select(); TF_RETURN_IF_ERROR(CallNestedComputation( &b_, *ir_emitter_context_, *select_computation, {selected_value_address, operand_address}, select_return_buffer)); @@ -2573,7 +2143,7 @@ Status IrEmitterUnnested::EmitSelectAndScatter( Load(selected_index_address->getAllocatedType(), selected_index_address_slot)); } - const Shape output_shape = GetShape(select_and_scatter_op.getOut()); + const Shape output_shape = instr->shape(); llvm::Value* source_value_address = source_array.EmitArrayElementAddress(source_index, &b_); llvm_ir::IrArray::Index selected_index(selected_multi_index, output_shape, @@ -2581,7 +2151,7 @@ Status IrEmitterUnnested::EmitSelectAndScatter( llvm::Value* output_value_address = out_array.EmitArrayElementAddress(selected_index, &b_); - const HloComputation* scatter_computation = select_and_scatter->scatter(); + const HloComputation* scatter_computation = instr->scatter(); return EmitAtomicOperationForNestedComputation( &b_, *ir_emitter_context_, *scatter_computation, output_value_address, source_value_address, source_array.GetElementLlvmType()); @@ -2592,248 +2162,71 @@ Status IrEmitterUnnested::EmitSelectAndScatter( .EmitLoop(name, index_type); } -Status IrEmitterUnnested::EmitWhile( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo) { - auto while_op = mlir::cast(op); - - auto cond_result = GetHloOutputs(while_op); - TF_RET_CHECK(cond_result.size() == 1); - TF_RET_CHECK(cond_result[0] - .getType() - .cast() - .getElementType() - .isInteger(/*width=*/1)) - << "While condition computation must return bool"; - - // Build ForThunk for conformant while loops, otherwise build WhileThunk. - // - // If Xla runtime is enabled we always lower to `lmhlo.while` operation and - // rely on `lmhlo-to-gpu-runtime` to lower while loops with known trip counts - // to `scf.for` loops. - if (while_op.getTripCount() && - !IsXlaRuntimeExecutableEnabled( - ir_emitter_context_->hlo_module().config())) { - TF_ASSIGN_OR_RETURN( - auto thunk, - BuildForThunk(while_op, Thunk::ThunkInfo::WithProfileAnnotation(op), - *while_op.getTripCount(), hlo_for_lmhlo)); - AddThunkToThunkSequence(std::move(thunk)); - } else { - TF_ASSIGN_OR_RETURN( - auto thunk, - BuildWhileThunk(while_op, Thunk::ThunkInfo::WithProfileAnnotation(op), - hlo_for_lmhlo)); - AddThunkToThunkSequence(std::move(thunk)); - } - return OkStatus(); -} +absl::Status IrEmitterUnnested::EmitWhile(const HloInstruction* instr) { + TF_ASSIGN_OR_RETURN(auto config, + instr->backend_config()); + + std::optional trip_count = std::nullopt; + if (config.has_known_trip_count()) trip_count = config.known_trip_count().n(); -Status IrEmitterUnnested::EmitRngGetAndUpdateState(mlir::Operation* op) { - auto rng_op = mlir::dyn_cast(op); + TF_ASSIGN_OR_RETURN( + auto thunk, + BuildWhileThunk(instr, Thunk::ThunkInfo::WithProfileAnnotation(instr), + trip_count)); + + AddThunkToThunkSequence(std::move(thunk)); + return absl::OkStatus(); +} +absl::Status IrEmitterUnnested::EmitRngGetAndUpdateState( + const HloRngGetAndUpdateStateInstruction* instr) { // Emit a kernel to increment the global state for Philox RNG algorithm. - TF_ASSIGN_OR_RETURN(auto ir_arrays, - BuildKernelThunkForNonFusionOp( - rng_op /*, rng_op.getState(),*/, LaunchDimensions())); + TF_ASSIGN_OR_RETURN(auto ir_arrays, BuildKernelThunkForNonFusionOp( + instr, {}, LaunchDimensions())); auto& [inputs, outputs] = ir_arrays; - llvm::Value* old_state = - llvm_ir::RngGetAndUpdateState(rng_op.getDelta(), module_, &b_); - - const Shape shape = GetShape(rng_op.getState()); - + llvm_ir::RngGetAndUpdateState(instr->delta(), module_, &b_); llvm::Value* output_address = inputs[0].EmitArrayElementAddress( llvm_ir::IrArray::Index( - /*linear=*/b_.getInt64(0), shape, &b_), + /*linear=*/b_.getInt64(0), instr->shape(), &b_), &b_, "rng_state_address"); Store(old_state, output_address); - - return OkStatus(); -} - -Status IrEmitterUnnested::EmitScatter( - const ScatterDescriptor& desc, const LaunchDimensions& launch_dimensions) { - auto loop_body_emitter = [&](const llvm_ir::IrArray::Index& index) -> Status { - std::vector raw_window_multidim; - std::vector input_scatter_multidim; - std::vector raw_window_bounds; - - auto get_i64_array = [](absl::Span container) { - return llvm::ArrayRef{container.data(), - static_cast(container.size())}; - }; - - llvm::ArrayRef update_window_dims = - get_i64_array(desc.dim_numbers.update_window_dims()); - // Partition the index into window indices and scatter indices. - for (int64_t i = 0, e = index.size(); i != e; ++i) { - // For window indices also remember the window size, this comes in handy - // later. - if (llvm::is_contained(update_window_dims, i)) { - raw_window_multidim.push_back(index[i]); - raw_window_bounds.push_back(desc.updates_shape.dimensions(i)); - } else { - input_scatter_multidim.push_back(index[i]); - } - } - DCHECK_EQ(raw_window_multidim.size(), - desc.dim_numbers.update_window_dims_size()); - - // Apply inserted_window_dims to the window dimensions. - int64_t raw_window_multidim_idx = 0; - llvm::SmallVector input_window_multidim; - llvm::SmallVector input_window_bounds; - const int64_t rank = desc.operand_shape.rank(); - input_window_bounds.reserve(rank); - input_window_multidim.reserve(rank); - - llvm::ArrayRef inserted_window_dims = - get_i64_array(desc.dim_numbers.inserted_window_dims()); - for (int64_t i = 0; i != rank; ++i) { - if (llvm::is_contained(inserted_window_dims, i)) { - input_window_bounds.push_back(1); // Trivial dimension. - input_window_multidim.push_back(index.GetConstantWithIndexType(0)); - } else { - input_window_bounds.push_back( - raw_window_bounds[raw_window_multidim_idx]); - input_window_multidim.push_back( - raw_window_multidim[raw_window_multidim_idx]); - ++raw_window_multidim_idx; - } - } - DCHECK_EQ(input_window_multidim.size(), desc.operand_shape.rank()); - - // Insert a 1 dimension at the end if index_vector_dim requests one. - Shape scatter_indices_shape_fixed = desc.scatter_indices_shape; - if (desc.dim_numbers.index_vector_dim() == - desc.scatter_indices_shape.rank()) { - scatter_indices_shape_fixed.add_dimensions(1); - scatter_indices_shape_fixed.mutable_layout()->add_minor_to_major( - desc.dim_numbers.index_vector_dim()); - } - - // Now load the indices corresponding to the current window from - // scatter_indices. - std::vector raw_scatter_index_multidim = - input_scatter_multidim; - raw_scatter_index_multidim.insert(raw_scatter_index_multidim.begin() + - desc.dim_numbers.index_vector_dim(), - nullptr); - - llvm::ArrayRef scatter_dims_to_operand_dims = - get_i64_array(desc.dim_numbers.scatter_dims_to_operand_dims()); - llvm::Value* is_in_bounds = b_.getTrue(); - for (int64_t i = 0, e = scatter_dims_to_operand_dims.size(); i != e; ++i) { - // Our index is stored along index_vector_dim, insert that into the lookup - // index into scatter_indices. - raw_scatter_index_multidim[desc.dim_numbers.index_vector_dim()] = - index.GetConstantWithIndexType(i); - llvm_ir::IrArray::Index raw_scatter_index_index( - raw_scatter_index_multidim, scatter_indices_shape_fixed, - index.GetType()); - - int64_t operand_dim = scatter_dims_to_operand_dims[i]; - if (operand_dim > rank) { - return absl::OutOfRangeError( - "The provided scatter_dims_to_operand_dims was out of range."); - } - TF_ASSIGN_OR_RETURN( - llvm::Value* const loaded_scatter_index, - desc.scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape( - scatter_indices_shape_fixed, desc.scatter_indices_shape, &b_))); - // And add the index to our window index. This yields the output index. - llvm::Value* casted_scatter_index = IntCast( - loaded_scatter_index, index.GetType(), - /*isSigned=*/ShapeUtil::ElementIsSigned(desc.scatter_indices_shape)); - llvm::Value* dim_offset = - Add(input_window_multidim[operand_dim], casted_scatter_index); - input_window_multidim[operand_dim] = dim_offset; - - // Also do the bounds check now. - int64_t max_index = desc.operand_shape.dimensions(operand_dim) - - input_window_bounds[operand_dim] + 1; - // is_in_bounds = index >= 0 && index < dim_size-window_size+1 - // --> index u< dim_size-window_size+1 - is_in_bounds = - And(is_in_bounds, ICmpULT(casted_scatter_index, - index.GetConstantWithIndexType(max_index))); - } - - llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse( - is_in_bounds, "scatter.in_bounds", &b_, /*emit_else=*/false); - llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_); - // All done, now just read from the calculated input from the window, and do - // an atomic store to the calculated location in the output. - llvm_ir::IrArray::Index input_window_index( - input_window_multidim, desc.output.GetShape(), index.GetType()); - llvm::Value* output_address = - desc.output.EmitArrayElementAddress(input_window_index, &b_); - llvm::Value* input_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(desc.updates_shape.element_type(), - module_), - "input_address", &b_); - TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, - desc.updates_gen(index)); - Store(input_ir_value, input_address); - - if (!desc.unique_indices) { - return EmitAtomicOperationForNestedComputation( - &b_, *ir_emitter_context_, *desc.update_computation, output_address, - input_address, desc.output.GetElementLlvmType()); - } else { - return CallNestedComputation( - &b_, *ir_emitter_context_, *desc.update_computation, - {output_address, input_address}, output_address); - } - }; - - // Launch a kernel that reads every element in the updates tensor. We could - // also do one kernel per window instead if bounds checks turn out to be a - // bottleneck. - return ParallelLoopEmitter(loop_body_emitter, desc.updates_shape, - launch_dimensions, &b_) - .EmitLoop(desc.name, - desc.get_index_type(launch_dimensions.launch_bound())); + return absl::OkStatus(); } -Status IrEmitterUnnested::EmitSort( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo) { - auto sort_op = mlir::cast(op); - auto* sort = hlo_for_lmhlo.at(op); - - std::string op_name = GetIrNameFromLoc(sort_op.getLoc()); - llvm::SmallVector operands = GetHloOperands(sort_op); - const Shape& keys_shape = GetShape(operands[0]); - int64_t dimension_to_sort = sort_op.getDimension(); - for (int64_t i = 0; i < operands.size(); ++i) { +absl::Status IrEmitterUnnested::EmitSort(const HloSortInstruction* sort) { + std::string op_name(sort->name()); + const Shape& keys_shape = sort->operand(0)->shape(); + int64_t dimension_to_sort = sort->sort_dimension(); + for (int64_t i = 0; i < sort->operand_count(); ++i) { + ShapeIndex shape_index = + sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); // We assume that the layout of all involved operands and outputs is the // same. - TF_RET_CHECK( - LayoutUtil::LayoutsInShapesEqual(keys_shape, GetShape(operands[i]))); + TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(keys_shape, + sort->operand(i)->shape())); TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( - keys_shape, GetShape(GetHloOutputs(sort_op)[i]))); - - // If possible, we share buffers. If that is not possible, we need to copy - // the values, because the emitter does the sorting in-place. - TF_ASSIGN_OR_RETURN(auto destination_buffer, - GetAllocationSlice(sort_op.getOutput()[i])); - TF_ASSIGN_OR_RETURN(auto source_address, - GetAllocationSlice(sort_op.getOperands()[i])); + keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index))); + + BufferAllocation::Slice destination_buffer; + BufferAllocation::Slice source_address; + + // If possible, we share buffers. If that is not possible, we need to + // copy the values, because the emitter does the sorting in-place. + TF_ASSIGN_OR_RETURN(destination_buffer, + GetAllocationSliceForHlo(sort, shape_index)); + TF_ASSIGN_OR_RETURN(source_address, + GetAllocationSliceForHlo(sort->operand(i), {})); + if (destination_buffer != source_address) { // TODO(b/26783907): Figure out why we never seem to share buffers for // key/value sort. VLOG(2) << op_name << " requires initial D2D copy for operand " << i; AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo(op), + Thunk::ThunkInfo::WithProfileAnnotation(sort), /*source_buffer=*/source_address, /*destination_buffer=*/destination_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(GetShape(operands[i])), - /*source_value=*/sort_op.getOperands()[i], - /*destination_value=*/sort_op.getOutput()[i])); + /*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()))); } } @@ -2879,10 +2272,8 @@ Status IrEmitterUnnested::EmitSort( standard_iteration_shape.set_dimensions(dimension_to_sort, standard_num_iterations_in_sort_dim); - TF_ASSIGN_OR_RETURN( - LaunchDimensions standard_launch_dimensions, - CalculateLaunchDimensions(standard_iteration_shape, - ir_emitter_context_->gpu_device_info())); + LaunchDimensions standard_launch_dimensions = CalculateLaunchDimensions( + standard_iteration_shape, ir_emitter_context_->gpu_device_info()); // Calculate the launch dimensions for the case where we use tiling. We split // the dimension that should be sorted into tiles of size 'kTileSize'. This @@ -2905,10 +2296,10 @@ Status IrEmitterUnnested::EmitSort( // Check whether we should use any tiling. We might not be able to use it if // we have not enough threads, or not enough shared memory. int64_t total_shared_memory_needed = 0; - for (int64_t i = 0; i < operands.size(); ++i) { + for (int64_t i = 0; i < sort->operand_count(); ++i) { total_shared_memory_needed += kTileSize * ShapeUtil::ByteSizeOfPrimitiveType( - GetShape(operands[i]).element_type()); + sort->operand(i)->shape().element_type()); } bool no_tiling = kThreadsPerBlock > @@ -2937,9 +2328,9 @@ Status IrEmitterUnnested::EmitSort( LaunchDimensions launch_dimensions = xor_masks.size() > 1 ? tiled_launch_dimensions : standard_launch_dimensions; - TF_ASSIGN_OR_RETURN(auto ir_arrays, - BuildKernelThunkForNonFusionOp( - sort_op, sort_op.getOutput(), launch_dimensions)); + TF_ASSIGN_OR_RETURN(auto ir_arrays, BuildKernelThunkForNonFusionOp( + sort, {}, launch_dimensions)); + auto& [inputs, outputs] = ir_arrays; auto* comparator = sort->called_computations().front(); return llvm_ir::EmitSortInPlace( @@ -2976,108 +2367,145 @@ Status IrEmitterUnnested::EmitSort( if (!xor_masks.empty()) { TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); } - return OkStatus(); + return absl::OkStatus(); } -template -Status IrEmitterUnnested::EmitReplicaOrPartitionId(mlir::Operation* op) { - auto casted = mlir::cast(op); +template +absl::Status IrEmitterUnnested::EmitReplicaOrPartitionId( + const HloInstruction* instr) { TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice, - GetAllocationSlice(casted.getOperand())); + GetAllocationSliceForHlo(instr, {})); auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), result_slice); + Thunk::ThunkInfo::WithProfileAnnotation(instr), result_slice); AddThunkToThunkSequence(std::move(thunk)); - return OkStatus(); + return absl::OkStatus(); } -template -Status IrEmitterUnnested::EmitCollectivePermute(mlir::Operation* op) { - auto collective_permute_op = mlir::cast(op); - +Status IrEmitterUnnested::EmitCollectivePermute( + const HloCollectivePermuteInstruction* instr) { + TF_RET_CHECK(instr->operand_count() == 1); + auto* operand = instr->operand(0); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice source_slice, - GetAllocationSlice(collective_permute_op.getOperand())); + GetAllocationSliceForHlo(operand)); + // First output is aliased. + TF_RET_CHECK( + instr->shape().IsTuple() && instr->shape().tuple_shapes_size() == 2 && + Shape::Equal().IgnoreMemorySpaceInLayout()( + instr->shape().tuple_shapes(0), instr->shape().tuple_shapes(1))); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice, - GetAllocationSlice(collective_permute_op.getOutput())); + GetAllocationSliceForHlo(instr, {1})); - const Shape shape = GetShape(collective_permute_op.getOperand()); + const Shape shape = operand->shape(); const auto& hlo_config = ir_emitter_context_->hlo_module().config(); const int64_t replica_count = hlo_config.replica_count(); const int64_t partition_count = hlo_config.num_partitions(); + const int64_t src_memory_space = shape.layout().memory_space(); + const int64_t dst_memory_space = + instr->shape().tuple_shapes(1).layout().memory_space(); - NcclCollectiveThunk::AsyncExecutor* async_executor; - if (NcclThunkType::IsDegenerate(collective_permute_op, replica_count, - partition_count)) { + if (NcclCollectivePermuteStartThunk::IsDegenerate(instr, replica_count, + partition_count)) { // For a degenerate collective permute, just generate a copy thunk. AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), + Thunk::ThunkInfo::WithProfileAnnotation(instr), /*source_buffer=*/source_slice, /*destination_buffer=*/result_slice, - /*mem_size=*/ShapeUtil::ByteSizeOf(shape), - /*source_value=*/collective_permute_op.getOperand(), - /*destination_value=*/collective_permute_op.getOutput())); + /*mem_size=*/ShapeUtil::ByteSizeOf(shape))); // Signal that start thunk not created with nullptr. - async_executor = nullptr; + GetCollectivesAsyncEvents().try_emplace(instr, nullptr); } else { const NcclCollectiveThunk::Buffer buffer = { /*element_count=*/ShapeUtil::ElementsIn(shape), /*source_buffer=*/source_slice, - /*destination_buffer=*/result_slice}; - auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), collective_permute_op, - replica_count, partition_count, buffer); - async_executor = thunk->async_executor(); + /*destination_buffer=*/result_slice, + /*source_memory_space=*/src_memory_space, + /*destination_memory_space=*/dst_memory_space}; + auto thunk = std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(instr), NcclApi::Default(), + instr, replica_count, partition_count, buffer); + GetCollectivesAsyncEvents().try_emplace(instr, thunk->async_events()); AddThunkToThunkSequence(std::move(thunk)); } - async_executors_.insert({op, async_executor}); - return OkStatus(); + return absl::OkStatus(); } -template -Status IrEmitterUnnested::EmitNcclThunk(mlir::Operation* untyped_op) { - OpT op = mlir::cast(untyped_op); +template +absl::Status IrEmitterUnnested::EmitNcclThunk( + Thunk::Kind kind, const HloInstruction* async_start, + const HloInstType* inst, std::optional use_global_device_ids) { const auto& hlo_config = ir_emitter_context_->hlo_module().config(); int64_t replica_count = hlo_config.replica_count(); int64_t partition_count = hlo_config.num_partitions(); VLOG(2) << NcclThunkType::GetHloOpName() << "; replica count: " << replica_count << "; partition count: " << partition_count - << "; operand count: " << op.getOperands().size() - << "; NCCL is enabled: " << NcclThunkType::NcclIsEnabled(); + << "; operand count: " << inst->operand_count(); // A given collective op can be degenerate if across all groups formed // by it are singleton. In such a case, we don't need to do any communication // and we can just copy the input to the output. - bool is_degenerate = - NcclThunkType::IsDegenerate(op, replica_count, partition_count); - Status implementable_status = - NcclThunkType::CheckImplementable(op, replica_count, partition_count); + bool is_degenerate = GetNcclCollectiveConfig(inst, use_global_device_ids) + .IsDegenerate(replica_count, partition_count); + absl::Status implementable_status = + NcclThunkType::CheckImplementable(inst, replica_count, partition_count); bool should_use_nccl_thunk = !is_degenerate && implementable_status.ok(); // Stash relevant information in NcclCollectiveThunk::Buffer even if we may // not generate an NcclCollectiveThunk. std::vector buffers; - buffers.reserve(op.getInputs().size()); - for (auto it : llvm::zip(op.getInputs(), op.getOutputs())) { - mlir::Value operand = std::get<0>(it); - mlir::Value result = std::get<1>(it); - const Shape shape = GetShape(operand); - TF_ASSIGN_OR_RETURN(auto source_slice, GetAllocationSlice(operand)); - TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSlice(result)); + + int64_t operand_count = inst->operand_count(); + buffers.reserve(operand_count); + + // Adds a source and destination buffers pair to `buffers`. + auto add_buffer = [&](int64_t element_count, BufferAllocation::Slice src, + int64_t src_memory_space, BufferAllocation::Slice dst, + int64_t dst_memory_space) { buffers.push_back(NcclCollectiveThunk::Buffer{ - /*element_count=*/ShapeUtil::ElementsIn(shape), - /*source_buffer=*/source_slice, - /*destination_buffer=*/dest_slice, - /*source_value=*/operand, - /*destination_value=*/result}); + /*element_count=*/element_count, + /*source_buffer=*/src, + /*destination_buffer=*/dst, + /*source_memory_space=*/src_memory_space, + /*destination_memory_space=*/dst_memory_space, + /*source_value=*/nullptr, + /*destination_value=*/nullptr}); + }; + + if (kind == Thunk::Kind::kNcclAllGatherStart) { + // Start operations return a tuple of (<>, <>) where + // outputs can be a tuple itself (if operation has multiple operands). + for (int64_t i = 0; i < operand_count; i++) { + ShapeIndex idx = operand_count > 1 ? ShapeIndex({1, i}) : ShapeIndex({1}); + const Shape& src_shape = inst->operand(i)->shape(); + const Shape& dst_shape = ShapeUtil::GetSubshape(inst->shape(), idx); + TF_ASSIGN_OR_RETURN(auto src, GetAllocationSliceForHlo(inst->operand(i))); + TF_ASSIGN_OR_RETURN(auto dst, GetAllocationSliceForHlo(inst, idx)); + add_buffer(ShapeUtil::ElementsIn(src_shape), src, + src_shape.layout().memory_space(), dst, + dst_shape.layout().memory_space()); + } + + } else { + // For other operations simply zip operands with results. + for (int64_t i = 0; i < operand_count; i++) { + ShapeIndex idx = operand_count > 1 ? ShapeIndex({i}) : ShapeIndex({}); + const Shape& src_shape = inst->operand(i)->shape(); + const Shape& dst_shape = ShapeUtil::GetSubshape(inst->shape(), idx); + TF_ASSIGN_OR_RETURN(auto src, GetAllocationSliceForHlo(inst->operand(i))); + TF_ASSIGN_OR_RETURN(auto dst, GetAllocationSliceForHlo(inst, idx)); + add_buffer(ShapeUtil::ElementsIn(src_shape), src, + src_shape.layout().memory_space(), dst, + dst_shape.layout().memory_space()); + } } if (should_use_nccl_thunk) { auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), op, + Thunk::ThunkInfo::WithProfileAnnotation(inst), NcclApi::Default(), inst, /*buffers=*/std::move(buffers)); - async_executors_.insert({untyped_op, thunk->async_executor()}); + GetCollectivesAsyncEvents().insert({async_start, thunk->async_events()}); AddThunkToThunkSequence(std::move(thunk)); - return OkStatus(); + return absl::OkStatus(); } if (!is_degenerate) { @@ -3085,7 +2513,7 @@ Status IrEmitterUnnested::EmitNcclThunk(mlir::Operation* untyped_op) { } // Signal that start thunk not created with nullptr. - async_executors_.insert({untyped_op, nullptr}); + GetCollectivesAsyncEvents().insert({async_start, nullptr}); VLOG(1) << "Collective call is degenerate, not doing NCCL call"; @@ -3093,594 +2521,697 @@ Status IrEmitterUnnested::EmitNcclThunk(mlir::Operation* untyped_op) { // assignment expects a copy, so that's what we do. ThunkSequence thunks; for (int64_t i = 0; i < buffers.size(); i++) { - const Shape shape = GetShape(op.getOperands()[i]); + const Shape shape = inst->operand(i)->shape(); thunks.push_back(std::make_unique( - buffers.size() == 1 ? Thunk::ThunkInfo::WithProfileAnnotation(op) - : Thunk::ThunkInfo(op), + Thunk::ThunkInfo::WithProfileAnnotation(inst), /*source_buffer=*/buffers[i].source_buffer, /*destination_buffer=*/buffers[i].destination_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(shape), - /*source_value=*/buffers[i].source_value, - /*destination_value=*/buffers[i].destination_value)); + /*mem_size=*/ShapeUtil::ByteSizeOf(shape))); } if (thunks.size() == 1) { AddThunkToThunkSequence(std::move(thunks[0])); } else { AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(thunks))); + Thunk::ThunkInfo::WithProfileAnnotation(inst), std::move(thunks))); } - return OkStatus(); + return absl::OkStatus(); } -template -Status IrEmitterUnnested::EmitNcclAsyncDone(Thunk::Kind kind, - mlir::Operation* op) { - auto start_op = mlir::cast(op).getToken().getDefiningOp(); - auto async_executor = async_executors_.extract(start_op); - TF_RET_CHECK(async_executor) << "couldn't find async executor for start op"; +absl::Status IrEmitterUnnested::EmitNcclAsyncDone(Thunk::Kind kind, + const HloInstruction* inst) { + CollectivesAsyncEvents& collectives_async_events = + GetCollectivesAsyncEvents(); + if (kind == Thunk::Kind::kNcclRecvDone || + kind == Thunk::Kind::kNcclSendDone) { + const HloChannelInstruction* done = DynCast(inst); + int64_t channel_id = done->channel_id().value(); + // We only pipeline Send/Recv when channel_id > 0, and allows multiple + // and potentially interleaving Send/Recv chains using channel_id = 0. + if (MayPipelineSendRecvChannel(channel_id)) { + auto it = collectives_async_events.find( + GetSendRecvAsyncEventsKey(kind, channel_id)); + TF_RET_CHECK(it != collectives_async_events.end()) + << "couldn't find async events for channel_id " << channel_id; + AddThunkToThunkSequence(std::make_unique( + kind, Thunk::ThunkInfo::WithProfileAnnotation(inst), it->second)); + return absl::OkStatus(); + } + } + + const HloInstruction* start = inst->operand(0); + auto async_events = collectives_async_events.extract(start); + TF_RET_CHECK(async_events) + << "couldn't find async events for start operation"; // Can be null if no start thunk was created (e.g. if the start op is // degenerate), in which case there's nothing to do here. - if (async_executor.mapped() != nullptr) { + if (async_events.mapped()) { AddThunkToThunkSequence(std::make_unique( - kind, Thunk::ThunkInfo::WithProfileAnnotation(op), - *async_executor.mapped())); + kind, Thunk::ThunkInfo::WithProfileAnnotation(inst), + std::move(async_events.mapped()))); + } + return absl::OkStatus(); +} + +absl::Status IrEmitterUnnested::EmitWaitForStreamsThunk( + const HloInstruction* inst, GpuBackendConfig& gpu_config, + bool is_async_done) { + std::vector wait_on_streams; + ExecutionStreamId source_stream_id = Thunk::kDefaultExecutionStreamId; + // If it's for an async done, then we need to synchronize on the execution + // stream of the instruction from main compute stream + if (is_async_done) { + wait_on_streams.push_back( + ExecutionStreamId(gpu_config.operation_queue_id())); + } else if (gpu_config.wait_on_operation_queues().size() == 0) { + // If wait on queue is empty, we just synchronize on the main compute + // stream from the execution stream. + wait_on_streams.push_back(Thunk::kDefaultExecutionStreamId); + source_stream_id = gpu_config.operation_queue_id(); + } else { + // Else, we synchronize on all specified + // streams from the execution stream. + for (int64_t stream_id : gpu_config.wait_on_operation_queues()) { + wait_on_streams.push_back(ExecutionStreamId(stream_id)); + } + source_stream_id = gpu_config.operation_queue_id(); } - return OkStatus(); + + AddThunkToThunkSequence(std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(inst), source_stream_id, + wait_on_streams)); + return absl::OkStatus(); } -StatusOr> IrEmitterUnnested::GetShapedSlices( - mlir::Operation::operand_range operands) { +absl::Status IrEmitterUnnested::EmitInfeed(const HloInfeedInstruction* instr) { + // Infeed instruction returns a tuple containing the result data and a token. + // We only need the result data to construct the infeed thunk. std::vector shaped_slices; - shaped_slices.reserve(operands.size()); - for (mlir::Value opnd : operands) { - TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(opnd)); - shaped_slices.push_back(ShapedSlice{slice, GetShape(opnd)}); - } - return shaped_slices; -} + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + instr->shape(), + [&](const Shape& subshape, const ShapeIndex& index) -> absl::Status { + if (subshape.IsTuple() || subshape.IsToken()) return absl::OkStatus(); + if (subshape.IsArray()) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data, + GetAllocationSliceForHlo(instr, index)); + ShapedSlice shaped_slice = {data, subshape}; + shaped_slices.push_back(shaped_slice); + return absl::OkStatus(); + } + return Internal("Unexpected shape kind for %s and shape index %s", + instr->ToString(), index.ToString()); + })); -Status IrEmitterUnnested::EmitInfeed(mlir::Operation* op) { - mlir::Operation::operand_range operands = - mlir::cast(op).getOutputs(); - TF_ASSIGN_OR_RETURN(auto shaped_slices, GetShapedSlices(operands)); auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(shaped_slices)); + Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(shaped_slices)); AddThunkToThunkSequence(std::move(thunk)); - - return OkStatus(); + return absl::OkStatus(); } -Status IrEmitterUnnested::EmitOutfeed(mlir::Operation* op) { - mlir::Operation::operand_range operands = - mlir::cast(op).getInputs(); - TF_ASSIGN_OR_RETURN(auto shaped_slices, GetShapedSlices(operands)); +absl::Status IrEmitterUnnested::EmitOutfeed( + const HloOutfeedInstruction* instr) { + // HLO outfeed instruction has 2 operands, the source and a token, and a + // single token output. + const HloInstruction* source = instr->operand(0); + std::vector shaped_slices; + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + source->shape(), + [&](const Shape& subshape, const ShapeIndex& index) -> absl::Status { + if (subshape.IsTuple()) return absl::OkStatus(); + if (subshape.IsArray()) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data, + GetAllocationSliceForHlo(source, index)); + ShapedSlice shaped_slice = {data, subshape}; + shaped_slices.push_back(shaped_slice); + return absl::OkStatus(); + } + return Internal("Unexpected shape kind for %s and shape index %s", + source->ToString(), index.ToString()); + })); + auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(shaped_slices)); + Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(shaped_slices)); AddThunkToThunkSequence(std::move(thunk)); - - return OkStatus(); + return absl::OkStatus(); } -StatusOr< - std::pair, std::vector>> +absl::StatusOr /*inputs*/, + std::vector /*outputs*/>> IrEmitterUnnested::BuildKernelThunkForNonFusionOp( - mlir::Operation* op, mlir::ValueRange needed_operands, + const HloInstruction* hlo, + absl::Span needed_operands, const LaunchDimensions& launch_dimensions) { - TF_RET_CHECK(!mlir::isa(op)) - << "Please use BuildKernelThunkForFusion!"; - - std::string suggested_kernel_name = GetIrNameFromLoc(op->getLoc()); + std::string suggested_kernel_name(hlo->name()); TF_ASSIGN_OR_RETURN( auto kernel_arguments, - KernelArguments::Create(ir_emitter_context_->allocations(), op, + KernelArguments::Create(ir_emitter_context_->buffer_assignment(), hlo, needed_operands)); VLOG(3) << "Generating (without reuse check): " << suggested_kernel_name; - auto [kernel, inputs, outputs] = BuildKernelPrototype( - *ir_emitter_context_, suggested_kernel_name, kernel_arguments.args(), - needed_operands.size(), launch_dimensions, &b_); + llvm::Function* kernel; + std::vector inputs; + std::vector outputs; + TF_ASSIGN_OR_RETURN( + std::tie(kernel, inputs, outputs), + BuildKernelPrototype( + *ir_emitter_context_, suggested_kernel_name, kernel_arguments.args(), + kernel_arguments.args().size(), launch_dimensions, &b_)); AddThunkToThunkSequence(std::make_unique( - op, kernel->getName().str(), kernel_arguments.args(), launch_dimensions, + hlo, kernel->getName().str(), kernel_arguments.args(), launch_dimensions, + /*cluster_dim=*/std::nullopt, /*shmem_bytes=*/0)); return {{inputs, outputs}}; } -StatusOr< - std::pair, std::vector>> -IrEmitterUnnested::BuildKernelThunkForNonFusionOp( - mlir::Operation* op, const LaunchDimensions& launch_dimensions) { - return BuildKernelThunkForNonFusionOp(op, op->getOperands(), - launch_dimensions); -} - -Status IrEmitterUnnested::BuildInitializerThunk( - mlir::Operation* op, const HloInstruction* instr, - const HloInstruction* init_value, mlir::Value init_value_mlir, - mlir::Value dest) { +absl::Status IrEmitterUnnested::BuildInitializerThunk( + const HloInstruction* instr, const HloInstruction* init_value) { // initial value must be a scalar memref. TF_RET_CHECK(init_value->shape().rank() == 0); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dest_slice, - GetAllocationSlice(dest)); + auto maybe_dest_slice = GetAllocationSliceForHlo(instr, {}); + if (!maybe_dest_slice.ok()) return maybe_dest_slice.status(); - TF_ASSIGN_OR_RETURN( - std::optional> constant_init_thunk, - BuildConstantInitializerThunk(*ir_emitter_context_, op, instr, init_value, - dest, dest_slice)); + BufferAllocation::Slice dest_slice = *maybe_dest_slice; + + TF_ASSIGN_OR_RETURN(std::optional> constant_init_thunk, + BuildConstantInitializerThunk(*ir_emitter_context_, instr, + init_value, dest_slice)); if (constant_init_thunk) { AddThunkToThunkSequence(*std::move(constant_init_thunk)); - return OkStatus(); + return absl::OkStatus(); } // Otherwise fall back to our slow initializer code. The thunk in this case // will just need the IR arrays for the initial value and the destination. - const Shape dest_shape = GetShape(dest); + const Shape& dest_shape = instr->shape(); - TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, - CalculateLaunchDimensions( - dest_shape, ir_emitter_context_->gpu_device_info())); - TF_ASSIGN_OR_RETURN(auto ir_arrays, - BuildKernelThunkForNonFusionOp( - op, {init_value_mlir, dest}, launch_dimensions)); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + dest_shape, ir_emitter_context_->gpu_device_info()); + TF_ASSIGN_OR_RETURN( + auto ir_arrays, + BuildKernelThunkForNonFusionOp(instr, {init_value}, launch_dimensions)); auto& [inputs, outputs] = ir_arrays; auto init_array = inputs[0]; - std::string name = GetIrNameFromLoc(op->getLoc()); + std::string name = llvm_ir::IrName(instr, "init"); TF_RETURN_IF_ERROR(ParallelLoopEmitter( [=](const llvm_ir::IrArray::Index& index) { return init_array.EmitReadArrayElement(index, &b_); }, {inputs[1]}, launch_dimensions, &b_) - .EmitLoop(GetIrNameFromLoc(op->getLoc()))); - return OkStatus(); + .EmitLoop(name)); + return absl::OkStatus(); } -StatusOr> IrEmitterUnnested::BuildWhileThunk( - mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info, - const absl::flat_hash_map& - hlo_for_lmhlo) { +absl::StatusOr> IrEmitterUnnested::BuildWhileThunk( + const HloInstruction* instr, const Thunk::ThunkInfo& thunk_info, + std::optional trip_count) { + HloComputation* condition = instr->while_condition(); + HloComputation* body = instr->while_body(); + // Generate thunk sequence for while 'condition'. - mlir::Region* condition = &while_op.getCond(); auto ir_emitter_condition = IrEmitterUnnested::Create(ir_emitter_context_); - - TF_RETURN_IF_ERROR( - ir_emitter_condition->EmitLmhloRegion(condition, hlo_for_lmhlo)); + TF_RETURN_IF_ERROR(ir_emitter_condition->EmitHloComputation(condition)); // Generate thunk sequence for while 'body'. - mlir::Region* body = &while_op.getBody(); auto ir_emitter_body = IrEmitterUnnested::Create(ir_emitter_context_); + TF_RETURN_IF_ERROR(ir_emitter_body->EmitHloComputation(body)); - TF_RETURN_IF_ERROR(ir_emitter_body->EmitLmhloRegion(body, hlo_for_lmhlo)); - - // Extract the condition value from the last op (excluding the terminator op) - // in the condition region. - auto cond_result = GetHloOutputs(while_op); - TF_RET_CHECK(cond_result.size() == 1); - TF_ASSIGN_OR_RETURN(auto cond_result_slice, - GetAllocationSlice(cond_result[0])); + // Buffer slice holding while loop predicate. + TF_ASSIGN_OR_RETURN( + auto pred, GetAllocationSliceForHlo(condition->root_instruction(), {})); - return std::unique_ptr( - new WhileThunk(thunk_info, cond_result_slice, - ir_emitter_condition->ConsumeThunkSequence(), - ir_emitter_body->ConsumeThunkSequence())); + return std::unique_ptr(new WhileThunk( + thunk_info, pred, ir_emitter_condition->ConsumeThunkSequence(), + ir_emitter_body->ConsumeThunkSequence(), trip_count)); } -StatusOr> IrEmitterUnnested::BuildForThunk( - mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info, - const int64_t loop_limit, - const absl::flat_hash_map& - hlo_for_lmhlo) { - // Generate thunk sequence for while 'body' (will be used a For loop body). - auto ir_emitter_body = IrEmitterUnnested::Create(ir_emitter_context_); - TF_RETURN_IF_ERROR( - ir_emitter_body->EmitLmhloRegion(&while_op.getBody(), hlo_for_lmhlo)); - - return std::unique_ptr(new ForThunk( - thunk_info, loop_limit, ir_emitter_body->ConsumeThunkSequence())); +absl::Status IrEmitterUnnested::EmitTargetElementLoop( + const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter) { + return Internal("This should be unreachable"); } -Status IrEmitterUnnested::EmitTargetElementLoop( - const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter) { - return InternalError("This should be unreachable"); +static absl::flat_hash_map ConvertFrontendAttributes( + const FrontendAttributes& attrs) { + absl::flat_hash_map result; + for (auto& [k, v] : attrs.map()) result[k] = v; + return result; } -StatusOr IrEmitterUnnested::EmitScatter( - const HloFusionInstruction* fusion, mlir::lmhlo::FusionOp fusion_op, - HloFusionAnalysis& fusion_analysis) { - auto* fused_computation = fusion->fused_instructions_computation(); - auto* root = fused_computation->root_instruction(); +static std::optional DeviceConstraint( + const HloInstruction* hlo) { + if (hlo->has_sharding() && hlo->sharding().HasUniqueDevice()) { + return GlobalDeviceId(hlo->sharding().GetUniqueDevice()); + } + return std::nullopt; +} + +absl::Status IrEmitterUnnested::EmitCopyStartThunk( + const HloCopyStartInstruction* instr) { + // copy-start has a tuple shape: {host, device, context}, + // or {device, host, context}. + // Only the destination shape is needed to get the output buffer. + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dst_buffer, + GetAllocationSliceForHlo(instr, + /*ShapeIndex=*/{0})); + + const HloInstruction* src = instr->operand(0); + const Shape& input_shape = src->shape(); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice src_buffer, + GetAllocationSliceForHlo(src, {})); + Shape shape = instr->shape(); + CHECK(shape.IsTuple()); + + if (shape.mutable_tuple_shapes(0)->has_layout() && + shape.mutable_tuple_shapes(0)->mutable_layout()->memory_space() == + static_cast(stream_executor::MemoryType::kHost)) { + VLOG(3) << "Device to Host: host memory space " + << static_cast(stream_executor::MemoryType::kHost); + auto thunk = std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(instr), + /*source_buffer=*/src_buffer, + /*destination_buffer=*/dst_buffer, + /*mem_size=*/ShapeUtil::ByteSizeOf(input_shape)); + AddThunkToThunkSequence(std::move(thunk)); + return absl::OkStatus(); + } + if (shape.mutable_tuple_shapes(1)->has_layout() && + shape.mutable_tuple_shapes(1)->mutable_layout()->memory_space() == + static_cast(stream_executor::MemoryType::kHost)) { + VLOG(3) << "Host to Device from the host memory space " + << static_cast(stream_executor::MemoryType::kHost); + ; + auto thunk = std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(instr), + /*source_buffer=*/src_buffer, + /*destination_buffer=*/dst_buffer, + /*mem_size=*/ShapeUtil::ByteSizeOf(input_shape)); + AddThunkToThunkSequence(std::move(thunk)); + return absl::OkStatus(); + } - // Nothing should have been fused into the first operand of scatter. - CHECK_EQ(root->operand(0)->opcode(), HloOpcode::kParameter); + // Disabled the generation of memcpy D2D as only H2D and D2H are useful + // for memory offload now. - const Shape& updates_shape = root->operand(2)->shape(); + auto thunk = std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(instr), + /*source_buffer=*/src_buffer, + /*destination_buffer=*/dst_buffer, + /*mem_size=*/ShapeUtil::ByteSizeOf(input_shape)); + AddThunkToThunkSequence(std::move(thunk)); - TF_ASSIGN_OR_RETURN( - LaunchDimensions launch_dimensions, - CalculateLaunchDimensions(updates_shape, - ir_emitter_context_->gpu_device_info())); - - auto builder_fn = [&, this](std::vector inputs, - std::vector outputs) -> Status { - // Spin up a new fused emitter for the scatter kernel and emit it. - FusedIrEmitter scatter_fused_emitter = FusedIrEmitter(elemental_emitter_); - for (int i = 0; i < fused_computation->num_parameters(); i++) { - auto fused_operand = fused_computation->parameter_instruction(i); - scatter_fused_emitter.BindGenerator( - *fused_operand, [this, &input = inputs[i], - fused_operand](llvm_ir::IrArray::Index index) { - return input.EmitReadArrayElement(index, &b_, - fused_operand->name()); - }); + return absl::OkStatus(); +} + +absl::Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) { + if (!instr->channel_id().has_value()) + return absl::InternalError("Unknown send instruction channel id"); + + const HloInstruction* src = instr->operand(0); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer, + GetAllocationSliceForHlo(src, {})); + if (!instr->is_host_transfer()) { + const auto& hlo_config = ir_emitter_context_->hlo_module().config(); + const int64_t replica_count = hlo_config.replica_count(); + const int64_t partition_count = hlo_config.num_partitions(); + const int64_t memory_space = + instr->shape().IsTuple() + ? instr->shape().tuple_shapes(0).layout().memory_space() + : instr->shape().layout().memory_space(); + const NcclCollectiveThunk::Buffer nccl_buffer = { + /*element_count=*/ShapeUtil::ElementsIn(src->shape()), + /*source_buffer=*/buffer, + /*destination_buffer=*/buffer, + /*source_memory_space=*/memory_space, + /*destination_memory_space=*/memory_space}; + auto thunk = std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(instr), NcclApi::Default(), + instr, replica_count, partition_count, nccl_buffer); + CollectivesAsyncEvents& collectives_async_events = + GetCollectivesAsyncEvents(); + int64_t channel_id = instr->channel_id().value(); + if (MayPipelineSendRecvChannel(channel_id)) { + std::pair async_events_key = + GetSendRecvAsyncEventsKey(Thunk::Kind::kNcclSendDone, channel_id); + auto it = collectives_async_events.find(async_events_key); + if (it != collectives_async_events.end()) { + VLOG(0) << "Found async events " << it->second.get(); + thunk->set_async_events(it->second); + } else { + VLOG(0) << "Used Async events create for thunk " + << thunk->async_events().get(); + collectives_async_events.emplace(async_events_key, + thunk->async_events()); + } + } else { + collectives_async_events.try_emplace(instr, thunk->async_events()); } - auto* scatter = Cast(root); - const xla::ScatterDimensionNumbers& xla_scatter_dim = - scatter->scatter_dimension_numbers(); - - ScatterDescriptor desc; - desc.name = llvm_ir::IrName(root); - desc.operand_shape = root->operand(0)->shape(); - desc.scatter_indices_shape = root->operand(1)->shape(); - desc.updates_shape = updates_shape; - desc.dim_numbers = xla_scatter_dim; - desc.unique_indices = root->unique_indices(); - desc.update_computation = root->called_computations()[0]; - desc.output = outputs.back(); - TF_ASSIGN_OR_RETURN(desc.scatter_indices_gen, - scatter_fused_emitter.GetGenerator(*root->operand(1))); - TF_ASSIGN_OR_RETURN(desc.updates_gen, - scatter_fused_emitter.GetGenerator(*root->operand(2))); - desc.get_index_type = [&](int64_t launch_size) { - return GetIndexTypeForKernel(root, launch_size, &b_); - }; - return EmitScatter(desc, launch_dimensions); - }; + AddThunkToThunkSequence(std::move(thunk)); + return absl::OkStatus(); + } - TF_ASSIGN_OR_RETURN(std::unique_ptr kernel_thunk, - BuildKernelThunkForFusion( - *ir_emitter_context_, kernel_reuse_cache_, fusion, - fusion_op, fused_computation, launch_dimensions, - /*discriminator=*/"scatter", builder_fn, &b_)); + AddThunkToThunkSequence(std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(instr), src->shape(), buffer, + *instr->channel_id(), send_recv_events_, + ConvertFrontendAttributes(instr->frontend_attributes()), + DeviceConstraint(instr))); - FusionEmissionResult result; - result.thunks.push_back(std::move(kernel_thunk)); - return result; + return absl::OkStatus(); } -StatusOr IrEmitterUnnested::EmitCustomFusion( - const HloFusionInstruction* fusion, const CustomFusionConfig& config) { - VLOG(3) << "Lower HLO fusion to a custom fusion " << config.name(); +absl::Status IrEmitterUnnested::EmitSendDoneThunk( + const HloSendDoneInstruction* instr) { + if (!instr->channel_id().has_value()) + return absl::InternalError("Unknown send done instruction channel id"); - auto* registry = CustomFusionRegistry::Default(); - auto* custom_fusion = registry->Lookup(config.name()); - - // If custom fusion is not found it means that some of the build targets might - // not be statically linked into the binary. - if (custom_fusion == nullptr) { - return absl::InternalError(absl::StrCat( - "Custom fusion ", config.name(), " not found in a default registry.")); + if (!instr->is_host_transfer()) { + return EmitNcclAsyncDone(Thunk::kNcclSendDone, instr); } - // Load custom kernels that can implement a fusion computation. - TF_ASSIGN_OR_RETURN( - std::vector kernels, - custom_fusion->LoadKernels(fusion->fused_instructions_computation())); + AddThunkToThunkSequence(std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(instr), *instr->channel_id(), + send_recv_events_, DeviceConstraint(instr))); - // This should never happen, it means that compilation pipeline created a - // fusion operation that is not supported by a given custom fusion. - if (kernels.empty()) { - return absl::InternalError( - absl::StrCat("Custom fusion ", config.name(), - " returned empty custom kernels for a fused computation")); - } + return absl::OkStatus(); +} - // TODO(ezhulenev): Add support for auto tuning to select the best kernel. - if (kernels.size() != 1) { - return absl::InternalError("Expected exactly one custom kernel"); +absl::Status IrEmitterUnnested::EmitRecvThunk(const HloRecvInstruction* instr) { + if (!instr->channel_id().has_value()) + return absl::InternalError("Unknown recv instruction channel id"); + TF_RET_CHECK(instr->shape().IsTuple()); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer, + GetAllocationSliceForHlo(instr, {0})); + + if (!instr->is_host_transfer()) { + const auto& hlo_config = ir_emitter_context_->hlo_module().config(); + const int64_t replica_count = hlo_config.replica_count(); + const int64_t partition_count = hlo_config.num_partitions(); + + const int64_t memory_space = + instr->shape().IsTuple() + ? instr->shape().tuple_shapes(0).layout().memory_space() + : instr->shape().layout().memory_space(); + + const NcclCollectiveThunk::Buffer nccl_buffer = { + /*element_count=*/ShapeUtil::ElementsIn(instr->shape().tuple_shapes(0)), + /*source_buffer=*/buffer, + /*destination_buffer=*/buffer, + /*source_memory_space=*/memory_space, + /*destination_memory_space=*/memory_space}; + auto thunk = std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(instr), NcclApi::Default(), + instr, replica_count, partition_count, nccl_buffer); + CollectivesAsyncEvents& collectives_async_events = + GetCollectivesAsyncEvents(); + int64_t channel_id = instr->channel_id().value(); + if (MayPipelineSendRecvChannel(channel_id)) { + std::pair async_events_key = + GetSendRecvAsyncEventsKey(Thunk::Kind::kNcclRecvDone, channel_id); + auto it = collectives_async_events.find(async_events_key); + + if (it != GetCollectivesAsyncEvents().end()) { + thunk->set_async_events(it->second); + } else { + collectives_async_events.emplace(async_events_key, + thunk->async_events()); + } + } else { + collectives_async_events.try_emplace(instr, thunk->async_events()); + } + + AddThunkToThunkSequence(std::move(thunk)); + return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN( - auto thunk, BuildCustomKernelThunkForFusion(*ir_emitter_context_, fusion, - std::move(kernels[0]))); + AddThunkToThunkSequence(std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(instr), + instr->shape().tuple_shapes()[0], buffer, *instr->channel_id(), + send_recv_events_, + ConvertFrontendAttributes(instr->frontend_attributes()), + DeviceConstraint(instr))); - FusionEmissionResult result; - result.thunks.push_back(std::move(thunk)); - return result; + return absl::OkStatus(); } -Status IrEmitterUnnested::EmitOp( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo) { - if (mlir::isa(op)) { - return OkStatus(); - } +absl::Status IrEmitterUnnested::EmitRecvDoneThunk( + const HloRecvDoneInstruction* instr) { + if (!instr->channel_id().has_value()) + return absl::InternalError("Unknown recv done instruction channel id"); - if (mlir::isa(op)) { - const HloConstantInstruction* hlo_const_instr = - DynCast(hlo_for_lmhlo.at(op)); - TF_RET_CHECK(hlo_const_instr); - return EmitConstant(op, hlo_const_instr->literal()); + if (!instr->is_host_transfer()) { + return EmitNcclAsyncDone(Thunk::kNcclRecvDone, instr); } - if (auto call = mlir::dyn_cast(op)) { - if (call.getCallTargetName() == "PadToStatic") { - return EmitPadToStatic(op); - } - if (call.getCallTargetName() == "SliceToDynamic") { - return EmitSliceToDynamic(op); - } - const llvm::StringRef call_target = call.getCallTargetName(); -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - if (absl::string_view(call_target.data(), call_target.size()) == - kTriangularSolveCallTarget) { - return EmitTriangularSolveCustomCall(op); - } -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + AddThunkToThunkSequence(std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(instr), *instr->channel_id(), + send_recv_events_, DeviceConstraint(instr))); - return EmitCustomCallThunk(op); - } + return absl::OkStatus(); +} - if (mlir::isa(op)) { - return EmitGemmThunk(op); - } +absl::Status IrEmitterUnnested::EmitHloInstruction( + const HloInstruction* instr) { + switch (instr->opcode()) { + case HloOpcode::kAllGatherDone: + return EmitNcclAsyncDone(Thunk::kNcclAllGatherDone, instr); + case HloOpcode::kAllGatherStart: { + auto* all_gather = Cast(instr); + return EmitNcclThunk( + Thunk::kNcclAllGatherStart, all_gather, all_gather, + all_gather->use_global_device_ids()); + } + + case HloOpcode::kAllReduceDone: + return EmitNcclAsyncDone(Thunk::kNcclAllReduceDone, instr); + case HloOpcode::kAllReduceStart: { + auto* all_reduce = Cast(instr); + return EmitNcclThunk( + Thunk::kNcclAllReduceStart, all_reduce, all_reduce, + all_reduce->use_global_device_ids()); + } + case HloOpcode::kAsyncDone: { + const HloInstruction* wrapped = instr->async_wrapped_instruction(); + switch (wrapped->opcode()) { + case HloOpcode::kReduceScatter: + return EmitNcclAsyncDone(Thunk::kNcclReduceScatterDone, instr); + case HloOpcode::kAllToAll: + return EmitNcclAsyncDone(Thunk::kNcclAllToAllDone, instr); + case HloOpcode::kCollectiveBroadcast: + return EmitNcclAsyncDone(Thunk::kNcclCollectiveBroadcastDone, instr); + default: { + if (wrapped->has_backend_config()) { + TF_ASSIGN_OR_RETURN( + xla::gpu::GpuBackendConfig gpu_config, + wrapped->backend_config()); + if (gpu_config.operation_queue_id() != 0) { + // If there an async-done instruction that wraps an instruction + // that runs on a non-default stream, then we will + // just emit syncOnStreamThunk(). + return EmitWaitForStreamsThunk(instr, gpu_config, + /*is_async_done=*/true); + } + } + + return Internal("Unsupported async done wrapped instruction: %s", + HloOpcodeString(wrapped->opcode())); + } + } + } + case HloOpcode::kAsyncStart: { + const HloInstruction* wrapped = instr->async_wrapped_instruction(); + switch (wrapped->opcode()) { + case HloOpcode::kReduceScatter: { + auto* reduce_scatter = Cast(wrapped); + return EmitNcclThunk( + Thunk::kNcclReduceScatter, instr, reduce_scatter, + reduce_scatter->use_global_device_ids()); + } + case HloOpcode::kAllToAll: { + auto* all_to_all = Cast(wrapped); + return EmitNcclThunk( + Thunk::kNcclAllToAll, instr, all_to_all, std::nullopt); + } + case HloOpcode::kCollectiveBroadcast: { + auto* collective_broadcast = + Cast(wrapped); + return EmitNcclThunk( + Thunk::kNcclCollectiveBroadcast, instr, collective_broadcast, + std::nullopt); + } + default: { + if (wrapped->has_backend_config()) { + TF_ASSIGN_OR_RETURN( + xla::gpu::GpuBackendConfig gpu_config, + wrapped->backend_config()); + if (gpu_config.operation_queue_id() != 0) { + // If there an async instruction that wraps an instruction + // that runs on a non-default stream, then we will + // emit syncOnStreamThunk(source=execution_stream, + // wait_on=main_compute_stream) + // then the thunk of wrapped instruction. + TF_RETURN_IF_ERROR( + EmitWaitForStreamsThunk(instr, gpu_config, + /*is_async_done=*/false)); + return EmitHloInstruction(wrapped); + } + } + return Internal("Unsupported async start wrapped instruction: %s", + HloOpcodeString(wrapped->opcode())); + } + } + } + case HloOpcode::kCall: + return EmitCommandBufferThunk(instr); + case HloOpcode::kCollectivePermuteDone: + return EmitNcclAsyncDone(Thunk::kNcclCollectivePermuteDone, instr); + case HloOpcode::kCollectivePermuteStart: + return EmitCollectivePermute( + Cast(instr)); + case HloOpcode::kConditional: + return EmitConditional(instr); + case HloOpcode::kConstant: + return EmitConstant(Cast(instr)); + case HloOpcode::kCustomCall: { + auto* custom_call = Cast(instr); + if (IsLegacyCublasMatmul(*instr)) { + return EmitGemmThunk(custom_call); + } #if GOOGLE_CUDA || TF_HIPBLASLT - if (mlir::isa(op)) { - return EmitCublasLtMatmulThunk(op); - } + if (IsCublasLtMatmul(*instr)) { + return EmitCublasLtMatmulThunk(custom_call); + } + if (IsCublasLtMatmulF8(*instr)) { + return EmitCublasLtMatmulThunkF8(custom_call); + } #endif // GOOGLE_CUDA || TF_HIPBLASLT #if GOOGLE_CUDA - if (mlir::isa(op)) { - return EmitCublasLtMatmulThunkF8(op); - } - if (mlir::isa(op)) { - return EmitConvolutionReorderThunk(op); - } - if (mlir::isa(op)) { - return EmitNormThunk(op); - } - if (mlir::isa(op)) { - return EmitFusedMHAThunk(op); - } - if (mlir::isa(op)) { - return EmitFusedMHABackwardThunk(op); - } + if (IsCudnnConvolutionReorder(*instr)) { + return EmitConvolutionReorderThunk(custom_call); + } + if (IsCustomCallToDnnNorm(*instr)) { + return EmitNormThunk(custom_call); + } + if (IsFwdCustomCallTofMHA(*instr)) { + return EmitFusedMHAThunk(custom_call); + } + if (IsBwdCustomCallTofMHA(*instr)) { + return EmitFusedMHABackwardThunk(custom_call); + } + if (IsFwdCustomCallToFlashAttn(*instr)) { + return EmitFlashAttnFwdThunk(custom_call); + } + if (IsBwdCustomCallToFlashAttn(*instr)) { + return EmitFlashAttnBwdThunk(custom_call); + } #endif // GOOGLE_CUDA - - if (mlir::isa(op)) { - return EmitConvolutionThunk(op); - } - + if (IsCustomCallToTopK(*instr)) { + return EmitTopKCustomCall(custom_call); + } + if (IsCustomCallToDnnConvolution(*instr)) { + return EmitConvolutionThunk(custom_call); + } #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - if (mlir::isa(op)) { - return EmitCubDeviceRadixSort(op); - } - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - return EmitCholeskyThunk(hlo_for_lmhlo.at(op)); - } else { - return EmitCholeskyThunk(op); - } - } + if (IsCustomCallToCusolver(*instr)) { + return EmitCholeskyThunk(instr); + } + if (IsTriangularSolve(*instr)) { + return EmitTriangularSolveCustomCall(instr); + } + if (IsCubDeviceRadixSort(*instr)) { + return EmitCubDeviceRadixSort(custom_call); + } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - - if (mlir::isa(op)) { - return EmitFftThunk(op); - } - - if (mlir::isa(op)) { - return InternalError( - "TriangularSolve is implemented as a custom-call; we do not expect to " - "lower a true HLO TriangularSolve op."); - } - - if (mlir::isa(op)) { - if (ir_emitter_context_->emit_ir_from_hlo()) { - const HloFusionInstruction* instr = - Cast(hlo_for_lmhlo.at(op)); - TF_ASSIGN_OR_RETURN(auto backend_config, - instr->backend_config()); - const se::DeviceDescription& device_info = - ir_emitter_context_->gpu_device_info(); - TF_ASSIGN_OR_RETURN(auto fusion_analysis, - HloFusionAnalysis::Create(instr, &device_info)); - // TODO(anlunx): Add support for emitting specialized kLoops. - if (!IsSpecializedLoopFusion(op, ir_emitter_context_->allocations(), - fusion_analysis)) { - return EmitFusion(instr, fusion_analysis); + if (custom_call->custom_call_target() == "PadToStatic") { + return EmitPadToStatic(custom_call); + } + if (instr->custom_call_target() == "SliceToDynamic") { + return EmitSliceToDynamic(custom_call); + } + if (instr->custom_call_target() == "__gpu$xla.gpu.triton") { + return EmitTritonCustomCall(custom_call); } + return EmitCustomCallThunk(custom_call); } - - return EmitFusion(op, hlo_for_lmhlo); - } - - if (mlir::isa(op)) { - return EmitSelectAndScatter(op, hlo_for_lmhlo); - } - - if (mlir::isa(op)) { - return EmitRngGetAndUpdateState(op); - } - - if (mlir::isa(op)) { - return EmitSort(op, hlo_for_lmhlo); - } - - if (mlir::isa(op)) { - return EmitReplicaOrPartitionId( - op); - } - - if (mlir::isa(op)) { - return EmitReplicaOrPartitionId(op); - } - - if (mlir::isa(op)) { - return EmitCollectivePermute(op); - } - - if (mlir::isa(op)) { - return EmitNcclAsyncDone( - Thunk::kNcclCollectivePermuteDone, op); - } - - if (mlir::isa(op)) { - return EmitNcclThunk(op); - } - - if (mlir::isa(op)) { - return EmitNcclAsyncDone( - Thunk::kNcclAllGatherDone, op); - } - - if (mlir::isa(op)) { - return EmitNcclThunk(op); - } - - if (mlir::isa(op)) { - return EmitNcclAsyncDone( - Thunk::kNcclAllReduceDone, op); - } - - if (mlir::isa(op)) { - return EmitNcclThunk(op); - } - - if (mlir::isa(op)) { - return EmitNcclAsyncDone( - Thunk::kNcclReduceScatterDone, op); - } - - if (mlir::isa(op)) { - return EmitNcclThunk(op); - } - - if (mlir::isa(op)) { - return EmitNcclAsyncDone( - Thunk::kNcclAllToAllDone, op); - } - - if (mlir::isa(op)) { - return EmitInfeed(op); - } - - if (mlir::isa(op)) { - return EmitOutfeed(op); - } - - if (mlir::isa(op)) { - return EmitConditional(op, hlo_for_lmhlo); - } - - if (mlir::isa(op)) { - return EmitWhile(op, hlo_for_lmhlo); - } - - // Remaining arith.constant ops are the gpu.launch_func dimensions as a result - // of inlining the fusion region after lowering. They can safely be skipped - // because constants have no side effects. - if (mlir::isa(op)) { - return OkStatus(); - } - - if (mlir::isa(op)) { - return EmitCommandBufferThunk(hlo_for_lmhlo.at(op)); - } - - // Point to point communication operations are only implemented as XLA - // GPU runtime custom calls. - bool is_gpu_runtime = ir_emitter_context_->debug_options() - .xla_gpu_enable_xla_runtime_executable(); - if (is_gpu_runtime && - mlir::isa(op)) { - return EmitUnreachable(op, - "Point-to-point communication operations are not " - "implemented as thunks"); - } - - return InternalError("Unrecognized op: %s", llvm_ir::DumpToString(op)); -} - -Status IrEmitterUnnested::EmitLmhloRegion( - mlir::Region* region, - const absl::flat_hash_map& - hlo_for_lmhlo) { - for (mlir::Operation& op : llvm::make_early_inc_range(region->front())) { - TF_RETURN_IF_ERROR(EmitOp(&op, hlo_for_lmhlo)); - } - return OkStatus(); -} - -Status IrEmitterUnnested::EmitHloInstruction(const HloInstruction* instr) { - // TODO(anlunx): Support other instruction opcodes. - switch (instr->opcode()) { case HloOpcode::kFusion: { auto* fusion = Cast(instr); - TF_ASSIGN_OR_RETURN(auto backend_config, - instr->backend_config()); const se::DeviceDescription& device_info = ir_emitter_context_->gpu_device_info(); - TF_ASSIGN_OR_RETURN(auto fusion_analysis, - HloFusionAnalysis::Create(fusion, &device_info)); - TF_RETURN_IF_ERROR(EmitFusion(fusion, fusion_analysis)); - return OkStatus(); - } + auto fusion_analysis = HloFusionAnalysis::Create(fusion, &device_info); + return EmitFusion(fusion, fusion_analysis); + } + case HloOpcode::kInfeed: + return EmitInfeed(Cast(instr)); + case HloOpcode::kOutfeed: + return EmitOutfeed(Cast(instr)); + case HloOpcode::kPartitionId: + return EmitReplicaOrPartitionId(instr); + case HloOpcode::kFft: + return EmitFftThunk(Cast(instr)); + + case HloOpcode::kRecv: + return EmitRecvThunk(Cast(instr)); + case HloOpcode::kRecvDone: + return EmitRecvDoneThunk(Cast(instr)); + + case HloOpcode::kReplicaId: + return EmitReplicaOrPartitionId(instr); + case HloOpcode::kRngGetAndUpdateState: + return EmitRngGetAndUpdateState( + Cast(instr)); + case HloOpcode::kSelectAndScatter: + return EmitSelectAndScatter(Cast(instr)); + + case HloOpcode::kSend: + return EmitSendThunk(Cast(instr)); + case HloOpcode::kSendDone: + return EmitSendDoneThunk(Cast(instr)); + + case HloOpcode::kSort: + return EmitSort(Cast(instr)); + case HloOpcode::kWhile: + return EmitWhile(instr); + case HloOpcode::kCopyStart: + return EmitCopyStartThunk(Cast(instr)); + + // HLO module is already scheduled, so instructions for ordering are noops. + case HloOpcode::kAddDependency: + case HloOpcode::kAfterAll: // We don't need to emit thunks for these operations because their semantics // are encoded by buffers. + case HloOpcode::kBitcast: case HloOpcode::kGetTupleElement: case HloOpcode::kParameter: - case HloOpcode::kTuple: { - return OkStatus(); - } + case HloOpcode::kTuple: + case HloOpcode::kCopyDone: + return absl::OkStatus(); default: - return InternalError("Unsupported instruction opcode"); + return Internal("Unsupported instruction opcode: %s", + HloOpcodeString(instr->opcode())); } - return InternalError("Unhandled HLO instruction"); + return Internal("Unhandled HLO instruction"); } -Status IrEmitterUnnested::EmitHloComputation( +absl::Status IrEmitterUnnested::EmitHloComputation( const HloComputation* computation) { - ThunkSequence thunk_sequence; - for (const HloInstruction* instr : computation->instructions()) { + const HloSchedule& schedule = computation->parent()->schedule(); + if (!schedule.is_computation_scheduled(computation)) + return Internal("Sequence not found for computation: %s", + computation->name()); + + const HloInstructionSequence& sequence = schedule.sequence(computation); + for (HloInstruction* instr : sequence.instructions()) { TF_RETURN_IF_ERROR(EmitHloInstruction(instr)); } - return OkStatus(); -} - -void IrEmitterUnnested::GetDependentDialects(mlir::DialectRegistry& registry) { - registry.insert(); - mlir::registerBuiltinDialectTranslation(registry); - mlir::registerLLVMDialectTranslation(registry); - mlir::registerNVVMDialectTranslation(registry); - mlir::registerROCDLDialectTranslation(registry); - mlir::func::registerAllExtensions(registry); + return absl::OkStatus(); } } // namespace gpu diff --git a/xla/service/gpu/ir_emitter_unnested.h b/xla/service/gpu/ir_emitter_unnested.h index 6e9fa5e6f7418..fc8b41bf69268 100644 --- a/xla/service/gpu/ir_emitter_unnested.h +++ b/xla/service/gpu/ir_emitter_unnested.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ limitations under the License. #define XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ #include +#include #include #include #include @@ -24,8 +25,11 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" #include "mlir/IR/Value.h" // from @llvm-project #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_computation.h" @@ -34,18 +38,24 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/elemental_ir_emitter.h" #include "xla/service/gpu/fusions/fusion_emitter.h" -#include "xla/service/gpu/fusions/tiling_util.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emitter.h" -#include "xla/service/gpu/kernel_mapping_scheme.h" -#include "xla/service/gpu/kernel_reuse_cache.h" -#include "xla/service/gpu/nccl_collective_thunk.h" -#include "xla/service/gpu/thunk.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/runtime/send_recv_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "xla/service/llvm_ir/loop_emitter.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" -#include "xla/statusor.h" +#include "tsl/platform/errors.h" + +#if TENSORFLOW_USE_ROCM +// for TF_HIPBLASLT +#include "rocm/rocm_config.h" +#endif namespace xla { namespace gpu { @@ -105,112 +115,95 @@ class IrEmitterUnnested : public IrEmitter { return std::make_unique(std::move(thunk_sequence_)); } - // Emits code for the given LMHLO region. + // Emits code for the given HLO computation. // // Also populates related information to 'ir_emitter_context_' for // large-constant initializations. Large constants don't get initializers in // the generated code and so must be initialized by XLA. The value of these // constants will be stored in 'content'. Constants with initializers in the // generated code will have empty 'content'. - Status EmitLmhloRegion( - mlir::Region* region, - const absl::flat_hash_map& - hlo_for_lmhlo); - - // Emits code for the given HLO computation. Right now it is only used to emit - // thunks for constructing command buffer. The plan is to replace - // EmitLmhloRegion by this function altogether, after we support emitting - // all instructions from HLO. - Status EmitHloComputation(const HloComputation* computation); - - static void GetDependentDialects(mlir::DialectRegistry& registry); + absl::Status EmitHloComputation(const HloComputation* computation); private: explicit IrEmitterUnnested(IrEmitterContext* ir_emitter_context); - Status EmitUnreachable(mlir::Operation* op, std::string error_message); - - Status EmitCommandBufferThunk(const HloInstruction* instr); + absl::Status EmitCommandBufferThunk(const HloInstruction* instr); // IrEmitterUnnested handles the following instructions differently from // IrEmitter. It also mixes in some special handling for custom kernels // via the ThunkEmitter. - Status EmitConstant(mlir::Operation* op, const Literal& literal); - - Status EmitConditional( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo); - Status EmitConvolutionThunk(mlir::Operation* op); - Status EmitGemmThunk(mlir::Operation* op); + absl::Status EmitConstant(const HloConstantInstruction* instr); + + absl::Status EmitConditional(const HloInstruction* instr); + absl::Status EmitConvolutionThunk(const HloCustomCallInstruction* instr); + absl::Status EmitGemmThunk(const HloCustomCallInstruction* instr); #if GOOGLE_CUDA || TF_HIPBLASLT - Status EmitCublasLtMatmulThunk(mlir::Operation* op); + absl::Status EmitCublasLtMatmulThunk(const HloCustomCallInstruction* instr); + absl::Status EmitCublasLtMatmulThunkF8(const HloCustomCallInstruction* instr); #endif // GOOGLE_CUDA || TF_HIPBLASLT #if GOOGLE_CUDA - Status EmitCublasLtMatmulThunkF8(mlir::Operation* op); - Status EmitConvolutionReorderThunk(mlir::Operation* op); - Status EmitNormThunk(mlir::Operation* op); - StatusOr EmitTritonFusion( - const HloFusionAnalysis& hlo_fusion_analysis, - const HloFusionInstruction* fusion, mlir::Operation* op); - Status EmitFusedMHAThunk(mlir::Operation* op); - Status EmitFusedMHABackwardThunk(mlir::Operation* op); + absl::Status EmitConvolutionReorderThunk( + const HloCustomCallInstruction* instr); + absl::Status EmitNormThunk(const HloCustomCallInstruction* instr); + absl::Status EmitFusedMHAThunk(const HloCustomCallInstruction* instr); + absl::Status EmitFusedMHABackwardThunk(const HloCustomCallInstruction* instr); + absl::Status EmitFlashAttnFwdThunk(const HloCustomCallInstruction* instr); + absl::Status EmitFlashAttnBwdThunk(const HloCustomCallInstruction* instr); #endif // GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - Status EmitCubDeviceRadixSort(mlir::Operation* op); - Status EmitCholeskyThunk(mlir::Operation* op); - Status EmitCholeskyThunk(const HloInstruction* instr); + absl::Status EmitCubDeviceRadixSort(const HloCustomCallInstruction* instr); + absl::Status EmitCholeskyThunk(const HloInstruction* instr); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - Status EmitCustomCallThunk(mlir::Operation* op); - Status EmitFftThunk(mlir::Operation* op); - StatusOr GetFusionEmissionResult( - const HloFusionInstruction* instr, HloFusionAnalysis& fusion_analysis); - Status EmitFusion( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo); - Status EmitFusion(const HloFusionInstruction* instr, - HloFusionAnalysis& fusion_analysis); - Status EmitSelectAndScatter( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo); - Status EmitWhile( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo); - Status EmitInfeed(mlir::Operation* op); - Status EmitOutfeed(mlir::Operation* op); - Status EmitRngGetAndUpdateState(mlir::Operation* op); - Status EmitSort( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo); + absl::Status EmitCustomCallThunk(const HloCustomCallInstruction* instr); + absl::Status EmitFftThunk(const HloFftInstruction* instr); + absl::Status EmitFusion(const HloFusionInstruction* instr, + HloFusionAnalysis& fusion_analysis); + absl::Status EmitSelectAndScatter( + const HloSelectAndScatterInstruction* instr); + absl::Status EmitWhile(const HloInstruction* instr); + absl::Status EmitInfeed(const HloInfeedInstruction* instr); + absl::Status EmitOutfeed(const HloOutfeedInstruction* instr); + absl::Status EmitRngGetAndUpdateState( + const HloRngGetAndUpdateStateInstruction* instr); + + absl::Status EmitSort(const HloSortInstruction* sort); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - Status EmitTriangularSolveCustomCall(mlir::Operation* op); + absl::Status EmitTriangularSolveCustomCall(const HloInstruction* instr); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + absl::Status EmitTopKCustomCall(const HloCustomCallInstruction* instr); + absl::Status EmitTritonCustomCall(const HloCustomCallInstruction* instr); - template - Status EmitNcclThunk(mlir::Operation* op); - template - Status EmitNcclAsyncDone(Thunk::Kind kind, mlir::Operation* op); + absl::Status EmitSendThunk(const HloSendInstruction* instr); + absl::Status EmitSendDoneThunk(const HloSendDoneInstruction* instr); - template - Status EmitReplicaOrPartitionId(mlir::Operation* op); + absl::Status EmitRecvThunk(const HloRecvInstruction* instr); + absl::Status EmitRecvDoneThunk(const HloRecvDoneInstruction* instr); - template - Status EmitCollectivePermute(mlir::Operation* op); + template + absl::Status EmitNcclThunk(Thunk::Kind kind, + const HloInstruction* async_start, + const HloInstType* inst, + std::optional use_global_device_ids); - Status EmitOp( - mlir::Operation* op, - const absl::flat_hash_map& - hlo_for_lmhlo); + absl::Status EmitNcclAsyncDone(Thunk::Kind kind, const HloInstruction* instr); - Status EmitHloInstruction(const HloInstruction* instr); + absl::Status EmitWaitForStreamsThunk(const HloInstruction* inst, + GpuBackendConfig& gpu_config, + bool is_async_done); + template + absl::Status EmitReplicaOrPartitionId(const HloInstruction* instr); - static Thunk::ThunkInfo GetThunkInfo(mlir::Operation* op); + absl::Status EmitCollectiveBroadcast( + const HloCollectiveBroadcastInstruction* instr); - Status EmitTargetElementLoop( + absl::Status EmitCollectivePermute( + const HloCollectivePermuteInstruction* instr); + + absl::Status EmitCopyStartThunk(const HloCopyStartInstruction* instr); + + absl::Status EmitHloInstruction(const HloInstruction* instr); + + absl::Status EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter) override; @@ -219,6 +212,15 @@ class IrEmitterUnnested : public IrEmitter { thunk_sequence_.emplace_back(std::move(thunk)); } + absl::Status AddThunksToThunkSequence( + absl::StatusOr result) { + TF_RETURN_IF_ERROR(result.status()); + for (auto& thunk : result->thunks) { + AddThunkToThunkSequence(std::move(thunk)); + } + return absl::OkStatus(); + } + // Load data from potentially unaligned address. If address is offset by // `alignment_bytes`, data is read in the unit of `alignment_bytes` to avoid // memory read misalignment in CUDA; otherwise, the entire data are loaded @@ -287,7 +289,7 @@ class IrEmitterUnnested : public IrEmitter { // return; // } // ``` - Status EmitPadToStatic(mlir::Operation* op); + absl::Status EmitPadToStatic(const HloCustomCallInstruction* instr); // Input = {dynamic array(with dynamic dimension meta data at the end)} // Output = {static array, dynamic_dim0, dynamic_dim1} @@ -333,119 +335,55 @@ class IrEmitterUnnested : public IrEmitter { // return; // } // ``` - Status EmitSliceToDynamic(mlir::Operation* op); - - StatusOr GetAllocationSlice(mlir::Value v); - StatusOr> GetAllocationSlices( - mlir::OperandRange operands); + absl::Status EmitSliceToDynamic(const HloCustomCallInstruction* instr); int64_t ByteSizeOf(const Shape& shape) const { return llvm_ir::ByteSizeOf( shape, ir_emitter_context_->llvm_module()->getDataLayout()); } - // Structure describing a scatter operation for IR emission. - // TODO(jurahul): Migrate element generators to use MLIR. - // Migrate update_computation to be an MLIR Region. - struct ScatterDescriptor { - std::string name; - Shape operand_shape; - Shape scatter_indices_shape; - Shape updates_shape; - ScatterDimensionNumbers dim_numbers; - bool unique_indices; - const HloComputation* update_computation; - llvm_ir::IrArray output; - llvm_ir::ElementGenerator scatter_indices_gen; - llvm_ir::ElementGenerator updates_gen; - std::function get_index_type; - }; - - // Emits code for an in-place scatter using the provided scatter operation - // description. - Status EmitScatter(const ScatterDescriptor& desc, - const LaunchDimensions& launch_dimensions); - - StatusOr EmitScatter( - const HloFusionInstruction* fusion, mlir::lmhlo::FusionOp fusion_op, - HloFusionAnalysis& fusion_analysis); - - // Emits kernel thunk for a custom fusion implemented with hand written custom - // device kernels. - StatusOr EmitCustomFusion( - const HloFusionInstruction* fusion, const CustomFusionConfig& config); - - // Builds a kernel thunk for a non-fusion operation, without reuse. - // - // All input and output tensors of `op` are passed to the kernel. - // - // TODO(tdanyluk): Consider also reusing non-fusion kernels. - StatusOr /*inputs*/, - std::vector /*outputs*/>> - BuildKernelThunkForNonFusionOp(mlir::Operation* op, - const LaunchDimensions& launch_dimensions); - - // Builds a kernel thunk for a non-fusion operation, without reuse. - // - // Only the tensors specified in `needed_operands` are passed to the kernel. - // - // TODO(tdanyluk): Consider also reusing non-fusion kernels. - StatusOr /*inputs*/, - std::vector /*outputs*/>> - BuildKernelThunkForNonFusionOp(mlir::Operation* op, - mlir::ValueRange needed_operands, - const LaunchDimensions& launch_dimensions); + absl::StatusOr /*inputs*/, + std::vector /*outputs*/>> + BuildKernelThunkForNonFusionOp( + const HloInstruction* hlo, + absl::Span needed_operands, + const LaunchDimensions& launch_dimensions); - Status BuildInitializerThunk(mlir::Operation* op, const HloInstruction* instr, - const HloInstruction* init_value, - mlir::Value init_value_mlir, mlir::Value dest); + absl::Status BuildInitializerThunk(const HloInstruction* instr, + const HloInstruction* init_value); // Returns a WhileThunk that invokes thunk sequences for 'condition' and - // 'body' sub-computations of while instruction 'hlo'. - StatusOr> BuildWhileThunk( - mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info, - const absl::flat_hash_map& - hlo_for_lmhlo); - - // Returns a ForThunk which executes 'loop_limit' invocations of a thunk - // sequence from the 'body' sub-computation of the while instruction 'hlo'. - StatusOr> BuildForThunk( - mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info, - int64_t loop_limit, - const absl::flat_hash_map& - hlo_for_lmhlo); + // 'body' sub-computations of while instruction. + absl::StatusOr> BuildWhileThunk( + const HloInstruction* instr, const Thunk::ThunkInfo& thunk_info, + std::optional trip_count); // Returns a ConditionalThunk which executes the thunk sequence for the // 'branch_computation' corresponding to the predicate/branch_index of the // given conditional instruction. - StatusOr> BuildConditionalThunk( + absl::StatusOr> BuildConditionalThunk( const HloInstruction* conditional); - Status AssertNonDeterminismIsOkay(const std::string& op_name); + absl::Status AssertNonDeterminismIsOkay(const std::string& op_name); + + absl::StatusOr GetAllocationSliceForHlo( + const HloInstruction* instr, const ShapeIndex& index = {}) const; - StatusOr GetAllocationSliceForHlo( - const HloInstruction* instr, const ShapeIndex& index) const; + CollectivesAsyncEvents& GetCollectivesAsyncEvents() { + return ir_emitter_context_->collectives_async_events(); + } // The thunk sequence this IrEmitter generates for the input computation. ThunkSequence thunk_sequence_; - // Maps async start ops to their executors so done can access the thunk. - // Executor may be null if the start op is degenerate (so not emitted). - absl::flat_hash_map - async_executors_; - - // Begin optional members for XLA HLO -> LMHLO: - absl::flat_hash_map> - scratch_nested_computations_; - // End optional members for XLA HLO -> LMHLO. + // Container for async send/recv events shared by send/recv thunks. + std::shared_ptr send_recv_events_; // Returns the ShapedSlices for the given operands. - StatusOr> GetShapedSlices( + absl::StatusOr> GetShapedSlices( mlir::Operation::operand_range operands); GpuElementalIrEmitter elemental_emitter_; - - KernelReuseCache kernel_reuse_cache_; }; } // namespace gpu diff --git a/xla/service/gpu/kernel_arguments.cc b/xla/service/gpu/kernel_arguments.cc index 5fcea6d418a28..ebdfbd7946cad 100644 --- a/xla/service/gpu/kernel_arguments.cc +++ b/xla/service/gpu/kernel_arguments.cc @@ -1,4 +1,4 @@ -/*Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/*Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,85 +14,57 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/kernel_arguments.h" +#include +#include #include #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/types/span.h" -#include "llvm/ADT/STLExtras.h" -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/gpu_constants.h" -#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" -#include "xla/statusor.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { -StatusOr KernelArgument::Create( - absl::Span allocations, mlir::Value value, - bool is_written) { - TF_ASSIGN_OR_RETURN( - auto slice, xla::gpu::GetAllocationSlice(value, allocations, nullptr)); - return KernelArgument(value, GetShape(value), slice, is_written); -} - -StatusOr KernelArguments::Create( - absl::Span allocations, - mlir::lmhlo::FusionOp fusion) { - auto operands = GetHloOperands(fusion); - auto outputs = GetHloOutputs(fusion); - std::vector kernel_arguments; - kernel_arguments.reserve(operands.size() + outputs.size()); - - for (auto value : operands) { - TF_ASSIGN_OR_RETURN(auto arg, KernelArgument::Create(allocations, value, - /*is_written=*/false)); - kernel_arguments.emplace_back(std::move(arg)); - } - for (auto value : outputs) { - TF_ASSIGN_OR_RETURN(auto arg, KernelArgument::Create(allocations, value, - /*is_written=*/true)); - kernel_arguments.emplace_back(std::move(arg)); - } - - return KernelArguments{std::move(kernel_arguments)}; -} - -StatusOr KernelArguments::Create( +absl::StatusOr KernelArguments::Create( const BufferAssignment& buffer_assignment, const HloFusionInstruction* fusion) { std::vector kernel_arguments; for (const HloInstruction* operand : fusion->operands()) { TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, buffer_assignment.GetUniqueSlice(operand, {})); - kernel_arguments.emplace_back(KernelArgument( - /*value=*/nullptr, operand->shape(), slice, /*written=*/false)); + kernel_arguments.emplace_back( + KernelArgument(operand->shape(), slice, /*written=*/false)); } TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( fusion->shape(), [&](const Shape& subshape, const ShapeIndex& index) { if (!subshape.IsArray()) { - return OkStatus(); + return absl::OkStatus(); } TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, buffer_assignment.GetUniqueSlice(fusion, index)); - kernel_arguments.emplace_back(KernelArgument( - /*value=*/nullptr, subshape, slice, /*written=*/true)); - return OkStatus(); + kernel_arguments.emplace_back( + KernelArgument(subshape, slice, /*written=*/true)); + return absl::OkStatus(); })); - return KernelArguments{std::move(kernel_arguments)}; + return KernelArguments{std::move(kernel_arguments), /*dedup=*/true}; } std::vector KernelArguments::ProcessArguments( - std::vector kernel_arguments) { + std::vector kernel_arguments, bool dedup) { absl::flat_hash_set buffers_written; for (const KernelArgument& kernel_argument : kernel_arguments) { if (kernel_argument.written()) { @@ -102,19 +74,22 @@ std::vector KernelArguments::ProcessArguments( absl::flat_hash_map> first_indices_for_slices; + int next_llvm_arg_index = 0; for (int i = 0; i < static_cast(kernel_arguments.size()); ++i) { KernelArgument& kernel_argument = kernel_arguments[i]; auto& first_index = first_indices_for_slices[kernel_argument.slice_]; - if (first_index) { + if (dedup && first_index) { const KernelArgument& same = kernel_arguments[*first_index]; kernel_argument.first_with_same_slice_ = first_index; kernel_argument.alignment_ = same.alignment_; kernel_argument.aliased_ = same.aliased_; kernel_argument.written_ = same.written_; + kernel_argument.llvm_arg_index_ = same.llvm_arg_index_; continue; } else { first_index = i; + kernel_argument.llvm_arg_index_ = next_llvm_arg_index++; } const BufferAllocation* alloc = kernel_argument.slice().allocation(); @@ -150,18 +125,33 @@ std::vector KernelArguments::ProcessArguments( return kernel_arguments; } -StatusOr KernelArguments::Create( - absl::Span allocations, - mlir::Operation* non_fusion_op, mlir::ValueRange needed_operands) { +absl::StatusOr KernelArguments::Create( + const BufferAssignment& buffer_assignment, + const HloInstruction* non_fusion_hlo, + absl::Span needed_operands, bool dedup) { std::vector kernel_arguments; - kernel_arguments.reserve(needed_operands.size()); - for (const auto& [i, value] : llvm::enumerate(needed_operands)) { - bool written = WritesMlirBuffer(non_fusion_op, value); - TF_ASSIGN_OR_RETURN(auto arg, - KernelArgument::Create(allocations, value, written)); - kernel_arguments.emplace_back(std::move(arg)); + for (const HloInstruction* operand : needed_operands) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, + buffer_assignment.GetUniqueSlice(operand, {})); + kernel_arguments.emplace_back( + KernelArgument(operand->shape(), slice, /*written=*/false)); } - return KernelArguments{std::move(kernel_arguments)}; + + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + non_fusion_hlo->shape(), + [&](const Shape& subshape, const ShapeIndex& index) { + if (!subshape.IsArray()) return absl::OkStatus(); + + TF_ASSIGN_OR_RETURN( + BufferAllocation::Slice slice, + buffer_assignment.GetUniqueSlice(non_fusion_hlo, index)); + + kernel_arguments.emplace_back( + KernelArgument(subshape, slice, /*written=*/true)); + return absl::OkStatus(); + })); + + return KernelArguments{std::move(kernel_arguments), dedup}; } } // namespace gpu diff --git a/xla/service/gpu/kernel_arguments.h b/xla/service/gpu/kernel_arguments.h index 12926227edec4..eeff9720ec7fc 100644 --- a/xla/service/gpu/kernel_arguments.h +++ b/xla/service/gpu/kernel_arguments.h @@ -1,4 +1,4 @@ -/*Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/*Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,16 +15,17 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_KERNEL_ARGUMENTS_H_ #define XLA_SERVICE_GPU_KERNEL_ARGUMENTS_H_ +#include #include #include #include -#include "mlir/IR/Value.h" // from @llvm-project +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/service/buffer_assignment.h" #include "xla/shape.h" -#include "xla/statusor.h" namespace xla { namespace gpu { @@ -33,11 +34,6 @@ namespace gpu { // Thread-safe. class KernelArgument { public: - static StatusOr Create( - absl::Span allocations, mlir::Value value, - bool is_written); - - mlir::Value value() const { return value_; } const Shape& shape() const { return shape_; } const BufferAllocation::Slice& slice() const { return slice_; } bool written() const { return written_; } @@ -46,18 +42,18 @@ class KernelArgument { return first_with_same_slice_; } bool aliased() const { return aliased_; } + int llvm_arg_index() const { return llvm_arg_index_; } private: - KernelArgument(mlir::Value value, Shape shape, BufferAllocation::Slice slice, - bool written) - : value_(value), shape_(shape), slice_(slice), written_(written) {} + KernelArgument(Shape shape, BufferAllocation::Slice slice, bool written) + : shape_(shape), slice_(slice), written_(written) {} - mlir::Value value_; Shape shape_; BufferAllocation::Slice slice_; bool aliased_ = true; int64_t alignment_ = 1; bool written_ = true; + int llvm_arg_index_; // Holds the index of the first argument which has the same slice as this, // if this is not the first such argument. std::optional first_with_same_slice_; @@ -67,26 +63,24 @@ class KernelArgument { class KernelArguments { public: - static StatusOr Create( - absl::Span allocations, - mlir::lmhlo::FusionOp fusion); - - static StatusOr Create( + static absl::StatusOr Create( const BufferAssignment& buffer_assignment, const HloFusionInstruction* fusion); - static StatusOr Create( - absl::Span allocations, - mlir::Operation* non_fusion_op, mlir::ValueRange needed_operands); + static absl::StatusOr Create( + const BufferAssignment& buffer_assignment, + const HloInstruction* non_fusion_hlo, + absl::Span needed_operands, + bool dedup = true); const std::vector& args() const { return args_; } private: - explicit KernelArguments(std::vector args) - : args_(ProcessArguments(std::move(args))) {} + explicit KernelArguments(std::vector args, bool dedup = true) + : args_(ProcessArguments(std::move(args), dedup)) {} static std::vector ProcessArguments( - std::vector kernel_arguments); + std::vector kernel_arguments, bool dedup); std::vector args_; }; diff --git a/xla/service/gpu/kernel_mapping_scheme.h b/xla/service/gpu/kernel_mapping_scheme.h deleted file mode 100644 index 92ad74b634c03..0000000000000 --- a/xla/service/gpu/kernel_mapping_scheme.h +++ /dev/null @@ -1,211 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_KERNEL_MAPPING_SCHEME_H_ -#define XLA_SERVICE_GPU_KERNEL_MAPPING_SCHEME_H_ - -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "llvm/IR/Value.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/llvm_ir/loop_emitter.h" -#include "xla/util.h" - -namespace xla { -namespace gpu { - -// Describes tiling used by the kernel. -// -// Used by reductions and 021 transpose algorithm. Both algorithms operate over -// "logical" 3D views over input arrays, hence tiling and number of threads -// information has only 3 dimensions. -// -// In the presence of virtual threadIdx/blockIdx scaling, all accessors are -// "logical", unless otherwise specified. -class TilingScheme { - public: - enum { DimZ = 0, DimY, DimX, DimTot }; - - enum IndexingOrder { - // Thread reads consecutive elements. - LinearIndexingX, - // Thread reads strided elements while keeping memory coalescing. - StridedIndexingX, - }; - - TilingScheme(Vector3 dims_in_elems, Vector3 tile_sizes, Vector3 num_threads, - IndexingOrder indexing_order, int vector_size, - int scaling_factor, Vector2 tiling_dimensions = Vector2{1, 2}) - : dims_in_elems_(dims_in_elems), - tile_sizes_(tile_sizes), - tiling_dimensions_(tiling_dimensions), - num_threads_(num_threads), - indexing_order_(indexing_order), - vector_size_(vector_size), - thread_id_virtual_scaling_(scaling_factor) { - CHECK_EQ(tile_sizes[2] % vector_size_, 0) - << "tile sizes = " << absl::StrJoin(tile_sizes, ", ") - << "; vector size = " << vector_size_; - } - - static std::string IndexingOrderToString(IndexingOrder order) { - switch (order) { - case LinearIndexingX: - return "linear"; - case StridedIndexingX: - return "strided"; - } - } - - std::string ToString() const { - return absl::StrJoin( - {absl::StrFormat("dims_in_elems = {%s}", - absl::StrJoin(dims_in_elems_, ", ")), - absl::StrFormat("tile_sizes = {%s}", absl::StrJoin(tile_sizes_, ", ")), - absl::StrFormat("num_threads = {%s}", - absl::StrJoin(num_threads_, ", ")), - absl::StrFormat("indexing_order = %s", - IndexingOrderToString(indexing_order_)), - absl::StrFormat("vector_size = %d", vector_size_), - absl::StrFormat("thread_id_virtual_scaling = %d", - thread_id_virtual_scaling_), - absl::StrFormat("tiling_dimensions = {%s}", - absl::StrJoin(tiling_dimensions_, ", "))}, - ", "); - } - - // Number of elements in each dimension (Z/Y/X respectively). - absl::Span GetDimsInElems() const { return dims_in_elems_; } - - Vector3 GetDimsInBlocks() const { - return {GetDimInBlock(0), GetDimInBlock(1), GetDimInBlock(2)}; - } - - // Number of blocks required to "cover" the given dimension. - int64_t GetDimInBlock(int d) const { - return CeilOfRatio(dims_in_elems_[d], GetBlockTileSizeFor(d)); - } - - // Tile size for a given dimensions per thread. - // - // Equals to the number of iterations in the loop each tile will make. - int64_t GetTileSizeFor(int d) const { return tile_sizes_.at(d); } - - // The tiling dimension for dimension 'd' of the shared memory tile. - int64_t GetTilingDimension(int d) const { return tiling_dimensions_.at(d); } - - // Tile size for a given dimension per entire thread block. - int64_t GetBlockTileSizeFor(int d) const { - return num_threads_.at(d) * tile_sizes_.at(d); - } - - // Number of threads in given dimension. - int64_t GetNumThreadsFor(int d) const { return num_threads_.at(d); } - - // Number of logical threads per block. - int64_t GetNumThreadsPerBlock() const { - return GetNumThreadsFor(0) * GetNumThreadsFor(1) * GetNumThreadsFor(2); - } - - // Number of logical blocks. - int64_t GetNumberOfBlocks() const { - return GetDimInBlock(0) * GetDimInBlock(1) * GetDimInBlock(2); - } - - // Number of physical blocks launched (with scaling applied). - int64_t GetNumberOfBlocksPhysical() const { - return CeilOfRatio(GetNumberOfBlocks(), thread_id_virtual_scaling_); - } - - // Number of physical threads per block launched (with scaling applied). - int64_t GetNumThreadsPerBlockPhysical() const { - return GetNumThreadsPerBlock() * thread_id_virtual_scaling_; - } - - IndexingOrder GetIndexingOrder() const { return indexing_order_; } - int GetVectorSize() const { return vector_size_; } - - // Scaling factor for transforming physical threadId to logical. - int GetThreadIdScalingFactor() const { return thread_id_virtual_scaling_; } - - private: - // The number of elements in each dimension. - Vector3 dims_in_elems_; - - // The number of elements for each dimension of a tile. - Vector3 tile_sizes_; - - // The dimensions which are used for the shared memory tile. - Vector2 tiling_dimensions_; - - // Number of threads implicitly assigned to each dimension. - Vector3 num_threads_; - - IndexingOrder indexing_order_; - - // Vector size for dimension X. - int vector_size_; - - // Scaling apply to transform physical threadIdx into logical. - int64_t thread_id_virtual_scaling_ = 1; -}; - -class ReductionCodegenInfo { - public: - using IndexGroups = std::vector>; - - ReductionCodegenInfo(TilingScheme mapping_scheme, int num_partial_results, - bool is_row_reduction, bool is_race_free, - IndexGroups index_groups, - const HloInstruction* first_reduce) - : tiling_scheme_(mapping_scheme), - num_partial_results_(num_partial_results), - is_row_reduction_(is_row_reduction), - is_race_free_(is_race_free), - index_groups_(std::move(index_groups)), - first_reduce_(first_reduce) { - if (!is_row_reduction && num_partial_results > 1) { - CHECK_EQ(num_partial_results, - mapping_scheme.GetTileSizeFor(TilingScheme::DimX)); - } - } - - const TilingScheme& GetTilingScheme() const { return tiling_scheme_; } - const IndexGroups& GetIndexGroups() const { return index_groups_; } - Shape GetReduceOperandShape() const { - return first_reduce_->operand(0)->shape(); - } - - int GetNumPartialResults() const { return num_partial_results_; } - bool IsRowReduction() const { return is_row_reduction_; } - bool IsRaceFree() const { return is_race_free_; } - - private: - TilingScheme tiling_scheme_; - int num_partial_results_; - bool is_row_reduction_; - bool is_race_free_; - IndexGroups index_groups_; - const HloInstruction* first_reduce_; -}; - -} // end namespace gpu -} // end namespace xla - -#endif // XLA_SERVICE_GPU_KERNEL_MAPPING_SCHEME_H_ diff --git a/xla/service/gpu/kernel_reuse_cache.cc b/xla/service/gpu/kernel_reuse_cache.cc index c9ada1369dc7d..2d4a9d3356b67 100644 --- a/xla/service/gpu/kernel_reuse_cache.cc +++ b/xla/service/gpu/kernel_reuse_cache.cc @@ -1,4 +1,4 @@ -/*Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/*Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,8 +18,15 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/kernel_arguments.h" #include "xla/util.h" #include "tsl/platform/logging.h" @@ -77,34 +84,36 @@ std::string GetComputationFingerprint( fused_computation->ToString(print_options)); } -std::pair KernelReuseCache::Get( - const HloComputation* fused_computation, - absl::Span kernel_arguments, - absl::string_view discriminator, - const std::function& generator) { - auto ret = GetWithStatus(fused_computation, kernel_arguments, discriminator, - [&]() -> StatusOr { return generator(); }); - return {*ret.first, ret.second}; -} - -std::pair, bool> +std::pair, bool> KernelReuseCache::GetWithStatus( const HloComputation* fused_computation, absl::Span kernel_arguments, absl::string_view discriminator, - const std::function()>& generator) { + const std::function()>& generator) { std::string fingerprint = GetComputationFingerprint( fused_computation, kernel_arguments, discriminator); VLOG(4) << "Fingerprint: "; XLA_VLOG_LINES(4, fingerprint); + return GetWithStatus(std::move(fingerprint), generator); +} - auto& entry = cache_[fingerprint]; - if (entry.kernel_name.empty()) { - auto ret = generator(); - if (ret.ok()) entry = *ret; - return {ret, false}; +std::pair, bool> +KernelReuseCache::GetWithStatus( + std::string fingerprint, + const std::function()>& generator) { + auto it = cache_.find(fingerprint); + if (it != cache_.end()) { + return {&it->second, /*was_cached=*/true}; } - return {{entry}, true}; + + absl::StatusOr entry = generator(); + if (entry.ok()) { + it = + cache_.insert({std::move(fingerprint), std::move(entry.value())}).first; + return {&it->second, /*was_cached=*/false}; + } + + return {entry.status(), /*was_cached=*/false}; } } // namespace gpu diff --git a/xla/service/gpu/kernel_reuse_cache.h b/xla/service/gpu/kernel_reuse_cache.h index 33cddb7d5c15c..ea55d97dc4398 100644 --- a/xla/service/gpu/kernel_reuse_cache.h +++ b/xla/service/gpu/kernel_reuse_cache.h @@ -1,4 +1,4 @@ -/*Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/*Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,14 +15,20 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_KERNEL_REUSE_CACHE_H_ #define XLA_SERVICE_GPU_KERNEL_REUSE_CACHE_H_ +#include #include +#include #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/service/gpu/kernel_arguments.h" #include "xla/service/gpu/launch_dimensions.h" +#include "xla/stream_executor/launch_dim.h" namespace xla { namespace gpu { @@ -34,22 +40,33 @@ class KernelReuseCache { struct Entry { std::string kernel_name; LaunchDimensions launch_dimensions; - int64_t shmem_bytes; + std::optional cluster_dim; + int64_t shmem_bytes = 0; }; // Retrieves the cache entry for the given computation, or generates it using // the given generator function and stores it in the cache. - std::pair Get( - const HloComputation* fused_computation, - absl::Span kernel_arguments, - absl::string_view discriminator, const std::function& generator); - - // Like `Get`, but for generator functions that can fail. - std::pair, bool /*was_cached*/> GetWithStatus( + // + // The returned pointer is never nullptr. + // + // A non-OK status is returned if the entry is not found and the generator + // failed. + std::pair, bool /*was_cached*/> GetWithStatus( const HloComputation* fused_computation, absl::Span kernel_arguments, absl::string_view discriminator, - const std::function()>& generator); + const std::function()>& generator); + + // Retrieves the cache entry for the given fingerprint, or generates it using + // the given generator function and stores it in the cache. + // + // The returned pointer is never nullptr. + // + // A non-OK status is returned if the entry is not found and the generator + // failed. + std::pair, bool /*was_cached*/> GetWithStatus( + std::string fingerprint, + const std::function()>& generator); private: absl::flat_hash_map cache_; diff --git a/xla/service/gpu/kernel_thunk.cc b/xla/service/gpu/kernel_thunk.cc deleted file mode 100644 index 58dfab7b11320..0000000000000 --- a/xla/service/gpu/kernel_thunk.cc +++ /dev/null @@ -1,239 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/kernel_thunk.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/strings/str_format.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/kernel_arguments.h" -#include "xla/service/gpu/kernels/custom_kernel.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/stream_executor_util.h" -#include "xla/service/gpu/thunk.h" -#include "xla/status.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" - -namespace xla { -namespace gpu { -namespace { - -//===----------------------------------------------------------------------===// -// KernelThunk -//===----------------------------------------------------------------------===// - -mlir::Value RemoveTransformingOperations(mlir::Value value) { - mlir::Operation* defining_op = value.getDefiningOp(); - if (auto cast_op = llvm::isa(defining_op)) { - return defining_op->getOperand(0); - } - return value; -} - -} // namespace - -KernelThunk::KernelThunk( - std::variant op, - std::string kernel_name, absl::Span kernel_arguments, - LaunchDimensions launch_dimensions, int64_t shmem_bytes) - : Thunk(Kind::kKernel, std::holds_alternative(op) - ? Thunk::ThunkInfo::WithProfileAnnotation( - std::get(op)) - : Thunk::ThunkInfo::WithProfileAnnotation( - std::get(op))), - kernel_name_(std::move(kernel_name)), - launch_dimensions_(std::move(launch_dimensions)), - shmem_bytes_(shmem_bytes) { - args_.reserve(kernel_arguments.size()); - written_.reserve(kernel_arguments.size()); - for (const auto& kernel_argument : kernel_arguments) { - if (!kernel_argument.first_with_same_slice().has_value()) { - args_.push_back(kernel_argument.slice()); - written_.push_back(kernel_argument.written()); - } - } - - if (std::holds_alternative(op)) { - // Skip populating MLIR values_ if emitting from HLO. - return; - } - - values_.reserve(kernel_arguments.size()); - for (const auto& kernel_argument : kernel_arguments) { - if (!kernel_argument.first_with_same_slice().has_value()) { - values_.push_back(RemoveTransformingOperations(kernel_argument.value())); - } - } -} - -std::string KernelThunk::ToStringExtra(int indent) const { - return absl::StrFormat(", kernel = %s, launch dimensions = %s", kernel_name_, - launch_dimensions_.ToString()); -} - -Status KernelThunk::Initialize(se::StreamExecutor* executor, - ExecutableSource src) { - absl::MutexLock lock(&mutex_); - - // Load the kernel into the device if necessary. - // - // We could alternatively do this within ExecuteOnStream, but doing it here - // lets the time spent loading the kernel not count towards our execution - // profiles. - auto it = kernel_cache_.find(executor); - if (kernel_cache_.end() == it) { - TF_ASSIGN_OR_RETURN(std::unique_ptr kernel, - CreateKernel(kernel_name_, args_.size(), src.text, - src.binary, executor, shmem_bytes_)); - - kernel_cache_.emplace(executor, std::move(kernel)); - } - - return OkStatus(); -} - -static void PrintBufferContents( - se::Stream* stream, absl::Span buffer_args) { - int input_idx = 0; - for (const se::DeviceMemoryBase& buf : buffer_args) { - auto host_buffer = std::make_unique(buf.size()); - CHECK(stream->ThenMemcpy(host_buffer.get(), buf, buf.size()).ok()); - CHECK_OK(stream->BlockHostUntilDone()); - - std::string buffer_contents; - for (int i = 0; i < buf.size(); i++) { - absl::StrAppendFormat(&buffer_contents, "%x ", - static_cast(host_buffer[i])); - } - VLOG(100) << "BUF(" << input_idx++ << ") = " << buffer_contents; - } -} - -Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) { - // Load the kernel. - se::StreamExecutor* executor = params.stream->parent(); - LaunchDimensions launch_dimensions; - const se::Kernel* kernel = nullptr; - - { - absl::MutexLock lock(&mutex_); - auto it = kernel_cache_.find(executor); - CHECK(it != kernel_cache_.end()) - << "Initialize() not called for StreamExecutor " << executor; - launch_dimensions = launch_dimensions_; - kernel = it->second.get(); - } - - VLOG(3) << "Launching " << kernel->name(); - absl::InlinedVector buffer_args; - for (const BufferAllocation::Slice& arg : args_) { - se::DeviceMemoryBase buf = params.buffer_allocations->GetDeviceAddress(arg); - VLOG(3) << " Arg: alloc #" << arg.index() << ", offset: " << arg.offset() - << ": " << buf.opaque() << " (" << buf.size() << "B)"; - buffer_args.push_back(buf); - } - - if (VLOG_IS_ON(100)) { - PrintBufferContents(params.stream, buffer_args); - } - - return ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions, - params.stream); -} - -//===----------------------------------------------------------------------===// -// CustomKernelThunk -//===----------------------------------------------------------------------===// - -CustomKernelThunk::CustomKernelThunk( - const HloInstruction* instr, CustomKernel custom_kernel, - absl::Span kernel_arguments) - : Thunk(Kind::kKernel, Thunk::ThunkInfo::WithProfileAnnotation(instr)), - custom_kernel_(std::move(custom_kernel)) { - args_.reserve(kernel_arguments.size()); - written_.reserve(kernel_arguments.size()); - for (const auto& kernel_argument : kernel_arguments) { - if (!kernel_argument.first_with_same_slice().has_value()) { - args_.push_back(kernel_argument.slice()); - written_.push_back(kernel_argument.written()); - } - } -} - -std::string CustomKernelThunk::ToStringExtra(int indent) const { - return custom_kernel_.ToString(); -} - -Status CustomKernelThunk::Initialize(se::StreamExecutor* executor, - ExecutableSource src) { - absl::MutexLock lock(&mutex_); - - auto it = kernel_cache_.find(executor); - if (kernel_cache_.end() == it) { - auto kernel = std::make_unique(executor); - TF_RETURN_IF_ERROR( - executor->GetKernel(custom_kernel_.kernel_spec(), kernel.get())); - kernel_cache_.emplace(executor, std::move(kernel)); - } - - return OkStatus(); -} - -Status CustomKernelThunk::ExecuteOnStream(const ExecuteParams& params) { - se::StreamExecutor* executor = params.stream->parent(); - - const se::Kernel* kernel = [&] { - absl::MutexLock lock(&mutex_); - return kernel_cache_[executor].get(); - }(); - - VLOG(3) << "Launching " << custom_kernel_.ToString() << " as device kernel " - << kernel->name(); - - absl::InlinedVector buffer_args; - for (const BufferAllocation::Slice& arg : args_) { - se::DeviceMemoryBase buf = params.buffer_allocations->GetDeviceAddress(arg); - VLOG(3) << " Arg: alloc #" << arg.index() << ", offset: " << arg.offset() - << ": " << buf.opaque() << " (" << buf.size() << "B)"; - buffer_args.push_back(buf); - } - - if (VLOG_IS_ON(100)) { - PrintBufferContents(params.stream, buffer_args); - } - - se::KernelArgsDeviceMemoryArray args(buffer_args, - custom_kernel_.shared_memory_bytes()); - return executor->Launch(params.stream, custom_kernel_.thread_dims(), - custom_kernel_.block_dims(), *kernel, args); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/kernel_thunk.h b/xla/service/gpu/kernel_thunk.h deleted file mode 100644 index 7f8ff1331324e..0000000000000 --- a/xla/service/gpu/kernel_thunk.h +++ /dev/null @@ -1,166 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_KERNEL_THUNK_H_ -#define XLA_SERVICE_GPU_KERNEL_THUNK_H_ - -#include -#include -#include -#include -#include - -#include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/kernel_arguments.h" -#include "xla/service/gpu/kernels/custom_kernel.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/thunk.h" -#include "xla/status.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" // IWYU pragma: keep - -namespace xla { -namespace gpu { - -class GpuExecutable; - -// TODO(ezhulenev): Unify KernelThunk and CustomKernelThunk as they are very -// similar. XLA:GPU should use more of kernel loading APIs provided by -// StreamExecutor out of the box and less custom kernel loading solutions. -// -// Today KernelThunk is required for lowering to XLA runtime, and -// CustomKernelThunk is only supported for thunk execution. - -//===----------------------------------------------------------------------===// -// KernelThunk -//===----------------------------------------------------------------------===// - -// This class stores everything that StreamExecutor needs for launching a -// kernel. It implements the ExecuteOnStream interface for GpuExecutable to -// invoke the corresponding kernel. -// -// This is thread-compatible. -class KernelThunk : public Thunk { - public: - // Constructs a thunk for the given kernel. - // - // KernelThunk takes args as `BufferAllocation::Slice`s (wrapped in - // `KernelArgument`s). Each slice directly corresponds to an argument or - // output of the computation. Also, the values must correspond to each arg - // directly, not to their base allocation (e.g. they can be the result of an - // `mlir::memref::ViewOp`). - KernelThunk(std::variant op, - std::string kernel_name, - absl::Span kernel_arguments, - LaunchDimensions launch_dimensions, int64_t shmem_bytes); - KernelThunk(const KernelThunk&) = delete; - KernelThunk& operator=(const KernelThunk&) = delete; - ~KernelThunk() override = default; - - std::string ToStringExtra(int indent) const override; - - Status Initialize(se::StreamExecutor* executor, - ExecutableSource src) override; - Status ExecuteOnStream(const ExecuteParams& params) override; - - void ClearCompileTimeInfo() override { - Thunk::ClearCompileTimeInfo(); - for (auto& value : values_) { - value = nullptr; - } - } - - const std::vector& arguments() const { - return args_; - } - const std::vector& written() const { return written_; } - - const std::string& kernel_name() const { return kernel_name_; } - const LaunchDimensions& launch_dimensions() const { - return launch_dimensions_; - } - // The shared memory required by the kernel. - int64_t shmem_bytes() const { return shmem_bytes_; } - absl::Span values() const { return values_; } - - private: - // Buffer slices passed to the kernel as arguments. - std::vector args_; - - // args_[i] is written iff (written_[i] == true). - std::vector written_; - - // Entry kernel name for the computation. - const std::string kernel_name_; - - // The thread and block dimension used to launch the kernel. - const LaunchDimensions launch_dimensions_; - - int64_t shmem_bytes_; - - // mlir::Value(s) corresponding to the buffer slice arguments. - std::vector values_; - - // Loaded kernels for each `StreamExecutor`. - mutable absl::Mutex mutex_; - absl::flat_hash_map> - kernel_cache_ ABSL_GUARDED_BY(mutex_); -}; - -//===----------------------------------------------------------------------===// -// CustomKernelThunk -//===----------------------------------------------------------------------===// - -// CustomKernelThunk loads and executes kernels defined by a custom kernel -// (which in practice means hand written CUDA C++ kernel), instead of a kernel -// compiled by XLA and loaded from an executable source. -class CustomKernelThunk : public Thunk { - public: - CustomKernelThunk(const HloInstruction* instr, CustomKernel custom_kernel, - absl::Span kernel_arguments); - - std::string ToStringExtra(int indent) const override; - - Status Initialize(se::StreamExecutor* executor, - ExecutableSource src) override; - Status ExecuteOnStream(const ExecuteParams& params) override; - - private: - // Buffer slices passed to the kernel as arguments. - std::vector args_; - - // args_[i] is written iff (written_[i] == true). - std::vector written_; - - CustomKernel custom_kernel_; - - // Loaded kernels for each `StreamExecutor`. - mutable absl::Mutex mutex_; - absl::flat_hash_map> - kernel_cache_ ABSL_GUARDED_BY(mutex_); -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_KERNEL_THUNK_H_ diff --git a/xla/service/gpu/kernels/BUILD b/xla/service/gpu/kernels/BUILD index deb840ba8b08d..d754e975dfb52 100644 --- a/xla/service/gpu/kernels/BUILD +++ b/xla/service/gpu/kernels/BUILD @@ -1,12 +1,15 @@ -# copybara:uncomment_begin(google-only-loads) -# load("//xla/tests:build_defs.bzl", "xla_test") -# load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") -# load("@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") -# copybara:uncomment_end(google-only-loads) +load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") +load("@tsl//tsl/platform:build_config_root.bzl", "tf_gpu_tests_tags") +load("@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") +load("//xla:xla.bzl", "xla_cc_test") +load("//xla/service/gpu:build_defs.bzl", "gpu_kernel_library") +load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") +load("//xla/tests:build_defs.bzl", "xla_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [":friends"], + default_visibility = ["//visibility:private"], licenses = ["notice"], ) @@ -16,33 +19,41 @@ package_group( ) cc_library( - name = "custom_fusion", - srcs = ["custom_fusion.cc"], - hdrs = ["custom_fusion.h"], + name = "custom_kernel_fusion", + srcs = ["custom_kernel_fusion.cc"], + hdrs = ["custom_kernel_fusion.h"], + visibility = [":friends"], deps = [ ":custom_kernel", "//xla:status", "//xla:statusor", "//xla/hlo/ir:hlo", + "//xla/stream_executor:device_description", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", ], ) cc_library( - name = "custom_fusion_pattern", - srcs = ["custom_fusion_pattern.cc"], - hdrs = ["custom_fusion_pattern.h"], + name = "custom_kernel_fusion_pattern", + srcs = ["custom_kernel_fusion_pattern.cc"], + hdrs = ["custom_kernel_fusion_pattern.h"], + visibility = [":friends"], deps = [ + "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service/gpu:backend_configs_cc", + "//xla/stream_executor:device_description", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -50,12 +61,10 @@ cc_library( name = "custom_kernel", srcs = ["custom_kernel.cc"], hdrs = ["custom_kernel.h"], + visibility = [":friends"], deps = [ - "//xla:statusor", "//xla/stream_executor", "@com_google_absl//absl/strings:str_format", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:statusor", ], ) @@ -63,102 +72,346 @@ cc_library( # a single dependency. cc_library( name = "custom_fusion_library", - # copybara:uncomment_begin(google-only) - # deps = [":cutlass_gemm_fusion"], - # copybara:uncomment_end(google-only) -) - -# copybara:uncomment_begin(google-only) -# # TODO(ezhulenev): We currently do not have a CUTLASS dependency in open source BUILD. -# -# cc_library( -# name = "cutlass_gemm_fusion", -# srcs = ["cutlass_gemm_fusion.cc"], -# hdrs = ["cutlass_gemm_fusion.h"], -# deps = [ -# ":custom_fusion", -# ":custom_fusion_pattern", -# ":custom_kernel", -# ":cutlass_gemm_kernel", -# "@com_google_absl//absl/status", -# "//xla:shape_util", -# "//xla:status", -# "//xla:statusor", -# "//xla:xla_data_proto_cc", -# "//xla/hlo/ir:hlo", -# "//xla/service:pattern_matcher", -# "@tsl//tsl/platform:errors", -# "@tsl//tsl/platform:logging", -# "@tsl//tsl/platform:statusor", -# ], -# alwayslink = 1, # static fusion registration -# ) -# -# xla_test( -# name = "cutlass_gemm_fusion_test", -# srcs = ["cutlass_gemm_fusion_test.cc"], -# backends = ["gpu"], -# deps = [ -# ":custom_fusion_pattern", -# ":cutlass_gemm_fusion", -# "@com_google_absl//absl/strings", -# "//xla:debug_options_flags", -# "//xla:error_spec", -# "//xla/service/gpu:custom_fusion_rewriter", -# "//xla/tests:hlo_test_base", -# "@tsl//tsl/platform:test", -# "@tsl//tsl/platform:test_main", -# ], -# ) -# -# cuda_library( -# name = "cutlass_gemm_kernel", -# srcs = ["cutlass_gemm_kernel.cu.cc"], -# hdrs = ["cutlass_gemm_kernel.h"], -# visibility = ["//visibility:private"], -# deps = [ -# ":custom_kernel", -# ":cutlass_gemm_universal", -# "@com_google_absl//absl/status", -# "@com_google_absl//absl/strings", -# "//third_party/gpus/cutlass", -# "//xla:statusor", -# "//xla:xla_data_proto_cc", -# "//xla/stream_executor", -# ], -# ) -# -# cuda_library( -# name = "cutlass_gemm_universal", -# hdrs = ["cutlass_gemm_universal.cu.h"], -# visibility = ["//visibility:private"], -# deps = [ -# "@com_google_absl//absl/status", -# "@com_google_absl//absl/strings", -# "//third_party/gpus/cutlass", -# "//xla:statusor", -# "//xla/stream_executor", -# ], -# ) -# -# xla_test( -# name = "cutlass_gemm_test", -# srcs = if_cuda_is_configured(["cutlass_gemm_test.cc"]), -# backends = ["gpu"], -# deps = [ -# ":cutlass_gemm_kernel", -# "//xla:types", -# "//xla:xla_data_proto_cc", -# "//xla/stream_executor", -# "//xla/stream_executor:multi_platform_manager", -# "//xla/stream_executor:platform", -# "//xla/stream_executor/cuda:cuda_platform", -# "@tsl//tsl/lib/core:status_test_util", -# "@tsl//tsl/platform:status", -# "@tsl//tsl/platform:test", -# "@tsl//tsl/platform:test_benchmark", -# "@tsl//tsl/platform:test_main", -# ], -# ) -# -# copybara:uncomment_end(google-only) + visibility = [":friends"], + deps = [":cutlass_gemm_fusion"], +) + +cc_library( + name = "cutlass_gemm_fusion", + srcs = ["cutlass_gemm_fusion.cc"], + hdrs = ["cutlass_gemm_fusion.h"], + deps = [ + ":custom_kernel", + ":custom_kernel_fusion", + ":custom_kernel_fusion_pattern", + ":cutlass_gemm", + ":cutlass_gemm_custom_kernel", + "//xla:shape_util", + "//xla:status", + "//xla:statusor", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], + alwayslink = 1, # static fusion registration +) + +xla_test( + name = "cutlass_gemm_fusion_test", + srcs = ["cutlass_gemm_fusion_test.cc"], + backends = ["gpu"], + # TODO(b/332820384): Enable when it passes on H100. + disabled_backends = ["gpu_h100"], + tags = ["no_rocm"], + deps = [ + ":custom_kernel_fusion_pattern", + ":cutlass_gemm_fusion", + "//xla:array", + "//xla:array2d", + "//xla:array3d", + "//xla:error_spec", + "//xla:literal_util", + "//xla:types", + "//xla/service/gpu:custom_kernel_fusion_rewriter", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/tests:hlo_test_base", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "topk_kernel", + srcs = if_gpu_is_configured(["topk_kernel.cc"]), + hdrs = if_gpu_is_configured(["topk_kernel.h"]), + compatible_with = [], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + deps = [ + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/stream_executor", # build_cleaner: keep + "//xla/stream_executor:platform", + "//xla/stream_executor/gpu:gpu_stream_header", + "//xla/stream_executor/gpu:gpu_types_header", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", + ] + if_gpu_is_configured([ + ":topk_kernel_gpu", + ]), +) + +gpu_kernel_library( + name = "topk_kernel_gpu", + srcs = if_gpu_is_configured([ + "topk_kernel_bfloat16.cu.cc", + "topk_kernel_float.cu.cc", + "topk_kernel.cu.h", + ]), + hdrs = if_gpu_is_configured(["topk_kernel_common.h"]), + compatible_with = [], + deps = [ + "//xla:types", + "//xla/stream_executor/gpu:gpu_types_header", + "@tsl//tsl/lib/math:math_util", + ], +) + +xla_cc_test( + name = "topk_kernel_test", + srcs = if_gpu_is_configured(["topk_kernel_test.cc"]), + tags = tf_gpu_tests_tags(), + deps = [ + ":topk_kernel", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/stream_executor", # build_cleaner: keep + "//xla/stream_executor:platform_manager", + "//xla/stream_executor/gpu:gpu_init", + "//xla/stream_executor/gpu:gpu_stream_header", + "//xla/stream_executor/gpu:gpu_timer_header", + "//xla/stream_executor/gpu:gpu_types_header", + "//xla/stream_executor/host:host_platform", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_benchmark", + "@tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "topk_custom_kernel", + srcs = ["topk_custom_kernel.cc"], + hdrs = ["topk_custom_kernel.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + visibility = [":friends"], + deps = [ + ":custom_kernel", + "//xla:statusor", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/stream_executor", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@tsl//tsl/platform:statusor", + ] + if_gpu_is_configured([ + ":topk_kernel_gpu", + ]), +) + +xla_test( + name = "topk_custom_kernel_test", + srcs = if_gpu_is_configured(["topk_custom_kernel_test.cc"]), + backends = ["gpu"], + deps = [ + ":topk_custom_kernel", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/service:platform_util", + "//xla/stream_executor", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor/cuda:cuda_platform", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:path", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + +#===--------------------------------------------------------------------------------------------===# +# CUTLASS Gemm <-> xla::gpu::kernel::CustomKernel adaptor +#===--------------------------------------------------------------------------------------------===# + +cc_library( + name = "cutlass_gemm_custom_kernel", + srcs = if_cuda_is_configured( + ["cutlass_gemm_custom_kernel.cc"], + ["cutlass_gemm_custom_kernel_stub.cc"], + ), + hdrs = ["cutlass_gemm_custom_kernel.h"], + deps = [ + ":custom_kernel", + ":cutlass_gemm", + ":cutlass_gemm_kernels", # build_cleaner: keep + "//xla:statusor", + "//xla:xla_data_proto_cc", + "//xla/stream_executor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +xla_test( + name = "cutlass_gemm_custom_kernel_test", + srcs = if_cuda_is_configured(["cutlass_gemm_custom_kernel_test.cc"]), + backends = ["gpu"], + data = [":cutlass_gemm_kernel_f32xf32_to_f32.so"], + deps = [ + ":cutlass_gemm_custom_kernel", + "//xla:xla_data_proto_cc", + "//xla/stream_executor", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor/cuda:cuda_platform", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:path", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + +cc_binary( + name = "cutlass_gemm_custom_kernel_benchmarks", + testonly = 1, + srcs = if_cuda_is_configured(["cutlass_gemm_custom_kernel_benchmarks.cc"]), + deps = [ + ":cutlass_gemm_custom_kernel", + "//xla:xla_data_proto_cc", + "//xla/service:gpu_plugin", + "//xla/stream_executor", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor/cuda:cuda_platform", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_benchmark", + "@tsl//tsl/platform:test_main", + ], +) + +#===--------------------------------------------------------------------------------------------===# +# CUTLASS GemmUniversal-base kernels <-> StreamExecutor adaptor +#===--------------------------------------------------------------------------------------------===# + +cc_library( + name = "cutlass_gemm", + srcs = ["cutlass_gemm.cc"], + hdrs = ["cutlass_gemm.h"], + deps = ["@tsl//tsl/platform:logging"], +) + +cuda_library( + name = "cutlass_gemm_adaptor", + hdrs = if_cuda_is_configured(["cutlass_gemm_adaptor.cu.h"]), + copts = ["-Wno-unknown-attributes"], # __grid_constant__ is not supported by clang + deps = if_cuda_is_configured([ + ":cutlass_gemm", + "@cutlass_archive//:cutlass", + ]), +) + +cuda_library( + name = "cutlass_gemm_epilogue", + # TODO(ezhulenev): Update to regular hdrs after fixing CUTLASS headers. + textual_hdrs = if_cuda_is_configured(["cutlass_gemm_epilogue.cu.h"]), + deps = if_cuda_is_configured(["@cutlass_archive//:cutlass"]), +) + +#===--------------------------------------------------------------------------------------------===# +# CUTLASS Gemm kernels implementation +#===--------------------------------------------------------------------------------------------===# + +# We split each individual kernel into a separate targets to compile them all in parallel. We also +# do not have any dependencies except CUTLASS itself to reduce the number of recompilations. + +cc_library( + name = "cutlass_gemm_kernels", + deps = [ + ":cutlass_gemm_kernel_bf16xbf16_to_bf16", + ":cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80", + ":cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90", + ":cutlass_gemm_kernel_f32xf32_to_f32", + ], +) + +# CUTLASS requires all loops to be unrolled, and in some kernels defined below we force Clang/LLVM +# to unroll them with extra compiler options because by default LLVM is not as aggressive with loop +# unrolling as NVCC. + +# TODO(ezhulenev): Write a build rule to simplify kernel target declarations. + +cuda_library( + name = "cutlass_gemm_kernel_bf16xbf16_to_bf16", + srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16.cu.cc"]), + copts = ["-Wno-unknown-attributes -mllvm -unroll-threshold=100000"], + deps = if_cuda_is_configured([ + ":cutlass_gemm_adaptor", + "@cutlass_archive//:cutlass", + "@local_config_cuda//cuda:cuda_headers", + ]), +) + +cuda_library( + name = "cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80", + srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80.cu.cc"]), + copts = ["-Wno-unknown-attributes -mllvm -unroll-threshold=100000"], + deps = if_cuda_is_configured([ + ":cutlass_gemm_adaptor", + "@cutlass_archive//:cutlass", + "@local_config_cuda//cuda:cuda_headers", + ]), +) + +cuda_library( + name = "cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90", + srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90.cu.cc"]), + copts = ["-Wno-ctad-maybe-unsupported -Wno-unknown-attributes -mllvm -unroll-threshold=100000"], + deps = if_cuda_is_configured([ + ":cutlass_gemm_adaptor", + ":cutlass_gemm_epilogue", + "@cutlass_archive//:cutlass", + "@local_config_cuda//cuda:cuda_headers", + ]), +) + +cuda_library( + name = "cutlass_gemm_kernel_f32xf32_to_f32", + srcs = if_cuda_is_configured(["cutlass_gemm_kernel_f32xf32_to_f32.cu.cc"]), + copts = ["-Wno-unknown-attributes"], + deps = if_cuda_is_configured([ + ":cutlass_gemm_adaptor", + "@cutlass_archive//:cutlass", + "@local_config_cuda//cuda:cuda_headers", + ]), +) + +#===--------------------------------------------------------------------------------------------===# +# CUTLASS Gemm kernel libraries +#===--------------------------------------------------------------------------------------------===# + +cc_binary( + name = "cutlass_gemm_kernel_f32xf32_to_f32.so", + srcs = if_cuda_is_configured(["cutlass_gemm_kernel_f32xf32_to_f32.cc"]), + linkshared = True, + linkstatic = False, + deps = [":cutlass_gemm"], +) diff --git a/xla/service/gpu/kernels/custom_fusion.cc b/xla/service/gpu/kernels/custom_fusion.cc deleted file mode 100644 index c35ecfa833fe3..0000000000000 --- a/xla/service/gpu/kernels/custom_fusion.cc +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/kernels/custom_fusion.h" - -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/synchronization/mutex.h" -#include "xla/status.h" - -namespace xla::gpu { - -//===----------------------------------------------------------------------===// -// CustomFusionRegistry -//===----------------------------------------------------------------------===// - -CustomFusionRegistry* CustomFusionRegistry::Default() { - static auto* registry = new CustomFusionRegistry(); - return registry; -} - -Status CustomFusionRegistry::Register(std::string name, - std::unique_ptr fusion) { - absl::MutexLock lock(&mutex_); - if (auto it = registry_.try_emplace(name, std::move(fusion)); it.second) - return OkStatus(); - return absl::InternalError( - absl::StrCat("Custom fusion ", name, " already registered.")); -} - -CustomFusion* CustomFusionRegistry::Lookup(std::string_view name) const { - absl::MutexLock lock(&mutex_); - if (auto it = registry_.find(name); it != registry_.end()) - return it->second.get(); - return nullptr; -} - -} // namespace xla::gpu diff --git a/xla/service/gpu/kernels/custom_fusion.h b/xla/service/gpu/kernels/custom_fusion.h deleted file mode 100644 index 9311e43e63040..0000000000000 --- a/xla/service/gpu/kernels/custom_fusion.h +++ /dev/null @@ -1,150 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_KERNELS_CUSTOM_FUSION_H_ -#define XLA_SERVICE_GPU_KERNELS_CUSTOM_FUSION_H_ - -#include -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/service/gpu/kernels/custom_fusion.h" -#include "xla/service/gpu/kernels/custom_kernel.h" -#include "xla/status.h" -#include "xla/statusor.h" -#include "tsl/platform/logging.h" - -namespace xla::gpu { - -//===----------------------------------------------------------------------===// -// CustomFusion -//===----------------------------------------------------------------------===// - -// Custom fusion is a mechanism for registering custom kernels corresponding to -// HLO fusions. -// -// Example: row-major mixed dtype gemm with fused bitcast -// -// %gemm (parameter_0: s8[19,17], parameter_1: f16[15,19]) -> f16[15,17] { -// %parameter_1 = f16[15,19]{1,0} parameter(1) -// %parameter_0 = s8[19,17]{1,0} parameter(0) -// %cp1.1 = f16[19,17]{1,0} convert(%parameter_0) -// ROOT %r.1 = f16[15,17]{1,0} dot(%parameter_1, %cp1.1), -// lhs_contracting_dims={1}, -// rhs_contracting_dims={0} -// } -// -// ENTRY %e (p0: f16[15,19], p1: s8[19,17]) -> f16[15,17] { -// %p1 = s8[19,17]{1,0} parameter(1) -// %p0 = f16[15,19]{1,0} parameter(0) -// ROOT %gemm = f16[15,17]{1,0} fusion(%p1, %p0), kind=kCustom, -// -// } -// -// XLA:GPU has multiple strategies for executing this fusion on device: -// -// (1) cuBLAS library call: a lot of simple gemm operations are supported by -// cuBLAS out of the box. However some combinations of paramters casting and -// epilogue fusion are not supported, which means that XLA has to form -// smaller fusions or use code generation to compiled a device kernel. -// -// (2) Triton: XLA:GPU uses Triton to codegen gemm fusion into devie kernels -// (PTX and CUBIN for NVIDIA gpus). -// -// (3) Custom fusion is another mechanism to execute fusion on device, which -// relies on pre-compiled libraries of custom kernels authored by CUDA C++ -// experts. Custom fusion implements one particular fusion pattern (e.g. -// type casting plus a dot operation like in the example above) with custom -// kernels that XLA has to choose from at run time based on auto tuning. -// -// In practice custom fusion almost always implemented with multiple -// kernels, because input shapes are not known at compile time, and custom -// fusion has multiple kernels with different tiling schemes. -// -// What differentiates custom fusions from custom calls, is that custom fusion -// should be implemented with a device kernel, and this allows XLA:GPU to treat -// custom fusion just like any other device kernel: it's launched as a regular -// KernelThunk and automatically captured into command buffers. -// -// Custom calls (registered with XLA:FFI) on the other hand gives much more -// flexibility, and can be implemented as a combination of a non-trivial host -// side code plus multiple kernel launches or library calls. -// -// Also XLA:FFI offers a stable C API that allows registering external functions -// loaded from dynamic libraries compiled with a different toolchain of XLA -// version. Custom fusions integration relies on C++ ABI and static linking. -// -// TODO(ezhulenev): It should be possible to lower `stablehlo.custom_call` -// operations to custom fusions, albeit with a static linking restriction. -class CustomFusion { - public: - virtual ~CustomFusion() = default; - - // Loads kernels implementing `hlo_computation`. - virtual StatusOr> LoadKernels( - const HloComputation* computation) const = 0; -}; - -//===----------------------------------------------------------------------===// -// CustomFusionRegistry -//===----------------------------------------------------------------------===// - -// Custom fusion registry is a mapping from a custom fusion name to the custom -// fusion implementation, and XLA compiler uses this registry to lower fusion -// operations to kernels when emitting thunks. -class CustomFusionRegistry { - public: - // Returns a pointer to a default custom fusion registry, which is a global - // static registry. - static CustomFusionRegistry* Default(); - - // Registers custom fusion in the registry. Returns error if fusion with the - // given name already registered. - Status Register(std::string name, std::unique_ptr fusion); - - // Looks up custom fusion by name. Return nullptr if it's not found. - CustomFusion* Lookup(std::string_view name) const; - - private: - mutable absl::Mutex mutex_; - absl::flat_hash_map> registry_ - ABSL_GUARDED_BY(mutex_); -}; - -} // namespace xla::gpu - -#define XLA_REGISTER_CUSTOM_FUSION(NAME, FUSION) \ - XLA_REGISTER_CUSTOM_FUSION_(NAME, FUSION, __COUNTER__) - -#define XLA_REGISTER_CUSTOM_FUSION_(NAME, FUSION, N) \ - XLA_REGISTER_CUSTOM_FUSION__(NAME, FUSION, N) - -#define XLA_REGISTER_CUSTOM_FUSION__(NAME, FUSION, N) \ - ABSL_ATTRIBUTE_UNUSED static const bool \ - xla_custom_fusion_##N##_registered_ = [] { \ - ::xla::Status status = \ - ::xla::gpu::CustomFusionRegistry::Default()->Register( \ - NAME, std::make_unique()); \ - if (!status.ok()) LOG(ERROR) << status; \ - return status.ok(); \ - }() - -#endif // XLA_SERVICE_GPU_KERNELS_CUSTOM_FUSION_H_ diff --git a/xla/service/gpu/kernels/custom_fusion_pattern.cc b/xla/service/gpu/kernels/custom_fusion_pattern.cc deleted file mode 100644 index db4d5d2140918..0000000000000 --- a/xla/service/gpu/kernels/custom_fusion_pattern.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/kernels/custom_fusion_pattern.h" - -#include -#include -#include - -#include "xla/hlo/ir/hlo_instruction.h" - -namespace xla::gpu { - -CustomFusionPatternRegistry* CustomFusionPatternRegistry::Default() { - static auto* registry = new CustomFusionPatternRegistry(); - return registry; -} - -std::vector CustomFusionPatternRegistry::Match( - HloInstruction* instr) const { - std::vector matches; - for (auto& pattern : patterns_) { - if (auto matched = pattern->TryMatch(instr); matched.has_value()) - matches.push_back(std::move(*matched)); - } - return matches; -} - -void CustomFusionPatternRegistry::Add( - std::unique_ptr pattern) { - patterns_.push_back(std::move(pattern)); -} - -} // namespace xla::gpu diff --git a/xla/service/gpu/kernels/custom_fusion_pattern.h b/xla/service/gpu/kernels/custom_fusion_pattern.h deleted file mode 100644 index 02123388e2600..0000000000000 --- a/xla/service/gpu/kernels/custom_fusion_pattern.h +++ /dev/null @@ -1,93 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_KERNELS_CUSTOM_FUSION_PATTERN_H_ -#define XLA_SERVICE_GPU_KERNELS_CUSTOM_FUSION_PATTERN_H_ - -#include -#include -#include -#include - -#include "absl/base/attributes.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/backend_configs.pb.h" - -namespace xla::gpu { - -//===----------------------------------------------------------------------===// -// CustomFusionPattern -//===----------------------------------------------------------------------===// - -// Custom fusion pattern matches HLO instruction to custom kernels. -class CustomFusionPattern { - public: - virtual ~CustomFusionPattern() = default; - - struct Match { - CustomFusionConfig config; - std::vector instructions; - }; - - // Returns custom fusion config and a list of instructions that matched to a - // custom fusion (one or more custom kernels). Custom fusion pass will outline - // matched instructions into a custom fusion operation if possible. - // - // TODO(ezhulenev): Today the last instruction defines custom fusion root - // (results), however we need to add support for custom fusion that can return - // intermediate result, and custom fusions that require an extra workspace. - virtual std::optional TryMatch(HloInstruction *instr) const = 0; -}; - -//===----------------------------------------------------------------------===// -// CustomFusionPatternRegistry -//===----------------------------------------------------------------------===// - -class CustomFusionPatternRegistry { - public: - // Returns a pointer to a default custom fusion pattern registry, which is a - // global static registry. - static CustomFusionPatternRegistry *Default(); - - std::vector Match(HloInstruction *instr) const; - - void Add(std::unique_ptr pattern); - - template > - void Emplace() { - (Add(std::make_unique()), ...); - } - - private: - std::vector> patterns_; -}; - -} // namespace xla::gpu - -#define XLA_REGISTER_CUSTOM_FUSION_PATTERN(PATTERN) \ - XLA_REGISTER_CUSTOM_FUSION_PATTERN_(PATTERN, __COUNTER__) - -#define XLA_REGISTER_CUSTOM_FUSION_PATTERN_(PATTERN, N) \ - XLA_REGISTER_CUSTOM_FUSION_PATTERN__(PATTERN, N) - -#define XLA_REGISTER_CUSTOM_FUSION_PATTERN__(PATTERN, N) \ - ABSL_ATTRIBUTE_UNUSED static const bool \ - xla_custom_fusion_pattern_##N##_registered_ = [] { \ - ::xla::gpu::CustomFusionPatternRegistry::Default() \ - ->Emplace(); \ - return true; \ - }() - -#endif // XLA_SERVICE_GPU_KERNELS_CUSTOM_FUSION_PATTERN_H_ diff --git a/xla/service/gpu/kernels/custom_kernel.cc b/xla/service/gpu/kernels/custom_kernel.cc index b9451eb6ff015..47cb849c611bc 100644 --- a/xla/service/gpu/kernels/custom_kernel.cc +++ b/xla/service/gpu/kernels/custom_kernel.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,7 +16,9 @@ limitations under the License. #include "xla/service/gpu/kernels/custom_kernel.h" #include +#include #include +#include #include #include "absl/strings/str_format.h" @@ -33,9 +35,23 @@ CustomKernel::CustomKernel(std::string name, kernel_spec_(std::move(kernel_spec)), block_dims_(block_dims), thread_dims_(thread_dims), + cluster_dims_(std::nullopt), + shared_memory_bytes_(shared_memory_bytes) {} +CustomKernel::CustomKernel(std::string name, + se::MultiKernelLoaderSpec kernel_spec, + se::BlockDim block_dims, se::ThreadDim thread_dims, + se::ClusterDim cluster_dims, + size_t shared_memory_bytes) + : name_(std::move(name)), + kernel_spec_(std::move(kernel_spec)), + block_dims_(block_dims), + thread_dims_(thread_dims), + cluster_dims_(cluster_dims), shared_memory_bytes_(shared_memory_bytes) {} +std::string_view CustomKernel::name() const { return name_; } + const se::MultiKernelLoaderSpec& CustomKernel::kernel_spec() const { return kernel_spec_; } @@ -44,6 +60,10 @@ se::BlockDim CustomKernel::block_dims() const { return block_dims_; } se::ThreadDim CustomKernel::thread_dims() const { return thread_dims_; } +std::optional CustomKernel::cluster_dims() const { + return cluster_dims_; +} + size_t CustomKernel::shared_memory_bytes() const { return shared_memory_bytes_; } diff --git a/xla/service/gpu/kernels/custom_kernel.h b/xla/service/gpu/kernels/custom_kernel.h index 32f65b76a7a49..433f43f38ce49 100644 --- a/xla/service/gpu/kernels/custom_kernel.h +++ b/xla/service/gpu/kernels/custom_kernel.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,7 +17,9 @@ limitations under the License. #define XLA_SERVICE_GPU_KERNELS_CUSTOM_KERNEL_H_ #include +#include #include +#include #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" @@ -47,12 +49,20 @@ class CustomKernel { se::BlockDim block_dims, se::ThreadDim thread_dims, size_t shared_memory_bytes); + CustomKernel(std::string name, se::MultiKernelLoaderSpec kernel_spec, + se::BlockDim block_dims, se::ThreadDim thread_dims, + se::ClusterDim cluster_dims, size_t shared_memory_bytes); + + std::string_view name() const; + const se::MultiKernelLoaderSpec& kernel_spec() const; se::BlockDim block_dims() const; se::ThreadDim thread_dims() const; + std::optional cluster_dims() const; + size_t shared_memory_bytes() const; std::string ToString() const; @@ -62,6 +72,7 @@ class CustomKernel { se::MultiKernelLoaderSpec kernel_spec_; se::BlockDim block_dims_; se::ThreadDim thread_dims_; + std::optional cluster_dims_; size_t shared_memory_bytes_; }; diff --git a/xla/service/gpu/kernels/custom_kernel_fusion.cc b/xla/service/gpu/kernels/custom_kernel_fusion.cc new file mode 100644 index 0000000000000..ec65f30510ebb --- /dev/null +++ b/xla/service/gpu/kernels/custom_kernel_fusion.cc @@ -0,0 +1,57 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/kernels/custom_kernel_fusion.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "xla/status.h" + +namespace xla::gpu { + +//===----------------------------------------------------------------------===// +// CustomKernelFusionRegistry +//===----------------------------------------------------------------------===// + +CustomKernelFusionRegistry* CustomKernelFusionRegistry::Default() { + static auto* registry = new CustomKernelFusionRegistry(); + return registry; +} + +absl::Status CustomKernelFusionRegistry::Register( + std::string name, std::unique_ptr fusion) { + absl::MutexLock lock(&mutex_); + if (auto it = registry_.try_emplace(name, std::move(fusion)); it.second) + return absl::OkStatus(); + return absl::InternalError( + absl::StrCat("Custom kernel fusion ", name, " already registered.")); +} + +CustomKernelFusion* CustomKernelFusionRegistry::Lookup( + std::string_view name) const { + absl::MutexLock lock(&mutex_); + if (auto it = registry_.find(name); it != registry_.end()) + return it->second.get(); + return nullptr; +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/kernels/custom_kernel_fusion.h b/xla/service/gpu/kernels/custom_kernel_fusion.h new file mode 100644 index 0000000000000..e0d8a27ff5768 --- /dev/null +++ b/xla/service/gpu/kernels/custom_kernel_fusion.h @@ -0,0 +1,156 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_KERNELS_CUSTOM_KERNEL_FUSION_H_ +#define XLA_SERVICE_GPU_KERNELS_CUSTOM_KERNEL_FUSION_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/status.h" +#include "xla/stream_executor/device_description.h" +#include "tsl/platform/logging.h" + +namespace xla::gpu { + +//===----------------------------------------------------------------------===// +// CustomKernelFusion +//===----------------------------------------------------------------------===// + +// Custom kernel fusion is a mechanism for registering custom kernels +// corresponding to HLO fusions. +// +// Example: row-major mixed dtype gemm with fused bitcast +// +// %gemm (parameter_0: s8[19,17], parameter_1: f16[15,19]) -> f16[15,17] { +// %parameter_1 = f16[15,19]{1,0} parameter(1) +// %parameter_0 = s8[19,17]{1,0} parameter(0) +// %cp1.1 = f16[19,17]{1,0} convert(%parameter_0) +// ROOT %r.1 = f16[15,17]{1,0} dot(%parameter_1, %cp1.1), +// lhs_contracting_dims={1}, +// rhs_contracting_dims={0} +// } +// +// ENTRY %e (p0: f16[15,19], p1: s8[19,17]) -> f16[15,17] { +// %p1 = s8[19,17]{1,0} parameter(1) +// %p0 = f16[15,19]{1,0} parameter(0) +// ROOT %gemm = f16[15,17]{1,0} fusion(%p1, %p0), kind=kCustom, +// +// } +// +// XLA:GPU has multiple strategies for executing this fusion on device: +// +// (1) cuBLAS library call: a lot of simple gemm operations are supported by +// cuBLAS out of the box. However some combinations of paramters casting and +// epilogue fusion are not supported, which means that XLA has to form +// smaller fusions or use code generation to compiled a device kernel. +// +// (2) Triton: XLA:GPU uses Triton to codegen gemm fusion into devie kernels +// (PTX and CUBIN for NVIDIA gpus). +// +// (3) Custom kernel fusion is another mechanism to execute fusion on device, +// which +// relies on pre-compiled libraries of custom kernels authored by CUDA C++ +// experts. Custom kernel fusion implements one particular fusion pattern +// (e.g. type casting plus a dot operation like in the example above) with +// custom kernels that XLA has to choose from at run time based on auto +// tuning. +// +// In practice custom kernel fusion almost always implemented with multiple +// kernels, because input shapes are not known at compile time, and custom +// fusion has multiple kernels with different tiling schemes. +// +// What differentiates custom kernel fusions from custom calls, is that custom +// kernel fusion should be implemented with a device kernel, and this allows +// XLA:GPU to treat custom kernel fusion just like any other device kernel: it's +// launched as a regular KernelThunk and automatically captured into command +// buffers. +// +// Custom calls (registered with XLA:FFI) on the other hand gives much more +// flexibility, and can be implemented as a combination of a non-trivial host +// side code plus multiple kernel launches or library calls. +// +// Also XLA:FFI offers a stable C API that allows registering external functions +// loaded from dynamic libraries compiled with a different toolchain of XLA +// version. Custom kernel fusions integration relies on C++ ABI and static +// linking. +// +// TODO(ezhulenev): It should be possible to lower `stablehlo.custom_call` +// operations to custom kernel fusions, albeit with a static linking +// restriction. +class CustomKernelFusion { + public: + virtual ~CustomKernelFusion() = default; + + // Loads kernels implementing `hlo_computation` optimized for a given device. + virtual absl::StatusOr> LoadKernels( + const se::DeviceDescription& device, + const HloComputation* computation) const = 0; +}; + +//===----------------------------------------------------------------------===// +// CustomKernelFusionRegistry +//===----------------------------------------------------------------------===// + +// Custom fusion registry is a mapping from a custom kernel fusion name to the +// custom fusion implementation, and XLA compiler uses this registry to lower +// fusion operations to kernels when emitting thunks. +class CustomKernelFusionRegistry { + public: + // Returns a pointer to a default custom fusion registry, which is a global + // static registry. + static CustomKernelFusionRegistry* Default(); + + // Registers custom kernel fusion in the registry. Returns error if fusion + // with the given name already registered. + absl::Status Register(std::string name, + std::unique_ptr fusion); + + // Looks up custom kernel fusion by name. Return nullptr if it's not found. + CustomKernelFusion* Lookup(std::string_view name) const; + + private: + mutable absl::Mutex mutex_; + absl::flat_hash_map> + registry_ ABSL_GUARDED_BY(mutex_); +}; + +} // namespace xla::gpu + +#define XLA_REGISTER_CUSTOM_FUSION(NAME, FUSION) \ + XLA_REGISTER_CUSTOM_FUSION_(NAME, FUSION, __COUNTER__) + +#define XLA_REGISTER_CUSTOM_FUSION_(NAME, FUSION, N) \ + XLA_REGISTER_CUSTOM_FUSION__(NAME, FUSION, N) + +#define XLA_REGISTER_CUSTOM_FUSION__(NAME, FUSION, N) \ + ABSL_ATTRIBUTE_UNUSED static const bool \ + xla_custom_fusion_##N##_registered_ = [] { \ + absl::Status status = \ + ::xla::gpu::CustomKernelFusionRegistry::Default()->Register( \ + NAME, std::make_unique()); \ + if (!status.ok()) LOG(ERROR) << status; \ + return status.ok(); \ + }() + +#endif // XLA_SERVICE_GPU_KERNELS_CUSTOM_KERNEL_FUSION_H_ diff --git a/xla/service/gpu/kernels/custom_kernel_fusion_pattern.cc b/xla/service/gpu/kernels/custom_kernel_fusion_pattern.cc new file mode 100644 index 0000000000000..aa5531967b7b2 --- /dev/null +++ b/xla/service/gpu/kernels/custom_kernel_fusion_pattern.cc @@ -0,0 +1,92 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" + +namespace xla::gpu { + +//===----------------------------------------------------------------------===// +// CustomKernelFusionPattern::Match +//===----------------------------------------------------------------------===// + +CustomKernelFusionPattern::Match::Match( + CustomFusionConfig config, std::vector instructions, + int64_t workspace_size_bytes) + : config_(std::move(config)), + instructions_(std::move(instructions)), + workspace_size_bytes_(workspace_size_bytes) {} + +void CustomKernelFusionPattern::Match::AddReplacement(HloInstruction* instr, + Replacement replacement) { + replacements_[instr] = std::move(replacement); +} + +bool CustomKernelFusionPattern::Match::HasReplacement( + HloInstruction* instr) const { + return replacements_.contains(instr); +} + +absl::StatusOr +CustomKernelFusionPattern::Match::BuildReplacement( + HloInstruction* instr, HloFusionInstruction* fusion) const { + if (auto it = replacements_.find(instr); it != replacements_.end()) { + return it->second(fusion); + } + + return absl::InvalidArgumentError( + absl::StrCat("no replacement for instruction: ", instr->name())); +} + +//===----------------------------------------------------------------------===// +// CustomKernelFusionPatternRegistry +//===----------------------------------------------------------------------===// + +CustomKernelFusionPatternRegistry* +CustomKernelFusionPatternRegistry::Default() { + static auto* registry = new CustomKernelFusionPatternRegistry(); + return registry; +} + +std::vector +CustomKernelFusionPatternRegistry::Match(const se::DeviceDescription& device, + HloInstruction* instr) const { + std::vector matches; + for (auto& pattern : patterns_) { + if (auto matched = pattern->TryMatch(device, instr); matched.has_value()) + matches.push_back(std::move(*matched)); + } + return matches; +} + +void CustomKernelFusionPatternRegistry::Add( + std::unique_ptr pattern) { + patterns_.push_back(std::move(pattern)); +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/kernels/custom_kernel_fusion_pattern.h b/xla/service/gpu/kernels/custom_kernel_fusion_pattern.h new file mode 100644 index 0000000000000..a16f667410dd0 --- /dev/null +++ b/xla/service/gpu/kernels/custom_kernel_fusion_pattern.h @@ -0,0 +1,148 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_KERNELS_CUSTOM_KERNEL_FUSION_PATTERN_H_ +#define XLA_SERVICE_GPU_KERNELS_CUSTOM_KERNEL_FUSION_PATTERN_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/stream_executor/device_description.h" + +namespace xla::gpu { + +//===----------------------------------------------------------------------===// +// CustomKernelFusionPattern +//===----------------------------------------------------------------------===// + +// Custom kernel fusion pattern matches HLO instruction to custom kernels. +class CustomKernelFusionPattern { + public: + // A name of a custom call that can be added to a custom kernel fusion body to + // allocate a workspace buffer require for the custom kernel fusion + // implementation. + static constexpr const char *kWorkspace = "__custom_kernel_fusion$workspace"; + + virtual ~CustomKernelFusionPattern() = default; + + // Matched sequence of instructions that can be handled by a custom kernel + // fusion. + class Match { + public: + Match(CustomFusionConfig config, std::vector instructions, + int64_t workspace_size_bytes = 0); + + // If some of operations matched by a pattern have users outside of the + // custom kernel fusion, pattern can optionally provide a replacement that + // can be derived from the fusion instruction result, or from other + // instructions in the parent computation. + using Replacement = + std::function(HloFusionInstruction *)>; + + void AddReplacement(HloInstruction *instr, Replacement replacement); + bool HasReplacement(HloInstruction *instr) const; + + // Builds a replacement for `instr` using a `fusion` instruction constructed + // for a pattern match. + absl::StatusOr BuildReplacement( + HloInstruction *instr, HloFusionInstruction *fusion) const; + + const CustomFusionConfig &config() const { return config_; } + absl::Span instructions() const { + return instructions_; + } + + HloInstruction *root() const { return instructions_.back(); } + + int64_t workspace_size_bytes() const { return workspace_size_bytes_; } + + private: + CustomFusionConfig config_; + std::vector instructions_; + absl::flat_hash_map replacements_; + int64_t workspace_size_bytes_; + }; + + // Returns custom fusion config and a list of instructions that matched to a + // custom kernel fusion (one or more custom kernels). Custom kernel fusion + // pass will outline matched instructions into a custom kernel fusion + // operation if possible. + // + // TODO(ezhulenev): Today the last instruction defines custom kernel fusion + // root (results), however we need to add support for custom kernel fusion + // that can return intermediate result, and custom kernel fusions that require + // an extra workspace. + virtual std::optional TryMatch(const se::DeviceDescription &device, + HloInstruction *instr) const = 0; +}; + +//===----------------------------------------------------------------------===// +// CustomKernelFusionPatternRegistry +//===----------------------------------------------------------------------===// + +class CustomKernelFusionPatternRegistry { + public: + // Returns a pointer to a default custom kernel fusion pattern registry, which + // is a global static registry. + static CustomKernelFusionPatternRegistry *Default(); + + std::vector Match( + const se::DeviceDescription &device, HloInstruction *instr) const; + + void Add(std::unique_ptr pattern); + + template > + void Emplace() { + (Add(std::make_unique()), ...); + } + + template > + void Emplace(Arg &&arg) { + (Add(std::make_unique(std::forward(arg))), ...); + } + + private: + std::vector> patterns_; +}; + +} // namespace xla::gpu + +#define XLA_REGISTER_CUSTOM_FUSION_PATTERN(PATTERN) \ + XLA_REGISTER_CUSTOM_FUSION_PATTERN_(PATTERN, __COUNTER__) + +#define XLA_REGISTER_CUSTOM_FUSION_PATTERN_(PATTERN, N) \ + XLA_REGISTER_CUSTOM_FUSION_PATTERN__(PATTERN, N) + +#define XLA_REGISTER_CUSTOM_FUSION_PATTERN__(PATTERN, N) \ + ABSL_ATTRIBUTE_UNUSED static const bool \ + xla_custom_fusion_pattern_##N##_registered_ = [] { \ + ::xla::gpu::CustomKernelFusionPatternRegistry::Default() \ + ->Emplace(); \ + return true; \ + }() + +#endif // XLA_SERVICE_GPU_KERNELS_CUSTOM_KERNEL_FUSION_PATTERN_H_ diff --git a/xla/service/gpu/kernels/cutlass_gemm.cc b/xla/service/gpu/kernels/cutlass_gemm.cc new file mode 100644 index 0000000000000..f2859cd05b180 --- /dev/null +++ b/xla/service/gpu/kernels/cutlass_gemm.cc @@ -0,0 +1,202 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/kernels/cutlass_gemm.h" + +#include +#include +#include + +#include "tsl/platform/logging.h" + +#if !defined(PLATFORM_WINDOWS) +#include +#endif + +namespace xla::gpu::kernel::gemm_universal { + +// TODO(b/315492043): We should add an XLA PJRT style C API for registering +// libraries of custom CUTLASS kernels compiled into shared libraries. It should +// be possible to bundle multiple custom CUTLASS kernels into a single shared +// library, and then load them optionally by name. For now we assume that there +// is a 1-to-1 mapping from a kernel to shared library, and they exported with +// a simple C API, and we hope that functions exported from a library has ABI +// that matches our expectations. + +using BlockDimFn = void (*)(int32_t m, int32_t n, int32_t k, uint32_t* x, + uint32_t* y, uint32_t* z); +using ThreadDimFn = void (*)(uint32_t* x, uint32_t* y, uint32_t* z); +using SharedMemoryBytesFn = int32_t (*)(); +using CanImplementFn = bool (*)(int32_t m, int32_t n, int32_t k); +using WorkspaceSizeFn = int64_t (*)(int32_t m, int32_t n, int32_t k); +using InitializeFn = void (*)(void* params, int32_t m, int32_t n, int32_t k, + void* lhs, void* rhs, void* out, void* workspace, + int32_t* out_offset, int32_t device_sms, + int32_t sm_occupancy); +using KernelSymboFn = void* (*)(); + +static constexpr const char* kBlockDimFn = "xla_cutlass_kernel_block_dim"; +static constexpr const char* kThreadDimFn = "xla_cutlass_kernel_thread_dim"; +static constexpr const char* kSharedMemoryBytes = + "xla_cutlass_kernel_shared_memory_bytes"; +static constexpr const char* kCanImplement = "xla_cutlass_kernel_can_implement"; +static constexpr const char* kWorkspaceSize = + "xla_cutlass_kernel_workspace_size"; +static constexpr const char* kInitialize = "xla_cutlass_kernel_initialize"; +static constexpr const char* kKernelSymbol = "xla_cutlass_kernel_symbol"; + +static void* Dlopen(const char* path) { +#if defined(PLATFORM_WINDOWS) + return nullptr; +#else + return dlopen(path, RTLD_LAZY); +#endif // defined(PLATFORM_WINDOWS) +} + +static void* Dlsym(void* handle, const char* name) { +#if defined(PLATFORM_WINDOWS) + return nullptr; +#else + return dlsym(handle, name); +#endif // defined(PLATFORM_WINDOWS) +} + +//===----------------------------------------------------------------------===// +// CUTLASS Host Side Adaptor +//===----------------------------------------------------------------------===// + +std::optional> Adaptor::Load( + const std::string& path) { + VLOG(3) << "Load CUTLASS adaptor from a shared library: " << path; + + void* library = Dlopen(path.c_str()); + if (library == nullptr) return std::nullopt; + + auto resolve = [&](const char* name) -> void* { + void* sym = Dlsym(library, name); + if (sym == nullptr) { + LOG(ERROR) << "Failed to resolve CUTLASS adaptor function: " << name + << " in library: " << path; + } + return sym; + }; + + void* block_dim_fn = resolve(kBlockDimFn); + if (block_dim_fn == nullptr) return std::nullopt; + + void* thread_dim_fn = resolve(kThreadDimFn); + if (thread_dim_fn == nullptr) return std::nullopt; + + void* shared_memory_bytes_fn = resolve(kSharedMemoryBytes); + if (shared_memory_bytes_fn == nullptr) return std::nullopt; + + void* can_implement_fn = resolve(kCanImplement); + if (shared_memory_bytes_fn == nullptr) return std::nullopt; + + void* workspace_size_fn = resolve(kWorkspaceSize); + if (workspace_size_fn == nullptr) return std::nullopt; + + void* initialize_fn = resolve(kInitialize); + if (shared_memory_bytes_fn == nullptr) return std::nullopt; + + return Adaptor(library, block_dim_fn, thread_dim_fn, shared_memory_bytes_fn, + can_implement_fn, workspace_size_fn, initialize_fn); +} + +std::optional Adaptor::ClusterDim() const { + return std::nullopt; +} + +Dim3 Adaptor::BlockDim(int32_t m, int32_t n, int32_t k) const { + Dim3 dim; + reinterpret_cast(block_dim_fn_)(m, n, k, &dim.x, &dim.y, &dim.z); + return dim; +} + +Dim3 Adaptor::ThreadDim() const { + Dim3 dim; + reinterpret_cast(thread_dim_fn_)(&dim.x, &dim.y, &dim.z); + return dim; +} + +int32_t Adaptor::SharedMemoryBytes() const { + return reinterpret_cast(shared_memory_bytes_fn_)(); +} + +bool Adaptor::CanImplement(const Arguments& args) const { + return reinterpret_cast(can_implement_fn_)(args.m, args.n, + args.k); +} + +int64_t Adaptor::WorkspaceSize(const Arguments& args) const { + return reinterpret_cast(workspace_size_fn_)(args.m, args.n, + args.k); +} + +void Adaptor::Initialize(void* params, const Arguments& args, + int32_t device_sms, + int32_t sm_occupancy) const { + reinterpret_cast(initialize_fn_)( + params, args.m, args.n, args.k, args.lhs, args.rhs, args.out, + args.workspace, args.slices.out, device_sms, sm_occupancy); +} + +Adaptor::Adaptor(void* handle, void* block_dim_fn, + void* thread_dim_fn, + void* shared_memory_bytes_fn, + void* can_implement_fn, + void* workspace_size_fn, void* initialize_fn) + : handle_(handle), + block_dim_fn_(block_dim_fn), + thread_dim_fn_(thread_dim_fn), + shared_memory_bytes_fn_(shared_memory_bytes_fn), + can_implement_fn_(can_implement_fn), + workspace_size_fn_(workspace_size_fn), + initialize_fn_(initialize_fn) {} + +//===----------------------------------------------------------------------===// +// CUTLASS Device Side Adaptor +//===----------------------------------------------------------------------===// + +std::optional> DeviceKernel::Load( + const std::string& path) { + VLOG(3) << "Load CUTLASS device kernel from a shared library: " << path; + + void* library = Dlopen(path.c_str()); + if (library == nullptr) return std::nullopt; + + auto resolve = [&](const char* name) -> void* { + void* sym = Dlsym(library, name); + if (sym == nullptr) { + LOG(ERROR) << "Failed to resolve CUTLASS kernel function: " << name + << " in library: " << path; + } + return sym; + }; + + void* kernel_symbol_fn = resolve(kKernelSymbol); + if (kernel_symbol_fn == nullptr) return std::nullopt; + + return DeviceKernel(library, kernel_symbol_fn); +} + +void* DeviceKernel::symbol() const { + return reinterpret_cast(symbol_fn_)(); +} + +DeviceKernel::DeviceKernel(void* handle, void* symbol_fn) + : handle_(handle), symbol_fn_(symbol_fn) {} + +} // namespace xla::gpu::kernel::gemm_universal diff --git a/xla/service/gpu/kernels/cutlass_gemm.h b/xla/service/gpu/kernels/cutlass_gemm.h new file mode 100644 index 0000000000000..37fb0ad8486ae --- /dev/null +++ b/xla/service/gpu/kernels/cutlass_gemm.h @@ -0,0 +1,247 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_H_ +#define XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_H_ + +//===-------------------------------------------------------------------------// +// ! ! ! ! ! WARNING ! ! ! ! ! // +//===-------------------------------------------------------------------------// +// // +// Do not add external dependencies to this header. Use only std library. // +// // +//===-------------------------------------------------------------------------// +// ! ! ! ! ! WARNING ! ! ! ! ! // +//===-------------------------------------------------------------------------// + +#include +#include +#include + +namespace xla::gpu::kernel::gemm_universal { + +//===----------------------------------------------------------------------===// +// Tag based GEMM dispatching +//===----------------------------------------------------------------------===// + +// We use tag-based template specializations to carefully avoid including +// CUTLASS headers into regular libraries, and specialize templates in separate +// CUDA build targets that have no dependencies on other parts of XLA or ABSL to +// enable parallel compilation and minimize recompilations on code changes. +// +// Here we re-define some of the enums and types defined in CUTLASS and CUTE to +// break a dependency on them from XLA. + +enum class Arch { kDefault, kSm80, kSm90 }; + +template +struct Bf16xBf16ToBf16 {}; + +template +struct F32xF32ToF32 {}; + +// A tag to specialize CUTLASS kernel adaptors for loading kernels from shared +// libraries using dlopen. +struct DlOpenedKernel {}; + +//===----------------------------------------------------------------------===// +// CUTLASS gemm arguments +//===----------------------------------------------------------------------===// + +// Indices of a custom fusion parameters corresponding to Gemm kernel arguments. +// +// Example: +// se::KernelArgsDeviceMemoryArray args = ... +// void* lhs = args->device_memory_ptr(indices.lhs); +// +// Custom fusion instruction can have parameters in arbitrary order, and we need +// a mapping from a custom kernel argument to the fusion instruction parameter. +struct ArgsIndices { + int64_t lhs; + int64_t rhs; + int64_t out; + + // Workspace parameter is a special case, as it's always passed as a last + // parameter at run time (only if requested). + bool has_workspace; +}; + +// Custom CUTLASS gemm kernels support on-device address arithmetics for input +// and output buffers, so that we can fuse dynamic-slice/dynamic-update-slice +// operations into the GEMM kernel. +// +// Base pointers and memory layout known on the host before kernel launch, but +// offsets are computed on device and available only in device memory. We can't +// load offsets to the host as it would require stream synchronization. +// +// Following structs encode how dynamic offsets passed to custom kernels. +// +// Example: CUTLASS gemm with a dynamic-update-slice +// +// cutlass_gemm { +// p0 = f32[2,2]{1,0} parameter(0) +// p1 = f32[2,2,2]{2,1,0} parameter(1) +// p2 = s32[] parameter(2) <--- major dim offset +// p3 = s32[] parameter(3) <--- minor dims offset +// dot = f32[2,2]{1,0} dot(p0, p0) +// ... +// ROOT r = f32[2,2,2]{2,1,0} dynamic-update-slice(p1, ..., p2, p3, p3) +// } +// +// In this example `p2` parameter defines a dynamic slice offset along the +// major dimension (0-th dimension for a row major layout). In practice +// parameters can be passed to fusions in arbitrary order, and when we pack +// custom kernel arguments into device kernel parameters we need to know +// how to find correct device pointers in the list of fusion arguments. +// +// For this example: +// +// DynamicSliceIndices::out = 2 +// DynamicSliceArguments::out = +// +// `DynamicSliceIndices` used in the host-code to fetch device memory pointers +// from arguments and pass it as `DynamicSliceArguments` to a device kernel. +// +// Kernel arguments packing function can pass dynamic slices as a part of +// CUTLASS kernel parameters, or as a separate argument to a device kernel entry +// function (CUTLASS 3x vs 2x). + +// Indices of a custom fusion parameters corresponding to dynamic slice offsets. +struct DynamicSliceIndices { + // Index of a dynamic slice offset along the major dimension. + std::optional out; +}; + +// Pointers to buffers (s32[] buffers in HLO) holding dynamic slice offsets. +struct DynamicSliceArguments { + int32_t* out = nullptr; +}; + +// Type-erased CUTLASS gemm arguments structure that has all of the details +// required for packing CUTLASS kernel parameters. +struct Arguments { + int32_t m; + int32_t n; + int32_t k; + + void* lhs; + void* rhs; + void* out; + void* workspace; + + DynamicSliceArguments slices; +}; + +//===----------------------------------------------------------------------===// +// CUTLASS Host Side Adaptor +//===----------------------------------------------------------------------===// + +template +struct Traits; + +struct Dim3 { + uint32_t x = 1; + uint32_t y = 1; + uint32_t z = 1; +}; + +// This is a type-erased adaptor that has all details required for launching +// CUTLASS kernel on a device. At run time device kernel parameters is really +// just a bag of bytes that driver sends to a kernel, so we rely on it to hide +// CUTLASS templates inside individual build targets and don't leak them into +// XLA, as they contain device code and can't be parsed by regular clang. +template +class Adaptor { + public: + std::optional ClusterDim() const; + Dim3 BlockDim(int32_t m, int32_t n, int32_t k) const; + Dim3 ThreadDim() const; + + int32_t SharedMemoryBytes() const; + + bool CanImplement(const Arguments& args) const; + int64_t WorkspaceSize(const Arguments& args) const; + + void Initialize(void* params, const Arguments& args, int32_t device_sms, + int32_t sm_occupancy) const; +}; + +// This is a specialization of adaptor that can load CUTLASS kernels from +// pre-compiled shared libraries on disk. Libraries can be compiled ahead of +// time using external toolchain, e.g. NVCC, as long as they export required +// symbols with a plain C calling convention. +template <> +class Adaptor { + public: + static std::optional Load(const std::string& path); + + std::optional ClusterDim() const; + Dim3 BlockDim(int32_t m, int32_t n, int32_t k) const; + Dim3 ThreadDim() const; + + int32_t SharedMemoryBytes() const; + + bool CanImplement(const Arguments& args) const; + int64_t WorkspaceSize(const Arguments& args) const; + + void Initialize(void* params, const Arguments& args, int32_t device_sms, + int32_t sm_occupancy) const; + + private: + Adaptor(void* handle, void* block_dim_fn, void* thread_dim_fn, + void* shared_memory_bytes_fn, void* can_implement_fn, + void* workspace_size_fn, void* initialize_fn); + + void* handle_; + void* block_dim_fn_; + void* thread_dim_fn_; + void* shared_memory_bytes_fn_; + void* can_implement_fn_; + void* workspace_size_fn_; + void* initialize_fn_; +}; + +//===----------------------------------------------------------------------===// +// CUTLASS Device Side Adaptor +//===----------------------------------------------------------------------===// + +// We keep device side adaptor separate from host side adaptor so that we could +// easily split host and device code compilation if needed. + +template +class DeviceKernel { + public: + void* symbol() const; +}; + +// This is a specialization of device kernel for loading CUTLASS kernels from +// shared libraries on disk (see Adaptor specialization above). +template <> +class DeviceKernel { + public: + static std::optional Load(const std::string& path); + + void* symbol() const; + + private: + DeviceKernel(void* handle, void* symbol_fn); + + void* handle_; + void* symbol_fn_; +}; + +} // namespace xla::gpu::kernel::gemm_universal + +#endif // XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_H_ diff --git a/xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h b/xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h new file mode 100644 index 0000000000000..f721279a9c5b6 --- /dev/null +++ b/xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h @@ -0,0 +1,441 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_ADAPTOR_CU_H_ +#define XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_ADAPTOR_CU_H_ + +#include +#include +#include + +#include "cute/layout.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/gemm_enumerated_types.h" +#include "cutlass/gemm_coord.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/util/packed_stride.hpp" +#include "xla/service/gpu/kernels/cutlass_gemm.h" + +namespace xla::gpu::kernel::gemm_universal { + +// This is a template library implementing adaptor from a CUTLASS kernel to +// StreamExecutor primitives for kernel arguments packing and kernel launching. +// +// This library is based on `GemmUniversalAdaptor` from CUTLASS itself, but +// instead of targeting CUDA runtime for launching kernels, it targets XLA +// StreamExecutor abstractions, but conceptually it has the same role: wrapping +// device kernels into C++ API to make them launchable on streams. + +//===----------------------------------------------------------------------===// +// CUTLASS 2x vs 3x +//===----------------------------------------------------------------------===// + +// Cutlass 2x and 3x have slightly different APIs, with a little bit of template +// metaprogramming and constexpr ifs we dispatch to the correct version at +// compile time based on a kernel template. +template +static constexpr bool is_cutlass_3x = + cutlass::gemm::detail::IsCutlass3GemmKernel< + typename Traits::Kernel>::value; + +//===----------------------------------------------------------------------===// +// Gemm strides computation +//===----------------------------------------------------------------------===// + +// TODO(ezhulenev): CUTLASS already has functions in cute to compute strides for +// a GEMM operations/kernels. Remove custom LdA/B/C functions. + +template +int64_t LdA(const cutlass::gemm::GemmCoord &problem_size) { + using LayoutA = typename Gemm::LayoutA; + + if constexpr (std::is_same_v) { + return problem_size.k(); + } else { + static_assert(sizeof(Gemm) == 0, "unsupported layout type"); + } +} + +template +int64_t LdB(const cutlass::gemm::GemmCoord &problem_size) { + using LayoutB = typename Gemm::LayoutB; + + if constexpr (std::is_same_v) { + return problem_size.n(); + } else { + static_assert(sizeof(Gemm) == 0, "unsupported layout type"); + } +} + +template +int64_t LdC(const cutlass::gemm::GemmCoord &problem_size) { + using LayoutC = typename Gemm::LayoutA; + + if constexpr (std::is_same_v) { + return problem_size.n(); + } else { + static_assert(sizeof(Gemm) == 0, "unsupported layout type"); + } +} + +//===----------------------------------------------------------------------===// +// CUTLASS 2x host side adaptor +//===----------------------------------------------------------------------===// + +namespace adaptor_2x { + +template +static std::optional ClusterDim() { + return std::nullopt; +} + +template +static Dim3 BlockDim(int32_t m, int32_t n, int32_t k) { + using Operation = typename Traits::Operation; + using ThreadblockSwizzle = typename Operation::ThreadblockSwizzle; + using ThreadblockShape = typename Operation::ThreadblockShape; + + cutlass::gemm::GemmCoord problem_size(m, n, k); + cutlass::gemm::GemmCoord tile_size(ThreadblockShape::kM, ThreadblockShape::kN, + ThreadblockShape::kK); + cutlass::gemm::GemmCoord grid_tiled_shape = + ThreadblockSwizzle::get_tiled_shape(problem_size, tile_size, + /*split_k_slices=*/1); + + auto grid = ThreadblockSwizzle().get_grid_shape(grid_tiled_shape); + return Dim3{grid.x, grid.y, grid.z}; +} + +template +static int32_t SharedMemoryBytes() { + return sizeof(typename Traits::Kernel::SharedStorage); +}; + +template +static Dim3 ThreadDim() { + return Dim3{Traits::Kernel::kThreadCount, 1, 1}; +} + +template +static bool CanImplement(const Arguments &args) { + cutlass::gemm::GemmCoord problem_size(args.m, args.n, args.k); + return Traits::Kernel::can_implement(problem_size) == + cutlass::Status::kSuccess; +} + +// Converts type-erased gemm arguments to the underlying CUTLASS operation +// arguments. +template +static typename Traits::Arguments OpArguments(const Arguments &args) { + cutlass::gemm::GemmCoord problem_size(args.m, args.n, args.k); + + // TODO(ezhulenev): Replace with cute::stride instead of custom templates. + auto lda = LdA::Operation>(problem_size); + auto ldb = LdB::Operation>(problem_size); + auto ldc = LdC::Operation>(problem_size); + + auto mode = cutlass::gemm::GemmUniversalMode::kGemm; + + // TODO(ezhulenev): We hardcode parameters for `LinearCombination` + // epilogue, however `Gemm` template can be compiled with arbitrary + // epilogues. We have to support custom epilogues in a way that does not + // leak cutlass types via the public API function signature. + using Accumulator = typename Traits::Operation::ElementAccumulator; + Accumulator alpha{1.0}; + Accumulator beta{0.0}; + + return typename Traits::Arguments( // CUTLASS Operation arguments + mode, problem_size, // + 1, // batch + {alpha, beta}, // epilogue + args.lhs, args.rhs, args.out, args.out, // pointers + 0, 0, 0, 0, // batch strides + lda, ldb, ldc, ldc // strides + ); +} + +template +int64_t WorkspaceSize(const Arguments &args) { + return Traits::Operation::get_workspace_size(OpArguments(args)); +} + +template +void Initialize(void *params, const Arguments &args, int32_t device_sms, + int32_t sm_occupancy) { + // Sanity check that parameters struct is compatible with parameters storage + // defined by custom gemm kernel. + static_assert(sizeof(typename Traits::Params) <= 1024, + "Params struct size is too large"); + static_assert(alignof(typename Traits::Params) <= 32, + "Params struct alignment is too large"); + + // Convert CUTLASS operation arguments to a device kernel parameters. + new (params) typename Traits::Params(OpArguments(args), device_sms, + sm_occupancy); +} + +}; // namespace adaptor_2x + +//===----------------------------------------------------------------------===// +// CUTLASS 3x host side adaptor +//===----------------------------------------------------------------------===// + +namespace adaptor_3x { + +template +static std::optional ClusterDim() { + typename Traits::Kernel::DispatchPolicy::ClusterShape cluster; + return Dim3{cute::get<0>(cluster), cute::get<1>(cluster), + cute::get<2>(cluster)}; +} + +template +static Dim3 BlockDim(int32_t m, int32_t n, int32_t k) { + return adaptor_2x::BlockDim(m, n, k); +} + +template +static Dim3 ThreadDim() { + auto block_shape = Traits::Kernel::get_block_shape(); + return Dim3{block_shape.x, block_shape.y, block_shape.z}; +} + +template +static int32_t SharedMemoryBytes() { + return Traits::Kernel::SharedStorageSize; +}; + +template +static typename Traits::Arguments OpArguments(const Arguments &args) { + using Kernel = typename Traits::Kernel; + using Operation = typename Traits::Operation; + + auto stride_a = cutlass::make_cute_packed_stride( + typename Kernel::StrideA{}, cute::make_shape(args.m, args.k, 1)); + auto stride_b = cutlass::make_cute_packed_stride( + typename Kernel::StrideB{}, cute::make_shape(args.n, args.k, 1)); + auto stride_c = cutlass::make_cute_packed_stride( + typename Kernel::StrideC{}, cute::make_shape(args.m, args.n, 1)); + auto stride_d = cutlass::make_cute_packed_stride( + typename Kernel::StrideD{}, cute::make_shape(args.m, args.n, 1)); + + // TODO(ezhulenev): Pass device id and sm_count in arguments. + cutlass::KernelHardwareInfo hw_info{/*device_id=*/0, /*sm_count=*/128}; + + auto mode = cutlass::gemm::GemmUniversalMode::kGemm; + typename Kernel::ProblemShape problem_shape = {args.m, args.n, args.k, + /*batch=*/1}; + + // TODO(ezhulenev): We hardcode parameters for `LinearCombination` + // epilogue, however `Gemm` template can be compiled with arbitrary + // epilogues. We have to support custom epilogues in a way that does not + // leak cutlass types via the public API function signature. + using Accumulator = typename Traits::Operation::ElementAccumulator; + Accumulator alpha{1.0}; + Accumulator beta{0.0}; + + typename Kernel::MainloopArguments mainloop_args{ + reinterpret_cast(args.lhs), stride_a, + reinterpret_cast(args.rhs), stride_b}; + + typename Kernel::EpilogueArguments epilogue_args{ + {alpha, beta}, + reinterpret_cast(args.out), + stride_c, + reinterpret_cast(args.out), + stride_d, + {{args.slices.out}, {args.m * args.n}}, // dynamic offsets for C + {{args.slices.out}, {args.m * args.n}}, // dynamic offsets for D + }; + + return typename Operation::Arguments{mode, problem_shape, mainloop_args, + epilogue_args, hw_info}; +} + +template +static bool CanImplement(const Arguments &args) { + return Traits::Kernel::can_implement(OpArguments(args)); +} + +template +static int64_t WorkspaceSize(const Arguments &args) { + return Traits::Operation::get_workspace_size(OpArguments(args)); +} + +template +static void Initialize(void *params, const Arguments &args, int32_t device_sms, + int32_t sm_occupancy) { + // Sanity check that parameters struct is compatible with parameters storage + // defined by custom gemm kernel. + static_assert(sizeof(typename Traits::Params) <= 1024, + "Params struct size is too large"); + static_assert(alignof(typename Traits::Params) <= 64, + "Params struct alignment is too large"); + + // Convert CUTLASS operation arguments to a device kernel parameters. + using Kernel = typename Traits::Kernel; + new (params) typename Traits::Params( + Kernel::to_underlying_arguments(OpArguments(args), args.workspace)); +} + +}; // namespace adaptor_3x + +//===----------------------------------------------------------------------===// +// Dispatch between CUTLASS 2x and 3x host adaptors +//===----------------------------------------------------------------------===// + +template +std::optional Adaptor::ClusterDim() const { + if constexpr (is_cutlass_3x) { + return adaptor_3x::ClusterDim(); + } else { + return adaptor_2x::ClusterDim(); + } +} + +template +Dim3 Adaptor::ThreadDim() const { + if constexpr (is_cutlass_3x) { + return adaptor_3x::ThreadDim(); + } else { + return adaptor_2x::ThreadDim(); + } +} + +template +Dim3 Adaptor::BlockDim(int32_t m, int32_t n, int32_t k) const { + if constexpr (is_cutlass_3x) { + return adaptor_3x::BlockDim(m, n, k); + } else { + return adaptor_2x::BlockDim(m, n, k); + } +} + +template +int32_t Adaptor::SharedMemoryBytes() const { + if constexpr (is_cutlass_3x) { + return adaptor_3x::SharedMemoryBytes(); + } else { + return adaptor_2x::SharedMemoryBytes(); + } +}; + +template +bool Adaptor::CanImplement(const Arguments &args) const { + if constexpr (is_cutlass_3x) { + return adaptor_3x::CanImplement(args); + } else { + return adaptor_2x::CanImplement(args); + } +} + +template +int64_t Adaptor::WorkspaceSize(const Arguments &args) const { + if constexpr (is_cutlass_3x) { + return adaptor_3x::WorkspaceSize(args); + } else { + return adaptor_2x::WorkspaceSize(args); + } +} + +template +void Adaptor::Initialize(void *params, const Arguments &args, + int32_t device_sms, int32_t sm_occupancy) const { + if constexpr (is_cutlass_3x) { + return adaptor_3x::Initialize(params, args, device_sms, sm_occupancy); + } else { + return adaptor_2x::Initialize(params, args, device_sms, sm_occupancy); + } +} + +//===----------------------------------------------------------------------===// +// CUTLASS 2x device kernel entry point +//===----------------------------------------------------------------------===// + +// This entry point is based on `cutlass::Kernel2` template with an extra +// parameter to pass dynamic slices. +// +// TODO(ezhulenev): Dynamic slices should be encoded in kernel parameters. +template +__global__ void Kernel2EntryPoint(typename Kernel::Params params, + DynamicSliceArguments dynamic_slices) { + extern __shared__ int SharedStorageBase[]; + typename Kernel::SharedStorage *shared_storage = + reinterpret_cast(SharedStorageBase); + + // Adjust output pointer to account for dynamic offsets. + if (dynamic_slices.out) { + auto m = params.problem_size.m(); + auto n = params.problem_size.n(); + + using ElementC = typename Kernel::ElementC; + int64_t offset = sizeof(ElementC) * *dynamic_slices.out * (m * n); + + char *ptr_c = reinterpret_cast(params.ptr_C); + char *ptr_d = reinterpret_cast(params.ptr_D); + + params.ptr_C = ptr_c + offset; + params.ptr_D = ptr_d + offset; + } + + Kernel::invoke(params, *shared_storage); +} + +//===----------------------------------------------------------------------===// +// CUTLASS 3x device kernel entry point +//===----------------------------------------------------------------------===// + +template +__global__ void Kernel3EntryPoint( + CUTLASS_GRID_CONSTANT const typename Kernel::Params params) { + extern __shared__ char shared_memory[]; + + Kernel kernel; + kernel(params, shared_memory); +} + +//===----------------------------------------------------------------------===// +// Dispatch between CUTLASS 2x and 3x kernel entry points +//===----------------------------------------------------------------------===// + +template +void *DeviceKernel::symbol() const { + using Kernel = typename Traits::Kernel; + + if constexpr (is_cutlass_3x) { + return reinterpret_cast(Kernel3EntryPoint); + } else { + return reinterpret_cast(Kernel2EntryPoint); + } +}; + +//===----------------------------------------------------------------------===// +// CUTLASS kernel traits helper +//===----------------------------------------------------------------------===// + +#define XLA_GPU_DEFINE_CUTLASS_GEMM_TRAITS(TAG, OPERATION) \ + template <> \ + struct Traits { \ + using Operation = OPERATION; \ + using Arguments = typename Operation::Arguments; \ + using Kernel = typename Operation::GemmKernel; \ + using Params = typename Kernel::Params; \ + } + +} // namespace xla::gpu::kernel::gemm_universal + +#endif // XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_ADAPTOR_CU_H_ diff --git a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc new file mode 100644 index 0000000000000..63f9cbe683902 --- /dev/null +++ b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc @@ -0,0 +1,249 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/kernels/cutlass_gemm.h" +#include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/xla_data.pb.h" + +namespace xla::gpu::kernel::gemm_universal { + +static constexpr auto Default = Arch::kDefault; // NOLINT +static constexpr auto Sm80 = Arch::kSm80; // NOLINT +static constexpr auto Sm90 = Arch::kSm90; // NOLINT + +// Each individual CUTLASS kernel adaptor will be compiled in a separate +// cuda_library and linked into the `cutlass_gemm_custom_kernels` target. We use +// this approach for a few reasons: +// +// - It enables parallel compilation of CUTLASS templates which in practice +// becomes quite expensive for any non-trivial GEMM. +// +// - We do not include any of the CUTLASS headers in our custom kernel +// library which would require converting it to a cuda_library, and we +// want to minimize the number of headers included in .cu.cc files as NVCC +// does not particularly like templates defined in ABSL. +// +extern template struct Adaptor>; +extern template struct DeviceKernel>; + +extern template struct Adaptor>; +extern template struct DeviceKernel>; + +extern template struct Adaptor>; +extern template struct DeviceKernel>; + +extern template struct Adaptor>; +extern template struct DeviceKernel>; + +//===----------------------------------------------------------------------===// +// CUTLASS kernel arguments packing +//===----------------------------------------------------------------------===// + +using KernelArgsPacking = se::MultiKernelLoaderSpec::KernelArgsPacking; + +template +static Dim As(Dim3 dim3) { + return Dim(dim3.x, dim3.y, dim3.z); +} + +template +static std::optional As(std::optional dim3) { + if (dim3.has_value()) return Dim(dim3->x, dim3->y, dim3->z); + return std::nullopt; +} + +// Returns a pointer to device memory holding a slice offset. +static int32_t* SlicePtr(const se::KernelArgsDeviceMemoryArray* args, + int64_t index) { + const void* opaque = args->device_memory_ptr(index); + return static_cast(const_cast(opaque)); +} + +template +KernelArgsPacking ArgsPacking(int32_t m, int32_t n, int32_t k, + const ArgsIndices& indices, + const DynamicSliceIndices& slices, + int32_t device_sms, Adaptor adaptor) { + using Packed = absl::StatusOr>; + + // TODO(ezhulenev): CUTLASS kernel Params struct not necessarily trivially + // destructible or even trivially copyable, we have to own the life time of an + // object constructed in the storage. For now we ignore it, and it's textbook + // definition of UB, but for CUTLASS kernels we use today it's perfectly safe. + struct Params { + alignas(64) std::byte storage[1024]; + }; + + return [=](const se::Kernel& kernel, const se::KernelArgs& args) -> Packed { + auto* mem_args = se::Cast(&args); + + Arguments arguments = {m, n, k}; + arguments.lhs = const_cast(mem_args->device_memory_ptr(indices.lhs)); + arguments.rhs = const_cast(mem_args->device_memory_ptr(indices.rhs)); + arguments.out = const_cast(mem_args->device_memory_ptr(indices.out)); + + // Workspace argument always passed as the last one (if passed at all). + if (indices.has_workspace) { + size_t num_mem_args = mem_args->device_memory_args().size(); + arguments.workspace = + const_cast(mem_args->device_memory_ptr(num_mem_args - 1)); + } else { + arguments.workspace = nullptr; + } + + // Set up dynamic slices if they are available. + if (slices.out.has_value()) { + arguments.slices.out = SlicePtr(mem_args, *slices.out); + } + + if (!adaptor.CanImplement(arguments)) { + return absl::InternalError(absl::StrCat( + "CUTLASS kernel can not implement gemm for a given problem size", + ": m=", m, ", n=", n, ", k=", k)); + } + + auto threads = As(adaptor.ThreadDim()); + auto shmem_bytes = adaptor.SharedMemoryBytes(); + + // We keep max_occupancy in a static variable as currently for all + // practical purposes all stream executors in the process have identical + // underlying devices, and there is no need to repeatedly query this + // property. + static int32_t sm_occupancy = + kernel.GetMaxOccupiedBlocksPerCore(threads, shmem_bytes).value_or(1); + + // TODO(ezhulenev): In theory when sm_occupancy is 0 we should not be able + // to run kernels, and we could return error here, however in practice + // it's not true, and kernels with 0 occupancy run just fine! Figure out + // where is the problem, and how we can reliably use sm occupancy numbers. + // + // TODO(ezhulenev): We need to set kernel dynamic shmem limit before asking + // for sm occupancy, it's likely why we get 0 today. + if (sm_occupancy == 0) { + LOG_FIRST_N(WARNING, 1) + << "CUTLASS gemm kernel reported 0 occupancy: threads_per_block=" + << (threads.x * threads.y * threads.z) + << ", dynamic_shared_memory_bytes=" << shmem_bytes; + } + + // Initialize parameters storage using adaptor. + Params params; + adaptor.Initialize(¶ms, arguments, device_sms, sm_occupancy); + + // TODO(ezhulenev): We need to support EmplaceKernelArgs with inplace + // construction to avoid copying 1kb of byte storage. + // + // TODO(ezhulenev): Remove `DynamicSliceArguments` once we encode + // dynamic slice offsets in kernel parameters. + return se::PackKernelArgs( + args.number_of_shared_bytes(), params, arguments.slices); + }; +} +//===----------------------------------------------------------------------===// + +template +static absl::StatusOr Load(std::string name, int32_t m, int32_t n, + int32_t k, const ArgsIndices& indices, + const DynamicSliceIndices& slices, + const se::DeviceDescription& device, + Adaptor adaptor = {}, + DeviceKernel kernel = {}) { + // Get the dispatch grid size and shared memory requirements. + auto cluster_dim = As(adaptor.ClusterDim()); + auto block_dim = As(adaptor.BlockDim(m, n, k)); + auto thread_dim = As(adaptor.ThreadDim()); + auto shared_memory_bytes = adaptor.SharedMemoryBytes(); + + auto packing = + ArgsPacking(m, n, k, indices, slices, device.core_count(), adaptor); + + se::MultiKernelLoaderSpec spec(/*arity=*/2, std::move(packing)); + spec.AddInProcessSymbol(kernel.symbol(), name); + + if (cluster_dim.has_value()) { + return CustomKernel(std::move(name), std::move(spec), block_dim, thread_dim, + *cluster_dim, shared_memory_bytes); + } else { + return CustomKernel(std::move(name), std::move(spec), block_dim, thread_dim, + shared_memory_bytes); + } +} + +absl::StatusOr GetCutlassGemmKernel( + std::string name, PrimitiveType dtype, int32_t m, int32_t n, int32_t k, + const ArgsIndices& indices, const DynamicSliceIndices& slices, + const se::DeviceDescription& device) { + auto& cuda_cc = + std::get(device.gpu_compute_capability()); + + switch (dtype) { + case PrimitiveType::F32: + return Load>(std::move(name), m, n, k, indices, + slices, device); + case PrimitiveType::BF16: + if (cuda_cc.IsAtLeastHopper()) { + return Load>(std::move(name), m, n, k, indices, + slices, device); + } else if (cuda_cc.IsAtLeastAmpere()) { + return Load>(std::move(name), m, n, k, indices, + slices, device); + } + return Load>(std::move(name), m, n, k, indices, + slices, device); + + default: + return absl::InvalidArgumentError("Unsupported CUTLASS gemm data type"); + } +} + +absl::StatusOr LoadCutlassGemmKernel( + std::string name, const std::string& library_path, PrimitiveType dtype, + int32_t m, int32_t n, int32_t k, const ArgsIndices& indices, + const DynamicSliceIndices& slices, const se::DeviceDescription& device) { + auto adaptor = Adaptor::Load(library_path); + if (!adaptor.has_value()) { + return absl::InternalError( + absl::StrCat("Failed to load CUTLASS adaptor from a shared library: ", + library_path)); + } + + auto kernel = DeviceKernel::Load(library_path); + if (!kernel.has_value()) { + return absl::InternalError(absl::StrCat( + "Failed to load CUTLASS kernel from a shared library: ", library_path)); + } + + return Load(std::move(name), m, n, k, indices, slices, device, + *adaptor, *kernel); +} + +} // namespace xla::gpu::kernel::gemm_universal diff --git a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h new file mode 100644 index 0000000000000..e432373092752 --- /dev/null +++ b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h @@ -0,0 +1,46 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_CUSTOM_KERNEL_H_ +#define XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_CUSTOM_KERNEL_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/kernels/cutlass_gemm.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla_data.pb.h" + +namespace xla::gpu::kernel::gemm_universal { + +// Returns a pre-compiled custom kernel for a given data type and problem size. +absl::StatusOr GetCutlassGemmKernel( + std::string name, PrimitiveType dtype, int32_t m, int32_t n, int32_t k, + const ArgsIndices& indices, const DynamicSliceIndices& slices, + const se::DeviceDescription& device); + +// Loads custom kernel for a given data type and problem size from a shared +// library. It's up to the caller to guarantee that CUTLASS kernel in the shared +// library is compatible with the data type and problem size. +absl::StatusOr LoadCutlassGemmKernel( + std::string name, const std::string& library_path, PrimitiveType dtype, + int32_t m, int32_t n, int32_t k, const ArgsIndices& indices, + const DynamicSliceIndices& slices, const se::DeviceDescription& device); + +} // namespace xla::gpu::kernel::gemm_universal + +#endif // XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_CUSTOM_KERNEL_H_ diff --git a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc new file mode 100644 index 0000000000000..e2f7bccea9fe5 --- /dev/null +++ b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc @@ -0,0 +1,83 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" +#include "tsl/platform/test_benchmark.h" + +namespace xla::gpu::kernel::gemm_universal { + +static uint32_t BitPattern(float value) { + uint32_t pattern; + std::memcpy(&pattern, &value, sizeof(float)); + return pattern; +} + +static void BM_RowMajorGemm(benchmark::State& state) { + se::Platform* platform = + se::PlatformManager::PlatformWithName("CUDA").value(); + se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + const se::DeviceDescription& device = executor->GetDeviceDescription(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + // GEMM: 8192x4096 * 4096x16384 -> 8192x16384 + int32_t m = 8192; + int32_t n = 16384; + int32_t k = 4096; + + auto custom_kernel = + GetCutlassGemmKernel("cutlass_gemm", PrimitiveType::BF16, m, n, k, + /*indices=*/{0, 1, 2}, /*slices=*/{}, device); + + TF_ASSERT_OK_AND_ASSIGN( + auto gemm, se::Kernel::Create(executor, custom_kernel->kernel_spec())); + + // Prepare arguments: a=1.1, b=1.2, c=0.0 + se::DeviceMemory a = executor->AllocateArray(m * k, 0); + se::DeviceMemory b = executor->AllocateArray(k * n, 0); + se::DeviceMemory c = executor->AllocateArray(m * n, 0); + + TF_CHECK_OK(stream->Memset32(&a, BitPattern(1.1f), a.size())); + TF_CHECK_OK(stream->Memset32(&b, BitPattern(1.2f), b.size())); + TF_CHECK_OK(stream->MemZero(&c, c.size())); + + se::KernelArgsDeviceMemoryArray args( + std::vector({a, b, c}), + custom_kernel->shared_memory_bytes()); + + for (auto s : state) { + TF_CHECK_OK(executor->Launch(stream.get(), custom_kernel->thread_dims(), + custom_kernel->block_dims(), *gemm, args)); + TF_CHECK_OK(stream->BlockHostUntilDone()); + } +} + +BENCHMARK(BM_RowMajorGemm); + +} // namespace xla::gpu::kernel::gemm_universal diff --git a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc new file mode 100644 index 0000000000000..ed658d2e01584 --- /dev/null +++ b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc @@ -0,0 +1,41 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/kernels/cutlass_gemm.h" +#include "xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h" +#include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla_data.pb.h" + +namespace xla::gpu::kernel::gemm_universal { + +absl::StatusOr GetCutlassGemmKernel( + std::string name, PrimitiveType dtype, int32_t m, int32_t n, int32_t k, + const ArgsIndices& indices, const DynamicSliceIndices& slices, + const se::DeviceDescription& device) { + return absl::InternalError("XLA compiled without CUDA support"); +} + +absl::StatusOr LoadCutlassGemmKernel( + std::string name, const std::string& library_path, PrimitiveType dtype, + int32_t m, int32_t n, int32_t k, const ArgsIndices& indices, + const DynamicSliceIndices& slices, const se::DeviceDescription& device) { + return absl::InternalError("XLA compiled without CUDA support"); +} + +} // namespace xla::gpu::kernel::gemm_universal diff --git a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc new file mode 100644 index 0000000000000..3748ed5251564 --- /dev/null +++ b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc @@ -0,0 +1,131 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h" + +#include +#include +#include +#include + +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/xla_data.pb.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/path.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::gpu::kernel::gemm_universal { + +TEST(CutlassGemmKernelTest, SimpleGemm) { + se::Platform* platform = + se::PlatformManager::PlatformWithName("CUDA").value(); + se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + auto stream = executor->CreateStream().value(); + + // Load [4, 4] x [4, 4] gemm kernel written in CUDA C++ with CUTLASS. + auto custom_kernel = GetCutlassGemmKernel( + "cutlass_gemm", PrimitiveType::F32, 4, 4, 4, + /*indices=*/{0, 1, 2}, /*slices=*/{}, executor->GetDeviceDescription()); + + TF_ASSERT_OK_AND_ASSIGN( + auto gemm, se::Kernel::Create(executor, custom_kernel->kernel_spec())); + + int64_t length = 4 * 4; + int64_t byte_length = sizeof(float) * length; + + // Prepare arguments: a=2, b=2, c=0 + se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory b = executor->AllocateArray(length, 0); + se::DeviceMemory c = executor->AllocateArray(length, 0); + + float value = 2.0; + uint32_t pattern; + std::memcpy(&pattern, &value, sizeof(pattern)); + + TF_ASSERT_OK(stream->Memset32(&a, pattern, byte_length)); + TF_ASSERT_OK(stream->Memset32(&b, pattern, byte_length)); + TF_ASSERT_OK(stream->MemZero(&c, byte_length)); + + // Launch gemm kernel with device memory arguments. + se::KernelArgsDeviceMemoryArray arr( + std::vector({a, b, c}), + custom_kernel->shared_memory_bytes()); + TF_ASSERT_OK(executor->Launch(stream.get(), custom_kernel->thread_dims(), + custom_kernel->block_dims(), *gemm, arr)); + + // Copy `c` data back to host. + std::vector dst(length, -1.0f); + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + + std::vector expected(length, 16.0); + ASSERT_EQ(dst, expected); +} + +TEST(CutlassGemmKernelTest, LoadFromSharedLibrary) { + std::string kernel_lib_path = + tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu", "kernels", + "cutlass_gemm_kernel_f32xf32_to_f32.so"); + + se::Platform* platform = + se::PlatformManager::PlatformWithName("CUDA").value(); + se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + auto stream = executor->CreateStream().value(); + + // Load [4, 4] x [4, 4] gemm kernel written in CUDA C++ with CUTLASS. + auto custom_kernel = LoadCutlassGemmKernel( + "cutlass_gemm", kernel_lib_path, PrimitiveType::F32, 4, 4, 4, + /*indices=*/{0, 1, 2}, /*slices=*/{}, executor->GetDeviceDescription()); + + TF_ASSERT_OK_AND_ASSIGN( + auto gemm, se::Kernel::Create(executor, custom_kernel->kernel_spec())); + + int64_t length = 4 * 4; + int64_t byte_length = sizeof(float) * length; + + se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory b = executor->AllocateArray(length, 0); + se::DeviceMemory c = executor->AllocateArray(length, 0); + + float value = 2.0; + uint32_t pattern; + std::memcpy(&pattern, &value, sizeof(pattern)); + + TF_ASSERT_OK(stream->Memset32(&a, pattern, byte_length)); + TF_ASSERT_OK(stream->Memset32(&b, pattern, byte_length)); + TF_ASSERT_OK(stream->MemZero(&c, byte_length)); + + // Launch gemm kernel with device memory arguments. + se::KernelArgsDeviceMemoryArray arr( + std::vector({a, b, c}), + custom_kernel->shared_memory_bytes()); + TF_ASSERT_OK(executor->Launch(stream.get(), custom_kernel->thread_dims(), + custom_kernel->block_dims(), *gemm, arr)); + + // Copy `c` data back to host. + std::vector dst(length, -1.0f); + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + + std::vector expected(length, 16.0); + ASSERT_EQ(dst, expected); +} + +} // namespace xla::gpu::kernel::gemm_universal diff --git a/xla/service/gpu/kernels/cutlass_gemm_epilogue.cu.h b/xla/service/gpu/kernels/cutlass_gemm_epilogue.cu.h new file mode 100644 index 0000000000000..1c48e6f53090a --- /dev/null +++ b/xla/service/gpu/kernels/cutlass_gemm_epilogue.cu.h @@ -0,0 +1,309 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_EPILOGUE_CU_H_ +#define XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_EPILOGUE_CU_H_ + +#include +#include + +#include "cute/config.hpp" +#include "cute/container/array.hpp" +#include "cute/layout.hpp" +#include "cute/numeric/int.hpp" +#include "cute/numeric/integral_constant.hpp" +#include "cute/tensor.hpp" +#include "cute/underscore.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/detail/helper_macros.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/numeric_conversion.h" + +namespace xla::gpu::kernel::gemm_universal { + +using cutlass::epilogue::collective::detail::get_epilogue_stride; + +//===----------------------------------------------------------------------===// +// Custom CUTLASS epilogue fusions +//===----------------------------------------------------------------------===// + +template +struct LinearCombinationWithDynamicSlice + : cutlass::epilogue::fusion::ScaledAcc { + static constexpr bool IsSourceSupported = true; // NOLINT +}; + +//===----------------------------------------------------------------------===// +// CUTLASS gemm epilogue with an on-device offset support +//===----------------------------------------------------------------------===// + +// This epilogue is derived from CUTLASS default epilogue with an additional +// support for dynamic slice offsets. +// +// Original: cutlass/epilogue/collective/default_epilogue.hpp + +// Applies an element wise operation to all elements within the fragment +// and writes them out to destination storage. C and D storage can have +// optional dynamic offsets (offsets stored in a device memory). +template +class DynamicSliceEpilogue { + public: + using EpilogueSchedule = EpilogueSchedule_; + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = + typename cute::uint_bit::value * + kOutputAlignment>::type; + + static_assert(cute::rank(StrideC{}) == 3, + "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, + "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage {}; + + // Offset into C and D computed as a dot product of `offset` and `stride`. + struct DynamicOffset { + cute::array offset{}; + cute::array stride{}; + }; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const* ptr_c = nullptr; + StrideC stride_c{}; + ElementD* ptr_d = nullptr; + StrideD stride_d{}; + DynamicOffset offset_c{}; + DynamicOffset offset_d{}; + }; + + // Device side epilogue params_ + using Params = Arguments; + + template + static constexpr Params to_underlying_arguments(ProblemShape const& _, + Arguments const& args, + void* workspace) { + return args; + } + + template + static size_t get_workspace_size(ProblemShape const& problem_shape, + Arguments const& args) { + return 0; + } + + template + static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, + Arguments const& args, + void* workspace, + cudaStream_t stream) { + return cutlass::Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool can_implement( + ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + explicit DynamicSliceEpilogue(Params const& params__) + : params_(params__), epilogue_op_(params__.thread) {} + + CUTLASS_DEVICE + bool is_source_needed() { return epilogue_op_.is_source_needed(); } + + template + CUTLASS_HOST_DEVICE void operator()( + ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_mnk, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, + TiledMma tiled_mma, ResidueMNK residue_mnk, int thread_idx, + char* smem_buf) { + using cute::_; + using cute::_1; + using cute::local_tile; + using cute::make_coord; + using cute::make_gmem_ptr; + using cute::make_identity_tensor; + using cute::make_shape; + using cute::make_tensor; + using cute::shape; + using cute::Tensor; + using cute::unwrap; + + using X = cute::Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, + "ProblemShapeMNKL must be rank 4"); + static_assert(cute::is_static::value, + "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, + "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, + "BlockCoordMNKL must be rank 3"); + + // Separate out problem shape for convenience + auto m = cute::get<0>(problem_shape_mnkl); + auto n = cute::get<1>(problem_shape_mnkl); + auto l = cute::get<3>(problem_shape_mnkl); + + auto stride_c = get_epilogue_stride(params_.stride_c); + auto stride_d = get_epilogue_stride(params_.stride_d); + + ElementC const* ptr_c = params_.ptr_c; + ElementD* ptr_d = params_.ptr_d; + + // Apply dynamic offsets to base pointers. + for (unsigned i = 0; i < dynamic_offset; ++i) { + if (params_.offset_c.offset[i]) + ptr_c += *params_.offset_c.offset[i] * params_.offset_c.stride[i]; + } + for (unsigned i = 0; i < dynamic_offset; ++i) { + if (params_.offset_d.offset[i]) + ptr_d += *params_.offset_d.offset[i] * params_.offset_d.stride[i]; + } + + // Represent the full output tensor + Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_c), make_shape(m, n, l), + stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(ptr_d), make_shape(m, n, l), + stride_d); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_mnk, make_coord(_, _, _), + cute::Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_mnk, make_coord(_, _, _), + cute::Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + Tensor gC = gC_mnl(_, _, m_coord, n_coord, l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_, _, m_coord, n_coord, l_coord); // (BLK_M,BLK_N) + + // Partition source and destination tiles to match the accumulator + // partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + + static_assert(cute::is_static::value, + "Accumulator layout must be static"); + CUTE_STATIC_ASSERT_V( + size(tCgC) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V( + size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + + // Make an identity coordinate tensor for predicating our output MN tile + auto cD = make_identity_tensor( + make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor tCcD = thr_mma.partition_C(cD); + + if (epilogue_op_.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(cute::get<0>(residue_mnk), + cute::get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op_(accumulators(i), tCgC(i)); + } + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(cute::get<0>(residue_mnk), + cute::get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op_(accumulators(i)); + } + } + } + } + + private: + Params params_; + ThreadEpilogueOp epilogue_op_; +}; + +} // namespace xla::gpu::kernel::gemm_universal + +namespace cutlass::epilogue::collective { + +//===----------------------------------------------------------------------===// +// Collective builder specialization for LinearCombinationWithDynamicSlice +//===----------------------------------------------------------------------===// + +// Specialization for `NoSmemWarpSpecialized` schedule. +template +struct CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape_MNK, + ClusterShape_MNK, EpilogueTileType, ElementAccumulator, ElementCompute, + ElementC_, GmemLayoutTagC_, AlignmentC, ElementD, GmemLayoutTagD, + AlignmentD, cutlass::epilogue::NoSmemWarpSpecialized, + xla::gpu::kernel::gemm_universal::LinearCombinationWithDynamicSlice< + ElementD, ElementCompute, dynamic_offset, RoundStyle>, + void> { + // Passing void C disables source load + using ElementC = + cute::conditional_t, ElementD, ElementC_>; + using GmemLayoutTagC = cute::conditional_t, + GmemLayoutTagD, GmemLayoutTagC_>; + + static constexpr cutlass::epilogue::thread::ScaleType::Kind ScaleType = + cute::is_void_v + ? cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + : cutlass::epilogue::thread::ScaleType::Default; + + static constexpr int FragmentSize = 1; + using ThreadOp = cutlass::epilogue::thread::LinearCombination< + ElementD, FragmentSize, ElementAccumulator, ElementCompute, ScaleType, + RoundStyle, ElementC>; + + using CollectiveOp = + cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + xla::gpu::kernel::gemm_universal::DynamicSliceEpilogue< + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, ThreadOp, + cutlass::gemm::EpilogueDefault, dynamic_offset>>; +}; + +} // namespace cutlass::epilogue::collective + +#endif // XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_EPILOGUE_CU_H_ diff --git a/xla/service/gpu/kernels/cutlass_gemm_fusion.cc b/xla/service/gpu/kernels/cutlass_gemm_fusion.cc index ac58ba9c2d8ca..9fdbcd6790633 100644 --- a/xla/service/gpu/kernels/cutlass_gemm_fusion.cc +++ b/xla/service/gpu/kernels/cutlass_gemm_fusion.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,23 +16,27 @@ limitations under the License. #include "xla/service/gpu/kernels/cutlass_gemm_fusion.h" #include +#include #include #include #include #include "absl/status/status.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/kernels/custom_fusion.h" -#include "xla/service/gpu/kernels/custom_fusion_pattern.h" #include "xla/service/gpu/kernels/custom_kernel.h" -#include "xla/service/gpu/kernels/cutlass_gemm_kernel.h" +#include "xla/service/gpu/kernels/custom_kernel_fusion.h" +#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h" +#include "xla/service/gpu/kernels/cutlass_gemm.h" +#include "xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h" #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/status.h" #include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -46,6 +50,25 @@ namespace xla::gpu { namespace { namespace m = match; +// If custom fusion requires extra workspace at run time, ROOT instruction will +// be a tuple with second operand being a result of workspace allocation custom +// call. +struct RootWithWorkspace { + HloInstruction* root; + HloInstruction* workspace; +}; + +static RootWithWorkspace MatchRootWithWorkspace(HloInstruction* root) { + RootWithWorkspace result; + if (Match(root, + m::Tuple(m::Op(&result.root), + m::CustomCall(&result.workspace, + {CustomKernelFusionPattern::kWorkspace})))) { + return result; + } + return {root, nullptr}; +} + // Pattern for matching mixed precision GEMMs. struct GemmWithUpcast { explicit GemmWithUpcast(HloDotInstruction* dot) : dot(dot) {} @@ -54,10 +77,28 @@ struct GemmWithUpcast { HloInstruction* lhs_upcast = nullptr; // HLO convert instr HloInstruction* rhs_upcast = nullptr; // HLO convert instr }; + +// Pattern for matching GEMM with surrounding dynamic-slice/update-slice. +struct GemmWithDynamicSlice { + explicit GemmWithDynamicSlice(HloDynamicUpdateSliceInstruction* update_slice) + : update_slice(update_slice) {} + + std::vector Instrs() { + // Bitcast could be optional + if (bitcast == nullptr) { + return {dot, update_slice}; + } + return {dot, bitcast, update_slice}; + } + + HloInstruction* dot = nullptr; + HloInstruction* bitcast = nullptr; // result bitcast + HloInstruction* update_slice = nullptr; // update result slice +}; } // namespace // Returns OK if dot instruction is a simple 2D row-major gemm. -static Status MatchRowMajorGemm(HloDotInstruction* dot) { +static absl::Status MatchRowMajorGemm(HloDotInstruction* dot) { if (dot->operand(0)->shape().dimensions_size() != 2 || dot->operand(1)->shape().dimensions_size() != 2) { return absl::InternalError("operands must have rank 2"); @@ -75,55 +116,97 @@ static Status MatchRowMajorGemm(HloDotInstruction* dot) { return absl::InternalError("rhs contracting dimensions must be 0"); } - return OkStatus(); + return absl::OkStatus(); } // Return OK if dot instruction is a simple gemm with all operands and result // having the same data type. -static Status MatchSimpleGemm(HloDotInstruction* dot, PrimitiveType dtype) { +static absl::Status MatchSimpleGemm( + HloDotInstruction* dot, absl::Span support_dtypes) { TF_RETURN_IF_ERROR(MatchRowMajorGemm(dot)); - if (dot->operand(0)->shape().element_type() != dtype || - dot->operand(1)->shape().element_type() != dtype || - dot->shape().element_type() != dtype) { - return absl::InternalError("operands and result must have the same type"); + for (PrimitiveType dtype : support_dtypes) { + if (dot->operand(0)->shape().element_type() == dtype && + dot->operand(1)->shape().element_type() == dtype && + dot->shape().element_type() == dtype) { + return absl::OkStatus(); + } } - return OkStatus(); + return absl::InternalError("unsupported operands type"); } // Returns matched GEMM with one of the operands upcasted to the accumulator // data type with an HLO convert instruction. -static StatusOr MatchGemmWithUpcast(HloDotInstruction* dot) { +static absl::StatusOr MatchGemmWithUpcast( + HloDotInstruction* dot) { TF_RETURN_IF_ERROR(MatchRowMajorGemm(dot)); - GemmWithUpcast matched(dot); + GemmWithUpcast match(dot); // C <- convert(A) * B if (Match(const_cast(dot->operand(0)), - m::Convert(&matched.lhs_upcast, m::Op()))) { - return matched; + m::Convert(&match.lhs_upcast, m::Op()))) { + return match; } // C <- A * convert(B) if (Match(const_cast(dot->operand(1)), - m::Convert(&matched.rhs_upcast, m::Op()))) { - return matched; + m::Convert(&match.rhs_upcast, m::Op()))) { + return match; } return absl::InternalError("unsupported gemm with upcasing"); } +template +auto OptionalBitcast(HloInstruction** optional_bitcast, Pattern pattern) { + return m::AnyOf(m::Bitcast(optional_bitcast, pattern), + std::move(pattern)); +} + +// Returns matched GEMM with result used to update a slice. +static absl::StatusOr MatchGemmWithDynamicUpdateSlice( + HloDynamicUpdateSliceInstruction* update_slice) { + GemmWithDynamicSlice match(update_slice); + + if (!Match(const_cast(update_slice->update()), + OptionalBitcast(&match.bitcast, + m::Dot(&match.dot, m::Op(), m::Op())))) { + return absl::InternalError("failed to match update slice instr"); + } + + TF_RETURN_IF_ERROR(MatchRowMajorGemm(Cast(match.dot))); + + return match; +} + +static bool AreInstructionsOnTheSameStream( + absl::Span instructions) { + absl::flat_hash_set stream_set; + for (const HloInstruction* inst : instructions) { + auto gpu_config = inst->backend_config(); + if (!gpu_config.ok()) { + continue; + } + stream_set.insert(gpu_config->operation_queue_id()); + if (stream_set.size() > 1) { + return false; + } + } + return true; +}; + //===----------------------------------------------------------------------===// // Cutlass Gemm Patterns //===----------------------------------------------------------------------===// -std::optional CutlassGemmPattern::TryMatch( - HloInstruction* instr) const { +std::optional CutlassGemmPattern::TryMatch( + const se::DeviceDescription& device, HloInstruction* instr) const { auto* dot = DynCast(instr); if (!dot) return std::nullopt; - auto matched = MatchSimpleGemm(dot, PrimitiveType::F32); + auto matched = MatchSimpleGemm(dot, {PrimitiveType::F32}); if (!matched.ok()) return std::nullopt; CustomFusionConfig config; @@ -131,8 +214,42 @@ std::optional CutlassGemmPattern::TryMatch( return Match{config, {instr}}; } -std::optional -CutlassGemmWithUpcastPattern::TryMatch(HloInstruction* instr) const { +std::optional +CutlassGemmWithDynamicUpdateSlicePattern::TryMatch( + const se::DeviceDescription& device, HloInstruction* instr) const { + auto* update_slice = DynCast(instr); + if (!update_slice) return std::nullopt; + + auto matched = MatchGemmWithDynamicUpdateSlice(update_slice); + if (!matched.ok() || !AreInstructionsOnTheSameStream(matched->Instrs())) + return std::nullopt; + + CustomFusionConfig config; + config.set_name("cutlass_gemm_with_dynamic_update_slice"); + + Match match(config, matched->Instrs()); + + // Add an optional replacement for intermediate dot instruction as a + // dynamic-slice from the fusion result. + match.AddReplacement(matched->dot, [=](HloFusionInstruction* fusion) { + HloComputation* parent = fusion->parent(); + auto* dus = Cast(matched->update_slice); + bool has_bitcast = matched->bitcast != nullptr; + const Shape dus_shape = + has_bitcast ? matched->bitcast->shape() : matched->dot->shape(); + auto* slice = parent->AddInstruction(HloInstruction::CreateDynamicSlice( + dus_shape, fusion, dus->index_operands(), dus_shape.dimensions())); + + return parent->AddInstruction( + HloInstruction::CreateBitcast(matched->dot->shape(), slice)); + }); + + return match; +} + +std::optional +CutlassGemmWithUpcastPattern::TryMatch(const se::DeviceDescription& device, + HloInstruction* instr) const { auto* dot = DynCast(instr); if (!dot) return std::nullopt; @@ -153,9 +270,10 @@ CutlassGemmWithUpcastPattern::TryMatch(HloInstruction* instr) const { // Cutlass Gemm Fusions //===----------------------------------------------------------------------===// -class CutlassGemmFusion : public CustomFusion { +class CutlassGemmFusion : public CustomKernelFusion { public: - StatusOr> LoadKernels( + absl::StatusOr> LoadKernels( + const se::DeviceDescription& device, const HloComputation* computation) const final { auto* dot = DynCast(computation->root_instruction()); if (dot == nullptr) { @@ -163,26 +281,37 @@ class CutlassGemmFusion : public CustomFusion { "cutlass_gemm requires ROOT operation to be a dot"); } - TF_RETURN_IF_ERROR(MatchSimpleGemm(dot, PrimitiveType::F32)); + TF_RETURN_IF_ERROR(MatchSimpleGemm(dot, {PrimitiveType::F32})); auto dtype = dot->shape().element_type(); - auto& lhs_shape = dot->operand(0)->shape(); - auto& rhs_shape = dot->operand(1)->shape(); + auto* lhs = Cast(dot->operand(0)); + auto* rhs = Cast(dot->operand(1)); + + // Mapping from fusion arguments to gemm kernel arguments. + kernel::gemm_universal::ArgsIndices indices = { + lhs->parameter_number(), rhs->parameter_number(), + computation->num_parameters()}; + + auto& lhs_shape = lhs->shape(); + auto& rhs_shape = rhs->shape(); size_t m = lhs_shape.dimensions(0); size_t k = lhs_shape.dimensions(1); size_t n = rhs_shape.dimensions(1); - TF_ASSIGN_OR_RETURN(auto kernel, - kernel::GetCutlassGemmKernel(dtype, m, n, k)); + TF_ASSIGN_OR_RETURN( + auto kernel, + kernel::gemm_universal::GetCutlassGemmKernel( + "cutlass_gemm", dtype, m, n, k, indices, /*slices=*/{}, device)); return std::vector{std::move(kernel)}; } }; -class CutlassGemmWithUpcastFusion : public CustomFusion { +class CutlassGemmWithUpcastFusion : public CustomKernelFusion { public: - StatusOr> LoadKernels( + absl::StatusOr> LoadKernels( + const se::DeviceDescription& device, const HloComputation* computation) const final { auto* dot = DynCast(computation->root_instruction()); if (dot == nullptr) { @@ -207,10 +336,65 @@ class CutlassGemmWithUpcastFusion : public CustomFusion { } }; +class CutlassGemmWithDynamicUpdateSliceFusion : public CustomKernelFusion { + public: + absl::StatusOr> LoadKernels( + const se::DeviceDescription& device, + const HloComputation* computation) const final { + auto [root, workspace] = + MatchRootWithWorkspace(computation->root_instruction()); + + auto* dus = DynCast(root); + if (dus == nullptr) { + return absl::InternalError( + "cutlass_gemm_with_dynamic_update_slice requires ROOT operation to " + "be a dynamic update slice"); + } + + TF_ASSIGN_OR_RETURN(auto matched, MatchGemmWithDynamicUpdateSlice(dus)); + TF_RETURN_IF_ERROR( + MatchSimpleGemm(Cast(matched.dot), + {PrimitiveType::F32, PrimitiveType::BF16})); + + auto dtype = matched.dot->shape().element_type(); + + auto* lhs = Cast(matched.dot->operand(0)); + auto* rhs = Cast(matched.dot->operand(1)); + auto* out = Cast(matched.update_slice->operand(0)); + + // Mapping from fusion arguments to gemm kernel arguments. + kernel::gemm_universal::ArgsIndices args_indices = { + lhs->parameter_number(), rhs->parameter_number(), + out->parameter_number(), /*has_workspace=*/workspace != nullptr}; + + // Mapping to a buffer that holds output slice offset. + auto* offset = + Cast(matched.update_slice->operand(2)); + kernel::gemm_universal::DynamicSliceIndices slices; + slices.out = offset->parameter_number(); + + auto& lhs_shape = lhs->shape(); + auto& rhs_shape = rhs->shape(); + + size_t m = lhs_shape.dimensions(0); + size_t k = lhs_shape.dimensions(1); + size_t n = rhs_shape.dimensions(1); + + TF_ASSIGN_OR_RETURN( + auto kernel, kernel::gemm_universal::GetCutlassGemmKernel( + "cutlass_gemm_with_dynamic_update_slice", dtype, m, n, + k, args_indices, slices, device)); + return std::vector{std::move(kernel)}; + } +}; + } // namespace xla::gpu -XLA_REGISTER_CUSTOM_FUSION_PATTERN(::xla::gpu::CutlassGemmPattern); +XLA_REGISTER_CUSTOM_FUSION_PATTERN( + ::xla::gpu::CutlassGemmWithDynamicUpdateSlicePattern); XLA_REGISTER_CUSTOM_FUSION("cutlass_gemm", ::xla::gpu::CutlassGemmFusion); XLA_REGISTER_CUSTOM_FUSION("cutlass_gemm_with_upcast", ::xla::gpu::CutlassGemmWithUpcastFusion); +XLA_REGISTER_CUSTOM_FUSION("cutlass_gemm_with_dynamic_update_slice", + ::xla::gpu::CutlassGemmWithDynamicUpdateSliceFusion); diff --git a/xla/service/gpu/kernels/cutlass_gemm_fusion.h b/xla/service/gpu/kernels/cutlass_gemm_fusion.h index f448b2d0a4d91..e7027c905d9ef 100644 --- a/xla/service/gpu/kernels/cutlass_gemm_fusion.h +++ b/xla/service/gpu/kernels/cutlass_gemm_fusion.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,21 +19,32 @@ limitations under the License. #include #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/kernels/custom_fusion_pattern.h" +#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h" +#include "xla/stream_executor/device_description.h" namespace xla::gpu { // Pattern matches simple row-major gemms to CUTLASS kernels. -class CutlassGemmPattern : public CustomFusionPattern { +class CutlassGemmPattern : public CustomKernelFusionPattern { public: - std::optional TryMatch(HloInstruction* instr) const override; + std::optional TryMatch(const se::DeviceDescription& device, + HloInstruction* instr) const override; +}; + +// Pattern matches simple row-major gemms with dynamic-update-slice. +class CutlassGemmWithDynamicUpdateSlicePattern + : public CustomKernelFusionPattern { + public: + std::optional TryMatch(const se::DeviceDescription& device, + HloInstruction* instr) const override; }; // Pattern matches mixed dtype gemms when one of the operands is upcasted to an // accumulator (output) dtype, i.e. BF16 <= BF16 x S8. -class CutlassGemmWithUpcastPattern : public CustomFusionPattern { +class CutlassGemmWithUpcastPattern : public CustomKernelFusionPattern { public: - std::optional TryMatch(HloInstruction* instr) const override; + std::optional TryMatch(const se::DeviceDescription& device, + HloInstruction* instr) const override; }; } // namespace xla::gpu diff --git a/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc b/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc index 541ba5c569b08..bf96a264b5dbe 100644 --- a/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc +++ b/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,25 +15,24 @@ limitations under the License. #include "xla/service/gpu/kernels/cutlass_gemm_fusion.h" +#include #include -#include "xla/debug_options_flags.h" +#include "xla/array.h" +#include "xla/array2d.h" +#include "xla/array3d.h" #include "xla/error_spec.h" -#include "xla/service/gpu/custom_fusion_rewriter.h" -#include "xla/service/gpu/kernels/custom_fusion_pattern.h" +#include "xla/literal_util.h" +#include "xla/service/gpu/custom_kernel_fusion_rewriter.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h" #include "xla/tests/hlo_test_base.h" +#include "xla/types.h" #include "tsl/platform/test.h" namespace xla::gpu { -class CutlassFusionTest : public HloTestBase { - // Custom fusions are not supported by XLA runtime. - DebugOptions GetDebugOptionsForTest() override { - auto debug_options = GetDebugOptionsFromFlags(); - debug_options.set_xla_gpu_enable_xla_runtime_executable(false); - return debug_options; - } -}; +class CutlassFusionTest : public HloTestBase {}; //===----------------------------------------------------------------------===// // Pattern matching tests @@ -56,7 +55,7 @@ TEST_F(CutlassFusionTest, RowMajorGemm) { ; CHECK: [[P0:%[^ ]+]] = f32[15,19]{1,0} parameter(0) ; CHECK: [[P1:%[^ ]+]] = f32[19,17]{1,0} parameter(1) ; CHECK: ROOT [[DOT:%[^ ]+]] = f32[15,17]{1,0} dot([[P0]], [[P1]]), - ; CEHCK: lhs_contracting_dims={1}, rhs_contracting_dims={0} + ; CHECK: lhs_contracting_dims={1}, rhs_contracting_dims={0} ; CHECK: } ; CHECK: ENTRY %main {{.*}} { @@ -69,10 +68,11 @@ TEST_F(CutlassFusionTest, RowMajorGemm) { ; CHECK: } )"; - CustomFusionPatternRegistry patterns; + CustomKernelFusionPatternRegistry patterns; patterns.Emplace(); - CustomFusionRewriter pass(&patterns); + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + CustomKernelFusionRewriter pass(&device, &patterns); RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); } @@ -95,7 +95,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithUpcast) { ; CHECK-DAG: [[P1:%[^ ]+]] = s8[19,17]{1,0} parameter ; CHECK: [[C1:%[^ ]+]] = bf16[19,17]{1,0} convert([[P1]]) ; CHECK: ROOT [[DOT:%[^ ]+]] = bf16[15,17]{1,0} dot([[P0]], [[C1]]), - ; CEHCK: lhs_contracting_dims={1}, rhs_contracting_dims={0} + ; CHECK: lhs_contracting_dims={1}, rhs_contracting_dims={0} ; CHECK: } ; CHECK: ENTRY %main {{.*}} { @@ -108,10 +108,166 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithUpcast) { ; CHECK: } )"; - CustomFusionPatternRegistry patterns; + CustomKernelFusionPatternRegistry patterns; patterns.Emplace(); - CustomFusionRewriter pass(&patterns); + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + CustomKernelFusionRewriter pass(&device, &patterns); + RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); +} + +TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSlice) { + const char* hlo = R"( + HloModule test + + ENTRY %main (p0: f32[2,2,2], p1: f32[2,2], i: s32[]) -> f32[2,2,2] { + %p0 = f32[2,2,2]{2,1,0} parameter(0) + %p1 = f32[2,2]{1,0} parameter(1) + %i = s32[] parameter(2) + + %dot = f32[2,2]{1,0} dot(%p1, %p1), + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + %bc = f32[1,2,2]{2,1,0} bitcast(%dot) + + ROOT %r = f32[2,2,2]{2,1,0} dynamic-update-slice(%p0, %bc, %i, %i, %i) + } + )"; + + const char* expected = R"( + ; CHECK: %cutlass_gemm_with_dynamic_update_slice {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter + ; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2,2]{2,1,0} parameter + ; CHECK-DAG: [[P2:%[^ ]+]] = s32[] parameter + ; CHECK-DAG: [[DOT:%[^ ]+]] = f32[2,2]{1,0} dot([[P0]], [[P0]]) + ; CHECK-DAG: [[CAST:%[^ ]+]] = f32[1,2,2]{2,1,0} bitcast([[DOT]]) + ; CHECK: ROOT [[DUS:%[^ ]+]] = f32[2,2,2]{2,1,0} dynamic-update-slice( + ; CHECK: [[P1]], [[CAST]], [[P2]], [[P2]], [[P2]] + ; CHECK: ) + ; CHECK: } + + ; CHECK: ENTRY %main {{.*}} { + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f32[2,2,2]{2,1,0} fusion + ; CHECK: kind=kCustom, calls=%cutlass_gemm_with_dynamic_update_slice, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{ + ; CHECK: "name":"cutlass_gemm_with_dynamic_update_slice" + ; CHECK: } + ; CHECK: } + ; CHECK: } + )"; + + CustomKernelFusionPatternRegistry patterns; + patterns.Emplace(); + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + CustomKernelFusionRewriter pass(&device, &patterns); + RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); +} + +TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceMultipleUses) { + const char* hlo = R"( + HloModule test + + ENTRY %main { + %p0 = f32[2,2,2]{2,1,0} parameter(0) + %p1 = f32[2,2]{1,0} parameter(1) + %i = s32[] parameter(2) + + %dot = f32[2,2]{1,0} dot(%p1, %p1), + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + %add = f32[2,2]{1,0} add(%dot, %dot) + + %cast = f32[1,2,2]{2,1,0} bitcast(%dot) + %dus = f32[2,2,2]{2,1,0} dynamic-update-slice(%p0, %cast, %i, %i, %i) + + ROOT %r = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(%add, %dus) + } + )"; + + const char* expected = R"( + ; CHECK: %cutlass_gemm_with_dynamic_update_slice {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter + ; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2,2]{2,1,0} parameter + ; CHECK-DAG: [[P2:%[^ ]+]] = s32[] parameter + ; CHECK-DAG: [[DOT:%[^ ]+]] = f32[2,2]{1,0} dot([[P0]], [[P0]]) + ; CHECK-DAG: [[CAST:%[^ ]+]] = f32[1,2,2]{2,1,0} bitcast([[DOT]]) + ; CHECK: ROOT [[DUS:%[^ ]+]] = f32[2,2,2]{2,1,0} dynamic-update-slice( + ; CHECK: [[P1]], [[CAST]], [[P2]], [[P2]], [[P2]] + ; CHECK: ) + ; CHECK: } + + ; CHECK: ENTRY %main {{.*}} { + ; CHECK: [[OFFSET:%[^ ]+]] = s32[] parameter(2) + ; CHECK: [[FUSION:%[^ ]+]] = f32[2,2,2]{2,1,0} fusion + ; CHECK: kind=kCustom, calls=%cutlass_gemm_with_dynamic_update_slice, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{ + ; CHECK: "name":"cutlass_gemm_with_dynamic_update_slice" + ; CHECK: } + ; CHECK: } + ; CHECK: [[SLICE:%[^ ]+]] = f32[1,2,2]{2,1,0} dynamic-slice( + ; CHECK: [[FUSION]], [[OFFSET]], [[OFFSET]], [[OFFSET]]), + ; CHECK: dynamic_slice_sizes={1,2,2} + ; CHECK: [[CAST:%[^. ]+]] = f32[2,2]{1,0} bitcast([[SLICE]]) + ; CHECK: [[ADD:%[^. ]+]] = f32[2,2]{1,0} add([[CAST]], [[CAST]]) + ; CHECK: } + )"; + + CustomKernelFusionPatternRegistry patterns; + patterns.Emplace(); + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + CustomKernelFusionRewriter pass(&device, &patterns); + RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); +} + +TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceWithoutBitcast) { + const char* hlo = R"( + HloModule test + + ENTRY %main (p0: f32[4,2], p1: f32[2,2], i: s32[]) -> f32[4,2] { + %p0 = f32[4,2]{1,0} parameter(0) + %p1 = f32[2,2]{1,0} parameter(1) + %i = s32[] parameter(2) + + %dot = f32[2,2]{1,0} dot(%p1, %p1), + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + + ROOT %r = f32[4,2]{1,0} dynamic-update-slice(%p0, %dot, %i, %i) + } + )"; + + const char* expected = R"( + ; CHECK: %cutlass_gemm_with_dynamic_update_slice {{.*}} { + ; CHECK-DAG: [[P1:%[^ ]+]] = f32[4,2]{1,0} parameter + ; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter + ; CHECK-DAG: [[DOT:%[^ ]+]] = f32[2,2]{1,0} dot([[P0]], [[P0]]) + ; CHECK-DAG: [[P2:%[^ ]+]] = s32[] parameter + ; CHECK: ROOT [[DUS:%[^ ]+]] = f32[4,2]{1,0} dynamic-update-slice([[P1]], [[DOT]], [[P2]], [[P2]]) + ; CHECK: } + + ; CHECK: ENTRY %main {{.*}} { + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f32[4,2]{1,0} fusion + ; CHECK: kind=kCustom, calls=%cutlass_gemm_with_dynamic_update_slice, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{ + ; CHECK: "name":"cutlass_gemm_with_dynamic_update_slice" + ; CHECK: } + ; CHECK: } + ; CHECK: } + )"; + + CustomKernelFusionPatternRegistry patterns; + patterns.Emplace(); + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + CustomKernelFusionRewriter pass(&device, &patterns); RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); } @@ -130,7 +286,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmKernel) { arg1 = f32[784,10]{1,0} parameter(1) gemm = (f32[100,10]{1,0}, s8[0]{0}) custom-call(arg0, arg1), custom_call_target="__cublas$gemm", - backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"} + backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} ROOT get-tuple-element = f32[100,10]{1,0} get-tuple-element((f32[100,10]{1,0}, s8[0]{0}) gemm), index=0 })"; @@ -148,7 +304,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmKernel) { arg0 = f32[100,784]{1,0} parameter(0) arg1 = f32[784,10]{1,0} parameter(1) ROOT _ = f32[100,10]{1,0} fusion(arg0, arg1), kind=kCustom, calls=cutlass_gemm, - backend_config={kind: "__custom_fusion", custom_fusion_config: {"name":"cutlass_gemm"}} + backend_config={"fusion_backend_config":{kind: "__custom_fusion", custom_fusion_config: {"name":"cutlass_gemm"}}} })"; EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_cublas, hlo_text_custom_fusion, @@ -169,8 +325,8 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithUpcastKernel) { c1 = bf16[32,8]{1,0} convert(p1) gemm = (bf16[16,8]{1,0}, s8[0]{0}) custom-call(p0, c1), custom_call_target="__cublas$gemm", - backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"} - ROOT get-tuple-element = bf16[16,8]{1,0} get-tuple-element((bf16[16,8]{1,0}, s8[0]{0}) gemm), index=0 + backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = bf16[16,8]{1,0} get-tuple-element(gemm), index=0 })"; const char* hlo_text_custom_fusion = R"( @@ -188,11 +344,147 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithUpcastKernel) { p0 = bf16[16,32]{1,0} parameter(0) p1 = s8[32,8]{1,0} parameter(1) ROOT _ = bf16[16,8]{1,0} fusion(p0, p1), kind=kCustom, calls=cutlass_gemm_with_upcast, - backend_config={kind: "__custom_fusion", custom_fusion_config: {"name":"cutlass_gemm_with_upcast"}} + backend_config={"fusion_backend_config":{kind: "__custom_fusion", custom_fusion_config: {"name":"cutlass_gemm_with_upcast"}}} })"; EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_cublas, hlo_text_custom_fusion, error_spec, /*run_hlo_passes=*/false)); } +TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceKernel) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_text_cublas = R"( + HloModule cublas + + ENTRY e { + p0 = bf16[2,8,8]{2,1,0} parameter(0) + p1 = bf16[8,8]{1,0} parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + + gemm.tuple = (bf16[8,8]{1,0}, s8[0]{0}) custom-call(p1, p1), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} + gemm = bf16[8,8]{1,0} get-tuple-element(gemm.tuple), index=0 + cast = bf16[1,8,8]{2,1,0} bitcast(gemm) + + ROOT r = bf16[2,8,8]{2,1,0} dynamic-update-slice(p0, cast, p2, p3, p3) + })"; + + const char* hlo_text_custom_fusion = R"( + HloModule cutlass + + cutlass_gemm { + p0.1 = bf16[8,8]{1,0} parameter(0) + p1.1 = bf16[2,8,8]{2,1,0} parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + dot.1 = bf16[8,8]{1,0} dot(p0.1, p0.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + bc.1 = bf16[1,8,8]{2,1,0} bitcast(dot.1) + r.1 = bf16[2,8,8]{2,1,0} dynamic-update-slice(p1.1, bc.1, p2, p3, p3) + workspace = u8[1024]{0} custom-call(), + custom_call_target="__custom_kernel_fusion$workspace", + api_version=API_VERSION_TYPED_FFI + ROOT tuple = (bf16[2,8,8]{2,1,0}, u8[1024]{0}) tuple(r.1, workspace) + } + + ENTRY e { + p0 = bf16[2,8,8]{2,1,0} parameter(0) + p1 = bf16[8,8]{1,0} parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + r.0 = (bf16[2,8,8]{2,1,0}, u8[1024]{0}) fusion(p1, p0, p2, p3), kind=kCustom, + calls=%cutlass_gemm, + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"cutlass_gemm_with_dynamic_update_slice"}}} + ROOT %get-tuple-element = bf16[2,8,8]{2,1,0} get-tuple-element(r.0), index=0 + })"; + + Array3D p0_arr(2, 8, 8); // bf16[2,8,8] + Array2D p1_arr(8, 8); // bf16[8,8] + p1_arr.Each([](int64_t i, int64_t j, bfloat16* out) { + *out = bfloat16{1.0f * i * j}; + }); + + Array p2_arr({}, 1); + Array p3_arr({}, 0); + + auto p0 = LiteralUtil::CreateFromArray(p0_arr); + auto p1 = LiteralUtil::CreateFromArray(p1_arr); + auto p2 = LiteralUtil::CreateFromArray(p2_arr); + auto p3 = LiteralUtil::CreateFromArray(p3_arr); + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_cublas, hlo_text_custom_fusion, + {&p0, &p1, &p2, &p3}, error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(CutlassFusionTest, + RowMajorGemmWithDynamicUpdateSliceKernelWithoutBitcast) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_text_cublas = R"( + HloModule cublas + + ENTRY e { + p0 = bf16[16,8]{1,0} parameter(0) + p1 = bf16[8,8]{1,0} parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + + gemm.tuple = (bf16[8,8]{1,0}, s8[0]{0}) custom-call(p1, p1), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} + gemm = bf16[8,8]{1,0} get-tuple-element(gemm.tuple), index=0 + + ROOT r = bf16[16,8]{1,0} dynamic-update-slice(p0, gemm, p2, p3) + } + )"; + + const char* hlo_text_custom_fusion = R"( + HloModule cutlass + + cutlass_gemm { + p0.1 = bf16[8,8]{1,0} parameter(0) + p1.1 = bf16[16,8]{1,0} parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + dot.1 = bf16[8,8]{1,0} dot(p0.1, p0.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + r.1 = bf16[16,8]{1,0} dynamic-update-slice(p1.1, dot.1, p2, p3) + workspace = u8[1024]{0} custom-call(), + custom_call_target="__custom_kernel_fusion$workspace", + api_version=API_VERSION_TYPED_FFI + ROOT tuple = (bf16[16,8]{1,0}, u8[1024]{0}) tuple(r.1, workspace) + } + + ENTRY e { + p0 = bf16[16,8]{1,0} parameter(0) + p1 = bf16[8,8]{1,0} parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + r.0 = (bf16[16,8]{1,0}, u8[1024]{0}) fusion(p1, p0, p2, p3), kind=kCustom, + calls=%cutlass_gemm, + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"cutlass_gemm_with_dynamic_update_slice"}}} + ROOT %get-tuple-element = bf16[16,8]{1,0} get-tuple-element(r.0), index=0 + })"; + + Array2D p0_arr(16, 8); // bf16[16,8] + Array2D p1_arr(8, 8); // bf16[8,8] + p1_arr.Each([](int64_t i, int64_t j, bfloat16* out) { + *out = bfloat16{1.0f * i * j}; + }); + + Array p2_arr({}, 0); + Array p3_arr({}, 1); + + auto p0 = LiteralUtil::CreateFromArray(p0_arr); + auto p1 = LiteralUtil::CreateFromArray(p1_arr); + auto p2 = LiteralUtil::CreateFromArray(p2_arr); + auto p3 = LiteralUtil::CreateFromArray(p3_arr); + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_cublas, hlo_text_custom_fusion, + {&p0, &p1, &p2, &p3}, error_spec, + /*run_hlo_passes=*/false)); +} + } // namespace xla::gpu diff --git a/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc b/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc deleted file mode 100644 index a84686a4617b3..0000000000000 --- a/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/kernels/cutlass_gemm_kernel.h" - -#include -#include - -#include "absl/status/status.h" -#include "xla/service/gpu/kernels/custom_kernel.h" -#include "xla/service/gpu/kernels/cutlass_gemm_universal.cu.h" -#include "xla/statusor.h" -#include "xla/stream_executor/kernel_spec.h" -#include "xla/xla_data.pb.h" - -namespace xla::gpu::kernel { - -using F32xF32toF32 = - cutlass::gemm::device::GemmUniversal; - -//===----------------------------------------------------------------------===// -// Adaptor from a CUTLASS GemmUniversal to a CustomKernel. -//===----------------------------------------------------------------------===// - -template -StatusOr LoadCutlassGemmUniversal(int32_t m, int32_t n, - int32_t k) { - using Kernel = typename Gemm::GemmKernel; - - cutlass::gemm::GemmCoord problem_size = {m, n, k}; - - // TODO(ezhulenev): We should generate more descriptive names for custom - // kernels, i.e. include tile and dimensions sizes, dtypes, etc. - se::MultiKernelLoaderSpec spec( - /*arity=*/1, gemm_universal::ArgsPacking(problem_size)); - spec.AddInProcessSymbol(reinterpret_cast(cutlass::Kernel2), - "cutlass_universal_gemm"); - - return CustomKernel("cutlass_gemm:f32<-f32xf32", std::move(spec), - gemm_universal::BlockDim(problem_size), - gemm_universal::ThreadDim(), - sizeof(typename Kernel::SharedStorage)); -} - -StatusOr GetCutlassGemmKernel(PrimitiveType dtype, int32_t m, - int32_t n, int32_t k) { - if (dtype != PrimitiveType::F32) - return absl::InvalidArgumentError( - "Currently cutlass gemm kernel supports only F32 data type"); - - return LoadCutlassGemmUniversal(m, n, k); -} - -} // namespace xla::gpu::kernel diff --git a/xla/service/gpu/kernels/cutlass_gemm_kernel.h b/xla/service/gpu/kernels/cutlass_gemm_kernel.h deleted file mode 100644 index 41cf68e5619a3..0000000000000 --- a/xla/service/gpu/kernels/cutlass_gemm_kernel.h +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_KERNEL_H_ -#define XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_KERNEL_H_ - -#include "xla/service/gpu/kernels/custom_kernel.h" -#include "xla/statusor.h" -#include "xla/xla_data.pb.h" - -namespace xla::gpu::kernel { - -// A reference implementation GEMM kernel written in CUTLASS based on -// `00_basic_gemm` example. -StatusOr GetCutlassGemmKernel(PrimitiveType dtype, int32_t m, - int32_t n, int32_t k); - -} // namespace xla::gpu::kernel - -#endif // XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_KERNEL_H_ diff --git a/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xbf16_to_bf16.cu.cc b/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xbf16_to_bf16.cu.cc new file mode 100644 index 0000000000000..dfb062d7c0f76 --- /dev/null +++ b/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xbf16_to_bf16.cu.cc @@ -0,0 +1,33 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cutlass/gemm/device/gemm_universal.h" +#include "xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h" + +namespace xla::gpu::kernel::gemm_universal { + +using GemmOperation = cutlass::gemm::device::GemmUniversal< + cutlass::bfloat16_t, cutlass::layout::RowMajor, // A + cutlass::bfloat16_t, cutlass::layout::RowMajor, // B + cutlass::bfloat16_t, cutlass::layout::RowMajor, // C + float>; + +XLA_GPU_DEFINE_CUTLASS_GEMM_TRAITS(Bf16xBf16ToBf16, + GemmOperation); + +template struct Adaptor>; +template struct DeviceKernel>; + +} // namespace xla::gpu::kernel::gemm_universal diff --git a/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80.cu.cc b/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80.cu.cc new file mode 100644 index 0000000000000..6a2c74fa1972c --- /dev/null +++ b/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80.cu.cc @@ -0,0 +1,39 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cutlass/gemm/device/gemm_universal.h" +#include "xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h" + +namespace xla::gpu::kernel::gemm_universal { + +using GemmOperation = cutlass::gemm::device::GemmUniversal< + cutlass::bfloat16_t, cutlass::layout::RowMajor, // A + cutlass::bfloat16_t, cutlass::layout::RowMajor, // B + cutlass::bfloat16_t, cutlass::layout::RowMajor, // C + float, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, // ThreadblockShape + cutlass::gemm::GemmShape<64, 64, 64>, // WarpShape + cutlass::gemm::GemmShape<16, 8, 16>, // InstructionShape + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + /*Stages=*/3, /*AlignmentA=*/8, /*AlignmentB=*/8>; + +XLA_GPU_DEFINE_CUTLASS_GEMM_TRAITS(Bf16xBf16ToBf16, GemmOperation); + +template struct Adaptor>; +template struct DeviceKernel>; + +} // namespace xla::gpu::kernel::gemm_universal diff --git a/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90.cu.cc b/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90.cu.cc new file mode 100644 index 0000000000000..1713c6ce6139f --- /dev/null +++ b/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90.cu.cc @@ -0,0 +1,60 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h" + +// CUTLASS headers must be included after adaptor +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +// Custom epilogue must be included after CUTLASS headers +#include "xla/service/gpu/kernels/cutlass_gemm_epilogue.cu.h" + +namespace xla::gpu::kernel::gemm_universal { + +using EpilogueLoop = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, float, float, + cutlass::bfloat16_t, cutlass::layout::RowMajor, 8, cutlass::bfloat16_t, + cutlass::layout::RowMajor, 8, cutlass::epilogue::NoSmemWarpSpecialized, + LinearCombinationWithDynamicSlice>::CollectiveOp; + +using MainLoop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::bfloat16_t, + cutlass::layout::RowMajor, 8, cutlass::bfloat16_t, + cutlass::layout::RowMajor, 8, float, + cute::Shape, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; + +using GemmKernel = + cutlass::gemm::kernel::GemmUniversal, + MainLoop, EpilogueLoop, + cutlass::gemm::StreamKScheduler>; + +using GemmOperation = cutlass::gemm::device::GemmUniversalAdapter; + +XLA_GPU_DEFINE_CUTLASS_GEMM_TRAITS(Bf16xBf16ToBf16, GemmOperation); + +template struct Adaptor>; +template struct DeviceKernel>; + +} // namespace xla::gpu::kernel::gemm_universal diff --git a/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xf32_to_f32.cc b/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xf32_to_f32.cc new file mode 100644 index 0000000000000..5aff534c351ae --- /dev/null +++ b/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xf32_to_f32.cc @@ -0,0 +1,79 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "xla/service/gpu/kernels/cutlass_gemm.h" + +namespace xla::gpu::kernel::gemm_universal { + +using CutlassGemm = F32xF32ToF32; + +extern template struct Adaptor; +extern template struct DeviceKernel; + +extern "C" void xla_cutlass_kernel_block_dim(int32_t m, int32_t n, int32_t k, + uint32_t* x, uint32_t* y, + uint32_t* z) { + Adaptor adaptor; + auto dim = adaptor.BlockDim(m, n, k); + *x = dim.x; + *y = dim.y; + *z = dim.z; +} + +extern "C" void xla_cutlass_kernel_thread_dim(uint32_t* x, uint32_t* y, + uint32_t* z) { + Adaptor adaptor; + auto dim = adaptor.ThreadDim(); + *x = dim.x; + *y = dim.y; + *z = dim.z; +} + +extern "C" int32_t xla_cutlass_kernel_shared_memory_bytes() { + Adaptor adaptor; + return adaptor.SharedMemoryBytes(); +} + +extern "C" bool xla_cutlass_kernel_can_implement(int32_t m, int32_t n, + int32_t k) { + Adaptor adaptor; + Arguments arguments = {m, n, k}; + return adaptor.CanImplement(arguments); +} + +extern "C" int64_t xla_cutlass_kernel_workspace_size(int32_t m, int32_t n, + int32_t k) { + Adaptor adaptor; + Arguments arguments = {m, n, k}; + return adaptor.WorkspaceSize(arguments); +} + +extern "C" void xla_cutlass_kernel_initialize( + void* params, int32_t m, int32_t n, int32_t k, void* lhs, void* rhs, + void* out, void* workspace, int32_t* out_offset, int32_t device_sms, + int32_t sm_occupancy) { + Adaptor adaptor; + Arguments arguments = {m, n, k, lhs, rhs, out, workspace, {out_offset}}; + adaptor.Initialize(params, arguments, device_sms, sm_occupancy); +} + +extern "C" void* xla_cutlass_kernel_symbol() { + DeviceKernel kernel; + return kernel.symbol(); +} + +} // namespace xla::gpu::kernel::gemm_universal diff --git a/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xf32_to_f32.cu.cc b/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xf32_to_f32.cu.cc new file mode 100644 index 0000000000000..0a7cce0c5d8f7 --- /dev/null +++ b/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xf32_to_f32.cu.cc @@ -0,0 +1,31 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cutlass/gemm/device/gemm_universal.h" +#include "xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h" + +namespace xla::gpu::kernel::gemm_universal { + +using GemmOperation = + cutlass::gemm::device::GemmUniversal; + +XLA_GPU_DEFINE_CUTLASS_GEMM_TRAITS(F32xF32ToF32, + GemmOperation); + +template struct Adaptor>; +template struct DeviceKernel>; +} // namespace xla::gpu::kernel::gemm_universal diff --git a/xla/service/gpu/kernels/cutlass_gemm_test.cc b/xla/service/gpu/kernels/cutlass_gemm_test.cc deleted file mode 100644 index 5d73ffd7be2b0..0000000000000 --- a/xla/service/gpu/kernels/cutlass_gemm_test.cc +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "xla/service/gpu/kernels/cutlass_gemm_kernel.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/multi_platform_manager.h" -#include "xla/stream_executor/platform.h" -#include "xla/stream_executor/stream.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" -#include "tsl/platform/test.h" - -namespace xla::gpu::kernel { - -TEST(CutlassGemmKernelTest, SimpleGemm) { - se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("CUDA").value(); - se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - se::Kernel gemm(executor); - - // Load [4, 4] x [4, 4] gemm kernel written in CUDA C++ with CUTLASS. - auto custom_kernel = GetCutlassGemmKernel(PrimitiveType::F32, 4, 4, 4); - TF_ASSERT_OK(executor->GetKernel(custom_kernel->kernel_spec(), &gemm)); - - int64_t length = 4 * 4; - int64_t byte_length = sizeof(float) * length; - - // Prepare arguments: a=2, b=2, c=0 - se::DeviceMemory a = executor->AllocateArray(length, 0); - se::DeviceMemory b = executor->AllocateArray(length, 0); - se::DeviceMemory c = executor->AllocateArray(length, 0); - - float value = 2.0; - uint32_t pattern; - std::memcpy(&pattern, &value, sizeof(pattern)); - - stream.ThenMemset32(&a, pattern, byte_length); - stream.ThenMemset32(&b, pattern, byte_length); - stream.ThenMemZero(&c, byte_length); - - // Launch gemm kernel with device memory arguments. - se::KernelArgsDeviceMemoryArray arr( - std::vector({a, b, c}), - custom_kernel->shared_memory_bytes()); - TF_ASSERT_OK(executor->Launch(&stream, custom_kernel->thread_dims(), - custom_kernel->block_dims(), gemm, arr)); - - // Copy `c` data back to host. - std::vector dst(length, -1.0f); - stream.ThenMemcpy(dst.data(), c, byte_length); - - std::vector expected(length, 16.0); - ASSERT_EQ(dst, expected); -} - -} // namespace xla::gpu::kernel diff --git a/xla/service/gpu/kernels/cutlass_gemm_universal.cu.h b/xla/service/gpu/kernels/cutlass_gemm_universal.cu.h deleted file mode 100644 index d5758dbac810a..0000000000000 --- a/xla/service/gpu/kernels/cutlass_gemm_universal.cu.h +++ /dev/null @@ -1,185 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_UNIVERSAL_CU_H_ -#define XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_UNIVERSAL_CU_H_ - -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "third_party/gpus/cutlass/include/cutlass/cutlass.h" -#include "third_party/gpus/cutlass/include/cutlass/gemm/device/gemm_universal.h" -#include "third_party/gpus/cutlass/include/cutlass/gemm/gemm_enumerated_types.h" -#include "third_party/gpus/cutlass/include/cutlass/gemm_coord.h" -#include "third_party/gpus/cutlass/include/cutlass/layout/matrix.h" -#include "xla/statusor.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/launch_dim.h" - -namespace xla::gpu::kernel::gemm_universal { - -// This is a template library that implements an adaptor from a CUTLASS -// GemmUniversal kernel to StreamExecutor primitives for kernel arguments -// packing and kernel launching. -// -// In all templates defined below `typename Gemm` should be a -// an instance of `cutlass::gemm::device::GemmUniversal` template. - -namespace se = ::stream_executor; - -//===----------------------------------------------------------------------===// -// Gemm launch dimension computation. -//===----------------------------------------------------------------------===// - -template -se::ThreadDim ThreadDim() { - using Kernel = typename Gemm::GemmKernel; - return se::ThreadDim(Kernel::kThreadCount, 1, 1); -} - -template -se::BlockDim BlockDim(const cutlass::gemm::GemmCoord &problem_size) { - using ThreadblockSwizzle = typename Gemm::ThreadblockSwizzle; - using ThreadblockShape = typename Gemm::ThreadblockShape; - - cutlass::gemm::GemmCoord tile_size = { - ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}; - - cutlass::gemm::GemmCoord grid_tiled_shape = - ThreadblockSwizzle::get_tiled_shape(problem_size, tile_size, - /*split_k_slices=*/1); - - auto grid = ThreadblockSwizzle().get_grid_shape(grid_tiled_shape); - - return se::BlockDim(grid.x, grid.y, grid.z); -} - -//===----------------------------------------------------------------------===// -// Gemm strides computation. -//===----------------------------------------------------------------------===// - -template -int64_t LdA(const cutlass::gemm::GemmCoord &problem_size) { - using LayoutA = typename Gemm::LayoutA; - - if constexpr (std::is_same_v) { - return problem_size.k(); - } else { - static_assert(sizeof(Gemm) == 0, "unsupported layout type"); - } -} - -template -int64_t LdB(const cutlass::gemm::GemmCoord &problem_size) { - using LayoutB = typename Gemm::LayoutB; - - if constexpr (std::is_same_v) { - return problem_size.n(); - } else { - static_assert(sizeof(Gemm) == 0, "unsupported layout type"); - } -} - -template -int64_t LdC(const cutlass::gemm::GemmCoord &problem_size) { - using LayoutC = typename Gemm::LayoutA; - - if constexpr (std::is_same_v) { - return problem_size.n(); - } else { - static_assert(sizeof(Gemm) == 0, "unsupported layout type"); - } -} - -//===----------------------------------------------------------------------===// -// Packing kernel arguments to CUTLASS kernel parameters struct. -//===----------------------------------------------------------------------===// - -using KernelArgsPacking = se::MultiKernelLoaderSpec::KernelArgsPacking; - -template -auto *DevicePtr(const se::KernelArgsDeviceMemoryArray *args) { - const void *opaque = args->device_memory_ptr(index); - - if constexpr (index == 0) { - return static_cast(const_cast(opaque)); - } else if constexpr (index == 1) { - return static_cast(const_cast(opaque)); - } else if constexpr (index == 2) { - return static_cast(const_cast(opaque)); - } else { - static_assert(sizeof(Gemm) == 0, "illegal Gemm argument index"); - } -} - -template -KernelArgsPacking ArgsPacking(cutlass::gemm::GemmCoord problem_size) { - using Arguments = typename Gemm::Arguments; - using Kernel = typename Gemm::GemmKernel; - using Params = typename Kernel::Params; - - // Sanity check that we do not accidentally get a giant parameters struct. - static_assert(sizeof(Params) < 512, - "Params struct size is unexpectedly large"); - - using PackedArgs = StatusOr>; - - return [=](const se::KernelArgs &args) -> PackedArgs { - auto *mem_args = Cast(&args); - - cutlass::Status can_implement = Kernel::can_implement(problem_size); - if (can_implement != cutlass::Status::kSuccess) { - return absl::InternalError(absl::StrCat( - "CUTLASS kernel can not implement gemm for a given problem size", - ": m=", problem_size.m(), ", n=", problem_size.n(), - ", k=", problem_size.k())); - } - - auto lda = LdA(problem_size); - auto ldb = LdB(problem_size); - auto ldc = LdC(problem_size); - - auto ptr_a = DevicePtr(mem_args); - auto ptr_b = DevicePtr(mem_args); - auto ptr_c = DevicePtr(mem_args); - - auto mode = cutlass::gemm::GemmUniversalMode::kGemm; - float alpha = 1.0, beta = 0.0; - - // CUTLASS operation arguments. - Arguments arguments(mode, problem_size, - 1, // batch - {alpha, beta}, // epilogue - ptr_a, ptr_b, ptr_c, ptr_c, // pointers - 0, 0, 0, 0, // batch strides - lda, ldb, ldc, ldc // strides - ); - - // TODO(ezhulenev): Get number of SMs from a DeviceDescription and calculate - // correct kernel occupancy using GpuRuntime. - Params params(arguments, /*device_sms=*/128, /*sm_occupancy=*/10); - - return se::PackKernelArgs(args.number_of_shared_bytes(), params); - }; -} - -} // namespace xla::gpu::kernel::gemm_universal - -#endif // XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_UNIVERSAL_CU_H_ diff --git a/xla/service/gpu/kernels/topk_custom_kernel.cc b/xla/service/gpu/kernels/topk_custom_kernel.cc new file mode 100644 index 0000000000000..2be74bf301ee1 --- /dev/null +++ b/xla/service/gpu/kernels/topk_custom_kernel.cc @@ -0,0 +1,150 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/kernels/topk_custom_kernel.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/numeric/bits.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" + +#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) +#include "xla/service/gpu/kernels/topk_kernel_common.h" +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +namespace xla::gpu::kernel::topk { + +#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) + +namespace { + +using KernelArgsPacking = se::MultiKernelLoaderSpec::KernelArgsPacking; + +// The optimal number of threads is the smaller value between the number of +// threads available per block and the number of slices of data. +size_t EstimateOptimalNumThreads(size_t n, size_t k, size_t batch_size) { + // Estimate number of threads per block that can run concurrently given the + // register footprint (k elements are kept in registers at all times). + constexpr size_t kEstimatedThreadsPerBlock = 512; + constexpr size_t kMaxKValue = 16; + size_t simultaneous_threads_per_block = + kEstimatedThreadsPerBlock * (kMaxKValue / k); + size_t threads_per_block = + std::min(simultaneous_threads_per_block, kTopKMaxThreadsPerBlock); + // Minimum amount of data that each thread needs to receive for the algorithm. + size_t min_slice = absl::bit_floor(n / absl::bit_ceil(k)); + return std::min(threads_per_block, min_slice); +} + +// Gets the right version of TopK kernel based on the value of `k`. +template +absl::StatusOr GetKernel(int n, int k) { + if (k <= 1) return GetTopKKernelForK(n); + if (k <= 2) return GetTopKKernelForK(n); + if (k <= 4) return GetTopKKernelForK(n); + if (k <= 8) return GetTopKKernelForK(n); + if (k <= 16) return GetTopKKernelForK(n); + return absl::UnimplementedError(absl::StrCat("Unsupported K: ", k)); +} + +// Returns the function creating packed arguments for TopK kernel. +template +KernelArgsPacking CreateTopKArgsPacking(size_t num_elements, size_t k) { + using Packed = absl::StatusOr>; + + return [=](const se::Kernel& kernel, const se::KernelArgs& args) -> Packed { + auto* mem_args = se::Cast(&args); + + se::DeviceMemory data(mem_args->device_memory_args()[0]); + se::DeviceMemory top_elements(mem_args->device_memory_args()[1]); + se::DeviceMemory top_indices(mem_args->device_memory_args()[2]); + + return se::PackKernelArgs(args.number_of_shared_bytes(), data, num_elements, + top_elements, top_indices, k); + }; +} + +// Implementation for creating a CustomKernel for TopK operation with element +// type `T`. +template +absl::StatusOr GetTypedTopK(std::string name, size_t num_elements, + size_t k, size_t batch_size) { + constexpr size_t kMaxKVSize = sizeof(uint64_t); + // Allocate shmem assuming we have a full reduction. + int shmem_size = absl::bit_ceil(k) * kMaxKVSize * GetTopKWaveFrontSize(); + int num_threads = EstimateOptimalNumThreads(num_elements, k, batch_size); + if (num_threads == 0) { + return absl::FailedPreconditionError( + "Invalid kernel parameters. This is likely a bug in the " + "TopkSpecializer."); + } + + auto packing = CreateTopKArgsPacking(num_elements, k); + + se::MultiKernelLoaderSpec spec(/*arity=*/5, std::move(packing)); + TF_ASSIGN_OR_RETURN(void* kernel_symbol, GetKernel(num_elements, k)); + spec.AddInProcessSymbol(kernel_symbol, name); + + return CustomKernel(std::move(name), std::move(spec), + se::BlockDim(batch_size, 1, 1), + se::ThreadDim(num_threads, 1, 1), shmem_size); +} + +} // namespace + +absl::StatusOr GetTopKKernel(std::string name, + PrimitiveType dtype, + size_t num_elements, size_t k, + size_t batch_size) { + switch (dtype) { + case PrimitiveType::F32: + return GetTypedTopK(std::move(name), num_elements, k, batch_size); + case PrimitiveType::BF16: + return GetTypedTopK(std::move(name), num_elements, k, + batch_size); + default: + return absl::InvalidArgumentError( + absl::StrCat("Unsupported GpuTopK data type: ", dtype)); + } +} + +#else + +// Fallback implementation of creating a CustomKernel for TopK operation. +absl::StatusOr GetTopKKernel(std::string name, + PrimitiveType dtype, + size_t num_elements, size_t k, + size_t batch_size) { + return absl::InternalError("XLA compiled without CUDA support"); +} + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +} // namespace xla::gpu::kernel::topk diff --git a/xla/service/gpu/kernels/topk_custom_kernel.h b/xla/service/gpu/kernels/topk_custom_kernel.h new file mode 100644 index 0000000000000..715f92f7701e7 --- /dev/null +++ b/xla/service/gpu/kernels/topk_custom_kernel.h @@ -0,0 +1,36 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_KERNELS_TOPK_CUSTOM_KERNEL_H_ +#define XLA_SERVICE_GPU_KERNELS_TOPK_CUSTOM_KERNEL_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/xla_data.pb.h" + +namespace xla::gpu::kernel::topk { + +// Creates a CustomKernel for TopK operation. +absl::StatusOr GetTopKKernel(std::string name, + PrimitiveType dtype, + size_t num_elements, size_t k, + size_t batch_size); + +} // namespace xla::gpu::kernel::topk + +#endif // XLA_SERVICE_GPU_KERNELS_TOPK_CUSTOM_KERNEL_H_ diff --git a/xla/service/gpu/kernels/topk_custom_kernel_test.cc b/xla/service/gpu/kernels/topk_custom_kernel_test.cc new file mode 100644 index 0000000000000..d50fd054df2ac --- /dev/null +++ b/xla/service/gpu/kernels/topk_custom_kernel_test.cc @@ -0,0 +1,205 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/kernels/topk_custom_kernel.h" + +#include +#include +#include +#include +#include +#include + +#include +#include "absl/random/random.h" +#include "absl/strings/ascii.h" +#include "absl/strings/substitute.h" +#include "xla/service/platform_util.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::gpu::kernel::topk { + +using ::testing::Combine; +using ::testing::Values; + +template +std::vector RandomVecRange(int num_elements, T start, T end) { + std::vector local; + local.reserve(num_elements); + thread_local absl::BitGen gen; + for (int i = 0; i < num_elements; ++i) { + local.push_back(absl::Uniform(gen, start, end)); + } + return local; +} + +template +std::vector RandomVec(int num_elements) { + return RandomVecRange(num_elements, static_cast(0), + static_cast(num_elements)); +} + +template +std::vector RandomVecNegative(int num_elements) { + return RandomVecRange(num_elements, -static_cast(num_elements), + static_cast(0)); +} + +PrimitiveType Get(float) { return PrimitiveType::F32; } + +PrimitiveType Get(bfloat16) { return PrimitiveType::BF16; } + +// Params: +// - n_kb: number of elements in kilobytes. +// - k: number of elements to return. +// - batch_size +// - offset +using TopKKernelTest = ::testing::TestWithParam>; + +// In this test we only check that the TopK logic works with float. For the full +// dtype coverage suite, please add them to topk_test.cc, where we can use XLA +// utilities to simplify the test logic. +TEST_P(TopKKernelTest, TopKFloat) { + using T = float; + + auto name = + absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); + se::Platform* platform = se::PlatformManager::PlatformWithName(name).value(); + se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + auto stream = executor->CreateStream().value(); + + const auto [n_kb, k, batch_size, offset] = GetParam(); + const size_t n = n_kb * 1024 + offset; + + se::DeviceMemory input_buffer = + executor->AllocateArray(n * batch_size, 0); + se::DeviceMemory output_values = + executor->AllocateArray(k * batch_size, 0); + se::DeviceMemory output_indices = + executor->AllocateArray(k * batch_size, 0); + + auto source = RandomVec(n * batch_size); + TF_ASSERT_OK( + stream->Memcpy(&input_buffer, source.data(), n * batch_size * sizeof(T))); + TF_ASSERT_OK(stream->MemZero(&output_values, k * batch_size * sizeof(T))); + TF_ASSERT_OK( + stream->MemZero(&output_indices, k * batch_size * sizeof(uint32_t))); + + auto custom_kernel = + GetTopKKernel("topk", PrimitiveType::F32, n, k, batch_size); + + TF_ASSERT_OK_AND_ASSIGN( + auto kernel, se::Kernel::Create(executor, custom_kernel->kernel_spec())); + + // Launch topk kernel with device memory arguments. + se::KernelArgsDeviceMemoryArray arr( + std::vector( + {input_buffer, output_values, output_indices}), + custom_kernel->shared_memory_bytes()); + TF_ASSERT_OK(executor->Launch(stream.get(), custom_kernel->thread_dims(), + custom_kernel->block_dims(), *kernel, arr)); + + std::vector got(k); + ASSERT_TRUE(stream->BlockHostUntilDone().ok()); + for (int i = 0; i < batch_size; i++) { + TF_ASSERT_OK(stream->Memcpy(got.data(), output_values.GetSlice(k * i, k), + k * sizeof(T))); + std::vector slice(source.data() + n * i, source.data() + n * (i + 1)); + std::sort(slice.begin(), slice.end(), std::greater()); + slice.resize(k); + EXPECT_THAT(got, ::testing::ElementsAreArray(slice)) + << " k=" << k << ", batch_size=" << batch_size << " i=" << i; + } +} + +TEST_P(TopKKernelTest, TopKPackedNegative) { + using T = float; + + auto name = + absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); + se::Platform* platform = se::PlatformManager::PlatformWithName(name).value(); + se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + auto stream = executor->CreateStream().value(); + + const auto [n_kb, k, batch_size, offset] = GetParam(); + const size_t n = n_kb * 1024 + offset; + + se::DeviceMemory input_buffer = + executor->AllocateArray(n * batch_size, 0); + se::DeviceMemory output_values = + executor->AllocateArray(k * batch_size, 0); + se::DeviceMemory output_indices = + executor->AllocateArray(k * batch_size, 0); + + auto source = RandomVecNegative(n * batch_size); + TF_ASSERT_OK( + stream->Memcpy(&input_buffer, source.data(), n * batch_size * sizeof(T))); + TF_ASSERT_OK(stream->MemZero(&output_values, k * batch_size * sizeof(T))); + TF_ASSERT_OK( + stream->MemZero(&output_indices, k * batch_size * sizeof(uint32_t))); + + auto custom_kernel = + GetTopKKernel("topk", PrimitiveType::F32, n, k, batch_size); + + TF_ASSERT_OK_AND_ASSIGN( + auto kernel, se::Kernel::Create(executor, custom_kernel->kernel_spec())); + + // Launch topk kernel with device memory arguments. + se::KernelArgsDeviceMemoryArray arr( + std::vector( + {input_buffer, output_values, output_indices}), + custom_kernel->shared_memory_bytes()); + TF_ASSERT_OK(executor->Launch(stream.get(), custom_kernel->thread_dims(), + custom_kernel->block_dims(), *kernel, arr)); + + std::vector got(k); + ASSERT_TRUE(stream->BlockHostUntilDone().ok()); + for (int i = 0; i < batch_size; i++) { + TF_ASSERT_OK(stream->Memcpy(got.data(), output_values.GetSlice(k * i, k), + k * sizeof(T))); + std::vector slice(source.data() + n * i, source.data() + n * (i + 1)); + std::sort(slice.begin(), slice.end(), std::greater()); + slice.resize(k); + EXPECT_THAT(got, ::testing::ElementsAreArray(slice)) + << " k=" << k << ", batch_size=" << batch_size << " i=" << i; + } +} + +INSTANTIATE_TEST_SUITE_P(TopKTests, TopKKernelTest, + Combine( + /*n_kb=*/Values(1, 8, 12, 64, 128), + /*k=*/Values(1, 2, 8, 16, 7, 12), + /*batch_size=*/Values(1, 16, 64, 128), + /*offset=*/Values(0, 7, 4)), + [](const auto& info) { + return absl::Substitute( + "n$0KiB_k$1_batch_size$2_offset$3", + std::get<0>(info.param), std::get<1>(info.param), + std::get<2>(info.param), + std::get<3>(info.param)); + }); + +} // namespace xla::gpu::kernel::topk diff --git a/xla/service/gpu/runtime/topk_kernel.cc b/xla/service/gpu/kernels/topk_kernel.cc similarity index 82% rename from xla/service/gpu/runtime/topk_kernel.cc rename to xla/service/gpu/kernels/topk_kernel.cc index 7a65107ea9e6e..1261b5ea23402 100644 --- a/xla/service/gpu/runtime/topk_kernel.cc +++ b/xla/service/gpu/kernels/topk_kernel.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,32 +14,32 @@ limitations under the License. ==============================================================================*/ // This file contains bespoke and optimized implementation for TopK shapes. When -// adding support for new shapes/dtypes, you also need to modify the rewritter +// adding support for new shapes/dtypes, you also need to modify the rewriter // on topk_specializer.cc for these changes to be picked up. -#include "xla/service/gpu/runtime/topk_kernel.h" +#include "xla/service/gpu/kernels/topk_kernel.h" #include +#include #include #include "absl/numeric/bits.h" #include "absl/status/status.h" -#include "Eigen/Core" // from @eigen_archive +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "xla/primitive_util.h" -#include "xla/service/gpu/runtime/gpu_kernel_helper.h" -#include "xla/service/gpu/runtime/topk_kernel_common.h" -#include "xla/stream_executor/gpu/gpu_stream.h" +#include "xla/service/gpu/kernels/topk_kernel_common.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream.h" +#include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" namespace xla::gpu { - namespace { -using se::gpu::GpuStreamHandle; - size_t NumThreads(size_t n, size_t k, size_t batch_size) { // Estimate number of threads per block that can run concurrently given the // register footprint. @@ -68,7 +68,7 @@ absl::Status TypedTopK(se::Stream* stream, se::DeviceMemoryBase data, size_t batch_size) { constexpr size_t max_kv_size = sizeof(uint64_t); // Allocate shmem assuming we have a full reduction. - int shmem_size = absl::bit_ceil(k) * max_kv_size * WAVEFRONT_SIZE; + int shmem_size = absl::bit_ceil(k) * max_kv_size * GetTopKWaveFrontSize(); int num_threads = NumThreads(num_elements, k, batch_size); if (num_threads == 0) { return absl::FailedPreconditionError( @@ -83,14 +83,13 @@ absl::Status TypedTopK(se::Stream* stream, se::DeviceMemoryBase data, TF_ASSIGN_OR_RETURN(void* kernel_symbol, GetKernel(num_elements, k)); TF_ASSIGN_OR_RETURN( auto kernel, - (executor - ->CreateTypedKernel, size_t, se::DeviceMemory, - se::DeviceMemory, size_t>( - "topk", kernel_symbol))); + (se::TypedKernel, size_t, se::DeviceMemory, + se::DeviceMemory, + size_t>::Create(executor, "topk", kernel_symbol))); TF_RETURN_IF_ERROR(stream->ThenLaunch( se::ThreadDim(num_threads, 1, 1), se::BlockDim(batch_size, 1, 1), - shmem_size, *kernel, data_typed, num_elements, top_elements_typed, + shmem_size, kernel, data_typed, num_elements, top_elements_typed, top_indices_typed, k)); return absl::OkStatus(); @@ -110,8 +109,8 @@ absl::Status RunTopk(se::Stream* stream, PrimitiveType dtype, return TypedTopK(stream, data, num_elements, top_elements, top_indices, k, batch_size); case PrimitiveType::BF16: - return TypedTopK( - stream, data, num_elements, top_elements, top_indices, k, batch_size); + return TypedTopK(stream, data, num_elements, top_elements, + top_indices, k, batch_size); default: return absl::UnimplementedError("GpuTopK not implemented for this dtype"); } diff --git a/xla/service/gpu/runtime/topk_kernel.cu.h b/xla/service/gpu/kernels/topk_kernel.cu.h similarity index 80% rename from xla/service/gpu/runtime/topk_kernel.cu.h rename to xla/service/gpu/kernels/topk_kernel.cu.h index d34ec1e870da7..44b11394ece3e 100644 --- a/xla/service/gpu/runtime/topk_kernel.cu.h +++ b/xla/service/gpu/kernels/topk_kernel.cu.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,22 +13,77 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME_TOPK_KERNEL_CU_H_ -#define XLA_SERVICE_GPU_RUNTIME_TOPK_KERNEL_CU_H_ +#ifndef XLA_SERVICE_GPU_KERNELS_TOPK_KERNEL_CU_H_ +#define XLA_SERVICE_GPU_KERNELS_TOPK_KERNEL_CU_H_ // This file contains bespoke and optimized implementation for TopK shapes. When -// adding support for new shapes/dtypes, you also need to modify the rewritter +// adding support for new shapes/dtypes, you also need to modify the rewriter // on topk_specializer.cc for these changes to be picked up. #include #include #include -#include "xla/service/gpu/runtime/gpu_kernel_helper.h" -#include "xla/service/gpu/runtime/topk_kernel_common.h" +#include "xla/service/gpu/kernels/topk_kernel_common.h" +#include "xla/stream_executor/gpu/gpu_types.h" +#include "tsl/lib/math/math_util.h" + +#if GOOGLE_CUDA + +#define WAVEFRONT_SIZE 32 +#define FORCEINLINE __forceinline__ + +#elif TENSORFLOW_USE_ROCM // GOOGLE_CUDA + +#ifdef __AMDGCN_WAVEFRONT_SIZE +#define WAVEFRONT_SIZE __AMDGCN_WAVEFRONT_SIZE +#else +#define WAVEFRONT_SIZE 64 +#endif +#define FORCEINLINE __forceinline__ + +#endif // TENSORFLOW_USE_ROCM namespace xla::gpu { +enum class ShflType { kSync, kUp, kDown, kXor }; + +template +__device__ FORCEINLINE NT GpuShuffle(NT val, uint32_t idx, + uint32_t allmsk = 0xffffffffu) { + constexpr uint32_t SZ = + tsl::MathUtil::CeilOfRatio(sizeof(NT), sizeof(uint32_t)); + union S { + NT v; + uint32_t d[SZ]; + }; + S in{val}, res{}; + +#pragma unroll + for (uint32_t i = 0; i < SZ; i++) { +#if GOOGLE_CUDA + if constexpr (Type == ShflType::kSync) + res.d[i] = __shfl_sync(allmsk, in.d[i], idx); + else if constexpr (Type == ShflType::kUp) + res.d[i] = __shfl_up_sync(allmsk, in.d[i], idx); + else if constexpr (Type == ShflType::kDown) + res.d[i] = __shfl_down_sync(allmsk, in.d[i], idx); + else if constexpr (Type == ShflType::kXor) + res.d[i] = __shfl_xor_sync(allmsk, in.d[i], idx); +#elif TENSORFLOW_USE_ROCM // ROcm does not support sync shuffle intrinsics + if constexpr (Type == ShflType::kSync) + res.d[i] = __shfl(in.d[i], idx); + else if constexpr (Type == ShflType::kUp) + res.d[i] = __shfl_up(in.d[i], idx); + else if constexpr (Type == ShflType::kDown) + res.d[i] = __shfl_down(in.d[i], idx); + else if constexpr (Type == ShflType::kXor) + res.d[i] = __shfl_xor(in.d[i], idx); +#endif + } + return res.v; +} + // Default implementation for KV holder. Useful for testing while adding support // for a new type, but generally bitpacking those values is more efficient. See // implementations below. @@ -191,7 +246,7 @@ struct TopK { for (int offset = num_lanes / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < K; i++) { - KVT kv = GpuShuffle(tmp[i], offset); + KVT kv = GpuShuffle(tmp[i], offset); if (lane_id >= offset) continue; Push(tmp, kv); } @@ -245,12 +300,17 @@ __launch_bounds__(kTopKMaxThreadsPerBlock, 1) __global__ template void* GetTopKKernelForK(int n) { // TODO(doak): Switch to uint32_t if we don't have an efficient - // implemementation for uint16_t. + // implementation for uint16_t. return n < std::numeric_limits::max() ? reinterpret_cast(&Run) : reinterpret_cast(&Run); } +template +int32_t GetTopKWaveFrontSize() { + return WAVEFRONT_SIZE; +} + } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_RUNTIME_TOPK_KERNEL_CU_H_ +#endif // XLA_SERVICE_GPU_KERNELS_TOPK_KERNEL_CU_H_ diff --git a/xla/service/gpu/runtime/topk_kernel.h b/xla/service/gpu/kernels/topk_kernel.h similarity index 80% rename from xla/service/gpu/runtime/topk_kernel.h rename to xla/service/gpu/kernels/topk_kernel.h index 5be0d121d4f31..8e15483f8c066 100644 --- a/xla/service/gpu/runtime/topk_kernel.h +++ b/xla/service/gpu/kernels/topk_kernel.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,18 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME_TOPK_KERNEL_H_ -#define XLA_SERVICE_GPU_RUNTIME_TOPK_KERNEL_H_ +#ifndef XLA_SERVICE_GPU_KERNELS_TOPK_KERNEL_H_ +#define XLA_SERVICE_GPU_KERNELS_TOPK_KERNEL_H_ #include -#include #include "absl/status/status.h" -#include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" +#include "xla/types.h" // IWYU pragma: keep #include "xla/xla_data.pb.h" namespace xla::gpu { @@ -43,4 +40,4 @@ absl::Status RunTopk(se::Stream* stream, PrimitiveType dtype, } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_RUNTIME_TOPK_KERNEL_H_ +#endif // XLA_SERVICE_GPU_KERNELS_TOPK_KERNEL_H_ diff --git a/xla/service/gpu/kernels/topk_kernel_bfloat16.cu.cc b/xla/service/gpu/kernels/topk_kernel_bfloat16.cu.cc new file mode 100644 index 0000000000000..c0e47295a18d0 --- /dev/null +++ b/xla/service/gpu/kernels/topk_kernel_bfloat16.cu.cc @@ -0,0 +1,29 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/kernels/topk_kernel.cu.h" +#include "xla/types.h" + +namespace xla::gpu { + +template void* GetTopKKernelForK(int n); +template void* GetTopKKernelForK(int n); +template void* GetTopKKernelForK(int n); +template void* GetTopKKernelForK(int n); +template void* GetTopKKernelForK(int n); + +template int32_t GetTopKWaveFrontSize(); + +} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/topk_kernel_common.h b/xla/service/gpu/kernels/topk_kernel_common.h similarity index 80% rename from xla/service/gpu/runtime/topk_kernel_common.h rename to xla/service/gpu/kernels/topk_kernel_common.h index fdb04623d6f28..5ddd9ed513d9a 100644 --- a/xla/service/gpu/runtime/topk_kernel_common.h +++ b/xla/service/gpu/kernels/topk_kernel_common.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME_TOPK_KERNEL_COMMON_H_ -#define XLA_SERVICE_GPU_RUNTIME_TOPK_KERNEL_COMMON_H_ +#ifndef XLA_SERVICE_GPU_KERNELS_TOPK_KERNEL_COMMON_H_ +#define XLA_SERVICE_GPU_KERNELS_TOPK_KERNEL_COMMON_H_ #include @@ -31,6 +31,9 @@ static constexpr size_t kTopKMaxThreadsPerBlock = 1024; template void* GetTopKKernelForK(int n); +template +int32_t GetTopKWaveFrontSize(); + } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_RUNTIME_TOPK_KERNEL_COMMON_H_ +#endif // XLA_SERVICE_GPU_KERNELS_TOPK_KERNEL_COMMON_H_ diff --git a/xla/service/gpu/runtime/topk_kernel_float.cu.cc b/xla/service/gpu/kernels/topk_kernel_float.cu.cc similarity index 86% rename from xla/service/gpu/runtime/topk_kernel_float.cu.cc rename to xla/service/gpu/kernels/topk_kernel_float.cu.cc index caf585a4201e0..b7b7823a4dff2 100644 --- a/xla/service/gpu/runtime/topk_kernel_float.cu.cc +++ b/xla/service/gpu/kernels/topk_kernel_float.cu.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime/topk_kernel.cu.h" +#include "xla/service/gpu/kernels/topk_kernel.cu.h" namespace xla::gpu { @@ -23,4 +23,6 @@ template void* GetTopKKernelForK(int n); template void* GetTopKKernelForK(int n); template void* GetTopKKernelForK(int n); +template int32_t GetTopKWaveFrontSize(); + } // namespace xla::gpu diff --git a/xla/service/gpu/kernels/topk_kernel_test.cc b/xla/service/gpu/kernels/topk_kernel_test.cc new file mode 100644 index 0000000000000..29738ef144744 --- /dev/null +++ b/xla/service/gpu/kernels/topk_kernel_test.cc @@ -0,0 +1,233 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/kernels/topk_kernel.h" + +#include +#include + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/random/random.h" +#include "absl/strings/substitute.h" +#include "absl/time/time.h" +#include "xla/stream_executor/gpu/gpu_init.h" +#include "xla/stream_executor/gpu/gpu_stream.h" +#include "xla/stream_executor/gpu/gpu_timer.h" +#include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/test.h" +#include "tsl/platform/test_benchmark.h" + +namespace xla::gpu { +namespace { + +using se::gpu::GpuStreamHandle; +using ::testing::Combine; +using ::testing::Values; + +template +std::vector RandomVecRange(int num_elements, T start, T end) { + std::vector local; + local.reserve(num_elements); + thread_local absl::BitGen gen; + for (int i = 0; i < num_elements; ++i) { + local.push_back(absl::Uniform(gen, start, end)); + } + return local; +} + +template +std::vector RandomVec(int num_elements) { + return RandomVecRange(num_elements, static_cast(0), + static_cast(num_elements)); +} + +template +std::vector RandomVecNegative(int num_elements) { + return RandomVecRange(num_elements, -static_cast(num_elements), + static_cast(0)); +} + +PrimitiveType Get(float) { return PrimitiveType::F32; } +PrimitiveType Get(bfloat16) { return PrimitiveType::BF16; } + +se::StreamExecutor* GetGpuExecutor() { + auto* platform = + se::PlatformManager::PlatformWithName(se::GpuPlatformName()).value(); + return platform->ExecutorForDevice(0).value(); +} + +// Params: +// - n_kb: number of elements in kilobytes. +// - k: number of elements to return. +// - batch_size +// - offset +using TopkTest = ::testing::TestWithParam>; + +// In this test we only check that the TopK logic works with float. For the full +// dtype coverage suite, please add them to topk_test.cc, where we can use XLA +// utilities to simplify the test logic. +TEST_P(TopkTest, TopKFloat) { + using T = float; + + auto* executor = GetGpuExecutor(); + auto stream = executor->CreateStream().value(); + + const auto [n_kb, k, batch_size, offset] = GetParam(); + const size_t n = n_kb * 1024 + offset; + + auto input_buffer = executor->AllocateOwnedArray(n * batch_size), + output_values = executor->AllocateOwnedArray(k * batch_size); + auto output_indices = executor->AllocateOwnedArray(k * batch_size); + + ASSERT_TRUE(!(input_buffer.is_null() || output_values.is_null() || + output_indices.is_null())); + + auto source = RandomVec(n * batch_size); + CHECK_OK(stream->Memcpy(input_buffer.ptr(), source.data(), + n * batch_size * sizeof(T))); + + ASSERT_TRUE(RunTopk(stream.get(), Get(T()), *input_buffer, n, *output_values, + *output_indices, k, batch_size) + .ok()); + std::vector got(k); + ASSERT_TRUE(stream->BlockHostUntilDone().ok()); + for (int i = 0; i < batch_size; i++) { + CHECK_OK(stream->Memcpy(got.data(), output_values->GetSlice(k * i, k), + k * sizeof(T))); + std::vector slice(source.data() + n * i, source.data() + n * (i + 1)); + std::sort(slice.begin(), slice.end(), std::greater()); + slice.resize(k); + EXPECT_THAT(got, ::testing::ElementsAreArray(slice)) + << " k=" << k << ", batch_size=" << batch_size << " i=" << i; + } +} + +TEST_P(TopkTest, TopKPackedNegative) { + using T = float; + + auto* executor = GetGpuExecutor(); + auto stream = executor->CreateStream().value(); + + const auto [n_kb, k, batch_size, offset] = GetParam(); + const size_t n = n_kb * 1024 + offset; + + auto input_buffer = executor->AllocateOwnedArray(n * batch_size), + output_values = executor->AllocateOwnedArray(k * batch_size); + auto output_indices = executor->AllocateOwnedArray(k * batch_size); + + ASSERT_TRUE(!(input_buffer.is_null() || output_values.is_null() || + output_indices.is_null())); + + auto source = RandomVecNegative(n * batch_size); + CHECK_OK(stream->Memcpy(input_buffer.ptr(), source.data(), + n * batch_size * sizeof(T))); + + ASSERT_TRUE(RunTopk(stream.get(), Get(T()), *input_buffer, n, *output_values, + *output_indices, k, batch_size) + .ok()); + std::vector got(k); + ASSERT_TRUE(stream->BlockHostUntilDone().ok()); + for (int i = 0; i < batch_size; i++) { + CHECK_OK(stream->Memcpy(got.data(), output_values->GetSlice(k * i, k), + k * sizeof(T))); + std::vector slice(source.data() + n * i, source.data() + n * (i + 1)); + std::sort(slice.begin(), slice.end(), std::greater()); + slice.resize(k); + EXPECT_THAT(got, ::testing::ElementsAreArray(slice)) + << " k=" << k << ", batch_size=" << batch_size << " i=" << i; + } +} + +INSTANTIATE_TEST_SUITE_P(TopkTests, TopkTest, + Combine( + /*n_kb=*/Values(1, 8, 12, 64, 128), + /*k=*/Values(1, 2, 8, 16, 7, 12), + /*batch_size=*/Values(1, 16, 64, 128), + /*offset=*/Values(0, 7, 4)), + [](const auto& info) { + return absl::Substitute( + "n$0KiB_k$1_batch_size$2_offset$3", + std::get<0>(info.param), std::get<1>(info.param), + std::get<2>(info.param), + std::get<3>(info.param)); + }); + +template +void BM_SmallTopk(benchmark::State& state) { + using T = float; + + size_t k = K; + size_t batch_size = state.range(0); + size_t n = state.range(1) * 1024; + state.SetLabel( + absl::Substitute("n=$0Ki k=$1 batch_size=$2", n / 1024, k, batch_size)); + + auto* executor = GetGpuExecutor(); + auto stream = executor->CreateStream().value(); + + auto input_buffer = executor->AllocateOwnedArray(n * batch_size), + output_values = executor->AllocateOwnedArray(k * batch_size); + auto output_indices = executor->AllocateOwnedArray(k * batch_size); + + if (input_buffer.is_null() || output_values.is_null() || + output_indices.is_null()) { + state.SkipWithError("Unable to allocate GPU memory: aborting benchmark"); + return; + } + + auto source = RandomVec(n); + // use the same random vector for all batches (otherwise it takes too much + // time to generate random data) + for (size_t i = 0; i < batch_size; i++) { + auto slice = input_buffer->GetSlice(i * n, n); + CHECK_OK(stream->Memcpy(&slice, source.data(), n * sizeof(T))); + } + + for (auto _ : state) { + // Warmup execution without GpuTimer active + CHECK_OK(RunTopk(stream.get(), Get(T()), *input_buffer, n, *output_values, + *output_indices, k, batch_size)); + auto timer = se::gpu::GpuTimer::Create(stream.get(), + true /* warmup run was executed */); + CHECK_OK(timer.status()); + CHECK_OK(RunTopk(stream.get(), Get(T()), *input_buffer, n, *output_values, + *output_indices, k, batch_size)); + auto timer_duration = timer.value().GetElapsedDuration(); + CHECK_OK(timer_duration.status()); + state.SetIterationTime(absl::ToDoubleSeconds(timer_duration.value())); + } + size_t items_processed = batch_size * n * state.iterations(); + state.SetItemsProcessed(items_processed); + state.SetBytesProcessed(items_processed * sizeof(T)); +} + +BENCHMARK(BM_SmallTopk<1>)->RangePair(1, 1024, 16, 1024)->UseManualTime(); +BENCHMARK(BM_SmallTopk<2>)->RangePair(1, 1024, 16, 1024)->UseManualTime(); +BENCHMARK(BM_SmallTopk<4>)->RangePair(1, 1024, 16, 1024)->UseManualTime(); +BENCHMARK(BM_SmallTopk<8>)->RangePair(1, 1024, 16, 1024)->UseManualTime(); +BENCHMARK(BM_SmallTopk<16>)->RangePair(1, 1024, 16, 1024)->UseManualTime(); + +} // namespace +} // namespace xla::gpu diff --git a/xla/service/gpu/launch_dimensions.cc b/xla/service/gpu/launch_dimensions.cc index 7831c4cd5affb..b31b0e532c199 100644 --- a/xla/service/gpu/launch_dimensions.cc +++ b/xla/service/gpu/launch_dimensions.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,15 +16,17 @@ limitations under the License. #include "xla/service/gpu/launch_dimensions.h" #include +#include #include #include #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/strings/str_format.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/util.h" namespace xla { @@ -32,8 +34,8 @@ namespace gpu { std::ostream& operator<<(std::ostream& out, const LaunchDimensions& launch_dims) { - LaunchDimensions::Dim3D block_counts = launch_dims.block_counts(); - LaunchDimensions::Dim3D thread_counts = launch_dims.thread_counts_per_block(); + se::BlockDim block_counts = launch_dims.block_counts(); + se::ThreadDim thread_counts = launch_dims.thread_counts_per_block(); out << absl::StrFormat("[block: {%d, %d, %d}, thread: {%d, %d, %d}]", block_counts.x, block_counts.y, block_counts.z, thread_counts.x, thread_counts.y, thread_counts.z); @@ -173,7 +175,7 @@ BlockSizes GetBlockSizes(LaunchDimensionsConfig dim_config, } // namespace -StatusOr CalculateLaunchDimensions( +LaunchDimensions CalculateLaunchDimensions( const Shape& shape, const se::DeviceDescription& gpu_device_info, LaunchDimensionsConfig dim_config) { int64_t num_elements = ShapeUtil::ElementsIn(shape); @@ -183,17 +185,10 @@ StatusOr CalculateLaunchDimensions( num_elements = CeilOfRatio(num_elements, int64_t{dim_config.unroll_factor}); BlockSizes sizes = GetBlockSizes(dim_config, gpu_device_info, shape, num_elements); - if (gpu_device_info.block_dim_limit().x > 0 && - sizes.block_count >= gpu_device_info.block_dim_limit().x) { - return absl::UnimplementedError( - absl::StrCat("Kernel launch needs more blocks (", sizes.block_count, - ") than allowed by hardware (", - gpu_device_info.block_dim_limit().x, ").")); - } return LaunchDimensions( - {sizes.block_count, 1, 1}, - {sizes.threads_per_block_x, sizes.threads_per_block_y, 1}); + se::BlockDim(sizes.block_count, 1, 1), + se::ThreadDim(sizes.threads_per_block_x, sizes.threads_per_block_y, 1)); } } // namespace gpu diff --git a/xla/service/gpu/launch_dimensions.h b/xla/service/gpu/launch_dimensions.h index 575f9466a7400..0d38013657ef5 100644 --- a/xla/service/gpu/launch_dimensions.h +++ b/xla/service/gpu/launch_dimensions.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,11 +16,14 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_LAUNCH_DIMENSIONS_H_ #define XLA_SERVICE_GPU_LAUNCH_DIMENSIONS_H_ +#include #include #include +#include "absl/strings/str_cat.h" #include "xla/shape.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/launch_dim.h" namespace xla { namespace gpu { @@ -29,46 +32,39 @@ namespace gpu { // number of threads per block. class LaunchDimensions { public: - struct Dim3D { - int64_t x, y, z; - - bool operator==(const Dim3D& other) const { - return x == other.x && y == other.y && z == other.z; - } - - bool operator!=(const Dim3D& other) const { return !(*this == other); } - }; - // The default constructor creates a launch dimension that indicate // single-threaded execution. LaunchDimensions() - : block_counts_({1, 1, 1}), thread_counts_per_block_({1, 1, 1}) {} + : block_counts_(se::BlockDim()), + thread_counts_per_block_(se::ThreadDim()) {} - LaunchDimensions(int64_t block_x_count, int64_t thread_x_count_per_block) - : block_counts_({block_x_count, 1, 1}), - thread_counts_per_block_({thread_x_count_per_block, 1, 1}) {} + LaunchDimensions(uint64_t block_x_count, uint64_t thread_x_count_per_block) + : block_counts_(block_x_count, 1, 1), + thread_counts_per_block_(thread_x_count_per_block, 1, 1) {} - LaunchDimensions(const Dim3D& block_counts, - const Dim3D& thread_counts_per_block) + LaunchDimensions(const se::BlockDim& block_counts, + const se::ThreadDim& thread_counts_per_block) : block_counts_(block_counts), thread_counts_per_block_(thread_counts_per_block) {} - Dim3D block_counts() const { return block_counts_; } + se::BlockDim block_counts() const { return block_counts_; } - Dim3D thread_counts_per_block() const { return thread_counts_per_block_; } + se::ThreadDim thread_counts_per_block() const { + return thread_counts_per_block_; + } // Returns the total number of blocks. - int64_t num_blocks() const { + uint64_t num_blocks() const { return block_counts_.x * block_counts_.y * block_counts_.z; } // Returns the total number of threads in a block. - int64_t num_threads_per_block() const { + uint64_t num_threads_per_block() const { return thread_counts_per_block_.x * thread_counts_per_block_.y * thread_counts_per_block_.z; } - int64_t launch_bound() const { + uint64_t launch_bound() const { return num_blocks() * num_threads_per_block(); } @@ -90,8 +86,8 @@ class LaunchDimensions { } private: - Dim3D block_counts_; - Dim3D thread_counts_per_block_; + se::BlockDim block_counts_; + se::ThreadDim thread_counts_per_block_; }; std::ostream& operator<<(std::ostream& out, @@ -108,7 +104,7 @@ struct LaunchDimensionsConfig { // a block of unroll_factor elements. Otherwise each thread will // handle only unroll_factor. bool few_waves = false; - // If `row_optimized` is true, then the block size will equal to + // If `row_vectorized` is true, then the block size will equal to // `hlo.shape().dimensions().back()/unroll_factor`. // Currently few_waves and row_vectorized do not work together. bool row_vectorized = false; @@ -127,7 +123,7 @@ int64_t ThreadsPerBlockRowVectorized( LaunchDimensionsConfig dim_config); // Calculates the launch dimensions used to invoke `hlo`. -StatusOr CalculateLaunchDimensions( +LaunchDimensions CalculateLaunchDimensions( const Shape& shape, const se::DeviceDescription& gpu_device_info, LaunchDimensionsConfig dim_config = {}); diff --git a/xla/service/gpu/llvm_gpu_backend/BUILD b/xla/service/gpu/llvm_gpu_backend/BUILD index 619acb37af3be..72ea7fdc71d29 100644 --- a/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/xla/service/gpu/llvm_gpu_backend/BUILD @@ -1,12 +1,13 @@ -load("//xla:xla.bzl", "xla_cc_test") load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured", ) +load("@tsl//tsl:tsl.bzl", "internal_visibility") +load("//xla:xla.bzl", "xla_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [":friends"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -28,6 +29,7 @@ cc_library( "utils.h", ], deps = [ + "//xla:status", "//xla:status_macros", "//xla:statusor", "//xla:types", @@ -37,10 +39,16 @@ cc_library( "//xla/service/llvm_ir:llvm_command_line_options", "//xla/service/llvm_ir:llvm_type_conversion_util", "//xla/stream_executor:device_description", + "//xla/tsl/util:env_var", "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@llvm-project//llvm:Analysis", "@llvm-project//llvm:BitReader", "@llvm-project//llvm:BitWriter", @@ -56,13 +64,18 @@ cc_library( "@llvm-project//llvm:Scalar", "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", + "@llvm-project//mlir:NVVMDialect", + "@local_config_cuda//cuda:cuda_headers", "@tsl//tsl/platform:cuda_libdevice_path", "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:path", "@tsl//tsl/platform:random", + "@tsl//tsl/platform:rocm_rocdl_path", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", - "@tsl//tsl/util:env_var", ] + if_rocm_is_configured([ "@local_config_rocm//rocm:rocm_headers", "@llvm-project//llvm:AMDGPUCodeGen", @@ -80,8 +93,6 @@ xla_cc_test( ":llvm_gpu_backend", "//xla/tests:xla_internal_test_main", "@llvm-project//llvm:Core", - "@llvm-project//llvm:Support", - "@tsl//tsl/platform:logging", "@tsl//tsl/platform:path", "@tsl//tsl/platform:resource_loader", "@tsl//tsl/platform:test", diff --git a/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 20c22245c217b..c84dd8f6092be 100644 --- a/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,20 +15,37 @@ limitations under the License. #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include #include #include #include #include +#include // NOLINT #include #include +#include // NOLINT #include #include #include #include "absl/base/call_once.h" +#include "absl/base/const_init.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "llvm/ADT/Any.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Analysis/CGSCCPassManager.h" +#include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Bitcode/BitcodeReader.h" @@ -38,19 +55,22 @@ limitations under the License. #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/Verifier.h" #include "llvm/InitializePasses.h" #include "llvm/Linker/Linker.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/PassRegistry.h" +#include "llvm/Passes/OptimizationLevel.h" #include "llvm/Passes/PassBuilder.h" #include "llvm/Passes/StandardInstrumentations.h" +#include "llvm/Support/CodeGen.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/Program.h" #include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" -#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/IPO/Internalize.h" #include "llvm/Transforms/Scalar.h" @@ -58,21 +78,29 @@ limitations under the License. #include "xla/service/gpu/metrics.h" #include "xla/service/llvm_ir/llvm_command_line_options.h" #include "xla/service/llvm_ir/llvm_type_conversion_util.h" -#include "xla/status_macros.h" -#include "xla/types.h" +#include "xla/status.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tsl/util/env_var.h" #include "xla/util.h" #include "tsl/platform/cuda_libdevice_path.h" #include "tsl/platform/env.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/path.h" #include "tsl/platform/random.h" +#include "tsl/platform/rocm_rocdl_path.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" -#include "tsl/util/env_var.h" #if !defined(PLATFORM_GOOGLE) && TENSORFLOW_USE_ROCM #include "rocm/rocm_config.h" #endif +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" +#endif + namespace xla { namespace gpu { namespace { @@ -120,12 +148,14 @@ static std::string GetSmName(se::CudaComputeCapability compute_capability) { return absl::StrCat("sm_", sm_version, extension); } +// NOLINTBEGIN: clang-diagnostic-unused-function // Convenience function for producing a name of a temporary compilation product // from the input filename. std::string MakeNameForTempProduct(absl::string_view input_filename, absl::string_view extension) { return ReplaceFilenameExtension(tsl::io::Basename(input_filename), extension); } +// NOLINTEND: clang-diagnostic-unused-function // Initializes LLVM passes. Uses the PassRegistry mechanism. void InitializePasses(llvm::PassRegistry* pass_registry) { @@ -138,7 +168,7 @@ void InitializePasses(llvm::PassRegistry* pass_registry) { llvm::initializeTransformUtils(*pass_registry); llvm::initializeInstCombine(*pass_registry); llvm::initializeTarget(*pass_registry); - llvm::initializeCodeGenPreparePass(*pass_registry); + llvm::initializeCodeGenPrepareLegacyPassPass(*pass_registry); } // Returns the TargetMachine, given a triple. @@ -211,7 +241,7 @@ void FeedLLVMWithFlags(const std::vector& cl_opts) { for (const std::string& cl_opt : cl_opts) { fake_argv.push_back(cl_opt.c_str()); } - llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]); + llvm::cl::ParseCommandLineOptions(fake_argv.size(), fake_argv.data()); } // Returns whether the module could use any device bitcode library functions. @@ -220,9 +250,9 @@ bool CouldNeedDeviceBitcode(const llvm::Module& module) { // The list of prefixes should be in sync with library functions used in // target_util.cc. if (!function.isIntrinsic() && function.isDeclaration() && - (function.getName().startswith("__nv_") || - function.getName().startswith("__ocml_") || - function.getName().startswith("__ockl_"))) { + (function.getName().starts_with("__nv_") || + function.getName().starts_with("__ocml_") || + function.getName().starts_with("__ockl_"))) { return true; } } @@ -231,7 +261,7 @@ bool CouldNeedDeviceBitcode(const llvm::Module& module) { // Links the module with a vector of path to bitcode modules. // The caller must guarantee that the paths exist. -Status LinkWithBitcodeVector( +absl::Status LinkWithBitcodeVector( llvm::Module* module, const std::vector& bitcode_path_vector) { llvm::Linker linker(*module); @@ -240,7 +270,7 @@ Status LinkWithBitcodeVector( LOG(ERROR) << "bitcode module is required by this HLO module but was " "not found at " << bitcode_path; - return xla::InternalError("bitcode module not found at %s", bitcode_path); + return xla::Internal("bitcode module not found at %s", bitcode_path); } std::unique_ptr bitcode_module = @@ -255,17 +285,17 @@ Status LinkWithBitcodeVector( return !GV.hasName() || (GVS.count(GV.getName()) == 0); }); })) { - return xla::InternalError("Error linking bitcode module from %s", - bitcode_path); + return xla::Internal("Error linking bitcode module from %s", + bitcode_path); } } - return OkStatus(); + return absl::OkStatus(); } -Status NVPTXTargetModuleLinker(llvm::Module* module, - se::GpuComputeCapability gpu_version, - const DebugOptions& debug_options, - const std::string& device_bitcode_path) { +absl::Status NVPTXTargetModuleLinker(llvm::Module* module, + se::GpuComputeCapability gpu_version, + const DebugOptions& debug_options, + const std::string& device_bitcode_path) { // Link the input module with libdevice, to pull in implementations of some // builtins. TF_RETURN_IF_ERROR( @@ -283,20 +313,21 @@ Status NVPTXTargetModuleLinker(llvm::Module* module, } } - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr NVPTXGetTargetMachine( llvm::Triple target_triple, se::CudaComputeCapability compute_capability, const DebugOptions& debug_options) { - // TODO(b/266678775): Make it always PTX 7.1 as soon as TF driver requirements - // are updated. - const std::string ptx_ver = - debug_options.xla_gpu_enable_triton_gemm() ? "+ptx71" : "+ptx60"; // Figure out the exact name of the processor as known to the NVPTX backend // from the gpu_architecture flag. +#if defined(GOOGLE_CUDA) && CUDA_VERSION >= 12010 + // use ptx81 for CUDA >= 12.1 + return GetTargetMachine(target_triple, GetSmName(compute_capability), + debug_options, /*feature_str=*/"+ptx81"); +#endif return GetTargetMachine(target_triple, GetSmName(compute_capability), - debug_options, ptx_ver); + debug_options, /*feature_str=*/"+ptx74"); } using TargetModuleLinker = @@ -352,7 +383,7 @@ auto DumpCallbackForModule(std::string module_identifier, }; } -Status LinkAndOptimizeModule( +absl::Status LinkAndOptimizeModule( llvm::Module* module, se::GpuComputeCapability gpu_version, const DebugOptions& debug_options, const std::string& device_bitcode_path, TargetModuleLinker module_linker, llvm::Triple default_target_triple, @@ -425,7 +456,7 @@ Status LinkAndOptimizeModule( mpm.run(*module, mam); - return OkStatus(); + return absl::OkStatus(); } // One-time module initializer. @@ -535,24 +566,24 @@ std::string LibDevicePath(absl::string_view xla_gpu_cuda_data_dir) { } // Links libdevice into the given module if the module needs libdevice. -Status LinkLibdeviceIfNecessary(llvm::Module* module, - const std::string& libdevice_path) { +absl::Status LinkLibdeviceIfNecessary(llvm::Module* module, + const std::string& libdevice_path) { if (!CouldNeedDeviceBitcode(*module)) { - return OkStatus(); + return absl::OkStatus(); } if (!tsl::Env::Default()->FileExists(libdevice_path).ok()) { LOG(WARNING) << "libdevice is required by this HLO module but was not found at " << libdevice_path; - return xla::InternalError("libdevice not found at %s", libdevice_path); + return xla::Internal("libdevice not found at %s", libdevice_path); } VLOG(1) << "Linking with libdevice from: " << libdevice_path; return LinkWithBitcodeVector(module, {libdevice_path}); } -StatusOr CompileToPtx( +absl::StatusOr CompileToPtx( llvm::Module* module, se::GpuComputeCapability gpu_version, const DebugOptions& debug_options, std::function configure_target) { @@ -578,8 +609,7 @@ StatusOr CompileToPtx( auto compute_capability = std::get_if(&gpu_version); if (!compute_capability) { - return xla::InternalError( - "Incompatible compute capability was specified."); + return xla::Internal("Incompatible compute capability was specified."); } llvm::Triple default_target_triple("nvptx64-unknown-unknown"); @@ -669,7 +699,7 @@ struct HsacoCache { const std::vector& hsaco); }; -static HsacoCache g_hsacoCache; +static HsacoCache g_hsacoCache; // NOLINT: static/global vars forbidden bool HsacoCache::Find(const std::string& ir, uint64_t& hash, const std::string& gfx, std::vector& hsaco) { @@ -705,13 +735,13 @@ void HsacoCache::Add(const std::string& ir, uint64_t hash, // Emits the given module to HSA Code Object. target_machine is an initialized // TargetMachine for the AMDGPU target. -StatusOr> EmitModuleToHsaco( +absl::StatusOr> EmitModuleToHsaco( llvm::Module* module, llvm::TargetMachine* target_machine) { auto* env = tsl::Env::Default(); std::vector tempdir_vector; env->GetLocalTempDirectories(&tempdir_vector); if (tempdir_vector.empty()) { - return xla::InternalError( + return xla::Internal( "Unable to locate a temporary directory for compile-time artifacts."); } std::string tempdir_name = tempdir_vector.front(); @@ -768,13 +798,11 @@ StatusOr> EmitModuleToHsaco( ir_fs->flush(); } // Locate lld. - // TODO(whchung@gmail.com): change to tensorflow::ROCmRoot() after - // ROCm-Device-Libs PR. - std::string lld_path = tsl::io::JoinPath("/opt/rocm", "llvm/bin"); + std::string lld_path = tsl::io::JoinPath(tsl::RocmRoot(), "llvm/bin"); auto lld_program = llvm::sys::findProgramByName("ld.lld", {lld_path}); if (!lld_program) { - return xla::InternalError("unable to find ld.lld in PATH: %s", - lld_program.getError().message()); + return xla::Internal("unable to find ld.lld in PATH: %s", + lld_program.getError().message()); } std::vector lld_args{ llvm_ir::AsStringRef("ld.lld"), llvm_ir::AsStringRef("-flavor"), @@ -788,8 +816,8 @@ StatusOr> EmitModuleToHsaco( llvm::sys::ExecuteAndWait(*lld_program, llvm_ir::AsArrayRef(lld_args), std::nullopt, {}, 0, 0, &error_message); if (lld_result) { - return xla::InternalError("ld.lld execute fail: %s, error code %d", - error_message, lld_result); + return xla::Internal("ld.lld execute fail: %s, error code %d", + error_message, lld_result); } // Read HSACO. @@ -798,7 +826,7 @@ StatusOr> EmitModuleToHsaco( std::vector hsaco(hsaco_file_size); hsaco_file.seekg(0, std::ios::beg); - hsaco_file.read(reinterpret_cast(&hsaco[0]), hsaco_file_size); + hsaco_file.read(reinterpret_cast(hsaco.data()), hsaco_file_size); hsaco_file.close(); if (!keep_tempfiles) { remove(ir_path.c_str()); @@ -809,26 +837,27 @@ StatusOr> EmitModuleToHsaco( } // Links ROCm-Device-Libs into the given module if the module needs it. -Status LinkROCDLIfNecessary(llvm::Module* module, std::string gcn_arch_name, - const std::string& rocdl_dir_path) { +absl::Status LinkROCDLIfNecessary(llvm::Module* module, + std::string gcn_arch_name, + const std::string& rocdl_dir_path) { if (!CouldNeedDeviceBitcode(*module)) { - return OkStatus(); + return absl::OkStatus(); } return LinkWithBitcodeVector(module, GetROCDLPaths(gcn_arch_name, rocdl_dir_path)); } -Status AMDGPUTargetModuleLinker(llvm::Module* module, - se::GpuComputeCapability gpu_version, - const DebugOptions& debug_options, - const std::string& device_bitcode_dir_path) { +absl::Status AMDGPUTargetModuleLinker( + llvm::Module* module, se::GpuComputeCapability gpu_version, + const DebugOptions& debug_options, + const std::string& device_bitcode_dir_path) { // Link the input module with ROCDL. auto compute_capability = std::get_if(&gpu_version); if (!compute_capability) { - return xla::InternalError("Incompatible compute capability was specified."); + return xla::Internal("Incompatible compute capability was specified."); } std::string gcn_arch_name = compute_capability->gcn_arch_name(); @@ -842,7 +871,7 @@ Status AMDGPUTargetModuleLinker(llvm::Module* module, } } - return OkStatus(); + return absl::OkStatus(); } // The following routine maps a feature token extracted from the @@ -855,10 +884,14 @@ Status AMDGPUTargetModuleLinker(llvm::Module* module, // related changes which have not yet been upstreamed (to the LLVM repo) // When that upstreaming happens (and TF LLVM pointer moves past the // upstream commit), the following mapping will need to change -std::string MapGCNArchNameTokenToFeatureStr(const std::string& token) { +std::string MapGCNArchNameTokenToFeatureStr(const std::string& token, + const std::string& gfx) { if (token == "sramecc+") { return "+sramecc"; } else if (token == "sramecc-") { + if (gfx == "gfx90a" || gfx == "gfx940" || gfx == "gfx941" || + gfx == "gfx942") + return ""; return "-sramecc"; } else if (token == "xnack+") { return "+xnack"; @@ -883,7 +916,7 @@ std::pair GetFeatureStrFromGCNArchName( // The rest of the tokens are the feature/targetid strings if (it != tokens.begin()) { std::string token(*it); - std::string mapped_token = MapGCNArchNameTokenToFeatureStr(token); + std::string mapped_token = MapGCNArchNameTokenToFeatureStr(token, gfx); mapped_tokens.push_back(mapped_token); } } @@ -904,7 +937,32 @@ std::unique_ptr AMDGPUGetTargetMachine( arch.second); } -void AMDGPUBackendInit(const DebugOptions& debug_options) { +// Returns the directory containing ROCm-Device-Libs files. +std::string GetROCDLDir(const DebugOptions& debug_options) { + std::vector potential_rocdl_dirs; + const std::string& datadir = debug_options.xla_gpu_cuda_data_dir(); + if (!datadir.empty()) { + potential_rocdl_dirs.push_back(datadir); + } + potential_rocdl_dirs.push_back(tsl::RocdlRoot()); + + // Tries all potential ROCDL directories in the order they are inserted. + // Returns the first directory that exists in the file system. + for (const std::string& potential_rocdl_dir : potential_rocdl_dirs) { + if (tsl::Env::Default()->IsDirectory(potential_rocdl_dir).ok()) { + VLOG(2) << "Found ROCm-Device-Libs dir " << potential_rocdl_dir; + return potential_rocdl_dir; + } + VLOG(2) << "Unable to find potential ROCm-Device-Libs dir " + << potential_rocdl_dir; + } + + // Last resort: maybe in the current folder. + return "."; +} + +void AMDGPUBackendInit(const DebugOptions& debug_options, + std::string& rocdl_dir_path) { llvm_ir::InitializeLLVMCommandLineOptions( debug_options.xla_backend_extra_options()); @@ -915,10 +973,11 @@ void AMDGPUBackendInit(const DebugOptions& debug_options) { LLVMInitializeAMDGPUTarget(); LLVMInitializeAMDGPUTargetInfo(); LLVMInitializeAMDGPUTargetMC(); + LLVMInitializeAMDGPUAsmParser(); LLVMInitializeAMDGPUAsmPrinter(); - #endif + rocdl_dir_path = GetROCDLDir(debug_options); llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry(); InitializePasses(registry); } @@ -926,12 +985,28 @@ void AMDGPUBackendInit(const DebugOptions& debug_options) { } // namespace namespace amdgpu { -StatusOr> CompileToHsaco( + +std::string LibDevicePath(std::string gcn_arch_name, + const std::string& rocdl_dir_path) { + auto libdevice_dir_paths = GetROCDLPaths(gcn_arch_name, rocdl_dir_path); + for (auto libdevice_dir_path : libdevice_dir_paths) { + if (libdevice_dir_path.find("ocml.bc")) { + return libdevice_dir_path; + } + } + return ""; +} + +absl::StatusOr> CompileToHsaco( llvm::Module* module, se::GpuComputeCapability gpu_version, - const DebugOptions& debug_options, const std::string& rocdl_dir_path, + const DebugOptions& debug_options, const std::string& module_config_cache_key) { static absl::once_flag backend_init_flag; - absl::call_once(backend_init_flag, AMDGPUBackendInit, debug_options); + // TODO(rocm) Ideally this would be refreshed if xla_gpu_cuda_data_dir + // changes. + static std::string rocdl_dir_path; // NOLINT: static/global vars forbidden + absl::call_once(backend_init_flag, AMDGPUBackendInit, debug_options, + rocdl_dir_path); std::vector hsaco; std::unique_ptr target_machine; @@ -958,8 +1033,7 @@ StatusOr> CompileToHsaco( auto compute_capability = std::get_if(&gpu_version); if (!compute_capability) { - return xla::InternalError( - "Incompatible compute capability was specified."); + return xla::Internal("Incompatible compute capability was specified."); } std::string gcn_arch_name = compute_capability->gcn_arch_name(); diff --git a/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h b/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h index 9ff362c2df339..3d67bf043e644 100644 --- a/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h +++ b/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,13 +17,17 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_ #define XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_ +#include +#include #include -#include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "llvm/IR/Module.h" #include "llvm/Target/TargetMachine.h" -#include "xla/statusor.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project #include "xla/stream_executor/device_description.h" #include "xla/types.h" #include "xla/xla.pb.h" @@ -40,8 +44,8 @@ std::string CantFindCudaMessage(absl::string_view msg, std::string LibDevicePath(absl::string_view xla_gpu_cuda_data_dir); // Link libdevice if functions using it are detected in the module. -Status LinkLibdeviceIfNecessary(llvm::Module* module, - const std::string& libdevice_path); +absl::Status LinkLibdeviceIfNecessary(llvm::Module* module, + const std::string& libdevice_path); // Compiles the argument module and returns it. libdevice_dir_path is the parent // directory of the libdevice bitcode libraries. The contents of the module may @@ -50,19 +54,22 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module, // The Compile.* interfaces each create their own llvm::LLVMContext objects for // thread safety, but note that LLVM's multithreaded support is very // preliminary; multithreaded use is not recommended at this time. -StatusOr CompileToPtx( +absl::StatusOr CompileToPtx( llvm::Module* module, se::GpuComputeCapability gpu_version, const DebugOptions& debug_options, std::function configure_target = nullptr); } // namespace nvptx namespace amdgpu { +// Get path to libdevice file. +std::string LibDevicePath(std::string gcn_arch_name, + const std::string& rocdl_dir_path); // Compiles the argument module and returns it with LLVM AMDGPU backend. // rocdl_dir_path is the parent directory of ROCm-Device-Libs bitcode libraries. // The contents of the module may be changed. -StatusOr> CompileToHsaco( +absl::StatusOr> CompileToHsaco( llvm::Module* module, se::GpuComputeCapability gpu_version, - const DebugOptions& debug_options, const std::string& rocdl_dir_path, + const DebugOptions& debug_options, const std::string& module_config_cache_key); } // namespace amdgpu diff --git a/xla/service/gpu/llvm_gpu_backend/utils.cc b/xla/service/gpu/llvm_gpu_backend/utils.cc index 4d3269b287a42..84254bff68873 100644 --- a/xla/service/gpu/llvm_gpu_backend/utils.cc +++ b/xla/service/gpu/llvm_gpu_backend/utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/llvm_gpu_backend/utils.h b/xla/service/gpu/llvm_gpu_backend/utils.h index 5c6a07ae3514e..b355852ea4fe0 100644 --- a/xla/service/gpu/llvm_gpu_backend/utils.h +++ b/xla/service/gpu/llvm_gpu_backend/utils.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ limitations under the License. #include #include "absl/strings/string_view.h" -#include "xla/types.h" namespace llvm { class LLVMContext; diff --git a/xla/service/gpu/llvm_gpu_backend/utils_test.cc b/xla/service/gpu/llvm_gpu_backend/utils_test.cc index bcf33c9a428d3..4675cdc6a341d 100644 --- a/xla/service/gpu/llvm_gpu_backend/utils_test.cc +++ b/xla/service/gpu/llvm_gpu_backend/utils_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/loop_double_buffer_transformer.cc b/xla/service/gpu/loop_double_buffer_transformer.cc index af0d8fe40bc0f..dffa62b799a49 100644 --- a/xla/service/gpu/loop_double_buffer_transformer.cc +++ b/xla/service/gpu/loop_double_buffer_transformer.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" @@ -24,6 +25,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_clone_context.h" @@ -35,7 +37,6 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/flatten_call_graph.h" #include "xla/status.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -53,11 +54,8 @@ void SetChannelIdForNewCollective(HloInstruction* new_instr, // have the same unique channel id. absl::flat_hash_map old_to_new_channel_id_map; absl::flat_hash_map channel_id_comp_map; - if (HloAsyncInstruction::ClassOf(new_instr) && - hlo_query::IsCollectiveCommunicationOp( - DynCast(new_instr) - ->async_wrapped_instruction() - ->opcode())) { + if (new_instr->IsAsynchronous() && hlo_query::IsCollectiveCommunicationOp( + new_instr->async_wrapped_opcode())) { HloInstruction* wrapped_instr = DynCast(new_instr)->async_wrapped_instruction(); int64_t old_channel_id = *wrapped_instr->channel_id(); @@ -73,18 +71,19 @@ void SetChannelIdForNewCollective(HloInstruction* new_instr, wrapped_instr->set_channel_id(new_channel_id); if (channel_id_comp_map.find(new_channel_id) == channel_id_comp_map.end()) { - channel_id_comp_map[new_channel_id] = new_instr->called_computations()[0]; + channel_id_comp_map[new_channel_id] = + new_instr->async_wrapped_computation(); } else { - channel_id_comp_map[new_channel_id]->AddAsyncInstruction(*new_instr); + channel_id_comp_map[new_channel_id]->AddAsyncStart(new_instr); } } else if (hlo_query::IsCollectiveCommunicationOp(new_instr->opcode()) || - hlo_query::IsAsyncCollectiveStartOp(new_instr->opcode())) { + hlo_query::IsAsyncCollectiveStartOp(new_instr)) { new_instr->set_channel_id(hlo_query::NextChannelId(*module)); } } -Status PeelInstructionsForOddTripCount(HloModule* module, - HloInstruction* while_instr) { +absl::Status PeelInstructionsForOddTripCount(HloModule* module, + HloInstruction* while_instr) { std::string suffix = "peeled_double_buffer"; absl::flat_hash_map old_to_new_map; HloComputation* while_body = while_instr->while_body(); @@ -144,18 +143,19 @@ Status PeelInstructionsForOddTripCount(HloModule* module, } } } - return OkStatus(); + return absl::OkStatus(); } } // namespace -StatusOr LoopDoubleBufferTransformer::Run( +absl::StatusOr LoopDoubleBufferTransformer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; std::vector while_instrs; - absl::c_copy_if(module->entry_computation()->instructions(), - std::back_inserter(while_instrs), - HloPredicateIsOp); + for (auto comp : module->MakeNonfusionComputations()) { + absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + HloPredicateIsOp); + } VLOG(2) << "Processing " << while_instrs.size() << " while loops."; for (HloInstruction* while_instr : while_instrs) { diff --git a/xla/service/gpu/loop_double_buffer_transformer.h b/xla/service/gpu/loop_double_buffer_transformer.h index fa48499f42e55..a95ca18c04c42 100644 --- a/xla/service/gpu/loop_double_buffer_transformer.h +++ b/xla/service/gpu/loop_double_buffer_transformer.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,10 +15,11 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_LOOP_DOUBLE_BUFFER_TRANSFORMER_H_ #define XLA_SERVICE_GPU_LOOP_DOUBLE_BUFFER_TRANSFORMER_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" namespace xla { namespace gpu { @@ -44,7 +45,7 @@ class LoopDoubleBufferTransformer : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/loop_double_buffer_transformer_test.cc b/xla/service/gpu/loop_double_buffer_transformer_test.cc index d890234ad1b71..d319bcd3768cb 100644 --- a/xla/service/gpu/loop_double_buffer_transformer_test.cc +++ b/xla/service/gpu/loop_double_buffer_transformer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,19 +24,20 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/hlo_dce.h" #include "xla/service/tuple_simplifier.h" #include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" namespace xla { namespace gpu { namespace { +using tsl::testing::IsOkAndHolds; + int64_t CountInstructions(const HloComputation& computation, HloOpcode opcode) { int64_t count = 0; for (const auto& instruction : computation.instructions()) { @@ -107,11 +108,12 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); LoopDoubleBufferTransformer double_buffer; - HloDCE dce; TupleSimplifier tuple_simp; - ASSERT_IS_OK(double_buffer.Run(module.get()).status()); - ASSERT_IS_OK(tuple_simp.Run(module.get()).status()); - ASSERT_IS_OK(dce.Run(module.get()).status()); + bool changed; + TF_ASSERT_OK_AND_ASSIGN(changed, double_buffer.Run(module.get())); + EXPECT_TRUE(changed); + TF_ASSERT_OK_AND_ASSIGN(changed, tuple_simp.Run(module.get())); + EXPECT_TRUE(changed); HloInstruction* while_instruction; for (auto instr : module->entry_computation()->instructions()) { @@ -178,11 +180,9 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); LoopDoubleBufferTransformer double_buffer; - HloDCE dce; TupleSimplifier tuple_simp; - ASSERT_IS_OK(double_buffer.Run(module.get()).status()); - ASSERT_IS_OK(tuple_simp.Run(module.get()).status()); - ASSERT_IS_OK(dce.Run(module.get()).status()); + EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true)); + EXPECT_THAT(tuple_simp.Run(module.get()), IsOkAndHolds(true)); // We expect that for the while loop, no further copy needs to be added to the // module. @@ -245,11 +245,9 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); LoopDoubleBufferTransformer double_buffer; - HloDCE dce; TupleSimplifier tuple_simp; - ASSERT_IS_OK(double_buffer.Run(module.get()).status()); - ASSERT_IS_OK(tuple_simp.Run(module.get()).status()); - ASSERT_IS_OK(dce.Run(module.get()).status()); + EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true)); + EXPECT_THAT(tuple_simp.Run(module.get()), IsOkAndHolds(true)); HloInstruction* while_instruction; for (auto instr : module->entry_computation()->instructions()) { @@ -319,11 +317,9 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); LoopDoubleBufferTransformer double_buffer; - HloDCE dce; TupleSimplifier tuple_simp; - ASSERT_IS_OK(double_buffer.Run(module.get()).status()); - ASSERT_IS_OK(tuple_simp.Run(module.get()).status()); - ASSERT_IS_OK(dce.Run(module.get()).status()); + EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true)); + EXPECT_THAT(tuple_simp.Run(module.get()), IsOkAndHolds(true)); HloInstruction* while_instruction; for (auto instr : module->entry_computation()->instructions()) { @@ -397,11 +393,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); LoopDoubleBufferTransformer double_buffer; - HloDCE dce; - TupleSimplifier tuple_simp; - ASSERT_IS_OK(double_buffer.Run(module.get()).status()); - ASSERT_IS_OK(tuple_simp.Run(module.get()).status()); - ASSERT_IS_OK(dce.Run(module.get()).status()); + EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true)); absl::flat_hash_set while_loops_callees; @@ -460,11 +452,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); LoopDoubleBufferTransformer double_buffer; - HloDCE dce; - TupleSimplifier tuple_simp; - ASSERT_IS_OK(double_buffer.Run(module.get()).status()); - ASSERT_IS_OK(tuple_simp.Run(module.get()).status()); - ASSERT_IS_OK(dce.Run(module.get()).status()); + EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true)); absl::flat_hash_set while_loops_callees; @@ -482,6 +470,60 @@ ENTRY main { // associated computations. EXPECT_EQ(while_loops_callees.size(), 8); } + +TEST_F(GpuLoopDoubleBufferTransformerTest, NestedWhileLoopAreUnrolled) { + const char* const kModuleString = R"( +HloModule loop_unrolling_nested_are_unrolled +condition_nested { + input_tuple = (s32[]) parameter(0) + cond = s32[] get-tuple-element(input_tuple), index=0 + trip_count = s32[] constant(10) + ROOT done = pred[] compare(cond, trip_count), direction=LT +} +body_nested { + input_tuple = (s32[]) parameter(0) + cond = s32[] get-tuple-element(input_tuple), index=0 + one = s32[] constant(1) + cond_plus_1 = s32[] add(cond, one) + ROOT output = (s32[]) tuple(cond_plus_1) +} +condition { + input_tuple = (s32[]) parameter(0) + cond = s32[] get-tuple-element(input_tuple), index=0 + trip_count = s32[] constant(10) + ROOT done = pred[] compare(cond, trip_count), direction=LT +} +body { + input_tuple = (s32[]) parameter(0) + ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested, backend_config={"known_trip_count":{"n":"11"}} +} +ENTRY main { + param_0 = (s32[]) parameter(0) + ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"11"}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + LoopDoubleBufferTransformer double_buffer; + EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(true)); + + int64_t num_whiles = 0; + for (const HloComputation* computation : module->computations()) { + for (const HloInstruction* instr : computation->instructions()) { + if (instr->opcode() == HloOpcode::kWhile) { + // All loops in the module should be unrolled now and have trip count + // of 5. + EXPECT_EQ(instr->backend_config() + ->known_trip_count() + .n(), + 5); + ++num_whiles; + } + } + } + // We expect the number of while loops to be 4 in total after unrolling. + EXPECT_EQ(num_whiles, 4); +} } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/make_batch_pointers.cc b/xla/service/gpu/make_batch_pointers.cc index 66f3842072005..e7788d652dcc4 100644 --- a/xla/service/gpu/make_batch_pointers.cc +++ b/xla/service/gpu/make_batch_pointers.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,8 +16,8 @@ limitations under the License. #include "xla/service/gpu/make_batch_pointers.h" #include -#include +#include "absl/status/status.h" #include "xla/status.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" @@ -43,9 +43,10 @@ namespace make_batch_pointers { void* kernel(); // returns a pointer to a CUDA C++ device function } // namespace make_batch_pointers -Status MakeBatchPointers(se::Stream* stream, se::DeviceMemoryBase base_ptr, - size_t stride_bytes, size_t n, - se::DeviceMemoryBase ptrs_out) { +absl::Status MakeBatchPointers(se::Stream* stream, + se::DeviceMemoryBase base_ptr, + size_t stride_bytes, size_t n, + se::DeviceMemoryBase ptrs_out) { static constexpr size_t kThreads = 128; se::StreamExecutor* executor = stream->parent(); @@ -58,16 +59,18 @@ Status MakeBatchPointers(se::Stream* stream, se::DeviceMemoryBase base_ptr, #else TF_ASSIGN_OR_RETURN( - auto kernel, (executor->CreateTypedKernel( - "make_batch_pointers", make_batch_pointers::kernel()))); + auto kernel, + (se::TypedKernel< + se::DeviceMemoryBase, size_t, size_t, + se::DeviceMemoryBase>::Create(executor, "make_batch_pointers", + make_batch_pointers::kernel()))); TF_RETURN_IF_ERROR( stream->ThenLaunch(se::ThreadDim(kThreads, 1, 1), - se::BlockDim(CeilOfRatio(n, kThreads), 1, 1), *kernel, + se::BlockDim(CeilOfRatio(n, kThreads), 1, 1), kernel, base_ptr, stride_bytes, n, ptrs_out)); #endif - return OkStatus(); + return absl::OkStatus(); } } // namespace xla::gpu diff --git a/xla/service/gpu/make_batch_pointers.cu.cc b/xla/service/gpu/make_batch_pointers.cu.cc index a8caaab316420..344f8ecc214e1 100644 --- a/xla/service/gpu/make_batch_pointers.cu.cc +++ b/xla/service/gpu/make_batch_pointers.cu.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/make_batch_pointers.h b/xla/service/gpu/make_batch_pointers.h index 320090beae386..171f33616d27e 100644 --- a/xla/service/gpu/make_batch_pointers.h +++ b/xla/service/gpu/make_batch_pointers.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ limitations under the License. #define XLA_SERVICE_GPU_MAKE_BATCH_POINTERS_H_ #include -#include #include "xla/status.h" #include "xla/stream_executor/device_memory.h" @@ -50,9 +49,10 @@ namespace xla::gpu { // driver and slow down *all* work on the GPU. So to do this right, we'd // need to allocate the host memory as pinned, one alloc per stream. Then // we'd need to manage this memory without leaks. This becomes complex! -Status MakeBatchPointers(se::Stream* stream, se::DeviceMemoryBase base_ptr, - size_t stride_bytes, size_t n, - se::DeviceMemoryBase ptrs_out); +absl::Status MakeBatchPointers(se::Stream* stream, + se::DeviceMemoryBase base_ptr, + size_t stride_bytes, size_t n, + se::DeviceMemoryBase ptrs_out); } // namespace xla::gpu diff --git a/xla/service/gpu/matmul_utils.cc b/xla/service/gpu/matmul_utils.cc index 3482c7edeefcf..9a26f9357a9b8 100644 --- a/xla/service/gpu/matmul_utils.cc +++ b/xla/service/gpu/matmul_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,26 +16,30 @@ limitations under the License. #include "xla/service/gpu/matmul_utils.h" #include +#include #include #include #include #include -#include #include #include #include "absl/algorithm/container.h" #include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" +#include "xla/service/algorithm_util.h" +#include "xla/service/gpu/backend_configs.pb.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" @@ -44,17 +48,17 @@ limitations under the License. #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #if GOOGLE_CUDA -#include "xla/stream_executor/host_or_device_scalar.h" #endif // GOOGLE_CUDA namespace xla { namespace gpu { -StatusOr> GetNonContractingDims( +absl::StatusOr> GetNonContractingDims( const Shape& shape, absl::Span batch_dims, absl::Span contracting_dims) { std::vector non_contracting_dims; @@ -81,32 +85,33 @@ const tsl::protobuf::RepeatedField& BatchDimensionsForOperand( return dimension_numbers.rhs_batch_dimensions(); } -int64_t ContractingDimensionIndex(const HloInstruction& dot, - const int operand_number) { +absl::StatusOr ContractingDimensionIndex(const HloInstruction& dot, + const int operand_number) { const DotDimensionNumbers& dimension_numbers = dot.dot_dimension_numbers(); if (operand_number == 0) { - CHECK_EQ(dimension_numbers.lhs_contracting_dimensions().size(), 1); + TF_RET_CHECK(dimension_numbers.lhs_contracting_dimensions().size() == 1); return dimension_numbers.lhs_contracting_dimensions(0); } - CHECK_EQ(dimension_numbers.rhs_contracting_dimensions().size(), 1); + TF_RET_CHECK(dimension_numbers.rhs_contracting_dimensions().size() == 1); return dimension_numbers.rhs_contracting_dimensions(0); } -int64_t NonContractingDimensionIndex(const HloInstruction& dot, - const int operand_number) { - StatusOr> non_contracting_dims = +absl::StatusOr NonContractingDimensionIndex(const HloInstruction& dot, + const int operand_number) { + TF_ASSIGN_OR_RETURN(int64_t contracting_dim, + ContractingDimensionIndex(dot, operand_number)); + TF_ASSIGN_OR_RETURN( + std::vector non_contracting_dims, GetNonContractingDims(dot.operand(operand_number)->shape(), BatchDimensionsForOperand(dot, operand_number), - {ContractingDimensionIndex(dot, operand_number)}); - TF_CHECK_OK(non_contracting_dims.status()); - CHECK_EQ(non_contracting_dims->size(), 1); - return non_contracting_dims->front(); + {contracting_dim})); + TF_RET_CHECK(non_contracting_dims.size() == 1); + return non_contracting_dims.front(); } -StatusOr GetBatchRowColumnShape(const Shape& shape, - absl::Span batch_dims, - absl::Span row_dims, - absl::Span col_dims) { +absl::StatusOr GetBatchRowColumnShape( + const Shape& shape, absl::Span batch_dims, + absl::Span row_dims, absl::Span col_dims) { TF_RET_CHECK(shape.has_layout()); std::vector minor_to_major; @@ -114,13 +119,14 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, // The GeMM output always has its layout set such that the batch, row, and // col dim groups are each laid out physically sequentially. GeMM operands // must, therefore, be laid out similarly. - auto check_physically_sequential = [&](absl::Span dims) { + auto check_physically_sequential = + [&](absl::Span dims) -> absl::Status { for (auto it = dims.rbegin(); it != dims.rend(); ++it) { // NOTE: `i` is incremented as we check the dimensions. if (*it != shape.layout().minor_to_major()[i++]) return InvalidArgument("dims not physically_sequential"); } - return OkStatus(); + return absl::OkStatus(); }; int64_t dim = shape.layout().minor_to_major()[i]; @@ -155,7 +161,7 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, } // Returns the matrix layout for a logical shape (batch, rows, columns). -/*static*/ StatusOr MatrixLayout::For(const Shape& shape) { +/*static*/ absl::StatusOr MatrixLayout::For(const Shape& shape) { TF_RET_CHECK(shape.rank() == 3); TF_RET_CHECK(shape.has_layout()); @@ -163,7 +169,7 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, int64_t num_rows = shape.dimensions(1); int64_t num_cols = shape.dimensions(2); - MatrixLayout::Order order = MatrixLayout::Order::kRowMajor; + Order order{Order::kRowMajor}; int64_t leading_dim_stride = num_cols; int64_t batch_stride = num_rows * num_cols; @@ -174,7 +180,7 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, case 012: // (B,R,C) (major-to-minor) break; case 021: // (B,C,R) - order = MatrixLayout::Order::kColumnMajor; + order = Order::kColumnMajor; leading_dim_stride = num_rows; break; case 0102: // (R,B,C) @@ -182,7 +188,7 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, batch_stride = num_cols; break; case 0201: // (C,B,R) - order = MatrixLayout::Order::kColumnMajor; + order = Order::kColumnMajor; leading_dim_stride = batch_size * num_rows; batch_stride = num_rows; break; @@ -190,14 +196,15 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, return Unimplemented("batch in most minor dimension"); } - if (batch_size == 1) batch_stride = 0; - return MatrixLayout{ - shape.element_type(), num_rows, num_cols, order, - batch_size, leading_dim_stride, batch_stride, - }; + if (batch_size == 1) { + batch_stride = 0; + } + return MatrixLayout{se::gpu::MatrixLayout{shape.element_type(), num_rows, + num_cols, order, batch_size, + leading_dim_stride, batch_stride}}; } -/*static*/ StatusOr MatrixLayout::For( +/*static*/ absl::StatusOr MatrixLayout::For( const Shape& shape, absl::Span batch_dims, absl::Span row_dims, absl::Span col_dims) { TF_ASSIGN_OR_RETURN( @@ -206,11 +213,9 @@ StatusOr GetBatchRowColumnShape(const Shape& shape, return MatrixLayout::For(batch_row_col_shape); } -/*static*/ StatusOr MatrixLayout::For(const Shape& shape, - size_t lhs_num_batch_dims, - size_t lhs_num_row_dims, - size_t rhs_num_batch_dims, - size_t rhs_num_col_dims) { +/*static*/ absl::StatusOr MatrixLayout::For( + const Shape& shape, size_t lhs_num_batch_dims, size_t lhs_num_row_dims, + size_t rhs_num_batch_dims, size_t rhs_num_col_dims) { size_t num_batch_dims = std::max(lhs_num_batch_dims, rhs_num_batch_dims); TF_RET_CHECK(shape.rank() == @@ -242,8 +247,8 @@ std::vector NormalizedRelativeOrder(absl::Span dims) { } } // namespace -StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, - int64_t operand_idx) { +absl::StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, + int64_t operand_idx) { TF_RET_CHECK(dot.opcode() == HloOpcode::kDot); TF_RET_CHECK(dot.operand_count() > operand_idx); @@ -287,29 +292,33 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, .ok(); } -/*static*/ StatusOr GemmConfig::For( +/*static*/ absl::StatusOr GemmConfig::For( const Shape& lhs_shape, absl::Span lhs_batch_dims, absl::Span lhs_contracting_dims, const Shape& rhs_shape, absl::Span rhs_batch_dims, absl::Span rhs_contracting_dims, const Shape& output_shape, double alpha_real, double alpha_imag, double beta, + PrecisionConfig::Algorithm precision_algorithm, std::optional algorithm, int64_t compute_precision, bool grad_x, bool grad_y) { return GemmConfig::For(lhs_shape, lhs_batch_dims, lhs_contracting_dims, rhs_shape, rhs_batch_dims, rhs_contracting_dims, /*c_shape=*/output_shape, /*bias_shape_ptr=*/nullptr, - output_shape, alpha_real, alpha_imag, beta, algorithm, - compute_precision, grad_x, grad_y); + output_shape, alpha_real, alpha_imag, beta, + precision_algorithm, algorithm, compute_precision, + grad_x, grad_y); } -/*static*/ StatusOr GemmConfig::For( +/*static*/ absl::StatusOr GemmConfig::For( const Shape& lhs_shape, absl::Span lhs_batch_dims, absl::Span lhs_contracting_dims, const Shape& rhs_shape, absl::Span rhs_batch_dims, absl::Span rhs_contracting_dims, const Shape& c_shape, const Shape* bias_shape_ptr, const Shape& output_shape, double alpha_real, - double alpha_imag, double beta, std::optional algorithm, - int64_t compute_precision, bool grad_x, bool grad_y) { + double alpha_imag, double beta, + PrecisionConfig::Algorithm precision_algorithm, + std::optional algorithm, int64_t compute_precision, bool grad_x, + bool grad_y) { absl::Span lhs_col_dims = lhs_contracting_dims; TF_ASSIGN_OR_RETURN( std::vector lhs_row_dims, @@ -384,6 +393,8 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, switch (output_shape.element_type()) { case F8E4M3FN: case F8E5M2: + case F8E4M3FNUZ: + case F8E5M2FNUZ: case F16: case BF16: case F32: @@ -397,7 +408,7 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, TF_RET_CHECK(alpha_imag == 0); if (lhs_layout.dtype != PrimitiveType::S8 || rhs_layout.dtype != PrimitiveType::S8) { - return InternalError( + return Internal( "For int32 gemm output only int8 input is supported, got input: " "%s, %s", primitive_util::LowercasePrimitiveTypeName(lhs_layout.dtype), @@ -405,9 +416,9 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, } break; default: - return InternalError("Unexpected GEMM datatype: %s", - primitive_util::LowercasePrimitiveTypeName( - output_shape.element_type())); + return Internal("Unexpected GEMM datatype: %s", + primitive_util::LowercasePrimitiveTypeName( + output_shape.element_type())); } return GemmConfig{lhs_layout, @@ -417,18 +428,35 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, {alpha_real, alpha_imag}, beta, compute_precision, + precision_algorithm, algorithm, grad_x, grad_y}; } -/*static*/ StatusOr GemmConfig::For(const HloInstruction* gemm) { - TF_ASSIGN_OR_RETURN(GemmBackendConfig config, - gemm->backend_config()); +namespace { + +bool IsTf32Allowed(PrecisionConfig::Algorithm algorithm, + int64_t compute_precision) { + if (algorithm == PrecisionConfig::ALG_UNSET) { + return compute_precision <= 1; + } + return algorithm_util::HasTf32InputType(algorithm); +} + +} // namespace + +/*static*/ absl::StatusOr GemmConfig::For( + const HloInstruction* gemm) { + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + gemm->backend_config()); + const GemmBackendConfig& config = gpu_config.gemm_backend_config(); std::optional algorithm; if (config.algorithm_case() != GemmBackendConfig::ALGORITHM_NOT_SET) { algorithm = config.selected_algorithm(); + } else { + algorithm = se::blas::kDefaultAlgorithm; } const Shape& lhs_shape = gemm->operand(0)->shape(); @@ -437,210 +465,216 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, const Shape& output_shape = gemm->shape().IsTuple() ? gemm->shape().tuple_shapes(0) : gemm->shape(); + bool has_matrix_bias = config.beta() != 0.; + Shape c_shape = has_matrix_bias ? gemm->operand(2)->shape() : output_shape; + + std::optional vector_bias_shape; + TF_ASSIGN_OR_RETURN( + bool has_vector_bias, + xla::gpu::gpublas_lt::EpilogueAddsVectorBias(config.epilogue())); + if (has_vector_bias) { + int vector_bias_index = has_matrix_bias ? 3 : 2; + if (primitive_util::IsF8Type(lhs_shape.element_type())) { + // FP8 gemms have 4 scales as inputs which come before the vector bias. + vector_bias_index += 4; + } + vector_bias_shape = gemm->operand(vector_bias_index)->shape(); + } + auto attributes = gemm->frontend_attributes().map(); bool grad_x = (attributes["grad_x"] == "true"); bool grad_y = (attributes["grad_y"] == "true"); + int64_t precision = se::blas::kDefaultComputePrecision; + for (auto operand_precision : config.precision_config().operand_precision()) { + precision = std::max(precision, static_cast(operand_precision)); + } + const PrecisionConfig::Algorithm precision_algorithm = + config.precision_config().algorithm(); + return GemmConfig::For( lhs_shape, dot_dims.lhs_batch_dimensions(), dot_dims.lhs_contracting_dimensions(), rhs_shape, dot_dims.rhs_batch_dimensions(), dot_dims.rhs_contracting_dimensions(), - output_shape, config.alpha_real(), config.alpha_imag(), config.beta(), - algorithm, se::blas::kDefaultComputePrecision, grad_x, grad_y); + /*c_shape=*/c_shape, + /*bias_shape_ptr=*/ + vector_bias_shape ? &vector_bias_shape.value() : nullptr, output_shape, + config.alpha_real(), config.alpha_imag(), config.beta(), + precision_algorithm, algorithm, precision, grad_x, grad_y); } -/*static*/ StatusOr GemmConfig::For(mlir::lmhlo_gpu::GEMMOp op) { - mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.getDotDimensionNumbers(); +absl::StatusOr GemmConfig::GetMatrixDescriptors( + se::DeviceMemoryBase lhs_buf, se::DeviceMemoryBase rhs_buf, + se::DeviceMemoryBase out_buf) const { + auto create_matrix_desc = [](const se::gpu::MatrixLayout& layout, + se::DeviceMemoryBase data) + -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(se::blas::DataType type, + se::gpu::AsBlasDataType(layout.dtype)); + return se::gpu::MatrixDescriptor{ + data, layout.leading_dim_stride, layout.batch_stride, type, + // BLAS is column-major by default. + (layout.order == se::gpu::MatrixLayout::Order::kColumnMajor + ? se::blas::Transpose::kNoTranspose + : se::blas::Transpose::kTranspose)}; + }; + // TODO: make a local copy to prevent modification of layouts, + // but maybe we can modify them once instead during creation ? + se::gpu::MatrixLayout lhs = lhs_layout, rhs = rhs_layout, out = output_layout; - std::optional algorithm; - if (op.getAlgorithm()) algorithm = *op.getAlgorithm(); - - bool grad_x = false; - bool grad_y = false; - auto attr_grad_x = op.getGradX(); - if (attr_grad_x) grad_x = attr_grad_x.value(); - auto attr_grad_y = op.getGradY(); - if (attr_grad_y) grad_y = attr_grad_y.value(); - - int64_t compute_precision = 0; // Default - if (op.getPrecisionConfig().has_value()) { - auto precision_config = op.getPrecisionConfig(); - for (auto attr : precision_config.value()) { - int64_t value = static_cast( - attr.template cast().getValue()); - if (value > compute_precision) { - compute_precision = value; - } - } + bool must_swap_operands = MakeOutputColumnMajor(lhs, rhs, out); + if (must_swap_operands) { + std::swap(lhs_buf, rhs_buf); } - return GemmConfig::For( - GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(), - dot_dims.getLhsContractingDimensions(), GetShape(op.getB()), - dot_dims.getRhsBatchingDimensions(), - dot_dims.getRhsContractingDimensions(), GetShape(op.getC()), - op.getAlphaReal().convertToDouble(), op.getAlphaImag().convertToDouble(), - op.getBeta().convertToDouble(), algorithm, compute_precision, grad_x, - grad_y); + TF_ASSIGN_OR_RETURN(se::gpu::OutputMatrixDescriptor out_desc, + create_matrix_desc(out, out_buf)); + out_desc.batch_size = out.batch_size; + out_desc.m = out.num_rows; + out_desc.n = out.num_cols; + out_desc.k = lhs.num_cols; + // TODO(tdanyluk): Investigate why don't we use the actual precision (and + // algorithm) here? Why do we use the default? + TF_ASSIGN_OR_RETURN(out_desc.compute_type, + se::gpu::GetBlasComputationType( + PrecisionConfig::ALG_UNSET, lhs.dtype, out.dtype, + se::blas::kDefaultComputePrecision)); + + TF_ASSIGN_OR_RETURN(se::gpu::MatrixDescriptor lhs_desc, + create_matrix_desc(lhs, lhs_buf)); + TF_ASSIGN_OR_RETURN(se::gpu::MatrixDescriptor rhs_desc, + create_matrix_desc(rhs, rhs_buf)); + + return DescriptorsTuple{lhs_desc, rhs_desc, out_desc, must_swap_operands}; } namespace { -// This struct contains the metadata of a matrix, e.g., its base address and -// dimensions. -struct MatrixDescriptor { - se::DeviceMemoryBase data; - int64_t leading_dim_stride; - int64_t batch_stride; - se::blas::Transpose transpose; - - template - se::DeviceMemory cast() const { - return se::DeviceMemory(data); - } -}; - -se::blas::Transpose AsBlasTranspose(MatrixLayout::Order order) { - // BLAS is column-major by default. - return (order == MatrixLayout::Order::kColumnMajor) - ? se::blas::Transpose::kNoTranspose - : se::blas::Transpose::kTranspose; -} - -MatrixDescriptor GetMatrixDesc(const MatrixLayout& layout, - se::DeviceMemoryBase data) { - return MatrixDescriptor{ - data, - *layout.leading_dim_stride, - *layout.batch_stride, - AsBlasTranspose(layout.order), - }; -} - template -Status DoGemmWithAlgorithm( - int64_t batch_size, int64_t m, int64_t n, int64_t k, - const MatrixDescriptor& lhs, const MatrixDescriptor& rhs, - const MatrixDescriptor& output, se::DeviceMemoryBase workspace, Scale alpha, - Scale beta, se::Stream* stream, se::blas::AlgorithmType algorithm, - se::blas::ComputePrecision compute_precision, - const se::NumericOptions& numeric_options, - se::blas::ProfileResult* profile_result, se::blas::CallContext context) { +absl::Status DoGemmWithAlgorithm(const se::gpu::MatrixDescriptor& lhs, + const se::gpu::MatrixDescriptor& rhs, + const se::gpu::OutputMatrixDescriptor& output, + se::DeviceMemoryBase workspace, Scale alpha, + Scale beta, se::Stream* stream, + PrecisionConfig::Algorithm precision_algorithm, + se::blas::AlgorithmType algorithm, + se::blas::ComputePrecision compute_precision, + const se::NumericOptions& numeric_options, + se::blas::ProfileResult* profile_result, + se::blas::CallContext context) { CHECK(output.transpose == se::blas::Transpose::kNoTranspose); PrimitiveType lhs_type = primitive_util::NativeToPrimitiveType(); PrimitiveType output_type = primitive_util::NativeToPrimitiveType(); - TF_ASSIGN_OR_RETURN(se::blas::ComputationType computation_type, - se::gpu::GetBlasComputationType(lhs_type, output_type, - compute_precision)); + TF_ASSIGN_OR_RETURN( + se::blas::ComputationType computation_type, + se::gpu::GetBlasComputationType(precision_algorithm, lhs_type, + output_type, compute_precision)); se::DeviceMemory output_data(output.data); // Set a workspace for all Blas operations launched below. - se::blas::BlasSupport::ScopedWorkspace scoped_workspace( - stream->parent()->AsBlas(), &workspace); - - if (batch_size != 1) { - return stream->ThenBlasGemmStridedBatchedWithAlgorithm( - lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast(), - lhs.leading_dim_stride, lhs.batch_stride, rhs.cast(), - rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data, - output.leading_dim_stride, output.batch_stride, batch_size, - computation_type, algorithm, numeric_options, profile_result, context); + auto* blas = stream->parent()->AsBlas(); + if (blas == nullptr) { + return absl::InternalError("No Blas support for stream"); + } + + se::blas::BlasSupport::ScopedWorkspace scoped_workspace(blas, &workspace); + + if (output.batch_size != 1) { + return blas->BlasGemmStridedBatchedWithAlgorithm( + stream, lhs.transpose, rhs.transpose, output.m, output.n, output.k, + alpha, lhs.cast(), lhs.leading_dim_stride, lhs.batch_stride, + rhs.cast(), rhs.leading_dim_stride, rhs.batch_stride, beta, + &output_data, output.leading_dim_stride, output.batch_stride, + output.batch_size, computation_type, algorithm, numeric_options, + profile_result, context); } else { - return stream->ThenBlasGemmWithAlgorithm( - lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast(), - lhs.leading_dim_stride, rhs.cast(), rhs.leading_dim_stride, beta, - &output_data, output.leading_dim_stride, computation_type, algorithm, - numeric_options, profile_result, context); + return blas->BlasGemmWithAlgorithm( + stream, lhs.transpose, rhs.transpose, output.m, output.n, output.k, + alpha, lhs.cast(), lhs.leading_dim_stride, rhs.cast(), + rhs.leading_dim_stride, beta, &output_data, output.leading_dim_stride, + computation_type, algorithm, numeric_options, profile_result, context); } } template -Status DoGemm(int64_t batch_size, int64_t m, int64_t n, int64_t k, - const MatrixDescriptor& lhs, const MatrixDescriptor& rhs, - const MatrixDescriptor& output, se::DeviceMemoryBase workspace, - Scale alpha, Scale beta, se::Stream* stream, - std::optional algorithm, - se::blas::ComputePrecision compute_precision, - const se::NumericOptions& numeric_options, - se::blas::ProfileResult* profile_result, - se::blas::CallContext context) { +absl::Status DoGemm(const se::gpu::MatrixDescriptor& lhs, + const se::gpu::MatrixDescriptor& rhs, + const se::gpu::OutputMatrixDescriptor& output, + se::DeviceMemoryBase workspace, Scale alpha, Scale beta, + se::Stream* stream, + PrecisionConfig::Algorithm precision_algorithm, + std::optional algorithm, + se::blas::ComputePrecision compute_precision, + const se::NumericOptions& numeric_options, + se::blas::ProfileResult* profile_result, + se::blas::CallContext context) { CHECK(output.transpose == se::blas::Transpose::kNoTranspose); se::DeviceMemory output_data(output.data); + auto* blas = stream->parent()->AsBlas(); + if (blas == nullptr) { + return absl::InternalError("No Blas support for stream"); + } // Set a workspace for all Blas operations launched below. - se::blas::BlasSupport::ScopedWorkspace scoped_workspace( - stream->parent()->AsBlas(), &workspace); + se::blas::BlasSupport::ScopedWorkspace scoped_workspace(blas, &workspace); -// TODO: enable DoGemmWithAlgorithm for ROCm ! -#if GOOGLE_CUDA if (algorithm) { return DoGemmWithAlgorithm( - batch_size, m, n, k, lhs, rhs, output, workspace, alpha, beta, stream, + lhs, rhs, output, workspace, alpha, beta, stream, precision_algorithm, *algorithm, compute_precision, numeric_options, profile_result, context); } -#endif - if (batch_size != 1) { - return stream->ThenBlasGemmStridedBatched( - lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast(), - lhs.leading_dim_stride, lhs.batch_stride, rhs.cast(), - rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data, - output.leading_dim_stride, output.batch_stride, batch_size, - numeric_options, context); + if (output.batch_size != 1) { + return blas->BlasGemmStridedBatched( + stream, lhs.transpose, rhs.transpose, output.m, output.n, output.k, + alpha, lhs.cast(), lhs.leading_dim_stride, lhs.batch_stride, + rhs.cast(), rhs.leading_dim_stride, rhs.batch_stride, beta, + &output_data, output.leading_dim_stride, output.batch_stride, + output.batch_size, numeric_options, context); } - return stream->ThenBlasGemm( - lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast(), - lhs.leading_dim_stride, rhs.cast(), rhs.leading_dim_stride, beta, - &output_data, output.leading_dim_stride, numeric_options, context); + return blas->BlasGemm(stream, lhs.transpose, rhs.transpose, output.m, + output.n, output.k, alpha, lhs.cast(), + lhs.leading_dim_stride, rhs.cast(), + rhs.leading_dim_stride, beta, &output_data, + output.leading_dim_stride, numeric_options, context); } } // namespace -Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, - se::DeviceMemoryBase output_buffer, - se::DeviceMemoryBase workspace_buffer, bool deterministic_ops, - se::Stream* stream, - std::optional algorithm, - se::blas::ProfileResult* profile_result) { +absl::Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, + se::DeviceMemoryBase rhs_buffer, + se::DeviceMemoryBase output_buffer, + se::DeviceMemoryBase workspace_buffer, + bool deterministic_ops, se::Stream* stream, + std::optional algorithm, + se::blas::ProfileResult* profile_result) { VLOG(2) << "Executing a GemmThunk"; - auto lhs_layout = MatrixLayout{config.lhs_layout}, - rhs_layout = MatrixLayout{config.rhs_layout}, - output_layout = MatrixLayout{config.output_layout}; - bool must_swap_operands = - se::gpu::MakeOutputColumnMajor(lhs_layout, rhs_layout, output_layout); - if (must_swap_operands) { - std::swap(lhs_buffer, rhs_buffer); - } + TF_ASSIGN_OR_RETURN( + GemmConfig::DescriptorsTuple desc, + config.GetMatrixDescriptors(lhs_buffer, rhs_buffer, output_buffer)); - int64_t m = output_layout.num_rows; - int64_t n = output_layout.num_cols; - int64_t k = lhs_layout.num_cols; - MatrixDescriptor lhs = GetMatrixDesc(lhs_layout, lhs_buffer); - MatrixDescriptor rhs = GetMatrixDesc(rhs_layout, rhs_buffer); - MatrixDescriptor output = GetMatrixDesc(output_layout, output_buffer); - int64_t batch_size = output_layout.batch_size; se::NumericOptions numeric_options{ deterministic_ops, - /*allow_tf32=*/config.compute_precision <= 1}; + /*allow_tf32=*/IsTf32Allowed(config.precision_algorithm, + config.compute_precision)}; if (!algorithm) algorithm = config.algorithm; se::blas::CallContext context = se::blas::CallContext::kNone; if (config.grad_x) { - context = must_swap_operands ? se::blas::CallContext::kBackpropInput2 - : se::blas::CallContext::kBackpropInput1; + context = desc.operands_swapped ? se::blas::CallContext::kBackpropInput2 + : se::blas::CallContext::kBackpropInput1; } if (config.grad_y) { - context = must_swap_operands ? se::blas::CallContext::kBackpropInput1 - : se::blas::CallContext::kBackpropInput2; + context = desc.operands_swapped ? se::blas::CallContext::kBackpropInput1 + : se::blas::CallContext::kBackpropInput2; } - std::tuple operand_types{ - lhs_layout.dtype, rhs_layout.dtype, output_layout.dtype}; + std::tuple operand_types{config.lhs_layout.dtype, config.rhs_layout.dtype, + config.output_layout.dtype}; // Skip degenerate gemm with memzero. In general this is not safe, because it // will suppress NaN propagation, however cuBLAS internally has exactly the @@ -649,44 +683,47 @@ Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, // graphs, so we are making sure we do not trigger it). if (config.alpha.real() == 0.0 && config.alpha.imag() == 0.0 && config.beta == 0.0) { - stream->ThenMemZero(&output_buffer, output_buffer.size()); - return tsl::OkStatus(); - } - -#define TYPED_GEMM(SCALENTYPE, ATYPE, BTYPE, CTYPE) \ - if (operand_types == std::make_tuple(ATYPE, BTYPE, CTYPE)) { \ - using NativeScaleType = \ - primitive_util::PrimitiveTypeToNative::type; \ - using NativeAType = primitive_util::PrimitiveTypeToNative::type; \ - using NativeCType = primitive_util::PrimitiveTypeToNative::type; \ - return DoGemm( \ - batch_size, m, n, k, lhs, rhs, output, workspace_buffer, \ - static_cast(config.alpha.real()), \ - static_cast(config.beta), stream, algorithm, \ - config.compute_precision, numeric_options, profile_result, context); \ - } - -#define TYPED_GEMM_COMPLEX(SCALENTYPE, ATYPE, BTYPE, CTYPE) \ - if (operand_types == std::make_tuple(ATYPE, BTYPE, CTYPE)) { \ - using NativeScaleType = \ - primitive_util::PrimitiveTypeToNative::type; \ - using NativeAType = primitive_util::PrimitiveTypeToNative::type; \ - using NativeCType = primitive_util::PrimitiveTypeToNative::type; \ - return DoGemm( \ - batch_size, m, n, k, lhs, rhs, output, workspace_buffer, \ - static_cast(config.alpha), \ - static_cast(config.beta), stream, algorithm, \ - config.compute_precision, numeric_options, profile_result, context); \ - } - - if (output_layout.dtype == S32) { + return stream->MemZero(&output_buffer, output_buffer.size()); + } + +#define TYPED_GEMM(SCALENTYPE, ATYPE, BTYPE, CTYPE) \ + if (operand_types == std::make_tuple(ATYPE, BTYPE, CTYPE)) { \ + using NativeScaleType = \ + primitive_util::PrimitiveTypeToNative::type; \ + using NativeAType = primitive_util::PrimitiveTypeToNative::type; \ + using NativeCType = primitive_util::PrimitiveTypeToNative::type; \ + return DoGemm( \ + desc.lhs, desc.rhs, desc.output, workspace_buffer, \ + static_cast(config.alpha.real()), \ + static_cast(config.beta), stream, \ + config.precision_algorithm, algorithm, config.compute_precision, \ + numeric_options, profile_result, context); \ + } + +#define TYPED_GEMM_COMPLEX(SCALENTYPE, ATYPE, BTYPE, CTYPE) \ + if (operand_types == std::make_tuple(ATYPE, BTYPE, CTYPE)) { \ + using NativeScaleType = \ + primitive_util::PrimitiveTypeToNative::type; \ + using NativeAType = primitive_util::PrimitiveTypeToNative::type; \ + using NativeCType = primitive_util::PrimitiveTypeToNative::type; \ + return DoGemm( \ + desc.lhs, desc.rhs, desc.output, workspace_buffer, \ + static_cast(config.alpha), \ + static_cast(config.beta), stream, \ + config.precision_algorithm, algorithm, config.compute_precision, \ + numeric_options, profile_result, context); \ + } + + if (config.output_layout.dtype == S32) { if (!algorithm) algorithm = se::blas::kDefaultGemmAlgo; + // TODO(tdanyluk): Investigate why don't we use the actual precision (and + // algorithm) here? Why do we use the default? return DoGemmWithAlgorithm( - batch_size, m, n, k, lhs, rhs, output, workspace_buffer, + desc.lhs, desc.rhs, desc.output, workspace_buffer, static_cast(config.alpha.real()), - static_cast(config.beta), stream, *algorithm, - se::blas::kDefaultComputePrecision, numeric_options, profile_result, - context); + static_cast(config.beta), stream, PrecisionConfig::ALG_UNSET, + *algorithm, se::blas::kDefaultComputePrecision, numeric_options, + profile_result, context); } TYPED_GEMM(F32, BF16, BF16, BF16) @@ -701,16 +738,17 @@ Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, #undef TYPED_GEMM #undef TYPED_GEMM_COMPLEX - return InternalError( + return Internal( "Unexpected GEMM dtype: %s %s %s", - primitive_util::LowercasePrimitiveTypeName(lhs_layout.dtype), - primitive_util::LowercasePrimitiveTypeName(rhs_layout.dtype), - primitive_util::LowercasePrimitiveTypeName(output_layout.dtype)); + primitive_util::LowercasePrimitiveTypeName(config.lhs_layout.dtype), + primitive_util::LowercasePrimitiveTypeName(config.rhs_layout.dtype), + primitive_util::LowercasePrimitiveTypeName(config.output_layout.dtype)); } // namespace gpu namespace gpublas_lt { -StatusOr EpilogueAddsVectorBias(GemmBackendConfig_Epilogue epilogue) { +absl::StatusOr EpilogueAddsVectorBias( + GemmBackendConfig_Epilogue epilogue) { switch (epilogue) { case GemmBackendConfig::DEFAULT: case GemmBackendConfig::RELU: @@ -723,11 +761,12 @@ StatusOr EpilogueAddsVectorBias(GemmBackendConfig_Epilogue epilogue) { case GemmBackendConfig::BIAS_GELU_AUX: return true; default: - return InternalError("Unknown Epilogue."); + return Internal("Unknown Epilogue."); } } -StatusOr EpilogueHasAuxiliaryOutput(GemmBackendConfig_Epilogue epilogue) { +absl::StatusOr EpilogueHasAuxiliaryOutput( + GemmBackendConfig_Epilogue epilogue) { switch (epilogue) { case GemmBackendConfig::DEFAULT: case GemmBackendConfig::RELU: @@ -740,45 +779,50 @@ StatusOr EpilogueHasAuxiliaryOutput(GemmBackendConfig_Epilogue epilogue) { case GemmBackendConfig::BIAS_GELU_AUX: return true; default: - return InternalError("Unknown Epilogue."); + return Internal("Unknown Epilogue."); } } -StatusOr AsBlasLtEpilogue( - mlir::lmhlo_gpu::CublasLtMatmulEpilogue epilogue) { +absl::StatusOr AsBlasLtEpilogue( + GemmBackendConfig_Epilogue epilogue) { switch (epilogue) { - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Default: + case GemmBackendConfig::DEFAULT: return se::gpu::BlasLt::Epilogue::kDefault; - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Relu: + case GemmBackendConfig::RELU: return se::gpu::BlasLt::Epilogue::kReLU; - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Gelu: + case GemmBackendConfig::GELU: return se::gpu::BlasLt::Epilogue::kGELU; - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::GeluAux: + case GemmBackendConfig::GELU_AUX: return se::gpu::BlasLt::Epilogue::kGELUWithAux; - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Bias: + case GemmBackendConfig::BIAS: return se::gpu::BlasLt::Epilogue::kBias; - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::BiasRelu: + case GemmBackendConfig::BIAS_RELU: return se::gpu::BlasLt::Epilogue::kBiasThenReLU; - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::BiasGelu: + case GemmBackendConfig::BIAS_GELU: return se::gpu::BlasLt::Epilogue::kBiasThenGELU; - case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::BiasGeluAux: + case GemmBackendConfig::BIAS_GELU_AUX: return se::gpu::BlasLt::Epilogue::kBiasThenGELUWithAux; + default: + return Internal("unexpected epilogue value"); } - return InternalError("unexpected epilogue value"); } } // namespace gpublas_lt -/*static*/ TritonGemmConfig TritonGemmConfig::FromProto( +/*static*/ absl::StatusOr TritonGemmConfig::FromProto( const AutotuneResult::TritonGemmKey& proto) { - TritonGemmConfig config; - config.block_m = proto.block_m(); - config.block_n = proto.block_n(); - config.block_k = proto.block_k(); - config.split_k = proto.split_k(); - config.num_stages = proto.num_stages(); - config.num_warps = proto.num_warps(); - return config; + // Sanity check to avoid loading incomplete data. + TF_RET_CHECK(proto.block_m() > 0); + TF_RET_CHECK(proto.block_n() > 0); + TF_RET_CHECK(proto.block_k() > 0); + TF_RET_CHECK(proto.split_k() > 0); + TF_RET_CHECK(proto.num_stages() > 0); + TF_RET_CHECK(proto.num_warps() > 0); + TF_RET_CHECK(proto.num_ctas() > 0); + + return TritonGemmConfig(proto.block_m(), proto.block_n(), proto.block_k(), + proto.split_k(), proto.num_stages(), + proto.num_warps(), proto.num_ctas()); } AutotuneResult::TritonGemmKey TritonGemmConfig::ToProto() const { @@ -789,6 +833,7 @@ AutotuneResult::TritonGemmKey TritonGemmConfig::ToProto() const { key.set_split_k(split_k); key.set_num_stages(num_stages); key.set_num_warps(num_warps); + key.set_num_ctas(num_ctas); return key; } @@ -796,7 +841,60 @@ std::string TritonGemmConfig::ToString() const { return absl::StrCat("{block_m:", block_m, ",block_n:", block_n, ",block_k:", block_k, ",split_k:", split_k, ",num_stages:", num_stages, ",num_warps:", num_warps, - "}"); + ",num_ctas:", num_ctas, "}"); +} + +absl::StatusOr IsMatrixMultiplicationTooSmallForRewriting( + const HloInstruction& dot, int64_t threshold) { + CHECK_EQ(dot.opcode(), HloOpcode::kDot); + + const Shape& lhs_shape = dot.operand(0)->shape(); + const Shape& rhs_shape = dot.operand(1)->shape(); + const DotDimensionNumbers& dot_dims = dot.dot_dimension_numbers(); + + int64_t contracting_size = 1; + for (int64_t dim : dot_dims.lhs_contracting_dimensions()) { + contracting_size *= lhs_shape.dimensions(dim); + } + + TF_ASSIGN_OR_RETURN( + std::vector lhs_non_contracting_dims, + GetNonContractingDims(lhs_shape, dot_dims.lhs_batch_dimensions(), + dot_dims.lhs_contracting_dimensions())); + int64_t lhs_non_contracting_size = 1; + for (int64_t dim : lhs_non_contracting_dims) { + lhs_non_contracting_size *= lhs_shape.dimensions(dim); + } + + TF_ASSIGN_OR_RETURN( + std::vector rhs_non_contracting_dims, + GetNonContractingDims(rhs_shape, dot_dims.rhs_batch_dimensions(), + dot_dims.rhs_contracting_dimensions())); + int64_t rhs_non_contracting_size = 1; + for (int64_t dim : rhs_non_contracting_dims) { + rhs_non_contracting_size *= rhs_shape.dimensions(dim); + } + + return (rhs_non_contracting_size + lhs_non_contracting_size) * + contracting_size < + threshold; +} + +bool IsDotSupportedByClassicalEmitters(const HloInstruction& dot) { + if (!algorithm_util::IsSupportedByElementalIrEmitter( + dot.precision_config().algorithm())) { + return false; + } + + // Let us be conservative and only throw float dots at the emitters. + switch (dot.shape().element_type()) { + case F16: + case F32: + case BF16: + return true; + default: + return false; + } } } // namespace gpu diff --git a/xla/service/gpu/matmul_utils.h b/xla/service/gpu/matmul_utils.h index 4d329fea522a7..22d7f17813383 100644 --- a/xla/service/gpu/matmul_utils.h +++ b/xla/service/gpu/matmul_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_MATMUL_UTILS_H_ #define XLA_SERVICE_GPU_MATMUL_UTILS_H_ +#include #include #include #include @@ -23,17 +24,16 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/ir_emission_utils.h" #include "xla/shape.h" -#include "xla/statusor.h" #include "xla/stream_executor/blas.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" #if TENSORFLOW_USE_ROCM @@ -44,7 +44,7 @@ namespace xla { namespace gpu { // Ordered non-contracting dimensions for a dot instruction operand. -StatusOr> GetNonContractingDims( +absl::StatusOr> GetNonContractingDims( const Shape& shape, absl::Span batch_dims, absl::Span contracting_dims); @@ -54,38 +54,47 @@ const tsl::protobuf::RepeatedField& BatchDimensionsForOperand( const HloInstruction& dot, int operand_number); // Index of the only contracting dimension of dot instruction operand. -int64_t ContractingDimensionIndex(const HloInstruction& dot, - int operand_number); +absl::StatusOr ContractingDimensionIndex(const HloInstruction& dot, + int operand_number); // Index of the only non-contracting dimension of dot instruction operand. -int64_t NonContractingDimensionIndex(const HloInstruction& dot, - int operand_number); +absl::StatusOr NonContractingDimensionIndex(const HloInstruction& dot, + int operand_number); // Normalize shape to (batch, rows, columns) logical dimensions. -StatusOr GetBatchRowColumnShape(const Shape& shape, - absl::Span batch_dims, - absl::Span row_dims, - absl::Span col_dims); +absl::StatusOr GetBatchRowColumnShape( + const Shape& shape, absl::Span batch_dims, + absl::Span row_dims, absl::Span col_dims); // GPU folding rule for the `TransposeFolding` pass. -StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, - int64_t operand_idx); +absl::StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, + int64_t operand_idx); + +// Returns true if the sum of the sizes of the unbatched operand matrices +// for the dot is smaller than the given threshold. +absl::StatusOr IsMatrixMultiplicationTooSmallForRewriting( + const HloInstruction& dot, int64_t threshold); + +// Returns true if the backend can lower the dot. Currently the classical +// emitters cannot handle some dots, e.g., i8[] x i8[] -> i32[] dots, +// so we need to always use cuBLAS or Triton for those. +bool IsDotSupportedByClassicalEmitters(const HloInstruction& dot); // extending plain MatrixLayout struct with creator functions struct MatrixLayout : public se::gpu::MatrixLayout { // Returns the matrix layout for a logical shape (batch, rows, columns). - static StatusOr For(const Shape& shape); + static absl::StatusOr For(const Shape& shape); // Returns the matrix layout with the given batch, row, col dimensions. - static StatusOr For(const Shape& shape, - absl::Span batch_dims, - absl::Span row_dims, - absl::Span col_dims); + static absl::StatusOr For(const Shape& shape, + absl::Span batch_dims, + absl::Span row_dims, + absl::Span col_dims); // Returns the matrix layout for the output. - static StatusOr For(const Shape& shape, - size_t lhs_num_batch_dims, - size_t lhs_num_row_dims, - size_t rhs_num_batch_dims, - size_t rhs_num_col_dims); + static absl::StatusOr For(const Shape& shape, + size_t lhs_num_batch_dims, + size_t lhs_num_row_dims, + size_t rhs_num_batch_dims, + size_t rhs_num_col_dims); }; struct GemmConfig : public se::gpu::GemmConfig { @@ -97,86 +106,63 @@ struct GemmConfig : public se::gpu::GemmConfig { static constexpr int64_t kHopperWorkspace = 32 * 1024 * 1024; // 32 MiB static constexpr int64_t kDefaultWorkspace = 4 * 1024 * 1024; // 4 MiB - static StatusOr For(const HloInstruction* gemm); - static StatusOr For(mlir::lmhlo_gpu::GEMMOp op); + static absl::StatusOr For(const HloInstruction* gemm); - static StatusOr For( + static absl::StatusOr For( const Shape& lhs_shape, absl::Span lhs_batch_dims, absl::Span lhs_contracting_dims, const Shape& rhs_shape, absl::Span rhs_batch_dims, absl::Span rhs_contracting_dims, const Shape& output_shape, double alpha_real, double alpha_imag, double beta, + PrecisionConfig::Algorithm precision_algorithm, std::optional algorithm, int64_t compute_precision, bool grad_x, bool grad_y); // As above with additional `c_shape` and `bias_shape_ptr` parameter, both // which are only necessarily for F8 gemms. - static StatusOr For( + static absl::StatusOr For( const Shape& lhs_shape, absl::Span lhs_batch_dims, absl::Span lhs_contracting_dims, const Shape& rhs_shape, absl::Span rhs_batch_dims, absl::Span rhs_contracting_dims, const Shape& c_shape, const Shape* bias_shape_ptr, const Shape& output_shape, double alpha_real, - double alpha_imag, double beta, std::optional algorithm, - int64_t compute_precision, bool grad_x, bool grad_y); - - template ::value || - std::is_same::value>> - static StatusOr For(CublasLtMatmulMaybeF8Op op) { - mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.getDotDimensionNumbers(); - - int64_t compute_precision = 0; // Default - if (op.getPrecisionConfig().has_value()) { - auto precision_config = op.getPrecisionConfig(); - for (auto attr : precision_config.value()) { - int64_t value = static_cast( - attr.template cast().getValue()); - if (value > compute_precision) { - compute_precision = value; - } - } - } - - Shape bias_shape; - if (op.getBias() != nullptr) { - bias_shape = GetShape(op.getBias()); - } - return GemmConfig::For( - GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(), - dot_dims.getLhsContractingDimensions(), GetShape(op.getB()), - dot_dims.getRhsBatchingDimensions(), - dot_dims.getRhsContractingDimensions(), GetShape(op.getC()), - op.getBias() == nullptr ? nullptr : &bias_shape, GetShape(op.getD()), - op.getAlphaReal().convertToDouble(), - op.getAlphaImag().convertToDouble(), op.getBeta().convertToDouble(), - op.getAlgorithm(), compute_precision, /*grad_x=*/false, - /*grad_y=*/false); - } + double alpha_imag, double beta, + PrecisionConfig::Algorithm precision_algorithm, + std::optional algorithm, int64_t compute_precision, bool grad_x, + bool grad_y); + + struct DescriptorsTuple { + se::gpu::MatrixDescriptor lhs; + se::gpu::MatrixDescriptor rhs; + se::gpu::OutputMatrixDescriptor output; + bool operands_swapped; + }; + absl::StatusOr GetMatrixDescriptors( + se::DeviceMemoryBase lhs_buf, se::DeviceMemoryBase rhs_buf, + se::DeviceMemoryBase out_buf) const; }; // Run the given GEMM instruction `gemm` subject to the configuration // in `gemm_config` and the passed buffers. // // If `algorithm` is provided, it overrides the one specified in `config`. -Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, - se::DeviceMemoryBase output_buffer, - se::DeviceMemoryBase workspace_buffer, bool deterministic_ops, - se::Stream* stream, - std::optional algorithm = std::nullopt, - se::blas::ProfileResult* profile_result = nullptr); +absl::Status RunGemm( + const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, + se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, + se::DeviceMemoryBase workspace_buffer, bool deterministic_ops, + se::Stream* stream, + std::optional algorithm = std::nullopt, + se::blas::ProfileResult* profile_result = nullptr); namespace gpublas_lt { -StatusOr EpilogueAddsVectorBias(GemmBackendConfig_Epilogue epilogue); -StatusOr EpilogueHasAuxiliaryOutput(GemmBackendConfig_Epilogue epilogue); +absl::StatusOr EpilogueAddsVectorBias( + GemmBackendConfig_Epilogue epilogue); +absl::StatusOr EpilogueHasAuxiliaryOutput( + GemmBackendConfig_Epilogue epilogue); -StatusOr AsBlasLtEpilogue( - mlir::lmhlo_gpu::CublasLtMatmulEpilogue epilogue); +absl::StatusOr AsBlasLtEpilogue( + GemmBackendConfig_Epilogue epilogue); } // namespace gpublas_lt @@ -185,29 +171,41 @@ StatusOr AsBlasLtEpilogue( struct TritonGemmConfig { constexpr TritonGemmConfig() = default; constexpr TritonGemmConfig(int block_m, int block_n, int block_k, int split_k, - int num_stages, int num_warps) + int num_stages, int num_warps, int num_ctas = 1) : block_m(block_m), block_n(block_n), block_k(block_k), split_k(split_k), num_stages(num_stages), - num_warps(num_warps) {} - + num_warps(num_warps), + num_ctas(num_ctas) {} int block_m = 0; int block_n = 0; int block_k = 0; int split_k = 0; int num_stages = 0; int num_warps = 0; + // Number of blocks in a block cluster. + int num_ctas = 0; + + // When adding new members, please update all methods, such as ToTuple, + // FromProto, ToProto, ToString, etc. Updating ToTuple is not enough. + // Please also add new members to AutotuneResult::TritonGemmKey in + // autotuning.proto. Also kVersion has to be incremented in autotuner_util.cc + // and all the autotuning results stored in tests, repos, etc. will have to + // be updated. private: auto ToTuple() const { return std::make_tuple(block_m, block_n, block_k, split_k, num_stages, - num_warps); + num_warps, num_ctas); } public: - static TritonGemmConfig FromProto(const AutotuneResult::TritonGemmKey& proto); + // Creates a TritonGemmConfig from the supplied proto, doing a simple sanity + // check. + static absl::StatusOr FromProto( + const AutotuneResult::TritonGemmKey& proto); AutotuneResult::TritonGemmKey ToProto() const; std::string ToString() const; diff --git a/xla/service/gpu/matmul_utils_test.cc b/xla/service/gpu/matmul_utils_test.cc index 09b98a1007603..c3ccdb517438b 100644 --- a/xla/service/gpu/matmul_utils_test.cc +++ b/xla/service/gpu/matmul_utils_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,13 +15,18 @@ limitations under the License. #include "xla/service/gpu/matmul_utils.h" +#include +#include #include #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_parser.h" +#include "xla/shape.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -225,6 +230,115 @@ TEST(GetMatrixLayoutTest, BatchInMostMinorPhysicalDimension) { EXPECT_FALSE(MatrixLayout::For(shape).ok()); } +using GetMatrixSizeRewriteThresholdTest = HloTestBase; + +TEST_F(GetMatrixSizeRewriteThresholdTest, MatMulTooSmallForRewrite) { + const char* hlo_text = R"( +HloModule DotFuncModule + +ENTRY DotFunc { + x = f32[100,30,3] parameter(0) + y = f32[100,3,3] parameter(1) + ROOT dot = f32[100,30,3] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} +} + +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + auto dot = module->entry_computation()->root_instruction(); + EXPECT_THAT(IsMatrixMultiplicationTooSmallForRewriting(*dot, 100), + IsOkAndHolds(true)); +} + +TEST_F(GetMatrixSizeRewriteThresholdTest, MatMulSupportedByClassicalEmitters) { + const char* hlo_text = R"( +HloModule DotFuncModule + +ENTRY DotFunc { + x = f32[100,30,3] parameter(0) + y = f32[100,3,3] parameter(1) + ROOT dot = f32[100,30,3] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} +} + +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + auto dot = module->entry_computation()->root_instruction(); + EXPECT_TRUE(IsDotSupportedByClassicalEmitters(*dot)); +} + +TEST_F(GetMatrixSizeRewriteThresholdTest, + MatMulUnsupportedByClassicalEmitters) { + const char* hlo_text = R"( +HloModule DotFuncModule + +ENTRY DotFunc { + x = s8[100,30,3] parameter(0) + y = s8[100,3,3] parameter(1) + ROOT dot = s32[100,30,3] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} +} + +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + auto dot = module->entry_computation()->root_instruction(); + EXPECT_FALSE(IsDotSupportedByClassicalEmitters(*dot)); +} + +TEST_F(GetMatrixSizeRewriteThresholdTest, MatMulLeftLargeEnoughForRewrite) { + const char* hlo_text = R"( +HloModule DotFuncModule + +ENTRY DotFunc { + x = f32[50,2] parameter(0) + y = f32[2,2] parameter(1) + ROOT dot = f32[50,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + auto dot = module->entry_computation()->root_instruction(); + EXPECT_THAT(IsMatrixMultiplicationTooSmallForRewriting(*dot, 100), + IsOkAndHolds(false)); +} + +TEST_F(GetMatrixSizeRewriteThresholdTest, MatMulRightLargeEnoughForRewrite) { + const char* hlo_text = R"( +HloModule DotFuncModule + +ENTRY DotFunc { + x = f32[2,2] parameter(0) + y = f32[2,50] parameter(1) + ROOT dot = f32[2,50] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + auto dot = module->entry_computation()->root_instruction(); + EXPECT_THAT(IsMatrixMultiplicationTooSmallForRewriting(*dot, 100), + IsOkAndHolds(false)); +} + +TEST_F(GetMatrixSizeRewriteThresholdTest, MatMulTogetherLargeEnoughForRewrite) { + const char* hlo_text = R"( +HloModule DotFuncModule + +ENTRY DotFunc { + x = f32[4,16] parameter(0) + y = f32[16,4] parameter(1) + ROOT dot = f32[4,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + auto dot = module->entry_computation()->root_instruction(); + EXPECT_THAT(IsMatrixMultiplicationTooSmallForRewriting(*dot, 100), + IsOkAndHolds(false)); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/memset_thunk.cc b/xla/service/gpu/memset_thunk.cc deleted file mode 100644 index 5b2380c0a425c..0000000000000 --- a/xla/service/gpu/memset_thunk.cc +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/memset_thunk.h" - -#include "xla/stream_executor/stream_executor.h" - -namespace xla { -namespace gpu { - -Status MemzeroThunk::ExecuteOnStream(const ExecuteParams& params) { - se::DeviceMemoryBase dest_data = - params.buffer_allocations->GetDeviceAddress(dest_); - params.stream->ThenMemZero(&dest_data, dest_data.size()); - return OkStatus(); -} - -Status Memset32BitValueThunk::ExecuteOnStream(const ExecuteParams& params) { - se::DeviceMemoryBase dest_data = - params.buffer_allocations->GetDeviceAddress(dest_); - params.stream->ThenMemset32(&dest_data, value_, dest_data.size()); - return OkStatus(); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/memset_thunk.h b/xla/service/gpu/memset_thunk.h deleted file mode 100644 index 4529571b0ebca..0000000000000 --- a/xla/service/gpu/memset_thunk.h +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_MEMSET_THUNK_H_ -#define XLA_SERVICE_GPU_MEMSET_THUNK_H_ - -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/thunk.h" -#include "xla/status.h" -#include "xla/stream_executor/stream_executor.h" - -// This file contains thunks that set a buffer's elements to a particular value. -// This can be faster than emitting a kernel to set the elements. - -namespace xla { -namespace gpu { - -// Thunk that zeroes out a given chunk of memory. -class MemzeroThunk : public Thunk { - public: - explicit MemzeroThunk(ThunkInfo thunk_info, - const BufferAllocation::Slice& dest, - mlir::Value dest_value) - : Thunk(Kind::kMemzero, thunk_info), - dest_(dest), - dest_value_(dest_value) {} - - Status ExecuteOnStream(const ExecuteParams& params) override; - - void ClearCompileTimeInfo() override { - Thunk::ClearCompileTimeInfo(); - dest_value_ = nullptr; - } - - const BufferAllocation::Slice& destination() const { return dest_; } - mlir::Value dest_value() const { return dest_value_; } - - private: - const BufferAllocation::Slice dest_; - mlir::Value dest_value_; -}; - -// Thunk that sets a given chunk of memory to a particular 32-bit value. The -// destination chunk must have size divisible by 32 bits. -class Memset32BitValueThunk : public Thunk { - public: - explicit Memset32BitValueThunk(ThunkInfo thunk_info, uint32_t value, - const BufferAllocation::Slice& dest, - mlir::Value dest_value) - : Thunk(Kind::kMemset32BitValue, thunk_info), - value_(value), - dest_(dest), - dest_value_(dest_value) {} - - Status ExecuteOnStream(const ExecuteParams& params) override; - - void ClearCompileTimeInfo() override { - Thunk::ClearCompileTimeInfo(); - dest_value_ = nullptr; - } - - const BufferAllocation::Slice& destination() const { return dest_; } - uint32_t value() const { return value_; } - mlir::Value dest_value() const { return dest_value_; } - - private: - const uint32_t value_; - const BufferAllocation::Slice dest_; - mlir::Value dest_value_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_MEMSET_THUNK_H_ diff --git a/xla/service/gpu/metrics.cc b/xla/service/gpu/metrics.cc index 44fff398f23ac..6d42b92d8bccb 100644 --- a/xla/service/gpu/metrics.cc +++ b/xla/service/gpu/metrics.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/metrics.h b/xla/service/gpu/metrics.h index d660743838a26..c3579dba6e4cf 100644 --- a/xla/service/gpu/metrics.h +++ b/xla/service/gpu/metrics.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/mock_nccl_topo_config.h b/xla/service/gpu/mock_nccl_topo_config.h new file mode 100644 index 0000000000000..7125a06475654 --- /dev/null +++ b/xla/service/gpu/mock_nccl_topo_config.h @@ -0,0 +1,296 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MOCK_NCCL_TOPO_CONFIG_H_ +#define XLA_SERVICE_GPU_MOCK_NCCL_TOPO_CONFIG_H_ + +namespace xla { +namespace gpu { +// Nccl device topology info generated by the NCCL_TOPO_DUMP_FILE of the Nccl +// library. See +// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-topo-dump-file +// for more details. +// kGCPA3 is for GCP A3 VM. +// kNvidia is for Nvidia A100 VM +const char kGCPA3[] = R"( + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +)"; +const char kNvidia[] = R"( + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +)"; +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MOCK_NCCL_TOPO_CONFIG_H_ diff --git a/xla/service/gpu/mock_nccl_utils.cc b/xla/service/gpu/mock_nccl_utils.cc new file mode 100644 index 0000000000000..c3390da949c02 --- /dev/null +++ b/xla/service/gpu/mock_nccl_utils.cc @@ -0,0 +1,846 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/mock_nccl_utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/gpus/cuda/include/driver_types.h" +#include "third_party/gpus/cuda/include/vector_types.h" +#include "third_party/gpus/nccl/graph/topo.h" +#include "third_party/gpus/nccl/graph/xml.h" +#include "third_party/gpus/nccl/include/alloc.h" +#include "third_party/gpus/nccl/include/comm.h" +#include "third_party/gpus/nccl/include/graph.h" +#include "third_party/gpus/nccl/include/info.h" +#include "third_party/gpus/nccl/include/nccl_common.h" +#include "third_party/nccl/nccl.h" +#include "xla/debug_options_flags.h" +#include "xla/executable_run_options.h" +#include "xla/primitive_util.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/global_device_id.h" +#include "xla/service/gpu/gpu_executable_run_options.h" +#include "xla/service/gpu/mock_nccl_topo_config.h" +#include "xla/service/gpu/mock_nccl_xml.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_clique.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/service/gpu/runtime/nccl_p2p_thunk_common.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/service/gpu/sleep_kernel.h" +#include "xla/service/lockable.h" +#include "xla/service/rendezvous.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/gpu/gpu_activation.h" +#include "xla/stream_executor/gpu/gpu_stream.h" +#include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/stream.h" +#include "xla/util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +//==-----------------------------------------------------------------------===// +// Macros to return or warn on NCCL errors. +//==-----------------------------------------------------------------------===// + +static absl::Status ToStatus(ncclResult_t s, const char* file, int64_t line, + const char* expr) { + if (s == ncclSuccess) return absl::OkStatus(); + + return absl::InternalError(absl::StrFormat( + "%s:%d: NCCL operation %s failed: %s." + " Last NCCL warning(error) log entry (may be unrelated) '%s'.", + file, line, expr, ncclGetErrorString(s), ncclGetLastError(nullptr))); +} + +#define XLA_NCCL_STATUS(expr) \ + xla::gpu::ToStatus(expr, __FILE__, __LINE__, #expr) + +#define XLA_NCCL_RETURN_IF_ERROR(expr) \ + do { \ + absl::Status s = XLA_NCCL_STATUS(expr); \ + if (!s.ok()) { \ + return s; \ + } \ + } while (0) + +#define XLA_NCCL_LOG_IF_ERROR(expr) \ + do { \ + absl::Status s = XLA_NCCL_STATUS(expr); \ + if (!s.ok()) { \ + LOG(ERROR) << s.ToString(); \ + } \ + } while (0) + +//==-----------------------------------------------------------------------===// + +static absl::StatusOr ToNcclDataType(PrimitiveType element_type, + Thunk::Kind reduction_op) { + switch (element_type) { + case S8: + case F8E5M2: + case F8E4M3FN: + return ncclInt8; + case PRED: + case U8: + return ncclUint8; + case S32: + return ncclInt32; + case U32: + return ncclUint32; + case S64: + return ncclInt64; + case U64: + return ncclUint64; + case F16: + return ncclFloat16; + case F32: + case C64: + return ncclFloat32; + case F64: + case C128: + return ncclFloat64; + case S16: + case U16: + // For all-reduce and reduce-scatter, we expect 16 bit integer types to be + // promoted to 32-bit. + if (reduction_op == Thunk::kNcclAllReduce || + reduction_op == Thunk::kNcclAllReduceStart || + reduction_op == Thunk::kNcclReduceScatter) { + return tsl::errors::InvalidArgument(absl::StrFormat( + "Unsupported data type: %s", PrimitiveType_Name(element_type))); + } + // For collectives that just move data around, we can use ncclFloat16 for + // 16-bit integer data types. + return ncclFloat16; +#if defined(__CUDA_BF16_TYPES_EXIST__) || TENSORFLOW_USE_ROCM + case BF16: + return ncclBfloat16; +#endif + default: + return tsl::errors::InvalidArgument(absl::StrFormat( + "Unsupported data type: %s", PrimitiveType_Name(element_type))); + } +} + +static absl::StatusOr> +ToNcclDataTypeAndCountMultiplier(PrimitiveType element_type, + Thunk::Kind reduction_op) { + TF_ASSIGN_OR_RETURN(ncclDataType_t dtype, + ToNcclDataType(element_type, reduction_op)); + bool is_complex = primitive_util::IsComplexType(element_type); + return std::make_pair(dtype, is_complex ? 2 : 1); +} + +using ncclInfo_t = ncclInfo*; + +absl::StatusOr GetNcclDataTypeSize(ncclDataType_t dtype) { + switch (dtype) { + case ncclInt8: + case ncclUint8: + return 1; + case ncclInt32: + case ncclUint32: + return 4; + case ncclInt64: + case ncclUint64: + return 8; + case ncclFloat16: + return 2; + case ncclFloat32: + return 4; + case ncclFloat64: + return 8; +#if defined(__CUDA_BF16_TYPES_EXIST__) || TENSORFLOW_USE_ROCM + case ncclBfloat16: + return 2; +#endif + default: + return absl::InvalidArgumentError( + absl::StrFormat("Unsupported nccl data type: %d", dtype)); + } +} + +absl::StatusOr ToNcclFunctionType(Thunk::Kind reduce_op) { + switch (reduce_op) { + case Thunk::kNcclAllReduce: + return ncclFuncAllReduce; + case Thunk::kNcclAllGather: + return ncclFuncAllGather; + case Thunk::kNcclReduceScatter: + return ncclFuncReduceScatter; + case Thunk::kNcclSend: + return ncclFuncSend; + case Thunk::kNcclRecv: + return ncclFuncRecv; + default: + return absl::InvalidArgumentError( + absl::StrFormat("Unsupported nccl function type: %d", reduce_op)); + } +} + +absl::Status LaunchSleepKernel(se::StreamExecutor* executor, + se::gpu::GpuStreamHandle gpu_stream, + ncclInfo_t info, int64_t sleep_duration) { + void* kernel = GetSleepKernel(); + int64_t clock_cycles = + sleep_duration * executor->GetDeviceDescription().clock_rate_ghz(); + void* kernel_args[] = {&clock_cycles}; + dim3 gridDim = {1, 1, 1}; + dim3 blockDim = {512, 1, 1}; + cudaError_t launch_status = + cudaLaunchKernel(kernel, gridDim, blockDim, kernel_args, 0, gpu_stream); + if (launch_status != cudaSuccess) { + return absl::InternalError(absl::StrCat("Failed to launch kernel: ", + cudaGetErrorString(launch_status))); + } + return absl::OkStatus(); +} + +inline absl::Status MockNcclInfoSetDerived(ncclInfo_t info, int nRanks) { + TF_ASSIGN_OR_RETURN(int dtype_size, GetNcclDataTypeSize(info->datatype)); + info->nBytes = info->count * dtype_size; + if (info->coll == ncclFuncAllGather || info->coll == ncclFuncBroadcast) { + info->count = info->nBytes; + info->datatype = ncclInt8; + } + if (info->coll == ncclFuncAllGather || info->coll == ncclFuncReduceScatter) + info->nBytes *= nRanks; // count is per rank + return absl::OkStatus(); +} + +// Return estimated sleep time in nano seconds for simulating the nccl +// collective calls +absl::StatusOr GetMockNcclSleepTime(size_t count, + ncclDataType_t datatype, + ncclComm_t comm, + cudaStream_t stream, + ncclInfo_t info) { + info->count = count; + info->datatype = datatype; + info->nChannels = 1; + info->algorithm = -1; + info->protocol = -1; + + TF_RETURN_IF_ERROR(MockNcclInfoSetDerived(info, comm->nRanks)); + + int numPipeOps = 1; // number of pipelined ops. Used to adjust latency. + // Assume 1 for simplicity. + float minTime = std::numeric_limits::infinity(); + float time = 0.0f; + if (info->coll == ncclFuncAllReduce) { + XLA_NCCL_RETURN_IF_ERROR(ncclTopoGetAlgoTime( + info, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, numPipeOps, &time)); + info->algorithm = NCCL_ALGO_RING; + info->protocol = NCCL_PROTO_SIMPLE; + minTime = time; + } else { + for (int p = 0; p < 3; p++) { + XLA_NCCL_RETURN_IF_ERROR( + ncclTopoGetAlgoTime(info, NCCL_ALGO_RING, p, numPipeOps, &time)); + if (time > 0 && time < minTime) { + info->algorithm = NCCL_ALGO_RING; + info->protocol = p; + minTime = time; + } + } + } + return ceil(minTime * 1000); +} + +// Create the mock nccl communicator assuming all hosts have the same hardwares. +// We first create a local nccl communicator for gpus within a single host; then +// together with the input clique, we re-run nccl algorithms to construct the +// target nccl topology graphs. +absl::StatusOr LockMockNcclComm( + const Thunk::CollectiveExecuteParams& params, + const std::vector& replica_groups, + CollectiveOpGroupMode group_mode, int64_t op_id, int64_t stream_id, + bool enable_clique_optimization, + GpuExecutableRunOptions::MockNcclTopoModel topo_model) { + GlobalDeviceId global_device_id = params.global_device_id; + + TF_ASSIGN_OR_RETURN( + std::vector participants, + GetParticipatingDevices(global_device_id, *params.device_assn, + replica_groups, group_mode)); + + if (IsGlobalNcclConfig() && + (participants.size() != params.device_assn->replica_count())) { + return InvalidArgument( + "Partial replica groups are not allowed when using NCCL_COMM_ID " + "environment configuration."); + } + + std::vector local_devices; + if (params.global_device_id_map) { + local_devices.reserve(params.global_device_id_map->size()); + for (const auto& entry : *params.global_device_id_map) { + local_devices.push_back(entry.second); + } + } else { + local_devices = participants; + } + TF_ASSIGN_OR_RETURN( + const NcclCliqueIdCallback* clique_id_callback, + GetNcclCliqueIdCallback(params.nccl_clique_id_callback, true)); + + size_t num_local_participants = GetNumLocalParticipants( + participants, params.global_device_id_map ? &local_devices : nullptr); + + auto global_it = absl::c_find(participants, global_device_id); + TF_RET_CHECK(global_it != participants.end()); + int global_rank = global_it - participants.begin(); + + if (global_rank != 0) { + return absl::CancelledError("Only mock nccl call for gpu rank 0"); + } + + return AcquireMockNcclComm(params.run_id, OpId(op_id), + std::move(participants), std::move(local_devices), + num_local_participants, *clique_id_callback, + global_rank, stream_id, false, topo_model); +} + +absl::Status RunMockNcclCollectives(NcclApi* nccl_api, + std::vector& buffers, + se::Stream& stream, + NcclApi::NcclCommHandle comm, + Thunk::Kind reduce_op) { + ncclComm_t mock_comm = reinterpret_cast(comm); + + int device_ordinal = stream.parent()->device_ordinal(); + VLOG(3) << "Performing the mock nccl collective call from device ordinal: " + << device_ordinal; + se::StreamExecutor* executor = stream.parent(); + se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream); + ncclInfo info; + TF_ASSIGN_OR_RETURN(info.coll, ToNcclFunctionType(reduce_op)); + info.comm = reinterpret_cast(mock_comm); + info.stream = gpu_stream; + + int64_t total_element_count = 0; + ncclDataType_t previous_dtype = ncclNumTypes; + int64_t sleep_duration = 0; + for (size_t i = 0; i < buffers.size(); ++i) { + DeviceBufferPair& buffer = buffers[i]; + PrimitiveType element_type = buffer.element_type; + TF_ASSIGN_OR_RETURN( + auto dtype_and_multiplier, + ToNcclDataTypeAndCountMultiplier(element_type, reduce_op)); + ncclDataType_t dtype = dtype_and_multiplier.first; + int64_t element_count = buffer.element_count * dtype_and_multiplier.second; + if (reduce_op == Thunk::kNcclReduceScatter) + element_count = element_count / mock_comm->nRanks; + if (i == 0 || dtype == previous_dtype) { + previous_dtype = dtype; + total_element_count += element_count; + continue; + } + + TF_ASSIGN_OR_RETURN(sleep_duration, GetMockNcclSleepTime( + total_element_count, previous_dtype, + mock_comm, gpu_stream, &info)); + TF_RETURN_IF_ERROR( + LaunchSleepKernel(executor, gpu_stream, &info, sleep_duration)); + total_element_count = element_count; + previous_dtype = dtype; + } + + TF_ASSIGN_OR_RETURN(sleep_duration, + GetMockNcclSleepTime(total_element_count, previous_dtype, + mock_comm, gpu_stream, &info)); + + TF_RETURN_IF_ERROR( + LaunchSleepKernel(executor, gpu_stream, &info, sleep_duration)); + VLOG(3) << "Done performing the mock nccl collective call for ordinal: " + << device_ordinal; + return absl::OkStatus(); +} + +absl::Status RunMockNcclAllToAll(NcclApi* nccl_api, bool has_split_dimension, + std::vector& buffers, + se::Stream& stream, + NcclApi::NcclCommHandle comm) { + ncclComm_t mock_comm = reinterpret_cast(comm); + + se::StreamExecutor* executor = stream.parent(); + se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream); + int num_participants = mock_comm->nRanks; + + ncclInfo info; + info.comm = mock_comm; + info.stream = gpu_stream; + + int64_t sleep_duration = 0; + + // AllToAll can operate in two modes. Either it specifies a split dimension, + // in which case inputs are split and outputs concatenated in that dimension + // (here, we only support dimension 0), or it takes a list of inputs + // and produces a tuple of outputs. + if (has_split_dimension) { + for (size_t i = 0; i < buffers.size(); ++i) { + DeviceBufferPair& buffer = buffers[i]; + const uint8_t* send_buffer = + static_cast(buffer.source_buffer.opaque()); + uint8_t* recv_buffer = + static_cast(buffer.destination_buffer.opaque()); + + TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier, + ToNcclDataTypeAndCountMultiplier( + buffer.element_type, Thunk::kNcclAllToAll)); + ncclDataType_t dtype = dtype_and_multiplier.first; + int64_t element_count = + buffer.element_count * dtype_and_multiplier.second; + + TF_RET_CHECK(element_count % num_participants == 0) + << "Buffer was not an exact multiple of the number of participants."; + size_t chunk_elements = element_count / num_participants; + size_t chunk_bytes = chunk_elements * ShapeUtil::ByteSizeOfPrimitiveType( + buffer.element_type); + for (int rank = 0; rank < num_participants; ++rank) { + VLOG(3) << absl::StreamFormat( + "Calling mock ncclSend(sendbuff=%p, count=%d, peer=%d " + "comm=%p, stream=%p)", + send_buffer + rank * chunk_bytes, chunk_elements, rank, + static_cast(mock_comm), gpu_stream); + info.coll = ncclFuncSend; + TF_ASSIGN_OR_RETURN(sleep_duration, + GetMockNcclSleepTime(chunk_elements, dtype, + mock_comm, gpu_stream, &info)); + TF_RETURN_IF_ERROR( + LaunchSleepKernel(executor, gpu_stream, &info, sleep_duration)); + + VLOG(3) << absl::StreamFormat( + "Calling mock ncclRecv(recvbuff=%p, count=%d, peer=%d " + "comm=%p, stream=%p)", + recv_buffer + rank * chunk_bytes, chunk_elements, rank, + static_cast(mock_comm), gpu_stream); + + info.coll = ncclFuncRecv; + TF_ASSIGN_OR_RETURN(sleep_duration, + GetMockNcclSleepTime(chunk_elements, dtype, + mock_comm, gpu_stream, &info)); + TF_RETURN_IF_ERROR( + LaunchSleepKernel(executor, gpu_stream, &info, sleep_duration)); + } + } + } else { + TF_RET_CHECK(buffers.size() == num_participants) + << "Number of inputs didn't match the number of participants."; + for (size_t i = 0; i < buffers.size(); ++i) { + DeviceBufferPair& buffer = buffers[i]; + const uint8_t* send_buffer = + static_cast(buffer.source_buffer.opaque()); + uint8_t* recv_buffer = + static_cast(buffer.destination_buffer.opaque()); + + TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier, + ToNcclDataTypeAndCountMultiplier( + buffer.element_type, Thunk::kNcclAllToAll)); + ncclDataType_t dtype = dtype_and_multiplier.first; + int64_t element_count = + buffer.element_count * dtype_and_multiplier.second; + + VLOG(3) << absl::StreamFormat( + "Calling mock ncclSend(sendbuff=%p, count=%d, peer=%d " + "comm=%p, stream=%p)", + send_buffer, element_count, i, static_cast(mock_comm), + gpu_stream); + + info.coll = ncclFuncSend; + TF_ASSIGN_OR_RETURN(sleep_duration, + GetMockNcclSleepTime(element_count, dtype, mock_comm, + gpu_stream, &info)); + TF_RETURN_IF_ERROR( + LaunchSleepKernel(executor, gpu_stream, &info, sleep_duration)); + + VLOG(3) << absl::StreamFormat( + "Calling mock ncclRecv(recvbuff=%p, count=%d, peer=%d " + "comm=%p, stream=%p)", + recv_buffer, element_count, i, static_cast(mock_comm), + gpu_stream); + + info.coll = ncclFuncRecv; + TF_ASSIGN_OR_RETURN(sleep_duration, + GetMockNcclSleepTime(element_count, dtype, mock_comm, + gpu_stream, &info)); + TF_RETURN_IF_ERROR( + LaunchSleepKernel(executor, gpu_stream, &info, sleep_duration)); + } + } + + VLOG(3) << "Done performing mock all-to-all "; + return absl::OkStatus(); +} + +absl::Status RunMockCollectivePermute( + NcclApi* nccl_api, NcclP2PConfig::SourceTargetMapEntry source_target, + DeviceBufferPair& buffer, se::Stream& stream, NcclApi::NcclCommHandle comm, + absl::string_view device_string, int64_t current_id) { + ncclComm_t mock_comm = reinterpret_cast(comm); + + se::StreamExecutor* executor = stream.parent(); + int device_ordinal = stream.parent()->device_ordinal(); + VLOG(3) << "Performing collective permute from device ordinal: " + << device_ordinal << "current_id " << current_id; + + const std::optional source_id = source_target.source; + const std::optional target_id = source_target.target; + + se::DeviceMemoryBase src_addr = buffer.source_buffer; + se::DeviceMemoryBase dest_addr = buffer.destination_buffer; + + VLOG(3) << absl::StreamFormat("%s : id = %d, source_id = %d, target_id = %d", + device_string, current_id, + source_id.value_or(-1), target_id.value_or(-1)); + + TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier, + ToNcclDataTypeAndCountMultiplier( + buffer.element_type, Thunk::kNcclCollectivePermute)); + ncclDataType_t dtype = dtype_and_multiplier.first; + int64_t element_count = buffer.element_count * dtype_and_multiplier.second; + + se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream); + ncclInfo info; + info.comm = mock_comm; + info.stream = gpu_stream; + + int64_t sleep_duration = 0; + + // Send source buffer to target peer if needed. + if (target_id) { + info.coll = ncclFuncSend; + VLOG(3) << absl::StreamFormat( + "%s : Calling mock ncclSend(sendbuff=%p, count=%d, peer=%d " + "comm=%p, stream=%p)", + device_string, src_addr.opaque(), element_count, *target_id, + static_cast(mock_comm), gpu_stream); + TF_ASSIGN_OR_RETURN(sleep_duration, + GetMockNcclSleepTime(element_count, dtype, mock_comm, + gpu_stream, &info)); + TF_RETURN_IF_ERROR( + LaunchSleepKernel(executor, gpu_stream, &info, sleep_duration)); + } + + // Receive data from the source peer to the destination buffer. + if (source_id) { + info.coll = ncclFuncRecv; + VLOG(3) << absl::StreamFormat( + "%s : Calling mock ncclRecv(recvbuff=%p, count=%d, peer=%d comm=%p, " + "stream=%p)", + device_string, dest_addr.opaque(), element_count, *source_id, + static_cast(mock_comm), gpu_stream); + TF_ASSIGN_OR_RETURN(sleep_duration, + GetMockNcclSleepTime(element_count, dtype, mock_comm, + gpu_stream, &info)); + TF_RETURN_IF_ERROR( + LaunchSleepKernel(executor, gpu_stream, &info, sleep_duration)); + } + + VLOG(3) << "Done performing the mock nccl collective call for ordinal: " + << device_ordinal; + + if (!source_id) { + // If there is no source peer, i.e. no one send us any data, zero out dest + // buffer. + VLOG(3) << absl::StreamFormat( + "%s : mock collective-Permute: Issuing MemZero", device_string); + return stream.MemZero(&dest_addr, dest_addr.size()); + } + return absl::OkStatus(); +} + +namespace { +void CheckNcclAsyncError(NcclComm& lockable_comm) { + NcclApi::NcclCommHandle comm = *lockable_comm.Acquire(); + if (comm == nullptr) return; + + absl::Status status = NcclApi::Default()->CommGetAsyncError(comm); + if (!status.ok()) LOG(ERROR) << status; +} + +struct NcclCliqueState { + NcclCliqueId clique_id; + int64_t run_id = -1; + + // `mu` guards `communicators` and `status` during initialization. + // Once `ready` has been notified, the communicators may be accessed without + // synchronization. + absl::Mutex mu; + absl::Notification ready; + absl::Status status; + absl::flat_hash_map> communicators; +}; + +using NcclClique = Lockable; + +struct NcclCliques { + NcclClique& operator[](const NcclCliqueKey& key) { + absl::MutexLock lock(&mu); + return cliques[key]; + } + + absl::Mutex mu; + absl::node_hash_map cliques ABSL_GUARDED_BY(mu); +}; + +absl::StatusOr ToNcclUniqueId(const std::string& id_str) { + static_assert(sizeof(ncclUniqueId) == NCCL_UNIQUE_ID_BYTES, + "NCCL_UNIQUE_ID_BYTES"); + + TF_RET_CHECK(id_str.size() == NCCL_UNIQUE_ID_BYTES); + ncclUniqueId id; + absl::c_copy(id_str, id.internal); + return id; +} + +absl::StatusOr> AcquireNcclClique( + RunId run_id, OpId op_id, NcclCliqueKey clique_key, + const NcclCliqueIdCallback& clique_id_callback, + size_t num_local_participants, bool may_skip_rendezvous) { + static auto& cliques = *new NcclCliques; + + VLOG(2) << "AcquireNcclClique Rendezvous key (clique_key:" + << clique_key.ToString() << ", run" << run_id.ToString() << ", op" + << op_id.value() << ")"; + + auto rendezvous_key = std::make_tuple(run_id, op_id, std::move(clique_key)); + + int64_t terminate_timeout = xla::GetDebugOptionsFromFlags() + .xla_gpu_nccl_termination_timeout_seconds(); + + return RendezvousSingle>( + "acquire mock NCCL clique", rendezvous_key, num_local_participants, + [&]() -> absl::StatusOr { + const NcclCliqueKey& clique_key = std::get<2>(rendezvous_key); + NcclClique::Lock clique = cliques[clique_key].Acquire(); + clique->run_id = run_id.ToInt(); + return clique; + }, + /*warn_stuck_timeout=*/absl::Seconds(10), + (terminate_timeout >= 0) ? absl::Seconds(terminate_timeout) + : absl::InfiniteDuration()); +} + +absl::Status InitializeMockNcclCostModel( + int nRanks, int rank, int num_local_participants, + absl::Span> local_ranks, + GpuExecutableRunOptions::MockNcclTopoModel topo_model, + ncclComm_t* comm_ptr) { + XLA_NCCL_RETURN_IF_ERROR(ncclCalloc(comm_ptr, 1)); + ncclComm_t comm = *comm_ptr; + comm->nChannels = 1; + comm->nRanks = nRanks; + comm->rank = rank; + absl::string_view xml_str; + switch (topo_model) { + case GpuExecutableRunOptions::MockNcclTopoModel::kGCPA3: + comm->collNetSupport = false; + comm->nvlsSupport = false; + comm->minCompCap = comm->maxCompCap = stream_executor:: + CudaComputeCapability::CudaComputeCapabilities::HOPPER; + XLA_NCCL_RETURN_IF_ERROR(ncclCalloc(&comm->peerInfo, nRanks + 1)); + xml_str = kGCPA3; + break; + case GpuExecutableRunOptions::MockNcclTopoModel::kNvidia: + comm->collNetSupport = false; + comm->nvlsSupport = false; + comm->minCompCap = comm->maxCompCap = stream_executor:: + CudaComputeCapability::CudaComputeCapabilities::AMPERE; + XLA_NCCL_RETURN_IF_ERROR(ncclCalloc(&comm->peerInfo, nRanks + 1)); + xml_str = kNvidia; + break; + default: + return absl::InvalidArgumentError("Unknown MockNcclTopoModel"); + } + + auto xml = std::make_unique(); + TF_RETURN_IF_ERROR(MockTopoGetXml(xml_str, xml.get())); + TF_RETURN_IF_ERROR(MockNcclTopoUpdateXml(local_ranks, xml.get())); + XLA_NCCL_RETURN_IF_ERROR(ncclTopoTrimXml(xml.get())); + XLA_NCCL_RETURN_IF_ERROR(ncclTopoGetSystemFromXml(xml.get(), &comm->topo)); + XLA_NCCL_RETURN_IF_ERROR(ncclTopoComputePaths(comm->topo, nullptr)); + XLA_NCCL_RETURN_IF_ERROR(ncclTopoTrimSystem(comm->topo, comm)); + XLA_NCCL_RETURN_IF_ERROR(ncclTopoComputePaths(comm->topo, nullptr)); + XLA_NCCL_RETURN_IF_ERROR(ncclTopoSearchInit(comm->topo)); + + ncclTopoGraph ringGraph; + ncclTopoGraph treeGraph; + ncclTopoGraph collNetGraph; + ncclTopoGraph nvlsGraph; + ncclTopoGraph* graphs[] = {&treeGraph, &ringGraph, &collNetGraph, + &collNetGraph, &nvlsGraph, &nvlsGraph}; + + // Get rings and trees + ringGraph.id = 0; + ringGraph.pattern = NCCL_TOPO_PATTERN_RING; + ringGraph.collNet = 0; + ringGraph.minChannels = 1; + ringGraph.maxChannels = MAXCHANNELS / 2; + XLA_NCCL_RETURN_IF_ERROR(ncclTopoCompute(comm->topo, &ringGraph)); + + treeGraph.id = 1; + treeGraph.pattern = NCCL_TOPO_PATTERN_BALANCED_TREE; + treeGraph.collNet = 0; + treeGraph.minChannels = ringGraph.nChannels; + treeGraph.maxChannels = ringGraph.nChannels; + XLA_NCCL_RETURN_IF_ERROR(ncclTopoCompute(comm->topo, &treeGraph)); + + collNetGraph.id = 2; + collNetGraph.pattern = NCCL_TOPO_PATTERN_TREE; + collNetGraph.collNet = 1; + collNetGraph.minChannels = collNetGraph.maxChannels = ringGraph.nChannels; + if (comm->collNetSupport) { + XLA_NCCL_RETURN_IF_ERROR(ncclTopoCompute(comm->topo, &collNetGraph)); + } else { + collNetGraph.nChannels = 0; + } + + nvlsGraph.id = 3; + nvlsGraph.pattern = NCCL_TOPO_PATTERN_NVLS; + nvlsGraph.collNet = 0; + nvlsGraph.minChannels = 1; + nvlsGraph.maxChannels = MAXCHANNELS; + if (comm->nvlsSupport) { + XLA_NCCL_RETURN_IF_ERROR(ncclTopoCompute(comm->topo, &nvlsGraph)); + } else { + nvlsGraph.nChannels = 0; + } + + comm->nNodes = nRanks / num_local_participants; + XLA_NCCL_RETURN_IF_ERROR( + ncclTopoTuneModel(comm, comm->minCompCap, comm->maxCompCap, graphs)); + return absl::OkStatus(); +} + +} // namespace + +absl::StatusOr AcquireMockNcclComm( + RunId run_id, OpId op_id, std::vector participants, + std::vector local_devices, size_t num_local_participants, + const NcclCliqueIdCallback& clique_id_callback, int rank, int64_t stream_id, + bool enable_clique_optimization, + GpuExecutableRunOptions::MockNcclTopoModel topo_model) { + int nRanks = participants.size(); + std::vector> local_ranks; + for (int i = 0; i < local_devices.size(); i++) { + auto it = absl::c_find(participants, local_devices[i]); + if (it != participants.end()) { + local_ranks.push_back(std::make_pair(it - participants.begin(), i)); + } + } + + // Ensure that this group of threads have exclusive access to the clique to + // prevent threads from different groups locking communicators in the clique. + NcclCliqueKey clique_key(std::move(participants), stream_id); + TF_ASSIGN_OR_RETURN( + auto clique, + AcquireNcclClique( + run_id, op_id, clique_key, clique_id_callback, 1, + enable_clique_optimization || + stream_id == GetStreamId(true, AsyncStreamKind::kP2P0))); + + struct AllCommunicators { + absl::Mutex mu; + std::vector communicators ABSL_GUARDED_BY(mu); + }; + static auto& all_communicators = *new AllCommunicators; + + // Launch a thread that periodically checks all NCCL communicators for + // asynchronous errors. If an asynchronous error is observed, the communicator + // is aborted and an error message logged. + static auto check_async_error_thread = tsl::Env::Default()->StartThread( + tsl::ThreadOptions(), "nccl_async_error_thread", [&] { + while (true) { + absl::SleepFor(absl::Seconds(30)); + absl::MutexLock lock(&all_communicators.mu); + for (NcclComm* comm : all_communicators.communicators) { + CheckNcclAsyncError(*comm); + } + } + }); + (void)check_async_error_thread; // Silence unused variable warning. + + NcclCliqueState& state = **clique; + + if (!state.ready.HasBeenNotified()) { + ncclComm_t comm = nullptr; + absl::Status status = InitializeMockNcclCostModel( + nRanks, rank, num_local_participants, local_ranks, topo_model, &comm); + size_t num_initialized = [&] { + absl::MutexLock lock(&state.mu); + state.status.Update(status); + state.communicators[rank] = std::make_unique( + reinterpret_cast(comm)); + return state.communicators.size(); + }(); + + // Wait for all communicators to initialize before allowing any progress. + // Otherwise we may get deadlocks, because ncclCommInitRank may allocate, + // which may block on the completion of device activity on a peer device, + // which may depend on the completion of this collective if we do not have a + // barrier to prevent it. + if (num_initialized == 1) { + state.ready.Notify(); + } else { + TF_RETURN_IF_ERROR(status); + state.ready.WaitForNotification(); + } + + absl::MutexLock lock(&all_communicators.mu); + all_communicators.communicators.push_back(state.communicators[rank].get()); + } + + TF_RETURN_IF_ERROR(state.status); + return state.communicators[rank]->Acquire(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/mock_nccl_utils.h b/xla/service/gpu/mock_nccl_utils.h new file mode 100644 index 0000000000000..146b6441358b7 --- /dev/null +++ b/xla/service/gpu/mock_nccl_utils.h @@ -0,0 +1,94 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MOCK_NCCL_UTILS_H_ +#define XLA_SERVICE_GPU_MOCK_NCCL_UTILS_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "xla/executable_run_options.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/global_device_id.h" +#include "xla/service/gpu/gpu_executable_run_options.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/service/gpu/runtime/nccl_p2p_thunk_common.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/service/lockable.h" +#include "xla/stream_executor/stream.h" +#include "tsl/lib/gtl/int_type.h" + +namespace xla { +namespace gpu { + +TSL_LIB_GTL_DEFINE_INT_TYPE(OpId, int64_t); + +struct NcclCommName { + static std::string ToString(NcclApi::NcclCommHandle comm) { + return absl::StrFormat("lockable comm %p", comm); + } +}; + +struct NcclComm : public Lockable { + explicit NcclComm(NcclApi::NcclCommHandle comm) : Lockable(comm) {} +}; + +// Create the mock nccl communicator assuming all hosts have the same hardwares. +absl::StatusOr LockMockNcclComm( + const Thunk::CollectiveExecuteParams& params, + const std::vector& replica_groups, + CollectiveOpGroupMode group_mode, int64_t op_id, int64_t stream_id, + bool enable_clique_optimization, + GpuExecutableRunOptions::MockNcclTopoModel topo_model); + +absl::StatusOr AcquireMockNcclComm( + RunId run_id, OpId op_id, std::vector participants, + std::vector local_devices, size_t num_local_participants, + const NcclCliqueIdCallback& clique_id_callback, int rank, int64_t stream_id, + bool enable_clique_optimization, + GpuExecutableRunOptions::MockNcclTopoModel topo_model); + +// Mock a Nccl collective op including all-reduce, all-gather, and +// reduce-scatter. +absl::Status RunMockNcclCollectives(NcclApi* nccl_api, + std::vector& buffers, + se::Stream& stream, + NcclApi::NcclCommHandle comm, + Thunk::Kind reduce_op); + +// Mock a NCCL-based All-To-All op. +absl::Status RunMockNcclAllToAll(NcclApi* nccl_api, bool has_split_dimension, + std::vector& buffers, + se::Stream& stream, + NcclApi::NcclCommHandle comm); + +// Mock a collective permute op. +absl::Status RunMockCollectivePermute( + NcclApi* nccl_api, NcclP2PConfig::SourceTargetMapEntry source_target, + DeviceBufferPair& buffer, se::Stream& stream, NcclApi::NcclCommHandle comm, + absl::string_view device_string, int64_t current_id); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MOCK_NCCL_UTILS_H_ diff --git a/xla/service/gpu/mock_nccl_utils_default.cc b/xla/service/gpu/mock_nccl_utils_default.cc new file mode 100644 index 0000000000000..948d3825cea3e --- /dev/null +++ b/xla/service/gpu/mock_nccl_utils_default.cc @@ -0,0 +1,77 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/executable_run_options.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/global_device_id.h" +#include "xla/service/gpu/gpu_executable_run_options.h" +#include "xla/service/gpu/mock_nccl_utils.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/service/gpu/runtime/nccl_p2p_thunk_common.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/stream_executor/stream.h" +#include "xla/util.h" + +namespace xla { +namespace gpu { + +absl::StatusOr AcquireMockNcclComm( + RunId run_id, OpId op_id, std::vector participants, + std::vector local_devices, size_t num_local_participants, + const NcclCliqueIdCallback& clique_id_callback, int rank, int64_t stream_id, + bool enable_clique_optimization, + GpuExecutableRunOptions::MockNcclTopoModel topo_model) { + return Unimplemented("AcquireMockNcclComm is not implemented."); +} + +absl::StatusOr LockMockNcclComm( + const Thunk::CollectiveExecuteParams& params, + const std::vector& replica_groups, + CollectiveOpGroupMode group_mode, int64_t op_id, int64_t stream_id, + bool enable_clique_optimization, + GpuExecutableRunOptions::MockNcclTopoModel topo_model) { + return Unimplemented("LockMockNcclComm is not implemented."); +} + +absl::Status RunMockNcclCollectives(NcclApi*, std::vector&, + se::Stream&, NcclApi::NcclCommHandle, + Thunk::Kind) { + return Unimplemented("Mock nccl collectives is not implemented."); +} + +absl::Status RunMockNcclAllToAll(NcclApi*, bool, std::vector&, + se::Stream&, NcclApi::NcclCommHandle) { + return Unimplemented("Mock nccl AllToAll is not implemented."); +} + +absl::Status RunMockCollectivePermute(NcclApi*, + NcclP2PConfig::SourceTargetMapEntry, + DeviceBufferPair&, se::Stream&, + NcclApi::NcclCommHandle, + absl::string_view, int64_t) { + return Unimplemented("Mock collective permute is not implemented."); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/mock_nccl_xml.cc b/xla/service/gpu/mock_nccl_xml.cc new file mode 100644 index 0000000000000..881eb061fb5a8 --- /dev/null +++ b/xla/service/gpu/mock_nccl_xml.cc @@ -0,0 +1,197 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/mock_nccl_xml.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/status.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/regexp.h" +#if GOOGLE_CUDA +#include "third_party/gpus/nccl/graph/xml.h" +#endif + +namespace xla { +namespace gpu { +namespace { + +#if GOOGLE_CUDA + +class NcclTopoXmlParser { + public: + NcclTopoXmlParser(absl::string_view src, ncclXml* xml) + : src_(src), xml_(xml) {} + // A simple xml parser to parse the NCCL_TOPO_DUMP_FILE (see + // mock_nccl_topo_a3.xml) generated by the nccl library to the ncclXmlNode + // struct. + Status Parse(ncclXmlNode* head) { + if (head && head->type == NODE_TYPE_SINGLE) return absl::OkStatus(); + while (true) { + if (xml_->maxIndex == MAX_NODES) { + return absl::InternalError("XML parser is limited to 1024 nodes"); + } + ncclXmlNode* node = xml_->nodes + xml_->maxIndex; + memset(node, 0, sizeof(ncclXmlNode)); + TF_RETURN_IF_ERROR(GetNode(node)); + if (node->type == NODE_TYPE_NONE) { + if (head) { + return absl::InternalError( + absl::StrFormat("XML Parser : unterminated %s", head->name)); + } else { + // All done + return absl::OkStatus(); + } + } + if (head && node->type == NODE_TYPE_CLOSE) { + if (strcmp(node->name, head->name) != 0) { + return absl::InternalError(absl::StrFormat( + "XML Parser Mismatch : %s / %s", head->name, node->name)); + } + return absl::OkStatus(); + } + if (!head || spec_[head->name].contains(node->name)) { + if (head) head->subs[head->nSubs++] = node; + node->parent = head; + node->nSubs = 0; + xml_->maxIndex++; + TF_RETURN_IF_ERROR(Parse(node)); + } else { + return absl::InternalError( + absl::StrFormat("XML Parser : Unhandled element %s", node->name)); + } + } + } + + private: + std::pair GetAttr(absl::string_view src, + std::string& name, + std::string& value) { + static const LazyRE2 attr_regex = {"\\s*([^=]+)=\"([^\"]*)\""}; + return {RE2::Consume(&src, *attr_regex, &name, &value), src}; + } + + std::pair GetNodeName(absl::string_view src, + std::string& delimiter, + std::string& name) { + static const LazyRE2 name_regex = {"(/?)(\\w+)"}; + return {RE2::Consume(&src, *name_regex, &delimiter, &name), src}; + } + + Status GetNode(ncclXmlNode* node) { + static const LazyRE2 node_regex = {"\\s*<([^>]+)>"}; + std::string node_str; + + if (!RE2::Consume(&src_, *node_regex, &node_str)) return absl::OkStatus(); + absl::string_view node_str_view = absl::string_view(node_str); + std::string delimiter; + std::string name; + bool found_name = false; + std::tie(found_name, node_str_view) = + GetNodeName(node_str_view, delimiter, name); + CHECK(found_name) << "Fail to extract nccl topo node name"; + absl::SNPrintF(node->name, sizeof(node->name), "%s", name.c_str()); + + if (delimiter[0] == '/') { + node->type = NODE_TYPE_CLOSE; + return absl::OkStatus(); + } + node->type = NODE_TYPE_OPEN; + int num_attrs = 0; + bool found_attr = false; + do { + std::string key; + std::string value; + std::tie(found_attr, node_str_view) = GetAttr(node_str_view, key, value); + if (found_attr) { + absl::SNPrintF(node->attrs[num_attrs].key, + sizeof(node->attrs[num_attrs].key), "%s", key.c_str()); + absl::SNPrintF(node->attrs[num_attrs].value, + sizeof(node->attrs[num_attrs].value), "%s", + value.c_str()); + num_attrs++; + } + } while (found_attr); + node->nAttrs = num_attrs; + if (*node_str.rbegin() == '/') node->type = NODE_TYPE_SINGLE; + return absl::OkStatus(); + } + + absl::string_view src_; + ncclXml* xml_; + absl::flat_hash_map> spec_ = { + {"system", {"cpu"}}, {"pci", {"pci", "gpu", "nic"}}, {"gpu", {"nvlink"}}, + {"nic", {"net"}}, {"cpu", {"pci", "nic"}}, {"nvlink", {}}, + {"net", {}}}; +}; + +Status MockNcclTopoUpdateXmlRec( + absl::Span> local_ranks, + struct ncclXmlNode* node) { + if (strcmp(node->name, "gpu") == 0) { + int rank; + xmlGetAttrInt(node, "rank", &rank); + for (auto p : local_ranks) { + if (rank == p.second) { + xmlSetAttrInt(node, "keep", 1); + xmlSetAttrInt(node, "rank", p.first); + break; + } + } + } else if (strcmp(node->name, "net") == 0) { + xmlSetAttrInt(node, "keep", 1); + } + + for (int i = 0; i < node->nSubs; i++) { + TF_RETURN_IF_ERROR(MockNcclTopoUpdateXmlRec(local_ranks, node->subs[i])); + } + return absl::OkStatus(); +} + +#endif + +} // namespace + +Status MockTopoGetXml(absl::string_view xml_str_view, ncclXml* xml) { +#if GOOGLE_CUDA + xml->maxIndex = 0; + NcclTopoXmlParser parser(xml_str_view, xml); + return parser.Parse(nullptr); +#else + return absl::OkStatus(); +#endif +} + +Status MockNcclTopoUpdateXml(absl::Span> local_ranks, + ncclXml* xml) { +#if GOOGLE_CUDA + return MockNcclTopoUpdateXmlRec(local_ranks, xml->nodes); +#else + return absl::OkStatus(); +#endif +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/mock_nccl_xml.h b/xla/service/gpu/mock_nccl_xml.h new file mode 100644 index 0000000000000..9fbe41e92552f --- /dev/null +++ b/xla/service/gpu/mock_nccl_xml.h @@ -0,0 +1,40 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MOCK_NCCL_XML_H_ +#define XLA_SERVICE_GPU_MOCK_NCCL_XML_H_ + +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/status.h" + +struct ncclXml; + +namespace xla { +namespace gpu { + +Status MockTopoGetXml(absl::string_view xml_str_view, ncclXml* xml); + +// Based on which local gpu devices participate the input clique, update the xml +// topo graph. +Status MockNcclTopoUpdateXml(absl::Span> local_ranks, + ncclXml* xml); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MOCK_NCCL_XML_H_ diff --git a/xla/service/gpu/mock_nccl_xml_test.cc b/xla/service/gpu/mock_nccl_xml_test.cc new file mode 100644 index 0000000000000..c33f29f98e5ee --- /dev/null +++ b/xla/service/gpu/mock_nccl_xml_test.cc @@ -0,0 +1,98 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/mock_nccl_xml.h" + +#include +#include + +#include +#include "xla/status.h" +#include "tsl/platform/test.h" +#if GOOGLE_CUDA +#include "third_party/gpus/nccl/graph/xml.h" +#endif + +namespace xla { +namespace gpu { +namespace { + +#if GOOGLE_CUDA + +class MockNcclXmlParserTest : public ::testing::Test {}; + +TEST_F(MockNcclXmlParserTest, PciNic) { + const std::string original = R"( + + + + + + )"; + auto xml = std::make_unique(); + auto result = MockTopoGetXml(original, xml.get()); + + EXPECT_EQ(OkStatus(), result); + EXPECT_EQ(xml->maxIndex, 3); + EXPECT_EQ(std::string(xml->nodes[0].name), "pci"); + EXPECT_EQ(xml->nodes[0].nAttrs, 8); + EXPECT_EQ(std::string(xml->nodes[0].attrs[0].key), "busid"); + EXPECT_EQ(std::string(xml->nodes[0].attrs[0].value), "0000:0c:00.0"); + EXPECT_EQ(std::string(xml->nodes[0].attrs[1].key), "class"); + EXPECT_EQ(std::string(xml->nodes[0].attrs[1].value), "0x020700"); + EXPECT_EQ(std::string(xml->nodes[0].attrs[2].key), "vendor"); + EXPECT_EQ(std::string(xml->nodes[0].attrs[2].value), "0x15b3"); + EXPECT_EQ(std::string(xml->nodes[0].attrs[3].key), "device"); + EXPECT_EQ(std::string(xml->nodes[0].attrs[3].value), "0x101b"); + EXPECT_EQ(std::string(xml->nodes[0].attrs[4].key), "subsystem_vendor"); + EXPECT_EQ(std::string(xml->nodes[0].attrs[4].value), "0x15b3"); + EXPECT_EQ(std::string(xml->nodes[0].attrs[5].key), "subsystem_device"); + EXPECT_EQ(std::string(xml->nodes[0].attrs[5].value), "0x0007"); + EXPECT_EQ(std::string(xml->nodes[0].attrs[6].key), "link_speed"); + EXPECT_EQ(std::string(xml->nodes[0].attrs[6].value), "16.0 GT/s PCIe"); + EXPECT_EQ(std::string(xml->nodes[0].attrs[7].key), "link_width"); + EXPECT_EQ(std::string(xml->nodes[0].attrs[7].value), "16"); + EXPECT_EQ(xml->nodes[0].nSubs, 1); + EXPECT_EQ(std::string(xml->nodes[0].subs[0]->name), "nic"); +} + +TEST_F(MockNcclXmlParserTest, GpuNvlink) { + const std::string original = R"( + + + + )"; + auto xml = std::make_unique(); + auto result = MockTopoGetXml(original, xml.get()); + EXPECT_EQ(OkStatus(), result); + EXPECT_EQ(xml->maxIndex, 2); + EXPECT_EQ(std::string(xml->nodes[0].name), "gpu"); + EXPECT_EQ(xml->nodes[0].nAttrs, 4); + EXPECT_EQ(xml->nodes[0].nSubs, 1); + EXPECT_EQ(std::string(xml->nodes[0].subs[0]->name), "nvlink"); + EXPECT_EQ(xml->nodes[0].subs[0]->nAttrs, 3); + EXPECT_EQ(std::string(xml->nodes[0].subs[0]->attrs[0].key), "target"); + EXPECT_EQ(std::string(xml->nodes[0].subs[0]->attrs[0].value), "0000:c7:00.0"); + EXPECT_EQ(std::string(xml->nodes[0].subs[0]->attrs[1].key), "count"); + EXPECT_EQ(std::string(xml->nodes[0].subs[0]->attrs[1].value), "2"); + EXPECT_EQ(std::string(xml->nodes[0].subs[0]->attrs[2].key), "tclass"); + EXPECT_EQ(std::string(xml->nodes[0].subs[0]->attrs[2].value), "0x068000"); +} + +#endif + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/BUILD b/xla/service/gpu/model/BUILD index 539502bf84117..9d40d7a510427 100644 --- a/xla/service/gpu/model/BUILD +++ b/xla/service/gpu/model/BUILD @@ -1,15 +1,17 @@ -# Libraries for performance modeling of HLO. -load("//xla/tests:build_defs.bzl", "xla_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("//xla:xla.bzl", "xla_cc_test", "xla_nvml_deps") +load("@tsl//tsl:tsl.bzl", "internal_visibility") load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@tsl//tsl/platform:build_config.bzl", "tf_proto_library") load("@tsl//tsl/platform:build_config_root.bzl", "tf_cuda_tests_tags") load("@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") +load("//xla:xla.bzl", "xla_cc_test", "xla_nvml_deps") + +# Libraries for performance modeling of HLO. +load("//xla/tests:build_defs.bzl", "xla_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [":friends"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -25,8 +27,10 @@ cc_library( srcs = ["analytical_latency_estimator.cc"], hdrs = ["analytical_latency_estimator.h"], deps = [ + ":gpu_collective_performance_model", ":gpu_hlo_cost_analysis", ":gpu_performance_model", + ":gpu_performance_model_base", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", @@ -34,7 +38,9 @@ cc_library( "//xla/service:latency_hiding_scheduler", "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/time", + "@tsl//tsl/platform:status", ], ) @@ -50,15 +56,14 @@ xla_test( deps = [ ":analytical_latency_estimator", "//xla:shape_util", + "//xla:statusor", "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", "//xla/service:hlo_cost_analysis", - "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", "//xla/service:latency_hiding_scheduler", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:statusor", @@ -74,6 +79,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service/gpu:hlo_fusion_analysis", "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/synchronization", ], @@ -90,6 +96,9 @@ xla_cc_test( "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@tsl//tsl/platform:statusor", ], ) @@ -100,6 +109,7 @@ cc_library( deps = [ ":gpu_hlo_cost_analysis", ":gpu_performance_model", + ":gpu_performance_model_base", "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service:hlo_cost_analysis", @@ -107,7 +117,6 @@ cc_library( "//xla/service:hlo_pass", "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@tsl//tsl/platform:status", @@ -121,28 +130,28 @@ xla_cc_test( ":gpu_cost_model_stats_collection", ":gpu_hlo_cost_analysis", "//xla:shape_util", + "//xla/hlo/ir:hlo", "//xla/service:hlo_cost_analysis", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status:statusor", - "@tsl//tsl/lib/core:status_test_util", + "@com_google_googletest//:gtest", + "@tsl//tsl/platform:statusor", ], ) cc_library( name = "gpu_hlo_cost_analysis", srcs = ["gpu_hlo_cost_analysis.cc"], - hdrs = [ - "gpu_hlo_cost_analysis.h", - "hlo_op_profiles.h", - ], + hdrs = ["gpu_hlo_cost_analysis.h"], compatible_with = get_compatible_with_portable(), deps = [ ":hlo_op_profile_proto_cc", + ":hlo_op_profiles", "//xla:shape_util", + "//xla:status", "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", @@ -152,8 +161,15 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -162,8 +178,61 @@ xla_cc_test( srcs = ["gpu_hlo_cost_analysis_test.cc"], deps = [ ":gpu_hlo_cost_analysis", + "//xla:shape_util", + "//xla:test_helpers", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "gpu_performance_model_base", + srcs = ["gpu_performance_model_base.cc"], + hdrs = ["gpu_performance_model_base.h"], + deps = [ + ":fusion_analysis_cache", + ":gpu_hlo_cost_analysis", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions", + "//xla/service/gpu/fusions:fusion_emitter", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + ], +) + +xla_cc_test( + name = "gpu_performance_model_base_test", + srcs = ["gpu_performance_model_base_test.cc"], + deps = [ + ":gpu_hlo_cost_analysis", + ":gpu_performance_model_base", + "//xla:shape_util", + "//xla:test_helpers", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@tsl//tsl/platform:statusor", ], ) @@ -171,33 +240,144 @@ cc_library( name = "gpu_performance_model", srcs = ["gpu_performance_model.cc"], hdrs = ["gpu_performance_model.h"], + deps = [ + ":coalescing_analysis", + ":gpu_hlo_cost_analysis", + ":gpu_performance_model_base", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:launch_dimensions", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@tsl//tsl/platform:status", + ], +) + +xla_cc_test( + name = "gpu_performance_model_test", + srcs = ["gpu_performance_model_test.cc"], + deps = [ + ":gpu_hlo_cost_analysis", + ":gpu_indexing_performance_model", + ":gpu_performance_model", + ":gpu_performance_model_base", + "//xla:shape_util", + "//xla:test_helpers", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "gpu_collective_performance_model", + srcs = ["gpu_collective_performance_model.cc"], + hdrs = ["gpu_collective_performance_model.h"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ + ":coalescing_analysis", ":fusion_analysis_cache", ":gpu_hlo_cost_analysis", + ":gpu_performance_model_base", + ":hlo_op_profiles", + ":indexing_analysis", + ":indexing_map", "//xla:shape_util", + "//xla:statusor", + "//xla:util", "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", "//xla/service:hlo_dataflow_analysis", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions", + "//xla/service/gpu/fusions:fusion_emitter", "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:status", ] + if_cuda_is_configured(xla_nvml_deps()), ) xla_cc_test( - name = "gpu_performance_model_test", - srcs = ["gpu_performance_model_test.cc"], + name = "gpu_collective_performance_model_test", + srcs = ["gpu_collective_performance_model_test.cc"], + deps = [ + ":gpu_collective_performance_model", + "//xla/service/gpu:backend_configs_cc", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "gpu_indexing_performance_model", + srcs = ["gpu_indexing_performance_model.cc"], + hdrs = ["gpu_indexing_performance_model.h"], deps = [ + ":coalescing_analysis", ":gpu_hlo_cost_analysis", - ":gpu_performance_model", + ":gpu_performance_model_base", + ":hlo_op_profiles", + ":indexing_analysis", + ":indexing_map", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:launch_dimensions", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:status", + ], +) + +xla_cc_test( + name = "gpu_indexing_performance_model_test", + srcs = ["gpu_indexing_performance_model_test.cc"], + deps = [ + ":gpu_hlo_cost_analysis", + ":gpu_indexing_performance_model", + ":gpu_performance_model_base", "//xla:shape_util", "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/stream_executor:device_description", @@ -205,39 +385,166 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:statusor", ], ) cc_library( - name = "tile_analysis", - srcs = ["tile_analysis.cc"], - hdrs = ["tile_analysis.h"], + name = "affine_map_printer", + srcs = ["affine_map_printer.cc"], + hdrs = ["affine_map_printer.h"], deps = [ + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +xla_cc_test( + name = "affine_map_printer_test", + srcs = ["affine_map_printer_test.cc"], + deps = [ + ":affine_map_printer", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:test", + ], +) + +cc_library( + name = "indexing_map", + srcs = ["indexing_map.cc"], + hdrs = ["indexing_map.h"], + deps = [ + ":affine_map_printer", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "indexing_map_test", + srcs = ["indexing_map_test.cc"], + deps = [ + ":affine_map_printer", + ":indexing_analysis", + ":indexing_map", + ":indexing_test_utils", + "//xla:literal_util", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:test", + ], +) + +cc_library( + name = "indexing_test_utils", + testonly = True, + srcs = ["indexing_test_utils.cc"], + hdrs = ["indexing_test_utils.h"], + deps = [ + ":indexing_analysis", + ":indexing_map", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "indexing_analysis", + srcs = ["indexing_analysis.cc"], + hdrs = ["indexing_analysis.h"], + deps = [ + ":affine_map_printer", + ":indexing_map", + "//xla:permutation_util", "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service:gather_simplifier", + "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:matmul_utils", + "//xla/service/gpu/fusions:tiling_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +xla_cc_test( + name = "indexing_analysis_test", + srcs = ["indexing_analysis_test.cc"], + deps = [ + ":indexing_analysis", + ":indexing_map", + ":indexing_test_utils", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu/fusions:tiling_util", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:test", + ], +) + +cc_library( + name = "symbolic_tile", + srcs = ["symbolic_tile.cc"], + hdrs = ["symbolic_tile.h"], + deps = [ + ":affine_map_printer", + ":indexing_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", - "@tsl//tsl/platform:statusor", + "@llvm-project//mlir:Support", ], ) xla_cc_test( - name = "tile_analysis_test", - srcs = ["tile_analysis_test.cc"], + name = "symbolic_tile_test", + srcs = ["symbolic_tile_test.cc"], deps = [ - ":tile_analysis", - "//xla:statusor", - "//xla:test_helpers", + ":affine_map_printer", + ":indexing_analysis", + ":indexing_test_utils", + ":symbolic_tile", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -245,7 +552,162 @@ xla_cc_test( "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", "@llvm-project//mlir:IR", + "@tsl//tsl/platform:test", + ], +) + +cc_library( + name = "symbolic_tiled_hlo_instruction", + srcs = ["symbolic_tiled_hlo_instruction.cc"], + hdrs = ["symbolic_tiled_hlo_instruction.h"], + deps = [ + ":indexing_map", + ":symbolic_tile", + "//xla:status", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +xla_cc_test( + name = "symbolic_tiled_hlo_instruction_test", + srcs = ["symbolic_tiled_hlo_instruction_test.cc"], + deps = [ + ":indexing_analysis", + ":indexing_map", + ":symbolic_tile", + ":symbolic_tiled_hlo_instruction", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:hlo_traversal", + "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "tiled_hlo_instruction", + srcs = ["tiled_hlo_instruction.cc"], + hdrs = ["tiled_hlo_instruction.h"], + deps = [ + ":indexing_map", + "//xla:util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +xla_cc_test( + name = "tiled_hlo_instruction_test", + srcs = ["tiled_hlo_instruction_test.cc"], + deps = [ + ":indexing_map", + ":indexing_test_utils", + ":tiled_hlo_instruction", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "symbolic_tile_analysis", + srcs = ["symbolic_tile_analysis.cc"], + hdrs = ["symbolic_tile_analysis.h"], + deps = [ + ":indexing_analysis", + ":indexing_map", + ":symbolic_tile", + ":symbolic_tiled_hlo_instruction", + "//xla:status", + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/service:instruction_fusion", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +xla_cc_test( + name = "symbolic_tile_analysis_test", + srcs = ["symbolic_tile_analysis_test.cc"], + deps = [ + ":indexing_test_utils", + ":symbolic_tile_analysis", + ":symbolic_tiled_hlo_instruction", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "coalescing_analysis", + srcs = ["coalescing_analysis.cc"], + hdrs = ["coalescing_analysis.h"], + deps = [ + ":indexing_analysis", + ":indexing_map", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:gpu_fusible", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu/fusions:fusion_emitter", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +xla_cc_test( + name = "coalescing_analysis_test", + srcs = ["coalescing_analysis_test.cc"], + deps = [ + ":coalescing_analysis", + ":indexing_map", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu/fusions", + "//xla/service/gpu/fusions:fusion_emitter", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", "@tsl//tsl/platform:test", ], ) @@ -260,6 +722,42 @@ tf_proto_library( ], ) +cc_library( + name = "hlo_op_profiles", + srcs = ["hlo_op_profiles.cc"], + hdrs = [ + "hlo_op_profiles.h", + "hlo_op_profiles_data.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + ":hlo_op_profile_proto_cc", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_proto_cc", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:protobuf", + ], +) + +xla_cc_test( + name = "hlo_op_profiles_test", + srcs = ["hlo_op_profiles_test.cc"], + deps = [ + ":hlo_op_profiles", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/tests:hlo_test_base", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "hlo_op_profiler_lib", testonly = True, @@ -284,45 +782,58 @@ cc_library( "//xla/service:interpreter_plugin", "//xla/stream_executor:device_description", "//xla/tests:test_utils", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) -xla_cc_test( - name = "hlo_op_profiler_run", - timeout = "eternal", - srcs = ["hlo_op_profiler_run.cc"], - # Disable backend optimizations (in particular reassociate and instcombine) which would optimize - # expressions like integer add and multiply. - args = ["--xla_backend_optimization_level=0"], - # This is a development tool, not a normal test, and thus should only be run - # manually with --config=cuda. - tags = [ - "gpu", - "manual", - "notap", - "requires-gpu-nvidia", - ], - deps = [ - ":hlo_op_profile_proto_cc", - ":hlo_op_profiler_lib", - "//xla:debug_options_flags", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_runner", - "//xla/service:platform_util", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@tsl//tsl/platform:env", - "@tsl//tsl/platform:path", - "@tsl//tsl/platform:platform_port", - "@tsl//tsl/platform:status", - "@tsl//tsl/util:command_line_flags", - ], -) +[ + xla_cc_test( + name = "hlo_op_profiler_run_" + sm, + timeout = "eternal", + srcs = ["hlo_op_profiler_run.cc"], + # Disable backend optimizations (in particular reassociate and instcombine) which would optimize + # expressions like integer add and multiply. + args = ["--xla_backend_optimization_level=0"], + # This is a development tool, not a normal test, and thus should only be run + # manually with --config=cuda. + tags = [ + "gpu", + "manual", + "notap", + "requires-gpu-" + sm + "-only", + ], + deps = [ + ":hlo_op_profile_proto_cc", + ":hlo_op_profiler_lib", + ":hlo_op_profiles", + "//xla:debug_options_flags", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_runner", + "//xla/service:platform_util", + "//xla/stream_executor:device_description", + "//xla/tsl/util:command_line_flags", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:path", + "@tsl//tsl/platform:platform_port", + "@tsl//tsl/platform:status", + ], + ) + for sm in [ + "sm60", + "sm70", + "sm80", + "sm90", + ] +] xla_cc_test( name = "hlo_op_profiler_test", @@ -334,6 +845,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/service:gpu_plugin", "//xla/tests:hlo_test_base", + "@com_google_googletest//:gtest", "@tsl//tsl/platform:test_main", ], ) diff --git a/xla/service/gpu/model/affine_map_printer.cc b/xla/service/gpu/model/affine_map_printer.cc new file mode 100644 index 0000000000000..972da89717e32 --- /dev/null +++ b/xla/service/gpu/model/affine_map_printer.cc @@ -0,0 +1,269 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/affine_map_printer.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace xla { +namespace gpu { +namespace { + +using mlir::AffineBinaryOpExpr; +using mlir::AffineConstantExpr; +using mlir::AffineDimExpr; +using mlir::AffineExpr; +using mlir::AffineExprKind; +using mlir::AffineMap; +using mlir::AffineSymbolExpr; + +} // namespace + +AffineMapPrinter::AffineMapPrinter( + absl::Span dim_names, + absl::Span symbol_names) { + dim_id_to_name_.reserve(dim_names.size()); + for (const auto& [index, name] : llvm::enumerate(dim_names)) { + dim_id_to_name_[index] = name; + } + symbol_id_to_name_.reserve(symbol_names.size()); + for (const auto& [index, name] : llvm::enumerate(symbol_names)) { + symbol_id_to_name_[index] = name; + } +} + +void AffineMapPrinter::Print(std::ostream& out, AffineMap affine_map) const { + out << ToString(affine_map); +} + +std::string AffineMapPrinter::ToString(AffineMap affine_map) const { + std::string s; + llvm::raw_string_ostream ss(s); + + if (dim_id_to_name_.empty() && symbol_id_to_name_.empty()) { + affine_map.print(ss); + return s; + } + // Dimension identifiers. + int dim_count = affine_map.getNumDims(); + ss << '('; + for (int i = 0; i < dim_count - 1; ++i) { + ss << GetDimensionName(i) << ", "; + } + if (dim_count >= 1) { + ss << GetDimensionName(dim_count - 1); + } + ss << ')'; + // Symbolic identifiers. + int symbol_count = affine_map.getNumSymbols(); + if (symbol_count != 0) { + ss << '['; + for (unsigned i = 0; i < symbol_count - 1; ++i) { + ss << GetSymbolName(i) << ", "; + } + if (affine_map.getNumSymbols() >= 1) { + ss << GetSymbolName(symbol_count - 1); + } + ss << ']'; + } + // Result affine expressions. + ss << " -> ("; + llvm::interleaveComma(affine_map.getResults(), ss, [&](AffineExpr expr) { + PrintExprImpl(expr, /*add_parentheses=*/false, ss); + }); + ss << ')'; + return s; +} + +void AffineMapPrinter::Print(std::ostream& out, + mlir::AffineExpr affine_expr) const { + out << ToString(affine_expr); +} + +std::string AffineMapPrinter::ToString(mlir::AffineExpr affine_expr) const { + std::string s; + llvm::raw_string_ostream ss(s); + PrintExprImpl(affine_expr, /*add_parentheses=*/false, ss); + return s; +} + +void AffineMapPrinter::PrintExprImpl(const mlir::AffineExpr affine_expr, + bool add_parentheses, + llvm::raw_ostream& os) const { + const char* binopSpelling = nullptr; + switch (affine_expr.getKind()) { + case AffineExprKind::SymbolId: { + unsigned symbol_id = + mlir::cast(affine_expr).getPosition(); + os << GetSymbolName(symbol_id); + return; + } + case AffineExprKind::DimId: { + unsigned dim_id = mlir::cast(affine_expr).getPosition(); + os << GetDimensionName(dim_id); + return; + } + case AffineExprKind::Constant: + os << mlir::cast(affine_expr).getValue(); + return; + case AffineExprKind::Add: + binopSpelling = " + "; + break; + case AffineExprKind::Mul: + binopSpelling = " * "; + break; + case AffineExprKind::FloorDiv: + binopSpelling = " floordiv "; + break; + case AffineExprKind::CeilDiv: + binopSpelling = " ceildiv "; + break; + case AffineExprKind::Mod: + binopSpelling = " mod "; + break; + } + + auto binOp = mlir::cast(affine_expr); + AffineExpr lhsExpr = binOp.getLHS(); + AffineExpr rhsExpr = binOp.getRHS(); + + // Handle tightly binding binary operators. + if (binOp.getKind() != AffineExprKind::Add) { + if (add_parentheses) { + os << '('; + } + + // Pretty print multiplication with -1. + auto rhsConst = mlir::dyn_cast(rhsExpr); + if (rhsConst && binOp.getKind() == AffineExprKind::Mul && + rhsConst.getValue() == -1) { + os << "-"; + PrintExprImpl(lhsExpr, /*add_parentheses=*/true, os); + if (add_parentheses) { + os << ')'; + } + return; + } + + PrintExprImpl(lhsExpr, /*add_parentheses=*/true, os); + + os << binopSpelling; + PrintExprImpl(rhsExpr, /*add_parentheses=*/true, os); + + if (add_parentheses) { + os << ')'; + } + return; + } + + // Print out special "pretty" forms for add. + if (add_parentheses) { + os << '('; + } + + // Pretty print addition to a product that has a negative operand as a + // subtraction. + if (auto rhs = mlir::dyn_cast(rhsExpr)) { + if (rhs.getKind() == AffineExprKind::Mul) { + AffineExpr rrhsExpr = rhs.getRHS(); + if (auto rrhs = mlir::dyn_cast(rrhsExpr)) { + if (rrhs.getValue() == -1) { + PrintExprImpl(lhsExpr, /*add_parentheses=*/false, os); + os << " - "; + if (rhs.getLHS().getKind() == AffineExprKind::Add) { + PrintExprImpl(rhs.getLHS(), /*add_parentheses=*/true, os); + } else { + PrintExprImpl(rhs.getLHS(), /*add_parentheses=*/false, os); + } + + if (add_parentheses) { + os << ')'; + } + return; + } + + if (rrhs.getValue() < -1) { + PrintExprImpl(lhsExpr, /*add_parentheses=*/false, os); + os << " - "; + PrintExprImpl(rhs.getLHS(), /*add_parentheses=*/true, os); + os << " * " << -rrhs.getValue(); + if (add_parentheses) { + os << ')'; + } + return; + } + } + } + } + + // Pretty print addition to a negative number as a subtraction. + if (auto rhsConst = mlir::dyn_cast(rhsExpr)) { + if (rhsConst.getValue() < 0) { + PrintExprImpl(lhsExpr, /*add_parentheses=*/false, os); + os << " - " << -rhsConst.getValue(); + if (add_parentheses) { + os << ')'; + } + return; + } + } + + PrintExprImpl(lhsExpr, /*add_parentheses=*/false, os); + + os << " + "; + PrintExprImpl(rhsExpr, /*add_parentheses=*/false, os); + + if (add_parentheses) { + os << ')'; + } +} + +void AffineMapPrinter::SetSymbolName(int64_t symbol_id, llvm::StringRef name) { + symbol_id_to_name_[symbol_id] = name; +} + +void AffineMapPrinter::SetDimensionName(int64_t dim_id, llvm::StringRef name) { + dim_id_to_name_[dim_id] = name; +} + +std::string AffineMapPrinter::GetSymbolName(int64_t symbol_id) const { + auto it = symbol_id_to_name_.find(symbol_id); + if (it == symbol_id_to_name_.end()) { + return absl::StrCat("s", symbol_id); + } + return it->second; +} + +std::string AffineMapPrinter::GetDimensionName(int64_t dim_id) const { + auto it = dim_id_to_name_.find(dim_id); + if (it == dim_id_to_name_.end()) { + return absl::StrCat("d", dim_id); + } + return it->second; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/affine_map_printer.h b/xla/service/gpu/model/affine_map_printer.h new file mode 100644 index 0000000000000..1db3b1a791adc --- /dev/null +++ b/xla/service/gpu/model/affine_map_printer.h @@ -0,0 +1,67 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_AFFINE_MAP_PRINTER_H_ +#define XLA_SERVICE_GPU_MODEL_AFFINE_MAP_PRINTER_H_ + +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project + +namespace xla { +namespace gpu { + +// AffineMapPrinter allows to "pretty print" mlir::AffineMap by setting custom +// symbol and dimension names. +class AffineMapPrinter { + public: + AffineMapPrinter() = default; + AffineMapPrinter(AffineMapPrinter&& other) = default; + AffineMapPrinter& operator=(AffineMapPrinter&& other) = default; + AffineMapPrinter(absl::Span dim_names, + absl::Span symbol_names); + + void SetSymbolName(int64_t symbol_id, llvm::StringRef name); + void SetDimensionName(int64_t dim_id, llvm::StringRef name); + + std::string GetSymbolName(int64_t symbol_id) const; + std::string GetDimensionName(int64_t dim_id) const; + + void Print(std::ostream& out, mlir::AffineMap affine_map) const; + std::string ToString(mlir::AffineMap affine_map) const; + + void Print(std::ostream& out, mlir::AffineExpr affine_expr) const; + std::string ToString(mlir::AffineExpr affine_expr) const; + + private: + void PrintExprImpl(mlir::AffineExpr affine_expr, bool add_parentheses, + llvm::raw_ostream& os) const; + + llvm::DenseMap dim_id_to_name_; + llvm::DenseMap symbol_id_to_name_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_AFFINE_MAP_PRINTER_H_ diff --git a/xla/service/gpu/model/affine_map_printer_test.cc b/xla/service/gpu/model/affine_map_printer_test.cc new file mode 100644 index 0000000000000..1544d05686a0f --- /dev/null +++ b/xla/service/gpu/model/affine_map_printer_test.cc @@ -0,0 +1,59 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/affine_map_printer.h" + +#include +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +using ::mlir::AffineExpr; +using ::mlir::AffineMap; +using ::mlir::bindDims; +using ::mlir::bindSymbols; +using ::testing::HasSubstr; + +class IndexingMapTest : public HloTestBase { + public: + mlir::MLIRContext mlir_context_; + AffineMapPrinter printer_; +}; + +TEST_F(IndexingMapTest, AffineMapPrinterTest) { + AffineExpr d0, d1, s0, s1; + bindDims(&mlir_context_, d0, d1); + bindSymbols(&mlir_context_, s0, s1); + + // (d0, d1)[s0, s1] -> (d0 + d1 floordiv 8, s0 + s1 mod 16). + auto map = + AffineMap::get(2, 2, {d0 + d1.floorDiv(8), s0 + s1 % 16}, &mlir_context_); + + printer_.SetDimensionName(0, "offset"); + printer_.SetSymbolName(1, "linear_index"); + EXPECT_THAT(printer_.ToString(map), + HasSubstr("(offset, d1)[s0, linear_index] -> " + "(offset + d1 floordiv 8, s0 + linear_index mod 16)")); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/analytical_latency_estimator.cc b/xla/service/gpu/model/analytical_latency_estimator.cc index d13c96d435606..fa0908e8dda11 100644 --- a/xla/service/gpu/model/analytical_latency_estimator.cc +++ b/xla/service/gpu/model/analytical_latency_estimator.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,15 +18,19 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/time/time.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/service/gpu/model/gpu_collective_performance_model.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/latency_hiding_scheduler.h" #include "xla/stream_executor/device_description.h" +#include "tsl/platform/status.h" namespace xla { namespace gpu { @@ -52,9 +56,8 @@ LatencyEstimator::TimeCost AnalyticalLatencyEstimator::GetLatencyBetween( LatencyEstimator::TimeCost AnalyticalLatencyEstimator::NodeCost( const HloInstruction* instr) const { - const HloOpcode opcode = instr->opcode(); - if (hlo_query::IsAsyncCollectiveStartOp(opcode, /*include_send_recv=*/true) || - hlo_query::IsAsyncCollectiveDoneOp(opcode, /*include_send_recv=*/true)) { + if (hlo_query::IsAsyncCollectiveStartOp(instr, /*include_send_recv=*/true) || + hlo_query::IsAsyncCollectiveDoneOp(instr, /*include_send_recv=*/true)) { return kLowCost; } diff --git a/xla/service/gpu/model/analytical_latency_estimator.h b/xla/service/gpu/model/analytical_latency_estimator.h index 02944491be4c2..3f3c67f9da40d 100644 --- a/xla/service/gpu/model/analytical_latency_estimator.h +++ b/xla/service/gpu/model/analytical_latency_estimator.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,11 +18,13 @@ limitations under the License. #include #include -#include -#include "absl/container/flat_hash_map.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/hlo_cost_analysis.h" #include "xla/service/latency_hiding_scheduler.h" +#include "xla/stream_executor/device_description.h" #include "xla/xla.pb.h" namespace xla { diff --git a/xla/service/gpu/model/analytical_latency_estimator_test.cc b/xla/service/gpu/model/analytical_latency_estimator_test.cc index 4521fe0685939..1fa86f0caa024 100644 --- a/xla/service/gpu/model/analytical_latency_estimator_test.cc +++ b/xla/service/gpu/model/analytical_latency_estimator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -55,7 +56,7 @@ SchedulerConfig GetDefaultSchedulerConfig() { return scheduler_config; } -StatusOr RunScheduler( +absl::StatusOr RunScheduler( HloModule* module, const SchedulerConfig& sched_config, std::unique_ptr latency_estimator = std::make_unique()) { @@ -85,7 +86,7 @@ StatusOr RunScheduler( class AnalyticalLatencyHidingSchedulerTest : public GpuCodegenTest { public: - StatusOr> ParseHloText( + absl::StatusOr> ParseHloText( absl::string_view hlo_string) { return ParseAndReturnVerifiedModule(hlo_string, GetModuleConfigForTest()); } diff --git a/xla/service/gpu/model/coalescing_analysis.cc b/xla/service/gpu/model/coalescing_analysis.cc new file mode 100644 index 0000000000000..7392ef4a186fd --- /dev/null +++ b/xla/service/gpu/model/coalescing_analysis.cc @@ -0,0 +1,444 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/coalescing_analysis.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/gpu_fusible.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/shape.h" +#include "xla/shape_util.h" + +namespace xla { +namespace gpu { + +// Returns true if all input reads are coalesced. If consumer is not nullptr, +// producer and consumer are considered as one fusion, otherwise it's only the +// producer. +bool IsReadCoalescedHeuristic(HloFusionAnalysis::EmitterFusionKind fusion_kind, + const HloInstruction* producer, + const HloInstruction* consumer) { + // Transposing minor dimension breaks coalescing. + if (fusion_kind != HloFusionAnalysis::EmitterFusionKind::kTranspose) { + auto is_broadcast = [&](const HloInstruction* instr) { + while (true) { + if (instr->opcode() == HloOpcode::kBroadcast || + instr->opcode() == HloOpcode::kIota) { + return true; + } + if (instr->operand_count() != 1) return false; + if (instr->opcode() != HloOpcode::kBitcast && !instr->IsElementwise()) { + return false; + } + instr = instr->operand(0); + } + }; + auto is_bad_transpose = [&](const HloInstruction* instr) { + if (instr->opcode() == HloOpcode::kFusion) { + for (auto* instr : instr->fused_instructions()) { + // Hack: we allow transposes of broadcasts or iotas. + if (TransposesMinorDimension(instr) && + !is_broadcast(instr->operand(0))) { + return true; + } + } + return false; + } + // Hack: we allow transposes of broadcasts or iotas. + return TransposesMinorDimension(instr) && + !is_broadcast(instr->operand(0)); + }; + if (is_bad_transpose(producer)) return false; + if (consumer && is_bad_transpose(consumer)) return false; + } + // Fusing two row reductions breaks coalescing. + if (fusion_kind == HloFusionAnalysis::EmitterFusionKind::kReduction && + IsInputFusibleReduction(*producer) && consumer && + IsInputFusibleReduction(*consumer)) { + return false; + } + return true; +} + +namespace { + +using mlir::AffineExpr; +using mlir::AffineMap; +using mlir::getAffineConstantExpr; +using mlir::MLIRContext; + +// Performs backtracking to find all feasible dimensions, symbols that satisfy +// the constraints and then evaluates the affine map at those. +// For example, for the following indexing map: +// (d0)[s0] -> (d0 + s0) +// domain: +// d0 in [0, 3] +// s0 in [0, 1, 2] +// s0 mod 2 in [0, 0] +// The function will compute the following indices [0, 2, 1, 3, 2, 4, 3, 5]. +void FindAllIndices(const IndexingMap& thread_id_to_physical_index, + MLIRContext* mlir_context, int dim_id, int symbol_id, + std::vector* dimensions, + std::vector* symbols, + std::vector* indices) { + if (dim_id < thread_id_to_physical_index.GetDimensionCount()) { + Interval dim_range = thread_id_to_physical_index.GetDimensionBound(dim_id); + for (int64_t dim_value = dim_range.lower; dim_value <= dim_range.upper; + ++dim_value) { + dimensions->push_back(getAffineConstantExpr(dim_value, mlir_context)); + FindAllIndices(thread_id_to_physical_index, mlir_context, dim_id + 1, + symbol_id, dimensions, symbols, indices); + dimensions->pop_back(); + } + return; + } + if (symbol_id < thread_id_to_physical_index.GetRangeVarsCount()) { + Interval symbol_range = + thread_id_to_physical_index.GetSymbolBound(symbol_id); + for (int64_t symbol_value = symbol_range.lower; + symbol_value <= symbol_range.upper; ++symbol_value) { + symbols->push_back(getAffineConstantExpr(symbol_value, mlir_context)); + FindAllIndices(thread_id_to_physical_index, mlir_context, dim_id, + symbol_id + 1, dimensions, symbols, indices); + symbols->pop_back(); + } + return; + } + if (!thread_id_to_physical_index.ConstraintsSatisfied(*dimensions, + *symbols)) { + return; + } + indices->push_back( + thread_id_to_physical_index.Evaluate(*dimensions, *symbols).front()); +} + +// Computes contiguous intervals of accessed elements. +// For example, for an indexing map +// (thread_x) -> (thread_x * 4 + s0 + (thread_x floordiv 16) * 1984) +// d0 in [0, 31] +// s0 in [0, 3] +// The intervals are [0, 63] and [2047, 2111]. +// TODO(b/325613460): Make it faster than O(number of elements in the domain). +std::vector FindContiguousIntervals( + const IndexingMap& thread_id_to_physical_index) { + CHECK(thread_id_to_physical_index.GetAffineMap().getNumResults() == 1) + << "Expects an affine map that maps to 1D."; + MLIRContext* mlir_context = thread_id_to_physical_index.GetMLIRContext(); + + // Find all linear indices, sort and deduplicate them. + std::vector dimensions, symbols; + std::vector linear_indices; + FindAllIndices(thread_id_to_physical_index, mlir_context, + /*dim_id=*/0, + /*symbol_id=*/0, &dimensions, &symbols, &linear_indices); + std::sort(linear_indices.begin(), linear_indices.end()); + linear_indices.erase( + std::unique(linear_indices.begin(), linear_indices.end()), + linear_indices.end()); + + // Scan over the sorted unique indices and combine them in intervals. + std::vector intervals; + for (int i = 0, start, end; i < linear_indices.size(); ++i) { + start = linear_indices[i++]; + end = start; + while (i < linear_indices.size() && linear_indices[i] == end + 1) { + ++end; + ++i; + } + intervals.push_back(Interval{start, end}); + } + return intervals; +} + +int64_t CeilDiv(int64_t a, int64_t b) { return a / b + (a % b != 0); } + +// Approximately estimate the number of memory transactions needed to load all +// elements in every range and compare it with the "ideal" number of memory +// transactions, i.e. total number of elements in all ranges / WarpSize(). +// Note, that later we would need to take the element type into account. +bool EstimateCoalescingViaMemoryTransactionsCount( + absl::Span intervals, PrimitiveType element_type) { + constexpr int64_t kBytesPerMemoryTransaction = 128; + int64_t type_size = ShapeUtil::ByteSizeOfPrimitiveType(element_type); + int memory_transactions = 0; + int total_num_elements = 0; + for (const auto& range : intervals) { + int64_t num_elements = range.upper - range.lower + 1; + memory_transactions += + CeilDiv(num_elements * type_size, kBytesPerMemoryTransaction); + total_num_elements += num_elements; + } + if (memory_transactions == 0) { + return true; + } + int memory_transactions_lower_bound = + CeilDiv(total_num_elements * type_size, kBytesPerMemoryTransaction); + // The magic value chosen by an uneducated guess. + constexpr float kIsCoalescedThreshold = 0.9; + return memory_transactions_lower_bound > + memory_transactions * kIsCoalescedThreshold; +} + +bool IsCoalesced(const IndexingMap& thread_id_to_input_indexing_map, + PrimitiveType element_type) { + // Undefined indexing maps, i.e. those for which we don't know the indexing + // are assumed to be uncoalesced. + if (thread_id_to_input_indexing_map.IsUndefined()) { + return false; + } + // 0d constants are coalesced. + if (thread_id_to_input_indexing_map.GetAffineMap().getNumResults() == 0) { + return true; + } + MLIRContext* mlir_context = thread_id_to_input_indexing_map.GetMLIRContext(); + AffineExpr thread_x_dim = mlir::getAffineDimExpr( + KernelFusionInterface::kIndexingMapThreadIdxDims[0], mlir_context); + AffineExpr c0 = mlir::getAffineConstantExpr(0, mlir_context); + IndexingMap thread_x_first_32_elements{ + AffineMap::get(1, 0, {thread_x_dim, c0, c0, c0, c0, c0}, mlir_context), + {DimVar{{0, 31}}}, + /*range_vars=*/{}, + /*rt_vars=*/{}}; + IndexingMap thread_x_to_linearized_input = + thread_x_first_32_elements * thread_id_to_input_indexing_map; + + // If RTVars are present, replace them with constants. + if (thread_x_to_linearized_input.GetRTVarsCount() > 0) { + llvm::SmallVector symbol_replacements; + for (int64_t symbol_id = 0; + symbol_id < thread_x_to_linearized_input.GetRangeVarsCount(); + ++symbol_id) { + symbol_replacements.push_back( + mlir::getAffineSymbolExpr(symbol_id, mlir_context)); + } + for (const RTVar& rt_var : thread_x_to_linearized_input.GetRTVars()) { + // Take midpoint of the feasible interval for the RT variable. + symbol_replacements.push_back(getAffineConstantExpr( + (rt_var.feasible_values.lower + rt_var.feasible_values.upper) / 2, + mlir_context)); + } + AffineMap thread_x_to_input_no_rt_symbols = + thread_x_to_linearized_input.GetAffineMap().replaceDimsAndSymbols( + {}, symbol_replacements, + thread_x_to_linearized_input.GetDimVarsCount(), + thread_x_to_linearized_input.GetRangeVarsCount()); + thread_x_to_linearized_input = IndexingMap{ + thread_x_to_input_no_rt_symbols, + thread_x_to_linearized_input.GetDimVars(), + thread_x_to_linearized_input.GetRangeVars(), + thread_x_to_linearized_input.GetRTVars(), + }; + } + thread_x_to_linearized_input.Simplify(GetIndexingMapForInstruction); + thread_x_to_linearized_input.RescaleSymbols(); + thread_x_to_linearized_input.RemoveUnusedSymbols(); + return EstimateCoalescingViaMemoryTransactionsCount( + FindContiguousIntervals(thread_x_to_linearized_input), element_type); +} + +// Returns a linearized shape, i.e. tensor. +Shape GetLinearizedShape(const Shape& shape) { + if (shape.rank() == 0) { + return shape; + } + std::vector dims{ShapeUtil::ElementsIn(shape)}; + auto result = Shape(shape.element_type(), dims, + absl::InlinedVector(dims.size(), false), {}); + *result.mutable_layout() = xla::Layout({0}); + return result; +} + +// Returns thread ID to linearized physical layout indexing map for each operand +// of the fusion. +std::optional GetThreadIdToInputMemoryLayoutsMaps( + const HloFusionAdaptor& fusion_adaptor, + absl::Span operands, + const HloFusionAnalysis& fusion_analysis, + KernelFusionInterface* fusion_interface, mlir::MLIRContext* mlir_context) { + GroupedByOpIndexingMap result; + for (const auto& [root_index, hero] : + llvm::enumerate(fusion_analysis.fusion_heroes())) { + for (const auto& [hero_operand_index, hero_operand] : + llvm::enumerate(hero->operands())) { + if (hero_operand->shape().rank() == 0) { + continue; + } + // Compute thread ID -> hero operand indexing map. + std::optional thread_id_to_hero_operand_map = + fusion_interface->ComputeThreadIdToInputIndexing( + root_index, hero_operand_index, mlir_context); + if (!thread_id_to_hero_operand_map.has_value()) { + return std::nullopt; + } + // Compute indexing from output to inputs for logical layout. + HloInstructionAdaptor hero_operand_adaptor(*hero_operand); + GroupedByOpIndexingMap instr_indexing_keyed_by_operands = + ComputeGroupedOutputToInputIndexing( + fusion_adaptor, hero_operand_adaptor, mlir_context); + // For every operand compute thread ID -> physical layout of operand + // indexing map. + for (const HloInstruction* operand : operands) { + auto operand_indexing_maps_it = + instr_indexing_keyed_by_operands.find(operand); + if (operand_indexing_maps_it == + instr_indexing_keyed_by_operands.end()) { + continue; + } + const Shape& operand_shape = operand->shape(); + + IndexingMap operand_logical_to_physical_map = + GetIndexingMapFromLogicalToPhysicalLayout(operand_shape, + mlir_context); + IndexingMap operand_physical_to_linearized_shape = GetBitcastMap( + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + operand_shape), + GetLinearizedShape(operand_shape), mlir_context); + IndexingMap operand_logical_to_linearized_physical_shape = + operand_logical_to_physical_map * + operand_physical_to_linearized_shape; + operand_logical_to_linearized_physical_shape.Simplify( + GetIndexingMapForInstruction); + + for (const IndexingMap& operand_indexing_map : + operand_indexing_maps_it->second) { + // If one of the indexing maps for the operand is undefined, we remove + // all indexing maps for it and store only the undefined one. + if (operand_indexing_map.IsUndefined()) { + result[operand] = {operand_indexing_map}; + break; + } + IndexingMap logical_output_to_linearized_physical_input_map = + operand_indexing_map * + operand_logical_to_linearized_physical_shape; + IndexingMap thread_id_to_linearized_physical_input_map = + *thread_id_to_hero_operand_map * + logical_output_to_linearized_physical_input_map; + thread_id_to_linearized_physical_input_map.Simplify( + GetIndexingMapForInstruction); + result[operand].insert(thread_id_to_linearized_physical_input_map); + } + } + } + } + return result; +} + +} // namespace + +CoalescingAnalysis::CoalescingAnalysis( + const HloInstruction* instr, + absl::Span operands, + const HloFusionAnalysis& fusion_analysis, + KernelFusionInterface* fusion_interface, mlir::MLIRContext* mlir_context, + bool use_heuristic) { + auto fusion_adaptor = HloFusionAdaptor::ForInstruction(instr); + if (!use_heuristic && ComputeCoalescingForAllOperands( + *fusion_adaptor, operands, fusion_analysis, + fusion_interface, mlir_context)) { + return; + } + // If ComputeCoalescingForAllOperands fails, fallback to using the heuristic. + is_coalesced_computed_by_heuristic_ = + IsReadCoalescedHeuristic(fusion_analysis.GetEmitterFusionKind(), instr); +} + +CoalescingAnalysis::CoalescingAnalysis( + const HloInstruction* producer, const HloInstruction* consumer, + absl::Span operands, + const HloFusionAnalysis& fusion_analysis, + KernelFusionInterface* fusion_interface, mlir::MLIRContext* mlir_context, + bool use_heuristic) { + ProducerConsumerFusion fusion_adaptor(producer, consumer); + if (!use_heuristic && + ComputeCoalescingForAllOperands(fusion_adaptor, operands, fusion_analysis, + fusion_interface, mlir_context)) { + return; + } + // If ComputeCoalescingForAllOperands fails, fallback to using the heuristic. + is_coalesced_computed_by_heuristic_ = IsReadCoalescedHeuristic( + fusion_analysis.GetEmitterFusionKind(), producer, consumer); +} + +bool CoalescingAnalysis::ComputeCoalescingForAllOperands( + const HloFusionAdaptor& fusion_adaptor, + absl::Span operands, + const HloFusionAnalysis& fusion_analysis, + KernelFusionInterface* fusion_interface, mlir::MLIRContext* mlir_context) { + std::optional thread_id_to_input_memory_layouts = + GetThreadIdToInputMemoryLayoutsMaps(fusion_adaptor, operands, + fusion_analysis, fusion_interface, + mlir_context); + if (!thread_id_to_input_memory_layouts.has_value()) { + return false; + } + for (const HloInstruction* operand : operands) { + if (operand->shape().rank() == 0) { + coalescing_per_operand_.insert({operand, true}); + continue; + } + auto operand_indexing_maps = + thread_id_to_input_memory_layouts->find(operand); + // If there is no indexing map for the operand, it means that it is not used + // in the fusion cluster. + if (operand_indexing_maps == thread_id_to_input_memory_layouts->end()) { + coalescing_per_operand_.insert({operand, true}); + continue; + } + for (const IndexingMap& operand_indexing_map : + operand_indexing_maps->second) { + bool is_coalesced = + IsCoalesced(operand_indexing_map, operand->shape().element_type()); + auto [it, inserted] = + coalescing_per_operand_.insert({operand, is_coalesced}); + if (!inserted) { + it->second &= is_coalesced; + } + if (!is_coalesced) break; + } + } + return true; +} + +bool CoalescingAnalysis::IsReadCoalesced(const HloInstruction* operand) const { + auto it = coalescing_per_operand_.find(operand); + if (it == coalescing_per_operand_.end()) { + return is_coalesced_computed_by_heuristic_; + } + return it->second; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/coalescing_analysis.h b/xla/service/gpu/model/coalescing_analysis.h new file mode 100644 index 0000000000000..300036aa453ba --- /dev/null +++ b/xla/service/gpu/model/coalescing_analysis.h @@ -0,0 +1,77 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_COALESCING_ANALYSIS_H_ +#define XLA_SERVICE_GPU_MODEL_COALESCING_ANALYSIS_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" + +namespace xla { +namespace gpu { + +// Computes read coalescing for operands of an instruction or a +// producer-consumer fusion. +// Note, that later, after we migrate away from using the heuristic, we might +// want to use HloFusionAdaptor instead of having two different constructors. +class CoalescingAnalysis { + public: + // Computes read coalescing for operands of `instr`. + CoalescingAnalysis(const HloInstruction* instr, + absl::Span operands, + const HloFusionAnalysis& fusion_analysis, + KernelFusionInterface* fusion_interface = nullptr, + mlir::MLIRContext* mlir_context = nullptr, + bool use_heuristic = true); + + // Computes read coalescing for operands of fused `producer` and `consumer`. + CoalescingAnalysis(const HloInstruction* producer, + const HloInstruction* consumer, + absl::Span operands, + const HloFusionAnalysis& fusion_analysis, + KernelFusionInterface* fusion_interface = nullptr, + mlir::MLIRContext* mlir_context = nullptr, + bool use_heuristic = true); + + // Returns true if the operand is read coalesced. + bool IsReadCoalesced(const HloInstruction* operand) const; + + private: + bool ComputeCoalescingForAllOperands( + const HloFusionAdaptor& fusion_adaptor, + absl::Span operands, + const HloFusionAnalysis& fusion_analysis, + KernelFusionInterface* fusion_interface, mlir::MLIRContext* mlir_context); + + absl::flat_hash_map coalescing_per_operand_; + bool is_coalesced_computed_by_heuristic_ = false; +}; + +// Returns true if all input reads are coalesced. If consumer is not nullptr, +// producer and consumer are considered as one fusion, otherwise it's only the +// producer. +bool IsReadCoalescedHeuristic(HloFusionAnalysis::EmitterFusionKind fusion_kind, + const HloInstruction* producer, + const HloInstruction* consumer = nullptr); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_COALESCING_ANALYSIS_H_ diff --git a/xla/service/gpu/model/coalescing_analysis_test.cc b/xla/service/gpu/model/coalescing_analysis_test.cc new file mode 100644 index 0000000000000..9e0251cf01d73 --- /dev/null +++ b/xla/service/gpu/model/coalescing_analysis_test.cc @@ -0,0 +1,493 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/coalescing_analysis.h" + +#include +#include +#include + +#include +#include "absl/strings/string_view.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/fusions/fusions.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +using ::testing::ElementsAre; + +class CoalescingTest : public HloTestBase { + public: + std::vector IsReadCoalescedPerOperand(absl::string_view hlo_string) { + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + HloInstruction* root = module->entry_computation()->root_instruction(); + return IsReadCoalescedPerOperand(root); + } + + std::vector IsReadCoalescedPerOperand(const HloInstruction* root) { + auto fusion_adaptor = HloFusionAdaptor::ForInstruction(root); + auto analysis = AnalyzeFusion(*root, device_info_); + auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis}); + auto fusion = dynamic_cast(emitter.value().get()); + EXPECT_TRUE(emitter.ok()); + + CoalescingAnalysis coalescing_analysis(root, root->operands(), analysis, + fusion, &mlir_context_, + /*use_heuristic=*/false); + + std::vector results; + for (const HloInstruction* operand : root->operands()) { + results.push_back(coalescing_analysis.IsReadCoalesced(operand)); + } + return results; + } + + bool IsReadCoalescedHeuristic(absl::string_view hlo_string) { + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + HloInstruction* root = module->entry_computation()->root_instruction(); + auto analysis = AnalyzeFusion(*root, device_info_); + return xla::gpu::IsReadCoalescedHeuristic(analysis.GetEmitterFusionKind(), + root->operand(0), root); + } + + protected: + stream_executor::DeviceDescription device_info_ = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); + mlir::MLIRContext mlir_context_; +}; + +TEST_F(CoalescingTest, IdentityLayout) { + absl::string_view ir = R"( + HloModule m + fusion { + p0 = f32[100, 200] parameter(0) + p1 = f32[100, 200] parameter(1) + ROOT adthread_x = f32[100, 200] add(p0, p1) + } + ENTRY e { + p0 = f32[100, 200] parameter(0) + p1 = f32[100, 200] parameter(1) + ROOT fusion = f32[100, 200] fusion(p0, p1), kind=kInput, calls=fusion + } + )"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (thread_x) -> (thread_x) + // Operand 2: (thread_x) -> (thread_x) + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true, true)); +} + +TEST_F(CoalescingTest, RhsTransposedLayout) { + absl::string_view ir = R"( + HloModule m + fusion { + p0 = f32[100, 200]{1, 0} parameter(0) + p1 = f32[100, 200]{0, 1} parameter(1) + ROOT exp = f32[100, 200]{1, 0} add(p0, p1) + } + ENTRY e { + p0 = f32[100, 200]{1, 0} parameter(0) + p1 = f32[100, 200]{0, 1} parameter(1) + ROOT fusion = f32[100, 200]{1, 0} fusion(p0, p1), kind=kInput, calls=fusion + } + )"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (thread_x) -> (thread_x) + // Operand 2: (thread_x) -> (thread_x * 100) + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true, false)); +} + +TEST_F(CoalescingTest, OutputTransposedLayout) { + absl::string_view ir = R"( + HloModule m + fusion { + p0 = f32[100, 200]{1, 0} parameter(0) + p1 = f32[100, 200]{1, 0} parameter(1) + ROOT exp = f32[100, 200]{0, 1} add(p0, p1) + } + ENTRY e { + p0 = f32[100, 200]{1, 0} parameter(0) + p1 = f32[100, 200]{1, 0} parameter(1) + ROOT fusion = f32[100, 200]{0, 1} fusion(p0, p1), kind=kInput, calls=fusion + } + )"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (thread_x) -> (thread_x * 200) + // Operand 2: (thread_x) -> (thread_x * 200) + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(false, false)); +} + +TEST_F(CoalescingTest, OutputAndLhsTransposedLayout) { + absl::string_view ir = R"( + HloModule m + fusion { + p0 = f32[100, 200]{1, 0} parameter(0) + p1 = f32[100, 200]{0, 1} parameter(1) + ROOT exp = f32[100, 200]{1, 0} add(p0, p1) + } + ENTRY e { + p0 = f32[100, 200]{1, 0} parameter(0) + p1 = f32[100, 200]{0, 1} parameter(1) + ROOT fusion = f32[100, 200]{1, 0} fusion(p0, p1), kind=kInput, calls=fusion + } + )"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (thread_x) -> (thread_x) + // Operand 2: (thread_x) -> (thread_x * 100) + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true, false)); +} + +TEST_F(CoalescingTest, Transpose) { + absl::string_view ir = R"( + HloModule module + + fusion { + %input = f32[100, 64, 32] parameter(0) + ROOT transpose = f32[32, 100, 64] transpose(%input), dimensions={2, 0, 1} + } + + ENTRY entry { + %input = f32[100, 64, 32] parameter(0) + ROOT %fusion = f32[32, 100, 64] fusion(%input), kind=kLoop, calls=fusion + })"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (thread_x)[s0] -> (thread_x + s0 * 128) for s0 in [0, 7] + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true)); +} + +TEST_F(CoalescingTest, TransposeOfBroadcastHeuristic) { + absl::string_view ir = R"( + HloModule module + + fusion { + input = f32[32, 100, 64] parameter(0) + ROOT slice = f32[32, 100, 1] slice(input), slice={[0:32:1], [0:100:1], [0:1:1]} + } + + ENTRY entry { + p0 = f32[32] parameter(0) + broadcast = f32[100, 64, 32] broadcast(p0), dimensions={2} + transpose = f32[32, 100, 64] transpose(broadcast), dimensions={2, 0, 1} + ROOT %fusion = f32[32, 100, 1] fusion(transpose), kind=kLoop, calls=fusion + })"; + EXPECT_TRUE(IsReadCoalescedHeuristic(ir)); +} + +TEST_F(CoalescingTest, TransposeOfIotaHeuristic) { + absl::string_view ir = R"( + HloModule module + + fusion { + p0 = f32[32, 100, 64] parameter(0) + ROOT slice = f32[32, 100, 1] slice(p0), slice={[0:32:1], [0:100:1], [0:1:1]} + } + + ENTRY entry { + iota = f32[100, 64, 32] iota(), iota_dimension=1 + transpose = f32[32, 100, 64] transpose(iota), dimensions={2, 0, 1} + ROOT %fusion = f32[32, 100, 1] fusion(transpose), kind=kLoop, calls=fusion + })"; + EXPECT_TRUE(IsReadCoalescedHeuristic(ir)); +} + +TEST_F(CoalescingTest, TransposeOfAddHeuristic) { + absl::string_view ir = R"( + HloModule module + + fusion { + p0 = f32[32, 100, 64] parameter(0) + ROOT slice = f32[32, 100, 1] slice(p0), slice={[0:32:1], [0:100:1], [0:1:1]} + } + + ENTRY entry { + input = f32[100, 64, 32] parameter(0) + add = f32[100, 64, 32] add(input, input) + transpose = f32[32, 100, 64] transpose(add), dimensions={2, 0, 1} + ROOT %fusion = f32[32, 100, 1] fusion(transpose), kind=kLoop, calls=fusion + })"; + EXPECT_FALSE(IsReadCoalescedHeuristic(ir)); +} + +TEST_F(CoalescingTest, TransposeOnlyOuterDims) { + absl::string_view ir = R"( + HloModule module + + fusion { + %input = f32[100, 32, 64] parameter(0) + ROOT transpose = f32[32, 100, 64] transpose(%input), dimensions={1, 0, 2} + } + + ENTRY entry { + %input = f32[100, 32, 64] parameter(0) + ROOT %fusion = f32[32, 100, 64] fusion(%input), kind=kLoop, calls=fusion + })"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: + // (thread_x) -> (thread_x * 4 + s0 + (thread_x floordiv 16) * 1984) + // for s0 in [0, 3] + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true)); +} + +TEST_F(CoalescingTest, PadOp) { + absl::string_view ir = R"( + HloModule module + fusion { + p0 = f32[997, 436] parameter(0) + p1 = f32[] parameter(1) + ROOT pad = f32[1024, 512] pad(p0, p1), padding=10_17x24_52 + } + ENTRY entry { + p0 = f32[997, 436] parameter(0) + p1 = f32[] parameter(1) + ROOT %fusion = f32[1024, 512] fusion(p0, p1), kind=kLoop, calls=fusion + })"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (thread_x)[s0] -> (thread_x * 4 + s0 - 4384) + // for s0 in [0, 3] and thread_x * 4 + s0 in [24, 459] + // Operand 2: (thread_x) -> () + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true, true)); +} + +TEST_F(CoalescingTest, RowReduction) { + absl::string_view ir = R"( + HloModule module + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + fusion { + %input = f32[100,64,512] parameter(0) + %c0 = f32[] constant(0) + ROOT reduce = f32[100,64] reduce(%input, %c0), dimensions={2}, to_apply=add + } + ENTRY entry { + %input = f32[100,64,512] parameter(0) + ROOT %fusion = f32[100,64] fusion(%input), kind=kInput, calls=fusion + })"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (thread_x)[s0] -> (thread_x + s0 * 32) for s0 in [0, 15] + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true)); +} + +TEST_F(CoalescingTest, MultiRowReduction) { + absl::string_view ir = R"( + HloModule module + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + fusion { + %input = f32[100,64,4] parameter(0) + %c0 = f32[] constant(0) + ROOT reduce = f32[100,64] reduce(%input, %c0), dimensions={2}, to_apply=add + } + ENTRY entry { + %input = f32[100,64,4] parameter(0) + ROOT %fusion = f32[100,64] fusion(%input), kind=kInput, calls=fusion + })"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (thread_x) -> (thread_x) + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true)); +} + +TEST_F(CoalescingTest, ColumnReduction) { + absl::string_view ir = R"( + HloModule module + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + fusion { + %input = f32[100,64,32] parameter(0) + %c0 = f32[] constant(0) + ROOT reduce = f32[100,32] reduce(%input, %c0), + dimensions={1}, to_apply=add + } + ENTRY entry { + %input = f32[100,64,32] parameter(0) + ROOT %fusion = f32[100,32] fusion(%input), kind=kInput, calls=fusion + })"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (thread_x)[s0] -> (thread_x + s0 * 1024) for s0 in [0, 1] + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true)); +} + +TEST_F(CoalescingTest, VariadicReduceViaLoopEmitter) { + absl::string_view ir = R"( + HloModule module + max { + p0 = s32[] parameter(0) + p1 = s32[] parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + max01 = s32[] maximum(p0, p1) + max23 = s32[] maximum(p2, p3) + ROOT max = (s32[], s32[]) tuple(max01, max23) + } + fusion { + p0 = s32 [5696,10,4] parameter(0) + p1 = s32 [5696,10,4] parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + ROOT reduce = (s32[5696,4], s32[5696,4]) reduce(s32[5696,10,4] p0, + s32[5696,10,4] p1, s32[] p2, s32[] p3), dimensions={1}, to_apply=max + } + ENTRY entry { + p0 = s32 [5696,10,4] parameter(0) + p1 = s32 [5696,10,4] parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + ROOT f = (s32[5696,4], s32[5696,4]) fusion(p0, p1, p2, p3), + kind=kInput, calls=fusion + })"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operands 1, 2: (d0)[s0] -> ((d0 floordiv 4) * 40 + d0 mod 4 + s0 * 4) + // for s0 in [0, 9]. + EXPECT_THAT(IsReadCoalescedPerOperand(ir), + ElementsAre(true, true, true, true)); +} + +TEST_F(CoalescingTest, VariadicReduceViaReductionEmitter) { + absl::string_view ir = R"( + HloModule module + max { + p0 = s32[] parameter(0) + p1 = s32[] parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + max01 = s32[] maximum(p0, p1) + max23 = s32[] maximum(p2, p3) + ROOT max = (s32[], s32[]) tuple(max01, max23) + } + fusion { + p0 = s32[32,40] parameter(0) + p1 = s32[32,40] parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + ROOT reduce = (s32[32], s32[32]) + reduce(s32[32,40] p0, s32[32,40] p1, s32[] p2, s32[] p3), + dimensions={1}, to_apply=max + } + ENTRY entry { + p0 = s32[32,40] parameter(0) + p1 = s32[32,40] parameter(1) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + ROOT f = (s32[32], s32[32]) fusion(p0, p1, p2, p3), + kind=kInput, calls=fusion + })"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operands 1, 2: (d0)[s0] -> (d0 + s0 * 32) + // for s0 in [0, 1] and d0 + s0 * 32 in [0, 39]. + EXPECT_THAT(IsReadCoalescedPerOperand(ir), + ElementsAre(true, true, true, true)); +} + +TEST_F(CoalescingTest, Gather) { + absl::string_view ir = R"( + HloModule module + fusion { + operand = f32[33, 76, 70] parameter(0) + indices = s32[1806, 2] parameter(1) + ROOT gather = f32[1806, 7, 8, 4] gather(operand, indices), + offset_dims={1,2,3}, collapsed_slice_dims={}, start_index_map={0,1}, + index_vector_dim=1, slice_sizes={7,8,4} + } + ENTRY entry { + p0 = f32[33, 76, 70] parameter(0) + p1 = s32[1806, 2] parameter(1) + ROOT %fusion = f32[1806, 7, 8, 4] fusion(p0, p1), kind=kLoop, calls=fusion + })"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (d0)[s0] -> ( + // (d0 floordiv 8) * 5320 + (d0 mod 8) * 70 + s0 * 70 + 34) for s0 in [0, 3] + // Operand 2: (d0)[s0] -> (s0) + // for s0 in [0, 1]. + EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(false, true)); +} + +TEST_F(CoalescingTest, DynamicSlice) { + absl::string_view ir = R"( + HloModule module + fusion { + %src = s32[2,2,258] parameter(0) + %of1 = s32[] parameter(1) + %of2 = s32[] parameter(2) + %of3 = s32[] parameter(3) + ROOT %ds = s32[1,2,32] dynamic-slice(s32[2,2,258] %src, + s32[] %of1, s32[] %of2, s32[] %of3), + dynamic_slice_sizes={1, 2, 32} + } + ENTRY entry { + %p0 = s32[2,2,258] parameter(0) + %p1 = s32[] parameter(1) + %p2 = s32[] parameter(2) + %p3 = s32[] parameter(3) + ROOT %fusion = s32[1,2,32] fusion(p0, p1, p2, p3), kind=kLoop, calls=fusion + })"; + // thread_x to linearized input mapping for thread_x in [0, 31]: + // Operand 1: (d0) -> (d0). + EXPECT_THAT(IsReadCoalescedPerOperand(ir), + ElementsAre(true, true, true, true)); +} + +TEST_F(CoalescingTest, UnusedParameter) { + Shape shape = ShapeUtil::MakeShape(F32, {100000}); + + auto module = std::make_unique("m", HloModuleConfig{}); + HloComputation::Builder b("b"); + auto p0 = b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + auto p1 = b.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); + + HloComputation::Builder sub_builder("subcomp"); + HloInstruction* p0f = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "p0f")); + // p1f is not used. + HloInstruction* p1f = sub_builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "p1f")); + ASSERT_NE(p1f, nullptr); + sub_builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0f)); + + HloComputation* subcomp = module->AddEmbeddedComputation(sub_builder.Build()); + auto fusion = HloInstruction::CreateFusion( + shape, HloInstruction::FusionKind::kLoop, {p0, p1}, subcomp); + b.AddInstruction(std::move(fusion)); + module->AddEntryComputation(b.Build()); + + EXPECT_THAT(IsReadCoalescedPerOperand( + module->entry_computation()->root_instruction()), + ElementsAre(true, true)); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/fusion_analysis_cache.cc b/xla/service/gpu/model/fusion_analysis_cache.cc index 00a294413506a..ba033fb74f9d8 100644 --- a/xla/service/gpu/model/fusion_analysis_cache.cc +++ b/xla/service/gpu/model/fusion_analysis_cache.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,22 +15,25 @@ limitations under the License. #include "xla/service/gpu/model/fusion_analysis_cache.h" +#include + +#include "absl/synchronization/mutex.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" namespace xla::gpu { -const std::optional& HloFusionAnalysisCache::Get( +const HloFusionAnalysis& HloFusionAnalysisCache::Get( const HloInstruction& instruction) { { - absl::ReaderMutexLock lock(&mutex_); + absl::MutexLock lock(&mutex_); auto it = analyses_.find(instruction.unique_id()); if (it != analyses_.end()) { return it->second; } } - std::optional analysis = - AnalyzeFusion(instruction, device_info_); + HloFusionAnalysis analysis = AnalyzeFusion(instruction, device_info_); absl::MutexLock lock(&mutex_); // If some other thread created an entry for this key concurrently, return @@ -40,21 +43,22 @@ const std::optional& HloFusionAnalysisCache::Get( return it->second; } - return analyses_[instruction.unique_id()] = std::move(analysis); + return analyses_.emplace(instruction.unique_id(), std::move(analysis)) + .first->second; } -const std::optional& HloFusionAnalysisCache::Get( +const HloFusionAnalysis& HloFusionAnalysisCache::Get( const HloInstruction& producer, const HloInstruction& consumer) { std::pair key{producer.unique_id(), consumer.unique_id()}; { - absl::ReaderMutexLock lock(&mutex_); + absl::MutexLock lock(&mutex_); auto it = producer_consumer_analyses_.find(key); if (it != producer_consumer_analyses_.end()) { return it->second; } } - std::optional analysis = + HloFusionAnalysis analysis = AnalyzeProducerConsumerFusion(producer, consumer, device_info_); absl::MutexLock lock(&mutex_); @@ -69,7 +73,8 @@ const std::optional& HloFusionAnalysisCache::Get( producer.unique_id()); consumers_for_producers_[producer.unique_id()].push_back( consumer.unique_id()); - return producer_consumer_analyses_[key] = std::move(analysis); + return producer_consumer_analyses_.emplace(key, std::move(analysis)) + .first->second; } void HloFusionAnalysisCache::Invalidate(const HloInstruction& instruction) { @@ -90,4 +95,13 @@ void HloFusionAnalysisCache::Invalidate(const HloInstruction& instruction) { } } +void HloFusionAnalysisCache::Clear() { + absl::MutexLock lock(&mutex_); + + analyses_.clear(); + producer_consumer_analyses_.clear(); + consumers_for_producers_.clear(); + producers_for_consumers_.clear(); +} + } // namespace xla::gpu diff --git a/xla/service/gpu/model/fusion_analysis_cache.h b/xla/service/gpu/model/fusion_analysis_cache.h index b13c0a102f370..4cf6053e03fed 100644 --- a/xla/service/gpu/model/fusion_analysis_cache.h +++ b/xla/service/gpu/model/fusion_analysis_cache.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,10 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_MODEL_FUSION_ANALYSIS_CACHE_H_ #define XLA_SERVICE_GPU_MODEL_FUSION_ANALYSIS_CACHE_H_ +#include +#include + +#include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_map.h" #include "absl/synchronization/mutex.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -35,25 +39,27 @@ class HloFusionAnalysisCache { // Returns the analysis for the given instruction, creating it if it doesn't // exist yet. Do not call concurrently with `Invalidate` for the same key. - const std::optional& Get( - const HloInstruction& instruction); + const HloFusionAnalysis& Get(const HloInstruction& instruction); // Returns the analysis for the given producer/consumer pair. - const std::optional& Get(const HloInstruction& producer, - const HloInstruction& consumer); + const HloFusionAnalysis& Get(const HloInstruction& producer, + const HloInstruction& consumer); // Removes the cache entry for the given instruction, if it exists. Also // removes all producer-consumer fusions that involve this instruction. void Invalidate(const HloInstruction& instruction); + // Delete all cache entries. + void Clear(); + private: const stream_executor::DeviceDescription& device_info_; absl::Mutex mutex_; -// All `int` keys and values here are unique instruction IDs. - absl::node_hash_map> analyses_; - absl::node_hash_map, std::optional> + // All `int` keys and values here are unique instruction IDs. + absl::node_hash_map analyses_; + absl::node_hash_map, HloFusionAnalysis> producer_consumer_analyses_; // For each instruction `producer`, contains the `consumer`s for which we have diff --git a/xla/service/gpu/model/fusion_analysis_cache_test.cc b/xla/service/gpu/model/fusion_analysis_cache_test.cc index edacd6a7c8666..d507a88a4d803 100644 --- a/xla/service/gpu/model/fusion_analysis_cache_test.cc +++ b/xla/service/gpu/model/fusion_analysis_cache_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,11 +15,15 @@ limitations under the License. #include "xla/service/gpu/model/fusion_analysis_cache.h" +#include +#include +#include "absl/strings/string_view.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/hlo_parser.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" namespace xla::gpu { namespace { @@ -52,17 +56,17 @@ TEST_F(FusionAnalysisCacheTest, CachesAndInvalidates) { auto* negate = computation->GetInstructionWithName("n0"); auto* fusion = module->entry_computation()->root_instruction(); - EXPECT_THAT(cache_.Get(*fusion)->fusion_roots(), + EXPECT_THAT(cache_.Get(*fusion).fusion_roots(), ::testing::ElementsAre(negate)); computation->set_root_instruction(broadcast); - EXPECT_THAT(cache_.Get(*fusion)->fusion_roots(), + EXPECT_THAT(cache_.Get(*fusion).fusion_roots(), ::testing::ElementsAre(negate)) << "Analysis should be cached."; cache_.Invalidate(*fusion); - EXPECT_THAT(cache_.Get(*fusion)->fusion_roots(), + EXPECT_THAT(cache_.Get(*fusion).fusion_roots(), ::testing::ElementsAre(broadcast)) << "Analysis should have been recomputed"; } @@ -96,17 +100,17 @@ TEST_F(FusionAnalysisCacheTest, CachesAndInvalidatesProducerConsumerFusions) { auto* computation = module->GetComputationWithName("f"); auto* constant = computation->GetInstructionWithName("c0"); - EXPECT_EQ(cache_.Get(*fusion, *neg)->GetEmitterFusionKind(), + EXPECT_EQ(cache_.Get(*fusion, *neg).GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction); computation->set_root_instruction(constant); - EXPECT_EQ(cache_.Get(*fusion, *neg)->GetEmitterFusionKind(), + EXPECT_EQ(cache_.Get(*fusion, *neg).GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction) << "Analysis should be cached."; cache_.Invalidate(*fusion); - EXPECT_EQ(cache_.Get(*fusion, *neg)->GetEmitterFusionKind(), + EXPECT_EQ(cache_.Get(*fusion, *neg).GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kLoop) << "Analysis should have been recomputed"; } diff --git a/xla/service/gpu/model/gpu_collective_performance_model.cc b/xla/service/gpu/model/gpu_collective_performance_model.cc new file mode 100644 index 0000000000000..f415b0e34d206 --- /dev/null +++ b/xla/service/gpu/model/gpu_collective_performance_model.cc @@ -0,0 +1,483 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/gpu_collective_performance_model.h" + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/numbers.h" +#include "absl/time/time.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/hlo_dataflow_analysis.h" +#include "xla/stream_executor/device_description.h" +#include "xla/util.h" + +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/nvml/include/nvml.h" +#endif // GOOGLE_CUDA +namespace xla { +namespace gpu { + +namespace { + +int64_t GetNcclMaxNumChannels( + GpuPerformanceWithCollectiveModel::CollectiveAlgo algorithm) { + int64_t max_nchannels = 0; + switch (algorithm) { + // Tree and Ring algos share the same max channel number. + case GpuPerformanceWithCollectiveModel::RING: + case GpuPerformanceWithCollectiveModel::TREE: + max_nchannels = GpuPerformanceWithCollectiveModel::kMaxNumChannelsRing; + break; + } + const char* env = std::getenv("NCCL_MAX_NCHANNELS"); + if (env != nullptr) { + int64_t max_nchannels_from_env; + if (absl::SimpleAtoi(env, &max_nchannels_from_env)) { + max_nchannels = std::min(max_nchannels_from_env, max_nchannels); + } + } + return max_nchannels; +} +// CostModelKind GetCostModelKind(){ +// const char* env = std::getenv("XLA_COLLECTIVE_COST_MODEL_KIND"); +// } +int64_t GetMinNumberOfChannels( + GpuPerformanceWithCollectiveModel::CollectiveAlgo algorithm) { + int64_t min_nchannels = 0; + switch (algorithm) { + // Tree and Ring algos share the same min channel number. + case GpuPerformanceWithCollectiveModel::RING: + case GpuPerformanceWithCollectiveModel::TREE: + min_nchannels = 1; + break; + } + const char* env = std::getenv("NCCL_MIN_NCHANNELS"); + if (env != nullptr) { + int64_t min_nchannels_from_env; + if (absl::SimpleAtoi(env, &min_nchannels_from_env)) { + min_nchannels = std::min(min_nchannels_from_env, min_nchannels); + } + } + return min_nchannels; +} + +int GetNumThreads(int warp_size, int min_num_threads, int max_num_threads, + int default_num_threads) { + int threads_from_env = default_num_threads; + const char* env = std::getenv("NCCL_NTHREADS"); + if (env != nullptr) { + CHECK(absl::SimpleAtoi(env, &threads_from_env)); + } + int num_threads = threads_from_env; + if (num_threads > 0) { + if (num_threads % warp_size != 0) { + num_threads = max_num_threads; + } else if (num_threads > max_num_threads) { + num_threads = max_num_threads; + } else if (num_threads < min_num_threads) { + num_threads = min_num_threads; + } + } else { + num_threads = default_num_threads; + } + return num_threads; +} + +float GetMaxSysBwFromGpu(const se::CudaComputeCapability cc, + const double* bandwidths_table) { + switch (cc.major) { + case se::CudaComputeCapability::VOLTA: + return bandwidths_table[0]; + case se::CudaComputeCapability::AMPERE: + return bandwidths_table[1]; + case se::CudaComputeCapability::HOPPER: + return bandwidths_table[2]; + } + return -1; +} + +} // namespace + +// Returns NVLink bw in GB/s +/*static*/ +float GpuPerformanceWithCollectiveModel::GetNvlinkBw( + se::CudaComputeCapability compute_capability) { + return compute_capability.IsAtLeast(se::CudaComputeCapability::HOPPER) + ? kSm90NvlinkBandwidth + : compute_capability.IsAtLeast(se::CudaComputeCapability::AMPERE) + ? kSm80NvlinkBandwidth + : compute_capability.IsAtLeast(se::CudaComputeCapability::VOLTA) + ? kSm70NvlinkBandwidth + : compute_capability.IsAtLeast(se::CudaComputeCapability::PASCAL_) + ? kSm60NvlinkBandwidth + : kSm80NvlinkBandwidth; +} + +/*static*/ bool GpuPerformanceWithCollectiveModel::InitNvml() { +#if GOOGLE_CUDA + void* libhandle = dlopen("libnvidia-ml.so.1", RTLD_NOW); + CHECK(libhandle != nullptr) << "Failed to open libnvidia-ml.so.1"; + + struct SymbolEntry { + void** functor; + char const* name; + }; + + std::vector symbols = { + {(void**)&xla_nvmlInit, "nvmlInit_v2"}, + {(void**)&xla_nvmlShutdown, "nvmlShutdown"}, + {(void**)&xla_nvmlDeviceGetHandleByIndex, "nvmlDeviceGetHandleByIndex"}, + {(void**)&xla_nvmlDeviceGetNvLinkCapability, + "nvmlDeviceGetNvLinkCapability"}, + }; + for (SymbolEntry se : symbols) { + *se.functor = dlsym(libhandle, se.name); + } + nvmlReturn_t init_result = xla_nvmlInit(); + return init_result == NVML_SUCCESS; +#else + return false; +#endif // GOOGLE_CUDA +} + +/*static*/ bool GpuPerformanceWithCollectiveModel::ShutdownNvml() { +#if GOOGLE_CUDA + nvmlReturn_t shutdown_result = xla_nvmlShutdown(); + return shutdown_result == NVML_SUCCESS; +#else + return false; +#endif // GOOGLE_CUDA +} + +/*static*/ uint32_t +GpuPerformanceWithCollectiveModel::CheckIfNvlinkSupportsP2P() { +#if GOOGLE_CUDA + // We will use nvml library to detect nvlink capability + // to see if it supports p2p communication. + // We first load libnvidia-ml.so and assign symbols to function pointers + // to avoid linking errors. + // Then gpu 0 will be used to query for nvlink capability, note that + // we only look at link 0 of gpu 0 since all other links are assumed + // to have the same capability. + CHECK(InitNvml()) << "NVML init failed."; + nvmlDevice_t nvml_device; + nvmlReturn_t get_device_result = + xla_nvmlDeviceGetHandleByIndex(0, &nvml_device); + CHECK(get_device_result == NVML_SUCCESS); + + uint32_t supported_p2p = 0; + + nvmlReturn_t nvlink_cap_result = xla_nvmlDeviceGetNvLinkCapability( + nvml_device, /*nvlink link number*/ 0, NVML_NVLINK_CAP_P2P_SUPPORTED, + &supported_p2p); + if(nvlink_cap_result==NVML_ERROR_NOT_SUPPORTED) + { + VLOG(8) << "nvmlDeviceGetNvLinkCapability is not supported."; + return 0; + } + CHECK(nvlink_cap_result == NVML_SUCCESS); + CHECK(ShutdownNvml()) << "NVML shutdown failed."; + return supported_p2p; +#else + return 0; +#endif // GOOGLE_CUDA +} + +/*static*/ absl::Duration +GpuPerformanceWithCollectiveModel::ComputeAllreduceTime( + const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, + const se::DeviceDescription& gpu_device_info) { + // We use nccl group call to launch multiple allreduces so launch overhead + // only occurs once. + absl::Duration total_time = kNcclKernelLaunchOverhead; + stream_executor::CudaComputeCapability compute_cap = + gpu_device_info.cuda_compute_capability(); + + int64_t size_of_speed_array = kIntraNodeSpeeds.size(); + int64_t size_of_sm90_speed_array = kIntraNodeSpeedsSm90.size(); + + int num_speeds = compute_cap.major >= se::CudaComputeCapability::HOPPER + ? size_of_sm90_speed_array + : size_of_speed_array; + const double* speeds = compute_cap.major >= se::CudaComputeCapability::HOPPER + ? kIntraNodeSpeedsSm90.data() + : kIntraNodeSpeeds.data(); + + int speed_index = 0; + float max_sys_bw = + GetMaxSysBwFromGpu(compute_cap, kLowLatencyMaxBandwidths.data()); + + CHECK_GT(max_sys_bw, 0); + + while ((speed_index < num_speeds - 1) && speeds[speed_index] > max_sys_bw) { + speed_index++; + } + float bw_intra_node = speeds[speed_index]; + int64_t num_devices = cost_analysis->NumOfDevices(instr); + + int64_t min_nchannels = + std::max(num_devices, GetMinNumberOfChannels(CollectiveAlgo::RING)); + int64_t num_channels = + std::max(min_nchannels, GetNcclMaxNumChannels(CollectiveAlgo::RING)); + int default_threads = + (bw_intra_node * num_channels <= kPciBandwidth) ? 256 : kLL128NumThreads; + + int warp_size = gpu_device_info.threads_per_warp(); + int num_threads = GetNumThreads(warp_size, kLL128NumThreads / 4, + kLL128NumThreads, default_threads); + + // Since channels are pipelined together, compute time will only occur as in a + // single channel. + absl::Duration compute_time_per_channel = + ComputeTime(gpu_device_info, + cost_analysis->flop_count(instr) / num_channels, num_threads); + total_time += compute_time_per_channel; + + uint32_t supported_p2p = CheckIfNvlinkSupportsP2P(); + + if (supported_p2p == 0) { + VLOG(8) << "Nvlink doesn't support p2p communication. Model will " + "continue using default system bandwidth."; + } else { + VLOG(8) << "Nvlink supports p2p communication, setting intra node " + "bandwidth to nvlink bw."; + bw_intra_node = GetNvlinkBw(compute_cap); + } + + double bus_bandwidth = bw_intra_node * num_channels; + + // Get per channel LL128 ring bandwidth + double per_channel_ring_ll128_Bw = + GetMaxSysBwFromGpu(compute_cap, kPerChannelMaxRingLL128Bandwidths.data()); + + bus_bandwidth = std::min(bus_bandwidth * kRingAlgorithmDiscountFactor, + num_channels * per_channel_ring_ll128_Bw); + auto bandswidth_vector = + GetInterInnerBandwidths(instr, cost_analysis, gpu_device_info); + double intra_node_bus_bandwidth = bandswidth_vector[0]; + double inner_node_bus_bandwidth = bandswidth_vector[1]; + // allreduce send(single direction) inter nodes: instr_bytes*(num_devices-kInnerNodeGpu) + auto instr_bytes =cost_analysis->bytes_accessed(instr); + auto intra_nodes_numel_bytes=instr_bytes*(num_devices-kInnerNodeGpu); + + auto local_gpu = std::min(kInnerNodeGpu, num_devices); + // allreduce send(single direction) inter nodes: instr_bytes*(local_gpu-1) + auto inner_node_numel_bytes=instr_bytes*(local_gpu-1); + + double actual_bandwidth = bus_bandwidth * cost_analysis->ScalingRatio(instr); + absl::Duration communication_time = absl::Milliseconds( + std::max(intra_nodes_numel_bytes / (intra_node_bus_bandwidth * 1e6), + inner_node_numel_bytes / (inner_node_bus_bandwidth * 1e6))); + total_time += communication_time; + return total_time; +} +std::vector GpuPerformanceWithCollectiveModel::GetInterInnerBandwidths( + const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, + const se::DeviceDescription& gpu_device_info) { + const char* XLA_INTERNODE_BW = std::getenv("XLA_INTERNODE_BW"); + const char* XLA_INNERNODE_BW = std::getenv("XLA_INNERNODE_BW"); + if (XLA_INTERNODE_BW != nullptr&& XLA_INNERNODE_BW!=nullptr) + { + double inner_node_bus_bandwidth = std::stod(XLA_INNERNODE_BW); + double intra_node_bus_bandwidth = std::stod(XLA_INTERNODE_BW); + return std::vector( + {intra_node_bus_bandwidth, inner_node_bus_bandwidth}); + } + stream_executor::CudaComputeCapability compute_cap = + gpu_device_info.cuda_compute_capability(); + + int64_t size_of_speed_array = kIntraNodeSpeeds.size(); + int64_t size_of_sm90_speed_array = kIntraNodeSpeedsSm90.size(); + + int num_speeds = compute_cap.major >= se::CudaComputeCapability::HOPPER + ? size_of_sm90_speed_array + : size_of_speed_array; + const double* speeds = compute_cap.major >= se::CudaComputeCapability::HOPPER + ? kIntraNodeSpeedsSm90.data() + : kIntraNodeSpeeds.data(); + + int speed_index = 0; + float max_sys_bw = + GetMaxSysBwFromGpu(compute_cap, kLowLatencyMaxBandwidths.data()); + + CHECK_GT(max_sys_bw, 0); + + while ((speed_index < num_speeds - 1) && speeds[speed_index] > max_sys_bw) { + speed_index++; + } + float bw_intra_node = speeds[speed_index]; + int64_t num_devices = cost_analysis->NumOfDevices(instr); + + int64_t min_nchannels = + std::max(num_devices, GetMinNumberOfChannels(CollectiveAlgo::RING)); + int64_t num_channels = + std::max(min_nchannels, GetNcclMaxNumChannels(CollectiveAlgo::RING)); + int default_threads = + (bw_intra_node * num_channels <= kPciBandwidth) ? 256 : kLL128NumThreads; + + int warp_size = gpu_device_info.threads_per_warp(); + int num_threads = GetNumThreads(warp_size, kLL128NumThreads / 4, + kLL128NumThreads, default_threads); + + uint32_t supported_p2p = CheckIfNvlinkSupportsP2P(); + + if (supported_p2p == 0) { + VLOG(8) << "Nvlink doesn't support p2p communication. Model will " + "continue using default system bandwidth."; + } else { + VLOG(8) << "Nvlink supports p2p communication, setting intra node " + "bandwidth to nvlink bw."; + bw_intra_node = GetNvlinkBw(compute_cap); + } + // Get per channel LL128 ring bandwidth + double per_channel_ring_ll128_Bw = + GetMaxSysBwFromGpu(compute_cap, kPerChannelMaxRingLL128Bandwidths.data()); + double bus_bandwidth = bw_intra_node * num_channels; + double intra_node_bus_bandwidth = + bw_intra_node * num_channels * kRingAlgorithmDiscountFactor; + double inner_node_bus_bandwidth = num_channels * per_channel_ring_ll128_Bw; + // maybe get from env is better? + return std::vector( + {intra_node_bus_bandwidth, inner_node_bus_bandwidth}); +} +/*static*/ absl::Duration +GpuPerformanceWithCollectiveModel::ComputeAllgatherTime( + const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, + const se::DeviceDescription& gpu_device_info) { + // allgather: all devices send their data to all other devices. there is + // allgather ring method + // TODO if using bruck algorithm, the time will be different + // communication inter_node time = bytes_accessed * (total_gpu - local_gpu) / + // bandwidth + VLOG(5) << instr.ToString() << " ComputeAllgatherTime begin"; + absl::Duration total_time = kKernelLaunchOverhead; + auto bandswidth_vector = + GetInterInnerBandwidths(instr, cost_analysis, gpu_device_info); + double intra_node_bus_bandwidth = bandswidth_vector[0]; + double inner_node_bus_bandwidth = bandswidth_vector[1]; + + auto numel_bytes = cost_analysis->bytes_accessed(instr); + + int64_t total_gpu = cost_analysis->NumOfDevices(instr); + //TODO: total_gpu is uncorrect + + int64_t intra_nodes = (total_gpu - kInnerNodeGpu) / kInnerNodeGpu; + // + auto intra_nodes_numel_bytes = + numel_bytes * + ((total_gpu - kInnerNodeGpu) > 0 ? (total_gpu - kInnerNodeGpu) : 0); + auto inner_node_numel_bytes = + numel_bytes * (std::min(kInnerNodeGpu, total_gpu) - 1); + double actual_bandwidth_ratio = cost_analysis->ScalingRatio(instr); + absl::Duration communication_time = absl::Milliseconds( + std::max(intra_nodes_numel_bytes / (actual_bandwidth_ratio*intra_node_bus_bandwidth * 1e6), + inner_node_numel_bytes / (actual_bandwidth_ratio*inner_node_bus_bandwidth * 1e6))); + VLOG(5) << instr.ToString() << " numel_bytes:" << numel_bytes + << " intra_nodes_numel_bytes: " << intra_nodes_numel_bytes + << " inner_node_numel_bytes: " << inner_node_numel_bytes + << " intra_node_bus_bandwidth: " << intra_node_bus_bandwidth + << "GBps,inner_node_bus_bandwidth: " << inner_node_bus_bandwidth + << "GBps communication_time: " << communication_time; + total_time += communication_time; + return total_time; +} +/*static*/ absl::Duration +GpuPerformanceWithCollectiveModel::ComputeReducescatterTime( + const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, + const se::DeviceDescription& gpu_device_info) { + VLOG(5) << instr.ToString() << " ComputeReducescatterTime begin"; + absl::Duration total_time = kKernelLaunchOverhead; + auto bandswidth_vector = + GetInterInnerBandwidths(instr, cost_analysis, gpu_device_info); + double intra_node_bus_bandwidth = bandswidth_vector[0]; + double inner_node_bus_bandwidth = bandswidth_vector[1]; + auto numel_bytes = cost_analysis->bytes_accessed(instr); + auto num_devices = cost_analysis->NumOfDevices(instr); + auto local_gpu = std::min(kInnerNodeGpu, num_devices); + if (num_devices <= 1) { + return kKernelLaunchOverhead; + } + if (intra_node_bus_bandwidth == 0 || inner_node_bus_bandwidth == 0) { + return kKernelLaunchOverhead; + } + // compute and comm on op will overlap,ignore compute + + auto intra_nodes_numel_bytes = + numel_bytes * (num_devices - kInnerNodeGpu) / num_devices; + auto inner_node_numel_bytes = numel_bytes * (local_gpu - 1) / num_devices; + double actual_bandwidth_ratio = cost_analysis->ScalingRatio(instr); + absl::Duration communication_time = absl::Milliseconds( + std::max(intra_nodes_numel_bytes / (actual_bandwidth_ratio*intra_node_bus_bandwidth * 1e6), + inner_node_numel_bytes / (actual_bandwidth_ratio*inner_node_bus_bandwidth * 1e6))); + VLOG(5) << instr.ToString() << " numel_bytes:" << numel_bytes + << " intra_nodes_numel_bytes: " << intra_nodes_numel_bytes + << " inner_node_numel_bytes: " << inner_node_numel_bytes + << " intra_node_bus_bandwidth: " << intra_node_bus_bandwidth + << "GBps,inner_node_bus_bandwidth: " << inner_node_bus_bandwidth + << "GBps communication_time: " << communication_time; + total_time += communication_time; + return total_time; +} +/*static*/ absl::Duration +GpuPerformanceWithCollectiveModel::ComputeCollectiveTime( + const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, + const se::DeviceDescription& gpu_device_info) { + if (cost_analysis->NumOfDevices(instr) == 1) { + VLOG(8) << "Returning only kernel launch overhead for a single partition."; + return kKernelLaunchOverhead; + } + + if (HloDataflowAnalysis::IsAsynchronousOperationDone(instr.opcode())) { + VLOG(8) << "Returning 0 cost for async done op " << instr.name(); + return absl::ZeroDuration(); + } + switch (instr.opcode()) { + case HloOpcode::kAllReduce: + case HloOpcode::kAllReduceStart: + return ComputeAllreduceTime(instr, cost_analysis, gpu_device_info); + case HloOpcode::kAllGather: + case HloOpcode::kAllGatherStart: + return ComputeAllgatherTime(instr, cost_analysis, gpu_device_info); + case HloOpcode::kReduceScatter: + return ComputeReducescatterTime(instr, cost_analysis, gpu_device_info); + // asyncop+reducescatter + case HloOpcode::kAsyncStart: { + if (instr.async_wrapped_instruction()->opcode() == + HloOpcode::kReduceScatter) { + return ComputeReducescatterTime(*instr.async_wrapped_instruction(), + cost_analysis, gpu_device_info); + } + } + + default: { + LOG(WARNING) + << "Runtime estimate for " << instr.name() << instr.opcode() + << " not implemented. Returning only the kernel launch time."; + return kKernelLaunchOverhead; + } + } +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/gpu_collective_performance_model.h b/xla/service/gpu/model/gpu_collective_performance_model.h new file mode 100644 index 0000000000000..0a18d5d71d3f0 --- /dev/null +++ b/xla/service/gpu/model/gpu_collective_performance_model.h @@ -0,0 +1,156 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_GPU_COLLECTIVE_PERFORMANCE_MODEL_H_ +#define XLA_SERVICE_GPU_MODEL_GPU_COLLECTIVE_PERFORMANCE_MODEL_H_ + +#include +#include + +#include "absl/time/time.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" +#include "xla/stream_executor/device_description.h" + +#if GOOGLE_CUDA +#include + +#include "third_party/gpus/cuda/nvml/include/nvml.h" +// Below is a list of function pointers to be used +// for querying device properties through nvml library. +#define NVML_FUNCTOR(name, rettype, args) \ + inline rettype(*xla_##name) args = nullptr; + +NVML_FUNCTOR(nvmlInit, nvmlReturn_t, ()) +NVML_FUNCTOR(nvmlShutdown, nvmlReturn_t, ()) +NVML_FUNCTOR(nvmlDeviceGetHandleByIndex, nvmlReturn_t, + (unsigned int index, nvmlDevice_t* device)) +NVML_FUNCTOR(nvmlDeviceGetNvLinkCapability, nvmlReturn_t, + (nvmlDevice_t device, unsigned int link, + nvmlNvLinkCapability_t capability, unsigned int* capResult)) + +#endif + +namespace xla { +namespace gpu { + +class GpuPerformanceWithCollectiveModel : public GpuPerformanceModelBase { + public: + // Different algorithms that can be used to perform the collective. + enum CollectiveAlgo { + RING = 0, + TREE, + }; + // nccl can't reach bandwidth when below 32M, kind=MEATURE,we fit a linear model + enum CostModelKind{ + THEORY=0, + MEATURE=1, + }; + + // Table for max system bandwidths GB/s for using NCCL's low latency + // algorithm. This is used for intra-node estimate. + static constexpr std::array kLowLatencyMaxBandwidths = { + 39.0 /* Volta*/, 87.7 /* Ampere*/, 87.7 /* Hopper*/ + }; + + // Max bandwidth in GB/s for ring low latency 128 algorithm per channel on a + // single-node + static constexpr std::array kPerChannelMaxRingLL128Bandwidths = { + 20.0 /* Volta */, + 20.0 /* Ampere */, + 36.7 /* Hopper */, + }; + + // Nvlink unidirectional bandwidth for different compute cap. Note this is per + // lane bandwidth. + static constexpr double kSm60NvlinkBandwidth = 18.0; + static constexpr double kSm70NvlinkBandwidth = 20.0; + static constexpr double kSm80NvlinkBandwidth = 20.0; + static constexpr double kSm90NvlinkBandwidth = 20.0; + + // PCIE bandwidth for PCI Gen3 x16 + static constexpr double kPciBandwidth = 12.0; + + // Discount factor for ring algorithm + static constexpr double kRingAlgorithmDiscountFactor = 0.92; + + // Different tiers for intra-node bandwidth. + static constexpr std::array kIntraNodeSpeeds = { + 40.0, 30.0, 20.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0}; + // SM90 has different bandwidths. + static constexpr std::array kIntraNodeSpeedsSm90 = { + 60.0, 40.0, 30.0, 24.0, 20.0, 15.0, 12.0, 6.0, 3.0}; + + // Maximum number of channels allowed by NCCL + static constexpr int64_t kMaxNumChannelsRing = 16; + + // ll128 is by default enabled for Volta, Ampere and Hopper, ll128 by default + // launches 640 threads. + static constexpr int64_t kLL128NumThreads = 640; + // baseon meature, when below 32M,it can't reach bandwidth + static constexpr int64_t kLinearThreshold = 33554432; + + // TODO: hard code, inner_node_gpu=8,we can't load this attr from instr + static constexpr int64_t kInnerNodeGpu = 8; + static constexpr absl::Duration kNcclKernelLaunchOverhead = + absl::Microseconds(5); + + static absl::Duration ComputeCollectiveTime( + const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, + const se::DeviceDescription& gpu_device_info); + // x = log2(byteaccessed); busbw(GB) = 3.12*x^2 - 65.14*x + 338.66 + static constexpr std::array allreduce_bandwidth_curve = { + 3.12,-65.14,338.66 + }; + static constexpr std::array allgather_bandwidth_curve = { + 3.12,-65.14,338.66 + }; + + + // Returns NVLink bw in GB/s + static float GetNvlinkBw(se::CudaComputeCapability compute_capability); + + // Initialize nvml library. + static bool InitNvml(); + + // Shut down nvml library. + static bool ShutdownNvml(); + + // This checks if the nvlink supports direct P2P communication, + // If not, we will use PCIE bandwidth to estimate latency. + static uint32_t CheckIfNvlinkSupportsP2P(); + + + private: + static absl::Duration ComputeAllreduceTime( + const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, + const se::DeviceDescription& gpu_device_info); + static absl::Duration ComputeAllgatherTime( + const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, + const se::DeviceDescription& gpu_device_info); + static absl::Duration ComputeReducescatterTime( + const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, + const se::DeviceDescription& gpu_device_info); + // return internode/innernode bandwidths in GB/s + static std::vector GetInterInnerBandwidths( + const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, + const se::DeviceDescription& gpu_device_info); +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_GPU_COLLECTIVE_PERFORMANCE_MODEL_H_ diff --git a/xla/service/gpu/model/gpu_collective_performance_model_test.cc b/xla/service/gpu/model/gpu_collective_performance_model_test.cc new file mode 100644 index 0000000000000..4e68c8704001c --- /dev/null +++ b/xla/service/gpu/model/gpu_collective_performance_model_test.cc @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#include +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/tests/hlo_test_base.h" + +namespace xla { +namespace gpu { +namespace { + +using GpuPerformanceWithCollectiveModelTest = HloTestBase; + +TEST_F(GpuPerformanceWithCollectiveModelTest, TestNvmlLibraryLoading) { +#if GOOGLE_CUDA + EXPECT_TRUE(GpuPerformanceWithCollectiveModel::InitNvml()); + // After successful init, we try to use one of the + // nvml functions to see if the result is good. + nvmlDevice_t nvml_device; + nvmlReturn_t get_device_result = + xla_nvmlDeviceGetHandleByIndex(0, &nvml_device); + EXPECT_TRUE(get_device_result == NVML_SUCCESS); + + EXPECT_TRUE(GpuPerformanceWithCollectiveModel::InitNvml()); + +#endif // GOOGLE_CUDA +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/gpu_cost_model_stats_collection.cc b/xla/service/gpu/model/gpu_cost_model_stats_collection.cc index 14472872c3ab9..834d9151f3910 100644 --- a/xla/service/gpu/model/gpu_cost_model_stats_collection.cc +++ b/xla/service/gpu/model/gpu_cost_model_stats_collection.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,20 +16,20 @@ limitations under the License. #include "xla/service/gpu/model/gpu_cost_model_stats_collection.h" #include "absl/container/flat_hash_set.h" -#include "absl/log/log.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/model/gpu_performance_model.h" -#include "xla/statusor.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" #include "tsl/platform/status.h" namespace xla { namespace gpu { -StatusOr GpuCostModelStatsCollection::Run( +absl::StatusOr GpuCostModelStatsCollection::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { // Scan all computations for fusion instructions. diff --git a/xla/service/gpu/model/gpu_cost_model_stats_collection.h b/xla/service/gpu/model/gpu_cost_model_stats_collection.h index 26bf5cd8d30c2..6f03d26e985f4 100644 --- a/xla/service/gpu/model/gpu_cost_model_stats_collection.h +++ b/xla/service/gpu/model/gpu_cost_model_stats_collection.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,7 +24,6 @@ limitations under the License. #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" namespace xla { @@ -44,7 +43,7 @@ class GpuCostModelStatsCollection : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/gpu/model/gpu_cost_model_stats_collection_test.cc b/xla/service/gpu/model/gpu_cost_model_stats_collection_test.cc index 68fd748a27a19..9e8c78c6aa44f 100644 --- a/xla/service/gpu/model/gpu_cost_model_stats_collection_test.cc +++ b/xla/service/gpu/model/gpu_cost_model_stats_collection_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,8 @@ limitations under the License. #include -#include "absl/status/statusor.h" +#include +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" @@ -28,7 +29,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" -#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -67,8 +68,11 @@ TEST_F(GpuCostModelStatsCollectionTest, FusinInEntryComputation) { EXPECT_FALSE(cost_model_stats_.Run(module.get()).value()); HloInstruction* root = module->entry_computation()->root_instruction(); - TF_ASSERT_OK_AND_ASSIGN(auto backend_config, - root->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + root->backend_config()); + const FusionBackendConfig& backend_config = + gpu_config.fusion_backend_config(); + EXPECT_TRUE(backend_config.has_reification_cost()); EXPECT_GT(backend_config.reification_cost().end_to_end_cycles(), 0); } @@ -103,8 +107,11 @@ TEST_F(GpuCostModelStatsCollectionTest, FusinInWhileComputation) { ->root_instruction() ->while_body() ->root_instruction(); - TF_ASSERT_OK_AND_ASSIGN(auto backend_config, - root->backend_config()); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, + root->backend_config()); + const FusionBackendConfig& backend_config = + gpu_config.fusion_backend_config(); + EXPECT_TRUE(backend_config.has_reification_cost()); EXPECT_GT(backend_config.reification_cost().end_to_end_cycles(), 0); } diff --git a/xla/service/gpu/model/gpu_hlo_cost_analysis.cc b/xla/service/gpu/model/gpu_hlo_cost_analysis.cc index e468b1825a0b2..20c258b004f2b 100644 --- a/xla/service/gpu/model/gpu_hlo_cost_analysis.cc +++ b/xla/service/gpu/model/gpu_hlo_cost_analysis.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,17 +16,22 @@ limitations under the License. #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include +#include #include #include -#include -#include +#include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" -#include "absl/strings/str_cat.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/map_util.h" @@ -38,8 +43,12 @@ limitations under the License. #include "xla/service/gpu/model/hlo_op_profiles.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_module_config.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/status.h" #include "xla/stream_executor/device_description.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -59,14 +68,14 @@ static constexpr absl::string_view kCollNumDevicesKey = // type of hardware below. // TODO TJ this needs to be hosted somewhere more centralized. -Status GpuHloCostAnalysis::Preprocess(const HloInstruction* hlo) { +absl::Status GpuHloCostAnalysis::Preprocess(const HloInstruction* hlo) { TF_RETURN_IF_ERROR(HloCostAnalysis::Preprocess(hlo)); current_properties_[kIRSizeKey] = 1; current_properties_[kBasicBlockSplitCountKey] = ElementalIrEmitter::OpInvalidatesCache(hlo); - return OkStatus(); + return absl::OkStatus(); } float GpuHloCostAnalysis::ScalingRatio(const HloInstruction& hlo) const { @@ -85,10 +94,10 @@ int64_t GpuHloCostAnalysis::FusionParameterReadBytes( if (!options_.count_multiple_input_accesses) { utilization = fmin(utilization, 1.0); } - return GetShapeSize(hlo->shape()) * utilization; + return std::llround(GetShapeSize(hlo->shape()) * utilization); } -Status GpuHloCostAnalysis::FusionCalculateUtilizations( +absl::Status GpuHloCostAnalysis::FusionCalculateUtilizations( const HloInstruction* fusion) { const HloInstruction* root = fusion->fused_expression_root(); // Traverse through the computation from the root till parameters propagating @@ -163,15 +172,23 @@ Status GpuHloCostAnalysis::FusionCalculateUtilizations( // to be more realistic. int64_t operand_elements = ShapeUtil::ElementsInRecursive(operand->shape()); - cur_operand_utilization = - ceil(cur_operand_utilization * operand_elements) / operand_elements; + + if (operand_elements == 0) { + // Element count should not be 0 in any production use case, but there + // are valid HLO inputs that occur in tests. + cur_operand_utilization = 0; + } else { + cur_operand_utilization = + ceil(cur_operand_utilization * operand_elements) / + operand_elements; + } root_utilizations_[operand] += cur_operand_utilization; root_ir_sizes[operand] += cur_instr_times_emitted; } } } - return OkStatus(); + return absl::OkStatus(); } float GpuHloCostAnalysis::CommonElementwiseUtilization( @@ -212,14 +229,16 @@ bool GpuHloCostAnalysis::ProducerConsumerMergedTooLarge( return merged_ir_size > kMaxIRSize; } -Status GpuHloCostAnalysis::HandleCustomCall(const HloInstruction* custom_call) { +absl::Status GpuHloCostAnalysis::HandleCustomCall( + const HloInstruction* custom_call) { if (IsCublasGemm(*custom_call)) { // The naming conventions and meanings of gemm parameters are documented // here: // https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-gemm - TF_ASSIGN_OR_RETURN(auto gemm_config, - custom_call->backend_config()); - + TF_ASSIGN_OR_RETURN(auto gpu_config, + custom_call->backend_config()); + const gpu::GemmBackendConfig& gemm_config = + gpu_config.gemm_backend_config(); // Technically, in addition to the dot product (A * B), cuBLAS gemm also // performs additional scaling (by factor 'alpha') and addition with a // scaled third matrix (beta * C), which will introduce additional @@ -246,7 +265,7 @@ Status GpuHloCostAnalysis::HandleCustomCall(const HloInstruction* custom_call) { current_properties_[kFlopsKey] = GetDotFlops(custom_call->operand(0)->shape(), output_shape, gemm_config.dot_dimension_numbers()); - return OkStatus(); + return absl::OkStatus(); } if (IsCustomCallToDnnConvolution(*custom_call)) { @@ -279,7 +298,7 @@ Status GpuHloCostAnalysis::HandleCustomCall(const HloInstruction* custom_call) { current_properties_[kBytesAccessedKey] += output_size; current_properties_.set_output_bytes_accessed(output_size); } - return OkStatus(); + return absl::OkStatus(); } return HloCostAnalysis::HandleCustomCall(custom_call); @@ -305,50 +324,13 @@ int64_t GpuHloCostAnalysis::GetConvolutionFlops( result_shape); } -using ProfilesNestedMap = absl::flat_hash_map< - std::string, // compute capability. - absl::flat_hash_map>>; - -const ProfilesNestedMap* LoadOpProfiles() { - ProfilesNestedMap* ret = new ProfilesNestedMap(); - DeviceHloInstructionProfiles all_device_profiles; - CHECK(tsl::protobuf::TextFormat::ParseFromString( - std::string(kDeviceHloOpProfiles), &all_device_profiles)); - for (const auto& device_profile : all_device_profiles.entries()) { - for (const auto& entry : device_profile.second.entries()) { - (*ret)[device_profile.first][entry.instruction().shape().element_type()] - [StringToHloOpcode(entry.instruction().opcode()).value()] = - entry.clock_cycles(); - } - } - return ret; -} - int64_t FlopsPerElement(const se::DeviceDescription* device_info, const PrimitiveType type, const HloOpcode opcode) { - std::string compute_capability = ""; - if (device_info != nullptr) { - if (auto* ptr = std::get_if( - &device_info->gpu_compute_capability())) - compute_capability = absl::StrCat("sm_", ptr->major, ptr->minor); - if (auto* ptr = std::get_if( - &device_info->gpu_compute_capability())) - compute_capability = ptr->gfx_version(); - } - - static const auto* all_profiles = LoadOpProfiles(); - static const auto& default_profile = all_profiles->at("sm_86"); - auto device_profiles = - FindOrDefault(*all_profiles, compute_capability, default_profile); - auto dtype_profiles = MaybeFind(device_profiles, type); - + auto device_profile = HloOpProfiles::Singleton().GetProfile(device_info); // Elementwise instructions typically take at least a few clock cycles. constexpr int64_t kDefaultFlopsPerElement = 3; - if (!dtype_profiles.ok()) { - return kDefaultFlopsPerElement; - } - return FindOrDefault(dtype_profiles->get(), opcode, kDefaultFlopsPerElement); + return FindOrDefault(device_profile, std::make_pair(opcode, type), + kDefaultFlopsPerElement); } int64_t GetFlopsForElementwiseOp(const se::DeviceDescription* gpu_device_info, @@ -363,14 +345,14 @@ int64_t GetFlopsForElementwiseOp(const se::DeviceDescription* gpu_device_info, return GetFlopsForElementwiseOp(gpu_device_info, instr->opcode(), instr->shape()); } - -Status GpuHloCostAnalysis::HandleAllReduce(const HloInstruction* allreduce) { - const HloModuleConfig& config = allreduce->GetModule()->config(); +template +absl::Status GpuHloCostAnalysis::HandleCommOp(const HloInstruction* hlo){ + const HloModuleConfig& config = hlo->GetModule()->config(); TF_ASSIGN_OR_RETURN( CollectiveOpGroupMode group_mode, GetCollectiveOpGroupMode( - allreduce->channel_id().has_value(), - Cast(allreduce)->use_global_device_ids())); + hlo->channel_id().has_value(), + Cast(hlo)->use_global_device_ids())); // Get number of ranks for this instruction based on replica groups and mode. int64_t num_devices = config.num_partitions(); @@ -378,28 +360,30 @@ Status GpuHloCostAnalysis::HandleAllReduce(const HloInstruction* allreduce) { TF_ASSIGN_OR_RETURN( std::vector participant_counts, GetPariticipantCountsForReplicaGroups( - num_replicas, num_devices, allreduce->replica_groups(), group_mode)); + num_replicas, num_devices, hlo->replica_groups(), group_mode)); int64_t num_ranks = 1; for (auto count : participant_counts) { + VLOG(5) << "Computing cost for " << num_ranks << " ranks participant_counts:"<ToString(); + << hlo->ToString(); int64_t output_bytes_accessed = 0; // Since for allreduces, the input shape is the same as output shape and can // be done in-place, we calculate output_bytes_accessed based on just the // output size. ShapeUtil::ForEachSubshape( - allreduce->shape(), [&](const Shape& subshape, const ShapeIndex&) { + hlo->shape(), [&](const Shape& subshape, const ShapeIndex&) { if (subshape.IsArray()) { output_bytes_accessed += GetShapeSize(subshape); } }); int64_t bytes_accessed = output_bytes_accessed; - for (const HloInstruction* operand : allreduce->operands()) { + for (const HloInstruction* operand : hlo->operands()) { bytes_accessed += GetShapeSize(operand->shape()); } current_properties_.set_output_bytes_accessed(output_bytes_accessed); @@ -407,9 +391,12 @@ Status GpuHloCostAnalysis::HandleAllReduce(const HloInstruction* allreduce) { current_properties_[kCollNumDevicesKey] = num_ranks; // Since allreduce has compute, we need to get flops for the compute // part which is an elementwise op. - current_properties_[kFlopsKey] = GetFlopsForElementwiseOp( - device_info_, allreduce->to_apply()->root_instruction()->opcode(), - allreduce->shape()); + + if(std::is_same()){ + current_properties_[kFlopsKey] = GetFlopsForElementwiseOp( + device_info_, hlo->to_apply()->root_instruction()->opcode(), + hlo->shape()); + } // TODO TJ support multi-node case, we need to know how many nodes there are. int num_intra_steps = 2 * (num_ranks - 1); @@ -421,19 +408,144 @@ Status GpuHloCostAnalysis::HandleAllReduce(const HloInstruction* allreduce) { float scaling_ratio = (1.0 * num_ranks) / num_intra_steps; current_properties_[kCollAlgoScaleRatioKey] = scaling_ratio; - return OkStatus(); + return absl::OkStatus(); +} +absl::Status GpuHloCostAnalysis::HandleAllReduce( + const HloInstruction* allreduce) { + return HandleCommOp(allreduce); + +} + +absl::Status GpuHloCostAnalysis::HandleAllGather(const HloInstruction* hlo){ + auto st = HandleCommOp(hlo); + if(!st.ok()){ + return st; + } + //Overwrite bytes + int64_t output_bytes_accessed = 0; + ShapeUtil::ForEachSubshape( + hlo->shape(), [&](const Shape& subshape, const ShapeIndex&) { + if (subshape.IsArray()) { + output_bytes_accessed += GetShapeSize(subshape); + } + }); + // set allgather bytes_accessed equal input operator size + int64_t bytes_accessed = 0; + for (const HloInstruction* operand : hlo->operands()) { + bytes_accessed += GetShapeSize(operand->shape()); + } + current_properties_.set_output_bytes_accessed(output_bytes_accessed); + current_properties_[kBytesAccessedKey] = bytes_accessed; + // Compute algorithmic scaling ratio + // link: https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md + // algbw = Buswidth*n/(n-1) + int64_t num_ranks = current_properties_[kCollNumDevicesKey]; + float scaling_ratio = 1.0; + if(num_ranks>1){ + scaling_ratio = (1.0 * num_ranks) / (num_ranks-1); + } + current_properties_[kCollAlgoScaleRatioKey] = scaling_ratio; + return absl::OkStatus(); +} +absl::Status GpuHloCostAnalysis::HandleReduceScatter(const HloInstruction* hlo){ + auto st = HandleCommOp(hlo); + if(!st.ok()){ + return st; + } + int64_t num_ranks = current_properties_[kCollNumDevicesKey]; + // Compute algorithmic scaling ratio + // link: https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md + // algbw = Buswidth*n/(n-1) + float scaling_ratio = 1.0; + if(num_ranks>1){ + scaling_ratio = (1.0 * num_ranks) / (num_ranks-1); + } + current_properties_[kCollAlgoScaleRatioKey] = scaling_ratio; + return absl::OkStatus(); +} + +absl::Status GpuHloCostAnalysis::HandleConcatenate(const HloInstruction* hlo) { + // Concat turns into a compare plus branch instruction. + int64_t flop_per_element = 6; + // If a warp crosses the operands boundary, both branches are executed. This + // depends on the tiling of the final fusion and is therefore hard to predict + // at this level. Executing both branches drives up the flops, but not the + // bandwidth. So it might seem like a good idea to fuse a concat into a + // memory-bound consumer. However, the divergent warps increase the cost of + // compute-heavy producers that might be fused later. We see this issue in + // some important LLM models that fuse a concat into a column reduction (see + // PriorityFusionTest.DontFuseConcat test). To prevent this particular fusion, + // we add large number of flops to the concat. Both the condition and the flop + // count are tuned to this particular case. + // TODO(b/315776282): Model this more accurately once we can reason about + // tiling patterns. + int64_t dim = Cast(hlo)->concatenate_dimension(); + if (dim > 0 && hlo->operand(0)->shape().dimensions()[dim] & 31) { + flop_per_element = 400; + } + current_properties_[kFlopsKey] = + flop_per_element * ShapeUtil::ElementsInRecursive(hlo->shape()); + return absl::OkStatus(); +} + +absl::Status GpuHloCostAnalysis::HandleReduce(const HloInstruction* hlo) { + // HloCostAnalysis::HandleReduce computes FLOPs for the computation correctly, + // but `bytes_accessed` estimates are different for GPU. + TF_RETURN_IF_ERROR(HloCostAnalysis::HandleReduce(hlo)); + + const HloReduceInstruction* reduce = DynCast(hlo); + auto output_shape = reduce->shape().IsArray() + ? reduce->shape() + : reduce->shape().tuple_shapes(0); + + int64_t output_bytes_accessed = 0; + ShapeUtil::ForEachLeafShape( + reduce->shape(), [&](const Shape& sub_shape, const ShapeIndex& index) { + output_bytes_accessed += GetShapeSize(sub_shape); + }); + + current_properties_.set_output_bytes_accessed(output_bytes_accessed); + + int64_t bytes_accessed = output_bytes_accessed; + for (int64_t input_operand_id = 0; input_operand_id < reduce->input_count(); + ++input_operand_id) { + bytes_accessed += + current_properties_.operand_bytes_accessed(input_operand_id); + } + + int64_t output_shape_size = ShapeUtil::ElementsIn(output_shape); + for (int64_t init_operand_id = reduce->input_count(); + init_operand_id < reduce->operand_count(); ++init_operand_id) { + auto init_operand = reduce->operand(init_operand_id); + + int64_t operand_bytes_accessed = + output_shape_size * GetShapeSize(init_operand->shape()); + current_properties_.set_operand_bytes_accessed(init_operand_id, + operand_bytes_accessed); + current_properties_.set_operand_utilization(init_operand_id, + output_shape_size); + + bytes_accessed += operand_bytes_accessed; + } + + current_properties_[kBytesAccessedKey] = bytes_accessed; + + return absl::OkStatus(); } -Status GpuHloCostAnalysis::HandleElementwiseOp(const HloInstruction* hlo) { +absl::Status GpuHloCostAnalysis::HandleElementwiseOp( + const HloInstruction* hlo) { current_properties_[kFlopsKey] = GetFlopsForElementwiseOp(device_info_, hlo); - return OkStatus(); + return absl::OkStatus(); } -Status GpuHloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) { +absl::Status GpuHloCostAnalysis::HandleElementwiseUnary( + const HloInstruction* hlo) { return HandleElementwiseOp(hlo); } -Status GpuHloCostAnalysis::HandleElementwiseBinary(const HloInstruction* hlo) { +absl::Status GpuHloCostAnalysis::HandleElementwiseBinary( + const HloInstruction* hlo) { return HandleElementwiseOp(hlo); } diff --git a/xla/service/gpu/model/gpu_hlo_cost_analysis.h b/xla/service/gpu/model/gpu_hlo_cost_analysis.h index 352891494b6a4..260fa08a489b5 100644 --- a/xla/service/gpu/model/gpu_hlo_cost_analysis.h +++ b/xla/service/gpu/model/gpu_hlo_cost_analysis.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,8 +16,13 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_MODEL_GPU_HLO_COST_ANALYSIS_H_ #define XLA_SERVICE_GPU_MODEL_GPU_HLO_COST_ANALYSIS_H_ +#include +#include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/hlo_cost_analysis.h" @@ -39,20 +44,27 @@ class GpuHloCostAnalysis : public HloCostAnalysis { const se::DeviceDescription* device_info = nullptr) : HloCostAnalysis(options), device_info_(device_info) {} - Status Preprocess(const HloInstruction* hlo) override; + absl::Status Preprocess(const HloInstruction* hlo) override; float ScalingRatio(const HloInstruction& hlo) const; int64_t NumOfDevices(const HloInstruction& hlo) const; - Status HandleCustomCall(const HloInstruction* call) override; + absl::Status HandleCustomCall(const HloInstruction* call) override; int64_t GetConvolutionFlops(const HloInstruction* convolution) override; - Status HandleElementwiseOp(const HloInstruction* hlo); - Status HandleElementwiseUnary(const HloInstruction* hlo) override; - Status HandleElementwiseBinary(const HloInstruction* hlo) override; + absl::Status HandleElementwiseOp(const HloInstruction* hlo); + absl::Status HandleElementwiseUnary(const HloInstruction* hlo) override; + absl::Status HandleElementwiseBinary(const HloInstruction* hlo) override; + + template + absl::Status HandleCommOp(const HloInstruction* hlo); - Status HandleAllReduce(const HloInstruction* allreduce) override; + absl::Status HandleConcatenate(const HloInstruction* hlo) override; + absl::Status HandleAllReduce(const HloInstruction* allreduce) override; + absl::Status HandleAllGather(const HloInstruction* hlo) override; + absl::Status HandleReduceScatter(const HloInstruction* hlo) override; + absl::Status HandleReduce(const HloInstruction* hlo) override; // Estimate the total size of IR accounting for both duplication // of producer code by consumer and the total number of basic blocks. @@ -77,7 +89,8 @@ class GpuHloCostAnalysis : public HloCostAnalysis { protected: std::unique_ptr CreateNestedCostAnalysis() override; int64_t FusionParameterReadBytes(const HloInstruction* hlo) const override; - Status FusionCalculateUtilizations(const HloInstruction* fusion) override; + absl::Status FusionCalculateUtilizations( + const HloInstruction* fusion) override; size_t immediate_constant_max_elements() const override { return 8; } diff --git a/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc b/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc index e0f4e2a4cd177..29bedc37edcb1 100644 --- a/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc +++ b/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,7 +15,19 @@ limitations under the License. #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -109,6 +121,31 @@ ENTRY entry { sizeof(float) * n_output_elements); } + + +TEST_F(GpuHloCostAnalysisTest, AllGatherDevice) { + absl::string_view hlo_string = R"( +HloModule m +ENTRY e { + %p0 = f32[400,12800]{1,0} parameter(0) + %ag-start = (f32[400,12800], f32[1600,12800]) all-gather-start( + f32[400,12800] %p0), replica_groups={{0,1,2,3}}, dimensions={0}, + metadata={op_type="AllGather" op_name="ag0"} + %ag-done = f32[1600,12800] all-gather-done( + (f32[400,12800], f32[1600,12800]) %ag-start), + metadata={op_type="AllGather" op_name="ag0.done"} + ROOT tuple = (f32[1600,12800]{0,1}) tuple(%ag-done ) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string,/*replica_count*/ 4)); + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); + auto computation = module->entry_computation(); + const HloInstruction* all_reduce_start = computation->GetInstructionWithName("ag0"); + + EXPECT_EQ(analysis_.NumOfDevices(*all_reduce_start), 4); +} TEST_F(GpuHloCostAnalysisTest, BroadcastWithRepeats) { absl::string_view hlo_string = R"( HloModule m @@ -529,5 +566,82 @@ TEST_F(GpuHloCostAnalysisTest, CommonElementwiseUseParameterAndRoot) { 0.f); } +TEST_F(GpuHloCostAnalysisTest, Reduce) { + absl::string_view hlo_string = R"( +HloModule m + +add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT add.0 = f32[] add(param_0, param_1) +} + +ENTRY entry_computation { + param_0.3 = f32[32,40]{1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[32]{0} reduce(param_0.3, constant), dimensions={1}, to_apply=add +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + + int64_t input_bytes_accessed = 4 * 32 * 40; + int64_t init_bytes_accessed = 4 * 32; + int64_t output_bytes_accessed = 4 * 32; + + EXPECT_EQ(analysis_.operand_bytes_accessed(*reduce, 0), input_bytes_accessed); + EXPECT_EQ(analysis_.operand_bytes_accessed(*reduce, 1), init_bytes_accessed); + EXPECT_EQ(analysis_.output_bytes_accessed(*reduce), output_bytes_accessed); + EXPECT_EQ(analysis_.bytes_accessed(*reduce), + input_bytes_accessed + init_bytes_accessed + output_bytes_accessed); + EXPECT_EQ(analysis_.flop_count(*reduce), 32 * 39 * 3); +} + +TEST_F(GpuHloCostAnalysisTest, VariadicReduce) { + absl::string_view hlo_string = R"( +HloModule m + +add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + param_2 = f32[] parameter(2) + param_3 = f32[] parameter(3) + add.0 = f32[] add(param_0, param_2) + add.1 = f32[] add(param_1, param_3) + ROOT t = (f32[], f32[]) tuple(add.0, add.1) +} + +ENTRY entry_computation { + param_0.3 = f32[32,40]{1,0} parameter(0) + param_1.3 = f32[32,40]{1,0} parameter(1) + param_2.2 = f32[] parameter(2) + constant = f32[] constant(0) + ROOT reduce = (f32[32]{0}, f32[32]{0}) reduce(param_0.3, param_1.3, param_2.2, constant), dimensions={1}, to_apply=add +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + + int64_t input_bytes_accessed = 4 * 32 * 40; + int64_t init_bytes_accessed = 4 * 32; + int64_t output_bytes_accessed = 2 * 4 * 32; + + EXPECT_EQ(analysis_.operand_bytes_accessed(*reduce, 0), input_bytes_accessed); + EXPECT_EQ(analysis_.operand_bytes_accessed(*reduce, 1), input_bytes_accessed); + EXPECT_EQ(analysis_.operand_bytes_accessed(*reduce, 2), init_bytes_accessed); + EXPECT_EQ(analysis_.operand_bytes_accessed(*reduce, 3), init_bytes_accessed); + EXPECT_EQ(analysis_.output_bytes_accessed(*reduce), output_bytes_accessed); + EXPECT_EQ(analysis_.bytes_accessed(*reduce), 2 * input_bytes_accessed + + 2 * init_bytes_accessed + + output_bytes_accessed); + EXPECT_EQ(analysis_.flop_count(*reduce), 32 * 39 * 6); +} + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/model/gpu_indexing_performance_model.cc b/xla/service/gpu/model/gpu_indexing_performance_model.cc new file mode 100644 index 0000000000000..6e8b57d62dcbd --- /dev/null +++ b/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -0,0 +1,235 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/gpu_indexing_performance_model.h" + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/coalescing_analysis.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "tsl/platform/status.h" + +namespace xla { +namespace gpu { + +int64_t GpuPerformanceModelWithIndexingAnalysis::FlopsPerElement( + const HloInstruction* instr) const { + // TODO(shyshkov): Replace dependency on GpuHloCostAnalysis with independent + // flops calculation. + GpuHloCostAnalysis::Options cost_analysis_options{ + shape_size_, + /*per_second_rates=*/{}, + /*count_multiple_input_accesses=*/true}; + GpuHloCostAnalysis cost_analysis(cost_analysis_options, device_info_); + TF_CHECK_OK( + cost_analysis.RevisitInstruction(const_cast(instr))); + + int64_t num_elements = [&] { + if (instr->opcode() == HloOpcode::kReduce && instr->shape().IsTuple()) { + return ShapeUtil::ElementsInRecursive(instr->shape().tuple_shapes(0)); + } + return ShapeUtil::ElementsInRecursive(instr->shape()); + }(); + + return cost_analysis.flop_count(*instr) / num_elements; +} + +int64_t GpuPerformanceModelWithIndexingAnalysis::GetShapeSizeRecursive( + const Shape& shape) const { + CHECK(shape.IsArray() || shape.IsTuple()); + if (shape.IsArray()) { + return shape_size_(shape); + } + + int64_t total_size = 0; + for (const auto& element_shape : shape.tuple_shapes()) { + total_size += GetShapeSizeRecursive(element_shape); + } + return total_size; +} + +int64_t GetIterationSpaceSize(const IndexingMap& indexing_map, + const HloInstruction* instr) { + if (indexing_map.IsUndefined()) { + return ShapeUtil::ElementsInRecursive(instr->shape()); + } + + if (indexing_map.IsKnownEmpty()) { + return 0; + } + + auto get_ranges_iteration_space_size = + [](const std::vector& ranges) { + int64_t num_iters = 1; + for (const Interval& range : ranges) { + num_iters *= range.upper - range.lower + 1; + } + return num_iters; + }; + + return get_ranges_iteration_space_size(indexing_map.GetSymbolBounds()) * + get_ranges_iteration_space_size(indexing_map.GetDimensionBounds()); +} + +EstimateRunTimeData +GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForFusion( + const HloFusionAnalysis& fusion_analysis, bool is_coalesced) { + auto& fusion_adaptor = fusion_analysis.fusion(); + VLOG(5) << "EstimateRunTimeForFusion: " << fusion_adaptor.ToString(); + + auto roots = fusion_adaptor.GetRoots(); + CHECK_EQ(roots.size(), 1) + << "Indexing cost model doesn't support multi-output fusions."; + auto root_shape = roots.front().shape(); + + LaunchDimensions launch_dimensions = + EstimateFusionLaunchDimensions(ShapeUtil::ElementsInRecursive(root_shape), + fusion_analysis, *device_info_); + + int64_t num_threads = launch_dimensions.launch_bound(); + int64_t num_blocks = launch_dimensions.num_blocks(); + + // Compute indexing from root to each instruction in the fusion and fusion + // operands. For each instruction, tells which elements of the instructions + // result will be used to compute one result element of the fusion. + auto grouped_fusion_indexing = ComputeGroupedOutputToInputIndexing( + fusion_adaptor, roots[0], mlir_context_); + + int64_t flops = 0; + int64_t bytes_read = 0; + absl::Duration read_time = absl::ZeroDuration(); + + for (const auto& [instr, indexing_maps] : grouped_fusion_indexing) { + VLOG(10) << "instr: " << instr->name(); + HloInstructionAdaptor instr_adaptor(*instr); + + // Instructions inside the fusion are computation and account for FLOPs + // count. Instructions outside the fusion are operands of the fusion and + // account for memory read time. + bool is_operand = !fusion_adaptor.ContainsInstruction(instr_adaptor); + + auto element_type = instr->shape().element_type(); + int64_t n_bytes_total = 0; + for (const auto& indexing_map : indexing_maps) { + VLOG(10) << indexing_map.ToString(); + + int64_t num_iters = GetIterationSpaceSize(indexing_map, instr); + + if (is_operand) { + int64_t type_size = ShapeUtil::ByteSizeOfPrimitiveType(element_type); + n_bytes_total += type_size * num_iters; + } else { + int64_t flops_per_element = FlopsPerElement(instr); + flops += flops_per_element * num_iters; + } + } + + if (is_operand) { + int64_t operand_size = shape_size_(instr->shape()); + int64_t n_bytes_net = std::min(operand_size, n_bytes_total); + bytes_read += n_bytes_total; + + VLogOperandRead(instr, n_bytes_total, n_bytes_net, is_coalesced); + + read_time += + ReadTimeWithDRAMHeuristic(*device_info_, num_blocks, n_bytes_net, + n_bytes_total, element_type, is_coalesced); + } + } + + int64_t bytes_written = GetShapeSizeRecursive(root_shape); + + absl::Duration compute_time = ComputeTime(*device_info_, flops, num_threads); + absl::Duration write_time = WriteTime(*device_info_, bytes_written); + absl::Duration memory_access_time = read_time + write_time; + absl::Duration exec_time = CombineComputeAndMemoryAccessTime( + compute_time, memory_access_time, + GpuPerformanceModelOptions::PriorityFusion()); + + VLogResult(flops, bytes_read, bytes_written, num_threads, compute_time, + read_time, write_time, exec_time); + + return EstimateRunTimeData{flops, bytes_written, num_threads, read_time, + write_time, compute_time, exec_time}; +} + +EstimateRunTimeData +GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForInstruction( + const HloInstruction* producer) { + // Stand-alone bitcast is always no-op during runtime. + if (producer->opcode() == HloOpcode::kBitcast) { + return {0, 0, 0, absl::ZeroDuration(), absl::ZeroDuration()}; + } + + auto fusion_analysis = AnalyzeFusion(*producer, *device_info_); + + bool is_coalesced = IsReadCoalescedHeuristic( + fusion_analysis.GetEmitterFusionKind(), producer); + return EstimateRunTimeForFusion(fusion_analysis, is_coalesced); +} + +EstimateRunTimeData +GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForProducerConsumer( + const HloInstruction* producer, const HloInstruction* consumer) { + auto fusion_analysis = + AnalyzeProducerConsumerFusion(*producer, *consumer, *device_info_); + + bool is_coalesced = IsReadCoalescedHeuristic( + fusion_analysis.GetEmitterFusionKind(), producer, consumer); + return EstimateRunTimeForFusion(fusion_analysis, is_coalesced); +} + +/*static*/ +GpuPerformanceModelWithIndexingAnalysis::RunTimes +GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimes( + const HloInstruction* producer, + absl::Span fused_consumers) { + auto producer_runtime = EstimateRunTimeForInstruction(producer); + + absl::Duration time_unfused = + kKernelLaunchOverhead * (fused_consumers.size() + 1) + + producer_runtime.exec_time; + + absl::Duration time_fused = kKernelLaunchOverhead * fused_consumers.size(); + + for (const auto& consumer : fused_consumers) { + time_unfused += EstimateRunTimeForInstruction(consumer).exec_time; + time_fused += + EstimateRunTimeForProducerConsumer(producer, consumer).exec_time; + } + + return {time_unfused, time_fused}; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/gpu_indexing_performance_model.h b/xla/service/gpu/model/gpu_indexing_performance_model.h new file mode 100644 index 0000000000000..14d7e520a820d --- /dev/null +++ b/xla/service/gpu/model/gpu_indexing_performance_model.h @@ -0,0 +1,76 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_GPU_INDEXING_PERFORMANCE_MODEL_H_ +#define XLA_SERVICE_GPU_MODEL_GPU_INDEXING_PERFORMANCE_MODEL_H_ + +#include + +#include "absl/types/span.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" +#include "xla/service/gpu/model/hlo_op_profiles.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/shape.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { +namespace gpu { + +// Implementation of Cost Model that uses indexing analysis to estimate amount +// of compute and memory access time. +class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { + public: + explicit GpuPerformanceModelWithIndexingAnalysis( + const se::DeviceDescription* device_info, + HloCostAnalysis::ShapeSizeFunction shape_size, + mlir::MLIRContext* mlir_context) + : hlo_op_profile_(&HloOpProfiles::Singleton().GetProfile(device_info)), + device_info_(device_info), + shape_size_(shape_size), + mlir_context_(mlir_context) {} + + EstimateRunTimeData EstimateRunTimeForFusion( + const HloFusionAnalysis& fusion_analysis, bool is_coalesced = true); + + EstimateRunTimeData EstimateRunTimeForInstruction( + const HloInstruction* producer); + + EstimateRunTimeData EstimateRunTimeForProducerConsumer( + const HloInstruction* producer, const HloInstruction* consumer); + + RunTimes EstimateRunTimes( + const HloInstruction* producer, + absl::Span fused_consumers = {}); + + private: + // Returns an estimate how many FLOPs will be used to produce one element of + // the output. + int64_t FlopsPerElement(const HloInstruction* instr) const; + + int64_t GetShapeSizeRecursive(const Shape& shape) const; + + const HloOpProfiles::HloOpProfile* hlo_op_profile_; + const se::DeviceDescription* device_info_; + HloCostAnalysis::ShapeSizeFunction shape_size_; + mlir::MLIRContext* mlir_context_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_GPU_INDEXING_PERFORMANCE_MODEL_H_ diff --git a/xla/service/gpu/model/gpu_indexing_performance_model_test.cc b/xla/service/gpu/model/gpu_indexing_performance_model_test.cc new file mode 100644 index 0000000000000..5e52685762e52 --- /dev/null +++ b/xla/service/gpu/model/gpu_indexing_performance_model_test.cc @@ -0,0 +1,172 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/gpu_indexing_performance_model.h" + +#include +#include + +#include +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuIndexingPerformanceModelTest : public HloTestBase { + GpuHloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { + return [&](const Shape& shape) { + constexpr int64_t kPointerSize = 8; + return ShapeUtil::ByteSizeOf(shape, kPointerSize); + }; + } + + public: + mlir::MLIRContext mlir_context_; + // The reference times in the test cases below are measured + // on A6000 by profiling the execution of the HLOs. + se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; + GpuPerformanceModelWithIndexingAnalysis indexing_cost_model_{ + &device_info_, ShapeSizeBytesFunction(), &mlir_context_}; + + GpuIndexingPerformanceModelTest() : HloTestBase() {} +}; + +TEST_F(GpuIndexingPerformanceModelTest, BroadcastElementwise) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + R"( +HloModule extracted + +ENTRY entry_computation { + param_0 = f32[32]{0} parameter(0) + broadcast = f32[32,1,768]{2,1,0} broadcast(param_0), dimensions={0} + param_1 = f32[32,1,768]{2,1,0} parameter(1) + ROOT multiply = f32[32,1,768]{2,1,0} multiply(broadcast, param_1) +} +)")); + + auto producer = + module->entry_computation()->GetInstructionWithName("broadcast"); + auto consumer = + module->entry_computation()->GetInstructionWithName("multiply"); + + auto runtime_data = indexing_cost_model_.EstimateRunTimeForProducerConsumer( + producer, consumer); + EXPECT_EQ(runtime_data.flops, 73728); + EXPECT_EQ(runtime_data.bytes_written, 98304); + EXPECT_NEAR(absl::ToInt64Nanoseconds(runtime_data.write_time), 128, 2); + EXPECT_NEAR(absl::ToInt64Nanoseconds(runtime_data.exec_time), 267, 2); +} + +TEST_F(GpuIndexingPerformanceModelTest, Bitcast) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + R"( +HloModule m + +ENTRY entry_computation { + param_0 = bf16[4,8,65,128]{3,2,1,0} parameter(0) + ROOT bitcast = bf16[8,4,65,128]{3,2,0,1} bitcast(param_0) +} +)")); + + auto instruction = + module->entry_computation()->GetInstructionWithName("bitcast"); + + auto runtime_data = + indexing_cost_model_.EstimateRunTimeForInstruction(instruction); + EXPECT_EQ(runtime_data.flops, 0); + EXPECT_EQ(runtime_data.bytes_written, 0); + EXPECT_EQ(runtime_data.write_time, absl::ZeroDuration()); + EXPECT_EQ(runtime_data.exec_time, absl::ZeroDuration()); +} + +TEST_F(GpuIndexingPerformanceModelTest, Reduce) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + R"( +HloModule m + +add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT add.0 = f32[] add(param_0, param_1) +} + +ENTRY entry_computation { + param_0.3 = f32[32,40]{1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[32]{0} reduce(param_0.3, constant), dimensions={1}, to_apply=add +} +)")); + + auto instruction = module->entry_computation()->root_instruction(); + + auto runtime_data = + indexing_cost_model_.EstimateRunTimeForInstruction(instruction); + EXPECT_EQ(runtime_data.flops, 3744); + EXPECT_EQ(runtime_data.bytes_written, 128); + EXPECT_NEAR(absl::ToDoubleNanoseconds(runtime_data.write_time), 0, 1); + EXPECT_NEAR(absl::ToDoubleNanoseconds(runtime_data.exec_time), 29, 1); +} + +TEST_F(GpuIndexingPerformanceModelTest, VariadicReduce) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + R"( +HloModule m + +add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + param_2 = f32[] parameter(2) + param_3 = f32[] parameter(3) + add.0 = f32[] add(param_0, param_2) + add.1 = f32[] add(param_1, param_3) + ROOT t = (f32[], f32[]) tuple(add.0, add.1) +} + +ENTRY entry_computation { + param_0.3 = f32[32,40]{1,0} parameter(0) + param_1.3 = f32[32,40]{1,0} parameter(1) + param_2.2 = f32[] parameter(2) + constant = f32[] constant(0) + ROOT reduce = (f32[32]{0}, f32[32]{0}) reduce(param_0.3, param_1.3, param_2.2, constant), dimensions={1}, to_apply=add +} +)")); + + auto instruction = module->entry_computation()->root_instruction(); + + auto runtime_data = + indexing_cost_model_.EstimateRunTimeForInstruction(instruction); + EXPECT_EQ(runtime_data.flops, 7488); + EXPECT_EQ(runtime_data.bytes_written, 256); + EXPECT_NEAR(absl::ToDoubleNanoseconds(runtime_data.write_time), 0, 1); + EXPECT_NEAR(absl::ToDoubleNanoseconds(runtime_data.exec_time), 58, 1); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/gpu_performance_model.cc b/xla/service/gpu/model/gpu_performance_model.cc index 7bd454134cebb..5144094c377b4 100644 --- a/xla/service/gpu/model/gpu_performance_model.cc +++ b/xla/service/gpu/model/gpu_performance_model.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,384 +18,110 @@ limitations under the License. #include #include #include -#include #include -#include #include +#include "absl/log/check.h" #include "absl/log/log.h" #include "absl/time/time.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/coalescing_analysis.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" +#include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" +#include "xla/util.h" +#include "tsl/platform/status.h" -#if GOOGLE_CUDA -#include "third_party/gpus/cuda/nvml/include/nvml.h" -#endif // GOOGLE_CUDA namespace xla { namespace gpu { - namespace { -// Estimated values in the absence of easy ways to query them. -static constexpr absl::Duration kKernelLaunchOverhead = absl::Microseconds(5); -static constexpr float kL2CacheSpeedup = 2.5; -static constexpr float kL1CacheSpeedup = 8; -// A very conservative estimate. L1 size varies because it can be dynamically -// configured as shared memory; there is no easy way to query its actual size; -// also we do not count what occupies cache, but rather claim that what is -// much smaller than the cache size will likely stay in it. -// For reference, it can be up to 256 kB per SM on RTX A6000. -static constexpr float kL1CacheSizePerSM = 2 * 1024; - -// Returns whether a fusion uses the parameter at the given index elementwise -// from its root. -bool FusionUsesParameterElementwiseFromRoot( - const HloInstruction* fusion, int parameter_index, - const GpuHloCostAnalysis* cost_analysis) { - return cost_analysis->CommonElementwiseUtilization( - fusion->fused_parameter(parameter_index), - fusion->fused_expression_root()) == 1.f; -} - -int GetCoalescingWasteFactor(PrimitiveType element_type) { - int64_t element_size_bytes = - element_type == PrimitiveType::TUPLE || - element_type == PrimitiveType::TOKEN - ? 4 /* Dummy value. TODO(jreiffers): Model this case. */ - : ShapeUtil::ByteSizeOfPrimitiveType(element_type); - // Cache line is 128B that is split into 4 sectors of 32B. Default transaction - // size from DRAM -> L2 = 64 Bytes = 2 sectors, since V100, but it can be also - // configured. - // https://developer.download.nvidia.com/video/gputechconf/gtc/2020/presentations/s21819-optimizing-applications-for-nvidia-ampere-gpu-architecture.pdf - // (page 10). - constexpr int kDRAMToL2TransactionSizeBytes = 64; - // Assume we use one element from the cache line and waste the remaining - // bandwidth. For example, if we're reading f32s, we use 1/16nd of the cache - // line. - return kDRAMToL2TransactionSizeBytes / element_size_bytes; -} - -// Estimate read time of n_bytes_total bytes from global memory on a -// given GPU. Account for L1 / L2 cache speedup if the input's nominal size -// n_bytes_net is small. -absl::Duration ReadTime(const se::DeviceDescription& gpu_device_info, - int64_t num_blocks, int64_t n_bytes_net, - int64_t n_bytes_total, PrimitiveType element_type, - bool coalesced, bool first_read_from_dram) { - int waste_factor = coalesced ? 1 : GetCoalescingWasteFactor(element_type); - - // Limit the bandwidth for low occupancy cases. Each SM can issue at most - // one 32B memory transaction per clock. H100 needs at least 56.8 active SMs - // (1830 MHz) to saturate the memory bandwidth (3.35 TB/s). - float per_block_bandwidth = gpu_device_info.clock_rate_ghz() * 1.0e9f * 32; - float max_bandwidth = num_blocks * per_block_bandwidth; - - if (first_read_from_dram) { - // The first read of the input buffer always happens from DRAM. If reads are - // no coaleced, bandwidth is reduced by the waste factor. - float dram_bandwidth = gpu_device_info.memory_bandwidth() / waste_factor; - - // Two things can happed on re-reading the buffer: - // - If the buffer fits into cache, the L1/L2 cache speedup is applied. - // - If the buffer doesn't fit, it will be read from DRAM and the same - // coalessing waste factor is applied. - float rest_bandwidth = gpu_device_info.memory_bandwidth(); - if (n_bytes_net < gpu_device_info.l2_cache_size()) { - rest_bandwidth *= kL2CacheSpeedup; - if (n_bytes_net < kL1CacheSizePerSM * gpu_device_info.core_count()) { - rest_bandwidth *= kL1CacheSpeedup; - } - } else { - rest_bandwidth /= waste_factor; - } - - dram_bandwidth = std::min(dram_bandwidth, max_bandwidth); - rest_bandwidth = std::min(rest_bandwidth, max_bandwidth); - - // n_bytes_net > n_bytes_total can happend when we compute read time of - // shared operand. This is a flaw in the interface that should be fixed. - int64_t n_bytes_read_dram = std::min(n_bytes_net, n_bytes_total); - - // Number of bytes that we be re-read, potentially from cache. - int64_t n_bytes_read_cache = n_bytes_total - n_bytes_read_dram; - - return absl::Seconds(n_bytes_read_dram / dram_bandwidth) + - absl::Seconds(n_bytes_read_cache / rest_bandwidth); - } else { - float bandwidth = gpu_device_info.memory_bandwidth(); - if (n_bytes_net < gpu_device_info.l2_cache_size()) { - bandwidth *= kL2CacheSpeedup; - if (n_bytes_net < kL1CacheSizePerSM * gpu_device_info.core_count()) { - bandwidth *= kL1CacheSpeedup; - } - } else if (!coalesced) { - bandwidth /= waste_factor; - } - - bandwidth = std::min(bandwidth, max_bandwidth); - return absl::Seconds(n_bytes_total / bandwidth); - } -} - -int64_t GetNcclMaxNumChannels( - GpuPerformanceWithCollectiveModel::CollectiveAlgo algorithm) { - int64_t max_nchannels = 0; - switch (algorithm) { - // Tree and Ring algos share the same max channel number. - case GpuPerformanceWithCollectiveModel::RING: - case GpuPerformanceWithCollectiveModel::TREE: - max_nchannels = GpuPerformanceWithCollectiveModel::kMaxNumChannelsRing; - break; - } - const char* env = std::getenv("NCCL_MAX_NCHANNELS"); - if (env != nullptr) { - int64_t max_nchannels_from_env; - if (absl::SimpleAtoi(env, &max_nchannels_from_env)) { - max_nchannels = std::min(max_nchannels_from_env, max_nchannels); - } - } - return max_nchannels; -} - -int64_t GetMinNumberOfChannels( - GpuPerformanceWithCollectiveModel::CollectiveAlgo algorithm) { - int64_t min_nchannels = 0; - switch (algorithm) { - // Tree and Ring algos share the same min channel number. - case GpuPerformanceWithCollectiveModel::RING: - case GpuPerformanceWithCollectiveModel::TREE: - min_nchannels = 1; - break; - } - const char* env = std::getenv("NCCL_MIN_NCHANNELS"); - if (env != nullptr) { - int64_t min_nchannels_from_env; - if (absl::SimpleAtoi(env, &min_nchannels_from_env)) { - min_nchannels = std::min(min_nchannels_from_env, min_nchannels); - } - } - return min_nchannels; -} - -int GetNumThreads(int warp_size, int min_num_threads, int max_num_threads, - int default_num_threads) { - int threads_from_env = default_num_threads; - const char* env = std::getenv("NCCL_NTHREADS"); - if (env != nullptr) { - CHECK(absl::SimpleAtoi(env, &threads_from_env)); - } - int num_threads = threads_from_env; - if (num_threads > 0) { - if (num_threads % warp_size != 0) { - num_threads = max_num_threads; - } else if (num_threads > max_num_threads) { - num_threads = max_num_threads; - } else if (num_threads < min_num_threads) { - num_threads = min_num_threads; - } - } else { - num_threads = default_num_threads; - } - return num_threads; -} - -float GetMaxSysBwFromGpu(const se::CudaComputeCapability cc, - const double* bandwidths_table) { - switch (cc.major) { - case se::CudaComputeCapability::VOLTA: - return bandwidths_table[0]; - case se::CudaComputeCapability::AMPERE: - return bandwidths_table[1]; - case se::CudaComputeCapability::HOPPER: - return bandwidths_table[2]; +std::vector GetUniqueFusionOperands( + const HloInstruction* producer, const HloInstruction* consumer) { + std::vector fusion_operands; + for (const HloInstruction* operand : producer->operands()) { + fusion_operands.push_back(operand); } - return -1; -} - -// Uses HloFusionAnalysis for computing the actual number of threads and blocks -// that the IR emitter will use. -LaunchDimensions EstimateFusionLaunchDimensions( - int64_t estimated_num_threads, - const std::optional& fusion_analysis, - const se::DeviceDescription& device_info) { - if (fusion_analysis) { - // TODO(jreiffers): This is the wrong place for this DUS analysis. - const HloInstruction* dus = nullptr; - for (const auto* root : fusion_analysis->fusion_roots()) { - if (root->opcode() == HloOpcode::kDynamicUpdateSlice) { - dus = root; - } else if (root->opcode() == HloOpcode::kBitcast && - root->operand(0)->opcode() == HloOpcode::kDynamicUpdateSlice) { - dus = root->operand(0); - } else { - dus = nullptr; - break; - } - } - - if (dus) { - if (auto dims = - CalculateLaunchDimensions(dus->operand(1)->shape(), device_info); - dims.ok()) { - return dims.value(); - } + for (const HloInstruction* operand : consumer->operands()) { + if (operand != producer) { + fusion_operands.push_back(operand); } - - auto launch_dimensions = fusion_analysis->GetLaunchDimensions(); - if (launch_dimensions.ok()) return *launch_dimensions; } - int64_t block_size = 128; // Result for default LaunchDimensionsConfig. - int64_t num_blocks = CeilOfRatio(estimated_num_threads, block_size); - return LaunchDimensions(num_blocks, block_size); -} - -// Returns true if all input reads are coalesced. If consumer is not nullptr, -// producer and consumer are considered as one fusion, otherwise it's only the -// producer. -// -// This is a crude heuristic until we get proper tile analysis. -bool IsReadCoalesced(const std::optional& fusion_analysis, - const GpuPerformanceModelOptions& config, - const HloInstruction* producer, - const HloInstruction* consumer = nullptr) { - if (!config.consider_coalescing) return true; - - auto analyzed_kind_or_reduction = - fusion_analysis ? fusion_analysis->GetEmitterFusionKind() - : HloFusionAnalysis::EmitterFusionKind::kReduction; - - // Transposing minor dimension breaks coalescing. - if (analyzed_kind_or_reduction != - HloFusionAnalysis::EmitterFusionKind::kTranspose) { - if (TransposesMinorDimension(producer)) return false; - if (consumer && TransposesMinorDimension(consumer)) return false; - } - - // Fusing two row reductions breaks coalescing. - if (analyzed_kind_or_reduction == - HloFusionAnalysis::EmitterFusionKind::kReduction && - IsInputFusibleReduction(*producer) && consumer && - IsInputFusibleReduction(*consumer)) { - return false; - } - - return true; + std::sort(fusion_operands.begin(), fusion_operands.end()); + fusion_operands.erase( + std::unique(fusion_operands.begin(), fusion_operands.end()), + fusion_operands.end()); + return fusion_operands; } } // namespace -std::optional GpuPerformanceModelCache::Get( - const HloInstruction& instruction) { - absl::MutexLock lock(&mutex_); - - auto it = instruction_runtime_data_.find(HloInstructionAdaptor(instruction)); - if (it != instruction_runtime_data_.end()) { - return it->second; - } - return std::nullopt; -} - -std::optional GpuPerformanceModelCache::Get( - const HloInstruction& producer, const HloInstruction& consumer) { - absl::MutexLock lock(&mutex_); - - auto it = fusion_runtime_data_.find(HloInstructionAdaptor(producer)); - if (it != fusion_runtime_data_.end()) { - auto jt = it->second.find(HloInstructionAdaptor(consumer)); - if (jt != it->second.end()) { - return jt->second; - } - } - return std::nullopt; -} - -void GpuPerformanceModelCache::Set(const HloInstruction& instruction, - const EstimateRunTimeData& runtime_data) { - absl::MutexLock lock(&mutex_); - - instruction_runtime_data_[HloInstructionAdaptor(instruction)] = runtime_data; -} - -void GpuPerformanceModelCache::Set(const HloInstruction& producer, - const HloInstruction& consumer, - absl::Duration runtime) { - absl::MutexLock lock(&mutex_); - fusion_runtime_data_[HloInstructionAdaptor(producer)] - [HloInstructionAdaptor(consumer)] = runtime; -} - -void GpuPerformanceModelCache::Invalidate(const HloInstruction& instruction) { - absl::MutexLock lock(&mutex_); - HloInstructionAdaptor adaptor(instruction); - - // Remove runtime data for the instruction. - instruction_runtime_data_.erase(adaptor); - - // Remove cache for all producer-consumer pairs where the instruction is - // producer. - fusion_runtime_data_.erase(adaptor); - - // Iterate through operands to find all producer-consumer pairs where - // instruction is consumer and remove them from cache. - for (auto* operand : instruction.operands()) { - auto it = fusion_runtime_data_.find(HloInstructionAdaptor(*operand)); - if (it != fusion_runtime_data_.end()) { - it->second.erase(adaptor); - } - } -} - /*static*/ EstimateRunTimeData GpuPerformanceModel::EstimateRunTimeForInstruction( const HloInstruction* instr, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config) { + VLOG(8) << "EstimateRunTimeForInstruction: " << instr->name(); const se::DeviceDescription* device_info = cost_analysis->device_info_; int64_t flops = cost_analysis->flop_count(*instr); int64_t bytes_written = cost_analysis->output_bytes_accessed(*instr); - int64_t bytes_read = cost_analysis->bytes_accessed(*instr) - bytes_written; // Use the analysis cache if present. // TODO(jreiffers): Remove this once all callers use a cache. - std::optional local_analysis = - config.fusion_analysis_cache - ? std::nullopt - : AnalyzeFusion(*instr, *cost_analysis->device_info_); + std::optional local_analysis; + if (!config.fusion_analysis_cache) { + local_analysis = AnalyzeFusion(*instr, *cost_analysis->device_info_); + } const auto& fusion_analysis = config.fusion_analysis_cache ? config.fusion_analysis_cache->Get(*instr) - : local_analysis; + : local_analysis.value(); LaunchDimensions launch_dimensions = EstimateFusionLaunchDimensions( ShapeUtil::ElementsInRecursive(instr->shape()), fusion_analysis, *device_info); int64_t num_threads = launch_dimensions.launch_bound(); + int64_t num_blocks = launch_dimensions.num_blocks(); absl::Duration compute_time = ComputeTime(*device_info, flops, num_threads); - absl::Duration read_time = ProducerInputAccessTime( - cost_analysis, *device_info, launch_dimensions.num_blocks(), - /*producer=*/instr, fusion_analysis, config); - absl::Duration write_time = - absl::Seconds(1.0f * bytes_written / device_info->memory_bandwidth()); - absl::Duration exec_time = std::max(compute_time, read_time + write_time); - if (VLOG_IS_ON(8)) { - LOG(INFO) << "FLOPs: " << flops; - LOG(INFO) << "Bytes read: " << bytes_read; - LOG(INFO) << "Bytes written: " << bytes_written; - LOG(INFO) << "Num threads:" << num_threads; - LOG(INFO) << "Compute time: " << compute_time; - LOG(INFO) << "Input read time: " << read_time; - LOG(INFO) << "Output write time: " << write_time; + CoalescingAnalysis coalescing_analysis(instr, instr->operands(), + fusion_analysis); + + absl::Duration read_time; + int64_t bytes_read = 0; + for (const auto [operand_id, operand] : llvm::enumerate(instr->operands())) { + int64_t operand_size = cost_analysis->GetShapeSize(operand->shape()); + int64_t n_bytes_total = + GetOperandBytesAccessed(cost_analysis, instr, operand); + int64_t n_bytes_net = std::min(operand_size, n_bytes_total); + bytes_read += n_bytes_total; + + bool coalesced = coalescing_analysis.IsReadCoalesced(operand); + + VLogOperandRead(operand, n_bytes_total, n_bytes_net, coalesced); + + read_time += ReadTimeWithDRAMHeuristic( + *device_info, num_blocks, n_bytes_net, n_bytes_total, + operand->shape().element_type(), coalesced); } - return {flops, bytes_written, num_threads, write_time, exec_time}; + absl::Duration write_time = WriteTime(*device_info, bytes_written); + absl::Duration exec_time = CombineComputeAndMemoryAccessTime( + compute_time, read_time + write_time, config); + + VLogResult(flops, bytes_read, bytes_written, num_threads, compute_time, + read_time, write_time, exec_time); + + return {flops, bytes_written, num_threads, read_time, + write_time, compute_time, exec_time}; } /*static*/ EstimateRunTimeData @@ -418,167 +144,18 @@ GpuPerformanceModel::EstimateRunTimeForInstructionCached( return runtime_data; } -// Returns utilization of operand by instruction. Returns 0, if the operand is -// not used by the instruction. -float GetOperandUtilization(const GpuHloCostAnalysis* cost_analysis, - const HloInstruction* instr, - const HloInstruction* operand) { - if (!instr->IsUserOf(operand)) { - return 0.f; - } - - return cost_analysis->operand_utilization(*instr, - instr->operand_index(operand)); -} - -// Returns utilization `overlap` between a common operand of producer and -// consumer on merge. `utilization > 0` means that the operand will be accessed -// more efficiently after fusion. -// -// Currently covers two cases: -// 1) Producer has to use the common operand elementwise from its root if it is -// a fusion or just be an elementwise instruction. -// 2) Consumer has to have common elementwise roots for the producer and the -// common operand if it is a fusion or just be an elementwise instruction. -float GetCommonUtilization(const GpuHloCostAnalysis* cost_analysis, - const HloInstruction* producer, - int64_t producer_idx_of_operand, - const HloInstruction* consumer) { - const auto* operand = producer->operand(producer_idx_of_operand); - - if (!consumer || !consumer->IsUserOf(operand)) { - return 0.f; - } - - if (producer->IsElementwise() || - (producer->opcode() == HloOpcode::kFusion && - FusionUsesParameterElementwiseFromRoot(producer, producer_idx_of_operand, - cost_analysis))) { - if (consumer->opcode() == HloOpcode::kFusion) { - int64_t consumer_idx_of_common_operand = consumer->operand_index(operand); - int64_t consumer_idx_of_producer = consumer->operand_index(producer); - return cost_analysis->CommonElementwiseUtilization( - consumer->fused_parameter(consumer_idx_of_common_operand), - consumer->fused_parameter(consumer_idx_of_producer)); - } else { - if (consumer->IsElementwise()) { - return 1.f; - } - } - } - return 0.f; -} - -// Returns utilization of operand after producer and consumer are fused -// together. `GetCommonUtilization` works only for a limited set of elementwise -// cases. -// TODO(shyshkov): Combine logic from GpuHloCostAnalysis with boundary function -// to properly calculate utilization. -float GetSharedUtilization(const GpuHloCostAnalysis* cost_analysis, - const HloInstruction* producer, - const HloInstruction* consumer, - const HloInstruction* operand) { - float producer_utilization_by_consumer = - GetOperandUtilization(cost_analysis, consumer, producer); - - float operand_utilization_by_producer = - GetOperandUtilization(cost_analysis, producer, operand); - - float operand_utilization_by_consumer = - GetOperandUtilization(cost_analysis, consumer, operand); - - float common_utilization = - producer->IsUserOf(operand) - ? GetCommonUtilization(cost_analysis, producer, - producer->operand_index(operand), consumer) - : 0.f; - - return producer_utilization_by_consumer * operand_utilization_by_producer + - operand_utilization_by_consumer - common_utilization; -} - -// Tells input access time of the producer alone if fused_consumer -// is not specified. Otherwise estimates the access time to producer's -// inputs as if it is fused into the consumer. -/*static*/ absl::Duration GpuPerformanceModel::ProducerInputAccessTime( - const GpuHloCostAnalysis* cost_analysis, - const se::DeviceDescription& gpu_device_info, int64_t num_blocks, - const HloInstruction* producer, - const std::optional& fusion_analysis, - const GpuPerformanceModelOptions& config, - const HloInstruction* fused_consumer) { - absl::Duration ret = absl::ZeroDuration(); - float producer_output_utilization = - fused_consumer - ? GetOperandUtilization(cost_analysis, fused_consumer, producer) - : 1.f; - - // TODO(jreiffers): We should be checking each operand. - bool coalesced = - IsReadCoalesced(fusion_analysis, config, producer, fused_consumer); - for (int i = 0; i < producer->operand_count(); ++i) { - // Information about data read taking into account utilization. - // If `operand_utilization` is 0, `operand_bytes_accessed` should be also 0. - int64_t operand_bytes_accessed = - cost_analysis->operand_bytes_accessed(*producer, i); - float operand_utilization = - cost_analysis->operand_utilization(*producer, i); - - // An estimate how much data would need to fit into L1/L2 cache to speed up - // the operand access. - // If `operand_utilization` < 1, only a part of the full operand size should - // be read. Otherwise, `operand_bytes_accessed / operand_utilization` is the - // size of the operand without reuse. - int64_t n_bytes_net = std::llround(operand_bytes_accessed / - std::max(operand_utilization, 1.0f)); - - // Look if common operand of producer and consumer will be accessed more - // efficiently on merge. - float common_utilization = GetCommonUtilization( - cost_analysis, producer, /*producer_idx_of_operand=*/i, fused_consumer); - - const auto& operand_shape = producer->operand(i)->shape(); - - CHECK_LE(common_utilization, producer_output_utilization); - float n_bytes_total = operand_bytes_accessed * - (producer_output_utilization - common_utilization); - ret += ReadTime(gpu_device_info, num_blocks, /*n_bytes_net=*/n_bytes_net, - n_bytes_total, operand_shape.element_type(), coalesced, - config.first_read_from_dram); - } - return ret; -} - -absl::Duration GpuPerformanceModel::ComputeTime( - const se::DeviceDescription& gpu_device_info, int64_t flops, - int64_t num_threads) { - int64_t fpu_count = - gpu_device_info.core_count() * gpu_device_info.fpus_per_core(); - int64_t n_threads_active = std::min(num_threads, fpu_count); - int64_t flop_per_ns_per_fpu = gpu_device_info.clock_rate_ghz() * /*fma:*/ 2; - int64_t flop_per_ns_effective = flop_per_ns_per_fpu * n_threads_active; - return absl::Nanoseconds(1.0f * flops / flop_per_ns_effective); -} - +/*static*/ absl::Duration GpuPerformanceModel::EstimateUnfusedExecTime( const HloInstruction* producer, const EstimateRunTimeData& producer_runtime, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config, - const std::vector& fused_consumers, - const std::vector& consumer_runtimes) { + absl::Span fused_consumers) { const se::DeviceDescription* device_info = cost_analysis->device_info_; absl::Duration time_unfused = kKernelLaunchOverhead * (fused_consumers.size() + 1) + producer_runtime.exec_time; - if (config.calculate_full_priority) { - for (const auto& consumer_runtime : consumer_runtimes) { - time_unfused += consumer_runtime.exec_time; - } - return time_unfused; - } - for (const HloInstruction* fused_consumer : fused_consumers) { VLOG(8) << "Unfused consumer: " << fused_consumer->name(); float utilization_by_this_consumer = @@ -586,14 +163,14 @@ absl::Duration GpuPerformanceModel::EstimateUnfusedExecTime( // Use the analysis cache if present. // TODO(jreiffers): Remove this once all callers use a cache. - std::optional local_analysis = - config.fusion_analysis_cache - ? std::nullopt - : AnalyzeFusion(*fused_consumer, *device_info); + std::optional local_analysis; + if (!config.fusion_analysis_cache) { + local_analysis = AnalyzeFusion(*fused_consumer, *device_info); + } const auto& analysis_unfused = config.fusion_analysis_cache ? config.fusion_analysis_cache->Get(*fused_consumer) - : local_analysis; + : local_analysis.value(); LaunchDimensions launch_dimensions_unfused = EstimateFusionLaunchDimensions( ShapeUtil::ElementsInRecursive(fused_consumer->shape()), @@ -604,12 +181,9 @@ absl::Duration GpuPerformanceModel::EstimateUnfusedExecTime( int64_t n_bytes_net = std::min(producer_runtime.bytes_written, n_bytes_total); - bool coalesced = - IsReadCoalesced(analysis_unfused, config, /*producer=*/fused_consumer); - auto read_time_unfused = ReadTime( - *device_info, launch_dimensions_unfused.num_blocks(), n_bytes_net, - n_bytes_total, fused_consumer->shape().element_type(), coalesced, - config.first_read_from_dram); + auto read_time_unfused = + ReadTime(*device_info, launch_dimensions_unfused.num_blocks(), + n_bytes_net, n_bytes_total); VLOG(10) << " Read time unfused: " << read_time_unfused; time_unfused += read_time_unfused; @@ -622,56 +196,99 @@ absl::Duration GpuPerformanceModel::EstimateUnfusedExecTime( const HloInstruction* producer, const HloInstruction* consumer, const EstimateRunTimeData& producer_runtime, const EstimateRunTimeData& consumer_runtime, - const LaunchDimensions& launch_dimensions, - float utilization_by_this_consumer, const GpuHloCostAnalysis* cost_analysis, - const std::optional& fusion_analysis, + const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config) { + VLOG(8) << "EstimateRunTimeForFusion, producer: " << producer->name() + << " consumer: " << consumer->name(); const se::DeviceDescription* device_info = cost_analysis->device_info_; - int64_t fused_flops = producer_runtime.flops * utilization_by_this_consumer + - consumer_runtime.flops; - - absl::Duration compute_time = - ComputeTime(*device_info, fused_flops, launch_dimensions.launch_bound()); + float utilization_by_this_consumer = cost_analysis->operand_utilization( + *consumer, consumer->operand_index(producer)); - absl::flat_hash_set fusion_operands; - for (auto* operand : producer->operands()) { - fusion_operands.insert(operand); - } - for (auto* operand : consumer->operands()) { - if (operand != producer) { - fusion_operands.insert(operand); - } + std::optional local_analysis_fused; + if (!config.fusion_analysis_cache) { + local_analysis_fused = + AnalyzeProducerConsumerFusion(*producer, *consumer, *device_info); } + const auto& fusion_analysis = + config.fusion_analysis_cache + ? config.fusion_analysis_cache->Get(*producer, *consumer) + : local_analysis_fused.value(); + + LaunchDimensions launch_dimensions = EstimateFusionLaunchDimensions( + producer_runtime.num_threads * utilization_by_this_consumer, + fusion_analysis, *device_info); + + int64_t flops = producer_runtime.flops * utilization_by_this_consumer + + consumer_runtime.flops; + + int64_t num_threads = launch_dimensions.launch_bound(); + absl::Duration compute_time = ComputeTime(*device_info, flops, num_threads); + + std::vector fusion_operands = + GetUniqueFusionOperands(producer, consumer); + CoalescingAnalysis coalescing_analysis(producer, consumer, fusion_operands, + fusion_analysis); absl::Duration read_time; + int64_t bytes_read = 0; for (const auto* operand : fusion_operands) { - float operand_utilization = - GetSharedUtilization(cost_analysis, producer, consumer, operand); - int64_t operand_size = cost_analysis->GetShapeSize(operand->shape()); - int64_t n_bytes_total = std::llround(operand_size * operand_utilization); + int64_t n_bytes_total = GetSharedOperandBytesAccessed( + cost_analysis, producer, consumer, operand); int64_t n_bytes_net = std::min(operand_size, n_bytes_total); + bytes_read += n_bytes_total; + + bool coalesced = coalescing_analysis.IsReadCoalesced(operand); + + VLogOperandRead(operand, n_bytes_total, n_bytes_net, coalesced); + + read_time += ReadTimeWithDRAMHeuristic( + *device_info, launch_dimensions.num_blocks(), n_bytes_net, + n_bytes_total, operand->shape().element_type(), coalesced); + } + + auto exec_time = CombineComputeAndMemoryAccessTime( + compute_time, read_time + consumer_runtime.write_time, config); - bool coalesced = - IsReadCoalesced(fusion_analysis, config, producer, consumer); + VLogResult(flops, bytes_read, consumer_runtime.bytes_written, num_threads, + compute_time, read_time, consumer_runtime.write_time, exec_time); - read_time += - ReadTime(*device_info, launch_dimensions.num_blocks(), n_bytes_net, - n_bytes_total, operand->shape().element_type(), coalesced, - config.first_read_from_dram); + return exec_time; +} + +/*static*/ +absl::Duration GpuPerformanceModel::EstimateRunTimeForFusionCached( + const HloInstruction* producer, const HloInstruction* consumer, + const EstimateRunTimeData& producer_runtime, + const EstimateRunTimeData& consumer_runtime, + const GpuHloCostAnalysis* cost_analysis, + const GpuPerformanceModelOptions& config) { + if (config.gpu_performance_model_cache) { + if (auto fusion_runtime = + config.gpu_performance_model_cache->Get(*producer, *consumer)) { + return *fusion_runtime; + } } - return std::max(compute_time, read_time + consumer_runtime.write_time); + auto fusion_runtime = + EstimateRunTimeForFusion(producer, consumer, producer_runtime, + consumer_runtime, cost_analysis, config); + + if (config.gpu_performance_model_cache) { + config.gpu_performance_model_cache->Set(*producer, *consumer, + fusion_runtime); + } + return fusion_runtime; } +/*static*/ absl::Duration GpuPerformanceModel::EstimateFusedExecTime( const HloInstruction* producer, const EstimateRunTimeData& producer_runtime, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config, - const std::vector& fused_consumers, - const std::vector& consumer_runtimes, + absl::Span fused_consumers, bool multi_output) { const se::DeviceDescription* device_info = cost_analysis->device_info_; @@ -680,52 +297,23 @@ absl::Duration GpuPerformanceModel::EstimateFusedExecTime( for (auto [idx, fused_consumer] : llvm::enumerate(fused_consumers)) { VLOG(8) << "Fused consumer: " << fused_consumer->name(); - if (config.calculate_full_priority && config.gpu_performance_model_cache) { - if (auto fusion_runtime = config.gpu_performance_model_cache->Get( - *producer, *fused_consumer)) { - exec_time_fused += *fusion_runtime; - continue; - } - } - float utilization_by_this_consumer = cost_analysis->operand_utilization( *fused_consumer, fused_consumer->operand_index(producer)); - std::optional local_analysis_fused = - config.fusion_analysis_cache - ? std::nullopt - : AnalyzeProducerConsumerFusion(*producer, *fused_consumer, - *device_info); + std::optional local_analysis_fused; + if (!config.fusion_analysis_cache) { + local_analysis_fused = AnalyzeProducerConsumerFusion( + *producer, *fused_consumer, *device_info); + } const auto& analysis_fused = config.fusion_analysis_cache ? config.fusion_analysis_cache->Get(*producer, *fused_consumer) - : local_analysis_fused; + : local_analysis_fused.value(); LaunchDimensions launch_dimensions_fused = EstimateFusionLaunchDimensions( producer_runtime.num_threads * utilization_by_this_consumer, analysis_fused, *device_info); - // The original model ignores consumer computation and output writes. The - // main goal of the model is to compare estimates of fused and unfused - // cases. Since epilog of the consumers remains unchanged in both bases, we - // only consider duplication of the producer computation and repeated access - // to producer inputs. - // - // With `calculate_full_priority`, consumer computation and full read time - // is accounted in the priority. - if (config.calculate_full_priority) { - auto fusion_runtime = EstimateRunTimeForFusion( - producer, fused_consumer, producer_runtime, consumer_runtimes[idx], - launch_dimensions_fused, utilization_by_this_consumer, cost_analysis, - analysis_fused, config); - exec_time_fused += fusion_runtime; - if (config.gpu_performance_model_cache) { - config.gpu_performance_model_cache->Set(*producer, *fused_consumer, - fusion_runtime); - } - continue; - } - absl::Duration compute_time_by_this_consumer = ComputeTime( *device_info, producer_runtime.flops * utilization_by_this_consumer, launch_dimensions_fused.launch_bound()); @@ -741,8 +329,9 @@ absl::Duration GpuPerformanceModel::EstimateFusedExecTime( VLOG(10) << " Input access time by consumer: " << input_access_time_by_this_consumer; - exec_time_fused += std::max(compute_time_by_this_consumer, - input_access_time_by_this_consumer); + exec_time_fused += CombineComputeAndMemoryAccessTime( + compute_time_by_this_consumer, input_access_time_by_this_consumer, + config); } // Multi-output fusion still writes the initial output of the producer. @@ -754,10 +343,56 @@ absl::Duration GpuPerformanceModel::EstimateFusedExecTime( return exec_time_fused; } +/*static*/ +GpuPerformanceModel::RunTimes +GpuPerformanceModel::EstimateRunTimesForPriorityFusion( + const HloInstruction* producer, const GpuHloCostAnalysis* cost_analysis, + const GpuPerformanceModelOptions& config, + absl::Span fused_consumers, + bool multi_output) { + EstimateRunTimeData producer_runtime = + EstimateRunTimeForInstructionCached(producer, cost_analysis, config); + + absl::Duration time_unfused = + kKernelLaunchOverhead * (fused_consumers.size() + 1) + + producer_runtime.exec_time; + + absl::Duration time_fused = kKernelLaunchOverhead * fused_consumers.size(); + + for (auto fused_consumer : fused_consumers) { + VLOG(8) << "Fused consumer: " << fused_consumer->name(); + + EstimateRunTimeData consumer_runtime = EstimateRunTimeForInstructionCached( + fused_consumer, cost_analysis, config); + + time_unfused += consumer_runtime.exec_time; + + time_fused += EstimateRunTimeForFusionCached( + producer, fused_consumer, producer_runtime, consumer_runtime, + cost_analysis, config); + } + + // Multi-output fusion still writes the initial output of the producer. + // For now assume that the producer's output does not need to be recomputed. + if (multi_output) { + time_fused += producer_runtime.write_time; + } + + if (VLOG_IS_ON(8)) { + LOG(INFO) << "Consumer count: " << fused_consumers.size(); + LOG(INFO) << "Unfused time: " << time_unfused; + LOG(INFO) << "Fused time: " << time_fused; + } + + return {time_unfused, time_fused}; +} + +/*static*/ GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimes( const HloInstruction* producer, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config, - std::vector fused_consumers, bool multi_output) { + absl::Span fused_consumers, + bool multi_output) { VLOG(8) << "Producer: " << producer->name(); if (producer->opcode() == HloOpcode::kFusion) { VLOG(10) << producer->fused_instructions_computation()->ToString(); @@ -766,36 +401,15 @@ GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimes( EstimateRunTimeData producer_runtime = EstimateRunTimeForInstructionCached(producer, cost_analysis, config); - std::vector consumer_runtimes; - if (config.calculate_full_priority) { - consumer_runtimes.reserve(fused_consumers.size()); - for (auto* consumer : fused_consumers) { - consumer_runtimes.push_back( - EstimateRunTimeForInstructionCached(consumer, cost_analysis, config)); - } - } - - absl::Duration time_unfused = - EstimateUnfusedExecTime(producer, producer_runtime, cost_analysis, config, - fused_consumers, consumer_runtimes); + absl::Duration time_unfused = EstimateUnfusedExecTime( + producer, producer_runtime, cost_analysis, config, fused_consumers); absl::Duration time_fused = EstimateFusedExecTime(producer, producer_runtime, cost_analysis, config, - fused_consumers, consumer_runtimes, multi_output); - - int64_t fused_consumer_count = fused_consumers.size(); - float total_producer_utilization = 0; - - for (const HloInstruction* fused_consumer : fused_consumers) { - float utilization_by_this_consumer = cost_analysis->operand_utilization( - *fused_consumer, fused_consumer->operand_index(producer)); - total_producer_utilization += utilization_by_this_consumer; - } + fused_consumers, multi_output); if (VLOG_IS_ON(8)) { - LOG(INFO) << "Consumer count: " << fused_consumer_count; - LOG(INFO) << "Utilization of producer output: " - << total_producer_utilization; + LOG(INFO) << "Consumer count: " << fused_consumers.size(); LOG(INFO) << "Unfused time: " << time_unfused; LOG(INFO) << "Fused time: " << time_fused; } @@ -803,6 +417,7 @@ GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimes( return {time_unfused, time_fused}; } +/*static*/ void GpuPerformanceModel::RecordEstimatedRunTime( HloInstruction* instruction, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config) { @@ -814,196 +429,21 @@ void GpuPerformanceModel::RecordEstimatedRunTime( double cycles = absl::ToDoubleNanoseconds(data.exec_time) * cost_analysis->device_info_->clock_rate_ghz(); - auto backend_config = instruction->backend_config(); - TF_CHECK_OK(backend_config.status()) << instruction->ToString(); - backend_config->mutable_reification_cost()->set_end_to_end_cycles(cycles); - TF_CHECK_OK(instruction->set_backend_config(*backend_config)); + auto gpu_config = instruction->backend_config(); + TF_CHECK_OK(gpu_config.status()) << instruction->ToString(); + auto reification_cost = + gpu_config->mutable_fusion_backend_config()->mutable_reification_cost(); + reification_cost->set_end_to_end_cycles(cycles); + reification_cost->set_compute_time_us( + absl::ToDoubleMicroseconds(data.compute_time)); + reification_cost->set_memory_access_time_us( + absl::ToDoubleMicroseconds(data.read_time + data.write_time)); + reification_cost->set_exec_time_us( + absl::ToDoubleMicroseconds(data.exec_time)); + TF_CHECK_OK(instruction->set_backend_config(*gpu_config)); VLOG(8) << "RecordEstimatedRunTime: " << instruction->ToString(); } -// Returns NVLink bw in GB/s -/*static*/ -float GpuPerformanceWithCollectiveModel::GetNvlinkBw( - se::CudaComputeCapability compute_capability) { - return compute_capability.IsAtLeast(se::CudaComputeCapability::HOPPER) - ? kSm90NvlinkBandwidth - : compute_capability.IsAtLeast(se::CudaComputeCapability::AMPERE) - ? kSm80NvlinkBandwidth - : compute_capability.IsAtLeast(se::CudaComputeCapability::VOLTA) - ? kSm70NvlinkBandwidth - : compute_capability.IsAtLeast(se::CudaComputeCapability::PASCAL_) - ? kSm60NvlinkBandwidth - : kSm80NvlinkBandwidth; -} - -/*static*/ bool GpuPerformanceWithCollectiveModel::InitNvml() { -#if GOOGLE_CUDA - void* libhandle = dlopen("libnvidia-ml.so.1", RTLD_NOW); - CHECK(libhandle != nullptr) << "Failed to open libnvidia-ml.so.1"; - - struct SymbolEntry { - void** functor; - char const* name; - }; - - std::vector symbols = { - {(void**)&xla_nvmlInit, "nvmlInit_v2"}, - {(void**)&xla_nvmlShutdown, "nvmlShutdown"}, - {(void**)&xla_nvmlDeviceGetHandleByIndex, "nvmlDeviceGetHandleByIndex"}, - {(void**)&xla_nvmlDeviceGetNvLinkCapability, - "nvmlDeviceGetNvLinkCapability"}, - }; - for (SymbolEntry se : symbols) { - *se.functor = dlsym(libhandle, se.name); - } - nvmlReturn_t init_result = xla_nvmlInit(); - return init_result == NVML_SUCCESS; -#else - return false; -#endif // GOOGLE_CUDA -} - -/*static*/ bool GpuPerformanceWithCollectiveModel::ShutdownNvml() { -#if GOOGLE_CUDA - nvmlReturn_t shutdown_result = xla_nvmlShutdown(); - return shutdown_result == NVML_SUCCESS; -#else - return false; -#endif // GOOGLE_CUDA -} - -/*static*/ uint32_t -GpuPerformanceWithCollectiveModel::CheckIfNvlinkSupportsP2P() { -#if GOOGLE_CUDA - // We will use nvml library to detect nvlink capability - // to see if it supports p2p communication. - // We first load libnvidia-ml.so and assign symbols to function pointers - // to avoid linking errors. - // Then gpu 0 will be used to query for nvlink capability, note that - // we only look at link 0 of gpu 0 since all other links are assumed - // to have the same capability. - CHECK(InitNvml()) << "NVML init failed."; - nvmlDevice_t nvml_device; - nvmlReturn_t get_device_result = - xla_nvmlDeviceGetHandleByIndex(0, &nvml_device); - CHECK(get_device_result == NVML_SUCCESS); - - uint32_t supported_p2p = 0; - - nvmlReturn_t nvlink_cap_result = xla_nvmlDeviceGetNvLinkCapability( - nvml_device, /*nvlink link number*/ 0, NVML_NVLINK_CAP_P2P_SUPPORTED, - &supported_p2p); - CHECK(nvlink_cap_result == NVML_SUCCESS); - CHECK(ShutdownNvml()) << "NVML shutdown failed."; - return supported_p2p; -#else - return 0; -#endif // GOOGLE_CUDA -} - -/*static*/ absl::Duration -GpuPerformanceWithCollectiveModel::ComputeAllreduceTime( - const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, - const se::DeviceDescription& gpu_device_info) { - // We use nccl group call to launch multiple allreduces so launch overhead - // only occurs once. - absl::Duration total_time = kKernelLaunchOverhead; - stream_executor::CudaComputeCapability compute_cap = - gpu_device_info.cuda_compute_capability(); - - int64_t size_of_speed_array = kIntraNodeSpeeds.size(); - int64_t size_of_sm90_speed_array = kIntraNodeSpeedsSm90.size(); - - int num_speeds = compute_cap.major >= se::CudaComputeCapability::HOPPER - ? size_of_sm90_speed_array - : size_of_speed_array; - const double* speeds = compute_cap.major >= se::CudaComputeCapability::HOPPER - ? kIntraNodeSpeedsSm90.data() - : kIntraNodeSpeeds.data(); - - int speed_index = 0; - float max_sys_bw = - GetMaxSysBwFromGpu(compute_cap, kLowLatencyMaxBandwidths.data()); - - CHECK_GT(max_sys_bw, 0); - - while ((speed_index < num_speeds - 1) && speeds[speed_index] > max_sys_bw) { - speed_index++; - } - float bw_intra_node = speeds[speed_index]; - int64_t num_devices = cost_analysis->NumOfDevices(instr); - - int64_t min_nchannels = - std::max(num_devices, GetMinNumberOfChannels(CollectiveAlgo::RING)); - int64_t num_channels = - std::max(min_nchannels, GetNcclMaxNumChannels(CollectiveAlgo::RING)); - int default_threads = - (bw_intra_node * num_channels <= kPciBandwidth) ? 256 : kLL128NumThreads; - - int warp_size = gpu_device_info.threads_per_warp(); - int num_threads = GetNumThreads(warp_size, kLL128NumThreads / 4, - kLL128NumThreads, default_threads); - - // Since channels are pipelined together, compute time will only occur as in a - // single channel. - absl::Duration compute_time_per_channel = - ComputeTime(gpu_device_info, - cost_analysis->flop_count(instr) / num_channels, num_threads); - total_time += compute_time_per_channel; - - uint32_t supported_p2p = CheckIfNvlinkSupportsP2P(); - - if (supported_p2p == 0) { - VLOG(8) << "Nvlink doesn't support p2p communication. Model will " - "continue using default system bandwidth."; - } else { - VLOG(8) << "Nvlink supports p2p communication, setting intra node " - "bandwidth to nvlink bw."; - bw_intra_node = GetNvlinkBw(compute_cap); - } - - double bus_bandwidth = bw_intra_node * num_channels; - - // Get per channel LL128 ring bandwidth - double per_channel_ring_ll128_Bw = - GetMaxSysBwFromGpu(compute_cap, kPerChannelMaxRingLL128Bandwidths.data()); - - bus_bandwidth = std::min(bus_bandwidth * kRingAlgorithmDiscountFactor, - num_channels * per_channel_ring_ll128_Bw); - double actual_bandwidth = bus_bandwidth * cost_analysis->ScalingRatio(instr); - - absl::Duration communication_time = absl::Microseconds( - cost_analysis->bytes_accessed(instr) / (1e6 * actual_bandwidth)); - total_time += communication_time; - return total_time; -} - -/*static*/ absl::Duration -GpuPerformanceWithCollectiveModel::ComputeCollectiveTime( - const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, - const se::DeviceDescription& gpu_device_info) { - if (cost_analysis->NumOfDevices(instr) == 1) { - VLOG(8) << "Returning only kernel launch overhead for a single partition."; - return kKernelLaunchOverhead; - } - - if (HloDataflowAnalysis::IsAsynchronousOperationDone(instr.opcode())) { - VLOG(8) << "Returning 0 cost for async done op " << instr.name(); - return absl::ZeroDuration(); - } - switch (instr.opcode()) { - case HloOpcode::kAllReduce: - case HloOpcode::kAllReduceStart: - return ComputeAllreduceTime(instr, cost_analysis, gpu_device_info); - default: { - LOG(WARNING) - << "Runtime estimate for " << instr.name() - << " not implemented. Returning only the kernel launch time."; - return kKernelLaunchOverhead; - } - } -} - } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/model/gpu_performance_model.h b/xla/service/gpu/model/gpu_performance_model.h index 9e5cfe49d58c2..d23c74d96563c 100644 --- a/xla/service/gpu/model/gpu_performance_model.h +++ b/xla/service/gpu/model/gpu_performance_model.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,128 +16,18 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_MODEL_GPU_PERFORMANCE_MODEL_H_ #define XLA_SERVICE_GPU_MODEL_GPU_PERFORMANCE_MODEL_H_ -#include -#include -#include "absl/container/flat_hash_map.h" #include "absl/time/time.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/hlo_traversal.h" -#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" -#include "xla/stream_executor/device_description.h" - -#if GOOGLE_CUDA -#include - -#include "third_party/gpus/cuda/nvml/include/nvml.h" -// Below is a list of function pointers to be used -// for querying device properties through nvml library. -#define NVML_FUNCTOR(name, rettype, args) rettype(*xla_##name) args = nullptr; - -NVML_FUNCTOR(nvmlInit, nvmlReturn_t, ()) -NVML_FUNCTOR(nvmlShutdown, nvmlReturn_t, ()) -NVML_FUNCTOR(nvmlDeviceGetHandleByIndex, nvmlReturn_t, - (unsigned int index, nvmlDevice_t* device)) -NVML_FUNCTOR(nvmlDeviceGetNvLinkCapability, nvmlReturn_t, - (nvmlDevice_t device, unsigned int link, - nvmlNvLinkCapability_t capability, unsigned int* capResult)) - -#endif +#include "xla/service/gpu/model/gpu_performance_model_base.h" namespace xla { namespace gpu { -struct EstimateRunTimeData { - int64_t flops; - int64_t bytes_written; - int64_t num_threads; - absl::Duration write_time; - absl::Duration exec_time; -}; - -class GpuPerformanceModelCache { - public: - // Returns cached runtime data for the instruction or producer-consumer pair. - // Returns nullopt if there is no data in cache. - std::optional Get(const HloInstruction& instruction); - std::optional Get(const HloInstruction& producer, - const HloInstruction& consumer); - - // Sets cache value for the instruction or producer-consumer pair. - void Set(const HloInstruction& instruction, - const EstimateRunTimeData& runtime_data); - void Set(const HloInstruction& producer, const HloInstruction& consumer, - absl::Duration runtime); - - // Removes all cache entries for this instruction. The cache contains entries - // for individual instructions in instruction_runtime_data_ and for - // producer-consumer pairs in fusion_runtime_data_. - void Invalidate(const HloInstruction& instruction); - - private: - absl::Mutex mutex_; - - // Stores unfused runtime data for individual instructions. - absl::flat_hash_map - instruction_runtime_data_; - - // Stores fused runtime data for producer-consumer pairs. - absl::flat_hash_map< - HloInstructionAdaptor, - absl::flat_hash_map> - fusion_runtime_data_; -}; - -struct GpuPerformanceModelOptions { - // Whether to attempt to model the effect of uncoalesced reads. - bool consider_coalescing = false; - - // Use better read modelling, when first read always happends from DRAM and - // re-reads can happen from cache. - bool first_read_from_dram = false; - - // Properly calculate read+write and compute time in both fused and unfused - // case for producer and consumer. - bool calculate_full_priority = false; - - // If present, use this to retrieve fusion analyses. - HloFusionAnalysisCache* fusion_analysis_cache = nullptr; - - GpuPerformanceModelCache* gpu_performance_model_cache = nullptr; - - static GpuPerformanceModelOptions Default() { - return GpuPerformanceModelOptions(); - } - - static GpuPerformanceModelOptions PriorityFusion( - HloFusionAnalysisCache* fusion_analysis_cache, - GpuPerformanceModelCache* gpu_performance_model_cache) { - GpuPerformanceModelOptions config; - config.consider_coalescing = true; - config.first_read_from_dram = true; - config.calculate_full_priority = true; - config.fusion_analysis_cache = fusion_analysis_cache; - config.gpu_performance_model_cache = gpu_performance_model_cache; - return config; - } - - static GpuPerformanceModelOptions ForModule(const HloModule* module) { - return module->config().debug_options().xla_gpu_enable_priority_fusion() - ? PriorityFusion(nullptr, - nullptr) // Only cache within priority fusion. - : Default(); - } -}; - -class GpuPerformanceModel { +class GpuPerformanceModel : public GpuPerformanceModelBase { public: - struct RunTimes { - absl::Duration time_unfused; - absl::Duration time_fused; - }; - static EstimateRunTimeData EstimateRunTimeForInstruction( const HloInstruction* instr, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config); @@ -151,10 +41,14 @@ class GpuPerformanceModel { const HloInstruction* producer, const HloInstruction* consumer, const EstimateRunTimeData& producer_runtime, const EstimateRunTimeData& consumer_runtime, - const LaunchDimensions& launch_dimensions, - float utilization_by_this_consumer, const GpuHloCostAnalysis* cost_analysis, - const std::optional& fusion_analysis, + const GpuPerformanceModelOptions& config); + + static absl::Duration EstimateRunTimeForFusionCached( + const HloInstruction* producer, const HloInstruction* consumer, + const EstimateRunTimeData& producer_runtime, + const EstimateRunTimeData& consumer_runtime, + const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config); static absl::Duration EstimateUnfusedExecTime( @@ -162,111 +56,32 @@ class GpuPerformanceModel { const EstimateRunTimeData& producer_runtime, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config, - const std::vector& fused_consumers, - const std::vector& consumer_runtime); + absl::Span fused_consumers); static absl::Duration EstimateFusedExecTime( const HloInstruction* producer, const EstimateRunTimeData& producer_runtime, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config, - const std::vector& fused_consumers, - const std::vector& consumer_runtimes, + absl::Span fused_consumers, bool multi_output); static RunTimes EstimateRunTimes( const HloInstruction* producer, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config, - std::vector fused_consumers = {}, + absl::Span fused_consumers = {}, + bool multi_output = false); + + static RunTimes EstimateRunTimesForPriorityFusion( + const HloInstruction* producer, const GpuHloCostAnalysis* cost_analysis, + const GpuPerformanceModelOptions& config, + absl::Span fused_consumers = {}, bool multi_output = false); // Writes estimated execution time to FusionBackendConfig.reification_cost. static void RecordEstimatedRunTime(HloInstruction* instruction, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config); - static absl::Duration ComputeTime( - const se::DeviceDescription& gpu_device_info, int64_t flops, - int64_t num_threads); - - static absl::Duration ProducerInputAccessTime( - const GpuHloCostAnalysis* cost_analysis, - const se::DeviceDescription& gpu_device_info, int64_t num_blocks, - const HloInstruction* producer, - const std::optional& fusion_analysis, - const GpuPerformanceModelOptions& config, - const HloInstruction* fused_consumer = nullptr); -}; - -class GpuPerformanceWithCollectiveModel : public GpuPerformanceModel { - public: - // Different algorithms that can be used to perform the collective. - enum CollectiveAlgo { - RING = 0, - TREE, - }; - - // Table for max system bandwidths GB/s for using NCCL's low latency - // algorithm. This is used for intra-node estimate. - static constexpr std::array kLowLatencyMaxBandwidths = { - 39.0 /* Volta*/, 87.7 /* Ampere*/, 87.7 /* Hopper*/ - }; - - // Max bandwidth in GB/s for ring low latency 128 algorithm per channel on a - // single-node - static constexpr std::array kPerChannelMaxRingLL128Bandwidths = { - 20.0 /* Volta */, - 20.0 /* Ampere */, - 36.7 /* Hopper */, - }; - - // Nvlink unidirectional bandwidth for different compute cap. Note this is per - // lane bandwidth. - static constexpr double kSm60NvlinkBandwidth = 18.0; - static constexpr double kSm70NvlinkBandwidth = 20.0; - static constexpr double kSm80NvlinkBandwidth = 20.0; - static constexpr double kSm90NvlinkBandwidth = 20.0; - - // PCIE bandwidth for PCI Gen3 x16 - static constexpr double kPciBandwidth = 12.0; - - // Discount factor for ring algorithm - static constexpr double kRingAlgorithmDiscountFactor = 0.92; - - // Different tiers for intra-node bandwidth. - static constexpr std::array kIntraNodeSpeeds = { - 40.0, 30.0, 20.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0}; - // SM90 has different bandwidths. - static constexpr std::array kIntraNodeSpeedsSm90 = { - 60.0, 40.0, 30.0, 24.0, 20.0, 15.0, 12.0, 6.0, 3.0}; - - // Maximum number of channels allowed by NCCL - static constexpr int64_t kMaxNumChannelsRing = 16; - - // ll128 is by default enabled for Volta, Ampere and Hopper, ll128 by default - // launches 640 threads. - static constexpr int64_t kLL128NumThreads = 640; - - static absl::Duration ComputeCollectiveTime( - const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, - const se::DeviceDescription& gpu_device_info); - - // Returns NVLink bw in GB/s - static float GetNvlinkBw(se::CudaComputeCapability compute_capability); - - // Initialize nvml library. - static bool InitNvml(); - - // Shut down nvml library. - static bool ShutdownNvml(); - - // This checks if the nvlink supports direct P2P communication, - // If not, we will use PCIE bandwidth to estimate latency. - static uint32_t CheckIfNvlinkSupportsP2P(); - - private: - static absl::Duration ComputeAllreduceTime( - const HloInstruction& instr, const GpuHloCostAnalysis* cost_analysis, - const se::DeviceDescription& gpu_device_info); }; } // namespace gpu diff --git a/xla/service/gpu/model/gpu_performance_model_base.cc b/xla/service/gpu/model/gpu_performance_model_base.cc new file mode 100644 index 0000000000000..40bb1ff69b1b7 --- /dev/null +++ b/xla/service/gpu/model/gpu_performance_model_base.cc @@ -0,0 +1,404 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/gpu_performance_model_base.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/fusions/fusions.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/util.h" + +namespace xla { +namespace gpu { + +namespace { + +// Returns whether a fusion uses the parameter at the given index elementwise +// from its root. +bool FusionUsesParameterElementwiseFromRoot( + const HloInstruction* fusion, int parameter_index, + const GpuHloCostAnalysis* cost_analysis) { + return cost_analysis->CommonElementwiseUtilization( + fusion->fused_parameter(parameter_index), + fusion->fused_expression_root()) == 1.f; +} + +int GetCoalescingWasteFactor(PrimitiveType element_type) { + int64_t element_size_bytes = + element_type == PrimitiveType::TUPLE || + element_type == PrimitiveType::TOKEN + ? 4 /* Dummy value. TODO(jreiffers): Model this case. */ + : ShapeUtil::ByteSizeOfPrimitiveType(element_type); + // Cache line is 128B that is split into 4 sectors of 32B. Default transaction + // size from DRAM -> L2 = 64 Bytes = 2 sectors, since V100, but it can be also + // configured. + // https://developer.download.nvidia.com/video/gputechconf/gtc/2020/presentations/s21819-optimizing-applications-for-nvidia-ampere-gpu-architecture.pdf + // (page 10). + constexpr int kDRAMToL2TransactionSizeBytes = 64; + // Assume we use one element from the cache line and waste the remaining + // bandwidth. For example, if we're reading f32s, we use 1/16nd of the cache + // line. + return kDRAMToL2TransactionSizeBytes / element_size_bytes; +} + +// Limit the bandwidth for low occupancy cases. Each SM can issue at most +// one 32B memory transaction per clock. H100 needs at least 56.8 active SMs +// (1830 MHz) to saturate the memory bandwidth (3.35 TB/s). +float AdjustBandwidth(const se::DeviceDescription& gpu_device_info, + float bandwidth, int64_t num_blocks) { + float per_block_bandwidth = gpu_device_info.clock_rate_ghz() * 1.0e9f * 32; + float max_bandwidth = num_blocks * per_block_bandwidth; + + return std::min(bandwidth, max_bandwidth); +} + +} // namespace + +std::optional GpuPerformanceModelCache::Get( + const HloInstruction& instruction) { + absl::MutexLock lock(&mutex_); + + auto it = instruction_runtime_data_.find(HloInstructionAdaptor(instruction)); + if (it != instruction_runtime_data_.end()) { + return it->second; + } + return std::nullopt; +} + +std::optional GpuPerformanceModelCache::Get( + const HloInstruction& producer, const HloInstruction& consumer) { + absl::MutexLock lock(&mutex_); + + auto it = fusion_runtime_data_.find(HloInstructionAdaptor(producer)); + if (it != fusion_runtime_data_.end()) { + auto jt = it->second.find(HloInstructionAdaptor(consumer)); + if (jt != it->second.end()) { + return jt->second; + } + } + return std::nullopt; +} + +void GpuPerformanceModelCache::Set(const HloInstruction& instruction, + const EstimateRunTimeData& runtime_data) { + absl::MutexLock lock(&mutex_); + + instruction_runtime_data_[HloInstructionAdaptor(instruction)] = runtime_data; +} + +void GpuPerformanceModelCache::Set(const HloInstruction& producer, + const HloInstruction& consumer, + absl::Duration runtime) { + absl::MutexLock lock(&mutex_); + fusion_runtime_data_[HloInstructionAdaptor(producer)] + [HloInstructionAdaptor(consumer)] = runtime; +} + +void GpuPerformanceModelCache::Invalidate(const HloInstruction& instruction) { + absl::MutexLock lock(&mutex_); + HloInstructionAdaptor adaptor(instruction); + + // Remove runtime data for the instruction. + instruction_runtime_data_.erase(adaptor); + + // Remove cache for all producer-consumer pairs where the instruction is + // producer. + fusion_runtime_data_.erase(adaptor); + + // Iterate through operands to find all producer-consumer pairs where + // instruction is consumer and remove them from cache. + for (auto* operand : instruction.operands()) { + auto it = fusion_runtime_data_.find(HloInstructionAdaptor(*operand)); + if (it != fusion_runtime_data_.end()) { + it->second.erase(adaptor); + } + } +} + +/*static*/ +LaunchDimensions GpuPerformanceModelBase::EstimateFusionLaunchDimensions( + int64_t estimated_num_threads, const HloFusionAnalysis& fusion_analysis, + const se::DeviceDescription& device_info) { + auto emitter = + GetFusionEmitter(PreBufferAssignmentFusionInfo{fusion_analysis}); + if (emitter.ok()) { + if (const auto* kernel_emitter = + dynamic_cast(emitter->get())) { + return kernel_emitter->launch_dimensions(); + } + } + int64_t block_size = 128; // Result for default LaunchDimensionsConfig. + int64_t num_blocks = CeilOfRatio(estimated_num_threads, block_size); + return LaunchDimensions(num_blocks, block_size); +} + +/*static*/ +int64_t GpuPerformanceModelBase::GetOperandBytesAccessed( + const GpuHloCostAnalysis* cost_analysis, const HloInstruction* instr, + const HloInstruction* operand) { + // When called for a consumer-producer fusion, the operand can be from a + // different instruction. GpuHloCostAnalysis can't fail gravefully in this + // case, so we need an explicit check. + if (!instr->IsUserOf(operand)) { + return 0; + } + + return cost_analysis->operand_bytes_accessed(*instr, + instr->operand_index(operand)); +} + +/*static*/ +float GpuPerformanceModelBase::GetOperandUtilization( + const GpuHloCostAnalysis* cost_analysis, const HloInstruction* instr, + const HloInstruction* operand) { + // When called for a consumer-producer fusion, the operand can be from a + // different instruction. GpuHloCostAnalysis can't fail gravefully in this + // case, so we need an explicit check. + if (!instr->IsUserOf(operand)) { + return 0.f; + } + + return cost_analysis->operand_utilization(*instr, + instr->operand_index(operand)); +} + +/*static*/ +float GpuPerformanceModelBase::GetCommonUtilization( + const GpuHloCostAnalysis* cost_analysis, const HloInstruction* producer, + int64_t producer_idx_of_operand, const HloInstruction* consumer) { + const auto* operand = producer->operand(producer_idx_of_operand); + + if (!consumer || !consumer->IsUserOf(operand)) { + return 0.f; + } + + if (producer->IsElementwise() || + (producer->opcode() == HloOpcode::kFusion && + FusionUsesParameterElementwiseFromRoot(producer, producer_idx_of_operand, + cost_analysis))) { + if (consumer->opcode() == HloOpcode::kFusion) { + int64_t consumer_idx_of_common_operand = consumer->operand_index(operand); + int64_t consumer_idx_of_producer = consumer->operand_index(producer); + return cost_analysis->CommonElementwiseUtilization( + consumer->fused_parameter(consumer_idx_of_common_operand), + consumer->fused_parameter(consumer_idx_of_producer)); + } else { + if (consumer->IsElementwise()) { + return 1.f; + } + } + } + return 0.f; +} + +/*static*/ +int64_t GpuPerformanceModelBase::GetSharedOperandBytesAccessed( + const GpuHloCostAnalysis* cost_analysis, const HloInstruction* producer, + const HloInstruction* consumer, const HloInstruction* operand) { + float producer_utilization_by_consumer = + GetOperandUtilization(cost_analysis, consumer, producer); + + int64_t bytes_accessed_by_producer = + GetOperandBytesAccessed(cost_analysis, producer, operand); + + int64_t bytes_accessed_by_consumer = + GetOperandBytesAccessed(cost_analysis, consumer, operand); + + float common_utilization = + producer->IsUserOf(operand) + ? GetCommonUtilization(cost_analysis, producer, + producer->operand_index(operand), consumer) + : 0.f; + + int64_t operand_size = cost_analysis->GetShapeSize(operand->shape()); + int64_t common_bytes_accessed = + std::llround(operand_size * common_utilization); + + return std::llround(bytes_accessed_by_producer * + producer_utilization_by_consumer) + + bytes_accessed_by_consumer - common_bytes_accessed; +} + +/*static*/ +absl::Duration GpuPerformanceModelBase::ReadTime( + const se::DeviceDescription& gpu_device_info, int64_t num_blocks, + int64_t n_bytes_net, int64_t n_bytes_total) { + float bandwidth = gpu_device_info.memory_bandwidth(); + if (n_bytes_net < gpu_device_info.l2_cache_size()) { + bandwidth *= kL2CacheSpeedup; + if (n_bytes_net < kL1CacheSizePerSM * gpu_device_info.core_count()) { + bandwidth *= kL1CacheSpeedup; + } + } + + bandwidth = AdjustBandwidth(gpu_device_info, bandwidth, num_blocks); + return absl::Seconds(n_bytes_total / bandwidth); +} + +/*static*/ +absl::Duration GpuPerformanceModelBase::ReadTimeWithDRAMHeuristic( + const se::DeviceDescription& gpu_device_info, int64_t num_blocks, + int64_t n_bytes_net, int64_t n_bytes_total, PrimitiveType element_type, + bool coalesced) { + int waste_factor = coalesced ? 1 : GetCoalescingWasteFactor(element_type); + + // The first read of the input buffer always happens from DRAM. If reads are + // no coaleced, bandwidth is reduced by the waste factor. + float dram_bandwidth = gpu_device_info.memory_bandwidth() / waste_factor; + + // Two things can happed on re-reading the buffer: + // - If the buffer fits into cache, the L1/L2 cache speedup is applied. + // - If the buffer doesn't fit, it will be read from DRAM and the same + // coalessing waste factor is applied. + float rest_bandwidth = gpu_device_info.memory_bandwidth(); + if (n_bytes_net < gpu_device_info.l2_cache_size()) { + rest_bandwidth *= kL2CacheSpeedup; + if (n_bytes_net < kL1CacheSizePerSM * gpu_device_info.core_count()) { + rest_bandwidth *= kL1CacheSpeedup; + } + } else { + rest_bandwidth /= waste_factor; + } + + dram_bandwidth = AdjustBandwidth(gpu_device_info, dram_bandwidth, num_blocks); + rest_bandwidth = AdjustBandwidth(gpu_device_info, rest_bandwidth, num_blocks); + + // n_bytes_net > n_bytes_total can happen when we compute read time of + // shared operand. This is a flaw in the interface that should be fixed. + int64_t n_bytes_read_dram = std::min(n_bytes_net, n_bytes_total); + + // Number of bytes that we be re-read, potentially from cache. + int64_t n_bytes_read_cache = n_bytes_total - n_bytes_read_dram; + + return absl::Seconds(n_bytes_read_dram / dram_bandwidth) + + absl::Seconds(n_bytes_read_cache / rest_bandwidth); +} + +/*static*/ absl::Duration GpuPerformanceModelBase::ProducerInputAccessTime( + const GpuHloCostAnalysis* cost_analysis, + const se::DeviceDescription& gpu_device_info, int64_t num_blocks, + const HloInstruction* producer, const HloFusionAnalysis& fusion_analysis, + const GpuPerformanceModelOptions& config, + const HloInstruction* fused_consumer) { + absl::Duration ret = absl::ZeroDuration(); + float producer_output_utilization = + fused_consumer + ? GetOperandUtilization(cost_analysis, fused_consumer, producer) + : 1.f; + + for (int i = 0; i < producer->operand_count(); ++i) { + // Information about data read taking into account utilization. + // If `operand_utilization` is 0, `operand_bytes_accessed` should be also 0. + int64_t operand_bytes_accessed = + cost_analysis->operand_bytes_accessed(*producer, i); + float operand_utilization = + cost_analysis->operand_utilization(*producer, i); + + // An estimate how much data would need to fit into L1/L2 cache to speed up + // the operand access. + // If `operand_utilization` < 1, only a part of the full operand size should + // be read. Otherwise, `operand_bytes_accessed / operand_utilization` is the + // size of the operand without reuse. + int64_t n_bytes_net = std::llround(operand_bytes_accessed / + std::max(operand_utilization, 1.0f)); + + // Look if common operand of producer and consumer will be accessed more + // efficiently on merge. + float common_utilization = GetCommonUtilization( + cost_analysis, producer, /*producer_idx_of_operand=*/i, fused_consumer); + + CHECK_LE(common_utilization, producer_output_utilization); + float n_bytes_total = operand_bytes_accessed * + (producer_output_utilization - common_utilization); + ret += ReadTime(gpu_device_info, num_blocks, n_bytes_net, n_bytes_total); + } + return ret; +} + +/*static*/ +absl::Duration GpuPerformanceModelBase::WriteTime( + const se::DeviceDescription& gpu_device_info, int64_t bytes_written) { + return absl::Seconds(1.0f * bytes_written / + gpu_device_info.memory_bandwidth()); +} + +/*static*/ +absl::Duration GpuPerformanceModelBase::ComputeTime( + const se::DeviceDescription& gpu_device_info, int64_t flops, + int64_t num_threads) { + int64_t fpu_count = + gpu_device_info.core_count() * gpu_device_info.fpus_per_core(); + int64_t n_threads_active = std::min(num_threads, fpu_count); + int64_t flop_per_ns_per_fpu = gpu_device_info.clock_rate_ghz() * /*fma:*/ 2; + int64_t flop_per_ns_effective = flop_per_ns_per_fpu * n_threads_active; + return absl::Nanoseconds(1.0f * flops / flop_per_ns_effective); +} + +/*static*/ +absl::Duration GpuPerformanceModelBase::CombineComputeAndMemoryAccessTime( + absl::Duration compute_time, absl::Duration memory_access_time, + const GpuPerformanceModelOptions& config) { + return compute_time + memory_access_time - + std::min(compute_time, memory_access_time) * + config.memory_compute_parallelism; +} + +/*static*/ +void GpuPerformanceModelBase::VLogOperandRead(const HloInstruction* operand, + int64_t n_bytes_total, + int64_t n_bytes_net, + bool coalesced) { + VLOG(8) << "operand " << operand->name() + << ", n_bytes_total: " << n_bytes_total + << ", n_bytes_net: " << n_bytes_net << ", coalesced: " << coalesced; +} + +/*static*/ +void GpuPerformanceModelBase::VLogResult( + int64_t flops, int64_t bytes_read, int64_t bytes_written, + int64_t num_threads, absl::Duration compute_time, absl::Duration read_time, + absl::Duration write_time, absl::Duration exec_time) { + if (VLOG_IS_ON(8)) { + LOG(INFO) << "FLOPs: " << flops; + LOG(INFO) << "Bytes read: " << bytes_read; + LOG(INFO) << "Bytes written: " << bytes_written; + LOG(INFO) << "Num threads: " << num_threads; + LOG(INFO) << "Compute time: " << compute_time; + LOG(INFO) << "Input read time: " << read_time; + LOG(INFO) << "Output write time: " << write_time; + LOG(INFO) << "Exec time: " << exec_time; + } +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/gpu_performance_model_base.h b/xla/service/gpu/model/gpu_performance_model_base.h new file mode 100644 index 0000000000000..7d08a0c68a0bb --- /dev/null +++ b/xla/service/gpu/model/gpu_performance_model_base.h @@ -0,0 +1,229 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_GPU_PERFORMANCE_MODEL_BASE_H_ +#define XLA_SERVICE_GPU_MODEL_GPU_PERFORMANCE_MODEL_BASE_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/fusion_analysis_cache.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { +namespace gpu { + +struct EstimateRunTimeData { + int64_t flops; + int64_t bytes_written; + int64_t num_threads; + absl::Duration read_time; + absl::Duration write_time; + absl::Duration compute_time; + absl::Duration exec_time; +}; + +class GpuPerformanceModelCache { + public: + // Returns cached runtime data for the instruction or producer-consumer pair. + // Returns nullopt if there is no data in cache. + std::optional Get(const HloInstruction& instruction); + std::optional Get(const HloInstruction& producer, + const HloInstruction& consumer); + + // Sets cache value for the instruction or producer-consumer pair. + void Set(const HloInstruction& instruction, + const EstimateRunTimeData& runtime_data); + void Set(const HloInstruction& producer, const HloInstruction& consumer, + absl::Duration runtime); + + // Removes all cache entries for this instruction. The cache contains entries + // for individual instructions in instruction_runtime_data_ and for + // producer-consumer pairs in fusion_runtime_data_. + void Invalidate(const HloInstruction& instruction); + + private: + absl::Mutex mutex_; + + // Stores unfused runtime data for individual instructions. + absl::flat_hash_map + instruction_runtime_data_; + + // Stores fused runtime data for producer-consumer pairs. + absl::flat_hash_map< + HloInstructionAdaptor, + absl::flat_hash_map> + fusion_runtime_data_; +}; + +struct GpuPerformanceModelOptions { + // Factor for how much parallelism between compute and memory accesses should + // be assumed. If 1.0, assume perfect parallelism (the run time is the maximum + // of both times). If 0.0, assume no parallelism (the run time is the sum of + // both times). + double memory_compute_parallelism = 1.0; + + // If present, use this to retrieve fusion analyses. + HloFusionAnalysisCache* fusion_analysis_cache = nullptr; + + GpuPerformanceModelCache* gpu_performance_model_cache = nullptr; + + static GpuPerformanceModelOptions Default() { + return GpuPerformanceModelOptions(); + } + + static GpuPerformanceModelOptions PriorityFusion( + HloFusionAnalysisCache* fusion_analysis_cache = nullptr, + GpuPerformanceModelCache* gpu_performance_model_cache = nullptr) { + GpuPerformanceModelOptions config; + config.fusion_analysis_cache = fusion_analysis_cache; + config.gpu_performance_model_cache = gpu_performance_model_cache; + // This constant was chosen empirically in early 2024, based on runtime + // performance on a set of benchmarks internal to Google. Intuitively, we + // expect it to be close to 1, but not quite 1 (i.e., sometimes, compute + // or memory accesses will be stalled waiting for the other, but usually + // they won't). + config.memory_compute_parallelism = 0.95; + return config; + } + + static GpuPerformanceModelOptions ForModule(const HloModule* module) { + return module->config().debug_options().xla_gpu_enable_priority_fusion() + ? PriorityFusion() // Only cache within priority fusion. + : Default(); + } +}; + +class GpuPerformanceModelBase { + public: + struct RunTimes { + absl::Duration time_unfused; + absl::Duration time_fused; + }; + + // Estimated values in the absence of easy ways to query them. + static constexpr absl::Duration kKernelLaunchOverhead = absl::Microseconds(1); + static constexpr absl::Duration kNcclKernelLaunchOverhead = + absl::Microseconds(5); + static constexpr float kL2CacheSpeedup = 2.5; + static constexpr float kL1CacheSpeedup = 8; + // A very conservative estimate. L1 size varies because it can be dynamically + // configured as shared memory; there is no easy way to query its actual size; + // also we do not count what occupies cache, but rather claim that what is + // much smaller than the cache size will likely stay in it. + // For reference, it can be up to 256 kB per SM on RTX A6000. + static constexpr float kL1CacheSizePerSM = 2 * 1024; + + // Uses HloFusionAnalysis for computing the actual number of threads and + // blocks that the IR emitter will use. + static LaunchDimensions EstimateFusionLaunchDimensions( + int64_t estimated_num_threads, const HloFusionAnalysis& fusion_analysis, + const se::DeviceDescription& device_info); + + // Returns bytes accessed of operand output by instruction. Returns 0, if the + // operand is not used by the instruction. + static int64_t GetOperandBytesAccessed( + const GpuHloCostAnalysis* cost_analysis, const HloInstruction* instr, + const HloInstruction* operand); + + // Returns utilization of operand by instruction. Returns 0, if the operand is + // not used by the instruction. + static float GetOperandUtilization(const GpuHloCostAnalysis* cost_analysis, + const HloInstruction* instr, + const HloInstruction* operand); + + // Returns utilization `overlap` between a common operand of producer and + // consumer on merge. `utilization > 0` means that the operand will be + // accessed more efficiently after fusion. + // + // Currently covers two cases: + // 1) Producer has to use the common operand elementwise from its root if it + // is a fusion or just be an elementwise instruction. + // 2) Consumer has to have common elementwise roots for the producer and the + // common operand if it is a fusion or just be an elementwise instruction. + static float GetCommonUtilization(const GpuHloCostAnalysis* cost_analysis, + const HloInstruction* producer, + int64_t producer_idx_of_operand, + const HloInstruction* consumer); + + // Returns bytes accessed of operand after producer and consumer are fused + // together. `GetCommonUtilization` works only for a limited set of + // elementwise cases. + static int64_t GetSharedOperandBytesAccessed( + const GpuHloCostAnalysis* cost_analysis, const HloInstruction* producer, + const HloInstruction* consumer, const HloInstruction* operand); + + // Estimate read time of n_bytes_total bytes from global memory on a + // given GPU. Account for L1 / L2 cache speedup if the input's nominal size + // n_bytes_net is small. + static absl::Duration ReadTime(const se::DeviceDescription& gpu_device_info, + int64_t num_blocks, int64_t n_bytes_net, + int64_t n_bytes_total); + + // Estimate read time of n_bytes_total bytes from global memory on a + // given GPU. + // + // Assumes that the first n_bytes_net are always read from DRAM, but next + // reads can be cached. Applies waste factor if read from DRAM is uncoalesced. + static absl::Duration ReadTimeWithDRAMHeuristic( + const se::DeviceDescription& gpu_device_info, int64_t num_blocks, + int64_t n_bytes_net, int64_t n_bytes_total, PrimitiveType element_type, + bool coalesced); + + // Tells input access time of the producer alone if fused_consumer + // is not specified. Otherwise estimates the access time to producer's + // inputs as if it is fused into the consumer. + static absl::Duration ProducerInputAccessTime( + const GpuHloCostAnalysis* cost_analysis, + const se::DeviceDescription& gpu_device_info, int64_t num_blocks, + const HloInstruction* producer, const HloFusionAnalysis& fusion_analysis, + const GpuPerformanceModelOptions& config, + const HloInstruction* fused_consumer = nullptr); + + static absl::Duration WriteTime(const se::DeviceDescription& gpu_device_info, + int64_t bytes_written); + + static absl::Duration ComputeTime( + const se::DeviceDescription& gpu_device_info, int64_t flops, + int64_t num_threads); + + static absl::Duration CombineComputeAndMemoryAccessTime( + absl::Duration compute_time, absl::Duration memory_access_time, + const GpuPerformanceModelOptions& config); + + // Logs estimates for the operand read if VLOG is enabled. + static void VLogOperandRead(const HloInstruction* operand, + int64_t n_bytes_total, int64_t n_bytes_net, + bool coalesced); + + // Logs estimate results of the performance model if VLOG is enabled. + static void VLogResult(int64_t flops, int64_t bytes_read, + int64_t bytes_written, int64_t num_threads, + absl::Duration compute_time, absl::Duration read_time, + absl::Duration write_time, absl::Duration exec_time); +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_GPU_PERFORMANCE_MODEL_BASE_H_ diff --git a/xla/service/gpu/model/gpu_performance_model_base_test.cc b/xla/service/gpu/model/gpu_performance_model_base_test.cc new file mode 100644 index 0000000000000..d15c0d4339bfc --- /dev/null +++ b/xla/service/gpu/model/gpu_performance_model_base_test.cc @@ -0,0 +1,196 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/gpu_performance_model_base.h" + +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/test_helpers.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuPerformanceModelBaseTest : public HloTestBase { + public: + GpuHloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { + return [&](const Shape& shape) { + constexpr int64_t kPointerSize = 8; + return ShapeUtil::ByteSizeOf(shape, kPointerSize); + }; + } + + GpuHloCostAnalysis::Options options_{ShapeSizeBytesFunction(), + /*per_second_rates=*/{}, + /*count_multiple_input_accesses=*/true}; + // The reference times in the test cases below are measured + // on A6000 by profiling the execution of the HLOs. + se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; + GpuHloCostAnalysis analysis_{options_, &device_info_}; + + GpuPerformanceModelBaseTest() : HloTestBase() {} +}; + +TEST_F(GpuPerformanceModelBaseTest, SharedOperandBytesAccessed_InPlaceDUS) { + absl::string_view hlo_string = R"( +HloModule m + +ENTRY entry_computation { + param_0 = f32[8,16] parameter(0) + param_1 = f32[4,4] parameter(1) + c_0 = s32[] constant(0) + log = f32[4,4] log(param_1) + ROOT dynamic-update-slice = f32[8,16] dynamic-update-slice(param_0, log, c_0, c_0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto computation = module->entry_computation(); + ASSERT_IS_OK(computation->Accept(&analysis_)); + + auto dus_consumer = computation->root_instruction(); + auto log_producer = dus_consumer->mutable_operand(1); + + auto get_shared_operand_bytes_accessed = [&](const HloInstruction* operand) { + return GpuPerformanceModelBase::GetSharedOperandBytesAccessed( + &analysis_, log_producer, dus_consumer, operand); + }; + + EXPECT_EQ(get_shared_operand_bytes_accessed(dus_consumer->operand(0)), 0); + EXPECT_EQ(get_shared_operand_bytes_accessed(log_producer->operand(0)), 64); +} + +TEST_F(GpuPerformanceModelBaseTest, SharedOperandBytesAccessed_DUS) { + absl::string_view hlo_string = R"( +HloModule m + +ENTRY entry_computation { + param_0 = f32[8,16] parameter(0) + param_1 = f32[4,4] parameter(1) + c_0 = s32[] constant(0) + log = f32[8,16] log(param_0) + ROOT dynamic-update-slice = f32[8,16] dynamic-update-slice(log, param_1, c_0, c_0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto computation = module->entry_computation(); + ASSERT_IS_OK(computation->Accept(&analysis_)); + + auto dus_consumer = computation->root_instruction(); + auto log_producer = dus_consumer->mutable_operand(0); + + auto get_shared_operand_bytes_accessed = [&](const HloInstruction* operand) { + return GpuPerformanceModelBase::GetSharedOperandBytesAccessed( + &analysis_, log_producer, dus_consumer, operand); + }; + + EXPECT_EQ(get_shared_operand_bytes_accessed(dus_consumer->operand(1)), 64); + EXPECT_EQ(get_shared_operand_bytes_accessed(log_producer->operand(0)), 448); +} + +// This test documents current behaviour. See comments below how the correct +// result should look like. +TEST_F(GpuPerformanceModelBaseTest, + ReduceBroadcastedDim_IncorrectBytesAccessed) { + absl::string_view hlo_string = R"( +HloModule m + +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +f1 { + p0 = f32[128] parameter(0) + c0 = f32[] constant(0) + broadcast = f32[128,256] broadcast(p0), dimensions={0} + ROOT reduce = f32[128] reduce(broadcast, c0), dimensions={1}, to_apply=add +} + +ENTRY entry_computation { + param_0 = f32[128] parameter(0) + param_1 = f32[4,4] parameter(1) + ROOT fusion = f32[128] fusion(param_0), kind=kLoop, calls=f1 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto computation = module->entry_computation(); + ASSERT_IS_OK(computation->Accept(&analysis_)); + + auto root = computation->root_instruction(); + + // Cost Model estimates that input element we be re-read in reduce. Each + // element of reduce output needs only one input element. Bytes accessed + // should be 4*128=512. + EXPECT_EQ(GpuPerformanceModelBase::GetOperandBytesAccessed(&analysis_, root, + root->operand(0)), + /*4*128*256=*/131072); +} + +// This test documents current behaviour. See comments below how the correct +// result should look like. +TEST_F(GpuPerformanceModelBaseTest, ElementwiseBitcast_IncorrectBytesAccessed) { + absl::string_view hlo_string = R"( +HloModule m + +f1 { + p0 = f32[128] parameter(0) + bitcast.1 = f32[8,16] bitcast(p0) + log = f32[128] log(p0) + bitcast.2 = f32[8,16] bitcast(log) + ROOT add = f32[8,16] add(bitcast.1, bitcast.2) +} + +ENTRY entry_computation { + param_0 = f32[128] parameter(0) + ROOT fusion = f32[8,16] fusion(param_0), kind=kLoop, calls=f1 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto computation = module->entry_computation(); + ASSERT_IS_OK(computation->Accept(&analysis_)); + + auto root = computation->root_instruction(); + + // Bitcast breaks the chain of elementwise utilization even if the bitcast + // doesn't change physical layout. Each element of `param_0` should be read + // only once, but Cost Model estimates that it will be accessed twice. Bytes + // accessed should be 4*128=512. + EXPECT_EQ(GpuPerformanceModelBase::GetOperandBytesAccessed(&analysis_, root, + root->operand(0)), + /*2*4*128=*/1024); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/gpu_performance_model_test.cc b/xla/service/gpu/model/gpu_performance_model_test.cc index 0ed4f664ef1f6..82984b1193bef 100644 --- a/xla/service/gpu/model/gpu_performance_model_test.cc +++ b/xla/service/gpu/model/gpu_performance_model_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,9 +18,13 @@ limitations under the License. #include #include #include +#include +#include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -28,11 +32,16 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/gpu_indexing_performance_model.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" +#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -47,13 +56,34 @@ class GpuPerformanceModelTest : public HloTestBase { } public: + GpuPerformanceModel::RunTimes EstimateRunTimesDefault( + const HloInstruction* producer, + std::vector fused_consumers = {}) { + return GpuPerformanceModel::EstimateRunTimes( + producer, &analysis_, GpuPerformanceModelOptions::Default(), + fused_consumers); + } + + GpuPerformanceModel::RunTimes EstimateRunTimesForPriorityFusion( + const HloInstruction* producer, + std::vector fused_consumers = {}) { + return GpuPerformanceModel::EstimateRunTimesForPriorityFusion( + producer, &analysis_, GpuPerformanceModelOptions::PriorityFusion(), + fused_consumers); + } + + mlir::MLIRContext mlir_context_; GpuHloCostAnalysis::Options options_{ShapeSizeBytesFunction(), /*per_second_rates=*/{}, /*count_multiple_input_accesses=*/true}; // The reference times in the test cases below are measured // on A6000 by profiling the execution of the HLOs. - se::DeviceDescription dev_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; - GpuHloCostAnalysis analysis_{options_, &dev_info_}; + se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; + GpuHloCostAnalysis analysis_{options_, &device_info_}; + + GpuPerformanceModelWithIndexingAnalysis indexing_cost_model_{ + &device_info_, ShapeSizeBytesFunction(), &mlir_context_}; + GpuPerformanceModelTest() : HloTestBase() {} }; @@ -75,10 +105,16 @@ ENTRY e { HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); - GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( - root, &analysis_, GpuPerformanceModelOptions::Default()); + auto t = EstimateRunTimesDefault(root); // Dominated by the DRAM bandwidth. - EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 57, 10); + EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 53, 10); + + auto prio_t = EstimateRunTimesForPriorityFusion(root); + // Dominated by the DRAM bandwidth. + EXPECT_NEAR(absl::ToInt64Microseconds(prio_t.time_unfused), 53, 10); + + auto indexing_t = indexing_cost_model_.EstimateRunTimes(root); + EXPECT_NEAR(absl::ToInt64Microseconds(indexing_t.time_unfused), 53, 10); } TEST_F(GpuPerformanceModelTest, SmallReadWrite) { @@ -102,17 +138,20 @@ ENTRY e { HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(root->Accept(&analysis_)); - GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( - root, &analysis_, GpuPerformanceModelOptions::Default()); + auto t = EstimateRunTimesDefault(root); // Dominated by the kernel launch overhead. - EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 5, 1); + EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 1, 1); GpuPerformanceModel::RecordEstimatedRunTime( root, &analysis_, GpuPerformanceModelOptions::Default()); - double recorded_cycles = root->backend_config() - ->reification_cost() - .end_to_end_cycles(); - EXPECT_NEAR(recorded_cycles, 257.7, 0.1); + auto reification_cost = root->backend_config() + ->fusion_backend_config() + .reification_cost(); + EXPECT_NEAR(reification_cost.end_to_end_cycles(), 257.7, 0.1); + EXPECT_NEAR(reification_cost.exec_time_us(), 0, 1); + + auto indexing_t = indexing_cost_model_.EstimateRunTimes(root); + EXPECT_NEAR(absl::ToInt64Microseconds(indexing_t.time_unfused), 1, 1); } TEST_F(GpuPerformanceModelTest, LargeReadWrite) { @@ -136,17 +175,19 @@ ENTRY e { HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(root->Accept(&analysis_)); - GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( - root, &analysis_, GpuPerformanceModelOptions::Default()); + auto t = EstimateRunTimesDefault(root); // Dominated by the DRAM bandwidth. EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 175, 30); GpuPerformanceModel::RecordEstimatedRunTime( root, &analysis_, GpuPerformanceModelOptions::Default()); - double recorded_cycles = root->backend_config() - ->reification_cost() - .end_to_end_cycles(); - EXPECT_NEAR(recorded_cycles, 220284, 100); + auto reification_cost = root->backend_config() + ->fusion_backend_config() + .reification_cost(); + EXPECT_NEAR(reification_cost.end_to_end_cycles(), 220284, 100); + EXPECT_NEAR(reification_cost.exec_time_us(), 156, 10); + EXPECT_NEAR(reification_cost.compute_time_us(), 1, 1); + EXPECT_NEAR(reification_cost.memory_access_time_us(), 156, 10); } TEST_F(GpuPerformanceModelTest, L1CacheEffect) { @@ -172,8 +213,7 @@ ENTRY e { HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(root->Accept(&analysis_)); - GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( - root, &analysis_, GpuPerformanceModelOptions::Default()); + auto t = EstimateRunTimesDefault(root); // Parameter 0 read is accelerated by L1 cache even though the total data // volume is the same as in the test LargeReadWrite above. EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 118, 12); @@ -202,8 +242,7 @@ ENTRY e { HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(root->Accept(&analysis_)); - GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( - root, &analysis_, GpuPerformanceModelOptions::Default()); + auto t = EstimateRunTimesDefault(root); // Parameter 0 read is accelerated by L2 cache (does not fit in L1). EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 123, 12); } @@ -235,26 +274,8 @@ TEST_F(GpuPerformanceModelTest, UnusedParameter) { HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); - GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( - root, &analysis_, GpuPerformanceModelOptions::Default()); - EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 5, 1); -} - -using GpuPerformanceWithCollectiveModelTest = GpuPerformanceModelTest; - -TEST_F(GpuPerformanceWithCollectiveModelTest, TestNvmlLibraryLoading) { -#if GOOGLE_CUDA - EXPECT_TRUE(GpuPerformanceWithCollectiveModel::InitNvml()); - // After successful init, we try to use one of the - // nvml functions to see if the result is good. - nvmlDevice_t nvml_device; - nvmlReturn_t get_device_result = - xla_nvmlDeviceGetHandleByIndex(0, &nvml_device); - EXPECT_TRUE(get_device_result == NVML_SUCCESS); - - EXPECT_TRUE(GpuPerformanceWithCollectiveModel::InitNvml()); - -#endif // GOOGLE_CUDA + auto t = EstimateRunTimesDefault(root); + EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 1, 1); } TEST_F(GpuPerformanceModelTest, ComputeBoundReducesWithSameLaunchDimensions) { @@ -308,9 +329,9 @@ ENTRY fusion { )"; auto run = [&](absl::string_view hlo_text) - -> StatusOr { + -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_text)); - GpuHloCostAnalysis analysis(options_, &dev_info_); + GpuHloCostAnalysis analysis(options_, &device_info_); TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis)); auto* producer = @@ -318,8 +339,7 @@ ENTRY fusion { std::vector consumers{ module->entry_computation()->GetInstructionWithName("reduce.2")}; - return GpuPerformanceModel::EstimateRunTimes( - producer, &analysis, GpuPerformanceModelOptions::Default(), consumers); + return EstimateRunTimesDefault(producer, consumers); }; TF_ASSERT_OK_AND_ASSIGN(auto large_small_reduce_runtime, @@ -359,10 +379,8 @@ ENTRY fusion { module->entry_computation()->GetInstructionWithName("transpose.1"); std::vector consumers{ module->entry_computation()->GetInstructionWithName("reduce.1")}; - GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( - producer, &analysis_, - GpuPerformanceModelOptions::PriorityFusion(nullptr, nullptr), consumers); + auto t = EstimateRunTimesForPriorityFusion(producer, consumers); EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 105, 10); EXPECT_NEAR(absl::ToInt64Microseconds(t.time_fused), 514, 10); } @@ -392,10 +410,12 @@ ENTRY fusion { module->entry_computation()->GetInstructionWithName("transpose.1"); std::vector consumers{ module->entry_computation()->GetInstructionWithName("reduce.1")}; - GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( - producer, &analysis_, GpuPerformanceModelOptions::Default(), consumers); + auto t = EstimateRunTimesDefault(producer, consumers); EXPECT_LT(t.time_fused, t.time_unfused); + + auto prio_t = EstimateRunTimesForPriorityFusion(producer, consumers); + EXPECT_LT(prio_t.time_fused, prio_t.time_unfused); } TEST_F(GpuPerformanceModelTest, DusScalesWithUpdates) { @@ -443,17 +463,20 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); - GpuPerformanceModel::RunTimes t1 = GpuPerformanceModel::EstimateRunTimes( - module->entry_computation()->root_instruction()->operand(0), &analysis_, - GpuPerformanceModelOptions::Default()); - GpuPerformanceModel::RunTimes t2 = GpuPerformanceModel::EstimateRunTimes( - module->entry_computation()->root_instruction()->operand(1), &analysis_, - GpuPerformanceModelOptions::Default()); + auto* operand0 = module->entry_computation()->root_instruction()->operand(0); + auto* operand1 = module->entry_computation()->root_instruction()->operand(1); + auto t1 = EstimateRunTimesDefault(operand0); + auto t2 = EstimateRunTimesDefault(operand1); // DUS scales with the size of the updates, so these two fusions should have // the same cost. EXPECT_NEAR(absl::ToInt64Microseconds(t1.time_unfused), absl::ToInt64Microseconds(t2.time_unfused), 10); + + auto prio_t1 = EstimateRunTimesForPriorityFusion(operand0); + auto prio_t2 = EstimateRunTimesForPriorityFusion(operand1); + EXPECT_NEAR(absl::ToInt64Microseconds(prio_t1.time_unfused), + absl::ToInt64Microseconds(prio_t2.time_unfused), 10); } TEST_F(GpuPerformanceModelTest, EqualCostBeforeAndAfterFusion) { @@ -498,9 +521,7 @@ ENTRY e2 { HloInstruction* consumer = computation_without_fusion->root_instruction(); const HloInstruction* producer = consumer->operand(0); - GpuPerformanceModel::RunTimes t1 = GpuPerformanceModel::EstimateRunTimes( - producer, &analysis_, - GpuPerformanceModelOptions::PriorityFusion(nullptr, nullptr), {consumer}); + auto t1 = EstimateRunTimesForPriorityFusion(producer, {consumer}); HloComputation* computation_with_fusion = module->GetComputationWithName("e2"); @@ -508,13 +529,95 @@ ENTRY e2 { HloInstruction* root_with_fusion = computation_with_fusion->root_instruction(); - GpuPerformanceModel::RunTimes t2 = GpuPerformanceModel::EstimateRunTimes( - root_with_fusion, &analysis_, - GpuPerformanceModelOptions::PriorityFusion(nullptr, nullptr), {}); - + auto t2 = EstimateRunTimesForPriorityFusion(root_with_fusion); EXPECT_EQ(t1.time_fused, t2.time_unfused); } +TEST_F(GpuPerformanceModelTest, DoNotFuseDivideIntoSmallReduce) { + // Fusing this divide is not supported by reduce epilogue fusion. + constexpr absl::string_view kHlo = R"( +HloModule testmodule + +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +ENTRY fusion { + c = f32[] constant(0) + p0 = f32[3072] parameter(0) + p1 = f32[] parameter(1) + reduce = f32[] reduce(p0, c), dimensions={0}, to_apply=add + ROOT divide = f32[] divide(reduce, p1) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); + ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); + + auto* producer = + module->entry_computation()->GetInstructionWithName("reduce"); + std::vector consumers{ + module->entry_computation()->GetInstructionWithName("divide")}; + + auto t = EstimateRunTimesForPriorityFusion(producer, consumers); + EXPECT_LT(t.time_unfused, t.time_fused); +} + +TEST_F(GpuPerformanceModelTest, PreferFusingExpensiveInstructionsIntoProducer) { + // All things being equal, prefer fusing instructions into their producer, + // since this avoids potentially expensive recomputations when memory and + // compute aren't perfectly overlapping. + constexpr absl::string_view kHlo = R"( +HloModule testmodule + +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +fused_computation.0 { + p0 = f32[4,8,8] parameter(0) + bc = f32[1,4,1424,8,8] broadcast(p0), dimensions={1,3,4} + p1 = f32[1,4,1424,8,8] parameter(1) + ROOT sub = f32[1,4,1424,8,8] subtract(bc, p1) +} + +fused_computation.1 { + p0 = f32[1,4,1424,8,8] parameter(0) + bc = f32[4,1424,8,8] bitcast(p0) + c0 = f32[] constant(0) + ROOT reduce = f32[4,8,8] reduce(bc, c0), to_apply=add, dimensions={1} +} + +ENTRY fusion { + p0 = f32[4,8,8] parameter(0) + p1 = f32[1,4,1424,8,8] parameter(1) + fusion.0 = f32[1,4,1424,8,8] fusion(p0, p1), kind=kLoop, calls=fused_computation.0 + exp = f32[1,4,1424,8,8] exponential(fusion.0) + ROOT fusion.1 = f32[4,8,8] fusion(exp), kind=kInput, calls=fused_computation.1 +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); + ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); + + auto* fusion_0 = + module->entry_computation()->GetInstructionWithName("fusion.0"); + auto* exp = module->entry_computation()->GetInstructionWithName("exp"); + auto exp_consumer_runtimes = + EstimateRunTimesForPriorityFusion(fusion_0, {exp}); + auto exp_producer_runtimes = + EstimateRunTimesForPriorityFusion(exp, exp->users()); + + auto exp_consumer_priority = + exp_consumer_runtimes.time_unfused - exp_consumer_runtimes.time_fused; + auto exp_producer_priority = + exp_producer_runtimes.time_unfused - exp_producer_runtimes.time_fused; + + EXPECT_LT(exp_producer_priority, exp_consumer_priority); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/model/hlo_op_profiler.cc b/xla/service/gpu/model/hlo_op_profiler.cc index 5a3d32e2eca08..27a9cb2452a9e 100644 --- a/xla/service/gpu/model/hlo_op_profiler.cc +++ b/xla/service/gpu/model/hlo_op_profiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,8 +15,6 @@ limitations under the License. #include "xla/service/gpu/model/hlo_op_profiler.h" -#include -#include #include #include #include @@ -24,6 +22,8 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "xla/debug_options_flags.h" @@ -43,6 +43,7 @@ limitations under the License. #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" #ifdef GOOGLE_CUDA #include "xla/backends/profiler/gpu/cupti_collector.h" @@ -149,7 +150,7 @@ class CuptiKernelTracer { return module; } -StatusOr HloOpProfiler::MeasureOpChainDuration( +absl::StatusOr HloOpProfiler::MeasureOpChainDuration( HloOpcode op, PrimitiveType data_type, int chain_length) { #ifndef GOOGLE_CUDA return FailedPrecondition("Not built with --config=cuda"); @@ -201,7 +202,7 @@ HloOpProfiler::HloOpProfiler(HloRunner& runner) << "Failed to measure kernel runtime"; } -StatusOr HloOpProfiler::MeasureClockCyclesPerOp( +absl::StatusOr HloOpProfiler::MeasureClockCyclesPerOp( HloOpcode op, PrimitiveType data_type) { VLOG(2) << "Measuring " << HloOpcodeString(op) << " " << primitive_util::LowercasePrimitiveTypeName(data_type); diff --git a/xla/service/gpu/model/hlo_op_profiler.h b/xla/service/gpu/model/hlo_op_profiler.h index 86cd08392a299..f9d83f63d559e 100644 --- a/xla/service/gpu/model/hlo_op_profiler.h +++ b/xla/service/gpu/model/hlo_op_profiler.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,9 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_MODEL_HLO_OP_PROFILER_H_ #define XLA_SERVICE_GPU_MODEL_HLO_OP_PROFILER_H_ -#include #include +#include "absl/status/statusor.h" #include "absl/time/time.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -34,13 +34,13 @@ class HloOpProfiler { static std::unique_ptr MakeModuleForMeasurements( HloOpcode op, PrimitiveType data_type, int chain_length); - StatusOr MeasureOpChainDuration(HloOpcode op, - PrimitiveType data_type, - int chain_length); + absl::StatusOr MeasureOpChainDuration(HloOpcode op, + PrimitiveType data_type, + int chain_length); public: explicit HloOpProfiler(HloRunner& runner); - StatusOr MeasureClockCyclesPerOp( + absl::StatusOr MeasureClockCyclesPerOp( HloOpcode op, PrimitiveType data_type); private: diff --git a/xla/service/gpu/model/hlo_op_profiler_run.cc b/xla/service/gpu/model/hlo_op_profiler_run.cc index 606fe4dab5f95..b71dc91505dbd 100644 --- a/xla/service/gpu/model/hlo_op_profiler_run.cc +++ b/xla/service/gpu/model/hlo_op_profiler_run.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,15 +24,16 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/model/hlo_op_profile.pb.h" #include "xla/service/gpu/model/hlo_op_profiler.h" +#include "xla/service/gpu/model/hlo_op_profiles.h" #include "xla/service/hlo_runner.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/device_description.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/xla_data.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/init_main.h" #include "tsl/platform/path.h" #include "tsl/platform/status.h" -#include "tsl/util/command_line_flags.h" namespace xla { namespace gpu { @@ -88,6 +89,7 @@ int RunProfiler(int argc, char** argv) { // Unary HloOpcode::kCbrt, HloOpcode::kCos, + HloOpcode::kErf, HloOpcode::kExp, HloOpcode::kExpm1, HloOpcode::kLog, @@ -119,10 +121,11 @@ int RunProfiler(int argc, char** argv) { } } - VLOG(1) << "\n" << instr_profiles.DebugString(); + VLOG(1) << "\n" << instr_profiles; + auto profile_name = HloOpProfiles::GetProfileName(&dev_info); DeviceHloInstructionProfiles device_profiles; - device_profiles.mutable_entries()->insert({dev_info.name(), instr_profiles}); + device_profiles.mutable_entries()->insert({profile_name, instr_profiles}); if (!output_file.empty()) { WriteOutput(device_profiles, output_file); } diff --git a/xla/service/gpu/model/hlo_op_profiler_test.cc b/xla/service/gpu/model/hlo_op_profiler_test.cc index 755e26bb8656a..6a8ed6538e8ed 100644 --- a/xla/service/gpu/model/hlo_op_profiler_test.cc +++ b/xla/service/gpu/model/hlo_op_profiler_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/model/hlo_op_profiler.h" +#include #include "xla/hlo/ir/hlo_opcode.h" #include "xla/tests/hlo_test_base.h" @@ -38,7 +39,7 @@ TEST_F(HloOpProfilerTest, BasicMeasurementsAreCorrect) { EXPECT_GT(profiler.MeasureClockCyclesPerOp(HloOpcode::kDivide, F64) .value() .clock_cycles(), - 500); + 400); // c128 sqrt is slow. EXPECT_GT(profiler.MeasureClockCyclesPerOp(HloOpcode::kSqrt, C128) .value() diff --git a/xla/service/gpu/model/hlo_op_profiles.cc b/xla/service/gpu/model/hlo_op_profiles.cc new file mode 100644 index 0000000000000..e8a46b077478c --- /dev/null +++ b/xla/service/gpu/model/hlo_op_profiles.cc @@ -0,0 +1,81 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/hlo_op_profiles.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/model/hlo_op_profiles_data.h" +#include "xla/stream_executor/device_description.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/protobuf.h" + +namespace xla { +namespace gpu { + +/*static*/ const HloOpProfiles& HloOpProfiles::Singleton() { + static const auto* hlo_op_profiles = + HloOpProfiles::Load(kDeviceHloOpProfiles, + /*default_profile_name=*/"sm_86") + .release(); + return *hlo_op_profiles; +} + +/*static*/ std::string HloOpProfiles::GetProfileName( + const se::DeviceDescription* device_info) { + if (device_info != nullptr) { + if (auto* ptr = std::get_if( + &device_info->gpu_compute_capability())) + return absl::StrCat("sm_", ptr->major, ptr->minor); + } + return ""; +} + +/*static*/ std::unique_ptr HloOpProfiles::Load( + std::string_view profiles_text_proto, + std::string_view default_profile_name) { + ProfilesNestedMap profiles_map; + DeviceHloInstructionProfiles all_device_profiles; + CHECK(tsl::protobuf::TextFormat::ParseFromString( + std::string(profiles_text_proto), &all_device_profiles)); + for (const auto& device_profile : all_device_profiles.entries()) { + for (const auto& entry : device_profile.second.entries()) { + auto op_code = StringToHloOpcode(entry.instruction().opcode()).value(); + auto element_type = entry.instruction().shape().element_type(); + + profiles_map[device_profile.first][std::make_pair( + op_code, element_type)] = entry.clock_cycles(); + } + } + return absl::WrapUnique( + new HloOpProfiles(std::move(profiles_map), default_profile_name)); +} + +const HloOpProfiles::HloOpProfile& HloOpProfiles::GetProfile( + const se::DeviceDescription* device_info) const { + auto it = profiles_.find(GetProfileName(device_info)); + if (it != profiles_.end()) return it->second; + return default_profile_; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/hlo_op_profiles.h b/xla/service/gpu/model/hlo_op_profiles.h index 0db77f1603ae5..fdda6d269e64d 100644 --- a/xla/service/gpu/model/hlo_op_profiles.h +++ b/xla/service/gpu/model/hlo_op_profiles.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,3075 +16,55 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_MODEL_HLO_OP_PROFILES_H_ #define XLA_SERVICE_GPU_MODEL_HLO_OP_PROFILES_H_ +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/model/hlo_op_profile.pb.h" +#include "xla/service/hlo.pb.h" +#include "xla/stream_executor/device_description.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" + namespace xla { namespace gpu { -// The data below is obtained with -// xla/service/gpu/model:hlo_op_profiler_run +class HloOpProfiles { + public: + using HloOpProfile = + absl::flat_hash_map, int64_t>; + using ProfilesNestedMap = + absl::flat_hash_map; + + // Returns singleton with profiler data. + static const HloOpProfiles& Singleton(); -constexpr char kDeviceHloOpProfiles[] = R"pb( - entries { - key: "sm_90" - value { - entries { - instruction { - opcode: "divide" - shape { element_type: S8 } - } - clock_cycles: 351 - } - entries { - instruction { - opcode: "multiply" - shape { element_type: S8 } - } - clock_cycles: 7 - } - entries { - instruction { - opcode: "power" - shape { element_type: S8 } - } - clock_cycles: 115 - } - entries { - instruction { - opcode: "divide" - shape { element_type: S16 } - } - clock_cycles: 375 - } - entries { - instruction { - opcode: "multiply" - shape { element_type: S16 } - } - clock_cycles: 7 - } - entries { - instruction { - opcode: "power" - shape { element_type: S16 } - } - clock_cycles: 115 - } - entries { - instruction { - opcode: "divide" - shape { element_type: S32 } - } - clock_cycles: 298 - } - entries { - instruction { - opcode: "multiply" - shape { element_type: S32 } - } - clock_cycles: 3 - } - entries { - instruction { - opcode: "power" - shape { element_type: S32 } - } - clock_cycles: 66 - } - entries { - instruction { - opcode: "divide" - shape { element_type: S64 } - } - clock_cycles: 698 - } - entries { - instruction { - opcode: "multiply" - shape { element_type: S64 } - } - clock_cycles: 10 - } - entries { - instruction { - opcode: "power" - shape { element_type: S64 } - } - clock_cycles: 238 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U8 } - } - clock_cycles: 308 - } - entries { - instruction { - opcode: "multiply" - shape { element_type: U8 } - } - clock_cycles: 7 - } - entries { - instruction { - opcode: "power" - shape { element_type: U8 } - } - clock_cycles: 115 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U16 } - } - clock_cycles: 301 - } - entries { - instruction { - opcode: "multiply" - shape { element_type: U16 } - } - clock_cycles: 7 - } - entries { - instruction { - opcode: "power" - shape { element_type: U16 } - } - clock_cycles: 115 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U32 } - } - clock_cycles: 119 - } - entries { - instruction { - opcode: "multiply" - shape { element_type: U32 } - } - clock_cycles: 3 - } - entries { - instruction { - opcode: "power" - shape { element_type: U32 } - } - clock_cycles: 66 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U64 } - } - clock_cycles: 621 - } - entries { - instruction { - opcode: "multiply" - shape { element_type: U64 } - } - clock_cycles: 10 - } - entries { - instruction { - opcode: "power" - shape { element_type: U64 } - } - clock_cycles: 238 - } - entries { - instruction { - opcode: "cbrt" - shape { element_type: F16 } - } - clock_cycles: 466 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: F16 } - } - clock_cycles: 329 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: F16 } - } - clock_cycles: 105 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: F16 } - } - clock_cycles: 217 - } - entries { - instruction { - opcode: "log" - shape { element_type: F16 } - } - clock_cycles: 182 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: F16 } - } - clock_cycles: 245 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: F16 } - } - clock_cycles: 94 - } - entries { - instruction { - opcode: "sine" - shape { element_type: F16 } - } - clock_cycles: 333 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: F16 } - } - clock_cycles: 98 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: F16 } - } - clock_cycles: 200 - } - entries { - instruction { - opcode: "add" - shape { element_type: F16 } - } - clock_cycles: 7 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: F16 } - } - clock_cycles: 449 - } - entries { - instruction { - opcode: "divide" - shape { element_type: F16 } - } - clock_cycles: 45 - } - entries { - instruction { - opcode: "multiply" - shape { element_type: F16 } - } - clock_cycles: 7 - } - entries { - instruction { - opcode: "power" - shape { element_type: F16 } - } - clock_cycles: 491 - } - entries { - instruction { - opcode: "subtract" - shape { element_type: F16 } - } - clock_cycles: 7 - } - entries { - instruction { - opcode: "cbrt" - shape { element_type: F32 } - } - clock_cycles: 400 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: F32 } - } - clock_cycles: 326 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: F32 } - } - clock_cycles: 80 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: F32 } - } - clock_cycles: 196 - } - entries { - instruction { - opcode: "log" - shape { element_type: F32 } - } - clock_cycles: 157 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: F32 } - } - clock_cycles: 221 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: F32 } - } - clock_cycles: 77 - } - entries { - instruction { - opcode: "sine" - shape { element_type: F32 } - } - clock_cycles: 933 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: F32 } - } - clock_cycles: 77 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: F32 } - } - clock_cycles: 179 - } - entries { - instruction { - opcode: "add" - shape { element_type: F32 } - } - clock_cycles: 7 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: F32 } - } - clock_cycles: 428 - } - entries { - instruction { - opcode: "divide" - shape { element_type: F32 } - } - clock_cycles: 24 - } - entries { - instruction { - opcode: "multiply" - shape { element_type: F32 } - } - clock_cycles: 7 - } - entries { - instruction { - opcode: "power" - shape { element_type: F32 } - } - clock_cycles: 487 - } - entries { - instruction { - opcode: "subtract" - shape { element_type: F32 } - } - clock_cycles: 7 - } - entries { - instruction { - opcode: "cbrt" - shape { element_type: F64 } - } - clock_cycles: 1656 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: F64 } - } - clock_cycles: 568 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: F64 } - } - clock_cycles: 382 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: F64 } - } - clock_cycles: 403 - } - entries { - instruction { - opcode: "log" - shape { element_type: F64 } - } - clock_cycles: 800 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: F64 } - } - clock_cycles: 1210 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: F64 } - } - clock_cycles: 277 - } - entries { - instruction { - opcode: "sine" - shape { element_type: F64 } - } - clock_cycles: 561 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: F64 } - } - clock_cycles: 333 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: F64 } - } - clock_cycles: 393 - } - entries { - instruction { - opcode: "add" - shape { element_type: F64 } - } - clock_cycles: 14 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: F64 } - } - clock_cycles: 866 - } - entries { - instruction { - opcode: "divide" - shape { element_type: F64 } - } - clock_cycles: 530 - } - entries { - instruction { - opcode: "multiply" - shape { element_type: F64 } - } - clock_cycles: 14 - } - entries { - instruction { - opcode: "power" - shape { element_type: F64 } - } - clock_cycles: 2179 - } - entries { - instruction { - opcode: "subtract" - shape { element_type: F64 } - } - clock_cycles: 14 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: C64 } - } - clock_cycles: 579 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: C64 } - } - clock_cycles: 635 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: C64 } - } - clock_cycles: 631 - } - entries { - instruction { - opcode: "log" - shape { element_type: C64 } - } - clock_cycles: 807 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: C64 } - } - clock_cycles: 614 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: C64 } - } - clock_cycles: 2815 - } - entries { - instruction { - opcode: "sine" - shape { element_type: C64 } - } - clock_cycles: 723 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: C64 } - } - clock_cycles: 4113 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: C64 } - } - clock_cycles: 2348 - } - entries { - instruction { - opcode: "add" - shape { element_type: C64 } - } - clock_cycles: 7 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: C64 } - } - clock_cycles: 6047 - } - entries { - instruction { - opcode: "divide" - shape { element_type: C64 } - } - clock_cycles: 452 - } - entries { - instruction { - opcode: "multiply" - shape { element_type: C64 } - } - clock_cycles: 77 - } - entries { - instruction { - opcode: "power" - shape { element_type: C64 } - } - clock_cycles: 4706 - } - entries { - instruction { - opcode: "subtract" - shape { element_type: C64 } - } - clock_cycles: 7 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: C128 } - } - clock_cycles: 1779 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: C128 } - } - clock_cycles: 1333 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: C128 } - } - clock_cycles: 1288 - } - entries { - instruction { - opcode: "log" - shape { element_type: C128 } - } - clock_cycles: 2337 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: C128 } - } - clock_cycles: 2299 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: C128 } - } - clock_cycles: 5036 - } - entries { - instruction { - opcode: "sine" - shape { element_type: C128 } - } - clock_cycles: 1997 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: C128 } - } - clock_cycles: 6181 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: C128 } - } - clock_cycles: 4419 - } - entries { - instruction { - opcode: "add" - shape { element_type: C128 } - } - clock_cycles: 14 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: C128 } - } - clock_cycles: 12453 - } - entries { - instruction { - opcode: "divide" - shape { element_type: C128 } - } - clock_cycles: 2270 - } - entries { - instruction { - opcode: "multiply" - shape { element_type: C128 } - } - clock_cycles: 38 - } - entries { - instruction { - opcode: "power" - shape { element_type: C128 } - } - clock_cycles: 7339 - } - entries { - instruction { - opcode: "subtract" - shape { element_type: C128 } - } - clock_cycles: 14 - } - } - } - entries { - key: "sm_86" - value { - entries { - instruction { - opcode: "divide" - shape { element_type: S8 } - } - clock_cycles: 370 - } - entries { - instruction { - opcode: "power" - shape { element_type: S8 } - } - clock_cycles: 392 - } - entries { - instruction { - opcode: "divide" - shape { element_type: S16 } - } - clock_cycles: 367 - } - entries { - instruction { - opcode: "power" - shape { element_type: S16 } - } - clock_cycles: 396 - } - entries { - instruction { - opcode: "divide" - shape { element_type: S32 } - } - clock_cycles: 306 - } - entries { - instruction { - opcode: "divide" - shape { element_type: S64 } - } - clock_cycles: 918 - } - entries { - instruction { - opcode: "power" - shape { element_type: S64 } - } - clock_cycles: 601 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U8 } - } - clock_cycles: 306 - } - entries { - instruction { - opcode: "power" - shape { element_type: U8 } - } - clock_cycles: 388 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U16 } - } - clock_cycles: 302 - } - entries { - instruction { - opcode: "power" - shape { element_type: U16 } - } - clock_cycles: 399 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U32 } - } - clock_cycles: 115 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U64 } - } - clock_cycles: 838 - } - entries { - instruction { - opcode: "power" - shape { element_type: U64 } - } - clock_cycles: 604 - } - entries { - instruction { - opcode: "cbrt" - shape { element_type: F16 } - } - clock_cycles: 925 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: F16 } - } - clock_cycles: 691 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: F16 } - } - clock_cycles: 108 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: F16 } - } - clock_cycles: 396 - } - entries { - instruction { - opcode: "log" - shape { element_type: F16 } - } - clock_cycles: 266 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: F16 } - } - clock_cycles: 284 - } - entries { - instruction { - opcode: "logistic" - shape { element_type: F16 } - } - clock_cycles: 226 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: F16 } - } - clock_cycles: 97 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: F16 } - } - clock_cycles: 97 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: F16 } - } - clock_cycles: 212 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: F16 } - } - clock_cycles: 482 - } - entries { - instruction { - opcode: "power" - shape { element_type: F16 } - } - clock_cycles: 975 - } - entries { - instruction { - opcode: "cbrt" - shape { element_type: F32 } - } - clock_cycles: 867 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: F32 } - } - clock_cycles: 662 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: F32 } - } - clock_cycles: 86 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: F32 } - } - clock_cycles: 381 - } - entries { - instruction { - opcode: "log" - shape { element_type: F32 } - } - clock_cycles: 244 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: F32 } - } - clock_cycles: 262 - } - entries { - instruction { - opcode: "logistic" - shape { element_type: F32 } - } - clock_cycles: 176 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: F32 } - } - clock_cycles: 75 - } - entries { - instruction { - opcode: "sine" - shape { element_type: F32 } - } - clock_cycles: 662 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: F32 } - } - clock_cycles: 75 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: F32 } - } - clock_cycles: 190 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: F32 } - } - clock_cycles: 486 - } - entries { - instruction { - opcode: "power" - shape { element_type: F32 } - } - clock_cycles: 925 - } - entries { - instruction { - opcode: "cbrt" - shape { element_type: F64 } - } - clock_cycles: 6339 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: F64 } - } - clock_cycles: 1717 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: F64 } - } - clock_cycles: 1652 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: F64 } - } - clock_cycles: 1900 - } - entries { - instruction { - opcode: "log" - shape { element_type: F64 } - } - clock_cycles: 608 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: F64 } - } - clock_cycles: 2073 - } - entries { - instruction { - opcode: "logistic" - shape { element_type: F64 } - } - clock_cycles: 2412 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: F64 } - } - clock_cycles: 698 - } - entries { - instruction { - opcode: "sine" - shape { element_type: F64 } - } - clock_cycles: 1789 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: F64 } - } - clock_cycles: 986 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: F64 } - } - clock_cycles: 1609 - } - entries { - instruction { - opcode: "add" - shape { element_type: F64 } - } - clock_cycles: 97 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: F64 } - } - clock_cycles: 3747 - } - entries { - instruction { - opcode: "divide" - shape { element_type: F64 } - } - clock_cycles: 2016 - } - entries { - instruction { - opcode: "multiply" - shape { element_type: F64 } - } - clock_cycles: 97 - } - entries { - instruction { - opcode: "power" - shape { element_type: F64 } - } - clock_cycles: 5511 - } - entries { - instruction { - opcode: "subtract" - shape { element_type: F64 } - } - clock_cycles: 97 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: C64 } - } - clock_cycles: 1360 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: C64 } - } - clock_cycles: 1400 - } - entries { - instruction { - opcode: "log" - shape { element_type: C64 } - } - clock_cycles: 950 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: C64 } - } - clock_cycles: 842 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: C64 } - } - clock_cycles: 2383 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: C64 } - } - clock_cycles: 3193 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: C64 } - } - clock_cycles: 5353 - } - entries { - instruction { - opcode: "divide" - shape { element_type: C64 } - } - clock_cycles: 687 - } - entries { - instruction { - opcode: "power" - shape { element_type: C64 } - } - clock_cycles: 3351 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: C128 } - } - clock_cycles: 6613 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: C128 } - } - clock_cycles: 4028 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: C128 } - } - clock_cycles: 4161 - } - entries { - instruction { - opcode: "log" - shape { element_type: C128 } - } - clock_cycles: 7599 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: C128 } - } - clock_cycles: 6962 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: C128 } - } - clock_cycles: 11318 - } - entries { - instruction { - opcode: "sine" - shape { element_type: C128 } - } - clock_cycles: 5878 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: C128 } - } - clock_cycles: 15606 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: C128 } - } - clock_cycles: 9939 - } - entries { - instruction { - opcode: "add" - shape { element_type: C128 } - } - clock_cycles: 97 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: C128 } - } - clock_cycles: 39027 - } - entries { - instruction { - opcode: "divide" - shape { element_type: C128 } - } - clock_cycles: 7941 - } - entries { - instruction { - opcode: "multiply" - shape { element_type: C128 } - } - clock_cycles: 270 - } - entries { - instruction { - opcode: "power" - shape { element_type: C128 } - } - clock_cycles: 18205 - } - entries { - instruction { - opcode: "subtract" - shape { element_type: C128 } - } - clock_cycles: 97 - } - } - } + // Returns profile name for the gived device. + // For CUDA, the format is "sm_XX". + static std::string GetProfileName(const se::DeviceDescription* device_info); - entries { key: "sm_80" - value { entries { - instruction { - opcode: "divide" - shape { element_type: S8 } - } - clock_cycles: 417 - } - entries { - instruction { - opcode: "divide" - shape { element_type: S16 } - } - clock_cycles: 468 - } - entries { - instruction { - opcode: "divide" - shape { element_type: S64 } - } - clock_cycles: 1094 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U8 } - } - clock_cycles: 420 - } - entries { - instruction { - opcode: "power" - shape { element_type: U8 } - } - clock_cycles: 417 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U16 } - } - clock_cycles: 391 - } - entries { - instruction { - opcode: "power" - shape { element_type: U16 } - } - clock_cycles: 454 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U64 } - } - clock_cycles: 908 - } - entries { - instruction { - opcode: "power" - shape { element_type: U64 } - } - clock_cycles: 744 - } - entries { - instruction { - opcode: "cbrt" - shape { element_type: F16 } - } - clock_cycles: 1195 - } - entries { - instruction { - opcode: "log" - shape { element_type: F16 } - } - clock_cycles: 321 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: F16 } - } - clock_cycles: 346 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: F16 } - } - clock_cycles: 124 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: F16 } - } - clock_cycles: 499 - } - entries { - instruction { - opcode: "log" - shape { element_type: F32 } - } - clock_cycles: 259 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: F32 } - } - clock_cycles: 504 - } - entries { - instruction { - opcode: "power" - shape { element_type: F32 } - } - clock_cycles: 1221 - } - entries { - instruction { - opcode: "cbrt" - shape { element_type: F64 } - } - clock_cycles: 1638 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: F64 } - } - clock_cycles: 572 - })pb" - R"pb( - entries { - instruction { - opcode: "log" - shape { element_type: F64 } - } - clock_cycles: 699 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: F64 } - } - clock_cycles: 1223 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: F64 } - } - clock_cycles: 329 - } - entries { - instruction { - opcode: "sine" - shape { element_type: F64 } - } - clock_cycles: 597 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: F64 } - } - clock_cycles: 397 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: F64 } - } - clock_cycles: 733 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: F64 } - } - clock_cycles: 1080 - } - entries { - instruction { - opcode: "divide" - shape { element_type: F64 } - } - clock_cycles: 831 - } - entries { - instruction { - opcode: "power" - shape { element_type: F64 } - } - clock_cycles: 1861 - } - entries { - instruction { - opcode: "log" - shape { element_type: C64 } - } - clock_cycles: 1037 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: C64 } - } - clock_cycles: 1029 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: C64 } - } - clock_cycles: 6618 - } - entries { - instruction { - opcode: "power" - shape { element_type: C64 } - } - clock_cycles: 4131 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: C128 } - } - clock_cycles: 2309 - } - entries { - instruction { - opcode: "log" - shape { element_type: C128 } - } - clock_cycles: 2371 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: C128 } - } - clock_cycles: 2405 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: C128 } - } - clock_cycles: 3945 - } - entries { - instruction { - opcode: "sine" - shape { element_type: C128 } - } - clock_cycles: 2284 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: C128 } - } - clock_cycles: 5304 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: C128 } - } - clock_cycles: 3618 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: C128 } - } - clock_cycles: 13564 - } - entries { - instruction { - opcode: "divide" - shape { element_type: C128 } - } - clock_cycles: 3037 - } - entries { - instruction { - opcode: "power" - shape { element_type: C128 } - } - clock_cycles: 6054 - } - } - } + // Loads profiles from the given text proto data. + static std::unique_ptr Load( + std::string_view profiles_text_proto, + std::string_view default_profile_name); - entries { - key: "sm_70" - value { - entries { - instruction { - opcode: "divide" - shape { element_type: S8 } - } - clock_cycles: 345 - } - entries { - instruction { - opcode: "divide" - shape { element_type: S16 } - } - clock_cycles: 345 - } - entries { - instruction { - opcode: "divide" - shape { element_type: S64 } - } - clock_cycles: 954 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U8 } - } - clock_cycles: 302 - } - entries { - instruction { - opcode: "power" - shape { element_type: U8 } - } - clock_cycles: 526 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U16 } - } - clock_cycles: 309 - } - entries { - instruction { - opcode: "power" - shape { element_type: U16 } - } - clock_cycles: 544 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U64 } - } - clock_cycles: 749 - } - entries { - instruction { - opcode: "power" - shape { element_type: U64 } - } - clock_cycles: 820 - } - entries { - instruction { - opcode: "cbrt" - shape { element_type: F16 } - } - clock_cycles: 1227 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: F16 } - } - clock_cycles: 865 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: F16 } - } - clock_cycles: 137 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: F16 } - } - clock_cycles: 544 - } - entries { - instruction { - opcode: "log" - shape { element_type: F16 } - } - clock_cycles: 354 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: F16 } - } - clock_cycles: 388 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: F16 } - } - clock_cycles: 122 - } - entries { - instruction { - opcode: "sine" - shape { element_type: F16 } - } - clock_cycles: 841 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: F16 } - } - clock_cycles: 134 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: F16 } - } - clock_cycles: 556 - } - entries { - instruction { - opcode: "power" - shape { element_type: F16 } - } - clock_cycles: 1279 - } - entries { - instruction { - opcode: "cbrt" - shape { element_type: F32 } - } - clock_cycles: 1168 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: F32 } - } - clock_cycles: 823 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: F32 } - } - clock_cycles: 110 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: F32 } - } - clock_cycles: 514 - } - entries { - instruction { - opcode: "log" - shape { element_type: F32 } - } - clock_cycles: 333 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: F32 } - } - clock_cycles: 361 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: F32 } - } - clock_cycles: 529 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: F32 } - } - clock_cycles: 660 - } - entries { - instruction { - opcode: "power" - shape { element_type: F32 } - } - clock_cycles: 1214 - } - entries { - instruction { - opcode: "cbrt" - shape { element_type: F64 } - } - clock_cycles: 1392 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: F64 } - } - clock_cycles: 673 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: F64 } - } - clock_cycles: 474 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: F64 } - } - clock_cycles: 676 - } - entries { - instruction { - opcode: "log" - shape { element_type: F64 } - } - clock_cycles: 618 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: F64 } - } - clock_cycles: 1061 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: F64 } - } - clock_cycles: 290 - } - entries { - instruction { - opcode: "sine" - shape { element_type: F64 } - } - clock_cycles: 667 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: F64 } - } - clock_cycles: 391 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: F64 } - } - clock_cycles: 709 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: F64 } - } - clock_cycles: 1178 - } - entries { - instruction { - opcode: "divide" - shape { element_type: F64 } - } - clock_cycles: 682 - } - entries { - instruction { - opcode: "power" - shape { element_type: F64 } - } - clock_cycles: 1679 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: C64 } - } - clock_cycles: 1762 - } - entries { - instruction { - opcode: "log" - shape { element_type: C64 } - } - clock_cycles: 1450 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: C64 } - } - clock_cycles: 1141 - } - entries { - instruction { - opcode: "sine" - shape { element_type: C64 } - } - clock_cycles: 1787 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: C64 } - } - clock_cycles: 3935 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: C64 } - } - clock_cycles: 7025 - } - entries { - instruction { - opcode: "divide" - shape { element_type: C64 } - } - clock_cycles: 948 - } - entries { - instruction { - opcode: "power" - shape { element_type: C64 } - } - clock_cycles: 4277 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: C128 } - } - clock_cycles: 2386 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: C128 } - } - clock_cycles: 1881 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: C128 } - } - clock_cycles: 1875 - } - entries { - instruction { - opcode: "log" - shape { element_type: C128 } - } - clock_cycles: 2622 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: C128 } - } - clock_cycles: 2328 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: C128 } - } - clock_cycles: 4531 - } - entries { - instruction { - opcode: "sine" - shape { element_type: C128 } - } - clock_cycles: 2408 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: C128 } - } - clock_cycles: 5388 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: C128 } - } - clock_cycles: 3867 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: C128 } - } - clock_cycles: 13794 - } - entries { - instruction { - opcode: "divide" - shape { element_type: C128 } - } - clock_cycles: 3001 - } - entries { - instruction { - opcode: "power" - shape { element_type: C128 } - } - clock_cycles: 6046 - } - } - } + const HloOpProfile& GetProfile( + const se::DeviceDescription* device_info) const; - entries { - key: "sm_60" - value { - entries { - instruction { - opcode: "divide" - shape { element_type: S8 } - } - clock_cycles: 438 - } - entries { - instruction { - opcode: "divide" - shape { element_type: S16 } - } - clock_cycles: 479 - } - entries { - instruction { - opcode: "divide" - shape { element_type: S32 } - } - clock_cycles: 758 - } - entries { - instruction { - opcode: "divide" - shape { element_type: S64 } - } - clock_cycles: 2037 - } - entries { - instruction { - opcode: "power" - shape { element_type: S64 } - } - clock_cycles: 2937 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U8 } - } - clock_cycles: 307 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U16 } - } - clock_cycles: 293 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U64 } - } - clock_cycles: 1708 - } - entries { - instruction { - opcode: "power" - shape { element_type: U64 } - } - clock_cycles: 2993 - } - entries { - instruction { - opcode: "cbrt" - shape { element_type: F16 } - } - clock_cycles: 1661 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: F16 } - } - clock_cycles: 213 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: F16 } - } - clock_cycles: 778 - } - entries { - instruction { - opcode: "log" - shape { element_type: F16 } - } - clock_cycles: 598 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: F16 } - } - clock_cycles: 538 - } - entries { - instruction { - opcode: "logistic" - shape { element_type: F16 } - } - clock_cycles: 402 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: F16 } - } - clock_cycles: 130 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: F16 } - } - clock_cycles: 453 - } - entries { - instruction { - opcode: "power" - shape { element_type: F16 } - } - clock_cycles: 1717 - } - entries { - instruction { - opcode: "cbrt" - shape { element_type: F32 } - } - clock_cycles: 1672 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: F32 } - } - clock_cycles: 168 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: F32 } - } - clock_cycles: 731 - } - entries { - instruction { - opcode: "log" - shape { element_type: F32 } - } - clock_cycles: 435 - } - )pb" - R"pb( - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: F32 } - } - clock_cycles: 589 - } - entries { - instruction { - opcode: "logistic" - shape { element_type: F32 } - } - clock_cycles: 343 - } - entries { - instruction { - opcode: "sine" - shape { element_type: F32 } - } - clock_cycles: 1024 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: F32 } - } - clock_cycles: 417 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: F32 } - } - clock_cycles: 873 - } - entries { - instruction { - opcode: "power" - shape { element_type: F32 } - } - clock_cycles: 1779 - } - entries { - instruction { - opcode: "cbrt" - shape { element_type: F64 } - } - clock_cycles: 1649 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: F64 } - } - clock_cycles: 1175 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: F64 } - } - clock_cycles: 639 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: F64 } - } - clock_cycles: 911 - } - entries { - instruction { - opcode: "log" - shape { element_type: F64 } - } - clock_cycles: 935 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: F64 } - } - clock_cycles: 1421 - } - entries { - instruction { - opcode: "logistic" - shape { element_type: F64 } - } - clock_cycles: 1098 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: F64 } - } - clock_cycles: 355 - } - entries { - instruction { - opcode: "sine" - shape { element_type: F64 } - } - clock_cycles: 1187 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: F64 } - } - clock_cycles: 645 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: F64 } - } - clock_cycles: 917 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: F64 } - } - clock_cycles: 1394 - } - entries { - instruction { - opcode: "divide" - shape { element_type: F64 } - } - clock_cycles: 959 - } - entries { - instruction { - opcode: "power" - shape { element_type: F64 } - } - clock_cycles: 2667 - } - entries { - instruction { - opcode: "log" - shape { element_type: C64 } - } - clock_cycles: 1726 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: C64 } - } - clock_cycles: 1518 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: C64 } - } - clock_cycles: 4142 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: C64 } - } - clock_cycles: 5069 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: C64 } - } - clock_cycles: 4053 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: C64 } - } - clock_cycles: 9469 - } - entries { - instruction { - opcode: "divide" - shape { element_type: C64 } - } - clock_cycles: 1317 - } - entries { - instruction { - opcode: "power" - shape { element_type: C64 } - } - clock_cycles: 5617 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: C128 } - } - clock_cycles: 3416 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: C128 } - } - clock_cycles: 2730 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: C128 } - } - clock_cycles: 2765 - } - entries { - instruction { - opcode: "log" - shape { element_type: C128 } - } - clock_cycles: 3106 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: C128 } - } - clock_cycles: 2895 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: C128 } - } - clock_cycles: 5922 - } - entries { - instruction { - opcode: "sine" - shape { element_type: C128 } - } - clock_cycles: 3496 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: C128 } - } - clock_cycles: 7014 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: C128 } - } - clock_cycles: 5400 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: C128 } - } - clock_cycles: 21766 - } - entries { - instruction { - opcode: "divide" - shape { element_type: C128 } - } - clock_cycles: 4133 - } - entries { - instruction { - opcode: "power" - shape { element_type: C128 } - } - clock_cycles: 10458 - } - } - } + private: + HloOpProfiles(ProfilesNestedMap profiles, + std::string_view default_profile_name) + : profiles_(std::move(profiles)), + default_profile_(profiles_.at(default_profile_name)) {} - entries { - key: "sm_75" - value { - entries { - instruction { - opcode: "divide" - shape { element_type: S8 } - } - clock_cycles: 360 - } - entries { - instruction { - opcode: "power" - shape { element_type: S8 } - } - clock_cycles: 336 - } - entries { - instruction { - opcode: "divide" - shape { element_type: S16 } - } - clock_cycles: 357 - } - entries { - instruction { - opcode: "power" - shape { element_type: S16 } - } - clock_cycles: 339 - } - entries { - instruction { - opcode: "divide" - shape { element_type: S32 } - } - clock_cycles: 296 - } - entries { - instruction { - opcode: "divide" - shape { element_type: S64 } - } - clock_cycles: 979 - } - entries { - instruction { - opcode: "power" - shape { element_type: S64 } - } - clock_cycles: 495 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U8 } - } - clock_cycles: 293 - } - entries { - instruction { - opcode: "power" - shape { element_type: U8 } - } - clock_cycles: 334 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U16 } - } - clock_cycles: 290 - } - entries { - instruction { - opcode: "power" - shape { element_type: U16 } - } - clock_cycles: 336 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U32 } - } - clock_cycles: 118 - } - entries { - instruction { - opcode: "divide" - shape { element_type: U64 } - } - clock_cycles: 812 - } - entries { - instruction { - opcode: "power" - shape { element_type: U64 } - } - clock_cycles: 515 - } - entries { - instruction { - opcode: "cbrt" - shape { element_type: F16 } - } - clock_cycles: 792 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: F16 } - } - clock_cycles: 815 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: F16 } - } - clock_cycles: 132 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: F16 } - } - clock_cycles: 342 - } - entries { - instruction { - opcode: "log" - shape { element_type: F16 } - } - clock_cycles: 239 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: F16 } - } - clock_cycles: 239 - } - entries { - instruction { - opcode: "logistic" - shape { element_type: F16 } - } - clock_cycles: 262 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: F16 } - } - clock_cycles: 126 - } - entries { - instruction { - opcode: "sine" - shape { element_type: F16 } - } - clock_cycles: 794 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: F16 } - } - clock_cycles: 123 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: F16 } - } - clock_cycles: 175 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: F16 } - } - clock_cycles: 414 - } - entries { - instruction { - opcode: "divide" - shape { element_type: F16 } - } - clock_cycles: 74 - } - entries { - instruction { - opcode: "power" - shape { element_type: F16 } - } - clock_cycles: 1120 - } - entries { - instruction { - opcode: "cbrt" - shape { element_type: F32 } - } - clock_cycles: 783 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: F32 } - } - clock_cycles: 737 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: F32 } - } - clock_cycles: 83 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: F32 } - } - clock_cycles: 319 - } - entries { - instruction { - opcode: "log" - shape { element_type: F32 } - } - clock_cycles: 201 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: F32 } - } - clock_cycles: 218 - } - entries { - instruction { - opcode: "logistic" - shape { element_type: F32 } - } - clock_cycles: 181 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: F32 } - } - clock_cycles: 74 - } - entries { - instruction { - opcode: "sine" - shape { element_type: F32 } - } - clock_cycles: 717 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: F32 } - } - clock_cycles: 74 - } - )pb" - R"pb( - entries { - instruction { - opcode: "tanh" - shape { element_type: F32 } - } - clock_cycles: 167 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: F32 } - } - clock_cycles: 414 - } - entries { - instruction { - opcode: "power" - shape { element_type: F32 } - } - clock_cycles: 1085 - } - entries { - instruction { - opcode: "cbrt" - shape { element_type: F64 } - } - clock_cycles: 6494 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: F64 } - } - clock_cycles: 1800 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: F64 } - } - clock_cycles: 1630 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: F64 } - } - clock_cycles: 1929 - } - entries { - instruction { - opcode: "log" - shape { element_type: F64 } - } - clock_cycles: 596 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: F64 } - } - clock_cycles: 1774 - } - entries { - instruction { - opcode: "logistic" - shape { element_type: F64 } - } - clock_cycles: 2430 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: F64 } - } - clock_cycles: 705 - } - entries { - instruction { - opcode: "sine" - shape { element_type: F64 } - } - clock_cycles: 1805 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: F64 } - } - clock_cycles: 984 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: F64 } - } - clock_cycles: 1535 - } - entries { - instruction { - opcode: "add" - shape { element_type: F64 } - } - clock_cycles: 95 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: F64 } - } - clock_cycles: 3744 - } - entries { - instruction { - opcode: "divide" - shape { element_type: F64 } - } - clock_cycles: 1915 - } - entries { - instruction { - opcode: "multiply" - shape { element_type: F64 } - } - clock_cycles: 95 - } - entries { - instruction { - opcode: "power" - shape { element_type: F64 } - } - clock_cycles: 5538 - } - entries { - instruction { - opcode: "subtract" - shape { element_type: F64 } - } - clock_cycles: 95 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: C64 } - } - clock_cycles: 1702 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: C64 } - } - clock_cycles: 1503 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: C64 } - } - clock_cycles: 1474 - } - entries { - instruction { - opcode: "log" - shape { element_type: C64 } - } - clock_cycles: 835 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: C64 } - } - clock_cycles: 737 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: C64 } - } - clock_cycles: 2232 - } - entries { - instruction { - opcode: "sine" - shape { element_type: C64 } - } - clock_cycles: 1632 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: C64 } - } - clock_cycles: 2989 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: C64 } - } - clock_cycles: 2263 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: C64 } - } - clock_cycles: 4847 - } - entries { - instruction { - opcode: "power" - shape { element_type: C64 } - } - clock_cycles: 3219 - } - entries { - instruction { - opcode: "cosine" - shape { element_type: C128 } - } - clock_cycles: 6474 - } - entries { - instruction { - opcode: "exponential" - shape { element_type: C128 } - } - clock_cycles: 4962 - } - entries { - instruction { - opcode: "exponential-minus-one" - shape { element_type: C128 } - } - clock_cycles: 4037 - } - entries { - instruction { - opcode: "log" - shape { element_type: C128 } - } - clock_cycles: 7286 - } - entries { - instruction { - opcode: "log-plus-one" - shape { element_type: C128 } - } - clock_cycles: 6848 - } - entries { - instruction { - opcode: "rsqrt" - shape { element_type: C128 } - } - clock_cycles: 10748 - } - entries { - instruction { - opcode: "sine" - shape { element_type: C128 } - } - clock_cycles: 5391 - } - entries { - instruction { - opcode: "sqrt" - shape { element_type: C128 } - } - clock_cycles: 15981 - } - entries { - instruction { - opcode: "tanh" - shape { element_type: C128 } - } - clock_cycles: 9653 - } - entries { - instruction { - opcode: "add" - shape { element_type: C128 } - } - clock_cycles: 95 - } - entries { - instruction { - opcode: "atan2" - shape { element_type: C128 } - } - clock_cycles: 38206 - } - entries { - instruction { - opcode: "divide" - shape { element_type: C128 } - } - clock_cycles: 8040 - } - entries { - instruction { - opcode: "multiply" - shape { element_type: C128 } - } - clock_cycles: 273 - } - entries { - instruction { - opcode: "power" - shape { element_type: C128 } - } - clock_cycles: 18550 - } - entries { - instruction { - opcode: "subtract" - shape { element_type: C128 } - } - clock_cycles: 97 - } - } - } -)pb"; + ProfilesNestedMap profiles_; + const HloOpProfile& default_profile_; +}; } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/model/hlo_op_profiles_data.h b/xla/service/gpu/model/hlo_op_profiles_data.h new file mode 100644 index 0000000000000..043596a51fef9 --- /dev/null +++ b/xla/service/gpu/model/hlo_op_profiles_data.h @@ -0,0 +1,3720 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_HLO_OP_PROFILES_DATA_H_ +#define XLA_SERVICE_GPU_MODEL_HLO_OP_PROFILES_DATA_H_ + +namespace xla { +namespace gpu { + +// The data below is obtained with +// xla/service/gpu/model:hlo_op_profiler_run + +constexpr char kDeviceHloOpProfiles[] = R"pb( + entries { + key: "sm_90" # "NVIDIA H100 80GB HBM3" + value { + entries { + instruction { + opcode: "divide" + shape { element_type: S8 } + } + clock_cycles: 356 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: S8 } + } + clock_cycles: 7 + } + entries { + instruction { + opcode: "power" + shape { element_type: S8 } + } + clock_cycles: 122 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S16 } + } + clock_cycles: 364 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: S16 } + } + clock_cycles: 7 + } + entries { + instruction { + opcode: "power" + shape { element_type: S16 } + } + clock_cycles: 122 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S32 } + } + clock_cycles: 297 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: S32 } + } + clock_cycles: 3 + } + entries { + instruction { + opcode: "power" + shape { element_type: S32 } + } + clock_cycles: 71 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S64 } + } + clock_cycles: 685 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: S64 } + } + clock_cycles: 11 + } + entries { + instruction { + opcode: "power" + shape { element_type: S64 } + } + clock_cycles: 253 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U8 } + } + clock_cycles: 300 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: U8 } + } + clock_cycles: 7 + } + entries { + instruction { + opcode: "power" + shape { element_type: U8 } + } + clock_cycles: 122 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U16 } + } + clock_cycles: 304 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: U16 } + } + clock_cycles: 7 + } + entries { + instruction { + opcode: "power" + shape { element_type: U16 } + } + clock_cycles: 126 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U32 } + } + clock_cycles: 122 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: U32 } + } + clock_cycles: 3 + } + entries { + instruction { + opcode: "power" + shape { element_type: U32 } + } + clock_cycles: 71 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U64 } + } + clock_cycles: 629 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: U64 } + } + clock_cycles: 11 + } + entries { + instruction { + opcode: "power" + shape { element_type: U64 } + } + clock_cycles: 253 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F16 } + } + clock_cycles: 201 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F16 } + } + clock_cycles: 997 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F16 } + } + clock_cycles: 102 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F16 } + } + clock_cycles: 217 + } + entries { + instruction { + opcode: "log" + shape { element_type: F16 } + } + clock_cycles: 182 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F16 } + } + clock_cycles: 245 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F16 } + } + clock_cycles: 95 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F16 } + } + clock_cycles: 993 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F16 } + } + clock_cycles: 95 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F16 } + } + clock_cycles: 502 + } + entries { + instruction { + opcode: "add" + shape { element_type: F16 } + } + clock_cycles: 7 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F16 } + } + clock_cycles: 451 + } + entries { + instruction { + opcode: "divide" + shape { element_type: F16 } + } + clock_cycles: 43 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: F16 } + } + clock_cycles: 7 + } + entries { + instruction { + opcode: "power" + shape { element_type: F16 } + } + clock_cycles: 526 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: F16 } + } + clock_cycles: 7 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F32 } + } + clock_cycles: 178 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F32 } + } + clock_cycles: 978 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F32 } + } + clock_cycles: 79 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F32 } + } + clock_cycles: 190 + } + entries { + instruction { + opcode: "log" + shape { element_type: F32 } + } + clock_cycles: 166 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F32 } + } + clock_cycles: 229 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F32 } + } + clock_cycles: 75 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F32 } + } + clock_cycles: 958 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F32 } + } + clock_cycles: 75 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F32 } + } + clock_cycles: 467 + } + entries { + instruction { + opcode: "add" + shape { element_type: F32 } + } + clock_cycles: 7 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F32 } + } + clock_cycles: 431 + } + entries { + instruction { + opcode: "divide" + shape { element_type: F32 } + } + clock_cycles: 19 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: F32 } + } + clock_cycles: 3 + } + entries { + instruction { + opcode: "power" + shape { element_type: F32 } + } + clock_cycles: 510 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: F32 } + } + clock_cycles: 7 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F64 } + } + clock_cycles: 586 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F64 } + } + clock_cycles: 558 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F64 } + } + clock_cycles: 376 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F64 } + } + clock_cycles: 712 + } + entries { + instruction { + opcode: "log" + shape { element_type: F64 } + } + clock_cycles: 815 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F64 } + } + clock_cycles: 1259 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F64 } + } + clock_cycles: 277 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F64 } + } + clock_cycles: 554 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F64 } + } + clock_cycles: 332 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F64 } + } + clock_cycles: 431 + } + entries { + instruction { + opcode: "add" + shape { element_type: F64 } + } + clock_cycles: 15 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F64 } + } + clock_cycles: 930 + } + entries { + instruction { + opcode: "divide" + shape { element_type: F64 } + } + clock_cycles: 526 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: F64 } + } + clock_cycles: 15 + } + entries { + instruction { + opcode: "power" + shape { element_type: F64 } + } + clock_cycles: 2205 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: F64 } + } + clock_cycles: 15 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: C64 } + } + clock_cycles: 2415 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: C64 } + } + clock_cycles: 641 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: C64 } + } + clock_cycles: 2055 + } + entries { + instruction { + opcode: "log" + shape { element_type: C64 } + } + clock_cycles: 756 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C64 } + } + clock_cycles: 633 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: C64 } + } + clock_cycles: 3148 + } + entries { + instruction { + opcode: "sine" + shape { element_type: C64 } + } + clock_cycles: 2324 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C64 } + } + clock_cycles: 4344 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: C64 } + } + clock_cycles: 2379 + } + entries { + instruction { + opcode: "add" + shape { element_type: C64 } + } + clock_cycles: 7 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C64 } + } + clock_cycles: 6462 + } + entries { + instruction { + opcode: "divide" + shape { element_type: C64 } + } + clock_cycles: 498 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: C64 } + } + clock_cycles: 79 + } + entries { + instruction { + opcode: "power" + shape { element_type: C64 } + } + clock_cycles: 5532 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: C64 } + } + clock_cycles: 7 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: C128 } + } + clock_cycles: 1750 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: C128 } + } + clock_cycles: 1342 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: C128 } + } + clock_cycles: 1275 + } + entries { + instruction { + opcode: "log" + shape { element_type: C128 } + } + clock_cycles: 2455 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C128 } + } + clock_cycles: 2403 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: C128 } + } + clock_cycles: 5500 + } + entries { + instruction { + opcode: "sine" + shape { element_type: C128 } + } + clock_cycles: 1999 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C128 } + } + clock_cycles: 6636 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: C128 } + } + clock_cycles: 4613 + } + entries { + instruction { + opcode: "add" + shape { element_type: C128 } + } + clock_cycles: 15 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C128 } + } + clock_cycles: 13131 + } + entries { + instruction { + opcode: "divide" + shape { element_type: C128 } + } + clock_cycles: 2280 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: C128 } + } + clock_cycles: 39 + } + entries { + instruction { + opcode: "power" + shape { element_type: C128 } + } + clock_cycles: 8363 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: C128 } + } + clock_cycles: 15 + } + } + } + + entries { + key: "sm_86" + value { + entries { + instruction { + opcode: "divide" + shape { element_type: S8 } + } + clock_cycles: 370 + } + entries { + instruction { + opcode: "power" + shape { element_type: S8 } + } + clock_cycles: 392 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S16 } + } + clock_cycles: 367 + } + entries { + instruction { + opcode: "power" + shape { element_type: S16 } + } + clock_cycles: 396 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S32 } + } + clock_cycles: 306 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S64 } + } + clock_cycles: 918 + } + entries { + instruction { + opcode: "power" + shape { element_type: S64 } + } + clock_cycles: 601 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U8 } + } + clock_cycles: 306 + } + entries { + instruction { + opcode: "power" + shape { element_type: U8 } + } + clock_cycles: 388 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U16 } + } + clock_cycles: 302 + } + entries { + instruction { + opcode: "power" + shape { element_type: U16 } + } + clock_cycles: 399 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U32 } + } + clock_cycles: 115 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U64 } + } + clock_cycles: 838 + } + entries { + instruction { + opcode: "power" + shape { element_type: U64 } + } + clock_cycles: 604 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F16 } + } + clock_cycles: 925 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F16 } + } + clock_cycles: 691 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F16 } + } + clock_cycles: 108 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F16 } + } + clock_cycles: 396 + } + entries { + instruction { + opcode: "log" + shape { element_type: F16 } + } + clock_cycles: 266 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F16 } + } + clock_cycles: 284 + } + entries { + instruction { + opcode: "logistic" + shape { element_type: F16 } + } + clock_cycles: 226 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F16 } + } + clock_cycles: 97 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F16 } + } + clock_cycles: 97 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F16 } + } + clock_cycles: 212 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F16 } + } + clock_cycles: 482 + } + entries { + instruction { + opcode: "power" + shape { element_type: F16 } + } + clock_cycles: 975 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F32 } + } + clock_cycles: 867 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F32 } + } + clock_cycles: 662 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F32 } + } + clock_cycles: 86 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F32 } + } + clock_cycles: 381 + } + entries { + instruction { + opcode: "log" + shape { element_type: F32 } + } + clock_cycles: 244 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F32 } + } + clock_cycles: 262 + } + entries { + instruction { + opcode: "logistic" + shape { element_type: F32 } + } + clock_cycles: 176 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F32 } + } + clock_cycles: 75 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F32 } + } + clock_cycles: 662 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F32 } + } + clock_cycles: 75 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F32 } + } + clock_cycles: 190 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F32 } + } + clock_cycles: 486 + } + entries { + instruction { + opcode: "power" + shape { element_type: F32 } + } + clock_cycles: 925 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F64 } + } + clock_cycles: 6339 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F64 } + } + clock_cycles: 1717 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F64 } + } + clock_cycles: 1652 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F64 } + } + clock_cycles: 1900 + } + entries { + instruction { + opcode: "log" + shape { element_type: F64 } + } + clock_cycles: 608 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F64 } + } + clock_cycles: 2073 + } + entries { + instruction { + opcode: "logistic" + shape { element_type: F64 } + } + clock_cycles: 2412 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F64 } + } + clock_cycles: 698 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F64 } + } + clock_cycles: 1789 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F64 } + } + clock_cycles: 986 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F64 } + } + clock_cycles: 1609 + } + entries { + instruction { + opcode: "add" + shape { element_type: F64 } + } + clock_cycles: 97 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F64 } + } + clock_cycles: 3747 + } + entries { + instruction { + opcode: "divide" + shape { element_type: F64 } + } + clock_cycles: 2016 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: F64 } + } + clock_cycles: 97 + } + entries { + instruction { + opcode: "power" + shape { element_type: F64 } + } + clock_cycles: 5511 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: F64 } + } + clock_cycles: 97 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: C64 } + } + clock_cycles: 1360 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: C64 } + } + clock_cycles: 1400 + } + entries { + instruction { + opcode: "log" + shape { element_type: C64 } + } + clock_cycles: 950 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C64 } + } + clock_cycles: 842 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: C64 } + } + clock_cycles: 2383 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C64 } + } + clock_cycles: 3193 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C64 } + } + clock_cycles: 5353 + } + entries { + instruction { + opcode: "divide" + shape { element_type: C64 } + } + clock_cycles: 687 + } + entries { + instruction { + opcode: "power" + shape { element_type: C64 } + } + clock_cycles: 3351 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: C128 } + } + clock_cycles: 6613 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: C128 } + } + clock_cycles: 4028 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: C128 } + } + clock_cycles: 4161 + } + entries { + instruction { + opcode: "log" + shape { element_type: C128 } + } + clock_cycles: 7599 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C128 } + } + clock_cycles: 6962 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: C128 } + } + clock_cycles: 11318 + } + entries { + instruction { + opcode: "sine" + shape { element_type: C128 } + } + clock_cycles: 5878 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C128 } + } + clock_cycles: 15606 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: C128 } + } + clock_cycles: 9939 + } + entries { + instruction { + opcode: "add" + shape { element_type: C128 } + } + clock_cycles: 97 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C128 } + } + clock_cycles: 39027 + } + entries { + instruction { + opcode: "divide" + shape { element_type: C128 } + } + clock_cycles: 7941 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: C128 } + } + clock_cycles: 270 + } + entries { + instruction { + opcode: "power" + shape { element_type: C128 } + } + clock_cycles: 18205 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: C128 } + } + clock_cycles: 97 + } + } + } + + entries { + key: "sm_80" # "NVIDIA A100-SXM4-40GB" + value { + entries { + instruction { + opcode: "divide" + shape { element_type: S8 } + } + clock_cycles: 417 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S16 } + } + clock_cycles: 468 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S64 } + } + clock_cycles: 1094 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U8 } + } + clock_cycles: 420 + } + entries { + instruction { + opcode: "power" + shape { element_type: U8 } + } + clock_cycles: 417 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U16 } + } + clock_cycles: 391 + } + entries { + instruction { + opcode: "power" + shape { element_type: U16 } + } + clock_cycles: 454 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U64 } + } + clock_cycles: 908 + } + entries { + instruction { + opcode: "power" + shape { element_type: U64 } + } + clock_cycles: 744 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F16 } + } + clock_cycles: 1195 + } + entries { + instruction { + opcode: "log" + shape { element_type: F16 } + } + clock_cycles: 321 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F16 } + } + clock_cycles: 346 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F16 } + } + clock_cycles: 124 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F16 } + } + clock_cycles: 499 + } + entries { + instruction { + opcode: "log" + shape { element_type: F32 } + } + clock_cycles: 259 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F32 } + } + clock_cycles: 504 + } + entries { + instruction { + opcode: "power" + shape { element_type: F32 } + } + clock_cycles: 1221 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F64 } + } + clock_cycles: 1638 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F64 } + } + clock_cycles: 572 + } + entries { + instruction { + opcode: "log" + shape { element_type: F64 } + } + clock_cycles: 699 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F64 } + } + clock_cycles: 1223 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F64 } + } + clock_cycles: 329 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F64 } + } + clock_cycles: 597 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F64 } + } + clock_cycles: 397 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F64 } + } + clock_cycles: 733 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F64 } + } + clock_cycles: 1080 + } + entries { + instruction { + opcode: "divide" + shape { element_type: F64 } + } + clock_cycles: 831 + } + entries { + instruction { + opcode: "power" + shape { element_type: F64 } + } + clock_cycles: 1861 + } + entries { + instruction { + opcode: "log" + shape { element_type: C64 } + } + clock_cycles: 1037 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C64 } + } + clock_cycles: 1029 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C64 } + } + clock_cycles: 6618 + } + entries { + instruction { + opcode: "power" + shape { element_type: C64 } + } + clock_cycles: 4131 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: C128 } + } + clock_cycles: 2309 + } + entries { + instruction { + opcode: "log" + shape { element_type: C128 } + } + clock_cycles: 2371 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C128 } + } + clock_cycles: 2405 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: C128 } + } + clock_cycles: 3945 + } + entries { + instruction { + opcode: "sine" + shape { element_type: C128 } + } + clock_cycles: 2284 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C128 } + } + clock_cycles: 5304 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: C128 } + } + clock_cycles: 3618 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C128 } + } + clock_cycles: 13564 + } + entries { + instruction { + opcode: "divide" + shape { element_type: C128 } + } + clock_cycles: 3037 + } + entries { + instruction { + opcode: "power" + shape { element_type: C128 } + } + clock_cycles: 6054 + } + } + } + + entries { + key: "sm_70" # "Tesla V100-SXM2-16GB" + value { + entries { + instruction { + opcode: "divide" + shape { element_type: S8 } + } + clock_cycles: 336 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: S8 } + } + clock_cycles: 9 + } + entries { + instruction { + opcode: "power" + shape { element_type: S8 } + } + clock_cycles: 189 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S16 } + } + clock_cycles: 345 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: S16 } + } + clock_cycles: 9 + } + entries { + instruction { + opcode: "power" + shape { element_type: S16 } + } + clock_cycles: 183 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S32 } + } + clock_cycles: 287 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: S32 } + } + clock_cycles: 3 + } + entries { + instruction { + opcode: "power" + shape { element_type: S32 } + } + clock_cycles: 104 + } + entries { + instruction { + opcode: "add" + shape { element_type: S64 } + } + clock_cycles: 3 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S64 } + } + clock_cycles: 685 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: S64 } + } + clock_cycles: 12 + } + entries { + instruction { + opcode: "power" + shape { element_type: S64 } + } + clock_cycles: 376 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U8 } + } + clock_cycles: 293 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: U8 } + } + clock_cycles: 9 + } + entries { + instruction { + opcode: "power" + shape { element_type: U8 } + } + clock_cycles: 189 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U16 } + } + clock_cycles: 293 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: U16 } + } + clock_cycles: 9 + } + entries { + instruction { + opcode: "power" + shape { element_type: U16 } + } + clock_cycles: 183 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U32 } + } + clock_cycles: 113 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: U32 } + } + clock_cycles: 3 + } + entries { + instruction { + opcode: "power" + shape { element_type: U32 } + } + clock_cycles: 104 + } + entries { + instruction { + opcode: "add" + shape { element_type: U64 } + } + clock_cycles: 3 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U64 } + } + clock_cycles: 599 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: U64 } + } + clock_cycles: 12 + } + entries { + instruction { + opcode: "power" + shape { element_type: U64 } + } + clock_cycles: 376 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F16 } + } + clock_cycles: 226 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F16 } + } + clock_cycles: 425 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F16 } + } + clock_cycles: 128 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F16 } + } + clock_cycles: 241 + } + entries { + instruction { + opcode: "log" + shape { element_type: F16 } + } + clock_cycles: 232 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F16 } + } + clock_cycles: 266 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F16 } + } + clock_cycles: 122 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F16 } + } + clock_cycles: 425 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F16 } + } + clock_cycles: 122 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F16 } + } + clock_cycles: 284 + } + entries { + instruction { + opcode: "add" + shape { element_type: F16 } + } + clock_cycles: 9 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F16 } + } + clock_cycles: 449 + } + entries { + instruction { + opcode: "divide" + shape { element_type: F16 } + } + clock_cycles: 73 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: F16 } + } + clock_cycles: 9 + } + entries { + instruction { + opcode: "power" + shape { element_type: F16 } + } + clock_cycles: 709 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: F16 } + } + clock_cycles: 9 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F32 } + } + clock_cycles: 189 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F32 } + } + clock_cycles: 373 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F32 } + } + clock_cycles: 79 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F32 } + } + clock_cycles: 205 + } + entries { + instruction { + opcode: "log" + shape { element_type: F32 } + } + clock_cycles: 180 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F32 } + } + clock_cycles: 217 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F32 } + } + clock_cycles: 76 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F32 } + } + clock_cycles: 373 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F32 } + } + clock_cycles: 76 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F32 } + } + clock_cycles: 269 + } + entries { + instruction { + opcode: "add" + shape { element_type: F32 } + } + clock_cycles: 6 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F32 } + } + clock_cycles: 406 + } + entries { + instruction { + opcode: "divide" + shape { element_type: F32 } + } + clock_cycles: 21 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: F32 } + } + clock_cycles: 6 + } + entries { + instruction { + opcode: "power" + shape { element_type: F32 } + } + clock_cycles: 673 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: F32 } + } + clock_cycles: 6 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F64 } + } + clock_cycles: 599 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F64 } + } + clock_cycles: 624 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F64 } + } + clock_cycles: 358 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F64 } + } + clock_cycles: 410 + } + entries { + instruction { + opcode: "log" + shape { element_type: F64 } + } + clock_cycles: 318 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F64 } + } + clock_cycles: 633 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F64 } + } + clock_cycles: 263 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F64 } + } + clock_cycles: 618 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F64 } + } + clock_cycles: 324 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F64 } + } + clock_cycles: 406 + } + entries { + instruction { + opcode: "add" + shape { element_type: F64 } + } + clock_cycles: 15 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F64 } + } + clock_cycles: 973 + } + entries { + instruction { + opcode: "divide" + shape { element_type: F64 } + } + clock_cycles: 501 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: F64 } + } + clock_cycles: 15 + } + entries { + instruction { + opcode: "power" + shape { element_type: F64 } + } + clock_cycles: 2099 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: F64 } + } + clock_cycles: 15 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: C64 } + } + clock_cycles: 780 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: C64 } + } + clock_cycles: 722 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: C64 } + } + clock_cycles: 703 + } + entries { + instruction { + opcode: "log" + shape { element_type: C64 } + } + clock_cycles: 758 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C64 } + } + clock_cycles: 654 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: C64 } + } + clock_cycles: 3261 + } + entries { + instruction { + opcode: "sine" + shape { element_type: C64 } + } + clock_cycles: 789 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C64 } + } + clock_cycles: 6282 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: C64 } + } + clock_cycles: 1924 + } + entries { + instruction { + opcode: "add" + shape { element_type: C64 } + } + clock_cycles: 12 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C64 } + } + clock_cycles: 8151 + } + entries { + instruction { + opcode: "divide" + shape { element_type: C64 } + } + clock_cycles: 480 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: C64 } + } + clock_cycles: 42 + } + entries { + instruction { + opcode: "power" + shape { element_type: C64 } + } + clock_cycles: 8105 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: C64 } + } + clock_cycles: 12 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: C128 } + } + clock_cycles: 1808 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: C128 } + } + clock_cycles: 1487 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: C128 } + } + clock_cycles: 1334 + } + entries { + instruction { + opcode: "log" + shape { element_type: C128 } + } + clock_cycles: 1805 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C128 } + } + clock_cycles: 1618 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: C128 } + } + clock_cycles: 7261 + } + entries { + instruction { + opcode: "sine" + shape { element_type: C128 } + } + clock_cycles: 2013 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C128 } + } + clock_cycles: 8237 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: C128 } + } + clock_cycles: 6343 + } + entries { + instruction { + opcode: "add" + shape { element_type: C128 } + } + clock_cycles: 15 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C128 } + } + clock_cycles: 15355 + } + entries { + instruction { + opcode: "divide" + shape { element_type: C128 } + } + clock_cycles: 2423 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: C128 } + } + clock_cycles: 45 + } + entries { + instruction { + opcode: "power" + shape { element_type: C128 } + } + clock_cycles: 9810 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: C128 } + } + clock_cycles: 15 + } + } + } + + entries { + key: "sm_60" # "Tesla P100-SXM2-16GB" + value { + entries { + instruction { + opcode: "add" + shape { element_type: S8 } + } + clock_cycles: 2 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S8 } + } + clock_cycles: 426 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: S8 } + } + clock_cycles: 5 + } + entries { + instruction { + opcode: "power" + shape { element_type: S8 } + } + clock_cycles: 216 + } + entries { + instruction { + opcode: "add" + shape { element_type: S16 } + } + clock_cycles: 2 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S16 } + } + clock_cycles: 420 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: S16 } + } + clock_cycles: 5 + } + entries { + instruction { + opcode: "power" + shape { element_type: S16 } + } + clock_cycles: 216 + } + entries { + instruction { + opcode: "add" + shape { element_type: S32 } + } + clock_cycles: 2 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S32 } + } + clock_cycles: 444 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: S32 } + } + clock_cycles: 14 + } + entries { + instruction { + opcode: "power" + shape { element_type: S32 } + } + clock_cycles: 417 + } + entries { + instruction { + opcode: "add" + shape { element_type: S64 } + } + clock_cycles: 2 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S64 } + } + clock_cycles: 1018 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: S64 } + } + clock_cycles: 82 + } + entries { + instruction { + opcode: "power" + shape { element_type: S64 } + } + clock_cycles: 1569 + } + entries { + instruction { + opcode: "add" + shape { element_type: U8 } + } + clock_cycles: 2 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U8 } + } + clock_cycles: 299 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: U8 } + } + clock_cycles: 5 + } + entries { + instruction { + opcode: "power" + shape { element_type: U8 } + } + clock_cycles: 213 + } + entries { + instruction { + opcode: "add" + shape { element_type: U16 } + } + clock_cycles: 2 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U16 } + } + clock_cycles: 307 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: U16 } + } + clock_cycles: 5 + } + entries { + instruction { + opcode: "power" + shape { element_type: U16 } + } + clock_cycles: 216 + } + entries { + instruction { + opcode: "add" + shape { element_type: U32 } + } + clock_cycles: 2 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U32 } + } + clock_cycles: 189 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: U32 } + } + clock_cycles: 14 + } + entries { + instruction { + opcode: "power" + shape { element_type: U32 } + } + clock_cycles: 420 + } + entries { + instruction { + opcode: "add" + shape { element_type: U64 } + } + clock_cycles: 2 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U64 } + } + clock_cycles: 888 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: U64 } + } + clock_cycles: 79 + } + entries { + instruction { + opcode: "power" + shape { element_type: U64 } + } + clock_cycles: 1548 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F16 } + } + clock_cycles: 233 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F16 } + } + clock_cycles: 532 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F16 } + } + clock_cycles: 142 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F16 } + } + clock_cycles: 364 + } + entries { + instruction { + opcode: "log" + shape { element_type: F16 } + } + clock_cycles: 325 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F16 } + } + clock_cycles: 373 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F16 } + } + clock_cycles: 100 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F16 } + } + clock_cycles: 497 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F16 } + } + clock_cycles: 100 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F16 } + } + clock_cycles: 458 + } + entries { + instruction { + opcode: "add" + shape { element_type: F16 } + } + clock_cycles: 11 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F16 } + } + clock_cycles: 675 + } + entries { + instruction { + opcode: "divide" + shape { element_type: F16 } + } + clock_cycles: 68 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: F16 } + } + clock_cycles: 11 + } + entries { + instruction { + opcode: "power" + shape { element_type: F16 } + } + clock_cycles: 1012 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: F16 } + } + clock_cycles: 11 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F32 } + } + clock_cycles: 213 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F32 } + } + clock_cycles: 494 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F32 } + } + clock_cycles: 109 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F32 } + } + clock_cycles: 337 + } + entries { + instruction { + opcode: "log" + shape { element_type: F32 } + } + clock_cycles: 284 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F32 } + } + clock_cycles: 328 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F32 } + } + clock_cycles: 71 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F32 } + } + clock_cycles: 473 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F32 } + } + clock_cycles: 71 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F32 } + } + clock_cycles: 426 + } + entries { + instruction { + opcode: "add" + shape { element_type: F32 } + } + clock_cycles: 11 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F32 } + } + clock_cycles: 663 + } + entries { + instruction { + opcode: "divide" + shape { element_type: F32 } + } + clock_cycles: 35 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: F32 } + } + clock_cycles: 11 + } + entries { + instruction { + opcode: "power" + shape { element_type: F32 } + } + clock_cycles: 988 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: F32 } + } + clock_cycles: 11 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F64 } + } + clock_cycles: 645 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F64 } + } + clock_cycles: 1427 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F64 } + } + clock_cycles: 405 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F64 } + } + clock_cycles: 544 + } + entries { + instruction { + opcode: "log" + shape { element_type: F64 } + } + clock_cycles: 441 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F64 } + } + clock_cycles: 784 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F64 } + } + clock_cycles: 355 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F64 } + } + clock_cycles: 1640 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F64 } + } + clock_cycles: 417 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F64 } + } + clock_cycles: 473 + } + entries { + instruction { + opcode: "add" + shape { element_type: F64 } + } + clock_cycles: 14 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F64 } + } + clock_cycles: 1169 + } + entries { + instruction { + opcode: "divide" + shape { element_type: F64 } + } + clock_cycles: 565 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: F64 } + } + clock_cycles: 14 + } + entries { + instruction { + opcode: "power" + shape { element_type: F64 } + } + clock_cycles: 2682 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: F64 } + } + clock_cycles: 14 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: C64 } + } + clock_cycles: 1128 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: C64 } + } + clock_cycles: 1021 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: C64 } + } + clock_cycles: 991 + } + entries { + instruction { + opcode: "log" + shape { element_type: C64 } + } + clock_cycles: 1107 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C64 } + } + clock_cycles: 994 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: C64 } + } + clock_cycles: 2158 + } + entries { + instruction { + opcode: "sine" + shape { element_type: C64 } + } + clock_cycles: 1139 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C64 } + } + clock_cycles: 2934 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: C64 } + } + clock_cycles: 1883 + } + entries { + instruction { + opcode: "add" + shape { element_type: C64 } + } + clock_cycles: 20 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C64 } + } + clock_cycles: 16282 + } + entries { + instruction { + opcode: "divide" + shape { element_type: C64 } + } + clock_cycles: 760 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: C64 } + } + clock_cycles: 65 + } + entries { + instruction { + opcode: "power" + shape { element_type: C64 } + } + clock_cycles: 8335 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: C64 } + } + clock_cycles: 20 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: C128 } + } + clock_cycles: 4302 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: C128 } + } + clock_cycles: 3665 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: C128 } + } + clock_cycles: 3656 + } + entries { + instruction { + opcode: "log" + shape { element_type: C128 } + } + clock_cycles: 2057 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C128 } + } + clock_cycles: 1806 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: C128 } + } + clock_cycles: 6135 + } + entries { + instruction { + opcode: "sine" + shape { element_type: C128 } + } + clock_cycles: 4169 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C128 } + } + clock_cycles: 8595 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: C128 } + } + clock_cycles: 5294 + } + entries { + instruction { + opcode: "add" + shape { element_type: C128 } + } + clock_cycles: 20 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C128 } + } + clock_cycles: 22278 + } + entries { + instruction { + opcode: "divide" + shape { element_type: C128 } + } + clock_cycles: 3194 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: C128 } + } + clock_cycles: 65 + } + entries { + instruction { + opcode: "power" + shape { element_type: C128 } + } + clock_cycles: 17893 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: C128 } + } + clock_cycles: 20 + } + } + } + + entries { + key: "sm_75" + value { + entries { + instruction { + opcode: "divide" + shape { element_type: S8 } + } + clock_cycles: 360 + } + entries { + instruction { + opcode: "power" + shape { element_type: S8 } + } + clock_cycles: 336 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S16 } + } + clock_cycles: 357 + } + entries { + instruction { + opcode: "power" + shape { element_type: S16 } + } + clock_cycles: 339 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S32 } + } + clock_cycles: 296 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S64 } + } + clock_cycles: 979 + } + entries { + instruction { + opcode: "power" + shape { element_type: S64 } + } + clock_cycles: 495 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U8 } + } + clock_cycles: 293 + } + entries { + instruction { + opcode: "power" + shape { element_type: U8 } + } + clock_cycles: 334 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U16 } + } + clock_cycles: 290 + } + entries { + instruction { + opcode: "power" + shape { element_type: U16 } + } + clock_cycles: 336 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U32 } + } + clock_cycles: 118 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U64 } + } + clock_cycles: 812 + } + entries { + instruction { + opcode: "power" + shape { element_type: U64 } + } + clock_cycles: 515 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F16 } + } + clock_cycles: 792 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F16 } + } + clock_cycles: 815 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F16 } + } + clock_cycles: 132 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F16 } + } + clock_cycles: 342 + } + entries { + instruction { + opcode: "log" + shape { element_type: F16 } + } + clock_cycles: 239 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F16 } + } + clock_cycles: 239 + } + entries { + instruction { + opcode: "logistic" + shape { element_type: F16 } + } + clock_cycles: 262 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F16 } + } + clock_cycles: 126 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F16 } + } + clock_cycles: 794 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F16 } + } + clock_cycles: 123 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F16 } + } + clock_cycles: 175 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F16 } + } + clock_cycles: 414 + } + entries { + instruction { + opcode: "divide" + shape { element_type: F16 } + } + clock_cycles: 74 + } + entries { + instruction { + opcode: "power" + shape { element_type: F16 } + } + clock_cycles: 1120 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F32 } + } + clock_cycles: 783 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F32 } + } + clock_cycles: 737 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F32 } + } + clock_cycles: 83 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F32 } + } + clock_cycles: 319 + } + entries { + instruction { + opcode: "log" + shape { element_type: F32 } + } + clock_cycles: 201 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F32 } + } + clock_cycles: 218 + } + entries { + instruction { + opcode: "logistic" + shape { element_type: F32 } + } + clock_cycles: 181 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F32 } + } + clock_cycles: 74 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F32 } + } + clock_cycles: 717 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F32 } + } + clock_cycles: 74 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F32 } + } + clock_cycles: 167 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F32 } + } + clock_cycles: 414 + } + entries { + instruction { + opcode: "power" + shape { element_type: F32 } + } + clock_cycles: 1085 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F64 } + } + clock_cycles: 6494 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F64 } + } + clock_cycles: 1800 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F64 } + } + clock_cycles: 1630 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F64 } + } + clock_cycles: 1929 + } + entries { + instruction { + opcode: "log" + shape { element_type: F64 } + } + clock_cycles: 596 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F64 } + } + clock_cycles: 1774 + } + entries { + instruction { + opcode: "logistic" + shape { element_type: F64 } + } + clock_cycles: 2430 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F64 } + } + clock_cycles: 705 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F64 } + } + clock_cycles: 1805 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F64 } + } + clock_cycles: 984 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F64 } + } + clock_cycles: 1535 + } + entries { + instruction { + opcode: "add" + shape { element_type: F64 } + } + clock_cycles: 95 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F64 } + } + clock_cycles: 3744 + } + entries { + instruction { + opcode: "divide" + shape { element_type: F64 } + } + clock_cycles: 1915 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: F64 } + } + clock_cycles: 95 + } + entries { + instruction { + opcode: "power" + shape { element_type: F64 } + } + clock_cycles: 5538 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: F64 } + } + clock_cycles: 95 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: C64 } + } + clock_cycles: 1702 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: C64 } + } + clock_cycles: 1503 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: C64 } + } + clock_cycles: 1474 + } + entries { + instruction { + opcode: "log" + shape { element_type: C64 } + } + clock_cycles: 835 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C64 } + } + clock_cycles: 737 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: C64 } + } + clock_cycles: 2232 + } + entries { + instruction { + opcode: "sine" + shape { element_type: C64 } + } + clock_cycles: 1632 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C64 } + } + clock_cycles: 2989 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: C64 } + } + clock_cycles: 2263 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C64 } + } + clock_cycles: 4847 + } + entries { + instruction { + opcode: "power" + shape { element_type: C64 } + } + clock_cycles: 3219 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: C128 } + } + clock_cycles: 6474 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: C128 } + } + clock_cycles: 4962 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: C128 } + } + clock_cycles: 4037 + } + entries { + instruction { + opcode: "log" + shape { element_type: C128 } + } + clock_cycles: 7286 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C128 } + } + clock_cycles: 6848 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: C128 } + } + clock_cycles: 10748 + } + entries { + instruction { + opcode: "sine" + shape { element_type: C128 } + } + clock_cycles: 5391 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C128 } + } + clock_cycles: 15981 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: C128 } + } + clock_cycles: 9653 + } + entries { + instruction { + opcode: "add" + shape { element_type: C128 } + } + clock_cycles: 95 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C128 } + } + clock_cycles: 38206 + } + entries { + instruction { + opcode: "divide" + shape { element_type: C128 } + } + clock_cycles: 8040 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: C128 } + } + clock_cycles: 273 + } + entries { + instruction { + opcode: "power" + shape { element_type: C128 } + } + clock_cycles: 18550 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: C128 } + } + clock_cycles: 97 + } + } + } +)pb"; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_HLO_OP_PROFILES_DATA_H_ diff --git a/xla/service/gpu/model/hlo_op_profiles_test.cc b/xla/service/gpu/model/hlo_op_profiles_test.cc new file mode 100644 index 0000000000000..d8b3e953e09e3 --- /dev/null +++ b/xla/service/gpu/model/hlo_op_profiles_test.cc @@ -0,0 +1,89 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/hlo_op_profiles.h" + +#include +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { +namespace { + +constexpr char kDeviceHloOpProfiles[] = R"pb( + entries { + key: "sm_90" + value { + entries { + instruction { + opcode: "divide" + shape { element_type: F32 } + } + clock_cycles: 32 + } + } + } + + entries { + key: "sm_80" + value { + entries { + instruction { + opcode: "multiply" + shape { element_type: F32 } + } + clock_cycles: 64 + } + } + } +)pb"; + +using HloOpProfilesTest = ::testing::Test; + +TEST_F(HloOpProfilesTest, GetProfile) { + auto hlo_op_profiles = HloOpProfiles::Load(kDeviceHloOpProfiles, + /*default_profile_name=*/"sm_80"); + auto device_info_sm_90 = TestGpuDeviceInfo::RTXA6000DeviceInfo( + stream_executor::CudaComputeCapability(9, 0)); + + const auto& op_profile = hlo_op_profiles->GetProfile(&device_info_sm_90); + ASSERT_TRUE(op_profile.contains( + std::make_pair(HloOpcode::kDivide, PrimitiveType::F32))); + EXPECT_EQ( + op_profile.at(std::make_pair(HloOpcode::kDivide, PrimitiveType::F32)), + 32); +} + +TEST_F(HloOpProfilesTest, GetProfileDefault) { + auto hlo_op_profiles = HloOpProfiles::Load(kDeviceHloOpProfiles, + /*default_profile_name=*/"sm_80"); + auto device_info_sm_85 = TestGpuDeviceInfo::RTXA6000DeviceInfo( + stream_executor::CudaComputeCapability(8, 5)); + + // hlo_op_profiles only has sm_80 and sm_90, should return the default sm_80. + const auto& op_profile = hlo_op_profiles->GetProfile(&device_info_sm_85); + ASSERT_TRUE(op_profile.contains( + std::make_pair(HloOpcode::kMultiply, PrimitiveType::F32))); + EXPECT_EQ( + op_profile.at(std::make_pair(HloOpcode::kMultiply, PrimitiveType::F32)), + 64); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/indexing_analysis.cc b/xla/service/gpu/model/indexing_analysis.cc new file mode 100644 index 0000000000000..b925b5887494e --- /dev/null +++ b/xla/service/gpu/model/indexing_analysis.cc @@ -0,0 +1,1518 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/indexing_analysis.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" +#include "xla/permutation_util.h" +#include "xla/service/gather_simplifier.h" +#include "xla/service/gpu/fusions/tiling_util.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { +namespace { + +using llvm::SmallVector; +using mlir::AffineExpr; +using mlir::AffineMap; +using mlir::getAffineConstantExpr; +using mlir::getAffineDimExpr; +using mlir::getAffineSymbolExpr; +using mlir::MLIRContext; + +HloInstructionIndexing CreateUnknownIndexing(int64_t count = 1) { + HloInstructionIndexing indexing; + indexing.indexing_maps = std::vector>( + count, {IndexingMap::GetUndefined()}); + return indexing; +} + +HloInstructionIndexing ComputeOutputToInputCwiseOpIndexing( + const HloInstruction* instr, MLIRContext* mlir_context) { + IndexingMap identity_map = CreateIdentityMap(instr->shape(), mlir_context); + + HloInstructionIndexing instr_indexing; + instr_indexing.indexing_maps.resize(instr->operand_count()); + int64_t operand_count = instr->operand_count(); + for (int64_t operand_id = 0; operand_id < operand_count; ++operand_id) { + instr_indexing.indexing_maps[operand_id].insert(identity_map); + } + return instr_indexing; +} + +HloInstructionIndexing ComputeInputToOutputCwiseOpIndexing( + const HloInstruction* instr, MLIRContext* mlir_context) { + IndexingMap identity_map = CreateIdentityMap(instr->shape(), mlir_context); + return HloInstructionIndexing::FromIndexingMaps({identity_map}); +} + +HloInstructionIndexing ComputeOutputToInputBroadcastOpIndexing( + const HloBroadcastInstruction* bcast, MLIRContext* mlir_context) { + auto output_dims = bcast->shape().dimensions(); + + std::vector exprs; + exprs.reserve(bcast->dimensions().size()); + for (int64_t bcast_dim : bcast->dimensions()) { + exprs.push_back(getAffineDimExpr(bcast_dim, mlir_context)); + } + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, + mlir_context), + output_dims, {}); + return HloInstructionIndexing::FromIndexingMaps({indexing_map}); +} + +HloInstructionIndexing ComputeInputToOutputBroadcastOpIndexing( + const HloBroadcastInstruction* bcast, MLIRContext* mlir_context) { + absl::Span bcast_dims = bcast->dimensions(); + + const Shape& input_shape = bcast->operand(0)->shape(); + const Shape& output_shape = bcast->shape(); + + std::vector added_dims_sizes; + std::vector exprs; + exprs.reserve(output_shape.rank()); + for (auto [output_dim_id, output_dim] : + llvm::enumerate(output_shape.dimensions())) { + auto bcast_dim = + std::find(bcast_dims.begin(), bcast_dims.end(), output_dim_id); + if (bcast_dim == bcast_dims.end()) { + exprs.push_back( + getAffineSymbolExpr(added_dims_sizes.size(), mlir_context)); + added_dims_sizes.push_back(output_dim); + continue; + } + exprs.push_back(getAffineDimExpr( + std::distance(bcast_dims.begin(), bcast_dim), mlir_context)); + } + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(input_shape.rank(), added_dims_sizes.size(), exprs, + mlir_context), + input_shape.dimensions(), added_dims_sizes); + + return HloInstructionIndexing::FromIndexingMaps({indexing_map}); +} + +HloInstructionIndexing ComputeOutputToInputConcatenateOpIndexing( + const HloConcatenateInstruction* concat, MLIRContext* mlir_context) { + const auto& operand_0_dims = concat->operand(0)->shape().dimensions(); + + // Initialize affine map and domain. Only concat_dim elements of both have to + // be adjusted for a particular operand_id. + mlir::MutableAffineMap affine_map = + AffineMap::getMultiDimIdentityMap(operand_0_dims.size(), mlir_context); + std::vector dim_vars = DimVarsFromTensorSizes(operand_0_dims); + + HloInstructionIndexing concat_indexing; + concat_indexing.indexing_maps.resize(concat->operand_count()); + int64_t concat_dim = concat->concatenate_dimension(); + AffineExpr concat_dim_expr = getAffineDimExpr(concat_dim, mlir_context); + int64_t offset = 0; + for (const auto [operand_id, operand] : llvm::enumerate(concat->operands())) { + affine_map.setResult(concat_dim, concat_dim_expr - offset); + int64_t operand_concat_dim = operand->shape().dimensions()[concat_dim]; + dim_vars[concat_dim] = DimVar{{offset, offset + operand_concat_dim - 1}}; + concat_indexing.indexing_maps[operand_id].insert( + IndexingMap(affine_map.getAffineMap(), dim_vars, + /*range_vars=*/{}, /*rt_vars=*/{})); + offset += operand_concat_dim; + } + return concat_indexing; +} + +HloInstructionIndexing ComputeInputToOutputConcatenateOpIndexing( + const HloConcatenateInstruction* concat, int input_id, + MLIRContext* mlir_context) { + int64_t concat_dim = concat->concatenate_dimension(); + int64_t offset = 0; + for (int64_t operand_id = 0; operand_id < input_id; ++operand_id) { + offset += concat->operand(operand_id)->shape().dimensions()[concat_dim]; + } + // Initialize affine map. Only concat_dim element has to be adjusted for a + // particular operand_id. + const auto& operand_dims = concat->operand(input_id)->shape().dimensions(); + mlir::MutableAffineMap affine_map = + AffineMap::getMultiDimIdentityMap(operand_dims.size(), mlir_context); + affine_map.setResult(concat_dim, + getAffineDimExpr(concat_dim, mlir_context) + offset); + IndexingMap indexing_map = + IndexingMap::FromTensorSizes(affine_map.getAffineMap(), operand_dims, {}); + return HloInstructionIndexing::FromIndexingMaps({indexing_map}); +} + +// Composes instruction indexing maps starting at the root instruction +// until the HloParameterInstruction is found. +HloInstructionIndexing ComputeOutputToInputFusionOpIndexing( + const HloFusionInstruction* fusion, int output_id, + MLIRContext* mlir_context) { + auto fusion_adaptor = HloFusionAdaptor::ForInstruction(fusion); + auto grouped_indexing_maps = ComputeGroupedOutputToInputIndexing( + *fusion_adaptor, fusion_adaptor->GetRoots()[output_id], mlir_context); + + // After the traversal, `grouped_indexing_maps` is keyed by + // HloParameterInstructions. Convert them back to the operand id and return. + HloInstructionIndexing fusion_indexing; + fusion_indexing.indexing_maps.resize(fusion->operand_count()); + for (auto [operand_id, operand] : llvm::enumerate(fusion->operands())) { + fusion_indexing.indexing_maps[operand_id] = grouped_indexing_maps[operand]; + } + return fusion_indexing; +} + +HloInstructionIndexing ComputeOutputToInputDotOpIndexing( + const HloDotInstruction* dot, MLIRContext* mlir_context) { + CHECK_NE(dot, nullptr); + const DotDimensionNumbers& dim_numbers = dot->dot_dimension_numbers(); + absl::Span lhs_contracting_dims( + dim_numbers.lhs_contracting_dimensions()); + absl::Span rhs_contracting_dims = + dim_numbers.rhs_contracting_dimensions(); + + absl::Span lhs_batch_dims = dim_numbers.lhs_batch_dimensions(); + absl::Span rhs_batch_dims = dim_numbers.rhs_batch_dimensions(); + + const Shape& lhs_shape = dot->operand(0)->shape(); + const Shape& rhs_shape = dot->operand(1)->shape(); + // According to the StableHLO specification, the dimensions of the output + // shape are ordered as follows: + // lhs_batch_dims | lhs_non_contracting_dims | rhs_non_contracting_dims + SmallVector lhs_exprs(lhs_shape.rank()); + SmallVector rhs_exprs(rhs_shape.rank()); + int64_t output_dim_id = 0; + + // lhs_batch_dims + for (auto [lhs_batch_dim, rhs_batch_dim] : + llvm::zip(lhs_batch_dims, rhs_batch_dims)) { + AffineExpr output_dim_expr = getAffineDimExpr(output_dim_id, mlir_context); + lhs_exprs[lhs_batch_dim] = output_dim_expr; + rhs_exprs[rhs_batch_dim] = output_dim_expr; + ++output_dim_id; + } + + // lhs_non_contracting_dims + auto lhs_non_contracting_dims = + GetNonContractingDims(lhs_shape, lhs_batch_dims, lhs_contracting_dims); + assert(lhs_non_contracting_dims.ok()); + + for (int64_t lhs_non_contracting_dim : lhs_non_contracting_dims.value()) { + lhs_exprs[lhs_non_contracting_dim] = + getAffineDimExpr(output_dim_id++, mlir_context); + } + + // rhs_non_contracting_dims + auto rhs_non_contracting_dims = + GetNonContractingDims(rhs_shape, rhs_batch_dims, rhs_contracting_dims); + assert(rhs_non_contracting_dims.ok()); + for (int64_t rhs_non_contracting_dim : rhs_non_contracting_dims.value()) { + rhs_exprs[rhs_non_contracting_dim] = + getAffineDimExpr(output_dim_id++, mlir_context); + } + + int64_t input_dim_id = 0; + std::vector input_dim_sizes; + input_dim_sizes.reserve(lhs_contracting_dims.size()); + + for (auto [lhs_contracting_dim, rhs_contracting_dim] : + llvm::zip(lhs_contracting_dims, rhs_contracting_dims)) { + AffineExpr input_dim_expr = getAffineSymbolExpr(input_dim_id, mlir_context); + lhs_exprs[lhs_contracting_dim] = input_dim_expr; + rhs_exprs[rhs_contracting_dim] = input_dim_expr; + ++input_dim_id; + + // LHS and RHS contracting dimensions must match pairwise, and we therefore + // need only populate a single input_dim_sizes vector. + input_dim_sizes.push_back(lhs_shape.dimensions(lhs_contracting_dim)); + } + + IndexingMap lhs_indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(dot->shape().rank(), input_dim_sizes.size(), lhs_exprs, + mlir_context), + dot->shape().dimensions(), input_dim_sizes); + + IndexingMap rhs_indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(dot->shape().rank(), input_dim_sizes.size(), rhs_exprs, + mlir_context), + dot->shape().dimensions(), input_dim_sizes); + return HloInstructionIndexing::FromIndexingMaps( + {lhs_indexing_map, rhs_indexing_map}); +} + +HloInstructionIndexing ComputeOutputToInputDynamicSliceOpIndexing( + const HloDynamicSliceInstruction* dynamic_slice, + MLIRContext* mlir_context) { + const Shape& input_shape = dynamic_slice->operand(0)->shape(); + const Shape& output_shape = dynamic_slice->shape(); + int64_t rank = output_shape.rank(); + const int64_t first_index_num = dynamic_slice->first_index_operand_number(); + + CHECK(dynamic_slice->operand(first_index_num)->shape().rank() == 0) + << "b/118437727: Old form, not supported."; + // A map from tensor iteration space to (), because index operands are 0d + // tensors. + AffineMap empty_results_affine_map = AffineMap::get( + /*dimCount=*/rank, /*symbolCount=*/0, /*results=*/{}, mlir_context); + IndexingMap start_indices_map = IndexingMap::FromTensorSizes( + empty_results_affine_map, output_shape.dimensions(), {}); + + std::vector offsets_rt_vars; + offsets_rt_vars.reserve(rank); + std::vector exprs; + exprs.reserve(rank); + for (auto [dim, slice_size] : + llvm::enumerate(dynamic_slice->dynamic_slice_sizes())) { + exprs.push_back(getAffineDimExpr(dim, mlir_context) + + getAffineSymbolExpr(dim, mlir_context)); + offsets_rt_vars.push_back( + RTVar{Interval{0, input_shape.dimensions(dim) - slice_size}, + dynamic_slice->operand(dim + first_index_num), + empty_results_affine_map}); + } + std::vector indexing_maps(dynamic_slice->operand_count(), + start_indices_map); + indexing_maps.front() = + IndexingMap{AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/rank, exprs, + mlir_context), + start_indices_map.GetDimVars(), /*range_vars=*/{}, + std::move(offsets_rt_vars)}; + return HloInstructionIndexing::FromIndexingMaps(indexing_maps); +} + +HloInstructionIndexing ComputeOutputToInputDynamicUpdateSliceOpIndexing( + const HloDynamicUpdateSliceInstruction* dus, MLIRContext* mlir_context) { + const Shape& update_shape = dus->update()->shape(); + const Shape& output_shape = dus->shape(); + int64_t rank = output_shape.rank(); + + // operand: (d0, ... d_{N-1}) -> (d0, ... d_{N-1}) + std::vector identity; + for (int64_t dim = 0; dim < rank; ++dim) { + identity.push_back(getAffineDimExpr(dim, mlir_context)); + } + IndexingMap operand_map = IndexingMap::FromTensorSizes( + AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, /*results=*/identity, + mlir_context), + output_shape.dimensions(), {}); + + // start_indices: (d0, ... d_{N-1}) -> () + AffineMap empty_results_affine_map = AffineMap::get( + /*dimCount=*/rank, /*symbolCount=*/0, /*results=*/{}, mlir_context); + IndexingMap start_indices_map = IndexingMap::FromTensorSizes( + empty_results_affine_map, output_shape.dimensions(), {}); + + // update: (d_0 - s_0, ..., d_{N-1} - s_{N-1}) + std::vector exprs; + exprs.reserve(rank); + std::vector rt_vars; + rt_vars.reserve(rank); + for (auto [dim, slice_size] : llvm::enumerate(update_shape.dimensions())) { + exprs.push_back(getAffineDimExpr(dim, mlir_context) - + getAffineSymbolExpr(dim, mlir_context)); + Interval feasible_values{0, output_shape.dimensions(dim) - slice_size}; + rt_vars.push_back(RTVar{feasible_values, dus->operand(2 + dim), + empty_results_affine_map}); + } + IndexingMap update_map{AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/rank, + /*results=*/exprs, mlir_context), + operand_map.GetDimVars(), + /*range_vars=*/{}, rt_vars}; + + std::vector indexing_maps(dus->operand_count(), + start_indices_map); + indexing_maps[0] = std::move(operand_map); + indexing_maps[1] = std::move(update_map); + return HloInstructionIndexing::FromIndexingMaps(indexing_maps); +} + +HloInstructionIndexing ComputeOutputToInputGatherOpIndexing( + const HloGatherInstruction* gather, MLIRContext* mlir_context) { + CHECK(GatherSimplifier::IsSimplifiedGather(gather)) + << "Non-simplified HLO Gather is not supported."; + const Shape& operand_shape = gather->operand(0)->shape(); + const Shape& indices_shape = gather->operand(1)->shape(); + + const GatherDimensionNumbers& dimension_numbers = + gather->gather_dimension_numbers(); + int64_t index_vector_length = + indices_shape.dimensions(dimension_numbers.index_vector_dim()); + + const Shape& output_shape = gather->shape(); + int64_t output_rank = output_shape.rank(); + + // A map for the `indices` operand of gather. It is always + // (d_0, ... d_{rank - 1}) -> (d_0, s_0), + // where 0 <= s_0 <= indices_shape[1] - 1. + AffineExpr indices_id_dim = getAffineDimExpr(0, mlir_context); + std::vector dim_vars = + DimVarsFromTensorSizes(output_shape.dimensions()); + IndexingMap indices_map{ + AffineMap::get(output_rank, 1, + {indices_id_dim, getAffineSymbolExpr(0, mlir_context)}, + mlir_context), + dim_vars, + {RangeVar{{0, index_vector_length - 1}}}, + /*rt_vars=*/{}}; + + // A map for the `operand` operand of gather, from which we extract slices. + // (d_0, ... d_{rank - 1}) -> (d_1 + s0, d_2 + s_1, ...), + // where s_i are RTVars that extract indices from the `indices` operand. + std::vector rt_vars; + std::vector exprs; + exprs.reserve(operand_shape.rank()); + for (auto [operand_dim_id, slice_size] : + llvm::enumerate(gather->gather_slice_sizes())) { + int64_t output_dim_id = dimension_numbers.offset_dims(operand_dim_id); + exprs.push_back(getAffineDimExpr(output_dim_id, mlir_context)); + + if (operand_dim_id >= index_vector_length) continue; + + rt_vars.push_back(RTVar{ + Interval{0, operand_shape.dimensions(operand_dim_id) - slice_size}, + gather->operand(1), + AffineMap::get(output_rank, /*symbolCount=*/0, + {indices_id_dim, + getAffineConstantExpr(operand_dim_id, mlir_context)}, + mlir_context)}); + exprs.back() = + exprs.back() + getAffineSymbolExpr(operand_dim_id, mlir_context); + } + IndexingMap operand_map = { + AffineMap::get(/*dimCount=*/output_rank, + /*symbolCount=*/index_vector_length, exprs, mlir_context), + std::move(dim_vars), /*range_vars=*/{}, std::move(rt_vars)}; + return HloInstructionIndexing::FromIndexingMaps({operand_map, indices_map}); +} + +IndexingMap ComputeOutputToInputPadOpIndexingImpl( + absl::Span output_dims, + absl::Span padding_low, + absl::Span padding_high, + absl::Span padding_interior, MLIRContext* mlir_context) { + int64_t output_rank = output_dims.size(); + + std::vector exprs; + std::vector> constraints; + std::vector dim_vars; + exprs.reserve(output_rank); + constraints.reserve(output_rank); + int64_t output_dim_id = 0; + for (const auto [output_dim, pad_low, pad_high, pad_interior] : + llvm::zip(output_dims, padding_low, padding_high, padding_interior)) { + AffineExpr dim_expr = getAffineDimExpr(output_dim_id, mlir_context); + dim_vars.push_back( + {Interval{std::max(int64_t{0}, pad_low), + std::min(output_dim - 1, output_dim - 1 - pad_high)}}); + if (pad_interior == 0) { + exprs.push_back(dim_expr - pad_low); + } else { + exprs.push_back((dim_expr - pad_low).floorDiv(pad_interior + 1)); + constraints.push_back( + {(dim_expr - pad_low) % (pad_interior + 1), Interval{0, 0}}); + } + ++output_dim_id; + } + return IndexingMap{ + AffineMap::get(output_rank, /*symbolCount=*/0, exprs, mlir_context), + std::move(dim_vars), + /*range_vars = */ {}, + /*rt_vars = */ {}, + absl::MakeSpan(constraints)}; +} + +HloInstructionIndexing ComputeOutputToInputPadOpIndexing( + const HloPadInstruction* pad, MLIRContext* mlir_context) { + const Shape& output_shape = pad->shape(); + int64_t rank = output_shape.rank(); + SmallVector padding_low, padding_high, padding_interior; + padding_low.reserve(rank); + padding_high.reserve(rank); + padding_interior.reserve(rank); + for (const auto& dim_config : pad->padding_config().dimensions()) { + padding_low.push_back(dim_config.edge_padding_low()); + padding_high.push_back(dim_config.edge_padding_high()); + padding_interior.push_back(dim_config.interior_padding()); + } + IndexingMap input_indexing_map = ComputeOutputToInputPadOpIndexingImpl( + output_shape.dimensions(), padding_low, padding_high, padding_interior, + mlir_context); + IndexingMap padding_value_indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(output_shape.rank(), /*symbolCount=*/0, {}, mlir_context), + output_shape.dimensions(), /*symbol_upper_bounds=*/{}); + return HloInstructionIndexing::FromIndexingMaps( + {input_indexing_map, padding_value_indexing_map}); +} + +HloInstructionIndexing ComputeOutputToInputReduceOpIndexing( + const HloReduceInstruction* reduce, int output_id, + MLIRContext* mlir_context) { + absl::flat_hash_set reduce_dims_ids(reduce->dimensions().begin(), + reduce->dimensions().end()); + + const Shape& input_shape = reduce->operand(output_id)->shape(); + const Shape& output_shape = GetOutputShape(reduce, 0); + + std::vector parallel_dims_sizes; + int64_t output_dim_id = 0; + std::vector exprs; + exprs.reserve(input_shape.rank()); + for (auto [input_dim_id, input_dim] : + llvm::enumerate(input_shape.dimensions())) { + if (reduce_dims_ids.contains(input_dim_id)) { + exprs.push_back( + getAffineSymbolExpr(parallel_dims_sizes.size(), mlir_context)); + parallel_dims_sizes.push_back(input_dim); + continue; + } + exprs.push_back(getAffineDimExpr(output_dim_id++, mlir_context)); + } + IndexingMap inputs_indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(output_shape.rank(), reduce_dims_ids.size(), exprs, + mlir_context), + output_shape.dimensions(), parallel_dims_sizes); + IndexingMap inits_indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(output_shape.rank(), /*symbolCount=*/0, {}, mlir_context), + output_shape.dimensions(), {}); + + HloInstructionIndexing instr_indexing; + instr_indexing.indexing_maps.resize(reduce->operand_count()); + for (int64_t id = 0; id < reduce->input_count(); ++id) { + instr_indexing.indexing_maps[id].insert(inputs_indexing_map); + } + for (int64_t id = reduce->input_count(); id < reduce->operand_count(); ++id) { + instr_indexing.indexing_maps[id].insert(inits_indexing_map); + } + return instr_indexing; +} + +HloInstructionIndexing ComputeInputToOutputReduceOpIndexing( + const HloReduceInstruction* reduce, int input_id, + MLIRContext* mlir_context) { + absl::flat_hash_set reduce_dims_ids(reduce->dimensions().begin(), + reduce->dimensions().end()); + const Shape& input_shape = reduce->operand(input_id)->shape(); + const Shape& output_shape = GetOutputShape(reduce, 0); + int64_t output_rank = output_shape.rank(); + + int64_t output_dim_id = 0; + std::vector inputs_exprs, inits_exprs; + inputs_exprs.reserve(output_rank); + inits_exprs.reserve(output_rank); + for (auto [input_dim_id, input_dim] : + llvm::enumerate(input_shape.dimensions())) { + if (reduce_dims_ids.contains(input_dim_id)) { + continue; + } + inputs_exprs.push_back(getAffineDimExpr(input_dim_id, mlir_context)); + inits_exprs.push_back(getAffineSymbolExpr(output_dim_id++, mlir_context)); + } + IndexingMap inputs_indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(input_shape.rank(), /*symbolCount=*/0, inputs_exprs, + mlir_context), + input_shape.dimensions(), {}); + IndexingMap inits_indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(0, /*symbolCount=*/output_rank, inits_exprs, mlir_context), + {}, output_shape.dimensions()); + + HloInstructionIndexing instr_indexing; + instr_indexing.indexing_maps.resize(reduce->operand_count()); + for (int64_t id = 0; id < reduce->input_count(); ++id) { + instr_indexing.indexing_maps[id].insert(inputs_indexing_map); + } + for (int64_t id = reduce->input_count(); id < reduce->operand_count(); ++id) { + instr_indexing.indexing_maps[id].insert(inits_indexing_map); + } + return instr_indexing; +} + +IndexingMap ComposeIndexingMapsForWindow( + absl::Span input_dimensions, + absl::Span output_dimensions, const Window& window, + MLIRContext* mlir_context) { + size_t rank = input_dimensions.size(); + + // Compute shape of the padded input and the indexing map of pad op required + // to pad the input. + SmallVector padding_low, padding_high, padding_interior, + padded_input_dimensions; + padding_low.reserve(rank); + padding_high.reserve(rank); + padding_interior.reserve(rank); + padded_input_dimensions.reserve(rank); + SmallVector exprs; + std::vector dim_vars; + std::vector range_vars; + exprs.reserve(rank); + dim_vars.reserve(rank); + range_vars.reserve(rank); + for (const auto& [dim_id, window_config] : + llvm::enumerate(window.dimensions())) { + padding_low.push_back(window_config.padding_low()); + padding_high.push_back(window_config.padding_high()); + // For some reason interior_padding in HLO pad is offset from base_dilations + // in HLO reduce-window by 1. + padding_interior.push_back(window_config.base_dilation() - 1); + padded_input_dimensions.push_back( + input_dimensions[dim_id] + window_config.padding_low() + + window_config.padding_high() + + (input_dimensions[dim_id] - 1) * (window_config.base_dilation() - 1)); + AffineExpr dim_expr = getAffineDimExpr(dim_id, mlir_context); + AffineExpr symbol_expr = getAffineSymbolExpr(dim_id, mlir_context); + + exprs.push_back(symbol_expr * window_config.window_dilation() + + window_config.stride() * dim_expr); + dim_vars.push_back({Interval{0, output_dimensions[dim_id] - 1}}); + range_vars.push_back({Interval{0, window_config.size() - 1}}); + } + // Indexing map for pad op that pads the input. + IndexingMap padded_input_indexing = ComputeOutputToInputPadOpIndexingImpl( + padded_input_dimensions, padding_low, padding_high, padding_interior, + mlir_context); + // Indexing map for reduce-window, that does not do any padding. + IndexingMap input_indexing_no_padding( + AffineMap::get(rank, rank, exprs, mlir_context), dim_vars, range_vars, + /*rt_vars=*/{}); + + // Composed indexing. + IndexingMap result = + ComposeIndexingMaps(input_indexing_no_padding, padded_input_indexing); + result.Simplify(GetIndexingMapForInstruction); + result.RemoveUnusedSymbols(); + return result; +} + +// Indexing for reduce-window with dilations and non-trivial padding can be +// represented as a composition of pad op and reduce-window that never goes out +// of bounds. +HloInstructionIndexing ComputeOutputToInputReduceWindowOpIndexing( + const HloReduceWindowInstruction* reduce_window, int output_id, + MLIRContext* mlir_context) { + const Shape& input_shape = reduce_window->operand(0)->shape(); + const Shape& output_shape = GetOutputShape(reduce_window, 0); + + // Indexing map for the input value. + IndexingMap inputs_indexing = ComposeIndexingMapsForWindow( + input_shape.dimensions(), output_shape.dimensions(), + reduce_window->window(), mlir_context); + + // Indexing map for the init value. + IndexingMap inits_indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(output_shape.rank(), /*symbolCount=*/0, {}, mlir_context), + output_shape.dimensions(), /*symbol_upper_bounds=*/{}); + + HloInstructionIndexing instr_indexing; + instr_indexing.indexing_maps.resize(reduce_window->operand_count()); + for (int64_t id = 0; id < reduce_window->input_count(); ++id) { + instr_indexing.indexing_maps[id].insert(inputs_indexing); + } + for (int64_t id = reduce_window->input_count(); + id < reduce_window->operand_count(); ++id) { + instr_indexing.indexing_maps[id].insert(inits_indexing_map); + } + return instr_indexing; +} + +HloInstructionIndexing ComputeOutputToInputConvolutionOpIndexing( + const HloConvolutionInstruction* convolution, MLIRContext* mlir_context) { + const Shape& input_shape = convolution->operand(0)->shape(); + const Shape& kernel_shape = convolution->operand(1)->shape(); + const Shape& output_shape = convolution->shape(); + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); + size_t rank = output_shape.rank(); + + // Collect sizes for input/output spatial dimensions. + size_t spatial_rank = rank - 2; + std::vector input_spatial_sizes(spatial_rank); + std::vector kernel_spatial_sizes(spatial_rank); + std::vector output_spatial_sizes(spatial_rank); + for (int i = 0; i < spatial_rank; ++i) { + input_spatial_sizes[i] = + input_shape.dimensions(dnums.input_spatial_dimensions(i)); + kernel_spatial_sizes[i] = + kernel_shape.dimensions(dnums.kernel_spatial_dimensions(i)); + output_spatial_sizes[i] = + output_shape.dimensions(dnums.output_spatial_dimensions(i)); + } + + // Indexing map for the input value (spatial dimensions only). + // The dimension numbers in the resulting affine expressions have to be + // remapped to correspond to the correct output dimensions. + IndexingMap input_spatial_indexing = + ComposeIndexingMapsForWindow(input_spatial_sizes, output_spatial_sizes, + convolution->window(), mlir_context); + std::vector replacement_dims(spatial_rank); + for (int i = 0; i < spatial_rank; ++i) { + replacement_dims[i] = + getAffineDimExpr(dnums.output_spatial_dimensions(i), mlir_context); + } + + // Build affine expressions and constraints for input spatial dimensions. + std::vector input_exprs(rank); + for (int i = 0; i < spatial_rank; ++i) { + input_exprs[dnums.input_spatial_dimensions(i)] = + input_spatial_indexing.GetAffineMap().getResult(i).replaceDims( + replacement_dims); + } + llvm::DenseMap input_constraints; + for (const auto& [key, val] : input_spatial_indexing.GetConstraints()) { + input_constraints[key.replaceDims(replacement_dims)] = val; + } + + // Build affine expressions for kernel spatial and output dimensions. + std::vector kernel_exprs(rank); + for (int i = 0; i < spatial_rank; ++i) { + kernel_exprs[dnums.kernel_spatial_dimensions(i)] = + getAffineSymbolExpr(i, mlir_context); + } + AffineExpr dim_expr = + getAffineDimExpr(dnums.output_feature_dimension(), mlir_context); + kernel_exprs[dnums.kernel_output_feature_dimension()] = dim_expr; + + // Build initial symbol ranges. + std::vector input_symbols = input_spatial_indexing.GetRangeVars(); + std::vector kernel_symbols = + RangeVarsFromTensorSizes(kernel_spatial_sizes); + + // Add symbol for input feature dimension. + input_exprs[dnums.input_feature_dimension()] = + getAffineSymbolExpr(input_symbols.size(), mlir_context); + kernel_exprs[dnums.kernel_input_feature_dimension()] = + getAffineSymbolExpr(kernel_symbols.size(), mlir_context); + + int64_t input_group_size = + kernel_shape.dimensions(dnums.kernel_input_feature_dimension()); + Interval input_feature_range{0, input_group_size - 1}; + input_symbols.push_back({input_feature_range}); + kernel_symbols.push_back({input_feature_range}); + + // With multiple feature groups, the input feature dimension is equally split. + if (convolution->feature_group_count() > 1) { + AffineExpr& input_feature = input_exprs[dnums.input_feature_dimension()]; + int64_t output_group_size = + output_shape.dimensions(dnums.output_feature_dimension()); + int64_t feature_group_size = + output_group_size / convolution->feature_group_count(); + input_feature = dim_expr.floorDiv(feature_group_size) * input_group_size + + input_feature; + } + + // With multiple batch groups, the input batch dimension is equally split. + AffineExpr batch_dim_expr = + getAffineDimExpr(dnums.output_batch_dimension(), mlir_context); + if (convolution->batch_group_count() > 1) { + int64_t batch_group_size = + output_shape.dimensions(dnums.output_batch_dimension()); + AffineExpr batch_group_expr = + getAffineSymbolExpr(input_symbols.size(), mlir_context); + input_symbols.push_back({{0, convolution->batch_group_count() - 1}}); + input_exprs[dnums.input_batch_dimension()] = + batch_group_expr * batch_group_size + batch_dim_expr; + } else { + input_exprs[dnums.input_batch_dimension()] = batch_dim_expr; + } + + // Indexing map for the input value. + IndexingMap inputs_indexing( + AffineMap::get(rank, input_symbols.size(), input_exprs, mlir_context), + DimVarsFromTensorSizes(output_shape.dimensions()), input_symbols, + /*rt_vars=*/{}, input_constraints); + + // Indexing map for the kernel value. + IndexingMap kernel_indexing( + AffineMap::get(rank, kernel_symbols.size(), kernel_exprs, mlir_context), + DimVarsFromTensorSizes(output_shape.dimensions()), kernel_symbols, + /*rt_vars=*/{}); + + return HloInstructionIndexing::FromIndexingMaps( + {inputs_indexing, kernel_indexing}); +} + +// Computes strides for a shape. +std::vector ComputeStrides(absl::Span dims) { + int rank = static_cast(dims.size()); + std::vector strides(rank, 1); + for (int i = rank - 2; i >= 0; --i) { + strides[i] = dims[i + 1] * strides[i + 1]; + } + return strides; +} + +} // namespace + +AffineExpr LinearizeShape(absl::Span dims, + absl::Span dimension_exprs, + MLIRContext* mlir_context) { + AffineExpr linear_index = getAffineConstantExpr(0, mlir_context); + + auto strides = ComputeStrides(dims); + for (auto [stride, dimension_expr] : llvm::zip(strides, dimension_exprs)) { + linear_index = linear_index + dimension_expr * stride; + } + return linear_index; +} + +std::vector DelinearizeIndex(absl::Span dims, + AffineExpr linear_index, + MLIRContext* mlir_context) { + std::vector multi_index; + multi_index.reserve(dims.size()); + + AffineExpr remainder = linear_index; + for (int64_t stride : ComputeStrides(dims)) { + multi_index.push_back(remainder.floorDiv(stride)); + remainder = remainder % stride; + } + return multi_index; +} + +namespace { + +// Computes indexing for "minimal" reshapes, i.e. reshapes that cannot be +// represented by a series of composed reshapes, i.e. when there are no +// subshapes in input and output that have the same number of elements. +// For example, [8, 4] -> [8, 2, 2] is not a minimal reshape, it has matching +// subshapes [8] -> [8] and [4] -> [2, 2]. +// +// There are only 4 types of "minimal" reshapes considers only 4 cases: +// 1. Dimension is not changed, e.g. [8] -> [8] +// 2. Dimension is expanded, e.g. [8] -> [4, 2] +// 3. Dimension is collapsed, e.g. [4, 2] -> [8] +// 4. Dimension is collapsed and expanded, e.g. [8, 16] -> [4, 32] +// +// The function computes indexing maps for these 4 cases, i.e. considers given +// input/output shapes and checks if the shapes are the same, expanded or +// collapsed. Otherwise, performs linearization/delinearization. +void ComputeMinimalReshapeIndexing( + absl::Span input_dims, absl::Span output_dims, + absl::Span output_dims_exprs, + std::vector* exprs, MLIRContext* mlir_context) { + // The shape does not change. + if (input_dims.size() == 1 && output_dims.size() == 1) { + absl::c_copy(output_dims_exprs, std::back_inserter(*exprs)); + return; + } + // Expand shape. + if (input_dims.size() == 1) { + exprs->push_back( + LinearizeShape(output_dims, output_dims_exprs, mlir_context)); + return; + } + // Collapse shape. + if (output_dims.size() == 1) { + auto multi_index = + DelinearizeIndex(input_dims, output_dims_exprs.front(), mlir_context); + absl::c_copy(multi_index, std::back_inserter(*exprs)); + return; + } + // Generic case. + AffineExpr linear_index = + LinearizeShape(output_dims, output_dims_exprs, mlir_context); + auto multi_index = DelinearizeIndex(input_dims, linear_index, mlir_context); + absl::c_copy(multi_index, std::back_inserter(*exprs)); +} + +// Scans input and output shapes from left to right in an attempt to find +// subshapes with the same number of elements and then computes indexing map for +// every pair of subshapes. +// +// Example: +// p0 = f32[4, 8, 12] parameter(0) +// reshape = f32[32, 3, 4] reshape(p0) +// +// This reshape can be represented as a composition of two reshapes. +// The first reshape collapses dimensions first two input dimensions [4, 8] onto +// the output dimension [32]. +// The second reshape expands the input dimension [12] into two output +// dimensions [3, 4]. +// This is an optimization that allows us to construct simpler affine maps, +// otherwise we would need to linearize/delinearize even some of the simpler +// cases. +AffineMap ComputeReshapeIndexingMap(const Shape& input, const Shape& output, + MLIRContext* mlir_context) { + absl::Span input_dims = input.dimensions(); + absl::Span output_dims = output.dimensions(); + + std::vector exprs; + exprs.reserve(input.rank()); + + // If the input shape has no elements (e.g. 1000x10x0 -> 100x100x0), just set + // everything to 0. + if (ShapeUtil::ElementsIn(input) == 0) { + for (int i = 0; i < input.rank(); ++i) { + exprs.push_back(getAffineConstantExpr(0, mlir_context)); + } + return AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, + mlir_context); + } + + std::vector output_dims_exprs; + + // Find subshapes with the same element count and compute indexing for them. + int64_t input_num_elements = 1; + int64_t output_num_elements = 1; + std::vector input_subshape, output_subshape; + size_t input_dim_id = 0, output_dim_id = 0; + while (input_dim_id < input.rank() || output_dim_id < output.rank() || + !input_subshape.empty()) { + if (input_dim_id < input.rank() && + (input_subshape.empty() || input_num_elements < output_num_elements || + input_dims[input_dim_id] == 1)) { + input_num_elements *= input_dims[input_dim_id]; + input_subshape.push_back(input_dims[input_dim_id]); + ++input_dim_id; + continue; + } + if (output_dim_id < output.rank() && + (output_subshape.empty() || output_num_elements < input_num_elements || + output_dims[output_dim_id] == 1)) { + output_num_elements *= output_dims[output_dim_id]; + output_subshape.push_back(output_dims[output_dim_id]); + output_dims_exprs.push_back( + getAffineDimExpr(output_dim_id, mlir_context)); + ++output_dim_id; + continue; + } + ComputeMinimalReshapeIndexing(input_subshape, output_subshape, + output_dims_exprs, &exprs, mlir_context); + input_num_elements = 1; + output_num_elements = 1; + input_subshape.clear(); + output_subshape.clear(); + output_dims_exprs.clear(); + } + return AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, + mlir_context); +}; + +HloInstructionIndexing ComputeOutputToInputReshapeOpIndexing( + const HloReshapeInstruction* reshape, MLIRContext* mlir_context) { + const auto& input = reshape->operand(0)->shape(); + const auto& output = reshape->shape(); + + IndexingMap reshape_indexing_map = IndexingMap::FromTensorSizes( + ComputeReshapeIndexingMap(input, output, mlir_context), + output.dimensions(), {}); + reshape_indexing_map.Simplify(GetIndexingMapForInstruction); + return HloInstructionIndexing::FromIndexingMaps({reshape_indexing_map}); +} +HloInstructionIndexing ComputeInputToOutputReshapeOpIndexing( + const HloReshapeInstruction* reshape, MLIRContext* mlir_context) { + const auto& input = reshape->operand(0)->shape(); + const auto& output = reshape->shape(); + + IndexingMap reshape_indexing_map = IndexingMap::FromTensorSizes( + ComputeReshapeIndexingMap(output, input, mlir_context), + input.dimensions(), {}); + reshape_indexing_map.Simplify(GetIndexingMapForInstruction); + return HloInstructionIndexing::FromIndexingMaps({reshape_indexing_map}); +} + +HloInstructionIndexing ComputeReverseOpIndexing( + const HloReverseInstruction* reverse, MLIRContext* mlir_context) { + absl::flat_hash_set reverse_dims(reverse->dimensions().begin(), + reverse->dimensions().end()); + auto output_dims = reverse->shape().dimensions(); + + std::vector exprs; + exprs.reserve(output_dims.size()); + for (auto [output_dim_id, output_dim] : llvm::enumerate(output_dims)) { + auto dim_expr = getAffineDimExpr(output_dim_id, mlir_context); + if (!reverse_dims.contains(output_dim_id)) { + exprs.push_back(dim_expr); + continue; + } + exprs.push_back(-dim_expr + output_dim - 1); + } + + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, + mlir_context), + output_dims, {}); + + return HloInstructionIndexing::FromIndexingMaps({indexing_map}); +} + +HloInstructionIndexing ComputeOutputToInputSliceOpIndexing( + const HloSliceInstruction* slice, MLIRContext* mlir_context) { + auto output_rank = slice->shape().rank(); + + std::vector exprs; + exprs.reserve(output_rank); + for (int64_t dim = 0; dim < output_rank; ++dim) { + AffineExpr dim_expr = getAffineDimExpr(dim, mlir_context); + exprs.push_back(dim_expr * slice->slice_strides()[dim] + + slice->slice_starts()[dim]); + } + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + AffineMap::get(output_rank, /*symbolCount=*/0, exprs, mlir_context), + slice->shape().dimensions(), {}); + return HloInstructionIndexing::FromIndexingMaps({indexing_map}); +} + +AffineMap ComputeTransposeIndexingMap(absl::Span permutation, + MLIRContext* mlir_context) { + return AffineMap::getPermutationMap( + std::vector(permutation.begin(), permutation.end()), + mlir_context); +} + +HloInstructionIndexing ComputeOutputToInputTransposeOpIndexing( + const HloTransposeInstruction* transpose, MLIRContext* mlir_context) { + AffineMap inverse_permutation = ComputeTransposeIndexingMap( + InversePermutation(transpose->dimensions()), mlir_context); + return HloInstructionIndexing::FromIndexingMaps({IndexingMap::FromTensorSizes( + inverse_permutation, transpose->shape().dimensions(), {})}); +} + +HloInstructionIndexing ComputeInputToOutputTransposeOpIndexing( + const HloTransposeInstruction* transpose, MLIRContext* mlir_context) { + AffineMap forward_permutation = + ComputeTransposeIndexingMap(transpose->dimensions(), mlir_context); + return HloInstructionIndexing::FromIndexingMaps({IndexingMap::FromTensorSizes( + forward_permutation, transpose->operand(0)->shape().dimensions(), {})}); +} + +} // namespace + +IndexingMap GetBitcastMap(const Shape& input_shape, const Shape& output_shape, + MLIRContext* mlir_context) { + ShapeUtil::BitcastDecomposition decomposed_bitcast = + ShapeUtil::DecomposeBitcast(input_shape, output_shape); + + if (std::holds_alternative( + decomposed_bitcast)) { + auto permutation = ShapeUtil::DeduceTransposeDimensionsForBitcast( + input_shape, output_shape); + CHECK(permutation.has_value()) + << "Failed to deduce permutation for a bitcast."; + + return IndexingMap::FromTensorSizes( + ComputeTransposeIndexingMap(permutation.value(), mlir_context), + input_shape.dimensions(), {}); + } + if (std::holds_alternative( + decomposed_bitcast)) { + // Note: ComputeReshapeIndexingMap assumes it's computing an output->input + // indexing, so input and output are reversed. + return IndexingMap::FromTensorSizes( + ComputeReshapeIndexingMap(output_shape, input_shape, mlir_context), + input_shape.dimensions(), {}); + } + // `trt` stands for transpose-reshape-transpose decomposition of bitcast. + auto trt = std::get(decomposed_bitcast); + auto transpose_map_1 = + ComputeTransposeIndexingMap(trt.transpose1_dims, mlir_context); + auto reshape_map = ComputeReshapeIndexingMap( + trt.reshape_shape, trt.transpose1_shape, mlir_context); + auto transpose_map_2 = + ComputeTransposeIndexingMap(trt.transpose2_dims, mlir_context); + auto bitcast_map = + transpose_map_2.compose(reshape_map).compose(transpose_map_1); + return IndexingMap::FromTensorSizes(bitcast_map, input_shape.dimensions(), + {}); +} + +namespace { + +HloInstructionIndexing ComputeOutputToInputBitcastOpIndexing( + const HloInstruction* bitcast, MLIRContext* mlir_context) { + auto bitcast_map = GetBitcastMap(bitcast->shape(), + bitcast->operand(0)->shape(), mlir_context); + bitcast_map.Simplify(GetIndexingMapForInstruction); + return HloInstructionIndexing::FromIndexingMaps({bitcast_map}); +} + +HloInstructionIndexing ComputeInputToOutputBitcastOpIndexing( + const HloInstruction* bitcast, MLIRContext* mlir_context) { + auto bitcast_map = GetBitcastMap(bitcast->operand(0)->shape(), + bitcast->shape(), mlir_context); + bitcast_map.Simplify(GetIndexingMapForInstruction); + return HloInstructionIndexing::FromIndexingMaps({bitcast_map}); +} + +// Converts a layout to a dimensions transposition necessary to get to that +// layout from identity. +std::vector ToTransposeDimensions(const Layout& l) { + std::vector out(l.minor_to_major().begin(), + l.minor_to_major().end()); + absl::c_reverse(out); + return out; +} + +AffineMap GetTilingAffineMap(llvm::ArrayRef exprs, + const Tiling& tiling) { + return AffineMap::get( + /*dimCount=*/6, /*symbolCount=*/tiling.GetShape().size(), exprs, + exprs[0].getContext()); +} + +} // namespace + +IndexingMap CreateIdentityMap(const Shape& shape, MLIRContext* mlir_context) { + if (shape.IsTuple()) { + // Should happen only for variadic reduce. In that case all tuple shapes are + // equal. + return CreateIdentityMap(shape.tuple_shapes(0), mlir_context); + } + + auto dimensions = shape.dimensions(); + IndexingMap identity_map = IndexingMap::FromTensorSizes( + AffineMap::getMultiDimIdentityMap(dimensions.size(), mlir_context), + dimensions, {}); + return identity_map; +} + +llvm::SmallVector DelinearizeInBoundsIndex( + AffineExpr linear, absl::Span sizes, + absl::Span strides) { + llvm::SmallVector result; + result.reserve(sizes.size()); + if (absl::c_linear_search(sizes, 0)) { + for (int dim = 0; dim < sizes.size(); ++dim) { + result.push_back(mlir::getAffineConstantExpr(0, linear.getContext())); + } + return result; + } + + for (auto [size, stride] : llvm::zip(sizes, strides)) { + result.push_back(linear.floorDiv(stride) % size); + } + for (int dim = 0; dim < sizes.size(); ++dim) { + if (sizes[dim] > 1) { + // We assume the linear index is in bounds, so no mod for the first major + // non-degenerate dimension. Degenerate dimensions are already rewritten + // to 0 by operator%. + result[dim] = linear.floorDiv(strides[dim]); + break; + } + } + return result; +} + +IndexingMap GetIndexingMapFromPhysicalLayoutToLogical( + const Shape& shape, MLIRContext* mlir_context) { + if (shape.rank() == 0) { + return IndexingMap(AffineMap::get(mlir_context), + /*dim_vars=*/{}, /*range vars=*/{}, /*rt_vars=*/{}); + } + return IndexingMap::FromTensorSizes( + ComputeTransposeIndexingMap( + InversePermutation(ToTransposeDimensions(shape.layout())), + mlir_context), + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(shape) + .dimensions(), + {}); +} + +IndexingMap GetIndexingMapFromLogicalToPhysicalLayout( + const Shape& shape, MLIRContext* mlir_context) { + if (shape.rank() == 0) { + return IndexingMap(AffineMap::get(mlir_context), + /*dim_vars=*/{}, /*range vars=*/{}, /*rt_vars=*/{}); + } + return IndexingMap::FromTensorSizes( + ComputeTransposeIndexingMap(ToTransposeDimensions(shape.layout()), + mlir_context), + shape.dimensions(), {}); +} + +AffineMap GetBlockOffsetsForTiling(const Tiling& tiling, + MLIRContext* mlir_context) { + auto offsets = DelinearizeInBoundsIndex(getAffineDimExpr(3, mlir_context), + tiling.GetBlockCounts(), + tiling.GetBlockStrides()); + for (auto&& [offset, tile_size] : + llvm::zip(offsets, tiling.GetBlockTileSize())) { + offset = offset * tile_size; + } + return GetTilingAffineMap(offsets, tiling); +} + +AffineMap GetThreadOffsetsForTiling(const Tiling& tiling, + MLIRContext* mlir_context) { + auto offsets = DelinearizeInBoundsIndex(getAffineDimExpr(0, mlir_context), + tiling.GetThreadsPerBlock(), + tiling.GetThreadStrides()); + for (int dim = 0; dim < tiling.GetShape().size(); ++dim) { + if (tiling.GetThreadTileSize()[dim] > 1) { + offsets[dim] = offsets[dim] + getAffineSymbolExpr(dim, mlir_context) * + tiling.GetThreadsPerBlock()[dim]; + } + } + return GetTilingAffineMap(offsets, tiling); +} + +IndexingMap GetIndexingMapForTiling(const Tiling& tiling, + MLIRContext* mlir_context) { + return GetIndexingMapForTiling( + GetBlockOffsetsForTiling(tiling, mlir_context), + GetThreadOffsetsForTiling(tiling, mlir_context), + tiling.GetNumThreadsPerBlock(), tiling.GetNumBlocks(), + tiling.GetThreadTileSize(), tiling.GetShape()); +} + +IndexingMap GetIndexingMapForTiling(AffineMap block_offsets, + AffineMap thread_offsets, + int64_t threads_per_block, + int64_t num_blocks, + absl::Span thread_tile_sizes, + absl::Span tiled_shape) { + auto* mlir_context = block_offsets.getContext(); + llvm::SmallVector offsets; + offsets.reserve(block_offsets.getNumResults()); + for (auto [block, thread] : + llvm::zip(block_offsets.getResults(), thread_offsets.getResults())) { + offsets.push_back(block + thread); + } + std::vector dimension_ranges{ + {{0, threads_per_block - 1}}, {}, {}, {{0, num_blocks - 1}}, {}, {}, + }; + auto affine_map = mlir::AffineMap::get(block_offsets.getNumDims(), + block_offsets.getNumSymbols(), offsets, + mlir_context); + IndexingMap map{affine_map, dimension_ranges, + RangeVarsFromTensorSizes(thread_tile_sizes), /*rt_vars=*/{}}; + for (int i = 0; i < tiled_shape.size(); ++i) { + map.AddConstraint(affine_map.getResult(i), {0, tiled_shape[i] - 1}); + } + return map; +} + +bool HloInstructionIndexing::Simplify() { + bool any_simplified = false; + for (auto& operand_indexing : indexing_maps) { + std::vector to_remove, to_add; + for (IndexingMap map : operand_indexing) { + to_remove.push_back(map); + if (map.IsUndefined()) { + to_add.push_back(map); + } else if (map.Simplify(GetIndexingMapForInstruction)) { + map.RemoveUnusedSymbols(); + } else { + to_remove.pop_back(); + } + } + for (auto& map : to_remove) { + operand_indexing.erase(map); + } + for (auto& map : to_add) { + operand_indexing.insert(map); + } + any_simplified |= !to_remove.empty(); + } + return any_simplified; +} + +HloInstructionIndexing HloInstructionIndexing::FromIndexingMaps( + absl::Span indexing_maps) { + HloInstructionIndexing instr_indexing; + instr_indexing.indexing_maps.resize(indexing_maps.size()); + for (const auto& [index, map] : llvm::enumerate(indexing_maps)) { + instr_indexing.indexing_maps[index].insert(map); + } + return instr_indexing; +} + +std::string HloInstructionIndexing::ToString( + const AffineMapPrinter& printer) const { + std::string s; + std::stringstream ss(s); + Print(ss, printer); + return ss.str(); +} + +void HloInstructionIndexing::Print(std::ostream& out, + const AffineMapPrinter& printer) const { + for (const auto& [operand_id, indexing_maps] : + llvm::enumerate(indexing_maps)) { + out << "operand id = " << operand_id << ' '; + for (const auto& indexing_map : indexing_maps) { + if (indexing_map.IsUndefined()) { + out << "unknown indexing"; + continue; + } + indexing_map.Print(out, printer); + } + } +} + +std::ostream& operator<<(std::ostream& out, + const HloInstructionIndexing& instr_indexing) { + AffineMapPrinter printer; + instr_indexing.Print(out, printer); + return out; +} + +const Shape& GetOutputShape(const HloInstruction* instr, int64_t output_id) { + return instr->shape().IsTuple() + ? ShapeUtil::GetSubshape(instr->shape(), {output_id}) + : instr->shape(); +} + +GroupedByOpIndexingMap GroupIndexingMapsByProducers( + const HloInstructionIndexing& indexing, const HloInstruction* instr) { + GroupedByOpIndexingMap result; + for (const auto& [operand_id, indexing_maps] : + llvm::enumerate(indexing.indexing_maps)) { + result[instr->operand(operand_id)].insert(indexing_maps.begin(), + indexing_maps.end()); + } + return result; +} + +GroupedByOpIndexingMap ComputeGroupedOutputToInputIndexing( + const HloFusionAdaptor& fusion_adaptor, HloInstructionAdaptor target_instr, + MLIRContext* ctx) { + auto initial_map = CreateIdentityMap(target_instr.instruction().shape(), ctx); + + GroupedByOpIndexingMap grouped_indexing_maps; + // If target_instr is a parameter of a fusion, then we create an identity map + // for the fusion operand. + if (fusion_adaptor.ContainsInstruction(target_instr)) { + if (auto parameter_instr = + DynCast(&target_instr.instruction())) { + const HloInstruction* user = parameter_instr->users().front(); + auto fusion_operand = HloInstructionAdaptor(*user).GetOperand( + parameter_instr->parameter_number()); + grouped_indexing_maps[&fusion_operand.instruction()] = {initial_map}; + return grouped_indexing_maps; + } + } + grouped_indexing_maps[&target_instr.instruction()].insert(initial_map); + + auto post_order = fusion_adaptor.MakeInstructionPostOrder(); + + // Iterator in reversed post-order (use-before-def). + auto it = std::find(post_order.rbegin(), post_order.rend(), target_instr); + for (; it != post_order.rend(); ++it) { + auto producer_indexing = ComputeOutputToInputIndexing(&it->instruction(), + /*output_id=*/0, ctx); + auto consumer_indexing_maps = + grouped_indexing_maps.find(&it->instruction()); + if (consumer_indexing_maps == grouped_indexing_maps.end()) { + continue; + } + // Indexing maps have to be copied because of rehashing. Consider using a + // different container to get better performance. + IndexingMapSet consumer_indexing_maps_copy = consumer_indexing_maps->second; + for (const auto& [producer_operand_id, producer_operand_indexing] : + llvm::enumerate(producer_indexing.indexing_maps)) { + auto producer_operand_adaptor = it->GetOperand(producer_operand_id); + for (const IndexingMap& producer_map : producer_operand_indexing) { + for (const IndexingMap& consumer_map : consumer_indexing_maps_copy) { + auto composed_map = ComposeIndexingMaps(consumer_map, producer_map); + composed_map.Simplify(GetIndexingMapForInstruction); + composed_map.RemoveUnusedSymbols(); + grouped_indexing_maps[&producer_operand_adaptor.instruction()].insert( + composed_map); + } + } + } + } + return grouped_indexing_maps; +} + +bool FuseProducerConsumerOutputToInputIndexing( + const HloInstruction* producer_instr, + absl::flat_hash_map* + consumer_indexing, + MLIRContext* mlir_context) { + auto producer_indexing = ComputeOutputToInputIndexing( + producer_instr, /*output_id=*/0, mlir_context); + auto consumer_indexing_maps = (*consumer_indexing)[producer_instr]; + for (const auto& [producer_operand_id, producer_operand_indexing] : + llvm::enumerate(producer_indexing.indexing_maps)) { + const HloInstruction* producer_operand_instr = + producer_instr->operand(producer_operand_id); + for (const IndexingMap& producer_map : producer_operand_indexing) { + for (const IndexingMap& consumer_map : consumer_indexing_maps) { + (*consumer_indexing)[producer_operand_instr].insert( + ComposeIndexingMaps(producer_map, consumer_map)); + } + } + } + consumer_indexing->erase(producer_instr); + return true; +} + +HloInstructionIndexing ComputeOutputToInputIndexing(const HloInstruction* instr, + int output_id, + MLIRContext* ctx) { + if (HloInstruction::IsOpElementwise(instr->opcode())) { + return ComputeOutputToInputCwiseOpIndexing(instr, ctx); + } + if (instr->opcode() == HloOpcode::kBitcast) { + return ComputeOutputToInputBitcastOpIndexing(instr, ctx); + } + if (auto broadcast = DynCast(instr)) { + return ComputeOutputToInputBroadcastOpIndexing(broadcast, ctx); + } + if (auto concat = DynCast(instr)) { + return ComputeOutputToInputConcatenateOpIndexing(concat, ctx); + } + if (auto constant = DynCast(instr)) { + return HloInstructionIndexing{}; + } + if (auto dot = DynCast(instr)) { + return ComputeOutputToInputDotOpIndexing(dot, ctx); + } + if (auto dynamic_slice = DynCast(instr)) { + return ComputeOutputToInputDynamicSliceOpIndexing(dynamic_slice, ctx); + } + if (auto dus = DynCast(instr)) { + return ComputeOutputToInputDynamicUpdateSliceOpIndexing(dus, ctx); + } + if (auto fusion = DynCast(instr)) { + return ComputeOutputToInputFusionOpIndexing(fusion, output_id, ctx); + } + if (auto gather = DynCast(instr)) { + return ComputeOutputToInputGatherOpIndexing(gather, ctx); + } + if (auto iota = DynCast(instr)) { + return HloInstructionIndexing{}; + } + if (auto pad = DynCast(instr)) { + return ComputeOutputToInputPadOpIndexing(pad, ctx); + } + if (auto reduce = DynCast(instr)) { + return ComputeOutputToInputReduceOpIndexing(reduce, output_id, ctx); + } + if (auto reduce_window = DynCast(instr)) { + return ComputeOutputToInputReduceWindowOpIndexing(reduce_window, output_id, + ctx); + } + if (auto convolution = DynCast(instr)) { + return ComputeOutputToInputConvolutionOpIndexing(convolution, ctx); + } + if (auto reshape = DynCast(instr)) { + return ComputeOutputToInputReshapeOpIndexing(reshape, ctx); + } + if (auto reverse = DynCast(instr)) { + return ComputeReverseOpIndexing(reverse, ctx); + } + if (auto slice = DynCast(instr)) { + return ComputeOutputToInputSliceOpIndexing(slice, ctx); + } + if (auto transpose = DynCast(instr)) { + return ComputeOutputToInputTransposeOpIndexing(transpose, ctx); + } + // If we cannot compute output-to-input indexing, we return std::nullopt for + // every op parameter. + return CreateUnknownIndexing(instr->operand_count()); +} + +HloInstructionIndexing ComputeInputToOutputIndexing(const HloInstruction* instr, + int input_id, + MLIRContext* ctx) { + if (HloInstruction::IsOpElementwise(instr->opcode())) { + return ComputeInputToOutputCwiseOpIndexing(instr, ctx); + } + if (instr->opcode() == HloOpcode::kBitcast) { + return ComputeInputToOutputBitcastOpIndexing(instr, ctx); + } + if (auto broadcast = DynCast(instr)) { + return ComputeInputToOutputBroadcastOpIndexing(broadcast, ctx); + } + if (auto concat = DynCast(instr)) { + return ComputeInputToOutputConcatenateOpIndexing(concat, input_id, ctx); + } + if (auto reduce = DynCast(instr)) { + return ComputeInputToOutputReduceOpIndexing(reduce, input_id, ctx); + } + if (auto reshape = DynCast(instr)) { + return ComputeInputToOutputReshapeOpIndexing(reshape, ctx); + } + if (auto reverse = DynCast(instr)) { + return ComputeReverseOpIndexing(reverse, ctx); + } + if (auto transpose = DynCast(instr)) { + return ComputeInputToOutputTransposeOpIndexing(transpose, ctx); + } + if (instr->opcode() == HloOpcode::kTuple) { + return HloInstructionIndexing::FromIndexingMaps( + {CreateIdentityMap(instr->shape().tuple_shapes(input_id), ctx)}); + } + // If we cannot compute input-to-output indexing, we return std::nullopt for + // every op result. + int64_t num_results = + instr->shape().IsTuple() ? instr->shape().tuple_shapes_size() : 1; + return CreateUnknownIndexing(num_results); +} + +IndexingMap ComputeEpilogueInputToOutputIndexing( + const HloInstruction* epilogue_root, MLIRContext* mlir_context, + std::function is_root) { + auto* instr = epilogue_root; + auto root_indexing = CreateIdentityMap(instr->shape(), mlir_context); + while (!is_root(instr)) { + // There can be multiple users, but they must have compatible indexing maps. + auto* user = instr->users().front(); + auto user_indexing = ComputeInputToOutputIndexing( + user, user->operand_index(instr), mlir_context); + root_indexing = root_indexing * *user_indexing.indexing_maps[0].begin(); + root_indexing.Simplify(GetIndexingMapForInstruction); + instr = user; + } + return root_indexing; +} + +IndexingMap GetIndexingMapForInstruction(const HloInstruction* instr, + int64_t operand_idx, + mlir::MLIRContext* mlir_context) { + HloInstructionIndexing indexing = + ComputeOutputToInputIndexing(instr, operand_idx, mlir_context); + return *indexing.indexing_maps[0].begin(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/indexing_analysis.h b/xla/service/gpu/model/indexing_analysis.h new file mode 100644 index 0000000000000..22012ea472f88 --- /dev/null +++ b/xla/service/gpu/model/indexing_analysis.h @@ -0,0 +1,191 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_INDEXING_ANALYSIS_H_ +#define XLA_SERVICE_GPU_MODEL_INDEXING_ANALYSIS_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/types/span.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/fusions/tiling_util.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/shape.h" + +namespace xla { +namespace gpu { + +using IndexingMapSet = absl::flat_hash_set; + +// Contains indexing maps for all N-dimensional tensor input operands that +// correspond to a particular output. +struct HloInstructionIndexing { + std::string ToString( + const AffineMapPrinter& printer = AffineMapPrinter()) const; + void Print(std::ostream& out, const AffineMapPrinter& printer) const; + + // Returns true if the indexing was simplified. + bool Simplify(); + + // Creates a HloInstructionIndexing from a list of indexing maps for all + // operands and sorted w.r.t. operand index, i.e. indexing_maps[i] corresponds + // to operand[i] of the instruction. + static HloInstructionIndexing FromIndexingMaps( + absl::Span indexing_maps); + + // Maps input operand index to the indexing map for one particular output. + std::vector indexing_maps; +}; +std::ostream& operator<<(std::ostream& out, + const HloInstructionIndexing& instr_indexing); + +std::string ToString(const mlir::AffineMap& affine_map); + +// Computes indexing maps for all input operands necessary to compute an element +// of the `output_id` instruction output. +HloInstructionIndexing ComputeOutputToInputIndexing(const HloInstruction* instr, + int output_id, + mlir::MLIRContext* ctx); + +// Computes indexing maps for all output operands that the element of the +// `input_id` instruction input will participate in. +HloInstructionIndexing ComputeInputToOutputIndexing(const HloInstruction* instr, + int input_id, + mlir::MLIRContext* ctx); + +// Computes the indexing for `epilogue_parent`'s epilogue. For example, if +// `epilogue_parent` is a transpose, computes the input to output indexing for +// everything below the transpose. +// +// transpose +// | +// bitcast +// | +// ROOT +// +// Here, the result will be the input to output indexing for the bitcast. +// `epilogue_root` may be identical to the root of the fusion (if there is no +// epilogue). In this case, the result is the identity indexing map. +// Note: this function assumes the epilogue is compatible with +// FindNonTrivialHero, i.e., each instruction in the epilogue only has a single +// user, or the users have identical indexing maps. +IndexingMap ComputeEpilogueInputToOutputIndexing( + const HloInstruction* epilogue_root, mlir::MLIRContext* ctx, + std::function is_root = + [](const HloInstruction* instr) { return instr->IsRoot(); }); + +using GroupedByOpIndexingMap = + absl::flat_hash_map; + +// Computes output-to-input indexing for every instruction within a fusion +// cluster starting with `target_instr` and going from def to use. +GroupedByOpIndexingMap ComputeGroupedOutputToInputIndexing( + const HloFusionAdaptor& fusion_adaptor, HloInstructionAdaptor target_instr, + mlir::MLIRContext* ctx); + +// Groups indexing maps by instructions. +absl::flat_hash_map +GroupIndexingMapsByProducers(const HloInstructionIndexing& indexing, + const HloInstruction* instr); + +// Computes producer indexing maps and fuse/compose them with the consumer +// indexing maps. +bool FuseProducerConsumerOutputToInputIndexing( + const HloInstruction* producer_instr, + absl::flat_hash_map* + consumer_indexing, + mlir::MLIRContext* mlir_context); + +// Creates an indexing map for bitcasting from `input_shape` to `output_shape`. +// Equivalent to linearizing the input_shape index and then delinearizing it +// to output_shape. +IndexingMap GetBitcastMap(const Shape& input_shape, const Shape& output_shape, + mlir::MLIRContext* ctx); + +// Creates an indexing map from the physical layout of the tensor to its logical +// layout. +IndexingMap GetIndexingMapFromPhysicalLayoutToLogical(const Shape& shape, + mlir::MLIRContext* ctx); + +// Creates an indexing map from the logical layout of the tensor to its physical +// layout. +IndexingMap GetIndexingMapFromLogicalToPhysicalLayout(const Shape& shape, + mlir::MLIRContext* ctx); + +// Creates an indexing map from thread and block IDs to elements of the tiled +// shape. Uses the same convention as KernelFusionInterface: dimensions 0 to 2 +// are thread indices (currently only 0 is used), dimensions 3 to 5 are block +// indices (currently only 3 is used). +mlir::AffineMap GetBlockOffsetsForTiling(const Tiling& tiling, + mlir::MLIRContext* ctx); +mlir::AffineMap GetThreadOffsetsForTiling(const Tiling& tiling, + mlir::MLIRContext* ctx); + +// Convenience functions for the two functions above +// (`GetBlockOffsestsForTiling` + `GetThreadOffsetsForTiling`). Also sets up +// the ranges of dimensions and symbols. +IndexingMap GetIndexingMapForTiling(const Tiling& tiling, + mlir::MLIRContext* ctx); +IndexingMap GetIndexingMapForTiling(mlir::AffineMap block_offsets, + mlir::AffineMap thread_offsets, + int64_t threads_per_block, + int64_t num_blocks, + absl::Span thread_tile_sizes, + absl::Span tiled_shape); + +// Returns the shape of the output of the instruction. +const Shape& GetOutputShape(const HloInstruction* instr, int64_t output_id); + +// Computes 1D index given a shape and N-d indexing expressions. +mlir::AffineExpr LinearizeShape( + absl::Span dims, + absl::Span dimension_exprs, + mlir::MLIRContext* mlir_context); + +// Computes N-d indexing expressions given a linear index and a shape. +std::vector DelinearizeIndex(absl::Span dims, + mlir::AffineExpr linear_index, + mlir::MLIRContext* mlir_context); + +// Creates an identity indexing map corresponding to the parameter shape. +IndexingMap CreateIdentityMap(const Shape& shape, + mlir::MLIRContext* mlir_context); + +llvm::SmallVector DelinearizeInBoundsIndex( + mlir::AffineExpr linear, absl::Span sizes, + absl::Span strides); + +// Returns the output-to-input indexing map of the first output of `instr` +IndexingMap GetIndexingMapForInstruction(const HloInstruction* instr, + int64_t operand_idx, + mlir::MLIRContext* mlir_context); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_INDEXING_ANALYSIS_H_ diff --git a/xla/service/gpu/model/indexing_analysis_test.cc b/xla/service/gpu/model/indexing_analysis_test.cc new file mode 100644 index 0000000000000..55b4b18c53242 --- /dev/null +++ b/xla/service/gpu/model/indexing_analysis_test.cc @@ -0,0 +1,2616 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/indexing_analysis.h" + +#include +#include +#include "absl/strings/string_view.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/fusions/tiling_util.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::ExplainMatchResult; +using ::testing::IsEmpty; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +MATCHER_P2(MatchInstrIndexing, operand_id, indexing_map_matchers, "") { + return ExplainMatchResult(Eq(operand_id), arg.operand_id, result_listener) && + ExplainMatchResult(indexing_map_matchers, arg.indexing_maps, + result_listener); +} + +using IndexingAnalysisTest = IndexingTestBase; + +TEST_F(IndexingAnalysisTest, FuseProducerConsumerOutputToInputIndexing) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[1000, 1000] parameter(0) + transpose_p0 = f32[1000, 1000]{0, 1} transpose(p0), dimensions={1, 0} + ROOT a0 = f32[1000, 1000] add(p0, transpose_p0) + } + )"); + const HloInstruction* parameter = root->operand(0); + const HloInstruction* transpose = root->operand(1); + + auto root_indexing = GetOutputToInputIndexing(root); + auto grouped_by_key = GroupIndexingMapsByProducers(root_indexing, root); + + EXPECT_THAT( + grouped_by_key, + UnorderedElementsAre(Pair(parameter, ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 999] + d1 in [0, 999] + )"))), + Pair(transpose, ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 999] + d1 in [0, 999] + )"))))); +} + +TEST_F(IndexingAnalysisTest, ComputeGroupedOutputToInputIndexing) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[1000, 1000] parameter(0) + transpose_p0 = f32[1000, 1000]{0, 1} transpose(p0), dimensions={1, 0} + ROOT a0 = f32[1000, 1000] add(p0, transpose_p0) + } + )"); + const HloInstruction* parameter = root->operand(0); + const HloInstruction* transpose = root->operand(1); + + auto fusion_adaptor = ProducerConsumerFusion(transpose, root); + + auto grouped_indexing = ComputeGroupedOutputToInputIndexing( + fusion_adaptor, fusion_adaptor.GetRoots()[0], &mlir_context_); + EXPECT_THAT(grouped_indexing, + UnorderedElementsAre( + Pair(root, ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 999] + d1 in [0, 999] + )"))), + Pair(transpose, ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 999] + d1 in [0, 999] + )"))), + Pair(parameter, UnorderedElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 999] + d1 in [0, 999] + )"), + MatchIndexingMap(R"( + (d0, d1) -> (d1, d0) + domain: + d0 in [0, 999] + d1 in [0, 999] + )"))))); +} + +TEST_F(IndexingAnalysisTest, + ComputeGroupedOutputToInputIndexing_VariadicReduce) { + auto root = ParseAndGetRoot(R"( + HloModule m + + add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + param_2 = f32[] parameter(2) + param_3 = f32[] parameter(3) + add.0 = f32[] add(param_0, param_2) + add.1 = f32[] add(param_1, param_3) + ROOT t = (f32[], f32[]) tuple(add.0, add.1) + } + + ENTRY entry_computation { + param_0.3 = f32[32,40]{1,0} parameter(0) + param_1.3 = f32[32,40]{1,0} parameter(1) + param_2.2 = f32[] parameter(2) + constant = f32[] constant(0) + ROOT reduce = (f32[32]{0}, f32[32]{0}) + reduce(param_0.3, param_1.3, param_2.2, constant), + dimensions={1}, to_apply=add + } + )"); + auto fusion_adaptor = HloFusionAdaptor::ForInstruction(root); + + auto grouped_indexing = ComputeGroupedOutputToInputIndexing( + *fusion_adaptor, fusion_adaptor->GetRoots()[0], &mlir_context_); + + EXPECT_THAT(grouped_indexing, + UnorderedElementsAre( + Pair(root, ElementsAre(MatchIndexingMap(R"( + (d0) -> (d0) + domain: + d0 in [0, 31] + )"))), + Pair(root->operand(0), ElementsAre(MatchIndexingMap(R"( + (d0)[s0] -> (d0, s0) + domain: + d0 in [0, 31] + s0 in [0, 39] + )"))), + Pair(root->operand(1), ElementsAre(MatchIndexingMap(R"( + (d0)[s0] -> (d0, s0) + domain: + d0 in [0, 31] + s0 in [0, 39] + )"))), + Pair(root->operand(2), ElementsAre(MatchIndexingMap(R"( + (d0) -> () + domain: + d0 in [0, 31] + )"))), + Pair(root->operand(3), ElementsAre(MatchIndexingMap(R"( + (d0) -> () + domain: + d0 in [0, 31] + )"))))); +} + +TEST_F(IndexingAnalysisTest, ComputeGroupedOutputToInputIndexing_SingleOp) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[1000, 1000] parameter(0) + p1 = f32[1000, 1000] parameter(1) + exp0 = f32[1000, 1000] exponential(p1) + ROOT a0 = f32[1000, 1000] add(p0, exp0) + } + )"); + HloComputation* entry_computation = root->parent(); + const HloInstruction* exponential = + entry_computation->GetInstructionWithName("exp0"); + const HloInstruction* parameter = + entry_computation->GetInstructionWithName("p1"); + + auto fusion_adaptor = HloFusionAdaptor::ForInstruction(exponential); + HloInstructionAdaptor parameter_adaptor(*parameter); + auto grouped_indexing = ComputeGroupedOutputToInputIndexing( + *fusion_adaptor, parameter_adaptor, &mlir_context_); + EXPECT_THAT(grouped_indexing, UnorderedElementsAre(Pair( + parameter, ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 999] + d1 in [0, 999] + )"))))); +} + +TEST_F(IndexingAnalysisTest, + ComputeGroupedOutputToInputIndexing_StartNotAtRoot) { + auto root = ParseAndGetRoot(R"( + HloModule m + max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT max = f32[] maximum(p0, p1) + } + f { + p0 = f32[15, 20] parameter(0) + p0_init = f32[] parameter(1) + p0_bcast = f32[15, 32, 20, 64] broadcast(p0), dimensions={0, 2} + + ROOT reduce_2 = f32[15, 64] reduce(p0_bcast, p0_init), + dimensions={1, 2}, to_apply=max + } + ENTRY e { + p0 = f32[15, 20] parameter(0) + p0_init = f32[] constant(-inf) + ROOT fusion = f32[15, 64] fusion(p0, p0_init), kind=kLoop, calls=f + } + )"); + auto fusion_adaptor = HloFusionAdaptor::ForInstruction(root); + auto root_adaptor = fusion_adaptor->GetRoots()[0]; + + auto bcast = root_adaptor.GetOperand(0); + auto parameter_0 = bcast.GetOperand(0); + + auto grouped_indexing = ComputeGroupedOutputToInputIndexing( + *fusion_adaptor, bcast, &mlir_context_); + EXPECT_THAT( + grouped_indexing, + UnorderedElementsAre( + Pair(&bcast.instruction(), ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3) -> (d0, d1, d2, d3) + domain: + d0 in [0, 14] + d1 in [0, 31] + d2 in [0, 19] + d3 in [0, 63] + )"))), + Pair(¶meter_0.instruction(), ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3) -> (d0, d2) + domain: + d0 in [0, 14] + d1 in [0, 31] + d2 in [0, 19] + d3 in [0, 63] + )"))))); +} + +TEST_F(IndexingAnalysisTest, PhysicalLayoutTestOutputPermutation) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[10, 20, 30] parameter(0) + ROOT add0 = f32[10, 20, 30]{1, 0, 2} exponential(p0) + } + )"); + auto input_indexing = GetOutputToInputIndexing(root, /*output_id=*/0, + /*use_physical_layout=*/true); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d1, d2, d0) + domain: + d0 in [0, 29] + d1 in [0, 9] + d2 in [0, 19] + )")))); + + auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0, + /*use_physical_layout=*/true); + EXPECT_THAT(output_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d2, d0, d1) + domain: + d0 in [0, 9] + d1 in [0, 19] + d2 in [0, 29] + )")))); +} + +TEST_F(IndexingAnalysisTest, CopyNothing) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[0, 0]{0,1} parameter(0) + ROOT copy0 = f32[0, 0]{1,0} copy(p0) + } + )"); + auto input_indexing = GetOutputToInputIndexing(root, /*output_id=*/0); + input_indexing.Simplify(); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, -1] + d1 in [0, -1] + )")))); + + auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0); + output_indexing.Simplify(); + EXPECT_THAT(output_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, -1] + d1 in [0, -1] + )")))); +} + +TEST_F(IndexingAnalysisTest, ReshapeNothing) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[1,0,0] parameter(0) + ROOT reshape = f32[0] reshape(p0) + } + )"); + auto input_indexing = GetOutputToInputIndexing(root, /*output_id=*/0); + input_indexing.Simplify(); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0) -> (0, 0, 0) + domain: + d0 in [0, -1] + )")))); + + auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0); + output_indexing.Simplify(); + EXPECT_THAT(output_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (0) + domain: + d0 in [0, 0] + d1 in [0, -1] + d2 in [0, -1] + )")))); +} + +TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputPermutation) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[10, 20, 30]{1, 0, 2} parameter(0) + ROOT add0 = f32[10, 20, 30] exponential(p0) + } + )"); + auto input_indexing = GetOutputToInputIndexing(root, /*output_id=*/0, + /*use_physical_layout=*/true); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d2, d0, d1) + domain: + d0 in [0, 9] + d1 in [0, 19] + d2 in [0, 29] + )")))); + + auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0, + /*use_physical_layout=*/true); + EXPECT_THAT(output_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d1, d2, d0) + domain: + d0 in [0, 29] + d1 in [0, 9] + d2 in [0, 19] + )")))); +} + +TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputAndOutputPermutation) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[10, 20, 30]{1, 0, 2} parameter(0) + ROOT add0 = f32[10, 20, 30]{1, 0, 2} exponential(p0) + } + )"); + auto input_indexing = GetOutputToInputIndexing(root, /*output_id=*/0, + /*use_physical_layout=*/true); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0, d1, d2) + domain: + d0 in [0, 29] + d1 in [0, 9] + d2 in [0, 19] + )")))); + + auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0, + /*use_physical_layout=*/true); + EXPECT_THAT(output_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0, d1, d2) + domain: + d0 in [0, 29] + d1 in [0, 9] + d2 in [0, 19] + )")))); +} + +TEST_F(IndexingAnalysisTest, ElementwiseOp) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[10, 20] parameter(0) + p1 = f32[10, 20] parameter(1) + ROOT add0 = f32[10, 20] add(p0, p1) + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 9] + d1 in [0, 19] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 9] + d1 in [0, 19] + )")))); + + auto output_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); + EXPECT_THAT(output_indexing_0.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 9] + d1 in [0, 19] + )")))); + + auto output_indexing_1 = GetInputToOutputIndexing(root, /*input_id=*/1); + EXPECT_THAT(output_indexing_1.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 9] + d1 in [0, 19] + )")))); +} + +TEST_F(IndexingAnalysisTest, BitcastIsReshape) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[4, 32] parameter(0) + ROOT bitcast = f32[4, 8, 4] bitcast(p0) + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0, d1 * 4 + d2) + domain: + d0 in [0, 3] + d1 in [0, 7] + d2 in [0, 3] + )")))); +} + +TEST_F(IndexingAnalysisTest, BitcastIsTranspose) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[3, 12288, 6, 128] parameter(0) + ROOT bitcast = f32[3, 6, 128, 12288] {2, 1, 3, 0} bitcast(p0) + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3) -> (d0, d3, d1, d2) + domain: + d0 in [0, 2] + d1 in [0, 5] + d2 in [0, 127] + d3 in [0, 12287] + )")))); +} + +TEST_F(IndexingAnalysisTest, BitcastIsTransposeReshapeTranspose) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[16, 17, 3] parameter(0) + ROOT bitcast = f32[51, 16] {0, 1} bitcast(p0) + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d1, d0 floordiv 3, d0 mod 3) + domain: + d0 in [0, 50] + d1 in [0, 15] + )")))); + auto output_indexing = GetInputToOutputIndexing(root); + EXPECT_THAT(output_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d1 * 3 + d2, d0) + domain: + d0 in [0, 15] + d1 in [0, 16] + d2 in [0, 2] + )")))); +} + +TEST_F(IndexingAnalysisTest, BroadcastOp) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[20] parameter(0) + ROOT bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1} + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d1) + domain: + d0 in [0, 9] + d1 in [0, 19] + d2 in [0, 29] + )")))); + + auto output_indexing = GetInputToOutputIndexing(root); + EXPECT_THAT(output_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0)[s0, s1] -> (s0, d0, s1) + domain: + d0 in [0, 19] + s0 in [0, 9] + s1 in [0, 29] + )")))); +} + +TEST_F(IndexingAnalysisTest, ConstantOp) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + ROOT c1 = bf16[17, 22] constant(1) + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, IsEmpty()); +} + +TEST_F(IndexingAnalysisTest, ConcatenateOp) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[2, 5, 7] parameter(0) + p1 = f32[2, 11, 7] parameter(1) + p2 = f32[2, 17, 7] parameter(2) + ROOT concat = f32[2, 33, 7] concatenate( + f32[2, 5, 7] p0, f32[2, 11, 7] p1, f32[2, 17, 7] p2), dimensions={1} + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0, d1, d2) + domain: + d0 in [0, 1] + d1 in [0, 4] + d2 in [0, 6] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0, d1 - 5, d2) + domain: + d0 in [0, 1] + d1 in [5, 15] + d2 in [0, 6] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0, d1 - 16, d2) + domain: + d0 in [0, 1] + d1 in [16, 32] + d2 in [0, 6] + )")))); + + auto output_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); + EXPECT_THAT(output_indexing_0.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0, d1, d2) + domain: + d0 in [0, 1] + d1 in [0, 4] + d2 in [0, 6] + )")))); + + auto output_indexing_1 = GetInputToOutputIndexing(root, /*input_id=*/1); + EXPECT_THAT(output_indexing_1.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0, d1 + 5, d2) + domain: + d0 in [0, 1] + d1 in [0, 10] + d2 in [0, 6] + )")))); + + auto output_indexing_2 = GetInputToOutputIndexing(root, /*input_id=*/2); + EXPECT_THAT(output_indexing_2.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0, d1 + 16, d2) + domain: + d0 in [0, 1] + d1 in [0, 16] + d2 in [0, 6] + )")))); +} + +TEST_F(IndexingAnalysisTest, DynamicSliceOp) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + %src = s32[2,2,258] parameter(0) + %of1 = s32[] parameter(1) + %of2 = s32[] parameter(2) + %of3 = s32[] parameter(3) + ROOT %ds = s32[1,2,32] dynamic-slice(s32[2,2,258] %src, + s32[] %of1, s32[] %of2, s32[] %of3), + dynamic_slice_sizes={1, 2, 32} + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2)[s0, s1, s2] -> (d0 + s0, d1 + s1, d2 + s2) + domain: + d0 in [0, 0] + d1 in [0, 1] + d2 in [0, 31] + s0 in [0, 1] + hlo: %of1 = s32[] parameter(1) + (d0, d1, d2) -> () + s1 in [0, 0] + hlo: %of2 = s32[] parameter(2) + (d0, d1, d2) -> () + s2 in [0, 226] + hlo: %of3 = s32[] parameter(3) + (d0, d1, d2) -> () + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> () + domain: + d0 in [0, 0] + d1 in [0, 1] + d2 in [0, 31] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> () + domain: + d0 in [0, 0] + d1 in [0, 1] + d2 in [0, 31] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> () + domain: + d0 in [0, 0] + d1 in [0, 1] + d2 in [0, 31] + )")))); +} + +TEST_F(IndexingAnalysisTest, DynamicUpdateSliceOp) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + %src = s32[20,30] parameter(0) + %upd = s32[5,10] parameter(1) + %of1 = s32[] parameter(2) + %of2 = s32[] parameter(3) + ROOT %dus = s32[20,30] dynamic-update-slice( + s32[20,30] %src, s32[5,10] %upd, s32[] %of1, s32[] %of2) + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 19] + d1 in [0, 29] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1)[s0, s1] -> (d0 - s0, d1 - s1) + domain: + d0 in [0, 19] + d1 in [0, 29] + s0 in [0, 15] + hlo: %of1 = s32[] parameter(2) + (d0, d1) -> () + s1 in [0, 20] + hlo: %of2 = s32[] parameter(3) + (d0, d1) -> () + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 19] + d1 in [0, 29] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 19] + d1 in [0, 29] + )")))); +} + +TEST_F(IndexingAnalysisTest, FusionOpWithSingleBinaryOp) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + f { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + ROOT a0 = f32[100] add(p0, p1) + } + ENTRY e { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + ROOT fusion = f32[100] fusion(p0, p1), kind=kLoop, calls=f + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0) -> (d0) + domain: + d0 in [0, 99] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0) -> (d0) + domain: + d0 in [0, 99] + )")))); +} + +TEST_F(IndexingAnalysisTest, FusionOpWithDot) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + f { + p0 = s8[3,12288,6,128]{3,2,1,0} parameter(0) + bitcast1 = s8[3,6,128,12288]{2,1,3,0} bitcast(p0) + copy1 = s8[3,6,128,12288]{3,2,1,0} copy(bitcast1) + bitcast2 = s8[2304,12288]{1,0} bitcast(copy1) + convert1 = bf16[2304,12288]{1,0} convert(bitcast2) + bitcast3 = bf16[2304,16,768]{2,1,0} bitcast(convert1) + p3 = bf16[16,12288]{1,0} parameter(3) + convert2 = f32[16,12288]{1,0} convert(p3) + p4 = bf16[16,12288]{1,0} parameter(4) + convert3 = f32[16,12288]{1,0} convert(p4) + add1 = f32[16,12288]{1,0} add(convert2, convert3) + p2 = bf16[16]{0} parameter(2) + convert15 = f32[16]{0} convert(p2) + rsqrt = f32[16]{0} rsqrt(convert15) + convert4 = bf16[16]{0} convert(rsqrt) + bcast1 = bf16[16,12288]{1,0} broadcast(convert4), dimensions={0} + convert5 = f32[16,12288]{1,0} convert(bcast1) + multiply1 = f32[16,12288]{1,0} multiply(add1, convert5) + p1 = bf16[12288]{0} parameter(1) + convert6 = f32[12288]{0} convert(p1) + c1 = bf16[] constant(1) + bcast2 = bf16[12288]{0} broadcast(c1), dimensions={} + convert7 = f32[12288]{0} convert(bcast2) + add2 = f32[12288]{0} add(convert6, convert7) + convert8 = bf16[12288]{0} convert(add2) + bcast3 = bf16[16,12288]{1,0} broadcast(convert8), dimensions={1} + convert9 = f32[16,12288]{1,0} convert(bcast3) + multiply2 = f32[16,12288]{1,0} multiply(multiply1, convert9) + convert10 = bf16[16,12288]{1,0} convert(multiply2) + bcast4 = bf16[16,16,768]{2,1,0} bitcast(convert10) + dot = bf16[16,2304,16]{2,1,0} dot(bitcast3, bcast4), + lhs_batch_dims={1}, lhs_contracting_dims={2}, + rhs_batch_dims={1}, rhs_contracting_dims={2} + bcast5 = bf16[16,3,6,128,16]{4,3,2,1,0} bitcast(dot) + copy2 = bf16[16,3,6,128,16]{3,2,4,1,0} copy(bcast5) + convert13 = f32[16,3,6,128,16]{3,2,4,1,0} convert(copy2) + p5 = bf16[3,6,128]{2,1,0} parameter(5) + bcast6 = bf16[3,6,128,16]{2,1,3,0} broadcast(p5), dimensions={0,1,2} + convert11 = f32[3,6,128,16]{2,1,3,0} convert(bcast6) + bcast7 = f32[16,3,6,128,16]{3,2,4,1,0} broadcast(convert11), + dimensions={1,2,3,4} + multiply3 = f32[16,3,6,128,16]{3,2,4,1,0} multiply(convert13, bcast7) + convert12 = bf16[16,3,6,128,16]{3,2,4,1,0} convert(multiply3) + ROOT bcast8 = bf16[16,16,3,1,6,128]{5,4,1,3,2,0} bitcast(convert12) + } + ENTRY e { + p0 = s8[3,12288,6,128]{3,2,1,0} parameter(0) + p1 = bf16[12288]{0} parameter(1) + p2 = bf16[16]{0} parameter(2) + p3 = bf16[16,12288]{1,0} parameter(3) + p4 = bf16[16,12288]{1,0} parameter(4) + p5 = bf16[3,6,128]{2,1,0} parameter(5) + ROOT fusion = bf16[16,16,3,1,6,128]{5,4,1,3,2,0} + fusion(p0, p1, p2, p3, p4, p5), kind=kLoop, calls=f + } + )")); + + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3, d4, d5)[s0] -> (d2, d0 * 768 + s0, d4, d5) + domain: + d0 in [0, 15] + d1 in [0, 15] + d2 in [0, 2] + d3 in [0, 0] + d4 in [0, 5] + d5 in [0, 127] + s0 in [0, 767] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3, d4, d5)[s0] -> (d0 * 768 + s0) + domain: + d0 in [0, 15] + d1 in [0, 15] + d2 in [0, 2] + d3 in [0, 0] + d4 in [0, 5] + d5 in [0, 127] + s0 in [0, 767] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3, d4, d5) -> (d1) + domain: + d0 in [0, 15] + d1 in [0, 15] + d2 in [0, 2] + d3 in [0, 0] + d4 in [0, 5] + d5 in [0, 127] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0) + domain: + d0 in [0, 15] + d1 in [0, 15] + d2 in [0, 2] + d3 in [0, 0] + d4 in [0, 5] + d5 in [0, 127] + s0 in [0, 767] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0) + domain: + d0 in [0, 15] + d1 in [0, 15] + d2 in [0, 2] + d3 in [0, 0] + d4 in [0, 5] + d5 in [0, 127] + s0 in [0, 767] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3, d4, d5) -> (d2, d4, d5) + domain: + d0 in [0, 15] + d1 in [0, 15] + d2 in [0, 2] + d3 in [0, 0] + d4 in [0, 5] + d5 in [0, 127] + )")))); +} + +TEST_F(IndexingAnalysisTest, FusionOpWithSoftmax) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + add_computation { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + max_computation { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT max = f32[] maximum(p0, p1) + } + softmax { + p0 = f32[2,65,125]{2,1,0} parameter(0) + bitcast0 = f32[65,2,125]{2,1,0} bitcast(p0) + constant_neg_inf_1 = f32[] constant(-inf) + reduce0 = f32[2,65]{1,0} reduce(p0, constant_neg_inf_1), + dimensions={2}, to_apply=max_computation + bitcast1 = f32[130]{0} bitcast(reduce0) + bcast1 = f32[130,125]{1,0} broadcast(bitcast1), dimensions={0} + bitcast2 = f32[65,2,125]{2,1,0} bitcast(bcast1) + subtract0 = f32[65,2,125]{2,1,0} subtract(bitcast0, bitcast2) + exponential0 = f32[65,2,125]{2,1,0} exponential(subtract0) + bitcast3 = f32[65,2,125]{2,1,0} bitcast(p0) + reduce1 = f32[2,65]{1,0} reduce(p0, constant_neg_inf_1), + dimensions={2}, to_apply=max_computation + bitcast4 = f32[130]{0} bitcast(reduce1) + bcast2 = f32[130,125]{1,0} broadcast(bitcast4), dimensions={0} + bitcast5 = f32[65,2,125]{2,1,0} bitcast(bcast2) + subtract1 = f32[65,2,125]{2,1,0} subtract(bitcast3, bitcast5) + exponential1 = f32[65,2,125]{2,1,0} exponential(subtract1) + constant_zero_1 = f32[] constant(0) + reduce2 = f32[65,2]{1,0} reduce(exponential1, constant_zero_1), + dimensions={2}, to_apply=add_computation + bitcast6 = f32[130]{0} bitcast(reduce2) + bcast3 = f32[130,125]{1,0} broadcast(bitcast6), dimensions={0} + bitcast7 = f32[65,2,125]{2,1,0} bitcast(bcast3) + divide = f32[65,2,125]{2,1,0} divide(exponential0, bitcast7) + ROOT bitcast8 = f32[2,65,125]{2,1,0} bitcast(divide) + } + ENTRY e { + p0 = f32[2,65,125]{2,1,0} parameter(0) + ROOT fusion = f32[2,65,125]{2,1,0} + fusion(p0), kind=kLoop, calls=softmax + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(UnorderedElementsAre(MatchIndexingMap(R"( + (d0, d1, d2)[s0] -> (d0, d1, s0) + domain: + d0 in [0, 1] + d1 in [0, 64] + d2 in [0, 124] + s0 in [0, 124] + )"), + MatchIndexingMap(R"( + (d0, d1, d2) -> (d0, d1, d2) + domain: + d0 in [0, 1] + d1 in [0, 64] + d2 in [0, 124] + )")))); +} + +TEST_F(IndexingAnalysisTest, FusionOpTensorPlusTransposedTensor) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + f { + p0 = f32[1000, 1000] parameter(0) + transpose_p0 = f32[1000, 1000]{0, 1} transpose(p0), dimensions={1, 0} + ROOT a0 = f32[1000, 1000] add(p0, transpose_p0) + } + ENTRY e { + p0 = f32[1000,1000] parameter(0) + ROOT fusion = f32[1000,1000] fusion(p0), kind=kLoop, calls=f + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(UnorderedElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 999] + d1 in [0, 999] + )"), + MatchIndexingMap(R"( + (d0, d1) -> (d1, d0) + domain: + d0 in [0, 999] + d1 in [0, 999] + )")))); +} + +TEST_F(IndexingAnalysisTest, FusionExponentialDuplication) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule test_module + + fused_computation { + p0 = f32[4] parameter(0) + p1 = f32[4] parameter(1) + add0 = f32[4] add(p0, p1) + slice1.0 = f32[3] slice(add0), slice={[0:3]} + slice1.1 = f32[3] slice(add0), slice={[1:4]} + add1 = f32[3]{0} add(slice1.0, slice1.1) + slice2.0 = f32[2] slice(add1), slice={[0:2]} + slice2.1 = f32[2] slice(add1), slice={[1:3]} + ROOT add2 = f32[2] add(slice2.0, slice2.1) + } + + ENTRY entry_computation { + p0 = f32[4] parameter(0) + p1 = f32[4] parameter(1) + ROOT fusion = f32[2] fusion(p0, p1), kind=kLoop, + calls=fused_computation + })")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(UnorderedElementsAre(MatchIndexingMap(R"( + (d0) -> (d0 + 1) + domain: + d0 in [0, 1] + )"), + MatchIndexingMap(R"( + (d0) -> (d0) + domain: + d0 in [0, 1] + )"), + MatchIndexingMap(R"( + (d0) -> (d0 + 2) + domain: + d0 in [0, 1] + )")), + UnorderedElementsAre(MatchIndexingMap(R"( + (d0) -> (d0 + 2) + domain: + d0 in [0, 1] + )"), + MatchIndexingMap(R"( + (d0) -> (d0 + 1) + domain: + d0 in [0, 1] + )"), + MatchIndexingMap(R"( + (d0) -> (d0) + domain: + d0 in [0, 1] + )")))); +} + +TEST_F(IndexingAnalysisTest, GatherOp) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY main { + operand = f32[33,76,70] parameter(0) + indices = s32[1806,2] parameter(1) + ROOT r = f32[1806,7,8,4] gather(operand, indices), offset_dims={1,2,3}, + collapsed_slice_dims={}, start_index_map={0,1}, + index_vector_dim=1, slice_sizes={7,8,4} + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3)[s0, s1] -> (d1 + s0, d2 + s1, d3) + domain: + d0 in [0, 1805] + d1 in [0, 6] + d2 in [0, 7] + d3 in [0, 3] + s0 in [0, 26] + hlo: %indices = s32[1806,2]{1,0} parameter(1) + (d0, d1, d2, d3) -> (d0, 0) + s1 in [0, 68] + hlo: %indices = s32[1806,2]{1,0} parameter(1) + (d0, d1, d2, d3) -> (d0, 1) + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3)[s0] -> (d0, s0) + domain: + d0 in [0, 1805] + d1 in [0, 6] + d2 in [0, 7] + d3 in [0, 3] + s0 in [0, 1] + )")))); +} + +TEST_F(IndexingAnalysisTest, FusionOpWithReduceOfReduce) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT max = f32[] maximum(p0, p1) + } + f { + p0 = f32[150, 20, 10, 50] parameter(0) + p0_init = f32[] parameter(1) + reduce_1 = f32[20, 10] reduce(p0, p0_init), + dimensions={0, 3}, to_apply=max + ROOT reduce_2 = f32[10] reduce(reduce_1, p0_init), + dimensions={0}, to_apply=max + } + ENTRY e { + p0 = f32[150, 20, 10, 50] parameter(0) + p0_init = f32[] constant(-inf) + ROOT fusion = f32[10] fusion(p0, p0_init), kind=kLoop, calls=f + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0)[s0, s1, s2] -> (s0, s2, d0, s1) + domain: + d0 in [0, 9] + s0 in [0, 149] + s1 in [0, 49] + s2 in [0, 19] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0) -> () + domain: + d0 in [0, 9] + )")))); +} + +TEST_F(IndexingAnalysisTest, FusionOpWithReduceOfBroadcast) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT max = f32[] maximum(p0, p1) + } + f { + p0 = f32[15, 20] parameter(0) + p0_init = f32[] parameter(1) + p0_bcast = f32[15, 32, 20, 64] broadcast(p0), dimensions={0, 2} + + ROOT reduce_2 = f32[15, 64] reduce(p0_bcast, p0_init), + dimensions={1, 2}, to_apply=max + } + ENTRY e { + p0 = f32[15, 20] parameter(0) + p0_init = f32[] constant(-inf) + ROOT fusion = f32[15, 64] fusion(p0, p0_init), kind=kLoop, calls=f + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1)[s0] -> (d0, s0) + domain: + d0 in [0, 14] + d1 in [0, 63] + s0 in [0, 19] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 14] + d1 in [0, 63] + )")))); +} + +TEST_F(IndexingAnalysisTest, FusionOpWithTransposeOfTranspose) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + f { + p0 = f32[20, 10, 50] parameter(0) + + lhs_transpose_1 = f32[10, 20, 50] + transpose(p0), dimensions={1, 0, 2} + lhs_e = f32[10, 20, 50] exponential(lhs_transpose_1) + lhs_transpose_2 = f32[10, 50, 20] + transpose(lhs_e), dimensions={0, 2, 1} + + rhs_transpose_1 = f32[50, 10, 20] + transpose(p0), dimensions={2, 1, 0} + rhs_log = f32[50, 10, 20] exponential(rhs_transpose_1) + rhs_transpose_2 = f32[10, 50, 20] + transpose(rhs_log), dimensions={1, 0, 2} + + ROOT add = f32[10, 50, 20] add(lhs_transpose_2, rhs_transpose_2) + } + ENTRY e { + p0 = f32[20, 10, 50] parameter(0) + ROOT fusion = f32[10, 50, 20] fusion(p0), kind=kLoop, calls=f + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d2, d0, d1) + domain: + d0 in [0, 9] + d1 in [0, 49] + d2 in [0, 19] + )")))); +} + +TEST_F(IndexingAnalysisTest, FusionOpWithReducedSlice) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT max = f32[] maximum(p0, p1) + } + f { + p0 = f32[150, 64, 1024] parameter(0) + p0_init = f32[] parameter(1) + p0_slice = f32[16, 32, 128] slice(f32[150, 64, 1024] p0), + slice={[5:21:1], [0:64:2], [50:434:3]} + ROOT reduce = f32[32] reduce(p0_slice, p0_init), + dimensions={0, 2}, to_apply=max + } + ENTRY e { + p0 = f32[150, 64, 1024] parameter(0) + p0_init = f32[] constant(-inf) + ROOT fusion = f32[32] fusion(p0, p0_init), kind=kLoop, calls=f + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0)[s0, s1] -> (s0 + 5, d0 * 2, s1 * 3 + 50) + domain: + d0 in [0, 31] + s0 in [0, 15] + s1 in [0, 127] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0) -> () + domain: + d0 in [0, 31] + )")))); +} + +TEST_F(IndexingAnalysisTest, FusionOpWithReshape_CollapseOfExpand) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + f { + p0 = f32[128] parameter(0) + expand = f32[8, 16] reshape(p0) + ROOT collapse = f32[128] reshape(expand) + } + ENTRY e { + p0 = f32[128] parameter(0) + ROOT fusion = f32[128] fusion(p0), kind=kLoop, calls=f + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0) -> (d0) + domain: + d0 in [0, 127] + )")))); +} + +TEST_F(IndexingAnalysisTest, FusionOpWithReshape_ExpandOfCollapse) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + f { + p0 = f32[8, 16] parameter(0) + collapse = f32[128] reshape(p0) + ROOT expand = f32[8, 16] reshape(collapse) + } + ENTRY e { + p0 = f32[8, 16] parameter(0) + ROOT fusion = f32[8, 16] fusion(p0), kind=kLoop, calls=f + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 7] + d1 in [0, 15] + )")))); +} + +TEST_F(IndexingAnalysisTest, FusionOpWithReshape_ChainedGenericReshapes) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + f { + p0 = f32[10, 10, 10] parameter(0) + reshape1 = f32[50, 20] reshape(p0) + ROOT reshape2 = f32[10, 10, 10] reshape(reshape1) + } + ENTRY e { + p0 = f32[10, 10, 10] parameter(0) + ROOT fusion = f32[10, 10, 10] fusion(p0), kind=kLoop, calls=f + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0, d1, d2) + domain: + d0 in [0, 9] + d1 in [0, 9] + d2 in [0, 9] + )")))); +} + +TEST_F(IndexingAnalysisTest, FusionOpWithSliceOfSlice) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + f { + p0 = f32[150, 64, 1024] parameter(0) + p0_slice_1 = f32[16, 32, 128] slice(f32[150, 64, 1024] p0), + slice={[5:21:1], [0:64:2], [50:434:3]} + ROOT p0_slice_2 = f32[7, 9, 24] slice(f32[16, 32, 128] p0_slice_1), + slice={[3:16:2], [4:30:3], [5:100:4]} + } + ENTRY e { + p0 = f32[150, 64, 1024] parameter(0) + ROOT fusion = f32[7, 9, 24] fusion(p0), kind=kLoop, calls=f + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0 * 2 + 8, + d1 * 6 + 8, + d2 * 12 + 65) + domain: + d0 in [0, 6] + d1 in [0, 8] + d2 in [0, 23] + )")))); +} + +TEST_F(IndexingAnalysisTest, FusionOpWithDynSliceOfDynSlice) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + f { + %src = s32[150, 64] parameter(0) + %of11 = s32[] parameter(1) + %of12 = s32[] parameter(2) + %of21 = s32[] parameter(3) + %of22 = s32[] parameter(4) + + %ds1 = s32[50, 32] dynamic-slice(s32[150, 64] %src, + s32[] %of11, s32[] %of12), dynamic_slice_sizes={50, 32} + + ROOT %ds2 = s32[25, 16] dynamic-slice(s32[50, 32] %ds1, + s32[] %of21, s32[] %of22), dynamic_slice_sizes={25, 16} + } + ENTRY e { + %p0 = s32[150, 64] parameter(0) + %p1 = s32[] parameter(1) + %p2 = s32[] parameter(2) + %p3 = s32[] parameter(3) + %p4 = s32[] parameter(4) + ROOT fusion = s32[25, 16] fusion(p0, p1, p2, p3, p4), + kind=kLoop, calls=f + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1)[s0, s1, s2, s3] -> (d0 + s0 + s2, d1 + s1 + s3) + domain: + d0 in [0, 24] + d1 in [0, 15] + s0 in [0, 100] + hlo: %of11 = s32[] parameter(1) + (d0, d1) -> () + s1 in [0, 32] + hlo: %of12 = s32[] parameter(2) + (d0, d1) -> () + s2 in [0, 25] + hlo: %of21 = s32[] parameter(3) + (d0, d1) -> () + s3 in [0, 16] + hlo: %of22 = s32[] parameter(4) + (d0, d1) -> () + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 24] + d1 in [0, 15] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 24] + d1 in [0, 15] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 24] + d1 in [0, 15] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 24] + d1 in [0, 15] + )")))); +} + +TEST_F(IndexingAnalysisTest, FusionOpSliceOfAllConcatenateOpInputs) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + f { + p0 = f32[2, 5, 7] parameter(0) + p1 = f32[2, 11, 7] parameter(1) + p2 = f32[2, 17, 7] parameter(2) + concat = f32[2, 33, 7] concatenate( + f32[2, 5, 7] p0, f32[2, 11, 7] p1, f32[2, 17, 7] p2), dimensions={1} + ROOT slice = f32[2, 11, 7] slice(f32[2, 33, 7] concat), + slice={[0:2:1], [0:33:3], [0:7:1]} + } + ENTRY e { + p0 = f32[2, 5, 7] parameter(0) + p1 = f32[2, 11, 7] parameter(1) + p2 = f32[2, 17, 7] parameter(2) + ROOT fusion = f32[2, 11, 7] fusion(p0, p1, p2), kind=kLoop, calls=f + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0, d1 * 3, d2) + domain: + d0 in [0, 1] + d1 in [0, 1] + d2 in [0, 6] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0, d1 * 3 - 5, d2) + domain: + d0 in [0, 1] + d1 in [2, 5] + d2 in [0, 6] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0, d1 * 3 - 16, d2) + domain: + d0 in [0, 1] + d1 in [6, 10] + d2 in [0, 6] + )")))); +} + +TEST_F(IndexingAnalysisTest, FusionOpSliceOfOneOfConcatenateOpInputs) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + f { + p0 = f32[2, 5, 7] parameter(0) + p1 = f32[2, 11, 7] parameter(1) + p2 = f32[2, 17, 7] parameter(2) + concat = f32[2, 33, 7] concatenate( + f32[2, 5, 7] p0, f32[2, 11, 7] p1, f32[2, 17, 7] p2), dimensions={1} + ROOT slice = f32[2, 3, 7] slice(f32[2, 33, 7] concat), + slice={[0:2:1], [0:5:2], [0:7:1]} + } + ENTRY e { + p0 = f32[2, 5, 7] parameter(0) + p1 = f32[2, 11, 7] parameter(1) + p2 = f32[2, 17, 7] parameter(2) + ROOT fusion = f32[2, 3, 7] fusion(p0, p1, p2), kind=kLoop, calls=f + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0, d1 * 2, d2) + domain: + d0 in [0, 1] + d1 in [0, 2] + d2 in [0, 6] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0, d1 * 2 - 5, d2) + domain: + d0 in [0, 1] + d1 in [3, 2] + d2 in [0, 6] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0, d1 * 2 - 16, d2) + domain: + d0 in [0, 1] + d1 in [8, 2] + d2 in [0, 6] + )")))); +} + +TEST_F(IndexingAnalysisTest, FusionOpReshapeOfConcat) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + f { + p0 = f32[2] parameter(0) + p1 = f32[30] parameter(1) + concat = f32[32] concatenate(f32[2] p0, f32[30] p1), dimensions={0} + ROOT reshape = f32[4, 8] reshape(concat) + } + ENTRY e { + p0 = f32[2] parameter(0) + p1 = f32[30] parameter(1) + ROOT fusion = f32[4, 8] fusion(p0, p1), kind=kLoop, calls=f + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0 * 8 + d1) + domain: + d0 in [0, 3] + d1 in [0, 7] + d0 * 8 + d1 in [0, 1] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0 * 8 + d1 - 2) + domain: + d0 in [0, 3] + d1 in [0, 7] + d0 * 8 + d1 in [2, 31] + )")))); +} + +TEST_F(IndexingAnalysisTest, IotaOp) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + ROOT iota = s32[5,5,111,42] iota(), iota_dimension=0 + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, IsEmpty()); +} + +TEST_F(IndexingAnalysisTest, ReshapeOpCollapseShape) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[4,8] parameter(0) + ROOT reshape = f32[32] reshape(p0) + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0) -> (d0 floordiv 8, d0 mod 8) + domain: + d0 in [0, 31] + )")))); +} + +TEST_F(IndexingAnalysisTest, ReshapeOpExpandShape) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[32] parameter(0) + ROOT reshape = f32[4, 8] reshape(p0) + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0 * 8 + d1) + domain: + d0 in [0, 3] + d1 in [0, 7] + )")))); +} + +TEST_F(IndexingAnalysisTest, ReshapeOpExpandAndCollapseShape) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[4, 8, 12] parameter(0) + ROOT reshape = f32[32, 3, 4] reshape(p0) + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0 floordiv 8, d0 mod 8, d1 * 4 + d2) + domain: + d0 in [0, 31] + d1 in [0, 2] + d2 in [0, 3] + )")))); + + auto output_indexing = GetInputToOutputIndexing(root); + EXPECT_THAT(output_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0 * 8 + d1, d2 floordiv 4, d2 mod 4) + domain: + d0 in [0, 3] + d1 in [0, 7] + d2 in [0, 11] + )")))); +} + +TEST_F(IndexingAnalysisTest, ReshapeOpExpandSubshapeOnly) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[16, 8] parameter(0) + ROOT reshape = f32[4, 4, 8] reshape(p0) + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0 * 4 + d1, d2) + domain: + d0 in [0, 3] + d1 in [0, 3] + d2 in [0, 7] + )")))); +} + +TEST_F(IndexingAnalysisTest, ReshapeOpGenericReshape2DTo3D) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[4,8] parameter(0) + ROOT reshape = f32[2, 4, 4] reshape(p0) + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0 * 2 + d1 floordiv 2, + d2 + (d1 mod 2) * 4) + domain: + d0 in [0, 1] + d1 in [0, 3] + d2 in [0, 3] + )")))); +} + +TEST_F(IndexingAnalysisTest, ReshapeOpGenericReshape3DTo2D) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[2, 4, 4] parameter(0) + ROOT reshape = f32[4, 8] reshape(p0) + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0 floordiv 2, + d1 floordiv 4 + (d0 mod 2) * 2, + d1 mod 4) + domain: + d0 in [0, 3] + d1 in [0, 7] + )")))); +} + +TEST_F(IndexingAnalysisTest, PadOp) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[4, 4] parameter(0) + p1 = f32[] parameter(1) + ROOT pad = f32[12, 16] pad(p0, p1), padding=1_4_1x4_8_0 + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> ( + (d0 - 1) floordiv 2, + d1 - 4 + ) + domain: + d0 in [1, 7] + d1 in [4, 7] + (d0 - 1) mod 2 in [0, 0] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 11] + d1 in [0, 15] + )")))); +} + +TEST_F(IndexingAnalysisTest, PadOpNoInterior) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[2,8] parameter(0) + p1 = f32[] parameter(1) + ROOT pad = f32[10,8] pad(p0, p1), padding=1_7x0_0 + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0 - 1, d1) + domain: + d0 in [1, 2] + d1 in [0, 7] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 9] + d1 in [0, 7] + )")))); +} + +TEST_F(IndexingAnalysisTest, PadOpNegativePadding) { + // The interior padding is applied first (even with negative padding), so we + // get a size of 5 (7 + 6 - 8). + // in: 0 1 2 3 4 5 6 + // padded: 0 p 1 p 2 p 3 p 4 p 5 p 6 + // sliced: p 2 p 3 p + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[7] parameter(0) + p1 = f32[] parameter(1) + ROOT pad = f32[5] pad(p0, p1), padding=-3_-5_1 + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0) -> ((d0 + 3) floordiv 2) + domain: + d0 in [0, 4] + (d0 + 3) mod 2 in [0, 0] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0) -> () + domain: + d0 in [0, 4] + )")))); +} + +TEST_F(IndexingAnalysisTest, ReduceOp) { + auto root = ParseAndGetRoot(R"( + HloModule m + max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT max = f32[] maximum(p0, p1) + } + ENTRY e { + p0 = f32[150, 20, 10, 50] parameter(0) + p0_init = f32[] constant(-inf) + ROOT reduce = f32[150, 10] reduce(p0, p0_init), + dimensions={3, 1}, to_apply=max + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1)[s0, s1] -> (d0, s0, d1, s1) + domain: + d0 in [0, 149] + d1 in [0, 9] + s0 in [0, 19] + s1 in [0, 49] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 149] + d1 in [0, 9] + )")))); + + auto output_indexing = GetInputToOutputIndexing(root); + EXPECT_THAT(output_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3) -> (d0, d2) + domain: + d0 in [0, 149] + d1 in [0, 19] + d2 in [0, 9] + d3 in [0, 49] + )")), + ElementsAre(MatchIndexingMap(R"( + ()[s0, s1] -> (s0, s1) + domain: + s0 in [0, 149] + s1 in [0, 9] + )")))); +} + +TEST_F(IndexingAnalysisTest, VariadicReduceOp) { + HloInstruction* root = ParseAndGetRoot(R"( + HloModule m + min { + tmp_0 = f32[] parameter(0) + tmp_1 = f32[] parameter(2) + tmp_2 = s32[] parameter(1) + tmp_3 = s32[] parameter(3) + cmp = pred[] compare(tmp_0, tmp_1), direction=GE + select1 = f32[] select(cmp, tmp_0, tmp_1) + select2 = s32[] select(cmp, tmp_2, tmp_3) + ROOT tmp_4 = (f32[], s32[]) tuple(select1, select2) + } + ENTRY e { + p0 = f32[256,10] parameter(0) + p0_init = f32[] constant(-inf) + p1 = s32[256,10] parameter(1) + p1_init = s32[] constant(0) + ROOT reduce = (f32[10], s32[10]) reduce(p0, p1, p0_init, p1_init), + dimensions={0}, to_apply=min + } + )"); + + auto output_indexing_0 = GetOutputToInputIndexing(root, /*output_id=*/0); + EXPECT_THAT(output_indexing_0.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0)[s0] -> (s0, d0) + domain: + d0 in [0, 9] + s0 in [0, 255] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0)[s0] -> (s0, d0) + domain: + d0 in [0, 9] + s0 in [0, 255] + )")), + + ElementsAre(MatchIndexingMap(R"( + (d0) -> () + domain: + d0 in [0, 9] + )")), + + ElementsAre(MatchIndexingMap(R"( + (d0) -> () + domain: + d0 in [0, 9] + )")))); + + auto output_indexing_1 = GetOutputToInputIndexing(root, /*output_id=*/1); + EXPECT_THAT(output_indexing_1.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0)[s0] -> (s0, d0) + domain: + d0 in [0, 9] + s0 in [0, 255] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0)[s0] -> (s0, d0) + domain: + d0 in [0, 9] + s0 in [0, 255] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0) -> () + domain: + d0 in [0, 9] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0) -> () + domain: + d0 in [0, 9] + )")))); + + auto input_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); + + EXPECT_THAT(input_indexing_0.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d1) + domain: + d0 in [0, 255] + d1 in [0, 9] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d1) + domain: + d0 in [0, 255] + d1 in [0, 9] + )")), + ElementsAre(MatchIndexingMap(R"( + ()[s0] -> (s0) + domain: + s0 in [0, 9] + )")), + ElementsAre(MatchIndexingMap(R"( + ()[s0] -> (s0) + domain: + s0 in [0, 9] + )")))); + + auto input_indexing_1 = GetInputToOutputIndexing(root, /*input_id=*/1); + EXPECT_THAT(input_indexing_1.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d1) + domain: + d0 in [0, 255] + d1 in [0, 9] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d1) + domain: + d0 in [0, 255] + d1 in [0, 9] + )")), + ElementsAre(MatchIndexingMap(R"( + ()[s0] -> (s0) + domain: + s0 in [0, 9] + )")), + ElementsAre(MatchIndexingMap(R"( + ()[s0] -> (s0) + domain: + s0 in [0, 9] + )")))); +} + +TEST_F(IndexingAnalysisTest, ReduceWindowOp_NoPadding) { + auto root = ParseAndGetRoot(R"( + HloModule m + max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT max = f32[] maximum(p0, p1) + } + ENTRY e { + c_inf = f32[] constant(-inf) + p0 = f32[1024, 514]parameter(0) + ROOT reduce-window = f32[1024, 3] reduce-window(p0, c_inf), + window={size=1x512 pad=0_0x0_0}, to_apply=max + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1)[s0] -> (d0, d1 + s0) + domain: + d0 in [0, 1023] + d1 in [0, 2] + s0 in [0, 511] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 1023] + d1 in [0, 2] + )")))); +} + +TEST_F(IndexingAnalysisTest, ReduceWindowOp_PaddingAndWindowStride) { + auto root = ParseAndGetRoot(R"( + HloModule m + max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT max = f32[] maximum(p0, p1) + } + ENTRY e { + c_inf = f32[] constant(-inf) + p0 = f32[13, 17] parameter(0) + ROOT reduce-window = f32[7, 17] reduce-window(p0, c_inf), + window={size=3x2 stride=2x1 pad=1_1x0_1}, to_apply=max + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1)[s0, s1] -> (d0 * 2 + s0 - 1, d1 + s1) + domain: + d0 in [0, 6] + d1 in [0, 16] + s0 in [0, 2] + s1 in [0, 1] + d0 * 2 + s0 in [1, 13] + d1 + s1 in [0, 16] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 6] + d1 in [0, 16] + )")))); +} + +TEST_F(IndexingAnalysisTest, ReduceWindowOp_BaseDilation) { + auto root = ParseAndGetRoot(R"( + HloModule m + max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT max = f32[] maximum(p0, p1) + } + ENTRY e { + c_inf = f32[] constant(-inf) + p0 = f32[2, 3] parameter(0) + ROOT reduce-window = f32[3, 5] reduce-window(p0, c_inf), + window={size=1x1 pad=0_0x0_0 lhs_dilate=2x2}, to_apply=max + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0 floordiv 2, d1 floordiv 2) + domain: + d0 in [0, 2] + d1 in [0, 4] + d0 mod 2 in [0, 0] + d1 mod 2 in [0, 0] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 2] + d1 in [0, 4] + )")))); +} + +TEST_F(IndexingAnalysisTest, ReduceWindowOp_WindowDilation) { + auto root = ParseAndGetRoot(R"( + HloModule m + max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT max = f32[] maximum(p0, p1) + } + ENTRY e { + c_inf = f32[] constant(-inf) + p0 = f32[7, 3] parameter(0) + ROOT reduce-window = f32[4, 3] reduce-window(p0, c_inf), + window={size=2x1 pad=0_0x0_0 rhs_dilate=3x1}, to_apply=max + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1)[s0] -> (d0 + s0 * 3, d1) + domain: + d0 in [0, 3] + d1 in [0, 2] + s0 in [0, 1] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 3] + d1 in [0, 2] + )")))); +} + +TEST_F(IndexingAnalysisTest, ReduceWindowOp_Variadic) { + auto root = ParseAndGetRoot(R"( + HloModule m + combiner { + a0 = f32[] parameter(0) + a1 = s32[] parameter(1) + b0 = f32[] parameter(2) + b1 = s32[] parameter(3) + add0 = f32[] add(a0, b0) + add1 = s32[] add(a1, b1) + ROOT sum2 = (f32[], s32[]) tuple(add0, add1) + } + ENTRY e { + c_f32 = f32[] constant(-inf) + c_s32 = s32[] constant(10) + p0 = f32[2, 3] parameter(0) + p1 = s32[2, 3] parameter(1) + ROOT reduce-window = (f32[1, 2], s32[1, 2]) + reduce-window(p0, p1, c_f32, c_s32), + window={size=2x2 pad=0_0x0_0}, to_apply=combiner + } + )"); + auto input_indexing_0 = GetOutputToInputIndexing(root, /*output_id=*/0); + EXPECT_THAT(input_indexing_0.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1)[s0, s1] -> (s0, d1 + s1) + domain: + d0 in [0, 0] + d1 in [0, 1] + s0 in [0, 1] + s1 in [0, 1] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1)[s0, s1] -> (s0, d1 + s1) + domain: + d0 in [0, 0] + d1 in [0, 1] + s0 in [0, 1] + s1 in [0, 1] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 0] + d1 in [0, 1] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 0] + d1 in [0, 1] + )")))); + auto input_indexing_1 = GetOutputToInputIndexing(root, /*output_id=*/1); + EXPECT_THAT(input_indexing_1.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1)[s0, s1] -> (s0, d1 + s1) + domain: + d0 in [0, 0] + d1 in [0, 1] + s0 in [0, 1] + s1 in [0, 1] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1)[s0, s1] -> (s0, d1 + s1) + domain: + d0 in [0, 0] + d1 in [0, 1] + s0 in [0, 1] + s1 in [0, 1] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 0] + d1 in [0, 1] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> () + domain: + d0 in [0, 0] + d1 in [0, 1] + )")))); +} + +TEST_F(IndexingAnalysisTest, ConvolutionOp_NoPadding) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[1,12,10,4] parameter(0) + p1 = f32[4,3,5,8] parameter(1) + ROOT conv = f32[1,10,6,8] convolution(p0, p1), + window={size=3x5 pad=0_0x0_0}, dim_labels=b01f_i01o->b01f + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3)[s0, s1, s2] -> (d0, d1 + s0, d2 + s1, s2) + domain: + d0 in [0, 0] + d1 in [0, 9] + d2 in [0, 5] + d3 in [0, 7] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3) + domain: + d0 in [0, 0] + d1 in [0, 9] + d2 in [0, 5] + d3 in [0, 7] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] + )")))); +} + +TEST_F(IndexingAnalysisTest, ConvolutionOp_PaddingAndWindowStride) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[1,12,10,4] parameter(0) + p1 = f32[4,3,5,8] parameter(1) + ROOT conv = f32[1,6,5,8] convolution(p0, p1), + window={size=3x5 stride=2x2 pad=1_1x2_2}, dim_labels=b01f_i01o->b01f + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3)[s0, s1, s2] -> (d0, d1 * 2 + s0 - 1, d2 * 2 + s1 - 2, s2) + domain: + d0 in [0, 0] + d1 in [0, 5] + d2 in [0, 4] + d3 in [0, 7] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] + d1 * 2 + s0 in [1, 12] + d2 * 2 + s1 in [2, 11] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3) + domain: + d0 in [0, 0] + d1 in [0, 5] + d2 in [0, 4] + d3 in [0, 7] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] + )")))); +} + +TEST_F(IndexingAnalysisTest, ConvolutionOp_LhsDilation) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[1,12,10,4] parameter(0) + p1 = f32[4,3,5,8] parameter(1) + ROOT conv = f32[1,21,15,8] convolution(p0, p1), + window={size=3x5 pad=0_0x0_0 lhs_dilate=2x2}, dim_labels=b01f_i01o->b01f + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3)[s0, s1, s2] -> (d0, (d1 + s0) floordiv 2, (d2 + s1) floordiv 2, s2) + domain: + d0 in [0, 0] + d1 in [0, 20] + d2 in [0, 14] + d3 in [0, 7] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] + (d1 + s0) mod 2 in [0, 0] + (d2 + s1) mod 2 in [0, 0] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3) + domain: + d0 in [0, 0] + d1 in [0, 20] + d2 in [0, 14] + d3 in [0, 7] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] + )")))); +} + +TEST_F(IndexingAnalysisTest, ConvolutionOp_RhsDilation) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[1,12,10,4] parameter(0) + p1 = f32[4,3,5,8] parameter(1) + ROOT conv = f32[1,8,2,8] convolution(p0, p1), + window={size=3x5 pad=0_0x0_0 rhs_dilate=2x2}, dim_labels=b01f_i01o->b01f + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3)[s0, s1, s2] -> (d0, d1 + s0 * 2, d2 + s1 * 2, s2) + domain: + d0 in [0, 0] + d1 in [0, 7] + d2 in [0, 1] + d3 in [0, 7] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3) + domain: + d0 in [0, 0] + d1 in [0, 7] + d2 in [0, 1] + d3 in [0, 7] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] + )")))); +} + +TEST_F(IndexingAnalysisTest, ConvolutionOp_FeatureGroups) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[1,12,10,24] parameter(0) + p1 = f32[4,3,5,48] parameter(1) + ROOT conv = f32[1,10,6,48] convolution(p0, p1), + window={size=3x5 pad=0_0x0_0}, dim_labels=b01f_i01o->b01f, feature_group_count=6 + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3)[s0, s1, s2] -> (d0, d1 + s0, d2 + s1, (d3 floordiv 8) * 4 + s2) + domain: + d0 in [0, 0] + d1 in [0, 9] + d2 in [0, 5] + d3 in [0, 47] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3) + domain: + d0 in [0, 0] + d1 in [0, 9] + d2 in [0, 5] + d3 in [0, 47] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] + )")))); +} + +TEST_F(IndexingAnalysisTest, ConvolutionOp_BatchGroups) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[14,12,10,4] parameter(0) + p1 = f32[4,3,5,21] parameter(1) + ROOT conv = f32[2,10,6,21] convolution(p0, p1), + window={size=3x5 pad=0_0x0_0}, dim_labels=b01f_i01o->b01f, batch_group_count=7 + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 + s3 * 2, d1 + s0, d2 + s1, s2) + domain: + d0 in [0, 1] + d1 in [0, 9] + d2 in [0, 5] + d3 in [0, 20] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] + s3 in [0, 6] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3) + domain: + d0 in [0, 1] + d1 in [0, 9] + d2 in [0, 5] + d3 in [0, 20] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] + )")))); +} + +TEST_F(IndexingAnalysisTest, ReverseOp) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[1, 17, 9, 9] parameter(0) + ROOT reverse = f32[1, 17, 9, 9] reverse(p0), dimensions={1, 2} + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3) + domain: + d0 in [0, 0] + d1 in [0, 16] + d2 in [0, 8] + d3 in [0, 8] + )")))); + + auto output_indexing = GetInputToOutputIndexing(root); + EXPECT_THAT(output_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3) + domain: + d0 in [0, 0] + d1 in [0, 16] + d2 in [0, 8] + d3 in [0, 8] + )")))); +} + +TEST_F(IndexingAnalysisTest, ReverseReshape) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + fused_computation { + p0 = f32[10, 11] parameter(0) + reverse.0 = f32[10, 11] reverse(p0), dimensions={0, 1} + reshape.0 = f32[110] reshape(reverse.0) + reverse.1 = f32[110] reverse(reshape.0), dimensions={0} + ROOT reshape.1 = f32[10, 11] reshape(reverse.1) + } + ENTRY e { + p0 = f32[10, 11] parameter(0) + ROOT fusion = f32[10, 11] fusion(p0), kind=kLoop, + calls=fused_computation + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 9] + d1 in [0, 10] + )")))); +} + +TEST_F(IndexingAnalysisTest, SliceOp) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[10, 20, 50] parameter(0) + ROOT slice = f32[5, 3, 25] slice(f32[10, 20, 50] p0), + slice={[5:10:1], [3:20:7], [0:50:2]} + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2) -> (d0 + 5, d1 * 7 + 3, d2 * 2) + domain: + d0 in [0, 4] + d1 in [0, 2] + d2 in [0, 24] + )")))); +} + +TEST_F(IndexingAnalysisTest, TransposeOp) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[3, 12288, 6, 128] parameter(0) + ROOT transpose = f32[3, 6, 128, 12288] + transpose(p0), dimensions={0, 2, 3, 1} + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3) -> (d0, d3, d1, d2) + domain: + d0 in [0, 2] + d1 in [0, 5] + d2 in [0, 127] + d3 in [0, 12287] + )")))); + auto output_indexing = GetInputToOutputIndexing(root); + EXPECT_THAT(output_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3) -> (d0, d2, d3, d1) + domain: + d0 in [0, 2] + d1 in [0, 12287] + d2 in [0, 5] + d3 in [0, 127] + )")))); +} + +TEST_F(IndexingAnalysisTest, TransposeOp4D) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[3, 12288, 6, 128] parameter(0) + ROOT bitcast = f32[3, 6, 128, 12288] {2, 1, 3, 0} bitcast(p0) + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3) -> (d0, d3, d1, d2) + domain: + d0 in [0, 2] + d1 in [0, 5] + d2 in [0, 127] + d3 in [0, 12287] + )")))); +} + +TEST_F(IndexingAnalysisTest, DotOp) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[4, 38, 17, 11, 18, 10] parameter(0) + p1 = f32[17, 10, 16, 18, 22, 38] parameter(1) + ROOT dot = f32[10, 38, 4, 11, 16, 22] dot(p0, p1), + lhs_batch_dims={5,1}, rhs_batch_dims={1,5}, + lhs_contracting_dims={4,2}, rhs_contracting_dims={3,0} + } + )")); + EXPECT_THAT(input_indexing.indexing_maps, + ElementsAre(ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1] -> (d2, d1, s1, d3, s0, d0) + domain: + d0 in [0, 9] + d1 in [0, 37] + d2 in [0, 3] + d3 in [0, 10] + d4 in [0, 15] + d5 in [0, 21] + s0 in [0, 17] + s1 in [0, 16] + )")), + ElementsAre(MatchIndexingMap(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1] -> (s1, d0, d4, s0, d5, d1) + domain: + d0 in [0, 9] + d1 in [0, 37] + d2 in [0, 3] + d3 in [0, 10] + d4 in [0, 15] + d5 in [0, 21] + s0 in [0, 17] + s1 in [0, 16] + )")))); +} + +TEST_F(IndexingAnalysisTest, UnsupportedOps) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[20, 20] parameter(0) + p1 = f32[4,4] parameter(1) + p2 = f32[4,3] parameter(2) + ROOT out = f32[4,3] triangular-solve(f32[4,4] p1, f32[4,3] p2), + left_side=true, + lower=true, + transpose_a=NO_TRANSPOSE, + unit_diagonal=true + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + EXPECT_THAT( + input_indexing.indexing_maps, + ElementsAre(ElementsAre(UndefinedMap()), ElementsAre(UndefinedMap()))); + + auto output_indexing_0 = GetInputToOutputIndexing(root, 0); + EXPECT_THAT(output_indexing_0.indexing_maps, + ElementsAre(ElementsAre(UndefinedMap()))); + + auto output_indexing_1 = GetInputToOutputIndexing(root, 1); + EXPECT_THAT(output_indexing_1.indexing_maps, + ElementsAre(ElementsAre(UndefinedMap()))); +} + +TEST_F(IndexingAnalysisTest, FusionWithUnsupportedOp) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + fused_computation { + p0 = f32[20, 20] parameter(0) + p1 = f32[4,4] parameter(1) + p2 = f32[4,3] parameter(2) + lhs = f32[4,3] triangular-solve(f32[4,4] p1, f32[4,3] p2), + left_side=true, + lower=true, + transpose_a=NO_TRANSPOSE, + unit_diagonal=true + rhs = f32[4, 3] slice(f32[20, 20] p0), + slice={[0:20:6], [0:5:2]} + ROOT add = f32[4, 3] add(lhs, rhs) + } + ENTRY e { + p0 = f32[20, 20] parameter(0) + p1 = f32[4, 4] parameter(1) + p2 = f32[4, 3] parameter(2) + ROOT fusion = f32[4, 3] fusion(p0, p1, p2), kind=kLoop, + calls=fused_computation + } + )")); + EXPECT_THAT( + input_indexing.indexing_maps, + ElementsAre(UnorderedElementsAre(MatchIndexingMap(R"( + (d0, d1) -> (d0 * 6, d1 * 2) + domain: + d0 in [0, 3] + d1 in [0, 2] + )")), + ElementsAre(UndefinedMap()), ElementsAre(UndefinedMap()))); +} + +TEST_F(IndexingAnalysisTest, TilingIndexing) { + Tiling tiling{/*shape=*/{1022, 256, 16}, + /*tile_sizes=*/{8, 1, 4}, + /*num_threads=*/{1, 4, 4}}; + auto indexing_map = GetIndexingMapForTiling(tiling, &mlir_context_); + indexing_map.Simplify(GetIndexingMapForInstruction); + EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + (d3 floordiv 64) * 8 + s0, + d0 floordiv 4 + (d3 mod 64) * 4, + d0 mod 4 + s2 * 4 + ) + domain: + d0 in [0, 15] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 8191] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 7] + s1 in [0, 0] + s2 in [0, 3] + (d3 floordiv 64) * 8 + s0 in [0, 1021] + )")); +} + +TEST_F(IndexingAnalysisTest, EpilogueIndexing) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule m + ENTRY e { + p0 = f32[1000, 1000] parameter(0) + t = f32[1000, 1000]{0, 1} transpose(p0), dimensions={1, 0} + a0 = f32[1000000] bitcast(t) + ROOT log = f32[1000000] log(a0) + } + )"); + ASSERT_TRUE(module.ok()); + EXPECT_THAT(ComputeEpilogueInputToOutputIndexing( + (*module)->entry_computation()->GetInstructionWithName("t"), + &mlir_context_) + .ToString(), + MatchIndexingString(R"( + (d0, d1) -> (d0 + d1 * 1000) + domain: + d0 in [0, 999] + d1 in [0, 999] + )")); +} + +TEST_F(IndexingAnalysisTest, EpilogueIndexing_NoEpilogue) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule m + ENTRY e { + p0 = f32[1000, 1000] parameter(0) + ROOT t = f32[1000, 1000]{0, 1} transpose(p0), dimensions={1, 0} + } + )"); + ASSERT_TRUE(module.ok()); + EXPECT_THAT(ComputeEpilogueInputToOutputIndexing( + (*module)->entry_computation()->GetInstructionWithName("t"), + &mlir_context_) + .ToString(), + MatchIndexingString(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 999] + d1 in [0, 999] + )")); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/indexing_context.h b/xla/service/gpu/model/indexing_context.h new file mode 100644 index 0000000000000..e5dfc6adb7d3c --- /dev/null +++ b/xla/service/gpu/model/indexing_context.h @@ -0,0 +1,39 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_INDEXING_CONTEXT_H_ +#define XLA_SERVICE_GPU_MODEL_INDEXING_CONTEXT_H_ + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { + +class IndexingContext { + public: + explicit IndexingContext(mlir::MLIRContext* mlir_context) + : mlir_context_(mlir_context) {} + + mlir::MLIRContext* GetMLIRContext() const { return mlir_context_; } + + private: + mlir::MLIRContext* mlir_context_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_INDEXING_CONTEXT_H_ diff --git a/xla/service/gpu/model/indexing_map.cc b/xla/service/gpu/model/indexing_map.cc new file mode 100644 index 0000000000000..a740243bc2e79 --- /dev/null +++ b/xla/service/gpu/model/indexing_map.cc @@ -0,0 +1,1444 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/indexing_map.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/model/affine_map_printer.h" +#include "tsl/platform/logging.h" // IWYU pragma: keep + +namespace xla { +namespace gpu { +namespace { + +using llvm::ArrayRef; +using llvm::SmallBitVector; +using llvm::SmallVector; +using mlir::AffineBinaryOpExpr; +using mlir::AffineConstantExpr; +using mlir::AffineDimExpr; +using mlir::AffineExpr; +using mlir::AffineExprKind; +using mlir::AffineMap; +using mlir::AffineSymbolExpr; +using mlir::getAffineBinaryOpExpr; +using mlir::getAffineConstantExpr; +using mlir::MLIRContext; + +int64_t FloorDiv(int64_t dividend, int64_t divisor) { + return dividend / divisor - + (((dividend >= 0) != (divisor >= 0) && dividend % divisor) ? 1 : 0); +} + +int64_t CeilDiv(int64_t dividend, int64_t divisor) { + return dividend / divisor + + (((dividend >= 0) == (divisor >= 0) && dividend % divisor) ? 1 : 0); +} + +class AffineExprSimplifier { + public: + explicit AffineExprSimplifier(RangeEvaluator* range_evaluator) + : range_evaluator_(range_evaluator) {} + + // Simplifies the map as much as possible. + mlir::AffineMap Simplify(mlir::AffineMap affine_map); + + mlir::AffineExpr Simplify(mlir::AffineExpr expr); + + private: + std::optional GetConstantRhs(mlir::AffineExpr expr, + AffineExprKind kind); + + // Simplifier for mod. + // - Rewrites (a * 100 + ...) % 100 to (...) % 100 + // - Rewrites a % b to a if a is known to be less than b. + mlir::AffineExpr RewriteMod(mlir::AffineBinaryOpExpr mod); + + // Simplifier for floordiv. + // - Rewrites (a * 100 + ...) / 100 to a + (...) / 100 + // - Rewrites a / 100 to 0 when a is known to be less than 100. + mlir::AffineExpr RewriteFloorDiv(mlir::AffineBinaryOpExpr div); + + mlir::AffineExpr RewriteSum( + mlir::AffineExpr expr, + const std::function& map); + + mlir::AffineExpr RewriteSumIf( + mlir::AffineExpr expr, const std::function& pred); + + // Attempts to simplify the expression, but doesn't attempt to simplify the + // result further. + mlir::AffineExpr SimplifyOnce(mlir::AffineExpr expr); + + // Simplifies the expression using MLIR's simplifier, except for mods. + mlir::AffineExpr SimplifyWithMlir(mlir::AffineExpr expr, int num_dims, + int num_symbols); + + mlir::AffineMap SimplifyWithMlir(mlir::AffineMap map) { + llvm::SmallVector exprs; + for (auto e : map.getResults()) { + exprs.push_back( + SimplifyWithMlir(e, map.getNumDims(), map.getNumSymbols())); + } + return mlir::AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs, + map.getContext()); + } + + RangeEvaluator* range_evaluator_; +}; + +AffineExpr AffineExprSimplifier::RewriteMod(AffineBinaryOpExpr mod) { + auto rhs = range_evaluator_->ComputeExpressionRange(mod.getRHS()); + + // The logic below assumes we have a constant RHS. + if (!rhs.IsPoint()) { + return mod; + } + int64_t m = rhs.lower; + // Can only happen in cases where it doesn't matter, return 0. + if (m == 0) { + return mlir::getAffineConstantExpr(0, mod.getContext()); + } + + auto lhs_simplified = SimplifyOnce(mod.getLHS()); + auto lhs = range_evaluator_->ComputeExpressionRange(lhs_simplified); + // a % b where b is always larger than a? + if (0 <= lhs.lower && lhs.upper < rhs.lower) { + return lhs_simplified; + } + + // Rewrite `(c * a) % ab` to `(c % b) * a`. + // (c * a) % ab + // = c * a - (c * a) // ab * ab + // = c * a - c // b * ab + // = (c - c // b * b) * a + // = (c % b) * a + if (auto mul = GetConstantRhs(lhs_simplified, AffineExprKind::Mul); + mul && (m % *mul == 0)) { + return (mlir::cast(lhs_simplified).getLHS() % + (m / *mul)) * + *mul; + } + + Interval no_multiplier_range{0, 0}; + int64_t multiplier_gcd = -1; + + int64_t extracted_constant = 0; + auto new_lhs = RewriteSumIf(lhs_simplified, [&](AffineExpr expr) { + if (auto cst = mlir::dyn_cast(expr); + cst && cst.getValue() >= m) { + extracted_constant += cst.getValue(); + return false; + } + if (auto multiplier = GetConstantRhs(expr, AffineExprKind::Mul)) { + if (*multiplier % m == 0) { + return false; + } + + if (multiplier_gcd == -1) { + multiplier_gcd = *multiplier; + } else { + multiplier_gcd = std::gcd(multiplier_gcd, *multiplier); + } + return true; + } + auto range = range_evaluator_->ComputeExpressionRange(expr); + no_multiplier_range.lower += range.lower; + no_multiplier_range.upper += range.upper; + return true; + }); + new_lhs = new_lhs + (extracted_constant % m); + + mlir::AffineExpr extracted = getAffineConstantExpr(0, mod.getContext()); + if (m % multiplier_gcd == 0 && no_multiplier_range.lower >= 0 && + no_multiplier_range.upper < multiplier_gcd) { + // Remove everything that doesn't have a multiplier. + new_lhs = RewriteSumIf(new_lhs, [&](AffineExpr expr) { + if (GetConstantRhs(expr, AffineExprKind::Mul)) { + return true; + } + extracted = extracted + expr; + return false; + }); + } + return new_lhs % mod.getRHS() + extracted; +} + +AffineExpr AffineExprSimplifier::RewriteFloorDiv(AffineBinaryOpExpr div) { + auto mlir_context = range_evaluator_->GetMLIRContext(); + auto lhs_simplified = SimplifyOnce(div.getLHS()); + auto lhs = range_evaluator_->ComputeExpressionRange(lhs_simplified); + auto rhs = range_evaluator_->ComputeExpressionRange(div.getRHS()); + + if (0 <= lhs.lower && lhs.upper < rhs.lower) { + return getAffineConstantExpr(0, mlir_context); + } + + // The logic below assumes we have a constant RHS. + if (!rhs.IsPoint()) { + return div; + } + int64_t d = rhs.lower; + + // Rewrite `(c % ab) // a` to `(c // a) % b`. + // (c % ab) // a + // = (c - c // ab * ab) // a expand mod + // = c // a - (c // ab * b) rhs of - divides a + // = c // a - (c // a) // b * b) split ab + // = (c // a) % b contract mod + if (auto mod = GetConstantRhs(lhs_simplified, AffineExprKind::Mod); + mod && (*mod % d == 0)) { + return mlir::cast(lhs_simplified).getLHS().floorDiv(d) % + (*mod / d); + } + + // If the dividend's range has a single element, return its value. + int64_t a = FloorDiv(lhs.lower, d); + int64_t b = FloorDiv(lhs.upper, d); + if (a == b) { + return getAffineConstantExpr(a, mlir_context); + } + + // Rewrite `(a / b) / c` to `a / (b * c)` if `a >= 0` and `b` and `c` are + // constants. + if (lhs_simplified.getKind() == AffineExprKind::FloorDiv) { + auto lhs_div = mlir::cast(lhs_simplified); + auto lhs_lhs = range_evaluator_->ComputeExpressionRange(lhs_div.getLHS()); + if (lhs_lhs.lower >= 0) { + auto lhs_rhs = range_evaluator_->ComputeExpressionRange(lhs_div.getRHS()); + if (lhs_rhs.IsPoint()) { + return lhs_div.getLHS().floorDiv(lhs_rhs.lower * d); + } + } + } + + Interval no_multiplier_range{0, 0}; + int64_t multiplier_gcd = -1; + // The maximum GCD of any remaining multiplier inside the div and the divisor. + int64_t max_remaining_multiplier_gcd = -1; + AffineExpr zero = getAffineConstantExpr(0, mlir_context); + AffineExpr extracted = zero; + auto new_dividend = RewriteSumIf(lhs_simplified, [&](AffineExpr expr) { + if (auto multiplier = GetConstantRhs(expr, AffineExprKind::Mul)) { + // (x * 7 + ...) / 3 -> can't extract. We could extract x * 2 and keep + // one x, but we currently have no reason to do that. + + if (*multiplier % d == 0) { + int64_t factor = *multiplier / d; + extracted = + extracted + mlir::cast(expr).getLHS() * factor; + // Remove from dividend. + return false; + } + + if (*multiplier > 0) { + if (multiplier_gcd == -1) { + multiplier_gcd = *multiplier; + } else { + multiplier_gcd = std::gcd(multiplier_gcd, *multiplier); + } + max_remaining_multiplier_gcd = + std::max(max_remaining_multiplier_gcd, std::gcd(*multiplier, d)); + return true; + } + } + auto range = range_evaluator_->ComputeExpressionRange(expr); + no_multiplier_range.lower += range.lower; + no_multiplier_range.upper += range.upper; + // Not a constant multiplier, keep in dividend. + return true; + }); + + // If we removed everything, skip the div. + if (new_dividend == zero) { + return extracted; + } + + if ((d % multiplier_gcd) == 0) { + if (no_multiplier_range.lower >= 0 && + no_multiplier_range.upper < multiplier_gcd) { + // Remove everything that doesn't have a multiplier. + new_dividend = RewriteSumIf(new_dividend, [&](AffineExpr expr) { + auto mult = GetConstantRhs(expr, AffineExprKind::Mul); + return mult.has_value(); + }); + } + } + + // If we have a gcd > 1, we can split the div into two: + // (x * 128 + y) // 192 -> (x * 2 + y // 64) // 3 + if (max_remaining_multiplier_gcd > 1) { + AffineExpr partially_extracted = getAffineConstantExpr(0, mlir_context); + new_dividend = RewriteSumIf(new_dividend, [&](AffineExpr expr) { + if (auto multiplier = GetConstantRhs(expr, AffineExprKind::Mul); + multiplier && (*multiplier > 0) && + ((*multiplier % max_remaining_multiplier_gcd) == 0)) { + auto expr_lhs = mlir::cast(expr).getLHS(); + partially_extracted = + partially_extracted + + expr_lhs * (*multiplier / max_remaining_multiplier_gcd); + // Remove from dividend. + return false; + } + return true; + }); + return extracted + (partially_extracted + + new_dividend.floorDiv(max_remaining_multiplier_gcd)) + .floorDiv(d / max_remaining_multiplier_gcd); + } + + // If we removed nothing, return the original division. + if (extracted == getAffineConstantExpr(0, mlir_context) && + new_dividend == div.getLHS()) { + return div; + } + + return extracted + new_dividend.floorDiv(div.getRHS()); +} + +std::optional AffineExprSimplifier::GetConstantRhs( + AffineExpr expr, AffineExprKind kind) { + if (expr.getKind() != kind) { + return std::nullopt; + } + auto bound = range_evaluator_->ComputeExpressionRange( + mlir::cast(expr).getRHS()); + if (!bound.IsPoint()) { + return std::nullopt; + } + return bound.lower; +} + +AffineExpr AffineExprSimplifier::RewriteSum( + AffineExpr expr, const std::function& map) { + if (expr.getKind() == AffineExprKind::Add) { + auto add = mlir::dyn_cast(expr); + return RewriteSum(add.getLHS(), map) + RewriteSum(add.getRHS(), map); + } + return map(expr); +} + +AffineExpr AffineExprSimplifier::RewriteSumIf( + AffineExpr expr, const std::function& pred) { + if (expr.getKind() == AffineExprKind::Add) { + auto add = mlir::dyn_cast(expr); + auto lhs = RewriteSumIf(add.getLHS(), pred); + auto rhs = RewriteSumIf(add.getRHS(), pred); + if (lhs == add.getLHS() && rhs == add.getRHS()) { + return add; + } + return lhs + rhs; + } + return pred(expr) ? expr : mlir::getAffineConstantExpr(0, expr.getContext()); +} + +AffineExpr AffineExprSimplifier::SimplifyOnce(AffineExpr expr) { + switch (expr.getKind()) { + case AffineExprKind::Mul: { + auto binop = mlir::cast(expr); + auto lhs = SimplifyOnce(binop.getLHS()); + auto rhs = SimplifyOnce(binop.getRHS()); + return getAffineBinaryOpExpr(expr.getKind(), lhs, rhs); + } + case AffineExprKind::Add: { + auto binop = mlir::cast(expr); + auto lhs = SimplifyOnce(binop.getLHS()); + auto rhs = SimplifyOnce(binop.getRHS()); + + // Rewrite `(x // c) * c + (x % c)` to `x`. + // This should also work with (a+b)+c. + auto rewrite_add = [&](AffineExpr a, AffineExpr b) -> AffineExpr { + if (auto mod = GetConstantRhs(a, AffineExprKind::Mod)) { + if (auto mul = GetConstantRhs(b, AffineExprKind::Mul); mod == mul) { + auto b_lhs = mlir::cast(b).getLHS(); + if (auto div = GetConstantRhs(b_lhs, AffineExprKind::FloorDiv); + div == mul) { + auto x = mlir::cast(b_lhs).getLHS(); + if (x == mlir::cast(a).getLHS()) { + return x; + } + } + } + } + return nullptr; + }; + + if (auto rewritten = rewrite_add(lhs, rhs)) { + return rewritten; + } + if (auto rewritten = rewrite_add(rhs, lhs)) { + return rewritten; + } + + return getAffineBinaryOpExpr(expr.getKind(), lhs, rhs); + } + case AffineExprKind::Mod: + return RewriteMod(mlir::cast(expr)); + case AffineExprKind::FloorDiv: + return RewriteFloorDiv(mlir::cast(expr)); + case AffineExprKind::DimId: + case AffineExprKind::SymbolId: { + auto bounds = range_evaluator_->ComputeExpressionRange(expr); + if (bounds.IsPoint()) { + return getAffineConstantExpr(bounds.lower, + range_evaluator_->GetMLIRContext()); + } + return expr; + } + + default: + return expr; + } +} + +AffineExpr AffineExprSimplifier::SimplifyWithMlir(AffineExpr expr, int num_dims, + int num_symbols) { + int next_symbol = num_symbols; + llvm::DenseMap mod_to_sym; + llvm::DenseMap sym_to_mod; + std::function replace_mods; + replace_mods = [&](AffineExpr e) { + switch (e.getKind()) { + case AffineExprKind::Mul: + case AffineExprKind::Add: + case AffineExprKind::CeilDiv: + case AffineExprKind::FloorDiv: { + auto bin = mlir::cast(e); + return getAffineBinaryOpExpr(e.getKind(), replace_mods(bin.getLHS()), + replace_mods(bin.getRHS())); + } + case AffineExprKind::Mod: { + auto& ret = mod_to_sym[e]; + if (ret) return ret; + + auto bin = mlir::cast(e); + ret = getAffineSymbolExpr(next_symbol++, expr.getContext()); + sym_to_mod[ret] = getAffineBinaryOpExpr( + AffineExprKind::Mod, + SimplifyWithMlir(bin.getLHS(), num_dims, num_symbols), + bin.getRHS()); + return ret; + } + case AffineExprKind::Constant: + case AffineExprKind::DimId: + case AffineExprKind::SymbolId: + return e; + } + }; + + auto m = replace_mods(expr); + return mlir::simplifyAffineExpr(m, num_dims, next_symbol).replace(sym_to_mod); +} + +AffineExpr AffineExprSimplifier::Simplify(AffineExpr expr) { + while (true) { + auto simplified = SimplifyOnce(expr); + if (simplified == expr) { + return expr; + } + expr = simplified; + } +} + +AffineMap AffineExprSimplifier::Simplify(AffineMap affine_map) { + affine_map = SimplifyWithMlir(affine_map); + SmallVector results; + results.reserve(affine_map.getNumResults()); + bool nothing_changed = true; + for (AffineExpr expr : affine_map.getResults()) { + AffineExpr simplified = Simplify(expr); + nothing_changed &= simplified == expr; + results.push_back(simplified); + } + if (nothing_changed) { + return affine_map; + } + return Simplify(AffineMap::get(affine_map.getNumDims(), + affine_map.getNumSymbols(), results, + affine_map.getContext())); +} + +// Computes intersection of two ranges. +Interval Intersect(const Interval& lhs, const Interval& rhs) { + return Interval{std::max(lhs.lower, rhs.lower), + std::min(lhs.upper, rhs.upper)}; +} + +// Simplifies a constraint range, i.e. a constraint d0 + x in [lb, ub] will +// become d0 in [lb - x, ub - x]. Also supports *, floorDiv. +bool SimplifyConstraintRangeOnce(AffineExpr* expr, Interval* range) { + switch (expr->getKind()) { + case AffineExprKind::DimId: + case AffineExprKind::SymbolId: + // do the trick with constant + case AffineExprKind::Constant: { + return false; + } + default: { + auto binary_op = mlir::cast(*expr); + CHECK(binary_op); + auto lhs = binary_op.getLHS(); + auto rhs = binary_op.getRHS(); + auto constant = mlir::dyn_cast(rhs); + if (!constant) { + return false; + } + switch (expr->getKind()) { + case AffineExprKind::Add: { + int64_t shift = constant.getValue(); + range->lower -= shift; + range->upper -= shift; + *expr = lhs; + return true; + } + case AffineExprKind::Mul: { + int64_t factor = constant.getValue(); + if (factor < 0) { + factor *= -1; + range->lower *= -1; + range->upper *= -1; + std::swap(range->lower, range->upper); + } + range->lower = CeilDiv(range->lower, factor); + range->upper = FloorDiv(range->upper, factor); + *expr = lhs; + return true; + } + case AffineExprKind::FloorDiv: { + int64_t divisor = constant.getValue(); + if (divisor < 0) { + divisor *= -1; + range->lower *= -1; + range->upper *= -1; + std::swap(range->lower, range->upper); + } + range->lower *= divisor; + range->upper = (range->upper + 1) * divisor - 1; + *expr = lhs; + return true; + } + default: { + return false; + } + } + } + } +} + +// Repeatedly simplifies the range of the constraint. +bool SimplifyConstraintRange(AffineExpr* expr, Interval* range) { + bool is_simplified = false; + while (SimplifyConstraintRangeOnce(expr, range)) { + is_simplified = true; + } + return is_simplified; +} + +// Computes the symbols list replacement to go from +// [range_vars(second)|rt_vars(second)|range_vars(first)|rt_vars(first)] +// to +// [range_vars(second)|range_vars(first)|rt_vars(second)|rt_vars(first)]. +SmallVector GetComposedSymbolsPermutationToCorrectOrder( + const IndexingMap& first, const IndexingMap& second) { + SmallVector symbol_replacements; + MLIRContext* mlir_context = first.GetMLIRContext(); + for (int id = 0; id < second.GetRangeVarsCount(); ++id) { + symbol_replacements.push_back(getAffineSymbolExpr(id, mlir_context)); + } + int64_t rt_vars_second_start = + first.GetRangeVarsCount() + second.GetRangeVarsCount(); + for (int64_t id = 0; id < second.GetRTVarsCount(); ++id) { + symbol_replacements.push_back( + getAffineSymbolExpr(rt_vars_second_start++, mlir_context)); + } + int64_t range_vars_first_start = second.GetRangeVarsCount(); + for (int64_t id = 0; id < first.GetRangeVarsCount(); ++id) { + symbol_replacements.push_back( + getAffineSymbolExpr(range_vars_first_start++, mlir_context)); + } + int64_t rt_vars_first_start = rt_vars_second_start + second.GetRTVarsCount(); + for (int64_t id = 0; id < first.GetRTVarsCount(); ++id) { + symbol_replacements.push_back( + getAffineSymbolExpr(rt_vars_first_start++, mlir_context)); + } + return symbol_replacements; +} + +// Computes the symbols list mapping to go from +// [range_vars(map)|rt_vars(map)] +// to +// [range_vars(second)|range_vars(first)|rt_vars(second)|rt_vars(first)]. +SmallVector MapSymbolsToComposedSymbolsList( + const IndexingMap& map, const IndexingMap& composed) { + SmallVector symbol_replacements; + + MLIRContext* mlir_context = map.GetMLIRContext(); + int64_t range_vars_start = + composed.GetRangeVarsCount() - map.GetRangeVarsCount(); + for (int64_t id = 0; id < map.GetRangeVarsCount(); ++id) { + symbol_replacements.push_back( + getAffineSymbolExpr(range_vars_start++, mlir_context)); + } + int64_t rt_vars_start = composed.GetSymbolCount() - map.GetRTVarsCount(); + for (int64_t id = 0; id < map.GetRTVarsCount(); ++id) { + symbol_replacements.push_back( + getAffineSymbolExpr(rt_vars_start++, mlir_context)); + } + return symbol_replacements; +} + +} // namespace + +std::string Interval::ToString() const { + std::stringstream ss; + Print(ss); + return ss.str(); +} + +void Interval::Print(std::ostream& out) const { + out << '[' << lower << ", " << upper << "]"; +} + +std::ostream& operator<<(std::ostream& out, const Interval& range) { + range.Print(out); + return out; +} + +bool operator==(const Interval& lhs, const Interval& rhs) { + return lhs.lower == rhs.lower && lhs.upper == rhs.upper; +} + +bool operator==(const DimVar& lhs, const DimVar& rhs) { + return lhs.bounds == rhs.bounds; +} + +bool operator==(const RangeVar& lhs, const RangeVar& rhs) { + return lhs.range == rhs.range; +} + +bool operator==(const RTVar& lhs, const RTVar& rhs) { + return lhs.feasible_values == rhs.feasible_values && lhs.hlo == rhs.hlo && + lhs.map == rhs.map; +} + +std::vector DimVarsFromTensorSizes( + absl::Span tensor_sizes) { + std::vector ranges; + ranges.reserve(tensor_sizes.size()); + for (int64_t size : tensor_sizes) { + ranges.push_back({Interval{0, size - 1}}); + } + return ranges; +} + +std::vector RangeVarsFromTensorSizes( + absl::Span tensor_sizes) { + std::vector ranges; + ranges.reserve(tensor_sizes.size()); + for (int64_t size : tensor_sizes) { + ranges.push_back({Interval{0, size - 1}}); + } + return ranges; +} + +IndexingMap IndexingMap::FromTensorSizes( + AffineMap affine_map, absl::Span dim_upper_bounds, + absl::Span symbol_upper_bounds) { + return IndexingMap{affine_map, DimVarsFromTensorSizes(dim_upper_bounds), + RangeVarsFromTensorSizes(symbol_upper_bounds), + /*rt_vars=*/{}}; +} + +const Interval& IndexingMap::GetDimensionBound(int64_t dim_id) const { + return dim_vars_[dim_id].bounds; +} + +Interval& IndexingMap::GetMutableDimensionBound(int64_t dim_id) { + return dim_vars_[dim_id].bounds; +} + +std::vector IndexingMap::GetDimensionBounds() const { + std::vector bounds; + bounds.reserve(affine_map_.getNumDims()); + for (const auto& dim : dim_vars_) { + bounds.push_back(dim.bounds); + } + return bounds; +} + +const Interval& IndexingMap::GetSymbolBound(int64_t symbol_id) const { + // Because affine map symbols are packed like [range_vars, rt_vars], + // we have to pick the correct bounds. + int64_t range_var_count = GetRangeVarsCount(); + return symbol_id < range_var_count + ? range_vars_[symbol_id].range + : rt_vars_[symbol_id - range_var_count].feasible_values; +} + +Interval& IndexingMap::GetMutableSymbolBound(int64_t symbol_id) { + // Because affine map symbols are packed like [range_vars, rt_vars], + // we have to pick the correct bounds. + int64_t range_var_count = GetRangeVarsCount(); + return symbol_id < range_var_count + ? range_vars_[symbol_id].range + : rt_vars_[symbol_id - range_var_count].feasible_values; +} + +std::vector IndexingMap::GetSymbolBounds() const { + std::vector bounds; + bounds.reserve(affine_map_.getNumSymbols()); + for (const auto& range_var : range_vars_) { + bounds.push_back(range_var.range); + } + for (const auto& rt_var : rt_vars_) { + bounds.push_back(rt_var.feasible_values); + } + return bounds; +} + +void IndexingMap::AddConstraint(mlir::AffineExpr expr, Interval range) { + if (auto dim_expr = mlir::dyn_cast(expr)) { + Interval& current_range = GetMutableDimensionBound(dim_expr.getPosition()); + current_range = Intersect(current_range, range); + return; + } + if (auto symbol_expr = mlir::dyn_cast(expr)) { + Interval& current_range = GetMutableSymbolBound(symbol_expr.getPosition()); + current_range = Intersect(current_range, range); + return; + } + if (SimplifyConstraintRange(&expr, &range)) { + AddConstraint(expr, range); + return; + } + auto [it, inserted] = constraints_.insert({expr, range}); + if (!inserted) { + it->second = Intersect(it->second, range); + } +} + +bool IndexingMap::ConstraintsSatisfied( + ArrayRef dim_const_exprs, + ArrayRef symbol_const_exprs) const { + CHECK(dim_const_exprs.size() == affine_map_.getNumDims()); + CHECK(symbol_const_exprs.size() == affine_map_.getNumSymbols()); + if (IsKnownEmpty()) { + return false; + } + for (auto& [expr, range] : constraints_) { + int64_t expr_value = + mlir::cast( + expr.replaceDimsAndSymbols(dim_const_exprs, symbol_const_exprs)) + .getValue(); + if (expr_value < range.lower || expr_value > range.upper) { + return false; + } + } + return true; +} + +SmallVector IndexingMap::Evaluate( + ArrayRef dim_const_exprs, + ArrayRef symbol_const_exprs) const { + CHECK(dim_const_exprs.size() == GetDimensionCount()); + CHECK(symbol_const_exprs.size() == GetSymbolCount()); + AffineMap eval = affine_map_.replaceDimsAndSymbols( + dim_const_exprs, symbol_const_exprs, dim_const_exprs.size(), + symbol_const_exprs.size()); + return eval.getConstantResults(); +} + +bool IndexingMap::IsKnownEmpty() const { + return llvm::any_of(dim_vars_, + [](const DimVar& dim_var) { + return dim_var.bounds.lower > dim_var.bounds.upper; + }) || + llvm::any_of(range_vars_, + [](const RangeVar& range_var) { + return range_var.range.lower > range_var.range.upper; + }) || + llvm::any_of(constraints_, + [&](const std::pair& item) { + return item.second.lower > item.second.upper; + }); +} + +RangeEvaluator::RangeEvaluator(absl::Span dim_ranges, + absl::Span symbol_ranges, + MLIRContext* mlir_context) + : mlir_context_(mlir_context) { + for (const auto& [index, range] : llvm::enumerate(dim_ranges)) { + expression_ranges_cache_[getAffineDimExpr(index, mlir_context_)] = range; + } + for (const auto& [index, range] : llvm::enumerate(symbol_ranges)) { + expression_ranges_cache_[getAffineSymbolExpr(index, mlir_context_)] = range; + } +} + +bool RangeEvaluator::IsAlwaysPositiveOrZero(mlir::AffineExpr expr) { + return ComputeExpressionRange(expr).lower >= 0; +} + +bool RangeEvaluator::IsAlwaysNegativeOrZero(mlir::AffineExpr expr) { + return ComputeExpressionRange(expr).upper <= 0; +} + +Interval RangeEvaluator::ComputeExpressionRange(AffineExpr expr) { + switch (expr.getKind()) { + case AffineExprKind::Constant: { + int64_t value = mlir::cast(expr).getValue(); + return Interval{value, value}; + } + case AffineExprKind::DimId: { + return expression_ranges_cache_[expr]; + } + case AffineExprKind::SymbolId: { + return expression_ranges_cache_[expr]; + } + default: + auto bound = expression_ranges_cache_.find(expr); + if (bound != expression_ranges_cache_.end()) { + return bound->second; + } + auto binary_op = mlir::dyn_cast(expr); + CHECK(binary_op); + auto lhs = ComputeExpressionRange(binary_op.getLHS()); + auto rhs = ComputeExpressionRange(binary_op.getRHS()); + + auto& result = expression_ranges_cache_[expr]; + switch (expr.getKind()) { + case AffineExprKind::Add: + return result = {lhs.lower + rhs.lower, lhs.upper + rhs.upper}; + case AffineExprKind::Mul: { + int64_t a = lhs.lower * rhs.lower; + int64_t b = lhs.upper * rhs.upper; + return result = {std::min(a, b), std::max(a, b)}; + } + case AffineExprKind::Mod: { + CHECK(rhs.IsPoint()) << "RHS of mod must be a constant"; + int64_t m = rhs.lower; + if (0 <= lhs.lower && lhs.upper < m) { + return result = lhs; + } + return result = {0, m - 1}; + } + case AffineExprKind::FloorDiv: { + CHECK(rhs.IsPoint()) << "RHS of floor_div must be a constant"; + int64_t d = rhs.lower; + int64_t a = FloorDiv(lhs.lower, d); + int64_t b = FloorDiv(lhs.upper, d); + return result = {std::min(a, b), std::max(a, b)}; + } + default: + // We don't use ceildiv, so we don't support it. + LOG(FATAL) << "Unsupported expression"; + } + } +} + +std::string IndexingMap::ToString(const AffineMapPrinter& printer) const { + std::stringstream ss; + Print(ss, printer); + return ss.str(); +} + +void IndexingMap::Print(std::ostream& out, + const AffineMapPrinter& printer) const { + printer.Print(out, affine_map_); + out << "\ndomain:\n"; + for (const auto& [index, dim_var] : llvm::enumerate(dim_vars_)) { + out << printer.GetDimensionName(static_cast(index)) << " in "; + dim_var.bounds.Print(out); + out << '\n'; + } + for (const auto& [index, range_var] : llvm::enumerate(range_vars_)) { + out << printer.GetSymbolName(static_cast(index)) << " in "; + range_var.range.Print(out); + out << '\n'; + } + int64_t range_vars_count = GetRangeVarsCount(); + for (const auto& [index, rt_var] : llvm::enumerate(rt_vars_)) { + out << printer.GetSymbolName(static_cast(range_vars_count + index)) + << " in "; + rt_var.feasible_values.Print(out); + out << "\n hlo: " + << (rt_var.hlo == nullptr ? "NULL" : rt_var.hlo->ToString()) << "\n "; + printer.Print(out, rt_var.map); + out << '\n'; + } + std::vector expr_range_strings; + expr_range_strings.reserve(constraints_.size()); + for (const auto& [expr, range] : constraints_) { + std::stringstream ss; + printer.Print(ss, expr); + ss << " in "; + range.Print(ss); + expr_range_strings.push_back(ss.str()); + } + std::sort(expr_range_strings.begin(), expr_range_strings.end()); + for (const auto& expr_range_string : expr_range_strings) { + out << expr_range_string << '\n'; + } +} + +MLIRContext* IndexingMap::GetMLIRContext() const { + return IsUndefined() ? nullptr : affine_map_.getContext(); +} + +std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map) { + AffineMapPrinter printer; + indexing_map.Print(out, printer); + return out; +} + +bool operator==(const IndexingMap& lhs, const IndexingMap& rhs) { + return lhs.GetAffineMap() == rhs.GetAffineMap() && + lhs.GetDimVars() == rhs.GetDimVars() && + lhs.GetRangeVars() == rhs.GetRangeVars() && + lhs.GetRTVars() == rhs.GetRTVars(); +} + +IndexingMap operator*(const IndexingMap& lhs, const IndexingMap& rhs) { + return ComposeIndexingMaps(lhs, rhs); +} + +// Simplification of IndexingMap has two main parts. +// At first we optimized constraints to make the domain as small and simple as +// possible. And only then we simplify the affine_map, because its +// simplification relies on lower/upper bounds of dimensions and symbols. + +// Constraint simplification is performed in two stages repeated until +// convergence. +// 1. Simplify affine expressions in all constraints. +// 2. Simplify constraint ranges for all constraints. +// We don't optimize every constraint separately to avoid re-initialization of +// RangeEvaluator for every constraint. Note that we start with "expr" +// simplification, because the ranges of constraints were already optimized once +// when IndexingMap was constructed. +bool IndexingMap::Simplify(IndexingMapProvider indexing_map_provider) { + if (IsUndefined()) return false; + + bool rtvars_were_eliminated = ReplaceConstantRTVars(indexing_map_provider); + + // Simplify constraints to shrink the lower/upper bounds of dims and symbols. + bool constraints_were_simplified = false; + while (true) { + if (!SimplifyConstraintExprs()) break; + constraints_were_simplified = true; + if (!SimplifyConstraintRanges()) break; + } + // Simplify dependent constraints. + MergeModConstraints(); + // Simplify affine_map using the optimized ranges. + // Potentially, we can be smarter about recreating the range_evaluator. + RangeEvaluator range_evaluator(GetDimensionBounds(), GetSymbolBounds(), + GetMLIRContext()); + AffineMap simplified_affine_map = + AffineExprSimplifier(&range_evaluator).Simplify(affine_map_); + bool affine_map_was_simplified = simplified_affine_map != affine_map_; + if (affine_map_was_simplified) { + affine_map_ = simplified_affine_map; + } + return affine_map_was_simplified || constraints_were_simplified || + rtvars_were_eliminated; +} + +bool IndexingMap::SimplifyConstraintExprs() { + // Simplify affine expression in the constraints_. + RangeEvaluator range_evaluator(GetDimensionBounds(), GetSymbolBounds(), + GetMLIRContext()); + AffineExprSimplifier simplifier(&range_evaluator); + std::vector to_remove; + std::vector> to_add; + for (const auto& [expr, range] : constraints_) { + AffineExpr simplified = simplifier.Simplify(expr); + + // Skip constraints that are always satisfied. + Interval evaluated_range = + range_evaluator.ComputeExpressionRange(simplified); + if (evaluated_range.upper <= range.upper && + evaluated_range.lower >= range.lower) { + to_remove.push_back(expr); + continue; + } + if (simplified == expr) continue; + to_add.push_back({simplified, range}); + to_remove.push_back(expr); + } + for (const auto& expr : to_remove) { + constraints_.erase(expr); + } + for (const auto& [expr, range] : to_add) { + AddConstraint(expr, range); + } + return !to_add.empty(); +} + +bool IndexingMap::SimplifyConstraintRanges() { + std::vector to_remove; + std::vector> to_add; + for (const auto& [expr, range] : constraints_) { + AffineExpr simplified_expr = expr; + Interval simplified_range = range; + if (SimplifyConstraintRange(&simplified_expr, &simplified_range)) { + to_add.push_back({simplified_expr, simplified_range}); + to_remove.push_back(expr); + } + } + for (const auto& expr : to_remove) { + constraints_.erase(expr); + } + for (const auto& [expr, range] : to_add) { + AddConstraint(expr, range); + } + return !to_add.empty(); +} + +namespace { + +struct UsedParameters { + llvm::DenseSet dimension_ids; + llvm::DenseSet symbol_ids; +}; + +void GetUsedParametersImpl(const AffineExpr& expr, + UsedParameters& used_parameters) { + if (auto dim_expr = mlir::dyn_cast(expr)) { + used_parameters.dimension_ids.insert(dim_expr.getPosition()); + return; + } + if (auto symbol_expr = mlir::dyn_cast(expr)) { + used_parameters.symbol_ids.insert(symbol_expr.getPosition()); + return; + } + if (auto binary_expr = mlir::dyn_cast(expr)) { + GetUsedParametersImpl(binary_expr.getLHS(), used_parameters); + GetUsedParametersImpl(binary_expr.getRHS(), used_parameters); + } +} + +// Returns IDs of dimensions and symbols that participate in AffineExpr. +UsedParameters GetUsedParameters(const mlir::AffineExpr& expr) { + UsedParameters used_parameters; + GetUsedParametersImpl(expr, used_parameters); + return used_parameters; +} + +bool IsFunctionOfUnusedDimsAndSymbolsOnly( + const UsedParameters& used_parameters, + const SmallBitVector& unused_dims_bit_vector, + const SmallBitVector& unused_symbols_bit_vector) { + for (int64_t dim_id : used_parameters.dimension_ids) { + if (!unused_dims_bit_vector[dim_id]) return false; + } + for (int64_t symbol_id : used_parameters.symbol_ids) { + if (!unused_symbols_bit_vector[symbol_id]) return false; + } + return true; +} + +} // namespace + +void IndexingMap::RemoveUnusedSymbols() { + if (IsUndefined()) return; + + // Remove unused symbols from the affine_map. + unsigned num_symbols_before = affine_map_.getNumSymbols(); + SmallBitVector unused_symbols_bit_vector = + mlir::getUnusedSymbolsBitVector({affine_map_}); + SmallBitVector unused_dims_bit_vector = + mlir::getUnusedDimsBitVector({affine_map_}); + + // Check if the symbols that are unused in `affine_map` are also unused in + // expressions. + std::vector> candidates_to_remove; + for (const auto& [expr, range] : constraints_) { + UsedParameters used_parameters = GetUsedParameters(expr); + // If the expression uses only symbols and dims that are "unused" in + // `affine_map`, then we can remove it. + if (IsFunctionOfUnusedDimsAndSymbolsOnly(used_parameters, + unused_dims_bit_vector, + unused_symbols_bit_vector)) { + candidates_to_remove.push_back({expr, used_parameters}); + continue; + } + // Otherwise, we need to mark all symbols of these expr as "used". + for (int64_t symbol_id : used_parameters.symbol_ids) { + unused_symbols_bit_vector[symbol_id] = false; + } + } + for (const auto& [expr, used_parameters] : candidates_to_remove) { + if (IsFunctionOfUnusedDimsAndSymbolsOnly(used_parameters, + unused_dims_bit_vector, + unused_symbols_bit_vector)) { + constraints_.erase(expr); + } + } + + // Compress `affine_map` using the updated `unused_symbols_bit_vector`. + affine_map_ = mlir::compressSymbols(affine_map_, unused_symbols_bit_vector); + + // Remap symbols in the constraint expressions accordingly. + unsigned num_symbols_after = affine_map_.getNumSymbols(); + if (num_symbols_after == num_symbols_before) return; + + std::vector compressed_range_vars; + std::vector compressed_rt_vars; + MLIRContext* mlir_context = GetMLIRContext(); + int64_t used_symbols_count = 0; + std::vector symbol_replacements( + num_symbols_before, getAffineConstantExpr(0, mlir_context)); + auto range_vars_count = range_vars_.size(); + for (int i = 0; i < unused_symbols_bit_vector.size(); ++i) { + if (!unused_symbols_bit_vector[i]) { + if (i < range_vars_count) { + compressed_range_vars.push_back(range_vars_[i]); + } else { + compressed_rt_vars.push_back(rt_vars_[i - range_vars_count]); + } + symbol_replacements[i] = + getAffineSymbolExpr(used_symbols_count++, mlir_context); + } + } + range_vars_ = std::move(compressed_range_vars); + rt_vars_ = std::move(compressed_rt_vars); + std::vector to_remove; + std::vector> to_add; + for (const auto& [expr, range] : constraints_) { + auto updated_expr = expr.replaceSymbols(symbol_replacements); + if (updated_expr == expr) continue; + to_add.push_back({updated_expr, range}); + to_remove.push_back(expr); + } + for (const auto& expr : to_remove) { + constraints_.erase(expr); + } + for (const auto& [expr, range] : to_add) { + AddConstraint(expr, range); + } +} + +void IndexingMap::MergeModConstraints() { + RangeEvaluator range_evaluator(GetDimensionBounds(), GetSymbolBounds(), + GetMLIRContext()); + + // Group constraints by LHS. + llvm::DenseMap> + grouped_constraints; + for (const auto& [expr, _] : constraints_) { + if (expr.getKind() != AffineExprKind::Mod) continue; + auto binop = mlir::cast(expr); + grouped_constraints[binop.getLHS()].push_back(binop); + } + + // Merge constraints of type MOD. + // (X mod 3 == 0) & (X mod 2 == 0) => (X mod 6 == 0) + for (const auto& [lhs, binops] : grouped_constraints) { + llvm::DenseMap> + mod_groups; + for (const auto& binop : binops) { + Interval mod_result = constraints_[binop]; + if (mod_result.IsPoint()) { + mod_groups[mod_result.lower].push_back(binop); + } + } + if (mod_groups.empty()) continue; + + // Update domain for dimensions and symbols only. + Interval* update = nullptr; + if (lhs.getKind() == AffineExprKind::DimId) { + update = &GetMutableDimensionBound( + mlir::cast(lhs).getPosition()); + } else if (lhs.getKind() == AffineExprKind::SymbolId) { + update = &GetMutableSymbolBound( + mlir::cast(lhs).getPosition()); + } + for (const auto& [res, ops] : mod_groups) { + // Calculate least common multiple for the divisors. + int64_t div = 1; + for (const auto& op : ops) { + int64_t rhs_value = + range_evaluator.ComputeExpressionRange(op.getRHS()).lower; + div = std::lcm(div, rhs_value); + } + // Replace multiple constraints with a merged one. + if (ops.size() > 1) { + for (const auto& op : ops) { + constraints_.erase(op); + } + constraints_[lhs % div] = Interval{res, res}; + } + // Update dimension and symbol bounds. + if (update != nullptr) { + int64_t l = (update->lower / div) * div + res; + update->lower = l >= update->lower ? l : l + div; + int64_t h = (update->upper / div) * div + res; + update->upper = h <= update->upper ? h : h - div; + } + } + } +} + +IndexingMap ComposeIndexingMaps(const IndexingMap& first, + const IndexingMap& second) { + if (second.IsUndefined() || first.IsUndefined()) { + return IndexingMap::GetUndefined(); + } + AffineMap producer_affine_map = second.GetAffineMap(); + AffineMap composed_map = producer_affine_map.compose(first.GetAffineMap()); + + // The symbols in the composed map, i.e. combined + // producer_map.compose(consumer_map) are packed as + // [range_vars(second)|rt_vars(second)|range_vars(first)|rt_vars(first)]. + std::vector combined_range_vars; + combined_range_vars.reserve(second.GetRangeVarsCount() + + first.GetRangeVarsCount()); + for (const RangeVar& range_var : llvm::concat( + second.GetRangeVars(), first.GetRangeVars())) { + combined_range_vars.push_back(range_var); + } + std::vector combined_rt_vars; + combined_rt_vars.reserve(second.GetRTVarsCount() + first.GetRTVarsCount()); + for (const RTVar& rt_var : + llvm::concat(second.GetRTVars(), first.GetRTVars())) { + combined_rt_vars.push_back(rt_var); + } + // The symbols in the composed map have to be permuted to keep the invariant + // that range_vars go before rt_vars in the composed affine map symbols list. + SmallVector symbol_replacements = + GetComposedSymbolsPermutationToCorrectOrder(first, second); + IndexingMap composed_indexing_map(composed_map, first.GetDimVars(), + std::move(combined_range_vars), + std::move(combined_rt_vars)); + + // Add constraints that are already present in the producer_map. We have to + // compute consumer_map(producer_constraints). To keep all symbols and + // dimension IDs the same as in the `composed_indexing_map.affine_map`, we + // create an AffineMap + // (dims of producer_affine_map)[symbols_of_producer_affine_map] = + // (constraint_1, ..., constraint_N) and then compose. + std::vector constraints; + std::vector constraints_ranges; + for (const auto& [expr, range] : second.GetConstraints()) { + constraints.push_back(expr); + constraints_ranges.push_back(range); + } + auto constraints_map = AffineMap::get( + producer_affine_map.getNumDims(), producer_affine_map.getNumSymbols(), + constraints, producer_affine_map.getContext()); + auto remapped_constraints = + constraints_map.compose(first.GetAffineMap()) + .replaceDimsAndSymbols(/*dimReplacements=*/{}, symbol_replacements, + composed_indexing_map.GetDimensionCount(), + composed_indexing_map.GetSymbolCount()); + for (const auto& [expr, range] : + llvm::zip(remapped_constraints.getResults(), constraints_ranges)) { + composed_indexing_map.AddConstraint(expr, range); + } + // Remap symbol ids and add constraints that are already present in the + // consumer_map. + SmallVector first_map_symbols_to_composed_symbols = + MapSymbolsToComposedSymbolsList(first, composed_indexing_map); + for (const auto& [expr, range] : first.GetConstraints()) { + composed_indexing_map.AddConstraint( + expr.replaceSymbols(first_map_symbols_to_composed_symbols), range); + } + // Add constraints for consumer's codomain w.r.t. producer's domain. + for (auto [index, expr] : + llvm::enumerate(first.GetAffineMap().getResults())) { + Interval producer_dim_range = + second.GetDimensionBound(static_cast(index)); + composed_indexing_map.AddConstraint( + expr.replaceSymbols(first_map_symbols_to_composed_symbols), + producer_dim_range); + } + return composed_indexing_map; +} + +bool IndexingMap::RescaleSymbols() { + MergeModConstraints(); + + std::vector to_delete; + + for (const auto& [expr, range] : constraints_) { + if (range.lower != range.upper) continue; + auto shift_value = range.lower; + + if (expr.getKind() != AffineExprKind::Mod) continue; + auto mod_expr = mlir::cast(expr); + + auto constant_expr = mlir::dyn_cast(mod_expr.getRHS()); + if (!constant_expr) continue; + + // We don't rescale mod expressions with non-positive divisors. + if (constant_expr.getValue() <= 0) continue; + auto scaling_factor = constant_expr.getValue(); + + if (mod_expr.getLHS().getKind() != AffineExprKind::SymbolId) continue; + auto symbol_expr = mlir::cast(mod_expr.getLHS()); + + affine_map_ = affine_map_.replace( + symbol_expr, constant_expr * symbol_expr + shift_value, + affine_map_.getNumDims(), affine_map_.getNumSymbols()); + + for (auto& [other_expr, other_range] : constraints_) { + if (other_expr == expr) continue; + if (!other_expr.isFunctionOfSymbol(symbol_expr.getPosition())) continue; + + other_expr = other_expr.replace( + symbol_expr, constant_expr * symbol_expr + shift_value); + } + + auto& symbol_range = range_vars_[symbol_expr.getPosition()].range; + symbol_range.lower = (symbol_range.lower - shift_value) / scaling_factor; + symbol_range.upper = (symbol_range.upper - shift_value) / scaling_factor; + + to_delete.emplace_back(expr); + } + + for (const auto& expr : to_delete) { + constraints_.erase(expr); + } + + return !to_delete.empty(); +} + +// Returns either: +// 1. an AffineExpr if the RTVar folds entirely into a constant expression +// 2. an updated RTVar if some partial optimization was possible +// 3. an unchanged RTVar if no optimization was possible +static std::variant OptimizeRTVar( + RTVar rt_var, MLIRContext* mlir_context, + IndexingMap::IndexingMapProvider indexing_map_provider) { + while (true) { + if (auto constant_expr = DynCast(rt_var.hlo)) { + if (rt_var.map.isConstant()) { + const auto idx = rt_var.map.getConstantResults(); + return getAffineConstantExpr( + constant_expr->literal().GetIntegralAsS64(idx).value(), + mlir_context); + } + return rt_var; + } + + if (auto iota_expr = DynCast(rt_var.hlo)) { + auto iota_dimension = iota_expr->iota_dimension(); + CHECK(iota_dimension < rt_var.map.getNumResults()); + return rt_var.map.getResults()[iota_dimension]; + } + + auto is_indexing_transformation = [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kBitcast || + instr->opcode() == HloOpcode::kBroadcast || + instr->opcode() == HloOpcode::kReshape || + instr->opcode() == HloOpcode::kReverse || + instr->opcode() == HloOpcode::kSlice || + instr->opcode() == HloOpcode::kTranspose; + }; + + if (is_indexing_transformation(rt_var.hlo)) { + auto instr_indexing_map = + indexing_map_provider(rt_var.hlo, 0, mlir_context); + + rt_var.hlo = rt_var.hlo->operand(0); + rt_var.map = instr_indexing_map.GetAffineMap().compose(rt_var.map); + continue; + } + + return rt_var; + } +} + +bool IndexingMap::ReplaceConstantRTVars( + IndexingMap::IndexingMapProvider indexing_map_provider) { + if (rt_vars_.empty()) return false; + + std::vector to_delete; + + for (auto index = 0; index < rt_vars_.size(); ++index) { + auto& rt_var = rt_vars_[index]; + auto result = + OptimizeRTVar(rt_var, GetMLIRContext(), indexing_map_provider); + + // If we got an RTVar back, then we just replace it and move on. + if (std::holds_alternative(result)) { + rt_var = std::get(std::move(result)); + continue; + } + + // But if we received an AffineExpr we can eliminate the RTVar from + // all expressions in the indexing map. + auto folded_expr = std::get(std::move(result)); + + // range_vars and rt_vars share the symbol space, with the rt_vars coming + // after the range_vars. + auto symbol_index = range_vars_.size() + index; + affine_map_ = affine_map_.replace( + {{mlir::getAffineSymbolExpr(symbol_index, GetMLIRContext()), + folded_expr}}); + + llvm::DenseMap replacements; + + for (const auto& [constraint, interval] : constraints_) { + auto modified_constraint = constraint.replace( + mlir::getAffineSymbolExpr(symbol_index, GetMLIRContext()), + folded_expr); + + if (constraint == modified_constraint) continue; + replacements[constraint] = modified_constraint; + } + + for (const auto& [old_expr, new_expr] : replacements) { + auto interval = constraints_.at(old_expr); + constraints_.erase(old_expr); + constraints_[new_expr] = interval; + } + + to_delete.emplace_back(index); + } + + for (auto index : llvm::reverse(to_delete)) { + rt_vars_.erase(rt_vars_.begin() + index); + } + + return !to_delete.empty(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/indexing_map.h b/xla/service/gpu/model/indexing_map.h new file mode 100644 index 0000000000000..bfc8abf30bdd3 --- /dev/null +++ b/xla/service/gpu/model/indexing_map.h @@ -0,0 +1,376 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_INDEXING_MAP_H_ +#define XLA_SERVICE_GPU_MODEL_INDEXING_MAP_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/model/affine_map_printer.h" + +namespace xla { +namespace gpu { + +// Interval represents a closed interval [lower_bound, upper_bound]. +struct Interval { + std::string ToString() const; + void Print(std::ostream& out) const; + + bool IsPoint() const { return lower == upper; } + + bool Contains(int64_t value) const { + return value >= lower && value <= upper; + } + + // The result of a range comparison. We wrap std::optional in a struct to + // avoid accidental implicit conversion to bool: + // if (range < 42) { + // Executed if the result of the comparison is known to be false! + // } + struct ComparisonResult { + // true or false if the result is known, nullopt otherwise. + std::optional result; + + ComparisonResult operator!() const { + if (result) return {!*result}; + return {result}; + } + bool operator==(const ComparisonResult& other) const { + return result == other.result; + } + bool operator==(bool other) const { return result && *result == other; } + bool operator==(std::nullopt_t) const { return !result; } + bool operator!=(std::nullopt_t) const { return result.has_value(); } + bool operator*() const { return *result; } + }; + + // All comparison operators here return true or false if the result is known, + // or nullopt if it may be either true or false. + ComparisonResult operator>(int64_t value) const { + if (lower > value) { + return {true}; + } + if (upper <= value) { + return {false}; + } + return {std::nullopt}; + } + ComparisonResult operator<(int64_t value) const { + if (upper < value) { + return {true}; + } + if (lower >= value) { + return {false}; + } + return {std::nullopt}; + } + ComparisonResult operator>=(int64_t value) const { return !(*this < value); } + ComparisonResult operator<=(int64_t value) const { return !(*this > value); } + ComparisonResult operator==(int64_t value) const { + if (IsPoint()) return {lower == value}; + if (!Contains(value)) return {false}; + return {std::nullopt}; + } + ComparisonResult operator!=(int64_t value) const { return !(*this == value); } + + int64_t lower = 0; + int64_t upper = 0; +}; + +std::ostream& operator<<(std::ostream& out, const Interval& range); +bool operator==(const Interval& lhs, const Interval& rhs); + +template +H AbslHashValue(H h, const Interval& range) { + return H::combine(std::move(h), range.lower, range.upper); +} + +// Evaluates lower and upper bounds for expressions given the domain. +// Not thread safe. +class RangeEvaluator { + public: + RangeEvaluator(absl::Span dim_ranges, + absl::Span symbol_ranges, + mlir::MLIRContext* mlir_context); + + // Checks whether an `AffineExpr` always describes a non-negative value. + bool IsAlwaysPositiveOrZero(mlir::AffineExpr expr); + + // Checks whether an `AffineExpr` always describes a non-positive value. + bool IsAlwaysNegativeOrZero(mlir::AffineExpr expr); + + // Computes the range of expression using its subexpression ranges. + Interval ComputeExpressionRange(mlir::AffineExpr expr); + + // Return MLIR context. + mlir::MLIRContext* GetMLIRContext() const { return mlir_context_; } + + private: + mlir::MLIRContext* mlir_context_; + llvm::DenseMap expression_ranges_cache_; +}; + +// Dimension variable represents a dimension of a tensor or a GPU grid. +// Dimensions correspond to the dimension parameter of `affine_map_`. +struct DimVar { + Interval bounds; +}; +bool operator==(const DimVar& lhs, const DimVar& rhs); + +template +H AbslHashValue(H h, const DimVar& dimension) { + return H::combine(std::move(h), dimension.bounds); +} + +// RangeSymbol variable represents a range of values, e.g. to compute a single +// element of the reduction's result we need a range of values from the input +// tensor. RangeSymbol variables correspond to the front portion of the +// symbols in `affine_map_`. +struct RangeVar { + Interval range; +}; +bool operator==(const RangeVar& lhs, const RangeVar& rhs); + +template +H AbslHashValue(H h, const RangeVar& range_var) { + return H::combine(std::move(h), range_var.range); +} + +// RTSymbol variable represents a runtime symbol, e.g. a dynamic offset in +// HLO dynamic-update-slice op. RTSymbol variables correspond to the back +// portion of the symbols in `affine_map_`. +struct RTVar { + Interval feasible_values; + const HloInstruction* hlo; + mlir::AffineMap map; +}; +bool operator==(const RTVar& lhs, const RTVar& rhs); + +template +H AbslHashValue(H h, const RTVar& rt_var) { + llvm::hash_code map_hash = llvm::hash_combine(rt_var.map); + return H::combine(std::move(h), rt_var.feasible_values, rt_var.hlo, + static_cast(map_hash)); +} + +std::vector DimVarsFromTensorSizes( + absl::Span tensor_sizes); + +std::vector RangeVarsFromTensorSizes( + absl::Span tensor_sizes); + +// Contains an affine map with N dimension expressions and M symbols: +// (d0, ..., d_{N - 1})[s_0, ..., s_{M - 1}] -> f(d_i, s_j) +// Dimensions d_i correspond to the iteration space of the output tensor. Some +// or all of the dimensions of the input operands can be expressed as a function +// of dimensions of output. For example, for broadcasts and cwise ops all +// dimensions of the inputs are covered by the output dimensions. +// Domain specifies for what ranges of values the indexing map is specified. +// +// Example: +// +// 1. Indexing map for the input of the following reduction +// ``` +// p0 = f32[150, 20, 10, 50] parameter(0) +// reduce = f32[150, 10] reduce(p0, p0_init), dimensions={3, 1} +// ``` +// can be written as `(d0, d1)[s0, s1] -> (d0, s0, d1, s1)` with +// d0 in [0, 149], d1 in [0, 9], s0 in [0, 19] and s1 in [0, 49]. +// +// 2. Indexing map for the input of the reverse op +// ``` +// %p0 = f32[1, 17, 9, 9] parameter(0) +// reverse = f32[1, 17, 9, 9] reverse(%p0), dimensions={1, 2} +// ``` +// can be written as `(d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3)` with +// d0 in [0, 1), d1 in [0, 16], d2 in [0, 8] and d3 in [0, 8]. +class IndexingMap { + public: + IndexingMap( + mlir::AffineMap affine_map, std::vector dimensions, + std::vector range_vars, std::vector rt_vars, + absl::Span> constraints = {}) + : affine_map_(affine_map), + dim_vars_(std::move(dimensions)), + range_vars_(std::move(range_vars)), + rt_vars_(std::move(rt_vars)) { + for (const auto& [expr, range] : constraints) { + AddConstraint(expr, range); + } + } + IndexingMap(mlir::AffineMap affine_map, std::vector dimensions, + std::vector range_vars, std::vector rt_vars, + const llvm::DenseMap& constraints) + : affine_map_(affine_map), + dim_vars_(std::move(dimensions)), + range_vars_(std::move(range_vars)), + rt_vars_(std::move(rt_vars)), + constraints_(constraints) {} + + static IndexingMap GetUndefined() { return IndexingMap(); } + + static IndexingMap FromTensorSizes( + mlir::AffineMap affine_map, absl::Span dim_upper_bounds, + absl::Span symbol_upper_bounds); + + std::string ToString( + const AffineMapPrinter& printer = AffineMapPrinter()) const; + + void Print(std::ostream& out, const AffineMapPrinter& printer) const; + + // TODO(hebecker): Rearrange code structure so that we can call + // `ComputeInputToOutputIndexing` from `:indexing_analysis` directly. + using IndexingMapProvider = llvm::function_ref; + + // Returns true if the map was simplified. + bool Simplify(IndexingMapProvider indexing_map_provider); + + // Return MLIRContext. + mlir::MLIRContext* GetMLIRContext() const; + + // Returns the affine map. + mlir::AffineMap GetAffineMap() const { return affine_map_; } + + // Getters for dimension vars. + const DimVar& GetDimVars(int64_t id) const { return dim_vars_[id]; } + const std::vector& GetDimVars() const { return dim_vars_; } + int64_t GetDimVarsCount() const { return dim_vars_.size(); } + + // Getters for range vars. + const RangeVar& GetRangeVar(int64_t id) const { return range_vars_[id]; } + const std::vector& GetRangeVars() const { return range_vars_; } + int64_t GetRangeVarsCount() const { return range_vars_.size(); } + + // Getters for runtime vars. + const RTVar& GetRTVar(int64_t id) const { return rt_vars_[id]; } + const std::vector& GetRTVars() const { return rt_vars_; } + int64_t GetRTVarsCount() const { return rt_vars_.size(); } + + // Gets bounds of `affine_map_` dimensions. + const Interval& GetDimensionBound(int64_t dim_id) const; + Interval& GetMutableDimensionBound(int64_t dim_id); + std::vector GetDimensionBounds() const; + int64_t GetDimensionCount() const { return affine_map_.getNumDims(); } + + // Gets bounds of `affine_map_` symbols. + const Interval& GetSymbolBound(int64_t symbol_id) const; + Interval& GetMutableSymbolBound(int64_t symbol_id); + std::vector GetSymbolBounds() const; + int64_t GetSymbolCount() const { return affine_map_.getNumSymbols(); } + + // Getters for affine expression constraints. + const llvm::DenseMap& GetConstraints() const { + return constraints_; + } + int64_t GetConstraintsCount() const { return constraints_.size(); } + + // Allows to add bounds for the affine expression `expr`. If there are + // bounds for the `expr`, then computes intersection of the current and new + // ranges. + void AddConstraint(mlir::AffineExpr expr, Interval range); + + // Evaluates the constraints at a given point and returns `true` if all + // constraints are satisfied. + bool ConstraintsSatisfied( + llvm::ArrayRef dim_const_exprs, + llvm::ArrayRef symbol_const_exprs) const; + + // Evaluates indexing map results at a given point. + llvm::SmallVector Evaluate( + llvm::ArrayRef dim_const_exprs, + llvm::ArrayRef symbol_const_exprs) const; + + // Returns true if the domain is empty. Right now it scans through all + // constraints to find the one where lower_bound > upper_bound. If it returns + // true, that does not mean that the domain is not effectively empty. + // For example, if there are two constraints 0 <= d0 mod 7 <= 0 and + // 0 <= d0 mod 11 <= 0 for a dimension 0<= d0 <= 50 then there is no d0 that + // satisfies both constraints. + bool IsKnownEmpty() const; + + bool IsUndefined() const { return affine_map_ == mlir::AffineMap(); } + + // Removes unused symbols from the `affine_map_` and constraints. + void RemoveUnusedSymbols(); + + // Rescales all symbols that are sufficiently constrained through `s? mod x = + // [N, N]` constraints. Returns true if a rescale took place, otherwise false. + bool RescaleSymbols(); + + private: + IndexingMap() = default; + + // Performs AffineExpr simplification for all constraints. + // Returns true if simplification was performed. + bool SimplifyConstraintExprs(); + + // Performs range simplification for all constraints. + // Returns true if simplification was performed. + bool SimplifyConstraintRanges(); + + // Merges "mod" constraints for the same AffineExpr. + void MergeModConstraints(); + + // Replace RTVars that yield constants by indexing expressions. + // Returns true if a replacement was performed, otherwise false. + bool ReplaceConstantRTVars(IndexingMapProvider indexing_map_provider); + + mlir::AffineMap affine_map_; + std::vector dim_vars_; + std::vector range_vars_; + std::vector rt_vars_; + // Inequality constraints for affine expressions. They restrict the feasible + // set for the domain of the indexing map. It contains affine expressions + // other than AffineDimExpr and AffineSymbolExpr. + llvm::DenseMap constraints_; +}; +std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map); +bool operator==(const IndexingMap& lhs, const IndexingMap& rhs); +IndexingMap operator*(const IndexingMap& lhs, const IndexingMap& rhs); + +// Composes affine maps, i.e. second ∘ first. +IndexingMap ComposeIndexingMaps(const IndexingMap& first, + const IndexingMap& second); + +template +H AbslHashValue(H h, const IndexingMap& indexing_map) { + llvm::hash_code affine_map_hash = + llvm::hash_combine(indexing_map.GetAffineMap()); + return H::combine(std::move(h), static_cast(affine_map_hash), + indexing_map.GetDimVars(), indexing_map.GetRangeVars(), + indexing_map.GetRTVars(), + indexing_map.GetConstraintsCount()); +} + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_INDEXING_MAP_H_ diff --git a/xla/service/gpu/model/indexing_map_test.cc b/xla/service/gpu/model/indexing_map_test.cc new file mode 100644 index 0000000000000..2b250db787109 --- /dev/null +++ b/xla/service/gpu/model/indexing_map_test.cc @@ -0,0 +1,919 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/indexing_map.h" + +#include +#include +#include +#include + +#include +#include +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/literal_util.h" +#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/shape_util.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +using ::mlir::AffineMap; +using ::testing::ElementsAre; + +class IndexingMapTest : public HloTestBase { + public: + mlir::MLIRContext mlir_context_; + AffineMapPrinter printer_; +}; + +TEST_F(IndexingMapTest, RTVar) { + auto zero_dim_map = AffineMap::get(&mlir_context_); + std::vector rt_vars{RTVar{Interval{0, 2}, + /*instr=*/nullptr, zero_dim_map}, + RTVar({Interval{0, 7}, + /*instr=*/nullptr, zero_dim_map})}; + + IndexingMap indexing_map( + ParseAffineMap("(d0, d1)[s0, s1, s2] -> (d1, d0, s0 + s1, s1)", + &mlir_context_), + {DimVar{{0, 99}}, DimVar{{0, 43}}}, {RangeVar{{-99, 99}}}, + std::move(rt_vars)); + printer_.SetSymbolName(0, "range"); + printer_.SetSymbolName(1, "rt_0"); + printer_.SetSymbolName(2, "rt_1"); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0, d1)[range, rt_0, rt_1] -> (d1, d0, range + rt_0, rt_0) + domain: + d0 in [0, 99] + d1 in [0, 43] + range in [-99, 99] + rt_0 in [0, 2] + hlo: NULL + () -> () + rt_1 in [0, 7] + hlo: NULL + () -> () + )")); +} + +TEST_F(IndexingMapTest, Evaluation) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), + {4, 4}, {2, 2}); + + auto results = indexing_map.Evaluate( + mlir::getAffineConstantExprs({1, 2}, &mlir_context_), + mlir::getAffineConstantExprs({3, 4}, &mlir_context_)); + EXPECT_THAT(results, ElementsAre(2, 1, 4, 3)); + + auto feasible = indexing_map.ConstraintsSatisfied( + mlir::getAffineConstantExprs({1, 2}, &mlir_context_), + mlir::getAffineConstantExprs({3, 4}, &mlir_context_)); + EXPECT_TRUE(feasible); + + indexing_map.AddConstraint(ParseAffineExpr("s0 mod 4", &mlir_context_), + Interval{0, 0}); + + auto infeasible = indexing_map.ConstraintsSatisfied( + mlir::getAffineConstantExprs({1, 2}, &mlir_context_), + mlir::getAffineConstantExprs({5, 4}, &mlir_context_)); + EXPECT_FALSE(infeasible); +} + +TEST_F(IndexingMapTest, Composition_Permutation) { + IndexingMap producer = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), + {4, 4}, {2, 2}); + + IndexingMap consumer = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), {4}, {4}); + + auto composed = ComposeIndexingMaps(consumer, producer); + EXPECT_THAT(composed, MatchIndexingMap(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0) + domain: + d0 in [0, 3] + s0 in [0, 1] + s1 in [0, 1] + s2 in [0, 3] + )")); +} + +TEST_F(IndexingMapTest, Composition_RestrictedInterval) { + IndexingMap producer = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), + {5, 6}, {7, 2}); + + IndexingMap consumer = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), {10}, {8}); + + auto composed = ComposeIndexingMaps(consumer, producer); + EXPECT_THAT(composed, MatchIndexingMap(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0) + domain: + d0 in [0, 4] + s0 in [0, 6] + s1 in [0, 1] + s2 in [0, 5] + )")); +} + +TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { + IndexingMap producer = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), + {50, 60}, {70, 20}); + producer.AddConstraint(ParseAffineExpr("d0 mod 8", &mlir_context_), + Interval{0, 0}); + producer.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), + Interval{1, 1}); + + IndexingMap consumer = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), {10}, {8}); + consumer.AddConstraint(ParseAffineExpr("d0 + s0", &mlir_context_), + Interval{0, 20}); + consumer.AddConstraint(ParseAffineExpr("s0 mod 4", &mlir_context_), + Interval{0, 0}); + + auto composed = ComposeIndexingMaps(consumer, producer); + EXPECT_THAT(composed, MatchIndexingMap(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0) + domain: + d0 in [0, 9] + s0 in [0, 69] + s1 in [0, 19] + s2 in [0, 7] + d0 + s2 in [0, 20] + d0 mod 8 in [0, 0] + s0 mod 3 in [1, 1] + s2 mod 4 in [0, 0] + )")); + composed.Simplify(GetIndexingMapForInstruction); + EXPECT_THAT(composed, MatchIndexingMap(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0) + domain: + d0 in [0, 8] + s0 in [1, 67] + s1 in [0, 19] + s2 in [0, 4] + d0 mod 8 in [0, 0] + s0 mod 3 in [1, 1] + s2 mod 4 in [0, 0] + )")); +} + +TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1)", &mlir_context_), + {50, 60}, {70, 20}); + // This constraint cannot be removed, because it contains a "used symbol". + indexing_map.AddConstraint(ParseAffineExpr("s0 + s1", &mlir_context_), + Interval{1, 100}); + indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), + Interval{0, 0}); + indexing_map.RemoveUnusedSymbols(); + EXPECT_THAT(indexing_map, MatchIndexingMap(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1) + domain: + d0 in [0, 49] + d1 in [0, 59] + s0 in [0, 69] + s1 in [0, 19] + s0 + s1 in [1, 100] + s0 mod 3 in [0, 0] + )")); +} + +TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSymbols) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1)", &mlir_context_), + {50, 60}, {70, 20}); + // This constraint can be removed, because it contains only the unused symbol. + indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), + Interval{0, 0}); + indexing_map.RemoveUnusedSymbols(); + EXPECT_THAT(indexing_map, MatchIndexingMap(R"( + (d0, d1)[s0] -> (d1, d0, s0) + domain: + d0 in [0, 49] + d1 in [0, 59] + s0 in [0, 19] + )")); +} + +TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithManySymbols) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0)[s0, s1, s2, s3, s4] -> (d0 * 4 + s1 + s3 - 42)", + &mlir_context_), + {32}, {1, 2, 3, 4, 5}); + indexing_map.AddConstraint( + ParseAffineExpr("d0 * 4 + s1 + s3", &mlir_context_), Interval{24, 459}); + indexing_map.RemoveUnusedSymbols(); + // Symbols s0, s2, s4 will be removed and s1 and s3 will become s0 and s1. + EXPECT_THAT(indexing_map, MatchIndexingMap(R"( + (d0)[s0, s1] -> (d0 * 4 + s0 + s1 - 42) + domain: + d0 in [0, 31] + s0 in [0, 1] + s1 in [0, 3] + d0 * 4 + s0 + s1 in [24, 459] + )")); +} + +TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithRTVars) { + auto zero_dim_map = AffineMap::get(&mlir_context_); + IndexingMap indexing_map( + ParseAffineMap("(d0)[s0, s1, s2, s3, s4] -> (d0 * 4 + s1 + s3 - 42)", + &mlir_context_), + {DimVar{{0, 31}}}, {RangeVar{{0, 0}}, RangeVar{{0, 1}}, RangeVar{{0, 2}}}, + {RTVar{Interval{0, 3}, + /*instr=*/nullptr, zero_dim_map}, + RTVar{Interval{0, 4}, + /*instr=*/nullptr, zero_dim_map}}); + indexing_map.AddConstraint( + ParseAffineExpr("d0 * 4 + s1 + s3", &mlir_context_), Interval{24, 459}); + indexing_map.RemoveUnusedSymbols(); + // Symbols s0, s2, s4 will be removed and s1 and s3 will become s0 and s1. + EXPECT_THAT(indexing_map, MatchIndexingMap(R"( + (d0)[s0, s1] -> (d0 * 4 + s0 + s1 - 42) + domain: + d0 in [0, 31] + s0 in [0, 1] + s1 in [0, 3] + hlo: NULL + () -> () + d0 * 4 + s0 + s1 in [24, 459] + )")); +} + +TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, {}); + + indexing_map.AddConstraint(ParseAffineExpr("(d0 mod 8) + 5", &mlir_context_), + Interval{50, 54}); + + EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + (d0) -> (d0) + domain: + d0 in [0, 99] + d0 mod 8 in [45, 49] + )")); +} + +TEST_F(IndexingMapTest, + ConstraintIntervalSimplification_FloorDivPositiveDivisorPositiveBounds) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, {}); + + indexing_map.AddConstraint(ParseAffineExpr("d0 floordiv 8", &mlir_context_), + Interval{5, 11}); + EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + (d0) -> (d0) + domain: + d0 in [40, 95] + )")); +} + +TEST_F(IndexingMapTest, + ConstraintIntervalSimplification_FloorDivPositiveDivisorNegativeBounds) { + IndexingMap indexing_map = + IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), + {DimVar{{0, 99}}}, {RangeVar{{-99, 99}}}, /*rt_vars=*/{}); + + indexing_map.AddConstraint(ParseAffineExpr("s0 floordiv 3", &mlir_context_), + Interval{-11, -5}); + EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + (d0)[s0] -> (d0) + domain: + d0 in [0, 99] + s0 in [-33, -13] + )")); +} + +TEST_F(IndexingMapTest, + ConstraintIntervalSimplification_FloorDivNegativeDivisorNegativeBounds) { + IndexingMap indexing_map = + IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), + {DimVar{{0, 99}}}, {RangeVar{{-99, 99}}}, /*rt_vars=*/{}); + + indexing_map.AddConstraint(ParseAffineExpr("s0 floordiv -3", &mlir_context_), + Interval{-11, -5}); + EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + (d0)[s0] -> (d0) + domain: + d0 in [0, 99] + s0 in [15, 35] + )")); +} + +TEST_F(IndexingMapTest, + ConstraintIntervalSimplification_MulPositiveMultiplierPositiveBounds) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, {}); + + indexing_map.AddConstraint(ParseAffineExpr("d0 * 8", &mlir_context_), + Interval{14, 33}); + EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + (d0) -> (d0) + domain: + d0 in [2, 4] + )")); +} + +TEST_F(IndexingMapTest, + ConstraintIntervalSimplification_MulPositiveMultiplierNegativeBounds) { + IndexingMap indexing_map = + IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), + {DimVar{{0, 99}}}, {RangeVar{{-99, 99}}}, /*rt_vars=*/{}); + + indexing_map.AddConstraint(ParseAffineExpr("s0 * 3", &mlir_context_), + Interval{-11, -5}); + EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + (d0)[s0] -> (d0) + domain: + d0 in [0, 99] + s0 in [-3, -2] + )")); +} + +TEST_F(IndexingMapTest, + ConstraintIntervalSimplification_MulNegativeMultiplierNegativeBounds) { + IndexingMap indexing_map = + IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), + {DimVar{{0, 99}}}, {RangeVar{{-99, 99}}}, /*rt_vars=*/{}); + + indexing_map.AddConstraint(ParseAffineExpr("s0 * -3", &mlir_context_), + Interval{-11, -5}); + EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + (d0)[s0] -> (d0) + domain: + d0 in [0, 99] + s0 in [2, 3] + )")); +} + +TEST_F(IndexingMapTest, ConstraintMerge_Mod) { + IndexingMap indexing_map( + ParseAffineMap("(d0)[s0, s1] -> (d0, s1, s0)", &mlir_context_), + {DimVar{{0, 4}}}, {RangeVar{{-21, -1}}, RangeVar{{0, 10}}}, + /*rt_vars=*/{}); + indexing_map.AddConstraint(ParseAffineExpr("d0 mod 3", &mlir_context_), + Interval{0, 0}); + indexing_map.AddConstraint(ParseAffineExpr("s0 mod 2", &mlir_context_), + Interval{0, 0}); + indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), + Interval{0, 0}); + indexing_map.AddConstraint(ParseAffineExpr("s1 mod 5", &mlir_context_), + Interval{1, 1}); + indexing_map.Simplify(GetIndexingMapForInstruction); + + EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + (d0)[s0, s1] -> (d0, s1, s0) + domain: + d0 in [0, 3] + s0 in [-18, -6] + s1 in [1, 6] + d0 mod 3 in [0, 0] + s0 mod 6 in [0, 0] + s1 mod 5 in [1, 1] + )")); +} + +TEST_F(IndexingMapTest, AffineMapSimplification_ConstantDims) { + IndexingMap indexing_map = + IndexingMap(ParseAffineMap("(d0) -> (d0)", &mlir_context_), + {DimVar{{5, 5}}}, /*range_vars=*/{}, /*rt_vars=*/{}); + indexing_map.Simplify(GetIndexingMapForInstruction); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0) -> (5) + domain: + d0 in [5, 5] + )")); +} + +TEST_F(IndexingMapTest, + AffineMapSimplification_DivsAndModsIfSmallerThanDivisor) { + auto serialized_map = "(d0, d1) -> (d0 + d1 floordiv 16, d1 mod 16)"; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {8, 16}, {}); + indexing_map.Simplify(GetIndexingMapForInstruction); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 7] + d1 in [0, 15] + )")); +} + +TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithMultipliers) { + auto serialized_map = + "(d0, d1, d2) -> ((d0 * 100 + d1 * 10 + d2) floordiv 100, " + "((d0 * 100 + d1 * 10 + d2) mod 100) floordiv 10, " + "d2 mod 10)"; + + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {9, 9, 9}, {}); + indexing_map.Simplify(GetIndexingMapForInstruction); + + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0, d1, d2) -> (d0, d1, d2) + domain: + d0 in [0, 8] + d1 in [0, 8] + d2 in [0, 8] + )")); +} + +TEST_F(IndexingMapTest, + AffineMapSimplification_DivsAndModsWithDivisibleMultipliers) { + auto serialized_map = + "(d0, d1, d2) -> ((d0 * 16 + d1 * 4 + d2) floordiv 8, " + " (d0 * 16 + d1 * 4 + d2) mod 8)"; + + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {10, 10, 10}, {}); + indexing_map.Simplify(GetIndexingMapForInstruction); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0, d1, d2) -> (d0 * 2 + (d1 + d2 floordiv 4) floordiv 2, + (d1 * 4 + d2) mod 8) + domain: + d0 in [0, 9] + d1 in [0, 9] + d2 in [0, 9] + )")); +} + +TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithReverse) { + auto serialized_map = + "(d0, d1) -> (-((d0 * -11 - d1 + 109) floordiv 11) + 9, " + "d0 * 11 + d1 + ((d0 * -11 - d1 + 109) floordiv 11) * 11 - 99)"; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {8, 9}, {}); + indexing_map.Simplify(GetIndexingMapForInstruction); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0, d1) -> (d0, d1) + domain: + d0 in [0, 7] + d1 in [0, 8] + )")); +} + +TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape) { + auto serialized_map = + "()[s0] -> ((s0 * 128) mod 715 + ((s0 * 128) floordiv 715) * 715)"; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {}, {128}); + indexing_map.Simplify(GetIndexingMapForInstruction); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + ()[s0] -> (s0 * 128) + domain: s0 in [0, 127] + )")); +} + +TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape_Regression) { + // We have s0 * 128 in the mod, but s0 * 64 in the floordiv *. + auto serialized_map = + "()[s0] -> ((s0 * 128) mod 715 + ((s0 * 64) floordiv 715) * 715)"; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {}, {128}); + indexing_map.Simplify(GetIndexingMapForInstruction); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + ()[s0] -> ((s0 * 128) mod 715 + ((s0 * 64) floordiv 715) * 715) + domain: s0 in [0, 127] + )")); +} + +TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) { + auto serialized_map = + "()[s0] -> (s0 - ((s0 floordiv 2) floordiv 7) * 14 + (s0 floordiv 14) * " + "14)"; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {}, {1234}); + indexing_map.Simplify(GetIndexingMapForInstruction); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + ()[s0] -> (s0) + domain: + s0 in [0, 1233] + )")); +} + +TEST_F(IndexingMapTest, AffineMapSimplification_DivGcdGreater1) { + auto serialized_map = + "()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 - ((s0 * 2 + s1 floordiv 64) " + "floordiv 3) * 768 + ((s0 * 128 + s1) floordiv 192) * 768)"; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {}, {1234, 128, 4}); + indexing_map.Simplify(GetIndexingMapForInstruction); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + ()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2) + domain: + s0 in [0, 1233] + s1 in [0, 127] + s2 in [0, 3] + )")); +} + +TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { + auto serialized_map = + "()[s0, s1, s2, s3] -> ((s0 * 458752 + s1 + s2 * 4 + s3 * 512) mod " + "20000)"; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {}, {872, 4, 128, 896}); + indexing_map.Simplify(GetIndexingMapForInstruction); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + ()[s0, s1, s2, s3] -> ( + s1 + (s0 * 458752 + s2 * 4 + s3 * 512) mod 20000 + ) + domain: + s0 in [0, 871] + s1 in [0, 3] + s2 in [0, 127] + s3 in [0, 895] + )")); +} + +TEST_F(IndexingMapTest, + AffineMapSimplification_ExtractFromDiv_NegativeMultiplier) { + auto serialized_map = + "()[s0, s1] -> ((s0 * 16 - (s1 floordiv 4) floordiv 2 + (s1 floordiv 8) " + "* 2) floordiv 4)"; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {}, {2, 128}); + indexing_map.Simplify(GetIndexingMapForInstruction); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + ()[s0, s1] -> ( + s0 * 4 + s1 floordiv 32 + ) + domain: + s0 in [0, 1] + s1 in [0, 127] + )")); +} + +TEST_F(IndexingMapTest, RescaleSymbols_Simple) { + auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s1, s0 floordiv 6)"; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {4}, {7, 2, 6}); + indexing_map.AddConstraint(ParseAffineExpr("s0 mod 6", &mlir_context_), + Interval{0, 0}); + + EXPECT_TRUE(indexing_map.RescaleSymbols()); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0) + domain: + d0 in [0, 3] + s0 in [0, 1] + s1 in [0, 1] + s2 in [0, 5] + )")); +} + +TEST_F(IndexingMapTest, RescaleSymbols_WithShift) { + auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s1, s0)"; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {4}, {42, 2, 6}); + indexing_map.AddConstraint(ParseAffineExpr("s0 mod 6", &mlir_context_), + Interval{3, 3}); + + // [BEFORE] Allowed values for s0: 3, 9, 15, ..., 39 = (6 * 6 + 3) + // [AFTER] Allowed values for s0: 0, 1, 2, ..., 6 + EXPECT_TRUE(indexing_map.RescaleSymbols()); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0 * 6 + 3) + domain: + d0 in [0, 3] + s0 in [0, 6] + s1 in [0, 1] + s2 in [0, 5] + )")); +} + +TEST_F(IndexingMapTest, RescaleSymbols_TwoModConstraints) { + auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s1, s0 floordiv 6)"; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {4}, {7, 2, 6}); + indexing_map.AddConstraint(ParseAffineExpr("s0 mod 2", &mlir_context_), + Interval{0, 0}); + indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), + Interval{0, 0}); + + EXPECT_TRUE(indexing_map.RescaleSymbols()); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0) + domain: + d0 in [0, 3] + s0 in [0, 1] + s1 in [0, 1] + s2 in [0, 5] + )")); +} + +TEST_F(IndexingMapTest, RescaleSymbols_RescaledSymbolInOtherConstraint) { + auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s1, s0)"; + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap(serialized_map, &mlir_context_), {4}, {10, 2, 6}); + indexing_map.AddConstraint(ParseAffineExpr("s0 mod 6", &mlir_context_), + Interval{3, 3}); + indexing_map.AddConstraint(ParseAffineExpr("s0 * s2", &mlir_context_), + Interval{0, 28}); + + EXPECT_TRUE(indexing_map.RescaleSymbols()); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0 * 6 + 3) + domain: + d0 in [0, 3] + s0 in [0, 1] + s1 in [0, 1] + s2 in [0, 5] + (s0 * 6 + 3) * s2 in [0, 28] + )")); +} + +TEST_F(IndexingMapTest, RangeEvaluatorTest) { + RangeEvaluator range_evaluator( + {Interval{0, 9}, Interval{-10, -1}, Interval{-1, 2}, Interval{0, 0}}, {}, + &mlir_context_); + mlir::AffineExpr d0, d1, d2, d3; + bindDims(&mlir_context_, d0, d1, d2, d3); + + // d0 is always positive. + EXPECT_TRUE(range_evaluator.IsAlwaysPositiveOrZero(d0)); + EXPECT_FALSE(range_evaluator.IsAlwaysNegativeOrZero(d0)); + + // d1 is always negative. + EXPECT_FALSE(range_evaluator.IsAlwaysPositiveOrZero(d1)); + EXPECT_TRUE(range_evaluator.IsAlwaysNegativeOrZero(d1)); + + // d2 is sometimes positive and sometimes negative. + EXPECT_FALSE(range_evaluator.IsAlwaysPositiveOrZero(d2)); + EXPECT_FALSE(range_evaluator.IsAlwaysNegativeOrZero(d2)); + + // d3 is always 0. + EXPECT_TRUE(range_evaluator.IsAlwaysPositiveOrZero(d3)); + EXPECT_TRUE(range_evaluator.IsAlwaysNegativeOrZero(d3)); +} + +TEST(IntervalComparisionTest, Comparisons) { + Interval interval{12, 64}; + EXPECT_EQ(interval > 11, true); + EXPECT_EQ(interval > 12, std::nullopt); + EXPECT_EQ(interval > 65, false); + + EXPECT_EQ(interval < 65, true); + EXPECT_EQ(interval < 64, std::nullopt); + EXPECT_EQ(interval < 10, false); + + EXPECT_EQ(interval == 11, false); + EXPECT_EQ(interval == 15, std::nullopt); + EXPECT_EQ(interval == 65, false); + + EXPECT_EQ(interval != 11, true); + EXPECT_EQ(interval != 15, std::nullopt); + EXPECT_EQ(interval != 65, true); + + EXPECT_EQ(interval >= 12, true); + EXPECT_EQ(interval >= 64, std::nullopt); + EXPECT_EQ(interval >= 65, false); + + EXPECT_EQ(interval <= 11, false); + EXPECT_EQ(interval <= 64, true); + EXPECT_EQ(interval <= 63, std::nullopt); + EXPECT_EQ(interval <= 65, true); + + Interval point{15, 15}; + EXPECT_EQ(point == 15, true); + EXPECT_EQ(point == 16, false); + + EXPECT_EQ(point != 15, false); + EXPECT_EQ(point != 16, true); +} + +TEST_F(IndexingMapTest, ReplaceConstantRTVars_ScalarConstant) { + // auto zero_dim_map = AffineMap::get(&mlir_context_); + auto constant = + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42)); + + IndexingMap indexing_map(ParseAffineMap("()[s0] -> (s0)", &mlir_context_), + /*dimensions=*/{}, + /*range_vars=*/{}, + {RTVar{Interval{42, 42}, constant.get(), + AffineMap::get(0, 0, {}, &mlir_context_)}}); + + EXPECT_TRUE(indexing_map.Simplify(GetIndexingMapForInstruction)); + + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + () -> (42) + domain: + )")); +} + +TEST_F(IndexingMapTest, ReplaceConstantRTVars_StaticIndexIntoTensorConstant) { + // auto zero_dim_map = AffineMap::get(&mlir_context_); + auto constant = HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 2, 3, 4}, {11, 12, 13, 14}})); + + IndexingMap indexing_map( + ParseAffineMap("()[s0] -> (s0)", &mlir_context_), + /*dimensions=*/{}, + /*range_vars=*/{}, + {RTVar{Interval{1, 14}, constant.get(), + ParseAffineMap("() -> (1,2)", &mlir_context_)}}); + + EXPECT_TRUE(indexing_map.Simplify(GetIndexingMapForInstruction)); + + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + () -> (13) + domain: + )")); +} + +TEST_F(IndexingMapTest, ReplaceConstantRTVars_NonFoldableTensor) { + // auto zero_dim_map = AffineMap::get(&mlir_context_); + auto constant = HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 2, 3, 4}, {11, 12, 13, 14}})); + + IndexingMap indexing_map( + ParseAffineMap("(d0)[s0] -> (s0)", &mlir_context_), + /*dimensions=*/{}, + /*range_vars=*/{}, + {RTVar{Interval{1, 14}, constant.get(), + ParseAffineMap("(d0) -> (1, d0)", &mlir_context_)}}); + + EXPECT_FALSE(indexing_map.Simplify(GetIndexingMapForInstruction)); +} + +TEST_F(IndexingMapTest, ReplaceConstantRTVars_Iota) { + auto iota = HloInstruction::CreateIota( + ShapeUtil::MakeShape(PrimitiveType::S64, {10, 10}), 0); + + IndexingMap indexing_map( + ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), + /*dimensions=*/{{0, 255}}, + /*range_vars=*/{}, + {RTVar{Interval{0, 9}, iota.get(), + ParseAffineMap("(d0) -> (d0, 7)", &mlir_context_)}}); + + EXPECT_TRUE(indexing_map.Simplify(GetIndexingMapForInstruction)); + + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0) -> (d0, d0) + domain: + d0 in [0, 255] + )")); +} + +TEST_F(IndexingMapTest, ReplaceConstantRTVars_IotaAsConstant) { + auto iota = HloInstruction::CreateIota( + ShapeUtil::MakeShape(PrimitiveType::S64, {10, 10}), 1); + + IndexingMap indexing_map( + ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), + /*dimensions=*/{{0, 255}}, + /*range_vars=*/{}, + {RTVar{Interval{0, 9}, iota.get(), + ParseAffineMap("(d0) -> (d0, 7)", &mlir_context_)}}); + + EXPECT_TRUE(indexing_map.Simplify(GetIndexingMapForInstruction)); + + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0) -> (d0, 7) + domain: + d0 in [0, 255] + )")); +} + +TEST_F(IndexingMapTest, ReplaceConstantRTVars_ConstraintsGetUpdated) { + auto iota = HloInstruction::CreateIota( + ShapeUtil::MakeShape(PrimitiveType::S64, {10, 10}), 0); + + IndexingMap indexing_map( + ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), + /*dimensions=*/{{0, 255}}, + /*range_vars=*/{}, + {RTVar{Interval{0, 9}, iota.get(), + ParseAffineMap("(d0) -> (d0, 7)", &mlir_context_)}}); + indexing_map.AddConstraint(ParseAffineExpr("s0 mod 2", &mlir_context_), + Interval{0, 0}); + + EXPECT_TRUE(indexing_map.Simplify(GetIndexingMapForInstruction)); + + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0) -> (d0, d0) + domain: + d0 in [0, 254] + d0 mod 2 in [0, 0] + )")); +} + +TEST_F(IndexingMapTest, ReplaceConstantRTVars_Broadcast) { + auto iota = HloInstruction::CreateIota( + ShapeUtil::MakeShape(PrimitiveType::S64, {12}), 0); + auto transpose = HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(PrimitiveType::S64, {32, 12}), iota.get(), {1}); + + // (d0, 11): d0 maps into the broadcasted dimension, so it doesn't matter + // and 11 maps to 11 in iota. + IndexingMap indexing_map( + ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), + /*dimensions=*/{{0, 31}}, + /*range_vars=*/{}, + {RTVar{Interval{0, 11}, transpose.get(), + ParseAffineMap("(d0) -> (d0, 11)", &mlir_context_)}}); + + indexing_map.Simplify(GetIndexingMapForInstruction); + + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0) -> (d0, 11) + domain: + d0 in [0, 31] + )")); +} + +TEST_F(IndexingMapTest, ReplaceConstantRTVars_ChainedNoncomputeOps) { + auto iota = HloInstruction::CreateIota( + ShapeUtil::MakeShape(PrimitiveType::S64, {12}), 0); + auto reverse = HloInstruction::CreateReverse( + ShapeUtil::MakeShape(PrimitiveType::S64, {12}), iota.get(), {0}); + auto reshape = HloInstruction::CreateReshape( + ShapeUtil::MakeShape(PrimitiveType::S64, {3, 4}), reverse.get()); + auto broadcast = HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(PrimitiveType::S64, {36, 3, 4}), reshape.get(), + {1, 2}); + + // - Iota: [0, 1, ,,,, 11] + // - Reverse: [11, 10, ..., 0] + // - Reshape: [[11, 10, 9, 8], [7, 6, 5, 4], [3, 2, 1, 0]] + // - Coordinates: (d0 floordiv 12, 3) + // - y-coordinate=3 means we index into [8, 4, 0] + // - x-coordinate=(d0 floordiv 12) means our constant looks like this: + // [8, ..., 8, 4, ..., 4, 0, ..., 0] + // - Hence our final expression: (d0 floordiv 12) * -4 + 8 + IndexingMap indexing_map( + ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), + /*dimensions=*/{{0, 35}}, + /*range_vars=*/{}, + {RTVar{ + Interval{0, 11}, broadcast.get(), + ParseAffineMap("(d0) -> (d0, d0 floordiv 12, 3)", &mlir_context_)}}); + + indexing_map.Simplify(GetIndexingMapForInstruction); + + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0) -> (d0, (d0 floordiv 12) * -4 + 8) + domain: + d0 in [0, 35] + )")); +} + +TEST_F(IndexingMapTest, ReplaceConstantRTVars_PartialRTVarRemoval) { + auto iota = HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 7, 25, 1, 7, 25, 1, 7, 25, 1, 7, 25})); + auto broadcast = HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(PrimitiveType::S64, {24, 12}), iota.get(), {1}); + + // (d0, d0 floordiv 2): d0 maps into the broadcasted dimension, so it can't be + // removed, but d0 floordiv 2 doesn't yield an affine expression so we need to + // keep the RTVar, but can optimize it by removing the broadcast. + IndexingMap indexing_map( + ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), + /*dimensions=*/{{0, 23}}, + /*range_vars=*/{}, + {RTVar{Interval{0, 512}, broadcast.get(), + ParseAffineMap("(d0) -> (d0, d0 floordiv 2)", &mlir_context_)}}); + + indexing_map.Simplify(GetIndexingMapForInstruction); + + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0)[s0] -> (d0, s0) + domain: + d0 in [0, 23] + s0 in [0, 512] + hlo: %constant = s64[12]{0} constant({...}) + (d0) -> (d0 floordiv 2) + )")); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/indexing_test_utils.cc b/xla/service/gpu/model/indexing_test_utils.cc new file mode 100644 index 0000000000000..e7b7e39ac7132 --- /dev/null +++ b/xla/service/gpu/model/indexing_test_utils.cc @@ -0,0 +1,164 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/indexing_test_utils.h" + +#include +#include +#include + +#include +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/AsmParser/AsmParser.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/tests/hlo_test_base.h" + +namespace xla { +namespace gpu { + +using ::mlir::AffineExpr; +using ::mlir::AffineMap; +using ::mlir::MLIRContext; + +HloInstruction* IndexingTestBase::ParseAndGetRoot( + absl::string_view hlo_string) { + auto module_or = ParseAndReturnVerifiedModule(hlo_string); + CHECK_OK(module_or); + module_ = std::move(module_or.value()); + return module_->entry_computation()->root_instruction(); +} + +HloInstructionIndexing IndexingTestBase::GetOutputToInputIndexing( + const HloInstruction* instr, int output_id, bool use_physical_layout) { + HloInstructionIndexing indexing = + ComputeOutputToInputIndexing(instr, output_id, &mlir_context_); + + if (!use_physical_layout) return indexing; + + IndexingMap output_permutation = GetIndexingMapFromPhysicalLayoutToLogical( + GetOutputShape(instr, output_id), &mlir_context_); + + for (const auto& [operand_id, indexing_maps] : + llvm::enumerate(indexing.indexing_maps)) { + IndexingMap operand_permutation = GetIndexingMapFromLogicalToPhysicalLayout( + instr->operand(operand_id)->shape(), &mlir_context_); + + absl::flat_hash_set operand_indexing_maps; + for (const IndexingMap& indexing_map : indexing_maps) { + auto normalized_indexing_map = indexing_map; + if (!output_permutation.GetAffineMap().isIdentity()) { + normalized_indexing_map = + ComposeIndexingMaps(output_permutation, normalized_indexing_map); + } + if (!operand_permutation.GetAffineMap().isIdentity()) { + normalized_indexing_map = + ComposeIndexingMaps(normalized_indexing_map, operand_permutation); + } + operand_indexing_maps.insert(normalized_indexing_map); + } + indexing.indexing_maps[operand_id] = operand_indexing_maps; + } + return indexing; +} + +HloInstructionIndexing IndexingTestBase::GetInputToOutputIndexing( + const HloInstruction* instr, int input_id, bool use_physical_layout) { + HloInstructionIndexing indexing = + ComputeInputToOutputIndexing(instr, input_id, &mlir_context_); + + if (!use_physical_layout) return indexing; + + IndexingMap input_permutation = GetIndexingMapFromPhysicalLayoutToLogical( + instr->operand(input_id)->shape(), &mlir_context_); + + for (const auto& [output_id, indexing_maps] : + llvm::enumerate(indexing.indexing_maps)) { + IndexingMap operand_permutation = GetIndexingMapFromLogicalToPhysicalLayout( + GetOutputShape(instr, output_id), &mlir_context_); + + absl::flat_hash_set operand_indexing_maps; + for (const IndexingMap& indexing_map : indexing_maps) { + auto normalized_indexing_map = indexing_map; + if (!input_permutation.GetAffineMap().isIdentity()) { + normalized_indexing_map = + ComposeIndexingMaps(input_permutation, normalized_indexing_map); + } + if (!operand_permutation.GetAffineMap().isIdentity()) { + normalized_indexing_map = + ComposeIndexingMaps(normalized_indexing_map, operand_permutation); + } + operand_indexing_maps.insert(normalized_indexing_map); + } + indexing.indexing_maps[output_id] = operand_indexing_maps; + } + return indexing; +} + +AffineMap ParseAffineMap(absl::string_view serialized_affine_map, + MLIRContext* context) { + std::string full_affine_map_string = + absl::StrCat("affine_map<", serialized_affine_map, ">"); + return mlir::parseAttribute(full_affine_map_string, context) + .cast() + .getValue(); +} + +// Since MLIR does not have AffineExprAttr, we construct an AffineMap and then +// retrieve its first result. +AffineExpr ParseAffineExpr(absl::string_view serialized_affine_expr, + MLIRContext* context) { + std::string full_affine_map_string = absl::StrCat( + "affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9)" + "[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (", + serialized_affine_expr, ")>"); + return mlir::parseAttribute(full_affine_map_string, context) + .cast() + .getValue() + .getResult(0); +} + +bool ApproximateMatch(std::string_view lhs, std::string_view rhs) { + size_t lhs_length = lhs.size(); + size_t rhs_length = rhs.size(); + size_t l = 0, r = 0; + while (l < lhs_length && r < rhs_length) { + while (l < lhs_length && std::isspace(lhs[l])) { + ++l; + } + while (r < rhs_length && std::isspace(rhs[r])) { + ++r; + } + if (l == lhs_length || r == rhs_length) { + continue; + } + if (lhs[l++] != rhs[r++]) { + return false; + } + } + return l == lhs_length && r == rhs_length; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/indexing_test_utils.h b/xla/service/gpu/model/indexing_test_utils.h new file mode 100644 index 0000000000000..62abd0e5e7fdb --- /dev/null +++ b/xla/service/gpu/model/indexing_test_utils.h @@ -0,0 +1,88 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_INDEXING_TEST_UTILS_H_ +#define XLA_SERVICE_GPU_MODEL_INDEXING_TEST_UTILS_H_ + +#include +#include + +#include +#include "absl/strings/string_view.h" +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/verified_hlo_module.h" + +namespace xla { +namespace gpu { + +// Matches two strings ignoring whitespaces. +bool ApproximateMatch(std::string_view lhs, std::string_view rhs); + +MATCHER(UndefinedMap, "") { return arg.IsUndefined(); } + +MATCHER_P(MatchIndexingMap, indexing_string, "") { + if (arg.IsUndefined()) { + return false; + } + return ExplainMatchResult( + true, ApproximateMatch(indexing_string, arg.ToString()), result_listener); +} + +MATCHER_P(MatchIndexingString, indexing_string, "") { + return ExplainMatchResult(true, ApproximateMatch(indexing_string, arg), + result_listener); +} + +class IndexingTestBase : public HloTestBase { + public: + HloInstruction* ParseAndGetRoot(absl::string_view hlo_string); + + HloInstructionIndexing GetOutputToInputIndexing( + const HloInstruction* instr, int output_id = 0, + bool use_physical_layout = false); + + HloInstructionIndexing GetInputToOutputIndexing( + const HloInstruction* instr, int input_id = 0, + bool use_physical_layout = false); + + mlir::MLIRContext mlir_context_; + std::unique_ptr module_; +}; + +HloInstructionIndexing ComputeOutputToInputIndexingForEntryComputation( + HloTestBase* test_base, mlir::MLIRContext* mlir_context, + absl::string_view hlo_string, int output_id = 0, + bool use_physical_layout = false); + +HloInstructionIndexing ComputeInputToOutputIndexingForEntryComputation( + HloTestBase* test_base, mlir::MLIRContext* mlir_context, + absl::string_view hlo_string, int input_id = 0, + bool use_physical_layout = false); + +mlir::AffineMap ParseAffineMap(absl::string_view serialized_affine_map, + mlir::MLIRContext* context); + +mlir::AffineExpr ParseAffineExpr(absl::string_view serialized_affine_expr, + mlir::MLIRContext* context); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_INDEXING_TEST_UTILS_H_ diff --git a/xla/service/gpu/model/symbolic_tile.cc b/xla/service/gpu/model/symbolic_tile.cc new file mode 100644 index 0000000000000..98ed3ecc4dff4 --- /dev/null +++ b/xla/service/gpu/model/symbolic_tile.cc @@ -0,0 +1,364 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/symbolic_tile.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { +namespace { + +using ::mlir::AffineExpr; +using ::mlir::AffineExprKind; +using ::mlir::AffineMap; +using ::mlir::getAffineConstantExpr; +using ::mlir::getAffineDimExpr; +using ::mlir::MLIRContext; + +// Internal helper that checks whether an affine map describes a tileable space. +// In simple terms, this currently returns true if "dimensions don't mix", i.e., +// every result expression only refers to a single dimension (or symbol). +// +// TODO(b/328427138): this is too restrictive for expressions involving e.g. +// (output-to-input) split reshapes, where several symbols may appear within the +// same expression but still yield a tileable space. This will be handled in a +// forthcoming change. +bool IndexingMapDescribesTileableSpace(const IndexingMap& indexing_map) { + for (AffineExpr result_expr : indexing_map.GetAffineMap().getResults()) { + // Using a simple integer here might be overly restrictive, since there may + // be cases where the same symbol appears in several places within the + // expression. It is a bit unclear whether this is a case that would happen + // in practice and whether we would be able to handle it well in all cases + // if it did. For that reason, we err on the side of conservatism and + // explicitly do not support such cases. + int64_t num_hits = 0; + result_expr.walk([&num_hits](AffineExpr expr) { + if (expr.getKind() == AffineExprKind::SymbolId || + expr.getKind() == AffineExprKind::DimId) { + ++num_hits; + } + }); + + if (num_hits > 1) { + return false; + } + } + return true; +} + +// Helper to perform function application to using the same parameter for every +// dimension and symbol parameter. +AffineMap SubstituteAllIndicesAndKnownSymbolsWithSameValue(AffineMap affine_map, + AffineExpr value) { + MLIRContext* mlir_context = affine_map.getContext(); + int64_t num_dims = affine_map.getNumDims(); + int64_t num_symbols = affine_map.getNumSymbols(); + llvm::DenseMap indices; + + for (int64_t i = 0; i < num_dims; ++i) { + indices[getAffineDimExpr(i, mlir_context)] = value; + } + + for (int64_t i = 0; i < num_symbols; ++i) { + indices[getAffineSymbolExpr(i, mlir_context)] = value; + } + + return simplifyAffineMap(affine_map.replace(indices, num_dims, num_symbols)); +} + +struct SizeAndStrideExpression { + AffineExpr size; + AffineExpr stride; +}; + +// Converts a dimension expression to a symbol expression with the corresponding +// index. +AffineExpr ToSymbol(mlir::AffineDimExpr dim_expr) { + return mlir::getAffineSymbolExpr(dim_expr.getPosition(), + dim_expr.getContext()); +} + +// Extracts size and stride expressions from the operands to a modulo +// expression. +// +// TODO(b/326998704): Currently, this fails when the stride is not exactly unit. +std::optional ExtractSizeAndStrideFromMod( + AffineExpr lhs, AffineExpr modulus) { + // TODO(b/326998704): derive constraints here, as well as the non-one stride + // case, both in the code and in the proof. + // Let f(d0) = d0 mod c. Then, given an input tile size n, + // {f(x) | x in Fin(n)} contains: + // * n elements if n < c (and we add a constraint such that c | n); + // * c elements if n >= c (and we add a constraint such that n | c). + // Given these constraints and assumptions, we derive + // card({f(x) | x in Fin(n)}) = n - ((n - 1) floordiv n) * n. + // Proof: + // * n < c (and c | n): + // n - ((n - 1) floordiv c) * c + // = n - 0 * c (n < c => n floordiv c == 0) + // = n + // * n >= c (and n | c): + // n - ((n - 1) floordiv c) * c + // = n - (n / c - 1) * c (n | c => (n - 1) floordiv c = n / c - 1) + // = n - (n - c) + // = c + CHECK(modulus.getKind() == AffineExprKind::Constant); + if (auto dim_expr = llvm::dyn_cast(lhs)) { + AffineExpr sym = ToSymbol(dim_expr); + AffineExpr size = sym - mlir::getAffineBinaryOpExpr( + AffineExprKind::FloorDiv, sym - 1, modulus) * + modulus; + // In this case, stride is effectively 1 mod modulus = 1. + return SizeAndStrideExpression{ + size, /*stride=*/getAffineConstantExpr(1, lhs.getContext())}; + } + + return std::nullopt; +} + +// Extracts size and stride expressions from the operands to a floordiv +// expression. +// +// TODO(b/326998704): Currently, this fails when the numerator of the stride +// is not exactly unit. +std::optional ExtractSizeAndStrideFromFloorDiv( + AffineExpr num, AffineExpr den) { + if (den.getKind() != AffineExprKind::Constant) { + return std::nullopt; + } + + if (auto dim_expr = llvm::dyn_cast(num)) { + // Let f(d0) = d0 floordiv c. Then, given an input tile size n, + // {f(x) | x in Fin(n)} contains n ceildiv c elements, with stride + // (1 ceildiv c) = 1. + // + // We represent `a ceildiv b` as `(a + b - 1) floordiv b`, since indexing + // maps are not compatible with CeilDiv affine expressions. + AffineExpr size = mlir::getAffineBinaryOpExpr( + AffineExprKind::FloorDiv, ToSymbol(dim_expr) + (den - 1), den); + return SizeAndStrideExpression{ + size, /*stride=*/getAffineConstantExpr(1, num.getContext())}; + } + + return std::nullopt; +} + +std::optional ExtractSizeAndStride( + AffineExpr strided_indexing, absl::Span symbol_intervals) { + MLIRContext* ctx = strided_indexing.getContext(); + // Deal with the symbol case (capturing a whole untiled dimension). + // TODO(b/330906085): concatenating across a reduction dimension needs to be + // handled by this code. + if (auto symbol = llvm::dyn_cast(strided_indexing)) { + const Interval& symbol_interval = symbol_intervals[symbol.getPosition()]; + if (symbol_interval.lower != 0) { + return std::nullopt; + } + + return SizeAndStrideExpression{ + /*size=*/getAffineConstantExpr(symbol_interval.upper + 1, ctx), + /*stride=*/getAffineConstantExpr(1, ctx)}; + } + + AffineMapPrinter printer; + + // TODO(b/328427138): support multivariate size expressions. + switch (strided_indexing.getKind()) { + case AffineExprKind::DimId: + return SizeAndStrideExpression{ + /*size=*/ToSymbol(llvm::cast(strided_indexing)), + /*stride=*/getAffineConstantExpr(1, ctx)}; + case mlir::AffineExprKind::Mul: { + const auto mul = llvm::cast(strided_indexing); + AffineExpr lhs = mul.getLHS(); + // The stride may not be fully collapsed if it is negative; in that case, + // we need to extract the negative multiplier first. + if (const auto rhs = + llvm::dyn_cast(mul.getRHS()); + rhs && rhs.getValue() == -1) { + std::optional maybe_size_and_stride = + ExtractSizeAndStride(lhs, symbol_intervals); + if (!maybe_size_and_stride.has_value()) { + return std::nullopt; + } + + return SizeAndStrideExpression{ + /*size=*/maybe_size_and_stride->size, + /*stride=*/maybe_size_and_stride->stride * rhs}; + } + CHECK(lhs.getKind() == AffineExprKind::DimId); + return SizeAndStrideExpression{ + /*size=*/ToSymbol(llvm::cast(lhs)), + /*stride=*/mul.getRHS()}; + } + case mlir::AffineExprKind::Mod: { + auto mod = llvm::cast(strided_indexing); + return ExtractSizeAndStrideFromMod(mod.getLHS(), mod.getRHS()); + } + case mlir::AffineExprKind::FloorDiv: { + auto floor_div = llvm::cast(strided_indexing); + return ExtractSizeAndStrideFromFloorDiv(floor_div.getLHS(), + floor_div.getRHS()); + }; + case mlir::AffineExprKind::Constant: + return SizeAndStrideExpression{/*size=*/getAffineConstantExpr(1, ctx), + /*stride=*/getAffineConstantExpr(0, ctx)}; + case mlir::AffineExprKind::SymbolId: + VLOG(1) << "Encountered complex size expression involving symbol " + << printer.ToString(strided_indexing); + return std::nullopt; + case mlir::AffineExprKind::Add: + // TODO(b/328427138): this should only be necessary in the multivariate + // case, and will be implemented later. + VLOG(1) << "Encountered complex strided indexing expression " + << printer.ToString(strided_indexing); + return std::nullopt; + case mlir::AffineExprKind::CeilDiv: + break; + }; + LOG(FATAL) << "unreachable"; +} + +} // anonymous namespace + +/*static*/ std::optional SymbolicTile::FromIndexingMap( + const IndexingMap& indexing_map) { + // Bail out on runtime offsets. + if (indexing_map.GetRTVarsCount()) { + return std::nullopt; + } + // TODO(b/328427138): handle multiple symbols in a single tile to support + // merging dimensions. + if (!IndexingMapDescribesTileableSpace(indexing_map)) { + return std::nullopt; + } + + AffineMap input_affine_map = indexing_map.GetAffineMap(); + MLIRContext* mlir_context = input_affine_map.getContext(); + + // If indexing_map describes a tileable space, then input_affine_map can be + // expressed as + // f(dim0, ..., dim{M-1})[sym0, ..., sym{P-1}] = (expr0, ..., expr{N-1}) + // where the result expressions expr0, ..., expr{N-1} are strided expressions + // of the form + // offset_expr{i} + stride_expr{i} * index_expr{i} + // with 0 <= i < N. + // + // We are interested in extracting expressions for offset_expr{i}, + // stride_expr{i}, and size_expr{i} (the count of different values that + // expr{i} can represent). + // + // We have that the following equations hold: + // + // (1) f(0, ..., 0)[0, ..., 0]{i} + // = offset_expr{i} + stride_expr{i} * 0 + // = offset_expr{i} + // + // (2) f(x0, ..., x{M-1})[x{M}, ..., x{M+P-1}]{i} - f(0, ..., 0)[0, ..., 0]{i} + // = offset_expr{i} + stride_expr{i} * index_expr{i} - offset_expr{i} + // = stride_expr{i} * index_expr{i} + // + // offset_expressions = f(0, ..., 0)[0, ..., 0]. + std::vector offset_expressions = + SubstituteAllIndicesAndKnownSymbolsWithSameValue( + input_affine_map, getAffineConstantExpr(0, mlir_context)) + .getResults(); + + std::vector size_expressions; + std::vector stride_expressions; + size_expressions.reserve(offset_expressions.size()); + stride_expressions.reserve(offset_expressions.size()); + + // strided_indexing_expressions = + // f(x0, ..., x{M-1})[x{M}, ..., x{M+P-1}] - offset_expressions + for (auto [composite_indexing, offset] : + llvm::zip(input_affine_map.getResults(), offset_expressions)) { + std::optional maybe_size_and_stride = + ExtractSizeAndStride(composite_indexing - offset, + indexing_map.GetSymbolBounds()); + if (!maybe_size_and_stride.has_value()) { + return std::nullopt; + } + size_expressions.push_back(maybe_size_and_stride->size); + stride_expressions.push_back(maybe_size_and_stride->stride); + } + + // Eliminate negative strides and recalculate offsets. + std::vector dim_replacements, sym_replacements; + for (auto [offset, size, stride] : + llvm::zip(offset_expressions, size_expressions, stride_expressions)) { + auto constant = llvm::dyn_cast(stride); + if (!constant) { + AffineMapPrinter printer; + VLOG(1) << "Unexpected non-constant stride expression: " + << printer.ToString(stride); + return std::nullopt; + } + if (constant.getValue() < 0) { + offset = offset + size * stride - stride; + stride = -stride; + } + } + + int64_t num_symbols = input_affine_map.getNumDims(); + AffineMap offset_map = + AffineMap::get(0, num_symbols, offset_expressions, mlir_context); + AffineMap size_map = + AffineMap::get(0, num_symbols, size_expressions, mlir_context); + AffineMap stride_map = + AffineMap::get(0, num_symbols, stride_expressions, mlir_context); + + return SymbolicTile(offset_map, size_map, stride_map); +} + +std::string SymbolicTile::ToString(const AffineMapPrinter& printer) const { + std::string s; + std::stringstream ss(s); + Print(ss, printer); + return ss.str(); +} + +void SymbolicTile::Print(std::ostream& out, + const AffineMapPrinter& printer) const { + out << "Symbolic tile with \n"; + out << "\toffset_map: "; + printer.Print(out, offset_map_); + out << "\n\tsize_map: "; + printer.Print(out, size_map_); + out << "\n\tstride_map: "; + printer.Print(out, stride_map_); + out << "\n"; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/symbolic_tile.h b/xla/service/gpu/model/symbolic_tile.h new file mode 100644 index 0000000000000..93dc5ae25b031 --- /dev/null +++ b/xla/service/gpu/model/symbolic_tile.h @@ -0,0 +1,81 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_SYMBOLIC_TILE_H_ +#define XLA_SERVICE_GPU_MODEL_SYMBOLIC_TILE_H_ + +#include +#include +#include + +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { + +// A tile describes a structured subset of indices inside an N-dimensional +// array, where the set of indices captured along each dimension can be +// expressed as a strided expression +// offset + stride * iota(size) +// with offset, stride, and size three integers, and iota the usual range +// function. These values may never be negative. +// +// A N-dimensional symbolic tile is a function from offsets, strides, and sizes +// to a N-dimensional tile. It can be represented as three affine maps with +// domain +// ()[size0, ..., size{M-1}] +// and respective co-domains +// (offset0, ..., offset{N-1}) (offset_map()) +// (size0', ..., size'{N-1}) (size_map()) +// (stride0, ..., stride{N-1}) (stride_map()) +// where maps respectively encode the offset, size, and stride component of +// each strided expression in the tile. The parameters to the maps above are all +// assumed to be strictly positive. The input offsets are assumed to be all 0s, +// and the input strides are assumed to be all 1s. +// +// A symbolic tile with M symbols and N results is constructed using an +// `IndexingMap` with M input dimensions and N results. The construction of the +// symbolic tile may fail if any one of the resulting expressions is not a +// strided expression as described above. +class SymbolicTile { + public: + static std::optional FromIndexingMap( + const IndexingMap& indexing_map); + + std::string ToString( + const AffineMapPrinter& printer = AffineMapPrinter()) const; + + void Print(std::ostream& out, const AffineMapPrinter& printer) const; + + mlir::AffineMap offset_map() const { return offset_map_; } + mlir::AffineMap size_map() const { return size_map_; } + mlir::AffineMap stride_map() const { return stride_map_; } + + private: + mlir::AffineMap offset_map_; + mlir::AffineMap size_map_; + mlir::AffineMap stride_map_; + + SymbolicTile(mlir::AffineMap offset_map, mlir::AffineMap size_map, + mlir::AffineMap stride_map) + : offset_map_(offset_map), size_map_(size_map), stride_map_(stride_map) {} +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_SYMBOLIC_TILE_H_ diff --git a/xla/service/gpu/model/symbolic_tile_analysis.cc b/xla/service/gpu/model/symbolic_tile_analysis.cc new file mode 100644 index 0000000000000..424c11266e751 --- /dev/null +++ b/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -0,0 +1,275 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/symbolic_tile_analysis.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/symbolic_tile.h" +#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" +#include "xla/service/instruction_fusion.h" +#include "xla/status.h" +#include "xla/status_macros.h" + +namespace xla { +namespace gpu { + +namespace { + +using ::mlir::AffineExpr; +using ::mlir::MLIRContext; + +// Computes indexing map from program id into the tile offset for the given +// shape and tile sizes. +IndexingMap ComputeProgramIdToOutputTileIndexing( + absl::Span dimensions, absl::Span tile_sizes, + mlir::MLIRContext* mlir_context) { + CHECK_EQ(dimensions.size(), tile_sizes.size()); // Crash OK + + int num_tiles = 1; + std::vector outer_loop_bounds; + outer_loop_bounds.reserve(dimensions.size()); + for (auto [dim_size, tile_size] : llvm::zip(dimensions, tile_sizes)) { + int num_tiles_per_dim = (dim_size + tile_size - 1) / tile_size; + + num_tiles *= num_tiles_per_dim; + outer_loop_bounds.push_back(num_tiles_per_dim); + } + + mlir::AffineExpr program_id = mlir::getAffineDimExpr(0, mlir_context); + + // Delinearize the block id. + auto tile_exprs = + DelinearizeIndex(outer_loop_bounds, program_id, mlir_context); + + // Scale each index by the tile size to produce tile offset. + for (auto [tile_expr, tile_size] : llvm::zip(tile_exprs, tile_sizes)) { + tile_expr = tile_expr * tile_size; + } + + return IndexingMap::FromTensorSizes( + mlir::AffineMap::get( + /*dimCount=*/1, /*symbolCount=*/0, tile_exprs, mlir_context), + /*dim_upper_bounds=*/{num_tiles}, /*symbol_upper_bounds=*/{}); +} + +} // namespace + +/*static*/ SymbolicTileAnalysisOrError SymbolicTileAnalysis::AnalyzeComputation( + const HloComputation& computation, MLIRContext* ctx) { + std::vector> + tiled_hlo_instructions; + absl::flat_hash_map, + SymbolicTiledHloInstruction*> + tiled_hlo_instructions_map; + + absl::flat_hash_map topological_order; + + std::function( + const HloInstruction*, IndexingMap)> + get_tiled_hlo_instruction; + + // Create a new tiled hlo instruction or return existing instruction from + // cache for the given hlo and indexing map. + get_tiled_hlo_instruction = [&](const HloInstruction* hlo, + IndexingMap indexing_map) + -> std::variant { + auto key = std::make_pair(hlo, indexing_map); + + auto it = tiled_hlo_instructions_map.find(key); + if (it != tiled_hlo_instructions_map.end()) { + return it->second; + } + + // Bail out on instructions that are known to cause problems down the + // line. This is not an inherent limitation of the approach, but simply + // issues to be resolved in the current implementation. + if (hlo->opcode() == HloOpcode::kDot || + hlo->opcode() == HloOpcode::kReshape || + hlo->opcode() == HloOpcode::kBitcast || + hlo->opcode() == HloOpcode::kConcatenate) { + return FusionDecision{} << "Bailing out on " << hlo->ToString(); + } + + // Bail out on instructions that do not output a single array. + if (!hlo->shape().IsArray()) { + return FusionDecision{} << hlo->ToString() + << " outputs more than a single array"; + } + + auto symbolic_tile = SymbolicTile::FromIndexingMap(indexing_map); + if (!symbolic_tile.has_value()) { + return FusionDecision{} << "Failed to compute symbolic tile for " + << indexing_map.ToString() << " for HLO " + << hlo->ToString(); + } + + tiled_hlo_instructions.push_back( + std::make_unique( + hlo, std::move(indexing_map), std::move(*symbolic_tile))); + + auto tiled_hlo_instruction = tiled_hlo_instructions.back().get(); + + std::optional operands_indexing = + ComputeOutputToInputIndexing(tiled_hlo_instruction->hlo(), + /*output_id=*/0, ctx); + + if (!operands_indexing.has_value()) { + return FusionDecision{} << "Failed to compute operands indexing for " + << tiled_hlo_instruction->hlo()->ToString(); + } + + for (auto [operand, operand_indexing_map_set] : + llvm::zip(tiled_hlo_instruction->hlo()->operands(), + operands_indexing->indexing_maps)) { + CHECK_EQ(operand_indexing_map_set.size(), 1); // Crash OK + + IndexingMap operand_indexing_map = + ComposeIndexingMaps(tiled_hlo_instruction->indexing_map(), + *operand_indexing_map_set.begin()); + + auto tiled_operand_or = + get_tiled_hlo_instruction(operand, std::move(operand_indexing_map)); + + if (auto fusion_decison = + std::get_if(&tiled_operand_or)) { + return *fusion_decison; + } + + tiled_hlo_instruction->AppendOperand( + std::get(tiled_operand_or)); + } + + topological_order[tiled_hlo_instruction] = topological_order.size(); + tiled_hlo_instructions_map.emplace(key, tiled_hlo_instruction); + return tiled_hlo_instruction; + }; + + const HloInstruction* root = computation.root_instruction(); + auto tiled_root = + get_tiled_hlo_instruction(root, CreateIdentityMap(root->shape(), ctx)); + if (auto* fusion_decision = std::get_if(&tiled_root)) { + return *fusion_decision; + } + + // Order instructions in def-before-use order. + absl::c_sort(tiled_hlo_instructions, [&](const auto& i1, const auto& i2) { + return topological_order.at(i1.get()) < topological_order.at(i2.get()); + }); + + return SymbolicTileAnalysis(std::move(tiled_hlo_instructions), ctx); +} + +std::vector SymbolicTileAnalysis::TileOffsets( + const SymbolicTiledHloInstruction& tiled_hlo) const { + CHECK(tile_parameters_.has_value()) // Crash OK + << "SetTileSizes() must be called before TileOffsets()"; + return tiled_hlo.TileOffsets(*tile_parameters_); +} + +// TODO(bchetioui): remove dependency on stride and offset parameters. +std::vector SymbolicTileAnalysis::TileSizes( + const SymbolicTiledHloInstruction& tiled_hlo) const { + CHECK(tile_parameters_.has_value()) // Crash OK + << "SetTileSizes() must be called before TileSizes()"; + return tiled_hlo.TileSizes(*tile_parameters_); +} + +std::vector SymbolicTileAnalysis::TileStrides( + const SymbolicTiledHloInstruction& tiled_hlo) const { + CHECK(tile_parameters_.has_value()) // Crash OK + << "SetTileSizes() must be called before TileStrides()"; + return tiled_hlo.TileStrides(*tile_parameters_); +} + +absl::StatusOr +SymbolicTileAnalysis::ComputeBlockIdToTileOffsetIndexing( + const SymbolicTiledHloInstruction& tiled_hlo) const { + TF_RET_CHECK(block_id_to_root_tile_offset_.has_value()) + << "SetTileSizes() must be called before " + "ComputeBlockIdToTileOffsetIndexing()"; + + IndexingMap block_id_to_tile_offset_indexing = ComposeIndexingMaps( + *block_id_to_root_tile_offset_, tiled_hlo.indexing_map()); + + // A symbol in an indexing map means that to produce on element of output, we + // need to read all elements of input in the symbol range. Since this function + // computes start of the tile, we need to substitute each symbol with its + // lower bound value. We assume here the iteration order is normalized. + // TODO(b/330906085): Support cases when tile offsets are not 0. + if (absl::c_any_of(block_id_to_tile_offset_indexing.GetSymbolBounds(), + [](const Interval& symbol_bound) { + return symbol_bound.lower != 0; + })) { + return absl::FailedPreconditionError( + absl::StrCat("Symbol lower bound is not zero. ", + block_id_to_tile_offset_indexing.ToString())); + } + + std::vector symbol_lower_bounds( + block_id_to_tile_offset_indexing.GetSymbolCount(), + mlir::getAffineConstantExpr(0, context_)); + + mlir::AffineMap simplified_affine_map = + block_id_to_tile_offset_indexing.GetAffineMap().replaceDimsAndSymbols( + /*dimReplacements=*/{}, symbol_lower_bounds, + block_id_to_tile_offset_indexing.GetDimVarsCount(), + /*numResultSyms=*/ + block_id_to_tile_offset_indexing.GetRangeVarsCount()); + + IndexingMap simplified_indexing_map = IndexingMap{ + simplified_affine_map, block_id_to_tile_offset_indexing.GetDimVars(), + block_id_to_tile_offset_indexing.GetRangeVars(), + block_id_to_tile_offset_indexing.GetRTVars()}; + + simplified_indexing_map.Simplify(GetIndexingMapForInstruction); + simplified_indexing_map.RescaleSymbols(); + simplified_indexing_map.RemoveUnusedSymbols(); + + return simplified_indexing_map; +} + +void SymbolicTileAnalysis::SetTileSizes(std::vector sizes) { + block_id_to_root_tile_offset_ = ComputeProgramIdToOutputTileIndexing( + GetRoot()->hlo()->shape().dimensions(), sizes, context_); + + // TODO(bchetioui): CHECK num parameters somehow? + tile_parameters_ = std::vector(std::move(sizes)); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/symbolic_tile_analysis.h b/xla/service/gpu/model/symbolic_tile_analysis.h new file mode 100644 index 0000000000000..db643f924fe8a --- /dev/null +++ b/xla/service/gpu/model/symbolic_tile_analysis.h @@ -0,0 +1,120 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_SYMBOLIC_TILE_ANALYSIS_H_ +#define XLA_SERVICE_GPU_MODEL_SYMBOLIC_TILE_ANALYSIS_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" +#include "xla/service/instruction_fusion.h" + +namespace xla { +namespace gpu { + +class SymbolicTileAnalysis; +using SymbolicTileAnalysisOrError = + std::variant; + +// Constructs and holds symbolic tiles for all the instructions within a +// computation. We may hold several different symbolic tiles for the same +// instruction if the instruction is indexed in several different ways in order +// to produce a single chunk of the output. In order to handle this properly, +// we store a symbolic tile for each possible path starting from the root +// instruction of the computation to the relevant instruction. +class SymbolicTileAnalysis { + public: + // Tries to construct a symbolic tile analysis from a computation. Returns + // a diagnostic if the construction fails for any reason. + static SymbolicTileAnalysisOrError AnalyzeComputation( + const HloComputation& computation, mlir::MLIRContext* ctx); + + // Evaluates the tile offsets of an instruction from the analyzed computation + // following the provided path from the root. Tile parameters must have been + // set before calling this method. + std::vector TileOffsets( + const SymbolicTiledHloInstruction& tiled_hlo) const; + // Evaluates the tile sizes of an instruction from the analyzed computation + // following the provided path from the root. Tile parameters must have been + // set before calling this method. + std::vector TileSizes( + const SymbolicTiledHloInstruction& tiled_hlo) const; + // Evaluates the tile strides of an instruction from the analyzed computation + // following the provided path from the root. Tile parameters must have been + // set before calling this method. + std::vector TileStrides( + const SymbolicTiledHloInstruction& tiled_hlo) const; + + // Computes the indexing map from block id to tile offset of the tiled HLO + // instruction. The indexing map has the following form: + // + // (block_id) -> (tile_offset0, tile_offset1, ...) + absl::StatusOr ComputeBlockIdToTileOffsetIndexing( + const SymbolicTiledHloInstruction& tiled_hlo) const; + + // Populates input tile sizes. This is a prerequisite in order to extract + // concrete values using `TileOffsets`, `TileSizes`, and `TileStrides`. + void SetTileSizes(std::vector sizes); + + // Returns the tiled root instruction. + const SymbolicTiledHloInstruction* GetRoot() const { + return tiled_hlo_instructions_.back().get(); + } + + // Returns the tiled HLO instructions in def-before-use order. + const std::vector>& + GetTiledHloInstructions() const { + return tiled_hlo_instructions_; + } + + // Return the underlying MLIRContext. + mlir::MLIRContext* GetMLIRContext() const { return context_; }; + + private: + SymbolicTileAnalysis(std::vector> + tiled_hlo_instructions, + mlir::MLIRContext* context) + : tiled_hlo_instructions_(std::move(tiled_hlo_instructions)), + context_(context) {} + + // The tiled HLO instructions in def-before-use order. + std::vector> + tiled_hlo_instructions_; + + mlir::MLIRContext* context_; + // Optionally set tile parameters. These parameters can be set by calling + // `SetTileParameters`, and correspond to the output tile for the analyzed + // computation. The order and type of parameters are as explained in the + // documentation of `SymbolicTile`. + std::optional> tile_parameters_; + + // Indexing map from block id to root tile offset. Computed from the tile + // parameters. + std::optional block_id_to_root_tile_offset_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_SYMBOLIC_TILE_ANALYSIS_H_ diff --git a/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/xla/service/gpu/model/symbolic_tile_analysis_test.cc new file mode 100644 index 0000000000000..8925c46f91721 --- /dev/null +++ b/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -0,0 +1,236 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/symbolic_tile_analysis.h" + +#include +#include +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +using ::testing::ElementsAre; + +class SymbolicTileAnalysisTest : public HloTestBase { + public: + bool SetAnalysis(HloModule* module) { + SymbolicTileAnalysisOrError analysis_or_error = + SymbolicTileAnalysis::AnalyzeComputation(*module->entry_computation(), + &mlir_context_); + + if (std::holds_alternative(analysis_or_error)) { + analysis_ = std::get(std::move(analysis_or_error)); + return true; + } + return false; + } + + mlir::MLIRContext mlir_context_; + std::optional analysis_; +}; + +TEST_F(SymbolicTileAnalysisTest, SimpleNormalizationDiamondIsSupported) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +max { + p1 = f32[] parameter(1) + p0 = f32[] parameter(0) + ROOT m = f32[] maximum(p0, p1) +} + +ENTRY main { + p0 = f32[2,97]{1,0} parameter(0) + constant = f32[] constant(-inf) + reduce = f32[2] reduce(p0, constant), dimensions={1}, to_apply=max + broadcast = f32[2,97]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = f32[2,97]{1,0} subtract(p0, broadcast) +})")); + + EXPECT_TRUE(SetAnalysis(module.get())); + + const SymbolicTiledHloInstruction* root = analysis_->GetRoot(); + + analysis_->SetTileSizes(/*sizes=*/{1, 10}); + + EXPECT_THAT(*analysis_->ComputeBlockIdToTileOffsetIndexing(*root), + MatchIndexingMap(R"( + (d0) -> (d0 floordiv 10, (d0 mod 10) * 10) + domain: + d0 in [0, 19] + )")); + + auto p0_from_subtract0 = root->operand(0); + auto p0_from_subtract1 = root->operand(1)->operand(0)->operand(0); + + EXPECT_THAT(analysis_->TileOffsets(*p0_from_subtract0), ElementsAre(0, 0)); + EXPECT_THAT(analysis_->TileSizes(*p0_from_subtract0), ElementsAre(1, 10)); + EXPECT_THAT(analysis_->TileStrides(*p0_from_subtract0), ElementsAre(1, 1)); + + EXPECT_THAT( + *analysis_->ComputeBlockIdToTileOffsetIndexing(*p0_from_subtract0), + MatchIndexingMap(R"( + (d0) -> (d0 floordiv 10, (d0 mod 10) * 10) + domain: + d0 in [0, 19] + )")); + + EXPECT_THAT(analysis_->TileOffsets(*p0_from_subtract1), ElementsAre(0, 0)); + EXPECT_THAT(analysis_->TileSizes(*p0_from_subtract1), ElementsAre(1, 97)); + EXPECT_THAT(analysis_->TileStrides(*p0_from_subtract1), ElementsAre(1, 1)); + + EXPECT_THAT( + *analysis_->ComputeBlockIdToTileOffsetIndexing(*p0_from_subtract1), + MatchIndexingMap(R"( + (d0) -> (d0 floordiv 10, 0) + domain: + d0 in [0, 19] + )")); +} + +TEST_F(SymbolicTileAnalysisTest, ElementwiseDiamondCSEIsSupported) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY main { + p0 = f32[2,97] parameter(0) + exp = f32[2,97] exponential(p0) + log = f32[2,97] log(p0) + ROOT subtract = f32[2,97] subtract(exp, log) +})")); + + EXPECT_TRUE(SetAnalysis(module.get())); + + const SymbolicTiledHloInstruction* root = analysis_->GetRoot(); + + auto p0_from_subtract0 = root->operand(0)->operand(0); + auto p0_from_subtract1 = root->operand(1)->operand(0); + + EXPECT_EQ(p0_from_subtract0, p0_from_subtract1); +} + +TEST_F(SymbolicTileAnalysisTest, TransposeOffsetIndexingIsCorrect) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY main { + p0 = f32[8,16,4] parameter(0) + ROOT transpose = f32[4,8,16] transpose(p0), dimensions={2,0,1} +})")); + + EXPECT_TRUE(SetAnalysis(module.get())); + + analysis_->SetTileSizes(/*sizes=*/{2, 4, 2}); + + const SymbolicTiledHloInstruction* root = analysis_->GetRoot(); + + EXPECT_THAT(*analysis_->ComputeBlockIdToTileOffsetIndexing(*root), + MatchIndexingMap(R"( + (d0) -> ((d0 floordiv 16) * 2, ((d0 floordiv 8) mod 2) * 4, (d0 mod 8) * 2) + domain: + d0 in [0, 31] + )")); + + EXPECT_THAT(*analysis_->ComputeBlockIdToTileOffsetIndexing(*root->operand(0)), + MatchIndexingMap(R"( + (d0) -> (((d0 floordiv 8) mod 2) * 4, (d0 mod 8) * 2, (d0 floordiv 16) * 2) + domain: + d0 in [0, 31] + )")); +} + +TEST_F(SymbolicTileAnalysisTest, BailOutOnUnsupportedDot) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY main { + p0 = f32[1,2]{1,0} parameter(0) + p1 = f32[2,3]{1,0} parameter(1) + ROOT dot = f32[1,3]{1,0} dot(p0, p1), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})")); + + EXPECT_FALSE(SetAnalysis(module.get())); +} + +TEST_F(SymbolicTileAnalysisTest, BailOutOnUnsupportedReshape) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY main { + p0 = f32[1,2]{1,0} parameter(0) + ROOT reshape = f32[2] reshape(p0) +})")); + + EXPECT_FALSE(SetAnalysis(module.get())); +} + +TEST_F(SymbolicTileAnalysisTest, BailOutOnUnsupportedBitcast) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY main { + p0 = f32[1,2]{1,0} parameter(0) + ROOT bitcast = f32[2] bitcast(p0) +})")); + + mlir::MLIRContext mlir_ctx; + SymbolicTileAnalysisOrError analysis_or_error = + SymbolicTileAnalysis::AnalyzeComputation(*module->entry_computation(), + &mlir_ctx); + EXPECT_FALSE(SetAnalysis(module.get())); +} + +TEST_F(SymbolicTileAnalysisTest, BailOutOnUnsupportedConcatenate) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY main { + p0 = f32[1,3]{1,0} parameter(0) + p1 = f32[1,3]{1,0} parameter(1) + ROOT concatenate = f32[2,3] concatenate(p0, p1), dimensions={0} +})")); + + EXPECT_FALSE(SetAnalysis(module.get())); +} + +TEST_F(SymbolicTileAnalysisTest, ComputingIndexingMapFailsWithoutTileSizes) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY main { + p0 = f32[4,8]{1,0} parameter(0) + ROOT exponential = f32[4,8]{1,0} exponential(p0) +})")); + + EXPECT_TRUE(SetAnalysis(module.get())); + + const SymbolicTiledHloInstruction* root = analysis_->GetRoot(); + + EXPECT_THAT( + analysis_->ComputeBlockIdToTileOffsetIndexing(*root).status().message(), + ::testing::HasSubstr("SetTileSizes() must be called before")); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/symbolic_tile_test.cc b/xla/service/gpu/model/symbolic_tile_test.cc new file mode 100644 index 0000000000000..74673bd8f4db1 --- /dev/null +++ b/xla/service/gpu/model/symbolic_tile_test.cc @@ -0,0 +1,439 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/symbolic_tile.h" + +#include + +#include +#include +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +using ::testing::ExplainMatchResult; +using ::testing::Optional; +using ::testing::StrEq; + +MATCHER_P3(MatchSymbolicTile, offset_map_string, size_map_string, + stride_map_string, + absl::StrCat(negation + ? "equals " + : "doesn't equal symbolic tile with offset_map_ ", + offset_map_string, " and size_map_ ", size_map_string, + " and stride_map_ ", stride_map_string)) { + AffineMapPrinter printer; + return ExplainMatchResult(StrEq(offset_map_string), + printer.ToString(arg.offset_map()), + result_listener) && + ExplainMatchResult(StrEq(size_map_string), + printer.ToString(arg.size_map()), + result_listener) && + ExplainMatchResult(StrEq(stride_map_string), + printer.ToString(arg.stride_map()), + result_listener); +} + +using SymbolicTileTest = IndexingTestBase; + +TEST_F(SymbolicTileTest, CanPropagateTileFromDotOutputToInputs) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[11, 17, 19] parameter(0) + p1 = f32[11, 19, 23] parameter(1) + ROOT dot = f32[11, 17, 23] dot(p0, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} + } + )")); + + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + Optional(MatchSymbolicTile("()[s0, s1, s2] -> (0, 0, 0)", + "()[s0, s1, s2] -> (s0, s1, 19)", + "()[s0, s1, s2] -> (1, 1, 1)"))); + + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[1].begin()), + Optional(MatchSymbolicTile("()[s0, s1, s2] -> (0, 0, 0)", + "()[s0, s1, s2] -> (s0, 19, s2)", + "()[s0, s1, s2] -> (1, 1, 1)"))); +} + +TEST_F(SymbolicTileTest, CanPropagateTileThroughTrivialReshape) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[11, 17, 19] parameter(0) + ROOT reshape = f32[1, 11, 17, 19] reshape(p0) + } + )")); + + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + Optional(MatchSymbolicTile("()[s0, s1, s2, s3] -> (0, 0, 0)", + "()[s0, s1, s2, s3] -> (s1, s2, s3)", + "()[s0, s1, s2, s3] -> (1, 1, 1)"))); +} + +TEST_F(SymbolicTileTest, + CanPropagateTileThroughNonTrivialMergeReshapeFromOutputToInput) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[1,8,6,4]{3,2,1,0} parameter(0) + ROOT bitcast = f32[48,4]{1,0} bitcast(p0) + } + )")); + + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + Optional(MatchSymbolicTile( + "()[s0, s1] -> (0, 0, 0, 0)", + "()[s0, s1] -> " + "(1, (s0 + 5) floordiv 6, s0 - ((s0 - 1) floordiv 6) * 6, s1)", + "()[s0, s1] -> (0, 1, 1, 1)"))); +} + +TEST_F(SymbolicTileTest, FailsToPropagateTileThroughNonTrivialReshape) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[12, 4, 19] parameter(0) + ROOT reshape = f32[4, 12, 19] reshape(p0) + } + )")); + + EXPECT_EQ( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + std::nullopt); +} + +TEST_F(SymbolicTileTest, CanPropagateTileThroughElementwiseOp) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[150] parameter(0) + p1 = f32[150] parameter(1) + ROOT add = f32[150] add(p0, p1) + } + )")); + + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + Optional(MatchSymbolicTile("()[s0] -> (0)", "()[s0] -> (s0)", + "()[s0] -> (1)"))); +} + +TEST_F(SymbolicTileTest, CanPropagateTileFromBroadcastOutputToInput) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[150] parameter(0) + ROOT broadcast = f32[157,150] broadcast(p0), dimensions={1} + } + )")); + + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + Optional(MatchSymbolicTile("()[s0, s1] -> (0)", "()[s0, s1] -> (s1)", + "()[s0, s1] -> (1)"))); +} + +TEST_F(SymbolicTileTest, CanPropagateTileFromReduceOutputToInput) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT max = f32[] maximum(p0, p1) + } + + ENTRY e { + p0 = f32[125,150] parameter(0) + c0 = f32[] constant(-inf) + ROOT reduce = f32[150] reduce(p0, c0), dimensions={0}, to_apply=max + } + )")); + + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + Optional(MatchSymbolicTile("()[s0] -> (0, 0)", "()[s0] -> (125, s0)", + "()[s0] -> (1, 1)"))); +} + +TEST_F(SymbolicTileTest, CanPropagateTileThroughReverse) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[179] parameter(0) + ROOT reverse = f32[179] reverse(p0), dimensions={0} + } + )")); + + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + Optional(MatchSymbolicTile("()[s0] -> (-s0 + 179)", "()[s0] -> (s0)", + "()[s0] -> (1)"))); +} + +TEST_F(SymbolicTileTest, CanPropagateTileFromSliceOutputToInput) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[120,142] parameter(0) + ROOT slice = f32[10,21] slice(p0), slice={[40:60:2], [20:104:4]} + } + )")); + + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + Optional(MatchSymbolicTile("()[s0, s1] -> (40, 20)", + "()[s0, s1] -> (s0, s1)", + "()[s0, s1] -> (2, 4)"))); +} + +TEST_F(SymbolicTileTest, CanPropagateTileThroughTranspose) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[21,10] parameter(0) + ROOT transpose = f32[10,21] transpose(p0), dimensions={1,0} + } + )")); + + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + Optional(MatchSymbolicTile("()[s0, s1] -> (0, 0)", + "()[s0, s1] -> (s1, s0)", + "()[s0, s1] -> (1, 1)"))); +} + +TEST_F(SymbolicTileTest, CanPropagateTileThroughConcatenate) { + // TODO(b/325488844): Add additional concat test cases with constraints. + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[2,5,7] parameter(0) + p1 = f32[2,11,7] parameter(1) + p2 = f32[2,17,7] parameter(2) + ROOT concat = f32[2,33,7] concatenate(p0, p1, p2), dimensions={1} + } + )")); + + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + Optional(MatchSymbolicTile("()[s0, s1, s2] -> (0, 0, 0)", + "()[s0, s1, s2] -> (s0, s1, s2)", + "()[s0, s1, s2] -> (1, 1, 1)"))); + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[1].begin()), + Optional(MatchSymbolicTile("()[s0, s1, s2] -> (0, -5, 0)", + "()[s0, s1, s2] -> (s0, s1, s2)", + "()[s0, s1, s2] -> (1, 1, 1)"))); + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[2].begin()), + Optional(MatchSymbolicTile("()[s0, s1, s2] -> (0, -16, 0)", + "()[s0, s1, s2] -> (s0, s1, s2)", + "()[s0, s1, s2] -> (1, 1, 1)"))); +} + +TEST_F(SymbolicTileTest, CanPropagateTileThroughPadOpWithoutInteriorPadding) { + // TODO(b/325488844): Add pad tests with defined constraints on tile input. + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = f32[4, 4] parameter(0) + p1 = f32[] parameter(1) + ROOT pad = f32[8,8] pad(p0, p1), padding=2_2_0x1_3_0 + } + )")); + + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + Optional(MatchSymbolicTile("()[s0, s1] -> (-2, -1)", + "()[s0, s1] -> (s0, s1)", + "()[s0, s1] -> (1, 1)"))); +} + +TEST_F(SymbolicTileTest, CanPropagateTileThroughSplitReshapeOfReverse) { + // A split reshape of a reverse creates a negative unit stride atop a + // floordiv. + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + computation { + p0 = f32[1,8,6,4]{3,2,1,0} parameter(0) + reverse = f32[1,8,6,4]{3,2,1,0} reverse(p0), dimensions={1,2} + ROOT bitcast = f32[48,4]{1,0} bitcast(reverse) + } + + ENTRY e { + p0 = f32[1,8,6,4]{3,2,1,0} parameter(0) + ROOT fusion = f32[48,4]{1,0} fusion(p0), kind=kLoop, calls=computation + } + )")); + + // TODO(b/331257678): the expected expressions should be simplified. + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + Optional(MatchSymbolicTile( + "()[s0, s1] -> (0, -((s0 + 5) floordiv 6) + 8, " + "-(s0 - ((s0 - 1) floordiv 6) * 6) + 6, 0)", + "()[s0, s1] -> " + "(1, (s0 + 5) floordiv 6, s0 - ((s0 - 1) floordiv 6) * 6, s1)", + "()[s0, s1] -> (0, 1, 1, 1)"))); +} + +TEST_F(SymbolicTileTest, + FailsGracefullyAtPropagatingTileThroughSliceOfSplitReshape) { + // TODO(b/326998704): constraints should allow us to unblock this use case. + // A slice of a split reshape creates a non-unit stride atop a floordiv. + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + computation { + p0 = f32[1,8,6,4]{3,2,1,0} parameter(0) + bitcast = f32[48,4]{1,0} bitcast(p0) + ROOT slice = f32[5,2]{1,0} slice(bitcast), slice={[18:43:5], [0:4:2]} + } + + ENTRY e { + p0 = f32[1,8,6,4]{3,2,1,0} parameter(0) + ROOT fusion = f32[5,2]{1,0} fusion(p0), kind=kLoop, calls=computation + } + )")); + + EXPECT_EQ( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + std::nullopt); +} + +TEST_F(SymbolicTileTest, + FailsGracefullyAtPropagatingTileThroughMisalignedSliceOfSplitReshape) { + // TODO(b/326998704): constraints should allow us to unblock part of this use + // case. + // TODO(b/331257678): handling correctly cases where offsets don't get + // simplified away perfectly will allow us to unblock part of this use case. + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + computation { + p0 = f32[1,8,6,4]{3,2,1,0} parameter(0) + bitcast = f32[48,4]{1,0} bitcast(p0) + ROOT slice = f32[5,2]{1,0} slice(bitcast), slice={[20:45:5], [0:4:2]} + } + + ENTRY e { + p0 = f32[1,8,6,4]{3,2,1,0} parameter(0) + ROOT fusion = f32[5,2]{1,0} fusion(p0), kind=kLoop, calls=computation + } + )")); + + EXPECT_EQ( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + std::nullopt); +} + +TEST_F(SymbolicTileTest, + FailsGracefullyAtPropagatingTileThroughSliceOfSplitReshapeOnTranspose) { + // TODO(b/326998704): constraints should allow us to unblock this use case. + // A slice of a split reshape creates a non-unit stride atop a floordiv. + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + computation { + p0 = f32[1,6,8,4]{3,2,1,0} parameter(0) + transpose = f32[1,8,6,4]{3,2,1,0} transpose(p0), dimensions={0,2,1,3} + bitcast = f32[48,4]{1,0} bitcast(transpose) + ROOT slice = f32[5,2]{1,0} slice(bitcast), slice={[18:43:5], [0:4:2]} + } + + ENTRY e { + p0 = f32[1,6,8,4]{3,2,1,0} parameter(0) + ROOT fusion = f32[5,2]{1,0} fusion(p0), kind=kLoop, calls=computation + } + )")); + + EXPECT_EQ( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + std::nullopt); +} + +TEST_F(SymbolicTileTest, + FailsGracefullyAtPropagatingTileThroughSliceOfSplitReshapeOfReverse) { + // TODO(b/326998704): constraints should allow us to unblock this use case. + // A slice of a split reshape of a reverse creates a negative non-unit stride + // atop a floordiv. + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + computation { + p0 = f32[1,8,6,4]{3,2,1,0} parameter(0) + reverse = f32[1,8,6,4]{3,2,1,0} reverse(p0), dimensions={1,2} + bitcast = f32[48,4]{1,0} bitcast(reverse) + ROOT slice = f32[5,2]{1,0} slice(bitcast), slice={[18:43:5], [0:4:2]} + } + + ENTRY e { + p0 = f32[1,8,6,4]{3,2,1,0} parameter(0) + ROOT fusion = f32[5,2]{1,0} fusion(p0), kind=kLoop, calls=computation + } + )")); + + EXPECT_EQ( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + std::nullopt); +} + +TEST_F(SymbolicTileTest, + FailsGracefullyAtPropagatingTileThroughReductionOfConcatenation) { + // TODO(b/330906085): concatenating across a reduction dimension needs to be + // handled to unblock this. + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + max_computation { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(p0, p1) + } + + computation { + p0 = f32[10,8]{1,0} parameter(0) + p1 = f32[20,8]{1,0} parameter(1) + concatenate = f32[30,8]{1,0} concatenate(p0, p1), dimensions={0} + neg_inf = f32[] constant(-inf) + ROOT reduce = f32[8] reduce(concatenate, neg_inf), dimensions={0}, + to_apply=max_computation + } + + ENTRY e { + p0 = f32[10,8]{1,0} parameter(0) + p1 = f32[20,8]{1,0} parameter(1) + ROOT fusion = f32[8] fusion(p0, p1), kind=kLoop, calls=computation + } + )")); + + EXPECT_EQ( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[1].begin()), + std::nullopt); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/symbolic_tiled_hlo_instruction.cc b/xla/service/gpu/model/symbolic_tiled_hlo_instruction.cc new file mode 100644 index 0000000000000..c419cb91cc36f --- /dev/null +++ b/xla/service/gpu/model/symbolic_tiled_hlo_instruction.cc @@ -0,0 +1,80 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" + +#include +#include + +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "xla/service/gpu/model/symbolic_tile.h" +#include "xla/status.h" + +namespace xla { +namespace gpu { +namespace { + +using ::mlir::AffineExpr; +using ::mlir::AffineMap; +using ::mlir::SmallVector; + +std::vector EvaluateTileMap(AffineMap affine_map, + absl::Span parameters) { + CHECK_EQ(affine_map.getNumSymbols(), parameters.size()); + CHECK_EQ(affine_map.getNumDims(), 0); + + SmallVector symbol_replacements = llvm::to_vector( + llvm::map_range(parameters, [affine_map](const int64_t v) -> AffineExpr { + return mlir::getAffineConstantExpr(v, affine_map.getContext()); + })); + + AffineMap simplified_affine_map = + mlir::simplifyAffineMap(affine_map.replaceDimsAndSymbols( + /*dimReplacements=*/{}, symbol_replacements, /*numResultDims=*/0, + /*numResultSyms=*/0)); + + SmallVector results = llvm::to_vector(llvm::map_range( + simplified_affine_map.getResults(), [](AffineExpr result) -> int64_t { + return llvm::cast(result).getValue(); + })); + + return std::vector(results.begin(), results.end()); +} + +} // namespace + +std::vector SymbolicTiledHloInstruction::TileOffsets( + absl::Span tile_parameters) const { + return EvaluateTileMap(symbolic_tile_.offset_map(), tile_parameters); +} + +std::vector SymbolicTiledHloInstruction::TileSizes( + absl::Span tile_parameters) const { + return EvaluateTileMap(symbolic_tile_.size_map(), tile_parameters); +} + +std::vector SymbolicTiledHloInstruction::TileStrides( + absl::Span tile_parameters) const { + return EvaluateTileMap(symbolic_tile_.stride_map(), tile_parameters); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/symbolic_tiled_hlo_instruction.h b/xla/service/gpu/model/symbolic_tiled_hlo_instruction.h new file mode 100644 index 0000000000000..18cab47c7984a --- /dev/null +++ b/xla/service/gpu/model/symbolic_tiled_hlo_instruction.h @@ -0,0 +1,89 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_SYMBOLIC_TILED_HLO_INSTRUCTION_H_ +#define XLA_SERVICE_GPU_MODEL_SYMBOLIC_TILED_HLO_INSTRUCTION_H_ + +#include +#include +#include + +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/symbolic_tile.h" + +namespace xla { +namespace gpu { + +// A node in the symbolic tiled representation of an HLO computation. During +// tiling and codegen an HLO instruction may need to be emitted multiple times +// with different tiling parameters. +class SymbolicTiledHloInstruction { + public: + SymbolicTiledHloInstruction(const HloInstruction* hlo, + IndexingMap indexing_map, + SymbolicTile symbolic_tile) + : hlo_(hlo), + indexing_map_(std::move(indexing_map)), + symbolic_tile_(std::move(symbolic_tile)) {} + + // Evaluates the tile offsets of an instruction with given tile parameters. + std::vector TileOffsets( + absl::Span tile_parameters) const; + // Evaluates the tile sizes of an instruction with given tile parameters. + std::vector TileSizes( + absl::Span tile_parameters) const; + // Evaluates the tile strides of an instruction with given tile parameters. + std::vector TileStrides( + absl::Span tile_parameters) const; + + const HloInstruction* hlo() const { return hlo_; } + const IndexingMap& indexing_map() const { return indexing_map_; } + const SymbolicTile& symbolic_tile() const { return symbolic_tile_; } + + const SymbolicTiledHloInstruction* operand(int64_t operand_id) const { + return operands_[operand_id]; + } + SymbolicTiledHloInstruction* operand(int64_t operand_id) { + return operands_[operand_id]; + } + const std::vector& operands() const { + return operands_; + } + + // Appends an operand to the end of the operand list. + void AppendOperand(SymbolicTiledHloInstruction* operand) { + operands_.push_back(operand); + } + + private: + // Pointer to the original HLO instruction. + const HloInstruction* hlo_; + + // Indexing map from the computation root to this instruction output. + IndexingMap indexing_map_; + + // Symbolic tile derived from the indexing map. + SymbolicTile symbolic_tile_; + + // Operands of the instruction in the tiled computation graph. + std::vector operands_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_SYMBOLIC_TILED_HLO_INSTRUCTION_H_ diff --git a/xla/service/gpu/model/symbolic_tiled_hlo_instruction_test.cc b/xla/service/gpu/model/symbolic_tiled_hlo_instruction_test.cc new file mode 100644 index 0000000000000..0f8fea04493d5 --- /dev/null +++ b/xla/service/gpu/model/symbolic_tiled_hlo_instruction_test.cc @@ -0,0 +1,97 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" + +#include +#include +#include + +#include +#include +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/symbolic_tile.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +using ::testing::ElementsAre; +using SymbolicTiledHloInstructionTest = HloTestBase; + +TEST_F(SymbolicTiledHloInstructionTest, TransposeTileSizesAreSupported) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +fused_computation { + p0 = f32[16,32] parameter(0) + p1 = f32[32,16] parameter(1) + transpose = f32[32,16] transpose(p0), dimensions={1,0} + ROOT subtract = f32[32,16] subtract(transpose, p1) +} + +ENTRY main { + p0 = f32[16,32] parameter(0) + p1 = f32[32,16] parameter(1) + ROOT root = f32[32,16] fusion(p0, p1), kind=kLoop, calls=fused_computation +} +)")); + + mlir::MLIRContext mlir_ctx; + auto fusion = module->entry_computation()->root_instruction(); + auto fusion_adaptor = HloFusionAdaptor::ForInstruction(fusion); + + auto output_to_input_indexing = ComputeGroupedOutputToInputIndexing( + *fusion_adaptor, fusion_adaptor->GetRoots()[0], &mlir_ctx); + + HloInstruction* subtract = fusion->fused_expression_root(); + HloInstruction* p0 = subtract->mutable_operand(0)->mutable_operand(0); + HloInstruction* p1 = subtract->mutable_operand(1); + + // We use `fusion->operand(0)` to get indexing from the map instead of `p0`, + // because `HloFusionAdaptor` and `ComputeGroupedOutputToInputIndexing` ignore + // kParameter instructions inside the fusion and produces indexing for fusion + // operands. + IndexingMap p0_indexing = + *output_to_input_indexing[fusion->operand(0)].begin(); + std::optional p0_symbolic_tile = + SymbolicTile::FromIndexingMap(p0_indexing); + ASSERT_TRUE(p0_symbolic_tile.has_value()); + SymbolicTiledHloInstruction tiled_p0(p0, p0_indexing, *p0_symbolic_tile); + ASSERT_TRUE(p0_symbolic_tile.has_value()); + + IndexingMap p1_indexing = + *output_to_input_indexing[fusion->operand(1)].begin(); + std::optional p1_symbolic_tile = + SymbolicTile::FromIndexingMap(p1_indexing); + ASSERT_TRUE(p1_symbolic_tile.has_value()); + SymbolicTiledHloInstruction tiled_p1(p1, p1_indexing, *p1_symbolic_tile); + + std::vector output_tile_sizes = {8, 4}; + + auto p0_tile_sizes = tiled_p0.TileSizes(output_tile_sizes); + EXPECT_THAT(tiled_p0.TileSizes(output_tile_sizes), ElementsAre(4, 8)); + EXPECT_THAT(tiled_p1.TileSizes(output_tile_sizes), ElementsAre(8, 4)); +} + +} // namespace + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/tile_analysis.cc b/xla/service/gpu/model/tile_analysis.cc deleted file mode 100644 index 771ec1494e48e..0000000000000 --- a/xla/service/gpu/model/tile_analysis.cc +++ /dev/null @@ -1,982 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/model/tile_analysis.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/log/check.h" -#include "absl/strings/str_join.h" -#include "absl/types/span.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallBitVector.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/AffineExpr.h" // from @llvm-project -#include "mlir/IR/AffineMap.h" // from @llvm-project -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/statusor.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { -namespace { - -using llvm::SmallVector; -using mlir::AffineBinaryOpExpr; -using mlir::AffineDimExpr; -using mlir::AffineExpr; -using mlir::AffineExprKind; -using mlir::AffineMap; -using mlir::AffineSymbolExpr; -using mlir::getAffineBinaryOpExpr; -using mlir::getAffineConstantExpr; -using mlir::getAffineDimExpr; -using mlir::MLIRContext; - -StatusOr ComputeCwiseOpIndexing( - const HloInstruction* instr, MLIRContext* mlir_context) { - auto dims = instr->shape().dimensions(); - IndexingMap identity_map{.affine_map = AffineMap::getMultiDimIdentityMap( - dims.size(), mlir_context), - .input_dims_sizes = {}}; - - std::vector operand_indexing_maps; - int64_t operand_count = instr->operand_count(); - operand_indexing_maps.reserve(operand_count); - for (int64_t operand_id = 0; operand_id < operand_count; ++operand_id) { - operand_indexing_maps.push_back({{identity_map}, operand_id}); - } - return HloInstructionIndexing{std::move(operand_indexing_maps)}; -} - -StatusOr ComputeBroadcastOpIndexing( - const HloBroadcastInstruction* bcast, MLIRContext* mlir_context) { - auto output_dims = bcast->shape().dimensions(); - - std::vector exprs; - for (int64_t bcast_dim : bcast->dimensions()) { - exprs.push_back(getAffineDimExpr(bcast_dim, mlir_context)); - } - IndexingMap indexing_map{ - .affine_map = AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, - mlir_context), - .input_dims_sizes = {}}; - - return HloInstructionIndexing{{HloOperandIndexing{ - .indexing_maps = {std::move(indexing_map)}, .operand_id = 0}}}; -} - -// Composes affine maps, i.e. consumer_map ∘ producer_map. -IndexingMap ComposeIndexingMaps(const IndexingMap& producer_map, - const IndexingMap& consumer_map) { - // AffineMap::compose(some_affine_map) actually computes some_affine_map ∘ - // this. - AffineMap composed_map = mlir::simplifyAffineMap( - producer_map.affine_map.compose(consumer_map.affine_map)); - - // After the composition some of the symbols might become unused, e.g. when a - // dimension was added by broadcasting as then reduced. We should remove these - // dimensions from the composed affine map and also from the resulting - // `input_dim_sizes`. - // - // For example, if there is a reduction(broadcast): - // - // param = f32[15] parameter(0) - // bcast = f32[15, 20] broadcast(p0), dimensions={0} - // reduce = f32[15, 20] reduce(bcast, init) dimensions={1} - // - // then `reduce` has (d0)[s0] -> (d0, s0) with size(s0) = 20 - // and `bcast` has (d0, d1) -> (d0) indexing map. - // - // The composition of there two maps yields (d0)[s0] -> (d0) with size(s0), - // although `s0` is not used in the mapping. In order to remove such symbols, - // we get the indices of unused symbols and remove them from the composed - // affine map and the `input_dim_sizes`. - auto unused_symbols_bit_vector = - mlir::getUnusedSymbolsBitVector({composed_map}); - composed_map = mlir::compressSymbols(composed_map, unused_symbols_bit_vector); - - // The input dims symbols in the composed map, i.e. combined - // producer_map.compose(consumer_map) are packed as [symbols(producer_map) | - // symbols(consumer_map)]. In that order we are adding the sizes for the input - // dims while skipping the symbols that are unused. - std::vector combined_sizes; - int64_t symbol_id = 0; - for (int64_t dim : llvm::concat( - producer_map.input_dims_sizes, consumer_map.input_dims_sizes)) { - if (unused_symbols_bit_vector[symbol_id++]) continue; - combined_sizes.push_back(dim); - } - return IndexingMap{.affine_map = std::move(composed_map), - .input_dims_sizes = std::move(combined_sizes)}; -} - -// Computes HloInstructionIndexing that maps the iteration space of the -// consumer's output tensor to the iteration space of the producer's inputs and -// the remaining outputs of the consumer as if the producer was fused. -// -// Example: -// -// operand1 operand2 -// | | # producer_instr_indexing edges -// producer_instr -// | # consumer_operand_indexing edge -// consumer -// -// The function has two inputs: -// -// 1. `producer_instr_indexing` is the producer's HloInstructionIndexing -// that maps the iteration space of its output tensor to the inputs of -// producers. -// 2. `consumer_operand_indexing` is the consumer's HloOperandIndexing for the -// operand that corresponds to the provided producer. -HloInstructionIndexing ComputeFusedProducerConsumerIndexing( - const HloInstructionIndexing& producer_instr_indexing, - const HloOperandIndexing& consumer_operand_indexing) { - HloInstructionIndexing fused_instr_indexing; - - // Every operand can be read 1 or more times by the consumer which also can - // have 1 or more read accesses to its operands. So, to get the composed - // indexing maps we have to compute a "cross product" here. - for (const HloOperandIndexing& producer_operand_indexing : - producer_instr_indexing.operand_indexing_maps) { - auto& composed_operand_indexing = - fused_instr_indexing.operand_indexing_maps.emplace_back(); - composed_operand_indexing.operand_id = producer_operand_indexing.operand_id; - for (const IndexingMap& producer_map : - producer_operand_indexing.indexing_maps) { - for (const IndexingMap& consumer_map : - consumer_operand_indexing.indexing_maps) { - composed_operand_indexing.indexing_maps.insert( - ComposeIndexingMaps(producer_map, consumer_map)); - } - } - fused_instr_indexing.operand_indexing_maps.push_back( - std::move(composed_operand_indexing)); - } - return fused_instr_indexing; -} - -// Composes instruction indexing maps starting at the root instruction -// until the HloParameterInstruction is found. -StatusOr ComputeFusionOpIndexing( - const HloFusionInstruction* fusion, int output_id, - MLIRContext* mlir_context) { - const HloInstruction* root = - fusion->shape().IsTuple() - ? fusion->fused_expression_root()->operand(output_id) - : fusion->fused_expression_root(); - std::queue> bfs; - TF_ASSIGN_OR_RETURN(auto root_indexing, ComputeInstructionIndexing( - root, output_id, mlir_context)); - - bfs.push(std::make_pair(root, root_indexing)); - absl::flat_hash_map> - parameter_indexing_maps; - while (!bfs.empty()) { - const auto& [instr, instr_indexing] = bfs.front(); - for (const auto& operand_indexing : instr_indexing.operand_indexing_maps) { - const HloInstruction* producer_instr = - instr->operand(operand_indexing.operand_id); - // If the producer is a fusion op parameter, store the result. - if (auto parameter = DynCast(producer_instr)) { - parameter_indexing_maps[parameter->parameter_number()].insert( - operand_indexing.indexing_maps.begin(), - operand_indexing.indexing_maps.end()); - continue; - } - TF_ASSIGN_OR_RETURN(auto producer_instr_indexing, - ComputeInstructionIndexing( - producer_instr, /*output_id=*/0, mlir_context)); - bfs.push(std::make_pair(producer_instr, - ComputeFusedProducerConsumerIndexing( - producer_instr_indexing, operand_indexing))); - } - bfs.pop(); - } - HloInstructionIndexing fusion_indexing; - for (const auto& [operand_id, maps] : parameter_indexing_maps) { - fusion_indexing.operand_indexing_maps.push_back({maps, operand_id}); - } - return fusion_indexing; -} - -StatusOr ComputeDotOpIndexing( - const HloDotInstruction* dot, MLIRContext* mlir_context) { - CHECK_NE(dot, nullptr); - const DotDimensionNumbers& dim_numbers = dot->dot_dimension_numbers(); - absl::Span lhs_contracting_dims( - dim_numbers.lhs_contracting_dimensions()); - absl::Span rhs_contracting_dims = - dim_numbers.rhs_contracting_dimensions(); - - absl::Span lhs_batch_dims = dim_numbers.lhs_batch_dimensions(); - absl::Span rhs_batch_dims = dim_numbers.rhs_batch_dimensions(); - - const Shape& lhs_shape = dot->operand(0)->shape(); - const Shape& rhs_shape = dot->operand(1)->shape(); - // According to the StableHLO specification, the dimensions of the output - // shape are ordered as follows: - // lhs_batch_dims | lhs_non_contracting_dims | rhs_non_contracting_dims - SmallVector lhs_exprs(lhs_shape.rank()); - SmallVector rhs_exprs(rhs_shape.rank()); - int64_t output_dim_id = 0; - - // lhs_batch_dims - for (auto [lhs_batch_dim, rhs_batch_dim] : - llvm::zip(lhs_batch_dims, rhs_batch_dims)) { - AffineExpr output_dim_expr = getAffineDimExpr(output_dim_id, mlir_context); - lhs_exprs[lhs_batch_dim] = output_dim_expr; - rhs_exprs[rhs_batch_dim] = output_dim_expr; - ++output_dim_id; - } - - // lhs_non_contracting_dims - TF_ASSIGN_OR_RETURN( - std::vector lhs_non_contracting_dims, - GetNonContractingDims(lhs_shape, lhs_batch_dims, lhs_contracting_dims)); - - for (int64_t lhs_non_contracting_dim : lhs_non_contracting_dims) { - lhs_exprs[lhs_non_contracting_dim] = - getAffineDimExpr(output_dim_id++, mlir_context); - } - - // rhs_non_contracting_dims - TF_ASSIGN_OR_RETURN( - std::vector rhs_non_contracting_dims, - GetNonContractingDims(rhs_shape, rhs_batch_dims, rhs_contracting_dims)); - - for (int64_t rhs_non_contracting_dim : rhs_non_contracting_dims) { - rhs_exprs[rhs_non_contracting_dim] = - getAffineDimExpr(output_dim_id++, mlir_context); - } - - int64_t input_dim_id = 0; - std::vector input_dim_sizes; - input_dim_sizes.reserve(lhs_contracting_dims.size()); - - for (auto [lhs_contracting_dim, rhs_contracting_dim] : - llvm::zip(lhs_contracting_dims, rhs_contracting_dims)) { - AffineExpr input_dim_expr = getAffineSymbolExpr(input_dim_id, mlir_context); - lhs_exprs[lhs_contracting_dim] = input_dim_expr; - rhs_exprs[rhs_contracting_dim] = input_dim_expr; - ++input_dim_id; - - // LHS and RHS contracting dimensions must match pairwise, and we therefore - // need only populate a single input_dim_sizes vector. - input_dim_sizes.push_back(lhs_shape.dimensions(lhs_contracting_dim)); - } - - IndexingMap lhs_indexing_map{ - .affine_map = AffineMap::get(dot->shape().rank(), input_dim_sizes.size(), - lhs_exprs, mlir_context), - .input_dims_sizes = input_dim_sizes}; - - IndexingMap rhs_indexing_map{ - .affine_map = AffineMap::get(dot->shape().rank(), input_dim_sizes.size(), - rhs_exprs, mlir_context), - .input_dims_sizes = input_dim_sizes}; - - return HloInstructionIndexing{ - {HloOperandIndexing{.indexing_maps = {std::move(lhs_indexing_map)}, - .operand_id = 0}, - HloOperandIndexing{.indexing_maps = {std::move(rhs_indexing_map)}, - .operand_id = 1}}}; -} - -StatusOr ComputeReduceOpIndexing( - const HloReduceInstruction* reduce, int output_id, - MLIRContext* mlir_context) { - absl::flat_hash_set reduce_dims_ids(reduce->dimensions().begin(), - reduce->dimensions().end()); - - const Shape& input_shape = reduce->operand(output_id)->shape(); - const Shape& output_shape = reduce->shape().IsTuple() - ? ShapeUtil::GetSubshape(reduce->shape(), {0}) - : reduce->shape(); - - std::vector input_dims_sizes; - int64_t reduced_dim_id = 0; - int64_t output_dim_id = 0; - std::vector exprs; - for (auto [input_dim_id, input_dim] : - llvm::enumerate(input_shape.dimensions())) { - if (reduce_dims_ids.contains(input_dim_id)) { - exprs.push_back(getAffineSymbolExpr(reduced_dim_id++, mlir_context)); - input_dims_sizes.push_back(input_dim); - continue; - } - exprs.push_back(getAffineDimExpr(output_dim_id++, mlir_context)); - } - IndexingMap indexing_map{ - .affine_map = AffineMap::get(output_shape.rank(), reduce_dims_ids.size(), - exprs, mlir_context), - .input_dims_sizes = std::move(input_dims_sizes)}; - - std::vector operand_indexing_maps; - int64_t input_count = reduce->input_count(); - operand_indexing_maps.reserve(input_count); - for (int64_t input_id = 0; input_id < input_count; ++input_id) { - operand_indexing_maps.push_back({{indexing_map}, input_id}); - } - return HloInstructionIndexing{std::move(operand_indexing_maps)}; -} -// Computes strides for a shape. -std::vector ComputeStrides(absl::Span dims) { - int rank = static_cast(dims.size()); - std::vector strides(rank, 1); - for (int i = rank - 2; i >= 0; --i) { - strides[i] = dims[i + 1] * strides[i + 1]; - } - return strides; -} - -// Computes 1D index given a shape and N-d indexing expressions. -AffineExpr LinearizeShape(absl::Span dims, - absl::Span dimension_exprs, - MLIRContext* mlir_context) { - AffineExpr linear_index = getAffineConstantExpr(0, mlir_context); - - auto strides = ComputeStrides(dims); - for (auto [stride, dimension_expr] : llvm::zip(strides, dimension_exprs)) { - linear_index = getAffineBinaryOpExpr( - AffineExprKind::Add, linear_index, - getAffineBinaryOpExpr(AffineExprKind::Mul, - getAffineConstantExpr(stride, mlir_context), - dimension_expr)); - } - return linear_index; -} - -// Computes N-d indexing expressions given a linear index and a shape. -std::vector DelinearizeIndex(absl::Span dims, - AffineExpr linear_index, - MLIRContext* mlir_context) { - std::vector multi_index; - multi_index.reserve(dims.size()); - - AffineExpr remainder = linear_index; - for (int64_t stride : ComputeStrides(dims)) { - AffineExpr stride_expr = getAffineConstantExpr(stride, mlir_context); - multi_index.push_back(getAffineBinaryOpExpr(AffineExprKind::FloorDiv, - remainder, stride_expr)); - remainder = - getAffineBinaryOpExpr(AffineExprKind::Mod, remainder, stride_expr); - } - return multi_index; -} - -// Computes indexing for "minimal" reshapes, i.e. reshapes that cannot be -// represented by a series of composed reshapes, i.e. when there are no -// subshapes in input and output that have the same number of elements. -// For example, [8, 4] -> [8, 2, 2] is not a minimal reshape, it has matching -// subshapes [8] -> [8] and [4] -> [2, 2]. -// -// There are only 4 types of "minimal" reshapes considers only 4 cases: -// 1. Dimension is not changed, e.g. [8] -> [8] -// 2. Dimension is expanded, e.g. [8] -> [4, 2] -// 3. Dimension is collapsed, e.g. [4, 2] -> [8] -// 4. Dimension is collapsed and expanded, e.g. [8, 16] -> [4, 32] -// -// The function computes indexing maps for these 4 cases, i.e. considers given -// input/output shapes and checks if the shapes are the same, expanded or -// collapsed. Otherwise, performs linearization/delinearization. -void ComputeMinimalReshapeIndexing( - absl::Span input_dims, absl::Span output_dims, - absl::Span output_dims_exprs, - std::vector* exprs, MLIRContext* mlir_context) { - // The shape does not change. - if (input_dims.size() == 1 && output_dims.size() == 1) { - absl::c_copy(output_dims_exprs, std::back_inserter(*exprs)); - return; - } - // Expand shape. - if (input_dims.size() == 1) { - exprs->push_back( - LinearizeShape(output_dims, output_dims_exprs, mlir_context)); - return; - } - // Collapse shape. - if (output_dims.size() == 1) { - auto multi_index = - DelinearizeIndex(input_dims, output_dims_exprs.front(), mlir_context); - absl::c_copy(multi_index, std::back_inserter(*exprs)); - return; - } - // Generic case. - AffineExpr linear_index = - LinearizeShape(output_dims, output_dims_exprs, mlir_context); - auto multi_index = DelinearizeIndex(input_dims, linear_index, mlir_context); - absl::c_copy(multi_index, std::back_inserter(*exprs)); -} - -// Scans input and output shapes from left to right in an attempt to find -// subshapes with the same number of elements and then computes indexing map for -// every pair of subshapes. -// -// Example: -// p0 = f32[4, 8, 12] parameter(0) -// reshape = f32[32, 3, 4] reshape(p0) -// -// This reshape can be represented as a composition of two reshapes. -// The first reshape collapses dimensions first two input dimensions [4, 8] onto -// the output dimension [32]. -// The second reshape expands the input dimension [12] into two output -// dimensions [3, 4]. -// This is an optimization that allows us to construct simpler affine maps, -// otherwise we would need to linearize/delinearize even some of the simpler -// cases. -IndexingMap ComputeReshapeIndexingMap(absl::Span input_dims, - absl::Span output_dims, - MLIRContext* mlir_context) { - std::vector exprs; - - size_t input_rank = input_dims.size(); - size_t output_rank = output_dims.size(); - std::vector output_dims_exprs; - - // Find subshapes with the same element count and compute indexing for them. - int64_t input_num_elements = 1; - int64_t output_num_elements = 1; - std::vector input_subshape, output_subshape; - size_t input_dim_id = 0, output_dim_id = 0; - while (input_dim_id < input_rank || output_dim_id < output_rank || - !input_subshape.empty()) { - if (input_dim_id < input_rank && - (input_subshape.empty() || input_num_elements < output_num_elements || - input_dims[input_dim_id] == 1)) { - input_num_elements *= input_dims[input_dim_id]; - input_subshape.push_back(input_dims[input_dim_id]); - ++input_dim_id; - continue; - } - if (output_dim_id < output_rank && - (output_subshape.empty() || output_num_elements < input_num_elements || - output_dims[output_dim_id] == 1)) { - output_num_elements *= output_dims[output_dim_id]; - output_subshape.push_back(output_dims[output_dim_id]); - output_dims_exprs.push_back( - getAffineDimExpr(output_dim_id, mlir_context)); - ++output_dim_id; - continue; - } - ComputeMinimalReshapeIndexing(input_subshape, output_subshape, - output_dims_exprs, &exprs, mlir_context); - input_num_elements = 1; - output_num_elements = 1; - input_subshape.clear(); - output_subshape.clear(); - output_dims_exprs.clear(); - } - return IndexingMap{ - .affine_map = AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, - mlir_context), - .input_dims_sizes = {}}; -} - -StatusOr ComputeReshapeOpIndexing( - const HloReshapeInstruction* reshape, MLIRContext* mlir_context) { - auto input_dims = reshape->operand(0)->shape().dimensions(); - auto output_dims = reshape->shape().dimensions(); - IndexingMap reshape_indexing_map = - ComputeReshapeIndexingMap(input_dims, output_dims, mlir_context); - return HloInstructionIndexing{{HloOperandIndexing{ - .indexing_maps = {std::move(reshape_indexing_map)}, .operand_id = 0}}}; -} - -StatusOr ComputeReverseOpIndexing( - const HloReverseInstruction* reverse, MLIRContext* mlir_context) { - absl::flat_hash_set reverse_dims(reverse->dimensions().begin(), - reverse->dimensions().end()); - auto output_dims = reverse->shape().dimensions(); - - std::vector exprs; - for (auto [output_dim_id, output_dim] : llvm::enumerate(output_dims)) { - auto dim_expr = getAffineDimExpr(output_dim_id, mlir_context); - if (!reverse_dims.contains(output_dim_id)) { - exprs.push_back(dim_expr); - continue; - } - auto dim_size = getAffineConstantExpr(output_dim, mlir_context); - auto neg_dim_expr = getAffineBinaryOpExpr( - AffineExprKind::Mul, getAffineConstantExpr(-1, mlir_context), dim_expr); - exprs.push_back( - getAffineBinaryOpExpr(AffineExprKind::Add, neg_dim_expr, dim_size)); - } - - IndexingMap indexing_map{ - .affine_map = AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, - mlir_context), - .input_dims_sizes = {}}; - - return HloInstructionIndexing{{HloOperandIndexing{ - .indexing_maps = {std::move(indexing_map)}, .operand_id = 0}}}; -} - -StatusOr ComputeSliceOpIndexing( - const HloSliceInstruction* slice, MLIRContext* mlir_context) { - auto output_dims = slice->shape().dimensions(); - - std::vector exprs; - for (int64_t dim = 0; dim < output_dims.size(); ++dim) { - AffineExpr offset = - getAffineConstantExpr(slice->slice_starts()[dim], mlir_context); - AffineExpr stride = - getAffineConstantExpr(slice->slice_strides()[dim], mlir_context); - AffineExpr dim_expr = getAffineDimExpr(dim, mlir_context); - - AffineExpr mul = - getAffineBinaryOpExpr(AffineExprKind::Mul, stride, dim_expr); - exprs.push_back(getAffineBinaryOpExpr(AffineExprKind::Add, offset, mul)); - } - IndexingMap indexing_map{ - .affine_map = AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, - mlir_context), - .input_dims_sizes = {}}; - return HloInstructionIndexing{{HloOperandIndexing{ - .indexing_maps = {std::move(indexing_map)}, .operand_id = 0}}}; -} - -IndexingMap ComputeTransposeIndexingMap(absl::Span permutation, - MLIRContext* mlir_context) { - auto forward_permutation = AffineMap::getPermutationMap( - std::vector(permutation.begin(), permutation.end()), - mlir_context); - return IndexingMap{ - .affine_map = mlir::inversePermutation(forward_permutation), - .input_dims_sizes = {}}; -} - -StatusOr ComputeTransposeOpIndexing( - const HloTransposeInstruction* transpose, MLIRContext* mlir_context) { - IndexingMap transpose_indexing_map = - ComputeTransposeIndexingMap(transpose->dimensions(), mlir_context); - return HloInstructionIndexing{{HloOperandIndexing{ - .indexing_maps = {std::move(transpose_indexing_map)}, .operand_id = 0}}}; -} - -StatusOr ComputeBitcastOpIndexing( - const HloInstruction* bitcast, MLIRContext* mlir_context) { - const Shape& input_shape = bitcast->operand(0)->shape(); - const Shape& output_shape = bitcast->shape(); - ShapeUtil::BitcastDecomposition decomposed_bitcast = - ShapeUtil::DecomposeBitcast(input_shape, output_shape); - - if (std::holds_alternative( - decomposed_bitcast)) { - auto permutation = ShapeUtil::DeduceTransposeDimensionsForBitcast( - input_shape, output_shape); - CHECK(permutation.has_value()) - << "Failed to deduce permutation for a bitcast."; - IndexingMap transpose_indexing_map = - ComputeTransposeIndexingMap(*permutation, mlir_context); - return HloInstructionIndexing{{HloOperandIndexing{ - .indexing_maps = {std::move(transpose_indexing_map)}, - .operand_id = 0}}}; - } - if (std::holds_alternative( - decomposed_bitcast)) { - IndexingMap reshape_indexing_map = ComputeReshapeIndexingMap( - input_shape.dimensions(), output_shape.dimensions(), mlir_context); - return HloInstructionIndexing{{HloOperandIndexing{ - .indexing_maps = {std::move(reshape_indexing_map)}, .operand_id = 0}}}; - } - // `trt` stands for transpose-reshape-transpose decomposition of bitcast. - auto trt = std::get(decomposed_bitcast); - IndexingMap transpose_map_1 = - ComputeTransposeIndexingMap(trt.transpose1_dims, mlir_context); - IndexingMap reshape_map = - ComputeReshapeIndexingMap(trt.transpose1_shape.dimensions(), - trt.reshape_shape.dimensions(), mlir_context); - IndexingMap transpose_map_2 = - ComputeTransposeIndexingMap(trt.transpose2_dims, mlir_context); - IndexingMap composed_map = ComposeIndexingMaps( - ComposeIndexingMaps(transpose_map_1, reshape_map), transpose_map_2); - return HloInstructionIndexing{{HloOperandIndexing{ - .indexing_maps = {std::move(composed_map)}, .operand_id = 0}}}; -} - -template -std::string ToStringImpl(const T& value) { - std::string s; - std::stringstream ss(s); - ss << value; - return ss.str(); -} - -struct IndexingMapSimplifier { - struct Bounds { - int64_t lower; - int64_t upper; - }; - - Bounds BoundsInclusive(AffineExpr expr) { - auto bound = bounds.find(expr); - if (bound != bounds.end()) return bound->second; - - switch (expr.getKind()) { - case AffineExprKind::Constant: { - int64_t value = mlir::cast(expr).getValue(); - CHECK_GE(value, 0); - return bounds[expr] = {value, value}; - } - case AffineExprKind::DimId: { - int64_t size = - dimension_sizes[mlir::cast(expr).getPosition()]; - return bounds[expr] = {0, size - 1}; - } - case AffineExprKind::SymbolId: { - int64_t size = - symbol_sizes[mlir::cast(expr).getPosition()]; - return bounds[expr] = {0, size - 1}; - } - default: - auto binary_op = mlir::dyn_cast(expr); - CHECK(binary_op); - auto lhs = BoundsInclusive(binary_op.getLHS()); - auto rhs = BoundsInclusive(binary_op.getRHS()); - - auto& result = bounds[expr]; - switch (expr.getKind()) { - case AffineExprKind::Add: - return result = {lhs.lower + rhs.lower, lhs.upper + rhs.upper}; - case AffineExprKind::Mul: - return result = {lhs.lower * rhs.lower, lhs.upper * rhs.upper}; - case AffineExprKind::Mod: { - CHECK_EQ(rhs.lower, rhs.upper) << "RHS of mod must be a constant"; - int64_t m = rhs.lower; - if (lhs.upper < m) { - return result = lhs; - } - return result = {0, m - 1}; - } - case AffineExprKind::FloorDiv: { - CHECK_EQ(rhs.lower, rhs.upper) - << "RHS of floor_div must be a constant"; - int64_t d = rhs.lower; - return result = {lhs.lower / d, lhs.upper / d}; - } - default: - // We don't use ceildiv, so we don't support it. - LOG(FATAL) << "Unsupported expression"; - } - } - } - - // Simplifier for mod. - // - Rewrites (a * 100 + ...) % 100 to (...) % 100 - // - Rewrites a % b to a if a is known to be less than b. - AffineExpr RewriteMod(AffineBinaryOpExpr mod) { - auto lhs_simplified = SimplifyOnce(mod.getLHS()); - - auto lhs = BoundsInclusive(lhs_simplified); - auto rhs = BoundsInclusive(mod.getRHS()); - - // a % b where b is always larger than a? - if (lhs.upper < rhs.lower) return lhs_simplified; - - // The logic below assumes we have a constant RHS. - if (rhs.lower != rhs.upper) return mod; - int64_t m = rhs.lower; - - auto new_lhs = RewriteSumIf(lhs_simplified, [&](AffineExpr expr) { - if (expr.getKind() != AffineExprKind::Mul) { - return true; - } - - auto mul_rhs = - BoundsInclusive(mlir::cast(expr).getRHS()); - bool remove = mul_rhs.lower == mul_rhs.upper && (mul_rhs.lower % m) == 0; - return !remove; // We keep it if we don't remove it! - }); - - // If we weren't able to remove or simplify anything, return the original - // expression. - if (new_lhs == mod.getLHS()) { - return mod; - } - // If we removed everything, return 0. - if (!new_lhs) { - return getAffineConstantExpr(0, mlir_context); - } - // Otherwise, return new_sum % m. - return getAffineBinaryOpExpr(AffineExprKind::Mod, new_lhs, mod.getRHS()); - } - - // Simplifier for floordiv. - // - Rewrites (a * 100 + ...) / 100 to a + (...) / 100 - // - Rewrites a / 100 to 0 when a is known to be less than 100. - AffineExpr RewriteFloorDiv(AffineBinaryOpExpr div) { - auto lhs_simplified = SimplifyOnce(div.getLHS()); - auto lhs = BoundsInclusive(lhs_simplified); - auto rhs = BoundsInclusive(div.getRHS()); - - if (lhs.upper < rhs.lower) { - return getAffineConstantExpr(0, mlir_context); - } - - // The logic below assumes we have a constant RHS. - if (rhs.lower != rhs.upper) return div; - int64_t d = rhs.lower; - - AffineExpr extracted = getAffineConstantExpr(0, mlir_context); - auto new_dividend = RewriteSumIf(lhs_simplified, [&](AffineExpr expr) { - if (auto multiplier = GetConstantRhsMultiplier(expr)) { - // (x * 7 + ...) / 3 -> can't extract. We could extract x * 2 and keep - // one x, but we currently have no reason to do that. - if (*multiplier % d != 0) return true; - int64_t factor = *multiplier / d; - extracted = getAffineBinaryOpExpr( - AffineExprKind::Add, extracted, - getAffineBinaryOpExpr(AffineExprKind::Mul, - cast(expr).getLHS(), - getAffineConstantExpr(factor, mlir_context))); - // Remove from dividend. - return false; - } - - // Not a constant multiplier, keep in dividend. - return true; - }); - - // If we removed everything, skip the div. - if (!new_dividend) return extracted; - // If we removed nothing, return the original division. - if (extracted == getAffineConstantExpr(0, mlir_context) && - new_dividend == div.getLHS()) { - return div; - } - - return getAffineBinaryOpExpr( - AffineExprKind::Add, extracted, - getAffineBinaryOpExpr(AffineExprKind::FloorDiv, new_dividend, - div.getRHS())); - } - - std::optional GetConstantRhsMultiplier(AffineExpr expr) { - if (expr.getKind() != AffineExprKind::Mul) return std::nullopt; - auto bound = BoundsInclusive(mlir::cast(expr).getRHS()); - if (bound.lower != bound.upper) return std::nullopt; - return bound.lower; - } - - AffineExpr RewriteSumIf(AffineExpr expr, - const std::function& pred) { - if (expr.getKind() == AffineExprKind::Add) { - auto add = mlir::dyn_cast(expr); - auto lhs = RewriteSumIf(add.getLHS(), pred); - auto rhs = RewriteSumIf(add.getRHS(), pred); - if (lhs == add.getLHS() && rhs == add.getRHS()) { - return add; - } - if (lhs && rhs) { - return getAffineBinaryOpExpr(AffineExprKind::Add, lhs, rhs); - } - return lhs ? lhs : (rhs ? rhs : nullptr); - } - return pred(expr) ? expr : nullptr; - } - - // Attempts to simplify the expression, but doesn't attempt to simplify the - // result further. - AffineExpr SimplifyOnce(AffineExpr expr) { - switch (expr.getKind()) { - case AffineExprKind::Mul: - case AffineExprKind::Add: { - auto binop = mlir::cast(expr); - auto lhs = SimplifyOnce(binop.getLHS()); - auto rhs = SimplifyOnce(binop.getRHS()); - if (lhs == binop.getLHS() && rhs == binop.getRHS()) { - return expr; - } - return getAffineBinaryOpExpr(expr.getKind(), lhs, rhs); - } - case AffineExprKind::Mod: - return RewriteMod(cast(expr)); - case AffineExprKind::FloorDiv: - return RewriteFloorDiv(cast(expr)); - default: - return expr; - } - } - - // Simplifies the expression as much as possible. - AffineExpr Simplify(AffineExpr expr) { - while (true) { - auto simplified = SimplifyOnce(expr); - if (simplified == expr) return expr; - expr = simplified; - } - } - - MLIRContext* mlir_context; - absl::Span dimension_sizes; - absl::Span symbol_sizes; - llvm::DenseMap bounds{}; -}; - -} // namespace - -bool IndexingMap::Simplify(absl::Span dimension_sizes) { - IndexingMapSimplifier simplifier{affine_map.getContext(), dimension_sizes, - input_dims_sizes}; - std::vector results; - bool any_changed = false; - for (auto expr : affine_map.getResults()) { - auto simplified = simplifier.Simplify(expr); - any_changed |= simplified != expr; - results.push_back(simplified); - } - - if (!any_changed) { - return false; - } - - affine_map = - AffineMap::get(affine_map.getNumDims(), affine_map.getNumSymbols(), - results, affine_map.getContext()); - return true; -} - -bool HloOperandIndexing::Simplify(absl::Span dimension_sizes) { - std::vector to_remove; - std::vector to_add; - for (auto map : indexing_maps) { - to_remove.push_back(map); - if (map.Simplify(dimension_sizes)) { - to_add.push_back(map); - } else { - to_remove.pop_back(); - } - } - for (auto& map : to_remove) { - indexing_maps.erase(map); - } - for (auto& map : to_add) { - indexing_maps.insert(map); - } - return !to_remove.empty(); -} - -bool HloInstructionIndexing::Simplify( - absl::Span dimension_sizes) { - bool any_simplified = false; - for (auto& operand_indexing : operand_indexing_maps) { - any_simplified |= operand_indexing.Simplify(dimension_sizes); - } - return any_simplified; -} - -std::string ToString(const AffineMap& affine_map) { - std::string s; - llvm::raw_string_ostream ss(s); - affine_map.print(ss); - return s; -} - -bool operator==(const IndexingMap& lhs, const IndexingMap& rhs) { - return lhs.affine_map == rhs.affine_map && - lhs.input_dims_sizes == rhs.input_dims_sizes; -} - -std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map) { - out << ToString(indexing_map.affine_map) << " with sizes " - << absl::StrJoin(indexing_map.input_dims_sizes, ", ") << "\n"; - return out; -} - -std::ostream& operator<<(std::ostream& out, - const HloOperandIndexing& operand_indexing) { - out << "operand id = " << operand_indexing.operand_id << ' '; - for (const auto& map : operand_indexing.indexing_maps) { - out << map; - } - return out; -} - -std::ostream& operator<<(std::ostream& out, - const HloInstructionIndexing& instr_indexing) { - for (const auto& operand_map : instr_indexing.operand_indexing_maps) { - out << operand_map; - } - return out; -} - -std::string IndexingMap::ToString() const { return ToStringImpl(*this); } - -std::string HloOperandIndexing::ToString() const { return ToStringImpl(*this); } - -std::string HloInstructionIndexing::ToString() const { - return ToStringImpl(*this); -} - -StatusOr ComputeInstructionIndexing( - const HloInstruction* instr, int output_id, MLIRContext* mlir_context) { - if (HloInstruction::IsOpElementwise(instr->opcode())) { - return ComputeCwiseOpIndexing(instr, mlir_context); - } - if (auto bcast = DynCast(instr)) { - return ComputeBroadcastOpIndexing(bcast, mlir_context); - } - if (instr->opcode() == HloOpcode::kBitcast) { - return ComputeBitcastOpIndexing(instr, mlir_context); - } - if (auto dot = DynCast(instr)) { - return ComputeDotOpIndexing(dot, mlir_context); - } - if (auto fusion = DynCast(instr)) { - return ComputeFusionOpIndexing(fusion, output_id, mlir_context); - } - if (auto reduce = DynCast(instr)) { - return ComputeReduceOpIndexing(reduce, output_id, mlir_context); - } - if (auto reshape = DynCast(instr)) { - return ComputeReshapeOpIndexing(reshape, mlir_context); - } - if (auto reverse = DynCast(instr)) { - return ComputeReverseOpIndexing(reverse, mlir_context); - } - if (auto slice = DynCast(instr)) { - return ComputeSliceOpIndexing(slice, mlir_context); - } - if (auto transpose = DynCast(instr)) { - return ComputeTransposeOpIndexing(transpose, mlir_context); - } - return InvalidArgument("Unsupported instruction type"); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/model/tile_analysis.h b/xla/service/gpu/model/tile_analysis.h deleted file mode 100644 index a8bed45911697..0000000000000 --- a/xla/service/gpu/model/tile_analysis.h +++ /dev/null @@ -1,122 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_MODEL_TILE_ANALYSIS_H_ -#define XLA_SERVICE_GPU_MODEL_TILE_ANALYSIS_H_ - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "llvm/ADT/Hashing.h" -#include "mlir/IR/AffineMap.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/statusor.h" - -namespace xla { -namespace gpu { - -// Contains an affine map with N dimension expressions and M symbols: -// (d0, ..., d_{N - 1})[s_0, ..., s_{M - 1}] -> f(d_i, s_j) -// Dimensions d_i correspond to the iteration space of the output tensor. Some -// or all of the dimensions of the input operands can be expressed as a function -// of dimensions of output. For example, for broadcasts and cwise ops all -// dimensions of the inputs are covered by the output dimensions. -// Symbols s_j correspond to the dimensions that are present ONLY in inputs. -// `input_dims_sizes` is an array that holds the upper bounds for the iteration -// sizes for every input-only dimension. Note, that the sizes have upper -// bounds only and the lower bounds are always 0, since we can encode the -// offsets in the affine map. The sizes for the output dimensions can be deduced -// from the shape of the output tensor. -// -// Example: -// -// 1. Indexing map for the input of the following reduction -// ``` -// p0 = f32[150, 20, 10, 50] parameter(0) -// reduce = f32[150, 10] reduce(p0, p0_init), dimensions={3, 1} -// ``` -// can be written as `(d0, d1)[s0, s1] -> (d0, s0, d1, s1)` with the input -// dimensions sizes `[/*s0 size=*/20, /*s1 size=*/50]`. -// -// 2. Indexing map for the input of the reverse op -// ``` -// %p0 = f32[1, 17, 9, 9] parameter(0) -// reverse = f32[1, 17, 9, 9] reverse(%p0), dimensions={1, 2} -// ``` -// can be written as `(d0, d1, d2, d3) -> (d0, -d1 + 17, -d2 + 9, d3)` with the -// empty 'input_dims_sizes`, because there are no dimensions in the input that -// could not be expressed via dimensions of the output. -struct IndexingMap { - std::string ToString() const; - // Returns true if the map was simplified. - bool Simplify(absl::Span dimension_sizes); - - mlir::AffineMap affine_map; - std::vector input_dims_sizes; -}; -std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map); -bool operator==(const IndexingMap& lhs, const IndexingMap& rhs); - -template -H AbslHashValue(H h, const IndexingMap& indexing_map) { - llvm::hash_code affine_map_hash = llvm::hash_combine(indexing_map.affine_map); - return H::combine(std::move(h), static_cast(affine_map_hash)); -} - -// Contains 1 or more indexing maps for the `operand_id`. There are cases, when -// the same input operand is read multiple times in various ways. Especially, it -// happens a lot in fusion ops. -struct HloOperandIndexing { - std::string ToString() const; - - // Returns true if the indexing was simplified. - bool Simplify(absl::Span dimension_sizes); - - absl::flat_hash_set indexing_maps; - int64_t operand_id; -}; -std::ostream& operator<<(std::ostream& out, - const HloOperandIndexing& operand_indexing); - -// Contains indexing maps for all N-dimensional tensor input operands that -// correspond to a particular output. -struct HloInstructionIndexing { - std::string ToString() const; - - // Returns true if the indexing was simplified. - bool Simplify(absl::Span dimension_sizes); - - std::vector operand_indexing_maps; -}; -std::ostream& operator<<(std::ostream& out, - const HloInstructionIndexing& instr_indexing); - -std::string ToString(const mlir::AffineMap& affine_map); - -// Computes indexing maps for all input operands necessary to compute an element -// of the `output_id` instruction output. -StatusOr ComputeInstructionIndexing( - const HloInstruction* instr, int output_id, - mlir::MLIRContext* mlir_context); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_MODEL_TILE_ANALYSIS_H_ diff --git a/xla/service/gpu/model/tile_analysis_test.cc b/xla/service/gpu/model/tile_analysis_test.cc deleted file mode 100644 index 7dadcc852974c..0000000000000 --- a/xla/service/gpu/model/tile_analysis_test.cc +++ /dev/null @@ -1,729 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/model/tile_analysis.h" - -#include - -#include -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/statusor.h" -#include "xla/test_helpers.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" - -namespace xla { -namespace gpu { -namespace { - -using ::testing::ElementsAre; -using ::testing::ElementsAreArray; -using ::testing::Eq; -using ::testing::ExplainMatchResult; -using ::testing::HasSubstr; -using ::testing::PrintToString; -using ::testing::UnorderedElementsAre; - -MATCHER_P2(MatchIndexingMap, affine_map_string, input_dims_sizes, - absl::StrCat(negation ? "equals " : "doesn't equal ", "affine map ", - affine_map_string, " with input dim sizes ", - PrintToString(input_dims_sizes))) { - return ExplainMatchResult(HasSubstr(affine_map_string), - ToString(arg.affine_map), result_listener) && - ExplainMatchResult(ElementsAreArray(input_dims_sizes), - arg.input_dims_sizes, result_listener); -} - -MATCHER_P2(MatchOperandIndexing, operand_id, indexing_map_matchers, "") { - return ExplainMatchResult(Eq(operand_id), arg.operand_id, result_listener) && - ExplainMatchResult(indexing_map_matchers, arg.indexing_maps, - result_listener); -} - -class TileAnalysisTest : public HloTestBase { - public: - StatusOr GetIndexingMapsForEntryComputation( - absl::string_view hlo_string, int operand_id = 0) { - TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* root = module->entry_computation()->root_instruction(); - - return ComputeInstructionIndexing(root, operand_id, &mlir_context_); - } - mlir::MLIRContext mlir_context_; -}; - -TEST_F(TileAnalysisTest, ElementwiseOp) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - ENTRY e { - p0 = f32[10, 20] parameter(0) - p1 = f32[10, 20] parameter(1) - ROOT add0 = f32[10, 20] add(p0, p1) - } - )")); - EXPECT_THAT( - input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", - std::vector{}))), - MatchOperandIndexing( - 1, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, BitcastIsReshape) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - ENTRY e { - p0 = f32[4, 32] parameter(0) - ROOT bitcast = f32[4, 8, 4] bitcast(p0) - } - )")); - EXPECT_THAT( - input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d0, d1 * 4 + d2)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, BitcastIsTranspose) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - ENTRY e { - p0 = f32[3, 12288, 6, 128] parameter(0) - ROOT bitcast = f32[3, 6, 128, 12288] {2, 1, 3, 0} bitcast(p0) - } - )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3) -> (d0, d3, d1, d2)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, BitcastIsTransposeReshapeTranspose) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - ENTRY e { - p0 = f32[16, 17, 3] parameter(0) - ROOT bitcast = f32[51, 16] {0, 1} bitcast(p0) - } - )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1) -> (d1, d0 floordiv 3, d0 mod 3)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, BroadcastOp) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - ENTRY e { - p0 = f32[20] parameter(0) - ROOT bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1} - } - )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d1)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, FusionOpWithSingleBinaryOp) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - f { - p0 = f32[100] parameter(0) - p1 = f32[100] parameter(1) - ROOT a0 = f32[100] add(p0, p1) - } - ENTRY e { - p0 = f32[100] parameter(0) - p1 = f32[100] parameter(1) - ROOT fusion = f32[100] fusion(p0, p1), kind=kLoop, calls=f - } - )")); - EXPECT_THAT( - input_indexing.operand_indexing_maps, - UnorderedElementsAre( - MatchOperandIndexing(0, ElementsAre(MatchIndexingMap( - "(d0) -> (d0)", std::vector{}))), - MatchOperandIndexing(1, ElementsAre(MatchIndexingMap( - "(d0) -> (d0)", std::vector{}))))); -} - -TEST_F(TileAnalysisTest, FusionOpTensorPlusTransposedTensor) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - f { - p0 = f32[1000, 1000] parameter(0) - transpose_p0 = f32[1000, 1000]{0, 1} transpose(p0), dimensions={1, 0} - ROOT a0 = f32[1000, 1000] add(p0, transpose_p0) - } - ENTRY e { - p0 = f32[1000,1000] parameter(0) - ROOT fusion = f32[1000,1000] fusion(p0), kind=kLoop, calls=f - } - )")); - EXPECT_THAT( - input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, - UnorderedElementsAre( - MatchIndexingMap("(d0, d1) -> (d1, d0)", std::vector{}), - MatchIndexingMap("(d0, d1) -> (d0, d1)", std::vector{}))))); -} - -TEST_F(TileAnalysisTest, FusionExponentialDuplication) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule test_module - ENTRY entry_computation { - p0 = f32[4] parameter(0) - p1 = f32[4] parameter(1) - add0 = f32[4] add(p0, p1) - slice1.0 = f32[3] slice(add0), slice={[0:3]} - slice1.1 = f32[3] slice(add0), slice={[1:4]} - add1 = f32[3]{0} add(slice1.0, slice1.1) - slice2.0 = f32[2] slice(add1), slice={[0:2]} - slice2.1 = f32[2] slice(add1), slice={[1:3]} - ROOT add2 = f32[2] add(slice2.0, slice2.1) - })")); - EXPECT_THAT( - input_indexing.operand_indexing_maps, - ElementsAre( - MatchOperandIndexing(0, ElementsAre(MatchIndexingMap( - "(d0) -> (d0)", std::vector{}))), - MatchOperandIndexing(1, ElementsAre(MatchIndexingMap( - "(d0) -> (d0)", std::vector{}))))); -} - -TEST_F(TileAnalysisTest, FusionOpWithReduceOfReduce) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - max { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT max = f32[] maximum(p0, p1) - } - f { - p0 = f32[150, 20, 10, 50] parameter(0) - p0_init = f32[] parameter(1) - reduce_1 = f32[20, 10] reduce(p0, p0_init), - dimensions={0, 3}, to_apply=max - ROOT reduce_2 = f32[10] reduce(reduce_1, p0_init), - dimensions={0}, to_apply=max - } - ENTRY e { - p0 = f32[150, 20, 10, 50] parameter(0) - p0_init = f32[] constant(-inf) - ROOT fusion = f32[10] fusion(p0, p0_init), kind=kLoop, calls=f - } - )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0)[s0, s1, s2] -> (s0, s2, d0, s1)", - std::vector{150, 50, 20}))))); -} - -TEST_F(TileAnalysisTest, FusionOpWithReduceOfBroadcast) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - max { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT max = f32[] maximum(p0, p1) - } - f { - p0 = f32[15, 20] parameter(0) - p0_init = f32[] parameter(1) - p0_bcast = f32[15, 32, 20, 64] broadcast(p0), dimensions={0, 2} - - ROOT reduce_2 = f32[15, 64] reduce(p0_bcast, p0_init), - dimensions={1, 2}, to_apply=max - } - ENTRY e { - p0 = f32[15, 20] parameter(0) - p0_init = f32[] constant(-inf) - ROOT fusion = f32[15, 64] fusion(p0, p0_init), kind=kLoop, calls=f - } - )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0, d1)[s0] -> (d0, s0)", - std::vector{20}))))); -} - -TEST_F(TileAnalysisTest, FusionOpWithTransposeOfTranspose) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - f { - p0 = f32[20, 10, 50] parameter(0) - - lhs_transpose_1 = f32[10, 20, 50] - transpose(p0), dimensions={1, 0, 2} - lhs_e = f32[10, 20, 50] exponential(lhs_transpose_1) - lhs_transpose_2 = f32[10, 50, 20] - transpose(lhs_e), dimensions={0, 2, 1} - - rhs_transpose_1 = f32[50, 10, 20] - transpose(p0), dimensions={2, 1, 0} - rhs_log = f32[50, 10, 20] exponential(rhs_transpose_1) - rhs_transpose_2 = f32[10, 50, 20] - transpose(rhs_log), dimensions={1, 0, 2} - - ROOT add = f32[10, 50, 20] add(lhs_transpose_2, rhs_transpose_2) - } - ENTRY e { - p0 = f32[20, 10, 50] parameter(0) - ROOT fusion = f32[10, 50, 20] fusion(p0), kind=kLoop, calls=f - } - )")); - EXPECT_THAT( - input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d2, d0, d1)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, FusionOpWithReducedSlice) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - max { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT max = f32[] maximum(p0, p1) - } - f { - p0 = f32[150, 64, 1024] parameter(0) - p0_init = f32[] parameter(1) - p0_slice = f32[16, 32, 128] slice(f32[150, 64, 1024] p0), - slice={[5:21:1], [0:64:2], [50:434:3]} - ROOT reduce = f32[32] reduce(p0_slice, p0_init), - dimensions={0, 2}, to_apply=max - } - ENTRY e { - p0 = f32[150, 64, 1024] parameter(0) - p0_init = f32[] constant(-inf) - ROOT fusion = f32[32] fusion(p0, p0_init), kind=kLoop, calls=f - } - )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0)[s0, s1] -> (s0 + 5, d0 * 2, s1 * 3 + 50)", - std::vector{16, 128}))))); -} - -TEST_F(TileAnalysisTest, FusionOpWithReshape_CollapseOfExpand) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - f { - p0 = f32[128] parameter(0) - expand = f32[8, 16] reshape(p0) - ROOT collapse = f32[128] reshape(expand) - } - ENTRY e { - p0 = f32[128] parameter(0) - ROOT fusion = f32[128] fusion(p0), kind=kLoop, calls=f - } - )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0) -> (d0)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, FusionOpWithReshape_ExpandOfCollapse) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - f { - p0 = f32[8, 16] parameter(0) - collapse = f32[128] reshape(p0) - ROOT expand = f32[8, 16] reshape(collapse) - } - ENTRY e { - p0 = f32[8, 16] parameter(0) - ROOT fusion = f32[8, 16] fusion(p0), kind=kLoop, calls=f - } - )")); - EXPECT_TRUE(input_indexing.Simplify({8, 16})); - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0, d1)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, FusionOpWithReshape_ChainedGenericReshapes) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - f { - p0 = f32[10, 10, 10] parameter(0) - reshape1 = f32[50, 20] reshape(p0) - ROOT reshape2 = f32[10, 10, 10] reshape(reshape1) - } - ENTRY e { - p0 = f32[10, 10, 10] parameter(0) - ROOT fusion = f32[10, 10, 10] fusion(p0), kind=kLoop, calls=f - } - )")); - EXPECT_TRUE(input_indexing.Simplify({10, 10, 10})); - EXPECT_THAT( - input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d0, d1, d2)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, FusionOpWithSliceOfSlice) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - f { - p0 = f32[150, 64, 1024] parameter(0) - p0_slice_1 = f32[16, 32, 128] slice(f32[150, 64, 1024] p0), - slice={[5:21:1], [0:64:2], [50:434:3]} - ROOT p0_slice_2 = f32[7, 9, 24] slice(f32[16, 32, 128] p0_slice_1), - slice={[3:16:2], [4:30:3], [5:100:4]} - } - ENTRY e { - p0 = f32[150, 64, 1024] parameter(0) - ROOT fusion = f32[7, 9, 24] fusion(p0), kind=kLoop, calls=f - } - )")); - EXPECT_THAT( - input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2) -> (d0 * 2 + 8, d1 * 6 + 8, d2 * 12 + 65)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, ReshapeOpCollapseShape) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - ENTRY e { - p0 = f32[4,8] parameter(0) - ROOT reshape = f32[32] reshape(p0) - } - )")); - EXPECT_FALSE(input_indexing.Simplify({32})); - EXPECT_THAT( - input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0) -> (d0 floordiv 8, d0 mod 8)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, ReshapeOpExpandShape) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - ENTRY e { - p0 = f32[32] parameter(0) - ROOT reshape = f32[4, 8] reshape(p0) - } - )")); - EXPECT_FALSE(input_indexing.Simplify({4, 8})); - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0, d1) -> (d0 * 8 + d1)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, ReshapeOpExpandAndCollapseShape) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - ENTRY e { - p0 = f32[4, 8, 12] parameter(0) - ROOT reshape = f32[32, 3, 4] reshape(p0) - } - )")); - EXPECT_FALSE(input_indexing.Simplify({32, 3, 4})); - EXPECT_THAT( - input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2) -> (d0 floordiv 8, d0 mod 8, d1 * 4 + d2)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, ReshapeOpExpandSubshapeOnly) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - ENTRY e { - p0 = f32[16, 8] parameter(0) - ROOT reshape = f32[4, 4, 8] reshape(p0) - } - )")); - EXPECT_FALSE(input_indexing.Simplify({4, 4, 8})); - EXPECT_THAT( - input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0, d1, d2) -> (d0 * 4 + d1, d2)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, ReshapeOpGenericReshape2DTO3D) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - ENTRY e { - p0 = f32[4,8] parameter(0) - ROOT reshape = f32[2, 4, 4] reshape(p0) - } - )")); - EXPECT_TRUE(input_indexing.Simplify({2, 4, 4})); - // TODO(b/313840171): Simplify `(d1 * 4 + d2) floordiv 8` to `d1 floordiv 2`. - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2) -> (d0 * 2 + (d1 * 4 + d2) floordiv 8, " - "(d1 * 4 + d2) mod 8)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, ReshapeOpGenericReshape3DTO2D) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - ENTRY e { - p0 = f32[2, 4, 4] parameter(0) - ROOT reshape = f32[4, 8] reshape(p0) - } - )")); - EXPECT_FALSE(input_indexing.Simplify({4, 8})); - // TODO(b/313840171): Simplify `(d0 * 8 + d1) floordiv 16` to `d0 floordiv 2`. - // TODO(b/313840171): Simplify `((d0 * 8 + d1) mod 16) floordiv 4` to - // `((d0 * 8 + d1) floordiv 4) mod 4` to `(d0 * 2 + d1 floordiv 4) mod 4`. - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1) -> ((d0 * 8 + d1) floordiv 16, " - "((d0 * 8 + d1) mod 16) floordiv 4, d1 mod 4)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, ReduceOp) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - max { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT max = f32[] maximum(p0, p1) - } - ENTRY e { - p0 = f32[150, 20, 10, 50] parameter(0) - p0_init = f32[] constant(-inf) - ROOT reduce = f32[150, 10] reduce(p0, p0_init), - dimensions={3, 1}, to_apply=max - } - )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1)[s0, s1] -> (d0, s0, d1, s1)", - std::vector{20, 50}))))); -} - -TEST_F(TileAnalysisTest, VariadicReduceOp) { - absl::string_view hlo_string = R"( - HloModule m - min { - tmp_0 = f32[] parameter(0) - tmp_1 = f32[] parameter(2) - tmp_2 = s32[] parameter(1) - tmp_3 = s32[] parameter(3) - cmp = pred[] compare(tmp_0, tmp_1), direction=GE - select1 = f32[] select(cmp, tmp_0, tmp_1) - select2 = s32[] select(cmp, tmp_2, tmp_3) - ROOT tmp_4 = (f32[], s32[]) tuple(select1, select2) - } - ENTRY e { - p0 = f32[256,10] parameter(0) - p0_init = f32[] constant(-inf) - p1 = s32[256,10] parameter(1) - p1_init = s32[] constant(0) - ROOT reduce = (f32[10], s32[10]) reduce(p0, p1, p0_init, p1_init), - dimensions={0}, to_apply=min - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* root = module->entry_computation()->root_instruction(); - - auto input_indexing_0 = ComputeInstructionIndexing(root, 0, &mlir_context_); - ASSERT_IS_OK(input_indexing_0); - EXPECT_THAT( - input_indexing_0->operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0)[s0] -> (s0, d0)", - std::vector{256}))), - MatchOperandIndexing( - 1, ElementsAre(MatchIndexingMap( - "(d0)[s0] -> (s0, d0)", std::vector{256}))))); - - auto input_indexing_1 = ComputeInstructionIndexing(root, 1, &mlir_context_); - ASSERT_IS_OK(input_indexing_1); - EXPECT_THAT( - input_indexing_1->operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap("(d0)[s0] -> (s0, d0)", - std::vector{256}))), - MatchOperandIndexing( - 1, ElementsAre(MatchIndexingMap( - "(d0)[s0] -> (s0, d0)", std::vector{256}))))); -} - -TEST_F(TileAnalysisTest, ReverseOp) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - ENTRY e { - p0 = f32[1, 17, 9, 9] parameter(0) - ROOT reverse = f32[1, 17, 9, 9] reverse(p0), dimensions={1, 2} - } - )")); - // TODO(b/313840171): Support simplifying this. - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3) -> (d0, -d1 + 17, -d2 + 9, d3)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, SliceOp) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - ENTRY e { - p0 = f32[10, 20, 50] parameter(0) - ROOT slice = f32[5, 3, 25] slice(f32[10, 20, 50] p0), - slice={[5:10:1], [3:20:7], [0:50:2]} - } - )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2) -> (d0 + 5, d1 * 7 + 3, d2 * 2)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, TransposeOp) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - ENTRY e { - p0 = f32[3, 12288, 6, 128] parameter(0) - ROOT transpose = f32[3, 6, 128, 12288] - transpose(p0), dimensions={0, 2, 3, 1} - } - )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3) -> (d0, d3, d1, d2)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, TransposeOp4D) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - ENTRY e { - p0 = f32[3, 12288, 6, 128] parameter(0) - ROOT bitcast = f32[3, 6, 128, 12288] {2, 1, 3, 0} bitcast(p0) - } - )")); - EXPECT_THAT(input_indexing.operand_indexing_maps, - ElementsAre(MatchOperandIndexing( - 0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3) -> (d0, d3, d1, d2)", - std::vector{}))))); -} - -TEST_F(TileAnalysisTest, DotOp) { - TF_ASSERT_OK_AND_ASSIGN(auto input_indexing, - GetIndexingMapsForEntryComputation(R"( - HloModule m - ENTRY e { - p0 = f32[4, 38, 17, 11, 18, 10] parameter(0) - p1 = f32[17, 10, 16, 18, 22, 38] parameter(1) - ROOT dot = f32[10, 38, 4, 11, 16, 22] dot(p0, p1), - lhs_batch_dims={5,1}, rhs_batch_dims={1,5}, - lhs_contracting_dims={4,2}, rhs_contracting_dims={3,0} - } - )")); - EXPECT_THAT( - input_indexing.operand_indexing_maps, - ElementsAre( - MatchOperandIndexing(0, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3, d4, d5)[s0, s1] -> " - "(d2, d1, s1, d3, s0, d0)", - std::vector{18, 17}))), - MatchOperandIndexing(1, ElementsAre(MatchIndexingMap( - "(d0, d1, d2, d3, d4, d5)[s0, s1] -> " - "(s1, d0, d4, s0, d5, d1)", - std::vector{18, 17}))))); -} - -TEST_F(TileAnalysisTest, UnsupportedOps) { - ASSERT_IS_NOT_OK(GetIndexingMapsForEntryComputation(R"( - HloModule m - ENTRY e { - p0 = f32[1, 17, 9, 9] parameter(0) - p1 = f32[5, 17, 9, 9] parameter(1) - ROOT concat = f32[6, 17, 9, 9] concatenate(p0, p1) - } - )")); - ASSERT_IS_NOT_OK(GetIndexingMapsForEntryComputation(R"( - HloModule m - ENTRY e { - input = s32[1,1,25,1] parameter(0) - update = s32[1,1,2,1] parameter(1) - start_indices = s32[4] parameter(2) - ROOT dyn-update = s32[1,1,25,1] dynamic-update-slice( - s32[1,1,25,1] input, s32[1,1,2,1] update, s32[4] start_indices) - } - )")); -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/model/tiled_hlo_instruction.cc b/xla/service/gpu/model/tiled_hlo_instruction.cc new file mode 100644 index 0000000000000..7c0698f2e40c0 --- /dev/null +++ b/xla/service/gpu/model/tiled_hlo_instruction.cc @@ -0,0 +1,118 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/tiled_hlo_instruction.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/hash/hash.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/util.h" + +namespace xla { +namespace gpu { + +size_t TiledHloInstruction::PtrHash::operator()( + const TiledHloInstruction* tiled_hlo) const { + return absl::HashOf(*tiled_hlo); +} + +bool TiledHloInstruction::PtrEqual::operator()( + const TiledHloInstruction* lhs, const TiledHloInstruction* rhs) const { + return *lhs == *rhs; +} + +bool operator==(const TiledHloInstruction& lhs, + const TiledHloInstruction& rhs) { + return lhs.hlo() == rhs.hlo() && lhs.tile_sizes() == rhs.tile_sizes() && + lhs.tile_strides() == rhs.tile_strides() && + lhs.block_id_to_tile_offsets_indexing() == + rhs.block_id_to_tile_offsets_indexing(); +} + +bool operator!=(const TiledHloInstruction& lhs, + const TiledHloInstruction& rhs) { + return !(lhs == rhs); +} + +/*static*/ +absl::StatusOr> +TiledHloInstruction::Create(const HloInstruction* hlo, + std::vector tile_sizes, + std::vector tile_strides, + IndexingMap block_id_to_tile_offsets_indexing) { + int rank = hlo->shape().rank(); + + if (tile_sizes.size() != rank) { + return absl::InvalidArgumentError( + absl::StrCat("Number of tile sizes must be equal to the rank of the " + "hlo shape. tile_sizes = ", + tile_sizes.size(), ", hlo = ", hlo->ToString())); + } + + if (tile_strides.size() != rank) { + return absl::InvalidArgumentError( + absl::StrCat("Number of tile strides must be equal to the rank of the " + "hlo shape. tile_sizes = ", + tile_strides.size(), ", hlo = ", hlo->ToString())); + } + + if (block_id_to_tile_offsets_indexing.GetDimensionCount() != 1 || + block_id_to_tile_offsets_indexing.GetSymbolCount() != 0) { + return absl::InvalidArgumentError(absl::StrCat( + "block_id_to_tile_offsets_indexing must have 1 dim and 0 symbols. " + "block_id_to_tile_offsets_indexing = ", + block_id_to_tile_offsets_indexing.ToString())); + } + + if (block_id_to_tile_offsets_indexing.GetAffineMap().getNumResults() != + rank) { + return absl::InvalidArgumentError(absl::StrCat( + "block_id_to_tile_offsets_indexing must have the same number of " + "results as the rank of the hlo shape. " + "block_id_to_tile_offsets_indexing = ", + block_id_to_tile_offsets_indexing.ToString(), + ", hlo = ", hlo->ToString())); + } + + return absl::WrapUnique(new TiledHloInstruction( + hlo, std::move(tile_sizes), std::move(tile_strides), + std::move(block_id_to_tile_offsets_indexing))); +} + +std::string TiledHloInstruction::ToString() const { + std::stringstream ss; + ss << "hlo: " << hlo_->ToString() << "\n"; + ss << "tile_sizes: {" << absl::StrJoin(tile_sizes_, ", ") << "}\n"; + ss << "tile_strides: {" << absl::StrJoin(tile_strides_, ", ") << "}\n"; + ss << "block_id_to_tile_offsets_indexing: " + << block_id_to_tile_offsets_indexing_; + return ss.str(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/model/tiled_hlo_instruction.h b/xla/service/gpu/model/tiled_hlo_instruction.h new file mode 100644 index 0000000000000..045cfd27dd8aa --- /dev/null +++ b/xla/service/gpu/model/tiled_hlo_instruction.h @@ -0,0 +1,136 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_TILED_HLO_INSTRUCTION_H_ +#define XLA_SERVICE_GPU_MODEL_TILED_HLO_INSTRUCTION_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { + +// A wrapper around HloInstruction that represents a tiled HLO instruction. +// +// The class contains information required to emit this instruction in +// block-level codegen. Tile sizes and strides are constants and do not depend +// on the block id. Tile offsets are computed using an indexing map of form: +// `(block_id) -> (tile_offset0, tile_offset1, ...)`. +class TiledHloInstruction { + public: + // PtrHash and PtrEqual are helper classes to use in hash maps and sets that + // compare values behind the pointers. For example, + // absl::flat_hash_set hlo_set; + struct PtrHash { + size_t operator()(const TiledHloInstruction* tiled_hlo) const; + }; + + struct PtrEqual { + bool operator()(const TiledHloInstruction* lhs, + const TiledHloInstruction* rhs) const; + }; + + // Creates an instance of TiledHloInstruction. Returns an error if any of the + // following preconditions is not met: + // * Number of tile sizes, strides should match HLO shape rank. + // * Number of result of `block_id_to_tile_offsets_indexing` should match HLO + // shape rank. + // * `block_id_to_tile_offsets_indexing` should have only 1 dimension and 0 + // symbols. + static absl::StatusOr> Create( + const HloInstruction* hlo, std::vector tile_sizes, + std::vector tile_strides, + IndexingMap block_id_to_tile_offsets_indexing); + + // Returns the original HLO instruction. + const HloInstruction* hlo() const { return hlo_; } + + // Returns the tile sizes. The number of tile sizes is equal to the rank of + // the output shape. + const std::vector& tile_sizes() const { return tile_sizes_; } + + // Returns the tile strides. The number of tile strides is equal to the rank + // of the output shape. + const std::vector& tile_strides() const { return tile_strides_; } + + // Returns the indexing map from block_id to tile offsets. The map has a form + // of `(block_id) -> (tile_offset0, tile_offset1, ...)`. The number of tile + // offsets is equal to the rank of the output shape. + const IndexingMap& block_id_to_tile_offsets_indexing() const { + return block_id_to_tile_offsets_indexing_; + } + + const TiledHloInstruction* operand(int64_t operand_id) const { + return operands_[operand_id]; + } + + const std::vector& operands() const { + return operands_; + } + + void AppendOperand(TiledHloInstruction* operand) { + operands_.push_back(operand); + } + + std::string ToString() const; + + private: + TiledHloInstruction(const HloInstruction* hlo, + std::vector tile_sizes, + std::vector tile_strides, + IndexingMap block_id_to_tile_offsets_indexing) + : hlo_(hlo), + tile_sizes_(std::move(tile_sizes)), + tile_strides_(std::move(tile_strides)), + block_id_to_tile_offsets_indexing_( + std::move(block_id_to_tile_offsets_indexing)) {} + + // Pointer to the original HLO instruction. + const HloInstruction* hlo_; + + // Tile sizes and strides. + std::vector tile_sizes_; + std::vector tile_strides_; + + // Indexing map from block_id to tile offsets. + IndexingMap block_id_to_tile_offsets_indexing_; + + // Operands of the instruction in the tiled computation graph. + std::vector operands_; +}; + +bool operator==(const TiledHloInstruction& lhs, const TiledHloInstruction& rhs); +bool operator!=(const TiledHloInstruction& lhs, const TiledHloInstruction& rhs); + +template +H AbslHashValue(H h, const TiledHloInstruction& tiled_hlo_instruction) { + return H::combine(std::move(h), tiled_hlo_instruction.hlo(), + tiled_hlo_instruction.tile_sizes(), + tiled_hlo_instruction.tile_strides(), + tiled_hlo_instruction.block_id_to_tile_offsets_indexing()); +} + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_TILED_HLO_INSTRUCTION_H_ diff --git a/xla/service/gpu/model/tiled_hlo_instruction_test.cc b/xla/service/gpu/model/tiled_hlo_instruction_test.cc new file mode 100644 index 0000000000000..dc2db9b3d96bd --- /dev/null +++ b/xla/service/gpu/model/tiled_hlo_instruction_test.cc @@ -0,0 +1,142 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/tiled_hlo_instruction.h" + +#include + +#include +#include +#include "absl/container/flat_hash_set.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/shape_util.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class TiledHloInstructionTest : public HloTestBase { + public: + mlir::MLIRContext mlir_context_; +}; + +TEST_F(TiledHloInstructionTest, PtrHashAndPtrEqualWorkCorrectly) { + std::unique_ptr hlo = HloInstruction::CreateParameter( + /*parameter_number=*/0, + ShapeUtil::MakeShape(PrimitiveType::F32, {32, 64}), "p0"); + + IndexingMap block_id_to_tile_offsets_indexing = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0) -> (d0 floordiv 16, (d0 mod 16) * 16)", + &mlir_context_), + /*dim_upper_bounds=*/{8}, + /*symbol_upper_bounds=*/{}); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr tiled_hlo1, + TiledHloInstruction::Create(hlo.get(), /*tile_sizes=*/{16, 16}, + /*tile_strides=*/{1, 1}, + block_id_to_tile_offsets_indexing)); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr tiled_hlo2, + TiledHloInstruction::Create(hlo.get(), /*tile_sizes=*/{16, 16}, + /*tile_strides=*/{1, 1}, + block_id_to_tile_offsets_indexing)); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr tiled_hlo3, + TiledHloInstruction::Create(hlo.get(), /*tile_sizes=*/{16, 32}, + /*tile_strides=*/{1, 1}, + block_id_to_tile_offsets_indexing)); + + EXPECT_EQ(*tiled_hlo1, *tiled_hlo2); + EXPECT_NE(*tiled_hlo1, *tiled_hlo3); + + absl::flat_hash_set + tiled_hlo_set = {tiled_hlo1.get(), tiled_hlo2.get(), tiled_hlo3.get()}; + EXPECT_EQ(tiled_hlo_set.size(), 2); +} + +TEST_F(TiledHloInstructionTest, TileSizesAndStridesShouldMatchHloShapeRank) { + std::unique_ptr hlo = HloInstruction::CreateParameter( + /*parameter_number=*/0, + ShapeUtil::MakeShape(PrimitiveType::F32, {32, 64}), "p0"); + + IndexingMap block_id_to_tile_offsets_indexing = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0) -> (d0 floordiv 16, (d0 mod 16) * 16)", + &mlir_context_), + /*dim_upper_bounds=*/{8}, + /*symbol_upper_bounds=*/{}); + + EXPECT_THAT( + TiledHloInstruction::Create(hlo.get(), /*tile_sizes=*/{16}, + /*tile_strides=*/{1, 1}, + block_id_to_tile_offsets_indexing) + .status() + .message(), + ::testing::HasSubstr("Number of tile sizes must be equal to the rank")); + + EXPECT_THAT( + TiledHloInstruction::Create(hlo.get(), /*tile_sizes=*/{16, 16}, + /*tile_strides=*/{1, 1, 1}, + block_id_to_tile_offsets_indexing) + .status() + .message(), + ::testing::HasSubstr("Number of tile strides must be equal to the rank")); +} + +TEST_F(TiledHloInstructionTest, + ShouldReturnErrorIfBlockIdToTileOffsetsIndexingIsInvalid) { + std::unique_ptr hlo = HloInstruction::CreateParameter( + /*parameter_number=*/0, + ShapeUtil::MakeShape(PrimitiveType::F32, {32, 64}), "p0"); + + IndexingMap block_id_to_tile_offsets_indexing1 = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0) -> (d0 floordiv 16)", &mlir_context_), + /*dim_upper_bounds=*/{8}, + /*symbol_upper_bounds=*/{}); + + EXPECT_THAT( + TiledHloInstruction::Create(hlo.get(), /*tile_sizes=*/{16, 16}, + /*tile_strides=*/{1, 1}, + block_id_to_tile_offsets_indexing1) + .status() + .message(), + ::testing::HasSubstr( + "must have the same number of results as the rank of the hlo shape")); + + IndexingMap block_id_to_tile_offsets_indexing2 = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0)[s0] -> (d0 + s0, d0 floordiv 16)", &mlir_context_), + /*dim_upper_bounds=*/{8}, + /*symbol_upper_bounds=*/{8}); + + EXPECT_THAT(TiledHloInstruction::Create(hlo.get(), /*tile_sizes=*/{16, 16}, + /*tile_strides=*/{1, 1}, + block_id_to_tile_offsets_indexing2) + .status() + .message(), + ::testing::HasSubstr("must have 1 dim and 0 symbols")); +} + +} // namespace + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/move_copy_to_users.cc b/xla/service/gpu/move_copy_to_users.cc index 367014ab3eb60..51ffbed0ec013 100644 --- a/xla/service/gpu/move_copy_to_users.cc +++ b/xla/service/gpu/move_copy_to_users.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,7 +18,9 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -35,7 +37,7 @@ namespace { class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { // Turn copy->pad into pad->copy - Status HandlePad(HloInstruction* hlo) override { + absl::Status HandlePad(HloInstruction* hlo) override { HloInstruction* operand = hlo->mutable_operand(0); HloInstruction* c = hlo->mutable_operand(1); if (operand->opcode() == HloOpcode::kCopy) { @@ -49,11 +51,11 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { HloInstruction* later_copy = MakeCopyHlo(earlier_pad, hlo->shape()); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); } - return OkStatus(); + return absl::OkStatus(); } // Turn copy->slice into slice->copy, as slice is layout-preserving. - Status HandleSlice(HloInstruction* hlo) override { + absl::Status HandleSlice(HloInstruction* hlo) override { HloInstruction* operand = hlo->mutable_operand(0); if (operand->opcode() == HloOpcode::kCopy) { HloInstruction* copied = operand->mutable_operand(0); @@ -66,12 +68,32 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { HloInstruction* later_copy = MakeCopyHlo(earlier_slice, hlo->shape()); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); } - return OkStatus(); + return absl::OkStatus(); + } + + // Turn copy->dynamic-slice into dynamic-slice->copy, as dynamic-slice is + // layout-preserving. + absl::Status HandleDynamicSlice(HloInstruction* hlo) override { + HloInstruction* operand = hlo->mutable_operand(0); + if (operand->opcode() == HloOpcode::kCopy) { + HloInstruction* copied = operand->mutable_operand(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * earlier_slice, + MakeDynamicSliceHlo( + copied, + absl::Span(hlo->operands()).subspan(1), + hlo->dynamic_slice_sizes(), &hlo->metadata())); + *earlier_slice->mutable_shape()->mutable_layout() = + copied->shape().layout(); + HloInstruction* later_copy = MakeCopyHlo(earlier_slice, hlo->shape()); + TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); + } + return absl::OkStatus(); } // Turn copy->reduce_window into reduce_window->copy, as reduce_window is // layout-preserving. - Status HandleReduceWindow(HloInstruction* hlo) override { + absl::Status HandleReduceWindow(HloInstruction* hlo) override { HloInstruction* operand = hlo->mutable_operand(0); if (operand->opcode() == HloOpcode::kCopy) { HloInstruction* copied = operand->mutable_operand(0); @@ -85,10 +107,10 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { MakeCopyHlo(earlier_reduce_window, hlo->shape()); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); } - return OkStatus(); + return absl::OkStatus(); } - Status HandleReduce(HloInstruction* hlo) override { + absl::Status HandleReduce(HloInstruction* hlo) override { HloInstruction* operand = hlo->mutable_operand(0); // Reductions can handle transposes, e.g. via column reduction. if (operand->opcode() == HloOpcode::kCopy && !hlo->shape().IsTuple()) { @@ -97,18 +119,18 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { hlo->mutable_operand(1)})); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, new_reduce)); } - return OkStatus(); + return absl::OkStatus(); } - Status HandleBitcastConvert(HloInstruction* hlo) override { - return OkStatus(); + absl::Status HandleBitcastConvert(HloInstruction* hlo) override { + return absl::OkStatus(); } // Sink kCopy across elementwise unary. - Status HandleElementwiseUnary(HloInstruction* hlo) override { + absl::Status HandleElementwiseUnary(HloInstruction* hlo) override { HloInstruction* operand = hlo->mutable_operand(0); if (hlo->opcode() == HloOpcode::kReducePrecision) { - return OkStatus(); + return absl::OkStatus(); } if (operand->opcode() == HloOpcode::kCopy) { HloInstruction* copied = operand->mutable_operand(0); @@ -119,11 +141,11 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { MakeCopyHlo(earlier_elementwise, hlo->shape()); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); } - return OkStatus(); + return absl::OkStatus(); } // Sink kCopy across reverse - Status HandleReverse(HloInstruction* hlo) override { + absl::Status HandleReverse(HloInstruction* hlo) override { HloInstruction* operand = hlo->mutable_operand(0); if (operand->opcode() == HloOpcode::kCopy) { HloInstruction* copied = operand->mutable_operand(0); @@ -133,11 +155,11 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { HloInstruction* later_copy = MakeCopyHlo(earlier_reverse, hlo->shape()); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); } - return OkStatus(); + return absl::OkStatus(); } // Sink kCopy across convert. - Status HandleConvert(HloInstruction* hlo) override { + absl::Status HandleConvert(HloInstruction* hlo) override { HloInstruction* operand = hlo->mutable_operand(0); if (operand->opcode() == HloOpcode::kCopy) { HloInstruction* copied = operand->mutable_operand(0); @@ -146,11 +168,11 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { HloInstruction* later_copy = MakeCopyHlo(earlier_convert, hlo->shape()); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); } - return OkStatus(); + return absl::OkStatus(); } // Sink kCopy across elementwise binary. - Status HandleElementwiseBinary(HloInstruction* hlo) override { + absl::Status HandleElementwiseBinary(HloInstruction* hlo) override { HloInstruction* a = hlo->mutable_operand(0); HloInstruction* b = hlo->mutable_operand(1); if (a->opcode() == HloOpcode::kCopy && b->opcode() == HloOpcode::kCopy) { @@ -173,14 +195,14 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); } } - return OkStatus(); + return absl::OkStatus(); } // Move copy across kConcat if it occurs on all operands. - Status HandleConcatenate(HloInstruction* hlo) override { + absl::Status HandleConcatenate(HloInstruction* hlo) override { const HloInstruction* first = hlo->operand(0); if (first->opcode() != HloOpcode::kCopy) { - return OkStatus(); + return absl::OkStatus(); } const HloInstruction* inner_op = first->operand(0); const Layout& inner_op_layout = inner_op->shape().layout(); @@ -192,7 +214,7 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { op->operand(0)->shape().layout() != inner_op_layout) { VLOG(3) << "Mismatch between " << op->ToString() << " and expected op layout " << inner_op_layout.ToString(); - return OkStatus(); + return absl::OkStatus(); } new_operands.push_back(op->mutable_operand(0)); } @@ -204,13 +226,13 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { HloInstruction* new_copy = MakeCopyHlo(new_concat, hlo->shape()); TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, new_copy)); - return OkStatus(); + return absl::OkStatus(); } }; } // end namespace -StatusOr MoveCopyToUsers::Run( +absl::StatusOr MoveCopyToUsers::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { return MoveCopyToUsersVisitor{}.RunOnModule(module, execution_threads); diff --git a/xla/service/gpu/move_copy_to_users.h b/xla/service/gpu/move_copy_to_users.h index a01d8e97d186f..4a7dfb43bbf6e 100644 --- a/xla/service/gpu/move_copy_to_users.h +++ b/xla/service/gpu/move_copy_to_users.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,12 +16,11 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_MOVE_COPY_TO_USERS_H_ #define XLA_SERVICE_GPU_MOVE_COPY_TO_USERS_H_ -#include - +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" namespace xla { @@ -30,7 +29,7 @@ class MoveCopyToUsers : public HloModulePass { public: absl::string_view name() const override { return "move_copy_to_users"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/move_copy_to_users_test.cc b/xla/service/gpu/move_copy_to_users_test.cc index 3e4b008fcff16..718847168f901 100644 --- a/xla/service/gpu/move_copy_to_users_test.cc +++ b/xla/service/gpu/move_copy_to_users_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,8 +16,9 @@ limitations under the License. #include "xla/service/gpu/move_copy_to_users.h" #include -#include +#include "absl/strings/string_view.h" +#include "xla/service/layout_assignment.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/test.h" @@ -26,6 +27,10 @@ namespace { class MoveCopyToUsersTest : public HloTestBase { public: + MoveCopyToUsersTest() + : HloTestBase(/*verifier_layout_sensitive=*/true, + /*allow_mixed_precision_in_hlo_verifier=*/true, + LayoutAssignment::InstructionCanChangeLayout) {} void CheckMoveCopyToUsers(absl::string_view hlo, std::optional expected) { RunAndFilecheckHloRewrite(hlo, MoveCopyToUsers{}, expected); @@ -112,13 +117,34 @@ HloModule module ENTRY main { input = f32[1,17,9,9]{3,2,1,0} parameter(0) copy = f32[1,17,9,9]{1,3,2,0} copy(input) - ROOT converted = f32[1,4,6,6] slice(copy), slice={[0:1],[0:4],[0:6],[0:6]} + ROOT slice = f32[1,4,6,6]{1,3,2,0} slice(copy), slice={[0:1],[0:4],[0:6],[0:6]} } )"; CheckMoveCopyToUsers(hlo, R"( // CHECK: [[slice_0:%[^ ]+]] = f32[1,4,6,6]{3,2,1,0} slice([[input_1:%[^ ]+]]), slice={[0:1], [0:4], [0:6], [0:6]} -// CHECK-NEXT: ROOT [[copy_1_2:%[^ ]+]] = f32[1,4,6,6]{3,2,1,0} copy([[slice_0]]) +// CHECK-NEXT: ROOT [[copy_1_2:%[^ ]+]] = f32[1,4,6,6]{1,3,2,0} copy([[slice_0]]) +)"); +} + +TEST_F(MoveCopyToUsersTest, DynamicSlice) { + const char* hlo = R"( +HloModule module + +ENTRY main { + input = f32[1,17,9,9]{3,2,1,0} parameter(0) + copy = f32[1,17,9,9]{1,3,2,0} copy(input) + s0 = s32[] parameter(1) + s1 = s32[] parameter(2) + s2 = s32[] parameter(3) + s3 = s32[] parameter(4) + ROOT ds = f32[1,4,6,6]{1,3,2,0} dynamic-slice(copy, s0, s1, s2, s3), dynamic_slice_sizes={1,4,6,6} +} +)"; + + CheckMoveCopyToUsers(hlo, R"( +// CHECK: [[ds:%[^ ]+]] = f32[1,4,6,6]{3,2,1,0} dynamic-slice({{.*}}), dynamic_slice_sizes={1,4,6,6} +// CHECK-NEXT: ROOT {{.*}} = f32[1,4,6,6]{1,3,2,0} copy([[ds]]) )"); } diff --git a/xla/service/gpu/multi_output_fusion.cc b/xla/service/gpu/multi_output_fusion.cc index a88845be62f3c..a6e60230acd92 100644 --- a/xla/service/gpu/multi_output_fusion.cc +++ b/xla/service/gpu/multi_output_fusion.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,27 +16,36 @@ limitations under the License. #include "xla/service/gpu/multi_output_fusion.h" #include -#include +#include +#include #include #include -#include #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_dfs_reachability.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_reachability.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/service/hlo_graph_dumper.h" #include "xla/service/instruction_fusion.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -153,7 +162,7 @@ HloInstruction* SelectPreferredFusionCandidate( // reachable from the producer, this would create a cycle. FusionDecision OperandReachableFromProducer( const HloInstruction& producer, const HloInstruction& consumer, - const HloReachabilityMap& reachability) { + const HloDfsReachability& reachability) { for (const auto* operand : consumer.operands()) { // If a get-tuple-element instruction is not in the reachability // map, it has been created by fusion in this pass. Simply move @@ -173,8 +182,41 @@ FusionDecision OperandReachableFromProducer( return {}; } +FusionDecision ProducerCandidateIsFusible( + const HloInstruction& producer, const HloInstruction& consumer, + const HloDfsReachability& reachability, FusionInfoCache* fusion_info_cache, + GpuHloCostAnalysis* cost_analysis) { + if (!IsFusibleAsMultiOutputFusionRoot(consumer)) { + return "consumer not eligible as multi-output fusion root."; + } + + RETURN_IF_NOT_FUSIBLE( + ShapesCompatibleForMultiOutputFusion(consumer, producer)); + + RETURN_IF_NOT_FUSIBLE( + OperandReachableFromProducer(producer, consumer, reachability)); + + RETURN_IF_NOT_FUSIBLE(FusionFitsInBudget( + producer, consumer, *cost_analysis->device_info_, + /*is_consumer_producer_fusion=*/false, fusion_info_cache)); + + if (cost_analysis->ProducerConsumerMergedTooLarge(producer, consumer)) { + return "will generate too large IR"; + } + + GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( + &producer, cost_analysis, GpuPerformanceModelOptions::Default(), + /*fused_consumers=*/{&consumer}, + /*multi_output=*/true); + if (t.time_fused > t.time_unfused) { + return "will execute slower if fused"; + } + + return {}; +} + std::vector GetProducerConsumerMultiOutputFusionCandidates( - const HloInstruction* producer, const HloReachabilityMap& reachability, + const HloInstruction* producer, const HloDfsReachability& reachability, FusionInfoCache* fusion_info_cache, GpuHloCostAnalysis* cost_analysis) { std::vector fusion_candidates; const HloComputation* computation = producer->parent(); @@ -196,46 +238,19 @@ std::vector GetProducerConsumerMultiOutputFusionCandidates( return fusion_candidates; } - using std::placeholders::_1, std::placeholders::_2; - std::tuple checks{ - [](const HloInstruction& producer, - const HloInstruction& consumer) -> FusionDecision { - return {IsFusibleAsMultiOutputFusionRoot(consumer), - "consumer not eligible as multi-output fusion root."}; - }, - &ShapesCompatibleForMultiOutputFusion, - std::bind(OperandReachableFromProducer, _1, _2, std::cref(reachability)), - std::bind(FusionFitsInBudget, _1, _2, - std::cref(*cost_analysis->device_info_), - /*is_consumer_producer_fusion=*/false, fusion_info_cache), - [&](const HloInstruction& producer, - const HloInstruction& consumer) -> FusionDecision { - return { - !cost_analysis->ProducerConsumerMergedTooLarge(producer, consumer), - "will generate too large IR"}; - }, - [&](const HloInstruction& producer, - const HloInstruction& consumer) -> FusionDecision { - GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( - &producer, cost_analysis, GpuPerformanceModelOptions::Default(), - // `EstimateRunTimes`'s interface violates const correctness, so we - // need the const cast here. - {const_cast(&consumer)}, - /*multi_output=*/true); - return {t.time_fused <= t.time_unfused, "will execute slower if fused"}; - }}; - for (HloInstruction* consumer : producer->users()) { VLOG(3) << "Looking at producer " << producer->name() << " and its consumer " << consumer->name(); - if (auto decision = FusionDecision::All(checks, *producer, *consumer)) { + if (auto decision = + ProducerCandidateIsFusible(*producer, *consumer, reachability, + fusion_info_cache, cost_analysis)) { fusion_candidates.push_back(consumer); } else if (dump_fusion) { RegisterFusionState( *computation, - absl::StrCat("Not considering fusion of producer |", "|", - producer->name(), "| into consumer |", consumer->name(), + absl::StrCat("Not considering fusion of producer |", producer->name(), + "| into consumer |", consumer->name(), "| due to: ", decision.Explain()), *consumer, producer); } @@ -251,20 +266,46 @@ bool IsSiblingFusionCandidate(const HloInstruction* instr) { // Check if the users of multioutput fusion is not a get-tuple-element. // If this is the case, we bail out because the transformation assumes // the users are get-tuple-element. - if (instr->IsMultiOutputFusion()) { - for (HloInstruction* user : instr->users()) { - if (user->opcode() != HloOpcode::kGetTupleElement) { - return false; - } - } + return (!instr->IsMultiOutputFusion() || + absl::c_all_of(instr->users(), [&](const HloInstruction* user) { + return user->opcode() == HloOpcode::kGetTupleElement; + })); +} + +FusionDecision CanFuseSiblings(const HloInstruction& sibling_consumer_1, + const HloInstruction& sibling_consumer_2, + const HloInstruction& common_producer, + const HloDfsReachability& reachability, + FusionInfoCache* fusion_info_cache, + GpuHloCostAnalysis* cost_analysis) { + if (reachability.IsConnected(&sibling_consumer_1, &sibling_consumer_2)) { + return {absl::StrCat(sibling_consumer_1.name(), " and ", + sibling_consumer_2.name(), " are connected")}; } - return true; + + RETURN_IF_NOT_FUSIBLE(ShapesCompatibleForMultiOutputFusion( + sibling_consumer_1, sibling_consumer_2)); + + // Technically, this check is order-dependent (e.g. siblings A, B, C where + // {A, B} and {B, C} overlap, but {A, C} do not. If the priority order is + // [C, A, B], only {C, B} will be fused, and A will only be fused in the + // next iteration of the fusion pipeline, potentially requiring several + // iterations to converge. We assume this case to be very rare in + // practice. + RETURN_IF_NOT_FUSIBLE(ParameterSlicesAreNonOverlapping( + sibling_consumer_1, sibling_consumer_2, &common_producer)); + + // This check should be last, as it may be expensive. + RETURN_IF_NOT_FUSIBLE(LegalToFuse(sibling_consumer_1, sibling_consumer_2, + *cost_analysis->device_info_, + fusion_info_cache)); + return {}; } } // namespace void GpuMultiOutputFusion::RecomputeReachability() { - reachability_ = HloReachabilityMap::Build(computation_); + reachability_ = HloDfsReachability::Build(computation_); } bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent, @@ -290,24 +331,7 @@ bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent, [](const HloInstruction* a, const HloInstruction* b) { return FusionPriority(a) > FusionPriority(b); }); - using std::placeholders::_1, std::placeholders::_2; - std::tuple fusible_checks{ - [&](const HloInstruction& i, const HloInstruction& j) -> FusionDecision { - return FusionDecision{ - !reachability_->IsConnected(&i, &j), - absl::StrCat(i.name(), " and ", j.name(), " are connected")}; - }, - &ShapesCompatibleForMultiOutputFusion, - // Technically, this check is order-dependent (e.g. siblings A, B, C where - // {A, B} and {B, C} overlap, but {A, C} do not. If the priority order is - // [C, A, B], only {C, B} will be fused, and A will only be fused in the - // next iteration of the fusion pipeline, potentially requiring several - // iterations to converge. We assume this case to be very rare in - // practice. - std::bind(ParameterSlicesAreNonOverlapping, _1, _2, parent), - // This check should be last, as it may be expensive. - std::bind(LegalToFuse, _1, _2, std::cref(*cost_analysis->device_info_), - fusion_info_cache)}; + for (auto i = siblings.begin(); i != siblings.end(); ++i) { VLOG(3) << "Considering " << (*i)->name(); if ((*i)->opcode() != HloOpcode::kFusion) { @@ -316,7 +340,8 @@ bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent, for (auto j = i + 1; j != siblings.end();) { VLOG(3) << "Considering " << (*i)->name() << " and " << (*j)->name(); - if (auto fusible = FusionDecision::All(fusible_checks, **i, **j); + if (auto fusible = CanFuseSiblings(**i, **j, *parent, *reachability_, + fusion_info_cache, cost_analysis); !fusible) { // We pick `j` arbitrarily as a consumer. if (dump_fusion) { @@ -347,9 +372,9 @@ bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent, TF_CHECK_OK(cost_analysis->RemoveInstruction(fused)); DumpFusionState(*remaining, - absl::StrCat("About to fuse producer |", fused->name(), - "| into consumer |", remaining->name(), - "| inside GPU multi-output fusion"), + absl::StrCat("About to fuse sibling |", fused->name(), + "| into sibling |", remaining->name(), + "| inside multi-output fusion"), /*producer=*/fused); if (fused->opcode() == HloOpcode::kFusion) { @@ -363,8 +388,8 @@ bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent, TF_CHECK_OK(computation_->RemoveInstruction(fused)); } DumpFusionState(*remaining, - absl::StrCat("Fused into consumer |", remaining->name(), - "| inside GPU multi-output fusion")); + absl::StrCat("Fused into |", remaining->name(), + "| inside multi-output fusion")); TF_CHECK_OK(cost_analysis->RevisitInstruction(remaining)); changed = true; siblings.erase(j); @@ -374,7 +399,7 @@ bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent, return changed; } -StatusOr GpuMultiOutputFusion::DoMultiOutputFusion() { +absl::StatusOr GpuMultiOutputFusion::DoMultiOutputFusion() { bool changed = false; RecomputeReachability(); GpuHloCostAnalysis cost_analysis({shape_size_function_, @@ -390,7 +415,6 @@ StatusOr GpuMultiOutputFusion::DoMultiOutputFusion() { for (auto it = defs_before_uses.rbegin(); it != defs_before_uses.rend(); ++it) { auto* producer = *it; - absl::string_view producer_name = producer->name(); // Never multi-output fuse constants. To the extent that we want to fuse // constants, that should be handled by the regular fusion pass. if (producer->opcode() == HloOpcode::kConstant) { @@ -443,6 +467,12 @@ StatusOr GpuMultiOutputFusion::DoMultiOutputFusion() { computation_->ReplaceInstruction(consumer_for_fusion, input_fusion)); } + DumpFusionState(*input_fusion, + absl::StrCat("About to fuse producer |", producer->name(), + "| into consumer |", input_fusion->name(), + "| inside multi-output fusion"), + /*producer=*/producer); + if (producer->opcode() == HloOpcode::kFusion) { input_fusion->MergeFusionInstructionIntoMultiOutput(producer); } else { @@ -452,10 +482,9 @@ StatusOr GpuMultiOutputFusion::DoMultiOutputFusion() { } TF_RETURN_IF_ERROR(cost_analysis.RevisitInstruction(input_fusion)); - DumpFusionState( - *input_fusion, - absl::StrCat("Fusing producer |", producer_name, "| into consumer |", - input_fusion->name(), "| inside GPU multi-output fusion")); + DumpFusionState(*input_fusion, + absl::StrCat("Fused into |", input_fusion->name(), + "| inside multi-output fusion")); RecomputeReachability(); } return changed; @@ -472,7 +501,7 @@ void GpuMultiOutputFusion::DumpFusionState(const HloInstruction& consumer, } } -StatusOr GpuMultiOutputFusion::Run( +absl::StatusOr GpuMultiOutputFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/gpu/multi_output_fusion.h b/xla/service/gpu/multi_output_fusion.h index 73bd2f9eb6ee3..82789d3be5791 100644 --- a/xla/service/gpu/multi_output_fusion.h +++ b/xla/service/gpu/multi_output_fusion.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,17 +17,18 @@ limitations under the License. #define XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ #include -#include -#include -#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_dfs_reachability.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_reachability.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" namespace xla { @@ -101,7 +102,7 @@ class GpuMultiOutputFusion : public HloModulePass { absl::string_view name() const override { return "multi_output_fusion"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -109,7 +110,7 @@ class GpuMultiOutputFusion : public HloModulePass { bool FuseSiblings(HloInstruction* parent, FusionInfoCache* fusion_info_cache, GpuHloCostAnalysis* cost_analysis); - StatusOr DoMultiOutputFusion(); + absl::StatusOr DoMultiOutputFusion(); // Recompute reachability for the current computation. void RecomputeReachability(); @@ -121,7 +122,7 @@ class GpuMultiOutputFusion : public HloModulePass { HloComputation* computation_; // The reachability map of current computation. - std::unique_ptr reachability_; + std::unique_ptr reachability_; se::DeviceDescription device_info_; HloCostAnalysis::ShapeSizeFunction shape_size_function_; diff --git a/xla/service/gpu/multi_output_fusion_test.cc b/xla/service/gpu/multi_output_fusion_test.cc index 9ade64c727053..c28c953a58bdb 100644 --- a/xla/service/gpu/multi_output_fusion_test.cc +++ b/xla/service/gpu/multi_output_fusion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,15 +15,23 @@ limitations under the License. #include "xla/service/gpu/multi_output_fusion.h" +#include #include #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/gpu_fusible.h" +#include "xla/service/hlo_cost_analysis.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/stream_executor/device_description.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" namespace xla { @@ -1968,35 +1976,35 @@ ENTRY main { )"); } -TEST_F(TransposeMultiOutputFusionTest, CopyAndInputEpilogueFusion) { +TEST_F(TransposeMultiOutputFusionTest, TransposeAndInputEpilogueFusion) { const char* hlo = R"( HloModule module fused_computation { param_0.1 = f32[16,32]{1,0} parameter(0) s.1 = f32[16,32]{1,0} sqrt(param_0.1) - c.1 = f32[16,32]{0,1} copy(s.1) - ROOT out = f32[16,32,1]{0,1,2} bitcast(c.1) + t.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0} + ROOT out = f32[32,16,1]{2,1,0} bitcast(t.1) } ENTRY main { p = f32[16,32]{1,0} parameter(0) - fusion = f32[16,32,1]{0,1,2} fusion(p), kind=kInput, calls=fused_computation + fusion = f32[32,16,1]{2,1,0} fusion(p), kind=kInput, calls=fused_computation c1 = exponential(p) ROOT t = tuple(fusion, c1) } )"; CheckGpuMultiOutputFusion(hlo, R"( -// CHECK: %fused_computation (param_0.1: f32[16,32]) -> (f32[16,32,1], f32[16,32]) { +// CHECK: %fused_computation // CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0) // CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]]) -// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1_1]]) -// CHECK-NEXT: [[out_3:%[^ ]+]] = f32[16,32,1]{0,1,2} bitcast([[c_1_2]]) +// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[32,16]{1,0} transpose([[s_1_1]]) +// CHECK-NEXT: [[out_3:%[^ ]+]] = f32[32,16,1]{2,1,0} bitcast([[c_1_2]]) // CHECK-NEXT: [[c1_1_4:%[^ ]+]] = f32[16,32]{1,0} exponential([[param_0_1_0]]) -// CHECK-NEXT: ROOT [[tuple_5:%[^ ]+]] = (f32[16,32,1]{0,1,2}, f32[16,32]{1,0}) tuple([[out_3]], [[c1_1_4]]) +// CHECK-NEXT: ROOT [[tuple_5:%[^ ]+]] = (f32[32,16,1]{2,1,0}, f32[16,32]{1,0}) tuple([[out_3]], [[c1_1_4]]) // CHECK-NEXT: } -// CHECK: [[fusion_0:%[^ ]+]] = (f32[16,32,1]{0,1,2}, f32[16,32]{1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]] +// CHECK: [[fusion_0:%[^ ]+]] = (f32[32,16,1]{2,1,0}, f32[16,32]{1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]] )"); } diff --git a/xla/service/gpu/nccl_all_gather_thunk.cc b/xla/service/gpu/nccl_all_gather_thunk.cc deleted file mode 100644 index a4fcdfa5fefb0..0000000000000 --- a/xla/service/gpu/nccl_all_gather_thunk.cc +++ /dev/null @@ -1,139 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/nccl_all_gather_thunk.h" - -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "xla/service/gpu/ir_emission_utils.h" - -#if XLA_ENABLE_XCCL -#include "xla/stream_executor/gpu/gpu_stream.h" -#endif - -namespace xla { -namespace gpu { - -using mlir::lmhlo_gpu::AllGatherStartOp; - -namespace impl { -NcclAllGatherConfig GetNcclAllGatherConfig(AllGatherStartOp op) { - NcclAllGatherConfig config; - config.config = - GetNcclCollectiveConfigForMlir(op, op.getUseGlobalDeviceIds()); - return config; -} - -Status CheckImplementable(AllGatherStartOp op) { - TF_RETURN_IF_ERROR(NcclCollectiveThunk::CheckImplementable()); - for (mlir::Value operand : op.getInputs()) { - TF_RETURN_IF_ERROR(IsValidOperand(operand, Thunk::kNcclAllGather)); - Shape shape = GetShape(operand); - if (!ShapeUtil::IsEffectivelyMostMajorDimension( - shape, op.getAllGatherDimension())) { - return tsl::errors::Unimplemented(absl::StrFormat( - "all-gather dim %u is not the most major in input shape %s", - op.getAllGatherDimension(), shape.ToString(/*print_layout=*/true))); - } - } - return OkStatus(); -} -} // namespace impl - -NcclAllGatherStartThunk::NcclAllGatherStartThunk( - ThunkInfo thunk_info, AllGatherStartOp op, - std::vector buffers) - : NcclCollectiveThunk(Thunk::kNcclAllGatherStart, thunk_info, - op.getIsSync()), - config_(impl::GetNcclAllGatherConfig(op)), - buffers_(std::move(buffers)) { - CHECK_EQ(config_.config.operand_count, buffers_.size()); -} - -/*static*/ Status NcclAllGatherStartThunk::CheckImplementable( - AllGatherStartOp op, int64_t replica_count, int64_t partition_count) { - return AddOpDescription( - impl::CheckImplementable(op), op, replica_count, partition_count); -} - -/*static*/ bool NcclAllGatherStartThunk::IsDegenerate(AllGatherStartOp op, - int64_t replica_count, - int64_t partition_count) { - return impl::GetNcclAllGatherConfig(op).config.IsDegenerate(replica_count, - partition_count); -} - -/*static*/ CollectiveOpGroupMode NcclAllGatherStartThunk::GetGroupMode( - AllGatherStartOp op) { - return impl::GetNcclAllGatherConfig(op).config.group_mode; -} - -Status NcclAllGatherStartThunk::RunNcclCollective(const ExecuteParams& params, - se::Stream& stream, - ncclComm_t comm) { - TF_ASSIGN_OR_RETURN( - std::vector device_buffers, - ConvertToDeviceBuffers(params, buffers_, - config_.config.operand_element_type)); - return xla::gpu::RunAllGather(device_buffers, stream, comm); -} - -Status RunAllGather(std::vector& buffers, se::Stream& stream, - ncclComm_t comm) { -#if XLA_ENABLE_XCCL - int device_ordinal = stream.parent()->device_ordinal(); - VLOG(3) << "Performing all-gather from device ordinal: " << device_ordinal; - - se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream); - - XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart()); - for (size_t i = 0; i < buffers.size(); ++i) { - DeviceBufferPair& buffer = buffers[i]; - const void* send_buffer = buffer.source_buffer.opaque(); - void* recv_buffer = buffer.destination_buffer.opaque(); - - PrimitiveType element_type = buffer.element_type; - TF_ASSIGN_OR_RETURN( - auto dtype_and_multiplier, - ToNcclDataTypeAndCountMultiplier(element_type, Thunk::kNcclAllGather)); - ncclDataType_t dtype = dtype_and_multiplier.first; - int64_t element_count = buffer.element_count * dtype_and_multiplier.second; - - VLOG(3) << absl::StreamFormat( - "Calling ncclAllGather(send_buffer=%p, recv_buffer=%p, sendcount=%d, " - "comm=%p, stream=%p)", - send_buffer, recv_buffer, element_count, static_cast(comm), - gpu_stream); - - XLA_CUDA_RETURN_IF_ERROR(ncclAllGather( - send_buffer, recv_buffer, element_count, dtype, comm, gpu_stream)); - } - XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd()); - - VLOG(3) << "Done performing all-gather for ordinal: " << device_ordinal; - return OkStatus(); -#else // XLA_ENABLE_XCCL - return Unimplemented( - "NCCL support is not available: this binary was not built with a CUDA " - "compiler, which is necessary to build the NCCL source library."); -#endif // XLA_ENABLE_XCCL -} - - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/nccl_all_gather_thunk.h b/xla/service/gpu/nccl_all_gather_thunk.h deleted file mode 100644 index dfd4e39b4ebb5..0000000000000 --- a/xla/service/gpu/nccl_all_gather_thunk.h +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_NCCL_ALL_GATHER_THUNK_H_ -#define XLA_SERVICE_GPU_NCCL_ALL_GATHER_THUNK_H_ - -#include - -#include "xla/service/collective_ops_utils.h" -#include "xla/service/gpu/nccl_collective_thunk.h" - -namespace xla { -namespace gpu { - -struct NcclAllGatherConfig { - NcclCollectiveConfig config; -}; - -// Thunk that performs a NCCL-based All-Gather among CUDA GPU-based replicas. -class NcclAllGatherStartThunk : public NcclCollectiveThunk { - public: - NcclAllGatherStartThunk(ThunkInfo thunk_info, - mlir::lmhlo_gpu::AllGatherStartOp op, - std::vector buffers); - - static const char* GetHloOpName() { return "all-gather-start"; } - - static Status CheckImplementable(mlir::lmhlo_gpu::AllGatherStartOp op, - int64_t replica_count, - int64_t partition_count); - static bool IsDegenerate(mlir::lmhlo_gpu::AllGatherStartOp op, - int64_t replica_count, int64_t partition_count); - static CollectiveOpGroupMode GetGroupMode( - mlir::lmhlo_gpu::AllGatherStartOp op); - - protected: - const NcclCollectiveConfig& config() const override { return config_.config; } - Status RunNcclCollective(const ExecuteParams& params, se::Stream& stream, - ncclComm_t comm) override; - - private: - const NcclAllGatherConfig config_; - const std::vector buffers_; -}; - -Status RunAllGather(std::vector& buffers, se::Stream& stream, - ncclComm_t comm); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_NCCL_ALL_GATHER_THUNK_H_ diff --git a/xla/service/gpu/nccl_all_reduce_thunk.cc b/xla/service/gpu/nccl_all_reduce_thunk.cc deleted file mode 100644 index e8aaf1b429740..0000000000000 --- a/xla/service/gpu/nccl_all_reduce_thunk.cc +++ /dev/null @@ -1,356 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/nccl_all_reduce_thunk.h" - -#include -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "xla/service/collective_ops_utils.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" -#include "xla/xla_data.pb.h" - -#if XLA_ENABLE_XCCL -#include "xla/stream_executor/gpu/gpu_stream.h" -#endif - -namespace xla { -namespace gpu { - -using mlir::lmhlo_gpu::AllReduceStartOp; -using mlir::lmhlo_gpu::ReduceScatterStartOp; - -Status RunAllReduce(ReductionKind reduction_kind, - std::vector& buffers, se::Stream& stream, - ncclComm_t comm) { -#if XLA_ENABLE_XCCL - int device_ordinal = stream.parent()->device_ordinal(); - VLOG(3) << "Performing all-reduce from device ordinal: " << device_ordinal; - - ncclRedOp_t reduce_op = ToNcclReduction(reduction_kind); - - se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream); - - XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart()); - for (size_t i = 0; i < buffers.size(); ++i) { - DeviceBufferPair& buffer = buffers[i]; - const void* send_buffer = buffer.source_buffer.opaque(); - void* recv_buffer = buffer.destination_buffer.opaque(); - - TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier, - ToNcclDataTypeAndCountMultiplier( - buffer.element_type, Thunk::kNcclAllReduce)); - ncclDataType_t dtype = dtype_and_multiplier.first; - int64_t element_count = buffer.element_count * dtype_and_multiplier.second; - - VLOG(3) << absl::StreamFormat( - "Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, " - "comm=%p, stream=%p)", - send_buffer, recv_buffer, element_count, static_cast(comm), - gpu_stream); - - XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer, - element_count, dtype, reduce_op, - comm, gpu_stream)); - } - return XLA_CUDA_STATUS(ncclGroupEnd()); -#else // XLA_ENABLE_XCCL - return Unimplemented( - "NCCL support is not available: this binary was not built with a CUDA " - "compiler, which is necessary to build the NCCL source library."); -#endif // XLA_ENABLE_XCCL -} - -namespace { - -// Generally, the reduction op should be the only operation in the block, except -// the terminator. However, if the type is bf16, the `FloatNormalization` -// pass will have converted the op to float32 and added type conversions. -// TODO(cjfj): Can we prevent the bf16 conversion for this computation? -StatusOr FindReductionOp(mlir::Block& block) { - TF_RET_CHECK(block.getNumArguments() == 2); - mlir::Operation* terminator = block.getTerminator(); - TF_RET_CHECK(terminator); - TF_RET_CHECK(terminator->getNumOperands() == 1); - mlir::Value result = terminator->getOperand(0); - TF_RET_CHECK(block.getArgument(0).getType() == result.getType()); - TF_RET_CHECK(block.getArgument(1).getType() == result.getType()); - - mlir::Operation* result_op = result.getDefiningOp(); - TF_RET_CHECK(result_op); - - // In the bf16 case, the type conversions and op might be fused. - if (mlir::isa(result_op)) { - return FindReductionOp(result_op->getRegion(0).front()); - } - - // Standard case. - if (absl::c_is_permutation(result_op->getOperands(), block.getArguments())) { - return result_op; - } - - // bf16 case. - TF_RET_CHECK(mlir::isa(result_op)); - TF_RET_CHECK(result_op->getNumOperands() == 1); - mlir::Operation* reduction_op = result_op->getOperand(0).getDefiningOp(); - TF_RET_CHECK(reduction_op); - TF_RET_CHECK(reduction_op->getNumOperands() == 2); - mlir::Value operand0 = reduction_op->getOperand(0); - mlir::Value operand1 = reduction_op->getOperand(1); - auto operand0_op = operand0.getDefiningOp(); - auto operand1_op = operand1.getDefiningOp(); - TF_RET_CHECK(operand0_op); - TF_RET_CHECK(operand1_op); - TF_RET_CHECK(operand0_op->getNumOperands() == 1); - TF_RET_CHECK(operand1_op->getNumOperands() == 1); - std::array operands{operand0_op->getOperand(0), - operand1_op->getOperand(0)}; - TF_RET_CHECK(absl::c_is_permutation(operands, block.getArguments())); - return reduction_op; -} - -} // namespace - -namespace impl { - -template -Status CheckImplementable(OpT op, Thunk::Kind reduction_op) { - TF_RETURN_IF_ERROR(NcclCollectiveThunk::CheckImplementable()); - for (mlir::Value operand : op.getInputs()) { - TF_RETURN_IF_ERROR(IsValidOperand(operand, reduction_op)); - } - if (!NcclAllReduceReduceScatterThunkBase::MatchAllReduceComputation( - op.getComputation()) - .has_value()) { - return tsl::errors::Unimplemented("Unrecognized reduction computation"); - } - return OkStatus(); -} - -template -NcclAllReduceConfig GetNcclAllReduceConfig(OpT op) { - std::optional reduction_kind = - NcclAllReduceReduceScatterThunkBase::MatchAllReduceComputation( - op.getComputation()); - CHECK(reduction_kind.has_value()); - - NcclAllReduceConfig config; - config.config = - GetNcclCollectiveConfigForMlir(op, op.getUseGlobalDeviceIds()); - config.reduction_kind = *reduction_kind; - return config; -} - -template -bool IsDegenerate(OpT op, int64_t replica_count, int64_t partition_count) { - return GetNcclCollectiveConfigForMlir(op, op.getUseGlobalDeviceIds()) - .IsDegenerate(replica_count, partition_count); -} - -template -CollectiveOpGroupMode GetGroupMode(OpT op) { - return GetNcclAllReduceConfig(op).config.group_mode; -} - -} // namespace impl - -std::optional -NcclAllReduceReduceScatterThunkBase::MatchAllReduceComputation( - mlir::Region& computation) { - mlir::Block& block = computation.front(); - StatusOr reduction_op = FindReductionOp(block); - if (!reduction_op.ok()) return std::nullopt; - StatusOr opcode = MhloToHloOpcode(*reduction_op); - if (!opcode.ok()) return std::nullopt; - // Match the operation to a reduction kind. We can represent and/or of pred as - // min/max. This works because pred is stored as an 8-bit int of value 0 or 1. - PrimitiveType type = - TypeToShape(block.getArgument(0).getType()).element_type(); - if (type == PRED) { - switch (opcode.value()) { - case HloOpcode::kAnd: - return ReductionKind::MIN; - case HloOpcode::kOr: - return ReductionKind::MAX; - default: - return std::nullopt; - } - } else if (primitive_util::IsComplexType(type)) { - // Only addition is supported for complex types. - if (*opcode == HloOpcode::kAdd) { - return ReductionKind::SUM; - } else { - return std::nullopt; - } - } else { - switch (*opcode) { - case HloOpcode::kAdd: - return ReductionKind::SUM; - case HloOpcode::kMultiply: - return ReductionKind::PRODUCT; - case HloOpcode::kMaximum: - return ReductionKind::MAX; - case HloOpcode::kMinimum: - return ReductionKind::MIN; - default: - return std::nullopt; - } - } -} - -NcclAllReduceReduceScatterThunkBase::NcclAllReduceReduceScatterThunkBase( - Thunk::Kind kind, ThunkInfo thunk_info, NcclAllReduceConfig config, - std::vector buffers, bool is_sync) - : NcclCollectiveThunk(kind, thunk_info, is_sync), - config_(std::move(config)), - buffers_(std::move(buffers)) { - CHECK_EQ(config_.config.operand_count, buffers_.size()); -} - -NcclAllReduceStartThunk::NcclAllReduceStartThunk(ThunkInfo thunk_info, - AllReduceStartOp op, - std::vector buffers) - : NcclAllReduceReduceScatterThunkBase(Thunk::kNcclAllReduceStart, - thunk_info, - impl::GetNcclAllReduceConfig(op), - std::move(buffers), op.getIsSync()) {} - -Status NcclAllReduceStartThunk::CheckImplementable(AllReduceStartOp op, - int64_t replica_count, - int64_t partition_count) { - return AddOpDescription( - impl::CheckImplementable(op, Thunk::kNcclAllReduceStart), op, - replica_count, partition_count); -} - -bool NcclAllReduceStartThunk::IsDegenerate(AllReduceStartOp op, - int64_t replica_count, - int64_t partition_count) { - return impl::IsDegenerate(op, replica_count, partition_count); -} - -CollectiveOpGroupMode NcclAllReduceStartThunk::GetGroupMode( - AllReduceStartOp op) { - return impl::GetGroupMode(op); -} - -Status NcclAllReduceStartThunk::RunNcclCollective(const ExecuteParams& params, - se::Stream& stream, - ncclComm_t comm) { - TF_ASSIGN_OR_RETURN( - std::vector device_buffers, - ConvertToDeviceBuffers(params, buffers_, - config_.config.operand_element_type)); - return ::xla::gpu::RunAllReduce(config_.reduction_kind, device_buffers, - stream, comm); -} - -NcclReduceScatterStartThunk::NcclReduceScatterStartThunk( - ThunkInfo thunk_info, ReduceScatterStartOp op, - std::vector buffers) - : NcclAllReduceReduceScatterThunkBase(Thunk::kNcclReduceScatterStart, - thunk_info, - impl::GetNcclAllReduceConfig(op), - std::move(buffers), op.getIsSync()) {} - -/*static*/ Status NcclReduceScatterStartThunk::CheckImplementable( - ReduceScatterStartOp op, int64_t replica_count, int64_t partition_count) { - return AddOpDescription( - impl::CheckImplementable(op, Thunk::kNcclReduceScatterStart), op, - replica_count, partition_count); -} - -/*static*/ bool NcclReduceScatterStartThunk::IsDegenerate( - ReduceScatterStartOp op, int64_t replica_count, int64_t partition_count) { - return impl::IsDegenerate(op, replica_count, partition_count); -} - -/*static*/ CollectiveOpGroupMode NcclReduceScatterStartThunk::GetGroupMode( - ReduceScatterStartOp op) { - return impl::GetGroupMode(op); -} - -Status NcclReduceScatterStartThunk::RunNcclCollective( - const ExecuteParams& params, se::Stream& stream, ncclComm_t comm) { - TF_ASSIGN_OR_RETURN( - std::vector device_buffers, - ConvertToDeviceBuffers(params, buffers_, - config_.config.operand_element_type)); - return ::xla::gpu::RunReduceScatter(config_.reduction_kind, device_buffers, - stream, comm); -} - -Status RunReduceScatter(ReductionKind reduction_kind, - std::vector& buffers, - se::Stream& stream, ncclComm_t comm) { -#if XLA_ENABLE_XCCL - int device_ordinal = stream.parent()->device_ordinal(); - VLOG(3) << "Performing reduce-scatter from device ordinal: " - << device_ordinal; - - ncclRedOp_t reduce_op = ToNcclReduction(reduction_kind); - - se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream); - - int num_participants = 0; - XLA_CUDA_RETURN_IF_ERROR(ncclCommCount(comm, &num_participants)); - - XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart()); - for (size_t i = 0; i < buffers.size(); ++i) { - DeviceBufferPair& buffer = buffers[i]; - const void* send_buffer = buffer.source_buffer.opaque(); - void* recv_buffer = buffer.destination_buffer.opaque(); - - TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier, - ToNcclDataTypeAndCountMultiplier( - buffer.element_type, Thunk::kNcclReduceScatter)); - ncclDataType_t dtype = dtype_and_multiplier.first; - int64_t element_count = buffer.element_count * dtype_and_multiplier.second; - - // buffer.element_count is the source buffers element count. For - // ncclReduceScatter, we need the destination buffers element count. - TF_RET_CHECK(element_count % num_participants == 0) - << "Source buffer was not an exact multiple of the number of " - "participants."; - - int64_t recv_count = element_count / num_participants; - VLOG(3) << absl::StreamFormat( - "Calling ncclReduceScatter(send_buffer=%p, recv_buffer=%p, " - "recvcount=%d, " - "comm=%p, stream=%p)", - send_buffer, recv_buffer, recv_count, static_cast(comm), - gpu_stream); - XLA_CUDA_RETURN_IF_ERROR(ncclReduceScatter(send_buffer, recv_buffer, - recv_count, dtype, reduce_op, - comm, gpu_stream)); - } - XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd()); - - VLOG(3) << "Done performing reduce-scatter for ordinal: " << device_ordinal; - return OkStatus(); -#else // XLA_ENABLE_XCCL - return Unimplemented( - "NCCL support is not available: this binary was not built with a CUDA " - "compiler, which is necessary to build the NCCL source library."); -#endif // XLA_ENABLE_XCCL -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/nccl_all_reduce_thunk.h b/xla/service/gpu/nccl_all_reduce_thunk.h deleted file mode 100644 index b70e7dc317f15..0000000000000 --- a/xla/service/gpu/nccl_all_reduce_thunk.h +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ -#define XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ - -#include -#include - -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" -#include "xla/service/collective_ops_utils.h" -#include "xla/service/gpu/nccl_collective_thunk.h" - -namespace xla { -namespace gpu { - -struct NcclAllReduceConfig { - NcclCollectiveConfig config; - ReductionKind reduction_kind; -}; - -// Thunk that performs a NCCL-based All-Reduce or Reduce-Scatter among CUDA -// GPU-based replicas. -class NcclAllReduceReduceScatterThunkBase : public NcclCollectiveThunk { - public: - static std::optional MatchAllReduceComputation( - mlir::Region& computation); - - NcclAllReduceReduceScatterThunkBase(Kind kind, ThunkInfo thunk_info, - NcclAllReduceConfig config, - std::vector buffers, - bool is_sync); - - protected: - const NcclCollectiveConfig& config() const override { return config_.config; } - - const NcclAllReduceConfig config_; - const std::vector buffers_; -}; - -// ----------------------------------------------------------------------------- -// AllReduce thunk. -// ----------------------------------------------------------------------------- - -class NcclAllReduceStartThunk : public NcclAllReduceReduceScatterThunkBase { - public: - NcclAllReduceStartThunk(ThunkInfo thunk_info, - mlir::lmhlo_gpu::AllReduceStartOp op, - std::vector buffers); - - static const char* GetHloOpName() { return "all-reduce-start"; } - - static Status CheckImplementable(mlir::lmhlo_gpu::AllReduceStartOp op, - int64_t replica_count, - int64_t partition_count); - static bool IsDegenerate(mlir::lmhlo_gpu::AllReduceStartOp op, - int64_t replica_count, int64_t partition_count); - static CollectiveOpGroupMode GetGroupMode( - mlir::lmhlo_gpu::AllReduceStartOp op); - - protected: - Status RunNcclCollective(const ExecuteParams& params, se::Stream& stream, - ncclComm_t comm) override; -}; - -// ----------------------------------------------------------------------------- -// ReduceScatter thunk -// ----------------------------------------------------------------------------- -class NcclReduceScatterStartThunk : public NcclAllReduceReduceScatterThunkBase { - public: - NcclReduceScatterStartThunk(ThunkInfo thunk_info, - mlir::lmhlo_gpu::ReduceScatterStartOp op, - std::vector buffers); - - static const char* GetHloOpName() { return "reduce-scatter-start"; } - - static Status CheckImplementable(mlir::lmhlo_gpu::ReduceScatterStartOp op, - int64_t replica_count, - int64_t partition_count); - static bool IsDegenerate(mlir::lmhlo_gpu::ReduceScatterStartOp op, - int64_t replica_count, int64_t partition_count); - static CollectiveOpGroupMode GetGroupMode( - mlir::lmhlo_gpu::ReduceScatterStartOp op); - - protected: - Status RunNcclCollective(const ExecuteParams& params, se::Stream& stream, - ncclComm_t comm) override; -}; - -// ----------------------------------------------------------------------------- - -Status RunAllReduce(ReductionKind reduction_kind, - std::vector& buffers, se::Stream& stream, - ncclComm_t comm); - -Status RunReduceScatter(ReductionKind reduction_kind, - std::vector& buffers, - se::Stream& stream, ncclComm_t comm); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ diff --git a/xla/service/gpu/nccl_all_to_all_thunk.cc b/xla/service/gpu/nccl_all_to_all_thunk.cc deleted file mode 100644 index a94451fb62822..0000000000000 --- a/xla/service/gpu/nccl_all_to_all_thunk.cc +++ /dev/null @@ -1,209 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/nccl_all_to_all_thunk.h" - -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/nccl_collective_thunk.h" -#include "xla/shape_util.h" - -#if XLA_ENABLE_XCCL -#include "xla/stream_executor/gpu/gpu_stream.h" -#endif - -namespace xla { -namespace gpu { - -using mlir::lmhlo_gpu::AllToAllStartOp; - -namespace impl { -NcclAllToAllConfig GetNcclAllToAllConfig(AllToAllStartOp op) { - NcclAllToAllConfig config; - // FIXME(b/180174349): LMHLO AllToAll incorrectly has use_global_device_ids - // attribute and it should be removed. - config.config = GetNcclCollectiveConfigForMlir(op, std::nullopt); - config.has_split_dimension = op.getSplitDimension().has_value(); - return config; -} - -Status CheckImplementable(AllToAllStartOp op) { - TF_RETURN_IF_ERROR(NcclCollectiveThunk::CheckImplementable()); - std::optional split_dim = op.getSplitDimension(); - for (mlir::Value operand : op.getInputs()) { - TF_RETURN_IF_ERROR(IsValidOperand(operand, Thunk::kNcclAllToAll)); - Shape shape = GetShape(operand); - if (split_dim && - !ShapeUtil::IsEffectivelyMostMajorDimension(shape, *split_dim)) { - return tsl::errors::Unimplemented( - "all-to-all split dim %u is not the most major in input shape %s", - *split_dim, shape.ToString(/*print_layout=*/true)); - } - } - return OkStatus(); -} -} // namespace impl - -NcclAllToAllStartThunk::NcclAllToAllStartThunk( - ThunkInfo thunk_info, AllToAllStartOp op, - std::vector buffers) - : NcclCollectiveThunk(Thunk::kNcclAllToAllStart, thunk_info, - op.getIsSync()), - config_(impl::GetNcclAllToAllConfig(op)), - buffers_(std::move(buffers)) { - CHECK_EQ(config_.config.operand_count, buffers_.size()); -} - -/*static*/ Status NcclAllToAllStartThunk::CheckImplementable( - AllToAllStartOp op, int64_t replica_count, int64_t partition_count) { - return AddOpDescription( - impl::CheckImplementable(op), op, replica_count, partition_count); -} - -/*static*/ bool NcclAllToAllStartThunk::IsDegenerate(AllToAllStartOp op, - int64_t replica_count, - int64_t partition_count) { - return impl::GetNcclAllToAllConfig(op).config.IsDegenerate(replica_count, - partition_count); -} - -/*static*/ CollectiveOpGroupMode NcclAllToAllStartThunk::GetGroupMode( - AllToAllStartOp op) { - return impl::GetNcclAllToAllConfig(op).config.group_mode; -} - -Status NcclAllToAllStartThunk::RunNcclCollective(const ExecuteParams& params, - se::Stream& stream, - ncclComm_t comm) { - TF_ASSIGN_OR_RETURN( - std::vector device_buffers, - ConvertToDeviceBuffers(params, buffers_, - config_.config.operand_element_type)); - return xla::gpu::RunAllToAll(config_.has_split_dimension, device_buffers, - stream, comm); -} - -Status RunAllToAll(bool has_split_dimension, - std::vector& buffers, se::Stream& stream, - ncclComm_t comm) { -#if XLA_ENABLE_XCCL - int device_ordinal = stream.parent()->device_ordinal(); - VLOG(3) << "Performing all-to-all from device ordinal: " << device_ordinal; - - se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream); - - int num_participants; - XLA_CUDA_RETURN_IF_ERROR(ncclCommCount(comm, &num_participants)); - - XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart()); - // AllToAll can operate in two modes. Either it specifies a split dimension, - // in which case inputs are split and outputs concatenated in that dimension - // (here, we only support dimension 0), or it takes a list of inputs - // and produces a tuple of outputs. - if (has_split_dimension) { - for (size_t i = 0; i < buffers.size(); ++i) { - DeviceBufferPair& buffer = buffers[i]; - const uint8_t* send_buffer = - static_cast(buffer.source_buffer.opaque()); - uint8_t* recv_buffer = - static_cast(buffer.destination_buffer.opaque()); - - TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier, - ToNcclDataTypeAndCountMultiplier( - buffer.element_type, Thunk::kNcclAllToAll)); - auto [dtype, multiplier] = dtype_and_multiplier; - int64_t element_count = buffer.element_count; - - TF_RET_CHECK(element_count % num_participants == 0) - << "Buffer was not an exact multiple of the number of participants."; - size_t chunk_elements = element_count / num_participants; - size_t chunk_bytes = chunk_elements * ShapeUtil::ByteSizeOfPrimitiveType( - buffer.element_type); - - for (int rank = 0; rank < num_participants; ++rank) { - VLOG(3) << absl::StreamFormat( - "Calling ncclSend(sendbuff=%p, count=%d, peer=%d " - "comm=%p, stream=%p)", - send_buffer + rank * chunk_bytes, chunk_elements * multiplier, rank, - static_cast(comm), gpu_stream); - XLA_CUDA_RETURN_IF_ERROR(ncclSend(send_buffer + rank * chunk_bytes, - chunk_elements * multiplier, dtype, - rank, comm, gpu_stream)); - - VLOG(3) << absl::StreamFormat( - "Calling ncclRecv(recvbuff=%p, count=%d, peer=%d " - "comm=%p, stream=%p)", - recv_buffer + rank * chunk_bytes, chunk_elements * multiplier, rank, - static_cast(comm), gpu_stream); - - XLA_CUDA_RETURN_IF_ERROR(ncclRecv(recv_buffer + rank * chunk_bytes, - chunk_elements * multiplier, dtype, - rank, comm, gpu_stream)); - } - } - } else { - TF_RET_CHECK(buffers.size() == num_participants) - << "Number of inputs didn't match the number of participants."; - - for (size_t i = 0; i < buffers.size(); ++i) { - DeviceBufferPair& buffer = buffers[i]; - const uint8_t* send_buffer = - static_cast(buffer.source_buffer.opaque()); - uint8_t* recv_buffer = - static_cast(buffer.destination_buffer.opaque()); - - TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier, - ToNcclDataTypeAndCountMultiplier( - buffer.element_type, Thunk::kNcclAllToAll)); - auto [dtype, multiplier] = dtype_and_multiplier; - int64_t element_count = buffer.element_count * multiplier; - - VLOG(3) << absl::StreamFormat( - "Calling ncclSend(sendbuff=%p, count=%d, peer=%d " - "comm=%p, stream=%p)", - send_buffer, element_count, i, static_cast(comm), - gpu_stream); - - XLA_CUDA_RETURN_IF_ERROR(ncclSend(send_buffer, element_count, dtype, - /*rank=*/i, comm, gpu_stream)); - - VLOG(3) << absl::StreamFormat( - "Calling ncclRecv(recvbuff=%p, count=%d, peer=%d " - "comm=%p, stream=%p)", - recv_buffer, element_count, i, static_cast(comm), - gpu_stream); - - XLA_CUDA_RETURN_IF_ERROR(ncclRecv(recv_buffer, element_count, dtype, - /*rank=*/i, comm, gpu_stream)); - } - } - XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd()); - - VLOG(3) << "Done performing all-to-all for ordinal: " << device_ordinal; - return OkStatus(); -#else // XLA_ENABLE_XCCL - return Unimplemented( - "NCCL support is not available: this binary was not built with a CUDA " - "compiler, which is necessary to build the NCCL source library."); -#endif // XLA_ENABLE_XCCL -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/nccl_all_to_all_thunk.h b/xla/service/gpu/nccl_all_to_all_thunk.h deleted file mode 100644 index e6e60c5d0fb65..0000000000000 --- a/xla/service/gpu/nccl_all_to_all_thunk.h +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_NCCL_ALL_TO_ALL_THUNK_H_ -#define XLA_SERVICE_GPU_NCCL_ALL_TO_ALL_THUNK_H_ - -#include - -#include "xla/service/collective_ops_utils.h" -#include "xla/service/gpu/nccl_collective_thunk.h" - -namespace xla { -namespace gpu { - -struct NcclAllToAllConfig { - NcclCollectiveConfig config; - bool has_split_dimension; -}; - -// Thunk that performs a NCCL-based All-to-All among CUDA GPU-based replicas. -class NcclAllToAllStartThunk : public NcclCollectiveThunk { - public: - NcclAllToAllStartThunk(ThunkInfo thunk_info, - mlir::lmhlo_gpu::AllToAllStartOp op, - std::vector buffers); - - // Returns whether the given instruction can be lowered to a nccl all-to-all - // call. - static Status CheckImplementable(mlir::lmhlo_gpu::AllToAllStartOp op, - int64_t replica_count, - int64_t partition_count); - - static const char* GetHloOpName() { return "all-to-all-start"; } - static bool IsDegenerate(mlir::lmhlo_gpu::AllToAllStartOp op, - int64_t replica_count, int64_t partition_count); - static CollectiveOpGroupMode GetGroupMode( - mlir::lmhlo_gpu::AllToAllStartOp op); - - protected: - const NcclCollectiveConfig& config() const override { return config_.config; } - Status RunNcclCollective(const ExecuteParams& params, se::Stream& stream, - ncclComm_t comm) override; - - private: - const NcclAllToAllConfig config_; - const std::vector buffers_; -}; - -Status RunAllToAll(bool has_split_dimension, - std::vector& buffers, se::Stream& stream, - ncclComm_t comm); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_NCCL_ALL_TO_ALL_THUNK_H_ diff --git a/xla/service/gpu/nccl_clique_key.cc b/xla/service/gpu/nccl_clique_key.cc new file mode 100644 index 0000000000000..eebfd23b8dc05 --- /dev/null +++ b/xla/service/gpu/nccl_clique_key.cc @@ -0,0 +1,123 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/nccl_clique_key.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "xla/service/global_device_id.h" + +namespace xla::gpu { + +//===----------------------------------------------------------------------===// +// NcclCliqueKey +//===----------------------------------------------------------------------===// + +NcclCliqueKey::NcclCliqueKey(std::vector devices, + int64_t stream_id, AsyncStreamKind stream_kind) + : devices_(std::move(devices)), + stream_id_(stream_id), + stream_kind_(stream_kind) {} + +absl::Span NcclCliqueKey::devices() const { + return devices_; +} + +int64_t NcclCliqueKey::stream_id() const { return stream_id_; } + +std::optional NcclCliqueKey::rank(GlobalDeviceId id) const { + if (auto it = absl::c_find(devices_, id); it != devices_.end()) { + return it - devices_.begin(); + } + return std::nullopt; +} + +bool NcclCliqueKey::IsSubsetOf(const NcclCliqueKey& other) const { + return stream_id_ == other.stream_id_ && + absl::c_all_of(devices_, [&](GlobalDeviceId id) { + return absl::c_linear_search(other.devices_, id); + }); +} + +std::string NcclCliqueKey::ToString() const { + return absl::StrFormat("devices=[%s]; stream=%d", + GlobalDeviceIdsToString(devices_), stream_id_); +} + +bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b) { + return a.devices_ == b.devices_ && a.stream_id_ == b.stream_id_; +} + +bool operator<(const NcclCliqueKey& a, const NcclCliqueKey& b) { + if (a.devices_.size() < b.devices_.size()) return true; + if (b.devices_.size() < a.devices_.size()) return false; + + if (a.devices_ < b.devices_) return true; + if (b.devices_ < a.devices_) return false; + + return a.stream_id_ < b.stream_id_; +} + +bool operator>(const NcclCliqueKey& a, const NcclCliqueKey& b) { + if (a.devices_.size() > b.devices_.size()) return true; + if (b.devices_.size() > a.devices_.size()) return false; + + if (a.devices_ > b.devices_) return true; + if (b.devices_ > a.devices_) return false; + + // We still use `<` to order by stream id as we want to acquire sync cliques + // before async ones. + return a.stream_id_ < b.stream_id_; +} + +//===----------------------------------------------------------------------===// +// NcclCliqueId +//===----------------------------------------------------------------------===// + +NcclCliqueId::NcclCliqueId() { std::fill(data_.begin(), data_.end(), 0); } + +NcclCliqueId::NcclCliqueId(char bytes[kSize]) { + std::copy(bytes, bytes + kSize, data_.data()); +} + +absl::StatusOr NcclCliqueId::FromString(std::string_view str) { + if (str.size() != kSize) { + return absl::InvalidArgumentError( + absl::StrFormat("Invalid NCCL clique id size: %d , expected %d bytes", + str.size(), kSize)); + } + char bytes[kSize]; + std::copy(str.data(), str.data() + kSize, bytes); + return NcclCliqueId(bytes); +} + +absl::Span NcclCliqueId::data() const { return data_; } + +std::string NcclCliqueId::ToString() const { + return std::string(data_.data(), data_.size()); +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/nccl_clique_key.h b/xla/service/gpu/nccl_clique_key.h new file mode 100644 index 0000000000000..dbd7ba1200b72 --- /dev/null +++ b/xla/service/gpu/nccl_clique_key.h @@ -0,0 +1,162 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_NCCL_CLIQUE_KEY_H_ +#define XLA_SERVICE_GPU_NCCL_CLIQUE_KEY_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/service/global_device_id.h" + +namespace xla::gpu { + +// A standalone library without any dependencies on NCCL that allows us to +// include this header in all of XLA without worrying about NCCL availability. + +//===----------------------------------------------------------------------===// +// AsyncStreamKind +//===----------------------------------------------------------------------===// + +// We include a stream kind into the NCCL clique key because in XLA we do not +// share communicators for collective operations of different kind (CUDA-graph +// launched, async collectives, sync collectives) as it can lead to dead locks. +// +// We carefully isolate different kinds of collectives using separate +// communicators and guarantee that all collective operations have a total order +// that will not create a deadlock. +// +// See more details in `nccl_clique` library. + +enum class AsyncStreamKind : int64_t { + kCollective = 0, // Stream for asynchronous collective ops. + kP2P0 = 1, // One Stream for P2P Send and Recv ops. + kP2P1 = 2, // Another Stream for P2P Send and Recv ops. +}; + +constexpr static int64_t kAsyncStreamTotal = + static_cast(AsyncStreamKind::kP2P1) + 1; + +// Assigns a unique ID to a stream for asynchronous or synchronous execution. +// These IDs can be used, for example, to look up the NCCL communicator. +inline uint64_t GetStreamId( + bool is_async, AsyncStreamKind stream_kind = AsyncStreamKind::kCollective) { + return is_async ? static_cast(stream_kind) + 1 : 0; +} + +//===----------------------------------------------------------------------===// +// NcclCliqueKey +//===----------------------------------------------------------------------===// + +// Key for naming up a particular NCCL clique. This is just a set of unique +// device IDs (i.e. GPU IDs) and a stream_id. The device IDs must be global +// within a cluster. The stream_id is used to create different NCCL clique and +// communicators for collectives executed on different streams within an +// executable. +class NcclCliqueKey { + public: + explicit NcclCliqueKey( + std::vector devices, int64_t stream_id = 0, + AsyncStreamKind stream_kind = AsyncStreamKind::kCollective); + + absl::Span devices() const; + + int64_t stream_id() const; + + // Returns the rank of the global device in the clique. + std::optional rank(GlobalDeviceId id) const; + + // Returns true if this clique is a subset of `other`: both cliques have the + // same `stream_id` and all clique devices are part of `other` clique. + bool IsSubsetOf(const NcclCliqueKey& other) const; + + // Returns the stream kind for this clique key, + // stream kind will be used to specify what configuration + // to pass for each type of operation. + AsyncStreamKind stream_kind() const { return stream_kind_; } + + std::string ToString() const; + + template + friend H AbslHashValue(H h, const NcclCliqueKey& k); + + friend bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b); + friend bool operator<(const NcclCliqueKey& a, const NcclCliqueKey& b); + friend bool operator>(const NcclCliqueKey& a, const NcclCliqueKey& b); + + private: + std::vector devices_; + int64_t stream_id_; + AsyncStreamKind stream_kind_; +}; + +template +H AbslHashValue(H h, const NcclCliqueKey& k) { + return H::combine(std::move(h), k.devices_, k.stream_id_); +} + +bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b); +bool operator<(const NcclCliqueKey& a, const NcclCliqueKey& b); + +//===----------------------------------------------------------------------===// +// NcclCliqueId +//===----------------------------------------------------------------------===// + +// All collective cliques have a globally unique ID (128 bytes long for NCCL) +// that allows multiple hosts and devices to find each other and agree who is a +// member of a clique. It is a user responsibility to redistribute this id to +// all participating hosts (i.e. JAX uses shared KV store for that). For single +// host collective operations XLA automatically generates a unique id for local +// cliques (cliques consisting of devices visible from a process). + +// A globally unique collective clique identifier. +class NcclCliqueId { + public: + static constexpr int32_t kSize = 128; + + static absl::StatusOr FromString(std::string_view str); + + NcclCliqueId(); + explicit NcclCliqueId(char bytes[kSize]); + + absl::Span data() const; + std::string ToString() const; + + template + friend H AbslHashValue(H h, const NcclCliqueId& id); + + private: + std::array data_; +}; + +template +H AbslHashValue(H h, const NcclCliqueId& id) { + return H::combine(std::move(h), id.data()); +} + +// A callback to get a unique clique id (see `ncclUniqueId` documentation). +using NcclCliqueIdCallback = // NOLINT + std::function(const NcclCliqueKey&)>; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_NCCL_CLIQUE_KEY_H_ diff --git a/xla/service/gpu/nccl_clique_key_test.cc b/xla/service/gpu/nccl_clique_key_test.cc new file mode 100644 index 0000000000000..ca804a4e18669 --- /dev/null +++ b/xla/service/gpu/nccl_clique_key_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/nccl_clique_key.h" + +#include +#include + +#include "absl/container/btree_map.h" +#include "xla/service/global_device_id.h" +#include "tsl/platform/test.h" + +namespace xla::gpu { + +TEST(NcclCliqueKeyTest, IsSubsetOf) { + GlobalDeviceId id0 = GlobalDeviceId(0); + GlobalDeviceId id1 = GlobalDeviceId(1); + GlobalDeviceId id2 = GlobalDeviceId(2); + GlobalDeviceId id3 = GlobalDeviceId(3); + + NcclCliqueKey key0({id0, id1}, 0); + NcclCliqueKey key1({id0, id1, id2, id3}, 0); + NcclCliqueKey key2({id0, id1, id2, id3}, 1); + NcclCliqueKey key3({id1, id2, id3}, 0); + + EXPECT_TRUE(key0.IsSubsetOf(key1)); + EXPECT_FALSE(key0.IsSubsetOf(key2)); + EXPECT_FALSE(key0.IsSubsetOf(key3)); +} + +TEST(NcclCliqueKeyTest, Compare) { + GlobalDeviceId id0 = GlobalDeviceId(0); + GlobalDeviceId id1 = GlobalDeviceId(1); + GlobalDeviceId id2 = GlobalDeviceId(2); + GlobalDeviceId id3 = GlobalDeviceId(3); + + NcclCliqueKey key0({id0, id1}, 0); + NcclCliqueKey key1({id1, id2, id3}, 0); + + EXPECT_LT(key0, key1); + EXPECT_GT(key1, key0); +} + +TEST(NcclCliqueKeyTest, BtreeIterationOrder) { + GlobalDeviceId id0 = GlobalDeviceId(0); + GlobalDeviceId id1 = GlobalDeviceId(1); + GlobalDeviceId id2 = GlobalDeviceId(2); + GlobalDeviceId id3 = GlobalDeviceId(3); + + NcclCliqueKey key0({id0, id2}, 0); + NcclCliqueKey key1({id0, id1, id2, id3}, 0); + + absl::btree_map> map; + map[key0] = 0; + map[key1] = 1; + + EXPECT_EQ(map.begin()->first, key1); +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/nccl_collective_permute_thunk.cc b/xla/service/gpu/nccl_collective_permute_thunk.cc deleted file mode 100644 index d1a8180d58791..0000000000000 --- a/xla/service/gpu/nccl_collective_permute_thunk.cc +++ /dev/null @@ -1,269 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/nccl_collective_permute_thunk.h" - -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_set.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" -#include "xla/service/collective_ops_utils.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/translate/mhlo_to_hlo/attribute_exporter.h" -#include "xla/xla_data.pb.h" - -#if XLA_ENABLE_XCCL -#include "xla/stream_executor/gpu/gpu_stream.h" -#endif - -namespace xla { -namespace gpu { - -using mlir::lmhlo_gpu::CollectivePermuteStartOp; - -namespace impl { - -CollectiveOpGroupMode GetGroupMode(CollectivePermuteStartOp op) { - return GetCollectiveOpGroupMode(op.getChannelId().has_value(), std::nullopt) - .value(); -} - -NcclP2PConfig GetNcclP2PConfig(CollectivePermuteStartOp op, - int64_t replica_count, int64_t partition_count) { - NcclP2PConfig collective_permute_config; - auto& config = collective_permute_config.config; - - config.operand_count = 1; - const Shape shape = GetShape(op.getOperand()); - config.operand_element_type.push_back(shape.element_type()); - config.SetCollectiveOpKindAndID(op); - config.group_mode = GetGroupMode(op); - - // With a collective permute, all execution instances together form one - // replica group. - const int64_t num_participants = - config.group_mode == CollectiveOpGroupMode::kCrossReplica - ? replica_count - : partition_count; - config.replica_groups.emplace_back(); - ReplicaGroup& replica_group = config.replica_groups.front(); - for (int i = 0; i < num_participants; ++i) { - replica_group.add_replica_ids(i); - } - - const std::vector> source_target_pairs = - ConvertNx2Attribute(op.getSourceTargetPairs()).value(); - - for (const std::pair& source_target : source_target_pairs) { - int64_t source = source_target.first; - int64_t target = source_target.second; - - collective_permute_config.id_to_source_target.insert({target, {}}) - .first->second.source = source; - collective_permute_config.id_to_source_target.insert({source, {}}) - .first->second.target = target; - } - - return collective_permute_config; -} - -// The collective permute is degenerate if all source-target pairs are identity, -// and all the IDs appear in the list. -bool IsDegenerate(CollectivePermuteStartOp op, int64_t replica_count, - int64_t partition_count) { - const std::vector> source_target_pairs = - ConvertNx2Attribute(op.getSourceTargetPairs()).value(); - // Each ID can appear only once as a source and as a target. So if all pairs - // are identity, all IDs must appear in the list is the size == number of - // replicas/partitions. - const int64_t expected_size = - op.getChannelId() ? partition_count : replica_count; - return source_target_pairs.size() == expected_size && - absl::c_all_of(source_target_pairs, - [](const std::pair& source_target) { - return source_target.first == source_target.second; - }); -} - -Status CheckImplementable(CollectivePermuteStartOp op) { - TF_RETURN_IF_ERROR(NcclCollectiveThunk::CheckImplementable()); - return IsValidOperand(op.getOperand(), Thunk::kNcclCollectivePermute); -} - -} // namespace impl - -NcclCollectivePermuteStartThunk::NcclCollectivePermuteStartThunk( - ThunkInfo thunk_info, CollectivePermuteStartOp op, int64_t replica_count, - int64_t partition_count, const Buffer& buffer) - : NcclCollectiveThunk(Thunk::kNcclCollectivePermuteStart, thunk_info, - op.getIsSync()), - config_(GetNcclP2PConfig(op, replica_count, partition_count)), - buffer_(buffer) {} - -/*static*/ NcclP2PConfig NcclCollectivePermuteStartThunk::GetNcclP2PConfig( - CollectivePermuteStartOp op, int64_t replica_count, - int64_t partition_count) { - return impl::GetNcclP2PConfig(op, replica_count, partition_count); -} - -/*static*/ Status NcclCollectivePermuteStartThunk::CheckImplementable( - CollectivePermuteStartOp op, int64_t replica_count, - int64_t partition_count) { - return AddOpDescription( - impl::CheckImplementable(op), op, replica_count, partition_count); -} - -/*static*/ bool NcclCollectivePermuteStartThunk::IsDegenerate( - CollectivePermuteStartOp op, int64_t replica_count, - int64_t partition_count) { - return impl::IsDegenerate(op, replica_count, partition_count); -} - -/*static*/ CollectiveOpGroupMode NcclCollectivePermuteStartThunk::GetGroupMode( - CollectivePermuteStartOp op) { - return impl::GetGroupMode(op); -} - -Status NcclCollectivePermuteStartThunk::RunNcclCollective( - const ExecuteParams& params, se::Stream& stream, ncclComm_t comm) { - TF_ASSIGN_OR_RETURN( - std::vector device_buffers, - ConvertToDeviceBuffers(params, {buffer_}, - config_.config.operand_element_type)); - TF_RET_CHECK(device_buffers.size() == 1) << "Expected one buffer pair."; - - TF_ASSIGN_OR_RETURN(const GlobalDeviceId global_device_id, - params.nccl_params.GetGlobalDeviceId()); - TF_ASSIGN_OR_RETURN( - const DeviceAssignment::LogicalID current_logical_id, - params.nccl_params.device_assn->LogicalIdForDevice(global_device_id)); - const int64_t current_id = - config_.config.group_mode == CollectiveOpGroupMode::kCrossReplica - ? current_logical_id.replica_id - : current_logical_id.computation_id; - std::string device_string = GetDeviceString(params.nccl_params); - - const NcclP2PConfig::SourceTargetMapEntry source_target = - NcclP2PConfig::GetSourceTarget(config_.id_to_source_target, current_id); - - return ::xla::gpu::RunCollectivePermute(source_target, device_buffers[0], - stream, comm, device_string, - current_id); -} - -Status RunCollectivePermute(NcclP2PConfig::SourceTargetMapEntry source_target, - DeviceBufferPair& buffer, se::Stream& stream, - ncclComm_t comm, absl::string_view device_string, - int64_t current_id) { -#if XLA_ENABLE_XCCL - // Determine the source and target IDs for this instance. The source ID is the - // ID which will copy its data to this instance. The destination ID is the ID - // to which this instance will copy its data. Either are optional. - // - // No source and no dest: - // - this instance does not actually participate, no one send it any data and - // it does not have to send any data as well. Since there is no dest, - // just memzero() the dest buffer as required by the collective permute - // semantics. - // - // No source, dest present: - // - This instance has to send data to 'dest' Issue an send of the input. - // Since there is no source, memzero the dest buffer. - // - // Source present, no destination: - // - This instance received data from the source, does not have to send data - // to anyone, Issue a receive. - // - // Source and dest both present: - // - Issue a send of the input to dest, receive for the output from the - // src. - // - // - - int device_ordinal = stream.parent()->device_ordinal(); - VLOG(3) << "Performing collective permute from device ordinal: " - << device_ordinal << "current_id " << current_id; - - const std::optional source_id = source_target.source; - const std::optional target_id = source_target.target; - - se::DeviceMemoryBase src_addr = buffer.source_buffer; - se::DeviceMemoryBase dest_addr = buffer.destination_buffer; - - VLOG(3) << absl::StreamFormat("%s : id = %d, source_id = %d, target_id = %d", - device_string, current_id, - source_id.value_or(-1), target_id.value_or(-1)); - - // ncclGroupStart/end API is needed only if we will issue both ncclSend and - // ncclRecv API calls. - const bool is_nccl_group_needed = (target_id && source_id); - if (is_nccl_group_needed) { - XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart()); - } - - TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier, - ToNcclDataTypeAndCountMultiplier( - buffer.element_type, Thunk::kNcclCollectivePermute)); - ncclDataType_t dtype = dtype_and_multiplier.first; - int64_t element_count = buffer.element_count * dtype_and_multiplier.second; - - se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream); - - // Send source buffer to target peer if needed. - if (target_id) { - VLOG(3) << absl::StreamFormat( - "%s : Calling ncclSend(sendbuff=%p, count=%d, peer=%d " - "comm=%p, stream=%p)", - device_string, src_addr.opaque(), element_count, *target_id, - static_cast(comm), gpu_stream); - XLA_CUDA_RETURN_IF_ERROR(ncclSend(src_addr.opaque(), element_count, dtype, - *target_id, comm, gpu_stream)); - } - - // Receive data from the source peer to the destination buffer. - if (source_id) { - VLOG(3) << absl::StreamFormat( - "%s : Calling ncclRecv(recvbuff=%p, count=%d, peer=%d comm=%p, " - "stream=%p)", - device_string, dest_addr.opaque(), element_count, *source_id, - static_cast(comm), gpu_stream); - XLA_CUDA_RETURN_IF_ERROR(ncclRecv(dest_addr.opaque(), element_count, dtype, - *source_id, comm, gpu_stream)); - } - if (is_nccl_group_needed) { - XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd()); - } - - if (!source_id) { - // If there is no source peer, i.e. no one send us any data, zero out dest - // buffer. - VLOG(3) << absl::StreamFormat("%s : collective-Permute: Issuing MemZero", - device_string); - stream.ThenMemZero(&dest_addr, dest_addr.size()); - } - return OkStatus(); -#else // XLA_ENABLE_XCCL - return Unimplemented( - "NCCL support is not available: this binary was not built with a CUDA " - "compiler, which is necessary to build the NCCL source library."); -#endif // XLA_ENABLE_XCCL -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/nccl_collective_permute_thunk.h b/xla/service/gpu/nccl_collective_permute_thunk.h deleted file mode 100644 index 8f2a675e782c2..0000000000000 --- a/xla/service/gpu/nccl_collective_permute_thunk.h +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_NCCL_COLLECTIVE_PERMUTE_THUNK_H_ -#define XLA_SERVICE_GPU_NCCL_COLLECTIVE_PERMUTE_THUNK_H_ - -#include - -#include "xla/service/collective_ops_utils.h" -#include "xla/service/gpu/nccl_collective_thunk.h" -#include "xla/service/gpu/nccl_p2p_thunk_common.h" - -namespace xla { -namespace gpu { - -// Thunk that performs a NCCL-based collective permute. -class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk { - public: - static NcclP2PConfig GetNcclP2PConfig( - mlir::lmhlo_gpu::CollectivePermuteStartOp op, int64_t replica_count, - int64_t partition_count); - - static Status CheckImplementable(mlir::lmhlo_gpu::CollectivePermuteStartOp op, - int64_t replica_count, - int64_t partition_count); - - static bool IsDegenerate(mlir::lmhlo_gpu::CollectivePermuteStartOp op, - int64_t replica_count, int64_t partition_count); - static CollectiveOpGroupMode GetGroupMode( - mlir::lmhlo_gpu::CollectivePermuteStartOp op); - - static const char* GetHloOpName() { return "collective-permute-start"; } - - NcclCollectivePermuteStartThunk(ThunkInfo thunk_info, - mlir::lmhlo_gpu::CollectivePermuteStartOp op, - int64_t replica_count, - int64_t partition_count, - const Buffer& buffer); - - protected: - const NcclCollectiveConfig& config() const override { return config_.config; } - Status RunNcclCollective(const ExecuteParams& params, se::Stream& stream, - ncclComm_t comm) override; - - private: - const NcclP2PConfig config_; - const Buffer buffer_; -}; - -Status RunCollectivePermute(NcclP2PConfig::SourceTargetMapEntry source_target, - DeviceBufferPair& buffer, se::Stream& stream, - ncclComm_t comm, absl::string_view device_string, - int64_t current_id); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_NCCL_COLLECTIVE_PERMUTE_THUNK_H_ diff --git a/xla/service/gpu/nccl_collective_thunk.cc b/xla/service/gpu/nccl_collective_thunk.cc deleted file mode 100644 index 54baa7228c2e4..0000000000000 --- a/xla/service/gpu/nccl_collective_thunk.cc +++ /dev/null @@ -1,340 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/nccl_collective_thunk.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "xla/service/collective_ops_utils.h" -#include "xla/service/global_device_id.h" -#include "xla/stream_executor/gpu/gpu_activation.h" -#include "xla/util.h" - -namespace xla { -namespace gpu { -namespace { - -bool IsTypeSupportedByNccl(PrimitiveType element_type, - Thunk::Kind reduction_op) { - switch (element_type) { - case S8: - case PRED: - case U8: - case S32: - case U32: - case S64: - case U64: - case F16: - case F32: - case F64: -#if defined(__CUDA_BF16_TYPES_EXIST__) || TENSORFLOW_USE_ROCM - case BF16: -#endif - case C64: - case C128: - return true; - case S16: - case U16: - // 16-bit integer reductions are not directly supported by NCCL and cannot - // be implicitly converted into other 16-bit types like ncclFloat16 as - // they involve actual computation and not just data movement. - case F8E5M2: - case F8E4M3FN: - return !IsReductionCollective(reduction_op); - default: - return false; - } -} - -} // namespace - -// This file runs collective ops (i.e. ops that communicate between multiple -// GPUs) using NCCL. -// -// Here's a high-level overview of how running an op works. -// -// - Multiple threads call ExecuteOnStream. -// - All threads that "go together" (i.e. are participating in the "same" -// collective op) choose the same Rendezvous object from a global map. -// - Once all threads have arrived at the Rendezvous, we know exactly which -// GPUs are participating in the op, so we get or create a NcclClique -// containing those GPUs. -// - We perform the NCCL operation using the clique. - -NcclCollectiveConfig::NcclCollectiveConfig() = default; -NcclCollectiveConfig::NcclCollectiveConfig(NcclCollectiveConfig&&) = default; -NcclCollectiveConfig::~NcclCollectiveConfig() = default; -NcclCollectiveConfig& NcclCollectiveConfig::operator=(NcclCollectiveConfig&&) = - default; - -// Returns if the collective communication operation is degenerate because all -// the groups formed by the operation are singleton. A given op can be -// degenerate under several conditions, corresponding to the modes supported -// in GetParticipatingDevices(). -// 1. no channel id, use_global_device_ids = false: -// degenerate if replica_groups are singleton, or groups empty and -// replica_count == 1. -// 2. channel_id is set, use_global_device_ids = false: -// degenerate if replica_groups are singleton and num_partitions == 1, -// or groups empty and num_replicas == 1 && num_partitions == 1. -// 3. channel_id is set, use_global_device_ids = true (flattened-ids): -// degenerate if replica_groups are singleton (groups cannot be empty). -// 4. no channel_id, no use_global_device_ids: -// identical to 1. -// 5. channel_id is set, no use_global_device_ids: -// degenerate if replica_groups are singleton or group emty and -// num_partitions == 1 (since replica groups contain partition ids). -// -bool NcclCollectiveConfig::IsDegenerate(int64_t replica_count, - int64_t partition_count) const { - bool groups_empty = replica_groups.empty(); - - // check if all replica_groups are singleton. If not, then the operation is - // not degenerate. - bool all_groups_singleton = - !groups_empty && - absl::c_all_of(replica_groups, [](const ReplicaGroup& group) { - return group.replica_ids_size() == 1; - }); - - switch (group_mode) { - case CollectiveOpGroupMode::kCrossReplica: - return all_groups_singleton || (groups_empty && replica_count == 1); - case CollectiveOpGroupMode::kCrossPartition: - return all_groups_singleton || (groups_empty && partition_count == 1); - case CollectiveOpGroupMode::kCrossReplicaAndPartition: - return (all_groups_singleton && partition_count == 1) || - (groups_empty && replica_count == 1 && partition_count == 1); - case CollectiveOpGroupMode::kFlattenedID: - CHECK(!groups_empty) - << "replica groups cannot be empty if use_global_device_ids = true"; - return all_groups_singleton; - default: - CHECK(0) << "Invalid collective op mode"; - return false; - } -} - -NcclCollectiveThunk::NcclCollectiveThunk(Kind kind, ThunkInfo thunk_info, - bool is_sync) - : Thunk(kind, thunk_info) { - if (!is_sync) { - async_ = std::make_unique(); - } -} - -/* static */ bool NcclCollectiveThunk::NcclIsEnabled() { -#if XLA_ENABLE_XCCL - return true; -#else - return false; -#endif -} - -/* static */ Status NcclCollectiveThunk::CheckImplementable() { - if (!NcclIsEnabled()) { - return tsl::errors::Unimplemented("NCCL is not enabled"); - } - return OkStatus(); -} - -#if XLA_ENABLE_XCCL -StatusOr LockNcclComm( - const NcclExecuteParams& params, - const std::vector& replica_groups, - CollectiveOpGroupMode group_mode, int64_t op_id, int64_t stream_id, - bool enable_clique_optimization) { - TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id, - params.GetGlobalDeviceId()); - - TF_ASSIGN_OR_RETURN( - std::vector participants, - GetParticipatingDevices(global_device_id, *params.device_assn, - replica_groups, group_mode)); - - if (IsGlobalNcclConfig() && - (participants.size() != params.device_assn->replica_count())) { - return InvalidArgument( - "Partial replica groups are not allowed when using NCCL_COMM_ID " - "environment configuration."); - } - - auto it = absl::c_find(participants, global_device_id); - TF_RET_CHECK(it != participants.end()); - int rank = it - participants.begin(); - - std::vector local_devices; - if (params.gpu_global_device_ids) { - local_devices.reserve(params.gpu_global_device_ids->size()); - for (const auto& entry : *params.gpu_global_device_ids) { - local_devices.push_back(entry.second); - } - } - size_t num_local_participants = GetNumLocalParticipants( - participants, params.gpu_global_device_ids ? &local_devices : nullptr); - - bool is_local = participants.size() == num_local_participants; - TF_ASSIGN_OR_RETURN( - const NcclUniqueIdCallback* unique_id_callback, - GetNcclUniqueIdCallback(params.nccl_unique_id_callback, is_local)); - - se::gpu::ScopedActivateExecutorContext scoped_context(params.stream_executor); - - return AcquireNcclComm(params.run_id, OpId(op_id), std::move(participants), - num_local_participants, *unique_id_callback, rank, - stream_id, enable_clique_optimization); -} -#endif // XLA_ENABLE_XCCL - -StatusOr> ConvertToDeviceBuffers( - const Thunk::ExecuteParams& params, - const std::vector& buffers, - const std::vector& element_types) { - if (buffers.size() != element_types.size()) - return FailedPrecondition("Mismatch in operand buffer counts."); - - std::vector device_buffers; - device_buffers.reserve(buffers.size()); - for (int i = 0; i < buffers.size(); ++i) { - device_buffers.emplace_back(DeviceBufferPair{ - element_types[i], buffers[i].element_count, - - params.buffer_allocations->GetDeviceAddress(buffers[i].source_buffer), - params.buffer_allocations->GetDeviceAddress( - buffers[i].destination_buffer)}); - } - return device_buffers; -} - -Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) { -#if XLA_ENABLE_XCCL - VLOG(1) << absl::StreamFormat("Starting %s %s.", IsAsync() ? "async" : "sync", - Thunk::KindToString(kind())); - const int64_t stream_id = GetStreamId(); - TF_ASSIGN_OR_RETURN( - NcclComm::Lock comm, - LockNcclComm(params.nccl_params, config().replica_groups, - config().group_mode, config().op_id, stream_id, - /*enable_clique_optimization=*/false)); - - // Run the collective on main stream or using the async executor. - Status status = [&]() { - if (!IsAsync()) { - return RunNcclCollective(params, *params.stream, *comm); - } - return async_->Execute( - [this](const ExecuteParams& params, se::Stream& stream, - ncclComm_t comm) { - return RunNcclCollective(params, stream, comm); - }, - params, *comm, GetAsyncStreamKind()); - }(); - TF_RETURN_IF_ERROR(status); - - // Block host on the first call to ensure that all devices have allocated the - // required buffers for their communicators before allowing any device to - // continue enqueuing operations. Otherwise, the allocations can cause - // deadlock in the CUDA driver (b/215649390). - if (first_call_to_execute_) { - se::Stream* stream = IsAsync() - ? params.async_comms_streams[GetAsyncStreamKind()] - : params.stream; - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); - first_call_to_execute_ = false; - } - return OkStatus(); -#else // XLA_ENABLE_XCCL - return Unimplemented( - "NCCL support is not available: this binary was not built with a CUDA " - "compiler, which is necessary to build the NCCL source library."); -#endif // XLA_ENABLE_XCCL -} - -std::string NcclCollectiveThunk::GetDeviceString( - const NcclExecuteParams& nccl_params) { - int device_ordinal = nccl_params.stream_executor->device_ordinal(); - GlobalDeviceId global_device_id = nccl_params.GetGlobalDeviceId().value(); - DeviceAssignment::LogicalID logical_id = - nccl_params.device_assn->LogicalIdForDevice(global_device_id).value(); - return absl::StrFormat("(r%d, p%d) : GlobalID %d, ord %d", - logical_id.replica_id, logical_id.computation_id, - global_device_id.value(), device_ordinal); -} - -Status NcclCollectiveThunk::AsyncExecutor::Execute( - absl::FunctionRef fn, - const ExecuteParams& params, ncclComm_t comm, AsyncStreamKind stream_kind) { - se::Stream& async_comms_stream = *params.async_comms_streams[stream_kind]; - // Wait until compute inputs are ready. - async_comms_stream.ThenWaitFor(params.stream); - - TF_RETURN_IF_ERROR(fn(params, async_comms_stream, comm)); - - // Create an event on the async stream for the completion of the collective. - se::Event done_event(async_comms_stream.parent()); - TF_RET_CHECK(done_event.Init()); - async_comms_stream.ThenRecordEvent(&done_event); - - int device_ordinal = async_comms_stream.parent()->device_ordinal(); - absl::MutexLock lock(&mu_); - auto [_, was_inserted] = - done_events_.insert({device_ordinal, std::move(done_event)}); - TF_RET_CHECK(was_inserted) << "done event has not been consumed"; - return OkStatus(); -} - -Status NcclCollectiveThunk::AsyncExecutor::Await(const ExecuteParams& params) { - int device_ordinal = params.stream->parent()->device_ordinal(); - auto done_event = [this, device_ordinal] { - absl::MutexLock lock(&mu_); - return done_events_.extract(device_ordinal); - }(); - TF_RET_CHECK(done_event) << "done event not found"; - params.stream->ThenWaitFor(&done_event.mapped()); - return OkStatus(); -} - -NcclCollectiveDoneThunk::NcclCollectiveDoneThunk( - Thunk::Kind kind, ThunkInfo thunk_info, - NcclCollectiveThunk::AsyncExecutor& async) - : Thunk(kind, std::move(thunk_info)), async_(async) {} - -Status NcclCollectiveDoneThunk::ExecuteOnStream(const ExecuteParams& params) { - return async_.Await(params); -} - -Status IsValidOperand(mlir::Value operand, Thunk::Kind reduction_op) { - Shape shape = GetShape(operand); - if (!LayoutUtil::IsDenseArray(shape)) { - return tsl::errors::Unimplemented( - absl::StrFormat("input is not a dense array: %s", - shape.ToString(/*print_layout=*/true))); - } - if (!IsTypeSupportedByNccl(shape.element_type(), reduction_op)) { - return tsl::errors::Unimplemented(absl::StrFormat( - "element type %s not suppored by NCCL", - primitive_util::LowercasePrimitiveTypeName(shape.element_type()))); - } - return OkStatus(); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/nccl_collective_thunk.h b/xla/service/gpu/nccl_collective_thunk.h index e38b98ca76bd2..e69de29bb2d1d 100644 --- a/xla/service/gpu/nccl_collective_thunk.h +++ b/xla/service/gpu/nccl_collective_thunk.h @@ -1,218 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_ -#define XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_ - -#include -#include -#include -#include - -#include "absl/functional/function_ref.h" -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "xla/service/collective_ops_utils.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/thunk.h" -#include "xla/service/llvm_ir/llvm_util.h" -#include "xla/translate/mhlo_to_hlo/attribute_exporter.h" -#include "xla/xla_data.pb.h" - -#if XLA_ENABLE_XCCL -#include "xla/service/gpu/nccl_utils.h" -#endif // XLA_ENABLE_XCCL - -struct ncclComm; -using ncclComm_t = ncclComm*; - -namespace xla { -namespace gpu { - -class NcclClique; - -struct NcclCollectiveConfig { - NcclCollectiveConfig(); - NcclCollectiveConfig(NcclCollectiveConfig&&); - ~NcclCollectiveConfig(); - - NcclCollectiveConfig& operator=(NcclCollectiveConfig&&); - - int64_t operand_count; - std::vector operand_element_type; - std::vector replica_groups; - RendezvousKey::CollectiveOpKind collective_op_kind; - int64_t op_id; - CollectiveOpGroupMode group_mode; - int partition_count; - int replica_count; - - template - void SetCollectiveOpKindAndID(OpT op); - bool IsDegenerate(int64_t replica_count, int64_t partition_count) const; -}; - -template -void NcclCollectiveConfig::SetCollectiveOpKindAndID(OpT op) { - if (op.getChannelId()) { - collective_op_kind = RendezvousKey::kCrossModule; - op_id = static_cast(op.getChannelId()->getHandle()); - } else { - collective_op_kind = RendezvousKey::kCrossReplica; - mlir::ModuleOp parent = op->template getParentOfType(); - mlir::IntegerAttr unique_id = - parent->getAttrOfType("hlo.unique_id"); - op_id = static_cast(unique_id.getInt()); - } -} - -template -NcclCollectiveConfig GetNcclCollectiveConfigForMlir( - OpT op, std::optional use_global_device_ids) { - NcclCollectiveConfig config; - config.operand_count = op.getInputs().size(); - config.operand_element_type.reserve(config.operand_count); - for (int i = 0; i < config.operand_count; i++) { - const Shape shape = GetShape(op.getInputs()[i]); - config.operand_element_type.push_back(shape.element_type()); - } - config.replica_groups = ConvertReplicaGroups(op.getReplicaGroups()).value(); - config.SetCollectiveOpKindAndID(op); - config.group_mode = GetCollectiveOpGroupMode(op.getChannelId().has_value(), - use_global_device_ids) - .value(); - return config; -} - -// Thunk base class for NCCL collective operations. -class NcclCollectiveThunk : public Thunk { - public: - NcclCollectiveThunk(Kind kind, ThunkInfo thunk_info, bool is_sync); - - struct Buffer { - int64_t element_count; - BufferAllocation::Slice source_buffer; - BufferAllocation::Slice destination_buffer; - mlir::Value source_value; - mlir::Value destination_value; - }; - - class AsyncExecutor { - public: - // Executes the function on the async communications stream and records a - // completion event. - Status Execute( - absl::FunctionRef - fn, - const ExecuteParams& params, ncclComm_t comm, - AsyncStreamKind stream_kind); - // Blocks the compute stream until async communication is complete. - Status Await(const ExecuteParams& params); - - private: - absl::Mutex mu_; - // Store done events (by device ordinal) for the done thunk to wait on. - absl::flat_hash_map done_events_ ABSL_GUARDED_BY(mu_); - }; - - // Returns whether NCCL operations appear possible to perform; e.g. if we - // haven't done a build with the CUDA compiler enabled, we can't compile the - // NCCL header, and thus this will be false. - // - // When this is false, the ExecuteOnStream() call will simply return a status - // error. - static bool NcclIsEnabled(); - static Status CheckImplementable(); - - // Logging support. - static std::string GetDeviceString(const NcclExecuteParams& params); - - AsyncExecutor* async_executor() { return async_.get(); } - Status ExecuteOnStream(const ExecuteParams& params) override; - - protected: - virtual Status RunNcclCollective(const ExecuteParams& params, - se::Stream& stream, ncclComm_t comm) = 0; - virtual const NcclCollectiveConfig& config() const = 0; - virtual AsyncStreamKind GetAsyncStreamKind() const { - return kAsyncStreamCollective; - } - - private: - bool IsAsync() const { return async_ != nullptr; } - int64_t GetStreamId() const { - return xla::gpu::GetStreamId(IsAsync(), GetAsyncStreamKind()); - } - -#if XLA_ENABLE_XCCL - bool first_call_to_execute_ = true; -#endif // XLA_ENABLE_XCCL - std::unique_ptr async_; // null if not async. -}; - -class NcclCollectiveDoneThunk : public Thunk { - public: - NcclCollectiveDoneThunk(Thunk::Kind kind, ThunkInfo thunk_info, - NcclCollectiveThunk::AsyncExecutor& async); - - Status ExecuteOnStream(const ExecuteParams& params) override; - - private: - NcclCollectiveThunk::AsyncExecutor& async_; -}; - -Status IsValidOperand(mlir::Value operand, Thunk::Kind reduction_op); - -template -Status AddOpDescription(Status status, OpT op, int64_t replica_count, - int64_t partition_count) { - if (status.ok()) { - return status; - } - CollectiveOpGroupMode group_mode = NcclThunkType::GetGroupMode(op); - return Status( - status.code(), - absl::StrFormat( - "%s\n" - "%s with replica_count: %d, partition_count: %d, group_mode: %s, " - "operand_count: %d\n%s", - status.message(), NcclThunkType::GetHloOpName(), replica_count, - partition_count, CollectiveOpGroupModeToString(group_mode), - op->getNumOperands() / 2, llvm_ir::DumpToString(op.getOperation()))); -} - -#if XLA_ENABLE_XCCL -// TODO(hanbinyoon): Consider moving to nccl_utils.h when deprecating Thunks. -StatusOr LockNcclComm( - const NcclExecuteParams& params, - const std::vector& replica_groups, - CollectiveOpGroupMode group_mode, int64_t op_id, int64_t stream_id, - bool enable_clique_optimization); -#endif // XLA_ENABLE_XCCL - -struct DeviceBufferPair { - PrimitiveType element_type; - int64_t element_count; - se::DeviceMemoryBase source_buffer; - se::DeviceMemoryBase destination_buffer; -}; -StatusOr> ConvertToDeviceBuffers( - const Thunk::ExecuteParams& params, - const std::vector& buffers, - const std::vector& element_types); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_ diff --git a/xla/service/gpu/nccl_p2p_thunk_common.cc b/xla/service/gpu/nccl_p2p_thunk_common.cc deleted file mode 100644 index 448ff63952cd1..0000000000000 --- a/xla/service/gpu/nccl_p2p_thunk_common.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/nccl_p2p_thunk_common.h" - -#include -#include - -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "xla/service/hlo_parser.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace gpu { - -StatusOr>> GetSourceTargetPairs( - mlir::DictionaryAttr frontend_attributes) { - mlir::StringAttr src_dst_string = frontend_attributes.getAs( - kSendRecvSourceTargetPairsAttr); - if (!src_dst_string) { - return absl::AbortedError( - absl::StrCat("expecting send/recv op with string attribute ", - kSendRecvSourceTargetPairsAttr)); - } - TF_ASSIGN_OR_RETURN(std::vector replica_groups, - ParseReplicaGroupsOnly(src_dst_string.str())); - std::vector> source_target_pairs; - source_target_pairs.reserve(replica_groups.size()); - for (const ReplicaGroup& replica_group : replica_groups) { - TF_RET_CHECK(replica_group.replica_ids_size() == 2); - source_target_pairs.emplace_back(replica_group.replica_ids(0), - replica_group.replica_ids(1)); - } - return source_target_pairs; -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/nccl_p2p_thunk_common.h b/xla/service/gpu/nccl_p2p_thunk_common.h deleted file mode 100644 index 24b9ca2eb652c..0000000000000 --- a/xla/service/gpu/nccl_p2p_thunk_common.h +++ /dev/null @@ -1,135 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_NCCL_P2P_THUNK_COMMON_H_ -#define XLA_SERVICE_GPU_NCCL_P2P_THUNK_COMMON_H_ - -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "xla/service/collective_ops_utils.h" -#include "xla/service/gpu/nccl_collective_thunk.h" - -namespace xla { -namespace gpu { - -// Records the information for implementing CollectivePermute, Send and Recv. -struct NcclP2PConfig { - // Record the target ID for sending a data and the source ID from which to - // receive a data. Either target or source can be optional. - struct SourceTargetMapEntry { - std::optional source; - std::optional target; - }; - - using IdToSourceTargetMap = - absl::flat_hash_map; - - // Returns the source and target ID corresponding to the given ID (these IDs - // are replica_ids for cross replica permute or partition_ids for cross - // partition permute). The source ID is the id which will send data to this - // ID and the target ID is the id to which this ID will send its data. Either - // can be optional. - static SourceTargetMapEntry GetSourceTarget( - const IdToSourceTargetMap& id_to_source_target, int64_t id) { - auto it = id_to_source_target.find(id); - if (it != id_to_source_target.end()) return it->second; - return SourceTargetMapEntry{}; - } - - NcclCollectiveConfig config; - IdToSourceTargetMap id_to_source_target; -}; - -// Extracts source/target pairs for send/recv from frontend attributes. -StatusOr>> GetSourceTargetPairs( - mlir::DictionaryAttr frontend_attributes); - -// Returns the GroupMode for Send and Recv. -template -std::enable_if_t || - std::is_same_v, - CollectiveOpGroupMode> -GetGroupModeForSendRecv(OpT op) { - // return GetCollectiveOpGroupMode(op.getChannelHandle().getHandle() > 1, - // std::nullopt) - // .value(); - return CollectiveOpGroupMode::kFlattenedID; -} - -// Constructs the NcclP2PConfig for Send and Recv. -template -std::enable_if_t || - std::is_same_v, - NcclP2PConfig> -GetNcclP2PConfigForSendRecv(OpT op, int64_t replica_count, - int64_t partition_count) { - NcclP2PConfig p2p_config; - auto& config = p2p_config.config; - - config.operand_count = 1; - const Shape shape = GetShape(op.getOperand(0)); - config.operand_element_type.push_back(shape.element_type()); - - const int64_t channel_id = op.getChannelHandle().getHandle(); - config.group_mode = GetGroupModeForSendRecv(op); - // Emulate SetCollectiveOpKindAndID. - // Send and Recv ops have a non-optional channel id while collective-permute - // has an optional channel id. We use 0 to encode the send/recv transformed - // from collective-permute without a channel id. - if (channel_id >= 1) { - config.collective_op_kind = RendezvousKey::kCrossModule; - config.op_id = channel_id; - } else { - config.collective_op_kind = RendezvousKey::kCrossReplica; - mlir::ModuleOp parent = op->template getParentOfType(); - mlir::IntegerAttr unique_id = - parent->getAttrOfType("hlo.unique_id"); - config.op_id = static_cast(unique_id.getInt()); - } - - // All execution instances of a send/recv together form a replica group. - config.replica_count = replica_count; - config.partition_count = partition_count; - const int64_t num_participants = replica_count * partition_count; - config.replica_groups.emplace_back(); - ReplicaGroup& replica_group = config.replica_groups.front(); - for (int i = 0; i < num_participants; ++i) { - replica_group.add_replica_ids(i); - } - - auto source_target_pairs = GetSourceTargetPairs(op.getFrontendAttributes()); - TF_CHECK_OK(source_target_pairs.status()); - for (const std::pair& source_target : - *source_target_pairs) { - int64_t source = source_target.first; - int64_t target = source_target.second; - - p2p_config.id_to_source_target.insert({target, {}}).first->second.source = - source; - p2p_config.id_to_source_target.insert({source, {}}).first->second.target = - target; - } - return p2p_config; -} - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_NCCL_P2P_THUNK_COMMON_H_ diff --git a/xla/service/gpu/nccl_recv_thunk.cc b/xla/service/gpu/nccl_recv_thunk.cc deleted file mode 100644 index 7adce8c91d189..0000000000000 --- a/xla/service/gpu/nccl_recv_thunk.cc +++ /dev/null @@ -1,154 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/nccl_recv_thunk.h" - -#include -#include -#include -#include - -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/service/collective_ops_utils.h" - -#if XLA_ENABLE_XCCL -#include "xla/stream_executor/gpu/gpu_stream.h" -#endif - -namespace xla { -namespace gpu { - -using mlir::lmhlo::RecvOp; - -namespace impl { - -NcclP2PConfig GetNcclP2PConfig(RecvOp op, int64_t replica_count, - int64_t partition_count) { - return GetNcclP2PConfigForSendRecv(op, replica_count, partition_count); -} - -Status CheckImplementable(RecvOp op) { - TF_RETURN_IF_ERROR(NcclCollectiveThunk::CheckImplementable()); - return IsValidOperand(op.getOutputs()[0], Thunk::kNcclSend); -} - -} // namespace impl - -NcclRecvThunk::NcclRecvThunk(ThunkInfo thunk_info, RecvOp op, - int64_t replica_count, int64_t partition_count, - const Buffer& buffer) - : NcclCollectiveThunk(Thunk::kNcclRecv, thunk_info, /*is_sync=*/false), - config_(GetNcclP2PConfig(op, replica_count, partition_count)), - buffer_(buffer) {} - -/*static*/ NcclP2PConfig NcclRecvThunk::GetNcclP2PConfig( - RecvOp op, int64_t replica_count, int64_t partition_count) { - return impl::GetNcclP2PConfig(op, replica_count, partition_count); -} - -/*static*/ Status NcclRecvThunk::CheckImplementable(RecvOp op, - int64_t replica_count, - int64_t partition_count) { - return AddOpDescription(impl::CheckImplementable(op), op, - replica_count, partition_count); -} - -/*static*/ CollectiveOpGroupMode NcclRecvThunk::GetGroupMode(RecvOp op) { - return GetGroupModeForSendRecv(op); -} - -Status NcclRecvThunk::RunNcclCollective(const ExecuteParams& params, - se::Stream& stream, ncclComm_t comm) { - TF_ASSIGN_OR_RETURN( - std::vector device_buffers, - ConvertToDeviceBuffers(params, {buffer_}, - config_.config.operand_element_type)); - TF_RET_CHECK(device_buffers.size() == 1) << "Expected one buffer pair."; - - TF_ASSIGN_OR_RETURN(const GlobalDeviceId global_device_id, - params.nccl_params.GetGlobalDeviceId()); - TF_ASSIGN_OR_RETURN( - const DeviceAssignment::LogicalID current_logical_id, - params.nccl_params.device_assn->LogicalIdForDevice(global_device_id)); - // const int64_t current_id = - // config_.config.group_mode == CollectiveOpGroupMode::kCrossReplica - // ? current_logical_id.replica_id - // : current_logical_id.computation_id; - const int64_t current_id = - current_logical_id.replica_id * config_.config.partition_count + - current_logical_id.computation_id; - VLOG(3) << "Performing Recv, replica_id: " << current_logical_id.replica_id - << ", partition_count: " << config_.config.partition_count - << ", computation_id: " << current_logical_id.computation_id; - std::string device_string = GetDeviceString(params.nccl_params); - - const NcclP2PConfig::SourceTargetMapEntry source_target = - NcclP2PConfig::GetSourceTarget(config_.id_to_source_target, current_id); - - return ::xla::gpu::RunRecv(source_target, device_buffers[0], stream, comm, - device_string, current_id); -} - -Status RunRecv(NcclP2PConfig::SourceTargetMapEntry source_target, - DeviceBufferPair& buffer, se::Stream& stream, ncclComm_t comm, - absl::string_view device_string, int64_t current_id) { -#if XLA_ENABLE_XCCL - // Determine the source IDs for this instance. The source ID is the ID for - // the peer that will copy its data to this instance. If there is no source, - // just memzero() the destination buffer. - int device_ordinal = stream.parent()->device_ordinal(); - VLOG(3) << "Performing Recv from device ordinal: " << device_ordinal - << "current_id " << current_id; - - const std::optional source_id = source_target.source; - se::DeviceMemoryBase dest_addr = buffer.destination_buffer; - - VLOG(3) << absl::StreamFormat("%s : id = %d, source_id = %d", device_string, - current_id, source_id.value_or(-1)); - - TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier, - ToNcclDataTypeAndCountMultiplier( - buffer.element_type, Thunk::kNcclCollectivePermute)); - ncclDataType_t dtype = dtype_and_multiplier.first; - int64_t element_count = buffer.element_count * dtype_and_multiplier.second; - - se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream); - - // Receive data from the source peer to the destination buffer. - // if (source_id) { - VLOG(3) << absl::StreamFormat( - "%s : Calling ncclRecv(recvbuff=%p, count=%d, peer=%d comm=%p, " - "stream=%p)", - device_string, dest_addr.opaque(), element_count, *source_id, - static_cast(comm), gpu_stream); - XLA_CUDA_RETURN_IF_ERROR(ncclRecv(dest_addr.opaque(), element_count, dtype, - *source_id, comm, gpu_stream)); - // } else { - // // If there is no source peer, i.e. no sender to this instance, zero out - // // the destination buffer. - // VLOG(3) << absl::StreamFormat("%s : collective-Permute: Issuing MemZero", - // device_string); - // stream.ThenMemZero(&dest_addr, dest_addr.size()); - // } - return OkStatus(); -#else // XLA_ENABLE_XCCL - return Unimplemented( - "NCCL support is not available: this binary was not built with a CUDA " - "compiler, which is necessary to build the NCCL source library."); -#endif // XLA_ENABLE_XCCL -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/nccl_recv_thunk.h b/xla/service/gpu/nccl_recv_thunk.h deleted file mode 100644 index d8dbc22f98229..0000000000000 --- a/xla/service/gpu/nccl_recv_thunk.h +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_NCCL_RECV_THUNK_H_ -#define XLA_SERVICE_GPU_NCCL_RECV_THUNK_H_ - -#include - -#include "xla/service/collective_ops_utils.h" -#include "xla/service/gpu/nccl_collective_thunk.h" -#include "xla/service/gpu/nccl_p2p_thunk_common.h" - -namespace xla { -namespace gpu { - -// Thunk that performs a NCCL-recv. -class NcclRecvThunk : public NcclCollectiveThunk { - public: - static NcclP2PConfig GetNcclP2PConfig(mlir::lmhlo::RecvOp, - int64_t replica_count, - int64_t partition_count); - - static Status CheckImplementable(mlir::lmhlo::RecvOp op, - int64_t replica_count, - int64_t partition_count); - - static CollectiveOpGroupMode GetGroupMode(mlir::lmhlo::RecvOp op); - static const char* GetHloOpName() { return "recv"; } - - NcclRecvThunk(ThunkInfo thunk_info, mlir::lmhlo::RecvOp op, - int64_t replica_count, int64_t partition_count, - const Buffer& buffer); - - protected: - const NcclCollectiveConfig& config() const override { return config_.config; } - Status RunNcclCollective(const ExecuteParams& params, se::Stream& stream, - ncclComm_t comm) override; - AsyncStreamKind GetAsyncStreamKind() const override { - return kAsyncStreamP2P; - } - - private: - const NcclP2PConfig config_; - const Buffer buffer_; -}; - -Status RunRecv(NcclP2PConfig::SourceTargetMapEntry source_target, - DeviceBufferPair& buffer, se::Stream& stream, ncclComm_t comm, - absl::string_view device_string, int64_t current_id); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_NCCL_RECV_THUNK_H_ diff --git a/xla/service/gpu/nccl_send_thunk.cc b/xla/service/gpu/nccl_send_thunk.cc deleted file mode 100644 index 0127f62987174..0000000000000 --- a/xla/service/gpu/nccl_send_thunk.cc +++ /dev/null @@ -1,148 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/nccl_send_thunk.h" - -#include -#include -#include -#include - -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/service/collective_ops_utils.h" - -#if XLA_ENABLE_XCCL -#include "xla/stream_executor/gpu/gpu_stream.h" -#endif - -namespace xla { -namespace gpu { - -using mlir::lmhlo::SendOp; - -namespace impl { - -NcclP2PConfig GetNcclP2PConfig(SendOp op, int64_t replica_count, - int64_t partition_count) { - return GetNcclP2PConfigForSendRecv(op, replica_count, partition_count); -} - -Status CheckImplementable(SendOp op) { - TF_RETURN_IF_ERROR(NcclCollectiveThunk::CheckImplementable()); - return IsValidOperand(op.getInputs()[0], Thunk::kNcclSend); -} - -} // namespace impl - -NcclSendThunk::NcclSendThunk(ThunkInfo thunk_info, SendOp op, - int64_t replica_count, int64_t partition_count, - const Buffer& buffer) - : NcclCollectiveThunk(Thunk::kNcclSend, thunk_info, /*is_sync=*/false), - config_(GetNcclP2PConfig(op, replica_count, partition_count)), - buffer_(buffer) {} - -/*static*/ NcclP2PConfig NcclSendThunk::GetNcclP2PConfig( - SendOp op, int64_t replica_count, int64_t partition_count) { - return impl::GetNcclP2PConfig(op, replica_count, partition_count); -} - -/*static*/ Status NcclSendThunk::CheckImplementable(mlir::lmhlo::SendOp op, - int64_t replica_count, - int64_t partition_count) { - return AddOpDescription(impl::CheckImplementable(op), op, - replica_count, partition_count); -} - -/*static*/ CollectiveOpGroupMode NcclSendThunk::GetGroupMode(SendOp op) { - return GetGroupModeForSendRecv(op); -} - -Status NcclSendThunk::RunNcclCollective(const ExecuteParams& params, - se::Stream& stream, ncclComm_t comm) { - TF_ASSIGN_OR_RETURN( - std::vector device_buffers, - ConvertToDeviceBuffers(params, {buffer_}, - config_.config.operand_element_type)); - TF_RET_CHECK(device_buffers.size() == 1) << "Expected one buffer pair."; - - TF_ASSIGN_OR_RETURN(const GlobalDeviceId global_device_id, - params.nccl_params.GetGlobalDeviceId()); - TF_ASSIGN_OR_RETURN( - const DeviceAssignment::LogicalID current_logical_id, - params.nccl_params.device_assn->LogicalIdForDevice(global_device_id)); - // const int64_t current_id = - // config_.config.group_mode == CollectiveOpGroupMode::kCrossReplica - // ? current_logical_id.replica_id - // : current_logical_id.computation_id; - const int64_t current_id = - current_logical_id.replica_id * config_.config.partition_count + - current_logical_id.computation_id; - VLOG(3) << "Performing Send, replica_id: " << current_logical_id.replica_id - << ", partition_count: " << config_.config.partition_count - << ", computation_id: " << current_logical_id.computation_id; - std::string device_string = GetDeviceString(params.nccl_params); - - const NcclP2PConfig::SourceTargetMapEntry source_target = - NcclP2PConfig::GetSourceTarget(config_.id_to_source_target, current_id); - - return ::xla::gpu::RunSend(source_target, device_buffers[0], stream, comm, - device_string, current_id); -} - -Status RunSend(NcclP2PConfig::SourceTargetMapEntry source_target, - DeviceBufferPair& buffer, se::Stream& stream, ncclComm_t comm, - absl::string_view device_string, int64_t current_id) { -#if XLA_ENABLE_XCCL - // Determine the target IDs for this instance. The target ID is the ID - // to which this instance will copy its data. - - int device_ordinal = stream.parent()->device_ordinal(); - VLOG(3) << "Performing Send from device ordinal: " - << device_ordinal << "current_id " << current_id; - - const std::optional target_id = source_target.target; - se::DeviceMemoryBase src_addr = buffer.source_buffer; - - VLOG(3) << absl::StreamFormat("%s : id = %d, target_id = %d", device_string, - current_id, target_id.value_or(-1)); - - TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier, - ToNcclDataTypeAndCountMultiplier( - buffer.element_type, Thunk::kNcclCollectivePermute)); - ncclDataType_t dtype = dtype_and_multiplier.first; - int64_t element_count = buffer.element_count * dtype_and_multiplier.second; - - se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream); - - // Send source buffer to target peer if needed. - // if (target_id) { - VLOG(3) << absl::StreamFormat( - "%s : Calling ncclSend(sendbuff=%p, count=%d, peer=%d " - "comm=%p, stream=%p)", - device_string, src_addr.opaque(), element_count, *target_id, - static_cast(comm), gpu_stream); - XLA_CUDA_RETURN_IF_ERROR(ncclSend(src_addr.opaque(), element_count, dtype, - *target_id, comm, gpu_stream)); - // } - return OkStatus(); -#else // XLA_ENABLE_XCCL - return Unimplemented( - "NCCL support is not available: this binary was not built with a CUDA " - "compiler, which is necessary to build the NCCL source library."); -#endif // XLA_ENABLE_XCCL -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/nccl_send_thunk.h b/xla/service/gpu/nccl_send_thunk.h deleted file mode 100644 index b57b1ee3279d7..0000000000000 --- a/xla/service/gpu/nccl_send_thunk.h +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_NCCL_SEND_THUNK_H_ -#define XLA_SERVICE_GPU_NCCL_SEND_THUNK_H_ - -#include - -#include "xla/service/collective_ops_utils.h" -#include "xla/service/gpu/nccl_collective_thunk.h" -#include "xla/service/gpu/nccl_p2p_thunk_common.h" - -namespace xla { -namespace gpu { - -// Thunk that performs a NCCL-send. -class NcclSendThunk : public NcclCollectiveThunk { - public: - static NcclP2PConfig GetNcclP2PConfig(mlir::lmhlo::SendOp, - int64_t replica_count, - int64_t partition_count); - - static Status CheckImplementable(mlir::lmhlo::SendOp op, - int64_t replica_count, - int64_t partition_count); - static CollectiveOpGroupMode GetGroupMode(mlir::lmhlo::SendOp op); - static const char* GetHloOpName() { return "send"; } - - NcclSendThunk(ThunkInfo thunk_info, mlir::lmhlo::SendOp op, - int64_t replica_count, int64_t partition_count, - const Buffer& buffer); - - protected: - const NcclCollectiveConfig& config() const override { return config_.config; } - Status RunNcclCollective(const ExecuteParams& params, se::Stream& stream, - ncclComm_t comm) override; - AsyncStreamKind GetAsyncStreamKind() const override { - return kAsyncStreamP2P; - } - - private: - const NcclP2PConfig config_; - const Buffer buffer_; -}; - -Status RunSend(NcclP2PConfig::SourceTargetMapEntry source_target, - DeviceBufferPair& buffer, se::Stream& stream, ncclComm_t comm, - absl::string_view device_string, int64_t current_id); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_NCCL_SEND_THUNK_H_ diff --git a/xla/service/gpu/nccl_utils.cc b/xla/service/gpu/nccl_utils.cc deleted file mode 100644 index a1c0beb0b9ab7..0000000000000 --- a/xla/service/gpu/nccl_utils.cc +++ /dev/null @@ -1,334 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/nccl_utils.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_format.h" -#include "absl/synchronization/notification.h" -#include "absl/time/time.h" -#include "xla/debug_options_flags.h" -#include "xla/service/global_device_id.h" -#include "xla/service/gpu/gpu_executable_run_options.h" -#include "xla/service/gpu/thunk.h" -#include "xla/service/rendezvous.h" -#include "xla/status_macros.h" -#include "xla/statusor.h" -#include "tsl/platform/env.h" - -namespace xla { -namespace gpu { - -bool IsGlobalNcclConfig() { - static const bool global_nccl_config = std::getenv("NCCL_COMM_ID") != nullptr; - return global_nccl_config; -} - -Status ToStatus(ncclResult_t s, const char* file, int64_t line, - const char* expr) { - if (s == ncclSuccess) { - return OkStatus(); - } - return tsl::errors::Internal(absl::StrFormat( - "%s:%d: NCCL operation %s failed: %s." - " Last NCCL warning(error) log entry (may be unrelated) '%s'.", - file, line, expr, ncclGetErrorString(s), ncclGetLastError(NULL))); -} - -ncclRedOp_t ToNcclReduction(ReductionKind kind) { - switch (kind) { - case ReductionKind::SUM: - return ncclSum; - case ReductionKind::PRODUCT: - return ncclProd; - case ReductionKind::MIN: - return ncclMin; - case ReductionKind::MAX: - return ncclMax; - } -} - -namespace { - -StatusOr ToNcclDataType(PrimitiveType element_type, - Thunk::Kind reduction_op) { - switch (element_type) { - case S8: - case F8E5M2: - case F8E4M3FN: - return ncclInt8; - case PRED: - case U8: - return ncclUint8; - case S32: - return ncclInt32; - case U32: - return ncclUint32; - case S64: - return ncclInt64; - case U64: - return ncclUint64; - case F16: - return ncclFloat16; - case F32: - case C64: - return ncclFloat32; - case F64: - case C128: - return ncclFloat64; - case S16: - case U16: - // For all-reduce and reduce-scatter, we expect 16 bit integer types to be - // promoted to 32-bit. - if (reduction_op == Thunk::kNcclAllReduce || - reduction_op == Thunk::kNcclAllReduceStart || - reduction_op == Thunk::kNcclReduceScatter) { - return tsl::errors::InvalidArgument(absl::StrFormat( - "Unsupported data type: %s", PrimitiveType_Name(element_type))); - } - // For collectives that just move data around, we can use ncclFloat16 for - // 16-bit integer data types. - return ncclFloat16; -#if defined(__CUDA_BF16_TYPES_EXIST__) || TENSORFLOW_USE_ROCM - case BF16: - return ncclBfloat16; -#endif - default: - return tsl::errors::InvalidArgument(absl::StrFormat( - "Unsupported data type: %s", PrimitiveType_Name(element_type))); - } -} - -StatusOr ToNcclUniqueId(const std::string& id_str) { - static_assert(sizeof(ncclUniqueId) == NCCL_UNIQUE_ID_BYTES, - "NCCL_UNIQUE_ID_BYTES"); - - TF_RET_CHECK(id_str.size() == NCCL_UNIQUE_ID_BYTES); - ncclUniqueId id; - absl::c_copy(id_str, id.internal); - return id; -} - -StatusOr LocalNcclUniqueIdCallback(const NcclCliqueKey&) { - ncclUniqueId id; - XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&id)); - return std::string(id.internal, NCCL_UNIQUE_ID_BYTES); -} - -struct NcclCliqueState { - ncclUniqueId unique_id; - int64_t run_id = -1; - - // `mu` guards `communicators` and `status` during initialization. - // Once `ready` has been notified, the communicators may be accessed without - // synchronization. - absl::Mutex mu; - absl::Notification ready; - Status status; - absl::flat_hash_map> communicators; -}; - -using NcclClique = Lockable; - -std::shared_ptr> AcquireNcclClique( - RunId run_id, OpId op_id, NcclCliqueKey clique_key, - const NcclUniqueIdCallback& unique_id_callback, - size_t num_local_participants, bool may_skip_rendezvous) { - static auto& cliques = *new ThreadSafeMap; - - VLOG(2) << "AcquireNcclClique Rendezvous key (clique_key:" - << clique_key.ToString() << ", run" << run_id.ToString() << ", op" - << op_id.value() << ")"; - - // RendezvousSingle should only be used to guard nccl communicator - // initialization. Return the clique state when we are done with such - // initialization. - // - // TODO(bixia): enable this unconditionally after fixing a deadlock issue. - if (may_skip_rendezvous) { - // Destruct clique if it hasn't been notified. - NcclClique::Lock clique = cliques[clique_key].Acquire(); - if (clique->ready.HasBeenNotified() && clique->run_id == run_id.ToInt()) { - return std::make_shared>(std::move(clique)); - } - } - - auto rendezvous_key = std::make_tuple(run_id, op_id, std::move(clique_key)); - - int64_t terminate_timeout = xla::GetDebugOptionsFromFlags() - .xla_gpu_nccl_termination_timeout_seconds(); - - return RendezvousSingle>( - rendezvous_key, num_local_participants, - [&]() -> StatusOr { - const NcclCliqueKey& clique_key = std::get<2>(rendezvous_key); - NcclClique::Lock clique = cliques[clique_key].Acquire(); - if (clique->run_id < 0) { - TF_ASSIGN_OR_RETURN(std::string id, unique_id_callback(clique_key)); - TF_ASSIGN_OR_RETURN(clique->unique_id, ToNcclUniqueId(id)); - } - // If multiple executable are running simultaneously while using - // multiple hosts, it is possible that different executables could - // acquire the same clique on different hosts. We protect against this - // by checking that the run ID increases monotonically. - bool is_local = clique_key.devices().size() == num_local_participants; - TF_RET_CHECK(is_local || (run_id.ToInt() >= clique->run_id)); - clique->run_id = run_id.ToInt(); - return clique; - }, - /*warn_stuck_timeout=*/absl::Seconds(10), - (terminate_timeout >= 0) ? absl::Seconds(terminate_timeout) - : absl::InfiniteDuration()); -} - -void CheckNcclAsyncError(NcclComm& lockable_comm) { - ncclComm_t comm = *lockable_comm.Acquire(); - if (comm == nullptr) return; - - Status status = [comm] { - ncclResult_t async_err; - XLA_CUDA_RETURN_IF_ERROR(ncclCommGetAsyncError(comm, &async_err)); - if (async_err != ncclSuccess) { - LOG(ERROR) << "Aborting communicator: " << comm - << " due to async NCCL error: " - << ncclGetErrorString(async_err) - << ". Last NCCL warning(error) log entry (may be unrelated): " - << ncclGetLastError(NULL); - XLA_CUDA_RETURN_IF_ERROR(ncclCommAbort(comm)); - } - return XLA_CUDA_STATUS(async_err); - }(); - - if (!status.ok()) LOG(ERROR) << status; -} - -} // namespace - -StatusOr> ToNcclDataTypeAndCountMultiplier( - PrimitiveType element_type, Thunk::Kind reduction_op) { - TF_ASSIGN_OR_RETURN(ncclDataType_t dtype, - ToNcclDataType(element_type, reduction_op)); - bool is_complex = primitive_util::IsComplexType(element_type); - return std::make_pair(dtype, is_complex ? 2 : 1); -} - -size_t GetNumLocalParticipants( - const std::vector& participants, - const std::vector* local_devices) { - if (local_devices == nullptr) return participants.size(); - - return absl::c_count_if(participants, [&](const GlobalDeviceId& device_id) { - return absl::c_linear_search(*local_devices, device_id); - }); -} - -StatusOr GetNcclUniqueIdCallback( - const NcclUniqueIdCallback* unique_id_callback, bool is_local) { - if (unique_id_callback != nullptr) return unique_id_callback; - - TF_RET_CHECK(is_local || IsGlobalNcclConfig()) - << "If non-local devices are taking part of a collective API on " - "GPU, the nccl_unique_id_callback must be provided by the client."; - - static auto* local_callback = - new NcclUniqueIdCallback(LocalNcclUniqueIdCallback); - return local_callback; -} - -StatusOr AcquireNcclComm( - RunId run_id, OpId op_id, std::vector participants, - size_t num_local_participants, - const NcclUniqueIdCallback& unique_id_callback, int rank, int64_t stream_id, - bool enable_clique_optimization) { - // Ensure that this group of threads have exclusive access to the clique to - // prevent threads from different groups locking communicators in the clique. - // The enable_clique_optimization value is only used for asynchronous - // collective stream currenly. For synchronous collectives, we should always - // enable the optimization. For P2P stream, we currently have to always enable - // the optimization, because we initially implement this optimization to - // workaround an NCCL bug related to P2P operations. - NcclCliqueKey clique_key(std::move(participants), stream_id); - std::shared_ptr> clique = AcquireNcclClique( - run_id, op_id, clique_key, unique_id_callback, num_local_participants, - enable_clique_optimization || - stream_id != GetStreamId(/*is_async=*/true, kAsyncStreamCollective)); - - if (!clique->ok()) return clique->status(); - - struct AllCommunicators { - absl::Mutex mu; - std::vector communicators ABSL_GUARDED_BY(mu); - }; - static auto& all_communicators = *new AllCommunicators; - - // Launch a thread that periodically checks all NCCL communicators for - // asynchronous errors. If an asynchronous error is observed, the communicator - // is aborted and an error message logged. - static auto check_async_error_thread = tsl::Env::Default()->StartThread( - tsl::ThreadOptions(), "nccl_async_error_thread", [&] { - while (true) { - absl::SleepFor(absl::Seconds(30)); - absl::MutexLock lock(&all_communicators.mu); - for (NcclComm* comm : all_communicators.communicators) { - CheckNcclAsyncError(*comm); - } - } - }); - (void)check_async_error_thread; // Silence unused variable warning. - - NcclCliqueState& state = ***clique; - if (!state.ready.HasBeenNotified()) { - int nranks = clique_key.devices().size(); - const ncclUniqueId& id = state.unique_id; - - ncclComm_t comm = nullptr; - Status status = XLA_CUDA_STATUS(ncclCommInitRank(&comm, nranks, id, rank)); - - size_t num_initialized = [&] { - absl::MutexLock lock(&state.mu); - state.status.Update(status); - state.communicators[rank] = std::make_unique(comm); - return state.communicators.size(); - }(); - - // Wait for all communicators to initialize before allowing any progress. - // Otherwise we may get deadlocks, because ncclCommInitRank may allocate, - // which may block on the completion of device activity on a peer device, - // which may depend on the completion of this collective if we do not have a - // barrier to prevent it. - if (num_initialized == num_local_participants) { - state.ready.Notify(); - } else { - TF_RETURN_IF_ERROR(status); - state.ready.WaitForNotification(); - } - - absl::MutexLock lock(&all_communicators.mu); - all_communicators.communicators.push_back(state.communicators[rank].get()); - } - - TF_RETURN_IF_ERROR(state.status); - return state.communicators[rank]->Acquire(); -} -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/nccl_utils.h b/xla/service/gpu/nccl_utils.h deleted file mode 100644 index 1d8c9571a81d8..0000000000000 --- a/xla/service/gpu/nccl_utils.h +++ /dev/null @@ -1,138 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_NCCL_UTILS_H_ -#define XLA_SERVICE_GPU_NCCL_UTILS_H_ - -#if TENSORFLOW_USE_ROCM -#define __HIP_DISABLE_CPP_FUNCTIONS__ -#endif - -#include -#include -#include - -#include "absl/synchronization/mutex.h" -#include "xla/service/collective_ops_utils.h" -#include "xla/service/gpu/gpu_executable_run_options.h" -#include "xla/service/gpu/thunk.h" -#include "xla/status.h" -#include "xla/statusor.h" -#include "xla/xla_data.pb.h" - -// Common place for all collective thunks to include nccl/rccl headers. -#if TENSORFLOW_USE_ROCM -#include "rocm/rocm_config.h" -#if (TF_ROCM_VERSION >= 50200) -#include "rocm/include/rccl/rccl.h" -#else -#include "rocm/include/rccl.h" -#endif -#else -#include "third_party/nccl/nccl.h" -#endif - -namespace xla { -namespace gpu { - -ncclRedOp_t ToNcclReduction(ReductionKind kind); -StatusOr> ToNcclDataTypeAndCountMultiplier( - PrimitiveType element_type, Thunk::Kind reduction_op); - -bool IsGlobalNcclConfig(); - -Status ToStatus(ncclResult_t s, const char* file, int64_t line, - const char* expr); - -// Macros to return or warn on CUDA/NCCL errors. (The same macro works for both -// NCCL and CUDA errors.) -// -// It's tempting to say these macros belong in an XLA header somewhere, but in -// practice we don't do much direct-to-CUDA-API stuff outside of this file. -#define XLA_CUDA_STATUS(expr) \ - xla::gpu::ToStatus(expr, __FILE__, __LINE__, #expr) - -#define XLA_CUDA_RETURN_IF_ERROR(expr) \ - do { \ - Status s = XLA_CUDA_STATUS(expr); \ - if (!s.ok()) { \ - return s; \ - } \ - } while (0) - -#define XLA_CUDA_WARN_IF_ERROR(expr) \ - do { \ - Status s = XLA_CUDA_STATUS(expr); \ - if (!s.ok()) { \ - LOG(ERROR) << s.ToString(); \ - } \ - } while (0) - -size_t GetNumLocalParticipants( - const std::vector& participants, - const std::vector* local_devices); // may be null - -StatusOr GetNcclUniqueIdCallback( - const NcclUniqueIdCallback* unique_id_callback, // may be null - bool is_local); - -// Represents a type that requires mutually exclusive access. -template -class Lockable { - public: - // RAII type that will release the exclusive lock when it is destroyed. - using Lock = std::unique_ptr>; - - Lockable() = default; - explicit Lockable(T value) : value_(std::move(value)) {} - Lockable(const Lockable&) = delete; - Lockable(Lockable&&) = delete; - Lockable& operator=(const Lockable&) = delete; - Lockable& operator=(Lockable&&) = delete; - - Lock Acquire() { - absl::MutexLock lock(&mutex_); - mutex_.Await(absl::Condition(&is_unlocked_)); - is_unlocked_ = false; - - return {&value_, [this](T*) { - absl::MutexLock lock(&mutex_); - CHECK(!is_unlocked_); - is_unlocked_ = true; - }}; - } - - private: - T value_; - absl::Mutex mutex_; - bool is_unlocked_ ABSL_GUARDED_BY(mutex_) = true; -}; - -TSL_LIB_GTL_DEFINE_INT_TYPE(OpId, int64_t); - -struct NcclComm : public Lockable { - explicit NcclComm(ncclComm_t comm) : Lockable(comm) {} -}; - -StatusOr AcquireNcclComm( - RunId run_id, OpId op_id, std::vector participants, - size_t num_local_participants, - const NcclUniqueIdCallback& unique_id_callback, int rank, int64_t stream_id, - bool enable_clique_optimization); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_NCCL_UTILS_H_ diff --git a/xla/service/gpu/non_atomically_upgradeable_rw_lock.h b/xla/service/gpu/non_atomically_upgradeable_rw_lock.h deleted file mode 100644 index ea09016bd6f30..0000000000000 --- a/xla/service/gpu/non_atomically_upgradeable_rw_lock.h +++ /dev/null @@ -1,95 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_NON_ATOMICALLY_UPGRADEABLE_RW_LOCK_H_ -#define XLA_SERVICE_GPU_NON_ATOMICALLY_UPGRADEABLE_RW_LOCK_H_ - -#include -#include - -#include "absl/synchronization/mutex.h" - -namespace xla { -namespace gpu { - -// Augments absl::ReaderMutexLock with a poor man's upgrade/downgrade pair using -// RAII. Instead of a true upgrade (or downgrade), we simply drop the read -// (write) lock and then reacquire it as a write (read) lock. -class ABSL_SCOPED_LOCKABLE NonAtomicallyUpgradeableRWLock { - public: - explicit NonAtomicallyUpgradeableRWLock(absl::Mutex* mu) - ABSL_SHARED_LOCK_FUNCTION(mu) - : mu_(mu), is_reader_(true) { - mu_->ReaderLock(); - } - - NonAtomicallyUpgradeableRWLock(const NonAtomicallyUpgradeableRWLock&) = - delete; - NonAtomicallyUpgradeableRWLock(NonAtomicallyUpgradeableRWLock&&) = delete; - NonAtomicallyUpgradeableRWLock& operator=( - const NonAtomicallyUpgradeableRWLock&) = delete; - NonAtomicallyUpgradeableRWLock& operator=(NonAtomicallyUpgradeableRWLock&&) = - delete; - - ~NonAtomicallyUpgradeableRWLock() ABSL_UNLOCK_FUNCTION() { - if (is_reader_) { - mu_->ReaderUnlock(); - } else { - mu_->WriterUnlock(); - } - } - - // Upgrade and downgrade the reader lock via RAII. - class ABSL_SCOPED_LOCKABLE WriterLock { - public: - explicit WriterLock(NonAtomicallyUpgradeableRWLock* parent) - ABSL_EXCLUSIVE_LOCK_FUNCTION(parent->mu_) - : parent_(parent) { - assert(parent_->is_reader_); - parent_->mu_->ReaderUnlock(); - parent_->mu_->WriterLock(); - parent_->is_reader_ = false; - } - - WriterLock(const WriterLock&) = delete; - WriterLock(WriterLock&&) = delete; - WriterLock& operator=(const WriterLock&) = delete; - WriterLock& operator=(WriterLock&&) = delete; - - ~WriterLock() ABSL_UNLOCK_FUNCTION() { - parent_->mu_->WriterUnlock(); - parent_->mu_->ReaderLock(); - parent_->is_reader_ = true; - } - - private: - NonAtomicallyUpgradeableRWLock* parent_; - }; - - // Update the reader lock to a writer lock. The function is invalid if the - // lock is already upgraded. - WriterLock UpgradeToWriterMutexLock() ABSL_NO_THREAD_SAFETY_ANALYSIS { - return WriterLock(this); - } - - private: - absl::Mutex* const mu_; - bool is_reader_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_NON_ATOMICALLY_UPGRADEABLE_RW_LOCK_H_ diff --git a/xla/service/gpu/non_atomically_upgradeable_rw_lock_test.cc b/xla/service/gpu/non_atomically_upgradeable_rw_lock_test.cc deleted file mode 100644 index afb33006452e8..0000000000000 --- a/xla/service/gpu/non_atomically_upgradeable_rw_lock_test.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" - -#include -#include "tsl/platform/test.h" - -namespace xla { -namespace gpu { -namespace { - -TEST(NonAtomicallyUpgradeableRWLock, UpgradeReaderMutexLock) { - absl::Mutex mu; - { - NonAtomicallyUpgradeableRWLock reader_lock(&mu); - mu.AssertReaderHeld(); - - { - NonAtomicallyUpgradeableRWLock::WriterLock writer_lock = - reader_lock.UpgradeToWriterMutexLock(); - mu.AssertHeld(); - } - - // The lock downgrades after the WriterLock goes out of scope. - mu.AssertReaderHeld(); - } - mu.AssertNotHeld(); -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/norm_thunk.cc b/xla/service/gpu/norm_thunk.cc deleted file mode 100644 index 915a2e99315a6..0000000000000 --- a/xla/service/gpu/norm_thunk.cc +++ /dev/null @@ -1,103 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/norm_thunk.h" - -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" -#include "xla/util.h" -#include "tsl/platform/logging.h" - -namespace xla { -namespace gpu { - -NormThunk::NormThunk(ThunkInfo thunk_info, GpuNormConfig config, - BufferAllocation::Slice input_slice, - BufferAllocation::Slice scale_slice, - BufferAllocation::Slice bias_slice, - BufferAllocation::Slice output_slice, - std::optional expectation_slice, - std::optional norm_factor_slice, - BufferAllocation::Slice scratch_slice) - : Thunk(Kind::kNorm, thunk_info), - input_buffer_(input_slice), - scale_buffer_(scale_slice), - bias_buffer_(bias_slice), - output_buffer_(output_slice), - expectation_buffer_(expectation_slice), - norm_factor_buffer_(norm_factor_slice), - scratch_buffer_(scratch_slice), - config_(config) {} - -NormRunner& NormThunk::GetOrCreateRunner( - const stream_executor::Stream* stream) { - absl::MutexLock lock(&mu_); - auto it = runner_cache_.find(stream); - if (it == runner_cache_.end()) { - it = runner_cache_.insert({stream, std::make_unique(config_)}) - .first; - } - return *it->second; -} - -Status NormThunk::ExecuteOnStream(const ExecuteParams& params) { - const auto& buffer_allocations = *params.buffer_allocations; - - se::DeviceMemoryBase input_se_buffer = - buffer_allocations.GetDeviceAddress(input_buffer_); - se::DeviceMemoryBase scale_se_buffer = - buffer_allocations.GetDeviceAddress(scale_buffer_); - se::DeviceMemoryBase bias_se_buffer = - buffer_allocations.GetDeviceAddress(bias_buffer_); - se::DeviceMemoryBase output_se_buffer = - buffer_allocations.GetDeviceAddress(output_buffer_); - - std::optional expectation_se_buffer, - norm_factor_se_buffer; - if (expectation_buffer_) { - expectation_se_buffer = - buffer_allocations.GetDeviceAddress(expectation_buffer_.value()); - } - if (norm_factor_buffer_) { - norm_factor_se_buffer = - buffer_allocations.GetDeviceAddress(norm_factor_buffer_.value()); - } - - se::DeviceMemoryBase scratch = - buffer_allocations.GetDeviceAddress(scratch_buffer_); - - RunNormOptions opts; - opts.norm_runner = &GetOrCreateRunner(params.stream); - - TF_RETURN_IF_ERROR(RunGpuNorm(config_, input_se_buffer, scale_se_buffer, - bias_se_buffer, output_se_buffer, - expectation_se_buffer, norm_factor_se_buffer, - scratch, params.stream, opts)); - - if (!params.stream->ok()) { - return InternalError("NormThunk::ExecuteOnStream failed."); - } - return OkStatus(); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/norm_thunk.h b/xla/service/gpu/norm_thunk.h deleted file mode 100644 index 2c9ad5598ef61..0000000000000 --- a/xla/service/gpu/norm_thunk.h +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_NORM_THUNK_H_ -#define XLA_SERVICE_GPU_NORM_THUNK_H_ - -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/gpu_executable.h" -#include "xla/service/gpu/gpu_norm_runner.h" -#include "xla/service/gpu/thunk.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/status.h" - -namespace xla { -namespace gpu { - -class NormThunk : public Thunk { - public: - NormThunk(ThunkInfo thunk_info, GpuNormConfig config, - BufferAllocation::Slice input, BufferAllocation::Slice scale, - BufferAllocation::Slice bias, BufferAllocation::Slice output, - std::optional expectation, - std::optional norm_factor, - BufferAllocation::Slice scratch); - - NormThunk(const NormThunk&) = delete; - NormThunk& operator=(const NormThunk&) = delete; - - Status ExecuteOnStream(const ExecuteParams& params) override; - - private: - BufferAllocation::Slice input_buffer_; - BufferAllocation::Slice scale_buffer_; - BufferAllocation::Slice bias_buffer_; - BufferAllocation::Slice output_buffer_; - std::optional expectation_buffer_; - std::optional norm_factor_buffer_; - BufferAllocation::Slice scratch_buffer_; - NormRunner& GetOrCreateRunner(const stream_executor::Stream*); - - GpuNormConfig config_; - absl::Mutex mu_; - absl::flat_hash_map> - runner_cache_ ABSL_GUARDED_BY(mu_); -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_NORM_THUNK_H_ diff --git a/xla/service/gpu/nvptx_compiler.cc b/xla/service/gpu/nvptx_compiler.cc index 2c1ca2b440034..8620bbbb8c0db 100644 --- a/xla/service/gpu/nvptx_compiler.cc +++ b/xla/service/gpu/nvptx_compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,19 +16,30 @@ limitations under the License. #include "xla/service/gpu/nvptx_compiler.h" #include +#include #include +#include #include -#include #include #include #include #include +#include "absl/algorithm/container.h" #include "absl/base/call_once.h" +#include "absl/cleanup/cleanup.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/algebraic_simplifier.h" #include "xla/service/call_inliner.h" @@ -40,32 +51,34 @@ limitations under the License. #include "xla/service/gpu/autotuner_util.h" #include "xla/service/gpu/buffer_sharing.h" #include "xla/service/gpu/conv_algorithm_picker.h" -#include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/cublas_pad_for_gemms.h" #include "xla/service/gpu/cublas_padding_requirements.h" #include "xla/service/gpu/cudnn_fused_conv_rewriter.h" #include "xla/service/gpu/cudnn_fused_mha_rewriter.h" #include "xla/service/gpu/cudnn_fused_mha_transpose_fusion.h" +#include "xla/service/gpu/cudnn_fusion_compiler.h" #include "xla/service/gpu/cudnn_norm_rewriter.h" #include "xla/service/gpu/cudnn_pad_for_convolutions.h" #include "xla/service/gpu/cudnn_simplify_padding.h" #include "xla/service/gpu/cudnn_vectorize_convolutions.h" #include "xla/service/gpu/cusolver_rewriter.h" #include "xla/service/gpu/gemm_algorithm_picker.h" +#include "xla/service/gpu/gemm_fusion_autotuner.h" #include "xla/service/gpu/gpu_asm_opts_util.h" +#include "xla/service/gpu/gpu_compiler.h" #include "xla/service/gpu/gpu_conv_padding_legalization.h" #include "xla/service/gpu/gpu_conv_rewriter.h" #include "xla/service/gpu/gpu_sort_rewriter.h" -#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "xla/service/gpu/metrics.h" #include "xla/service/gpu/move_copy_to_users.h" #include "xla/service/gpu/target_constants.h" #include "xla/service/gpu/triangular_solve_rewriter.h" -#include "xla/service/gpu/triton_autotuner.h" #include "xla/service/hlo_constant_folding.h" #include "xla/service/hlo_cse.h" +#include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_dce.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_fix.h" #include "xla/service/hlo_pass_pipeline.h" #include "xla/service/hlo_verifier.h" @@ -74,20 +87,29 @@ limitations under the License. #include "xla/service/reshape_decomposer.h" #include "xla/service/reshape_mover.h" #include "xla/service/tuple_simplifier.h" +#include "xla/status.h" #include "xla/stream_executor/cuda/cuda_diagnostics.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/cuda/ptx_compiler.h" +#include "xla/stream_executor/cuda/ptx_compiler_support.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/dnn.h" #include "xla/stream_executor/gpu/asm_compiler.h" +#include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/stream_executor_internal.h" +#include "xla/tsl/util/env_var.h" #include "xla/util.h" #include "xla/xla.pb.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" #include "tsl/platform/path.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" +#include "tsl/platform/threadpool.h" #include "tsl/profiler/lib/traceme.h" -#include "tsl/util/env_var.h" namespace xla { namespace gpu { @@ -125,7 +147,7 @@ class ConvBfloat16Support : public FloatSupport { } // namespace -Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( +absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( HloModule* hlo_module, se::GpuComputeCapability gpu_version, se::dnn::VersionInfo dnn_version, se::DeviceMemoryAllocator* device_allocator) { @@ -138,7 +160,7 @@ Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( /*layout_sensitive=*/false, /*allow_mixed_precision=*/false); - // Convert upsupported bf16 convolutions to f32. + // Convert unsupported bf16 convolutions to f32. ConvBfloat16Support conv_bf16_support(dnn_version, cuda_compute_capability); pipeline.AddPass(&conv_bf16_support); @@ -155,7 +177,8 @@ Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddPass(); pipeline.AddPass(); - AlgebraicSimplifierOptions algsimp_options; + AlgebraicSimplifierOptions algsimp_options = + GetAlgebraicSimplifierOptions(hlo_module->config()); algsimp_options.set_enable_conv_operand_swap(false); algsimp_options.set_enable_unconditional_reduce_of_concat_replacement(false); pipeline.AddPass>(algsimp_options); @@ -195,15 +218,13 @@ Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - return OkStatus(); + return absl::OkStatus(); } -Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( +absl::Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( HloModule* hlo_module, se::StreamExecutor* stream_exec, const CompileOptions& options, const TargetConfig& gpu_target_config, tsl::thread::ThreadPool* thread_pool) { - HloPassPipeline pre_pipeline("nvptx post-layout_assignment part 1"); - // This needs to run before GemmRewriter, which is part of // OptimizeHloPostLayoutAssignment(). auto cuda_compute_capability = std::get( @@ -212,10 +233,10 @@ Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( if (hlo_module->config().debug_options().xla_gpu_enable_cudnn_fmha()) { HloPassPipeline mha_fusion_pipeline( "nvptx cudnn multi-headed attention fusion"); - const DebugOptions& debug_options = hlo_module->config().debug_options(); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. - AlgebraicSimplifierOptions alg_sim_options; + AlgebraicSimplifierOptions alg_sim_options = + GetAlgebraicSimplifierOptions(hlo_module->config()); alg_sim_options.set_supports_non_canonical_dots(false); alg_sim_options.set_is_layout_sensitive(true); alg_sim_options.set_enable_conv_operand_swap(false); @@ -224,11 +245,6 @@ Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( !hlo_module->config().debug_options().xla_gpu_enable_fast_min_max()); alg_sim_options.set_enable_unconditional_reduce_of_concat_replacement( false); - if (debug_options.xla_gpu_normalize_layouts()) { - mha_fusion_pipeline.AddPass(); - mha_fusion_pipeline.AddPass>(); - mha_fusion_pipeline.AddPass(); - } mha_fusion_pipeline.AddPass(/*is_layout_sensitive=*/true); mha_fusion_pipeline.AddPass>( @@ -249,8 +265,11 @@ Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( TF_RETURN_IF_ERROR(mha_fusion_pipeline.Run(hlo_module).status()); } - // Rewrite normalization patterns into cuDNN Custom Calls. - pre_pipeline.AddPass(cuda_compute_capability); + HloPassPipeline pre_pipeline("nvptx post-layout_assignment part 1"); + if (hlo_module->config().debug_options().xla_gpu_enable_cudnn_layer_norm()) { + // Rewrite normalization patterns into cuDNN Custom Calls. + pre_pipeline.AddPass(cuda_compute_capability); + } pre_pipeline.AddPass(); @@ -278,7 +297,7 @@ Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( TF_RETURN_IF_ERROR(post_pipeline.Run(hlo_module).status()); - return OkStatus(); + return absl::OkStatus(); } // Linearize collective schedule under if online autotuning of convolutions is @@ -299,30 +318,38 @@ bool NVPTXCompiler::RequiresCollectiveScheduleLinearizer( return false; } -Status NVPTXCompiler::AddConvAndGemmAutotuningPasses( +absl::Status NVPTXCompiler::AddConvAndGemmAutotuningPasses( HloPassPipeline* pipeline, HloModule* hlo_module, AutotuneConfig& autotune_config, tsl::thread::ThreadPool* thread_pool) { if (GpuConvAlgorithmPicker::IsEnabled(hlo_module)) { pipeline->AddPass(autotune_config); } pipeline->AddPass(autotune_config); - return OkStatus(); + return absl::OkStatus(); } -Status NVPTXCompiler::AddTritonGemmAutotuningPasses( +absl::Status NVPTXCompiler::AddGemmFusionAutotuningPasses( HloPassPipeline* pipeline, HloModule* hlo_module, AutotuneConfig& autotune_config, tsl::thread::ThreadPool* thread_pool) { - pipeline->AddPass(autotune_config, thread_pool); - return OkStatus(); + pipeline->AddPass(autotune_config, thread_pool); + return absl::OkStatus(); } -Status NVPTXCompiler::AddCustomKernelReplacementPasses( +absl::Status NVPTXCompiler::AddCustomKernelReplacementPasses( HloPassPipeline* pipeline, const DebugOptions& debug_options) { if (debug_options.xla_gpu_enable_cub_radix_sort()) { pipeline->AddPass(); } - return OkStatus(); + return absl::OkStatus(); } + +absl::Status NVPTXCompiler::RunCudnnFusionCompilerPass( + HloModule* module, se::StreamExecutor* stream_exec, + Thunk::BinaryMap* dnn_compiled_graphs) { + CuDnnFusionCompiler cudnn_compiler(*stream_exec, *dnn_compiled_graphs); + return cudnn_compiler.Run(module).status(); +} + namespace { // Try to load ptx from files defined in the FLAGS. If successful, return true. bool MaybeLoadPtxFromFile(const HloModuleConfig module_config, @@ -452,10 +479,13 @@ HloDataflowAnalysis::CanShareBuffer NVPTXCompiler::GetCanShareBuffer() const { return &CanShareBufferHint; } -StatusOr NVPTXCompiler::CompileTargetBinary( - const HloModuleConfig& module_config, llvm::Module* llvm_module, - se::GpuComputeCapability gpu_version, bool relocatable, - const HloModule* debug_module, const CompileOptions& options) { +absl::StatusOr +NVPTXCompiler::CompileTargetBinary(const HloModuleConfig& module_config, + llvm::Module* llvm_module, + se::GpuComputeCapability gpu_version, + bool relocatable, + const HloModule* debug_module, + const CompileOptions& options) { std::unique_ptr loaded_module = MaybeLoadLLVMFromFile(debug_module, llvm_module); llvm::Module* selected_module = nullptr; @@ -486,19 +516,118 @@ StatusOr NVPTXCompiler::CompileTargetBinary( RecordLlvmPassesAndLlvmToPtxDuration(end_usecs - start_usecs); } - StatusOr> maybe_cubin = CompileGpuAsmOrGetCachedResult( - ptx, std::get(gpu_version), module_config, - (debug_module != nullptr ? debug_module->name() : "(unknown)"), - relocatable, options); + absl::StatusOr> maybe_cubin = + CompileGpuAsmOrGetCachedResult( + ptx, std::get(gpu_version), module_config, + (debug_module != nullptr ? debug_module->name() : "(unknown)"), + relocatable, options); - if (maybe_cubin.status().code() == absl::StatusCode::kCancelled || - maybe_cubin.status().code() == absl::StatusCode::kResourceExhausted) { + if (!maybe_cubin.ok()) { return maybe_cubin.status(); } return BackendCompileResult{std::move(ptx), std::move(maybe_cubin.value())}; } -StatusOr> NVPTXCompiler::CompileGpuAsmOrGetCachedResult( +static absl::StatusOr> AssembleOptionsAndCompile( + const std::string& ptx, se::CudaComputeCapability cc, + const HloModuleConfig& hlo_module_config, + GpuCompiler::CompileOptions options, bool relocatable) { + if (ptx.empty()) { + return std::vector(); + } + + se::GpuAsmOpts ptxas_config = + PtxOptsFromDebugOptions(hlo_module_config.debug_options()); + if (relocatable) { + ptxas_config.extra_flags.push_back("-c"); + } + uint64_t start_usecs = tsl::Env::Default()->NowMicros(); + + bool cancel_if_reg_spill = + hlo_module_config.debug_options() + .xla_gpu_filter_kernels_spilling_registers_on_autotuning() && + options.is_autotuning_compilation; + + absl::StatusOr> maybe_cubin = [&] { + if (hlo_module_config.debug_options().xla_gpu_enable_libnvptxcompiler() && + se::IsLibNvPtxCompilerSupported()) { + return se::CompileGpuAsmUsingLibNvPtxCompiler( + cc.major, cc.minor, ptx.c_str(), ptxas_config, cancel_if_reg_spill); + } + + return se::CompileGpuAsmUsingPtxAs(cc.major, cc.minor, ptx.c_str(), + ptxas_config, cancel_if_reg_spill); + }(); + + if (maybe_cubin.ok()) { + uint64_t end_usecs = tsl::Env::Default()->NowMicros(); + // This won't record values for calls that error out (because if they + // error out we have no way of telling how far through the process we + // got). + RecordPtxToCubinDuration(end_usecs - start_usecs); + + VLOG(1) << "Compiled PTX size: " << ptx.size() + << "bytes. CUBIN size: " << maybe_cubin.value().size() << "bytes."; + + return maybe_cubin; + } + + if (maybe_cubin.status().code() == absl::StatusCode::kNotFound) { + if (!hlo_module_config.debug_options() + .xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found()) { + LOG(WARNING) << nvptx::CantFindCudaMessage( + "Can't find ptxas binary in ${CUDA_DIR}/bin. Custom ptxas " + "location can be specified using $PATH.", + hlo_module_config.debug_options().xla_gpu_cuda_data_dir()); + LOG(FATAL) << "Can't find ptxas binary. You can pass the flag " + "--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found " + "to use the GPU driver for compiling ptx instead. However " + "this option is discouraged and can lead to increased " + "memory consumptions and other subtle runtime issues."; + } + + // Missing ptxas is expected in some environments where CUDA SDK + // binaries are not available. We don't want to spam logs with + // identical warnings in this case. + LOG_FIRST_N(WARNING, 1) << nvptx::CantFindCudaMessage( + "Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to " + "the GPU driver for PTX -> sass compilation. This is OK so " + "long as you don't see a warning below about an out-of-date " + "driver version. Custom ptxas location can be specified " + "using $PATH.", + hlo_module_config.debug_options().xla_gpu_cuda_data_dir()); + + // We're going to use the driver to JIT our PTX->SASS, so warn if + // the JIT in the driver has known bugs. + WarnIfBadDriverJITVersion(); + return maybe_cubin; + } + + if (maybe_cubin.status().code() == absl::StatusCode::kCancelled) { + // Register spilling has occurred during autotuning. + CHECK(options.is_autotuning_compilation) << maybe_cubin.status(); + return maybe_cubin; + } + + if (maybe_cubin.status().code() == absl::StatusCode::kResourceExhausted) { + // Exhausting the register limit during autotuning is not a fatal + // error, we should just skip the problematic tiling. + CHECK(options.is_autotuning_compilation) << maybe_cubin.status(); + return maybe_cubin; + } + + if (maybe_cubin.status().code() != absl::StatusCode::kUnimplemented) { + return AppendStatus( + maybe_cubin.status(), + "If the error message indicates that a file could not be written, " + "please verify that sufficient filesystem space is provided."); + } + + return maybe_cubin; +} + +absl::StatusOr> +NVPTXCompiler::CompileGpuAsmOrGetCachedResult( const std::string& ptx, se::CudaComputeCapability cc, const HloModuleConfig& hlo_module_config, absl::string_view module_name, bool relocatable, const CompileOptions& options) { @@ -510,129 +639,48 @@ StatusOr> NVPTXCompiler::CompileGpuAsmOrGetCachedResult( !options.is_autotuning_compilation); tsl::profiler::TraceMe activity("PTX->CUBIN", tsl::profiler::TraceMeLevel::kInfo); - bool inserted; - decltype(compilation_cache_.begin()) iter; - // Pointers into compilation_cache_ where the ptx and (optional) cubin are - // stored. - const std::string* cache_ptx = nullptr; CompilationCacheValue* cache_value = nullptr; - - { + bool inserted = [&] { + auto flags = CompilationCacheFlags{ + hlo_module_config.debug_options() + .xla_gpu_filter_kernels_spilling_registers_on_autotuning()}; absl::MutexLock lock(&mutex_); - std::tie(iter, inserted) = compilation_cache_.emplace( + auto [iter, inserted] = compilation_cache_.emplace( std::piecewise_construct, - std::forward_as_tuple(ptx, cc.major, cc.minor, relocatable), + std::forward_as_tuple(ptx, cc.major, cc.minor, relocatable, flags), std::forward_as_tuple()); - cache_ptx = &iter->first.ptx; + // Do not move this assignment outside of the critical section. There is + // a TOCTOU if `compilation_cache_` is rehashed before the iterator is used. cache_value = &iter->second; - } + return inserted; + }(); // Compile the ptx if it wasn't in the cache before we called this function. // Other threads asking for the same compilation key will block on // cache_value->mutex_ until compilation is done. - { - absl::MutexLock lock(&cache_value->mutex); - if (inserted) { - CHECK(!cache_value->compilation_done); - if (!ptx.empty()) { - se::GpuAsmOpts ptxas_config = - PtxOptsFromDebugOptions(hlo_module_config.debug_options()); - if (relocatable) { - ptxas_config.extra_flags.push_back("-c"); - } - uint64_t start_usecs = tsl::Env::Default()->NowMicros(); - - bool cancel_if_reg_spill = - hlo_module_config.debug_options() - .xla_gpu_filter_kernels_spilling_registers_on_autotuning() && - options.is_autotuning_compilation; - StatusOr> maybe_cubin = - se::CompileGpuAsm(cc.major, cc.minor, cache_ptx->c_str(), - ptxas_config, cancel_if_reg_spill); - - if (maybe_cubin.ok()) { - uint64_t end_usecs = tsl::Env::Default()->NowMicros(); - // This won't record values for calls that error out (because if they - // error out we have no way of telling how far through the process we - // got). - RecordPtxToCubinDuration(end_usecs - start_usecs); - cache_value->cubin_data = std::move(maybe_cubin).value(); - VLOG(1) << "Compiled PTX size:" << ptx.size() - << " CUBIN size: " << cache_value->cubin_data.size(); - } else { - if (maybe_cubin.status().code() == absl::StatusCode::kNotFound) { - if (!hlo_module_config.debug_options() - .xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found()) { - LOG(WARNING) << nvptx::CantFindCudaMessage( - "Can't find ptxas binary in ${CUDA_DIR}/bin. Custom ptxas " - "location can be specified using $PATH.", - hlo_module_config.debug_options().xla_gpu_cuda_data_dir()); - LOG(FATAL) - << "Can't find ptxas binary. You can pass the flag " - "--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found " - "to use the GPU driver for compiling ptx instead. However " - "this option is discouraged and can lead to increased " - "memory consumptions and other subtle runtime issues."; - } - // Missing ptxas is expected in some environments where CUDA SDK - // binaries are not available. We don't want to spam logs with - // identical warnings in this case. - - LOG_FIRST_N(WARNING, 1) << nvptx::CantFindCudaMessage( - "Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to " - "the GPU driver for PTX -> sass compilation. This is OK so " - "long as you don't see a warning below about an out-of-date " - "driver version. Custom ptxas location can be specified " - "using $PATH.", - hlo_module_config.debug_options().xla_gpu_cuda_data_dir()); - } else if (maybe_cubin.status().code() == - absl::StatusCode::kCancelled) { - // Register spilling has occurred during autotuning, this config - // should not be tried further. - CHECK(options.is_autotuning_compilation); - cache_value->compilation_done = true; - cache_value->compilation_done_cv.SignalAll(); - return maybe_cubin; - } else if (maybe_cubin.status().code() == - absl::StatusCode::kResourceExhausted) { - // Exhausting the register limit during autotuning is not a fatal - // error, we should just skip the problematic tiling. - CHECK(options.is_autotuning_compilation); - cache_value->compilation_done = true; - cache_value->compilation_done_cv.SignalAll(); - return maybe_cubin; - } else if (maybe_cubin.status().code() != - absl::StatusCode::kUnimplemented) { - // If unimplemented is returned, we fallback to the driver. - LOG(FATAL) << "ptxas returned an error during compilation of ptx " - "to sass: '" - << maybe_cubin.status() << "' " - << "If the error message indicates that a file could " - "not be written, please verify that sufficient " - "filesystem space is provided."; - } - - // We're going to use the driver to JIT our PTX->SASS, so warn if - // the JIT in the driver has known bugs. - WarnIfBadDriverJITVersion(); - } - } + absl::MutexLock lock(&cache_value->mutex); + if (inserted) { + CHECK(!cache_value->compilation_done); + absl::Cleanup mark_compilation_as_done = [cache_value] { + // Note that we will set this to true also in the error case, so that we + // don't retry this compilation. cache_value->compilation_done = true; cache_value->compilation_done_cv.SignalAll(); - } else { - while (!cache_value->compilation_done) { - cache_value->compilation_done_cv.Wait(&cache_value->mutex); - } - } + }; + + cache_value->maybe_cubin = AssembleOptionsAndCompile( + ptx, cc, hlo_module_config, options, relocatable); + return cache_value->maybe_cubin; + } + + while (!cache_value->compilation_done) { + cache_value->compilation_done_cv.Wait(&cache_value->mutex); } - CHECK(cache_value != nullptr); - CHECK(cache_value->compilation_done); - return cache_value->cubin_data; + return cache_value->maybe_cubin; } -static std::optional> GetNvLinkVersion( - const std::string& preferred_cuda_dir) { +static bool IsNvlinkEnabled() { const bool use_nvlink_by_default = #ifdef TF_DISABLE_NVLINK_BY_DEFAULT false; @@ -643,23 +691,53 @@ static std::optional> GetNvLinkVersion( TF_CHECK_OK(tsl::ReadBoolFromEnvVar("TF_USE_NVLINK_FOR_PARALLEL_COMPILATION", /*default_val=*/ use_nvlink_by_default, &use_nvlink)); + return use_nvlink; +} + +absl::StatusOr ChooseLinkingMethodImpl( + const DebugOptions& debug_options, const std::string& preferred_cuda_dir) { + using LinkingMethod = NVPTXCompiler::LinkingMethod; + TF_ASSIGN_OR_RETURN(auto ptxas_version_tuple, + se::GetAsmCompilerVersion(preferred_cuda_dir)); - if (!use_nvlink) { - return std::nullopt; + auto nvlink_version = stream_executor::GetNvLinkVersion(preferred_cuda_dir); + if (IsNvlinkEnabled() && nvlink_version.ok() && + nvlink_version.value() >= ptxas_version_tuple) { + return LinkingMethod::kNvLink; } - // Make sure nvlink exists and is executable. - const std::string bin_path = - se::FindCudaExecutable("nvlink", preferred_cuda_dir); - auto version = se::GetToolVersion(bin_path); - if (!version.ok()) { - return std::nullopt; + int ptxas_version = std::get<0>(ptxas_version_tuple) * 1000 + + std::get<1>(ptxas_version_tuple) * 10; + TF_ASSIGN_OR_RETURN(int driver_version, + se::gpu::GpuDriver::GetDriverVersion()); + + if (driver_version >= ptxas_version) { + return LinkingMethod::kDriver; } - return *version; + + LOG_FIRST_N(WARNING, 1) + << "The NVIDIA driver's CUDA version is " + << absl::StrFormat("%d.%d", driver_version / 1000, + (driver_version % 1000) / 10) + << " which is older than the ptxas CUDA version " + << absl::StrFormat("(%d.%d.%d)", std::get<0>(ptxas_version_tuple), + std::get<1>(ptxas_version_tuple), + std::get<2>(ptxas_version_tuple)) + << ". Because the driver is older than the ptxas version, XLA is " + "disabling parallel compilation, which may slow down " + "compilation. " + "You should update your NVIDIA driver or use the " + "NVIDIA-provided " + "CUDA forward compatibility packages."; + + return LinkingMethod::kNone; } -StatusOr NVPTXCompiler::ChooseLinkingMethod( - const std::string& preferred_cuda_dir) { +absl::StatusOr NVPTXCompiler::ChooseLinkingMethod( + const DebugOptions& debug_options) { + se::GpuAsmOpts ptxas_config = PtxOptsFromDebugOptions(debug_options); + std::string& preferred_cuda_dir = ptxas_config.preferred_cuda_dir; + { absl::MutexLock lock(&mutex_); auto it = linking_methods_.find(preferred_cuda_dir); @@ -668,47 +746,11 @@ StatusOr NVPTXCompiler::ChooseLinkingMethod( } } - LinkingMethod linking_method = LinkingMethod::kNone; - TF_ASSIGN_OR_RETURN(auto ptxas_version_tuple, - se::GetAsmCompilerVersion(preferred_cuda_dir)); - - // ptxas versions prior to 11.8 are not supported anymore. We check this here, - // since we are fetching the ptxas version anyway. Catching the error - // elsewhere might introduce unnecessary overhead. - if (ptxas_version_tuple < std::array{11, 8, 0}) { - return Status(absl::StatusCode::kInternal, - "XLA requires ptxas version 11.8 or higher"); - } - - static const std::optional> nvlink_version = - GetNvLinkVersion(preferred_cuda_dir); - if (nvlink_version && *nvlink_version >= ptxas_version_tuple) { - linking_method = LinkingMethod::kNvLink; - } else { - int ptxas_version = std::get<0>(ptxas_version_tuple) * 1000 + - std::get<1>(ptxas_version_tuple) * 10; - int driver_version; - if (!se::gpu::GpuDriver::GetDriverVersion(&driver_version)) { - return FailedPrecondition("Unable to get CUDA driver version"); - } + // This wrapper only handles caching. The actual choice happens in this call: + TF_ASSIGN_OR_RETURN( + LinkingMethod linking_method, + ChooseLinkingMethodImpl(debug_options, preferred_cuda_dir)); - if (driver_version >= ptxas_version) { - linking_method = LinkingMethod::kDriver; - } else { - LOG_FIRST_N(WARNING, 1) - << "The NVIDIA driver's CUDA version is " - << absl::StrFormat("%d.%d", driver_version / 1000, - (driver_version % 1000) / 10) - << " which is older than the ptxas CUDA version " - << absl::StrFormat("(%d.%d.%d)", std::get<0>(ptxas_version_tuple), - std::get<1>(ptxas_version_tuple), - std::get<2>(ptxas_version_tuple)) - << ". Because the driver is older than the ptxas version, XLA is " - "disabling parallel compilation, which may slow down compilation. " - "You should update your NVIDIA driver or use the NVIDIA-provided " - "CUDA forward compatibility packages."; - } - } { absl::MutexLock lock(&mutex_); linking_methods_[preferred_cuda_dir] = linking_method; @@ -716,18 +758,16 @@ StatusOr NVPTXCompiler::ChooseLinkingMethod( return linking_method; } -StatusOr NVPTXCompiler::CanUseLinkModules( +absl::StatusOr NVPTXCompiler::CanUseLinkModules( const HloModuleConfig& hlo_module_config) { // TODO(phawkins): rather than comparing version numbers, it might be more // robust if we simply tried to link something the first time we compile. - auto ptxas_config = - PtxOptsFromDebugOptions(hlo_module_config.debug_options()); TF_ASSIGN_OR_RETURN(LinkingMethod linking_method, - ChooseLinkingMethod(ptxas_config.preferred_cuda_dir)); + ChooseLinkingMethod(hlo_module_config.debug_options())); return linking_method != LinkingMethod::kNone; } -StatusOr> NVPTXCompiler::LinkModules( +absl::StatusOr> NVPTXCompiler::LinkModules( se::StreamExecutor* stream_exec, std::vector> modules, const DebugOptions& debug_options) { auto ptxas_config = PtxOptsFromDebugOptions(debug_options); @@ -737,11 +777,10 @@ StatusOr> NVPTXCompiler::LinkModules( for (std::vector& module : modules) { images.push_back({"", std::move(module)}); } - auto context = static_cast( - stream_exec->platform_specific_handle().context); + auto context = se::gpu::ExtractGpuExecutor(stream_exec)->gpu_context(); TF_ASSIGN_OR_RETURN(LinkingMethod linking_method, - ChooseLinkingMethod(ptxas_config.preferred_cuda_dir)); + ChooseLinkingMethod(debug_options)); if (linking_method == LinkingMethod::kNvLink) { return LinkUsingNvlink(debug_options.xla_gpu_cuda_data_dir(), context, images); diff --git a/xla/service/gpu/nvptx_compiler.h b/xla/service/gpu/nvptx_compiler.h index 2de9551cb0aec..3d7a770282b13 100644 --- a/xla/service/gpu/nvptx_compiler.h +++ b/xla/service/gpu/nvptx_compiler.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,16 +16,30 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_NVPTX_COMPILER_H_ #define XLA_SERVICE_GPU_NVPTX_COMPILER_H_ +#include #include #include #include +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "llvm/IR/Module.h" #include "xla/autotune_results.pb.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/autotuner_util.h" #include "xla/service/gpu/gpu_compiler.h" -#include "xla/statusor.h" +#include "xla/service/hlo_dataflow_analysis.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/hlo_pass_pipeline.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/stream_executor_pimpl.h" #include "xla/xla.pb.h" #include "tsl/platform/threadpool.h" @@ -39,12 +53,12 @@ class NVPTXCompiler : public GpuCompiler { public: NVPTXCompiler(); - Status OptimizeHloConvolutionCanonicalization( + absl::Status OptimizeHloConvolutionCanonicalization( HloModule* hlo_module, se::GpuComputeCapability gpu_version, se::dnn::VersionInfo dnn_version, se::DeviceMemoryAllocator* device_allocator) override; - Status OptimizeHloPostLayoutAssignment( + absl::Status OptimizeHloPostLayoutAssignment( HloModule* hlo_module, se::StreamExecutor* stream_exec, const CompileOptions& options, const TargetConfig& gpu_target_config, tsl::thread::ThreadPool* thread_pool) override; @@ -52,55 +66,76 @@ class NVPTXCompiler : public GpuCompiler { bool RequiresCollectiveScheduleLinearizer( const HloModule* module, se::StreamExecutor* stream_exec) override; - Status AddConvAndGemmAutotuningPasses( + absl::Status AddConvAndGemmAutotuningPasses( HloPassPipeline* pipeline, HloModule* hlo_module, AutotuneConfig& autotune_config, tsl::thread::ThreadPool* thread_pool) override; - Status AddTritonGemmAutotuningPasses( + absl::Status AddGemmFusionAutotuningPasses( HloPassPipeline* pipeline, HloModule* hlo_module, AutotuneConfig& autotune_config, tsl::thread::ThreadPool* thread_pool) override; - Status AddCustomKernelReplacementPasses( + absl::Status AddCustomKernelReplacementPasses( HloPassPipeline* pipeline, const DebugOptions& debug_options) override; + absl::Status RunCudnnFusionCompilerPass( + HloModule* module, se::StreamExecutor* stream_exec, + Thunk::BinaryMap* dnn_compiled_graphs) override; + HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() const override; - StatusOr CompileTargetBinary( + absl::StatusOr CompileTargetBinary( const HloModuleConfig& module_config, llvm::Module* llvm_module, se::GpuComputeCapability gpu_version, bool relocatable, const HloModule* debug_module, const CompileOptions& options) override; + enum class LinkingMethod { + kNone, + kNvLink, + kDriver, + }; + private: - StatusOr CanUseLinkModules( + absl::StatusOr CanUseLinkModules( const HloModuleConfig& module_config) override; - StatusOr> LinkModules( + absl::StatusOr> LinkModules( se::StreamExecutor* stream_exec, std::vector> modules, const DebugOptions& debug_options) override; absl::Mutex mutex_; - enum class LinkingMethod { - kNone, - kNvLink, - kDriver, - }; absl::flat_hash_map linking_methods_ ABSL_GUARDED_BY(mutex_); - StatusOr ChooseLinkingMethod( - const std::string& preferred_cuda_dir); + absl::StatusOr ChooseLinkingMethod( + const DebugOptions& debug_options); // Tries to compile the given ptx string to cubin. Returns a vector with the // compiled cubin if compilation succeeded. - StatusOr> CompileGpuAsmOrGetCachedResult( + absl::StatusOr> CompileGpuAsmOrGetCachedResult( const std::string& ptx, se::CudaComputeCapability cc, const HloModuleConfig& hlo_module_config, absl::string_view module_name, bool relocatable, const CompileOptions& options); + struct CompilationCacheFlags { + template + friend H AbslHashValue(H h, const CompilationCacheFlags& flags) { + return H::combine(std::move(h), + flags.filter_kernels_spilling_registers_on_autotuning); + } + + friend bool operator==(const CompilationCacheFlags& a, + const CompilationCacheFlags& b) { + return a.filter_kernels_spilling_registers_on_autotuning == + b.filter_kernels_spilling_registers_on_autotuning; + } + + bool filter_kernels_spilling_registers_on_autotuning; + }; + // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor} // -> cubin so we don't recompile the same ptx twice. This is important for // some interactive workflows. (We also cache at the HLO level, but sometimes @@ -115,29 +150,36 @@ class NVPTXCompiler : public GpuCompiler { // and leave compilation up to the driver. struct CompilationCacheKey { CompilationCacheKey(std::string ptx, int cc_major, int cc_minor, - bool relocatable) + bool relocatable, CompilationCacheFlags flags) : ptx(std::move(ptx)), cc_major(cc_major), cc_minor(cc_minor), - relocatable(relocatable) {} + relocatable(relocatable), + flags(std::move(flags)) {} + template friend H AbslHashValue(H h, const CompilationCacheKey& key) { return H::combine(std::move(h), key.ptx, key.cc_major, key.cc_minor, - key.relocatable); + key.relocatable, key.flags); } + friend bool operator==(const CompilationCacheKey& a, const CompilationCacheKey& b) { return a.cc_major == b.cc_major && a.cc_minor == b.cc_minor && - a.ptx == b.ptx && a.relocatable == b.relocatable; + a.ptx == b.ptx && a.relocatable == b.relocatable && + a.flags == b.flags; } + std::string ptx; int cc_major; int cc_minor; bool relocatable; + CompilationCacheFlags flags; }; + struct CompilationCacheValue { bool compilation_done = false; - std::vector cubin_data; + absl::StatusOr> maybe_cubin; // mutex and condition variable to serialize compilation completing. absl::Mutex mutex; absl::CondVar compilation_done_cv; diff --git a/xla/service/gpu/nvptx_compiler_registration.cc b/xla/service/gpu/nvptx_compiler_registration.cc index 22138bb19732f..7d7bd94ac20a7 100644 --- a/xla/service/gpu/nvptx_compiler_registration.cc +++ b/xla/service/gpu/nvptx_compiler_registration.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + +#include "xla/service/compiler.h" #include "xla/service/gpu/nvptx_compiler.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" diff --git a/xla/service/gpu/nvptx_compiler_test.cc b/xla/service/gpu/nvptx_compiler_test.cc index f32502deb6075..f8fab8f8a4d04 100644 --- a/xla/service/gpu/nvptx_compiler_test.cc +++ b/xla/service/gpu/nvptx_compiler_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,16 +15,20 @@ limitations under the License. #include "xla/service/gpu/nvptx_compiler.h" +#include #include #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/backend.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/gpu_compiler.h" -#include "xla/statusor.h" +#include "xla/service/gpu/gpu_constants.h" +#include "xla/service/gpu/gpu_hlo_schedule.h" +#include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" #include "xla/xla.pb.h" @@ -55,11 +59,24 @@ int64_t CountCopies(const HloModule& module) { class NVPTXCompilerTest : public HloTestBase { public: - StatusOr> AssignBuffers(HloModule* module) { - Backend& test_backend = backend(); - NVPTXCompiler compiler; - return compiler.AssignBuffers(module, - test_backend.default_stream_executor()); + absl::StatusOr> AssignBuffers( + HloModule* module) { + constexpr uint64_t pointer_size = 4; + const se::DeviceDescription& gpu_device_info = + backend().default_stream_executor()->GetDeviceDescription(); + TF_RETURN_IF_ERROR( + ScheduleGpuModule(module, pointer_size, gpu_device_info).status()); + + auto buffer_size_bytes_function = + [this](const BufferValue& buffer_value) -> int64_t { + return GetSizeOfShape(buffer_value.shape(), pointer_size); + }; + + return BufferAssigner::Run( + module, std::make_unique(module->schedule()), + buffer_size_bytes_function, + /*color_alignment=*/ + [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; }); } }; @@ -130,7 +147,7 @@ ENTRY entry { TEST_F(NVPTXCompilerTestTriton, DotDimensionAreSortedBeforePaddingForCublasEnablingTritonFusion) { - MatchOptimizedHlo(R"( + const absl::string_view hlo_string = R"( ENTRY e { p0 = f16[11,22,33,44] parameter(0) p1 = s8[11,22,33,44] parameter(1) @@ -138,13 +155,25 @@ ENTRY e { ROOT d = f16[11,22,44,44] dot(p0, p1c), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} -})", - R"( +})"; + + se::CudaComputeCapability cc = backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + + if (cc.IsAtLeastAmpere()) { + MatchOptimizedHlo(hlo_string, R"( ; CHECK: ENTRY ; CHECK-NEXT: parameter ; CHECK-NEXT: parameter ; CHECK-NEXT: __triton_gemm - )"); + )"); + } else { + MatchOptimizedHlo(hlo_string, R"( +; CHECK-NOT: triton + )"); + } } TEST_F(NVPTXCompilerTest, RemovesUnnecessaryCopyInPostSchedulingPipelines) { @@ -202,7 +231,9 @@ ENTRY main { HloOpcode::kCopy); NVPTXCompiler compiler; - TF_EXPECT_OK(compiler.RunPostSchedulingPipelines(module.get(), 100000)); + TF_EXPECT_OK(compiler.RunPostSchedulingPipelines( + module.get(), 100000, + backend().default_stream_executor()->GetDeviceDescription())); EXPECT_EQ(CountCopies(*module), 3); while_op = hlo_query::GetFirstInstructionWithOpcode( *module->entry_computation(), HloOpcode::kWhile); diff --git a/xla/service/gpu/outfeed_manager.cc b/xla/service/gpu/outfeed_manager.cc index ca49526e6c00e..a4b25db4706bb 100644 --- a/xla/service/gpu/outfeed_manager.cc +++ b/xla/service/gpu/outfeed_manager.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,7 +17,10 @@ limitations under the License. #include -#include "xla/map_util.h" +#include "absl/status/status.h" +#include "xla/literal.h" +#include "xla/shape.h" +#include "xla/shape_tree.h" #include "xla/shape_util.h" #include "tsl/platform/logging.h" @@ -41,7 +44,7 @@ OutfeedManager *GetOrCreateOutfeedManager(se::StreamExecutor *executor) { #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } -Status OutfeedManager::TransferLiteralFromOutfeed( +absl::Status OutfeedManager::TransferLiteralFromOutfeed( se::StreamExecutor* executor, MutableBorrowingLiteral literal) { ShapeTree> outfeed_buffers( &literal.shape()); @@ -68,7 +71,7 @@ Status OutfeedManager::TransferLiteralFromOutfeed( leaf.second->WaitUntilAvailable(); } - return OkStatus(); + return absl::OkStatus(); } } // namespace gpu diff --git a/xla/service/gpu/outfeed_manager.h b/xla/service/gpu/outfeed_manager.h index 4da268a26b0b9..c9042af0f9ad8 100644 --- a/xla/service/gpu/outfeed_manager.h +++ b/xla/service/gpu/outfeed_manager.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,9 +16,11 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_OUTFEED_MANAGER_H_ #define XLA_SERVICE_GPU_OUTFEED_MANAGER_H_ +#include #include #include +#include "absl/status/status.h" #include "xla/literal.h" #include "xla/service/gpu/xfeed_queue.h" #include "xla/shape_tree.h" @@ -61,8 +63,8 @@ class OutfeedBuffer { class OutfeedManager : public XfeedQueue>*> { public: - Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, - MutableBorrowingLiteral literal); + absl::Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, + MutableBorrowingLiteral literal); }; // Returns the GPU outfeed manager for the given stream executor. diff --git a/xla/service/gpu/parallel_loop_emitter.cc b/xla/service/gpu/parallel_loop_emitter.cc index d9c79c2a2225b..903ffdbb1d78c 100644 --- a/xla/service/gpu/parallel_loop_emitter.cc +++ b/xla/service/gpu/parallel_loop_emitter.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,11 +16,20 @@ limitations under the License. #include "xla/service/gpu/parallel_loop_emitter.h" #include -#include +#include #include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" #include "xla/primitive_util.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/llvm_ir/ir_array.h" +#include "xla/service/llvm_ir/loop_emitter.h" +#include "xla/shape.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" @@ -65,7 +74,8 @@ ParallelLoopEmitter::EmitLinearBaseAndThreadIdx(llvm::Type* index_type, llvm::Value* block_id = EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_counts().x, - static_cast(block_id)); + static_cast(block_id), + b_->GetInsertBlock()->getModule()); block_id = b_->CreateZExtOrTrunc(block_id, index_type, "block_id"); // Per the PTX documentation: @@ -73,7 +83,8 @@ ParallelLoopEmitter::EmitLinearBaseAndThreadIdx(llvm::Type* index_type, llvm::Value* thread_id_x = EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.thread_counts_per_block().x, - static_cast(thread_id_x)); + static_cast(thread_id_x), + b_->GetInsertBlock()->getModule()); thread_id_x = b_->CreateZExtOrTrunc(thread_id_x, index_type, "thread_id_x"); llvm::Value* linear_index_base = @@ -87,7 +98,8 @@ ParallelLoopEmitter::EmitLinearBaseAndThreadIdx(llvm::Type* index_type, llvm::Value* thread_id_y = EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdy, {}, {}, b_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.thread_counts_per_block().y, - static_cast(thread_id_y)); + static_cast(thread_id_y), + b_->GetInsertBlock()->getModule()); thread_id_y = b_->CreateZExtOrTrunc(thread_id_y, index_type, "thread_id_y"); linear_index_base = b_->CreateAdd( linear_index_base, @@ -176,41 +188,37 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, EmitLinearBaseAndThreadIdx(index_type, base_index); llvm::Value* linear_index_base = linear_base_and_thread_idx.linear_base; - llvm::Value* thread_id_x = linear_base_and_thread_idx.thread_idx; - - // When enable_row_index is true, it means the inner most dimensions - // match the block sizes. So we can generate a simpler indexing - // for that dimensions. This helps LLVM generate vectorized codes - // in that cases. - llvm::Value* row_index = nullptr; - if (!launch_config_.row_vectorized) { - array_indices.emplace_back(linear_index_base, shape_, b_); - } else { - // Simpler index for row computation. - // This will allow LLVM to vectorize. - row_index = b_->CreateMul( - thread_id_x, - llvm::ConstantInt::get(index_type, launch_config_.unroll_factor), - "row_index", /*HasNUW=*/true, /*HasNSW=*/true); - std::vector multidim(shape_.rank(), nullptr); - multidim.back() = row_index; - array_indices.emplace_back(linear_index_base, multidim, shape_, b_); - } - for (int i = 1; i < launch_config_.unroll_factor; ++i) { + llvm::Value* row_index = + launch_config_.row_vectorized + ? b_->CreateMul(linear_base_and_thread_idx.thread_idx, + llvm::ConstantInt::get(index_type, + launch_config_.unroll_factor), + "row_index", /*HasNUW=*/true, /*HasNSW=*/true) + : nullptr; + + std::vector multidim(shape_.rank(), nullptr); + for (int i = 0; i < launch_config_.unroll_factor; ++i) { + // The add operation is needed even if the offset is 0, since when the + // kernel is unrolled, the following GEP instruction shares the same pointer + // and sequential indices with others, allowing the default SLP pass to + // optimize them into vectorized load/store operations. llvm::Value* linear_index = b_->CreateAdd(linear_index_base, llvm::ConstantInt::get(index_type, i), absl::StrCat("linear_index", i), /*HasNUW=*/true, /*HasNSW=*/true); - if (!launch_config_.row_vectorized) { - array_indices.emplace_back(linear_index, shape_, b_); - } else { - std::vector multidim(shape_.rank(), nullptr); - multidim.back() = b_->CreateAdd( - row_index, llvm::ConstantInt::get(index_type, i), - absl::StrCat("row_index_plus", i), /*HasNUW=*/true, /*HasNSW=*/true); - array_indices.emplace_back(linear_index, multidim, shape_, b_); + if (launch_config_.row_vectorized) { + // This lets us avoid emitting the division for the last dimension of the + // index. The check for i > 0 is here for historical reasons, it might not + // do anything. + multidim.back() = + i == 0 ? row_index + : b_->CreateAdd( + row_index, llvm::ConstantInt::get(index_type, i), + absl::StrCat("row_index_plus", i), /*HasNUW=*/true, + /*HasNSW=*/true); } + array_indices.emplace_back(linear_index, multidim, shape_, b_); } auto if_in_bounds = llvm_ir::EmitIfThenElse( @@ -229,9 +237,9 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, return array_indices; } -Status ParallelLoopEmitter::EmitSerialLoop(absl::string_view loop_name, - llvm::Type* index_type, - llvm::Value* base_indvar) { +absl::Status ParallelLoopEmitter::EmitSerialLoop(absl::string_view loop_name, + llvm::Type* index_type, + llvm::Value* base_indvar) { int64_t num_elements = ShapeUtil::ElementsIn(shape_); bool check_bounds = num_elements % launch_config_.unroll_factor > 0; for (const llvm_ir::IrArray::Index& array_index : @@ -256,11 +264,11 @@ Status ParallelLoopEmitter::EmitSerialLoop(absl::string_view loop_name, llvm_ir::SetToFirstInsertPoint(if_in_bounds.after_block, b_); } } - return OkStatus(); + return absl::OkStatus(); } -Status ParallelLoopEmitter::EmitLoop(absl::string_view loop_name, - llvm::Type* index_type) { +absl::Status ParallelLoopEmitter::EmitLoop(absl::string_view loop_name, + llvm::Type* index_type) { if (index_type == nullptr) { index_type = b_->getInt64Ty(); } @@ -289,7 +297,7 @@ Status ParallelLoopEmitter::EmitLoop(absl::string_view loop_name, // code emitted for later instructions will be correctly placed. CHECK(exit_bb_->getTerminator()); b_->SetInsertPoint(exit_bb_->getTerminator()); - return OkStatus(); + return absl::OkStatus(); } } // namespace gpu diff --git a/xla/service/gpu/parallel_loop_emitter.h b/xla/service/gpu/parallel_loop_emitter.h index 07a5b9f128385..87fca57141d34 100644 --- a/xla/service/gpu/parallel_loop_emitter.h +++ b/xla/service/gpu/parallel_loop_emitter.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,19 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_PARALLEL_LOOP_EMITTER_H_ #define XLA_SERVICE_GPU_PARALLEL_LOOP_EMITTER_H_ +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/loop_emitter.h" +#include "xla/shape.h" namespace xla { namespace gpu { @@ -56,8 +65,8 @@ class ParallelLoopEmitter { absl::string_view loop_name, llvm::Type* index_type, llvm::Value* base_index); - Status EmitLoop(absl::string_view loop_name = "", - llvm::Type* index_type = nullptr); + absl::Status EmitLoop(absl::string_view loop_name = "", + llvm::Type* index_type = nullptr); private: struct LinearBaseAndThreadIdx { @@ -67,8 +76,9 @@ class ParallelLoopEmitter { LinearBaseAndThreadIdx EmitLinearBaseAndThreadIdx(llvm::Type* index_type, llvm::Value* base_index); - Status EmitSerialLoop(absl::string_view loop_name, llvm::Type* index_type, - llvm::Value* base_indvar = nullptr); + absl::Status EmitSerialLoop(absl::string_view loop_name, + llvm::Type* index_type, + llvm::Value* base_indvar = nullptr); // The thread and block dimension to parallelize the loop on. const LaunchDimensions launch_dimensions_; diff --git a/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc b/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc index 34cdfec28a0de..a4ca2005752a2 100644 --- a/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc +++ b/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,12 +17,14 @@ limitations under the License. #include #include +#include #include "xla/hlo/ir/hlo_module.h" #include "xla/service/copy_insertion.h" #include "xla/service/cpu_gpu_shape_verifier.h" #include "xla/service/gpu/alias_passthrough_params.h" #include "xla/service/gpu/copy_fusion.h" +#include "xla/service/gpu/gpu_flash_attn_normalization.h" #include "xla/service/gpu/gpu_sanitize_constant_names.h" #include "xla/service/gpu/horizontal_loop_fusion.h" #include "xla/service/hlo_dataflow_analysis.h" @@ -85,6 +87,7 @@ HloPassPipeline PrepareHloModuleForIrEmittingPipeline( sub_pipeline.AddPass("copy_"); sub_pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); return pipeline; } diff --git a/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h b/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h index 73881e82614ea..8a40b10490fe2 100644 --- a/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h +++ b/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/priority_fusion.cc b/xla/service/gpu/priority_fusion.cc index 0c0c4c69a5877..c96c65319b259 100644 --- a/xla/service/gpu/priority_fusion.cc +++ b/xla/service/gpu/priority_fusion.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,37 +22,45 @@ limitations under the License. #include #include #include -#include +#include #include #include -#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/meta/type_traits.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" +#include "llvm/ADT/STLExtras.h" +#include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/dump.h" -#include "xla/service/fusion_node_indexing_evaluation.h" #include "xla/service/fusion_queue.h" #include "xla/service/gpu/fusion_process_dump.pb.h" #include "xla/service/gpu/gpu_fusible.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" +#include "xla/service/hlo_graph_dumper.h" #include "xla/service/instruction_fusion.h" #include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" #include "tsl/platform/blocking_counter.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" +#include "tsl/platform/threadpool.h" namespace xla { namespace gpu { @@ -63,6 +71,45 @@ bool ElementIsF32OrF16(const Shape& shape) { return type == F32 || type == F16; } +bool IsFusible(const HloInstruction& instr) { + // Side-effecting operations are not fusible. + if (!instr.IsFusible()) { + return false; + } + + // Element-wise operations are always fusible. + if (instr.IsElementwise()) { + return true; + } + + // Other non-elementwise ops also supported by elemental fusion. + switch (instr.opcode()) { + case HloOpcode::kFusion: + return instr.fusion_kind() != HloInstruction::FusionKind::kCustom; + + case HloOpcode::kCopy: + case HloOpcode::kIota: + case HloOpcode::kConstant: + case HloOpcode::kReduce: + case HloOpcode::kBitcast: + case HloOpcode::kBroadcast: + case HloOpcode::kConcatenate: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kGather: + case HloOpcode::kPad: + case HloOpcode::kReduceWindow: + case HloOpcode::kReshape: + case HloOpcode::kReverse: + case HloOpcode::kScatter: + case HloOpcode::kSlice: + case HloOpcode::kTranspose: + return true; + default: + return false; + } +} + // An implementation of FusionQueue that determines whether to fuse instructions // according to a cost model, and chooses the next fusion candidate according to // dynamically updated priorities. The elements in the queue are producer nodes @@ -70,7 +117,7 @@ bool ElementIsF32OrF16(const Shape& shape) { // performance when fusing it to all of its fusible users. We greedily pick the // max-benefit producer to fuse, and update the estimated benefits of the fused // nodes and their operands. -class GpuPriorityFusionQueue : public FusionQueue { +class GpuPriorityFusionQueue { using Priority = int64_t; using CanFuseCallback = std::function; @@ -79,19 +126,23 @@ class GpuPriorityFusionQueue : public FusionQueue { GpuPriorityFusionQueue( HloComputation* computation, const GpuHloCostAnalysis::Options& cost_analysis_options, - const se::DeviceDescription* device_info, const CanFuseCallback& can_fuse, + const se::DeviceDescription* device_info, FusionProcessDumpProto* fusion_process_dump, tsl::thread::ThreadPool* thread_pool, HloFusionAnalysisCache& fusion_analysis_cache) : computation_(computation), cost_analysis_(cost_analysis_options, device_info), - can_fuse_(can_fuse), fusion_process_dump_(fusion_process_dump), thread_pool_(thread_pool), fusion_analysis_cache_(fusion_analysis_cache) { VLOG(2) << "Running full HLO cost analysis for " << computation_->name(); TF_CHECK_OK(computation_->Accept(&cost_analysis_)); + dump_fusion_visualization_ = computation->parent() + ->config() + .debug_options() + .xla_dump_fusion_visualization(); + // Initializes the priority queue. std::vector instructions; for (auto* instruction : computation->MakeInstructionPostOrder()) { @@ -103,12 +154,36 @@ class GpuPriorityFusionQueue : public FusionQueue { } instructions.push_back(instruction); } + + ComputeAndSetPriorities(instructions); + } + + void ComputeAndSetPriorities( + const std::vector& instructions) { std::vector priorities = ComputePriorities(instructions); for (auto [instruction, priority] : llvm::zip(instructions, priorities)) { - auto emplace_result = producer_priority_queue_.emplace( - std::make_pair(priority, instruction->unique_id()), instruction); - CHECK(emplace_result.second); + auto key = std::make_pair(priority, instruction->unique_id()); + + // Remove instruction with the old priority from the queue. + auto reverse_it = reverse_map_.find(instruction); + if (reverse_it != reverse_map_.end()) { + const PriorityQueue::iterator& queue_it = reverse_it->second; + // Priority didn't change. Nothing to do. + if (key == queue_it->first) { + continue; + } + producer_priority_queue_.erase(queue_it); + reverse_map_.erase(reverse_it); + } + + // If the priority is negative, it's not helpful to perform fusion on this + // instruction. + if (priority < 0) { + continue; + } + + auto emplace_result = producer_priority_queue_.emplace(key, instruction); reverse_map_.emplace(instruction, emplace_result.first); } } @@ -135,45 +210,84 @@ class GpuPriorityFusionQueue : public FusionQueue { return priorities; } - std::pair> - DequeueNextInstructionAndOperandsToFuseInOrder() override { - while (current_consumers_.empty()) { - if (producer_priority_queue_.empty()) { - return {}; - } + // Gets the next pair of (producer, consumers) from the queue for fusion. + // Returns true if there is the next producer to fuse, otherwise false. Stores + // the producer and consumers in `current_producer_` and `current_consumers_`. + bool DequeueNextProducer() { + current_producer_ = nullptr; + current_consumers_.clear(); + + while (!producer_priority_queue_.empty() && current_consumers_.empty()) { auto next_it = std::prev(producer_priority_queue_.end()); - auto priority = next_it->first.first; current_producer_ = next_it->second; producer_priority_queue_.erase(next_it); reverse_map_.erase(current_producer_); - // If the priority is negative, it's not helpful to perform fusion on this - // instruction. - if (priority < 0) { - continue; - } current_consumers_ = current_producer_->users(); + + if (current_producer_->opcode() == HloOpcode::kBitcast) { + // We don't check if bitcasts can be fused with all consumers, so we + // have to do it here. + llvm::erase_if(current_consumers_, [&](HloInstruction* consumer) { + return !CanFuseCached(current_producer_, consumer); + }); + } + } + + return !current_consumers_.empty(); + } + + // Update priorities of all affected ops. + void UpdatePriorities() { + // Revisit costs of all updated ops. It's important to update cost analysis + // before recalculating priorities. + for (auto instruction : to_update_priority_) { + TF_CHECK_OK(cost_analysis_.RevisitInstruction(instruction)); + } + + ComputeAndSetPriorities(std::vector{ + to_update_priority_.begin(), to_update_priority_.end()}); + + to_update_priority_.clear(); + } + + // Prepares producer and consumer instruction to be fused. Invalidates caches + // and writes logs. + void PreFusion(HloInstruction* producer, HloInstruction* consumer) { + if (dump_fusion_visualization_) { + RegisterFusionState( + *computation_, + absl::StrCat("About to fuse |", producer->name(), "| into |", + consumer->name(), "| inside PriorityFusion"), + *consumer, producer); } - auto next_consumer = current_consumers_.back(); - int64_t producer_operand_index = - next_consumer->operand_index(current_producer_); - current_consumers_.pop_back(); - VLOG(5) << "next: " << next_consumer->name() << "(" << next_consumer - << ") + " << current_producer_->name() << "(" << current_producer_ - << ")"; - return {next_consumer, {producer_operand_index}}; + InvalidateCaches(producer); + InvalidateCaches(consumer); } - // Calculates the compute cost and free computation of the new fusion in the - // PreFusion callback. - void PreFusion(HloInstruction* producer, HloInstruction* consumer) override {} + // Invalidates all cached value related to this instruction. Called before the + // instruction is fused. The instruction can be either producer or consumer. + void InvalidateCaches(HloInstruction* instruction) { + HloInstructionAdaptor instruction_adaptor(*instruction); + + can_fuse_cache_.erase(instruction_adaptor); + for (auto operand : instruction_adaptor.GetOperands()) { + auto it = can_fuse_cache_.find(operand); + if (it != can_fuse_cache_.end()) { + it->second.erase(instruction_adaptor); + } + } + + gpu_performance_model_cache_.Invalidate(*instruction); + fusion_analysis_cache_.Invalidate(*instruction); + } // Updates data for the new fusion instruction and its users and operands. void OnFusingInstruction(HloInstruction* fusion, HloInstruction* original_producer, - HloInstruction* original_consumer) override { + HloInstruction* original_consumer) { if (fusion_process_dump_) { auto* fusion_step = fusion_process_dump_->add_fusion_steps()->mutable_fusion(); @@ -184,13 +298,13 @@ class GpuPriorityFusionQueue : public FusionQueue { fusion_step->set_consumer_name(std::string(original_consumer->name())); } - HloInstructionAdaptor fusion_adaptor(*fusion); - can_fuse_cache_.erase(fusion_adaptor); - - gpu_performance_model_cache_.Invalidate(*fusion); - - fusion_analysis_cache_.Invalidate(*fusion); - fusion_analysis_cache_.Invalidate(*original_producer); + if (dump_fusion_visualization_) { + RegisterFusionState( + *computation_, + absl::StrCat("Fused |", original_producer->name(), "| into |", + fusion->name(), "| inside PriorityFusion"), + *fusion); + } // The original consumer was replaced with the fusion, but it's pointer can // still be referenced somewhere, for example, in to_update_priority_. @@ -211,8 +325,6 @@ class GpuPriorityFusionQueue : public FusionQueue { // Collect the instructions whose priorities need to be updated. for (HloInstruction* operand : fusion->operands()) { if (operand == original_producer || - original_producer->opcode() == HloOpcode::kBroadcast || - operand->opcode() == HloOpcode::kBroadcast || operand->opcode() == HloOpcode::kConstant || operand->opcode() == HloOpcode::kGetTupleElement) { continue; @@ -223,52 +335,13 @@ class GpuPriorityFusionQueue : public FusionQueue { continue; } - HloInstructionAdaptor operand_adaptor(*operand); - can_fuse_cache_[operand_adaptor].erase(fusion_adaptor); to_update_priority_.insert(operand); } to_update_priority_.insert(fusion); - - // When current_consumers_ is empty, we will need to dequeue a new producer - // next time, so we update the priorities now. - if (current_consumers_.empty()) { - // Revisit costs of all updated ops. It's important to update cost - // analysis before recalculating priorities. - for (auto instruction : to_update_priority_) { - TF_CHECK_OK(cost_analysis_.RevisitInstruction(instruction)); - } - - std::vector to_update_vector{to_update_priority_.begin(), - to_update_priority_.end()}; - std::vector new_priorities = - ComputePriorities(to_update_vector); - - for (auto [instruction, new_priority] : - llvm::zip(to_update_vector, new_priorities)) { - auto reverse_it = reverse_map_.find(instruction); - const auto new_key = - std::make_pair(new_priority, instruction->unique_id()); - if (reverse_it != reverse_map_.end()) { - if (new_key == reverse_it->second->first) { - continue; - } - producer_priority_queue_.erase(reverse_it->second); - } - auto emplace_result = - producer_priority_queue_.emplace(new_key, instruction); - CHECK(emplace_result.second); - if (reverse_it != reverse_map_.end()) { - reverse_it->second = emplace_result.first; - } else { - reverse_map_.emplace(instruction, emplace_result.first); - } - } - to_update_priority_.clear(); - } } // Removes data for the instruction. - void RemoveInstruction(HloInstruction* instruction) override { + void RemoveInstruction(HloInstruction* instruction) { to_update_priority_.erase(instruction); fusion_analysis_cache_.Invalidate(*instruction); @@ -280,14 +353,29 @@ class GpuPriorityFusionQueue : public FusionQueue { reverse_map_.erase(reverse_it); } - const std::vector* FusionConfiguration() override { return nullptr; } + HloInstruction* current_producer() { return current_producer_; } + + const std::vector& current_consumers() { + return current_consumers_; + } private: // Returns the priority of the producer based on its current operands and // users. Priority CalculateProducerPriority(HloInstruction* producer) { + // Bitcasts should always be fused first, since they are no-ops. + if (producer->opcode() == HloOpcode::kBitcast) { + return std::numeric_limits::max(); + } + // We always fuse constants, but the cost model doesn't handle them very + // well: fusing constants changes costs significantly. Also, there's no + // point recomputing priorities. Therefore, we fuse all of them at the end. + if (producer->opcode() == HloOpcode::kConstant) { + return std::numeric_limits::min(); + } + // Don't fuse if we can't fuse in all users. - if (auto fusion_decision = CanFuseWithAllUsers(producer); + if (auto fusion_decision = CanFuseWithAllNonBitcastUsers(producer); !fusion_decision) { if (fusion_process_dump_) { absl::MutexLock lock(&fusion_process_dump_mutex_); @@ -300,11 +388,12 @@ class GpuPriorityFusionQueue : public FusionQueue { } GpuPerformanceModel::RunTimes run_times = - GpuPerformanceModel::EstimateRunTimes( + GpuPerformanceModel::EstimateRunTimesForPriorityFusion( producer, &cost_analysis_, GpuPerformanceModelOptions::PriorityFusion( &fusion_analysis_cache_, &gpu_performance_model_cache_), producer->users()); + if (fusion_process_dump_) { absl::MutexLock lock(&fusion_process_dump_mutex_); auto* step = @@ -320,6 +409,93 @@ class GpuPriorityFusionQueue : public FusionQueue { run_times.time_fused); } + FusionDecision CanFuse(HloInstruction* producer, HloInstruction* consumer) { + if (!IsFusible(*producer)) { + return "the producer is not fusible"; + } + + if (!IsFusible(*consumer)) { + return "the consumer is not fusible"; + } + + if (consumer->opcode() == HloOpcode::kBitcast) { + return "not fusing into a single bitcast as consumer"; + } + + // Scatter is special as it has no elemental version but is still input + // fusible. Block attempts to create scatter fusions we can't codegen. + if (auto can_fuse = CanEmitInputFusedScatter(*producer, *consumer); + !can_fuse) { + return can_fuse; + } + + // Avoid fusing reduce into reduce. Our cost model doesn't currently + // understand this case due to a lack of tiling analysis. + // TODO(b/312200883): Remove this. + auto contains_significant_reduce = [&](const HloInstruction* instr) { + auto fusion = HloFusionAdaptor::ForInstruction(instr); + return HloAnyOf(fusion->GetRoots(), *fusion, [](auto node) { + if (!(node.opcode() == HloOpcode::kReduce && node.shape().IsArray())) { + return false; + } + + int64_t reduction_size = + ShapeUtil::ElementsIn(node.instruction().operand(0)->shape()) / + ShapeUtil::ElementsIn(node.shape()); + + // Small reductions are emitted using the elemental emitter anyway. + return reduction_size >= 16; + }); + }; + if (contains_significant_reduce(producer) && + contains_significant_reduce(consumer)) { + return "both the producer and the consumer contain a reduce"; + } + + // Avoid doing fusions into the output of an "input" fusion when it would + // switch it to the loop emitter. This often occurs during epilog fusion for + // reductions, which suffer from limited emitter support. + // TODO(b/312686229): Cost model should handle this. + const auto& analysis = fusion_analysis_cache_.Get(*producer); + if (analysis.GetEmitterFusionKind() == + HloFusionAnalysis::EmitterFusionKind::kReduction) { + const auto& analysis_fused = + fusion_analysis_cache_.Get(*producer, *consumer); + if (analysis_fused.GetEmitterFusionKind() == + HloFusionAnalysis::EmitterFusionKind::kLoop) { + return "fusion into output of a reduce fusion would create a loop " + "fusion"; + } + } + + // Avoid cases where we'd create a fusion that hit limitations in ptxas. + // Would be nice to model this with cost instead. + if (auto fits_budget = FusionFitsInBudget( + *consumer, *producer, *cost_analysis_.device_info_, + /*is_consumer_producer_fusion=*/true); + !fits_budget) { + return fits_budget; + } + + // Also check that our emitter can handle the fusion node. We currently can + // have exponential time/memory requirements for emitting certain fusion + // kernels, in which case we don't want to fuse. + // TODO(b/119692968): Remove this once we have fixed our fusion emitter. + if (cost_analysis_.ProducerConsumerMergedTooLarge(*producer, *consumer)) { + return "the fusion would result in an overly large code duplication"; + } + + // Don't fuse across a root instruction. There are situation when a root + // instruction is not the last in the computation. Instructions after the + // root are not necessary dead. They can be inputs to instructions with side + // effects, like outfeed. + if (producer == producer->parent()->root_instruction()) { + return "not fusing into the output of the root instruction"; + } + + return InstructionFusion::ShouldFuseInPlaceOp(producer, consumer); + } + FusionDecision CanFuseCached(HloInstruction* producer, HloInstruction* consumer) { HloInstructionAdaptor producer_adaptor(*producer); @@ -335,8 +511,7 @@ class GpuPriorityFusionQueue : public FusionQueue { } } - auto fusion_decision = - can_fuse_(consumer, consumer->operand_index(producer)); + auto fusion_decision = CanFuse(producer, consumer); // The lock is required, because writing to a flat_hash_map is not // thread-safe even for different keys. We never call this computation @@ -350,13 +525,18 @@ class GpuPriorityFusionQueue : public FusionQueue { return fusion_decision; } - FusionDecision CanFuseWithAllUsers(HloInstruction* producer) { - if (producer->users().size() == 0) { + FusionDecision CanFuseWithAllNonBitcastUsers(HloInstruction* producer) { + if (producer->users().empty()) { return "No users to fuse"; } FusionDecision result; + bool has_non_bitcast_user = false; for (const auto& user : producer->users()) { + if (user->opcode() == HloOpcode::kBitcast) { + continue; + } + has_non_bitcast_user = true; if (auto fusion_decision = CanFuseCached(producer, user); !fusion_decision) { VLOG(10) << "Cannot fuse " << producer->name() << " with " @@ -364,6 +544,9 @@ class GpuPriorityFusionQueue : public FusionQueue { return fusion_decision; } } + if (!has_non_bitcast_user) { + return "not fusing because there are only bitcast users"; + } return {}; } @@ -388,11 +571,6 @@ class GpuPriorityFusionQueue : public FusionQueue { // The current consumers being visited. std::vector current_consumers_; - // Callbacks passed from the caller to check if we can fuse a pair of - // producer and consumer, where the consumer is given as a HloInstruction* - // and the producer is given as the consumer's operand index. - CanFuseCallback can_fuse_; - // The set of producers whose priorities need to be updated. Their // priorities are changed because their neighbors got fused, but we delay // the priority updates until current_consumers_ becomes empty. This is to @@ -418,6 +596,8 @@ class GpuPriorityFusionQueue : public FusionQueue { absl::Mutex can_fuse_cache_mutex_; GpuPerformanceModelCache gpu_performance_model_cache_; + + bool dump_fusion_visualization_; }; } // namespace @@ -443,12 +623,36 @@ class GpuPriorityFusionQueue : public FusionQueue { return InstructionFusion::IsExpensive(instruction); } -StatusOr GpuPriorityFusion::Run( +// Return true, if instr is a small constant. +// +// There is not single definition for what is a small constant in XLA. +// IrEmitterContext::emit_constant treats as small only constants of 1 element. +// HloPrintOptions::print_large_constants is effective for constants larger +// than 10 elements. +// +// This function matches the emitter logic. +bool IsSmallConstant(const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kConstant && instr->shape().IsArray() && + ShapeUtil::ElementsIn(instr->shape()) <= 1; +} + +bool GpuPriorityFusion::ConsumeFuel(HloInstruction* producer, + HloInstruction* consumer) { + return xla::ConsumeFuel(name(), /*ran_out_of_fuel_msg=*/[&] { + return absl::StrFormat("Not fusing producer %s with consumer %s", + producer->name(), consumer->name()); + }); +}; + +absl::StatusOr GpuPriorityFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { - bool dump_enabled = DumpingEnabledForHloModule(*module); + bool dump_enabled = + DumpingEnabledForHloPass(name(), module->config().debug_options()); if (dump_enabled) { fusion_process_dump_ = std::make_unique(); + *fusion_process_dump_->mutable_gpu_device_info() = + device_info_.ToGpuProto(); } // Appends ".0" suffix to all instructions. @@ -458,136 +662,108 @@ StatusOr GpuPriorityFusion::Run( // Before: broadcast.123 -> broadcast.124 // After: broadcast.123.0 -> broadcast.123.1 // - // With this modification it will be easier to match intructions before and + // With this modification it will be easier to match instructions before and // after fusion passes, because they will have the same unique prefix. Names // are not used in the pipeline, but it makes debugging much easier. - for (auto* computation : GetFusionComputations(module, execution_threads)) { + for (auto* computation : + GetNonFusionComputations(module, execution_threads)) { for (auto* instruction : computation->instructions()) { - instruction->SetAndSanitizeName(absl::StrCat(instruction->name(), ".0")); + module->SetAndUniquifyInstrName(instruction, + absl::StrCat(instruction->name(), ".0")); } } - auto result = InstructionFusion::Run(module, execution_threads); - if (dump_enabled) { - DumpPerModuleProtobufToFile(*module, *fusion_process_dump_, - module->config().debug_options(), - "priority_fusion_dump"); + fusion_process_dump_->set_hlo_module_before_fusion( + module->ToString(HloPrintOptions::ShortParsable())); } - return result; -} + int changed = false; + for (auto* computation : + GetNonFusionComputations(module, execution_threads)) { + CHECK(!computation->IsFusionComputation()); -FusionDecision GpuPriorityFusion::ShouldFuse(HloInstruction* consumer, - int64_t operand_index) { - auto isFusible = [](const HloInstruction& instr) { - // Side-effecting operations are not fusible. - if (!instr.IsFusible()) { - return false; - } + auto fusion_queue = std::make_unique( + computation, cost_analysis_options_, &device_info_, + fusion_process_dump_.get(), thread_pool_, fusion_analysis_cache_); - // Element-wise operations are always fusible. - if (instr.IsElementwise()) { - return true; - } + while (fusion_queue->DequeueNextProducer()) { + auto producer = fusion_queue->current_producer(); - // Other non-elementwise ops also supported by elemental fusion. - switch (instr.opcode()) { - case HloOpcode::kFusion: - return instr.fusion_kind() != HloInstruction::FusionKind::kCustom; - - case HloOpcode::kCopy: - case HloOpcode::kIota: - case HloOpcode::kConstant: - case HloOpcode::kReduce: - case HloOpcode::kBitcast: - case HloOpcode::kBroadcast: - case HloOpcode::kConcatenate: - case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kGather: - case HloOpcode::kPad: - case HloOpcode::kReduceWindow: - case HloOpcode::kReshape: - case HloOpcode::kReverse: - case HloOpcode::kScatter: - case HloOpcode::kSlice: - case HloOpcode::kTranspose: - return true; - default: - return false; - } - }; + for (auto* consumer : fusion_queue->current_consumers()) { + // Don't fuse into single bitcasts. We ignore them in the check + // CanFuseWithAllNonBitcastUsers(), so we need to check it here. + if (consumer->opcode() == HloOpcode::kBitcast) { + continue; + } + if (!ConsumeFuel(producer, consumer)) continue; - HloInstruction* producer = consumer->mutable_operand(operand_index); - if (!isFusible(*producer)) { - return "the producer is not fusible"; - } + VLOG(5) << "next: " << consumer->name() << "(" << consumer << ") + " + << producer->name() << "(" << producer << ")"; - if (!isFusible(*consumer)) { - return "the consumer is not fusible"; - } + fusion_queue->PreFusion(producer, consumer); + auto fusion_instruction = Fuse(producer, consumer, computation); + fusion_queue->OnFusingInstruction(fusion_instruction, producer, + consumer); - // Scatter is special as it has no elemental version but is still input - // fusible. Block attempts to create scatter fusions we can't codegen. - if (auto can_fuse = CanEmitInputFusedScatter(*producer, *consumer); - !can_fuse) { - return can_fuse; - } + changed = true; + } - // Avoid fusing reduce into reduce. Our cost model doesn't currently - // understand this case due to a lack of tiling analysis. - // TODO(b/312200883): Remove this. - auto contains_reduce = [&](const HloInstruction* instr) { - return HloAnyOf({HloInstructionAdaptor{*instr}}, - *HloFusionAdaptor::ForInstruction(instr), [](auto node) { - return node.opcode() == HloOpcode::kReduce; - }); - }; - if (contains_reduce(producer) && contains_reduce(consumer)) { - return "both the producer and the consumer contain a reduce"; - } + if (producer->user_count() == 0) { + fusion_queue->RemoveInstruction(producer); + // Remove from computation. + TF_RETURN_IF_ERROR(computation->RemoveInstruction(producer)); + } - // Avoid doing fusions into the output of an "input" fusion when it would - // switch it to the loop emitter. This often occurs during epilog fusion for - // reductions, which suffer from limited emitter support. - // TODO(b/312686229): Cost model should handle this. - auto analysis_fused = - AnalyzeProducerConsumerFusion(*producer, *consumer, device_info_); - if (producer->IsInputFusion() && analysis_fused && - analysis_fused->GetEmitterFusionKind() == - HloFusionAnalysis::EmitterFusionKind::kLoop) { - return "fusion into output of an input fusion would create a loop fusion"; - } + fusion_queue->UpdatePriorities(); + } - // Avoid cases where we'd create a fusion that hit limitations in ptxas. - // Would be nice to model this with cost instead. - if (auto fits_budget = - FusionFitsInBudget(*consumer, *producer, device_info_, - /*is_consumer_producer_fusion=*/true); - !fits_budget) { - return fits_budget; + // Fuse all constants. + std::vector constants; + for (auto* instruction : computation->instructions()) { + // Small constants should be fused, because they can be folded and + // codegened efficiently. + // Fusing large constants doesn't give much benefits, because they're + // treated like parameters and read from global memory anyway. Fusion + // and duplication of large constants can, however, cause problems if we + // want to dump hlo and parse back, because in that case duplicated + // constants will be filled with different data. + if (IsSmallConstant(instruction)) { + constants.push_back(instruction); + } + } + for (auto* constant : constants) { + auto users = constant->users(); + for (auto* user : users) { + if (IsFusible(*user) && CanEmitInputFusedScatter(*constant, *user)) { + Fuse(constant, user, computation); + changed = true; + } + } + } } - // Also check that our emitter can handle the fusion node. We currently can - // have exponential time/memory requirements for emitting certain fusion - // kernels, in which case we don't want to fuse. - // TODO(b/119692968): Remove this once we have fixed our fusion emitter. - if (consumer->opcode() == HloOpcode::kFusion) { - absl::MutexLock lock(&fusion_node_evaluations_mutex_); - if (fusion_node_evaluations_.find(consumer) == - fusion_node_evaluations_.end()) { - // We have no cached results for this fusion node yet. Compute it now. - fusion_node_evaluations_.emplace(consumer, - FusionNodeIndexingEvaluation(consumer)); - } - if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh( - producer)) { - return "the fusion would result in an overly large code duplication"; - } + // FusionAnalysis cache uses unique_id as key. IDs are only unique inside one + // module. It's important to fully clear the cache if the same instance of the + // pass will be called on a different module. + fusion_analysis_cache_.Clear(); + + if (dump_enabled) { + DumpPerModuleProtobufToFile(*module, *fusion_process_dump_, + module->config().debug_options(), + "priority_fusion_dump"); } - return InstructionFusion::ShouldFuse(consumer, operand_index); + return changed; +} + +FusionDecision GpuPriorityFusion::ShouldFuse(HloInstruction* consumer, + int64_t operand_index) { + // This method is called in `InstructionFusion::Run` right before fusion, but + // it will always return true. Fusion decision are fully controlled by the + // PriorityQueue. If the queue returns a producer that shouldn't be fused, + // it's a bug and should be fixed in the queue logic. + return {}; } HloInstruction::FusionKind GpuPriorityFusion::ChooseKind( @@ -595,15 +771,15 @@ HloInstruction::FusionKind GpuPriorityFusion::ChooseKind( // Derive kInput/kLoop fusion kinds from fusion analysis. This shouldn't // matter but some passes downstream still query these instead of fusion // analysis. - // TODO: Don't recompute this all the time. const auto& analysis = fusion_analysis_cache_.Get(*producer, *consumer); - if (!analysis) return HloInstruction::FusionKind::kLoop; - switch (analysis->GetEmitterFusionKind()) { + switch (analysis.GetEmitterFusionKind()) { case HloFusionAnalysis::EmitterFusionKind::kLoop: return HloInstruction::FusionKind::kLoop; case HloFusionAnalysis::EmitterFusionKind::kTriton: case HloFusionAnalysis::EmitterFusionKind::kCustomFusion: + case HloFusionAnalysis::EmitterFusionKind::kCuDnn: return HloInstruction::FusionKind::kCustom; + case HloFusionAnalysis::EmitterFusionKind::kConcatenate: case HloFusionAnalysis::EmitterFusionKind::kReduction: case HloFusionAnalysis::EmitterFusionKind::kTranspose: case HloFusionAnalysis::EmitterFusionKind::kInputSlices: @@ -620,24 +796,12 @@ HloInstruction* GpuPriorityFusion::FuseInstruction( } else { result = InstructionFusion::FuseInstruction(fusion_instruction, producer); } - - // Invalidate cached values that are now invalid. - for (auto* user : fusion_instruction->users()) { - fusion_node_evaluations_.erase(user); - } - fusion_node_evaluations_.erase(fusion_instruction); - return result; } std::unique_ptr GpuPriorityFusion::GetFusionQueue( HloComputation* computation) { - return std::unique_ptr(new GpuPriorityFusionQueue( - computation, cost_analysis_options_, &device_info_, - [this](HloInstruction* consumer, int64_t operand_index) { - return ShouldFuse(consumer, operand_index); - }, - fusion_process_dump_.get(), thread_pool_, fusion_analysis_cache_)); + return nullptr; } } // namespace gpu diff --git a/xla/service/gpu/priority_fusion.h b/xla/service/gpu/priority_fusion.h index afc5e8f99003d..c8e45da90cb31 100644 --- a/xla/service/gpu/priority_fusion.h +++ b/xla/service/gpu/priority_fusion.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,14 +19,14 @@ limitations under the License. #include #include +#include #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/fusion_node_indexing_evaluation.h" #include "xla/service/fusion_queue.h" #include "xla/service/gpu/fusion_process_dump.pb.h" #include "xla/service/gpu/model/fusion_analysis_cache.h" @@ -34,7 +34,7 @@ limitations under the License. #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_pass_interface.h" #include "xla/service/instruction_fusion.h" -#include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" #include "tsl/platform/threadpool.h" namespace xla { @@ -56,13 +56,14 @@ class GpuPriorityFusion : public InstructionFusion { static bool IsExpensive(const HloInstruction& instruction); using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; protected: std::unique_ptr GetFusionQueue( HloComputation* computation) override; + FusionDecision ShouldFuse(HloInstruction* consumer, int64_t operand_index) override; @@ -73,6 +74,10 @@ class GpuPriorityFusion : public InstructionFusion { HloInstruction* FuseInstruction(HloInstruction* fusion_instruction, HloInstruction* producer) override; + // Consumes a unit of compiler fuel and returns true if we should + // continue with the transformation. + bool ConsumeFuel(HloInstruction* producer, HloInstruction* consumer); + tsl::thread::ThreadPool* thread_pool_; se::DeviceDescription device_info_; @@ -83,11 +88,6 @@ class GpuPriorityFusion : public InstructionFusion { // null, logging is disabled. std::unique_ptr fusion_process_dump_; - // Keep track of the number of times each instruction inside a fusion node is - // indexed with different index vectors. - absl::Mutex fusion_node_evaluations_mutex_; - absl::flat_hash_map - fusion_node_evaluations_; HloFusionAnalysisCache fusion_analysis_cache_; }; diff --git a/xla/service/gpu/priority_fusion_test.cc b/xla/service/gpu/priority_fusion_test.cc index 8c122df7f1c00..14831b4cef88e 100644 --- a/xla/service/gpu/priority_fusion_test.cc +++ b/xla/service/gpu/priority_fusion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,8 +18,13 @@ limitations under the License. #include #include +#include +#include +#include #include +#include +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -38,7 +43,6 @@ limitations under the License. namespace m = ::xla::match; -using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; using ::tsl::testing::IsOk; using ::tsl::testing::IsOkAndHolds; @@ -65,11 +69,9 @@ class PriorityFusionTest : public HloTestBase { if (!computation->FusionInstruction()) continue; auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - auto analysis = - HloFusionAnalysis::Create( - Cast(computation->FusionInstruction()), - &device_info) - .value(); + auto analysis = HloFusionAnalysis::Create( + Cast(computation->FusionInstruction()), + &device_info); kinds.push_back(analysis.GetEmitterFusionKind()); } return kinds; @@ -140,6 +142,24 @@ CHECK-NEXT: ROOT {{.*}} tuple(%[[FUSION_0]], %[[FUSION_1]]) )"); } +TEST_F(PriorityFusionTest, FuseBroadcastIntoBitcastConsumers) { + absl::string_view kHlo = R"( + HloModule test_module + + ENTRY main { + param_0 = f32[96]{0} parameter(0) + broadcast = f32[8,96,128,7]{3,2,1,0} broadcast(param_0), dimensions={1} + bitcast.6079.2 = f32[8,24,4,128,7]{4,3,2,1,0} bitcast(broadcast) + ROOT transpose.1990.2 = f32[8,24,128,7,4]{4,3,2,1,0} transpose(bitcast.6079.2), dimensions={0,1,3,4,2} + } + )"; + RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"( +CHECK: ENTRY +CHECK-NEXT: %[[PARAM:.*]] = f32[96]{0} parameter(0) +CHECK-NEXT: ROOT %{{.*}} fusion(%[[PARAM]]) + )"); +} + TEST_F(PriorityFusionTest, FuseWideningConvertIntoConsumers) { absl::string_view kHlo = R"( HloModule test_module @@ -157,8 +177,9 @@ TEST_F(PriorityFusionTest, FuseWideningConvertIntoConsumers) { CHECK: ENTRY CHECK-NEXT: %[[PARAM:.*]] = f16[512]{0} parameter(0) CHECK-NEXT: %[[FUSION_F32:.*]] = f32[512]{0} fusion(%[[PARAM]]) -CHECK-NEXT: %[[FUSION_S32:.*]] = s32[512]{0} fusion(%[[PARAM]]) -CHECK-NEXT: ROOT %{{.*}} = (f32[512]{0}, s32[512]{0}) tuple(%[[FUSION_F32]], %[[FUSION_S32]]) +CHECK-NEXT: %[[CONVERT_FUSION:.*]] = f32[512]{0} fusion(%[[PARAM]]) +CHECK-NEXT: %[[BITCAST:.*]] = s32[512]{0} bitcast(%[[CONVERT_FUSION]]) +CHECK-NEXT: ROOT %{{.*}} = (f32[512]{0}, s32[512]{0}) tuple(%[[FUSION_F32]], %[[BITCAST]]) )"); } @@ -202,7 +223,8 @@ CHECK-COUNT-3: fusion } TEST_F(PriorityFusionTest, ReductionEpilogueFusionRegressionTest) { - // Regression test for epilogue fusion of convert+bitcast into a reduction. + // Regression test for epilogue fusion of convert into a reduction, even if + // the convert has a bitcast as consumer. absl::string_view kHlo = R"( HloModule test_module @@ -251,10 +273,37 @@ TEST_F(PriorityFusionTest, ReductionEpilogueFusionRegressionTest) { RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"( CHECK: ENTRY -CHECK: ROOT {{.*}} fusion( +CHECK: ROOT {{.*}} bitcast({{.*}}fusion{{.*}}) )"); } +TEST_F(PriorityFusionTest, DoNotChangeReductionFusionToLoopFusion) { + // Regression test for epilogue fusion of slice into a reduction. The fusion + // kind for the reduction fusion is intentionally chosen to be set to kLoop, + // as we cannot rely on reductions always having fusion kind kInput. + auto module = *ParseAndReturnVerifiedModule(R"( + HloModule test_module + + add { + rhs.407 = f32[] parameter(1) + lhs.407 = f32[] parameter(0) + ROOT add.24451 = f32[] add(lhs.407, rhs.407) + } + + fused_computation { + p0 = f32[16,64]{1,0} parameter(0) + zero = f32[] constant(0.0) + ROOT reduce = f32[16]{0} reduce(p0, zero), dimensions={1}, to_apply=add + } + + ENTRY main { + param0 = f32[16,64]{1,0} parameter(0) + fusion = f32[16]{0} fusion(param0), kind=kLoop, calls=fused_computation + ROOT slice = f32[8]{0} slice(fusion), slice={[0:8]} + })"); + EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false)); +} + TEST_F(PriorityFusionTest, DoNotFuseTransposeIntoReduce) { absl::string_view kHlo = R"( HloModule test_module @@ -323,10 +372,11 @@ TEST_F(PriorityFusionTest, DoNotFuseTransposeIntoReduce) { })"; using Kind = HloFusionAnalysis::EmitterFusionKind; - EXPECT_THAT(RunAndGetFusionKinds(kHlo), - UnorderedElementsAre(Kind::kLoop, Kind::kReduction, - Kind::kReduction, Kind::kTranspose, - Kind::kTranspose, Kind::kTranspose)); + EXPECT_THAT( + RunAndGetFusionKinds(kHlo), + UnorderedElementsAre(Kind::kLoop, Kind::kLoop, Kind::kLoop, + Kind::kReduction, Kind::kReduction, Kind::kTranspose, + Kind::kTranspose, Kind::kTranspose)); } TEST_F(PriorityFusionTest, DoNotFuseReduceIntoReduce) { @@ -469,33 +519,6 @@ CHECK-COUNT-2: fusion( )"); } -TEST_F(PriorityFusionTest, SingleTransposeFusion) { - // A regression test that verifies the given HLO fuses into a single fusion. - absl::string_view kHlo = R"( - HloModule test_module - - ENTRY main { - param_0.14390 = bf16[2048,24576]{1,0} parameter(0) - convert.34192 = f32[2048,24576]{1,0} convert(param_0.14390) - constant_11107 = bf16[] constant(0.02002) - convert.35472 = f32[] convert(constant_11107) - broadcast.21886 = f32[2048,24576]{1,0} broadcast(convert.35472), dimensions={} - multiply.14420 = f32[2048,24576]{1,0} multiply(convert.34192, broadcast.21886) - fusion.3520 = f32[2048,24576]{1,0} tanh(multiply.14420) - - constant_11286 = bf16[] constant(50) - convert.42562 = f32[] convert(constant_11286) - broadcast.22230 = f32[2048,24576]{1,0} broadcast(convert.42562), dimensions={} - multiply.14798 = f32[2048,24576]{1,0} multiply(fusion.3520, broadcast.22230) - convert.34603 = bf16[2048,24576]{1,0} convert(multiply.14798) - bitcast.21354 = bf16[1,2048,2048,12]{3,2,1,0} bitcast(convert.34603) - ROOT transpose.6502 = bf16[1,12,2048,2048]{3,2,1,0} transpose(bitcast.21354), dimensions={0,3,2,1} - })"; - - using Kind = HloFusionAnalysis::EmitterFusionKind; - EXPECT_THAT(RunAndGetFusionKinds(kHlo), ElementsAre(Kind::kTranspose)); -} - TEST_F(PriorityFusionTest, DontFuseIntoFirstOperandOfScatter) { auto module = *ParseAndReturnVerifiedModule(R"( HloModule test_module @@ -532,6 +555,42 @@ TEST_F(PriorityFusionTest, DontFuseIntoFirstOperandOfScatter) { GmockMatch(m::Scatter(m::Parameter(), m::Add(), m::Add()))); } +// This test is similar to DontFuseIntoFirstOperandOfScatter, but PriorityFusion +// has a separate run to fuse constants. Fusing anything into a scatter fusion +// will fail in the emitter. +TEST_F(PriorityFusionTest, DontFuseConstantIntoFirstOperandOfScatter) { + auto module = *ParseAndReturnVerifiedModule(R"( + HloModule test_module + + add { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) + } + + ENTRY FuseIntoScatter { + operand = s32[1] constant({0}) + indices = s32[24,1] parameter(0) + constant = s32[] constant(1) + updates = s32[24,1] broadcast(constant) + ROOT scatter = s32[1] scatter(operand, indices, updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + })"); + + EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true)); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, GmockMatch(m::Fusion(m::Constant(), m::Parameter()))); + EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kInput); + EXPECT_THAT(root->fused_expression_root(), + GmockMatch(m::Scatter(m::Parameter(), m::Parameter(), + m::Broadcast(m::Constant())))); +} + TEST_F(PriorityFusionTest, DoNotFuseReduceIntoReduceEvenIfOccupancyIsHigh) { constexpr absl::string_view kHlo = R"( HloModule test_module @@ -653,5 +712,148 @@ TEST_F(PriorityFusionTest, EpilogueFusionFails) { EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false)); } +TEST_F(PriorityFusionTest, DoNotFuseIntoRoot) { + auto module = *ParseAndReturnVerifiedModule(R"( + HloModule test_module + + ENTRY %main (p.0: u32[2], p.1: u32[]) -> u32[2] { + %p.0 = u32[2]{0} parameter(0) + %p.1 = u32[] parameter(1) + ROOT %broadcast = u32[2]{0} broadcast(u32[] %p.1), dimensions={}, sharding={replicated} + %add = u32[2]{0} add(u32[2]{0} %p.0, u32[2]{0} %broadcast) + %tuple.1 = (u32[2]{0}) tuple(u32[2]{0} %add) + %token.0 = token[] after-all() + %outfeed.6 = token[] outfeed((u32[2]{0}) %tuple.1, token[] %token.0), outfeed_shape=(u32[2]{0}), sharding={maximal device=0} + })"); + + EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false)); +} + +TEST_F(PriorityFusionTest, DontFuseConcat) { + // Regression test that verifies we don't fuse concat into a column reduction. + auto module = *ParseAndReturnVerifiedModule(R"( + HloModule module + + %maximum (param_0: f32[], param_1: f32[]) -> f32[] { + %param_0 = f32[] parameter(0) + %param_1 = f32[] parameter(1) + ROOT %maximum = f32[] maximum(f32[] %param_0, f32[] %param_1) + } + + %fused_concat (param_0: f32[1,4,401,8,8], param_1: f32[1,1,4,1023,8], param_2: bf16[1,4,1023,8,8]) -> f32[1,4,1424,8,8] { + %param_2 = bf16[1,4,1023,8,8]{4,3,2,1,0} parameter(2) + %convert = f32[1,4,1023,8,8]{4,3,2,1,0} convert(bf16[1,4,1023,8,8]{4,3,2,1,0} %param_2) + %param_1 = f32[1,1,4,1023,8]{4,3,2,1,0} parameter(1) + %bitcast = f32[4,1023,8]{2,1,0} bitcast(f32[1,1,4,1023,8]{4,3,2,1,0} %param_1) + %broadcast = f32[1,4,1023,8,8]{4,3,2,1,0} broadcast(f32[4,1023,8]{2,1,0} %bitcast), dimensions={1,2,4} + %add = f32[1,4,1023,8,8]{4,3,2,1,0} add(f32[1,4,1023,8,8]{4,3,2,1,0} %convert, f32[1,4,1023,8,8]{4,3,2,1,0} %broadcast) + %param_0 = f32[1,4,401,8,8]{4,3,2,1,0} parameter(0) + ROOT %concatenate = f32[1,4,1424,8,8]{4,3,2,1,0} concatenate(f32[1,4,1023,8,8]{4,3,2,1,0} %add, f32[1,4,401,8,8]{4,3,2,1,0} %param_0), dimensions={2} + } + + %fused_reduce (param_0: f32[], param_1: f32[1,4,1424,8,8]) -> f32[4,8,8] { + %param_1 = f32[1,4,1424,8,8]{4,3,2,1,0} parameter(1) + %bitcast = f32[4,1424,8,8]{3,2,1,0} bitcast(f32[1,4,1424,8,8]{4,3,2,1,0} %param_1) + %param_0 = f32[] parameter(0) + ROOT %reduce = f32[4,8,8]{2,1,0} reduce(f32[4,1424,8,8]{3,2,1,0} %bitcast, f32[] %param_0), dimensions={1}, to_apply=%maximum + } + + %fused_broadcast (param_0: f32[1,4,1424,8,8], param_1: f32[4,8,8]) -> f32[1,4,1424,8,8] { + %param_0 = f32[1,4,1424,8,8]{4,3,2,1,0} parameter(0) + %param_1 = f32[4,8,8]{2,1,0} parameter(1) + %broadcast = f32[1,4,1424,8,8]{4,3,2,1,0} broadcast(f32[4,8,8]{2,1,0} %param_1), dimensions={1,3,4} + ROOT %subtract = f32[1,4,1424,8,8]{4,3,2,1,0} subtract(f32[1,4,1424,8,8]{4,3,2,1,0} %param_0, f32[1,4,1424,8,8]{4,3,2,1,0} %broadcast) + } + + ENTRY fusion { + %param_0 = f32[1,4,401,8,8]{4,3,2,1,0} parameter(0) + %param_1 = f32[1,1,4,1023,8]{4,3,2,1,0} parameter(1) + %param_2 = bf16[1,4,1023,8,8]{4,3,2,1,0} parameter(2) + %concat = f32[1,4,1424,8,8]{4,3,2,1,0} fusion(%param_0, %param_1, %param_2), kind=kLoop, calls=fused_concat + %param_3 = f32[] parameter(3) + %reduce = f32[4,8,8]{2,1,0} fusion(%param_3, %concat), kind=kLoop, calls=fused_reduce + %param_4 = f32[4,8,8]{2,1,0} parameter(4) + %broadcast = f32[1,4,1424,8,8]{4,3,2,1,0} fusion(%concat, %param_4), kind=kLoop, calls=fused_broadcast + ROOT tuple = (f32[4,8,8]{2,1,0}, f32[1,4,1424,8,8]{4,3,2,1,0}) tuple(%reduce, %broadcast) + } + )"); + + EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false)); +} + +TEST_F(PriorityFusionTest, FuseOnlySmallConstant) { + auto module = *ParseAndReturnVerifiedModule(R"( + HloModule module + + ENTRY main { + param_0 = f32[32,32]{1,0} parameter(0) + c_1 = f32[] constant(1) + c_2 = f32[32,32] constant({...}) + broadcast = f32[32,32]{1,0} broadcast(c_1), dimensions={} + add = f32[32,32]{1,0} add(param_0, broadcast) + ROOT mul = f32[32,32]{1,0} multiply(c_2, add) + } + )"); + EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true)); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, GmockMatch(m::Fusion(m::Constant(), m::Parameter()))); + EXPECT_THAT(root->fused_expression_root(), + GmockMatch(m::Multiply( + m::Parameter(), + m::Add(m::Parameter(), m::Broadcast(m::Constant()))))); +} + +TEST_F(PriorityFusionTest, DoNotFuseProducerConsumerMergedTooLarge) { + auto module = *ParseAndReturnVerifiedModule(R"( + HloModule module + + fused_computation.1 { + iota.9.7 = s32[3,1,1]{2,1,0} iota(), iota_dimension=0 + param_3.29 = s32[] parameter(2) + pad.2.7 = s32[3,1,2]{2,1,0} pad(iota.9.7, param_3.29), padding=0_0x0_0x0_1 + param_2.39 = s32[] parameter(1) + broadcast.76.1 = s32[3,1,2]{2,1,0} broadcast(param_2.39), dimensions={} + compare.9.1 = pred[3,1,2]{2,1,0} compare(pad.2.7, broadcast.76.1), direction=GE + param_1.73 = s32[2]{0} parameter(0) + broadcast.78.1 = s32[3,2]{1,0} broadcast(param_1.73), dimensions={1} + bitcast.1 = s32[3,2]{1,0} bitcast(pad.2.7) + compare.10.1 = pred[3,2]{1,0} compare(bitcast.1, broadcast.78.1), direction=LE + bitcast.2 = pred[3,1,2]{2,1,0} bitcast(compare.10.1) + ROOT and.3.1 = pred[3,1,2]{2,1,0} and(compare.9.1, bitcast.2) + } + + and { + x = pred[] parameter(0) + y = pred[] parameter(1) + ROOT and = pred[] and(x, y) + } + + fused_computation.2 { + param0 = pred[3,1,2]{2,1,0} parameter(0) + slice = pred[1,1,2]{2,1,0} slice(param0), slice={[0:1], [0:1], [0:2]} + bitcast = pred[2]{0} bitcast(slice) + init = pred[] constant(true) + reduce = pred[2]{0} reduce(param0, init), dimensions={0,1}, to_apply=and + and = pred[2]{0} and(bitcast, reduce) + pad = pred[3]{0} pad(and, init), padding=0_1 + broadcast = pred[3,2]{1,0} broadcast(pad), dimensions={0} + bitcast2 = pred[6]{0} bitcast(broadcast) + broadcast2 = pred[2,3]{1,0} broadcast(pad), dimensions={1} + bitcast3 = pred[6]{0} bitcast(broadcast2) + ROOT and2 = pred[6]{0} and(bitcast2, bitcast3) + } + + ENTRY main { + p0 = s32[2]{0} parameter(0) + p1 = s32[] parameter(1) + p2 = s32[] parameter(2) + fusion1 = pred[3,1,2]{2,1,0} fusion(p0, p1, p2), kind=kLoop, calls=fused_computation.1 + ROOT fusion2 = pred[6]{0} fusion(fusion1), kind=kInput, calls=fused_computation.2 + } + )"); + EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false)); +} + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/reduction_degenerate_dim_remover.cc b/xla/service/gpu/reduction_degenerate_dim_remover.cc index 86dfa84a6197a..f6d422b367469 100644 --- a/xla/service/gpu/reduction_degenerate_dim_remover.cc +++ b/xla/service/gpu/reduction_degenerate_dim_remover.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,27 +15,31 @@ limitations under the License. #include "xla/service/gpu/reduction_degenerate_dim_remover.h" -#include +#include +#include +#include +#include #include "absl/algorithm/container.h" -#include "absl/strings/str_join.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/pattern_matcher.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/statusor.h" -#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { class ReductionDegenerateDimRemoverVisitor : public DfsHloRewriteVisitor { public: - Status HandleReduce(HloInstruction *hlo) override { + absl::Status HandleReduce(HloInstruction *hlo) override { auto instr = Cast(hlo); absl::InlinedVector input_reshapes; absl::InlinedVector canonical_reduce_shapes; @@ -50,7 +54,7 @@ class ReductionDegenerateDimRemoverVisitor : public DfsHloRewriteVisitor { : instr->shape(); if (!ShapeUtil::HasDegenerateDimensions(reduced_op->shape())) { - return OkStatus(); + return absl::OkStatus(); } Shape canonical_input_shape = ShapeUtil::DropDegenerateDimensions(input_shape); @@ -113,7 +117,7 @@ class ReductionDegenerateDimRemoverVisitor : public DfsHloRewriteVisitor { } }; -StatusOr ReductionDegenerateDimRemover::Run( +absl::StatusOr ReductionDegenerateDimRemover::Run( HloModule *module, const absl::flat_hash_set &execution_threads) { TF_ASSIGN_OR_RETURN(bool changed, diff --git a/xla/service/gpu/reduction_degenerate_dim_remover.h b/xla/service/gpu/reduction_degenerate_dim_remover.h index e146a3a01cbd0..03d6819081d5d 100644 --- a/xla/service/gpu/reduction_degenerate_dim_remover.h +++ b/xla/service/gpu/reduction_degenerate_dim_remover.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,9 +15,9 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_REDUCTION_DEGENERATE_DIM_REMOVER_H_ #define XLA_SERVICE_GPU_REDUCTION_DEGENERATE_DIM_REMOVER_H_ -#include - -#include "xla/hlo/ir/hlo_instructions.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" @@ -45,7 +45,7 @@ class ReductionDegenerateDimRemover : public HloModulePass { return "reduction-degenerate-dim-remover"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/reduction_dimension_grouper.cc b/xla/service/gpu/reduction_dimension_grouper.cc index 86896ad444f9c..8ab4fcf648a25 100644 --- a/xla/service/gpu/reduction_dimension_grouper.cc +++ b/xla/service/gpu/reduction_dimension_grouper.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,25 +15,34 @@ limitations under the License. #include "xla/service/gpu/reduction_dimension_grouper.h" -#include +#include #include #include #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/layout_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { class ReduceDimensionGroupVisitor : public DfsHloRewriteVisitor { public: - Status HandleReduce(HloInstruction *hlo) override { + absl::Status HandleReduce(HloInstruction *hlo) override { auto reduce = Cast(hlo); VLOG(4) << "Input: " << reduce->ToString(); @@ -82,7 +91,7 @@ class ReduceDimensionGroupVisitor : public DfsHloRewriteVisitor { } if (!changed) { // Since all inputs have same shape dimensions. - return OkStatus(); + return absl::OkStatus(); } Shape grouped_shape = @@ -101,7 +110,7 @@ class ReduceDimensionGroupVisitor : public DfsHloRewriteVisitor { } }; -StatusOr ReductionDimensionGrouper::Run( +absl::StatusOr ReductionDimensionGrouper::Run( HloModule *module, const absl::flat_hash_set &execution_threads) { TF_ASSIGN_OR_RETURN(bool changed, ReduceDimensionGroupVisitor().RunOnModule( diff --git a/xla/service/gpu/reduction_dimension_grouper.h b/xla/service/gpu/reduction_dimension_grouper.h index b0d02bde2253e..8ee4efd0cfd26 100644 --- a/xla/service/gpu/reduction_dimension_grouper.h +++ b/xla/service/gpu/reduction_dimension_grouper.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,9 +15,9 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_REDUCTION_DIMENSION_GROUPER_H_ #define XLA_SERVICE_GPU_REDUCTION_DIMENSION_GROUPER_H_ -#include - -#include "xla/hlo/ir/hlo_instructions.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" @@ -45,7 +45,7 @@ class ReductionDimensionGrouper : public HloModulePass { return "reduction-dimension-grouper"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/reduction_layout_normalizer.cc b/xla/service/gpu/reduction_layout_normalizer.cc index 5c56e279203ab..a91fdf7e387b7 100644 --- a/xla/service/gpu/reduction_layout_normalizer.cc +++ b/xla/service/gpu/reduction_layout_normalizer.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,26 +15,38 @@ limitations under the License. #include "xla/service/gpu/reduction_layout_normalizer.h" -#include +#include +#include +#include +#include #include "absl/algorithm/container.h" -#include "absl/strings/str_join.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/pattern_matcher.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/statusor.h" +#include "xla/util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor { - Status HandleReduce(HloInstruction *hlo) override { + absl::Status HandleReduce(HloInstruction *hlo) override { auto reduce = Cast(hlo); VLOG(5) << "Input: " << reduce->ToString(); @@ -123,7 +135,7 @@ class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor { new_reduce_shape_layout); if (new_operand_shape == operand_shape && reduce->inputs().size() == 1) { - return OkStatus(); + return absl::OkStatus(); } HloInstruction *canonical_reduce_input = @@ -178,7 +190,7 @@ class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor { } }; -StatusOr ReductionLayoutNormalizer::Run( +absl::StatusOr ReductionLayoutNormalizer::Run( HloModule *module, const absl::flat_hash_set &execution_threads) { TF_ASSIGN_OR_RETURN(bool changed, diff --git a/xla/service/gpu/reduction_layout_normalizer.h b/xla/service/gpu/reduction_layout_normalizer.h index b7f0defd619f2..7d2d207773e05 100644 --- a/xla/service/gpu/reduction_layout_normalizer.h +++ b/xla/service/gpu/reduction_layout_normalizer.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,9 +15,9 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_REDUCTION_LAYOUT_NORMALIZER_H_ #define XLA_SERVICE_GPU_REDUCTION_LAYOUT_NORMALIZER_H_ -#include - -#include "xla/hlo/ir/hlo_instructions.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" @@ -43,7 +43,7 @@ class ReductionLayoutNormalizer : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/reduction_splitter.cc b/xla/service/gpu/reduction_splitter.cc index f1380cdc20aae..cd37319a47de3 100644 --- a/xla/service/gpu/reduction_splitter.cc +++ b/xla/service/gpu/reduction_splitter.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,34 +16,49 @@ limitations under the License. #include "xla/service/gpu/reduction_splitter.h" #include +#include +#include #include #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" -#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/layout_util.h" #include "xla/service/gpu/reduction_utils.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { class ReductionSplitterVisitor : public DfsHloRewriteVisitor { public: - Status HandleReduce(HloInstruction *reduce) override { + explicit ReductionSplitterVisitor(bool ignore_small_dims) + : ignore_small_dims_(ignore_small_dims) {} + absl::Status HandleReduce(HloInstruction *reduce) override { VLOG(4) << "Input: " << reduce->ToString(); // Reductions with contiguous dimensions are lowered to efficient code. No // need to split such ops. if (IsReductionFromOrToContiguousDimensions(*reduce)) { - return OkStatus(); + VLOG(4) << "Reduction with contiguous dimensions. Return."; + return absl::OkStatus(); } if (reduce->dimensions().size() < 2) { - return OkStatus(); + return absl::OkStatus(); } if (!reduce->shape().IsArray()) { // TODO(cheshire): Handle variadic reduction. - return OkStatus(); + return absl::OkStatus(); } HloInstruction *operand = reduce->mutable_operand(0); @@ -71,9 +86,8 @@ class ReductionSplitterVisitor : public DfsHloRewriteVisitor { max_shape_dim = input_shape.dimensions(max_reduce_dim); } } - // TODO(tjoerg): Run microbenchmarks to tune this threshold. - if (max_shape_dim < 128) { - return OkStatus(); + if (ignore_small_dims_ && max_shape_dim <= 8) { + return absl::OkStatus(); } // Split the reduction into a pre-reduction and a final reduction. @@ -108,13 +122,17 @@ class ReductionSplitterVisitor : public DfsHloRewriteVisitor { reduce->mutable_operand(1), final_reduce_dims, reduce->to_apply()); return ReplaceWithNewInstruction(reduce, std::move(final_reduce)); } + + private: + bool ignore_small_dims_; }; -StatusOr ReductionSplitter::Run( +absl::StatusOr ReductionSplitter::Run( HloModule *module, const absl::flat_hash_set &execution_threads) { - TF_ASSIGN_OR_RETURN(bool changed, ReductionSplitterVisitor().RunOnModule( - module, execution_threads)); + TF_ASSIGN_OR_RETURN(bool changed, + ReductionSplitterVisitor(ignore_small_dims_) + .RunOnModule(module, execution_threads)); return changed; } diff --git a/xla/service/gpu/reduction_splitter.h b/xla/service/gpu/reduction_splitter.h index 776055acd4bac..7e7652500e6d3 100644 --- a/xla/service/gpu/reduction_splitter.h +++ b/xla/service/gpu/reduction_splitter.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,35 +15,42 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_ #define XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" namespace xla { namespace gpu { -// Splits a reduce op into two consecutive reduce ops if -// * the reduce dimensions are not contiguous and -// * at least one reduce dimension is large (i.e. corresponds to a large input -// shape dimension). +// Splits a reduce op into two consecutive reduce ops if the reduce dimensions +// are not contiguous. Ignores small reduce dimensions if `ignore_small_dims` is +// set. // // Reductions with non-contiguous dimensions are emitted as simple element-wise // loops. This is inefficient when reducing large input shape dimensions. // Splitting such reductions allows using more efficient reduction emitters. // // This pass splits reduce ops into two consecutive reduce ops. Run it to a -// fixpoint to split reduce ops along multiple large dimensions. +// fixpoint to split reduce ops along multiple dimensions. // // Precondition: ReductionDimensionGrouper has been run and adjacent reduce // dimentsions have been grouped. Reduction layouts have been normalized. class ReductionSplitter : public HloModulePass { public: + explicit ReductionSplitter(bool ignore_small_dims) + : ignore_small_dims_(ignore_small_dims) {} absl::string_view name() const override { return "reduction-splitter"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; + + private: + bool ignore_small_dims_; }; } // namespace gpu diff --git a/xla/service/gpu/reduction_splitter_test.cc b/xla/service/gpu/reduction_splitter_test.cc index f98ae743b1af1..13a5210fee2ee 100644 --- a/xla/service/gpu/reduction_splitter_test.cc +++ b/xla/service/gpu/reduction_splitter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,12 +15,15 @@ limitations under the License. #include "xla/service/gpu/reduction_splitter.h" +#include +#include + +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" namespace xla { @@ -51,7 +54,8 @@ TEST_F(ReductionSplitterTest, SplitReductionAtDimensionTwo) { } )") .value(); - ASSERT_TRUE(ReductionSplitter().Run(module.get()).value()); + ASSERT_TRUE( + ReductionSplitter(/*ignore_small_dims=*/true).Run(module.get()).value()); SCOPED_TRACE(module->ToString()); const HloInstruction* root_reduction = module->entry_computation()->root_instruction(); @@ -82,7 +86,8 @@ TEST_F(ReductionSplitterTest, SplitReductionAtDimensionZero) { } )") .value(); - ASSERT_TRUE(ReductionSplitter().Run(module.get()).value()); + ASSERT_TRUE( + ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value()); SCOPED_TRACE(module->ToString()); const HloInstruction* root_reduction = module->entry_computation()->root_instruction(); @@ -108,13 +113,16 @@ TEST_F(ReductionSplitterTest, DontSplitReductionWithSmallDimensions) { } ENTRY entry_computation { - param_0 = f32[8,1024,8]{2,1,0} parameter(0) + param_0 = f32[16,8,1024,8]{3,2,1,0} parameter(0) constant_11111 = f32[] constant(0) - ROOT reduce.982 = f32[1024]{0} reduce(param_0, constant_11111), dimensions={2,0}, to_apply=add_computation + ROOT reduce.982 = f32[16,1024]{1,0} reduce(param_0, constant_11111), dimensions={3,1}, to_apply=add_computation } )") .value(); - EXPECT_FALSE(ReductionSplitter().Run(module.get()).value()); + EXPECT_FALSE( + ReductionSplitter(/*ignore_small_dims=*/true).Run(module.get()).value()); + EXPECT_TRUE( + ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value()); } TEST_F(ReductionSplitterTest, DontSplitReductionsWithContiguousDimensions) { @@ -135,7 +143,8 @@ TEST_F(ReductionSplitterTest, DontSplitReductionsWithContiguousDimensions) { } )") .value(); - EXPECT_FALSE(ReductionSplitter().Run(module.get()).value()); + EXPECT_FALSE( + ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value()); } } // namespace diff --git a/xla/service/gpu/reduction_utils.cc b/xla/service/gpu/reduction_utils.cc index a8576a5e4f1c3..5403d8c5b8587 100644 --- a/xla/service/gpu/reduction_utils.cc +++ b/xla/service/gpu/reduction_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,13 +17,17 @@ limitations under the License. #include #include +#include #include "absl/algorithm/container.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout_util.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/util.h" #include "tsl/platform/logging.h" @@ -96,17 +100,26 @@ Vector3 GetReductionTiling(const ReductionDimensions& reduction_dimensions) { return {1, 128, 1}; } -static bool IsUnnestedReductionFasterThanElemental( +int64_t ReductionDimensionRaceFreeBound( + const HloModuleConfig& hlo_module_config, + const ReductionDimensions& reduction_dimensions) { + Vector3 reduction_tiling = GetReductionTiling(reduction_dimensions); + if (reduction_dimensions.is_row_reduction) { + return MinThreadsXRowReduction(hlo_module_config) * reduction_tiling[2]; + } + return WarpSize() * reduction_tiling[1]; +} + +bool IsUnnestedReductionFasterThanElemental( const ReductionDimensions& reduction_dimensions) { - const int kWarpSize = 32; if (reduction_dimensions.is_row_reduction) { // For row reduction, the tile block is 1 x tile_size_x, and we are reducing // along tile_size_x which needs to be large enough to make the tiling // implementation efficient. // For very small reductions with a power-of-two size, we can fit multiple // reductions inside a single warp, which is more efficient than a loop. - return (reduction_dimensions.dimensions[2] >= kWarpSize) || - ((kWarpSize % reduction_dimensions.dimensions[2]) == 0); + return (reduction_dimensions.dimensions[2] >= WarpSize()) || + ((WarpSize() % reduction_dimensions.dimensions[2]) == 0); } // For column reduction, the tile block is tile_size_y x tile_size_x, and we @@ -117,10 +130,10 @@ static bool IsUnnestedReductionFasterThanElemental( // Rule generated by sweeping the search space of small column reductions. bool prefer_elemental_emitter = - (major_size < kWarpSize) || - (major_size < 2 * kWarpSize && minor_size < kWarpSize) || - (major_size < 4 * kWarpSize && minor_size < 8) || - (major_size < 8 * kWarpSize && minor_size < 3); + (major_size < WarpSize()) || + (major_size < 2 * WarpSize() && minor_size < WarpSize()) || + (major_size < 4 * WarpSize() && minor_size < 8) || + (major_size < 8 * WarpSize() && minor_size < 3); return !prefer_elemental_emitter; } @@ -153,18 +166,18 @@ bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) { bool ReductionIsRaceFree(const HloModuleConfig& hlo_module_config, const ReductionDimensions& reduction_dimensions) { - const int kWarpSize = 32; - Vector3 reduction_tiling = GetReductionTiling(reduction_dimensions); if (reduction_dimensions.is_row_reduction) { return reduction_dimensions.dimensions[2] <= - MinThreadsXRowReduction(hlo_module_config) * - reduction_tiling[2] && + ReductionDimensionRaceFreeBound(hlo_module_config, + reduction_dimensions) && reduction_dimensions.dimensions[0] <= BatchedReductionRaceFreeBound(); } // Column reduction. - return reduction_dimensions.dimensions[1] <= kWarpSize * reduction_tiling[1]; + return reduction_dimensions.dimensions[1] <= + ReductionDimensionRaceFreeBound(hlo_module_config, + reduction_dimensions); } ReductionDimensions GetReductionKindAndContiguousComponents( @@ -208,5 +221,23 @@ ReductionDimensions GetReductionKindAndContiguousComponents( return {/*is_row_reduction=*/false, shape_partition}; } +bool IsRealReductionHero(const HloInstruction& root, + const HloInstruction& hero) { + if (!IsReductionFromOrToContiguousDimensions(hero)) { + return false; + } + return &root == &hero || + ReductionIsRaceFree(hero.GetModule()->config(), + GetReductionKindAndContiguousComponents(hero)); +} + +bool AreReductionsMultiOutputFusionCompatible( + const HloInstruction* reduce_hero, const HloInstruction* first_reduce) { + // The reduction kind must be the same for all reduce heroes inside of a + // multioutput fusion. + return GetReductionKindAndContiguousComponents(*reduce_hero) == + GetReductionKindAndContiguousComponents(*first_reduce); +} + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/reduction_utils.h b/xla/service/gpu/reduction_utils.h index 610200c888464..8245b34c3d73a 100644 --- a/xla/service/gpu/reduction_utils.h +++ b/xla/service/gpu/reduction_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_REDUCTION_UTILS_H_ #define XLA_SERVICE_GPU_REDUCTION_UTILS_H_ -#include "absl/types/span.h" +#include + #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/hlo_module_config.h" #include "xla/util.h" @@ -31,11 +32,16 @@ int64_t MinThreadsXRowReduction(const HloModuleConfig& hlo_module_config); // When doing batched row reduction, how big the batch dimension could be. inline constexpr int64_t BatchedReductionRaceFreeBound() { return 8; } -// Returns true if either the dimensions being reduced or the dimensions being -// kept are contiguous in the input of the reduce instruction. -bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce); - struct ReductionDimensions { + // The reduction dimension indices used below. + constexpr static int kRowMajorReducedDimension = 0; + constexpr static int kRowKeptDimension = 1; + constexpr static int kRowMinorReducedDimension = 2; + + constexpr static int kColMajorKeptDimension = 0; + constexpr static int kColReducedDimension = 1; + constexpr static int kColMinorKeptDimension = 2; + // Indicates whether the reduction is a row reduction or a column reduction. bool is_row_reduction; @@ -45,8 +51,22 @@ struct ReductionDimensions { // For row reduction, we do: [D, H, W] -> [D, H]. // For column reduction, we do: [D, H, W] -> [D, W]. Vector3 dimensions; + + bool operator==(const ReductionDimensions& other) const { + return is_row_reduction == other.is_row_reduction && + dimensions == other.dimensions; + } }; +// Returns true if using the reduction emitter is estimated to be faster than +// using the elemental emitter. +bool IsUnnestedReductionFasterThanElemental( + const ReductionDimensions& reduction_dimensions); + +// Returns true if either the dimensions being reduced or the dimensions being +// kept are contiguous in the input of the reduce instruction. +bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce); + // Given the input shape and dimensions to reduce for a reduction, returns // ReductionDimensions. // @@ -59,11 +79,24 @@ ReductionDimensions GetReductionKindAndContiguousComponents( // Get tiling per thread for the given reduction in dimensions [D, H, W]. Vector3 GetReductionTiling(const ReductionDimensions& reduction_dimensions); +// How big the reduction dimension can be to be race free. +int64_t ReductionDimensionRaceFreeBound( + const HloModuleConfig& hlo_module_config, + const ReductionDimensions& reduction_dimensions); + // Returns whether the given reduction can be safely generated without atomics : // that is, at most one block will write to every output element. bool ReductionIsRaceFree(const HloModuleConfig& hlo_module_config, const ReductionDimensions& reduction_dimensions); +// Whether the instruction is a reduction hero for the given root. +bool IsRealReductionHero(const HloInstruction& root, + const HloInstruction& hero); + +// Whether `reduction_hero` is compatible with `first_reduce`. +bool AreReductionsMultiOutputFusionCompatible( + const HloInstruction* reduce_hero, const HloInstruction* first_reduce); + } // namespace gpu } // namespace xla #endif // XLA_SERVICE_GPU_REDUCTION_UTILS_H_ diff --git a/xla/service/gpu/reduction_utils_test.cc b/xla/service/gpu/reduction_utils_test.cc new file mode 100644 index 0000000000000..9ddd2fd6db150 --- /dev/null +++ b/xla/service/gpu/reduction_utils_test.cc @@ -0,0 +1,191 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/reduction_utils.h" + +#include +#include "absl/strings/str_cat.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/hlo_parser.h" +#include "xla/tests/hlo_test_base.h" + +namespace xla { +namespace gpu { +namespace { + +using ReductionUtilsTest = HloTestBase; + +const char kModulePrefix[] = R"( + HloModule test_module + scalar_add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + })"; + +TEST_F(ReductionUtilsTest, ReductionsAreMultioutputFusionCompatible) { + auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( + fused_sibling1 { + p_0 = f32[32,64]{1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[32]{0} reduce(p_0, constant), dimensions={1}, to_apply=scalar_add + } + + fused_sibling2 { + p_0 = f32[32,64]{1,0} parameter(0) + neg = f32[32,64]{1,0} negate(p_0) + constant = f32[] constant(0) + ROOT reduce = f32[32]{0} reduce(neg, constant), dimensions={1}, to_apply=scalar_add + } + + ENTRY entry { + p_0 = f32[32,64]{1,0} parameter(0) + fusion1 = f32[32]{0} fusion(p_0), kind=kInput, calls=fused_sibling1 + fusion2 = f32[32]{0} fusion(p_0), kind=kInput, calls=fused_sibling2 + ROOT root = (f32[32]{0}, f32[32]{0}) tuple(fusion1, fusion2) + })")) + .value(); + const HloInstruction* root = module->entry_computation()->root_instruction(); + const HloInstruction* fusion1 = root->operand(0); + const HloInstruction* fusion2 = root->operand(1); + EXPECT_TRUE(AreReductionsMultiOutputFusionCompatible( + fusion1->fused_expression_root(), fusion2->fused_expression_root())); +} + +TEST_F(ReductionUtilsTest, + ReductionsWithSameCanonicalizedDimsAreMultioutputFusionCompatible) { + auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( + fused_sibling1 { + p_0 = f32[32,64]{1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[32]{0} reduce(p_0, constant), dimensions={1}, to_apply=scalar_add + } + + fused_sibling2 { + p_0 = f32[32,64]{1,0} parameter(0) + bitcast = f32[32,8,8]{2,1,0} bitcast(p_0) + constant = f32[] constant(0) + ROOT reduce = f32[32]{0} reduce(bitcast, constant), dimensions={1,2}, to_apply=scalar_add + } + + ENTRY entry { + p_0 = f32[32,64]{1,0} parameter(0) + fusion1 = f32[32]{0} fusion(p_0), kind=kInput, calls=fused_sibling1 + fusion2 = f32[32]{0} fusion(p_0), kind=kInput, calls=fused_sibling2 + ROOT root = (f32[32]{0}, f32[32]{0}) tuple(fusion1, fusion2) + })")) + .value(); + const HloInstruction* root = module->entry_computation()->root_instruction(); + const HloInstruction* fusion1 = root->operand(0); + const HloInstruction* fusion2 = root->operand(1); + EXPECT_TRUE(AreReductionsMultiOutputFusionCompatible( + fusion1->fused_expression_root(), fusion2->fused_expression_root())); +} + +TEST_F(ReductionUtilsTest, + ReductionsAreNotMultioutputFusionCompatible_DifferentOperandShapes) { + auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( + fused_sibling1 { + p_0 = f32[32,64]{1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[32]{0} reduce(p_0, constant), dimensions={1}, to_apply=scalar_add + } + + fused_sibling2 { + p_0 = f32[64,32]{1,0} parameter(0) + neg = f32[64,32]{1,0} negate(p_0) + constant = f32[] constant(0) + ROOT reduce = f32[32]{0} reduce(neg, constant), dimensions={0}, to_apply=scalar_add + } + + ENTRY entry { + p_0 = f32[32,64]{1,0} parameter(0) + p_1 = f32[64,32]{1,0} parameter(1) + fusion1 = f32[32]{0} fusion(p_0), kind=kInput, calls=fused_sibling1 + fusion2 = f32[32]{0} fusion(p_1), kind=kInput, calls=fused_sibling2 + ROOT root = (f32[32]{0}, f32[32]{0}) tuple(fusion1, fusion2) + })")) + .value(); + const HloInstruction* root = module->entry_computation()->root_instruction(); + const HloInstruction* fusion1 = root->operand(0); + const HloInstruction* fusion2 = root->operand(1); + EXPECT_FALSE(AreReductionsMultiOutputFusionCompatible( + fusion1->fused_expression_root(), fusion2->fused_expression_root())); +} + +TEST_F(ReductionUtilsTest, + ReductionsAreNotMultioutputFusionCompatible_DifferentOutputShapes) { + auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( + fused_sibling1 { + p_0 = f32[32,64]{1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[32]{0} reduce(p_0, constant), dimensions={1}, to_apply=scalar_add + } + + fused_sibling2 { + p_0 = f32[64,32]{1,0} parameter(0) + neg = f32[64,32]{1,0} negate(p_0) + constant = f32[] constant(0) + ROOT reduce = f32[64]{0} reduce(neg, constant), dimensions={1}, to_apply=scalar_add + } + + ENTRY entry { + p_0 = f32[32,64]{1,0} parameter(0) + p_1 = f32[64,32]{1,0} parameter(1) + fusion1 = f32[32]{0} fusion(p_0), kind=kInput, calls=fused_sibling1 + fusion2 = f32[64]{0} fusion(p_1), kind=kInput, calls=fused_sibling2 + ROOT root = (f32[32]{0}, f32[64]{0}) tuple(fusion1, fusion2) + })")) + .value(); + const HloInstruction* root = module->entry_computation()->root_instruction(); + const HloInstruction* fusion1 = root->operand(0); + const HloInstruction* fusion2 = root->operand(1); + EXPECT_FALSE(AreReductionsMultiOutputFusionCompatible( + fusion1->fused_expression_root(), fusion2->fused_expression_root())); +} + +TEST_F(ReductionUtilsTest, + ReductionsAreNotMultioutputFusionCompatible_DifferentReduceDimensions) { + auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( + fused_sibling1 { + p_0 = f32[32,32]{1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[32]{0} reduce(p_0, constant), dimensions={0}, to_apply=scalar_add + } + + fused_sibling2 { + p_0 = f32[32,32]{1,0} parameter(0) + neg = f32[32,32]{1,0} negate(p_0) + constant = f32[] constant(0) + ROOT reduce = f32[32]{0} reduce(neg, constant), dimensions={1}, to_apply=scalar_add + } + + ENTRY entry { + p_0 = f32[32,32]{1,0} parameter(0) + fusion1 = f32[32]{0} fusion(p_0), kind=kInput, calls=fused_sibling1 + fusion2 = f32[32]{0} fusion(p_0), kind=kInput, calls=fused_sibling2 + ROOT root = (f32[32]{0}, f32[32]{0}) tuple(fusion1, fusion2) + })")) + .value(); + const HloInstruction* root = module->entry_computation()->root_instruction(); + const HloInstruction* fusion1 = root->operand(0); + const HloInstruction* fusion2 = root->operand(1); + EXPECT_FALSE(AreReductionsMultiOutputFusionCompatible( + fusion1->fused_expression_root(), fusion2->fused_expression_root())); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/rename_fusions.cc b/xla/service/gpu/rename_fusions.cc new file mode 100644 index 0000000000000..1a6731cdd4919 --- /dev/null +++ b/xla/service/gpu/rename_fusions.cc @@ -0,0 +1,93 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/rename_fusions.h" + +#include +#include + +#include "absl/container/btree_set.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/ir_emission_utils.h" + +namespace xla { +namespace gpu { +namespace { + +constexpr absl::string_view FusionKindToString( + HloInstruction::FusionKind kind) { + switch (kind) { + case HloInstruction::FusionKind::kCustom: + return "custom"; + case HloInstruction::FusionKind::kLoop: + return "loop"; + case HloInstruction::FusionKind::kInput: + return "input"; + case HloInstruction::FusionKind::kOutput: + return "output"; + } +} + +std::string MakeFusionHeroNames(const HloInstruction* instruction) { + std::unique_ptr fusion_adaptor = + HloFusionAdaptor::ForInstruction(instruction); + absl::btree_set heroes; + + for (auto root : fusion_adaptor->GetRoots()) { + heroes.insert(HloOpcodeString( + FindNonTrivialHero(root.instruction(), *fusion_adaptor).opcode())); + } + return absl::StrReplaceAll(absl::StrJoin(heroes, "_"), {{"-", "_"}}); +} + +void RenameFusion(HloModule* module, HloInstruction* instruction) { + std::string hero_names = MakeFusionHeroNames(instruction); + module->SetAndUniquifyInstrName( + instruction, absl::StrCat(FusionKindToString(instruction->fusion_kind()), + "_", hero_names, "_fusion")); + module->SetAndUniquifyComputationName( + instruction->fused_instructions_computation(), + absl::StrCat("fused_", hero_names)); +} + +} // namespace + +absl::StatusOr RenameFusions::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + for (HloComputation* computation : module->MakeNonfusionComputations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kFusion || + instruction->fusion_kind() == HloInstruction::FusionKind::kCustom) { + continue; + } + RenameFusion(module, instruction); + } + } + return true; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/rename_fusions.h b/xla/service/gpu/rename_fusions.h new file mode 100644 index 0000000000000..c3065a4dbd1df --- /dev/null +++ b/xla/service/gpu/rename_fusions.h @@ -0,0 +1,47 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RENAME_FUSIONS_H_ +#define XLA_SERVICE_GPU_RENAME_FUSIONS_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// An HLO pass that gives fusions and fused computations descriptive names. +// +// The name is based on hero instructions and the fusion kind, i.e. +// Fusions get name "__fusion", +// and fused computations get name "fused_". +// In the case of multiple roots, the hero instructions in the name are +// underscore-separated and alphabetically sorted. + +class RenameFusions : public HloModulePass { + absl::string_view name() const override { return "rename_fusions"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RENAME_FUSIONS_H_ diff --git a/xla/service/gpu/rename_fusions_test.cc b/xla/service/gpu/rename_fusions_test.cc new file mode 100644 index 0000000000000..60c97cf2ff943 --- /dev/null +++ b/xla/service/gpu/rename_fusions_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/rename_fusions.h" + +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/tests/hlo_test_base.h" + +namespace xla { +namespace gpu { + +class RenameFusionsTest : public HloTestBase { + protected: + RenameFusions rename_fusions_; +}; + +TEST_F(RenameFusionsTest, FusionInstructionNames) { + absl::string_view kHlo = R"( + HloModule test_module + + square { + p = f32[16384] parameter(0) + ROOT m = f32[16384] multiply(p, p) + } + + exp { + p = f32[16384] parameter(0) + ROOT e = f32[16384] exponential(p) + } + + log { + p = f32[16384] parameter(0) + ROOT l = f32[16384] log(p) + } + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + ENTRY main { + p0 = bf16[1024,8192] parameter(0) + p1 = f32[8192] parameter(1) + p2 = f32[16384] parameter(2) + convert = f32[1024,8192] convert(p0) + broadcast = f32[1024,8192] broadcast(p1), dimensions={1} + c0 = f32[] constant(0) + multiply = f32[1024,8192] multiply(broadcast, convert) + reduce = f32[1024] reduce(multiply, c0), dimensions={1}, to_apply=add + convert.1 = bf16[1024] convert(reduce) + s = f32[16384] fusion(p2), kind=kLoop, calls=square + e = f32[16384] fusion(s), kind=kLoop, calls=exp + l = f32[16384] fusion(s), kind=kInput, calls=log + ROOT result = (bf16[1024]{0}, f32[16384]{0}, f32[16384]{0}) tuple(convert.1, l, e) + })"; + + RunAndFilecheckHloRewrite(kHlo, std::move(rename_fusions_), R"( +CHECK: ENTRY %main +CHECK: %loop_multiply_fusion{{.*}} calls=%fused_multiply +CHECK: %input_log_fusion{{.*}} calls=%fused_log +CHECK: %loop_exponential_fusion{{.*}} calls=%fused_exponential +CHECK: ROOT %result + )"); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/replica_id_thunk.cc b/xla/service/gpu/replica_id_thunk.cc deleted file mode 100644 index 7452f2b5cc8ac..0000000000000 --- a/xla/service/gpu/replica_id_thunk.cc +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/replica_id_thunk.h" - -namespace xla { -namespace gpu { - -Status ReplicaOrPartitionIdThunk::ExecuteOnStream(const ExecuteParams& params) { - auto dest_addr = params.buffer_allocations->GetDeviceAddress(dest_); - - TF_ASSIGN_OR_RETURN(const GlobalDeviceId global_device_id, - params.nccl_params.GetGlobalDeviceId()); - TF_ASSIGN_OR_RETURN( - const DeviceAssignment::LogicalID logical_id, - params.nccl_params.device_assn->LogicalIdForDevice(global_device_id)); - int id = kind() == Kind::kReplicaId ? logical_id.replica_id - : logical_id.computation_id; - params.stream->ThenMemset32(&dest_addr, id, /*size=*/4); - return OkStatus(); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/BUILD b/xla/service/gpu/runtime/BUILD index 00f806a46354d..078fd98ac569c 100644 --- a/xla/service/gpu/runtime/BUILD +++ b/xla/service/gpu/runtime/BUILD @@ -1,17 +1,10 @@ -load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") -load("//xla:xla.bzl", "xla_cc_test") -load("//xla/service/gpu:build_defs.bzl", "gpu_kernel_library") -load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") -load( - "@local_config_rocm//rocm:build_defs.bzl", - "if_rocm_is_configured", - "rocm_library", -) -load( - "@tsl//tsl/platform:build_config_root.bzl", - "tf_gpu_tests_tags", -) +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") +load("@tsl//tsl:tsl.bzl", "if_google", "if_nccl", "nvtx_headers") +load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") +load("//xla/service/gpu:build_defs.bzl", "get_cub_sort_kernel_types") +load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") +load("//xla/tests:build_defs.bzl", "xla_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -21,831 +14,1178 @@ package( package_group( name = "friends", - includes = [ - "//xla:friends", - ], + includes = ["//xla:friends"], ) -gpu_kernel_library( - name = "gpu_kernel_helper", - hdrs = if_gpu_is_configured(["gpu_kernel_helper.h"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - deps = [ - "//xla/stream_executor/platform", - "@tsl//tsl/lib/math:math_util", - ] + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - ]) + if_rocm_is_configured([ - "@local_config_rocm//rocm:rocm_headers", - ]), -) +#===-------------------------------------------------------------------------------------------===// +# Runtime tracing libraries +#===-------------------------------------------------------------------------------------------===// cc_library( - name = "cholesky", - srcs = ["cholesky.cc"], - hdrs = ["cholesky.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), + name = "annotation", + srcs = ["annotation.cc"], + hdrs = ["annotation.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ - ":support", - "//xla:xla_proto_cc", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", - "//xla/service/gpu:gpu_asm_opts_util", - "//xla/service/gpu/runtime3:cholesky_thunk", - ], + "//xla:printer", + "//xla:status", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@tsl//tsl/platform:errors", + "@tsl//tsl/profiler/lib:nvtx_utils", + "@tsl//tsl/profiler/lib:scoped_annotation", + ] + if_cuda_is_configured(nvtx_headers()), ) -cuda_library( - name = "sleep_kernel_cuda", - srcs = if_cuda_is_configured(["sleep_kernel.cu.cc"]), - hdrs = if_cuda_is_configured(["sleep_kernel.h"]), - compatible_with = [], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - deps = if_cuda_is_configured([ - ":gpu_kernel_helper", - ]), -) +#===-------------------------------------------------------------------------------------------===// +# Command Buffer Integration +#===-------------------------------------------------------------------------------------------===// -rocm_library( - name = "sleep_kernel_rocm", - srcs = if_rocm_is_configured(["sleep_kernel.cu.cc"]), - hdrs = if_rocm_is_configured(["sleep_kernel.h"]), - compatible_with = [], - local_defines = if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), - deps = if_rocm_is_configured([ - ":gpu_kernel_helper", - ]), +cc_library( + name = "command_buffer_allocations", + srcs = ["command_buffer_allocations.cc"], + hdrs = ["command_buffer_allocations.h"], + deps = [ + "//xla:status", + "//xla:statusor", + "//xla/service:buffer_assignment", + "//xla/service/gpu:buffer_allocations", + "//xla/stream_executor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], ) cc_library( - name = "collectives", - srcs = ["collectives.cc"], - hdrs = ["collectives.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", + name = "command_buffer_cmd", + srcs = ["command_buffer_cmd.cc"], + hdrs = ["command_buffer_cmd.h"], + local_defines = if_cuda_is_configured([ + "GOOGLE_CUDA=1", ]), deps = [ - ":support", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", + ":annotation", + ":custom_call_thunk", + ":nccl_all_gather_thunk", + ":nccl_all_reduce_thunk", + ":nccl_api", + ":nccl_collective_broadcast_thunk", + ":nccl_collective_thunk", + "//xla:executable_run_options", + "//xla:status", + "//xla:types", + "//xla:util", + "//xla/ffi:call_frame", + "//xla/ffi:ffi_api", + "//xla/ffi/api:c_api", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", - "//xla/service:computation_placer_hdr", + "//xla/service:computation_placer", + "//xla/service:custom_call_status_internal", + "//xla/service:custom_call_status_public_headers", "//xla/service:executable", "//xla/service:global_device_id", - "//xla/service/gpu:gpu_executable_run_options", - "//xla/service/gpu:nccl_collective_thunks", - "//xla/service/gpu:thunk", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:nccl_clique_key", + "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu/kernels:custom_kernel", + "//xla/service/gpu/runtime:thunk", "//xla/stream_executor", - "@com_google_absl//absl/base", + "//xla/stream_executor/gpu:gpu_stream_header", + "//xla/stream_executor/gpu:gpu_types_header", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - ] + if_gpu_is_configured([ - ":gpu_kernel_helper", - "//xla/stream_executor/gpu:gpu_types_header", - "//xla/stream_executor/gpu:gpu_stream_header", - ]) + if_cuda_is_configured([ - ":sleep_kernel_cuda", - ]) + if_rocm_is_configured([ - ":sleep_kernel_rocm", - ]), + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@tsl//tsl/concurrency:ref_count", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/profiler/lib:scoped_annotation", + ], ) cc_library( - name = "conv", - srcs = ["conv.cc"], - hdrs = ["conv.h"], - local_defines = if_cuda_is_configured([ - "GOOGLE_CUDA=1", - ]), + name = "command_buffer_cmd_emitter", + srcs = ["command_buffer_cmd_emitter.cc"], + hdrs = ["command_buffer_cmd_emitter.h"], deps = [ - ":support", + ":command_buffer_cmd", + ":conditional_thunk", + ":copy_thunk", + ":cudnn_thunk", + ":custom_call_thunk", + ":gemm_thunk", + ":kernel_thunk", + ":memset_thunk", + ":nccl_all_gather_thunk", + ":nccl_all_reduce_thunk", + ":replica_id_thunk", + ":sequential_thunk", + ":wait_for_streams_thunk", + ":while_thunk", "//xla:status", - "//xla:xla_proto_cc", - "//xla/mlir/runtime/transforms:custom_call_encoding", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", + "//xla:statusor", + "//xla:util", + "//xla/service/gpu/runtime:thunk", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "command_buffer_cmd_test", + srcs = if_gpu_is_configured(["command_buffer_cmd_test.cc"]), + backends = ["gpu"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), + deps = [ + ":command_buffer_cmd", + "//xla:status", + "//xla:types", + "//xla/service:buffer_assignment", "//xla/service:executable", - "//xla/service/gpu:autotuner_util", - "//xla/service/gpu:gpu_asm_opts_util", - "//xla/service/gpu:gpu_conv_runner", - "//xla/service/gpu:non_atomically_upgradeable_rw_lock", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_allocator", - "//xla/translate/mhlo_to_hlo:attribute_exporter", - "@com_google_absl//absl/container:node_hash_map", + "//xla/service:platform_util", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor/gpu:gpu_test_kernels", "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/synchronization", - "@llvm-project//llvm:Support", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_benchmark", + "@tsl//tsl/platform:test_main", + ], +) + +#===-------------------------------------------------------------------------------------------===// +# NCCL integration +#===-------------------------------------------------------------------------------------------===// + +# A lot of build complexity below is because NCCL dependency might not always be available and we +# have `if_nccl` and `if_gpu_configured` that do not compose. NCCL header included directly in +# :nccl_api target and all other targets should use this header to launch collective operations. +# This allows to minimize the spreading of #ifdef all over the XLA code base. +alias( + name = "nccl_api", + actual = if_nccl(":_nccl_api_impl", ":_nccl_api_stub"), +) + +cc_library( + name = "_nccl_api_impl", + srcs = if_gpu_is_configured( + ["nccl_api.cc"], + ["nccl_api_stub.cc"], + ), + hdrs = ["nccl_api.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/service:collective_ops_utils", + "//xla/service/gpu:nccl_clique_key", + "//xla/stream_executor", + "//xla/stream_executor/gpu:gpu_activation", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@tsl//tsl/concurrency:ref_count", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", ] + if_cuda_is_configured([ - "//xla/service/gpu:conv_algorithm_picker", + "@local_config_nccl//:nccl", + "//xla/stream_executor/cuda:cuda_driver", + "//xla/stream_executor/cuda:cuda_executor", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rccl", + "//xla/stream_executor/rocm:rocm_driver", + "//xla/stream_executor/rocm:rocm_executor", + ]) + if_gpu_is_configured([ + "//xla/stream_executor/gpu:gpu_stream", ]), ) cc_library( - name = "conv_reorder", - srcs = ["conv_reorder.cc"], - hdrs = ["conv_reorder.h"], + name = "_nccl_api_stub", + srcs = ["nccl_api_stub.cc"], + hdrs = ["nccl_api.h"], + compatible_with = get_compatible_with_portable(), deps = [ - ":support", - "//xla:xla_proto_cc", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/service:collective_ops_utils", + "//xla/service/gpu:nccl_clique_key", + "//xla/stream_executor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@tsl//tsl/concurrency:ref_count", + "@tsl//tsl/platform:logging", ], ) cc_library( - name = "norm", - srcs = ["norm.cc"], - hdrs = ["norm.h"], + name = "nccl_clique", + srcs = ["nccl_clique.cc"], + hdrs = ["nccl_clique.h"], deps = [ - ":support", - "//xla:status", - "//xla:xla_proto_cc", - "//xla/mlir/runtime/transforms:custom_call_encoding", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", - "//xla/service/gpu:gpu_asm_opts_util", - "//xla/service/gpu:gpu_norm_runner", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_allocator", - "//xla/translate/mhlo_to_hlo:attribute_exporter", + ":nccl_api", + "//xla:debug_options_flags", + "//xla:executable_run_options", + "//xla:status_macros", + "//xla/service:global_device_id", + "//xla/service:lockable", + "//xla/service:rendezvous", + "//xla/service/gpu:nccl_clique_key", + "//xla/stream_executor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "@llvm-project//llvm:Support", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:hash", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", ], ) +#===-------------------------------------------------------------------------------------------===// +# XLA Thunks Runtime +#===-------------------------------------------------------------------------------------------===// + cc_library( - name = "fused_attention", - srcs = ["fused_attention.cc"], - hdrs = ["fused_attention.h"], + name = "address_computation_thunk", + srcs = ["address_computation_thunk.cc"], + hdrs = ["address_computation_thunk.h"], deps = [ - ":support", + ":sequential_thunk", + "//xla:shape_util", "//xla:status", - "//xla:xla_proto_cc", - "//xla/mlir/runtime/transforms:custom_call_encoding", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", - "//xla/service/gpu:gpu_asm_opts_util", - "//xla/service/gpu:gpu_fused_mha_runner", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_allocator", - "//xla/translate/mhlo_to_hlo:attribute_exporter", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/functional:function_ref", + "//xla:status_macros", + "//xla/service:buffer_assignment", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "//xla/stream_executor:memory_allocation", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) -cc_library( - name = "cub_sort", - srcs = ["cub_sort.cc"], - hdrs = ["cub_sort.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), +xla_test( + name = "address_computation_thunk_test", + srcs = if_gpu_is_configured(["address_computation_thunk_test.cc"]), + backend_tags = { + "gpu_a100": if_google(["config-cuda-only"]), + "gpu_v100": if_google(["config-cuda-only"]), + }, + backends = [ + "gpu_a100", + "gpu_v100", + ], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), deps = [ - ":support", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:memref_view", + ":address_computation_thunk", + ":custom_call_thunk", + ":gemm_thunk", + "//xla:shape_util", + "//xla:types", + "//xla/ffi", + "//xla/ffi:ffi_api", + "//xla/service:buffer_assignment", "//xla/service:executable", + "//xla/service:platform_util", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor/gpu:gpu_test_kernels", + "//xla/stream_executor/gpu:gpu_types_header", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]), +) + +cc_library( + name = "cholesky_thunk", + srcs = if_gpu_is_configured(["cholesky_thunk.cc"]), + hdrs = if_gpu_is_configured(["cholesky_thunk.h"]), + deps = if_gpu_is_configured([ + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:cusolver_context", + "//xla/service/gpu:make_batch_pointers", + "//xla/service/gpu/runtime:thunk", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/service:buffer_assignment", + "//xla/hlo/ir:hlo", + "@tsl//tsl/platform:logging", + "//xla/stream_executor", "//xla/stream_executor:device_memory", + "//xla/stream_executor/gpu:gpu_asm_opts", "@com_google_absl//absl/status", - ] + if_gpu_is_configured([ - "//xla/service/gpu:cub_sort_thunk", + "@com_google_absl//absl/strings:str_format", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:status", ]), ) cc_library( - name = "custom_call", - srcs = ["custom_call.cc"], - hdrs = ["custom_call.h"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", + name = "command_buffer_thunk", + srcs = ["command_buffer_thunk.cc"], + hdrs = ["command_buffer_thunk.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + deps = [ + ":annotation", + ":command_buffer_allocations", + ":command_buffer_cmd", + "//xla:status", + "//xla:statusor", + "//xla/service:buffer_assignment", # build_cleaner: keep + "//xla/service/gpu:buffer_allocations", # build_cleaner: keep + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/profiler/lib:profiler_lock", + "@tsl//tsl/profiler/lib:scoped_annotation", + "@tsl//tsl/profiler/lib:traceme", + "@tsl//tsl/profiler/lib:traceme_encode", ], - features = ["-use_header_modules"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), +) + +xla_test( + name = "command_buffer_thunk_test", + srcs = if_gpu_is_configured(["command_buffer_thunk_test.cc"]), + backend_tags = { + "gpu_a100": if_google(["config-cuda-only"]), + "gpu_v100": if_google(["config-cuda-only"]), + }, + backends = [ + "gpu_a100", + "gpu_v100", + ], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), deps = [ - ":support", - ":triangular_solve", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:custom_call_status_internal", - "//xla/service:custom_call_status_public_headers", - "//xla/service:custom_call_target_registry", + ":command_buffer_allocations", + ":command_buffer_cmd", + ":command_buffer_thunk", + "//xla:shape_util", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/service:buffer_assignment", "//xla/service:executable", - "//xla/service/gpu:cublas_cudnn", - "//xla/stream_executor/gpu:gpu_stream_header", - ], + "//xla/service:platform_util", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor/gpu:gpu_test_kernels", + "//xla/stream_executor/gpu:gpu_types_header", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]), ) cc_library( - name = "custom_call_registry", - srcs = ["custom_call_registry.cc"], - hdrs = ["custom_call_registry.h"], - deps = ["//xla/runtime:custom_call_registry"], + name = "conditional_thunk", + srcs = ["conditional_thunk.cc"], + hdrs = ["conditional_thunk.h"], + deps = [ + ":sequential_thunk", + "//xla:status", + "//xla:status_macros", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:variant_visitor", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "//xla/stream_executor:memory_allocation", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], ) cc_library( - name = "executable", - srcs = ["executable.cc"], - hdrs = ["executable.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + name = "convolution_thunk", + srcs = ["convolution_thunk.cc"], + hdrs = ["convolution_thunk.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), deps = [ - ":cholesky", - ":collectives", - ":concurrent_region", - ":conv", - ":conv_reorder", - ":cub_sort", - ":custom_call", - ":custom_call_registry", - ":fft", - ":fused_attention", - ":gemm", - ":gpublas_lt_matmul", - ":graph_launch", - ":io_feed", - ":kernel_launch", - ":memcpy", - ":memset", - ":norm", - ":send_recv", - ":stream_synchronization", - ":support", - ":topk", - ":resize_bicubic", - ":tracing", - "//xla:statusor", - "//xla:xla_proto_cc", - "//xla/mlir/runtime/transforms:compilation_pipeline_gpu", - "//xla/runtime:executable", - "//xla/runtime:jit_executable", - "//xla/runtime:module_registry", - "//xla/service:executable", - "//xla/service:stream_pool", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:non_atomically_upgradeable_rw_lock", - "//xla/service/gpu:thunk", + "//xla:util", + "//xla/service:buffer_assignment", + "//xla/service/gpu:gpu_conv_runner", + "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu/runtime:thunk", "//xla/stream_executor", - "//xla/stream_executor/gpu:gpu_stream", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", - "@tsl//tsl/protobuf:dnn_proto_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", ], ) cc_library( - name = "fft", - srcs = ["fft.cc"], - hdrs = ["fft.h"], + name = "copy_thunk", + srcs = ["copy_thunk.cc"], + hdrs = ["copy_thunk.h"], deps = [ - ":support", - "//xla/mlir/runtime/transforms:custom_call_encoding", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:state", - "//xla/service/gpu/runtime3:fft_thunk", - "//xla/stream_executor:fft", - "//xla/translate/mhlo_to_hlo:attribute_exporter", + "//xla:status", + "//xla/service:buffer_assignment", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "@com_google_absl//absl/status", + "@llvm-project//mlir:IR", ], ) cc_library( - name = "topk_kernel", - srcs = if_gpu_is_configured(["topk_kernel.cc"]), - hdrs = if_gpu_is_configured(["topk_kernel.h"]), - compatible_with = [], + name = "cub_sort_thunk", + srcs = if_gpu_is_configured(["cub_sort_thunk.cc"]), + hdrs = if_gpu_is_configured(["cub_sort_thunk.h"]), local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - deps = [ - ":gpu_kernel_helper", - ":support", + deps = if_gpu_is_configured([ + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "//xla/service:buffer_assignment", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor:device_memory", "//xla:shape_util", - "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/runtime:memref_view", - "//xla/stream_executor", # build_cleaner: keep - "//xla/stream_executor:platform", + "@tsl//tsl/platform:errors", + ] + ["//xla/service/gpu:cub_sort_kernel_" + suffix for suffix in get_cub_sort_kernel_types()]), +) + +cc_library( + name = "custom_call_thunk", + srcs = ["custom_call_thunk.cc"], + hdrs = ["custom_call_thunk.h"], + local_defines = if_cuda_is_configured([ + "GOOGLE_CUDA=1", + ]), + deps = [ + "//xla:executable_run_options", + "//xla:shape_util", + "//xla:status", + "//xla:util", + "//xla/ffi:call_frame", + "//xla/ffi:ffi_api", + "//xla/ffi/api:c_api", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service:custom_call_status", + "//xla/service:custom_call_status_internal", + "//xla/service:executable", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor:device_memory", "//xla/stream_executor/gpu:gpu_stream_header", "//xla/stream_executor/gpu:gpu_types_header", - "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@eigen_archive//:eigen3", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:statusor", - ] + if_cuda_is_configured([ - ":topk_kernel_cuda", - ]) + if_rocm_is_configured([ - ":topk_kernel_rocm", - ]), + ], ) -cuda_library( - name = "topk_kernel_cuda", - srcs = if_cuda_is_configured( - [ - "topk_kernel_bfloat16.cu.cc", - "topk_kernel_float.cu.cc", - "topk_kernel.cu.h", - ], - ), - hdrs = if_cuda_is_configured(["topk_kernel_common.h"]), - compatible_with = [], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), +cc_library( + name = "fft_thunk", + srcs = ["fft_thunk.cc"], + hdrs = ["fft_thunk.h"], deps = [ - ":gpu_kernel_helper", - "@eigen_archive//:eigen3", + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:status", ], ) -rocm_library( - name = "topk_kernel_rocm", - srcs = if_rocm_is_configured( - [ - "topk_kernel_bfloat16.cu.cc", - "topk_kernel_float.cu.cc", - "topk_kernel.cu.h", - ], - ), - hdrs = if_rocm_is_configured(["topk_kernel_common.h"]), - compatible_with = [], - local_defines = if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), +cc_library( + name = "flash_attn_thunk", + srcs = ["flash_attn_thunk.cc"], + hdrs = ["flash_attn_thunk.h"], deps = [ - ":gpu_kernel_helper", - "@eigen_archive//:eigen3", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/service:buffer_assignment", + "//xla/service/gpu:gpu_flash_attn", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", ], ) -xla_cc_test( - name = "topk_kernel_test", - srcs = if_gpu_is_configured(["topk_kernel_test.cc"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - tags = tf_gpu_tests_tags(), +cc_library( + name = "fused_mha_thunk", + srcs = ["fused_mha_thunk.cc"], + hdrs = ["fused_mha_thunk.h"], deps = [ - ":gpu_kernel_helper", - ":topk_kernel", + "//xla:util", "//xla:xla_data_proto_cc", - "//xla/stream_executor", # build_cleaner: keep - "//xla/stream_executor:multi_platform_manager", - "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_timer_header", - "//xla/stream_executor/gpu:gpu_types_header", - "//xla/stream_executor/host:host_platform", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/random", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@eigen_archive//:eigen3", - "@tsl//tsl/platform:test", - "@tsl//tsl/platform:test_benchmark", - "@tsl//tsl/platform:test_main", + "//xla/service:buffer_assignment", + "//xla/service/gpu:gpu_fused_mha_runner", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", ], ) -xla_cc_test( - name = "topk_test", - srcs = if_gpu_is_configured(["topk_test.cc"]), +cc_library( + name = "gemm_thunk", + srcs = ["gemm_thunk.cc"], + hdrs = ["gemm_thunk.h"], + deps = [ + "//xla:status", + "//xla/service:buffer_assignment", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor:device_memory", + "@com_google_absl//absl/status", + "@tsl//tsl/platform:logging", + ], +) + +cc_library( + name = "gpublas_lt_matmul_thunk", + srcs = if_gpu_is_configured(["gpublas_lt_matmul_thunk.cc"]), + hdrs = if_gpu_is_configured(["gpublas_lt_matmul_thunk.h"]), local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - tags = tf_gpu_tests_tags(), + deps = if_gpu_is_configured([ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/synchronization", + "//xla/service:buffer_assignment", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu/runtime:thunk", + "//xla:status", + "//xla/stream_executor:device_memory", + "//xla/stream_executor", + "//xla/stream_executor/gpu:gpu_blas_lt", + "@tsl//tsl/platform:logging", + ]), +) + +cc_library( + name = "infeed_thunk", + srcs = ["infeed_thunk.cc"], + hdrs = ["infeed_thunk.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:io_feed_manager", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "kernel_thunk", + srcs = ["kernel_thunk.cc"], + hdrs = ["kernel_thunk.h"], + deps = [ + "//xla:status", + "//xla:types", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service/gpu:kernel_arguments", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu/kernels:custom_kernel", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "memset_thunk", + srcs = ["memset_thunk.cc"], + hdrs = ["memset_thunk.h"], + deps = [ + "//xla:status", + "//xla/service:buffer_assignment", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "nccl_all_gather_thunk", + srcs = ["nccl_all_gather_thunk.cc"], + hdrs = ["nccl_all_gather_thunk.h"], deps = [ - ":topk", - "//xla:error_spec", + ":nccl_api", + ":nccl_collective_thunk", "//xla:shape_util", - "//xla:status", - "//xla:statusor", - "//xla:types", "//xla/hlo/ir:hlo", - "//xla/service:gpu_plugin", - "//xla/service:hlo_pass", - "//xla/service:platform_util", - "//xla/service:topk_rewriter", - "//xla/service/gpu:topk_specializer", - "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", + "//xla/service:collective_ops_utils", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", "@tsl//tsl/platform:statusor", - "@tsl//tsl/platform:test_main", ], ) cc_library( - name = "topk", - srcs = if_gpu_is_configured(["topk.cc"]), - hdrs = ["topk.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - deps = if_gpu_is_configured([":topk_kernel"]) + [ - ":support", - "//xla:executable_run_options", - "//xla:shape_util", - "//xla:status", - "//xla:statusor", - "//xla:types", + name = "nccl_all_reduce_thunk", + srcs = ["nccl_all_reduce_thunk.cc"], + hdrs = ["nccl_all_reduce_thunk.h"], + deps = [ + ":nccl_api", + ":nccl_collective_thunk", + "//xla:status_macros", "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", - "//xla/mlir/runtime/transforms:custom_call_encoding", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:state", - "//xla/service:executable", - "//xla/service:hlo_pass", - "//xla/service:tuple_util", - "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_types_header", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", + "//xla/service:collective_ops_utils", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", "@tsl//tsl/platform:statusor", ], ) cc_library( - name = "resize_bicubic_kernel", - srcs = if_cuda_is_configured( - [ - "resize_bicubic_kernel.cc", - ], - ), - hdrs = if_cuda_is_configured(["resize_bicubic_kernel.h"]), - compatible_with = [], + name = "nccl_all_to_all_thunk", + srcs = ["nccl_all_to_all_thunk.cc"], + hdrs = ["nccl_all_to_all_thunk.h"], deps = [ - ":resize_bicubic_kernel_cuda", - # "//xla:shape_util", - "//xla:xla_proto_cc", - "//xla:xla_data_proto_cc", - "//xla/runtime:memref_view", - "//xla/stream_executor:platform", - "//xla/stream_executor:stream_executor_headers", # build_cleaner: keep - "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_types_header", - "@com_google_absl//absl/numeric:bits", + ":nccl_api", + ":nccl_collective_thunk", + "//xla:shape_util", + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/service:collective_ops_utils", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@local_config_cuda//cuda:cuda_headers", + "@com_google_absl//absl/strings", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", ], ) -cuda_library( - name = "resize_bicubic_kernel_cuda", - srcs = if_cuda_is_configured( - [ - "resize_bicubic_kernel.cu.cc", - ], - ), - hdrs = if_cuda_is_configured(["resize_bicubic_kernel_common.h"]), - compatible_with = [], +cc_library( + name = "nccl_collective_broadcast_thunk", + srcs = ["nccl_collective_broadcast_thunk.cc"], + hdrs = ["nccl_collective_broadcast_thunk.h"], deps = [ - "@eigen_archive//:eigen3", - "@local_config_cuda//cuda:cuda_headers", + ":nccl_api", + ":nccl_collective_thunk", + "//xla:status", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:collective_ops_utils", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) - cc_library( - name = "resize_bicubic", - srcs = if_cuda_is_configured( - ["resize_bicubic.cc"], - ), - hdrs = ["resize_bicubic.h"], - deps = if_cuda_is_configured([":resize_bicubic_kernel"]) + [ - ":support", - "//xla:executable_run_options", - # "//xla:shape_util", - "//xla:status", - "//xla:statusor", - # "//xla:types", + name = "nccl_collective_permute_thunk", + srcs = ["nccl_collective_permute_thunk.cc"], + hdrs = ["nccl_collective_permute_thunk.h"], + deps = [ + ":nccl_api", + ":nccl_collective_thunk", + ":nccl_p2p_thunk_common", + "//xla:status_macros", "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", - # "//xla/mlir/runtime/transforms:custom_call_encoding", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:state", - # "//xla/runtime/ffi:ffi_api", - # "//xla/runtime/ffi:ffi_c_api_hdrs", - "//xla/service:executable", - "//xla/service:hlo_pass", - "//xla/service:tuple_util", - "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_types_header", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "//xla/translate/mhlo_to_hlo:attribute_exporter", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", ], ) cc_library( - name = "gemm", - srcs = ["gemm.cc"], - hdrs = ["gemm.h"], + name = "nccl_collective_thunk", + srcs = ["nccl_collective_thunk.cc"], + hdrs = ["nccl_collective_thunk.h"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", + "TENSORFLOW=1", ]), deps = [ - ":support", + ":nccl_api", + ":nccl_clique", + "//xla:debug_options_flags", + "//xla:shape_util", "//xla:status", - "//xla:xla_proto_cc", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:state", - "//xla/service:executable", - "//xla/service:hlo_module_config", - "//xla/service/gpu:gpu_asm_opts_util", - "//xla/service/gpu:matmul_utils", - "//xla/service/gpu:non_atomically_upgradeable_rw_lock", - "//xla/stream_executor:blas", - "//xla/stream_executor:device_memory", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service:collective_ops_utils", + "//xla/service:computation_placer", + "//xla/service:global_device_id", + "//xla/service:rendezvous", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:nccl_clique_key", + "//xla/service/gpu/runtime:thunk", + "//xla/service/llvm_ir:llvm_util", + "//xla/stream_executor", + "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor/gpu:gpu_activation_header", + "//xla/stream_executor/gpu:gpu_driver_header", + "//xla/stream_executor/gpu:gpu_stream", + "//xla/stream_executor/gpu:gpu_types_header", + "//xla/translate/mhlo_to_hlo:attribute_exporter", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@llvm-project//mlir:IR", "@tsl//tsl/platform:errors", - ] + if_gpu_is_configured([ - "//xla/service/gpu:gemm_algorithm_picker", - "//xla/stream_executor/gpu:redzone_allocator", + "@tsl//tsl/platform:statusor", + ] + if_cuda_is_configured([ + "@local_config_nccl//:nccl", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rccl", ]), ) cc_library( - name = "graph_launch", - srcs = ["graph_launch.cc"], - hdrs = ["graph_launch.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), + name = "nccl_p2p_thunk_common", + srcs = ["nccl_p2p_thunk_common.cc"], + hdrs = ["nccl_p2p_thunk_common.h"], deps = [ - ":concurrent_region", - ":conv", - ":gemm", - ":kernel_launch", - ":support", - "//xla:statusor", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:non_atomically_upgradeable_rw_lock", + ":nccl_collective_thunk", + "//xla:shape_util", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:collective_ops_utils", + "//xla/service:hlo_parser", + "//xla/service/gpu:nccl_clique_key", "//xla/stream_executor", - "//xla/stream_executor/gpu:gpu_graph", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", + "//xla/stream_executor:stream_executor_headers", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@tsl//tsl/profiler/lib:profiler_lock", - "@tsl//tsl/profiler/lib:traceme", - "@tsl//tsl/profiler/lib:traceme_encode", + "@llvm-project//mlir:IR", + "@tsl//tsl/platform:statusor", ], ) cc_library( - name = "concurrent_region", - srcs = ["concurrent_region.cc"], - hdrs = ["concurrent_region.h"], + name = "nccl_recv_thunk", + srcs = ["nccl_recv_thunk.cc"], + hdrs = ["nccl_recv_thunk.h"], deps = [ - ":support", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", - "//xla/service:stream_pool", + ":nccl_api", + ":nccl_collective_thunk", + ":nccl_p2p_thunk_common", + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/service:collective_ops_utils", + "//xla/service:computation_placer", + "//xla/service:global_device_id", + "//xla/service/gpu:nccl_clique_key", + "//xla/service/gpu/runtime:thunk", "//xla/stream_executor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", ], ) cc_library( - name = "stream_synchronization", - srcs = ["stream_synchronization.cc"], - hdrs = ["stream_synchronization.h"], + name = "nccl_send_thunk", + srcs = ["nccl_send_thunk.cc"], + hdrs = ["nccl_send_thunk.h"], deps = [ - ":concurrent_region", - ":support", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", - "//xla/service:stream_pool", + ":nccl_api", + ":nccl_collective_thunk", + ":nccl_p2p_thunk_common", + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/service:collective_ops_utils", + "//xla/service:computation_placer", + "//xla/service:global_device_id", + "//xla/service/gpu:nccl_clique_key", + "//xla/service/gpu/runtime:thunk", "//xla/stream_executor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", ], ) cc_library( - name = "io_feed", - srcs = ["io_feed.cc"], - hdrs = ["io_feed.h"], + name = "norm_thunk", + srcs = ["norm_thunk.cc"], + hdrs = ["norm_thunk.h"], deps = [ - ":support", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", - "//xla/service/gpu:io_feed_manager", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/service:buffer_assignment", + "//xla/service/gpu:gpu_norm_runner", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", ], ) cc_library( - name = "kernel_launch", - srcs = ["kernel_launch.cc"], - hdrs = ["kernel_launch.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), + name = "outfeed_thunk", + srcs = ["outfeed_thunk.cc"], + hdrs = ["outfeed_thunk.h"], deps = [ - ":concurrent_region", - ":support", - "//xla:types", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:state", - "//xla/service:executable", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:stream_executor_util", + "//xla:util", + "//xla/service/gpu:io_feed_manager", + "//xla/service/gpu/runtime:thunk", "//xla/stream_executor", - "//xla/stream_executor/gpu:gpu_graph", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "replica_id_thunk", + srcs = ["replica_id_thunk.cc"], + hdrs = ["replica_id_thunk.h"], + deps = [ + "//xla/service:buffer_assignment", + "//xla/service:global_device_id", + "//xla/service/gpu/runtime:thunk", + "@com_google_absl//absl/status", + "@tsl//tsl/platform:statusor", ], ) cc_library( - name = "gpublas_lt_matmul", - srcs = ["gpublas_lt_matmul.cc"], - hdrs = ["gpublas_lt_matmul.h"], + name = "sequential_thunk", + srcs = ["sequential_thunk.cc"], + hdrs = ["sequential_thunk.h"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), deps = [ - ":support", - "//xla:xla_proto_cc", - "//xla/mlir/runtime/transforms:custom_call_encoding", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:logical_result", - "//xla/runtime:state", - "//xla/service:executable", - "//xla/service/gpu:matmul_utils", - "//xla/stream_executor", - "@tsl//tsl/platform:status", - ] + if_rocm_is_configured([ - "@local_config_rocm//rocm:rocm_headers", - ]), + ":annotation", + "//xla:status", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu/runtime:thunk", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@tsl//tsl/platform:errors", + "@tsl//tsl/profiler/lib:scoped_annotation", + ], ) cc_library( - name = "memcpy", - srcs = ["memcpy.cc"], - hdrs = ["memcpy.h"], + name = "send_recv_thunk", + srcs = ["send_recv_thunk.cc"], + hdrs = ["send_recv_thunk.h"], deps = [ - ":concurrent_region", - ":support", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", - "//xla/service/gpu:io_feed_manager", + "//xla:shape_util", + "//xla:status", + "//xla:statusor", + "//xla:xla_data_proto_cc", + "//xla/service:buffer_assignment", + "//xla/service:global_device_id", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@tsl//tsl/concurrency:async_value", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/profiler/lib:traceme", ], ) cc_library( - name = "memset", - srcs = ["memset.cc"], - hdrs = ["memset.h"], + name = "thunk", + srcs = ["thunk.cc"], + hdrs = ["thunk.h"], deps = [ - ":support", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", + ":nccl_api", + ":nccl_clique", + "//xla:executable_run_options", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", "//xla/service:executable", - "//xla/service/gpu:io_feed_manager", + "//xla/service:global_device_id", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:gpu_executable_run_options", + "//xla/service/gpu:nccl_clique_key", + "//xla/stream_executor", + "//xla/translate/mhlo_to_hlo:location_exporter", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", + "@tsl//tsl/lib/gtl:int_type", + "@tsl//tsl/platform:statusor", ], ) cc_library( - name = "support", - srcs = ["support.cc"], - hdrs = ["support.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - deps = [ - "//xla:shape_util", - "//xla/mlir/runtime/transforms:custom_call_encoding", - "//xla/runtime:custom_call", - "//xla/service/gpu:matmul_utils", - "//xla/stream_executor:blas", + name = "triangular_solve_thunk", + srcs = if_gpu_is_configured(["triangular_solve_thunk.cc"]), + hdrs = if_gpu_is_configured(["triangular_solve_thunk.h"]), + deps = if_gpu_is_configured([ + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:make_batch_pointers", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", "//xla/stream_executor:device_memory", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@tsl//tsl/profiler/lib:scoped_annotation_stack", - ], + "//xla/stream_executor/gpu:gpu_asm_opts", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:status", + ]), ) cc_library( - name = "send_recv", - srcs = ["send_recv.cc"], - hdrs = ["send_recv.h"], + name = "while_thunk", + srcs = ["while_thunk.cc"], + hdrs = ["while_thunk.h"], deps = [ - ":support", - "//xla/mlir/runtime/transforms:custom_call_encoding", - "//xla/mlir_hlo", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/service:executable", + ":sequential_thunk", + "//xla:status", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu/runtime:thunk", "//xla/stream_executor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "@tsl//tsl/concurrency:async_value", - "@tsl//tsl/profiler/lib:traceme", - "@tsl//tsl/profiler/lib:traceme_encode", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", ], ) cc_library( - name = "tracing", - srcs = ["tracing.cc"], - hdrs = ["tracing.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + name = "wait_for_streams_thunk", + srcs = ["wait_for_streams_thunk.cc"], + hdrs = ["wait_for_streams_thunk.h"], deps = [ - ":support", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:tracing", - "//xla/runtime:type_id", + "//xla/service:global_device_id", + "//xla/service/gpu/runtime:thunk", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@tsl//tsl/profiler/lib:scoped_annotation_stack", + "@com_google_absl//absl/strings", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) cc_library( - name = "triangular_solve", - srcs = ["triangular_solve.cc"], - hdrs = ["triangular_solve.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM"]), + name = "cudnn_thunk", + srcs = ["cudnn_thunk.cc"], + hdrs = ["cudnn_thunk.h"], deps = [ - ":support", - "//xla:xla_proto_cc", - "//xla/runtime:custom_call", - "//xla/runtime:executable", - "//xla/service:executable", - "//xla/service/gpu:gpu_asm_opts_util", - "//xla/service/gpu/runtime3:triangular_solve_thunk", - "@tsl//tsl/platform:human_readable_json", + "//xla/service:buffer_assignment", + "//xla/service/gpu:kernel_arguments", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "@com_google_absl//absl/base", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", ], ) diff --git a/xla/service/gpu/runtime/address_computation_thunk.cc b/xla/service/gpu/runtime/address_computation_thunk.cc new file mode 100644 index 0000000000000..b24a4f2b7cc3b --- /dev/null +++ b/xla/service/gpu/runtime/address_computation_thunk.cc @@ -0,0 +1,210 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/address_computation_thunk.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "llvm/ADT/STLExtras.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/memory_allocation.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +AddressComputationThunk::AddressComputationThunk( + ThunkInfo thunk_info, std::unique_ptr embedded_thunk, + std::vector> arguments, + std::vector> fake_allocations, + std::vector>> + offset_buffer_indices, + std::vector> orig_shapes, + std::vector> sliced_shapes, + std::vector> offset_byte_sizes) + : Thunk(Kind::kAddressComputation, thunk_info), + embedded_thunk_(std::make_unique( + ThunkInfo(thunk_info.op), std::move(*embedded_thunk))), + embedded_thunk_arguments_(std::move(arguments)), + fake_allocations_(std::move(fake_allocations)), + offset_buffer_indices_(std::move(offset_buffer_indices)), + orig_shapes_(std::move(orig_shapes)), + sliced_shapes_(std::move(sliced_shapes)), + offset_byte_sizes_(std::move(offset_byte_sizes)) {} + +absl::Status AddressComputationThunk::Prepare( + const PrepareParams& params, ResourceRequests& resource_requests) { + auto num_arguments = embedded_thunk_arguments_.size(); + TF_RET_CHECK(num_arguments == offset_buffer_indices_.size()); + TF_RET_CHECK(num_arguments == orig_shapes_.size()); + TF_RET_CHECK(num_arguments == sliced_shapes_.size()); + TF_RET_CHECK(num_arguments == offset_byte_sizes_.size()); + for (auto [argument, offset_slice, orig_shape, sliced_shape, + offset_byte_size] : + llvm::zip(embedded_thunk_arguments_, offset_buffer_indices_, + orig_shapes_, sliced_shapes_, offset_byte_sizes_)) { + if (offset_slice.has_value()) { + TF_RET_CHECK(argument.has_value()); + TF_RET_CHECK(orig_shape.has_value()); + TF_RET_CHECK(sliced_shape.has_value()); + TF_RET_CHECK(offset_byte_size.has_value()); + + TF_RET_CHECK(orig_shape->IsArray()); + TF_RET_CHECK(sliced_shape->IsArray()); + + TF_RET_CHECK(offset_slice->size() == orig_shape->rank()); + TF_RET_CHECK(sliced_shape->rank() == orig_shape->rank()); + } + } + + TF_RETURN_IF_ERROR(embedded_thunk_->Prepare(params, resource_requests)); + return absl::OkStatus(); +} + +absl::Status AddressComputationThunk::Initialize( + const InitializeParams& params) { + TF_RETURN_IF_ERROR(embedded_thunk_->Initialize(params)); + + unsigned offset_count = 0; + for (auto maybe_shape : sliced_shapes_) { + offset_count += (maybe_shape == std::nullopt) ? 1 : maybe_shape->rank(); + } + + absl::MutexLock lock(&mutex_); + if (auto it = offsets_.find(params.executor); it == offsets_.end()) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr allocation, + params.executor->HostMemoryAllocate(offset_count * sizeof(int64_t))); + offsets_.emplace(params.executor, std::move(allocation)); + } + + return absl::OkStatus(); +} + +absl::Status AddressComputationThunk::ExecuteOnStream( + const ExecuteParams& params) { + auto& stream = *params.stream; + const BufferAllocations& orig_allocations = *params.buffer_allocations; + std::vector new_buffers( + embedded_thunk_arguments_.size(), se::DeviceMemoryBase()); + + // Get memory allocation for copying offsets from device. + int64_t* offsets_base = [&] { + absl::MutexLock lock(&mutex_); + return reinterpret_cast(offsets_.at(stream.parent())->opaque()); + }(); + + for (auto [argument_idx, values] : llvm::enumerate( + llvm::zip(embedded_thunk_arguments_, offset_buffer_indices_, + orig_shapes_, sliced_shapes_, offset_byte_sizes_))) { + auto [argument_slice, offset_slice, orig_shape, sliced_shape, + offset_byte_size] = values; + + if (argument_slice == std::nullopt) { + continue; + } + + // `orig_argument` will contain the original offset for slice + // `argument_slice` within `orig_allocations` + se::DeviceMemoryBase orig_argument = + orig_allocations.GetDeviceAddress(*argument_slice); + + if (offset_slice == std::nullopt) { + new_buffers[argument_idx] = orig_argument; + continue; + } + + const Shape& src_shape = *orig_shape; + const Shape& dst_shape = *sliced_shape; + TF_RET_CHECK(IsContiguousSlice(src_shape, dst_shape)); + + std::vector slice_starts; + slice_starts.reserve(dst_shape.rank()); + + // Get offset for `argument_idx`-th argument, which has `dst_shape.rank()` + // components. + for (auto [offset_idx, values] : llvm::enumerate(llvm::zip( + *offset_slice, src_shape.dimensions(), dst_shape.dimensions()))) { + auto [slice, src_dim, dst_dim] = values; + se::DeviceMemoryBase offset_src = + orig_allocations.GetDeviceAddress(slice); + int64_t* offset_dst = &offsets_base[argument_idx + offset_idx]; + // Copy the `offset_idx`-th component of the offset for the + // `argument_idx`-th argument from device to host. + TF_RETURN_IF_ERROR( + stream.Memcpy(offset_dst, offset_src, offset_byte_size.value())); + + if (absl::Status blocked = stream.BlockHostUntilDone(); !blocked.ok()) { + return absl::InternalError(absl::StrFormat( + "Failed to retrieve all slice offset values on stream %p: %s", + &stream, blocked.message())); + } + // Clamp start indices: + // start_indices[i] = min(max(start_indices[i], 0), + // operand.dimension_size[i] - size_indices[i]) + auto start_index = std::min(std::max(*offset_dst, 0L), src_dim - dst_dim); + slice_starts.push_back(start_index); + } + + // Compute new slice. No need to copy the content to new buffers as we can + // reuse the original buffers since slices are contiguous. + int64_t new_size = ShapeUtil::ByteSizeOf(dst_shape); + + int64_t new_offset = 0; + for (auto [start, stride] : + llvm::zip(slice_starts, *ShapeUtil::ByteStrides(src_shape))) { + new_offset += start * stride; + } + + new_buffers[argument_idx] = + orig_argument.GetByteSlice(new_offset, new_size); + } + + // Safe to create a local BufferAllocations here since buffers are only slices + // of bigger ones allocated elsewhere. + BufferAllocations new_allocations(new_buffers, + orig_allocations.device_ordinal(), + orig_allocations.memory_allocator(), + orig_allocations.external_allocations()); + + Thunk::ExecuteParams new_params = + Thunk::ExecuteParams::CloneWithNewAllocations(params, new_allocations); + + // Execute the underlying custom call thunk with the new buffers. + TF_RETURN_IF_ERROR(embedded_thunk_->ExecuteOnStream(new_params)); + + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/address_computation_thunk.h b/xla/service/gpu/runtime/address_computation_thunk.h new file mode 100644 index 0000000000000..8d36751b9d830 --- /dev/null +++ b/xla/service/gpu/runtime/address_computation_thunk.h @@ -0,0 +1,84 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_ADDRESS_COMPUTATION_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_ADDRESS_COMPUTATION_THUNK_H_ + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/status.h" +#include "xla/stream_executor/memory_allocation.h" +#include "xla/stream_executor/stream_executor.h" + +namespace xla { +namespace gpu { + +// AddressComputationThunk wraps the logic to compute dynamic offsets/sizes from +// dynamic-slice or DUS around some original thunks (e.g. custom call or NCCL +// thunks) +// +// AddressComputationThunk assumes that the slices are contiguous. +class AddressComputationThunk : public Thunk { + public: + AddressComputationThunk( + ThunkInfo thunk_info, std::unique_ptr embedded_thunk, + std::vector> arguments, + std::vector> fake_allocations_, + std::vector>> + offset_buffer_indices, + std::vector> orig_shapes, + std::vector> sliced_shapes, + std::vector> offset_byte_sizes); + + AddressComputationThunk(const AddressComputationThunk&) = delete; + AddressComputationThunk& operator=(const AddressComputationThunk&) = delete; + + absl::Status Prepare(const PrepareParams& params, + ResourceRequests& resource_requests) override; + absl::Status Initialize(const InitializeParams& params) override; + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + private: + std::unique_ptr embedded_thunk_; + std::vector> + embedded_thunk_arguments_; + std::vector> fake_allocations_; + std::vector>> + offset_buffer_indices_; + std::vector> orig_shapes_; + std::vector> sliced_shapes_; + std::vector> offset_byte_sizes_; + + // Pinned host memory for transferring offset values from device to host. + absl::Mutex mutex_; + absl::flat_hash_map> + offsets_ ABSL_GUARDED_BY(mutex_); +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_ADDRESS_COMPUTATION_THUNK_H_ diff --git a/xla/service/gpu/runtime/address_computation_thunk_test.cc b/xla/service/gpu/runtime/address_computation_thunk_test.cc new file mode 100644 index 0000000000000..470a137456197 --- /dev/null +++ b/xla/service/gpu/runtime/address_computation_thunk_test.cc @@ -0,0 +1,1700 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/address_computation_thunk.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/runtime/custom_call_thunk.h" +#include "xla/service/gpu/runtime/gemm_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/service/platform_util.h" +#include "xla/service/service_executable_run_options.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/blas.h" +#include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_types.h" // IWYU pragma: keep +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/types.h" // IWYU pragma: keep +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/test.h" + +#if GOOGLE_CUDA +#define PLATFORM "CUDA" +#elif TENSORFLOW_USE_ROCM +#define PLATFORM "ROCM" +#endif + +namespace xla::gpu { + +namespace { + +static se::StreamExecutor* GpuExecutor() { + auto name = + absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); + auto* platform = se::PlatformManager::PlatformWithName(name).value(); + return platform->ExecutorForDevice(0).value(); +} + +} // namespace + +TEST(AddressComputationThunkTest, SlicedGemm) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t lhs_length = sizeof(float) * 2 * 4; + int64_t rhs_length = sizeof(float) * 3 * 1; + int64_t out_length = sizeof(float) * 1 * 1; + int64_t offset_length = sizeof(int64_t); + + // Step 1: + // Prepare embedded and address computation thunks. + + // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(4); + + fake_allocations.push_back( + std::make_unique(/*index=*/0, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0, + rhs_length); + + BufferAllocation alloc_lhs(/*index=*/0, lhs_length, /*color=*/0); + BufferAllocation::Slice slice_lhs(&alloc_lhs, 0, lhs_length); + + fake_allocations.push_back( + std::make_unique(/*index=*/1, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_rhs(fake_allocations.back().get(), 0, + rhs_length); + + fake_allocations.push_back( + std::make_unique(/*index=*/2, out_length, /*color=*/0)); + BufferAllocation::Slice slice_out(fake_allocations.back().get(), 0, + out_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/3, 1024 * 1024, /*color=*/0)); + BufferAllocation::Slice slice_workspace(fake_allocations.back().get(), 0, + 1024 * 1024); + + BufferAllocation alloc_lhs_offset_0(/*index=*/4, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_0(&alloc_lhs_offset_0, 0, + offset_length); + + BufferAllocation alloc_lhs_offset_1(/*index=*/5, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_1(&alloc_lhs_offset_1, 0, + offset_length); + + // Preparing config for GEMM thunk. + auto config = + GemmConfig::For(ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), {}, {1}, + ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1}), {}, {0}, + ShapeUtil::MakeShape(PrimitiveType::F32, {1, 1}), 1.0, + 0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt, + se::blas::kDefaultComputePrecision, false, false); + ASSERT_TRUE(config.ok()); + + // Creating embedded GEMM thunk. + ThunkSequence seq; + seq.emplace_back(std::make_unique( + Thunk::ThunkInfo(nullptr), config.value(), slice_lhs_fake, slice_rhs, + slice_out, slice_workspace, /*deterministic=*/true)); + + // Wrapping address computation thunk around the GEMM thunk. + std::vector lhs_offsets{slice_lhs_offset_0, + slice_lhs_offset_1}; + AddressComputationThunk thunk( + Thunk::ThunkInfo(nullptr), + std::make_unique(std::move(seq)), + {slice_lhs, slice_rhs, slice_out, slice_workspace}, + std::move(fake_allocations), + {lhs_offsets, std::nullopt, std::nullopt, std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt, + std::nullopt, std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), std::nullopt, + std::nullopt, std::nullopt}, + {sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt}); + + // Step 2: + // Execute address computation thunk. + // + // Given a `lhs` tensor of shape f32[2,4]{1,0} + // The `lhs` slice that we want to use will be equivalent to this static + // slice op: + // f32[1,3]{1,0} slice(lhs), slice={[0:1], [1:4]} + + // Preparing memory for thunk arguments. + // lhs = [1.0, 2.0, 3.0, 4.0, + // 5.0, 6.0, 7.0, 8.0] + se::DeviceMemory lhs = executor->AllocateArray(2 * 4); + std::vector lhs_arr{1, 2, 3, 4, 5, 6, 7, 8}; + TF_ASSERT_OK(stream->Memcpy(&lhs, lhs_arr.data(), lhs_length)); + + // rhs = [1.0, + // 1.0, + // 1.0] + se::DeviceMemory rhs = executor->AllocateArray(3 * 1); + std::vector rhs_arr(3, 1); + TF_ASSERT_OK(stream->Memcpy(&rhs, rhs_arr.data(), rhs_length)); + + se::DeviceMemory out = executor->AllocateArray(1 * 1); + TF_ASSERT_OK(stream->MemZero(&out, out_length)); + + se::DeviceMemory workspace = + executor->AllocateArray(1024 * 1024); + TF_ASSERT_OK(stream->MemZero(&workspace, 1024 * 1024)); + + se::DeviceMemory lhs_offset_0 = executor->AllocateArray(1); + se::DeviceMemory lhs_offset_1 = executor->AllocateArray(1); + std::vector lhs_offset_arr{0, 1}; + TF_ASSERT_OK( + stream->Memcpy(&lhs_offset_0, &lhs_offset_arr[0], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&lhs_offset_1, &lhs_offset_arr[1], offset_length)); + + // Preparing parameters for thunk execution. + ServiceExecutableRunOptions run_options; + BufferAllocations allocations( + {lhs, rhs, out, workspace, lhs_offset_0, lhs_offset_1}, 0, + executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; + TF_ASSERT_OK(thunk.Initialize( + {executor, source, &allocations, stream.get(), stream.get()})); + + // Executing address computation thunk. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copying `out` data back to host for verification. + std::vector dst(1, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), out, out_length)); + + ASSERT_EQ(dst, std::vector({9})); +} + +TEST(AddressComputationThunkTest, SlicedNonContiguousGemm) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t lhs_length = sizeof(float) * 2 * 4; + int64_t rhs_length = sizeof(float) * 4 * 3; + int64_t out_length = sizeof(float) * 2 * 2; + int64_t offset_length = sizeof(int64_t); + int64_t slice_length = sizeof(float) * 2 * 2; + + // Step 1: + // Prepare embedded and address computation thunks. + + // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(4); + + fake_allocations.push_back(std::make_unique( + /*index=*/0, slice_length, /*color=*/0)); + BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0, + slice_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/1, slice_length, /*color=*/0)); + BufferAllocation::Slice slice_rhs_fake(fake_allocations.back().get(), 0, + slice_length); + + BufferAllocation alloc_lhs(/*index=*/0, lhs_length, /*color=*/0); + BufferAllocation::Slice slice_lhs(&alloc_lhs, 0, lhs_length); + + BufferAllocation alloc_rhs(/*index=*/1, rhs_length, /*color=*/0); + BufferAllocation::Slice slice_rhs(&alloc_rhs, 0, rhs_length); + + fake_allocations.push_back( + std::make_unique(/*index=*/2, out_length, /*color=*/0)); + BufferAllocation::Slice slice_out(fake_allocations.back().get(), 0, + out_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/3, 1024 * 1024, /*color=*/0)); + BufferAllocation::Slice slice_workspace(fake_allocations.back().get(), 0, + 1024 * 1024); + + BufferAllocation alloc_lhs_offset_0(/*index=*/4, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_0(&alloc_lhs_offset_0, 0, + offset_length); + + BufferAllocation alloc_lhs_offset_1(/*index=*/5, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_1(&alloc_lhs_offset_1, 0, + offset_length); + + BufferAllocation alloc_rhs_offset_0(/*index=*/6, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_rhs_offset_0(&alloc_rhs_offset_0, 0, + offset_length); + + BufferAllocation alloc_rhs_offset_1(/*index=*/7, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_rhs_offset_1(&alloc_rhs_offset_1, 0, + offset_length); + + // Preparing config for GEMM thunk. + auto config = + GemmConfig::For(ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2}), {}, {1}, + ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2}), {}, {0}, + ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2}), 1.0, + 0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt, + se::blas::kDefaultComputePrecision, false, false); + ASSERT_TRUE(config.ok()); + + // Creating embedded GEMM thunk. + ThunkSequence seq; + seq.emplace_back(std::make_unique( + Thunk::ThunkInfo(nullptr), config.value(), slice_lhs_fake, slice_rhs_fake, + slice_out, slice_workspace, /*deterministic=*/true)); + + // Wrapping address computation thunk around the GEMM thunk. + std::vector lhs_offsets{slice_lhs_offset_0, + slice_lhs_offset_1}; + std::vector rhs_offsets{slice_rhs_offset_0, + slice_rhs_offset_1}; + AddressComputationThunk thunk( + Thunk::ThunkInfo(nullptr), + std::make_unique(std::move(seq)), + {slice_lhs, slice_rhs, slice_out, slice_workspace}, + std::move(fake_allocations), + {lhs_offsets, rhs_offsets, std::nullopt, std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), + ShapeUtil::MakeShape(PrimitiveType::F32, {4, 3}), std::nullopt, + std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2}), + ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2}), std::nullopt, + std::nullopt}, + {sizeof(int64_t), sizeof(int64_t), std::nullopt, std::nullopt}); + + // Step 2: + // Execute address computation thunk. + // + // Given a `lhs` tensor of shape f32[2,4]{1,0} + // The `lhs` slice that we want to use will be equivalent to this static + // slice op: + // f32[2,2]{1,0} slice(lhs), slice={[0:2], [1:3]} + + // Preparing memory for thunk arguments. + // lhs = [1.0, 2.0, 3.0, 4.0, + // 5.0, 6.0, 7.0, 8.0] + se::DeviceMemory lhs = executor->AllocateArray(2 * 4); + std::vector lhs_arr{1, 2, 3, 4, 5, 6, 7, 8}; + TF_ASSERT_OK(stream->Memcpy(&lhs, lhs_arr.data(), lhs_length)); + + // Given a `rhs` tensor of shape f32[4,3]{1,0} + // The `rhs` slice that we want to use will be equivalent to this static + // slice op: + // f32[2,2]{1,0} slice(rhs), slice={[2:4], [1:3]} + // rhs = [1.0, 1.0, 1.0, + // 1.0, 1.0, 1.0, + // 1.0, 1.0, 1.0, + // 1.0, 1.0, 1.0] + se::DeviceMemory rhs = executor->AllocateArray(4 * 3); + std::vector rhs_arr(12, 1); + TF_ASSERT_OK(stream->Memcpy(&rhs, rhs_arr.data(), rhs_length)); + + se::DeviceMemory out = executor->AllocateArray(2 * 2); + TF_ASSERT_OK(stream->MemZero(&out, out_length)); + + se::DeviceMemory workspace = + executor->AllocateArray(1024 * 1024); + TF_ASSERT_OK(stream->MemZero(&workspace, 1024 * 1024)); + + se::DeviceMemory lhs_offset_0 = executor->AllocateArray(1); + se::DeviceMemory lhs_offset_1 = executor->AllocateArray(1); + std::vector lhs_offset_arr{0, 1}; + TF_ASSERT_OK( + stream->Memcpy(&lhs_offset_0, &lhs_offset_arr[0], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&lhs_offset_1, &lhs_offset_arr[1], offset_length)); + + se::DeviceMemory rhs_offset_0 = executor->AllocateArray(1); + se::DeviceMemory rhs_offset_1 = executor->AllocateArray(1); + std::vector rhs_offset_arr{2, 1}; + TF_ASSERT_OK( + stream->Memcpy(&rhs_offset_0, &rhs_offset_arr[0], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&rhs_offset_1, &rhs_offset_arr[1], offset_length)); + + // Preparing parameters for thunk execution. + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({lhs, rhs, out, workspace, lhs_offset_0, + lhs_offset_1, rhs_offset_0, rhs_offset_1}, + 0, executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; + TF_ASSERT_OK(thunk.Initialize( + {executor, source, &allocations, stream.get(), stream.get()})); + + // Execute address computation thunk and verify that it failed because of non + // contiguous slices on both `lhs` and `rhs`. + ASSERT_FALSE(thunk.ExecuteOnStream(params).ok()); +} + +TEST(AddressComputationThunkTest, MulipleSlicedOperandsGemm) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t length = sizeof(float) * 2 * 4; + int64_t out_length = sizeof(float) * 1; + int64_t offset_length = sizeof(int64_t); + int64_t slice_length = sizeof(float) * 3; + + // Step 1: + // Prepare embedded and address computation thunks. + + // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(4); + + fake_allocations.push_back(std::make_unique( + /*index=*/0, slice_length, /*color=*/0)); + BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0, + slice_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/1, slice_length, /*color=*/0)); + BufferAllocation::Slice slice_rhs_fake(fake_allocations.back().get(), 0, + slice_length); + + BufferAllocation alloc_lhs(/*index=*/0, length, /*color=*/0); + BufferAllocation::Slice slice_lhs(&alloc_lhs, 0, length); + + BufferAllocation alloc_rhs(/*index=*/1, length, /*color=*/0); + BufferAllocation::Slice slice_rhs(&alloc_rhs, 0, length); + + fake_allocations.push_back( + std::make_unique(/*index=*/2, out_length, /*color=*/0)); + BufferAllocation::Slice slice_out(fake_allocations.back().get(), 0, + out_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/3, 1024 * 1024, /*color=*/0)); + BufferAllocation::Slice slice_workspace(fake_allocations.back().get(), 0, + 1024 * 1024); + + BufferAllocation alloc_lhs_offset_0(/*index=*/4, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_0(&alloc_lhs_offset_0, 0, + offset_length); + + BufferAllocation alloc_lhs_offset_1(/*index=*/5, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_1(&alloc_lhs_offset_1, 0, + offset_length); + + BufferAllocation alloc_rhs_offset_0(/*index=*/6, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_rhs_offset_0(&alloc_rhs_offset_0, 0, + offset_length); + + BufferAllocation alloc_rhs_offset_1(/*index=*/7, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_rhs_offset_1(&alloc_rhs_offset_1, 0, + offset_length); + + // Preparing config for GEMM thunk. + auto config = + GemmConfig::For(ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), {}, {1}, + ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1}), {}, {0}, + ShapeUtil::MakeShape(PrimitiveType::F32, {1, 1}), 1.0, + 0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt, + se::blas::kDefaultComputePrecision, false, false); + ASSERT_TRUE(config.ok()); + + // Creating embedded GEMM thunk. + ThunkSequence seq; + seq.emplace_back(std::make_unique( + Thunk::ThunkInfo(nullptr), config.value(), slice_lhs_fake, slice_rhs_fake, + slice_out, slice_workspace, /*deterministic=*/true)); + + // Wrapping address computation thunk around the GEMM thunk. + std::vector lhs_offsets{slice_lhs_offset_0, + slice_lhs_offset_1}; + std::vector rhs_offsets{slice_rhs_offset_0, + slice_rhs_offset_1}; + AddressComputationThunk thunk( + Thunk::ThunkInfo(nullptr), + std::make_unique(std::move(seq)), + {slice_lhs, slice_rhs, slice_out, slice_workspace}, + std::move(fake_allocations), + {lhs_offsets, rhs_offsets, std::nullopt, std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), + ShapeUtil::MakeShape(PrimitiveType::F32, {8, 1}), std::nullopt, + std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), + ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1}), std::nullopt, + std::nullopt}, + {sizeof(int64_t), sizeof(int64_t), std::nullopt, std::nullopt}); + + // Step 2: + // Execute address computation thunk. + // + // Given a `lhs` tensor of shape f32[2,4]{1,0} + // The `lhs` slice that we want to use will be equivalent to this static + // slice op: + // f32[1,3]{1,0} slice(lhs), slice={[0:1], [1:4]} + + // Preparing memory for thunk arguments. + // lhs = [1.0, 2.0, 3.0, 4.0, + // 5.0, 6.0, 7.0, 8.0] + std::vector arr{1, 2, 3, 4, 5, 6, 7, 8}; + se::DeviceMemory lhs = executor->AllocateArray(2 * 4); + TF_ASSERT_OK(stream->Memcpy(&lhs, arr.data(), length)); + + // Given a `rhs` tensor of shape f32[8,1]{1,0} + // The `rhs` slice that we want to use will be equivalent to this static + // slice op: + // f32[3,1]{1,0} slice(rhs), slice={[2:5], [0:1]} + // rhs = [1.0, + // 2.0, + // 3.0, + // 4.0, + // 5.0, + // 6.0, + // 7.0, + // 8.0] + se::DeviceMemory rhs = executor->AllocateArray(8); + std::vector rhs_arr(8, 1); + TF_ASSERT_OK(stream->Memcpy(&rhs, arr.data(), length)); + + se::DeviceMemory out = executor->AllocateArray(1); + TF_ASSERT_OK(stream->MemZero(&out, out_length)); + + se::DeviceMemory workspace = + executor->AllocateArray(1024 * 1024); + TF_ASSERT_OK(stream->MemZero(&workspace, 1024 * 1024)); + + se::DeviceMemory lhs_offset_0 = executor->AllocateArray(1); + se::DeviceMemory lhs_offset_1 = executor->AllocateArray(1); + std::vector lhs_offset_arr{0, 1}; + TF_ASSERT_OK( + stream->Memcpy(&lhs_offset_0, &lhs_offset_arr[0], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&lhs_offset_1, &lhs_offset_arr[1], offset_length)); + + se::DeviceMemory rhs_offset_0 = executor->AllocateArray(1); + se::DeviceMemory rhs_offset_1 = executor->AllocateArray(1); + std::vector rhs_offset_arr{2, 0}; + TF_ASSERT_OK( + stream->Memcpy(&rhs_offset_0, &rhs_offset_arr[0], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&rhs_offset_1, &rhs_offset_arr[1], offset_length)); + + // Preparing parameters for thunk execution. + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({lhs, rhs, out, workspace, lhs_offset_0, + lhs_offset_1, rhs_offset_0, rhs_offset_1}, + 0, executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; + TF_ASSERT_OK(thunk.Initialize( + {executor, source, &allocations, stream.get(), stream.get()})); + + // Execute address computation thunk and verify that it executed a GEMM on the + // right slices. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `out` data back to host for verification. + std::vector dst(1, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), out, out_length)); + + ASSERT_EQ(dst, std::vector({2 * 3 + 3 * 4 + 4 * 5})); +} + +static absl::Status Memcpy(se::Stream* stream, ffi::BufferBase src, + ffi::BufferBase dst) { + return stream->MemcpyD2D( + &dst.data, src.data, + absl::c_accumulate(src.dimensions, 1.0, std::multiplies()) * + sizeof(float)); +} + +XLA_FFI_DEFINE_HANDLER(kMemcpy, Memcpy, + ffi::Ffi::Bind() + .Ctx() + .Arg() // src + .Arg() // dst +); +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$memcpy", PLATFORM, + kMemcpy); + +TEST(AddressComputationThunkTest, SlicedMemcpy) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t src_count = 8 * 8 * 10 * 8; + int64_t dst_count = 8 * 8; + int64_t src_length = sizeof(int32_t) * src_count; + int64_t dst_length = sizeof(int32_t) * dst_count; + int64_t offset_length = sizeof(int64_t); + int64_t slice_length = sizeof(int32_t) * dst_count; + + // Step 1: + // Prepare embedded and address computation thunks. + + // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(2); + + // Fake slices for embedded thunk creation. + fake_allocations.push_back(std::make_unique( + /*index=*/0, slice_length, /*color=*/0)); + BufferAllocation::Slice slice_src_fake(fake_allocations.back().get(), 0, + slice_length); + + BufferAllocation alloc_src(/*index=*/0, src_length, /*color=*/0); + BufferAllocation::Slice slice_src(&alloc_src, 0, src_length); + + fake_allocations.push_back( + std::make_unique(/*index=*/1, dst_length, /*color=*/0)); + BufferAllocation::Slice slice_dst(fake_allocations.back().get(), 0, + dst_length); + + BufferAllocation alloc_offset_0(/*index=*/2, offset_length, /*color=*/0); + BufferAllocation::Slice slice_offset_0(&alloc_offset_0, 0, offset_length); + + BufferAllocation alloc_offset_1(/*index=*/3, offset_length, /*color=*/0); + BufferAllocation::Slice slice_offset_1(&alloc_offset_1, 0, offset_length); + + BufferAllocation alloc_offset_2(/*index=*/4, offset_length, /*color=*/0); + BufferAllocation::Slice slice_offset_2(&alloc_offset_2, 0, offset_length); + + BufferAllocation alloc_offset_3(/*index=*/5, offset_length, /*color=*/0); + BufferAllocation::Slice slice_offset_3(&alloc_offset_3, 0, offset_length); + + // Preparing custom call thunk: setting up call target and operands + results + // buffers. + auto registration = xla::ffi::FindHandler("__xla_test$$memcpy", PLATFORM); + ASSERT_TRUE(registration.ok()); + + std::vector> operands{ + CustomCallThunk::Slice{slice_src_fake, + ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8})}}; + std::vector> results{ + CustomCallThunk::Slice{slice_dst, + ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8})}}; + + // Creating embedded custom call thunk. + ThunkSequence seq; + seq.emplace_back(std::make_unique( + Thunk::ThunkInfo(nullptr), registration->handler, operands, results, + /*attributes=*/CustomCallThunk::AttributesMap(), + /*called_computation=*/nullptr)); + + // Wrapping address computation thunk around the custom call thunk. + std::vector slice_offsets{ + slice_offset_0, slice_offset_1, slice_offset_2, slice_offset_3}; + AddressComputationThunk thunk( + Thunk::ThunkInfo(nullptr), + std::make_unique(std::move(seq)), {slice_src, slice_dst}, + std::move(fake_allocations), {slice_offsets, std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8, 10, 8}), std::nullopt}, + // Make sure to pass a dst shape with the same rank as src shape (i.e. + // original slice result and not bitcasted one) + {ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 8, 8}), std::nullopt}, + {sizeof(int64_t), std::nullopt}); + + // Step 2: + // Execute address computation thunk. + // + // Given a `src` tensor of shape s32[8,8,10,8]{3,2,1,0} + // The `src` slice that we want to copy from will be equivalent to this static + // slice op: + // s32[1,1,8,8]{3,2,1,0} slice(src), slice={[3:4], [5:6], [2:10], [0:8]} + + // Preparing memory for thunk arguments. + se::DeviceMemory src = executor->AllocateArray(src_count); + std::vector src_arr(src_count, 0); + for (unsigned i = 0; i < src_count; ++i) src_arr[i] = i; + TF_ASSERT_OK(stream->Memcpy(&src, src_arr.data(), src_length)); + + se::DeviceMemory dst = executor->AllocateArray(dst_count); + TF_ASSERT_OK(stream->MemZero(&dst, dst_length)); + + se::DeviceMemory offset_0 = executor->AllocateArray(1); + se::DeviceMemory offset_1 = executor->AllocateArray(1); + se::DeviceMemory offset_2 = executor->AllocateArray(1); + se::DeviceMemory offset_3 = executor->AllocateArray(1); + std::vector offset_arr{3, 5, 2, 0}; + TF_ASSERT_OK(stream->Memcpy(&offset_0, &offset_arr[0], offset_length)); + TF_ASSERT_OK(stream->Memcpy(&offset_1, &offset_arr[1], offset_length)); + TF_ASSERT_OK(stream->Memcpy(&offset_2, &offset_arr[2], offset_length)); + TF_ASSERT_OK(stream->Memcpy(&offset_3, &offset_arr[3], offset_length)); + + // Preparing parameters for thunk execution. + ServiceExecutableRunOptions run_options; + BufferAllocations allocations( + {src, dst, offset_0, offset_1, offset_2, offset_3}, 0, + executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; + TF_ASSERT_OK(thunk.Initialize( + {executor, source, &allocations, stream.get(), stream.get()})); + + // Executing address computation thunk. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copying `dst` data back to host for verification. + std::vector out(dst_count, 0); + TF_ASSERT_OK(stream->Memcpy(out.data(), dst, dst_length)); + + // Verifying that the right slice of `src` was copied to `dst`. + std::vector ref(dst_count, 0); + int64_t offset_val = + offset_arr[3] + + 8 * (offset_arr[2] + 10 * (offset_arr[1] + 8 * offset_arr[0])); + std::copy(src_arr.begin() + offset_val, + src_arr.begin() + offset_val + dst_count, ref.begin()); + ASSERT_EQ(out, ref); +} + +TEST(AddressComputationThunkTest, SlicedOutputMemcpy) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t src_count = 8 * 8 * 10 * 2; + int64_t dst_count = 2 * 2 * 2 * 2; + int64_t slice_count = 2 * 2; + int64_t src_length = sizeof(int32_t) * src_count; + int64_t dst_length = sizeof(int32_t) * dst_count; + int64_t offset_length = sizeof(int64_t); + int64_t slice_length = sizeof(int32_t) * slice_count; + + // Step 1: + // Prepare embedded and address computation thunks. + + // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(2); + + // Fake slices for embedded thunk creation. + fake_allocations.push_back(std::make_unique( + /*index=*/0, slice_length, /*color=*/0)); + BufferAllocation::Slice slice_src_fake(fake_allocations.back().get(), 0, + slice_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/1, slice_length, /*color=*/0)); + BufferAllocation::Slice slice_dst_fake(fake_allocations.back().get(), 0, + slice_length); + + BufferAllocation alloc_src(/*index=*/0, src_length, /*color=*/0); + BufferAllocation::Slice slice_src(&alloc_src, 0, src_length); + + BufferAllocation alloc_dst(/*index=*/1, dst_length, /*color=*/0); + BufferAllocation::Slice slice_dst(&alloc_dst, 0, dst_length); + + BufferAllocation alloc_src_offset_0(/*index=*/2, offset_length, /*color=*/0); + BufferAllocation::Slice slice_src_offset_0(&alloc_src_offset_0, 0, + offset_length); + + BufferAllocation alloc_src_offset_1(/*index=*/3, offset_length, /*color=*/0); + BufferAllocation::Slice slice_src_offset_1(&alloc_src_offset_1, 0, + offset_length); + + BufferAllocation alloc_src_offset_2(/*index=*/4, offset_length, /*color=*/0); + BufferAllocation::Slice slice_src_offset_2(&alloc_src_offset_2, 0, + offset_length); + + BufferAllocation alloc_src_offset_3(/*index=*/5, offset_length, /*color=*/0); + BufferAllocation::Slice slice_src_offset_3(&alloc_src_offset_3, 0, + offset_length); + + BufferAllocation alloc_dst_offset_0(/*index=*/6, offset_length, /*color=*/0); + BufferAllocation::Slice slice_dst_offset_0(&alloc_dst_offset_0, 0, + offset_length); + + BufferAllocation alloc_dst_offset_1(/*index=*/7, offset_length, /*color=*/0); + BufferAllocation::Slice slice_dst_offset_1(&alloc_dst_offset_1, 0, + offset_length); + + BufferAllocation alloc_dst_offset_2(/*index=*/8, offset_length, /*color=*/0); + BufferAllocation::Slice slice_dst_offset_2(&alloc_dst_offset_2, 0, + offset_length); + + BufferAllocation alloc_dst_offset_3(/*index=*/9, offset_length, /*color=*/0); + BufferAllocation::Slice slice_dst_offset_3(&alloc_dst_offset_3, 0, + offset_length); + + // Preparing custom call thunk: setting up call target and operands + results + // buffers. + auto registration = xla::ffi::FindHandler("__xla_test$$memcpy", PLATFORM); + ASSERT_TRUE(registration.ok()); + + std::vector> operands{ + CustomCallThunk::Slice{slice_src_fake, + ShapeUtil::MakeShape(PrimitiveType::S32, {2, 2})}}; + std::vector> results{ + CustomCallThunk::Slice{slice_dst_fake, + ShapeUtil::MakeShape(PrimitiveType::S32, {2, 2})}}; + + // Creating embedded custom call thunk. + ThunkSequence seq; + seq.emplace_back(std::make_unique( + Thunk::ThunkInfo(nullptr), registration->handler, operands, results, + /*attributes=*/CustomCallThunk::AttributesMap(), + /*called_computation=*/nullptr)); + + // Wrapping address computation thunk around the custom call thunk. + std::vector slice_src_offsets{ + slice_src_offset_0, slice_src_offset_1, slice_src_offset_2, + slice_src_offset_3}; + std::vector slice_dst_offsets{ + slice_dst_offset_0, slice_dst_offset_1, slice_dst_offset_2, + slice_dst_offset_3}; + AddressComputationThunk thunk( + Thunk::ThunkInfo(nullptr), + std::make_unique(std::move(seq)), {slice_src, slice_dst}, + std::move(fake_allocations), {slice_src_offsets, slice_dst_offsets}, + {ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8, 10, 2}), + ShapeUtil::MakeShape(PrimitiveType::S32, {2, 2, 2, 2})}, + // Make sure to pass a dst shape with the same rank as src shape (i.e. + // original slice result and not bitcasted one) + {ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 2, 2}), + ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 2, 2})}, + {sizeof(int64_t), sizeof(int64_t)}); + + // Step 2: + // Execute address computation thunk. + // + // Given a `src` tensor of shape s32[8,8,10,2]{3,2,1,0} + // The `src` slice that we want to copy from will be equivalent to this static + // slice op: + // s32[1,1,2,2]{3,2,1,0} slice(src), slice={[3:4], [5:6], [2:4], [0:2]} + // + // Given a `dst` tensor of shape s32[2,2,2,2]{3,2,1,0} + // The `dst` slice that we want to copy into will be equivalent to this static + // slice op: + // s32[1,1,2,2]{3,2,1,0} slice(dst), slice={[1:2], [1:2], [0:2], [0:2]} + + // Preparing memory for thunk arguments. + se::DeviceMemory src = executor->AllocateArray(src_count); + std::vector src_arr(src_count, 0); + for (unsigned i = 0; i < src_count; ++i) src_arr[i] = i; + TF_ASSERT_OK(stream->Memcpy(&src, src_arr.data(), src_length)); + + se::DeviceMemory dst = executor->AllocateArray(dst_count); + TF_ASSERT_OK(stream->MemZero(&dst, dst_length)); + + se::DeviceMemory src_offset_0 = executor->AllocateArray(1); + se::DeviceMemory src_offset_1 = executor->AllocateArray(1); + se::DeviceMemory src_offset_2 = executor->AllocateArray(1); + se::DeviceMemory src_offset_3 = executor->AllocateArray(1); + std::vector src_offset_arr{3, 5, 2, 0}; + TF_ASSERT_OK( + stream->Memcpy(&src_offset_0, &src_offset_arr[0], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&src_offset_1, &src_offset_arr[1], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&src_offset_2, &src_offset_arr[2], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&src_offset_3, &src_offset_arr[3], offset_length)); + + se::DeviceMemory dst_offset_0 = executor->AllocateArray(1); + se::DeviceMemory dst_offset_1 = executor->AllocateArray(1); + se::DeviceMemory dst_offset_2 = executor->AllocateArray(1); + se::DeviceMemory dst_offset_3 = executor->AllocateArray(1); + std::vector dst_offset_arr{1, 1, 0, 0}; + TF_ASSERT_OK( + stream->Memcpy(&dst_offset_0, &dst_offset_arr[0], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&dst_offset_1, &dst_offset_arr[1], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&dst_offset_2, &dst_offset_arr[2], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&dst_offset_3, &dst_offset_arr[3], offset_length)); + + // Preparing parameters for thunk execution. + ServiceExecutableRunOptions run_options; + BufferAllocations allocations( + {src, dst, src_offset_0, src_offset_1, src_offset_2, src_offset_3, + dst_offset_0, dst_offset_1, dst_offset_2, dst_offset_3}, + 0, executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; + TF_ASSERT_OK(thunk.Initialize( + {executor, source, &allocations, stream.get(), stream.get()})); + + // Executing address computation thunk. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copying `dst` data back to host for verification. + std::vector out(dst_count, 0); + TF_ASSERT_OK(stream->Memcpy(out.data(), dst, dst_length)); + + // Verifying that the right slice of `src` was copied to `dst`. + std::vector ref(dst_count, 0); + int64_t src_offset_val = + src_offset_arr[3] + + 2 * (src_offset_arr[2] + + 10 * (src_offset_arr[1] + 8 * src_offset_arr[0])); + int64_t dst_offset_val = + dst_offset_arr[3] + + 2 * (dst_offset_arr[2] + 2 * (dst_offset_arr[1] + 2 * dst_offset_arr[0])); + std::copy(src_arr.begin() + src_offset_val, + src_arr.begin() + src_offset_val + slice_count, + ref.begin() + dst_offset_val); + ASSERT_EQ(out, ref); +} + +TEST(AddressComputationThunkTest, SlicedGemmArbitraryArgumentOrder) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t lhs_length = sizeof(float) * 2 * 4; + int64_t rhs_length = sizeof(float) * 3 * 1; + int64_t out_length = sizeof(float) * 1 * 1; + int64_t offset_length = sizeof(int64_t); + + // Step 1: + // Prepare embedded and address computation thunks. + + // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(4); + + fake_allocations.push_back( + std::make_unique(/*index=*/0, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0, + rhs_length); + + fake_allocations.push_back( + std::make_unique(/*index=*/1, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_rhs_fake(fake_allocations.back().get(), 0, + rhs_length); + + fake_allocations.push_back( + std::make_unique(/*index=*/2, out_length, /*color=*/0)); + BufferAllocation::Slice slice_out_fake(fake_allocations.back().get(), 0, + out_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/3, 1024 * 1024, /*color=*/0)); + BufferAllocation::Slice slice_workspace_fake(fake_allocations.back().get(), 0, + 1024 * 1024); + + BufferAllocation alloc_lhs(/*index=*/1, lhs_length, /*color=*/0); + BufferAllocation::Slice slice_lhs(&alloc_lhs, 0, lhs_length); + + BufferAllocation alloc_rhs(/*index=*/3, rhs_length, /*color=*/0); + BufferAllocation::Slice slice_rhs(&alloc_rhs, 0, rhs_length); + + BufferAllocation alloc_out(/*index=*/2, out_length, /*color=*/0); + BufferAllocation::Slice slice_out(&alloc_out, 0, out_length); + + BufferAllocation alloc_workspace(/*index=*/0, 1024 * 1024, /*color=*/0); + BufferAllocation::Slice slice_workspace(&alloc_workspace, 0, 1024 * 1024); + + BufferAllocation alloc_lhs_offset_0(/*index=*/4, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_0(&alloc_lhs_offset_0, 0, + offset_length); + + BufferAllocation alloc_lhs_offset_1(/*index=*/5, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_1(&alloc_lhs_offset_1, 0, + offset_length); + + // Preparing config for GEMM thunk. + auto config = + GemmConfig::For(ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), {}, {1}, + ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1}), {}, {0}, + ShapeUtil::MakeShape(PrimitiveType::F32, {1, 1}), 1.0, + 0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt, + se::blas::kDefaultComputePrecision, false, false); + ASSERT_TRUE(config.ok()); + + // Creating embedded GEMM thunk. + ThunkSequence seq; + seq.emplace_back(std::make_unique( + Thunk::ThunkInfo(nullptr), config.value(), slice_lhs_fake, slice_rhs_fake, + slice_out_fake, slice_workspace_fake, /*deterministic=*/true)); + + // Wrapping address computation thunk around the GEMM thunk. + std::vector lhs_offsets{slice_lhs_offset_0, + slice_lhs_offset_1}; + AddressComputationThunk thunk( + Thunk::ThunkInfo(nullptr), + std::make_unique(std::move(seq)), + {slice_lhs, slice_rhs, slice_out, slice_workspace}, + std::move(fake_allocations), + {lhs_offsets, std::nullopt, std::nullopt, std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt, + std::nullopt, std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), std::nullopt, + std::nullopt, std::nullopt}, + {sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt}); + + // Step 2: + // Execute address computation thunk. + // + // Given a `lhs` tensor of shape f32[2,4]{1,0} + // The `lhs` slice that we want to use will be equivalent to this static + // slice op: + // f32[1,3]{1,0} slice(lhs), slice={[0:1], [1:4]} + + // Preparing memory for thunk arguments. + // lhs = [1.0, 2.0, 3.0, 4.0, + // 5.0, 6.0, 7.0, 8.0] + se::DeviceMemory lhs = executor->AllocateArray(2 * 4); + std::vector lhs_arr{1, 2, 3, 4, 5, 6, 7, 8}; + TF_ASSERT_OK(stream->Memcpy(&lhs, lhs_arr.data(), lhs_length)); + + // rhs = [1.0, + // 1.0, + // 1.0] + se::DeviceMemory rhs = executor->AllocateArray(3 * 1); + std::vector rhs_arr(3, 1); + TF_ASSERT_OK(stream->Memcpy(&rhs, rhs_arr.data(), rhs_length)); + + se::DeviceMemory out = executor->AllocateArray(1 * 1); + TF_ASSERT_OK(stream->MemZero(&out, out_length)); + + se::DeviceMemory workspace = + executor->AllocateArray(1024 * 1024); + TF_ASSERT_OK(stream->MemZero(&workspace, 1024 * 1024)); + + se::DeviceMemory lhs_offset_0 = executor->AllocateArray(1); + se::DeviceMemory lhs_offset_1 = executor->AllocateArray(1); + std::vector lhs_offset_arr{0, 1}; + TF_ASSERT_OK( + stream->Memcpy(&lhs_offset_0, &lhs_offset_arr[0], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&lhs_offset_1, &lhs_offset_arr[1], offset_length)); + + // Preparing parameters for thunk execution. + ServiceExecutableRunOptions run_options; + BufferAllocations allocations( + {workspace, lhs, out, rhs, lhs_offset_0, lhs_offset_1}, 0, + executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; + TF_ASSERT_OK(thunk.Initialize( + {executor, source, &allocations, stream.get(), stream.get()})); + + // Executing address computation thunk. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copying `out` data back to host for verification. + std::vector dst(1, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), out, out_length)); + + ASSERT_EQ(dst, std::vector({9})); +} + +TEST(AddressComputationThunkTest, SlicedGemmArbitraryNumberOfArguments) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t lhs_length = sizeof(float) * 2 * 4; + int64_t rhs_length = sizeof(float) * 3 * 1; + int64_t out_length = sizeof(float) * 1 * 1; + int64_t offset_length = sizeof(int64_t); + + // Step 1: + // Prepare embedded and address computation thunks. + + // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(4); + + fake_allocations.push_back( + std::make_unique(/*index=*/0, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0, + rhs_length); + + fake_allocations.push_back( + std::make_unique(/*index=*/1, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_rhs_fake(fake_allocations.back().get(), 0, + rhs_length); + + fake_allocations.push_back( + std::make_unique(/*index=*/2, out_length, /*color=*/0)); + BufferAllocation::Slice slice_out_fake(fake_allocations.back().get(), 0, + out_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/3, 1024 * 1024, /*color=*/0)); + BufferAllocation::Slice slice_workspace_fake(fake_allocations.back().get(), 0, + 1024 * 1024); + + BufferAllocation alloc_lhs(/*index=*/7, lhs_length, /*color=*/0); + BufferAllocation::Slice slice_lhs(&alloc_lhs, 0, lhs_length); + + BufferAllocation alloc_rhs(/*index=*/3, rhs_length, /*color=*/0); + BufferAllocation::Slice slice_rhs(&alloc_rhs, 0, rhs_length); + + BufferAllocation alloc_out(/*index=*/2, out_length, /*color=*/0); + BufferAllocation::Slice slice_out(&alloc_out, 0, out_length); + + BufferAllocation alloc_workspace(/*index=*/0, 1024 * 1024, /*color=*/0); + BufferAllocation::Slice slice_workspace(&alloc_workspace, 0, 1024 * 1024); + + BufferAllocation alloc_lhs_offset_0(/*index=*/4, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_0(&alloc_lhs_offset_0, 0, + offset_length); + + BufferAllocation alloc_lhs_offset_1(/*index=*/5, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_1(&alloc_lhs_offset_1, 0, + offset_length); + + // Preparing config for GEMM thunk. + auto config = + GemmConfig::For(ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), {}, {1}, + ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1}), {}, {0}, + ShapeUtil::MakeShape(PrimitiveType::F32, {1, 1}), 1.0, + 0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt, + se::blas::kDefaultComputePrecision, false, false); + ASSERT_TRUE(config.ok()); + + // Creating embedded GEMM thunk. + ThunkSequence seq; + seq.emplace_back(std::make_unique( + Thunk::ThunkInfo(nullptr), config.value(), slice_lhs_fake, slice_rhs_fake, + slice_out_fake, slice_workspace_fake, /*deterministic=*/true)); + + // Wrapping address computation thunk around the GEMM thunk. + std::vector lhs_offsets{slice_lhs_offset_0, + slice_lhs_offset_1}; + AddressComputationThunk thunk( + Thunk::ThunkInfo(nullptr), + std::make_unique(std::move(seq)), + {slice_lhs, slice_rhs, slice_out, slice_workspace}, + std::move(fake_allocations), + {lhs_offsets, std::nullopt, std::nullopt, std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt, + std::nullopt, std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), std::nullopt, + std::nullopt, std::nullopt}, + {sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt}); + + // Step 2: + // Execute address computation thunk. + // + // Given a `lhs` tensor of shape f32[2,4]{1,0} + // The `lhs` slice that we want to use will be equivalent to this static + // slice op: + // f32[1,3]{1,0} slice(lhs), slice={[0:1], [1:4]} + + // Preparing memory for thunk arguments. + // lhs = [1.0, 2.0, 3.0, 4.0, + // 5.0, 6.0, 7.0, 8.0] + se::DeviceMemory lhs = executor->AllocateArray(2 * 4); + std::vector lhs_arr{1, 2, 3, 4, 5, 6, 7, 8}; + TF_ASSERT_OK(stream->Memcpy(&lhs, lhs_arr.data(), lhs_length)); + + // rhs = [1.0, + // 1.0, + // 1.0] + se::DeviceMemory rhs = executor->AllocateArray(3 * 1); + std::vector rhs_arr(3, 1); + TF_ASSERT_OK(stream->Memcpy(&rhs, rhs_arr.data(), rhs_length)); + + se::DeviceMemory out = executor->AllocateArray(1 * 1); + TF_ASSERT_OK(stream->MemZero(&out, out_length)); + + se::DeviceMemory workspace = + executor->AllocateArray(1024 * 1024); + TF_ASSERT_OK(stream->MemZero(&workspace, 1024 * 1024)); + + se::DeviceMemory lhs_offset_0 = executor->AllocateArray(1); + se::DeviceMemory lhs_offset_1 = executor->AllocateArray(1); + std::vector lhs_offset_arr{0, 1}; + TF_ASSERT_OK( + stream->Memcpy(&lhs_offset_0, &lhs_offset_arr[0], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&lhs_offset_1, &lhs_offset_arr[1], offset_length)); + + // Preparing parameters for thunk execution. + ServiceExecutableRunOptions run_options; + BufferAllocations allocations( + {workspace, /*garbage, to be ignored*/ se::DeviceMemoryBase(), out, rhs, + lhs_offset_0, lhs_offset_1, /*garbage, to be ignored*/ rhs, lhs}, + 0, executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; + TF_ASSERT_OK(thunk.Initialize( + {executor, source, &allocations, stream.get(), stream.get()})); + + // Executing address computation thunk. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copying `out` data back to host for verification. + std::vector dst(1, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), out, out_length)); + + ASSERT_EQ(dst, std::vector({9})); +} + +TEST(AddressComputationThunkTest, SlicedTupledOperandGemm) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t lhs_length = sizeof(float) * 2 * 4; + int64_t rhs_length = sizeof(float) * 3 * 1; + int64_t out_length = sizeof(float) * 1 * 1; + int64_t offset_length = sizeof(int64_t); + + // Step 1: + // Prepare embedded and address computation thunks. + + // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(4); + + fake_allocations.push_back( + std::make_unique(/*index=*/0, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0, + rhs_length); + + BufferAllocation alloc_lhs(/*index=*/0, 3 * lhs_length, /*color=*/0); + BufferAllocation::Slice slice_lhs(&alloc_lhs, lhs_length, lhs_length); + + fake_allocations.push_back( + std::make_unique(/*index=*/1, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_rhs(fake_allocations.back().get(), 0, + rhs_length); + + fake_allocations.push_back( + std::make_unique(/*index=*/2, out_length, /*color=*/0)); + BufferAllocation::Slice slice_out(fake_allocations.back().get(), 0, + out_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/3, 1024 * 1024, /*color=*/0)); + BufferAllocation::Slice slice_workspace(fake_allocations.back().get(), 0, + 1024 * 1024); + + BufferAllocation alloc_lhs_offset_0(/*index=*/4, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_0(&alloc_lhs_offset_0, 0, + offset_length); + + BufferAllocation alloc_lhs_offset_1(/*index=*/5, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_1(&alloc_lhs_offset_1, 0, + offset_length); + + // Preparing config for GEMM thunk. + auto config = + GemmConfig::For(ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), {}, {1}, + ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1}), {}, {0}, + ShapeUtil::MakeShape(PrimitiveType::F32, {1, 1}), 1.0, + 0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt, + se::blas::kDefaultComputePrecision, false, false); + ASSERT_TRUE(config.ok()); + + // Creating embedded GEMM thunk. + ThunkSequence seq; + seq.emplace_back(std::make_unique( + Thunk::ThunkInfo(nullptr), config.value(), slice_lhs_fake, slice_rhs, + slice_out, slice_workspace, /*deterministic=*/true)); + + // Wrapping address computation thunk around the GEMM thunk. + std::vector lhs_offsets{slice_lhs_offset_0, + slice_lhs_offset_1}; + AddressComputationThunk thunk( + Thunk::ThunkInfo(nullptr), + std::make_unique(std::move(seq)), + {slice_lhs, slice_rhs, slice_out, slice_workspace}, + std::move(fake_allocations), + {lhs_offsets, std::nullopt, std::nullopt, std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt, + std::nullopt, std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), std::nullopt, + std::nullopt, std::nullopt}, + {sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt}); + + // Step 2: + // Execute address computation thunk. + // + + // Preparing memory for thunk arguments. + // lhs = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, + // 5.0, 6.0, 7.0, 8.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + // + // The real `lhs` tensor will look more like this: + // lhs = [1.0, 2.0, 3.0, 4.0, + // 5.0, 6.0, 7.0, 8.0] + // The `lhs` slice that we want to use will be equivalent to this static + // slice op: + // f32[1,3]{1,0} slice(lhs), slice={[0:1], [1:4]} + se::DeviceMemory lhs_whole_buffer = + executor->AllocateArray(2 * 4 * 3); + TF_ASSERT_OK(stream->MemZero(&lhs_whole_buffer, 2 * 4 * 3)); + std::vector lhs_arr{1, 2, 3, 4, 5, 6, 7, 8}; + se::DeviceMemoryBase lhs = + lhs_whole_buffer.GetByteSlice(lhs_length, lhs_length); + TF_ASSERT_OK(stream->Memcpy(&lhs, lhs_arr.data(), lhs_length)); + + // rhs = [1.0, + // 1.0, + // 1.0] + se::DeviceMemory rhs = executor->AllocateArray(3 * 1); + std::vector rhs_arr(3, 1); + TF_ASSERT_OK(stream->Memcpy(&rhs, rhs_arr.data(), rhs_length)); + + se::DeviceMemory out = executor->AllocateArray(1 * 1); + TF_ASSERT_OK(stream->MemZero(&out, out_length)); + + se::DeviceMemory workspace = + executor->AllocateArray(1024 * 1024); + TF_ASSERT_OK(stream->MemZero(&workspace, 1024 * 1024)); + + se::DeviceMemory lhs_offset_0 = executor->AllocateArray(1); + se::DeviceMemory lhs_offset_1 = executor->AllocateArray(1); + std::vector lhs_offset_arr{0, 1}; + TF_ASSERT_OK( + stream->Memcpy(&lhs_offset_0, &lhs_offset_arr[0], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&lhs_offset_1, &lhs_offset_arr[1], offset_length)); + + // Preparing parameters for thunk execution. + ServiceExecutableRunOptions run_options; + BufferAllocations allocations( + {lhs_whole_buffer, rhs, out, workspace, lhs_offset_0, lhs_offset_1}, 0, + executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; + TF_ASSERT_OK(thunk.Initialize( + {executor, source, &allocations, stream.get(), stream.get()})); + + // Executing address computation thunk. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copying `out` data back to host for verification. + std::vector dst(1, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), out, out_length)); + + ASSERT_EQ(dst, std::vector({9})); +} + +TEST(AddressComputationThunkTest, SlicedMemcpyOOB) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t src_count = 8 * 8 * 10 * 2; + int64_t dst_count = 2 * 2 * 2 * 2; + int64_t slice_count = 2 * 2; + int64_t src_length = sizeof(int32_t) * src_count; + int64_t dst_length = sizeof(int32_t) * dst_count; + int64_t offset_length = sizeof(int64_t); + int64_t slice_length = sizeof(int32_t) * slice_count; + + // Step 1: + // Prepare embedded and address computation thunks. + + // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(2); + + // Fake slices for embedded thunk creation. + fake_allocations.push_back(std::make_unique( + /*index=*/0, slice_length, /*color=*/0)); + BufferAllocation::Slice slice_src_fake(fake_allocations.back().get(), 0, + slice_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/1, slice_length, /*color=*/0)); + BufferAllocation::Slice slice_dst_fake(fake_allocations.back().get(), 0, + slice_length); + + BufferAllocation alloc_src(/*index=*/0, src_length, /*color=*/0); + BufferAllocation::Slice slice_src(&alloc_src, 0, src_length); + + BufferAllocation alloc_dst(/*index=*/1, dst_length, /*color=*/0); + BufferAllocation::Slice slice_dst(&alloc_dst, 0, dst_length); + + BufferAllocation alloc_src_offset_0(/*index=*/2, offset_length, /*color=*/0); + BufferAllocation::Slice slice_src_offset_0(&alloc_src_offset_0, 0, + offset_length); + + BufferAllocation alloc_src_offset_1(/*index=*/3, offset_length, /*color=*/0); + BufferAllocation::Slice slice_src_offset_1(&alloc_src_offset_1, 0, + offset_length); + + BufferAllocation alloc_src_offset_2(/*index=*/4, offset_length, /*color=*/0); + BufferAllocation::Slice slice_src_offset_2(&alloc_src_offset_2, 0, + offset_length); + + BufferAllocation alloc_src_offset_3(/*index=*/5, offset_length, /*color=*/0); + BufferAllocation::Slice slice_src_offset_3(&alloc_src_offset_3, 0, + offset_length); + + BufferAllocation alloc_dst_offset_0(/*index=*/6, offset_length, /*color=*/0); + BufferAllocation::Slice slice_dst_offset_0(&alloc_dst_offset_0, 0, + offset_length); + + BufferAllocation alloc_dst_offset_1(/*index=*/7, offset_length, /*color=*/0); + BufferAllocation::Slice slice_dst_offset_1(&alloc_dst_offset_1, 0, + offset_length); + + BufferAllocation alloc_dst_offset_2(/*index=*/8, offset_length, /*color=*/0); + BufferAllocation::Slice slice_dst_offset_2(&alloc_dst_offset_2, 0, + offset_length); + + BufferAllocation alloc_dst_offset_3(/*index=*/9, offset_length, /*color=*/0); + BufferAllocation::Slice slice_dst_offset_3(&alloc_dst_offset_3, 0, + offset_length); + + // Preparing custom call thunk: setting up call target and operands + results + // buffers. + auto registration = xla::ffi::FindHandler("__xla_test$$memcpy", PLATFORM); + ASSERT_TRUE(registration.ok()); + + std::vector> operands{ + CustomCallThunk::Slice{slice_src_fake, + ShapeUtil::MakeShape(PrimitiveType::S32, {2, 2})}}; + std::vector> results{ + CustomCallThunk::Slice{slice_dst_fake, + ShapeUtil::MakeShape(PrimitiveType::S32, {2, 2})}}; + + // Creating embedded custom call thunk. + ThunkSequence seq; + seq.emplace_back(std::make_unique( + Thunk::ThunkInfo(nullptr), registration->handler, operands, results, + /*attributes=*/CustomCallThunk::AttributesMap(), + /*called_computation=*/nullptr)); + + // Wrapping address computation thunk around the custom call thunk. + std::vector slice_src_offsets{ + slice_src_offset_0, slice_src_offset_1, slice_src_offset_2, + slice_src_offset_3}; + std::vector slice_dst_offsets{ + slice_dst_offset_0, slice_dst_offset_1, slice_dst_offset_2, + slice_dst_offset_3}; + AddressComputationThunk thunk( + Thunk::ThunkInfo(nullptr), + std::make_unique(std::move(seq)), {slice_src, slice_dst}, + std::move(fake_allocations), {slice_src_offsets, slice_dst_offsets}, + {ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8, 10, 2}), + ShapeUtil::MakeShape(PrimitiveType::S32, {2, 2, 2, 2})}, + // Make sure to pass a dst shape with the same rank as src shape (i.e. + // original slice result and not bitcasted one) + {ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 2, 2}), + ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 2, 2})}, + {sizeof(int64_t), sizeof(int64_t)}); + + // Step 2: + // Execute address computation thunk. + // + // Given a `src` tensor of shape s32[8,8,10,2]{3,2,1,0} + // The `src` slice that we want to copy from will be equivalent to this static + // slice op: + // s32[1,1,2,2]{3,2,1,0} slice(src), slice={[3:4], [5:6], [2:4], [0:2]} + // + // Given a `dst` tensor of shape s32[2,2,2,2]{3,2,1,0} + // The `dst` slice that we want to copy into will be equivalent to this static + // slice op: + // s32[1,1,2,2]{3,2,1,0} slice(dst), slice={[1:2], [1:2], [0:2], [0:2]} + + // Preparing memory for thunk arguments. + se::DeviceMemory src = executor->AllocateArray(src_count); + std::vector src_arr(src_count, 0); + for (unsigned i = 0; i < src_count; ++i) src_arr[i] = i; + TF_ASSERT_OK(stream->Memcpy(&src, src_arr.data(), src_length)); + + se::DeviceMemory dst = executor->AllocateArray(dst_count); + TF_ASSERT_OK(stream->MemZero(&dst, dst_length)); + + se::DeviceMemory src_offset_0 = executor->AllocateArray(1); + se::DeviceMemory src_offset_1 = executor->AllocateArray(1); + se::DeviceMemory src_offset_2 = executor->AllocateArray(1); + se::DeviceMemory src_offset_3 = executor->AllocateArray(1); + std::vector src_ref_offset_arr{3, 5, 2, 0}; + std::vector src_offset_arr{3, 5, 2, -3}; + TF_ASSERT_OK( + stream->Memcpy(&src_offset_0, &src_offset_arr[0], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&src_offset_1, &src_offset_arr[1], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&src_offset_2, &src_offset_arr[2], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&src_offset_3, &src_offset_arr[3], offset_length)); + + se::DeviceMemory dst_offset_0 = executor->AllocateArray(1); + se::DeviceMemory dst_offset_1 = executor->AllocateArray(1); + se::DeviceMemory dst_offset_2 = executor->AllocateArray(1); + se::DeviceMemory dst_offset_3 = executor->AllocateArray(1); + std::vector dst_ref_offset_arr{1, 1, 0, 0}; + std::vector dst_offset_arr{3, 2, 5, -4}; + TF_ASSERT_OK( + stream->Memcpy(&dst_offset_0, &dst_offset_arr[0], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&dst_offset_1, &dst_offset_arr[1], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&dst_offset_2, &dst_offset_arr[2], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&dst_offset_3, &dst_offset_arr[3], offset_length)); + + // Preparing parameters for thunk execution. + ServiceExecutableRunOptions run_options; + BufferAllocations allocations( + {src, dst, src_offset_0, src_offset_1, src_offset_2, src_offset_3, + dst_offset_0, dst_offset_1, dst_offset_2, dst_offset_3}, + 0, executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; + TF_ASSERT_OK(thunk.Initialize( + {executor, source, &allocations, stream.get(), stream.get()})); + + // Executing address computation thunk. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copying `dst` data back to host for verification. + std::vector out(dst_count, 0); + TF_ASSERT_OK(stream->Memcpy(out.data(), dst, dst_length)); + + // Verifying that the right slice of `src` was copied to `dst`. + std::vector ref(dst_count, 0); + int64_t src_offset_val = + src_ref_offset_arr[3] + + 2 * (src_ref_offset_arr[2] + + 10 * (src_ref_offset_arr[1] + 8 * src_ref_offset_arr[0])); + int64_t dst_offset_val = + dst_ref_offset_arr[3] + + 2 * (dst_ref_offset_arr[2] + + 2 * (dst_ref_offset_arr[1] + 2 * dst_ref_offset_arr[0])); + std::copy(src_arr.begin() + src_offset_val, + src_arr.begin() + src_offset_val + slice_count, + ref.begin() + dst_offset_val); + ASSERT_EQ(out, ref); +} + +TEST(AddressComputationThunkTest, SlicedOperandsSameBufferGemm) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t lhs_length = sizeof(float) * 2 * 4; + int64_t rhs_length = sizeof(float) * 3 * 1; + int64_t out_length = sizeof(float) * 1 * 1; + int64_t offset_length = sizeof(int64_t); + + // Step 1: + // Prepare embedded and address computation thunks. + + // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(4); + + fake_allocations.push_back( + std::make_unique(/*index=*/0, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0, + rhs_length); + + fake_allocations.push_back( + std::make_unique(/*index=*/1, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_rhs_fake(fake_allocations.back().get(), 0, + rhs_length); + + fake_allocations.push_back( + std::make_unique(/*index=*/2, out_length, /*color=*/0)); + BufferAllocation::Slice slice_out_fake(fake_allocations.back().get(), 0, + out_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/3, 1024 * 1024, /*color=*/0)); + BufferAllocation::Slice slice_workspace_fake(fake_allocations.back().get(), 0, + 1024 * 1024); + + BufferAllocation alloc(/*index=*/0, lhs_length + rhs_length + out_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs(&alloc, 0, lhs_length); + BufferAllocation::Slice slice_rhs(&alloc, lhs_length, rhs_length); + BufferAllocation::Slice slice_out(&alloc, lhs_length + rhs_length, + out_length); + + BufferAllocation alloc_workspace(/*index=*/1, 1024 * 1024, /*color=*/0); + BufferAllocation::Slice slice_workspace(&alloc_workspace, 0, 1024 * 1024); + + BufferAllocation alloc_lhs_offset_0(/*index=*/2, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_0(&alloc_lhs_offset_0, 0, + offset_length); + + BufferAllocation alloc_lhs_offset_1(/*index=*/3, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_1(&alloc_lhs_offset_1, 0, + offset_length); + + // Preparing config for GEMM thunk. + auto config = + GemmConfig::For(ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), {}, {1}, + ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1}), {}, {0}, + ShapeUtil::MakeShape(PrimitiveType::F32, {1, 1}), 1.0, + 0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt, + se::blas::kDefaultComputePrecision, false, false); + ASSERT_TRUE(config.ok()); + + // Creating embedded GEMM thunk. + ThunkSequence seq; + seq.emplace_back(std::make_unique( + Thunk::ThunkInfo(nullptr), config.value(), slice_lhs_fake, slice_rhs_fake, + slice_out_fake, slice_workspace_fake, /*deterministic=*/true)); + + // Wrapping address computation thunk around the GEMM thunk. + std::vector lhs_offsets{slice_lhs_offset_0, + slice_lhs_offset_1}; + AddressComputationThunk thunk( + Thunk::ThunkInfo(nullptr), + std::make_unique(std::move(seq)), + {slice_lhs, slice_rhs, slice_out, slice_workspace}, + std::move(fake_allocations), + {lhs_offsets, std::nullopt, std::nullopt, std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt, + std::nullopt, std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), std::nullopt, + std::nullopt, std::nullopt}, + {sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt}); + + // Step 2: + // Execute address computation thunk. + // + + // Preparing memory for thunk arguments. + // lhs = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, + // 5.0, 6.0, 7.0, 8.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + // + // The real `lhs` tensor will look more like this: + // lhs = [1.0, 2.0, 3.0, 4.0, + // 5.0, 6.0, 7.0, 8.0] + // The `lhs` slice that we want to use will be equivalent to this static + // slice op: + // f32[1,3]{1,0} slice(lhs), slice={[0:1], [1:4]} + se::DeviceMemory buffer = + executor->AllocateArray(lhs_length + rhs_length + out_length); + TF_ASSERT_OK(stream->MemZero(&buffer, lhs_length + rhs_length + out_length)); + + se::DeviceMemoryBase lhs = buffer.GetByteSlice(0, lhs_length); + std::vector lhs_arr{1, 2, 3, 4, 5, 6, 7, 8}; + TF_ASSERT_OK(stream->Memcpy(&lhs, lhs_arr.data(), lhs_length)); + + // rhs = [1.0, + // 1.0, + // 1.0] + se::DeviceMemoryBase rhs = buffer.GetByteSlice(lhs_length, rhs_length); + std::vector rhs_arr(3, 1); + TF_ASSERT_OK(stream->Memcpy(&rhs, rhs_arr.data(), rhs_length)); + + se::DeviceMemoryBase out = + buffer.GetByteSlice(lhs_length + rhs_length, out_length); + + se::DeviceMemory workspace = + executor->AllocateArray(1024 * 1024); + TF_ASSERT_OK(stream->MemZero(&workspace, 1024 * 1024)); + + se::DeviceMemory lhs_offset_0 = executor->AllocateArray(1); + se::DeviceMemory lhs_offset_1 = executor->AllocateArray(1); + std::vector lhs_offset_arr{0, 1}; + TF_ASSERT_OK( + stream->Memcpy(&lhs_offset_0, &lhs_offset_arr[0], offset_length)); + TF_ASSERT_OK( + stream->Memcpy(&lhs_offset_1, &lhs_offset_arr[1], offset_length)); + + // Preparing parameters for thunk execution. + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({buffer, workspace, lhs_offset_0, lhs_offset_1}, + 0, executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; + TF_ASSERT_OK(thunk.Initialize( + {executor, source, &allocations, stream.get(), stream.get()})); + + // Executing address computation thunk. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copying `out` data back to host for verification. + std::vector dst(1, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), out, out_length)); + + ASSERT_EQ(dst, std::vector({9})); +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/annotation.cc b/xla/service/gpu/runtime/annotation.cc new file mode 100644 index 0000000000000..809f02d30d10b --- /dev/null +++ b/xla/service/gpu/runtime/annotation.cc @@ -0,0 +1,573 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/annotation.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_split.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/printer.h" +#include "xla/status.h" +#include "tsl/platform/errors.h" +#include "tsl/profiler/lib/nvtx_utils.h" +#include "tsl/profiler/lib/scoped_annotation.h" + +#if GOOGLE_CUDA +#include "nvtx3/nvToolsExt.h" +#include "nvtx3/nvToolsExtPayload.h" +#endif + +namespace xla::gpu { + +using ::tsl::profiler::ScopedAnnotation; +using ::tsl::profiler::StringHandle; +namespace { + +StringHandle RegisterString(const std::string& str) { + if (auto domain = tsl::profiler::DefaultProfilerDomain(); domain) { + return tsl::profiler::RegisterString(domain, str); + } + return {}; +} + +// Nsight Systems supports some basic HTML markup in annotation strings. This +// escaping stops things like from disappearing. +std::ostream& PrintEscaped(std::ostream& os, std::string_view str) { + for (char c : str) { + switch (c) { + case '<': + os << "<"; + break; + case '>': + os << ">"; + break; + default: + os << c; + } + } + return os; +} + +// Print options for profiler annotations. +HloPrintOptions PrintOptions() { + auto opts = HloPrintOptions::ShortParsable(); + opts.set_print_large_constants(false); + opts.set_print_control_dependencies(true); + opts.set_print_operand_index_annotation_interval(5); + opts.set_print_backend_config(true); + opts.set_print_metadata(true); + opts.set_print_name_after_closing_brace(true); + return opts; +} + +// Sortable struct representing a frame in the Python stacktrace attached to a +// given instruction. +struct StackFrame { + std::string_view file_name, function_name, op_name; + int line, column; + + private: + auto tied() const { + return std::tie(file_name, line, column, function_name, op_name); + } + friend bool operator==(StackFrame const& lhs, StackFrame const& rhs) { + return lhs.tied() == rhs.tied(); + } + friend bool operator<(StackFrame const& lhs, StackFrame const& rhs) { + return lhs.tied() < rhs.tied(); + } +}; + +// Walk through the HLO graph from an instruction and collect the source +// file/line information we see along the way. This allows us to generate an +// annotation for each kernel that shows the (merged) Python stacktraces of the +// operations that were traced and compiled int this kernel. For example: +// +// - /opt/jax/examples/mnist_vae.py:143[] +// -- /opt/jax/examples/mnist_vae.py:127[run_epoch] +// --- /opt/jax/examples/mnist_vae.py:125[body_fun] +// ---- /opt/jax/examples/mnist_vae.py:124[] +// ----- /opt/jax/examples/mnist_vae.py:122[body_fun] transpose[permutation=(1, +// 0)] +// --- /opt/jax/examples/mnist_vae.py:126[body_fun] add +// --- /opt/jax/examples/mnist_vae.py:126[body_fun] mul +// --- /opt/jax/examples/mnist_vae.py:126[body_fun] sub +// +// shows four merged stacktraces (3 of depth 3, 1 of depth 5). +class SourceLocationVisitor : public ConstDfsHloVisitorWithDefault { + public: + explicit SourceLocationVisitor( + std::string_view op_name_prefix_to_remove__ = {}) + : op_name_prefix_to_remove_{op_name_prefix_to_remove__} {} + + std::string AsString(int32_t common_prefix) const { + // Format the call stacks we've collected; if call stack collection was not + // enabled then each "stack" just has depth 1 and no column/function name + // information. Skip the first `common_prefix` elements of each stack trace + if (common_prefix < 0) { + return "[invalid common_prefix]"; + } + std::ostringstream oss{}; + oss << '\n'; + std::vector current_state{}; + for (auto const& call_stack : location_set_) { + for (auto depth = 0; depth < call_stack.size() - common_prefix; ++depth) { + auto const& frame = call_stack[common_prefix + depth]; + if (depth < current_state.size() && current_state[depth] == frame) { + continue; + } + current_state.resize(depth + 1); + current_state[depth] = frame; + FormatFrame(oss, frame, depth); + } + } + return std::move(oss).str(); + } + + Status DefaultAction(HloInstruction const* inst) final { + OpMetadata const& meta = inst->metadata(); + // The full op_name is split across three places: the module-level + // annotation shows the prefix that is common to the whole module, the + // kernel-level annotation removes that prefix and shows whatever middle + // sections of the name are common to all operations in the kernel, and the + // individual call stack frames in the kernel-level annotation show the + // final parts of the op_name that have not already been shown. + std::string_view op_name = meta.op_name(); + if (!op_name.empty()) { + op_name = op_name.substr(op_name_prefix_to_remove_.size()); + } + if (!op_name.empty() && op_name.front() == '/') { + op_name = op_name.substr(1); + } + if (int frame_id = meta.stack_frame_id(); frame_id != 0) { + std::vector call_stack{}; + HloModule const* const hlo_module = inst->parent()->parent(); + while (frame_id != 0) { + HloModule::StackFrame frame = hlo_module->get_stack_frame(frame_id); + if (frame.empty()) { + break; + } + frame_id = frame.parent_frame_id; + call_stack.emplace_back(StackFrame{frame.file_name, frame.function_name, + op_name, frame.line, frame.column}); + // only attach the op_name to the most-nested frame + op_name = {}; + } + // re-order to be [caller, callee, ...] + std::reverse(call_stack.begin(), call_stack.end()); + location_set_.emplace(call_stack); + } else if (!meta.source_file().empty() && meta.source_line() != 0) { + location_set_.emplace(1, StackFrame{meta.source_file(), + {/* function_name */}, + op_name, + meta.source_line()}); + } + return OkStatus(); + } + + std::pair LongestSourceLocationPrefix() const { + // Find the longest common prefix along the members of location_set_ and + // return a formatted version of that prefix, along with its length. As + // location_set_ is sorted, that just means looking for the longest common + // prefix of the first and last elements. + if (location_set_.size() < 2) { + // Only extract a prefix if there are enough stack traces. + return {}; + } + const auto& first_loc = *location_set_.begin(); + const auto common_end = std::mismatch(first_loc.begin(), first_loc.end(), + location_set_.rbegin()->begin(), + location_set_.rbegin()->end()) + .first; + std::ostringstream oss{}; + oss << '\n'; + std::for_each(first_loc.begin(), common_end, + [&oss](const StackFrame& frame) { FormatFrame(oss, frame); }); + const int32_t prefix_frames = std::distance(first_loc.begin(), common_end); + return {RegisterString(std::move(oss).str()), prefix_frames}; + } + + private: + static void FormatFrame(std::ostringstream& oss, const StackFrame& frame, + int depth = -1) { + if (depth >= 0) { + oss << std::string(depth + 1, '-') << ' '; + } + PrintEscaped(oss, frame.file_name) << ':' << frame.line; + if (frame.column) { + oss << ':' << frame.column; + } + if (!frame.function_name.empty()) { + PrintEscaped(oss << '[', frame.function_name) << ']'; + } + if (!frame.op_name.empty()) { + PrintEscaped(oss << ' ', frame.op_name); + } + oss << '\n'; + } + std::string_view op_name_prefix_to_remove_{}; + std::set> location_set_{}; +}; + +template +absl::Status VisitInstAndCalledButNotOperands(Visitor& visitor, + const HloInstruction& inst) { + // Visit the given instruction, and the things it calls, but not its operands. + TF_RETURN_IF_ERROR(visitor.DefaultAction(&inst)); + for (const HloComputation* called : inst.called_computations()) { + const HloInstruction* const root = called->root_instruction(); + TF_RETURN_IF_ERROR(root->Accept(&visitor, false /* call_finish_visit */, + true /* ignore_control_predecessors */, + true /* cross_computation */)); + } + return absl::OkStatus(); +} + +// Split `a` and `b` by `delim` into two lists of possibly-empty tokens, then +// rejoin the first N of those lists that match by `delim`. Note: it is +// unspecified which argument the return value points into. +std::string_view LongestPrefix(std::string_view a, std::string_view b, + char delim = '/') { + auto split_a = absl::StrSplit(a, delim); + auto split_b = absl::StrSplit(b, delim); + + size_t common_prefix_len = 0; + + for (auto a_it = split_a.begin(), b_it = split_b.begin(); + a_it != split_a.end() && b_it != split_b.end(); ++a_it, ++b_it) { + if (*a_it != *b_it) break; + + if (common_prefix_len) ++common_prefix_len; // account for delimiter + common_prefix_len += a_it->size(); // length of a matching token + } + + return std::string_view(a.data(), common_prefix_len); +} + +// Find the longest prefix among instructions' op_name metadata +// Chunk this by delimiting slashes, i.e. given a/b/cat and a/b/cabbage, the +// longest prefix is a/b not a/b/ca +class OpNamePrefixVisitor : public ConstDfsHloVisitorWithDefault { + public: + absl::Status DefaultAction(const HloInstruction* inst) final { + auto const& op_name = inst->metadata().op_name(); + if (!op_name.empty()) { + prefix_ = prefix_ ? LongestPrefix(*prefix_, op_name) : op_name; + } + return absl::OkStatus(); + } + + std::string_view longest_op_name_prefix() const { + return prefix_.value_or(""); + } + + private: + std::optional prefix_; +}; + +std::string_view GetLongestOpNamePrefix(const HloModule& mod) { + // In the presence of (at least) debug callbacks, calling Accept on the root + // instruction of the module may not reach all instructions in the module. + OpNamePrefixVisitor visitor{}; + for (const HloComputation* computation : mod.computations()) { + for (const HloInstruction* inst : computation->instructions()) { + if (!visitor.DefaultAction(inst).ok()) { + return {}; + } + } + } + return visitor.longest_op_name_prefix(); +} + +std::string_view GetLongestOpNamePrefix(const HloInstruction& inst) { + OpNamePrefixVisitor visitor{}; + if (!VisitInstAndCalledButNotOperands(visitor, inst).ok()) { + return {}; + } + return visitor.longest_op_name_prefix(); +} + +std::string MakeTitle(const HloModule& mod, std::string_view longest_prefix) { + if (longest_prefix.empty()) { + return absl::StrFormat("XlaModule:#hlo_module=%s,program_id=%d#", + mod.name(), mod.unique_id()); + } + return absl::StrFormat("XlaModule:#prefix=%s,hlo_module=%s,program_id=%d#", + longest_prefix, mod.name(), mod.unique_id()); +} + +std::string FormatSourceLocations(HloInstruction const& inst, + int32_t common_frames) { + // Inside the source location/backtrace report the op_name too, but remove the + // kernel-wide prefix for brevity + SourceLocationVisitor visitor{GetLongestOpNamePrefix(inst)}; + // Visit the given instruction, and the things it calls, but not its operands + // -- we don't want to collect the source code locations that produced the + // inputs to this kernel, just those corresponding to the kernel itself. + if (!VisitInstAndCalledButNotOperands(visitor, inst).ok()) { + return "[error]"; + } + return visitor.AsString(common_frames); +} + +// Get the string representation of this instruction as an std::string. +std::string InstructionAsString(HloInstruction const& inst) { + StringPrinter printer; + inst.Print(&printer, PrintOptions()); + return std::move(printer).ToString(); +} + +// Get the string representation of the HLO code called by this instruction, +// but not the instruction itself. The typical example is a fusion instruction, +// where InstructionAsString(fusion_inst) would be something like +// fusion.N = ... fusion(...), calls=fused_computation.N ... +// and CalledInstructionsAsString(fusion_inst) would be something like +// fused_computation.N { ... } +std::string CalledInstructionsAsString(HloInstruction const& inst) { + StringPrinter printer; + auto const opts = PrintOptions(); + for (HloComputation const* called : inst.called_computations()) { + called->Print(&printer, opts); + } + return std::move(printer).ToString(); +} + +// Get a string representing the longest common prefix of source locations in +// this module, and the number of frames that that represents. +std::pair GetLongestSourceLocationPrefix( + const HloModule& mod) { + // In the presence of (at least) debug callbacks, calling Accept on the root + // instruction of the module may not reach all instructions in the module. + SourceLocationVisitor visitor{}; + for (const HloComputation* computation : mod.computations()) { + for (const HloInstruction* inst : computation->instructions()) { + if (!visitor.DefaultAction(inst).ok()) { + return {}; + } + } + } + return visitor.LongestSourceLocationPrefix(); +} +} // namespace + +ModuleAnnotation::ModuleAnnotation(std::string_view module_name_) + : title_str_(absl::StrFormat("XlaModule:#hlo_module=%s#", module_name_)), + title_(RegisterString(title_str_)), + module_name_(RegisterString(std::string{module_name_})) {} + +ModuleAnnotation::ModuleAnnotation(const HloModule& mod) + : longest_prefix_(GetLongestOpNamePrefix(mod)), + title_str_(MakeTitle(mod, longest_prefix_)), + title_(RegisterString(title_str_)), + module_name_(RegisterString(mod.name())), + module_id_(mod.unique_id()) { + std::tie(common_src_locations_, common_stack_frames_) = + GetLongestSourceLocationPrefix(mod); +} + +#if GOOGLE_CUDA +namespace { +auto schema_entry(uint64_t type, const char* name, uint64_t offset) { + nvtxPayloadSchemaEntry_t r{}; + r.type = type; + r.name = name; + r.offset = offset; + return r; +} +} // namespace +#endif + +uint64_t ModuleAnnotation::NvtxSchemaId() { + static std::uint64_t schema_id = []() -> std::uint64_t { +#if GOOGLE_CUDA + auto domain = tsl::profiler::DefaultProfilerDomain(); + if (!domain) { + return 0; + } + const nvtxPayloadSchemaEntry_t schema[] = { + schema_entry(NVTX_PAYLOAD_ENTRY_TYPE_NVTX_REGISTERED_STRING_HANDLE, + "Name", offsetof(ModuleAnnotation, module_name_)), + schema_entry(NVTX_PAYLOAD_ENTRY_TYPE_INT32, "Unique ID", + offsetof(ModuleAnnotation, module_id_)), + schema_entry(NVTX_PAYLOAD_ENTRY_TYPE_NVTX_REGISTERED_STRING_HANDLE, + "Common source locations", + offsetof(ModuleAnnotation, common_src_locations_))}; + const nvtxPayloadSchemaAttr_t schemaAttr = { + /* .fieldMask = */ NVTX_PAYLOAD_SCHEMA_ATTR_NAME | + NVTX_PAYLOAD_SCHEMA_ATTR_TYPE | NVTX_PAYLOAD_SCHEMA_ATTR_ENTRIES | + NVTX_PAYLOAD_SCHEMA_ATTR_NUM_ENTRIES | + NVTX_PAYLOAD_SCHEMA_ATTR_STATIC_SIZE, + /* .name = */ "XlaModule", + /* .type = */ NVTX_PAYLOAD_SCHEMA_TYPE_STATIC, + /* .flags = */ NVTX_PAYLOAD_SCHEMA_FLAG_NONE, + /* .entries = */ schema, + /* .numEntries = */ sizeof(schema) / sizeof(schema[0]), + /* .payloadStaticSize = */ sizeof(ModuleAnnotation)}; + return RegisterSchema(domain, &schemaAttr); +#else + return 0; +#endif + }(); + return schema_id; +} + +namespace { +std::string MakeKernelName(std::string_view prefix, + const HloInstruction& inst) { + // Sometimes an instruction doesn't have metadata, but the computations that + // it calls do have metadata. Consider all of those metadata op_name entries + // and attach the longest prefix to this launch. + std::string_view op_name = GetLongestOpNamePrefix(inst); + if (op_name.empty()) { + return absl::StrFormat("Thunk:#hlo_op=%s#", inst.name()); + } else if (op_name.substr(0, prefix.size()) != prefix) { + // the op_name we got for this instruction does not start with the prefix + // that we thought was common to all instructions in the module + return absl::StrFormat("Thunk:#name=%s,hlo_op=%s#", op_name, inst.name()); + } else { + // remove the prefix that's in the parent module annotation + auto short_name = op_name.substr(prefix.size()); + // remove the leading / if there is one (prefix might be an empty string) + if (!short_name.empty() && short_name.front() == '/') { + short_name = short_name.substr(1); + } + return absl::StrFormat("Thunk:#name=%s,hlo_op=%s#", short_name, + inst.name()); + } +} +} // namespace + +KernelAnnotation::KernelAnnotation(const ModuleAnnotation& module_annotation, + const HloInstruction& inst) + : title_str( + MakeKernelName(module_annotation.longest_op_name_prefix(), inst)), + title(RegisterString(title_str)), + hlo_dump(RegisterString(InstructionAsString(inst))), + src_locations(RegisterString(FormatSourceLocations( + inst, module_annotation.common_stack_frames()))), + called_hlo_dump(RegisterString("\n" + CalledInstructionsAsString(inst))) { +} + +ModuleAnnotations::ModuleAnnotations(std::string_view module_name) + : top_level(module_name) {} + +uint64_t KernelAnnotation::NvtxSchemaId() { + static std::uint64_t schema_id = []() -> std::uint64_t { +#if GOOGLE_CUDA + auto domain = tsl::profiler::DefaultProfilerDomain(); + if (!domain) { + return 0; + } + const nvtxPayloadSchemaEntry_t schema[] = { + schema_entry(NVTX_PAYLOAD_ENTRY_TYPE_NVTX_REGISTERED_STRING_HANDLE, + "Source locations", + offsetof(KernelAnnotation, src_locations)), + schema_entry(NVTX_PAYLOAD_ENTRY_TYPE_NVTX_REGISTERED_STRING_HANDLE, + "HLO", offsetof(KernelAnnotation, hlo_dump)), + schema_entry(NVTX_PAYLOAD_ENTRY_TYPE_NVTX_REGISTERED_STRING_HANDLE, + "Called HLO", + offsetof(KernelAnnotation, called_hlo_dump))}; + const nvtxPayloadSchemaAttr_t schemaAttr = { + /* .fieldMask = */ NVTX_PAYLOAD_SCHEMA_ATTR_NAME | + NVTX_PAYLOAD_SCHEMA_ATTR_TYPE | NVTX_PAYLOAD_SCHEMA_ATTR_ENTRIES | + NVTX_PAYLOAD_SCHEMA_ATTR_NUM_ENTRIES | + NVTX_PAYLOAD_SCHEMA_ATTR_STATIC_SIZE, + /* .name = */ "XlaKernel", + /* .type = */ NVTX_PAYLOAD_SCHEMA_TYPE_STATIC, + /* .flags = */ NVTX_PAYLOAD_SCHEMA_FLAG_NONE, + /* .entries = */ schema, + /* .numEntries = */ sizeof(schema) / sizeof(schema[0]), + /* .payloadStaticSize = */ sizeof(KernelAnnotation)}; + return RegisterSchema(domain, &schemaAttr); +#else + return 0; +#endif + }(); + return schema_id; +} + +ModuleAnnotations::ModuleAnnotations(const HloModule& mod) : top_level{mod} { + // loop through `mod` and populate `kernels` (string -> KernelAnnotation map) + // with the information we want to attach to individual kernels. + for (const HloComputation* computation : mod.computations()) { + for (const HloInstruction* inst : computation->instructions()) { + // e.g. inst.name is "fusion.6", inst.opcode is "kFusion" and called + // is ["fused_computation.5"], in which case the content of + // "fused_computation.5" ends up under an NVTX range called + // "fusion.6". We want to construct a useful annotation for that NVTX + // range based on the content of `inst`, including `called` etc. + // FIXME: using try_emplace here was sensitive to + // https://github.com/abseil/abseil-cpp/issues/388. + kernels.insert({inst->name(), {top_level, *inst}}); + } + } +} + +//===----------------------------------------------------------------------===// +// Scoped RAII helper to set and restore thread local module annotations +//===----------------------------------------------------------------------===// + +namespace { +thread_local const ModuleAnnotations* current_annotations = nullptr; +} // namespace + +ScopedModuleAnnotations::ScopedModuleAnnotations( + const ModuleAnnotations* annotations) + : restore_(std::exchange(current_annotations, annotations)) {} + +ScopedModuleAnnotations::~ScopedModuleAnnotations() { + std::exchange(current_annotations, restore_); +} + +const ModuleAnnotations* GetCurrentModuleAnnotations() { + return current_annotations; +} + +std::optional GetKernelAnnotation( + const ModuleAnnotations* annotations, std::string_view profile_annotation) { + if (profile_annotation.empty()) { + return {}; + } + if (annotations) { + // Have a set of pre-prepared thunk/kernel annotations to use + const auto iter = annotations->kernels.find(profile_annotation); + if (iter != annotations->kernels.end()) { + // Have a pre-prepared annotation, use it + return std::optional{[&] { return iter->second; }}; + } + } + return std::optional{ + [&] { return absl::StrFormat("Thunk:#hlo_op=%s#", profile_annotation); }}; +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/annotation.h b/xla/service/gpu/runtime/annotation.h new file mode 100644 index 0000000000000..70a4c8df1ddeb --- /dev/null +++ b/xla/service/gpu/runtime/annotation.h @@ -0,0 +1,111 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_ANNOTATION_H_ +#define XLA_SERVICE_GPU_RUNTIME_ANNOTATION_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "tsl/profiler/lib/nvtx_utils.h" +#include "tsl/profiler/lib/scoped_annotation.h" + +namespace xla::gpu { + +// Prepared information for the top level NVTX/profiler range covering an +// HloModule +class ModuleAnnotation { + public: + explicit ModuleAnnotation(std::string_view module_name); + explicit ModuleAnnotation(const HloModule& mod); + + std::string_view longest_op_name_prefix() const { return longest_prefix_; } + explicit operator std::string_view() const { return title_str_; } + tsl::profiler::StringHandle title() const { return title_; } + static uint64_t NvtxSchemaId(); + int32_t common_stack_frames() const { return common_stack_frames_; } + + private: + friend void RangePush(tsl::profiler::ProfilerDomainHandle domain, + const ModuleAnnotation& annotation) { + tsl::profiler::RangePush(domain, annotation.title(), annotation); + } + + std::string longest_prefix_; + std::string title_str_; + tsl::profiler::StringHandle title_; + tsl::profiler::StringHandle module_name_; + tsl::profiler::StringHandle common_src_locations_{}; + int32_t module_id_{-1}; + int32_t common_stack_frames_{}; +}; + +// Prepared information for a kernel/thunk/fusion/... within an HloModule +struct KernelAnnotation { + KernelAnnotation(const ModuleAnnotation& module_annotation, + const HloInstruction& inst); + + explicit operator std::string_view() const { return title_str; } + static uint64_t NvtxSchemaId(); + + private: + friend void RangePush(tsl::profiler::ProfilerDomainHandle domain, + const KernelAnnotation& annotation) { + tsl::profiler::RangePush(domain, annotation.title, annotation); + } + + std::string title_str; + tsl::profiler::StringHandle title; + tsl::profiler::StringHandle hlo_dump; + tsl::profiler::StringHandle src_locations; + tsl::profiler::StringHandle called_hlo_dump; +}; + +// Parsed/prepared information for an HloModule that gets propagated to NVTX +// ranges/profilers/... at execution time. +struct ModuleAnnotations { + explicit ModuleAnnotations(std::string_view module_name); + explicit ModuleAnnotations(const HloModule&); + + ModuleAnnotation top_level; + absl::flat_hash_map kernels; +}; + +//===----------------------------------------------------------------------===// +// Scoped RAII helper to set and restore thread local module annotations +//===----------------------------------------------------------------------===// + +class ScopedModuleAnnotations { + public: + explicit ScopedModuleAnnotations(const ModuleAnnotations* annotations); + ~ScopedModuleAnnotations(); + + private: + const ModuleAnnotations* restore_; +}; + +const ModuleAnnotations* GetCurrentModuleAnnotations(); + +std::optional GetKernelAnnotation( + const ModuleAnnotations* annotations, std::string_view profile_annotation); + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_RUNTIME_ANNOTATION_H_ diff --git a/xla/service/gpu/runtime/cholesky.cc b/xla/service/gpu/runtime/cholesky.cc deleted file mode 100644 index 05d9a32ec2ad2..0000000000000 --- a/xla/service/gpu/runtime/cholesky.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/cholesky.h" - -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/gpu_asm_opts_util.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/xla.pb.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/service/gpu/runtime3/cholesky_thunk.h" -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -namespace xla { -namespace gpu { - -using ::xla::runtime::CustomCall; -using ::xla::runtime::MemrefView; -using ::xla::runtime::StridedMemrefView; - -static absl::Status CholeskyImpl(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - StridedMemrefView operand, StridedMemrefView a, - MemrefView workspace, MemrefView info, - int64_t batch_size, bool is_lower, int64_t n) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - se::DeviceMemoryBase operand_buffer = GetDeviceAddress(operand); - se::DeviceMemoryBase a_buffer = GetDeviceAddress(a); - se::DeviceMemoryBase workspace_buffer = GetDeviceAddress(workspace); - se::DeviceMemoryBase info_buffer = GetDeviceAddress(info); - - VLOG(3) << "Running Cholesky"; - se::Stream* stream = run_options->stream(); - - // Copy operand to the a buffer if they are different. - if (a.data != operand.data) - stream->ThenMemcpy(&a_buffer, operand_buffer, operand_buffer.size()); - - using UpperLower = se::blas::UpperLower; - UpperLower uplo = is_lower ? UpperLower::kLower : UpperLower::kUpper; - - CholeskyParams params{n, batch_size, uplo, - a_buffer, workspace_buffer, info_buffer}; - return RunCholesky(xla::gpu::PtxOptsFromDebugOptions(*debug_options), - operand.dtype, ¶ms, stream); -#else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - return absl::InternalError("Cholesky is not supported without GPU"); -#endif -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Cholesky, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.cholesky") - .UserData() - .UserData() - .Arg() // operand - .Arg() // a - .Arg() // workspace - .Arg() // info - .Attr("batch_size") - .Attr("is_lower") - .Attr("n")); - -void RegisterCholeskyCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.cholesky", Cholesky); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/cholesky.h b/xla/service/gpu/runtime/cholesky.h deleted file mode 100644 index 23e8da2019892..0000000000000 --- a/xla/service/gpu/runtime/cholesky.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_CHOLESKY_H_ -#define XLA_SERVICE_GPU_RUNTIME_CHOLESKY_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime cholesky custom calls. -void RegisterCholeskyCustomCalls(runtime::DirectCustomCallRegistry& registry); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_CHOLESKY_H_ diff --git a/xla/service/gpu/runtime/cholesky_thunk.cc b/xla/service/gpu/runtime/cholesky_thunk.cc new file mode 100644 index 0000000000000..b91be4449d3b8 --- /dev/null +++ b/xla/service/gpu/runtime/cholesky_thunk.cc @@ -0,0 +1,176 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/cholesky_thunk.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "xla/service/gpu/cusolver_context.h" +#include "xla/service/gpu/make_batch_pointers.h" +#include "xla/stream_executor/blas.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_asm_opts.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" + +namespace xla { +namespace gpu { + +namespace { + +template +absl::Status DoPotrfBatched(const se::GpuAsmOpts& asm_opts, + CholeskyParams* params, se::Stream* stream, + GpuSolverContext& context) { + T* a_base = static_cast(params->a_buffer.opaque()); + se::DeviceMemory infos(params->info_buffer); +#if TENSORFLOW_USE_ROCSOLVER + // hipsolver is not supported so allocate a GPU buffer + se::ScopedDeviceMemory ptrs = + stream->parent()->AllocateOwnedArray(batch_size_); + auto as = *ptrs; +#else + se::DeviceMemory as(params->workspace_buffer); +#endif + + CHECK_GE(as.size(), params->batch_size); + CHECK_GE(infos.size(), params->batch_size); + + // Run a kernel that sets as[i] = &a_base[i * stride]. + const int64_t stride_bytes = params->n * params->n * sizeof(T); + TF_RETURN_IF_ERROR(MakeBatchPointers( + stream, se::DeviceMemoryBase(a_base), stride_bytes, + static_cast(params->batch_size), se::DeviceMemoryBase(as))); + + // Now that we've set up the `as` array, we can call cusolver. + return context.PotrfBatched(params->uplo, params->n, as, params->n, infos, + params->batch_size); +} + +template +absl::Status DoPotrfUnbatched(const se::GpuAsmOpts& asm_opts, + CholeskyParams* params, se::Stream* stream, + GpuSolverContext& context) { + T* a_base = static_cast(params->a_buffer.opaque()); + int* info_base = static_cast(params->info_buffer.opaque()); + + int64_t stride = params->n * params->n; + for (int64_t i = 0; i < params->batch_size; ++i) { + se::DeviceMemory a_data( + se::DeviceMemoryBase(&a_base[i * stride], sizeof(T) * stride)); + se::DeviceMemory info_data( + se::DeviceMemoryBase(&info_base[i], sizeof(int))); + se::DeviceMemory workspace_data(params->workspace_buffer); + TF_RETURN_IF_ERROR(context.Potrf(params->uplo, params->n, a_data, params->n, + info_data, workspace_data)); + } + return absl::OkStatus(); +} + +} // namespace + +CholeskyThunk::CholeskyThunk(ThunkInfo thunk_info, + const CholeskyOptions& options, + const se::GpuAsmOpts asm_opts, + BufferAllocation::Slice a_buffer, + BufferAllocation::Slice workspace_buffer, + BufferAllocation::Slice info_buffer, + PrimitiveType type, int64_t batch_size, int64_t n) + : Thunk(Kind::kCholesky, thunk_info), + asm_opts_(asm_opts), + uplo_(options.lower() ? se::blas::UpperLower::kLower + : se::blas::UpperLower::kUpper), + a_buffer_(a_buffer), + workspace_buffer_(workspace_buffer), + info_buffer_(info_buffer), + type_(type), + batch_size_(batch_size), + n_(n) {} + +absl::Status CholeskyThunk::ExecuteOnStream(const ExecuteParams& params) { + VLOG(3) << "type=" << PrimitiveType_Name(type_) + << " uplo=" << se::blas::UpperLowerString(uplo_) + << " batch_size=" << batch_size_ << " n=" << n_ + << " a=" << a_buffer_.ToString() + << " workspace=" << workspace_buffer_.ToString() + << " info=" << info_buffer_.ToString(); + + se::DeviceMemoryBase a_buffer = + params.buffer_allocations->GetDeviceAddress(a_buffer_); + se::DeviceMemoryBase info_buffer = + params.buffer_allocations->GetDeviceAddress(info_buffer_); + se::DeviceMemoryBase workspace_buffer = + params.buffer_allocations->GetDeviceAddress(workspace_buffer_); + CholeskyParams cholesky_params{n_, batch_size_, uplo_, + a_buffer, workspace_buffer, info_buffer}; + return RunCholesky(asm_opts_, type_, &cholesky_params, params.stream); +} + +absl::Status RunCholesky(const se::GpuAsmOpts& asm_opts, PrimitiveType type, + CholeskyParams* cholesky_params, se::Stream* stream) { + thread_local absl::StatusOr context = + GpuSolverContext::Create(); + TF_RETURN_IF_ERROR(context.status()); + TF_RETURN_IF_ERROR(context->SetStream(stream)); + + if (cholesky_params->batch_size > 1) { + switch (type) { + case F32: + return DoPotrfBatched(asm_opts, cholesky_params, stream, + *context); + case F64: + return DoPotrfBatched(asm_opts, cholesky_params, stream, + *context); + case C64: + return DoPotrfBatched>(asm_opts, cholesky_params, + stream, *context); + case C128: + return DoPotrfBatched>(asm_opts, cholesky_params, + stream, *context); + default: + return InvalidArgument("Invalid type for cholesky %s", + PrimitiveType_Name(type)); + } + } else { + switch (type) { + case F32: + return DoPotrfUnbatched(asm_opts, cholesky_params, stream, + *context); + case F64: + return DoPotrfUnbatched(asm_opts, cholesky_params, stream, + *context); + case C64: + return DoPotrfUnbatched>(asm_opts, cholesky_params, + stream, *context); + case C128: + return DoPotrfUnbatched>(asm_opts, cholesky_params, + stream, *context); + default: + return InvalidArgument("Invalid type for cholesky %s", + PrimitiveType_Name(type)); + } + } +} +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime3/cholesky_thunk.h b/xla/service/gpu/runtime/cholesky_thunk.h similarity index 76% rename from xla/service/gpu/runtime3/cholesky_thunk.h rename to xla/service/gpu/runtime/cholesky_thunk.h index f226a25f7e488..3fdbf3ebc89f9 100644 --- a/xla/service/gpu/runtime3/cholesky_thunk.h +++ b/xla/service/gpu/runtime/cholesky_thunk.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,22 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME3_CHOLESKY_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_CHOLESKY_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_CHOLESKY_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_CHOLESKY_THUNK_H_ -#include +#include -#include "xla/hlo/ir/hlo_instruction.h" +#include "absl/status/status.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/cusolver_context.h" -#include "xla/service/gpu/thunk.h" +#include "xla/service/gpu/runtime/thunk.h" #include "xla/stream_executor/blas.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/status.h" namespace xla { namespace gpu { @@ -51,7 +48,7 @@ class CholeskyThunk : public Thunk { CholeskyThunk(const CholeskyThunk&) = delete; CholeskyThunk& operator=(const CholeskyThunk&) = delete; - Status ExecuteOnStream(const ExecuteParams& params) override; + absl::Status ExecuteOnStream(const ExecuteParams& params) override; private: se::GpuAsmOpts asm_opts_; @@ -74,10 +71,10 @@ struct CholeskyParams { se::DeviceMemoryBase workspace_buffer; se::DeviceMemoryBase info_buffer; }; -Status RunCholesky(const se::GpuAsmOpts& asm_opts, PrimitiveType type, - CholeskyParams* params, se::Stream* stream); +absl::Status RunCholesky(const se::GpuAsmOpts& asm_opts, PrimitiveType type, + CholeskyParams* params, se::Stream* stream); } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME3_CHOLESKY_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_CHOLESKY_THUNK_H_ diff --git a/xla/service/gpu/runtime/collectives.cc b/xla/service/gpu/runtime/collectives.cc index 8c41e50b41f6d..e69de29bb2d1d 100644 --- a/xla/service/gpu/runtime/collectives.cc +++ b/xla/service/gpu/runtime/collectives.cc @@ -1,912 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/collectives.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/collective_ops_utils.h" -#include "xla/service/computation_placer.h" -#include "xla/service/global_device_id.h" -#include "xla/service/gpu/gpu_executable_run_options.h" -#include "xla/service/gpu/nccl_all_gather_thunk.h" -#include "xla/service/gpu/nccl_all_reduce_thunk.h" -#include "xla/service/gpu/nccl_all_to_all_thunk.h" -#include "xla/service/gpu/nccl_collective_permute_thunk.h" -#include "xla/service/gpu/nccl_collective_thunk.h" -#include "xla/service/gpu/nccl_recv_thunk.h" -#include "xla/service/gpu/nccl_send_thunk.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/gpu/thunk.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/stream.h" - -#if XLA_ENABLE_XCCL -#include "xla/service/gpu/runtime/gpu_kernel_helper.h" -#include "xla/service/gpu/runtime/sleep_kernel.h" -#include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/gpu_types.h" -#endif // XLA_ENABLE_XCCL - -namespace xla { -namespace gpu { - -using xla::runtime::CustomCall; -using xla::runtime::FlatMemrefView; -using xla::runtime::StridedMemrefView; - -namespace { - -Status RunRepeated(int32_t count, absl::FunctionRef to_run) { - if (count != 0) { - VLOG(3) << "Running each collective " << count << " times\n"; - } - for (int32_t i = 0; i < count; ++i) { - TF_RETURN_IF_ERROR(to_run()); - } - return OkStatus(); -} - -// Helper function to run a collective either synchronously on main stream or -// asynchronously on the async stream. -absl::Status RunSyncOrAsync( - const ServiceExecutableRunOptions* run_options, - CollectivesSupport* collectives, AsyncCollectivesSupport* async_collectives, - int32_t uid, bool is_async, - absl::FunctionRef to_run, - AsyncStreamKind stream_kind = kAsyncStreamCollective) { - se::Stream* main_stream = run_options->stream(); - se::Stream* async_stream = - is_async ? async_collectives->async_comm_stream(stream_kind) : nullptr; - if (is_async) { - // Wait until compute inputs are ready. - async_stream->ThenWaitFor(main_stream); - } - - // Launch the collective on either the main or async stream. - se::Stream* stream = is_async ? async_stream : main_stream; - TF_RETURN_IF_ERROR(to_run(stream)); - - if (is_async) { - TF_RETURN_IF_ERROR(async_collectives->RecordEvent(uid, stream_kind)); - } - int32_t device_ordinal = main_stream->parent()->device_ordinal(); - return collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, main_stream); -} - -#if XLA_ENABLE_XCCL -bool ShouldEnableCliqueOptimization(const NcclExecuteParams& params, - const DebugOptions* debug_options, - bool no_parallel_custom_call) { - // Enable clique optimization for single-host application, which is indicated - // by the absence of nccl_unique_id_callback. For multiple-host, only enable - // when a debug flag is set for now, due to some divergent compilation issues. - return no_parallel_custom_call && - (!params.nccl_unique_id_callback || - debug_options->xla_gpu_enable_nccl_clique_optimization()); -} - -StatusOr GetNcclComm( - const NcclExecuteParams& params, int64_t group_mode, int64_t op_id, - absl::Span replica_group_offsets, - absl::Span replica_group_values, int64_t stream_id, - bool enable_clique_optimization) { - // TODO(b/233930690): Pass the attribute below as a nested array. - // Pass an array of arrays using two vectors; one specifying all the values - // and another specifying the (ending) offsets of each array in the other - // vector. Example: [ [10, 20, 30, 40], [50, 60], [70, 80, 90] ] turns into - // offsets=[4, 6, 9] values=[10, 20, 30, 40, 50, 60, 70, 80, 90]. - std::vector replica_groups; - int i = 0; - for (int64_t replica_group_end : replica_group_offsets) { - ReplicaGroup replica_group; - while (i < replica_group_end) - replica_group.add_replica_ids(replica_group_values[i++]); - replica_groups.push_back(replica_group); - } - - return LockNcclComm(params, replica_groups, - static_cast(group_mode), op_id, - stream_id, enable_clique_optimization); -} -#endif // XLA_ENABLE_XCCL - -StatusOr> GetDeviceBufferPairs( - CustomCall::RemainingArgs& args) { - // Add MemRef arguments as buffer arguments. - TF_RET_CHECK(args.size() % 2 == 0); - const int buffer_pairs = args.size() / 2; - std::vector device_buffers; - device_buffers.reserve(buffer_pairs); - for (int i = 0; i < buffer_pairs; ++i) { - auto source = args.get(i); - auto destination = args.get(i + buffer_pairs); - if (failed(source) || failed(destination)) { - return InvalidArgument("Unsupported device buffer pair type"); - } - - int64_t element_count = 1; - for (int64_t size : source->sizes) element_count *= size; - device_buffers.emplace_back(DeviceBufferPair{ - source->dtype, element_count, GetDeviceAddress(*source), - GetDeviceAddress(*destination)}); - } - return device_buffers; -} - -// Expects a single argument, and returns a device buffer pair with that -// argument replicated in both source and destination buffer. -StatusOr> GetSingleArgAsDeviceBufferPair( - CustomCall::RemainingArgs& args) { - TF_RET_CHECK(args.size() == 1); - auto buffer = args.get(0); - if (failed(buffer)) { - return InvalidArgument("Unsupported device buffer type"); - } - int64_t element_count = 1; - for (int64_t size : buffer->sizes) element_count *= size; - return std::vector{ - DeviceBufferPair{buffer->dtype, element_count, GetDeviceAddress(*buffer), - GetDeviceAddress(*buffer)}}; -} - -absl::Status AsyncDoneImpl(const ServiceExecutableRunOptions* run_options, - AsyncCollectivesSupport* async_collectives, - int32_t uid, std::string_view done_type) { -#if XLA_ENABLE_XCCL - VLOG(3) << "Running " << done_type; - se::Stream* stream = run_options->stream(); - - TF_ASSIGN_OR_RETURN(se::Event event, async_collectives->PopEvent(uid)); - stream->ThenWaitFor(&event); - - return absl::OkStatus(); -#else // XLA_ENABLE_XCCL - return absl::InternalError("NCCL disabled"); -#endif // XLA_ENABLE_XCCL -} - -// TODO: shall we use GpuDriver::LaunchKernel() to avoid macros here ? -#if XLA_ENABLE_XCCL -absl::Status NcclMockImplCommon(se::Stream* stream) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#define CHK(x) \ - if (auto res = (x); res != gpuSuccess) { \ - return absl::InternalError( \ - absl::StrFormat("Call failed with '%s' at line %d", \ - gpuGetErrorString(res), __LINE__)); \ - } - auto gpu_stream = se::gpu::AsGpuStreamValue(stream); - uint32_t sleep_duration_ns = 1000; - void* kernel = GetSleepKernel(); - dim3 gridDim = {1, 1, 1}; - dim3 blockDim = {512, 1, 1}; - -#if GOOGLE_CUDA - void* kernel_args[] = {&sleep_duration_ns}; -#else - int devID = 0; - hipDeviceProp_t prop{}; - CHK(hipGetDevice(&devID)); - CHK(hipGetDeviceProperties(&prop, devID)); - void* kernel_args[] = {&sleep_duration_ns, &prop.clockRate}; -#endif - CHK(gpuLaunchKernel(kernel, gridDim, blockDim, kernel_args, 0, gpu_stream)); -#undef CHK -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - return absl::OkStatus(); -} -#endif // XLA_ENABLE_XCCL - -//===----------------------------------------------------------------------===// -// CollectivePermute. -//===----------------------------------------------------------------------===// - -#if XLA_ENABLE_XCCL -using NcclP2PRunner = absl::FunctionRef; - -using DeviceBuffersGetter = - absl::FunctionRef>( - CustomCall::RemainingArgs& args)>; - -absl::Status P2PImplCommon(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - se::Stream* stream, CustomCall::RemainingArgs args, - int64_t group_mode, int64_t op_id, - bool no_parallel_custom_call, - absl::Span replica_group_offsets, - absl::Span replica_group_values, - absl::Span source_peers, - absl::Span target_peers, - NcclP2PRunner runner, - DeviceBuffersGetter device_buffers_getter, - uint64_t stream_id) { - (void)no_parallel_custom_call; - NcclExecuteParams params(*run_options, stream->parent()); - bool enable_clique_opt = ShouldEnableCliqueOptimization( - params, debug_options, no_parallel_custom_call); - - const std::string device_string = - NcclCollectiveThunk::GetDeviceString(params); - auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets, - replica_group_values, stream_id, enable_clique_opt); - if (!comm.ok()) return comm.status(); - - auto device_buffers = device_buffers_getter(args); - if (!device_buffers.ok()) return device_buffers.status(); - if (device_buffers->size() != 1) { - return absl::InternalError(absl::StrFormat( - "Expected device buffer size: 1, got %d", device_buffers->size())); - } - - TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id, - params.GetGlobalDeviceId()); - - TF_ASSIGN_OR_RETURN(DeviceAssignment::LogicalID current_logical_id, - params.device_assn->LogicalIdForDevice(global_device_id)); - - int64_t current_id = 0; - switch (static_cast(group_mode)) { - case CollectiveOpGroupMode::kFlattenedID: { - int replica_count = params.device_assn->replica_count(); - int computation_count = params.device_assn->computation_count(); - current_id = current_logical_id.replica_id * computation_count + - current_logical_id.computation_id; - break; - } - case CollectiveOpGroupMode::kCrossReplica: { - current_id = current_logical_id.replica_id; - break; - } - default: { - current_id = current_logical_id.computation_id; - break; - } - } - - NcclP2PConfig::IdToSourceTargetMap id_to_source_target; - for (int i = 0; i < source_peers.size(); ++i) { - id_to_source_target[target_peers[i]].source = source_peers[i]; - id_to_source_target[source_peers[i]].target = target_peers[i]; - } - const NcclP2PConfig::SourceTargetMapEntry source_target = - NcclP2PConfig::GetSourceTarget(id_to_source_target, current_id); - - return RunRepeated( - debug_options->xla_gpu_collective_inflation_factor(), [&]() -> Status { - return runner(source_target, (*device_buffers)[0], *stream, **comm, - device_string, current_id); - }); -} -#endif // XLA_ENABLE_XCCL - -absl::Status CollectivePermuteImpl( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, CollectivesSupport* collectives, - AsyncCollectivesSupport* async_collectives, CustomCall::RemainingArgs args, - int32_t uid, int64_t group_mode, int64_t op_id, bool is_async, - bool no_parallel_custom_call, - absl::Span replica_group_offsets, - absl::Span replica_group_values, - absl::Span source_peers, - absl::Span target_peers) { -#if XLA_ENABLE_XCCL - VLOG(3) << "Running CollectivePermute " << (is_async ? "(Async) " : "(Sync) ") - << no_parallel_custom_call; - return RunSyncOrAsync( - run_options, collectives, async_collectives, uid, is_async, - [&](se::Stream* stream) { - const gpu::GpuExecutableRunOptions* gpu_opts = - run_options->run_options().gpu_executable_run_options(); - if (gpu_opts && gpu_opts->enable_mock_nccl_collectives()) { - return NcclMockImplCommon(stream); - } - return P2PImplCommon(run_options, debug_options, stream, args, - group_mode, op_id, no_parallel_custom_call, - replica_group_offsets, replica_group_values, - source_peers, target_peers, RunCollectivePermute, - GetDeviceBufferPairs, GetStreamId(is_async)); - }); -#else // XLA_ENABLE_XCCL - return absl::InternalError("NCCL disabled"); -#endif // XLA_ENABLE_XCCL -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - CollectivePermute, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.collective_permute") - .UserData() - .UserData() - .UserData() - .UserData() - .RemainingArgs() // args - .Attr("uid") - .Attr("group_mode") // CollectiveOpGroupMode - .Attr("op_id") - .Attr("is_async") - .Attr("no_parallel_custom_call") - .Attr>("replica_group_offsets") - .Attr>("replica_group_values") - .Attr>("source_peers") - .Attr>("target_peers")); - -//===----------------------------------------------------------------------===// -// Send. -//===----------------------------------------------------------------------===// - -static absl::Status P2PSendImpl(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - CollectivesSupport* collectives, - AsyncCollectivesSupport* async_collectives, - CustomCall::RemainingArgs args, int32_t uid, - int64_t group_mode, int64_t op_id, - bool is_async, bool no_parallel_custom_call, - absl::Span replica_group_offsets, - absl::Span replica_group_values, - absl::Span source_peers, - absl::Span target_peers) { -#if XLA_ENABLE_XCCL - VLOG(3) << "Running Send"; - TF_RET_CHECK(is_async); - // The scheduler guarantee no_parallel_custom_call for P2P chain, which is not - // reflected in the default value for the attribute. - return RunSyncOrAsync( - run_options, collectives, async_collectives, uid, is_async, - [&](se::Stream* stream) { - return P2PImplCommon( - run_options, debug_options, stream, args, group_mode, op_id, - /*no_parallel_custom_call=*/true, replica_group_offsets, - replica_group_values, source_peers, target_peers, RunSend, - GetSingleArgAsDeviceBufferPair, - GetStreamId(is_async, kAsyncStreamP2P)); - }, - kAsyncStreamP2P); -#else // XLA_ENABLE_XCCL - return absl::InternalError("NCCL disabled"); -#endif // XLA_ENABLE_XCCL -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - P2PSend, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.send") - .UserData() - .UserData() - .UserData() - .UserData() - .RemainingArgs() // args - .Attr("uid") - .Attr("group_mode") // CollectiveOpGroupMode - .Attr("op_id") - .Attr("is_async") - .Attr("no_parallel_custom_call") - .Attr>("replica_group_offsets") - .Attr>("replica_group_values") - .Attr>("source_peers") - .Attr>("target_peers")); - -//===----------------------------------------------------------------------===// -// Recv. -//===----------------------------------------------------------------------===// - -static absl::Status P2PRecvImpl(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - CollectivesSupport* collectives, - AsyncCollectivesSupport* async_collectives, - CustomCall::RemainingArgs args, int32_t uid, - int64_t group_mode, int64_t op_id, - bool is_async, bool no_parallel_custom_call, - absl::Span replica_group_offsets, - absl::Span replica_group_values, - absl::Span source_peers, - absl::Span target_peers) { -#if XLA_ENABLE_XCCL - VLOG(3) << "Running Recv"; - TF_RET_CHECK(is_async); - // The scheduler guarantee no_parallel_custom_call for P2P chain, which is not - // reflected in the default value for the attribute. - return RunSyncOrAsync( - run_options, collectives, async_collectives, uid, is_async, - [&](se::Stream* stream) { - return P2PImplCommon( - run_options, debug_options, stream, args, group_mode, op_id, - /*no_parallel_custom_call=*/true, replica_group_offsets, - replica_group_values, source_peers, target_peers, RunRecv, - GetSingleArgAsDeviceBufferPair, - GetStreamId(is_async, kAsyncStreamP2P)); - }, - kAsyncStreamP2P); -#else // XLA_ENABLE_XCCL - return absl::InternalError("NCCL disabled"); -#endif // XLA_ENABLE_XCCL -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - P2PRecv, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.recv") - .UserData() - .UserData() - .UserData() - .UserData() - .RemainingArgs() // args - .Attr("uid") - .Attr("group_mode") // CollectiveOpGroupMode - .Attr("op_id") - .Attr("is_async") - .Attr("no_parallel_custom_call") - .Attr>("replica_group_offsets") - .Attr>("replica_group_values") - .Attr>("source_peers") - .Attr>("target_peers")); - -//===----------------------------------------------------------------------===// -// AllGather. -//===----------------------------------------------------------------------===// - -#if XLA_ENABLE_XCCL -absl::Status AllGatherImplCommon( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, se::Stream* stream, - CustomCall::RemainingArgs args, int64_t group_mode, int64_t op_id, - absl::Span replica_group_offsets, - absl::Span replica_group_values, bool is_async, - bool no_parallel_custom_call) { - NcclExecuteParams params(*run_options, stream->parent()); - bool enable_clique_opt = ShouldEnableCliqueOptimization( - params, debug_options, no_parallel_custom_call); - TF_ASSIGN_OR_RETURN( - auto comm, GetNcclComm(params, group_mode, op_id, replica_group_offsets, - replica_group_values, GetStreamId(is_async), - enable_clique_opt)); - - TF_ASSIGN_OR_RETURN(auto device_buffers, GetDeviceBufferPairs(args)); - - return RunRepeated( - debug_options->xla_gpu_collective_inflation_factor(), - [&]() { return RunAllGather(device_buffers, *stream, *comm); }); -} -#endif // XLA_ENABLE_XCCL - -absl::Status AllGatherImpl(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - CollectivesSupport* collectives, - AsyncCollectivesSupport* async_collectives, - CustomCall::RemainingArgs args, int32_t uid, - int64_t group_mode, int64_t op_id, bool is_async, - bool no_parallel_custom_call, - absl::Span replica_group_offsets, - absl::Span replica_group_values) { -#if XLA_ENABLE_XCCL - VLOG(3) << "Running AllGather " << (is_async ? "(Async) " : "(Sync) ") - << no_parallel_custom_call; - return RunSyncOrAsync( - run_options, collectives, async_collectives, uid, is_async, - [&](se::Stream* stream) { - const gpu::GpuExecutableRunOptions* gpu_opts = - run_options->run_options().gpu_executable_run_options(); - if (gpu_opts && gpu_opts->enable_mock_nccl_collectives()) { - return NcclMockImplCommon(stream); - } - return AllGatherImplCommon(run_options, debug_options, stream, args, - group_mode, op_id, replica_group_offsets, - replica_group_values, is_async, - no_parallel_custom_call); - }); -#else // XLA_ENABLE_XCCL - return absl::InternalError("NCCL diasbled"); -#endif // XLA_ENABLE_XCCL -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - AllGather, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.all_gather") - .UserData() - .UserData() - .UserData() - .UserData() - .RemainingArgs() // args - .Attr("uid") - .Attr("group_mode") // CollectiveOpGroupMode - .Attr("op_id") - .Attr("is_async") - .Attr("no_parallel_custom_call") - .Attr>("replica_group_offsets") - .Attr>("replica_group_values")); - -//===----------------------------------------------------------------------===// -// AllReduce. -//===----------------------------------------------------------------------===// - -#if XLA_ENABLE_XCCL -absl::Status AllReduceImplCommon( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, se::Stream* stream, - CustomCall::RemainingArgs args, int64_t group_mode, int64_t op_id, - int64_t reduction_kind, absl::Span replica_group_offsets, - absl::Span replica_group_values, bool is_async, - bool no_parallel_custom_call) { - NcclExecuteParams params(*run_options, stream->parent()); - bool enable_clique_opt = ShouldEnableCliqueOptimization( - params, debug_options, no_parallel_custom_call); - - TF_ASSIGN_OR_RETURN( - auto comm, GetNcclComm(params, group_mode, op_id, replica_group_offsets, - replica_group_values, GetStreamId(is_async), - enable_clique_opt)); - - TF_ASSIGN_OR_RETURN(auto device_buffers, GetDeviceBufferPairs(args)); - - return RunRepeated( - debug_options->xla_gpu_collective_inflation_factor(), [&]() { - return RunAllReduce(static_cast(reduction_kind), - device_buffers, *stream, *comm); - }); -} -#endif // XLA_ENABLE_XCCL - -absl::Status AllReduceImpl(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - CollectivesSupport* collectives, - AsyncCollectivesSupport* async_collectives, - CustomCall::RemainingArgs args, int32_t uid, - int64_t group_mode, int64_t op_id, bool is_async, - bool no_parallel_custom_call, int64_t reduction_kind, - absl::Span replica_group_offsets, - absl::Span replica_group_values) { -#if XLA_ENABLE_XCCL - VLOG(3) << "Running AllReduce " << (is_async ? "(Async) " : "(Sync) ") - << no_parallel_custom_call; - return RunSyncOrAsync( - run_options, collectives, async_collectives, uid, is_async, - [&](se::Stream* stream) { - const gpu::GpuExecutableRunOptions* gpu_opts = - run_options->run_options().gpu_executable_run_options(); - if (gpu_opts && gpu_opts->enable_mock_nccl_collectives()) { - return NcclMockImplCommon(stream); - } - return AllReduceImplCommon(run_options, debug_options, stream, args, - group_mode, op_id, reduction_kind, - replica_group_offsets, replica_group_values, - is_async, no_parallel_custom_call); - }); -#else // XLA_ENABLE_XCCL - // NCCL disabled. - return absl::InternalError("NCCL disabled"); -#endif // XLA_ENABLE_XCCL -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - AllReduce, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.all_reduce") - .UserData() - .UserData() - .UserData() - .UserData() - .RemainingArgs() // args - .Attr("uid") - .Attr("group_mode") // CollectiveOpGroupMode - .Attr("op_id") - .Attr("is_async") - .Attr("no_parallel_custom_call") - .Attr("reduction_kind") // ReductionKind - .Attr>("replica_group_offsets") - .Attr>("replica_group_values")); - -//===----------------------------------------------------------------------===// -// AllToAll. -//===----------------------------------------------------------------------===// - -#if XLA_ENABLE_XCCL -absl::Status AllToAllImplCommon(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - se::Stream* stream, - CustomCall::RemainingArgs args, - int64_t group_mode, bool has_split_dimension, - int64_t op_id, - absl::Span replica_group_offsets, - absl::Span replica_group_values, - bool is_async, bool no_parallel_custom_call) { - NcclExecuteParams params(*run_options, stream->parent()); - bool enable_clique_opt = ShouldEnableCliqueOptimization( - params, debug_options, no_parallel_custom_call); - - TF_ASSIGN_OR_RETURN( - auto comm, GetNcclComm(params, group_mode, op_id, replica_group_offsets, - replica_group_values, GetStreamId(is_async), - enable_clique_opt)); - - TF_ASSIGN_OR_RETURN(auto device_buffers, GetDeviceBufferPairs(args)); - - return RunRepeated( - debug_options->xla_gpu_collective_inflation_factor(), [&]() { - return RunAllToAll(has_split_dimension, device_buffers, *stream, *comm); - }); -} -#endif // XLA_ENABLE_XCCL - -absl::Status AllToAllImpl(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - CollectivesSupport* collectives, - AsyncCollectivesSupport* async_collectives, - CustomCall::RemainingArgs args, int32_t uid, - int64_t group_mode, bool has_split_dimension, - int64_t op_id, bool is_async, - bool no_parallel_custom_call, - absl::Span replica_group_offsets, - absl::Span replica_group_values) { -#if XLA_ENABLE_XCCL - VLOG(3) << "Running AllToAll " << (is_async ? "(Async) " : "(Sync) ") - << no_parallel_custom_call; - return RunSyncOrAsync( - run_options, collectives, async_collectives, uid, is_async, - [&](se::Stream* stream) { - const gpu::GpuExecutableRunOptions* gpu_opts = - run_options->run_options().gpu_executable_run_options(); - if (gpu_opts && gpu_opts->enable_mock_nccl_collectives()) { - return NcclMockImplCommon(stream); - } - return AllToAllImplCommon(run_options, debug_options, stream, args, - group_mode, has_split_dimension, op_id, - replica_group_offsets, replica_group_values, - is_async, no_parallel_custom_call); - }); -#else // XLA_ENABLE_XCCL - return absl::InternalError("NCCL disabled"); -#endif // XLA_ENABLE_XCCL -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - AllToAll, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.all_to_all") - .UserData() - .UserData() - .UserData() - .UserData() - .RemainingArgs() // args - .Attr("uid") - .Attr("group_mode") // CollectiveOpGroupMode - .Attr("has_split_dimension") - .Attr("op_id") - .Attr("is_async") - .Attr("no_parallel_custom_call") - .Attr>("replica_group_offsets") - .Attr>("replica_group_values")); - -//===----------------------------------------------------------------------===// -// ReduceScatter. -//===----------------------------------------------------------------------===// - -#if XLA_ENABLE_XCCL -absl::Status ReduceScatterImplCommon( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, se::Stream* stream, - CustomCall::RemainingArgs args, int64_t group_mode, int64_t op_id, - int64_t reduction_kind, absl::Span replica_group_offsets, - absl::Span replica_group_values, bool is_async, - bool no_parallel_custom_call) { - NcclExecuteParams params(*run_options, stream->parent()); - bool enable_clique_opt = ShouldEnableCliqueOptimization( - params, debug_options, no_parallel_custom_call); - - TF_ASSIGN_OR_RETURN( - auto comm, GetNcclComm(params, group_mode, op_id, replica_group_offsets, - replica_group_values, GetStreamId(is_async), - enable_clique_opt)); - - TF_ASSIGN_OR_RETURN(auto device_buffers, GetDeviceBufferPairs(args)); - - return RunRepeated( - debug_options->xla_gpu_collective_inflation_factor(), [&]() { - return RunReduceScatter(static_cast(reduction_kind), - device_buffers, *stream, *comm); - }); -} -#endif // XLA_ENABLE_XCCL - -absl::Status ReduceScatterImpl(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - CollectivesSupport* collectives, - AsyncCollectivesSupport* async_collectives, - CustomCall::RemainingArgs args, int32_t uid, - int64_t group_mode, int64_t op_id, bool is_async, - bool no_parallel_custom_call, - int64_t reduction_kind, - absl::Span replica_group_offsets, - absl::Span replica_group_values) { -#if XLA_ENABLE_XCCL - VLOG(3) << "Running ReduceScatter " << (is_async ? "(Async) " : "(Sync) ") - << no_parallel_custom_call; - return RunSyncOrAsync( - run_options, collectives, async_collectives, uid, is_async, - [&](se::Stream* stream) { - const gpu::GpuExecutableRunOptions* gpu_opts = - run_options->run_options().gpu_executable_run_options(); - if (gpu_opts && gpu_opts->enable_mock_nccl_collectives()) { - return NcclMockImplCommon(stream); - } - return ReduceScatterImplCommon( - run_options, debug_options, stream, args, group_mode, op_id, - reduction_kind, replica_group_offsets, replica_group_values, - is_async, no_parallel_custom_call); - }); -#else // XLA_ENABLE_XCCL - return absl::InternalError("NCCL disabled"); -#endif // XLA_ENABLE_XCCL -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - ReduceScatter, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.reduce_scatter") - .UserData() - .UserData() - .UserData() - .UserData() - .RemainingArgs() // args - .Attr("uid") - .Attr("group_mode") // CollectiveOpGroupMode - .Attr("op_id") - .Attr("is_async") - .Attr("no_parallel_custom_call") - .Attr("reduction_kind") // ReductionKind - .Attr>("replica_group_offsets") - .Attr>("replica_group_values")); - -//===----------------------------------------------------------------------===// -// AsyncDone. -//===----------------------------------------------------------------------===// - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - AsyncDone, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.async_collective_done") - .UserData() - .UserData() - .Attr("uid") - .Attr("done_type")); - -//===----------------------------------------------------------------------===// -// ReplicaId. -//===----------------------------------------------------------------------===// - -absl::Status ReplicaPartitionIdImpl( - const ServiceExecutableRunOptions* run_options, FlatMemrefView result, - bool is_replica_id) { - VLOG(3) << "Running " << (is_replica_id ? "ReplicaId" : "PartitionId"); - se::Stream* stream = run_options->stream(); - NcclExecuteParams params(*run_options, stream->parent()); - - TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id, - params.GetGlobalDeviceId()); - - TF_ASSIGN_OR_RETURN(DeviceAssignment::LogicalID logical_id, - params.device_assn->LogicalIdForDevice(global_device_id)); - - se::DeviceMemoryBase result_data = GetDeviceAddress(result); - const uint32_t id = - is_replica_id ? logical_id.replica_id : logical_id.computation_id; - stream->ThenMemset32(&result_data, id, /*size=*/4); - return absl::OkStatus(); -} - -absl::Status ReplicaIdImpl(const ServiceExecutableRunOptions* run_options, - FlatMemrefView result) { - return ReplicaPartitionIdImpl(run_options, result, /*is_replica_id=*/true); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - ReplicaId, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.replica_id") - .UserData() - .Arg()); - -//===----------------------------------------------------------------------===// -// PartitionId. -//===----------------------------------------------------------------------===// - -absl::Status PartitionIdImpl(const ServiceExecutableRunOptions* run_options, - FlatMemrefView result) { - return ReplicaPartitionIdImpl(run_options, result, /*is_replica_id=*/false); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - PartitionId, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.partition_id") - .UserData() - .Arg()); - -//===----------------------------------------------------------------------===// - -int64_t Key(int32_t uid, int32_t device_ordinal) { - return static_cast(uid) << 32 | device_ordinal; -} - -} // namespace - -//===----------------------------------------------------------------------===// -// Collectives support library. -//===----------------------------------------------------------------------===// - -absl::Status CollectivesSupport::MaybeBlockAfterFirstRun(int32_t uid, - int32_t device_ordinal, - se::Stream* stream) { - bool block = [&] { - absl::MutexLock lock(&mutex_); - return executed_.insert(Key(uid, device_ordinal)).second; - }(); - return block ? stream->BlockHostUntilDone() : absl::OkStatus(); -} - -AsyncCollectivesSupport::AsyncCollectivesSupport( - absl::Span async_streams) - : async_comm_streams_(async_streams.begin(), async_streams.end()) {} - -absl::Status AsyncCollectivesSupport::RecordEvent( - int32_t uid, gpu::AsyncStreamKind async_stream_kind) { - // Create an event on the async stream for the completion of the collective. - se::Event done_event(async_comm_stream(async_stream_kind)->parent()); - if (!done_event.Init()) return absl::InternalError("Failed to create event"); - async_comm_stream(async_stream_kind)->ThenRecordEvent(&done_event); - - absl::MutexLock lock(&mutex_); - auto [_, was_inserted] = done_events_.insert({uid, std::move(done_event)}); - if (!was_inserted) { - return absl::InternalError(absl::StrFormat( - "Async done event has not been consumed (uid=%d, device_ordinal=%d)", - uid, async_comm_stream(async_stream_kind)->parent()->device_ordinal())); - } - return absl::OkStatus(); -} - -absl::StatusOr AsyncCollectivesSupport::PopEvent(int32_t uid) { - absl::MutexLock lock(&mutex_); - auto done_event = done_events_.extract(uid); - if (!done_event) { - return absl::InternalError( - absl::StrFormat("Async done event was not found (uid=%d)", uid)); - } - return std::move(done_event.mapped()); -} - -void RegisterCollectiveCustomCalls( - runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.collective_permute", CollectivePermute); - registry.Register("xla.gpu.send", P2PSend); - registry.Register("xla.gpu.recv", P2PRecv); - registry.Register("xla.gpu.all_gather", AllGather); - registry.Register("xla.gpu.all_reduce", AllReduce); - registry.Register("xla.gpu.all_to_all", AllToAll); - registry.Register("xla.gpu.reduce_scatter", ReduceScatter); - - registry.Register("xla.gpu.collective_done", AsyncDone); - - registry.Register("xla.gpu.partition_id", PartitionId); - registry.Register("xla.gpu.replica_id", ReplicaId); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/collectives.h b/xla/service/gpu/runtime/collectives.h deleted file mode 100644 index 23bd700cc151f..0000000000000 --- a/xla/service/gpu/runtime/collectives.h +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_COLLECTIVES_H_ -#define XLA_SERVICE_GPU_RUNTIME_COLLECTIVES_H_ - -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/service/gpu/nccl_collective_thunk.h" -#include "xla/stream_executor/event.h" - -namespace xla { -namespace gpu { - -// Support for running async collective operations communicating via events. -// Registers XLA Gpu runtime collective custom calls. -void RegisterCollectiveCustomCalls(runtime::DirectCustomCallRegistry& registry); - -class CollectivesSupport { - public: - // Maybe block host after the first call to the collective operation with the - // given uid, to ensure that all devices have allocated the required buffers - // for their communicators before allowing any device to continue enqueuing - // operations. Otherwise, the allocations can cause deadlock in the CUDA - // driver. - // - // This basically ports workaround from cr/435058849 to Xla runtime (see - // details in the b/215649390). - absl::Status MaybeBlockAfterFirstRun(int32_t uid, int32_t device_ordinal, - se::Stream* stream); - - private: - absl::Mutex mutex_; - - // Store if a particular collective operation was executed at least once. We - // rely on unique `uid` assigned to each collective operation by the lowering - // pass. - absl::flat_hash_set executed_ ABSL_GUARDED_BY(mutex_); -}; - -// Support for running async collective operations communicating via events. -class AsyncCollectivesSupport { - public: - explicit AsyncCollectivesSupport(absl::Span async_streams); - - absl::Status RecordEvent(int32_t uid, AsyncStreamKind async_stream_kind); - absl::StatusOr PopEvent(int32_t uid); - - se::Stream* async_comm_stream(AsyncStreamKind async_stream_kind) const { - return async_comm_streams_[async_stream_kind]; - } - - private: - absl::Mutex mutex_; - absl::InlinedVector async_comm_streams_; - - // Store done events for the Done ops to wait upon. - absl::flat_hash_map done_events_ ABSL_GUARDED_BY(mutex_); -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_COLLECTIVES_H_ diff --git a/xla/service/gpu/runtime/command_buffer_allocations.cc b/xla/service/gpu/runtime/command_buffer_allocations.cc new file mode 100644 index 0000000000000..a8ac270ff8d2e --- /dev/null +++ b/xla/service/gpu/runtime/command_buffer_allocations.cc @@ -0,0 +1,64 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/command_buffer_allocations.h" + +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "xla/service/buffer_assignment.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "xla/stream_executor/device_memory.h" + +namespace xla::gpu { + +absl::StatusOr CommandBufferAllocations::GetDeviceAddress( + BufferAllocation::Index index) const { + auto base = allocs_.find(index); + if (base == allocs_.end()) { + return absl::InternalError(absl::StrCat("Command buffer allocation #", + index, " was not allocated")); + } + return allocs_.at(index); +} + +absl::Status CommandBufferAllocations::AddAllocation( + BufferAllocation::Index index, se::DeviceMemoryBase memory) { + VLOG(2) << "Add comand buffer allocation: index=" << index + << "; ptr=" << memory.opaque(); + + auto emplaced = allocs_.try_emplace(index, std::move(memory)); + if (emplaced.second == false) { + return absl::InternalError(absl::StrCat("Command buffer allocation #", + index, " was already allocated")); + } + return absl::OkStatus(); +} + +absl::Status CommandBufferAllocations::EraseAllocation( + BufferAllocation::Index index) { + VLOG(2) << "Erase comand buffer allocation: index=" << index; + + if (allocs_.erase(index) == 0) { + return absl::InternalError(absl::StrCat("Command buffer allocation #", + index, " was not allocated")); + } + return absl::OkStatus(); +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/command_buffer_allocations.h b/xla/service/gpu/runtime/command_buffer_allocations.h new file mode 100644 index 0000000000000..c4257d0d92fc0 --- /dev/null +++ b/xla/service/gpu/runtime/command_buffer_allocations.h @@ -0,0 +1,50 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_ALLOCATIONS_H_ +#define XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_ALLOCATIONS_H_ + +#include "absl/container/flat_hash_map.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/status.h" +#include "xla/stream_executor/device_memory.h" + +namespace xla::gpu { + +// Command buffer allocations tracks external buffer allocations done via the +// CommandBuffer API and owned by the XLA executable (via instantiated command +// buffers and memory allocation Gpu graph nodes). +class CommandBufferAllocations : public BufferAllocations::ExternalAllocations { + public: + absl::StatusOr GetDeviceAddress( + BufferAllocation::Index index) const override; + + // Adds an external allocation for a given buffer index. Returns error if + // allocation already exists. + absl::Status AddAllocation(BufferAllocation::Index index, + se::DeviceMemoryBase memory) override; + + // Erases an external allocation for a given buffer index. Returns error if + // allocation does not exists. + absl::Status EraseAllocation(BufferAllocation::Index index) override; + + private: + absl::flat_hash_map allocs_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_ALLOCATIONS_H_ diff --git a/xla/service/gpu/runtime/command_buffer_cmd.cc b/xla/service/gpu/runtime/command_buffer_cmd.cc new file mode 100644 index 0000000000000..7f307a0da7442 --- /dev/null +++ b/xla/service/gpu/runtime/command_buffer_cmd.cc @@ -0,0 +1,1678 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/command_buffer_cmd.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/optimization.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/functional/function_ref.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/executable_run_options.h" +#include "xla/ffi/call_frame.h" +#include "xla/ffi/ffi_api.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/computation_placer.h" +#include "xla/service/global_device_id.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/service/gpu/runtime/annotation.h" +#include "xla/service/gpu/runtime/nccl_all_gather_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/service/gpu/stream_executor_util.h" +#include "xla/service/service_executable_run_options.h" +#include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/types.h" // IWYU pragma: keep +#include "xla/util.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/scoped_annotation.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "xla/service/custom_call_status.h" +#include "xla/service/custom_call_status_internal.h" +#include "xla/stream_executor/gpu/gpu_stream.h" +#include "xla/stream_executor/gpu/gpu_types.h" +#endif + +namespace xla::gpu { + +using ExecutionScopeId = se::CommandBuffer::ExecutionScopeId; +using MemoryAccess = CommandBufferCmd::MemoryAccess; + +static std::string_view ReductionKindString(ReductionKind kind) { + switch (kind) { + case ReductionKind::MAX: + return "max"; + case ReductionKind::MIN: + return "min"; + case ReductionKind::PRODUCT: + return "product"; + case ReductionKind::SUM: + return "sum"; + } +} + +// Creates command buffer builder from a cmd sequence. +static se::CommandBuffer::Builder CreateBuilder( + CommandBufferCmdSequence* commands, + const Thunk::ExecuteParams* execute_params, + const CommandBufferCmd::RecordParams* record_params) { + return [=](se::CommandBuffer* command_buffer) { + return commands->Record(*execute_params, *record_params, command_buffer, + CommandBufferCmdSequence::RecordMode::kConditional); + }; +} + +// Creates command buffer builders from a span of cmd sequences. +static std::vector CreateBuilders( + absl::Span commands, + const Thunk::ExecuteParams* execute_params, + const CommandBufferCmd::RecordParams* record_params) { + std::vector builders; + for (CommandBufferCmdSequence& cmd : commands) { + builders.push_back(CreateBuilder(&cmd, execute_params, record_params)); + } + return builders; +} + +// Creates command buffer execution scope builder from a cmd sequence. +static se::CommandBuffer::ExecutionScopeBuilder CreateExecutionScopeBuilder( + CommandBufferCmdSequence* commands, + const Thunk::ExecuteParams* execute_params, + const CommandBufferCmd::RecordParams* record_params) { + return [=](ExecutionScopeId id, se::CommandBuffer* command_buffer) { + CommandBufferCmd::RecordParams params = *record_params; + params.execution_scope_id = id; + return commands->Record(*execute_params, params, command_buffer, + CommandBufferCmdSequence::RecordMode::kConditional); + }; +} + +//===----------------------------------------------------------------------===// +// CommandBufferCmd +//===----------------------------------------------------------------------===// + +CommandBufferCmd::State* CommandBufferCmd::StateManager::GetOrNull( + const CommandBufferCmd* cmd) { + if (auto it = state_.find(cmd); it != state_.end()) { + return it->second.get(); + } + return nullptr; +} + +CommandBufferCmd::State* CommandBufferCmd::StateManager::GetOrCreate( + const CommandBufferCmd* cmd, + absl::FunctionRef()> create) { + if (auto it = state_.find(cmd); it != state_.end()) { + return it->second.get(); + } + return state_.try_emplace(cmd, create()).first->second.get(); +} + +se::CommandBuffer::ExecutionScopeId CommandBufferCmd::GetExecutionScope( + const RecordParams& record_params) const { + int64_t base = record_params.execution_scope_id.value(); + int64_t offset = execution_stream_id_.value(); + return se::CommandBuffer::ExecutionScopeId(base + offset); +} + +//===----------------------------------------------------------------------===// +// CommandBufferCmdSequence +//===----------------------------------------------------------------------===// + +CommandBufferCmdSequence::CommandBufferCmdSequence( + SynchronizationMode synchronization_mode) + : synchronization_mode_(synchronization_mode) {} + +void CommandBufferCmdSequence::Append(std::unique_ptr cmd) { + for (const CommandBufferCmd::BufferUsage& buffer : cmd->buffers()) { + buffers_.insert(buffer); + allocs_indices_.insert(buffer.slice.index()); + } + + ExecutionStreamId execution_stream_id = cmd->execution_stream_id(); + CommandBufferCmd::BufferUsageVector buffers = cmd->buffers(); + bool requires_barrier = HasConflicts(execution_stream_id, buffers); + + // Always add barriers between commands if we want to serialize execution. + if (synchronization_mode_ == SynchronizationMode::kSerialize && + !commands_.empty()) { + requires_barrier = true; + } + + // If the first recorded command is implemented as a nested command buffer we + // force a barrier before recording the next command as a workaround for CUDA + // graph bug, where child CUDA graph must be a single CUDA graph root node. + if (commands_.size() == 1 && commands_.front().cmd->IsNestedCommandBuffer()) { + requires_barrier = true; + } + + if (requires_barrier) ClearTrackedBuffers(execution_stream_id); + + commands_.push_back({std::move(cmd), requires_barrier}); + TrackBuffers(execution_stream_id, buffers); +} + +absl::Status CommandBufferCmdSequence::Prepare( + const Thunk::PrepareParams& params, + Thunk::ResourceRequests& resource_requests) { + for (auto& command : commands_) { + TF_RETURN_IF_ERROR(command.cmd->Prepare(params, resource_requests)); + } + return absl::OkStatus(); +} + +absl::Status CommandBufferCmdSequence::Initialize( + const Thunk::InitializeParams& params, + CommandBufferCmd::StateManager& state) { + for (auto& command : commands_) { + TF_RETURN_IF_ERROR(command.cmd->Initialize(params, state)); + } + return absl::OkStatus(); +} + +bool CommandBufferCmdSequence::HasConflicts( + ExecutionStreamId execution_stream_id, + const CommandBufferCmd::BufferUsageVector& buffers) { + auto& rwset = read_write_sets_[execution_stream_id]; + + // Returns true if slice overlaps with any of the slices in read set. + auto read_overlap = [&](const BufferAllocation::Slice& slice) { + if (rwset.read.contains(slice)) return true; + for (auto& read : rwset.read) + if (read.OverlapsWith(slice)) return true; + return false; + }; + + // Returns true if slice overlaps with any of the slices in write set. + auto write_overlap = [&](const BufferAllocation::Slice& slice) { + if (rwset.write.contains(slice)) return true; + for (auto& write : rwset.write) + if (write.OverlapsWith(slice)) return true; + return false; + }; + + return absl::c_any_of(buffers, [&](const auto& buffer) { + return buffer.access == MemoryAccess::kWrite + ? write_overlap(buffer.slice) || read_overlap(buffer.slice) + : write_overlap(buffer.slice); + }); +} + +void CommandBufferCmdSequence::TrackBuffers( + ExecutionStreamId execution_stream_id, + const CommandBufferCmd::BufferUsageVector& buffers) { + auto& rwset = read_write_sets_[execution_stream_id]; + for (const CommandBufferCmd::BufferUsage& buffer : buffers) { + if (buffer.access == MemoryAccess::kWrite) rwset.write.insert(buffer.slice); + if (buffer.access == MemoryAccess::kRead) rwset.read.insert(buffer.slice); + } +} + +void CommandBufferCmdSequence::ClearTrackedBuffers( + ExecutionStreamId execution_stream_id) { + read_write_sets_[execution_stream_id] = ReadWriteSet(); +} + +static std::string_view RecordModeString( + CommandBufferCmdSequence::RecordMode mode) { + switch (mode) { + case CommandBufferCmdSequence::RecordMode::kExclusive: + return "exclusive"; + case CommandBufferCmdSequence::RecordMode::kConditional: + return "conditional"; + } +} + +absl::Status CommandBufferCmdSequence::Record( + const Thunk::ExecuteParams& execute_params, + const CommandBufferCmd::RecordParams& record_params, + se::CommandBuffer* command_buffer, RecordMode mode) { + VLOG(3) << "Record " << commands_.size() << " commands into command buffer" + << "; mode=" << RecordModeString(mode); + uint64_t start_micros = tsl::Env::Default()->NowMicros(); + + if (mode == RecordMode::kExclusive) { + if (command_buffer->state() == se::CommandBuffer::State::kFinalized) { + TF_RETURN_IF_ERROR(command_buffer->Update()); + } + } + + se::StreamExecutor* device = execute_params.stream->parent(); + const ModuleAnnotations* annotations = GetCurrentModuleAnnotations(); + + // Track the number of commands recorded between barriers. + absl::flat_hash_map num_recorded_commands; + + for (auto& command : commands_) { + ExecutionScopeId execution_scope_id = + command.cmd->GetExecutionScope(record_params); + std::optional annotation = + GetKernelAnnotation(annotations, command.cmd->profile_annotation()); + + if (command.requires_barrier) { + VLOG(3) << "Add command buffer barrier after " + << num_recorded_commands[execution_scope_id] + << " recorded commands into the execution scope #" + << execution_scope_id.value(); + TF_RETURN_IF_ERROR(command_buffer->Barrier(device, execution_scope_id)); + num_recorded_commands.erase(execution_scope_id); + } + + TF_RETURN_IF_ERROR( + command.cmd->Record(execute_params, record_params, command_buffer)); + ++num_recorded_commands[execution_scope_id]; + } + + if (mode == RecordMode::kExclusive) { + TF_RETURN_IF_ERROR(command_buffer->Finalize()); + } + + uint64_t end_micros = tsl::Env::Default()->NowMicros(); + VLOG(3) << "Recorded " << commands_.size() + << " commands into command buffer in " << (end_micros - start_micros) + << " μs; mode=" << RecordModeString(mode); + + return absl::OkStatus(); +} + +const absl::flat_hash_set& +CommandBufferCmdSequence::buffers() const { + return buffers_; +} + +const absl::flat_hash_set& +CommandBufferCmdSequence::allocs_indices() const { + return allocs_indices_; +} + +std::vector CommandBufferCmdSequence::barriers() const { + std::vector barriers; + absl::c_transform(commands_, std::back_inserter(barriers), + [](auto& command) { return command.requires_barrier; }); + return barriers; +} + +//===----------------------------------------------------------------------===// +// TracedCommandBuffer +//===----------------------------------------------------------------------===// + +TracedCommandBuffer::TracedCommandBuffer( + CommandBufferCmd::BufferUsageVector buffers, int64_t capacity) + : capacity_(capacity), entries_(capacity) { + CHECK_GT(capacity, 0) << "capacity must be larger than 0"; // NOLINT + // Collect unique buffer allocation indices in a set first and convert to + // vector as flat hash set iteration has measurable overheads. + absl::flat_hash_set allocs_indices; + for (auto& buffer : buffers) allocs_indices.insert(buffer.slice.index()); + allocs_indices_.assign(allocs_indices.begin(), allocs_indices.end()); +} + +absl::StatusOr TracedCommandBuffer::GetOrTraceCommandBuffer( + const BufferAllocations* buffer_allocation, se::StreamExecutor* executor, + se::Stream* stream, absl::FunctionRef trace) { + // Collect memory addresses for relevant allocations. + absl::InlinedVector allocs; + allocs.reserve(allocs_indices_.size()); + for (auto& index : allocs_indices_) { + allocs.emplace_back(buffer_allocation->GetDeviceAddress(index)); + } + + // Moves entry at `i` position to front and moves entries in `[0, i)` range + // one element to the right. Returns reference to the first entry. + auto shift_right = [&](size_t i) -> Entry& { + if (i == 0) return entries_[0]; + + Entry entry = std::move(entries_[i]); + do { + entries_[i] = std::move(entries_[i - 1]); + } while (--i > 0); + + return entries_[0] = std::move(entry); + }; + + for (size_t i = 0; i < capacity_; ++i) { + // Found entry for a given allocations, move it to front and return a + // pointer to cached command buffer. + if (ABSL_PREDICT_TRUE(absl::c_equal(entries_[i].recorded_allocs, allocs) && + entries_[i].command_buffer)) { + return shift_right(i).command_buffer.get(); + } + + // Create a new entry by calling a user-provided tracing function, move it + // to front and return a pointer to cached command buffer. + if (entries_[i].command_buffer == nullptr) { + TF_ASSIGN_OR_RETURN(entries_[i].command_buffer, + se::CommandBuffer::Trace(executor, stream, trace)); + entries_[i].recorded_allocs.assign(allocs.begin(), allocs.end()); + return shift_right(i).command_buffer.get(); + } + } + + // Create a new entry by calling a user-provided tracing function, replace the + // last entry with it, move it to front and return a pointer to cached command + // buffer. + TF_ASSIGN_OR_RETURN(entries_[capacity_ - 1].command_buffer, + se::CommandBuffer::Trace(executor, stream, trace)); + entries_[capacity_ - 1].recorded_allocs.assign(allocs.begin(), allocs.end()); + return shift_right(capacity_ - 1).command_buffer.get(); +} + +//===----------------------------------------------------------------------===// +// TracedCommandBufferCmd +//===----------------------------------------------------------------------===// + +TracedCommandBufferCmd::TracedCommandBufferCmd( + ExecutionStreamId execution_stream_id) + : CommandBufferCmd(execution_stream_id) {} + +absl::Status TracedCommandBufferCmd::AddTracedCommandBuffer( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer, + absl::FunctionRef trace) { + auto traced_cmd = record_params.state.GetOrCreate( + this, [&] { return std::make_unique(buffers()); }); + + TF_ASSIGN_OR_RETURN( + auto nested_cmd, + traced_cmd->GetOrTraceCommandBuffer( + execute_params.buffer_allocations, execute_params.stream->parent(), + execute_params.command_buffer_trace_stream, trace)); + + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "Add nested command buffer to execution scope: " + << execution_scope_id.value(); + return command_buffer->AddNestedCommandBuffer(execution_scope_id, + *nested_cmd); +} + +//===----------------------------------------------------------------------===// +// ComputationId +//===----------------------------------------------------------------------===// + +// TODO(ezhulenev): PTX kernel should be replaced with CUDA C++ kernel but +// today we accidentally try to build them without CUDA support. We need to +// clean our build and testing infrastructure first. + +// PTX kernel compiled from: +// +// __global__ void memset32(int64_t n, uint32_t value, uint32_t* dst) +// { +// int i = blockIdx.x*blockDim.x + threadIdx.x; +// if (i < n) dst[i] = value; +// } +// +// Easiest way to get PTX from C++ is to use https://godbolt.org. +inline constexpr std::string_view kMemset32Kernel = R"( +.version 4.0 +.target sm_50 +.address_size 64 + +.visible .entry memset32( + .param .u64 memset32_param_0, + .param .u32 memset32_param_1, + .param .u64 memset32_param_2 +) +{ + .reg .pred %p<2>; + .reg .b32 %r<6>; + .reg .b64 %rd<7>; + .loc 1 3 0 + + ld.param.u64 %rd3, [memset32_param_0]; + ld.param.u32 %r1, [memset32_param_1]; + ld.param.u64 %rd2, [memset32_param_2]; + .loc 1 5 3 + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %ntid.x; + mov.u32 %r4, %tid.x; + mad.lo.s32 %r5, %r2, %r3, %r4; + .loc 1 6 3 + cvt.s64.s32 %rd1, %r5; + setp.ge.s64 %p1, %rd1, %rd3; + @%p1 bra $L__BB0_2; + + .loc 1 5 3 + cvta.to.global.u64 %rd4, %rd2; + .loc 1 6 3 + shl.b64 %rd5, %rd1, 2; + add.s64 %rd6, %rd4, %rd5; + st.global.u32 [%rd6], %r1; + +$L__BB0_2: + .loc 1 7 1 + ret; + +})"; + +ComputationIdCmd::ComputationIdCmd(ExecutionStreamId execution_stream_id, + BufferAllocation::Slice dest, Kind kind) + : CommandBufferCmd(execution_stream_id), dest_(dest), kind_(kind) {} + +CommandBufferCmd::BufferUsageVector ComputationIdCmd::buffers() { + return {{dest_, MemoryAccess::kWrite}}; +} + +absl::Status ComputationIdCmd::Initialize(const Thunk::InitializeParams& params, + StateManager& state) { +#if defined(GOOGLE_CUDA) + { + absl::MutexLock lock(&mutex_); + if (memset_kernels_.contains(params.executor)) return absl::OkStatus(); + } + + TF_ASSIGN_OR_RETURN(std::unique_ptr kernel, + CreateKernel("memset32", 3, kMemset32Kernel, + /*cubin_data=*/{}, params.executor, + /*shared_mem_bytes=*/0)); + + absl::MutexLock lock(&mutex_); + memset_kernels_.emplace(params.executor, std::move(kernel)); +#endif // GOOGLE_CUDA + return absl::OkStatus(); +} + +absl::Status ComputationIdCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { + se::DeviceMemoryBase dst = + execute_params.buffer_allocations->GetDeviceAddress(dest_); + + GlobalDeviceId global_device_id = + execute_params.collective_params->global_device_id; + TF_ASSIGN_OR_RETURN( + const DeviceAssignment::LogicalID logical_id, + execute_params.collective_params->device_assn->LogicalIdForDevice( + global_device_id)); + + uint32_t value = kind_ == Kind::kReplica ? logical_id.replica_id + : logical_id.computation_id; + + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "ComputationIdCmd" + << ": kind=" << (kind_ == Kind::kReplica ? "replica" : "partition") + << "; value=" << value + << "; execution_scope_id=" << execution_scope_id.value(); + VLOG(5) << " Id: " << dest_ << " (" << dst.opaque() << ")"; + +#if defined(GOOGLE_CUDA) + se::Kernel* memset_kernel = [&] { + absl::MutexLock lock(&mutex_); + return memset_kernels_[execute_params.stream->parent()].get(); + }(); + + if (memset_kernel == nullptr) { + return absl::InternalError( + "Memset kernel not loaded on a command buffer executor"); + } + + auto args = se::PackKernelArgs(/*shmem_bytes=*/0, int64_t{1}, value, dst); + return command_buffer->Launch(execution_scope_id, se::ThreadDim(1), + se::BlockDim(1), *memset_kernel, *args); +#else + return command_buffer->Memset(execution_scope_id, &dst, value, + /*num_elements=*/1); +#endif // GOOGLE_CUDA +} + +//===----------------------------------------------------------------------===// +// LaunchCmd +//===----------------------------------------------------------------------===// + +LaunchCmd::LaunchCmd(ExecutionStreamId execution_stream_id, + std::string kernel_name, + absl::Span args, + absl::Span args_access, + LaunchDimensions dims, int64_t shmem_bytes) + : CommandBufferCmd(execution_stream_id), + kernel_name_(std::move(kernel_name)), + args_(args.begin(), args.end()), + args_access_(args_access.begin(), args_access.end()), + dims_(dims), + shmem_bytes_(shmem_bytes) {} + +absl::Status LaunchCmd::Initialize(const Thunk::InitializeParams& params, + StateManager& state) { + { + absl::MutexLock lock(&mutex_); + if (kernels_.contains(params.executor)) return absl::OkStatus(); + } + + TF_ASSIGN_OR_RETURN( + std::unique_ptr kernel, + CreateKernel(kernel_name_, args_.size(), params.src.text, + params.src.binary, params.executor, shmem_bytes_)); + + absl::MutexLock lock(&mutex_); + kernels_.emplace(params.executor, std::move(kernel)); + return absl::OkStatus(); +} + +absl::Status LaunchCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "LaunchCmd: kernel=" << kernel_name_ + << "; shmem_bytes=" << shmem_bytes_ + << "; execution_scope_id=" << execution_scope_id.value(); + + se::Kernel* kernel = [&] { + absl::MutexLock lock(&mutex_); + return kernels_[execute_params.stream->parent()].get(); + }(); + + if (kernel == nullptr) { + return absl::InternalError(absl::StrCat( + "Kernel not loaded on a command buffer executor: ", kernel_name_)); + } + + absl::InlinedVector buffers; + for (const BufferAllocation::Slice& arg : args_) { + se::DeviceMemoryBase buf = + execute_params.buffer_allocations->GetDeviceAddress(arg); + VLOG(5) << " Arg: " << arg << ": " << buf.opaque(); + buffers.push_back(buf); + } + + TF_ASSIGN_OR_RETURN(auto kernel_args, + se::PackKernelArgs(buffers, shmem_bytes_)); + + return command_buffer->Launch(execution_scope_id, + dims_.thread_counts_per_block(), + dims_.block_counts(), *kernel, *kernel_args); +} + +CommandBufferCmd::BufferUsageVector LaunchCmd::buffers() { + BufferUsageVector buffers; + for (int32_t i = 0; i < args_.size(); ++i) { + buffers.emplace_back(args_[i], args_access_[i]); + } + return buffers; +} + +//===----------------------------------------------------------------------===// +// CustomKernelLaunchCmd +//===----------------------------------------------------------------------===// + +CustomKernelLaunchCmd::CustomKernelLaunchCmd( + ExecutionStreamId execution_stream_id, + absl::Span args, + absl::Span args_access, CustomKernel custom_kernel) + : CommandBufferCmd(execution_stream_id), + args_(args.begin(), args.end()), + args_access_(args_access.begin(), args_access.end()), + custom_kernel_(std::move(custom_kernel)) {} + +absl::Status CustomKernelLaunchCmd::Initialize( + const Thunk::InitializeParams& params, StateManager& state) { + { + absl::MutexLock lock(&mutex_); + if (kernels_.contains(params.executor)) return absl::OkStatus(); + } + + TF_ASSIGN_OR_RETURN( + std::unique_ptr kernel, + se::Kernel::Create(params.executor, custom_kernel_.kernel_spec())); + + absl::MutexLock lock(&mutex_); + kernels_.emplace(params.executor, std::move(kernel)); + return absl::OkStatus(); +} + +absl::Status CustomKernelLaunchCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "CustomKernelLaunchCmd: custom_kernel=" << custom_kernel_.name() + << "; execution_scope_id=" << execution_scope_id.value(); + + se::Kernel* kernel = [&] { + absl::MutexLock lock(&mutex_); + return kernels_[execute_params.stream->parent()].get(); + }(); + + if (kernel == nullptr) { + return absl::InternalError( + absl::StrCat("Custom kernel not loaded on a command buffer executor: ", + custom_kernel_.name())); + } + + absl::InlinedVector buffers; + for (const BufferAllocation::Slice& arg : args_) { + se::DeviceMemoryBase buf = + execute_params.buffer_allocations->GetDeviceAddress(arg); + VLOG(5) << " Arg: " << arg << ": " << buf.opaque(); + buffers.push_back(buf); + } + + se::KernelArgsDeviceMemoryArray kernel_args( + buffers, custom_kernel_.shared_memory_bytes()); + + return command_buffer->Launch( + execution_scope_id, custom_kernel_.thread_dims(), + custom_kernel_.block_dims(), *kernel, kernel_args); +} + +CommandBufferCmd::BufferUsageVector CustomKernelLaunchCmd::buffers() { + BufferUsageVector buffers; + for (int32_t i = 0; i < args_.size(); ++i) { + buffers.emplace_back(args_[i], args_access_[i]); + } + return buffers; +} + +//===----------------------------------------------------------------------===// +// MemcpyDeviceToDeviceCmd +//===----------------------------------------------------------------------===// + +MemcpyDeviceToDeviceCmd::MemcpyDeviceToDeviceCmd( + ExecutionStreamId execution_stream_id, BufferAllocation::Slice dst, + BufferAllocation::Slice src, int64_t num_bytes) + : CommandBufferCmd(execution_stream_id), + dst_(dst), + src_(src), + num_bytes_(num_bytes) {} + +absl::Status MemcpyDeviceToDeviceCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { + se::DeviceMemoryBase dst = + execute_params.buffer_allocations->GetDeviceAddress(dst_); + se::DeviceMemoryBase src = + execute_params.buffer_allocations->GetDeviceAddress(src_); + + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "MemcpyDeviceToDeviceCmd: num_bytes = " << num_bytes_ + << "; execution_scope_id=" << execution_scope_id.value(); + VLOG(5) << " Dst: " << dst_ << " (" << dst.opaque() << ")"; + VLOG(5) << " Src: " << src_ << " (" << src.opaque() << ")"; + + if (num_bytes_ == 0) { + VLOG(5) << "Skip recording MemcpyDeviceToDeviceCmd command of 0 bytes"; + return absl::OkStatus(); + } + + return command_buffer->MemcpyDeviceToDevice(execution_scope_id, &dst, src, + num_bytes_); +} + +CommandBufferCmd::BufferUsageVector MemcpyDeviceToDeviceCmd::buffers() { + return {{dst_, MemoryAccess::kWrite}, {src_, MemoryAccess::kRead}}; +} + +//===----------------------------------------------------------------------===// +// MemzeroCmd +//===----------------------------------------------------------------------===// + +MemzeroCmd::MemzeroCmd(ExecutionStreamId execution_stream_id, + BufferAllocation::Slice dst) + : CommandBufferCmd(execution_stream_id), dst_(dst) {} + +absl::Status MemzeroCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { + se::DeviceMemoryBase dst = + execute_params.buffer_allocations->GetDeviceAddress(dst_); + + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "MemzeroCmd: execution_scope_id=" << execution_scope_id.value(); + VLOG(5) << " Dst: " << dst_ << " (" << dst.opaque() << ")"; + + if (dst_.size() == 0) { + VLOG(5) << "Skip recording MemzeroCmd command of 0 bytes"; + return absl::OkStatus(); + } + + return command_buffer->Memset(execution_scope_id, &dst, uint8_t{0}, + /*num_elements=*/dst_.size()); +} + +CommandBufferCmd::BufferUsageVector MemzeroCmd::buffers() { + return {{dst_, MemoryAccess::kWrite}}; +} + +//===----------------------------------------------------------------------===// +// Memset32Cmd +//===----------------------------------------------------------------------===// + +Memset32Cmd::Memset32Cmd(ExecutionStreamId execution_stream_id, + BufferAllocation::Slice dst, uint32_t bit_pattern) + : CommandBufferCmd(execution_stream_id), + dst_(dst), + bit_pattern_(bit_pattern) {} + +absl::Status Memset32Cmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { + se::DeviceMemoryBase dst = + execute_params.buffer_allocations->GetDeviceAddress(dst_); + + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "Memset32Cmd: bit_pattern=" << bit_pattern_ + << "; execution_scope_id=" << execution_scope_id.value(); + VLOG(5) << " Dst: " << dst_ << " (" << dst.opaque() << ")"; + + if (dst_.size() == 0) { + VLOG(5) << "Skip recording Memset32Cmd command of 0 bytes"; + return absl::OkStatus(); + } + + return command_buffer->Memset( + execution_scope_id, &dst, bit_pattern_, + /*num_elements=*/dst_.size() / sizeof(uint32_t)); +} + +CommandBufferCmd::BufferUsageVector Memset32Cmd::buffers() { + return {{dst_, MemoryAccess::kWrite}}; +} + +//===----------------------------------------------------------------------===// +// IfCmd +//===----------------------------------------------------------------------===// + +IfCmd::IfCmd(ExecutionStreamId execution_stream_id, + BufferAllocation::Slice pred, + CommandBufferCmdSequence then_commands) + : CommandBufferCmd(execution_stream_id), + pred_(pred), + then_commands_(std::move(then_commands)) {} + +absl::Status IfCmd::Initialize(const Thunk::InitializeParams& params, + StateManager& state) { + return then_commands_.Initialize(params, state); +} + +absl::Status IfCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { + se::DeviceMemoryBase pred = + execute_params.buffer_allocations->GetDeviceAddress(pred_); + + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "IfCmd: execution_scope_id=" << execution_scope_id.value(); + VLOG(5) << " pred: " << pred_ << " (" << pred.opaque() << ")"; + + return command_buffer->If( + execution_scope_id, execute_params.stream->parent(), + se::DeviceMemory(pred), + CreateBuilder(&then_commands_, &execute_params, &record_params)); +} + +CommandBufferCmd::BufferUsageVector IfCmd::buffers() { + absl::flat_hash_set buffers; + buffers.emplace(pred_, MemoryAccess::kRead); + buffers.insert(then_commands_.buffers().begin(), + then_commands_.buffers().end()); + return {buffers.begin(), buffers.end()}; +} + +//===----------------------------------------------------------------------===// +// IfElseCmd +//===----------------------------------------------------------------------===// + +IfElseCmd::IfElseCmd(ExecutionStreamId execution_stream_id, + BufferAllocation::Slice pred, + CommandBufferCmdSequence then_commands, + CommandBufferCmdSequence else_commands) + : CommandBufferCmd(execution_stream_id), + pred_(pred), + then_commands_(std::move(then_commands)), + else_commands_(std::move(else_commands)) {} + +absl::Status IfElseCmd::Initialize(const Thunk::InitializeParams& params, + StateManager& state) { + TF_RETURN_IF_ERROR(then_commands_.Initialize(params, state)); + TF_RETURN_IF_ERROR(else_commands_.Initialize(params, state)); + return absl::OkStatus(); +} + +absl::Status IfElseCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { + se::DeviceMemoryBase pred = + execute_params.buffer_allocations->GetDeviceAddress(pred_); + + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "IfElseCmd: execution_scope_id=" << execution_scope_id.value(); + VLOG(5) << " pred: " << pred_ << " (" << pred.opaque() << ")"; + + return command_buffer->IfElse( + execution_scope_id, execute_params.stream->parent(), + se::DeviceMemory(pred), + CreateBuilder(&then_commands_, &execute_params, &record_params), + CreateBuilder(&else_commands_, &execute_params, &record_params)); +} + +CommandBufferCmd::BufferUsageVector IfElseCmd::buffers() { + absl::flat_hash_set buffers; + buffers.emplace(pred_, MemoryAccess::kRead); + buffers.insert(then_commands_.buffers().begin(), + then_commands_.buffers().end()); + buffers.insert(else_commands_.buffers().begin(), + else_commands_.buffers().end()); + return {buffers.begin(), buffers.end()}; +} + +//===----------------------------------------------------------------------===// +// CaseCmd +//===----------------------------------------------------------------------===// + +CaseCmd::CaseCmd(ExecutionStreamId execution_stream_id, + BufferAllocation::Slice index, + std::vector branches_commands) + : CommandBufferCmd(execution_stream_id), + index_(index), + branches_commands_(std::move(branches_commands)) {} + +absl::Status CaseCmd::Initialize(const Thunk::InitializeParams& params, + StateManager& state) { + for (auto& branch : branches_commands_) { + TF_RETURN_IF_ERROR(branch.Initialize(params, state)); + } + return absl::OkStatus(); +} + +absl::Status CaseCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { + se::DeviceMemoryBase index = + execute_params.buffer_allocations->GetDeviceAddress(index_); + + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "CaseCmd: execution_scope_id=" << execution_scope_id.value(); + VLOG(5) << " index: " << index_ << " (" << index.opaque() << ")"; + + return command_buffer->Case(execution_scope_id, + execute_params.stream->parent(), + se::DeviceMemory(index), + CreateBuilders(absl::MakeSpan(branches_commands_), + &execute_params, &record_params)); +} + +CommandBufferCmd::BufferUsageVector CaseCmd::buffers() { + absl::flat_hash_set buffers; + buffers.emplace(index_, MemoryAccess::kRead); + for (auto& branch : branches_commands_) { + buffers.insert(branch.buffers().begin(), branch.buffers().end()); + } + return {buffers.begin(), buffers.end()}; +} + +//===----------------------------------------------------------------------===// +// ForCmd +//===----------------------------------------------------------------------===// + +ForCmd::ForCmd(ExecutionStreamId execution_stream_id, int32_t num_iterations, + BufferAllocation::Slice loop_counter, + CommandBufferCmdSequence body_commands) + : CommandBufferCmd(execution_stream_id), + num_iterations_(num_iterations), + loop_counter_(loop_counter), + body_commands_(std::move(body_commands)) {} + +absl::Status ForCmd::Initialize(const Thunk::InitializeParams& params, + StateManager& state) { + return body_commands_.Initialize(params, state); +} + +absl::Status ForCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { + se::DeviceMemoryBase loop_counter = + execute_params.buffer_allocations->GetDeviceAddress(loop_counter_); + + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "ForCmd: num_iterations=" << num_iterations_ + << "; body_commands=" << body_commands_.size() + << "; execution_scope_id=" << execution_scope_id.value(); + VLOG(5) << " loop_counter: " << loop_counter_ << " (" + << loop_counter.opaque() << ")"; + + return command_buffer->For( + execution_scope_id, execute_params.stream->parent(), num_iterations_, + se::DeviceMemory(loop_counter), + CreateBuilder(&body_commands_, &execute_params, &record_params)); +} + +CommandBufferCmd::BufferUsageVector ForCmd::buffers() { + absl::flat_hash_set buffers; + buffers.emplace(loop_counter_, MemoryAccess::kWrite); + buffers.insert(body_commands_.buffers().begin(), + body_commands_.buffers().end()); + return {buffers.begin(), buffers.end()}; +} + +//===----------------------------------------------------------------------===// +// WhileCmd +//===----------------------------------------------------------------------===// + +WhileCmd::WhileCmd(ExecutionStreamId execution_stream_id, + BufferAllocation::Slice pred, + CommandBufferCmdSequence cond_commands, + CommandBufferCmdSequence body_commands) + : CommandBufferCmd(execution_stream_id), + pred_(pred), + cond_commands_(std::move(cond_commands)), + body_commands_(std::move(body_commands)) {} + +absl::Status WhileCmd::Initialize(const Thunk::InitializeParams& params, + StateManager& state) { + TF_RETURN_IF_ERROR(cond_commands_.Initialize(params, state)); + return body_commands_.Initialize(params, state); +} + +absl::Status WhileCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { + se::DeviceMemoryBase pred = + execute_params.buffer_allocations->GetDeviceAddress(pred_); + + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "WhileCmd: cond_commands=" << cond_commands_.size() + << " body_commands=" << body_commands_.size() + << "; execution_scope_id=" << execution_scope_id.value(); + VLOG(5) << " pred: " << pred_ << " (" << pred.opaque() << ")"; + + return command_buffer->While( + execution_scope_id, execute_params.stream->parent(), + se::DeviceMemory(pred), + CreateExecutionScopeBuilder(&cond_commands_, &execute_params, + &record_params), + CreateBuilder(&body_commands_, &execute_params, &record_params)); +} + +CommandBufferCmd::BufferUsageVector WhileCmd::buffers() { + absl::flat_hash_set buffers; + buffers.emplace(pred_, MemoryAccess::kWrite); + buffers.insert(cond_commands_.buffers().begin(), + cond_commands_.buffers().end()); + buffers.insert(body_commands_.buffers().begin(), + body_commands_.buffers().end()); + return {buffers.begin(), buffers.end()}; +} + +//===----------------------------------------------------------------------===// +// AllocateCmd +//===----------------------------------------------------------------------===// + +AllocateCmd::AllocateCmd(ExecutionStreamId execution_stream_id, + BufferAllocation allocation) + : CommandBufferCmd(execution_stream_id), allocation_(allocation) {} + +absl::Status AllocateCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { + // Memory allocation address is returned on graph creation, and there is no + // update operation + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(2) << "AllocationCmd: index=" << allocation_.index() + << "; execution_scope_id=" << execution_scope_id.value(); + + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase buffer, + command_buffer->Allocate(execution_scope_id, allocation_.size())); + return execute_params.buffer_allocations->AddExternalAllocation( + allocation_.index(), buffer); +} + +CommandBufferCmd::BufferUsageVector AllocateCmd::buffers() { return {}; } + +//===----------------------------------------------------------------------===// +// FreeCmd +//===----------------------------------------------------------------------===// + +FreeCmd::FreeCmd(ExecutionStreamId execution_stream_id, + BufferAllocation allocation) + : CommandBufferCmd(execution_stream_id), allocation_(allocation) {} + +absl::Status FreeCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(2) << "FreeCmd: index=" << allocation_.index() + << "; execution_scope_id=" << execution_scope_id.value(); + + se::DeviceMemoryBase address = + execute_params.buffer_allocations->GetDeviceAddress(allocation_.index()); + + // Free is in the same command buffer + TF_RETURN_IF_ERROR(command_buffer->Free(execution_scope_id, address)); + + // Remove the buffer from external allocations. + return execute_params.buffer_allocations->EraseExternalAllocation( + allocation_.index()); +} + +CommandBufferCmd::BufferUsageVector FreeCmd::buffers() { return {}; } + +//===----------------------------------------------------------------------===// +// GemmCmd +//===----------------------------------------------------------------------===// + +GemmCmd::GemmCmd(ExecutionStreamId execution_stream_id, GemmConfig config, + const BufferAllocation::Slice& lhs_buffer, + const BufferAllocation::Slice& rhs_buffer, + const BufferAllocation::Slice& output_buffer, + const BufferAllocation::Slice& workspace, bool deterministic) + : TracedCommandBufferCmd(execution_stream_id), + config_(std::move(config)), + lhs_buffer_(lhs_buffer), + rhs_buffer_(rhs_buffer), + output_buffer_(output_buffer), + workspace_(workspace), + deterministic_(deterministic) {} + +absl::Status GemmCmd::Initialize(const Thunk::InitializeParams& params, + StateManager& state) { + if (!params.stream->parent()->AsBlas()) { + return absl::InternalError("Failed to initialize BLAS support for GemmCmd"); + } + return absl::OkStatus(); +} + +absl::Status GemmCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { + se::DeviceMemoryBase lhs = + execute_params.buffer_allocations->GetDeviceAddress(lhs_buffer_); + se::DeviceMemoryBase rhs = + execute_params.buffer_allocations->GetDeviceAddress(rhs_buffer_); + se::DeviceMemoryBase out = + execute_params.buffer_allocations->GetDeviceAddress(output_buffer_); + se::DeviceMemoryBase workspace = + execute_params.buffer_allocations->GetDeviceAddress(workspace_); + + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "GemmCmd: deterministic=" << deterministic_ + << "; execution_scope_id=" << execution_scope_id.value(); + VLOG(5) << " Lhs: " << lhs_buffer_ << " (" << lhs.opaque() << ")"; + VLOG(5) << " Lhs: " << rhs_buffer_ << " (" << rhs.opaque() << ")"; + VLOG(5) << " Out: " << output_buffer_ << " (" << out.opaque() << ")"; + VLOG(5) << " Workspace: " << workspace_ << " (" << workspace.opaque() << ")"; + + return AddTracedCommandBuffer( + execute_params, record_params, command_buffer, [&](se::Stream* stream) { + return RunGemm(config_, lhs, rhs, out, workspace, deterministic_, + stream); + }); +} + +CommandBufferCmd::BufferUsageVector GemmCmd::buffers() { + return {{lhs_buffer_, MemoryAccess::kRead}, + {rhs_buffer_, MemoryAccess::kRead}, + {output_buffer_, MemoryAccess::kWrite}, + {workspace_, MemoryAccess::kWrite}}; +} + +//===----------------------------------------------------------------------===// +// CuDnnCmd +//===----------------------------------------------------------------------===// + +CuDnnCmd::CuDnnCmd(ExecutionStreamId execution_stream_id, + absl::Span args, + const std::shared_ptr graph) + : TracedCommandBufferCmd(execution_stream_id), + args_(args.cbegin(), args.cend()), + graph_(graph) {} + +absl::Status CuDnnCmd::Initialize(const Thunk::InitializeParams& params, + StateManager&) { + if (!params.stream->parent()->AsDnn()) { + return absl::InternalError("Failed to initialize DNN support for CuDnnCmd"); + } + return absl::OkStatus(); +} + +absl::Status CuDnnCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { + CHECK(graph_ != nullptr); + std::vector operands; + operands.reserve(args_.size()); + for (const BufferAllocation::Slice& arg : args_) { + se::DeviceMemoryBase buf = + execute_params.buffer_allocations->GetDeviceAddress(arg); + VLOG(5) << " Arg: " << arg << ": " << buf.opaque(); + operands.push_back(buf); + } + + return AddTracedCommandBuffer( + execute_params, record_params, command_buffer, [&](se::Stream* stream) { + return graph_->get()->Execute( + *stream, absl::Span(operands)); + }); +} + +CommandBufferCmd::BufferUsageVector CuDnnCmd::buffers() { + CommandBufferCmd::BufferUsageVector buffer_usage; + buffer_usage.reserve(args_.size()); + for (int i = 0; i < args_.size() - 1; ++i) { + buffer_usage.push_back({args_[i], MemoryAccess::kRead}); + } + buffer_usage.push_back({args_.back(), MemoryAccess::kWrite}); + return buffer_usage; +} + +//===----------------------------------------------------------------------===// +// CustomCallCmd +//===----------------------------------------------------------------------===// + +absl::Status CustomCallCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { + if (handler_ == nullptr) { + return RecordLegacyCustomCall(execute_params, record_params, + command_buffer); + } + return RecordXlaFfiCall(execute_params, record_params, command_buffer); +} + +absl::Status CustomCallCmd::RecordLegacyCustomCall( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { + std::vector buffers; + buffers.reserve(operands_.size() + results_.size()); + for (auto& slices : {operands_, results_}) { + for (const std::optional& slice : slices) { + if (!slice.has_value()) { + buffers.push_back(nullptr); + continue; + } + + if (!slice->slice.allocation()) { + return absl::InternalError( + "custom call input missing buffer allocation"); + } + + buffers.push_back( + execute_params.buffer_allocations->GetDeviceAddress(slice->slice) + .opaque()); + } + } + + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "CustomCallCmd: execution_scope_id=" << execution_scope_id.value(); + for (int i = 0; i < operands_.size(); ++i) { + if (operands_[i].has_value()) { + VLOG(5) << " Operand " << i << ": " << operands_[i]->slice << " (" + << buffers[i] << ")"; + } else { + VLOG(5) << " Operand " << i << ": null"; + } + } + for (int i = 0; i < results_.size(); ++i) { + if (results_[i].has_value()) { + VLOG(5) << " Result " << i << ": " << results_[i]->slice << " (" + << buffers[operands_.size() + i] << ")"; + } else { + VLOG(5) << " Result " << i << ": null"; + } + } + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + TF_ASSIGN_OR_RETURN( + auto nested_cmd, + se::CommandBuffer::Trace( + execute_params.stream->parent(), + execute_params.command_buffer_trace_stream, [&](se::Stream* stream) { + se::gpu::GpuStreamHandle gpu_stream = + se::gpu::AsGpuStreamValue(stream); + XlaCustomCallStatus custom_call_status; + call_target_(gpu_stream, buffers.data(), opaque_.data(), + opaque_.size(), &custom_call_status); + auto message = CustomCallStatusGetMessage(&custom_call_status); + if (message) { + return absl::InternalError( + absl::StrCat("CustomCall failed: ", *message)); + } + return absl::OkStatus(); + })); + + return command_buffer->AddNestedCommandBuffer(execution_scope_id, + *nested_cmd); +#else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + return Unavailable( + "Custom calls on GPU are not supported in this configuration. Please " + "build with --config=cuda or --config=rocm"); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +} + +absl::Status CustomCallCmd::RecordXlaFfiCall( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { + // TODO(ezhulenev): This is not the most optimal approach, as we'll be doing + // a lot of extra allocation on every call. We have to keep attributes + // separate from arguments, as they do not change after thunk is constructed. + ffi::CallFrameBuilder builder; + + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "CustomCallCmd: execution_scope_id=" << execution_scope_id.value(); + + for (int i = 0; i < operands_.size(); ++i) { + const std::optional& slice = operands_[i]; + // TODO(ezhulenev): Add a token argument type to XLA:FFI. + if (!slice.has_value()) { + return Internal("FFI handlers do not support tokens (yet)!"); + } + + if (!slice->slice.allocation()) + return Internal("custom call input missing buffer allocation"); + + se::DeviceMemoryBase buffer = + execute_params.buffer_allocations->GetDeviceAddress(slice->slice); + VLOG(5) << " Operand " << i << ": " << slice->slice << " (" + << buffer.opaque() << ")"; + builder.AddBufferArg(buffer, slice->shape.element_type(), + slice->shape.dimensions()); + } + + for (int i = 0; i < results_.size(); ++i) { + const std::optional& slice = results_[i]; + // TODO(ezhulenev): Add a token argument type to XLA:FFI. + if (!slice.has_value()) { + return Internal("FFI handlers do not support tokens (yet)!"); + } + + if (!slice->slice.allocation()) + return Internal("custom call input missing buffer allocation"); + + se::DeviceMemoryBase buffer = + execute_params.buffer_allocations->GetDeviceAddress(slice->slice); + VLOG(5) << " Result " << i << ": " << slice->slice << " (" + << buffer.opaque() << ")"; + builder.AddBufferArg(buffer, slice->shape.element_type(), + slice->shape.dimensions()); + } + + ffi::CallFrameBuilder::AttributesBuilder attrs; + attrs.Append(attributes_); + builder.AddAttributes(attrs.Build()); + ffi::CallFrame call_frame = builder.Build(); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + TF_ASSIGN_OR_RETURN( + auto nested_cmd, + se::CommandBuffer::Trace( + execute_params.stream->parent(), + execute_params.command_buffer_trace_stream, [&](se::Stream* stream) { + ExecutableRunOptions run_options; + run_options.set_stream(stream); + ServiceExecutableRunOptions service_run_options(run_options); + ffi::CallOptions options = {&service_run_options, + called_computation_}; + return ffi::Call(handler_, call_frame, options); + })); + + return command_buffer->AddNestedCommandBuffer(execution_scope_id, + *nested_cmd); +#else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + return Unavailable( + "Custom calls on GPU are not supported in this configuration. Please " + "build with --config=cuda or --config=rocm"); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +} + +CommandBufferCmd::BufferUsageVector CustomCallCmd::buffers() { + CommandBufferCmd::BufferUsageVector buffer_usage; + for (auto& slices : {operands_, results_}) { + for (const std::optional& slice : slices) { + if (!slice.has_value()) continue; + buffer_usage.push_back({slice->slice, MemoryAccess::kWrite}); + } + } + return buffer_usage; +} + +//===----------------------------------------------------------------------===// +// CollectiveCmd +//===----------------------------------------------------------------------===// + +CollectiveCmd::CollectiveCmd(ExecutionStreamId execution_stream_id, + NcclApi* nccl_api, NcclCollectiveConfig config) + : TracedCommandBufferCmd(execution_stream_id), + nccl_api_(nccl_api), + config_(std::move(config)) {} + +absl::Status CollectiveCmd::Prepare( + const Thunk::PrepareParams& params, + Thunk::ResourceRequests& resource_requests) { + const Thunk::CollectiveExecuteParams* collectives = params.collective_params; + + TF_ASSIGN_OR_RETURN( + std::vector participants, + GetParticipatingDevices(collectives->global_device_id, + *collectives->device_assn, + config().replica_groups, config().group_mode)); + + std::vector local_devices; + if (collectives->global_device_id_map) { + local_devices.reserve(collectives->global_device_id_map->size()); + for (const auto& entry : *collectives->global_device_id_map) { + local_devices.push_back(entry.second); + } + } + + size_t num_local_participants = GetNumLocalParticipants( + participants, + collectives->global_device_id_map ? &local_devices : nullptr); + + return resource_requests.AddClique( + NcclCliqueKey(std::move(participants), /*stream_id=*/0, + GetAsyncStreamKind()), + num_local_participants); +} + +//===----------------------------------------------------------------------===// +// AllReduceCmd +//===----------------------------------------------------------------------===// + +AllReduceCmd::AllReduceCmd( + ExecutionStreamId execution_stream_id, NcclApi* nccl_api, + NcclCollectiveConfig config, ReductionKind reduction_kind, + absl::Span buffers) + : CollectiveCmd(execution_stream_id, nccl_api, std::move(config)), + reduction_kind_(reduction_kind), + buffers_(buffers.begin(), buffers.end()) {} + +absl::Status AllReduceCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { + TF_ASSIGN_OR_RETURN( + std::vector device_buffers, + ConvertToDeviceBuffers(execute_params.buffer_allocations, buffers_, + config().operand_element_type)); + + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "AllReduceCmd: reduction=" << ReductionKindString(reduction_kind_) + << "; execution_scope_id=" << execution_scope_id.value(); + + for (size_t i = 0; i < device_buffers.size(); ++i) { + VLOG(5) << " Src: " << buffers_[i].source_buffer << " (" + << device_buffers[i].source_buffer.opaque() << ")"; + VLOG(5) << " Dst: " << buffers_[i].destination_buffer << " (" + << device_buffers[i].destination_buffer.opaque() << ")"; + } + + if (!execute_params.collective_params || !execute_params.collective_cliques) { + return absl::InvalidArgumentError( + "AllReduceCmd requires collective parameters and cliques"); + } + + // Today when recording collective operations into command buffers we always + // use a sync mode and a stream id `0`. + TF_ASSIGN_OR_RETURN(NcclApi::NcclCommHandle comm, + GetNcclComm(*execute_params.collective_params, + *execute_params.collective_cliques, + config().replica_groups, config().group_mode, + /*stream_id=*/0, GetAsyncStreamKind())); + + // Use custom allocator for persistent execution plans. + NcclApi::ScopedPersistentPlanAllocator scoped_allocator( + comm, tsl::MakeRef( + execute_params.buffer_allocations->device_ordinal(), + execute_params.buffer_allocations->memory_allocator(), + execute_params.stream)); + + return AddTracedCommandBuffer( + execute_params, record_params, command_buffer, [&](se::Stream* stream) { + return RunAllReduce(nccl_api(), reduction_kind_, device_buffers, + *stream, comm); + }); +} + +CommandBufferCmd::BufferUsageVector AllReduceCmd::buffers() { + BufferUsageVector buffer_usage; + for (auto& buffer : buffers_) { + buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead); + buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite); + } + return buffer_usage; +} + +//===----------------------------------------------------------------------===// +// ReduceScatterCmd +//===----------------------------------------------------------------------===// + +ReduceScatterCmd::ReduceScatterCmd( + ExecutionStreamId execution_stream_id, NcclApi* nccl_api, + NcclCollectiveConfig config, ReductionKind reduction_kind, + absl::Span buffers) + : CollectiveCmd(execution_stream_id, nccl_api, std::move(config)), + reduction_kind_(reduction_kind), + buffers_(buffers.begin(), buffers.end()) {} + +absl::Status ReduceScatterCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { + TF_ASSIGN_OR_RETURN( + std::vector device_buffers, + ConvertToDeviceBuffers(execute_params.buffer_allocations, buffers_, + config().operand_element_type)); + + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "ReduceScatterCmd: reduction=" + << ReductionKindString(reduction_kind_) + << "; execution_scope_id=" << execution_scope_id.value(); + + for (size_t i = 0; i < device_buffers.size(); ++i) { + VLOG(5) << " Src: " << buffers_[i].source_buffer << " (" + << device_buffers[i].source_buffer.opaque() << ")"; + VLOG(5) << " Dst: " << buffers_[i].destination_buffer << " (" + << device_buffers[i].destination_buffer.opaque() << ")"; + } + + if (!execute_params.collective_params || !execute_params.collective_cliques) { + return absl::InvalidArgumentError( + "ReduceScatterCmd requires collective parameters and cliques"); + } + + // Today when recording collective operations into command buffers we always + // use a sync mode and a stream id `0`. + TF_ASSIGN_OR_RETURN(NcclApi::NcclCommHandle comm, + GetNcclComm(*execute_params.collective_params, + *execute_params.collective_cliques, + config().replica_groups, config().group_mode, + /*stream_id=*/0, GetAsyncStreamKind())); + + // Use custom allocator for persistent execution plans. + NcclApi::ScopedPersistentPlanAllocator scoped_allocator( + comm, tsl::MakeRef( + execute_params.buffer_allocations->device_ordinal(), + execute_params.buffer_allocations->memory_allocator(), + execute_params.stream)); + + return AddTracedCommandBuffer( + execute_params, record_params, command_buffer, [&](se::Stream* stream) { + return RunReduceScatter(nccl_api(), reduction_kind_, device_buffers, + *stream, comm); + }); +} + +CommandBufferCmd::BufferUsageVector ReduceScatterCmd::buffers() { + BufferUsageVector buffer_usage; + for (auto& buffer : buffers_) { + buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead); + buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite); + } + return buffer_usage; +} + +//===----------------------------------------------------------------------===// +// AllGatherCmd +//===----------------------------------------------------------------------===// + +AllGatherCmd::AllGatherCmd( + ExecutionStreamId execution_stream_id, NcclApi* nccl_api, + NcclCollectiveConfig config, + absl::Span buffers) + : CollectiveCmd(execution_stream_id, nccl_api, std::move(config)), + buffers_(buffers.begin(), buffers.end()) {} + +absl::Status AllGatherCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { + TF_ASSIGN_OR_RETURN( + std::vector device_buffers, + ConvertToDeviceBuffers(execute_params.buffer_allocations, buffers_, + config().operand_element_type)); + + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "AllGatherCmd: execution_scope_id=" << execution_scope_id.value(); + + for (size_t i = 0; i < device_buffers.size(); ++i) { + VLOG(5) << " Src: " << buffers_[i].source_buffer << " (" + << device_buffers[i].source_buffer.opaque() << ")"; + VLOG(5) << " Dst: " << buffers_[i].destination_buffer << " (" + << device_buffers[i].destination_buffer.opaque() << ")"; + } + + if (!execute_params.collective_params || !execute_params.collective_cliques) { + return absl::InvalidArgumentError( + "AllGatherCmd requires collective parameters and cliques"); + } + + // Today when recording collective operations into command buffers we always + // use a sync mode and a stream id `0`. + TF_ASSIGN_OR_RETURN(NcclApi::NcclCommHandle comm, + GetNcclComm(*execute_params.collective_params, + *execute_params.collective_cliques, + config().replica_groups, config().group_mode, + /*stream_id=*/0, GetAsyncStreamKind())); + + // Use custom allocator for persistent execution plans. + NcclApi::ScopedPersistentPlanAllocator scoped_allocator( + comm, tsl::MakeRef( + execute_params.buffer_allocations->device_ordinal(), + execute_params.buffer_allocations->memory_allocator(), + execute_params.stream)); + + return AddTracedCommandBuffer( + execute_params, record_params, command_buffer, [&](se::Stream* stream) { + return RunAllGather(nccl_api(), device_buffers, *stream, comm); + }); +} + +CommandBufferCmd::BufferUsageVector AllGatherCmd::buffers() { + BufferUsageVector buffer_usage; + for (auto& buffer : buffers_) { + buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead); + buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite); + } + return buffer_usage; +} + +//===----------------------------------------------------------------------===// +// CollectiveBroadcastCmd +//===----------------------------------------------------------------------===// + +CollectiveBroadcastCmd::CollectiveBroadcastCmd( + ExecutionStreamId execution_stream_id, NcclApi* nccl_api, + NcclCollectiveConfig config, + absl::Span buffers) + : CollectiveCmd(execution_stream_id, nccl_api, std::move(config)), + buffers_(buffers.begin(), buffers.end()) {} + +absl::Status CollectiveBroadcastCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { + TF_ASSIGN_OR_RETURN( + std::vector device_buffers, + ConvertToDeviceBuffers(execute_params.buffer_allocations, buffers_, + config().operand_element_type)); + + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "CollectiveBroadcastCmd: execution_scope_id=" + << execution_scope_id.value(); + + for (size_t i = 0; i < device_buffers.size(); ++i) { + VLOG(5) << " Src: " << buffers_[i].source_buffer << " (" + << device_buffers[i].source_buffer.opaque() << ")"; + VLOG(5) << " Dst: " << buffers_[i].destination_buffer << " (" + << device_buffers[i].destination_buffer.opaque() << ")"; + } + + if (!execute_params.collective_params || !execute_params.collective_cliques) { + return absl::InvalidArgumentError( + "CollectiveBroadcastCmd requires collective parameters and cliques"); + } + + // Today when recording collective operations into command buffers we always + // use a sync mode and a stream id `0`. + TF_ASSIGN_OR_RETURN(NcclApi::NcclCommHandle comm, + GetNcclComm(*execute_params.collective_params, + *execute_params.collective_cliques, + config().replica_groups, config().group_mode, + /*stream_id=*/0, GetAsyncStreamKind())); + + // Use custom allocator for persistent execution plans. + NcclApi::ScopedPersistentPlanAllocator scoped_allocator( + comm, tsl::MakeRef( + execute_params.buffer_allocations->device_ordinal(), + execute_params.buffer_allocations->memory_allocator(), + execute_params.stream)); + + return AddTracedCommandBuffer( + execute_params, record_params, command_buffer, [&](se::Stream* stream) { + return RunCollectiveBroadcast(device_buffers, *stream, comm, + nccl_api()); + }); +} + +CommandBufferCmd::BufferUsageVector CollectiveBroadcastCmd::buffers() { + BufferUsageVector buffer_usage; + for (auto& buffer : buffers_) { + buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead); + buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite); + } + return buffer_usage; +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/command_buffer_cmd.h b/xla/service/gpu/runtime/command_buffer_cmd.h new file mode 100644 index 0000000000000..48861587219fa --- /dev/null +++ b/xla/service/gpu/runtime/command_buffer_cmd.h @@ -0,0 +1,970 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_CMD_H_ +#define XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_CMD_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/functional/function_ref.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/ffi/api/c_api.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/service/gpu/runtime/custom_call_thunk.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/status.h" +#include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" + +namespace xla::gpu { + +//===----------------------------------------------------------------------===// +// CommandBufferCmd +//===----------------------------------------------------------------------===// + +// Command is a Thunk counterpart that instead of launching operations directly +// on the underlying device records them into command buffers. +// +// Commands have the same execution stages as thunks as they are executed by a +// command buffer thunk: Prepare, Initialize and Record (Execute). See Thunk +// documentation for details. +// +// Commands must be thread safe as they can be recorded into multiple command +// buffers concurrently on different stream executors. +class CommandBufferCmd { + public: + explicit CommandBufferCmd(ExecutionStreamId execution_stream_id) + : execution_stream_id_(execution_stream_id) {} + virtual ~CommandBufferCmd() = default; + + enum class MemoryAccess { kRead, kWrite }; + + // BufferUsage tracks memory access type for a buffer slice, so that we can + // correctly insert command buffer barriers to avoid read/write conflicts. + struct BufferUsage { + BufferUsage(BufferAllocation::Slice slice, MemoryAccess access) + : slice(slice), access(access) {} + + template + friend H AbslHashValue(H h, const BufferUsage& buffer) { + return H::combine(std::move(h), buffer.slice, buffer.access); + } + + bool operator==(const BufferUsage& other) const { + return slice == other.slice && access == other.access; + } + + BufferAllocation::Slice slice; + MemoryAccess access; + }; + + using BufferUsageVector = absl::InlinedVector; + + // A base class for externally managed command state. + // + // Commands can be executed concurrently for many stream executors (underlying + // devices) and command buffers. Managing per-executor state can become + // expensive as it requires synchronization. Furthermore the number of command + // buffers command is recorded into is unbounded as they come and go (command + // buffers evicted and reconstructed) which makes it hard to manage the + // lifetime of resources attached to command buffers. + // + // Externally managed state (owned and synchronized by CommandBufferThunk) + // allows commands to attach a piece of information to command buffer in a + // safe and performant way. + class State { + public: + virtual ~State() = default; + }; + + // An external manager for a state attached to commands. + class StateManager { + public: + virtual ~StateManager() = default; + + template + ConcreteState* GetOrNull(const CommandBufferCmd* cmd) { + static_assert(std::is_base_of_v); + return static_cast(GetOrNull(cmd)); + } + + template + ConcreteState* GetOrCreate( + const CommandBufferCmd* cmd, + absl::FunctionRef()> create) { + static_assert(std::is_base_of_v); + return static_cast(GetOrCreate( + cmd, [&]() -> std::unique_ptr { return create(); })); + } + + template + ConcreteState* GetOrCreate(const CommandBufferCmd* cmd) { + static_assert(std::is_base_of_v); + return static_cast( + GetOrCreate(cmd, [] { return std::make_unique(); })); + } + + private: + State* GetOrNull(const CommandBufferCmd* cmd); + + State* GetOrCreate(const CommandBufferCmd* cmd, + absl::FunctionRef()> create); + + absl::flat_hash_map> state_; + }; + + // Parameters for recording commands into the command buffer. + struct RecordParams { + // An external state manager that gives efficient access to per-device state + // to commands without a need to add expensive synchronization. + StateManager& state; + + // Execution scope id defines the default execution scope that should be + // used for recording commands. Each individual command uses this scope plus + // its own execution stream id to compute the execution scope that will be + // used for adding commands to command buffer. It is a command sequence + // responsibility to guarantee that all commands eventually will be + // correctly synchronized with an execution scope id passed as argument. + // + // This argument allows conditional commands to record a command sequence + // into non-default execution scope. + se::CommandBuffer::ExecutionScopeId execution_scope_id = + se::CommandBuffer::kDefaulExecutionScope; + }; + + // See Thunk documentation for XLA execution stages (prepare, initialize, + // execute). Commands mirror thunks as they are executed as CommandBufferThunk + // that is plugged into the Thunk execution cycle. + + // Prepare command for execution by allowing command to request shared state + // required for recording (i.e. collective commands request cliques). + virtual absl::Status Prepare(const Thunk::PrepareParams& params, + Thunk::ResourceRequests& resource_requests) { + return absl::OkStatus(); + } + + // Initialize a command for recording on a given executor. We split it into a + // separate function to allow expensive initialization (e.g. device kernel + // loading) to happen before a command buffer thunk execution. + virtual absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) { + return absl::OkStatus(); + } + + // Records command into the command buffer using given execution scope. + virtual absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) = 0; + + // Returns all buffers used by the cmd. These will be used to track cmd + // updates, thus they need to be consistent across calls to the function. + virtual BufferUsageVector buffers() = 0; + + // Returns true if command implemented as a nested command buffer. + virtual bool IsNestedCommandBuffer() const { return false; } + + // Returns a command execution scope computed from the command stream id and + // the default command buffer execution scope. + se::CommandBuffer::ExecutionScopeId GetExecutionScope( + const RecordParams& record_params) const; + + std::string_view profile_annotation() const { return profile_annotation_; } + void set_profile_annotation(std::string_view profile_annotation) { + profile_annotation_ = profile_annotation; + } + + ExecutionStreamId execution_stream_id() const { return execution_stream_id_; } + + private: + std::string profile_annotation_; + ExecutionStreamId execution_stream_id_; +}; + +//===----------------------------------------------------------------------===// +// CommandBufferCmdSequence +//===----------------------------------------------------------------------===// + +// A sequence of command buffer commands that create or update a command buffer. +// You can think of CommandBufferCmdSequence as a mini interpreter whose sole +// purpose is to manipulate command buffers at run time. +class CommandBufferCmdSequence { + public: + // Synchronization mode defines how execution streams gets converted to + // command buffer execution scopes and barriers. + // + // Each individual Thunk assigned an execution stream id, and we have explicit + // inter-stream synchronization (`Thunk::Kind::kWaitForStreams`) between + // streams. Thunks assigned to the same stream are implicitly synchronized. + // + // Command buffers on the other hand by default can execute commands + // concurrently and require barriers to enforce execution order. + // + // WARNING: We do not have implicit synchronization between execution scopes + // corresponding to different execution streams and rely on explicit barriers + // emitted from thunks. Synchronization mode controls only barriers within + // a single exection scope (corresponds to execution stream). + enum class SynchronizationMode { + // Adds barriers between all commands recorded into the same execution scope + // (thunks sharing execution stream) and enforces completely serialized + // execution order that matches what would happen in a ThunkSequence. + kSerialize, + + // Relies on buffer use analysis to insert barriers only between commands + // that have read-write conflicts into the same buffers. Conflicts are + // detected only between commands using the same stream id, and inter-stream + // synchronization is a user responsibility. + kAutomatic + }; + + enum class RecordMode { + // In exclusive mode no one else is recording commands into the command + // buffer argument, and cmd sequence is responsible for updating command + // buffer state: finalizing after all commands recorded, and + // switching to update state before recording updates. + kExclusive, + + // In conditional mode multiple cmd sequences can be recorded into the + // command buffer argument, and with command buffer state managed externally + // cmd sequence should not finalize or update it. This mode is used when + // command buffer cmd sequence is recorded into conditional command buffers + // owned by the parent command buffer. + kConditional + }; + + explicit CommandBufferCmdSequence(SynchronizationMode synchronization_mode = + SynchronizationMode::kAutomatic); + + void Append(std::unique_ptr cmd); + + template + void Emplace(Args... args) { + Append(std::make_unique(std::forward(args)...)); + } + + // Prepares all commands added to a sequence. + absl::Status Prepare(const Thunk::PrepareParams& params, + Thunk::ResourceRequests& resource_requests); + + // Initializes all commands added to a sequence. + absl::Status Initialize(const Thunk::InitializeParams& params, + CommandBufferCmd::StateManager& state); + + // Records all commands added to a sequence into the given command buffer. + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const CommandBufferCmd::RecordParams& record_params, + se::CommandBuffer* command_buffer, + RecordMode mode = RecordMode::kExclusive); + + // Returns buffers referenced by commands in this sequence. + const absl::flat_hash_set& buffers() const; + + // Returns buffer allocations indices referenced by commands in this sequence. + const absl::flat_hash_set& allocs_indices() const; + + // Returns a vector that tells if command at the given index requires a + // barrier. + std::vector barriers() const; + + bool empty() const { return commands_.empty(); } + size_t size() const { return commands_.size(); } + + private: + struct CommandInfo { + std::unique_ptr cmd; + bool requires_barrier; + }; + + // Functions for tracking buffer usage of recorded commands and figuring out + // when the next command requires a barrier for correctness. + bool HasConflicts(ExecutionStreamId execution_stream_id, + const CommandBufferCmd::BufferUsageVector& buffers); + void TrackBuffers(ExecutionStreamId execution_stream_id, + const CommandBufferCmd::BufferUsageVector& buffers); + void ClearTrackedBuffers(ExecutionStreamId execution_stream_id); + + SynchronizationMode synchronization_mode_; + std::vector commands_; + + // Buffers referenced by commands in this sequence. + absl::flat_hash_set buffers_; + + // Buffer allocations indices referenced by commands in this sequence. + absl::flat_hash_set allocs_indices_; + + // We track read and write sets of commands recorded into the command + // sequence to detect conflicts and insert explicit barriers. These are the + // buffer allocation slices used by commands appended since the last barrier. + struct ReadWriteSet { + absl::flat_hash_set read; + absl::flat_hash_set write; + }; + + absl::flat_hash_map read_write_sets_; +}; + +//===----------------------------------------------------------------------===// +// TracedCommandBuffer +//===----------------------------------------------------------------------===// + +// A cache for traced command buffers that will re-trace on change in buffer +// allocations that are relevant for `buffers` passed to constructor. We use a +// very simple most-recently-used cache of traced command buffers as in practice +// subsequent calls to XLA executable tend to reuse the same allocations. +class TracedCommandBuffer : public CommandBufferCmd::State { + public: + explicit TracedCommandBuffer(CommandBufferCmd::BufferUsageVector buffers, + int64_t capacity = 16); + + // Returns cached command buffer traced using the same buffer addresses or + // traces and caches a new command buffer using user provided callback. + absl::StatusOr GetOrTraceCommandBuffer( + const BufferAllocations* buffer_allocation, se::StreamExecutor* executor, + se::Stream* stream, absl::FunctionRef trace); + + private: + std::vector allocs_indices_; + + struct Entry { + std::vector recorded_allocs; + std::unique_ptr command_buffer; + }; + + int64_t capacity_; + std::vector entries_; +}; + +//===----------------------------------------------------------------------===// +// TracedCommandBufferCmd +//===----------------------------------------------------------------------===// + +// A base class for commands implemented as tracing of stream activities. +class TracedCommandBufferCmd : public CommandBufferCmd { + protected: + explicit TracedCommandBufferCmd(ExecutionStreamId execution_stream_id); + + // Creates a command buffer by calling a user-provided `trace` function and + // adds it as a nested command to `command_buffer`. Traced command buffers + // cached and reused in an instance of `TracedCommandBuffer` kept in `state`. + absl::Status AddTracedCommandBuffer( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer, + absl::FunctionRef trace); +}; + +//===----------------------------------------------------------------------===// +// ComputationIdCmd (ReplicaId and PartitionId) +//===----------------------------------------------------------------------===// + +class ComputationIdCmd : public CommandBufferCmd { + public: + enum class Kind { kReplica, kPartition }; + + ComputationIdCmd(ExecutionStreamId execution_stream_id, + BufferAllocation::Slice dest, Kind kind); + + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) override; + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + private: + BufferAllocation::Slice dest_; + Kind kind_; + + // Command sequence can be recorded concurrently for multiple command buffers + // on different stream executors and we need to synchronize mutable state. + absl::Mutex mutex_; + + // TODO(ezhulenev): This is a workaround for CUDA graphs + conditional nodes + // bug that will be fixed in CUDA 12.4.1 release: currently it's impossible to + // update a memset node inside a conditional graph. Instead of using memset + // node we replace it with a kernel launch node of CUDA kernels doing 1D + // memset. This should be removed when bug is fixed in CUDA. + absl::flat_hash_map> + memset_kernels_ ABSL_GUARDED_BY(mutex_); +}; + +//===----------------------------------------------------------------------===// +// LaunchCmd +//===----------------------------------------------------------------------===// + +class LaunchCmd : public CommandBufferCmd { + public: + LaunchCmd(ExecutionStreamId execution_stream_id, std::string kernel_name, + absl::Span args, + absl::Span args_access, LaunchDimensions dims, + int64_t shmem_bytes); + + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) override; + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + private: + std::string kernel_name_; + std::vector args_; + std::vector args_access_; + LaunchDimensions dims_; + int64_t shmem_bytes_; + + // Command sequence can be recorded concurrently for multiple command buffers + // on different stream executors and we need to synchronize mutable state. + absl::Mutex mutex_; + absl::flat_hash_map> kernels_ + ABSL_GUARDED_BY(mutex_); +}; + +//===----------------------------------------------------------------------===// +// CustomKenelLaunchCmd +//===----------------------------------------------------------------------===// + +class CustomKernelLaunchCmd : public CommandBufferCmd { + public: + CustomKernelLaunchCmd(ExecutionStreamId execution_stream_id, + absl::Span args, + absl::Span args_access, + CustomKernel custom_kernel); + + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) override; + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + private: + std::vector args_; + std::vector args_access_; + CustomKernel custom_kernel_; + + // Command sequence can be recorded concurrently for multiple command buffers + // on different stream executors and we need to synchronize mutable state. + absl::Mutex mutex_; + absl::flat_hash_map> kernels_ + ABSL_GUARDED_BY(mutex_); +}; + +//===----------------------------------------------------------------------===// +// MemcpyDeviceToDeviceCmd +//===----------------------------------------------------------------------===// + +class MemcpyDeviceToDeviceCmd : public CommandBufferCmd { + public: + MemcpyDeviceToDeviceCmd(ExecutionStreamId execution_stream_id, + BufferAllocation::Slice dst, + BufferAllocation::Slice src, int64_t num_bytes); + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + private: + BufferAllocation::Slice dst_; + BufferAllocation::Slice src_; + int64_t num_bytes_; +}; + +//===----------------------------------------------------------------------===// +// MemzeroCmd +//===----------------------------------------------------------------------===// + +class MemzeroCmd : public CommandBufferCmd { + public: + MemzeroCmd(ExecutionStreamId execution_stream_id, + BufferAllocation::Slice dst); + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + private: + BufferAllocation::Slice dst_; +}; + +//===----------------------------------------------------------------------===// +// Memset32Cmd +//===----------------------------------------------------------------------===// + +class Memset32Cmd : public CommandBufferCmd { + public: + Memset32Cmd(ExecutionStreamId execution_stream_id, + BufferAllocation::Slice dst, uint32_t bit_pattern); + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + private: + BufferAllocation::Slice dst_; + uint32_t bit_pattern_; +}; + +//===----------------------------------------------------------------------===// +// IfCmd +//===----------------------------------------------------------------------===// + +class IfCmd : public CommandBufferCmd { + public: + IfCmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice pred, + CommandBufferCmdSequence then_commands); + + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) override; + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + private: + BufferAllocation::Slice pred_; + CommandBufferCmdSequence then_commands_; +}; + +//===----------------------------------------------------------------------===// +// IfElseCmd +//===----------------------------------------------------------------------===// + +class IfElseCmd : public CommandBufferCmd { + public: + IfElseCmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice pred, + CommandBufferCmdSequence then_commands, + CommandBufferCmdSequence else_commands); + + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) override; + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + private: + BufferAllocation::Slice pred_; + CommandBufferCmdSequence then_commands_; + CommandBufferCmdSequence else_commands_; +}; + +//===----------------------------------------------------------------------===// +// CaseCmd +//===----------------------------------------------------------------------===// + +class CaseCmd : public CommandBufferCmd { + public: + CaseCmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice index, + std::vector branches_commands); + + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) override; + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + private: + BufferAllocation::Slice index_; + std::vector branches_commands_; +}; + +//===----------------------------------------------------------------------===// +// ForCmd +//===----------------------------------------------------------------------===// + +class ForCmd : public CommandBufferCmd { + public: + ForCmd(ExecutionStreamId execution_stream_id, int32_t num_iterations, + BufferAllocation::Slice loop_counter, + CommandBufferCmdSequence body_commands); + + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) override; + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + private: + int32_t num_iterations_; + BufferAllocation::Slice loop_counter_; + CommandBufferCmdSequence body_commands_; +}; + +//===----------------------------------------------------------------------===// +// WhileCmd +//===----------------------------------------------------------------------===// + +class WhileCmd : public CommandBufferCmd { + public: + WhileCmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice pred, + CommandBufferCmdSequence cond_commands, + CommandBufferCmdSequence body_commands); + + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) override; + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + private: + BufferAllocation::Slice pred_; + CommandBufferCmdSequence cond_commands_; + CommandBufferCmdSequence body_commands_; +}; + +//===----------------------------------------------------------------------===// +// AllocateCmd +//===----------------------------------------------------------------------===// + +class AllocateCmd : public CommandBufferCmd { + public: + AllocateCmd(ExecutionStreamId execution_stream_id, + BufferAllocation allocation); + + // After calling this function, the allocated memory is tracked in + // CommandBuffer object. + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + private: + BufferAllocation allocation_; +}; + +//===----------------------------------------------------------------------===// +// FreeCmd +//===----------------------------------------------------------------------===// + +class FreeCmd : public CommandBufferCmd { + public: + FreeCmd(ExecutionStreamId execution_stream_id, BufferAllocation allocation); + + // After calling this function, the allocated memory address for dst + // BufferAllocation is freed, no update is required. + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + private: + BufferAllocation allocation_; +}; + +//===----------------------------------------------------------------------===// +// GemmCmd +//===----------------------------------------------------------------------===// + +class GemmCmd : public TracedCommandBufferCmd { + public: + GemmCmd(ExecutionStreamId execution_stream_id, GemmConfig config, + const BufferAllocation::Slice& lhs_buffer, + const BufferAllocation::Slice& rhs_buffer, + const BufferAllocation::Slice& output_buffer, + const BufferAllocation::Slice& workspace, bool deterministic); + + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) override; + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + bool IsNestedCommandBuffer() const final { return true; } + + private: + const GemmConfig config_; + const BufferAllocation::Slice lhs_buffer_; + const BufferAllocation::Slice rhs_buffer_; + const BufferAllocation::Slice output_buffer_; + const BufferAllocation::Slice workspace_; + // Whether to run deterministically. + const bool deterministic_; +}; + +//===----------------------------------------------------------------------===// +// CuDnnCmd +//===----------------------------------------------------------------------===// + +class CuDnnCmd : public TracedCommandBufferCmd { + public: + CuDnnCmd(ExecutionStreamId execution_stream_id, + absl::Span args, + std::shared_ptr graph); + + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state) override; + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + bool IsNestedCommandBuffer() const final { return true; } + + private: + std::vector args_; + const std::shared_ptr graph_; +}; + +//===----------------------------------------------------------------------===// +// CustomCallCmd +//===----------------------------------------------------------------------===// + +class CustomCallCmd : public CommandBufferCmd { + public: + using Slice = CustomCallThunk::Slice; + using Stream = CustomCallThunk::Stream; + using CustomCallTarget = CustomCallThunk::CustomCallTarget; + using AttributesMap = CustomCallThunk::AttributesMap; + + // This is a legacy custom call API that is discouraged, and will be + // deprecated once XLA:FFI mechanism is ready. + // + // TODO(b/323534971): We have an ODR violation somewhere in Tensorflow/XLA and + // include this header with different set of defines and CustomCallTarget + // has different meaning in different translation units. We need to get rid of + // GOOGLE_CUDA defines all over XLA to fix this! As a workaround just keep + // constructor in a header file. + CustomCallCmd(ExecutionStreamId execution_stream_id, + CustomCallTarget call_target, + std::vector> operands, + std::vector> results, + absl::string_view opaque) + : CommandBufferCmd(execution_stream_id), + call_target_(std::move(call_target)), + opaque_(opaque), + operands_(std::move(operands)), + results_(std::move(results)) {} + + CustomCallCmd(ExecutionStreamId execution_stream_id, XLA_FFI_Handler* handler, + std::vector> operands, + std::vector> results, + AttributesMap attributes, + const HloComputation* called_computation) + : CommandBufferCmd(execution_stream_id), + handler_(handler), + attributes_(std::move(attributes)), + called_computation_(called_computation), + operands_(std::move(operands)), + results_(std::move(results)) {} + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + bool IsNestedCommandBuffer() const final { return true; } + + private: + absl::Status RecordLegacyCustomCall(const Thunk::ExecuteParams& execute_param, + const RecordParams& record_params, + se::CommandBuffer* command_buffer); + absl::Status RecordXlaFfiCall(const Thunk::ExecuteParams& execute_param, + const RecordParams& record_params, + se::CommandBuffer* command_buffer); + + // This is a legacy custom call API that is discouraged, and will be + // deprecated once XLA:FFI mechanism is ready. + CustomCallTarget call_target_; + std::string opaque_; + + // XLA FFI provides a right type safe mechanism for registering external + // functions with XLA runtime. It's under construction, and still misses + // a lot of features. Long term it will replace legacy custom calls. + XLA_FFI_Handler* handler_ = nullptr; + AttributesMap attributes_; + const HloComputation* called_computation_; + + std::vector> operands_; + std::vector> results_; +}; + +//===----------------------------------------------------------------------===// +// CollectiveCmd +//===----------------------------------------------------------------------===// + +class CollectiveCmd : public TracedCommandBufferCmd { + public: + CollectiveCmd(ExecutionStreamId execution_stream_id, NcclApi* nccl_api, + NcclCollectiveConfig config); + + absl::Status Prepare(const Thunk::PrepareParams& params, + Thunk::ResourceRequests& resource_requests) final; + + bool IsNestedCommandBuffer() const final { return true; } + + virtual AsyncStreamKind GetAsyncStreamKind() = 0; + + protected: + NcclApi* nccl_api() const { return nccl_api_; } + const NcclCollectiveConfig& config() const { return config_; } + + private: + NcclApi* nccl_api_; + NcclCollectiveConfig config_; +}; + +//===----------------------------------------------------------------------===// +// AllReduceCmd +//===----------------------------------------------------------------------===// + +class AllReduceCmd : public CollectiveCmd { + public: + AllReduceCmd(ExecutionStreamId execution_stream_id, NcclApi* nccl_api, + NcclCollectiveConfig config, ReductionKind reduction_kind, + absl::Span buffers); + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + AsyncStreamKind GetAsyncStreamKind() override { + return AsyncStreamKind::kCollective; + }; + + private: + ReductionKind reduction_kind_; + std::vector buffers_; +}; + +//===----------------------------------------------------------------------===// +// ReduceScatterCmd +//===----------------------------------------------------------------------===// + +class ReduceScatterCmd : public CollectiveCmd { + public: + ReduceScatterCmd(ExecutionStreamId execution_stream_id, NcclApi* nccl_api, + NcclCollectiveConfig config, ReductionKind reduction_kind, + absl::Span buffers); + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + AsyncStreamKind GetAsyncStreamKind() override { + return AsyncStreamKind::kCollective; + }; + + private: + ReductionKind reduction_kind_; + std::vector buffers_; +}; + +//===----------------------------------------------------------------------===// +// AllGatherCmd +//===----------------------------------------------------------------------===// + +class AllGatherCmd : public CollectiveCmd { + public: + AllGatherCmd(ExecutionStreamId execution_stream_id, NcclApi* nccl_api, + NcclCollectiveConfig config, + absl::Span buffers); + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + AsyncStreamKind GetAsyncStreamKind() override { + return AsyncStreamKind::kCollective; + }; + + private: + std::vector buffers_; +}; + +//===----------------------------------------------------------------------===// +// CollectiveBroadcastCmd +//===----------------------------------------------------------------------===// + +class CollectiveBroadcastCmd : public CollectiveCmd { + public: + CollectiveBroadcastCmd(ExecutionStreamId execution_stream_id, + NcclApi* nccl_api, NcclCollectiveConfig config, + absl::Span buffers); + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + private: + std::vector buffers_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_CMD_H_ diff --git a/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc b/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc new file mode 100644 index 0000000000000..e1974c17cfefa --- /dev/null +++ b/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc @@ -0,0 +1,290 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/command_buffer_cmd_emitter.h" + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/service/gpu/runtime/command_buffer_cmd.h" +#include "xla/service/gpu/runtime/conditional_thunk.h" +#include "xla/service/gpu/runtime/copy_thunk.h" +#include "xla/service/gpu/runtime/cudnn_thunk.h" +#include "xla/service/gpu/runtime/custom_call_thunk.h" +#include "xla/service/gpu/runtime/gemm_thunk.h" +#include "xla/service/gpu/runtime/kernel_thunk.h" +#include "xla/service/gpu/runtime/memset_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_gather_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h" +#include "xla/service/gpu/runtime/replica_id_thunk.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/service/gpu/runtime/while_thunk.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { + +// Appends command(s) converted from `thunk` to `cmd_sequence`. +static absl::Status AppendCommands( + CommandBufferCmdSequence& cmd_sequence, const Thunk& thunk, + CommandBufferCmdSequence::SynchronizationMode synchronization_mode); + +// Appends command(s) converted from `sequence` to `cmd_sequence`. +static absl::Status AppendCommands( + CommandBufferCmdSequence& cmd_sequence, const ThunkSequence& sequence, + CommandBufferCmdSequence::SynchronizationMode synchronization_mode); + +//===----------------------------------------------------------------------===// +// Conversions from Thunk to Command +//===----------------------------------------------------------------------===// + +using Command = std::unique_ptr; + +static auto ArgsAccess(const std::vector& written) { + absl::InlinedVector args_access; + args_access.reserve(written.size()); + for (bool w : written) { + args_access.push_back(w ? CommandBufferCmd::MemoryAccess::kWrite + : CommandBufferCmd::MemoryAccess::kRead); + } + return args_access; +} + +static absl::StatusOr Convert(const KernelThunk& thunk) { + return std::make_unique( + thunk.execution_stream_id(), thunk.kernel_name(), thunk.arguments(), + ArgsAccess(thunk.written()), thunk.launch_dimensions(), + thunk.shmem_bytes()); +} + +static absl::StatusOr Convert(const CustomKernelThunk& thunk) { + return std::make_unique( + thunk.execution_stream_id(), thunk.arguments(), + ArgsAccess(thunk.written()), thunk.custom_kernel()); +} + +static absl::StatusOr Convert(const DeviceToDeviceCopyThunk& thunk) { + return std::make_unique( + thunk.execution_stream_id(), thunk.destination(), thunk.source(), + thunk.size_bytes()); +} + +static absl::StatusOr Convert(const MemzeroThunk& thunk) { + return std::make_unique(thunk.execution_stream_id(), + thunk.destination()); +} + +static absl::StatusOr Convert(const Memset32BitValueThunk& thunk) { + return std::make_unique(thunk.execution_stream_id(), + thunk.destination(), thunk.value()); +} + +static absl::StatusOr Convert( + const WhileThunk& thunk, + CommandBufferCmdSequence::SynchronizationMode synchronization_mode) { + TF_ASSIGN_OR_RETURN( + CommandBufferCmdSequence cond_cmds, + ConvertToCommands(thunk.condition_thunk_sequence()->thunks(), + synchronization_mode)); + TF_ASSIGN_OR_RETURN(CommandBufferCmdSequence body_cmds, + ConvertToCommands(thunk.body_thunk_sequence()->thunks(), + synchronization_mode)); + return std::make_unique(thunk.execution_stream_id(), + thunk.condition_result_buffer(), + std::move(cond_cmds), std::move(body_cmds)); +} + +static absl::StatusOr Convert(const GemmThunk& thunk) { + if (!thunk.workspace().has_value()) { + return absl::InternalError( + "Gemm thunk does not contain a workspace buffer"); + } + return std::make_unique( + thunk.execution_stream_id(), thunk.config(), thunk.lhs_buffer(), + thunk.rhs_buffer(), thunk.output_buffer(), thunk.workspace().value(), + thunk.deterministic()); +} + +static absl::StatusOr Convert( + const ConditionalThunk& thunk, + CommandBufferCmdSequence::SynchronizationMode synchronization_mode) { + std::vector branch_cmds; + branch_cmds.reserve(thunk.branch_thunks().size()); + for (auto& branch_thunk : thunk.branch_thunks()) { + TF_ASSIGN_OR_RETURN( + CommandBufferCmdSequence cmds, + ConvertToCommands(branch_thunk->thunks(), synchronization_mode)); + branch_cmds.emplace_back(std::move(cmds)); + } + return std::make_unique(thunk.execution_stream_id(), + thunk.branch_index_buffer(), + std::move(branch_cmds)); +} + +static absl::StatusOr Convert(const NcclAllReduceStartThunk& thunk) { + return std::make_unique( + thunk.execution_stream_id(), thunk.nccl_api(), thunk.config(), + thunk.reduction_kind(), thunk.buffers()); +} + +static absl::StatusOr Convert( + const NcclReduceScatterStartThunk& thunk) { + return std::make_unique( + thunk.execution_stream_id(), thunk.nccl_api(), thunk.config(), + thunk.reduction_kind(), thunk.buffers()); +} + +static absl::StatusOr Convert(const NcclAllGatherStartThunk& thunk) { + return std::make_unique(thunk.execution_stream_id(), + thunk.nccl_api(), thunk.config(), + thunk.buffers()); +} + +static absl::StatusOr Convert(const PartitionIdThunk& thunk) { + return std::make_unique(thunk.execution_stream_id(), + thunk.dest(), + ComputationIdCmd::Kind::kPartition); +} + +static absl::StatusOr Convert(const ReplicaIdThunk& thunk) { + return std::make_unique(thunk.execution_stream_id(), + thunk.dest(), + ComputationIdCmd::Kind::kReplica); +} + +static absl::StatusOr Convert(const CustomCallThunk& thunk) { + return std::make_unique(thunk.execution_stream_id(), + thunk.call_target(), thunk.operands(), + thunk.results(), thunk.opaque()); +} + +static absl::StatusOr Convert(const CuDnnThunk& thunk) { + return std::make_unique(thunk.execution_stream_id(), + thunk.arguments(), thunk.graph()); +} + +//===----------------------------------------------------------------------===// +static absl::StatusOr CopyMetadata(absl::StatusOr cmd, + const Thunk& thunk) { + if (cmd.ok()) { + (*cmd)->set_profile_annotation(thunk.profile_annotation()); + return cmd; + } + return cmd; +} + +template +static absl::StatusOr Convert(const Thunk& thunk) { + return CopyMetadata(Convert(static_cast(thunk)), thunk); +} + +template +static absl::StatusOr Convert( + const Thunk& thunk, + CommandBufferCmdSequence::SynchronizationMode synchronization_mode) { + return Convert(static_cast(thunk), synchronization_mode); +} + +static absl::Status AppendCommands( + CommandBufferCmdSequence& cmd_sequence, const Thunk& thunk, + CommandBufferCmdSequence::SynchronizationMode synchronization_mode) { + auto append = [&](absl::StatusOr command) -> absl::Status { + if (command.ok()) { + cmd_sequence.Append(std::move(*command)); + return absl::OkStatus(); + } + return command.status(); + }; + + switch (thunk.kind()) { + case Thunk::Kind::kConditional: + return append(Convert(thunk, synchronization_mode)); + case Thunk::Kind::kCopy: + return append(Convert(thunk)); + case Thunk::Kind::kCustomCall: + return append(Convert(thunk)); + case Thunk::Kind::kCustomKernel: + return append(Convert(thunk)); + case Thunk::Kind::kKernel: + return append(Convert(thunk)); + case Thunk::Kind::kGemm: + return append(Convert(thunk)); + case Thunk::Kind::kMemset32BitValue: + return append(Convert(thunk)); + case Thunk::Kind::kMemzero: + return append(Convert(thunk)); + case Thunk::Kind::kNcclAllGatherStart: + return append(Convert(thunk)); + case Thunk::Kind::kNcclAllReduceStart: + return append(Convert(thunk)); + case Thunk::Kind::kNcclReduceScatterStart: + return append(Convert(thunk)); + case Thunk::Kind::kPartitionId: + return append(Convert(thunk)); + case Thunk::Kind::kReplicaId: + return append(Convert(thunk)); + case Thunk::Kind::kWhile: + return append(Convert(thunk, synchronization_mode)); + case Thunk::Kind::kCuDnn: + return append(Convert(thunk)); + + // Sequential thunk does not have any special semantics and we simply inline + // all nested thunks into command buffer. + case Thunk::Kind::kSequential: + return AppendCommands(cmd_sequence, + static_cast(thunk).thunks(), + synchronization_mode); + + // Currently all collective operations recorded on the tracing stream and do + // not need to have a separate done command. + case Thunk::Kind::kNcclAllGatherDone: + case Thunk::Kind::kNcclAllReduceDone: + case Thunk::Kind::kNcclReduceScatterDone: + case Thunk::Kind::kWaitForStreams: + return absl::OkStatus(); + + default: + return Internal("Unsupported thunk kind: %s", + Thunk::KindToString(thunk.kind())); + } +} + +static absl::Status AppendCommands( + CommandBufferCmdSequence& cmd_sequence, const ThunkSequence& sequence, + CommandBufferCmdSequence::SynchronizationMode synchronization_mode) { + for (const std::unique_ptr& thunk : sequence) + TF_RETURN_IF_ERROR( + AppendCommands(cmd_sequence, *thunk, synchronization_mode)); + return absl::OkStatus(); +} + +// TODO(vuson): Add unit tests. +absl::StatusOr ConvertToCommands( + const ThunkSequence& sequence, + CommandBufferCmdSequence::SynchronizationMode synchronization_mode) { + CommandBufferCmdSequence cmd_sequence(synchronization_mode); + TF_RETURN_IF_ERROR( + AppendCommands(cmd_sequence, sequence, synchronization_mode)); + return cmd_sequence; +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/command_buffer_cmd_emitter.h b/xla/service/gpu/runtime/command_buffer_cmd_emitter.h new file mode 100644 index 0000000000000..a5608355a7975 --- /dev/null +++ b/xla/service/gpu/runtime/command_buffer_cmd_emitter.h @@ -0,0 +1,35 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_CMD_EMITTER_H_ +#define XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_CMD_EMITTER_H_ + +#include "absl/status/statusor.h" +#include "xla/service/gpu/runtime/command_buffer_cmd.h" +#include "xla/service/gpu/runtime/thunk.h" + +namespace xla::gpu { + +// Converts thunk sequence to a command buffer cmd sequence. If `force_barrier` +// is true we automatically insert barriers between all commands in a sequence. +// Otherwise we use buffer usage aliasing to allow commands to run concurrently +// and insert barriers only when needed for correctness. +absl::StatusOr ConvertToCommands( + const ThunkSequence& sequence, + CommandBufferCmdSequence::SynchronizationMode synchronization_mode); + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_CMD_EMITTER_H_ diff --git a/xla/service/gpu/runtime/command_buffer_cmd_test.cc b/xla/service/gpu/runtime/command_buffer_cmd_test.cc new file mode 100644 index 0000000000000..d174c4733d886 --- /dev/null +++ b/xla/service/gpu/runtime/command_buffer_cmd_test.cc @@ -0,0 +1,438 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/command_buffer_cmd.h" + +#include +#include +#include + +#include "absl/functional/function_ref.h" +#include "absl/status/status.h" +#include "absl/strings/ascii.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/service/platform_util.h" +#include "xla/service/service_executable_run_options.h" +#include "xla/status.h" +#include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_test_kernels.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/types.h" // IWYU pragma: keep +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" +#include "tsl/platform/test_benchmark.h" + +namespace xla::gpu { + +using BufferUsage = CommandBufferCmd::BufferUsage; +using BufferUsageVector = CommandBufferCmd::BufferUsageVector; +using MemoryAccess = CommandBufferCmd::MemoryAccess; + +static se::StreamExecutor* GpuExecutor() { + auto name = + absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); + auto* platform = se::PlatformManager::PlatformWithName(name).value(); + return platform->ExecutorForDevice(0).value(); +} + +// Give a short aliases to execution threads. +static constexpr auto s0 = ExecutionStreamId(0); +static constexpr auto s1 = ExecutionStreamId(1); + +// A command buffer cmd for testing automatic barriers insertion by the command +// buffer cmd sequence. We never execute this command, we need it only to pass +// buffer usage vector to the command buffer cmd sequence. +struct TestOnlyCommandBufferCmd : public CommandBufferCmd { + TestOnlyCommandBufferCmd(ExecutionStreamId execution_stream_id, + BufferUsageVector buffer_usage) + : CommandBufferCmd(execution_stream_id), buffer_usage(buffer_usage) {} + + absl::Status Record(const Thunk::ExecuteParams&, const RecordParams&, + se::CommandBuffer*) override { + return absl::OkStatus(); + } + + BufferUsageVector buffers() override { return buffer_usage; } + + BufferUsageVector buffer_usage; +}; + +TEST(CommandBufferCmdTest, SerializeExecution) { + BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0); + + auto slice0 = BufferAllocation::Slice(&alloc0, 0, 100); + auto slice1 = BufferAllocation::Slice(&alloc0, 50, 100); + + // Reads from overlapping slices do not require barriers by default. + auto use0 = BufferUsage(slice0, MemoryAccess::kRead); + auto use1 = BufferUsage(slice1, MemoryAccess::kRead); + + CommandBufferCmdSequence commands( + CommandBufferCmdSequence::SynchronizationMode::kSerialize); + commands.Emplace(s0, BufferUsageVector{use0}); + commands.Emplace(s0, BufferUsageVector{use1}); + + ASSERT_EQ(commands.barriers().size(), 2); + EXPECT_EQ(commands.barriers().at(0), false); + EXPECT_EQ(commands.barriers().at(1), true); +} + +TEST(CommandBufferCmdTest, NoReadBarrier) { + BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0); + + auto slice0 = BufferAllocation::Slice(&alloc0, 0, 100); + auto slice1 = BufferAllocation::Slice(&alloc0, 50, 100); + + // Reads from overlapping slices do not require barriers. + auto use0 = BufferUsage(slice0, MemoryAccess::kRead); + auto use1 = BufferUsage(slice1, MemoryAccess::kRead); + + CommandBufferCmdSequence commands; + commands.Emplace(s0, BufferUsageVector{use0}); + commands.Emplace(s0, BufferUsageVector{use1}); + + ASSERT_EQ(commands.barriers().size(), 2); + EXPECT_EQ(commands.barriers().at(0), false); + EXPECT_EQ(commands.barriers().at(1), false); +} + +TEST(CommandBufferCmdTest, NoWriteBarrier) { + BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0); + + // Writes to non-overlapping slices do not require barriers. + auto slice0 = BufferAllocation::Slice(&alloc0, 0, 100); + auto slice1 = BufferAllocation::Slice(&alloc0, 200, 100); + + auto use0 = BufferUsage(slice0, MemoryAccess::kWrite); + auto use1 = BufferUsage(slice1, MemoryAccess::kWrite); + + CommandBufferCmdSequence commands; + commands.Emplace(s0, BufferUsageVector{use0}); + commands.Emplace(s0, BufferUsageVector{use1}); + + ASSERT_EQ(commands.barriers().size(), 2); + EXPECT_EQ(commands.barriers().at(0), false); + EXPECT_EQ(commands.barriers().at(1), false); +} + +TEST(CommandBufferCmdTest, WriteConflictBarrier) { + BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0); + + auto slice0 = BufferAllocation::Slice(&alloc0, 0, 100); + auto slice1 = BufferAllocation::Slice(&alloc0, 50, 100); + + // Reads from overlapping slices can be done in parallel, and before a write + // into overlapping slice we need to insert a barrier. + auto use0 = BufferUsage(slice0, MemoryAccess::kRead); + auto use1 = BufferUsage(slice0, MemoryAccess::kRead); + auto use2 = BufferUsage(slice1, MemoryAccess::kWrite); + + CommandBufferCmdSequence commands; + commands.Emplace(s0, BufferUsageVector{use0}); + commands.Emplace(s0, BufferUsageVector{use1}); + commands.Emplace(s0, BufferUsageVector{use2}); + + ASSERT_EQ(commands.barriers().size(), 3); + EXPECT_EQ(commands.barriers().at(0), false); + EXPECT_EQ(commands.barriers().at(1), false); + EXPECT_EQ(commands.barriers().at(2), true); +} + +TEST(CommandBufferCmdTest, NoWriteConflictsAcrossStreams) { + BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0); + + auto slice0 = BufferAllocation::Slice(&alloc0, 0, 100); + auto slice1 = BufferAllocation::Slice(&alloc0, 50, 100); + + // Read and write happens on different execution streams and we do not insert + // any automatic barriers between streams. + auto use0 = BufferUsage(slice0, MemoryAccess::kRead); + auto use1 = BufferUsage(slice1, MemoryAccess::kWrite); + + CommandBufferCmdSequence commands; + commands.Emplace(s0, BufferUsageVector{use0}); + commands.Emplace(s1, BufferUsageVector{use1}); + + ASSERT_EQ(commands.barriers().size(), 2); + EXPECT_EQ(commands.barriers().at(0), false); + EXPECT_EQ(commands.barriers().at(1), false); +} + +TEST(CommandBufferCmdTest, MemcpyCmd) { + se::StreamExecutor* executor = GpuExecutor(); + + auto stream = executor->CreateStream().value(); + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=42, b=0 + se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory b = executor->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); + TF_ASSERT_OK(stream->MemZero(&b, byte_length)); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/1, byte_length, /*color=*/0); + + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + + // Prepare commands sequence for constructing command buffer. + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice_b, slice_a, byte_length); + + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({a, b}, 0, executor->GetAllocator()); + + CommandBufferCmd::StateManager state; + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + CommandBufferCmd::RecordParams record_params = {state}; + + auto command_buffer = se::CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(commands.Record(params, record_params, command_buffer.get())); + + // Execute command buffer and verify that it copied the memory. + TF_ASSERT_OK(executor->Submit(stream.get(), *command_buffer)); + + // Copy `b` data back to host. + std::vector dst(4, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 42)); +} + +TEST(CommandBufferCmdTest, LaunchCmd) { + se::StreamExecutor* executor = GpuExecutor(); + + auto stream = executor->CreateStream().value(); + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=42, b=0 + se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory b = executor->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); + TF_ASSERT_OK(stream->MemZero(&b, byte_length)); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/1, byte_length, /*color=*/0); + + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + + auto args = {slice_a, slice_a, slice_b}; // b = a + a + auto args_access = {MemoryAccess::kRead, MemoryAccess::kRead, + MemoryAccess::kWrite}; + + // Prepare commands sequence for constructing command buffer. + CommandBufferCmdSequence commands; + commands.Emplace(s0, "add", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + + // Initialize command sequence and load device kernels. + Thunk::ExecutableSource source = { +#if defined(GOOGLE_CUDA) + /*text=*/se::gpu::internal::kAddI32Kernel, + /*binary=*/{} +#elif defined(TENSORFLOW_USE_ROCM) + /*text=*/{}, + /*binary=*/se::gpu::internal::kAddI32KernelModule +#endif + }; + + CommandBufferCmd::StateManager state; + TF_ASSERT_OK(commands.Initialize({executor, source}, state)); + + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({a, b}, 0, executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + CommandBufferCmd::RecordParams record_params = {state}; + + auto command_buffer = se::CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(commands.Record(params, record_params, command_buffer.get())); + + // Execute command buffer and verify that it copied the memory. + TF_ASSERT_OK(executor->Submit(stream.get(), *command_buffer)); + + // Copy `b` data back to host. + std::vector dst(4, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 42 + 42)); +} + +TEST(CommandBufferCmdStateManageTest, GetOrCreateState) { + struct TestState : public CommandBufferCmd::State { + int32_t value = 0; + }; + + // We need a fake command buffer pointer to use as a key. + CommandBufferCmd* cmd = reinterpret_cast(0x1234567); + + CommandBufferCmd::StateManager state_manager; + + auto* state0 = state_manager.GetOrNull(cmd); + ASSERT_EQ(state0, nullptr); + + auto* state1 = state_manager.GetOrCreate(cmd); + ASSERT_EQ(state1->value, 0); + state1->value += 42; + + auto* state2 = state_manager.GetOrCreate(cmd); + ASSERT_EQ(state2->value, 42); + ASSERT_EQ(state1, state2); +} + +TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) { + se::StreamExecutor* executor = GpuExecutor(); + + auto stream = executor->CreateStream().value(); + BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0); + BufferAllocation alloc1(/*index=*/1, /*size=*/1024, /*color=*/0); + + CommandBufferCmd::BufferUsageVector buffers = { + {BufferAllocation::Slice(&alloc0, 0, 1024), MemoryAccess::kRead}, + {BufferAllocation::Slice(&alloc1, 0, 1024), MemoryAccess::kWrite}}; + + TracedCommandBuffer traced_cmd_buffer(buffers, /*capacity=*/2); + + se::DeviceMemoryBase mem0(reinterpret_cast(0x01234567)); + se::DeviceMemoryBase mem1(reinterpret_cast(0x12345670)); + + BufferAllocations allocations({mem0, mem1}, 0, executor->GetAllocator()); + + // No-op trace callback to count how many times it was called. + int64_t num_calls = 0; + auto trace = [&](se::Stream*) { + num_calls++; + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer0, + traced_cmd_buffer.GetOrTraceCommandBuffer( + &allocations, executor, stream.get(), trace)); + + TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer1, + traced_cmd_buffer.GetOrTraceCommandBuffer( + &allocations, executor, stream.get(), trace)); + + // Check that command buffer was reused as buffer allocations didn't change. + ASSERT_EQ(command_buffer0, command_buffer1); + EXPECT_EQ(num_calls, 1); + + // Check that when memory address changes we re-trace the command buffer. + se::DeviceMemoryBase mem2(reinterpret_cast(0x23456701)); + allocations = BufferAllocations({mem0, mem2}, 0, executor->GetAllocator()); + + TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer2, + traced_cmd_buffer.GetOrTraceCommandBuffer( + &allocations, executor, stream.get(), trace)); + + ASSERT_NE(command_buffer0, command_buffer2); + EXPECT_EQ(num_calls, 2); + + // Check that we keep first command buffer in cache. + allocations = BufferAllocations({mem0, mem1}, 0, executor->GetAllocator()); + + TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer3, + traced_cmd_buffer.GetOrTraceCommandBuffer( + &allocations, executor, stream.get(), trace)); + ASSERT_EQ(command_buffer0, command_buffer3); + EXPECT_EQ(num_calls, 2); + + // Check that we trace a new graph when buffer allocation pattern is new. + allocations = BufferAllocations({mem0, mem0}, 0, executor->GetAllocator()); + + TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer4, + traced_cmd_buffer.GetOrTraceCommandBuffer( + &allocations, executor, stream.get(), trace)); + ASSERT_NE(command_buffer4, command_buffer3); + ASSERT_NE(command_buffer4, command_buffer2); + EXPECT_EQ(num_calls, 3); + + // Check that we still keep the previous graph in cache. + allocations = BufferAllocations({mem0, mem1}, 0, executor->GetAllocator()); + + TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer5, + traced_cmd_buffer.GetOrTraceCommandBuffer( + &allocations, executor, stream.get(), trace)); + ASSERT_EQ(command_buffer0, command_buffer5); + EXPECT_EQ(num_calls, 3); +} + +//===----------------------------------------------------------------------===// +// Performance benchmarks below +//===----------------------------------------------------------------------===// + +static void BM_GetOrTraceCommandBuffer(benchmark::State& state) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0); + BufferAllocation alloc1(/*index=*/1, /*size=*/1024, /*color=*/0); + + CommandBufferCmd::BufferUsageVector buffers = { + {BufferAllocation::Slice(&alloc0, 0, 1024), MemoryAccess::kRead}, + {BufferAllocation::Slice(&alloc1, 0, 1024), MemoryAccess::kWrite}}; + + se::DeviceMemoryBase mem0(reinterpret_cast(0x01234567)); + se::DeviceMemoryBase mem1(reinterpret_cast(0x12345670)); + + std::array allocations = { + BufferAllocations({mem0, mem1}, 0, executor->GetAllocator()), + BufferAllocations({mem1, mem0}, 0, executor->GetAllocator()), + BufferAllocations({mem0, mem0}, 0, executor->GetAllocator()), + BufferAllocations({mem1, mem1}, 0, executor->GetAllocator()), + }; + + int32_t index = 0; + TracedCommandBuffer traced_cmd_buffer(buffers); + + auto trace = [](se::Stream*) { return absl::OkStatus(); }; + absl::FunctionRef trace_ref(trace); + + for (auto s : state) { + TF_CHECK_OK(traced_cmd_buffer + .GetOrTraceCommandBuffer(&allocations[index++ % 4], + executor, stream.get(), trace_ref) + .status()); + } +} + +BENCHMARK(BM_GetOrTraceCommandBuffer); + +} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/command_buffer_thunk.cc b/xla/service/gpu/runtime/command_buffer_thunk.cc new file mode 100644 index 0000000000000..b1de1a0f86f87 --- /dev/null +++ b/xla/service/gpu/runtime/command_buffer_thunk.cc @@ -0,0 +1,330 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/command_buffer_thunk.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/gpu/runtime/annotation.h" +#include "xla/service/gpu/runtime/command_buffer_cmd.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/profiler_lock.h" +#include "tsl/profiler/lib/traceme.h" +#include "tsl/profiler/lib/traceme_encode.h" + +namespace xla::gpu { + +using tsl::profiler::TraceMe; +using tsl::profiler::TraceMeEncode; + +//===----------------------------------------------------------------------===// +// CommandBufferThunk +//===----------------------------------------------------------------------===// + +CommandBufferThunk::ExecutorCommandBuffer::ExecutorCommandBuffer( + std::unique_ptr command_buffer) + : command_buffer(std::move(command_buffer)) {} + +CommandBufferThunk::CommandBufferThunk(CommandBufferCmdSequence commands, + ThunkInfo thunk_info, + std::optional thunks) + : Thunk(Thunk::kCommandBuffer, std::move(thunk_info)), + commands_(std::move(commands)), + thunks_(std::move(thunks)), + state_(std::make_shared()) { + // When we create a new command buffer thunk (which happens when we + // instantiate a new Gpu executable) we evict command buffers for all + // previously instantiated executables. If previously instantiated executable + // will be executed again, it will simply reconstruct command buffer from + // a command buffer cmd sequence which is not terribly expensive (few + // milliseconds for large command buffers). With this approach we keep command + // buffers (CUDA graphs) resident in device memory only for executable that + // are actually used. + // + // In a perfect world higher level framework (JAX, Tensorflow, PyTorch) would + // be more aggressive with destroying unused executables, however today they + // all have a pretty large LRU cache for keeping O(1000) XLA executables. + EvictCommandBuffers(); + TrackCommandBuffers(state_); +} + +bool CommandBufferThunk::ExecutorCommandBuffer::ShouldUpdateCommandBuffer( + const CommandBufferCmdSequence& commands, + const Thunk::ExecuteParams& params) { + bool should_update = false; + const BufferAllocations* allocs = params.buffer_allocations; + + // We check only allocations referenced by commands in a cmd sequence, and + // leave every other entry default initialized (nullptr device memory). + for (BufferAllocation::Index index : commands.allocs_indices()) { + se::DeviceMemoryBase alloc = allocs->GetDeviceAddress(index); + + if (recorded_allocs.size() <= index) { + recorded_allocs.resize(index + 1); + should_update = true; + } + + if (!recorded_allocs[index].IsSameAs(alloc)) { + recorded_allocs[index] = alloc; + should_update = true; + } + } + + return should_update; +} + +absl::Status CommandBufferThunk::Prepare(const PrepareParams& params, + ResourceRequests& resource_requests) { + // We might end up with empty command sequence if all of the captured fusions + // are no-op (e.g. memcpy of size 0) and we have no emitted thunks for them. + if (commands_.empty()) return absl::OkStatus(); + + TF_RETURN_IF_ERROR(commands_.Prepare(params, resource_requests)); + + // Always prepare thunks if they are present so we are ready to fall back + // on them if we detect profiling activity. + if (thunks_.has_value()) { + for (auto& thunk : *thunks_) { + TF_RETURN_IF_ERROR(thunk->Prepare(params, resource_requests)); + } + } + + return absl::OkStatus(); +} + +absl::Status CommandBufferThunk::Initialize(const InitializeParams& params) { + // We might end up with empty command sequence if all of the captured fusions + // are no-op (e.g. memcpy of size 0) and we have no emitted thunks for them. + if (commands_.empty()) return absl::OkStatus(); + + TF_ASSIGN_OR_RETURN(std::shared_ptr cmd_buffer, + GetOrCreateCommandBuffer(params.executor)); + absl::MutexLock lock(&cmd_buffer->mutex); + + // Initialize commands. + TF_RETURN_IF_ERROR(commands_.Initialize(params, cmd_buffer->state)); + + // Always initialize thunks if they are present so we are ready to fall back + // on them if we detect profiling activity. + if (thunks_.has_value()) { + for (auto& thunk : *thunks_) { + TF_RETURN_IF_ERROR(thunk->Initialize(params)); + } + } + + // Construct ExecuteParams with empty fields for everything that is not needed + // for recording commands. + Thunk::ExecuteParams execute_params( + params.buffer_allocations, params.stream, + params.command_buffer_trace_stream, {}, params.collective_params, + params.collective_cliques, /*device_to_host_stream=*/nullptr, + /*host_to_device_stream=*/nullptr, + /*send_device_memory_function=*/nullptr, + /*recv_device_memory_function=*/nullptr); + + // If command buffer is in `kCreate` state it means that command buffer + // sequence was never recorded into it. We initialize all command buffers + // before execution, because command buffers when instantiated will allocate + // memory on device and this might lead to deadlocks when we have concurrent + // NCCL operations in flight. + if (cmd_buffer->command_buffer->state() == + se::CommandBuffer::State::kCreate && + cmd_buffer->ShouldUpdateCommandBuffer(commands_, execute_params)) { + VLOG(3) << "Initialize command buffer on device #" + << params.executor->device_ordinal() + << " by recoding command buffer cmd sequence" + << "; num_commands=" << commands_.size(); + + TraceMe trace([&] { + return TraceMeEncode("command_buffer::initialize", + {{"device", params.executor->device_ordinal()}, + {"num_commands", commands_.size()}}); + }); + + uint64_t start_micros = tsl::Env::Default()->NowMicros(); + + CommandBufferCmd::RecordParams record_params = {cmd_buffer->state}; + TF_RETURN_IF_ERROR(commands_.Record(execute_params, record_params, + cmd_buffer->command_buffer.get())); + + uint64_t end_micros = tsl::Env::Default()->NowMicros(); + VLOG(3) << "Initialized command buffer on device #" + << params.executor->device_ordinal() << " in " + << (end_micros - start_micros) + << " μs; num_commands=" << commands_.size(); + } + + return absl::OkStatus(); +} + +absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { + // We might end up with empty command sequence if all of the captured fusions + // are no-op (e.g. memcpy of size 0) and we have no emitted thunks for them. + if (commands_.empty()) return absl::OkStatus(); + + // TODO(b/290773547): Profiler (CUPTI) + CUDA graphs lead to memory + // corruption. As a work around disable command buffers (CUDA graphs) and run + // everything in op-by-op mode. + if (tsl::profiler::ProfilerLock::HasActiveSession() && thunks_.has_value()) { + VLOG(1) << "Execute command buffer thunk as a regular thunk sequence " + "because we detected active profiling session"; + const ModuleAnnotations* annotations = GetCurrentModuleAnnotations(); + for (auto& thunk : *thunks_) { + auto scoped_annotation = + GetKernelAnnotation(annotations, thunk->profile_annotation()); + TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(params)); + } + return absl::OkStatus(); + } + + se::StreamExecutor* executor = params.stream->parent(); + TF_ASSIGN_OR_RETURN(std::shared_ptr cmd_buffer, + GetOrCreateCommandBuffer(executor)); + + absl::MutexLock lock(&cmd_buffer->mutex); + + if (cmd_buffer->ShouldUpdateCommandBuffer(commands_, params)) { + VLOG(3) << "Update command buffer on device #" << executor->device_ordinal() + << " by recoding command buffer cmd sequence" << " after " + << cmd_buffer->num_executions << " executions since last update" + << "; num_commands=" << commands_.size(); + + TraceMe trace([&] { + cmd_buffer->mutex.AssertHeld(); + return TraceMeEncode("command_buffer::update", + {{"device", executor->device_ordinal()}, + {"num_commands", commands_.size()}, + {"num_executions", cmd_buffer->num_executions}}); + }); + + uint64_t start_micros = tsl::Env::Default()->NowMicros(); + + CommandBufferCmd::RecordParams record_params = {cmd_buffer->state}; + TF_RETURN_IF_ERROR(commands_.Record(params, record_params, + cmd_buffer->command_buffer.get())); + + uint64_t end_micros = tsl::Env::Default()->NowMicros(); + VLOG(3) << "Updated command buffer in " << (end_micros - start_micros) + << " μs; num_commands=" << commands_.size(); + cmd_buffer->num_executions = 0; + } + + ++cmd_buffer->num_executions; + + VLOG(3) << "Execute command buffer on device #" << executor->device_ordinal() + << "; num_executions=" << cmd_buffer->num_executions; + + TraceMe trace([&] { + cmd_buffer->mutex.AssertHeld(); + return TraceMeEncode("command_buffer::execute", + {{"device", executor->device_ordinal()}, + {"num_commands", commands_.size()}, + {"num_executions", cmd_buffer->num_executions}}); + }); + + return executor->Submit(params.stream, *cmd_buffer->command_buffer); +} + +absl::StatusOr> +CommandBufferThunk::GetOrCreateCommandBuffer(se::StreamExecutor* executor) { + absl::MutexLock lock(&state_->mutex); + + // Check if command buffer already exists + if (auto it = state_->command_buffers.find(executor); + it != state_->command_buffers.end()) { + return it->second; + } + + // Create a new empty command buffer. + TF_ASSIGN_OR_RETURN(auto command_buffer, se::CommandBuffer::Create(executor)); + auto emplaced = state_->command_buffers.emplace( + executor, + std::make_shared(std::move(command_buffer))); + + return emplaced.first->second; +} + +//===----------------------------------------------------------------------===// +// Command buffer eviction +//===----------------------------------------------------------------------===// + +struct CommandBufferThunk::GlobalState { + absl::Mutex mutex; + std::vector> state + ABSL_GUARDED_BY(mutex); +}; + +CommandBufferThunk::GlobalState* CommandBufferThunk::GetGlobalState() { + static auto* global_state = new GlobalState(); + return global_state; +} + +void CommandBufferThunk::TrackCommandBuffers( + std::weak_ptr state) { + auto* global_state = GetGlobalState(); + absl::MutexLock global_state_lock(&global_state->mutex); + global_state->state.push_back(state); +} + +void CommandBufferThunk::EvictCommandBuffers() { + TraceMe trace([&] { return "EvictCommandBuffers"; }); + + auto* global_state = GetGlobalState(); + absl::MutexLock global_state_lock(&global_state->mutex); + VLOG(3) << "Evict command buffer thunk command buffers; tracked thunks = " + << global_state->state.size(); + + // Erase state for already destroyed thunks. + global_state->state.erase( + std::remove_if(global_state->state.begin(), global_state->state.end(), + [](auto& weak_ptr) { return weak_ptr.expired(); }), + global_state->state.end()); + + // Evict command buffers for all tracked thunks. + int64_t num_evicted = 0; + for (auto& weak_ptr : global_state->state) { + auto ptr = weak_ptr.lock(); + if (!ptr) continue; + + // Evict all command buffers. + absl::MutexLock state_lock(&ptr->mutex); + num_evicted += ptr->command_buffers.size(); + ptr->command_buffers.clear(); + } + + if (num_evicted > 0) { + VLOG(3) << "Evicted " << num_evicted + << " command buffer thunk command buffers"; + } +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/command_buffer_thunk.h b/xla/service/gpu/runtime/command_buffer_thunk.h new file mode 100644 index 0000000000000..c1c5f1feba64c --- /dev/null +++ b/xla/service/gpu/runtime/command_buffer_thunk.h @@ -0,0 +1,144 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_THUNK_H_ + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "xla/service/gpu/runtime/command_buffer_allocations.h" +#include "xla/service/gpu/runtime/command_buffer_cmd.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream_executor.h" + +namespace xla::gpu { + +class CommandBufferThunk : public Thunk { + public: + CommandBufferThunk(CommandBufferCmdSequence commands, ThunkInfo thunk_info, + std::optional thunks = std::nullopt); + + absl::Status Prepare(const PrepareParams& params, + ResourceRequests& resource_requests) override; + absl::Status Initialize(const InitializeParams& params) override; + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + // Return the allocation address that was lazilly allocated inside command + // buffer. This API is required when the buffers are allocated inside command + // buffer but will be consumed by non-command buffer operations. + absl::StatusOr GetCommandBufferAllocationAddress( + const ExecuteParams& params, int64_t index); + + private: + // Command buffer instantiated on a `se::StreamExecutor` instance, and + // auxiliary state required for efficient command buffer updates. + struct ExecutorCommandBuffer { + explicit ExecutorCommandBuffer( + std::unique_ptr command_buffer); + + // Returns true if `commands` cmd sequence has to be recorded into + // `command_buffer` to update it (see `recorded_allocs` below). + bool ShouldUpdateCommandBuffer(const CommandBufferCmdSequence& commands, + const Thunk::ExecuteParams& params) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex); + + // se::CommandBuffer is not thread safe, and we guard it with a mutex to + // guarantee that we do not mutate it concurrently. + absl::Mutex mutex; + std::unique_ptr command_buffer ABSL_GUARDED_BY(mutex); + + // A manager for an external state attached by commands in a command + // sequence to a command buffer. + CommandBufferCmd::StateManager state ABSL_GUARDED_BY(mutex); + + // TODO(ezhulenev): We need to move command buffer allocations all the way + // up to the GpuExecutable as we can have Allocate and Free commands in + // different command buffers. Consider making it a part of + // BufferAllocations (as std::unique_ptr member). + + // Memory allocations performed by a `command_buffer`. + CommandBufferAllocations allocations ABSL_GUARDED_BY(mutex); + + // Mapping from buffer allocation index to the device memory passed at + // that index to the last call of `commands_.Record(...)` for + // `command_buffer`. We can just use a vector instead of map because + // `BufferAllocation::Index` is a unique identifier assigned + // contiguously and thus can be used as array index. + // + // If no device memory addresses changed from a previous call to + // `Record`, we can skip command buffer update and simply submit it for + // execution on a stream. All other pieces of information (like thread + // and block sizes) captured by commands at construction time and do not + // change. + std::vector recorded_allocs ABSL_GUARDED_BY(mutex); + + // Number of command buffer executions since last update. + int64_t num_executions ABSL_GUARDED_BY(mutex) = 0; + }; + + // Command buffer thunk owns commands buffers instantiated on all executors. + struct State { + absl::Mutex mutex; + absl::flat_hash_map> + command_buffers ABSL_GUARDED_BY(mutex); + }; + + // Returns a command buffer instantiated for `executor` or creates new one. + absl::StatusOr> + GetOrCreateCommandBuffer(se::StreamExecutor* executor); + + // Each individual command buffer allocates state on device (CUDA graph) and + // it adds up pretty quickly. To prevent OOM errors we proactively evict + // command buffers from device by clearing command buffer thunk state. We use + // global state to track all command buffer thunks in a process and coordinate + // command buffer eviction. + struct GlobalState; + + // Returns a global state of tracked command buffers thunks. + static GlobalState* GetGlobalState(); + + // Adds command buffer thunk state for tracking. + static void TrackCommandBuffers(std::weak_ptr state); + + // Evicts all previously instantiated command buffers. + static void EvictCommandBuffers(); + + // Command sequence that initializes command buffers on each executor. + CommandBufferCmdSequence commands_; + + // Thunk sequence that executes the same commands as in `commands_` but using + // thunk mechanism. We use it as a fallback mechanism to work around CUPTI + // bugs that lead to memory corruption when CUPTI traces CUDA graph execution. + std::optional thunks_; + + // Command buffer thunk state allocated in heap to allow global (per-process) + // management of instantiated command buffers. + std::shared_ptr state_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_THUNK_H_ diff --git a/xla/service/gpu/runtime/command_buffer_thunk_test.cc b/xla/service/gpu/runtime/command_buffer_thunk_test.cc new file mode 100644 index 0000000000000..4a78b621b3dd4 --- /dev/null +++ b/xla/service/gpu/runtime/command_buffer_thunk_test.cc @@ -0,0 +1,1186 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/command_buffer_thunk.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/runtime/command_buffer_allocations.h" +#include "xla/service/gpu/runtime/command_buffer_cmd.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/service/platform_util.h" +#include "xla/service/service_executable_run_options.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/blas.h" +#include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_test_kernels.h" +#include "xla/stream_executor/gpu/gpu_types.h" // IWYU pragma: keep +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/types.h" // IWYU pragma: keep +#include "xla/xla_data.pb.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +#ifdef GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" +#endif + +namespace xla::gpu { + +using MemoryAccess = CommandBufferCmd::MemoryAccess; +using KernelArgsPacking = se::MultiKernelLoaderSpec::KernelArgsPacking; + +static se::StreamExecutor* GpuExecutor() { + auto name = + absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); + auto* platform = se::PlatformManager::PlatformWithName(name).value(); + return platform->ExecutorForDevice(0).value(); +} + +static Thunk::ExecutableSource ExecutableSource() { + Thunk::ExecutableSource source = { +#if defined(GOOGLE_CUDA) + /*text=*/se::gpu::internal::kAddI32Kernel, + /*binary=*/{} +#elif defined(TENSORFLOW_USE_ROCM) + /*text=*/{}, + /*binary=*/se::gpu::internal::kAddI32KernelModule +#endif + }; + return source; +} + +static KernelArgsPacking CreateDefaultArgsPacking() { + using Packed = absl::StatusOr>; + + return [=](const se::Kernel& kernel, const se::KernelArgs& args) -> Packed { + auto* mem_args = se::Cast(&args); + + return se::PackKernelArgs(mem_args->device_memory_args(), + args.number_of_shared_bytes()); + }; +} + +// Some of the tests rely on CUDA 12.3+ features. +static bool IsAtLeastCuda12300() { +#if defined(TENSORFLOW_USE_ROCM) + return false; +#endif +#if CUDA_VERSION >= 12030 + return true; +#endif + return false; +} + +// Give a short aliases to execution threads. +static constexpr auto s0 = ExecutionStreamId(0); +static constexpr auto s1 = ExecutionStreamId(1); + +TEST(CommandBufferThunkTest, MemcpyCmd) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=42, b=0 + se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory b = executor->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); + TF_ASSERT_OK(stream->MemZero(&b, byte_length)); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/1, byte_length, /*color=*/0); + + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + + // Prepare commands sequence for constructing command buffer. + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice_b, slice_a, byte_length); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({a, b}, 0, executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + // Execute command buffer thunk and verify that it copied the memory. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `b` data back to host. + std::vector dst(4, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 42)); + + // Try to update the command buffer with the same buffers. + TF_ASSERT_OK(stream->MemZero(&b, byte_length)); + + // Thunk execution should automatically update underlying command buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `b` data back to host. + std::fill(dst.begin(), dst.end(), 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 42)); +} + +TEST(CommandBufferThunkTest, MemzeroCmd) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=42 + se::DeviceMemory a = executor->AllocateArray(length, 0); + TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + + // Prepare commands sequence for constructing command buffer. + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice_a); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({a}, 0, executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + // Execute command buffer thunk and verify that it zeroes the memory. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `a` data back to host. + std::vector dst(4, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), a, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 0)); +} + +TEST(CommandBufferThunkTest, Memset32Cmd) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=42 + se::DeviceMemory a = executor->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + + // Prepare commands sequence for constructing command buffer. + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice_a, int32_t{84}); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({a}, 0, executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + // Execute command buffer thunk and verify that it set the memory. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `a` data back to host. + std::vector dst(4, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), a, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 84)); +} + +TEST(CommandBufferThunkTest, Memset32CmdOnDifferentStreams) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + se::DeviceMemory a = executor->AllocateArray(2, 0); + TF_ASSERT_OK(stream->MemZero(&a, 2 * sizeof(int32_t))); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc(/*index=*/0, a.size(), /*color=*/0); + BufferAllocation::Slice slice0(&alloc, 0 * sizeof(int32_t), sizeof(int32_t)); + BufferAllocation::Slice slice1(&alloc, 1 * sizeof(int32_t), sizeof(int32_t)); + + // Prepare commands sequence for constructing command buffer. + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice0, int32_t{12}); + commands.Emplace(s1, slice1, int32_t{34}); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({a}, 0, executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + // Execute command buffer thunk and verify that it set the memory. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `a` data back to host. + std::vector dst(2, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), a, a.size())); + + ASSERT_EQ(dst, std::vector({12, 34})); +} + +// This test does the following operations: +// 1. Allocates memory region "a" and "c" outside command buffer. +// 2. Allocates memory region "b" inside command buffer. +// 3. MemCopyDeviceToDevice from "a" to "b" inside command buffer. + +// 4. MemCopyDeviceToDevice from "b" to "c" inside command buffer. +// 5. Free memory region "b" inside command buffer. +// 6. Verify that region "c" has the same content as "a". +TEST(CommandBufferThunkTest, MemallocFreeCmdSameThunk) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + // Prepare arguments: + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/1, byte_length, /*color=*/0); + BufferAllocation alloc_c(/*index=*/2, byte_length, /*color=*/0); + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + BufferAllocation::Slice slice_c(&alloc_c, 0, byte_length); + + // Prepare commands sequence for constructing command buffer. + CommandBufferCmdSequence commands; + commands.Emplace(s0, alloc_b); + commands.Emplace(s0, slice_b, slice_a, byte_length); + commands.Emplace(s0, slice_c, slice_b, byte_length); + commands.Emplace(s0, alloc_b); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + + // Prepare arguments: a=42, b=0 + se::DeviceMemory a = executor->AllocateArray(length, 0); + TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); + + se::DeviceMemory b(se::DeviceMemoryBase( + reinterpret_cast(BufferAllocations::kExternalAllocationMarker), + byte_length)); + se::DeviceMemory c = executor->AllocateArray(length, 0); + + auto external_allocation = std::make_unique(); + + BufferAllocations allocations({a, b, c}, 0, executor->GetAllocator(), + external_allocation.get()); + + ServiceExecutableRunOptions run_options; + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + // Execute command buffer thunk and verify that it copied the memory. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `b` data back to host. + std::vector dst(4, 0); + TF_ASSERT_OK(stream->Memcpy( + dst.data(), allocations.GetMutableDeviceAddress(2), byte_length)); + + ASSERT_EQ(dst, std::vector(4, 42)); +} + +// This test does the following operations: +// 1. Allocates memory region "a" and "c" outside command buffer. +// 2. Allocates memory region "b" inside command buffer thunk 1. +// 3. MemCopyDeviceToDevice from "a" to "b" inside command buffer 1. +// 4. MemCopyDeviceToDevice from "b" to "c" inside command buffer 2. +// 5. Free memory region "b" inside command buffer 2. +// 6. Verify that region "c" has the same content as "a". +TEST(CommandBufferThunkTest, MemallocFreeCmdAcrossThunk) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + // Prepare arguments: + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/1, byte_length, /*color=*/0); + BufferAllocation alloc_c(/*index=*/2, byte_length, /*color=*/0); + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + BufferAllocation::Slice slice_c(&alloc_c, 0, byte_length); + + // =================Thunk 1================================= + // Prepare commands sequence for constructing command buffer. + CommandBufferCmdSequence commands1; + commands1.Emplace(s0, alloc_b); + commands1.Emplace(s0, slice_b, slice_a, byte_length); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk1(std::move(commands1), Thunk::ThunkInfo(nullptr)); + + // Prepare arguments: a=42, b=0 + se::DeviceMemory a = executor->AllocateArray(length, 0); + TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); + se::DeviceMemory b(se::DeviceMemoryBase( + reinterpret_cast(BufferAllocations::kExternalAllocationMarker), + byte_length)); + se::DeviceMemory c = executor->AllocateArray(length, 0); + + auto external_allocation = std::make_unique(); + + BufferAllocations allocations({a, b, c}, 0, executor->GetAllocator(), + external_allocation.get()); + + ServiceExecutableRunOptions run_options; + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + // Execute command buffer thunk and verify that it copied the memory. + TF_ASSERT_OK(thunk1.ExecuteOnStream(params)); + + // =================Thunk 2================================= + CommandBufferCmdSequence commands2; + commands2.Emplace(s0, slice_c, slice_b, byte_length); + commands2.Emplace(s0, alloc_b); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk2(std::move(commands2), Thunk::ThunkInfo(nullptr)); + + // Execute command buffer thunk and verify that it copied the memory. + TF_ASSERT_OK(thunk2.ExecuteOnStream(params)); + + // Copy `c` data back to host. + std::vector dst(4, 0); + TF_ASSERT_OK(stream->Memcpy( + dst.data(), allocations.GetMutableDeviceAddress(2), byte_length)); + + ASSERT_EQ(dst, std::vector(4, 42)); +} + +TEST(CommandBufferThunkTest, LaunchCmd) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=42, b=0 + se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory b = executor->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); + TF_ASSERT_OK(stream->MemZero(&b, byte_length)); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/1, byte_length, /*color=*/0); + + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + + auto args = {slice_a, slice_a, slice_b}; // b = a + a + auto args_access = {MemoryAccess::kRead, MemoryAccess::kRead, + MemoryAccess::kWrite}; + + // Prepare commands sequence for constructing command buffer. + CommandBufferCmdSequence commands; + commands.Emplace(s0, "add", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({a, b}, 0, executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK( + thunk.Initialize({executor, source, &allocations, stream.get()})); + + // Execute command buffer thunk and verify that it added the value. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `b` data back to host. + std::vector dst(4, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 42 + 42)); + + // Prepare buffer allocation for updating command buffer: c=0 + se::DeviceMemory c = executor->AllocateArray(length, 0); + TF_ASSERT_OK(stream->MemZero(&c, byte_length)); + + // Update buffer allocation #1 to buffer `c`. + allocations = BufferAllocations({a, c}, 0, executor->GetAllocator()); + + // Thunk execution should automatically update underlying command buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `c` data back to host. + std::fill(dst.begin(), dst.end(), 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 42 + 42)); + + // Try to update the command buffer with the same buffers. + TF_ASSERT_OK(stream->MemZero(&c, byte_length)); + + // Thunk execution should automatically update underlying command buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `c` data back to host. + std::fill(dst.begin(), dst.end(), 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 42 + 42)); +} + +TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + auto packing = CreateDefaultArgsPacking(); + + se::MultiKernelLoaderSpec spec(/*arity=*/3, std::move(packing)); + spec.AddInProcessSymbol(se::gpu::internal::GetAddI32Kernel(), "add"); + + auto custom_kernel = + CustomKernel("add", std::move(spec), se::BlockDim(), + se::ThreadDim(4, 1, 1), /*shared_memory_bytes=*/0); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=42, b=0 + se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory b = executor->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); + TF_ASSERT_OK(stream->MemZero(&b, byte_length)); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/1, byte_length, /*color=*/0); + + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + + auto args = {slice_a, slice_a, slice_b}; // b = a + a + auto args_access = {MemoryAccess::kRead, MemoryAccess::kRead, + MemoryAccess::kWrite}; + + // Prepare commands sequence for constructing command buffer. + CommandBufferCmdSequence commands; + commands.Emplace(s0, "add", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({a, b}, 0, executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK( + thunk.Initialize({executor, source, &allocations, stream.get()})); + + // Execute command buffer thunk and verify that it added the value. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `b` data back to host. + std::vector dst(4, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 42 + 42)); + + // Prepare buffer allocation for updating command buffer: c=0 + se::DeviceMemory c = executor->AllocateArray(length, 0); + TF_ASSERT_OK(stream->MemZero(&c, byte_length)); + + // Update buffer allocation #1 to buffer `c`. + allocations = BufferAllocations({a, c}, 0, executor->GetAllocator()); + + // Thunk execution should automatically update underlying command buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `c` data back to host. + std::fill(dst.begin(), dst.end(), 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 42 + 42)); + + // Try to update the command buffer with the same buffers. + TF_ASSERT_OK(stream->MemZero(&c, byte_length)); + + // Thunk execution should automatically update underlying command buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `c` data back to host. + std::fill(dst.begin(), dst.end(), 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 42 + 42)); +} + +TEST(CommandBufferThunkTest, GemmCmd) { + if (!IsAtLeastCuda12300()) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t lhs_length = sizeof(float) * 2 * 4; + int64_t rhs_length = sizeof(float) * 4 * 3; + int64_t out_length = sizeof(float) * 2 * 3; + + // Prepare arguments: + // lhs = [1.0, 2.0, 3.0, 4.0 + // 5.0, 6.0, 7.0, 8.0] + // rhs = [1.0, 1.0, 1.0 + // 1.0, 1.0, 1.0 + // 1.0, 1.0, 1.0 + // 1.0, 1.0, 1.0] + se::DeviceMemory lhs = executor->AllocateArray(2 * 4); + std::vector lhs_arr{1, 2, 3, 4, 5, 6, 7, 8}; + TF_ASSERT_OK(stream->Memcpy(&lhs, lhs_arr.data(), lhs_length)); + + se::DeviceMemory rhs = executor->AllocateArray(4 * 3); + std::vector rhs_arr(12, 1); + TF_ASSERT_OK(stream->Memcpy(&rhs, rhs_arr.data(), rhs_length)); + + se::DeviceMemory out = executor->AllocateArray(2 * 3); + TF_ASSERT_OK(stream->MemZero(&out, out_length)); + + se::DeviceMemory workspace = + executor->AllocateArray(1024 * 1024); + TF_ASSERT_OK(stream->MemZero(&workspace, 1024 * 1024)); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_lhs(/*index=*/0, lhs_length, /*color=*/0); + BufferAllocation alloc_rhs(/*index=*/1, rhs_length, /*color=*/0); + BufferAllocation alloc_out(/*index=*/2, out_length, /*color=*/0); + BufferAllocation alloc_workspace(/*index=*/3, 1024 * 1024, /*color=*/0); + + BufferAllocation::Slice slice_lhs(&alloc_lhs, 0, lhs_length); + BufferAllocation::Slice slice_rhs(&alloc_rhs, 0, rhs_length); + BufferAllocation::Slice slice_out(&alloc_out, 0, out_length); + BufferAllocation::Slice slice_workspace(&alloc_workspace, 0, 1024 * 1024); + + auto config = + GemmConfig::For(ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), {}, {1}, + ShapeUtil::MakeShape(PrimitiveType::F32, {4, 3}), {}, {0}, + ShapeUtil::MakeShape(PrimitiveType::F32, {2, 3}), 1.0, + 0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt, + se::blas::kDefaultComputePrecision, false, false); + ASSERT_TRUE(config.ok()); + + // Prepare commands sequence for constructing command buffer. + CommandBufferCmdSequence commands; + commands.Emplace(s0, config.value(), slice_lhs, slice_rhs, slice_out, + slice_workspace, + /*deterministic=*/true); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({lhs, rhs, out, workspace}, 0, + executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; + TF_ASSERT_OK(thunk.Initialize( + {executor, source, &allocations, stream.get(), stream.get()})); + + // Execute command buffer thunk and verify that it executed a GEMM. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `out` data back to host. + std::vector dst(6, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), out, out_length)); + + ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); + + // Prepare buffer allocation for updating command buffer. + se::DeviceMemory updated_out = executor->AllocateArray(2 * 3); + TF_ASSERT_OK(stream->MemZero(&updated_out, out_length)); + + // Update buffer allocation to updated `out` buffer. + allocations = BufferAllocations({lhs, rhs, updated_out, workspace}, 0, + executor->GetAllocator()); + + // Thunk execution should automatically update underlying command buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `updated_out` data back to host. + std::fill(dst.begin(), dst.end(), 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), updated_out, out_length)); + + ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); + + // Try to update the command buffer with the same buffers. + TF_ASSERT_OK(stream->MemZero(&updated_out, out_length)); + + // Thunk execution should automatically update underlying command buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `updated_out` data back to host. + std::fill(dst.begin(), dst.end(), 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), updated_out, out_length)); + + ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); +} + +TEST(CommandBufferThunkTest, MultipleLaunchCmd) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=42, b=0 + se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory b = executor->AllocateArray(length, 0); + se::DeviceMemory c = executor->AllocateArray(length, 0); + se::DeviceMemory d = executor->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); + TF_ASSERT_OK(stream->MemZero(&b, byte_length)); + TF_ASSERT_OK(stream->Memset32(&c, 21, byte_length)); + TF_ASSERT_OK(stream->MemZero(&d, byte_length)); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/1, byte_length, /*color=*/0); + BufferAllocation alloc_c(/*index=*/2, byte_length, /*color=*/0); + BufferAllocation alloc_d(/*index=*/3, byte_length, /*color=*/0); + + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + BufferAllocation::Slice slice_c(&alloc_c, 0, byte_length); + BufferAllocation::Slice slice_d(&alloc_d, 0, byte_length); + + auto args = {slice_a, slice_a, slice_b}; // b = a + a + auto args_1 = {slice_c, slice_c, slice_d}; // d = c + c + auto args_access = {MemoryAccess::kRead, MemoryAccess::kRead, + MemoryAccess::kWrite}; + + // Prepare commands sequence for constructing command buffer. + CommandBufferCmdSequence commands; + commands.Emplace(s0, "add", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + commands.Emplace(s0, "add", args_1, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({a, b, c, d}, 0, executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK( + thunk.Initialize({executor, source, &allocations, stream.get()})); + + // Execute command buffer thunk and verify that it added the value. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `b` data back to host. + std::vector dst(4, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + ASSERT_EQ(dst, std::vector(4, 42 + 42)); + + // Copy `d` data back to host. + std::fill(dst.begin(), dst.end(), 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), d, byte_length)); + ASSERT_EQ(dst, std::vector(4, 21 + 21)); + + BufferAllocation alloc_e(/*index=*/3, byte_length, /*color=*/0); + BufferAllocation::Slice slice_e(&alloc_e, 0, byte_length); + + // Prepare buffer allocation for updating command buffer: e=0 + se::DeviceMemory e = executor->AllocateArray(length, 0); + TF_ASSERT_OK(stream->MemZero(&e, byte_length)); + + // Update buffer allocation #1 to buffer `c`. + allocations = BufferAllocations({a, b, c, e}, 0, executor->GetAllocator()); + + // Thunk execution should automatically update underlying command buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `b` data back to host. + std::fill(dst.begin(), dst.end(), 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + ASSERT_EQ(dst, std::vector(4, 42 + 42)); + + // Copy `e` data back to host. + std::fill(dst.begin(), dst.end(), 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), e, byte_length)); + ASSERT_EQ(dst, std::vector(4, 21 + 21)); + + // Try to update the command buffer with the same buffers. + TF_ASSERT_OK(stream->MemZero(&e, byte_length)); + + // Thunk execution should automatically update underlying command buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `b` data back to host. + std::fill(dst.begin(), dst.end(), 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + ASSERT_EQ(dst, std::vector(4, 42 + 42)); + + // Copy `e` data back to host. + std::fill(dst.begin(), dst.end(), 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), e, byte_length)); + ASSERT_EQ(dst, std::vector(4, 21 + 21)); +} + +TEST(CommandBufferThunkTest, IfCmd) { + if (!IsAtLeastCuda12300()) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: pred=true, a=42, b=0 + se::DeviceMemory pred = executor->AllocateArray(1, 0); + se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory b = executor->AllocateArray(length, 0); + + constexpr bool kTrue = true; + TF_ASSERT_OK(stream->Memcpy(&pred, &kTrue, 1)); + TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); + TF_ASSERT_OK(stream->MemZero(&b, byte_length)); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_p(/*index=*/0, 1, /*color=*/0); + BufferAllocation alloc_a(/*index=*/1, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/2, byte_length, /*color=*/0); + + BufferAllocation::Slice slice_p(&alloc_p, 0, 1); + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + + auto args = {slice_a, slice_a, slice_b}; // b = a + a + auto args_access = {MemoryAccess::kRead, MemoryAccess::kRead, + MemoryAccess::kWrite}; + + // Prepare commands sequence for `then` branch. + CommandBufferCmdSequence then_commands; + then_commands.Emplace(s0, "add", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + + // Prepare commands sequence for thunk. + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice_p, std::move(then_commands)); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({pred, a, b}, 0, executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK( + thunk.Initialize({executor, source, &allocations, stream.get()})); + + // Execute command buffer thunk and verify that it added the value. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `b` data back to host. + std::vector dst(4, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 42 + 42)); + + // Prepare buffer allocation for updating command buffer: c=0 + se::DeviceMemory c = executor->AllocateArray(length, 0); + TF_ASSERT_OK(stream->MemZero(&c, byte_length)); + + // Update buffer allocation #2 to buffer `c`. + allocations = BufferAllocations({pred, a, c}, 0, executor->GetAllocator()); + + // Thunk execution should automatically update underlying command buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `c` data back to host. + std::fill(dst.begin(), dst.end(), 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 42 + 42)); +} + +TEST(CommandBufferThunkTest, IfElseCmd) { + if (!IsAtLeastCuda12300()) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: pred=true, a=42, b=0 + se::DeviceMemory pred = executor->AllocateArray(1, 0); + se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory b = executor->AllocateArray(length, 0); + + constexpr bool kTrue = true; + TF_ASSERT_OK(stream->Memcpy(&pred, &kTrue, 1)); + TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); + TF_ASSERT_OK(stream->MemZero(&b, byte_length)); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_p(/*index=*/0, 1, /*color=*/0); + BufferAllocation alloc_a(/*index=*/1, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/2, byte_length, /*color=*/0); + + BufferAllocation::Slice slice_p(&alloc_p, 0, 1); + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + + // Prepare commands sequence for `then` & `else` branches. + CommandBufferCmdSequence then_commands; + CommandBufferCmdSequence else_commands; + + auto args_access = {MemoryAccess::kRead, MemoryAccess::kRead, + MemoryAccess::kWrite}; + + { // Then: b = a + a + auto args = {slice_a, slice_a, slice_b}; + then_commands.Emplace(s0, "add", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + } + + { // Else: b = b + b + auto args = {slice_b, slice_b, slice_b}; + else_commands.Emplace(s0, "add", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + } + + // Prepare commands sequence for thunk. + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice_p, std::move(then_commands), + std::move(else_commands)); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({pred, a, b}, 0, executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK( + thunk.Initialize({executor, source, &allocations, stream.get()})); + + // Execute command buffer thunk and verify that it added the value. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `b` data back to host. + std::vector dst(4, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 42 + 42)); + + // Change branch to `else` and check that it updated the `b` buffer. + constexpr bool kFalse = false; + TF_ASSERT_OK(stream->Memcpy(&pred, &kFalse, 1)); + + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + ASSERT_EQ(dst, std::vector(4, 2 * (42 + 42))); +} + +TEST(CommandBufferThunkTest, CaseCmd) { + if (!IsAtLeastCuda12300()) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: index=0, a=42, b=0 + se::DeviceMemory index = executor->AllocateArray(1, 0); + se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory b = executor->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&index, 0, sizeof(int32_t))); + TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); + TF_ASSERT_OK(stream->MemZero(&b, byte_length)); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_i(/*index=*/0, 1, /*color=*/0); + BufferAllocation alloc_a(/*index=*/1, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/2, byte_length, /*color=*/0); + + BufferAllocation::Slice slice_i(&alloc_i, 0, sizeof(int32_t)); + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + + // Prepare commands sequence for branches. + std::vector branches(2); + + auto args_access = {MemoryAccess::kRead, MemoryAccess::kRead, + MemoryAccess::kWrite}; + + { // Case 0: b = a + a + auto args = {slice_a, slice_a, slice_b}; + branches[0].Emplace(s0, "add", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + } + + { // Case 1: b = b + b + auto args = {slice_b, slice_b, slice_b}; + branches[1].Emplace(s0, "add", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + } + + // Prepare commands sequence for thunk. + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice_i, std::move(branches)); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({index, a, b}, 0, executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK( + thunk.Initialize({executor, source, &allocations, stream.get()})); + + // Execute command buffer thunk and verify that it added the value. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `b` data back to host. + std::vector dst(4, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 42 + 42)); + + // Change `index` to `1` and check that it updated the `b` buffer. + TF_ASSERT_OK(stream->Memset32(&index, 1, sizeof(int32_t))); + + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + ASSERT_EQ(dst, std::vector(4, 2 * (42 + 42))); +} + +TEST(CommandBufferThunkTest, ForCmd) { + if (!IsAtLeastCuda12300()) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: loop_cnt=0, a=1, b=0 + se::DeviceMemory loop_cnt = executor->AllocateArray(1, 0); + se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory b = executor->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&loop_cnt, 0, sizeof(int32_t))); + TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length)); + TF_ASSERT_OK(stream->MemZero(&b, byte_length)); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_cnt(/*index=*/0, 1, /*color=*/0); + BufferAllocation alloc_a(/*index=*/1, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/2, byte_length, /*color=*/0); + + BufferAllocation::Slice slice_cnt(&alloc_cnt, 0, sizeof(int32_t)); + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + + auto args = {slice_a, slice_b, slice_b}; // b = a + b + auto args_access = {MemoryAccess::kRead, MemoryAccess::kRead, + MemoryAccess::kWrite}; + + // Prepare commands sequence for loop `body`. + CommandBufferCmdSequence body_commands; + body_commands.Emplace(s0, "add", args, args_access, + LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + + // Prepare commands sequence for thunk. + CommandBufferCmdSequence commands; + commands.Emplace(s0, /*num_iterations=*/10, slice_cnt, + std::move(body_commands)); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({loop_cnt, a, b}, 0, executor->GetAllocator()); + + Thunk::ExecuteParams params = + Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), + stream.get(), {}, nullptr, nullptr); + + Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK( + thunk.Initialize({executor, source, &allocations, stream.get()})); + + // Execute command buffer thunk and verify that it added the value 10 times. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `b` data back to host. + std::vector dst(4, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 10)); +} + +TEST(CommandBufferThunkTest, WhileCmd) { + // TODO(ezhulenev): Find a way to test WhileCmd: add a test only TraceCmd that + // could allow us trace custom kernels to update while loop iterations. Or + // maybe add a CustomLaunchCmd and wrap loop update into custom kernel. +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/concurrent_region.cc b/xla/service/gpu/runtime/concurrent_region.cc deleted file mode 100644 index 083259960f407..0000000000000 --- a/xla/service/gpu/runtime/concurrent_region.cc +++ /dev/null @@ -1,167 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/concurrent_region.h" - -#include -#include - -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/stream_pool.h" -#include "xla/stream_executor/stream.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { - -//===----------------------------------------------------------------------===// -// Definitions for ConcurrentRegionStatus. -//===----------------------------------------------------------------------===// - -ConcurrentRegionStatus::ConcurrentRegionStatus( - const ServiceExecutableRunOptions* run_options, int num_borrowed_streams) - : num_borrowed_streams_(num_borrowed_streams), - run_options_(run_options), - stream_index_(0), - capture_stream_(nullptr) {} - -ConcurrentRegionStatus::~ConcurrentRegionStatus() { - DCHECK(!IsInConcurrentRegion()); -} - -// Assign a stream in a round-robin fashion. Either the capture stream or one of -// the borrowed streams is returned. -se::Stream* ConcurrentRegionStatus::GetNextStream() { - DCHECK(IsInConcurrentRegion()); - if (borrowed_streams_.empty()) { - return nullptr; - } - - int index = stream_index_ % (borrowed_streams_.size() + 1); - stream_index_++; - - if (index == 0) { - return capture_stream_; - } - - return borrowed_streams_[index - 1].get(); -} - -absl::StatusOr ConcurrentRegionStatus::GetStream(int index) { - DCHECK(IsInConcurrentRegion()); - - if (index < 0 || index >= region_size_) { - return absl::OutOfRangeError("Invalid stream index"); - } - - if (index == 0) { - return capture_stream_; - } - - return borrowed_streams_[index - 1].get(); -} - -absl::Status ConcurrentRegionStatus::StartConcurrentRegion( - se::Stream* capture_stream, int64_t size) { - if (disabled_) { - return absl::OkStatus(); - } - - DCHECK(!IsInConcurrentRegion()); - se::StreamExecutor* executor = run_options_->stream()->parent(); - - // Stream borrowing should only happen in the first call to this function. - if (borrowed_streams_.empty()) { - TF_ASSIGN_OR_RETURN(std::vector borrowed_streams, - run_options_->BorrowStreams(executor->device_ordinal(), - num_borrowed_streams_)); - for (StreamPool::Ptr& stream : borrowed_streams) { - borrowed_streams_.push_back(std::move(stream)); - } - } - - // Switch borrowed streams into capture mode. We only synchronize enough - // streams to run the kernels. - for (int i = 0; i < std::min(size - 1, num_borrowed_streams_); ++i) { - borrowed_streams_[i]->ThenWaitFor(capture_stream); - } - - region_size_ = size; - capture_stream_ = capture_stream; - return absl::OkStatus(); -} - -void ConcurrentRegionStatus::EndConcurrentRegion() { - if (disabled_) { - return; - } - - DCHECK(IsInConcurrentRegion()); - - // Synchronize main capture stream with all borrowed streams in capture mode. - for (int i = 0; i < std::min(region_size_ - 1, num_borrowed_streams_); - ++i) { - capture_stream_->ThenWaitFor(borrowed_streams_[i].get()); - } - - stream_index_ = 0; - capture_stream_ = nullptr; -} - -bool ConcurrentRegionStatus::IsInConcurrentRegion() { - return capture_stream_ != nullptr; -} - -//===----------------------------------------------------------------------===// -// Define custom calls that mark the concurrent region in CUDA graphs. -//===----------------------------------------------------------------------===// - -using xla::runtime::CustomCall; - -static absl::Status RegionBegin(const ServiceExecutableRunOptions* run_options, - ConcurrentRegionStatus* region_status, - int64_t size) { - se::Stream* capture_stream = run_options->stream(); - return region_status->StartConcurrentRegion(capture_stream, size); -} - -static absl::Status RegionEnd(ConcurrentRegionStatus* region_status) { - region_status->EndConcurrentRegion(); - return absl::OkStatus(); -} - -//===----------------------------------------------------------------------===// - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Begin, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.concurrent_region.begin") - .UserData() - .UserData() - .Attr("size")); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL(End, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.concurrent_region.end") - .UserData()); - -void RegisterConcurrentRegionCustomCalls( - runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.concurrent_region.begin", Begin); - registry.Register("xla.gpu.concurrent_region.end", End); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/concurrent_region.h b/xla/service/gpu/runtime/concurrent_region.h deleted file mode 100644 index c662282ef6efa..0000000000000 --- a/xla/service/gpu/runtime/concurrent_region.h +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_CONCURRENT_REGION_H_ -#define XLA_SERVICE_GPU_RUNTIME_CONCURRENT_REGION_H_ - -#include - -#include "xla/runtime/custom_call_registry.h" -#include "xla/service/service_executable_run_options.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime kernel launch custom calls. -void RegisterConcurrentRegionCustomCalls( - runtime::DirectCustomCallRegistry& registry); - -// The state to keep track of the information regarding concurrent regions -// between custom calls. -class ConcurrentRegionStatus { - public: - explicit ConcurrentRegionStatus( - const ServiceExecutableRunOptions* run_options, - int num_borrowed_streams = 10); - - ~ConcurrentRegionStatus(); - - absl::Status StartConcurrentRegion(se::Stream* capture_stream, int64_t size); - void EndConcurrentRegion(); - - // Temporarily disable concurrent execution when we run GPU graphs op-by-op. - // If disabled_ is set to true, StartConcurrentRegion will become an no-op and - // IsInConcurrentRegion always returns false. - void DisableConcurrentRegion() { disabled_ = true; } - void EnableConcurrentRegion() { disabled_ = false; } - - // Get a stream on which the concurrent-executable kernel runs. It returns a - // different stream each time to avoid building dependencies in the CUDA - // graph. - se::Stream* GetNextStream(); - - absl::StatusOr GetStream(int index); - - bool IsInConcurrentRegion(); - - private: - const int num_borrowed_streams_; - std::vector borrowed_streams_; - const ServiceExecutableRunOptions* run_options_; - - bool disabled_ = false; - int32_t stream_index_; - - // It is set to nullptr if not in a concurrent region. - se::Stream* capture_stream_; - int region_size_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_CONCURRENT_REGION_H_ diff --git a/xla/service/gpu/runtime/conditional_thunk.cc b/xla/service/gpu/runtime/conditional_thunk.cc new file mode 100644 index 0000000000000..8680126792fe0 --- /dev/null +++ b/xla/service/gpu/runtime/conditional_thunk.cc @@ -0,0 +1,128 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/conditional_thunk.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/service/gpu/variant_visitor.h" +#include "xla/status.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/memory_allocation.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +ConditionalThunk::ConditionalThunk( + ThunkInfo thunk_info, ConditionalThunkConfig config, + const BufferAllocation::Slice& branch_index_buffer_index) + : Thunk(Kind::kConditional, thunk_info), + config_(std::move(config)), + branch_index_buffer_index_(branch_index_buffer_index) {} + +absl::Status ConditionalThunk::Prepare(const PrepareParams& params, + ResourceRequests& resource_requests) { + if (config_.branch_index_is_bool) { + TF_RET_CHECK(config_.branch_thunks.size() == 2); + } else { + TF_RET_CHECK(!config_.branch_thunks.empty()); + } + for (auto& branch_thunk : config_.branch_thunks) { + TF_RETURN_IF_ERROR(branch_thunk->Prepare(params, resource_requests)); + } + return absl::OkStatus(); +} + +absl::Status ConditionalThunk::Initialize(const InitializeParams& params) { + if (config_.branch_index_is_bool) { + TF_RET_CHECK(config_.branch_thunks.size() == 2); + } else { + TF_RET_CHECK(!config_.branch_thunks.empty()); + } + for (auto& branch_thunk : config_.branch_thunks) { + TF_RETURN_IF_ERROR(branch_thunk->Initialize(params)); + } + + absl::MutexLock lock(&mutex_); + if (auto it = predicates_.find(params.executor); it == predicates_.end()) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr allocation, + params.executor->HostMemoryAllocate( + config_.branch_index_is_bool ? sizeof(bool) : sizeof(int32_t))); + predicates_.emplace(params.executor, std::move(allocation)); + } + + return absl::OkStatus(); +} + +absl::Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) { + auto& stream = *params.stream; + + // Copy the predicate value from device. + auto branch_index_or_pred = [&]() -> std::variant { + absl::MutexLock lock(&mutex_); + se::StreamExecutor* executor = stream.parent(); + if (config_.branch_index_is_bool) { + return reinterpret_cast(predicates_.at(executor)->opaque()); + } else { + return reinterpret_cast(predicates_.at(executor)->opaque()); + } + }(); + + se::DeviceMemoryBase branch_index_address = + params.buffer_allocations->GetDeviceAddress(branch_index_buffer_index_); + if (config_.branch_index_is_bool) { + TF_RETURN_IF_ERROR(stream.Memcpy(std::get(branch_index_or_pred), + branch_index_address, sizeof(bool))); + } else { + TF_RETURN_IF_ERROR(stream.Memcpy(std::get(branch_index_or_pred), + branch_index_address, sizeof(int32_t))); + } + + if (absl::Status blocked = stream.BlockHostUntilDone(); !blocked.ok()) { + return Internal("Failed to retrieve branch_index value on stream %p: %s.", + &stream, blocked.message()); + } + + int32_t branch_index = std::visit( + VariantVisitor{[](int32_t* branch_index) { return *branch_index; }, + [](bool* pred) { return *pred ? 0 : 1; }}, + branch_index_or_pred); + + // Handle default scenario for branch_index not in [0, num_branches). + if (branch_index < 0 || branch_index >= config_.branch_count) { + branch_index = config_.branch_count - 1; + } + + // Execute the branch computation corresponding to the value of branch_index. + TF_RETURN_IF_ERROR( + config_.branch_thunks[branch_index]->ExecuteOnStream(params)); + + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/conditional_thunk.h b/xla/service/gpu/runtime/conditional_thunk.h new file mode 100644 index 0000000000000..0d8109b70c09b --- /dev/null +++ b/xla/service/gpu/runtime/conditional_thunk.h @@ -0,0 +1,89 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_CONDITIONAL_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_CONDITIONAL_THUNK_H_ + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/status.h" +#include "xla/stream_executor/memory_allocation.h" +#include "xla/stream_executor/stream_executor.h" + +namespace xla { +namespace gpu { + +struct ConditionalThunkConfig { + bool branch_index_is_bool; + int64_t branch_count; + std::vector> branch_thunks; +}; + +// ConditionalThunk implements the conditional instruction on GPU by reading the +// predicate of the conditional and executing the true or the false computation +// depending on the value of the predicate. +// +// ConditionalThunk assumes that the buffers of the conditional result and the +// result of the true and false computations share the same allocation. Also, +// the buffers of the true operand of the conditional and that of the parameter +// instruction of the true computation share the same allocation. Similarly, the +// buffers of the false operand and that of the parameter instruction of the +// false computation share the same allocation. +class ConditionalThunk : public Thunk { + public: + ConditionalThunk(ThunkInfo thunk_info, ConditionalThunkConfig config, + const BufferAllocation::Slice& branch_index_buffer_index); + + ConditionalThunk(const ConditionalThunk&) = delete; + ConditionalThunk& operator=(const ConditionalThunk&) = delete; + + absl::Status Prepare(const PrepareParams& params, + ResourceRequests& resource_requests) override; + absl::Status Initialize(const InitializeParams& params) override; + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + absl::Span> branch_thunks() const { + return config_.branch_thunks; + } + + const BufferAllocation::Slice& branch_index_buffer() const { + return branch_index_buffer_index_; + } + + private: + const ConditionalThunkConfig config_; + const BufferAllocation::Slice branch_index_buffer_index_; + + // Pinned host memory for transferring predicate value from device to host. + absl::Mutex mutex_; + absl::flat_hash_map> + predicates_ ABSL_GUARDED_BY(mutex_); +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_CONDITIONAL_THUNK_H_ diff --git a/xla/service/gpu/runtime/conv.cc b/xla/service/gpu/runtime/conv.cc deleted file mode 100644 index 568127a6c8bba..0000000000000 --- a/xla/service/gpu/runtime/conv.cc +++ /dev/null @@ -1,689 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/conv.h" - -#include -#include -#include -#include -#include - -#include "llvm/ADT/Sequence.h" -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/gpu_asm_opts_util.h" -#include "xla/service/gpu/gpu_conv_runner.h" -#include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/status.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" -#include "xla/translate/mhlo_to_hlo/attribute_exporter.h" -#include "xla/xla.pb.h" - -#if GOOGLE_CUDA -#include "xla/service/gpu/autotuner_util.h" -#include "xla/service/gpu/conv_algorithm_picker.h" -#endif - -namespace xla { - -using xla::runtime::AggregateAttrDef; -using xla::runtime::AggregateAttrEncoding; -using xla::runtime::CustomCall; -using xla::runtime::EnumAttrEncoding; -using xla::runtime::FlatMemrefView; -using xla::runtime::State; -using xla::runtime::StridedMemrefView; -using xla::runtime::Tagged; - -namespace lmhlo_gpu = ::mlir::lmhlo_gpu; -namespace mhlo = ::mlir::mhlo; - -//===----------------------------------------------------------------------===// -// Structs for encoding convolution attributes defined in MHLO dialect. -//===----------------------------------------------------------------------===// - -namespace gpu { - -struct ConvDimensionNumbers { - int64_t input_batch_dim; - int64_t input_feature_dim; - absl::Span input_spatial_dims; - - int64_t kernel_in_feature_dim; - int64_t kernel_out_feature_dim; - absl::Span kernel_spatial_dims; - - int64_t output_batch_dim; - int64_t output_feature_dim; - absl::Span output_spatial_dims; -}; - -struct ConvBackendConfig { - int64_t algorithm; - bool tensor_ops_enabled; - bool is_cudnn_frontend; - bool is_cudnn_reordered_int8; - absl::Span knob_ids; - absl::Span knob_values; - absl::Span operand_0_layout; - absl::Span operand_1_layout; - absl::Span result_layout; - int64_t workspace_size; -}; - -} // namespace gpu - -//===----------------------------------------------------------------------===// -// Register convolution attributes decoding with the Xla runtime. -//===----------------------------------------------------------------------===// - -namespace runtime { - -XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING(se::dnn::ActivationMode); - -XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING( - xla::gpu::ConvDimensionNumbers, - // --- input dimensions - AggregateMember("input_batch_dim"), - AggregateMember("input_feature_dim"), - AggregateMember>("input_spatial_dims"), - // --- kernel dimensions - AggregateMember("kernel_in_feature_dim"), - AggregateMember("kernel_out_feature_dim"), - AggregateMember>("kernel_spatial_dims"), - // --- output dimensions - AggregateMember("output_batch_dim"), - AggregateMember("output_feature_dim"), - AggregateMember>("output_spatial_dims")); - -XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING( - xla::gpu::ConvBackendConfig, // - AggregateMember("algorithm"), - AggregateMember("tensor_ops_enabled"), - AggregateMember("is_cudnn_frontend"), - AggregateMember("is_cudnn_reordered_int8"), - AggregateMember>("knob_ids"), - AggregateMember>("knob_values"), - AggregateMember>("operand_0_layout"), - AggregateMember>("operand_1_layout"), - AggregateMember>("result_layout"), - AggregateMember("workspace_size")); - -} // namespace runtime - -//===----------------------------------------------------------------------===// -// Type names for encoded attributes. -//===----------------------------------------------------------------------===// - -namespace gpu { - -void RegisterConvTypeIdNames(runtime::TypeIDNameRegistry& registry) { - registry.Register>("__type_id_conv_dim_numbers"); - registry.Register>("__type_id_conv_backend_config"); -} - -//===----------------------------------------------------------------------===// -// Encoding from MHLO attributes to Xla runtime aggregate attributes. -//===----------------------------------------------------------------------===// - -// TODO(ezhulenev): We have to support enum encoding that can fail instead of -// always getting the value from returned StatusOr. -static auto EncodeConvActivation(lmhlo_gpu::Activation activation) { - return ConvertConvActivationMode(activation).value(); -} - -void PopulateConvAttrEncoding(runtime::CustomCallAttrEncodingSet& encoding) { - { // --- Encode `lmhlo_gpu::ActivationAttr`. - encoding - .Add>(EncodeConvActivation); - } - - { // --- Encode `mhlo::ConvDimensionNumbersAttr`. - using Attr = mhlo::ConvDimensionNumbersAttr; - encoding.Add>( - encoding, - AggregateAttrDef() - .Add("input_batch_dim", &Attr::getInputBatchDimension) - .Add("input_feature_dim", &Attr::getInputFeatureDimension) - .Add("input_spatial_dims", &Attr::getInputSpatialDimensions) - .Add("kernel_in_feature_dim", &Attr::getKernelInputFeatureDimension) - .Add("kernel_out_feature_dim", - &Attr::getKernelOutputFeatureDimension) - .Add("kernel_spatial_dims", &Attr::getKernelSpatialDimensions) - .Add("output_batch_dim", &Attr::getOutputBatchDimension) - .Add("output_feature_dim", &Attr::getOutputFeatureDimension) - .Add("output_spatial_dims", &Attr::getOutputSpatialDimensions)); - } - - { // --- Encode `lmhlo_gpu::ConvolutionBackendConfigAttr`. - using Attr = lmhlo_gpu::ConvolutionBackendConfigAttr; - encoding.Add>( - encoding, - AggregateAttrDef() - .Add("algorithm", &Attr::getAlgorithm) - .Add("tensor_ops_enabled", &Attr::getTensorOpsEnabled) - .Add("is_cudnn_frontend", &Attr::getIsCudnnFrontend) - .Add("is_cudnn_reordered_int8", &Attr::getIsCudnnReorderedInt8) - .Add("knob_ids", &Attr::getKnobIds) - .Add("knob_values", &Attr::getKnobValues) - .Add("operand_0_layout", &Attr::getOperand_0Layout) - .Add("operand_1_layout", &Attr::getOperand_1Layout) - .Add("result_layout", &Attr::getResultLayout) - .Add("workspace_size", &Attr::getWorkspaceSize)); - } -} - -//===----------------------------------------------------------------------===// -// Convolution runners caching. -//===----------------------------------------------------------------------===// - -StreamExecutorConvRunners* ConvRunners::operator()( - se::StreamExecutor* executor) { - absl::MutexLock lock(&mutex_); - return &runners_[executor]; -} - -//===----------------------------------------------------------------------===// -// Convolution custom call implementation. -//===----------------------------------------------------------------------===// - -namespace { - -struct Window { - absl::Span window_strides; - absl::Span padding; - absl::Span lhs_dilation; - absl::Span rhs_dilation; - absl::Span window_reversal; -}; - -struct ConvAttrs { - int64_t feature_group_count; - double result_scale; -}; - -struct FusedConvAttrs { - se::dnn::ActivationMode activation_mode; -}; - -struct SideInputAttrs { - double side_input_scale; -}; - -struct LeakyReluAlphaAttrs { - double leaky_relu_alpha; -}; - -} // namespace - -static GpuConvDescriptor GetConvDescriptor( - CudnnConvKind kind, - // Arguments - StridedMemrefView operand0, StridedMemrefView operand1, - StridedMemrefView output, FlatMemrefView scratch, - // Attributes - ConvDimensionNumbers dims, Window w, ConvBackendConfig b, ConvAttrs attrs, - // Conv-specific arguments and attributes - std::optional fused = std::nullopt, - std::optional side_input = std::nullopt, - std::optional leakyrelu_alpha = std::nullopt) { - // Build a convolution descriptor from the attributes. - GpuConvDescriptor descriptor; - descriptor.kind = kind; - - // Apply backend config layout to the shape. - auto apply_layout = [](StridedMemrefView& memref, - absl::Span minor_to_major) { - Shape shape = ToShape(memref); - return ShapeUtil::MakeShapeWithDenseLayout( - shape.element_type(), shape.dimensions(), minor_to_major); - }; - - descriptor.operand0_shape = apply_layout(operand0, b.operand_0_layout); - descriptor.operand1_shape = apply_layout(operand1, b.operand_1_layout); - descriptor.result_shape = apply_layout(output, b.result_layout); - - // Set up convolution dimensions numbers. - ConvolutionDimensionNumbers dns; - dns.set_input_batch_dimension(dims.input_batch_dim); - dns.set_input_feature_dimension(dims.input_feature_dim); - dns.set_kernel_input_feature_dimension(dims.kernel_in_feature_dim); - dns.set_kernel_output_feature_dimension(dims.kernel_out_feature_dim); - dns.set_output_batch_dimension(dims.output_batch_dim); - dns.set_output_feature_dimension(dims.output_feature_dim); - for (int64_t d : dims.input_spatial_dims) dns.add_input_spatial_dimensions(d); - for (int64_t d : dims.kernel_spatial_dims) - dns.add_kernel_spatial_dimensions(d); - for (int64_t d : dims.output_spatial_dims) - dns.add_output_spatial_dimensions(d); - descriptor.dnums = std::move(dns); - - // Put together convolution window config. - for (auto index : llvm::seq(0, w.window_strides.size())) { - WindowDimension* dim = descriptor.window.add_dimensions(); - // Window size for a convolution is the same as the kernel size. - // Kernel size of the convolution is operand1_shape. We need to look at - // the convolution dimension numbers kernel spatial dimensions to get - // the window size. - int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index); - dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim)); - dim->set_stride(w.window_strides[index]); - dim->set_padding_low(w.padding[index]); - dim->set_padding_high(w.padding[index]); - dim->set_base_dilation(w.lhs_dilation[index]); - dim->set_window_dilation(w.rhs_dilation[index]); - dim->set_window_reversal(w.window_reversal[index]); - } - - descriptor.scratch_size = scratch.size_in_bytes; - descriptor.feature_group_count = attrs.feature_group_count; - descriptor.backend_config.set_conv_result_scale(attrs.result_scale); - descriptor.backend_config.set_reordered_int8_nchw_vect( - b.is_cudnn_reordered_int8); - - // Set up convolution algorigthm. - auto* algo = descriptor.backend_config.mutable_algorithm(); - algo->set_algo_id(b.algorithm); - algo->set_math_type(b.tensor_ops_enabled - ? se::dnn::AlgorithmProto::TENSOR_OP_MATH - : se::dnn::AlgorithmProto::DEFAULT_MATH); - algo->set_is_cudnn_frontend(b.is_cudnn_frontend); - - if (b.workspace_size >= 0) - algo->mutable_workspace_size()->set_value(b.workspace_size); - - for (unsigned i = 0; i < b.knob_ids.size(); ++i) { - algo->mutable_tuning_knobs()->insert({b.knob_ids[i], b.knob_values[i]}); - } - - // Set attributes specific for fused convolutions. - if (fused.has_value()) - descriptor.backend_config.set_activation_mode(fused->activation_mode); - - // Set attributes specific for fused convolutions with leaky_relu_alpha. - if (leakyrelu_alpha.has_value()) - descriptor.backend_config.set_leakyrelu_alpha( - leakyrelu_alpha->leaky_relu_alpha); - - // Set attributes specific for convolutions with side input. - if (side_input.has_value()) - descriptor.backend_config.set_side_input_scale( - side_input->side_input_scale); - - return descriptor; -} - -template -static absl::Status DoConv( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, NonAtomicallyUpgradeableRWLock* gpu_lock, - State runner, - // Arguments - StridedMemrefView operand0, StridedMemrefView operand1, - std::optional bias, - std::optional side_input, - absl::Span outputs, FlatMemrefView scratch, - int64_t uid, - // Convolution config - ConvDimensionNumbers conv_dims, - // Window config - absl::Span window_strides, absl::Span padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - absl::Span window_reversal, - // Backend config attributes - ConvBackendConfig backend_config, - // Remaining attributes - int64_t feature_group_count, double result_scale, - // Optional attributes for fused convolutions. - std::optional activation_mode = std::nullopt, - std::optional side_input_scale = std::nullopt, - std::optional leakyrelu_alpha = std::nullopt, - // Optional extra arguments for graph convolutions. - absl::Span extra_operands = {}, - std::optional serialized_graph = std::nullopt) { - // Build config for optional attributes. - std::optional fused_attrs = std::nullopt; - if (activation_mode.has_value()) fused_attrs = {*activation_mode}; - - std::optional side_input_attrs = std::nullopt; - if (side_input_scale.has_value()) side_input_attrs = {*side_input_scale}; - - std::optional leakyrelu_alpha_attrs = std::nullopt; - if (leakyrelu_alpha.has_value()) leakyrelu_alpha_attrs = {*leakyrelu_alpha}; - - bool runtime_autotuning = false; - if (backend_config.algorithm == -1) { - // Set the algorithm back to the default algorithm to avoid error from - // cuDNN. - backend_config.algorithm = 0; - runtime_autotuning = true; - } - - // Get or create the convolution runner state. - TF_ASSIGN_OR_RETURN( - ConvRunner * conv, - runner.GetOrCreate([&]() -> absl::StatusOr { - GpuConvDescriptor descriptor = GetConvDescriptor( - kind, operand0, operand1, outputs[0], scratch, conv_dims, - {window_strides, padding, lhs_dilation, rhs_dilation, - window_reversal}, - backend_config, {feature_group_count, result_scale}, fused_attrs, - side_input_attrs, leakyrelu_alpha_attrs); - if (serialized_graph.has_value()) { - descriptor.backend_config.set_serialized_graph( - std::string(serialized_graph.value())); - } - TF_ASSIGN_OR_RETURN(GpuConvConfig conv_config, - GetGpuConvConfig(descriptor, "")); - - return ConvRunner(std::move(conv_config)); - })); - - // Prepare buffer arguments. - std::vector buffers = {GetDeviceAddress(operand0), - GetDeviceAddress(operand1)}; - if (bias.has_value()) buffers.push_back(GetDeviceAddress(*bias)); - if (side_input.has_value()) buffers.push_back(GetDeviceAddress(*side_input)); - for (const StridedMemrefView& operand : extra_operands) { - buffers.push_back(GetDeviceAddress(operand)); - } - - std::vector result_buffers; - for (const StridedMemrefView& output : outputs) { - result_buffers.push_back(GetDeviceAddress(output)); - } - se::DeviceMemoryBase scratch_buffer = GetDeviceAddress(scratch); - - int64_t scratch_buffer_size = scratch_buffer.size(); - - // Do runtime conv autotuning. - if (runtime_autotuning) { -#if GOOGLE_CUDA - // Don't run autotuning concurrently on the same GPU. - NonAtomicallyUpgradeableRWLock::WriterLock writer_lock = - gpu_lock->UpgradeToWriterMutexLock(); - - auto stream_exec = run_options->stream()->parent(); - auto allocator = run_options->allocator(); - AutotuneConfig config(DeviceConfig{stream_exec, allocator}, *debug_options); - GpuConvAlgorithmPicker conv_algorithm_picker(config); - - GpuConvConfig gpu_conv_config = conv->config; - TF_ASSIGN_OR_RETURN( - AutotuneResult best_algo, - conv_algorithm_picker.PickBestAlgorithmWithAllocatedBuffer( - config, gpu_conv_config, run_options, *debug_options, buffers, - result_buffers)); - - // Set algorithm in the convolution runner state. - se::dnn::AlgorithmDesc algo_desc(best_algo.conv().algorithm(), - best_algo.conv().tensor_ops_enabled()); - conv->config.algorithm = algo_desc; - - // Set scratch buffer size according to the selected algorithm. - scratch_buffer_size = best_algo.scratch_bytes(); -#else - return absl::InternalError( - "Failed to run runtime autotuner because CUDA is not enabled"); -#endif - } - - RunConvOptions opts; - opts.runner_cache = &conv->runner; - - if (scratch_buffer_size > scratch_buffer.size()) { - // Need to reallocate scratch buffer. - se::DeviceMemoryAllocator* allocator = run_options->allocator(); - TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer, - allocator->Allocate(run_options->device_ordinal(), - scratch_buffer_size)); - se::DeviceMemoryBase new_scratch_buffer(allocated_buffer.ptr(), - scratch_buffer_size); - - // Run the convolution using the new scratch buffer. - TF_RETURN_IF_ERROR(RunGpuConv(conv->config, buffers, result_buffers, - new_scratch_buffer, run_options->stream(), - opts)); - if (!run_options->stream()->ok()) { - return absl::InternalError("run_options stream not ok"); - } - return absl::OkStatus(); - } - - // Run the convolution. - TF_RETURN_IF_ERROR(RunGpuConv(conv->config, buffers, result_buffers, - scratch_buffer, run_options->stream(), opts)); - if (!run_options->stream()->ok()) { - return absl::InternalError("run_options stream not ok"); - } - - return absl::OkStatus(); -} - -template -static absl::Status ConvImpl( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, NonAtomicallyUpgradeableRWLock* gpu_lock, - State runner, - // Arguments - StridedMemrefView operand0, StridedMemrefView operand1, - std::optional bias, - std::optional side_input, StridedMemrefView output, - FlatMemrefView scratch, int64_t uid, - // Convolution config - ConvDimensionNumbers conv_dims, - // Window config - absl::Span window_strides, absl::Span padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - absl::Span window_reversal, - // Backend config attributes - ConvBackendConfig backend_config, - // Remaining attributes - int64_t feature_group_count, double result_scale, - // Optional attributes for fused convolutions. - std::optional activation_mode = std::nullopt, - std::optional side_input_scale = std::nullopt, - std::optional leakyrelu_alpha = std::nullopt) { - return DoConv(run_options, debug_options, gpu_lock, runner, operand0, - operand1, bias, side_input, {output}, scratch, uid, - conv_dims, window_strides, padding, lhs_dilation, - rhs_dilation, window_reversal, backend_config, - feature_group_count, result_scale, activation_mode, - side_input_scale, leakyrelu_alpha); -} - -template -static absl::Status ConvGraphImpl( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, NonAtomicallyUpgradeableRWLock* gpu_lock, - State runner, - // Arguments - StridedMemrefView operand0, StridedMemrefView operand1, - CustomCall::RemainingArgs args, int64_t uid, - // Convolution config - ConvDimensionNumbers conv_dims, - // Window config - absl::Span window_strides, absl::Span padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - absl::Span window_reversal, - // Backend config attributes - ConvBackendConfig backend_config, - // Remaining attributes - int64_t feature_group_count, double result_scale, int32_t n_aux_outputs, - std::string_view serialized_graph) { - // Let N be the size of 'args'. The first (N - n_aux_outputs - 2) elements of - // 'args' are extra operands, which are operands other than the input and - // filter. The next (n_aux_outputs + 1) elements are the outputs -- the first - // being the main convolution output and the others being the "auxiliary" - // outputs (e.g. amax). The last element of 'args' is the scratch space. - std::vector extra_operands; - for (int i = 0; i < args.size() - n_aux_outputs - 2; i++) { - auto arg = args.get(i); - if (failed(arg)) { - return absl::InternalError( - "Failed to get operand buffer for convolution graph"); - } - extra_operands.push_back(arg.value()); - } - - std::vector outputs; - for (int i = args.size() - n_aux_outputs - 2; i < args.size() - 1; i++) { - auto arg = args.get(i); - if (failed(arg)) { - return absl::InternalError( - "Failed to get output buffer for convolution graph"); - } - outputs.push_back(arg.value()); - } - - auto scratch = args.get(args.size() - 1); - if (failed(scratch)) { - return absl::InternalError( - "Failed to get scratch buffer for convolution graph"); - } - - return DoConv(run_options, debug_options, gpu_lock, runner, operand0, - operand1, /*bias=*/{}, - /*side_input=*/{}, outputs, scratch.value(), uid, - conv_dims, window_strides, padding, lhs_dilation, - rhs_dilation, window_reversal, backend_config, - feature_group_count, result_scale, /*activation_mode=*/{}, - /*side_input_scale=*/{}, /*leakyrelu_alpha=*/{}, - extra_operands, serialized_graph); -} - -//===----------------------------------------------------------------------===// -// Convolution custom calls bindings and registration. -//===----------------------------------------------------------------------===// - -using Kind = CudnnConvKind; - -template -static auto BindConvAttributes(runtime::CustomCallBinding binding) { - return std::move(binding) - // Unique convolution id for caching state. - .template Attr("uid") - // Convolution dimensions numbers - .template Attr("conv_dims") - // Window config - .template Attr>("window_strides") - .template Attr>("padding") - .template Attr>("lhs_dilation") - .template Attr>("rhs_dilation") - .template Attr>("window_reversal") - // Backend config attributes - .template Attr("backend_config") - // Remaining attributes. - .template Attr("feature_group_count") - .template Attr("result_scale"); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL_TEMPLATE( - Kind kind, Conv, FunctionWrapper>(), checks, - BindConvAttributes( - CustomCall::Bind("xla.gpu.conv") - .UserData() - .UserData() - .UserData() - .State("uid") // runner - .Arg() // operand0 - .Arg() // operand1 - .Value(std::optional()) // bias - .Value(std::optional()) // side_input - .Arg() // output - .Arg() // scratch - ) - .Value(std::optional()) // activation_mode - .Value(std::optional()) // side_input_scale - .Value(std::optional()) // leaky_relu_alpha -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - ConvFused, FunctionWrapper>(), checks, - BindConvAttributes( - CustomCall::Bind("xla.gpu.conv.fused") - .UserData() - .UserData() - .UserData() - .State("uid") // runner - .Arg() // operand0 - .Arg() // operand1 - .Arg() // bias - .Value(std::optional()) // side_input - .Arg() // output - .Arg() // scratch - ) - .Attr("activation_mode") - .Value(std::optional()) // side_input_scale - .Attr("leakyrelu_alpha") // leaky_relu_alpha -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - ConvFusedSideInput, FunctionWrapper>(), - checks, - BindConvAttributes(CustomCall::Bind("xla.gpu.conv.fused.side_input") - .UserData() - .UserData() - .UserData() - .State("uid") // runner - .Arg() // operand0 - .Arg() // operand1 - .Arg() // bias - .Arg() // side_input - .Arg() // output - .Arg() // scratch - ) - .Attr("activation_mode") - .Attr("side_input_scale") - .Value(std::optional())); // leaky_relu_alpha - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - ConvForwardGraph, FunctionWrapper>(), - checks, - BindConvAttributes(CustomCall::Bind("xla.gpu.conv.forward.graph") - .UserData() - .UserData() - .UserData() - .State("uid") // runner - .Arg() // operand0 - .Arg() // operand1 - .RemainingArgs() // binary_operands - ) - .Attr("n_aux_outputs") - .Attr("serialized_graph")); - -//===----------------------------------------------------------------------===// - -void RegisterConvCustomCalls(runtime::DirectCustomCallRegistry& registry) { - auto conv = [](std::string name) { return "xla.gpu.conv." + name; }; - registry.Register(conv("forward"), Conv); - registry.Register(conv("backward.input"), Conv); - registry.Register(conv("backward.filter"), Conv); - registry.Register(conv("forward.fused"), ConvFused); - registry.Register(conv("forward.fused.side_input"), ConvFusedSideInput); - registry.Register(conv("forward.graph"), ConvForwardGraph); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/conv.h b/xla/service/gpu/runtime/conv.h deleted file mode 100644 index d8622c16136b4..0000000000000 --- a/xla/service/gpu/runtime/conv.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_CONV_H_ -#define XLA_SERVICE_GPU_RUNTIME_CONV_H_ - -#include -#include - -#include "absl/container/node_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/service/gpu/gpu_conv_runner.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime Conv custom calls. -void RegisterConvCustomCalls(runtime::DirectCustomCallRegistry& registry); - -// Register type names for convoluttion attributes defined by MHLO dialect. -void RegisterConvTypeIdNames(runtime::TypeIDNameRegistry& registry); - -// Add attributes encoding for convoluttion attributes defined by MHLO dialect. -void PopulateConvAttrEncoding(runtime::CustomCallAttrEncodingSet& encoding); - -//===----------------------------------------------------------------------===// -// Cache conv runners between invocations of convolution custom calls. -//===----------------------------------------------------------------------===// - -struct ConvRunner { - explicit ConvRunner(GpuConvConfig config) - : config(std::move(config)), runner(this->config) {} - GpuConvConfig config; - GenericConvRunner runner; -}; - -class StreamExecutorConvRunners : public runtime::StateVector {}; - -// Xla executable keeps a mapping from stream executors to convolution runners. -class ConvRunners { - public: - StreamExecutorConvRunners* operator()(se::StreamExecutor* executor); - - private: - mutable absl::Mutex mutex_; - absl::node_hash_map runners_ - ABSL_GUARDED_BY(mutex_); -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_CONV_H_ diff --git a/xla/service/gpu/runtime/conv_reorder.cc b/xla/service/gpu/runtime/conv_reorder.cc deleted file mode 100644 index e69a7f74e0ece..0000000000000 --- a/xla/service/gpu/runtime/conv_reorder.cc +++ /dev/null @@ -1,105 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/conv_reorder.h" - -#include -#include - -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/xla.pb.h" - -namespace xla { -namespace gpu { -namespace { - -using ::xla::runtime::CustomCall; -using ::xla::runtime::FlatMemrefView; -using ::xla::runtime::StridedMemrefView; - -se::dnn::FilterDescriptor GetFilterDescriptor( - absl::Span filter_dims) { - se::dnn::FilterDescriptor filter_desc(2); - filter_desc.set_layout(se::dnn::FilterLayout::kOutputInputYX32); - filter_desc.set_output_feature_map_count(filter_dims[0]); - filter_desc.set_input_feature_map_count(filter_dims[1]); - filter_desc.set_input_filter_height(filter_dims[2]); - filter_desc.set_input_filter_width(filter_dims[3]); - return filter_desc; -} - -absl::Status ConvReorderFilterImpl( - const ServiceExecutableRunOptions* run_options, - StridedMemrefView input_view, StridedMemrefView output_view, - absl::Span filter_dims) { - auto input = se::DeviceMemory(GetDeviceAddress(input_view)); - auto output = se::DeviceMemory(GetDeviceAddress(output_view)); - - return run_options->stream()->CudnnReorderConvolutionFilterAndBias( - GetFilterDescriptor(filter_dims), input, &output, std::nullopt, - std::nullopt); -} - -absl::Status ConvReorderFilterAndBiasImpl( - const ServiceExecutableRunOptions* run_options, - StridedMemrefView filter_input_view, FlatMemrefView bias_input_view, - StridedMemrefView filter_output_view, FlatMemrefView bias_output_view, - absl::Span filter_dims) { - auto filter_input = - se::DeviceMemory(GetDeviceAddress(filter_input_view)); - auto filter_output = - se::DeviceMemory(GetDeviceAddress(filter_output_view)); - auto bias_input = se::DeviceMemory(GetDeviceAddress(bias_input_view)); - auto bias_output = - se::DeviceMemory(GetDeviceAddress(bias_output_view)); - - return run_options->stream()->CudnnReorderConvolutionFilterAndBias( - GetFilterDescriptor(filter_dims), filter_input, &filter_output, - std::make_optional(bias_input), std::make_optional(bias_output)); -} - -} // namespace - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - ConvReorderFilter, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.conv.reorder.filter") - .UserData() - .Arg() // filter_input - .Arg() // filter_output - .Attr>("filter_dims")); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - ConvReorderFilterAndBias, FunctionWrapper(), - checks, - CustomCall::Bind("xla.gpu.conv.reorder.filter_and_bias") - .UserData() - .Arg() // filter_input - .Arg() // bias_input - .Arg() // filter_output - .Arg() // bias_output - .Attr>("filter_dims")); - -void RegisterConvReorderCustomCalls( - runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.conv.reorder.filter", ConvReorderFilter); - registry.Register("xla.gpu.conv.reorder.filter_and_bias", - ConvReorderFilterAndBias); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/conv_reorder.h b/xla/service/gpu/runtime/conv_reorder.h deleted file mode 100644 index 654e9dd56bcb5..0000000000000 --- a/xla/service/gpu/runtime/conv_reorder.h +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_CONV_REORDER_H_ -#define XLA_SERVICE_GPU_RUNTIME_CONV_REORDER_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime convolution reorder custom calls. -void RegisterConvReorderCustomCalls( - runtime::DirectCustomCallRegistry& registry); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_CONV_REORDER_H_ diff --git a/xla/service/gpu/runtime/convolution_thunk.cc b/xla/service/gpu/runtime/convolution_thunk.cc new file mode 100644 index 0000000000000..6e8158d866aaf --- /dev/null +++ b/xla/service/gpu/runtime/convolution_thunk.cc @@ -0,0 +1,169 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/convolution_thunk.h" + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/gpu_conv_runner.h" +#include "xla/service/gpu/stream_executor_util.h" +#include "xla/stream_executor/scratch_allocator.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/util.h" + +namespace xla { +namespace gpu { + +ConvolutionThunk::ConvolutionThunk( + ThunkInfo thunk_info, GpuConvConfig config, + std::vector operand_slices, + std::vector result_slices, + BufferAllocation::Slice scratch_slice) + : Thunk(Kind::kConvolution, thunk_info), + operand_buffers_(std::move(operand_slices)), + result_buffers_(std::move(result_slices)), + scratch_buffer_(scratch_slice), + config_(std::move(config)) {} + +GenericConvRunner& ConvolutionThunk::GetOrCreateRunner( + const stream_executor::Stream* stream, bool* runner_created) { + absl::MutexLock lock(&mu_); + auto it = runner_cache_.find(stream); + *runner_created = (it == runner_cache_.end()); + if (*runner_created) { + it = runner_cache_ + .insert({stream, std::make_unique(config_)}) + .first; + } + return *it->second; +} + +absl::Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) { + const auto& buffer_allocations = *params.buffer_allocations; + + std::vector operand_se_buffers, result_se_buffers; + operand_se_buffers.reserve(operand_buffers_.size()); + for (BufferAllocation::Slice buffer : operand_buffers_) { + operand_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer)); + } + + result_se_buffers.reserve(result_buffers_.size()); + for (BufferAllocation::Slice buffer : result_buffers_) { + result_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer)); + } + + se::DeviceMemoryBase scratch = + buffer_allocations.GetDeviceAddress(scratch_buffer_); + + bool runner_created = false; + RunConvOptions opts; + opts.runner_cache = &GetOrCreateRunner(params.stream, &runner_created); + +#if TENSORFLOW_USE_ROCM + if (runner_created) { + TF_ASSIGN_OR_RETURN( + GpuConvParams conv_params, + GetGpuConvParams(config_, operand_se_buffers, result_se_buffers)); + + TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind, + GetDNNConvKindFromCudnnConvKind(config_.kind)); + + TF_ASSIGN_OR_RETURN(se::dnn::DataType input_type, + GetDNNDataTypeFromPrimitiveType(config_.input_type)); + + TF_ASSIGN_OR_RETURN(auto dnn, + se::dnn::internal::GetDnnFromStream(params.stream)); + se::OwningScratchAllocator<> scratch_allocator( + buffer_allocations.device_ordinal(), + buffer_allocations.memory_allocator()); + + std::vector profile_results; + dnn->GetMIOpenConvolveAlgorithms( + kind, input_type, params.stream, config_.input_descriptor, + conv_params.input_buf, config_.filter_descriptor, + conv_params.filter_buf, config_.output_descriptor, + conv_params.output_buf, config_.conv_desc, &scratch_allocator, + &profile_results); + } +#endif // TENSORFLOW_USE_ROCM + + TF_RETURN_IF_ERROR(RunGpuConv(config_, absl::MakeSpan(operand_se_buffers), + absl::MakeSpan(result_se_buffers), scratch, + params.stream, opts)); + + // Note: Convolution has a tuple buffer as an output, but we don't need to + // populate it as no one should be reading from the tuple directly. + if (!params.stream->ok()) { + return Internal("ConvolutionThunk::ExecuteOnStream failed."); + } + return absl::OkStatus(); +} + +ConvolutionReorderThunk::ConvolutionReorderThunk( + ThunkInfo thunk_info, absl::Span filter_nchw, + absl::InlinedVector operand_slices, + absl::InlinedVector result_slices) + : Thunk(Kind::kConvolutionReorder, thunk_info), + filter_descriptor_(CreateFilterDescriptor(filter_nchw)), + operand_buffers_(operand_slices), + result_buffers_(result_slices) {} + +absl::Status ConvolutionReorderThunk::ExecuteOnStream( + const ExecuteParams& params) { + bool has_bias = operand_buffers_.size() > 1; + CHECK_EQ(operand_buffers_.size(), result_buffers_.size()); + + const auto& buffer_allocations = *params.buffer_allocations; + + auto filter_input = se::DeviceMemory( + buffer_allocations.GetDeviceAddress(operand_buffers_[0])); + auto filter_output = se::DeviceMemory( + buffer_allocations.GetDeviceAddress(result_buffers_[0])); + auto bias_input = + has_bias ? std::make_optional(se::DeviceMemory( + buffer_allocations.GetDeviceAddress(operand_buffers_[1]))) + : std::nullopt; + auto bias_output = + has_bias ? std::make_optional(se::DeviceMemory( + buffer_allocations.GetDeviceAddress(result_buffers_[1]))) + : std::nullopt; + + auto dnn = params.stream->parent()->AsDnn(); + if (dnn == nullptr) { + return absl::InternalError("No DNN for stream."); + } + return dnn->CudnnReorderConvolutionFilterAndBias( + params.stream, filter_descriptor_, filter_input, &filter_output, + std::move(bias_input), std::move(bias_output)); +} + +se::dnn::FilterDescriptor ConvolutionReorderThunk::CreateFilterDescriptor( + absl::Span filter_nchw) { + CHECK_EQ(filter_nchw.size(), 4); + se::dnn::FilterDescriptor filter_desc(2); + filter_desc.set_layout(se::dnn::FilterLayout::kOutputInputYX32); + filter_desc.set_output_feature_map_count(filter_nchw[0]); + filter_desc.set_input_feature_map_count(filter_nchw[1]); + filter_desc.set_input_filter_height(filter_nchw[2]); + filter_desc.set_input_filter_width(filter_nchw[3]); + return filter_desc; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/convolution_thunk.h b/xla/service/gpu/runtime/convolution_thunk.h new file mode 100644 index 0000000000000..3f9db4ea26660 --- /dev/null +++ b/xla/service/gpu/runtime/convolution_thunk.h @@ -0,0 +1,97 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_CONVOLUTION_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_CONVOLUTION_THUNK_H_ + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/gpu_conv_runner.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/stream_executor.h" + +namespace xla { +namespace gpu { + +// This class stores everything that StreamExecutor needs to launch a DNN +// convolution. It is generated by IrEmitter. +// +// This is thread-compatible. +class ConvolutionThunk : public Thunk { + public: + // Constructs a thunk for launching a DNN convolution. + // + // operand_slices should be in the same order as cudnn_call->operands(). + ConvolutionThunk(ThunkInfo thunk_info, GpuConvConfig config, + std::vector operand_slices, + std::vector result_slices, + BufferAllocation::Slice scratch_slice); + + ConvolutionThunk(const ConvolutionThunk&) = delete; + ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + private: + std::vector operand_buffers_; + std::vector result_buffers_; + BufferAllocation::Slice scratch_buffer_; + GenericConvRunner& GetOrCreateRunner(const stream_executor::Stream* stream, + bool* runner_created); + + // Convolution config + const GpuConvConfig config_; + absl::Mutex mu_; + absl::flat_hash_map> + runner_cache_ ABSL_GUARDED_BY(mu_); +}; + +// Launches the kernel that reorders input data for int8x32 convolutions. +class ConvolutionReorderThunk : public Thunk { + public: + ConvolutionReorderThunk( + ThunkInfo thunk_info, absl::Span filter_nchw, + absl::InlinedVector operand_slices, + absl::InlinedVector result_slices); + + ConvolutionReorderThunk(const ConvolutionReorderThunk&) = delete; + ConvolutionReorderThunk& operator=(const ConvolutionReorderThunk&) = delete; + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + private: + static se::dnn::FilterDescriptor CreateFilterDescriptor( + absl::Span filter_nchw); + + const se::dnn::FilterDescriptor filter_descriptor_; + absl::InlinedVector operand_buffers_; + absl::InlinedVector result_buffers_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_CONVOLUTION_THUNK_H_ diff --git a/xla/service/gpu/runtime/copy_thunk.cc b/xla/service/gpu/runtime/copy_thunk.cc new file mode 100644 index 0000000000000..9a8698b23f8b5 --- /dev/null +++ b/xla/service/gpu/runtime/copy_thunk.cc @@ -0,0 +1,85 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/copy_thunk.h" + +#include + +#include "mlir/IR/Value.h" // from @llvm-project +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/status.h" +#include "xla/stream_executor/stream_executor.h" + +namespace xla { +namespace gpu { + +DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( + ThunkInfo thunk_info, const BufferAllocation::Slice& source_buffer, + const BufferAllocation::Slice& destination_buffer, uint64_t mem_size) + : Thunk(Kind::kCopy, thunk_info), + source_buffer_(source_buffer), + destination_buffer_(destination_buffer), + mem_size_(mem_size) {} + +absl::Status DeviceToDeviceCopyThunk::ExecuteOnStream( + const ExecuteParams& params) { + se::DeviceMemoryBase destination_data = + params.buffer_allocations->GetDeviceAddress(destination_buffer_); + se::DeviceMemoryBase source_data = + params.buffer_allocations->GetDeviceAddress(source_buffer_); + VLOG(3) << "Memcpy D2D of size " << mem_size_ << " from " + << source_data.opaque() << " to " << destination_data.opaque(); + return params.stream->Memcpy(&destination_data, source_data, mem_size_); +} + +DeviceToHostCopyThunk::DeviceToHostCopyThunk( + ThunkInfo thunk_info, const BufferAllocation::Slice& source_buffer, + const BufferAllocation::Slice& destination_buffer, uint64_t mem_size) + : DeviceToDeviceCopyThunk(thunk_info, source_buffer, destination_buffer, + mem_size) {} + +absl::Status DeviceToHostCopyThunk::ExecuteOnStream( + const ExecuteParams& params) { + se::DeviceMemoryBase destination_data = + params.buffer_allocations->GetDeviceAddress(destination()); + se::DeviceMemoryBase source_data = + params.buffer_allocations->GetDeviceAddress(source()); + void* cpu_dst = destination_data.opaque(); + VLOG(3) << "Memcpy D2H for memory offload from " << source_data.opaque() + << " to " << cpu_dst; + return params.stream->Memcpy(cpu_dst, source_data, size_bytes()); +} + +HostToDeviceCopyThunk::HostToDeviceCopyThunk( + ThunkInfo thunk_info, const BufferAllocation::Slice& source_buffer, + const BufferAllocation::Slice& destination_buffer, uint64_t mem_size) + : DeviceToDeviceCopyThunk(thunk_info, source_buffer, destination_buffer, + mem_size) {} + +absl::Status HostToDeviceCopyThunk::ExecuteOnStream( + const ExecuteParams& params) { + se::DeviceMemoryBase destination_data = + params.buffer_allocations->GetDeviceAddress(destination()); + se::DeviceMemoryBase source_data = + params.buffer_allocations->GetDeviceAddress(source()); + void* cpu_src = source_data.opaque(); + VLOG(3) << "Memcpy H2D for memory offload from " << cpu_src << " to " + << destination_data.opaque(); + return params.stream->Memcpy(&destination_data, cpu_src, size_bytes()); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/copy_thunk.h b/xla/service/gpu/runtime/copy_thunk.h new file mode 100644 index 0000000000000..521030ccee233 --- /dev/null +++ b/xla/service/gpu/runtime/copy_thunk.h @@ -0,0 +1,79 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_COPY_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_COPY_THUNK_H_ + +#include + +#include "absl/status/status.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/runtime/thunk.h" + +namespace xla { +namespace gpu { + +// A thunk that copies data from a device buffer to another device buffer. +class DeviceToDeviceCopyThunk : public Thunk { + public: + // Constructs a CopyThunk that copies host data from `source_buffer` to the + // device buffer `destination_buffer`. `mem_size` is the size of the data in + // bytes. + DeviceToDeviceCopyThunk(ThunkInfo thunk_info, + const BufferAllocation::Slice& source_buffer, + const BufferAllocation::Slice& destination_buffer, + uint64_t mem_size); + + DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete; + DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + void ClearCompileTimeInfo() override { Thunk::ClearCompileTimeInfo(); } + + const BufferAllocation::Slice& source() const { return source_buffer_; } + const BufferAllocation::Slice& destination() const { + return destination_buffer_; + } + uint64_t size_bytes() const { return mem_size_; } + + private: + const BufferAllocation::Slice source_buffer_; + const BufferAllocation::Slice destination_buffer_; + const uint64_t mem_size_; +}; + +class DeviceToHostCopyThunk : public DeviceToDeviceCopyThunk { + public: + DeviceToHostCopyThunk(ThunkInfo thunk_info, + const BufferAllocation::Slice& source_buffer, + const BufferAllocation::Slice& destination_buffer, + uint64_t mem_size); + absl::Status ExecuteOnStream(const ExecuteParams& params) override; +}; + +class HostToDeviceCopyThunk : public DeviceToDeviceCopyThunk { + public: + HostToDeviceCopyThunk(ThunkInfo thunk_info, + const BufferAllocation::Slice& source_buffer, + const BufferAllocation::Slice& destination_buffer, + uint64_t mem_size); + absl::Status ExecuteOnStream(const ExecuteParams& params) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_COPY_THUNK_H_ diff --git a/xla/service/gpu/runtime/cub_sort.cc b/xla/service/gpu/runtime/cub_sort.cc deleted file mode 100644 index e6cbd13347aab..0000000000000 --- a/xla/service/gpu/runtime/cub_sort.cc +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/cub_sort.h" - -#include - -#include "absl/status/status.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/executable.h" // IWYU pragma: keep -#include "xla/runtime/memref_view.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/stream_executor/device_memory.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/service/gpu/cub_sort_thunk.h" -#endif - -namespace xla { -namespace gpu { -namespace { - -using ::stream_executor::DeviceMemoryBase; -using ::xla::runtime::CustomCall; -using ::xla::runtime::FlatMemrefView; - -absl::Status CubDeviceRadixSortKeysImpl( - const ServiceExecutableRunOptions* run_options, FlatMemrefView input_view, - FlatMemrefView output_view, FlatMemrefView scratch_view, bool descending) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - return RunCubSort(input_view.dtype, std::nullopt, - GetDeviceAddress(input_view), DeviceMemoryBase(), - GetDeviceAddress(output_view), DeviceMemoryBase(), - GetDeviceAddress(scratch_view), descending); -#else - return absl::UnimplementedError("CUB is not available"); -#endif -} - -absl::Status CubDeviceRadixSortPairsImpl( - const ServiceExecutableRunOptions* run_options, - FlatMemrefView input_keys_view, FlatMemrefView input_values_view, - FlatMemrefView output_keys_view, FlatMemrefView output_values_view, - FlatMemrefView scratch_view, bool descending) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - return RunCubSort( - input_keys_view.dtype, input_values_view.dtype, - GetDeviceAddress(input_keys_view), GetDeviceAddress(input_values_view), - GetDeviceAddress(output_keys_view), GetDeviceAddress(output_values_view), - GetDeviceAddress(scratch_view), descending); -#else - return absl::UnimplementedError("CUB is not available"); -#endif -} - -} // namespace - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - CubDeviceRadixSortKeys, FunctionWrapper(), - checks, - CustomCall::Bind("xla.gpu.radix_sort_keys") - .UserData() - .Arg() // input - .Arg() // output - .Arg() // scratch - .Attr("descending")); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - CubDeviceRadixSortPairs, FunctionWrapper(), - checks, - CustomCall::Bind("xla.gpu.radix_sort_pairs") - .UserData() - .Arg() // input_keys - .Arg() // input_values - .Arg() // output_keys - .Arg() // output_values - .Arg() // scratch - .Attr("descending")); - -void RegisterCubSortCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.radix_sort_keys", CubDeviceRadixSortKeys); - registry.Register("xla.gpu.radix_sort_pairs", CubDeviceRadixSortPairs); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/cub_sort.h b/xla/service/gpu/runtime/cub_sort.h deleted file mode 100644 index 7014b39de2d8a..0000000000000 --- a/xla/service/gpu/runtime/cub_sort.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_CUB_SORT_H_ -#define XLA_SERVICE_GPU_RUNTIME_CUB_SORT_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime CUB sort custom calls. -void RegisterCubSortCustomCalls(runtime::DirectCustomCallRegistry& registry); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_CUB_SORT_H_ diff --git a/xla/service/gpu/runtime/cub_sort_thunk.cc b/xla/service/gpu/runtime/cub_sort_thunk.cc new file mode 100644 index 0000000000000..5b477dc5fb2a4 --- /dev/null +++ b/xla/service/gpu/runtime/cub_sort_thunk.cc @@ -0,0 +1,276 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/cub_sort_thunk.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "xla/primitive_util.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/gpu/cub_sort_kernel.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { +namespace { + +// Template class for sorting a single tensor. +class CubSortKeysImpl : public CubSortRunnerInterface { + public: + using SortKeysFn = std::function; + + explicit CubSortKeysImpl(SortKeysFn sort_keys_fn, PrimitiveType type) + : sort_keys_fn_(sort_keys_fn), type_(type) {} + + absl::Status Run(se::DeviceMemoryBase input_keys, + se::DeviceMemoryBase input_values, + se::DeviceMemoryBase output_keys, + se::DeviceMemoryBase output_values, + se::DeviceMemoryBase scratch, bool descending) override; + absl::Status Run(const Thunk::ExecuteParams& params, + const CubSortThunk* thunk) override; + absl::StatusOr GetScratchSize(int64_t num_items) override; + + private: + SortKeysFn sort_keys_fn_; + PrimitiveType type_; +}; + +absl::Status CubSortKeysImpl::Run(se::DeviceMemoryBase input_keys, + se::DeviceMemoryBase input_values, + se::DeviceMemoryBase output_keys, + se::DeviceMemoryBase output_values, + se::DeviceMemoryBase scratch, + bool descending) { + size_t temp_bytes = scratch.size(); + size_t num_items = input_keys.size() * 8 / primitive_util::BitWidth(type_); + CHECK(input_values.is_null()); + CHECK(output_values.is_null()); + const char* error = + sort_keys_fn_(scratch.opaque(), temp_bytes, input_keys.opaque(), + output_keys.opaque(), num_items, descending); + if (error != nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("CubSortKeys error: ", error)); + } + return absl::OkStatus(); +} + +absl::Status CubSortKeysImpl::Run(const Thunk::ExecuteParams& params, + const CubSortThunk* thunk) { + const BufferAllocations& allocs = *params.buffer_allocations; + return Run(allocs.GetDeviceAddress(thunk->operand(0)), se::DeviceMemoryBase(), + allocs.GetDeviceAddress(thunk->result(0)), se::DeviceMemoryBase(), + allocs.GetDeviceAddress(thunk->scratch()), thunk->descending()); +} + +absl::StatusOr CubSortKeysImpl::GetScratchSize(int64_t num_items) { + size_t temp_bytes = 0; + const char* error = + sort_keys_fn_(nullptr, temp_bytes, nullptr, nullptr, num_items, false); + if (error != nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("CubSortKeys error: ", error)); + } + return temp_bytes; +} + +// Template class for sorting a pair of tensors. +class CubSortPairsImpl : public CubSortRunnerInterface { + public: + using SortPairsFn = std::function; + + explicit CubSortPairsImpl(SortPairsFn sort_pairs_fn, PrimitiveType type) + : sort_pairs_fn_(sort_pairs_fn), type_(type) {} + + absl::Status Run(se::DeviceMemoryBase input_keys, + se::DeviceMemoryBase input_values, + se::DeviceMemoryBase output_keys, + se::DeviceMemoryBase output_values, + se::DeviceMemoryBase scratch, bool descending) override; + absl::Status Run(const Thunk::ExecuteParams& params, + const CubSortThunk* thunk) override; + absl::StatusOr GetScratchSize(int64_t num_items) override; + + private: + SortPairsFn sort_pairs_fn_; + PrimitiveType type_; +}; + +absl::Status CubSortPairsImpl::Run(se::DeviceMemoryBase input_keys, + se::DeviceMemoryBase input_values, + se::DeviceMemoryBase output_keys, + se::DeviceMemoryBase output_values, + se::DeviceMemoryBase scratch, + bool descending) { + size_t temp_bytes = scratch.size(); + size_t num_items = input_keys.size() * 8 / primitive_util::BitWidth(type_); + const char* error = sort_pairs_fn_( + scratch.opaque(), temp_bytes, input_keys.opaque(), output_keys.opaque(), + input_values.opaque(), output_values.opaque(), num_items, descending); + if (error != nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("CubSortPairs error: ", error)); + } + return absl::OkStatus(); +} + +absl::Status CubSortPairsImpl::Run(const Thunk::ExecuteParams& params, + const CubSortThunk* thunk) { + const BufferAllocations& allocs = *params.buffer_allocations; + return Run(allocs.GetDeviceAddress(thunk->operand(0)), + allocs.GetDeviceAddress(thunk->operand(1)), + allocs.GetDeviceAddress(thunk->result(0)), + allocs.GetDeviceAddress(thunk->result(1)), + allocs.GetDeviceAddress(thunk->scratch()), thunk->descending()); +} + +absl::StatusOr CubSortPairsImpl::GetScratchSize(int64_t num_items) { + size_t temp_bytes = 0; + const char* error = sort_pairs_fn_(nullptr, temp_bytes, nullptr, nullptr, + nullptr, nullptr, num_items, false); + if (error != nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("CubSortPairs error: ", error)); + } + return temp_bytes; +} + +absl::StatusOr> CreateCubSortRunner( + PrimitiveType type) { + switch (type) { + case F16: + return std::make_unique(CubSortKeys_f16, F16); + case F32: + return std::make_unique(CubSortKeys_f32, F32); + case F64: + return std::make_unique(CubSortKeys_f64, F64); + case S8: + return std::make_unique(CubSortKeys_s8, S8); + case S16: + return std::make_unique(CubSortKeys_s16, S16); + case S32: + return std::make_unique(CubSortKeys_s32, S32); + case S64: + return std::make_unique(CubSortKeys_s64, S64); + case U8: + return std::make_unique(CubSortKeys_u8, U8); + case U16: + return std::make_unique(CubSortKeys_u16, U16); + case U32: + return std::make_unique(CubSortKeys_u32, U32); + case U64: + return std::make_unique(CubSortKeys_u64, U64); + default: + return InvalidArgument("Unsupported type of the sort kernel: %s", + primitive_util::LowercasePrimitiveTypeName(type)); + } +} + +absl::StatusOr> CreateCubSortRunner( + PrimitiveType key_type, PrimitiveType value_type) { + // Values can be of any type of 16/32/64 bit width. + int valueWidth = primitive_util::BitWidth(value_type); + if (valueWidth != 16 && valueWidth != 32 && valueWidth != 64) { + return InvalidArgument( + "Unsupported value type of the sort kernel: %s", + primitive_util::LowercasePrimitiveTypeName(value_type)); + } + + // Only unsigned integer types could be used for keys. + switch (key_type) { + case U16: + if (valueWidth == 16) { + return std::make_unique(CubSortPairs_u16_b16, U16); + } + if (valueWidth == 32) { + return std::make_unique(CubSortPairs_u16_b32, U16); + } + return std::make_unique(CubSortPairs_u16_b64, U16); + case U32: + if (valueWidth == 16) { + return std::make_unique(CubSortPairs_u32_b16, U32); + } + if (valueWidth == 32) { + return std::make_unique(CubSortPairs_u32_b32, U32); + } + return std::make_unique(CubSortPairs_u32_b64, U32); + case U64: + if (valueWidth == 16) { + return std::make_unique(CubSortPairs_u64_b16, U64); + } + if (valueWidth == 32) { + return std::make_unique(CubSortPairs_u64_b32, U64); + } + return std::make_unique(CubSortPairs_u64_b64, U64); + default: + return InvalidArgument( + "Unsupported key type of the sort kernel: %s", + primitive_util::LowercasePrimitiveTypeName(key_type)); + } +} + +} // namespace + +absl::StatusOr> +CubSortRunnerInterface::Create(PrimitiveType type, + std::optional value_type) { + return value_type.has_value() ? CreateCubSortRunner(type, *value_type) + : CreateCubSortRunner(type); +} + +CubSortThunk::CubSortThunk( + ThunkInfo thunk_info, PrimitiveType type, + std::optional value_type, + absl::InlinedVector operands, + absl::InlinedVector results, + BufferAllocation::Slice scratch, bool descending) + : Thunk(Thunk::kCubSort, thunk_info), + runner_(CubSortRunnerInterface::Create(type, value_type).value()), + operands_(std::move(operands)), + results_(std::move(results)), + scratch_(scratch), + descending_(descending) {} + +absl::Status RunCubSort(PrimitiveType type, + std::optional value_type, + se::DeviceMemoryBase input_keys, + se::DeviceMemoryBase input_values, + se::DeviceMemoryBase output_keys, + se::DeviceMemoryBase output_values, + se::DeviceMemoryBase scratch, bool descending) { + auto runner = CubSortRunnerInterface::Create(type, value_type).value(); + return runner->Run(input_keys, input_values, output_keys, output_values, + scratch, descending); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/cub_sort_thunk.h b/xla/service/gpu/runtime/cub_sort_thunk.h new file mode 100644 index 0000000000000..12ee7a6dd1f3e --- /dev/null +++ b/xla/service/gpu/runtime/cub_sort_thunk.h @@ -0,0 +1,86 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_CUB_SORT_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_CUB_SORT_THUNK_H_ + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +class CubSortRunnerInterface { + public: + virtual ~CubSortRunnerInterface() = default; + virtual absl::Status Run(se::DeviceMemoryBase input_keys, + se::DeviceMemoryBase input_values, + se::DeviceMemoryBase output_keys, + se::DeviceMemoryBase output_values, + se::DeviceMemoryBase scratch, bool descending) = 0; + virtual absl::Status Run(const Thunk::ExecuteParams& params, + const class CubSortThunk* thunk) = 0; + virtual absl::StatusOr GetScratchSize(int64_t num_items) = 0; + + static absl::StatusOr> Create( + PrimitiveType type, std::optional value_type); +}; + +class CubSortThunk : public Thunk { + public: + CubSortThunk(ThunkInfo thunk_info, PrimitiveType type, + std::optional value_type, + absl::InlinedVector operands, + absl::InlinedVector results, + BufferAllocation::Slice scratch, bool descending); + + absl::Status ExecuteOnStream(const ExecuteParams& params) override { + return runner_->Run(params, this); + } + + BufferAllocation::Slice operand(int i) const { return operands_[i]; } + BufferAllocation::Slice result(int i) const { return results_[i]; } + BufferAllocation::Slice scratch() const { return scratch_; } + bool descending() const { return descending_; } + + private: + std::unique_ptr runner_; + absl::InlinedVector operands_; + absl::InlinedVector results_; + BufferAllocation::Slice scratch_; + bool descending_; +}; + +absl::Status RunCubSort(PrimitiveType type, + std::optional value_type, + se::DeviceMemoryBase input_keys, + se::DeviceMemoryBase input_values, + se::DeviceMemoryBase output_keys, + se::DeviceMemoryBase output_values, + se::DeviceMemoryBase scratch, bool descending); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_CUB_SORT_THUNK_H_ diff --git a/xla/service/gpu/runtime/cudnn_thunk.cc b/xla/service/gpu/runtime/cudnn_thunk.cc new file mode 100644 index 0000000000000..b156aa3228c24 --- /dev/null +++ b/xla/service/gpu/runtime/cudnn_thunk.cc @@ -0,0 +1,67 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/cudnn_thunk.h" + +#include +#include + +#include "absl/status/status.h" +#include "xla/stream_executor/dnn.h" +#include "tsl/platform/errors.h" + +namespace xla { +namespace gpu { + +CuDnnThunk::CuDnnThunk(std::string fingerprint, ThunkInfo thunk_info, + absl::Span kernel_arguments) + : Thunk(Kind::kCuDnn, std::move(thunk_info)), + fingerprint_(std::move(fingerprint)), + graph_(std::make_shared(nullptr)) { + args_.reserve(kernel_arguments.size()); + for (const KernelArgument& kernel_argument : kernel_arguments) { + args_.push_back(kernel_argument.slice()); + }; +} + +absl::Status CuDnnThunk::Initialize(const InitializeParams& params) { + absl::Status ret = absl::OkStatus(); + absl::call_once(once_flag_, [&] { + auto result = params.stream->parent()->AsDnn()->DeserializeGraph( + params.src.dnn_compiled_graphs.at(fingerprint_)); + std::string().swap(fingerprint_); + if (result.ok()) { + graph_->swap(*result); + } + ret = result.status(); + }); + return ret; +} + +absl::Status CuDnnThunk::ExecuteOnStream(const ExecuteParams& params) { + InitializeParams initialize_params; + initialize_params.stream = params.stream; + TF_RETURN_IF_ERROR(Initialize(initialize_params)); + std::vector buffer_args; + buffer_args.reserve(args_.size()); + for (const BufferAllocation::Slice& arg : args_) { + buffer_args.push_back(params.buffer_allocations->GetDeviceAddress(arg)); + } + return graph_->get()->Execute(*params.stream, + absl::Span(buffer_args)); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/cudnn_thunk.h b/xla/service/gpu/runtime/cudnn_thunk.h new file mode 100644 index 0000000000000..b1b9988952217 --- /dev/null +++ b/xla/service/gpu/runtime/cudnn_thunk.h @@ -0,0 +1,61 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_CUDNN_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_CUDNN_THUNK_H_ + +#include +#include +#include + +#include "absl/base/call_once.h" +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/kernel_arguments.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/stream_executor/dnn.h" + +namespace xla { +namespace gpu { + +// Wraps executable cuDNN graph objects. +class CuDnnThunk : public Thunk { + public: + CuDnnThunk(std::string fingerprint, ThunkInfo, + absl::Span); + CuDnnThunk(const CuDnnThunk&) = delete; + CuDnnThunk& operator=(const CuDnnThunk&) = delete; + ~CuDnnThunk() override = default; + + absl::Status Initialize(const InitializeParams&) override; + absl::Status ExecuteOnStream(const ExecuteParams&) override; + + std::shared_ptr graph() const { return graph_; } + const std::vector& arguments() const { + return args_; + } + + private: + absl::once_flag once_flag_; + std::string fingerprint_; + std::shared_ptr graph_; + std::vector args_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_CUDNN_THUNK_H_ diff --git a/xla/service/gpu/runtime/custom_call.cc b/xla/service/gpu/runtime/custom_call.cc deleted file mode 100644 index 90d40a95d8899..0000000000000 --- a/xla/service/gpu/runtime/custom_call.cc +++ /dev/null @@ -1,177 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/custom_call.h" - -#include -#include -#include - -#include "xla/runtime/executable.h" -#include "xla/service/custom_call_status_internal.h" -#include "xla/service/custom_call_target_registry.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/gpu/runtime/triangular_solve.h" -#include "xla/service/service_executable_run_options.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/stream_executor/gpu/gpu_stream.h" -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -namespace xla { -namespace gpu { - -// Custom calls with API version API_VERSION_TYPED_FFI lowered directly to an -// Xla runtime custom calls. Older API versions handled by adapting Xla runtime -// calling convention to the calling convention expected by the registered -// handler. -// -// Once all Xla backends will use Xla runtime we will deprecate older API -// version, and migrate all users to API_VERSION_TYPED_FFI. -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -using xla::runtime::CustomCall; -using xla::runtime::FlatMemrefView; -using xla::runtime::StridedMemrefView; - -static absl::Status XlaCustomCallImpl( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, CustomCall::RemainingArgs args, - std::string_view call_target_name, int32_t api_version, - std::string_view backend_config) { - // Pattern match custom call to a few special cases, otherwise find the custom - // call handler regustered with the runtime. - if (call_target_name == kTriangularSolveCallTarget) - return TriangularSolve::run(run_options, debug_options, args, - backend_config); - - // Find the Xla custom call handler. - auto& platform_name = run_options->stream()->parent()->platform()->Name(); - void* call_target = CustomCallTargetRegistry::Global()->Lookup( - std::string(call_target_name), platform_name); - if (!call_target) { - return absl::InvalidArgumentError(absl::StrCat( - "Cannot find the Xla custom call handler ", call_target_name)); - } - - // Prepare pointers to buffers to pass to the Xla custom call handler. - llvm::SmallVector buffers; - for (unsigned i = 0; i < args.size(); ++i) { - if (auto memref = args.get(i); succeeded(memref)) { - buffers.push_back(memref->data); - continue; - } - - if (auto strided = args.get(i); succeeded(strided)) { - buffers.push_back(strided->data); - continue; - } - - // TODO(ezhulenev): Add dialect and type to model Xla custom call holes, - // today we rely on the fact that custom calls do not support scalar - // arguments and we can disambiguate holes from real arguments. - if (auto hole = args.get(i); succeeded(hole)) { - buffers.push_back(nullptr); - continue; - } - - return absl::InvalidArgumentError( - "Failed to get arguments as (strided) memref view"); - } - - // Call custom call handler using the calling convention it requires. - using ApiVersion = CustomCallApiVersion; - - // Original custom call API version that doesn't support returning status. - if (api_version == ApiVersion::API_VERSION_ORIGINAL) { - using XlaCustomCallType = - void (*)(se::gpu::GpuStreamHandle, void**, const char*, size_t); - auto xla_call_target = reinterpret_cast(call_target); - - // As this is calling an external library, we should catch the - // error as there isn't another working correctly path to return - // an error to XLA. - try { - xla_call_target(se::gpu::AsGpuStreamValue(run_options->stream()), - buffers.data(), backend_config.data(), - backend_config.size()); - } catch (std::exception& e) { - return absl::UnknownError( - absl::StrCat(call_target_name, - " XLA extension have thrown an exception: ", e.what())); - } catch (...) { - return absl::UnknownError(absl::StrCat( - call_target_name, " XLA extension have thrown an exception.")); - } - - return absl::OkStatus(); - } - - // Xla Custom call API returning status. - if (api_version == ApiVersion::API_VERSION_STATUS_RETURNING || - api_version == ApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED) { - using XlaCustomCallType = - void (*)(se::gpu::GpuStreamHandle, void**, const char*, size_t, - XlaCustomCallStatus*); - auto xla_call_target = reinterpret_cast(call_target); - - XlaCustomCallStatus custom_call_status; - // As this is calling an external library, we should catch the - // error as there isn't another working correctly path to return - // an error to XLA. - try { - xla_call_target(se::gpu::AsGpuStreamValue(run_options->stream()), - buffers.data(), backend_config.data(), - backend_config.size(), &custom_call_status); - } catch (std::exception& e) { - return absl::UnknownError( - absl::StrCat(call_target_name, - " XLA extension have thrown an exception: ", e.what())); - } catch (...) { - return absl::UnknownError(absl::StrCat( - call_target_name, " XLA extension have thrown an exception.")); - } - - if (auto message = CustomCallStatusGetMessage(&custom_call_status)) { - return absl::InternalError(message.value()); - } else { - return absl::OkStatus(); - } - } - - return absl::InvalidArgumentError( - absl::StrFormat("Unsupported custom call API version: %d", api_version)); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - XlaCustomCall, FunctionWrapper(), checks, - runtime::CustomCall::Bind("xla.gpu.memcpy") - .UserData() - .UserData() - .Arg() // args - .Attr("call_target_name") - .Attr("api_version") - .Attr("backend_config")); - -void RegisterXlaClassicCustomCalls( - runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.custom_call", XlaCustomCall); -} - -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/custom_call.h b/xla/service/gpu/runtime/custom_call.h deleted file mode 100644 index df1c9f9fca3ef..0000000000000 --- a/xla/service/gpu/runtime/custom_call.h +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_CUSTOM_CALL_H_ -#define XLA_SERVICE_GPU_RUNTIME_CUSTOM_CALL_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace gpu { - -void RegisterXlaClassicCustomCalls(runtime::DirectCustomCallRegistry& registry); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_CUSTOM_CALL_H_ diff --git a/xla/service/gpu/runtime/custom_call_registry.cc b/xla/service/gpu/runtime/custom_call_registry.cc deleted file mode 100644 index 4cca2c15045b2..0000000000000 --- a/xla/service/gpu/runtime/custom_call_registry.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/custom_call_registry.h" - -#include -#include - -#include "xla/runtime/custom_call_registry.h" - -namespace xla::gpu { - -using DirectCustomCallRegistration = - std::function; - -static std::vector* -DirectCustomCallRegistrations() { - static auto* storage = new std::vector(); - return storage; -} - -void AddDirectCustomCallRegistration( - DirectCustomCallRegistration registration) { - DirectCustomCallRegistrations()->push_back(registration); -} - -// Registers all direct custom calls with the given registry. -void RegisterDirectCustomCalls(runtime::DirectCustomCallRegistry& registry) { - for (auto& registration : *DirectCustomCallRegistrations()) { - registration(registry); - } -} - -} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/custom_call_registry.h b/xla/service/gpu/runtime/custom_call_registry.h deleted file mode 100644 index 5038e2708d1ed..0000000000000 --- a/xla/service/gpu/runtime/custom_call_registry.h +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_CUSTOM_CALL_REGISTRY_H_ -#define XLA_SERVICE_GPU_RUNTIME_CUSTOM_CALL_REGISTRY_H_ - -#include - -#include "xla/runtime/custom_call_registry.h" - -namespace xla::gpu { - -// This is a static custom call registry for XLA:GPU executables. XLA runtime -// custom calls must not be confused with a "classic" custom calls, they are -// an internal implementation of XLA runtime (and XLA:GPU by extension), and -// do not provide stable ABI across dynamically loaded libraries. XLA runtime -// custom calls must be statically linked. -// -// XLA:FFI is the planned mechanism for registering "custom calls" via a stable -// C ABI for internal and external uses, however it's under construction. -// -// See more XLA runtime and XLA FFI plans here: -// https://docs.google.com/document/d/1XHzJyfq-ZFn9WHoKe4o_urnwS991dFHgWoNRboBK_3I/edit#bookmark=id.696pyshem503 -// -// XLA:FFI will become an official "external custom call" mechanism for XLA:GPU -// and XLA:CPU some time in 2024. - -// Adds a direct custom call registration function to a static registry. -void AddDirectCustomCallRegistration( - std::function registration); - -// Registers all direct custom calls with the given registry. -void RegisterDirectCustomCalls(runtime::DirectCustomCallRegistry& registry); - -//===----------------------------------------------------------------------===// -// Helper macro to define a static module registration. -//===----------------------------------------------------------------------===// - -#define XLA_GPU_REGISTER_RUNTIME_CUSTOM_CALL(FUNC) \ - XLA_GPU_REGISTER_RUNTIME_CUSTOM_CALL_IMPL(FUNC, __COUNTER__) - -#define XLA_GPU_REGISTER_RUNTIME_CUSTOM_CALL_IMPL(FUNC, N) \ - static bool xla_gpu_runtime_custom_call_##N##_registered_ = []() { \ - ::xla::gpu::AddDirectCustomCallRegistration(FUNC); \ - return true; \ - }() - -} // namespace xla::gpu - -#endif // XLA_SERVICE_GPU_RUNTIME_CUSTOM_CALL_REGISTRY_H_ diff --git a/xla/service/gpu/runtime/custom_call_thunk.cc b/xla/service/gpu/runtime/custom_call_thunk.cc new file mode 100644 index 0000000000000..5c4f17fc83154 --- /dev/null +++ b/xla/service/gpu/runtime/custom_call_thunk.cc @@ -0,0 +1,233 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/custom_call_thunk.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/executable_run_options.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/call_frame.h" +#include "xla/ffi/ffi_api.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/custom_call_status.h" +#include "xla/service/custom_call_status_internal.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/service/service_executable_run_options.h" +#include "xla/status.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "xla/stream_executor/gpu/gpu_stream.h" +#endif + +namespace xla { +namespace gpu { + +using xla::ffi::CallFrame; +using xla::ffi::CallFrameBuilder; +using xla::ffi::CallOptions; + +CustomCallThunk::CustomCallThunk(ThunkInfo thunk_info, + CustomCallTarget call_target, + std::vector> operands, + std::vector> results, + const std::string& opaque) + : Thunk(Thunk::kCustomCall, thunk_info), + operands_(std::move(operands)), + results_(std::move(results)), + call_target_(std::move(call_target)), + opaque_(opaque) {} + +CustomCallThunk::CustomCallThunk(ThunkInfo thunk_info, XLA_FFI_Handler* handler, + std::vector> operands, + std::vector> results, + AttributesMap attributes, + const HloComputation* called_computation) + : Thunk(Thunk::kCustomCall, thunk_info), + operands_(std::move(operands)), + results_(std::move(results)), + handler_(std::move(handler)), + attributes_(std::move(attributes)), + called_computation_(called_computation) {} + +absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) { + // gpu_stream is CUstream or e.g. the equivalent type in ROCm. + std::vector buffers; + buffers.reserve(operands_.size() + results_.size()); + for (auto& slices : {operands_, results_}) { + for (const std::optional& slice : slices) { + if (!slice.has_value()) { + buffers.push_back(nullptr); + continue; + } + + if (!slice->slice.allocation()) + return Internal("custom call input missing buffer allocation"); + + buffers.push_back( + params.buffer_allocations->GetDeviceAddress(slice->slice).opaque()); + } + } + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + auto gpu_stream = se::gpu::AsGpuStreamValue(params.stream); + XlaCustomCallStatus custom_call_status; + call_target_(gpu_stream, buffers.data(), opaque_.data(), opaque_.size(), + &custom_call_status); + auto message = CustomCallStatusGetMessage(&custom_call_status); + if (message) { + return Internal("CustomCall failed: %s", *message); + } else { + return absl::OkStatus(); + } +#else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + return Unavailable( + "Custom calls on GPU are not supported in this configuration. Please " + "build with --config=cuda or --config=rocm"); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +} + +absl::Status CustomCallThunk::ExecuteFfiHandler(const ExecuteParams& params) { + // TODO(ezhulenev): This is not the most optimal approach, as we'll be doing + // a lot of extra allocation on every call. We have to keep attributes + // separate from arguments, as they do not change after thunk is constructed. + CallFrameBuilder builder; + + for (auto& operand : operands_) { + if (!operand.has_value()) + return Internal("FFI handlers do not support tokens (yet)!"); + if (!operand->slice.allocation()) + return Internal("custom call argument missing buffer allocation"); + + builder.AddBufferArg( + params.buffer_allocations->GetDeviceAddress(operand->slice), + operand->shape.element_type(), operand->shape.dimensions()); + } + + for (auto& result : results_) { + if (!result.has_value()) + return Internal("FFI handlers do not support tokens (yet)!"); + if (!result->slice.allocation()) + return Internal("custom call result missing buffer allocation"); + + builder.AddBufferRet( + params.buffer_allocations->GetDeviceAddress(result->slice), + result->shape.element_type(), result->shape.dimensions()); + } + + CallFrameBuilder::AttributesBuilder attrs; + attrs.Append(attributes_); + + builder.AddAttributes(attrs.Build()); + CallFrame call_frame = builder.Build(); + + // TODO(ezhulenev): Remove `ServiceExecutableRunOptions` from FFI handler + // execution context, as apparently it's not easily accessible from Thunk. + ExecutableRunOptions run_options; + run_options.set_stream(params.stream); + run_options.set_allocator(params.buffer_allocations->memory_allocator()); + run_options.set_device_ordinal(params.buffer_allocations->device_ordinal()); + ServiceExecutableRunOptions service_run_options(run_options); + + CallOptions options = {&service_run_options, called_computation_}; + return Call(handler_, call_frame, options); +} + +absl::Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) { + return handler_ ? ExecuteFfiHandler(params) : ExecuteCustomCall(params); +} + +absl::StatusOr BuildAttributesMap( + mlir::DictionaryAttr dict) { + CustomCallThunk::AttributesMap attributes; + for (auto& kv : dict) { + std::string_view name = kv.getName().strref(); + + auto integer = [&](mlir::IntegerAttr integer) { + switch (integer.getType().getIntOrFloatBitWidth()) { + case 32: + attributes[name] = static_cast(integer.getInt()); + return absl::OkStatus(); + case 64: + attributes[name] = static_cast(integer.getInt()); + return absl::OkStatus(); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported integer attribute bit width for attribute: ", name)); + } + }; + + auto fp = [&](mlir::FloatAttr fp) { + switch (fp.getType().getIntOrFloatBitWidth()) { + case 32: + attributes[name] = static_cast(fp.getValue().convertToFloat()); + return absl::OkStatus(); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported float attribute bit width for attribute: ", name)); + } + }; + + auto arr = [&](mlir::DenseArrayAttr arr) { + if (auto dense = mlir::dyn_cast(arr)) { + attributes[name] = dense.asArrayRef().vec(); + return absl::OkStatus(); + } else if (auto dense = mlir::dyn_cast(arr)) { + attributes[name] = dense.asArrayRef().vec(); + return absl::OkStatus(); + } + + return absl::InvalidArgumentError( + absl::StrCat("Unsupported array element type for attribute: ", name)); + }; + + auto str = [&](mlir::StringAttr str) { + attributes[name] = str.getValue().str(); + return absl::OkStatus(); + }; + + TF_RETURN_IF_ERROR( + llvm::TypeSwitch(kv.getValue()) + .Case(integer) + .Case(fp) + .Case(arr) + .Case(str) + .Default([&](mlir::Attribute) { + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported attribute type for attribute: ", name)); + })); + } + return attributes; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/custom_call_thunk.h b/xla/service/gpu/runtime/custom_call_thunk.h new file mode 100644 index 0000000000000..02679d2e0d21f --- /dev/null +++ b/xla/service/gpu/runtime/custom_call_thunk.h @@ -0,0 +1,131 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_CUSTOM_CALL_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_CUSTOM_CALL_THUNK_H_ + +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/call_frame.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/custom_call_status.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/shape.h" +#include "xla/status.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "xla/stream_executor/gpu/gpu_types.h" +#endif + +namespace xla { +namespace gpu { + +// Thunk to run a GPU custom call. +// +// This thunk's `ExecuteOnStream` implementation executes a host function +// `call_target` which is expected to enqueue operations onto the GPU. +// +// Note that not all kCustomCall HLOs in XLA:GPU end up being run by this thunk. +// XLA itself creates kCustomCall instructions when lowering kConvolution HLOs +// into calls to cudnn. These internally-created custom-calls are run using +// ConvolutionThunk, not CustomCallThunk. There's no ambiguity because they +// have special call target names (e.g. "__cudnn$convForward") that only the +// compiler is allowed to create. +class CustomCallThunk : public Thunk { + public: +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + using Stream = stream_executor::gpu::GpuStreamHandle; +#else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + using Stream = void*; +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + + using CustomCallTarget = std::function; + + // We keep buffer allocation slice together with its shape to be able to fill + // FFI arguments with required details. + struct Slice { + BufferAllocation::Slice slice; + Shape shape; + }; + + using Attribute = ffi::CallFrameBuilder::FlatAttribute; + using AttributesMap = ffi::CallFrameBuilder::FlatAttributesMap; + + CustomCallThunk(ThunkInfo thunk_info, CustomCallTarget call_target, + std::vector> operands, + std::vector> results, + const std::string& opaque); + + CustomCallThunk(ThunkInfo thunk_info, XLA_FFI_Handler* handler, + std::vector> operands, + std::vector> results, + AttributesMap attributes, + const HloComputation* called_computation); + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + const CustomCallTarget& call_target() const { return call_target_; } + const std::vector>& operands() const { + return operands_; + } + const std::vector>& results() const { return results_; } + absl::string_view opaque() const { return opaque_; } + + private: + absl::Status ExecuteCustomCall(const ExecuteParams& params); + absl::Status ExecuteFfiHandler(const ExecuteParams& params); + + std::vector> operands_; + std::vector> results_; + + // This is a legacy custom call API that is discouraged, and will be + // deprecated once XLA:FFI mechanism is ready. + CustomCallTarget call_target_; + std::string opaque_; + + // XLA FFI provides a right type safe mechanism for registering external + // functions with XLA runtime. It's under construction, and still misses + // a lot of features. Long term it will replace legacy custom calls. + XLA_FFI_Handler* handler_ = nullptr; + AttributesMap attributes_; + + // TODO(ezhulenev): Currently we assume that HloModule that owns this + // computation is owned by a GpuExecutable and stays alive for as long as + // thunk is alive, however in general it might not be true and we can destroy + // underlying HloModule. We have to make a copy of HloComputation for a thunk, + // and also pass some form of relatively-ABI-stable representation to external + // custom calls, i.e. we can pass it as HloComputationProto or as MLIR + // bytecode of the computation serialized to StableHLO. Today we assume that + // custom calls that access called computation can only be linked statically. + const HloComputation* called_computation_ = nullptr; +}; + +// Converts MLIR dictionary attribute attached to a custom call operation to a +// custom call thunk attributes that are forwarded to the FFI handler. +absl::StatusOr BuildAttributesMap( + mlir::DictionaryAttr dict); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_CUSTOM_CALL_THUNK_H_ diff --git a/xla/service/gpu/runtime/executable.cc b/xla/service/gpu/runtime/executable.cc index ea57b94aac97c..e69de29bb2d1d 100644 --- a/xla/service/gpu/runtime/executable.cc +++ b/xla/service/gpu/runtime/executable.cc @@ -1,521 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/executable.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "xla/mlir/runtime/transforms/compilation_pipeline_gpu.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/jit_executable.h" -#include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" -#include "xla/service/gpu/runtime/cholesky.h" -#include "xla/service/gpu/runtime/concurrent_region.h" -#include "xla/service/gpu/runtime/conv.h" -#include "xla/service/gpu/runtime/conv_reorder.h" -#include "xla/service/gpu/runtime/cub_sort.h" -#include "xla/service/gpu/runtime/custom_call.h" -#include "xla/service/gpu/runtime/custom_call_registry.h" -#include "xla/service/gpu/runtime/fft.h" -#include "xla/service/gpu/runtime/fused_attention.h" -#include "xla/service/gpu/runtime/gemm.h" -#include "xla/service/gpu/runtime/gpublas_lt_matmul.h" -#include "xla/service/gpu/runtime/graph_launch.h" -#include "xla/service/gpu/runtime/io_feed.h" -#include "xla/service/gpu/runtime/memcpy.h" -#include "xla/service/gpu/runtime/memset.h" -#include "xla/service/gpu/runtime/norm.h" -#include "xla/service/gpu/runtime/send_recv.h" -#include "xla/service/gpu/runtime/stream_synchronization.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/gpu/runtime/topk.h" -#include "xla/service/gpu/runtime/resize_bicubic.h" -#include "xla/service/gpu/runtime/tracing.h" -#include "xla/service/gpu/thunk.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/service/stream_pool.h" -#include "xla/statusor.h" -#include "xla/stream_executor/stream.h" -#include "tsl/protobuf/dnn.pb.h" - -namespace xla { -namespace gpu { - -using ::xla::runtime::CustomCallAttrEncodingSet; -using ::xla::runtime::DirectCustomCallRegistry; -using ::xla::runtime::Executable; -using ::xla::runtime::JitExecutable; -using ::xla::runtime::Tagged; -using ::xla::runtime::TypeIDNameRegistry; - -using ::xla::runtime::CustomCall; -using ::xla::runtime::DiagnosticEngine; -using ::xla::runtime::ExportModules; - -void RegisterXlaGpuRuntimeCustomCalls(DirectCustomCallRegistry& registry) { - // Register custom calls from a static XLA:GPU registry. - RegisterDirectCustomCalls(registry); - - // Register builtin XLA:GPU custom calls (aka GPU runtime). - RegisterKernelLaunchCustomCalls(registry); - RegisterTracingCustomCalls(registry); - RegisterFftCustomCalls(registry); - RegisterCholeskyCustomCalls(registry); - RegisterCollectiveCustomCalls(registry); - RegisterGemmCustomCalls(registry); - RegisterConvCustomCalls(registry); - RegisterConvReorderCustomCalls(registry); - RegisterMemcpyCustomCalls(registry); - RegisterIoFeedCustomCalls(registry); - RegisterMemsetCustomCalls(registry); - RegisterSendRecvCustomCalls(registry); - RegisterResizeBicubicCustomCall(registry); - -#if GOOGLE_CUDA || TF_HIPBLASLT - RegisterMatmulCustomCalls(registry); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#if GOOGLE_CUDA - RegisterNormCustomCalls(registry); - RegisterFusedAttentionCustomCalls(registry); - RegisterFusedAttentionBackwardCustomCalls(registry); -#endif // GOOGLE_CUDA -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - // Graph launch kernels depend on Cuda Graph API. - RegisterGraphLaunchCustomCalls(registry); - RegisterConcurrentRegionCustomCalls(registry); - RegisterStreamSynchronizationCustomCalls(registry); - RegisterCubSortCustomCalls(registry); - RegisterXlaClassicCustomCalls(registry); - RegisterTopkCustomCall(registry); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -} - -void RegisterXlaGpuTypeIdNames(TypeIDNameRegistry& registry) { - registry.Register>( - "__type_id_se_dnn_activation"); - registry.Register>( - "__type_id_dot_dimension_numbers"); - registry.Register>("__type_id_se_fft_type"); - - RegisterTracingTypeIdNames(registry); - RegisterConvTypeIdNames(registry); - RegisterSendRecvTypeIdNames(registry); - -#if GOOGLE_CUDA || TF_HIPBLASLT - registry.Register>( - "__type_id_se_gpublas_lt_epilogue"); - RegisterFusedAttentionTypeIdNames(registry); - RegisterNormTypeIdNames(registry); -#endif // GOOGLE_CUDA || TF_HIPBLASLT -} - -void RegisterXlaGpuAttrEncoding(CustomCallAttrEncodingSet& encoding) { - PopulateConvAttrEncoding(encoding); - PopulateFftAttrEncoding(encoding); - PopulateDotDimsAttrEncoding(encoding); - PopulateSendRecvAttrEncoding(encoding); - -#if GOOGLE_CUDA || TF_HIPBLASLT - PopulateCublasLtMatmulAttrEncoding(encoding); - PopulateFusedAttentionAlgorithmConfigAttrEncoding(encoding); - PopulateFusedAttentionForwardDAGSignatureAttrEncoding(encoding); - PopulateFusedAttentionBackwardDAGSignatureAttrEncoding(encoding); - PopulateNormAlgorithmConfigAttrEncoding(encoding); -#endif // GOOGLE_CUDA || TF_HIPBLASLT -} - -//===----------------------------------------------------------------------===// - -// Executable can have only one "main" function and only graph capture function. -static int64_t GetNumGraphs(const runtime::Executable& executable) { - return executable.num_functions() - 1; -} - -GpuRuntimeExecutable::GpuRuntimeExecutable( - std::string module_name, std::vector buffer_sizes, - std::vector> allocation_indices, - std::unique_ptr jit_executable, DebugOptions debug_options, - ModulesState modules_state) - : module_name_(std::move(module_name)), - buffer_sizes_(std::move(buffer_sizes)), - allocation_indices_(std::move(allocation_indices)), - executable_(std::move(jit_executable)), - debug_options_(std::move(debug_options)), -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - graph_instances_(module_name_, GetNumGraphs(executable())), -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - modules_state_(std::move(modules_state)) { - ExportModules(dynamic_custom_calls_); // export runtime modules -} - -GpuRuntimeExecutable::GpuRuntimeExecutable( - std::string module_name, std::vector buffer_sizes, - std::vector> allocation_indices, - std::unique_ptr aot_executable, DebugOptions debug_options, - ModulesState modules_state) - : module_name_(std::move(module_name)), - buffer_sizes_(std::move(buffer_sizes)), - allocation_indices_(std::move(allocation_indices)), - executable_(std::move(aot_executable)), - debug_options_(std::move(debug_options)), -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - graph_instances_(module_name_, GetNumGraphs(executable())), -#endif // GOOGL_CUDA || TENSORFLOW_USE_ROCM - modules_state_(std::move(modules_state)) { - ExportModules(dynamic_custom_calls_); // export runtime modules -} - -//===----------------------------------------------------------------------===// -// Compile Xla program lowered to runtime dialects to Gpu runtime executable. -//===----------------------------------------------------------------------===// - -/*static*/ StatusOr> -GpuRuntimeExecutable::Create(std::string module_name, - std::unique_ptr program) { - // Options for the default XLA Runtime compilation pipeline. - runtime::CompilationPipelineOptions copts; - - // Populate mapping from XLA (SE) enums/structs type id to symbol names. - copts.populate_type_id_names = RegisterXlaGpuTypeIdNames; - - // For passing LMHLO attributes as XLA (SE) enums/structs to custom calls. - copts.populate_attr_encodings = RegisterXlaGpuAttrEncoding; - - // Options for constructing XLA runtime JitExecutable. - JitExecutable::Options opts; - opts.specialization = JitExecutable::Specialization::kDisabled; - opts.compiler.verification_level = - program->debug_options.xla_gpu_llvm_verification_level(); - opts.compiler.register_dialects = - runtime::RegisterDefaultXlaGpuRuntimeDialects; - - // Register XLA Gpu runtime custom calls with the linker. - opts.compiler.symbols_binding = runtime::ToSymbolsBinding( - RegisterXlaGpuRuntimeCustomCalls, RegisterXlaGpuTypeIdNames); - - // We just use the default compilation pipeline provided by the XLA runtime. - // Alternatively instead of having a separate Xla Runtime program (LMHLO - // lowered to canonical dialects), we can assemble a pipeline that will - // compile starting from the LMHLO dialect. However this intermediate step - // helps with debugging, by materializing IR with XLA runtime custom calls. - opts.compiler.create_compilation_pipeline = - [copts](xla::runtime::PassManager& passes) { - runtime::CreateDefaultXlaGpuRuntimeCompilationPipeline(passes, copts); - }; - - // Do not run expensive optimization passes because we do not expect any - // non-trivial host code in XLA:GPU host executables. - opts.compiler.jit_code_opt_level = llvm::CodeGenOptLevel::None; - - // Instantiate new JitExecutable from the MLIR source. - auto jit_executable = - JitExecutable::Instantiate(program->module, program->entry_point, opts); - if (!jit_executable.ok()) - return InternalError("Failed to compile XLA Runtime program: %s", - jit_executable.status().message()); - - // Instantiate state for all registered runtime modules. - auto modules_state = ModulesState::Instantiate(); - if (!modules_state.ok()) - return InternalError("Failed to instantiate modules state: %s", - modules_state.status().message()); - - return std::unique_ptr(new GpuRuntimeExecutable( - std::move(module_name), std::move(program->buffer_sizes), - std::move(program->allocation_indices), - std::make_unique(std::move(*jit_executable)), - std::move(program->debug_options), std::move(*modules_state))); -} - -//===----------------------------------------------------------------------===// -// Constructs Gpu runtime executable from AOT compiled runtime artifact. -//===----------------------------------------------------------------------===// - -/*static*/ StatusOr> -GpuRuntimeExecutable::Create( - std::string module_name, std::vector buffer_sizes, - std::vector> allocation_indices, Executable executable, - DebugOptions debug_options) { - // Instantiate state for all registered runtime modules. - auto modules_state = ModulesState::Instantiate(); - if (!modules_state.ok()) - return InternalError("Failed to instantiate modules state: %s", - modules_state.status().message()); - - return std::unique_ptr(new GpuRuntimeExecutable( - std::move(module_name), std::move(buffer_sizes), - std::move(allocation_indices), - std::make_unique(std::move(executable)), - std::move(debug_options), std::move(*modules_state))); -} - -//===----------------------------------------------------------------------===// -// Executes with the given buffer arguments. -//===----------------------------------------------------------------------===// - -static runtime::AsyncTaskRunner* NoAsyncTaskRunner() { - return reinterpret_cast(0XDEADBEEF); -} - -// TODO(ezhulenev): We rely on implementation details of passing memrefs to the -// compiled kernel. We should have a nicer API to do this, without creating a -// vector of temporary MemrefDesc for passing operands. -static void InitializeCallFrame(runtime::Executable::CallFrame& call_frame, - const BufferAllocations& buffer_allocations, - absl::Span buffer_sizes, - llvm::SmallVectorImpl& ptrs) { - size_t num_allocations = buffer_allocations.size(); - assert(ptrs.empty() && "pointers storage must be empty"); - ptrs.resize_for_overwrite(num_allocations); - - // Each buffer allocation passed as 1d memref to the compiled function: - // {basePtr, dataPtr, offset, [sizes, ...], [strides, ...]} - size_t num_args_ptrs = 1 + num_allocations * 5; - call_frame.args.resize_for_overwrite(num_args_ptrs); - - // Pass pointers to these constants as a memref offset and stride. - static int64_t zero = 0; - static int64_t one = 1; - void* offset = &zero; - void* stride = &one; - - // Add a placeholder for the kernel context as the first argument. - call_frame.args[0] = nullptr; - - // Initialize arguments for the buffer operands. - for (unsigned i = 0; i < num_allocations; ++i) { - void* data = &(ptrs[i] = buffer_allocations.GetDeviceAddress(i).opaque()); - void* size = const_cast(&buffer_sizes[i]); - unsigned idx = 1 + i * 5; - call_frame.args[idx + 0] = data; - call_frame.args[idx + 1] = data; - call_frame.args[idx + 2] = offset; - call_frame.args[idx + 3] = size; - call_frame.args[idx + 4] = stride; - } -} - -Status GpuRuntimeExecutable::Execute( - const ServiceExecutableRunOptions* run_options, const std::string& asm_text, - const std::vector& binary, - const BufferAllocations& buffer_allocations, - NonAtomicallyUpgradeableRWLock& gpu_lock, - const BufferAllocation* temp_alloc) { - // We pass a pointer to the executable through UserData, so that we can - // get access to other exported functions from custom call handlers. - runtime::Executable& executable = this->executable(); - - // Pack buffer allocations as executable arguments. It is guaranteed that - // the compiled function will make a copy of all arguments and will write all - // results after the call to `Execute` completes, so it is safe to keep them - // on the stack. - runtime::Executable::CallFrame call_frame; - - llvm::SmallVector ptrs; // storage for device address pointers - InitializeCallFrame(call_frame, buffer_allocations, buffer_sizes_, ptrs); - - // Check that initialized call frame is compatible with the executable - // entry point signature, otherwise compiled executable can read memory out of - // arguments bounds and crash with a segfault. - const runtime::FunctionType& signature = executable.signature(); - if (signature.num_operands() != buffer_allocations.size()) - return InternalError("Expected %d arguments but got %d buffer allocations", - signature.num_operands(), buffer_allocations.size()); - - for (unsigned i = 0; i < executable.signature().num_operands(); ++i) { - auto* memref = llvm::dyn_cast(signature.operand(i)); - if (!memref) return InvalidArgument("Expected memref as %d-th argument", i); - - if (memref->rank() != 1 || memref->sizes()[0] != buffer_sizes_[i]) - return InvalidArgument("Expected a buffer of size %d but got %d", - memref->sizes()[0], buffer_sizes_[i]); - } - - // XLA Runtime executables do not return any values. - runtime::NoResultConverter converter; - - // Get the async communications stream for async collectives. - se::StreamExecutor* executor = run_options->stream()->parent(); - se::StreamPriority stream_priority = se::StreamPriority::Default; - if (debug_options_.xla_gpu_enable_highest_priority_async_stream()) { - stream_priority = se::StreamPriority::Highest; - } - - // Create the needed streams to support NcclCollectiveThunk. - // - // Calling BorrowStream multiple times doesn't work as intended, see - // b/293945751. - absl::InlinedVector async_comm_streams( - kAsyncStreamTotal, nullptr); - StatusOr> streams = run_options->BorrowStreams( - executor->device_ordinal(), kAsyncStreamTotal, stream_priority); - if (streams.ok()) { - for (int64_t i = 0; i < kAsyncStreamTotal; ++i) { - async_comm_streams[i] = streams->at(i).get(); - } - } - - // Async Collectives support and Send/Recv events instantiated for each Gpu - // executable run, so that concurrent executions can run independently using a - // separate set of events for communication. - AsyncCollectivesSupport async_collectives(async_comm_streams); - SendRecvEvents send_recv_events; - - // Always pass in the temp buffer, even if it is null, to accommodate the - // 0-sized buffer corner case. - se::DeviceMemoryBase temp_buffer; - if (temp_alloc) - temp_buffer = buffer_allocations.GetDeviceAddress(temp_alloc->index()); - - // State cached separately for each stream executor. - StreamExecutorKernels::Snapshot kernels = gpu_kernels_(executor)->snapshot(); - StreamExecutorConvRunners::Snapshot conv_runners = - conv_runners_(executor)->snapshot(); - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - std::shared_ptr executor_graphs = - graph_instances_(executor); - - StreamExecutorGraphInstances::Snapshot graph_instances = - executor_graphs->snapshot(); - CapturedFunctionExecutionCount::Snapshot execution_count = - captured_function_counts_(executor)->snapshot(); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - - // Kernels in concurrent regions should be launched on borrowed stream, so - // that the cuda graph won't record dependencies between kernels. - // This state stores if the kernel being run is in a concurrent region and - // the borrowed streams for executing kernels in concurrent regions. - ConcurrentRegionStatus concurrent_region_status(run_options); - - // State cached globally for gpu executable. - GemmConfigs::Snapshot gemm_configs = gemm_configs_.snapshot(); - FftPlans::Snapshot fft_plans = fft_plans_.snapshot(); - -#if GOOGLE_CUDA || TF_HIPBLASLT - MatmulPlans::Snapshot matmul_plans = gpublas_lt_matmul_plans_.snapshot(); -#endif - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - StreamExecutorNormRunners::Snapshot norm_runners = - norm_runners_(executor)->snapshot(); - StreamExecutorFusedAttentionRunners::Snapshot fused_attention_runners = - fused_attention_runners_(executor)->snapshot(); - StreamExecutorFusedAttentionBackwardRunners::Snapshot - fused_attention_backward_runners = - fused_attention_backward_runners_(executor)->snapshot(); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - - // Pass auxiliary data to the custom call handlers. - runtime::CustomCall::UserData user_data( - run_options, &executable, &debug_options_, &temp_buffer, &asm_text, - &binary, &kernels, &gemm_configs, &conv_runners, &collectives_, - &fft_plans, &send_recv_events, &gpu_lock, -#if GOOGLE_CUDA || TF_HIPBLASLT - &matmul_plans, -#endif -#if GOOGLE_CUDA - // Auxiliary data that is available only if compiled with CUDA support - // only. - &norm_runners, &fused_attention_runners, - &fused_attention_backward_runners, -#endif // GOOGLE_CUDA -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - &graph_instances, &execution_count, -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - &concurrent_region_status, - // Null pointer will be interpreted as an absence of async collectives - // support and custom calls will safely return an error. - async_collectives.async_comm_stream(kAsyncStreamCollective) - ? &async_collectives - : nullptr); - - // Initialize state required for running functions from registered modules. - auto state_ref = modules_state_.InitializeUserData(user_data); - if (!state_ref.ok()) - return InternalError("Failed to initialize runtime modules state: %s", - state_ref.status().message()); - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - // Instantiate all CUDA graphs before executing the main function. - if (debug_options_.xla_gpu_graph_num_runs_to_instantiate() < 0 && - !graph_instances_.InstantiatedAllGraphs(run_options, executable)) { - if (auto instantiated = graph_instances_.InstantiateAllGraphs( - run_options, executable, user_data, buffer_allocations, - buffer_sizes_, allocation_indices_, - debug_options_.xla_gpu_graph_eviction_timeout_seconds()); - !instantiated.ok()) { - return InternalError("Failed to instantiate GPU graphs: %s", - instantiated.message()); - } - } -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - - // Collect all emitted diagnostic messages. - std::string diagnostic; - runtime::DiagnosticEngine diagnostic_engine; - AppendDiagnosticToString(diagnostic_engine, &diagnostic, true); - - // Prepare options for executing XLA Runtime program. - runtime::Executable::ExecuteOpts opts; - opts.async_task_runner = NoAsyncTaskRunner(); - opts.custom_call_data = &user_data; - opts.diagnostic_engine = &diagnostic_engine; - opts.custom_call_registry = &dynamic_custom_calls_; - - // Execute with the prepared call frame. - executable.Execute(call_frame, opts); - - if (auto st = executable.ReturnResults(converter, &call_frame); !st.ok()) { - return InternalError("Failed to execute XLA Runtime executable: %s%s%s.", - st.message(), diagnostic.empty() ? "" : ": ", - diagnostic); - } - - return OkStatus(); -} - -//===----------------------------------------------------------------------===// - -const Executable& GpuRuntimeExecutable::executable() const { - if (auto* jit = std::get_if>(&executable_)) { - return *(*jit)->DefaultExecutable(); - } - return *std::get>(executable_); -} - -StatusOr GpuRuntimeExecutable::GetObjFile() const { - if (auto obj_file = executable().obj_file()) - return std::string_view(obj_file->getBuffer()); - - return InternalError("gpu runtime executable didn't save the obj file"); -} - -StatusOr GpuRuntimeExecutable::GetMlirModule() const { - const auto* jit = std::get_if>(&executable_); - if (!jit) return InternalError("MLIR module is not available"); - - return (*jit)->mlir_module(); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/executable.h b/xla/service/gpu/runtime/executable.h deleted file mode 100644 index fbc95fe960ac4..0000000000000 --- a/xla/service/gpu/runtime/executable.h +++ /dev/null @@ -1,210 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_EXECUTABLE_H_ -#define XLA_SERVICE_GPU_RUNTIME_EXECUTABLE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "xla/runtime/executable.h" -#include "xla/runtime/jit_executable.h" -#include "xla/runtime/module_registry.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" -#include "xla/service/gpu/runtime/collectives.h" -#include "xla/service/gpu/runtime/conv.h" -#include "xla/service/gpu/runtime/fft.h" -#include "xla/service/gpu/runtime/fused_attention.h" -#include "xla/service/gpu/runtime/gemm.h" -#include "xla/service/gpu/runtime/gpublas_lt_matmul.h" -#include "xla/service/gpu/runtime/graph_launch.h" -#include "xla/service/gpu/runtime/kernel_launch.h" -#include "xla/service/gpu/runtime/norm.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/xla.pb.h" - -namespace xla { -namespace gpu { - -// Register custom calls implementing Xla Gpu runtime. -void RegisterXlaGpuRuntimeCustomCalls( - runtime::DirectCustomCallRegistry& registry); - -// Register mapping from XLA (SE) enums/structs type ids to symbol names. -void RegisterXlaGpuTypeIdNames(runtime::TypeIDNameRegistry& registry); - -// Register encoding for (L)MHLO attributes required by the runtime functions. -void RegisterXlaGpuAttrEncoding(runtime::CustomCallAttrEncodingSet& encoding); - -// Xla Gpu program lowered to the Xla runtime dialects. Gpu runtime executable -// jit-compiles this program to an executable artifact (via lowering to LLVM). -// -// We have this program as an intermediate step between lowering from HLO to -// runtime executable to be able to introspect the compilation process. Once we -// have this program, the Xla gpu compiler job is done, and lowering to LLVM is -// the responsibility of backend-agnostic Xla runtime passes. This is the last -// stage when IR is still at a fairly high level of abstraction and has a lot of -// Gpu specific details in it. -struct GpuRuntimeProgram { - GpuRuntimeProgram(std::string entry_point, std::string module, - std::vector buffer_sizes, - std::vector> allocation_indices, - DebugOptions debug_options) - : entry_point(std::move(entry_point)), - module(std::move(module)), - buffer_sizes(std::move(buffer_sizes)), - allocation_indices(std::move(allocation_indices)), - debug_options(std::move(debug_options)) {} - - std::string entry_point; - std::string module; - std::vector buffer_sizes; - std::vector> allocation_indices; - DebugOptions debug_options; -}; - -// Gpu runtime executable encapsulates the Xla runtime executable compiled from -// an Xla program and owns all the state required for running it (e.g. it owns -// various caches required for performance). -// -// TODO(ezhulenev): Once thunks are removed from Xla, it might make sense to -// merge this executable into GpuExecutable. Today we keep it separate to manage -// the complexity of mixing two execution modes in the same file. GpuExecutable -// provides an API at XLA level of abstraction (streams and buffers), and this -// executable provides a lower level API exposing some of the implementation -// details. -class GpuRuntimeExecutable { - using ModulesState = ::xla::runtime::ModulesState; - - public: - // Creates GpuRuntimeExecutable from the Xla Gpu Program. - static StatusOr> Create( - std::string module_name, std::unique_ptr program); - - // Creates GpuRuntimeExecutable from the AOT compiled binary. - static StatusOr> Create( - std::string module_name, std::vector buffer_sizes, - std::vector> allocation_indices, - runtime::Executable executable, DebugOptions debug_options); - - // Executes entry function with the given buffer arguments. - Status Execute(const ServiceExecutableRunOptions* run_options, - const std::string& asm_text, - const std::vector& binary, - const BufferAllocations& buffer_allocations, - NonAtomicallyUpgradeableRWLock& gpu_lock, - const BufferAllocation* temp_alloc = nullptr); - - // Returns object file behind the runtime executable. This object file can - // be exported and loaded later to instantiate another executable. - StatusOr GetObjFile() const; - - // Returns MLIR module behind this executable if it is available. - StatusOr GetMlirModule() const; - - std::string_view module_name() const { return module_name_; } - - private: - GpuRuntimeExecutable(std::string module_name, - std::vector buffer_sizes, - std::vector> allocation_indices, - std::unique_ptr jit_executable, - DebugOptions debug_options, ModulesState modules_state); - - GpuRuntimeExecutable(std::string module_name, - std::vector buffer_sizes, - std::vector> allocation_indices, - std::unique_ptr aot_executable, - DebugOptions debug_options, ModulesState modules_state); - - std::string module_name_; - - // Depending on the state of `executable_` returns a reference to active - // Xla runtime executable. - runtime::Executable& executable() { - return const_cast( - const_cast(this)->executable()); - } - const runtime::Executable& executable() const; - - std::vector buffer_sizes_; - - // `rt.allocation_index` attributes for all exported functions. Indexed by - // function ordinal. - std::vector> allocation_indices_; - - // In JIT compilation mode `JitExecutable` is used. In AOT compilation mode - // `Executable` is used. - std::variant, - std::unique_ptr> - executable_; - - const DebugOptions debug_options_; - - // Keep gpu kernels loaded by this executable. - GpuExecutableKernels gpu_kernels_; - - // Keep gemm configs for all gemm operation in the program. - GemmConfigs gemm_configs_; - - // Keep a cache for conv configs for all conv operations in the program. - ConvRunners conv_runners_; - - // Keep a cache for fused norm configs for all fused norm operations in the - // program. - NormRunnerStates norm_runners_; - - // Keep a cache for fused_dot_attention configs for all fused_dot_attention - // operations in the program. - FusedAttentionRunners fused_attention_runners_; - - // Keep a cache for fused_dot_attention configs for all fused_dot_attention - // backward - // operations in the program. - FusedAttentionBackwardRunners fused_attention_backward_runners_; - - // Support for running collective operations. - CollectivesSupport collectives_; - - // Keep a cache of fft plans for all FFT operations in the program. - FftPlans fft_plans_; - -#if GOOGLE_CUDA || TF_HIPBLASLT // Keep matmul execution plans. - MatmulPlans gpublas_lt_matmul_plans_; -#endif - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - // Keep captured and instantiated GPU graphs instances. - GraphInstances graph_instances_; - CapturedFunctionExecutionCounts captured_function_counts_; -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - - // Keep an executable state for all registered runtime modules. - ModulesState modules_state_; - - // Dynamic custom calls exported from XLA runtime modules (and FFI modules). - runtime::DynamicCustomCallRegistry dynamic_custom_calls_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_EXECUTABLE_H_ diff --git a/xla/service/gpu/runtime/fft.cc b/xla/service/gpu/runtime/fft.cc deleted file mode 100644 index a668a8e1ff847..0000000000000 --- a/xla/service/gpu/runtime/fft.cc +++ /dev/null @@ -1,133 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/fft.h" - -#include - -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/state.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/gpu/runtime3/fft_thunk.h" -#include "xla/stream_executor/fft.h" - -namespace xla { - -using xla::runtime::CustomCall; -using xla::runtime::State; -using xla::runtime::StridedMemrefView; - -//===----------------------------------------------------------------------===// -// Register FFT attributes decoding with the Xla runtime. -//===----------------------------------------------------------------------===// - -namespace runtime { - -XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING(se::fft::Type); - -} // namespace runtime - -//===----------------------------------------------------------------------===// -// Encoding from MHLO attributes to Xla runtime aggregate attributes. -//===----------------------------------------------------------------------===// - -namespace gpu { - -namespace mhlo = ::mlir::mhlo; - -static se::fft::Type ConvertFftType(mhlo::FftType type) { - switch (type) { - case mhlo::FftType::FFT: - return se::fft::Type::kC2CForward; - case mhlo::FftType::IFFT: - return se::fft::Type::kC2CInverse; - case mhlo::FftType::RFFT: - return se::fft::Type::kR2C; - case mhlo::FftType::IRFFT: - return se::fft::Type::kC2R; - default: - return se::fft::Type::kInvalid; - } -} - -void PopulateFftAttrEncoding(runtime::CustomCallAttrEncodingSet& encoding) { - encoding.Add>(ConvertFftType); -} - -//===----------------------------------------------------------------------===// -// FFT custom call implementation. -//===----------------------------------------------------------------------===// - -static absl::Status FftImpl(const ServiceExecutableRunOptions* run_options, - State> state, - StridedMemrefView input, StridedMemrefView output, - absl::Span fft_length, - se::fft::Type fft_type) { - se::Stream* stream = run_options->stream(); - se::StreamExecutor* executor = stream->parent(); - - if (input.dtype == PrimitiveType::F64 || input.dtype == PrimitiveType::C128) { - // Adjust FFT type to reflect double precision. - switch (fft_type) { - case se::fft::Type::kC2CForward: - fft_type = se::fft::Type::kZ2ZForward; - break; - case se::fft::Type::kC2CInverse: - fft_type = se::fft::Type::kZ2ZInverse; - break; - case se::fft::Type::kR2C: - fft_type = se::fft::Type::kD2Z; - break; - case se::fft::Type::kC2R: - fft_type = se::fft::Type::kZ2D; - break; - default: - return absl::InvalidArgumentError("Unsupported FFT type"); - } - } - - TF_ASSIGN_OR_RETURN( - std::unique_ptr * fft_plan_cache, - state.GetOrCreate([]() -> absl::StatusOr> { - return std::make_unique(); - })); - - return RunFft(GetDeviceAddress(input), ToShape(input), - GetDeviceAddress(output), ToShape(output), fft_type, fft_length, - executor->device_ordinal(), fft_plan_cache->get(), stream, - run_options->allocator()); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Fft, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.fft") - .UserData() - .State>("uid") - .Arg() // input - .Arg() // output - .Attr>("fft_length") - .Attr("fft_type")); - -//===----------------------------------------------------------------------===// - -void RegisterFftCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.fft", Fft); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/fft.h b/xla/service/gpu/runtime/fft.h deleted file mode 100644 index 7a34e2d1db4e6..0000000000000 --- a/xla/service/gpu/runtime/fft.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_FFT_H_ -#define XLA_SERVICE_GPU_RUNTIME_FFT_H_ - -#include - -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/service/gpu/runtime3/fft_thunk.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime fft custom calls. -void RegisterFftCustomCalls(runtime::DirectCustomCallRegistry& registry); - -// Adds attributes encoding set for fft custom calls -void PopulateFftAttrEncoding(runtime::CustomCallAttrEncodingSet& encoding); - -// Keep FftPlanCache for all FFT instances in the executable. -class FftPlans : public runtime::StateVector> {}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_FFT_H_ diff --git a/xla/service/gpu/runtime/fft_thunk.cc b/xla/service/gpu/runtime/fft_thunk.cc new file mode 100644 index 0000000000000..728c36752aeed --- /dev/null +++ b/xla/service/gpu/runtime/fft_thunk.cc @@ -0,0 +1,258 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/fft_thunk.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "xla/stream_executor/scratch_allocator.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/types.h" +#include "xla/util.h" +#include "tsl/platform/logging.h" + +namespace xla { +namespace gpu { +namespace { + +se::fft::Type FftTypeToSeType(FftType type, bool double_precision) { + switch (type) { + case FftType::FFT: + return double_precision ? se::fft::Type::kZ2ZForward + : se::fft::Type::kC2CForward; + case FftType::IFFT: + return double_precision ? se::fft::Type::kZ2ZInverse + : se::fft::Type::kC2CInverse; + case FftType::IRFFT: + return double_precision ? se::fft::Type::kZ2D : se::fft::Type::kC2R; + case FftType::RFFT: + return double_precision ? se::fft::Type::kD2Z : se::fft::Type::kR2C; + default: + LOG(FATAL) << "unsupported fft type"; + } +} + +std::string FftTypeToString(se::fft::Type type) { + switch (type) { + case se::fft::Type::kC2CForward: + case se::fft::Type::kZ2ZForward: + return "FFT"; + case se::fft::Type::kC2CInverse: + case se::fft::Type::kZ2ZInverse: + return "IFFT"; + case se::fft::Type::kC2R: + case se::fft::Type::kZ2D: + return "IRFFT"; + case se::fft::Type::kR2C: + case se::fft::Type::kD2Z: + return "RFFT"; + default: + LOG(FATAL) << "unknown fft type"; + } +} + +absl::StatusOr GetBlas( + se::Stream* stream) { + auto blas = stream->parent()->AsBlas(); + if (blas == nullptr) { + return absl::InternalError("Unable to get Blas support"); + } + return blas; +} + +absl::StatusOr GetFft(se::Stream* stream) { + auto fft = stream->parent()->AsFft(); + if (fft == nullptr) { + return absl::InternalError("Unable to get fft support"); + } + return fft; +} +} // namespace + +FftThunk::FftThunk(ThunkInfo thunk_info, FftType fft_type, + absl::Span fft_length, + const BufferAllocation::Slice& input_buffer, + const BufferAllocation::Slice& output_buffer, + const Shape& input_shape, const Shape& output_shape) + : Thunk(Kind::kFft, thunk_info), + fft_type_( + FftTypeToSeType(fft_type, input_shape.element_type() == F64 || + input_shape.element_type() == C128)), + fft_length_(fft_length.begin(), fft_length.end()), + input_buffer_(input_buffer), + output_buffer_(output_buffer), + input_shape_(input_shape), + output_shape_(output_shape) {} + +absl::Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { + auto& buffer_allocations = *params.buffer_allocations; + + return RunFft( + buffer_allocations.GetDeviceAddress(input_buffer_), input_shape_, + buffer_allocations.GetDeviceAddress(output_buffer_), output_shape_, + fft_type_, fft_length_, buffer_allocations.device_ordinal(), + &fft_plan_cache_, params.stream, buffer_allocations.memory_allocator()); +} + +absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, + se::DeviceMemoryBase output, const Shape& output_shape, + se::fft::Type fft_type, absl::Span fft_len, + int device_ordinal, FftPlanCache* fft_plan_cache, + se::Stream* stream, + se::DeviceMemoryAllocator* memory_allocator) { + VLOG(3) << "FFT type: " << FftTypeToString(fft_type); + VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape); + VLOG(3) << "Output shape: " << ShapeUtil::HumanStringWithLayout(output_shape); + + se::OwningScratchAllocator<2> scratch_allocator(device_ordinal, + memory_allocator); + + // Get the Fft plan for the given device ordinal. + FftPlan* fft_plan_ptr = fft_plan_cache->GetOrCreate(device_ordinal); + + // CuFFT thread-safety requires that separate host threads not share plans; + // protect each plan with a mutex. + absl::MutexLock lock(&fft_plan_ptr->mu); + std::unique_ptr& fft_plan = fft_plan_ptr->plan; + TF_ASSIGN_OR_RETURN(auto fft, GetFft(stream)); + if (fft_plan == nullptr) { + const int64_t fft_rank = fft_len.size(); + CHECK_LE(fft_rank, 3); + int batch_size = 1; + for (int i = 0; i < input_shape.dimensions_size() - fft_rank; ++i) { + batch_size *= input_shape.dimensions(i); + } + uint64_t fft_length[3]; + uint64_t input_embed[3]; + const uint64_t input_stride = 1; + uint64_t input_distance = 1; + uint64_t output_embed[3]; + const uint64_t output_stride = 1; + uint64_t output_distance = 1; + + for (int i = 0; i < fft_rank; ++i) { + auto dim_offset = input_shape.dimensions_size() - fft_rank + i; + fft_length[i] = static_cast(fft_len[i]); + input_embed[i] = input_shape.dimensions(dim_offset); + input_distance *= input_shape.dimensions(dim_offset); + output_embed[i] = output_shape.dimensions(dim_offset); + output_distance *= output_shape.dimensions(dim_offset); + } + + constexpr bool kInPlaceFft = false; + fft_plan = fft->CreateBatchedPlanWithScratchAllocator( + stream, fft_rank, fft_length, input_embed, input_stride, input_distance, + output_embed, output_stride, output_distance, fft_type, kInPlaceFft, + batch_size, &scratch_allocator); + TF_RET_CHECK(fft_plan != nullptr) + << "Failed to create cuFFT batched plan with scratch allocator"; + fft_plan_ptr->scale_factor = 1.0f / output_distance; + } else { + fft->UpdatePlanWithScratchAllocator(stream, fft_plan.get(), + &scratch_allocator); + } + + float scale_factor = fft_plan_ptr->scale_factor; + + bool launch_ok; + switch (fft_type) { + case se::fft::Type::kC2CForward: { + se::DeviceMemory input_data(input); + se::DeviceMemory output_data(output); + launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); + break; + } + case se::fft::Type::kZ2ZForward: { + se::DeviceMemory input_data(input); + se::DeviceMemory output_data(output); + launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); + break; + } + case se::fft::Type::kC2CInverse: { + se::DeviceMemory input_data(input); + se::DeviceMemory output_data(output); + launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); + if (launch_ok) { + TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); + launch_ok = + blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), + complex64(scale_factor), &output_data, 1); + } + break; + } + case se::fft::Type::kZ2ZInverse: { + se::DeviceMemory input_data(input); + se::DeviceMemory output_data(output); + launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); + if (launch_ok) { + TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); + launch_ok = + blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), + complex128(scale_factor), &output_data, 1); + } + break; + } + case se::fft::Type::kR2C: { + se::DeviceMemory input_data(input); + se::DeviceMemory output_data(output); + launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); + break; + } + case se::fft::Type::kD2Z: { + se::DeviceMemory input_data(input); + se::DeviceMemory output_data(output); + launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); + break; + } + case se::fft::Type::kC2R: { + se::DeviceMemory input_data(input); + se::DeviceMemory output_data(output); + launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); + if (launch_ok) { + TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); + launch_ok = + blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), + scale_factor, &output_data, 1); + } + break; + } + case se::fft::Type::kZ2D: { + se::DeviceMemory input_data(input); + se::DeviceMemory output_data(output); + launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); + if (launch_ok) { + TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); + launch_ok = + blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), + scale_factor, &output_data, 1); + } + break; + } + default: + LOG(FATAL) << "unsupported fft type"; + } + if (launch_ok) { + return absl::OkStatus(); + } + return Internal("Unable to launch fft with type %s", + FftTypeToString(fft_type)); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/fft_thunk.h b/xla/service/gpu/runtime/fft_thunk.h new file mode 100644 index 0000000000000..ffd45ed804fda --- /dev/null +++ b/xla/service/gpu/runtime/fft_thunk.h @@ -0,0 +1,107 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_FFT_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_FFT_THUNK_H_ + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/shape.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/fft.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +struct FftPlan { + // CuFFT thread-safety requires that separate host threads not share plans; + // protect each plan with a mutex. + absl::Mutex mu; + std::unique_ptr plan ABSL_GUARDED_BY(mu); + float scale_factor ABSL_GUARDED_BY(mu); +}; + +class FftPlanCache { + public: + // Returnes Fft plan cached for the given device ordinal or creates a new one. + FftPlan* GetOrCreate(int device_ordinal) { + absl::MutexLock lock(&mu_); + std::unique_ptr& plan = fft_plans_[device_ordinal]; + if (!plan) plan = std::make_unique(); + return plan.get(); + } + + private: + absl::Mutex mu_; + absl::flat_hash_map> fft_plans_ + ABSL_GUARDED_BY(mu_); +}; + +// This class stores everything that StreamExecutor needs to launch an FFT. +// It is generated by IrEmitter. +// +// This is thread-compatible. +class FftThunk : public Thunk { + public: + // Constructs a thunk for launching an FFT on a stream. + // Semantics of null hlo_instruction argument are as in Thunk. + FftThunk(ThunkInfo thunk_info, FftType fft_type, + absl::Span fft_length, + const BufferAllocation::Slice& input_buffer, + const BufferAllocation::Slice& output_buffer, + const Shape& input_shape, const Shape& output_shape); + + FftThunk(const FftThunk&) = delete; // Cannot share fft_plan_ + FftThunk& operator=(const FftThunk&) = delete; // Cannot share fft_plan_ + + // Does the FFT for the thunk on "stream". + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + private: + const se::fft::Type fft_type_; + const std::vector fft_length_; + + FftPlanCache fft_plan_cache_; + + const BufferAllocation::Slice input_buffer_; + const BufferAllocation::Slice output_buffer_; + + const Shape input_shape_; + const Shape output_shape_; +}; + +absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, + se::DeviceMemoryBase output, const Shape& output_shape, + se::fft::Type fft_type, + absl::Span fft_length, int device_ordinal, + FftPlanCache* fft_plan_cache, se::Stream* stream, + se::DeviceMemoryAllocator* memory_allocator); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_FFT_THUNK_H_ diff --git a/xla/service/gpu/runtime/flash_attn_thunk.cc b/xla/service/gpu/runtime/flash_attn_thunk.cc new file mode 100644 index 0000000000000..c65a9ac8ef185 --- /dev/null +++ b/xla/service/gpu/runtime/flash_attn_thunk.cc @@ -0,0 +1,182 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/flash_attn_thunk.h" + +namespace xla { +namespace gpu { + +FlashAttnFwdThunk::FlashAttnFwdThunk( + ThunkInfo thunk_info, FlashAttnFwdConfig config, + BufferAllocation::Slice query_slice, BufferAllocation::Slice key_slice, + BufferAllocation::Slice value_slice, + BufferAllocation::Slice cu_seqlens_query_slice, /* may be null */ + BufferAllocation::Slice cu_seqlens_key_slice, /* may be null */ + BufferAllocation::Slice alibi_slopes_slice, /* may be null */ + BufferAllocation::Slice output_accum_slice, /* may be null */ + BufferAllocation::Slice softmax_lse_accum_slice, /* may be null */ + BufferAllocation::Slice output_slice, + BufferAllocation::Slice softmax_lse_slice, + BufferAllocation::Slice rng_state_slice, + BufferAllocation::Slice s_dmask_slice /* may be null */ + ) + : Thunk(Kind::kFlashAttn, thunk_info), + config_(std::move(config)), + query_buffer_(query_slice), + key_buffer_(key_slice), + value_buffer_(value_slice), + cu_seqlens_query_buffer_(cu_seqlens_query_slice), + cu_seqlens_key_buffer_(cu_seqlens_key_slice), + alibi_slopes_buffer_(alibi_slopes_slice), + output_accum_buffer_(output_accum_slice), + softmax_lse_accum_buffer_(softmax_lse_accum_slice), + output_buffer_(output_slice), + softmax_lse_buffer_(softmax_lse_slice), + rng_state_buffer_(rng_state_slice), + s_dmask_buffer_(s_dmask_slice) {} + +static std::optional AssignBufferIfNotNull( + const BufferAllocations& buffer_allocations, + BufferAllocation::Slice& slice) { + return slice.allocation() != nullptr + ? std::optional{buffer_allocations + .GetDeviceAddress(slice)} + : std::nullopt; +} + +absl::Status FlashAttnFwdThunk::ExecuteOnStream(const ExecuteParams& params) { + const auto& buffer_allocations = *params.buffer_allocations; + + se::DeviceMemoryBase query_buffer = + buffer_allocations.GetDeviceAddress(query_buffer_); + se::DeviceMemoryBase key_buffer = + buffer_allocations.GetDeviceAddress(key_buffer_); + se::DeviceMemoryBase value_buffer = + buffer_allocations.GetDeviceAddress(value_buffer_); + std::optional cu_seqlens_query_buffer = + AssignBufferIfNotNull(buffer_allocations, cu_seqlens_query_buffer_); + std::optional cu_seqlens_key_buffer = + AssignBufferIfNotNull(buffer_allocations, cu_seqlens_key_buffer_); + std::optional alibi_slopes_buffer = + AssignBufferIfNotNull(buffer_allocations, alibi_slopes_buffer_); + std::optional output_accum_buffer = + AssignBufferIfNotNull(buffer_allocations, output_accum_buffer_); + std::optional softmax_lse_accum_buffer = + AssignBufferIfNotNull(buffer_allocations, softmax_lse_accum_buffer_); + + se::DeviceMemoryBase output_buffer = + buffer_allocations.GetDeviceAddress(output_buffer_); + se::DeviceMemoryBase softmax_lse_buffer = + buffer_allocations.GetDeviceAddress(softmax_lse_buffer_); + se::DeviceMemoryBase rng_state_buffer = + buffer_allocations.GetDeviceAddress(rng_state_buffer_); + std::optional s_dmask_buffer = + AssignBufferIfNotNull(buffer_allocations, s_dmask_buffer_); + + TF_RETURN_IF_ERROR(RunFlashAttnFwd( + params.stream, config_, query_buffer, key_buffer, value_buffer, + cu_seqlens_query_buffer, cu_seqlens_key_buffer, alibi_slopes_buffer, + output_accum_buffer, softmax_lse_accum_buffer, output_buffer, + softmax_lse_buffer, rng_state_buffer, s_dmask_buffer, -1, -1)); + + if (!params.stream->ok()) { + return Internal("FlashAttnFwdThunk::ExecuteOnStream failed."); + } + return absl::OkStatus(); +} + +FlashAttnBwdThunk::FlashAttnBwdThunk( + ThunkInfo thunk_info, FlashAttnBwdConfig config, + BufferAllocation::Slice grad_output_slice, + BufferAllocation::Slice query_slice, BufferAllocation::Slice key_slice, + BufferAllocation::Slice value_slice, BufferAllocation::Slice output_slice, + BufferAllocation::Slice softmax_lse_slice, + BufferAllocation::Slice rng_state_slice, + BufferAllocation::Slice cu_seqlens_query_slice, /* may be null */ + BufferAllocation::Slice cu_seqlens_key_slice, /* may be null */ + BufferAllocation::Slice alibi_slopes_slice, /* may be null */ + BufferAllocation::Slice grad_query_accum_slice, + BufferAllocation::Slice grad_query_slice, + BufferAllocation::Slice grad_key_slice, + BufferAllocation::Slice grad_value_slice, + BufferAllocation::Slice grad_softmax_slice) + : Thunk(Kind::kFlashAttn, thunk_info), + config_(std::move(config)), + grad_output_buffer_(grad_output_slice), + query_buffer_(query_slice), + key_buffer_(key_slice), + value_buffer_(value_slice), + output_buffer_(output_slice), + softmax_lse_buffer_(softmax_lse_slice), + rng_state_buffer_(rng_state_slice), + cu_seqlens_query_buffer_(cu_seqlens_query_slice), + cu_seqlens_key_buffer_(cu_seqlens_key_slice), + alibi_slopes_buffer_(alibi_slopes_slice), + grad_query_accum_buffer_(grad_query_accum_slice), + grad_query_buffer_(grad_query_slice), + grad_key_buffer_(grad_key_slice), + grad_value_buffer_(grad_value_slice), + grad_softmax_buffer_(grad_softmax_slice) {} + +absl::Status FlashAttnBwdThunk::ExecuteOnStream(const ExecuteParams& params) { + const auto& buffer_allocations = *params.buffer_allocations; + + se::DeviceMemoryBase grad_output_buffer = + buffer_allocations.GetDeviceAddress(grad_output_buffer_); + se::DeviceMemoryBase query_buffer = + buffer_allocations.GetDeviceAddress(query_buffer_); + se::DeviceMemoryBase key_buffer = + buffer_allocations.GetDeviceAddress(key_buffer_); + se::DeviceMemoryBase value_buffer = + buffer_allocations.GetDeviceAddress(value_buffer_); + se::DeviceMemoryBase output_buffer = + buffer_allocations.GetDeviceAddress(output_buffer_); + se::DeviceMemoryBase softmax_lse_buffer = + buffer_allocations.GetDeviceAddress(softmax_lse_buffer_); + se::DeviceMemoryBase rng_state_buffer = + buffer_allocations.GetDeviceAddress(rng_state_buffer_); + std::optional cu_seqlens_query_buffer = + AssignBufferIfNotNull(buffer_allocations, cu_seqlens_query_buffer_); + std::optional cu_seqlens_key_buffer = + AssignBufferIfNotNull(buffer_allocations, cu_seqlens_key_buffer_); + std::optional alibi_slopes_buffer = + AssignBufferIfNotNull(buffer_allocations, alibi_slopes_buffer_); + se::DeviceMemoryBase grad_query_accum_buffer = + buffer_allocations.GetDeviceAddress(grad_query_accum_buffer_); + + se::DeviceMemoryBase grad_query_buffer = + buffer_allocations.GetDeviceAddress(grad_query_buffer_); + se::DeviceMemoryBase grad_key_buffer = + buffer_allocations.GetDeviceAddress(grad_key_buffer_); + se::DeviceMemoryBase grad_value_buffer = + buffer_allocations.GetDeviceAddress(grad_value_buffer_); + se::DeviceMemoryBase grad_softmax_buffer = + buffer_allocations.GetDeviceAddress(grad_softmax_buffer_); + + TF_RETURN_IF_ERROR(RunFlashAttnBwd( + params.stream, config_, grad_output_buffer, query_buffer, key_buffer, + value_buffer, output_buffer, softmax_lse_buffer, rng_state_buffer, + cu_seqlens_query_buffer, cu_seqlens_key_buffer, alibi_slopes_buffer, + grad_query_accum_buffer, grad_query_buffer, grad_key_buffer, + grad_value_buffer, grad_softmax_buffer, -1, -1)); + + if (!params.stream->ok()) { + return Internal("FlashAttnBwdThunk::ExecuteOnStream failed."); + } + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/flash_attn_thunk.h b/xla/service/gpu/runtime/flash_attn_thunk.h new file mode 100644 index 0000000000000..a54f382983f9f --- /dev/null +++ b/xla/service/gpu/runtime/flash_attn_thunk.h @@ -0,0 +1,117 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_FLASH_ATTN_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_FLASH_ATTN_THUNK_H_ + +#include "absl/status/status.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/gpu_flash_attn.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/stream_executor/stream_executor.h" + +namespace xla { +namespace gpu { + +// This class stores everything that StreamExecutor needs to launch a CUDA +// FlashAttention. It is generated by IrEmitter. +class FlashAttnFwdThunk : public Thunk { + public: + // Constructs a thunk for launching a CUDA FlashAttention. + FlashAttnFwdThunk( + ThunkInfo thunk_info, FlashAttnFwdConfig config, + BufferAllocation::Slice query_slice, BufferAllocation::Slice key_slice, + BufferAllocation::Slice value_slice, + BufferAllocation::Slice cu_seqlens_query_slice, /* may be null */ + BufferAllocation::Slice cu_seqlens_key_slice, /* may be null */ + BufferAllocation::Slice alibi_slopes_slice, /* may be null */ + BufferAllocation::Slice output_accum_slice, /* may be null */ + BufferAllocation::Slice softmax_lse_accum_slice, /* may be null */ + BufferAllocation::Slice output_slice, + BufferAllocation::Slice softmax_lse_slice, + BufferAllocation::Slice rng_state_slice, + BufferAllocation::Slice s_dmask_slice /* may be null */); + + FlashAttnFwdThunk(const FlashAttnFwdThunk &) = delete; + FlashAttnFwdThunk &operator=(const FlashAttnFwdThunk &) = delete; + + absl::Status ExecuteOnStream(const ExecuteParams ¶ms) override; + + private: + const FlashAttnFwdConfig config_; + + BufferAllocation::Slice query_buffer_; // input + BufferAllocation::Slice key_buffer_; // input + BufferAllocation::Slice value_buffer_; // input + BufferAllocation::Slice cu_seqlens_query_buffer_; // input(varlen) + BufferAllocation::Slice cu_seqlens_key_buffer_; // input(varlen) + BufferAllocation::Slice alibi_slopes_buffer_; // input + BufferAllocation::Slice output_accum_buffer_; // input(temp) + BufferAllocation::Slice softmax_lse_accum_buffer_; // input(temp) + + BufferAllocation::Slice output_buffer_; // output + BufferAllocation::Slice softmax_lse_buffer_; // output + BufferAllocation::Slice rng_state_buffer_; // output + BufferAllocation::Slice s_dmask_buffer_; // output +}; + +class FlashAttnBwdThunk : public Thunk { + public: + FlashAttnBwdThunk( + ThunkInfo thunk_info, FlashAttnBwdConfig config, + BufferAllocation::Slice grad_output_slice, + BufferAllocation::Slice query_slice, BufferAllocation::Slice key_slice, + BufferAllocation::Slice value_slice, BufferAllocation::Slice output_slice, + BufferAllocation::Slice softmax_lse_slice, + BufferAllocation::Slice rng_state_slice, + BufferAllocation::Slice cu_seqlens_query_slice, /* may be null */ + BufferAllocation::Slice cu_seqlens_key_slice, /* may be null */ + BufferAllocation::Slice alibi_slopes_slice, /* may be null */ + BufferAllocation::Slice grad_query_accum_slice, + BufferAllocation::Slice grad_query_slice, + BufferAllocation::Slice grad_key_slice, + BufferAllocation::Slice grad_value_slice, + BufferAllocation::Slice grad_softmax_slice); + + FlashAttnBwdThunk(const FlashAttnBwdThunk &) = delete; + FlashAttnBwdThunk &operator=(const FlashAttnBwdThunk &) = delete; + + absl::Status ExecuteOnStream(const ExecuteParams ¶ms) override; + + private: + const FlashAttnBwdConfig config_; + + BufferAllocation::Slice grad_output_buffer_; // input + BufferAllocation::Slice query_buffer_; // input + BufferAllocation::Slice key_buffer_; // input + BufferAllocation::Slice value_buffer_; // input + BufferAllocation::Slice output_buffer_; // input + BufferAllocation::Slice softmax_lse_buffer_; // input + BufferAllocation::Slice rng_state_buffer_; // input + BufferAllocation::Slice cu_seqlens_query_buffer_; // input(varlen) + BufferAllocation::Slice cu_seqlens_key_buffer_; // input(varlen) + BufferAllocation::Slice alibi_slopes_buffer_; // input + BufferAllocation::Slice grad_query_accum_buffer_; // input(temp) + + BufferAllocation::Slice grad_query_buffer_; // output + BufferAllocation::Slice grad_key_buffer_; // output + BufferAllocation::Slice grad_value_buffer_; // output + BufferAllocation::Slice grad_softmax_buffer_; // output +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_FLASH_ATTN_THUNK_H_ diff --git a/xla/service/gpu/runtime/fused_attention.cc b/xla/service/gpu/runtime/fused_attention.cc deleted file mode 100644 index 9dbd3e6aa1907..0000000000000 --- a/xla/service/gpu/runtime/fused_attention.cc +++ /dev/null @@ -1,1200 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License.1 -==============================================================================*/ - -#include "xla/service/gpu/runtime/fused_attention.h" - -#include -#include -#include -#include -#include - -#include "llvm/ADT/Sequence.h" -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/gpu_asm_opts_util.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/status.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" -#include "xla/translate/mhlo_to_hlo/attribute_exporter.h" -#include "xla/xla.pb.h" - -namespace xla { - -using xla::runtime::CustomCall; -using xla::runtime::EnumAttrEncoding; -using xla::runtime::FlatMemrefView; -using xla::runtime::State; -using xla::runtime::StridedMemrefView; -using xla::runtime::Tagged; - -namespace lmhlo_gpu = ::mlir::lmhlo_gpu; -namespace gpu { -//===----------------------------------------------------------------------===// -// Structs for encoding fused attention attributes defined in LMHLO dialect. -//===----------------------------------------------------------------------===// -struct AlgorithmConfig { - int64_t algorithm; - absl::Span knob_ids; - absl::Span knob_values; - int64_t workspace_size; -}; - -} // namespace gpu - -//===----------------------------------------------------------------------===// -// Register fused attention attributes decoding with the Xla runtime. -//===----------------------------------------------------------------------===// -namespace runtime { -XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING(xla::gpu::CudnnfMHAKind); - -XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING( - xla::gpu::AlgorithmConfig, // - AggregateMember("algorithm"), - AggregateMember>("knob_ids"), - AggregateMember>("knob_values"), - AggregateMember("workspace_size")); - -} // namespace runtime - -//===----------------------------------------------------------------------===// -// Type names for encoded attributes. -//===----------------------------------------------------------------------===// - -namespace gpu { - -// Register type names for fused attention attributes defined by LMHLO dialect. -void RegisterFusedAttentionTypeIdNames(runtime::TypeIDNameRegistry& registry) { - registry.Register>("__type_id_algorithm_config"); - registry.Register>( - "__type_id_xla_gpu_cudnn_fmha_kind"); -} - -static auto EncodeFusedAttentionDAGSignature( - lmhlo_gpu::FusedMhaDagSignature signature) { - switch (signature) { - case mlir::lmhlo_gpu::FusedMhaDagSignature::Default: - return xla::gpu::CudnnfMHAKind::kBmmBmm; - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasMaskSoftmax: - return xla::gpu::CudnnfMHAKind::kScaleBiasMaskSoftmax; - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasMaskSoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kScaleBiasMaskSoftmaxDropout; - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleMaskSoftmax: - return xla::gpu::CudnnfMHAKind::kScaleMaskSoftmax; - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleMaskSoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kScaleMaskSoftmaxDropout; - case mlir::lmhlo_gpu::FusedMhaDagSignature::SoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kSoftmaxDropout; - case mlir::lmhlo_gpu::FusedMhaDagSignature::Softmax: - return xla::gpu::CudnnfMHAKind::kSoftmax; - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasSoftmax: - return xla::gpu::CudnnfMHAKind::kScaleBiasSoftmax; - case mlir::lmhlo_gpu::FusedMhaDagSignature::ScaleBiasSoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kScaleBiasSoftmaxDropout; - } -} - -static auto EncodeFusedAttentionBackwardDAGSignature( - lmhlo_gpu::FusedMhaBackwardDagSignature signature) { - switch (signature) { - // backward - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmax: - return xla::gpu::CudnnfMHAKind::kBackwardSoftmax; - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kBackwardSoftmaxDropout; - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasSoftmax: - return xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmax; - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasSoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout; - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasMaskSoftmax: - return xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmax; - case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasMaskSoftmaxDropout: - return xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmaxDropout; - } -} - -void PopulateFusedAttentionForwardDAGSignatureAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding) { - { // --- Encode `lmhlo_gpu::FusedMhaDagSignatureAttr`. - encoding.Add>( - EncodeFusedAttentionDAGSignature); - } -} - -void PopulateFusedAttentionBackwardDAGSignatureAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding) { - { // --- Encode `lmhlo_gpu::FusedMhaBackwardDagSignatureAttr`. - encoding.Add>( - EncodeFusedAttentionBackwardDAGSignature); - } -} - -void PopulateFusedAttentionAlgorithmConfigAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding) { - { // --- Encode `lmhlo_gpu::FusedMHAAlgorithmConfigAttr`. - using Attr = mlir::lmhlo_gpu::FusedMHAAlgorithmConfigAttr; - encoding.Add>( - encoding, xla::runtime::AggregateAttrDef() - .Add("algorithm", &Attr::getAlgorithm) - .Add("knob_ids", &Attr::getKnobIds) - .Add("knob_values", &Attr::getKnobValues) - .Add("workspace_size", &Attr::getWorkspaceSize)); - } -} - -//===----------------------------------------------------------------------===// -// Fused Dot Attention runners caching. -//===----------------------------------------------------------------------===// - -StreamExecutorFusedAttentionRunners* FusedAttentionRunners::operator()( - se::StreamExecutor* executor) { - absl::MutexLock lock(&mutex_); - return &runners_[executor]; -} - -StreamExecutorFusedAttentionBackwardRunners* -FusedAttentionBackwardRunners::operator()(se::StreamExecutor* executor) { - absl::MutexLock lock(&mutex_); - return &runners_[executor]; -} - -namespace { -struct DropoutAttrs { - double dropout_rate; - int64_t seed; -}; -} // namespace - -static GpufMHADescriptor GetGpufMHADescriptor( - CudnnfMHAKind kind, StridedMemrefView lhs_bmm1, StridedMemrefView rhs_bmm1, - StridedMemrefView rhs_bmm2, std::optional mask, - std::optional bias, StridedMemrefView output, - std::optional activation, double fmha_scale, - absl::Span intermediate_tensor_dimensions, - absl::Span intermediate_tensor_layout, AlgorithmConfig algo, - DotDimensionNumbers bmm1_dot_dimension_numbers, - DotDimensionNumbers bmm2_dot_dimension_numbers, bool is_flash_attention, - bool is_causal_mask, std::optional dropout = std::nullopt) { - GpufMHADescriptor descriptor; - descriptor.backend_config.set_fmha_scale(fmha_scale); - - auto* algorithm = descriptor.backend_config.mutable_algorithm(); - algorithm->set_algo_id(algo.algorithm); - for (unsigned i = 0; i < algo.knob_ids.size(); ++i) { - algorithm->mutable_tuning_knobs()->insert( - {algo.knob_ids[i], algo.knob_values[i]}); - } - algorithm->set_is_cudnn_frontend(true); - if (algo.workspace_size >= 0) { - algorithm->mutable_workspace_size()->set_value(algo.workspace_size); - } - descriptor.bmm1_dnums = - ConvertDotDimensionNumbers(bmm1_dot_dimension_numbers.lhs_batch, - bmm1_dot_dimension_numbers.lhs_contract, - bmm1_dot_dimension_numbers.rhs_batch, - bmm1_dot_dimension_numbers.rhs_contract); - descriptor.bmm2_dnums = - ConvertDotDimensionNumbers(bmm2_dot_dimension_numbers.lhs_batch, - bmm2_dot_dimension_numbers.lhs_contract, - bmm2_dot_dimension_numbers.rhs_batch, - bmm2_dot_dimension_numbers.rhs_contract); - // Apply backend config layout to the shape. - auto apply_shape = [](StridedMemrefView& memref) { - Shape shape = ToShape(memref); - return ShapeUtil::MakeShapeWithDenseLayout(shape.element_type(), - shape.dimensions(), - shape.layout().minor_to_major()); - }; - descriptor.lhs_bmm1_shape = apply_shape(lhs_bmm1); - descriptor.rhs_bmm1_shape = apply_shape(rhs_bmm1); - descriptor.rhs_bmm2_shape = apply_shape(rhs_bmm2); - descriptor.output_shapes.push_back(apply_shape(output)); - if (activation.has_value()) { - descriptor.output_shapes.push_back(apply_shape(*activation)); - } - if (bias.has_value()) { - descriptor.bias_shape = apply_shape(*bias); - } - if (mask.has_value()) { - descriptor.mask_shape = apply_shape(*mask); - } - - Shape out_shape = ToShape(output); - descriptor.intermediate_lhs_bmm2_shape = ShapeUtil::MakeShapeWithDenseLayout( - out_shape.element_type(), intermediate_tensor_dimensions, - intermediate_tensor_layout); - - if (dropout.has_value()) { - descriptor.backend_config.set_dropout_rate(dropout->dropout_rate); - descriptor.backend_config.set_seed(dropout->seed); - } - - descriptor.kind = kind; - descriptor.is_flash_attention = is_flash_attention; - descriptor.is_causal_mask = is_causal_mask; - return descriptor; -} - -static GpufMHABackwardDescriptor GetGpufMHABackwardDescriptor( - CudnnfMHAKind kind, StridedMemrefView bmm1_grad_gemm1_rhs, - StridedMemrefView bmm1_grad_gemm2_rhs, - StridedMemrefView bmm2_grad_gemm2_rhs, - StridedMemrefView bmm2_grad_gemm1_lhs, StridedMemrefView d_output, - std::optional mask, - std::optional d_bias, StridedMemrefView d_bmm1_lhs, - StridedMemrefView d_bmm1_rhs, StridedMemrefView d_bmm2_rhs, - std::optional d_S, - std::optional softmax_sum, - std::optional d_Q_accum, - std::optional fwd_output, - std::optional bias, double fmha_scale, - AlgorithmConfig algo, - DotDimensionNumbers bmm1_grad_gemm1_dot_dimension_numbers, - DotDimensionNumbers bmm1_grad_gemm2_dot_dimension_numbers, - DotDimensionNumbers bmm2_grad_gemm1_dot_dimension_numbers, - DotDimensionNumbers bmm2_grad_gemm2_dot_dimension_numbers, - absl::Span intermediate_tensor_dimensions, - absl::Span intermediate_tensor_layout, - bool is_flash_attention, bool is_causal_mask, - std::optional dropout_attrs = std::nullopt) { - GpufMHABackwardDescriptor descriptor; - descriptor.backend_config.set_fmha_scale(fmha_scale); - - auto* algorithm = descriptor.backend_config.mutable_algorithm(); - algorithm->set_algo_id(algo.algorithm); - for (unsigned i = 0; i < algo.knob_ids.size(); ++i) { - algorithm->mutable_tuning_knobs()->insert( - {algo.knob_ids[i], algo.knob_values[i]}); - } - algorithm->set_is_cudnn_frontend(true); - if (algo.workspace_size >= 0) { - algorithm->mutable_workspace_size()->set_value(algo.workspace_size); - } - - descriptor.bmm1_grad_gemm1_dnums = ConvertDotDimensionNumbers( - bmm1_grad_gemm1_dot_dimension_numbers.lhs_batch, - bmm1_grad_gemm1_dot_dimension_numbers.lhs_contract, - bmm1_grad_gemm1_dot_dimension_numbers.rhs_batch, - bmm1_grad_gemm1_dot_dimension_numbers.rhs_contract); - descriptor.bmm1_grad_gemm2_dnums = ConvertDotDimensionNumbers( - bmm1_grad_gemm2_dot_dimension_numbers.lhs_batch, - bmm1_grad_gemm2_dot_dimension_numbers.lhs_contract, - bmm1_grad_gemm2_dot_dimension_numbers.rhs_batch, - bmm1_grad_gemm2_dot_dimension_numbers.rhs_contract); - descriptor.bmm2_grad_gemm1_dnums = ConvertDotDimensionNumbers( - bmm2_grad_gemm1_dot_dimension_numbers.lhs_batch, - bmm2_grad_gemm1_dot_dimension_numbers.lhs_contract, - bmm2_grad_gemm1_dot_dimension_numbers.rhs_batch, - bmm2_grad_gemm1_dot_dimension_numbers.rhs_contract); - descriptor.bmm2_grad_gemm2_dnums = ConvertDotDimensionNumbers( - bmm2_grad_gemm2_dot_dimension_numbers.lhs_batch, - bmm2_grad_gemm2_dot_dimension_numbers.lhs_contract, - bmm2_grad_gemm2_dot_dimension_numbers.rhs_batch, - bmm2_grad_gemm2_dot_dimension_numbers.rhs_contract); - - // Apply backend config layout to the shape. - auto apply_shape = [](StridedMemrefView& memref) { - Shape shape = ToShape(memref); - return ShapeUtil::MakeShapeWithDenseLayout(shape.element_type(), - shape.dimensions(), - shape.layout().minor_to_major()); - }; - descriptor.bmm1_grad_gemm1_rhs_shape = apply_shape(bmm1_grad_gemm1_rhs); - descriptor.bmm1_grad_gemm2_rhs_shape = apply_shape(bmm1_grad_gemm2_rhs); - descriptor.bmm2_grad_gemm2_rhs_shape = apply_shape(bmm2_grad_gemm2_rhs); - if (is_flash_attention) { - // if it is flash attention then bmm2_grad_gemm1_lhs will be softmax_stats - // instead of P we need to use real P layout - descriptor.bmm2_grad_gemm1_lhs_shape = ShapeUtil::MakeShapeWithDenseLayout( - descriptor.bmm2_grad_gemm2_rhs_shape.element_type(), - intermediate_tensor_dimensions, intermediate_tensor_layout); - } else { - descriptor.bmm2_grad_gemm1_lhs_shape = apply_shape(bmm2_grad_gemm1_lhs); - } - - descriptor.d_output_shape = apply_shape(d_output); - descriptor.d_bmm1_lhs_shape = apply_shape(d_bmm1_lhs); - descriptor.d_bmm1_rhs_shape = apply_shape(d_bmm1_rhs); - descriptor.d_bmm2_rhs_shape = apply_shape(d_bmm2_rhs); - - if (mask.has_value()) { - descriptor.mask_shape = apply_shape(*mask); - } - if (d_bias.has_value()) { - descriptor.d_bias_shape = apply_shape(*d_bias); - } - if (fwd_output.has_value()) { - descriptor.fwd_output_shape = apply_shape(*fwd_output); - } - if (bias.has_value()) { - descriptor.bias_shape = apply_shape(*bias); - } - if (dropout_attrs.has_value()) { - descriptor.backend_config.set_dropout_rate(dropout_attrs->dropout_rate); - descriptor.backend_config.set_seed(dropout_attrs->seed); - } - - descriptor.kind = kind; - descriptor.is_flash_attention = is_flash_attention; - descriptor.is_causal_mask = is_causal_mask; - return descriptor; -} - -static absl::Status FusedAttentionForwardImpl( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, State runner, - StridedMemrefView lhs_bmm1, StridedMemrefView rhs_bmm1, - StridedMemrefView rhs_bmm2, std::optional mask, - std::optional bias, StridedMemrefView output, - FlatMemrefView scratch, std::optional activation, - int64_t uid, double fmha_scale, bool is_flash_attention, - bool is_causal_mask, - absl::Span intermediate_tensor_dimensions, - absl::Span intermediate_tensor_layout, - DotDimensionNumbers bmm1_dot_dimension_numbers, - DotDimensionNumbers bmm2_dot_dimension_numbers, - xla::gpu::CudnnfMHAKind kind, AlgorithmConfig algorithm_config, - std::optional dropout_rate = std::nullopt, - std::optional seed = std::nullopt) { - std::optional dropout_attrs = std::nullopt; - if (dropout_rate.has_value() && seed.has_value()) { - dropout_attrs = {*dropout_rate, *seed}; - } - // Get or create the fused attention runner state. - absl::StatusOr fda = - runner.GetOrCreate([&]() -> absl::StatusOr { - GpufMHADescriptor descriptor = GetGpufMHADescriptor( - kind, lhs_bmm1, rhs_bmm1, rhs_bmm2, mask, bias, output, activation, - fmha_scale, intermediate_tensor_dimensions, - intermediate_tensor_layout, algorithm_config, - bmm1_dot_dimension_numbers, bmm2_dot_dimension_numbers, - is_flash_attention, is_causal_mask, dropout_attrs); - - StatusOr config = GpufMHAConfig::For(descriptor); - if (!config.ok()) return tsl::ToAbslStatus(config.status()); - - return FusedAttentionRunner(*std::move(config)); - }); - if (!fda.ok()) return fda.status(); - - se::DeviceMemoryBase lhs_bmm1_buffer = GetDeviceAddress(lhs_bmm1); - se::DeviceMemoryBase rhs_bmm1_buffer = GetDeviceAddress(rhs_bmm1); - se::DeviceMemoryBase rhs_bmm2_buffer = GetDeviceAddress(rhs_bmm2); - se::DeviceMemoryBase output_buffer = GetDeviceAddress(output); - se::DeviceMemoryBase scratch_buffer = GetDeviceAddress(scratch); - - se::DeviceMemoryBase mask_buffer; - if (mask.has_value()) { - mask_buffer = GetDeviceAddress(*mask); - } - se::DeviceMemoryBase bias_buffer; - if (bias.has_value()) { - bias_buffer = GetDeviceAddress(*bias); - } - se::DeviceMemoryBase activation_buffer; - if (activation.has_value()) { - activation_buffer = GetDeviceAddress(*activation); - } - - RunFusedMHAOptions opts; - opts.runner_cache = &(*fda)->runner; - - // Run the fused dot attention. - auto st = - RunGpuFMHA((*fda)->config, lhs_bmm1_buffer, rhs_bmm1_buffer, - rhs_bmm2_buffer, output_buffer, scratch_buffer, mask_buffer, - bias_buffer, activation_buffer, run_options->stream(), opts); - if (!st.ok() || !run_options->stream()->ok()) { - return tsl::ToAbslStatus(st); - } - return absl::OkStatus(); -} - -static absl::Status FusedAttentionBackwardImpl( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - State runner, - StridedMemrefView bmm1_grad_gemm1_rhs, - StridedMemrefView bmm1_grad_gemm2_rhs, - StridedMemrefView bmm2_grad_gemm2_rhs, - StridedMemrefView bmm2_grad_gemm1_lhs, StridedMemrefView d_output, - std::optional mask, - std::optional bias, - std::optional fwd_output, StridedMemrefView d_bmm1_lhs, - StridedMemrefView d_bmm1_rhs, StridedMemrefView d_bmm2_rhs, - std::optional d_S, - std::optional softmax_sum, - std::optional d_Q_accum, FlatMemrefView scratch, - std::optional d_bias, int64_t uid, double fmha_scale, - bool is_flash_attention, bool is_causal_mask, - absl::Span intermediate_tensor_dimensions, - absl::Span intermediate_tensor_layout, - DotDimensionNumbers bmm1_grad_gemm1_dot_dimension_numbers, - DotDimensionNumbers bmm1_grad_gemm2_dot_dimension_numbers, - DotDimensionNumbers bmm2_grad_gemm1_dot_dimension_numbers, - DotDimensionNumbers bmm2_grad_gemm2_dot_dimension_numbers, - xla::gpu::CudnnfMHAKind kind, AlgorithmConfig algorithm_config, - std::optional dropout_rate = std::nullopt, - std::optional seed = std::nullopt) { - std::optional dropout_attrs = std::nullopt; - if (dropout_rate.has_value() && seed.has_value()) { - dropout_attrs = {*dropout_rate, *seed}; - } - - // Get or create the fused attention runner state. - absl::StatusOr fda = - runner.GetOrCreate([&]() -> absl::StatusOr { - GpufMHABackwardDescriptor descriptor = GetGpufMHABackwardDescriptor( - kind, bmm1_grad_gemm1_rhs, bmm1_grad_gemm2_rhs, bmm2_grad_gemm2_rhs, - bmm2_grad_gemm1_lhs, d_output, mask, d_bias, d_bmm1_lhs, d_bmm1_rhs, - d_bmm2_rhs, d_S, softmax_sum, d_Q_accum, fwd_output, bias, - fmha_scale, algorithm_config, bmm1_grad_gemm1_dot_dimension_numbers, - bmm1_grad_gemm2_dot_dimension_numbers, - bmm2_grad_gemm1_dot_dimension_numbers, - bmm2_grad_gemm2_dot_dimension_numbers, - intermediate_tensor_dimensions, intermediate_tensor_layout, - is_flash_attention, is_causal_mask, dropout_attrs); - StatusOr config = - GpufMHABackwardConfig::For(descriptor); - if (!config.ok()) return tsl::ToAbslStatus(config.status()); - - return FusedAttentionBackwardRunner(*std::move(config)); - }); - if (!fda.ok()) return fda.status(); - - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer = - GetDeviceAddress(bmm1_grad_gemm1_rhs); - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer = - GetDeviceAddress(bmm1_grad_gemm2_rhs); - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer = - GetDeviceAddress(bmm2_grad_gemm2_rhs); - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer = - GetDeviceAddress(bmm2_grad_gemm1_lhs); - - se::DeviceMemoryBase d_output_buffer = GetDeviceAddress(d_output); - se::DeviceMemoryBase d_bmm1_lhs_buffer = GetDeviceAddress(d_bmm1_lhs); - se::DeviceMemoryBase d_bmm1_rhs_buffer = GetDeviceAddress(d_bmm1_rhs); - se::DeviceMemoryBase d_bmm2_rhs_buffer = GetDeviceAddress(d_bmm2_rhs); - se::DeviceMemoryBase scratch_buffer = GetDeviceAddress(scratch); - - se::DeviceMemoryBase d_S_buffer; - if (d_S.has_value()) { - d_S_buffer = GetDeviceAddress(*d_S); - } - - se::DeviceMemoryBase mask_buffer; - if (mask.has_value()) { - mask_buffer = GetDeviceAddress(*mask); - } - - se::DeviceMemoryBase d_bias_buffer; - if (d_bias.has_value()) { - d_bias_buffer = GetDeviceAddress(*d_bias); - } - - se::DeviceMemoryBase softmax_sum_buffer; - if (softmax_sum.has_value()) { - softmax_sum_buffer = GetDeviceAddress(*softmax_sum); - } - - se::DeviceMemoryBase d_Q_accum_buffer; - if (d_Q_accum.has_value()) { - d_Q_accum_buffer = GetDeviceAddress(*d_Q_accum); - } - - se::DeviceMemoryBase fwd_output_buffer; - if (fwd_output.has_value()) { - fwd_output_buffer = GetDeviceAddress(*fwd_output); - } - - se::DeviceMemoryBase bias_buffer; - if (bias.has_value()) { - bias_buffer = GetDeviceAddress(*bias); - } - - RunFusedMHABackwardOptions opts; - opts.runner_cache = &(*fda)->runner; - - // Run the fused attention backward. - auto st = RunGpuFMHABackward( - (*fda)->config, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, - bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, d_output_buffer, - scratch_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, - d_S_buffer, softmax_sum_buffer, d_Q_accum_buffer, mask_buffer, - d_bias_buffer, fwd_output_buffer, bias_buffer, run_options->stream(), - opts); - if (!st.ok() || !run_options->stream()->ok()) { - return tsl::ToAbslStatus(st); - } - return absl::OkStatus(); -} - -//===----------------------------------------------------------------------===// -// Fused Attention custom calls bindings and registration. -//===----------------------------------------------------------------------===// - -template -auto BindFusedAttentionAttributes(runtime::CustomCallBinding binding) { - return std::move(binding) - .template Attr("uid") - .template Attr("fmha_scale") - .template Attr("is_flash_attention") - .template Attr("is_causal_mask") - .template Attr>( - "intermediate_tensor_dimensions") - .template Attr>("intermediate_tensor_layout") - .template Attr("bmm1_dot_dimension_numbers") - .template Attr("bmm2_dot_dimension_numbers") - .template Attr("fused_mha_dag") - .template Attr("algorithm_config"); -} - -auto FusedAttentionCall(const char* name) { - return CustomCall::Bind(name) - .UserData() - .UserData() - .State("uid") - .Arg() // lhs_bmm1 - .Arg() // rhs_bmm1 - .Arg(); // rhs_bmm2 -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionBmmBmmInference, FunctionWrapper(), - checks, - BindFusedAttentionAttributes( - FusedAttentionCall("xla.gpu.fused.attention.bmm.bmm.inference") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Value(std::optional()) // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionBmmBmmForward, FunctionWrapper(), - checks, - BindFusedAttentionAttributes( - FusedAttentionCall("xla.gpu.fused.attention.bmm.bmm.forward") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Arg() // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionSoftmaxInference, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall("xla.gpu.fused.attention.softmax.inference") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Value(std::optional()) // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionSoftmaxForward, FunctionWrapper(), - checks, - BindFusedAttentionAttributes( - FusedAttentionCall("xla.gpu.fused.attention.softmax.forward") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Arg() // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionSoftmaxDropoutInference, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall("xla.gpu.fused.attention.softmax.dropout.inference") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Value(std::optional()) // activation - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionSoftmaxDropoutForward, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall("xla.gpu.fused.attention.softmax.dropout.forward") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Arg() // activation - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasSoftmaxInference, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.bias.softmax.inference") - .Value(std::optional()) // mask - .Arg() // bias - .Arg() // output - .Arg() // scratch - .Value(std::optional()) // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasSoftmaxForward, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall("xla.gpu.fused.attention.scale.bias.softmax.forward") - .Value(std::optional()) // mask - .Arg() // bias - .Arg() // output - .Arg() // scratch - .Arg() // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasSoftmaxDropoutInference, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.bias.softmax.dropout.inference") - .Value(std::optional()) // mask - .Arg() // bias - .Arg() // output - .Arg() // scratch - .Value(std::optional()) // activation - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasSoftmaxDropoutForward, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.bias.softmax.dropout.forward") - .Value(std::optional()) // mask - .Arg() // bias - .Arg() // output - .Arg() // scratch - .Arg() // activation - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleMaskSoftmaxInference, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.mask.softmax.inference") - .Arg() // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Value(std::optional()) // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleMaskSoftmaxForward, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall("xla.gpu.fused.attention.scale.mask.softmax.forward") - .Arg() // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Arg() // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleMaskSoftmaxDropoutInference, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.mask.softmax.dropout.inference") - .Arg() // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Value(std::optional()) // activation - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleMaskSoftmaxDropoutForward, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.mask.softmax.dropout.forward") - .Arg() // mask - .Value(std::optional()) // bias - .Arg() // output - .Arg() // scratch - .Arg() // activation - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasMaskSoftmaxInference, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.bias.mask.softmax.inference") - .Arg() // mask - .Arg() // bias - .Arg() // output - .Arg() // scratch - .Value(std::optional()) // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasMaskSoftmaxForward, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.bias.mask.softmax.forward") - .Arg() // mask - .Arg() // bias - .Arg() // output - .Arg() // scratch - .Arg() // activation - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasMaskSoftmaxDropoutInference, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.bias.mask.softmax.dropout.inference") - .Arg() // mask - .Arg() // bias - .Arg() // output - .Arg() // scratch - .Value(std::optional()) // activation - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasMaskSoftmaxDropoutForward, - FunctionWrapper(), checks, - BindFusedAttentionAttributes( - FusedAttentionCall( - "xla.gpu.fused.attention.scale.bias.mask.softmax.dropout.forward") - .Arg() // mask - .Arg() // bias - .Arg() // output - .Arg() // scratch - .Arg() // activation - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -template -auto BindFusedAttentionBackwardAttributes( - runtime::CustomCallBinding binding) { - return std::move(binding) - .template Attr("uid") - .template Attr("fmha_scale") - .template Attr("is_flash_attention") - .template Attr("is_causal_mask") - .template Attr>( - "intermediate_tensor_dimensions") - .template Attr>("intermediate_tensor_layout") - .template Attr( - "bmm1_grad_gemm1_dot_dimension_numbers") - .template Attr( - "bmm1_grad_gemm2_dot_dimension_numbers") - .template Attr( - "bmm2_grad_gemm1_dot_dimension_numbers") - .template Attr( - "bmm2_grad_gemm2_dot_dimension_numbers") - .template Attr("fused_mha_dag") - .template Attr("algorithm_config"); -} - -auto FusedAttentionBackwardCall(const char* name) { - return CustomCall::Bind(name) - .UserData() - .UserData() - .State("uid") - .Arg() // bmm1_grad_gemm1_rhs - .Arg() // bmm1_grad_gemm2_rhs - .Arg() // bmm2_grad_gemm2_rhs - .Arg() // bmm2_grad_gemm1_lhs - .Arg(); // d_output -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasSoftmaxBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.fused.attention.backward.scale.dbias.softmax") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Value(std::optional()) // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Value(std::optional()) // softmax_sum - .Value(std::optional()) // d_Q_accum - .Arg() // scratch - .Arg() // d_bias - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleSoftmaxBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.fused.attention.backward.scale.softmax") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Value(std::optional()) // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Value(std::optional()) // softmax_sum - .Value(std::optional()) // d_Q_accum - .Arg() // scratch - .Value(std::optional()) // d_bias - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasSoftmaxDropoutBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.fused.attention.backward.scale.dbias.softmax.dropout") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Value(std::optional()) // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Value(std::optional()) // softmax_sum - .Value(std::optional()) // d_Q_accum - .Arg() // scratch - .Arg() // d_bias - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleSoftmaxDropoutBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.fused.attention.backward.scale.softmax.dropout") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Value(std::optional()) // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Value(std::optional()) // softmax_sum - .Value(std::optional()) // d_Q_accum - .Arg() // scratch - .Value(std::optional()) // d_bias - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasMaskSoftmaxBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.fused.attention.backward.scale.dbias.mask.softmax") - .Arg() // mask - .Value(std::optional()) // bias - .Value(std::optional()) // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Value(std::optional()) // softmax_sum - .Value(std::optional()) // d_Q_accum - .Arg() // scratch - .Arg() // d_bias - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleMaskSoftmaxBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.fused.attention.backward.scale.mask.softmax") - .Arg() // mask - .Value(std::optional()) // bias - .Value(std::optional()) // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Value(std::optional()) // softmax_sum - .Value(std::optional()) // d_Q_accum - .Arg() // scratch - .Value(std::optional()) // d_bias - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleBiasMaskSoftmaxDropoutBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.fused.attention.backward.scale.dbias.mask.softmax.dropout") - .Arg() // mask - .Value(std::optional()) // bias - .Value(std::optional()) // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Value(std::optional()) // softmax_sum - .Value(std::optional()) // d_Q_accum - .Arg() // scratch - .Arg() // d_bias - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FusedAttentionScaleMaskSoftmaxDropoutBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.fused.attention.backward.scale.mask.softmax.dropout") - .Arg() // mask - .Value(std::optional()) // bias - .Value(std::optional()) // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Value(std::optional()) // softmax_sum - .Value(std::optional()) // d_Q_accum - .Arg() // scratch - .Value(std::optional()) // d_bias - ) - .Attr("dropout_rate") // dropout_rate - .Attr("seed") // seed -); - -// flash attention backward custom call -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FlashAttentionScaleBiasSoftmaxBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.flash.attention.backward.scale.bias.softmax") - .Value(std::optional()) // mask - .Arg() // bias - .Arg() // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Value(std::optional()) // d_S - .Arg() // softmax_sum - .Arg() // d_Q_accum - .Arg() // scratch - .Value(std::optional()) // d_bias - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - FlashAttentionScaleSoftmaxBackward, - FunctionWrapper(), checks, - BindFusedAttentionBackwardAttributes( - FusedAttentionBackwardCall( - "xla.gpu.flash.attention.backward.scale.softmax") - .Value(std::optional()) // mask - .Value(std::optional()) // bias - .Arg() // fwd_output - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Value(std::optional()) // d_S - .Arg() // softmax_sum - .Arg() // d_Q_accum - .Arg() // scratch - .Value(std::optional()) // d_bias - ) - .Value(std::optional()) // dropout_rate - .Value(std::optional()) // seed -); - -//===----------------------------------------------------------------------===// -// cuBLASLt custom calls bindings and registration. -//===----------------------------------------------------------------------===// -void RegisterFusedAttentionCustomCalls( - runtime::DirectCustomCallRegistry& registry) { - auto fused_attention = [](std::string name) { - return "xla.gpu.fused.attention." + name; - }; - registry.Register(fused_attention("bmm.bmm.inference"), - FusedAttentionBmmBmmInference); - registry.Register(fused_attention("bmm.bmm.forward"), - FusedAttentionBmmBmmForward); - registry.Register(fused_attention("softmax.inference"), - FusedAttentionSoftmaxInference); - registry.Register(fused_attention("softmax.forward"), - FusedAttentionSoftmaxForward); - registry.Register(fused_attention("softmax.dropout.inference"), - FusedAttentionSoftmaxDropoutInference); - registry.Register(fused_attention("softmax.dropout.forward"), - FusedAttentionSoftmaxDropoutForward); - registry.Register(fused_attention("scale.bias.softmax.inference"), - FusedAttentionScaleBiasSoftmaxInference); - registry.Register(fused_attention("scale.bias.softmax.forward"), - FusedAttentionScaleBiasSoftmaxForward); - registry.Register(fused_attention("scale.bias.softmax.dropout.inference"), - FusedAttentionScaleBiasSoftmaxDropoutInference); - registry.Register(fused_attention("scale.bias.softmax.dropout.forward"), - FusedAttentionScaleBiasSoftmaxDropoutForward); - registry.Register(fused_attention("scale.mask.softmax.inference"), - FusedAttentionScaleMaskSoftmaxInference); - registry.Register(fused_attention("scale.mask.softmax.forward"), - FusedAttentionScaleMaskSoftmaxForward); - registry.Register(fused_attention("scale.mask.softmax.dropout.inference"), - FusedAttentionScaleMaskSoftmaxDropoutInference); - registry.Register(fused_attention("scale.mask.softmax.dropout.forward"), - FusedAttentionScaleMaskSoftmaxDropoutForward); - registry.Register(fused_attention("scale.bias.mask.softmax.inference"), - FusedAttentionScaleBiasMaskSoftmaxInference); - registry.Register(fused_attention("scale.bias.mask.softmax.forward"), - FusedAttentionScaleBiasMaskSoftmaxForward); - registry.Register( - fused_attention("scale.bias.mask.softmax.dropout.inference"), - FusedAttentionScaleBiasMaskSoftmaxDropoutInference); - registry.Register(fused_attention("scale.bias.mask.softmax.dropout.forward"), - FusedAttentionScaleBiasMaskSoftmaxDropoutForward); -} - -void RegisterFusedAttentionBackwardCustomCalls( - runtime::DirectCustomCallRegistry& registry) { - auto fused_attention = [](std::string name) { - return "xla.gpu.fused.attention.backward." + name; - }; - registry.Register(fused_attention("scale.dbias.softmax"), - FusedAttentionScaleBiasSoftmaxBackward); - registry.Register(fused_attention("scale.softmax"), - FusedAttentionScaleSoftmaxBackward); - registry.Register(fused_attention("scale.dbias.softmax.dropout"), - FusedAttentionScaleBiasSoftmaxDropoutBackward); - registry.Register(fused_attention("scale.softmax.dropout"), - FusedAttentionScaleSoftmaxDropoutBackward); - registry.Register(fused_attention("scale.dbias.mask.softmax"), - FusedAttentionScaleBiasMaskSoftmaxBackward); - registry.Register(fused_attention("scale.mask.softmax"), - FusedAttentionScaleMaskSoftmaxBackward); - registry.Register(fused_attention("scale.dbias.mask.softmax.dropout"), - FusedAttentionScaleBiasMaskSoftmaxDropoutBackward); - registry.Register(fused_attention("scale.mask.softmax.dropout"), - FusedAttentionScaleMaskSoftmaxDropoutBackward); - // flash attention bwd - auto flash_attention = [](std::string name) { - return "xla.gpu.flash.attention.backward." + name; - }; - registry.Register(flash_attention("scale.bias.softmax"), - FlashAttentionScaleBiasSoftmaxBackward); - registry.Register(flash_attention("scale.softmax"), - FlashAttentionScaleSoftmaxBackward); -} -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/fused_attention.h b/xla/service/gpu/runtime/fused_attention.h deleted file mode 100644 index d6458f31c4a27..0000000000000 --- a/xla/service/gpu/runtime/fused_attention.h +++ /dev/null @@ -1,108 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_FUSED_ATTENTION_H_ -#define XLA_SERVICE_GPU_RUNTIME_FUSED_ATTENTION_H_ - -#include -#include - -#include "absl/container/node_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime fused attention custom calls. -void RegisterFusedAttentionCustomCalls( - runtime::DirectCustomCallRegistry& registry); - -// Register type names for fused attention attributes defined by MHLO dialect. -void RegisterFusedAttentionTypeIdNames(runtime::TypeIDNameRegistry& registry); - -// Add attributes encoding for fused attention attributes defined by LMHLO -// dialect. -void PopulateFusedAttentionForwardDAGSignatureAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding); - -// Registers XLA Gpu runtime fused attention backward custom calls. -void RegisterFusedAttentionBackwardCustomCalls( - runtime::DirectCustomCallRegistry& registry); - -// Add attributes encoding for fused attention backward attributes defined by -// LMHLO dialect. -void PopulateFusedAttentionBackwardDAGSignatureAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding); - -void PopulateFusedAttentionAlgorithmConfigAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding); - -//===----------------------------------------------------------------------===// -// Cache fused dot attention runners between invocations of fused dot attention -// custom calls. -//===----------------------------------------------------------------------===// -struct FusedAttentionRunner { - explicit FusedAttentionRunner(GpufMHAConfig config) - : config(std::move(config)), runner(this->config) {} - GpufMHAConfig config; - FusedMultiHeadedAttentionRunner runner; -}; - -struct FusedAttentionBackwardRunner { - explicit FusedAttentionBackwardRunner(GpufMHABackwardConfig config) - : config(std::move(config)), runner(this->config) {} - GpufMHABackwardConfig config; - FusedMultiHeadedAttentionBackwardRunner runner; -}; - -class StreamExecutorFusedAttentionRunners - : public runtime::StateVector {}; - -class StreamExecutorFusedAttentionBackwardRunners - : public runtime::StateVector {}; - -// Xla executable keeps a mapping from stream executors to fused attention -// runners. -class FusedAttentionRunners { - public: - StreamExecutorFusedAttentionRunners* operator()(se::StreamExecutor* executor); - - private: - mutable absl::Mutex mutex_; - absl::node_hash_map - runners_ ABSL_GUARDED_BY(mutex_); -}; - -// Xla executable keeps a mapping from stream executors to fused attention -// backward runners. -class FusedAttentionBackwardRunners { - public: - StreamExecutorFusedAttentionBackwardRunners* operator()( - se::StreamExecutor* executor); - - private: - mutable absl::Mutex mutex_; - absl::node_hash_map - runners_ ABSL_GUARDED_BY(mutex_); -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_FUSED_ATTENTION_H_ diff --git a/xla/service/gpu/fused_mha_thunk.cc b/xla/service/gpu/runtime/fused_mha_thunk.cc similarity index 82% rename from xla/service/gpu/fused_mha_thunk.cc rename to xla/service/gpu/runtime/fused_mha_thunk.cc index f0ba6f3fbd177..b39ba7a373000 100644 --- a/xla/service/gpu/fused_mha_thunk.cc +++ b/xla/service/gpu/runtime/fused_mha_thunk.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,20 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fused_mha_thunk.h" +#include "xla/service/gpu/runtime/fused_mha_thunk.h" #include -#include #include -#include "absl/strings/str_cat.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/service/gpu/gpu_conv_runner.h" -#include "xla/service/gpu/ir_emission_utils.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" #include "xla/util.h" -#include "tsl/platform/logging.h" namespace xla { namespace gpu { @@ -36,7 +29,8 @@ FusedMHAThunk::FusedMHAThunk( BufferAllocation::Slice lhs_bmm1, BufferAllocation::Slice rhs_bmm1, BufferAllocation::Slice rhs_bmm2, BufferAllocation::Slice output, BufferAllocation::Slice scratch, BufferAllocation::Slice mask, - BufferAllocation::Slice bias, BufferAllocation::Slice activation) + BufferAllocation::Slice bias, BufferAllocation::Slice activation, + BufferAllocation::Slice seqlen_q, BufferAllocation::Slice seqlen_k) : Thunk(Kind::kFusedMHA, thunk_info), lhs_bmm1_buffer_(lhs_bmm1), rhs_bmm1_buffer_(rhs_bmm1), @@ -46,6 +40,8 @@ FusedMHAThunk::FusedMHAThunk( mask_buffer_(mask), bias_buffer_(bias), activation_buffer_(activation), + seqlen_q_buffer_(seqlen_q), + seqlen_k_buffer_(seqlen_k), config_(std::move(config)) {} FusedMultiHeadedAttentionRunner& FusedMHAThunk::GetOrCreateRunner( @@ -70,7 +66,7 @@ std::optional AssignBufferIfNotNull( : std::nullopt; } -Status FusedMHAThunk::ExecuteOnStream(const ExecuteParams& params) { +absl::Status FusedMHAThunk::ExecuteOnStream(const ExecuteParams& params) { const auto& buffer_allocations = *params.buffer_allocations; se::DeviceMemoryBase lhs_bmm1_buffer = buffer_allocations.GetDeviceAddress(lhs_bmm1_buffer_); @@ -89,18 +85,21 @@ Status FusedMHAThunk::ExecuteOnStream(const ExecuteParams& params) { AssignBufferIfNotNull(buffer_allocations, bias_buffer_); std::optional activation_buffer = AssignBufferIfNotNull(buffer_allocations, activation_buffer_); - + std::optional seqlen_q_buffer = + AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_); + std::optional seqlen_k_buffer = + AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_); RunFusedMHAOptions opts; opts.runner_cache = &GetOrCreateRunner(params.stream); - TF_RETURN_IF_ERROR(RunGpuFMHA(config_, lhs_bmm1_buffer, rhs_bmm1_buffer, - rhs_bmm2_buffer, output_buffer, scratch_buffer, - mask_buffer, bias_buffer, activation_buffer, - params.stream, opts)); + TF_RETURN_IF_ERROR(RunGpuFMHA( + config_, lhs_bmm1_buffer, rhs_bmm1_buffer, rhs_bmm2_buffer, output_buffer, + scratch_buffer, mask_buffer, bias_buffer, activation_buffer, + seqlen_q_buffer, seqlen_k_buffer, params.stream, opts)); if (!params.stream->ok()) { - return InternalError("FusedMHAThunk::ExecuteOnStream failed."); + return Internal("FusedMHAThunk::ExecuteOnStream failed."); } - return OkStatus(); + return absl::OkStatus(); } FusedMHABackwardThunk::FusedMHABackwardThunk( ThunkInfo thunk_info, GpufMHABackwardConfig config, @@ -113,7 +112,8 @@ FusedMHABackwardThunk::FusedMHABackwardThunk( BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s, BufferAllocation::Slice softmax_sum, BufferAllocation::Slice d_Q_accum, BufferAllocation::Slice mask, BufferAllocation::Slice d_bias, - BufferAllocation::Slice fwd_output, BufferAllocation::Slice bias) + BufferAllocation::Slice fwd_output, BufferAllocation::Slice bias, + BufferAllocation::Slice seqlen_q, BufferAllocation::Slice seqlen_k) : Thunk(Kind::kFusedMHA, thunk_info), bmm1_grad_gemm1_rhs_buffer_(bmm1_grad_gemm1_rhs), bmm1_grad_gemm2_rhs_buffer_(bmm1_grad_gemm2_rhs), @@ -131,6 +131,8 @@ FusedMHABackwardThunk::FusedMHABackwardThunk( d_bias_buffer_(d_bias), fwd_output_buffer_(fwd_output), bias_buffer_(bias), + seqlen_q_buffer_(seqlen_q), + seqlen_k_buffer_(seqlen_k), config_(std::move(config)) {} FusedMultiHeadedAttentionBackwardRunner& @@ -148,7 +150,8 @@ FusedMHABackwardThunk::GetOrCreateRunner( return *it->second; } -Status FusedMHABackwardThunk::ExecuteOnStream(const ExecuteParams& params) { +absl::Status FusedMHABackwardThunk::ExecuteOnStream( + const ExecuteParams& params) { const auto& buffer_allocations = *params.buffer_allocations; se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer = buffer_allocations.GetDeviceAddress(bmm1_grad_gemm1_rhs_buffer_); @@ -191,7 +194,10 @@ Status FusedMHABackwardThunk::ExecuteOnStream(const ExecuteParams& params) { AssignBufferIfNotNull(buffer_allocations, fwd_output_buffer_); std::optional bias_buffer = AssignBufferIfNotNull(buffer_allocations, bias_buffer_); - + std::optional seqlen_q_buffer = + AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_); + std::optional seqlen_k_buffer = + AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_); RunFusedMHABackwardOptions opts; opts.runner_cache = &GetOrCreateRunner(params.stream); @@ -201,11 +207,12 @@ Status FusedMHABackwardThunk::ExecuteOnStream(const ExecuteParams& params) { bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, d_output_buffer, scratch_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, softmax_sum_buffer, d_Q_accum_buffer, mask_buffer, - d_bias_buffer, fwd_output_buffer, bias_buffer, params.stream, opts)); + d_bias_buffer, fwd_output_buffer, bias_buffer, seqlen_q_buffer, + seqlen_k_buffer, params.stream, opts)); if (!params.stream->ok()) { - return InternalError("FusedMHABackwardThunk::ExecuteOnStream failed."); + return Internal("FusedMHABackwardThunk::ExecuteOnStream failed."); } - return OkStatus(); + return absl::OkStatus(); } } // namespace gpu diff --git a/xla/service/gpu/fused_mha_thunk.h b/xla/service/gpu/runtime/fused_mha_thunk.h similarity index 83% rename from xla/service/gpu/fused_mha_thunk.h rename to xla/service/gpu/runtime/fused_mha_thunk.h index a0d9e58aa0e64..6e93541f09c12 100644 --- a/xla/service/gpu/fused_mha_thunk.h +++ b/xla/service/gpu/runtime/fused_mha_thunk.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,24 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSED_MHA_THUNK_H_ -#define XLA_SERVICE_GPU_FUSED_MHA_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_FUSED_MHA_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_FUSED_MHA_THUNK_H_ #include -#include +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/gpu_executable.h" #include "xla/service/gpu/gpu_fused_mha_runner.h" -#include "xla/service/gpu/thunk.h" +#include "xla/service/gpu/runtime/thunk.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/status.h" namespace xla { namespace gpu { @@ -50,12 +46,14 @@ class FusedMHAThunk : public Thunk { BufferAllocation::Slice scratch_slice, BufferAllocation::Slice mask_slice, /* may be null */ BufferAllocation::Slice bias_slice /* may be null */, - BufferAllocation::Slice activation_slice /* may be null */); + BufferAllocation::Slice activation_slice /* may be null */, + BufferAllocation::Slice seqlen_q_slice /* may be null */, + BufferAllocation::Slice seqlen_k_slice /* may be null */); FusedMHAThunk(const FusedMHAThunk&) = delete; FusedMHAThunk& operator=(const FusedMHAThunk&) = delete; - Status ExecuteOnStream(const ExecuteParams& params) override; + absl::Status ExecuteOnStream(const ExecuteParams& params) override; private: BufferAllocation::Slice lhs_bmm1_buffer_; @@ -66,6 +64,8 @@ class FusedMHAThunk : public Thunk { BufferAllocation::Slice mask_buffer_; BufferAllocation::Slice bias_buffer_; BufferAllocation::Slice activation_buffer_; + BufferAllocation::Slice seqlen_q_buffer_; + BufferAllocation::Slice seqlen_k_buffer_; FusedMultiHeadedAttentionRunner& GetOrCreateRunner( const stream_executor::Stream* stream); @@ -97,12 +97,14 @@ class FusedMHABackwardThunk : public Thunk { BufferAllocation::Slice mask_slice, BufferAllocation::Slice d_bias_slice, BufferAllocation::Slice fwd_output_slice, - BufferAllocation::Slice bias_slice); + BufferAllocation::Slice bias_slice, + BufferAllocation::Slice seqlen_q_slice, + BufferAllocation::Slice seqlen_k_slice); FusedMHABackwardThunk(const FusedMHABackwardThunk&) = delete; FusedMHABackwardThunk& operator=(const FusedMHABackwardThunk&) = delete; - Status ExecuteOnStream(const ExecuteParams& params) override; + absl::Status ExecuteOnStream(const ExecuteParams& params) override; private: BufferAllocation::Slice bmm1_grad_gemm1_rhs_buffer_; @@ -121,6 +123,8 @@ class FusedMHABackwardThunk : public Thunk { BufferAllocation::Slice d_bias_buffer_; BufferAllocation::Slice fwd_output_buffer_; BufferAllocation::Slice bias_buffer_; + BufferAllocation::Slice seqlen_q_buffer_; + BufferAllocation::Slice seqlen_k_buffer_; FusedMultiHeadedAttentionBackwardRunner& GetOrCreateRunner( const stream_executor::Stream* stream); @@ -134,4 +138,4 @@ class FusedMHABackwardThunk : public Thunk { }; } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_FUSED_MHA_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_FUSED_MHA_THUNK_H_ diff --git a/xla/service/gpu/runtime/gemm.cc b/xla/service/gpu/runtime/gemm.cc deleted file mode 100644 index d56bbad214b60..0000000000000 --- a/xla/service/gpu/runtime/gemm.cc +++ /dev/null @@ -1,204 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/gemm.h" - -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/gpu_asm_opts_util.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/hlo_module_config.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/status.h" -#include "xla/stream_executor/blas.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/xla.pb.h" -#include "tsl/platform/errors.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/service/gpu/gemm_algorithm_picker.h" -#include "xla/stream_executor/gpu/redzone_allocator.h" -#endif - -namespace xla { -namespace gpu { - -using xla::runtime::CustomCall; -using xla::runtime::State; -using xla::runtime::StridedMemrefView; - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -// TODO(ezhulenev): Delete run time auto tuning from XLA. -Status DoRuntimeAutotuning(se::Stream* stream, GemmConfig& config, - se::DeviceMemoryBase lhs, se::DeviceMemoryBase rhs, - se::DeviceMemoryBase out, const Shape& output_shape, - double beta, const DebugOptions* debug_options, - NonAtomicallyUpgradeableRWLock* gpu_lock) { - VLOG(3) << "Running GEMM runtime autotuning"; - std::vector algorithms; - stream->parent()->GetBlasGemmAlgorithms(stream, &algorithms); - const bool deterministic_ops = debug_options->xla_gpu_deterministic_ops(); - - AutotuneConfig autotune_config{ - DeviceConfig{stream->parent(), stream->parent()->GetAllocator()}, - *debug_options}; - - // TODO(jlebar): We should not use stream->parent()->GetAllocator() here; - // that's the global CUDA allocator. There may not be any free space in - // there, because TF usually gobbles it all up for its own BFCAllocator. We - // should use the allocator the user passed when running the XLA program. - se::RedzoneAllocator buffer_allocator( - stream, stream->parent()->GetAllocator(), - PtxOptsFromDebugOptions(*debug_options), - /*memory_limit=*/std::numeric_limits::max(), - /*redzone_size=*/autotune_config.should_check_correctness() - ? debug_options->xla_gpu_redzone_padding_bytes() - : 0); - - // Upgrade the reader lock for execution to a writer lock to protect runtime - // autotuning. - NonAtomicallyUpgradeableRWLock::WriterLock writer_lock = - gpu_lock->UpgradeToWriterMutexLock(); - - TF_ASSIGN_OR_RETURN( - AutotuneResult best_algorithm, - GetBestBlasAlgorithm( - stream, buffer_allocator, /*gemm_str=*/std::nullopt, autotune_config, - lhs, rhs, out, algorithms, output_shape, HloModuleConfig(), beta, - [&](const se::blas::AlgorithmType& algorithm) - -> StatusOr { - se::blas::ProfileResult profile_result; - // We expect GemmWithAlgorithm to fail sometimes -- in fact, it will - // fail for all algorithms if we're targeting < sm_50. But because - // we pass a non-null ProfileResult, DoGemmWithAlgorithm should - // always return true, and the actual success-ness is returned in - // ProfileResult::is_valid. - TF_RETURN_IF_ERROR( - RunGemm(config, lhs, rhs, out, se::DeviceMemoryBase(nullptr, 0), - deterministic_ops, stream, algorithm, &profile_result)); - return std::move(profile_result); - })); - - if (best_algorithm.has_gemm()) { - config.algorithm = algorithms[best_algorithm.gemm().algorithm()]; - return OkStatus(); - } else { - return InternalError("Runtime autotuning failed to select an algorithm"); - } -} -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -static absl::Status GemmImpl(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - NonAtomicallyUpgradeableRWLock* gpu_lock, - State state, StridedMemrefView lhs, - StridedMemrefView rhs, StridedMemrefView out, - StridedMemrefView workspace, int64_t algorithm, - double alpha_real, double alpha_imag, double beta, - DotDimensionNumbers dot_dims, - absl::Span precision) { - se::DeviceMemoryBase lhs_data = GetDeviceAddress(lhs); - se::DeviceMemoryBase rhs_data = GetDeviceAddress(rhs); - se::DeviceMemoryBase output_data = GetDeviceAddress(out); - se::DeviceMemoryBase workspace_data = GetDeviceAddress(workspace); - const bool deterministic_ops = debug_options->xla_gpu_deterministic_ops(); - - VLOG(3) << "Running GEMM"; - se::Stream* stream = run_options->stream(); - Shape output_shape = ToShape(out); - - // Get the gemm config from the state. - TF_ASSIGN_OR_RETURN(GemmConfig * gemm_config, state.GetOrCreate([&] { - StatusOr gemm_config = - GetGemmConfig(lhs, rhs, out, algorithm, alpha_real, alpha_imag, beta, - dot_dims.lhs_batch, dot_dims.lhs_contract, - dot_dims.rhs_batch, dot_dims.rhs_contract, - precision.empty() ? se::blas::kDefaultComputePrecision - : *absl::c_max_element(precision)); - return ToAbsl(gemm_config); - })); - - // Set the gemm algorithm by runtime autotuning. We do runtime autotuning - // outside of state.GetOrCreate() because otherwise it would be a potential - // deadlock. - if (gemm_config->algorithm == stream_executor::blas::kRuntimeAutotuning) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - auto status = DoRuntimeAutotuning(stream, *gemm_config, lhs_data, rhs_data, - output_data, output_shape, beta, - debug_options, gpu_lock); - if (!status.ok()) { - return absl::InternalError(status.ToString()); - } -#else - return absl::InternalError( - "Failed to run runtime autotuner because CUDA is not enabled"); -#endif - } - - return RunGemm(*gemm_config, lhs_data, rhs_data, output_data, workspace_data, - deterministic_ops, stream); -} - -static absl::Status InitCuBLASImpl( - const ServiceExecutableRunOptions* run_options) { - // Initialize (with memoization) BlasSupport here because cublasCreate fails - // during gpu graph capturing. - se::StreamExecutor* executor = run_options->stream()->parent(); - if (!executor->AsBlas()) { - return absl::InternalError("Failed to initialize BLAS support"); - } - return absl::OkStatus(); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Gemm, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.gemm") - .UserData() - .UserData() - .UserData() - .State("uid") - .Arg() // lhs - .Arg() // rhs - .Arg() // out - .Arg() // workspace - .Attr("algorithm") - .Attr("alpha_real") - .Attr("alpha_imag") - .Attr("beta") - .Attr("dot_dims") - .Attr>("precision")); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - InitCuBLAS, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.init_cublas") - .UserData()); - -void RegisterGemmCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.gemm", Gemm); - registry.Register("xla.gpu.init_cublas", InitCuBLAS); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/gemm.h b/xla/service/gpu/runtime/gemm.h deleted file mode 100644 index 828c8a10602f0..0000000000000 --- a/xla/service/gpu/runtime/gemm.h +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_GEMM_H_ -#define XLA_SERVICE_GPU_RUNTIME_GEMM_H_ - -#include "absl/container/node_hash_map.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/state.h" -#include "xla/service/gpu/matmul_utils.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime Gemm# custom calls. -void RegisterGemmCustomCalls(runtime::DirectCustomCallRegistry& registry); - -// Keep GemmConfigs for all gemm/matmul instances in the executable. -class GemmConfigs : public runtime::StateVector {}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_GEMM_H_ diff --git a/xla/service/gpu/runtime/gemm_thunk.cc b/xla/service/gpu/runtime/gemm_thunk.cc new file mode 100644 index 0000000000000..4d46a78c4af58 --- /dev/null +++ b/xla/service/gpu/runtime/gemm_thunk.cc @@ -0,0 +1,69 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/gemm_thunk.h" + +#include + +#include "absl/status/status.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/stream_executor/device_memory.h" +#include "tsl/platform/logging.h" + +namespace xla { +namespace gpu { + +GemmThunk::GemmThunk(ThunkInfo thunk_info, GemmConfig config, + const BufferAllocation::Slice& lhs_buffer, + const BufferAllocation::Slice& rhs_buffer, + const BufferAllocation::Slice& output_buffer, + std::optional workspace, + bool deterministic) + : Thunk(Kind::kGemm, thunk_info), + config_(std::move(config)), + lhs_buffer_(lhs_buffer), + rhs_buffer_(rhs_buffer), + output_buffer_(output_buffer), + workspace_(workspace), + deterministic_(deterministic) {} + +absl::Status GemmThunk::ExecuteOnStream(const ExecuteParams& params) { + VLOG(3) << "Running GEMM thunk"; + const BufferAllocations& allocs = *params.buffer_allocations; + se::DeviceMemoryBase workspace(/*opaque=*/nullptr, /*size=*/0); + if (workspace_.has_value()) { + workspace = allocs.GetDeviceAddress(workspace_.value()); + } + TF_ASSIGN_OR_RETURN( + se::Stream * stream, + GetStreamForExecution(Thunk::execution_stream_id(), params)); + + return RunGemm(config_, allocs.GetDeviceAddress(lhs_buffer_), + allocs.GetDeviceAddress(rhs_buffer_), + allocs.GetDeviceAddress(output_buffer_), workspace, + deterministic_, stream); +} + +absl::Status GemmThunk::Initialize(const InitializeParams& params) { + if (!params.executor->AsBlas()) { + return absl::InternalError("Failed to initialize BLAS support"); + } + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/gemm_thunk.h b/xla/service/gpu/runtime/gemm_thunk.h new file mode 100644 index 0000000000000..58f13d33172bb --- /dev/null +++ b/xla/service/gpu/runtime/gemm_thunk.h @@ -0,0 +1,69 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_GEMM_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_GEMM_THUNK_H_ + +#include + +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/status.h" + +namespace xla { +namespace gpu { + +// This is thread-compatible. +class GemmThunk : public Thunk { + public: + // Constructs a thunk that computes "output = (lhs rhs) * alpha" using + // BLAS gemm (alpha is stored in the instruction GemmBackendConfig). + GemmThunk(ThunkInfo thunk_info, GemmConfig config, + const BufferAllocation::Slice& lhs_buffer, + const BufferAllocation::Slice& rhs_buffer, + const BufferAllocation::Slice& output_buffer, + std::optional workspace, + bool deterministic); + + GemmThunk(const GemmThunk&) = delete; + GemmThunk& operator=(const GemmThunk&) = delete; + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + absl::Status Initialize(const InitializeParams& params) override; + + GemmConfig config() const { return config_; } + BufferAllocation::Slice lhs_buffer() const { return lhs_buffer_; } + BufferAllocation::Slice rhs_buffer() const { return rhs_buffer_; } + BufferAllocation::Slice output_buffer() const { return output_buffer_; } + std::optional workspace() const { + return workspace_; + } + bool deterministic() const { return deterministic_; } + + private: + const GemmConfig config_; + const BufferAllocation::Slice lhs_buffer_; + const BufferAllocation::Slice rhs_buffer_; + const BufferAllocation::Slice output_buffer_; + std::optional workspace_; + // Whether to run deterministically. + const bool deterministic_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_GEMM_THUNK_H_ diff --git a/xla/service/gpu/runtime/gpu_kernel_helper.h b/xla/service/gpu/runtime/gpu_kernel_helper.h deleted file mode 100644 index f3ee9b250249d..0000000000000 --- a/xla/service/gpu/runtime/gpu_kernel_helper.h +++ /dev/null @@ -1,148 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_GPU_KERNEL_HELPER_H_ -#define XLA_SERVICE_GPU_RUNTIME_GPU_KERNEL_HELPER_H_ - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -#include - -#include "tsl/lib/math/math_util.h" - -namespace xla { -namespace gpu { - -#if GOOGLE_CUDA -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" -#else -#include "rocm/include/hip/hip_runtime.h" -#endif - -#if GOOGLE_CUDA -#define WAVEFRONT_SIZE 32 -#define FORCEINLINE __forceinline__ -using gpuStream_t = cudaStream_t; -using gpuError_t = cudaError_t; -using gpuEvent_t = cudaEvent_t; -#define gpuSuccess cudaSuccess -#define gpuGetLastError cudaGetLastError -#define gpuGetErrorString cudaGetErrorString -#define gpuEventRecord cudaEventRecord -#define gpuEventSynchronize cudaEventSynchronize -#define gpuEventDestroy cudaEventDestroy -#define gpuEventCreate cudaEventCreate -#define gpuEventCreateWithFlags cudaEventCreateWithFlags -#define gpuEventDisableTiming cudaEventDisableTiming -#define gpuEventElapsedTime cudaEventElapsedTime -#define gpuDeviceSynchronize cudaDeviceSynchronize -#define gpuLaunchKernel cudaLaunchKernel -#define gpuMemcpy cudaMemcpy -#define gpuMalloc cudaMalloc -#define gpuFree cudaFree -#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice -#define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost -#define gpuStreamCreate cudaStreamCreate -#define gpuStreamSynchronize cudaStreamSynchronize - -#elif TENSORFLOW_USE_ROCM -using gpuStream_t = hipStream_t; -using gpuError_t = hipError_t; -using gpuEvent_t = hipEvent_t; -#define gpuSuccess hipSuccess -#define gpuGetLastError hipGetLastError -#define gpuGetErrorString hipGetErrorString -#define gpuEventRecord hipEventRecord -#define gpuEventDestroy hipEventDestroy -#define gpuEventSynchronize hipEventSynchronize -#define gpuEventCreate hipEventCreate -#define gpuEventCreateWithFlags hipEventCreateWithFlags -#define gpuEventDisableTiming hipEventDisableTiming -#define gpuEventElapsedTime hipEventElapsedTime -#define gpuDeviceSynchronize hipDeviceSynchronize -#define gpuLaunchKernel hipLaunchKernel -#define gpuMemcpy hipMemcpy -#define gpuMalloc hipMalloc -#define gpuFree hipFree -#define gpuMemcpyHostToDevice hipMemcpyHostToDevice -#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost -#define gpuStreamCreate hipStreamCreate -#define gpuStreamSynchronize hipStreamSynchronize - -#ifdef __AMDGCN_WAVEFRONT_SIZE -#define WAVEFRONT_SIZE __AMDGCN_WAVEFRONT_SIZE -#else -#define WAVEFRONT_SIZE 64 -#endif -#define FORCEINLINE __forceinline__ -#endif - -// macro wrapper to declare dynamic shared memory -#if GOOGLE_CUDA - -#define GPU_DYNAMIC_SHARED_MEM_DECL(ALIGN, TYPE, NAME) \ - extern __shared__ __align__(ALIGN) \ - TYPE NAME[] - -#elif TENSORFLOW_USE_ROCM - -#define GPU_DYNAMIC_SHARED_MEM_DECL(ALIGN, TYPE, NAME) \ - HIP_DYNAMIC_SHARED(TYPE, NAME) - -#endif - -enum class ShflType { Sync, Up, Down, Xor }; - -template -__device__ FORCEINLINE NT GpuShuffle(NT val, uint32_t idx, - uint32_t allmsk = 0xffffffffu) { - constexpr uint32_t SZ = - tsl::MathUtil::CeilOfRatio(sizeof(NT), sizeof(uint32_t)); - union S { - NT v; - uint32_t d[SZ]; - }; - S in{val}, res{}; - -#pragma unroll - for (uint32_t i = 0; i < SZ; i++) { -#if GOOGLE_CUDA - if constexpr (Type == ShflType::Sync) - res.d[i] = __shfl_sync(allmsk, in.d[i], idx); - else if constexpr (Type == ShflType::Up) - res.d[i] = __shfl_up_sync(allmsk, in.d[i], idx); - else if constexpr (Type == ShflType::Down) - res.d[i] = __shfl_down_sync(allmsk, in.d[i], idx); - else if constexpr (Type == ShflType::Xor) - res.d[i] = __shfl_xor_sync(allmsk, in.d[i], idx); -#elif TENSORFLOW_USE_ROCM // ROcm does not support sync shuffle intrinsics - if constexpr (Type == ShflType::Sync) - res.d[i] = __shfl(in.d[i], idx); - else if constexpr (Type == ShflType::Up) - res.d[i] = __shfl_up(in.d[i], idx); - else if constexpr (Type == ShflType::Down) - res.d[i] = __shfl_down(in.d[i], idx); - else if constexpr (Type == ShflType::Xor) - res.d[i] = __shfl_xor(in.d[i], idx); -#endif - } - return res.v; -} - -} // namespace gpu -} // namespace xla - -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#endif // XLA_SERVICE_GPU_RUNTIME_GPU_KERNEL_HELPER_H_ diff --git a/xla/service/gpu/runtime/gpublas_lt_matmul.cc b/xla/service/gpu/runtime/gpublas_lt_matmul.cc deleted file mode 100644 index 151571c938952..0000000000000 --- a/xla/service/gpu/runtime/gpublas_lt_matmul.cc +++ /dev/null @@ -1,302 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License.1 -==============================================================================*/ - -#include "xla/service/gpu/runtime/gpublas_lt_matmul.h" - -#include -#include -#include -#include -#include - -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/logical_result.h" -#include "xla/runtime/state.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/stream_executor/scratch_allocator.h" -#include "xla/xla.pb.h" -#include "tsl/platform/status.h" - -#if TENSORFLOW_USE_ROCM -#include "rocm/rocm_config.h" -#endif - -namespace xla { -#if GOOGLE_CUDA || TF_HIPBLASLT - -using xla::runtime::CustomCall; -using xla::runtime::CustomCallAttrEncodingSet; -using xla::runtime::EnumAttrEncoding; -using xla::runtime::State; -using xla::runtime::StridedMemrefView; - -namespace lmhlo_gpu = ::mlir::lmhlo_gpu; - -//===----------------------------------------------------------------------===// -// Register cuBLASLt attributes decoding with the Xla runtime. -//===----------------------------------------------------------------------===// - -namespace runtime { -XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING(se::gpu::BlasLt::Epilogue); -} // namespace runtime - -//===----------------------------------------------------------------------===// -// Encoding from MHLO attributes to Xla runtime enums. -//===----------------------------------------------------------------------===// - -namespace gpu { - -void PopulateCublasLtMatmulAttrEncoding(CustomCallAttrEncodingSet& encoding) { - encoding.Add>( - [](lmhlo_gpu::CublasLtMatmulEpilogue value) -> se::gpu::BlasLt::Epilogue { - return gpublas_lt::AsBlasLtEpilogue(value).value(); - }); -} - -//===----------------------------------------------------------------------===// -// cuBLASLt matmul custom call implementation. -//===----------------------------------------------------------------------===// - -namespace { - -absl::Status DoMatmul( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, State gemm_config, - State matmul_plan, StridedMemrefView a, - StridedMemrefView b, StridedMemrefView c, StridedMemrefView d, - std::optional bias, std::optional aux, - std::optional a_scale, - std::optional b_scale, - std::optional c_scale, - std::optional d_scale, - std::optional d_amax, int64_t algorithm, - double alpha_real, double alpha_imag, double beta, - DotDimensionNumbers dot_dims, se::gpu::BlasLt::Epilogue epilogue, - absl::Span precision) { - se::Stream* stream = run_options->stream(); - - // Find the gemm config for this instance of matmul. - TF_ASSIGN_OR_RETURN(GemmConfig * config, gemm_config.GetOrCreate([&] { - return ToAbsl(GetGemmConfig( - a, b, d, algorithm, alpha_real, alpha_imag, beta, dot_dims.lhs_batch, - dot_dims.lhs_contract, dot_dims.rhs_batch, dot_dims.rhs_contract, - precision.empty() ? se::blas::kDefaultComputePrecision - : *absl::c_max_element(precision), - c, bias)); - })); - - // Get the matmul plan for this instance of matmul. - TF_ASSIGN_OR_RETURN(auto plan, matmul_plan.GetOrCreate([&] { - return ToAbsl(se::gpu::BlasLt::GetMatmulPlan(stream, *config, epilogue)); - })); - - TF_ASSIGN_OR_RETURN(auto algos, (*plan)->GetAlgorithms()); - if (static_cast(algorithm) >= algos.size()) { - return absl::InternalError( - absl::StrFormat("The requested gpublas-lt matmul " - "algorithm is not found. Total algorithms available: " - "%zu; requested: %zu", - algos.size(), static_cast(algorithm))); - } - - se::DeviceMemoryBase a_data = GetDeviceAddress(a); - se::DeviceMemoryBase b_data = GetDeviceAddress(b); - se::DeviceMemoryBase c_data = GetDeviceAddress(c); - se::DeviceMemoryBase d_data = GetDeviceAddress(d); - se::DeviceMemoryBase bias_data; - if (bias.has_value()) bias_data = GetDeviceAddress(*bias); - se::DeviceMemoryBase aux_data; - if (aux.has_value()) aux_data = GetDeviceAddress(*aux); - - se::DeviceMemoryBase a_scale_data; - if (a_scale.has_value()) a_scale_data = GetDeviceAddress(*a_scale); - se::DeviceMemoryBase b_scale_data; - if (b_scale.has_value()) b_scale_data = GetDeviceAddress(*b_scale); - se::DeviceMemoryBase c_scale_data; - if (c_scale.has_value()) c_scale_data = GetDeviceAddress(*c_scale); - se::DeviceMemoryBase d_scale_data; - if (d_scale.has_value()) d_scale_data = GetDeviceAddress(*d_scale); - se::DeviceMemoryBase d_amax_data; - if (d_amax.has_value()) d_amax_data = GetDeviceAddress(*d_amax); - - se::OwningScratchAllocator<> scratch_allocator( - stream->parent()->device_ordinal(), stream->parent()->GetAllocator()); - - return (*plan)->ExecuteOnStream( - stream, a_data, b_data, c_data, d_data, bias_data, aux_data, a_scale_data, - b_scale_data, c_scale_data, d_scale_data, d_amax_data, algos[algorithm], - scratch_allocator); -} - -} // namespace - -static absl::Status CublasLtMatmulImpl( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, State gemm_config, - State matmul_plan, StridedMemrefView a, - StridedMemrefView b, StridedMemrefView c, StridedMemrefView d, - std::optional bias, std::optional aux, - int64_t algorithm, double alpha_real, double alpha_imag, double beta, - DotDimensionNumbers dot_dims, se::gpu::BlasLt::Epilogue epilogue, - absl::Span precision) { - VLOG(3) << "Running CublasLtMatmul"; - std::optional a_scale, b_scale, c_scale, d_scale, d_amax; - return DoMatmul(run_options, debug_options, gemm_config, matmul_plan, a, b, c, - d, bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, - algorithm, alpha_real, alpha_imag, beta, dot_dims, epilogue, - precision); -} - -static absl::Status CublasLtMatmulF8Impl( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, State gemm_config, - State matmul_plan, StridedMemrefView a, - StridedMemrefView b, StridedMemrefView c, StridedMemrefView a_scale, - StridedMemrefView b_scale, StridedMemrefView c_scale, - StridedMemrefView d_scale, StridedMemrefView d, - CustomCall::RemainingArgs remaining_args, int64_t algorithm, - double alpha_real, double alpha_imag, double beta, - DotDimensionNumbers dot_dims, se::gpu::BlasLt::Epilogue epilogue, - absl::Span precision) { - VLOG(3) << "Running CublasLtMatmulF8"; - std::optional bias, d_amax, aux; - int current_remaining_arg = 0; - - // Get bias, if present - if (epilogue == se::gpu::BlasLt::Epilogue::kBias || - epilogue == se::gpu::BlasLt::Epilogue::kBiasThenReLU || - epilogue == se::gpu::BlasLt::Epilogue::kBiasThenGELU || - epilogue == se::gpu::BlasLt::Epilogue::kBiasThenGELUWithAux) { - if (remaining_args.size() <= current_remaining_arg) { - return absl::InternalError("Epilogue not present in CublasLtMatmulF8 op"); - } - auto bias_or_failure = - remaining_args.get(current_remaining_arg++); - if (failed(bias_or_failure)) { - return absl::InternalError("Failed to get epilogue"); - } - bias = bias_or_failure.value(); - } - - // Get amax, if present - if (remaining_args.size() > current_remaining_arg) { - auto d_amax_or_failure = - remaining_args.get(current_remaining_arg++); - if (failed(d_amax_or_failure)) { - return absl::InternalError("Failed to get d_amax"); - } - d_amax = d_amax_or_failure.value(); - } - - return DoMatmul(run_options, debug_options, gemm_config, matmul_plan, a, b, c, - d, bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, - algorithm, alpha_real, alpha_imag, beta, dot_dims, epilogue, - precision); -} - -//===----------------------------------------------------------------------===// -// cuBLASLt custom calls bindings and registration. -//===----------------------------------------------------------------------===// - -template -auto BindMatmulAttributes(runtime::CustomCallBinding binding) { - return std::move(binding) - .template Attr("algorithm") - .template Attr("alpha_real") - .template Attr("alpha_imag") - .template Attr("beta") - .template Attr("dot_dims") - .template Attr("epilogue") - .template Attr>("precision"); -} - -auto CublasLtMatmulCall(const char* name) { - return CustomCall::Bind(name) - .UserData() - .UserData() - .State("uid") - .State("uid") - .Arg() // a - .Arg() // b - .Arg() // c - .Arg(); // d -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - CublasLtMatmul, FunctionWrapper(), checks, - BindMatmulAttributes(CublasLtMatmulCall("xla.gpu.cublas.lt.matmul") - .Value(std::optional()) // bias - .Value(std::optional()) // aux - )); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - CublasLtMatmulBias, FunctionWrapper(), checks, - BindMatmulAttributes(CublasLtMatmulCall("xla.gpu.cublas.lt.matmul.bias") - .Arg() // bias - .Value(std::optional()) // aux - )); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - CublasLtMatmulAux, FunctionWrapper(), checks, - BindMatmulAttributes(CublasLtMatmulCall("xla.gpu.cublas.lt.matmul.aux") - .Value(std::optional()) // bias - .Arg() // aux - )); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - CublasLtMatmulBiasAux, FunctionWrapper(), checks, - BindMatmulAttributes(CublasLtMatmulCall("xla.gpu.cublas.lt.matmul.bias.aux") - .Arg() // bias - .Arg() // aux - )); - -auto CublasLtMatmulF8Call(const char* name) { - return CustomCall::Bind(name) - .UserData() - .UserData() - .State("uid") - .State("uid") - .Arg() // a - .Arg() // b - .Arg() // c - .Arg() // a_scale - .Arg() // b_scale - .Arg() // c_scale - .Arg() // d_scale - .Arg(); // d -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - CublasLtMatmulF8, FunctionWrapper(), checks, - BindMatmulAttributes( - CublasLtMatmulF8Call("xla.gpu.cublas.lt.matmul.f8").RemainingArgs())); - -void RegisterMatmulCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.cublas.lt.matmul", CublasLtMatmul); - registry.Register("xla.gpu.cublas.lt.matmul.bias", CublasLtMatmulBias); - registry.Register("xla.gpu.cublas.lt.matmul.aux", CublasLtMatmulAux); - registry.Register("xla.gpu.cublas.lt.matmul.bias.aux", CublasLtMatmulBiasAux); - registry.Register("xla.gpu.cublas.lt.matmul.f8", CublasLtMatmulF8); -} - -} // namespace gpu -#endif // GOOGLE_CUDA || TF_HIPBLASLT -} // namespace xla diff --git a/xla/service/gpu/runtime/gpublas_lt_matmul.h b/xla/service/gpu/runtime/gpublas_lt_matmul.h deleted file mode 100644 index be85ea6e86739..0000000000000 --- a/xla/service/gpu/runtime/gpublas_lt_matmul.h +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_GPUBLAS_LT_MATMUL_H_ -#define XLA_SERVICE_GPU_RUNTIME_GPUBLAS_LT_MATMUL_H_ - -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call_registry.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/service/gpu/matmul_utils.h" -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -namespace xla { -namespace gpu { - -// Add cuBLASLt attributes encoding -void PopulateCublasLtMatmulAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding); - -#if GOOGLE_CUDA || TF_HIPBLASLT - -// Registers XLA Gpu runtime cuBLASLt custom calls. -void RegisterMatmulCustomCalls(runtime::DirectCustomCallRegistry& registry); - -// Keep cublas_lt::MatmulPlan's for all matmul instances in the executable. -class MatmulPlans - : public runtime::StateVector {}; -#endif // GOOGLE_CUDA || TF_HIPBLASLT - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_GPUBLAS_LT_MATMUL_H_ diff --git a/xla/service/gpu/gpublas_lt_matmul_thunk.cc b/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.cc similarity index 75% rename from xla/service/gpu/gpublas_lt_matmul_thunk.cc rename to xla/service/gpu/runtime/gpublas_lt_matmul_thunk.cc index 4b3d4f1fcbe56..a6aaabe5b3189 100644 --- a/xla/service/gpu/gpublas_lt_matmul_thunk.cc +++ b/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,13 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpublas_lt_matmul_thunk.h" +#include "xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h" -#include #include #include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/thunk.h" +#include "xla/service/gpu/runtime/thunk.h" #include "xla/status.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/scratch_allocator.h" @@ -53,7 +52,7 @@ CublasLtMatmulThunk::CublasLtMatmulThunk( d_scale_buffer_(d_scale), d_amax_buffer_(d_amax) {} -Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) { +absl::Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) { TF_ASSIGN_OR_RETURN(auto plan, GetMatmulPlan(params.stream)); TF_ASSIGN_OR_RETURN(auto algorithm, GetMatmulAlgorithm(plan)); @@ -91,32 +90,38 @@ Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) { params.stream, allocs.GetDeviceAddress(a_buffer_), allocs.GetDeviceAddress(b_buffer_), allocs.GetDeviceAddress(c_buffer_), allocs.GetDeviceAddress(d_buffer_), bias, aux, a_scale, b_scale, c_scale, - d_scale, d_amax, *algorithm, scratch_allocator); + d_scale, d_amax, algorithm, scratch_allocator); } -StatusOr CublasLtMatmulThunk::GetMatmulPlan( +absl::StatusOr CublasLtMatmulThunk::GetMatmulPlan( const stream_executor::Stream* stream) { - absl::MutexLock lock(&matmul_plans_cache_mutex_); - auto it = matmul_plans_cache_.find(stream); - if (it == matmul_plans_cache_.end()) { - TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan( - stream, gemm_config_, epilogue_)); - it = matmul_plans_cache_.emplace(stream, std::move(plan)).first; + { + absl::MutexLock lock(&matmul_plans_cache_mutex_); + auto it = matmul_plans_cache_.find(stream); + if (it != matmul_plans_cache_.end()) return it->second.get(); } + TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan( + stream, gemm_config_, epilogue_)); + + absl::MutexLock lock(&matmul_plans_cache_mutex_); + auto [it, _] = matmul_plans_cache_.emplace(stream, std::move(plan)); return it->second.get(); } -StatusOr > +absl::StatusOr CublasLtMatmulThunk::GetMatmulAlgorithm( const se::gpu::BlasLt::MatmulPlan* plan) { - absl::MutexLock lock(&matmul_algorithm_cache_mutex_); - auto it = matmul_algorithm_cache_.find(plan); - if (it == matmul_algorithm_cache_.end()) { - TF_ASSIGN_OR_RETURN(auto algorithms, plan->GetAlgorithms()); - TF_RET_CHECK(algorithm_idx_ >= 0 && algorithm_idx_ < algorithms.size()); - auto algorithm = algorithms[algorithm_idx_]; - it = matmul_algorithm_cache_.emplace(plan, algorithm).first; + { + absl::MutexLock lock(&matmul_algorithm_cache_mutex_); + auto it = matmul_algorithm_cache_.find(plan); + if (it != matmul_algorithm_cache_.end()) return it->second; } + TF_ASSIGN_OR_RETURN(auto algorithms, plan->GetAlgorithms()); + TF_RET_CHECK(algorithm_idx_ >= 0 && algorithm_idx_ < algorithms.size()); + + absl::MutexLock lock(&matmul_algorithm_cache_mutex_); + auto [it, _] = + matmul_algorithm_cache_.emplace(plan, algorithms[algorithm_idx_]); return it->second; } diff --git a/xla/service/gpu/gpublas_lt_matmul_thunk.h b/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h similarity index 80% rename from xla/service/gpu/gpublas_lt_matmul_thunk.h rename to xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h index 5a394dbf80bd6..90ff6a1d6c24a 100644 --- a/xla/service/gpu/gpublas_lt_matmul_thunk.h +++ b/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,18 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPUBLAS_LT_MATMUL_THUNK_H_ -#define XLA_SERVICE_GPU_GPUBLAS_LT_MATMUL_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_GPUBLAS_LT_MATMUL_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_GPUBLAS_LT_MATMUL_THUNK_H_ -#include +#include #include -#include +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/thunk.h" +#include "xla/service/gpu/runtime/thunk.h" #include "xla/status.h" -#include "tsl/platform/statusor.h" +#include "xla/stream_executor/gpu/gpu_blas_lt.h" +#include "xla/stream_executor/stream.h" namespace xla { namespace gpu { @@ -45,12 +48,12 @@ class CublasLtMatmulThunk : public Thunk { BufferAllocation::Slice d_scale_buffer /* may be null */, BufferAllocation::Slice d_amax_buffer /* may be null */); - Status ExecuteOnStream(const ExecuteParams& params) override; + absl::Status ExecuteOnStream(const ExecuteParams& params) override; private: - StatusOr GetMatmulPlan( + absl::StatusOr GetMatmulPlan( const stream_executor::Stream* stream); - StatusOr > GetMatmulAlgorithm( + absl::StatusOr GetMatmulAlgorithm( const se::gpu::BlasLt::MatmulPlan* plan); absl::Mutex matmul_plans_cache_mutex_; @@ -82,4 +85,4 @@ class CublasLtMatmulThunk : public Thunk { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPUBLAS_LT_MATMUL_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_GPUBLAS_LT_MATMUL_THUNK_H_ diff --git a/xla/service/gpu/runtime/graph_launch.cc b/xla/service/gpu/runtime/graph_launch.cc deleted file mode 100644 index a62adcd73f638..0000000000000 --- a/xla/service/gpu/runtime/graph_launch.cc +++ /dev/null @@ -1,730 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/graph_launch.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" -#include "xla/service/gpu/runtime/concurrent_region.h" -#include "xla/service/gpu/runtime/conv.h" -#include "xla/service/gpu/runtime/gemm.h" -#include "xla/service/gpu/runtime/kernel_launch.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/statusor.h" -#include "tsl/profiler/lib/profiler_lock.h" -#include "tsl/profiler/lib/traceme.h" -#include "tsl/profiler/lib/traceme_encode.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/stream_executor/gpu/gpu_graph.h" -#endif // #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -namespace xla { -namespace gpu { - -using tsl::profiler::TraceMe; -using tsl::profiler::TraceMeEncode; - -using xla::runtime::Arguments; -using xla::runtime::AsyncTaskRunner; -using xla::runtime::CustomCall; -using xla::runtime::Executable; -using xla::runtime::FunctionRef; -using xla::runtime::FunctionType; -using xla::runtime::MemrefDesc; -using xla::runtime::MemrefType; -using xla::runtime::StridedMemrefView; - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -using se::gpu::OwnedGpuGraph; - -// Captures Gpu graph by running given function in capture mode. -static absl::StatusOr CaptureGraph( - const ServiceExecutableRunOptions* run_options, - runtime::FunctionRef function_ref, Arguments& args, - CustomCall::UserData user_data); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -//===----------------------------------------------------------------------===// -// GPU graphs caching. -//===----------------------------------------------------------------------===// - -struct GraphInstances::Impl { - struct State { - // A flag signalling if `InstantiateAllGraphs` was already called and we - // have all Gpu graph instantiated ahead of time. - bool instantiated = false; - - // Last time graph instances were used by a particular stream executor. - uint64_t last_use_micros = 0; - - std::shared_ptr instances = - std::make_shared(); - }; - - // XLA module name that owns graph instances. We use it only to produce logs - // that can be attributed back to XLA executables. - std::string module_name; - - // Number of graphs in the parent module. - int64_t num_graphs = 0; - - mutable absl::Mutex mu; - absl::node_hash_map graphs ABSL_GUARDED_BY(mu); -}; - -// Keep track of instantiated graphs on each StreamExecutor, we use this -// information in the graph eviction policy. -using GraphInstancesState = absl::flat_hash_map; - -static absl::Mutex* GetGraphInstancesStateMutex() { - static auto* mu = new absl::Mutex(); - return mu; -} - -static GraphInstancesState& GetGraphInstancesState() { - static auto* state = new GraphInstancesState(); - return *state; -} - -static int64_t NotifyGraphInstancesCreated(se::StreamExecutor* executor, - int64_t num_graphs) { - absl::MutexLock lock(GetGraphInstancesStateMutex()); - return GetGraphInstancesState()[executor] += num_graphs; -} - -static int64_t NotifyGraphInstancesDestroyed(se::StreamExecutor* executor, - int64_t num_graphs) { - absl::MutexLock lock(GetGraphInstancesStateMutex()); - return GetGraphInstancesState()[executor] -= num_graphs; -} - -// We keep track of all graph instances in the process, to implement graph -// eviction on OOM. Graph instances owned by GpuExecutable, so we rely on -// weak ptr to check if they are still alive. -using GraphInstancesVec = std::vector>; - -static absl::Mutex* GetGraphInstancesVecMutex() { - static auto* mu = new absl::Mutex(); - return mu; -} - -static GraphInstancesVec& GetGraphInstancesVec() { - static auto* vec = new GraphInstancesVec(); - return *vec; -} - -static void AddGraphInstances(std::weak_ptr impl) { - absl::MutexLock lock(GetGraphInstancesVecMutex()); - GetGraphInstancesVec().push_back(std::move(impl)); -} - -// Evicts all graphs for a given executor in the current process. -static void EvictAllGraphs( - se::StreamExecutor* executor, - std::optional eviction_timeout_seconds = std::nullopt) { - // We WARN only when we evict all Gpu graphs because it happens when we - // recover from OOM. Eviction by time out is business as usual. - if (eviction_timeout_seconds.has_value()) { - VLOG(3) << "Evict timed out gpu graphs from executor " << executor; - } else { - LOG(WARNING) << "Evict all gpu graphs from executor " << executor; - } - - TraceMe trace_instantiation([&] { - return TraceMeEncode("cuda.graph.evict_all_graphs", - {{"device_ordinal", executor->device_ordinal()}}); - }); - - absl::MutexLock lock(GetGraphInstancesVecMutex()); - auto& vec = GetGraphInstancesVec(); - - // Erase all expired graph instances. - vec.erase(std::remove_if(vec.begin(), vec.end(), - [](auto& weak_ptr) { return weak_ptr.expired(); }), - vec.end()); - - auto timed_out = [&](GraphInstances::Impl::State& state) -> bool { - if (!eviction_timeout_seconds.has_value()) { - return false; - } - - auto diff = tsl::Env::Default()->NowMicros() - state.last_use_micros; - return (diff / (1000 * 1000)) > *eviction_timeout_seconds; - }; - - int64_t num_evicted = 0; - - for (auto& weak_ptr : vec) { - auto ptr = weak_ptr.lock(); - if (!ptr) continue; - - if (!ptr->mu.TryLock()) continue; - - auto it = ptr->graphs.find(executor); - if (it == ptr->graphs.end()) { - ptr->mu.Unlock(); - continue; - } - - // If we have a timeout value, than check it first, otherwise always evict - // graphs for a given executor. - bool is_timed_out = timed_out(it->second); - if (eviction_timeout_seconds.has_value() && !is_timed_out) { - ptr->mu.Unlock(); - continue; - } - - if (ptr->num_graphs > 0) { - VLOG(3) << "Evict " << ptr->num_graphs << " graphs for: @" - << ptr->module_name << " at executor: " << executor - << " (timed_out = " << is_timed_out << ")." - << " Total remaining graphs at given executor: " - << NotifyGraphInstancesDestroyed(executor, ptr->num_graphs); - } - ptr->graphs.erase(it); - ptr->mu.Unlock(); - ++num_evicted; - } - - if (num_evicted > 0) { - VLOG(3) << "Evicted " << num_evicted << " graphs from executor " - << executor; -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - se::gpu::GpuGraphSupport::TrimDeviceMemory(executor); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - } -} - -GraphInstances::GraphInstances(std::string module_name, int64_t num_graphs) - : impl_(std::make_shared()) { - impl_->module_name = std::move(module_name); - impl_->num_graphs = num_graphs; - if (impl_->num_graphs > 0) { - VLOG(3) << "Construct graph instances cache for: @" << impl_->module_name - << " (num_graphs = " << impl_->num_graphs << ")"; - } - AddGraphInstances(impl_); -} - -GraphInstances::~GraphInstances() { - if (impl_->num_graphs > 0) { - VLOG(3) << "Destroy graph instances cache for: @" << impl_->module_name - << " (num_graphs = " << impl_->num_graphs << ")"; - - absl::MutexLock lock(&impl_->mu); - for (auto& [executor, state] : impl_->graphs) { - VLOG(3) << "Destroy " << impl_->num_graphs << " graphs for: @" - << impl_->module_name << " at executor: " << executor - << ". Total remaining graphs at given executor: " - << NotifyGraphInstancesDestroyed(executor, impl_->num_graphs); - } - } -} - -std::shared_ptr GraphInstances::operator()( - se::StreamExecutor* executor) { - absl::MutexLock lock(&impl_->mu); - - auto it = impl_->graphs.try_emplace(executor); - if (it.second && impl_->num_graphs > 0) { - VLOG(3) << "Instantiate " << impl_->num_graphs << " graphs for: @" - << impl_->module_name << " at executor: " << executor - << ". Total graphs at given executor: " - << NotifyGraphInstancesCreated(executor, impl_->num_graphs); - } - - Impl::State& state = it.first->second; - state.last_use_micros = tsl::Env::Default()->NowMicros(); - return state.instances; -} - -bool GraphInstances::InstantiatedAllGraphs( - const ServiceExecutableRunOptions* run_options, - const Executable& executable) { - if (executable.num_functions() == 1) return true; - - absl::MutexLock lock(&impl_->mu); - return impl_->graphs[run_options->stream()->parent()].instantiated; -} - -Status GraphInstances::InstantiateAllGraphs( - const ServiceExecutableRunOptions* run_options, - const Executable& executable, const CustomCall::UserData& user_data, - const BufferAllocations& buffer_allocations, - absl::Span buffer_sizes, - absl::Span> allocation_indices, - std::optional eviction_timeout_seconds) { - // We have only "main" function in the executable. - if (executable.num_functions() == 1) return OkStatus(); - - absl::MutexLock lock(&impl_->mu); - se::StreamExecutor* executor = run_options->stream()->parent(); - - Impl::State& state = impl_->graphs[executor]; - - // All Gpu graphs are already instantiated for a given executor. - if (state.instantiated) return OkStatus(); - - TraceMe trace("gpu.graph.instantiate_all"); - - // Evict all timeout graphs before trying to instantiate new ones. - EvictAllGraphs(executor, eviction_timeout_seconds); - - // We'll retry graph instantiation on OOM errors after evicting all graphs - // instantiated on `executor`. - int32_t num_retries = 0; - - StreamExecutorGraphInstances::Snapshot instances = - state.instances->snapshot(); - - // Instantiate all Gpu graphs by calling graph capture functions with fake - // arguments. Once we'll execute them first time for real, they'll be updated - // with correct pointers. - for (unsigned ordinal = 1; ordinal < executable.num_functions(); ++ordinal) { - if (!absl::StartsWith(executable.function_name(ordinal), - "xla.gpu.graph.capture")) - continue; - - VLOG(3) << "Instantiate Gpu graph defined by capture function @" - << executable.function_name(ordinal) << " (ordinal = " << ordinal - << ")"; - - TraceMe trace_instantiation([&] { - return TraceMeEncode("gpu.graph.instantiate", {{"ordinal", ordinal}}); - }); - - FunctionRef function_ref = executable.function_ref(ordinal); - - const FunctionType& signature = executable.signature(ordinal); - assert(signature.num_results() == 0 && "unexpected number of results"); - Arguments args(signature.num_operands()); - - // Mapping from graph capture argument to buffer allocation index. - absl::Span capture_allocs = allocation_indices[ordinal]; - if (capture_allocs.size() != signature.num_operands()) - return absl::InternalError( - "Invalid number of allocation indices for a graph capture function"); - - // Prepare arguments for the graph capture function. - for (size_t j = 0; j < signature.num_operands(); ++j) { - auto* memref = llvm::dyn_cast(signature.operand(j)); - - if (!memref) - return absl::InternalError(absl::StrFormat( - "Unsupported capture function argument type #%d", j)); - - if (memref->sizes().size() != 1) - return absl::InternalError( - absl::StrFormat("Unsupported capture function memref rank #%d: %d", - j, memref->sizes().size())); - - std::array sizes = {memref->size(0)}; - std::array strides = {1}; - - int64_t allocation_index = capture_allocs[j]; - args.emplace_back( - memref->element_type(), - buffer_allocations.GetDeviceAddress(allocation_index).opaque(), - /*offset=*/0, sizes, strides); - } - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - // Instantiate a Gpu graph with fake arguments. - auto instantiate = [&]() -> absl::StatusOr { - TF_ASSIGN_OR_RETURN( - auto g, CaptureGraph(run_options, function_ref, args, user_data)); - TF_ASSIGN_OR_RETURN(auto e, se::gpu::InstantiateGpuGraph(std::move(g))); - return GraphInstance(0, std::move(e)); - }; - - absl::StatusOr instance = - instances.GetOrCreate(ordinal, instantiate); - - if (instance.status().code() == absl::StatusCode::kResourceExhausted) { - if (num_retries == 0) { - LOG(WARNING) << "InstantiateAllGraph failed due to insufficient memory." - " Try to evict all graphs and free device memory."; - - // Retry on OOM error after evicting all graphs from executor. - EvictAllGraphs(executor); - num_retries++; - ordinal--; // we'll try to instantiate the same graph one more time - continue; - } else { - LOG(WARNING) << "InstantiateAllGraph failed due to insufficient memory." - " Unitialized graphs will run in op-by-op mode."; - return OkStatus(); - } - } - - // Otherwise return an error to the caller. - if (!instance.ok()) return instance.status(); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - } - - state.instantiated = true; - return OkStatus(); -} - -CapturedFunctionExecutionCount* CapturedFunctionExecutionCounts::operator()( - se::StreamExecutor* executor) { - absl::MutexLock lock(&mutex_); - return &counts_[executor]; -} - -//===----------------------------------------------------------------------===// -// Helper structure to hash the remaining arguments' memref pointers. -//===----------------------------------------------------------------------===// - -struct RemainingArgsPtrs { - CustomCall::RemainingArgs args; - se::DeviceMemoryBase* temp_buffer; - - template - friend H AbslHashValue(H h, const RemainingArgsPtrs& m); -}; - -template -H AbslHashValue(H h, const RemainingArgsPtrs& m) { - for (size_t i = 0; i < m.args.size(); ++i) { - if (auto memref = m.args.get(i); succeeded(memref)) - h = H::combine(std::move(h), memref->data); - } - return std::move(H::combine(std::move(h), m.temp_buffer->opaque())); -} - -//----------------------------------------------------------------------------// -// Runs capture function exported by the executable to construct a gpu graph. -//----------------------------------------------------------------------------// - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -static bool InDebugMode() { -#ifdef NDEBUG - return false; -#endif - return true; -} - -// Forwards custom call arguments to an arguments container that can be passed -// to an executable function. -static absl::Status ForwardArguments(CustomCall::RemainingArgs fwd_args, - Arguments& args) { - for (size_t i = 0; i < fwd_args.size(); ++i) { - if (auto memref = fwd_args.get(i); succeeded(memref)) { - args.emplace_back(memref->dtype, memref->data, /*offset=*/0, - memref->sizes, memref->strides); - continue; - } - - return absl::InvalidArgumentError("Unsupported argument type"); - } - - return OkStatus(); -} - -static absl::StatusOr CaptureGraph( - const ServiceExecutableRunOptions* run_options, - runtime::FunctionRef function_ref, Arguments& args, - CustomCall::UserData user_data) { - // We capture graph on a borrowed stream because we do not want to - // accidentally record any concurrent kernel launches from other XLA - // executables. - se::StreamExecutor* executor = run_options->stream()->parent(); - - // Initialize (with memoization) BlasSupport here because cublasCreate fails - // during gpu graph capturing. - if (function_ref.RequiresBlas()) { - if (!executor->AsBlas()) { - return absl::InternalError("Failed to initialize BLAS support"); - } - } - - StatusOr capture_stream = - run_options->BorrowStream(executor->device_ordinal()); - - if (!capture_stream.ok()) - return absl::InternalError( - absl::StrFormat("Failed to borrow a stream for graph capture: %s", - capture_stream.status().message())); - - TraceMe trace([&] { - return TraceMeEncode("gpu.graph.capture", - {{"ordinal", function_ref.ordinal()}}); - }); - - // TODO(ezhulenev): Pass graph capture context explicitly to the custom calls - // via UserData to be able to detect when executing custom call in graph - // capture mode. Currently we rely on the fact that we know for sure that - // operations in the graph capture function do not need anything except the - // main stream (we capture only kernel launches). - ExecutableRunOptions capture_run_options; - capture_run_options.set_stream(capture_stream->get()); - - const ServiceExecutableRunOptions capture_opts(capture_run_options); - user_data.insert(&capture_opts); - - // Collect all emitted diagnostic messages. - std::string diagnostic; - runtime::DiagnosticEngine diagnostic_engine; - AppendDiagnosticToString(diagnostic_engine, &diagnostic); - - // Prepare options for executing graph capture function. - Executable::ExecuteOpts opts; - opts.custom_call_data = &user_data; - opts.diagnostic_engine = &diagnostic_engine; - - // Graph capture function should not launch any async tasks. - opts.async_task_runner = reinterpret_cast(0XDEADBEEF); - - // Create a graph from running the graph capture function. - auto captured = se::gpu::CaptureGpuGraph(capture_stream->get(), [&]() { - return function_ref(args, runtime::NoResultConverter{}, opts, - /*verify_arguments=*/InDebugMode()) - .status(); - }); - - if (!captured.ok()) { - return InternalError("CaptureGpuGraph failed (%s): %s", - diagnostic.empty() ? "" : diagnostic, - captured.status().ToString()); - } - return std::move(*captured); -} - -// When graph execution is disabled we run the graph capture function in -// "regular" mode and execute all operation one by one. -static absl::Status RunGraphOpByOp( - const ServiceExecutableRunOptions* run_options, - runtime::FunctionRef function_ref, CustomCall::RemainingArgs fwd_args, - CustomCall::UserData user_data) { - // Prepare options for executing graph capture function. - Executable::ExecuteOpts opts; - auto* concurrent_region_status = user_data.get(); - // Ops should not run in parallel during op-by-op execution. - concurrent_region_status->DisableConcurrentRegion(); - opts.custom_call_data = &user_data; - - TraceMe trace([&] { - return TraceMeEncode("gpu.graph.run_op_by_op_fallback", - {{"ordinal", function_ref.ordinal()}}); - }); - - // Collect all emitted diagnostic messages. - std::string diagnostic; - runtime::DiagnosticEngine diagnostic_engine; - AppendDiagnosticToString(diagnostic_engine, &diagnostic); - - opts.diagnostic_engine = &diagnostic_engine; - - // Graph capture function should not launch any async tasks. - opts.async_task_runner = reinterpret_cast(0XDEADBEEF); - - Arguments args(fwd_args.size()); - TF_RETURN_IF_ERROR(ForwardArguments(fwd_args, args)); - - auto executed = - function_ref(args, runtime::NoResultConverter{}, opts, InDebugMode()); - concurrent_region_status->EnableConcurrentRegion(); - if (!executed.ok()) { - return InternalError("RunGraphOpByOp failed (%s): %s", - diagnostic.empty() ? "" : diagnostic, - executed.status().ToString()); - } - return absl::OkStatus(); -} - -#endif // #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -//===----------------------------------------------------------------------===// -// Define the gpu graph launch custom call. -//===----------------------------------------------------------------------===// - -static absl::Status LaunchGraph( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, const std::string* ptx, - const std::vector* cubin, se::DeviceMemoryBase* temp_buffer, - StreamExecutorKernels::Snapshot* kernels, - StreamExecutorConvRunners::Snapshot* convs, - StreamExecutorGraphInstances::Snapshot* instances, - CapturedFunctionExecutionCount::Snapshot* counts, - GemmConfigs::Snapshot* gemm_config, runtime::Executable* executable, - NonAtomicallyUpgradeableRWLock* gpu_lock, - ConcurrentRegionStatus* region_status, CustomCall::RemainingArgs fwd_args, - CustomCall::FunctionOrdinal capture) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - VLOG(1) << "Launch GPU Graph: ordinal = " << capture.ordinal; - - // Get a reference to exported function that captures the gpu graph. - runtime::FunctionRef function_ref = executable->function_ref(capture.ordinal); - - // Compute the hash of the buffer arguments. - size_t ptrs_hash = absl::HashOf(RemainingArgsPtrs{fwd_args, temp_buffer}); - - // Forwards user data required for launching kernels. - auto user_data = [&] { - return CustomCall::UserData(run_options, debug_options, ptx, cubin, - temp_buffer, kernels, convs, executable, - gemm_config, gpu_lock, region_status); - }; - - TF_ASSIGN_OR_RETURN(std::unique_ptr> * get_count, - counts->GetOrCreate(capture.ordinal, [] { - return std::make_unique>(0); - })); - - int64_t count = (*get_count)->fetch_add(1); - int64_t num_runs_to_instantiate = - debug_options->xla_gpu_graph_num_runs_to_instantiate(); - - // TODO(b/290773547): Profiler + CUDA graphs lead to memory corruption. As a - // work around disable graph execution and run everything in op-by-op mode. - bool is_profiling = tsl::profiler::ProfilerLock::HasActiveSession(); - - if (count < num_runs_to_instantiate || is_profiling) { - VLOG(3) << "Run gpu graph in op-by-op mode: ordinal = " << capture.ordinal; - return RunGraphOpByOp(run_options, function_ref, fwd_args, user_data()); - } - - // Instantiate Gpu graph by running graph capture function. - auto instantiate = [&]() -> absl::StatusOr { - Arguments args(fwd_args.size()); - TF_RETURN_IF_ERROR(ForwardArguments(fwd_args, args)); - - TF_ASSIGN_OR_RETURN( - auto g, CaptureGraph(run_options, function_ref, args, user_data())); - - TF_ASSIGN_OR_RETURN(auto e, se::gpu::InstantiateGpuGraph(std::move(g))); - - return GraphInstance(ptrs_hash, std::move(e)); - }; - - GraphInstance* instance; - if (num_runs_to_instantiate < 0) { - // If num_runs_to_instantiate is less than 0, all graphs should be - // instantiated ahead-of-time. If we fail to get the graph instance, then - // graph instantiation failed due to OOM. So we run the graph op-by-op. - absl::StatusOr try_get_instance = - instances->Get(capture.ordinal); - if (try_get_instance.ok()) { - instance = try_get_instance.value(); - } else { - return RunGraphOpByOp(run_options, function_ref, fwd_args, user_data()); - } - } else { - TF_ASSIGN_OR_RETURN(instance, - instances->GetOrCreate(capture.ordinal, instantiate)); - } - - { - // Lock graph instance for read only access. If we'll have to update the - // graph, we'll update to a writer lock below. - absl::ReaderMutexLock lock(instance->mutex.get()); - - // If pointers did not change we can run captured graph. - if (ptrs_hash == instance->ptr_hash) { - TraceMe trace([&] { - return TraceMeEncode("gpu.graph.launch_cached", - {{"ordinal", capture.ordinal}}); - }); - - VLOG(3) << "Execute cached graph instance"; - return instance->exec.Launch(run_options->stream()); - } - } - - // Otherwise we have to re-capture the graph and update the graph instance. - VLOG(3) << "Update cached graph instance"; - - Arguments args(fwd_args.size()); - TF_RETURN_IF_ERROR(ForwardArguments(fwd_args, args)); - - // Capture GPU graph by running capture function. - TF_ASSIGN_OR_RETURN( - auto g, CaptureGraph(run_options, function_ref, args, user_data())); - - // At this point we have to grab a writer lock, because we might potentially - // have concurrent execution of the cached graph instance. - absl::WriterMutexLock lock(instance->mutex.get()); - - // Update captured graph executable. - TF_RETURN_IF_ERROR(instance->exec.Update(std::move(g))); - - // Update captured pointer hash. - instance->ptr_hash = ptrs_hash; - - TraceMe trace([&] { - return TraceMeEncode("gpu.graph.launch_updated", - {{"ordinal", capture.ordinal}}); - }); - - return instance->exec.Launch(run_options->stream()); - -#else // #if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM - - return absl::InternalError("GPU graphs are not supported"); - -#endif // #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -} - -//===----------------------------------------------------------------------===// - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Launch, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.graph.launch") - .UserData() - .UserData() - .UserData() - .UserData*>() - .UserData() - .UserData() - .UserData() - .UserData() - .UserData() - .UserData() - .UserData() - .UserData() - .UserData() - .RemainingArgs() - .Attr("capture")); - -void RegisterGraphLaunchCustomCalls( - runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.graph.launch", Launch); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/graph_launch.h b/xla/service/gpu/runtime/graph_launch.h deleted file mode 100644 index 6fe5145098dd6..0000000000000 --- a/xla/service/gpu/runtime/graph_launch.h +++ /dev/null @@ -1,140 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_GRAPH_LAUNCH_H_ -#define XLA_SERVICE_GPU_RUNTIME_GRAPH_LAUNCH_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/node_hash_map.h" -#include "absl/types/span.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/stream_executor/stream_executor.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/stream_executor/gpu/gpu_graph.h" -#endif // #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime graph launch custom calls. -void RegisterGraphLaunchCustomCalls( - runtime::DirectCustomCallRegistry& registry); - -struct GraphInstance; // Forward declare -class StreamExecutorGraphInstances; // Forward declare - -// A state vector that keeps track of the number of times a capture function -// gets executed. Graph capture function ordinal is the key in this container. -class CapturedFunctionExecutionCount - : public runtime::StateVector>> {}; - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -// A state vector that owns all instantiated GPU graphs. Graph capture function -// ordinal is the key in this container. -class StreamExecutorGraphInstances - : public runtime::StateVector {}; - -// Instantiated GPU graph instance guarded with a mutex for exclusive access. -struct GraphInstance { - GraphInstance(size_t ptr_hash, se::gpu::OwnedGpuGraphExec exec) - : ptr_hash(ptr_hash), exec(std::move(exec)), mutex(new absl::Mutex) {} - - // Graph instance is fully identified by the hash of its pointer arguments - // because currently it's guaranteed that all shapes and launch dimensions - // will be constant from run to run. - size_t ptr_hash ABSL_GUARDED_BY(*mutex); - se::gpu::OwnedGpuGraphExec exec ABSL_GUARDED_BY(*mutex); - - // Access to a graph instance must be synchronized, because we potentially can - // run concurrent graph instance updates. - std::unique_ptr mutex; -}; - -#else // #if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM - -// Define empty struct and empty state when GPU is not enabled. -struct GraphInstance {}; -class StreamExecutorGraphInstances - : public runtime::StateVector {}; - -#endif // #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -// Xla executable keeps a mapping from stream executors to graph instances. -// -// Graph instances allocate on-device memory, so we periodically destroy -// them to free up some space on device. JAX for example keeps all XLA -// executables alive, and destroys them when the process shuts down, so we can -// end up with thousands of unused (or rarely used) graphs in device memory. -class GraphInstances { - public: - struct Impl; - - GraphInstances(std::string module_name, int64_t num_graphs); - ~GraphInstances(); - - std::shared_ptr operator()( - se::StreamExecutor* executor); - - // Instantiates all Gpu graphs defined by the given executable using user - // provided run options. This guarantees that once we start execution, all Gpu - // graphs are ready, and will only require cheap update operation and will not - // require allocating new resources (we avoid non deterministic OOM errors). - // - // If timeout is not nullopt it will evict all previously instantiated graphs - // that were used more than `eviction_timeout_seconds` seconds ago. - Status InstantiateAllGraphs( - const ServiceExecutableRunOptions* run_options, - const runtime::Executable& executable, - const runtime::CustomCall::UserData& user_data, - const BufferAllocations& buffer_allocations, - absl::Span buffer_sizes, - absl::Span> allocation_indices, - std::optional eviction_timeout_seconds = std::nullopt); - - // Returns true if all Gpu graphs were already instantiated. - bool InstantiatedAllGraphs(const ServiceExecutableRunOptions* run_options, - const runtime::Executable& executable); - - private: - std::shared_ptr impl_; -}; - -// Xla executable keeps a mapping from stream executors to execution counts. -class CapturedFunctionExecutionCounts { - public: - CapturedFunctionExecutionCount* operator()(se::StreamExecutor* executor); - - private: - mutable absl::Mutex mutex_; - absl::node_hash_map - counts_ ABSL_GUARDED_BY(mutex_); -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_GRAPH_LAUNCH_H_ diff --git a/xla/service/gpu/infeed_thunk.cc b/xla/service/gpu/runtime/infeed_thunk.cc similarity index 79% rename from xla/service/gpu/infeed_thunk.cc rename to xla/service/gpu/runtime/infeed_thunk.cc index e68577d4aaf87..02dabdf25e41c 100644 --- a/xla/service/gpu/infeed_thunk.cc +++ b/xla/service/gpu/runtime/infeed_thunk.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/infeed_thunk.h" +#include "xla/service/gpu/runtime/infeed_thunk.h" -#include "xla/service/buffer_assignment.h" +#include "absl/status/status.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/infeed_manager.h" #include "xla/shape_util.h" @@ -30,7 +30,7 @@ InfeedThunk::InfeedThunk(ThunkInfo thunk_info, std::vector dest_slices) : Thunk(Kind::kInfeed, thunk_info), dest_slices_(std::move(dest_slices)) {} -Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) { +absl::Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) { se::Stream& stream = *params.stream; const BufferAllocations& buffer_allocations = *params.buffer_allocations; @@ -45,28 +45,30 @@ Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) { se::ScopedDeviceMemory& buffer = source.second; const Shape& source_shape = ShapeUtil::GetSubshape(source_buffers.shape(), shape_index); - TF_RET_CHECK(ShapeUtil::Equal(dest_slices_[index].shape, source_shape)) + TF_RET_CHECK( + ShapeUtil::ReshapeIsBitcast(dest_slices_[index].shape, source_shape)) << "Mismatch between infeed source buffer shape " << ShapeUtil::HumanStringWithLayout(source_shape) << " and infeed dest buffer shape " << ShapeUtil::HumanStringWithLayout(dest_slices_[index].shape); se::DeviceMemoryBase dest_address = buffer_allocations.GetDeviceAddress(dest_slices_[index++].slice); - stream.ThenMemcpy(&dest_address, *buffer.ptr(), buffer.ptr()->size()); + TF_RETURN_IF_ERROR( + stream.Memcpy(&dest_address, *buffer.ptr(), buffer.ptr()->size())); } // Make sure that all dest slices have been copied into. CHECK_EQ(index, dest_slices_.size()) << "Infeed did not populate all destination buffers"; - Status block_status = stream.BlockHostUntilDone(); + absl::Status block_status = stream.BlockHostUntilDone(); if (!block_status.ok()) { - return InternalError("Failed to complete data transfer on stream %p: %s", - &stream, block_status.message()); + return Internal("Failed to complete data transfer on stream %p: %s", + &stream, block_status.message()); } VLOG(2) << "Infeeding to GPU complete"; - return OkStatus(); + return absl::OkStatus(); } } // namespace gpu diff --git a/xla/service/gpu/runtime/infeed_thunk.h b/xla/service/gpu/runtime/infeed_thunk.h new file mode 100644 index 0000000000000..7a3db689cd6a3 --- /dev/null +++ b/xla/service/gpu/runtime/infeed_thunk.h @@ -0,0 +1,48 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_INFEED_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_INFEED_THUNK_H_ + +#include + +#include "absl/status/status.h" +#include "xla/service/gpu/runtime/thunk.h" + +namespace xla { +namespace gpu { + +// A thunk that infeeds data. Data must be already resident on the +// device. This thunk performs an intra-device copy from that location +// to the buffer allocated for the infeed op. +class InfeedThunk : public Thunk { + public: + // Constructs a InfeedThunk that copies data from the on-device + // infeed queue into the buffers in the given shape tree. + InfeedThunk(ThunkInfo thunk_info, std::vector dest_slices); + + InfeedThunk(const InfeedThunk&) = delete; + InfeedThunk& operator=(const InfeedThunk&) = delete; + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + private: + const std::vector dest_slices_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_INFEED_THUNK_H_ diff --git a/xla/service/gpu/runtime/io_feed.cc b/xla/service/gpu/runtime/io_feed.cc deleted file mode 100644 index 799c4d64d6f43..0000000000000 --- a/xla/service/gpu/runtime/io_feed.cc +++ /dev/null @@ -1,178 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/io_feed.h" - -#include -#include - -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/infeed_manager.h" -#include "xla/service/gpu/outfeed_manager.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" - -namespace xla { -namespace gpu { - -using runtime::CustomCall; - -static absl::Status InfeedImpl(const ServiceExecutableRunOptions* run_options, - CustomCall::RemainingArgs args, - std::string_view config) { - VLOG(3) << "Infeeding to GPU"; - - se::Stream* stream = run_options->stream(); - ShapeTree> source_buffers = - GetOrCreateInfeedManager(stream->parent())->BlockingGetNextDestination(); - - // Check that we have correct number of arguments. - if (args.size() != source_buffers.leaf_count()) - return absl::InvalidArgumentError("Incorrect number of arguments"); - - size_t index = 0; - for (auto& source : source_buffers.leaves()) { - // Get the destination buffer. - auto dest = args.get(index); - if (failed(dest)) - return absl::InternalError("Failed to get the destination buffer"); - - // Get the source buffer shape. - const Shape& source_shape = - ShapeUtil::GetSubshape(source_buffers.shape(), source.first); - - // Check that destination shape matches the source shape. - Shape dest_shape = ToShape(*dest); - if (!ShapeUtil::ReshapeIsBitcast(dest_shape, source_shape)) { - return absl::InvalidArgumentError( - "The destination shape does not match the source shape"); - } - - se::DeviceMemoryBase dest_address = GetDeviceAddress(*dest); - se::ScopedDeviceMemory& buffer = source.second; - stream->ThenMemcpy(&dest_address, *buffer.ptr(), buffer.ptr()->size()); - - ++index; - } - - // TODO(ezhulenev): Make this function async? - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); - - VLOG(3) << "Infeeding to GPU complete"; - - return absl::OkStatus(); -} - -static absl::Status OutfeedImpl(const ServiceExecutableRunOptions* run_options, - CustomCall::RemainingArgs args, - std::string_view config) { - VLOG(3) << "Outfeeding from GPU"; - - se::Stream* stream = run_options->stream(); - OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager(stream->parent()); - ShapeTree>* dest_buffers = - outfeed_manager->BlockingGetNextDestination(); - - // Nothing to be done for an outfeed with no inputs. - // Note: Must do this after `BlockingGetNextDestination` above to dequeue an - // entry from the outfeed manager. - if (args.empty()) return absl::OkStatus(); - - // Check that we have correct number of arguments. - if (args.size() != dest_buffers->leaf_count()) - return absl::InvalidArgumentError("Incorrect number of arguments"); - - int64_t leaf_count = dest_buffers->leaf_count(); - auto dest_leaf_it = dest_buffers->leaf_begin(); - - for (int64_t index = 0; index < leaf_count; ++index) { - const ShapeIndex& shape_index = dest_leaf_it->first; - std::unique_ptr& buffer = dest_leaf_it->second; - - // NOTE: This code needs deal with the `dest_buffers` object getting - // deleted when it is executing. Specifically, objects in the outfeed queue - // are pointers to instances of stack-allocated objects in - // `GpuTransferManager::TransferLiteralFromOutfeed`. When all leaf node - // buffers are notified via "buffer->Done()" below in the stream host - // callback, `TransferLiteralFromOutfeed` deletes this stack-allocated - // object when it returns. This means that it is possible that during the - // last iteration, after the call to "buffer->Done()" is scheduled onto the - // stream, the `dest_buffers` object might get deleted, so we should avoid - // accessing the object after that. - // - // To achieve that, increment the leaf iterator here before the last "Done" - // is enqueued, instead of in the loop increment, which would be after the - // "Done" is scheduled. - ++dest_leaf_it; - - // Get the source buffer. - auto source = args.get(index); - if (failed(source)) - return absl::InternalError("Failed to get the source buffer"); - - // Get the source buffer shape. - const Shape& dest_shape = - ShapeUtil::GetSubshape(dest_buffers->shape(), shape_index); - - // Check that destination shape matches the source shape. - Shape source_shape = ToShape(*source); - if (!ShapeUtil::ReshapeIsBitcast(dest_shape, source_shape)) { - return absl::InvalidArgumentError( - "The destination shape does not match the source shape"); - } - - se::DeviceMemoryBase source_address = GetDeviceAddress(*source); - - // Schedule the memory transfer. - auto* dest_address = buffer->destination()->untyped_data(); - stream->ThenMemcpy(dest_address, source_address, buffer->length()) - .ThenDoHostCallback([&buffer]() { buffer->Done(); }); - } - - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); - - VLOG(3) << "Outfeeding from GPU complete"; - - return absl::OkStatus(); -} - -//===----------------------------------------------------------------------===// -// Define Xla runtime bindings for the custom calls. -//===----------------------------------------------------------------------===// - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Infeed, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.infeed") - .UserData() - .Arg() // args - .Attr("config")); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Outfeed, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.outfeed") - .UserData() - .Arg() // args - .Attr("config")); - -//===----------------------------------------------------------------------===// - -void RegisterIoFeedCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.infeed", Infeed); - registry.Register("xla.gpu.outfeed", Outfeed); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/io_feed.h b/xla/service/gpu/runtime/io_feed.h deleted file mode 100644 index ef4e42d770e49..0000000000000 --- a/xla/service/gpu/runtime/io_feed.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_IO_FEED_H_ -#define XLA_SERVICE_GPU_RUNTIME_IO_FEED_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime infeed and outfeed custom calls. -void RegisterIoFeedCustomCalls(runtime::DirectCustomCallRegistry& registry); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_IO_FEED_H_ diff --git a/xla/service/gpu/runtime/kernel_launch.cc b/xla/service/gpu/runtime/kernel_launch.cc deleted file mode 100644 index 3b324049ed175..0000000000000 --- a/xla/service/gpu/runtime/kernel_launch.cc +++ /dev/null @@ -1,164 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/kernel_launch.h" - -#include -#include -#include -#include -#include - -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/state.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/runtime/concurrent_region.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/gpu/stream_executor_util.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/stream_executor/kernel.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/stream_executor/gpu/gpu_graph.h" -#endif // #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -namespace xla { -namespace gpu { - -using xla::runtime::CustomCall; -using xla::runtime::State; -using xla::runtime::StridedMemrefView; - -StreamExecutorKernels* GpuExecutableKernels::operator()( - se::StreamExecutor* executor) { - absl::MutexLock lock(&mutex_); - return &kernels_[executor]; -} - -//===----------------------------------------------------------------------===// -// Define the kernel launch custom call. -//===----------------------------------------------------------------------===// - -static absl::Status LaunchImpl( - const ServiceExecutableRunOptions* run_options, const std::string* ptx, - const std::vector* cubin, se::DeviceMemoryBase* temp_buffer, - ConcurrentRegionStatus* region_status, - State> device_kernel, - int32_t shared_memory_bytes, int32_t grid_size_x, int32_t grid_size_y, - int32_t grid_size_z, int32_t block_size_x, int32_t block_size_y, - int32_t block_size_z, CustomCall::RemainingArgs args, std::string_view name, - int64_t stream_id) { - se::Stream* stream = run_options->stream(); - se::StreamExecutor* executor = stream->parent(); - - LaunchDimensions launch_dimensions( - {grid_size_x, grid_size_y, grid_size_z}, - {block_size_x, block_size_y, block_size_z}); - - const int args_size_including_temp_buffer = args.size() + 1; - - // If kernel does not exist create it from the ptx and cubin. - TF_ASSIGN_OR_RETURN( - std::unique_ptr * kernel, device_kernel.GetOrCreate([&] { - return ToAbsl(CreateKernel(absl::string_view(name.data(), name.size()), - args_size_including_temp_buffer, *ptx, - *cubin, executor, shared_memory_bytes)); - })); - assert((*kernel)->name() == name && "unexpected loaded kernel"); - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - if (VLOG_IS_ON(3)) { - TF_ASSIGN_OR_RETURN(bool is_capturing, se::gpu::IsStreamCapturing(stream)); - if (is_capturing) { - if (region_status->IsInConcurrentRegion()) { - LOG(INFO) << "Launching " << (*kernel)->name() - << "in a concurrent region during GPU graph capture"; - } else { - LOG(INFO) << "Launching " << (*kernel)->name() - << "during GPU graph capture"; - } - } else { - LOG(INFO) << "Launching " << (*kernel)->name(); - } - } -#else - VLOG(3) << "Launching " << (*kernel)->name(); -#endif - - absl::InlinedVector buffer_args( - args_size_including_temp_buffer); - - // Add MemRef arguments as buffer arguments. - for (unsigned i = 0; i < args.size(); ++i) { - // We get arguments corresponding to XLA allocations required by the - // compiled device kernel, and not the actual memrefs that device kernel - // writes/reads, so we don't have to pass the size along with the pointer. - if (auto strided = args.get(i); succeeded(strided)) { - buffer_args[i] = se::DeviceMemoryBase(strided->data); - continue; - } - - return absl::InvalidArgumentError( - absl::StrFormat("Unsupported argument #%d type", i)); - } - - // Always add temporary buffer as the last kernel argument. - buffer_args.back() = *temp_buffer; - - // If we are capturing a concurrent region in a GPU graph, then use the - // stream provided by ConcurrentRegionStatus to execute the kernel. - se::Stream* execution_stream = stream; - if (stream_id != 0) { - DCHECK(region_status->IsInConcurrentRegion()); - TF_ASSIGN_OR_RETURN(execution_stream, region_status->GetStream(stream_id)); - } else if (region_status->IsInConcurrentRegion()) { - execution_stream = region_status->GetNextStream(); - } - - // Execute device kernel on the execution stream. - return ExecuteKernelOnStream(**kernel, buffer_args, launch_dimensions, - execution_stream); -} - -//===----------------------------------------------------------------------===// - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Launch, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.func.launch") - .UserData() - .UserData() - .UserData*>() - .UserData() - .UserData() - .State>("uid") - .Arg() // shared_memory_bytes - .Arg() // grid_size_x - .Arg() // grid_size_y - .Arg() // grid_size_z - .Arg() // block_size_x - .Arg() // block_size_y - .Arg() // block_size_x - .RemainingArgs() // args - .Attr("kernel") - .Attr("stream")); - -void RegisterKernelLaunchCustomCalls( - runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.func.launch", Launch); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/kernel_launch.h b/xla/service/gpu/runtime/kernel_launch.h deleted file mode 100644 index 4a7f294f67f70..0000000000000 --- a/xla/service/gpu/runtime/kernel_launch.h +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_KERNEL_LAUNCH_H_ -#define XLA_SERVICE_GPU_RUNTIME_KERNEL_LAUNCH_H_ - -#include -#include -#include - -#include "absl/container/node_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/state.h" -#include "xla/stream_executor/stream_executor.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime kernel launch custom calls. -void RegisterKernelLaunchCustomCalls( - runtime::DirectCustomCallRegistry& registry); - -// Kernels loaded by Gpu executable for a single stream executor. -class StreamExecutorKernels - : public runtime::StateVector> {}; - -// Xla runtime Gpu executable owns the pre-compiled device module (PTX and -// Cubin for Nvidia Gpus) for all device kernels, and the cache keeps a mapping -// from stream executor to pre-loaded kernels -class GpuExecutableKernels { - public: - StreamExecutorKernels* operator()(se::StreamExecutor* executor); - - private: - mutable absl::Mutex mutex_; - absl::node_hash_map kernels_ - ABSL_GUARDED_BY(mutex_); -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_KERNEL_LAUNCH_H_ diff --git a/xla/service/gpu/runtime/kernel_thunk.cc b/xla/service/gpu/runtime/kernel_thunk.cc new file mode 100644 index 0000000000000..063940d8ab7cd --- /dev/null +++ b/xla/service/gpu/runtime/kernel_thunk.cc @@ -0,0 +1,237 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/kernel_thunk.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/kernel_arguments.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/service/gpu/stream_executor_util.h" +#include "xla/status.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +//===----------------------------------------------------------------------===// +// KernelThunk +//===----------------------------------------------------------------------===// + +KernelThunk::KernelThunk(const HloInstruction* instr, std::string kernel_name, + absl::Span kernel_arguments, + LaunchDimensions launch_dimensions, + std::optional cluster_dim, + int64_t shmem_bytes) + : Thunk(Kind::kKernel, Thunk::ThunkInfo::WithProfileAnnotation(instr)), + kernel_name_(std::move(kernel_name)), + launch_dimensions_(std::move(launch_dimensions)), + cluster_dim_(std::move(cluster_dim)), + shmem_bytes_(shmem_bytes) { + args_.reserve(kernel_arguments.size()); + written_.reserve(kernel_arguments.size()); + for (const auto& kernel_argument : kernel_arguments) { + if (!kernel_argument.first_with_same_slice().has_value()) { + args_.push_back(kernel_argument.slice()); + written_.push_back(kernel_argument.written()); + } + } +} + +std::string KernelThunk::ToStringExtra(int indent) const { + return absl::StrFormat( + ", kernel = %s, launch dimensions = %s, cluster_dim = %s", kernel_name_, + launch_dimensions_.ToString(), + cluster_dim_.has_value() ? cluster_dim_->ToString() : "nullopt"); +} + +absl::Status KernelThunk::Initialize(const InitializeParams& params) { + absl::MutexLock lock(&mutex_); + + // Load the kernel into the device if necessary. + // + // We could alternatively do this within ExecuteOnStream, but doing it here + // lets the time spent loading the kernel not count towards our execution + // profiles. + auto it = kernel_cache_.find(params.executor); + if (kernel_cache_.end() == it) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr kernel, + CreateKernel(kernel_name_, args_.size(), params.src.text, + params.src.binary, params.executor, shmem_bytes_)); + + kernel_cache_.emplace(params.executor, std::move(kernel)); + } + + return absl::OkStatus(); +} + +static void PrintBufferContents( + se::Stream* stream, absl::Span buffer_args) { + int input_idx = 0; + for (const se::DeviceMemoryBase& buf : buffer_args) { + auto host_buffer = std::make_unique(buf.size()); + CHECK(stream->Memcpy(host_buffer.get(), buf, buf.size()).ok()); + CHECK_OK(stream->BlockHostUntilDone()); + + std::string buffer_contents; + for (int i = 0; i < buf.size(); i++) { + absl::StrAppendFormat(&buffer_contents, "%x ", + static_cast(host_buffer[i])); + } + VLOG(100) << "BUF(" << input_idx++ << ") = " << buffer_contents; + } +} + +absl::Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) { + // Load the kernel. + se::StreamExecutor* executor = params.stream->parent(); + LaunchDimensions launch_dimensions; + std::optional cluster_dim; + const se::Kernel* kernel = nullptr; + + TF_ASSIGN_OR_RETURN( + se::Stream * stream, + GetStreamForExecution(Thunk::execution_stream_id(), params)); + + { + absl::MutexLock lock(&mutex_); + auto it = kernel_cache_.find(executor); + CHECK(it != kernel_cache_.end()) + << "Initialize() not called for StreamExecutor " << executor; + launch_dimensions = launch_dimensions_; + cluster_dim = cluster_dim_; + kernel = it->second.get(); + } + + VLOG(3) << "Launching " << kernel->name(); + absl::InlinedVector buffer_args; + for (const BufferAllocation::Slice& arg : args_) { + se::DeviceMemoryBase buf = params.buffer_allocations->GetDeviceAddress(arg); + VLOG(3) << " Arg: alloc #" << arg.index() << ", offset: " << arg.offset() + << ": " << buf.opaque() << " (" << buf.size() << "B)"; + buffer_args.push_back(buf); + } + + if (VLOG_IS_ON(100)) { + PrintBufferContents(stream, buffer_args); + } + + if (cluster_dim.has_value()) { + return ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions, + cluster_dim.value(), stream); + } else { + return ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions, + stream); + } +} + +//===----------------------------------------------------------------------===// +// CustomKernelThunk +//===----------------------------------------------------------------------===// + +CustomKernelThunk::CustomKernelThunk( + const HloInstruction* instr, CustomKernel custom_kernel, + absl::Span kernel_arguments) + : Thunk(Kind::kCustomKernel, + Thunk::ThunkInfo::WithProfileAnnotation(instr)), + custom_kernel_(std::move(custom_kernel)) { + args_.reserve(kernel_arguments.size()); + written_.reserve(kernel_arguments.size()); + for (const auto& kernel_argument : kernel_arguments) { + if (!kernel_argument.first_with_same_slice().has_value()) { + args_.push_back(kernel_argument.slice()); + written_.push_back(kernel_argument.written()); + } + } +} + +std::string CustomKernelThunk::ToStringExtra(int indent) const { + return custom_kernel_.ToString(); +} + +absl::Status CustomKernelThunk::Initialize(const InitializeParams& params) { + absl::MutexLock lock(&mutex_); + + auto it = kernel_cache_.find(params.executor); + if (kernel_cache_.end() == it) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr kernel, + se::Kernel::Create(params.executor, custom_kernel_.kernel_spec())); + kernel_cache_.emplace(params.executor, std::move(kernel)); + } + + return absl::OkStatus(); +} + +absl::Status CustomKernelThunk::ExecuteOnStream(const ExecuteParams& params) { + se::StreamExecutor* executor = params.stream->parent(); + + const se::Kernel* kernel = [&] { + absl::MutexLock lock(&mutex_); + return kernel_cache_[executor].get(); + }(); + + VLOG(3) << "Launching " << custom_kernel_.ToString() << " as device kernel " + << kernel->name(); + + absl::InlinedVector buffer_args; + for (const BufferAllocation::Slice& arg : args_) { + se::DeviceMemoryBase buf = params.buffer_allocations->GetDeviceAddress(arg); + VLOG(3) << " Arg: alloc #" << arg.index() << ", offset: " << arg.offset() + << ": " << buf.opaque() << " (" << buf.size() << "B)"; + buffer_args.push_back(buf); + } + + if (VLOG_IS_ON(100)) { + PrintBufferContents(params.stream, buffer_args); + } + + se::KernelArgsDeviceMemoryArray args(buffer_args, + custom_kernel_.shared_memory_bytes()); + + if (auto cluster = custom_kernel_.cluster_dims(); cluster.has_value()) { + return executor->Launch(params.stream, custom_kernel_.thread_dims(), + custom_kernel_.block_dims(), *cluster, *kernel, + args); + } else { + return executor->Launch(params.stream, custom_kernel_.thread_dims(), + custom_kernel_.block_dims(), *kernel, args); + } +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/kernel_thunk.h b/xla/service/gpu/runtime/kernel_thunk.h new file mode 100644 index 0000000000000..1b50ce6be2869 --- /dev/null +++ b/xla/service/gpu/runtime/kernel_thunk.h @@ -0,0 +1,175 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_KERNEL_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_KERNEL_THUNK_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/kernel_arguments.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/status.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/types.h" // IWYU pragma: keep + +namespace xla { +namespace gpu { + +class GpuExecutable; + +// TODO(ezhulenev): Unify KernelThunk and CustomKernelThunk as they are very +// similar. XLA:GPU should use more of kernel loading APIs provided by +// StreamExecutor out of the box and less custom kernel loading solutions. +// +// Today KernelThunk is required for lowering to XLA runtime, and +// CustomKernelThunk is only supported for thunk execution. + +//===----------------------------------------------------------------------===// +// KernelThunk +//===----------------------------------------------------------------------===// + +// This class stores everything that StreamExecutor needs for launching a +// kernel. It implements the ExecuteOnStream interface for GpuExecutable to +// invoke the corresponding kernel. +// +// This is thread-compatible. +class KernelThunk : public Thunk { + public: + // Constructs a thunk for the given kernel. + // + // KernelThunk takes args as `BufferAllocation::Slice`s (wrapped in + // `KernelArgument`s). Each slice directly corresponds to an argument or + // output of the computation. Also, the values must correspond to each arg + // directly, not to their base allocation (e.g. they can be the result of an + // `mlir::memref::ViewOp`). + KernelThunk(const HloInstruction* instr, std::string kernel_name, + absl::Span kernel_arguments, + LaunchDimensions launch_dimensions, + std::optional cluster_dim, int64_t shmem_bytes); + KernelThunk(const KernelThunk&) = delete; + KernelThunk& operator=(const KernelThunk&) = delete; + ~KernelThunk() override = default; + + std::string ToStringExtra(int indent) const override; + + absl::Status Initialize(const InitializeParams& params) override; + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + void ClearCompileTimeInfo() override { Thunk::ClearCompileTimeInfo(); } + + const std::vector& arguments() const { + return args_; + } + const std::vector& written() const { return written_; } + + const std::string& kernel_name() const { return kernel_name_; } + const LaunchDimensions& launch_dimensions() const { + return launch_dimensions_; + } + // The shared memory required by the kernel. + int64_t shmem_bytes() const { return shmem_bytes_; } + + private: + // Buffer slices passed to the kernel as arguments. + std::vector args_; + + // args_[i] is written iff (written_[i] == true). + std::vector written_; + + // Entry kernel name for the computation. + const std::string kernel_name_; + + // The thread and block dimension used to launch the kernel. + const LaunchDimensions launch_dimensions_; + + // The cluster dimensions used to launch the kernel. + const std::optional cluster_dim_; + + int64_t shmem_bytes_; + + // Loaded kernels for each `StreamExecutor`. + mutable absl::Mutex mutex_; + absl::flat_hash_map> + kernel_cache_ ABSL_GUARDED_BY(mutex_); +}; + +//===----------------------------------------------------------------------===// +// CustomKernelThunk +//===----------------------------------------------------------------------===// + +// CustomKernelThunk loads and executes kernels defined by a custom kernel +// (which in practice means hand written CUDA C++ kernel), instead of a kernel +// compiled by XLA and loaded from an executable source. +class CustomKernelThunk : public Thunk { + public: + CustomKernelThunk(const HloInstruction* inst, CustomKernel custom_kernel, + absl::Span kernel_arguments); + + std::string ToStringExtra(int indent) const override; + + absl::Status Initialize(const InitializeParams& params) override; + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + const CustomKernel& custom_kernel() const { return custom_kernel_; } + + const std::vector& arguments() const { + return args_; + } + + std::string_view custom_kernel_name() const { return custom_kernel_.name(); } + + const std::vector& written() const { return written_; } + + LaunchDimensions launch_dimensions() const { + return LaunchDimensions(custom_kernel_.block_dims(), + custom_kernel_.thread_dims()); + } + + int64_t shmem_bytes() const { return custom_kernel_.shared_memory_bytes(); } + + private: + // Buffer slices passed to the kernel as arguments. + std::vector args_; + + // args_[i] is written iff (written_[i] == true). + std::vector written_; + + CustomKernel custom_kernel_; + + // Loaded kernels for each `StreamExecutor`. + mutable absl::Mutex mutex_; + absl::flat_hash_map> + kernel_cache_ ABSL_GUARDED_BY(mutex_); +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_KERNEL_THUNK_H_ diff --git a/xla/service/gpu/runtime/memcpy.cc b/xla/service/gpu/runtime/memcpy.cc deleted file mode 100644 index 45955c500e60f..0000000000000 --- a/xla/service/gpu/runtime/memcpy.cc +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/memcpy.h" - -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/runtime/concurrent_region.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" - -namespace xla { -namespace gpu { - -using xla::runtime::CustomCall; -using xla::runtime::StridedMemrefView; - -enum class MemcpyDirection { kD2D, kD2H, kH2D }; - -template -absl::Status MemcpyImpl(const ServiceExecutableRunOptions* run_options, - ConcurrentRegionStatus* region_status, - runtime::StridedMemrefView dst, - runtime::StridedMemrefView src, int64_t stream_id) { - se::Stream* stream = run_options->stream(); - if (stream_id != 0) { - DCHECK(region_status->IsInConcurrentRegion()); - TF_ASSIGN_OR_RETURN(stream, region_status->GetStream(stream_id)); - } else if (region_status->IsInConcurrentRegion()) { - stream = region_status->GetNextStream(); - } - - if (dst.sizes != src.sizes) { - return absl::InvalidArgumentError( - "Source memref sizes do not match destination memref sizes"); - } - - if (dst.strides != src.strides) { - return absl::InvalidArgumentError( - "Source memref strides do not match destination memref strides"); - } - - switch (direction) { - case MemcpyDirection::kD2D: { - se::DeviceMemoryBase dst_data = GetDeviceAddress(dst); - se::DeviceMemoryBase src_data = GetDeviceAddress(src); - stream->ThenMemcpy(&dst_data, src_data, src_data.size()); - } break; - case MemcpyDirection::kD2H: { - se::DeviceMemoryBase src_data = GetDeviceAddress(src); - stream->ThenMemcpy(dst.data, src_data, src_data.size()); - } break; - case MemcpyDirection::kH2D: { - se::DeviceMemoryBase dst_data = GetDeviceAddress(dst); - stream->ThenMemcpy(&dst_data, src.data, dst_data.size()); - } break; - } - - // TODO(jacksonstokes): H2D and D2H memcpy instead of blocking the execution - // thread should return an async token that will become available when - // transfer is completed. - if (direction != MemcpyDirection::kD2D) { - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); - } - - return absl::OkStatus(); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL_TEMPLATE( - MemcpyDirection direction, Memcpy, FunctionWrapper>(), - checks, - CustomCall::Bind("xla.gpu.memcpy") - .UserData() - .UserData() - .Arg() // dst - .Arg() // src - .Attr("stream")); - -void RegisterMemcpyCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.memcpy.d2d", Memcpy); - registry.Register("xla.gpu.memcpy.h2d", Memcpy); - registry.Register("xla.gpu.memcpy.d2h", Memcpy); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/memcpy.h b/xla/service/gpu/runtime/memcpy.h deleted file mode 100644 index cd4ecb4d5020a..0000000000000 --- a/xla/service/gpu/runtime/memcpy.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_MEMCPY_H_ -#define XLA_SERVICE_GPU_RUNTIME_MEMCPY_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime memcpy custom calls. -void RegisterMemcpyCustomCalls(runtime::DirectCustomCallRegistry& registry); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_MEMCPY_H_ diff --git a/xla/service/gpu/runtime/memset.cc b/xla/service/gpu/runtime/memset.cc deleted file mode 100644 index 0e31161604a41..0000000000000 --- a/xla/service/gpu/runtime/memset.cc +++ /dev/null @@ -1,146 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/memset.h" - -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" - -namespace xla { -namespace gpu { - -using xla::runtime::CustomCall; -using xla::runtime::StridedMemrefView; - -// Checks all supported data types to see if the value is zero. -static bool IsZero(CustomCall::VariantArg constant) { - if (auto i1 = constant.get(); succeeded(i1)) - return *i1 == false; - else if (auto i8 = constant.get(); succeeded(i8)) - return *i8 == 0; - else if (auto i16 = constant.get(); succeeded(i16)) - return *i16 == 0; - else if (auto i32 = constant.get(); succeeded(i32)) - return *i32 == 0; - else if (auto i64 = constant.get(); succeeded(i64)) - return *i64 == 0; - else if (auto bf16 = constant.get(); succeeded(bf16)) - return *bf16 == bfloat16(0.0); - else if (auto f16 = constant.get(); succeeded(f16)) - return *f16 == half(0.0); - else if (auto f32 = constant.get(); succeeded(f32)) - return *f32 == 0.0; - else if (auto f64 = constant.get(); succeeded(f64)) - return *f64 == 0.0; - - return false; -} - -// Convert constant value to 32-bit pattern. -static absl::StatusOr ToBitPattern(CustomCall::VariantArg constant) { - // If the value is 8 or 16 bits wide, we can emit a 32-bit memset by - // repeating the value 4 or 2 times, so long as the destination buffer is - // an even multiple of 32 bits long. - // - // This code is identical to `ir_emitter_unnested`. - // - // We use `memcpy` operation to copy bytes between value and the uint32_t bit - // pattern because in theory they might have incompatible alignment, and we - // rely on LLVM to optimize it. - auto extend = [](auto value) -> uint32_t { - static constexpr size_t num_bytes = sizeof(value); - static_assert(num_bytes < 4); - - uint16_t pattern16; - if constexpr (num_bytes == 1) { - uint8_t b = value; - pattern16 = uint16_t{b} | (uint16_t{b} << 8); - } else { - memcpy(&pattern16, &value, sizeof(pattern16)); - } - return uint32_t{pattern16} | (uint32_t{pattern16} << 16); - }; - - // Truncate value to 32-bit pattern. - auto truncate = [](auto value) -> uint32_t { - static_assert(sizeof(value) >= 4); - - uint32_t pattern; - memcpy(&pattern, &value, sizeof(pattern)); - return pattern; - }; - - if (auto i1 = constant.get(); succeeded(i1)) - return extend(*i1); - else if (auto i8 = constant.get(); succeeded(i8)) - return extend(*i8); - else if (auto i16 = constant.get(); succeeded(i16)) - return extend(*i16); - else if (auto i32 = constant.get(); succeeded(i32)) - return truncate(*i32); - else if (auto i64 = constant.get(); succeeded(i64)) - return truncate(*i64); - else if (auto bf16 = constant.get(); succeeded(bf16)) - return extend(static_cast(*bf16)); - else if (auto f16 = constant.get(); succeeded(f16)) - return extend(static_cast(*f16)); - else if (auto f32 = constant.get(); succeeded(f32)) - return truncate(*f32); - else if (auto f64 = constant.get(); succeeded(f64)) - return truncate(*f64); - - return absl::InvalidArgumentError("Unsupported memset constant type"); -} - -static absl::Status MemsetImpl(const ServiceExecutableRunOptions* run_options, - StridedMemrefView dst, - CustomCall::VariantArg constant) { - se::Stream* stream = run_options->stream(); - se::DeviceMemoryBase dst_data = GetDeviceAddress(dst); - - // If the constant is zero we can use memzero directly. - if (IsZero(constant)) { - stream->ThenMemZero(&dst_data, dst_data.size()); - return absl::OkStatus(); - } - - // If the constant is not zero, use the given pattern to `memset`. - TF_ASSIGN_OR_RETURN(uint32_t pattern, ToBitPattern(constant)); - - if (dst_data.size() % 4 != 0) { - return absl::InvalidArgumentError("Memref size is not divisible by 4"); - } - - stream->ThenMemset32(&dst_data, pattern, dst_data.size()); - - return absl::OkStatus(); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Memset, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.memset") - .UserData() - .Arg() // dst - .Arg() // constant -); - -void RegisterMemsetCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.memset", Memset); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/memset.h b/xla/service/gpu/runtime/memset.h deleted file mode 100644 index e6d717941865f..0000000000000 --- a/xla/service/gpu/runtime/memset.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_MEMSET_H_ -#define XLA_SERVICE_GPU_RUNTIME_MEMSET_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime memset custom calls. -void RegisterMemsetCustomCalls(runtime::DirectCustomCallRegistry& registry); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_MEMSET_H_ diff --git a/xla/service/gpu/runtime/memset_thunk.cc b/xla/service/gpu/runtime/memset_thunk.cc new file mode 100644 index 0000000000000..573db38dd877e --- /dev/null +++ b/xla/service/gpu/runtime/memset_thunk.cc @@ -0,0 +1,38 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/memset_thunk.h" + +#include "absl/status/status.h" +#include "xla/stream_executor/stream_executor.h" + +namespace xla { +namespace gpu { + +absl::Status MemzeroThunk::ExecuteOnStream(const ExecuteParams& params) { + se::DeviceMemoryBase dest_data = + params.buffer_allocations->GetDeviceAddress(dest_); + return params.stream->MemZero(&dest_data, dest_data.size()); +} + +absl::Status Memset32BitValueThunk::ExecuteOnStream( + const ExecuteParams& params) { + se::DeviceMemoryBase dest_data = + params.buffer_allocations->GetDeviceAddress(dest_); + return params.stream->Memset32(&dest_data, value_, dest_data.size()); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/memset_thunk.h b/xla/service/gpu/runtime/memset_thunk.h new file mode 100644 index 0000000000000..e1eef4c39c5d3 --- /dev/null +++ b/xla/service/gpu/runtime/memset_thunk.h @@ -0,0 +1,73 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_MEMSET_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_MEMSET_THUNK_H_ + +#include + +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/status.h" + +// This file contains thunks that set a buffer's elements to a particular value. +// This can be faster than emitting a kernel to set the elements. + +namespace xla { +namespace gpu { + +// Thunk that zeroes out a given chunk of memory. +class MemzeroThunk : public Thunk { + public: + explicit MemzeroThunk(ThunkInfo thunk_info, + const BufferAllocation::Slice& dest) + : Thunk(Kind::kMemzero, thunk_info), dest_(dest) {} + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + void ClearCompileTimeInfo() override { Thunk::ClearCompileTimeInfo(); } + + const BufferAllocation::Slice& destination() const { return dest_; } + + private: + const BufferAllocation::Slice dest_; +}; + +// Thunk that sets a given chunk of memory to a particular 32-bit value. The +// destination chunk must have size divisible by 32 bits. +class Memset32BitValueThunk : public Thunk { + public: + explicit Memset32BitValueThunk(ThunkInfo thunk_info, uint32_t value, + const BufferAllocation::Slice& dest) + : Thunk(Kind::kMemset32BitValue, thunk_info), + value_(value), + dest_(dest) {} + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + void ClearCompileTimeInfo() override { Thunk::ClearCompileTimeInfo(); } + + const BufferAllocation::Slice& destination() const { return dest_; } + uint32_t value() const { return value_; } + + private: + const uint32_t value_; + const BufferAllocation::Slice dest_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_MEMSET_THUNK_H_ diff --git a/xla/service/gpu/runtime/nccl_all_gather_thunk.cc b/xla/service/gpu/runtime/nccl_all_gather_thunk.cc new file mode 100644 index 0000000000000..28a39e4a89dc1 --- /dev/null +++ b/xla/service/gpu/runtime/nccl_all_gather_thunk.cc @@ -0,0 +1,121 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/nccl_all_gather_thunk.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/stream.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +namespace impl { +NcclAllGatherConfig GetNcclAllGatherConfig( + const HloAllGatherInstruction* inst) { + NcclAllGatherConfig config; + config.config = GetNcclCollectiveConfig(inst, inst->use_global_device_ids()); + return config; +} + +absl::Status CheckImplementableInst(const HloAllGatherInstruction* inst) { + for (HloInstruction* operand : inst->operands()) { + const Shape& shape = operand->shape(); + + TF_RETURN_IF_ERROR(IsValidOperand(shape, Thunk::kNcclAllGather)); + + if (!ShapeUtil::IsEffectivelyMostMajorDimension( + shape, inst->all_gather_dimension())) { + return absl::AbortedError(absl::StrFormat( + "all-gather dim %u is not the most major in input shape %s", + inst->all_gather_dimension(), shape.ToString(/*print_layout=*/true))); + } + } + + return absl::OkStatus(); +} +} // namespace impl + +NcclAllGatherStartThunk::NcclAllGatherStartThunk( + ThunkInfo thunk_info, NcclApi* nccl_api, + const HloAllGatherInstruction* inst, std::vector buffers) + : NcclCollectiveThunk(Thunk::kNcclAllGatherStart, thunk_info, nccl_api, + IsSyncCollective(inst)), + config_(impl::GetNcclAllGatherConfig(inst)), + buffers_(std::move(buffers)) { + CHECK_EQ(config_.config.operand_count, buffers_.size()); +} + +/*static*/ absl::Status NcclAllGatherStartThunk::CheckImplementable( + const HloAllGatherInstruction* inst, int64_t replica_count, + int64_t partition_count) { + return AddOpDescription( + impl::CheckImplementableInst(inst), inst, replica_count, partition_count); +} + +/*static*/ CollectiveOpGroupMode NcclAllGatherStartThunk::GetGroupMode( + const HloAllGatherInstruction* inst) { + return impl::GetNcclAllGatherConfig(inst).config.group_mode; +} + +absl::Status NcclAllGatherStartThunk::RunNcclCollective( + const ExecuteParams& params, se::Stream& stream, + NcclApi::NcclCommHandle comm) { + TF_ASSIGN_OR_RETURN( + std::vector device_buffers, + ConvertToDeviceBuffers(params, buffers_, + config_.config.operand_element_type)); + return xla::gpu::RunAllGather(nccl_api(), device_buffers, stream, comm); +} + +absl::Status RunAllGather(NcclApi* nccl_api, + std::vector& buffers, + se::Stream& stream, NcclApi::NcclCommHandle comm) { + int device_ordinal = stream.parent()->device_ordinal(); + VLOG(3) << "Performing all-gather from device ordinal: " << device_ordinal; + TF_RETURN_IF_ERROR( + MaybeRegisterBuffers(nccl_api, device_ordinal, buffers, comm)); + + TF_RETURN_IF_ERROR(nccl_api->GroupStart()); + + for (DeviceBufferPair& buffer : buffers) { + TF_RETURN_IF_ERROR(nccl_api->AllGather( + buffer.source_buffer, buffer.destination_buffer, buffer.element_type, + buffer.element_count, comm, &stream)); + } + + TF_RETURN_IF_ERROR(nccl_api->GroupEnd()); + + VLOG(3) << "Done performing all-gather for ordinal: " << device_ordinal; + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/nccl_all_gather_thunk.h b/xla/service/gpu/runtime/nccl_all_gather_thunk.h new file mode 100644 index 0000000000000..2683cca2e9782 --- /dev/null +++ b/xla/service/gpu/runtime/nccl_all_gather_thunk.h @@ -0,0 +1,73 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_NCCL_ALL_GATHER_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_NCCL_ALL_GATHER_THUNK_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/stream_executor/stream.h" + +namespace xla { +namespace gpu { + +struct NcclAllGatherConfig { + NcclCollectiveConfig config; +}; + +// Thunk that performs a NCCL-based All-Gather among CUDA GPU-based replicas. +class NcclAllGatherStartThunk : public NcclCollectiveThunk { + public: + NcclAllGatherStartThunk(ThunkInfo thunk_info, NcclApi* nccl_api, + const HloAllGatherInstruction* inst, + std::vector buffers); + + static const char* GetHloOpName() { return "all-gather-start"; } + + static absl::Status CheckImplementable(const HloAllGatherInstruction* inst, + int64_t replica_count, + int64_t partition_count); + + static CollectiveOpGroupMode GetGroupMode( + const HloAllGatherInstruction* inst); + + const NcclCollectiveConfig& config() const override { return config_.config; } + absl::Span buffers() const { return buffers_; } + + protected: + absl::Status RunNcclCollective(const ExecuteParams& params, + se::Stream& stream, + NcclApi::NcclCommHandle comm) override; + + private: + const NcclAllGatherConfig config_; + const std::vector buffers_; +}; + +absl::Status RunAllGather(NcclApi* nccl_api, + std::vector& buffers, + se::Stream& stream, NcclApi::NcclCommHandle comm); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_NCCL_ALL_GATHER_THUNK_H_ diff --git a/xla/service/gpu/runtime/nccl_all_reduce_thunk.cc b/xla/service/gpu/runtime/nccl_all_reduce_thunk.cc new file mode 100644 index 0000000000000..7a3d9cffbf1f2 --- /dev/null +++ b/xla/service/gpu/runtime/nccl_all_reduce_thunk.cc @@ -0,0 +1,254 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/stream.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +absl::Status RunAllReduce(NcclApi* nccl_api, ReductionKind reduction_kind, + std::vector& buffers, + se::Stream& stream, NcclApi::NcclCommHandle comm) { + int device_ordinal = stream.parent()->device_ordinal(); + VLOG(3) << "Performing all-reduce from device ordinal: " << device_ordinal; + TF_RETURN_IF_ERROR( + MaybeRegisterBuffers(nccl_api, device_ordinal, buffers, comm)); + + TF_RETURN_IF_ERROR(nccl_api->GroupStart()); + for (DeviceBufferPair& buffer : buffers) { + TF_RETURN_IF_ERROR(nccl_api->AllReduce( + buffer.source_buffer, buffer.destination_buffer, buffer.element_type, + buffer.element_count, reduction_kind, comm, &stream)); + } + + return nccl_api->GroupEnd(); +} + +namespace { + +// Generally, the reduction op should be the only operation in the block, except +// the terminator. However, if the type is bf16, the `FloatNormalization` +// pass will have converted the op to float32 and added type conversions. +// TODO(cjfj): Can we prevent the bf16 conversion for this computation? +absl::StatusOr FindReductionOp(mlir::Block& block) { + TF_RET_CHECK(block.getNumArguments() == 2); + mlir::Operation* terminator = block.getTerminator(); + TF_RET_CHECK(terminator); + TF_RET_CHECK(terminator->getNumOperands() == 1); + mlir::Value result = terminator->getOperand(0); + TF_RET_CHECK(block.getArgument(0).getType() == result.getType()); + TF_RET_CHECK(block.getArgument(1).getType() == result.getType()); + + mlir::Operation* result_op = result.getDefiningOp(); + TF_RET_CHECK(result_op); + + // In the bf16 case, the type conversions and op might be fused. + if (mlir::isa(result_op)) { + return FindReductionOp(result_op->getRegion(0).front()); + } + + // Standard case. + if (absl::c_is_permutation(result_op->getOperands(), block.getArguments())) { + return result_op; + } + + // bf16 case. + TF_RET_CHECK(mlir::isa(result_op)); + TF_RET_CHECK(result_op->getNumOperands() == 1); + mlir::Operation* reduction_op = result_op->getOperand(0).getDefiningOp(); + TF_RET_CHECK(reduction_op); + TF_RET_CHECK(reduction_op->getNumOperands() == 2); + mlir::Value operand0 = reduction_op->getOperand(0); + mlir::Value operand1 = reduction_op->getOperand(1); + auto operand0_op = operand0.getDefiningOp(); + auto operand1_op = operand1.getDefiningOp(); + TF_RET_CHECK(operand0_op); + TF_RET_CHECK(operand1_op); + TF_RET_CHECK(operand0_op->getNumOperands() == 1); + TF_RET_CHECK(operand1_op->getNumOperands() == 1); + std::array operands{operand0_op->getOperand(0), + operand1_op->getOperand(0)}; + TF_RET_CHECK(absl::c_is_permutation(operands, block.getArguments())); + return reduction_op; +} + +} // namespace + +namespace impl { + +absl::Status CheckImplementableInst(const HloInstruction* inst, + Thunk::Kind reduction_op) { + for (HloInstruction* operand : inst->operands()) { + TF_RETURN_IF_ERROR(IsValidOperand(operand->shape(), reduction_op)); + } + + if (!MatchReductionComputation(inst->called_computations().front()) + .has_value()) { + return absl::UnimplementedError("Unrecognized reduction computation"); + } + + return absl::OkStatus(); +} + +template +NcclAllReduceConfig GetNcclAllReduceConfigInst(HloInstType* inst) { + std::optional reduction_kind = + MatchReductionComputation(inst->called_computations().front()); + CHECK(reduction_kind.has_value()); + + NcclAllReduceConfig config; + config.config = GetNcclCollectiveConfig(inst, inst->use_global_device_ids()); + config.reduction_kind = *reduction_kind; + return config; +} + +template +CollectiveOpGroupMode GetGroupModeInst(HloInstType* inst) { + return GetNcclAllReduceConfigInst(inst).config.group_mode; +} + +} // namespace impl + +NcclAllReduceReduceScatterThunkBase::NcclAllReduceReduceScatterThunkBase( + Thunk::Kind kind, ThunkInfo thunk_info, NcclApi* nccl_api, + NcclAllReduceConfig config, std::vector buffers, bool is_sync) + : NcclCollectiveThunk(kind, thunk_info, nccl_api, is_sync), + config_(std::move(config)), + buffers_(std::move(buffers)) { + CHECK_EQ(config_.config.operand_count, buffers_.size()); +} + +NcclAllReduceStartThunk::NcclAllReduceStartThunk( + ThunkInfo thunk_info, NcclApi* nccl_api, + const HloAllReduceInstruction* inst, std::vector buffers) + : NcclAllReduceReduceScatterThunkBase( + Thunk::kNcclAllReduceStart, thunk_info, nccl_api, + impl::GetNcclAllReduceConfigInst(inst), std::move(buffers), + IsSyncCollective(inst)) {} + +absl::Status NcclAllReduceStartThunk::CheckImplementable( + const HloAllReduceInstruction* inst, int64_t replica_count, + int64_t partition_count) { + return AddOpDescription( + impl::CheckImplementableInst(inst, Thunk::kNcclAllReduceStart), inst, + replica_count, partition_count); +} + +CollectiveOpGroupMode NcclAllReduceStartThunk::GetGroupMode( + const HloAllReduceInstruction* inst) { + return impl::GetGroupModeInst(inst); +} + +absl::Status NcclAllReduceStartThunk::RunNcclCollective( + const ExecuteParams& params, se::Stream& stream, + NcclApi::NcclCommHandle comm) { + TF_ASSIGN_OR_RETURN( + std::vector device_buffers, + ConvertToDeviceBuffers(params, buffers_, + config_.config.operand_element_type)); + return ::xla::gpu::RunAllReduce(nccl_api(), config_.reduction_kind, + device_buffers, stream, comm); +} + +NcclReduceScatterStartThunk::NcclReduceScatterStartThunk( + ThunkInfo thunk_info, NcclApi* nccl_api, + const HloReduceScatterInstruction* inst, std::vector buffers) + : NcclAllReduceReduceScatterThunkBase( + Thunk::kNcclReduceScatterStart, thunk_info, nccl_api, + impl::GetNcclAllReduceConfigInst(inst), std::move(buffers), + inst->backend_config() + ->collective_backend_config() + .is_sync()) {} + +/*static*/ absl::Status NcclReduceScatterStartThunk::CheckImplementable( + const HloReduceScatterInstruction* inst, int64_t replica_count, + int64_t partition_count) { + return AddOpDescription( + impl::CheckImplementableInst(inst, Thunk::kNcclReduceScatterStart), inst, + replica_count, partition_count); +} + +/*static*/ CollectiveOpGroupMode NcclReduceScatterStartThunk::GetGroupMode( + const HloReduceScatterInstruction* inst) { + return impl::GetGroupModeInst(inst); +} + +absl::Status NcclReduceScatterStartThunk::RunNcclCollective( + const ExecuteParams& params, se::Stream& stream, + NcclApi::NcclCommHandle comm) { + TF_ASSIGN_OR_RETURN( + std::vector device_buffers, + ConvertToDeviceBuffers(params, buffers_, + config_.config.operand_element_type)); + return ::xla::gpu::RunReduceScatter(nccl_api(), config_.reduction_kind, + device_buffers, stream, comm); +} + +absl::Status RunReduceScatter(NcclApi* nccl_api, ReductionKind reduction_kind, + std::vector& buffers, + se::Stream& stream, + NcclApi::NcclCommHandle comm) { + int device_ordinal = stream.parent()->device_ordinal(); + VLOG(3) << "Performing reduce-scatter from device ordinal: " + << device_ordinal; + TF_RETURN_IF_ERROR( + MaybeRegisterBuffers(nccl_api, device_ordinal, buffers, comm)); + + TF_ASSIGN_OR_RETURN(int32_t num_participants, nccl_api->CommCount(comm)); + + TF_RETURN_IF_ERROR(nccl_api->GroupStart()); + + for (DeviceBufferPair& buffer : buffers) { + // buffer.element_count is the source buffers element count. For + // ncclReduceScatter, we need the destination buffers element count. + TF_RET_CHECK(buffer.element_count % num_participants == 0) + << "Source buffer was not an exact multiple of the number of " + "participants."; + + TF_RETURN_IF_ERROR(nccl_api->ReduceScatter( + buffer.source_buffer, buffer.destination_buffer, buffer.element_type, + buffer.element_count / num_participants, reduction_kind, comm, + &stream)); + } + + return nccl_api->GroupEnd(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/nccl_all_reduce_thunk.h b/xla/service/gpu/runtime/nccl_all_reduce_thunk.h new file mode 100644 index 0000000000000..7c21b0bea2b86 --- /dev/null +++ b/xla/service/gpu/runtime/nccl_all_reduce_thunk.h @@ -0,0 +1,120 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_NCCL_ALL_REDUCE_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_NCCL_ALL_REDUCE_THUNK_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/stream_executor/stream.h" + +namespace xla { +namespace gpu { + +struct NcclAllReduceConfig { + NcclCollectiveConfig config; + ReductionKind reduction_kind; +}; + +// Thunk that performs a NCCL-based All-Reduce or Reduce-Scatter among CUDA +// GPU-based replicas. +class NcclAllReduceReduceScatterThunkBase : public NcclCollectiveThunk { + public: + NcclAllReduceReduceScatterThunkBase(Kind kind, ThunkInfo thunk_info, + NcclApi* nccl_api, + NcclAllReduceConfig config, + std::vector buffers, + bool is_sync); + + const NcclCollectiveConfig& config() const override { return config_.config; } + ReductionKind reduction_kind() const { return config_.reduction_kind; } + + absl::Span buffers() const { return buffers_; } + + protected: + const NcclAllReduceConfig config_; + const std::vector buffers_; +}; + +// ----------------------------------------------------------------------------- +// AllReduce thunk. +// ----------------------------------------------------------------------------- + +class NcclAllReduceStartThunk : public NcclAllReduceReduceScatterThunkBase { + public: + NcclAllReduceStartThunk(ThunkInfo thunk_info, NcclApi* nccl_api, + const HloAllReduceInstruction* inst, + std::vector buffers); + + static const char* GetHloOpName() { return "all-reduce-start"; } + + static absl::Status CheckImplementable(const HloAllReduceInstruction* inst, + int64_t replica_count, + int64_t partition_count); + + static CollectiveOpGroupMode GetGroupMode( + const HloAllReduceInstruction* inst); + + protected: + absl::Status RunNcclCollective(const ExecuteParams& params, + se::Stream& stream, + NcclApi::NcclCommHandle comm) override; +}; + +// ----------------------------------------------------------------------------- +// ReduceScatter thunk +// ----------------------------------------------------------------------------- +class NcclReduceScatterStartThunk : public NcclAllReduceReduceScatterThunkBase { + public: + NcclReduceScatterStartThunk(ThunkInfo thunk_info, NcclApi* nccl_api, + const HloReduceScatterInstruction* inst, + std::vector buffers); + + static const char* GetHloOpName() { return "reduce-scatter-start"; } + + static absl::Status CheckImplementable( + const HloReduceScatterInstruction* inst, int64_t replica_count, + int64_t partition_count); + + static CollectiveOpGroupMode GetGroupMode( + const HloReduceScatterInstruction* inst); + + protected: + absl::Status RunNcclCollective(const ExecuteParams& params, + se::Stream& stream, + NcclApi::NcclCommHandle comm) override; +}; + +// ----------------------------------------------------------------------------- + +absl::Status RunAllReduce(NcclApi* nccl_api, ReductionKind reduction_kind, + std::vector& buffers, + se::Stream& stream, NcclApi::NcclCommHandle comm); + +absl::Status RunReduceScatter(NcclApi* nccl_api, ReductionKind reduction_kind, + std::vector& buffers, + se::Stream& stream, NcclApi::NcclCommHandle comm); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_NCCL_ALL_REDUCE_THUNK_H_ diff --git a/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc b/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc new file mode 100644 index 0000000000000..b0cf0294cfc67 --- /dev/null +++ b/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc @@ -0,0 +1,166 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/nccl_all_to_all_thunk.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/substitute.h" +#include "mlir/IR/Value.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +namespace { + +NcclAllToAllConfig GetNcclAllToAllConfig(const HloAllToAllInstruction* instr) { + NcclAllToAllConfig config; + // FIXME(b/180174349): LMHLO AllToAll incorrectly has use_global_device_ids + // attribute and it should be removed. + config.config = GetNcclCollectiveConfig(instr, std::nullopt); + config.has_split_dimension = instr->split_dimension().has_value(); + return config; +} + +} // namespace + +NcclAllToAllStartThunk::NcclAllToAllStartThunk( + ThunkInfo thunk_info, NcclApi* nccl_api, + const HloAllToAllInstruction* instr, + std::vector buffers) + : NcclCollectiveThunk(Thunk::kNcclAllToAllStart, thunk_info, nccl_api, + IsSyncCollective(instr)), + config_(GetNcclAllToAllConfig(instr)), + buffers_(std::move(buffers)) { + CHECK_EQ(config_.config.operand_count, buffers_.size()); +} + +/*static*/ absl::Status NcclAllToAllStartThunk::CheckImplementable( + const HloAllToAllInstruction* instr, int64_t replica_count, + int64_t partition_count) { + auto status = [&instr]() -> absl::Status { + std::optional split_dim = instr->split_dimension(); + for (HloInstruction* operand : instr->operands()) { + Shape shape = operand->shape(); + TF_RETURN_IF_ERROR(IsValidOperand(shape, Thunk::kNcclAllToAll)); + if (split_dim && + !ShapeUtil::IsEffectivelyMostMajorDimension(shape, *split_dim)) { + return absl::UnimplementedError(absl::Substitute( + "all-to-all split dim $0 is not the most major in input shape $1", + *split_dim, shape.ToString(/*print_layout=*/true))); + } + } + return absl::OkStatus(); + }; + return AddOpDescription( + status(), instr, replica_count, partition_count); +} + +/*static*/ CollectiveOpGroupMode NcclAllToAllStartThunk::GetGroupMode( + const HloAllToAllInstruction* instr) { + return GetNcclAllToAllConfig(instr).config.group_mode; +} + +absl::Status NcclAllToAllStartThunk::RunNcclCollective( + const ExecuteParams& params, se::Stream& stream, + NcclApi::NcclCommHandle comm) { + TF_ASSIGN_OR_RETURN( + std::vector device_buffers, + ConvertToDeviceBuffers(params, buffers_, + config_.config.operand_element_type)); + return xla::gpu::RunAllToAll(nccl_api(), config_.has_split_dimension, + device_buffers, stream, comm); +} + +absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension, + std::vector& buffers, + se::Stream& stream, NcclApi::NcclCommHandle comm) { + int device_ordinal = stream.parent()->device_ordinal(); + VLOG(3) << "Performing all-to-all from device ordinal: " << device_ordinal; + TF_RETURN_IF_ERROR( + MaybeRegisterBuffers(nccl_api, device_ordinal, buffers, comm)); + + TF_ASSIGN_OR_RETURN(int32_t num_participants, nccl_api->CommCount(comm)); + + TF_RETURN_IF_ERROR(nccl_api->GroupStart()); + + // AllToAll can operate in two modes. Either it specifies a split dimension, + // in which case inputs are split and outputs concatenated in that dimension + // (here, we only support dimension 0), or it takes a list of inputs + // and produces a tuple of outputs. + if (has_split_dimension) { + for (DeviceBufferPair& buffer : buffers) { + TF_RET_CHECK(buffer.element_count % num_participants == 0) + << "Buffer was not an exact multiple of the number of participants."; + + size_t chunk_elements = buffer.element_count / num_participants; + + for (int peer = 0; peer < num_participants; ++peer) { + se::DeviceMemoryBase send_slice = + NcclApi::Slice(buffer.source_buffer, buffer.element_type, + peer * chunk_elements, chunk_elements); + + se::DeviceMemoryBase recv_slice = + NcclApi::Slice(buffer.destination_buffer, buffer.element_type, + peer * chunk_elements, chunk_elements); + + TF_RETURN_IF_ERROR(nccl_api->Send(send_slice, buffer.element_type, + chunk_elements, peer, comm, &stream)); + + TF_RETURN_IF_ERROR(nccl_api->Recv(recv_slice, buffer.element_type, + chunk_elements, peer, comm, &stream)); + } + } + } else { + TF_RET_CHECK(buffers.size() == num_participants) + << "Number of inputs didn't match the number of participants."; + + for (size_t i = 0; i < buffers.size(); ++i) { + DeviceBufferPair& buffer = buffers[i]; + + TF_RETURN_IF_ERROR( + nccl_api->Send(buffer.source_buffer, buffer.element_type, + buffer.element_count, i, comm, &stream)); + + TF_RETURN_IF_ERROR( + nccl_api->Recv(buffer.destination_buffer, buffer.element_type, + buffer.element_count, i, comm, &stream)); + } + } + + return nccl_api->GroupEnd(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/nccl_all_to_all_thunk.h b/xla/service/gpu/runtime/nccl_all_to_all_thunk.h new file mode 100644 index 0000000000000..65fcc6b18cbe0 --- /dev/null +++ b/xla/service/gpu/runtime/nccl_all_to_all_thunk.h @@ -0,0 +1,73 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_NCCL_ALL_TO_ALL_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_NCCL_ALL_TO_ALL_THUNK_H_ + +#include +#include + +#include "absl/status/status.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/stream_executor/stream.h" + +namespace xla { +namespace gpu { + +struct NcclAllToAllConfig { + NcclCollectiveConfig config; + bool has_split_dimension; +}; + +// Thunk that performs a NCCL-based All-to-All among CUDA GPU-based replicas. +class NcclAllToAllStartThunk : public NcclCollectiveThunk { + public: + NcclAllToAllStartThunk(ThunkInfo thunk_info, NcclApi* nccl_api, + const HloAllToAllInstruction* instr, + std::vector buffers); + + // Returns whether the given instruction can be lowered to a nccl all-to-all + // call. + static absl::Status CheckImplementable(const HloAllToAllInstruction* instr, + int64_t replica_count, + int64_t partition_count); + + static const char* GetHloOpName() { return "all-to-all-start"; } + + static CollectiveOpGroupMode GetGroupMode( + const HloAllToAllInstruction* instr); + + protected: + const NcclCollectiveConfig& config() const override { return config_.config; } + absl::Status RunNcclCollective(const ExecuteParams& params, + se::Stream& stream, + NcclApi::NcclCommHandle comm) override; + + private: + const NcclAllToAllConfig config_; + const std::vector buffers_; +}; + +absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension, + std::vector& buffers, + se::Stream& stream, NcclApi::NcclCommHandle comm); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_NCCL_ALL_TO_ALL_THUNK_H_ diff --git a/xla/service/gpu/runtime/nccl_api.cc b/xla/service/gpu/runtime/nccl_api.cc new file mode 100644 index 0000000000000..f360c66cef674 --- /dev/null +++ b/xla/service/gpu/runtime/nccl_api.cc @@ -0,0 +1,659 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/nccl_api.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "xla/primitive_util.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_activation.h" +#include "xla/stream_executor/gpu/gpu_stream.h" +#include "xla/stream_executor/stream.h" +#include "xla/xla_data.pb.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" + +#if TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" +#if (TF_ROCM_VERSION >= 50200) +#include "rocm/include/rccl/rccl.h" +#else +#include "rocm/include/rccl.h" +#endif // TF_ROCM_VERSION >= 50200 +#else +#include "third_party/nccl/nccl.h" +#endif // TENSORFLOW_USE_ROCM + +namespace xla::gpu { + +//==-----------------------------------------------------------------------===// +// Macros to return or warn on NCCL errors. +//==-----------------------------------------------------------------------===// + +static absl::Status ToStatus(ncclResult_t s, const char* file, int64_t line, + const char* expr) { + if (s == ncclSuccess) return absl::OkStatus(); + + return absl::InternalError(absl::StrFormat( + "%s:%d: NCCL operation %s failed: %s." + " Last NCCL warning(error) log entry (may be unrelated) '%s'.", + file, line, expr, ncclGetErrorString(s), ncclGetLastError(nullptr))); +} + +#define XLA_NCCL_STATUS(expr) \ + xla::gpu::ToStatus(expr, __FILE__, __LINE__, #expr) + +#define XLA_NCCL_RETURN_IF_ERROR(expr) \ + do { \ + absl::Status s = XLA_NCCL_STATUS(expr); \ + if (!s.ok()) { \ + return s; \ + } \ + } while (0) + +#define XLA_NCCL_LOG_IF_ERROR(expr) \ + do { \ + absl::Status s = XLA_NCCL_STATUS(expr); \ + if (!s.ok()) { \ + LOG(ERROR) << s.ToString(); \ + } \ + } while (0) + +#define XLA_NCCL_CHECK(expr) CHECK(XLA_NCCL_STATUS(expr).ok()) + +//==-----------------------------------------------------------------------===// +// Conversions between XLA and NCCL data types +//==-----------------------------------------------------------------------===// + +static size_t ToNcclCount(PrimitiveType dtype, size_t count) { + return primitive_util::IsComplexType(dtype) ? count * 2 : count; +} + +static absl::StatusOr ToNcclDataType(PrimitiveType dtype, + bool is_reduction_op) { + switch (dtype) { + case S8: + case F8E5M2: + case F8E4M3FN: + return ncclInt8; + case PRED: + case U8: + return ncclUint8; + case S32: + return ncclInt32; + case U32: + return ncclUint32; + case S64: + return ncclInt64; + case U64: + return ncclUint64; + case F16: + return ncclFloat16; + case F32: + case C64: + return ncclFloat32; + case F64: + case C128: + return ncclFloat64; + case S16: + case U16: + // For reductions we expect 16 bit integer types to be promoted to 32-bit. + if (is_reduction_op) { + return absl::InvalidArgumentError( + absl::StrFormat("Unsupported data type for reduction operation: %s", + primitive_util::LowercasePrimitiveTypeName(dtype))); + } + // For collectives that just move data around, we can use ncclFloat16 for + // 16-bit integer data types. + return ncclFloat16; + case BF16: + return ncclBfloat16; + default: + return absl::InvalidArgumentError( + absl::StrFormat("Unsupported data type: %s", + primitive_util::LowercasePrimitiveTypeName(dtype))); + } +} + +static ncclRedOp_t ToNcclReduction(ReductionKind kind) { + switch (kind) { + case ReductionKind::SUM: + return ncclSum; + case ReductionKind::PRODUCT: + return ncclProd; + case ReductionKind::MIN: + return ncclMin; + case ReductionKind::MAX: + return ncclMax; + } +} + +static std::string_view ToString(ReductionKind reduction_kind) { + switch (reduction_kind) { + case ReductionKind::SUM: + return "sum"; + case ReductionKind::PRODUCT: + return "prod"; + case ReductionKind::MIN: + return "min"; + case ReductionKind::MAX: + return "max"; + } +} + +//==-----------------------------------------------------------------------===// +// Casting between opaque API structs and NCCL types. +//==-----------------------------------------------------------------------===// + +static NcclApi::NcclCommHandle Cast(ncclComm_t comm) { + return reinterpret_cast(comm); +} + +static ncclComm_t Cast(NcclApi::NcclCommHandle comm) { + return reinterpret_cast(comm); +} + +#ifdef PLATFORM_GOOGLE +static ncclPersistentPlanAllocator* Cast( + NcclApi::NcclPersistentPlanAllocatorHandle handle) { + return reinterpret_cast(handle); +} + +static ncclPersistentPlanAllocator** Cast( + NcclApi::NcclPersistentPlanAllocatorHandle* handle) { + return reinterpret_cast(handle); +} + +static NcclApi::NcclPersistentPlanAllocatorHandle Cast( + ncclPersistentPlanAllocator* ptr) { + return reinterpret_cast(ptr); +} +#endif // PLATFORM_GOOGLE + +//==-----------------------------------------------------------------------===// +// NcclApi::PersistentPlanAllocator +//==-----------------------------------------------------------------------===// + +using PersistentPlanAllocator = NcclApi::PersistentPlanAllocator; +using ScopedPersistentPlanAllocator = NcclApi::ScopedPersistentPlanAllocator; + +PersistentPlanAllocator::PersistentPlanAllocator( + int64_t device_ordinal, se::DeviceMemoryAllocator* allocator, + se::Stream* stream) + : handle_(nullptr), + device_ordinal_(device_ordinal), + allocator_(allocator), + stream_(stream) { + // NCCL persistent plan allocator is implemented as NCCL patch that is not yet + // open sourced and can't be used from OSS XLA. +#ifdef PLATFORM_GOOGLE + auto* nccl_allocator = new ncclPersistentPlanAllocator; + nccl_allocator->ctl = this; + + nccl_allocator->alloc = +[](void** ptr, void* src, size_t size, void* ctl) { + auto allocator = reinterpret_cast(ctl); + auto allocated = allocator->AllocateAndInitialize(src, size); + if (!allocated.ok()) return ncclInternalError; + *ptr = allocated->opaque(); + allocator->AddRef(); + return ncclSuccess; + }; + + nccl_allocator->free = +[](void* ptr, void* ctl) -> ncclResult_t { + auto allocator = reinterpret_cast(ctl); + auto status = allocator->Deallocate(se::DeviceMemoryBase(ptr)); + allocator->DropRef(); + return status.ok() ? ncclSuccess : ncclInternalError; + }; + + handle_ = Cast(nccl_allocator); +#endif // PLATFORM_GOOGLE +} + +PersistentPlanAllocator::~PersistentPlanAllocator() { +#ifdef PLATFORM_GOOGLE + delete Cast(handle_); +#endif // PLATFORM_GOOGLE +} + +absl::StatusOr +PersistentPlanAllocator::AllocateAndInitialize(void* src, size_t size) { + TF_ASSIGN_OR_RETURN(auto owned_mem, + allocator_->Allocate(device_ordinal_, size)); + VLOG(5) << "Allocate and initialize NCCL persistent plan; mem=" + << owned_mem->opaque() << "; size=" << size; + se::DeviceMemoryBase mem = owned_mem.Release(); + TF_RETURN_IF_ERROR(stream_->Memcpy(&mem, src, size)); + return mem; +} + +absl::Status PersistentPlanAllocator::Deallocate(se::DeviceMemoryBase mem) { + VLOG(5) << "Deallocate NCCL persistent plan; mem=" << mem.opaque(); + return allocator_->Deallocate(device_ordinal_, mem); +} + +ScopedPersistentPlanAllocator::ScopedPersistentPlanAllocator( + NcclCommHandle comm, tsl::RCReference allocator) + : comm_(comm), allocator_(std::move(allocator)) { +#ifdef PLATFORM_GOOGLE + XLA_NCCL_CHECK( + ncclCommGetPersistentPlanAllocator(Cast(comm_), Cast(&recover_))) + << "Failed to get NCCL persistent plan allocator"; + XLA_NCCL_CHECK(ncclCommSetPersistentPlanAllocator(Cast(comm_), + Cast(allocator_->handle()))) + << "Failed to set NCCL persistent plan allocator"; +#endif // PLATFORM_GOOGLE +} + +ScopedPersistentPlanAllocator::~ScopedPersistentPlanAllocator() { +#ifdef PLATFORM_GOOGLE + XLA_NCCL_CHECK( + ncclCommSetPersistentPlanAllocator(Cast(comm_), Cast(recover_))) + << "Failed to set NCCL persistent plan allocator"; +#endif // PLATFORM_GOOGLE +} + +//==-----------------------------------------------------------------------===// +// NcclApi +//==-----------------------------------------------------------------------===// + +// This a default NCCL API implementation that forwards all API calls to NCCL +// itself. It is available only if NCCL + CUDA are configured at compile time. +class DefaultNcclApi final : public NcclApi { + public: + absl::StatusOr GetUniqueId() final; + + absl::StatusOr> CommInitRanks( + int32_t nranks, const NcclCliqueId& clique_id, + absl::Span ranks, const Config& config) final; + + absl::StatusOr> CommSplit( + absl::Span comms, int32_t color, + absl::Span keys, std::optional config) final; + + absl::Status CommAbort(NcclCommHandle comm) final; + absl::Status CommFinalize(NcclCommHandle comm) final; + absl::Status CommDestroy(NcclCommHandle comm) final; + + absl::StatusOr CommCount(NcclCommHandle comm) final; + + absl::Status CommGetAsyncError(NcclCommHandle comm) final; + + absl::Status GroupStart() final; + absl::Status GroupEnd() final; + + absl::Status AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, ReductionKind reduction_kind, + NcclCommHandle comm, se::Stream* stream) final; + + absl::Status Broadcast(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, size_t root, NcclCommHandle comm, + se::Stream* stream) final; + + absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, NcclCommHandle comm, + se::Stream* stream) final; + + absl::Status AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, NcclCommHandle comm, + se::Stream* stream) final; + + absl::Status Send(se::DeviceMemoryBase send_buffer, PrimitiveType dtype, + size_t count, int32_t peer, NcclCommHandle comm, + se::Stream* stream) final; + + absl::Status Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, int32_t peer, NcclCommHandle comm, + se::Stream* stream) final; + + absl::StatusOr RegisterBuffer( + NcclCommHandle comm, se::DeviceMemoryBase buffer) final; + + absl::StatusOr DeregisterBuffer( + NcclCommHandle comm, NcclRegisteredBufferHandle handle) final; +}; + +NcclApi* NcclApi::Default() { + static auto* nccl_api = new DefaultNcclApi(); + return nccl_api; +} + +static_assert(NCCL_UNIQUE_ID_BYTES == NcclCliqueId::kSize, + "size of nccl unique id must match the clique id size"); + +static ncclUniqueId AsNcclUniqueId(const NcclCliqueId& clique_id) { + ncclUniqueId id; + absl::c_copy(clique_id.data(), id.internal); + return id; +} + +absl::StatusOr DefaultNcclApi::GetUniqueId() { + VLOG(3) << "Get NCCL unique id"; + ncclUniqueId id; + XLA_NCCL_RETURN_IF_ERROR(ncclGetUniqueId(&id)); + return NcclCliqueId(id.internal); +} + +absl::StatusOr> +DefaultNcclApi::CommInitRanks(int32_t nranks, const NcclCliqueId& clique_id, + absl::Span ranks, + const Config& config) { + VLOG(1) << "Initialize NCCL communicator for " << ranks.size() + << " devices; hash(id)=" << absl::HashOf(clique_id); + + ncclConfig_t comm_config = NCCL_CONFIG_INITIALIZER; +#if !defined(TENSORFLOW_USE_ROCM) || TF_ROCM_VERSION > 50700 + comm_config.splitShare = config.split_share; +#endif + if (config.max_nchannels > 0) { + comm_config.maxCTAs = config.max_nchannels; + VLOG(1) << "Maximum number of channels for hash(id)=" + << absl::HashOf(clique_id) << " is set to: " << comm_config.maxCTAs; + } + + std::vector comm_handles; + std::vector comms; + + comm_handles.resize(ranks.size(), nullptr); + comms.reserve(ranks.size()); + + TF_RETURN_IF_ERROR(GroupStart()); + for (size_t i = 0; i < ranks.size(); ++i) { + VLOG(1) << "Initialize NCCL communicator for rank #" << ranks[i].rank + << " of " << nranks << "; hash(id)=" << absl::HashOf(clique_id); + + se::gpu::ScopedActivateExecutorContext activate_context(ranks[i].device); + + XLA_NCCL_RETURN_IF_ERROR(ncclCommInitRankConfig( + &comm_handles[i], nranks, AsNcclUniqueId(clique_id), ranks[i].rank, + &comm_config)); + } + TF_RETURN_IF_ERROR(GroupEnd()); + + for (ncclComm_t comm_handle : comm_handles) { + comms.emplace_back(Cast(comm_handle), NcclCommDeleter{this}); + } + + return comms; +} + +absl::StatusOr> DefaultNcclApi::CommSplit( + absl::Span comms, int32_t color, + absl::Span keys, std::optional config) { + VLOG(1) << absl::StreamFormat( + "Split %d NCCL communicators using color %d and keys: [%s]", comms.size(), + color, absl::StrJoin(keys, ",")); + +#if !defined(TENSORFLOW_USE_ROCM) || TF_ROCM_VERSION >= 60000 + if (keys.size() != comms.size()) { + return absl::InvalidArgumentError( + absl::StrFormat("Comms and keys must have the same size, but %d != %d", + comms.size(), keys.size())); + } + + ncclConfig_t comm_config = NCCL_CONFIG_INITIALIZER; + if (config.has_value()) { + comm_config.splitShare = config.value().split_share; + // If max_nchannels is set, then we don't want to + // inherit from parent comm. + if (config.value().max_nchannels > 0) { + comm_config.maxCTAs = config.value().max_nchannels; + VLOG(1) << "CommSplit maximum number of channels " + << " is set to: " << comm_config.maxCTAs; + } + } + + // In contrast to grouped initialization communicator splitting initializes + // communicators only after a successful call to `GroupEnd`, so we keep a + // vector of handles and after successful splitting convert to RAII wrappers. + std::vector split_comms_handles; + split_comms_handles.resize(comms.size(), nullptr); + + ncclConfig_t* comm_config_ptr = config.has_value() ? &comm_config : nullptr; + TF_RETURN_IF_ERROR(GroupStart()); + for (size_t i = 0; i < comms.size(); ++i) { + VLOG(1) << "Split NCCL communicator " << comms[i] << " with color " << color + << " and key " << keys[i]; + XLA_NCCL_RETURN_IF_ERROR(ncclCommSplit(Cast(comms[i]), color, keys[i], + &split_comms_handles[i], + /*config=*/comm_config_ptr)); + } + TF_RETURN_IF_ERROR(GroupEnd()); + + std::vector split_comms; + for (size_t i = 0; i < split_comms_handles.size(); ++i) { + split_comms.emplace_back(Cast(split_comms_handles[i]), + NcclCommDeleter{this}); + } + return split_comms; +#else + return absl::UnimplementedError( + absl::StrFormat("%s:%d: NCCL operation ncclCommSplit not implemented", + __FILE__, __LINE__)); +#endif // !defined(TENSORFLOW_USE_ROCM) || TF_ROCM_VERSION >= 60000 +} + +absl::Status DefaultNcclApi::CommAbort(NcclCommHandle comm) { + VLOG(1) << "Abort NCCL communicator: " << comm; + return XLA_NCCL_STATUS(ncclCommAbort(Cast(comm))); +} + +absl::Status DefaultNcclApi::CommFinalize(NcclCommHandle comm) { + VLOG(1) << "Finalize NCCL communicator: " << comm; + return XLA_NCCL_STATUS(ncclCommFinalize(Cast(comm))); +} + +absl::Status DefaultNcclApi::CommDestroy(NcclCommHandle comm) { + VLOG(1) << "Destroy NCCL communicator: " << comm; + return XLA_NCCL_STATUS(ncclCommDestroy(Cast(comm))); +} + +absl::StatusOr DefaultNcclApi::CommCount(NcclCommHandle comm) { + VLOG(5) << "Get the number of ranks in NCCL communicator: " << comm; + int32_t count; + XLA_NCCL_RETURN_IF_ERROR(ncclCommCount(Cast(comm), &count)); + return count; +} + +absl::Status DefaultNcclApi::CommGetAsyncError(NcclCommHandle comm) { + VLOG(5) << "Get last async error for NCCL communicator: " << comm; + + ncclResult_t async_err; + XLA_NCCL_RETURN_IF_ERROR(ncclCommGetAsyncError(Cast(comm), &async_err)); + if (async_err == ncclSuccess) return absl::OkStatus(); + + return absl::InternalError(absl::StrCat( + ncclGetErrorString(async_err), + ". Last NCCL error (maybe unrelated): ", ncclGetLastError(Cast(comm)))); +} + +absl::Status DefaultNcclApi::GroupStart() { + VLOG(5) << "Start NCCL group"; + return XLA_NCCL_STATUS(ncclGroupStart()); +} + +absl::Status DefaultNcclApi::GroupEnd() { + VLOG(5) << "End NCCL group"; + return XLA_NCCL_STATUS(ncclGroupEnd()); +} + +absl::Status DefaultNcclApi::AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + NcclCommHandle comm, + se::Stream* stream) { + VLOG(3) << absl::StreamFormat( + "Launch NCCL AllReduce operation on device #%d; send_buffer=%p; " + "recv_buffer=%p; dtype=%s; count=%d; reduction_kind=%s; comm=%p; " + "stream=%p", + stream->parent()->device_ordinal(), send_buffer.opaque(), + recv_buffer.opaque(), primitive_util::LowercasePrimitiveTypeName(dtype), + count, ToString(reduction_kind), comm, stream); + + TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); + + return XLA_NCCL_STATUS(ncclAllReduce( + send_buffer.opaque(), recv_buffer.opaque(), ToNcclCount(dtype, count), + nccl_dtype, ToNcclReduction(reduction_kind), Cast(comm), + se::gpu::AsGpuStreamValue(stream))); +} + +absl::Status DefaultNcclApi::Broadcast(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + size_t root, NcclCommHandle comm, + se::Stream* stream) { + VLOG(3) << absl::StreamFormat( + "Launch NCCL Broadcast operation on device #%d; send_buffer=%p; " + "recv_buffer=%p; dtype=%s; count=%d; root=%d; comm=%p; " + "stream=%p", + stream->parent()->device_ordinal(), send_buffer.opaque(), + recv_buffer.opaque(), primitive_util::LowercasePrimitiveTypeName(dtype), + count, root, comm, stream); + + TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); + + return XLA_NCCL_STATUS(ncclBroadcast( + send_buffer.opaque(), recv_buffer.opaque(), ToNcclCount(dtype, count), + nccl_dtype, root, Cast(comm), se::gpu::AsGpuStreamValue(stream))); +} + +absl::Status DefaultNcclApi::ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + NcclCommHandle comm, + se::Stream* stream) { + VLOG(3) << absl::StreamFormat( + "Launch NCCL ReduceScatter operation on device #%d; send_buffer=%p; " + "recv_buffer=%p; dtype=%s; count=%d; reduction_kind=%s; comm=%p; " + "stream=%p", + stream->parent()->device_ordinal(), send_buffer.opaque(), + recv_buffer.opaque(), primitive_util::LowercasePrimitiveTypeName(dtype), + count, ToString(reduction_kind), comm, stream); + + TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); + + return XLA_NCCL_STATUS(ncclReduceScatter( + send_buffer.opaque(), recv_buffer.opaque(), ToNcclCount(dtype, count), + nccl_dtype, ToNcclReduction(reduction_kind), Cast(comm), + se::gpu::AsGpuStreamValue(stream))); +} + +absl::Status DefaultNcclApi::AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + NcclCommHandle comm, + se::Stream* stream) { + VLOG(3) << absl::StreamFormat( + "Launch NCCL AllGather operation on device #%d; send_buffer=%p; " + "recv_buffer=%p; dtype=%s; count=%d; comm=%p; stream=%p", + stream->parent()->device_ordinal(), send_buffer.opaque(), + recv_buffer.opaque(), primitive_util::LowercasePrimitiveTypeName(dtype), + count, comm, stream); + + TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); + + return XLA_NCCL_STATUS(ncclAllGather( + send_buffer.opaque(), recv_buffer.opaque(), ToNcclCount(dtype, count), + nccl_dtype, Cast(comm), se::gpu::AsGpuStreamValue(stream))); +} + +absl::Status DefaultNcclApi::Send(se::DeviceMemoryBase send_buffer, + PrimitiveType dtype, size_t count, + int32_t peer, NcclCommHandle comm, + se::Stream* stream) { + VLOG(3) << absl::StreamFormat( + "Launch NCCL Send operation on device #%d; send_buffer=%p; dtype=%s; " + "count=%d; peer=%d; comm=%p; stream=%p", + stream->parent()->device_ordinal(), send_buffer.opaque(), + primitive_util::LowercasePrimitiveTypeName(dtype), count, peer, comm, + stream); + + TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); + + return XLA_NCCL_STATUS( + ncclSend(send_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype, + peer, Cast(comm), se::gpu::AsGpuStreamValue(stream))); +} + +absl::Status DefaultNcclApi::Recv(se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + int32_t peer, NcclCommHandle comm, + se::Stream* stream) { + VLOG(3) << absl::StreamFormat( + "Launch NCCL Recv operation on device #%d; recv_buffer=%p; dtype=%s; " + "count=%d; peer=%d; comm=%p; stream=%p", + stream->parent()->device_ordinal(), recv_buffer.opaque(), + primitive_util::LowercasePrimitiveTypeName(dtype), count, peer, comm, + stream); + + TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); + + return XLA_NCCL_STATUS( + ncclRecv(recv_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype, + peer, Cast(comm), se::gpu::AsGpuStreamValue(stream))); +} + +absl::StatusOr +DefaultNcclApi::RegisterBuffer(NcclCommHandle comm, + se::DeviceMemoryBase buffer) { + VLOG(3) << absl::StreamFormat( + "Register buffer for NCCL communicator; buffer=%p; size=%d; comm=%p", + buffer.opaque(), buffer.size(), comm); + void* handle = nullptr; +#if (NCCL_VERSION_CODE >= 21901) + XLA_NCCL_RETURN_IF_ERROR( + ncclCommRegister(Cast(comm), buffer.opaque(), buffer.size(), &handle)); +#endif // NCCL_VERSION_CODE >= 21901 + return reinterpret_cast(handle); +} + +absl::StatusOr +DefaultNcclApi::DeregisterBuffer(NcclCommHandle comm, + NcclRegisteredBufferHandle handle) { + VLOG(3) << absl::StreamFormat( + "Deregister buffer for NCCL communicator; handle=%p; comm=%p", handle, + comm); +#if (NCCL_VERSION_CODE >= 21901) + return XLA_NCCL_STATUS( + ncclCommDeregister(Cast(comm), reinterpret_cast(handle))); +#endif // NCCL_VERSION_CODE >= 21901 +} +} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/nccl_api.h b/xla/service/gpu/runtime/nccl_api.h new file mode 100644 index 0000000000000..05f230088826c --- /dev/null +++ b/xla/service/gpu/runtime/nccl_api.h @@ -0,0 +1,274 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_NCCL_API_H_ +#define XLA_SERVICE_GPU_RUNTIME_NCCL_API_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream.h" +#include "xla/xla_data.pb.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/logging.h" + +namespace xla::gpu { + +//===----------------------------------------------------------------------===// +// NcclApi +//===----------------------------------------------------------------------===// + +// NcclApi hides implementation detail of collective operations built on top of +// NCCL library so that no other parts of XLA should include nccl.h header +// directly (or indirectly). + +class NcclApi { + public: + virtual ~NcclApi() = default; + + // Communicator configuration. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig + struct Config { + bool split_share = false; + int64_t max_nchannels = 0; + }; + + // Returns a default NcclApi for a current process. Can be a real one based on + // NCCL or a stub if XLA compiled without NCCL or CUDA support. + static NcclApi* Default(); + + // Forward declarations of opaque structs corresponding to underlying platform + // types (also defined as opaque structs). + struct NcclComm; + struct NcclPersistentPlanAllocator; + struct NcclRegisteredBuffer; + + // Convenience handles for defining API functions. + using NcclCommHandle = NcclComm*; + using NcclPersistentPlanAllocatorHandle = NcclPersistentPlanAllocator*; + using NcclRegisteredBufferHandle = NcclRegisteredBuffer*; + + // RAII handle for NCCL communicator. + struct NcclCommDeleter { + void operator()(NcclCommHandle comm) { + if (auto destroyed = api->CommDestroy(comm); !destroyed.ok()) + LOG(ERROR) << "Failed to destroy communicator: " << destroyed; + } + NcclApi* api; + }; + + using OwnedNcclComm = std::unique_ptr; + + // Persistent plan allocator allows to pass XLA memory allocator to NCCL to + // allocate device memory for persistent execution plans for NCCL operations + // captured into CUDA graphs. It relies on NCCL patch that is not part of + // upstream NCCL. + class PersistentPlanAllocator + : public tsl::ReferenceCounted { + public: + PersistentPlanAllocator(int64_t device_ordinal, + se::DeviceMemoryAllocator* allocator, + se::Stream* stream); + ~PersistentPlanAllocator(); + + // Allocates new device memory buffer and copies `size` bytes from `src` + // into it (NCCL persistent execution plan for a collective operation). + absl::StatusOr AllocateAndInitialize(void* src, + size_t size); + absl::Status Deallocate(se::DeviceMemoryBase mem); + + NcclPersistentPlanAllocatorHandle handle() const { return handle_; } + + private: + NcclPersistentPlanAllocatorHandle handle_; // owned + + int64_t device_ordinal_; + se::DeviceMemoryAllocator* allocator_; + se::Stream* stream_; + }; + + // RAII helper to set NCCL persistent plan `allocator` for `comm`. + class ScopedPersistentPlanAllocator { + public: + ScopedPersistentPlanAllocator( + NcclCommHandle comm, + tsl::RCReference allocator); + ~ScopedPersistentPlanAllocator(); + + private: + NcclCommHandle comm_; + NcclPersistentPlanAllocatorHandle recover_; + tsl::RCReference allocator_; + }; + + struct DeviceRank { + DeviceRank(se::StreamExecutor* device, int32_t rank) + : device(device), rank(rank) {} + + se::StreamExecutor* device; + int32_t rank; + }; + + // Returns a slice of device memory `buff` containing `count` values of data + // type `dtype` starting from `offset`. + static se::DeviceMemoryBase Slice(se::DeviceMemoryBase buff, + PrimitiveType dtype, size_t offset, + size_t count) { + size_t multiplier = ShapeUtil::ByteSizeOfPrimitiveType(dtype); + return buff.GetByteSlice(offset * multiplier, count * multiplier); + } + + // Creates a new unique clique id. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclgetuniqueid + virtual absl::StatusOr GetUniqueId() = 0; + + // Creates new communicators for given devices. + // + // This API doesn't have a corresponding API in NCCL and implemented as + // multiple calls to ncclCommInitRank within a single group. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcomminitrank + virtual absl::StatusOr> CommInitRanks( + int32_t nranks, const NcclCliqueId& clique_id, + absl::Span ranks, const Config& config) = 0; + + // Creates new communicators by splitting `comms`. + // + // This API doesn't have a corresponding API in NCCL and implemented as + // multiple calls to ncclCommSplit within a single group. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommsplit + virtual absl::StatusOr> CommSplit( + absl::Span comms, int32_t color, + absl::Span keys, std::optional config) = 0; + + // Abort any uncompleted operations and destroys the communicator. Frees + // resources that are allocated to a communicator object comm. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommabort + virtual absl::Status CommAbort(NcclCommHandle comm) = 0; + + // Finalize a communicator object comm. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommdestroy + virtual absl::Status CommFinalize(NcclCommHandle comm) = 0; + + // Destroy a communicator object comm. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommdestroy + virtual absl::Status CommDestroy(NcclCommHandle comm) = 0; + + // Returns the number of ranks in the NCCL communicator comm. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommcount + virtual absl::StatusOr CommCount(NcclCommHandle comm) = 0; + + // Queries the progress and potential errors of asynchronous operations + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommgetasyncerror + virtual absl::Status CommGetAsyncError(NcclCommHandle comm) = 0; + + // Starts a group call. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/group.html#ncclgroupstart + virtual absl::Status GroupStart() = 0; + + // Ends a group call. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/group.html#ncclgroupend + virtual absl::Status GroupEnd() = 0; + + // Reduce buffers of length `count` in `send_buff` using `reduction_kind` + // reduction and leaves identical copies of the result on each `recv_buff`. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/colls.html#ncclallreduce + virtual absl::Status AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + NcclCommHandle comm, se::Stream* stream) = 0; + + // Copy data in `send_buff` from the root GPU to the `recv_buff` on + // all GPUs. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/colls.html#ncclbroadcast + virtual absl::Status Broadcast(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, size_t root, + NcclCommHandle comm, se::Stream* stream) = 0; + // Reduce data in `send_buff` from all GPUs using the `reduction_kind` + // operation and leave the reduced result scattered over the devices so that + // the `recv_buff` on rank `i` will contain the i-th block of the result. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/colls.html#ncclreducescatter + virtual absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + NcclCommHandle comm, + se::Stream* stream) = 0; + + // Gather `count` values from all GPUs into recv_buffer, receiving data from + // rank `i` at offset `i * sendcount`. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/colls.html#ncclallgather + virtual absl::Status AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + NcclCommHandle comm, se::Stream* stream) = 0; + + // Send data from `send_buff` to rank `peer`. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclsend + virtual absl::Status Send(se::DeviceMemoryBase send_buffer, + PrimitiveType dtype, size_t count, int32_t peer, + NcclCommHandle comm, se::Stream* stream) = 0; + + // Receive data from rank `peer` into `recv_buff`. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclrecv + virtual absl::Status Recv(se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, int32_t peer, + NcclCommHandle comm, se::Stream* stream) = 0; + + // Register `buffer` with communicator `comm` for zero-copy communication. + // Returned handle can be used for future unregistration. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommregister + virtual absl::StatusOr RegisterBuffer( + NcclCommHandle comm, se::DeviceMemoryBase buffer) = 0; + + // Deregister buffer represented by `handle` from communicator `comm`. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommderegister + virtual absl::StatusOr DeregisterBuffer( + NcclCommHandle comm, NcclRegisteredBufferHandle handle) = 0; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_RUNTIME_NCCL_API_H_ diff --git a/xla/service/gpu/runtime/nccl_api_stub.cc b/xla/service/gpu/runtime/nccl_api_stub.cc new file mode 100644 index 0000000000000..a0eab15e44542 --- /dev/null +++ b/xla/service/gpu/runtime/nccl_api_stub.cc @@ -0,0 +1,173 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream.h" +#include "tsl/concurrency/ref_count.h" + +namespace xla::gpu { + +// This is a NCCL API stub that is linked into the process when XLA compiled +// without NCCL or CUDA support. It returns errors from all API calls. This stub +// makes it always safe to include NCCL API headers everywhere in XLA without +// #ifdefs or complex build rules magic. All magic handled by `:nccl_api`. + +//==-----------------------------------------------------------------------===// +// NcclApi::PersistentPlanAllocator +//==-----------------------------------------------------------------------===// + +using PersistentPlanAllocator = NcclApi::PersistentPlanAllocator; +using ScopedPersistentPlanAllocator = NcclApi::ScopedPersistentPlanAllocator; + +PersistentPlanAllocator::PersistentPlanAllocator(int64_t, + se::DeviceMemoryAllocator*, + se::Stream*) { + // Suppress clang unused private field warnings. + (void)device_ordinal_; + (void)allocator_; + (void)stream_; +} + +PersistentPlanAllocator::~PersistentPlanAllocator() = default; + +absl::StatusOr +PersistentPlanAllocator::AllocateAndInitialize(void*, size_t) { + return absl::UnimplementedError("XLA compiled without NCCL support"); +} + +absl::Status PersistentPlanAllocator::Deallocate(se::DeviceMemoryBase mem) { + return absl::UnimplementedError("XLA compiled without NCCL support"); +} + +ScopedPersistentPlanAllocator::ScopedPersistentPlanAllocator( + NcclCommHandle, tsl::RCReference) { + // Suppress clang unused private field warnings. + (void)comm_; + (void)recover_; + (void)allocator_; +} + +ScopedPersistentPlanAllocator::~ScopedPersistentPlanAllocator() = default; + +//===----------------------------------------------------------------------===// +// NcclApiStub +//===----------------------------------------------------------------------===// + +static absl::Status UnimplementedError() { + return absl::UnimplementedError("XLA compiled without NCCL support"); +} + +class NcclApiStub final : public NcclApi { + public: + absl::StatusOr GetUniqueId() final { + return UnimplementedError(); + } + + absl::StatusOr> CommInitRanks( + int32_t, const NcclCliqueId&, absl::Span, + const Config&) final { + return UnimplementedError(); + } + + absl::StatusOr> CommSplit( + absl::Span, int32_t, absl::Span, + std::optional) final { + return UnimplementedError(); + } + + absl::Status CommAbort(NcclCommHandle) final { return UnimplementedError(); } + + absl::Status CommFinalize(NcclCommHandle) final { + return UnimplementedError(); + } + + absl::Status CommDestroy(NcclCommHandle) final { + return UnimplementedError(); + } + + absl::StatusOr CommCount(NcclCommHandle) final { + return UnimplementedError(); + } + + absl::Status CommGetAsyncError(NcclCommHandle) final { + return UnimplementedError(); + } + + absl::Status GroupStart() final { return UnimplementedError(); } + absl::Status GroupEnd() final { return UnimplementedError(); } + + absl::Status AllReduce(se::DeviceMemoryBase, se::DeviceMemoryBase, + PrimitiveType, size_t, ReductionKind, NcclCommHandle, + se::Stream*) final { + return UnimplementedError(); + } + + absl::Status Broadcast(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, size_t root, NcclCommHandle comm, + se::Stream* stream) final { + return UnimplementedError(); + } + + absl::Status ReduceScatter(se::DeviceMemoryBase, se::DeviceMemoryBase, + PrimitiveType, size_t, ReductionKind, + NcclCommHandle, se::Stream*) final { + return UnimplementedError(); + } + + absl::Status AllGather(se::DeviceMemoryBase, se::DeviceMemoryBase, + PrimitiveType, size_t, NcclCommHandle, + se::Stream*) final { + return UnimplementedError(); + } + + absl::Status Send(se::DeviceMemoryBase, PrimitiveType, size_t, int32_t, + NcclCommHandle, se::Stream*) final { + return UnimplementedError(); + } + + absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, int32_t, + NcclCommHandle, se::Stream*) final { + return UnimplementedError(); + } + + absl::StatusOr RegisterBuffer( + NcclCommHandle, se::DeviceMemoryBase) final { + return UnimplementedError(); + } + + absl::StatusOr DeregisterBuffer( + NcclCommHandle, NcclRegisteredBufferHandle) final { + return UnimplementedError(); + } +}; + +NcclApi* NcclApi::Default() { + static auto* nccl_api = new NcclApiStub(); + return nccl_api; +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/nccl_clique.cc b/xla/service/gpu/runtime/nccl_clique.cc new file mode 100644 index 0000000000000..605c449c3c5fd --- /dev/null +++ b/xla/service/gpu/runtime/nccl_clique.cc @@ -0,0 +1,517 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/nccl_clique.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/btree_map.h" +#include "absl/container/node_hash_map.h" +#include "absl/functional/function_ref.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/debug_options_flags.h" +#include "xla/executable_run_options.h" +#include "xla/service/global_device_id.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/lockable.h" +#include "xla/service/rendezvous.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/hash.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { + +//===----------------------------------------------------------------------===// +// NcclCliqueIdCallback +//===----------------------------------------------------------------------===// + +bool IsGlobalNcclConfig() { + static const char* const nccl_comm_id = std::getenv("NCCL_COMM_ID"); + return nccl_comm_id != nullptr; +} + +absl::StatusOr GetNcclCliqueIdCallback( + const NcclCliqueIdCallback* clique_id_callback, bool is_local) { + if (clique_id_callback != nullptr) return clique_id_callback; + + TF_RET_CHECK(is_local || IsGlobalNcclConfig()) + << "If non-local devices are taking part of a collective API on " + "GPU, the nccl_clique_id_callback must be provided by the client."; + + static auto* local_callback = new NcclCliqueIdCallback( + [](const NcclCliqueKey&) { return NcclApi::Default()->GetUniqueId(); }); + return local_callback; +} + +//===----------------------------------------------------------------------===// +// NcclClique Acquire and Initialization Timeouts +//===----------------------------------------------------------------------===// + +// We rely on rendezvous (all participating threads arriving to a rendezvous +// barrier at the same time) to guarantee that NCCL communicators used in a way +// that does not lead to a deadlock. This itself can create new deadlocks if +// thread pools sized incorrectly. To prevent hard to debug deadlocks with WARN +// and terminate when detect that rendezvous runs for too long. + +static absl::Duration WarnStuckTimeout() { return absl::Seconds(10); } + +static absl::Duration TerminateTimeout() { + static const int64_t terminate_timeout = + xla::GetDebugOptionsFromFlags() + .xla_gpu_nccl_termination_timeout_seconds(); + return (terminate_timeout >= 0) ? absl::Seconds(terminate_timeout) + : absl::InfiniteDuration(); +} + +//===----------------------------------------------------------------------===// +// NcclClique +//===----------------------------------------------------------------------===// + +NcclCliqueCommunicators::NcclCliqueCommunicators( + NcclCliqueKey clique_key, std::optional clique_id, + absl::btree_map communicators) + : clique_key_(std::move(clique_key)), + clique_id_(std::move(clique_id)), + communicators_(std::move(communicators)) {} + +std::optional NcclCliqueCommunicators::comm( + int32_t rank) { + if (auto it = communicators_.find(rank); it != communicators_.end()) { + return it->second.get(); + } + return std::nullopt; +} + +bool NcclCliqueCommunicators::IsLocal() const { + return communicators_.size() == clique_key_.devices().size(); +} + +void NcclCliqueCommunicators::ForEachComm( + absl::FunctionRef fn) { + for (auto& [rank, comm] : communicators_) { + fn(rank, comm.get()); + } +} + +std::string NcclCliqueCommunicators::DebugString() const { + std::string out = + absl::StrFormat("clique_key: %s; hash(id): %d; size: %d; communicators: ", + clique_key_.ToString(), + clique_id_.has_value() ? absl::HashOf(*clique_id_) : 0, + communicators_.size()); + int32_t cnt = 0; + for (const auto& [rank, comm] : communicators_) { + if (cnt++) absl::StrAppend(&out, ", "); + absl::StrAppendFormat(&out, "[rank=%d, comm=%p]", rank, comm.get()); + } + return out; +} + +std::string NcclClique::DebugString() const { + return absl::StrFormat("NcclClique: %s", value().DebugString()); +} + +namespace { +// Container for initialized and ready to use local (in-process) NCCL cliques. +struct NcclCliques { + absl::Mutex mu; + absl::node_hash_map map ABSL_GUARDED_BY(mu); +}; +} // namespace + +// Returns local (in-process) NcclCliques container. +static NcclCliques& GetNcclCliques() { + static auto* cliques = new NcclCliques; + return *cliques; +} + +//===----------------------------------------------------------------------===// +// NcclClique Heart Beat Monitor +//===----------------------------------------------------------------------===// + +// Runs an async error check for a `comm` and aborts it if it is in the +// error state. It will free resources that are allocated to a communicator +// and abort any uncompleted operations before destroying the communicator. +static absl::Status CheckComm(NcclApi::NcclCommHandle comm) { + absl::Status async_err = NcclApi::Default()->CommGetAsyncError(comm); + if (!async_err.ok()) { + LOG(ERROR) << "Aborting communicator: " << comm + << " due to async NCCL error: " << async_err; + TF_RETURN_IF_ERROR(NcclApi::Default()->CommAbort(comm)); + } + return async_err; +} + +// Runs async check on all communicators in a clique. +static void CheckClique(const NcclCliqueKey& clique_key, + NcclClique& lockable_clique) { + if (NcclClique::Lock clique = lockable_clique.TryAcquire()) { + VLOG(5) << "Checking NCCL clique " << clique_key.ToString() + << " for async errors; num_communicators=" + << clique->num_communicators(); + clique->ForEachComm([](int32_t rank, NcclApi::NcclCommHandle comm) { + if (auto status = CheckComm(comm); !status.ok()) LOG(ERROR) << status; + }); + } else { + VLOG(5) << "Skip checking in-use NCCL clique " << clique_key.ToString(); + } +} + +// TODO(ezhulenev): We need a mechanism to destroy whole clique when one of the +// communicators is aborted to be able to recover from errors. +static void NcclCliqueHeartBeatMonitorThread() { + VLOG(5) << "Starting NCCL clique heart beat monitor"; + while (true) { + absl::SleepFor(absl::Seconds(30)); + NcclCliques& cliques = GetNcclCliques(); + absl::MutexLock lock(&cliques.mu); + VLOG(5) << "Checking NCCL communicators for async errors" + << "; num_cliques=" << cliques.map.size(); + for (auto& [clique_key, lockable_clique] : cliques.map) { + CheckClique(clique_key, lockable_clique); + } + } +} + +static void StartNcclCliqueHeartBeatMonitor() { + static auto* monitor_thread = tsl::Env::Default()->StartThread( + tsl::ThreadOptions(), "nccl_clique_heart_beat_monitor", + NcclCliqueHeartBeatMonitorThread); + (void)monitor_thread; // suppress unused variable warning +} + +//===----------------------------------------------------------------------===// +// NcclClique Initialization +//===----------------------------------------------------------------------===// + +// NcclClique initialization must be executed together by all participants, and +// we rely on rendezvous to guarantee that all ranks are ready to initialize +// NCCL communicators. In general collective operations are expected to be +// executed concurrently by all participating ranks, and when some ranks do not +// join the operation it leads to deadlocks. We use a combination of rendezvous +// and locking to guarantee that all collective operations in XLA have a well +// defined order and do not deadlock inside underlying collective communication +// library. + +static auto DeviceRanksToString(absl::Span ranks) { + return absl::StrJoin(ranks, ",", [](std::string* str, auto& rank) { + str->append(std::to_string(rank.rank)); + }); +} + +// Joins a NcclClique initialization rendezvous for a `clique_key` and returns +// a lock that gives an access to initialized clique (access is shared between +// all participating ranks that own a shared pointer). +static absl::StatusOr> InitializeNcclClique( + se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key, + const NcclCliqueIdCallback& clique_id_callback, + int32_t num_local_participants, int32_t rank, NcclApi::Config& config) { + int nranks = clique_key.devices().size(); + VLOG(3) << "Initialize NCCL clique " << clique_key.ToString() << " rank #" + << rank << "; num_local_participants=" << num_local_participants; + + // Start NCCL clique heart beat monitor when create a first clique. + StartNcclCliqueHeartBeatMonitor(); + + // Initializes a NcclClique for given device ranks and returns a lock that + // gives access to clique communicators. + auto initialize = [&](absl::Span args) + -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(auto clique_id, clique_id_callback(clique_key)); + + std::vector ranks; + ranks.reserve(args.size()); + for (auto* arg : args) ranks.emplace_back(*arg); + + // Sort device ranks, mainly to get more readable logs below, NCCL does + // not care in what order ranks are initialized. + absl::c_sort(ranks, [](auto& a, auto& b) { return a.rank < b.rank; }); + + VLOG(3) << absl::StreamFormat( + "Create NCCL communicators for clique %s; ranks=[%s]; hash(id)=%d", + clique_key.ToString(), DeviceRanksToString(ranks), + absl::HashOf(clique_id)); + + TF_ASSIGN_OR_RETURN( + std::vector created_comms, + NcclApi::Default()->CommInitRanks(nranks, clique_id, ranks, config)); + + absl::btree_map comms; + for (size_t i = 0; i < ranks.size(); ++i) { + comms[ranks[i].rank] = std::move(created_comms[i]); + } + + VLOG(3) << absl::StreamFormat( + "Created NCCL communicators for clique %s; ranks=[%s]; hash(id)=%d", + clique_key.ToString(), DeviceRanksToString(ranks), + absl::HashOf(clique_id)); + + NcclCliques& cliques = GetNcclCliques(); + absl::MutexLock lock(&cliques.mu); + + // Create a new clique with given clique key and communicators. + auto emplaced = cliques.map.try_emplace(clique_key, clique_key, clique_id, + std::move(comms)); + + // We can have a race to create a clique for a given key, the winner + // inserts it into a map and the looser destroys all communicators. + if (!emplaced.second) { + VLOG(3) << "Clique already exists: " + << emplaced.first->second.DebugString(); + } else { + VLOG(3) << "Created new clique: " << emplaced.first->second.DebugString(); + } + + return emplaced.first->second.Acquire(); + }; + + // We include `run_id` to a rendezvous key to make sure that multiple + // concurrent initializations will not join the same rendezvous. The winner + // will update cliques state, and others will destroy unused communicators. + auto rendezvous_key = std::make_tuple(run_id, clique_key); + auto initialization_rendezvous_name = + absl::StrFormat("initialize clique for rank %d; clique=%s; run_id=%d", + rank, clique_key.ToString(), run_id.ToInt()); + + NcclApi::DeviceRank device_rank = {device, rank}; + + return RendezvousSingle>( + initialization_rendezvous_name, rendezvous_key, device_rank, + num_local_participants, initialize, WarnStuckTimeout(), + TerminateTimeout()); +} + +// Computes a unique NCCL communicator split color from a clique key. We use a +// deterministic hash function to guarantee that all participating processes get +// the same color value for a clique. +static int32_t GetCommSplitColor(const NcclCliqueKey& clique_key) { + std::vector global_device_ids; + global_device_ids.reserve(clique_key.devices().size()); + + for (GlobalDeviceId id : clique_key.devices()) { + global_device_ids.push_back(id.value()); + } + + return abs(static_cast( + tsl::Hash32(reinterpret_cast(global_device_ids.data()), + sizeof(int64_t) * global_device_ids.size(), 0))); +} + +// Joins a NcclClique initialization rendezvous for a `clique_key` and returns +// a lock that gives an access to clique created by splitting already acquired +// `parent_clique` clique (access is shared between all participating ranks that +// own a shared pointer). +static absl::StatusOr> InitializeNcclClique( + se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key, + std::shared_ptr parent_clique, + int32_t num_local_participants, int32_t rank, NcclApi::Config& config) { + // Find our rank in the parent clique. + const NcclCliqueKey& parent_clique_key = (*parent_clique)->clique_key(); + int32_t parent_rank = *parent_clique_key.rank(clique_key.devices()[rank]); + + VLOG(3) << "Initialize NCCL clique " << clique_key.ToString() << " rank #" + << rank << " by splitting rank #" << parent_rank + << " in parent clique " << parent_clique_key.ToString() + << "; num_local_participants=" << num_local_participants; + + using RankPair = std::pair; + RankPair rank_pair = {parent_rank, rank}; + + // Current approach for communicator splitting works because of XLAs SPMD + // programming model where all collective operations have replica groups that + // include all ranks. This property guarantees that we'll split each + // communicator exactly once with a unique color computed from rank mapping + // and each communicator in the parent clique will become a member of exactly + // one new clique. Clique splitting happens concurrently for multiple + // non-overlapping clique and this guarantees forward progress even with + // implicit synchronization inside NCCL. + + // Initializes a NcclClique for given device ranks and returns a lock that + // gives access to clique communicators. + auto split = [&](absl::Span rank_pairs) + -> absl::StatusOr { + // Collect mapping from ranks in parent clique to ranks in a new clique. + absl::btree_map rank_mapping; + for (auto* rank_pair : rank_pairs) { + rank_mapping[rank_pair->first] = rank_pair->second; + } + + auto rank_mapping_formatter = [](std::string* str, auto mapping) { + absl::StrAppend(str, mapping.first, "->", mapping.second); + }; + + // Collect parent communicators we'll be splitting from and keys for + // creating new communicators. + std::vector parent_comms; + std::vector keys; + + for (auto& [parent_rank, split_rank] : rank_mapping) { + auto parent_comm = (*parent_clique)->comm(parent_rank); + if (!parent_comm.has_value()) { + return absl::InvalidArgumentError(absl::StrFormat( + "Parent clique %s does not have a communicator for rank %d", + parent_clique_key.ToString(), parent_rank)); + } + + parent_comms.push_back(*parent_comm); + keys.push_back(split_rank); + } + + // Get a globally consistent color value for newly created clique. + int32_t color = GetCommSplitColor(clique_key); + + VLOG(3) << absl::StreamFormat( + "Create NCCL communicators for clique %s; parent=%s; color=%d; " + "rank_mapping=[%s]", + clique_key.ToString(), parent_clique_key.ToString(), color, + absl::StrJoin(rank_mapping, ",", rank_mapping_formatter)); + + TF_ASSIGN_OR_RETURN( + auto splitted_comms, + NcclApi::Default()->CommSplit(parent_comms, color, keys, config)); + + absl::btree_map comms; + for (size_t i = 0; i < splitted_comms.size(); ++i) { + comms[i] = std::move(splitted_comms[i]); + } + + VLOG(3) << absl::StreamFormat( + "Created NCCL communicators for clique %s; parent=%s; color=%d; " + "rank_mapping=[%s]", + clique_key.ToString(), parent_clique_key.ToString(), color, + absl::StrJoin(rank_mapping, ",", rank_mapping_formatter)); + + NcclCliques& cliques = GetNcclCliques(); + absl::MutexLock lock(&cliques.mu); + + // Create a new clique with given clique key and communicators. + auto emplaced = cliques.map.try_emplace(clique_key, clique_key, + std::nullopt, std::move(comms)); + + // We can have a race to create a clique for a given key, the winner + // inserts it into a map and the looser destroys all communicators. + if (!emplaced.second) { + VLOG(3) << "Clique already exists: " + << emplaced.first->second.DebugString(); + } else { + VLOG(3) << "Created new clique: " << emplaced.first->second.DebugString(); + } + + return emplaced.first->second.Acquire(); + }; + + // We include `run_id` to a rendezvous key to make sure that multiple + // concurrent initializations will not join the same rendezvous. The winner + // will update cliques state, and others will destroy unused communicators. + auto rendezvous_key = std::make_tuple(run_id, clique_key, parent_clique_key); + auto initialization_rendezvous_name = absl::StrFormat( + "initialize clique for rank %d; clique=%s; run_id=%d; parent=%s", rank, + clique_key.ToString(), run_id.ToInt(), parent_clique_key.ToString()); + + return RendezvousSingle>( + initialization_rendezvous_name, rendezvous_key, rank_pair, + num_local_participants, split, WarnStuckTimeout(), TerminateTimeout()); +} + +//===----------------------------------------------------------------------===// + +using AcquiredCliquesMap = NcclClique::AcquiredCliquesMap; + +absl::StatusOr> AcquireNcclClique( + se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key, + const NcclCliqueIdCallback& clique_id_callback, int32_t rank, + size_t num_local_participants, const AcquiredCliquesMap& acquired_cliques, + int64_t max_nchannels) { + VLOG(2) << "Acquire NCCL clique " << clique_key.ToString() << "; run" + << run_id.ToString() << "; rank " << rank + << "; num_local_participants=" << num_local_participants + << "; acquired_cliques=" << acquired_cliques.size(); + + // Get the clique lock via the rendezvous to guarantee that all clique + // members participate in XLA run. + auto rendezvous_key = std::make_tuple(run_id, clique_key); + auto rendezvous_name = + absl::StrFormat("acquire clique for rank %d; clique=%s; run_id=%d", rank, + clique_key.ToString(), run_id.ToInt()); + + TF_ASSIGN_OR_RETURN( + std::shared_ptr clique, + RendezvousSingle>( + rendezvous_name, rendezvous_key, num_local_participants, + [&] { + NcclCliques& cliques = GetNcclCliques(); + absl::MutexLock lock(&cliques.mu); + // Returns empty lock if we do not have a clique for `clique_key`. + auto it = cliques.map.find(clique_key); + return it == cliques.map.end() ? NcclClique::Lock() + : it->second.Acquire(); + }, + WarnStuckTimeout(), TerminateTimeout())); + + // If lock is not null return it to the caller. + if (*clique) return clique; + + // Maybe find if we acquired a clique with communicators that we can split. + static const int64_t enable_nccl_comm_splitting = + xla::GetDebugOptionsFromFlags().xla_gpu_enable_nccl_comm_splitting(); + + // We enable resource sharing between parent and split communicators by + // default because that's the only reason why we use comm splitting. + NcclApi::Config config; + config.split_share = true; + config.max_nchannels = max_nchannels; + + if (enable_nccl_comm_splitting) { + for (auto& [acquired_clique_key, acquired_clique] : acquired_cliques) { + // We don't support splitting non-local cliques as it requires careful + // synchronization between multiple processes. + if (!(*acquired_clique)->IsLocal()) continue; + + if (clique_key.IsSubsetOf(acquired_clique_key)) { + return InitializeNcclClique(device, run_id, clique_key, acquired_clique, + num_local_participants, rank, config); + } + } + } + + // If we can't split any of the acquired cliques, create a new one. + return InitializeNcclClique(device, run_id, clique_key, clique_id_callback, + num_local_participants, rank, config); +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/nccl_clique.h b/xla/service/gpu/runtime/nccl_clique.h new file mode 100644 index 0000000000000..d02c68a5e1e7d --- /dev/null +++ b/xla/service/gpu/runtime/nccl_clique.h @@ -0,0 +1,145 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_NCCL_CLIQUE_H_ +#define XLA_SERVICE_GPU_RUNTIME_NCCL_CLIQUE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/btree_map.h" +#include "absl/functional/function_ref.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "xla/executable_run_options.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/lockable.h" +#include "xla/stream_executor/stream_executor.h" + +namespace xla::gpu { + +// NCCL clique (collective clique) is a set of devices that execute collective +// operations (e.g. all-reduce). It is notoriously easy to misuse NCCL +// communicators (see link below) and get a dead lock at run time, so in XLA we +// take extra care to order all collective operations in a way that would not +// lead to a deadlock. +// +// We rely on exclusive access to a NCCL clique (using Lockable mechanism) to +// guarantee that only a set of threads executing a particular collective +// operation can schedule new work using communicators belonging to a clique. +// +// In XLA process we have multiple cliques for different combinations of +// participating devices and properties of collective operations launched on +// them, e.g. mixing NCCL operations launched from CUDA graphs with regularly +// launched operations is prone to dead locks, and we keep them separate. See +// NcclCliqueKey for details. +// +// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#using-multiple-nccl-communicators-concurrently + +//===----------------------------------------------------------------------===// +// NcclUniqueId +//===----------------------------------------------------------------------===// + +// Returns true if the NCCL config is global (NCCL_COMM_ID env variable is set). +bool IsGlobalNcclConfig(); + +// Returns a clique id callback passed as an argument if it's not null or a +// default callback to get create a clique id if we are running in local mode. +absl::StatusOr GetNcclCliqueIdCallback( + const NcclCliqueIdCallback* clique_id_callback, // may be null + bool is_local); + +//===----------------------------------------------------------------------===// +// NcclClique +//===----------------------------------------------------------------------===// + +// A group of NCCL communicators making up a clique. With NCCL it's notoriously +// easy to get a deadlock, so we take extra care by grouping communicators into +// cliques and making sure that we have a well defined order of all collective +// operations that does not lead to deadlocks. +class NcclCliqueCommunicators { + public: + NcclCliqueCommunicators( + NcclCliqueKey clique_key, std::optional clique_id, + absl::btree_map communicators); + + // Returns a NCCL communicator for a given rank if it's in a clique. + std::optional comm(int32_t rank); + + // Return true if clique is local: all communicators belong to current + // process. Non-local cliques spans multiple processes (typically hosts). + bool IsLocal() const; + + // Calls `fn` for each communicator in the clique. + void ForEachComm( + absl::FunctionRef fn); + + const NcclCliqueKey& clique_key() const { return clique_key_; } + const std::optional& clique_id() const { return clique_id_; } + size_t num_communicators() const { return communicators_.size(); } + + std::string DebugString() const; + + private: + NcclCliqueKey clique_key_; + std::optional clique_id_; + + // TODO(ezhulenev): Switch this map to GlobalDeviceId key. + absl::btree_map communicators_; +}; + +struct NcclCliqueName { + static std::string ToString(const NcclCliqueCommunicators& comms) { + return absl::StrFormat("lockable clique %s", comms.clique_key().ToString()); + } +}; + +struct NcclClique : public Lockable { + // We keep acquired cliques in a sorted container to guarantee that all + // participants iterate over cliques in the same order. + using AcquiredCliquesMap = + absl::btree_map, + std::greater>; + + NcclClique(NcclCliqueKey clique_key, std::optional clique_id, + absl::btree_map communicators) + : Lockable(std::move(clique_key), clique_id, std::move(communicators)) {} + + std::string DebugString() const; +}; + +// Acquires an shared access to a NCCL clique (NcclClique::Lock collectively +// owned by `num_local_participants` threads). XLA uses this lock to serialize +// execution of all collective operations sharing a `clique_id`. +// +// If clique for a given key does not exist it will be initialized from newly +// created communicators or maybe created by splitting of the already acquired +// cliques. +absl::StatusOr> AcquireNcclClique( + se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key, + const NcclCliqueIdCallback& clique_id_callback, int32_t rank, + size_t num_local_participants, + const NcclClique::AcquiredCliquesMap& acquired_cliques, + int64_t max_nchannels = 0); + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_RUNTIME_NCCL_CLIQUE_H_ diff --git a/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.cc b/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.cc new file mode 100644 index 0000000000000..4d58b7d0debdc --- /dev/null +++ b/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.cc @@ -0,0 +1,84 @@ +/* Copyright 2024 The OpenXLA Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h" + +#include +#include +#include +#include + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/status.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { + +NcclCollectiveBroadcastStartThunk::NcclCollectiveBroadcastStartThunk( + ThunkInfo thunk_info, NcclApi* nccl_api, + const HloCollectiveBroadcastInstruction* instr, std::vector buffers) + : NcclCollectiveThunk(Thunk::kNcclCollectiveBroadcastStart, thunk_info, + nccl_api, IsSyncCollective(instr)), + config_(GetNcclCollectiveConfig(instr, std::nullopt)), + buffers_(std::move(buffers)) {} + +/*static*/ Status NcclCollectiveBroadcastStartThunk::CheckImplementable( + const HloInstruction* instr, int64_t replica_count, + int64_t partition_count) { + return OkStatus(); +} + +/*static*/ CollectiveOpGroupMode +NcclCollectiveBroadcastStartThunk::GetGroupMode( + const HloCollectiveBroadcastInstruction* inst) { + return GetNcclCollectiveConfig(inst, std::nullopt).group_mode; +} + +Status NcclCollectiveBroadcastStartThunk::RunNcclCollective( + const ExecuteParams& params, se::Stream& stream, + NcclApi::NcclCommHandle comm) { + TF_ASSIGN_OR_RETURN( + std::vector device_buffers, + ConvertToDeviceBuffers(params, buffers_, config_.operand_element_type)); + return ::xla::gpu::RunCollectiveBroadcast(device_buffers, stream, comm, + nccl_api()); +} + +Status RunCollectiveBroadcast(std::vector& buffers, + se::Stream& stream, NcclApi::NcclCommHandle comm, + NcclApi* nccl_api) { + TF_RETURN_IF_ERROR(nccl_api->GroupStart()); + for (auto buffer : buffers) { + se::DeviceMemoryBase src_addr = buffer.source_buffer; + se::DeviceMemoryBase dest_addr = buffer.destination_buffer; + TF_RETURN_IF_ERROR(nccl_api->Broadcast( + // Always use rank 0 since we always broadcast from the first id in + // replica_groups + src_addr, dest_addr, buffer.element_type, buffer.element_count, 0, comm, + &stream)); + } + return nccl_api->GroupEnd(); +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h b/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h new file mode 100644 index 0000000000000..a3311c09e9031 --- /dev/null +++ b/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h @@ -0,0 +1,67 @@ +/* Copyright 2024 The OpenXLA Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_NCCL_COLLECTIVE_BROADCAST_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_NCCL_COLLECTIVE_BROADCAST_THUNK_H_ + +#include +#include + +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/status.h" +#include "xla/stream_executor/stream.h" + +namespace xla::gpu { +// Thunk that performs a NCCL-based collective broadcast. +class NcclCollectiveBroadcastStartThunk : public NcclCollectiveThunk { + public: + static Status CheckImplementable(const HloInstruction* instr, + int64_t replica_count, + int64_t partition_count); + + static CollectiveOpGroupMode GetGroupMode( + const HloCollectiveBroadcastInstruction* inst); + + const NcclCollectiveConfig& config() const override { return config_; } + absl::Span buffers() const { return buffers_; } + + static const char* GetHloOpName() { return "collective-broadcast-start"; } + + NcclCollectiveBroadcastStartThunk( + ThunkInfo thunk_info, NcclApi* nccl_api, + const HloCollectiveBroadcastInstruction* instr, + std::vector buffers); + + protected: + Status RunNcclCollective(const ExecuteParams& params, se::Stream& stream, + NcclApi::NcclCommHandle comm) override; + + private: + const NcclCollectiveConfig config_; + const std::vector buffers_; +}; + +Status RunCollectiveBroadcast(std::vector& buffers, + se::Stream& stream, NcclApi::NcclCommHandle comm, + NcclApi* nccl_api); + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_RUNTIME_NCCL_COLLECTIVE_BROADCAST_THUNK_H_ diff --git a/xla/service/gpu/runtime/nccl_collective_permute_thunk.cc b/xla/service/gpu/runtime/nccl_collective_permute_thunk.cc new file mode 100644 index 0000000000000..9ae4f07c4b780 --- /dev/null +++ b/xla/service/gpu/runtime/nccl_collective_permute_thunk.cc @@ -0,0 +1,224 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/nccl_collective_permute_thunk.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/global_device_id.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/service/gpu/runtime/nccl_p2p_thunk_common.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/stream_executor/stream.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" + +namespace xla { +namespace gpu { + +NcclCollectivePermuteStartThunk::NcclCollectivePermuteStartThunk( + ThunkInfo thunk_info, NcclApi* nccl_api, + const HloCollectivePermuteInstruction* instr, int64_t replica_count, + int64_t partition_count, const Buffer& buffer) + : NcclCollectiveThunk(Thunk::kNcclCollectivePermuteStart, thunk_info, + nccl_api, IsSyncCollective(instr)), + config_(GetNcclP2PConfig(instr, replica_count, partition_count)), + buffer_(buffer) {} + +/*static*/ NcclP2PConfig NcclCollectivePermuteStartThunk::GetNcclP2PConfig( + const HloCollectivePermuteInstruction* instr, int64_t replica_count, + int64_t partition_count) { + NcclP2PConfig collective_permute_config; + auto& config = collective_permute_config.config; + + config.operand_count = 1; + const Shape shape = instr->operand(0)->shape(); + config.operand_element_type.push_back(shape.element_type()); + config.SetCollectiveOpKindAndID(instr); + config.group_mode = GetGroupMode(instr); + + // With a collective permute, all execution instances together form one + // replica group. + const int64_t num_participants = + config.group_mode == CollectiveOpGroupMode::kCrossReplica + ? replica_count + : partition_count; + config.replica_groups.emplace_back(); + ReplicaGroup& replica_group = config.replica_groups.front(); + for (int i = 0; i < num_participants; ++i) { + replica_group.add_replica_ids(i); + } + + const std::vector> source_target_pairs = + instr->source_target_pairs(); + + for (const std::pair& source_target : source_target_pairs) { + int64_t source = source_target.first; + int64_t target = source_target.second; + + collective_permute_config.id_to_source_target.insert({target, {}}) + .first->second.source = source; + collective_permute_config.id_to_source_target.insert({source, {}}) + .first->second.target = target; + } + + return collective_permute_config; +} + +/*static*/ bool NcclCollectivePermuteStartThunk::IsDegenerate( + const HloCollectivePermuteInstruction* instr, int64_t replica_count, + int64_t partition_count) { + // The collective permute is degenerate if all source-target pairs are + // identity, and all the IDs appear in the list. + const std::vector> source_target_pairs = + instr->source_target_pairs(); + // Each ID can appear only once as a source and as a target. So if all pairs + // are identity, all IDs must appear in the list is the size == number of + // replicas/partitions. + const int64_t expected_size = + instr->channel_id().has_value() ? partition_count : replica_count; + return source_target_pairs.size() == expected_size && + absl::c_all_of(source_target_pairs, + [](const std::pair& source_target) { + return source_target.first == source_target.second; + }); +} + +/*static*/ CollectiveOpGroupMode NcclCollectivePermuteStartThunk::GetGroupMode( + const HloCollectivePermuteInstruction* instr) { + return GetCollectiveOpGroupMode(instr->channel_id().has_value(), std::nullopt) + .value(); +} + +absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective( + const ExecuteParams& params, se::Stream& stream, + NcclApi::NcclCommHandle comm) { + TF_ASSIGN_OR_RETURN( + std::vector device_buffers, + ConvertToDeviceBuffers(params, {buffer_}, + config_.config.operand_element_type)); + TF_RET_CHECK(device_buffers.size() == 1) << "Expected one buffer pair."; + + GlobalDeviceId global_device_id = params.collective_params->global_device_id; + + TF_ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID current_logical_id, + params.collective_params->device_assn->LogicalIdForDevice( + global_device_id)); + const int64_t current_id = + config_.config.group_mode == CollectiveOpGroupMode::kCrossReplica + ? current_logical_id.replica_id + : current_logical_id.computation_id; + std::string device_string = GetDeviceString(*params.collective_params); + + const NcclP2PConfig::SourceTargetMapEntry source_target = + NcclP2PConfig::GetSourceTarget(config_.id_to_source_target, current_id); + + return ::xla::gpu::RunCollectivePermute(nccl_api(), source_target, + device_buffers[0], stream, comm, + device_string, current_id); +} + +absl::Status RunCollectivePermute( + NcclApi* nccl_api, NcclP2PConfig::SourceTargetMapEntry source_target, + DeviceBufferPair& buffer, se::Stream& stream, NcclApi::NcclCommHandle comm, + absl::string_view device_string, int64_t current_id) { + // Determine the source and target IDs for this instance. The source ID is the + // ID which will copy its data to this instance. The destination ID is the ID + // to which this instance will copy its data. Either are optional. + // + // No source and no dest: + // - this instance does not actually participate, no one send it any data and + // it does not have to send any data as well. Since there is no dest, + // just memzero() the dest buffer as required by the collective permute + // semantics. + // + // No source, dest present: + // - This instance has to send data to 'dest' Issue an send of the input. + // Since there is no source, memzero the dest buffer. + // + // Source present, no destination: + // - This instance received data from the source, does not have to send data + // to anyone, Issue a receive. + // + // Source and dest both present: + // - Issue a send of the input to dest, receive for the output from the + // src. + // + // + + int device_ordinal = stream.parent()->device_ordinal(); + VLOG(3) << "Performing collective permute from device ordinal: " + << device_ordinal << "current_id " << current_id; + TF_RETURN_IF_ERROR( + MaybeRegisterBuffers(nccl_api, device_ordinal, {buffer}, comm)); + + const std::optional source_id = source_target.source; + const std::optional target_id = source_target.target; + + se::DeviceMemoryBase src_addr = buffer.source_buffer; + se::DeviceMemoryBase dest_addr = buffer.destination_buffer; + + VLOG(3) << absl::StreamFormat("%s : id = %d, source_id = %d, target_id = %d", + device_string, current_id, + source_id.value_or(-1), target_id.value_or(-1)); + + // GroupStart/End API is needed only if we will issue both send & recv calls. + const bool is_nccl_group_needed = (target_id && source_id); + if (is_nccl_group_needed) { + TF_RETURN_IF_ERROR(nccl_api->GroupStart()); + } + + // Send source buffer to target peer if needed. + if (target_id) { + TF_RETURN_IF_ERROR(nccl_api->Send(src_addr, buffer.element_type, + buffer.element_count, *target_id, comm, + &stream)); + } + + // Receive data from the source peer to the destination buffer. + if (source_id) { + TF_RETURN_IF_ERROR(nccl_api->Recv(dest_addr, buffer.element_type, + buffer.element_count, *source_id, comm, + &stream)); + } + + if (is_nccl_group_needed) { + TF_RETURN_IF_ERROR(nccl_api->GroupEnd()); + } + + if (!source_id) { + // If there is no source peer, i.e. no one send us any data, zero out dest + // buffer. + VLOG(3) << absl::StreamFormat("%s : collective-Permute: Issuing MemZero", + device_string); + TF_RETURN_IF_ERROR(stream.MemZero(&dest_addr, dest_addr.size())); + } + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/nccl_collective_permute_thunk.h b/xla/service/gpu/runtime/nccl_collective_permute_thunk.h new file mode 100644 index 0000000000000..062152f79b316 --- /dev/null +++ b/xla/service/gpu/runtime/nccl_collective_permute_thunk.h @@ -0,0 +1,73 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_NCCL_COLLECTIVE_PERMUTE_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_NCCL_COLLECTIVE_PERMUTE_THUNK_H_ + +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/service/gpu/runtime/nccl_p2p_thunk_common.h" +#include "xla/stream_executor/stream.h" + +namespace xla { +namespace gpu { + +// Thunk that performs a NCCL-based collective permute. +class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk { + public: + static NcclP2PConfig GetNcclP2PConfig( + const HloCollectivePermuteInstruction* instr, int64_t replica_count, + int64_t partition_count); + + static bool IsDegenerate(const HloCollectivePermuteInstruction* instr, + int64_t replica_count, int64_t partition_count); + + static CollectiveOpGroupMode GetGroupMode( + const HloCollectivePermuteInstruction* instr); + + NcclCollectivePermuteStartThunk(ThunkInfo thunk_info, NcclApi* nccl_api, + const HloCollectivePermuteInstruction* instr, + int64_t replica_count, + int64_t partition_count, + const Buffer& buffer); + + static const char* GetHloOpName() { return "collective-permute-start"; } + + protected: + const NcclCollectiveConfig& config() const override { return config_.config; } + absl::Status RunNcclCollective(const ExecuteParams& params, + se::Stream& stream, + NcclApi::NcclCommHandle comm) override; + + private: + const NcclP2PConfig config_; + const Buffer buffer_; +}; + +absl::Status RunCollectivePermute( + NcclApi* nccl_api, NcclP2PConfig::SourceTargetMapEntry source_target, + DeviceBufferPair& buffer, se::Stream& stream, NcclApi::NcclCommHandle comm, + absl::string_view device_string, int64_t current_id); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_NCCL_COLLECTIVE_PERMUTE_THUNK_H_ diff --git a/xla/service/gpu/runtime/nccl_collective_thunk.cc b/xla/service/gpu/runtime/nccl_collective_thunk.cc new file mode 100644 index 0000000000000..95b808fcda138 --- /dev/null +++ b/xla/service/gpu/runtime/nccl_collective_thunk.cc @@ -0,0 +1,553 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "mlir/IR/Value.h" // from @llvm-project +#include "xla/debug_options_flags.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/layout_util.h" +#include "xla/primitive_util.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/computation_placer.h" +#include "xla/service/global_device_id.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_clique.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/service/rendezvous.h" +#include "xla/shape.h" +#include "xla/status.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/gpu/gpu_activation.h" +#include "xla/stream_executor/stream.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +#if GOOGLE_CUDA +#include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/gpu_types.h" +#endif // GOOGLE_CUDA + +namespace xla { +namespace gpu { +namespace { + +static constexpr int64_t kCollectiveMemorySpaceColor = 1; +static constexpr int64_t kNoStreamId = 0; + +bool IsTypeSupportedByNccl(PrimitiveType element_type, + Thunk::Kind reduction_op) { + switch (element_type) { + case S8: + case PRED: + case U8: + case S32: + case U32: + case S64: + case U64: + case F16: + case F32: + case F64: + case BF16: + case C64: + case C128: + return true; + case S16: + case U16: + // 16-bit integer reductions are not directly supported by NCCL and cannot + // be implicitly converted into other 16-bit types like ncclFloat16 as + // they involve actual computation and not just data movement. + case F8E5M2: + case F8E4M3FN: + return !IsReductionCollective(reduction_op); + default: + return false; + } +} + +} // namespace + +// This file runs collective ops (i.e. ops that communicate between multiple +// GPUs) using NCCL. +// +// Here's a high-level overview of how running an op works. +// +// - Multiple threads call ExecuteOnStream. +// - All threads that "go together" (i.e. are participating in the "same" +// collective op) choose the same Rendezvous object from a global map. +// - Once all threads have arrived at the Rendezvous, we know exactly which +// GPUs are participating in the op, so we get or create a NcclClique +// containing those GPUs. +// - We perform the NCCL operation using the clique. + +// Returns if the collective communication operation is degenerate because all +// the groups formed by the operation are singleton. A given op can be +// degenerate under several conditions, corresponding to the modes supported +// in GetParticipatingDevices(). +// 1. no channel id, use_global_device_ids = false: +// degenerate if replica_groups are singleton, or groups empty and +// replica_count == 1. +// 2. channel_id is set, use_global_device_ids = false: +// degenerate if replica_groups are singleton and num_partitions == 1, +// or groups empty and num_replicas == 1 && num_partitions == 1. +// 3. channel_id is set, use_global_device_ids = true (flattened-ids): +// degenerate if replica_groups are singleton (groups cannot be empty). +// 4. no channel_id, no use_global_device_ids: +// identical to 1. +// 5. channel_id is set, no use_global_device_ids: +// degenerate if replica_groups are singleton or group emty and +// num_partitions == 1 (since replica groups contain partition ids). +// +bool NcclCollectiveConfig::IsDegenerate(int64_t replica_count, + int64_t partition_count) const { + bool groups_empty = replica_groups.empty(); + + // check if all replica_groups are singleton. If not, then the operation is + // not degenerate. + bool all_groups_singleton = + !groups_empty && + absl::c_all_of(replica_groups, [](const ReplicaGroup& group) { + return group.replica_ids_size() == 1; + }); + + switch (group_mode) { + case CollectiveOpGroupMode::kCrossReplica: + return all_groups_singleton || (groups_empty && replica_count == 1); + case CollectiveOpGroupMode::kCrossPartition: + return all_groups_singleton || (groups_empty && partition_count == 1); + case CollectiveOpGroupMode::kCrossReplicaAndPartition: + return (all_groups_singleton && partition_count == 1) || + (groups_empty && replica_count == 1 && partition_count == 1); + case CollectiveOpGroupMode::kFlattenedID: + CHECK(!groups_empty) + << "replica groups cannot be empty if use_global_device_ids = true"; + return all_groups_singleton; + default: + CHECK(0) << "Invalid collective op mode"; + return false; + } +} + +void NcclCollectiveConfig::SetCollectiveOpKindAndID( + const HloCollectivePermuteInstruction* instr) { + if (instr->channel_id().has_value()) { + collective_op_kind = RendezvousKey::kCrossModule; + op_id = instr->channel_id().value(); + } else { + collective_op_kind = RendezvousKey::kCrossReplica; + op_id = static_cast(instr->GetModule()->unique_id()); + } +} + +void NcclCollectiveConfig::SetCollectiveOpKindAndID( + const HloSendRecvInstruction* instr) { + int64_t channel_id = instr->channel_id().value_or(0); + if (channel_id > 0) { + collective_op_kind = RendezvousKey::kCrossModule; + op_id = channel_id; + } else { + collective_op_kind = RendezvousKey::kCrossReplica; + op_id = static_cast(instr->GetModule()->unique_id()); + } +} + +NcclCollectiveConfig GetNcclCollectiveConfig( + const HloInstruction* hlo, std::optional use_global_device_ids) { + NcclCollectiveConfig config; + config.operand_count = hlo->operands().size(); + config.operand_element_type.reserve(config.operand_count); + for (int i = 0; i < config.operand_count; i++) { + config.operand_element_type.push_back( + hlo->operand(i)->shape().element_type()); + } + config.replica_groups = hlo->replica_groups(); + + if (hlo->channel_id().has_value()) { + config.collective_op_kind = RendezvousKey::kCrossModule; + config.op_id = *hlo->channel_id(); + } else { + config.collective_op_kind = RendezvousKey::kCrossReplica; + config.op_id = static_cast(hlo->GetModule()->unique_id()); + } + + config.group_mode = GetCollectiveOpGroupMode(hlo->channel_id().has_value(), + use_global_device_ids) + .value(); + + return config; +} + +NcclCollectiveThunk::NcclCollectiveThunk(Kind kind, ThunkInfo thunk_info, + NcclApi* nccl_api, bool is_sync) + : Thunk(kind, thunk_info), + nccl_api_(nccl_api), + async_events_(is_sync ? nullptr : new AsyncEvents()) {} + +static absl::StatusOr GetNcclCliqueKey( + const Thunk::CollectiveExecuteParams& params, + const std::vector& replica_groups, + CollectiveOpGroupMode group_mode, int64_t stream_id, + AsyncStreamKind stream_kind) { + GlobalDeviceId global_device_id = params.global_device_id; + + TF_ASSIGN_OR_RETURN( + std::vector participants, + GetParticipatingDevices(global_device_id, *params.device_assn, + replica_groups, group_mode)); + + if (IsGlobalNcclConfig() && + (participants.size() != params.device_assn->replica_count())) { + return InvalidArgument( + "Partial replica groups are not allowed when using NCCL_COMM_ID " + "environment configuration."); + } + static const bool enable_per_stream_comms = + xla::GetDebugOptionsFromFlags().xla_gpu_enable_nccl_per_stream_comms(); + + return NcclCliqueKey(std::move(participants), + enable_per_stream_comms ? stream_id : kNoStreamId, + stream_kind); +} + +absl::StatusOr GetNcclComm( + const Thunk::CollectiveExecuteParams& params, + const Thunk::CollectiveCliques& collective_cliques, + const std::vector& replica_groups, + CollectiveOpGroupMode group_mode, int64_t stream_id, + AsyncStreamKind stream_kind) { + TF_ASSIGN_OR_RETURN(NcclCliqueKey clique_key, + GetNcclCliqueKey(params, replica_groups, group_mode, + stream_id, stream_kind)); + + std::optional rank = clique_key.rank(params.global_device_id); + return collective_cliques.GetComm(std::move(clique_key), *rank); +} + +absl::StatusOr> ConvertToDeviceBuffers( + const Thunk::ExecuteParams& params, + const std::vector& buffers, + const std::vector& element_types) { + return ConvertToDeviceBuffers(params.buffer_allocations, buffers, + element_types); +} + +absl::StatusOr> ConvertToDeviceBuffers( + const BufferAllocations* buffer_allocations, + const std::vector& buffers, + const std::vector& element_types) { + if (buffers.size() != element_types.size()) + return FailedPrecondition("Mismatch in operand buffer counts."); + + std::vector device_buffers; + device_buffers.reserve(buffers.size()); + for (int i = 0; i < buffers.size(); ++i) { + device_buffers.emplace_back(DeviceBufferPair{ + element_types[i], buffers[i].element_count, + buffer_allocations->GetDeviceAddress(buffers[i].source_buffer), + buffer_allocations->GetDeviceAddress(buffers[i].destination_buffer), + buffers[i].source_memory_space, buffers[i].destination_memory_space}); + } + return device_buffers; +} + +Status RegisterBufferOnce(NcclApi* nccl_api, int device_ordinal, + NcclApi::NcclCommHandle comm, + se::DeviceMemoryBase buffer) { + // Keep track of which communicators we have registered for already. + // Each ncclMemAlloc'd buffer needs to be registered once per comm. + struct RegisteredBuffers { + absl::Mutex mu; + // Device ordinal, communicator, and base pointer address. + absl::flat_hash_set> records + ABSL_GUARDED_BY(mu); + // Buffers could be deregistered with ncclCommDeregister. + std::vector handles + ABSL_GUARDED_BY(mu); + }; + static auto& all_registered = *new RegisteredBuffers; + + // Since each XLA buffer is a slice into a larger BFCAllocator chunk, first + // get the base address of buffer. We will use the base address to keep track + // of which chunks we have registered. + void* base_ptr; + size_t base_size; +#ifdef GOOGLE_CUDA + TF_RETURN_IF_ERROR(se::gpu::GpuDriver::GetPointerAddressRange( + reinterpret_cast(buffer.opaque()), + reinterpret_cast(&base_ptr), &base_size)); +#else // GOOGLE_CUDA + base_ptr = nullptr; + base_size = 0; +#endif // GOOGLE_CUDA + + absl::MutexLock lock(&all_registered.mu); + if (!all_registered.records.contains({device_ordinal, comm, base_ptr})) { + // ncclCommRegister will internally get and use the base address/size of the + // address we provide. + TF_ASSIGN_OR_RETURN(NcclApi::NcclRegisteredBufferHandle handle, + nccl_api->RegisterBuffer(comm, buffer)); + all_registered.handles.push_back(handle); + all_registered.records.insert({device_ordinal, comm, base_ptr}); + } + return OkStatus(); +} + +Status MaybeRegisterBuffers(NcclApi* nccl_api, int device_ordinal, + const std::vector& buffers, + NcclApi::NcclCommHandle comm) { + for (int i = 0; i < buffers.size(); ++i) { + if (buffers[i].source_memory_space == kCollectiveMemorySpaceColor) { + TF_RETURN_IF_ERROR(RegisterBufferOnce(nccl_api, device_ordinal, comm, + buffers[i].source_buffer)); + } + if (buffers[i].destination_memory_space == kCollectiveMemorySpaceColor) { + TF_RETURN_IF_ERROR(RegisterBufferOnce(nccl_api, device_ordinal, comm, + buffers[i].destination_buffer)); + } + } + return OkStatus(); +} + +absl::Status NcclCollectiveThunk::AsyncEvents::Initialize( + se::StreamExecutor* executor) { + absl::MutexLock lock(&mu_); + if (events_.contains(executor)) return absl::OkStatus(); + + se::Event event(executor); + if (!event.Init()) { + return absl::InternalError( + "Failed to initialize collective operation async completion event"); + } + + events_.try_emplace(executor, std::move(event)); + return absl::OkStatus(); +} + +absl::StatusOr NcclCollectiveThunk::AsyncEvents::GetEvent( + se::StreamExecutor* executor) { + absl::MutexLock lock(&mu_); + + auto event = events_.find(executor); + if (event == events_.end()) { + return absl::InternalError( + "Collective operation async completion event not initialized"); + } + + return &event->second; +} + +absl::Status NcclCollectiveThunk::Prepare(const PrepareParams& params, + ResourceRequests& resource_requests) { + const CollectiveExecuteParams* collectives = params.collective_params; + + TF_ASSIGN_OR_RETURN( + std::vector participants, + GetParticipatingDevices(collectives->global_device_id, + *collectives->device_assn, + config().replica_groups, config().group_mode)); + + std::vector local_devices; + if (collectives->global_device_id_map) { + local_devices.reserve(collectives->global_device_id_map->size()); + for (const auto& entry : *collectives->global_device_id_map) { + local_devices.push_back(entry.second); + } + } + + size_t num_local_participants = GetNumLocalParticipants( + participants, + collectives->global_device_id_map ? &local_devices : nullptr); + AsyncStreamKind stream_kind = GetAsyncStreamKind(); + static const bool enable_per_stream_comms = + xla::GetDebugOptionsFromFlags().xla_gpu_enable_nccl_per_stream_comms(); + return resource_requests.AddClique( + NcclCliqueKey(std::move(participants), + enable_per_stream_comms ? GetStreamId() : kNoStreamId, + stream_kind), + num_local_participants); +} + +absl::Status NcclCollectiveThunk::Initialize(const InitializeParams& params) { + if (async_events_) { + TF_RETURN_IF_ERROR(async_events_->Initialize(params.executor)); + } + return absl::OkStatus(); +} + +namespace { +// Wrap NcclCliqueKey into a unique struct to guarantee we do not accidentally +// try to run multiple unrelated rendezvous for a same key. +struct FirstCallRendezvousKey { + NcclCliqueKey clique_key; + + template + friend H AbslHashValue(H h, const FirstCallRendezvousKey& key) { + return H::combine(std::move(h), key.clique_key); + } +}; + +bool operator==(const FirstCallRendezvousKey& a, + const FirstCallRendezvousKey& b) { + return a.clique_key == b.clique_key; +} +} // namespace + +Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) { + VLOG(1) << absl::StreamFormat("Starting %s %s.", IsAsync() ? "async" : "sync", + Thunk::KindToString(kind())); + const int64_t stream_id = GetStreamId(); + AsyncStreamKind stream_kind = GetAsyncStreamKind(); + TF_ASSIGN_OR_RETURN( + NcclApi::NcclCommHandle comm, + GetNcclComm(*params.collective_params, *params.collective_cliques, + config().replica_groups, config().group_mode, stream_id, + stream_kind)); + + se::StreamExecutor* executor = params.stream->parent(); + int64_t async_stream_idx = static_cast(stream_kind); + + if (IsAsync()) { + // Launch collective operation on an async stream. + se::Stream& async_stream = *params.async_comms_streams[async_stream_idx]; + + // Wait for main compute stream to make sure all buffers are ready. + TF_RETURN_IF_ERROR(async_stream.WaitFor(params.stream)); + + TF_RETURN_IF_ERROR(RunNcclCollective(params, async_stream, comm)); + + // Record collective operation completion. + TF_ASSIGN_OR_RETURN(se::Event * event, async_events_->GetEvent(executor)); + TF_RETURN_IF_ERROR(async_stream.RecordEvent(event)); + + } else { + // Launch collective operation on a main stream. + TF_RETURN_IF_ERROR(RunNcclCollective(params, *params.stream, comm)); + } + + // After a first execution of this instance of collective operation do a + // rendezvous with other participants to make sure that all of them allocated + // required state (internal to NCCL) and ready to continue. Going too far + // ahead on one rank leads to deadlocks in NCCL. + if (NeedFirstCallRendzevous() && !first_call_rendezvous_flag_.IsCompleted()) { + TF_ASSIGN_OR_RETURN( + NcclCliqueKey clique_key, + GetNcclCliqueKey(*params.collective_params, config().replica_groups, + config().group_mode, stream_id, stream_kind)); + + TF_ASSIGN_OR_RETURN( + size_t num_local_participants, + params.collective_cliques->num_communicators(clique_key)); + + auto global_device_id = params.collective_params->global_device_id; + VLOG(1) << "Do a rendezvous after a first call to " + << Thunk::KindToString(kind()) + << "; run_id=" << params.collective_params->run_id.ToInt() + << "; op_id=" << config().op_id + << "; num_local_participants=" << num_local_participants + << "; rank=" << clique_key.rank(global_device_id).value_or(-1) + << "; clique_key=" << clique_key.ToString(); + + auto rendezvous_key = FirstCallRendezvousKey{std::move(clique_key)}; + auto rendezvous_name = absl::StrFormat( + "first call to collective operation %d; run_id=%d", config().op_id, + params.collective_params->run_id.ToInt()); + + RendezvousSingle(first_call_rendezvous_flag_, rendezvous_name, + rendezvous_key, num_local_participants, + /*warn_stuck_timeout=*/absl::Seconds(20), + /*terminate_timeout=*/absl::Seconds(40)); + } + + return absl::OkStatus(); +} + +std::string NcclCollectiveThunk::GetDeviceString( + const Thunk::CollectiveExecuteParams& collective_params) { + GlobalDeviceId global_device_id = collective_params.global_device_id; + DeviceAssignment::LogicalID logical_id = + collective_params.device_assn->LogicalIdForDevice(global_device_id) + .value(); + return absl::StrFormat("(r%d, p%d) : GlobalID %d, ord %d", + logical_id.replica_id, logical_id.computation_id, + global_device_id.value(), + collective_params.local_device_ordinal); +} + +NcclCollectiveDoneThunk::NcclCollectiveDoneThunk( + Thunk::Kind kind, ThunkInfo thunk_info, + std::shared_ptr async_events) + : Thunk(kind, std::move(thunk_info)), async_events_(async_events) {} + +absl::Status NcclCollectiveDoneThunk::ExecuteOnStream( + const ExecuteParams& params) { + se::StreamExecutor* executor = params.stream->parent(); + TF_ASSIGN_OR_RETURN(se::Event * event, async_events_->GetEvent(executor)); + return params.stream->WaitFor(event); +} + +absl::Status IsValidOperand(mlir::Value operand, Thunk::Kind reduction_op) { + Shape shape = GetShape(operand); + return IsValidOperand(shape, reduction_op); +} + +absl::Status IsValidOperand(Shape shape, Thunk::Kind reduction_op) { + if (!LayoutUtil::IsDenseArray(shape)) { + return absl::AbortedError( + absl::StrFormat("input is not a dense array: %s", + shape.ToString(/*print_layout=*/true))); + } + if (!IsTypeSupportedByNccl(shape.element_type(), reduction_op)) { + return absl::AbortedError(absl::StrFormat( + "element type %s not suppored by NCCL", + primitive_util::LowercasePrimitiveTypeName(shape.element_type()))); + } + return absl::OkStatus(); +} + +size_t GetNumLocalParticipants( + const std::vector& participants, + const std::vector* local_devices) { + if (local_devices == nullptr) return participants.size(); + + return absl::c_count_if(participants, [&](const GlobalDeviceId& device_id) { + return absl::c_linear_search(*local_devices, device_id); + }); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/nccl_collective_thunk.h b/xla/service/gpu/runtime/nccl_collective_thunk.h new file mode 100644 index 0000000000000..5a879487622dc --- /dev/null +++ b/xla/service/gpu/runtime/nccl_collective_thunk.h @@ -0,0 +1,301 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_NCCL_COLLECTIVE_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_NCCL_COLLECTIVE_THUNK_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/node_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/global_device_id.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/service/rendezvous.h" +#include "xla/shape.h" +#include "xla/status.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/stream.h" +#include "xla/translate/mhlo_to_hlo/attribute_exporter.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +class NcclClique; + +struct NcclCollectiveConfig { + int64_t operand_count; + std::vector operand_element_type; + std::vector replica_groups; + RendezvousKey::CollectiveOpKind collective_op_kind; + int64_t op_id; + CollectiveOpGroupMode group_mode; + + template + void SetCollectiveOpKindAndID(OpT op); + void SetCollectiveOpKindAndID(const HloCollectivePermuteInstruction* instr); + void SetCollectiveOpKindAndID(const HloSendRecvInstruction* instr); + bool IsDegenerate(int64_t replica_count, int64_t partition_count) const; +}; + +template +void NcclCollectiveConfig::SetCollectiveOpKindAndID(OpT op) { + if (op.getChannelId()) { + collective_op_kind = RendezvousKey::kCrossModule; + op_id = static_cast(op.getChannelId()->getHandle()); + } else { + collective_op_kind = RendezvousKey::kCrossReplica; + mlir::ModuleOp parent = op->template getParentOfType(); + mlir::IntegerAttr unique_id = + parent->getAttrOfType("hlo.unique_id"); + op_id = static_cast(unique_id.getInt()); + } +} + +NcclCollectiveConfig GetNcclCollectiveConfig( + const HloInstruction* hlo, std::optional use_global_device_ids); + +template +NcclCollectiveConfig GetNcclCollectiveConfigForMlir( + OpT op, std::optional use_global_device_ids) { + NcclCollectiveConfig config; + config.operand_count = op.getInputs().size(); + config.operand_element_type.reserve(config.operand_count); + for (int i = 0; i < config.operand_count; i++) { + const Shape shape = GetShape(op.getInputs()[i]); + config.operand_element_type.push_back(shape.element_type()); + } + config.replica_groups = ConvertReplicaGroups(op.getReplicaGroups()).value(); + config.SetCollectiveOpKindAndID(op); + config.group_mode = GetCollectiveOpGroupMode(op.getChannelId().has_value(), + use_global_device_ids) + .value(); + return config; +} + +//===----------------------------------------------------------------------===// +// NcclCollectiveThunk +//===----------------------------------------------------------------------===// + +// Forward declare. +class NcclCollectiveDoneThunk; + +// Thunk base class for NCCL collective operations. +class NcclCollectiveThunk : public Thunk { + public: + NcclCollectiveThunk(Kind kind, ThunkInfo thunk_info, NcclApi* nccl_api, + bool is_sync); + + struct Buffer { + int64_t element_count; + BufferAllocation::Slice source_buffer; + BufferAllocation::Slice destination_buffer; + int64_t source_memory_space; + int64_t destination_memory_space; + mlir::Value source_value; + mlir::Value destination_value; + }; + + // Completion events for asynchronous collective operations (operations + // launched on a dedicated stream that is synchronized with main compute + // stream only when needed). + class AsyncEvents { + private: + friend class NcclCollectiveThunk; + friend class NcclCollectiveDoneThunk; + + absl::Status Initialize(se::StreamExecutor* executor); + absl::StatusOr GetEvent(se::StreamExecutor* executor); + + private: + absl::Mutex mu_; + absl::node_hash_map events_ + ABSL_GUARDED_BY(mu_); + }; + + // Logging support. + static std::string GetDeviceString( + const Thunk::CollectiveExecuteParams& params); + + absl::Status Prepare(const PrepareParams& params, + ResourceRequests& resource_requests) override; + + absl::Status Initialize(const InitializeParams& params) override; + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + NcclApi* nccl_api() const { return nccl_api_; } + std::shared_ptr async_events() const { return async_events_; } + void set_async_events(std::shared_ptr async_events) { + async_events_ = async_events; + } + + protected: + virtual absl::Status RunNcclCollective(const ExecuteParams& params, + se::Stream& stream, + NcclApi::NcclCommHandle comm) = 0; + virtual const NcclCollectiveConfig& config() const = 0; + virtual AsyncStreamKind GetAsyncStreamKind() const { + return AsyncStreamKind::kCollective; + } + + // A collective thunk is normally an independent operation in a sense that + // different instances of the same collective thunk communicate each other. + // The only exception are SendThunk and RecvThunk. Assume two devices are + // executing a program contains the following instructions, the Recv from + // device 1 will release the Send from device 0. Adding first call + // rendezvous on the SendThunk would cause a runtime deadlock. + // Send(src_target={0,1}) + // Recv(src_target={0,1}) + virtual bool NeedFirstCallRendzevous() const { return true; } + + private: + bool IsAsync() const { return async_events_ != nullptr; } + int64_t GetStreamId() const { + return xla::gpu::GetStreamId(IsAsync(), GetAsyncStreamKind()); + } + + NcclApi* nccl_api_; + std::shared_ptr async_events_; + + // After a first call to this particular instance of a NCCL collective thunk + // we do a round of rendezvous to make sure that all participants successfully + // allocated on-device state required for executing collective operation. This + // is required to avoid deadlocks when one device goes too far ahead and + // causes a deadlock in CUDA driver (root cause is mysterious). + // + // TODO(ezhulenev): Try to move this flag to NCCL clique as we need to make + // sure that all NCCL resources are allocated just once. + RendezvousSingleFlag first_call_rendezvous_flag_; +}; + +//===----------------------------------------------------------------------===// +// NcclCollectiveDoneThunk +//===----------------------------------------------------------------------===// + +class NcclCollectiveDoneThunk : public Thunk { + public: + NcclCollectiveDoneThunk( + Thunk::Kind kind, ThunkInfo thunk_info, + std::shared_ptr async_events); + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + private: + std::shared_ptr async_events_; +}; + +//===----------------------------------------------------------------------===// + +absl::Status IsValidOperand(mlir::Value operand, Thunk::Kind reduction_op); + +absl::Status IsValidOperand(Shape shape, Thunk::Kind reduction_op); + +template +absl::Status AddOpDescription(absl::Status status, OpT op, + int64_t replica_count, int64_t partition_count) { + if (status.ok()) { + return status; + } + CollectiveOpGroupMode group_mode = NcclThunkType::GetGroupMode(op); + + int64_t operand_count = 0; + std::string str; + + if constexpr (std::is_base_of_v>) { + operand_count = op->operand_count(); + str = op->ToString(); + } else { + operand_count = op->getNumOperands() / 2; + str = llvm_ir::DumpToString(op.getOperation()); + } + + return Status( + status.code(), + absl::StrFormat( + "%s\n" + "%s with replica_count: %d, partition_count: %d, group_mode: %s, " + "operand_count: %d\n%s", + status.message(), NcclThunkType::GetHloOpName(), replica_count, + partition_count, CollectiveOpGroupModeToString(group_mode), + operand_count, str)); +} + +//===----------------------------------------------------------------------===// + +size_t GetNumLocalParticipants( + const std::vector& participants, + const std::vector* local_devices); // may be null + +absl::StatusOr GetNcclComm( + const Thunk::CollectiveExecuteParams& params, + const Thunk::CollectiveCliques& collective_cliques, + const std::vector& replica_groups, + CollectiveOpGroupMode group_mode, int64_t stream_id, + AsyncStreamKind stream_kind); + +struct DeviceBufferPair { + PrimitiveType element_type; + int64_t element_count; + se::DeviceMemoryBase source_buffer; + se::DeviceMemoryBase destination_buffer; + int64_t source_memory_space; + int64_t destination_memory_space; +}; + +absl::StatusOr> ConvertToDeviceBuffers( + const Thunk::ExecuteParams& params, + const std::vector& buffers, + const std::vector& element_types); + +absl::StatusOr> ConvertToDeviceBuffers( + const BufferAllocations* buffer_allocations, + const std::vector& buffers, + const std::vector& element_types); + +// Registers buffers allocated in collective memory (see ncclMemAlloc) with a +// communicator to enable zero-copy collectives. +// +// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html +Status MaybeRegisterBuffers(NcclApi* nccl_api, int device_ordinal, + const std::vector& buffers, + NcclApi::NcclCommHandle comm); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_NCCL_COLLECTIVE_THUNK_H_ diff --git a/xla/service/gpu/runtime/nccl_p2p_thunk_common.cc b/xla/service/gpu/runtime/nccl_p2p_thunk_common.cc new file mode 100644 index 0000000000000..30b2d3c4ed988 --- /dev/null +++ b/xla/service/gpu/runtime/nccl_p2p_thunk_common.cc @@ -0,0 +1,179 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/nccl_p2p_thunk_common.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/service/hlo_parser.h" +#include "xla/shape.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +absl::Status ExecutionCounters::Initialize(se::StreamExecutor* executor) { + absl::MutexLock lock(&mu_); + if (counters_.contains(executor)) return absl::OkStatus(); + + counters_.emplace(executor, 0); + return absl::OkStatus(); +} + +absl::StatusOr ExecutionCounters::GetCounter( + se::StreamExecutor* executor) { + absl::MutexLock lock(&mu_); + + auto counter = counters_.find(executor); + if (counter == counters_.end()) { + return absl::InternalError("Execution counter not initialized"); + } + + return &counter->second; +} + +absl::StatusOr>> GetSourceTargetPairs( + mlir::DictionaryAttr frontend_attributes) { + mlir::StringAttr src_dst_string = frontend_attributes.getAs( + kSendRecvSourceTargetPairsAttr); + if (!src_dst_string) { + return absl::AbortedError( + absl::StrCat("expecting send/recv op with string attribute ", + kSendRecvSourceTargetPairsAttr)); + } + TF_ASSIGN_OR_RETURN(std::vector replica_groups, + ParseReplicaGroupsOnly(src_dst_string.str())); + std::vector> source_target_pairs; + source_target_pairs.reserve(replica_groups.size()); + for (const ReplicaGroup& replica_group : replica_groups) { + TF_RET_CHECK(replica_group.replica_ids_size() == 2); + source_target_pairs.emplace_back(replica_group.replica_ids(0), + replica_group.replica_ids(1)); + } + return source_target_pairs; +} + +NcclP2PConfig GetNcclP2PConfigForSendRecv(const HloSendRecvInstruction* instr, + const Shape& shape, + int64_t replica_count, + int64_t partition_count) { + NcclP2PConfig p2p_config; + auto& config = p2p_config.config; + + config.operand_count = 1; + config.operand_element_type.push_back(shape.element_type()); + config.SetCollectiveOpKindAndID(instr); + config.group_mode = GetCollectiveOpGroupMode( + instr->channel_id().value_or(0) > 0, std::nullopt) + .value(); + + // All execution instances of a Send/Recv together form a replica group. + const int64_t num_participants = + config.group_mode == CollectiveOpGroupMode::kCrossReplica + ? replica_count + : partition_count; + config.replica_groups.emplace_back(); + ReplicaGroup& replica_group = config.replica_groups.front(); + for (int i = 0; i < num_participants; ++i) { + replica_group.add_replica_ids(i); + } + + std::optional source_target_pairs_string = + instr->frontend_attributes().map().at(kSendRecvSourceTargetPairsAttr); + + // We currently ignore problems related to the source-target-pair string to + // avoid using StatusOr for the return type. This should be ok as Send/Recv + // are generated by the compiler. + if (!source_target_pairs_string.has_value()) { + return p2p_config; + } + auto statusor = ParseReplicaGroupsOnly(*source_target_pairs_string); + if (!statusor.ok()) { + return p2p_config; + } + + std::vector replica_groups = statusor.value(); + auto validation_it = + instr->frontend_attributes().map().find(kSendRecvValidationAttr); + NcclP2PConfig::ValidationKind validation_kind = + NcclP2PConfig::ValidationKind::kValid; + std::vector bounds; + if (validation_it != instr->frontend_attributes().map().end()) { + if (validation_it->second == "invalid") { + validation_kind = NcclP2PConfig::ValidationKind::kInvalid; + } else { + auto statusor_bounds = ParseReplicaGroupsOnly(validation_it->second); + if (!statusor_bounds.ok() || + statusor_bounds.value().size() != replica_groups.size()) { + // Ignore problems related to the source-target-pair string to avoid + // using StatusOr for the return type. + return p2p_config; + } + validation_kind = NcclP2PConfig::ValidationKind::kConditional; + bounds = statusor_bounds.value(); + } + } + + int i = 0; + p2p_config.validation_kind = validation_kind; + NcclP2PConfig::SourceTargetToBounds& source_target_to_bounds = + p2p_config.source_target_to_bounds; + for (const ReplicaGroup& replica_group : replica_groups) { + int64_t source = replica_group.replica_ids(0); + int64_t target = replica_group.replica_ids(1); + + p2p_config.id_to_source_target.insert({target, {}}).first->second.source = + source; + p2p_config.id_to_source_target.insert({source, {}}).first->second.target = + target; + + if (validation_kind == NcclP2PConfig::ValidationKind::kConditional) { + const ReplicaGroup& bound = bounds[i]; + int64_t lower = bound.replica_ids(0); + int64_t upper = bound.replica_ids(1); + source_target_to_bounds[std::make_pair(source, target)] = + std::make_pair(lower, upper); + i++; + } + } + + return p2p_config; +} + +AsyncStreamKind GetStreamKindForSendRecv(const HloSendRecvInstruction* instr) { + auto it = instr->frontend_attributes().map().find(kSendRecvPipelineAttr); + if (it != instr->frontend_attributes().map().end() && it->second == "1") { + return AsyncStreamKind::kP2P1; + } + return AsyncStreamKind::kP2P0; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/nccl_p2p_thunk_common.h b/xla/service/gpu/runtime/nccl_p2p_thunk_common.h new file mode 100644 index 0000000000000..0a255fc3012cc --- /dev/null +++ b/xla/service/gpu/runtime/nccl_p2p_thunk_common.h @@ -0,0 +1,106 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_NCCL_P2P_THUNK_COMMON_H_ +#define XLA_SERVICE_GPU_RUNTIME_NCCL_P2P_THUNK_COMMON_H_ + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/shape.h" +#include "xla/stream_executor/stream_executor_pimpl.h" + +namespace xla { +namespace gpu { + +// Count the number of times a Send or Recv instruction executed on a device. +class ExecutionCounters { + public: + absl::Status Initialize(se::StreamExecutor* executor); + absl::StatusOr GetCounter(se::StreamExecutor* executor); + + private: + absl::Mutex mu_; + absl::flat_hash_map counters_ + ABSL_GUARDED_BY(mu_); +}; + +// Records the information for implementing CollectivePermute, Send and Recv. +struct NcclP2PConfig { + // Record the target ID for sending a data and the source ID from which to + // receive a data. Either target or source can be optional. + struct SourceTargetMapEntry { + std::optional source; + std::optional target; + }; + + using IdToSourceTargetMap = + absl::flat_hash_map; + + enum class ValidationKind { kValid = 0, kInvalid = 1, kConditional = 2 }; + + using SourceTargetToBounds = absl::flat_hash_map, + std::pair>; + + // Returns the source and target ID corresponding to the given ID (these IDs + // are replica_ids for cross replica permute or partition_ids for cross + // partition permute). The source ID is the id which will send data to this + // ID and the target ID is the id to which this ID will send its data. Either + // can be optional. + static SourceTargetMapEntry GetSourceTarget( + const IdToSourceTargetMap& id_to_source_target, int64_t id) { + auto it = id_to_source_target.find(id); + if (it != id_to_source_target.end()) return it->second; + return SourceTargetMapEntry{}; + } + + NcclCollectiveConfig config; + IdToSourceTargetMap id_to_source_target; + ValidationKind validation_kind = ValidationKind::kValid; + // When a Send or Recv has validation_kind = ValidationKind::kConditional, + // record the valid execution numbers as a pair of [lower-bound, upper-bound] + // for each source and target pair. + SourceTargetToBounds source_target_to_bounds; +}; + +// Extracts source/target pairs for send/recv from frontend attributes. +absl::StatusOr>> GetSourceTargetPairs( + mlir::DictionaryAttr frontend_attributes); + +// Constructs the NcclP2PConfig for an HLO Send or Recv instruction. +NcclP2PConfig GetNcclP2PConfigForSendRecv(const HloSendRecvInstruction* instr, + const Shape& shape, + int64_t replica_count, + int64_t partition_count); +// Returns the stream kind for the asynchronous stream used to execute an HLO +// Send or Recv instruction, by inspecting the frontend attributes of the +// instruction. +AsyncStreamKind GetStreamKindForSendRecv(const HloSendRecvInstruction* instr); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_NCCL_P2P_THUNK_COMMON_H_ diff --git a/xla/service/gpu/runtime/nccl_recv_thunk.cc b/xla/service/gpu/runtime/nccl_recv_thunk.cc new file mode 100644 index 0000000000000..a5b82cd2c8fb9 --- /dev/null +++ b/xla/service/gpu/runtime/nccl_recv_thunk.cc @@ -0,0 +1,144 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/nccl_recv_thunk.h" + +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/computation_placer.h" +#include "xla/service/global_device_id.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/service/gpu/runtime/nccl_p2p_thunk_common.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/stream.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +NcclRecvThunk::NcclRecvThunk(ThunkInfo thunk_info, NcclApi* nccl_api, + const HloRecvInstruction* instr, + int64_t replica_count, int64_t partition_count, + const Buffer& buffer) + : NcclCollectiveThunk(Thunk::kNcclRecv, thunk_info, nccl_api, + /*is_sync=*/false), + config_(GetNcclP2PConfigForSendRecv(instr, instr->shape().tuple_shapes(0), + replica_count, partition_count)), + buffer_(buffer), + stream_kind_(GetStreamKindForSendRecv(instr)), + execution_counters_(config_.validation_kind == + NcclP2PConfig::ValidationKind::kConditional + ? new ExecutionCounters() + : nullptr) {} + +absl::Status NcclRecvThunk::Initialize(const InitializeParams& params) { + TF_RETURN_IF_ERROR(NcclCollectiveThunk::Initialize(params)); + if (execution_counters_) { + TF_RETURN_IF_ERROR(execution_counters_->Initialize(params.executor)); + } + return absl::OkStatus(); +} + +absl::Status NcclRecvThunk::RunNcclCollective(const ExecuteParams& params, + se::Stream& stream, + NcclApi::NcclCommHandle comm) { + TF_ASSIGN_OR_RETURN( + std::vector device_buffers, + ConvertToDeviceBuffers(params, {buffer_}, + config_.config.operand_element_type)); + TF_RET_CHECK(device_buffers.size() == 1) << "Expected one buffer pair."; + + GlobalDeviceId global_device_id = params.collective_params->global_device_id; + + TF_ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID current_logical_id, + params.collective_params->device_assn->LogicalIdForDevice( + global_device_id)); + const int64_t current_id = + config_.config.group_mode == CollectiveOpGroupMode::kCrossReplica + ? current_logical_id.replica_id + : current_logical_id.computation_id; + std::string device_string = GetDeviceString(*params.collective_params); + + const NcclP2PConfig::SourceTargetMapEntry source_target = + NcclP2PConfig::GetSourceTarget(config_.id_to_source_target, current_id); + DeviceBufferPair& buffer = device_buffers[0]; + + // Determine the source IDs for this instance. The source ID is the ID for + // the peer that will copy its data to this instance. If there is no + // source, just memzero() the destination buffer. + int device_ordinal = stream.parent()->device_ordinal(); + VLOG(3) << "Performing Recv from device ordinal: " << device_ordinal + << "current_id " << current_id; + TF_RETURN_IF_ERROR( + MaybeRegisterBuffers(nccl_api(), device_ordinal, {buffer}, comm)); + + const std::optional source_id = source_target.source; + se::DeviceMemoryBase dest_addr = buffer.destination_buffer; + + VLOG(3) << absl::StreamFormat("%s : id = %d, source_id = %d", device_string, + current_id, source_id.value_or(-1)); + + // Receive data from the source peer to the destination buffer. + if (source_id) { + bool should_run = + config_.validation_kind == NcclP2PConfig::ValidationKind::kInvalid + ? false + : true; + if (config_.validation_kind == + NcclP2PConfig::ValidationKind::kConditional) { + se::StreamExecutor* executor = params.stream->parent(); + TF_ASSIGN_OR_RETURN(int64_t * counter, + execution_counters_->GetCounter(executor)); + auto it = config_.source_target_to_bounds.find( + std::make_pair(*source_target.source, current_id)); + if (it == config_.source_target_to_bounds.end()) { + return absl::InternalError("Missing bounds for conditional Recv"); + } + if (*counter < it->second.first || *counter > it->second.second) { + should_run = false; + } + VLOG(3) << "RunNcclCollective counter " << *counter << " " << should_run; + ++(*counter); + } + if (should_run) { + TF_RETURN_IF_ERROR(nccl_api()->Recv(dest_addr, buffer.element_type, + buffer.element_count, *source_id, + comm, &stream)); + } + + } else { + // If there is no source peer, i.e. no sender to this instance, zero out + // the destination buffer. + VLOG(3) << absl::StreamFormat("%s : collective-Permute: Issuing MemZero", + device_string); + TF_RETURN_IF_ERROR(stream.MemZero(&dest_addr, dest_addr.size())); + } + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/nccl_recv_thunk.h b/xla/service/gpu/runtime/nccl_recv_thunk.h new file mode 100644 index 0000000000000..9e10a453dfcae --- /dev/null +++ b/xla/service/gpu/runtime/nccl_recv_thunk.h @@ -0,0 +1,66 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_NCCL_RECV_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_NCCL_RECV_THUNK_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/service/gpu/runtime/nccl_p2p_thunk_common.h" +#include "xla/stream_executor/stream.h" + +namespace xla { +namespace gpu { + +// Thunk that performs a NCCL-recv. +class NcclRecvThunk : public NcclCollectiveThunk { + public: + NcclRecvThunk(ThunkInfo thunk_info, NcclApi* nccl_api, + const HloRecvInstruction* instr, int64_t replica_count, + int64_t partition_count, const Buffer& buffer); + absl::Status Initialize(const InitializeParams& params) override; + + protected: + const NcclCollectiveConfig& config() const override { return config_.config; } + absl::Status RunNcclCollective(const ExecuteParams& params, + se::Stream& stream, + NcclApi::NcclCommHandle comm) override; + AsyncStreamKind GetAsyncStreamKind() const override { return stream_kind_; } + bool NeedFirstCallRendzevous() const override { return false; } + + private: + const NcclP2PConfig config_; + const Buffer buffer_; + const AsyncStreamKind stream_kind_; + std::shared_ptr execution_counters_; +}; + +absl::Status RunRecv(NcclApi* nccl_api, + NcclP2PConfig::SourceTargetMapEntry source_target, + DeviceBufferPair& buffer, se::Stream& stream, + NcclApi::NcclCommHandle comm, + absl::string_view device_string, int64_t current_id); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_NCCL_RECV_THUNK_H_ diff --git a/xla/service/gpu/runtime/nccl_send_thunk.cc b/xla/service/gpu/runtime/nccl_send_thunk.cc new file mode 100644 index 0000000000000..64a9f55a102bd --- /dev/null +++ b/xla/service/gpu/runtime/nccl_send_thunk.cc @@ -0,0 +1,139 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/nccl_send_thunk.h" + +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/computation_placer.h" +#include "xla/service/global_device_id.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/service/gpu/runtime/nccl_p2p_thunk_common.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/stream.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +NcclSendThunk::NcclSendThunk(ThunkInfo thunk_info, NcclApi* nccl_api, + const HloSendInstruction* instr, + int64_t replica_count, int64_t partition_count, + const Buffer& buffer) + : NcclCollectiveThunk(Thunk::kNcclSend, thunk_info, nccl_api, + /*is_sync=*/false), + config_(GetNcclP2PConfigForSendRecv(instr, instr->operand(0)->shape(), + replica_count, partition_count)), + buffer_(buffer), + stream_kind_(GetStreamKindForSendRecv(instr)), + execution_counters_(config_.validation_kind == + NcclP2PConfig::ValidationKind::kConditional + ? new ExecutionCounters() + : nullptr) {} + +absl::Status NcclSendThunk::Initialize(const InitializeParams& params) { + TF_RETURN_IF_ERROR(NcclCollectiveThunk::Initialize(params)); + if (execution_counters_) { + TF_RETURN_IF_ERROR(execution_counters_->Initialize(params.executor)); + } + return absl::OkStatus(); +} + +absl::Status NcclSendThunk::RunNcclCollective(const ExecuteParams& params, + se::Stream& stream, + NcclApi::NcclCommHandle comm) { + TF_ASSIGN_OR_RETURN( + std::vector device_buffers, + ConvertToDeviceBuffers(params, {buffer_}, + config_.config.operand_element_type)); + TF_RET_CHECK(device_buffers.size() == 1) << "Expected one buffer pair."; + + GlobalDeviceId global_device_id = params.collective_params->global_device_id; + + TF_ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID current_logical_id, + params.collective_params->device_assn->LogicalIdForDevice( + global_device_id)); + const int64_t current_id = + config_.config.group_mode == CollectiveOpGroupMode::kCrossReplica + ? current_logical_id.replica_id + : current_logical_id.computation_id; + std::string device_string = GetDeviceString(*params.collective_params); + + const NcclP2PConfig::SourceTargetMapEntry source_target = + NcclP2PConfig::GetSourceTarget(config_.id_to_source_target, current_id); + DeviceBufferPair& buffer = device_buffers[0]; + + // Determine the target IDs for this instance. The target ID is the ID + // to which this instance will copy its data. + int device_ordinal = stream.parent()->device_ordinal(); + VLOG(3) << "Performing collective permute from device ordinal: " + << device_ordinal << "current_id " << current_id; + TF_RETURN_IF_ERROR( + MaybeRegisterBuffers(nccl_api(), device_ordinal, {buffer}, comm)); + + const std::optional target_id = source_target.target; + se::DeviceMemoryBase src_addr = buffer.source_buffer; + + VLOG(3) << absl::StreamFormat("%s : id = %d, target_id = %d", device_string, + current_id, target_id.value_or(-1)); + + // Send source buffer to target peer if needed. + if (target_id) { + bool should_run = + config_.validation_kind == NcclP2PConfig::ValidationKind::kInvalid + ? false + : true; + if (config_.validation_kind == + NcclP2PConfig::ValidationKind::kConditional) { + se::StreamExecutor* executor = params.stream->parent(); + TF_ASSIGN_OR_RETURN(int64_t * counter, + execution_counters_->GetCounter(executor)); + auto it = config_.source_target_to_bounds.find( + std::make_pair(current_id, *source_target.target)); + if (it == config_.source_target_to_bounds.end()) { + return absl::InternalError("Missing bounds for conditional Send"); + } + if (*counter < it->second.first || *counter > it->second.second) { + should_run = false; + } + VLOG(3) << "RunNcclCollective counter " << *counter << " " << should_run; + ++(*counter); + } + + if (should_run) { + TF_RETURN_IF_ERROR(nccl_api()->Send(src_addr, buffer.element_type, + buffer.element_count, *target_id, + comm, &stream)); + } + } + + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/nccl_send_thunk.h b/xla/service/gpu/runtime/nccl_send_thunk.h new file mode 100644 index 0000000000000..747df380137e7 --- /dev/null +++ b/xla/service/gpu/runtime/nccl_send_thunk.h @@ -0,0 +1,67 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_NCCL_SEND_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_NCCL_SEND_THUNK_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/service/gpu/runtime/nccl_p2p_thunk_common.h" +#include "xla/stream_executor/stream.h" + +namespace xla { +namespace gpu { + +// Thunk that performs a NCCL-send. +class NcclSendThunk : public NcclCollectiveThunk { + public: + NcclSendThunk(ThunkInfo thunk_info, NcclApi* nccl_api, + const HloSendInstruction* instr, int64_t replica_count, + int64_t partition_count, const Buffer& buffer); + absl::Status Initialize(const InitializeParams& params) override; + + protected: + const NcclCollectiveConfig& config() const override { return config_.config; } + absl::Status RunNcclCollective(const ExecuteParams& params, + se::Stream& stream, + NcclApi::NcclCommHandle comm) override; + AsyncStreamKind GetAsyncStreamKind() const override { return stream_kind_; } + bool NeedFirstCallRendzevous() const override { return false; } + + private: + const NcclP2PConfig config_; + const Buffer buffer_; + const AsyncStreamKind stream_kind_; + std::shared_ptr execution_counters_; +}; + +absl::Status RunSend(NcclApi* nccl_api, + NcclP2PConfig::SourceTargetMapEntry source_target, + DeviceBufferPair& buffer, se::Stream& stream, + NcclApi::NcclCommHandle comm, + absl::string_view device_string, int64_t current_id); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_NCCL_SEND_THUNK_H_ diff --git a/xla/service/gpu/runtime/norm.cc b/xla/service/gpu/runtime/norm.cc deleted file mode 100644 index 18e1cc2eeb329..0000000000000 --- a/xla/service/gpu/runtime/norm.cc +++ /dev/null @@ -1,231 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/norm.h" - -#include -#include -#include -#include -#include - -#include "llvm/ADT/Sequence.h" -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/gpu_asm_opts_util.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/status.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" -#include "xla/translate/mhlo_to_hlo/attribute_exporter.h" -#include "xla/xla.pb.h" - -namespace xla { - -using xla::runtime::CustomCall; -using xla::runtime::EnumAttrEncoding; -using xla::runtime::FlatMemrefView; -using xla::runtime::State; -using xla::runtime::StridedMemrefView; -using xla::runtime::Tagged; - -namespace lmhlo_gpu = ::mlir::lmhlo_gpu; -namespace gpu { - -struct NormAlgorithmConfig { - int64_t algorithm; - int64_t workspace_size; -}; - -void PopulateNormAlgorithmConfigAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding) { - { // --- Encode `lmhlo_gpu::NormAlgorithmConfigAttr`. - using Attr = mlir::lmhlo_gpu::NormAlgorithmConfigAttr; - encoding - .Add>( - encoding, xla::runtime::AggregateAttrDef() - .Add("algorithm", &Attr::getAlgorithm) - .Add("workspace_size", &Attr::getWorkspaceSize)); - } -} -} // namespace gpu - -namespace runtime { -XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING( - xla::gpu::NormAlgorithmConfig, // - AggregateMember("algorithm"), - AggregateMember("workspace_size")); -} // namespace runtime - -namespace gpu { - -void RegisterNormTypeIdNames(runtime::TypeIDNameRegistry& registry) { - registry.Register>( - "__type_id_norm_algorithm_config"); -} - -static GpuNormDescriptor GetGpuNormDescriptor( - StridedMemrefView input, StridedMemrefView scale, StridedMemrefView bias, - StridedMemrefView output, std::optional expectation, - std::optional norm_factor, double epsilon, - NormAlgorithmConfig algorithm_config, - absl::Span operand_layouts) { - GpuNormDescriptor descriptor; - - auto* algorithm = descriptor.backend_config.mutable_algorithm(); - algorithm->set_algo_id(algorithm_config.algorithm); - algorithm->set_is_cudnn_frontend(true); - if (algorithm_config.workspace_size >= 0) { - algorithm->mutable_workspace_size()->set_value( - algorithm_config.workspace_size); - } - - // Apply backend config layout to the shape. - int layout_idx = 0; - auto apply_shape = [&operand_layouts, - &layout_idx](const StridedMemrefView& memref) -> Shape { - std::vector minor_to_major = { - operand_layouts.begin() + layout_idx, - operand_layouts.begin() + layout_idx + memref.sizes.size()}; - layout_idx += memref.sizes.size(); - Shape shape = ToShape(memref); - return ShapeUtil::MakeShapeWithDenseLayout( - shape.element_type(), shape.dimensions(), minor_to_major); - }; - - descriptor.input_shape = apply_shape(input); - descriptor.scale_shape = apply_shape(scale); - descriptor.bias_shape = apply_shape(bias); - descriptor.output_shape = apply_shape(output); - if (expectation) { - descriptor.expectation_shape = apply_shape(*expectation); - } - if (norm_factor) { - descriptor.norm_factor_shape = apply_shape(*norm_factor); - } - - descriptor.backend_config.set_epsilon(epsilon); - - return descriptor; -} - -static absl::Status NormImpl(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - State runner_state, - StridedMemrefView input, StridedMemrefView scale, - StridedMemrefView bias, StridedMemrefView output, - CustomCall::RemainingArgs remaining_args, - int64_t uid, double epsilon, - absl::Span operand_layouts, - NormAlgorithmConfig algorithm_config) { - std::optional expectation, norm_factor; - // Final remaining arg is the scratch space. - if (remaining_args.size() == 3) { - auto expectation_ = remaining_args.get(0); - if (failed(expectation_)) { - return absl::InternalError("Failure while retrieving expectation."); - } - expectation = expectation_.value(); - - auto norm_factor_ = remaining_args.get(1); - if (failed(norm_factor_)) { - return absl::InternalError("Failure while retrieving norm factor."); - } - norm_factor = norm_factor_.value(); - } - - GpuNormDescriptor descriptor = - GetGpuNormDescriptor(input, scale, bias, output, expectation, norm_factor, - epsilon, algorithm_config, operand_layouts); - - auto config = GpuNormConfig::For(descriptor); - if (!config.ok()) { - return tsl::ToAbslStatus(config.status()); - } - auto current_runner = - runner_state.GetOrCreate([&config]() -> absl::StatusOr { - return NormRunnerState(std::move(config.value())); - }); - if (!current_runner.ok()) { - return tsl::ToAbslStatus(current_runner.status()); - } - - se::DeviceMemoryBase input_buffer = GetDeviceAddress(input); - se::DeviceMemoryBase scale_buffer = GetDeviceAddress(scale); - se::DeviceMemoryBase bias_buffer = GetDeviceAddress(bias); - se::DeviceMemoryBase output_buffer = GetDeviceAddress(output); - std::optional expectation_buffer, norm_factor_buffer; - if (expectation) { - expectation_buffer = GetDeviceAddress(expectation.value()); - } - if (norm_factor) { - norm_factor_buffer = GetDeviceAddress(norm_factor.value()); - } - - auto scratch = remaining_args.get(remaining_args.size() - 1); - if (failed(scratch)) { - return absl::InternalError("Failure while retrieving scratch."); - } - se::DeviceMemoryBase scratch_buffer = GetDeviceAddress(scratch.value()); - - RunNormOptions opts; - opts.norm_runner = ¤t_runner.value()->runner; - - // Run the norm. - return RunGpuNorm(current_runner.value()->config, input_buffer, scale_buffer, - bias_buffer, output_buffer, expectation_buffer, - norm_factor_buffer, scratch_buffer, run_options->stream(), - opts); -} - -template -auto BindNormAttributes(runtime::CustomCallBinding binding) { - return std::move(binding) - // Unique convolution id for caching state. - .template Attr("uid") - .template Attr("epsilon") - .template Attr>("operand_layouts") - .template Attr("norm_algorithm_config"); -} - -auto NormCall(const char* name) { - return CustomCall::Bind(name) - .UserData() - .UserData() - .State("uid") - .Arg() // input - .Arg() // scale - .Arg() // bias - .Arg(); // output -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Norm, FunctionWrapper(), checks, - BindNormAttributes(NormCall("xla.gpu.norm").RemainingArgs())); - -void RegisterNormCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.norm", Norm); -} - -StreamExecutorNormRunners* NormRunnerStates::operator()( - se::StreamExecutor* executor) { - absl::MutexLock lock(&mutex_); - return &runners_[executor]; -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/norm.h b/xla/service/gpu/runtime/norm.h deleted file mode 100644 index 522180efb2569..0000000000000 --- a/xla/service/gpu/runtime/norm.h +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_NORM_H_ -#define XLA_SERVICE_GPU_RUNTIME_NORM_H_ - -#include -#include - -#include "absl/container/node_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/service/gpu/gpu_norm_runner.h" - -namespace xla { -namespace gpu { - -// Registers XLA GPU runtime norm custom calls. -void RegisterNormCustomCalls(runtime::DirectCustomCallRegistry& registry); - -// Register type names for norm attributes defined by MHLO dialect. -void RegisterNormTypeIdNames(runtime::TypeIDNameRegistry& registry); - -void PopulateNormAlgorithmConfigAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding); - -// State of the norm runners between invocations. -struct NormRunnerState { - explicit NormRunnerState(GpuNormConfig config) - : config(std::move(config)), runner(this->config) {} - GpuNormConfig config; - NormRunner runner; -}; - -class StreamExecutorNormRunners : public runtime::StateVector { -}; - -// XLA executable keeps a mapping from stream executors to norm runners. -class NormRunnerStates { - public: - StreamExecutorNormRunners* operator()(se::StreamExecutor* executor); - - private: - mutable absl::Mutex mutex_; - absl::node_hash_map runners_ - ABSL_GUARDED_BY(mutex_); -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_NORM_H_ diff --git a/xla/service/gpu/runtime/norm_thunk.cc b/xla/service/gpu/runtime/norm_thunk.cc new file mode 100644 index 0000000000000..d3862f7bfeac7 --- /dev/null +++ b/xla/service/gpu/runtime/norm_thunk.cc @@ -0,0 +1,110 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/norm_thunk.h" + +#include +#include + +#include "absl/status/status.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/util.h" + +namespace xla { +namespace gpu { + +NormThunk::NormThunk(ThunkInfo thunk_info, GpuNormConfig config, + BufferAllocation::Slice x_slice, + BufferAllocation::Slice scale_slice, + BufferAllocation::Slice y_or_dx_slice, + std::optional bias_slice, + std::optional expectation_slice, + std::optional norm_factor_slice, + std::optional dy_slice, + std::optional dscale_slice, + std::optional dbias_slice, + BufferAllocation::Slice scratch_slice) + : Thunk(Kind::kNorm, thunk_info), + x_buffer_(x_slice), + scale_buffer_(scale_slice), + y_or_dx_buffer_(y_or_dx_slice), + bias_buffer_(bias_slice), + expectation_buffer_(expectation_slice), + norm_factor_buffer_(norm_factor_slice), + dy_buffer_(dy_slice), + dscale_buffer_(dscale_slice), + dbias_buffer_(dbias_slice), + scratch_buffer_(scratch_slice), + config_(config) {} + +NormRunner& NormThunk::GetOrCreateRunner( + const stream_executor::Stream* stream) { + absl::MutexLock lock(&mu_); + auto it = runner_cache_.find(stream); + if (it == runner_cache_.end()) { + it = runner_cache_.insert({stream, std::make_unique(config_)}) + .first; + } + return *it->second; +} + +absl::Status NormThunk::ExecuteOnStream(const ExecuteParams& params) { + const auto& buffer_allocations = *params.buffer_allocations; + + se::DeviceMemoryBase x_se_buffer = + buffer_allocations.GetDeviceAddress(x_buffer_); + se::DeviceMemoryBase scale_se_buffer = + buffer_allocations.GetDeviceAddress(scale_buffer_); + se::DeviceMemoryBase y_or_dx_se_buffer = + buffer_allocations.GetDeviceAddress(y_or_dx_buffer_); + + std::optional bias_se_buffer, expectation_se_buffer, + norm_factor_se_buffer, dy_se_buffer, dscale_se_buffer, dbias_se_buffer; + if (bias_buffer_) { + bias_se_buffer = buffer_allocations.GetDeviceAddress(bias_buffer_.value()); + } + if (expectation_buffer_) { + expectation_se_buffer = + buffer_allocations.GetDeviceAddress(expectation_buffer_.value()); + norm_factor_se_buffer = + buffer_allocations.GetDeviceAddress(norm_factor_buffer_.value()); + } + if (dscale_buffer_) { + dy_se_buffer = buffer_allocations.GetDeviceAddress(dy_buffer_.value()); + dscale_se_buffer = + buffer_allocations.GetDeviceAddress(dscale_buffer_.value()); + dbias_se_buffer = + buffer_allocations.GetDeviceAddress(dbias_buffer_.value()); + } + + se::DeviceMemoryBase scratch = + buffer_allocations.GetDeviceAddress(scratch_buffer_); + + RunNormOptions opts; + opts.norm_runner = &GetOrCreateRunner(params.stream); + + TF_RETURN_IF_ERROR(RunGpuNorm( + config_, x_se_buffer, scale_se_buffer, y_or_dx_se_buffer, bias_se_buffer, + dy_se_buffer, expectation_se_buffer, norm_factor_se_buffer, + dscale_se_buffer, dbias_se_buffer, scratch, params.stream, opts)); + + if (!params.stream->ok()) { + return Internal("NormThunk::ExecuteOnStream failed."); + } + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/norm_thunk.h b/xla/service/gpu/runtime/norm_thunk.h new file mode 100644 index 0000000000000..602d504175fb3 --- /dev/null +++ b/xla/service/gpu/runtime/norm_thunk.h @@ -0,0 +1,76 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_NORM_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_NORM_THUNK_H_ + +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/gpu_norm_runner.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +class NormThunk : public Thunk { + public: + NormThunk(ThunkInfo thunk_info, GpuNormConfig config, + BufferAllocation::Slice x, BufferAllocation::Slice scale, + BufferAllocation::Slice y_or_dx, + std::optional bias, + std::optional expectation, + std::optional norm_factor, + std::optional dy, + std::optional dscale, + std::optional dbias, + BufferAllocation::Slice scratch); + + NormThunk(const NormThunk&) = delete; + NormThunk& operator=(const NormThunk&) = delete; + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + private: + BufferAllocation::Slice x_buffer_; + BufferAllocation::Slice scale_buffer_; + BufferAllocation::Slice y_or_dx_buffer_; + std::optional bias_buffer_; + std::optional expectation_buffer_; + std::optional norm_factor_buffer_; + std::optional dy_buffer_; + std::optional dscale_buffer_; + std::optional dbias_buffer_; + BufferAllocation::Slice scratch_buffer_; + NormRunner& GetOrCreateRunner(const stream_executor::Stream*); + + GpuNormConfig config_; + absl::Mutex mu_; + absl::flat_hash_map> + runner_cache_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_NORM_THUNK_H_ diff --git a/xla/service/gpu/outfeed_thunk.cc b/xla/service/gpu/runtime/outfeed_thunk.cc similarity index 82% rename from xla/service/gpu/outfeed_thunk.cc rename to xla/service/gpu/runtime/outfeed_thunk.cc index ce00e957b82ae..dd3dc2e153dbb 100644 --- a/xla/service/gpu/outfeed_thunk.cc +++ b/xla/service/gpu/runtime/outfeed_thunk.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,10 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/outfeed_thunk.h" +#include "xla/service/gpu/runtime/outfeed_thunk.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/literal.h" +#include "absl/status/status.h" #include "xla/service/gpu/outfeed_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/util.h" @@ -29,7 +28,7 @@ OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, : Thunk(Kind::kOutfeed, thunk_info), source_slices_(std::move(source_slices)) {} -Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { +absl::Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { se::Stream& stream = *params.stream; const BufferAllocations& buffer_allocations = *params.buffer_allocations; @@ -43,7 +42,7 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { // Note: Cannot do this before `BlockingGetNextDestination` above to dequeue // an entry from the outfeed manager. if (source_slices_.empty()) { - return OkStatus(); + return absl::OkStatus(); } const int64_t leaf_count = output_buffers->leaf_count(); @@ -74,7 +73,8 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { ++output_leaf_it; const Shape& output_shape = ShapeUtil::GetSubshape(output_buffers->shape(), shape_index); - TF_RET_CHECK(ShapeUtil::Equal(source_slices_[index].shape, output_shape)) + TF_RET_CHECK( + ShapeUtil::ReshapeIsBitcast(source_slices_[index].shape, output_shape)) << "Mismatch between outfeed output buffer shape " << ShapeUtil::HumanStringWithLayout(output_shape) << " and outfeed source buffer shape " @@ -82,26 +82,25 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { BufferAllocation::Slice source_slice = source_slices_[index].slice; if (!source_slice.allocation()) - return InternalError("outfeed source missing buffer allocation"); + return Internal("outfeed source missing buffer allocation"); se::DeviceMemoryBase data_address = buffer_allocations.GetDeviceAddress(source_slice); // TODO(b/111309141): Run this on a separate stream so it doesn't block // the GPU from doing work during the transfer. - stream - .ThenMemcpy(buffer->destination()->untyped_data(), data_address, - buffer->length()) - .ThenDoHostCallback([&buffer]() { buffer->Done(); }); + TF_RETURN_IF_ERROR(stream.Memcpy(buffer->destination()->untyped_data(), + data_address, buffer->length())); + TF_RETURN_IF_ERROR(stream.DoHostCallback([&buffer]() { buffer->Done(); })); } - Status block_status = stream.BlockHostUntilDone(); + absl::Status block_status = stream.BlockHostUntilDone(); if (!block_status.ok()) { - return InternalError("Failed to complete data transfer on stream %p: %s", - &stream, block_status.message()); + return Internal("Failed to complete data transfer on stream %p: %s", + &stream, block_status.message()); } VLOG(2) << "Outfeeding from GPU complete"; - return OkStatus(); + return absl::OkStatus(); } } // namespace gpu diff --git a/xla/service/gpu/outfeed_thunk.h b/xla/service/gpu/runtime/outfeed_thunk.h similarity index 75% rename from xla/service/gpu/outfeed_thunk.h rename to xla/service/gpu/runtime/outfeed_thunk.h index e6abb189a6051..a216431eb0a42 100644 --- a/xla/service/gpu/outfeed_thunk.h +++ b/xla/service/gpu/runtime/outfeed_thunk.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_OUTFEED_THUNK_H_ -#define XLA_SERVICE_GPU_OUTFEED_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_OUTFEED_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_OUTFEED_THUNK_H_ -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/thunk.h" -#include "xla/stream_executor/stream_executor.h" +#include + +#include "absl/status/status.h" +#include "xla/service/gpu/runtime/thunk.h" namespace xla { namespace gpu { @@ -36,7 +36,7 @@ class OutfeedThunk : public Thunk { OutfeedThunk(const OutfeedThunk&) = delete; OutfeedThunk& operator=(const OutfeedThunk&) = delete; - Status ExecuteOnStream(const ExecuteParams& params) override; + absl::Status ExecuteOnStream(const ExecuteParams& params) override; private: const std::vector source_slices_; @@ -45,4 +45,4 @@ class OutfeedThunk : public Thunk { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_OUTFEED_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_OUTFEED_THUNK_H_ diff --git a/xla/service/gpu/runtime/replica_id_thunk.cc b/xla/service/gpu/runtime/replica_id_thunk.cc new file mode 100644 index 0000000000000..c563afed6fea1 --- /dev/null +++ b/xla/service/gpu/runtime/replica_id_thunk.cc @@ -0,0 +1,39 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/replica_id_thunk.h" + +#include "absl/status/status.h" +#include "xla/service/global_device_id.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +absl::Status ReplicaOrPartitionIdThunk::ExecuteOnStream( + const ExecuteParams& params) { + auto dest_addr = params.buffer_allocations->GetDeviceAddress(dest_); + + GlobalDeviceId global_device_id = params.collective_params->global_device_id; + TF_ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID logical_id, + params.collective_params->device_assn->LogicalIdForDevice( + global_device_id)); + int id = kind() == Kind::kReplicaId ? logical_id.replica_id + : logical_id.computation_id; + return params.stream->Memset32(&dest_addr, id, /*size=*/4); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/replica_id_thunk.h b/xla/service/gpu/runtime/replica_id_thunk.h similarity index 78% rename from xla/service/gpu/replica_id_thunk.h rename to xla/service/gpu/runtime/replica_id_thunk.h index 23f694a0654a8..7b9aa403de1bc 100644 --- a/xla/service/gpu/replica_id_thunk.h +++ b/xla/service/gpu/runtime/replica_id_thunk.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,18 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_REPLICA_ID_THUNK_H_ -#define XLA_SERVICE_GPU_REPLICA_ID_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME_REPLICA_ID_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_REPLICA_ID_THUNK_H_ +#include "absl/status/status.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/thunk.h" +#include "xla/service/gpu/runtime/thunk.h" namespace xla { namespace gpu { // Thunk that implements the ReplicaId(Idx == 0) or PartitionId(Idx == 1). class ReplicaOrPartitionIdThunk : public Thunk { - Status ExecuteOnStream(const ExecuteParams& params) override; + public: + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + BufferAllocation::Slice dest() const { return dest_; } protected: ReplicaOrPartitionIdThunk(Kind kind, ThunkInfo thunk_info, @@ -50,4 +54,4 @@ class PartitionIdThunk : public ReplicaOrPartitionIdThunk { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_REPLICA_ID_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME_REPLICA_ID_THUNK_H_ diff --git a/xla/service/gpu/runtime/send_recv.cc b/xla/service/gpu/runtime/send_recv.cc deleted file mode 100644 index 5b70f5322f084..0000000000000 --- a/xla/service/gpu/runtime/send_recv.cc +++ /dev/null @@ -1,307 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/send_recv.h" - -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_format.h" -#include "absl/synchronization/mutex.h" -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/service_executable_run_options.h" -#include "tsl/concurrency/async_value.h" -#include "tsl/concurrency/async_value_ref.h" -#include "tsl/profiler/lib/traceme.h" -#include "tsl/profiler/lib/traceme_encode.h" - -namespace xla { -namespace gpu { - -using absl::InternalError; -using absl::InvalidArgumentError; -using absl::StrFormat; - -using tsl::AsyncValueRef; -using tsl::profiler::TraceMe; -using tsl::profiler::TraceMeEncode; - -using xla::runtime::AggregateAttrDef; -using xla::runtime::AggregateAttrEncoding; -using xla::runtime::CustomCall; -using xla::runtime::CustomCallAttrEncodingSet; -using xla::runtime::Dictionary; -using xla::runtime::StridedMemrefView; -using xla::runtime::Tagged; -using xla::runtime::TypeIDNameRegistry; - -namespace mhlo = ::mlir::mhlo; - -//===----------------------------------------------------------------------===// -// Structs for encoding send/recv operations attributes. -//===----------------------------------------------------------------------===// - -struct ChannelHandle { - int64_t handle; - int64_t type; -}; - -} // namespace gpu - -//===----------------------------------------------------------------------===// -// Register send/recv attributes decoding with the Xla runtime. -//===----------------------------------------------------------------------===// - -namespace runtime { - -XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING(xla::gpu::ChannelHandle, - AggregateMember("handle"), - AggregateMember("type")); - -} // namespace runtime - -//===----------------------------------------------------------------------===// -// Type names for encoded attributes. -//===----------------------------------------------------------------------===// - -namespace gpu { - -void RegisterSendRecvTypeIdNames(TypeIDNameRegistry& registry) { - registry.Register>("__type_id_channel_handle"); -} - -//===----------------------------------------------------------------------===// -// Encoding from MHLO attributes to Xla runtime aggregate attributes. -//===----------------------------------------------------------------------===// - -void PopulateSendRecvAttrEncoding(CustomCallAttrEncodingSet& encoding) { - { // --- Encode `mhlo::ChannelHandleAttr`. - using Attr = mhlo::ChannelHandleAttr; - encoding.Add>( - encoding, AggregateAttrDef() - .Add("handle", &Attr::getHandle) - .Add("type", &Attr::getType)); - } -} - -//===----------------------------------------------------------------------===// -// Support for running asynchronous Send/Recv SendDone/RecvDone operations. -//===----------------------------------------------------------------------===// - -absl::Status SendRecvEvents::PushEvent(int32_t handle, - AsyncValueRef event) { - absl::MutexLock lock(&mutex_); - if (auto it = events_.try_emplace(handle, std::move(event)); it.second) - return absl::OkStatus(); - - return InternalError( - StrFormat("Async send/recv event already exists (handle=%d)", handle)); -} - -absl::StatusOr> SendRecvEvents::PopEvent( - int32_t handle) { - absl::MutexLock lock(&mutex_); - if (auto event = events_.extract(handle)) return std::move(event.mapped()); - - return InternalError( - StrFormat("Async send/recv event was not found (handle==%d)", handle)); -} - -//===----------------------------------------------------------------------===// -// Generate a map with frontend attributes. -//===----------------------------------------------------------------------===// - -absl::flat_hash_map GenerateFrontEndAttributeMap( - Dictionary frontend_attrs) { - absl::flat_hash_map frontend_attr_map; - for (std::string_view key : frontend_attrs.keys()) { - auto frontend_attr = frontend_attrs.get(key); - if (mlir::succeeded(frontend_attr)) { - frontend_attr_map.insert({std::string(key), std::string(*frontend_attr)}); - } - } - return frontend_attr_map; -} - -//===----------------------------------------------------------------------===// -// Send/Recv custom call implementation. -//===----------------------------------------------------------------------===// - -static absl::Status SendImpl(const ServiceExecutableRunOptions* run_options, - SendRecvEvents* events, StridedMemrefView arg, - ChannelHandle channel, Dictionary frontend_attrs) { - VLOG(3) << "Host Send buffer:" - << " channel=" << channel.handle; - - TraceMe trace([&] { - return TraceMeEncode("xla.gpu.send_host", {{"channel", channel.handle}}); - }); - - // Use device_to_host stream if it is available. - se::Stream* stream = run_options->run_options().device_to_host_stream(); - if (stream) { - stream->ThenWaitFor(run_options->stream()); - } else { - stream = run_options->stream(); - } - - // Send buffer to a handler registered with the run options. - if (auto* send = run_options->run_options().send_device_memory_function()) { - TF_ASSIGN_OR_RETURN( - auto done_event, - (*send)(channel.handle, stream, ToShape(arg), GetDeviceAddress(arg), - GenerateFrontEndAttributeMap(frontend_attrs))); - return events->PushEvent(channel.handle, std::move(done_event)); - } - - return InvalidArgumentError("SendDeviceMemoryFunction is not available"); -} - -static absl::Status RecvImpl(const ServiceExecutableRunOptions* run_options, - SendRecvEvents* events, StridedMemrefView arg, - ChannelHandle channel, Dictionary frontend_attrs) { - VLOG(3) << "Host Receive buffer:" - << " channel=" << channel.handle; - - TraceMe trace([&] { - return TraceMeEncode("xla.gpu.recv_host", {{"channel", channel.handle}}); - }); - - // Use host_to_device stream if it is available. - se::Stream* stream = run_options->run_options().host_to_device_stream(); - if (stream) { - stream->ThenWaitFor(run_options->stream()); - } else { - stream = run_options->stream(); - } - - // Recv buffer from a handler registered with the run options. - if (auto* recv = run_options->run_options().recv_device_memory_function()) { - auto dst = GetDeviceAddress(arg); - TF_ASSIGN_OR_RETURN(auto done_event, - (*recv)(channel.handle, stream, ToShape(arg), &dst, - GenerateFrontEndAttributeMap(frontend_attrs))); - return events->PushEvent(channel.handle, std::move(done_event)); - } - - return InvalidArgumentError("RecvDeviceMemoryFunction is not available"); -} - -static absl::Status SendDoneImpl(const ServiceExecutableRunOptions* run_options, - SendRecvEvents* events, - ChannelHandle channel) { - VLOG(3) << "Wait for Host Send completion:" - << " channel=" << channel.handle; - - TraceMe trace([&] { - return TraceMeEncode("xla.gpu.send_done_host", - {{"channel", channel.handle}}); - }); - - TF_ASSIGN_OR_RETURN(auto done_event, events->PopEvent(channel.handle)); - - // Wait until send handler will record an event on the stream. - BlockUntilReady(done_event.GetAsyncValue()); - if (done_event.IsError()) return done_event.GetError(); - - VLOG(5) << "Completed Host Send operation: " - << " channel=" << channel.handle; - - // Once event is recorded we can add a stream dependency. - run_options->stream()->ThenWaitFor(&done_event.get()); - return absl::OkStatus(); -} - -static absl::Status RecvDoneImpl(const ServiceExecutableRunOptions* run_options, - SendRecvEvents* events, - ChannelHandle channel) { - VLOG(3) << "Wait for Recv completion:" - << " channel=" << channel.handle; - - TraceMe trace([&] { - return TraceMeEncode("xla.gpu.recv_done_host", - {{"channel", channel.handle}}); - }); - - TF_ASSIGN_OR_RETURN(auto done_event, events->PopEvent(channel.handle)); - - // Wait until send handler will record an event on the stream. - BlockUntilReady(done_event.GetAsyncValue()); - if (done_event.IsError()) return done_event.GetError(); - - VLOG(5) << "Completed Host Recv operation: " - << " channel=" << channel.handle; - - // Once event is recorded we can add a stream dependency. - run_options->stream()->ThenWaitFor(&done_event.get()); - return absl::OkStatus(); -} - -//===----------------------------------------------------------------------===// -// Send/Recv custom calls bindings and registration. -//===----------------------------------------------------------------------===// - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - SendHost, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.send_host") - .UserData() - .UserData() - .Arg() - .Attr("channel_handle") - .Attr("frontend_attributes")); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - RecvHost, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.recv_host") - .UserData() - .UserData() - .Arg() - .Attr("channel_handle") - .Attr("frontend_attributes")); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - SendDoneHost, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.send_done_host") - .UserData() - .UserData() - .Attr("channel_handle")); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - RecvDoneHost, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.recv_done_host") - .UserData() - .UserData() - .Attr("channel_handle")); - -//===----------------------------------------------------------------------===// - -// Registers XLA Gpu runtime Host Send/Recv custom calls. -void RegisterSendRecvCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.send_host", SendHost); - registry.Register("xla.gpu.recv_host", RecvHost); - registry.Register("xla.gpu.send_done_host", SendDoneHost); - registry.Register("xla.gpu.recv_done_host", RecvDoneHost); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/send_recv.h b/xla/service/gpu/runtime/send_recv.h deleted file mode 100644 index fcbe1ddcc9f9c..0000000000000 --- a/xla/service/gpu/runtime/send_recv.h +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_SEND_RECV_H_ -#define XLA_SERVICE_GPU_RUNTIME_SEND_RECV_H_ - -#include -#include - -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/stream_executor/event.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime Send/Recv custom calls. -void RegisterSendRecvCustomCalls(runtime::DirectCustomCallRegistry& registry); - -// Register type names for communication attributes defined by MHLO dialect. -void RegisterSendRecvTypeIdNames(runtime::TypeIDNameRegistry& registry); - -// Adds attributes encoding for Send/Recv custom calls -void PopulateSendRecvAttrEncoding(runtime::CustomCallAttrEncodingSet& encoding); - -//===----------------------------------------------------------------------===// -// Support for running asynchronous Send/Recv SendDone/RecvDone operations. -//===----------------------------------------------------------------------===// - -class SendRecvEvents { - public: - absl::Status PushEvent(int32_t handle, tsl::AsyncValueRef event); - absl::StatusOr> PopEvent(int32_t handle); - - private: - absl::Mutex mutex_; - absl::flat_hash_map> events_ - ABSL_GUARDED_BY(mutex_); -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_SEND_RECV_H_ diff --git a/xla/service/gpu/runtime/send_recv_thunk.cc b/xla/service/gpu/runtime/send_recv_thunk.cc new file mode 100644 index 0000000000000..ba23806bbbaf1 --- /dev/null +++ b/xla/service/gpu/runtime/send_recv_thunk.cc @@ -0,0 +1,267 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/send_recv_thunk.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/global_device_id.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/shape.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/concurrency/async_value.h" +#include "tsl/concurrency/async_value_ref.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/traceme.h" + +namespace xla::gpu { + +using tsl::AsyncValueRef; +using tsl::profiler::TraceMe; +using tsl::profiler::TraceMeEncode; + +// For sharded buffers we should execute Send/Recv operations only on devices +// with maximal sharding, and do nothing on every other device. +static absl::StatusOr ShouldSkip( + std::string_view operation, const Thunk::ExecuteParams& params, + const std::optional& device_constraint) { + if (!device_constraint.has_value()) return false; + + GlobalDeviceId global_device_id = params.collective_params->global_device_id; + bool skip = global_device_id != *device_constraint; + if (skip) { + VLOG(3) << "Skip " << operation << " as device id " << global_device_id + << " doesn't match device id constraint " << *device_constraint; + } + + return skip; +} + +//===----------------------------------------------------------------------===// +// SendRecvAsyncEvents +//===----------------------------------------------------------------------===// + +absl::Status SendRecvAsyncEvents::Emplace(se::StreamExecutor* executor, + int32_t channel_id, + tsl::AsyncValueRef event) { + Key key = {executor, channel_id}; + + absl::MutexLock lock(&mutex_); + if (auto it = events_.try_emplace(key, std::move(event)); it.second) + return absl::OkStatus(); + + return absl::InternalError(absl::StrFormat( + "Async send/recv event already exists (channel_id=%d)", channel_id)); +} + +absl::StatusOr> SendRecvAsyncEvents::Extract( + se::StreamExecutor* executor, int32_t channel_id) { + Key key = {executor, channel_id}; + + absl::MutexLock lock(&mutex_); + if (auto event = events_.extract(key)) return std::move(event.mapped()); + + return absl::InternalError(absl::StrFormat( + "Async send/recv event was not found (channel_id==%d)", channel_id)); +} + +//===----------------------------------------------------------------------===// +// SendThunk +//===----------------------------------------------------------------------===// + +SendThunk::SendThunk( + ThunkInfo thunk_info, Shape shape, BufferAllocation::Slice buffer, + int64_t channel_id, std::shared_ptr events, + absl::flat_hash_map frontend_attrs, + std::optional device_constraint) + : Thunk(Thunk::kSend, thunk_info), + shape_(shape), + buffer_(buffer), + channel_id_(channel_id), + events_(std::move(events)), + frontend_attrs_(std::move(frontend_attrs)), + device_constraint_(device_constraint) {} + +absl::Status SendThunk::ExecuteOnStream(const ExecuteParams& params) { + VLOG(3) << "Send buffer: channel_id=" << channel_id_ + << "; shape=" << shape_.ToString(); + + TF_ASSIGN_OR_RETURN(bool skip, + ShouldSkip("sending buffer", params, device_constraint_)); + if (skip) return absl::OkStatus(); + + TraceMe trace( + [&] { return TraceMeEncode("Send", {{"channel_id", channel_id_}}); }); + + // Use device_to_host stream if it is available. + se::Stream* stream = params.device_to_host_stream; + if (stream) { + TF_RETURN_IF_ERROR(stream->WaitFor(params.stream)); + } else { + stream = params.stream; + } + + se::DeviceMemoryBase src = + params.buffer_allocations->GetDeviceAddress(buffer_); + + // Send buffer to a handler registered with the executable. + if (auto* send = params.send_device_memory_function) { + TF_ASSIGN_OR_RETURN( + AsyncValueRef done, + (*send)(channel_id_, stream, shape_, src, frontend_attrs_)); + return events_->Emplace(stream->parent(), channel_id_, std::move(done)); + } + + return absl::InvalidArgumentError( + "SendDeviceMemoryFunction is not available"); +} + +//===----------------------------------------------------------------------===// +// SendDoneThunk +//===----------------------------------------------------------------------===// + +SendDoneThunk::SendDoneThunk(ThunkInfo thunk_info, int64_t channel_id, + std::shared_ptr events, + std::optional device_constraint) + : Thunk(Thunk::kSend, thunk_info), + channel_id_(channel_id), + events_(std::move(events)), + device_constraint_(device_constraint) {} + +absl::Status SendDoneThunk::ExecuteOnStream(const ExecuteParams& params) { + VLOG(3) << "Wait for send completion: channel_id=" << channel_id_; + + TF_ASSIGN_OR_RETURN(bool skip, ShouldSkip("waiting for send completion", + params, device_constraint_)); + if (skip) return absl::OkStatus(); + + TraceMe trace( + [&] { return TraceMeEncode("SendDone", {{"channel_id", channel_id_}}); }); + + se::StreamExecutor* executor = params.stream->parent(); + TF_ASSIGN_OR_RETURN(auto done_event, events_->Extract(executor, channel_id_)); + + // Wait until send handler will record an event on the stream. + BlockUntilReady(done_event.GetAsyncValue()); + if (done_event.IsError()) return done_event.GetError(); + + VLOG(5) << "Completed Send operation: channel_id=" << channel_id_; + + // Once event is recorded we can add a stream dependency. + return params.stream->WaitFor(&done_event.get()); +} + +//===----------------------------------------------------------------------===// +// RecvThunk +//===----------------------------------------------------------------------===// + +RecvThunk::RecvThunk( + ThunkInfo thunk_info, Shape shape, BufferAllocation::Slice buffer, + int64_t channel_id, std::shared_ptr events, + absl::flat_hash_map frontend_attrs, + std::optional device_constraint) + : Thunk(Thunk::kSend, thunk_info), + shape_(shape), + buffer_(buffer), + channel_id_(channel_id), + events_(std::move(events)), + frontend_attrs_(std::move(frontend_attrs)), + device_constraint_(device_constraint) {} + +absl::Status RecvThunk::ExecuteOnStream(const ExecuteParams& params) { + VLOG(3) << "Recv buffer: channel_id=" << channel_id_ + << "; shape=" << shape_.ToString(); + + TF_ASSIGN_OR_RETURN( + bool skip, ShouldSkip("receiving buffer", params, device_constraint_)); + if (skip) return absl::OkStatus(); + + TraceMe trace( + [&] { return TraceMeEncode("Recv", {{"channel_id", channel_id_}}); }); + + // Use host_to_device stream if it is available. + se::Stream* stream = params.host_to_device_stream; + if (stream) { + TF_RETURN_IF_ERROR(stream->WaitFor(params.stream)); + } else { + stream = params.stream; + } + + se::DeviceMemoryBase dst = + params.buffer_allocations->GetDeviceAddress(buffer_); + + // Recv buffer from a handler registered with the run options. + if (auto* recv = params.recv_device_memory_function) { + TF_ASSIGN_OR_RETURN( + AsyncValueRef done, + (*recv)(channel_id_, stream, shape_, &dst, frontend_attrs_)); + return events_->Emplace(stream->parent(), channel_id_, std::move(done)); + } + + return absl::InvalidArgumentError( + "RecvDeviceMemoryFunction is not available"); +} + +//===----------------------------------------------------------------------===// +// RecvDoneThunk +//===----------------------------------------------------------------------===// + +RecvDoneThunk::RecvDoneThunk(ThunkInfo thunk_info, int64_t channel_id, + std::shared_ptr events, + std::optional device_constraint) + : Thunk(Thunk::kSend, thunk_info), + channel_id_(channel_id), + events_(std::move(events)) {} + +absl::Status RecvDoneThunk::ExecuteOnStream(const ExecuteParams& params) { + VLOG(3) << "Wait for recv completion: channel_id=" << channel_id_; + + TF_ASSIGN_OR_RETURN(bool skip, ShouldSkip("waiting for recv completion", + params, device_constraint_)); + if (skip) return absl::OkStatus(); + + TraceMe trace( + [&] { return TraceMeEncode("RecvDone", {{"channel_d", channel_id_}}); }); + + se::StreamExecutor* executor = params.stream->parent(); + TF_ASSIGN_OR_RETURN(auto done_event, events_->Extract(executor, channel_id_)); + + // Wait until send handler will record an event on the stream. + BlockUntilReady(done_event.GetAsyncValue()); + if (done_event.IsError()) return done_event.GetError(); + + VLOG(5) << "Completed Recv operation: channel=" << channel_id_; + + // Once event is recorded we can add a stream dependency. + return params.stream->WaitFor(&done_event.get()); +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/send_recv_thunk.h b/xla/service/gpu/runtime/send_recv_thunk.h new file mode 100644 index 0000000000000..2c235664e8188 --- /dev/null +++ b/xla/service/gpu/runtime/send_recv_thunk.h @@ -0,0 +1,168 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_SEND_RECV_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_SEND_RECV_THUNK_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/global_device_id.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/shape.h" +#include "xla/status.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/xla_data.pb.h" +#include "tsl/concurrency/async_value_ref.h" + +namespace xla::gpu { + +//===----------------------------------------------------------------------===// +// SendRecvAsyncEvents +//===----------------------------------------------------------------------===// + +// Send/Recv operations have two levels of async behavior: +// +// (1) AsyncValueRef will become available only after send/recv handler +// schedules all activities on the device. +// +// (2) se::Event will become available when device activity recorded by +// send/recv handlers complete. +// +// We keep track of Send/Recv commands in flight, and synchronize `send` and +// `recv` operations with corresponding `send-done` and `recv-done`. +// +// Each channel can have at most one event in flight for a given executor. +// +// We have a single instance of `SendRecvAsyncEvents` for each Gpu executable, +// and all thunks share it using a shared pointer. +// +// TODO(ezhulenev): Rename to `SendRecvEvents` once we remove deprecated XLA +// runtime, as it has name conflict. +class SendRecvAsyncEvents { + public: + // Emplace a new send/recv completion event. + absl::Status Emplace(se::StreamExecutor* executor, int32_t channel_id, + tsl::AsyncValueRef event); + + // Extract a send/recv completion event. + absl::StatusOr> Extract( + se::StreamExecutor* executor, int32_t channel_id); + + private: + using Key = std::pair; + + absl::Mutex mutex_; + absl::flat_hash_map> events_ + ABSL_GUARDED_BY(mutex_); +}; + +//===----------------------------------------------------------------------===// +// SendThunk +//===----------------------------------------------------------------------===// + +class SendThunk : public Thunk { + public: + SendThunk(ThunkInfo thunk_info, Shape shape, BufferAllocation::Slice buffer, + int64_t channel_id, std::shared_ptr events, + absl::flat_hash_map frontend_attrs, + std::optional device_constraint); + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + private: + Shape shape_; + BufferAllocation::Slice buffer_; + + int64_t channel_id_; + + std::shared_ptr events_; + absl::flat_hash_map frontend_attrs_; + std::optional device_constraint_; +}; + +//===----------------------------------------------------------------------===// +// SendDoneThunk +//===----------------------------------------------------------------------===// + +class SendDoneThunk : public Thunk { + public: + SendDoneThunk(ThunkInfo thunk_info, int64_t channel_id, + std::shared_ptr events, + std::optional device_constraint); + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + private: + int64_t channel_id_; + + std::shared_ptr events_; + std::optional device_constraint_; +}; + +//===----------------------------------------------------------------------===// +// RecvThunk +//===----------------------------------------------------------------------===// + +class RecvThunk : public Thunk { + public: + RecvThunk(ThunkInfo thunk_info, Shape shape, BufferAllocation::Slice buffer, + int64_t channel_id, std::shared_ptr events, + absl::flat_hash_map frontend_attrs, + std::optional device_constraint); + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + private: + Shape shape_; + BufferAllocation::Slice buffer_; + + int64_t channel_id_; + + std::shared_ptr events_; + absl::flat_hash_map frontend_attrs_; + std::optional device_constraint_; +}; + +//===----------------------------------------------------------------------===// +// RecvDoneThunk +//===----------------------------------------------------------------------===// + +class RecvDoneThunk : public Thunk { + public: + RecvDoneThunk(ThunkInfo thunk_info, int64_t channel_id, + std::shared_ptr events, + std::optional device_constraint); + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + private: + int64_t channel_id_; + + std::shared_ptr events_; + std::optional device_constraint_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_RUNTIME_SEND_RECV_THUNK_H_ diff --git a/xla/service/gpu/runtime/sequential_thunk.cc b/xla/service/gpu/runtime/sequential_thunk.cc new file mode 100644 index 0000000000000..143ad94a29071 --- /dev/null +++ b/xla/service/gpu/runtime/sequential_thunk.cc @@ -0,0 +1,68 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/sequential_thunk.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "xla/service/gpu/runtime/annotation.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "tsl/platform/errors.h" +#include "tsl/profiler/lib/scoped_annotation.h" + +namespace xla { +namespace gpu { + +using ::tsl::profiler::ScopedAnnotation; + +SequentialThunk::SequentialThunk(ThunkInfo thunk_info, ThunkSequence thunks) + : Thunk(Kind::kSequential, thunk_info), thunks_(std::move(thunks)) {} + +std::string SequentialThunk::ToStringExtra(int indent) const { + std::string result = "\n"; + absl::StrAppend(&result, thunks().ToString(indent + 1, nullptr)); + return result; +} + +absl::Status SequentialThunk::Prepare(const PrepareParams& params, + ResourceRequests& resource_requests) { + for (auto& thunk : thunks_) { + TF_RETURN_IF_ERROR(thunk->Prepare(params, resource_requests)); + } + return absl::OkStatus(); +} + +absl::Status SequentialThunk::Initialize(const InitializeParams& params) { + for (auto& thunk : thunks_) { + TF_RETURN_IF_ERROR(thunk->Initialize(params)); + } + return absl::OkStatus(); +} + +absl::Status SequentialThunk::ExecuteOnStream(const ExecuteParams& params) { + const ModuleAnnotations* annotations = GetCurrentModuleAnnotations(); + for (const auto& thunk : thunks_) { + auto annotation = + GetKernelAnnotation(annotations, thunk->profile_annotation()); + TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(params)); + } + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/sequential_thunk.h b/xla/service/gpu/runtime/sequential_thunk.h new file mode 100644 index 0000000000000..4642f08becadb --- /dev/null +++ b/xla/service/gpu/runtime/sequential_thunk.h @@ -0,0 +1,53 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_SEQUENTIAL_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_SEQUENTIAL_THUNK_H_ + +#include + +#include "absl/status/status.h" +#include "xla/service/gpu/runtime/thunk.h" + +namespace xla { +namespace gpu { + +// A thunk that wraps a list of sub-thunks. Executing this thunk executes all +// the sub-thunks sequentially. This is useful to implement instructions that +// require multiple kernel launches or library calls. +class SequentialThunk : public Thunk { + public: + SequentialThunk(ThunkInfo thunk_info, ThunkSequence thunks); + SequentialThunk(const SequentialThunk&) = delete; + SequentialThunk& operator=(const SequentialThunk&) = delete; + + ThunkSequence& thunks() { return thunks_; } + const ThunkSequence& thunks() const { return thunks_; } + std::string ToStringExtra(int indent) const override; + + absl::Status Prepare(const PrepareParams& params, + ResourceRequests& resource_requests) override; + absl::Status Initialize(const InitializeParams& params) override; + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + private: + // The list of sub-thunks. + ThunkSequence thunks_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_SEQUENTIAL_THUNK_H_ diff --git a/xla/service/gpu/runtime/sleep_kernel.cu.cc b/xla/service/gpu/runtime/sleep_kernel.cu.cc deleted file mode 100644 index bcdfaeeb09dc9..0000000000000 --- a/xla/service/gpu/runtime/sleep_kernel.cu.cc +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "xla/service/gpu/runtime/sleep_kernel.h" - -namespace xla::gpu { -namespace { - -#if GOOGLE_CUDA -__global__ void mock_nccl_call(unsigned sleep_ns) { -#if __CUDA_ARCH__ >= 700 // __nanosleep requires compute capability 7.0 - // Passing too high a number to __nanosleep makes it sleep for much less time - // than the passed-in number. So only pass 1,000,000 and keep calling - // __nanosleep in a loop. - int n = sleep_ns / 1000000; - unsigned rem = sleep_ns % 1000000; - for (int i = 0; i < n; i++) __nanosleep(1000000U); - __nanosleep(rem); -#endif -} -#elif TENSORFLOW_USE_ROCM -__global__ void mock_nccl_call(unsigned sleep_ns, unsigned clock_rate_khz) { - if (threadIdx.x < warpSize) { - // s_sleep causes a wave to sleep for (64 * SIMM16[6:0] + 1..64) clocks. - uint32_t nclocks = (uint32_t)((float)clock_rate_khz / 64e6 * sleep_ns); - for (uint32_t i = 0; i < nclocks / 64; i++) { - __builtin_amdgcn_s_sleep(64); - } - } - __syncthreads(); -} -#endif // TENSORFLOW_USE_ROCM -} // namespace - -void* GetSleepKernel() { return reinterpret_cast(&mock_nccl_call); } - -} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/sleep_kernel.h b/xla/service/gpu/runtime/sleep_kernel.h deleted file mode 100644 index 0bcaed7c56cfb..0000000000000 --- a/xla/service/gpu/runtime/sleep_kernel.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RUNTIME_SLEEP_KERNEL_H_ -#define XLA_SERVICE_GPU_RUNTIME_SLEEP_KERNEL_H_ - -namespace xla::gpu { - -void* GetSleepKernel(); - -} // namespace xla::gpu - -#endif // XLA_SERVICE_GPU_RUNTIME_SLEEP_KERNEL_H_ diff --git a/xla/service/gpu/runtime/stream_synchronization.cc b/xla/service/gpu/runtime/stream_synchronization.cc deleted file mode 100644 index 626acb10bed02..0000000000000 --- a/xla/service/gpu/runtime/stream_synchronization.cc +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/stream_synchronization.h" - -#include "xla/runtime/executable.h" -#include "xla/service/gpu/runtime/concurrent_region.h" -#include "xla/service/gpu/runtime/support.h" - -namespace xla { -namespace gpu { - -static absl::Status AwaitImpl(ConcurrentRegionStatus* region_status, - int64_t from, absl::Span to) { - TF_ASSIGN_OR_RETURN(se::Stream * from_stream, region_status->GetStream(from)); - for (int64_t to_index : to) { - TF_ASSIGN_OR_RETURN(se::Stream * to_stream, - region_status->GetStream(to_index)); - from_stream->ThenWaitFor(to_stream); - } - - return absl::OkStatus(); -} - -//===----------------------------------------------------------------------===// -// Define custom calls that mark the concurrent region in CUDA graphs. -//===----------------------------------------------------------------------===// - -using xla::runtime::CustomCall; - -XLA_RUNTIME_DEFINE_CUSTOM_CALL(Await, FunctionWrapper(), checks, - CustomCall::Bind("xla.streams.await") - .UserData() - .Attr("from") - .Attr>("to")); - -void RegisterStreamSynchronizationCustomCalls( - runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.streams.await", Await); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/stream_synchronization.h b/xla/service/gpu/runtime/stream_synchronization.h deleted file mode 100644 index 11d3175057c8c..0000000000000 --- a/xla/service/gpu/runtime/stream_synchronization.h +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_STREAM_SYNCHRONIZATION_H_ -#define XLA_SERVICE_GPU_RUNTIME_STREAM_SYNCHRONIZATION_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace gpu { - -// Registers XLA Gpu runtime stream synchronization custom calls. -void RegisterStreamSynchronizationCustomCalls( - runtime::DirectCustomCallRegistry& registry); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_STREAM_SYNCHRONIZATION_H_ diff --git a/xla/service/gpu/runtime/support.cc b/xla/service/gpu/runtime/support.cc deleted file mode 100644 index e1f9753b519a8..0000000000000 --- a/xla/service/gpu/runtime/support.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/support.h" - -#include -#include - -#include "tsl/profiler/lib/scoped_annotation_stack.h" - -namespace xla { -namespace gpu { - -namespace { -static thread_local std::string_view current_tracing_scope = {}; -} // namespace - -void SetCurrentTracingScope(std::string_view scope) { - current_tracing_scope = scope; -} - -void ResetCurrentTracingScope() { current_tracing_scope = std::string_view(); } - -void AppendDiagnosticToString(runtime::DiagnosticEngine& diagnostic_engine, - std::string* diagnostic, - bool append_annotation_stack) { - diagnostic_engine.AddHandler( - [append_annotation_stack, diagnostic](runtime::Diagnostic& d) { - if (!diagnostic->empty()) absl::StrAppend(diagnostic, "; "); - absl::StrAppend(diagnostic, d.status().message()); - - // Append the current trace which should help identifying original HLO - // operation that fails. - if (!current_tracing_scope.empty()) { - absl::StrAppend(diagnostic, - "; current tracing scope: ", current_tracing_scope); - } - - // Append current profiling annotation which will have the XLA - // executable name and program id. - if (append_annotation_stack) { - absl::StrAppend(diagnostic, "; current profiling annotation: ", - tsl::profiler::AnnotationStack::Get()); - } - - LOG(WARNING) << "Intercepted XLA runtime error:\n" - << d.status().ToString( - absl::StatusToStringMode::kWithEverything); - - return runtime::success(); - }); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/support.h b/xla/service/gpu/runtime/support.h deleted file mode 100644 index 98cefce43a7a8..0000000000000 --- a/xla/service/gpu/runtime/support.h +++ /dev/null @@ -1,166 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_SUPPORT_H_ -#define XLA_SERVICE_GPU_RUNTIME_SUPPORT_H_ - -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "llvm/ADT/ArrayRef.h" -#include "xla/mlir/runtime/transforms/custom_call_encoding.h" -#include "xla/runtime/custom_call.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/shape_util.h" -#include "xla/stream_executor/blas.h" -#include "xla/stream_executor/device_memory.h" - -namespace xla { -namespace gpu { - -template -using FunctionWrapper = xla::runtime::CustomCall::FunctionWrapper; - -struct DotDimensionNumbers { - absl::Span lhs_batch; - absl::Span lhs_contract; - absl::Span rhs_batch; - absl::Span rhs_contract; -}; - -// Disable expensive CustomCall checks in optimized build. -inline constexpr runtime::CustomCall::RuntimeChecks checks = // NOLINT -#if defined(NDEBUG) - runtime::CustomCall::RuntimeChecks::kLess; -#else - runtime::CustomCall::RuntimeChecks::kDefault; -#endif - -template -absl::StatusOr ToAbsl(StatusOr status_or) { - if (!status_or.ok()) return status_or.status(); - return std::move(status_or).value(); -} - -inline se::DeviceMemoryBase GetDeviceAddress( - const runtime::FlatMemrefView& memref) { - return se::DeviceMemoryBase(memref.data, memref.size_in_bytes); -} - -inline se::DeviceMemoryBase GetDeviceAddress( - const runtime::MemrefView& memref) { - uint64_t size = primitive_util::ByteWidth(memref.dtype); - for (auto dim : memref.sizes) size *= dim; - return se::DeviceMemoryBase(memref.data, size); -} - -inline se::DeviceMemoryBase GetDeviceAddress( - const runtime::StridedMemrefView& memref) { - uint64_t size = primitive_util::ByteWidth(memref.dtype); - for (auto dim : memref.sizes) size *= dim; - if (primitive_util::Is4BitType(memref.dtype)) { - size = (size + 1) / 2; - } - return se::DeviceMemoryBase(memref.data, size); -} - -inline Shape ToShape(const runtime::StridedMemrefView& memref) { - // Recover `minor_to_major` dimensions permutation from strides. - auto indexed_strides_range = - llvm::map_range(llvm::enumerate(memref.strides), [](auto pair) { - return std::pair{pair.value(), pair.index()}; - }); - - auto indexed_strides = llvm::to_vector(indexed_strides_range); - llvm::stable_sort(indexed_strides); - - llvm::SmallVector minor_to_major; - minor_to_major.reserve(indexed_strides.size()); - for (auto& pair : indexed_strides) minor_to_major.push_back(pair.second); - - return ShapeUtil::MakeShapeWithDenseLayout(memref.dtype, memref.sizes, - minor_to_major); -} - -inline StatusOr GetGemmConfig( - const runtime::StridedMemrefView& lhs, - const runtime::StridedMemrefView& rhs, - const runtime::StridedMemrefView& out, int64_t algorithm, double alpha_real, - double alpha_imag, double beta, absl::Span lhs_batch, - absl::Span lhs_contract, absl::Span rhs_batch, - absl::Span rhs_contract, int64_t compute_precision, - const std::optional c = std::nullopt, - const std::optional& bias = std::nullopt, - bool grad_x = false, bool grad_y = false) { - Shape c_shape = ToShape(c.value_or(out)); - Shape bias_shape; - Shape* bias_shape_ptr = nullptr; - if (bias) { - bias_shape = ToShape(*bias); - bias_shape_ptr = &bias_shape; - } - return GemmConfig::For(ToShape(lhs), lhs_batch, lhs_contract, ToShape(rhs), - rhs_batch, rhs_contract, c_shape, bias_shape_ptr, - ToShape(out), alpha_real, alpha_imag, beta, algorithm, - compute_precision, grad_x, grad_y); -} - -// adds Dot Dimension Attribute encodings for calls to Gemm and cuBLASLt -inline void PopulateDotDimsAttrEncoding( - runtime::CustomCallAttrEncodingSet& encoding) { - using DotDimsAttr = mlir::mhlo::DotDimensionNumbersAttr; - encoding.Add< - xla::runtime::AggregateAttrEncoding>( - encoding, - xla::runtime::AggregateAttrDef() - .Add("lhs_batch", &DotDimsAttr::getLhsBatchingDimensions) - .Add("lhs_contract", &DotDimsAttr::getLhsContractingDimensions) - .Add("rhs_batch", &DotDimsAttr::getRhsBatchingDimensions) - .Add("rhs_contract", &DotDimsAttr::getRhsContractingDimensions)); -} - -// Appends to `diagnostic_engine` a handler that appends all emitted errors to -// the `diagnostic` string. If `append_annotation_stack` is true, it will append -// current profiler annotation stack to the diagnostic message (annotation used -// in Xprof). -void AppendDiagnosticToString(runtime::DiagnosticEngine& diagnostic_engine, - std::string* diagnostic, - bool append_annotation_stack = false); - -// Sets the current tracing scope that will be added to all emitted diagnostics. -void SetCurrentTracingScope(std::string_view scope); -void ResetCurrentTracingScope(); - -} // namespace gpu -} // namespace xla - -namespace xla { -namespace runtime { - -// using llvm::ArrayRef; - -XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING( - xla::gpu::DotDimensionNumbers, - AggregateMember>("lhs_batch"), - AggregateMember>("lhs_contract"), - AggregateMember>("rhs_batch"), - AggregateMember>("rhs_contract")); - -} // namespace runtime -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_SUPPORT_H_ diff --git a/xla/service/gpu/runtime/thunk.cc b/xla/service/gpu/runtime/thunk.cc new file mode 100644 index 0000000000000..88ee7b4b4a25a --- /dev/null +++ b/xla/service/gpu/runtime/thunk.cc @@ -0,0 +1,351 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/thunk.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "xla/executable_run_options.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/global_device_id.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/gpu/gpu_executable_run_options.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_clique.h" +#include "xla/service/service_executable_run_options.h" +#include "xla/stream_executor/stream.h" +#include "xla/translate/mhlo_to_hlo/location_exporter.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +//===----------------------------------------------------------------------===// +// Thunk::CollectiveCliques +//===----------------------------------------------------------------------===// + +Thunk::CollectiveCliques::CollectiveCliques( + NcclClique::AcquiredCliquesMap cliques_map) + : cliques_map_(std::move(cliques_map)) {} + +absl::StatusOr Thunk::CollectiveCliques::GetComm( + const NcclCliqueKey& clique_key, int32_t rank) const { + // Check that we locked access to a clique for `clique_key`. + auto clique = cliques_map_.find(clique_key); + if (clique == cliques_map_.end()) { + return absl::NotFoundError(absl::StrCat("No clique found for clique key: ", + clique_key.ToString())); + } + + // Check that clique has a communicator for our rank. + auto communicator = (*clique->second)->comm(rank); + if (!communicator.has_value()) { + return absl::InternalError(absl::StrCat("Communicator for rank ", rank, + " not found in a NCCL clique ", + clique_key.ToString())); + } + + return *communicator; +} + +absl::StatusOr Thunk::CollectiveCliques::num_communicators( + const NcclCliqueKey& clique_key) const { + // Check that we locked access to a clique for `clique_key`. + auto clique = cliques_map_.find(clique_key); + if (clique == cliques_map_.end()) { + return absl::NotFoundError(absl::StrCat("No clique found for clique key: ", + clique_key.ToString())); + } + + return (*clique->second)->num_communicators(); +} + +//===----------------------------------------------------------------------===// +// Thunk::CollectiveExecuteParams +//===----------------------------------------------------------------------===// + +using GlobalDeviceIdMap = Thunk::CollectiveExecuteParams::GlobalDeviceIdMap; + +// Returns global device id for a local device ordinal or an error if global +// device id map is misconfigured and missing an entry for a local device. +static absl::StatusOr GetGlobalDeviceId( + const GlobalDeviceIdMap* device_id_map, int64_t local_device_ordinal) { + // No local -> global mapping was provided; assume the identity mapping. + if (!device_id_map) return GlobalDeviceId(local_device_ordinal); + + // Find a global device id in a global device id map. + auto it = device_id_map->find(local_device_ordinal); + if (it == device_id_map->end()) + return absl::NotFoundError( + absl::StrCat("No global device id found for local device ordinal: ", + local_device_ordinal)); + + return it->second; +} + +absl::StatusOr +Thunk::CollectiveExecuteParams::Create( + const ServiceExecutableRunOptions& run_options, + int64_t local_device_ordinal, int64_t collective_max_nchannels, + int64_t p2p_max_nchannels) { + const GpuExecutableRunOptions* gpu_options = + run_options.run_options().gpu_executable_run_options(); + + auto* device_id_map = gpu_options && gpu_options->gpu_global_device_ids() + ? &*gpu_options->gpu_global_device_ids() + : nullptr; + + auto* nccl_callback = gpu_options && gpu_options->nccl_clique_id_callback() + ? &gpu_options->nccl_clique_id_callback() + : nullptr; + + TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id, + GetGlobalDeviceId(device_id_map, local_device_ordinal)); + + return CollectiveExecuteParams( + run_options.stream()->parent(), run_options.run_options().run_id(), + local_device_ordinal, global_device_id, + run_options.run_options().device_assignment(), device_id_map, + nccl_callback, collective_max_nchannels, p2p_max_nchannels); +} + +Thunk::CollectiveExecuteParams::CollectiveExecuteParams( + se::StreamExecutor* executor, RunId run_id, int64_t local_device_ordinal, + GlobalDeviceId global_device_id, const DeviceAssignment* device_assn, + const GlobalDeviceIdMap* global_device_id_map, + const NcclCliqueIdCallback* nccl_clique_id_callback, + int64_t collective_max_nchannels, int64_t p2p_max_nchannels) + : executor(executor), + run_id(run_id), + local_device_ordinal(local_device_ordinal), + global_device_id(global_device_id), + device_assn(device_assn), + global_device_id_map(global_device_id_map), + nccl_clique_id_callback(nccl_clique_id_callback), + collective_max_nchannels(collective_max_nchannels), + p2p_max_nchannels(p2p_max_nchannels) {} + +//===----------------------------------------------------------------------===// +// Thunk::ExecuteParams +//===----------------------------------------------------------------------===// + +Thunk::ExecuteParams Thunk::ExecuteParams::Create( + const ServiceExecutableRunOptions& run_options, + const BufferAllocations& buffer_allocations, se::Stream* stream, + se::Stream* command_buffer_trace_stream, + absl::Span async_streams, + CollectiveExecuteParams* collective_params, + CollectiveCliques* collective_cliques, + ExecutionStreamIdMap additional_compute_streams) { + return ExecuteParams(&buffer_allocations, stream, command_buffer_trace_stream, + {async_streams.begin(), async_streams.end()}, + collective_params, collective_cliques, + run_options.run_options().device_to_host_stream(), + run_options.run_options().host_to_device_stream(), + run_options.run_options().send_device_memory_function(), + run_options.run_options().recv_device_memory_function(), + additional_compute_streams); +} + +Thunk::ExecuteParams Thunk::ExecuteParams::CloneWithNewAllocations( + const Thunk::ExecuteParams& params, + const BufferAllocations& buffer_allocations) { + return ExecuteParams( + &buffer_allocations, params.stream, params.command_buffer_trace_stream, + {params.async_comms_streams.begin(), params.async_comms_streams.end()}, + params.collective_params, params.collective_cliques, + params.device_to_host_stream, params.host_to_device_stream, + params.send_device_memory_function, params.recv_device_memory_function, + params.additional_compute_streams); +} + +Thunk::ExecuteParams::ExecuteParams( + const BufferAllocations* buffer_allocations, se::Stream* stream, + se::Stream* command_buffer_trace_stream, + absl::InlinedVector async_comms_streams, + CollectiveExecuteParams* collective_params, + CollectiveCliques* collective_cliques, se::Stream* device_to_host_stream, + se::Stream* host_to_device_stream, + SendDeviceMemoryFunction* send_device_memory_function, + RecvDeviceMemoryFunction* recv_device_memory_function, + ExecutionStreamIdMap additional_compute_streams) + : buffer_allocations(buffer_allocations), + stream(stream), + command_buffer_trace_stream(command_buffer_trace_stream), + async_comms_streams(async_comms_streams), + collective_params(collective_params), + collective_cliques(collective_cliques), + device_to_host_stream(device_to_host_stream), + host_to_device_stream(host_to_device_stream), + send_device_memory_function(send_device_memory_function), + recv_device_memory_function(recv_device_memory_function), + additional_compute_streams(additional_compute_streams) {} + +//===----------------------------------------------------------------------===// + +/*static*/ absl::string_view Thunk::KindToString(Thunk::Kind kind) { +#define CASE(x) \ + case Thunk::x: \ + return #x + switch (kind) { + CASE(kAddressComputation); + CASE(kCholesky); + CASE(kCommandBuffer); + CASE(kConditional); + CASE(kConvolution); + CASE(kConvolutionReorder); + CASE(kCopy); + CASE(kCubSort); + CASE(kCublasLtMatmul); + CASE(kCustomCall); + CASE(kCustomKernel); + CASE(kNcclAllGather); + CASE(kNcclAllGatherStart); + CASE(kNcclAllGatherDone); + CASE(kNcclAllReduce); + CASE(kNcclAllReduceStart); + CASE(kNcclAllReduceDone); + CASE(kNcclCollectiveBroadcast); + CASE(kNcclCollectiveBroadcastStart); + CASE(kNcclCollectiveBroadcastDone); + CASE(kNcclCollectivePermute); + CASE(kNcclCollectivePermuteStart); + CASE(kNcclCollectivePermuteDone); + CASE(kNcclReduceScatter); + CASE(kNcclReduceScatterStart); + CASE(kNcclReduceScatterDone); + CASE(kNcclAllToAll); + CASE(kNcclAllToAllStart); + CASE(kNcclAllToAllDone); + CASE(kNcclSend); + CASE(kNcclSendDone); + CASE(kNcclRecv); + CASE(kNcclRecvDone); + CASE(kFft); + CASE(kGemm); + CASE(kInfeed); + CASE(kKernel); + CASE(kMemset32BitValue); + CASE(kMemzero); + CASE(kNorm); + CASE(kOutfeed); + CASE(kSend); + CASE(kSendDone); + CASE(kPartitionId); + CASE(kReplicaId); + CASE(kRecv); + CASE(kRecvDone); + CASE(kSequential); + CASE(kTriangularSolve); + CASE(kWhile); + CASE(kFlashAttn); + CASE(kFusedMHA); + CASE(kWaitForStreams); + CASE(kCuDnn); + } +} + +/*static*/ +absl::StatusOr Thunk::GetStreamForExecution( + ExecutionStreamId stream_id, const ExecuteParams& params) { + if (stream_id == kDefaultExecutionStreamId) { + return params.stream; + } + auto iter = params.additional_compute_streams.find(stream_id); + if (iter == params.additional_compute_streams.end()) { + return absl::InvalidArgumentError("Invalid execution stream id."); + } + return iter->second; +} + +std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) { + return os << Thunk::KindToString(kind); +} + +std::string ThunkSequence::ToString( + int indent, + std::function get_thunk_annotation) const { + const std::string indent_str(indent * 2, ' '); + if (empty()) return indent_str + "No thunks."; + + auto thunk_with_longest_kind = absl::c_max_element( + *this, + [](const std::unique_ptr& a, const std::unique_ptr& b) { + return Thunk::KindToString(a->kind()).length() < + Thunk::KindToString(b->kind()).length(); + }); + int64_t max_thunk_kind_len = + Thunk::KindToString(thunk_with_longest_kind->get()->kind()).length(); + std::string result; + for (const std::unique_ptr& thunk : *this) { + // Write out the thunk kind, padded out to max_thunk_kind_len. + absl::string_view kind_str = Thunk::KindToString(thunk->kind()); + absl::StrAppend(&result, indent_str, kind_str, + std::string(max_thunk_kind_len - kind_str.length(), ' '), + "\t"); + if (get_thunk_annotation) { + absl::StrAppend(&result, get_thunk_annotation(thunk.get())); + } + absl::StrAppend(&result, thunk->ToStringExtra(indent)); + absl::StrAppend(&result, "\n"); + } + return result; +} + +bool IsReductionCollective(Thunk::Kind kind) { + return kind == Thunk::kNcclAllReduce || kind == Thunk::kNcclAllReduceStart || + kind == Thunk::kNcclReduceScatter || + kind == Thunk::kNcclReduceScatterStart; +} + +Thunk::ThunkInfo Thunk::ThunkInfo::WithProfileAnnotation(mlir::Operation* op) { + ThunkInfo thunk_info(op); + thunk_info.profile_annotation = + mlir::mhlo::GetDebugNameFromLocation(op->getLoc()); + return thunk_info; +} + +Thunk::ThunkInfo Thunk::ThunkInfo::WithProfileAnnotation( + const HloInstruction* instr) { + ThunkInfo thunk_info(nullptr); + thunk_info.profile_annotation = instr->name(); + auto gpu_backend_config = instr->backend_config(); + if (gpu_backend_config.ok()) { + thunk_info.execution_stream_id = + std::max(kDefaultExecutionStreamId.value(), + gpu_backend_config->operation_queue_id()); + } + return thunk_info; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/thunk.h b/xla/service/gpu/runtime/thunk.h new file mode 100644 index 0000000000000..4427143a774a3 --- /dev/null +++ b/xla/service/gpu/runtime/thunk.h @@ -0,0 +1,462 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_THUNK_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "xla/executable_run_options.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/global_device_id.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_clique.h" +#include "xla/service/service_executable_run_options.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/lib/gtl/int_type.h" + +namespace xla { +namespace gpu { + +TSL_LIB_GTL_DEFINE_INT_TYPE(ExecutionStreamId, int64_t); + +// Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the +// metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction. +// +// Thunk provides the Initialize and ExecuteOnStream interface for GpuExecutable +// to initialize and execute the invocation respectively. Its subclasses are +// supposed to override these interfaces to launch a generated kernel or call an +// external library function (such as operations in cuBLAS). +// +// Thunks have three execution stages: +// +// (1) Prepare: at this stage Thunk can request shared resources required at run +// time, i.e. collective thunks request collective cliques. Executable(s) +// will coordinate resource acquisition. +// +// (2) Initialize: at this stage Thunk must initialize all internal state +// required for execution, maybe using resources requested at prepare stage. +// +// (3) Execute: at this stage Thunk must launch "work" on underlying device +// using given stream, and it's expected that all expensive initialization +// is completed at earlier stages. +// +// This is thread-compatible. Thunk implementation should expect that it will be +// called concurrently from multiple threads, for different run ids and for +// different devices (stream executors). For partitioned XLA programs the +// expectation is that all local participants execute simultaneously on +// different threads and coordinate resource acquisition via rendezvous. +class Thunk { + public: + using ExecutionStreamIdMap = + absl::flat_hash_map; + + // When default execution stream id is used, operations launched by a thunk + // must be synchronized with a stream passed in ExecuteOptions. + static constexpr auto kDefaultExecutionStreamId = ExecutionStreamId(0); + + enum Kind { + kAddressComputation, + kCholesky, + kConditional, + kConvolution, + kConvolutionReorder, + kCopy, + kCommandBuffer, + kCubSort, + kCublasLtMatmul, + kCustomCall, + kCustomKernel, + kFft, + kGemm, + kInfeed, + kKernel, + kMemset32BitValue, + kMemzero, + kNcclAllGather, + kNcclAllGatherStart, + kNcclAllGatherDone, + kNcclAllReduce, + kNcclAllReduceStart, + kNcclAllReduceDone, + kNcclCollectiveBroadcast, + kNcclCollectiveBroadcastStart, + kNcclCollectiveBroadcastDone, + kNcclCollectivePermute, + kNcclCollectivePermuteStart, + kNcclCollectivePermuteDone, + kNcclReduceScatter, + kNcclReduceScatterStart, + kNcclReduceScatterDone, + kNcclAllToAll, + kNcclAllToAllStart, + kNcclAllToAllDone, + kNcclSend, + kNcclSendDone, + kNcclRecv, + kNcclRecvDone, + kNorm, + kOutfeed, + kPartitionId, + kRecv, + kRecvDone, + kReplicaId, + kSequential, + kSend, + kSendDone, + kTriangularSolve, + kWhile, + kFlashAttn, + kFusedMHA, + kWaitForStreams, + kCuDnn + }; + + // . + using BinaryMap = absl::flat_hash_map; + + // TODO(ezhulenev): This should become a part of StreamExecutor library, but + // for now we keep it here as a Thunk implementation detail. It's not yet + // clear what else should become a part of "executable source", we likely + // need to keep some information about available symbols and signatures. + struct ExecutableSource { + std::string_view text; // PTX for NVIDIA backend + absl::Span binary; // CUBIN for NVIDIA backends + BinaryMap dnn_compiled_graphs; + }; + + struct ThunkInfo { + explicit ThunkInfo(mlir::Operation* op) : op(op) {} + static ThunkInfo WithProfileAnnotation(mlir::Operation* op); + static ThunkInfo WithProfileAnnotation(const HloInstruction* instr); + + std::string profile_annotation; + // TODO(b/304613751): This is only needed by the LMHLO. Remove this when + // LMHLO is removed from the runtime pipeline. + mlir::Operation* op; + + ExecutionStreamId execution_stream_id = kDefaultExecutionStreamId; + }; + + //===--------------------------------------------------------------------===// + // ResourceRequests + //===--------------------------------------------------------------------===// + + // Each individual thunk can request various resources required for execution + // at prepare stage. XLA executable is responsible for allocating them before + // initializing and executing thunks. + class ResourceRequests { + public: + virtual ~ResourceRequests() = default; + virtual absl::Status AddClique(const NcclCliqueKey& clique_key, + int32_t num_local_participants) = 0; + }; + + //===--------------------------------------------------------------------===// + // CollectiveCliques + //===--------------------------------------------------------------------===// + + // A collection of collective cliques acquired based on resource requests + // collected from all thunks at prepare stage. + class CollectiveCliques { + public: + CollectiveCliques() = default; + explicit CollectiveCliques(NcclClique::AcquiredCliquesMap cliques_map); + + absl::StatusOr GetComm( + const NcclCliqueKey& clique_key, int32_t rank) const; + + // Returns the number of communicators in a collective clique. Returns error + // if we do not have an acquired clique for a given key. + absl::StatusOr num_communicators( + const NcclCliqueKey& clique_key) const; + + bool empty() const { return cliques_map_.empty(); } + + private: + NcclClique::AcquiredCliquesMap cliques_map_; + }; + + //===--------------------------------------------------------------------===// + // CollectiveExecuteParams + //===--------------------------------------------------------------------===// + + // Parameters capturing all the details required for collective execution of + // XLA executables (multiple partitions and replicas). + struct CollectiveExecuteParams { + // Creates NCCL execution parameters from the run options for the given + // local device. Returns an error if run options are misconfigured (i.e. + // missing a global device mapping for a local device ordinal). + static absl::StatusOr Create( + const ServiceExecutableRunOptions& run_options, + int64_t local_device_ordinal, int64_t collective_max_nchannels = 0, + int64_t p2p_max_nchannels = 0); + + // A mapping from local device ordinals to global device IDs. + using GlobalDeviceIdMap = std::map; + + se::StreamExecutor* executor; + + // XLA execution run id allows us to distinguish collective operations + // from different concurrent executions and avoid deadlocks. + RunId run_id; + + int64_t local_device_ordinal; + GlobalDeviceId global_device_id; + + const DeviceAssignment* device_assn; + const GlobalDeviceIdMap* global_device_id_map; + const NcclCliqueIdCallback* nccl_clique_id_callback; + + int64_t collective_max_nchannels; + int64_t p2p_max_nchannels; + + private: + CollectiveExecuteParams(se::StreamExecutor* executor, RunId run_id, + int64_t local_device_ordinal, + GlobalDeviceId global_device_id, + const DeviceAssignment* device_assn, + const GlobalDeviceIdMap* global_device_id_map, + const NcclCliqueIdCallback* nccl_clique_id_callback, + int64_t collective_max_nchannels, + int64_t p2p_max_nchannels); + }; + + //===--------------------------------------------------------------------===// + // PrepareParams + //===--------------------------------------------------------------------===// + + // Parameters passed to Prepare. At thunk prepare time we do not launch any + // work or do any expensive initialization and only pass resource requirements + // back to executable, i.e. request collective cliques required at run time. + struct PrepareParams { + // Parameters for executing collective operations. + const CollectiveExecuteParams* collective_params = nullptr; + }; + + //===--------------------------------------------------------------------===// + // InitializeParams + //===--------------------------------------------------------------------===// + + // TODO(ezhulenev): Merge InitializeParams and ExecuteParams as they have + // almost the same members and tightly coupled. + + // Parameters passed to Initialize. At thunk initialization time we do not + // launch any "work" on device and only initialize thunks for execution, i.e. + // we pre-load kernels on device and instantiate all command buffers. + struct InitializeParams { + se::StreamExecutor* executor = nullptr; + ExecutableSource src; + + const BufferAllocations* buffer_allocations = nullptr; + + // Main compute stream that will be used, passed via `ExecuteParams` to + // `ExecuteOnStream`. It can be used to initialize on-device "state" (i.e. + // various control structures) at command buffer recording time (we use it + // to initialize NCCL execution plans on device when we trace NCCL + // operations into command buffers); + se::Stream* stream = nullptr; + + // Auxiliary stream for tracing command buffers. We use a separate stream to + // avoid accidental tracing of unrelated activities on a main stream. + se::Stream* command_buffer_trace_stream = nullptr; + + // Parameters for executing collective operations. + CollectiveExecuteParams* collective_params = nullptr; + + // Collective cliques acquired based on resource requests. + CollectiveCliques* collective_cliques = nullptr; + }; + + //===--------------------------------------------------------------------===// + // ExecuteParams + //===--------------------------------------------------------------------===// + + // Parameters passed to ExecuteOnStream. ExecuteOnStream is responsible for + // launching "work" on device, i.e. it launches kernels, executes command + // buffers and calls into libraries (cuBLAS, cuDNN etc.). + struct ExecuteParams { + // Constructs execute parameters from an executable run options. Return + // error if run options are misconfigured. + static ExecuteParams Create( + const ServiceExecutableRunOptions& run_options, + const BufferAllocations& buffer_allocations, se::Stream* stream, + se::Stream* command_buffer_trace_stream, + absl::Span async_streams, + CollectiveExecuteParams* collective_params, + CollectiveCliques* collective_cliques, + ExecutionStreamIdMap additional_compute_streams = {}); + + // Constructs execute parameters from an existing parameters but with + // different buffer allocations. + static ExecuteParams CloneWithNewAllocations( + const ExecuteParams& params, + const BufferAllocations& buffer_allocations); + + const BufferAllocations* buffer_allocations; // never null + + // Main compute stream on which thunks launch operations. + se::Stream* stream; + + // Auxiliary stream for tracing command buffers. We use a separate stream to + // avoid accidental tracing of unrelated activities on a main stream. + se::Stream* command_buffer_trace_stream; + + // Streams for asynchronous collective communications. + // TODO(ezhulenev): Move this into `CollectiveExecuteParams`. + absl::InlinedVector async_comms_streams; + + // Parameters for executing collective operations. + CollectiveExecuteParams* collective_params; + + // Collective cliques acquired based on resource requests. + CollectiveCliques* collective_cliques; + + // Streams for moving data between host and device. + se::Stream* device_to_host_stream; + se::Stream* host_to_device_stream; + + // Send/Recv callbacks passed to XLA from PjRt. + SendDeviceMemoryFunction* send_device_memory_function; + RecvDeviceMemoryFunction* recv_device_memory_function; + + // Additional compute streams on which thunks launch operations. + ExecutionStreamIdMap additional_compute_streams; + + private: + friend class CommandBufferThunk; + + ExecuteParams(const BufferAllocations* buffer_allocations, + se::Stream* stream, se::Stream* command_buffer_trace_stream, + absl::InlinedVector async_comms_streams, + CollectiveExecuteParams* collective_params, + CollectiveCliques* collective_cliques, + se::Stream* device_to_host_stream, + se::Stream* host_to_device_stream, + SendDeviceMemoryFunction* send_device_memory_function, + RecvDeviceMemoryFunction* recv_device_memory_function, + ExecutionStreamIdMap additional_compute_streams = {}); + }; + + //===--------------------------------------------------------------------===// + + // The hlo_instruction argument is meant to be the instruction this thunk was + // generated from, but Thunk never uses this argument other than to save it + // to Thunk::hlo_instruction, so it can be null. + Thunk(Kind kind, ThunkInfo thunk_info) + : kind_(kind), + profile_annotation_(thunk_info.profile_annotation), + op_(thunk_info.op), + execution_stream_id_(thunk_info.execution_stream_id) {} + virtual ~Thunk() = default; + Thunk(const Thunk&) = delete; + Thunk& operator=(const Thunk&) = delete; + + virtual std::string ToStringExtra(int indent) const { return ""; } + Kind kind() const { return kind_; } + std::string_view profile_annotation() const { return profile_annotation_; } + + // Only valid during compilation, i.e., lowering thunks to kernel-launch + // related XLA runtime custom calls). nullptr at runtime. MLIR codegen will + // cease the practice of lowering thunks to XLA runtime custom calls. + mlir::Operation* op() { return op_; } + + // Prepares thunk for execution. + // + // This may be called multiple times. Its main purpose is to pass resource + // requests up to the parent executable so it can acquire them before + // initialization and execution. + virtual absl::Status Prepare(const PrepareParams& params, + ResourceRequests& resource_requests) { + return absl::OkStatus(); + } + + // Initializes thunk for execution. + // + // This may be called multiple times. Its main purpose is to give us a chance + // to do initialization outside of ExecuteOnStream() so that the + // time spent initializing doesn't count towards our execution profile. + // + // Precondition: Prepare(initialize_params) has been called. + virtual absl::Status Initialize(const InitializeParams& params) { + return absl::OkStatus(); + } + + // Executes thunk on the given stream. This method must be called after + // Initialize and can be called multiple times over Thunk's lifetime. + // + // Precondition: Initialize(initialize_params) has been called. + virtual absl::Status ExecuteOnStream(const ExecuteParams& params) = 0; + + // Clears metadata that is only valid during compile time. + virtual void ClearCompileTimeInfo() { op_ = nullptr; } + + static absl::string_view KindToString(Thunk::Kind kind); + + ExecutionStreamId execution_stream_id() const { return execution_stream_id_; } + + static absl::StatusOr GetStreamForExecution( + ExecutionStreamId stream_id, const ExecuteParams& params); + + private: + Kind kind_; + std::string profile_annotation_; + mlir::Operation* op_; + ExecutionStreamId execution_stream_id_; +}; + +// A sequence of thunks. +class ThunkSequence : public std::vector> { + public: + std::string ToString(int indent = 0, + std::function + get_thunk_annotation = nullptr) const; +}; + +std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); + +// A struct that defines a shaped slice, i.e., a BufferAllocation::Slice and its +// shape. +struct ShapedSlice { + BufferAllocation::Slice slice; + Shape shape; +}; + +// Returns if the thunk implements a reduction collective (all-reduce or +// reduce-scatter). +bool IsReductionCollective(Thunk::Kind kind); +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_THUNK_H_ diff --git a/xla/service/gpu/runtime/topk.cc b/xla/service/gpu/runtime/topk.cc deleted file mode 100644 index d21d08e98d88f..0000000000000 --- a/xla/service/gpu/runtime/topk.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/topk.h" - -#include - -#include - -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/executable.h" -#include "xla/service/gpu/runtime/support.h" -#include "xla/service/gpu/runtime/topk_kernel.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" - -namespace xla::gpu { -using ::xla::runtime::CustomCall; -using ::xla::runtime::StridedMemrefView; - -static absl::Status TopkImpl(const ServiceExecutableRunOptions* run_options, - StridedMemrefView data, - StridedMemrefView top_elements, - StridedMemrefView indices) { - if (data.sizes.size() > 2) - return absl::InvalidArgumentError("Invalid input shape"); - if (indices.dtype != PrimitiveType::S32) - return absl::InvalidArgumentError("Indices should be S32"); - bool has_batch = data.sizes.size() == 2; - size_t batch_size = has_batch ? data.sizes[0] : 1; - size_t n = has_batch ? data.sizes[1] : data.sizes[0]; - size_t k = has_batch ? top_elements.sizes[1] : top_elements.sizes[0]; - return RunTopk(run_options->stream(), data.dtype, GetDeviceAddress(data), n, - GetDeviceAddress(top_elements), GetDeviceAddress(indices), k, - batch_size); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - Topk, FunctionWrapper(), checks, - CustomCall::Bind("__gpu$TopK") - .UserData() - .Arg() // input - .Arg() // output (values) - .Arg() // output (indices) -); - -void RegisterTopkCustomCall(runtime::DirectCustomCallRegistry& registry) { - registry.Register("__gpu$TopK", Topk); -} - -} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/topk.h b/xla/service/gpu/runtime/topk.h deleted file mode 100644 index ea669e903dd6f..0000000000000 --- a/xla/service/gpu/runtime/topk.h +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_TOPK_H_ -#define XLA_SERVICE_GPU_RUNTIME_TOPK_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla::gpu { - -// Registers XLA Gpu runtime TopK custom calls. -void RegisterTopkCustomCall(runtime::DirectCustomCallRegistry& registry); - -} // namespace xla::gpu - -#endif // XLA_SERVICE_GPU_RUNTIME_TOPK_H_ diff --git a/xla/service/gpu/runtime/topk_kernel_bfloat16.cu.cc b/xla/service/gpu/runtime/topk_kernel_bfloat16.cu.cc deleted file mode 100644 index b00d8a373c079..0000000000000 --- a/xla/service/gpu/runtime/topk_kernel_bfloat16.cu.cc +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "Eigen/Core" // from @eigen_archive -#include "xla/service/gpu/runtime/topk_kernel.cu.h" - -namespace xla::gpu { - -template void* GetTopKKernelForK(int n); -template void* GetTopKKernelForK(int n); -template void* GetTopKKernelForK(int n); -template void* GetTopKKernelForK(int n); -template void* GetTopKKernelForK(int n); - -} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/topk_kernel_test.cc b/xla/service/gpu/runtime/topk_kernel_test.cc deleted file mode 100644 index 9bc2ccff274f8..0000000000000 --- a/xla/service/gpu/runtime/topk_kernel_test.cc +++ /dev/null @@ -1,229 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/topk_kernel.h" - -#include -#include - -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/random/random.h" -#include "absl/strings/substitute.h" -#include "absl/time/time.h" -#include "Eigen/Core" // from @eigen_archive -#include "xla/service/gpu/runtime/gpu_kernel_helper.h" -#include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/gpu_timer.h" -#include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/multi_platform_manager.h" -#include "xla/stream_executor/stream.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" - -namespace xla::gpu { -namespace { - -using se::gpu::GpuStreamHandle; -using ::testing::Combine; -using ::testing::Values; - -template -std::vector RandomVecRange(int num_elements, T start, T end) { - std::vector local; - local.reserve(num_elements); - thread_local absl::BitGen gen; - for (int i = 0; i < num_elements; ++i) { - local.push_back(absl::Uniform(gen, start, end)); - } - return local; -} - -template -std::vector RandomVec(int num_elements) { - return RandomVecRange(num_elements, static_cast(0), - static_cast(num_elements)); -} - -template -std::vector RandomVecNegative(int num_elements) { - return RandomVecRange(num_elements, -static_cast(num_elements), - static_cast(0)); -} - -PrimitiveType Get(float) { return PrimitiveType::F32; } -PrimitiveType Get(Eigen::bfloat16) { return PrimitiveType::BF16; } - -// Params: -// - n_kb: number of elements in kilobytes. -// - k: number of elements to return. -// - batch_size -// - offset -using TopkTest = ::testing::TestWithParam>; - -// In this test we only check that the TopK logic works with float. For the full -// dtype coverage suite, please add them to topk_test.cc, where we can use XLA -// utilities to simplify the test logic. -TEST_P(TopkTest, TopKFloat) { - using T = float; - - se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("CUDA").value(); - se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - const auto [n_kb, k, batch_size, offset] = GetParam(); - const size_t n = n_kb * 1024 + offset; - - se::DeviceMemory input_buffer = - executor->AllocateArray(n * batch_size, 0); - se::DeviceMemory output_values = - executor->AllocateArray(k * batch_size, 0); - se::DeviceMemory output_indices = - executor->AllocateArray(k * batch_size, 0); - - auto source = RandomVec(n * batch_size); - stream.ThenMemcpy(&input_buffer, source.data(), n * batch_size * sizeof(T)); - - ASSERT_TRUE(RunTopk(&stream, Get(T()), input_buffer, n, output_values, - output_indices, k, batch_size) - .ok()); - std::vector got(k); - ASSERT_TRUE(stream.BlockHostUntilDone().ok()); - for (int i = 0; i < batch_size; i++) { - stream.ThenMemcpy(got.data(), - executor->GetSubBuffer(&output_values, k * i, k), - k * sizeof(T)); - std::vector slice(source.data() + n * i, source.data() + n * (i + 1)); - std::sort(slice.begin(), slice.end(), std::greater()); - slice.resize(k); - EXPECT_THAT(got, ::testing::ElementsAreArray(slice)) - << " k=" << k << ", batch_size=" << batch_size << " i=" << i; - } -} - -TEST_P(TopkTest, TopKPackedNegative) { - using T = float; - - se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("CUDA").value(); - se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - const auto [n_kb, k, batch_size, offset] = GetParam(); - const size_t n = n_kb * 1024 + offset; - - se::DeviceMemory input_buffer = - executor->AllocateArray(n * batch_size, 0); - se::DeviceMemory output_values = - executor->AllocateArray(k * batch_size, 0); - se::DeviceMemory output_indices = - executor->AllocateArray(k * batch_size, 0); - - auto source = RandomVecNegative(n * batch_size); - stream.ThenMemcpy(&input_buffer, source.data(), n * batch_size * sizeof(T)); - - ASSERT_TRUE(RunTopk(&stream, Get(T()), input_buffer, n, output_values, - output_indices, k, batch_size) - .ok()); - std::vector got(k); - ASSERT_TRUE(stream.BlockHostUntilDone().ok()); - for (int i = 0; i < batch_size; i++) { - stream.ThenMemcpy(got.data(), - executor->GetSubBuffer(&output_values, k * i, k), - k * sizeof(T)); - std::vector slice(source.data() + n * i, source.data() + n * (i + 1)); - std::sort(slice.begin(), slice.end(), std::greater()); - slice.resize(k); - EXPECT_THAT(got, ::testing::ElementsAreArray(slice)) - << " k=" << k << ", batch_size=" << batch_size << " i=" << i; - } -} - -INSTANTIATE_TEST_SUITE_P(TopkTests, TopkTest, - Combine( - /*n_kb=*/Values(1, 8, 12, 64, 128), - /*k=*/Values(1, 2, 8, 16, 7, 12), - /*batch_size=*/Values(1, 16, 64, 128), - /*offset=*/Values(0, 7, 4)), - [](const auto& info) { - return absl::Substitute( - "n$0KiB_k$1_batch_size$2_offset$3", - std::get<0>(info.param), std::get<1>(info.param), - std::get<2>(info.param), - std::get<3>(info.param)); - }); - -template -void BM_SmallTopk(benchmark::State& state) { - using T = float; - - size_t k = K; - size_t batch_size = state.range(0); - size_t n = state.range(1) * 1024; - state.SetLabel( - absl::Substitute("n=$0Ki k=$1 batch_size=$2", n / 1024, k, batch_size)); - - se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("CUDA").value(); - se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - se::DeviceMemory input_buffer = - executor->AllocateArray(n * batch_size, 0); - se::DeviceMemory output_values = executor->AllocateArray(k, 0); - se::DeviceMemory output_indices = - executor->AllocateArray(k, 0); - - auto source = RandomVec(n); - stream.ThenMemcpy(&input_buffer, source.data(), n * sizeof(T)); - - for (auto _ : state) { - auto timer = se::gpu::GpuTimer::Create(se::gpu::AsGpuStream(&stream)); - CHECK_OK(timer.status()); - CHECK_OK(RunTopk(&stream, Get(T()), input_buffer, n, output_values, - output_indices, k, batch_size)); - CHECK_OK(stream.BlockHostUntilDone()); - auto timer_duration = timer.value().GetElapsedDuration(); - CHECK_OK(timer_duration.status()); - state.SetIterationTime(ToDoubleMicroseconds(timer_duration.value())); - } - size_t items_processed = batch_size * n * state.iterations(); - state.SetItemsProcessed(items_processed); - state.SetBytesProcessed(items_processed * sizeof(T)); -} - -BENCHMARK(BM_SmallTopk<1>)->RangePair(1, 512, 16, 1024)->UseManualTime(); -BENCHMARK(BM_SmallTopk<2>)->RangePair(1, 512, 16, 1024)->UseManualTime(); -BENCHMARK(BM_SmallTopk<4>)->RangePair(1, 512, 16, 1024)->UseManualTime(); -BENCHMARK(BM_SmallTopk<8>)->RangePair(1, 1024, 16, 1024)->UseManualTime(); -BENCHMARK(BM_SmallTopk<16>)->RangePair(1, 1024, 16, 1024)->UseManualTime(); - -} // namespace -} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/tracing.cc b/xla/service/gpu/runtime/tracing.cc deleted file mode 100644 index 49767a85500ba..0000000000000 --- a/xla/service/gpu/runtime/tracing.cc +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/tracing.h" - -#include - -#include "absl/status/statusor.h" -#include "absl/strings/str_format.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/tracing.h" -#include "xla/service/gpu/runtime/support.h" -#include "tsl/profiler/lib/scoped_annotation_stack.h" - -namespace xla { -namespace gpu { - -using ::xla::runtime::CustomCall; -using ::xla::runtime::HloTrace; - -using ::tsl::profiler::ScopedAnnotationStack; - -//===----------------------------------------------------------------------===// -// Type names for encoded attributes. -//===----------------------------------------------------------------------===// - -void RegisterTracingTypeIdNames(runtime::TypeIDNameRegistry& registry) { - runtime::PopulateTraceTypeIdNames(registry); -} - -//===----------------------------------------------------------------------===// -// Tracing custom calls implementation. -//===----------------------------------------------------------------------===// - -static absl::StatusOr ActivityStart(runtime::HloTrace annotation) { - SetCurrentTracingScope(annotation.hlo_op); - return ScopedAnnotationStack::ActivityStart([&] { - // We use the same tracing annotation scheme as the ThunkSequence (see - // implementation of `GetThunkInfo` in `ir_emitter_unnested.cc`). - return absl::StrFormat("Thunk:#hlo_op=%s#", annotation.hlo_op); - }); -} - -static absl::Status ActivityEnd(int64_t activity_id) { - ResetCurrentTracingScope(); - ScopedAnnotationStack::ActivityEnd(activity_id); - return absl::OkStatus(); -} - -XLA_RUNTIME_DEFINE_CUSTOM_CALL(Start, FunctionWrapper(), checks, - CustomCall::Bind("xla.trace.activity_start") - .Attr("annotation") - .Ret()); - -XLA_RUNTIME_DEFINE_CUSTOM_CALL( - End, FunctionWrapper(), checks, - CustomCall::Bind("xla.trace.activity_end").Arg()); - -void RegisterTracingCustomCalls(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.trace.activity_start", Start); - registry.Register("xla.trace.activity_end", End); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/tracing.h b/xla/service/gpu/runtime/tracing.h deleted file mode 100644 index 7f5efe48accac..0000000000000 --- a/xla/service/gpu/runtime/tracing.h +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_TRACING_H_ -#define XLA_SERVICE_GPU_RUNTIME_TRACING_H_ - -#include - -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/type_id.h" - -namespace xla { -namespace gpu { - -void RegisterTracingTypeIdNames(runtime::TypeIDNameRegistry& registry); - -void RegisterTracingCustomCalls(runtime::DirectCustomCallRegistry& registry); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_TRACING_H_ diff --git a/xla/service/gpu/runtime/triangular_solve.cc b/xla/service/gpu/runtime/triangular_solve.cc deleted file mode 100644 index ccd572cce0785..0000000000000 --- a/xla/service/gpu/runtime/triangular_solve.cc +++ /dev/null @@ -1,137 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/triangular_solve.h" - -#include -#include -#include - -#include "xla/runtime/custom_call.h" -#include "xla/service/gpu/gpu_asm_opts_util.h" -#include "xla/service/gpu/runtime/support.h" -#include "tsl/platform/human_readable_json.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/service/gpu/runtime3/triangular_solve_thunk.h" -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -namespace xla { -namespace gpu { - -using xla::runtime::CustomCall; - -using mlir::failure; -using mlir::FailureOr; - -absl::Status TriangularSolve::run( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, CustomCall::RemainingArgs args, - std::string_view backend_config) { - TriangularSolve handler = TriangularSolve::Handler(); - - if (args.size() != 4) - return absl::InvalidArgumentError( - absl::StrFormat("Expected 4 arguments, got %d", args.size())); - - // Check if all arguments have the correct type. - auto a = args.get(0); - auto b = args.get(1); - auto result = args.get(2); - auto temp = args.get(3); - if (failed(a) || failed(b) || failed(result) || failed(temp)) - return absl::InvalidArgumentError("Incorrect argument types"); - - // Parse backend config string. - TriangularSolveOptions opts; - - const std::string backend_config_str = - std::string(backend_config.data(), backend_config.length()); - - TF_RETURN_IF_ERROR(tsl::HumanReadableJsonToProto(backend_config_str, &opts)); - - return handler(run_options, debug_options, *a, *b, *result, *temp, - opts.left_side(), opts.lower(), opts.unit_diagonal(), - opts.transpose_a()); -} - -absl::Status TriangularSolve::operator()( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, runtime::StridedMemrefView a, - runtime::StridedMemrefView b, runtime::StridedMemrefView result, - runtime::FlatMemrefView temp, bool left_side, bool lower, - bool unit_diagonal, TriangularSolveOptions::Transpose transpose_a) const { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - se::Stream* stream = run_options->stream(); - - se::DeviceMemoryBase a_data = GetDeviceAddress(a); - se::DeviceMemoryBase b_data = GetDeviceAddress(b); - se::DeviceMemoryBase result_data = GetDeviceAddress(result); - se::DeviceMemoryBase temp_data = GetDeviceAddress(temp); - - // Triangular solve is in-place on 'b', so copy 'b' to the output if they - // aren't the same buffer. - if (b.data != result.data) - stream->ThenMemcpy(&result_data, b_data, b_data.size()); - - Shape b_shape = ToShape(b); - int64_t m = b_shape.dimensions(b_shape.rank() - 2); - int64_t n = b_shape.dimensions(b_shape.rank() - 1); - int64_t batch_size = std::accumulate( - b_shape.dimensions().begin(), b_shape.dimensions().end() - 2, int64_t{1}, - [](int64_t a, int64_t b) { return a * b; }); - - PrimitiveType elem_type = b.dtype; - int64_t elem_size = ShapeUtil::ByteSizeOfPrimitiveType(elem_type); - int64_t a_batch_stride = left_side ? m * m * elem_size : n * n * elem_size; - int64_t b_batch_stride = m * n * elem_size; - - using Side = se::blas::Side; - using Diagonal = se::blas::Diagonal; - using Transpose = se::blas::Transpose; - using UpperLower = se::blas::UpperLower; - - // Convert custom call attributes to se::blas enums. - UpperLower uplo = lower ? UpperLower::kLower : UpperLower::kUpper; - Side side = left_side ? Side::kLeft : Side::kRight; - Diagonal diagonal = unit_diagonal ? Diagonal::kUnit : Diagonal::kNonUnit; - - auto transpose = [&]() -> mlir::FailureOr { - switch (transpose_a) { - case TriangularSolveOptions::NO_TRANSPOSE: - return se::blas::Transpose::kNoTranspose; - case TriangularSolveOptions::TRANSPOSE: - return se::blas::Transpose::kTranspose; - case TriangularSolveOptions::ADJOINT: - return se::blas::Transpose::kConjugateTranspose; - default: - return failure(); - } - }(); - - if (failed(transpose)) - return absl::InternalError("Failed to convert transpose type"); - - return RunTriangularSolve(a_data, result_data, temp_data, - PtxOptsFromDebugOptions(*debug_options), uplo, side, - diagonal, *transpose, elem_type, batch_size, m, n, - a_batch_stride, b_batch_stride, stream); -#else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - return absl::InternalError("Not implemented without Gpu"); -#endif -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime/triangular_solve.h b/xla/service/gpu/runtime/triangular_solve.h deleted file mode 100644 index 947633ed2d079..0000000000000 --- a/xla/service/gpu/runtime/triangular_solve.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_TRIANGULAR_SOLVE_H_ -#define XLA_SERVICE_GPU_RUNTIME_TRIANGULAR_SOLVE_H_ - -#include - -#include "xla/runtime/custom_call.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/xla.pb.h" - -namespace xla { -namespace gpu { - -using runtime::CustomCall; - -struct TriangularSolve { - // Adaptor from XlaCustomCall API to properly typed TriangularSolve handler. - static absl::Status run(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - CustomCall::RemainingArgs args, - std::string_view backend_config); - - absl::Status operator()(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, - runtime::StridedMemrefView a, - runtime::StridedMemrefView b, - runtime::StridedMemrefView result, - runtime::FlatMemrefView temp, bool left_side, - bool lower, bool unit_diagonal, - TriangularSolveOptions::Transpose transpose_a) const; - - static TriangularSolve Handler() { return TriangularSolve(); } -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME_TRIANGULAR_SOLVE_H_ diff --git a/xla/service/gpu/runtime/triangular_solve_thunk.cc b/xla/service/gpu/runtime/triangular_solve_thunk.cc new file mode 100644 index 0000000000000..8be15a1846e12 --- /dev/null +++ b/xla/service/gpu/runtime/triangular_solve_thunk.cc @@ -0,0 +1,206 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/triangular_solve_thunk.h" + +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "xla/service/gpu/make_batch_pointers.h" +#include "xla/stream_executor/blas.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/types.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" + +namespace xla { +namespace gpu { + +TriangularSolveThunk::TriangularSolveThunk( + ThunkInfo thunk_info, const TriangularSolveOptions& options, + se::GpuAsmOpts asm_opts, // + const BufferAllocation::Slice& a_buffer, + const BufferAllocation::Slice& b_buffer, + const BufferAllocation::Slice& temp_buffer, // + PrimitiveType type, int64_t batch_size, int64_t m, int64_t n, + int64_t a_batch_stride, int64_t b_batch_stride) + : Thunk(Kind::kTriangularSolve, thunk_info), + asm_opts_(asm_opts), + uplo_(options.lower() ? se::blas::UpperLower::kLower + : se::blas::UpperLower::kUpper), + side_(options.left_side() ? se::blas::Side::kLeft + : se::blas::Side::kRight), + unit_diagonal_(options.unit_diagonal() ? se::blas::Diagonal::kUnit + : se::blas::Diagonal::kNonUnit), + a_buffer_(a_buffer), + b_buffer_(b_buffer), + temp_buffer_(temp_buffer), + type_(type), + batch_size_(batch_size), + m_(m), + n_(n), + a_batch_stride_(a_batch_stride), + b_batch_stride_(b_batch_stride) { + transpose_a_ = [&] { + switch (options.transpose_a()) { + case TriangularSolveOptions::NO_TRANSPOSE: + return se::blas::Transpose::kNoTranspose; + case TriangularSolveOptions::TRANSPOSE: + return se::blas::Transpose::kTranspose; + case TriangularSolveOptions::ADJOINT: + return se::blas::Transpose::kConjugateTranspose; + default: + LOG(ERROR) << "Invalid triangular solve transpose value " + << options.transpose_a(); + return se::blas::Transpose::kNoTranspose; + } + }(); +} + +absl::Status TriangularSolveThunk::ExecuteOnStream( + const ExecuteParams& params) { + auto& buffer_allocations = *params.buffer_allocations; + return RunTriangularSolve(buffer_allocations.GetDeviceAddress(a_buffer_), + buffer_allocations.GetDeviceAddress(b_buffer_), + buffer_allocations.GetDeviceAddress(temp_buffer_), + asm_opts_, uplo_, side_, unit_diagonal_, + transpose_a_, type_, batch_size_, m_, n_, + a_batch_stride_, b_batch_stride_, params.stream); +} + +absl::Status RunTriangularSolve( + se::DeviceMemoryBase a_data, se::DeviceMemoryBase b_data, + se::DeviceMemoryBase temp_data, se::GpuAsmOpts asm_opts, + se::blas::UpperLower uplo, se::blas::Side side, + se::blas::Diagonal unit_diagonal, se::blas::Transpose transpose_a, + PrimitiveType type, int64_t batch_size, int64_t m, int64_t n, + int64_t a_batch_stride, int64_t b_batch_stride, se::Stream* stream) { + VLOG(3) << "uplo=" << se::blas::UpperLowerString(uplo) + << " side=" << se::blas::SideString(side) + << " diagonal=" << se::blas::DiagonalString(unit_diagonal) + << " batch_size=" << batch_size << " m=" << m << " n=" << n + << " a_batch_stride=" << a_batch_stride + << " b_batch_stride=" << b_batch_stride; + + const int lda = side == se::blas::Side::kLeft ? m : n; + const int ldb = m; + + auto blas = stream->parent()->AsBlas(); + if (blas == nullptr) { + return absl::InternalError("No BLAS support in stream."); + } + bool launch_ok; + if (batch_size == 1) { + switch (type) { + case F32: { + se::DeviceMemory b_data_typed(b_data); + launch_ok = blas->DoBlasTrsm( + stream, side, uplo, transpose_a, unit_diagonal, m, n, + /*alpha=*/1.0f, se::DeviceMemory(a_data), lda, &b_data_typed, + ldb); + break; + } + case F64: { + se::DeviceMemory b_data_typed(b_data); + launch_ok = blas->DoBlasTrsm( + stream, side, uplo, transpose_a, unit_diagonal, m, n, + /*alpha=*/1.0, se::DeviceMemory(a_data), lda, &b_data_typed, + ldb); + break; + } + case C64: { + se::DeviceMemory> b_data_typed(b_data); + launch_ok = blas->DoBlasTrsm( + stream, side, uplo, transpose_a, unit_diagonal, m, n, + /*alpha=*/1.0f, se::DeviceMemory>(a_data), lda, + &b_data_typed, ldb); + break; + } + case C128: { + se::DeviceMemory> b_data_typed(b_data); + launch_ok = blas->DoBlasTrsm( + stream, side, uplo, transpose_a, unit_diagonal, m, n, + /*alpha=*/1.0, se::DeviceMemory>(a_data), lda, + &b_data_typed, ldb); + break; + } + default: + return InvalidArgument("Invalid type for triangular solve %d", type); + } + } else { + // cublas trsmBatched requires us to materialize out two arrays of + // batch_size_ pointers, pointing to the individual `a` and `b` matrices of + // our input. batch_pointers_bytes is the size in bytes of one of these + // arrays. + int64_t batch_pointers_bytes = sizeof(void*) * batch_size; + TF_RET_CHECK(temp_data.size() >= 2 * batch_pointers_bytes); + void** temp_base = reinterpret_cast(temp_data.opaque()); + se::DeviceMemoryBase a_pointers(temp_base, batch_pointers_bytes); + se::DeviceMemoryBase b_pointers(temp_base + batch_size, + batch_pointers_bytes); + + TF_RETURN_IF_ERROR(MakeBatchPointers(stream, a_data, a_batch_stride, + batch_size, a_pointers)); + TF_RETURN_IF_ERROR(MakeBatchPointers(stream, b_data, b_batch_stride, + batch_size, b_pointers)); + + switch (type) { + case F32: { + se::DeviceMemory typed_b_pointers(b_pointers); + launch_ok = blas->DoBlasTrsmBatched( + stream, side, uplo, transpose_a, unit_diagonal, m, n, + /*alpha=*/1.0f, se::DeviceMemory(a_pointers), lda, + &typed_b_pointers, ldb, batch_size); + break; + } + case F64: { + se::DeviceMemory typed_b_pointers(b_pointers); + launch_ok = blas->DoBlasTrsmBatched( + stream, side, uplo, transpose_a, unit_diagonal, m, n, + /*alpha=*/1.0f, se::DeviceMemory(a_pointers), lda, + &typed_b_pointers, ldb, batch_size); + break; + } + case C64: { + se::DeviceMemory*> typed_b_pointers(b_pointers); + launch_ok = blas->DoBlasTrsmBatched( + stream, side, uplo, transpose_a, unit_diagonal, m, n, + /*alpha=*/1.0f, se::DeviceMemory*>(a_pointers), + lda, &typed_b_pointers, ldb, batch_size); + break; + } + case C128: { + se::DeviceMemory*> typed_b_pointers(b_pointers); + launch_ok = blas->DoBlasTrsmBatched( + stream, side, uplo, transpose_a, unit_diagonal, m, n, + /*alpha=*/1.0f, se::DeviceMemory*>(a_pointers), + lda, &typed_b_pointers, ldb, batch_size); + break; + } + default: + return InvalidArgument("Invalid type for triangular solve %d", type); + } + } + + if (!launch_ok) { + return Internal("Unable to launch triangular solve"); + } + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/triangular_solve_thunk.h b/xla/service/gpu/runtime/triangular_solve_thunk.h new file mode 100644 index 0000000000000..83b97ae243819 --- /dev/null +++ b/xla/service/gpu/runtime/triangular_solve_thunk.h @@ -0,0 +1,84 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_TRIANGULAR_SOLVE_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_TRIANGULAR_SOLVE_THUNK_H_ + +#include + +#include "absl/status/status.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/stream_executor/blas.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_asm_opts.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +// This class stores everything that StreamExecutor needs to launch a triangular +// solve (BlasTrsm). It is generated by IrEmitter. +// +// Thread-compatible. +class TriangularSolveThunk : public Thunk { + public: + TriangularSolveThunk(ThunkInfo thunk_info, + const TriangularSolveOptions& options, + se::GpuAsmOpts asm_opts, + const BufferAllocation::Slice& a_buffer, + const BufferAllocation::Slice& b_buffer, + const BufferAllocation::Slice& temp_buffer, + PrimitiveType type, int64_t batch_size, int64_t m, + int64_t n, int64_t a_batch_stride, + int64_t b_batch_stride); + + TriangularSolveThunk(const TriangularSolveThunk&) = delete; + TriangularSolveThunk& operator=(const TriangularSolveThunk&) = delete; + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + private: + se::GpuAsmOpts asm_opts_; + const se::blas::UpperLower uplo_; + const se::blas::Side side_; + const se::blas::Diagonal unit_diagonal_; + se::blas::Transpose transpose_a_; + + const BufferAllocation::Slice a_buffer_; + const BufferAllocation::Slice b_buffer_; + const BufferAllocation::Slice temp_buffer_; + + const PrimitiveType type_; + const int64_t batch_size_; + const int64_t m_; + const int64_t n_; + const int64_t a_batch_stride_; + const int64_t b_batch_stride_; +}; + +absl::Status RunTriangularSolve( + se::DeviceMemoryBase a_data, se::DeviceMemoryBase b_data, + se::DeviceMemoryBase temp_data, se::GpuAsmOpts asm_opts, + se::blas::UpperLower uplo, se::blas::Side side, + se::blas::Diagonal unit_diagonal, se::blas::Transpose transpose_a, + PrimitiveType type, int64_t batch_size, int64_t m, int64_t n, + int64_t a_batch_stride, int64_t b_batch_stride, se::Stream* stream); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_TRIANGULAR_SOLVE_THUNK_H_ diff --git a/xla/service/gpu/runtime/wait_for_streams_thunk.cc b/xla/service/gpu/runtime/wait_for_streams_thunk.cc new file mode 100644 index 0000000000000..2bd961264ee12 --- /dev/null +++ b/xla/service/gpu/runtime/wait_for_streams_thunk.cc @@ -0,0 +1,47 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/wait_for_streams_thunk.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "tsl/platform/errors.h" + +namespace xla::gpu { + +absl::Status WaitForStreamsThunk::ExecuteOnStream(const ExecuteParams& params) { + TF_ASSIGN_OR_RETURN(se::Stream * stream, + Thunk::GetStreamForExecution(stream_id_, params)); + + VLOG(5) << "Waiting for stream ids: " + << absl::StrJoin( + wait_for_stream_ids_, ", ", + [&](std::string* s, const ExecutionStreamId& stream_id) { + absl::StrAppend(s, stream_id.value()); + }); + for (const auto& stream_id : wait_for_stream_ids_) { + TF_ASSIGN_OR_RETURN(se::Stream * wait_on_stream, + Thunk::GetStreamForExecution(stream_id, params)); + + TF_RETURN_IF_ERROR(stream->WaitFor(wait_on_stream)); + } + return absl::OkStatus(); +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/runtime/wait_for_streams_thunk.h b/xla/service/gpu/runtime/wait_for_streams_thunk.h new file mode 100644 index 0000000000000..2a545e0173865 --- /dev/null +++ b/xla/service/gpu/runtime/wait_for_streams_thunk.h @@ -0,0 +1,53 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_WAIT_FOR_STREAMS_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_WAIT_FOR_STREAMS_THUNK_H_ + +#include + +#include "absl/status/status.h" +#include "xla/service/gpu/runtime/thunk.h" + +namespace xla::gpu { + +// This thunk +class WaitForStreamsThunk : public Thunk { + public: + WaitForStreamsThunk(ThunkInfo thunk_info, ExecutionStreamId stream_id, + std::vector wait_for_stream_ids) + : Thunk(Kind::kWaitForStreams, thunk_info), + stream_id_(stream_id), + wait_for_stream_ids_(wait_for_stream_ids){}; + + WaitForStreamsThunk(const WaitForStreamsThunk&) = delete; + WaitForStreamsThunk& operator=(const WaitForStreamsThunk&) = delete; + + const ExecutionStreamId& stream_id() const { return stream_id_; } + + const std::vector& wait_for_stream_ids() const { + return wait_for_stream_ids_; + } + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + private: + ExecutionStreamId stream_id_; + std::vector wait_for_stream_ids_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_RUNTIME_WAIT_FOR_STREAMS_THUNK_H_ diff --git a/xla/service/gpu/runtime/while_thunk.cc b/xla/service/gpu/runtime/while_thunk.cc new file mode 100644 index 0000000000000..cbd9ce3134133 --- /dev/null +++ b/xla/service/gpu/runtime/while_thunk.cc @@ -0,0 +1,125 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/while_thunk.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/memory_allocation.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +WhileThunk::WhileThunk( + ThunkInfo thunk_info, + const BufferAllocation::Slice& condition_result_buffer_index, + std::unique_ptr condition_thunk_sequence, + std::unique_ptr body_thunk_sequence, + std::optional trip_count) + : Thunk(Kind::kWhile, thunk_info), + condition_result_buffer_index_(condition_result_buffer_index), + condition_thunk_sequence_(std::make_unique( + ThunkInfo(thunk_info.op), std::move(*condition_thunk_sequence))), + body_thunk_sequence_(std::make_unique( + ThunkInfo(thunk_info.op), std::move(*body_thunk_sequence))), + trip_count_(trip_count) {} + +absl::Status WhileThunk::Prepare(const PrepareParams& params, + ResourceRequests& resource_requests) { + TF_RETURN_IF_ERROR( + condition_thunk_sequence_->Prepare(params, resource_requests)); + TF_RETURN_IF_ERROR(body_thunk_sequence_->Prepare(params, resource_requests)); + return absl::OkStatus(); +} + +absl::Status WhileThunk::Initialize(const InitializeParams& params) { + TF_RETURN_IF_ERROR(condition_thunk_sequence_->Initialize(params)); + TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(params)); + + absl::MutexLock lock(&mutex_); + if (auto it = predicates_.find(params.executor); it == predicates_.end()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr allocation, + params.executor->HostMemoryAllocate(sizeof(bool))); + predicates_.emplace(params.executor, std::move(allocation)); + } + + return absl::OkStatus(); +} + +absl::Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) { + auto& stream = *params.stream; + + se::DeviceMemoryBase condition_result_data = + params.buffer_allocations->GetDeviceAddress( + condition_result_buffer_index_); + + if (trip_count_.has_value()) { + VLOG(2) << "Executing WhileThunk for " << *trip_count_ << " iterations"; + for (int64_t i = 0; i < trip_count_; ++i) { + VLOG(3) << "Executing iteration # " << i; + TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params)); + } + return absl::OkStatus(); + } + + int64_t iter = 0; + + // Get memory allocation for copying condition result from device. + bool* condition_result = [&] { + absl::MutexLock lock(&mutex_); + return reinterpret_cast(predicates_.at(stream.parent())->opaque()); + }(); + + while (true) { + VLOG(3) << "Executing WhileThunk condition computation; iter=" << iter; + TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream(params)); + + // Copy the result of condition computation and break the loop if 'false'. + TF_RETURN_IF_ERROR( + stream.Memcpy(condition_result, condition_result_data, sizeof(bool))); + + if (absl::Status blocked = stream.BlockHostUntilDone(); !blocked.ok()) { + return absl::InternalError(absl::StrFormat( + "Failed to complete all kernels launched on stream %p: %s", &stream, + blocked.message())); + } + + VLOG(3) << "condition_result = " << *condition_result; + if (!*condition_result) { + VLOG(3) << "Break WhileThunk loop; iter=" << iter; + break; + } + + VLOG(3) << "Executing WhileThunk body computation; iter=" << iter++; + TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params)); + } + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/runtime/while_thunk.h b/xla/service/gpu/runtime/while_thunk.h new file mode 100644 index 0000000000000..e1a06c9630859 --- /dev/null +++ b/xla/service/gpu/runtime/while_thunk.h @@ -0,0 +1,94 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_WHILE_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_WHILE_THUNK_H_ + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/stream_executor/memory_allocation.h" +#include "xla/stream_executor/stream_executor.h" + +namespace xla { +namespace gpu { + +// WhileThunk implements the while instruction on GPU by invoking a thunk +// sequence for the while 'condition' computation, and (conditionally) another +// thunk sequence for the while 'body' computation. WhileThunk assumes that +// buffers for the following set of while-related instructions share the same +// allocation: +// init, condition.parameter, body.parameter, body.root, while.result +// +// WhileThunk synchronizes the stream to test the result of the 'condition' +// computation. +// +// If `trip_count` is available it means that the while loop trip count is known +// statically and while loop is actually a for loop, and in this case at run +// time condition thunk might not be executed and instead body thunk will be +// executed for `trip_count` times. +class WhileThunk : public Thunk { + public: + // Constructs a WhileThunk to compute while instruction 'hlo'. + WhileThunk(ThunkInfo thunk_info, + const BufferAllocation::Slice& condition_result_buffer_index, + std::unique_ptr condition_thunk_sequence, + std::unique_ptr body_thunk_sequence, + std::optional trip_count = std::nullopt); + WhileThunk(const WhileThunk&) = delete; + WhileThunk& operator=(const WhileThunk&) = delete; + + absl::Status Prepare(const PrepareParams& params, + ResourceRequests& resource_requests) override; + absl::Status Initialize(const InitializeParams& params) override; + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + SequentialThunk* condition_thunk_sequence() const { + return condition_thunk_sequence_.get(); + } + + SequentialThunk* body_thunk_sequence() const { + return body_thunk_sequence_.get(); + } + + const BufferAllocation::Slice& condition_result_buffer() const { + return condition_result_buffer_index_; + } + + private: + const BufferAllocation::Slice condition_result_buffer_index_; + std::unique_ptr condition_thunk_sequence_; + std::unique_ptr body_thunk_sequence_; + std::optional trip_count_; + + // Pinned host memory for transfering predicate value from device to host. + absl::Mutex mutex_; + absl::flat_hash_map> + predicates_ ABSL_GUARDED_BY(mutex_); +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_WHILE_THUNK_H_ diff --git a/xla/service/gpu/runtime3/BUILD b/xla/service/gpu/runtime3/BUILD deleted file mode 100644 index 57d84e65e5267..0000000000000 --- a/xla/service/gpu/runtime3/BUILD +++ /dev/null @@ -1,247 +0,0 @@ -load("//xla/tests:build_defs.bzl", "xla_test") -load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") -load("@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [":friends"], - licenses = ["notice"], -) - -package_group( - name = "friends", - includes = ["//xla:friends"], -) - -#===-------------------------------------------------------------------------------------------===// -# Command Buffer Integration -#===-------------------------------------------------------------------------------------------===// - -cc_library( - name = "command_buffer_allocations", - srcs = ["command_buffer_allocations.cc"], - hdrs = ["command_buffer_allocations.h"], - deps = [ - "//xla:status", - "//xla:statusor", - "//xla/service:buffer_assignment", - "//xla/service/gpu:buffer_allocations", - "//xla/stream_executor", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "command_buffer_cmd", - srcs = ["command_buffer_cmd.cc"], - hdrs = ["command_buffer_cmd.h"], - deps = [ - ":command_buffer_allocations", - "//xla:status", - "//xla:statusor", - "//xla:types", - "//xla/service:buffer_assignment", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:gemm_thunk", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:matmul_utils", - "//xla/service/gpu:stream_executor_util", - "//xla/service/gpu:thunk", - "//xla/stream_executor", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "command_buffer_cmd_emitter", - srcs = ["command_buffer_cmd_emitter.cc"], - hdrs = ["command_buffer_cmd_emitter.h"], - deps = [ - ":command_buffer_cmd", - "//xla:statusor", - "//xla:util", - "//xla/service/gpu:gpu_executable", - "//xla/service/gpu:thunk", - "@tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "command_buffer_cmd_test", - srcs = if_gpu_is_configured(["command_buffer_cmd_test.cc"]), - backends = ["gpu"], - deps = [ - ":command_buffer_cmd", - "//xla:types", - "//xla/service:buffer_assignment", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:launch_dimensions", - "//xla/stream_executor", - "//xla/stream_executor:multi_platform_manager", - "//xla/stream_executor:platform", - "//xla/stream_executor/cuda:cuda_test_kernels", - "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:test", - "@tsl//tsl/platform:test_benchmark", - "@tsl//tsl/platform:test_main", - ], -) - -#===-------------------------------------------------------------------------------------------===// -# XLA Thunks Runtime -#===-------------------------------------------------------------------------------------------===// - -cc_library( - name = "cholesky_thunk", - srcs = if_gpu_is_configured(["cholesky_thunk.cc"]), - hdrs = if_gpu_is_configured(["cholesky_thunk.h"]), - deps = if_gpu_is_configured([ - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:cusolver_context", - "//xla/service/gpu:make_batch_pointers", - "//xla/service/gpu:thunk", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/service:buffer_assignment", - "//xla/hlo/ir:hlo", - "@tsl//tsl/platform:logging", - "//xla/stream_executor", - "//xla/stream_executor:device_memory", - "//xla/stream_executor/gpu:gpu_asm_opts", - ]) + ["@tsl//tsl/platform:status"], -) - -cc_library( - name = "command_buffer_thunk", - srcs = ["command_buffer_thunk.cc"], - hdrs = ["command_buffer_thunk.h"], - deps = [ - ":command_buffer_allocations", - ":command_buffer_cmd", - "//xla:status", - "//xla:statusor", - "//xla:types", - "//xla/service:buffer_assignment", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:thunk", - "//xla/stream_executor", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "command_buffer_thunk_test", - srcs = if_gpu_is_configured(["command_buffer_thunk_test.cc"]), - backends = ["gpu"], - deps = [ - ":command_buffer_cmd", - ":command_buffer_thunk", - "//xla:shape_util", - "//xla:types", - "//xla/service:buffer_assignment", - "//xla/service:executable", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:matmul_utils", - "//xla/service/gpu:thunk", - "//xla/stream_executor", - "//xla/stream_executor:multi_platform_manager", - "//xla/stream_executor:platform", - "//xla/stream_executor/cuda:cuda_test_kernels", - "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:test", - "@tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "custom_call_thunk", - srcs = ["custom_call_thunk.cc"], - hdrs = ["custom_call_thunk.h"], - local_defines = if_cuda_is_configured([ - "GOOGLE_CUDA=1", - ]), - deps = [ - "//xla:executable_run_options", - "//xla:shape_util", - "//xla:status", - "//xla:util", - "//xla/ffi:call_frame", - "//xla/ffi:ffi_api", - "//xla/ffi/api:c_api", - "//xla/service:buffer_assignment", - "//xla/service:custom_call_status", - "//xla/service:custom_call_status_internal", - "//xla/service:executable", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:thunk", - "//xla/stream_executor:device_memory", - "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_types_header", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:str_format", - "@tsl//tsl/platform:errors", - ], -) - -cc_library( - name = "fft_thunk", - srcs = ["fft_thunk.cc"], - hdrs = ["fft_thunk.h"], - deps = [ - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:buffer_assignment", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:thunk", - "//xla/stream_executor", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:status", - ], -) - -cc_library( - name = "triangular_solve_thunk", - srcs = if_gpu_is_configured(["triangular_solve_thunk.cc"]), - hdrs = if_gpu_is_configured(["triangular_solve_thunk.h"]), - deps = if_gpu_is_configured([ - "@com_google_absl//absl/strings:str_format", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:make_batch_pointers", - "//xla/service/gpu:thunk", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/service:buffer_assignment", - "//xla/hlo/ir:hlo", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "//xla/stream_executor", - "//xla/stream_executor:device_memory", - "//xla/stream_executor/gpu:gpu_asm_opts", - ]) + ["@tsl//tsl/platform:status"], -) diff --git a/xla/service/gpu/runtime3/README.md b/xla/service/gpu/runtime3/README.md deleted file mode 100644 index 351de805194d9..0000000000000 --- a/xla/service/gpu/runtime3/README.md +++ /dev/null @@ -1,8 +0,0 @@ -# XLA:GPU Runtime Under Construction - -This is a temporary folder to consolidate and clean up the Thunk-based XLA:GPU -runtime (right now it's all over the xla/servive/gpu folder), with a goal to -eventually delete `runtime` and `runtime2` folders and make `runtime3` the -default and only XLA:GPU runtime. - -Preliminary timeline for completion is late Q4 2023 - early Q1 2024. diff --git a/xla/service/gpu/runtime3/cholesky_thunk.cc b/xla/service/gpu/runtime3/cholesky_thunk.cc deleted file mode 100644 index 10235cd9bf567..0000000000000 --- a/xla/service/gpu/runtime3/cholesky_thunk.cc +++ /dev/null @@ -1,129 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime3/cholesky_thunk.h" - -#include -#include -#include -#include - -#include "xla/service/gpu/cusolver_context.h" -#include "xla/service/gpu/make_batch_pointers.h" -#include "xla/stream_executor/blas.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" - -namespace xla { -namespace gpu { - -namespace { - -template -Status DoPotrfBatched(const se::GpuAsmOpts& asm_opts, CholeskyParams* params, - se::Stream* stream, GpuSolverContext& context) { - T* a_base = static_cast(params->a_buffer.opaque()); - se::DeviceMemory infos(params->info_buffer); -#if TENSORFLOW_USE_ROCSOLVER - // hipsolver is not supported so allocate a GPU buffer - se::ScopedDeviceMemory ptrs = - stream->parent()->AllocateOwnedArray(batch_size_); - auto as = *ptrs; -#else - se::DeviceMemory as(params->workspace_buffer); -#endif - - CHECK_GE(as.size(), params->batch_size); - CHECK_GE(infos.size(), params->batch_size); - - // Run a kernel that sets as[i] = &a_base[i * stride]. - const int64_t stride_bytes = params->n * params->n * sizeof(T); - TF_RETURN_IF_ERROR(MakeBatchPointers( - stream, se::DeviceMemoryBase(a_base), stride_bytes, - static_cast(params->batch_size), se::DeviceMemoryBase(as))); - - // Now that we've set up the `as` array, we can call cusolver. - return context.PotrfBatched(params->uplo, params->n, as, params->n, infos, - params->batch_size); -} - -} // namespace - -CholeskyThunk::CholeskyThunk(ThunkInfo thunk_info, - const CholeskyOptions& options, - const se::GpuAsmOpts asm_opts, - BufferAllocation::Slice a_buffer, - BufferAllocation::Slice workspace_buffer, - BufferAllocation::Slice info_buffer, - PrimitiveType type, int64_t batch_size, int64_t n) - : Thunk(Kind::kCholesky, thunk_info), - asm_opts_(asm_opts), - uplo_(options.lower() ? se::blas::UpperLower::kLower - : se::blas::UpperLower::kUpper), - a_buffer_(a_buffer), - workspace_buffer_(workspace_buffer), - info_buffer_(info_buffer), - type_(type), - batch_size_(batch_size), - n_(n) {} - -Status CholeskyThunk::ExecuteOnStream(const ExecuteParams& params) { - VLOG(3) << "type=" << PrimitiveType_Name(type_) - << " uplo=" << se::blas::UpperLowerString(uplo_) - << " batch_size=" << batch_size_ << " n=" << n_ - << " a=" << a_buffer_.ToString() - << " workspace=" << workspace_buffer_.ToString() - << " info=" << info_buffer_.ToString(); - - se::DeviceMemoryBase a_buffer = - params.buffer_allocations->GetDeviceAddress(a_buffer_); - se::DeviceMemoryBase info_buffer = - params.buffer_allocations->GetDeviceAddress(info_buffer_); - se::DeviceMemoryBase workspace_buffer = - params.buffer_allocations->GetDeviceAddress(workspace_buffer_); - CholeskyParams cholesky_params{n_, batch_size_, uplo_, - a_buffer, workspace_buffer, info_buffer}; - return RunCholesky(asm_opts_, type_, &cholesky_params, params.stream); -} - -Status RunCholesky(const se::GpuAsmOpts& asm_opts, PrimitiveType type, - CholeskyParams* cholesky_params, se::Stream* stream) { - thread_local StatusOr context = GpuSolverContext::Create(); - TF_RETURN_IF_ERROR(context.status()); - TF_RETURN_IF_ERROR(context->SetStream(stream)); - - switch (type) { - case F32: - return DoPotrfBatched(asm_opts, cholesky_params, stream, *context); - case F64: - return DoPotrfBatched(asm_opts, cholesky_params, stream, - *context); - case C64: - return DoPotrfBatched>(asm_opts, cholesky_params, - stream, *context); - case C128: - return DoPotrfBatched>(asm_opts, cholesky_params, - stream, *context); - default: - return InvalidArgument("Invalid type for cholesky %s", - PrimitiveType_Name(type)); - } -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime3/command_buffer_allocations.cc b/xla/service/gpu/runtime3/command_buffer_allocations.cc deleted file mode 100644 index 0aa46ffec10f2..0000000000000 --- a/xla/service/gpu/runtime3/command_buffer_allocations.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime3/command_buffer_allocations.h" - -#include - -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "xla/service/buffer_assignment.h" -#include "xla/status.h" -#include "xla/statusor.h" -#include "xla/stream_executor/device_memory.h" - -namespace xla::gpu { - -StatusOr CommandBufferAllocations::GetDeviceAddress( - BufferAllocation::Slice buffer_slice) const { - auto base = allocs_.find(buffer_slice.index()); - if (base == allocs_.end()) { - return absl::InternalError(absl::StrCat("Command buffer allocation #", - buffer_slice.index(), - " was not allocated")); - } - - char* ptr = static_cast(const_cast(base->second.opaque())); - return se::DeviceMemoryBase(ptr + buffer_slice.offset(), buffer_slice.size()); -} - -Status CommandBufferAllocations::AddAllocation(BufferAllocation::Index index, - se::DeviceMemoryBase memory) { - VLOG(2) << "Add comand buffer allocation: index=" << index - << "; ptr=" << memory.opaque(); - - auto emplaced = allocs_.try_emplace(index, std::move(memory)); - if (emplaced.second == false) { - return absl::InternalError(absl::StrCat("Command buffer allocation #", - index, " was already allocated")); - } - return OkStatus(); -} - -Status CommandBufferAllocations::EraseAllocation( - BufferAllocation::Index index) { - VLOG(2) << "Erase comand buffer allocation: index=" << index; - - if (allocs_.erase(index) == 0) { - return absl::InternalError(absl::StrCat("Command buffer allocation #", - index, " was not allocated")); - } - return OkStatus(); -} - -} // namespace xla::gpu diff --git a/xla/service/gpu/runtime3/command_buffer_allocations.h b/xla/service/gpu/runtime3/command_buffer_allocations.h deleted file mode 100644 index 3435dc0d69434..0000000000000 --- a/xla/service/gpu/runtime3/command_buffer_allocations.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_ALLOCATIONS_H_ -#define XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_ALLOCATIONS_H_ - -#include "absl/container/flat_hash_map.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/status.h" -#include "xla/statusor.h" -#include "xla/stream_executor/device_memory.h" - -namespace xla::gpu { - -// Command buffer allocations tracks external buffer allocations done via the -// CommandBuffer API and owned by the XLA executable (via instantiated command -// buffers and memory allocation Gpu graph nodes). -class CommandBufferAllocations : public BufferAllocations::ExternalAllocations { - public: - StatusOr GetDeviceAddress( - BufferAllocation::Slice buffer_slice) const override; - - // Adds an external allocation for a given buffer index. Returns error if - // allocation already exists. - Status AddAllocation(BufferAllocation::Index index, - se::DeviceMemoryBase memory); - - // Erases an external allocation for a given buffer index. Returns error if - // allocation does not exists. - Status EraseAllocation(BufferAllocation::Index index); - - private: - absl::flat_hash_map allocs_; -}; - -} // namespace xla::gpu - -#endif // XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_ALLOCATIONS_H_ diff --git a/xla/service/gpu/runtime3/command_buffer_cmd.cc b/xla/service/gpu/runtime3/command_buffer_cmd.cc deleted file mode 100644 index 854f0bd790e40..0000000000000 --- a/xla/service/gpu/runtime3/command_buffer_cmd.cc +++ /dev/null @@ -1,294 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime3/command_buffer_cmd.h" - -#include -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/stream_executor_util.h" -#include "xla/status.h" -#include "xla/stream_executor/command_buffer.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" // IWYU pragma: keep -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" - -namespace xla::gpu { - -//===----------------------------------------------------------------------===// -// CommandBufferCmdSequence -//===----------------------------------------------------------------------===// - -void CommandBufferCmdSequence::Append(std::unique_ptr cmd) { - for (BufferAllocation::Slice& slice : cmd->slices()) { - slices_.insert(slice); - allocs_indices_.insert(slice.index()); - } - commands_.push_back(std::move(cmd)); -} - -Status CommandBufferCmdSequence::Initialize( - se::StreamExecutor* executor, CommandBufferCmd::ExecutableSource source) { - for (auto& cmd : commands_) { - TF_RETURN_IF_ERROR(cmd->Initialize(executor, source)); - } - return OkStatus(); -} - -Status CommandBufferCmdSequence::Record( - const CommandBufferCmd::RecordParams& params, - se::CommandBuffer* command_buffer, RecordMode mode) { - if (mode == RecordMode::kExclusive) { - if (command_buffer->state() == se::CommandBuffer::State::kFinalized) { - TF_RETURN_IF_ERROR(command_buffer->Update()); - } - } - - for (auto& cmd : commands_) { - TF_RETURN_IF_ERROR(cmd->Record(params, command_buffer)); - } - - if (mode == RecordMode::kExclusive) { - TF_RETURN_IF_ERROR(command_buffer->Finalize()); - } - - return OkStatus(); -} - -// Returns buffer allocation slices referenced by commands in this sequence. -const absl::flat_hash_set& -CommandBufferCmdSequence::slices() const { - return slices_; -} - -// Returns buffer allocations indices referenced by commands in this sequence. -const absl::flat_hash_set& -CommandBufferCmdSequence::allocs_indices() const { - return allocs_indices_; -} - -//===----------------------------------------------------------------------===// -// LaunchCmd -//===----------------------------------------------------------------------===// - -LaunchCmd::LaunchCmd(std::string kernel_name, - absl::Span args, - LaunchDimensions dims, int64_t shmem_bytes) - : kernel_name_(std::move(kernel_name)), - args_(args.begin(), args.end()), - dims_(dims), - shmem_bytes_(shmem_bytes) {} - -Status LaunchCmd::Initialize(se::StreamExecutor* executor, - ExecutableSource source) { - if (kernels_.contains(executor)) return OkStatus(); - - TF_ASSIGN_OR_RETURN(std::unique_ptr kernel, - CreateKernel(kernel_name_, args_.size(), source.text, - source.binary, executor, shmem_bytes_)); - - kernels_.emplace(executor, std::move(kernel)); - return OkStatus(); -} - -Status LaunchCmd::Record(const RecordParams& params, - se::CommandBuffer* command_buffer) { - VLOG(5) << "LaunchCmd: kernel=" << kernel_name_ - << ", shmem_bytes=" << shmem_bytes_; - - se::Kernel* kernel = kernels_[params.executor].get(); - if (kernel == nullptr) { - return absl::InternalError( - "Kernel not loaded on a command buffer executor"); - } - - absl::InlinedVector buffers; - for (const BufferAllocation::Slice& arg : args_) { - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buf, - params.buffer_allocations->GetDeviceAddress( - arg, *params.command_buffer_allocations)); - VLOG(5) << " Arg: " << arg << ": " << buf.opaque(); - buffers.push_back(buf); - } - - TF_ASSIGN_OR_RETURN(auto kernel_args, - se::PackKernelArgs(buffers, shmem_bytes_)); - - LaunchDimensions::Dim3D thread_counts = dims_.thread_counts_per_block(); - LaunchDimensions::Dim3D block_counts = dims_.block_counts(); - - return command_buffer->Launch( - se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z), - se::BlockDim(block_counts.x, block_counts.y, block_counts.z), *kernel, - *kernel_args); -} - -CommandBufferCmd::Slices LaunchCmd::slices() { - return CommandBufferCmd::Slices(args_.begin(), args_.end()); -} - -//===----------------------------------------------------------------------===// -// MemcpyDeviceToDeviceCmd -//===----------------------------------------------------------------------===// - -MemcpyDeviceToDeviceCmd::MemcpyDeviceToDeviceCmd(BufferAllocation::Slice dst, - BufferAllocation::Slice src, - int64_t num_bytes) - : dst_(dst), src_(src), num_bytes_(num_bytes) {} - -Status MemcpyDeviceToDeviceCmd::Record(const RecordParams& params, - se::CommandBuffer* command_buffer) { - VLOG(5) << "MemcpyDeviceToDeviceCmd: dst=" << dst_ << ", src=" << src_ - << ", num_bytes=" << num_bytes_; - - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase dst, - params.buffer_allocations->GetDeviceAddress( - dst_, *params.command_buffer_allocations)); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase src, - params.buffer_allocations->GetDeviceAddress( - src_, *params.command_buffer_allocations)); - - return command_buffer->MemcpyDeviceToDevice(&dst, src, num_bytes_); -} - -CommandBufferCmd::Slices MemcpyDeviceToDeviceCmd::slices() { - return {dst_, src_}; -} - -//===----------------------------------------------------------------------===// -// IfCmd -//===----------------------------------------------------------------------===// - -IfCmd::IfCmd(BufferAllocation::Slice pred, CommandBufferCmdSequence then_cmds) - : pred_(pred), then_cmds_(std::move(then_cmds)) {} - -Status IfCmd::Initialize(se::StreamExecutor* executor, - ExecutableSource source) { - return then_cmds_.Initialize(executor, source); -} - -Status IfCmd::Record(const RecordParams& params, - se::CommandBuffer* command_buffer) { - se::DeviceMemoryBase pred = - params.buffer_allocations->GetDeviceAddress(pred_); - - return command_buffer->If( - params.executor, se::DeviceMemory(pred), - [&](se::CommandBuffer* then_cmd_buffer) { - return then_cmds_.Record( - params, then_cmd_buffer, - CommandBufferCmdSequence::RecordMode::kConditional); - }); -} - -CommandBufferCmd::Slices IfCmd::slices() { - auto& slices = then_cmds_.slices(); - return {slices.begin(), slices.end()}; -} - -//===----------------------------------------------------------------------===// -// AllocateCmd -//===----------------------------------------------------------------------===// - -AllocateCmd::AllocateCmd(BufferAllocation* allocation) - : allocation_(allocation) {} - -Status AllocateCmd::Record(const RecordParams& params, - se::CommandBuffer* command_buffer) { - // Memory allocation address is returned on graph creation, and there is no - // update operation - VLOG(5) << "AllocationCmd: index=" << allocation_->index(); - - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer, - command_buffer->Allocate(allocation_->size())); - - TF_RETURN_IF_ERROR(params.command_buffer_allocations->AddAllocation( - allocation_->index(), buffer)); - - return OkStatus(); -} - -CommandBufferCmd::Slices AllocateCmd::slices() { return {}; } - -//===----------------------------------------------------------------------===// -// GemmCmd -//===----------------------------------------------------------------------===// - -GemmCmd::GemmCmd(GemmConfig config, const BufferAllocation::Slice& lhs_buffer, - const BufferAllocation::Slice& rhs_buffer, - const BufferAllocation::Slice& output_buffer, - bool deterministic) - : config_(std::move(config)), - lhs_buffer_(lhs_buffer), - rhs_buffer_(rhs_buffer), - output_buffer_(output_buffer), - deterministic_(deterministic) {} - -Status GemmCmd::Initialize(se::StreamExecutor* executor, - ExecutableSource source) { - if (!executor->AsBlas()) { - return absl::InternalError("Failed to initialize BLAS support for GemmCmd"); - } - return OkStatus(); -} - -Status GemmCmd::Record(const RecordParams& params, - se::CommandBuffer* command_buffer) { - VLOG(5) << "GemmCmd: lhs=" << lhs_buffer_ << ", rhs=" << rhs_buffer_ - << ", output=" << output_buffer_ - << ", deterministic=" << deterministic_; - - se::DeviceMemoryBase workspace(nullptr, 0); - - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs, - params.buffer_allocations->GetDeviceAddress( - lhs_buffer_, *params.command_buffer_allocations)); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs, - params.buffer_allocations->GetDeviceAddress( - rhs_buffer_, *params.command_buffer_allocations)); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase out, - params.buffer_allocations->GetDeviceAddress( - output_buffer_, *params.command_buffer_allocations)); - - TF_ASSIGN_OR_RETURN( - auto nested_buffer, - se::CommandBuffer::Trace(params.executor, [&](se::Stream* stream) { - return RunGemm(config_, lhs, rhs, out, workspace, deterministic_, - stream); - })); - - return command_buffer->AddNestedCommandBuffer(nested_buffer); -} - -CommandBufferCmd::Slices GemmCmd::slices() { - return {lhs_buffer_, rhs_buffer_, output_buffer_}; -} - -} // namespace xla::gpu diff --git a/xla/service/gpu/runtime3/command_buffer_cmd.h b/xla/service/gpu/runtime3/command_buffer_cmd.h deleted file mode 100644 index 699c9a43283fb..0000000000000 --- a/xla/service/gpu/runtime3/command_buffer_cmd.h +++ /dev/null @@ -1,262 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_CMD_H_ -#define XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_CMD_H_ - -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/runtime3/command_buffer_allocations.h" -#include "xla/service/gpu/thunk.h" -#include "xla/status.h" -#include "xla/stream_executor/command_buffer.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/stream_executor.h" - -namespace xla::gpu { - -//===----------------------------------------------------------------------===// -// CommandBufferCmd -//===----------------------------------------------------------------------===// - -// CommandBufferCmd is an abstract command that creates or updates command -// buffer by recording commands into it. -class CommandBufferCmd { - public: - using ExecutableSource = Thunk::ExecutableSource; - using Slices = absl::InlinedVector; - - // Run time parameters required for recording commands into the command - // buffer. For example when we emit command buffer cmd sequence from an HLO - // module, we only know the buffer slices required for HLO operations, but the - // concrete device pointers become available only at run time. - // - // For allocations that performed through command buffer Allocate command, the - // target addresses are tracked by command buffer runtime. To record command - // that consumes buffers allocated inside command buffer, user should specify - // the target address as se::DeviceMemoryBase{nullptr, size}. - struct RecordParams { - se::StreamExecutor* executor; - const BufferAllocations* buffer_allocations; - CommandBufferAllocations* command_buffer_allocations; - }; - - // Prepares a command for recording on a given executor. We split it into a - // separate function to allow expensive initialization (e.g. device kernel - // loading) to happen before a command buffer thunk execution. - virtual Status Initialize(se::StreamExecutor* executor, - ExecutableSource source) { - return OkStatus(); - } - - // Records command into the command buffer. - virtual Status Record(const RecordParams& params, - se::CommandBuffer* command_buffer) = 0; - - // Returns all buffer slices of the cmd. These will be used to track cmd - // updates, thus they need to be consistent across calls to the function. - virtual Slices slices() = 0; - - virtual ~CommandBufferCmd() = default; -}; - -//===----------------------------------------------------------------------===// -// CommandBufferCmdSequence -//===----------------------------------------------------------------------===// - -// A sequence of command buffer commands that create or update a command buffer. -// You can think of CommandBufferCmdSequence as a mini interpreter whose sole -// purpose is to manipulate command buffers at run time. -class CommandBufferCmdSequence { - public: - CommandBufferCmdSequence() = default; - - enum class RecordMode { - // In exclusive mode no one else is recording commands into the command - // buffer argument, and cmd sequence is responsible for updating command - // buffer state: finalizing after all commands recorded, and - // switching to update state before recording updates. - kExclusive, - - // In conditional mode multiple cmd sequences can be recorded into the - // command buffer argument, and with command buffer state managed externally - // cmd sequence should not finalize or update it. This mode is used when - // command buffer cmd sequence is recorded into conditional command buffers - // owned by the parent command buffer. - kConditional - }; - - void Append(std::unique_ptr cmd); - - template - void Emplace(Args... args) { - Append(std::make_unique(std::forward(args)...)); - } - - // Initialized all commands added to a sequence. - Status Initialize(se::StreamExecutor* executor, - CommandBufferCmd::ExecutableSource source); - - // Records all commands added to a sequence into the given command buffer. - Status Record(const CommandBufferCmd::RecordParams& params, - se::CommandBuffer* command_buffer, - RecordMode mode = RecordMode::kExclusive); - - // Returns buffer allocation slices referenced by commands in this sequence. - const absl::flat_hash_set& slices() const; - - // Returns buffer allocations indices referenced by commands in this sequence. - const absl::flat_hash_set& allocs_indices() const; - - private: - std::vector> commands_; - - // Buffer allocation slices referenced by commands in this sequence. - absl::flat_hash_set slices_; - - // Buffer allocations indices referenced by commands in this sequence. - absl::flat_hash_set allocs_indices_; -}; - -//===----------------------------------------------------------------------===// -// LaunchCmd -//===----------------------------------------------------------------------===// - -class LaunchCmd : public CommandBufferCmd { - public: - LaunchCmd(std::string kernel_name, - absl::Span args, - LaunchDimensions dims, int64_t shmem_bytes); - - Status Initialize(se::StreamExecutor* executor, - ExecutableSource source) override; - - Status Record(const RecordParams& params, - se::CommandBuffer* command_buffer) override; - - Slices slices() override; - - private: - using OwnedKernel = std::unique_ptr; - - std::string kernel_name_; - std::vector args_; - LaunchDimensions dims_; - int64_t shmem_bytes_; - - absl::flat_hash_map kernels_; -}; - -//===----------------------------------------------------------------------===// -// MemcpyDeviceToDeviceCmd -//===----------------------------------------------------------------------===// - -class MemcpyDeviceToDeviceCmd : public CommandBufferCmd { - public: - MemcpyDeviceToDeviceCmd(BufferAllocation::Slice dst, - BufferAllocation::Slice src, int64_t num_bytes); - - Status Record(const RecordParams& params, - se::CommandBuffer* command_buffer) override; - - Slices slices() override; - - private: - BufferAllocation::Slice dst_; - BufferAllocation::Slice src_; - int64_t num_bytes_; -}; - -//===----------------------------------------------------------------------===// -// IfCmd -//===----------------------------------------------------------------------===// - -class IfCmd : public CommandBufferCmd { - public: - IfCmd(BufferAllocation::Slice pred, CommandBufferCmdSequence then_cmds); - - Status Initialize(se::StreamExecutor* executor, - ExecutableSource source) override; - - Status Record(const RecordParams& params, - se::CommandBuffer* command_buffer) override; - - Slices slices() override; - - private: - BufferAllocation::Slice pred_; - CommandBufferCmdSequence then_cmds_; -}; - -//===----------------------------------------------------------------------===// -// AllocateCmd -//===----------------------------------------------------------------------===// - -class AllocateCmd : public CommandBufferCmd { - public: - explicit AllocateCmd(BufferAllocation* allocation); - - // After calling this function, the allocated memory address is updated to - Status Record(const RecordParams& params, - se::CommandBuffer* command_buffer) override; - - Slices slices() override; - - private: - BufferAllocation* allocation_; -}; - -//===----------------------------------------------------------------------===// -// GemmCmd -//===----------------------------------------------------------------------===// - -class GemmCmd : public CommandBufferCmd { - public: - GemmCmd(GemmConfig config, const BufferAllocation::Slice& lhs_buffer, - const BufferAllocation::Slice& rhs_buffer, - const BufferAllocation::Slice& output_buffer, bool deterministic); - - Status Initialize(se::StreamExecutor* executor, - ExecutableSource source) override; - - Status Record(const RecordParams& params, - se::CommandBuffer* command_buffer) override; - - Slices slices() override; - - private: - const GemmConfig config_; - const BufferAllocation::Slice lhs_buffer_; - const BufferAllocation::Slice rhs_buffer_; - const BufferAllocation::Slice output_buffer_; - // Whether to run deterministically. - const bool deterministic_; -}; - -} // namespace xla::gpu - -#endif // XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_CMD_H_ diff --git a/xla/service/gpu/runtime3/command_buffer_cmd_emitter.cc b/xla/service/gpu/runtime3/command_buffer_cmd_emitter.cc deleted file mode 100644 index 23e2d5407f3e7..0000000000000 --- a/xla/service/gpu/runtime3/command_buffer_cmd_emitter.cc +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime3/command_buffer_cmd_emitter.h" - -#include -#include - -#include "xla/service/gpu/kernel_thunk.h" -#include "xla/service/gpu/runtime3/command_buffer_cmd.h" -#include "xla/service/gpu/thunk.h" -#include "xla/statusor.h" -#include "xla/util.h" -#include "tsl/platform/statusor.h" - -namespace xla::gpu { - -namespace { - -StatusOr> ConvertToCommand( - const Thunk& thunk) { - switch (thunk.kind()) { - // TODO(anlunx): Support other thunk kinds. - case Thunk::Kind::kKernel: { - auto& kernel_thunk = static_cast(thunk); - auto kernel_cmd = std::make_unique( - kernel_thunk.kernel_name(), kernel_thunk.arguments(), - kernel_thunk.launch_dimensions(), kernel_thunk.shmem_bytes()); - return kernel_cmd; - } - default: - return InternalError("Unsupported thunk kind"); - } -} - -} // namespace - -StatusOr ConvertToCommands( - const ThunkSequence& sequence) { - CommandBufferCmdSequence cmd_sequence; - for (const std::unique_ptr& thunk : sequence) { - TF_ASSIGN_OR_RETURN(std::unique_ptr cmd, - ConvertToCommand(*thunk)); - cmd_sequence.Append(std::move(cmd)); - } - return cmd_sequence; -} - -} // namespace xla::gpu diff --git a/xla/service/gpu/runtime3/command_buffer_cmd_emitter.h b/xla/service/gpu/runtime3/command_buffer_cmd_emitter.h deleted file mode 100644 index f5abc2cd5d3f1..0000000000000 --- a/xla/service/gpu/runtime3/command_buffer_cmd_emitter.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_CMD_EMITTER_H_ -#define XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_CMD_EMITTER_H_ - -#include "xla/service/gpu/runtime3/command_buffer_cmd.h" -#include "xla/service/gpu/thunk.h" -#include "xla/statusor.h" - -namespace xla::gpu { - -StatusOr ConvertToCommands( - const ThunkSequence& sequence); - -} // namespace xla::gpu - -#endif // XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_CMD_EMITTER_H_ diff --git a/xla/service/gpu/runtime3/command_buffer_cmd_test.cc b/xla/service/gpu/runtime3/command_buffer_cmd_test.cc deleted file mode 100644 index 61474a5ba05e6..0000000000000 --- a/xla/service/gpu/runtime3/command_buffer_cmd_test.cc +++ /dev/null @@ -1,134 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime3/command_buffer_cmd.h" - -#include -#include - -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/stream_executor/command_buffer.h" -#include "xla/stream_executor/cuda/cuda_test_kernels.h" -#include "xla/stream_executor/multi_platform_manager.h" -#include "xla/stream_executor/platform.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" // IWYU pragma: keep -#include "tsl/lib/core/status_test_util.h" -#include "tsl/platform/test.h" - -namespace xla::gpu { - -static se::StreamExecutor* CudaExecutor() { - auto* platform = se::MultiPlatformManager::PlatformWithName("CUDA").value(); - return platform->ExecutorForDevice(0).value(); -} - -TEST(CommandBufferCmdTest, MemcpyCmd) { - se::StreamExecutor* executor = CudaExecutor(); - - se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: a=42, b=0 - se::DeviceMemory a = executor->AllocateArray(length, 0); - se::DeviceMemory b = executor->AllocateArray(length, 0); - - stream.ThenMemset32(&a, 42, byte_length); - stream.ThenMemZero(&b, byte_length); - - // Prepare buffer allocations for recording command buffer. - BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); - BufferAllocation alloc_b(/*index=*/1, byte_length, /*color=*/0); - - BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); - BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); - - // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence commands; - commands.Emplace(slice_b, slice_a, byte_length); - - BufferAllocations allocations({a, b}, 0, executor->GetAllocator()); - - auto command_buffer = se::CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(commands.Record({executor, &allocations}, &command_buffer)); - - // Execute command buffer and verify that it copied the memory. - TF_ASSERT_OK(executor->Submit(&stream, command_buffer)); - - // Copy `b` data back to host. - std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), b, byte_length); - - ASSERT_EQ(dst, std::vector(4, 42)); -} - -TEST(CommandBufferCmdTest, LaunchCmd) { - se::StreamExecutor* executor = CudaExecutor(); - - se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: a=42, b=0 - se::DeviceMemory a = executor->AllocateArray(length, 0); - se::DeviceMemory b = executor->AllocateArray(length, 0); - - stream.ThenMemset32(&a, 42, byte_length); - stream.ThenMemZero(&b, byte_length); - - // Prepare buffer allocations for recording command buffer. - BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); - BufferAllocation alloc_b(/*index=*/1, byte_length, /*color=*/0); - - BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); - BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); - - auto args = {slice_a, slice_a, slice_b}; // b = a + a - - // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence commands; - commands.Emplace("add", args, LaunchDimensions(1, 4), - /*shmem_bytes=*/0); - - // Initialize command sequence and load device kernels. - CommandBufferCmd::ExecutableSource source = { - /*text=*/se::cuda::internal::kAddI32Kernel, /*binary=*/{}}; - TF_ASSERT_OK(commands.Initialize(executor, source)); - - BufferAllocations allocations({a, b}, 0, executor->GetAllocator()); - - auto command_buffer = se::CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(commands.Record({executor, &allocations}, &command_buffer)); - - // Execute command buffer and verify that it copied the memory. - TF_ASSERT_OK(executor->Submit(&stream, command_buffer)); - - // Copy `b` data back to host. - std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), b, byte_length); - - ASSERT_EQ(dst, std::vector(4, 42 + 42)); -} - -} // namespace xla::gpu diff --git a/xla/service/gpu/runtime3/command_buffer_thunk.cc b/xla/service/gpu/runtime3/command_buffer_thunk.cc deleted file mode 100644 index 3638935f2c0c9..0000000000000 --- a/xla/service/gpu/runtime3/command_buffer_thunk.cc +++ /dev/null @@ -1,108 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime3/command_buffer_thunk.h" - -#include -#include - -#include "absl/synchronization/mutex.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/runtime3/command_buffer_cmd.h" -#include "xla/service/gpu/thunk.h" -#include "xla/status.h" -#include "xla/statusor.h" -#include "xla/stream_executor/command_buffer.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" - -namespace xla::gpu { - -CommandBufferThunk::ExecutorCommandBuffer::ExecutorCommandBuffer( - se::CommandBuffer command_buffer) - : command_buffer(std::move(command_buffer)) {} - -CommandBufferThunk::CommandBufferThunk(CommandBufferCmdSequence commands, - ThunkInfo thunk_info) - : Thunk(Thunk::kCommandBuffer, std::move(thunk_info)), - commands_(std::move(commands)) {} - -Status CommandBufferThunk::Initialize(se::StreamExecutor* executor, - ExecutableSource executable_source) { - return commands_.Initialize(executor, executable_source); -} - -bool CommandBufferThunk::ExecutorCommandBuffer::ShouldUpdateCommandBuffer( - const CommandBufferCmdSequence& commands, - const CommandBufferCmd::RecordParams& params) { - bool should_update = false; - const BufferAllocations* allocs = params.buffer_allocations; - - // We check only allocations referenced by commands in a cmd sequence, and - // leave every other entry default initialized (nullptr device memory). - for (BufferAllocation::Index index : commands.allocs_indices()) { - se::DeviceMemoryBase alloc = allocs->GetDeviceAddress(index); - - if (recorded_allocs.size() <= index) { - recorded_allocs.resize(index + 1); - } - - if (!recorded_allocs[index].IsSameAs(alloc)) { - recorded_allocs[index] = alloc; - should_update = true; - } - } - - return should_update; -} - -Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { - se::StreamExecutor* executor = params.stream->parent(); - TF_ASSIGN_OR_RETURN(ExecutorCommandBuffer * cmd_buffer, - GetOrCreateCommandBuffer(executor)); - - absl::MutexLock lock(&cmd_buffer->mutex); - - CommandBufferCmd::RecordParams record_params = { - executor, params.buffer_allocations, &cmd_buffer->allocations}; - - if (cmd_buffer->ShouldUpdateCommandBuffer(commands_, record_params)) { - TF_RETURN_IF_ERROR( - commands_.Record(record_params, &cmd_buffer->command_buffer)); - } - - return executor->Submit(params.stream, cmd_buffer->command_buffer); -} - -StatusOr -CommandBufferThunk::GetOrCreateCommandBuffer(se::StreamExecutor* executor) { - absl::MutexLock lock(&mutex_); - - // Check if command buffer already exists - if (auto it = command_buffers_.find(executor); it != command_buffers_.end()) { - return &it->second; - } - - // Create a new empty command buffer. - TF_ASSIGN_OR_RETURN(auto command_buffer, se::CommandBuffer::Create(executor)); - auto emplaced = command_buffers_.emplace(executor, std::move(command_buffer)); - - return &emplaced.first->second; -} - -} // namespace xla::gpu diff --git a/xla/service/gpu/runtime3/command_buffer_thunk.h b/xla/service/gpu/runtime3/command_buffer_thunk.h deleted file mode 100644 index e5e2a3ed6abda..0000000000000 --- a/xla/service/gpu/runtime3/command_buffer_thunk.h +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_THUNK_H_ - -#include - -#include "absl/base/thread_annotations.h" -#include "absl/container/node_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "xla/service/gpu/runtime3/command_buffer_allocations.h" -#include "xla/service/gpu/runtime3/command_buffer_cmd.h" -#include "xla/service/gpu/thunk.h" -#include "xla/status.h" -#include "xla/statusor.h" -#include "xla/stream_executor/command_buffer.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/stream_executor.h" - -namespace xla::gpu { - -class CommandBufferThunk : public Thunk { - public: - explicit CommandBufferThunk(CommandBufferCmdSequence commands, - ThunkInfo thunk_info); - - Status Initialize(se::StreamExecutor*, ExecutableSource) override; - Status ExecuteOnStream(const ExecuteParams& params) override; - - private: - // Command buffer instantiated on a `se::StreamExecutor` instance, and - // auxiliary state required for efficient command buffer updates. - struct ExecutorCommandBuffer { - explicit ExecutorCommandBuffer(se::CommandBuffer command_buffer); - - // Returns true if `commands` cmd sequence has to be recorded into - // `command_buffer` to update it (see `recorded_allocs` below). - bool ShouldUpdateCommandBuffer(const CommandBufferCmdSequence& commands, - const CommandBufferCmd::RecordParams& params) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex); - - // se::CommandBuffer is not thread safe, and we guard it with a mutex to - // guarantee that we do not mutate it concurrently. - absl::Mutex mutex; - se::CommandBuffer command_buffer ABSL_GUARDED_BY(mutex); - - // TODO(ezhulenev): We need to move command buffer allocations all the way - // up to the GpuExecutable as we can have Allocate and Free commands in - // different command buffers. Consider making it a part of - // BufferAllocations (as std::unique_ptr member). - - // Memory allocations performed by a `command_buffer`. - CommandBufferAllocations allocations ABSL_GUARDED_BY(mutex); - - // Mapping from buffer allocation index to the device memory passed at - // that index to the last call of `commands_.Record(...)` for - // `command_buffer`. We can just use a vector instead of map because - // `BufferAllocation::Index` is a unique identifier assigned - // contiguously and thus can be used as array index. - // - // If no device memory addresses changed from a previous call to - // `Record`, we can skip command buffer update and simply submit it for - // execution on a stream. All other pieces of information (like thread - // and block sizes) captured by commands at construction time and do not - // change. - std::vector recorded_allocs ABSL_GUARDED_BY(mutex); - }; - - // Returns a command buffer instantiated for `executor` or creates new one. - StatusOr GetOrCreateCommandBuffer( - se::StreamExecutor* executor); - - // Command sequence that initializes command buffers on each executor. - CommandBufferCmdSequence commands_; - - // Command buffer sequence instantiates command buffers on all executors. - absl::Mutex mutex_; - absl::node_hash_map - command_buffers_ ABSL_GUARDED_BY(mutex_); -}; - -} // namespace xla::gpu - -#endif // XLA_SERVICE_GPU_RUNTIME3_COMMAND_BUFFER_THUNK_H_ diff --git a/xla/service/gpu/runtime3/command_buffer_thunk_test.cc b/xla/service/gpu/runtime3/command_buffer_thunk_test.cc deleted file mode 100644 index f976cdecdbf36..0000000000000 --- a/xla/service/gpu/runtime3/command_buffer_thunk_test.cc +++ /dev/null @@ -1,528 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime3/command_buffer_thunk.h" - -#include -#include -#include -#include -#include - -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/runtime3/command_buffer_cmd.h" -#include "xla/service/gpu/thunk.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/shape_util.h" -#include "xla/stream_executor/blas.h" -#include "xla/stream_executor/command_buffer.h" -#include "xla/stream_executor/cuda/cuda_test_kernels.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/multi_platform_manager.h" -#include "xla/stream_executor/platform.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" // IWYU pragma: keep -#include "tsl/lib/core/status_test_util.h" -#include "tsl/platform/test.h" - -namespace xla::gpu { - -static se::StreamExecutor* CudaExecutor() { - auto* platform = se::MultiPlatformManager::PlatformWithName("CUDA").value(); - return platform->ExecutorForDevice(0).value(); -} - -TEST(CommandBufferThunkTest, MemcpyCmd) { - se::StreamExecutor* executor = CudaExecutor(); - - se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: a=42, b=0 - se::DeviceMemory a = executor->AllocateArray(length, 0); - se::DeviceMemory b = executor->AllocateArray(length, 0); - - stream.ThenMemset32(&a, 42, byte_length); - stream.ThenMemZero(&b, byte_length); - - // Prepare buffer allocations for recording command buffer. - BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); - BufferAllocation alloc_b(/*index=*/1, byte_length, /*color=*/0); - - BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); - BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); - - // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence commands; - commands.Emplace(slice_b, slice_a, byte_length); - - // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); - - ServiceExecutableRunOptions run_options; - BufferAllocations allocations({a, b}, 0, executor->GetAllocator()); - Thunk::ExecuteParams params(run_options, allocations, &stream, {}); - - // Execute command buffer thunk and verify that it copied the memory. - TF_ASSERT_OK(thunk.ExecuteOnStream(params)); - - // Copy `b` data back to host. - std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), b, byte_length); - - ASSERT_EQ(dst, std::vector(4, 42)); - - // Try to update the command buffer with the same buffers. - stream.ThenMemZero(&b, byte_length); - - // Thunk execution should automatically update underlying command buffer. - TF_ASSERT_OK(thunk.ExecuteOnStream(params)); - - // Copy `b` data back to host. - std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), b, byte_length); - - ASSERT_EQ(dst, std::vector(4, 42)); -} - -// This test does the following operations: -// 1. Allocates memory region "a" and "c" outside command buffer. -// 2. Allocates memory region "b" inside command buffer. -// 3. MemCopyDeviceToDevice from "a" to "b" inside command buffer. -// 4. MemCopyDEviceToDevice from "b" to "c" inside command buffer. -// 5. Verify that region "c" has the same content as "a". -TEST(CommandBufferThunkTest, AllocateCmd) { - se::StreamExecutor* executor = CudaExecutor(); - - se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - // Prepare arguments: - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); - BufferAllocation alloc_b(/*index=*/1, byte_length, /*color=*/0); - BufferAllocation alloc_c(/*index=*/2, byte_length, /*color=*/0); - BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); - BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); - BufferAllocation::Slice slice_c(&alloc_c, 0, byte_length); - - // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence commands; - commands.Emplace(&alloc_b); - commands.Emplace(slice_b, slice_a, byte_length); - commands.Emplace(slice_c, slice_b, byte_length); - - // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); - - // Prepare arguments: a=42, b=0 - se::DeviceMemory a = executor->AllocateArray(length, 0); - stream.ThenMemset32(&a, 42, byte_length); - se::DeviceMemory b(se::DeviceMemoryBase( - reinterpret_cast(BufferAllocations::kExternalAllocationMarker), - byte_length)); - se::DeviceMemory c = executor->AllocateArray(length, 0); - BufferAllocations allocations({a, b, c}, 0, executor->GetAllocator()); - - ServiceExecutableRunOptions run_options; - Thunk::ExecuteParams params(run_options, allocations, &stream, {}); - - // Execute command buffer thunk and verify that it copied the memory. - TF_ASSERT_OK(thunk.ExecuteOnStream(params)); - - // Copy `b` data back to host. - std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), allocations.GetMutableDeviceAddress(2), - byte_length); - - ASSERT_EQ(dst, std::vector(4, 42)); -} - -TEST(CommandBufferThunkTest, LaunchCmd) { - se::StreamExecutor* executor = CudaExecutor(); - - se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: a=42, b=0 - se::DeviceMemory a = executor->AllocateArray(length, 0); - se::DeviceMemory b = executor->AllocateArray(length, 0); - - stream.ThenMemset32(&a, 42, byte_length); - stream.ThenMemZero(&b, byte_length); - - // Prepare buffer allocations for recording command buffer. - BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); - BufferAllocation alloc_b(/*index=*/1, byte_length, /*color=*/0); - - BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); - BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); - - auto args = {slice_a, slice_a, slice_b}; // b = a + a - - // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence commands; - commands.Emplace("add", args, LaunchDimensions(1, 4), - /*shmem_bytes=*/0); - - // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); - - ServiceExecutableRunOptions run_options; - BufferAllocations allocations({a, b}, 0, executor->GetAllocator()); - Thunk::ExecuteParams params(run_options, allocations, &stream, {}); - - CommandBufferCmd::ExecutableSource source = { - /*text=*/se::cuda::internal::kAddI32Kernel, /*binary=*/{}}; - TF_ASSERT_OK(thunk.Initialize(executor, source)); - - // Execute command buffer thunk and verify that it added the value. - TF_ASSERT_OK(thunk.ExecuteOnStream(params)); - - // Copy `b` data back to host. - std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), b, byte_length); - - ASSERT_EQ(dst, std::vector(4, 42 + 42)); - - // Prepare buffer allocation for updating command buffer: c=0 - se::DeviceMemory c = executor->AllocateArray(length, 0); - stream.ThenMemZero(&c, byte_length); - - // Update buffer allocation #1 to buffer `c`. - allocations = BufferAllocations({a, c}, 0, executor->GetAllocator()); - - // Thunk execution should automatically update underlying command buffer. - TF_ASSERT_OK(thunk.ExecuteOnStream(params)); - - // Copy `c` data back to host. - std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), c, byte_length); - - ASSERT_EQ(dst, std::vector(4, 42 + 42)); - - // Try to update the command buffer with the same buffers. - stream.ThenMemZero(&c, byte_length); - - // Thunk execution should automatically update underlying command buffer. - TF_ASSERT_OK(thunk.ExecuteOnStream(params)); - - // Copy `c` data back to host. - std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), c, byte_length); - - ASSERT_EQ(dst, std::vector(4, 42 + 42)); -} - -TEST(CommandBufferThunkTest, GemmCmd) { - se::StreamExecutor* executor = CudaExecutor(); - - se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - int64_t lhs_length = sizeof(float) * 2 * 4; - int64_t rhs_length = sizeof(float) * 4 * 3; - int64_t out_length = sizeof(float) * 2 * 3; - - // Prepare arguments: - // lhs = [1.0, 2.0, 3.0, 4.0 - // 5.0, 6.0, 7.0, 8.0] - // rhs = [1.0, 1.0, 1.0 - // 1.0, 1.0, 1.0 - // 1.0, 1.0, 1.0 - // 1.0, 1.0, 1.0] - se::DeviceMemory lhs = executor->AllocateArray(2 * 4); - std::vector lhs_arr{1, 2, 3, 4, 5, 6, 7, 8}; - stream.ThenMemcpy(&lhs, lhs_arr.data(), lhs_length); - - se::DeviceMemory rhs = executor->AllocateArray(4 * 3); - std::vector rhs_arr(12, 1); - stream.ThenMemcpy(&rhs, rhs_arr.data(), rhs_length); - - se::DeviceMemory out = executor->AllocateArray(2 * 3); - stream.ThenMemZero(&out, out_length); - - // Prepare buffer allocations for recording command buffer. - BufferAllocation alloc_lhs(/*index=*/0, lhs_length, /*color=*/0); - BufferAllocation alloc_rhs(/*index=*/1, rhs_length, /*color=*/0); - BufferAllocation alloc_out(/*index=*/2, out_length, /*color=*/0); - - BufferAllocation::Slice slice_lhs(&alloc_lhs, 0, lhs_length); - BufferAllocation::Slice slice_rhs(&alloc_rhs, 0, rhs_length); - BufferAllocation::Slice slice_out(&alloc_out, 0, out_length); - - auto config = GemmConfig::For( - ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), {}, {1}, - ShapeUtil::MakeShape(PrimitiveType::F32, {4, 3}), {}, {0}, - ShapeUtil::MakeShape(PrimitiveType::F32, {2, 3}), 1.0, 0.0, 0.0, - std::nullopt, se::blas::kDefaultComputePrecision, false, false); - ASSERT_TRUE(config.ok()); - - // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence commands; - commands.Emplace(config.value(), slice_lhs, slice_rhs, slice_out, - /*deterministic=*/true); - - // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); - - ServiceExecutableRunOptions run_options; - BufferAllocations allocations({lhs, rhs, out}, 0, executor->GetAllocator()); - Thunk::ExecuteParams params(run_options, allocations, &stream, {}); - - CommandBufferCmd::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; - TF_ASSERT_OK(thunk.Initialize(executor, source)); - - // Execute command buffer thunk and verify that it executed a GEMM. - TF_ASSERT_OK(thunk.ExecuteOnStream(params)); - - // Copy `out` data back to host. - std::vector dst(6, 0); - stream.ThenMemcpy(dst.data(), out, out_length); - - ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); - - // Prepare buffer allocation for updating command buffer. - se::DeviceMemory updated_out = executor->AllocateArray(2 * 3); - stream.ThenMemZero(&updated_out, out_length); - - // Update buffer allocation to updated `out` buffer. - allocations = - BufferAllocations({lhs, rhs, updated_out}, 0, executor->GetAllocator()); - - // Thunk execution should automatically update underlying command buffer. - TF_ASSERT_OK(thunk.ExecuteOnStream(params)); - - // Copy `updated_out` data back to host. - std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), updated_out, out_length); - - ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); - - // Try to update the command buffer with the same buffers. - stream.ThenMemZero(&updated_out, out_length); - - // Thunk execution should automatically update underlying command buffer. - TF_ASSERT_OK(thunk.ExecuteOnStream(params)); - - // Copy `updated_out` data back to host. - std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), updated_out, out_length); - - ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); -} - -TEST(CommandBufferThunkTest, MultipleLaunchCmd) { - se::StreamExecutor* executor = CudaExecutor(); - - se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: a=42, b=0 - se::DeviceMemory a = executor->AllocateArray(length, 0); - se::DeviceMemory b = executor->AllocateArray(length, 0); - se::DeviceMemory c = executor->AllocateArray(length, 0); - se::DeviceMemory d = executor->AllocateArray(length, 0); - - stream.ThenMemset32(&a, 42, byte_length); - stream.ThenMemZero(&b, byte_length); - stream.ThenMemset32(&c, 21, byte_length); - stream.ThenMemZero(&d, byte_length); - - // Prepare buffer allocations for recording command buffer. - BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); - BufferAllocation alloc_b(/*index=*/1, byte_length, /*color=*/0); - BufferAllocation alloc_c(/*index=*/2, byte_length, /*color=*/0); - BufferAllocation alloc_d(/*index=*/3, byte_length, /*color=*/0); - - BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); - BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); - BufferAllocation::Slice slice_c(&alloc_c, 0, byte_length); - BufferAllocation::Slice slice_d(&alloc_d, 0, byte_length); - - auto args = {slice_a, slice_a, slice_b}; // b = a + a - auto args_1 = {slice_c, slice_c, slice_d}; // d = c + c - - // Prepare commands sequence for constructing command buffer. - CommandBufferCmdSequence commands; - commands.Emplace("add", args, LaunchDimensions(1, 4), - /*shmem_bytes=*/0); - commands.Emplace("add", args_1, LaunchDimensions(1, 4), - /*shmem_bytes=*/0); - - // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); - - ServiceExecutableRunOptions run_options; - BufferAllocations allocations({a, b, c, d}, 0, executor->GetAllocator()); - Thunk::ExecuteParams params(run_options, allocations, &stream, {}); - - CommandBufferCmd::ExecutableSource source = { - /*text=*/se::cuda::internal::kAddI32Kernel, /*binary=*/{}}; - TF_ASSERT_OK(thunk.Initialize(executor, source)); - - // Execute command buffer thunk and verify that it added the value. - TF_ASSERT_OK(thunk.ExecuteOnStream(params)); - - // Copy `b` data back to host. - std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), b, byte_length); - ASSERT_EQ(dst, std::vector(4, 42 + 42)); - - // Copy `d` data back to host. - std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), d, byte_length); - ASSERT_EQ(dst, std::vector(4, 21 + 21)); - - BufferAllocation alloc_e(/*index=*/3, byte_length, /*color=*/0); - BufferAllocation::Slice slice_e(&alloc_e, 0, byte_length); - - // Prepare buffer allocation for updating command buffer: e=0 - se::DeviceMemory e = executor->AllocateArray(length, 0); - stream.ThenMemZero(&e, byte_length); - - // Update buffer allocation #1 to buffer `c`. - allocations = BufferAllocations({a, b, c, e}, 0, executor->GetAllocator()); - - // Thunk execution should automatically update underlying command buffer. - TF_ASSERT_OK(thunk.ExecuteOnStream(params)); - - // Copy `b` data back to host. - std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), b, byte_length); - ASSERT_EQ(dst, std::vector(4, 42 + 42)); - - // Copy `e` data back to host. - std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), e, byte_length); - ASSERT_EQ(dst, std::vector(4, 21 + 21)); - - // Try to update the command buffer with the same buffers. - stream.ThenMemZero(&e, byte_length); - - // Thunk execution should automatically update underlying command buffer. - TF_ASSERT_OK(thunk.ExecuteOnStream(params)); - - // Copy `b` data back to host. - std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), b, byte_length); - ASSERT_EQ(dst, std::vector(4, 42 + 42)); - - // Copy `e` data back to host. - std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), e, byte_length); - ASSERT_EQ(dst, std::vector(4, 21 + 21)); -} - -TEST(CommandBufferThunkTest, IfCmd) { - se::StreamExecutor* executor = CudaExecutor(); - if (!se::CommandBuffer::SupportsConditionalCommands(executor->platform())) { - GTEST_SKIP() << "CUDA graph conditionals are not supported"; - } - - se::Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: pred=true, a=42, b=0 - se::DeviceMemory pred = executor->AllocateArray(1, 0); - se::DeviceMemory a = executor->AllocateArray(length, 0); - se::DeviceMemory b = executor->AllocateArray(length, 0); - - constexpr bool kTrue = true; - stream.ThenMemcpy(&pred, &kTrue, 1); - stream.ThenMemset32(&a, 42, byte_length); - stream.ThenMemZero(&b, byte_length); - - // Prepare buffer allocations for recording command buffer. - BufferAllocation alloc_p(/*index=*/0, 1, /*color=*/0); - BufferAllocation alloc_a(/*index=*/1, byte_length, /*color=*/0); - BufferAllocation alloc_b(/*index=*/2, byte_length, /*color=*/0); - - BufferAllocation::Slice slice_p(&alloc_p, 0, 1); - BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); - BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); - - auto args = {slice_a, slice_a, slice_b}; // b = a + a - - // Prepare commands sequence for `then` branch. - CommandBufferCmdSequence then_commands; - then_commands.Emplace("add", args, LaunchDimensions(1, 4), - /*shmem_bytes=*/0); - - // Prepare commands sequence for thunk. - CommandBufferCmdSequence commands; - commands.Emplace(slice_p, std::move(then_commands)); - - // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); - - ServiceExecutableRunOptions run_options; - BufferAllocations allocations({pred, a, b}, 0, executor->GetAllocator()); - Thunk::ExecuteParams params(run_options, allocations, &stream, {}); - - CommandBufferCmd::ExecutableSource source = { - /*text=*/se::cuda::internal::kAddI32Kernel, /*binary=*/{}}; - TF_ASSERT_OK(thunk.Initialize(executor, source)); - - // Execute command buffer thunk and verify that it added the value. - TF_ASSERT_OK(thunk.ExecuteOnStream(params)); - - // Copy `b` data back to host. - std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), b, byte_length); - - ASSERT_EQ(dst, std::vector(4, 42 + 42)); - - // Prepare buffer allocation for updating command buffer: c=0 - se::DeviceMemory c = executor->AllocateArray(length, 0); - stream.ThenMemZero(&c, byte_length); - - // Update buffer allocation #2 to buffer `c`. - allocations = BufferAllocations({pred, a, c}, 0, executor->GetAllocator()); - - // Thunk execution should automatically update underlying command buffer. - TF_ASSERT_OK(thunk.ExecuteOnStream(params)); - - // Copy `c` data back to host. - std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), c, byte_length); - - ASSERT_EQ(dst, std::vector(4, 42 + 42)); -} - -} // namespace xla::gpu diff --git a/xla/service/gpu/runtime3/custom_call_thunk.cc b/xla/service/gpu/runtime3/custom_call_thunk.cc deleted file mode 100644 index 62751ef883077..0000000000000 --- a/xla/service/gpu/runtime3/custom_call_thunk.cc +++ /dev/null @@ -1,149 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime3/custom_call_thunk.h" - -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "xla/executable_run_options.h" -#include "xla/ffi/api/c_api.h" -#include "xla/ffi/call_frame.h" -#include "xla/ffi/ffi_api.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/custom_call_status.h" -#include "xla/service/custom_call_status_internal.h" -#include "xla/service/gpu/thunk.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/status.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/util.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/stream_executor/gpu/gpu_stream.h" -#endif - -namespace xla { -namespace gpu { - -using xla::ffi::CallFrame; -using xla::ffi::CallFrameBuilder; -using xla::ffi::CallOptions; - -CustomCallThunk::CustomCallThunk(ThunkInfo thunk_info, - CustomCallTarget call_target, - std::vector> operands, - std::vector> results, - const std::string& opaque) - : Thunk(Thunk::kCustomCall, thunk_info), - operands_(std::move(operands)), - results_(std::move(results)), - call_target_(std::move(call_target)), - opaque_(opaque) {} - -CustomCallThunk::CustomCallThunk(ThunkInfo thunk_info, XLA_FFI_Handler* handler, - std::vector> operands, - std::vector> results, - AttributesMap attributes) - : Thunk(Thunk::kCustomCall, thunk_info), - operands_(std::move(operands)), - results_(std::move(results)), - handler_(std::move(handler)), - attributes_(std::move(attributes)) {} - -Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) { - // gpu_stream is CUstream or e.g. the equivalent type in ROCm. - std::vector buffers; - buffers.reserve(operands_.size() + results_.size()); - for (auto& slices : {operands_, results_}) { - for (const std::optional& slice : slices) { - if (!slice.has_value()) { - buffers.push_back(nullptr); - continue; - } - - if (!slice->slice.allocation()) - return InternalError("custom call input missing buffer allocation"); - - buffers.push_back( - params.buffer_allocations->GetDeviceAddress(slice->slice).opaque()); - } - } - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - auto gpu_stream = se::gpu::AsGpuStreamValue(params.stream); - XlaCustomCallStatus custom_call_status; - call_target_(gpu_stream, buffers.data(), opaque_.data(), opaque_.size(), - &custom_call_status); - auto message = CustomCallStatusGetMessage(&custom_call_status); - if (message) { - return InternalError("CustomCall failed: %s", *message); - } else { - return OkStatus(); - } -#else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - return Unavailable( - "Custom calls on GPU are not supported in this configuration. Please " - "build with --config=cuda or --config=rocm"); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -} - -Status CustomCallThunk::ExecuteFfiHandler(const ExecuteParams& params) { - // TODO(ezhulenev): This is not the most optimal approach, as we'll be doing - // a lot of extra allocation on every call. We have to keep attributes - // separate from arguments, as they do not change after thunk is constructed. - CallFrameBuilder builder; - - for (auto& slices : {operands_, results_}) { - for (const std::optional& slice : slices) { - // TODO(ezhulenev): Add a token argument type to XLA:FFI. - if (!slice.has_value()) { - return InternalError("FFI handlers do not support tokens (yet)!"); - } - - if (!slice->slice.allocation()) - return InternalError("custom call input missing buffer allocation"); - - builder.AddBufferArg( - params.buffer_allocations->GetDeviceAddress(slice->slice), - slice->shape.element_type(), slice->shape.dimensions()); - } - } - - CallFrameBuilder::AttributesBuilder attrs; - attrs.Append(attributes_); - - builder.AddAttributes(attrs.Build()); - CallFrame call_frame = builder.Build(); - - // TODO(ezhulenev): Remove `ServiceExecutableRunOptions` from FFI handler - // execution context, as apparently it's not easily accessible from Thunk. - ExecutableRunOptions run_options; - run_options.set_stream(params.stream); - ServiceExecutableRunOptions service_run_options(run_options); - - CallOptions options = {&service_run_options}; - return Call(handler_, call_frame, options); -} - -Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) { - return handler_ ? ExecuteFfiHandler(params) : ExecuteCustomCall(params); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime3/custom_call_thunk.h b/xla/service/gpu/runtime3/custom_call_thunk.h deleted file mode 100644 index 22e8f676fd3aa..0000000000000 --- a/xla/service/gpu/runtime3/custom_call_thunk.h +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME3_CUSTOM_CALL_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_CUSTOM_CALL_THUNK_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "xla/ffi/api/c_api.h" -#include "xla/ffi/call_frame.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/custom_call_status.h" -#include "xla/service/gpu/thunk.h" -#include "xla/shape.h" -#include "xla/status.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/stream_executor/gpu/gpu_types.h" -#endif - -namespace xla { -namespace gpu { - -// Thunk to run a GPU custom call. -// -// This thunk's `ExecuteOnStream` implementation executes a host function -// `call_target` which is expected to enqueue operations onto the GPU. -// -// Note that not all kCustomCall HLOs in XLA:GPU end up being run by this thunk. -// XLA itself creates kCustomCall instructions when lowering kConvolution HLOs -// into calls to cudnn. These internally-created custom-calls are run using -// ConvolutionThunk, not CustomCallThunk. There's no ambiguity because they -// have special call target names (e.g. "__cudnn$convForward") that only the -// compiler is allowed to create. -class CustomCallThunk : public Thunk { - public: -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - using Stream = stream_executor::gpu::GpuStreamHandle; -#else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - using Stream = void*; -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - - using CustomCallTarget = std::function; - - // We keep buffer allocation slice together with its shape to be able to fill - // FFI arguments with required details. - struct Slice { - BufferAllocation::Slice slice; - Shape shape; - }; - - using Attribute = ffi::CallFrameBuilder::FlatAttribute; - using AttributesMap = ffi::CallFrameBuilder::FlatAttributesMap; - - CustomCallThunk(ThunkInfo thunk_info, CustomCallTarget call_target, - std::vector> operands, - std::vector> results, - const std::string& opaque); - - CustomCallThunk(ThunkInfo thunk_info, XLA_FFI_Handler* handler, - std::vector> operands, - std::vector> results, - AttributesMap attributes); - - Status ExecuteOnStream(const ExecuteParams& params) override; - - private: - Status ExecuteCustomCall(const ExecuteParams& params); - Status ExecuteFfiHandler(const ExecuteParams& params); - - std::vector> operands_; - std::vector> results_; - - // This is a legacy custom call API that is discouraged, and will be - // deprecated once XLA:FFI mechanism is ready. - CustomCallTarget call_target_; - std::string opaque_; - - // XLA FFI provides a right type safe mechanism for registering external - // functions with XLA runtime. It's under construction, and still misses - // a lot of features. Long term it will replace legacy custom calls. - XLA_FFI_Handler* handler_ = nullptr; - AttributesMap attributes_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME3_CUSTOM_CALL_THUNK_H_ diff --git a/xla/service/gpu/runtime3/fft_thunk.cc b/xla/service/gpu/runtime3/fft_thunk.cc deleted file mode 100644 index 711ae990769c5..0000000000000 --- a/xla/service/gpu/runtime3/fft_thunk.cc +++ /dev/null @@ -1,247 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime3/fft_thunk.h" - -#include - -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "xla/stream_executor/scratch_allocator.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" -#include "xla/util.h" -#include "tsl/platform/logging.h" - -namespace xla { -namespace gpu { -namespace { - -se::fft::Type FftTypeToSeType(FftType type, bool double_precision) { - switch (type) { - case FftType::FFT: - return double_precision ? se::fft::Type::kZ2ZForward - : se::fft::Type::kC2CForward; - case FftType::IFFT: - return double_precision ? se::fft::Type::kZ2ZInverse - : se::fft::Type::kC2CInverse; - case FftType::IRFFT: - return double_precision ? se::fft::Type::kZ2D : se::fft::Type::kC2R; - case FftType::RFFT: - return double_precision ? se::fft::Type::kD2Z : se::fft::Type::kR2C; - default: - LOG(FATAL) << "unsupported fft type"; - } -} - -std::string FftTypeToString(se::fft::Type type) { - switch (type) { - case se::fft::Type::kC2CForward: - case se::fft::Type::kZ2ZForward: - return "FFT"; - case se::fft::Type::kC2CInverse: - case se::fft::Type::kZ2ZInverse: - return "IFFT"; - case se::fft::Type::kC2R: - case se::fft::Type::kZ2D: - return "IRFFT"; - case se::fft::Type::kR2C: - case se::fft::Type::kD2Z: - return "RFFT"; - default: - LOG(FATAL) << "unknown fft type"; - } -} - -} // namespace - -FftThunk::FftThunk(ThunkInfo thunk_info, FftType fft_type, - absl::Span fft_length, - const BufferAllocation::Slice& input_buffer, - const BufferAllocation::Slice& output_buffer, - const Shape& input_shape, const Shape& output_shape) - : Thunk(Kind::kFft, thunk_info), - fft_type_( - FftTypeToSeType(fft_type, input_shape.element_type() == F64 || - input_shape.element_type() == C128)), - fft_length_(fft_length.begin(), fft_length.end()), - input_buffer_(input_buffer), - output_buffer_(output_buffer), - input_shape_(input_shape), - output_shape_(output_shape) {} - -Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { - auto& buffer_allocations = *params.buffer_allocations; - - return RunFft( - buffer_allocations.GetDeviceAddress(input_buffer_), input_shape_, - buffer_allocations.GetDeviceAddress(output_buffer_), output_shape_, - fft_type_, fft_length_, buffer_allocations.device_ordinal(), - &fft_plan_cache_, params.stream, buffer_allocations.memory_allocator()); -} - -Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, - se::DeviceMemoryBase output, const Shape& output_shape, - se::fft::Type fft_type, absl::Span fft_len, - int device_ordinal, FftPlanCache* fft_plan_cache, - se::Stream* stream, se::DeviceMemoryAllocator* memory_allocator) { - VLOG(3) << "FFT type: " << FftTypeToString(fft_type); - VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape); - VLOG(3) << "Output shape: " << ShapeUtil::HumanStringWithLayout(output_shape); - - se::OwningScratchAllocator<2> scratch_allocator(device_ordinal, - memory_allocator); - - // Get the Fft plan for the given device ordinal. - FftPlan* fft_plan_ptr = fft_plan_cache->GetOrCreate(device_ordinal); - - // CuFFT thread-safety requires that separate host threads not share plans; - // protect each plan with a mutex. - absl::MutexLock lock(&fft_plan_ptr->mu); - std::unique_ptr& fft_plan = fft_plan_ptr->plan; - if (fft_plan == nullptr) { - const int64_t fft_rank = fft_len.size(); - CHECK_LE(fft_rank, 3); - int batch_size = 1; - for (int i = 0; i < input_shape.dimensions_size() - fft_rank; ++i) { - batch_size *= input_shape.dimensions(i); - } - uint64_t fft_length[3]; - uint64_t input_embed[3]; - const uint64_t input_stride = 1; - uint64_t input_distance = 1; - uint64_t output_embed[3]; - const uint64_t output_stride = 1; - uint64_t output_distance = 1; - - for (int i = 0; i < fft_rank; ++i) { - auto dim_offset = input_shape.dimensions_size() - fft_rank + i; - fft_length[i] = static_cast(fft_len[i]); - input_embed[i] = input_shape.dimensions(dim_offset); - input_distance *= input_shape.dimensions(dim_offset); - output_embed[i] = output_shape.dimensions(dim_offset); - output_distance *= output_shape.dimensions(dim_offset); - } - - constexpr bool kInPlaceFft = false; - fft_plan = stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator( - stream, fft_rank, fft_length, input_embed, input_stride, input_distance, - output_embed, output_stride, output_distance, fft_type, kInPlaceFft, - batch_size, &scratch_allocator); - TF_RET_CHECK(fft_plan != nullptr) - << "Failed to create cuFFT batched plan with scratch allocator"; - fft_plan_ptr->scale_factor = 1.0f / output_distance; - } else { - stream->parent()->AsFft()->UpdatePlanWithScratchAllocator( - stream, fft_plan.get(), &scratch_allocator); - } - - float scale_factor = fft_plan_ptr->scale_factor; - - bool launch_ok; - switch (fft_type) { - case se::fft::Type::kC2CForward: { - se::DeviceMemory input_data(input); - se::DeviceMemory output_data(output); - launch_ok = - stream->ThenFft(fft_plan.get(), input_data, &output_data).ok(); - break; - } - case se::fft::Type::kZ2ZForward: { - se::DeviceMemory input_data(input); - se::DeviceMemory output_data(output); - launch_ok = - stream->ThenFft(fft_plan.get(), input_data, &output_data).ok(); - break; - } - case se::fft::Type::kC2CInverse: { - se::DeviceMemory input_data(input); - se::DeviceMemory output_data(output); - launch_ok = - stream->ThenFft(fft_plan.get(), input_data, &output_data).ok(); - if (launch_ok) { - launch_ok = stream - ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape), - complex64(scale_factor), &output_data, 1) - .ok(); - } - break; - } - case se::fft::Type::kZ2ZInverse: { - se::DeviceMemory input_data(input); - se::DeviceMemory output_data(output); - launch_ok = - stream->ThenFft(fft_plan.get(), input_data, &output_data).ok(); - if (launch_ok) { - launch_ok = - stream - ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape), - complex128(scale_factor), &output_data, 1) - .ok(); - } - break; - } - case se::fft::Type::kR2C: { - se::DeviceMemory input_data(input); - se::DeviceMemory output_data(output); - launch_ok = - stream->ThenFft(fft_plan.get(), input_data, &output_data).ok(); - break; - } - case se::fft::Type::kD2Z: { - se::DeviceMemory input_data(input); - se::DeviceMemory output_data(output); - launch_ok = - stream->ThenFft(fft_plan.get(), input_data, &output_data).ok(); - break; - } - case se::fft::Type::kC2R: { - se::DeviceMemory input_data(input); - se::DeviceMemory output_data(output); - launch_ok = - stream->ThenFft(fft_plan.get(), input_data, &output_data).ok(); - if (launch_ok) { - launch_ok = stream - ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape), - scale_factor, &output_data, 1) - .ok(); - } - break; - } - case se::fft::Type::kZ2D: { - se::DeviceMemory input_data(input); - se::DeviceMemory output_data(output); - launch_ok = - stream->ThenFft(fft_plan.get(), input_data, &output_data).ok(); - if (launch_ok) { - launch_ok = stream - ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape), - scale_factor, &output_data, 1) - .ok(); - } - break; - } - default: - LOG(FATAL) << "unsupported fft type"; - } - if (launch_ok) { - return OkStatus(); - } - return InternalError("Unable to launch fft with type %s", - FftTypeToString(fft_type)); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime3/fft_thunk.h b/xla/service/gpu/runtime3/fft_thunk.h deleted file mode 100644 index 4e0de39e1ad4a..0000000000000 --- a/xla/service/gpu/runtime3/fft_thunk.h +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME3_FFT_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_FFT_THUNK_H_ - -#include - -#include "absl/container/flat_hash_map.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/thunk.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/status.h" - -namespace xla { -namespace gpu { - -struct FftPlan { - // CuFFT thread-safety requires that separate host threads not share plans; - // protect each plan with a mutex. - absl::Mutex mu; - std::unique_ptr plan ABSL_GUARDED_BY(mu); - float scale_factor ABSL_GUARDED_BY(mu); -}; - -class FftPlanCache { - public: - // Returnes Fft plan cached for the given device ordinal or creates a new one. - FftPlan* GetOrCreate(int device_ordinal) { - absl::MutexLock lock(&mu_); - std::unique_ptr& plan = fft_plans_[device_ordinal]; - if (!plan) plan = std::make_unique(); - return plan.get(); - } - - private: - absl::Mutex mu_; - absl::flat_hash_map> fft_plans_ - ABSL_GUARDED_BY(mu_); -}; - -// This class stores everything that StreamExecutor needs to launch an FFT. -// It is generated by IrEmitter. -// -// This is thread-compatible. -class FftThunk : public Thunk { - public: - // Constructs a thunk for launching an FFT on a stream. - // Semantics of null hlo_instruction argument are as in Thunk. - FftThunk(ThunkInfo thunk_info, FftType fft_type, - absl::Span fft_length, - const BufferAllocation::Slice& input_buffer, - const BufferAllocation::Slice& output_buffer, - const Shape& input_shape, const Shape& output_shape); - - FftThunk(const FftThunk&) = delete; // Cannot share fft_plan_ - FftThunk& operator=(const FftThunk&) = delete; // Cannot share fft_plan_ - - // Does the FFT for the thunk on "stream". - Status ExecuteOnStream(const ExecuteParams& params) override; - - private: - const se::fft::Type fft_type_; - const std::vector fft_length_; - - FftPlanCache fft_plan_cache_; - - const BufferAllocation::Slice input_buffer_; - const BufferAllocation::Slice output_buffer_; - - const Shape input_shape_; - const Shape output_shape_; -}; - -Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, - se::DeviceMemoryBase output, const Shape& output_shape, - se::fft::Type fft_type, absl::Span fft_length, - int device_ordinal, FftPlanCache* fft_plan_cache, - se::Stream* stream, se::DeviceMemoryAllocator* memory_allocator); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME3_FFT_THUNK_H_ diff --git a/xla/service/gpu/runtime3/triangular_solve_thunk.cc b/xla/service/gpu/runtime3/triangular_solve_thunk.cc deleted file mode 100644 index 47729a82f7e5e..0000000000000 --- a/xla/service/gpu/runtime3/triangular_solve_thunk.cc +++ /dev/null @@ -1,224 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime3/triangular_solve_thunk.h" - -#include "absl/strings/str_format.h" -#include "xla/service/gpu/make_batch_pointers.h" -#include "xla/stream_executor/blas.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" - -namespace xla { -namespace gpu { - -TriangularSolveThunk::TriangularSolveThunk( - ThunkInfo thunk_info, const TriangularSolveOptions& options, - se::GpuAsmOpts asm_opts, // - const BufferAllocation::Slice& a_buffer, - const BufferAllocation::Slice& b_buffer, - const BufferAllocation::Slice& temp_buffer, // - PrimitiveType type, int64_t batch_size, int64_t m, int64_t n, - int64_t a_batch_stride, int64_t b_batch_stride) - : Thunk(Kind::kTriangularSolve, thunk_info), - asm_opts_(asm_opts), - uplo_(options.lower() ? se::blas::UpperLower::kLower - : se::blas::UpperLower::kUpper), - side_(options.left_side() ? se::blas::Side::kLeft - : se::blas::Side::kRight), - unit_diagonal_(options.unit_diagonal() ? se::blas::Diagonal::kUnit - : se::blas::Diagonal::kNonUnit), - a_buffer_(a_buffer), - b_buffer_(b_buffer), - temp_buffer_(temp_buffer), - type_(type), - batch_size_(batch_size), - m_(m), - n_(n), - a_batch_stride_(a_batch_stride), - b_batch_stride_(b_batch_stride) { - transpose_a_ = [&] { - switch (options.transpose_a()) { - case TriangularSolveOptions::NO_TRANSPOSE: - return se::blas::Transpose::kNoTranspose; - case TriangularSolveOptions::TRANSPOSE: - return se::blas::Transpose::kTranspose; - case TriangularSolveOptions::ADJOINT: - return se::blas::Transpose::kConjugateTranspose; - default: - LOG(ERROR) << "Invalid triangular solve transpose value " - << options.transpose_a(); - return se::blas::Transpose::kNoTranspose; - } - }(); -} - -Status TriangularSolveThunk::ExecuteOnStream(const ExecuteParams& params) { - auto& buffer_allocations = *params.buffer_allocations; - return RunTriangularSolve(buffer_allocations.GetDeviceAddress(a_buffer_), - buffer_allocations.GetDeviceAddress(b_buffer_), - buffer_allocations.GetDeviceAddress(temp_buffer_), - asm_opts_, uplo_, side_, unit_diagonal_, - transpose_a_, type_, batch_size_, m_, n_, - a_batch_stride_, b_batch_stride_, params.stream); -} - -Status RunTriangularSolve(se::DeviceMemoryBase a_data, - se::DeviceMemoryBase b_data, - se::DeviceMemoryBase temp_data, - se::GpuAsmOpts asm_opts, se::blas::UpperLower uplo, - se::blas::Side side, se::blas::Diagonal unit_diagonal, - se::blas::Transpose transpose_a, PrimitiveType type, - int64_t batch_size, int64_t m, int64_t n, - int64_t a_batch_stride, int64_t b_batch_stride, - se::Stream* stream) { - VLOG(3) << "uplo=" << se::blas::UpperLowerString(uplo) - << " side=" << se::blas::SideString(side) - << " diagonal=" << se::blas::DiagonalString(unit_diagonal) - << " batch_size=" << batch_size << " m=" << m << " n=" << n - << " a_batch_stride=" << a_batch_stride - << " b_batch_stride=" << b_batch_stride; - - const int lda = side == se::blas::Side::kLeft ? m : n; - const int ldb = m; - - bool launch_ok; - if (batch_size == 1) { - switch (type) { - case F32: { - se::DeviceMemory b_data_typed(b_data); - launch_ok = - stream - ->ThenBlasTrsm(side, uplo, transpose_a, unit_diagonal, m, n, - /*alpha=*/1.0f, se::DeviceMemory(a_data), - lda, &b_data_typed, ldb) - .ok(); - break; - } - case F64: { - se::DeviceMemory b_data_typed(b_data); - launch_ok = - stream - ->ThenBlasTrsm(side, uplo, transpose_a, unit_diagonal, m, n, - /*alpha=*/1.0, se::DeviceMemory(a_data), - lda, &b_data_typed, ldb) - .ok(); - break; - } - case C64: { - se::DeviceMemory> b_data_typed(b_data); - launch_ok = - stream - ->ThenBlasTrsm(side, uplo, transpose_a, unit_diagonal, m, n, - /*alpha=*/1.0f, - se::DeviceMemory>(a_data), - lda, &b_data_typed, ldb) - .ok(); - break; - } - case C128: { - se::DeviceMemory> b_data_typed(b_data); - launch_ok = - stream - ->ThenBlasTrsm(side, uplo, transpose_a, unit_diagonal, m, n, - /*alpha=*/1.0, - se::DeviceMemory>(a_data), - lda, &b_data_typed, ldb) - .ok(); - break; - } - default: - return InvalidArgument("Invalid type for triangular solve %d", type); - } - } else { - // cublas trsmBatched requires us to materialize out two arrays of - // batch_size_ pointers, pointing to the individual `a` and `b` matrices of - // our input. batch_pointers_bytes is the size in bytes of one of these - // arrays. - int64_t batch_pointers_bytes = sizeof(void*) * batch_size; - TF_RET_CHECK(temp_data.size() >= 2 * batch_pointers_bytes); - void** temp_base = reinterpret_cast(temp_data.opaque()); - se::DeviceMemoryBase a_pointers(temp_base, batch_pointers_bytes); - se::DeviceMemoryBase b_pointers(temp_base + batch_size, - batch_pointers_bytes); - - TF_RETURN_IF_ERROR(MakeBatchPointers(stream, a_data, a_batch_stride, - batch_size, a_pointers)); - TF_RETURN_IF_ERROR(MakeBatchPointers(stream, b_data, b_batch_stride, - batch_size, b_pointers)); - - switch (type) { - case F32: { - se::DeviceMemory typed_b_pointers(b_pointers); - launch_ok = - stream - ->ThenBlasTrsmBatched(side, uplo, transpose_a, unit_diagonal, m, - n, /*alpha=*/1.0f, - se::DeviceMemory(a_pointers), lda, - &typed_b_pointers, ldb, batch_size) - .ok(); - break; - } - case F64: { - se::DeviceMemory typed_b_pointers(b_pointers); - launch_ok = - stream - ->ThenBlasTrsmBatched(side, uplo, transpose_a, unit_diagonal, m, - n, /*alpha=*/1.0f, - se::DeviceMemory(a_pointers), - lda, &typed_b_pointers, ldb, batch_size) - .ok(); - break; - } - case C64: { - se::DeviceMemory*> typed_b_pointers(b_pointers); - launch_ok = stream - ->ThenBlasTrsmBatched( - side, uplo, transpose_a, unit_diagonal, m, n, - /*alpha=*/1.0f, - se::DeviceMemory*>(a_pointers), - lda, &typed_b_pointers, ldb, batch_size) - .ok(); - break; - } - case C128: { - se::DeviceMemory*> typed_b_pointers(b_pointers); - launch_ok = stream - ->ThenBlasTrsmBatched( - side, uplo, transpose_a, unit_diagonal, m, n, - /*alpha=*/1.0f, - se::DeviceMemory*>(a_pointers), - lda, &typed_b_pointers, ldb, batch_size) - .ok(); - break; - } - default: - return InvalidArgument("Invalid type for triangular solve %d", type); - } - } - - if (!launch_ok) { - return InternalError("Unable to launch triangular solve"); - } - return OkStatus(); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/runtime3/triangular_solve_thunk.h b/xla/service/gpu/runtime3/triangular_solve_thunk.h deleted file mode 100644 index 8f582db16f7bd..0000000000000 --- a/xla/service/gpu/runtime3/triangular_solve_thunk.h +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME3_TRIANGULAR_SOLVE_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME3_TRIANGULAR_SOLVE_THUNK_H_ - -#include - -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/thunk.h" -#include "xla/stream_executor/blas.h" -#include "xla/stream_executor/gpu/gpu_asm_opts.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/status.h" - -namespace xla { -namespace gpu { - -// This class stores everything that StreamExecutor needs to launch a triangular -// solve (BlasTrsm). It is generated by IrEmitter. -// -// Thread-compatible. -class TriangularSolveThunk : public Thunk { - public: - TriangularSolveThunk(ThunkInfo thunk_info, - const TriangularSolveOptions& options, - se::GpuAsmOpts asm_opts, - const BufferAllocation::Slice& a_buffer, - const BufferAllocation::Slice& b_buffer, - const BufferAllocation::Slice& temp_buffer, - PrimitiveType type, int64_t batch_size, int64_t m, - int64_t n, int64_t a_batch_stride, - int64_t b_batch_stride); - - TriangularSolveThunk(const TriangularSolveThunk&) = delete; - TriangularSolveThunk& operator=(const TriangularSolveThunk&) = delete; - - Status ExecuteOnStream(const ExecuteParams& params) override; - - private: - se::GpuAsmOpts asm_opts_; - const se::blas::UpperLower uplo_; - const se::blas::Side side_; - const se::blas::Diagonal unit_diagonal_; - se::blas::Transpose transpose_a_; - - const BufferAllocation::Slice a_buffer_; - const BufferAllocation::Slice b_buffer_; - const BufferAllocation::Slice temp_buffer_; - - const PrimitiveType type_; - const int64_t batch_size_; - const int64_t m_; - const int64_t n_; - const int64_t a_batch_stride_; - const int64_t b_batch_stride_; -}; - -Status RunTriangularSolve(se::DeviceMemoryBase a_data, - se::DeviceMemoryBase b_data, - se::DeviceMemoryBase temp_data, - se::GpuAsmOpts asm_opts, se::blas::UpperLower uplo, - se::blas::Side side, se::blas::Diagonal unit_diagonal, - se::blas::Transpose transpose_a, PrimitiveType type, - int64_t batch_size, int64_t m, int64_t n, - int64_t a_batch_stride, int64_t b_batch_stride, - se::Stream* stream); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_RUNTIME3_TRIANGULAR_SOLVE_THUNK_H_ diff --git a/xla/service/gpu/runtime_intrinsics.cc b/xla/service/gpu/runtime_intrinsics.cc index 4e3dd1f383440..8e4413c00e335 100644 --- a/xla/service/gpu/runtime_intrinsics.cc +++ b/xla/service/gpu/runtime_intrinsics.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,8 @@ limitations under the License. #include #include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/ascii.h" #include "absl/strings/string_view.h" #include "xla/service/collective_ops_utils.h" @@ -27,9 +29,8 @@ limitations under the License. #include "xla/service/platform_util.h" #include "xla/shape_util.h" #include "xla/status.h" -#include "xla/statusor.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/util.h" @@ -46,38 +47,38 @@ std::string GetGpuPlatformName() { PlatformUtil::CanonicalPlatformName("gpu").value()); } -Status AssertOnGpu(void* stream_handle, void* buffer, - absl::string_view error_msg) { +absl::Status AssertOnGpu(void* stream_handle, void* buffer, + absl::string_view error_msg) { TF_ASSIGN_OR_RETURN( se::Platform * platform, - se::MultiPlatformManager::PlatformWithName(GetGpuPlatformName())); + se::PlatformManager::PlatformWithName(GetGpuPlatformName())); se::StreamExecutorConfig config; config.gpu_stream = stream_handle; TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, platform->GetExecutor(config)); se::Stream* stream = executor->FindAllocatedStream(stream_handle); if (!stream) { - return InternalError("Stream not found for: %p", stream_handle); + return Internal("Stream not found for: %p", stream_handle); } int8_t expected = false; int64_t byte_size = sizeof(int8_t); CHECK_EQ(byte_size, ShapeUtil::ByteSizeOfPrimitiveType(PrimitiveType::PRED)); - stream->ThenMemcpy( + TF_RETURN_IF_ERROR(stream->Memcpy( &expected, se::DeviceMemoryBase{buffer, static_cast(byte_size)}, - byte_size); + byte_size)); TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); if (!static_cast(expected)) { - return InternalError("%s", error_msg); + return Internal("%s", error_msg); } - return OkStatus(); + return absl::OkStatus(); } void AssertionCustomCall(void* stream_handle, void** buffers, const char* opaque, int opaque_len, XlaCustomCallStatus* status) { - Status s = + absl::Status s = AssertOnGpu(stream_handle, buffers[0], absl::string_view{opaque, static_cast(opaque_len)}); if (!s.ok()) { diff --git a/xla/service/gpu/runtime_intrinsics.h b/xla/service/gpu/runtime_intrinsics.h index d73aced0e2a1f..c73f9f13c25ed 100644 --- a/xla/service/gpu/runtime_intrinsics.h +++ b/xla/service/gpu/runtime_intrinsics.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/runtime_intrinsics_test.cc b/xla/service/gpu/runtime_intrinsics_test.cc index 7a90539581ec6..f7de70ae7b8b0 100644 --- a/xla/service/gpu/runtime_intrinsics_test.cc +++ b/xla/service/gpu/runtime_intrinsics_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/scatter_slice_simplifier.cc b/xla/service/gpu/scatter_slice_simplifier.cc index ba945e04f413a..9672bf259a328 100644 --- a/xla/service/gpu/scatter_slice_simplifier.cc +++ b/xla/service/gpu/scatter_slice_simplifier.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,17 +15,30 @@ limitations under the License. #include "xla/service/gpu/scatter_slice_simplifier.h" +#include #include #include #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_creation_utils.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -160,14 +173,14 @@ HloInstruction* CreateScatterFrom(HloScatterInstruction* scatter, class ScatterSliceSimplifierVisitor : public DfsHloRewriteVisitor { public: - Status HandleScatter(HloInstruction* instruction) override { + absl::Status HandleScatter(HloInstruction* instruction) override { auto* scatter = Cast(instruction); // Infer scatter shape from the slice users. std::optional result_shape = ScatterSliceMatcher(scatter).InferShape(); if (!result_shape.has_value()) { - return OkStatus(); + return absl::OkStatus(); } VLOG(2) << "Matched scatter " << scatter->name() << " with shape " << scatter->shape().ToString() << ", inferred result shape " @@ -181,8 +194,8 @@ class ScatterSliceSimplifierVisitor : public DfsHloRewriteVisitor { private: // Create a replacement for every user. If the user is a slice operation, // replace it in the computation graph, the old branch will be removed. - Status ReplaceAllUsersRecursive(HloInstruction* old_instruction, - HloInstruction* new_instruction) { + absl::Status ReplaceAllUsersRecursive(HloInstruction* old_instruction, + HloInstruction* new_instruction) { // Maintain the replacement map, needed for non-unary elementwise users. replacements_[old_instruction] = new_instruction; @@ -196,13 +209,14 @@ class ScatterSliceSimplifierVisitor : public DfsHloRewriteVisitor { } TF_RETURN_IF_ERROR(ReplaceUserRecursive(user, new_instruction)); } - return OkStatus(); + return absl::OkStatus(); } // Replace the slice user with a new scatter (or a new chain of operations // starting with a scatter). For elementwise operations, create a new user // with updated operands (build the chain). - Status ReplaceUserRecursive(HloInstruction* user, HloInstruction* operand) { + absl::Status ReplaceUserRecursive(HloInstruction* user, + HloInstruction* operand) { VLOG(3) << "Replacing scatter user " << user->name(); if (user->opcode() == HloOpcode::kSlice) { return ReplaceInstruction(user, operand); @@ -241,7 +255,7 @@ class ScatterSliceSimplifierVisitor : public DfsHloRewriteVisitor { } // namespace -StatusOr ScatterSliceSimplifier::Run( +absl::StatusOr ScatterSliceSimplifier::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { return ScatterSliceSimplifierVisitor{}.RunOnModule(module, execution_threads); diff --git a/xla/service/gpu/scatter_slice_simplifier.h b/xla/service/gpu/scatter_slice_simplifier.h index 374498fc31c5c..349837747466b 100644 --- a/xla/service/gpu/scatter_slice_simplifier.h +++ b/xla/service/gpu/scatter_slice_simplifier.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,10 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_SCATTER_SLICE_SIMPLIFIER_H_ #define XLA_SERVICE_GPU_SCATTER_SLICE_SIMPLIFIER_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" namespace xla { @@ -43,7 +47,8 @@ class ScatterSliceSimplifier : public HloModulePass { public: absl::string_view name() const override { return "scatter-slice-simplifier"; } - StatusOr Run( + using HloPassInterface::Run; + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/scatter_slice_simplifier_test.cc b/xla/service/gpu/scatter_slice_simplifier_test.cc index 89b43e74acf64..281a4f0576e0c 100644 --- a/xla/service/gpu/scatter_slice_simplifier_test.cc +++ b/xla/service/gpu/scatter_slice_simplifier_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ limitations under the License. #include "xla/service/gpu/scatter_slice_simplifier.h" +#include +#include #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" diff --git a/xla/service/gpu/sequential_thunk.cc b/xla/service/gpu/sequential_thunk.cc deleted file mode 100644 index e403170af957e..0000000000000 --- a/xla/service/gpu/sequential_thunk.cc +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/sequential_thunk.h" - -#include "tsl/platform/errors.h" -#include "tsl/profiler/lib/scoped_annotation.h" - -namespace xla { -namespace gpu { - -using ::tsl::profiler::ScopedAnnotation; - -SequentialThunk::SequentialThunk(ThunkInfo thunk_info, ThunkSequence thunks) - : Thunk(Kind::kSequential, thunk_info), thunks_(std::move(thunks)) {} - -std::string SequentialThunk::ToStringExtra(int indent) const { - std::string result = "\n"; - absl::StrAppend(&result, thunks().ToString(indent + 1, nullptr)); - return result; -} - -Status SequentialThunk::Initialize(se::StreamExecutor* executor, - ExecutableSource src) { - for (auto& thunk : thunks_) { - TF_RETURN_IF_ERROR(thunk->Initialize(executor, src)); - } - return OkStatus(); -} - -Status SequentialThunk::ExecuteOnStream(const ExecuteParams& params) { - for (const auto& thunk : thunks_) { - ScopedAnnotation annotation([&] { return thunk->profile_annotation(); }); - TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(params)); - } - return OkStatus(); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/sequential_thunk.h b/xla/service/gpu/sequential_thunk.h deleted file mode 100644 index e839f560e6e14..0000000000000 --- a/xla/service/gpu/sequential_thunk.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_SEQUENTIAL_THUNK_H_ -#define XLA_SERVICE_GPU_SEQUENTIAL_THUNK_H_ - -#include - -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/thunk.h" -#include "xla/stream_executor/stream_executor.h" - -namespace xla { -namespace gpu { - -// A thunk that wraps a list of sub-thunks. Executing this thunk executes all -// the sub-thunks sequentially. This is useful to implement instructions that -// require multiple kernel launches or library calls. -class SequentialThunk : public Thunk { - public: - SequentialThunk(ThunkInfo thunk_info, ThunkSequence thunks); - SequentialThunk(const SequentialThunk&) = delete; - SequentialThunk& operator=(const SequentialThunk&) = delete; - - ThunkSequence& thunks() { return thunks_; } - const ThunkSequence& thunks() const { return thunks_; } - std::string ToStringExtra(int indent) const override; - - Status Initialize(se::StreamExecutor* executor, - ExecutableSource src) override; - Status ExecuteOnStream(const ExecuteParams& params) override; - - private: - // The list of sub-thunks. - ThunkSequence thunks_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_SEQUENTIAL_THUNK_H_ diff --git a/xla/service/gpu/sleep_kernel.cu.cc b/xla/service/gpu/sleep_kernel.cu.cc new file mode 100644 index 0000000000000..8f37d47e67734 --- /dev/null +++ b/xla/service/gpu/sleep_kernel.cu.cc @@ -0,0 +1,32 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/sleep_kernel.h" + +namespace xla::gpu { +namespace { + +// Use busy waiting instead of __nanosleep() to make the code more portable +// (__nanosleep requires __CUDA_ARCH__ >= 700) +__global__ void sleep(int64_t num_clocks) { + int64_t start = clock64(); + while (clock64() - start < num_clocks) continue; +} + +} // namespace + +void* GetSleepKernel() { return reinterpret_cast(&sleep); } + +} // namespace xla::gpu diff --git a/xla/service/gpu/sleep_kernel.h b/xla/service/gpu/sleep_kernel.h new file mode 100644 index 0000000000000..3e040b10860e6 --- /dev/null +++ b/xla/service/gpu/sleep_kernel.h @@ -0,0 +1,26 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_SLEEP_KERNEL_H_ +#define XLA_SERVICE_GPU_SLEEP_KERNEL_H_ + +namespace xla::gpu { + +// Returns a pointer to CUDA kernel that does sleep operation on device. +void* GetSleepKernel(); + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_SLEEP_KERNEL_H_ diff --git a/xla/service/gpu/softmax_fusion.h b/xla/service/gpu/softmax_fusion.h deleted file mode 100644 index dc1e11e3455b5..0000000000000 --- a/xla/service/gpu/softmax_fusion.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_SOFTMAX_FUSION_H_ -#define XLA_SERVICE_GPU_SOFTMAX_FUSION_H_ - -#include "absl/container/flat_hash_set.h" -#include "absl/strings/string_view.h" -#include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" - -namespace xla::gpu { - -// Pass to match softmax patterns and replace them with custom calls. -class SoftmaxFusion : public HloModulePass { - public: - SoftmaxFusion() = default; - - absl::string_view name() const override { return "softmax_fusion"; } - - using HloPassInterface::Run; - StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla::gpu - -#endif // XLA_SERVICE_GPU_SOFTMAX_FUSION_H_ diff --git a/xla/service/gpu/softmax_rewriter_triton.cc b/xla/service/gpu/softmax_rewriter_triton.cc index 039bc7baf371d..2db88bc59379a 100644 --- a/xla/service/gpu/softmax_rewriter_triton.cc +++ b/xla/service/gpu/softmax_rewriter_triton.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -13,8 +13,8 @@ limitations under the License. #include "xla/service/gpu/softmax_rewriter_triton.h" #include -#include #include +#include #include #include "absl/algorithm/container.h" @@ -22,10 +22,13 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" @@ -33,11 +36,11 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/triton_support.h" +#include "xla/service/instruction_fusion.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -48,43 +51,14 @@ limitations under the License. namespace xla::gpu { namespace { +using hlo_query::IsBroadcastOfParameter; +using hlo_query::IsBroadcastOfScalarConstant; + bool HasDefaultLayout(const Shape& shape) { return shape.has_layout() && LayoutUtil::IsMonotonicWithDim0Major(shape.layout()); } -bool IsTritonSupportedInstruction(const HloInstruction* instr, - const se::GpuComputeCapability& gpu_version) { - if (!instr->shape().IsArray()) { - return false; - } - - if (!IsTritonSupportedDataType(instr->shape().element_type(), gpu_version)) { - return false; - } - - for (const HloInstruction* operand : instr->operands()) { - if (!IsTritonSupportedDataType(operand->shape().element_type(), - gpu_version)) { - return false; - } - } - - // TODO(bchetioui): expand with non-trivial instructions. - if (instr->IsElementwise()) { - return IsTritonSupportedElementwise(instr->opcode(), - instr->shape().element_type()); - } - - switch (instr->opcode()) { - case HloOpcode::kBitcast: - case HloOpcode::kParameter: - return true; - default: - return false; - } -} - // Returns true if a trivially connected producer of 'consumer' with opcode // 'opcode' exists. If such an instruction is found, the value of 'producer' is // set to it. The definition of "trivial" operations is as given in @@ -128,7 +102,103 @@ inline bool HasOneUse(const HloInstruction* instr) { return instr->user_count() == 1; } -using hlo_query::IsBroadcastOfScalarConstant; +// Supports two types of broadcast of parameters. Either to one batch +// dim, or one reduction dim. For example the following cases are supported: +// +// Case #1: +// p = f32[a] parameter(0) +// b = f32[a,x] broadcast(p), dimensions={0} +// +// Case #2: +// p = f32[a] parameter(0) +// b = f32[x,a] broadcast(p), dimensions={1} +// +// Case #3: +// p = f32[a,b] parameter(0) +// b = f32[x,a,b] broadcast(p), dimensions={1,2} +// +// Other broadcast tiling patterns are currently unsupported. +// See b/328049138 for details. +// +// Unsupported case #1: +// p = f32[a] parameter(0) +// b = f32[x,a,y] broadcast(p), dimensions={1} +// +// Unsupported case #2: +// p = f32[a,b] parameter(0) +// b = f32[x,a,y,b] broadcast(p), dimensions={1,3} +// +// Unsupported case #3: +// p = f32[a] parameter(0) +// b = f32[x,y,a] broadcast(p), dimensions={2} +// +// Unsupported case #4: +// p = f32[a,b] parameter(0) +// b = f32[a,x,b] broadcast(p), dimensions={0,2} +bool IsBatchOrReductionDimBroadcast(const HloInstruction& hlo) { + CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast) + << "Expected broadcast " << hlo.ToShortString(); + CHECK_EQ(hlo.operand(0)->opcode(), HloOpcode::kParameter) + << "Expected parameter " << hlo.operand(0)->ToShortString(); + + const HloBroadcastInstruction* broadcast = + Cast(&hlo); + + const HloParameterInstruction* parameter = + Cast(hlo.operand(0)); + + // Support only one dim broadcast. + if (parameter->shape().dimensions_size() + 1 != + broadcast->shape().dimensions_size()) { + return false; + } + + // It is enough to ensure that the broadcast does not preserve both last, and + // first dimensions of the parameter at the same time. Otherwise the broadcast + // is the unsupported case #4. + // + // Preserve the first dim: + // p = f32[a,b] parameter(0) + // b1 = f32[a,b,c] broadcast(p), dimensions={0,1} + bool preserve_first_dim = broadcast->dimensions().front() == 0; + // Preserve the last dim: + // p = f32[a,b] parameter(0) + // b1 = f32[c,a,b] broadcast(p), dimensions={1,2} + bool preserve_last_dim = broadcast->dimensions().back() == + broadcast->shape().dimensions_size() - 1; + // We do not want to preserve both first and last dim, as it means the + // broadcast is not expanding on outermost dims. + return !(preserve_first_dim && preserve_last_dim); +} + +bool IsBroadcastOfAScalar(const HloInstruction& hlo) { + CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast) + << "Expected broadcast " << hlo.ToShortString(); + return ShapeUtil::IsScalar(hlo.operand(0)->shape()); +} + +bool IsSingleRowParameterBroadcast(const HloInstruction& hlo) { + CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast) + << "Expected broadcast " << hlo.ToShortString(); + CHECK_EQ(hlo.operand(0)->opcode(), HloOpcode::kParameter) + << "Expected parameter " << hlo.operand(0)->ToShortString(); + + const HloBroadcastInstruction* broadcast = + Cast(&hlo); + const HloParameterInstruction* parameter = + Cast(hlo.operand(0)); + + if (parameter->shape().dimensions_size() != 1) { + return false; + } + return broadcast->dimensions()[0] == broadcast->shape().dimensions_size() - 1; +} + +bool IsSupportedBroadcastOfParameter(const HloInstruction& hlo) { + return IsBroadcastOfParameter(hlo) && + (IsBatchOrReductionDimBroadcast(hlo) || IsBroadcastOfAScalar(hlo) || + IsSingleRowParameterBroadcast(hlo)); +} // Chooses which operand to use for fusion processing. Taking in a unary or // binary instruction, returns the first non-splat operand. If none is @@ -137,8 +207,11 @@ HloInstruction* ChooseOperandForFusionProcessing(HloInstruction* instr) { CHECK_GT(instr->operand_count(), 0); CHECK_LE(instr->operand_count(), 2); + // TODO(b/326217416): Extend the broadcast of splat constants/parameters to a + // broadcast of any op. if (instr->operand_count() > 1 && - IsBroadcastOfScalarConstant(*instr->operand(0))) { + (IsBroadcastOfScalarConstant(*instr->operand(0)) || + IsSupportedBroadcastOfParameter(*instr->operand(0)))) { return instr->mutable_operand(1); } return instr->mutable_operand(0); @@ -163,7 +236,7 @@ bool IsTriviallyFusible(HloInstruction* instr, } if (instr->IsElementwise() && instr->operand_count() == 1) { - return IsTritonSupportedInstruction(instr, gpu_version); + return static_cast(IsTritonSupportedInstruction(*instr, gpu_version)); } // Elementwise binary ops are trivially fusible if the operands are the same, @@ -175,14 +248,20 @@ bool IsTriviallyFusible(HloInstruction* instr, // Elementwise binary ops should be fused if both operands are the same and // if the operand is triton supported. if (operand_0 == operand_1) { - return IsTritonSupportedInstruction(instr, gpu_version); + return static_cast( + IsTritonSupportedInstruction(*instr, gpu_version)); } // For simplicity we only fuse elementwise binary ops with splat operands // if they contain one non-splat operand. - if (IsBroadcastOfScalarConstant(*operand_0) ^ - IsBroadcastOfScalarConstant(*operand_1)) { - return IsTritonSupportedInstruction(instr, gpu_version); + // TODO(b/326217416): Extend the broadcast of splat constants/parameters to + // a broadcast of any op. + if ((IsBroadcastOfScalarConstant(*operand_0) || + IsSupportedBroadcastOfParameter(*operand_0)) ^ + (IsBroadcastOfScalarConstant(*operand_1) || + IsSupportedBroadcastOfParameter(*operand_1))) { + return static_cast( + IsTritonSupportedInstruction(*instr, gpu_version)); } } @@ -228,85 +307,6 @@ bool IsTriviallyConnectedProducerOf( return false; } -bool IsTritonSupportedComputation(const HloComputation* computation, - const se::GpuComputeCapability& gpu_version) { - for (const HloInstruction* instr : computation->instructions()) { - if (!IsTritonSupportedInstruction(instr, gpu_version)) { - return false; - } - } - return true; -} - -std::optional MatchesTritonCompatibleClosedReductionDiamond( - HloInstruction* instr, const se::GpuComputeCapability& gpu_version) { - // Return the producer of the following pattern: - // - // producer - // | \ - // | reduce_{max,sum,...} - // | | - // | broadcast - // | / - // binop (elementwise) - // - // where each edge is allowed to contain also trivial operations that can be - // generated by Triton. We mean by "trivial" here those operations that do not - // increase the amount of memory read/written by the fusion, and that are - // compatible with any chosen tiling. - // - // We also assume that the reduction is done on the last axis of the producer - // array. - std::optional match_failure = std::nullopt; - - if (!instr->IsElementwiseBinary() || - !IsTritonSupportedInstruction(instr, gpu_version)) { - return match_failure; - } - - HloInstruction* producer; - HloInstruction* broadcast; - HloInstruction* reduce; - - if (!(TrivialEdge(&broadcast, instr->mutable_operand(1), - HloOpcode::kBroadcast, gpu_version) && - TrivialEdge(&reduce, broadcast->mutable_operand(0), HloOpcode::kReduce, - gpu_version) && - HasDefaultLayout(broadcast->shape()) && - HasDefaultLayout(reduce->shape()) && reduce->operand_count() == 2 && - reduce->operand(1)->opcode() == HloOpcode::kConstant && - IsTritonSupportedComputation(reduce->to_apply(), gpu_version))) { - return match_failure; - } - - if (!HasOneUse(broadcast) || !HasOneUse(reduce)) { - return match_failure; - } - - producer = reduce->mutable_operand(0); - - if (!(reduce->dimensions().size() == 1 && - reduce->dimensions(0) == producer->shape().rank() - 1 && - !absl::c_linear_search(broadcast->dimensions(), - broadcast->shape().rank() - 1))) { - return match_failure; - } - - while (IsTriviallyFusible(producer, gpu_version)) { - producer = ChooseOperandForFusionProcessing(producer); - } - - if (!HasDefaultLayout(producer->shape()) || - !IsTriviallyConnectedProducerOf(producer, instr->mutable_operand(0), - gpu_version) || - !(producer == instr->operand(0) || - instr->operand(0)->user_count() == 1)) { - return match_failure; - } - - return producer; -} - // Finds the first non-fusible producer of a diamond. This instruction is either // 1. the direct producer of the diamond, if that producer is used more than // twice and/or is not otherwise trivially fusible @@ -327,7 +327,7 @@ HloInstruction* FindFirstNonFusibleDiamondProducer( return diamond_producer; } -Status FuseDiamondChainImpl(const DiamondChainDescriptor& diamond_chain) { +absl::Status FuseDiamondChainImpl(const DiamondChainDescriptor& diamond_chain) { auto [root, producer] = diamond_chain; std::string suggested_name = "triton_softmax"; @@ -336,40 +336,54 @@ Status FuseDiamondChainImpl(const DiamondChainDescriptor& diamond_chain) { absl::flat_hash_map old_to_new_mapping; - old_to_new_mapping[producer] = builder.AddInstruction( - HloInstruction::CreateParameter(0, producer->shape(), "parameter_0")); + int param = 0; + old_to_new_mapping[producer] = + builder.AddInstruction(HloInstruction::CreateParameter( + param, producer->shape(), absl::StrCat("parameter_", param))); + param++; + + std::vector parameters = {producer}; - std::function create_computation = - [&](const HloInstruction* instr) -> void { + std::function create_computation = + [&](HloInstruction* instr) -> void { if (old_to_new_mapping.contains(instr)) { return; } std::vector new_operands; - for (const HloInstruction* operand : instr->operands()) { + for (HloInstruction* operand : instr->mutable_operands()) { create_computation(operand); new_operands.push_back(old_to_new_mapping[operand]); } - old_to_new_mapping[instr] = builder.AddInstruction( - instr->CloneWithNewOperands(instr->shape(), new_operands)); + if (instr->opcode() == HloOpcode::kParameter) { + old_to_new_mapping[instr] = + builder.AddInstruction(HloInstruction::CreateParameter( + param, instr->shape(), absl::StrCat("parameter_", param))); + parameters.push_back(instr); + param++; + } else { + old_to_new_mapping[instr] = builder.AddInstruction( + instr->CloneWithNewOperands(instr->shape(), new_operands)); + } }; create_computation(root); - HloComputation* computation = root->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(), /*is_entry=*/false); HloInstruction* softmax_fusion = root->parent()->AddInstruction(HloInstruction::CreateFusion( - root->shape(), HloInstruction::FusionKind::kCustom, - std::vector({producer}), computation)); + root->shape(), HloInstruction::FusionKind::kCustom, parameters, + computation)); softmax_fusion->GetModule()->SetAndUniquifyInstrName(softmax_fusion, suggested_name); - TF_ASSIGN_OR_RETURN(auto backend_config, - softmax_fusion->backend_config()); + TF_ASSIGN_OR_RETURN(auto gpu_config, + softmax_fusion->backend_config()); + FusionBackendConfig& backend_config = + *gpu_config.mutable_fusion_backend_config(); backend_config.set_kind(std::string(kTritonSoftmaxFusionKind)); - TF_RETURN_IF_ERROR(softmax_fusion->set_backend_config(backend_config)); + TF_RETURN_IF_ERROR(softmax_fusion->set_backend_config(gpu_config)); if (root->IsRoot()) { root->parent()->set_root_instruction(softmax_fusion); @@ -381,13 +395,88 @@ Status FuseDiamondChainImpl(const DiamondChainDescriptor& diamond_chain) { } VLOG(5) << softmax_fusion->ToString(); - return OkStatus(); + return absl::OkStatus(); } using DiamondDescriptor = DiamondChainDescriptor; } // anonymous namespace +DiamondMatchingDecision +SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond( + HloInstruction* instr) const { + if (!instr->IsElementwiseBinary()) { + return "Root is not elementwise binary."; + } + + if (!IsTritonSupportedInstruction(*instr, gpu_version_)) { + return "Root is not supported for Triton instruction."; + } + + HloInstruction* producer; + HloInstruction* broadcast; + HloInstruction* reduce; + + if (!TrivialEdge(&broadcast, instr->mutable_operand(1), HloOpcode::kBroadcast, + gpu_version_)) { + return "Could not find a trivial connection from root to a broadcast."; + } + + if (!TrivialEdge(&reduce, broadcast->mutable_operand(0), HloOpcode::kReduce, + gpu_version_)) { + return "Could not find a trivial connection from matched broadcast to a " + "reduction."; + } + + if (!(HasDefaultLayout(broadcast->shape()) && + HasDefaultLayout(reduce->shape()))) { + return "Broadcast or reduce have non-default layouts."; + } + + if (CodegenDecision is_supported = + IsTritonSupportedInstruction(*reduce, gpu_version_); + !is_supported) { + VLOG(3) << is_supported.Explain(); + return is_supported; + } + + if (!HasOneUse(broadcast) || !HasOneUse(reduce)) { + return "More than one use of broadcast or reduce."; + } + + producer = reduce->mutable_operand(0); + + if (absl::c_linear_search(broadcast->dimensions(), + broadcast->shape().rank() - 1)) { + return "Broadcast is not along the reduction dimension."; + } + + while (IsTriviallyFusible(producer, gpu_version_)) { + producer = ChooseOperandForFusionProcessing(producer); + } + + if (!HasDefaultLayout(producer->shape())) { + return "Producer has non-default layout."; + } + + if (!IsTriviallyConnectedProducerOf(producer, instr->mutable_operand(0), + gpu_version_)) { + return "Producer is not trivially connected."; + } + + if (producer != instr->operand(0) && instr->operand(0)->user_count() != 1) { + return "Unsupported root-producer connection."; + } + + VLOG(5) << "Matched Softmax diamond with: "; + VLOG(5) << "root: " << instr->ToString(); + VLOG(5) << "producer: " << producer->ToString(); + VLOG(5) << "broadcast: " << broadcast->ToString(); + VLOG(5) << "reduce: " << reduce->ToString(); + + return producer; +} + std::vector SoftmaxRewriterTriton::FindAllFusibleDiamondChains( HloModule& module, @@ -408,9 +497,16 @@ SoftmaxRewriterTriton::FindAllFusibleDiamondChains( continue; } - if (auto producer = MatchesTritonCompatibleClosedReductionDiamond( - instr, gpu_version_)) { - matched_diamonds.push_back(DiamondDescriptor{instr, producer.value()}); + auto producer = MatchesTritonCompatibleClosedReductionDiamond(instr); + if (std::holds_alternative(producer)) { + matched_diamonds.push_back(DiamondDescriptor{ + instr, + std::get(producer), + }); + } else { + VLOG(5) << "Cannot match the diamond pattern for instruction " + << instr->ToString() + << ". Reason: " << std::get(producer).Explain(); } } } @@ -506,7 +602,8 @@ SoftmaxRewriterTriton::FindAllFusibleDiamondChains( // diamond producer of diamond chain n+1. diamond_chains.push_back(DiamondChainDescriptor{ last_trivially_fusible_user(previous_diamond_root), - current_fusion_producer}); + current_fusion_producer, + }); current_fusion_producer = first_non_fusible_diamond_producer; current_reduce_dimension_size = diamond_reduce_dimension_size; @@ -520,14 +617,21 @@ SoftmaxRewriterTriton::FindAllFusibleDiamondChains( return diamond_chains; } -Status SoftmaxRewriterTriton::FuseDiamondChain( +absl::Status SoftmaxRewriterTriton::FuseDiamondChain( const DiamondChainDescriptor& diamond_chain) { return FuseDiamondChainImpl(diamond_chain); } -StatusOr SoftmaxRewriterTriton::Run( +absl::StatusOr SoftmaxRewriterTriton::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { + auto cuda_compute_capability = + std::get_if(&gpu_version_); + if (!cuda_compute_capability || !cuda_compute_capability->IsAtLeastAmpere()) { + return absl::FailedPreconditionError( + "Triton support is only enabled for Ampere GPUs and up."); + } + std::vector diamond_chains = FindAllFusibleDiamondChains(*module, execution_threads); diff --git a/xla/service/gpu/softmax_rewriter_triton.h b/xla/service/gpu/softmax_rewriter_triton.h index fd131f8336d75..9463d510f4590 100644 --- a/xla/service/gpu/softmax_rewriter_triton.h +++ b/xla/service/gpu/softmax_rewriter_triton.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_SOFTMAX_REWRITER_TRITON_H_ #define XLA_SERVICE_GPU_SOFTMAX_REWRITER_TRITON_H_ +#include #include #include "absl/container/flat_hash_set.h" @@ -23,8 +24,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" +#include "xla/service/instruction_fusion.h" #include "xla/status.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" namespace xla { @@ -35,6 +36,8 @@ struct DiamondChainDescriptor { HloInstruction* producer = nullptr; }; +using DiamondMatchingDecision = std::variant; + // Rewrite compatible Softmax into a custom fusion region to be code-generated // with the Triton-based Softmax emitter. class SoftmaxRewriterTriton : public HloModulePass { @@ -44,7 +47,7 @@ class SoftmaxRewriterTriton : public HloModulePass { absl::string_view name() const override { return "triton-softmax-rewriter"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -58,7 +61,27 @@ class SoftmaxRewriterTriton : public HloModulePass { // Constructs a Softmax fusion containing all the instructions between the // root and the producer of a diamond chain. The producer is excluded from the // fusion. - Status FuseDiamondChain(const DiamondChainDescriptor& diamond_chain); + absl::Status FuseDiamondChain(const DiamondChainDescriptor& diamond_chain); + + // Return the producer of the following pattern: + // + // producer + // | \ + // | reduce_{max,sum,...} + // | | + // | broadcast + // | / + // binop (elementwise) + // + // where each edge is allowed to contain also trivial operations that can be + // generated by Triton. We mean by "trivial" here those operations that do not + // increase the amount of memory read/written by the fusion, and that are + // compatible with any chosen tiling. + // + // We also assume that the reduction is done on the last axis of the producer + // array. + DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamond( + HloInstruction* instr) const; private: se::GpuComputeCapability gpu_version_; diff --git a/xla/service/gpu/softmax_rewriter_triton_test.cc b/xla/service/gpu/softmax_rewriter_triton_test.cc index ad055a03414be..74e800f9a815c 100644 --- a/xla/service/gpu/softmax_rewriter_triton_test.cc +++ b/xla/service/gpu/softmax_rewriter_triton_test.cc @@ -1,8 +1,11 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. + You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -13,6 +16,7 @@ limitations under the License. #include #include +#include #include #include @@ -21,8 +25,10 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/strings/substitute.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/primitive_util.h" +#include "xla/service/instruction_fusion.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/statusor.h" @@ -38,10 +44,12 @@ namespace { namespace m = ::xla::match; +using ::testing::HasSubstr; + // Wrapper around SoftmaxRewriterTriton(gpu_version).Run(module) that finds // and fuses as many diamond chains as possible without invoking any kind of // cost analysis. -StatusOr SoftmaxRewriterTritonMatchAndRewrite( +absl::StatusOr SoftmaxRewriterTritonMatchAndRewrite( se::GpuComputeCapability gpu_version, HloModule* module) { CHECK_NE(module, nullptr); SoftmaxRewriterTriton softmax_rewriter_triton(gpu_version); @@ -1068,9 +1076,8 @@ ENTRY main { GmockMatch(m::Fusion(m::Parameter()))); } -TEST_P( - SoftmaxRewriterTritonTest, - CanOnlyFuseConvertInvolvingBF16InputIntoSoftmaxDiamondWithAtLeastAmpereComputeCapability) { // NOLINT(whitespace/line_length) +TEST_P(SoftmaxRewriterTritonTest, + CanFuseConvertInvolvingBF16InputIntoSoftmaxDiamond) { PrimitiveType data_type = GetParam(); const std::string hlo_string_template = R"( HloModule softmax @@ -1086,52 +1093,51 @@ ENTRY main { reduce = $0[127]{0} reduce(param_0_$0, constant_neg_inf), dimensions={1}, to_apply=max_computation broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0} ROOT subtract = $0[127,125]{1,0} subtract(param_0_$0, broadcast) -} -)"; +})"; const std::string hlo_string = absl::Substitute(hlo_string_template, primitive_util::LowercasePrimitiveTypeName(data_type)); - auto ampere_module = ParseAndReturnVerifiedModule(hlo_string).value(); - auto volta_module = ampere_module->Clone(); + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - // Ampere EXPECT_TRUE( SoftmaxRewriterTritonMatchAndRewrite( se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0}, - ampere_module.get()) + module.get()) .value()); - EXPECT_TRUE(verifier().Run(ampere_module.get()).status().ok()); - VLOG(2) << ampere_module->ToString(); - EXPECT_THAT(ampere_module->entry_computation()->root_instruction(), + EXPECT_TRUE(verifier().Run(module.get()).status().ok()); + VLOG(2) << module->ToString(); + EXPECT_THAT(module->entry_computation()->root_instruction(), GmockMatch(m::Fusion(m::Parameter()))); +} - // Volta (pre-Ampere) - VLOG(2) << volta_module->ToString(); +TEST_F(SoftmaxRewriterTritonTest, RewriterBailsOutOnPreAmpereGpu) { + const std::string hlo_string = R"( +HloModule softmax +max_computation { + arg_0 = f32[] parameter(0) + arg_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = bf16[127,125]{1,0} parameter(0) + param_0_f32 = f32[127,125]{1,0} convert(param_0) + constant_neg_inf = f32[] constant(-inf) + reduce = f32[127]{0} reduce(param_0_f32, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = f32[127,125]{1,0} subtract(param_0_f32, broadcast) +})"; - switch (data_type) { - case F32: - case F16: - EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite( - se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0}, - volta_module.get()) - .value()); - EXPECT_TRUE(verifier().Run(volta_module.get()).status().ok()); - EXPECT_THAT(volta_module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Convert(m::Parameter())))); - break; - case BF16: - // When bf16 is used, no fusion is possible on Volta. - EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite( - se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0}, - volta_module.get()) - .value()); - break; - default: - ABSL_UNREACHABLE(); - } + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + + EXPECT_THAT( + SoftmaxRewriterTriton( + se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0}) + .Run(module.get()), + tsl::testing::StatusIs( + tsl::error::FAILED_PRECONDITION, + ::testing::StrEq( + "Triton support is only enabled for Ampere GPUs and up."))); } TEST_P(SoftmaxRewriterTritonTest, DoesNotFuseConvertWithC64DataType) { @@ -1736,6 +1742,301 @@ ENTRY main { SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); } +TEST_F(SoftmaxRewriterTritonTest, FusionDecisionIsCapturedExplicitly) { + const std::string hlo_string = R"( +HloModule softmax +max_computation { + arg_0 = f32[] parameter(0) + arg_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = f32[127,125]{1,0} parameter(0) + identity = f32[] parameter(1) + reduce = f32[127]{0} reduce(param_0, identity), dimensions={1}, to_apply=max_computation + broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast) +} +)"; + + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + SoftmaxRewriterTriton softmax_rewriter_triton(gpu_version_); + int unmatched = 0, matched = 0; + for (HloInstruction* instruction : + module->entry_computation()->MakeInstructionPostOrder()) { + DiamondMatchingDecision decision = + softmax_rewriter_triton.MatchesTritonCompatibleClosedReductionDiamond( + instruction); + if (std::holds_alternative(decision)) { + std::string actual_decision = + std::get(decision).Explain(); + EXPECT_THAT( + actual_decision, + AnyOf(HasSubstr("Root is not elementwise binary"), + HasSubstr("Reduction init value should be a constant or a " + "convert of a constant."))); + unmatched++; + } else { + matched++; + } + } + EXPECT_EQ(unmatched, 5); + EXPECT_EQ(matched, 0); +} + +TEST_F( + SoftmaxRewriterTritonTest, + FusesBinaryElementwiseIfIntermediateDiamondOpWithBroadcastAlongReductionDimAsParameter) { // NOLINT(whitespace/line_length) + const std::string hlo_string = R"( +HloModule h1 + +add_computation { + y = f32[] parameter(1) + x = f32[] parameter(0) + ROOT add = f32[] add(x, y) +} + +ENTRY main { + p0 = f32[32]{0} parameter(0) + p1 = f32[32,16]{1,0} parameter(1) + c = f32[] constant(0) + + r0 = f32[32]{0} reduce(p1, c), dimensions={1}, to_apply=add_computation + b0 = f32[32,16]{1,0} broadcast(r0), dimensions={0} + b1 = f32[32,16]{1,0} broadcast(p0), dimensions={0} + add0 = f32[32,16]{1,0} add(b1, p1) + ROOT add1 = f32[32,16]{1,0} add(add0, b0) +})"; + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + EXPECT_TRUE( + SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); +} + +TEST_F( + SoftmaxRewriterTritonTest, + FusesBinaryElementwiseIfIntermediateDiamondOpWithBroadcastAlongBatchDimAsParameter) { // NOLINT(whitespace/line_length) + const std::string hlo_string = R"( +HloModule h1 + +add_computation { + y = f32[] parameter(1) + x = f32[] parameter(0) + ROOT add = f32[] add(x, y) +} + +ENTRY main { + p0 = f32[16]{0} parameter(0) + p1 = f32[32,16]{1,0} parameter(1) + c = f32[] constant(0) + + r0 = f32[32]{0} reduce(p1, c), dimensions={1}, to_apply=add_computation + b0 = f32[32,16]{1,0} broadcast(r0), dimensions={0} + b1 = f32[32,16]{1,0} broadcast(p0), dimensions={1} + add0 = f32[32,16]{1,0} add(b1, p1) + ROOT add1 = f32[32,16]{1,0} add(add0, b0) +})"; + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + EXPECT_TRUE( + SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); +} + +TEST_F( + SoftmaxRewriterTritonTest, + FusesBinaryElementwiseIfIntermediateDiamondOpWithMultiDimTensorBroadcastAlongBatchDimAsParameter) { // NOLINT(whitespace/line_length) + const std::string hlo_string = R"( +HloModule h1 + +add_computation { + y = f32[] parameter(1) + x = f32[] parameter(0) + ROOT add = f32[] add(x, y) +} + +ENTRY main { + p0 = f32[32,16]{1,0} parameter(0) + p1 = f32[64,32,16]{2,1,0} parameter(1) + c = f32[] constant(0) + + r0 = f32[64,32]{1,0} reduce(p1, c), dimensions={2}, to_apply=add_computation + b0 = f32[64,32,16]{2,1,0} broadcast(r0), dimensions={0,1} + b1 = f32[64,32,16]{2,1,0} broadcast(p0), dimensions={1,2} + add0 = f32[64,32,16]{2,1,0} add(b1, p1) + ROOT add1 = f32[64,32,16]{2,1,0} add(add0, b0) +})"; + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + EXPECT_TRUE( + SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); +} + +TEST_F( + SoftmaxRewriterTritonTest, + FusesBinaryElementwiseIfIntermediateDiamondOpWithZeroDimTensorBroadcastAsParameter) { // NOLINT(whitespace/line_length) + const std::string hlo_string = R"( +HloModule h1 + +add_computation { + y = f32[] parameter(1) + x = f32[] parameter(0) + ROOT add = f32[] add(x, y) +} + +ENTRY main { + parameter_0 = f32[] parameter(0) + parameter_1 = f32[64,32,16]{2,1,0} parameter(1) + c = f32[] constant(0) + + reduce_0 = f32[64,32]{1,0} reduce(parameter_1, c), dimensions={2}, to_apply=add_computation + broadcast_0 = f32[64,32,16]{2,1,0} broadcast(reduce_0), dimensions={0,1} + broadcast_1 = f32[64,32,16]{2,1,0} broadcast(parameter_0), dimensions={} + add_0 = f32[64,32,16]{2,1,0} add(broadcast_1, parameter_1) + ROOT add1 = f32[64,32,16]{2,1,0} add(add_0, broadcast_0) +})"; + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + EXPECT_TRUE( + SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); +} + +TEST_F( + SoftmaxRewriterTritonTest, + FusesBinaryElementwiseIfIntermediateDiamondOpIsBroadcastOf1DParameterAlongNonReductionDimensions) { // NOLINT(whitespace/line_length) + const std::string hlo_string = R"( +HloModule h1 + +add_computation { + y = f32[] parameter(1) + x = f32[] parameter(0) + ROOT add = f32[] add(x, y) +} + +ENTRY main { + parameter_0 = f32[16] parameter(0) + parameter_1 = f32[64,32,16]{2,1,0} parameter(1) + c = f32[] constant(0) + + reduce_0 = f32[64,32]{1,0} reduce(parameter_1, c), dimensions={2}, to_apply=add_computation + broadcast_0 = f32[64,32,16]{2,1,0} broadcast(reduce_0), dimensions={0,1} + broadcast_1 = f32[64,32,16]{2,1,0} broadcast(parameter_0), dimensions={2} + add_0 = f32[64,32,16]{2,1,0} add(broadcast_1, parameter_1) + ROOT add1 = f32[64,32,16]{2,1,0} add(add_0, broadcast_0) +})"; + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + EXPECT_TRUE( + SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); +} + +TEST_F( + SoftmaxRewriterTritonTest, + DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpIsBroadcastOf1DParameterAlongBothBatchAndReductionDimensions) { // NOLINT(whitespace/line_length) + const std::string hlo_string = R"( +HloModule h1 + +add_computation { + y = f32[] parameter(1) + x = f32[] parameter(0) + ROOT add = f32[] add(x, y) +} + +ENTRY main { + parameter_0 = f32[64] parameter(0) + parameter_1 = f32[64,32,16]{2,1,0} parameter(1) + c = f32[] constant(0) + + reduce_0 = f32[64,32]{1,0} reduce(parameter_1, c), dimensions={2}, to_apply=add_computation + broadcast_0 = f32[64,32,16]{2,1,0} broadcast(reduce_0), dimensions={0,1} + broadcast_1 = f32[64,32,16]{2,1,0} broadcast(parameter_0), dimensions={0} + add_0 = f32[64,32,16]{2,1,0} add(broadcast_1, parameter_1) + ROOT add1 = f32[64,32,16]{2,1,0} add(add_0, broadcast_0) +})"; + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + EXPECT_FALSE( + SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); +} + +TEST_F( + SoftmaxRewriterTritonTest, + DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpWithBroadcastAlongBatchAndReductionDimAsParameter) { // NOLINT(whitespace/line_length) + const std::string hlo_string = R"( +HloModule h1 + +add_computation { + y = f32[] parameter(1) + x = f32[] parameter(0) + ROOT add = f32[] add(x, y) +} + +ENTRY main { + p0 = f32[8]{0} parameter(0) + p1 = f32[32,8,16]{2,1,0} parameter(1) + c = f32[] constant(0) + + r0 = f32[32,8]{1,0} reduce(p1, c), dimensions={2}, to_apply=add_computation + b0 = f32[32,8,16]{2,1,0} broadcast(r0), dimensions={0,1} + b1 = f32[32,8,16]{2,1,0} broadcast(p0), dimensions={1} + add0 = f32[32,8,16]{2,1,0} add(b1, p1) + ROOT add1 = f32[32,8,16]{2,1,0} add(add0, b0) +})"; + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + EXPECT_FALSE( + SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); +} + +TEST_F( + SoftmaxRewriterTritonTest, + DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpWithPartialBroadcastToBatchDim) { // NOLINT(whitespace/line_length) + const std::string hlo_string = R"( +HloModule h1 + +add_computation { + y = f32[] parameter(1) + x = f32[] parameter(0) + ROOT add = f32[] add(x, y) +} + +ENTRY main { + p0 = f32[16,64]{1,0} parameter(0) + p1 = f32[8,16,32,64]{3,2,1,0} parameter(1) + c = f32[] constant(0) + + r0 = f32[8,16,32]{2,1,0} reduce(p1, c), dimensions={3}, to_apply=add_computation + b0 = f32[8,16,32,64]{3,2,1,0} broadcast(r0), dimensions={0,1,2} + b1 = f32[8,16,32,64]{3,2,1,0} broadcast(p0), dimensions={1,3} + add0 = f32[8,16,32,64]{3,2,1,0} add(b1, p1) + ROOT add1 = f32[8,16,32,64]{3,2,1,0} add(add0, b0) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + EXPECT_FALSE( + SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); +} + +TEST_F( + SoftmaxRewriterTritonTest, + DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpWithMultiDimBroadcastAlongBatchDimAsParameter) { // NOLINT(whitespace/line_length) + const std::string hlo_string = R"( +HloModule h1 + +add_computation { + y = f32[] parameter(1) + x = f32[] parameter(0) + ROOT add = f32[] add(x, y) +} + +ENTRY main { + p0 = f32[32,16]{1,0} parameter(0) + p1 = f32[128,64,32,16]{3,2,1,0} parameter(1) + c = f32[] constant(0) + + r0 = f32[128,64,32]{2,1,0} reduce(p1, c), dimensions={3}, to_apply=add_computation + b0 = f32[128,64,32,16]{3,2,1,0} broadcast(r0), dimensions={0,1,2} + b1 = f32[128,64,32,16]{3,2,1,0} broadcast(p0), dimensions={2,3} + add0 = f32[128,64,32,16]{3,2,1,0} add(b1, p1) + ROOT add1 = f32[128,64,32,16]{3,2,1,0} add(add0, b0) +})"; + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + EXPECT_FALSE( + SoftmaxRewriterTritonMatchAndRewrite(gpu_version_, module.get()).value()); +} + INSTANTIATE_TEST_SUITE_P(SoftmaxRewriterTritonTestSuite, SoftmaxRewriterTritonTest, ::testing::Values(F32, F16, BF16)); diff --git a/xla/service/gpu/split_k_gemm_rewriter.cc b/xla/service/gpu/split_k_gemm_rewriter.cc index 2da6897e22734..2ff84c2c61ea9 100644 --- a/xla/service/gpu/split_k_gemm_rewriter.cc +++ b/xla/service/gpu/split_k_gemm_rewriter.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -102,15 +102,35 @@ void CopyIncrementingAboveThreshold(absl::Span source, } } -Status UncompilableMatmul(absl::string_view explanation) { - Status s = absl::CancelledError(explanation); +absl::Status UncompilableMatmul(absl::string_view explanation) { + absl::Status s = absl::CancelledError(explanation); s.SetPayload(kUncompilableFusion, absl::Cord(explanation)); return s; } +absl::StatusOr MakeSparseMetaOperand( + HloDotInstruction& dot, const TritonGemmConfig& config) { + CHECK_EQ(dot.sparse_operands(), 1); + CHECK_EQ(dot.sparsity().front().index(), 0); + + HloInstruction* meta = dot.mutable_operand(2); + const Shape& shape = meta->shape(); + if (shape.dimensions().back() % config.split_k != 0) { + return UncompilableMatmul("Sparsity metadata has incorrect shape."); + } + + std::vector dimensions(shape.dimensions().begin(), + shape.dimensions().end() - 1); + dimensions.push_back(config.split_k); + dimensions.push_back(shape.dimensions().back() / config.split_k); + Shape new_shape = ShapeUtil::MakeShapeWithDescendingLayout( + shape.element_type(), dimensions); + return MakeBitcastHlo(meta, new_shape); +} + } // namespace -StatusOr MakeSplitKOperand( +absl::StatusOr MakeSplitKOperand( HloInstruction& dot, const TritonFusionAnalysis& analysis, const TritonGemmConfig& config, const int64_t contracting_dim_idx, const int operand_number) { @@ -127,7 +147,7 @@ StatusOr MakeSplitKOperand( analysis.IterSpec(scope, &hlo, contracting_dim_idx); if (spec == nullptr) { // No contracting dimension - no checks needed. - return OkStatus(); + return absl::OkStatus(); } if (spec->size() != 1) { return UncompilableMatmul("Unsupported case."); @@ -145,7 +165,7 @@ StatusOr MakeSplitKOperand( return UncompilableMatmul( "Too small divisible part of the contracting dimension."); } - return OkStatus(); + return absl::OkStatus(); }; // The divisibility check is only used to ensure that the TritonFusionAnalysis @@ -214,17 +234,18 @@ StatusOr MakeSplitKOperand( // Apply split K configuration from the tiling config to the fused dot() // computation: bitcast the operands, change the output shape and the dot // dimensions. -Status MakeDotComputationSplitKBatch(HloComputation* computation, - const TritonGemmConfig& config, - bool disable_reduced_precision_reduction) { - HloInstruction* dot = - hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); +absl::Status MakeDotComputationSplitKBatch( + HloComputation* computation, const TritonGemmConfig& config, + bool disable_reduced_precision_reduction) { + HloDotInstruction* dot = Cast( + hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot)); TF_ASSIGN_OR_RETURN(const auto analysis, TritonFusionAnalysis::Execute(*computation)); const DotDimensionNumbers& old_dim_numbers = dot->dot_dimension_numbers(); DotDimensionNumbers new_dim_numbers; - const int64_t lhs_contracting_idx = ContractingDimensionIndex(*dot, 0); + TF_ASSIGN_OR_RETURN(const int64_t lhs_contracting_idx, + ContractingDimensionIndex(*dot, 0)); CopyIncrementingAboveThreshold( old_dim_numbers.lhs_contracting_dimensions(), *new_dim_numbers.mutable_lhs_contracting_dimensions(), @@ -234,7 +255,8 @@ Status MakeDotComputationSplitKBatch(HloComputation* computation, old_dim_numbers.lhs_batch_dimensions(), *new_dim_numbers.mutable_lhs_batch_dimensions(), lhs_contracting_idx); - const int64_t rhs_contracting_idx = ContractingDimensionIndex(*dot, 1); + TF_ASSIGN_OR_RETURN(const int64_t rhs_contracting_idx, + ContractingDimensionIndex(*dot, 1)); CopyIncrementingAboveThreshold( old_dim_numbers.rhs_contracting_dimensions(), *new_dim_numbers.mutable_rhs_contracting_dimensions(), @@ -244,6 +266,13 @@ Status MakeDotComputationSplitKBatch(HloComputation* computation, old_dim_numbers.rhs_batch_dimensions(), *new_dim_numbers.mutable_rhs_batch_dimensions(), rhs_contracting_idx); + // Make sure we have a supported sparse dot. + if (dot->sparse_operands()) { + if (dot->sparsity().size() != 1 || dot->sparsity().front().index() != 0) { + return UncompilableMatmul("Sparsity is only supported on left operand."); + } + } + // Collect HLOs to transform between dot output and root. These will // get a new major most batch dimension sized as split K factor. Other inputs // of these HLOs will get broadcasted. @@ -282,8 +311,17 @@ Status MakeDotComputationSplitKBatch(HloComputation* computation, CHECK_EQ(rhs->operand(0)->opcode(), HloOpcode::kPad); did_pad = true; } + std::vector sparsity(dot->sparsity().begin(), + dot->sparsity().end()); + std::vector sparse_meta(sparsity.size()); + for (int i = 0; i < sparsity.size(); ++i) { + // This is only correct for LHS sparse operand after dot decomposition. + sparsity[i].set_dimension(sparsity[i].dimension() + 1); + TF_ASSIGN_OR_RETURN(sparse_meta[i], + MakeSparseMetaOperand(*dot, config)); + } expanded = MakeDotHlo(lhs, rhs, new_dim_numbers, dot->precision_config(), - dot->shape().element_type()) + dot->shape().element_type(), sparsity, sparse_meta) .value(); // Make the added batch dimension the major-most, keep the order of the // original dimensions. @@ -350,11 +388,11 @@ Status MakeDotComputationSplitKBatch(HloComputation* computation, TritonFusionAnalysis::Execute(*computation, config.split_k).status()); } - return OkStatus(); + return absl::OkStatus(); } -Status MakeDotSplitKBatch(HloInstruction* dot_fusion, - const TritonGemmConfig& config) { +absl::Status MakeDotSplitKBatch(HloInstruction* dot_fusion, + const TritonGemmConfig& config) { CHECK_EQ(dot_fusion->opcode(), HloOpcode::kFusion); if (dot_fusion->shape().IsTuple()) { @@ -379,9 +417,9 @@ Status MakeDotSplitKBatch(HloInstruction* dot_fusion, dot_fusion->parent()->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(root->shape().element_type()))); // The batch dimension to reduce is the first one by construction. - TF_ASSIGN_OR_RETURN( - HloInstruction * reduce, - MakeReduceHlo(dot_fusion, zero, /*dimensions=*/{0}, HloOpcode::kAdd)); + TF_ASSIGN_OR_RETURN(HloInstruction * reduce, + MakeReduceHlo(dot_fusion, zero, /*dimensions=*/{0}, + HloOpcode::kAdd, &dot_fusion->metadata())); // The output of the reduce has to have the layout of the original dot. *reduce->mutable_shape()->mutable_layout() = output_layout; @@ -403,7 +441,7 @@ Status MakeDotSplitKBatch(HloInstruction* dot_fusion, } } - return OkStatus(); + return absl::OkStatus(); } } // namespace gpu diff --git a/xla/service/gpu/split_k_gemm_rewriter.h b/xla/service/gpu/split_k_gemm_rewriter.h index c74b4dc181f7e..234288e9b1956 100644 --- a/xla/service/gpu/split_k_gemm_rewriter.h +++ b/xla/service/gpu/split_k_gemm_rewriter.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,8 +34,8 @@ bool HasDivisibleSuffixAllowingSplit(absl::Span span, // Apply split K configuration from the tiling config to the fusion instruction: // in addition to MakeDotComputationSplitKBatch on its computation add the // necessary reduction after it. -Status MakeDotSplitKBatch(HloInstruction* dot_fusion, - const TritonGemmConfig& config); +absl::Status MakeDotSplitKBatch(HloInstruction* dot_fusion, + const TritonGemmConfig& config); } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/split_k_gemm_rewriter_test.cc b/xla/service/gpu/split_k_gemm_rewriter_test.cc index 40b2bde030dd3..3a66bc9c2b83c 100644 --- a/xla/service/gpu/split_k_gemm_rewriter_test.cc +++ b/xla/service/gpu/split_k_gemm_rewriter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -44,6 +44,9 @@ limitations under the License. #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" +// TODO(b/317016172): Inspect usages of TritonGemmConfig and potentially update +// them to to use newly exposed parameters. + namespace xla { namespace gpu { namespace { @@ -89,15 +92,17 @@ ENTRY e { p0 = s8[3,128,5,32]{3,2,1,0} parameter(0) p1 = bf16[16,128]{1,0} parameter(1) ROOT fusion = bf16[480,16]{1,0} fusion(p0, p1), - kind=kCustom, calls=triton_gemm_dot, backend_config="__triton_gemm" + kind=kCustom, calls=triton_gemm_dot, backend_config="__triton_gemm", + metadata={op_name="foo"} })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); TritonGemmConfig config(16, 16, 16, 4, 1, 4); TF_EXPECT_OK(MakeDotSplitKBatch( module->entry_computation()->root_instruction(), config)); - EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), - HloOpcode::kReduce); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kReduce); + EXPECT_EQ(root->metadata().op_name(), "foo"); } TEST_F(SplitKTest, MakeSplitKWithOutputFusion) { @@ -615,6 +620,72 @@ ENTRY e { TritonFusionAnalysis::Execute(*dot_computation)); } +TEST_F(SplitKTest, SparseDotWithLhsSparseOperandIsRewritten) { + const std::string hlo_text = R"( +HloModule test + +triton_gemm { + lhs = f16[2,5,1600] parameter(0) + rhs = f16[2,3200,10] parameter(1) + meta = u16[2,5,200] parameter(2) + ROOT dot = f32[2,5,10] dot(lhs, rhs, meta), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1}, sparsity=L.2@2:4 +} + +ENTRY e { + lhs = f16[2,5,1600] parameter(0) + rhs = f16[2,3200,10] parameter(1) + meta = u16[2,5,200] parameter(2) + ROOT fusion = f32[2,5,10] fusion(lhs, rhs, meta), + kind=kCustom, calls=triton_gemm, backend_config="__triton_gemm" +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + TritonGemmConfig config(16, 16, 16, /*split_k=*/4, 1, 1); + TF_EXPECT_OK(MakeDotSplitKBatch( + module->entry_computation()->root_instruction(), config)); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kReduce); + + HloInstruction* dot = + module->GetComputationWithName("triton_gemm")->root_instruction(); + EXPECT_EQ(dot->operand(0)->shape(), + ShapeUtil::MakeShapeWithDescendingLayout(F16, {2, 5, 4, 400})); + EXPECT_EQ(dot->operand(1)->shape(), + ShapeUtil::MakeShapeWithDescendingLayout(F16, {2, 4, 800, 10})); + EXPECT_EQ(dot->operand(2)->shape(), + ShapeUtil::MakeShapeWithDescendingLayout(U16, {2, 5, 4, 50})); +} + +TEST_F(SplitKTest, SparseDotWithRhsSparseOperandTriggersError) { + const std::string hlo_text = R"( +HloModule test + +triton_gemm { + lhs = f16[2,5,3200] parameter(0) + rhs = f16[2,1600,10] parameter(1) + meta = u16[2,200,10] parameter(2) + ROOT dot = f32[2,5,10] dot(lhs, rhs, meta), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1}, sparsity=R.1@2:4 +} + +ENTRY e { + lhs = f16[2,5,3200] parameter(0) + rhs = f16[2,1600,10] parameter(1) + meta = u16[2,200,10] parameter(2) + ROOT fusion = f32[2,5,10] fusion(lhs, rhs, meta), + kind=kCustom, calls=triton_gemm, backend_config="__triton_gemm" +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + TritonGemmConfig config(16, 16, 16, /*split_k=*/4, 1, 1); + auto result = MakeDotSplitKBatch( + module->entry_computation()->root_instruction(), config); + EXPECT_FALSE(result.ok()); +} + class SplitKTestWithMorePreciseReduction : public HloTestBase, public ::testing::WithParamInterface { diff --git a/xla/service/gpu/stream_attribute_annotator.cc b/xla/service/gpu/stream_attribute_annotator.cc new file mode 100644 index 0000000000000..0bfa2cef837e2 --- /dev/null +++ b/xla/service/gpu/stream_attribute_annotator.cc @@ -0,0 +1,176 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/stream_attribute_annotator.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +bool IsOnlyRootNonDefaultStream(HloComputation* computation) { + HloInstruction* root = computation->root_instruction(); + auto root_gpu_config = root->backend_config(); + if (!root_gpu_config.ok() || root->opcode() == HloOpcode::kTuple) { + return false; + } + int64_t root_stream_id = root_gpu_config->operation_queue_id(); + VLOG(2) << "Found fusion computation's root stream id to be " + << root_stream_id; + if (root_stream_id == Thunk::kDefaultExecutionStreamId.value()) { + return false; + } + for (HloInstruction* instr : computation->MakeInstructionPostOrder()) { + if (instr == root) { + continue; + } + int64_t instr_stream_id = + instr->backend_config()->operation_queue_id(); + if (instr_stream_id != Thunk::kDefaultExecutionStreamId.value() && + instr_stream_id != root_stream_id) { + return false; + } + } + return true; +} + +absl::StatusOr AnnotateStreamAttributesForInstruction( + HloInstruction* instr, GpuBackendConfig& instr_gpu_config) { + if (instr->called_computations().size() != 1) { + return false; + } + HloComputation* called_comp = instr->called_computations()[0]; + int64_t stream_id = instr_gpu_config.operation_queue_id(); + + if (!IsOnlyRootNonDefaultStream(called_comp) || + stream_id != Thunk::kDefaultExecutionStreamId.value()) { + return false; + } + + auto comp_root_gpu_config = + called_comp->root_instruction()->backend_config(); + + instr_gpu_config.set_operation_queue_id( + comp_root_gpu_config->operation_queue_id()); + *instr_gpu_config.mutable_wait_on_operation_queues() = + comp_root_gpu_config->wait_on_operation_queues(); + TF_RETURN_IF_ERROR(instr->set_backend_config(instr_gpu_config)); + return true; +} + +absl::StatusOr AnnotateStreamAttributesForCopyStart( + HloInstruction* instr, int64_t channel_id, + GpuBackendConfig& instr_gpu_config) { + // Do nothing if copy-start has already been annotated + if (instr_gpu_config.operation_queue_id() != + Thunk::kDefaultExecutionStreamId.value()) { + return false; + } + instr_gpu_config.set_operation_queue_id(channel_id); + TF_RETURN_IF_ERROR(instr->set_backend_config(instr_gpu_config)); + VLOG(3) << "Add copy-start's backend config: " << channel_id; + return true; +} + +absl::StatusOr AnnotateStreamAttributesForUsers( + HloInstruction* instr, GpuBackendConfig& instr_gpu_config) { + bool changed = false; + int64_t stream_id = instr_gpu_config.operation_queue_id(); + if (stream_id == Thunk::kDefaultExecutionStreamId.value()) { + return changed; + } + std::vector all_consumers; + for (auto user : instr->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement) { + user = user->users()[0]; + } + all_consumers.push_back(user); + } + + for (auto user : all_consumers) { + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + user->backend_config()); + auto it = absl::c_find(gpu_config.wait_on_operation_queues(), stream_id); + if (it == gpu_config.wait_on_operation_queues().end() && + gpu_config.operation_queue_id() != stream_id) { + gpu_config.mutable_wait_on_operation_queues()->Add(stream_id); + TF_RETURN_IF_ERROR(user->set_backend_config(gpu_config)); + changed = true; + } + } + + return changed; +} +} // namespace + +absl::StatusOr StreamAttributeAnnotator::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + XLA_VLOG_LINES( + 5, "StreamAttributeAnnotator::Run(), before:\n" + module->ToString()); + bool changed = false; + int64_t channel_id = hlo_query::NextChannelId(*module); + for (const HloComputation* comp : module->computations(execution_threads)) { + for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { + auto instr_gpu_config = instr->backend_config(); + if (!instr_gpu_config.ok()) { + continue; + } + // For fusion instruction, only annotate + // when the root of fusion is a single instruction + // running on non-default stream. + if (instr->opcode() == HloOpcode::kFusion) { + TF_ASSIGN_OR_RETURN(bool comp_result, + AnnotateStreamAttributesForInstruction( + instr, instr_gpu_config.value())); + changed |= comp_result; + } else if (instr->opcode() == HloOpcode::kCopyStart) { + TF_ASSIGN_OR_RETURN(bool comp_result, + AnnotateStreamAttributesForCopyStart( + instr, channel_id, instr_gpu_config.value())); + changed |= comp_result; + continue; + } + + TF_ASSIGN_OR_RETURN( + bool user_result, + AnnotateStreamAttributesForUsers(instr, instr_gpu_config.value())); + changed |= user_result; + } + } + XLA_VLOG_LINES( + 5, "StreamAttributeAnnotator::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/stream_attribute_annotator.h b/xla/service/gpu/stream_attribute_annotator.h new file mode 100644 index 0000000000000..8a0284adee390 --- /dev/null +++ b/xla/service/gpu/stream_attribute_annotator.h @@ -0,0 +1,60 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ANNOTATOR_H_ +#define XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ANNOTATOR_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla::gpu { + +// This pass checks to see if: +// 1. there's any instruction, that +// consumes data from other computes streams, +// is missing "wait_on_operation_queues" attribute. +// 2. there's any fusion instruction with non-default +// stream fusion root. +// It will annotate the corresponding instruction with +// the correct attribute in GpuBackendConfig. +// Instructions annotated with operation_queue_id > 0 +// will be wrapped with AsyncInstruction and split into +// AsyncStart and AsyncDone in the +// StreamAttributeAsyncWrapper pass. +// We also check if there's any non-default-stream +// instruction's user doesn't have the correct "wait_on_operation_queues" +// attribute and set it with producer's operation_queue_id. +// "wait_on_operation_queues" will need to used by the emitter to emit the +// correct WaitForStreams thunk. + +class StreamAttributeAnnotator : public HloModulePass { + public: + absl::string_view name() const override { + return "stream-attribute-annotator"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ANNOTATOR_H_ diff --git a/xla/service/gpu/stream_attribute_annotator_test.cc b/xla/service/gpu/stream_attribute_annotator_test.cc new file mode 100644 index 0000000000000..2861f9a82a7ef --- /dev/null +++ b/xla/service/gpu/stream_attribute_annotator_test.cc @@ -0,0 +1,212 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/stream_attribute_annotator.h" + +#include +#include +#include +#include + +#include +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +using StreamAttributeAnnotatorTest = HloTestBase; + +TEST_F(StreamAttributeAnnotatorTest, AllUsersAreAnnotated) { + constexpr absl::string_view kHloString = R"( + HloModule ModuleWithAsync + + ENTRY entry { + p1_32 = f32[1] parameter(0) + p2_32 = f32[1] parameter(1) + add_32 = f32[1] add(p1_32, p2_32), backend_config={"operation_queue_id":"1", "wait_on_operation_queues":[]} + exp_32 = f32[1] exponential(add_32) + + neg32 = f32[1] negate(add_32) + ROOT add_out_32 = f32[1] add(neg32, exp_32) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + StreamAttributeAnnotator attr_annotator; + bool changed; + TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get())); + EXPECT_TRUE(changed); + + const HloInstruction* add = FindInstruction(module.get(), "add_32"); + for (auto user : add->users()) { + // Every user should have an annotation. + EXPECT_TRUE(user->has_backend_config()); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + user->backend_config()); + EXPECT_EQ(gpu_config.wait_on_operation_queues()[0], 1); + } +} + +TEST_F(StreamAttributeAnnotatorTest, MultipleStreamsAreCombined) { + constexpr absl::string_view kHloString = R"( + HloModule ModuleWithAsync + + ENTRY entry { + p1_32 = f32[1] parameter(0) + p2_32 = f32[1] parameter(1) + add_32 = f32[1] add(p1_32, p2_32), backend_config={"operation_queue_id":"1", "wait_on_operation_queues":[]} + exp_32 = f32[1] exponential(p2_32), backend_config={"operation_queue_id":"2", "wait_on_operation_queues":[]} + + ROOT add_out_32 = f32[1] add(add_32, exp_32) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + StreamAttributeAnnotator attr_annotator; + bool changed; + TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get())); + EXPECT_TRUE(changed); + + const HloInstruction* root = module->entry_computation()->root_instruction(); + // Root should wait on 2 streams. + EXPECT_TRUE(root->has_backend_config()); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + root->backend_config()); + std::vector expected_stream_ids = {1, 2}; + for (auto id : expected_stream_ids) { + auto it = absl::c_find(gpu_config.wait_on_operation_queues(), id); + EXPECT_NE(it, gpu_config.wait_on_operation_queues().end()); + } +} + +TEST_F(StreamAttributeAnnotatorTest, GTEUserIsAnnotated) { + constexpr absl::string_view kHloString = R"( + HloModule ModuleWithAsync + + ENTRY entry { + p1_32 = f32[16,32] parameter(0) + p2_32 = f32[32,16] parameter(1) + + custom-call.3 = (f32[16,16], s8[1028]{0}) custom-call(p1_32, p2_32), custom_call_target="__cublas$gemm", backend_config={"operation_queue_id":"1","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT","grad_x":false,"grad_y":false}} + get-tuple-element.24 = f32[16,16] get-tuple-element(custom-call.3), index=0 + + exp_32 = f32[16,16] exponential(get-tuple-element.24) + + ROOT neg32 = f32[16,16] negate(exp_32) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + StreamAttributeAnnotator attr_annotator; + bool changed; + TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get())); + EXPECT_TRUE(changed); + + const HloInstruction* exp = FindInstruction(module.get(), "exp_32"); + EXPECT_TRUE(exp->has_backend_config()); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + exp->backend_config()); + EXPECT_EQ(gpu_config.wait_on_operation_queues()[0], 1); +} + +TEST_F(StreamAttributeAnnotatorTest, FusionIsAnnotated) { + constexpr absl::string_view kHloString = R"( + HloModule ModuleWithFusion + + fused_computation.1 { + fusion_p0_32 = f32[16,16] parameter(0) + fusion_p2_32 = f32[16,16] parameter(1) + ROOT add = f32[16,16] add(fusion_p0_32, fusion_p2_32), backend_config={"operation_queue_id":"1","wait_on_operation_queues":[]} + } + + ENTRY entry { + p1_32 = f32[16,16] parameter(0) + p2_32 = f32[16,16] parameter(1) + ROOT fusion.1 = f32[16,16] fusion(p1_32, p2_32), kind=kLoop, calls=fused_computation.1 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + StreamAttributeAnnotator attr_annotator; + bool changed; + TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get())); + EXPECT_TRUE(changed); + + const HloInstruction* fusion = FindInstruction(module.get(), "fusion.1"); + EXPECT_TRUE(fusion->has_backend_config()); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + fusion->backend_config()); + EXPECT_EQ(gpu_config.operation_queue_id(), 1); +} + +TEST_F(StreamAttributeAnnotatorTest, CopyStartIsAnnotated) { + constexpr absl::string_view kHloString = R"( + HloModule offloading + ENTRY %main (param_0: f32[1024], param_1: f32[1024]) -> f32[1024] { + %param_1 = f32[1024]{0} parameter(1) + %param_0 = f32[1024]{0} parameter(0) + %res_3 = f32[1024]{0} add(f32[1024]{0} %param_0, f32[1024]{0} %param_1) + %copy-start = (f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) copy-start(f32[1024]{0} %res_3) + %res_4 = f32[1024]{0} tanh(f32[1024]{0} %res_3) + %copy-start.2 = (f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) copy-start(f32[1024]{0} %res_4) + %res_5 = f32[1024]{0} tanh(f32[1024]{0} %res_4) + %copy-done = f32[1024]{0:S(5)} copy-done((f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) %copy-start) + %res_6 = f32[1024]{0} tanh(f32[1024]{0} %res_5) + %copy-done.2 = f32[1024]{0:S(5)} copy-done((f32[1024]{0:S(5)}, f32[1024]{0}, u32[]) %copy-start.2) + %copy-start.3 = (f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) copy-start(f32[1024]{0:S(5)} %copy-done.2) + %res_7 = f32[1024]{0} add(f32[1024]{0} %res_6, f32[1024]{0} %res_6) + %copy-start.1 = (f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) copy-start(f32[1024]{0:S(5)} %copy-done) + %res_8 = f32[1024]{0} add(f32[1024]{0} %res_7, f32[1024]{0} %res_5) + %copy-done.3 = f32[1024]{0} copy-done((f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) %copy-start.3) + %res_9 = f32[1024]{0} add(f32[1024]{0} %res_8, f32[1024]{0} %copy-done.3) + %copy-done.1 = f32[1024]{0} copy-done((f32[1024]{0}, f32[1024]{0:S(5)}, u32[]) %copy-start.1) + %res_10 = f32[1024]{0} add(f32[1024]{0} %res_9, f32[1024]{0} %copy-done.1) + ROOT %res_11 = f32[1024]{0} tanh(f32[1024]{0} %res_10) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + StreamAttributeAnnotator attr_annotator; + bool changed; + TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get())); + EXPECT_TRUE(changed); + + for (std::string i : {"", ".1", ".2", ".3"}) { + const HloInstruction* cp_start = + FindInstruction(module.get(), "copy-start" + i); + EXPECT_TRUE(cp_start->has_backend_config()); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + cp_start->backend_config()); + EXPECT_EQ(gpu_config.operation_queue_id(), 1); + } +} +} // namespace +} // namespace xla::gpu diff --git a/xla/service/gpu/stream_attribute_async_wrapper.cc b/xla/service/gpu/stream_attribute_async_wrapper.cc new file mode 100644 index 0000000000000..822c6473dba48 --- /dev/null +++ b/xla/service/gpu/stream_attribute_async_wrapper.cc @@ -0,0 +1,74 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/stream_attribute_async_wrapper.h" + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { + +namespace { +static absl::StatusOr AsynchronizeInstruction(HloInstruction* instr) { + auto instr_gpu_config = instr->backend_config(); + if (!instr_gpu_config.ok() || instr_gpu_config->operation_queue_id() == + Thunk::kDefaultExecutionStreamId.value()) { + return false; + } + HloComputation* computation = instr->parent(); + TF_ASSIGN_OR_RETURN( + HloInstruction * done, + computation->CreateAsyncInstructions( + instr, {}, StreamAttributeAsyncWrapper::kParallelExecutionThread, + /*replace=*/true)); + TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + done->backend_config()); + // Set the false delay of done op to be false so it can be scheduled + // far apart from start. + gpu_config.set_force_earliest_schedule(false); + TF_RETURN_IF_ERROR(done->set_backend_config(gpu_config)); + VLOG(5) << "Created async instruction: " << done->ToString(); + return true; +} +} // namespace + +absl::StatusOr StreamAttributeAsyncWrapper::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + XLA_VLOG_LINES( + 2, "StreamAttributeAsyncWrapper::Run(), before:\n" + module->ToString()); + bool changed = false; + for (const HloComputation* comp : + module->MakeNonfusionComputations(execution_threads)) { + for (HloInstruction* instr : comp->instructions()) { + TF_ASSIGN_OR_RETURN(bool result, AsynchronizeInstruction(instr)); + changed |= result; + } + } + XLA_VLOG_LINES( + 2, "StreamAttributeAsyncWrapper::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/stream_attribute_async_wrapper.h b/xla/service/gpu/stream_attribute_async_wrapper.h new file mode 100644 index 0000000000000..95fe7bba66508 --- /dev/null +++ b/xla/service/gpu/stream_attribute_async_wrapper.h @@ -0,0 +1,49 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_ +#define XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla::gpu { + +// This pass will find the instructions that +// are annotated with non-default stream id in backend configs +// by the StreamAttributeAnnotator pass +// and wrap them using AsyncStartDone pairs to achieve +// asynchronous executions. +class StreamAttributeAsyncWrapper : public HloModulePass { + public: + inline static constexpr char kParallelExecutionThread[] = "parallel"; + + absl::string_view name() const override { + return "async-stream-attribute-wrapper"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_ diff --git a/xla/service/gpu/stream_attribute_async_wrapper_test.cc b/xla/service/gpu/stream_attribute_async_wrapper_test.cc new file mode 100644 index 0000000000000..8b3dcb23eac7b --- /dev/null +++ b/xla/service/gpu/stream_attribute_async_wrapper_test.cc @@ -0,0 +1,77 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/stream_attribute_async_wrapper.h" + +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +using StreamAttributeAsyncWrapperTest = HloTestBase; + +TEST_F(StreamAttributeAsyncWrapperTest, NonDefaultOpIsWrapped) { + constexpr absl::string_view kHloString = R"( + HloModule ModuleWithAsync + + ENTRY entry { + p1_32 = f32[1] parameter(0) + p2_32 = f32[1] parameter(1) + add_32 = f32[1] add(p1_32, p2_32), backend_config={"operation_queue_id":"1", "wait_on_operation_queues":[], "force_earliest_schedule":true} + ROOT exp_32 = f32[1] exponential(add_32), backend_config={"operation_queue_id":"0", "wait_on_operation_queues":[1]} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + StreamAttributeAsyncWrapper async_wrapper; + bool changed; + TF_ASSERT_OK_AND_ASSIGN(changed, async_wrapper.Run(module.get())); + EXPECT_TRUE(changed); + const HloInstruction* producer = + module->entry_computation()->root_instruction()->operand(0); + EXPECT_EQ(producer->opcode(), HloOpcode::kAsyncDone); + // Verify that the force_earliest_schedule is set to false for the done op. + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig done_gpu_config, + producer->backend_config()); + EXPECT_EQ(done_gpu_config.force_earliest_schedule(), false); + + const HloInstruction* producer_start = producer->operand(0); + EXPECT_EQ(producer_start->opcode(), HloOpcode::kAsyncStart); + + const xla::HloAsyncInstruction* async = + Cast(producer_start); + EXPECT_EQ(async->async_wrapped_opcode(), HloOpcode::kAdd); + // Verify that the backend config is kept intact + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + async->backend_config()); + EXPECT_EQ(gpu_config.operation_queue_id(), 1); + EXPECT_EQ(gpu_config.force_earliest_schedule(), true); + EXPECT_EQ(async->async_execution_thread(), "parallel"); +} +} // namespace +} // namespace xla::gpu diff --git a/xla/service/gpu/stream_executor_util.cc b/xla/service/gpu/stream_executor_util.cc index 803342532b999..4827fc251ce08 100644 --- a/xla/service/gpu/stream_executor_util.cc +++ b/xla/service/gpu/stream_executor_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/stream_executor_util.h" +#include #include #include #include @@ -24,21 +25,60 @@ limitations under the License. #include #include #include +#include #include #include +#include "absl/algorithm/container.h" +#include "absl/base/const_init.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "Eigen/Core" // from @eigen_archive #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/layout.h" #include "xla/layout_util.h" +#include "xla/primitive_util.h" +#include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/data_type.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/dnn.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" +#include "xla/tsl/util/env_var.h" +#include "xla/tsl/util/proto/proto_utils.h" #include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/util/env_var.h" -#include "tsl/util/proto/proto_utils.h" +#include "tsl/platform/ml_dtypes.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { +se::dnn::VersionInfo GetDnnVersionInfo( + stream_executor::StreamExecutor* stream_exec, + se::dnn::VersionInfo fallback_version) { + if (!stream_exec) { + return fallback_version; + } + stream_executor::dnn::DnnSupport* dnn = stream_exec->AsDnn(); + if (!dnn) { + return fallback_version; + } + return dnn->GetVersion().value_or(fallback_version); +} + namespace { using se::dnn::DataLayout; @@ -64,7 +104,7 @@ int64_t FindMissingDnum(absl::Span vals) { return vals.size(); } -StatusOr DataLayoutToXlaLayout( +absl::StatusOr DataLayoutToXlaLayout( DataLayout data_layout, int64_t batch_dimension, int64_t feature_dimension, absl::Span spatial_dimensions) { std::vector layout; @@ -90,14 +130,14 @@ StatusOr DataLayoutToXlaLayout( layout.push_back(feature_dimension); break; default: - return InternalError("Invalid layout %s", DataLayoutString(data_layout)); + return Internal("Invalid layout %s", DataLayoutString(data_layout)); } return LayoutUtil::MakeLayoutFromMajorToMinor(layout); } } // anonymous namespace -StatusOr> +absl::StatusOr> StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, DataLayout input, FilterLayout filter, DataLayout output) { @@ -137,9 +177,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, filter_layout.push_back(dnums.kernel_input_feature_dimension()); break; default: - return InternalError("Invalid filter layout %s for conv with dnums %s,", - FilterLayoutString(filter), - ConvolutionDimensionNumbersToString(dnums)); + return Internal("Invalid filter layout %s for conv with dnums %s,", + FilterLayoutString(filter), + ConvolutionDimensionNumbersToString(dnums)); } return std::make_tuple(input_layout, @@ -147,7 +187,7 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, output_layout); } -StatusOr> +absl::StatusOr> XlaConvShapesToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, const Shape& input, const Shape& filter, const Shape& output) { @@ -188,7 +228,7 @@ XlaConvShapesToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, } else if (vect_size == 32) { input_layout = DataLayout::kBatchDepthYX32; } else { - return InternalError( + return Internal( "Invalid input shape %s for conv with dnums %s. Most-minor dim " "should be 4 or 32, but was %d.", ShapeUtil::HumanStringWithLayout(input), @@ -197,7 +237,7 @@ XlaConvShapesToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, } else if (LayoutUtil::Equal(input.layout(), nhwc_input)) { input_layout = DataLayout::kBatchYXDepth; } else { - return InternalError( + return Internal( "Invalid input layout %s for conv with dnums %s; expected one of (%s, " "%s, %s)", LayoutUtil::HumanString(input.layout()), @@ -215,7 +255,7 @@ XlaConvShapesToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, } else if (vect_size == 32) { filter_layout = FilterLayout::kOutputInputYX32; } else { - return InternalError( + return Internal( "Invalid filter shape %s for conv with dnums %s. Most-minor dim " "should be 4 or 32, but was %d.", ShapeUtil::HumanStringWithLayout(filter), @@ -224,7 +264,7 @@ XlaConvShapesToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, } else if (LayoutUtil::Equal(filter.layout(), nhwc_filter)) { filter_layout = FilterLayout::kOutputYXInput; } else { - return InternalError( + return Internal( "Invalid filter layout %s for conv with dnums %s, expected one of (%s, " "%s, %s)", LayoutUtil::HumanString(filter.layout()), @@ -242,7 +282,7 @@ XlaConvShapesToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, } else if (vect_size == 32) { output_layout = DataLayout::kBatchDepthYX32; } else { - return InternalError( + return Internal( "Invalid output shape %s for conv with dnums %s. Most-minor dim " "should be 4 or 32, but was %d.", ShapeUtil::HumanStringWithLayout(output), @@ -251,9 +291,9 @@ XlaConvShapesToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, } else if (LayoutUtil::Equal(output.layout(), nhwc_output)) { output_layout = DataLayout::kBatchYXDepth; } else { - return InternalError("Invalid output layout %s for conv with dnums %s", - LayoutUtil::HumanString(output.layout()), - ConvolutionDimensionNumbersToString(dnums)); + return Internal("Invalid output layout %s for conv with dnums %s", + LayoutUtil::HumanString(output.layout()), + ConvolutionDimensionNumbersToString(dnums)); } return std::make_tuple(input_layout, filter_layout, output_layout); @@ -316,7 +356,7 @@ absl::Mutex& GetGpuMutex(const se::StreamExecutor* stream_exec) { return it->second; } -StatusOr> CreateKernel( +absl::StatusOr> CreateKernel( absl::string_view kernel_name, uint64_t num_args, absl::string_view ptx, absl::Span cubin_data, se::StreamExecutor* stream_exec, uint32_t shared_mem_bytes) { @@ -324,31 +364,42 @@ StatusOr> CreateKernel( loader_spec.AddCudaPtxInMemory(ptx, kernel_name); if (!cubin_data.empty()) { - loader_spec.AddCudaCubinInMemory( - reinterpret_cast(cubin_data.data()), kernel_name); + loader_spec.AddCudaCubinInMemory(cubin_data, kernel_name); } - auto kernel_base = std::make_unique(stream_exec); - TF_RETURN_IF_ERROR(stream_exec->GetKernel(loader_spec, kernel_base.get())); + TF_ASSIGN_OR_RETURN(std::unique_ptr kernel, + se::Kernel::Create(stream_exec, loader_spec)); + se::KernelMetadata m; m.set_shared_memory_bytes(shared_mem_bytes); - kernel_base->set_metadata(m); - return std::move(kernel_base); + kernel->set_metadata(m); + return kernel; +} + +absl::Status ExecuteKernelOnStream(const se::Kernel& kernel, + absl::Span args, + const LaunchDimensions& dims, + se::Stream* stream) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr kernel_args, + se::PackKernelArgs(args, kernel.metadata())); + + return stream->parent()->Launch(stream, dims.thread_counts_per_block(), + dims.block_counts(), kernel, *kernel_args); } -Status ExecuteKernelOnStream(const se::Kernel& kernel, - absl::Span args, - const LaunchDimensions& dims, se::Stream* stream) { +absl::Status ExecuteKernelOnStream(const se::Kernel& kernel, + absl::Span args, + const LaunchDimensions& dims, + const se::ClusterDim& cluster_dim, + se::Stream* stream) { TF_ASSIGN_OR_RETURN( std::unique_ptr kernel_args, se::PackKernelArgs(args, kernel.metadata())); - LaunchDimensions::Dim3D thread_counts = dims.thread_counts_per_block(); - LaunchDimensions::Dim3D block_counts = dims.block_counts(); - return stream->parent()->Launch( - stream, se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z), - se::BlockDim(block_counts.x, block_counts.y, block_counts.z), kernel, - *kernel_args); + return stream->parent()->Launch(stream, dims.thread_counts_per_block(), + dims.block_counts(), cluster_dim, kernel, + *kernel_args); } // Unimplemented for integers yet. @@ -410,8 +461,8 @@ static void InitializeTypedBuffer(se::Stream* stream, int64_t elements_copied = std::min(host_buffer->size() - host_index, elements_left); se::DeviceMemoryBase mem(current_addr, elements_copied * sizeof(T)); - stream->ThenMemcpy(&mem, host_buffer->data() + host_index, - elements_copied * sizeof(T)); + TF_CHECK_OK(stream->Memcpy(&mem, host_buffer->data() + host_index, + elements_copied * sizeof(T))); current_addr += elements_copied * sizeof(T); elements_left -= elements_copied; host_index += elements_copied; @@ -446,7 +497,7 @@ void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type, buffer_type); } -StatusOr GetDNNConvKindFromCudnnConvKind( +absl::StatusOr GetDNNConvKindFromCudnnConvKind( CudnnConvKind kind) { switch (kind) { case CudnnConvKind::kBackwardFilter: @@ -462,10 +513,24 @@ StatusOr GetDNNConvKindFromCudnnConvKind( default: break; } - return InternalError("Unexpected convolution kind"); + return Internal("Unexpected convolution kind"); } -StatusOr GetDNNFusedMHAKindFromCudnnfMHAKind( +absl::StatusOr GetDNNNormKindFromCudnnNormKind( + CudnnNormKind kind) { + switch (kind) { + case CudnnNormKind::kLayerForwardInfer: + return se::dnn::LAYER_FWD_INFER; + case CudnnNormKind::kLayerForwardTrain: + return se::dnn::LAYER_FWD_TRAIN; + case CudnnNormKind::kLayerBackward: + return se::dnn::LAYER_BWD; + default: + return Internal("Unexpected norm kind"); + } +} + +absl::StatusOr GetDNNFusedMHAKindFromCudnnfMHAKind( CudnnfMHAKind kind) { switch (kind) { case CudnnfMHAKind::kScaleBiasMaskSoftmaxDropout: @@ -491,10 +556,10 @@ StatusOr GetDNNFusedMHAKindFromCudnnfMHAKind( case CudnnfMHAKind::kBackwardSoftmax: return se::dnn::FusedMHAKind::BMM1_OUTPUT_INPUT_TYPE; } - return InternalError("Unexpected fMHA kind"); + return Internal("Unexpected fMHA kind"); } -StatusOr GetDNNDataTypeFromPrimitiveType( +absl::StatusOr GetDNNDataTypeFromPrimitiveType( PrimitiveType type) { switch (type) { case F16: @@ -516,7 +581,7 @@ StatusOr GetDNNDataTypeFromPrimitiveType( default: break; } - return InternalError("Unsupported convolution datatype"); + return Internal("Unsupported datatype"); } bool RequireDeterminism(const HloModuleConfig& config) { @@ -532,49 +597,114 @@ bool RequireDeterminism(const HloModuleConfig& config) { config.debug_options().xla_gpu_deterministic_ops(); } -StatusOr PickBestResult( +namespace { +std::vector KeepNonFailures( + absl::Span profile_results) { + // Filter out all failures except WRONG_RESULT, because false-positives are + // possible (e.g. perhaps the reference algorithm is the one that's + // incorrect!). Other failures can be detected with high accuracy. E.g. + // REDZONE_MODIFIED which is also quite severe. + std::vector filtered_results; + absl::c_copy_if(profile_results, std::back_inserter(filtered_results), + [](const AutotuneResult& r) { + return !r.has_failure() || + r.failure().kind() == AutotuneResult::WRONG_RESULT; + }); + return filtered_results; +} + +absl::Status AllAlgorithmsFailedInternalError( + std::optional instr_str, + absl::Span profile_results) { + std::ostringstream msg; + if (instr_str.has_value()) { + msg << "All algorithms tried for " << instr_str.value() + << " failed. Falling back to default algorithm. Per-algorithm " + "errors:"; + } else { + msg << "All algorithms failed. Falling back to the default algorithm. " + << "Per-algorithm errors:"; + } + for (const auto& result : profile_results) { + msg << "\n " << result.failure().msg(); + } + return Internal("%s", msg.str()); +} + +absl::Status NoAlgorithmSuppliedInternalError( + std::optional instr_str) { + std::ostringstream msg; + if (instr_str.has_value()) { + msg << "There are no algorithm candiates for computing: \n " + << instr_str.value() + << "\nThis likely means that the instruction shape is not supported by " + "the target GPU library."; + } else { + msg << "There are no algorithm candiates for computing the instruction.\n" + "This likely means that the instruction shape is not supported by " + "the target GPU library."; + } + return Internal("%s", msg.str()); +} + +void SortAutotuningResultsByRunTime(std::vector& results) { + absl::c_sort(results, + [](const AutotuneResult& lhs, const AutotuneResult& rhs) { + return tsl::proto_utils::FromDurationProto(lhs.run_time()) < + tsl::proto_utils::FromDurationProto(rhs.run_time()); + }); +} + +absl::Span TopResultsWithinMeasurementError( + std::vector& results_sorted_by_runtime) { + // This value was picked by repeatedly running a few kernels that run for a + // short time and observing the run-time variance. A more rigorous analysis + // of the measurement error might yield a better error threshold. + constexpr absl::Duration kMeasurementError = absl::Microseconds(2); + + absl::Duration min_time = tsl::proto_utils::FromDurationProto( + results_sorted_by_runtime.front().run_time()); + absl::Duration limit_time = min_time + kMeasurementError; + + auto limit_time_it = absl::c_find_if( + results_sorted_by_runtime, [limit_time](const AutotuneResult& x) { + return tsl::proto_utils::FromDurationProto(x.run_time()) > limit_time; + }); + return absl::MakeSpan(&*results_sorted_by_runtime.begin(), &*limit_time_it); +} +} // namespace + +absl::StatusOr PickBestResult( absl::Span profile_results, std::optional instr_str, HloModuleConfig hlo_module_config) { - std::vector filtered_results; + if (profile_results.empty()) { + return NoAlgorithmSuppliedInternalError(instr_str); + } - // For now, we ignore WRONG_RESULT failures because false-positives are - // possible (e.g. perhaps the reference algorithm is the one that's - // incorrect!). But we don't ignore REDZONE_MODIFIED failures because they're - // quite severe and can be detected with high accuracy. - absl::c_copy_if( - profile_results, std::back_inserter(filtered_results), - [](const AutotuneResult& r) { - return !(r.has_failure() && - r.failure().kind() != AutotuneResult::WRONG_RESULT); - }); + std::vector filtered_results = + KeepNonFailures(profile_results); if (filtered_results.empty()) { - std::ostringstream msg; - if (instr_str.has_value()) { - msg << "All algorithms tried for " << instr_str.value() - << " failed. Falling back to default algorithm. Per-algorithm " - "errors:"; - } else { - msg << "All algorithms failed. Falling back to the default algorithm. " - << "Per-algorithm errors:"; - } - for (const auto& result : profile_results) { - msg << "\n " << result.failure().msg(); - } - return InternalError("%s", msg.str()); + return AllAlgorithmsFailedInternalError(instr_str, profile_results); } - auto selected_result = filtered_results.begin(); - if (!RequireDeterminism(hlo_module_config)) { - selected_result = absl::c_min_element( - filtered_results, - [](const AutotuneResult& lhs, const AutotuneResult& rhs) { - return tsl::proto_utils::FromDurationProto(lhs.run_time()) < - tsl::proto_utils::FromDurationProto(rhs.run_time()); - }); + if (RequireDeterminism(hlo_module_config)) { + // If determinism is required (usually for debugging purposes) then always + // pick the first algorithm, instead of searching for the best, which can + // be noisy. + return *filtered_results.begin(); } - return *selected_result; + + // Kernel run-time measurements within kMeasurementError are not precise. + // Consider the lowest measurements within the error margin as equivalent and + // within them prefer algorithms that use the least amount of scratch memory. + SortAutotuningResultsByRunTime(filtered_results); + auto top_within_error = TopResultsWithinMeasurementError(filtered_results); + return *absl::c_min_element(top_within_error, [](const AutotuneResult& lhs, + const AutotuneResult& rhs) { + return lhs.scratch_bytes() < rhs.scratch_bytes(); + }); } } // namespace gpu diff --git a/xla/service/gpu/stream_executor_util.h b/xla/service/gpu/stream_executor_util.h index 52f39ac32cabe..b52062ee314f5 100644 --- a/xla/service/gpu/stream_executor_util.h +++ b/xla/service/gpu/stream_executor_util.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,19 +16,27 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ #define XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ +#include +#include +#include #include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/autotuning.pb.h" #include "xla/layout.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/hlo_module_config.h" -#include "xla/statusor.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/dnn.h" #include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" // Helper functions for interacting with StreamExecutor. @@ -36,16 +44,22 @@ limitations under the License. namespace xla { namespace gpu { +// Returns DNN version info from provided stream executor when possible, +// fallback version otherwise. +se::dnn::VersionInfo GetDnnVersionInfo( + stream_executor::StreamExecutor* stream_exec, + se::dnn::VersionInfo fallback_version = se::dnn::VersionInfo{0, 0, 0}); + // Returns (input, filter, output) XLA Layout protos given the StreamExecutor // layouts. -StatusOr> +absl::StatusOr> StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, se::dnn::DataLayout input, se::dnn::FilterLayout filter, se::dnn::DataLayout output); // Returns (input, filter, output) StreamExecutor layouts given the XLA layouts. -StatusOr< +absl::StatusOr< std::tuple> XlaConvShapesToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, const Shape& input, const Shape& filter, @@ -80,15 +94,23 @@ absl::Mutex& GetGpuMutex(const se::StreamExecutor* stream_exec); // // The canonical storage for both ptx and cubin_data should outlive // the lifetime of the kernel. -StatusOr> CreateKernel( +absl::StatusOr> CreateKernel( absl::string_view kernel_name, uint64_t num_args, absl::string_view ptx, absl::Span cubin_data, se::StreamExecutor* stream_exec, uint32_t shared_mem_bytes = 0); // Runs loaded kernel on the stream with the provided arguments. -Status ExecuteKernelOnStream(const se::Kernel& kernel, - absl::Span args, - const LaunchDimensions& dims, se::Stream* stream); +absl::Status ExecuteKernelOnStream(const se::Kernel& kernel, + absl::Span args, + const LaunchDimensions& dims, + se::Stream* stream); + +// Runs loaded kernel on the stream with the provided arguments. +absl::Status ExecuteKernelOnStream(const se::Kernel& kernel, + absl::Span args, + const LaunchDimensions& dims, + const se::ClusterDim& cluster_dim, + se::Stream* stream); // Initializes `buffer` with random data on `stream`. // `rng_state` is an inout parameter for the pseudorandom generator state. @@ -99,17 +121,21 @@ Status ExecuteKernelOnStream(const se::Kernel& kernel, void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type, int64_t* rng_state, se::DeviceMemoryBase buffer); -StatusOr GetDNNConvKindFromCudnnConvKind( +absl::StatusOr GetDNNConvKindFromCudnnConvKind( CudnnConvKind kind); -StatusOr GetDNNFusedMHAKindFromCudnnfMHAKind( +absl::StatusOr GetDNNNormKindFromCudnnNormKind( + CudnnNormKind kind); + +absl::StatusOr GetDNNFusedMHAKindFromCudnnfMHAKind( CudnnfMHAKind kind); -StatusOr GetDNNDataTypeFromPrimitiveType(PrimitiveType type); +absl::StatusOr GetDNNDataTypeFromPrimitiveType( + PrimitiveType type); // Returns result with the smallest time which has not failed. // If deterministic output is requested, returns first (not failing) result. -StatusOr PickBestResult( +absl::StatusOr PickBestResult( absl::Span profile_results, std::optional instr_str, HloModuleConfig hlo_module_config); diff --git a/xla/service/gpu/stream_executor_util_test.cc b/xla/service/gpu/stream_executor_util_test.cc new file mode 100644 index 0000000000000..cb3be24a6ceaa --- /dev/null +++ b/xla/service/gpu/stream_executor_util_test.cc @@ -0,0 +1,79 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/stream_executor_util.h" + +#include +#include + +#include +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "xla/autotuning.pb.h" +#include "xla/service/hlo_module_config.h" +#include "xla/tsl/util/proto/proto_utils.h" + +namespace xla::gpu { +namespace { + +struct Result { + int64_t run_time_ns; + int64_t scratch_bytes; + + bool operator==(const Result& other) const { + return other.run_time_ns == run_time_ns && + other.scratch_bytes == scratch_bytes; + }; + + explicit operator AutotuneResult() const { + AutotuneResult result; + *result.mutable_run_time() = + tsl::proto_utils::ToDurationProto(absl::Nanoseconds(run_time_ns)); + result.set_scratch_bytes(scratch_bytes); + return result; + } +}; + +static Result ATRToResult(AutotuneResult atr) { + return Result{.run_time_ns = absl::ToInt64Nanoseconds( + tsl::proto_utils::FromDurationProto(atr.run_time())), + .scratch_bytes = atr.scratch_bytes()}; +} + +std::vector Results(const std::vector& stats) { + std::vector results; + for (const auto& s : stats) results.push_back(AutotuneResult(s)); + return results; +} + +TEST(StreamExecutorTest, PickBestResult) { + absl::StatusOr atr; + + atr = PickBestResult(Results({{5000, 0}, {1000, 0}, {6000, 0}}), "", {}); + EXPECT_EQ(ATRToResult(atr.value()), Result({1000, 0})); + + atr = PickBestResult(Results({{4700, 0}, {4600, 0}, {4500, 0}}), "", {}); + EXPECT_EQ(ATRToResult(atr.value()), Result({4500, 0})); + + atr = PickBestResult(Results({{4700, 0}, {4600, 2}, {4500, 1}}), "", {}); + EXPECT_EQ(ATRToResult(atr.value()), Result({4700, 0})); + + atr = PickBestResult(Results({{5000, 1}, {6000, 0}, {7500, 0}}), "", {}); + EXPECT_EQ(ATRToResult(atr.value()), Result({6000, 0})); +} + +} // namespace + +} // namespace xla::gpu diff --git a/xla/service/gpu/target_constants.h b/xla/service/gpu/target_constants.h index 92f31a6c13506..13190ae690c13 100644 --- a/xla/service/gpu/target_constants.h +++ b/xla/service/gpu/target_constants.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -54,6 +54,24 @@ inline const char* DataLayout() { } // namespace amdgpu +namespace spir { +// The triple that represents our target on SPIR backend. +inline const char* TargetTriple() { + static constexpr char kTargetTriple[] = "spir64-unknown-unknown"; + return kTargetTriple; +} + +// The data layout of the emitted module. +inline const char* DataLayout() { + static constexpr char kDataLayout[] = + "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:" + "32:32-f64:64:64-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:" + "128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:" + "1024"; + return kDataLayout; +} +} // namespace spir + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/target_util.cc b/xla/service/gpu/target_util.cc index 12b12b3f19fae..3209c69f6cf51 100644 --- a/xla/service/gpu/target_util.cc +++ b/xla/service/gpu/target_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,17 +17,36 @@ limitations under the License. #include "xla/service/gpu/target_util.h" +#include #include +#include +#include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/CallingConv.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/FPEnv.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsAMDGPU.h" #include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/TargetParser/Triple.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" #include "xla/service/llvm_ir/llvm_type_conversion_util.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/status.h" +#include "xla/util.h" #include "tsl/platform/logging.h" namespace xla { @@ -45,6 +64,9 @@ struct TargetIntrinsics { std::variant*)>> amdgpu_intrinsic_or_function; + std::variant*)>> + spir_intrinsic_or_function; }; // Gets the llvm intrinsic ids on different platforms (NVPTX, AMDGPU) @@ -52,32 +74,82 @@ struct TargetIntrinsics { struct TargetIntrinsics GetIntrinsic(TargetIntrinsicID intrin) { switch (intrin) { case TargetIntrinsicID::kThreadIdx: { - return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, - llvm::Intrinsic::amdgcn_workitem_id_x}; + return { + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, + llvm::Intrinsic::amdgcn_workitem_id_x, + [](llvm::IRBuilder<>* b_) -> llvm::CallInst* { + return EmitDeviceFunctionCall( + "_Z32__spirv_BuiltInLocalInvocationIdi", {b_->getInt32(0)}, + {U32}, U64, {b_->getContext()}, b_); + }, + }; } case TargetIntrinsicID::kThreadIdy: { - return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y, - llvm::Intrinsic::amdgcn_workitem_id_y}; + return { + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y, + llvm::Intrinsic::amdgcn_workitem_id_y, + [](llvm::IRBuilder<>* b_) -> llvm::CallInst* { + return EmitDeviceFunctionCall( + "_Z32__spirv_BuiltInLocalInvocationIdi", {b_->getInt32(1)}, + {U32}, U64, {b_->getContext()}, b_); + }, + }; } case TargetIntrinsicID::kThreadIdz: { - return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z, - llvm::Intrinsic::amdgcn_workitem_id_z}; + return { + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z, + llvm::Intrinsic::amdgcn_workitem_id_z, + [](llvm::IRBuilder<>* b_) -> llvm::CallInst* { + return EmitDeviceFunctionCall( + "_Z32__spirv_BuiltInLocalInvocationIdi", {b_->getInt32(2)}, + {U32}, U64, {b_->getContext()}, b_); + }, + }; } case TargetIntrinsicID::kBlockIdx: { - return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, - llvm::Intrinsic::amdgcn_workgroup_id_x}; + return { + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, + llvm::Intrinsic::amdgcn_workgroup_id_x, + [](llvm::IRBuilder<>* b_) -> llvm::CallInst* { + return EmitDeviceFunctionCall("_Z26__spirv_BuiltInWorkgroupIdi", + {b_->getInt32(0)}, {U32}, U64, + {b_->getContext()}, b_); + }, + }; } case TargetIntrinsicID::kBlockIdy: { - return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y, - llvm::Intrinsic::amdgcn_workgroup_id_y}; + return { + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y, + llvm::Intrinsic::amdgcn_workgroup_id_y, + [](llvm::IRBuilder<>* b_) -> llvm::CallInst* { + return EmitDeviceFunctionCall("_Z26__spirv_BuiltInWorkgroupIdi", + {b_->getInt32(1)}, {U32}, U64, + {b_->getContext()}, b_); + }, + }; } case TargetIntrinsicID::kBlockIdz: { - return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z, - llvm::Intrinsic::amdgcn_workgroup_id_z}; + return { + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z, + llvm::Intrinsic::amdgcn_workgroup_id_z, + [](llvm::IRBuilder<>* b_) -> llvm::CallInst* { + return EmitDeviceFunctionCall("_Z26__spirv_BuiltInWorkgroupIdi", + {b_->getInt32(2)}, {U32}, U64, + {b_->getContext()}, b_); + }, + }; } case TargetIntrinsicID::kBarrierId: { - return {llvm::Intrinsic::nvvm_barrier0, - llvm::Intrinsic::amdgcn_s_barrier}; + return {llvm::Intrinsic::nvvm_barrier0, llvm::Intrinsic::amdgcn_s_barrier, + [](llvm::IRBuilder<>* b_) -> llvm::CallInst* { + return EmitDeviceFunctionCall( + "_Z22__spirv_ControlBarrierjjj", + {b_->getInt32(2), b_->getInt32(2), b_->getInt32(272)}, + {U32, U32, U32}, U32, + llvm::AttrBuilder(b_->getContext()) + .addAttribute(llvm::Attribute::Convergent), + b_); + }}; } case TargetIntrinsicID::kBlockDimx: { return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, @@ -85,6 +157,11 @@ struct TargetIntrinsics GetIntrinsic(TargetIntrinsicID intrin) { return EmitDeviceFunctionCall("__ockl_get_local_size", {b_->getInt32(0)}, {U32}, U64, {b_->getContext()}, b_); + }, + [](llvm::IRBuilder<>* b_) -> llvm::CallInst* { + return EmitDeviceFunctionCall( + "_Z28__spirv_BuiltInWorkgroupSizei", {b_->getInt32(0)}, + {U32}, U64, {b_->getContext()}, b_); }}; } case TargetIntrinsicID::kBlockDimy: { @@ -93,6 +170,11 @@ struct TargetIntrinsics GetIntrinsic(TargetIntrinsicID intrin) { return EmitDeviceFunctionCall("__ockl_get_local_size", {b_->getInt32(1)}, {U32}, U64, {b_->getContext()}, b_); + }, + [](llvm::IRBuilder<>* b_) -> llvm::CallInst* { + return EmitDeviceFunctionCall( + "_Z28__spirv_BuiltInWorkgroupSizei", {b_->getInt32(1)}, + {U32}, U64, {b_->getContext()}, b_); }}; } case TargetIntrinsicID::kBlockDimz: { @@ -101,11 +183,25 @@ struct TargetIntrinsics GetIntrinsic(TargetIntrinsicID intrin) { return EmitDeviceFunctionCall("__ockl_get_local_size", {b_->getInt32(2)}, {U32}, U64, {b_->getContext()}, b_); + }, + [](llvm::IRBuilder<>* b_) -> llvm::CallInst* { + return EmitDeviceFunctionCall( + "_Z28__spirv_BuiltInWorkgroupSizei", {b_->getInt32(2)}, + {U32}, U64, {b_->getContext()}, b_); }}; } case TargetIntrinsicID::kGroupBarrierId: { return {llvm::Intrinsic::nvvm_bar_warp_sync, - llvm::Intrinsic::amdgcn_wave_barrier}; + llvm::Intrinsic::amdgcn_wave_barrier, + [](llvm::IRBuilder<>* b_) -> llvm::CallInst* { + return EmitDeviceFunctionCall( + "_Z22__spirv_ControlBarrierjjj", + {b_->getInt32(2), b_->getInt32(2), b_->getInt32(272)}, + {U32, U32, U32}, U32, + llvm::AttrBuilder(b_->getContext()) + .addAttribute(llvm::Attribute::Convergent), + b_); + }}; } } } @@ -114,6 +210,7 @@ struct TargetIntrinsics GetIntrinsic(TargetIntrinsicID intrin) { struct TargetDeviceFunction { const std::string nvptx_root; const std::string amdgpu_root; + const std::string spir_root; }; // Gets the device function name on different platforms (NVPTX, AMDGPU) @@ -122,55 +219,58 @@ struct TargetDeviceFunction GetDeviceFunctionRoot( TargetDeviceFunctionID func_id) { switch (func_id) { case TargetDeviceFunctionID::kAtan2: { - return {"__nv_atan2", "__ocml_atan2"}; + return {"__nv_atan2", "__ocml_atan2", "_Z17__spirv_ocl_atan2"}; } case TargetDeviceFunctionID::kCos: { - return {"__nv_cos", "__ocml_cos"}; + return {"__nv_cos", "__ocml_cos", "_Z15__spirv_ocl_cos"}; + } + case TargetDeviceFunctionID::kErf: { + return {"__nv_erf", "__ocml_erf", "_Z15__spirv_ocl_erf"}; } case TargetDeviceFunctionID::kExp: { - return {"__nv_exp", "__ocml_exp"}; + return {"__nv_exp", "__ocml_exp", "_Z15__spirv_ocl_exp"}; } case TargetDeviceFunctionID::kExpm1: { - return {"__nv_expm1", "__ocml_expm1"}; + return {"__nv_expm1", "__ocml_expm1", "_Z17__spirv_ocl_expm1"}; } case TargetDeviceFunctionID::kFmod: { - return {"__nv_fmod", "__ocml_fmod"}; + return {"__nv_fmod", "__ocml_fmod", "_Z16__spirv_ocl_fmod"}; } case TargetDeviceFunctionID::kHypot: { - return {"__nv_hypot", "__ocml_hypot"}; + return {"__nv_hypot", "__ocml_hypot", "_Z17__spirv_ocl_hypot"}; } case TargetDeviceFunctionID::kLog: { - return {"__nv_log", "__ocml_log"}; + return {"__nv_log", "__ocml_log", "_Z15__spirv_ocl_log"}; } case TargetDeviceFunctionID::kLog1p: { - return {"__nv_log1p", "__ocml_log1p"}; + return {"__nv_log1p", "__ocml_log1p", "_Z17__spirv_ocl_log1p"}; } case TargetDeviceFunctionID::kPow: { - return {"__nv_pow", "__ocml_pow"}; + return {"__nv_pow", "__ocml_pow", "_Z15__spirv_ocl_pow"}; } case TargetDeviceFunctionID::kRsqrt: { - return {"__nv_rsqrt", "__ocml_rsqrt"}; + return {"__nv_rsqrt", "__ocml_rsqrt", "_Z17__spirv_ocl_rsqrt"}; } case TargetDeviceFunctionID::kSin: { - return {"__nv_sin", "__ocml_sin"}; + return {"__nv_sin", "__ocml_sin", "_Z15__spirv_ocl_sin"}; } case TargetDeviceFunctionID::kSqrt: { - return {"__nv_sqrt", "__ocml_sqrt"}; + return {"__nv_sqrt", "__ocml_sqrt", "_Z16__spirv_ocl_sqrt"}; } case TargetDeviceFunctionID::kTan: { - return {"__nv_tan", "__ocml_tan"}; + return {"__nv_tan", "__ocml_tan", "_Z15__spirv_ocl_tan"}; } case TargetDeviceFunctionID::kTanh: { - return {"__nv_tanh", "__ocml_tanh"}; + return {"__nv_tanh", "__ocml_tanh", "_Z16__spirv_ocl_tanh"}; } case TargetDeviceFunctionID::kCbrt: { - return {"__nv_cbrt", "__ocml_cbrt"}; + return {"__nv_cbrt", "__ocml_cbrt", "_Z16__spirv_ocl_cbrt"}; } } } } // namespace -StatusOr GetTargetDeviceFunctionID(HloOpcode op) { +absl::StatusOr GetTargetDeviceFunctionID(HloOpcode op) { switch (op) { case HloOpcode::kAtan2: return TargetDeviceFunctionID::kAtan2; @@ -178,6 +278,8 @@ StatusOr GetTargetDeviceFunctionID(HloOpcode op) { return TargetDeviceFunctionID::kCos; case HloOpcode::kExp: return TargetDeviceFunctionID::kExp; + case HloOpcode::kErf: + return TargetDeviceFunctionID::kErf; case HloOpcode::kExpm1: return TargetDeviceFunctionID::kExpm1; case HloOpcode::kLog: @@ -231,6 +333,28 @@ std::string ObtainDeviceFunctionName(TargetDeviceFunctionID func_id, } else { LOG(FATAL) << "Unexpected type while getting device function name."; } + } else if (target_triple.isSPIR()) { + if (output_type == F32) { + if (gpu_root_names.spir_root == "_Z17__spirv_ocl_hypot" || + gpu_root_names.spir_root == "_Z15__spirv_ocl_pow" || + gpu_root_names.spir_root == "_Z17__spirv_ocl_atan2" || + gpu_root_names.spir_root == "_Z16__spirv_ocl_fmod") { + return StrCat(gpu_root_names.spir_root, "ff"); + } else { + return StrCat(gpu_root_names.spir_root, "f"); + } + } else if (output_type == F64) { + if (gpu_root_names.spir_root == "_Z17__spirv_ocl_hypot" || + gpu_root_names.spir_root == "_Z15__spirv_ocl_pow" || + gpu_root_names.spir_root == "_Z17__spirv_ocl_atan2" || + gpu_root_names.spir_root == "_Z16__spirv_ocl_fmod") { + return StrCat(gpu_root_names.spir_root, "dd"); + } else { + return StrCat(gpu_root_names.spir_root, "d"); + } + } else { + LOG(FATAL) << "Unexpected type while getting device function name."; + } } else { LOG(FATAL) << "Invalid triple " << target_triple.str(); } @@ -243,6 +367,7 @@ llvm::CallInst* EmitDeviceFunctionCall( absl::string_view name) { std::vector ir_input_types; llvm::Module* module = b->GetInsertBlock()->getModule(); + llvm::Triple target_triple = llvm::Triple(module->getTargetTriple()); for (PrimitiveType input_type : input_types) { ir_input_types.push_back( llvm_ir::PrimitiveTypeToIrType(input_type, module)); @@ -260,6 +385,8 @@ llvm::CallInst* EmitDeviceFunctionCall( .getCallee()); callee->addFnAttrs(attributes); + if (target_triple.isSPIR()) + callee->setCallingConv(llvm::CallingConv::SPIR_FUNC); return b->CreateCall(callee, llvm_ir::AsArrayRef(operands), name.data()); } @@ -285,6 +412,18 @@ llvm::CallInst* EmitCallToTargetIntrinsic( &gpu_intrinsic_id.amdgpu_intrinsic_or_function); return (*builder_func)(b); } + } else if (target_triple.isSPIR()) { + llvm::Intrinsic::ID* llvm_intrinsic_id_ptr = + std::get_if( + &gpu_intrinsic_id.spir_intrinsic_or_function); + if (llvm_intrinsic_id_ptr) { + llvm_intrinsic_id = *llvm_intrinsic_id_ptr; + } else { + std::function*)>* builder_func = + std::get_if*)>>( + &gpu_intrinsic_id.spir_intrinsic_or_function); + return (*builder_func)(b); + } } else { LOG(FATAL) << "Invalid triple " << target_triple.str(); } @@ -312,6 +451,9 @@ void AnnotateFunctionAsGpuKernel(llvm::Module* module, llvm::Function* func, // Attach information so AMDGPU can recognize function as a AMDGPU kernel. func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); func->addFnAttr("amdgpu-flat-work-group-size", "1, 1024"); + } else if (target_triple.isSPIR()) { + // Attach information so that it can be recognized as a SPIR kernel. + func->setCallingConv(llvm::CallingConv::SPIR_KERNEL); } else { LOG(FATAL) << "Invalid triple " << target_triple.str(); } diff --git a/xla/service/gpu/target_util.h b/xla/service/gpu/target_util.h index 19b00307b3fea..cb321fb0ba990 100644 --- a/xla/service/gpu/target_util.h +++ b/xla/service/gpu/target_util.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,11 +18,15 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "llvm/IR/Attributes.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Value.h" #include "llvm/TargetParser/Triple.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/xla_data.pb.h" @@ -62,10 +66,11 @@ enum class TargetDeviceFunctionID { kSqrt, kTan, kTanh, + kErf, }; // HLO opcode -> TargetDeviceFunctionID mapping. -StatusOr GetTargetDeviceFunctionID(HloOpcode); +absl::StatusOr GetTargetDeviceFunctionID(HloOpcode); // Emits IR to call a device function named "callee_name" on the given // operand. Returns the IR value that represents the return value. diff --git a/xla/service/gpu/target_util_test.cc b/xla/service/gpu/target_util_test.cc index 11d4e8232f798..751efdecf23b5 100644 --- a/xla/service/gpu/target_util_test.cc +++ b/xla/service/gpu/target_util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,9 +16,12 @@ limitations under the License. #include "xla/service/gpu/target_util.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Verifier.h" +#include "llvm/Support/raw_ostream.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/service/gpu/tests/BUILD b/xla/service/gpu/tests/BUILD index d437f1b8a4564..b1cdb53cc34f5 100644 --- a/xla/service/gpu/tests/BUILD +++ b/xla/service/gpu/tests/BUILD @@ -1,14 +1,6 @@ # Description: GPU-specific XLA tests. For example, codegen tests that # verify the IR emitted. -load("//xla/tests:build_defs.bzl", "xla_test") -load("@bazel_skylib//rules:build_test.bzl", "build_test") -load("//xla:glob_lit_test.bzl", "glob_lit_tests") -load( - "//xla:xla.bzl", - "xla_cc_binary", - "xla_cc_test", -) load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured", @@ -17,11 +9,18 @@ load("@tsl//tsl:tsl.default.bzl", "filegroup") load( "@tsl//tsl/platform:build_config_root.bzl", "tf_cuda_tests_tags", + "tf_gpu_tests_tags", ) load( "@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) +load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") +load( + "//xla:xla.bzl", + "xla_cc_test", +) +load("//xla/tests:build_defs.bzl", "xla_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -54,17 +53,16 @@ cc_library( deps = [ "//xla:debug_options_flags", "//xla:shape_util", - "//xla:types", + "//xla/service:executable", "//xla/service:gpu_plugin", + "//xla/service:hlo_module_config", "//xla/service/gpu:gpu_executable", - "//xla/stream_executor", + "//xla/stream_executor:platform_manager", "//xla/tests:filecheck", "//xla/tests:llvm_irgen_test_base", "//xla/tests:verified_hlo_module", - "@com_google_absl//absl/memory", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:test", ], ) @@ -73,7 +71,6 @@ xla_cc_test( srcs = ["element_wise_row_vectorization_test.cc"], tags = tf_cuda_tests_tags(), deps = [ - ":gpu_codegen_test", "//xla:error_spec", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -87,6 +84,21 @@ xla_cc_test( deps = [ ":gpu_codegen_test", "//xla:literal_util", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test_main", + ], +) + +xla_test( + name = "float_conversions_test", + srcs = ["float_conversions_test.cc"], + backends = ["gpu"], + deps = [ + ":gpu_codegen_test", + "//xla:error_spec", + "//xla/tests:test_utils", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], ) @@ -96,14 +108,17 @@ xla_cc_test( srcs = ["gpu_reduce_scatter_creator_test.cc"], deps = [ "//xla:util", - "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service/gpu:gpu_reduce_scatter_creator", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/platform:test", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:statusor", ], ) @@ -118,7 +133,9 @@ xla_cc_test( "//xla/service/gpu:gpu_all_gather_optimizer", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], ) @@ -130,34 +147,50 @@ xla_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", + "//xla:debug_options_flags", + "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", + "//xla/service:executable", "//xla/service:hlo_module_config", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ], ) -xla_cc_test( +xla_test( name = "gemm_rewrite_test", - srcs = if_cuda_is_configured(["gemm_rewrite_test.cc"]) + if_rocm_is_configured(["gemm_rewrite_test.cc"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), - tags = tf_cuda_tests_tags(), + srcs = ["gemm_rewrite_test.cc"], + backends = ["gpu"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), deps = [ ":gpu_codegen_test", + "//xla:error_spec", "//xla:statusor", "//xla:test", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:gpu_plugin", + "//xla/service:buffer_assignment", + "//xla/service:executable", "//xla/service:hlo_module_config", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service/gpu:gemm_rewriter", "//xla/service/gpu:gpu_executable", + "//xla/service/gpu:variant_visitor", + "//xla/stream_executor:device_description", "//xla/tests:filecheck", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", + "@com_google_absl//absl/types:span", "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", @@ -168,18 +201,19 @@ xla_cc_test( xla_cc_test( name = "gemm_broadcast_folding_rewrite_test", - srcs = [ - "gemm_broadcast_folding_rewrite_test.cc", - ], - tags = tf_cuda_tests_tags() + [ - "no_rocm", - ], + srcs = ["gemm_broadcast_folding_rewrite_test.cc"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + tags = tf_gpu_tests_tags(), deps = [ ":gpu_codegen_test", "//xla:error_spec", + "//xla/hlo/ir:hlo", "//xla/service:gpu_plugin", "//xla/service/gpu:gemm_broadcast_folding_rewriter", "//xla/service/gpu:gemm_rewriter", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -193,9 +227,10 @@ xla_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", - "//xla:statusor", - "//xla/tests:hlo_test_base", - "@tsl//tsl/lib/core:status_test_util", + "//xla/hlo/ir:hlo", + "//xla/service:executable", + "@com_google_absl//absl/status:statusor", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -207,7 +242,6 @@ xla_cc_test( "reduction_degenerate_dim_remover_test.cc", ], deps = [ - "//xla:debug_options_flags", "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", @@ -215,10 +249,8 @@ xla_cc_test( "//xla/service/gpu:reduction_degenerate_dim_remover", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", - "//xla/tests:llvm_irgen_test_base", - "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:string_view", "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -231,19 +263,15 @@ xla_cc_test( ], tags = tf_cuda_tests_tags(), deps = [ - "//xla:debug_options_flags", - "//xla:statusor", + "//xla:error_spec", "//xla/hlo/ir:hlo", "//xla/service:gpu_plugin", - "//xla/service:hlo_module_config", "//xla/service:hlo_parser", "//xla/service/gpu:reduction_layout_normalizer", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", - "//xla/tests:llvm_irgen_test_base", - "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:string_view", "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -255,19 +283,16 @@ xla_cc_test( "tree_reduction_rewriter_test.cc", ], deps = [ - "//xla:debug_options_flags", "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", "//xla/service:hlo_parser", "//xla/service/gpu:tree_reduction_rewriter", + "//xla/stream_executor:device_description", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", - "//xla/tests:llvm_irgen_test_base", - "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -281,19 +306,8 @@ xla_cc_test( tags = ["no_rocm"] + tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", - "//xla:debug_options_flags", - "//xla:statusor", - "//xla/hlo/ir:hlo", + "//xla:error_spec", "//xla/service:gpu_plugin", - "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", - "//xla/service/gpu:gemm_rewriter", - "//xla/service/gpu:gpu_executable", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "//xla/tests:llvm_irgen_test_base", - "@com_google_absl//absl/memory", - "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -307,21 +321,19 @@ xla_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", - "//xla:debug_options_flags", + "//xla:error_spec", "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service:gpu_plugin", "//xla/service:hlo_module_config", "//xla/service:hlo_parser", - "//xla/service/gpu:gemm_rewriter", "//xla/service/gpu:gpu_executable", + "//xla/stream_executor:device_description", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", - "//xla/tests:llvm_irgen_test_base", - "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -333,18 +345,13 @@ xla_cc_test( "reduction_dimension_grouper_test.cc", ], deps = [ - "//xla:debug_options_flags", - "//xla:statusor", "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", "//xla/service:hlo_parser", "//xla/service/gpu:reduction_dimension_grouper", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", - "//xla/tests:llvm_irgen_test_base", - "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:string_view", "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -358,16 +365,16 @@ xla_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", + "//xla:error_spec", + "//xla:literal_util", + "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service:gpu_plugin", - "//xla/service:hlo_module_config", "//xla/service:hlo_parser", - "//xla/service/gpu:gpu_executable", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", - "//xla/tests:llvm_irgen_test_base", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:test", + "//xla/tests:verified_hlo_module", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ], ) @@ -380,8 +387,11 @@ xla_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", + "//xla:error_spec", "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", + "//xla/tests:verified_hlo_module", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ], ) @@ -392,12 +402,12 @@ xla_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", + "//xla:error_spec", "//xla:literal", "//xla:literal_util", - "//xla:util", - "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "@com_google_absl//absl/memory", + "//xla/tests:verified_hlo_module", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -411,8 +421,11 @@ xla_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", + "//xla:error_spec", "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", + "//xla/tests:verified_hlo_module", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ], ) @@ -425,7 +438,31 @@ xla_cc_test( ":gpu_codegen_test", "//xla:shape_util", "//xla/hlo/ir:hlo", - "@com_google_googletest//:gtest", + "@tsl//tsl/platform:test_main", + ], +) + +xla_test( + name = "gpu_triton_custom_call_test", + srcs = ["gpu_triton_custom_call_test.cc"], + backends = [ + "gpu_a100", + "gpu_v100", + ], + deps = [ + ":gpu_codegen_test", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:test_main", ], ) @@ -436,6 +473,8 @@ xla_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", + "//xla:shape_util", + "//xla/hlo/ir:hlo", "//xla/tests:verified_hlo_module", "@tsl//tsl/platform:test_main", ], @@ -447,16 +486,15 @@ xla_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", + "//xla:comparison_util", "//xla:literal", "//xla:shape_util", - "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", "//xla/service:hlo_parser", "//xla/tests:hlo_test_base", - "@com_google_absl//absl/memory", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -467,11 +505,13 @@ xla_cc_test( srcs = ["infeed_test.cc"], tags = tf_cuda_tests_tags(), deps = [ - ":gpu_codegen_test", + ":gpu_codegen_test", # build_cleaner: keep + "//xla:array3d", + "//xla:array4d", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:test_helpers", - "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", "//xla/client:xla_builder", @@ -479,7 +519,6 @@ xla_cc_test( "//xla/tests:client_library_test_base", "//xla/tests:literal_test_util", "@tsl//tsl/platform:env", - "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test_main", ], ) @@ -497,27 +536,38 @@ xla_test( ], deps = [ ":gpu_codegen_test", - "//xla/hlo/ir:hlo", + "//xla:error_spec", "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], ) +xla_cc_test( + name = "concatenate_emitter_test", + srcs = ["concatenate_emitter_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_codegen_test", + "//xla:error_spec", + "//xla/tests:hlo_test_base", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + xla_cc_test( name = "transpose_emitter_test", srcs = ["transpose_emitter_test.cc"], tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", + "//xla:error_spec", "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -529,11 +579,8 @@ xla_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", + "//xla:error_spec", "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -546,11 +593,10 @@ xla_cc_test( deps = [ ":gpu_codegen_test", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", - "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "@com_google_absl//absl/memory", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -565,10 +611,8 @@ xla_cc_test( ":gpu_codegen_test", "//xla:literal", "//xla:shape_util", - "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "@com_google_absl//absl/memory", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -580,9 +624,12 @@ xla_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", + "//xla:shape_util", + "//xla/hlo/ir:hlo", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:instruction_fusion", + "@com_google_absl//absl/strings:string_view", "@tsl//tsl/platform:test_main", ], ) @@ -593,14 +640,15 @@ xla_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", + "//xla:shape_util", + "//xla/service:hlo_cost_analysis", "//xla/service:hlo_pass_pipeline", "//xla/service/gpu:fusion_merger", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:instruction_fusion", "//xla/service/gpu:multi_output_fusion", "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "@tsl//tsl/platform:test", + "@com_google_absl//absl/strings:string_view", "@tsl//tsl/platform:test_main", ], ) @@ -611,8 +659,8 @@ xla_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", + "//xla:debug_options_flags", "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", "//xla/tests:hlo_test_base", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -630,7 +678,6 @@ xla_cc_test( "//xla/service:gpu_plugin", "//xla/service/llvm_ir:alias_analysis", "//xla/tests:filecheck", - "//xla/tests:llvm_irgen_test_base", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -654,7 +701,7 @@ xla_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", - "//xla/hlo/ir:hlo", + "//xla:error_spec", "//xla/service:hlo_module_config", "//xla/service:hlo_parser", "//xla/tests:hlo_test_base", @@ -677,8 +724,11 @@ xla_test( "notap", ], deps = [ + "//xla:debug_options_flags", + "//xla/service:hlo_module_config", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", ], ) @@ -711,99 +761,86 @@ xla_cc_test( "//xla:xla_data_proto_cc", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", "@eigen_archive//:eigen3", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ], ) -build_test( - name = "hlo_to_llvm_ir_build_test", - targets = [ - ":hlo_to_llvm_ir", - ], -) - -xla_cc_binary( - name = "hlo_to_llvm_ir", - testonly = True, - srcs = ["hlo_to_llvm_ir.cc"], - copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "-DTENSORFLOW_USE_ROCM=1", - ]), - deps = [ - "//xla:status", - "//xla/hlo/ir:hlo", - "//xla/service/gpu:buffer_sharing", - "//xla/service/gpu:compile_module_to_llvm_ir", - "//xla/service/gpu:gpu_compiler", - "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu:gpu_hlo_schedule", - "//xla/service/gpu:target_constants", - "//xla/service/gpu/llvm_gpu_backend", - "//xla/stream_executor", - "//xla/stream_executor:device_description", - "//xla/stream_executor:dnn", - "//xla/stream_executor/host:host_platform", - "//xla/tests:test_utils", - "//xla/tools:hlo_module_loader", - "@llvm-project//llvm:Target", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:platform_port", - "@tsl//tsl/util:command_line_flags", - ] + if_cuda_is_configured([ - "//xla/stream_executor/cuda:cuda_platform_id", - "//xla/service/gpu:nvptx_compiler_impl", - "//xla/stream_executor/cuda:cublas_plugin", - ]) + if_rocm_is_configured([ - "//xla/stream_executor/rocm:rocm_platform_id", - "//xla/stream_executor/rocm:rocblas_plugin", - "//xla/stream_executor/rocm:rocm_helpers", - "//xla/service/gpu:amdgpu_compiler_impl", - "@tsl//tsl/platform:rocm_rocdl_path", - ]), -) - -glob_lit_tests( +lit_test_suite( name = "all_tests", - data = [":test_utilities"], - default_tags = tf_cuda_tests_tags() + [ - ], - driver = "//xla:run_lit.sh", - exclude = ["execute_memzero_thunk.mlir"], - features = if_cuda_is_configured([ + srcs = enforce_glob( + [ + "add_preds.hlo", + "calling_convention.hlo", + "copy.hlo", + "dynamic_update_slice_inplace.hlo", + "element_wise_row_vectorization.hlo", + "fused_scatter.hlo", + "fused_slice.hlo", + "kernel_reuse.hlo", + "launch_dimensions.hlo", + "pad_to_static.hlo", + "reduce_atomic_min.hlo", + "reduce_column_layout_change.hlo", + "reduce_f64_column.hlo", + "reduce_large_row_to_scalar.hlo", + "reduce_row_vectorized.hlo", + "reduce_unnested.hlo", + "reduce_variadic_column.hlo", + "reduction_vectorization_sm_all.hlo", + "rng_get_and_update_state.hlo", + "scatter.hlo", + "select_and_scatter.hlo", + "single_instruction.hlo", + "slice_to_dynamic.hlo", + "sorting.hlo", + "transpose_021.hlo", + "transpose_021_extra_output.hlo", + "transpose_210.hlo", + "transpose_210_extra_output.hlo", + "triton_naming.hlo", + ], + include = [ + "*.hlo", + ], + ), + args = if_cuda_is_configured([ "--param=PTX=PTX", - "--param=SUBST_TIDX=@llvm.nvvm.read.ptx.sreg.tid.x", - "--param=SUBST_CTAIDX=@llvm.nvvm.read.ptx.sreg.ctaid.x", - "--param=SUBST_BARRIER=@llvm.nvvm.barrier0", - "--param=SUBST_KERNEL_ANNOTATION=''", - "--param=SUBST_ADDRSPACE_ANNOTATION=''", + "--param=GPU=a6000", ]) + if_rocm_is_configured([ "--param=PTX=GCN", - "--param=SUBST_TIDX=@llvm.amdgcn.workitem.id.x", - "--param=SUBST_CTAIDX=@llvm.amdgcn.workgroup.id.x", - "--param=SUBST_BARRIER=@llvm.amdgcn.s.barrier", - "--param=SUBST_KERNEL_ANNOTATION=amdgpu_kernel[SPACE]", - "--param=SUBST_ADDRSPACE_ANNOTATION='addrspace(5)[SPACE]'", + "--param=GPU=mi200", ]), + cfg = "//xla:lit.cfg.py", + data = [ + ":test_utilities", + ], + default_tags = tf_cuda_tests_tags(), tags_override = { "reduction_vectorization_sm_all.hlo": ["no_rocm"], "element_wise_row_vectorization.hlo": ["no_rocm"], "single_instruction.hlo": ["no_rocm"], "reduce_unnested.hlo": ["no_rocm"], }, - test_file_exts = ["hlo"], + tools = [ + "//xla/tools:hlo-opt", + "@llvm-project//llvm:FileCheck", + ], ) -# Bundle together all of the test utilities that are used by tests. filegroup( name = "test_utilities", testonly = True, data = [ - ":hlo_to_llvm_ir", + "//xla/tools:hlo-opt", + "//xla/tools/hlo_opt:gpu_specs/a100_80.txtpb", + "//xla/tools/hlo_opt:gpu_specs/a6000.txtpb", + "//xla/tools/hlo_opt:gpu_specs/h100.txtpb", + "//xla/tools/hlo_opt:gpu_specs/mi200.txtpb", + "//xla/tools/hlo_opt:gpu_specs/p100.txtpb", + "//xla/tools/hlo_opt:gpu_specs/v100.txtpb", "@llvm-project//llvm:FileCheck", - "@llvm-project//mlir:run_lit.sh", ], ) @@ -813,9 +850,8 @@ xla_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", + "//xla:error_spec", "//xla/tests:hlo_test_base", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], ) @@ -827,8 +863,6 @@ xla_cc_test( deps = [ ":gpu_codegen_test", "//xla/tests:hlo_test_base", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], ) @@ -838,6 +872,7 @@ xla_cc_test( srcs = ["in_place_op_test.cc"], tags = tf_cuda_tests_tags(), deps = [ + "//xla:debug_options_flags", "//xla/service:gpu_plugin", "//xla/tests:hlo_test_base", "@tsl//tsl/platform:test_main", @@ -852,6 +887,8 @@ xla_cc_test( "//xla:shape_util", "//xla:types", "//xla:xla_proto_cc", + "//xla/stream_executor:platform_manager", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", @@ -877,23 +914,30 @@ xla_test( "gpu", ], deps = [ - ":gpu_codegen_test", "//xla:error_spec", "//xla/tests:hlo_test_base", - "@com_google_googletest//:gtest", "@tsl//tsl/platform:tensor_float_32_utils", "@tsl//tsl/platform:test_main", ], ) -xla_cc_test( +xla_test( name = "gpu_fused_mha_test", srcs = ["gpu_fused_mha_test.cc"], - tags = tf_cuda_tests_tags(), + backend_tags = {"gpu": [ + "requires-gpu-sm80", + ]}, + backends = [ + "gpu", + ], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + shard_count = 2, deps = [ ":gpu_codegen_test", "//xla:array4d", + "//xla:error_spec", "//xla:literal", + "//xla:literal_util", "//xla:reference_util", "//xla:shape_util", "//xla:statusor", @@ -904,7 +948,9 @@ xla_cc_test( "//xla/client:xla_builder", "//xla/client:xla_computation", "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:stream_executor_util", "//xla/stream_executor", "//xla/stream_executor:device_description", "//xla/stream_executor:dnn", @@ -914,12 +960,16 @@ xla_cc_test( "//xla/tests:test_macros_header", "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_googletest//:gtest", - "@tsl//tsl/platform:logging", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", - ], + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]), ) # This library is here to be reused by tests. @@ -932,7 +982,6 @@ cc_library( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", "@tsl//tsl/lib/core:status_test_util", ], ) @@ -997,3 +1046,15 @@ xla_cc_test( "@tsl//tsl/platform:test_main", ], ) + +xla_cc_test( + name = "simplify_fp_conversions_test", + srcs = ["simplify_fp_conversions_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + "//xla:xla_proto_cc", + "//xla/service:gpu_plugin", + "//xla/tests:hlo_test_base", + "@tsl//tsl/platform:test_main", + ], +) diff --git a/xla/service/gpu/tests/add_preds.hlo b/xla/service/gpu/tests/add_preds.hlo index af446aa619ec3..120b6a5ad686b 100644 --- a/xla/service/gpu/tests/add_preds.hlo +++ b/xla/service/gpu/tests/add_preds.hlo @@ -1,27 +1,9 @@ -// RUN: hlo_to_llvm_ir %s | FileCheck %{IR_SUBST} %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck %s -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// CHECK-LABEL: entry: -// CHECK: %[[VAL_0:.*]] = call i32 [[CTAIDX]] -// CHECK: %[[VAL_1:.*]] = call i32 [[TIDX]] -// CHECK: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 1 -// CHECK: %[[VAL_3:.*]] = add nuw nsw i32 %[[VAL_2]], %[[VAL_1]] -// CHECK: %[[VAL_4:.*]] = icmp ult i32 %[[VAL_3]], 1 -// CHECK: call void @llvm.assume(i1 %[[VAL_4]]) -// CHECK: %[[VAL_5:.*]] = icmp ult i32 %[[VAL_3]], 1 -// CHECK: br i1 %[[VAL_5]], label %[[VAL_6:.*]], label %[[VAL_7:.*]] -// CHECK: fusion.in_bounds-after: ; preds = %[[VAL_6]], %[[VAL_8:.*]] -// CHECK: ret void -// CHECK: fusion.in_bounds-true: ; preds = %[[VAL_8]] -// CHECK: %[[VAL_9:.*]] = load i8, ptr %[[VAL_10:.*]], align 1, !invariant.load -// CHECK: %[[VAL_11:.*]] = load i8, ptr %[[VAL_12:.*]], align 1, !invariant.load -// CHECK: %[[VAL_13:.*]] = or i8 %[[VAL_9]], %[[VAL_11]] -// CHECK: %[[VAL_14:.*]] = trunc i8 %[[VAL_13]] to i1 -// CHECK: %[[VAL_15:.*]] = xor i1 %[[VAL_14]], true -// CHECK: %[[VAL_16:.*]] = zext i1 %[[VAL_15]] to i8 -// CHECK: store i8 %[[VAL_16]], ptr %[[VAL_17:.*]], align 1 -// CHECK: br label %[[VAL_7]] +// CHECK: define void @fusion({{.*}}%[[ARG0:.*]], {{.*}}%[[ARG1:.*]], +// CHECK: %[[A:.*]] = load {{.*}} ptr %[[ARG0]] +// CHECK: %[[B:.*]] = load {{.*}} ptr %[[ARG1]] +// CHECK: or {{.*}} %[[A]], %[[B]] HloModule xla_computation_f.8, is_scheduled=true diff --git a/xla/service/gpu/tests/calling_convention.hlo b/xla/service/gpu/tests/calling_convention.hlo index 7729ba7cdcff0..c84e0194c347c 100644 --- a/xla/service/gpu/tests/calling_convention.hlo +++ b/xla/service/gpu/tests/calling_convention.hlo @@ -1,4 +1,4 @@ -// RUN: hlo_to_llvm_ir %s | FileCheck %{IR_SUBST} %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s // Arguments are passed separately. // Even constant arguments are passed as arguments. @@ -6,7 +6,8 @@ // CHECK-LABEL: target triple // CHECK: @buffer_for_dynamic // CHECK: @buffer_for_static -// CHECK: define [[KERNEL_ANNOTATION]]void @custom_call(ptr noalias align 16 dereferenceable(32) %arg0, ptr noalias align 128 dereferenceable(4) %arg1, ptr noalias align 128 dereferenceable(4) %arg2, ptr noalias align 128 dereferenceable(32) %arg3) +// CHECK-PTX: define void @custom_call(ptr noalias align 16 dereferenceable(32) %arg0, ptr noalias align 128 dereferenceable(4) %arg1, ptr noalias align 128 dereferenceable(4) %arg2, ptr noalias align 128 dereferenceable(44) %arg3) +// CHECK-GCN: define amdgpu_kernel void @custom_call(ptr noalias align 16 dereferenceable(32) %arg0, ptr noalias align 128 dereferenceable(4) %arg1, ptr noalias align 128 dereferenceable(4) %arg2, ptr noalias align 128 dereferenceable(44) %arg3) // CHECK-NOT: @buffer_for_dynamic // CHECK-NOT: @buffer_for_static diff --git a/xla/service/gpu/tests/concat.hlo b/xla/service/gpu/tests/concat.hlo deleted file mode 100644 index 9fbcca44675fe..0000000000000 --- a/xla/service/gpu/tests/concat.hlo +++ /dev/null @@ -1,197 +0,0 @@ -// RUN: hlo_to_llvm_ir %s | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %{IR_SUBST} %s - -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// CHECK-LABEL: entry: -// CHECK: %[[VAL_0:.*]] = call i32 [[CTAIDX]] -// CHECK: %[[VAL_1:.*]] = call i32 [[TIDX]] -// CHECK-PTX: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 128 -// CHECK-GCN: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 256 -// CHECK: %[[VAL_3:.*]] = add nuw nsw i32 %[[VAL_2]], %[[VAL_1]] -// CHECK: %[[VAL_4:.*]] = icmp ult i32 %[[VAL_3]], 11008 -// CHECK: call void @llvm.assume(i1 %[[VAL_4]]) -// CHECK: %[[VAL_5:.*]] = udiv i32 %[[VAL_3]], 1 -// CHECK: %[[VAL_6:.*]] = icmp ult i32 %[[VAL_3]], 11000 -// CHECK: br i1 %[[VAL_6]], label %[[VAL_7:.*]], label %[[VAL_8:.*]] -// CHECK: fusion.in_bounds-after: ; preds = %[[VAL_9:.*]], %[[VAL_10:.*]] -// CHECK: ret void -// CHECK: fusion.in_bounds-true: ; preds = %[[VAL_10]] -// CHECK: br label %[[VAL_11:.*]] -// CHECK: concat_index_from_operand_id0: ; preds = %[[VAL_12:.*]] -// CHECK: %[[VAL_13:.*]] = phi i32 [ 0, %[[VAL_12]] ] -// CHECK: %[[VAL_14:.*]] = sub nsw i32 %[[VAL_5]], %[[VAL_13]] -// CHECK: %[[VAL_15:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_16:.*]], i32 0, i32 %[[VAL_14]] -// CHECK: %[[VAL_17:.*]] = load float, ptr %[[VAL_15]], align 4, !invariant.load -// CHECK: %[[VAL_18:.*]] = fptrunc float %[[VAL_17]] to half -// CHECK: br label %[[VAL_9]] -// CHECK: concat_index_from_operand_id1: ; preds = %[[VAL_19:.*]] -// CHECK: %[[VAL_20:.*]] = phi i32 [ 1000, %[[VAL_19]] ] -// CHECK: %[[VAL_21:.*]] = sub nsw i32 %[[VAL_5]], %[[VAL_20]] -// CHECK: %[[VAL_22:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_23:.*]], i32 0, i32 %[[VAL_21]] -// CHECK: %[[VAL_24:.*]] = load float, ptr %[[VAL_22]], align 4, !invariant.load -// CHECK: %[[VAL_25:.*]] = fptrunc float %[[VAL_24]] to half -// CHECK: br label %[[VAL_9]] -// CHECK: concat_index_from_operand_id2: ; preds = %[[VAL_26:.*]] -// CHECK: %[[VAL_27:.*]] = phi i32 [ 2000, %[[VAL_26]] ] -// CHECK: %[[VAL_28:.*]] = sub nsw i32 %[[VAL_5]], %[[VAL_27]] -// CHECK: %[[VAL_29:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_30:.*]], i32 0, i32 %[[VAL_28]] -// CHECK: %[[VAL_31:.*]] = load float, ptr %[[VAL_29]], align 4, !invariant.load -// CHECK: %[[VAL_32:.*]] = fptrunc float %[[VAL_31]] to half -// CHECK: br label %[[VAL_9]] -// CHECK: concat_index_from_operand_id3: ; preds = %[[VAL_33:.*]] -// CHECK: %[[VAL_34:.*]] = phi i32 [ 3000, %[[VAL_33]] ] -// CHECK: %[[VAL_35:.*]] = sub nsw i32 %[[VAL_5]], %[[VAL_34]] -// CHECK: %[[VAL_36:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_37:.*]], i32 0, i32 %[[VAL_35]] -// CHECK: %[[VAL_38:.*]] = load float, ptr %[[VAL_36]], align 4, !invariant.load -// CHECK: %[[VAL_39:.*]] = fptrunc float %[[VAL_38]] to half -// CHECK: br label %[[VAL_9]] -// CHECK: concat_index_from_operand_id4: ; preds = %[[VAL_40:.*]] -// CHECK: %[[VAL_41:.*]] = phi i32 [ 4000, %[[VAL_40]] ] -// CHECK: %[[VAL_42:.*]] = sub nsw i32 %[[VAL_5]], %[[VAL_41]] -// CHECK: %[[VAL_43:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_44:.*]], i32 0, i32 %[[VAL_42]] -// CHECK: %[[VAL_45:.*]] = load float, ptr %[[VAL_43]], align 4, !invariant.load -// CHECK: %[[VAL_46:.*]] = fptrunc float %[[VAL_45]] to half -// CHECK: br label %[[VAL_9]] -// CHECK: concat_index_from_operand_id5: ; preds = %[[VAL_47:.*]] -// CHECK: %[[VAL_48:.*]] = phi i32 [ 5000, %[[VAL_47]] ] -// CHECK: %[[VAL_49:.*]] = sub nsw i32 %[[VAL_5]], %[[VAL_48]] -// CHECK: %[[VAL_50:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_51:.*]], i32 0, i32 %[[VAL_49]] -// CHECK: %[[VAL_52:.*]] = load float, ptr %[[VAL_50]], align 4, !invariant.load -// CHECK: %[[VAL_53:.*]] = fptrunc float %[[VAL_52]] to half -// CHECK: br label %[[VAL_9]] -// CHECK: concat_index_from_operand_id6: ; preds = %[[VAL_54:.*]] -// CHECK: %[[VAL_55:.*]] = phi i32 [ 6000, %[[VAL_54]] ] -// CHECK: %[[VAL_56:.*]] = sub nsw i32 %[[VAL_5]], %[[VAL_55]] -// CHECK: %[[VAL_57:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_58:.*]], i32 0, i32 %[[VAL_56]] -// CHECK: %[[VAL_59:.*]] = load float, ptr %[[VAL_57]], align 4, !invariant.load -// CHECK: %[[VAL_60:.*]] = fptrunc float %[[VAL_59]] to half -// CHECK: br label %[[VAL_9]] -// CHECK: concat_index_from_operand_id7: ; preds = %[[VAL_61:.*]] -// CHECK: %[[VAL_62:.*]] = phi i32 [ 7000, %[[VAL_61]] ] -// CHECK: %[[VAL_63:.*]] = sub nsw i32 %[[VAL_5]], %[[VAL_62]] -// CHECK: %[[VAL_64:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_65:.*]], i32 0, i32 %[[VAL_63]] -// CHECK: %[[VAL_66:.*]] = load float, ptr %[[VAL_64]], align 4, !invariant.load -// CHECK: %[[VAL_67:.*]] = fptrunc float %[[VAL_66]] to half -// CHECK: br label %[[VAL_9]] -// CHECK: concat_index_from_operand_id8: ; preds = %[[VAL_68:.*]] -// CHECK: %[[VAL_69:.*]] = phi i32 [ 8000, %[[VAL_68]] ] -// CHECK: %[[VAL_70:.*]] = sub nsw i32 %[[VAL_5]], %[[VAL_69]] -// CHECK: %[[VAL_71:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_72:.*]], i32 0, i32 %[[VAL_70]] -// CHECK: %[[VAL_73:.*]] = load float, ptr %[[VAL_71]], align 4, !invariant.load -// CHECK: %[[VAL_74:.*]] = fptrunc float %[[VAL_73]] to half -// CHECK: br label %[[VAL_9]] -// CHECK: concat_index_from_operand_id9: ; preds = %[[VAL_75:.*]] -// CHECK: %[[VAL_76:.*]] = phi i32 [ 9000, %[[VAL_75]] ] -// CHECK: %[[VAL_77:.*]] = sub nsw i32 %[[VAL_5]], %[[VAL_76]] -// CHECK: %[[VAL_78:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_79:.*]], i32 0, i32 %[[VAL_77]] -// CHECK: %[[VAL_80:.*]] = load float, ptr %[[VAL_78]], align 4, !invariant.load -// CHECK: %[[VAL_81:.*]] = fptrunc float %[[VAL_80]] to half -// CHECK: br label %[[VAL_9]] -// CHECK: concat_index_from_operand_id10: ; preds = %[[VAL_82:.*]] -// CHECK: %[[VAL_83:.*]] = phi i32 [ 10000, %[[VAL_82]] ] -// CHECK: %[[VAL_84:.*]] = sub nsw i32 %[[VAL_5]], %[[VAL_83]] -// CHECK: %[[VAL_85:.*]] = getelementptr inbounds [1000 x float], ptr %[[VAL_86:.*]], i32 0, i32 %[[VAL_84]] -// CHECK: %[[VAL_87:.*]] = load float, ptr %[[VAL_85]], align 4, !invariant.load -// CHECK: %[[VAL_88:.*]] = fptrunc float %[[VAL_87]] to half -// CHECK: br label %[[VAL_9]] -// CHECK: concatenate.pivot.5000.: ; preds = %[[VAL_7]] -// CHECK: %[[VAL_89:.*]] = icmp ult i32 %[[VAL_5]], 5000 -// CHECK: br i1 %[[VAL_89]], label %[[VAL_90:.*]], label %[[VAL_91:.*]] -// CHECK: concatenate.pivot.2000.: ; preds = %[[VAL_11]] -// CHECK: %[[VAL_92:.*]] = icmp ult i32 %[[VAL_5]], 2000 -// CHECK: br i1 %[[VAL_92]], label %[[VAL_93:.*]], label %[[VAL_94:.*]] -// CHECK: concatenate.pivot.1000.: ; preds = %[[VAL_90]] -// CHECK: %[[VAL_95:.*]] = icmp ult i32 %[[VAL_5]], 1000 -// CHECK: br i1 %[[VAL_95]], label %[[VAL_12]], label %[[VAL_19]] -// CHECK: concatenate.pivot.0.: ; preds = %[[VAL_93]] -// CHECK: br label %[[VAL_96:.*]] -// CHECK: concatenate.pivot.1000.1: ; preds = %[[VAL_93]] -// CHECK: br label %[[VAL_97:.*]] -// CHECK: concatenate.pivot.3000.: ; preds = %[[VAL_90]] -// CHECK: %[[VAL_98:.*]] = icmp ult i32 %[[VAL_5]], 3000 -// CHECK: br i1 %[[VAL_98]], label %[[VAL_26]], label %[[VAL_99:.*]] -// CHECK: concatenate.pivot.2000.2: ; preds = %[[VAL_94]] -// CHECK: br label %[[VAL_100:.*]] -// CHECK: concatenate.pivot.4000.: ; preds = %[[VAL_94]] -// CHECK: %[[VAL_101:.*]] = icmp ult i32 %[[VAL_5]], 4000 -// CHECK: br i1 %[[VAL_101]], label %[[VAL_33]], label %[[VAL_40]] -// CHECK: concatenate.pivot.3000.3: ; preds = %[[VAL_99]] -// CHECK: br label %[[VAL_102:.*]] -// CHECK: concatenate.pivot.4000.4: ; preds = %[[VAL_99]] -// CHECK: br label %[[VAL_103:.*]] -// CHECK: concatenate.pivot.8000.: ; preds = %[[VAL_11]] -// CHECK: %[[VAL_104:.*]] = icmp ult i32 %[[VAL_5]], 8000 -// CHECK: br i1 %[[VAL_104]], label %[[VAL_105:.*]], label %[[VAL_106:.*]] -// CHECK: concatenate.pivot.6000.: ; preds = %[[VAL_91]] -// CHECK: %[[VAL_107:.*]] = icmp ult i32 %[[VAL_5]], 6000 -// CHECK: br i1 %[[VAL_107]], label %[[VAL_47]], label %[[VAL_108:.*]] -// CHECK: concatenate.pivot.5000.5: ; preds = %[[VAL_105]] -// CHECK: br label %[[VAL_109:.*]] -// CHECK: concatenate.pivot.7000.: ; preds = %[[VAL_105]] -// CHECK: %[[VAL_110:.*]] = icmp ult i32 %[[VAL_5]], 7000 -// CHECK: br i1 %[[VAL_110]], label %[[VAL_54]], label %[[VAL_61]] -// CHECK: concatenate.pivot.6000.6: ; preds = %[[VAL_108]] -// CHECK: br label %[[VAL_111:.*]] -// CHECK: concatenate.pivot.7000.7: ; preds = %[[VAL_108]] -// CHECK: br label %[[VAL_112:.*]] -// CHECK: concatenate.pivot.9000.: ; preds = %[[VAL_91]] -// CHECK: %[[VAL_113:.*]] = icmp ult i32 %[[VAL_5]], 9000 -// CHECK: br i1 %[[VAL_113]], label %[[VAL_68]], label %[[VAL_114:.*]] -// CHECK: concatenate.pivot.8000.8: ; preds = %[[VAL_106]] -// CHECK: br label %[[VAL_115:.*]] -// CHECK: concatenate.pivot.10000.: ; preds = %[[VAL_106]] -// CHECK: %[[VAL_116:.*]] = icmp ult i32 %[[VAL_5]], 10000 -// CHECK: br i1 %[[VAL_116]], label %[[VAL_75]], label %[[VAL_82]] -// CHECK: concatenate.pivot.9000.9: ; preds = %[[VAL_114]] -// CHECK: br label %[[VAL_117:.*]] -// CHECK: concatenate.pivot.10000.10: ; preds = %[[VAL_114]] -// CHECK: br label %[[VAL_118:.*]] -// CHECK: out.1.merge: ; preds = %[[VAL_118]], %[[VAL_117]], %[[VAL_115]], %[[VAL_112]], %[[VAL_111]], %[[VAL_109]], %[[VAL_103]], %[[VAL_102]], %[[VAL_100]], %[[VAL_97]], %[[VAL_96]] -// CHECK: %[[VAL_119:.*]] = phi half [ %[[VAL_18]], %[[VAL_96]] ], [ %[[VAL_25]], %[[VAL_97]] ], [ %[[VAL_32]], %[[VAL_100]] ], [ %[[VAL_39]], %[[VAL_102]] ], [ %[[VAL_46]], %[[VAL_103]] ], [ %[[VAL_53]], %[[VAL_109]] ], [ %[[VAL_60]], %[[VAL_111]] ], [ %[[VAL_67]], %[[VAL_112]] ], [ %[[VAL_74]], %[[VAL_115]] ], [ %[[VAL_81]], %[[VAL_117]] ], [ %[[VAL_88]], %[[VAL_118]] ] -// CHECK: %[[VAL_120:.*]] = getelementptr inbounds half, ptr %[[VAL_121:.*]], i32 %[[VAL_3]] -// CHECK: store half %[[VAL_119]], ptr %[[VAL_120]], align 2 -// CHECK: br label %[[VAL_8]] - -HloModule module, is_scheduled=true - -%fused_computation (param_0.1: f32[1000], param_1.2: f32[1000], param_2.3: f32[1000], param_3.4: f32[1000], param_4.5: f32[1000], param_5.6: f32[1000], param_6.7: f32[1000], param_7.8: f32[1000], param_8.9: f32[1000], param_9.10: f32[1000], param_10.11: f32[1000]) -> f16[11000] { - %param_10.11 = f32[1000]{0} parameter(10) - %converted0.1 = f16[1000]{0} convert(f32[1000]{0} %param_10.11) - %param_9.10 = f32[1000]{0} parameter(9) - %converted1.1 = f16[1000]{0} convert(f32[1000]{0} %param_9.10) - %param_8.9 = f32[1000]{0} parameter(8) - %converted2.1 = f16[1000]{0} convert(f32[1000]{0} %param_8.9) - %param_7.8 = f32[1000]{0} parameter(7) - %converted3.1 = f16[1000]{0} convert(f32[1000]{0} %param_7.8) - %param_6.7 = f32[1000]{0} parameter(6) - %converted4.1 = f16[1000]{0} convert(f32[1000]{0} %param_6.7) - %param_5.6 = f32[1000]{0} parameter(5) - %converted5.1 = f16[1000]{0} convert(f32[1000]{0} %param_5.6) - %param_4.5 = f32[1000]{0} parameter(4) - %converted6.1 = f16[1000]{0} convert(f32[1000]{0} %param_4.5) - %param_3.4 = f32[1000]{0} parameter(3) - %converted7.1 = f16[1000]{0} convert(f32[1000]{0} %param_3.4) - %param_2.3 = f32[1000]{0} parameter(2) - %converted8.1 = f16[1000]{0} convert(f32[1000]{0} %param_2.3) - %param_1.2 = f32[1000]{0} parameter(1) - %converted9.1 = f16[1000]{0} convert(f32[1000]{0} %param_1.2) - %param_0.1 = f32[1000]{0} parameter(0) - %converted10.1 = f16[1000]{0} convert(f32[1000]{0} %param_0.1) - ROOT %out.1 = f16[11000]{0} concatenate(f16[1000]{0} %converted0.1, f16[1000]{0} %converted1.1, f16[1000]{0} %converted2.1, f16[1000]{0} %converted3.1, f16[1000]{0} %converted4.1, /*index=5*/f16[1000]{0} %converted5.1, f16[1000]{0} %converted6.1, f16[1000]{0} %converted7.1, f16[1000]{0} %converted8.1, f16[1000]{0} %converted9.1, /*index=10*/f16[1000]{0} %converted10.1), dimensions={0} -} - -ENTRY %computation (p0: f32[1000], p1: f32[1000], p2: f32[1000], p3: f32[1000], p4: f32[1000], p5: f32[1000], p6: f32[1000], p7: f32[1000], p8: f32[1000], p9: f32[1000], p10: f32[1000]) -> f16[11000] { - %p10 = f32[1000]{0} parameter(10) - %p9 = f32[1000]{0} parameter(9) - %p8 = f32[1000]{0} parameter(8) - %p7 = f32[1000]{0} parameter(7) - %p6 = f32[1000]{0} parameter(6) - %p5 = f32[1000]{0} parameter(5) - %p4 = f32[1000]{0} parameter(4) - %p3 = f32[1000]{0} parameter(3) - %p2 = f32[1000]{0} parameter(2) - %p1 = f32[1000]{0} parameter(1) - %p0 = f32[1000]{0} parameter(0) - ROOT %fusion = f16[11000]{0} fusion(f32[1000]{0} %p10, f32[1000]{0} %p9, f32[1000]{0} %p8, f32[1000]{0} %p7, f32[1000]{0} %p6, /*index=5*/f32[1000]{0} %p5, f32[1000]{0} %p4, f32[1000]{0} %p3, f32[1000]{0} %p2, f32[1000]{0} %p1, /*index=10*/f32[1000]{0} %p0), kind=kLoop, calls=%fused_computation -} - diff --git a/xla/service/gpu/tests/concatenate_emitter_test.cc b/xla/service/gpu/tests/concatenate_emitter_test.cc new file mode 100644 index 0000000000000..5dd1ad48c5e9a --- /dev/null +++ b/xla/service/gpu/tests/concatenate_emitter_test.cc @@ -0,0 +1,171 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "xla/error_spec.h" +#include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace { + +class ConcatenateEmitterTest : public gpu::GpuCodegenTest { + protected: + ConcatenateEmitterTest() = default; +}; + +TEST_F(ConcatenateEmitterTest, Simple) { + const char* const kHloString = R"( + HloModule module + + ENTRY main { + param0 = f32[128] parameter(0) + param1 = f32[128] parameter(1) + ROOT concat = f32[256] concatenate(param0, param1), dimensions={0} + })"; + + auto expected_ir = R"( +; CHECK-DAG: %[[ARG0:.*]] = addrspacecast ptr %arg0 +; CHECK-DAG: %[[ARG1:.*]] = addrspacecast ptr %arg1 +; CHECK-DAG: %[[ARG2:.*]] = addrspacecast ptr %arg2 +; CHECK: %[[PTR:.*]] = getelementptr float, ptr addrspace(1) %[[ARG0]] +; CHECK-DAG: %[[VAL:.*]] = load float, ptr addrspace(1) %[[PTR]] +; CHECK-DAG: %[[DST:.*]] = getelementptr inbounds [256 x float], ptr addrspace(1) %[[ARG2]] +; CHECK: store float %[[VAL]], ptr addrspace(1) %[[DST]] +; CHECK: %[[PTR:.*]] = getelementptr float, ptr addrspace(1) %[[ARG1]] +; CHECK-DAG: %[[VAL:.*]] = load float, ptr addrspace(1) %[[PTR]] +; CHECK-DAG: %[[PTR:.*]] = getelementptr inbounds i8, ptr addrspace(1) %[[DST]], i64 512 +; CHECK: store float %[[VAL]], ptr addrspace(1) %[[PTR]] +; CHECK: !"reqntidx", i32 128 +)"; + CompileAndVerifyIr(kHloString, MakePlatformSpecificLlvm(expected_ir), + /*match_optimized_ir=*/true, + /*run_optimization_passes=*/false); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(ConcatenateEmitterTest, PrologueAndEpilogue) { + const char* const kHloString = R"( + HloModule module + + fused_computation { + param0 = f32[128] parameter(0) + negate = f32[128] negate(param0) + param1 = f32[128] parameter(1) + concat = f32[256] concatenate(negate, param1), dimensions={0} + param2 = f32[256] parameter(2) + ROOT add = f32[256] add(concat, param2) + } + + ENTRY main { + param0 = f32[128] parameter(0) + param1 = f32[128] parameter(1) + param2 = f32[256] parameter(2) + ROOT %fusion = f32[256] fusion(param0, param1, param2), kind=kInput, calls=fused_computation + })"; + + auto expected_ir = R"( +; CHECK-DAG: %[[ARG0:.*]] = addrspacecast ptr %arg0 +; CHECK-DAG: %[[ARG1:.*]] = addrspacecast ptr %arg1 +; CHECK-DAG: %[[ARG2:.*]] = addrspacecast ptr %arg2 +; CHECK-DAG: %[[ARG3:.*]] = addrspacecast ptr %arg3 +; CHECK: %[[PTR:.*]] = getelementptr float, ptr addrspace(1) %[[ARG0]] +; CHECK: %[[RHS:.*]] = load float, ptr addrspace(1) %[[PTR]] +; CHECK: %[[SRC:.*]] = getelementptr inbounds [256 x float], ptr addrspace(1) %[[ARG2]] +; CHECK: %[[LHS:.*]] = load float, ptr addrspace(1) %[[SRC]] +; CHECK: %[[VAL:.*]] = fsub float %[[LHS]], %[[RHS]] +; CHECK: %[[DST:.*]] = getelementptr inbounds [256 x float], ptr addrspace(1) %[[ARG3]] +; CHECK: store float %[[VAL]], ptr addrspace(1) %[[DST]] +; CHECK: %[[PTR:.*]] = getelementptr float, ptr addrspace(1) %[[ARG1]] +; CHECK: %[[LHS:.*]] = load float, ptr addrspace(1) %[[PTR]] +; CHECK: %[[PTR:.*]] = getelementptr inbounds i8, ptr addrspace(1) %[[SRC]], i64 512 +; CHECK: %[[RHS:.*]] = load float, ptr addrspace(1) %[[PTR]] +; CHECK: %[[VAL:.*]] = fadd float %[[LHS]], %[[RHS]] +; CHECK: %[[PTR:.*]] = getelementptr inbounds i8, ptr addrspace(1) %[[DST]], i64 512 +; CHECK: store float %[[VAL]], ptr addrspace(1) %[[PTR]] +)"; + CompileAndVerifyIr(kHloString, MakePlatformSpecificLlvm(expected_ir), + /*match_optimized_ir=*/true, + /*run_optimization_passes=*/false); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(ConcatenateEmitterTest, MajorDimension) { + const char* const kHloString = R"( + HloModule module + + fused_computation { + param0 = f32[16,16] parameter(0) + param1 = f32[16,16] parameter(1) + ROOT concat = f32[32,16] concatenate(param0, param1), dimensions={0} + } + + ENTRY main { + param0 = f32[16,16] parameter(0) + param1 = f32[16,16] parameter(1) + ROOT %fusion = f32[32,16] fusion(param0, param1), kind=kInput, calls=fused_computation + })"; + + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(ConcatenateEmitterTest, DifferentSizes) { + const char* const kHloString = R"( + HloModule module + + ENTRY main { + param0 = f32[112] parameter(0) + param1 = f32[128] parameter(1) + ROOT concat = f32[240] concatenate(param0, param1), dimensions={0} + })"; + + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(ConcatenateEmitterTest, RepeatedInput) { + const char* const kHloString = R"( + HloModule module + + ENTRY main { + param0 = f32[128] parameter(0) + ROOT concat = f32[256] concatenate(param0, param0), dimensions={0} + })"; + + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(ConcatenateEmitterTest, BitcastEpilogue) { + const char* const kHloString = R"( + HloModule module + + fused_computation { + param0 = f32[128] parameter(0) + param1 = f32[128] parameter(1) + concat = f32[256] concatenate(param0, param1), dimensions={0} + ROOT bitcast = f32[1,16,16] bitcast(concat) + } + + ENTRY main { + param0 = f32[128] parameter(0) + param1 = f32[128] parameter(1) + ROOT %fusion = f32[1,16,16] fusion(param0, param1), kind=kInput, calls=fused_computation + })"; + + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +} // namespace +} // namespace xla diff --git a/xla/service/gpu/tests/constant.hlo b/xla/service/gpu/tests/constant.hlo deleted file mode 100644 index 5dc79d68cf8a0..0000000000000 --- a/xla/service/gpu/tests/constant.hlo +++ /dev/null @@ -1,17 +0,0 @@ -// RUN: hlo_to_llvm_ir %s | FileCheck %s - -HloModule Test, is_scheduled=true - -fused_computation { - param_0 = pred[2,2]{1,0} parameter(0) - param_1 = pred[2,2]{1,0} parameter(1) - ROOT xor.1 = pred[2,2]{1,0} xor(pred[2,2]{1,0} param_0, pred[2,2]{1,0} param_1) -} - -ENTRY main { -// CHECK: %{{.*}} = getelementptr inbounds i8, ptr %arg0, i32 %{{.*}} -// CHECK: %{{.*}} = getelementptr inbounds i8, ptr %arg1, i32 %{{.*}} - a = pred[2, 2]{1,0} constant({{false, true}, {true, false}}) - b = pred[2, 2]{1,0} constant({{false, true}, {false, true}}) - ROOT wrapped_xor = pred[2,2]{1,0} fusion(pred[2,2]{1,0} a, pred[2,2]{1,0} b), kind=kLoop, calls=fused_computation -} diff --git a/xla/service/gpu/tests/copy.hlo b/xla/service/gpu/tests/copy.hlo index 3b608d880ffbe..beac8e6d36b11 100644 --- a/xla/service/gpu/tests/copy.hlo +++ b/xla/service/gpu/tests/copy.hlo @@ -1,4 +1,4 @@ -// RUN: hlo_to_llvm_ir %s | FileCheck %{IR_SUBST} %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck --check-prefixes=CHECK-%{PTX} %s // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py @@ -8,7 +8,8 @@ // minimized and named to reflect the test intent. -// CHECK: call void [[BARRIER]] +// CHECK-PTX: call void @llvm.nvvm.barrier0 +// CHECK-GCN: call void @llvm.amdgcn.s.barrier HloModule Test, is_scheduled=true diff --git a/xla/service/gpu/tests/copy_nested.hlo b/xla/service/gpu/tests/copy_nested.hlo deleted file mode 100644 index d09d584eeaa0c..0000000000000 --- a/xla/service/gpu/tests/copy_nested.hlo +++ /dev/null @@ -1,85 +0,0 @@ -// RUN: hlo_to_llvm_ir %s | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %{IR_SUBST} %s - -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// CHECK-LABEL: entry: -// CHECK: %[[VAL_0:.*]] = alloca i32, align 4 -// CHECK: store i32 0, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_0]], align 4 -// CHECK: br label %[[VAL_1:.*]] -// CHECK: loop.loop_header: ; preds = %[[VAL_2:.*]], %[[VAL_3:.*]] -// CHECK: %[[VAL_4:.*]] = load i32, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_0]], align 4 -// CHECK: %[[VAL_5:.*]] = icmp uge i32 %[[VAL_4]], 6000000 -// CHECK: br i1 %[[VAL_5]], label %[[VAL_6:.*]], label %[[VAL_7:.*]] -// CHECK: loop.loop_body: ; preds = %[[VAL_1]] -// CHECK-PTX: %[[VAL_8:.*]] = add nuw nsw i32 %[[VAL_4]], 516096 -// CHECK-GCN: %[[VAL_8:.*]] = add nuw nsw i32 %[[VAL_4]], 851968 -// CHECK: store i32 %[[VAL_8]], ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_0]], align 4 -// CHECK: %[[VAL_9:.*]] = icmp eq i32 %[[VAL_4]], 0 -// CHECK: %[[VAL_10:.*]] = call i32 [[CTAIDX]] -// CHECK: %[[VAL_11:.*]] = call i32 [[TIDX]] -// CHECK: %[[VAL_12:.*]] = mul nuw nsw i32 %[[VAL_10]], 128 -// CHECK: %[[VAL_13:.*]] = add nuw nsw i32 %[[VAL_12]], %[[VAL_11]] -// CHECK-PTX: %[[VAL_14:.*]] = icmp ult i32 %[[VAL_13]], 129024 -// CHECK-GCN: %[[VAL_14:.*]] = icmp ult i32 %[[VAL_13]], 212992 -// CHECK: call void @llvm.assume(i1 %[[VAL_14]]) -// CHECK: %[[VAL_15:.*]] = mul nuw nsw i32 %[[VAL_13]], 4 -// CHECK: %[[VAL_16:.*]] = add nuw nsw i32 %[[VAL_15]], %[[VAL_4]] -// CHECK: %[[VAL_17:.*]] = udiv i32 %[[VAL_16]], 1 -// CHECK: %[[VAL_18:.*]] = urem i32 %[[VAL_17]], 300 -// CHECK: %[[VAL_19:.*]] = udiv i32 %[[VAL_16]], 300 -// CHECK: %[[VAL_20:.*]] = urem i32 %[[VAL_19]], 100 -// CHECK: %[[VAL_21:.*]] = udiv i32 %[[VAL_16]], 30000 -// CHECK: %[[VAL_22:.*]] = add nuw nsw i32 %[[VAL_16]], 1 -// CHECK: %[[VAL_23:.*]] = udiv i32 %[[VAL_22]], 1 -// CHECK: %[[VAL_24:.*]] = urem i32 %[[VAL_23]], 300 -// CHECK: %[[VAL_25:.*]] = udiv i32 %[[VAL_22]], 300 -// CHECK: %[[VAL_26:.*]] = urem i32 %[[VAL_25]], 100 -// CHECK: %[[VAL_27:.*]] = udiv i32 %[[VAL_22]], 30000 -// CHECK: %[[VAL_28:.*]] = add nuw nsw i32 %[[VAL_16]], 2 -// CHECK: %[[VAL_29:.*]] = udiv i32 %[[VAL_28]], 1 -// CHECK: %[[VAL_30:.*]] = urem i32 %[[VAL_29]], 300 -// CHECK: %[[VAL_31:.*]] = udiv i32 %[[VAL_28]], 300 -// CHECK: %[[VAL_32:.*]] = urem i32 %[[VAL_31]], 100 -// CHECK: %[[VAL_33:.*]] = udiv i32 %[[VAL_28]], 30000 -// CHECK: %[[VAL_34:.*]] = add nuw nsw i32 %[[VAL_16]], 3 -// CHECK: %[[VAL_35:.*]] = udiv i32 %[[VAL_34]], 1 -// CHECK: %[[VAL_36:.*]] = urem i32 %[[VAL_35]], 300 -// CHECK: %[[VAL_37:.*]] = udiv i32 %[[VAL_34]], 300 -// CHECK: %[[VAL_38:.*]] = urem i32 %[[VAL_37]], 100 -// CHECK: %[[VAL_39:.*]] = udiv i32 %[[VAL_34]], 30000 -// CHECK: %[[VAL_40:.*]] = icmp ult i32 %[[VAL_16]], 6000000 -// CHECK: br i1 %[[VAL_40]], label %[[VAL_41:.*]], label %[[VAL_2]] -// CHECK: wrapped_b.in_bounds-after: -// CHECK: br label %[[VAL_1]], !llvm.loop -// CHECK: loop.loop_exit: ; preds = %[[VAL_1]] -// CHECK: ret void -// CHECK: wrapped_b.in_bounds-true: -// CHECK: %[[VAL_42:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], ptr %[[VAL_43:.*]], i32 0, i32 %[[VAL_20]], i32 %[[VAL_21]], i32 %[[VAL_18]] -// CHECK: %[[VAL_44:.*]] = load float, ptr %[[VAL_42]], align 4, !invariant.load -// CHECK: %[[VAL_45:.*]] = getelementptr inbounds float, ptr %[[VAL_46:.*]], i32 %[[VAL_16]] -// CHECK: store float %[[VAL_44]], ptr %[[VAL_45]], align 4 -// CHECK: %[[VAL_47:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], ptr %[[VAL_43]], i32 0, i32 %[[VAL_26]], i32 %[[VAL_27]], i32 %[[VAL_24]] -// CHECK: %[[VAL_48:.*]] = load float, ptr %[[VAL_47]], align 4, !invariant.load -// CHECK: %[[VAL_49:.*]] = getelementptr inbounds float, ptr %[[VAL_46]], i32 %[[VAL_22]] -// CHECK: store float %[[VAL_48]], ptr %[[VAL_49]], align 4 -// CHECK: %[[VAL_50:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], ptr %[[VAL_43]], i32 0, i32 %[[VAL_32]], i32 %[[VAL_33]], i32 %[[VAL_30]] -// CHECK: %[[VAL_51:.*]] = load float, ptr %[[VAL_50]], align 4, !invariant.load -// CHECK: %[[VAL_52:.*]] = getelementptr inbounds float, ptr %[[VAL_46]], i32 %[[VAL_28]] -// CHECK: store float %[[VAL_51]], ptr %[[VAL_52]], align 4 -// CHECK: %[[VAL_53:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], ptr %[[VAL_43]], i32 0, i32 %[[VAL_38]], i32 %[[VAL_39]], i32 %[[VAL_36]] -// CHECK: %[[VAL_54:.*]] = load float, ptr %[[VAL_53]], align 4, !invariant.load -// CHECK: %[[VAL_55:.*]] = getelementptr inbounds float, ptr %[[VAL_46]], i32 %[[VAL_34]] -// CHECK: store float %[[VAL_54]], ptr %[[VAL_55]], align 4 -// CHECK: br label %[[VAL_2]] - -HloModule Test, is_scheduled=true - -fused_computation { - param_0 = f32[100,200,300]{2,1,0} parameter(0) - ROOT b.1 = f32[100,200,300]{2,0,1} copy(f32[100,200,300]{2,1,0} param_0) -} - -ENTRY main { - a = f32[100, 200, 300]{2,1,0} parameter(0) - ROOT wrapped_b = f32[100,200,300]{2,0,1} fusion(f32[100,200,300]{2,1,0} %a), kind=kLoop, calls=fused_computation -} diff --git a/xla/service/gpu/tests/dynamic_shared_memory_test.cc b/xla/service/gpu/tests/dynamic_shared_memory_test.cc index 28a8f0d82e71a..3442e24ad83a1 100644 --- a/xla/service/gpu/tests/dynamic_shared_memory_test.cc +++ b/xla/service/gpu/tests/dynamic_shared_memory_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,12 +17,16 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/strings/string_view.h" #include "xla/service/gpu/gpu_asm_opts_util.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/asm_compiler.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/xla.pb.h" #include "tsl/platform/status.h" @@ -130,15 +134,15 @@ TEST(SharedMemoryUseTest, ArrayReversalWorks) { // memory with it, read it back inverting both axes, // copy the result back to the host and verify it. se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("cuda").value(); + se::PlatformManager::PlatformWithName("cuda").value(); se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - se::Stream stream(executor); - stream.Init(); + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); // Use 90% of the available shared memory to verify that a fractional // amount works as well, not only the full size. - const int n_cols = executor->GetDeviceDescription().threads_per_block_limit(); - const int n_rows = + const unsigned n_cols = + executor->GetDeviceDescription().threads_per_block_limit(); + const unsigned n_rows = 0.9 * executor->GetDeviceDescription().shared_memory_per_block_optin() / n_cols; const int n_elements = n_cols * n_rows; @@ -169,18 +173,21 @@ TEST(SharedMemoryUseTest, ArrayReversalWorks) { } } - stream.ThenMemcpy(&device_buffer, host_buffer.data(), buffer_size_bytes); + TF_CHECK_OK( + stream->Memcpy(&device_buffer, host_buffer.data(), buffer_size_bytes)); se::DeviceMemory dev_n_cols = executor->AllocateScalar(); - stream.ThenMemcpy(&dev_n_cols, &n_cols, sizeof(uint32_t)); + TF_CHECK_OK(stream->Memcpy(&dev_n_cols, &n_cols, sizeof(uint32_t))); se::DeviceMemory dev_n_rows = executor->AllocateScalar(); - stream.ThenMemcpy(&dev_n_rows, &n_rows, sizeof(uint32_t)); - TF_CHECK_OK(stream.BlockHostUntilDone()); + TF_CHECK_OK(stream->Memcpy(&dev_n_rows, &n_rows, sizeof(uint32_t))); + TF_CHECK_OK(stream->BlockHostUntilDone()); TF_CHECK_OK(ExecuteKernelOnStream( *kernel, {device_buffer, dev_n_cols, dev_n_rows}, - {/*block_x_count=*/1, /*thread_x_count_per_block=*/n_cols}, &stream)); - TF_CHECK_OK(stream.BlockHostUntilDone()); - stream.ThenMemcpy(host_buffer.data(), device_buffer, buffer_size_bytes); - TF_CHECK_OK(stream.BlockHostUntilDone()); + {/*block_x_count=*/1, /*thread_x_count_per_block=*/n_cols}, + stream.get())); + TF_CHECK_OK(stream->BlockHostUntilDone()); + TF_CHECK_OK( + stream->Memcpy(host_buffer.data(), device_buffer, buffer_size_bytes)); + TF_CHECK_OK(stream->BlockHostUntilDone()); for (int row = 0; row < n_rows; ++row) { for (int col = 0; col < n_cols; ++col) { diff --git a/xla/service/gpu/tests/dynamic_update_slice_inplace.hlo b/xla/service/gpu/tests/dynamic_update_slice_inplace.hlo index 0c56797d56c45..3d0af18b08110 100644 --- a/xla/service/gpu/tests/dynamic_update_slice_inplace.hlo +++ b/xla/service/gpu/tests/dynamic_update_slice_inplace.hlo @@ -1,4 +1,4 @@ -// RUN: hlo_to_llvm_ir %s | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %{IR_SUBST} %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py @@ -18,36 +18,40 @@ // CHECK: %[[VAL_12:.*]] = select i1 %[[VAL_11]], i32 0, i32 %[[VAL_10]] // CHECK: %[[VAL_13:.*]] = icmp sle i32 0, %[[VAL_12]] // CHECK: %[[VAL_14:.*]] = select i1 %[[VAL_13]], i32 0, i32 %[[VAL_12]] -// CHECK: %[[VAL_15:.*]] = call i32 [[CTAIDX]] +// CHECK-PTX: %[[VAL_15:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x +// CHECK-GCN: %[[VAL_15:.*]] = call i32 @llvm.amdgcn.workgroup.id.x // CHECK: %[[VAL_16:.*]] = zext i32 %[[VAL_15]] to i64 -// CHECK: %[[VAL_17:.*]] = call i32 [[TIDX]] +// CHECK-PTX: %[[VAL_17:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x +// CHECK-GCN: %[[VAL_17:.*]] = call i32 @llvm.amdgcn.workitem.id.x // CHECK: %[[VAL_18:.*]] = zext i32 %[[VAL_17]] to i64 // CHECK-PTX: %[[VAL_19:.*]] = mul nuw nsw i64 %[[VAL_16]], 128 // CHECK-GCN: %[[VAL_19:.*]] = mul nuw nsw i64 %[[VAL_16]], 256 // CHECK: %[[VAL_20:.*]] = add nuw nsw i64 %[[VAL_19]], %[[VAL_18]] // CHECK: %[[VAL_21:.*]] = icmp ult i64 %[[VAL_20]], 98304 // CHECK: call void @llvm.assume(i1 %[[VAL_21]]) -// CHECK: %[[VAL_22:.*]] = udiv i64 %[[VAL_20]], 1 -// CHECK: %[[VAL_23:.*]] = urem i64 %[[VAL_22]], 1024 -// CHECK: %[[VAL_24:.*]] = udiv i64 %[[VAL_20]], 1024 -// CHECK: %[[VAL_25:.*]] = urem i64 %[[VAL_24]], 96 -// CHECK: %[[VAL_26:.*]] = udiv i64 %[[VAL_20]], 98304 -// CHECK: %[[VAL_27:.*]] = icmp ult i64 %[[VAL_20]], 98304 -// CHECK: br i1 %[[VAL_27]], label %[[VAL_28:.*]], label %[[VAL_29:.*]] -// CHECK: dynamic-update-slice.in_bounds-after: ; preds = %[[VAL_28]], %[[VAL_30:.*]] +// CHECK: %[[VAL_22:.*]] = add nuw nsw i64 %[[VAL_20]], 0 +// CHECK: %[[VAL_23:.*]] = udiv i64 %[[VAL_22]], 1 +// CHECK: %[[VAL_24:.*]] = urem i64 %[[VAL_23]], 1024 +// CHECK: %[[VAL_25:.*]] = udiv i64 %[[VAL_22]], 1024 +// CHECK: %[[VAL_26:.*]] = urem i64 %[[VAL_25]], 96 +// CHECK: %[[VAL_27:.*]] = udiv i64 %[[VAL_22]], 98304 +// CHECK: %[[VAL_28:.*]] = icmp ult i64 %[[VAL_20]], 98304 +// CHECK: br i1 %[[VAL_28]], label %[[VAL_29:.*]], label %[[VAL_30:.*]] +// CHECK: dynamic-update-slice.in_bounds-after: ; preds = %[[VAL_29]], %[[VAL_31:.*]] // CHECK: ret void -// CHECK: dynamic-update-slice.in_bounds-true: ; preds = %[[VAL_30]] -// CHECK: %[[VAL_31:.*]] = sext i32 %[[VAL_4]] to i64 -// CHECK: %[[VAL_32:.*]] = add i64 %[[VAL_31]], %[[VAL_26]] -// CHECK: %[[VAL_33:.*]] = sext i32 %[[VAL_9]] to i64 -// CHECK: %[[VAL_34:.*]] = add i64 %[[VAL_33]], %[[VAL_25]] -// CHECK: %[[VAL_35:.*]] = sext i32 %[[VAL_14]] to i64 -// CHECK: %[[VAL_36:.*]] = add i64 %[[VAL_35]], %[[VAL_23]] -// CHECK: %[[VAL_37:.*]] = getelementptr inbounds half, ptr %[[VAL_38:.*]], i64 %[[VAL_20]] -// CHECK: %[[VAL_39:.*]] = load half, ptr %[[VAL_37]], align 2, !invariant.load -// CHECK: %[[VAL_40:.*]] = getelementptr inbounds [50 x [96 x [1024 x half]]], ptr %[[VAL_41:.*]], i64 0, i64 %[[VAL_32]], i64 %[[VAL_34]], i64 %[[VAL_36]] -// CHECK: store half %[[VAL_39]], ptr %[[VAL_40]], align 2 -// CHECK: br label %[[VAL_29]] +// CHECK: dynamic-update-slice.in_bounds-true: ; preds = %[[VAL_31]] +// CHECK: %[[VAL_32:.*]] = sext i32 %[[VAL_4]] to i64 +// CHECK: %[[VAL_33:.*]] = add i64 %[[VAL_32]], %[[VAL_27]] +// CHECK: %[[VAL_34:.*]] = sext i32 %[[VAL_9]] to i64 +// CHECK: %[[VAL_35:.*]] = add i64 %[[VAL_34]], %[[VAL_26]] +// CHECK: %[[VAL_36:.*]] = sext i32 %[[VAL_14]] to i64 +// CHECK: %[[VAL_37:.*]] = add i64 %[[VAL_36]], %[[VAL_24]] +// CHECK: %[[VAL_38:.*]] = getelementptr half, ptr %[[VAL_39:.*]], i64 %[[VAL_20]] +// CHECK: %[[VAL_40:.*]] = getelementptr inbounds half, ptr %[[VAL_38]], i64 0 +// CHECK: %[[VAL_41:.*]] = load half, ptr %[[VAL_40]], align 2, !invariant.load +// CHECK: %[[VAL_42:.*]] = getelementptr inbounds [50 x [96 x [1024 x half]]], ptr %[[VAL_43:.*]], i64 0, i64 %[[VAL_33]], i64 %[[VAL_35]], i64 %[[VAL_37]] +// CHECK: store half %[[VAL_41]], ptr %[[VAL_42]], align 2 +// CHECK: br label %[[VAL_30]] HloModule TestModule, is_scheduled=true @@ -78,21 +82,26 @@ ENTRY entry { // CHECK: [[FUSION]].in_bounds-true: // CHECK: br i1 %{{.*}}, label %[[SLICE:.*]]-true, label %[[SLICE]]-false // CHECK: [[SLICE]]-after: -// CHECK: %[[VAL_82:.*]] = load i32, ptr [[ADDRSPACE_ANNOTATION]]%[[RET_VALUE_ADDR:.*]], align 4 -// CHECK: %[[VAL_83:.*]] = getelementptr inbounds i32, ptr %[[ARG2]], i32 %[[LINEAR_INDEX]] -// CHECK: store i32 %[[VAL_82]], ptr %[[VAL_83]], align 4 +// CHECK-PTX: %[[VAL_85:.*]] = load i32, ptr %[[RET_VALUE_ADDR:.*]], align 4 +// CHECK-GCN: %[[VAL_85:.*]] = load i32, ptr addrspace(5) %[[RET_VALUE_ADDR:.*]], align 4 +// CHECK: %[[VAL_86:.*]] = getelementptr i32, ptr %[[ARG2]], i32 %[[LINEAR_INDEX]] +// CHECK: %[[VAL_88:.*]] = getelementptr inbounds i32, ptr %[[VAL_86]], i32 0 +// CHECK: store i32 %[[VAL_85]], ptr %[[VAL_88]], align 4 // CHECK: br label %[[FUSION]].in_bounds-after // CHECK: [[SLICE]]-true: -// CHECK: %[[VAL_104:.*]] = getelementptr inbounds [6 x i32], ptr %[[ARG0]], i32 0, i32 %[[VAL_102:.*]] -// CHECK: %[[VAL_106:.*]] = load i32, ptr %[[VAL_104]], align 4, !invariant.load -// CHECK: %[[VAL_107:.*]] = load i32, ptr @1, align 4 -// CHECK: %[[VAL_108:.*]] = add i32 %[[VAL_106]], %[[VAL_107]] -// CHECK: store i32 %[[VAL_108]], ptr [[ADDRSPACE_ANNOTATION]]%[[RET_VALUE_ADDR]], align 4 +// CHECK: %[[VAL_108:.*]] = getelementptr inbounds [6 x i32], ptr %[[ARG0]], i32 0, i32 %[[VAL_106:.*]] +// CHECK: %[[VAL_110:.*]] = load i32, ptr %[[VAL_108]], align 4, !invariant.load +// CHECK: %[[VAL_111:.*]] = load i32, ptr @1, align 4 +// CHECK: %[[VAL_112:.*]] = add i32 %[[VAL_110]], %[[VAL_111]] +// CHECK-PTX: store i32 %[[VAL_112]], ptr %[[RET_VALUE_ADDR]], align 4 +// CHECK-GCN: store i32 %[[VAL_112]], ptr addrspace(5) %[[RET_VALUE_ADDR]], align 4 // CHECK: br label %[[SLICE]]-after // CHECK: [[SLICE]]-false: -// CHECK: %[[VAL_114:.*]] = getelementptr inbounds i32, ptr %[[ARG0]], i32 %[[LINEAR_INDEX]] -// CHECK: %[[VAL_115:.*]] = load i32, ptr %[[VAL_114]], align 4, !invariant.load -// CHECK: store i32 %[[VAL_115]], ptr [[ADDRSPACE_ANNOTATION]]%[[RET_VALUE_ADDR]], align 4 +// CHECK: %[[VAL_118:.*]] = getelementptr i32, ptr %[[ARG0]], i32 %[[LINEAR_INDEX]] +// CHECK: %[[VAL_119:.*]] = getelementptr inbounds i32, ptr %[[VAL_118]], i32 0 +// CHECK: %[[VAL_120:.*]] = load i32, ptr %[[VAL_119]], align 4, !invariant.load +// CHECK-PTX: store i32 %[[VAL_120]], ptr %[[RET_VALUE_ADDR]], align 4 +// CHECK-GCN: store i32 %[[VAL_120]], ptr addrspace(5) %[[RET_VALUE_ADDR]], align 4 // CHECK: br label %[[SLICE]]-after HloModule fusion, is_scheduled=true @@ -132,13 +141,13 @@ ENTRY main { // CHECK: [[DUS1]].in_bounds-after: // CHECK-NEXT: ret void // CHECK: [[DUS0]].in_bounds-true -// CHECK-DAG: getelementptr inbounds i16, ptr %[[ARG1]] -// CHECK-DAG: getelementptr inbounds i16, ptr %[[ARG3]] -// CHECK-DAG: getelementptr inbounds [10 x [11 x [12 x i16]]], ptr %[[ARG0]] +// CHECK-DAG: getelementptr bfloat, ptr %[[ARG1]], i64 %[[VAL_141:.*]] +// CHECK-DAG: getelementptr bfloat, ptr %[[ARG3]], i64 %[[VAL_141]] +// CHECK-DAG: getelementptr inbounds [10 x [11 x [12 x bfloat]]], ptr %[[ARG0]], i64 0, i64 %[[VAL_185:.*]], i64 %[[VAL_187:.*]], i64 %[[VAL_189:.*]] // CHECK: [[DUS1]].in_bounds-true -// CHECK-DAG: getelementptr inbounds i16, ptr %[[ARG1]] -// CHECK-DAG: getelementptr inbounds i16, ptr %[[ARG3]] -// CHECK-DAG: getelementptr inbounds [8 x [11 x [12 x i16]]], ptr %[[ARG2]] +// CHECK-DAG: getelementptr bfloat, ptr %[[ARG1]], i64 %[[VAL_173:.*]] +// CHECK-DAG: getelementptr bfloat, ptr %[[ARG3]], i64 %[[VAL_173]] +// CHECK-DAG: getelementptr inbounds [8 x [11 x [12 x bfloat]]], ptr %[[ARG2]], i64 0, i64 %[[VAL_208:.*]], i64 %[[VAL_210:.*]], i64 %[[VAL_212:.*]] HloModule MultipleInplaceDus, is_scheduled=true, input_output_alias={ {0}: (0, {}), {1}: (2, {}) } @@ -182,13 +191,13 @@ ENTRY main { // CHECK: [[DUS1]].in_bounds-after: // CHECK-NEXT: ret void // CHECK: [[DUS0]].in_bounds-true -// CHECK-DAG: getelementptr inbounds i16, ptr %[[ARG1]] -// CHECK-DAG: getelementptr inbounds i16, ptr %[[ARG3]] -// CHECK-DAG: getelementptr inbounds [10 x [11 x [12 x i16]]], ptr %[[ARG0]] +// CHECK-DAG: getelementptr bfloat, ptr %[[ARG1]], i64 %[[VAL_247:.*]] +// CHECK-DAG: getelementptr bfloat, ptr %[[ARG3]], i64 %[[VAL_247]] +// CHECK-DAG: getelementptr inbounds [10 x [11 x [12 x bfloat]]], ptr %[[ARG0]], i64 0, i64 %[[VAL_291:.*]], i64 %[[VAL_293:.*]], i64 %[[VAL_295:.*]] // CHECK: [[DUS1]].in_bounds-true -// CHECK-DAG: getelementptr inbounds i16, ptr %[[ARG1]] -// CHECK-DAG: getelementptr inbounds i16, ptr %[[ARG3]] -// CHECK-DAG: getelementptr inbounds [8 x [11 x [12 x i16]]], ptr %[[ARG2]] +// CHECK-DAG: getelementptr bfloat, ptr %[[ARG1]], i64 %[[VAL_279:.*]] +// CHECK-DAG: getelementptr bfloat, ptr %[[ARG3]], i64 %[[VAL_279]] +// CHECK-DAG: getelementptr inbounds [8 x [11 x [12 x bfloat]]], ptr %[[ARG2]], i64 0, i64 %[[VAL_314:.*]], i64 %[[VAL_316:.*]], i64 %[[VAL_318:.*]] HloModule MultipleInplaceDusWithTransposeBitcastToTheRoot, is_scheduled=true, input_output_alias={ {0}: (0, {}), {1}: (2, {}) } @@ -229,9 +238,9 @@ ENTRY main { // CHECK: [[DUS0]].in_bounds-after: // CHECK-NEXT: ret void // CHECK: [[DUS0]].in_bounds-true -// CHECK-DAG: getelementptr inbounds i16, ptr %[[ARG1]] -// CHECK-DAG: getelementptr inbounds i16, ptr %[[ARG2]] -// CHECK-DAG: getelementptr inbounds [10 x [11 x [12 x i16]]], ptr %[[ARG0]] +// CHECK-DAG: getelementptr bfloat, ptr %[[ARG1]], i64 %[[VAL_353:.*]] +// CHECK-DAG: getelementptr bfloat, ptr %[[ARG2]], i64 %[[VAL_353]] +// CHECK-DAG: getelementptr inbounds [10 x [11 x [12 x bfloat]]], ptr %[[ARG0]], i64 0, i64 %[[VAL_366:.*]], i64 %[[VAL_368:.*]], i64 %[[VAL_370:.*]] HloModule SingleInplaceDusWithTransposeBitcastToTheRoot, is_scheduled=true, input_output_alias={ {}: (0, {}) } @@ -268,9 +277,9 @@ ENTRY main { // CHECK: [[DUS0]].in_bounds-after: // CHECK-NEXT: ret void // CHECK: [[DUS0]].in_bounds-true -// CHECK-DAG: getelementptr inbounds i16, ptr %[[ARG1]] -// CHECK-DAG: getelementptr inbounds i16, ptr %[[ARG2]] -// CHECK-DAG: getelementptr inbounds [10 x [11 x [12 x i16]]], ptr %[[ARG0]] +// CHECK-DAG: getelementptr bfloat, ptr %[[ARG1]], i64 %[[VAL_408:.*]] +// CHECK-DAG: getelementptr bfloat, ptr %[[ARG2]], i64 %[[VAL_408:.*]] +// CHECK-DAG: getelementptr inbounds [10 x [11 x [12 x bfloat]]], ptr %[[ARG0]], i64 0, i64 %[[VAL_421:.*]], i64 %[[VAL_423:.*]], i64 %[[VAL_425:.*]] HloModule SingleInplaceDusWithReshapeBitcastToTheRoot, is_scheduled=true, input_output_alias={ {}: (0, {}) } @@ -307,9 +316,9 @@ ENTRY main { // CHECK: [[DUS0]].in_bounds-after: // CHECK-NEXT: ret void // CHECK: [[DUS0]].in_bounds-true -// CHECK-DAG: getelementptr inbounds i16, ptr %[[ARG1]] -// CHECK-DAG: getelementptr inbounds i16, ptr %[[ARG2]] -// CHECK-DAG: getelementptr inbounds [10 x [6 x [2 x [11 x i16]]]], ptr %[[ARG0]] +// CHECK-DAG: getelementptr bfloat, ptr %[[ARG1]], i64 %[[VAL_468:.*]] +// CHECK-DAG: getelementptr bfloat, ptr %[[ARG2]], i64 %[[VAL_468]] +// CHECK-DAG: getelementptr inbounds [10 x [6 x [2 x [11 x bfloat]]]], ptr %[[ARG0]], i64 0, i64 %[[VAL_483:.*]], i64 %[[VAL_485:.*]], i64 %[[VAL_487:.*]], i64 %[[VAL_489:.*]] HloModule SingleInplaceDusWithBitcastToTheRootAndFromTheParameter, is_scheduled=true, input_output_alias={ {}: (0, {}) } diff --git a/xla/service/gpu/tests/element_wise_row_vectorization.hlo b/xla/service/gpu/tests/element_wise_row_vectorization.hlo index bd6ab76d2e8a2..b49e155da0a68 100644 --- a/xla/service/gpu/tests/element_wise_row_vectorization.hlo +++ b/xla/service/gpu/tests/element_wise_row_vectorization.hlo @@ -1,5 +1,5 @@ -// RUN: hlo_to_llvm_ir --ptx --sm=70 --xla_disable_all_hlo_passes=true %s | FileCheck %s -// RUN: hlo_to_llvm_ir --xla_disable_all_hlo_passes=true %s | FileCheck --check-prefix=CHECK-LLVM %s +// RUN: hlo-opt %s --platform=gpu --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/v100.txtpb --split-input-file | FileCheck %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK-LLVM %s // We check that the row loads are vectorized. HloModule SimpleAddRowBroadcasting, is_scheduled=true @@ -310,5 +310,8 @@ ENTRY computation { ROOT %fusion.9 = f16[5000,65,65,32] fusion(p0, zero), kind=kLoop, calls=%fused_computation.1 } -// Check that we emit vectorized read. -// CHECK: ld.global.nc.v4.f32 +// Our codegen can't emit a vectorized load here, but it can emit a vectorized +// store. +// CHECK-LABEL: .visible .entry fusion_9 +// CHECK-COUNT-4: ld.global.nc.u16 +// CHECK: st.global.v4.b16 diff --git a/xla/service/gpu/tests/element_wise_row_vectorization_test.cc b/xla/service/gpu/tests/element_wise_row_vectorization_test.cc index 099a2edf17b20..ffd507f94959c 100644 --- a/xla/service/gpu/tests/element_wise_row_vectorization_test.cc +++ b/xla/service/gpu/tests/element_wise_row_vectorization_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -10,8 +10,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "xla/error_spec.h" #include "xla/tests/hlo_test_base.h" diff --git a/xla/service/gpu/tests/float_conversions_test.cc b/xla/service/gpu/tests/float_conversions_test.cc new file mode 100644 index 0000000000000..fbc078752d7b6 --- /dev/null +++ b/xla/service/gpu/tests/float_conversions_test.cc @@ -0,0 +1,198 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/strings/string_view.h" +#include "xla/error_spec.h" +#include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace gpu { + +class FloatConversionTest : public GpuCodegenTest {}; + +TEST_F(FloatConversionTest, F8E5M2ToF16) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f8e5m2[] parameter(0) + ROOT %c = f16[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F8E4M3FNToF16) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f8e4m3fn[] parameter(0) + ROOT %c = f16[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F8E4M3B11FNUZToF16) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f8e4m3b11fnuz[] parameter(0) + ROOT %c = f16[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F8E5M2FNUZToF16) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f8e5m2fnuz[] parameter(0) + ROOT %c = f16[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F8E4M3FNUZToF16) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f8e4m3fnuz[] parameter(0) + ROOT %c = f16[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, BF16ToF32) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = bf16[] parameter(0) + ROOT %c = f32[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F16ToF32) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f16[] parameter(0) + ROOT %c = f32[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F64ToF32) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f64[] parameter(0) + ROOT %c = f32[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F16ToF8E5M2) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f16[] parameter(0) + ROOT %c = f8e5m2[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F16ToF8E4M3FN) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f16[] parameter(0) + ROOT %c = f8e4m3fn[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F16ToF8E4M3B11FNUZ) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f16[] parameter(0) + ROOT %c = f8e4m3b11fnuz[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F16ToF8E5M2FNUZ) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f16[] parameter(0) + ROOT %c = f8e5m2fnuz[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F16ToF8E4M3FNUZ) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f16[] parameter(0) + ROOT %c = f8e4m3fnuz[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F32ToBF16) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f32[] parameter(0) + ROOT %c = bf16[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F32ToF16) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f32[] parameter(0) + ROOT %c = f16[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F32ToF64) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + %p = f32[] parameter(0) + ROOT %c = f64[] convert(%p) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F32ToPred) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + iota = f32[1000] iota(), iota_dimension=0 + c500 = f32[] constant(500) + c500_b = f32[1000] broadcast(c500), dimensions={} + sub = f32[1000] subtract(iota, c500_b) + ROOT c = pred[1000] convert(sub) + })", + ErrorSpec{1e-5, 1e-5})); + + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + n = f32[] constant(nan) + ROOT c = pred[] convert(n) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, F32ToS8) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + iota = f32[1000] iota(), iota_dimension=0 + c500 = f32[] constant(500) + c500_b = f32[1000] broadcast(c500), dimensions={} + sub = f32[1000] subtract(iota, c500_b) + ROOT c = s8[1000] convert(sub) + })", + ErrorSpec{1e-5, 1e-5})); + + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + n = f32[] constant(nan) + ROOT c = s8[] convert(n) + })", + ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(FloatConversionTest, BF16ToS16IsBroken) { + EXPECT_TRUE(RunAndCompare(R"(ENTRY m { + iota = u16[65536] iota(), iota_dimension=0 + bc = bf16[65536] bitcast-convert(iota) + ROOT c = s16[65536] convert(bc) + })", + ErrorSpec{1e-5, 1e-5})); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/tests/fused_scatter.hlo b/xla/service/gpu/tests/fused_scatter.hlo index ab2a5d490eb38..9a30436ebfa38 100644 --- a/xla/service/gpu/tests/fused_scatter.hlo +++ b/xla/service/gpu/tests/fused_scatter.hlo @@ -1,106 +1,45 @@ -// RUN: hlo_to_llvm_ir %s | FileCheck %{IR_SUBST} %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py -// CHECK-LABEL: entry: -// CHECK: %[[VAL_0:.*]] = call i32 [[CTAIDX]] -// CHECK: %[[VAL_1:.*]] = call i32 [[TIDX]] -// CHECK: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 2 -// CHECK: %[[VAL_3:.*]] = add nuw nsw i32 %[[VAL_2]], %[[VAL_1]] -// CHECK: %[[VAL_4:.*]] = icmp ult i32 %[[VAL_3]], 2 -// CHECK: call void @llvm.assume(i1 %[[VAL_4]]) -// CHECK: %[[VAL_5:.*]] = udiv i32 %[[VAL_3]], 1 -// CHECK: %[[VAL_6:.*]] = icmp ult i32 %[[VAL_3]], 2 -// CHECK: br i1 %[[VAL_6]], label %[[VAL_7:.*]], label %[[VAL_8:.*]] -// CHECK: indices.in_bounds-after: ; preds = %[[VAL_7]], %[[VAL_9:.*]] +// CHECK: define void @wrapped_scatter +// CHECK: %[[VAL_70:.*]] = alloca i32, align 4 +// CHECK-PTX: %[[VAL_71:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x +// CHECK-GCN: %[[VAL_71:.*]] = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK-PTX: %[[VAL_72:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x +// CHECK-GCN: %[[VAL_72:.*]] = call i32 @llvm.amdgcn.workitem.id.x +// CHECK: %[[VAL_73:.*]] = mul nuw nsw i32 %[[VAL_71]], 6 +// CHECK: %[[VAL_74:.*]] = add nuw nsw i32 %[[VAL_73]], %[[VAL_72]] +// CHECK: %[[VAL_75:.*]] = icmp ult i32 %[[VAL_74]], 6 +// CHECK: call void @llvm.assume(i1 %[[VAL_75]]) +// CHECK: %[[VAL_76:.*]] = add nuw nsw i32 %[[VAL_74]], 0 +// CHECK: %[[VAL_77:.*]] = udiv i32 %[[VAL_76]], 1 +// CHECK: %[[VAL_78:.*]] = urem i32 %[[VAL_77]], 3 +// CHECK: %[[VAL_79:.*]] = udiv i32 %[[VAL_76]], 3 +// CHECK: %[[VAL_80:.*]] = icmp ult i32 %[[VAL_74]], 6 +// CHECK: br i1 %[[VAL_80]], label %[[VAL_81:.*]], label %[[VAL_82:.*]] +// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_83:.*]], %[[VAL_84:.*]] // CHECK: ret void -// CHECK: indices.in_bounds-true: ; preds = %[[VAL_9]] -// CHECK: %[[VAL_10:.*]] = getelementptr inbounds i32, ptr %[[VAL_11:.*]], i32 %[[VAL_3]] -// CHECK: %[[VAL_12:.*]] = load i32, ptr %[[VAL_10]], align 4, !invariant.load -// CHECK: %[[VAL_13:.*]] = getelementptr inbounds i32, ptr %[[VAL_11]], i32 %[[VAL_3]] -// CHECK: %[[VAL_14:.*]] = load i32, ptr %[[VAL_13]], align 4, !invariant.load -// CHECK: %[[VAL_15:.*]] = add i32 %[[VAL_12]], %[[VAL_14]] -// CHECK: %[[VAL_16:.*]] = getelementptr inbounds i32, ptr %[[VAL_17:.*]], i32 %[[VAL_3]] -// CHECK: store i32 %[[VAL_15]], ptr %[[VAL_16]], align 4 -// CHECK: br label %[[VAL_8]] -// CHECK: entry: -// CHECK: %[[VAL_18:.*]] = call i32 [[CTAIDX]] -// CHECK: %[[VAL_19:.*]] = call i32 [[TIDX]] -// CHECK: %[[VAL_20:.*]] = mul nuw nsw i32 %[[VAL_18]], 6 -// CHECK: %[[VAL_21:.*]] = add nuw nsw i32 %[[VAL_20]], %[[VAL_19]] -// CHECK: %[[VAL_22:.*]] = icmp ult i32 %[[VAL_21]], 6 -// CHECK: call void @llvm.assume(i1 %[[VAL_22]]) -// CHECK: %[[VAL_23:.*]] = udiv i32 %[[VAL_21]], 1 -// CHECK: %[[VAL_24:.*]] = urem i32 %[[VAL_23]], 3 -// CHECK: %[[VAL_25:.*]] = udiv i32 %[[VAL_21]], 3 -// CHECK: %[[VAL_26:.*]] = icmp ult i32 %[[VAL_21]], 6 -// CHECK: br i1 %[[VAL_26]], label %[[VAL_27:.*]], label %[[VAL_28:.*]] -// CHECK: updates.in_bounds-after: ; preds = %[[VAL_27]], %[[VAL_29:.*]] -// CHECK: ret void -// CHECK: updates.in_bounds-true: ; preds = %[[VAL_29]] -// CHECK: %[[VAL_30:.*]] = getelementptr inbounds i32, ptr %[[VAL_31:.*]], i32 %[[VAL_21]] -// CHECK: %[[VAL_32:.*]] = load i32, ptr %[[VAL_30]], align 4, !invariant.load -// CHECK: %[[VAL_33:.*]] = getelementptr inbounds i32, ptr %[[VAL_31]], i32 %[[VAL_21]] -// CHECK: %[[VAL_34:.*]] = load i32, ptr %[[VAL_33]], align 4, !invariant.load -// CHECK: %[[VAL_35:.*]] = add i32 %[[VAL_32]], %[[VAL_34]] -// CHECK: %[[VAL_36:.*]] = getelementptr inbounds i32, ptr %[[VAL_37:.*]], i32 %[[VAL_21]] -// CHECK: store i32 %[[VAL_35]], ptr %[[VAL_36]], align 4 -// CHECK: br label %[[VAL_28]] -// CHECK: entry: -// CHECK: %[[VAL_38:.*]] = call i32 [[CTAIDX]] -// CHECK: %[[VAL_39:.*]] = call i32 [[TIDX]] -// CHECK: %[[VAL_40:.*]] = mul nuw nsw i32 %[[VAL_38]], 9 -// CHECK: %[[VAL_41:.*]] = add nuw nsw i32 %[[VAL_40]], %[[VAL_39]] -// CHECK: %[[VAL_42:.*]] = icmp ult i32 %[[VAL_41]], 9 -// CHECK: call void @llvm.assume(i1 %[[VAL_42]]) -// CHECK: %[[VAL_43:.*]] = udiv i32 %[[VAL_41]], 1 -// CHECK: %[[VAL_44:.*]] = urem i32 %[[VAL_43]], 3 -// CHECK: %[[VAL_45:.*]] = udiv i32 %[[VAL_41]], 3 -// CHECK: %[[VAL_46:.*]] = icmp ult i32 %[[VAL_41]], 9 -// CHECK: br i1 %[[VAL_46]], label %[[VAL_47:.*]], label %[[VAL_48:.*]] -// CHECK: operand.in_bounds-after: ; preds = %[[VAL_47]], %[[VAL_49:.*]] -// CHECK: ret void -// CHECK: operand.in_bounds-true: ; preds = %[[VAL_49]] -// CHECK: %[[VAL_50:.*]] = getelementptr inbounds i32, ptr %[[VAL_51:.*]], i32 %[[VAL_41]] -// CHECK: %[[VAL_52:.*]] = load i32, ptr %[[VAL_50]], align 4, !invariant.load -// CHECK: %[[VAL_53:.*]] = getelementptr inbounds i32, ptr %[[VAL_51]], i32 %[[VAL_41]] -// CHECK: %[[VAL_54:.*]] = load i32, ptr %[[VAL_53]], align 4, !invariant.load -// CHECK: %[[VAL_55:.*]] = add i32 %[[VAL_52]], %[[VAL_54]] -// CHECK: %[[VAL_56:.*]] = getelementptr inbounds i32, ptr %[[VAL_57:.*]], i32 %[[VAL_41]] -// CHECK: store i32 %[[VAL_55]], ptr %[[VAL_56]], align 4 -// CHECK: br label %[[VAL_48]] -// CHECK: entry: -// CHECK: %[[VAL_58:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_59:.*]] = call i32 [[CTAIDX]] -// CHECK: %[[VAL_60:.*]] = call i32 [[TIDX]] -// CHECK: %[[VAL_61:.*]] = mul nuw nsw i32 %[[VAL_59]], 6 -// CHECK: %[[VAL_62:.*]] = add nuw nsw i32 %[[VAL_61]], %[[VAL_60]] -// CHECK: %[[VAL_63:.*]] = icmp ult i32 %[[VAL_62]], 6 -// CHECK: call void @llvm.assume(i1 %[[VAL_63]]) -// CHECK: %[[VAL_64:.*]] = udiv i32 %[[VAL_62]], 1 -// CHECK: %[[VAL_65:.*]] = urem i32 %[[VAL_64]], 3 -// CHECK: %[[VAL_66:.*]] = udiv i32 %[[VAL_62]], 3 -// CHECK: %[[VAL_67:.*]] = icmp ult i32 %[[VAL_62]], 6 -// CHECK: br i1 %[[VAL_67]], label %[[VAL_68:.*]], label %[[VAL_69:.*]] -// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_70:.*]], %[[VAL_71:.*]] -// CHECK: ret void -// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_71]] -// CHECK: %[[VAL_72:.*]] = getelementptr inbounds [2 x i32], ptr %[[VAL_73:.*]], i32 0, i32 %[[VAL_66]] -// CHECK: %[[VAL_74:.*]] = load i32, ptr %[[VAL_72]], align 4, !invariant.load -// CHECK: %[[VAL_75:.*]] = add i32 0, %[[VAL_74]] -// CHECK: %[[VAL_76:.*]] = icmp ult i32 %[[VAL_74]], 3 -// CHECK: %[[VAL_77:.*]] = and i1 true, %[[VAL_76]] -// CHECK: br i1 %[[VAL_77]], label %[[VAL_78:.*]], label %[[VAL_70]] -// CHECK: scatter.in_bounds-after3: ; preds = %[[VAL_78]], %[[VAL_68]] -// CHECK: br label %[[VAL_69]] -// CHECK: scatter.in_bounds-true2: ; preds = %[[VAL_68]] -// CHECK: %[[VAL_79:.*]] = getelementptr inbounds [3 x [3 x i32]], ptr %[[VAL_80:.*]], i32 0, i32 %[[VAL_75]], i32 %[[VAL_65]] -// CHECK: %[[VAL_81:.*]] = getelementptr inbounds i32, ptr %[[VAL_82:.*]], i32 %[[VAL_62]] -// CHECK: %[[VAL_83:.*]] = load i32, ptr %[[VAL_81]], align 4, !invariant.load -// CHECK: store i32 %[[VAL_83]], ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_58]], align 4 -// CHECK: %[[VAL_84:.*]] = load i32, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_58]], align 4 -// CHECK: store atomic i32 %[[VAL_84]], ptr %[[VAL_79]] unordered, align 4 -// CHECK: br label %[[VAL_70]] +// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_84]] +// CHECK: %[[VAL_85:.*]] = getelementptr inbounds [2 x i32], ptr %[[VAL_86:.*]], i32 0, i32 %[[VAL_79]] +// CHECK: %[[VAL_87:.*]] = load i32, ptr %[[VAL_85]], align 4, !invariant.load +// CHECK: %[[VAL_88:.*]] = add i32 0, %[[VAL_87]] +// CHECK: %[[VAL_89:.*]] = icmp ult i32 %[[VAL_87]], 3 +// CHECK: %[[VAL_90:.*]] = and i1 true, %[[VAL_89]] +// CHECK: br i1 %[[VAL_90]], label %[[VAL_91:.*]], label %[[VAL_83]] +// CHECK: scatter.in_bounds-after3: ; preds = %[[VAL_91]], %[[VAL_81]] +// CHECK: br label %[[VAL_82]] +// CHECK: scatter.in_bounds-true2: ; preds = %[[VAL_81]] +// CHECK: %[[VAL_92:.*]] = getelementptr inbounds [3 x [3 x i32]], ptr %[[VAL_93:.*]], i32 0, i32 %[[VAL_88]], i32 %[[VAL_78]] +// CHECK: %[[VAL_94:.*]] = getelementptr i32, ptr %[[VAL_95:.*]], i32 %[[VAL_74]] +// CHECK: %[[VAL_96:.*]] = getelementptr inbounds i32, ptr %[[VAL_94]], i32 0 +// CHECK: %[[VAL_97:.*]] = load i32, ptr %[[VAL_96]], align 4, !invariant.load +// CHECK-PTX: store i32 %[[VAL_97]], ptr %[[VAL_70]], align 4 +// CHECK-GCN: store i32 %[[VAL_97]], ptr addrspace(5) %[[VAL_70]], align 4 +// CHECK-PTX: %[[VAL_98:.*]] = load i32, ptr %[[VAL_70]], align 4 +// CHECK-GCN: %[[VAL_98:.*]] = load i32, ptr addrspace(5) %[[VAL_70]], align 4 +// CHECK: store atomic i32 %[[VAL_98]], ptr %[[VAL_92]] unordered, align 4 +// CHECK: br label %[[VAL_83]] HloModule TensorFlowScatterV1, is_scheduled=true diff --git a/xla/service/gpu/tests/fused_slice.hlo b/xla/service/gpu/tests/fused_slice.hlo index 2a1ba79a77ba7..4affcb0de7533 100644 --- a/xla/service/gpu/tests/fused_slice.hlo +++ b/xla/service/gpu/tests/fused_slice.hlo @@ -1,80 +1,83 @@ -// RUN: hlo_to_llvm_ir %s | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %{IR_SUBST} %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py // CHECK-LABEL: entry: -// CHECK: %[[VAL_0:.*]] = call i32 [[CTAIDX]] -// CHECK: %[[VAL_1:.*]] = call i32 [[TIDX]] +// CHECK-PTX: %[[VAL_0:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x +// CHECK-GCN: %[[VAL_0:.*]] = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK-PTX: %[[VAL_1:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x +// CHECK-GCN: %[[VAL_1:.*]] = call i32 @llvm.amdgcn.workitem.id.x // CHECK-PTX: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 128 // CHECK-GCN: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 256 // CHECK: %[[VAL_3:.*]] = add nuw nsw i32 %[[VAL_2]], %[[VAL_1]] // CHECK: %[[VAL_4:.*]] = icmp ult i32 %[[VAL_3]], 2048 // CHECK: call void @llvm.assume(i1 %[[VAL_4]]) -// CHECK: %[[VAL_5:.*]] = udiv i32 %[[VAL_3]], 1 -// CHECK: %[[VAL_6:.*]] = icmp ult i32 %[[VAL_3]], 2047 -// CHECK: br i1 %[[VAL_6]], label %[[VAL_7:.*]], label %[[VAL_8:.*]] -// CHECK: fusion.in_bounds-after: ; preds = %[[VAL_9:.*]], %[[VAL_10:.*]] +// CHECK: %[[VAL_5:.*]] = add nuw nsw i32 %[[VAL_3]], 0 +// CHECK: %[[VAL_6:.*]] = udiv i32 %[[VAL_5]], 1 +// CHECK: %[[VAL_7:.*]] = icmp ult i32 %[[VAL_3]], 2047 +// CHECK: br i1 %[[VAL_7]], label %[[VAL_8:.*]], label %[[VAL_9:.*]] +// CHECK: fusion.in_bounds-after: ; preds = %[[VAL_10:.*]], %[[VAL_11:.*]] // CHECK: ret void -// CHECK: fusion.in_bounds-true: ; preds = %[[VAL_10]] -// CHECK: br label %[[VAL_11:.*]] -// CHECK: concat_index_from_operand_id0: ; preds = %[[VAL_12:.*]] -// CHECK: %[[VAL_13:.*]] = phi i32 [ 0, %[[VAL_12]] ] -// CHECK: %[[VAL_14:.*]] = sub nsw i32 %[[VAL_5]], %[[VAL_13]] -// CHECK: %[[VAL_15:.*]] = getelementptr inbounds [1024 x half], ptr %[[VAL_16:.*]], i32 0, i32 %[[VAL_14]] -// CHECK: %[[VAL_17:.*]] = load half, ptr %[[VAL_15]], align 2, !invariant.load -// CHECK: %[[VAL_18:.*]] = getelementptr inbounds [1024 x half], ptr %[[VAL_19:.*]], i32 0, i32 %[[VAL_14]] -// CHECK: %[[VAL_20:.*]] = load half, ptr %[[VAL_18]], align 2, !invariant.load -// CHECK: %[[VAL_21:.*]] = fmul half %[[VAL_17]], %[[VAL_20]] -// CHECK: br label %[[VAL_22:.*]] -// CHECK: concat_index_from_operand_id1: ; preds = %[[VAL_23:.*]] -// CHECK: %[[VAL_24:.*]] = phi i32 [ 1024, %[[VAL_23]] ] -// CHECK: %[[VAL_25:.*]] = sub nsw i32 %[[VAL_5]], %[[VAL_24]] -// CHECK: %[[VAL_26:.*]] = getelementptr inbounds [1023 x half], ptr %[[VAL_27:.*]], i32 0, i32 %[[VAL_25]] -// CHECK: %[[VAL_28:.*]] = load half, ptr %[[VAL_26]], align 2, !invariant.load -// CHECK: %[[VAL_29:.*]] = getelementptr inbounds [1023 x half], ptr %[[VAL_30:.*]], i32 0, i32 %[[VAL_25]] -// CHECK: %[[VAL_31:.*]] = load half, ptr %[[VAL_29]], align 2, !invariant.load -// CHECK: %[[VAL_32:.*]] = fadd half %[[VAL_28]], %[[VAL_31]] -// CHECK: br label %[[VAL_22]] -// CHECK: concatenate.pivot.1024.: ; preds = %[[VAL_7]] -// CHECK: %[[VAL_33:.*]] = icmp ult i32 %[[VAL_5]], 1024 -// CHECK: br i1 %[[VAL_33]], label %[[VAL_12]], label %[[VAL_23]] -// CHECK: concatenate.pivot.0.: ; preds = %[[VAL_11]] -// CHECK: br label %[[VAL_34:.*]] -// CHECK: concatenate.pivot.1024.1: ; preds = %[[VAL_11]] +// CHECK: fusion.in_bounds-true: ; preds = %[[VAL_11]] +// CHECK: br label %[[VAL_12:.*]] +// CHECK: concat_index_from_operand_id0: ; preds = %[[VAL_13:.*]] +// CHECK: %[[VAL_14:.*]] = phi i32 [ 0, %[[VAL_13]] ] +// CHECK: %[[VAL_15:.*]] = sub nsw i32 %[[VAL_6]], %[[VAL_14]] +// CHECK: %[[VAL_16:.*]] = getelementptr inbounds [1024 x half], ptr %[[VAL_17:.*]], i32 0, i32 %[[VAL_15]] +// CHECK: %[[VAL_18:.*]] = load half, ptr %[[VAL_16]], align 2, !invariant.load +// CHECK: %[[VAL_19:.*]] = getelementptr inbounds [1024 x half], ptr %[[VAL_20:.*]], i32 0, i32 %[[VAL_15]] +// CHECK: %[[VAL_21:.*]] = load half, ptr %[[VAL_19]], align 2, !invariant.load +// CHECK: %[[VAL_22:.*]] = fmul half %[[VAL_18]], %[[VAL_21]] +// CHECK: br label %[[VAL_23:.*]] +// CHECK: concat_index_from_operand_id1: ; preds = %[[VAL_24:.*]] +// CHECK: %[[VAL_25:.*]] = phi i32 [ 1024, %[[VAL_24]] ] +// CHECK: %[[VAL_26:.*]] = sub nsw i32 %[[VAL_6]], %[[VAL_25]] +// CHECK: %[[VAL_27:.*]] = getelementptr inbounds [1023 x half], ptr %[[VAL_28:.*]], i32 0, i32 %[[VAL_26]] +// CHECK: %[[VAL_29:.*]] = load half, ptr %[[VAL_27]], align 2, !invariant.load +// CHECK: %[[VAL_30:.*]] = getelementptr inbounds [1023 x half], ptr %[[VAL_31:.*]], i32 0, i32 %[[VAL_26]] +// CHECK: %[[VAL_32:.*]] = load half, ptr %[[VAL_30]], align 2, !invariant.load +// CHECK: %[[VAL_33:.*]] = fadd half %[[VAL_29]], %[[VAL_32]] +// CHECK: br label %[[VAL_23]] +// CHECK: concatenate.pivot.1024.: ; preds = %[[VAL_8]] +// CHECK: %[[VAL_34:.*]] = icmp ult i32 %[[VAL_6]], 1024 +// CHECK: br i1 %[[VAL_34]], label %[[VAL_13]], label %[[VAL_24]] +// CHECK: concatenate.pivot.0.: ; preds = %[[VAL_12]] // CHECK: br label %[[VAL_35:.*]] -// CHECK: concat.1.merge: ; preds = %[[VAL_35]], %[[VAL_34]] -// CHECK: %[[VAL_36:.*]] = phi half [ %[[VAL_21]], %[[VAL_34]] ], [ %[[VAL_32]], %[[VAL_35]] ] -// CHECK: %[[VAL_37:.*]] = icmp sge i32 %[[VAL_5]], 0 -// CHECK: %[[VAL_38:.*]] = icmp slt i32 %[[VAL_5]], 1024 -// CHECK: %[[VAL_39:.*]] = and i1 %[[VAL_37]], %[[VAL_38]] -// CHECK: br i1 %[[VAL_39]], label %[[VAL_40:.*]], label %[[VAL_41:.*]] -// CHECK: slice0-after: ; preds = %[[VAL_40]], %[[VAL_22]] -// CHECK: %[[VAL_42:.*]] = icmp sge i32 %[[VAL_5]], 1024 -// CHECK: %[[VAL_43:.*]] = icmp slt i32 %[[VAL_5]], 2047 -// CHECK: %[[VAL_44:.*]] = and i1 %[[VAL_42]], %[[VAL_43]] -// CHECK: br i1 %[[VAL_44]], label %[[VAL_45:.*]], label %[[VAL_46:.*]] -// CHECK: slice1-after: ; preds = %[[VAL_45]], %[[VAL_41]] -// CHECK: %[[VAL_47:.*]] = icmp sge i32 %[[VAL_5]], 2047 -// CHECK: %[[VAL_48:.*]] = icmp slt i32 %[[VAL_5]], 2047 -// CHECK: %[[VAL_49:.*]] = and i1 %[[VAL_47]], %[[VAL_48]] -// CHECK: br i1 %[[VAL_49]], label %[[VAL_50:.*]], label %[[VAL_9]] -// CHECK: slice2-after: ; preds = %[[VAL_50]], %[[VAL_46]] -// CHECK: br label %[[VAL_8]] -// CHECK: slice0-true: ; preds = %[[VAL_22]] -// CHECK: %[[VAL_51:.*]] = sub i32 %[[VAL_5]], 0 -// CHECK: %[[VAL_52:.*]] = getelementptr inbounds [1024 x half], ptr %[[VAL_53:.*]], i32 0, i32 %[[VAL_51]] -// CHECK: store half %[[VAL_36]], ptr %[[VAL_52]], align 2 -// CHECK: br label %[[VAL_41]] -// CHECK: slice1-true: ; preds = %[[VAL_41]] -// CHECK: %[[VAL_54:.*]] = sub i32 %[[VAL_5]], 1024 -// CHECK: %[[VAL_55:.*]] = getelementptr inbounds [1023 x half], ptr %[[VAL_56:.*]], i32 0, i32 %[[VAL_54]] -// CHECK: store half %[[VAL_36]], ptr %[[VAL_55]], align 2 -// CHECK: br label %[[VAL_46]] -// CHECK: slice2-true: ; preds = %[[VAL_46]] -// CHECK: %[[VAL_57:.*]] = sub i32 %[[VAL_5]], 2047 -// CHECK: %[[VAL_58:.*]] = getelementptr inbounds [0 x half], ptr %[[VAL_59:.*]], i32 0, i32 %[[VAL_57]] -// CHECK: store half %[[VAL_36]], ptr %[[VAL_58]], align 2 +// CHECK: concatenate.pivot.1024.1: ; preds = %[[VAL_12]] +// CHECK: br label %[[VAL_36:.*]] +// CHECK: concat.1.merge: ; preds = %[[VAL_36]], %[[VAL_35]] +// CHECK: %[[VAL_37:.*]] = phi half [ %[[VAL_22]], %[[VAL_35]] ], [ %[[VAL_33]], %[[VAL_36]] ] +// CHECK: %[[VAL_38:.*]] = icmp sge i32 %[[VAL_6]], 0 +// CHECK: %[[VAL_39:.*]] = icmp slt i32 %[[VAL_6]], 1024 +// CHECK: %[[VAL_40:.*]] = and i1 %[[VAL_38]], %[[VAL_39]] +// CHECK: br i1 %[[VAL_40]], label %[[VAL_41:.*]], label %[[VAL_42:.*]] +// CHECK: slice0-after: ; preds = %[[VAL_41]], %[[VAL_23]] +// CHECK: %[[VAL_43:.*]] = icmp sge i32 %[[VAL_6]], 1024 +// CHECK: %[[VAL_44:.*]] = icmp slt i32 %[[VAL_6]], 2047 +// CHECK: %[[VAL_45:.*]] = and i1 %[[VAL_43]], %[[VAL_44]] +// CHECK: br i1 %[[VAL_45]], label %[[VAL_46:.*]], label %[[VAL_47:.*]] +// CHECK: slice1-after: ; preds = %[[VAL_46]], %[[VAL_42]] +// CHECK: %[[VAL_48:.*]] = icmp sge i32 %[[VAL_6]], 2047 +// CHECK: %[[VAL_49:.*]] = icmp slt i32 %[[VAL_6]], 2047 +// CHECK: %[[VAL_50:.*]] = and i1 %[[VAL_48]], %[[VAL_49]] +// CHECK: br i1 %[[VAL_50]], label %[[VAL_51:.*]], label %[[VAL_10]] +// CHECK: slice2-after: ; preds = %[[VAL_51]], %[[VAL_47]] // CHECK: br label %[[VAL_9]] +// CHECK: slice0-true: ; preds = %[[VAL_23]] +// CHECK: %[[VAL_52:.*]] = sub i32 %[[VAL_6]], 0 +// CHECK: %[[VAL_53:.*]] = getelementptr inbounds [1024 x half], ptr %[[VAL_54:.*]], i32 0, i32 %[[VAL_52]] +// CHECK: store half %[[VAL_37]], ptr %[[VAL_53]], align 2 +// CHECK: br label %[[VAL_42]] +// CHECK: slice1-true: ; preds = %[[VAL_42]] +// CHECK: %[[VAL_55:.*]] = sub i32 %[[VAL_6]], 1024 +// CHECK: %[[VAL_56:.*]] = getelementptr inbounds [1023 x half], ptr %[[VAL_57:.*]], i32 0, i32 %[[VAL_55]] +// CHECK: store half %[[VAL_37]], ptr %[[VAL_56]], align 2 +// CHECK: br label %[[VAL_47]] +// CHECK: slice2-true: ; preds = %[[VAL_47]] +// CHECK: %[[VAL_58:.*]] = sub i32 %[[VAL_6]], 2047 +// CHECK: %[[VAL_59:.*]] = getelementptr inbounds [0 x half], ptr %[[VAL_60:.*]], i32 0, i32 %[[VAL_58]] +// CHECK: store half %[[VAL_37]], ptr %[[VAL_59]], align 2 +// CHECK: br label %[[VAL_10]] HloModule input_fusion_with_a_tuple_of_slices, is_scheduled=true diff --git a/xla/service/gpu/tests/fused_slice_different_operands.hlo b/xla/service/gpu/tests/fused_slice_different_operands.hlo deleted file mode 100644 index 5d1f11d0ba9cb..0000000000000 --- a/xla/service/gpu/tests/fused_slice_different_operands.hlo +++ /dev/null @@ -1,93 +0,0 @@ -// RUN: hlo_to_llvm_ir %s | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %{IR_SUBST} %s - -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// The script is designed to make adding checks to -// a test case fast, it is *not* designed to be authoritative -// about what constitutes a good test! The CHECK should be -// minimized and named to reflect the test intent. - - -// CHECK-LABEL: entry: -// CHECK: %[[VAL_0:.*]] = call i32 [[CTAIDX]] -// CHECK: %[[VAL_1:.*]] = call i32 [[TIDX]] -// CHECK-PTX: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 128 -// CHECK-GCN: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 256 -// CHECK: %[[VAL_3:.*]] = add nuw nsw i32 %[[VAL_2]], %[[VAL_1]] -// CHECK: %[[VAL_4:.*]] = icmp ult i32 %[[VAL_3]], 1024 -// CHECK: call void @llvm.assume(i1 %[[VAL_4]]) -// CHECK: %[[VAL_5:.*]] = udiv i32 %[[VAL_3]], 1 -// CHECK: %[[VAL_6:.*]] = icmp ult i32 %[[VAL_3]], 1024 -// CHECK: br i1 %[[VAL_6]], label %[[VAL_7:.*]], label %[[VAL_8:.*]] -// CHECK: fusion.in_bounds-after: ; preds = %[[VAL_9:.*]], %[[VAL_10:.*]] -// CHECK: ret void -// CHECK: fusion.in_bounds-true: ; preds = %[[VAL_10]] -// CHECK: %[[VAL_11:.*]] = add i32 %[[VAL_5]], 0 -// CHECK: br label %[[VAL_12:.*]] -// CHECK: concat_index_from_operand_id0: ; preds = %[[VAL_13:.*]] -// CHECK: %[[VAL_14:.*]] = phi i32 [ 0, %[[VAL_13]] ] -// CHECK: %[[VAL_15:.*]] = sub nsw i32 %[[VAL_11]], %[[VAL_14]] -// CHECK: %[[VAL_16:.*]] = getelementptr inbounds [1024 x half], ptr %[[VAL_17:.*]], i32 0, i32 %[[VAL_15]] -// CHECK: %[[VAL_18:.*]] = load half, ptr %[[VAL_16]], align 2, !invariant.load -// CHECK: %[[VAL_19:.*]] = getelementptr inbounds [1024 x half], ptr %[[VAL_20:.*]], i32 0, i32 %[[VAL_15]] -// CHECK: %[[VAL_21:.*]] = load half, ptr %[[VAL_19]], align 2, !invariant.load -// CHECK: %[[VAL_22:.*]] = fmul half %[[VAL_18]], %[[VAL_21]] -// CHECK: br label %[[VAL_9]] -// CHECK: concat_index_from_operand_id1: ; preds = %[[VAL_23:.*]] -// CHECK: %[[VAL_24:.*]] = phi i32 [ 1024, %[[VAL_23]] ] -// CHECK: %[[VAL_25:.*]] = sub nsw i32 %[[VAL_11]], %[[VAL_24]] -// CHECK: %[[VAL_26:.*]] = getelementptr inbounds [1023 x half], ptr %[[VAL_27:.*]], i32 0, i32 %[[VAL_25]] -// CHECK: %[[VAL_28:.*]] = load half, ptr %[[VAL_26]], align 2, !invariant.load -// CHECK: %[[VAL_29:.*]] = getelementptr inbounds [1023 x half], ptr %[[VAL_30:.*]], i32 0, i32 %[[VAL_25]] -// CHECK: %[[VAL_31:.*]] = load half, ptr %[[VAL_29]], align 2, !invariant.load -// CHECK: %[[VAL_32:.*]] = fadd half %[[VAL_28]], %[[VAL_31]] -// CHECK: br label %[[VAL_9]] -// CHECK: concatenate.pivot.1024.: ; preds = %[[VAL_7]] -// CHECK: %[[VAL_33:.*]] = icmp ult i32 %[[VAL_11]], 1024 -// CHECK: br i1 %[[VAL_33]], label %[[VAL_13]], label %[[VAL_23]] -// CHECK: concatenate.pivot.0.: ; preds = %[[VAL_12]] -// CHECK: br label %[[VAL_34:.*]] -// CHECK: concatenate.pivot.1024.1: ; preds = %[[VAL_12]] -// CHECK: br label %[[VAL_35:.*]] -// CHECK: concat.1.merge: ; preds = %[[VAL_35]], %[[VAL_34]] -// CHECK: %[[VAL_36:.*]] = phi half [ %[[VAL_22]], %[[VAL_34]] ], [ %[[VAL_32]], %[[VAL_35]] ] -// CHECK: %[[VAL_37:.*]] = insertvalue { half, half } undef, half %[[VAL_36]], 0 -// CHECK: %[[VAL_38:.*]] = add i32 %[[VAL_5]], 0 -// CHECK: %[[VAL_39:.*]] = getelementptr inbounds [1024 x half], ptr %[[VAL_17]], i32 0, i32 %[[VAL_38]] -// CHECK: %[[VAL_40:.*]] = load half, ptr %[[VAL_39]], align 2, !invariant.load -// CHECK: %[[VAL_41:.*]] = getelementptr inbounds [1024 x half], ptr %[[VAL_20]], i32 0, i32 %[[VAL_38]] -// CHECK: %[[VAL_42:.*]] = load half, ptr %[[VAL_41]], align 2, !invariant.load -// CHECK: %[[VAL_43:.*]] = fmul half %[[VAL_40]], %[[VAL_42]] -// CHECK: %[[VAL_44:.*]] = insertvalue { half, half } %[[VAL_37]], half %[[VAL_43]], 1 -// CHECK: %[[VAL_45:.*]] = extractvalue { half, half } %[[VAL_44]], 0 -// CHECK: %[[VAL_46:.*]] = getelementptr inbounds half, ptr %[[VAL_47:.*]], i32 %[[VAL_3]] -// CHECK: store half %[[VAL_45]], ptr %[[VAL_46]], align 2 -// CHECK: %[[VAL_48:.*]] = extractvalue { half, half } %[[VAL_44]], 1 -// CHECK: %[[VAL_49:.*]] = getelementptr inbounds half, ptr %[[VAL_50:.*]], i32 %[[VAL_3]] -// CHECK: store half %[[VAL_48]], ptr %[[VAL_49]], align 2 -// CHECK: br label %[[VAL_8]] - -HloModule input_fusion_with_a_tuple_of_slices, is_scheduled=true - -fused_computation { - arg.1 = f16[1024]{0} parameter(0) - arg.2 = f16[1024]{0} parameter(1) - arg.3 = f16[1023]{0} parameter(2) - arg.4 = f16[1023]{0} parameter(3) - mul.1 = f16[1024]{0} multiply(arg.1, arg.2) - add.1 = f16[1023]{0} add(arg.3, arg.4) - concat.1 = f16[2047]{0} concatenate(mul.1, add.1), dimensions={0} - slice.1 = f16[1024]{0} slice(concat.1), slice={[0:1024]} - slice.2 = f16[1024]{0} slice(mul.1), slice={[0:1024]} - ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}) tuple(slice.1, slice.2) -} - -ENTRY kernel_entry { - arg.1 = f16[1024]{0} parameter(0) - arg.2 = f16[1024]{0} parameter(1) - arg.3 = f16[1023]{0} parameter(2) - arg.4 = f16[1023]{0} parameter(3) - ROOT fusion = (f16[1024]{0}, f16[1024]{0}) - fusion(arg.1, arg.2, arg.3, arg.4), kind=kLoop, calls=fused_computation -} - diff --git a/xla/service/gpu/tests/fusion.hlo b/xla/service/gpu/tests/fusion.hlo deleted file mode 100644 index 00b0acdc22015..0000000000000 --- a/xla/service/gpu/tests/fusion.hlo +++ /dev/null @@ -1,271 +0,0 @@ -// RUN: hlo_to_llvm_ir %s | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %{IR_SUBST} %s - -HloModule TestModule, is_scheduled=true - -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// CHECK-LABEL: entry: -// CHECK: %[[VAL_0:.*]] = call i32 [[CTAIDX]] -// CHECK: %[[VAL_1:.*]] = call i32 [[TIDX]] -// CHECK-PTX: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 128 -// CHECK-GCN: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 256 -// CHECK: %[[VAL_3:.*]] = add nuw nsw i32 %[[VAL_2]], %[[VAL_1]] -// CHECK: %[[VAL_4:.*]] = icmp ult i32 %[[VAL_3]], 25690112 -// CHECK: call void @llvm.assume(i1 %[[VAL_4]]) -// CHECK: %[[VAL_5:.*]] = mul nuw nsw i32 %[[VAL_3]], 4 -// CHECK: %[[VAL_6:.*]] = udiv i32 %[[VAL_5]], 1 -// CHECK: %[[VAL_7:.*]] = urem i32 %[[VAL_6]], 64 -// CHECK: %[[VAL_8:.*]] = udiv i32 %[[VAL_5]], 64 -// CHECK: %[[VAL_9:.*]] = urem i32 %[[VAL_8]], 112 -// CHECK: %[[VAL_10:.*]] = udiv i32 %[[VAL_5]], 7168 -// CHECK: %[[VAL_11:.*]] = urem i32 %[[VAL_10]], 112 -// CHECK: %[[VAL_12:.*]] = udiv i32 %[[VAL_5]], 802816 -// CHECK: %[[VAL_13:.*]] = add nuw nsw i32 %[[VAL_5]], 1 -// CHECK: %[[VAL_14:.*]] = udiv i32 %[[VAL_13]], 1 -// CHECK: %[[VAL_15:.*]] = urem i32 %[[VAL_14]], 64 -// CHECK: %[[VAL_16:.*]] = udiv i32 %[[VAL_13]], 64 -// CHECK: %[[VAL_17:.*]] = urem i32 %[[VAL_16]], 112 -// CHECK: %[[VAL_18:.*]] = udiv i32 %[[VAL_13]], 7168 -// CHECK: %[[VAL_19:.*]] = urem i32 %[[VAL_18]], 112 -// CHECK: %[[VAL_20:.*]] = udiv i32 %[[VAL_13]], 802816 -// CHECK: %[[VAL_21:.*]] = add nuw nsw i32 %[[VAL_5]], 2 -// CHECK: %[[VAL_22:.*]] = udiv i32 %[[VAL_21]], 1 -// CHECK: %[[VAL_23:.*]] = urem i32 %[[VAL_22]], 64 -// CHECK: %[[VAL_24:.*]] = udiv i32 %[[VAL_21]], 64 -// CHECK: %[[VAL_25:.*]] = urem i32 %[[VAL_24]], 112 -// CHECK: %[[VAL_26:.*]] = udiv i32 %[[VAL_21]], 7168 -// CHECK: %[[VAL_27:.*]] = urem i32 %[[VAL_26]], 112 -// CHECK: %[[VAL_28:.*]] = udiv i32 %[[VAL_21]], 802816 -// CHECK: %[[VAL_29:.*]] = add nuw nsw i32 %[[VAL_5]], 3 -// CHECK: %[[VAL_30:.*]] = udiv i32 %[[VAL_29]], 1 -// CHECK: %[[VAL_31:.*]] = urem i32 %[[VAL_30]], 64 -// CHECK: %[[VAL_32:.*]] = udiv i32 %[[VAL_29]], 64 -// CHECK: %[[VAL_33:.*]] = urem i32 %[[VAL_32]], 112 -// CHECK: %[[VAL_34:.*]] = udiv i32 %[[VAL_29]], 7168 -// CHECK: %[[VAL_35:.*]] = urem i32 %[[VAL_34]], 112 -// CHECK: %[[VAL_36:.*]] = udiv i32 %[[VAL_29]], 802816 -// CHECK: %[[VAL_37:.*]] = icmp ult i32 %[[VAL_5]], 102760448 -// CHECK: br i1 %[[VAL_37]], label %[[VAL_38:.*]], label %[[VAL_39:.*]] -// CHECK: fusion.1.in_bounds-after: ; preds = %[[VAL_38]], %[[VAL_40:.*]] -// CHECK: ret void -// CHECK: fusion.1.in_bounds-true: ; preds = %[[VAL_40]] -// CHECK: %[[VAL_41:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_42:.*]], i32 0, i32 %[[VAL_7]] -// CHECK: %[[VAL_43:.*]] = load float, ptr %[[VAL_41]], align 4, !invariant.load -// CHECK: %[[VAL_44:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_45:.*]], i32 0, i32 %[[VAL_7]] -// CHECK: %[[VAL_46:.*]] = load float, ptr %[[VAL_44]], align 4, !invariant.load -// CHECK: %[[VAL_47:.*]] = fmul float %[[VAL_43]], %[[VAL_46]] -// CHECK: %[[VAL_48:.*]] = load float, ptr @0, align 4 -// CHECK: %[[VAL_49:.*]] = fmul float %[[VAL_47]], %[[VAL_48]] -// CHECK: %[[VAL_50:.*]] = getelementptr inbounds half, ptr %[[VAL_51:.*]], i32 %[[VAL_5]] -// CHECK: %[[VAL_52:.*]] = load half, ptr %[[VAL_50]], align 2, !invariant.load -// CHECK: %[[VAL_53:.*]] = load half, ptr @2, align 2 -// CHECK: %[[VAL_54:.*]] = fcmp ogt half %[[VAL_52]], %[[VAL_53]] -// CHECK: %[[VAL_55:.*]] = zext i1 %[[VAL_54]] to i8 -// CHECK: %[[VAL_56:.*]] = getelementptr inbounds half, ptr %[[VAL_57:.*]], i32 %[[VAL_5]] -// CHECK: %[[VAL_58:.*]] = load half, ptr %[[VAL_56]], align 2, !invariant.load -// CHECK: %[[VAL_59:.*]] = trunc i8 %[[VAL_55]] to i1 -// CHECK: %[[VAL_60:.*]] = select i1 %[[VAL_59]], half %[[VAL_58]], half %[[VAL_53]] -// CHECK: %[[VAL_61:.*]] = fpext half %[[VAL_60]] to float -// CHECK: %[[VAL_62:.*]] = load float, ptr @1, align 4 -// CHECK: %[[VAL_63:.*]] = fmul float %[[VAL_61]], %[[VAL_62]] -// CHECK: %[[VAL_64:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_65:.*]], i32 0, i32 %[[VAL_7]] -// CHECK: %[[VAL_66:.*]] = load float, ptr %[[VAL_64]], align 4, !invariant.load -// CHECK: %[[VAL_67:.*]] = fsub float %[[VAL_63]], %[[VAL_66]] -// CHECK: %[[VAL_68:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_69:.*]], i32 0, i32 %[[VAL_7]] -// CHECK: %[[VAL_70:.*]] = load float, ptr %[[VAL_68]], align 4, !invariant.load -// CHECK: %[[VAL_71:.*]] = getelementptr inbounds half, ptr %[[VAL_72:.*]], i32 %[[VAL_5]] -// CHECK: %[[VAL_73:.*]] = load half, ptr %[[VAL_71]], align 2, !invariant.load -// CHECK: %[[VAL_74:.*]] = fpext half %[[VAL_73]] to float -// CHECK: %[[VAL_75:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_76:.*]], i32 0, i32 %[[VAL_7]] -// CHECK: %[[VAL_77:.*]] = load float, ptr %[[VAL_75]], align 4, !invariant.load -// CHECK: %[[VAL_78:.*]] = load float, ptr @0, align 4 -// CHECK: %[[VAL_79:.*]] = fmul float %[[VAL_77]], %[[VAL_78]] -// CHECK: %[[VAL_80:.*]] = fsub float %[[VAL_74]], %[[VAL_79]] -// CHECK: %[[VAL_81:.*]] = fmul float %[[VAL_70]], %[[VAL_80]] -// CHECK: %[[VAL_82:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_83:.*]], i32 0, i32 %[[VAL_7]] -// CHECK: %[[VAL_84:.*]] = load float, ptr %[[VAL_82]], align 4, !invariant.load -// CHECK: %[[VAL_85:.*]] = fdiv float %[[VAL_81]], %[[VAL_84]] -// CHECK: %[[VAL_86:.*]] = fsub float %[[VAL_67]], %[[VAL_85]] -// CHECK: %[[VAL_87:.*]] = fmul float %[[VAL_49]], %[[VAL_86]] -// CHECK: %[[VAL_88:.*]] = fptrunc float %[[VAL_87]] to half -// CHECK: %[[VAL_89:.*]] = getelementptr inbounds half, ptr %[[VAL_90:.*]], i32 %[[VAL_5]] -// CHECK: store half %[[VAL_88]], ptr %[[VAL_89]], align 2 -// CHECK: %[[VAL_91:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_42]], i32 0, i32 %[[VAL_15]] -// CHECK: %[[VAL_92:.*]] = load float, ptr %[[VAL_91]], align 4, !invariant.load -// CHECK: %[[VAL_93:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_45]], i32 0, i32 %[[VAL_15]] -// CHECK: %[[VAL_94:.*]] = load float, ptr %[[VAL_93]], align 4, !invariant.load -// CHECK: %[[VAL_95:.*]] = fmul float %[[VAL_92]], %[[VAL_94]] -// CHECK: %[[VAL_96:.*]] = load float, ptr @0, align 4 -// CHECK: %[[VAL_97:.*]] = fmul float %[[VAL_95]], %[[VAL_96]] -// CHECK: %[[VAL_98:.*]] = getelementptr inbounds half, ptr %[[VAL_51]], i32 %[[VAL_13]] -// CHECK: %[[VAL_99:.*]] = load half, ptr %[[VAL_98]], align 2, !invariant.load -// CHECK: %[[VAL_100:.*]] = load half, ptr @2, align 2 -// CHECK: %[[VAL_101:.*]] = fcmp ogt half %[[VAL_99]], %[[VAL_100]] -// CHECK: %[[VAL_102:.*]] = zext i1 %[[VAL_101]] to i8 -// CHECK: %[[VAL_103:.*]] = getelementptr inbounds half, ptr %[[VAL_57]], i32 %[[VAL_13]] -// CHECK: %[[VAL_104:.*]] = load half, ptr %[[VAL_103]], align 2, !invariant.load -// CHECK: %[[VAL_105:.*]] = trunc i8 %[[VAL_102]] to i1 -// CHECK: %[[VAL_106:.*]] = select i1 %[[VAL_105]], half %[[VAL_104]], half %[[VAL_100]] -// CHECK: %[[VAL_107:.*]] = fpext half %[[VAL_106]] to float -// CHECK: %[[VAL_108:.*]] = load float, ptr @1, align 4 -// CHECK: %[[VAL_109:.*]] = fmul float %[[VAL_107]], %[[VAL_108]] -// CHECK: %[[VAL_110:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_65]], i32 0, i32 %[[VAL_15]] -// CHECK: %[[VAL_111:.*]] = load float, ptr %[[VAL_110]], align 4, !invariant.load -// CHECK: %[[VAL_112:.*]] = fsub float %[[VAL_109]], %[[VAL_111]] -// CHECK: %[[VAL_113:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_69]], i32 0, i32 %[[VAL_15]] -// CHECK: %[[VAL_114:.*]] = load float, ptr %[[VAL_113]], align 4, !invariant.load -// CHECK: %[[VAL_115:.*]] = getelementptr inbounds half, ptr %[[VAL_72]], i32 %[[VAL_13]] -// CHECK: %[[VAL_116:.*]] = load half, ptr %[[VAL_115]], align 2, !invariant.load -// CHECK: %[[VAL_117:.*]] = fpext half %[[VAL_116]] to float -// CHECK: %[[VAL_118:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_76]], i32 0, i32 %[[VAL_15]] -// CHECK: %[[VAL_119:.*]] = load float, ptr %[[VAL_118]], align 4, !invariant.load -// CHECK: %[[VAL_120:.*]] = load float, ptr @0, align 4 -// CHECK: %[[VAL_121:.*]] = fmul float %[[VAL_119]], %[[VAL_120]] -// CHECK: %[[VAL_122:.*]] = fsub float %[[VAL_117]], %[[VAL_121]] -// CHECK: %[[VAL_123:.*]] = fmul float %[[VAL_114]], %[[VAL_122]] -// CHECK: %[[VAL_124:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_83]], i32 0, i32 %[[VAL_15]] -// CHECK: %[[VAL_125:.*]] = load float, ptr %[[VAL_124]], align 4, !invariant.load -// CHECK: %[[VAL_126:.*]] = fdiv float %[[VAL_123]], %[[VAL_125]] -// CHECK: %[[VAL_127:.*]] = fsub float %[[VAL_112]], %[[VAL_126]] -// CHECK: %[[VAL_128:.*]] = fmul float %[[VAL_97]], %[[VAL_127]] -// CHECK: %[[VAL_129:.*]] = fptrunc float %[[VAL_128]] to half -// CHECK: %[[VAL_130:.*]] = getelementptr inbounds half, ptr %[[VAL_90]], i32 %[[VAL_13]] -// CHECK: store half %[[VAL_129]], ptr %[[VAL_130]], align 2 -// CHECK: %[[VAL_131:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_42]], i32 0, i32 %[[VAL_23]] -// CHECK: %[[VAL_132:.*]] = load float, ptr %[[VAL_131]], align 4, !invariant.load -// CHECK: %[[VAL_133:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_45]], i32 0, i32 %[[VAL_23]] -// CHECK: %[[VAL_134:.*]] = load float, ptr %[[VAL_133]], align 4, !invariant.load -// CHECK: %[[VAL_135:.*]] = fmul float %[[VAL_132]], %[[VAL_134]] -// CHECK: %[[VAL_136:.*]] = load float, ptr @0, align 4 -// CHECK: %[[VAL_137:.*]] = fmul float %[[VAL_135]], %[[VAL_136]] -// CHECK: %[[VAL_138:.*]] = getelementptr inbounds half, ptr %[[VAL_51]], i32 %[[VAL_21]] -// CHECK: %[[VAL_139:.*]] = load half, ptr %[[VAL_138]], align 2, !invariant.load -// CHECK: %[[VAL_140:.*]] = load half, ptr @2, align 2 -// CHECK: %[[VAL_141:.*]] = fcmp ogt half %[[VAL_139]], %[[VAL_140]] -// CHECK: %[[VAL_142:.*]] = zext i1 %[[VAL_141]] to i8 -// CHECK: %[[VAL_143:.*]] = getelementptr inbounds half, ptr %[[VAL_57]], i32 %[[VAL_21]] -// CHECK: %[[VAL_144:.*]] = load half, ptr %[[VAL_143]], align 2, !invariant.load -// CHECK: %[[VAL_145:.*]] = trunc i8 %[[VAL_142]] to i1 -// CHECK: %[[VAL_146:.*]] = select i1 %[[VAL_145]], half %[[VAL_144]], half %[[VAL_140]] -// CHECK: %[[VAL_147:.*]] = fpext half %[[VAL_146]] to float -// CHECK: %[[VAL_148:.*]] = load float, ptr @1, align 4 -// CHECK: %[[VAL_149:.*]] = fmul float %[[VAL_147]], %[[VAL_148]] -// CHECK: %[[VAL_150:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_65]], i32 0, i32 %[[VAL_23]] -// CHECK: %[[VAL_151:.*]] = load float, ptr %[[VAL_150]], align 4, !invariant.load -// CHECK: %[[VAL_152:.*]] = fsub float %[[VAL_149]], %[[VAL_151]] -// CHECK: %[[VAL_153:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_69]], i32 0, i32 %[[VAL_23]] -// CHECK: %[[VAL_154:.*]] = load float, ptr %[[VAL_153]], align 4, !invariant.load -// CHECK: %[[VAL_155:.*]] = getelementptr inbounds half, ptr %[[VAL_72]], i32 %[[VAL_21]] -// CHECK: %[[VAL_156:.*]] = load half, ptr %[[VAL_155]], align 2, !invariant.load -// CHECK: %[[VAL_157:.*]] = fpext half %[[VAL_156]] to float -// CHECK: %[[VAL_158:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_76]], i32 0, i32 %[[VAL_23]] -// CHECK: %[[VAL_159:.*]] = load float, ptr %[[VAL_158]], align 4, !invariant.load -// CHECK: %[[VAL_160:.*]] = load float, ptr @0, align 4 -// CHECK: %[[VAL_161:.*]] = fmul float %[[VAL_159]], %[[VAL_160]] -// CHECK: %[[VAL_162:.*]] = fsub float %[[VAL_157]], %[[VAL_161]] -// CHECK: %[[VAL_163:.*]] = fmul float %[[VAL_154]], %[[VAL_162]] -// CHECK: %[[VAL_164:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_83]], i32 0, i32 %[[VAL_23]] -// CHECK: %[[VAL_165:.*]] = load float, ptr %[[VAL_164]], align 4, !invariant.load -// CHECK: %[[VAL_166:.*]] = fdiv float %[[VAL_163]], %[[VAL_165]] -// CHECK: %[[VAL_167:.*]] = fsub float %[[VAL_152]], %[[VAL_166]] -// CHECK: %[[VAL_168:.*]] = fmul float %[[VAL_137]], %[[VAL_167]] -// CHECK: %[[VAL_169:.*]] = fptrunc float %[[VAL_168]] to half -// CHECK: %[[VAL_170:.*]] = getelementptr inbounds half, ptr %[[VAL_90]], i32 %[[VAL_21]] -// CHECK: store half %[[VAL_169]], ptr %[[VAL_170]], align 2 -// CHECK: %[[VAL_171:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_42]], i32 0, i32 %[[VAL_31]] -// CHECK: %[[VAL_172:.*]] = load float, ptr %[[VAL_171]], align 4, !invariant.load -// CHECK: %[[VAL_173:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_45]], i32 0, i32 %[[VAL_31]] -// CHECK: %[[VAL_174:.*]] = load float, ptr %[[VAL_173]], align 4, !invariant.load -// CHECK: %[[VAL_175:.*]] = fmul float %[[VAL_172]], %[[VAL_174]] -// CHECK: %[[VAL_176:.*]] = load float, ptr @0, align 4 -// CHECK: %[[VAL_177:.*]] = fmul float %[[VAL_175]], %[[VAL_176]] -// CHECK: %[[VAL_178:.*]] = getelementptr inbounds half, ptr %[[VAL_51]], i32 %[[VAL_29]] -// CHECK: %[[VAL_179:.*]] = load half, ptr %[[VAL_178]], align 2, !invariant.load -// CHECK: %[[VAL_180:.*]] = load half, ptr @2, align 2 -// CHECK: %[[VAL_181:.*]] = fcmp ogt half %[[VAL_179]], %[[VAL_180]] -// CHECK: %[[VAL_182:.*]] = zext i1 %[[VAL_181]] to i8 -// CHECK: %[[VAL_183:.*]] = getelementptr inbounds half, ptr %[[VAL_57]], i32 %[[VAL_29]] -// CHECK: %[[VAL_184:.*]] = load half, ptr %[[VAL_183]], align 2, !invariant.load -// CHECK: %[[VAL_185:.*]] = trunc i8 %[[VAL_182]] to i1 -// CHECK: %[[VAL_186:.*]] = select i1 %[[VAL_185]], half %[[VAL_184]], half %[[VAL_180]] -// CHECK: %[[VAL_187:.*]] = fpext half %[[VAL_186]] to float -// CHECK: %[[VAL_188:.*]] = load float, ptr @1, align 4 -// CHECK: %[[VAL_189:.*]] = fmul float %[[VAL_187]], %[[VAL_188]] -// CHECK: %[[VAL_190:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_65]], i32 0, i32 %[[VAL_31]] -// CHECK: %[[VAL_191:.*]] = load float, ptr %[[VAL_190]], align 4, !invariant.load -// CHECK: %[[VAL_192:.*]] = fsub float %[[VAL_189]], %[[VAL_191]] -// CHECK: %[[VAL_193:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_69]], i32 0, i32 %[[VAL_31]] -// CHECK: %[[VAL_194:.*]] = load float, ptr %[[VAL_193]], align 4, !invariant.load -// CHECK: %[[VAL_195:.*]] = getelementptr inbounds half, ptr %[[VAL_72]], i32 %[[VAL_29]] -// CHECK: %[[VAL_196:.*]] = load half, ptr %[[VAL_195]], align 2, !invariant.load -// CHECK: %[[VAL_197:.*]] = fpext half %[[VAL_196]] to float -// CHECK: %[[VAL_198:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_76]], i32 0, i32 %[[VAL_31]] -// CHECK: %[[VAL_199:.*]] = load float, ptr %[[VAL_198]], align 4, !invariant.load -// CHECK: %[[VAL_200:.*]] = load float, ptr @0, align 4 -// CHECK: %[[VAL_201:.*]] = fmul float %[[VAL_199]], %[[VAL_200]] -// CHECK: %[[VAL_202:.*]] = fsub float %[[VAL_197]], %[[VAL_201]] -// CHECK: %[[VAL_203:.*]] = fmul float %[[VAL_194]], %[[VAL_202]] -// CHECK: %[[VAL_204:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_83]], i32 0, i32 %[[VAL_31]] -// CHECK: %[[VAL_205:.*]] = load float, ptr %[[VAL_204]], align 4, !invariant.load -// CHECK: %[[VAL_206:.*]] = fdiv float %[[VAL_203]], %[[VAL_205]] -// CHECK: %[[VAL_207:.*]] = fsub float %[[VAL_192]], %[[VAL_206]] -// CHECK: %[[VAL_208:.*]] = fmul float %[[VAL_177]], %[[VAL_207]] -// CHECK: %[[VAL_209:.*]] = fptrunc float %[[VAL_208]] to half -// CHECK: %[[VAL_210:.*]] = getelementptr inbounds half, ptr %[[VAL_90]], i32 %[[VAL_29]] -// CHECK: store half %[[VAL_209]], ptr %[[VAL_210]], align 2 -// CHECK: br label %[[VAL_39]] - - -%fused_computation.1 (param_0.5: f32[64], param_1.3088: f32[64], param_2.2116: f32[64], param_3.974: f32[64], param_4.1162: f32[64], param_5.893: f32[64], param_6.809: f16[128,64,112,112], param_7.770: f16[128,64,112,112], param_8.637: f16[128,64,112,112]) -> f16[128,64,112,112] { - %param_4.1162 = f32[64]{0} parameter(4) - %broadcast.2313 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[64]{0} %param_4.1162), dimensions={1} - %param_3.974 = f32[64]{0} parameter(3) - %broadcast.1844 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[64]{0} %param_3.974), dimensions={1} - %multiply.1049 = f32[128,64,112,112]{1,3,2,0} multiply(f32[128,64,112,112]{1,3,2,0} %broadcast.2313, f32[128,64,112,112]{1,3,2,0} %broadcast.1844) - %constant_1404 = f32[] constant(6.22807704e-07) - %broadcast.1843 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[] %constant_1404), dimensions={} - %multiply.1048 = f32[128,64,112,112]{1,3,2,0} multiply(f32[128,64,112,112]{1,3,2,0} %multiply.1049, f32[128,64,112,112]{1,3,2,0} %broadcast.1843) - %param_8.637 = f16[128,64,112,112]{1,3,2,0} parameter(8) - %constant_3626 = f16[] constant(0) - %broadcast.4770 = f16[128,64,112,112]{1,3,2,0} broadcast(f16[] %constant_3626), dimensions={} - %compare.259 = pred[128,64,112,112]{1,3,2,0} compare(f16[128,64,112,112]{1,3,2,0} %param_8.637, f16[128,64,112,112]{1,3,2,0} %broadcast.4770), direction=GT - %param_7.770 = f16[128,64,112,112]{1,3,2,0} parameter(7) - %select.254 = f16[128,64,112,112]{1,3,2,0} select(pred[128,64,112,112]{1,3,2,0} %compare.259, f16[128,64,112,112]{1,3,2,0} %param_7.770, f16[128,64,112,112]{1,3,2,0} %broadcast.4770) - %convert.108 = f32[128,64,112,112]{1,3,2,0} convert(f16[128,64,112,112]{1,3,2,0} %select.254) - %constant_1390 = f32[] constant(1605632) - %broadcast.1841 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[] %constant_1390), dimensions={} - %multiply.1046 = f32[128,64,112,112]{1,3,2,0} multiply(f32[128,64,112,112]{1,3,2,0} %convert.108, f32[128,64,112,112]{1,3,2,0} %broadcast.1841) - %param_2.2116 = f32[64]{0} parameter(2) - %broadcast.1840 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[64]{0} %param_2.2116), dimensions={1} - %subtract.266 = f32[128,64,112,112]{1,3,2,0} subtract(f32[128,64,112,112]{1,3,2,0} %multiply.1046, f32[128,64,112,112]{1,3,2,0} %broadcast.1840) - %param_1.3088 = f32[64]{0} parameter(1) - %broadcast.1839 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[64]{0} %param_1.3088), dimensions={1} - %param_6.809 = f16[128,64,112,112]{1,3,2,0} parameter(6) - %convert.644 = f32[128,64,112,112]{1,3,2,0} convert(f16[128,64,112,112]{1,3,2,0} %param_6.809) - %param_5.893 = f32[64]{0} parameter(5) - %broadcast.3388 = f32[64]{0} broadcast(f32[] %constant_1404), dimensions={} - %multiply.2336 = f32[64]{0} multiply(f32[64]{0} %param_5.893, f32[64]{0} %broadcast.3388) - %broadcast.3387 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[64]{0} %multiply.2336), dimensions={1} - %subtract.591 = f32[128,64,112,112]{1,3,2,0} subtract(f32[128,64,112,112]{1,3,2,0} %convert.644, f32[128,64,112,112]{1,3,2,0} %broadcast.3387) - %multiply.1045 = f32[128,64,112,112]{1,3,2,0} multiply(f32[128,64,112,112]{1,3,2,0} %broadcast.1839, f32[128,64,112,112]{1,3,2,0} %subtract.591) - %param_0.5 = f32[64]{0} parameter(0) - %broadcast.1838 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[64]{0} %param_0.5), dimensions={1} - %divide.212 = f32[128,64,112,112]{1,3,2,0} divide(f32[128,64,112,112]{1,3,2,0} %multiply.1045, f32[128,64,112,112]{1,3,2,0} %broadcast.1838) - %subtract.265 = f32[128,64,112,112]{1,3,2,0} subtract(f32[128,64,112,112]{1,3,2,0} %subtract.266, f32[128,64,112,112]{1,3,2,0} %divide.212) - %multiply.1044 = f32[128,64,112,112]{1,3,2,0} multiply(f32[128,64,112,112]{1,3,2,0} %multiply.1048, f32[128,64,112,112]{1,3,2,0} %subtract.265) - ROOT %convert.107 = f16[128,64,112,112]{1,3,2,0} convert(f32[128,64,112,112]{1,3,2,0} %multiply.1044) -} - -ENTRY main { - %get-tuple-element.1532 = f32[64]{0} parameter(0) - %get-tuple-element.876 = f32[64]{0} parameter(1) - %get-tuple-element.877 = f32[64]{0} parameter(2) - %get-tuple-element.1530 = f32[64]{0} parameter(3) - %arg112.113 = f32[64]{0} parameter(4) - %get-tuple-element.881 = f32[64]{0} parameter(5) - %get-tuple-element.872 = f16[128,64,112,112]{1,3,2,0} parameter(6) - %select-and-scatter.3626 = f16[128,64,112,112]{1,3,2,0} parameter(7) - %fusion.845 = f16[128,64,112,112]{1,3,2,0} parameter(8) - - ROOT %fusion.1 = f16[128,64,112,112]{1,3,2,0} fusion(f32[64]{0} %get-tuple-element.1532, f32[64]{0} %get-tuple-element.876, f32[64]{0} %get-tuple-element.877, f32[64]{0} %get-tuple-element.1530, f32[64]{0} %arg112.113, f32[64]{0} %get-tuple-element.881, f16[128,64,112,112]{1,3,2,0} %get-tuple-element.872, f16[128,64,112,112]{1,3,2,0} %select-and-scatter.3626, f16[128,64,112,112]{1,3,2,0} %fusion.845), kind=kLoop, calls=%fused_computation.1 -} diff --git a/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc b/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc index c314853f9aba7..7d6c68bbc3d24 100644 --- a/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc +++ b/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,10 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/gpu/gemm_broadcast_folding_rewriter.h" #include "xla/service/gpu/gemm_rewriter.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -26,11 +30,11 @@ namespace { class GemmBroadcastFoldingRewriteTest : public GpuCodegenTest { protected: - se::CudaComputeCapability GetCudaComputeCapability() { + const auto& GpuComputeComp() { return backend() .default_stream_executor() ->GetDeviceDescription() - .cuda_compute_capability(); + .gpu_compute_capability(); } DebugOptions GetDebugOptionsForTest() override { @@ -38,6 +42,7 @@ class GemmBroadcastFoldingRewriteTest : public GpuCodegenTest { // These tests test the cuBLAS rewriter so we have to make sure that we use // cuBLAS for them. debug_options.set_xla_gpu_enable_triton_gemm(false); + debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); return debug_options; } }; @@ -58,7 +63,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,2,2], y: f32[2,2]) -> f32[3,2,2] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,2,2], {{.*}}: f32[2,2]) -> f32[3,2,2] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,2,2]{2,1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]), @@ -97,7 +102,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[3,2,2]) -> f32[3,2,2] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[3,2,2]) -> f32[3,2,2] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,2,2]{2,1,0} parameter(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]), @@ -137,7 +142,7 @@ ENTRY AddDotsFunc { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); // Use GemmRewriter to generate cublasGemm call. - GemmRewriter gemm_rewriter(GetCudaComputeCapability()); + GemmRewriter gemm_rewriter(GpuComputeComp()); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&gemm_rewriter, module.get())); EXPECT_TRUE(changed); @@ -163,7 +168,7 @@ ENTRY AddDotsFunc { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); // Use GemmRewriter to generate cublasGemm call. - GemmRewriter gemm_rewriter(GetCudaComputeCapability()); + GemmRewriter gemm_rewriter(GpuComputeComp()); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&gemm_rewriter, module.get())); EXPECT_TRUE(changed); @@ -187,7 +192,7 @@ ENTRY %LHSBatchDimNonZero (Arg_1: f32[4,3], Arg_2: f32[4,7,3]) -> f32[4,7,7] { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); // Use GemmRewriter to generate cublasGemm call. - GemmRewriter gemm_rewriter(GetCudaComputeCapability()); + GemmRewriter gemm_rewriter(GpuComputeComp()); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&gemm_rewriter, module.get())); EXPECT_TRUE(changed); @@ -210,7 +215,7 @@ ENTRY %RHSBatchDimNonZero (Arg_1: f32[4,3], Arg_2: f32[4,7,3]) -> f32[4,7,7] { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter gemm_rewriter(GetCudaComputeCapability()); + GemmRewriter gemm_rewriter(GpuComputeComp()); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&gemm_rewriter, module.get())); EXPECT_TRUE(changed); diff --git a/xla/service/gpu/tests/gemm_rewrite_test.cc b/xla/service/gpu/tests/gemm_rewrite_test.cc index 11f917376ebab..69e956f62deac 100644 --- a/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,29 +13,45 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include #include #include #include +#include #include -#include +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" #include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/error_spec.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/executable.h" #include "xla/service/gpu/gemm_rewriter.h" #include "xla/service/gpu/gpu_executable.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/service/gpu/variant_visitor.h" #include "xla/service/hlo_module_config.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" #include "xla/test.h" #include "xla/tests/filecheck.h" #include "xla/xla.pb.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" #if GOOGLE_CUDA #include "third_party/gpus/cuda/include/cuda.h" +#elif TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" #endif namespace xla { @@ -46,12 +62,69 @@ namespace { namespace m = ::xla::match; class GemmRewriteTest : public GpuCodegenTest { + const auto& device_desc() { + return backend().default_stream_executor()->GetDeviceDescription(); + } + public: - se::CudaComputeCapability GetCudaComputeCapability() { - return backend() - .default_stream_executor() - ->GetDeviceDescription() - .cuda_compute_capability(); + const se::GpuComputeCapability& GpuComputeComp() { + return device_desc().gpu_compute_capability(); + } + se::GpuComputeCapability CudaHopperOrRocmMI300() { + return std::visit( + VariantVisitor{[](const se::CudaComputeCapability&) { + return se::GpuComputeCapability{ + se::CudaComputeCapability{ + se::CudaComputeCapability::HOPPER, 0}}; + }, + [](const se::RocmComputeCapability&) { + return se::GpuComputeCapability{ + se::RocmComputeCapability{"gfx942"}}; + }}, + GpuComputeComp()); + } + + enum class Switch : uint32_t { + False, // check always fails + True, // check always succeeds + }; + // Switch based on GPU platform only: true/false for both + bool CudaOrRocmCheck(Switch cuda_set, Switch rocm_set) { + return CudaOrRocmCheck( + [cuda_set](const se::CudaComputeCapability&) { + return cuda_set == Switch::True; + }, + [rocm_set](const se::RocmComputeCapability&) { + return rocm_set == Switch::True; + }); + } + // Major version check for CUDA and true/false for ROCM + bool CudaOrRocmCheck(int cuda_major, Switch rocm_set) { + return CudaOrRocmCheck(cuda_major, 0, rocm_set); + } + // Full version check for CUDA and true/false for ROCM + bool CudaOrRocmCheck(int cuda_major, int cuda_minor, Switch rocm_set) { + return CudaOrRocmCheck(cuda_major, cuda_minor, + [rocm_set](const se::RocmComputeCapability&) { + return rocm_set == Switch::True; + }); + } + // Full version check for CUDA and generic version for ROCM + bool CudaOrRocmCheck( + int cuda_major, int cuda_minor, + absl::AnyInvocable rocm_fun) { + return CudaOrRocmCheck( + [cuda_major, cuda_minor](const se::CudaComputeCapability& cc) { + return cc.IsAtLeast(cuda_major, cuda_minor); + }, + std::move(rocm_fun)); + } + // The most generic version for both platforms + bool CudaOrRocmCheck( + absl::AnyInvocable cuda_fun, + absl::AnyInvocable rocm_fun) { + return std::visit(VariantVisitor{std::move(cuda_fun), std::move(rocm_fun)}, + GpuComputeComp()); } DebugOptions GetDebugOptionsForTest() override { @@ -59,11 +132,28 @@ class GemmRewriteTest : public GpuCodegenTest { // These tests test the cuBLAS rewriter so we have to make sure that we use // cuBLAS for them. debug_options.set_xla_gpu_enable_triton_gemm(false); + debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); return debug_options; } + + bool SkipGpuBlasLtTest() { + return CudaOrRocmCheck( + [](const se::CudaComputeCapability&) { // never skip gpublas-lt tests + // for CUDA + return false; + }, + [this](const se::RocmComputeCapability& rocm) { + bool blaslt = GetDebugOptionsForTest().xla_gpu_enable_cublaslt(); + return (blaslt && !rocm.has_hipblaslt()); + }); + } }; TEST_F(GemmRewriteTest, CheckCustomCallTarget) { + if (SkipGpuBlasLtTest()) { + GTEST_SKIP() << "BlasLt is not supported on this GPU architecture"; + } + const char* hlo_text = R"( HloModule SimpleGemm @@ -74,7 +164,6 @@ ENTRY AddDotsFunc { } )"; - DebugOptions debug_options = GetDebugOptionsForTest(); if (debug_options.xla_gpu_enable_cublaslt()) { MatchOptimizedHlo(hlo_text, @@ -85,9 +174,9 @@ ENTRY AddDotsFunc { } } -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM TEST_F(GemmRewriteTest, TestBatchedAutotuning) { - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) { + if (CudaOrRocmCheck(se::CudaComputeCapability::AMPERE, Switch::False)) { GTEST_SKIP() << "There is no autotuning starting with the Nvidia Ampere generation"; } @@ -110,6 +199,10 @@ ENTRY %test { #endif TEST_F(GemmRewriteTest, SimpleRewriteDeterministic) { + if (SkipGpuBlasLtTest()) { + GTEST_SKIP() << "BlasLt is not supported on this GPU architecture"; + } + const char* hlo_text = R"( HloModule SimpleGemm @@ -143,8 +236,9 @@ ENTRY AddDotsFunc { *get_module(), backend().default_stream_executor(), backend().default_stream_executor()->GetAllocator())); - StatusOr filecheck_result = RunFileCheck(optimized_module->ToString(), - R"( + absl::StatusOr filecheck_result = + RunFileCheck(optimized_module->ToString(), + R"( ; CHECK: custom_call_target="__cublas${{(lt\$matmul|gemm)}}" )"); TF_ASSERT_OK(filecheck_result.status()); @@ -163,18 +257,32 @@ ENTRY bf16gemm { } )"; - MatchOptimizedHlo(hlo_text, R"( -; CHECK: [[P1:%[^ ]+]] = bf16[3]{0} parameter(1) -; CHECK: [[INSTR_1:%[^ ]+]] = f32[3]{0} convert([[P1]]) -; CHECK: [[P0:%[^ ]+]] = bf16[3]{0} parameter(0) -; CHECK: [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[P0]]) -; CHECK: [[INSTR_4:%[^ ]+]] = f32[3]{0} multiply([[INSTR_1]], [[INSTR_3]]) -; CHECK: [[INSTR_5:%[^ ]+]] = f32[] constant(0) -; CHECK: [[INSTR_6:%[^ ]+]] = f32[] reduce([[INSTR_4]], [[INSTR_5]]), dimensions={0}, to_apply=[[INSTR_7:%[^ ]+]] -; CHECK: ROOT [[INSTR_8:%[^ ]+]] = bf16[] convert([[INSTR_6]]) - )"); + if (CudaOrRocmCheck(9, 0, Switch::False)) { + // The Hopper optimized HLO has a BF16 multiply instruction since Hopper has + // native BF16 multiply support. + MatchOptimizedHlo(hlo_text, R"( + ; CHECK: [[P0:%[^ ]+]] = bf16[3]{0} parameter(0) + ; CHECK: [[P1:%[^ ]+]] = bf16[3]{0} parameter(1) + ; CHECK: [[INSTR_2:%[^ ]+]] = bf16[3]{0} multiply([[P0]], [[P1]]) + ; CHECK: [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[INSTR_2]]) + ; CHECK: [[INSTR_4:%[^ ]+]] = f32[] constant(0) + ; CHECK: [[INSTR_5:%[^ ]+]] = f32[] reduce([[INSTR_3]], [[INSTR_4]]), dimensions={0}, to_apply=[[INSTR_6:%[^ ]+]] + ; CHECK: ROOT [[INSTR_7:%[^ ]+]] = bf16[] convert([[INSTR_5]]) + )"); + } else { + MatchOptimizedHlo(hlo_text, R"( + ; CHECK: [[P1:%[^ ]+]] = bf16[3]{0} parameter(1) + ; CHECK: [[INSTR_1:%[^ ]+]] = f32[3]{0} convert([[P1]]) + ; CHECK: [[P0:%[^ ]+]] = bf16[3]{0} parameter(0) + ; CHECK: [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[P0]]) + ; CHECK: [[INSTR_4:%[^ ]+]] = f32[3]{0} multiply([[INSTR_1]], [[INSTR_3]]) + ; CHECK: [[INSTR_5:%[^ ]+]] = f32[] constant(0) + ; CHECK: [[INSTR_6:%[^ ]+]] = f32[] reduce([[INSTR_4]], [[INSTR_5]]), dimensions={0}, to_apply=[[INSTR_7:%[^ ]+]] + ; CHECK: ROOT [[INSTR_8:%[^ ]+]] = bf16[] convert([[INSTR_6]]) + )"); + } - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-4})); } TEST_F(GemmRewriteTest, BF16Transpose) { @@ -195,7 +303,7 @@ ENTRY broadcast { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); } -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // A test fixture class for tests which should have similar results with legacy // cublas and cublasLt class ParameterizedGemmRewriteTest @@ -222,6 +330,13 @@ class ParameterizedGemmRewriteTest return replacements_[kCustomCallTargetPlaceholder]; } + protected: + void SetUp() override { + if (SkipGpuBlasLtTest()) { + GTEST_SKIP() << "BlasLt is not supported on this GPU architecture"; + } + } + protected: absl::flat_hash_map replacements_; @@ -245,7 +360,7 @@ ENTRY test { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4]) -> f32[2,4] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]), @@ -283,7 +398,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,3], y: f32[3,4]) -> f32[2,4] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]), @@ -323,7 +438,7 @@ ENTRY AddDotsFunc { R"( ; CHECK-NOT: copy ; -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,4,2], y: f32[3,4,5]) -> f32[2,5] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,4,2], {{.*}}: f32[3,4,5]) -> f32[2,5] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,4,2]{2,1,0} parameter(0) ; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4,5]{2,1,0} parameter(1) ; CHECK-DAG: [[BITCAST0:%[^ ]+]] = f32[2,12]{0,1} bitcast([[P0]]) @@ -364,7 +479,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,2], y: f32[3,4]) -> f32[2,4] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,2], {{.*}}: f32[3,4]) -> f32[2,4] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,2]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]), @@ -403,7 +518,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,3,2], y: f32[5,3,4]) -> f32[5,2,4] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[5,3,2], {{.*}}: f32[5,3,4]) -> f32[5,2,4] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[5,3,2]{2,1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]), @@ -442,7 +557,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,5,3], y: f32[5,3,4]) -> f32[5,2,4] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,5,3], {{.*}}: f32[5,3,4]) -> f32[5,2,4] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,5,3]{2,1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]), @@ -481,7 +596,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,2,5], y: f32[5,3,4]) -> f32[5,2,4] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[3,2,5], {{.*}}: f32[5,3,4]) -> f32[5,2,4] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,2,5]{2,1,0} parameter(0) ; CHECK-DAG: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1) ; CHECK-DAG: [[FUSION:%[^ ]+]] = f32[5,2,3]{2,1,0} transpose([[P0]]) @@ -522,7 +637,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[20000,4,3,2], y: f32[20000,4,3,4]) -> f32[20000,4,2,4] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[20000,4,3,2], {{.*}}: f32[20000,4,3,4]) -> f32[20000,4,2,4] { ; CHECK: [[P0:%[^ ]+]] = f32[20000,4,3,2]{3,2,1,0} parameter(0) ; CHECK: [[BC0:%[^ ]+]] = f32[80000,3,2]{2,1,0} bitcast([[P0]]) ; CHECK: [[P1:%[^ ]+]] = f32[20000,4,3,4]{3,2,1,0} parameter(1) @@ -564,7 +679,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,3], y: f32[3,4]) -> f32[4,2] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[4,2] { ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P1]], [[P0]]), @@ -603,7 +718,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,2,3], y: f32[5,3,4]) -> f32[2,5,4] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[5,2,3], {{.*}}: f32[5,3,4]) -> f32[2,5,4] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]), @@ -643,7 +758,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,2,3], y: f32[5,3,4]) -> f32[2,4,5] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[5,2,3], {{.*}}: f32[5,3,4]) -> f32[2,4,5] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]), @@ -685,7 +800,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]), @@ -709,6 +824,13 @@ ENTRY AddDotsFunc { } TEST_P(ParameterizedGemmRewriteTest, ComplexAlphaSimpleRewrite) { + if (CudaOrRocmCheck( + [](se::CudaComputeCapability) { return false; }, + [this](se::RocmComputeCapability rocm) { + return GetDebugOptionsForTest().xla_gpu_enable_cublaslt(); + })) { + GTEST_SKIP() << "TODO: Unsupported C64 gpublas-lt datatype on ROCM"; + } const char* hlo_text = R"( HloModule ComplexAlphaSimpleRewrite @@ -726,7 +848,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: c64[2,2], y: c64[2,2]) -> c64[2,2] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: c64[2,2], {{.*}}: c64[2,2]) -> c64[2,2] { ; CHECK-NEXT: [[P0:%[^ ]+]] = c64[2,2]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = c64[2,2]{1,0} parameter(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]), @@ -805,7 +927,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]), @@ -840,7 +962,7 @@ ENTRY bf16gemm { )"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) { + if (CudaOrRocmCheck(se::CudaComputeCapability::AMPERE, Switch::True)) { MatchOptimizedHlo(hlo_text, R"( ; CHECK: {{.*}} custom-call(bf16[16,8]{1,0} {{.*}}, bf16[8,8]{1,0} {{.*}}), custom_call_target="<>" @@ -868,7 +990,7 @@ ENTRY bf16gemm { )"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) { + if (CudaOrRocmCheck(se::CudaComputeCapability::AMPERE, Switch::True)) { MatchOptimizedHlo(hlo_text, R"( ; CHECK: {{.*}} custom-call(bf16[3,8,8]{2,1,0} {{.*}}, bf16[3,8,8]{2,1,0} {{.*}}), custom_call_target="<>" @@ -901,7 +1023,7 @@ ENTRY int8gemm { )"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { + if (CudaOrRocmCheck(se::CudaComputeCapability::VOLTA, Switch::True)) { MatchOptimizedHlo(hlo_text, R"( ; CHECK: {{.*}} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm" @@ -918,6 +1040,10 @@ ENTRY int8gemm { } TEST_F(GemmRewriteTest, Int8GemmRankGreaterThanTwo) { + if (CudaOrRocmCheck(Switch::False, Switch::True)) { + GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm"; + } + const char* hlo_text = R"( HloModule int8gemm @@ -931,10 +1057,10 @@ ENTRY main.4 { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { + if (CudaOrRocmCheck(se::CudaComputeCapability::VOLTA, Switch::True)) { MatchOptimizedHlo(hlo_text, R"( -; CHECK: [[GEMM:%[^ ]+]] = (s32[8,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(s8[8,4]{1,0} %fusion.1, s8[4,4]{0,1} %bitcast.13), custom_call_target="__cublas$gemm", +; CHECK: [[GEMM:%[^ ]+]] = (s32[8,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(s8[8,4]{1,0} %{{.*}}, s8[4,4]{0,1} %{{.*}}), custom_call_target="__cublas$gemm", )", /*print_operand_shape=*/true); } @@ -955,7 +1081,7 @@ ENTRY int8gemm { )"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { + if (CudaOrRocmCheck(se::CudaComputeCapability::VOLTA, Switch::True)) { MatchOptimizedHlo(hlo_text, R"( ; CHECK: {{.*}} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]), @@ -989,7 +1115,7 @@ ENTRY int8gemm { )"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { + if (CudaOrRocmCheck(se::CudaComputeCapability::VOLTA, Switch::True)) { MatchOptimizedHlo(hlo_text, R"( ; CHECK: {{.*}} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]), @@ -1011,6 +1137,10 @@ ENTRY int8gemm { } TEST_P(ParameterizedGemmRewriteTest, Int8GemmNotMultipleOfFour) { + if (CudaOrRocmCheck(Switch::False, Switch::True)) { + GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm"; + } + const char* hlo_text = R"( HloModule int8gemm @@ -1022,7 +1152,7 @@ ENTRY int8gemm { )"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { + if (CudaOrRocmCheck(se::CudaComputeCapability::VOLTA, Switch::True)) { MatchOptimizedHlo(hlo_text, R"( ; CHECK: {{.*}} custom-call(s8[16,4]{1,0} [[A:%[^ ]+]], s8[4,12]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm" @@ -1039,6 +1169,10 @@ ENTRY int8gemm { } TEST_P(ParameterizedGemmRewriteTest, GemmTypeCombinationCheck) { + if (CudaOrRocmCheck(Switch::False, Switch::True)) { + GTEST_SKIP() << "DoBlasGemmWithAlgorithm is not yet implemented on ROCm"; + } + std::vector> type_combinations = {{"s8", "s8", true}, {"s32", "s32", true}, @@ -1054,7 +1188,7 @@ TEST_P(ParameterizedGemmRewriteTest, GemmTypeCombinationCheck) { {"f16", "f32", true}, {"bf16", "f32", true}}; - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { + if (CudaOrRocmCheck(se::CudaComputeCapability::VOLTA, Switch::True)) { // For compute capabilities before volta, we always do upcasting, so it // would be impossible for this test to fail. That is why we only add these // cases when the compute capability is at least Volta. @@ -1118,7 +1252,7 @@ ENTRY test { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); + GemmRewriter pass(GpuComputeComp()); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -1142,7 +1276,7 @@ ENTRY test { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); + GemmRewriter pass(GpuComputeComp()); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -1166,7 +1300,7 @@ ENTRY test { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); + GemmRewriter pass(GpuComputeComp()); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -1193,7 +1327,7 @@ ENTRY test { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); + GemmRewriter pass(GpuComputeComp()); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -1217,7 +1351,7 @@ ENTRY test { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); + GemmRewriter pass(GpuComputeComp()); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -1245,7 +1379,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); + GemmRewriter pass(GpuComputeComp()); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -1278,7 +1412,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); + GemmRewriter pass(GpuComputeComp()); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -1313,7 +1447,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); + GemmRewriter pass(GpuComputeComp()); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -1345,6 +1479,67 @@ class LegacyCublasGemmRewriteTest : public GemmRewriteTest { } }; +TEST_F(LegacyCublasGemmRewriteTest, MatrixVectorMultiplication) { + const char* hlo_text = R"( +HloModule m + +ENTRY e { + p0 = f32[2048] parameter(0) + p1 = f32[2048, 16384] parameter(1) + ROOT d = f32[16384] dot(p0, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +})"; + + RunAndFilecheckHloRewrite(hlo_text, + GemmRewriter(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}), + R"( +; CHECK: %[[P0:.+]] = f32[2048]{0} parameter(0) +; CHECK: %[[P1:.+]] = f32[2048,16384]{1,0} parameter(1) +; CHECK: %[[CUSTOM_CALL:.+]] = (f32[16384]{0}, s8[4194304]{0}) custom-call(%[[P0]], %[[P1]]), custom_call_target="__cublas$gemm" +)"); +} + +TEST_F(LegacyCublasGemmRewriteTest, MatrixVectorMultiplicationWithBatch) { + const char* hlo_text = R"( +HloModule m + +ENTRY e { + p0 = f32[10, 10, 2048] parameter(0) + p1 = f32[10, 10, 2048, 16384] parameter(1) + ROOT d = f32[10, 10, 16384] dot(p0, p1), + lhs_batch_dims={0, 1}, rhs_batch_dims={0, 1}, + lhs_contracting_dims={2}, rhs_contracting_dims={2} +})"; + + RunAndFilecheckHloRewrite(hlo_text, + GemmRewriter(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}), + R"( +; CHECK: %[[P0:.+]] = f32[10,10,2048]{2,1,0} parameter(0) +; CHECK: %[[P1:.+]] = f32[10,10,2048,16384]{3,2,1,0} parameter(1) +; CHECK: %[[CUSTOM_CALL:.+]] = (f32[10,10,16384]{2,1,0}, s8[4194304]{0}) custom-call(%[[P0]], %[[P1]]), custom_call_target="__cublas$gemm" +)"); +} + +TEST_F(LegacyCublasGemmRewriteTest, SparseDotNotSupported) { + const char* hlo_text = R"( +HloModule test + +ENTRY main { + lhs = f16[5,16] parameter(0) + rhs = f16[32,10] parameter(1) + meta = u16[5,2] parameter(2) + ROOT dot = f32[5,10] dot(lhs, rhs, meta), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4 +})"; + auto hlo_pass = GemmRewriter( + se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0}); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&hlo_pass, module.get())); + EXPECT_FALSE(changed); +} + // Test that the alpha and beta fields of the GemmBackendConfig are updated. // A bias must be present for the beta value to be set. // In order to have a bias add fused, the bias term must be overwritable. @@ -1373,7 +1568,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], param_2: f32[2,2]) -> f32[2,2] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] { ; CHECK-DAG: [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0) ; CHECK-DAG: [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1) ; CHECK: [[O:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], {{[^,)]+}}), @@ -1420,7 +1615,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] { ; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) ; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]), @@ -1459,7 +1654,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] { ; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) ; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]), @@ -1499,7 +1694,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], param_2: (f32[2,2], f32[3,3])) -> f32[2,2] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: (f32[2,2], f32[3,3])) -> f32[2,2] { ; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) ; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) ; CHECK-DAG: [[P2:%[^ ]+]] = (f32[2,2]{1,0}, f32[3,3]{1,0}) parameter(2) @@ -1548,7 +1743,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] { ; CHECK-DAG: [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0) ; CHECK-DAG: [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1) ; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2) @@ -1593,7 +1788,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[1024,1024], y: f32[1024,1024], bias: f32[1024,1024]) -> f32[1024,1024] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024]) -> f32[1024,1024] { ; CHECK-DAG: [[P0:%[^ ]+]] = f32[1024,1024]{1,0} parameter(0) ; CHECK-DAG: [[P1:%[^ ]+]] = f32[1024,1024]{1,0} parameter(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = (f32[1024,1024]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]), @@ -1638,7 +1833,7 @@ ENTRY BF16GemmWithBias { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2e-3, 2e-3})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %BF16GemmWithBias (x: bf16[8,8], y: bf16[8,8], param_2: bf16[8,8]) -> bf16[8,8] { +; CHECK-LABEL: ENTRY %BF16GemmWithBias ({{.*}}: bf16[8,8], {{.*}}: bf16[8,8], {{.*}}: bf16[8,8]) -> bf16[8,8] { ; CHECK-DAG: [[X:%[^ ]+]] = bf16[8,8]{1,0} parameter(0) ; CHECK-DAG: [[Y:%[^ ]+]] = bf16[8,8]{1,0} parameter(1) ; CHECK: [[GEMM:%[^ ]+]] = (bf16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], {{[^,)]+}}), @@ -1687,7 +1882,7 @@ ENTRY test { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], param_2: f32[2,4]) -> f32[2,4] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,4]) -> f32[2,4] { ; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK: [[GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], {{[^,)]+}}), @@ -1732,7 +1927,7 @@ ENTRY test { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (w: f32[2,3], x: f32[3,4], y: f32[2,3], z: f32[3,4]) -> f32[2,4] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] { ; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-DAG: [[P2:%[^ ]+]] = f32[2,3]{1,0} parameter(2) @@ -1778,7 +1973,7 @@ ENTRY test { )"); } -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Test gemm matrix bias add fusion with mix type TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixType) { std::vector> @@ -1855,7 +2050,7 @@ ENTRY test { } #endif -// Test batch gemm matrix bias add fusion with mix type that is not supported +// Test batch gemm matrix bias add fusion with mix type that is not supported. TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixTypeNotSupported) { const char* hlo_text = R"( HloModule test @@ -1875,17 +2070,15 @@ ENTRY test { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, GetOptimizedModule(hlo_text)); - EXPECT_THAT( - optimized_module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion( - m::Parameter(2), - m::GetTupleElement(m::CustomCall({"__cublas$gemm"}, m::Parameter(0), - m::Parameter(1)), - 0)))); + MatchOptimizedHlo(hlo_text, R"( +; CHECK: %[[custom_call:.*]] = {{.*}} custom-call{{.*}}__cublas$gemm +; CHECK: %[[gte:.*]] = {{.*}} get-tuple-element{{.*}}%[[custom_call]] +; CHECK: ROOT {{.*}} fusion({{.*}}%[[gte]] +)"); } // Test batch gemm matrix bias add fusion with mix type that is not supported -// cuz there are consumers of bias add +// because there are consumers of bias add. TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasMixTypeAddWithMoreConsumers) { const char* hlo_text = R"( HloModule test @@ -1906,13 +2099,11 @@ ENTRY test { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, GetOptimizedModule(hlo_text)); - EXPECT_THAT( - optimized_module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion( - m::Parameter(2), - m::GetTupleElement(m::CustomCall({"__cublas$gemm"}, m::Parameter(0), - m::Parameter(1)), - 0)))); + MatchOptimizedHlo(hlo_text, R"( +; CHECK: %[[custom_call:.*]] = {{.*}} custom-call{{.*}}__cublas$gemm +; CHECK: %[[gte:.*]] = {{.*}} get-tuple-element{{.*}}%[[custom_call]] +; CHECK: ROOT {{.*}} fusion({{.*}}%[[gte]] +)"); } TEST_F(LegacyCublasGemmRewriteTest, MergeBitcastAndAdd) { @@ -1929,7 +2120,7 @@ ENTRY test { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); + GemmRewriter pass(GpuComputeComp()); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -1979,7 +2170,7 @@ ENTRY test { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); + GemmRewriter pass(GpuComputeComp()); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); SCOPED_TRACE(module->ToString()); EXPECT_TRUE(changed); @@ -2001,7 +2192,7 @@ ENTRY test { 0)))); } -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // A test fixture class for tests which are specific to cublasLt class CublasLtGemmRewriteTest : public GemmRewriteTest { public: @@ -2011,6 +2202,13 @@ class CublasLtGemmRewriteTest : public GemmRewriteTest { debug_options.set_xla_gpu_enable_triton_gemm(false); return debug_options; } + + protected: + void SetUp() override { + if (SkipGpuBlasLtTest()) { + GTEST_SKIP() << "BlasLt is not supported on this GPU architecture"; + } + } }; TEST_F(CublasLtGemmRewriteTest, AlphaBetaRewrite) { @@ -2033,7 +2231,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] { ; CHECK-DAG: [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0) ; CHECK-DAG: [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1) ; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2) @@ -2077,7 +2275,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[2,2], {{.*}}: f32[2,2], {{.*}}: f32[2,2]) -> f32[2,2] { ; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) ; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) ; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2) @@ -2120,7 +2318,7 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[1024,1024], y: f32[1024,1024], bias: f32[1024,1024]) -> f32[1024,1024] { +; CHECK-LABEL: ENTRY %AddDotsFunc ({{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024]) -> f32[1024,1024] { ; CHECK-DAG: [[P0:%[^ ]+]] = f32[1024,1024]{1,0} parameter(0) ; CHECK-DAG: [[P1:%[^ ]+]] = f32[1024,1024]{1,0} parameter(1) ; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[1024,1024]{1,0} parameter(2) @@ -2161,7 +2359,7 @@ ENTRY BF16GemmWithBias { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %BF16GemmWithBias (x: bf16[8,8], y: bf16[8,8], bias: bf16[8,8]) -> bf16[8,8] { +; CHECK-LABEL: ENTRY %BF16GemmWithBias ({{.*}}: bf16[8,8], {{.*}}: bf16[8,8], {{.*}}: bf16[8,8]) -> bf16[8,8] { ; CHECK-DAG: [[X:%[^ ]+]] = bf16[8,8]{1,0} parameter(0) ; CHECK-DAG: [[Y:%[^ ]+]] = bf16[8,8]{1,0} parameter(1) ; CHECK-DAG: [[BIAS:%[^ ]+]] = bf16[8,8]{1,0} parameter(2) @@ -2202,7 +2400,7 @@ ENTRY test { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2,4]) -> f32[2,4] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,4]) -> f32[2,4] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2,4]{1,0} parameter(2) @@ -2245,7 +2443,7 @@ ENTRY test { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (w: f32[2,3], x: f32[3,4], y: f32[2,3], z: f32[3,4]) -> f32[2,4] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] { ; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-DAG: [[P2:%[^ ]+]] = f32[2,3]{1,0} parameter(2) @@ -2306,7 +2504,7 @@ ENTRY test { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4]) -> f32[2,4] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) @@ -2361,7 +2559,7 @@ ENTRY test { ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,4]{1,0} add([[P0]], [[P2]]) } -; CHECK-LABEL: ENTRY %test (x: f32[4,4], y: f32[4,4], z: f32[4]) -> f32[4,4] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[4,4], {{.*}}: f32[4,4], {{.*}}: f32[4]) -> f32[4,4] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,4]{1,0} parameter(1) ; CHECK-NEXT: [[MATMUL0:%[^ ]+]] = f32[4,4]{1,0} custom-call([[P0]], [[P1]]), @@ -2383,9 +2581,7 @@ ENTRY test { ; CHECK: } ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) ; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[4,4]{1,0} fusion([[MATMUL0]], [[P2]]), kind=kLoop, calls=[[FUSED_COMPUTATION]] -; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(5) -; CHECK-NEXT: [[C0_BCAST:%[^ ]+]] = f32[4,4]{1,0} broadcast([[C0]]), dimensions={} -; CHECK-NEXT: [[MATMUL1:%[^ ]+]] = f32[4,4]{1,0} custom-call([[MATMUL0]], [[C0_BCAST]]), +; CHECK: [[MATMUL1:%[^ ]+]] = f32[4,4]{1,0} custom-call([[MATMUL0]] ; CHECK: custom_call_target="__cublas$lt$matmul", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -2441,15 +2637,8 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3,4], y: f32[4,5,6], z: f32[3,5,6]) -> f32[2,3,5,6] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3,4]{2,1,0} parameter(0) -; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[6,4]{1,0} bitcast([[P0]]) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,5,6]{2,1,0} parameter(1) -; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,30]{1,0} -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[3,5,6]{2,1,0} parameter(2) -; CHECK-NEXT: [[BROADCAST:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} broadcast([[P2]]), dimensions={1,2,3} -; CHECK-NEXT: [[BITCAST:%[^ ]+]] = f32[6,30]{1,0} bitcast([[BROADCAST]]) -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[BITCAST]]), +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6], {{.*}}: f32[3,5,6]) -> f32[2,3,5,6] { +; CHECK: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} custom-call( ; CHECK: custom_call_target="__cublas$lt$matmul", ; CHECK: output_to_operand_aliasing={{[{][{]}}}: (2, {})}, ; CHECK: backend_config={ @@ -2490,15 +2679,8 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3,4], y: f32[4,5,6], z: f32[6]) -> f32[2,3,5,6] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3,4]{2,1,0} parameter(0) -; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[6,4]{1,0} bitcast([[P0]]) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,5,6]{2,1,0} parameter(1) -; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,30]{1,0} -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[6]{0} parameter(2) -; CHECK-NEXT: [[BROADCAST:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} broadcast([[P2]]), dimensions={3} -; CHECK-NEXT: [[BITCAST:%[^ ]+]] = f32[6,30]{1,0} bitcast([[BROADCAST]]) -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[BITCAST]]), +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6], {{.*}}: f32[6]) -> f32[2,3,5,6] { +; CHECK: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} custom-call( ; CHECK: custom_call_target="__cublas$lt$matmul", ; CHECK: output_to_operand_aliasing={{[{][{]}}}: (2, {})}, ; CHECK: backend_config={ @@ -2539,7 +2721,7 @@ ENTRY test { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2]) -> f32[4,2] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2]) -> f32[4,2] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2]{0} parameter(2) @@ -2584,7 +2766,7 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[4,3], y: f32[3,4], z: f32[3]) -> f32[2,3] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[4,3], {{.*}}: f32[3,4], {{.*}}: f32[3]) -> f32[2,3] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[4,3]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[3]{0} parameter(2) @@ -2633,16 +2815,7 @@ ENTRY test { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( - -; CHECK: [[FUSED_COMPUTATION:%[^ ]+]] ([[DUMMY0:[^ ]+]]: f32[2], [[DUMMY1:[^ ]+]]: f32[2,4]) -> f32[2,2] { -; CHECK-DAG: [[P0:%[^ ]+]] = f32[2]{0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,4]{1,0} parameter(1) -; CHECK-DAG: [[SLICE:%[^ ]+]] = f32[2,2]{1,0} slice([[P1]]), slice={[0:2], [0:2]} -; CHECK-NEXT: [[P0_BCAST:%[^ ]+]] = f32[2,2]{1,0} broadcast([[P0]]), dimensions={1} -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} add([[SLICE]], [[P0_BCAST]]) -} - -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2]) -> f32[2,2] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2]) -> f32[2,2] { ; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-DAG: [[P2:%[^ ]+]] = f32[2]{0} parameter(2) @@ -2663,11 +2836,7 @@ ENTRY test { ; CHECK-DAG: } ; CHECK-DAG: "epilogue":"DEFAULT" ; CHECK: } -; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[2,2]{1,0} fusion([[P2]], [[MATMUL0]]), kind=kLoop, calls=[[FUSED_COMPUTATION]] -; CHECK-NEXT: [[SLICE:%[^ ]+]] = f32[2,2]{1,0} slice([[MATMUL0]]), slice={[0:2], [0:2]} -; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(5) -; CHECK-NEXT: [[C0_BCAST:%[^ ]+]] = f32[2,2]{1,0} broadcast([[C0]]), dimensions={} -; CHECK-NEXT: [[MATMUL1:%[^ ]+]] = f32[2,2]{1,0} custom-call([[SLICE]], [[C0_BCAST]]), +; CHECK: [[MATMUL1:%[^ ]+]] = f32[2,2]{1,0} custom-call( ; CHECK: custom_call_target="__cublas$lt$matmul", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -2684,7 +2853,7 @@ ENTRY test { ; CHECK-DAG: } ; CHECK-DAG: "epilogue":"DEFAULT" ; CHECK: } -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[FUSION]], [[MATMUL1]]), +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call{{.*}}[[MATMUL1]] ; CHECK: custom_call_target="__cublas$lt$matmul", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -2765,7 +2934,7 @@ ENTRY test { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4], z2: f32[2,4]) -> f32[2,4] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4], {{.*}}: f32[2,4]) -> f32[2,4] { ; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-DAG: [[VECTOR_BIAS:%[^ ]+]] = f32[4]{0} parameter(2) @@ -2809,7 +2978,7 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: bf16[16,24], y: bf16[24,32], z: bf16[32]) -> bf16[16,32] { +; CHECK-LABEL: ENTRY %test ({{.*}}: bf16[16,24], {{.*}}: bf16[24,32], {{.*}}: bf16[32]) -> bf16[16,32] { ; CHECK-NEXT: [[P0:%[^ ]+]] = bf16[16,24]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = bf16[24,32]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[32]{0} parameter(2) @@ -2833,8 +3002,7 @@ ENTRY test { } TEST_F(CublasLtGemmRewriteTest, BF16VectorBiasPadded) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { + if (!CudaOrRocmCheck(se::CudaComputeCapability::AMPERE, Switch::True)) { GTEST_SKIP() << "Padding of GEMM bf16 operands only implemented on " "architectures with bf16 Tensor Cores."; } @@ -2848,38 +3016,13 @@ ENTRY test { dot_a = bf16[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} z_bcast = bf16[2,4] broadcast(z), dimensions={1} ROOT out = bf16[2,4] add(dot_a, z_bcast) -} - -)"; +})"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: bf16[2,3], y: bf16[3,4], z: bf16[4]) -> bf16[2,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = bf16[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[C0:%[^ ]+]] = bf16[] constant(0) -; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = bf16[8,8]{1,0} pad([[P0]], [[C0]]), padding=0_6x0_5 -; CHECK-NEXT: [[P1:%[^ ]+]] = bf16[3,4]{1,0} parameter(1) -; CHECK-NEXT: [[P1_PADDED:%[^ ]+]] = bf16[8,8]{1,0} pad([[P1]], [[C0]]), padding=0_5x0_4 -; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[4]{0} parameter(2) -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config={ -; CHECK-DAG: "alpha_real":1 -; CHECK-DAG: "alpha_imag":0 -; CHECK-DAG: "beta":0 -; CHECK-DAG: "dot_dimension_numbers":{ -; CHECK-DAG: "lhs_contracting_dimensions":["1"] -; CHECK-DAG: "rhs_contracting_dimensions":["0"] -; CHECK-DAG: "lhs_batch_dimensions":[] -; CHECK-DAG: "rhs_batch_dimensions":[] -; CHECK-DAG: } -; CHECK-DAG: "precision_config":{ -; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] -; CHECK-DAG: } -; CHECK-DAG: "epilogue":"BIAS" -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[2,4]{1,0} slice([[MATMUL]]), slice={[0:2], [0:4]} + MatchOptimizedHlo(hlo_text, R"( +; CHECK-DAG: ENTRY %test ({{.*}}: bf16[2,3], {{.*}}: bf16[3,4], {{.*}}: bf16[4]) -> bf16[2,4] { +; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_6x0_5 +; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_5x0_4 )"); } @@ -2902,7 +3045,7 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4]) -> f32[2,4] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), @@ -2944,7 +3087,7 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3,4], y: f32[4,5,6]) -> f32[2,3,5,6] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6]) -> f32[2,3,5,6] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3,4]{2,1,0} parameter(0) ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[6,4]{1,0} bitcast([[P0]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,5,6]{2,1,0} parameter(1) @@ -2990,7 +3133,7 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4]) -> f32[2,2] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,2] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), @@ -3035,7 +3178,7 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2,4]) -> f32[2,4] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2,4]) -> f32[2,4] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2,4]{1,0} parameter(2) @@ -3080,7 +3223,7 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[4,4], y: f32[4,4], z: f32[4,4]) -> f32[4,4] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[4,4], {{.*}}: f32[4,4], {{.*}}: f32[4,4]) -> f32[4,4] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,4]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,4]{1,0} parameter(2) @@ -3126,7 +3269,7 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4]) -> f32[2,4] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) @@ -3172,15 +3315,8 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3,4], y: f32[4,5,6], z: f32[3,5,6]) -> f32[2,3,5,6] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3,4]{2,1,0} parameter(0) -; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[6,4]{1,0} bitcast([[P0]]) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,5,6]{2,1,0} parameter(1) -; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,30]{1,0} -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[3,5,6]{2,1,0} parameter(2) -; CHECK-NEXT: [[BROADCAST:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} broadcast([[P2]]), dimensions={1,2,3} -; CHECK-NEXT: [[BITCAST:%[^ ]+]] = f32[6,30]{1,0} bitcast([[BROADCAST]]) -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[BITCAST]]), +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3,4], {{.*}}: f32[4,5,6], {{.*}}: f32[3,5,6]) -> f32[2,3,5,6] { +; CHECK: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} custom-call( ; CHECK: custom_call_target="__cublas$lt$matmul", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -3224,7 +3360,7 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2]) -> f32[4,2] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[2]) -> f32[4,2] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2]{0} parameter(2) @@ -3273,7 +3409,7 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z_vec: f32[4], z_matrix: f32[2,4]) -> f32[2,4] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4], {{.*}}: f32[2,4]) -> f32[2,4] { ; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-DAG: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) @@ -3331,7 +3467,7 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4]) -> f32[2,4] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> f32[2,4] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), @@ -3392,6 +3528,14 @@ ENTRY test { } TEST_F(CublasLtGemmRewriteTest, VectorBiasThenApproxGeluActivation) { +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60000 + auto rocm_switch = Switch::False; // GELU is only available from ROCM 6.0 +#else + auto rocm_switch = Switch::True; +#endif + if (CudaOrRocmCheck(Switch::False, rocm_switch)) { + GTEST_SKIP() << "TODO: Unsupported blas-lt epilogue on ROCM"; + } const char* hlo_text = R"( HloModule test @@ -3427,7 +3571,7 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4]) -> f32[2,4] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) @@ -3452,6 +3596,9 @@ ENTRY test { } TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationWithAux) { + if (CudaOrRocmCheck(Switch::False, Switch::True)) { + GTEST_SKIP() << "TODO: Unsupported blas-lt epilogue on ROCM"; + } const char* hlo_text = R"( HloModule test @@ -3485,7 +3632,7 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4]) -> (f32[2,4], f32[2,4]) { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4]) -> (f32[2,4], f32[2,4]) { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2,4]{1,0}) custom-call([[P0]], [[P1]]), @@ -3509,6 +3656,9 @@ ENTRY test { } TEST_F(CublasLtGemmRewriteTest, VectorBiasThenApproxGeluActivationWithAux) { + if (CudaOrRocmCheck(Switch::False, Switch::True)) { + GTEST_SKIP() << "TODO: Unsupported blas-lt epilogue on ROCM"; + } const char* hlo_text = R"( HloModule test @@ -3545,7 +3695,7 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4]) -> (f32[2,4], f32[2,4]) { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> (f32[2,4], f32[2,4]) { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) @@ -3570,8 +3720,7 @@ ENTRY test { } TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationBF16) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { + if (!CudaOrRocmCheck(se::CudaComputeCapability::AMPERE, Switch::True)) { GTEST_SKIP() << "Padding of GEMM bf16 operands only implemented on " "architectures with bf16 Tensor Cores."; } @@ -3599,38 +3748,13 @@ ENTRY test { bcast.3 = bf16[2,4] broadcast(const.3), dimensions={} mul.4 = bf16[2,4] multiply(add.2, bcast.3) ROOT out = bf16[2,4] multiply(dot, mul.4) -} - -)"; +})"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{5e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: bf16[2,3], y: bf16[3,4]) -> bf16[2,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = bf16[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[C0:%[^ ]+]] = bf16[] constant(0) -; CHECK-NEXT: [[P0_PAD:%[^ ]+]] = bf16[8,8]{1,0} pad([[P0]], [[C0]]), padding=0_6x0_5 -; CHECK-NEXT: [[P1:%[^ ]+]] = bf16[3,4]{1,0} parameter(1) -; CHECK-NEXT: [[P1_PAD:%[^ ]+]] = bf16[8,8]{1,0} pad([[P1]], [[C0]]), padding=0_5x0_4 -; CHECK-NEXT: [[DOT:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0_PAD]], [[P1_PAD]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config={ -; CHECK-DAG: "alpha_real":1 -; CHECK-DAG: "alpha_imag":0 -; CHECK-DAG: "beta":0 -; CHECK-DAG: "dot_dimension_numbers":{ -; CHECK-DAG: "lhs_contracting_dimensions":["1"] -; CHECK-DAG: "rhs_contracting_dimensions":["0"] -; CHECK-DAG: "lhs_batch_dimensions":[] -; CHECK-DAG: "rhs_batch_dimensions":[] -; CHECK-DAG: } -; CHECK-DAG: "precision_config":{ -; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] -; CHECK-DAG: } -; CHECK-DAG: "epilogue":"GELU" -; CHECK: } -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[2,4]{1,0} slice([[DOT]]), slice={[0:2], [0:4]} + MatchOptimizedHlo(hlo_text, R"( +; CHECK-DAG: ENTRY %test ({{.*}}: bf16[2,3], {{.*}}: bf16[3,4]) -> bf16[2,4] { +; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_6x0_5 +; CHECK-DAG: bf16[8,8]{1,0} pad({{.*}}), padding=0_5x0_4 )"); } @@ -3666,7 +3790,7 @@ ENTRY test { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); + GemmRewriter pass(GpuComputeComp()); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -3698,7 +3822,7 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f16[8,16], y: f16[16,8], z: f16[8,8]) -> f16[8,8] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f16[8,16], {{.*}}: f16[16,8], {{.*}}: f16[8,8]) -> f16[8,8] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8,8]{1,0} parameter(2) @@ -3740,7 +3864,7 @@ ENTRY test { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); + GemmRewriter pass(GpuComputeComp()); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -3766,39 +3890,18 @@ ENTRY test { dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} z_bcast = f16[8,8] broadcast(z), dimensions={1} ROOT add = f16[8,8] add(dot_a, z_bcast) -} +})"; -)"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f16[8,16], y: f16[16,8], z: f16[8]) -> f16[8,8] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8]{0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config={ -; CHECK-DAG: "alpha_real":1 -; CHECK-DAG: "alpha_imag":0 -; CHECK-DAG: "beta":0 -; CHECK-DAG: "dot_dimension_numbers":{ -; CHECK-DAG: "lhs_contracting_dimensions":["1"] -; CHECK-DAG: "rhs_contracting_dimensions":["0"] -; CHECK-DAG: "lhs_batch_dimensions":[] -; CHECK-DAG: "rhs_batch_dimensions":[] -; CHECK-DAG: } -; CHECK-DAG: "precision_config":{ -; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] -; CHECK-DAG: } -; CHECK-DAG: "epilogue":"BIAS" -; CHECK: } + MatchOptimizedHlo(hlo_text, R"( +; CHECK-NOT: pad(" +; CHECK: custom-call +; CHECK-SAME: custom_call_target="__cublas$lt$matmul" )"); } TEST_F(CublasLtGemmRewriteTest, VectorBiasF16Padded) { - if (!GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { + if (!CudaOrRocmCheck(se::CudaComputeCapability::VOLTA, Switch::True)) { GTEST_SKIP() << "Padding of GEMM operands only implemented on " "architectures with Tensor Cores."; } @@ -3812,38 +3915,15 @@ ENTRY test { dot_a = f16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} z_bcast = f16[6,6] broadcast(z), dimensions={1} ROOT add = f16[6,6] add(dot_a, z_bcast) -} +})"; -)"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f16[6,12], y: f16[12,6], z: f16[6]) -> f16[6,6] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f16[6,12]{1,0} parameter(0) -; CHECK-NEXT: [[C0:%[^ ]+]] = f16[] constant(0) -; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = f16[8,16]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_4 -; CHECK-NEXT: [[P1:%[^ ]+]] = f16[12,6]{1,0} parameter(1) -; CHECK-NEXT: [[P1_PADDED:%[^ ]+]] = f16[16,8]{1,0} pad([[P1]], [[C0]]), padding=0_4x0_2 -; CHECK-NEXT: [[P2:%[^ ]+]] = f16[6]{0} parameter(2) -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config={ -; CHECK-DAG: "alpha_real":1 -; CHECK-DAG: "alpha_imag":0 -; CHECK-DAG: "beta":0 -; CHECK-DAG: "dot_dimension_numbers":{ -; CHECK-DAG: "lhs_contracting_dimensions":["1"] -; CHECK-DAG: "rhs_contracting_dimensions":["0"] -; CHECK-DAG: "lhs_batch_dimensions":[] -; CHECK-DAG: "rhs_batch_dimensions":[] -; CHECK-DAG: } -; CHECK-DAG: "precision_config":{ -; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] -; CHECK-DAG: } -; CHECK-DAG: "epilogue":"BIAS" -; CHECK: } -; CHECK-NEXT: [[OUT:%[^ ]+]] = f16[6,6]{1,0} slice([[MATMUL]]), slice={[0:6], [0:6]} +; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6], {{.*}}: f16[6]) -> f16[6,6] { +; CHECK-DAG: f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4 +; CHECK-DAG: f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2 )"); } @@ -3860,38 +3940,18 @@ ENTRY test { c = f16[] constant(0) c_bcast = f16[8,8] broadcast(c), dimensions={} ROOT out = f16[8,8] maximum(dot_a, c_bcast) -} +})"; -)"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f16[8,16], y: f16[16,8]) -> f16[8,8] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config={ -; CHECK-DAG: "alpha_real":1 -; CHECK-DAG: "alpha_imag":0 -; CHECK-DAG: "beta":0 -; CHECK-DAG: "dot_dimension_numbers":{ -; CHECK-DAG: "lhs_contracting_dimensions":["1"] -; CHECK-DAG: "rhs_contracting_dimensions":["0"] -; CHECK-DAG: "lhs_batch_dimensions":[] -; CHECK-DAG: "rhs_batch_dimensions":[] -; CHECK-DAG: } -; CHECK-DAG: "precision_config":{ -; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] -; CHECK-DAG: } -; CHECK-DAG: "epilogue":"RELU" -; CHECK: } + MatchOptimizedHlo(hlo_text, R"( +; CHECK-NOT: pad(" +; CHECK: custom-call +; CHECK-SAME: custom_call_target="__cublas$lt$matmul" )"); } TEST_F(CublasLtGemmRewriteTest, ReluActivationF16Padded) { - if (!GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { + if (!CudaOrRocmCheck(se::CudaComputeCapability::VOLTA, Switch::True)) { GTEST_SKIP() << "Padding of GEMM operands only implemented on " "architectures with Tensor Cores."; } @@ -3905,37 +3965,13 @@ ENTRY test { c = f16[] constant(0) c_bcast = f16[6,6] broadcast(c), dimensions={} ROOT out = f16[6,6] maximum(dot_a, c_bcast) -} +})"; -)"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f16[6,12], y: f16[12,6]) -> f16[6,6] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f16[6,12]{1,0} parameter(0) -; CHECK-NEXT: [[C0:%[^ ]+]] = f16[] constant(0) -; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = f16[8,16]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_4 -; CHECK-NEXT: [[P1:%[^ ]+]] = f16[12,6]{1,0} parameter(1) -; CHECK-NEXT: [[P1_PADDED:%[^ ]+]] = f16[16,8]{1,0} pad([[P1]], [[C0]]), padding=0_4x0_2 -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config={ -; CHECK-DAG: "alpha_real":1 -; CHECK-DAG: "alpha_imag":0 -; CHECK-DAG: "beta":0 -; CHECK-DAG: "dot_dimension_numbers":{ -; CHECK-DAG: "lhs_contracting_dimensions":["1"] -; CHECK-DAG: "rhs_contracting_dimensions":["0"] -; CHECK-DAG: "lhs_batch_dimensions":[] -; CHECK-DAG: "rhs_batch_dimensions":[] -; CHECK-DAG: } -; CHECK-DAG: "precision_config":{ -; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] -; CHECK-DAG: } -; CHECK-DAG: "epilogue":"RELU" -; CHECK: } -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[6,6]{1,0} slice([[MATMUL]]), slice={[0:6], [0:6]} + MatchOptimizedHlo(hlo_text, R"( +; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6]) -> f16[6,6] { +; CHECK-DAG: f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4 +; CHECK-DAG: f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2 )"); } @@ -3960,7 +3996,7 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f16[8,16], y: f16[16,8], z: f16[8,8]) -> f16[8,8] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f16[8,16], {{.*}}: f16[16,8], {{.*}}: f16[8,8]) -> f16[8,8] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8,8]{1,0} parameter(2) @@ -4000,39 +4036,18 @@ ENTRY test { c = f16[] constant(0) c_bcast = f16[8,8] broadcast(c), dimensions={} ROOT out = f16[8,8] maximum(add, c_bcast) -} +})"; -)"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f16[8,16], y: f16[16,8], z: f16[8]) -> f16[8,8] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8]{0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config={ -; CHECK-DAG: "alpha_real":1 -; CHECK-DAG: "alpha_imag":0 -; CHECK-DAG: "beta":0 -; CHECK-DAG: "dot_dimension_numbers":{ -; CHECK-DAG: "lhs_contracting_dimensions":["1"] -; CHECK-DAG: "rhs_contracting_dimensions":["0"] -; CHECK-DAG: "lhs_batch_dimensions":[] -; CHECK-DAG: "rhs_batch_dimensions":[] -; CHECK-DAG: } -; CHECK-DAG: "precision_config":{ -; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] -; CHECK-DAG: } -; CHECK-DAG: "epilogue":"BIAS_RELU" -; CHECK: } - )"); -} + MatchOptimizedHlo(hlo_text, R"( +; CHECK-NOT: pad(" +; CHECK: custom-call +; CHECK-SAME: custom_call_target="__cublas$lt$matmul" +)"); +} TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF16Padded) { - if (!GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { + if (!CudaOrRocmCheck(se::CudaComputeCapability::VOLTA, Switch::True)) { GTEST_SKIP() << "Padding of GEMM operands only implemented on " "architectures with Tensor Cores."; } @@ -4049,37 +4064,13 @@ ENTRY test { c = f16[] constant(0) c_bcast = f16[6,6] broadcast(c), dimensions={} ROOT out = f16[6,6] maximum(add, c_bcast) -} +})"; -)"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f16[6,12], y: f16[12,6], z: f16[6]) -> f16[6,6] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f16[6,12]{1,0} parameter(0) -; CHECK-NEXT: [[C0:%[^ ]+]] = f16[] constant(0) -; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = f16[8,16]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_4 -; CHECK-NEXT: [[P1:%[^ ]+]] = f16[12,6]{1,0} parameter(1) -; CHECK-NEXT: [[P1_PADDED:%[^ ]+]] = f16[16,8]{1,0} pad([[P1]], [[C0]]), padding=0_4x0_2 -; CHECK-NEXT: [[P2:%[^ ]+]] = f16[6]{0} parameter(2) -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config={ -; CHECK-DAG: "alpha_real":1 -; CHECK-DAG: "alpha_imag":0 -; CHECK-DAG: "beta":0 -; CHECK-DAG: "dot_dimension_numbers":{ -; CHECK-DAG: "lhs_contracting_dimensions":["1"] -; CHECK-DAG: "rhs_contracting_dimensions":["0"] -; CHECK-DAG: "lhs_batch_dimensions":[] -; CHECK-DAG: "rhs_batch_dimensions":[] -; CHECK-DAG: } -; CHECK-DAG: "precision_config":{ -; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] -; CHECK-DAG: } -; CHECK-DAG: "epilogue":"BIAS_RELU" -; CHECK: } + MatchOptimizedHlo(hlo_text, R"( +; CHECK-DAG: ENTRY %test ({{.*}}: f16[6,12], {{.*}}: f16[12,6], {{.*}}: f16[6]) -> f16[6,6] { +; CHECK-DAG: f16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4 +; CHECK-DAG: f16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2 )"); } @@ -4103,7 +4094,7 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: bf16[8,16], y: bf16[16,8], z: bf16[8,8]) -> bf16[8,8] { +; CHECK-LABEL: ENTRY %test ({{.*}}: bf16[8,16], {{.*}}: bf16[16,8], {{.*}}: bf16[8,8]) -> bf16[8,8] { ; CHECK-DAG: [[P0:%[^ ]+]] = bf16[8,16]{1,0} parameter(0) ; CHECK-DAG: [[P1:%[^ ]+]] = bf16[16,8]{1,0} parameter(1) ; CHECK-DAG: [[P2:%[^ ]+]] = bf16[8,8]{1,0} parameter(2) @@ -4144,7 +4135,7 @@ ENTRY test { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); + GemmRewriter pass(GpuComputeComp()); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -4172,40 +4163,18 @@ ENTRY test { dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} z_bcast = bf16[8,8] broadcast(z), dimensions={1} ROOT add = bf16[8,8] add(dot_a, z_bcast) -} +})"; -)"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: bf16[8,16], y: bf16[16,8], z: bf16[8]) -> bf16[8,8] { -; CHECK-DAG: [[P0:%[^ ]+]] = bf16[8,16]{1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = bf16[16,8]{1,0} parameter(1) -; CHECK-DAG: [[P2:%[^ ]+]] = bf16[8]{0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config={ -; CHECK-DAG: "alpha_real":1 -; CHECK-DAG: "alpha_imag":0 -; CHECK-DAG: "beta":0 -; CHECK-DAG: "dot_dimension_numbers":{ -; CHECK-DAG: "lhs_contracting_dimensions":["1"] -; CHECK-DAG: "rhs_contracting_dimensions":["0"] -; CHECK-DAG: "lhs_batch_dimensions":[] -; CHECK-DAG: "rhs_batch_dimensions":[] -; CHECK-DAG: } -; CHECK-DAG: "precision_config":{ -; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] -; CHECK-DAG: } -; CHECK-DAG: "epilogue":"BIAS" -; CHECK: } + MatchOptimizedHlo(hlo_text, R"( +; CHECK-NOT: pad(" +; CHECK: custom-call +; CHECK-SAME: custom_call_target="__cublas$lt$matmul" )"); } TEST_F(CublasLtGemmRewriteTest, VectorBiasBF16Padded) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { + if (!CudaOrRocmCheck(se::CudaComputeCapability::AMPERE, Switch::True)) { GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on " "Ampere and newer architectures."; } @@ -4219,38 +4188,13 @@ ENTRY test { dot_a = bf16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} z_bcast = bf16[6,6] broadcast(z), dimensions={1} ROOT add = bf16[6,6] add(dot_a, z_bcast) -} +})"; -)"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: bf16[6,12], y: bf16[12,6], z: bf16[6]) -> bf16[6,6] { -; CHECK-DAG: [[P0:%[^ ]+]] = bf16[6,12]{1,0} parameter(0) -; CHECK-DAG: [[C0:%[^ ]+]] = bf16[] constant(0) -; CHECK-DAG: [[P0_PADDED:%[^ ]+]] = bf16[8,16]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_4 -; CHECK-DAG: [[P1:%[^ ]+]] = bf16[12,6]{1,0} parameter(1) -; CHECK-DAG: [[P1_PADDED:%[^ ]+]] = bf16[16,8]{1,0} pad([[P1]], [[C0]]), padding=0_4x0_2 -; CHECK-DAG: [[P2:%[^ ]+]] = bf16[6]{0} parameter(2) -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config={ -; CHECK-DAG: "alpha_real":1 -; CHECK-DAG: "alpha_imag":0 -; CHECK-DAG: "beta":0 -; CHECK-DAG: "dot_dimension_numbers":{ -; CHECK-DAG: "lhs_contracting_dimensions":["1"] -; CHECK-DAG: "rhs_contracting_dimensions":["0"] -; CHECK-DAG: "lhs_batch_dimensions":[] -; CHECK-DAG: "rhs_batch_dimensions":[] -; CHECK-DAG: } -; CHECK-DAG: "precision_config":{ -; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] -; CHECK-DAG: } -; CHECK-DAG: "epilogue":"BIAS" -; CHECK: } -; CHECK-NEXT: [[OUT:%[^ ]+]] = bf16[6,6]{1,0} slice([[MATMUL]]), slice={[0:6], [0:6]} + MatchOptimizedHlo(hlo_text, R"( +; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6], {{.*}}: bf16[6]) -> bf16[6,6] { +; CHECK-DAG: bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4 +; CHECK-DAG: bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2 )"); } @@ -4271,35 +4215,15 @@ ENTRY test { )"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: bf16[8,16], y: bf16[16,8]) -> bf16[8,8] { -; CHECK-DAG: [[P0:%[^ ]+]] = bf16[8,16]{1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = bf16[16,8]{1,0} parameter(1) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config={ -; CHECK-DAG: "alpha_real":1 -; CHECK-DAG: "alpha_imag":0 -; CHECK-DAG: "beta":0 -; CHECK-DAG: "dot_dimension_numbers":{ -; CHECK-DAG: "lhs_contracting_dimensions":["1"] -; CHECK-DAG: "rhs_contracting_dimensions":["0"] -; CHECK-DAG: "lhs_batch_dimensions":[] -; CHECK-DAG: "rhs_batch_dimensions":[] -; CHECK-DAG: } -; CHECK-DAG: "precision_config":{ -; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] -; CHECK-DAG: } -; CHECK-DAG: "epilogue":"RELU" -; CHECK: } + MatchOptimizedHlo(hlo_text, R"( +; CHECK-NOT: pad(" +; CHECK: custom-call +; CHECK-SAME: custom_call_target="__cublas$lt$matmul" )"); } TEST_F(CublasLtGemmRewriteTest, ReluActivationBF16Padded) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { + if (!CudaOrRocmCheck(se::CudaComputeCapability::AMPERE, Switch::True)) { GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on " "Ampere and newer architectures."; } @@ -4313,37 +4237,13 @@ ENTRY test { c = bf16[] constant(0) c_bcast = bf16[6,6] broadcast(c), dimensions={} ROOT out = bf16[6,6] maximum(dot_a, c_bcast) -} +})"; -)"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: bf16[6,12], y: bf16[12,6]) -> bf16[6,6] { -; CHECK-DAG: [[P0:%[^ ]+]] = bf16[6,12]{1,0} parameter(0) -; CHECK-DAG: [[C0:%[^ ]+]] = bf16[] constant(0) -; CHECK-DAG: [[P0_PADDED:%[^ ]+]] = bf16[8,16]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_4 -; CHECK-DAG: [[P1:%[^ ]+]] = bf16[12,6]{1,0} parameter(1) -; CHECK-DAG: [[P1_PADDED:%[^ ]+]] = bf16[16,8]{1,0} pad([[P1]], [[C0]]), padding=0_4x0_2 -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config={ -; CHECK-DAG: "alpha_real":1 -; CHECK-DAG: "alpha_imag":0 -; CHECK-DAG: "beta":0 -; CHECK-DAG: "dot_dimension_numbers":{ -; CHECK-DAG: "lhs_contracting_dimensions":["1"] -; CHECK-DAG: "rhs_contracting_dimensions":["0"] -; CHECK-DAG: "lhs_batch_dimensions":[] -; CHECK-DAG: "rhs_batch_dimensions":[] -; CHECK-DAG: } -; CHECK-DAG: "precision_config":{ -; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] -; CHECK-DAG: } -; CHECK-DAG: "epilogue":"RELU" -; CHECK: } -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[6,6]{1,0} slice([[MATMUL]]), slice={[0:6], [0:6]} + MatchOptimizedHlo(hlo_text, R"( +; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6]) -> bf16[6,6] { +; CHECK-DAG: bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4 +; CHECK-DAG: bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2 )"); } @@ -4363,40 +4263,19 @@ ENTRY test { c = bf16[] constant(0) c_bcast = bf16[8,8] broadcast(c), dimensions={} ROOT out = bf16[8,8] maximum(add, c_bcast) -} +})"; -)"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3})); MatchOptimizedHlo(hlo_text, R"( - -; CHECK-LABEL: ENTRY %test (x: bf16[8,16], y: bf16[16,8], z: bf16[8]) -> bf16[8,8] { -; CHECK-DAG: [[P0:%[^ ]+]] = bf16[8,16]{1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = bf16[16,8]{1,0} parameter(1) -; CHECK-DAG: [[P2:%[^ ]+]] = bf16[8]{0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config={ -; CHECK-DAG: "alpha_real":1 -; CHECK-DAG: "alpha_imag":0 -; CHECK-DAG: "beta":0 -; CHECK-DAG: "dot_dimension_numbers":{ -; CHECK-DAG: "lhs_contracting_dimensions":["1"] -; CHECK-DAG: "rhs_contracting_dimensions":["0"] -; CHECK-DAG: "lhs_batch_dimensions":[] -; CHECK-DAG: "rhs_batch_dimensions":[] -; CHECK-DAG: } -; CHECK-DAG: "precision_config":{ -; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] -; CHECK-DAG: } -; CHECK-DAG: "epilogue":"BIAS_RELU" -; CHECK: } +; CHECK-NOT: pad(" +; CHECK: custom-call +; CHECK-SAME: custom_call_target="__cublas$lt$matmul" )"); } TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationBF16Padded) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { + if (!CudaOrRocmCheck(se::CudaComputeCapability::AMPERE, Switch::True)) { GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on " "Ampere and newer architectures."; } @@ -4417,38 +4296,17 @@ ENTRY test { )"; EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: bf16[6,12], y: bf16[12,6], z: bf16[6]) -> bf16[6,6] { -; CHECK-DAG: [[P0:%[^ ]+]] = bf16[6,12]{1,0} parameter(0) -; CHECK-DAG: [[C0:%[^ ]+]] = bf16[] constant(0) -; CHECK-DAG: [[P0_PADDED:%[^ ]+]] = bf16[8,16]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_4 -; CHECK-DAG: [[P1:%[^ ]+]] = bf16[12,6]{1,0} parameter(1) -; CHECK-DAG: [[P1_PADDED:%[^ ]+]] = bf16[16,8]{1,0} pad([[P1]], [[C0]]), padding=0_4x0_2 -; CHECK-DAG: [[P2:%[^ ]+]] = bf16[6]{0} parameter(2) -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config={ -; CHECK-DAG: "alpha_real":1 -; CHECK-DAG: "alpha_imag":0 -; CHECK-DAG: "beta":0 -; CHECK-DAG: "dot_dimension_numbers":{ -; CHECK-DAG: "lhs_contracting_dimensions":["1"] -; CHECK-DAG: "rhs_contracting_dimensions":["0"] -; CHECK-DAG: "lhs_batch_dimensions":[] -; CHECK-DAG: "rhs_batch_dimensions":[] -; CHECK-DAG: } -; CHECK-DAG: "precision_config":{ -; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] -; CHECK-DAG: } -; CHECK-DAG: "epilogue":"BIAS_RELU" -; CHECK: } -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[6,6]{1,0} slice([[MATMUL]]), slice={[0:6], [0:6]} + MatchOptimizedHlo(hlo_text, R"( +; CHECK-DAG: ENTRY %test ({{.*}}: bf16[6,12], {{.*}}: bf16[12,6], {{.*}}: bf16[6]) -> bf16[6,6] { +; CHECK-DAG: bf16[8,16]{1,0} pad({{.*}}), padding=0_2x0_4 +; CHECK-DAG: bf16[16,8]{1,0} pad({{.*}}), padding=0_4x0_2 )"); } TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF64) { + if (CudaOrRocmCheck(Switch::False, Switch::True)) { + GTEST_SKIP() << "TODO: Unsupported blas-lt F64 datatype on ROCM"; + } const char* hlo_text = R"( HloModule test @@ -4470,7 +4328,7 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f64[2,3], y: f64[3,4], z: f64[4]) -> f64[2,4] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f64[2,3], {{.*}}: f64[3,4], {{.*}}: f64[4]) -> f64[2,4] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f64[2,3]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f64[3,4]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f64[4]{0} parameter(2) @@ -4519,7 +4377,7 @@ ENTRY test { MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4]) -> f32[2,4] { +; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,3], {{.*}}: f32[3,4], {{.*}}: f32[4]) -> f32[2,4] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) @@ -4571,7 +4429,7 @@ ENTRY test { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); + GemmRewriter pass(GpuComputeComp()); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); SCOPED_TRACE(module->ToString()); EXPECT_TRUE(changed); @@ -4626,6 +4484,9 @@ ENTRY main { // Test gemm matrix bias add fusion with mix type and out of place update(C != // D) TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeOutOfPlace) { + if (CudaOrRocmCheck(Switch::False, Switch::True)) { + GTEST_SKIP() << "TODO: Unsupported mixed datatypes on ROCM"; + } std::vector> type_combinations = { {"f16", "f32"}, @@ -4660,6 +4521,9 @@ ENTRY test { // Test batch gemm matrix bias add fusion with mix type and out of place // update(C != D) TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeOutOfPlaceBatched) { + if (CudaOrRocmCheck(Switch::False, Switch::True)) { + GTEST_SKIP() << "TODO: Unsupported mixed datatypes on ROCM"; + } std::vector> type_combinations = { {"f16", "f32"}, @@ -4693,6 +4557,9 @@ ENTRY test { // Test gemm matrix bias add fusion with mix type and in place update(C = D) TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeInPlace) { + if (CudaOrRocmCheck(Switch::False, Switch::True)) { + GTEST_SKIP() << "TODO: Unsupported mixed datatypes on ROCM"; + } std::vector> type_combinations = { {"f16", "f32"}, @@ -4745,92 +4612,224 @@ ENTRY test { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, GetOptimizedModule(hlo_text)); - EXPECT_THAT( - optimized_module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(2), - m::CustomCall({"__cublas$lt$matmul"}, - m::Parameter(0), m::Parameter(1))))); + MatchOptimizedHlo(hlo_text, R"( +; CHECK: %[[custom_call:.*]] = {{.*}} custom-call{{.*}}__cublas$lt$matmul +; CHECK: ROOT {{.*}} fusion({{.*}}%[[custom_call]] +)"); } class ParameterizedFp8GemmRewriteTest : public ParameterizedGemmRewriteTest { + public: + ParameterizedFp8GemmRewriteTest() { + replacements_[kF8E4M3DatatypePlaceholder] = +#if GOOGLE_CUDA + "f8e4m3fn"; +#else + "f8e4m3fnuz"; +#endif + replacements_[kF8E5M2DatatypePlaceholder] = +#if GOOGLE_CUDA + "f8e5m2"; +#else + "f8e5m2fnuz"; +#endif + replacements_[kF8E4M3AmaxPlaceholder] = +#if GOOGLE_CUDA + "448."; +#else + "240."; +#endif + } + protected: // Check the HLO runs and has an FP8 cuBLAS LT custom call on supported // architectures (Ada, Hopper, and later). void CheckFp8IfSupported(absl::string_view hlo_text, ErrorSpec error_spec = ErrorSpec{1e-2, 1e-2}) { - if (!GetCudaComputeCapability().IsAtLeast(8, 9)) { + if (!CudaOrRocmCheck(8, 9, [](const se::RocmComputeCapability& cc) { + return cc.has_fp8_support(); + })) { return; } - EXPECT_TRUE(RunAndCompare(hlo_text, error_spec)); + std::string replaced_hlo_text = + absl::StrReplaceAll(hlo_text, replacements_); + EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_), + error_spec)); // Most FP8 tests directly create a GemmRewriter and check the output. // Here, also run the entire HLO pass pipeline to ensure no other passes // interfere with GemmRewriter's pattern matching. TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, - GetOptimizedModule(hlo_text)); + GetOptimizedModule(replaced_hlo_text)); const HloInstruction* call = FindInstruction(optimized_module.get(), HloOpcode::kCustomCall); ASSERT_NE(call, nullptr); EXPECT_EQ(call->custom_call_target(), "__cublas$lt$matmul$f8"); } + + void MatchOptimizedHlo(absl::string_view hlo, const absl::string_view pattern, + bool print_operand_shape = false) { + GemmRewriteTest::MatchOptimizedHlo( + absl::StrReplaceAll(hlo, replacements_), + absl::StrReplaceAll(pattern, replacements_), print_operand_shape); + } + + void RunAndFilecheckHloRewrite( + absl::string_view hlo, HloPassInterface&& hlo_pass, + std::optional expected, + std::function after_pass_checks = nullptr, + const HloModuleConfig* config = nullptr) { + if (expected.has_value()) { + std::string replaced_pattern = + absl::StrReplaceAll(expected.value(), replacements_); + GemmRewriteTest::RunAndFilecheckHloRewrite( + absl::StrReplaceAll(hlo, replacements_), std::move(hlo_pass), + replaced_pattern, after_pass_checks, config); + } + } + + absl::StatusOr> + ParseAndReturnVerifiedModule(absl::string_view hlo_text, + int64_t replica_count = 1, + int64_t num_partitions = 1) { + return GemmRewriteTest::ParseAndReturnVerifiedModule( + absl::StrReplaceAll(hlo_text, replacements_)); + } + + private: + static constexpr const char* kF8E4M3DatatypePlaceholder{"<>"}; + static constexpr const char* kF8E5M2DatatypePlaceholder{"<>"}; + static constexpr const char* kF8E4M3AmaxPlaceholder{"<>"}; }; TEST_P(ParameterizedFp8GemmRewriteTest, DoNotRewriteToF8OnPreAda) { - if (GetCudaComputeCapability().IsAtLeast(8, 9)) { - GTEST_SKIP() << "Test requires a pre-Ada GPU."; + if (CudaOrRocmCheck(8, 9, [](const se::RocmComputeCapability& cc) { + return cc.has_fp8_support(); + })) { + GTEST_SKIP() << "Test requires a pre-Ada GPU or an AMD GPU prior to MI300."; } const char* hlo_text = R"( HloModule test ENTRY PreAdaTest { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) - ROOT out = f8e4m3fn[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) + ROOT out = <>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-2, 1e-2})); + EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_), + ErrorSpec{1e-2, 1e-2})); MatchOptimizedHlo(hlo_text, R"( -; CHECK-LABEL: ENTRY %PreAdaTest (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16]) -> f8e4m3fn[16,16] { +; CHECK-LABEL: ENTRY %PreAdaTest ({{.*}}: <>[16,32], {{.*}}: <>[32,16]) -> <>[16,16] { +; CHECK: {{.*}} = {{.*}} custom-call({{.*}}, {{.*}}) +; CHECK-DAG: custom_call_target="<>" + )"); +} + +TEST_P(ParameterizedFp8GemmRewriteTest, DoNotRewriteOnPreAdaWithF32Output) { + if (CudaOrRocmCheck(8, 9, [](const se::RocmComputeCapability& cc) { + return cc.has_fp8_support(); + })) { + GTEST_SKIP() << "Test requires a pre-Ada GPU or an AMD GPU prior to MI300."; + } + const char* hlo_text = R"( + HloModule test + + ENTRY PreAdaTest { + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) + ROOT out = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + +)"; + + EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_), + ErrorSpec{1e-2, 1e-2})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %PreAdaTest ({{.*}}: <>[16,32], {{.*}}: <>[32,16]) -> f32[16,16] { ; CHECK: {{.*}} = {{.*}} custom-call({{.*}}, {{.*}}) ; CHECK-DAG: custom_call_target="<>" )"); } TEST_P(ParameterizedFp8GemmRewriteTest, UnsupportedTypesF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + // Test with types unsupported by cuBLAS LT when FP8 is used. cuBLAS LT with // FP8 requires one of the operands to be F8E4M3FN. const char* hlo_text = R"( HloModule test ENTRY unsupported_types { - x = f8e5m2[16,16] parameter(0) - y = f8e5m2[16,16] parameter(1) - ROOT out = f8e5m2[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + x = <>[16,16] parameter(0) + y = <>[16,16] parameter(1) + ROOT out = <>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-2, 1e-2})); - RunAndFilecheckHloRewrite(hlo_text, GemmRewriter(GetCudaComputeCapability()), - absl::StrReplaceAll(R"( -; CHECK-LABEL: ENTRY %unsupported_types (x: f8e5m2[16,16], y: f8e5m2[16,16]) -> f8e5m2[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e5m2[16,16]{1,0} parameter(0) + EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(hlo_text, replacements_), + ErrorSpec{1e-2, 1e-2})); + RunAndFilecheckHloRewrite(hlo_text, + GemmRewriter(GpuComputeComp(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %unsupported_types ({{.*}}: <>[16,16], {{.*}}: <>[16,16]) -> <>[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,16]{1,0} parameter(0) ; CHECK-NEXT: [[P0_CONVERT:%[^ ]+]] = f16[16,16]{1,0} convert([[P0]]) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e5m2[16,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[16,16]{1,0} parameter(1) ; CHECK-NEXT: [[P1_CONVERT:%[^ ]+]] = f16[16,16]{1,0} convert([[P1]]) -; CHECK-NEXT: [[DOT:%[^ ]+]] = {{.*}} custom-call([[P0_CONVERT]], [[P1_CONVERT]]), -; CHECK: custom_call_target="<>", +; CHECK-NEXT: [[DOT:%[^ ]+]] = f16[16,16]{1,0} dot([[P0_CONVERT]], [[P1_CONVERT]]), lhs_contracting_dims={1}, rhs_contracting_dims={0} +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = <>[16,16]{1,0} convert([[DOT]]) + )"); +} + +TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDF8) { +#if GOOGLE_CUDA && CUDA_VERSION < 12000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; +#endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + + const char* hlo_text = R"( + HloModule test + + ENTRY test { + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) + ROOT out = <>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + +)"; + + CheckFp8IfSupported(hlo_text); + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16]) -> <>[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[C1:[^ ]+]] = f32[] constant(1) +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: ROOT [[OUT:%[^ ]+]] = <>[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 ; CHECK-DAG: "alpha_imag":0 ; CHECK-DAG: "beta":0 ; CHECK-DAG: "dot_dimension_numbers":{ ; CHECK-DAG: "lhs_contracting_dimensions":["1"] -; CHECK-DAG: "rhs_contracting_dimensions":["0"] +; CHECK-DAG: "rhs_contracting_dimensions":["1"] ; CHECK-DAG: "lhs_batch_dimensions":[] ; CHECK-DAG: "rhs_batch_dimensions":[] ; CHECK-DAG: } @@ -4839,37 +4838,43 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnsupportedTypesF8) { ; CHECK-DAG: } ; CHECK-DAG: "epilogue":"DEFAULT" ; CHECK: } -; CHECK: ROOT [[OUT:%[^ ]+]] = f8e5m2[16,16]{1,0} convert - )", - replacements_)); + )"); } -TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDF8) { -#if CUDA_VERSION < 12000 +// Do not fuse FP8 matrix bias. +TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) { +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) - ROOT out = f8e4m3fn[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) + dot_a = <>[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + b = <>[16,16] parameter(2) + ROOT out = <>[16,16] add(dot_a, b) } )"; CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16]) -> f8e4m3fn[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: <>[16,16]) -> <>[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C1:[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f8e4m3fn[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-GCN-NEXT: [[DOT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: [[DOT:%[^ ]+]] = <>[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -4886,19 +4891,26 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDF8) { ; CHECK-DAG: } ; CHECK-DAG: "epilogue":"DEFAULT" ; CHECK: } +; CHECK-NEXT: [[P2:%[^ ]+]] = <>[16,16]{1,0} parameter(2) +; CHECK-NEXT: [[ROOT:%[^ ]+]] = <>[16,16]{1,0} add([[DOT]], [[P2]]) )"); } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -4913,14 +4925,13 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) { )"; CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], x_scale: f32[], y_scale: f32[]) -> f32[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) @@ -4945,15 +4956,20 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDPaddedF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[13,17] parameter(0) - y = f8e4m3fn[17,31] parameter(1) + x = <>[13,17] parameter(0) + y = <>[17,31] parameter(1) x_f32 = f32[13,17] convert(x) y_f32 = f32[17,31] convert(y) x_scale = f32[] parameter(2) @@ -4968,18 +4984,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDPaddedF8) { )"; CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[13,17], y: f8e4m3fn[17,31], x_scale: f32[], y_scale: f32[]) -> f32[13,31] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[13,17]{1,0} parameter(0) -; CHECK-NEXT: [[C0:%[^ ]+]] = f8e4m3fn[] constant(0) -; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = f8e4m3fn[16,32]{1,0} pad([[P0]], [[C0]]), padding=0_3x0_15 -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[17,31]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[31,17]{1,0} transpose([[P1]]), dimensions={1,0} -; CHECK-NEXT: [[C1:%[^ ]+]] = f8e4m3fn[] constant(0) -; CHECK-NEXT: [[P1_TRANSPOSE_PADDED:%[^ ]+]] = f8e4m3fn[32,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]]) + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[13,17], {{.*}}: <>[17,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[13,31] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[13,17]{1,0} parameter(0) +; CHECK-NEXT: [[C0:%[^ ]+]] = <>[] constant(0) +; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = <>[16,32]{1,0} pad([[P0]], [[C0]]), padding=0_3x0_15 +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[17,31]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[31,17]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[C1:%[^ ]+]] = <>[] constant(0) +; CHECK-NEXT: [[P1_TRANSPOSE_PADDED:%[^ ]+]] = <>[32,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C4:%[^ ]+]] = f32[] constant(1) @@ -5005,15 +5020,20 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDPaddedF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDBitcastF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[2,8,16] parameter(0) - y = f8e4m3fn[16,16] parameter(1) + x = <>[2,8,16] parameter(0) + y = <>[16,16] parameter(1) x_f32 = f32[2,8,16] convert(x) y_f32 = f32[16,16] convert(y) x_scale = f32[] parameter(2) @@ -5030,8 +5050,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDBitcastF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass( - se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -5042,15 +5061,20 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDBitcastF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDUnaryOpsF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[3] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[3] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[3] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -5068,22 +5092,22 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDUnaryOpsF8) { } )"; + CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[3], y: f8e4m3fn[32,16], x_scale: f32[], y_scale: f32[]) -> f32[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[3]{0} parameter(0) +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[3], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[3]{0} parameter(0) ; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(0) -; CHECK-NEXT: [[C0_CONVERT:%[^ ]+]] = f8e4m3fn[] convert([[C0]]) -; CHECK-NEXT: [[P0_U0:%[^ ]+]] = f8e4m3fn[30]{0} pad([[P0]], [[C0_CONVERT]]), padding=0_27 -; CHECK-NEXT: [[P0_U1:%[^ ]+]] = f8e4m3fn[30,8,5]{2,1,0} broadcast([[P0_U0]]), dimensions={0} -; CHECK-NEXT: [[P0_U2:%[^ ]+]] = f8e4m3fn[16,8,4]{2,1,0} slice([[P0_U1]]), slice={[2:18], [0:8], [0:4]} -; CHECK-NEXT: [[P0_U3:%[^ ]+]] = f8e4m3fn[16,32]{1,0} reshape([[P0_U2]]) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[C0_CONVERT:%[^ ]+]] = <>[] convert([[C0]]) +; CHECK-NEXT: [[P0_U0:%[^ ]+]] = <>[30]{0} pad([[P0]], [[C0_CONVERT]]), padding=0_27 +; CHECK-NEXT: [[P0_U1:%[^ ]+]] = <>[30,8,5]{2,1,0} broadcast([[P0_U0]]), dimensions={0} +; CHECK-NEXT: [[P0_U2:%[^ ]+]] = <>[16,8,4]{2,1,0} slice([[P0_U1]]), slice={[2:18], [0:8], [0:4]} +; CHECK-NEXT: [[P0_U3:%[^ ]+]] = <>[16,32]{1,0} reshape([[P0_U2]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) @@ -5108,16 +5132,20 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDUnaryOpsF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; -#endif +#endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[32,32] parameter(0) - y = f8e4m3fn[16,32] parameter(1) + x = <>[32,32] parameter(0) + y = <>[16,32] parameter(1) zero = s32[] constant(0) x_f32 = f32[32,32] convert(x) y_f32 = f32[16,32] convert(y) @@ -5131,23 +5159,22 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) { ROOT dot_a = f32[16,16] dot(dyn_slice, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={1} } )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass( - se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[32,32], y: f8e4m3fn[16,32], x_scale: f32[], y_scale: f32[]) -> f32[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[32,32]{1,0} parameter(0) + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[32,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[32,32]{1,0} parameter(0) ; CHECK-NEXT: [[C0:%[^ ]+]] = s32[] constant(0) -; CHECK-NEXT: [[DYN_SLICE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} dynamic-slice([[P0]], [[C0]], [[C0]]), dynamic_slice_sizes={16,32} -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) +; CHECK-NEXT: [[DYN_SLICE:%[^ ]+]] = <>[16,32]{1,0} dynamic-slice([[P0]], [[C0]], [[C0]]), dynamic_slice_sizes={16,32} +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) @@ -5172,16 +5199,20 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; -#endif +#endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[16,32] parameter(1) + x = <>[16,32] parameter(0) + y = <>[16,32] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[16,32] convert(y) x_scale = f32[] parameter(2) @@ -5197,26 +5228,25 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) { ROOT dot_a = f32[16,16] dot(x_unscaled, select_a), lhs_contracting_dims={1}, rhs_contracting_dims={1} } )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass( - se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[16,32], x_scale: f32[], y_scale: f32[], k: pred[16,32]) -> f32[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: pred[16,32]) -> f32[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) ; CHECK-NEXT: [[P4:%[^ ]+]] = pred[16,32]{1,0} parameter(4) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) ; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(0) ; CHECK-NEXT: [[C0_BCAST:%[^ ]+]] = f32[16,32]{1,0} broadcast([[C0]]), dimensions={} -; CHECK-NEXT: [[C0_CONVERT:%[^ ]+]] = f8e4m3fn[16,32]{1,0} convert([[C0_BCAST]]) -; CHECK-NEXT: [[SELECT:%[^ ]+]] = f8e4m3fn[16,32]{1,0} select([[P4]], [[P1]], [[C0_CONVERT]]) +; CHECK-NEXT: [[C0_CONVERT:%[^ ]+]] = <>[16,32]{1,0} convert([[C0_BCAST]]) +; CHECK-NEXT: [[SELECT:%[^ ]+]] = <>[16,32]{1,0} select([[P4]], [[P1]], [[C0_CONVERT]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) @@ -5242,16 +5272,20 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) { TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectNonzeroConstantF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; -#endif +#endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[16,32] parameter(1) + x = <>[16,32] parameter(0) + y = <>[16,32] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[16,32] convert(y) x_scale = f32[] parameter(2) @@ -5267,32 +5301,29 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ROOT dot_a = f32[16,16] dot(x_unscaled, select_a), lhs_contracting_dims={1}, rhs_contracting_dims={1} } )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass( - se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); - EXPECT_TRUE(changed); - - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[16,32], x_scale: f32[], y_scale: f32[], k: pred[16,32]) -> f32[16,16] { -; CHECK-NOT: custom_call_target="__cublas$lt$matmul$f8" - )"); + EXPECT_FALSE(changed); } TEST_P(ParameterizedFp8GemmRewriteTest, BatchedScaledABUnscaledDF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[10,16,32] parameter(0) - y = f8e4m3fn[10,32,16] parameter(1) + x = <>[10,16,32] parameter(0) + y = <>[10,32,16] parameter(1) x_f32 = f32[10,16,32] convert(x) y_f32 = f32[10,32,16] convert(y) x_scale = f32[] parameter(2) @@ -5307,14 +5338,13 @@ TEST_P(ParameterizedFp8GemmRewriteTest, BatchedScaledABUnscaledDF8) { )"; CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[10,16,32], y: f8e4m3fn[10,32,16], x_scale: f32[], y_scale: f32[]) -> f32[10,16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[10,16,32]{2,1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[10,32,16]{2,1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[10,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1} + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[10,16,32], {{.*}}: <>[10,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[10,16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[10,16,32]{2,1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[10,32,16]{2,1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[10,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) @@ -5339,15 +5369,20 @@ TEST_P(ParameterizedFp8GemmRewriteTest, BatchedScaledABUnscaledDF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABAlphaDF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -5365,15 +5400,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABAlphaDF8) { )"; CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( - -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], x_scale: f32[], y_scale: f32[]) -> f32[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) @@ -5398,15 +5432,20 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABAlphaDF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDReluActivationF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -5424,19 +5463,190 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDReluActivationF8) { )"; CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( - -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], x_scale: f32[], y_scale: f32[]) -> f32[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK: custom_call_target="__cublas$lt$matmul$f8", +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":0 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-DAG: "rhs_contracting_dimensions":["1"] +; CHECK-DAG: "lhs_batch_dimensions":[] +; CHECK-DAG: "rhs_batch_dimensions":[] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-DAG: "epilogue":"RELU" +; CHECK: } + )"); +} + +TEST_P(ParameterizedFp8GemmRewriteTest, + ScaledABUnscaledDVectorBiasThenApproxGeluActivationF8) { +#if GOOGLE_CUDA && CUDA_VERSION < 12000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; +#endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + + const char* hlo_text = R"( + HloModule test + ENTRY test { + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) + x_bf16 = bf16[16,32] convert(x) + y_bf16 = bf16[32,16] convert(y) + x_scale = bf16[] parameter(2) + y_scale = bf16[] parameter(3) + bias = bf16[16] parameter(4) + x_scale_bcast = bf16[16,32] broadcast(x_scale), dimensions={} + y_scale_bcast = bf16[32,16] broadcast(y_scale), dimensions={} + x_unscaled = bf16[16,32] multiply(x_bf16, x_scale_bcast) + y_unscaled = bf16[32,16] multiply(y_bf16, y_scale_bcast) + dot1 = bf16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} + b_bcast = bf16[16,16] broadcast(bias), dimensions={1} + dot = bf16[16,16] add(dot1, b_bcast) + mul.0 = bf16[16,16] multiply(dot, dot) + mul.1 = bf16[16,16] multiply(dot, mul.0) + const.0 = bf16[] constant(0.044715) + bcast.0 = bf16[16,16] broadcast(const.0), dimensions={} + mul.2 = bf16[16,16] multiply(mul.1, bcast.0) + add.0 = bf16[16,16] add(dot, mul.2) + const.1 = bf16[] constant(0.797884583) + bcast.1 = bf16[16,16] broadcast(const.1), dimensions={} + mul.3 = bf16[16,16] multiply(add.0, bcast.1) + tanh = bf16[16,16] tanh(mul.3) + const.2 = bf16[] constant(1) + bcast.2 = bf16[16,16] broadcast(const.2), dimensions={} + add.2 = bf16[16,16] add(tanh, bcast.2) + const.3 = bf16[] constant(0.5) + bcast.3 = bf16[16,16] broadcast(const.3), dimensions={} + mul.4 = bf16[16,16] multiply(add.2, bcast.3) + ROOT out = bf16[16,16] multiply(dot, mul.4) + } +)"; + + CheckFp8IfSupported(hlo_text); + +// Fusing gelu into FP8 cublas matmuls is disabled on CUDA versions less +// than 12.4. +#if (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: bf16[], {{.*}}: bf16[], {{.*}}: bf16[16]) -> bf16[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[] parameter(2) +; CHECK-NEXT: [[XS:%[^ ]+]] = f32[] convert([[P2]]) +; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3) +; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]]) +; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-PTX-NEXT: [[B:%[^ ]+]] = bf16[16]{0} parameter(4) +; CHECK-PTX-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]], [[B]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), +; CHECK: custom_call_target="__cublas$lt$matmul$f8", +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":0 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-DAG: "rhs_contracting_dimensions":["1"] +; CHECK-DAG: "lhs_batch_dimensions":[] +; CHECK-DAG: "rhs_batch_dimensions":[] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-PTX-DAG: "epilogue":"BIAS_GELU" +; CHECK-GCN-DAG: "epilogue":"DEFAULT" +; CHECK: } + )"); +#endif // (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM +} + +TEST_P(ParameterizedFp8GemmRewriteTest, + ScaledABUnscaledDApproxGeluActivationF8) { +#if GOOGLE_CUDA && CUDA_VERSION < 12000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; +#endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + + const char* hlo_text = R"( + HloModule test + ENTRY test { + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) + x_bf16 = bf16[16,32] convert(x) + y_bf16 = bf16[32,16] convert(y) + x_scale = bf16[] parameter(2) + y_scale = bf16[] parameter(3) + x_scale_bcast = bf16[16,32] broadcast(x_scale), dimensions={} + y_scale_bcast = bf16[32,16] broadcast(y_scale), dimensions={} + x_unscaled = bf16[16,32] multiply(x_bf16, x_scale_bcast) + y_unscaled = bf16[32,16] multiply(y_bf16, y_scale_bcast) + dot = bf16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} + mul.0 = bf16[16,16] multiply(dot, dot) + mul.1 = bf16[16,16] multiply(dot, mul.0) + const.0 = bf16[] constant(0.044715) + bcast.0 = bf16[16,16] broadcast(const.0), dimensions={} + mul.2 = bf16[16,16] multiply(mul.1, bcast.0) + add.0 = bf16[16,16] add(dot, mul.2) + const.1 = bf16[] constant(0.797884583) + bcast.1 = bf16[16,16] broadcast(const.1), dimensions={} + mul.3 = bf16[16,16] multiply(add.0, bcast.1) + tanh = bf16[16,16] tanh(mul.3) + const.2 = bf16[] constant(1) + bcast.2 = bf16[16,16] broadcast(const.2), dimensions={} + add.2 = bf16[16,16] add(tanh, bcast.2) + const.3 = bf16[] constant(0.5) + bcast.3 = bf16[16,16] broadcast(const.3), dimensions={} + mul.4 = bf16[16,16] multiply(add.2, bcast.3) + ROOT out = bf16[16,16] multiply(dot, mul.4) + } +)"; + + CheckFp8IfSupported(hlo_text); + +// Fusing gelu into FP8 cublas matmuls is disabled on CUDA versions less +// than 12.4. +#if (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM + // Currently, hipBlasLt does not support output datatype bf16 for fp8 matmul. + // And no fusion was done for such cases. + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: bf16[], {{.*}}: bf16[]) -> bf16[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[] parameter(2) +; CHECK-NEXT: [[XS:%[^ ]+]] = f32[] convert([[P2]]) +; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3) +; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]]) +; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-PTX-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5451,21 +5661,28 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDReluActivationF8) { ; CHECK-DAG: "precision_config":{ ; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] ; CHECK-DAG: } -; CHECK-DAG: "epilogue":"RELU" +; CHECK-PTX-DAG: "epilogue":"GELU" +; CHECK-GCN-DAG: "epilogue":"DEFAULT" ; CHECK: } )"); +#endif // (GOOGLE_CUDA && CUDA_VERSION >= 12040) || TENSORFLOW_USE_ROCM } TEST_P(ParameterizedFp8GemmRewriteTest, InvScaledABUnscaledDF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -5480,24 +5697,28 @@ TEST_P(ParameterizedFp8GemmRewriteTest, InvScaledABUnscaledDF8) { )"; CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", )"); } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) b = f32[16,16] parameter(2) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) @@ -5514,15 +5735,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasF8) { )"; CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( - -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], b: f32[16,16], x_scale: f32[], y_scale: f32[]) -> f32[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[16,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C0:%[^ ]+]] = f32[16,16]{1,0} parameter(2) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) @@ -5548,15 +5768,20 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[14,31] parameter(0) - y = f8e4m3fn[31,14] parameter(1) + x = <>[14,31] parameter(0) + y = <>[31,14] parameter(1) b = f32[14,14] parameter(2) x_f32 = f32[14,31] convert(x) y_f32 = f32[31,14] convert(y) @@ -5573,19 +5798,18 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) { )"; CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( - -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[14,31], y: f8e4m3fn[31,14], b: f32[14,14], x_scale: f32[], y_scale: f32[]) -> f32[14,14] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[14,31]{1,0} parameter(0) -; CHECK-NEXT: [[C0:%[^ ]+]] = f8e4m3fn[] constant(0) -; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = f8e4m3fn[16,32]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_1 -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[31,14]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[14,31]{1,0} transpose([[P1]]), dimensions={1,0} -; CHECK-NEXT: [[C1:%[^ ]+]] = f8e4m3fn[] constant(0) -; CHECK-NEXT: [[P1_TRANSPOSE_PADDED:%[^ ]+]] = f8e4m3fn[16,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]]), padding=0_2x0_1 + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[14,31], {{.*}}: <>[31,14], {{.*}}: f32[14,14], {{.*}}: f32[], {{.*}}: f32[]) -> f32[14,14] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[14,31]{1,0} parameter(0) +; CHECK-NEXT: [[C0:%[^ ]+]] = <>[] constant(0) +; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = <>[16,32]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_1 +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[31,14]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[14,31]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[C1:%[^ ]+]] = <>[] constant(0) +; CHECK-NEXT: [[P1_TRANSPOSE_PADDED:%[^ ]+]] = <>[16,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]]), padding=0_2x0_1 ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[14,14]{1,0} parameter(2) ; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(0) ; CHECK-NEXT: [[P2_PADDED:%[^ ]+]] = f32[16,16]{1,0} pad([[P2]], [[C2]]), padding=0_2x0_2 @@ -5613,16 +5837,83 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) { )"); } +TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) { +#if GOOGLE_CUDA && CUDA_VERSION < 12000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; +#endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + + const char* hlo_text = R"( + HloModule test + + ENTRY test { + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) + z_scale = f32[] parameter(2) + z_scale_bcast = f32[16,16] broadcast(z_scale), dimensions={} + dot_a = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast) + c1 = f32[] constant(-448.) + c1_bcast = f32[16,16] broadcast(c1), dimensions={} + c2 = f32[] constant(448.) + c2_bcast = f32[16,16] broadcast(c2), dimensions={} + dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast) + ROOT dot_a_f8 = <>[16,16] convert(dot_a_clamped) + } + +)"; + + CheckFp8IfSupported(hlo_text, ErrorSpec{1e-2, 1e-1}); + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[]) -> <>[16,16] { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) +; CHECK-PTX-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) +; CHECK-PTX-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C2]], [[P2]]) +; CHECK-PTX-NEXT: ROOT [[OUT:%[^ ]+]] = <>[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[P2_INV]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK: custom_call_target="__cublas$lt$matmul$f8", +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":0 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-DAG: "rhs_contracting_dimensions":["1"] +; CHECK-DAG: "lhs_batch_dimensions":[] +; CHECK-DAG: "rhs_batch_dimensions":[] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-DAG: "epilogue":"DEFAULT" +; CHECK: } + )"); +} + TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -5635,32 +5926,32 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) { y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast) dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast) - c1 = f32[] constant(-448.) + c1 = f32[] constant(-<>) c1_bcast = f32[16,16] broadcast(c1), dimensions={} - c2 = f32[] constant(448.) + c2 = f32[] constant(<>) c2_bcast = f32[16,16] broadcast(c2), dimensions={} dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast) - ROOT dot_a_f8 = f8e4m3fn[16,16] convert(dot_a_clamped) + ROOT dot_a_f8 = <>[16,16] convert(dot_a_clamped) } )"; CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], x_scale: f32[], y_scale: f32[], z_scale: f32[]) -> f8e4m3fn[16,16] { -; CHECK: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> <>[16,16] { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f8e4m3fn[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), +; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) +; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) +; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) +; CHECK-PTX-NEXT: ROOT [[OUT:%[^ ]+]] = <>[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5681,15 +5972,20 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABInvScaledDF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -5702,21 +5998,20 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABInvScaledDF8) { y_unscaled = f32[32,16] multiply(y_f32, y_scale_bcast) dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} dot_a_scaled = f32[16,16] multiply(dot_a, z_scale_bcast) - c1 = f32[] constant(-448.) + c1 = f32[] constant(-<>) c1_bcast = f32[16,16] broadcast(c1), dimensions={} - c2 = f32[] constant(448.) + c2 = f32[] constant(<>) c2_bcast = f32[16,16] broadcast(c2), dimensions={} dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast) - ROOT dot_a_f8 = f8e4m3fn[16,16] convert(dot_a_clamped) + ROOT dot_a_f8 = <>[16,16] convert(dot_a_clamped) } )"; CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( ; CHECK-NOT: divide @@ -5726,14 +6021,19 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABInvScaledDF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -5749,31 +6049,31 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) { dot_a = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} relu_a = f32[16,16] maximum(dot_a, c_bcast) relu_a_scaled = f32[16,16] divide(relu_a, z_scale_bcast) - c1 = f32[] constant(-448.) + c1 = f32[] constant(-<>) c1_bcast = f32[16,16] broadcast(c1), dimensions={} - c2 = f32[] constant(448.) + c2 = f32[] constant(<>) c2_bcast = f32[16,16] broadcast(c2), dimensions={} relu_a_clamped = f32[16,16] clamp(c1_bcast, relu_a_scaled, c2_bcast) - ROOT out = f8e4m3fn[16,16] convert(relu_a_clamped) + ROOT out = <>[16,16] convert(relu_a_clamped) } )"; CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], x_scale: f32[], y_scale: f32[], z_scale: f32[]) -> f8e4m3fn[16,16] { -; CHECK: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> <>[16,16] { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f8e4m3fn[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), +; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) +; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) +; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) +; CHECK-PTX-NEXT: ROOT [[OUT:%[^ ]+]] = <>[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5794,15 +6094,20 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f16 = f16[16,32] convert(x) y_f16 = f16[32,16] convert(y) b = f16[16,16] parameter(2) @@ -5817,32 +6122,32 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasF8) { dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} dot_a_bias = f16[16,16] add(dot_a, b) dot_a_scaled = f16[16,16] divide(dot_a_bias, z_scale_bcast) - c1 = f16[] constant(-448.) + c1 = f16[] constant(-<>) c1_bcast = f16[16,16] broadcast(c1), dimensions={} - c2 = f16[] constant(448.) + c2 = f16[] constant(<>) c2_bcast = f16[16,16] broadcast(c2), dimensions={} dot_a_clamped = f16[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast) - ROOT dot_a_f8 = f8e4m3fn[16,16] convert(dot_a_clamped) + ROOT dot_a_f8 = <>[16,16] convert(dot_a_clamped) } )"; CheckFp8IfSupported(hlo_text, ErrorSpec{0.1, 0.1}); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( - -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], b: f16[16,16], x_scale: f16[], y_scale: f16[], z_scale: f16[]) -> f8e4m3fn[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[16,16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> <>[16,16] { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C0:%[^ ]+]] = f16[16,16]{1,0} parameter(2) ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3) ; CHECK: [[P3:%[^ ]+]] = f16[] parameter(4) ; CHECK: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK: [[P4:%[^ ]+]] = f16[] parameter(5) -; CHECK: ROOT [[OUT:%[^ ]+]] = f8e4m3fn[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]), +; CHECK-PTX: [[P4:%[^ ]+]] = f16[] parameter(5) +; CHECK-PTX: ROOT [[OUT:%[^ ]+]] = <>[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]), +; CHECK-GCN: [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5863,15 +6168,20 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f16 = f16[16,32] convert(x) y_f16 = f16[32,16] convert(y) b = f16[16] parameter(2) @@ -5887,37 +6197,37 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) { dot_a = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} dot_a_bias = f16[16,16] add(dot_a, b_bcast) dot_a_scaled = f16[16,16] divide(dot_a_bias, z_scale_bcast) - c1 = f16[] constant(-448.) + c1 = f16[] constant(-<>) c1_bcast = f16[16,16] broadcast(c1), dimensions={} - c2 = f16[] constant(448.) + c2 = f16[] constant(<>) c2_bcast = f16[16,16] broadcast(c2), dimensions={} dot_a_clamped = f16[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast) - ROOT dot_a_f8 = f8e4m3fn[16,16] convert(dot_a_clamped) + ROOT dot_a_f8 = <>[16,16] convert(dot_a_clamped) } )"; CheckFp8IfSupported(hlo_text, ErrorSpec{0.1, 0.1}); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( - -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], b: f16[16], x_scale: f16[], y_scale: f16[], z_scale: f16[]) -> f8e4m3fn[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> <>[16,16] { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3) ; CHECK-NEXT: [[CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) ; CHECK-NEXT: [[CV1:%[^ ]+]] = f32[] convert([[P3]]) ; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[C2:%[^ ]+]] = f16[] constant(1) -; CHECK-NEXT: [[P4:%[^ ]+]] = f16[] parameter(5) -; CHECK-NEXT: [[DV:%[^ ]+]] = f16[] divide([[C2]], [[P4]]) -; CHECK-NEXT: [[CV2:%[^ ]+]] = f32[] convert([[DV]]) +; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f16[] constant(1) +; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f16[] parameter(5) +; CHECK-PTX-NEXT: [[DV:%[^ ]+]] = f16[] divide([[C2]], [[P4]]) +; CHECK-PTX-NEXT: [[CV2:%[^ ]+]] = f32[] convert([[DV]]) ; CHECK-NEXT: [[VB:%[^ ]+]] = f16[16]{0} parameter(2) -; CHECK: ROOT [[OUT:%[^ ]+]] = f8e4m3fn[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[CV2]], [[VB]]), +; CHECK-PTX: ROOT [[OUT:%[^ ]+]] = <>[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[CV2]], [[VB]]), +; CHECK-GCN: [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5938,15 +6248,20 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF32VectorBiasF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) b = f32[16] parameter(2) @@ -5966,14 +6281,13 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF32VectorBiasF8) { )"; CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], b: f32[16], x_scale: f32[], y_scale: f32[]) -> f32[16,16] { -; CHECK: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) ; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) @@ -6001,15 +6315,20 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF32VectorBiasF8) { TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDVectorBiasThenReluActivationF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) b = f16[16] parameter(2) b_bcast = f16[16,16] broadcast(b), dimensions={1} x_f32 = f16[16,32] convert(x) @@ -6029,14 +6348,13 @@ TEST_P(ParameterizedFp8GemmRewriteTest, )"; CheckFp8IfSupported(hlo_text, ErrorSpec{2e-3, 0.}); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], b: f16[16], x_scale: f16[], y_scale: f16[]) -> f16[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3) ; CHECK-NEXT: [[CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) @@ -6064,14 +6382,19 @@ TEST_P(ParameterizedFp8GemmRewriteTest, } TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; -#endif +#endif // CUDA_VERSION < 12000 + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[4,16,16] parameter(0) - y = f8e4m3fn[16,32] parameter(1) + x = <>[4,16,16] parameter(0) + y = <>[16,32] parameter(1) b = f32[32] parameter(2) b_f16 = f16[32] convert(b) b_bcast = f16[4,16,32] broadcast(b_f16), dimensions={2} @@ -6092,8 +6415,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass( - se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -6102,15 +6424,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { .WithShape(F16, {64, 32})) .WithShape(F16, {4, 16, 32}))); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[4,16,16], y: f8e4m3fn[16,32], b: f32[32], x_scale: f16[], y_scale: f16[]) -> f16[4,16,32] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[4,16,16]{2,1,0} parameter(0) -; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f8e4m3fn[64,16]{1,0} bitcast([[P0]]) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[32,16]{1,0} transpose([[P1]]), dimensions={1,0} + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[4,16,16], {{.*}}: <>[16,32], {{.*}}: f32[32], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,16,32] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[4,16,16]{2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = <>[64,16]{1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[32,16]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3) ; CHECK-NEXT: [[P2_CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) @@ -6141,14 +6462,19 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasPaddedF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; #endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[4,15,15] parameter(0) - y = f8e4m3fn[15,31] parameter(1) + x = <>[4,15,15] parameter(0) + y = <>[15,31] parameter(1) b = f32[31] parameter(2) b_f16 = f16[31] convert(b) b_bcast = f16[4,15,31] broadcast(b_f16), dimensions={2} @@ -6169,8 +6495,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass( - se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -6181,19 +6506,18 @@ TEST_P(ParameterizedFp8GemmRewriteTest, .WithShape(F16, {60, 31})) .WithShape(F16, {4, 15, 31}))); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[4,15,15], y: f8e4m3fn[15,31], b: f32[31], x_scale: f16[], y_scale: f16[]) -> f16[4,15,31] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[4,15,15]{2,1,0} parameter(0) -; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f8e4m3fn[60,15]{1,0} bitcast([[P0]]) -; CHECK-NEXT: [[C1:%[^ ]+]] = f8e4m3fn[] constant(0) -; CHECK-NEXT: [[P0_PAD:%[^ ]+]] = f8e4m3fn[64,16]{1,0} pad([[P0_BITCAST]], [[C1]]), padding=0_4x0_1 -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[15,31]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[31,15]{1,0} transpose([[P1]]), dimensions={1,0} -; CHECK-NEXT: [[C2:%[^ ]+]] = f8e4m3fn[] constant(0) -; CHECK-NEXT: [[P1_PAD:%[^ ]+]] = f8e4m3fn[32,16]{1,0} pad([[P1_TRANSPOSE]], [[C2]]), padding=0_1x0_1 + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[4,15,15], {{.*}}: <>[15,31], {{.*}}: f32[31], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,15,31] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[4,15,15]{2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = <>[60,15]{1,0} bitcast([[P0]]) +; CHECK-NEXT: [[C1:%[^ ]+]] = <>[] constant(0) +; CHECK-NEXT: [[P0_PAD:%[^ ]+]] = <>[64,16]{1,0} pad([[P0_BITCAST]], [[C1]]), padding=0_4x0_1 +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[15,31]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[31,15]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[C2:%[^ ]+]] = <>[] constant(0) +; CHECK-NEXT: [[P1_PAD:%[^ ]+]] = <>[32,16]{1,0} pad([[P1_TRANSPOSE]], [[C2]]), padding=0_1x0_1 ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3) ; CHECK-NEXT: [[P2_CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) @@ -6226,14 +6550,19 @@ TEST_P(ParameterizedFp8GemmRewriteTest, } TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; #endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[4,16,16] parameter(0) - y = f8e4m3fn[16,32] parameter(1) + x = <>[4,16,16] parameter(0) + y = <>[16,32] parameter(1) b = f32[4,16,32] parameter(2) x_f32 = f32[4,16,16] convert(x) y_f32 = f32[16,32] convert(y) @@ -6252,8 +6581,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass( - se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -6262,15 +6590,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) { .WithShape(F32, {64, 32})) .WithShape(F32, {4, 16, 32}))); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[4,16,16], y: f8e4m3fn[16,32], b: f32[4,16,32], x_scale: f32[], y_scale: f32[]) -> f32[4,16,32] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[4,16,16]{2,1,0} parameter(0) -; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f8e4m3fn[64,16]{1,0} bitcast([[P0]]) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[32,16]{1,0} transpose([[P1]]), dimensions={1,0} + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[4,16,16], {{.*}}: <>[16,32], {{.*}}: f32[4,16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[4,16,32] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[4,16,16]{2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = <>[64,16]{1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[32,16]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[B:%[^ ]+]] = f32[4,16,32]{2,1,0} parameter(2) ; CHECK-NEXT: [[B_BITCAST:%[^ ]+]] = f32[64,32]{1,0} bitcast([[B]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) @@ -6299,14 +6626,19 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) { TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasPaddedF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; #endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[3,15,15] parameter(0) - y = f8e4m3fn[15,31] parameter(1) + x = <>[3,15,15] parameter(0) + y = <>[15,31] parameter(1) b = f32[3,15,31] parameter(2) x_f32 = f32[3,15,15] convert(x) y_f32 = f32[15,31] convert(y) @@ -6325,8 +6657,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass( - se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -6337,19 +6668,18 @@ TEST_P(ParameterizedFp8GemmRewriteTest, .WithShape(F32, {45, 31})) .WithShape(F32, {3, 15, 31}))); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[3,15,15], y: f8e4m3fn[15,31], b: f32[3,15,31], x_scale: f32[], y_scale: f32[]) -> f32[3,15,31] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[3,15,15]{2,1,0} parameter(0) -; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f8e4m3fn[45,15]{1,0} bitcast([[P0]]) -; CHECK-NEXT: [[C1:%[^ ]+]] = f8e4m3fn[] constant(0) -; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = f8e4m3fn[48,16]{1,0} pad([[P0_BITCAST]], [[C1]]), padding=0_3x0_1 -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[15,31]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[31,15]{1,0} transpose([[P1]]), dimensions={1,0} -; CHECK-NEXT: [[C2:%[^ ]+]] = f8e4m3fn[] constant(0) -; CHECK-NEXT: [[P1_PADDED:%[^ ]+]] = f8e4m3fn[32,16]{1,0} pad([[P1_TRANSPOSE]], [[C2]]), padding=0_1x0_1 + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[3,15,15], {{.*}}: <>[15,31], {{.*}}: f32[3,15,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[3,15,31] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[3,15,15]{2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = <>[45,15]{1,0} bitcast([[P0]]) +; CHECK-NEXT: [[C1:%[^ ]+]] = <>[] constant(0) +; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = <>[48,16]{1,0} pad([[P0_BITCAST]], [[C1]]), padding=0_3x0_1 +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[15,31]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[31,15]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[C2:%[^ ]+]] = <>[] constant(0) +; CHECK-NEXT: [[P1_PADDED:%[^ ]+]] = <>[32,16]{1,0} pad([[P1_TRANSPOSE]], [[C2]]), padding=0_1x0_1 ; CHECK-NEXT: [[B:%[^ ]+]] = f32[3,15,31]{2,1,0} parameter(2) ; CHECK-NEXT: [[B_BITCAST:%[^ ]+]] = f32[45,31]{1,0} bitcast([[B]]) ; CHECK-NEXT: [[C3:%[^ ]+]] = f32[] constant(0) @@ -6383,14 +6713,19 @@ TEST_P(ParameterizedFp8GemmRewriteTest, // of dimensions. TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasWithSliceF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; #endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[48,16] parameter(0) - y = f8e4m3fn[16,32] parameter(1) + x = <>[48,16] parameter(0) + y = <>[16,32] parameter(1) b = f32[32,16] parameter(2) x_f32 = f32[48,16] convert(x) y_f32 = f32[16,32] convert(y) @@ -6408,19 +6743,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass( - se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[48,16], y: f8e4m3fn[16,32], b: f32[32,16], x_scale: f32[], y_scale: f32[]) -> f32[32,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[48,16]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[32,16]{1,0} transpose([[P1]]), dimensions={1,0} + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[48,16], {{.*}}: <>[16,32], {{.*}}: f32[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[32,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[48,16]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[32,16]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) ; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) @@ -6448,16 +6781,20 @@ TEST_P(ParameterizedFp8GemmRewriteTest, } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllGatherF8) { -#if CUDA_VERSION < 12000 - GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; +#if GOOGLE_CUDA && CUDA_VERSION < 12000 + GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; #endif +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + absl::string_view hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[16,32] parameter(1) + x = <>[16,32] parameter(0) + y = <>[16,32] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[16,32] convert(y) x_scale = f32[] parameter(2) @@ -6476,16 +6813,15 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllGatherF8) { config.set_use_spmd_partitioning(true); config.set_num_partitions(8); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[16,32], x_scale: f32[], y_scale: f32[]) -> f32[16,32] { -; CHECK: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK: [[AG:%[^ ]+]] = f8e4m3fn[16,64]{1,0} all-gather([[P0]]), {{[^ ]+}} -; CHECK: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) -; CHECK: [[AG1:%[^ ]+]] = f8e4m3fn[64,32]{1,0} all-gather([[P1]]), {{[^ ]+}} -; CHECK: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[32,64]{1,0} transpose([[AG1]]), dimensions={1,0} + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,32] { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK: [[AG:%[^ ]+]] = <>[16,64]{1,0} all-gather([[P0]]), {{[^ ]+}} +; CHECK: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) +; CHECK: [[AG1:%[^ ]+]] = <>[64,32]{1,0} all-gather([[P1]]), {{[^ ]+}} +; CHECK: [[P1_TRANSPOSE:%[^ ]+]] = <>[32,64]{1,0} transpose([[AG1]]), dimensions={1,0} ; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK: [[C:%[^ ]+]] = f32[] constant(1) @@ -6507,20 +6843,24 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllGatherF8) { ; CHECK-DAG: "epilogue":"DEFAULT" ; CHECK: } )", - nullptr, &config); + nullptr, &config); } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllToAllF8) { -#if CUDA_VERSION < 12000 - GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; +#if GOOGLE_CUDA && CUDA_VERSION < 12000 + GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; #endif +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + absl::string_view hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[16,32] parameter(1) + x = <>[16,32] parameter(0) + y = <>[16,32] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[16,32] convert(y) x_scale = f32[] parameter(2) @@ -6538,14 +6878,13 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllToAllF8) { config.set_use_spmd_partitioning(true); config.set_num_partitions(8); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[16,32], x_scale: f32[], y_scale: f32[]) -> f32[16,16] { -; CHECK: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK: [[AA:%[^ ]+]] = f8e4m3fn[16,32]{1,0} all-to-all([[P0]]), {{[^ ]+}} -; CHECK: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK: [[AA:%[^ ]+]] = <>[16,32]{1,0} all-to-all([[P0]]), {{[^ ]+}} +; CHECK: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) ; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK: [[C:%[^ ]+]] = f32[] constant(1) @@ -6567,21 +6906,25 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllToAllF8) { ; CHECK-DAG: "epilogue":"DEFAULT" ; CHECK: } )", - nullptr, &config); + nullptr, &config); } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithCollectivePermuteF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + absl::string_view hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[16,32] parameter(1) + x = <>[16,32] parameter(0) + y = <>[16,32] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[16,32] convert(y) x_scale = f32[] parameter(2) @@ -6599,14 +6942,13 @@ TEST_P(ParameterizedFp8GemmRewriteTest, config.set_use_spmd_partitioning(true); config.set_num_partitions(8); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[16,32], x_scale: f32[], y_scale: f32[]) -> f32[16,16] { -; CHECK: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK: [[AA:%[^ ]+]] = f8e4m3fn[16,32]{1,0} collective-permute([[P0]]), {{[^ ]+}} -; CHECK: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK: [[AA:%[^ ]+]] = <>[16,32]{1,0} collective-permute([[P0]]), {{[^ ]+}} +; CHECK: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) ; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK: [[C:%[^ ]+]] = f32[] constant(1) @@ -6628,20 +6970,25 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-DAG: "epilogue":"DEFAULT" ; CHECK: } )", - nullptr, &config); + nullptr, &config); } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasThenVectorBiasF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; -#endif // CUDA_VERSION < 12000 +#endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f16 = f16[16,32] convert(x) y_f16 = f16[32,16] convert(y) b = f16[16] parameter(2) @@ -6659,15 +7006,15 @@ TEST_P(ParameterizedFp8GemmRewriteTest, } )"; + CheckFp8IfSupported(hlo_text, ErrorSpec{2e-3, 0.}); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], b: f16[16], b2: f16[16,16], x_scale: f16[], y_scale: f16[]) -> f16[16,16] { -; CHECK-DAG: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[16], {{.*}}: f16[16,16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] { +; CHECK-DAG: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[MB:%[^ ]+]] = f16[16,16]{1,0} parameter(3) ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(4) ; CHECK-NEXT: [[CV0:%[^ ]+]] = f32[] convert([[P2]]) @@ -6698,9 +7045,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; -#endif // CUDA_VERSION < 12000 +#endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test @@ -6711,8 +7063,8 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { } ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -6728,33 +7080,33 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { c0 = f32[] constant(-inf) amax = f32[] reduce(abs_dot_a, c0), dimensions={0,1}, to_apply=apply dot_a_scaled = f32[16,16] divide(dot_a, z_scale_bcast) - c1 = f32[] constant(-448.) + c1 = f32[] constant(-<>) c1_bcast = f32[16,16] broadcast(c1), dimensions={} - c2 = f32[] constant(448.) + c2 = f32[] constant(<>) c2_bcast = f32[16,16] broadcast(c2), dimensions={} dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast) - dot_a_f8 = f8e4m3fn[16,16] convert(dot_a_clamped) - ROOT out = (f8e4m3fn[16,16], f32[]) tuple(dot_a_f8, amax) + dot_a_f8 = <>[16,16] convert(dot_a_clamped) + ROOT out = (<>[16,16], f32[]) tuple(dot_a_f8, amax) } )"; CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], x_scale: f32[], y_scale: f32[], z_scale: f32[]) -> (f8e4m3fn[16,16], f32[]) { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]) + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (<>[16,16], f32[]) { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f8e4m3fn[16,16]{1,0}, f32[]) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), +; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) +; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) +; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[]) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6776,9 +7128,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8WithF16Intermediates) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; -#endif // CUDA_VERSION < 12000 +#endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + // This is the same as ScaledABScaledDWithDAmaxF8, but uses F16 intermediate // values instead of F32 intermediate values. const char* hlo_text = R"( @@ -6791,8 +7148,8 @@ TEST_P(ParameterizedFp8GemmRewriteTest, } ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f16 = f16[16,32] convert(x) y_f16 = f16[32,16] convert(y) x_scale = f16[] parameter(2) @@ -6808,36 +7165,36 @@ TEST_P(ParameterizedFp8GemmRewriteTest, c0 = f16[] constant(-inf) amax = f16[] reduce(abs_dot_a, c0), dimensions={0,1}, to_apply=apply dot_a_scaled = f16[16,16] divide(dot_a, z_scale_bcast) - c1 = f16[] constant(-448.) + c1 = f16[] constant(-<>) c1_bcast = f16[16,16] broadcast(c1), dimensions={} - c2 = f16[] constant(448.) + c2 = f16[] constant(<>) c2_bcast = f16[16,16] broadcast(c2), dimensions={} dot_a_clamped = f16[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast) - dot_a_f8 = f8e4m3fn[16,16] convert(dot_a_clamped) - ROOT out = (f8e4m3fn[16,16], f16[]) tuple(dot_a_f8, amax) + dot_a_f8 = <>[16,16] convert(dot_a_clamped) + ROOT out = (<>[16,16], f16[]) tuple(dot_a_f8, amax) } )"; CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], x_scale: f16[], y_scale: f16[], z_scale: f16[]) -> (f8e4m3fn[16,16], f16[]) { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]) + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> (<>[16,16], f16[]) { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(2) ; CHECK-NEXT: [[P2_CONVERT:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(3) ; CHECK-NEXT: [[P3_CONVERT:%[^ ]+]] = f32[] convert([[P3]]) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[C2:%[^ ]+]] = f16[] constant(1) -; CHECK-NEXT: [[P4:%[^ ]+]] = f16[] parameter(4) -; CHECK-NEXT: [[P4_INV:%[^ ]+]] = f16[] divide([[C2]], [[P4]]) -; CHECK-NEXT: [[P4_INV_CONVERT:%[^ ]+]] = f32[] convert([[P4_INV]]) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f8e4m3fn[16,16]{1,0}, f32[]) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[P4_INV_CONVERT]]), +; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f16[] constant(1) +; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f16[] parameter(4) +; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f16[] divide([[C2]], [[P4]]) +; CHECK-PTX-NEXT: [[P4_INV_CONVERT:%[^ ]+]] = f32[] convert([[P4_INV]]) +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[]) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[P4_INV_CONVERT]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6859,9 +7216,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationWithDAmaxF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; -#endif // CUDA_VERSION < 12000 +#endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test @@ -6872,8 +7234,8 @@ TEST_P(ParameterizedFp8GemmRewriteTest, } ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -6891,33 +7253,33 @@ TEST_P(ParameterizedFp8GemmRewriteTest, c0 = f32[] constant(-inf) amax = f32[] reduce(dot_a_relu, c0), dimensions={0,1}, to_apply=apply dot_a_scaled = f32[16,16] divide(dot_a_relu, z_scale_bcast) - c1 = f32[] constant(-448.) + c1 = f32[] constant(-<>) c1_bcast = f32[16,16] broadcast(c1), dimensions={} - c2 = f32[] constant(448.) + c2 = f32[] constant(<>) c2_bcast = f32[16,16] broadcast(c2), dimensions={} dot_a_clamped = f32[16,16] clamp(c1_bcast, dot_a_scaled, c2_bcast) - dot_a_f8 = f8e4m3fn[16,16] convert(dot_a_clamped) - ROOT out = (f8e4m3fn[16,16], f32[]) tuple(dot_a_f8, amax) + dot_a_f8 = <>[16,16] convert(dot_a_clamped) + ROOT out = (<>[16,16], f32[]) tuple(dot_a_f8, amax) } )"; CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], x_scale: f32[], y_scale: f32[], z_scale: f32[]) -> (f8e4m3fn[16,16], f32[]) { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]) + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (<>[16,16], f32[]) { +; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f8e4m3fn[16,16]{1,0}, f32[]) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), +; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) +; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) +; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[]) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6938,15 +7300,20 @@ TEST_P(ParameterizedFp8GemmRewriteTest, } TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDPrecisionF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; #endif // CUDA_VERSION < 12000 - const char* hlo_template = R"( + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + + const char* raw_hlo_template = R"( HloModule test ENTRY test { - x = f8e4m3fn[1600,3200] parameter(0) - y = f8e4m3fn[3200,1600] parameter(1) + x = <>[1600,3200] parameter(0) + y = <>[3200,1600] parameter(1) x_f32 = f32[1600,3200] convert(x) y_f32 = f32[3200,1600] convert(y) x_scale = f32[] parameter(2) @@ -6959,6 +7326,9 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDPrecisionF8) { } )"; + std::string hlo_template = + absl::StrReplaceAll(raw_hlo_template, replacements_); + absl::flat_hash_map replacements; replacements["<>"] = "default"; const auto hlo_text_default = absl::StrReplaceAll(hlo_template, replacements); @@ -6970,9 +7340,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDPrecisionF8) { } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; -#endif // CUDA_VERSION < 12000 +#endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + std::array, 32> combinations; int i = 0; @@ -7004,12 +7379,12 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) { const char* hlo_template = R"( HloModule test ENTRY test { - x = f8e4m3fn<><> parameter(0) + x = <><><> parameter(0) x_f32 = f32<><> convert(x) x_scale = f32[] parameter(2) x_scale_bcast = f32<> broadcast(x_scale), dimensions={} x_unscaled = f32<> multiply(x_f32, x_scale_bcast) - y = f8e4m3fn<><> parameter(1) + y = <><><> parameter(1) y_f32 = f32<><> convert(y) y_scale = f32[] parameter(3) y_scale_bcast = f32<> broadcast(y_scale), dimensions={} @@ -7029,10 +7404,9 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) { const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements); CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", )"); } @@ -7040,10 +7414,15 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) { TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8ParameterizedBatched) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; -#endif // CUDA_VERSION < 12000 - // TODO(wenscarl): For batched matmaul, not all combinations of A, B and +#endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + + // TODO(wenscarl): For batched matmul, not all combinations of A, B and // output layouts get pattern matched successfully to FP8 custom call. Only // a handful of cases are tested here. std::array, 32> combinations; @@ -7070,13 +7449,13 @@ TEST_P(ParameterizedFp8GemmRewriteTest, const char* hlo_template = R"( HloModule m ENTRY f { - x_q = f8e4m3fn<><> parameter(0) + x_q = <><><> parameter(0) x_scale = f32[] parameter(2) x_scale_broadcast = f32<><> broadcast(x_scale), dimensions={} x_q_convert = f32<><> convert(x_q) x_qdq = f32<><> multiply(x_q_convert, x_scale_broadcast) - y_q = f8e4m3fn<><> parameter(1) + y_q = <><><> parameter(1) y_scale = f32[] parameter(3) y_scale_broadcast = f32<><> broadcast(y_scale), dimensions={} y_q_convert = f32<><> convert(y_q) @@ -7085,6 +7464,7 @@ ENTRY f { ROOT out = f32[2,64,16]<> dot(x_qdq, y_qdq), lhs_batch_dims={0}, lhs_contracting_dims=<>, rhs_batch_dims={0}, rhs_contracting_dims=<> } )"; + for (const auto& combination : combinations) { absl::flat_hash_map replacements; replacements["<>"] = std::get<0>(combination); @@ -7098,25 +7478,29 @@ ENTRY f { const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements); CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", )"); } } TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; -#endif // CUDA_VERSION < 12000 +#endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e5m2[32,16] parameter(1) + x = <>[16,32] parameter(0) + y = <>[32,16] parameter(1) x_f32 = f32[16,32] convert(x) y_f32 = f32[32,16] convert(y) x_scale = f32[] parameter(2) @@ -7131,18 +7515,22 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) { )"; CheckFp8IfSupported(hlo_text); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", )"); } TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) { -#if CUDA_VERSION < 12000 +#if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; -#endif // CUDA_VERSION < 12000 +#endif + +#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above."; +#endif // TF_ROCM_VERSION < 60000 + // Test that FNUZ FP8 gemms are not rewritten, as cuBLAS does not support them const char* hlo_text = R"( HloModule test @@ -7161,31 +7549,45 @@ TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) { ROOT out = f32[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; +#if GOOGLE_CUDA + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GemmRewriter pass(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + EXPECT_FALSE(changed); +#endif +#if TENSORFLOW_USE_ROCM EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-2, 1e-2})); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - absl::StrReplaceAll(R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fnuz[16,32], y: f8e4m3fnuz[32,16], x_scale: f32[], y_scale: f32[]) -> f32[16,16] { + RunAndFilecheckHloRewrite( + hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), /*f8_rewrite=*/true), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fnuz[16,32], {{.*}}: f8e4m3fnuz[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fnuz[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P0_CV:%[^ ]+]] = f32[16,32]{1,0} convert([[P0]]) -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) -; CHECK-NEXT: [[P2_B:%[^ ]+]] = f32[16,32]{1,0} broadcast([[P2]]), dimensions={} -; CHECK-NEXT: [[P0_UNSCALED:%[^ ]+]] = f32[16,32]{1,0} multiply([[P0_CV]], [[P2_B]]) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fnuz[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_CV:%[^ ]+]] = f32[32,16]{1,0} convert([[P1]]) -; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[P3_B:%[^ ]+]] = f32[32,16]{1,0} broadcast([[P3]]), dimensions={} -; CHECK-NEXT: [[P1_UNSCALED:%[^ ]+]] = f32[32,16]{1,0} multiply([[P1_CV]], [[P3_B]]) -; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0_UNSCALED]], [[P1_UNSCALED]]), -; CHECK: custom_call_target="<>", +; CHECK-PTX-NEXT: [[P0_CV:%[^ ]+]] = f32[16,32]{1,0} convert([[P0]]) +; CHECK-PTX-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) +; CHECK-PTX-NEXT: [[P2_B:%[^ ]+]] = f32[16,32]{1,0} broadcast([[P2]]), dimensions={} +; CHECK-PTX-NEXT: [[P0_UNSCALED:%[^ ]+]] = f32[16,32]{1,0} multiply([[P0_CV]], [[P2_B]]) +; CHECK-PTX-NEXT: [[P1:%[^ ]+]] = f8e4m3fnuz[32,16]{1,0} parameter(1) +; CHECK-PTX-NEXT: [[P1_CV:%[^ ]+]] = f32[32,16]{1,0} convert([[P1]]) +; CHECK-PTX-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) +; CHECK-PTX-NEXT: [[P3_B:%[^ ]+]] = f32[32,16]{1,0} broadcast([[P3]]), dimensions={} +; CHECK-PTX-NEXT: [[P1_UNSCALED:%[^ ]+]] = f32[32,16]{1,0} multiply([[P1_CV]], [[P3_B]]) +; CHECK-PTX-NEXT: [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0_UNSCALED]], [[P1_UNSCALED]]), +; CHECK-GCN-NEXT: [[P1:%[^ ]+]] = f8e4m3fnuz[32,16]{1,0} parameter(1) +; CHECK-GCN-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]) +; CHECK-GCN-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) +; CHECK-GCN-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) +; CHECK-GCN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-PTX: custom_call_target="<>", +; CHECK-GCN: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 ; CHECK-DAG: "alpha_imag":0 ; CHECK-DAG: "beta":0 ; CHECK-DAG: "dot_dimension_numbers":{ ; CHECK-DAG: "lhs_contracting_dimensions":["1"] -; CHECK-DAG: "rhs_contracting_dimensions":["0"] +; CHECK-PTX-DAG: "rhs_contracting_dimensions":["0"] +; CHECK-GCN-DAG: "rhs_contracting_dimensions":["1"] ; CHECK-DAG: "lhs_batch_dimensions":[] ; CHECK-DAG: "rhs_batch_dimensions":[] ; CHECK-DAG: } @@ -7194,8 +7596,8 @@ TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) { ; CHECK-DAG: } ; CHECK-DAG: "epilogue":"DEFAULT" ; CHECK: } - )", - replacements_)); + )"); +#endif } INSTANTIATE_TEST_SUITE_P(Fp8CublasTestsBothLegacyAndLt, @@ -7241,6 +7643,13 @@ class GemmRewriteAllocationTest : public GpuCodegenTest { gpu_executable->GetAllocations(); ASSERT_EQ(allocations.size(), expected_number_of_allocations); } + + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + // Make sure the rewriter does not skip the rewrite for being too small. + debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); + return debug_options; + } }; TEST_F(GemmRewriteAllocationTest, SharedBufferAssignment) { @@ -7262,6 +7671,56 @@ ENTRY AddDotsFunc { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); } +class SmallDotGemmRewriteTest : public GemmRewriteTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_gemm_rewrite_size_threshold(100); + return debug_options; + } +}; + +TEST_F(SmallDotGemmRewriteTest, SkipSmallMatrixMultiplicationRewrite) { + const char* hlo_text = R"( +HloModule SkipSmallMatrixRewrite + +ENTRY DotFunc { + x = f32[3,3] parameter(0) + y = f32[3,3] parameter(1) + ROOT out = f32[3,3] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %DotFunc ({{.*}}: f32[3,3], {{.*}}: f32[3,3]) -> f32[3,3] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,3]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,3]{1,0} parameter(1) +; CHECK-NEXT: [[GEMM:%[^ ]+]] = {{.*}} dot([[P0]], [[P1]]), +; CHECK: lhs_contracting_dims={1}, rhs_contracting_dims={0} +)"); +} + +TEST_F(SmallDotGemmRewriteTest, LargeMatrixMultiplicationIsRewritten) { + const char* hlo_text = R"( +HloModule SkipSmallMatrixRewrite + +ENTRY DotFunc { + x = f32[8,8] parameter(0) + y = f32[8,8] parameter(1) + ROOT out = f32[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %DotFunc ({{.*}}: f32[8,8], {{.*}}: f32[8,8]) -> f32[8,8] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[8,8]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[8,8]{1,0} parameter(1) +; CHECK: {{[^ ]+}} = {{.*}} custom-call([[P0]], [[P1]]) +)"); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/tests/gpu_alignment_test.cc b/xla/service/gpu/tests/gpu_alignment_test.cc index da5fee9dd8d3c..27e7a5925e261 100644 --- a/xla/service/gpu/tests/gpu_alignment_test.cc +++ b/xla/service/gpu/tests/gpu_alignment_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,13 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include - -#include "xla/service/custom_call_target_registry.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/service/llvm_ir/alias_analysis.h" -#include "xla/tests/filecheck.h" #include "tsl/platform/test.h" namespace xla { @@ -44,13 +38,10 @@ ENTRY main { } )"; - auto expected_ir = is_built_with_rocm_ ? R"( -CHECK: @fusion(ptr noalias align 128 dereferenceable(800) %arg0, ptr noalias align 16 dereferenceable(400) %arg1, ptr noalias align 128 dereferenceable(600) %arg2) -)" - : R"( -CHECK: define void @fusion(ptr noalias align 128 dereferenceable(800) %arg0, ptr noalias align 16 dereferenceable(400) %arg1, ptr noalias align 128 dereferenceable(600) %arg2) -)"; - CompileAndVerifyIr(hlo_string, expected_ir); + CompileAndVerifyIr( + hlo_string, + "CHECK: {{.*}}align 128 dereferenceable(800) %{{.*}}align 16 " + "dereferenceable(400) %{{.*}}align 128 dereferenceable(600) %"); } } // namespace diff --git a/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc b/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc index bd6bb0001c45e..5db5ffd47def7 100644 --- a/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc +++ b/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,13 +20,14 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_module_config.h" -#include "xla/statusor.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -35,10 +36,9 @@ namespace { class GpuAllGatherOptimizerTest : public HloTestBase { public: - StatusOr> RunPass(absl::string_view hlo_module, - int64_t num_replicas, - int64_t num_partitions, - bool expect_change) { + absl::StatusOr> RunPass( + absl::string_view hlo_module, int64_t num_replicas, + int64_t num_partitions, bool expect_change) { HloModuleConfig config = GetModuleConfigForTest( /*replica_count=*/num_replicas, /*num_partitions=*/num_partitions); @@ -51,7 +51,7 @@ class GpuAllGatherOptimizerTest : public HloTestBase { return changed.status(); } EXPECT_EQ(changed.value(), expect_change); - return StatusOr>(std::move(module)); + return absl::StatusOr>(std::move(module)); } template diff --git a/xla/service/gpu/tests/gpu_atomic_test.cc b/xla/service/gpu/tests/gpu_atomic_test.cc index 7793a186af547..6897b9fa850e3 100644 --- a/xla/service/gpu/tests/gpu_atomic_test.cc +++ b/xla/service/gpu/tests/gpu_atomic_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,11 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/tests/filecheck.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/service/gpu/tests/gpu_codegen_test.cc b/xla/service/gpu/tests/gpu_codegen_test.cc index 5ade1fa13889a..28228093fa7ba 100644 --- a/xla/service/gpu/tests/gpu_codegen_test.cc +++ b/xla/service/gpu/tests/gpu_codegen_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,16 @@ limitations under the License. #include "xla/service/gpu/tests/gpu_codegen_test.h" #include +#include +#include +#include "absl/status/statusor.h" #include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" #include "xla/debug_options_flags.h" +#include "xla/service/executable.h" #include "xla/service/gpu/gpu_executable.h" +#include "xla/service/hlo_module_config.h" #include "xla/shape_util.h" #include "xla/tests/filecheck.h" #include "xla/tests/verified_hlo_module.h" @@ -52,7 +58,7 @@ void GpuCodegenTest::CompileAndOptionallyVerifyPtx( // executable, and hence the "ptx_str" will be empty. So disabling the // pattern check on the ROCm platform if (!is_built_with_rocm_) { - StatusOr filecheck_result = RunFileCheck(ptx_str, pattern); + absl::StatusOr filecheck_result = RunFileCheck(ptx_str, pattern); ASSERT_TRUE(filecheck_result.ok()); EXPECT_TRUE(filecheck_result.value()); } diff --git a/xla/service/gpu/tests/gpu_codegen_test.h b/xla/service/gpu/tests/gpu_codegen_test.h index 8bae78395c49f..a6269783536b7 100644 --- a/xla/service/gpu/tests/gpu_codegen_test.h +++ b/xla/service/gpu/tests/gpu_codegen_test.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/tests/llvm_irgen_test_base.h" #include "xla/tests/verified_hlo_module.h" @@ -30,7 +32,7 @@ class GpuCodegenTest : public LlvmIrGenTestBase { public: GpuCodegenTest() : is_built_with_rocm_( - se::MultiPlatformManager::PlatformWithName("ROCM").ok()) {} + se::PlatformManager::PlatformWithName("ROCM").ok()) {} protected: // Converts LLVM match to be platform-specific. diff --git a/xla/service/gpu/tests/gpu_compilation_parallelism_test.cc b/xla/service/gpu/tests/gpu_compilation_parallelism_test.cc index 1c060f377562e..18b5feeeb6e73 100644 --- a/xla/service/gpu/tests/gpu_compilation_parallelism_test.cc +++ b/xla/service/gpu/tests/gpu_compilation_parallelism_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,11 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include -#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/error_spec.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/service/hlo_module_config.h" +#include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -29,6 +31,7 @@ class CompilationParallelismTest : public GpuCodegenTest { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); // Use multiple threads for compilation debug_options.set_xla_gpu_force_compilation_parallelism(4); + debug_options.set_xla_gpu_enable_llvm_module_compilation_parallelism(true); return debug_options; } }; diff --git a/xla/service/gpu/tests/gpu_convolution_regression_test.cc b/xla/service/gpu/tests/gpu_convolution_regression_test.cc index bf2612291a4d3..2d382808bb20d 100644 --- a/xla/service/gpu/tests/gpu_convolution_regression_test.cc +++ b/xla/service/gpu/tests/gpu_convolution_regression_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include "absl/strings/string_view.h" +#include "xla/debug_options_flags.h" +#include "xla/service/hlo_module_config.h" #include "xla/tests/hlo_test_base.h" namespace xla { diff --git a/xla/service/gpu/tests/gpu_copy_alone_test.cc b/xla/service/gpu/tests/gpu_copy_alone_test.cc index 5f0bb25aa3832..65e538a6a61d9 100644 --- a/xla/service/gpu/tests/gpu_copy_alone_test.cc +++ b/xla/service/gpu/tests/gpu_copy_alone_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,11 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include -#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/error_spec.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/service/hlo_module_config.h" +#include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/tests/gpu_copy_test.cc b/xla/service/gpu/tests/gpu_copy_test.cc index 830a6a53b79b2..8ef34f0ba6336 100644 --- a/xla/service/gpu/tests/gpu_copy_test.cc +++ b/xla/service/gpu/tests/gpu_copy_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,13 +16,15 @@ limitations under the License. #include #include +#include "xla/error_spec.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -52,5 +54,24 @@ TEST_F(GpuCopyTest, UseMemcpy) { /*match_optimized_ir=*/false); } +TEST_F(GpuCopyTest, CopyTranspose) { + const char* hlo_text = R"( + HloModule Test + + fused_computation { + param_0 = f32[100,200,300]{2,1,0} parameter(0) + ROOT b.1 = f32[100,200,300]{2,0,1} copy(f32[100,200,300]{2,1,0} param_0) + } + + ENTRY main { + a = f32[100, 200, 300]{2,1,0} parameter(0) + ROOT wrapped_b = f32[100,200,300]{2,0,1} fusion(f32[100,200,300]{2,1,0} %a), kind=kLoop, calls=fused_computation + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + ParseAndReturnVerifiedModule(hlo_text)); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/tests/gpu_dyn_shape_test.cc b/xla/service/gpu/tests/gpu_dyn_shape_test.cc index 73fb7e8776fca..94dc2823e2c2f 100644 --- a/xla/service/gpu/tests/gpu_dyn_shape_test.cc +++ b/xla/service/gpu/tests/gpu_dyn_shape_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include -#include #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" @@ -40,9 +39,9 @@ TEST_F(GpuDynamicShapeTest, DynamicShapeR2) { CompileAndVerifyIr(std::move(hlo_module), R"( -; CHECK-LABEL: is_thread_0-true -; CHECK-LABEL: x_padded.in_dyn_bounds-true -; CHECK-LABEL: x_padded.in_bounds-true +; CHECK-DAG: is_thread_0-true +; CHECK-DAG: x.padded{{.*}}.in_dyn_bounds-true +; CHECK-DAG: x.padded{{.*}}.in_bounds-true ; CHECK: %[[dyn_dim_size:.*]] = load i32, ptr ; CHECK: %[[dyn_element_total:.*]] = mul i32 1, %[[dyn_dim_size:.*]] ; CHECK: %[[linear_index:.*]] = add nuw nsw i32 diff --git a/xla/service/gpu/tests/gpu_ftz_test.cc b/xla/service/gpu/tests/gpu_ftz_test.cc index 1a5a02459f15a..f0338549b37d8 100644 --- a/xla/service/gpu/tests/gpu_ftz_test.cc +++ b/xla/service/gpu/tests/gpu_ftz_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,7 +15,12 @@ limitations under the License. #include +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/tests/verified_hlo_module.h" // Check that the ftz (flush denormals to zero) flag is reflected in PTX as diff --git a/xla/service/gpu/tests/gpu_fused_mha_test.cc b/xla/service/gpu/tests/gpu_fused_mha_test.cc index 85ec435927e80..a000aa51c2978 100644 --- a/xla/service/gpu/tests/gpu_fused_mha_test.cc +++ b/xla/service/gpu/tests/gpu_fused_mha_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,48 +13,76 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include #include +#include #include #include #include #include -#include +#include "absl/algorithm/container.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "xla/array4d.h" -#include "xla/client/local_client.h" #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" +#include "xla/error_spec.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout_util.h" #include "xla/literal.h" -#include "xla/reference_util.h" -#include "xla/service/gpu/cublas_cudnn.h" +#include "xla/literal_util.h" +#include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/shape_util.h" -#include "xla/statusor.h" +#include "xla/service/hlo_module_config.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/stream_executor.h" #include "xla/test_helpers.h" -#include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" -#include "xla/tests/test_utils.h" #include "xla/types.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" +#endif + namespace xla { namespace gpu { +namespace { class MultiHeadedAttentionTest : public GpuCodegenTest { public: + MultiHeadedAttentionTest() { +#if !defined(GOOGLE_CUDA) || CUDA_VERSION < 12000 + skip_reason_ = "cuDNN Fused MHA requires CUDA 12 or later."; + return; +#endif + stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); + // Enforce capability minor == 0 because hardware with a non-zero minor + // number typically has insufficient shared memory for cuDNN FMHA. + if (!cc.IsAtLeastAmpere() || cc.minor != 0) { + skip_reason_ = + "cuDNN Fused MHA requires Nvidia AMPERE+ GPUs with minor " + "compute capability == 0."; + return; + } + if (GetDnnVersionInfo(backend().default_stream_executor()) < + se::dnn::VersionInfo(8, 8, 0)) { + skip_reason_ = "cuDNN Fused MHA requires cuDNN 8.8.0 or later."; + return; + } + } + se::CudaComputeCapability GetCudaComputeCapability() { return backend() .default_stream_executor() @@ -62,67 +90,48 @@ class MultiHeadedAttentionTest : public GpuCodegenTest { .cuda_compute_capability(); } - se::dnn::VersionInfo GetCudnnVersion() { - se::dnn::VersionInfo cudnn_version; - stream_executor::StreamExecutor *stream_exec = - backend().default_stream_executor(); - stream_executor::dnn::DnnSupport *dnn = stream_exec->AsDnn(); - if (!dnn) { - return se::dnn::VersionInfo(0, 0, 0); - } - StatusOr se_cudnn_version = dnn->GetVersion(); - if (se_cudnn_version.ok()) { - cudnn_version = (*se_cudnn_version); - } else { - cudnn_version = se::dnn::VersionInfo(0, 0, 0); - } - return cudnn_version; - } ErrorSpec error_spec_{2.5E-3, 1e-5}; protected: DebugOptions GetDebugOptionsForTest() override { auto debug_options = HloTestBase::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_xla_runtime_executable(true); + debug_options.set_xla_gpu_enable_xla_runtime_executable(false); + debug_options.set_xla_gpu_enable_cudnn_fmha(false); return debug_options; } - void IsFMHACalled(const std::string &hlo_string, - HloModuleConfig &config_with_fmha, - const std::string &prefix, bool is_training) { - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr verified_module, - ParseAndReturnVerifiedModule(hlo_string, config_with_fmha)); + absl::StatusOr CountFMHACalls(absl::string_view hlo_string, + const HloModuleConfig &config) { + TF_ASSIGN_OR_RETURN(std::unique_ptr verified_module, + ParseAndReturnVerifiedModule(hlo_string, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr optimized_verified_module, - GetOptimizedModule(std::move(verified_module))); + TF_ASSIGN_OR_RETURN(std::unique_ptr optimized_verified_module, + GetOptimizedModule(std::move(verified_module))); - auto count = absl::c_count_if( + return absl::c_count_if( optimized_verified_module->entry_computation()->instructions(), [&](const HloInstruction *inst) { return inst->opcode() == HloOpcode::kCustomCall && - absl::StrContains(inst->custom_call_target(), prefix); + absl::StrContains(inst->custom_call_target(), "__cudnn$fmha"); }); - if (is_training) { - EXPECT_EQ(count, 2); - } else { - EXPECT_EQ(count, 1); - } } - void ExecuteAndCompare(const std::string hlo_string, + void ExecuteAndCompare(absl::string_view hlo_string, const std::vector &literals, - bool is_training = false) { + int expected_num_fmha_calls = 1) { HloModuleConfig config; - config.set_debug_options(GetDebugOptionsForTest()); + DebugOptions debug_options = GetDebugOptionsForTest(); + config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_string, config)); auto expected_result = ExecuteAndTransfer(std::move(module), literals); - DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_cudnn_fmha(true); + // Sanity check to ensure the first computation doesn't use FMHA. + TF_ASSERT_OK_AND_ASSIGN(int num_fmha_calls, + CountFMHACalls(hlo_string, config)); + EXPECT_EQ(num_fmha_calls, 0); + debug_options.set_xla_gpu_enable_cudnn_fmha(true); HloModuleConfig config_with_fmha; config_with_fmha.set_debug_options(debug_options); @@ -133,8 +142,9 @@ class MultiHeadedAttentionTest : public GpuCodegenTest { EXPECT_TRUE( LiteralTestUtil::Near(expected_result, actual_result, error_spec_)); - std::string prefix = "__cudnn$fhma"; - IsFMHACalled(hlo_string, config_with_fmha, prefix, is_training); + TF_ASSERT_OK_AND_ASSIGN(num_fmha_calls, + CountFMHACalls(hlo_string, config_with_fmha)); + EXPECT_EQ(num_fmha_calls, expected_num_fmha_calls); } template @@ -157,6 +167,14 @@ class MultiHeadedAttentionTest : public GpuCodegenTest { return LiteralUtil::CreateR4FromArray4DWithLayout( input_data, LayoutUtil::MakeLayout(minor_to_major)); } + + // Centralize skip checks in the constructor. Unfortunately we cannot call + // GTEST_SKIP from the constructor. Instead, we set (if needed) `skip_reason`, + // and then check it from all test fixtures. + // An alternative would be to use the SetUp() override, but for this to be + // correct we'd have to ensure that all the parents' SetUp() methods are + // called, which is error prone. + std::optional skip_reason_; }; class MultiHeadedAttentionBMMBMM : public MultiHeadedAttentionTest { @@ -440,13 +458,7 @@ class MultiHeadedAttentionBMMBMM : public MultiHeadedAttentionTest { template void TestImpl_FMHABMM_BMM_vanilla() { - stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); - se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); - if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && - real_cudnn_version >= se::dnn::VersionInfo(8, 8, 0))) { - GTEST_SKIP() << "Fused MHA is supported with the Nvidia AMPERE+ GPUs and " - "cuDNN >= 8.8.0."; - } + if (skip_reason_) GTEST_SKIP() << *skip_reason_; XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -469,13 +481,7 @@ class MultiHeadedAttentionBMMBMM : public MultiHeadedAttentionTest { template void TestImpl_FMHABMM_BMM_arg_reversal() { - stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); - se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); - if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && - real_cudnn_version >= se::dnn::VersionInfo(8, 8, 0))) { - GTEST_SKIP() << "Fused MHA is supported with the Nvidia AMPERE+ GPUs and " - "cuDNN >= 8.8.0."; - } + if (skip_reason_) GTEST_SKIP() << *skip_reason_; XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -497,13 +503,7 @@ class MultiHeadedAttentionBMMBMM : public MultiHeadedAttentionTest { template void TestImpl_FMHABMM_BMM_arg_layout_manipulation_arg_reversal_fusion() { - stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); - se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); - if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && - real_cudnn_version >= se::dnn::VersionInfo(8, 8, 0))) { - GTEST_SKIP() << "Fused MHA is supported with the Nvidia AMPERE+ GPUs and " - "cuDNN >= 8.8.0."; - } + if (skip_reason_) GTEST_SKIP() << *skip_reason_; XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -527,13 +527,7 @@ class MultiHeadedAttentionBMMBMM : public MultiHeadedAttentionTest { template void TestImpl_FMHABMM_BMM_arg_reversal_epilogue_transpose_fusion() { - stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); - se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); - if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && - real_cudnn_version >= se::dnn::VersionInfo(8, 8, 0))) { - GTEST_SKIP() << "Fused MHA is supported with the Nvidia AMPERE+ GPUs and " - "cuDNN >= 8.8.0."; - } + if (skip_reason_) GTEST_SKIP() << *skip_reason_; XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -558,13 +552,7 @@ class MultiHeadedAttentionBMMBMM : public MultiHeadedAttentionTest { template void TestImpl_FMHABMM_BMM_arg_layout_manipulation_arg_reversal_prologue_transpose_fusion() { // NOLINT - stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); - se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); - if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && - real_cudnn_version >= se::dnn::VersionInfo(8, 8, 0))) { - GTEST_SKIP() << "Fused MHA is supported with the Nvidia AMPERE+ GPUs and " - "cuDNN >= 8.8.0."; - } + if (skip_reason_) GTEST_SKIP() << *skip_reason_; XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -588,13 +576,7 @@ class MultiHeadedAttentionBMMBMM : public MultiHeadedAttentionTest { template void TestImpl_FMHABMM_BMM_all_canonicalization_transpose_fusion() { - stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); - se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); - if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && - real_cudnn_version >= se::dnn::VersionInfo(8, 8, 0))) { - GTEST_SKIP() << "Fused MHA is supported with the Nvidia AMPERE+ GPUs and " - "cuDNN >= 8.8.0."; - } + if (skip_reason_) GTEST_SKIP() << *skip_reason_; XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -618,13 +600,7 @@ class MultiHeadedAttentionBMMBMM : public MultiHeadedAttentionTest { template void TestImpl_FMHABMM_BMM_all_canonicalization() { - stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); - se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); - if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && - real_cudnn_version >= se::dnn::VersionInfo(8, 8, 0))) { - GTEST_SKIP() << "Fused MHA is supported with the Nvidia AMPERE+ GPUs and " - "cuDNN >= 8.8.0."; - } + if (skip_reason_) GTEST_SKIP() << *skip_reason_; XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -646,13 +622,7 @@ class MultiHeadedAttentionBMMBMM : public MultiHeadedAttentionTest { template void TestImpl_FMHABMM_BMM_all_canonicalization_transpose_fusion_small() { - stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); - se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); - if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && - real_cudnn_version >= se::dnn::VersionInfo(8, 8, 0))) { - GTEST_SKIP() << "Fused MHA is supported with the Nvidia AMPERE+ GPUs and " - "cuDNN >= 8.8.0."; - } + if (skip_reason_) GTEST_SKIP() << *skip_reason_; XlaBuilder builder(TestName()); auto lhs_bmm1_literal = GetInput4DLiteral({2, 4, 2, 64}, {3, 2, 1, 0}); @@ -670,13 +640,7 @@ class MultiHeadedAttentionBMMBMM : public MultiHeadedAttentionTest { template void TestImpl_BMM_BMM1_contracting_dim_stride_not_1() { - stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); - se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); - if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && - real_cudnn_version >= se::dnn::VersionInfo(8, 8, 0))) { - GTEST_SKIP() << "Fused MHA is supported with the Nvidia AMPERE+ GPUs and " - "cuDNN >= 8.8.0."; - } + if (skip_reason_) GTEST_SKIP() << *skip_reason_; XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -698,13 +662,7 @@ class MultiHeadedAttentionBMMBMM : public MultiHeadedAttentionTest { template void TestImpl_BMM_BMM2_non_contracting_dim_stride_not_1() { - stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); - se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); - if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && - real_cudnn_version >= se::dnn::VersionInfo(8, 8, 0))) { - GTEST_SKIP() << "Fused MHA is supported with the Nvidia AMPERE+ GPUs and " - "cuDNN >= 8.8.0."; - } + if (skip_reason_) GTEST_SKIP() << *skip_reason_; XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -1268,13 +1226,7 @@ class MultiHeadedAttentionBMMScaleBiasMaskSoftmaxBMM // BMM1 - Scale - Bias - Mask - Softmax - BMM2 template void TestImpl_FMHABMM1_Scale_Bias_Mask_Softmax_BMM2_vanilla() { - stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); - se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); - if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && - real_cudnn_version >= se::dnn::VersionInfo(8, 8, 0))) { - GTEST_SKIP() << "Fused MHA is supported with the Nvidia AMPERE+ GPUs and " - "cuDNN >= 8.8.0."; - } + if (skip_reason_) GTEST_SKIP() << *skip_reason_; XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -1300,13 +1252,7 @@ class MultiHeadedAttentionBMMScaleBiasMaskSoftmaxBMM template void TestImpl_FMHABMM1_Scale_Bias_Mask_Softmax_BMM2_vanilla_smaller() { - stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); - se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); - if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && - real_cudnn_version >= se::dnn::VersionInfo(8, 8, 0))) { - GTEST_SKIP() << "Fused MHA is supported with the Nvidia AMPERE+ GPUs and " - "cuDNN >= 8.8.0."; - } + if (skip_reason_) GTEST_SKIP() << *skip_reason_; XlaBuilder builder(TestName()); auto lhs_bmm1_literal = GetInput4DLiteral({2, 6, 40, 64}, {3, 2, 1, 0}); @@ -1329,13 +1275,7 @@ class MultiHeadedAttentionBMMScaleBiasMaskSoftmaxBMM template void TestImpl_FMHABMM1_Scale_Bias_Mask_Softmax_BMM2_arg_reversal() { - stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); - se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); - if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && - real_cudnn_version >= se::dnn::VersionInfo(8, 8, 0))) { - GTEST_SKIP() << "Fused MHA is supported with the Nvidia AMPERE+ GPUs and " - "cuDNN >= 8.8.0."; - } + if (skip_reason_) GTEST_SKIP() << *skip_reason_; XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -1361,13 +1301,7 @@ class MultiHeadedAttentionBMMScaleBiasMaskSoftmaxBMM // Traning BMM1 - Scale - bias - Mask - Softmax - BMM2 template void TestImpl_FMHA_Training_BMM1_Scale_Bias_Mask_Softmax_BMM2_vanilla() { - stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); - se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); - if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && - real_cudnn_version >= se::dnn::VersionInfo(8, 9, 1))) { - GTEST_SKIP() << "Fused MHA is supported with the Nvidia AMPERE+ GPUs and " - "cuDNN >= 8.9.1."; - } + if (skip_reason_) GTEST_SKIP() << *skip_reason_; XlaBuilder builder(TestName()); auto lhs_bmm1_literal = GetInput4DLiteral({2, 6, 128, 64}, {3, 2, 1, 0}); @@ -1387,7 +1321,7 @@ class MultiHeadedAttentionBMMScaleBiasMaskSoftmaxBMM ExecuteAndCompare(hlo_string, {&lhs_bmm1_literal, &rhs_bmm1_literal, &rhs_bmm2_literal, &do_bmm2_literal, &mask_literal}, - true); + /*expected_num_fmha_calls=*/2); } }; @@ -1610,13 +1544,7 @@ class MultiHeadedAttentionBMMScaleMaskSoftmaxBMM // BMM1 - Scale - Mask - Softmax - BMM2 template void TestImpl_FMHABMM1_Scale_Mask_Softmax_BMM2_vanilla() { - stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); - se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); - if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && - real_cudnn_version >= se::dnn::VersionInfo(8, 8, 0))) { - GTEST_SKIP() << "Fused MHA is supported with the Nvidia AMPERE+ GPUs and " - "cuDNN >= 8.8.0."; - } + if (skip_reason_) GTEST_SKIP() << *skip_reason_; XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -1640,13 +1568,7 @@ class MultiHeadedAttentionBMMScaleMaskSoftmaxBMM template void TestImpl_FMHABMM1_Scale_Mask_Softmax_BMM2_arg_reversal() { - stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); - se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); - if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && - real_cudnn_version >= se::dnn::VersionInfo(8, 8, 0))) { - GTEST_SKIP() << "Fused MHA is supported with the Nvidia AMPERE+ GPUs and " - "cuDNN >= 8.8.0."; - } + if (skip_reason_) GTEST_SKIP() << *skip_reason_; XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -1671,8 +1593,8 @@ class MultiHeadedAttentionBMMScaleMaskSoftmaxBMM } }; +// Bmm1 - Softmax - Bmm2 class MultiHeadedAttentionBMMSoftmaxBMM : public MultiHeadedAttentionTest { - // Bmm1 - Softmax - Bmm2 protected: std::string GetModuleFMHABMM1_Softmax_BMM2_HloString_F16() { const std::string hlo_text = R"( @@ -1767,13 +1689,7 @@ class MultiHeadedAttentionBMMSoftmaxBMM : public MultiHeadedAttentionTest { // BMM1 - Softmax - BMM2 template void TestImpl_FMHABMM1_Softmax_BMM2_vanilla() { - stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); - se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); - if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && - real_cudnn_version >= se::dnn::VersionInfo(8, 8, 0))) { - GTEST_SKIP() << "Fused MHA is supported with the Nvidia AMPERE+ GPUs and " - "cuDNN >= 8.8.0."; - } + if (skip_reason_) GTEST_SKIP() << *skip_reason_; XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -2155,13 +2071,7 @@ class MultiHeadedAttentionBMMScaleBiasSoftmaxBMM // BMM1 - Scale - bias - Softmax - BMM2 template void TestImpl_FMHABMM1_Scale_Bias_Softmax_BMM2_vanilla() { - stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); - se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); - if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && - real_cudnn_version >= se::dnn::VersionInfo(8, 8, 0))) { - GTEST_SKIP() << "Fused MHA is supported with the Nvidia AMPERE+ GPUs and " - "cuDNN >= 8.8.0."; - } + if (skip_reason_) GTEST_SKIP() << *skip_reason_; XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -2184,12 +2094,10 @@ class MultiHeadedAttentionBMMScaleBiasSoftmaxBMM // Training BMM1 - Scale - bias - Softmax - BMM2 template void TestImpl_FMHA_Training_BMM1_Scale_Bias_Softmax_BMM2_vanilla() { - stream_executor::CudaComputeCapability cc = GetCudaComputeCapability(); - se::dnn::VersionInfo real_cudnn_version = GetCudnnVersion(); - if (!(cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0 && - real_cudnn_version >= se::dnn::VersionInfo(8, 9, 1))) { - GTEST_SKIP() << "Fused MHA is supported with the Nvidia AMPERE+ GPUs and " - "cuDNN >= 8.9.1."; + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfo(backend().default_stream_executor()) < + se::dnn::VersionInfo(8, 9, 1)) { + GTEST_SKIP() << "Backward fused MHA requires cuDNN >= 8.9.1."; } XlaBuilder builder(TestName()); @@ -2211,7 +2119,862 @@ class MultiHeadedAttentionBMMScaleBiasSoftmaxBMM ExecuteAndCompare(hlo_string, {&lhs_bmm1_literal, &rhs_bmm1_literal, &rhs_bmm2_literal, &bias_literal, &mask_literal, &do_literal}, - true); + /*expected_num_fmha_calls=*/2); + } +}; + +class FlashAttentionBMMScaleCausalMaskSoftmaxBMM + : public MultiHeadedAttentionTest { + protected: + const std::string // NOLINT + GetModuleFlash_Attention_BMM1_CausalMask_Softmax_BMM2_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0})->bf16[2,6,2048,128]{3,2,1,0}}, allow_spmd_sharding_propagation_to_output={true} + + region_0.28 { + Arg_0.29 = bf16[] parameter(0) + Arg_1.30 = bf16[] parameter(1) + ROOT maximum.31 = bf16[] maximum(Arg_0.29, Arg_1.30) + } + + region_1.40 { + Arg_0.41 = f32[] parameter(0) + Arg_1.42 = f32[] parameter(1) + ROOT add.43 = f32[] add(Arg_0.41, Arg_1.42) + } + + ENTRY main.52 { + Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated} + dot.10 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.6 = bf16[] constant(2) + broadcast.7 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.6), dimensions={} + multiply.11 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.10, broadcast.7) + iota.16 = s32[2048]{0} iota(), iota_dimension=0 + reshape.17 = s32[1,2048,1]{2,1,0} reshape(iota.16) + broadcast.18 = s32[1,2048,2048,1]{3,2,1,0} broadcast(reshape.17), dimensions={0,1,3} + reshape.19 = s32[2048,2048]{1,0} reshape(broadcast.18) + iota.12 = s32[2048]{0} iota(), iota_dimension=0 + reshape.13 = s32[1,1,2048]{2,1,0} reshape(iota.12) + broadcast.14 = s32[2048,1,1,2048]{3,2,1,0} broadcast(reshape.13), dimensions={1,2,3} + reshape.15 = s32[2048,2048]{1,0} reshape(broadcast.14) + compare.20 = pred[2048,2048]{1,0} compare(reshape.19, reshape.15), direction=LT + convert.21 = bf16[2048,2048]{1,0} convert(compare.20) + constant.4 = bf16[] constant(-2.366e+38) + broadcast.5 = bf16[2048,2048]{1,0} broadcast(constant.4), dimensions={} + multiply.22 = bf16[2048,2048]{1,0} multiply(convert.21, broadcast.5) + reshape.23 = bf16[1,1,2048,2048]{3,2,1,0} reshape(multiply.22) + broadcast.24 = bf16[1,1,2048,2048]{3,2,1,0} broadcast(reshape.23), dimensions={0,1,2,3} + reshape.25 = bf16[2048,2048]{1,0} reshape(broadcast.24) + broadcast.26 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reshape.25), dimensions={2,3} + add.27 = bf16[2,6,2048,2048]{3,2,1,0} add(multiply.11, broadcast.26) + constant.9 = bf16[] constant(-inf) + reduce.32 = bf16[2,6,2048]{2,1,0} reduce(add.27, constant.9), dimensions={3}, to_apply=region_0.28 + reshape.33 = bf16[2,6,2048,1]{3,2,1,0} reshape(reduce.32) + broadcast.34 = bf16[2,6,2048,1]{3,2,1,0} broadcast(reshape.33), dimensions={0,1,2,3} + reshape.35 = bf16[2,6,2048]{2,1,0} reshape(broadcast.34) + broadcast.36 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reshape.35), dimensions={0,1,2} + subtract.37 = bf16[2,6,2048,2048]{3,2,1,0} subtract(add.27, broadcast.36) + exponential.38 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.37) + convert.39 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.38) + constant.8 = f32[] constant(0) + reduce.44 = f32[2,6,2048]{2,1,0} reduce(convert.39, constant.8), dimensions={3}, to_apply=region_1.40 + reshape.45 = f32[2,6,2048,1]{3,2,1,0} reshape(reduce.44) + convert.46 = bf16[2,6,2048,1]{3,2,1,0} convert(reshape.45) + broadcast.47 = bf16[2,6,2048,1]{3,2,1,0} broadcast(convert.46), dimensions={0,1,2,3} + reshape.48 = bf16[2,6,2048]{2,1,0} reshape(broadcast.47) + broadcast.49 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reshape.48), dimensions={0,1,2} + divide.50 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.38, broadcast.49) + Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated} + ROOT dot.51 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.50, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + } + )"; + return hlo_text; + } + + const std::string // NOLINT + GetModuleFlash_Attention_Training_BMM1_CausalMask_Softmax_BMM2_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0})->(bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + + region_0.29 { + Arg_0.30 = bf16[] parameter(0) + Arg_1.31 = bf16[] parameter(1) + ROOT maximum.32 = bf16[] maximum(Arg_0.30, Arg_1.31) + } + + region_1.41 { + Arg_0.42 = f32[] parameter(0) + Arg_1.43 = f32[] parameter(1) + ROOT add.44 = f32[] add(Arg_0.42, Arg_1.43) + } + + region_2.63 { + Arg_0.64 = bf16[] parameter(0) + Arg_1.65 = bf16[] parameter(1) + ROOT add.66 = bf16[] add(Arg_0.64, Arg_1.65) + } + + region_3.75 { + Arg_0.76 = f32[] parameter(0) + Arg_1.77 = f32[] parameter(1) + ROOT add.78 = f32[] add(Arg_0.76, Arg_1.77) + } + + ENTRY main.88 { + Arg_0.1 = bf16[2,6,1024,64]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = bf16[2,6,64,1024]{3,2,1,0} parameter(1), sharding={replicated} + dot.12 = bf16[2,6,1024,1024]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + iota.17 = s32[1024]{0} iota(), iota_dimension=0 + reshape.18 = s32[1,1024,1]{2,1,0} reshape(iota.17) + broadcast.19 = s32[1,1024,1024,1]{3,2,1,0} broadcast(reshape.18), dimensions={0,1,3} + reshape.20 = s32[1024,1024]{1,0} reshape(broadcast.19) + iota.13 = s32[1024]{0} iota(), iota_dimension=0 + reshape.14 = s32[1,1,1024]{2,1,0} reshape(iota.13) + broadcast.15 = s32[1024,1,1,1024]{3,2,1,0} broadcast(reshape.14), dimensions={1,2,3} + reshape.16 = s32[1024,1024]{1,0} reshape(broadcast.15) + compare.21 = pred[1024,1024]{1,0} compare(reshape.20, reshape.16), direction=LT + convert.22 = bf16[1024,1024]{1,0} convert(compare.21) + constant.7 = bf16[] constant(-2.366e+38) + broadcast.8 = bf16[1024,1024]{1,0} broadcast(constant.7), dimensions={} + multiply.23 = bf16[1024,1024]{1,0} multiply(convert.22, broadcast.8) + reshape.24 = bf16[1,1,1024,1024]{3,2,1,0} reshape(multiply.23) + broadcast.25 = bf16[1,1,1024,1024]{3,2,1,0} broadcast(reshape.24), dimensions={0,1,2,3} + reshape.26 = bf16[1024,1024]{1,0} reshape(broadcast.25) + broadcast.27 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.26), dimensions={2,3} + add.28 = bf16[2,6,1024,1024]{3,2,1,0} add(dot.12, broadcast.27) + constant.11 = bf16[] constant(-inf) + reduce.33 = bf16[2,6,1024]{2,1,0} reduce(add.28, constant.11), dimensions={3}, to_apply=region_0.29 + reshape.34 = bf16[2,6,1024,1]{3,2,1,0} reshape(reduce.33) + broadcast.35 = bf16[2,6,1024,1]{3,2,1,0} broadcast(reshape.34), dimensions={0,1,2,3} + reshape.36 = bf16[2,6,1024]{2,1,0} reshape(broadcast.35) + broadcast.37 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.36), dimensions={0,1,2} + subtract.38 = bf16[2,6,1024,1024]{3,2,1,0} subtract(add.28, broadcast.37) + exponential.39 = bf16[2,6,1024,1024]{3,2,1,0} exponential(subtract.38) + convert.40 = f32[2,6,1024,1024]{3,2,1,0} convert(exponential.39) + constant.10 = f32[] constant(0) + reduce.45 = f32[2,6,1024]{2,1,0} reduce(convert.40, constant.10), dimensions={3}, to_apply=region_1.41 + reshape.46 = f32[2,6,1024,1]{3,2,1,0} reshape(reduce.45) + convert.47 = bf16[2,6,1024,1]{3,2,1,0} convert(reshape.46) + broadcast.48 = bf16[2,6,1024,1]{3,2,1,0} broadcast(convert.47), dimensions={0,1,2,3} + reshape.49 = bf16[2,6,1024]{2,1,0} reshape(broadcast.48) + broadcast.50 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.49), dimensions={0,1,2} + divide.51 = bf16[2,6,1024,1024]{3,2,1,0} divide(exponential.39, broadcast.50) + Arg_2.3 = bf16[2,6,1024,64]{3,2,1,0} parameter(2), sharding={replicated} + dot.54 = bf16[2,6,1024,64]{3,2,1,0} dot(divide.51, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = bf16[2,6,1024,64]{3,2,1,0} parameter(3), sharding={replicated} + dot.57 = bf16[2,6,1024,1024]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + broadcast.70 = bf16[2,6,1024,1]{3,2,1,0} broadcast(convert.47), dimensions={0,1,2,3} + reshape.71 = bf16[2,6,1024]{2,1,0} reshape(broadcast.70) + broadcast.72 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.71), dimensions={0,1,2} + divide.73 = bf16[2,6,1024,1024]{3,2,1,0} divide(dot.57, broadcast.72) + constant.5 = bf16[] constant(1) + broadcast.6 = bf16[2,6,1024,1]{3,2,1,0} broadcast(constant.5), dimensions={} + multiply.52 = bf16[2,6,1024,1]{3,2,1,0} multiply(convert.47, convert.47) + divide.53 = bf16[2,6,1024,1]{3,2,1,0} divide(broadcast.6, multiply.52) + broadcast.58 = bf16[2,6,1024,1]{3,2,1,0} broadcast(divide.53), dimensions={0,1,2,3} + reshape.59 = bf16[2,6,1024]{2,1,0} reshape(broadcast.58) + broadcast.60 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.59), dimensions={0,1,2} + multiply.61 = bf16[2,6,1024,1024]{3,2,1,0} multiply(dot.57, broadcast.60) + multiply.62 = bf16[2,6,1024,1024]{3,2,1,0} multiply(multiply.61, exponential.39) + constant.9 = bf16[] constant(0) + reduce.67 = bf16[2,6,1024]{2,1,0} reduce(multiply.62, constant.9), dimensions={3}, to_apply=region_2.63 + reshape.68 = bf16[2,6,1024,1]{3,2,1,0} reshape(reduce.67) + negate.69 = bf16[2,6,1024,1]{3,2,1,0} negate(reshape.68) + convert.74 = f32[2,6,1024,1]{3,2,1,0} convert(negate.69) + reduce.79 = f32[2,6,1024]{2,1,0} reduce(convert.74, constant.10), dimensions={3}, to_apply=region_3.75 + broadcast.80 = f32[2,6,1024,1024]{3,2,1,0} broadcast(reduce.79), dimensions={0,1,2} + convert.81 = bf16[2,6,1024,1024]{3,2,1,0} convert(broadcast.80) + add.82 = bf16[2,6,1024,1024]{3,2,1,0} add(divide.73, convert.81) + multiply.83 = bf16[2,6,1024,1024]{3,2,1,0} multiply(add.82, exponential.39) + dot.86 = bf16[2,6,1024,64]{3,2,1,0} dot(multiply.83, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot.84 = bf16[2,6,1024,64]{3,2,1,0} dot(multiply.83, Arg_0.1), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.85 = bf16[2,6,64,1024]{2,3,1,0} transpose(dot.84), dimensions={0,1,3,2} + dot.55 = bf16[2,6,64,1024]{3,2,1,0} dot(Arg_3.4, divide.51), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.56 = bf16[2,6,1024,64]{2,3,1,0} transpose(dot.55), dimensions={0,1,3,2} + ROOT tuple.87 = (bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{2,3,1,0}, bf16[2,6,1024,64]{2,3,1,0}) tuple(dot.54, dot.86, transpose.85, transpose.56) + } + )"; + return hlo_text; + } + + template + void TestImpl_Flash_Attention_BMM1_CausalMask_Softmax_BMM2() { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfo(backend().default_stream_executor()) < + se::dnn::VersionInfo(8, 9, 3)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.3."; + } + XlaBuilder builder(TestName()); + auto lhs_bmm1_literal = + GetInput4DLiteral({2, 6, 2048, 128}, {3, 2, 1, 0}); + auto rhs_bmm1_literal = + GetInput4DLiteral({2, 6, 128, 2048}, {3, 2, 1, 0}); + auto rhs_bmm2_literal = + GetInput4DLiteral({2, 6, 2048, 128}, {3, 2, 1, 0}); + std::string hlo_string = ""; + hlo_string = + GetModuleFlash_Attention_BMM1_CausalMask_Softmax_BMM2_HloString_BF16(); // NOLINT + ExecuteAndCompare( + hlo_string, {&lhs_bmm1_literal, &rhs_bmm1_literal, &rhs_bmm2_literal}); + } + + template + void TestImpl_Flash_Attention_Training_BMM1_CausalMask_Softmax_BMM2() { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfo(backend().default_stream_executor()) < + se::dnn::VersionInfo(8, 9, 3)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.3."; + } + XlaBuilder builder(TestName()); + auto lhs_bmm1_literal = + GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + auto rhs_bmm1_literal = + GetInput4DLiteral({2, 6, 64, 1024}, {3, 2, 1, 0}); + auto rhs_bmm2_literal = + GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + auto do_literal = GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + std::string hlo_string = ""; + hlo_string = + GetModuleFlash_Attention_Training_BMM1_CausalMask_Softmax_BMM2_HloString_BF16(); // NOLINT + ExecuteAndCompare( + hlo_string, + {&lhs_bmm1_literal, &rhs_bmm1_literal, &rhs_bmm2_literal, &do_literal}, + /*expected_num_fmha_calls=*/2); + } +}; + +class FlashAttentionBMMScaleBiasSoftmaxBMM : public MultiHeadedAttentionTest { + protected: + const std::string // NOLINT + GetModuleFlash_Attention_BMM1_Bias_Softmax_BMM2_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,2048]{3,2,1,0},bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,2048,2048]{3,2,1,0})->bf16[2,6,2048,128]{3,2,1,0}}, allow_spmd_sharding_propagation_to_output={true} + + region_0.28 { + Arg_0.29 = bf16[] parameter(0) + Arg_1.30 = bf16[] parameter(1) + ROOT maximum.31 = bf16[] maximum(Arg_0.29, Arg_1.30) + } + + region_1.40 { + Arg_0.41 = f32[] parameter(0) + Arg_1.42 = f32[] parameter(1) + ROOT add.43 = f32[] add(Arg_0.41, Arg_1.42) + } + + ENTRY main.52 { + Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = bf16[2,6,128,2048]{3,2,1,0} parameter(1), sharding={replicated} + dot.10 = bf16[2,6,2048,2048]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.6 = bf16[] constant(2) + broadcast.7 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(constant.6), dimensions={} + multiply.11 = bf16[2,6,2048,2048]{3,2,1,0} multiply(dot.10, broadcast.7) + Arg_3.4 = bf16[2,6,2048,2048]{3,2,1,0} parameter(3), sharding={replicated} + add.27 = bf16[2,6,2048,2048]{3,2,1,0} add(multiply.11, Arg_3.4) + constant.9 = bf16[] constant(-inf) + reduce.32 = bf16[2,6,2048]{2,1,0} reduce(add.27, constant.9), dimensions={3}, to_apply=region_0.28 + reshape.33 = bf16[2,6,2048,1]{3,2,1,0} reshape(reduce.32) + broadcast.34 = bf16[2,6,2048,1]{3,2,1,0} broadcast(reshape.33), dimensions={0,1,2,3} + reshape.35 = bf16[2,6,2048]{2,1,0} reshape(broadcast.34) + broadcast.36 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reshape.35), dimensions={0,1,2} + subtract.37 = bf16[2,6,2048,2048]{3,2,1,0} subtract(add.27, broadcast.36) + exponential.38 = bf16[2,6,2048,2048]{3,2,1,0} exponential(subtract.37) + convert.39 = f32[2,6,2048,2048]{3,2,1,0} convert(exponential.38) + constant.8 = f32[] constant(0) + reduce.44 = f32[2,6,2048]{2,1,0} reduce(convert.39, constant.8), dimensions={3}, to_apply=region_1.40 + reshape.45 = f32[2,6,2048,1]{3,2,1,0} reshape(reduce.44) + convert.46 = bf16[2,6,2048,1]{3,2,1,0} convert(reshape.45) + broadcast.47 = bf16[2,6,2048,1]{3,2,1,0} broadcast(convert.46), dimensions={0,1,2,3} + reshape.48 = bf16[2,6,2048]{2,1,0} reshape(broadcast.47) + broadcast.49 = bf16[2,6,2048,2048]{3,2,1,0} broadcast(reshape.48), dimensions={0,1,2} + divide.50 = bf16[2,6,2048,2048]{3,2,1,0} divide(exponential.38, broadcast.49) + Arg_2.3 = bf16[2,6,2048,128]{3,2,1,0} parameter(2), sharding={replicated} + ROOT dot.51 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.50, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + } + )"; + return hlo_text; + } + + const std::string // NOLINT + GetModuleFlash_Attention_Training_BMM1_Bias_Softmax_BMM2_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,1024]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0})->(bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + + region_0.13 { + Arg_0.14 = bf16[] parameter(0) + Arg_1.15 = bf16[] parameter(1) + ROOT maximum.16 = bf16[] maximum(Arg_0.14, Arg_1.15) + } + + region_1.25 { + Arg_0.26 = f32[] parameter(0) + Arg_1.27 = f32[] parameter(1) + ROOT add.28 = f32[] add(Arg_0.26, Arg_1.27) + } + + region_2.47 { + Arg_0.48 = bf16[] parameter(0) + Arg_1.49 = bf16[] parameter(1) + ROOT add.50 = bf16[] add(Arg_0.48, Arg_1.49) + } + + region_3.59 { + Arg_0.60 = f32[] parameter(0) + Arg_1.61 = f32[] parameter(1) + ROOT add.62 = f32[] add(Arg_0.60, Arg_1.61) + } + + ENTRY main.72 { + Arg_0.1 = bf16[2,6,1024,64]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = bf16[2,6,64,1024]{3,2,1,0} parameter(1), sharding={replicated} + dot.11 = bf16[2,6,1024,1024]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = bf16[2,6,1024,1024]{3,2,1,0} parameter(3), sharding={replicated} + add.12 = bf16[2,6,1024,1024]{3,2,1,0} add(dot.11, Arg_3.4) + constant.9 = bf16[] constant(-inf) + reduce.17 = bf16[2,6,1024]{2,1,0} reduce(add.12, constant.9), dimensions={3}, to_apply=region_0.13 + reshape.18 = bf16[2,6,1024,1]{3,2,1,0} reshape(reduce.17) + broadcast.19 = bf16[2,6,1024,1]{3,2,1,0} broadcast(reshape.18), dimensions={0,1,2,3} + reshape.20 = bf16[2,6,1024]{2,1,0} reshape(broadcast.19) + broadcast.21 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.20), dimensions={0,1,2} + subtract.22 = bf16[2,6,1024,1024]{3,2,1,0} subtract(add.12, broadcast.21) + exponential.23 = bf16[2,6,1024,1024]{3,2,1,0} exponential(subtract.22) + convert.24 = f32[2,6,1024,1024]{3,2,1,0} convert(exponential.23) + constant.8 = f32[] constant(0) + reduce.29 = f32[2,6,1024]{2,1,0} reduce(convert.24, constant.8), dimensions={3}, to_apply=region_1.25 + reshape.30 = f32[2,6,1024,1]{3,2,1,0} reshape(reduce.29) + convert.31 = bf16[2,6,1024,1]{3,2,1,0} convert(reshape.30) + broadcast.32 = bf16[2,6,1024,1]{3,2,1,0} broadcast(convert.31), dimensions={0,1,2,3} + reshape.33 = bf16[2,6,1024]{2,1,0} reshape(broadcast.32) + broadcast.34 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.33), dimensions={0,1,2} + divide.35 = bf16[2,6,1024,1024]{3,2,1,0} divide(exponential.23, broadcast.34) + Arg_2.3 = bf16[2,6,1024,64]{3,2,1,0} parameter(2), sharding={replicated} + dot.38 = bf16[2,6,1024,64]{3,2,1,0} dot(divide.35, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_4.5 = bf16[2,6,1024,64]{3,2,1,0} parameter(4), sharding={replicated} + dot.41 = bf16[2,6,1024,1024]{3,2,1,0} dot(Arg_4.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + broadcast.54 = bf16[2,6,1024,1]{3,2,1,0} broadcast(convert.31), dimensions={0,1,2,3} + reshape.55 = bf16[2,6,1024]{2,1,0} reshape(broadcast.54) + broadcast.56 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.55), dimensions={0,1,2} + divide.57 = bf16[2,6,1024,1024]{3,2,1,0} divide(dot.41, broadcast.56) + constant.5 = bf16[] constant(1) + broadcast.6 = bf16[2,6,1024,1]{3,2,1,0} broadcast(constant.5), dimensions={} + multiply.36 = bf16[2,6,1024,1]{3,2,1,0} multiply(convert.31, convert.31) + divide.37 = bf16[2,6,1024,1]{3,2,1,0} divide(broadcast.6, multiply.36) + broadcast.42 = bf16[2,6,1024,1]{3,2,1,0} broadcast(divide.37), dimensions={0,1,2,3} + reshape.43 = bf16[2,6,1024]{2,1,0} reshape(broadcast.42) + broadcast.44 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.43), dimensions={0,1,2} + multiply.45 = bf16[2,6,1024,1024]{3,2,1,0} multiply(dot.41, broadcast.44) + multiply.46 = bf16[2,6,1024,1024]{3,2,1,0} multiply(multiply.45, exponential.23) + constant.7 = bf16[] constant(0) + reduce.51 = bf16[2,6,1024]{2,1,0} reduce(multiply.46, constant.7), dimensions={3}, to_apply=region_2.47 + reshape.52 = bf16[2,6,1024,1]{3,2,1,0} reshape(reduce.51) + negate.53 = bf16[2,6,1024,1]{3,2,1,0} negate(reshape.52) + convert.58 = f32[2,6,1024,1]{3,2,1,0} convert(negate.53) + reduce.63 = f32[2,6,1024]{2,1,0} reduce(convert.58, constant.8), dimensions={3}, to_apply=region_3.59 + broadcast.64 = f32[2,6,1024,1024]{3,2,1,0} broadcast(reduce.63), dimensions={0,1,2} + convert.65 = bf16[2,6,1024,1024]{3,2,1,0} convert(broadcast.64) + add.66 = bf16[2,6,1024,1024]{3,2,1,0} add(divide.57, convert.65) + multiply.67 = bf16[2,6,1024,1024]{3,2,1,0} multiply(add.66, exponential.23) + dot.70 = bf16[2,6,1024,64]{3,2,1,0} dot(multiply.67, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot.68 = bf16[2,6,1024,64]{3,2,1,0} dot(multiply.67, Arg_0.1), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.69 = bf16[2,6,64,1024]{2,3,1,0} transpose(dot.68), dimensions={0,1,3,2} + dot.39 = bf16[2,6,64,1024]{3,2,1,0} dot(Arg_4.5, divide.35), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.40 = bf16[2,6,1024,64]{2,3,1,0} transpose(dot.39), dimensions={0,1,3,2} + ROOT tuple.71 = (bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{2,3,1,0}, bf16[2,6,1024,64]{2,3,1,0}) tuple(dot.38, dot.70, transpose.69, transpose.40) + } + )"; + return hlo_text; + } + + const std::string // NOLINT + GetModuleFlash_Attention_BMM1_Bias_Softmax_BMM2_Cross_Attention_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,2048,128]{3,2,1,0},bf16[2,6,128,1024]{3,2,1,0},bf16[2,6,1024,128]{3,2,1,0},bf16[2,6,2048,1024]{3,2,1,0})->bf16[2,6,2048,128]{3,2,1,0}}, allow_spmd_sharding_propagation_to_output={true} + + region_0.28 { + Arg_0.29 = bf16[] parameter(0) + Arg_1.30 = bf16[] parameter(1) + ROOT maximum.31 = bf16[] maximum(Arg_0.29, Arg_1.30) + } + + region_1.40 { + Arg_0.41 = f32[] parameter(0) + Arg_1.42 = f32[] parameter(1) + ROOT add.43 = f32[] add(Arg_0.41, Arg_1.42) + } + + ENTRY main.52 { + Arg_0.1 = bf16[2,6,2048,128]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = bf16[2,6,128,1024]{3,2,1,0} parameter(1), sharding={replicated} + dot.10 = bf16[2,6,2048,1024]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.6 = bf16[] constant(2) + broadcast.7 = bf16[2,6,2048,1024]{3,2,1,0} broadcast(constant.6), dimensions={} + multiply.11 = bf16[2,6,2048,1024]{3,2,1,0} multiply(dot.10, broadcast.7) + Arg_3.4 = bf16[2,6,2048,1024]{3,2,1,0} parameter(3), sharding={replicated} + add.27 = bf16[2,6,2048,1024]{3,2,1,0} add(multiply.11, Arg_3.4) + constant.9 = bf16[] constant(-inf) + reduce.32 = bf16[2,6,2048]{2,1,0} reduce(add.27, constant.9), dimensions={3}, to_apply=region_0.28 + reshape.33 = bf16[2,6,2048,1]{3,2,1,0} reshape(reduce.32) + broadcast.34 = bf16[2,6,2048,1]{3,2,1,0} broadcast(reshape.33), dimensions={0,1,2,3} + reshape.35 = bf16[2,6,2048]{2,1,0} reshape(broadcast.34) + broadcast.36 = bf16[2,6,2048,1024]{3,2,1,0} broadcast(reshape.35), dimensions={0,1,2} + subtract.37 = bf16[2,6,2048,1024]{3,2,1,0} subtract(add.27, broadcast.36) + exponential.38 = bf16[2,6,2048,1024]{3,2,1,0} exponential(subtract.37) + convert.39 = f32[2,6,2048,1024]{3,2,1,0} convert(exponential.38) + constant.8 = f32[] constant(0) + reduce.44 = f32[2,6,2048]{2,1,0} reduce(convert.39, constant.8), dimensions={3}, to_apply=region_1.40 + reshape.45 = f32[2,6,2048,1]{3,2,1,0} reshape(reduce.44) + convert.46 = bf16[2,6,2048,1]{3,2,1,0} convert(reshape.45) + broadcast.47 = bf16[2,6,2048,1]{3,2,1,0} broadcast(convert.46), dimensions={0,1,2,3} + reshape.48 = bf16[2,6,2048]{2,1,0} reshape(broadcast.47) + broadcast.49 = bf16[2,6,2048,1024]{3,2,1,0} broadcast(reshape.48), dimensions={0,1,2} + divide.50 = bf16[2,6,2048,1024]{3,2,1,0} divide(exponential.38, broadcast.49) + Arg_2.3 = bf16[2,6,1024,128]{3,2,1,0} parameter(2), sharding={replicated} + ROOT dot.51 = bf16[2,6,2048,128]{3,2,1,0} dot(divide.50, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + } + )"; + return hlo_text; + } + template + void TestImpl_Flash_Attention_BMM1_Bias_Softmax_BMM2() { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfo(backend().default_stream_executor()) < + se::dnn::VersionInfo(8, 9, 3)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.3."; + } + XlaBuilder builder(TestName()); + auto lhs_bmm1_literal = + GetInput4DLiteral({2, 6, 2048, 128}, {3, 2, 1, 0}); + auto rhs_bmm1_literal = + GetInput4DLiteral({2, 6, 128, 2048}, {3, 2, 1, 0}); + auto rhs_bmm2_literal = + GetInput4DLiteral({2, 6, 2048, 128}, {3, 2, 1, 0}); + auto bias_literal = GetInput4DLiteral({2, 6, 2048, 2048}, {3, 2, 1, 0}); + std::string hlo_string = ""; + hlo_string = + GetModuleFlash_Attention_BMM1_Bias_Softmax_BMM2_HloString_BF16(); + ExecuteAndCompare(hlo_string, {&lhs_bmm1_literal, &rhs_bmm1_literal, + &rhs_bmm2_literal, &bias_literal}); + } + + template + void TestImpl_Flash_Attention_Training_BMM1_Bias_Softmax_BMM2() { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfo(backend().default_stream_executor()) < + se::dnn::VersionInfo(8, 9, 3)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.3."; + } + XlaBuilder builder(TestName()); + auto lhs_bmm1_literal = + GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + auto rhs_bmm1_literal = + GetInput4DLiteral({2, 6, 64, 1024}, {3, 2, 1, 0}); + auto rhs_bmm2_literal = + GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + auto bias_literal = GetInput4DLiteral({2, 6, 1024, 1024}, {3, 2, 1, 0}); + auto do_literal = GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + std::string hlo_string = ""; + hlo_string = + GetModuleFlash_Attention_Training_BMM1_Bias_Softmax_BMM2_HloString_BF16(); // NOLINT + ExecuteAndCompare(hlo_string, + {&lhs_bmm1_literal, &rhs_bmm1_literal, &rhs_bmm2_literal, + &bias_literal, &do_literal}, + /*expected_num_fmha_calls=*/2); + } + + template + void TestImpl_Flash_Attention_BMM1_Bias_Softmax_BMM2_Cross_Attention() { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfo(backend().default_stream_executor()) < + se::dnn::VersionInfo(8, 9, 4)) { + GTEST_SKIP() << "Flash Attention cross attention requires " + "cuDNN >= 8.9.4."; + } + XlaBuilder builder(TestName()); + auto lhs_bmm1_literal = + GetInput4DLiteral({2, 6, 2048, 128}, {3, 2, 1, 0}); + auto rhs_bmm1_literal = + GetInput4DLiteral({2, 6, 128, 1024}, {3, 2, 1, 0}); + auto rhs_bmm2_literal = + GetInput4DLiteral({2, 6, 1024, 128}, {3, 2, 1, 0}); + auto bias_literal = GetInput4DLiteral({2, 6, 2048, 1024}, {3, 2, 1, 0}); + std::string hlo_string = ""; + hlo_string = + GetModuleFlash_Attention_BMM1_Bias_Softmax_BMM2_Cross_Attention_HloString_BF16(); // NOLINT + ExecuteAndCompare(hlo_string, {&lhs_bmm1_literal, &rhs_bmm1_literal, + &rhs_bmm2_literal, &bias_literal}); + } +}; + +class FlashAttentionBMMScaleBiasMaskSoftmaxBMM + : public MultiHeadedAttentionTest { + protected: + const std::string // NOLINT + GetModuleFlash_Attention_Training_BMM1_Bias_Mask_Softmax_BMM2_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,1024,64]{3,2,1,0},bf16[2,6,64,1024]{3,2,1,0},bf16[2,6,1024,64]{3,2,1,0},bf16[2,6,1024,64]{3,2,1,0},pred[2,6,1024,1024]{3,2,1,0})->(bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + + region_0.21 { + Arg_0.22 = bf16[] parameter(0) + Arg_1.23 = bf16[] parameter(1) + ROOT maximum.24 = bf16[] maximum(Arg_0.22, Arg_1.23) + } + + region_1.33 { + Arg_0.34 = f32[] parameter(0) + Arg_1.35 = f32[] parameter(1) + ROOT add.36 = f32[] add(Arg_0.34, Arg_1.35) + } + + region_2.55 { + Arg_0.56 = bf16[] parameter(0) + Arg_1.57 = bf16[] parameter(1) + ROOT add.58 = bf16[] add(Arg_0.56, Arg_1.57) + } + + region_3.67 { + Arg_0.68 = f32[] parameter(0) + Arg_1.69 = f32[] parameter(1) + ROOT add.70 = f32[] add(Arg_0.68, Arg_1.69) + } + + ENTRY main.82 { + constant.16 = pred[2,6,1024,1024]{3,2,1,0} parameter(4) + Arg_0.1 = bf16[2,6,1024,64]{3,2,1,0} parameter(0) + Arg_1.2 = bf16[2,6,64,1024]{3,2,1,0} parameter(1) + dot.17 = bf16[2,6,1024,1024]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.5 = bf16[] constant(2) + broadcast.6 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(constant.5), dimensions={} + multiply.18 = bf16[2,6,1024,1024]{3,2,1,0} multiply(dot.17, broadcast.6) + constant.14 = bf16[] constant(1) + broadcast.15 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(constant.14), dimensions={} + add.19 = bf16[2,6,1024,1024]{3,2,1,0} add(multiply.18, broadcast.15) + constant.7 = bf16[] constant(0) + broadcast.8 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(constant.7), dimensions={} + select.20 = bf16[2,6,1024,1024]{3,2,1,0} select(constant.16, add.19, broadcast.8) + constant.12 = bf16[] constant(-inf) + reduce.25 = bf16[2,6,1024]{2,1,0} reduce(select.20, constant.12), dimensions={3}, to_apply=region_0.21 + reshape.26 = bf16[2,6,1024,1]{3,2,1,0} reshape(reduce.25) + broadcast.27 = bf16[2,6,1024,1]{3,2,1,0} broadcast(reshape.26), dimensions={0,1,2,3} + reshape.28 = bf16[2,6,1024]{2,1,0} reshape(broadcast.27) + broadcast.29 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.28), dimensions={0,1,2} + subtract.30 = bf16[2,6,1024,1024]{3,2,1,0} subtract(select.20, broadcast.29) + exponential.31 = bf16[2,6,1024,1024]{3,2,1,0} exponential(subtract.30) + convert.32 = f32[2,6,1024,1024]{3,2,1,0} convert(exponential.31) + constant.11 = f32[] constant(0) + reduce.37 = f32[2,6,1024]{2,1,0} reduce(convert.32, constant.11), dimensions={3}, to_apply=region_1.33 + reshape.38 = f32[2,6,1024,1]{3,2,1,0} reshape(reduce.37) + convert.39 = bf16[2,6,1024,1]{3,2,1,0} convert(reshape.38) + broadcast.40 = bf16[2,6,1024,1]{3,2,1,0} broadcast(convert.39), dimensions={0,1,2,3} + reshape.41 = bf16[2,6,1024]{2,1,0} reshape(broadcast.40) + broadcast.42 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.41), dimensions={0,1,2} + divide.43 = bf16[2,6,1024,1024]{3,2,1,0} divide(exponential.31, broadcast.42) + Arg_2.3 = bf16[2,6,1024,64]{3,2,1,0} parameter(2) + dot.46 = bf16[2,6,1024,64]{3,2,1,0} dot(divide.43, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = bf16[2,6,1024,64]{3,2,1,0} parameter(3) + dot.49 = bf16[2,6,1024,1024]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + broadcast.62 = bf16[2,6,1024,1]{3,2,1,0} broadcast(convert.39), dimensions={0,1,2,3} + reshape.63 = bf16[2,6,1024]{2,1,0} reshape(broadcast.62) + broadcast.64 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.63), dimensions={0,1,2} + divide.65 = bf16[2,6,1024,1024]{3,2,1,0} divide(dot.49, broadcast.64) + constant.9 = bf16[] constant(1) + broadcast.10 = bf16[2,6,1024,1]{3,2,1,0} broadcast(constant.9), dimensions={} + multiply.44 = bf16[2,6,1024,1]{3,2,1,0} multiply(convert.39, convert.39) + divide.45 = bf16[2,6,1024,1]{3,2,1,0} divide(broadcast.10, multiply.44) + broadcast.50 = bf16[2,6,1024,1]{3,2,1,0} broadcast(divide.45), dimensions={0,1,2,3} + reshape.51 = bf16[2,6,1024]{2,1,0} reshape(broadcast.50) + broadcast.52 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.51), dimensions={0,1,2} + multiply.53 = bf16[2,6,1024,1024]{3,2,1,0} multiply(dot.49, broadcast.52) + multiply.54 = bf16[2,6,1024,1024]{3,2,1,0} multiply(multiply.53, exponential.31) + constant.13 = bf16[] constant(0) + reduce.59 = bf16[2,6,1024]{2,1,0} reduce(multiply.54, constant.13), dimensions={3}, to_apply=region_2.55 + reshape.60 = bf16[2,6,1024,1]{3,2,1,0} reshape(reduce.59) + negate.61 = bf16[2,6,1024,1]{3,2,1,0} negate(reshape.60) + convert.66 = f32[2,6,1024,1]{3,2,1,0} convert(negate.61) + reduce.71 = f32[2,6,1024]{2,1,0} reduce(convert.66, constant.11), dimensions={3}, to_apply=region_3.67 + broadcast.72 = f32[2,6,1024,1024]{3,2,1,0} broadcast(reduce.71), dimensions={0,1,2} + convert.73 = bf16[2,6,1024,1024]{3,2,1,0} convert(broadcast.72) + add.74 = bf16[2,6,1024,1024]{3,2,1,0} add(divide.65, convert.73) + multiply.75 = bf16[2,6,1024,1024]{3,2,1,0} multiply(add.74, exponential.31) + select.76 = bf16[2,6,1024,1024]{3,2,1,0} select(constant.16, multiply.75, broadcast.8) + multiply.77 = bf16[2,6,1024,1024]{3,2,1,0} multiply(select.76, broadcast.6) + dot.80 = bf16[2,6,1024,64]{3,2,1,0} dot(multiply.77, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot.78 = bf16[2,6,1024,64]{3,2,1,0} dot(multiply.77, Arg_0.1), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.79 = bf16[2,6,64,1024]{2,3,1,0} transpose(dot.78), dimensions={0,1,3,2} + dot.47 = bf16[2,6,64,1024]{3,2,1,0} dot(Arg_3.4, divide.43), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.48 = bf16[2,6,1024,64]{2,3,1,0} transpose(dot.47), dimensions={0,1,3,2} + ROOT tuple.81 = (bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{2,3,1,0}, bf16[2,6,1024,64]{2,3,1,0}) tuple(dot.46, dot.80, transpose.79, transpose.48) + } + )"; + + return hlo_text; + } + + template + void TestImpl_Flash_Attention_Training_BMM1_Bias_Mask_Softmax_BMM2() { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfo(backend().default_stream_executor()) < + se::dnn::VersionInfo(8, 9, 3)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.3."; + } + XlaBuilder builder(TestName()); + auto lhs_bmm1_literal = + GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + auto rhs_bmm1_literal = + GetInput4DLiteral({2, 6, 64, 1024}, {3, 2, 1, 0}); + auto rhs_bmm2_literal = + GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + auto mask_literal = GetMask4DLiteral({2, 6, 1024, 1024}, {3, 2, 1, 0}); + auto do_literal = GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + std::string hlo_string = + GetModuleFlash_Attention_Training_BMM1_Bias_Mask_Softmax_BMM2_HloString_BF16(); // NOLINT + ExecuteAndCompare(hlo_string, + {&lhs_bmm1_literal, &rhs_bmm1_literal, &rhs_bmm2_literal, + &do_literal, &mask_literal}, + /*expected_num_fmha_calls=*/2); + } +}; + +class FlashAttentionBMMScaleSoftmaxBMM : public MultiHeadedAttentionTest { + protected: + const std::string // NOLINT + GetModuleFlash_Attention_Training_BMM1_Softmax_BMM2_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0})->(bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + + region_0.13 { + Arg_0.14 = bf16[] parameter(0) + Arg_1.15 = bf16[] parameter(1) + ROOT maximum.16 = bf16[] maximum(Arg_0.14, Arg_1.15) + } + + region_1.25 { + Arg_0.26 = f32[] parameter(0) + Arg_1.27 = f32[] parameter(1) + ROOT add.28 = f32[] add(Arg_0.26, Arg_1.27) + } + + region_2.47 { + Arg_0.48 = bf16[] parameter(0) + Arg_1.49 = bf16[] parameter(1) + ROOT add.50 = bf16[] add(Arg_0.48, Arg_1.49) + } + + region_3.59 { + Arg_0.60 = f32[] parameter(0) + Arg_1.61 = f32[] parameter(1) + ROOT add.62 = f32[] add(Arg_0.60, Arg_1.61) + } + + ENTRY main.72 { + Arg_0.1 = bf16[2,6,1024,64]{3,2,1,0} parameter(0), sharding={replicated} + Arg_1.2 = bf16[2,6,64,1024]{3,2,1,0} parameter(1), sharding={replicated} + dot.11 = bf16[2,6,1024,1024]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.17 = bf16[] constant(37) + broadcast.29 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(constant.17), dimensions={} + multiply.2 = bf16[2,6,1024,1024]{3,2,1,0} multiply(dot.11, broadcast.29) + constant.9 = bf16[] constant(-inf) + reduce.17 = bf16[2,6,1024]{2,1,0} reduce(multiply.2, constant.9), dimensions={3}, to_apply=region_0.13 + reshape.18 = bf16[2,6,1024,1]{3,2,1,0} reshape(reduce.17) + broadcast.19 = bf16[2,6,1024,1]{3,2,1,0} broadcast(reshape.18), dimensions={0,1,2,3} + reshape.20 = bf16[2,6,1024]{2,1,0} reshape(broadcast.19) + broadcast.21 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.20), dimensions={0,1,2} + subtract.22 = bf16[2,6,1024,1024]{3,2,1,0} subtract(multiply.2, broadcast.21) + exponential.23 = bf16[2,6,1024,1024]{3,2,1,0} exponential(subtract.22) + convert.24 = f32[2,6,1024,1024]{3,2,1,0} convert(exponential.23) + constant.8 = f32[] constant(0) + reduce.29 = f32[2,6,1024]{2,1,0} reduce(convert.24, constant.8), dimensions={3}, to_apply=region_1.25 + reshape.30 = f32[2,6,1024,1]{3,2,1,0} reshape(reduce.29) + convert.31 = bf16[2,6,1024,1]{3,2,1,0} convert(reshape.30) + broadcast.32 = bf16[2,6,1024,1]{3,2,1,0} broadcast(convert.31), dimensions={0,1,2,3} + reshape.33 = bf16[2,6,1024]{2,1,0} reshape(broadcast.32) + broadcast.34 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.33), dimensions={0,1,2} + divide.35 = bf16[2,6,1024,1024]{3,2,1,0} divide(exponential.23, broadcast.34) + Arg_2.3 = bf16[2,6,1024,64]{3,2,1,0} parameter(2), sharding={replicated} + dot.38 = bf16[2,6,1024,64]{3,2,1,0} dot(divide.35, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_4.5 = bf16[2,6,1024,64]{3,2,1,0} parameter(3), sharding={replicated} + dot.41 = bf16[2,6,1024,1024]{3,2,1,0} dot(Arg_4.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + broadcast.54 = bf16[2,6,1024,1]{3,2,1,0} broadcast(convert.31), dimensions={0,1,2,3} + reshape.55 = bf16[2,6,1024]{2,1,0} reshape(broadcast.54) + broadcast.56 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.55), dimensions={0,1,2} + divide.57 = bf16[2,6,1024,1024]{3,2,1,0} divide(dot.41, broadcast.56) + constant.5 = bf16[] constant(1) + broadcast.6 = bf16[2,6,1024,1]{3,2,1,0} broadcast(constant.5), dimensions={} + multiply.36 = bf16[2,6,1024,1]{3,2,1,0} multiply(convert.31, convert.31) + divide.37 = bf16[2,6,1024,1]{3,2,1,0} divide(broadcast.6, multiply.36) + broadcast.42 = bf16[2,6,1024,1]{3,2,1,0} broadcast(divide.37), dimensions={0,1,2,3} + reshape.43 = bf16[2,6,1024]{2,1,0} reshape(broadcast.42) + broadcast.44 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.43), dimensions={0,1,2} + multiply.45 = bf16[2,6,1024,1024]{3,2,1,0} multiply(dot.41, broadcast.44) + multiply.46 = bf16[2,6,1024,1024]{3,2,1,0} multiply(multiply.45, exponential.23) + constant.7 = bf16[] constant(0) + reduce.51 = bf16[2,6,1024]{2,1,0} reduce(multiply.46, constant.7), dimensions={3}, to_apply=region_2.47 + reshape.52 = bf16[2,6,1024,1]{3,2,1,0} reshape(reduce.51) + negate.53 = bf16[2,6,1024,1]{3,2,1,0} negate(reshape.52) + convert.58 = f32[2,6,1024,1]{3,2,1,0} convert(negate.53) + reduce.63 = f32[2,6,1024]{2,1,0} reduce(convert.58, constant.8), dimensions={3}, to_apply=region_3.59 + broadcast.64 = f32[2,6,1024,1024]{3,2,1,0} broadcast(reduce.63), dimensions={0,1,2} + convert.65 = bf16[2,6,1024,1024]{3,2,1,0} convert(broadcast.64) + add.66 = bf16[2,6,1024,1024]{3,2,1,0} add(divide.57, convert.65) + multiply.67 = bf16[2,6,1024,1024]{3,2,1,0} multiply(add.66, exponential.23) + dot.70 = bf16[2,6,1024,64]{3,2,1,0} dot(multiply.67, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot.68 = bf16[2,6,1024,64]{3,2,1,0} dot(multiply.67, Arg_0.1), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.69 = bf16[2,6,64,1024]{2,3,1,0} transpose(dot.68), dimensions={0,1,3,2} + dot.39 = bf16[2,6,64,1024]{3,2,1,0} dot(Arg_4.5, divide.35), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.40 = bf16[2,6,1024,64]{2,3,1,0} transpose(dot.39), dimensions={0,1,3,2} + ROOT tuple.71 = (bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{2,3,1,0}, bf16[2,6,1024,64]{2,3,1,0}) tuple(dot.38, dot.70, transpose.69, transpose.40) + } + )"; + return hlo_text; + } + + template + void TestImpl_Flash_Attention_Training_BMM1_Softmax_BMM2() { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfo(backend().default_stream_executor()) < + se::dnn::VersionInfo(8, 9, 3)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.3."; + } + XlaBuilder builder(TestName()); + auto lhs_bmm1_literal = + GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + auto rhs_bmm1_literal = + GetInput4DLiteral({2, 6, 64, 1024}, {3, 2, 1, 0}); + auto rhs_bmm2_literal = + GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + auto do_literal = GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + std::string hlo_string = ""; + hlo_string = + GetModuleFlash_Attention_Training_BMM1_Softmax_BMM2_HloString_BF16(); // NOLINT + ExecuteAndCompare( + hlo_string, + {&lhs_bmm1_literal, &rhs_bmm1_literal, &rhs_bmm2_literal, &do_literal}, + /*expected_num_fmha_calls=*/2); + } +}; + +class FlashAttentionBMMScaleMaskSoftmaxBMM : public MultiHeadedAttentionTest { + protected: + const std::string // NOLINT + GetModuleFlash_Attention_Training_BMM1_Mask_Softmax_BMM2_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,6,1024,64]{3,2,1,0},bf16[2,6,64,1024]{3,2,1,0},bf16[2,6,1024,64]{3,2,1,0},bf16[2,6,1024,64]{3,2,1,0},pred[2,6,1024,1024]{3,2,1,0})->(bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + + region_0.21 { + Arg_0.22 = bf16[] parameter(0) + Arg_1.23 = bf16[] parameter(1) + ROOT maximum.24 = bf16[] maximum(Arg_0.22, Arg_1.23) + } + + region_1.33 { + Arg_0.34 = f32[] parameter(0) + Arg_1.35 = f32[] parameter(1) + ROOT add.36 = f32[] add(Arg_0.34, Arg_1.35) + } + + region_2.55 { + Arg_0.56 = bf16[] parameter(0) + Arg_1.57 = bf16[] parameter(1) + ROOT add.58 = bf16[] add(Arg_0.56, Arg_1.57) + } + + region_3.67 { + Arg_0.68 = f32[] parameter(0) + Arg_1.69 = f32[] parameter(1) + ROOT add.70 = f32[] add(Arg_0.68, Arg_1.69) + } + + ENTRY main.82 { + constant.16 = pred[2,6,1024,1024]{3,2,1,0} parameter(4) + Arg_0.1 = bf16[2,6,1024,64]{3,2,1,0} parameter(0) + Arg_1.2 = bf16[2,6,64,1024]{3,2,1,0} parameter(1) + dot.17 = bf16[2,6,1024,1024]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.5 = bf16[] constant(2) + broadcast.6 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(constant.5), dimensions={} + multiply.18 = bf16[2,6,1024,1024]{3,2,1,0} multiply(dot.17, broadcast.6) + constant.7 = bf16[] constant(0) + broadcast.8 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(constant.7), dimensions={} + select.20 = bf16[2,6,1024,1024]{3,2,1,0} select(constant.16, multiply.18, broadcast.8) + constant.12 = bf16[] constant(-inf) + reduce.25 = bf16[2,6,1024]{2,1,0} reduce(select.20, constant.12), dimensions={3}, to_apply=region_0.21 + reshape.26 = bf16[2,6,1024,1]{3,2,1,0} reshape(reduce.25) + broadcast.27 = bf16[2,6,1024,1]{3,2,1,0} broadcast(reshape.26), dimensions={0,1,2,3} + reshape.28 = bf16[2,6,1024]{2,1,0} reshape(broadcast.27) + broadcast.29 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.28), dimensions={0,1,2} + subtract.30 = bf16[2,6,1024,1024]{3,2,1,0} subtract(select.20, broadcast.29) + exponential.31 = bf16[2,6,1024,1024]{3,2,1,0} exponential(subtract.30) + convert.32 = f32[2,6,1024,1024]{3,2,1,0} convert(exponential.31) + constant.11 = f32[] constant(0) + reduce.37 = f32[2,6,1024]{2,1,0} reduce(convert.32, constant.11), dimensions={3}, to_apply=region_1.33 + reshape.38 = f32[2,6,1024,1]{3,2,1,0} reshape(reduce.37) + convert.39 = bf16[2,6,1024,1]{3,2,1,0} convert(reshape.38) + broadcast.40 = bf16[2,6,1024,1]{3,2,1,0} broadcast(convert.39), dimensions={0,1,2,3} + reshape.41 = bf16[2,6,1024]{2,1,0} reshape(broadcast.40) + broadcast.42 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.41), dimensions={0,1,2} + divide.43 = bf16[2,6,1024,1024]{3,2,1,0} divide(exponential.31, broadcast.42) + Arg_2.3 = bf16[2,6,1024,64]{3,2,1,0} parameter(2) + dot.46 = bf16[2,6,1024,64]{3,2,1,0} dot(divide.43, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + Arg_3.4 = bf16[2,6,1024,64]{3,2,1,0} parameter(3) + dot.49 = bf16[2,6,1024,1024]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + broadcast.62 = bf16[2,6,1024,1]{3,2,1,0} broadcast(convert.39), dimensions={0,1,2,3} + reshape.63 = bf16[2,6,1024]{2,1,0} reshape(broadcast.62) + broadcast.64 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.63), dimensions={0,1,2} + divide.65 = bf16[2,6,1024,1024]{3,2,1,0} divide(dot.49, broadcast.64) + constant.9 = bf16[] constant(1) + broadcast.10 = bf16[2,6,1024,1]{3,2,1,0} broadcast(constant.9), dimensions={} + multiply.44 = bf16[2,6,1024,1]{3,2,1,0} multiply(convert.39, convert.39) + divide.45 = bf16[2,6,1024,1]{3,2,1,0} divide(broadcast.10, multiply.44) + broadcast.50 = bf16[2,6,1024,1]{3,2,1,0} broadcast(divide.45), dimensions={0,1,2,3} + reshape.51 = bf16[2,6,1024]{2,1,0} reshape(broadcast.50) + broadcast.52 = bf16[2,6,1024,1024]{3,2,1,0} broadcast(reshape.51), dimensions={0,1,2} + multiply.53 = bf16[2,6,1024,1024]{3,2,1,0} multiply(dot.49, broadcast.52) + multiply.54 = bf16[2,6,1024,1024]{3,2,1,0} multiply(multiply.53, exponential.31) + constant.13 = bf16[] constant(0) + reduce.59 = bf16[2,6,1024]{2,1,0} reduce(multiply.54, constant.13), dimensions={3}, to_apply=region_2.55 + reshape.60 = bf16[2,6,1024,1]{3,2,1,0} reshape(reduce.59) + negate.61 = bf16[2,6,1024,1]{3,2,1,0} negate(reshape.60) + convert.66 = f32[2,6,1024,1]{3,2,1,0} convert(negate.61) + reduce.71 = f32[2,6,1024]{2,1,0} reduce(convert.66, constant.11), dimensions={3}, to_apply=region_3.67 + broadcast.72 = f32[2,6,1024,1024]{3,2,1,0} broadcast(reduce.71), dimensions={0,1,2} + convert.73 = bf16[2,6,1024,1024]{3,2,1,0} convert(broadcast.72) + add.74 = bf16[2,6,1024,1024]{3,2,1,0} add(divide.65, convert.73) + multiply.75 = bf16[2,6,1024,1024]{3,2,1,0} multiply(add.74, exponential.31) + select.76 = bf16[2,6,1024,1024]{3,2,1,0} select(constant.16, multiply.75, broadcast.8) + multiply.77 = bf16[2,6,1024,1024]{3,2,1,0} multiply(select.76, broadcast.6) + dot.80 = bf16[2,6,1024,64]{3,2,1,0} dot(multiply.77, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + dot.78 = bf16[2,6,1024,64]{3,2,1,0} dot(multiply.77, Arg_0.1), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.79 = bf16[2,6,64,1024]{2,3,1,0} transpose(dot.78), dimensions={0,1,3,2} + dot.47 = bf16[2,6,64,1024]{3,2,1,0} dot(Arg_3.4, divide.43), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.48 = bf16[2,6,1024,64]{2,3,1,0} transpose(dot.47), dimensions={0,1,3,2} + ROOT tuple.81 = (bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,1024,64]{3,2,1,0}, bf16[2,6,64,1024]{2,3,1,0}, bf16[2,6,1024,64]{2,3,1,0}) tuple(dot.46, dot.80, transpose.79, transpose.48) + } + )"; + + return hlo_text; + } + + template + void TestImpl_Flash_Attention_Training_BMM1_Mask_Softmax_BMM2() { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfo(backend().default_stream_executor()) < + se::dnn::VersionInfo(8, 9, 3)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.3."; + } + XlaBuilder builder(TestName()); + auto lhs_bmm1_literal = + GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + auto rhs_bmm1_literal = + GetInput4DLiteral({2, 6, 64, 1024}, {3, 2, 1, 0}); + auto rhs_bmm2_literal = + GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + auto mask_literal = GetMask4DLiteral({2, 6, 1024, 1024}, {3, 2, 1, 0}); + auto do_literal = GetInput4DLiteral({2, 6, 1024, 64}, {3, 2, 1, 0}); + std::string hlo_string = + GetModuleFlash_Attention_Training_BMM1_Mask_Softmax_BMM2_HloString_BF16(); // NOLINT + ExecuteAndCompare(hlo_string, + {&lhs_bmm1_literal, &rhs_bmm1_literal, &rhs_bmm2_literal, + &do_literal, &mask_literal}, + /*expected_num_fmha_calls=*/2); } }; @@ -2345,5 +3108,52 @@ XLA_TEST_F(MultiHeadedAttentionBMMScaleBiasSoftmaxBMM, FMHA_Training_BMM1_Scale_Bias_Softmax_BMM2_vanilla_BF16) { TestImpl_FMHA_Training_BMM1_Scale_Bias_Softmax_BMM2_vanilla(); } + +// flash attention +// BMM1 - Scale - CausalMask - Softmax - BMM2 +XLA_TEST_F(FlashAttentionBMMScaleCausalMaskSoftmaxBMM, + Flash_Attention_BMM1_CausalMask_Softmax_BMM2_BF16) { + TestImpl_Flash_Attention_BMM1_CausalMask_Softmax_BMM2(); +} + +XLA_TEST_F(FlashAttentionBMMScaleCausalMaskSoftmaxBMM, + Flash_Attention_Training_BMM1_CausalMask_Softmax_BMM2_BF16) { + TestImpl_Flash_Attention_Training_BMM1_CausalMask_Softmax_BMM2(); +} + +// BMM1 - Scale - Bias - Softmax - BMM2 +XLA_TEST_F(FlashAttentionBMMScaleBiasSoftmaxBMM, + Flash_Attention_BMM1_Bias_Softmax_BMM2_BF16) { + TestImpl_Flash_Attention_BMM1_Bias_Softmax_BMM2(); +} + +XLA_TEST_F(FlashAttentionBMMScaleBiasSoftmaxBMM, + Flash_Attention_Training_BMM1_Bias_Softmax_BMM2_BF16) { + TestImpl_Flash_Attention_Training_BMM1_Bias_Softmax_BMM2(); +} + +XLA_TEST_F(FlashAttentionBMMScaleBiasSoftmaxBMM, + Flash_Attention_BMM1_Bias_Softmax_BMM2_BF16_Cross_Attention) { + TestImpl_Flash_Attention_BMM1_Bias_Softmax_BMM2_Cross_Attention(); +} + +// BMM1 - Scale - Bias - Mask - Softmax - BMM2 +XLA_TEST_F(FlashAttentionBMMScaleBiasMaskSoftmaxBMM, + Flash_Attention_Training_BMM1_Bias_Mask_Softmax_BMM2_BF16) { + TestImpl_Flash_Attention_Training_BMM1_Bias_Mask_Softmax_BMM2(); +} + +// BMM1 - Scale - Softmax - BMM2 +XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMM, + Flash_Attention_Training_BMM1_Softmax_BMM2_BF16) { + TestImpl_Flash_Attention_Training_BMM1_Softmax_BMM2(); +} + +// BMM1 - Scale - Mask - Softmax - BMM2 +XLA_TEST_F(FlashAttentionBMMScaleMaskSoftmaxBMM, + Flash_Attention_Training_BMM1_Mask_Softmax_BMM2_BF16) { + TestImpl_Flash_Attention_Training_BMM1_Mask_Softmax_BMM2(); +} +} // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc b/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc index 2c96ac7e31678..3e573eb569bb6 100644 --- a/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc +++ b/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,16 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include -#include +#include "absl/strings/string_view.h" #include "xla/service/gpu/fusion_merger.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/instruction_fusion.h" #include "xla/service/gpu/multi_output_fusion.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_pass_pipeline.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" namespace xla { diff --git a/xla/service/gpu/tests/gpu_fusion_test.cc b/xla/service/gpu/tests/gpu_fusion_test.cc index dc5113a1c4207..849cf1dcaf5bb 100644 --- a/xla/service/gpu/tests/gpu_fusion_test.cc +++ b/xla/service/gpu/tests/gpu_fusion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,13 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/instruction_fusion.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/shape.h" +#include "xla/shape_util.h" namespace xla { namespace gpu { @@ -123,26 +130,26 @@ ENTRY main { )"); } -TEST_F(TransposeFusionTest, ReshapeAfterCopyFused) { +TEST_F(TransposeFusionTest, ReshapeAfterTransposeFused) { const char* hlo = R"( HloModule module ENTRY main { p = f32[16,32]{1,0} parameter(0) s = sqrt(p) - c = f32[16,32]{0,1} copy(s) - ROOT r = f32[16,32,1]{0,1,2} reshape(c) + t = f32[32,16]{1,0} transpose(s), dimensions={1,0} + ROOT r = f32[32,16,1]{2,1,0} reshape(t) } )"; CheckGpuFusion(hlo, R"( -// CHECK: %fused_computation (param_0.2: f32[16,32]) -> f32[16,32,1] { +// CHECK: %fused_computation (param_0.2: f32[16,32]) -> f32[32,16,1] { // CHECK-NEXT: [[param_0_2_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0) // CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_2_0]]) -// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1_1]]) -// CHECK-NEXT: ROOT [[r_1_3:%[^ ]+]] = f32[16,32,1]{0,1,2} reshape([[c_1_2]]) +// CHECK-NEXT: [[t_1_2:%[^ ]+]] = f32[32,16]{1,0} transpose([[s_1_1]]) +// CHECK-NEXT: ROOT [[r_1_3:%[^ ]+]] = f32[32,16,1]{2,1,0} reshape([[t_1_2]]) // CHECK-NEXT: } -// CHECK: ROOT [[fusion_1:%[^ ]+]] = f32[16,32,1]{0,1,2} fusion +// CHECK: ROOT [[fusion_1:%[^ ]+]] = f32[32,16,1]{2,1,0} fusion )"); } diff --git a/xla/service/gpu/tests/gpu_index_test.cc b/xla/service/gpu/tests/gpu_index_test.cc index 261b272167540..dd5461eab5301 100644 --- a/xla/service/gpu/tests/gpu_index_test.cc +++ b/xla/service/gpu/tests/gpu_index_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,13 +16,11 @@ limitations under the License. #include #include +#include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/literal.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla.pb.h" @@ -87,36 +85,6 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshape) { /*match_optimized_ir=*/true); } -TEST_F(GpuIndexTest, - ReuseMultidimIndexWithTrivialReshapeAndNonContiguousBroadcast) { - HloModuleConfig config; - config.set_debug_options(HloTestBase::GetDebugOptionsForTest()); - auto module = ParseAndReturnVerifiedModule(R"( - HloModule test_module - - ENTRY CompatibleUseLinearIndexWithReshape { - x = f32[1,7,2,5,3]{4,3,2,1,0} parameter(0) - y = f32[2,1,3]{2,1,0} parameter(1) - reshape = f32[1,2,3]{2,1,0} reshape(y) - broadcast = f32[1,7,2,5,3]{4,3,2,1,0} broadcast(reshape), dimensions={0,2,4} - ROOT gte = pred[1,7,2,5,3]{4,3,2,1,0} compare(x, broadcast), direction=GE - })", - config) - .value(); - CompileAndVerifyIr(std::move(module), - R"( -; CHECK: %[[tmp4:.*]] = udiv i32 %[[linear_index:.*]], 1 -; CHECK: %[[dim4:.*]] = urem i32 %[[tmp4]], 3 -; CHECK: %[[tmp3:.*]] = udiv i32 %[[linear_index]], 3 -; CHECK: %[[dim3:.*]] = urem i32 %[[tmp3]], 5 -; CHECK: %[[tmp2:.*]] = udiv i32 %[[linear_index]], 15 -; CHECK: %[[dim2:.*]] = urem i32 %[[tmp2]], 2 -; CHECK: %[[tmp1:.*]] = udiv i32 %[[linear_index]], 30 -; CHECK: %{{.*}} = getelementptr inbounds [2 x [1 x [3 x float]]], ptr %{{.*}}, i32 0, i32 %[[dim2]], i32 0, i32 %[[dim4]] - )", - /*match_optimized_ir=*/false); -} - #if TENSORFLOW_USE_ROCM #else TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshapeAndBroadcast) { @@ -171,9 +139,11 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithSizeOneDimensions) { R"( ; CHECK-LABEL: @wrapped_convert ; CHECK: icmp ult i32 %[[linear_index:.*]], 262144 -; CHECK: %[[ld_addr:.*]] = getelementptr inbounds float, ptr {{.*}}, i32 %[[linear_index]] +; CHECK: %[[ld_addr_base:.*]] = getelementptr float, ptr {{.*}}, i32 %[[linear_index]] +; CHECK: %[[ld_addr:.*]] = getelementptr inbounds float, ptr %[[ld_addr_base]], i32 0 ; CHECK: load float, ptr %[[ld_addr]] -; CHECK: %[[st_addr:.*]] = getelementptr inbounds half, ptr {{.*}}, i32 %[[linear_index]] +; CHECK: %[[st_addr_base:.*]] = getelementptr half, ptr {{.*}}, i32 %[[linear_index]] +; CHECK: %[[st_addr:.*]] = getelementptr inbounds half, ptr %[[st_addr_base]], i32 0 ; CHECK: store half {{.*}}, ptr %[[st_addr]] )", /*match_optimized_ir=*/false); diff --git a/xla/service/gpu/tests/gpu_input_fusible_slice_test.cc b/xla/service/gpu/tests/gpu_input_fusible_slice_test.cc index 3c309471817e4..ce0b693575720 100644 --- a/xla/service/gpu/tests/gpu_input_fusible_slice_test.cc +++ b/xla/service/gpu/tests/gpu_input_fusible_slice_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,9 +15,9 @@ limitations under the License. #include +#include "xla/error_spec.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/test.h" @@ -67,12 +67,12 @@ TEST_F(GpuSliceInputFusionTest, InputFusionWithATupleOfSlices) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = is_built_with_rocm_ ? R"( -; CHECK-LABEL: define amdgpu_kernel void @fusion +; CHECK-LABEL: define amdgpu_kernel void @{{[a-z_]*}}fusion ; CHECK: slice2 ; CHECK: } )" : R"( -; CHECK-LABEL: define void @fusion +; CHECK-LABEL: define void @{{[a-z_]*}}fusion ; CHECK: slice2 ; CHECK: } )"; @@ -114,12 +114,12 @@ TEST_F(GpuSliceInputFusionTest, ConcatThenSplit) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = is_built_with_rocm_ ? R"( -; CHECK-LABEL: define amdgpu_kernel void @fusion +; CHECK-LABEL: define amdgpu_kernel void @{{[a-z_]*}}fusion ; CHECK: slice2 ; CHECK: } )" : R"( -; CHECK-LABEL: define void @fusion +; CHECK-LABEL: define void @{{[a-z_]*}}fusion ; CHECK: slice2 ; CHECK: } )"; diff --git a/xla/service/gpu/tests/gpu_int4_test.cc b/xla/service/gpu/tests/gpu_int4_test.cc index 17eab52099808..b55a23d410860 100644 --- a/xla/service/gpu/tests/gpu_int4_test.cc +++ b/xla/service/gpu/tests/gpu_int4_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ TEST_F(GpuInt4Test, TestInt4ParameterSize) { // The input should be 2 bytes and the output should be 4 bytes auto expected_ir = R"( -; CHECK: define void {{.*}} dereferenceable(2){{.*}} dereferenceable(4) +; CHECK: define KERNEL_ANNOTATION {{.*}} dereferenceable(2){{.*}} dereferenceable(4) )"; CompileAndVerifyIr(std::move(hlo_module), MakePlatformSpecificLlvm(expected_ir), @@ -58,7 +58,7 @@ TEST_F(GpuInt4Test, TestInt4OutputSize) { // The input should be 4 bytes and the output should be 2 bytes auto expected_ir = R"( -; CHECK: define void {{.*}} dereferenceable(4){{.*}} dereferenceable(2) +; CHECK: define KERNEL_ANNOTATION {{.*}} dereferenceable(4){{.*}} dereferenceable(2) )"; CompileAndVerifyIr(std::move(hlo_module), MakePlatformSpecificLlvm(expected_ir), @@ -78,7 +78,7 @@ TEST_F(GpuInt4Test, TestConstantSize) { // The constant should be 2 bytes and the output should be 4 bytes auto expected_ir = R"( -; CHECK: define void {{.*}} dereferenceable(2){{.*}} dereferenceable(4) +; CHECK: define KERNEL_ANNOTATION {{.*}} dereferenceable(2){{.*}} dereferenceable(4) )"; CompileAndVerifyIr(std::move(hlo_module), MakePlatformSpecificLlvm(expected_ir), @@ -100,7 +100,7 @@ TEST_F(GpuInt4Test, TestOddElements) { // unrolled loop auto expected_ir = R"( ; CHECK: {{.*}}.in_bounds-true: -; CHECK-NEXT: %[[in_bounds:.*]] = icmp ult i32 %linear_index_base, 5 +; CHECK-NEXT: %[[in_bounds:.*]] = icmp ult i32 %linear_index0, 5 ; CHECK-NEXT: br i1 %{{.*}}, label %[[in_bounds_true:.*unrolled_in_bounds-true]], label %[[in_bounds_after:.*unrolled_in_bounds-after]] ; ; CHECK: [[in_bounds_true]]: diff --git a/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index 18322ffc5e7cf..e86f2c09b06ce 100644 --- a/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,10 +17,12 @@ limitations under the License. #include #include -#include "absl/strings/str_replace.h" +#include "absl/status/status.h" +#include "xla/error_spec.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/service/hlo_module_config.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/test.h" namespace xla { @@ -134,7 +136,7 @@ TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @fusion +; CHECK-LABEL: define KERNEL_ANNOTATION @{{[a-z_]*}}fusion ; CHECK: call void BARRIER() ; CHECK: } )"; @@ -170,7 +172,7 @@ TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @fusion +; CHECK-LABEL: define KERNEL_ANNOTATION @{{[a-z_]*}}fusion ; CHECK: call void BARRIER() ; CHECK: } )"; @@ -202,7 +204,7 @@ TEST_F(GpuKernelTilingTest, TransposedInputWithUserReverseNotTiled) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @fusion +; CHECK-LABEL: define KERNEL_ANNOTATION @{{[a-z_]*}}fusion ; CHECK-NOT: call void BARRIER() ; CHECK: } )"; @@ -231,7 +233,7 @@ TEST_F(GpuKernelTilingTest, TransposedInputWithUserBitcastNotTiled) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @fusion +; CHECK-LABEL: define KERNEL_ANNOTATION @{{[a-z_]*}}fusion ; CHECK-NOT: call void BARRIER() ; CHECK: } )"; @@ -268,7 +270,7 @@ TEST_F(GpuKernelTilingTest, TransposedInputWithoutUnsafeUseTiled) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @fusion +; CHECK-LABEL: define KERNEL_ANNOTATION @{{[a-z_]*}}fusion ; CHECK: call void BARRIER() ; CHECK: } )"; @@ -279,88 +281,6 @@ TEST_F(GpuKernelTilingTest, TransposedInputWithoutUnsafeUseTiled) { EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0001})); } -TEST_F(GpuKernelTilingTest, ColumnReductionWithPowerOf2OutputElementsUnrolled) { - const char *const kHloString = R"( - HloModule column_reduce_powerof2 - - reduction { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[] add(x, y) - } - - ENTRY kernel_entry { - constant0 = f32[] constant(0) - arg1 = f16[1024,512,128]{2,1,0} parameter(0) - arg1_conv = f32[1024,512,128]{2,1,0} convert(arg1) - ROOT reduce = f32[512,128]{1,0} reduce(arg1_conv, constant0), dimensions={0}, to_apply=reduction - })"; - - // Check that two calls to llvm.nvvm.atomic are generated. - auto hlo_module = - ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) - .value(); - const char *expected_ir = R"( -; CHECK: store float %{{.*}}, ptr addrspace(1) -; CHECK: store float %{{.*}}, ptr addrspace(1) -)"; - CompileAndVerifyIr(std::move(hlo_module), expected_ir, - /*match_optimized_ir=*/true); - // Check that the kernel runs correctly. - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); -} - -TEST_F(GpuKernelTilingTest, ColumnReductionMOFUnrolled) { - const char *const kHloString = R"( - HloModule column_reduce_powerof2_mof - - reduction22 { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[] add(x, y) - } - - fused_computation { - constant0 = f32[] constant(0) - arg.1 = f16[1024,512,128]{2,1,0} parameter(0) - arg.2 = f16[1024,512,128]{2,1,0} parameter(1) - arg1.conv = f32[1024,512,128]{2,1,0} convert(arg.1) - arg2.conv = f32[1024,512,128]{2,1,0} convert(arg.2) - reduce1 = f32[512,128]{1,0} reduce(arg1.conv, constant0), dimensions={0}, - to_apply=reduction22 - reduce2 = f32[512,128]{1,0} reduce(arg2.conv, constant0), dimensions={0}, - to_apply=reduction22 - add = f32[1024,512,128]{2,1,0} add(arg1.conv, arg2.conv) - ROOT tuple = (f32[512,128]{1,0}, f32[512,128]{1,0}, f32[1024,512,128]{2,1,0}) - tuple(reduce1, reduce2, add) - } - - ENTRY kernel_entry { - arg1 = f16[1024,512,128]{2,1,0} parameter(0) - arg2 = f16[1024,512,128]{2,1,0} parameter(1) - ROOT fusion = (f32[512,128]{1,0}, f32[512,128]{1,0}, f32[1024,512,128]{2,1,0}) - fusion(arg1, arg2), kind=kInput, calls=fused_computation - })"; - - // Check that four calls to llvm.nvvm.atomic are generated. - std::unique_ptr hlo_module = - ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) - .value(); - const char *expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @fusion -; CHECK: store float %{{.*}}, ptr addrspace(1) -; CHECK: store float %{{.*}}, ptr addrspace(1) -; CHECK: store float %{{.*}}, ptr addrspace(1) -; CHECK: store float %{{.*}}, ptr addrspace(1) -; CHECK-NOT: store float %{{.*}}, ptr addrspace(1) -)"; - CompileAndVerifyIr(std::move(hlo_module), - MakePlatformSpecificLlvm(expected_ir), - /*match_optimized_ir=*/true); - // Check that the kernel runs correctly. - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); -} - TEST_F(GpuKernelTilingTest, MofReduceDifferentType) { const char *const kHloString = R"( HloModule module, entry_computation_layout={(f32[128,1024]{1,0})->(f16[128]{0}, f32[128]{0})} @@ -424,7 +344,7 @@ TEST_F(GpuKernelTilingTest, ColumnReductionWithLayoutChangeTiled) { /*match_optimized_ir=*/true); // Check that the kernel runs correctly. - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); + EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{0.001})); } TEST_F(GpuKernelTilingTest, RowReductionWithLayoutChangeTiled) { @@ -448,7 +368,7 @@ TEST_F(GpuKernelTilingTest, RowReductionWithLayoutChangeTiled) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @wrapped_reduce +; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} ; CHECK: call SHUFFLE ; CHECK: } )"; @@ -457,7 +377,7 @@ TEST_F(GpuKernelTilingTest, RowReductionWithLayoutChangeTiled) { /*match_optimized_ir=*/true); // Check that the kernel runs correctly. - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); + EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{0.001})); } TEST_F(GpuKernelTilingTest, RowReductionTwoRowsPerWarp) { @@ -482,7 +402,7 @@ TEST_F(GpuKernelTilingTest, RowReductionTwoRowsPerWarp) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @wrapped_reduce +; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} ; CHECK: %[[TID_X:.*]] = tail call i32 TIDX() ; CHECK: %[[TID_LOGICAL:.*]] = and i32 %[[TID_X]], 15 ; CHECK: call SHUFFLE @@ -521,14 +441,11 @@ TEST_F(GpuKernelTilingTest, RowReductionFourRowsPerWarp) { ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @wrapped_reduce +; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} ; CHECK: %[[TID_X:.*]] = tail call i32 TIDX() ; CHECK: %[[TID_LOGICAL:.*]] = and i32 %[[TID_X]], 7 ; CHECK: call SHUFFLE ; CHECK: %[[LOGICAL_T0:.*]] = icmp eq i32 %[[TID_LOGICAL]], 0 -; CHECK: LCAL -; CHECK: EXTV -; CHECK: BR_CAL )"; CompileAndVerifyIr(std::move(hlo_module), @@ -561,7 +478,7 @@ TEST_F(GpuKernelTilingTest, ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); const char *expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @wrapped_reduce +; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} ; CHECK: store float %{{.*}}, ptr addrspace(1) ; CHECK: } )"; @@ -617,49 +534,6 @@ TEST_F(GpuKernelTilingTest, ColumnReductionSmallTileSizeX) { EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); } -TEST_F(GpuKernelTilingTest, ColReductionWithSmallDtype) { - const char *const kHloString = R"( -HloModule mod -region_0 { - Arg_0 = f32[] parameter(0) - Arg_1 = f32[] parameter(1) - ROOT add = f32[] add(Arg_0, Arg_1) -} -fused_computation { - param_0.4 = bf16[32,16,512,512]{3,2,1,0} parameter(0) - convert.31 = f32[32,16,512,512]{3,2,1,0} convert(param_0.4) - constant_1 = f32[] constant(0) - ROOT reduce = f32[16,512,512]{2,1,0} reduce(convert.31, constant_1), dimensions={0}, to_apply=region_0 -} -ENTRY main { - Arg_3.4 = bf16[32,16,512,512]{3,2,1,0} parameter(0) - ROOT fusion = f32[16,512,512]{2,1,0} fusion(Arg_3.4), kind=kInput, calls=fused_computation -})"; - - // Check that the kernel is not tiled by looking for llvm.nvvm.shfl.sync.down. - auto hlo_module = - ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) - .value(); - std::string expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @fusion -; CHECK: load <4 x i16> -; CHECK-COUNT-4: load PLATFORM_SPECIFIC_TYPE -; CHECK-NOT: load -; CHECK: } -)"; - - expected_ir = absl::StrReplaceAll( - expected_ir, - {{"PLATFORM_SPECIFIC_TYPE", is_built_with_rocm_ ? "i32" : "float"}}); - - CompileAndVerifyIr(std::move(hlo_module), - MakePlatformSpecificLlvm(expected_ir), - /*match_optimized_ir=*/true); - - // Check that the kernel runs correctly. - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); -} - TEST_F(GpuKernelTilingTest, RowReductionWithSmallNonPowerOfTwoDimensionNotTiled) { const char *const kHloString = R"( @@ -682,7 +556,7 @@ TEST_F(GpuKernelTilingTest, ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .value(); auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @wrapped_reduce +; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} ; CHECK-NOT: call SHUFFLE ; CHECK: } )"; @@ -719,59 +593,6 @@ TEST_F(GpuKernelTilingTest, RowReductionRequiring64BitIndex) { /*match_optimized_ir=*/true); } -TEST_F(GpuKernelTilingTest, ColumnReductionVectorization) { - const char *const kHloString = R"( -HloModule column_reduce_powerof2 - -reduction { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[] add(x, y) -} - -ENTRY kernel_entry { - constant0 = f32[] constant(0) - arg1 = f32[1024,512,128]{2,1,0} parameter(0) - ROOT reduce = f32[512,128]{1,0} reduce(arg1, constant0), dimensions={0}, to_apply=reduction -} - )"; - auto expected_ir = R"( -; CHECK: load <2 x float>, ptr - )"; - auto hlo_module = ParseAndReturnVerifiedModule(kHloString).value(); - CompileAndVerifyIr(std::move(hlo_module), expected_ir, - /*match_optimized_ir=*/true); -} - -TEST_F(GpuKernelTilingTest, ColumnMultiOutputVectorization) { - const char *const kHloString = R"( -HloModule HandleReductionToVectorAndOtherReduction - -add { - acc = f16[] parameter(1) - op = f16[] parameter(0) - ROOT out = f16[] add(acc, op) -} - -ENTRY main { - p = f16[4096,4096] parameter(0) - l1 = log(p) - l2 = log(l1) - s = log(l2) - z = f16[] constant(0) - r1 = f16[4096] reduce(p, z), dimensions={0}, to_apply=add - r2 = f16[4096] reduce(s, z), dimensions={0}, to_apply=add - ROOT out = (f16[4096], f16[4096]) tuple(r1, r2) -} - )"; - auto expected_ir = R"( -; CHECK: load <2 x half>, ptr - )"; - auto hlo_module = ParseAndReturnVerifiedModule(kHloString).value(); - CompileAndVerifyIr(std::move(hlo_module), expected_ir, - /*match_optimized_ir=*/true); -} - TEST_F(GpuKernelTilingTest, Hlo021CopyNoOobAccess) { const char *const kHloString = R"( HloModule primitive_computation_svd.38 @@ -815,10 +636,11 @@ TEST_F(GpuKernelTilingTest, RowReductionCorrectShmemUsage) { )"; auto hlo_module = ParseAndReturnVerifiedModule(kHloString).value(); auto expected_ir = is_built_with_rocm_ ? R"( -; CHECK: initial_value_addr = internal unnamed_addr addrspace({{[0-9]*}}) global [1024 x float] poison, align 4 +; CHECK: %llvm.amdgcn.kernel.input_reduce_fusion.lds.t = type { [4 x [2 x float]] } +; CHECK: @llvm.amdgcn.kernel.input_reduce_fusion.lds = internal addrspace(3) global %llvm.amdgcn.kernel.input_reduce_fusion.lds.t poison )" : R"( -; CHECK: shared_cache = private unnamed_addr addrspace({{[0-9]*}}) global [1 x [1 x [2 x float]]] +; CHECK: shared_cache = private unnamed_addr addrspace({{[0-9]*}}) global [4 x [2 x float]] )"; CompileAndVerifyIr(std::move(hlo_module), expected_ir, /*match_optimized_ir=*/true); @@ -835,18 +657,17 @@ TEST_F(GpuKernelTilingTest, ReductionInputTooLarge) { } ENTRY reduce.1 { - parameter = f32[4,1048576,1024,1024] parameter(0) + parameter = f32[16,1048576,1024,1024] parameter(0) init_value = f32[] constant(0) - ROOT reduce = f32[4,1048576,1024] reduce(parameter, init_value), dimensions={3}, to_apply=Sum + ROOT reduce = f32[16,1048576,1024] reduce(parameter, init_value), dimensions={3}, to_apply=Sum } )"; auto hlo_module = ParseAndReturnVerifiedModule(kHloString).value(); - Status status = CompileToExecutable(std::move(hlo_module)).status(); - EXPECT_EQ(status.code(), absl::StatusCode::kFailedPrecondition); - EXPECT_THAT( - status.message(), - ::testing::HasSubstr( - "Number of physical blocks (4294967296) does not fit in an i32")); + absl::Status status = CompileToExecutable(std::move(hlo_module)).status(); + EXPECT_THAT(status.message(), + ::testing::ContainsRegex( + "Kernel '.*' launch needs more blocks [(]4294967296[)] than " + "allowed by hardware [(]2147483647[)]")); } } // namespace diff --git a/xla/service/gpu/tests/gpu_ldg_test.cc b/xla/service/gpu/tests/gpu_ldg_test.cc index d871995f5cdb6..5fbee5e076d78 100644 --- a/xla/service/gpu/tests/gpu_ldg_test.cc +++ b/xla/service/gpu/tests/gpu_ldg_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,12 +22,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/literal.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/service/gpu/tests/gpu_noalias_test.cc b/xla/service/gpu/tests/gpu_noalias_test.cc index 39398c5ea821e..6d5b520bfc61c 100644 --- a/xla/service/gpu/tests/gpu_noalias_test.cc +++ b/xla/service/gpu/tests/gpu_noalias_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,8 +18,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/literal.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" diff --git a/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc b/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc index 41bb65d0a7a6a..2711e76aebcef 100644 --- a/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc +++ b/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,17 +15,25 @@ limitations under the License. #include "xla/service/gpu/gpu_reduce_scatter_creator.h" +#include +#include #include #include +#include "absl/algorithm/container.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -35,10 +43,9 @@ namespace m = ::xla::match; class GpuReduceScatterCreatorTest : public HloTestBase { public: - StatusOr> RunPass(absl::string_view hlo_module, - int64_t num_replicas, - int64_t num_partitions, - bool expect_change) { + absl::StatusOr> RunPass( + absl::string_view hlo_module, int64_t num_replicas, + int64_t num_partitions, bool expect_change) { HloModuleConfig config = GetModuleConfigForTest( /*replica_count=*/num_replicas, /*num_partitions=*/num_partitions); @@ -50,7 +57,7 @@ class GpuReduceScatterCreatorTest : public HloTestBase { return changed.status(); } EXPECT_EQ(changed.value(), expect_change); - return StatusOr>(std::move(module)); + return absl::StatusOr>(std::move(module)); } size_t AllReduceCount(std::unique_ptr &module) { diff --git a/xla/service/gpu/tests/gpu_spmd_e2e_compile_test.cc b/xla/service/gpu/tests/gpu_spmd_e2e_compile_test.cc index e4016acc8f4c5..0276247d3b113 100644 --- a/xla/service/gpu/tests/gpu_spmd_e2e_compile_test.cc +++ b/xla/service/gpu/tests/gpu_spmd_e2e_compile_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,10 +15,18 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/status/statusor.h" +#include "xla/debug_options_flags.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/service/executable.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/service/hlo_module_config.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -43,7 +51,7 @@ ENTRY entry { auto hlo_module = ParseAndReturnVerifiedModule(hlo_string, config).value(); // Verify that compilation succeeded. - StatusOr> executable = + absl::StatusOr> executable = CompileToExecutable(std::move(hlo_module)); TF_EXPECT_OK(executable.status()); } diff --git a/xla/service/gpu/tests/gpu_too_many_blocks_test.cc b/xla/service/gpu/tests/gpu_too_many_blocks_test.cc index 20d84aa7473db..87b2c79723268 100644 --- a/xla/service/gpu/tests/gpu_too_many_blocks_test.cc +++ b/xla/service/gpu/tests/gpu_too_many_blocks_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,12 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/executable.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/statusor.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -29,6 +31,10 @@ namespace { class TooManyBlocksTest : public GpuCodegenTest {}; TEST_F(TooManyBlocksTest, FailsWithInvalidStatus) { + // This test ensures that invalid (too large) launch grids are caught + // somewhere in the pipeline. The practical relevance is low, since as of + // 2024, the inputs or outputs have to be way too large to fit on any GPU + // anyway. const char* hlo_text = R"( HloModule primitive_computation_mul.8 @@ -45,14 +51,15 @@ ENTRY primitive_computation_mul.8 { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, GetOptimizedModule(hlo_text)); - StatusOr> failed_executable = + absl::StatusOr> failed_executable = backend().compiler()->RunBackend( std::move(optimized_module), backend().default_stream_executor(), backend().default_stream_executor()->GetAllocator()); EXPECT_FALSE(failed_executable.ok()); - EXPECT_THAT(failed_executable.status().ToString(), - ::testing::HasSubstr("Kernel launch needs more blocks")); + EXPECT_THAT( + failed_executable.status().ToString(), + ::testing::ContainsRegex("Kernel '.*fusion.*' launch needs more blocks")); } } // namespace diff --git a/xla/service/gpu/tests/gpu_triton_custom_call_test.cc b/xla/service/gpu/tests/gpu_triton_custom_call_test.cc new file mode 100644 index 0000000000000..52351018c743b --- /dev/null +++ b/xla/service/gpu/tests/gpu_triton_custom_call_test.cc @@ -0,0 +1,251 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/status_matchers.h" + +namespace xla { +namespace gpu { + +using ::mlir::ArrayRef; +using ::mlir::NamedAttribute; + +namespace { + +std::unique_ptr CreateAddTritonCustomCall( + Shape tuple_shape, HloInstruction* param_0, HloInstruction* param_1) { + mlir::MLIRContext context_; + mlir::Builder builder(&context_); + + // Create the backend_config for the triton custom call. + const std::string kMLIRText = R"( + module { + tt.func public @add_one(%arg0: !tt.ptr {tt.divisibility = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 32 : i32}, %arg3: !tt.ptr {tt.divisibility = 32 : i32}) { + %0 = tt.get_program_id x : i32 + %1 = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32 + %2 = tt.load %arg1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32 + %cst = arith.constant 1.000000e+00 : f32 + %3 = arith.addf %1, %cst : f32 + %4 = tt.load %arg2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32 + tt.store %arg2, %3 {cache = 1 : i32, evict = 1 : i32} : f32 + %5 = tt.load %arg3 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32 + tt.store %arg3, %2 {cache = 1 : i32, evict = 1 : i32} : f32 + tt.return + } + } + )"; + + NamedAttribute name = + builder.getNamedAttr("name", builder.getStringAttr("add_one")); + NamedAttribute ir = + builder.getNamedAttr("ir", builder.getStringAttr(kMLIRText)); + NamedAttribute num_stages = + builder.getNamedAttr("num_stages", builder.getI32IntegerAttr(3)); + NamedAttribute num_warps = + builder.getNamedAttr("num_warps", builder.getI32IntegerAttr(4)); + NamedAttribute grid_x = + builder.getNamedAttr("grid_x", builder.getI32IntegerAttr(1)); + NamedAttribute grid_y = + builder.getNamedAttr("grid_y", builder.getI32IntegerAttr(1)); + NamedAttribute grid_z = + builder.getNamedAttr("grid_z", builder.getI32IntegerAttr(1)); + NamedAttribute debug = + builder.getNamedAttr("debug", builder.getBoolAttr(false)); + + std::vector attributes = { + name, ir, num_stages, num_warps, grid_x, grid_y, grid_z, debug}; + ArrayRef attributesRef(attributes); + mlir::DictionaryAttr backend_config = + mlir::DictionaryAttr::get(&context_, attributesRef); + + // Parse the backend_config into a string. + std::string backend_config_str; + llvm::raw_string_ostream(backend_config_str) << backend_config; + + return HloInstruction::CreateCustomCall(tuple_shape, {param_0, param_1}, + "__gpu$xla.gpu.triton", + backend_config_str); +} + +} // namespace + +class GpuIrEmitterUnnestedTest : public GpuCodegenTest { + public: + se::CudaComputeCapability GetCudaComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + } +}; + +TEST_F(GpuIrEmitterUnnestedTest, + EmitTritonCustomCallWithCorrectLoweringAndWithoutNoaliasOrAlignment) { + if (!GetCudaComputeCapability().IsAtLeastAmpere()) { + GTEST_SKIP() << "Triton support is only enabled for Ampere GPUs and up."; + } + + // Tests that the lowering of a Triton custom call produces the correct LLVM + // IR, and that the arguments do not specify noalias or alignment attributes. + + HloComputation::Builder computation_builder(TestName()); + + // Create parameters and custom call in the computation builder. + Shape scalar_shape = xla::ShapeUtil::MakeShape(xla::F32, {}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + HloInstruction* param_0 = computation_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "arg_0")); + + HloInstruction* param_1 = computation_builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "arg_1")); + + computation_builder.AddInstruction( + CreateAddTritonCustomCall(tuple_shape, param_0, param_1)); + + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(computation_builder.Build()); + + // Check that the compiled llvm ir matches the expected lowering of our tt ir. + // We check that the arguments do not specify noalias or alignment attributes, + // as this prevents recompilation based on the alignment of the input buffers. + CompileAndVerifyIr(std::move(module), + R"( +; CHECK: @add_one +; CHECK-NOT: noalias align +; CHECK-SAME: dereferenceable(4) %arg0 +; CHECK-NOT: noalias align +; CHECK-SAME: dereferenceable(4) %arg1 +; CHECK-NOT: noalias align +; CHECK-SAME: dereferenceable(4) %arg2 +; CHECK-NOT: noalias align +; CHECK-SAME: dereferenceable(4) %arg3 +; CHECK-DAG: addrspacecast ptr %arg0 to ptr addrspace(1) +; CHECK-DAG: addrspacecast ptr %arg1 to ptr addrspace(1) +; CHECK-DAG: addrspacecast ptr %arg2 to ptr addrspace(1) +; CHECK-DAG: addrspacecast ptr %arg3 to ptr addrspace(1) +; CHECK: tail call i32 asm sideeffect +; CHECK: tail call i32 asm sideeffect +; CHECK: fadd float +; CHECK-SAME: 1.000000e+00 +; CHECK-DAG: tail call void asm sideeffect +; CHECK-DAG: tail call void asm sideeffect +; CHECK: ret void + )", + /*match_optimized_ir=*/false); +} + +TEST_F(GpuIrEmitterUnnestedTest, CanNotEmitTritonCustomCallOnPreAmpereGpu) { + if (GetCudaComputeCapability().IsAtLeastAmpere()) { + GTEST_SKIP() << "Running on Ampere or more recent GPU, skipping."; + } + + HloComputation::Builder computation_builder(TestName()); + + // Create parameters and custom call in the computation builder. + Shape scalar_shape = xla::ShapeUtil::MakeShape(xla::F32, {}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + HloInstruction* param_0 = computation_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "arg_0")); + + HloInstruction* param_1 = computation_builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "arg_1")); + + computation_builder.AddInstruction( + CreateAddTritonCustomCall(tuple_shape, param_0, param_1)); + + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(computation_builder.Build()); + + EXPECT_THAT( + CompileToExecutable(std::move(module), /*run_optimization_passes=*/false), + tsl::testing::StatusIs( + absl::StatusCode::kFailedPrecondition, + ::testing::StrEq( + "Triton support is only enabled for Ampere GPUs and up."))); +} + +class TritonCustomCallTest : public HloTestBase {}; + +TEST_F(TritonCustomCallTest, NoArgumentDeduplication) { + if (auto cc = backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + !cc.IsAtLeastAmpere()) { + GTEST_SKIP() << "Triton support is only enabled for Ampere GPUs and up."; + } + + // Tests that no argument deduplication is done for Triton kernels. + // + // Triton kernels are compiled on the first call and re-used for all the + // following calls. So, if we are unlucky, we could end up calling the + // compiled kernel with fewer arguments than it expects in the presence + // of argument deduplication. + // + // For example, + // + // * The first call is f(x, y). The arguments are distinct, no deduplication + // is done at compilation time and the compiled kernel expects two + // arguments. + // * The second call is f(x, x). The arguments are deduplicated and we + // call the previously compiled kernel with just x, causing a crash. + + HloComputation::Builder computation_builder(TestName()); + + Shape scalar_shape = xla::ShapeUtil::MakeShape(xla::F32, {}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + HloInstruction* param_0 = computation_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "arg_0")); + + HloInstruction* param_1 = computation_builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "arg_1")); + + auto* instr_0 = computation_builder.AddInstruction( + CreateAddTritonCustomCall(tuple_shape, param_0, param_1)); + computation_builder.AddInstruction( + CreateAddTritonCustomCall(tuple_shape, instr_0, instr_0)); + + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(computation_builder.Build()); + EXPECT_TRUE(Run(std::move(module), /*run_hlo_passes=*/false)); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/tests/gpu_unrolling_test.cc b/xla/service/gpu/tests/gpu_unrolling_test.cc index 3969034aaa123..8d0363478a75b 100644 --- a/xla/service/gpu/tests/gpu_unrolling_test.cc +++ b/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ limitations under the License. #include +#include "xla/debug_options_flags.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/service/hlo_module_config.h" #include "xla/tests/hlo_test_base.h" @@ -50,27 +51,14 @@ TEST_F(GpuUnrollingTest, UnrollDefaultTimes) { CompileAndVerifyIr(std::move(hlo_module), R"( -; CHECK-LABEL: @fusion -; CHECK: load float -; CHECK: load float -; CHECK: fadd -; CHECK: store float -; CHECK: load float -; CHECK: load float -; CHECK: fadd -; CHECK: store float -; CHECK: load float -; CHECK: load float -; CHECK: fadd -; CHECK: store float -; CHECK: load float -; CHECK: load float -; CHECK: fadd -; CHECK: store float -; CHECK-NOT: fadd -; CHECK: } +; CHECK-LABEL: @{{[a-z_]*}}fusion +; CHECK-NOT: load float +; CHECK-NOT: store float +; CHECK: load <4 x float> +; CHECK: load <4 x float> +; CHECK: store <4 x float> )", - /*match_optimized_ir=*/false); + /*match_optimized_ir=*/true); } TEST_F(GpuUnrollingTest, UnrollUnfusedAdd) { @@ -91,26 +79,13 @@ TEST_F(GpuUnrollingTest, UnrollUnfusedAdd) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: @wrapped_add -; CHECK: load float -; CHECK: load float -; CHECK: fadd -; CHECK: store float -; CHECK: load float -; CHECK: load float -; CHECK: fadd -; CHECK: store float -; CHECK: load float -; CHECK: load float -; CHECK: fadd -; CHECK: store float -; CHECK: load float -; CHECK: load float -; CHECK: fadd -; CHECK: store float -; CHECK-NOT: fadd -; CHECK: } +; CHECK-NOT: load float +; CHECK-NOT: store float +; CHECK: load <4 x float> +; CHECK: load <4 x float> +; CHECK: store <4 x float> )", - /*match_optimized_ir=*/false); + /*match_optimized_ir=*/true); } TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedSine) { @@ -243,38 +218,15 @@ TEST_F(GpuUnrollingTest, UnrollMultiOutputFusion) { CompileAndVerifyIr(std::move(hlo_module), R"( -; CHECK-LABEL: @fusion -; CHECK: load float -; CHECK: load float -; CHECK-NOT: load float -; CHECK-NOT: load float -; CHECK: fadd -; CHECK: load float -; CHECK: load float -; CHECK-NOT: load float -; CHECK-NOT: load float -; CHECK: fmul -; CHECK: store float -; CHECK: store float -; CHECK-NOT: store float -; CHECK-NOT: store float -; CHECK: load float -; CHECK: load float +; CHECK-LABEL: @{{[a-z_]*}}fusion ; CHECK-NOT: load float -; CHECK-NOT: load float -; CHECK: fadd -; CHECK: load float -; CHECK: load float -; CHECK-NOT: load float -; CHECK-NOT: load float -; CHECK: fmul -; CHECK: store float -; CHECK: store float -; CHECK-NOT: store float ; CHECK-NOT: store float -; CHECK: } +; CHECK: load <4 x float> +; CHECK: load <4 x float> +; CHECK: store <4 x float> +; CHECK: store <4 x float> )", - /*match_optimized_ir=*/false); + /*match_optimized_ir=*/true); } } // namespace diff --git a/xla/service/gpu/tests/hlo_to_llvm_ir.cc b/xla/service/gpu/tests/hlo_to_llvm_ir.cc deleted file mode 100644 index a3bc1ef8263a1..0000000000000 --- a/xla/service/gpu/tests/hlo_to_llvm_ir.cc +++ /dev/null @@ -1,159 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/buffer_sharing.h" -#include "xla/service/gpu/compile_module_to_llvm_ir.h" -#include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/service/gpu/gpu_hlo_schedule.h" -#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" -#include "xla/service/gpu/target_constants.h" -#include "xla/status.h" - -#if GOOGLE_CUDA -#include "xla/stream_executor/cuda/cuda_platform_id.h" -#elif TENSORFLOW_USE_ROCM -#include "xla/stream_executor/rocm/rocm_platform_id.h" -#include "tsl/platform/rocm_rocdl_path.h" -#endif -#include "xla/tests/test_utils.h" -#include "xla/tools/hlo_module_loader.h" -#include "tsl/platform/init_main.h" -#include "tsl/util/command_line_flags.h" - -const char* const kUsage = R"( -This tool reads in an scheduled HloModule from a file, compiles it using the -NVPTX or AMDGPU compiler and prints out the LLVM IR generated by the IR emitter. -The LLVM IR is not optimized by the LLVM pass pipeline, so this tool can be used -to unit test the XLA GPU IR emitters. - -Note that the LLVM IR does not contain the *full* module, but only parts that -will be code generated into PTX/Hsaco. The NVPTX/Hsaco compiler also generates a -GpuExecutable on the side that is not printed. - -When passed the parameter `--ptx`, the LLVM IR will be optimized and PTX -will be emitted and printed instead of the non-optimized LLVM. -By default SM 70 is targeted. But this can be changed with `--sm=SM`.)"; - -namespace { -xla::Status CompileAndPrintLlvmIr(const std::string& hlo_text, - bool generate_ptx, int sm) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_module, - xla::LoadModuleFromData(/*data=*/hlo_text, /*format=*/"hlo")); - - CHECK(hlo_module->has_schedule()); - TF_RETURN_IF_ERROR(VerifyHloModule(hlo_module.get(), - /*layout_sensitive=*/false, - /*allow_mixed_precision=*/true)); -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - llvm::LLVMContext llvm_context; - - stream_executor::CudaComputeCapability cuda_compute_capability; - cuda_compute_capability.major = sm / 10; - cuda_compute_capability.minor = sm % 10; - -#if GOOGLE_CUDA - stream_executor::DeviceDescription gpu_device_info = - xla::gpu::TestGpuDeviceInfo::RTXA6000DeviceInfo(cuda_compute_capability); - std::string target_triple = xla::gpu::nvptx::TargetTriple(); - std::string data_layout = xla::gpu::nvptx::DataLayout(); - std::string platform_name = "CUDA"; - stream_executor::Platform::Id platform_id = - stream_executor::cuda::kCudaPlatformId; -#elif TENSORFLOW_USE_ROCM - stream_executor::DeviceDescription gpu_device_info = - xla::gpu::TestGpuDeviceInfo::AMDMI210DeviceInfo(); - std::string target_triple = xla::gpu::amdgpu::TargetTriple(); - std::string data_layout = xla::gpu::amdgpu::DataLayout(); - std::string platform_name = "ROCm"; - stream_executor::Platform::Id platform_id = - stream_executor::rocm::kROCmPlatformId; -#endif - - auto buffer_size_bytes_function = [](const xla::BufferValue& buffer) { - return xla::gpu::GetSizeOfShape(buffer.shape(), /*pointer_size=*/8); - }; - - TF_ASSIGN_OR_RETURN( - xla::gpu::CompileModuleResults results, - xla::gpu::CompileModuleToLlvmIr( - hlo_module.get(), &llvm_context, target_triple, data_layout, - platform_name, platform_id, gpu_device_info, - &xla::gpu::CanShareBufferHint, buffer_size_bytes_function)); - llvm::Module* llvm_module = results.llvm_module.get(); - - if (!generate_ptx) { - llvm_module->print(llvm::outs(), nullptr); - } else { -#if GOOGLE_CUDA - TF_ASSIGN_OR_RETURN( - std::string ptx, - xla::gpu::nvptx::CompileToPtx(llvm_module, cuda_compute_capability, - hlo_module->config().debug_options())); - std::cout << ptx << std::endl; -#elif TENSORFLOW_USE_ROCM - return {absl::StatusCode::kUnimplemented, - "Feature not yet implemented in ROCm"}; -#endif - } -#endif - return xla::OkStatus(); -} - -xla::Status CompileAndPrintLlvmIrFromFile(const std::string& file_name, - bool ptx, int sm) { - std::string full_text; - TF_RETURN_IF_ERROR( - tsl::ReadFileToString(tsl::Env::Default(), file_name, &full_text)); - - std::vector hlo_module_texts = - absl::StrSplit(full_text, "// -----"); - for (const std::string& hlo_module_text : hlo_module_texts) { - TF_RETURN_IF_ERROR(CompileAndPrintLlvmIr(hlo_module_text, ptx, sm)); - } - - return xla::OkStatus(); -} -} // namespace - -int main(int argc, char** argv) { - bool ptx = false; - int sm = 70; - std::vector flag_list; - xla::AppendDebugOptionsFlags(&flag_list); - flag_list.emplace_back("ptx", &ptx, - "Print PTX instead of not optimized LLVM."); - flag_list.emplace_back("sm", &sm, - "Specify the SM to target (useful only with --ptx)."); - // The usage string includes the message at the top of the file, the - // DebugOptions flags and the flags defined above. - const std::string kUsageString = - absl::StrCat(kUsage, "\n\n", tsl::Flags::Usage(argv[0], flag_list)); - bool parse_ok = tsl::Flags::Parse(&argc, argv, flag_list); - tsl::port::InitMain(kUsageString.c_str(), &argc, &argv); - if (!parse_ok) { - LOG(QFATAL) << kUsageString; - } - - QCHECK(argc == 2) << "Must specify a single input file"; - TF_CHECK_OK(CompileAndPrintLlvmIrFromFile(argv[1], ptx, sm)); - - return 0; -} diff --git a/xla/service/gpu/tests/in_place_op_test.cc b/xla/service/gpu/tests/in_place_op_test.cc index db028c34c7aad..853aaf1843094 100644 --- a/xla/service/gpu/tests/in_place_op_test.cc +++ b/xla/service/gpu/tests/in_place_op_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + +#include "xla/debug_options_flags.h" #include "xla/tests/hlo_test_base.h" namespace xla { diff --git a/xla/service/gpu/tests/infeed_test.cc b/xla/service/gpu/tests/infeed_test.cc index 08c0a0d5ac592..933313f39710b 100644 --- a/xla/service/gpu/tests/infeed_test.cc +++ b/xla/service/gpu/tests/infeed_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,17 +15,19 @@ limitations under the License. #include +#include #include -#include "xla/client/global_data.h" -#include "xla/client/lib/arithmetic.h" +#include "xla/array3d.h" +#include "xla/array4d.h" #include "xla/client/local_client.h" #include "xla/client/xla_builder.h" +#include "xla/layout.h" +#include "xla/layout_util.h" #include "xla/literal.h" -#include "xla/shape_util.h" +#include "xla/literal_util.h" #include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" #include "tsl/platform/env.h" namespace xla { diff --git a/xla/service/gpu/tests/kernel_launch_test.cc b/xla/service/gpu/tests/kernel_launch_test.cc index b2c0f27eba5e2..c073f877e9ed4 100644 --- a/xla/service/gpu/tests/kernel_launch_test.cc +++ b/xla/service/gpu/tests/kernel_launch_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/error_spec.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/tests/hlo_test_base.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/tests/kernel_reuse.hlo b/xla/service/gpu/tests/kernel_reuse.hlo index c497a09d92c92..41734e06259a0 100644 --- a/xla/service/gpu/tests/kernel_reuse.hlo +++ b/xla/service/gpu/tests/kernel_reuse.hlo @@ -1,9 +1,11 @@ -// RUN: hlo_to_llvm_ir %s | FileCheck %{IR_SUBST} %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s // All fusions must reuse the same kernel: // CHECK-LABEL: target triple -// CHECK: define [[KERNEL_ANNOTATION]]void -// CHECK-NOT: define [[KERNEL_ANNOTATION]]void +// CHECK-PTX: define void +// CHECK-GCN: define amdgpu_kernel void +// CHECK-PTX-NOT: define void +// CHECK-GCN-NOT: define amdgpu_kernel void HloModule KernelReuse, is_scheduled=true @@ -30,13 +32,13 @@ fused_computation.2 { ENTRY main { a = f32[5,5]{1,0} parameter(0) - custom-call = f32[5,5]{1,0} custom-call(a, a), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}" + custom-call = f32[5,5]{1,0} custom-call(a, a), custom_call_target="__cublas$gemm", backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" fusion.2 = f32[5,5]{1,0} fusion(custom-call), kind=kLoop, calls=fused_computation.2 - custom-call.1 = f32[5,5]{1,0} custom-call(fusion.2, fusion.2), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}" + custom-call.1 = f32[5,5]{1,0} custom-call(fusion.2, fusion.2), custom_call_target="__cublas$gemm", backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" fusion.1 = f32[5,5]{1,0} fusion(custom-call.1), kind=kLoop, calls=fused_computation.1 - custom-call.2 = f32[5,5]{1,0} custom-call(fusion.1, fusion.1), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}" + custom-call.2 = f32[5,5]{1,0} custom-call(fusion.1, fusion.1), custom_call_target="__cublas$gemm", backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" fusion = f32[5,5]{1,0} fusion(custom-call.2), kind=kLoop, calls=fused_computation - custom-call.3 = f32[5,5]{1,0} custom-call(fusion, fusion), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}" + custom-call.3 = f32[5,5]{1,0} custom-call(fusion, fusion), custom_call_target="__cublas$gemm", backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" ROOT tuple = (f32[5,5]{1,0}, f32[5,5]{1,0}, f32[5,5]{1,0}, f32[5,5]{1,0}) tuple(custom-call, custom-call.1, custom-call.2, custom-call.3) } @@ -44,9 +46,12 @@ ENTRY main { // All (Triton) fusions must reuse the same kernel: // CHECK-LABEL: target triple -// CHECK-NOT: define [[KERNEL_ANNOTATION]]void -// CHECK: define [[KERNEL_ANNOTATION]]void @triton_gemm_dot1( -// CHECK-NOT: define [[KERNEL_ANNOTATION]]void +// CHECK-PTX-NOT: define void +// CHECK-GCN-NOT: define amdgpu_kernel void +// CHECK-PTX: define void @triton_gemm_dot1( +// CHECK-GCN: define amdgpu_kernel void @triton_gemm_dot1( +// CHECK-PTX-NOT: define void +// CHECK-GCN-NOT: define amdgpu_kernel void HloModule t, is_scheduled=true @@ -69,8 +74,8 @@ ENTRY e { p2 = f16[15,19]{1,0} parameter(2) p1 = s8[19,17]{1,0} parameter(1) p0 = f16[15,19]{1,0} parameter(0) - triton_gemm_dot1 = f16[15,17]{1,0} fusion(p3, p2), kind=kCustom, calls=triton_gemm_dot1, backend_config="{kind: \"__triton_gemm\", triton_gemm_config: {\"block_m\":\"64\",\"block_n\":\"32\",\"block_k\":\"64\",\"split_k\":\"1\",\"num_stages\":\"4\",\"num_warps\":\"4\"}}" - triton_gemm_dot0 = f16[15,17]{1,0} fusion(p1, p0), kind=kCustom, calls=triton_gemm_dot0, backend_config="{kind: \"__triton_gemm\", triton_gemm_config: {\"block_m\":\"64\",\"block_n\":\"32\",\"block_k\":\"64\",\"split_k\":\"1\",\"num_stages\":\"4\",\"num_warps\":\"4\"}}" + triton_gemm_dot1 = f16[15,17]{1,0} fusion(p3, p2), kind=kCustom, calls=triton_gemm_dot1, backend_config="{ \"fusion_backend_config\": {kind: \"__triton_gemm\", triton_gemm_config: {\"block_m\":\"64\",\"block_n\":\"32\",\"block_k\":\"64\",\"split_k\":\"1\",\"num_stages\":\"4\",\"num_warps\":\"4\",\"num_ctas\":\"1\"}}}" + triton_gemm_dot0 = f16[15,17]{1,0} fusion(p1, p0), kind=kCustom, calls=triton_gemm_dot0, backend_config="{ \"fusion_backend_config\": {kind: \"__triton_gemm\", triton_gemm_config: {\"block_m\":\"64\",\"block_n\":\"32\",\"block_k\":\"64\",\"split_k\":\"1\",\"num_stages\":\"4\",\"num_warps\":\"4\",\"num_ctas\":\"1\"}}}" ROOT tuple = (f16[15,17]{1,0}, f16[15,17]{1,0}) tuple(triton_gemm_dot0, triton_gemm_dot1) } @@ -80,9 +85,12 @@ ENTRY e { // - @fusion_2's %arg0 must have align 16, because we are passing a module input // - @fusion_1's %arg0 must have align 128, because we are passing an internal buffer // CHECK-LABEL: target triple -// CHECK: define [[KERNEL_ANNOTATION]]void @fusion_2(ptr noalias align 16 dereferenceable(100) %arg0, ptr noalias align 128 dereferenceable(100) %arg1) -// CHECK: define [[KERNEL_ANNOTATION]]void @fusion_1(ptr noalias align 128 dereferenceable(100) %arg0, ptr noalias align 128 dereferenceable(100) %arg1) -// CHECK-NOT: define [[KERNEL_ANNOTATION]]void +// CHECK-PTX-DAG: define void @fusion_2(ptr noalias align 16 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) +// CHECK-GCN-DAG: define amdgpu_kernel void @fusion_2(ptr noalias align 16 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) +// CHECK-PTX-DAG: define void @fusion_1(ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) +// CHECK-GCN-DAG: define amdgpu_kernel void @fusion_1(ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) +// CHECK-PTX-NOT: define void +// CHECK-GCN-NOT: define amdgpu_kernel void HloModule KernelReuse, is_scheduled=true @@ -110,11 +118,11 @@ fused_computation.2 { ENTRY main { a = f32[5,5]{1,0} parameter(0) fusion.2 = f32[5,5]{1,0} fusion(a), kind=kLoop, calls=fused_computation.2 - custom-call.1 = f32[5,5]{1,0} custom-call(fusion.2, fusion.2), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}" + custom-call.1 = f32[5,5]{1,0} custom-call(fusion.2, fusion.2), custom_call_target="__cublas$gemm", backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" fusion.1 = f32[5,5]{1,0} fusion(custom-call.1), kind=kLoop, calls=fused_computation.1 - custom-call.2 = f32[5,5]{1,0} custom-call(fusion.1, fusion.1), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}" + custom-call.2 = f32[5,5]{1,0} custom-call(fusion.1, fusion.1), custom_call_target="__cublas$gemm", backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" fusion = f32[5,5]{1,0} fusion(custom-call.2), kind=kLoop, calls=fused_computation - custom-call.3 = f32[5,5]{1,0} custom-call(fusion, fusion), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}" + custom-call.3 = f32[5,5]{1,0} custom-call(fusion, fusion), custom_call_target="__cublas$gemm", backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" ROOT tuple = (f32[5,5]{1,0}, f32[5,5]{1,0}, f32[5,5]{1,0}) tuple(custom-call.1, custom-call.2, custom-call.3) } @@ -124,8 +132,10 @@ ENTRY main { // The first has just 2 parameters (1 input, 1 output) and the second has 3 (2 input, 1 output). // All the parameters are noalias, because we are not passing the same argument twice to the kernel. // CHECK-LABEL: target triple -// CHECK: define [[KERNEL_ANNOTATION]]void @fusion_2(ptr noalias align 128 dereferenceable(100) %arg0, ptr noalias align 128 dereferenceable(100) %arg1) -// CHECK: define [[KERNEL_ANNOTATION]]void @fusion_1(ptr noalias align 128 dereferenceable(100) %arg0, ptr noalias align 128 dereferenceable(100) %arg1, ptr noalias align 128 dereferenceable(100) %arg2) +// CHECK-PTX-DAG: define void @fusion_2(ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) +// CHECK-GCN-DAG: define amdgpu_kernel void @fusion_2(ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) +// CHECK-PTX-DAG: define void @fusion_1(ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) +// CHECK-GCN-DAG: define amdgpu_kernel void @fusion_1(ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) // CHECK-NOT: define void HloModule KernelReuse, is_scheduled=true @@ -150,13 +160,13 @@ fused_computation.2 { ENTRY main { a = f32[5,5]{1,0} parameter(0) - custom-call = f32[5,5]{1,0} custom-call(a, a), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}" + custom-call = f32[5,5]{1,0} custom-call(a, a), custom_call_target="__cublas$gemm", backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" fusion.2 = f32[5,5]{1,0} fusion(custom-call, custom-call), kind=kLoop, calls=fused_computation.2 - custom-call.1 = f32[5,5]{1,0} custom-call(fusion.2, fusion.2), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}" + custom-call.1 = f32[5,5]{1,0} custom-call(fusion.2, fusion.2), custom_call_target="__cublas$gemm", backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" fusion.1 = f32[5,5]{1,0} fusion(custom-call, custom-call.1), kind=kLoop, calls=fused_computation.1 - custom-call.2 = f32[5,5]{1,0} custom-call(fusion.1, fusion.1), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}" + custom-call.2 = f32[5,5]{1,0} custom-call(fusion.1, fusion.1), custom_call_target="__cublas$gemm", backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" fusion = f32[5,5]{1,0} fusion(custom-call.1, custom-call.2), kind=kLoop, calls=fused_computation - custom-call.3 = f32[5,5]{1,0} custom-call(fusion, fusion), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}" + custom-call.3 = f32[5,5]{1,0} custom-call(fusion, fusion), custom_call_target="__cublas$gemm", backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" ROOT tuple = (f32[5,5]{1,0}, f32[5,5]{1,0}, f32[5,5]{1,0}, f32[5,5]{1,0}) tuple(custom-call, custom-call.1, custom-call.2, custom-call.3) } @@ -168,12 +178,16 @@ ENTRY main { // "!invariant.load" (thanks to ir_array.MarkInvariantOverWholeProgram). // // CHECK-LABEL: target triple -// CHECK: define [[KERNEL_ANNOTATION]]void @fusion_2(ptr noalias align 128 dereferenceable(100) %arg0) +// CHECK-PTX: define void @fusion_2(ptr noalias align 128 dereferenceable(100) %{{.*}}) +// CHECK-GCN: define amdgpu_kernel void @fusion_2(ptr noalias align 128 dereferenceable(100) %{{.*}}) // CHECK-NOT: !invariant.load -// CHECK: define [[KERNEL_ANNOTATION]]void @fusion(ptr noalias align 128 dereferenceable(100) %arg0, ptr noalias align 128 dereferenceable(100) %arg1) -// CHECK-NOT: define [[KERNEL_ANNOTATION]]void +// CHECK-PTX: define void @fusion(ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) +// CHECK-GCN: define amdgpu_kernel void @fusion(ptr noalias align 128 dereferenceable(100) %{{.*}}, ptr noalias align 128 dereferenceable(100) %{{.*}}) +// CHECK-PTX-NOT: define void +// CHECK-GCN-NOT: define amdgpu_kernel void // CHECK: !invariant.load -// CHECK-NOT: define [[KERNEL_ANNOTATION]]void +// CHECK-PTX-NOT: define void +// CHECK-GCN-NOT: define amdgpu_kernel void HloModule KernelReuse, is_scheduled=true @@ -194,13 +208,13 @@ fused_computation.2 { ENTRY main { a = f32[5,5]{1,0} parameter(0) - custom-call = f32[5,5]{1,0} custom-call(a, a), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}" + custom-call = f32[5,5]{1,0} custom-call(a, a), custom_call_target="__cublas$gemm", backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" fusion.2 = f32[5,5]{1,0} fusion(custom-call), kind=kLoop, calls=fused_computation.2 - custom-call.1 = f32[5,5]{1,0} custom-call(fusion.2, fusion.2), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}" + custom-call.1 = f32[5,5]{1,0} custom-call(fusion.2, fusion.2), custom_call_target="__cublas$gemm", backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" fusion = f32[5,5]{1,0} fusion(custom-call.1), kind=kLoop, calls=fused_computation fusion.1 = f32[5,5]{1,0} fusion(custom-call.1), kind=kLoop, calls=fused_computation.1 - custom-call.2 = f32[5,5]{1,0} custom-call(fusion.1, fusion.1), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}" - custom-call.3 = f32[5,5]{1,0} custom-call(fusion, fusion), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}" + custom-call.2 = f32[5,5]{1,0} custom-call(fusion.1, fusion.1), custom_call_target="__cublas$gemm", backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" + custom-call.3 = f32[5,5]{1,0} custom-call(fusion, fusion), custom_call_target="__cublas$gemm", backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" // We don't output custom-call, so fusion.2 can change its input. ROOT tuple = (f32[5,5]{1,0}, f32[5,5]{1,0}, f32[5,5]{1,0}) tuple(custom-call.1, custom-call.2, custom-call.3) } diff --git a/xla/service/gpu/tests/launch_dimensions.hlo b/xla/service/gpu/tests/launch_dimensions.hlo index 8e3077a3620f9..bcfa37733f7e6 100644 --- a/xla/service/gpu/tests/launch_dimensions.hlo +++ b/xla/service/gpu/tests/launch_dimensions.hlo @@ -1,12 +1,14 @@ -// RUN: hlo_to_llvm_ir %s | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %{IR_SUBST} %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s // This tests that we do not increase the grid launch size when // few_waves is enabled. -// CHECK-LABEL: entry: -// CHECK: call i32 [[CTAIDX]](), !range ![[ctaid_range:[0-9]+]] -// CHECK: call i32 [[TIDX]](), !range ![[tid_range:[0-9]+]] -// CHECK: ![[ctaid_range]] = !{i32 0, i32 2} -// CHECK: ![[tid_range]] = !{i32 0, i32 1024} +// CHECK-LABEL: define void @wrapped_b +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-DAG: ![[ctaid_range]] = !{i32 0, i32 2} +// CHECK-DAG: ![[tid_range]] = !{i32 0, i32 1024} HloModule Test, is_scheduled=true @@ -25,12 +27,14 @@ ENTRY main { // This tests that we cap grid launch code when few_waves is enabled. -// CHECK-LABEL: entry: -// CHECK: call i32 [[CTAIDX]](), !range ![[ctaid_range:[0-9]+]] -// CHECK: call i32 [[TIDX]](), !range ![[tid_range:[0-9]+]] -// CHECK-PTX: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-GCN: ![[ctaid_range]] = !{i32 0, i32 1664} -// CHECK: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-LABEL: define void @wrapped_b +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} +// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} +// CHECK-DAG: ![[tid_range]] = !{i32 0, i32 128} HloModule Test, is_scheduled=true @@ -49,15 +53,15 @@ ENTRY main { // This tests that we cap grid launch code when few_waves is enabled // and scalar broadcast are present. -// CHECK-LABEL: entry: -// CHECK-PTX: call i32 [[CTAIDX]](), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX: call i32 [[TIDX]](), !range ![[tid_range:[0-9]+]] -// CHECK-PTX: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-PTX: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN: ![[ctaid_range]] = !{i32 0, i32 1664} -// CHECK-GCN: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-LABEL: define void @fusion_3 +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} +// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} +// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 128} HloModule ScalarBroadcast, is_scheduled=true @@ -80,15 +84,15 @@ ENTRY main { // This tests that we enable few_waves in a simple fusion. It is the baseline // for the tests below. -// CHECK-LABEL: entry: -// CHECK-PTX: call i32 [[CTAIDX]](), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX: call i32 [[TIDX]](), !range ![[tid_range:[0-9]+]] -// CHECK-PTX: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-PTX: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN: ![[ctaid_range]] = !{i32 0, i32 1664} -// CHECK-GCN: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-LABEL: define void @fusion +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} +// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} +// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 128} HloModule SimpleFusion, is_scheduled=true @@ -109,15 +113,15 @@ ENTRY main { // This tests that we keep few_waves enabled for large constants. -// CHECK-LABEL: entry: -// CHECK-PTX: call i32 [[CTAIDX]](), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX: call i32 [[TIDX]](), !range ![[tid_range:[0-9]+]] -// CHECK-PTX: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-PTX: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN: ![[ctaid_range]] = !{i32 0, i32 1664} -// CHECK-GCN: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-LABEL: define void @fusion +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} +// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} +// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 128} HloModule LargeConstant, is_scheduled=true @@ -137,15 +141,15 @@ ENTRY main { // This tests that we disable few_waves if a non-elementwise op is present. -// CHECK-LABEL: entry: -// CHECK-PTX: call i32 [[CTAIDX]](), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX: call i32 [[TIDX]](), !range ![[tid_range:[0-9]+]] -// CHECK-PTX: ![[ctaid_range]] = !{i32 0, i32 195313} -// CHECK-PTX: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN: ![[ctaid_range]] = !{i32 0, i32 97657} -// CHECK-GCN: ![[tid_range]] = !{i32 0, i32 256} +// CHECK-LABEL: define void @fusion +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 195313} +// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 97657} +// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 256} HloModule NonElementwise, is_scheduled=true @@ -171,15 +175,15 @@ ENTRY main { // - the fusion is not row-vectorizable // It serves as a baseline for the tests below. -// CHECK-LABEL: entry: -// CHECK-PTX: call i32 [[CTAIDX]](), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX: call i32 [[TIDX]](), !range ![[tid_range:[0-9]+]] -// CHECK-PTX: ![[ctaid_range]] = !{i32 0, i32 7813} -// CHECK-PTX: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN: ![[ctaid_range]] = !{i32 0, i32 3907} -// CHECK-GCN: ![[tid_range]] = !{i32 0, i32 256} +// CHECK-LABEL: define void @fusion +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 7813} +// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 3907} +// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 256} HloModule NoFewWaves, is_scheduled=true @@ -215,15 +219,15 @@ ENTRY main { // - the fusion IS row-vectorizable // In this case, the block count is changed from 7813 to 2000. -// CHECK-LABEL: entry: -// CHECK-PTX: call i32 [[CTAIDX]](), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX: call i32 [[TIDX]](), !range ![[tid_range:[0-9]+]] -// CHECK-PTX: ![[ctaid_range]] = !{i32 0, i32 2000} -// CHECK-PTX: ![[tid_range]] = !{i32 0, i32 500} -// CHECK-GCN: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN: ![[ctaid_range]] = !{i32 0, i32 2000} -// CHECK-GCN: ![[tid_range]] = !{i32 0, i32 500} +// CHECK-LABEL: define void @fusion +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 2000} +// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 500} +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 2000} +// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 500} HloModule RowVectorizable, is_scheduled=true @@ -256,15 +260,15 @@ ENTRY main { // - the fusion is not row-vectorizable // In this case, the block count is changed from 7813 to 1008. -// CHECK-LABEL: entry: -// CHECK-PTX: call i32 [[CTAIDX]](), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX: call i32 [[TIDX]](), !range ![[tid_range:[0-9]+]] -// CHECK-PTX: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-PTX: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN: ![[ctaid_range]] = !{i32 0, i32 1664} -// CHECK-GCN: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-LABEL: define void @fusion +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} +// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} +// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 128} HloModule ScalarBroadcastFourInputs, is_scheduled=true @@ -296,12 +300,14 @@ ENTRY main { // This tests the GELU kernel. The original kernel that // motivated few_waves implementation. -// CHECK-LABEL: entry: -// CHECK: call i32 [[CTAIDX]](), !range ![[ctaid_range:[0-9]+]] -// CHECK: call i32 [[TIDX]](), !range ![[tid_range:[0-9]+]] -// CHECK-PTX: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-GCN: ![[ctaid_range]] = !{i32 0, i32 1664} -// CHECK: ![[tid_range]] = !{i32 0, i32 128} +// CHECK-LABEL: define void @fusion +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] +// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] +// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] +// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} +// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} +// CHECK-DAG: ![[tid_range]] = !{i32 0, i32 128} HloModule Test, is_scheduled=true diff --git a/xla/service/gpu/tests/mock_custom_call_test.cc b/xla/service/gpu/tests/mock_custom_call_test.cc index bcb6de5dd13ee..90380c4a5db72 100644 --- a/xla/service/gpu/tests/mock_custom_call_test.cc +++ b/xla/service/gpu/tests/mock_custom_call_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/tests/hlo_test_base.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/tests/pad_to_static.hlo b/xla/service/gpu/tests/pad_to_static.hlo index 77273eb930204..6e147df3928c0 100644 --- a/xla/service/gpu/tests/pad_to_static.hlo +++ b/xla/service/gpu/tests/pad_to_static.hlo @@ -1,4 +1,4 @@ -// RUN: hlo_to_llvm_ir %s | FileCheck %{IR_SUBST} %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py @@ -9,9 +9,11 @@ // CHECK: %[[VAL_4:.*]] = load i32, ptr %[[VAL_3]], align 4 // CHECK: %[[VAL_5:.*]] = getelementptr inbounds i8, ptr %[[VAL_1]], i32 40 // CHECK: %[[VAL_6:.*]] = load i32, ptr %[[VAL_5]], align 4 -// CHECK: %[[VAL_7:.*]] = call i32 [[TIDX]] +// CHECK-PTX: %[[VAL_7:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x +// CHECK-GCN: %[[VAL_7:.*]] = call i32 @llvm.amdgcn.workitem.id.x // CHECK: %[[VAL_8:.*]] = icmp eq i32 0, %[[VAL_7]] -// CHECK: %[[VAL_9:.*]] = call i32 [[CTAIDX]] +// CHECK-PTX: %[[VAL_9:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x +// CHECK-GCN: %[[VAL_9:.*]] = call i32 @llvm.amdgcn.workgroup.id.x // CHECK: %[[VAL_10:.*]] = icmp eq i32 0, %[[VAL_9]] // CHECK: %[[VAL_11:.*]] = and i1 %[[VAL_8]], %[[VAL_10]] // CHECK: br i1 %[[VAL_11]], label %[[VAL_12:.*]], label %[[VAL_13:.*]] @@ -19,50 +21,54 @@ // CHECK: %[[VAL_15:.*]] = mul i32 1, %[[VAL_2]] // CHECK: %[[VAL_16:.*]] = mul i32 %[[VAL_15]], %[[VAL_4]] // CHECK: %[[VAL_17:.*]] = mul i32 %[[VAL_16]], %[[VAL_6]] -// CHECK: %[[VAL_18:.*]] = call i32 [[CTAIDX]] -// CHECK: %[[VAL_19:.*]] = call i32 [[TIDX]] +// CHECK-PTX: %[[VAL_18:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x +// CHECK-GCN: %[[VAL_18:.*]] = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK-PTX: %[[VAL_19:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x +// CHECK-GCN: %[[VAL_19:.*]] = call i32 @llvm.amdgcn.workitem.id.x // CHECK: %[[VAL_20:.*]] = mul nuw nsw i32 %[[VAL_18]], 8 // CHECK: %[[VAL_21:.*]] = add nuw nsw i32 %[[VAL_20]], %[[VAL_19]] // CHECK: %[[VAL_22:.*]] = icmp ult i32 %[[VAL_21]], 8 // CHECK: call void @llvm.assume(i1 %[[VAL_22]]) -// CHECK: %[[VAL_23:.*]] = udiv i32 %[[VAL_21]], 1 -// CHECK: %[[VAL_24:.*]] = urem i32 %[[VAL_23]], 2 -// CHECK: %[[VAL_25:.*]] = udiv i32 %[[VAL_21]], 2 -// CHECK: %[[VAL_26:.*]] = urem i32 %[[VAL_25]], 2 -// CHECK: %[[VAL_27:.*]] = udiv i32 %[[VAL_21]], 4 -// CHECK: %[[VAL_28:.*]] = icmp ult i32 %[[VAL_21]], 8 -// CHECK: br i1 %[[VAL_28]], label %[[VAL_29:.*]], label %[[VAL_30:.*]] -// CHECK: custom_call_2.in_bounds-after: ; preds = %[[VAL_31:.*]], %[[VAL_13]] +// CHECK: %[[VAL_23:.*]] = add nuw nsw i32 %[[VAL_21]], 0 +// CHECK: %[[VAL_24:.*]] = udiv i32 %[[VAL_23]], 1 +// CHECK: %[[VAL_25:.*]] = urem i32 %[[VAL_24]], 2 +// CHECK: %[[VAL_26:.*]] = udiv i32 %[[VAL_23]], 2 +// CHECK: %[[VAL_27:.*]] = urem i32 %[[VAL_26]], 2 +// CHECK: %[[VAL_28:.*]] = udiv i32 %[[VAL_23]], 4 +// CHECK: %[[VAL_29:.*]] = icmp ult i32 %[[VAL_21]], 8 +// CHECK: br i1 %[[VAL_29]], label %[[VAL_30:.*]], label %[[VAL_31:.*]] +// CHECK: custom_call_2.in_bounds-after: ; preds = %[[VAL_32:.*]], %[[VAL_13]] // CHECK: ret void // CHECK: is_thread_0-true: ; preds = %[[VAL_14]] -// CHECK: store i32 %[[VAL_2]], ptr %[[VAL_32:.*]], align 4 -// CHECK: store i32 %[[VAL_4]], ptr %[[VAL_33:.*]], align 4 -// CHECK: store i32 %[[VAL_6]], ptr %[[VAL_34:.*]], align 4 +// CHECK: store i32 %[[VAL_2]], ptr %[[VAL_33:.*]], align 4 +// CHECK: store i32 %[[VAL_4]], ptr %[[VAL_34:.*]], align 4 +// CHECK: store i32 %[[VAL_6]], ptr %[[VAL_35:.*]], align 4 // CHECK: br label %[[VAL_13]] // CHECK: custom_call_2.in_bounds-true: ; preds = %[[VAL_13]] -// CHECK: %[[VAL_35:.*]] = mul nuw nsw i32 %[[VAL_24]], 1 -// CHECK: %[[VAL_36:.*]] = add nuw nsw i32 0, %[[VAL_35]] -// CHECK: %[[VAL_37:.*]] = mul nuw nsw i32 %[[VAL_26]], 2 -// CHECK: %[[VAL_38:.*]] = add nuw nsw i32 %[[VAL_36]], %[[VAL_37]] -// CHECK: %[[VAL_39:.*]] = mul nuw nsw i32 %[[VAL_27]], 4 -// CHECK: %[[VAL_40:.*]] = add nuw nsw i32 %[[VAL_38]], %[[VAL_39]] -// CHECK: %[[VAL_41:.*]] = icmp ult i32 %[[VAL_40]], %[[VAL_17]] -// CHECK: br i1 %[[VAL_41]], label %[[VAL_42:.*]], label %[[VAL_31]] -// CHECK: custom_call_2.in_dyn_bounds-after: ; preds = %[[VAL_42]], %[[VAL_29]] -// CHECK: br label %[[VAL_30]] -// CHECK: custom_call_2.in_dyn_bounds-true: ; preds = %[[VAL_29]] -// CHECK: %[[VAL_43:.*]] = udiv i32 %[[VAL_40]], 1 -// CHECK: %[[VAL_44:.*]] = urem i32 %[[VAL_43]], %[[VAL_6]] -// CHECK: %[[VAL_45:.*]] = mul i32 1, %[[VAL_6]] -// CHECK: %[[VAL_46:.*]] = udiv i32 %[[VAL_40]], %[[VAL_45]] -// CHECK: %[[VAL_47:.*]] = urem i32 %[[VAL_46]], %[[VAL_4]] -// CHECK: %[[VAL_48:.*]] = mul i32 %[[VAL_45]], %[[VAL_4]] -// CHECK: %[[VAL_49:.*]] = udiv i32 %[[VAL_40]], %[[VAL_48]] -// CHECK: %[[VAL_50:.*]] = getelementptr inbounds i32, ptr %[[VAL_1]], i32 %[[VAL_21]] -// CHECK: %[[VAL_51:.*]] = load i32, ptr %[[VAL_50]], align 4, !invariant.load -// CHECK: %[[VAL_52:.*]] = getelementptr inbounds [2 x [2 x [2 x i32]]], ptr %[[VAL_53:.*]], i32 0, i32 %[[VAL_49]], i32 %[[VAL_47]], i32 %[[VAL_44]] -// CHECK: store i32 %[[VAL_51]], ptr %[[VAL_52]], align 4 +// CHECK: %[[VAL_36:.*]] = mul nuw nsw i32 %[[VAL_25]], 1 +// CHECK: %[[VAL_37:.*]] = add nuw nsw i32 0, %[[VAL_36]] +// CHECK: %[[VAL_38:.*]] = mul nuw nsw i32 %[[VAL_27]], 2 +// CHECK: %[[VAL_39:.*]] = add nuw nsw i32 %[[VAL_37]], %[[VAL_38]] +// CHECK: %[[VAL_40:.*]] = mul nuw nsw i32 %[[VAL_28]], 4 +// CHECK: %[[VAL_41:.*]] = add nuw nsw i32 %[[VAL_39]], %[[VAL_40]] +// CHECK: %[[VAL_42:.*]] = icmp ult i32 %[[VAL_41]], %[[VAL_17]] +// CHECK: br i1 %[[VAL_42]], label %[[VAL_43:.*]], label %[[VAL_32]] +// CHECK: custom_call_2.in_dyn_bounds-after: ; preds = %[[VAL_43]], %[[VAL_30]] // CHECK: br label %[[VAL_31]] +// CHECK: custom_call_2.in_dyn_bounds-true: ; preds = %[[VAL_30]] +// CHECK: %[[VAL_44:.*]] = udiv i32 %[[VAL_41]], 1 +// CHECK: %[[VAL_45:.*]] = urem i32 %[[VAL_44]], %[[VAL_6]] +// CHECK: %[[VAL_46:.*]] = mul i32 1, %[[VAL_6]] +// CHECK: %[[VAL_47:.*]] = udiv i32 %[[VAL_41]], %[[VAL_46]] +// CHECK: %[[VAL_48:.*]] = urem i32 %[[VAL_47]], %[[VAL_4]] +// CHECK: %[[VAL_49:.*]] = mul i32 %[[VAL_46]], %[[VAL_4]] +// CHECK: %[[VAL_50:.*]] = udiv i32 %[[VAL_41]], %[[VAL_49]] +// CHECK: %[[VAL_51:.*]] = getelementptr i32, ptr %[[VAL_1]], i32 %[[VAL_21]] +// CHECK: %[[VAL_52:.*]] = getelementptr inbounds i32, ptr %[[VAL_51]], i32 0 +// CHECK: %[[VAL_53:.*]] = load i32, ptr %[[VAL_52]], align 4, !invariant.load +// CHECK: %[[VAL_54:.*]] = getelementptr inbounds [2 x [2 x [2 x i32]]], ptr %[[VAL_55:.*]], i32 0, i32 %[[VAL_50]], i32 %[[VAL_48]], i32 %[[VAL_45]] +// CHECK: store i32 %[[VAL_53]], ptr %[[VAL_54]], align 4 +// CHECK: br label %[[VAL_32]] HloModule PadToStatic, is_scheduled=true diff --git a/xla/service/gpu/tests/parallel_reduction_test.cc b/xla/service/gpu/tests/parallel_reduction_test.cc index 39f81b7895f32..30e37759c0519 100644 --- a/xla/service/gpu/tests/parallel_reduction_test.cc +++ b/xla/service/gpu/tests/parallel_reduction_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,11 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include - +#include +#include + +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/service/hlo_parser.h" -#include "xla/tests/filecheck.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/tests/pred_arithmetic_test.cc b/xla/service/gpu/tests/pred_arithmetic_test.cc index d29898e27ba74..17fa51972ce4d 100644 --- a/xla/service/gpu/tests/pred_arithmetic_test.cc +++ b/xla/service/gpu/tests/pred_arithmetic_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "xla/literal_util.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/tests/reduce_atomic_min.hlo b/xla/service/gpu/tests/reduce_atomic_min.hlo new file mode 100644 index 0000000000000..cff4ff8a79a5b --- /dev/null +++ b/xla/service/gpu/tests/reduce_atomic_min.hlo @@ -0,0 +1,310 @@ +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck %s --check-prefixes=CHECK,CHECK-%{PTX} + +// Check that for "min" we are still using atomics (CAS loop). + +HloModule MinReduce, is_scheduled=true + +Min { + x.1 = f32[] parameter(0) + y.1 = f32[] parameter(1) + ROOT min.1 = f32[] minimum(x.1, y.1) +} + +fused_computation { + param_0 = f32[300000]{0} parameter(0) + param_1 = f32[] parameter(1) + ROOT reduce.1 = f32[] reduce(f32[300000]{0} param_0, f32[] param_1), dimensions={0}, to_apply=Min +} + +ENTRY reduce.1 { + parameter = f32[300000] parameter(0) + init_value = f32[] constant(0) + ROOT wrapped_reduce = f32[] fusion(f32[300000]{0} parameter, f32[] init_value), kind=kInput, calls=fused_computation +} + +// CHECK-LABEL: entry: +// CHECK-PTX: %[[VAL_0:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !4 +// CHECK-GCN: %[[VAL_0:.*]] = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %[[VAL_1:.*]] = zext i32 %[[VAL_0]] to i64 +// CHECK-PTX: %[[VAL_2:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !4 +// CHECK-GCN: %[[VAL_2:.*]] = call i32 @llvm.amdgcn.workitem.id.x +// CHECK: %[[VAL_3:.*]] = zext i32 %[[VAL_2]] to i64 +// CHECK: %[[VAL_4:.*]] = mul nuw nsw i64 %[[VAL_1]], 1 +// CHECK: %[[VAL_5:.*]] = add nuw nsw i64 %[[VAL_4]], %[[VAL_3]] +// CHECK: %[[VAL_6:.*]] = icmp ult i64 %[[VAL_5]], 1 +// CHECK: call void @llvm.assume(i1 %[[VAL_6]]) +// CHECK: %[[VAL_7:.*]] = add nuw nsw i64 %[[VAL_5]], 0 +// CHECK: %[[VAL_8:.*]] = icmp ult i64 %[[VAL_5]], 1 +// CHECK: br i1 %[[VAL_8]], label %[[VAL_9:.*]], label %[[VAL_10:.*]] +// CHECK: wrapped_reduce.in_bounds-after: ; preds = %[[VAL_9]], %[[VAL_11:.*]] +// CHECK: ret void +// CHECK: wrapped_reduce.in_bounds-true: ; preds = %[[VAL_11]] +// CHECK: %[[VAL_12:.*]] = load float, ptr %[[VAL_13:.*]], align 4, !invariant.load !5 +// CHECK: store float %[[VAL_12]], ptr %[[VAL_14:.*]], align 4 +// CHECK: br label %[[VAL_10]] +// CHECK: entry: +// CHECK: %[[VAL_15:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_16:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_17:.*]] = alloca float, align 4 +// CHECK: %[[VAL_18:.*]] = alloca float, align 4 +// CHECK: %[[VAL_19:.*]] = alloca float, align 4 +// CHECK: %[[VAL_20:.*]] = alloca float, align 4 +// CHECK: %[[VAL_21:.*]] = alloca float, align 4 +// CHECK: %[[VAL_22:.*]] = alloca float, align 4 +// CHECK: %[[VAL_23:.*]] = alloca float, align 4 +// CHECK: %[[VAL_24:.*]] = alloca float, align 4 +// CHECK: %[[VAL_25:.*]] = alloca float, align 4 +// CHECK: %[[VAL_26:.*]] = alloca float, align 4 +// CHECK: %[[VAL_27:.*]] = alloca float, align 4 +// CHECK: %[[VAL_28:.*]] = alloca float, align 4 +// CHECK: %[[VAL_29:.*]] = alloca float, align 4 +// CHECK: %[[VAL_30:.*]] = alloca float, align 4 +// CHECK: %[[VAL_31:.*]] = alloca float, align 4 +// CHECK: %[[VAL_32:.*]] = alloca float, align 4 +// CHECK: %[[VAL_33:.*]] = alloca float, align 4 +// CHECK: %[[VAL_34:.*]] = alloca float, align 4 +// CHECK: %[[VAL_35:.*]] = alloca float, align 4 +// CHECK: %[[VAL_36:.*]] = alloca float, align 4 +// CHECK: %[[VAL_37:.*]] = alloca float, align 4 +// CHECK: %[[VAL_38:.*]] = alloca float, align 4 +// CHECK: %[[LOOP3_I_2:loop3.invar_address.*]] = alloca i32, align 4 +// CHECK: %[[LOOP2_I_2:loop2.invar_address.*]] = alloca i32, align 4 +// CHECK: %[[VAL_42:return_buffer.*]] = alloca float, align 4 +// CHECK: %[[VAL_40:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_43:.*]] = alloca i32, align 4 +// CHECK: %partial_reduction_result = alloca float, align 4 +// CHECK: %reduction_input_address = alloca float, align 4 +// CHECK-PTX: %[[VAL_47:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !range !4 +// CHECK-GCN: %[[VAL_47:.*]] = call i32 @llvm.amdgcn.workgroup.id.y +// CHECK: %[[VAL_48:.*]] = icmp eq i32 %[[VAL_47]], 0 +// CHECK: br i1 %[[VAL_48]], label %[[VAL_49:.*]], label %[[VAL_50:.*]] +// CHECK: reduce-group-0-after: ; preds = %[[VAL_51:.*]], %[[VAL_52:.*]] +// CHECK: ret void +// CHECK: reduce-group-0-true: ; preds = %[[VAL_52]] +// CHECK: %[[VAL_53:.*]] = load float, ptr %[[VAL_54:.*]], align 4, !invariant.load !5 +// CHECK: store float %[[VAL_53]], ptr %partial_reduction_result, align 4 +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !6 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !7 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 1024 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_63:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VECTOR_OFFSET:.*]] = urem i32 %[[VAL_63]], 1 +// CHECK: %[[VAL_63_2:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_64:.*]] = urem i32 %[[VAL_63_2]], 19 +// CHECK: %[[VAL_65:.*]] = udiv i32 %block.id.x, 19 +// CHECK: %[[VAL_66:.*]] = urem i32 %[[VAL_65]], 1 +// CHECK: %[[VAL_67:.*]] = udiv i32 %block.id.x, 19 +// CHECK: %[[VAL_68:.*]] = icmp eq i32 %[[VAL_64]], 18 +// CHECK: %tile_bound.2 = select i1 %[[VAL_68]], i32 2544, i32 8192 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_67]], 1 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_66]], 1 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_64]], 8192 +// CHECK: %tile_origin.3 = mul i32 %[[VECTOR_OFFSET]], 2 +// CHECK: %[[VAL_81:.*]] = icmp eq i32 8192, %tile_bound.2 +// CHECK: br i1 %[[VAL_81]], label %[[VAL_82:.*]], label %[[VAL_83:.*]] +// CHECK: is_full_tile-after: ; preds = %[[VAL_84:.*]], %[[VAL_85:.*]] +// CHECK: %[[VAL_86:.*]] = load float, ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_87:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_86]], i32 16, i32 31) +// CHECK: store float %[[VAL_87]], ptr %[[VAL_37]], align 4 +// CHECK: call void @[[MIN:Min.*]](ptr %partial_reduction_result, ptr %[[VAL_37]], ptr %[[VAL_36]]) +// CHECK: %[[VAL_88:.*]] = load float, ptr %[[VAL_36]], align 4 +// CHECK: store float %[[VAL_88]], ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_89:.*]] = load float, ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_90:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_89]], i32 8, i32 31) +// CHECK: store float %[[VAL_90]], ptr %[[VAL_35]], align 4 +// CHECK: call void @[[MIN]](ptr %partial_reduction_result, ptr %[[VAL_35]], ptr %[[VAL_34]]) +// CHECK: %[[VAL_91:.*]] = load float, ptr %[[VAL_34]], align 4 +// CHECK: store float %[[VAL_91]], ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_92:.*]] = load float, ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_93:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_92]], i32 4, i32 31) +// CHECK: store float %[[VAL_93]], ptr %[[VAL_33]], align 4 +// CHECK: call void @[[MIN]](ptr %partial_reduction_result, ptr %[[VAL_33]], ptr %[[VAL_32]]) +// CHECK: %[[VAL_94:.*]] = load float, ptr %[[VAL_32]], align 4 +// CHECK: store float %[[VAL_94]], ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_95:.*]] = load float, ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_96:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_95]], i32 2, i32 31) +// CHECK: store float %[[VAL_96]], ptr %[[VAL_31]], align 4 +// CHECK: call void @[[MIN]](ptr %partial_reduction_result, ptr %[[VAL_31]], ptr %[[VAL_30]]) +// CHECK: %[[VAL_97:.*]] = load float, ptr %[[VAL_30]], align 4 +// CHECK: store float %[[VAL_97]], ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_98:.*]] = load float, ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_99:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_98]], i32 1, i32 31) +// CHECK: store float %[[VAL_99]], ptr %[[VAL_29]], align 4 +// CHECK: call void @[[MIN]](ptr %partial_reduction_result, ptr %[[VAL_29]], ptr %[[VAL_28]]) +// CHECK: %[[VAL_100:.*]] = load float, ptr %[[VAL_28]], align 4 +// CHECK: store float %[[VAL_100]], ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_101:.*]] = udiv i32 %thread.id.2, 32 +// CHECK: br i1 true, label %[[VAL_105:.*]], label %[[VAL_51]] +// CHECK: thread_in_bounds-after: +// CHECK: br label %[[VAL_50]] +// CHECK: is_full_tile-true: +// CHECK: store i32 0, ptr %[[VAL_43]], align 4 +// CHECK: br label %[[VAL_107:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_108:.*]], %[[VAL_82]] +// CHECK: %[[VAL_109:.*]] = load i32, ptr %[[VAL_43]], align 4 +// CHECK: %[[VAL_110:.*]] = icmp uge i32 %[[VAL_109]], 8 +// CHECK: br i1 %[[VAL_110]], label %loop2.loop_exit, label %loop2.loop_body +// CHECK: loop2.loop_body: ; preds = %[[VAL_107]] +// CHECK: %[[VAL_111:.*]] = add nuw nsw i32 %[[VAL_109]], 1 +// CHECK: store i32 %[[VAL_111]], ptr %[[VAL_43]], align 4 +// CHECK: %[[VAL_112:.*]] = icmp eq i32 %[[VAL_109]], 0 +// CHECK: %[[OFFSET_2:.*]] = add i32 %loop2.indvar, %thread.id.2 +// CHECK: store i32 0, ptr %loop3.invar_address, align 4 +// CHECK: br label %loop3.loop_header +// CHECK: loop3.loop_header: +// CHECK: %loop3.indvar = load i32, ptr %loop3.invar_address, align 4 +// CHECK: %[[LOOP3_OOB:.*]] = icmp uge i32 %loop3.indvar, 2 +// CHECK: br i1 %[[LOOP3_OOB]], label %loop3.loop_exit, label %loop3.loop_body +// CHECK: loop3.loop_body: +// CHECK: %[[LOOP3_INC:.*]] = add nuw nsw i32 %loop3.indvar, 1 +// CHECK: store i32 %[[LOOP3_INC]], ptr %loop3.invar_address, align 4 +// CHECK: %[[START_0:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[START_1:.*]] = add i32 %tile_origin.1, 0 +// CHECK: %[[START_2:.*]] = add i32 %tile_origin.2, %[[OFFSET_2]] +// CHECK: %[[START_3:.*]] = add i32 %tile_origin.3, %loop3.indvar +// CHECK: %[[VAL_113:.*]] = mul nuw nsw i32 %[[START_3]], 1 +// CHECK: %[[VAL_114:.*]] = add nuw nsw i32 0, %[[VAL_113]] +// CHECK: %[[VAL_115:.*]] = mul nuw nsw i32 %[[START_2]], 2 +// CHECK: %[[VAL_116:.*]] = add nuw nsw i32 %[[VAL_114]], %[[VAL_115]] +// CHECK: %[[VAL_119:.*]] = getelementptr inbounds [300000 x float], ptr %[[VAL_120:.*]], i32 0, i32 %[[VAL_116]] +// CHECK: %[[VAL_121:.*]] = load float, ptr %[[VAL_119]], align 4, !invariant.load !5 +// CHECK: store float %[[VAL_121]], ptr %reduction_input_address, align 4 +// CHECK: call void @[[MIN]](ptr %partial_reduction_result, ptr %reduction_input_address, ptr %[[VAL_42]]) +// CHECK: %[[VAL_123:.*]] = load float, ptr %[[VAL_42]], align 4 +// CHECK: store float %[[VAL_123]], ptr %partial_reduction_result, align 4 +// CHECK: br label %loop3.loop_header +// CHECK: loop3.loop_exit: +// CHECK: br label %loop2.loop_header +// CHECK: loop2.loop_exit: +// CHECK: br label %is_full_tile-after +// CHECK: is_full_tile-false: +// CHECK: store i32 0, ptr %[[LOOP2_I_2]], align 4 +// CHECK: br label %[[VAL_134:.*]] +// CHECK: loop2.loop_header4: +// CHECK: %[[VAL_136:.*]] = load i32, ptr %[[LOOP2_I_2]], align 4 +// CHECK: %[[VAL_137:.*]] = icmp uge i32 %[[VAL_136]], 8 +// CHECK: br i1 %[[VAL_137]], label %[[VAL_84]], label %[[VAL_138:.*]] +// CHECK: loop2.loop_body5: +// CHECK: %[[VAL_139:.*]] = add nuw nsw i32 %[[VAL_136]], 1 +// CHECK: store i32 %[[VAL_139]], ptr %[[LOOP2_I_2]], align 4 +// CHECK: %[[VAL_140:.*]] = icmp eq i32 %[[VAL_136]], 0 +// CHECK: %[[VAL_141:.*]] = add i32 %[[VAL_136]], %thread.id.2 +// CHECK: %[[VAL_144:.*]] = icmp ult i32 %[[VAL_141]], %tile_bound.2 +// CHECK: br i1 %[[VAL_144]], label %x_in_tile-true, label %x_in_tile-after +// CHECK: x_in_tile-after: +// CHECK: br label %loop2.loop_header4 +// CHECK: loop2.loop_exit3: +// CHECK: br label %is_full_tile-after +// CHECK: x_in_tile-true: ; preds = %[[VAL_138]] +// CHECK: store i32 0, ptr %[[LOOP3_I_2]], align 4 +// CHECK: br label %loop3.loop_header10 +// CHECK: loop3.loop_header10: +// CHECK: %[[VAL_145:.*]] = load i32, ptr %[[LOOP3_I_2]], align 4 +// CHECK: %[[VAL_146:.*]] = icmp uge i32 %[[VAL_145]], 2 +// CHECK: br i1 %[[VAL_146]], label %loop3.loop_exit9, label %loop3.loop_body11 +// CHECK: loop3.loop_body11: +// CHECK: %[[VAL_147:.*]] = add nuw nsw i32 %[[VAL_145]], 1 +// CHECK: store i32 %[[VAL_147]], ptr %[[LOOP3_I_2]], align 4 +// CHECK: %[[IDX0:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[IDX1:.*]] = add i32 %tile_origin.1, 0 +// CHECK: %[[IDX2:.*]] = add i32 %tile_origin.2, %[[VAL_141]] +// CHECK: %[[IDX3:.*]] = add i32 %tile_origin.3, %[[VAL_145]] +// CHECK: %[[VAL_148:.*]] = mul nuw nsw i32 %[[IDX3]], 1 +// CHECK: %[[VAL_149:.*]] = add nuw nsw i32 0, %[[VAL_148]] +// CHECK: %[[VAL_150:.*]] = mul nuw nsw i32 %[[IDX2]], 2 +// CHECK: %[[VAL_151:.*]] = add nuw nsw i32 %[[VAL_149]], %[[VAL_150]] +// CHECK: %[[VAL_155:.*]] = getelementptr inbounds [300000 x float], ptr %[[VAL_120]], i32 0, i32 %[[VAL_151]] +// CHECK: %[[VAL_156:.*]] = load float, ptr %[[VAL_155]], align 4, !invariant.load !5 +// CHECK: store float %[[VAL_156]], ptr %reduction_input_address, align 4 +// CHECK: call void @[[MIN]](ptr %partial_reduction_result, ptr %reduction_input_address, ptr %[[VAL_38]]) +// CHECK: %[[VAL_158:.*]] = load float, ptr %[[VAL_38]], align 4 +// CHECK: store float %[[VAL_158]], ptr %partial_reduction_result, align 4 +// CHECK: br label %loop3.loop_header10 +// CHECK: loop3.loop_exit9: +// CHECK: br label %x_in_tile-after +// CHECK: thread_in_bounds-true: +// CHECK: %[[VAL_166:.*]] = icmp eq i32 %lane_id, 0 +// CHECK: br i1 %[[VAL_166]], label %[[VAL_167:.*]], label %[[VAL_168:.*]] +// CHECK: intra_warp_reduce_write-after: ; preds = %[[VAL_167]], %[[VAL_105]] +// CHECK: call void @llvm.nvvm.barrier0() +// CHECK: %[[VAL_169:.*]] = icmp eq i32 %[[VAL_101]], 0 +// CHECK: br i1 %[[VAL_169]], label %inter_warp_reduce-true, label %inter_warp_reduce-after +// CHECK: inter_warp_reduce-after: ; preds = %[[VAL_171:.*]], %[[VAL_168]] +// CHECK: br label %[[VAL_51]] +// CHECK: intra_warp_reduce_write-true: ; preds = %[[VAL_105]] +// CHECK: %[[VAL_172:.*]] = load float, ptr %partial_reduction_result, align 4 +// CHECK: %[[VAL_173:.*]] = getelementptr inbounds [1 x [32 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 0, i32 %[[VAL_101]] +// CHECK: %[[VAL_174:.*]] = addrspacecast ptr addrspace(3) %[[VAL_173]] to ptr +// CHECK: store float %[[VAL_172]], ptr %[[VAL_174]], align 4 +// CHECK: br label %[[VAL_168]] +// CHECK: inter_warp_reduce-true: ; preds = %[[VAL_168]] +// CHECK: %[[VAL_175:.*]] = getelementptr inbounds [1 x [32 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 0, i32 %lane_id +// CHECK: %[[VAL_176:.*]] = addrspacecast ptr addrspace(3) %[[VAL_175]] to ptr +// CHECK: store float %[[VAL_53]], ptr %[[VAL_27]], align 4 +// CHECK: %[[VAL_177:.*]] = icmp ult i32 %thread.id.2, 32 +// CHECK: %[[VAL_178:.*]] = select i1 %[[VAL_177]], ptr %[[VAL_176]], ptr %[[VAL_27]] +// CHECK: %[[VAL_179:.*]] = load float, ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_180:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_179]], i32 16, i32 31) +// CHECK: store float %[[VAL_180]], ptr %[[VAL_26]], align 4 +// CHECK: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_26]], ptr %[[VAL_25]]) +// CHECK: %[[VAL_181:.*]] = load float, ptr %[[VAL_25]], align 4 +// CHECK: store float %[[VAL_181]], ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_182:.*]] = load float, ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_183:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_182]], i32 8, i32 31) +// CHECK: store float %[[VAL_183]], ptr %[[VAL_24]], align 4 +// CHECK: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_24]], ptr %[[VAL_23]]) +// CHECK: %[[VAL_184:.*]] = load float, ptr %[[VAL_23]], align 4 +// CHECK: store float %[[VAL_184]], ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_185:.*]] = load float, ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_186:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_185]], i32 4, i32 31) +// CHECK: store float %[[VAL_186]], ptr %[[VAL_22]], align 4 +// CHECK: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_22]], ptr %[[VAL_21]]) +// CHECK: %[[VAL_187:.*]] = load float, ptr %[[VAL_21]], align 4 +// CHECK: store float %[[VAL_187]], ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_188:.*]] = load float, ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_189:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_188]], i32 2, i32 31) +// CHECK: store float %[[VAL_189]], ptr %[[VAL_20]], align 4 +// CHECK: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_20]], ptr %[[VAL_19]]) +// CHECK: %[[VAL_190:.*]] = load float, ptr %[[VAL_19]], align 4 +// CHECK: store float %[[VAL_190]], ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_191:.*]] = load float, ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_192:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_191]], i32 1, i32 31) +// CHECK: store float %[[VAL_192]], ptr %[[VAL_18]], align 4 +// CHECK: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_18]], ptr %[[VAL_17]]) +// CHECK: %[[VAL_193:.*]] = load float, ptr %[[VAL_17]], align 4 +// CHECK: store float %[[VAL_193]], ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_194:.*]] = icmp eq i32 %thread.id.2, 0 +// CHECK: br i1 %[[VAL_194]], label %[[VAL_195:.*]], label %[[VAL_171]] +// CHECK: reduction_write_output-after: +// CHECK: br label %inter_warp_reduce-after +// CHECK: reduction_write_output-true: +// CHECK: %[[VAL_200:.*]] = load float, ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_201:.*]] = load i32, ptr %[[VAL_202:.*]], align 4 +// CHECK: store i32 %[[VAL_201]], ptr %[[VAL_16]], align 4 +// CHECK: br label %[[VAL_203:.*]] +// CHECK: atomic_op_loop_exit: ; preds = %[[VAL_204:.*]], %[[VAL_203]] +// CHECK: br label %[[VAL_171]] +// CHECK: atomic_op_loop_body: ; preds = %[[VAL_204]], %[[VAL_195]] +// CHECK: %[[VAL_205:.*]] = load i32, ptr %[[VAL_16]], align 4 +// CHECK: store i32 %[[VAL_205]], ptr %[[VAL_15]], align 4 +// CHECK: call void @[[MIN]](ptr %[[VAL_15]], ptr %[[VAL_178]], ptr %[[VAL_15]]) +// CHECK: %[[VAL_206:.*]] = load i32, ptr %[[VAL_15]], align 4 +// CHECK: %[[VAL_207:.*]] = icmp eq i32 %[[VAL_205]], %[[VAL_206]] +// CHECK: br i1 %[[VAL_207]], label %atomic_op_loop_exit, label %atomic_op_loop_cas +// CHECK: atomic_op_loop_cas: ; preds = %[[VAL_203]] +// CHECK: %[[VAL_208:.*]] = cmpxchg ptr %[[VAL_202]], i32 %[[VAL_205]], i32 %[[VAL_206]] seq_cst seq_cst, align 4 +// CHECK: %[[VAL_209:.*]] = extractvalue { i32, i1 } %[[VAL_208]], 0 +// CHECK: store i32 %[[VAL_209]], ptr %[[VAL_16]], align 4 +// CHECK: %[[VAL_210:.*]] = extractvalue { i32, i1 } %[[VAL_208]], 1 +// CHECK: br i1 %[[VAL_210]], label %atomic_op_loop_exit, label %atomic_op_loop_body +// CHECK: entry: +// CHECK: %[[VAL_211:.*]] = alloca float, align 4 +// CHECK: %[[VAL_212:.*]] = load float, ptr %[[VAL_213:.*]], align 4 +// CHECK: %[[VAL_214:.*]] = load float, ptr %[[VAL_215:.*]], align 4 +// CHECK: %[[VAL_216:.*]] = call float @llvm.minimum.f32(float %[[VAL_212]], float %[[VAL_214]]) +// CHECK: store float %[[VAL_216]], ptr %[[VAL_211]], align 4 +// CHECK: %[[VAL_217:.*]] = load float, ptr %[[VAL_211]], align 4 +// CHECK: store float %[[VAL_217]], ptr %[[VAL_218:.*]], align 4 +// CHECK: ret void diff --git a/xla/service/gpu/tests/reduce_column_layout_change.hlo b/xla/service/gpu/tests/reduce_column_layout_change.hlo new file mode 100644 index 0000000000000..4c90b12e02ca7 --- /dev/null +++ b/xla/service/gpu/tests/reduce_column_layout_change.hlo @@ -0,0 +1,196 @@ +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck %s --check-prefixes=CHECK,CHECK-%{PTX} + +HloModule reduce_with_layout_change, is_scheduled=true + +reduction0 { + x0 = f32[] parameter(0) + y0 = f32[] parameter(1) + ROOT add0 = f32[] add(x0, y0) +} + +fused_computation { + arg0 = f32[12,3,32,16,32,4,3,12] parameter(0) + constant0 = f32[] constant(0) + ROOT reduce0 = f32[16,32,4,3,12]{1,3,2,0,4} reduce(arg0, constant0), dimensions={0,1,2}, to_apply=reduction0 +} + +ENTRY kernel_entry { + arg0 = f32[12,3,32,16,32,4,3,12] parameter(0) + ROOT fusion = f32[16,32,4,3,12]{1,3,2,0,4} fusion(arg0), kind=kInput, calls=fused_computation +} + +// CHECK-LABEL: entry: +// CHECK: %[[VAL_0:.*]] = alloca float, align 4 +// CHECK: %[[VAL_1:.*]] = alloca float, align 4 +// CHECK: %[[VAL_2:.*]] = alloca float, align 4 +// CHECK: %[[VAL_3:.*]] = alloca float, align 4 +// CHECK: %[[VAL_4:.*]] = alloca float, align 4 +// CHECK: %[[VAL_5:.*]] = alloca float, align 4 +// CHECK: %[[VAL_6:.*]] = alloca float, align 4 +// CHECK: %[[VAL_7:.*]] = alloca float, align 4 +// CHECK: %[[VAL_8:.*]] = alloca float, align 4 +// CHECK: %[[VAL_9:.*]] = alloca float, align 4 +// CHECK: %[[VAL_10:.*]] = alloca float, align 4 +// CHECK: %[[VAL_11:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_12:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_13:.*]] = alloca float, align 4 +// CHECK: %[[VAL_14:.*]] = alloca float, align 4 +// CHECK-PTX: %[[VAL_15:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !range !2 +// CHECK-GCN: %[[VAL_15:.*]] = call i32 @llvm.amdgcn.workgroup.id.y +// CHECK: %[[VAL_16:.*]] = icmp eq i32 %[[VAL_15]], 0 +// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]] +// CHECK: reduce-group-0-after: ; preds = %[[VAL_19:.*]], %[[VAL_20:.*]] +// CHECK: ret void +// CHECK: reduce-group-0-true: ; preds = %[[VAL_20]] +// CHECK: %[[VAL_21:.*]] = load float, ptr @0, align 4 +// CHECK: store float %[[VAL_21]], ptr %[[VAL_13]], align 4 +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !4 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %[[VAL_22:.*]] = udiv i32 %thread.id.x, 32 +// CHECK: %thread.id.1 = urem i32 %[[VAL_22]], 32 +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_23:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_24:.*]] = urem i32 %[[VAL_23]], 2304 +// CHECK: %[[VAL_25:.*]] = udiv i32 %block.id.x, 2304 +// CHECK: %[[VAL_26:.*]] = urem i32 %[[VAL_25]], 1 +// CHECK: %[[VAL_27:.*]] = udiv i32 %block.id.x, 2304 +// CHECK: %[[VAL_28:.*]] = icmp eq i32 %[[VAL_26]], 0 +// CHECK: %tile_bound.1 = select i1 %[[VAL_28]], i32 1152, i32 4096 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_27]], 1 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_26]], 4096 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_24]], 32 +// CHECK: store i32 %thread.id.1, ptr %[[VAL_12]], align 4 +// CHECK: br label %[[VAL_29:.*]] +// CHECK: loop1.loop_header: ; preds = %[[VAL_30:.*]], %[[VAL_17]] +// CHECK: %[[VAL_31:.*]] = load i32, ptr %[[VAL_12]], align 4 +// CHECK: %[[VAL_32:.*]] = icmp uge i32 %[[VAL_31]], %tile_bound.1 +// CHECK: br i1 %[[VAL_32]], label %[[VAL_33:.*]], label %[[VAL_34:.*]] +// CHECK: loop1.loop_body: ; preds = %[[VAL_29]] +// CHECK: %[[VAL_35:.*]] = add nuw nsw i32 %[[VAL_31]], 32 +// CHECK: store i32 %[[VAL_35]], ptr %[[VAL_12]], align 4 +// CHECK: %[[VAL_36:.*]] = icmp eq i32 %[[VAL_31]], %thread.id.1 +// CHECK: store i32 0, ptr %[[VAL_11]], align 4 +// CHECK: br label %[[VAL_37:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_38:.*]], %[[VAL_34]] +// CHECK: %[[VAL_39:.*]] = load i32, ptr %[[VAL_11]], align 4 +// CHECK: %[[VAL_40:.*]] = icmp uge i32 %[[VAL_39]], 32 +// CHECK: br i1 %[[VAL_40]], label %[[VAL_30]], label %[[VAL_41:.*]] +// CHECK: loop2.loop_body: ; preds = %[[VAL_37]] +// CHECK: %[[VAL_42:.*]] = add nuw nsw i32 %[[VAL_39]], 32 +// CHECK: store i32 %[[VAL_42]], ptr %[[VAL_11]], align 4 +// CHECK: %[[VAL_43:.*]] = icmp eq i32 %[[VAL_39]], 0 +// CHECK: %[[VAL_44:.*]] = add i32 %[[VAL_39]], %thread.id.2 +// CHECK: %[[VAL_45:.*]] = icmp ult i32 %[[VAL_44]], 32 +// CHECK: br i1 %[[VAL_45]], label %[[VAL_46:.*]], label %[[VAL_38]] +// CHECK: x_in_tile-after: ; preds = %[[VAL_46]], %[[VAL_41]] +// CHECK: br label %[[VAL_37]], !llvm.loop !5 +// CHECK: loop2.loop_exit: ; preds = %[[VAL_37]] +// CHECK: br label %[[VAL_29]], !llvm.loop !8 +// CHECK: loop1.loop_exit: ; preds = %[[VAL_29]] +// CHECK: %[[VAL_47:.*]] = load float, ptr %[[VAL_13]], align 4 +// CHECK: %[[VAL_48:.*]] = getelementptr inbounds [32 x [33 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.2, i32 %thread.id.1 +// CHECK: %[[VAL_49:.*]] = addrspacecast ptr addrspace(3) %[[VAL_48]] to ptr +// CHECK: store float %[[VAL_47]], ptr %[[VAL_49]], align 4 +// CHECK: call void @llvm.nvvm.barrier0() +// CHECK: %[[VAL_50:.*]] = getelementptr inbounds [32 x [33 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.1, i32 %thread.id.2 +// CHECK: %[[VAL_51:.*]] = addrspacecast ptr addrspace(3) %[[VAL_50]] to ptr +// CHECK: %[[VAL_52:.*]] = load float, ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_53:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_52]], i32 16, i32 31) +// CHECK: store float %[[VAL_53]], ptr %[[VAL_9]], align 4 +// CHECK: call void @[[REDUCTION0:reduction0.*]](ptr %[[VAL_51]], ptr %[[VAL_9]], ptr %[[VAL_8]]) +// CHECK: %[[VAL_54:.*]] = load float, ptr %[[VAL_8]], align 4 +// CHECK: store float %[[VAL_54]], ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_55:.*]] = load float, ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_56:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_55]], i32 8, i32 31) +// CHECK: store float %[[VAL_56]], ptr %[[VAL_7]], align 4 +// CHECK: call void @[[REDUCTION0]](ptr %[[VAL_51]], ptr %[[VAL_7]], ptr %[[VAL_6]]) +// CHECK: %[[VAL_57:.*]] = load float, ptr %[[VAL_6]], align 4 +// CHECK: store float %[[VAL_57]], ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_58:.*]] = load float, ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_59:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_58]], i32 4, i32 31) +// CHECK: store float %[[VAL_59]], ptr %[[VAL_5]], align 4 +// CHECK: call void @[[REDUCTION0]](ptr %[[VAL_51]], ptr %[[VAL_5]], ptr %[[VAL_4]]) +// CHECK: %[[VAL_60:.*]] = load float, ptr %[[VAL_4]], align 4 +// CHECK: store float %[[VAL_60]], ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_61:.*]] = load float, ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_62:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_61]], i32 2, i32 31) +// CHECK: store float %[[VAL_62]], ptr %[[VAL_3]], align 4 +// CHECK: call void @[[REDUCTION0]](ptr %[[VAL_51]], ptr %[[VAL_3]], ptr %[[VAL_2]]) +// CHECK: %[[VAL_63:.*]] = load float, ptr %[[VAL_2]], align 4 +// CHECK: store float %[[VAL_63]], ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_64:.*]] = load float, ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_65:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_64]], i32 1, i32 31) +// CHECK: store float %[[VAL_65]], ptr %[[VAL_1]], align 4 +// CHECK: call void @[[REDUCTION0]](ptr %[[VAL_51]], ptr %[[VAL_1]], ptr %[[VAL_0]]) +// CHECK: %[[VAL_66:.*]] = load float, ptr %[[VAL_0]], align 4 +// CHECK: store float %[[VAL_66]], ptr %[[VAL_51]], align 4 +// CHECK: %[[VAL_67:.*]] = icmp ult i32 %thread.id.1, 32 +// CHECK: %[[VAL_68:.*]] = icmp ult i32 %thread.id.2, %tile_bound.1 +// CHECK: %[[VAL_69:.*]] = and i1 %[[VAL_67]], %[[VAL_68]] +// CHECK: %[[VAL_70:.*]] = icmp eq i32 %lane_id, 0 +// CHECK: %[[VAL_71:.*]] = and i1 %[[VAL_69]], %[[VAL_70]] +// CHECK: br i1 %[[VAL_71]], label %[[VAL_72:.*]], label %[[VAL_19]] +// CHECK: reduction_write_output-after: ; preds = %[[VAL_72]], %[[VAL_33]] +// CHECK: br label %[[VAL_18]] +// CHECK: x_in_tile-true: ; preds = %[[VAL_41]] +// CHECK: %[[VAL_73:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_74:.*]] = add i32 %tile_origin.1, %[[VAL_31]] +// CHECK: %[[VAL_75:.*]] = add i32 %tile_origin.2, %[[VAL_44]] +// CHECK: %[[VAL_76:.*]] = mul nuw nsw i32 %[[VAL_75]], 1 +// CHECK: %[[VAL_77:.*]] = add nuw nsw i32 0, %[[VAL_76]] +// CHECK: %[[VAL_78:.*]] = urem i32 %[[VAL_77]], 12 +// CHECK: %[[VAL_79:.*]] = udiv i32 %[[VAL_77]], 12 +// CHECK: %[[VAL_80:.*]] = urem i32 %[[VAL_79]], 3 +// CHECK: %[[VAL_81:.*]] = udiv i32 %[[VAL_79]], 3 +// CHECK: %[[VAL_82:.*]] = urem i32 %[[VAL_81]], 4 +// CHECK: %[[VAL_83:.*]] = udiv i32 %[[VAL_81]], 4 +// CHECK: %[[VAL_84:.*]] = urem i32 %[[VAL_83]], 32 +// CHECK: %[[VAL_85:.*]] = udiv i32 %[[VAL_83]], 32 +// CHECK: %[[VAL_86:.*]] = udiv i32 %[[VAL_85]], 16 +// CHECK: %[[VAL_87:.*]] = mul nuw nsw i32 %[[VAL_74]], 1 +// CHECK: %[[VAL_88:.*]] = add nuw nsw i32 0, %[[VAL_87]] +// CHECK: %[[VAL_89:.*]] = urem i32 %[[VAL_88]], 32 +// CHECK: %[[VAL_90:.*]] = udiv i32 %[[VAL_88]], 32 +// CHECK: %[[VAL_91:.*]] = urem i32 %[[VAL_90]], 3 +// CHECK: %[[VAL_92:.*]] = udiv i32 %[[VAL_90]], 3 +// CHECK: %[[VAL_93:.*]] = udiv i32 %[[VAL_92]], 12 +// CHECK: %[[VAL_94:.*]] = mul nuw nsw i32 %[[VAL_73]], 1 +// CHECK: %[[VAL_95:.*]] = add nuw nsw i32 0, %[[VAL_94]] +// CHECK: %[[VAL_96:.*]] = getelementptr inbounds [12 x [3 x [32 x [16 x [32 x [4 x [3 x [12 x float]]]]]]]], ptr %[[VAL_97:.*]], i32 0, i32 %[[VAL_92]], i32 %[[VAL_91]], i32 %[[VAL_89]], i32 %[[VAL_85]], i32 %[[VAL_84]], i32 %[[VAL_82]], i32 %[[VAL_80]], i32 %[[VAL_78]] +// CHECK: %[[VAL_98:.*]] = load float, ptr %[[VAL_96]], align 4, !invariant.load !9 +// CHECK: store float %[[VAL_98]], ptr %[[VAL_14]], align 4 +// CHECK: call void @[[REDUCTION0]](ptr %[[VAL_13]], ptr %[[VAL_14]], ptr %[[VAL_10]]) +// CHECK: %[[VAL_99:.*]] = load float, ptr %[[VAL_10]], align 4 +// CHECK: store float %[[VAL_99]], ptr %[[VAL_13]], align 4 +// CHECK: br label %[[VAL_38]] +// CHECK: reduction_write_output-true: ; preds = %[[VAL_33]] +// CHECK: %[[VAL_100:.*]] = add i32 %tile_origin.2, %thread.id.1 +// CHECK: %[[VAL_101:.*]] = mul nuw nsw i32 %[[VAL_100]], 1 +// CHECK: %[[VAL_102:.*]] = add nuw nsw i32 0, %[[VAL_101]] +// CHECK: %[[VAL_103:.*]] = urem i32 %[[VAL_102]], 12 +// CHECK: %[[VAL_104:.*]] = udiv i32 %[[VAL_102]], 12 +// CHECK: %[[VAL_105:.*]] = urem i32 %[[VAL_104]], 3 +// CHECK: %[[VAL_106:.*]] = udiv i32 %[[VAL_104]], 3 +// CHECK: %[[VAL_107:.*]] = urem i32 %[[VAL_106]], 4 +// CHECK: %[[VAL_108:.*]] = udiv i32 %[[VAL_106]], 4 +// CHECK: %[[VAL_109:.*]] = urem i32 %[[VAL_108]], 32 +// CHECK: %[[VAL_110:.*]] = udiv i32 %[[VAL_108]], 32 +// CHECK: %[[VAL_111:.*]] = udiv i32 %[[VAL_110]], 16 +// CHECK: %[[VAL_112:.*]] = mul nuw nsw i32 %tile_origin.0, 1 +// CHECK: %[[VAL_113:.*]] = add nuw nsw i32 0, %[[VAL_112]] +// CHECK: %[[VAL_114:.*]] = getelementptr inbounds [12 x [16 x [4 x [3 x [32 x float]]]]], ptr %[[VAL_115:.*]], i32 0, i32 %[[VAL_103]], i32 %[[VAL_110]], i32 %[[VAL_107]], i32 %[[VAL_105]], i32 %[[VAL_109]] +// CHECK: %[[VAL_116:.*]] = load float, ptr %[[VAL_51]], align 4 +// CHECK: store float %[[VAL_116]], ptr %[[VAL_114]], align 4 +// CHECK: br label %[[VAL_19]] +// CHECK: entry: +// CHECK: %[[VAL_117:.*]] = alloca float, align 4 +// CHECK: %[[VAL_118:.*]] = load float, ptr %[[VAL_119:.*]], align 4 +// CHECK: %[[VAL_120:.*]] = load float, ptr %[[VAL_121:.*]], align 4 +// CHECK: %[[VAL_122:.*]] = fadd float %[[VAL_118]], %[[VAL_120]] +// CHECK: store float %[[VAL_122]], ptr %[[VAL_117]], align 4 +// CHECK: %[[VAL_123:.*]] = load float, ptr %[[VAL_117]], align 4 +// CHECK: store float %[[VAL_123]], ptr %[[VAL_124:.*]], align 4 +// CHECK: ret void diff --git a/xla/service/gpu/tests/reduce_f64_column.hlo b/xla/service/gpu/tests/reduce_f64_column.hlo new file mode 100644 index 0000000000000..abfab46233238 --- /dev/null +++ b/xla/service/gpu/tests/reduce_f64_column.hlo @@ -0,0 +1,256 @@ +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck %s --check-prefixes=CHECK,CHECK-%{PTX} + +HloModule m, is_scheduled=true + +add { + a = f64[] parameter(0) + b = f64[] parameter(1) + ROOT out = f64[] add(a, b) +} + +fused_computation { + p1 = f64[1024,1024]{1,0} parameter(0) + p2 = f64[1024,1024]{1,0} parameter(1) + s = pred[1024,1024]{1,0} parameter(2) + p = f64[1024,1024]{1,0} select(s, p1, p2) + z = f64[] constant(0) + ROOT out = f64[1024]{0} reduce(p, z), to_apply=add, dimensions={0} +} + +ENTRY e { + p1 = f64[1024,1024]{1,0} parameter(0) + p2 = f64[1024,1024]{1,0} parameter(1) + s = pred[1024,1024]{1,0} parameter(2) + ROOT f = f64[1024]{0} fusion(p1, p2, s), kind=kInput, calls=fused_computation +} + +// CHECK: @shared_cache = private addrspace(3) global [32 x [33 x double]] + +// CHECK-LABEL: entry: +// CHECK: %[[VAL_0:.*]] = alloca double, align 8 +// CHECK: %[[VAL_1:.*]] = alloca double, align 8 +// CHECK: %[[VAL_2:.*]] = alloca double, align 8 +// CHECK: %[[VAL_3:.*]] = alloca double, align 8 +// CHECK: %[[VAL_4:.*]] = alloca double, align 8 +// CHECK: %[[VAL_5:.*]] = alloca double, align 8 +// CHECK: %[[VAL_6:.*]] = alloca double, align 8 +// CHECK: %[[VAL_7:.*]] = alloca double, align 8 +// CHECK: %[[VAL_8:.*]] = alloca double, align 8 +// CHECK: %[[VAL_9:.*]] = alloca double, align 8 +// CHECK: %[[VAL_10:.*]] = alloca double, align 8 +// CHECK: %[[VAL_11:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_12:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_13:.*]] = alloca double, align 8 +// CHECK: %[[VAL_14:.*]] = alloca double, align 8 +// CHECK-PTX: %[[VAL_15:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !range !2 +// CHECK-GCN: %[[VAL_15:.*]] = call i32 @llvm.amdgcn.workgroup.id.y +// CHECK: %[[VAL_16:.*]] = icmp eq i32 %[[VAL_15]], 0 +// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]] +// CHECK: reduce-group-0-after: ; preds = %[[VAL_19:.*]], %[[VAL_20:.*]] +// CHECK: ret void +// CHECK: reduce-group-0-true: ; preds = %[[VAL_20]] +// CHECK: %[[VAL_21:.*]] = load double, ptr @0, align 8 +// CHECK: store double %[[VAL_21]], ptr{{.*}}%[[VAL_13]], align 8 +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !4 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %[[VAL_22:.*]] = udiv i32 %thread.id.x, 32 +// CHECK: %thread.id.1 = urem i32 %[[VAL_22]], 32 +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_23:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_24:.*]] = urem i32 %[[VAL_23]], 32 +// CHECK: %[[VAL_25:.*]] = udiv i32 %block.id.x, 32 +// CHECK: %[[VAL_26:.*]] = urem i32 %[[VAL_25]], 1 +// CHECK: %[[VAL_27:.*]] = udiv i32 %block.id.x, 32 +// CHECK: %[[VAL_28:.*]] = icmp eq i32 %[[VAL_26]], 0 +// CHECK: %tile_bound.1 = select i1 %[[VAL_28]], i32 1024, i32 4096 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_27]], 1 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_26]], 4096 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_24]], 32 +// CHECK: store i32 %thread.id.1, ptr{{.*}}%[[VAL_12]], align 4 +// CHECK: br label %[[VAL_29:.*]] +// CHECK: loop1.loop_header: ; preds = %[[VAL_30:.*]], %[[VAL_17]] +// CHECK: %[[VAL_31:.*]] = load i32, ptr{{.*}}%[[VAL_12]], align 4 +// CHECK: %[[VAL_32:.*]] = icmp uge i32 %[[VAL_31]], %tile_bound.1 +// CHECK: br i1 %[[VAL_32]], label %[[VAL_33:.*]], label %[[VAL_34:.*]] +// CHECK: loop1.loop_body: ; preds = %[[VAL_29]] +// CHECK: %[[VAL_35:.*]] = add nuw nsw i32 %[[VAL_31]], 32 +// CHECK: store i32 %[[VAL_35]], ptr{{.*}}%[[VAL_12]], align 4 +// CHECK: %[[VAL_36:.*]] = icmp eq i32 %[[VAL_31]], %thread.id.1 +// CHECK: store i32 0, ptr{{.*}}%[[VAL_11]], align 4 +// CHECK: br label %[[VAL_37:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_38:.*]], %[[VAL_34]] +// CHECK: %[[VAL_39:.*]] = load i32, ptr{{.*}}%[[VAL_11]], align 4 +// CHECK: %[[VAL_40:.*]] = icmp uge i32 %[[VAL_39]], 32 +// CHECK: br i1 %[[VAL_40]], label %[[VAL_30]], label %[[VAL_41:.*]] +// CHECK: loop2.loop_body: ; preds = %[[VAL_37]] +// CHECK: %[[VAL_42:.*]] = add nuw nsw i32 %[[VAL_39]], 32 +// CHECK: store i32 %[[VAL_42]], ptr{{.*}}%[[VAL_11]], align 4 +// CHECK: %[[VAL_43:.*]] = icmp eq i32 %[[VAL_39]], 0 +// CHECK: %[[VAL_44:.*]] = add i32 %[[VAL_39]], %thread.id.2 +// CHECK: %[[VAL_45:.*]] = icmp ult i32 %[[VAL_44]], 32 +// CHECK: br i1 %[[VAL_45]], label %[[VAL_46:.*]], label %[[VAL_38]] +// CHECK: x_in_tile-after: ; preds = %[[VAL_46]], %[[VAL_41]] +// CHECK: br label %[[VAL_37]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit: ; preds = %[[VAL_37]] +// CHECK: br label %[[VAL_29]], !llvm.loop !{{[0-9]}} +// CHECK: loop1.loop_exit: ; preds = %[[VAL_29]] +// CHECK: %[[VAL_47:.*]] = load double, ptr{{.*}}%[[VAL_13]], align 8 +// CHECK: %[[VAL_48:.*]] = getelementptr inbounds [32 x [33 x double]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.2, i32 %thread.id.1 +// CHECK: %[[VAL_49:.*]] = addrspacecast ptr addrspace(3) %[[VAL_48]] to ptr +// CHECK: store double %[[VAL_47]], ptr{{.*}}%[[VAL_49]], align 8 +// CHECK-PTX: call void @llvm.nvvm.barrier0() +// CHECK-GCN: call void @llvm.amdgcn.s.barrier() +// CHECK: %[[VAL_50:.*]] = getelementptr inbounds [32 x [33 x double]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.1, i32 %thread.id.2 +// CHECK: %[[VAL_51:.*]] = addrspacecast ptr addrspace(3) %[[VAL_50]] to ptr +// CHECK: %[[VAL_52:.*]] = load double, ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: %[[VAL_53:.*]] = bitcast double %[[VAL_52]] to i64 +// CHECK: %[[VAL_54:.*]] = bitcast i64 %[[VAL_53]] to <2 x i32> +// CHECK: %[[VAL_55:.*]] = extractelement <2 x i32> %[[VAL_54]], i64 0 +// CHECK-PTX: %[[VAL_56:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_55]], i32 16, i32 31) +// CHECK-GCN: %[[VAL_56:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_55]], i32 16) +// CHECK: %[[VAL_57:.*]] = insertelement <2 x i32> %[[VAL_54]], i32 %[[VAL_56]], i64 0 +// CHECK: %[[VAL_58:.*]] = extractelement <2 x i32> %[[VAL_57]], i64 1 +// CHECK-PTX: %[[VAL_59:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_58]], i32 16, i32 31) +// CHECK-GCN: %[[VAL_59:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_58]], i32 16) +// CHECK: %[[VAL_60:.*]] = insertelement <2 x i32> %[[VAL_57]], i32 %[[VAL_59]], i64 1 +// CHECK: %[[VAL_61:.*]] = bitcast <2 x i32> %[[VAL_60]] to i64 +// CHECK: %[[VAL_62:.*]] = bitcast i64 %[[VAL_61]] to double +// CHECK: store double %[[VAL_62]], ptr{{.*}}%[[VAL_9]], align 8 +// CHECK-PTX: call void @[[ADD:add.*]](ptr %[[VAL_51]], ptr %[[VAL_9]], ptr %[[VAL_8]]) +// CHECK-GCN: %[[VAL_9_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_9]] to ptr +// CHECK-GCN: %[[VAL_8_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_8]] to ptr +// CHECK-GCN: call void @[[ADD:add.*]](ptr %[[VAL_51]], ptr %[[VAL_9_1]], ptr %[[VAL_8_1]]) +// CHECK: %[[VAL_63:.*]] = load double, ptr{{.*}}%[[VAL_8]], align 8 +// CHECK: store double %[[VAL_63]], ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: %[[VAL_64:.*]] = load double, ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: %[[VAL_65:.*]] = bitcast double %[[VAL_64]] to i64 +// CHECK: %[[VAL_66:.*]] = bitcast i64 %[[VAL_65]] to <2 x i32> +// CHECK: %[[VAL_67:.*]] = extractelement <2 x i32> %[[VAL_66]], i64 0 +// CHECK-PTX: %[[VAL_68:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_67]], i32 8, i32 31) +// CHECK-GCN: %[[VAL_68:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_67]], i32 8) +// CHECK: %[[VAL_69:.*]] = insertelement <2 x i32> %[[VAL_66]], i32 %[[VAL_68]], i64 0 +// CHECK: %[[VAL_70:.*]] = extractelement <2 x i32> %[[VAL_69]], i64 1 +// CHECK-PTX: %[[VAL_71:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_70]], i32 8, i32 31) +// CHECK-GCN: %[[VAL_71:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_70]], i32 8) +// CHECK: %[[VAL_72:.*]] = insertelement <2 x i32> %[[VAL_69]], i32 %[[VAL_71]], i64 1 +// CHECK: %[[VAL_73:.*]] = bitcast <2 x i32> %[[VAL_72]] to i64 +// CHECK: %[[VAL_74:.*]] = bitcast i64 %[[VAL_73]] to double +// CHECK: store double %[[VAL_74]], ptr{{.*}}%[[VAL_7]], align 8 +// CHECK-PTX: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_7]], ptr %[[VAL_6]]) +// CHECK-GCN: %[[VAL_7_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_7]] to ptr +// CHECK-GCN: %[[VAL_6_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_6]] to ptr +// CHECK-GCN: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_7_1]], ptr %[[VAL_6_1]]) +// CHECK: %[[VAL_75:.*]] = load double, ptr{{.*}}%[[VAL_6]], align 8 +// CHECK: store double %[[VAL_75]], ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: %[[VAL_76:.*]] = load double, ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: %[[VAL_77:.*]] = bitcast double %[[VAL_76]] to i64 +// CHECK: %[[VAL_78:.*]] = bitcast i64 %[[VAL_77]] to <2 x i32> +// CHECK: %[[VAL_79:.*]] = extractelement <2 x i32> %[[VAL_78]], i64 0 +// CHECK-PTX: %[[VAL_80:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_79]], i32 4, i32 31) +// CHECK-GCN: %[[VAL_80:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_79]], i32 4) +// CHECK: %[[VAL_81:.*]] = insertelement <2 x i32> %[[VAL_78]], i32 %[[VAL_80]], i64 0 +// CHECK: %[[VAL_82:.*]] = extractelement <2 x i32> %[[VAL_81]], i64 1 +// CHECK-PTX: %[[VAL_83:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_82]], i32 4, i32 31) +// CHECK-GCN: %[[VAL_83:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_82]], i32 4) +// CHECK: %[[VAL_84:.*]] = insertelement <2 x i32> %[[VAL_81]], i32 %[[VAL_83]], i64 1 +// CHECK: %[[VAL_85:.*]] = bitcast <2 x i32> %[[VAL_84]] to i64 +// CHECK: %[[VAL_86:.*]] = bitcast i64 %[[VAL_85]] to double +// CHECK: store double %[[VAL_86]], ptr{{.*}}%[[VAL_5]], align 8 +// CHECK-PTX: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_5]], ptr %[[VAL_4]]) +// CHECK-GCN: %[[VAL_5_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_5]] to ptr +// CHECK-GCN: %[[VAL_4_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_4]] to ptr +// CHECK-GCN: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_5_1]], ptr %[[VAL_4_1]]) +// CHECK: %[[VAL_87:.*]] = load double, ptr{{.*}}%[[VAL_4]], align 8 +// CHECK: store double %[[VAL_87]], ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: %[[VAL_88:.*]] = load double, ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: %[[VAL_89:.*]] = bitcast double %[[VAL_88]] to i64 +// CHECK: %[[VAL_90:.*]] = bitcast i64 %[[VAL_89]] to <2 x i32> +// CHECK: %[[VAL_91:.*]] = extractelement <2 x i32> %[[VAL_90]], i64 0 +// CHECK-PTX: %[[VAL_92:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_91]], i32 2, i32 31) +// CHECK-GCN: %[[VAL_92:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_91]], i32 2) +// CHECK: %[[VAL_93:.*]] = insertelement <2 x i32> %[[VAL_90]], i32 %[[VAL_92]], i64 0 +// CHECK: %[[VAL_94:.*]] = extractelement <2 x i32> %[[VAL_93]], i64 1 +// CHECK-PTX: %[[VAL_95:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_94]], i32 2, i32 31) +// CHECK-GCN: %[[VAL_95:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_94]], i32 2) +// CHECK: %[[VAL_96:.*]] = insertelement <2 x i32> %[[VAL_93]], i32 %[[VAL_95]], i64 1 +// CHECK: %[[VAL_97:.*]] = bitcast <2 x i32> %[[VAL_96]] to i64 +// CHECK: %[[VAL_98:.*]] = bitcast i64 %[[VAL_97]] to double +// CHECK: store double %[[VAL_98]], ptr{{.*}}%[[VAL_3]], align 8 +// CHECK-PTX: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_3]], ptr %[[VAL_2]]) +// CHECK-GCN: %[[VAL_3_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_3]] to ptr +// CHECK-GCN: %[[VAL_2_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_2]] to ptr +// CHECK-GCN: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_3_1]], ptr %[[VAL_2_1]]) +// CHECK: %[[VAL_99:.*]] = load double, ptr{{.*}}%[[VAL_2]], align 8 +// CHECK: store double %[[VAL_99]], ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: %[[VAL_100:.*]] = load double, ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: %[[VAL_101:.*]] = bitcast double %[[VAL_100]] to i64 +// CHECK: %[[VAL_102:.*]] = bitcast i64 %[[VAL_101]] to <2 x i32> +// CHECK: %[[VAL_103:.*]] = extractelement <2 x i32> %[[VAL_102]], i64 0 +// CHECK-PTX: %[[VAL_104:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_103]], i32 1, i32 31) +// CHECK-GCN: %[[VAL_104:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_103]], i32 1) +// CHECK: %[[VAL_105:.*]] = insertelement <2 x i32> %[[VAL_102]], i32 %[[VAL_104]], i64 0 +// CHECK: %[[VAL_106:.*]] = extractelement <2 x i32> %[[VAL_105]], i64 1 +// CHECK-PTX: %[[VAL_107:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_106]], i32 1, i32 31) +// CHECK-GCN: %[[VAL_107:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_106]], i32 1) +// CHECK: %[[VAL_108:.*]] = insertelement <2 x i32> %[[VAL_105]], i32 %[[VAL_107]], i64 1 +// CHECK: %[[VAL_109:.*]] = bitcast <2 x i32> %[[VAL_108]] to i64 +// CHECK: %[[VAL_110:.*]] = bitcast i64 %[[VAL_109]] to double +// CHECK: store double %[[VAL_110]], ptr{{.*}}%[[VAL_1]], align 8 +// CHECK-PTX: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_1]], ptr %[[VAL_0]]) +// CHECK-GCN: %[[VAL_1_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_1]] to ptr +// CHECK-GCN: %[[VAL_0_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_0]] to ptr +// CHECK-GCN: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_1_1]], ptr %[[VAL_0_1]]) +// CHECK: %[[VAL_111:.*]] = load double, ptr{{.*}}%[[VAL_0]], align 8 +// CHECK: store double %[[VAL_111]], ptr{{.*}}%[[VAL_51]], align 8 +// CHECK-PTX: %[[VAL_112:.*]] = icmp ult i32 %thread.id.1, 32 +// CHECK-PTX: %[[VAL_113:.*]] = icmp ult i32 %thread.id.2, %tile_bound.1 +// CHECK-GCN: %[[VAL_113:.*]] = icmp ult i32 %thread.id.2, %tile_bound.1 +// CHECK-GCN: %[[VAL_112:.*]] = icmp ult i32 %thread.id.1, 32 +// CHECK: %[[VAL_114:.*]] = and i1 %[[VAL_112]], %[[VAL_113]] +// CHECK: %[[VAL_115:.*]] = icmp eq i32 %lane_id, 0 +// CHECK: %[[VAL_116:.*]] = and i1 %[[VAL_114]], %[[VAL_115]] +// CHECK: br i1 %[[VAL_116]], label %[[VAL_117:.*]], label %[[VAL_19]] +// CHECK: reduction_write_output-after: ; preds = %[[VAL_117]], %[[VAL_33]] +// CHECK: br label %[[VAL_18]] +// CHECK: x_in_tile-true: ; preds = %[[VAL_41]] +// CHECK: %[[VAL_118:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_119:.*]] = add i32 %tile_origin.1, %[[VAL_31]] +// CHECK: %[[VAL_120:.*]] = add i32 %tile_origin.2, %[[VAL_44]] +// CHECK: %[[VAL_121:.*]] = getelementptr inbounds [1024 x [1024 x i8]], ptr{{.*}}%[[VAL_122:.*]], i32 0, i32 %[[VAL_119]], i32 %[[VAL_120]] +// CHECK: %[[VAL_123:.*]] = load i8, ptr{{.*}}%[[VAL_121]], align 1, !invariant.load !{{[0-9]}} +// CHECK: %[[VAL_124:.*]] = getelementptr inbounds [1024 x [1024 x double]], ptr{{.*}}%[[VAL_125:.*]], i32 0, i32 %[[VAL_119]], i32 %[[VAL_120]] +// CHECK: %[[VAL_126:.*]] = load double, ptr{{.*}}%[[VAL_124]], align 8, !invariant.load !{{[0-9]}} +// CHECK: %[[VAL_127:.*]] = getelementptr inbounds [1024 x [1024 x double]], ptr{{.*}}%[[VAL_128:.*]], i32 0, i32 %[[VAL_119]], i32 %[[VAL_120]] +// CHECK: %[[VAL_129:.*]] = load double, ptr{{.*}}%[[VAL_127]], align 8, !invariant.load !{{[0-9]}} +// CHECK: %[[VAL_130:.*]] = trunc i8 %[[VAL_123]] to i1 +// CHECK: %[[VAL_131:.*]] = select i1 %[[VAL_130]], double %[[VAL_126]], double %[[VAL_129]] +// CHECK: store double %[[VAL_131]], ptr{{.*}}%[[VAL_14]], align 8 +// CHECK-PTX: call void @[[ADD]](ptr %[[VAL_13]], ptr %[[VAL_14]], ptr %[[VAL_10]]) +// CHECK-GCN: %[[VAL_13_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_13]] to ptr +// CHECK-GCN: %[[VAL_14_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_14]] to ptr +// CHECK-GCN: %[[VAL_10_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_10]] to ptr +// CHECK-GCN: call void @[[ADD]](ptr %[[VAL_13_1]], ptr %[[VAL_14_1]], ptr %[[VAL_10_1]]) +// CHECK: %[[VAL_132:.*]] = load double, ptr{{.*}}%[[VAL_10]], align 8 +// CHECK: store double %[[VAL_132]], ptr{{.*}}%[[VAL_13]], align 8 +// CHECK: br label %[[VAL_38]] +// CHECK: reduction_write_output-true: ; preds = %[[VAL_33]] +// CHECK: %[[VAL_135:.*]] = add i32 %tile_origin.2, %thread.id.1 +// CHECK: %[[VAL_139:.*]] = getelementptr inbounds [1024 x double], ptr{{.*}}%[[VAL_140:.*]], i32 0, i32 %[[VAL_135]] +// CHECK: %[[VAL_141:.*]] = load double, ptr{{.*}}%[[VAL_51]], align 8 +// CHECK: store double %[[VAL_141]], ptr{{.*}}%[[VAL_139]], align 8 +// CHECK: br label %[[VAL_19]] +// CHECK: entry: +// CHECK: %[[VAL_142:.*]] = alloca double, align 8 +// CHECK: %[[VAL_143:.*]] = load double, ptr{{.*}}%[[VAL_144:.*]], align 8 +// CHECK: %[[VAL_145:.*]] = load double, ptr{{.*}}%[[VAL_146:.*]], align 8 +// CHECK: %[[VAL_147:.*]] = fadd double %[[VAL_143]], %[[VAL_145]] +// CHECK: store double %[[VAL_147]], ptr{{.*}}%[[VAL_142]], align 8 +// CHECK: %[[VAL_148:.*]] = load double, ptr{{.*}}%[[VAL_142]], align 8 +// CHECK: store double %[[VAL_148]], ptr{{.*}}%[[VAL_149:.*]], align 8 +// CHECK: ret void + +// CHECK-PTX: !3 = !{i32 0, i32 1024} +// CHECK-PTX: !4 = !{i32 0, i32 32} diff --git a/xla/service/gpu/tests/reduce_large_row_to_scalar.hlo b/xla/service/gpu/tests/reduce_large_row_to_scalar.hlo new file mode 100644 index 0000000000000..21d32aebf1915 --- /dev/null +++ b/xla/service/gpu/tests/reduce_large_row_to_scalar.hlo @@ -0,0 +1,420 @@ +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck %s --check-prefixes=CHECK,CHECK-%{PTX} + +HloModule LargeReduction, is_scheduled=true + +Sum { + x.1 = c128[] parameter(0) + y.1 = c128[] parameter(1) + ROOT add.1 = c128[] add(x.1, y.1) +} + +fused_computation { + param_0 = c128[10000]{0} parameter(0) + param_1 = c128[] parameter(1) + ROOT out1.1 = c128[] reduce(c128[10000]{0} param_0, c128[] param_1), dimensions={0}, to_apply=Sum +} + +ENTRY reduce.1 { + parameter = c128[10000] parameter(0) + init_value = c128[] constant((0, 0)) + ROOT wrapped_out1 = c128[] fusion(c128[10000]{0} parameter, c128[] init_value), kind=kInput, calls=fused_computation +} + +// CHECK-LABEL: entry: +// CHECK: %[[VAL_0:.*]] = alloca %[[VAL_1:.*]], align 8 +// CHECK: %[[VAL_2:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_3:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_4:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_5:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_6:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_7:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_8:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_9:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_10:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_11:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_12:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_13:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_14:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_15:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_16:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_17:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_18:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_19:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_20:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_21:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_22:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_23:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_24:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_25:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_26:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_27:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_28:.*]] = alloca %[[VAL_1]], align 8 +// CHECK: %[[VAL_29:.*]] = alloca %[[VAL_1]], align 8 +// CHECK-PTX: %[[VAL_30:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !range !2 +// CHECK-GCN: %[[VAL_30:.*]] = call i32 @llvm.amdgcn.workgroup.id.y +// CHECK: %[[VAL_31:.*]] = icmp eq i32 %[[VAL_30]], 0 +// CHECK: br i1 %[[VAL_31]], label %[[VAL_32:.*]], label %[[VAL_33:.*]] +// CHECK: reduce-group-0-after: ; preds = %thread_in_bounds-after, %[[VAL_34:.*]] +// CHECK: ret void +// CHECK: reduce-group-0-true: ; preds = %[[VAL_34]] +// CHECK: %[[VAL_35:.*]] = load %[[VAL_1]], ptr %[[VAL_36:.*]], align 1, !invariant.load !3 +// CHECK: store %[[VAL_1]] %[[VAL_35]], ptr %[[VAL_28]], align 1 +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !4 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 640 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_37:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_38:.*]] = urem i32 %[[VAL_37]], 1 +// CHECK: %[[VAL_39:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_40:.*]] = urem i32 %[[VAL_39]], 1 +// CHECK: %[[VAL_41:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_42:.*]] = urem i32 %[[VAL_41]], 1 +// CHECK: %[[VAL_43:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_44:.*]] = icmp eq i32 %[[VAL_40]], 0 +// CHECK: %tile_bound.2 = select i1 %[[VAL_44]], i32 5000, i32 5120 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_43]], 1 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_42]], 1 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_40]], 5120 +// CHECK: %tile_origin.3 = mul i32 %[[VAL_38]], 2 +// CHECK: %[[VAL_45:.*]] = icmp eq i32 5120, %tile_bound.2 +// CHECK: br i1 %[[VAL_45]], label %[[VAL_46:.*]], label %[[VAL_47:.*]] +// CHECK: is_full_tile-after: ; preds = %[[VAL_48:.*]], %[[VAL_49:.*]] +// CHECK: %[[VAL_50:.*]] = load i128, ptr %[[VAL_28]], align 16 +// CHECK: %[[VAL_51:.*]] = bitcast i128 %[[VAL_50]] to <4 x i32> +// CHECK: %[[VAL_52:.*]] = extractelement <4 x i32> %[[VAL_51]], i64 0 +// CHECK: %[[VAL_53:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_52]], i32 16, i32 31) +// CHECK: %[[VAL_54:.*]] = insertelement <4 x i32> %[[VAL_51]], i32 %[[VAL_53]], i64 0 +// CHECK: %[[VAL_55:.*]] = extractelement <4 x i32> %[[VAL_54]], i64 1 +// CHECK: %[[VAL_56:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_55]], i32 16, i32 31) +// CHECK: %[[VAL_57:.*]] = insertelement <4 x i32> %[[VAL_54]], i32 %[[VAL_56]], i64 1 +// CHECK: %[[VAL_58:.*]] = extractelement <4 x i32> %[[VAL_57]], i64 2 +// CHECK: %[[VAL_59:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_58]], i32 16, i32 31) +// CHECK: %[[VAL_60:.*]] = insertelement <4 x i32> %[[VAL_57]], i32 %[[VAL_59]], i64 2 +// CHECK: %[[VAL_61:.*]] = extractelement <4 x i32> %[[VAL_60]], i64 3 +// CHECK: %[[VAL_62:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_61]], i32 16, i32 31) +// CHECK: %[[VAL_63:.*]] = insertelement <4 x i32> %[[VAL_60]], i32 %[[VAL_62]], i64 3 +// CHECK: %[[VAL_64:.*]] = bitcast <4 x i32> %[[VAL_63]] to i128 +// CHECK: store i128 %[[VAL_64]], ptr %[[VAL_21]], align 16 +// CHECK: call void @[[SUM:Sum.*]](ptr %[[VAL_28]], ptr %[[VAL_21]], ptr %[[VAL_20]]) +// CHECK: %[[VAL_65:.*]] = load %[[VAL_1]], ptr %[[VAL_20]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_65]], ptr %[[VAL_28]], align 1 +// CHECK: %[[VAL_66:.*]] = load i128, ptr %[[VAL_28]], align 16 +// CHECK: %[[VAL_67:.*]] = bitcast i128 %[[VAL_66]] to <4 x i32> +// CHECK: %[[VAL_68:.*]] = extractelement <4 x i32> %[[VAL_67]], i64 0 +// CHECK: %[[VAL_69:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_68]], i32 8, i32 31) +// CHECK: %[[VAL_70:.*]] = insertelement <4 x i32> %[[VAL_67]], i32 %[[VAL_69]], i64 0 +// CHECK: %[[VAL_71:.*]] = extractelement <4 x i32> %[[VAL_70]], i64 1 +// CHECK: %[[VAL_72:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_71]], i32 8, i32 31) +// CHECK: %[[VAL_73:.*]] = insertelement <4 x i32> %[[VAL_70]], i32 %[[VAL_72]], i64 1 +// CHECK: %[[VAL_74:.*]] = extractelement <4 x i32> %[[VAL_73]], i64 2 +// CHECK: %[[VAL_75:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_74]], i32 8, i32 31) +// CHECK: %[[VAL_76:.*]] = insertelement <4 x i32> %[[VAL_73]], i32 %[[VAL_75]], i64 2 +// CHECK: %[[VAL_77:.*]] = extractelement <4 x i32> %[[VAL_76]], i64 3 +// CHECK: %[[VAL_78:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_77]], i32 8, i32 31) +// CHECK: %[[VAL_79:.*]] = insertelement <4 x i32> %[[VAL_76]], i32 %[[VAL_78]], i64 3 +// CHECK: %[[VAL_80:.*]] = bitcast <4 x i32> %[[VAL_79]] to i128 +// CHECK: store i128 %[[VAL_80]], ptr %[[VAL_19]], align 16 +// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_19]], ptr %[[VAL_18]]) +// CHECK: %[[VAL_81:.*]] = load %[[VAL_1]], ptr %[[VAL_18]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_81]], ptr %[[VAL_28]], align 1 +// CHECK: %[[VAL_82:.*]] = load i128, ptr %[[VAL_28]], align 16 +// CHECK: %[[VAL_83:.*]] = bitcast i128 %[[VAL_82]] to <4 x i32> +// CHECK: %[[VAL_84:.*]] = extractelement <4 x i32> %[[VAL_83]], i64 0 +// CHECK: %[[VAL_85:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_84]], i32 4, i32 31) +// CHECK: %[[VAL_86:.*]] = insertelement <4 x i32> %[[VAL_83]], i32 %[[VAL_85]], i64 0 +// CHECK: %[[VAL_87:.*]] = extractelement <4 x i32> %[[VAL_86]], i64 1 +// CHECK: %[[VAL_88:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_87]], i32 4, i32 31) +// CHECK: %[[VAL_89:.*]] = insertelement <4 x i32> %[[VAL_86]], i32 %[[VAL_88]], i64 1 +// CHECK: %[[VAL_90:.*]] = extractelement <4 x i32> %[[VAL_89]], i64 2 +// CHECK: %[[VAL_91:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_90]], i32 4, i32 31) +// CHECK: %[[VAL_92:.*]] = insertelement <4 x i32> %[[VAL_89]], i32 %[[VAL_91]], i64 2 +// CHECK: %[[VAL_93:.*]] = extractelement <4 x i32> %[[VAL_92]], i64 3 +// CHECK: %[[VAL_94:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_93]], i32 4, i32 31) +// CHECK: %[[VAL_95:.*]] = insertelement <4 x i32> %[[VAL_92]], i32 %[[VAL_94]], i64 3 +// CHECK: %[[VAL_96:.*]] = bitcast <4 x i32> %[[VAL_95]] to i128 +// CHECK: store i128 %[[VAL_96]], ptr %[[VAL_17]], align 16 +// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_17]], ptr %[[VAL_16]]) +// CHECK: %[[VAL_97:.*]] = load %[[VAL_1]], ptr %[[VAL_16]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_97]], ptr %[[VAL_28]], align 1 +// CHECK: %[[VAL_98:.*]] = load i128, ptr %[[VAL_28]], align 16 +// CHECK: %[[VAL_99:.*]] = bitcast i128 %[[VAL_98]] to <4 x i32> +// CHECK: %[[VAL_100:.*]] = extractelement <4 x i32> %[[VAL_99]], i64 0 +// CHECK: %[[VAL_101:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_100]], i32 2, i32 31) +// CHECK: %[[VAL_102:.*]] = insertelement <4 x i32> %[[VAL_99]], i32 %[[VAL_101]], i64 0 +// CHECK: %[[VAL_103:.*]] = extractelement <4 x i32> %[[VAL_102]], i64 1 +// CHECK: %[[VAL_104:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_103]], i32 2, i32 31) +// CHECK: %[[VAL_105:.*]] = insertelement <4 x i32> %[[VAL_102]], i32 %[[VAL_104]], i64 1 +// CHECK: %[[VAL_106:.*]] = extractelement <4 x i32> %[[VAL_105]], i64 2 +// CHECK: %[[VAL_107:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_106]], i32 2, i32 31) +// CHECK: %[[VAL_108:.*]] = insertelement <4 x i32> %[[VAL_105]], i32 %[[VAL_107]], i64 2 +// CHECK: %[[VAL_109:.*]] = extractelement <4 x i32> %[[VAL_108]], i64 3 +// CHECK: %[[VAL_110:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_109]], i32 2, i32 31) +// CHECK: %[[VAL_111:.*]] = insertelement <4 x i32> %[[VAL_108]], i32 %[[VAL_110]], i64 3 +// CHECK: %[[VAL_112:.*]] = bitcast <4 x i32> %[[VAL_111]] to i128 +// CHECK: store i128 %[[VAL_112]], ptr %[[VAL_15]], align 16 +// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_15]], ptr %[[VAL_14]]) +// CHECK: %[[VAL_113:.*]] = load %[[VAL_1]], ptr %[[VAL_14]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_113]], ptr %[[VAL_28]], align 1 +// CHECK: %[[VAL_114:.*]] = load i128, ptr %[[VAL_28]], align 16 +// CHECK: %[[VAL_115:.*]] = bitcast i128 %[[VAL_114]] to <4 x i32> +// CHECK: %[[VAL_116:.*]] = extractelement <4 x i32> %[[VAL_115]], i64 0 +// CHECK: %[[VAL_117:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_116]], i32 1, i32 31) +// CHECK: %[[VAL_118:.*]] = insertelement <4 x i32> %[[VAL_115]], i32 %[[VAL_117]], i64 0 +// CHECK: %[[VAL_119:.*]] = extractelement <4 x i32> %[[VAL_118]], i64 1 +// CHECK: %[[VAL_120:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_119]], i32 1, i32 31) +// CHECK: %[[VAL_121:.*]] = insertelement <4 x i32> %[[VAL_118]], i32 %[[VAL_120]], i64 1 +// CHECK: %[[VAL_122:.*]] = extractelement <4 x i32> %[[VAL_121]], i64 2 +// CHECK: %[[VAL_123:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_122]], i32 1, i32 31) +// CHECK: %[[VAL_124:.*]] = insertelement <4 x i32> %[[VAL_121]], i32 %[[VAL_123]], i64 2 +// CHECK: %[[VAL_125:.*]] = extractelement <4 x i32> %[[VAL_124]], i64 3 +// CHECK: %[[VAL_126:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_125]], i32 1, i32 31) +// CHECK: %[[VAL_127:.*]] = insertelement <4 x i32> %[[VAL_124]], i32 %[[VAL_126]], i64 3 +// CHECK: %[[VAL_128:.*]] = bitcast <4 x i32> %[[VAL_127]] to i128 +// CHECK: store i128 %[[VAL_128]], ptr %[[VAL_13]], align 16 +// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_13]], ptr %[[VAL_12]]) +// CHECK: %[[VAL_129:.*]] = load %[[VAL_1]], ptr %[[VAL_12]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_129]], ptr %[[VAL_28]], align 1 +// CHECK: %[[VAL_130:.*]] = udiv i32 %thread.id.2, 32 +// CHECK: br i1 true, label %thread_in_bounds-true, label %thread_in_bounds-after +// CHECK: thread_in_bounds-after: ; preds = %[[VAL_131:.*]], %[[VAL_132:.*]] +// CHECK: br label %[[VAL_33]] +// CHECK: is_full_tile-true: ; preds = %[[VAL_32]] +// CHECK: store i32 0, ptr %[[VAL_27]], align 4 +// CHECK: br label %[[VAL_133:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_134:.*]], %[[VAL_46]] +// CHECK: %[[VAL_135:.*]] = load i32, ptr %[[VAL_27]], align 4 +// CHECK: %[[VAL_136:.*]] = icmp uge i32 %[[VAL_135]], 5120 +// CHECK: br i1 %[[VAL_136]], label %[[VAL_49]], label %[[VAL_137:.*]] +// CHECK: loop2.loop_body: ; preds = %[[VAL_133]] +// CHECK: %[[VAL_138:.*]] = add nuw nsw i32 %[[VAL_135]], 640 +// CHECK: store i32 %[[VAL_138]], ptr %[[VAL_27]], align 4 +// CHECK: %[[VAL_139:.*]] = icmp eq i32 %[[VAL_135]], 0 +// CHECK: %[[VAL_140:.*]] = add i32 %[[VAL_135]], %thread.id.2 +// CHECK: store i32 0, ptr %[[VAL_26]], align 4 +// CHECK: br label %[[VAL_141:.*]] +// CHECK: loop3.loop_header: ; preds = %[[VAL_142:.*]], %[[VAL_137]] +// CHECK: %[[VAL_143:.*]] = load i32, ptr %[[VAL_26]], align 4 +// CHECK: %[[VAL_144:.*]] = icmp uge i32 %[[VAL_143]], 2 +// CHECK: br i1 %[[VAL_144]], label %[[VAL_134]], label %[[VAL_142]] +// CHECK: loop3.loop_body: ; preds = %[[VAL_141]] +// CHECK: %[[VAL_145:.*]] = add nuw nsw i32 %[[VAL_143]], 1 +// CHECK: store i32 %[[VAL_145]], ptr %[[VAL_26]], align 4 +// CHECK: %[[VAL_146:.*]] = icmp eq i32 %[[VAL_143]], 0 +// CHECK: %[[VAL_147:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_148:.*]] = add i32 %tile_origin.1, 0 +// CHECK: %[[VAL_149:.*]] = add i32 %tile_origin.2, %[[VAL_140]] +// CHECK: %[[VAL_150:.*]] = add i32 %tile_origin.3, %[[VAL_143]] +// CHECK: %[[VAL_151:.*]] = mul nuw nsw i32 %[[VAL_150]], 1 +// CHECK: %[[VAL_152:.*]] = add nuw nsw i32 0, %[[VAL_151]] +// CHECK: %[[VAL_153:.*]] = mul nuw nsw i32 %[[VAL_149]], 2 +// CHECK: %[[VAL_154:.*]] = add nuw nsw i32 %[[VAL_152]], %[[VAL_153]] +// CHECK: %[[VAL_155:.*]] = udiv i32 %[[VAL_154]], 10000 +// CHECK: %[[VAL_156:.*]] = mul nuw nsw i32 %[[VAL_148]], 1 +// CHECK: %[[VAL_157:.*]] = add nuw nsw i32 0, %[[VAL_156]] +// CHECK: %[[VAL_158:.*]] = mul nuw nsw i32 %[[VAL_147]], 1 +// CHECK: %[[VAL_159:.*]] = add nuw nsw i32 0, %[[VAL_158]] +// CHECK: %[[VAL_160:.*]] = getelementptr inbounds [10000 x %[[VAL_1]]], ptr %[[VAL_161:.*]], i32 0, i32 %[[VAL_154]] +// CHECK: %[[VAL_162:.*]] = load %[[VAL_1]], ptr %[[VAL_160]], align 1, !invariant.load !3 +// CHECK: store %[[VAL_1]] %[[VAL_162]], ptr %[[VAL_29]], align 1 +// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_29]], ptr %[[VAL_25]]) +// CHECK: %[[VAL_163:.*]] = load %[[VAL_1]], ptr %[[VAL_25]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_163]], ptr %[[VAL_28]], align 1 +// CHECK: br label %[[VAL_141]], !llvm.loop !5 +// CHECK: loop3.loop_exit: ; preds = %[[VAL_141]] +// CHECK: br label %[[VAL_133]], !llvm.loop !7 +// CHECK: loop2.loop_exit: ; preds = %[[VAL_133]] +// CHECK: br label %[[VAL_132]] +// CHECK: is_full_tile-false: ; preds = %[[VAL_32]] +// CHECK: store i32 0, ptr %[[VAL_24]], align 4 +// CHECK: br label %[[VAL_164:.*]] +// CHECK: loop2.loop_header4: ; preds = %[[VAL_165:.*]], %[[VAL_47]] +// CHECK: %[[VAL_166:.*]] = load i32, ptr %[[VAL_24]], align 4 +// CHECK: %[[VAL_167:.*]] = icmp uge i32 %[[VAL_166]], 5120 +// CHECK: br i1 %[[VAL_167]], label %[[VAL_48]], label %[[VAL_168:.*]] +// CHECK: loop2.loop_body5: ; preds = %[[VAL_164]] +// CHECK: %[[VAL_169:.*]] = add nuw nsw i32 %[[VAL_166]], 640 +// CHECK: store i32 %[[VAL_169]], ptr %[[VAL_24]], align 4 +// CHECK: %[[VAL_170:.*]] = icmp eq i32 %[[VAL_166]], 0 +// CHECK: %[[VAL_171:.*]] = add i32 %[[VAL_166]], %thread.id.2 +// CHECK: %[[VAL_172:.*]] = icmp ult i32 %[[VAL_171]], %tile_bound.2 +// CHECK: br i1 %[[VAL_172]], label %[[VAL_173:.*]], label %[[VAL_165]] +// CHECK: x_in_tile-after: ; preds = %[[VAL_174:.*]], %[[VAL_168]] +// CHECK: br label %[[VAL_164]], !llvm.loop !9 +// CHECK: loop2.loop_exit3: ; preds = %[[VAL_164]] +// CHECK: br label %[[VAL_132]] +// CHECK: x_in_tile-true: ; preds = %[[VAL_168]] +// CHECK: store i32 0, ptr %[[VAL_23]], align 4 +// CHECK: br label %[[VAL_175:.*]] +// CHECK: loop3.loop_header10: ; preds = %[[VAL_176:.*]], %[[VAL_173]] +// CHECK: %[[VAL_177:.*]] = load i32, ptr %[[VAL_23]], align 4 +// CHECK: %[[VAL_178:.*]] = icmp uge i32 %[[VAL_177]], 2 +// CHECK: br i1 %[[VAL_178]], label %[[VAL_174]], label %[[VAL_176]] +// CHECK: loop3.loop_body11: ; preds = %[[VAL_175]] +// CHECK: %[[VAL_179:.*]] = add nuw nsw i32 %[[VAL_177]], 1 +// CHECK: store i32 %[[VAL_179]], ptr %[[VAL_23]], align 4 +// CHECK: %[[VAL_180:.*]] = icmp eq i32 %[[VAL_177]], 0 +// CHECK: %[[VAL_181:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_182:.*]] = add i32 %tile_origin.1, 0 +// CHECK: %[[VAL_183:.*]] = add i32 %tile_origin.2, %[[VAL_171]] +// CHECK: %[[VAL_184:.*]] = add i32 %tile_origin.3, %[[VAL_177]] +// CHECK: %[[VAL_185:.*]] = mul nuw nsw i32 %[[VAL_184]], 1 +// CHECK: %[[VAL_186:.*]] = add nuw nsw i32 0, %[[VAL_185]] +// CHECK: %[[VAL_187:.*]] = mul nuw nsw i32 %[[VAL_183]], 2 +// CHECK: %[[VAL_188:.*]] = add nuw nsw i32 %[[VAL_186]], %[[VAL_187]] +// CHECK: %[[VAL_189:.*]] = udiv i32 %[[VAL_188]], 10000 +// CHECK: %[[VAL_190:.*]] = mul nuw nsw i32 %[[VAL_182]], 1 +// CHECK: %[[VAL_191:.*]] = add nuw nsw i32 0, %[[VAL_190]] +// CHECK: %[[VAL_192:.*]] = mul nuw nsw i32 %[[VAL_181]], 1 +// CHECK: %[[VAL_193:.*]] = add nuw nsw i32 0, %[[VAL_192]] +// CHECK: %[[VAL_194:.*]] = getelementptr inbounds [10000 x %[[VAL_1]]], ptr %[[VAL_161]], i32 0, i32 %[[VAL_188]] +// CHECK: %[[VAL_195:.*]] = load %[[VAL_1]], ptr %[[VAL_194]], align 1, !invariant.load !3 +// CHECK: store %[[VAL_1]] %[[VAL_195]], ptr %[[VAL_29]], align 1 +// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_29]], ptr %[[VAL_22]]) +// CHECK: %[[VAL_196:.*]] = load %[[VAL_1]], ptr %[[VAL_22]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_196]], ptr %[[VAL_28]], align 1 +// CHECK: br label %[[VAL_175]], !llvm.loop !10 +// CHECK: loop3.loop_exit9: ; preds = %[[VAL_175]] +// CHECK: br label %[[VAL_165]] +// CHECK: thread_in_bounds-true: ; preds = %[[VAL_132]] +// CHECK: %[[VAL_197:.*]] = icmp eq i32 %lane_id, 0 +// CHECK: br i1 %[[VAL_197]], label %[[VAL_198:.*]], label %[[VAL_199:.*]] +// CHECK: intra_warp_reduce_write-after: ; preds = %[[VAL_198]], %thread_in_bounds-true +// CHECK: call void @llvm.nvvm.barrier0() +// CHECK: %[[VAL_200:.*]] = icmp eq i32 %[[VAL_130]], 0 +// CHECK: br i1 %[[VAL_200]], label %[[VAL_201:.*]], label %[[VAL_131]] +// CHECK: inter_warp_reduce-after: ; preds = %[[VAL_202:.*]], %[[VAL_199]] +// CHECK: br label %thread_in_bounds-after +// CHECK: intra_warp_reduce_write-true: ; preds = %thread_in_bounds-true +// CHECK: %[[VAL_203:.*]] = load %[[VAL_1]], ptr %[[VAL_28]], align 1 +// CHECK: %[[VAL_204:.*]] = getelementptr inbounds [1 x [20 x %[[VAL_1]]]], ptr addrspace(3) @shared_cache, i32 0, i32 0, i32 %[[VAL_130]] +// CHECK: %[[VAL_205:.*]] = addrspacecast ptr addrspace(3) %[[VAL_204]] to ptr +// CHECK: store %[[VAL_1]] %[[VAL_203]], ptr %[[VAL_205]], align 1 +// CHECK: br label %[[VAL_199]] +// CHECK: inter_warp_reduce-true: ; preds = %[[VAL_199]] +// CHECK: %[[VAL_206:.*]] = getelementptr inbounds [1 x [20 x %[[VAL_1]]]], ptr addrspace(3) @shared_cache, i32 0, i32 0, i32 %lane_id +// CHECK: %[[VAL_207:.*]] = addrspacecast ptr addrspace(3) %[[VAL_206]] to ptr +// CHECK: store %[[VAL_1]] %[[VAL_35]], ptr %[[VAL_11]], align 1 +// CHECK: %[[VAL_208:.*]] = icmp ult i32 %thread.id.2, 20 +// CHECK: %[[VAL_209:.*]] = select i1 %[[VAL_208]], ptr %[[VAL_207]], ptr %[[VAL_11]] +// CHECK: %[[VAL_210:.*]] = load i128, ptr %[[VAL_209]], align 16 +// CHECK: %[[VAL_211:.*]] = bitcast i128 %[[VAL_210]] to <4 x i32> +// CHECK: %[[VAL_212:.*]] = extractelement <4 x i32> %[[VAL_211]], i64 0 +// CHECK: %[[VAL_213:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_212]], i32 16, i32 31) +// CHECK: %[[VAL_214:.*]] = insertelement <4 x i32> %[[VAL_211]], i32 %[[VAL_213]], i64 0 +// CHECK: %[[VAL_215:.*]] = extractelement <4 x i32> %[[VAL_214]], i64 1 +// CHECK: %[[VAL_216:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_215]], i32 16, i32 31) +// CHECK: %[[VAL_217:.*]] = insertelement <4 x i32> %[[VAL_214]], i32 %[[VAL_216]], i64 1 +// CHECK: %[[VAL_218:.*]] = extractelement <4 x i32> %[[VAL_217]], i64 2 +// CHECK: %[[VAL_219:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_218]], i32 16, i32 31) +// CHECK: %[[VAL_220:.*]] = insertelement <4 x i32> %[[VAL_217]], i32 %[[VAL_219]], i64 2 +// CHECK: %[[VAL_221:.*]] = extractelement <4 x i32> %[[VAL_220]], i64 3 +// CHECK: %[[VAL_222:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_221]], i32 16, i32 31) +// CHECK: %[[VAL_223:.*]] = insertelement <4 x i32> %[[VAL_220]], i32 %[[VAL_222]], i64 3 +// CHECK: %[[VAL_224:.*]] = bitcast <4 x i32> %[[VAL_223]] to i128 +// CHECK: store i128 %[[VAL_224]], ptr %[[VAL_10]], align 16 +// CHECK: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_10]], ptr %[[VAL_9]]) +// CHECK: %[[VAL_225:.*]] = load %[[VAL_1]], ptr %[[VAL_9]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_225]], ptr %[[VAL_209]], align 1 +// CHECK: %[[VAL_226:.*]] = load i128, ptr %[[VAL_209]], align 16 +// CHECK: %[[VAL_227:.*]] = bitcast i128 %[[VAL_226]] to <4 x i32> +// CHECK: %[[VAL_228:.*]] = extractelement <4 x i32> %[[VAL_227]], i64 0 +// CHECK: %[[VAL_229:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_228]], i32 8, i32 31) +// CHECK: %[[VAL_230:.*]] = insertelement <4 x i32> %[[VAL_227]], i32 %[[VAL_229]], i64 0 +// CHECK: %[[VAL_231:.*]] = extractelement <4 x i32> %[[VAL_230]], i64 1 +// CHECK: %[[VAL_232:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_231]], i32 8, i32 31) +// CHECK: %[[VAL_233:.*]] = insertelement <4 x i32> %[[VAL_230]], i32 %[[VAL_232]], i64 1 +// CHECK: %[[VAL_234:.*]] = extractelement <4 x i32> %[[VAL_233]], i64 2 +// CHECK: %[[VAL_235:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_234]], i32 8, i32 31) +// CHECK: %[[VAL_236:.*]] = insertelement <4 x i32> %[[VAL_233]], i32 %[[VAL_235]], i64 2 +// CHECK: %[[VAL_237:.*]] = extractelement <4 x i32> %[[VAL_236]], i64 3 +// CHECK: %[[VAL_238:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_237]], i32 8, i32 31) +// CHECK: %[[VAL_239:.*]] = insertelement <4 x i32> %[[VAL_236]], i32 %[[VAL_238]], i64 3 +// CHECK: %[[VAL_240:.*]] = bitcast <4 x i32> %[[VAL_239]] to i128 +// CHECK: store i128 %[[VAL_240]], ptr %[[VAL_8]], align 16 +// CHECK: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_8]], ptr %[[VAL_7]]) +// CHECK: %[[VAL_241:.*]] = load %[[VAL_1]], ptr %[[VAL_7]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_241]], ptr %[[VAL_209]], align 1 +// CHECK: %[[VAL_242:.*]] = load i128, ptr %[[VAL_209]], align 16 +// CHECK: %[[VAL_243:.*]] = bitcast i128 %[[VAL_242]] to <4 x i32> +// CHECK: %[[VAL_244:.*]] = extractelement <4 x i32> %[[VAL_243]], i64 0 +// CHECK: %[[VAL_245:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_244]], i32 4, i32 31) +// CHECK: %[[VAL_246:.*]] = insertelement <4 x i32> %[[VAL_243]], i32 %[[VAL_245]], i64 0 +// CHECK: %[[VAL_247:.*]] = extractelement <4 x i32> %[[VAL_246]], i64 1 +// CHECK: %[[VAL_248:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_247]], i32 4, i32 31) +// CHECK: %[[VAL_249:.*]] = insertelement <4 x i32> %[[VAL_246]], i32 %[[VAL_248]], i64 1 +// CHECK: %[[VAL_250:.*]] = extractelement <4 x i32> %[[VAL_249]], i64 2 +// CHECK: %[[VAL_251:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_250]], i32 4, i32 31) +// CHECK: %[[VAL_252:.*]] = insertelement <4 x i32> %[[VAL_249]], i32 %[[VAL_251]], i64 2 +// CHECK: %[[VAL_253:.*]] = extractelement <4 x i32> %[[VAL_252]], i64 3 +// CHECK: %[[VAL_254:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_253]], i32 4, i32 31) +// CHECK: %[[VAL_255:.*]] = insertelement <4 x i32> %[[VAL_252]], i32 %[[VAL_254]], i64 3 +// CHECK: %[[VAL_256:.*]] = bitcast <4 x i32> %[[VAL_255]] to i128 +// CHECK: store i128 %[[VAL_256]], ptr %[[VAL_6]], align 16 +// CHECK: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_6]], ptr %[[VAL_5]]) +// CHECK: %[[VAL_257:.*]] = load %[[VAL_1]], ptr %[[VAL_5]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_257]], ptr %[[VAL_209]], align 1 +// CHECK: %[[VAL_258:.*]] = load i128, ptr %[[VAL_209]], align 16 +// CHECK: %[[VAL_259:.*]] = bitcast i128 %[[VAL_258]] to <4 x i32> +// CHECK: %[[VAL_260:.*]] = extractelement <4 x i32> %[[VAL_259]], i64 0 +// CHECK: %[[VAL_261:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_260]], i32 2, i32 31) +// CHECK: %[[VAL_262:.*]] = insertelement <4 x i32> %[[VAL_259]], i32 %[[VAL_261]], i64 0 +// CHECK: %[[VAL_263:.*]] = extractelement <4 x i32> %[[VAL_262]], i64 1 +// CHECK: %[[VAL_264:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_263]], i32 2, i32 31) +// CHECK: %[[VAL_265:.*]] = insertelement <4 x i32> %[[VAL_262]], i32 %[[VAL_264]], i64 1 +// CHECK: %[[VAL_266:.*]] = extractelement <4 x i32> %[[VAL_265]], i64 2 +// CHECK: %[[VAL_267:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_266]], i32 2, i32 31) +// CHECK: %[[VAL_268:.*]] = insertelement <4 x i32> %[[VAL_265]], i32 %[[VAL_267]], i64 2 +// CHECK: %[[VAL_269:.*]] = extractelement <4 x i32> %[[VAL_268]], i64 3 +// CHECK: %[[VAL_270:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_269]], i32 2, i32 31) +// CHECK: %[[VAL_271:.*]] = insertelement <4 x i32> %[[VAL_268]], i32 %[[VAL_270]], i64 3 +// CHECK: %[[VAL_272:.*]] = bitcast <4 x i32> %[[VAL_271]] to i128 +// CHECK: store i128 %[[VAL_272]], ptr %[[VAL_4]], align 16 +// CHECK: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_4]], ptr %[[VAL_3]]) +// CHECK: %[[VAL_273:.*]] = load %[[VAL_1]], ptr %[[VAL_3]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_273]], ptr %[[VAL_209]], align 1 +// CHECK: %[[VAL_274:.*]] = load i128, ptr %[[VAL_209]], align 16 +// CHECK: %[[VAL_275:.*]] = bitcast i128 %[[VAL_274]] to <4 x i32> +// CHECK: %[[VAL_276:.*]] = extractelement <4 x i32> %[[VAL_275]], i64 0 +// CHECK: %[[VAL_277:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_276]], i32 1, i32 31) +// CHECK: %[[VAL_278:.*]] = insertelement <4 x i32> %[[VAL_275]], i32 %[[VAL_277]], i64 0 +// CHECK: %[[VAL_279:.*]] = extractelement <4 x i32> %[[VAL_278]], i64 1 +// CHECK: %[[VAL_280:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_279]], i32 1, i32 31) +// CHECK: %[[VAL_281:.*]] = insertelement <4 x i32> %[[VAL_278]], i32 %[[VAL_280]], i64 1 +// CHECK: %[[VAL_282:.*]] = extractelement <4 x i32> %[[VAL_281]], i64 2 +// CHECK: %[[VAL_283:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_282]], i32 1, i32 31) +// CHECK: %[[VAL_284:.*]] = insertelement <4 x i32> %[[VAL_281]], i32 %[[VAL_283]], i64 2 +// CHECK: %[[VAL_285:.*]] = extractelement <4 x i32> %[[VAL_284]], i64 3 +// CHECK: %[[VAL_286:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_285]], i32 1, i32 31) +// CHECK: %[[VAL_287:.*]] = insertelement <4 x i32> %[[VAL_284]], i32 %[[VAL_286]], i64 3 +// CHECK: %[[VAL_288:.*]] = bitcast <4 x i32> %[[VAL_287]] to i128 +// CHECK: store i128 %[[VAL_288]], ptr %[[VAL_2]], align 16 +// CHECK: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_2]], ptr %[[VAL_0]]) +// CHECK: %[[VAL_289:.*]] = load %[[VAL_1]], ptr %[[VAL_0]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_289]], ptr %[[VAL_209]], align 1 +// CHECK: %[[VAL_290:.*]] = icmp eq i32 %thread.id.2, 0 +// CHECK: br i1 %[[VAL_290]], label %[[VAL_291:.*]], label %[[VAL_202]] +// CHECK: reduction_write_output-after: ; preds = %[[VAL_291]], %[[VAL_201]] +// CHECK: br label %[[VAL_131]] +// CHECK: reduction_write_output-true: ; preds = %[[VAL_201]] +// CHECK: %[[VAL_293:.*]] = add i32 %tile_origin.1, 0 +// CHECK: %[[VAL_296:.*]] = load %[[VAL_1]], ptr %[[VAL_209]], align 1 +// CHECK: store %[[VAL_1]] %[[VAL_296]], ptr %[[VAL_297:.*]], align 1 +// CHECK: br label %[[VAL_202]] +// CHECK: entry: +// CHECK: %[[VAL_298:.*]] = alloca %[[VAL_299:.*]], align 8 +// CHECK: %[[VAL_300:.*]] = load %[[VAL_299]], ptr %[[VAL_301:.*]], align 1 +// CHECK: %[[VAL_302:.*]] = load %[[VAL_299]], ptr %[[VAL_303:.*]], align 1 +// CHECK: %[[VAL_304:.*]] = extractvalue %[[VAL_299]] %[[VAL_300]], 0 +// CHECK: %[[VAL_305:.*]] = extractvalue %[[VAL_299]] %[[VAL_302]], 0 +// CHECK: %[[VAL_306:.*]] = fadd double %[[VAL_304]], %[[VAL_305]] +// CHECK: %[[VAL_307:.*]] = extractvalue %[[VAL_299]] %[[VAL_300]], 1 +// CHECK: %[[VAL_308:.*]] = extractvalue %[[VAL_299]] %[[VAL_302]], 1 +// CHECK: %[[VAL_309:.*]] = fadd double %[[VAL_307]], %[[VAL_308]] +// CHECK: %[[VAL_310:.*]] = insertvalue %[[VAL_299]] zeroinitializer, double %[[VAL_306]], 0 +// CHECK: %[[VAL_311:.*]] = insertvalue %[[VAL_299]] %[[VAL_310]], double %[[VAL_309]], 1 +// CHECK: store %[[VAL_299]] %[[VAL_311]], ptr %[[VAL_298]], align 1 +// CHECK: %[[VAL_312:.*]] = load %[[VAL_299]], ptr %[[VAL_298]], align 1 +// CHECK: store %[[VAL_299]] %[[VAL_312]], ptr %[[VAL_313:.*]], align 1 +// CHECK: ret void diff --git a/xla/service/gpu/tests/reduce_row_vectorized.hlo b/xla/service/gpu/tests/reduce_row_vectorized.hlo new file mode 100644 index 0000000000000..b85eeb0ac8831 --- /dev/null +++ b/xla/service/gpu/tests/reduce_row_vectorized.hlo @@ -0,0 +1,300 @@ +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck %s --check-prefixes=CHECK,CHECK-%{PTX} + +HloModule RowReductionVectorized, is_scheduled=true + +Sum { + x.1 = f32[] parameter(0) + y.1 = f32[] parameter(1) + ROOT add.1 = f32[] add(x.1, y.1) +} + +fusion_vectorized { + a = f32[131072,1024] parameter(0) + init = f32[] constant(0) + ROOT reduce = f32[131072] reduce(a, init), dimensions={1}, to_apply=Sum +} + +ENTRY reduce.1 { + parameter0 = f32[131072,1024] parameter(0) + ROOT fusion_row_reduction_vectorized = f32[131072] fusion( + f32[131072,1024] parameter0 + ), kind=kLoop, calls=fusion_vectorized +} + +// CHECK-LABEL: entry: +// CHECK: %[[VAL_0:.*]] = alloca float, align 4 +// CHECK: %[[VAL_1:.*]] = alloca float, align 4 +// CHECK: %[[VAL_2:.*]] = alloca float, align 4 +// CHECK: %[[VAL_3:.*]] = alloca float, align 4 +// CHECK: %[[VAL_4:.*]] = alloca float, align 4 +// CHECK: %[[VAL_5:.*]] = alloca float, align 4 +// CHECK: %[[VAL_6:.*]] = alloca float, align 4 +// CHECK: %[[VAL_7:.*]] = alloca float, align 4 +// CHECK: %[[VAL_8:.*]] = alloca float, align 4 +// CHECK: %[[VAL_9:.*]] = alloca float, align 4 +// CHECK: %[[VAL_10:.*]] = alloca float, align 4 +// CHECK: %[[VAL_11:.*]] = alloca float, align 4 +// CHECK: %[[VAL_12:.*]] = alloca float, align 4 +// CHECK: %[[VAL_13:.*]] = alloca float, align 4 +// CHECK: %[[VAL_14:.*]] = alloca float, align 4 +// CHECK: %[[VAL_15:.*]] = alloca float, align 4 +// CHECK: %[[VAL_16:.*]] = alloca float, align 4 +// CHECK: %[[VAL_17:.*]] = alloca float, align 4 +// CHECK: %[[VAL_18:.*]] = alloca float, align 4 +// CHECK: %[[VAL_19:.*]] = alloca float, align 4 +// CHECK: %[[VAL_20:.*]] = alloca float, align 4 +// CHECK: %[[VAL_21:.*]] = alloca float, align 4 +// CHECK: %[[VAL_22:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_23:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_24:.*]] = alloca float, align 4 +// CHECK: %[[VAL_25:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_26:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_27:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_28:.*]] = alloca float, align 4 +// CHECK: %[[VAL_29:.*]] = alloca float, align 4 +// CHECK-PTX: %[[VAL_30:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !range !2 +// CHECK-GCN: %[[VAL_30:.*]] = call i32 @llvm.amdgcn.workgroup.id.y +// CHECK: %[[VAL_31:.*]] = icmp eq i32 %[[VAL_30]], 0 +// CHECK: br i1 %[[VAL_31]], label %[[VAL_32:.*]], label %[[VAL_33:.*]] +// CHECK: reduce-group-0-after: ; preds = %thread_in_bounds-after, %[[VAL_34:.*]] +// CHECK: ret void +// CHECK: reduce-group-0-true: ; preds = %[[VAL_34]] +// CHECK: %[[VAL_35:.*]] = load float, ptr @0, align 4 +// CHECK: store float %[[VAL_35]], ptr %[[VAL_28]], align 4 +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !4 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %[[VAL_36:.*]] = udiv i32 %thread.id.x, 64 +// CHECK: %thread.id.1 = urem i32 %[[VAL_36]], 4 +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 64 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_37:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_38:.*]] = urem i32 %[[VAL_37]], 1 +// CHECK: %[[VAL_39:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_40:.*]] = urem i32 %[[VAL_39]], 1 +// CHECK: %[[VAL_41:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_42:.*]] = urem i32 %[[VAL_41]], 32768 +// CHECK: %[[VAL_43:.*]] = udiv i32 %block.id.x, 32768 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_43]], 1 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_42]], 4 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_40]], 512 +// CHECK: %tile_origin.3 = mul i32 %[[VAL_38]], 2 +// CHECK: store i32 %thread.id.1, ptr %[[VAL_27]], align 4 +// CHECK: br label %[[VAL_44:.*]] +// CHECK: loop1.loop_header: ; preds = %[[VAL_45:.*]], %[[VAL_32]] +// CHECK: %[[VAL_46:.*]] = load i32, ptr %[[VAL_27]], align 4 +// CHECK: %[[VAL_47:.*]] = icmp uge i32 %[[VAL_46]], 4 +// CHECK: br i1 %[[VAL_47]], label %[[VAL_48:.*]], label %[[VAL_49:.*]] +// CHECK: loop1.loop_body: ; preds = %[[VAL_44]] +// CHECK: %[[VAL_50:.*]] = add nuw nsw i32 %[[VAL_46]], 4 +// CHECK: store i32 %[[VAL_50]], ptr %[[VAL_27]], align 4 +// CHECK: %[[VAL_51:.*]] = icmp eq i32 %[[VAL_46]], %thread.id.1 +// CHECK: br i1 true, label %[[VAL_52:.*]], label %[[VAL_53:.*]] +// CHECK: is_full_tile-after: ; preds = %[[VAL_54:.*]], %[[VAL_55:.*]] +// CHECK: br label %[[VAL_44]], !llvm.loop !5 +// CHECK: loop1.loop_exit: ; preds = %[[VAL_44]] +// CHECK: %[[VAL_56:.*]] = load float, ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_57:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_56]], i32 16, i32 31) +// CHECK: store float %[[VAL_57]], ptr %[[VAL_20]], align 4 +// CHECK: call void @[[SUM:Sum.*]](ptr %[[VAL_28]], ptr %[[VAL_20]], ptr %[[VAL_19]]) +// CHECK: %[[VAL_58:.*]] = load float, ptr %[[VAL_19]], align 4 +// CHECK: store float %[[VAL_58]], ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_59:.*]] = load float, ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_60:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_59]], i32 8, i32 31) +// CHECK: store float %[[VAL_60]], ptr %[[VAL_18]], align 4 +// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_18]], ptr %[[VAL_17]]) +// CHECK: %[[VAL_61:.*]] = load float, ptr %[[VAL_17]], align 4 +// CHECK: store float %[[VAL_61]], ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_62:.*]] = load float, ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_63:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_62]], i32 4, i32 31) +// CHECK: store float %[[VAL_63]], ptr %[[VAL_16]], align 4 +// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_16]], ptr %[[VAL_15]]) +// CHECK: %[[VAL_64:.*]] = load float, ptr %[[VAL_15]], align 4 +// CHECK: store float %[[VAL_64]], ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_65:.*]] = load float, ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_66:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_65]], i32 2, i32 31) +// CHECK: store float %[[VAL_66]], ptr %[[VAL_14]], align 4 +// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_14]], ptr %[[VAL_13]]) +// CHECK: %[[VAL_67:.*]] = load float, ptr %[[VAL_13]], align 4 +// CHECK: store float %[[VAL_67]], ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_68:.*]] = load float, ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_69:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_68]], i32 1, i32 31) +// CHECK: store float %[[VAL_69]], ptr %[[VAL_12]], align 4 +// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_12]], ptr %[[VAL_11]]) +// CHECK: %[[VAL_70:.*]] = load float, ptr %[[VAL_11]], align 4 +// CHECK: store float %[[VAL_70]], ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_71:.*]] = udiv i32 %thread.id.2, 32 +// CHECK: %[[VAL_72:.*]] = icmp ult i32 %thread.id.1, 4 +// CHECK: br i1 %[[VAL_72]], label %thread_in_bounds-true, label %thread_in_bounds-after +// CHECK: thread_in_bounds-after: ; preds = %[[VAL_73:.*]], %[[VAL_48]] +// CHECK: br label %[[VAL_33]] +// CHECK: is_full_tile-true: ; preds = %[[VAL_49]] +// CHECK: store i32 0, ptr %[[VAL_26]], align 4 +// CHECK: br label %[[VAL_74:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_75:.*]], %[[VAL_52]] +// CHECK: %[[VAL_76:.*]] = load i32, ptr %[[VAL_26]], align 4 +// CHECK: %[[VAL_77:.*]] = icmp uge i32 %[[VAL_76]], 512 +// CHECK: br i1 %[[VAL_77]], label %[[VAL_55]], label %[[VAL_78:.*]] +// CHECK: loop2.loop_body: ; preds = %[[VAL_74]] +// CHECK: %[[VAL_79:.*]] = add nuw nsw i32 %[[VAL_76]], 64 +// CHECK: store i32 %[[VAL_79]], ptr %[[VAL_26]], align 4 +// CHECK: %[[VAL_80:.*]] = icmp eq i32 %[[VAL_76]], 0 +// CHECK: %[[VAL_81:.*]] = add i32 %[[VAL_76]], %thread.id.2 +// CHECK: store i32 0, ptr %[[VAL_25]], align 4 +// CHECK: br label %[[VAL_82:.*]] +// CHECK: loop3.loop_header: ; preds = %[[VAL_83:.*]], %[[VAL_78]] +// CHECK: %[[VAL_84:.*]] = load i32, ptr %[[VAL_25]], align 4 +// CHECK: %[[VAL_85:.*]] = icmp uge i32 %[[VAL_84]], 2 +// CHECK: br i1 %[[VAL_85]], label %[[VAL_75]], label %[[VAL_83]] +// CHECK: loop3.loop_body: ; preds = %[[VAL_82]] +// CHECK: %[[VAL_86:.*]] = add nuw nsw i32 %[[VAL_84]], 1 +// CHECK: store i32 %[[VAL_86]], ptr %[[VAL_25]], align 4 +// CHECK: %[[VAL_87:.*]] = icmp eq i32 %[[VAL_84]], 0 +// CHECK: %[[VAL_88:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_89:.*]] = add i32 %tile_origin.1, %[[VAL_46]] +// CHECK: %[[VAL_90:.*]] = add i32 %tile_origin.2, %[[VAL_81]] +// CHECK: %[[VAL_91:.*]] = add i32 %tile_origin.3, %[[VAL_84]] +// CHECK: %[[VAL_92:.*]] = mul nuw nsw i32 %[[VAL_91]], 1 +// CHECK: %[[VAL_93:.*]] = add nuw nsw i32 0, %[[VAL_92]] +// CHECK: %[[VAL_94:.*]] = mul nuw nsw i32 %[[VAL_90]], 2 +// CHECK: %[[VAL_95:.*]] = add nuw nsw i32 %[[VAL_93]], %[[VAL_94]] +// CHECK: %[[VAL_96:.*]] = udiv i32 %[[VAL_95]], 1024 +// CHECK: %[[VAL_97:.*]] = mul nuw nsw i32 %[[VAL_89]], 1 +// CHECK: %[[VAL_98:.*]] = add nuw nsw i32 0, %[[VAL_97]] +// CHECK: %[[VAL_99:.*]] = udiv i32 %[[VAL_98]], 131072 +// CHECK: %[[VAL_100:.*]] = mul nuw nsw i32 %[[VAL_88]], 1 +// CHECK: %[[VAL_101:.*]] = add nuw nsw i32 0, %[[VAL_100]] +// CHECK: %[[VAL_102:.*]] = getelementptr inbounds [131072 x [1024 x float]], ptr %[[VAL_103:.*]], i32 0, i32 %[[VAL_98]], i32 %[[VAL_95]] +// CHECK: %[[VAL_104:.*]] = load float, ptr %[[VAL_102]], align 4, !invariant.load !7 +// CHECK: store float %[[VAL_104]], ptr %[[VAL_29]], align 4 +// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_29]], ptr %[[VAL_24]]) +// CHECK: %[[VAL_105:.*]] = load float, ptr %[[VAL_24]], align 4 +// CHECK: store float %[[VAL_105]], ptr %[[VAL_28]], align 4 +// CHECK: br label %[[VAL_82]], !llvm.loop !8 +// CHECK: loop3.loop_exit: ; preds = %[[VAL_82]] +// CHECK: br label %[[VAL_74]], !llvm.loop !9 +// CHECK: loop2.loop_exit: ; preds = %[[VAL_74]] +// CHECK: br label %[[VAL_45]] +// CHECK: is_full_tile-false: ; preds = %[[VAL_49]] +// CHECK: store i32 0, ptr %[[VAL_23]], align 4 +// CHECK: br label %[[VAL_106:.*]] +// CHECK: loop2.loop_header5: ; preds = %[[VAL_107:.*]], %[[VAL_53]] +// CHECK: %[[VAL_108:.*]] = load i32, ptr %[[VAL_23]], align 4 +// CHECK: %[[VAL_109:.*]] = icmp uge i32 %[[VAL_108]], 512 +// CHECK: br i1 %[[VAL_109]], label %[[VAL_54]], label %[[VAL_110:.*]] +// CHECK: loop2.loop_body6: ; preds = %[[VAL_106]] +// CHECK: %[[VAL_111:.*]] = add nuw nsw i32 %[[VAL_108]], 64 +// CHECK: store i32 %[[VAL_111]], ptr %[[VAL_23]], align 4 +// CHECK: %[[VAL_112:.*]] = icmp eq i32 %[[VAL_108]], 0 +// CHECK: %[[VAL_113:.*]] = add i32 %[[VAL_108]], %thread.id.2 +// CHECK: %[[VAL_114:.*]] = icmp ult i32 %[[VAL_113]], 512 +// CHECK: br i1 %[[VAL_114]], label %[[VAL_115:.*]], label %[[VAL_107]] +// CHECK: x_in_tile-after: ; preds = %[[VAL_116:.*]], %[[VAL_110]] +// CHECK: br label %[[VAL_106]], !llvm.loop !11 +// CHECK: loop2.loop_exit4: ; preds = %[[VAL_106]] +// CHECK: br label %[[VAL_45]] +// CHECK: x_in_tile-true: ; preds = %[[VAL_110]] +// CHECK: store i32 0, ptr %[[VAL_22]], align 4 +// CHECK: br label %[[VAL_117:.*]] +// CHECK: loop3.loop_header11: ; preds = %[[VAL_118:.*]], %[[VAL_115]] +// CHECK: %[[VAL_119:.*]] = load i32, ptr %[[VAL_22]], align 4 +// CHECK: %[[VAL_120:.*]] = icmp uge i32 %[[VAL_119]], 2 +// CHECK: br i1 %[[VAL_120]], label %[[VAL_116]], label %[[VAL_118]] +// CHECK: loop3.loop_body12: ; preds = %[[VAL_117]] +// CHECK: %[[VAL_121:.*]] = add nuw nsw i32 %[[VAL_119]], 1 +// CHECK: store i32 %[[VAL_121]], ptr %[[VAL_22]], align 4 +// CHECK: %[[VAL_122:.*]] = icmp eq i32 %[[VAL_119]], 0 +// CHECK: %[[VAL_123:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_124:.*]] = add i32 %tile_origin.1, %[[VAL_46]] +// CHECK: %[[VAL_125:.*]] = add i32 %tile_origin.2, %[[VAL_113]] +// CHECK: %[[VAL_126:.*]] = add i32 %tile_origin.3, %[[VAL_119]] +// CHECK: %[[VAL_127:.*]] = mul nuw nsw i32 %[[VAL_126]], 1 +// CHECK: %[[VAL_128:.*]] = add nuw nsw i32 0, %[[VAL_127]] +// CHECK: %[[VAL_129:.*]] = mul nuw nsw i32 %[[VAL_125]], 2 +// CHECK: %[[VAL_130:.*]] = add nuw nsw i32 %[[VAL_128]], %[[VAL_129]] +// CHECK: %[[VAL_131:.*]] = udiv i32 %[[VAL_130]], 1024 +// CHECK: %[[VAL_132:.*]] = mul nuw nsw i32 %[[VAL_124]], 1 +// CHECK: %[[VAL_133:.*]] = add nuw nsw i32 0, %[[VAL_132]] +// CHECK: %[[VAL_134:.*]] = udiv i32 %[[VAL_133]], 131072 +// CHECK: %[[VAL_135:.*]] = mul nuw nsw i32 %[[VAL_123]], 1 +// CHECK: %[[VAL_136:.*]] = add nuw nsw i32 0, %[[VAL_135]] +// CHECK: %[[VAL_137:.*]] = getelementptr inbounds [131072 x [1024 x float]], ptr %[[VAL_103]], i32 0, i32 %[[VAL_133]], i32 %[[VAL_130]] +// CHECK: %[[VAL_138:.*]] = load float, ptr %[[VAL_137]], align 4, !invariant.load !7 +// CHECK: store float %[[VAL_138]], ptr %[[VAL_29]], align 4 +// CHECK: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_29]], ptr %[[VAL_21]]) +// CHECK: %[[VAL_139:.*]] = load float, ptr %[[VAL_21]], align 4 +// CHECK: store float %[[VAL_139]], ptr %[[VAL_28]], align 4 +// CHECK: br label %[[VAL_117]], !llvm.loop !12 +// CHECK: loop3.loop_exit10: ; preds = %[[VAL_117]] +// CHECK: br label %[[VAL_107]] +// CHECK: thread_in_bounds-true: ; preds = %[[VAL_48]] +// CHECK: %[[VAL_140:.*]] = icmp eq i32 %lane_id, 0 +// CHECK: br i1 %[[VAL_140]], label %[[VAL_141:.*]], label %[[VAL_142:.*]] +// CHECK: intra_warp_reduce_write-after: ; preds = %[[VAL_141]], %thread_in_bounds-true +// CHECK: call void @llvm.nvvm.barrier0() +// CHECK: %[[VAL_143:.*]] = icmp eq i32 %[[VAL_71]], 0 +// CHECK: br i1 %[[VAL_143]], label %[[VAL_144:.*]], label %[[VAL_73]] +// CHECK: inter_warp_reduce-after: ; preds = %[[VAL_145:.*]], %[[VAL_142]] +// CHECK: br label %thread_in_bounds-after +// CHECK: intra_warp_reduce_write-true: ; preds = %thread_in_bounds-true +// CHECK: %[[VAL_146:.*]] = load float, ptr %[[VAL_28]], align 4 +// CHECK: %[[VAL_147:.*]] = getelementptr inbounds [4 x [2 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.1, i32 %[[VAL_71]] +// CHECK: %[[VAL_148:.*]] = addrspacecast ptr addrspace(3) %[[VAL_147]] to ptr +// CHECK: store float %[[VAL_146]], ptr %[[VAL_148]], align 4 +// CHECK: br label %[[VAL_142]] +// CHECK: inter_warp_reduce-true: ; preds = %[[VAL_142]] +// CHECK: %[[VAL_149:.*]] = getelementptr inbounds [4 x [2 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.1, i32 %lane_id +// CHECK: %[[VAL_150:.*]] = addrspacecast ptr addrspace(3) %[[VAL_149]] to ptr +// CHECK: store float %[[VAL_35]], ptr %[[VAL_10]], align 4 +// CHECK: %[[VAL_151:.*]] = icmp ult i32 %thread.id.2, 2 +// CHECK: %[[VAL_152:.*]] = select i1 %[[VAL_151]], ptr %[[VAL_150]], ptr %[[VAL_10]] +// CHECK: %[[VAL_153:.*]] = load float, ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_154:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_153]], i32 16, i32 31) +// CHECK: store float %[[VAL_154]], ptr %[[VAL_9]], align 4 +// CHECK: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_9]], ptr %[[VAL_8]]) +// CHECK: %[[VAL_155:.*]] = load float, ptr %[[VAL_8]], align 4 +// CHECK: store float %[[VAL_155]], ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_156:.*]] = load float, ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_157:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_156]], i32 8, i32 31) +// CHECK: store float %[[VAL_157]], ptr %[[VAL_7]], align 4 +// CHECK: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_7]], ptr %[[VAL_6]]) +// CHECK: %[[VAL_158:.*]] = load float, ptr %[[VAL_6]], align 4 +// CHECK: store float %[[VAL_158]], ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_159:.*]] = load float, ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_160:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_159]], i32 4, i32 31) +// CHECK: store float %[[VAL_160]], ptr %[[VAL_5]], align 4 +// CHECK: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_5]], ptr %[[VAL_4]]) +// CHECK: %[[VAL_161:.*]] = load float, ptr %[[VAL_4]], align 4 +// CHECK: store float %[[VAL_161]], ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_162:.*]] = load float, ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_163:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_162]], i32 2, i32 31) +// CHECK: store float %[[VAL_163]], ptr %[[VAL_3]], align 4 +// CHECK: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_3]], ptr %[[VAL_2]]) +// CHECK: %[[VAL_164:.*]] = load float, ptr %[[VAL_2]], align 4 +// CHECK: store float %[[VAL_164]], ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_165:.*]] = load float, ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_166:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_165]], i32 1, i32 31) +// CHECK: store float %[[VAL_166]], ptr %[[VAL_1]], align 4 +// CHECK: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_1]], ptr %[[VAL_0]]) +// CHECK: %[[VAL_167:.*]] = load float, ptr %[[VAL_0]], align 4 +// CHECK: store float %[[VAL_167]], ptr %[[VAL_152]], align 4 +// CHECK: %[[VAL_168:.*]] = icmp eq i32 %thread.id.2, 0 +// CHECK: br i1 %[[VAL_168]], label %[[VAL_169:.*]], label %[[VAL_145]] +// CHECK: reduction_write_output-after: ; preds = %[[VAL_169]], %[[VAL_144]] +// CHECK: br label %[[VAL_73]] +// CHECK: reduction_write_output-true: ; preds = %[[VAL_144]] +// CHECK: %[[VAL_171:.*]] = add i32 %tile_origin.1, %thread.id.1 +// CHECK: %[[VAL_175:.*]] = getelementptr inbounds [131072 x float], ptr %[[VAL_176:.*]], i32 0, i32 %[[VAL_171]] +// CHECK: %[[VAL_177:.*]] = load float, ptr %[[VAL_152]], align 4 +// CHECK: store float %[[VAL_177]], ptr %[[VAL_175]], align 4 +// CHECK: br label %[[VAL_145]] +// CHECK: entry: +// CHECK: %[[VAL_178:.*]] = alloca float, align 4 +// CHECK: %[[VAL_179:.*]] = load float, ptr %[[VAL_180:.*]], align 4 +// CHECK: %[[VAL_181:.*]] = load float, ptr %[[VAL_182:.*]], align 4 +// CHECK: %[[VAL_183:.*]] = fadd float %[[VAL_179]], %[[VAL_181]] +// CHECK: store float %[[VAL_183]], ptr %[[VAL_178]], align 4 +// CHECK: %[[VAL_184:.*]] = load float, ptr %[[VAL_178]], align 4 +// CHECK: store float %[[VAL_184]], ptr %[[VAL_185:.*]], align 4 +// CHECK: ret void diff --git a/xla/service/gpu/tests/reduce_unnested.hlo b/xla/service/gpu/tests/reduce_unnested.hlo index 9f70f2749e576..844c3ded2ef02 100644 --- a/xla/service/gpu/tests/reduce_unnested.hlo +++ b/xla/service/gpu/tests/reduce_unnested.hlo @@ -1,293 +1,4 @@ -// RUN: hlo_to_llvm_ir %s | FileCheck %s - -// CHECK: target datalayout -// CHECK: call float @llvm.nvvm.shfl.sync.down.f32 - -HloModule Test, is_scheduled=true - -Add { - scalar_lhs.0 = f32[] parameter(0) - scalar_rhs.0 = f32[] parameter(1) - scalar_lhs.1 = f32[] parameter(2) - scalar_rhs.1 = f32[] parameter(3) - add.0 = f32[] add(scalar_lhs.0, scalar_lhs.1) - add.1 = f32[] add(scalar_rhs.0, scalar_rhs.1) - ROOT t = (f32[], f32[]) tuple(add.0, add.1) -} - -fused_computation { - param_0 = f32[5,200,300]{2,1,0} parameter(0) - param_1 = f32[5,200,300]{2,1,0} parameter(1) - param_2 = f32[] parameter(2) - ROOT d.1 = (f32[200]{0}, f32[200]{0}) reduce(f32[5,200,300]{2,1,0} param_0, f32[5,200,300]{2,1,0} %param_1, f32[] param_2, f32[] param_2), dimensions={0,2}, to_apply=Add -} - -ENTRY main { - a = f32[5, 200, 300]{2,1,0} parameter(0) - b = f32[5, 200, 300]{2,1,0} parameter(1) - c = f32[] constant(0) - ROOT wrapped_d = (f32[200]{0}, f32[200]{0}) fusion(f32[5,200,300]{2,1,0} a, f32[5,200,300]{2,1,0} b, f32[] c), kind=kInput, calls=fused_computation -} - -// ----- - -// CHECK: target datalayout -// CHECK: @llvm.nvvm.shfl.sync.down - -HloModule LargeReduction, is_scheduled=true - -Sum { - x.1 = c128[] parameter(0) - y.1 = c128[] parameter(1) - ROOT add.1 = c128[] add(x.1, y.1) -} - -fused_computation { - param_0 = c128[10000]{0} parameter(0) - param_1 = c128[] parameter(1) - ROOT out1.1 = c128[] reduce(c128[10000]{0} param_0, c128[] param_1), dimensions={0}, to_apply=Sum -} - -ENTRY reduce.1 { - parameter = c128[10000] parameter(0) - init_value = c128[] constant((0, 0)) - ROOT wrapped_out1 = c128[] fusion(c128[10000]{0} parameter, c128[] init_value), kind=kInput, calls=fused_computation -} - -// ----- - -// Check that for "max" we are still using atomics (CAS loop). - -// CHECK: target datalayout -// CHECK: cmpxchg - -HloModule MinReduce, is_scheduled=true - -Min { - x.1 = f32[] parameter(0) - y.1 = f32[] parameter(1) - ROOT min.1 = f32[] minimum(x.1, y.1) -} - -fused_computation { - param_0 = f32[300000]{0} parameter(0) - param_1 = f32[] parameter(1) - ROOT reduce.1 = f32[] reduce(f32[300000]{0} param_0, f32[] param_1), dimensions={0}, to_apply=Min -} - -ENTRY reduce.1 { - parameter = f32[300000] parameter(0) - init_value = f32[] constant(0) - ROOT wrapped_reduce = f32[] fusion(f32[300000]{0} parameter, f32[] init_value), kind=kInput, calls=fused_computation -} - -// ----- - -// CHECK: define void @wrapped_vectorized_reduce( -// CHECK-COUNT-12: call void @Sum -// CHECK-NOT: call void @Sum - -HloModule Abc1, is_scheduled=true - -Sum { - x.1 = f32[] parameter(0) - y.1 = f32[] parameter(1) - ROOT add.1 = f32[] add(x.1, y.1) -} - -fused_computation { - param_0 = f32[64,1048576]{1,0} parameter(0) - param_1 = f32[] parameter(1) - ROOT vectorized_reduce.1 = f32[1048576]{0} reduce(f32[64,1048576]{1,0} param_0, f32[] param_1), dimensions={0}, to_apply=Sum -} - -ENTRY reduce.1 { - parameter = f32[64,1048576] parameter(0) - init_value = f32[] constant(0) - ROOT wrapped_vectorized_reduce = f32[1048576]{0} fusion(f32[64,1048576]{1,0} parameter, f32[] init_value), kind=kInput, calls=fused_computation -} - -// ----- - -// CHECK-LABEL: define void @fusion_not_vectorized_too_many_large_params( -// CHECK-COUNT-6: call void @Sum -// CHECK-NOT: call void @Sum - -HloModule ColumnReductionNotVectorizedTooManyLargeParams, is_scheduled=true - -Sum { - x.1 = f32[] parameter(0) - y.1 = f32[] parameter(1) - ROOT add.1 = f32[] add(x.1, y.1) -} - -// Not vectorized because there are too many large parameters. -fusion_not_vectorized { - a = f32[64,65536] parameter(0) - b = f32[128,65536] parameter(1) - c = f32[128,65536] parameter(2) - d = f32[128,65536] parameter(3) - - sum.1 = f32[128,65536] add(b, c) - sum.2 = f32[128,65536] add(sum.1, d) - slice = f32[1,1] slice(sum.2), slice={[0:1],[0:1]} - init = f32[] reshape(f32[1,1] slice) - - ROOT reduce = f32[65536] reduce(a, init), dimensions={0}, to_apply=Sum -} - -ENTRY reduce.1 { - parameter0 = f32[64,65536] parameter(0) - parameter1 = f32[128,65536] parameter(1) - parameter2 = f32[128,65536] parameter(2) - parameter3 = f32[128,65536] parameter(3) - ROOT fusion_not_vectorized_too_many_large_params = f32[65536] fusion( - f32[64,65536] parameter0, - f32[128,65536] parameter1, - f32[128,65536] parameter2, - f32[128,65536] parameter3 - ), kind=kLoop, calls=fusion_not_vectorized -} - -// ----- - -// CHECK: define void @fusion_vectorized_some_large_params( -// CHECK-COUNT-12: call void @Sum -// CHECK-NOT: call void @Sum - -HloModule ColumnReductionVectorizedSomeLargeParams, is_scheduled=true - -Sum { - x.1 = f32[] parameter(0) - y.1 = f32[] parameter(1) - ROOT add.1 = f32[] add(x.1, y.1) -} - -fusion_vectorized { - a = f32[64,65536] parameter(0) - b = f32[128,65536] parameter(1) - c = f32[128,65536] parameter(2) - - sum = f32[128,65536] add(b, c) - slice = f32[1,1] slice(sum), slice={[0:1],[0:1]} - init = f32[] reshape(f32[1,1] slice) - - ROOT reduce = f32[65536] reduce(a, init), dimensions={0}, to_apply=Sum -} - -ENTRY reduce.1 { - parameter0 = f32[64,65536] parameter(0) - parameter1 = f32[128,65536] parameter(1) - parameter2 = f32[128,65536] parameter(2) - ROOT fusion_vectorized_some_large_params = f32[65536] fusion( - f32[64,65536] parameter0, - f32[128,65536] parameter1, - f32[128,65536] parameter2 - ), kind=kLoop, calls=fusion_vectorized -} - -// ----- - -// CHECK: define void @fusion_vectorized_non_elementwise( -// CHECK-COUNT-12: call void @Sum -// CHECK-NOT: call void @Sum - -HloModule ColumnReductionVectorizedNonElementwise, is_scheduled=true - -Sum { - x.1 = f32[] parameter(0) - y.1 = f32[] parameter(1) - ROOT add.1 = f32[] add(x.1, y.1) -} - -fusion_vectorized { - a = f32[64,65536] parameter(0) - b = f32[128,65536] parameter(1) - - slice = f32[1,1] slice(b), slice={[0:1],[0:1]} - init = f32[] reshape(f32[1,1] slice) - - reverse = f32[64,65536] reverse(f32[64,65536] a), dimensions={0} - ROOT reduce = f32[65536] reduce(reverse, init), dimensions={0}, to_apply=Sum -} - -ENTRY reduce.1 { - parameter0 = f32[64,65536] parameter(0) - parameter1 = f32[128,65536] parameter(1) - ROOT fusion_vectorized_non_elementwise = f32[65536] fusion( - f32[64,65536] parameter0, - f32[128,65536] parameter1 - ), kind=kLoop, calls=fusion_vectorized -} - -// ----- - -// CHECK: define void @fusion_not_vectorized_non_elementwise( -// CHECK-COUNT-6: call void @Sum -// CHECK-NOT: call void @Sum - -HloModule ColumnReductionNotVectorizedNonElementwise, is_scheduled=true - -Sum { - x.1 = f32[] parameter(0) - y.1 = f32[] parameter(1) - ROOT add.1 = f32[] add(x.1, y.1) -} - -// Not vectorized because there are two large parameters and the `a` parameter -// goes through a non-elementwise instruction. -fusion_not_vectorized { - a = f32[64,65536] parameter(0) - b = f32[128,65536] parameter(1) - c = f32[128,65536] parameter(2) - - sum = f32[128,65536] add(b, c) - slice = f32[1,1] slice(sum), slice={[0:1],[0:1]} - init = f32[] reshape(f32[1,1] slice) - - reverse = f32[64,65536] reverse(f32[64,65536] a), dimensions={0} - ROOT reduce = f32[65536] reduce(reverse, init), dimensions={0}, to_apply=Sum -} - -ENTRY reduce.1 { - parameter0 = f32[64,65536] parameter(0) - parameter1 = f32[128,65536] parameter(1) - parameter2 = f32[128,65536] parameter(2) - ROOT fusion_not_vectorized_non_elementwise = f32[65536] fusion( - f32[64,65536] parameter0, - f32[128,65536] parameter1, - f32[128,65536] parameter2 - ), kind=kLoop, calls=fusion_not_vectorized -} - -// ----- - -// CHECK: define void @fusion_row_reduction_vectorized( -// CHECK-COUNT-2: {{^x_in_tile-true}} -// CHECK-NOT: {{^x_in_tile-true}} - -HloModule RowReductionVectorized, is_scheduled=true - -Sum { - x.1 = f32[] parameter(0) - y.1 = f32[] parameter(1) - ROOT add.1 = f32[] add(x.1, y.1) -} - -fusion_vectorized { - a = f32[131072,1024] parameter(0) - init = f32[] constant(0) - ROOT reduce = f32[131072] reduce(a, init), dimensions={1}, to_apply=Sum -} - -ENTRY reduce.1 { - parameter0 = f32[131072,1024] parameter(0) - ROOT fusion_row_reduction_vectorized = f32[131072] fusion( - f32[131072,1024] parameter0 - ), kind=kLoop, calls=fusion_vectorized -} - -// ----- +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck %s // CHECK: define void @fusion_row_reduction_too_small( // CHECK-COUNT-1: {{^x_in_tile-true}} @@ -369,69 +80,3 @@ ENTRY reduce.1 { f32[131072,1024] parameter0 ), kind=kLoop, calls=fusion_not_vectorized } - -// ----- - -// TODO(jreiffers): This should most likely not be unrolled. The heuristic only -// checks instructions that are directly in the fusion, not nested computations. - -// CHECK: define void @fusion_row_reduction_sin_does_not_prevent_vectorization( -// CHECK-COUNT-2: {{^x_in_tile-true}} -// CHECK-NOT: {{^x_in_tile-true}} - -HloModule RowReductionVectorized, is_scheduled=true - -Sum { - x.1 = f32[] parameter(0) - y.1 = f32[] parameter(1) - sin = f32[] sine(y.1) - ROOT add.1 = f32[] add(x.1, sin) -} - -fusion_vectorized { - a = f32[131072,1024] parameter(0) - init = f32[] constant(0) - ROOT reduce = f32[131072] reduce(a, init), dimensions={1}, to_apply=Sum -} - -ENTRY reduce.1 { - parameter0 = f32[131072,1024] parameter(0) - ROOT fusion_row_reduction_sin_does_not_prevent_vectorization = f32[131072] fusion( - f32[131072,1024] parameter0 - ), kind=kLoop, calls=fusion_vectorized -} - -// ----- - -// CHECK: define void @vectorized_col_reduction_exceeding_shmem_budget( -// CHECK-COUNT-12: call void @add -// CHECK-NOT: call void @add - -// We are trying to have a column reduction that: -// - triggers vectorization (thus large number of elements 1048576) -// - has a small "smallest input size" (1 for pred) -// - exceeds the shmem budget because `num_partial_results` is 8 - -HloModule m, is_scheduled=true - -add { - a = f64[] parameter(0) - b = f64[] parameter(1) - ROOT out = f64[] add(a, b) -} - -fused_computation { - p1 = f64[1048576,1048576]{1,0} parameter(0) - p2 = f64[1048576,1048576]{1,0} parameter(1) - s = pred[1048576,1048576]{1,0} parameter(2) - p = f64[1048576,1048576]{1,0} select(s, p1, p2) - z = f64[] constant(0) - ROOT out = f64[1048576]{0} reduce(p, z), to_apply=add, dimensions={0} -} - -ENTRY e { - p1 = f64[1048576,1048576]{1,0} parameter(0) - p2 = f64[1048576,1048576]{1,0} parameter(1) - s = pred[1048576,1048576]{1,0} parameter(2) - ROOT vectorized_col_reduction_exceeding_shmem_budget = f64[1048576]{0} fusion(p1, p2, s), kind=kInput, calls=fused_computation -} diff --git a/xla/service/gpu/tests/reduce_variadic_column.hlo b/xla/service/gpu/tests/reduce_variadic_column.hlo new file mode 100644 index 0000000000000..2032a64930717 --- /dev/null +++ b/xla/service/gpu/tests/reduce_variadic_column.hlo @@ -0,0 +1,464 @@ +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck %s --check-prefixes=CHECK,CHECK-%{PTX} + +HloModule Test, is_scheduled=true + +Add { + scalar_lhs.0 = f32[] parameter(0) + scalar_rhs.0 = f32[] parameter(1) + scalar_lhs.1 = f32[] parameter(2) + scalar_rhs.1 = f32[] parameter(3) + add.0 = f32[] add(scalar_lhs.0, scalar_lhs.1) + add.1 = f32[] add(scalar_rhs.0, scalar_rhs.1) + ROOT t = (f32[], f32[]) tuple(add.0, add.1) +} + +fused_computation { + param_0 = f32[5,200,300]{2,1,0} parameter(0) + param_1 = f32[5,200,300]{2,1,0} parameter(1) + param_2 = f32[] parameter(2) + ROOT d.1 = (f32[200]{0}, f32[200]{0}) reduce(f32[5,200,300]{2,1,0} param_0, f32[5,200,300]{2,1,0} %param_1, f32[] param_2, f32[] param_2), dimensions={0,2}, to_apply=Add +} + +ENTRY main { + a = f32[5, 200, 300]{2,1,0} parameter(0) + b = f32[5, 200, 300]{2,1,0} parameter(1) + c = f32[] constant(0) + ROOT wrapped_d = (f32[200]{0}, f32[200]{0}) fusion(f32[5,200,300]{2,1,0} a, f32[5,200,300]{2,1,0} b, f32[] c), kind=kInput, calls=fused_computation +} + +// CHECK-LABEL: entry: +// CHECK: %[[VAL_0:.*]] = alloca float, align 4 +// CHECK: %[[VAL_1:.*]] = alloca float, align 4 +// CHECK: %[[VAL_2:.*]] = alloca float, align 4 +// CHECK: %[[VAL_3:.*]] = alloca float, align 4 +// CHECK: %[[VAL_4:.*]] = alloca [2 x ptr], align 8 +// CHECK: %[[VAL_5:.*]] = alloca float, align 4 +// CHECK: %[[VAL_6:.*]] = alloca float, align 4 +// CHECK: %[[VAL_7:.*]] = alloca float, align 4 +// CHECK: %[[VAL_8:.*]] = alloca float, align 4 +// CHECK: %[[VAL_9:.*]] = alloca [2 x ptr], align 8 +// CHECK: %[[VAL_10:.*]] = alloca float, align 4 +// CHECK: %[[VAL_11:.*]] = alloca float, align 4 +// CHECK: %[[VAL_12:.*]] = alloca float, align 4 +// CHECK: %[[VAL_13:.*]] = alloca float, align 4 +// CHECK: %[[VAL_14:.*]] = alloca [2 x ptr], align 8 +// CHECK: %[[VAL_15:.*]] = alloca float, align 4 +// CHECK: %[[VAL_16:.*]] = alloca float, align 4 +// CHECK: %[[VAL_17:.*]] = alloca float, align 4 +// CHECK: %[[VAL_18:.*]] = alloca float, align 4 +// CHECK: %[[VAL_19:.*]] = alloca [2 x ptr], align 8 +// CHECK: %[[VAL_20:.*]] = alloca float, align 4 +// CHECK: %[[VAL_21:.*]] = alloca float, align 4 +// CHECK: %[[VAL_22:.*]] = alloca float, align 4 +// CHECK: %[[VAL_23:.*]] = alloca float, align 4 +// CHECK: %[[VAL_24:.*]] = alloca [2 x ptr], align 8 +// CHECK: %[[VAL_25:.*]] = alloca float, align 4 +// CHECK: %[[VAL_26:.*]] = alloca float, align 4 +// CHECK: %[[VAL_27:.*]] = alloca float, align 4 +// CHECK: %[[VAL_28:.*]] = alloca float, align 4 +// CHECK: %[[VAL_29:.*]] = alloca [2 x ptr], align 8 +// CHECK: %[[VAL_30:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_31:.*]] = alloca float, align 4 +// CHECK: %[[VAL_32:.*]] = alloca float, align 4 +// CHECK: %[[VAL_33:.*]] = alloca [2 x ptr], align 8 +// CHECK: %[[VAL_34:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_35:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_36:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_37:.*]] = alloca float, align 4 +// CHECK: %[[VAL_38:.*]] = alloca float, align 4 +// CHECK: %[[VAL_39:.*]] = alloca float, align 4 +// CHECK: %[[VAL_40:.*]] = alloca float, align 4 +// CHECK-PTX: %[[VAL_41:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !range !2 +// CHECK-GCN: %[[VAL_41:.*]] = call i32 @llvm.amdgcn.workgroup.id.y +// CHECK: %[[VAL_42:.*]] = icmp eq i32 %[[VAL_41]], 0 +// CHECK: br i1 %[[VAL_42]], label %[[VAL_43:.*]], label %[[VAL_44:.*]] +// CHECK: reduce-group-0-after: ; preds = %thread_in_bounds-after, %[[VAL_45:.*]] +// CHECK: ret void +// CHECK: reduce-group-0-true: ; preds = %[[VAL_45]] +// CHECK: %[[VAL_46:.*]] = load float, ptr{{.*}}%[[VAL_47:.*]], align 4, !invariant.load !{{[0-9]}} +// CHECK: store float %[[VAL_46]], ptr{{.*}}%[[VAL_39]], align 4 +// CHECK: %[[VAL_48:.*]] = load float, ptr{{.*}}%[[VAL_47]], align 4, !invariant.load !{{[0-9]}} +// CHECK: store float %[[VAL_48]], ptr{{.*}}%[[VAL_37]], align 4 +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !4 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !5 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %[[VAL_49:.*]] = udiv i32 %thread.id.x, 32 +// CHECK: %thread.id.1 = urem i32 %[[VAL_49]], 8 +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_50:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_51:.*]] = urem i32 %[[VAL_50]], 1 +// CHECK: %[[VAL_52:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_53:.*]] = urem i32 %[[VAL_52]], 25 +// CHECK: %[[VAL_54:.*]] = udiv i32 %block.id.x, 25 +// CHECK: %[[VAL_55:.*]] = icmp eq i32 %[[VAL_51]], 0 +// CHECK: %tile_bound.2 = select i1 %[[VAL_55]], i32 300, i32 512 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_54]], 5 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_53]], 8 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_51]], 512 +// CHECK: store i32 0, ptr{{.*}}%[[VAL_36]], align 4 +// CHECK: br label %[[VAL_56:.*]] +// CHECK: loop0.loop_header: ; preds = %[[VAL_57:.*]], %[[VAL_43]] +// CHECK: %[[VAL_58:.*]] = load i32, ptr{{.*}}%[[VAL_36]], align 4 +// CHECK: %[[VAL_59:.*]] = icmp uge i32 %[[VAL_58]], 5 +// CHECK: br i1 %[[VAL_59]], label %[[VAL_60:.*]], label %[[VAL_61:.*]] +// CHECK: loop0.loop_body: ; preds = %[[VAL_56]] +// CHECK: %[[VAL_62:.*]] = add nuw nsw i32 %[[VAL_58]], 1 +// CHECK: store i32 %[[VAL_62]], ptr{{.*}}%[[VAL_36]], align 4 +// CHECK: %[[VAL_63:.*]] = icmp eq i32 %[[VAL_58]], 0 +// CHECK: store i32 %thread.id.1, ptr{{.*}}%[[VAL_35]], align 4 +// CHECK: br label %[[VAL_64:.*]] +// CHECK: loop1.loop_header: ; preds = %[[VAL_65:.*]], %[[VAL_61]] +// CHECK: %[[VAL_66:.*]] = load i32, ptr{{.*}}%[[VAL_35]], align 4 +// CHECK: %[[VAL_67:.*]] = icmp uge i32 %[[VAL_66]], 8 +// CHECK: br i1 %[[VAL_67]], label %[[VAL_57]], label %[[VAL_68:.*]] +// CHECK: loop1.loop_body: ; preds = %[[VAL_64]] +// CHECK: %[[VAL_69:.*]] = add nuw nsw i32 %[[VAL_66]], 8 +// CHECK: store i32 %[[VAL_69]], ptr{{.*}}%[[VAL_35]], align 4 +// CHECK: %[[VAL_70:.*]] = icmp eq i32 %[[VAL_66]], %thread.id.1 +// CHECK: %[[VAL_71:.*]] = icmp eq i32 512, %tile_bound.2 +// CHECK: br i1 %[[VAL_71]], label %[[VAL_72:.*]], label %[[VAL_73:.*]] +// CHECK: is_full_tile-after: ; preds = %[[VAL_74:.*]], %[[VAL_75:.*]] +// CHECK: br label %[[VAL_64]], !llvm.loop !{{[0-9]}} +// CHECK: loop1.loop_exit: ; preds = %[[VAL_64]] +// CHECK: br label %[[VAL_56]], !llvm.loop !{{[0-9]}} +// CHECK: loop0.loop_exit: ; preds = %[[VAL_56]] +// CHECK: %[[VAL_76:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 +// CHECK-PTX: %[[VAL_77:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_76]], i32 16, i32 31) +// CHECK-GCN: %[[VAL_76_1:.*]] = bitcast float %[[VAL_76]] to i32 +// CHECK-GCN: %[[VAL_77_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_76_1:.*]], i32 16) +// CHECK-GCN: %[[VAL_77:.*]] = bitcast i32 %[[VAL_77_1:.*]] to float +// CHECK: store float %[[VAL_77]], ptr{{.*}}%[[VAL_26]], align 4 +// CHECK: %[[VAL_78:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 +// CHECK-PTX: %[[VAL_79:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_78]], i32 16, i32 31) +// CHECK-GCN: %[[VAL_78_1:.*]] = bitcast float %[[VAL_78]] to i32 +// CHECK-GCN: %[[VAL_79_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_78_1:.*]], i32 16) +// CHECK-GCN: %[[VAL_79:.*]] = bitcast i32 %[[VAL_79_1:.*]] to float +// CHECK: store float %[[VAL_79]], ptr{{.*}}%[[VAL_25]], align 4 +// CHECK-GCN: %[[VAL_22_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_22]] to ptr +// CHECK: %[[VAL_80:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_24]], i64 0, i64 0 +// CHECK-PTX: store ptr %[[VAL_22]], ptr %[[VAL_80]], align 8 +// CHECK-GCN: store ptr %[[VAL_22_1]], ptr{{.*}}%[[VAL_80]], align 8 +// CHECK-GCN: %[[VAL_23_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_23]] to ptr +// CHECK: %[[VAL_81:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_24]], i64 0, i64 1 +// CHECK-PTX: store ptr %[[VAL_23]], ptr %[[VAL_81]], align 8 +// CHECK-GCN: store ptr %[[VAL_23_1]], ptr{{.*}}%[[VAL_81]], align 8 +// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_26]], ptr %[[VAL_25]], ptr %[[VAL_24]]) +// CHECK-GCN: %[[VAL_39_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr +// CHECK-GCN: %[[VAL_37_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr +// CHECK-GCN: %[[VAL_26_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_26]] to ptr +// CHECK-GCN: %[[VAL_25_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_25]] to ptr +// CHECK-GCN: %[[VAL_24_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_24]] to ptr +// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_1]], ptr %[[VAL_37_1]], ptr %[[VAL_26_1]], ptr %[[VAL_25_1]], ptr %[[VAL_24_1]]) +// CHECK: %[[VAL_82:.*]] = load float, ptr{{.*}}%[[VAL_22]], align 4 +// CHECK: %[[VAL_83:.*]] = load float, ptr{{.*}}%[[VAL_23]], align 4 +// CHECK: store float %[[VAL_82]], ptr{{.*}}%[[VAL_39]], align 4 +// CHECK: store float %[[VAL_83]], ptr{{.*}}%[[VAL_37]], align 4 +// CHECK: %[[VAL_84:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 +// CHECK-PTX: %[[VAL_85:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_84]], i32 8, i32 31) +// CHECK-GCN: %[[VAL_84_1:.*]] = bitcast float %[[VAL_84]] to i32 +// CHECK-GCN: %[[VAL_85_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_84_1:.*]], i32 8) +// CHECK-GCN: %[[VAL_85:.*]] = bitcast i32 %[[VAL_85_1:.*]] to float +// CHECK: store float %[[VAL_85]], ptr{{.*}}%[[VAL_21]], align 4 +// CHECK: %[[VAL_86:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 +// CHECK-PTX: %[[VAL_87:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_86]], i32 8, i32 31) +// CHECK-GCN: %[[VAL_86_1:.*]] = bitcast float %[[VAL_86]] to i32 +// CHECK-GCN: %[[VAL_87_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_86_1:.*]], i32 8) +// CHECK-GCN: %[[VAL_87:.*]] = bitcast i32 %[[VAL_87_1:.*]] to float +// CHECK: store float %[[VAL_87]], ptr{{.*}}%[[VAL_20]], align 4 +// CHECK-GCN: %[[VAL_17_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_17]] to ptr +// CHECK: %[[VAL_88:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_19]], i64 0, i64 0 +// CHECK-PTX: store ptr %[[VAL_17]], ptr %[[VAL_88]], align 8 +// CHECK-GCN: store ptr %[[VAL_17_1]], ptr{{.*}}%[[VAL_88]], align 8 +// CHECK-GCN: %[[VAL_18_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_18]] to ptr +// CHECK: %[[VAL_89:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_19]], i64 0, i64 1 +// CHECK-PTX: store ptr %[[VAL_18]], ptr %[[VAL_89]], align 8 +// CHECK-GCN: store ptr %[[VAL_18_1]], ptr{{.*}}%[[VAL_89]], align 8 +// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_21]], ptr %[[VAL_20]], ptr %[[VAL_19]]) +// CHECK-GCN: %[[VAL_39_2:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr +// CHECK-GCN: %[[VAL_37_2:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr +// CHECK-GCN: %[[VAL_21_2:.*]] = addrspacecast ptr{{.*}}%[[VAL_21]] to ptr +// CHECK-GCN: %[[VAL_20_2:.*]] = addrspacecast ptr{{.*}}%[[VAL_20]] to ptr +// CHECK-GCN: %[[VAL_19_2:.*]] = addrspacecast ptr{{.*}}%[[VAL_19]] to ptr +// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_2]], ptr %[[VAL_37_2]], ptr %[[VAL_21_2]], ptr %[[VAL_20_2]], ptr %[[VAL_19_2]]) +// CHECK: %[[VAL_90:.*]] = load float, ptr{{.*}}%[[VAL_17]], align 4 +// CHECK: %[[VAL_91:.*]] = load float, ptr{{.*}}%[[VAL_18]], align 4 +// CHECK: store float %[[VAL_90]], ptr{{.*}}%[[VAL_39]], align 4 +// CHECK: store float %[[VAL_91]], ptr{{.*}}%[[VAL_37]], align 4 +// CHECK: %[[VAL_92:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 +// CHECK-PTX: %[[VAL_93:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_92]], i32 4, i32 31) +// CHECK-GCN: %[[VAL_92_1:.*]] = bitcast float %[[VAL_92]] to i32 +// CHECK-GCN: %[[VAL_93_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_92_1:.*]], i32 4) +// CHECK-GCN: %[[VAL_93:.*]] = bitcast i32 %[[VAL_93_1:.*]] to float +// CHECK: store float %[[VAL_93]], ptr{{.*}}%[[VAL_16]], align 4 +// CHECK: %[[VAL_94:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 +// CHECK-PTX: %[[VAL_95:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_94]], i32 4, i32 31) +// CHECK-GCN: %[[VAL_94_1:.*]] = bitcast float %[[VAL_94]] to i32 +// CHECK-GCN: %[[VAL_95_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_94_1:.*]], i32 4) +// CHECK-GCN: %[[VAL_95:.*]] = bitcast i32 %[[VAL_95_1:.*]] to float +// CHECK: store float %[[VAL_95]], ptr{{.*}}%[[VAL_15]], align 4 +// CHECK-GCN: %[[VAL_12_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_12]] to ptr +// CHECK: %[[VAL_96:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_14]], i64 0, i64 0 +// CHECK-PTX: store ptr %[[VAL_12]], ptr %[[VAL_96]], align 8 +// CHECK-GCN: store ptr %[[VAL_12_1]], ptr{{.*}}%[[VAL_96]], align 8 +// CHECK-GCN: %[[VAL_13_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_13]] to ptr +// CHECK: %[[VAL_97:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_14]], i64 0, i64 1 +// CHECK-PTX: store ptr %[[VAL_13]], ptr %[[VAL_97]], align 8 +// CHECK-GCN: store ptr %[[VAL_13_1]], ptr{{.*}}%[[VAL_97]], align 8 +// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_16]], ptr %[[VAL_15]], ptr %[[VAL_14]]) +// CHECK-GCN: %[[VAL_39_3:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr +// CHECK-GCN: %[[VAL_37_3:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr +// CHECK-GCN: %[[VAL_16_3:.*]] = addrspacecast ptr{{.*}}%[[VAL_16]] to ptr +// CHECK-GCN: %[[VAL_15_3:.*]] = addrspacecast ptr{{.*}}%[[VAL_15]] to ptr +// CHECK-GCN: %[[VAL_14_3:.*]] = addrspacecast ptr{{.*}}%[[VAL_14]] to ptr +// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_3]], ptr %[[VAL_37_3]], ptr %[[VAL_16_3]], ptr %[[VAL_15_3]], ptr %[[VAL_14_3]]) +// CHECK: %[[VAL_98:.*]] = load float, ptr{{.*}}%[[VAL_12]], align 4 +// CHECK: %[[VAL_99:.*]] = load float, ptr{{.*}}%[[VAL_13]], align 4 +// CHECK: store float %[[VAL_98]], ptr{{.*}}%[[VAL_39]], align 4 +// CHECK: store float %[[VAL_99]], ptr{{.*}}%[[VAL_37]], align 4 +// CHECK: %[[VAL_100:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 +// CHECK-PTX: %[[VAL_101:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_100]], i32 2, i32 31) +// CHECK-GCN: %[[VAL_100_1:.*]] = bitcast float %[[VAL_100]] to i32 +// CHECK-GCN: %[[VAL_101_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_100_1:.*]], i32 2) +// CHECK-GCN: %[[VAL_101:.*]] = bitcast i32 %[[VAL_101_1:.*]] to float +// CHECK: store float %[[VAL_101]], ptr{{.*}}%[[VAL_11]], align 4 +// CHECK: %[[VAL_102:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 +// CHECK-PTX: %[[VAL_103:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_102]], i32 2, i32 31) +// CHECK-GCN: %[[VAL_102_1:.*]] = bitcast float %[[VAL_102]] to i32 +// CHECK-GCN: %[[VAL_103_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_102_1:.*]], i32 2) +// CHECK-GCN: %[[VAL_103:.*]] = bitcast i32 %[[VAL_103_1:.*]] to float +// CHECK: store float %[[VAL_103]], ptr{{.*}}%[[VAL_10]], align 4 +// CHECK-GCN: %[[VAL_7_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_7]] to ptr +// CHECK: %[[VAL_104:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_9]], i64 0, i64 0 +// CHECK-PTX: store ptr %[[VAL_7]], ptr %[[VAL_104]], align 8 +// CHECK-GCN: store ptr %[[VAL_7_1]], ptr{{.*}}%[[VAL_104]], align 8 +// CHECK-GCN: %[[VAL_8_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_8]] to ptr +// CHECK: %[[VAL_105:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_9]], i64 0, i64 1 +// CHECK-PTX: store ptr %[[VAL_8]], ptr %[[VAL_105]], align 8 +// CHECK-GCN: store ptr %[[VAL_8_1]], ptr{{.*}}%[[VAL_105]], align 8 +// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_11]], ptr %[[VAL_10]], ptr %[[VAL_9]]) +// CHECK-GCN: %[[VAL_39_4:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr +// CHECK-GCN: %[[VAL_37_4:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr +// CHECK-GCN: %[[VAL_11_4:.*]] = addrspacecast ptr{{.*}}%[[VAL_11]] to ptr +// CHECK-GCN: %[[VAL_10_4:.*]] = addrspacecast ptr{{.*}}%[[VAL_10]] to ptr +// CHECK-GCN: %[[VAL_9_4:.*]] = addrspacecast ptr{{.*}}%[[VAL_9]] to ptr +// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_4]], ptr %[[VAL_37_4]], ptr %[[VAL_11_4]], ptr %[[VAL_10_4]], ptr %[[VAL_9_4]]) +// CHECK: %[[VAL_106:.*]] = load float, ptr{{.*}}%[[VAL_7]], align 4 +// CHECK: %[[VAL_107:.*]] = load float, ptr{{.*}}%[[VAL_8]], align 4 +// CHECK: store float %[[VAL_106]], ptr{{.*}}%[[VAL_39]], align 4 +// CHECK: store float %[[VAL_107]], ptr{{.*}}%[[VAL_37]], align 4 +// CHECK: %[[VAL_108:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 +// CHECK-PTX: %[[VAL_109:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_108]], i32 1, i32 31) +// CHECK-GCN: %[[VAL_108_1:.*]] = bitcast float %[[VAL_108]] to i32 +// CHECK-GCN: %[[VAL_109_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_108_1:.*]], i32 1) +// CHECK-GCN: %[[VAL_109:.*]] = bitcast i32 %[[VAL_109_1:.*]] to float +// CHECK: store float %[[VAL_109]], ptr{{.*}}%[[VAL_6]], align 4 +// CHECK: %[[VAL_110:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 +// CHECK-PTX: %[[VAL_111:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_110]], i32 1, i32 31) +// CHECK-GCN: %[[VAL_110_1:.*]] = bitcast float %[[VAL_110]] to i32 +// CHECK-GCN: %[[VAL_111_1:.*]] = call i32 @__ockl_readuplane_i32(i32 %[[VAL_110_1:.*]], i32 1) +// CHECK-GCN: %[[VAL_111:.*]] = bitcast i32 %[[VAL_111_1:.*]] to float +// CHECK: store float %[[VAL_111]], ptr{{.*}}%[[VAL_5]], align 4 +// CHECK-GCN: %[[VAL_2_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_2]] to ptr +// CHECK: %[[VAL_112:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_4]], i64 0, i64 0 +// CHECK-PTX: store ptr %[[VAL_2]], ptr %[[VAL_112]], align 8 +// CHECK-GCN: store ptr %[[VAL_2_1]], ptr{{.*}}%[[VAL_112]], align 8 +// CHECK-GCN: %[[VAL_3_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_3]] to ptr +// CHECK: %[[VAL_113:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_4]], i64 0, i64 1 +// CHECK-PTX: store ptr %[[VAL_3]], ptr %[[VAL_113]], align 8 +// CHECK-GCN: store ptr %[[VAL_3_1]], ptr{{.*}}%[[VAL_113]], align 8 +// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_6]], ptr %[[VAL_5]], ptr %[[VAL_4]]) +// CHECK-GCN: %[[VAL_39_5:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr +// CHECK-GCN: %[[VAL_37_5:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr +// CHECK-GCN: %[[VAL_6_5:.*]] = addrspacecast ptr{{.*}}%[[VAL_6]] to ptr +// CHECK-GCN: %[[VAL_5_5:.*]] = addrspacecast ptr{{.*}}%[[VAL_5]] to ptr +// CHECK-GCN: %[[VAL_4_5:.*]] = addrspacecast ptr{{.*}}%[[VAL_4]] to ptr +// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_5]], ptr %[[VAL_37_5]], ptr %[[VAL_6_5]], ptr %[[VAL_5_5]], ptr %[[VAL_4_5]]) +// CHECK: %[[VAL_114:.*]] = load float, ptr{{.*}}%[[VAL_2]], align 4 +// CHECK: %[[VAL_115:.*]] = load float, ptr{{.*}}%[[VAL_3]], align 4 +// CHECK: store float %[[VAL_114]], ptr{{.*}}%[[VAL_39]], align 4 +// CHECK: store float %[[VAL_115]], ptr{{.*}}%[[VAL_37]], align 4 +// CHECK: %[[VAL_116:.*]] = udiv i32 %thread.id.2, 32 +// CHECK: %[[VAL_117:.*]] = icmp ult i32 %thread.id.1, 8 +// CHECK: br i1 %[[VAL_117]], label %thread_in_bounds-true, label %thread_in_bounds-after +// CHECK: thread_in_bounds-after: ; preds = %[[VAL_118:.*]], %[[VAL_60]] +// CHECK: br label %[[VAL_44]] +// CHECK: is_full_tile-true: ; preds = %[[VAL_68]] +// CHECK: store i32 0, ptr{{.*}}%[[VAL_34]], align 4 +// CHECK: br label %[[VAL_119:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_120:.*]], %[[VAL_72]] +// CHECK: %[[VAL_121:.*]] = load i32, ptr{{.*}}%[[VAL_34]], align 4 +// CHECK: %[[VAL_122:.*]] = icmp uge i32 %[[VAL_121]], 512 +// CHECK: br i1 %[[VAL_122]], label %[[VAL_75]], label %[[VAL_120]] +// CHECK: loop2.loop_body: ; preds = %[[VAL_119]] +// CHECK: %[[VAL_123:.*]] = add nuw nsw i32 %[[VAL_121]], 32 +// CHECK: store i32 %[[VAL_123]], ptr{{.*}}%[[VAL_34]], align 4 +// CHECK: %[[VAL_124:.*]] = icmp eq i32 %[[VAL_121]], 0 +// CHECK: %[[VAL_125:.*]] = add i32 %[[VAL_121]], %thread.id.2 +// CHECK: %[[VAL_126:.*]] = add i32 %tile_origin.0, %[[VAL_58]] +// CHECK: %[[VAL_127:.*]] = add i32 %tile_origin.1, %[[VAL_66]] +// CHECK: %[[VAL_128:.*]] = add i32 %tile_origin.2, %[[VAL_125]] +// CHECK: %[[VAL_129:.*]] = getelementptr inbounds [5 x [200 x [300 x float]]], ptr{{.*}}%[[VAL_130:.*]], i32 0, i32 %[[VAL_126]], i32 %[[VAL_127]], i32 %[[VAL_128]] +// CHECK: %[[VAL_131:.*]] = load float, ptr{{.*}}%[[VAL_129]], align 4, !invariant.load !{{[0-9]}} +// CHECK: store float %[[VAL_131]], ptr{{.*}}%[[VAL_40]], align 4 +// CHECK: %[[VAL_132:.*]] = getelementptr inbounds [5 x [200 x [300 x float]]], ptr{{.*}}%[[VAL_133:.*]], i32 0, i32 %[[VAL_126]], i32 %[[VAL_127]], i32 %[[VAL_128]] +// CHECK: %[[VAL_134:.*]] = load float, ptr{{.*}}%[[VAL_132]], align 4, !invariant.load !{{[0-9]}} +// CHECK: store float %[[VAL_134]], ptr{{.*}}%[[VAL_38]], align 4 +// CHECK-GCN: %[[VAL_31_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_31]] to ptr +// CHECK: %[[VAL_135:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_33]], i64 0, i64 0 +// CHECK-PTX: store ptr %[[VAL_31]], ptr %[[VAL_135]], align 8 +// CHECK-GCN: store ptr %[[VAL_31_1]], ptr{{.*}}%[[VAL_135]], align 8 +// CHECK-GCN: %[[VAL_32_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_32]] to ptr +// CHECK: %[[VAL_136:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_33]], i64 0, i64 1 +// CHECK-PTX: store ptr %[[VAL_32]], ptr %[[VAL_136]], align 8 +// CHECK-GCN: store ptr %[[VAL_32_1]], ptr{{.*}}%[[VAL_136]], align 8 +// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_40]], ptr %[[VAL_38]], ptr %[[VAL_33]]) +// CHECK-GCN: %[[VAL_39_6:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr +// CHECK-GCN: %[[VAL_37_6:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr +// CHECK-GCN: %[[VAL_40_6:.*]] = addrspacecast ptr{{.*}}%[[VAL_40]] to ptr +// CHECK-GCN: %[[VAL_38_6:.*]] = addrspacecast ptr{{.*}}%[[VAL_38]] to ptr +// CHECK-GCN: %[[VAL_33_6:.*]] = addrspacecast ptr{{.*}}%[[VAL_33]] to ptr +// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_6]], ptr %[[VAL_37_6]], ptr %[[VAL_40_6]], ptr %[[VAL_38_6]], ptr %[[VAL_33_6]]) +// CHECK: %[[VAL_137:.*]] = load float, ptr{{.*}}%[[VAL_31]], align 4 +// CHECK: %[[VAL_138:.*]] = load float, ptr{{.*}}%[[VAL_32]], align 4 +// CHECK: store float %[[VAL_137]], ptr{{.*}}%[[VAL_39]], align 4 +// CHECK: store float %[[VAL_138]], ptr{{.*}}%[[VAL_37]], align 4 +// CHECK: br label %[[VAL_119]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit: ; preds = %[[VAL_119]] +// CHECK: br label %[[VAL_65]] +// CHECK: is_full_tile-false: ; preds = %[[VAL_68]] +// CHECK: store i32 0, ptr{{.*}}%[[VAL_30]], align 4 +// CHECK: br label %[[VAL_139:.*]] +// CHECK: loop2.loop_header9: ; preds = %[[VAL_140:.*]], %[[VAL_73]] +// CHECK: %[[VAL_141:.*]] = load i32, ptr{{.*}}%[[VAL_30]], align 4 +// CHECK: %[[VAL_142:.*]] = icmp uge i32 %[[VAL_141]], 512 +// CHECK: br i1 %[[VAL_142]], label %[[VAL_74]], label %[[VAL_143:.*]] +// CHECK: loop2.loop_body10: ; preds = %[[VAL_139]] +// CHECK: %[[VAL_144:.*]] = add nuw nsw i32 %[[VAL_141]], 32 +// CHECK: store i32 %[[VAL_144]], ptr{{.*}}%[[VAL_30]], align 4 +// CHECK: %[[VAL_145:.*]] = icmp eq i32 %[[VAL_141]], 0 +// CHECK: %[[VAL_146:.*]] = add i32 %[[VAL_141]], %thread.id.2 +// CHECK: %[[VAL_147:.*]] = icmp ult i32 %[[VAL_146]], %tile_bound.2 +// CHECK: br i1 %[[VAL_147]], label %[[VAL_148:.*]], label %[[VAL_140]] +// CHECK: x_in_tile-after: ; preds = %[[VAL_148]], %[[VAL_143]] +// CHECK: br label %[[VAL_139]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit8: ; preds = %[[VAL_139]] +// CHECK: br label %[[VAL_65]] +// CHECK: x_in_tile-true: ; preds = %[[VAL_143]] +// CHECK: %[[VAL_149:.*]] = add i32 %tile_origin.0, %[[VAL_58]] +// CHECK: %[[VAL_150:.*]] = add i32 %tile_origin.1, %[[VAL_66]] +// CHECK: %[[VAL_151:.*]] = add i32 %tile_origin.2, %[[VAL_146]] +// CHECK: %[[VAL_152:.*]] = getelementptr inbounds [5 x [200 x [300 x float]]], ptr{{.*}}%[[VAL_130]], i32 0, i32 %[[VAL_149]], i32 %[[VAL_150]], i32 %[[VAL_151]] +// CHECK: %[[VAL_153:.*]] = load float, ptr{{.*}}%[[VAL_152]], align 4, !invariant.load !{{[0-9]}} +// CHECK: store float %[[VAL_153]], ptr{{.*}}%[[VAL_40]], align 4 +// CHECK: %[[VAL_154:.*]] = getelementptr inbounds [5 x [200 x [300 x float]]], ptr{{.*}}%[[VAL_133]], i32 0, i32 %[[VAL_149]], i32 %[[VAL_150]], i32 %[[VAL_151]] +// CHECK: %[[VAL_155:.*]] = load float, ptr{{.*}}%[[VAL_154]], align 4, !invariant.load !{{[0-9]}} +// CHECK: store float %[[VAL_155]], ptr{{.*}}%[[VAL_38]], align 4 +// CHECK-GCN: %[[VAL_27_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_27]] to ptr +// CHECK: %[[VAL_156:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_29]], i64 0, i64 0 +// CHECK-PTX: store ptr %[[VAL_27]], ptr %[[VAL_156]], align 8 +// CHECK-GCN: store ptr %[[VAL_27_1]], ptr{{.*}}%[[VAL_156]], align 8 +// CHECK-GCN: %[[VAL_28_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_28]] to ptr +// CHECK: %[[VAL_157:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_29]], i64 0, i64 1 +// CHECK-PTX: store ptr %[[VAL_28]], ptr %[[VAL_157]], align 8 +// CHECK-GCN: store ptr %[[VAL_28_1]], ptr{{.*}}%[[VAL_157]], align 8 +// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_40]], ptr %[[VAL_38]], ptr %[[VAL_29]]) +// CHECK-GCN: %[[VAL_39_7:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr +// CHECK-GCN: %[[VAL_37_7:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr +// CHECK-GCN: %[[VAL_40_7:.*]] = addrspacecast ptr{{.*}}%[[VAL_40]] to ptr +// CHECK-GCN: %[[VAL_38_7:.*]] = addrspacecast ptr{{.*}}%[[VAL_38]] to ptr +// CHECK-GCN: %[[VAL_29_7:.*]] = addrspacecast ptr{{.*}}%[[VAL_29]] to ptr +// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_7]], ptr %[[VAL_37_7]], ptr %[[VAL_40_7]], ptr %[[VAL_38_7]], ptr %[[VAL_29_7]]) +// CHECK: %[[VAL_158:.*]] = load float, ptr{{.*}}%[[VAL_27]], align 4 +// CHECK: %[[VAL_159:.*]] = load float, ptr{{.*}}%[[VAL_28]], align 4 +// CHECK: store float %[[VAL_158]], ptr{{.*}}%[[VAL_39]], align 4 +// CHECK: store float %[[VAL_159]], ptr{{.*}}%[[VAL_37]], align 4 +// CHECK: br label %[[VAL_140]] +// CHECK: thread_in_bounds-true: ; preds = %[[VAL_60]] +// CHECK: %[[VAL_160:.*]] = icmp eq i32 %lane_id, 0 +// CHECK: br i1 %[[VAL_160]], label %[[VAL_161:.*]], label %[[VAL_162:.*]] +// CHECK: intra_warp_reduce_write-after: ; preds = %[[VAL_161]], %thread_in_bounds-true +// CHECK-PTX: call void @llvm.nvvm.barrier0() +// CHECK-GCN: call void @llvm.amdgcn.s.barrier() +// CHECK: %[[VAL_163:.*]] = icmp eq i32 %[[VAL_116]], 0 +// CHECK: br i1 %[[VAL_163]], label %[[VAL_164:.*]], label %[[VAL_118]] +// CHECK: inter_warp_reduce-after: ; preds = %[[VAL_165:.*]], %[[VAL_162]] +// CHECK: br label %thread_in_bounds-after +// CHECK: intra_warp_reduce_write-true: ; preds = %thread_in_bounds-true +// CHECK: %[[VAL_166:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 +// CHECK: %[[VAL_167:.*]] = getelementptr inbounds [8 x [1 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.1, i32 %[[VAL_116]] +// CHECK: %[[VAL_168:.*]] = addrspacecast ptr addrspace(3) %[[VAL_167]] to ptr +// CHECK: store float %[[VAL_166]], ptr{{.*}}%[[VAL_168]], align 4 +// CHECK: %[[VAL_169:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 +// CHECK: %[[VAL_170:.*]] = getelementptr inbounds [8 x [1 x float]], ptr addrspace(3) @shared_cache{{.*}}, i32 0, i32 %thread.id.1, i32 %[[VAL_116]] +// CHECK: %[[VAL_171:.*]] = addrspacecast ptr addrspace(3) %[[VAL_170]] to ptr +// CHECK: store float %[[VAL_169]], ptr{{.*}}%[[VAL_171]], align 4 +// CHECK: br label %[[VAL_162]] +// CHECK: inter_warp_reduce-true: ; preds = %[[VAL_162]] +// CHECK: %[[VAL_172:.*]] = getelementptr inbounds [8 x [1 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.1, i32 %lane_id +// CHECK: %[[VAL_173:.*]] = addrspacecast ptr addrspace(3) %[[VAL_172]] to ptr +// CHECK-GCN: %[[VAL_1_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_1]] to ptr +// CHECK-PTX: store float %[[VAL_46]], ptr %[[VAL_1]], align 4 +// CHECK-GCN: store float %[[VAL_46]], ptr %[[VAL_1_1]], align 4 +// CHECK: %[[VAL_174:.*]] = icmp ult i32 %thread.id.2, 1 +// CHECK-PTX: %[[VAL_175:.*]] = select i1 %[[VAL_174]], ptr %[[VAL_173]], ptr %[[VAL_1]] +// CHECK-GCN: %[[VAL_175:.*]] = select i1 %[[VAL_174]], ptr %[[VAL_173]], ptr %[[VAL_1_1]] +// CHECK: %[[VAL_176:.*]] = getelementptr inbounds [8 x [1 x float]], ptr addrspace(3) @shared_cache{{.*}}, i32 0, i32 %thread.id.1, i32 %lane_id +// CHECK: %[[VAL_177:.*]] = addrspacecast ptr addrspace(3) %[[VAL_176]] to ptr +// CHECK-GCN: %[[VAL_0_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_0]] to ptr +// CHECK-PTX: store float %[[VAL_48]], ptr{{.*}}%[[VAL_0]], align 4 +// CHECK-GCN: store float %[[VAL_48]], ptr{{.*}}%[[VAL_0_1]], align 4 +// CHECK: %[[VAL_178:.*]] = icmp ult i32 %thread.id.2, 1 +// CHECK-PTX: %[[VAL_179:.*]] = select i1 %[[VAL_178]], ptr{{.*}}%[[VAL_177]], ptr %[[VAL_0]] +// CHECK-GCN: %[[VAL_179:.*]] = select i1 %[[VAL_178]], ptr{{.*}}%[[VAL_177]], ptr %[[VAL_0_1]] +// CHECK: %[[VAL_180:.*]] = icmp eq i32 %thread.id.2, 0 +// CHECK: br i1 %[[VAL_180]], label %[[VAL_181:.*]], label %[[VAL_165]] +// CHECK: reduction_write_output-after: ; preds = %[[VAL_181]], %[[VAL_164]] +// CHECK: br label %[[VAL_118]] +// CHECK: reduction_write_output-true: ; preds = %[[VAL_164]] +// CHECK: %[[VAL_183:.*]] = add i32 %tile_origin.1, %thread.id.1 +// CHECK: %[[VAL_186:.*]] = getelementptr inbounds [200 x float], ptr{{.*}}%[[VAL_187:.*]], i32 0, i32 %[[VAL_183]] +// CHECK: %[[VAL_188:.*]] = load float, ptr{{.*}}%[[VAL_175]], align 4 +// CHECK: store float %[[VAL_188]], ptr{{.*}}%[[VAL_186]], align 4 +// CHECK: %[[VAL_190:.*]] = add i32 %tile_origin.1, %thread.id.1 +// CHECK: %[[VAL_193:.*]] = getelementptr inbounds [200 x float], ptr{{.*}}%[[VAL_194:.*]], i32 0, i32 %[[VAL_190]] +// CHECK: %[[VAL_195:.*]] = load float, ptr{{.*}}%[[VAL_179]], align 4 +// CHECK: store float %[[VAL_195]], ptr{{.*}}%[[VAL_193]], align 4 +// CHECK: br label %[[VAL_165]] +// CHECK: entry: +// CHECK: %[[VAL_196:.*]] = alloca float, align 4 +// CHECK: %[[VAL_197:.*]] = alloca float, align 4 +// CHECK: %[[VAL_198:.*]] = alloca [2 x ptr], align 8 +// CHECK: %[[VAL_199:.*]] = alloca [2 x ptr], align 8 +// CHECK: %[[VAL_200:.*]] = alloca [2 x ptr], align 8 +// CHECK: %[[VAL_201:.*]] = load float, ptr{{.*}}%[[VAL_202:.*]], align 4 +// CHECK: %[[VAL_203:.*]] = load float, ptr{{.*}}%[[VAL_204:.*]], align 4 +// CHECK: %[[VAL_205:.*]] = fadd float %[[VAL_201]], %[[VAL_203]] +// CHECK: store float %[[VAL_205]], ptr{{.*}}%[[VAL_197]], align 4 +// CHECK: %[[VAL_206:.*]] = load float, ptr{{.*}}%[[VAL_207:.*]], align 4 +// CHECK: %[[VAL_208:.*]] = load float, ptr{{.*}}%[[VAL_209:.*]], align 4 +// CHECK: %[[VAL_210:.*]] = fadd float %[[VAL_206]], %[[VAL_208]] +// CHECK: store float %[[VAL_210]], ptr{{.*}}%[[VAL_196]], align 4 +// CHECK-GCN: %[[VAL_197_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_197]] to ptr +// CHECK: %[[VAL_211:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_200]], i64 0, i64 0 +// CHECK-PTX: store ptr %[[VAL_197]], ptr %[[VAL_211]], align 8 +// CHECK-GCN: store ptr %[[VAL_197_1]], ptr{{.*}}%[[VAL_211]], align 8 +// CHECK-GCN: %[[VAL_196_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_196]] to ptr +// CHECK: %[[VAL_212:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_200]], i64 0, i64 1 +// CHECK-PTX: store ptr %[[VAL_196]], ptr %[[VAL_212]], align 8 +// CHECK-GCN: store ptr %[[VAL_196_1]], ptr{{.*}}%[[VAL_212]], align 8 +// CHECK: %[[VAL_213:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_214:.*]], i64 0, i64 0 +// CHECK: %[[VAL_215:.*]] = load ptr, ptr{{.*}}%[[VAL_213]], align 8, !dereferenceable !{{[0-9]*}}, !align !{{[0-9]*}} +// CHECK: %[[VAL_216:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_200]], i64 0, i64 0 +// CHECK: %[[VAL_217:.*]] = load ptr, ptr{{.*}}%[[VAL_216]], align 8, !dereferenceable !{{[0-9]*}}, !align !{{[0-9]*}} +// CHECK: %[[VAL_218:.*]] = load float, ptr{{.*}}%[[VAL_217]], align 4 +// CHECK: store float %[[VAL_218]], ptr{{.*}}%[[VAL_215]], align 4 +// CHECK: %[[VAL_219:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_214]], i64 0, i64 1 +// CHECK: %[[VAL_220:.*]] = load ptr, ptr{{.*}}%[[VAL_219]], align 8, !dereferenceable !{{[0-9]*}}, !align !{{[0-9]*}} +// CHECK: %[[VAL_221:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_200]], i64 0, i64 1 +// CHECK: %[[VAL_222:.*]] = load ptr, ptr{{.*}}%[[VAL_221]], align 8, !dereferenceable !{{[0-9]*}}, !align !{{[0-9]*}} +// CHECK: %[[VAL_223:.*]] = load float, ptr{{.*}}%[[VAL_222]], align 4 +// CHECK: store float %[[VAL_223]], ptr{{.*}}%[[VAL_220]], align 4 +// CHECK: ret void + diff --git a/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc b/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc index 1cda938d7c55c..bb6eb634db78a 100644 --- a/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc +++ b/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,15 +16,9 @@ limitations under the License. #include "xla/service/gpu/reduction_degenerate_dim_remover.h" #include -#include -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" -#include "xla/statusor.h" -#include "xla/tests/filecheck.h" +#include "absl/strings/string_view.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/service/gpu/tests/reduction_dimension_grouper_test.cc b/xla/service/gpu/tests/reduction_dimension_grouper_test.cc index afae3c455ad55..fa149a13b940c 100644 --- a/xla/service/gpu/tests/reduction_dimension_grouper_test.cc +++ b/xla/service/gpu/tests/reduction_dimension_grouper_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,13 +16,9 @@ limitations under the License. #include "xla/service/gpu/reduction_dimension_grouper.h" #include -#include -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/hlo_parser.h" -#include "xla/tests/filecheck.h" +#include "absl/strings/string_view.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/service/gpu/tests/reduction_emitter_test.cc b/xla/service/gpu/tests/reduction_emitter_test.cc index 25aec5332d366..b6ecfd527dee5 100644 --- a/xla/service/gpu/tests/reduction_emitter_test.cc +++ b/xla/service/gpu/tests/reduction_emitter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,12 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include - +#include "xla/error_spec.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/tests/hlo_test_base.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/service/gpu/tests/reduction_layout_normalizer_test.cc b/xla/service/gpu/tests/reduction_layout_normalizer_test.cc index 50d4141da4e35..817d9c73c95b1 100644 --- a/xla/service/gpu/tests/reduction_layout_normalizer_test.cc +++ b/xla/service/gpu/tests/reduction_layout_normalizer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,13 +16,10 @@ limitations under the License. #include "xla/service/gpu/reduction_layout_normalizer.h" #include -#include -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/hlo_parser.h" -#include "xla/tests/filecheck.h" +#include "absl/strings/string_view.h" +#include "xla/error_spec.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/service/gpu/tests/reduction_vectorization_sm_all.hlo b/xla/service/gpu/tests/reduction_vectorization_sm_all.hlo index fb3509310802a..6a25580a4bcff 100644 --- a/xla/service/gpu/tests/reduction_vectorization_sm_all.hlo +++ b/xla/service/gpu/tests/reduction_vectorization_sm_all.hlo @@ -1,8 +1,7 @@ -// RUN: hlo_to_llvm_ir --ptx %s | FileCheck %s -// RUN: hlo_to_llvm_ir --ptx --sm=50 %s | FileCheck %s --check-prefix=CHECK-SM50 -// RUN: hlo_to_llvm_ir --ptx --sm=60 %s | FileCheck %s --check-prefix=CHECK-SM60 -// RUN: hlo_to_llvm_ir --ptx --sm=70 %s | FileCheck %s --check-prefix=CHECK-SM70 -// RUN: hlo_to_llvm_ir --ptx --sm=86 %s | FileCheck %s --check-prefix=CHECK-SM86 +// RUN: hlo-opt %s --platform=gpu --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck %s +// RUN: hlo-opt %s --platform=gpu --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/p100.txtpb --split-input-file | FileCheck %s --check-prefixes=CHECK-SM60 +// RUN: hlo-opt %s --platform=gpu --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/v100.txtpb --split-input-file | FileCheck %s --check-prefixes=CHECK-SM70 +// RUN: hlo-opt %s --platform=gpu --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/a6000.txtpb --split-input-file | FileCheck %s --check-prefixes=CHECK-SM86 // CHECK-LABEL: .entry wrapped_reduce_odd_row // CHECK-NOT: ld.global.nc.v2.f32 @@ -33,7 +32,7 @@ ENTRY %main { // ----- // CHECK-SM86-LABEL: .entry wrapped_reduce_small_row -// CHECK-SM86: .reqntid 96, 1, 1 +// CHECK-SM86: .reqntid 256, 1, 1 HloModule ReduceSmallRow, is_scheduled=true @@ -90,10 +89,6 @@ ENTRY %main { // SM dependent tests -// CHECK-SM50-LABEL: .entry wrapped_reduce_exp -// CHECK-SM50-NOT: ld.global.nc.v2.f32 -// CHECK-SM50-COUNT-8: ld.global.nc.f32 - // CHECK-SM60: .entry wrapped_exp // CHECK-SM60-LABEL: .entry wrapped_reduce_exp // CHECK-SM60-COUNT-8: ld.global.nc.v2.f32 @@ -132,10 +127,6 @@ ENTRY %main { HloModule ReduceTileFit, is_scheduled=true -// CHECK-SM50-LABEL: .entry wrapped_reduce_tile_fit -// CHECK-SM50-NOT: ld.global.nc.v2.f32 -// CHECK-SM50-COUNT-8: ld.global.nc.f32 - // CHECK-SM60-LABEL: .entry wrapped_reduce_tile_fit // CHECK-SM60-COUNT-8: ld.global.nc.v2.f32 @@ -164,10 +155,6 @@ ENTRY %main { HloModule ReducePower2, is_scheduled=true -// CHECK-SM50-LABEL: .entry wrapped_reduce_pow_2 -// CHECK-SM50-NOT: ld.global.nc.v2.f32 -// CHECK-SM50-COUNT-8: ld.global.nc.f32 - // CHECK-SM60-LABEL: .entry wrapped_reduce_pow_2 // CHECK-SM60-COUNT-4: ld.global.nc.v2.f32 @@ -202,7 +189,7 @@ HloModule ReduceEvenColumns, is_scheduled=true // CHECK-SM70-LABEL: .entry wrapped_reduce_even_col // CHECK-SM70-COUNT-2: ld.global.nc.v2.f32 -// CHECK-SM70-COUNT-2: ld.global.nc.f32 +// CHECK-SM70-COUNT-2: ld.global.nc.v2.f32 %max_ { %x = f32[] parameter(0) diff --git a/xla/service/gpu/tests/reduction_vectorization_test.cc b/xla/service/gpu/tests/reduction_vectorization_test.cc index b216afba8b5ee..3a4f912a205d3 100644 --- a/xla/service/gpu/tests/reduction_vectorization_test.cc +++ b/xla/service/gpu/tests/reduction_vectorization_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,18 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include "absl/strings/str_replace.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/gpu_executable.h" +#include "xla/error_spec.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_parser.h" -#include "xla/statusor.h" -#include "xla/tests/filecheck.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/stream_executor/device_description.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/service/gpu/tests/rng_get_and_update_state.hlo b/xla/service/gpu/tests/rng_get_and_update_state.hlo index 297d60f372050..e140b56af9d60 100644 --- a/xla/service/gpu/tests/rng_get_and_update_state.hlo +++ b/xla/service/gpu/tests/rng_get_and_update_state.hlo @@ -1,4 +1,4 @@ -// RUN: hlo_to_llvm_ir %s | FileCheck %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck %s HloModule TestModule, is_scheduled=true diff --git a/xla/service/gpu/tests/scatter.hlo b/xla/service/gpu/tests/scatter.hlo index 9edff4fc317a8..20211bdbe892f 100644 --- a/xla/service/gpu/tests/scatter.hlo +++ b/xla/service/gpu/tests/scatter.hlo @@ -1,152 +1,181 @@ // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py -// RUN: hlo_to_llvm_ir %s | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %{IR_SUBST} %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s // CHECK-LABEL: entry: // CHECK: %[[VAL_0:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_1:.*]] = call i32 [[CTAIDX]] -// CHECK: %[[VAL_2:.*]] = call i32 [[TIDX]] +// CHECK-PTX: %[[VAL_1:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x +// CHECK-GCN: %[[VAL_1:.*]] = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK-PTX: %[[VAL_2:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x +// CHECK-GCN: %[[VAL_2:.*]] = call i32 @llvm.amdgcn.workitem.id.x // CHECK: %[[VAL_3:.*]] = mul nuw nsw i32 %[[VAL_1]], 6 // CHECK: %[[VAL_4:.*]] = add nuw nsw i32 %[[VAL_3]], %[[VAL_2]] // CHECK: %[[VAL_5:.*]] = icmp ult i32 %[[VAL_4]], 6 // CHECK: call void @llvm.assume(i1 %[[VAL_5]]) -// CHECK: %[[VAL_6:.*]] = udiv i32 %[[VAL_4]], 1 -// CHECK: %[[VAL_7:.*]] = urem i32 %[[VAL_6]], 3 -// CHECK: %[[VAL_8:.*]] = udiv i32 %[[VAL_4]], 3 -// CHECK: %[[VAL_9:.*]] = icmp ult i32 %[[VAL_4]], 6 -// CHECK: br i1 %[[VAL_9]], label %[[VAL_10:.*]], label %[[VAL_11:.*]] -// CHECK: scatter_TensorFlowScatterV1.in_bounds-after: ; preds = %[[VAL_12:.*]], %[[VAL_13:.*]] +// CHECK: %[[VAL_6:.*]] = add nuw nsw i32 %[[VAL_4]], 0 +// CHECK: %[[VAL_7:.*]] = udiv i32 %[[VAL_6]], 1 +// CHECK: %[[VAL_8:.*]] = urem i32 %[[VAL_7]], 3 +// CHECK: %[[VAL_9:.*]] = udiv i32 %[[VAL_6]], 3 +// CHECK: %[[VAL_10:.*]] = icmp ult i32 %[[VAL_4]], 6 +// CHECK: br i1 %[[VAL_10]], label %[[VAL_11:.*]], label %[[VAL_12:.*]] +// CHECK: scatter_TensorFlowScatterV1.in_bounds-after: ; preds = %[[VAL_13:.*]], %[[VAL_14:.*]] // CHECK: ret void -// CHECK: scatter_TensorFlowScatterV1.in_bounds-true: ; preds = %[[VAL_13]] -// CHECK: %[[VAL_14:.*]] = getelementptr inbounds [2 x i32], ptr %[[VAL_15:.*]], i32 0, i32 %[[VAL_8]] -// CHECK: %[[VAL_16:.*]] = load i32, ptr %[[VAL_14]], align 4, !invariant.load -// CHECK: %[[VAL_17:.*]] = add i32 0, %[[VAL_16]] -// CHECK: %[[VAL_18:.*]] = icmp ult i32 %[[VAL_16]], 3 -// CHECK: %[[VAL_19:.*]] = and i1 true, %[[VAL_18]] -// CHECK: br i1 %[[VAL_19]], label %[[VAL_20:.*]], label %[[VAL_12]] -// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_20]], %[[VAL_10]] -// CHECK: br label %[[VAL_11]] -// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_10]] -// CHECK: %[[VAL_21:.*]] = getelementptr inbounds [3 x [3 x i32]], ptr %[[VAL_22:.*]], i32 0, i32 %[[VAL_17]], i32 %[[VAL_7]] -// CHECK: %[[VAL_23:.*]] = getelementptr inbounds i32, ptr %[[VAL_24:.*]], i32 %[[VAL_4]] -// CHECK: %[[VAL_25:.*]] = load i32, ptr %[[VAL_23]], align 4, !invariant.load -// CHECK: store i32 %[[VAL_25]], ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_0]], align 4 -// CHECK: %[[VAL_26:.*]] = load i32, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_0]], align 4 -// CHECK: store atomic i32 %[[VAL_26]], ptr %[[VAL_21]] unordered, align 4 +// CHECK: scatter_TensorFlowScatterV1.in_bounds-true: ; preds = %[[VAL_14]] +// CHECK: %[[VAL_15:.*]] = getelementptr inbounds [2 x i32], ptr %[[VAL_16:.*]], i32 0, i32 %[[VAL_9]] +// CHECK: %[[VAL_17:.*]] = load i32, ptr %[[VAL_15]], align 4, !invariant.load +// CHECK: %[[VAL_18:.*]] = add i32 0, %[[VAL_17]] +// CHECK: %[[VAL_19:.*]] = icmp ult i32 %[[VAL_17]], 3 +// CHECK: %[[VAL_20:.*]] = and i1 true, %[[VAL_19]] +// CHECK: br i1 %[[VAL_20]], label %[[VAL_21:.*]], label %[[VAL_13]] +// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_21]], %[[VAL_11]] // CHECK: br label %[[VAL_12]] +// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_11]] +// CHECK: %[[VAL_22:.*]] = getelementptr inbounds [3 x [3 x i32]], ptr %[[VAL_23:.*]], i32 0, i32 %[[VAL_18]], i32 %[[VAL_8]] +// CHECK: %[[VAL_24:.*]] = getelementptr i32, ptr %[[VAL_25:.*]], i32 %[[VAL_4]] +// CHECK: %[[VAL_26:.*]] = getelementptr inbounds i32, ptr %[[VAL_24]], i32 0 +// CHECK: %[[VAL_27:.*]] = load i32, ptr %[[VAL_26]], align 4, !invariant.load +// CHECK-PTX: store i32 %[[VAL_27]], ptr %[[VAL_0]], align 4 +// CHECK-GCN: store i32 %[[VAL_27]], ptr addrspace(5) %[[VAL_0]], align 4 +// CHECK-PTX: %[[VAL_28:.*]] = load i32, ptr %[[VAL_0]], align 4 +// CHECK-GCN: %[[VAL_28:.*]] = load i32, ptr addrspace(5) %[[VAL_0]], align 4 +// CHECK: store atomic i32 %[[VAL_28]], ptr %[[VAL_22]] unordered, align 4 +// CHECK: br label %[[VAL_13]] // CHECK: entry: -// CHECK: %[[VAL_27:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_28:.*]] = call i32 [[CTAIDX]] -// CHECK: %[[VAL_29:.*]] = call i32 [[TIDX]] -// CHECK: %[[VAL_30:.*]] = mul nuw nsw i32 %[[VAL_28]], 1 -// CHECK: %[[VAL_31:.*]] = add nuw nsw i32 %[[VAL_30]], %[[VAL_29]] -// CHECK: %[[VAL_32:.*]] = icmp ult i32 %[[VAL_31]], 1 -// CHECK: call void @llvm.assume(i1 %[[VAL_32]]) -// CHECK: %[[VAL_33:.*]] = icmp ult i32 %[[VAL_31]], 1 -// CHECK: br i1 %[[VAL_33]], label %[[VAL_34:.*]], label %[[VAL_35:.*]] -// CHECK: scatter_ScatterIntoScalar.in_bounds-after: ; preds = %[[VAL_36:.*]], %[[VAL_37:.*]] +// CHECK: %[[VAL_29:.*]] = alloca i32, align 4 +// CHECK-PTX: %[[VAL_30:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x +// CHECK-GCN: %[[VAL_30:.*]] = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK-PTX: %[[VAL_31:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x +// CHECK-GCN: %[[VAL_31:.*]] = call i32 @llvm.amdgcn.workitem.id.x +// CHECK: %[[VAL_32:.*]] = mul nuw nsw i32 %[[VAL_30]], 1 +// CHECK: %[[VAL_33:.*]] = add nuw nsw i32 %[[VAL_32]], %[[VAL_31]] +// CHECK: %[[VAL_34:.*]] = icmp ult i32 %[[VAL_33]], 1 +// CHECK: call void @llvm.assume(i1 %[[VAL_34]]) +// CHECK: %[[VAL_35:.*]] = add nuw nsw i32 %[[VAL_33]], 0 +// CHECK: %[[VAL_36:.*]] = icmp ult i32 %[[VAL_33]], 1 +// CHECK: br i1 %[[VAL_36]], label %[[VAL_37:.*]], label %[[VAL_38:.*]] +// CHECK: scatter_ScatterIntoScalar.in_bounds-after: ; preds = %[[VAL_39:.*]], %[[VAL_40:.*]] // CHECK: ret void -// CHECK: scatter_ScatterIntoScalar.in_bounds-true: ; preds = %[[VAL_37]] -// CHECK: br i1 true, label %[[VAL_38:.*]], label %[[VAL_36]] -// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_38]], %[[VAL_34]] -// CHECK: br label %[[VAL_35]] -// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_34]] -// CHECK: %[[VAL_39:.*]] = load i32, ptr %[[VAL_40:.*]], align 4, !invariant.load -// CHECK: store i32 %[[VAL_39]], ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_27]], align 4 -// CHECK: %[[VAL_41:.*]] = load i32, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_27]], align 4 -// CHECK: store atomic i32 %[[VAL_41]], ptr %[[VAL_42:.*]] unordered, align 4 -// CHECK: br label %[[VAL_36]] +// CHECK: scatter_ScatterIntoScalar.in_bounds-true: ; preds = %[[VAL_40]] +// CHECK: br i1 true, label %[[VAL_41:.*]], label %[[VAL_39]] +// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_41]], %[[VAL_37]] +// CHECK: br label %[[VAL_38]] +// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_37]] +// CHECK: %[[VAL_42:.*]] = load i32, ptr %[[VAL_43:.*]], align 4, !invariant.load +// CHECK-PTX: store i32 %[[VAL_42]], ptr %[[VAL_29]], align 4 +// CHECK-GCN: store i32 %[[VAL_42]], ptr addrspace(5) %[[VAL_29]], align 4 +// CHECK-PTX: %[[VAL_44:.*]] = load i32, ptr %[[VAL_29]], align 4 +// CHECK-GCN: %[[VAL_44:.*]] = load i32, ptr addrspace(5) %[[VAL_29]], align 4 +// CHECK: store atomic i32 %[[VAL_44]], ptr %[[VAL_45:.*]] unordered, align 4 +// CHECK: br label %[[VAL_39]] // CHECK: entry: -// CHECK: %[[VAL_43:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_44:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_45:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_46:.*]] = call i32 [[CTAIDX]] -// CHECK: %[[VAL_47:.*]] = call i32 [[TIDX]] -// CHECK: %[[VAL_48:.*]] = mul nuw nsw i32 %[[VAL_46]], 6 -// CHECK: %[[VAL_49:.*]] = add nuw nsw i32 %[[VAL_48]], %[[VAL_47]] -// CHECK: %[[VAL_50:.*]] = icmp ult i32 %[[VAL_49]], 6 -// CHECK: call void @llvm.assume(i1 %[[VAL_50]]) -// CHECK: %[[VAL_51:.*]] = udiv i32 %[[VAL_49]], 1 -// CHECK: %[[VAL_52:.*]] = urem i32 %[[VAL_51]], 3 -// CHECK: %[[VAL_53:.*]] = udiv i32 %[[VAL_49]], 3 -// CHECK: %[[VAL_54:.*]] = icmp ult i32 %[[VAL_49]], 6 -// CHECK: br i1 %[[VAL_54]], label %[[VAL_55:.*]], label %[[VAL_56:.*]] -// CHECK: scatter_TensorFlowScatter_Mul.in_bounds-after: ; preds = %[[VAL_57:.*]], %[[VAL_58:.*]] +// CHECK: %[[VAL_46:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_47:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_48:.*]] = alloca i32, align 4 +// CHECK-PTX: %[[VAL_49:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x +// CHECK-GCN: %[[VAL_49:.*]] = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK-PTX: %[[VAL_50:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x +// CHECK-GCN: %[[VAL_50:.*]] = call i32 @llvm.amdgcn.workitem.id.x +// CHECK: %[[VAL_51:.*]] = mul nuw nsw i32 %[[VAL_49]], 6 +// CHECK: %[[VAL_52:.*]] = add nuw nsw i32 %[[VAL_51]], %[[VAL_50]] +// CHECK: %[[VAL_53:.*]] = icmp ult i32 %[[VAL_52]], 6 +// CHECK: call void @llvm.assume(i1 %[[VAL_53]]) +// CHECK: %[[VAL_54:.*]] = add nuw nsw i32 %[[VAL_52]], 0 +// CHECK: %[[VAL_55:.*]] = udiv i32 %[[VAL_54]], 1 +// CHECK: %[[VAL_56:.*]] = urem i32 %[[VAL_55]], 3 +// CHECK: %[[VAL_57:.*]] = udiv i32 %[[VAL_54]], 3 +// CHECK: %[[VAL_58:.*]] = icmp ult i32 %[[VAL_52]], 6 +// CHECK: br i1 %[[VAL_58]], label %[[VAL_59:.*]], label %[[VAL_60:.*]] +// CHECK: scatter_TensorFlowScatter_Mul.in_bounds-after: ; preds = %[[VAL_61:.*]], %[[VAL_62:.*]] // CHECK: ret void -// CHECK: scatter_TensorFlowScatter_Mul.in_bounds-true: ; preds = %[[VAL_58]] -// CHECK: %[[VAL_59:.*]] = getelementptr inbounds [2 x i32], ptr %[[VAL_60:.*]], i32 0, i32 %[[VAL_53]] -// CHECK: %[[VAL_61:.*]] = load i32, ptr %[[VAL_59]], align 4, !invariant.load -// CHECK: %[[VAL_62:.*]] = add i32 0, %[[VAL_61]] -// CHECK: %[[VAL_63:.*]] = icmp ult i32 %[[VAL_61]], 3 -// CHECK: %[[VAL_64:.*]] = and i1 true, %[[VAL_63]] -// CHECK: br i1 %[[VAL_64]], label %[[VAL_65:.*]], label %[[VAL_57]] -// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_66:.*]], %[[VAL_55]] -// CHECK: br label %[[VAL_56]] -// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_55]] -// CHECK: %[[VAL_67:.*]] = getelementptr inbounds [3 x [3 x i32]], ptr %[[VAL_68:.*]], i32 0, i32 %[[VAL_62]], i32 %[[VAL_52]] -// CHECK: %[[VAL_69:.*]] = getelementptr inbounds i32, ptr %[[VAL_70:.*]], i32 %[[VAL_49]] -// CHECK: %[[VAL_71:.*]] = load i32, ptr %[[VAL_69]], align 4, !invariant.load -// CHECK: store i32 %[[VAL_71]], ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_45]], align 4 -// CHECK: %[[VAL_72:.*]] = load i32, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_45]], align 4 -// CHECK: %[[VAL_73:.*]] = load i32, ptr %[[VAL_67]], align 4 -// CHECK: store i32 %[[VAL_73]], ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_44]], align 4 -// CHECK: br label %[[VAL_74:.*]] -// CHECK: atomic_op_loop_exit: ; preds = %[[VAL_75:.*]], %[[VAL_74]] -// CHECK: br label %[[VAL_57]] -// CHECK: atomic_op_loop_body: ; preds = %[[VAL_75]], %[[VAL_65]] -// CHECK: %[[VAL_76:.*]] = load i32, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_44]], align 4 -// CHECK: store i32 %[[VAL_76]], ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_43]], align 4 -// CHECK-GCN: %[[VAL_43_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_43]] to ptr -// CHECK-GCN: %[[VAL_45_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_45]] to ptr -// CHECK-GCN: %[[VAL_43_3:.*]] = addrspacecast ptr addrspace(5) %[[VAL_43]] to ptr -// CHECK-GCN: call void @mul_{{.*}}(ptr %[[VAL_43_2]], ptr %[[VAL_45_2]], ptr %[[VAL_43_3]]) -// CHECK-PTX: call void @mul_{{.*}}(ptr %[[VAL_43]], ptr %[[VAL_45]], ptr %[[VAL_43]]) -// CHECK: %[[VAL_77:.*]] = load i32, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_43]], align 4 -// CHECK: %[[VAL_78:.*]] = icmp eq i32 %[[VAL_76]], %[[VAL_77]] -// CHECK: br i1 %[[VAL_78]], label %[[VAL_66]], label %[[VAL_75]] -// CHECK: atomic_op_loop_cas: ; preds = %[[VAL_74]] -// CHECK-PTX: %[[VAL_79:.*]] = cmpxchg ptr %[[VAL_67]], i32 %[[VAL_76]], i32 %[[VAL_77]] seq_cst seq_cst, align 4 -// CHECK-GCN: %[[VAL_79:.*]] = cmpxchg ptr %[[VAL_67]], i32 %[[VAL_76]], i32 %[[VAL_77]] {{.*}} seq_cst seq_cst, align 4 -// CHECK: %[[VAL_80:.*]] = extractvalue { i32, i1 } %[[VAL_79]], 0 -// CHECK: store i32 %[[VAL_80]], ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_44]], align 4 -// CHECK: %[[VAL_81:.*]] = extractvalue { i32, i1 } %[[VAL_79]], 1 -// CHECK: br i1 %[[VAL_81]], label %[[VAL_66]], label %[[VAL_74]] +// CHECK: scatter_TensorFlowScatter_Mul.in_bounds-true: ; preds = %[[VAL_62]] +// CHECK: %[[VAL_63:.*]] = getelementptr inbounds [2 x i32], ptr %[[VAL_64:.*]], i32 0, i32 %[[VAL_57]] +// CHECK: %[[VAL_65:.*]] = load i32, ptr %[[VAL_63]], align 4, !invariant.load +// CHECK: %[[VAL_66:.*]] = add i32 0, %[[VAL_65]] +// CHECK: %[[VAL_67:.*]] = icmp ult i32 %[[VAL_65]], 3 +// CHECK: %[[VAL_68:.*]] = and i1 true, %[[VAL_67]] +// CHECK: br i1 %[[VAL_68]], label %[[VAL_69:.*]], label %[[VAL_61]] +// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_70:.*]], %[[VAL_59]] +// CHECK: br label %[[VAL_60]] +// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_59]] +// CHECK: %[[VAL_71:.*]] = getelementptr inbounds [3 x [3 x i32]], ptr %[[VAL_72:.*]], i32 0, i32 %[[VAL_66]], i32 %[[VAL_56]] +// CHECK: %[[VAL_73:.*]] = getelementptr i32, ptr %[[VAL_74:.*]], i32 %[[VAL_52]] +// CHECK: %[[VAL_75:.*]] = getelementptr inbounds i32, ptr %[[VAL_73]], i32 0 +// CHECK: %[[VAL_76:.*]] = load i32, ptr %[[VAL_75]], align 4, !invariant.load +// CHECK-PTX: store i32 %[[VAL_76]], ptr %[[VAL_48]], align 4 +// CHECK-GCN: store i32 %[[VAL_76]], ptr addrspace(5) %[[VAL_48]], align 4 +// CHECK-PTX: %[[VAL_77:.*]] = load i32, ptr %[[VAL_48]], align 4 +// CHECK-GCN: %[[VAL_77:.*]] = load i32, ptr addrspace(5) %[[VAL_48]], align 4 +// CHECK: %[[VAL_78:.*]] = load i32, ptr %[[VAL_71]], align 4 +// CHECK-PTX: store i32 %[[VAL_78]], ptr %[[VAL_47]], align 4 +// CHECK-GCN: store i32 %[[VAL_78]], ptr addrspace(5) %[[VAL_47]], align 4 +// CHECK: br label %[[VAL_79:.*]] +// CHECK: atomic_op_loop_exit: ; preds = %[[VAL_80:.*]], %[[VAL_79]] +// CHECK: br label %[[VAL_61]] +// CHECK: atomic_op_loop_body: ; preds = %[[VAL_80]], %[[VAL_69]] +// CHECK-PTX: %[[VAL_81:.*]] = load i32, ptr %[[VAL_47]], align 4 +// CHECK-GCN: %[[VAL_81:.*]] = load i32, ptr addrspace(5) %[[VAL_47]], align 4 +// CHECK-PTX: store i32 %[[VAL_81]], ptr %[[VAL_46]], align 4 +// CHECK-GCN: store i32 %[[VAL_81]], ptr addrspace(5) %[[VAL_46]], align 4 +// CHECK-GCN: %[[VAL_46_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_46]] to ptr +// CHECK-GCN: %[[VAL_48_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_48]] to ptr +// CHECK-GCN: %[[VAL_46_3:.*]] = addrspacecast ptr addrspace(5) %[[VAL_46]] to ptr +// CHECK-GCN: call void @mul_{{.*}}(ptr %[[VAL_46_2]], ptr %[[VAL_48_2]], ptr %[[VAL_46_3]]) +// CHECK-PTX: call void @mul_{{.*}}(ptr %[[VAL_46]], ptr %[[VAL_48]], ptr %[[VAL_46]]) +// CHECK-PTX: %[[VAL_82:.*]] = load i32, ptr %[[VAL_46]], align 4 +// CHECK-GCN: %[[VAL_82:.*]] = load i32, ptr addrspace(5) %[[VAL_46]], align 4 +// CHECK: %[[VAL_83:.*]] = icmp eq i32 %[[VAL_81]], %[[VAL_82]] +// CHECK: br i1 %[[VAL_83]], label %[[VAL_70]], label %[[VAL_80]] +// CHECK: atomic_op_loop_cas: ; preds = %[[VAL_79]] +// CHECK-PTX: %[[VAL_84:.*]] = cmpxchg ptr %[[VAL_71]], i32 %[[VAL_81]], i32 %[[VAL_82]] seq_cst seq_cst, align 4 +// CHECK-GCN: %[[VAL_84:.*]] = cmpxchg ptr %[[VAL_71]], i32 %[[VAL_81]], i32 %[[VAL_82]] {{.*}} seq_cst seq_cst, align 4 +// CHECK: %[[VAL_85:.*]] = extractvalue { i32, i1 } %[[VAL_84]], 0 +// CHECK-PTX: store i32 %[[VAL_85]], ptr %[[VAL_47]], align 4 +// CHECK-GCN: store i32 %[[VAL_85]], ptr addrspace(5) %[[VAL_47]], align 4 +// CHECK: %[[VAL_86:.*]] = extractvalue { i32, i1 } %[[VAL_84]], 1 +// CHECK: br i1 %[[VAL_86]], label %[[VAL_70]], label %[[VAL_79]] // CHECK: entry: -// CHECK: %[[VAL_82:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_83:.*]] = load i32, ptr %[[VAL_84:.*]], align 4 -// CHECK: %[[VAL_85:.*]] = load i32, ptr %[[VAL_86:.*]], align 4 -// CHECK: %[[VAL_87:.*]] = mul i32 %[[VAL_83]], %[[VAL_85]] -// CHECK: store i32 %[[VAL_87]], ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_82]], align 4 -// CHECK: %[[VAL_88:.*]] = load i32, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_82]], align 4 -// CHECK: store i32 %[[VAL_88]], ptr %[[VAL_89:.*]], align 4 +// CHECK: %[[VAL_87:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_88:.*]] = load i32, ptr %[[VAL_89:.*]], align 4 +// CHECK: %[[VAL_90:.*]] = load i32, ptr %[[VAL_91:.*]], align 4 +// CHECK: %[[VAL_92:.*]] = mul i32 %[[VAL_88]], %[[VAL_90]] +// CHECK-PTX: store i32 %[[VAL_92]], ptr %[[VAL_87]], align 4 +// CHECK-GCN: store i32 %[[VAL_92]], ptr addrspace(5) %[[VAL_87]], align 4 +// CHECK-PTX: %[[VAL_93:.*]] = load i32, ptr %[[VAL_87]], align 4 +// CHECK-GCN: %[[VAL_93:.*]] = load i32, ptr addrspace(5) %[[VAL_87]], align 4 +// CHECK: store i32 %[[VAL_93]], ptr %[[VAL_94:.*]], align 4 // CHECK: ret void // CHECK: entry: -// CHECK: %[[VAL_90:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_91:.*]] = call i32 [[CTAIDX]] -// CHECK: %[[VAL_92:.*]] = call i32 [[TIDX]] -// CHECK: %[[VAL_93:.*]] = mul nuw nsw i32 %[[VAL_91]], 1 -// CHECK: %[[VAL_94:.*]] = add nuw nsw i32 %[[VAL_93]], %[[VAL_92]] -// CHECK: %[[VAL_95:.*]] = icmp ult i32 %[[VAL_94]], 1 -// CHECK: call void @llvm.assume(i1 %[[VAL_95]]) -// CHECK: %[[VAL_96:.*]] = icmp ult i32 %[[VAL_94]], 1 -// CHECK: br i1 %[[VAL_96]], label %[[VAL_97:.*]], label %[[VAL_98:.*]] -// CHECK: scatter_ScalarUpdate.in_bounds-after: ; preds = %[[VAL_99:.*]], %[[VAL_100:.*]] +// CHECK: %[[VAL_95:.*]] = alloca i32, align 4 +// CHECK-PTX: %[[VAL_96:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x +// CHECK-GCN: %[[VAL_96:.*]] = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK-PTX: %[[VAL_97:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x +// CHECK-GCN: %[[VAL_97:.*]] = call i32 @llvm.amdgcn.workitem.id.x +// CHECK: %[[VAL_98:.*]] = mul nuw nsw i32 %[[VAL_96]], 1 +// CHECK: %[[VAL_99:.*]] = add nuw nsw i32 %[[VAL_98]], %[[VAL_97]] +// CHECK: %[[VAL_100:.*]] = icmp ult i32 %[[VAL_99]], 1 +// CHECK: call void @llvm.assume(i1 %[[VAL_100]]) +// CHECK: %[[VAL_101:.*]] = add nuw nsw i32 %[[VAL_99]], 0 +// CHECK: %[[VAL_102:.*]] = icmp ult i32 %[[VAL_99]], 1 +// CHECK: br i1 %[[VAL_102]], label %[[VAL_103:.*]], label %[[VAL_104:.*]] +// CHECK: scatter_ScalarUpdate.in_bounds-after: ; preds = %[[VAL_105:.*]], %[[VAL_106:.*]] // CHECK: ret void -// CHECK: scatter_ScalarUpdate.in_bounds-true: ; preds = %[[VAL_100]] -// CHECK: %[[VAL_101:.*]] = load i32, ptr %[[VAL_102:.*]], align 4, !invariant.load -// CHECK: %[[VAL_103:.*]] = add i32 0, %[[VAL_101]] -// CHECK: %[[VAL_104:.*]] = icmp ult i32 %[[VAL_101]], 4 -// CHECK: %[[VAL_105:.*]] = and i1 true, %[[VAL_104]] -// CHECK: br i1 %[[VAL_105]], label %[[VAL_106:.*]], label %[[VAL_99]] -// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_106]], %[[VAL_97]] -// CHECK: br label %[[VAL_98]] -// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_97]] -// CHECK: %[[VAL_107:.*]] = getelementptr inbounds [4 x i32], ptr %[[VAL_108:.*]], i32 0, i32 %[[VAL_103]] -// CHECK: %[[VAL_109:.*]] = load i32, ptr %[[VAL_110:.*]], align 4, !invariant.load -// CHECK: store i32 %[[VAL_109]], ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_90]], align 4 -// CHECK: %[[VAL_111:.*]] = load i32, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_90]], align 4 -// CHECK: store atomic i32 %[[VAL_111]], ptr %[[VAL_107]] unordered, align 4 -// CHECK: br label %[[VAL_99]] +// CHECK: scatter_ScalarUpdate.in_bounds-true: ; preds = %[[VAL_106]] +// CHECK: %[[VAL_107:.*]] = load i32, ptr %[[VAL_108:.*]], align 4, !invariant.load +// CHECK: %[[VAL_109:.*]] = add i32 0, %[[VAL_107]] +// CHECK: %[[VAL_110:.*]] = icmp ult i32 %[[VAL_107]], 4 +// CHECK: %[[VAL_111:.*]] = and i1 true, %[[VAL_110]] +// CHECK: br i1 %[[VAL_111]], label %[[VAL_112:.*]], label %[[VAL_105]] +// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_112]], %[[VAL_103]] +// CHECK: br label %[[VAL_104]] +// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_103]] +// CHECK: %[[VAL_113:.*]] = getelementptr inbounds [4 x i32], ptr %[[VAL_114:.*]], i32 0, i32 %[[VAL_109]] +// CHECK: %[[VAL_115:.*]] = load i32, ptr %[[VAL_116:.*]], align 4, !invariant.load +// CHECK-PTX: store i32 %[[VAL_115]], ptr %[[VAL_95]], align 4 +// CHECK-GCN: store i32 %[[VAL_115]], ptr addrspace(5) %[[VAL_95]], align 4 +// CHECK-PTX: %[[VAL_117:.*]] = load i32, ptr %[[VAL_95]], align 4 +// CHECK-GCN: %[[VAL_117:.*]] = load i32, ptr addrspace(5) %[[VAL_95]], align 4 +// CHECK: store atomic i32 %[[VAL_117]], ptr %[[VAL_113]] unordered, align 4 +// CHECK: br label %[[VAL_105]] HloModule TensorFlowScatterV1, is_scheduled=true @@ -263,3 +292,35 @@ ENTRY main { p2 = s32[] parameter(2) ROOT wrapped_scatter = s32[4] fusion(p0, p1, p2), kind=kInput, calls=fused_computation } + +// ----- + + +HloModule TensorFlowScatter_Add, is_scheduled=true + +add_f16 (lhs: f16[], rhs: f16[]) -> f16[] { + lhs = f16[] parameter(0) + rhs = f16[] parameter(1) + ROOT add = f16[] add(f16[] lhs, f16[] rhs) +} + +fused_computation { + operand = f16[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = f16[2,3] parameter(2) + ROOT scatter_TensorFlowScatter_Mul = f16[3,3] scatter(operand, indices, updates), + to_apply=add_f16, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} + +ENTRY main { + p0 = f16[3,3] parameter(0) + p1 = s32[2] parameter(1) + p2 = f16[2,3] parameter(2) + ROOT wrapped_scatter = f16[3,3] fusion(p0, p1, p2), kind=kInput, calls=fused_computation +} + +// CHECK-PTX: atomicrmw fadd diff --git a/xla/service/gpu/tests/select_and_scatter.hlo b/xla/service/gpu/tests/select_and_scatter.hlo index 1bbe30d156c8a..08751943c13ef 100644 --- a/xla/service/gpu/tests/select_and_scatter.hlo +++ b/xla/service/gpu/tests/select_and_scatter.hlo @@ -1,4 +1,4 @@ -// RUN: hlo_to_llvm_ir %s | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %{IR_SUBST} %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py @@ -8,94 +8,116 @@ // CHECK: %[[VAL_2:.*]] = alloca i1, align 1 // CHECK: %[[VAL_3:.*]] = alloca i32, align 4 // CHECK: %[[VAL_4:.*]] = alloca float, align 4 -// CHECK: %[[VAL_5:.*]] = call i32 [[CTAIDX]] -// CHECK: %[[VAL_6:.*]] = call i32 [[TIDX]] +// CHECK-PTX: %[[VAL_5:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x +// CHECK-GCN: %[[VAL_5:.*]] = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK-PTX: %[[VAL_6:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x +// CHECK-GCN: %[[VAL_6:.*]] = call i32 @llvm.amdgcn.workitem.id.x // CHECK: %[[VAL_7:.*]] = mul nuw nsw i32 %[[VAL_5]], 2 // CHECK: %[[VAL_8:.*]] = add nuw nsw i32 %[[VAL_7]], %[[VAL_6]] // CHECK: %[[VAL_9:.*]] = icmp ult i32 %[[VAL_8]], 2 // CHECK: call void @llvm.assume(i1 %[[VAL_9]]) -// CHECK: %[[VAL_10:.*]] = udiv i32 %[[VAL_8]], 1 -// CHECK: %[[VAL_11:.*]] = icmp ult i32 %[[VAL_8]], 2 -// CHECK: br i1 %[[VAL_11]], label %[[VAL_12:.*]], label %[[VAL_13:.*]] -// CHECK: select_and_scatter_12.in_bounds-after: ; preds = %[[VAL_14:.*]], %[[VAL_15:.*]] +// CHECK: %[[VAL_10:.*]] = add nuw nsw i32 %[[VAL_8]], 0 +// CHECK: %[[VAL_11:.*]] = udiv i32 %[[VAL_10]], 1 +// CHECK: %[[VAL_12:.*]] = icmp ult i32 %[[VAL_8]], 2 +// CHECK: br i1 %[[VAL_12]], label %[[VAL_13:.*]], label %[[VAL_14:.*]] +// CHECK: select_and_scatter_12.in_bounds-after: ; preds = %[[VAL_15:.*]], %[[VAL_16:.*]] // CHECK: ret void -// CHECK: select_and_scatter_12.in_bounds-true: ; preds = %[[VAL_15]] -// CHECK: store i1 false, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_2]], align 1 -// CHECK: store i32 0, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_1]], align 4 -// CHECK: br label %[[VAL_16:.*]] -// CHECK: select_and_scatter_12inner.loop_header.window.0: ; preds = %[[VAL_17:.*]], %[[VAL_12]] -// CHECK: %[[VAL_18:.*]] = load i32, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_1]], align 4 -// CHECK: %[[VAL_19:.*]] = icmp uge i32 %[[VAL_18]], 3 -// CHECK: br i1 %[[VAL_19]], label %[[VAL_20:.*]], label %[[VAL_21:.*]] -// CHECK: select_and_scatter_12inner.loop_body.window.0: ; preds = %[[VAL_16]] -// CHECK: %[[VAL_22:.*]] = mul nsw i32 %[[VAL_10]], 3 -// CHECK: %[[VAL_23:.*]] = add nsw i32 %[[VAL_22]], %[[VAL_18]] -// CHECK: %[[VAL_24:.*]] = sub nsw i32 %[[VAL_23]], 0 -// CHECK: %[[VAL_25:.*]] = icmp ult i32 %[[VAL_24]], 6 -// CHECK: %[[VAL_26:.*]] = and i1 true, %[[VAL_25]] -// CHECK: br i1 %[[VAL_26]], label %[[VAL_27:.*]], label %[[VAL_28:.*]] -// CHECK: in-bounds-after: ; preds = %[[VAL_28]], %[[VAL_29:.*]] -// CHECK: %[[VAL_30:.*]] = add nuw nsw i32 %[[VAL_18]], 1 -// CHECK: store i32 %[[VAL_30]], ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_1]], align 4 -// CHECK: br label %[[VAL_16]] -// CHECK: select_and_scatter_12inner.loop_exit.window.0: ; preds = %[[VAL_16]] -// CHECK: %[[VAL_31:.*]] = load i1, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_2]], align 1 -// CHECK: br i1 %[[VAL_31]], label %[[VAL_32:.*]], label %[[VAL_14]] -// CHECK: should-store-after: ; preds = %[[VAL_32]], %[[VAL_20]] -// CHECK: br label %[[VAL_13]] -// CHECK: in-bounds-true: ; preds = %[[VAL_21]] -// CHECK: %[[VAL_33:.*]] = load i1, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_2]], align 1 -// CHECK: br i1 %[[VAL_33]], label %[[VAL_34:.*]], label %[[VAL_35:.*]] -// CHECK: initialized-after: ; preds = %[[VAL_35]], %[[VAL_36:.*]] +// CHECK: select_and_scatter_12.in_bounds-true: ; preds = %[[VAL_16]] +// CHECK-PTX: store i1 false, ptr %[[VAL_2]], align 1 +// CHECK-GCN: store i1 false, ptr addrspace(5) %[[VAL_2]], align 1 +// CHECK-PTX: store i32 0, ptr %[[VAL_1]], align 4 +// CHECK-GCN: store i32 0, ptr addrspace(5) %[[VAL_1]], align 4 +// CHECK: br label %[[VAL_17:.*]] +// CHECK: select_and_scatter_12inner.loop_header.window.0: ; preds = %[[VAL_18:.*]], %[[VAL_13]] +// CHECK-PTX: %[[VAL_19:.*]] = load i32, ptr %[[VAL_1]], align 4 +// CHECK-GCN: %[[VAL_19:.*]] = load i32, ptr addrspace(5) %[[VAL_1]], align 4 +// CHECK: %[[VAL_20:.*]] = icmp uge i32 %[[VAL_19]], 3 +// CHECK: br i1 %[[VAL_20]], label %[[VAL_21:.*]], label %[[VAL_22:.*]] +// CHECK: select_and_scatter_12inner.loop_body.window.0: ; preds = %[[VAL_17]] +// CHECK: %[[VAL_23:.*]] = mul nsw i32 %[[VAL_11]], 3 +// CHECK: %[[VAL_24:.*]] = add nsw i32 %[[VAL_23]], %[[VAL_19]] +// CHECK: %[[VAL_25:.*]] = sub nsw i32 %[[VAL_24]], 0 +// CHECK: %[[VAL_26:.*]] = icmp ult i32 %[[VAL_25]], 6 +// CHECK: %[[VAL_27:.*]] = and i1 true, %[[VAL_26]] +// CHECK: br i1 %[[VAL_27]], label %[[VAL_28:.*]], label %[[VAL_29:.*]] +// CHECK: in-bounds-after: ; preds = %[[VAL_29]], %[[VAL_30:.*]] +// CHECK: %[[VAL_31:.*]] = add nuw nsw i32 %[[VAL_19]], 1 +// CHECK-PTX: store i32 %[[VAL_31]], ptr %[[VAL_1]], align 4 +// CHECK-GCN: store i32 %[[VAL_31]], ptr addrspace(5) %[[VAL_1]], align 4 // CHECK: br label %[[VAL_17]] -// CHECK: in-bounds-false: ; preds = %[[VAL_21]] -// CHECK: br label %[[VAL_17]] -// CHECK: initialized-true: ; preds = %[[VAL_27]] -// CHECK: %[[VAL_37:.*]] = getelementptr inbounds [6 x float], ptr %[[VAL_38:.*]], i32 0, i32 %[[VAL_24]] +// CHECK: select_and_scatter_12inner.loop_exit.window.0: ; preds = %[[VAL_17]] +// CHECK-PTX: %[[VAL_32:.*]] = load i1, ptr %[[VAL_2]], align 1 +// CHECK-GCN: %[[VAL_32:.*]] = load i1, ptr addrspace(5) %[[VAL_2]], align 1 +// CHECK: br i1 %[[VAL_32]], label %[[VAL_33:.*]], label %[[VAL_15]] +// CHECK: should-store-after: ; preds = %[[VAL_33]], %[[VAL_21]] +// CHECK: br label %[[VAL_14]] +// CHECK: in-bounds-true: ; preds = %[[VAL_22]] +// CHECK-PTX: %[[VAL_34:.*]] = load i1, ptr %[[VAL_2]], align 1 +// CHECK-GCN: %[[VAL_34:.*]] = load i1, ptr addrspace(5) %[[VAL_2]], align 1 +// CHECK: br i1 %[[VAL_34]], label %[[VAL_35:.*]], label %[[VAL_36:.*]] +// CHECK: initialized-after: ; preds = %[[VAL_36]], %[[VAL_37:.*]] +// CHECK: br label %[[VAL_18]] +// CHECK: in-bounds-false: ; preds = %[[VAL_22]] +// CHECK: br label %[[VAL_18]] +// CHECK: initialized-true: ; preds = %[[VAL_28]] +// CHECK: %[[VAL_38:.*]] = getelementptr inbounds [6 x float], ptr %[[VAL_39:.*]], i32 0, i32 %[[VAL_25]] // CHECK-GCN: %[[VAL_4_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_4]] to ptr // CHECK-GCN: %[[VAL_0_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_0]] to ptr -// CHECK-GCN: call void @ge_{{.*}}(ptr %[[VAL_4_2]], ptr %[[VAL_37]], ptr %[[VAL_0_2]]) -// CHECK-PTX: call void @ge_{{.*}}(ptr %[[VAL_4]], ptr %[[VAL_37]], ptr %[[VAL_0]]) -// CHECK: %[[VAL_39:.*]] = load i8, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_0]], align 1 -// CHECK: %[[VAL_40:.*]] = icmp ne i8 %[[VAL_39]], 0 -// CHECK: br i1 %[[VAL_40]], label %[[VAL_41:.*]], label %[[VAL_42:.*]] -// CHECK: if-select-lhs-after: ; preds = %[[VAL_42]], %[[VAL_41]] -// CHECK: br label %[[VAL_29]] -// CHECK: initialized-false: ; preds = %[[VAL_27]] -// CHECK: %[[VAL_43:.*]] = getelementptr inbounds [6 x float], ptr %[[VAL_38]], i32 0, i32 %[[VAL_24]] -// CHECK: %[[VAL_44:.*]] = load float, ptr %[[VAL_43]], align 4, !invariant.load -// CHECK: store float %[[VAL_44]], ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_4]], align 4 -// CHECK: %[[VAL_45:.*]] = getelementptr inbounds i32, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_3]], i32 0 -// CHECK: store i32 %[[VAL_24]], ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_45]], align 4 -// CHECK: store i1 true, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_2]], align 1 -// CHECK: br label %[[VAL_29]] -// CHECK: if-select-lhs-true: ; preds = %[[VAL_34]] -// CHECK: br label %[[VAL_36]] -// CHECK: if-select-lhs-false: ; preds = %[[VAL_34]] -// CHECK: %[[VAL_46:.*]] = load float, ptr %[[VAL_37]], align 4 -// CHECK: store float %[[VAL_46]], ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_4]], align 4 -// CHECK: %[[VAL_47:.*]] = getelementptr inbounds i32, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_3]], i32 0 -// CHECK: store i32 %[[VAL_24]], ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_47]], align 4 -// CHECK: br label %[[VAL_36]] -// CHECK: should-store-true: ; preds = %[[VAL_20]] -// CHECK: %[[VAL_48:.*]] = getelementptr inbounds i32, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_3]], i32 0 -// CHECK: %[[VAL_49:.*]] = load i32, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_48]], align 4 -// CHECK: %[[VAL_50:.*]] = getelementptr inbounds float, ptr %[[VAL_51:.*]], i32 %[[VAL_8]] -// CHECK: %[[VAL_52:.*]] = getelementptr inbounds [6 x float], ptr %[[VAL_53:.*]], i32 0, i32 %[[VAL_49]] -// CHECK: %[[VAL_54:.*]] = load float, ptr %[[VAL_50]], align 4 -// CHECK-GCN: %[[VAL_52_2:.*]] = addrspacecast ptr %[[VAL_52]] to ptr addrspace(1) -// CHECK-GCN: %[[VAL_55:.*]] = atomicrmw fadd ptr {{.*}} %[[VAL_52_2]], float %[[VAL_54]] {{.*}} seq_cst, align 4 -// CHECK-PTX: %[[VAL_55:.*]] = atomicrmw fadd ptr %[[VAL_52]], float %[[VAL_54]] seq_cst, align 4 -// CHECK: br label %[[VAL_14]] +// CHECK-GCN: call void @ge_{{.*}}(ptr %[[VAL_4_2]], ptr %[[VAL_38]], ptr %[[VAL_0_2]]) +// CHECK-PTX: call void @ge_{{.*}}(ptr %[[VAL_4]], ptr %[[VAL_38]], ptr %[[VAL_0]]) +// CHECK-PTX: %[[VAL_40:.*]] = load i8, ptr %[[VAL_0]], align 1 +// CHECK-GCN: %[[VAL_40:.*]] = load i8, ptr addrspace(5) %[[VAL_0]], align 1 +// CHECK: %[[VAL_41:.*]] = icmp ne i8 %[[VAL_40]], 0 +// CHECK: br i1 %[[VAL_41]], label %[[VAL_42:.*]], label %[[VAL_43:.*]] +// CHECK: if-select-lhs-after: ; preds = %[[VAL_43]], %[[VAL_42]] +// CHECK: br label %[[VAL_30]] +// CHECK: initialized-false: ; preds = %[[VAL_28]] +// CHECK: %[[VAL_44:.*]] = getelementptr inbounds [6 x float], ptr %[[VAL_39]], i32 0, i32 %[[VAL_25]] +// CHECK: %[[VAL_45:.*]] = load float, ptr %[[VAL_44]], align 4, !invariant.load +// CHECK-PTX: store float %[[VAL_45]], ptr %[[VAL_4]], align 4 +// CHECK-GCN: store float %[[VAL_45]], ptr addrspace(5) %[[VAL_4]], align 4 +// CHECK-PTX: %[[VAL_46:.*]] = getelementptr inbounds i32, ptr %[[VAL_3]], i32 0 +// CHECK-GCN: %[[VAL_46:.*]] = getelementptr inbounds i32, ptr addrspace(5) %[[VAL_3]], i32 0 +// CHECK-PTX: store i32 %[[VAL_25]], ptr %[[VAL_46]], align 4 +// CHECK-GCN: store i32 %[[VAL_25]], ptr addrspace(5) %[[VAL_46]], align 4 +// CHECK-PTX: store i1 true, ptr %[[VAL_2]], align 1 +// CHECK-GCN: store i1 true, ptr addrspace(5) %[[VAL_2]], align 1 +// CHECK: br label %[[VAL_30]] +// CHECK: if-select-lhs-true: ; preds = %[[VAL_35]] +// CHECK: br label %[[VAL_37]] +// CHECK: if-select-lhs-false: ; preds = %[[VAL_35]] +// CHECK: %[[VAL_47:.*]] = load float, ptr %[[VAL_38]], align 4 +// CHECK-PTX: store float %[[VAL_47]], ptr %[[VAL_4]], align 4 +// CHECK-GCN: store float %[[VAL_47]], ptr addrspace(5) %[[VAL_4]], align 4 +// CHECK-PTX: %[[VAL_48:.*]] = getelementptr inbounds i32, ptr %[[VAL_3]], i32 0 +// CHECK-GCN: %[[VAL_48:.*]] = getelementptr inbounds i32, ptr addrspace(5) %[[VAL_3]], i32 0 +// CHECK-PTX: store i32 %[[VAL_25]], ptr %[[VAL_48]], align 4 +// CHECK-GCN: store i32 %[[VAL_25]], ptr addrspace(5) %[[VAL_48]], align 4 +// CHECK: br label %[[VAL_37]] +// CHECK: should-store-true: ; preds = %[[VAL_21]] +// CHECK-PTX: %[[VAL_49:.*]] = getelementptr inbounds i32, ptr %[[VAL_3]], i32 0 +// CHECK-GCN: %[[VAL_49:.*]] = getelementptr inbounds i32, ptr addrspace(5) %[[VAL_3]], i32 0 +// CHECK-PTX: %[[VAL_50:.*]] = load i32, ptr %[[VAL_49]], align 4 +// CHECK-GCN: %[[VAL_50:.*]] = load i32, ptr addrspace(5) %[[VAL_49]], align 4 +// CHECK: %[[VAL_51:.*]] = getelementptr float, ptr %[[VAL_52:.*]], i32 %[[VAL_8]] +// CHECK: %[[VAL_53:.*]] = getelementptr inbounds float, ptr %[[VAL_51]], i32 0 +// CHECK: %[[VAL_54:.*]] = getelementptr inbounds [6 x float], ptr %[[VAL_55:.*]], i32 0, i32 %[[VAL_50]] +// CHECK: %[[VAL_56:.*]] = load float, ptr %[[VAL_53]], align 4 +// CHECK-GCN: %[[VAL_54_2:.*]] = addrspacecast ptr %[[VAL_54]] to ptr addrspace(1) +// CHECK-GCN: %[[VAL_57:.*]] = atomicrmw fadd ptr {{.*}} %[[VAL_54_2]], float %[[VAL_56]] {{.*}} seq_cst, align 4 +// CHECK-PTX: %[[VAL_57:.*]] = atomicrmw fadd ptr %[[VAL_54]], float %[[VAL_56]] seq_cst, align 4 +// CHECK: br label %[[VAL_15]] // CHECK: entry: -// CHECK: %[[VAL_56:.*]] = alloca i8, align 1 -// CHECK: %[[VAL_57:.*]] = load float, ptr %[[VAL_58:.*]], align 4 +// CHECK: %[[VAL_58:.*]] = alloca i8, align 1 // CHECK: %[[VAL_59:.*]] = load float, ptr %[[VAL_60:.*]], align 4 -// CHECK: %[[VAL_61:.*]] = fcmp oge float %[[VAL_57]], %[[VAL_59]] -// CHECK: %[[VAL_62:.*]] = zext i1 %[[VAL_61]] to i8 -// CHECK: store i8 %[[VAL_62]], ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_56]], align 1 -// CHECK: %[[VAL_63:.*]] = load i8, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_56]], align 1 -// CHECK: store i8 %[[VAL_63]], ptr %[[VAL_64:.*]], align 1 +// CHECK: %[[VAL_61:.*]] = load float, ptr %[[VAL_62:.*]], align 4 +// CHECK: %[[VAL_63:.*]] = fcmp oge float %[[VAL_59]], %[[VAL_61]] +// CHECK: %[[VAL_64:.*]] = zext i1 %[[VAL_63]] to i8 +// CHECK-PTX: store i8 %[[VAL_64]], ptr %[[VAL_58]], align 1 +// CHECK-GCN: store i8 %[[VAL_64]], ptr addrspace(5) %[[VAL_58]], align 1 +// CHECK-PTX: %[[VAL_65:.*]] = load i8, ptr %[[VAL_58]], align 1 +// CHECK-GCN: %[[VAL_65:.*]] = load i8, ptr addrspace(5) %[[VAL_58]], align 1 +// CHECK: store i8 %[[VAL_65]], ptr %[[VAL_66:.*]], align 1 // CHECK: ret void HloModule SelectAndScatter, is_scheduled=true diff --git a/xla/service/gpu/tests/select_and_scatter_test.cc b/xla/service/gpu/tests/select_and_scatter_test.cc index 9a371757013eb..0e72a55fd3549 100644 --- a/xla/service/gpu/tests/select_and_scatter_test.cc +++ b/xla/service/gpu/tests/select_and_scatter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,7 +15,6 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/tests/hlo_test_base.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/tests/simple_optimization_test.cc b/xla/service/gpu/tests/simple_optimization_test.cc index 47154827994fd..a18d58d6df333 100644 --- a/xla/service/gpu/tests/simple_optimization_test.cc +++ b/xla/service/gpu/tests/simple_optimization_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include "absl/strings/string_view.h" #include "xla/tests/hlo_test_base.h" #include "tsl/lib/core/status_test_util.h" diff --git a/xla/service/gpu/tests/simplify_fp_conversions_test.cc b/xla/service/gpu/tests/simplify_fp_conversions_test.cc new file mode 100644 index 0000000000000..9a73cc62e3177 --- /dev/null +++ b/xla/service/gpu/tests/simplify_fp_conversions_test.cc @@ -0,0 +1,91 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "xla/tests/hlo_test_base.h" +#include "xla/xla.pb.h" + +namespace xla { +namespace gpu { +namespace { + +class SimplifyFPConversionsTest : public HloTestBase { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_allow_excess_precision( + enable_simplify_all_fp_conversions_); + return debug_options; + } + + void SetEnableSimplifyFpConversions(bool enable_simplify_all_fp_conversions) { + enable_simplify_all_fp_conversions_ = enable_simplify_all_fp_conversions; + } + + static constexpr std::string_view kHloText = R"( +HloModule module + +ENTRY main { + param0 = bf16[1536]{0} parameter(0) + param1 = bf16[4,1536]{1,0} parameter(1) + + s = bf16[1536]{0} rsqrt(param0) + // Redundant conversions appear here when the algebraic simplifier + // pushes the broadcast op further down + b = bf16[4,1536]{1,0} broadcast(s), dimensions={1} + + ROOT d = bf16[4,1536]{1,0} multiply(b, param1) +} + )"; + + private: + bool enable_simplify_all_fp_conversions_ = false; +}; + +TEST_F(SimplifyFPConversionsTest, RedundantTypeConversionsGetCleanedUp) { + // The algebraic simplifier might expose redundant type conversions, + // i.e. f32 -> bf16 -> f32. This test ensures that they will get cleaned up + // eventually by the SimplifyFPConversion pass. + + SetEnableSimplifyFpConversions(true); + + // This matcher ensures that there will be no convert in between the rsqrt and + // the broadcast instruction. + MatchOptimizedHlo(kHloText, R"( +// CHECK: rsqrt( +// CHECK-NOT: convert( +// CHECK: broadcast( +)"); +} + +TEST_F(SimplifyFPConversionsTest, RedundantTypeConversionsArePresentInTest) { + // This test ensures that the HLO that we use in the previous test is actually + // meaningful and would lead to redundant type conversions if the simplifier + // didn't clean them up. + + SetEnableSimplifyFpConversions(false); + + MatchOptimizedHlo(kHloText, R"( +// CHECK: rsqrt( +// CHECK-NEXT: convert( +// CHECK-NEXT: convert( +// CHECK-NEXT: broadcast( +)"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/tests/single_instruction.hlo b/xla/service/gpu/tests/single_instruction.hlo index 50af63fb3a866..c8378f746aa98 100644 --- a/xla/service/gpu/tests/single_instruction.hlo +++ b/xla/service/gpu/tests/single_instruction.hlo @@ -1,4 +1,6 @@ -// RUN: hlo_to_llvm_ir --ptx %s | FileCheck %s +// RUN: hlo-opt %s --platform=gpu --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck %s +// RUN: hlo-opt %s --platform=gpu --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/a100_80.txtpb --split-input-file | FileCheck %s --check-prefixes=CHECK-SM80 +// RUN: hlo-opt %s --platform=gpu --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/h100.txtpb --split-input-file | FileCheck %s --check-prefixes=CHECK-SM90 // CHECK-DAG: sqrt.approx.f32 @@ -61,3 +63,39 @@ ENTRY main { a = f32[] parameter(0) ROOT wrapped_b = f32[] fusion(f32[] a), kind=kLoop, calls=fused_computation } + +// ----- + +// CHECK-SM80: min.NaN.f32 + +HloModule Test, is_scheduled=true + +fused_computation { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT b.1 = f32[] minimum(f32[] param_0, f32[] param_1) +} + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT wrapped_b = f32[] fusion(f32[] a, f32[] b), kind=kLoop, calls=fused_computation +} + +// ----- + +// CHECK-SM80: cvt.rn.f32.s16 +// CHECK-SM80: cvt.rn.bf16.f32 +// CHECK-SM90: cvt.rn.bf16.s16 + +HloModule Test, is_scheduled=true + +fused_computation { + param_0 = s16[] parameter(0) + ROOT b.1 = bf16[] convert(s16[] param_0) +} + +ENTRY main { + a = s16[] parameter(0) + ROOT wrapped_b = bf16[] fusion(s16[] a), kind=kLoop, calls=fused_computation +} diff --git a/xla/service/gpu/tests/slice_to_dynamic.hlo b/xla/service/gpu/tests/slice_to_dynamic.hlo index b67edac01be63..242bd749bdaf1 100644 --- a/xla/service/gpu/tests/slice_to_dynamic.hlo +++ b/xla/service/gpu/tests/slice_to_dynamic.hlo @@ -1,4 +1,4 @@ -// RUN: hlo_to_llvm_ir %s | FileCheck %{IR_SUBST} %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py @@ -6,9 +6,11 @@ // CHECK: %[[VAL_0:.*]] = load i32, ptr %[[VAL_1:.*]], align 4 // CHECK: %[[VAL_2:.*]] = load i32, ptr %[[VAL_3:.*]], align 4 // CHECK: %[[VAL_4:.*]] = load i32, ptr %[[VAL_1]], align 4 -// CHECK: %[[VAL_5:.*]] = call i32 [[TIDX]] +// CHECK-PTX: %[[VAL_5:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x +// CHECK-GCN: %[[VAL_5:.*]] = call i32 @llvm.amdgcn.workitem.id.x // CHECK: %[[VAL_6:.*]] = icmp eq i32 0, %[[VAL_5]] -// CHECK: %[[VAL_7:.*]] = call i32 [[CTAIDX]] +// CHECK-PTX: %[[VAL_7:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x +// CHECK-GCN: %[[VAL_7:.*]] = call i32 @llvm.amdgcn.workgroup.id.x // CHECK: %[[VAL_8:.*]] = icmp eq i32 0, %[[VAL_7]] // CHECK: %[[VAL_9:.*]] = and i1 %[[VAL_6]], %[[VAL_8]] // CHECK: br i1 %[[VAL_9]], label %[[VAL_10:.*]], label %[[VAL_11:.*]] @@ -16,53 +18,57 @@ // CHECK: %[[VAL_13:.*]] = mul i32 1, %[[VAL_0]] // CHECK: %[[VAL_14:.*]] = mul i32 %[[VAL_13]], %[[VAL_2]] // CHECK: %[[VAL_15:.*]] = mul i32 %[[VAL_14]], %[[VAL_4]] -// CHECK: %[[VAL_16:.*]] = call i32 [[CTAIDX]] -// CHECK: %[[VAL_17:.*]] = call i32 [[TIDX]] +// CHECK-PTX: %[[VAL_16:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x +// CHECK-GCN: %[[VAL_16:.*]] = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK-PTX: %[[VAL_17:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x +// CHECK-GCN: %[[VAL_17:.*]] = call i32 @llvm.amdgcn.workitem.id.x // CHECK: %[[VAL_18:.*]] = mul nuw nsw i32 %[[VAL_16]], 8 // CHECK: %[[VAL_19:.*]] = add nuw nsw i32 %[[VAL_18]], %[[VAL_17]] // CHECK: %[[VAL_20:.*]] = icmp ult i32 %[[VAL_19]], 8 // CHECK: call void @llvm.assume(i1 %[[VAL_20]]) -// CHECK: %[[VAL_21:.*]] = udiv i32 %[[VAL_19]], 1 -// CHECK: %[[VAL_22:.*]] = urem i32 %[[VAL_21]], 2 -// CHECK: %[[VAL_23:.*]] = udiv i32 %[[VAL_19]], 2 -// CHECK: %[[VAL_24:.*]] = urem i32 %[[VAL_23]], 2 -// CHECK: %[[VAL_25:.*]] = udiv i32 %[[VAL_19]], 4 -// CHECK: %[[VAL_26:.*]] = icmp ult i32 %[[VAL_19]], 8 -// CHECK: br i1 %[[VAL_26]], label %[[VAL_27:.*]], label %[[VAL_28:.*]] -// CHECK: custom_call.in_bounds-after: ; preds = %[[VAL_29:.*]], %[[VAL_11]] +// CHECK: %[[VAL_21:.*]] = add nuw nsw i32 %[[VAL_19]], 0 +// CHECK: %[[VAL_22:.*]] = udiv i32 %[[VAL_21]], 1 +// CHECK: %[[VAL_23:.*]] = urem i32 %[[VAL_22]], 2 +// CHECK: %[[VAL_24:.*]] = udiv i32 %[[VAL_21]], 2 +// CHECK: %[[VAL_25:.*]] = urem i32 %[[VAL_24]], 2 +// CHECK: %[[VAL_26:.*]] = udiv i32 %[[VAL_21]], 4 +// CHECK: %[[VAL_27:.*]] = icmp ult i32 %[[VAL_19]], 8 +// CHECK: br i1 %[[VAL_27]], label %[[VAL_28:.*]], label %[[VAL_29:.*]] +// CHECK: custom-call.in_bounds-after: ; preds = %[[VAL_30:.*]], %[[VAL_11]] // CHECK: ret void // CHECK: is_thread_0-true: ; preds = %[[VAL_12]] -// CHECK: %[[VAL_30:.*]] = getelementptr inbounds i8, ptr %[[VAL_31:.*]], i32 32 -// CHECK: store i32 %[[VAL_0]], ptr %[[VAL_30]], align 4 -// CHECK: %[[VAL_32:.*]] = getelementptr inbounds i8, ptr %[[VAL_31]], i32 36 -// CHECK: store i32 %[[VAL_2]], ptr %[[VAL_32]], align 4 -// CHECK: %[[VAL_33:.*]] = getelementptr inbounds i8, ptr %[[VAL_31]], i32 40 -// CHECK: store i32 %[[VAL_4]], ptr %[[VAL_33]], align 4 +// CHECK: %[[VAL_31:.*]] = getelementptr inbounds i8, ptr %[[VAL_32:.*]], i32 32 +// CHECK: store i32 %[[VAL_0]], ptr %[[VAL_31]], align 4 +// CHECK: %[[VAL_33:.*]] = getelementptr inbounds i8, ptr %[[VAL_32]], i32 36 +// CHECK: store i32 %[[VAL_2]], ptr %[[VAL_33]], align 4 +// CHECK: %[[VAL_34:.*]] = getelementptr inbounds i8, ptr %[[VAL_32]], i32 40 +// CHECK: store i32 %[[VAL_4]], ptr %[[VAL_34]], align 4 // CHECK: br label %[[VAL_11]] -// CHECK: custom_call.in_bounds-true: ; preds = %[[VAL_11]] -// CHECK: %[[VAL_34:.*]] = mul nuw nsw i32 %[[VAL_22]], 1 -// CHECK: %[[VAL_35:.*]] = add nuw nsw i32 0, %[[VAL_34]] -// CHECK: %[[VAL_36:.*]] = mul nuw nsw i32 %[[VAL_24]], 2 -// CHECK: %[[VAL_37:.*]] = add nuw nsw i32 %[[VAL_35]], %[[VAL_36]] -// CHECK: %[[VAL_38:.*]] = mul nuw nsw i32 %[[VAL_25]], 4 -// CHECK: %[[VAL_39:.*]] = add nuw nsw i32 %[[VAL_37]], %[[VAL_38]] -// CHECK: %[[VAL_40:.*]] = icmp ult i32 %[[VAL_39]], %[[VAL_15]] -// CHECK: br i1 %[[VAL_40]], label %[[VAL_41:.*]], label %[[VAL_29]] -// CHECK: custom_call.in_dyn_bounds-after: ; preds = %[[VAL_41]], %[[VAL_27]] -// CHECK: br label %[[VAL_28]] -// CHECK: custom_call.in_dyn_bounds-true: ; preds = %[[VAL_27]] -// CHECK: %[[VAL_42:.*]] = udiv i32 %[[VAL_39]], 1 -// CHECK: %[[VAL_43:.*]] = urem i32 %[[VAL_42]], %[[VAL_4]] -// CHECK: %[[VAL_44:.*]] = mul i32 1, %[[VAL_4]] -// CHECK: %[[VAL_45:.*]] = udiv i32 %[[VAL_39]], %[[VAL_44]] -// CHECK: %[[VAL_46:.*]] = urem i32 %[[VAL_45]], %[[VAL_0]] -// CHECK: %[[VAL_47:.*]] = mul i32 %[[VAL_44]], %[[VAL_0]] -// CHECK: %[[VAL_48:.*]] = udiv i32 %[[VAL_39]], %[[VAL_47]] -// CHECK: %[[VAL_49:.*]] = getelementptr inbounds [2 x [2 x [2 x i32]]], ptr %[[VAL_50:.*]], i32 0, i32 %[[VAL_48]], i32 %[[VAL_46]], i32 %[[VAL_43]] -// CHECK: %[[VAL_51:.*]] = load i32, ptr %[[VAL_49]], align 4, !invariant.load -// CHECK: %[[VAL_52:.*]] = getelementptr inbounds i32, ptr %[[VAL_31]], i32 %[[VAL_19]] -// CHECK: store i32 %[[VAL_51]], ptr %[[VAL_52]], align 4 +// CHECK: custom-call.in_bounds-true: ; preds = %[[VAL_11]] +// CHECK: %[[VAL_35:.*]] = mul nuw nsw i32 %[[VAL_23]], 1 +// CHECK: %[[VAL_36:.*]] = add nuw nsw i32 0, %[[VAL_35]] +// CHECK: %[[VAL_37:.*]] = mul nuw nsw i32 %[[VAL_25]], 2 +// CHECK: %[[VAL_38:.*]] = add nuw nsw i32 %[[VAL_36]], %[[VAL_37]] +// CHECK: %[[VAL_39:.*]] = mul nuw nsw i32 %[[VAL_26]], 4 +// CHECK: %[[VAL_40:.*]] = add nuw nsw i32 %[[VAL_38]], %[[VAL_39]] +// CHECK: %[[VAL_41:.*]] = icmp ult i32 %[[VAL_40]], %[[VAL_15]] +// CHECK: br i1 %[[VAL_41]], label %[[VAL_42:.*]], label %[[VAL_30]] +// CHECK: custom-call.in_dyn_bounds-after: ; preds = %[[VAL_42]], %[[VAL_28]] // CHECK: br label %[[VAL_29]] +// CHECK: custom-call.in_dyn_bounds-true: ; preds = %[[VAL_28]] +// CHECK: %[[VAL_43:.*]] = udiv i32 %[[VAL_40]], 1 +// CHECK: %[[VAL_44:.*]] = urem i32 %[[VAL_43]], %[[VAL_4]] +// CHECK: %[[VAL_45:.*]] = mul i32 1, %[[VAL_4]] +// CHECK: %[[VAL_46:.*]] = udiv i32 %[[VAL_40]], %[[VAL_45]] +// CHECK: %[[VAL_47:.*]] = urem i32 %[[VAL_46]], %[[VAL_0]] +// CHECK: %[[VAL_48:.*]] = mul i32 %[[VAL_45]], %[[VAL_0]] +// CHECK: %[[VAL_49:.*]] = udiv i32 %[[VAL_40]], %[[VAL_48]] +// CHECK: %[[VAL_50:.*]] = getelementptr inbounds [2 x [2 x [2 x i32]]], ptr %[[VAL_51:.*]], i32 0, i32 %[[VAL_49]], i32 %[[VAL_47]], i32 %[[VAL_44]] +// CHECK: %[[VAL_52:.*]] = load i32, ptr %[[VAL_50]], align 4, !invariant.load +// CHECK: %[[VAL_53:.*]] = getelementptr i32, ptr %[[VAL_32]], i32 %[[VAL_19]] +// CHECK: %[[VAL_54:.*]] = getelementptr inbounds i32, ptr %[[VAL_53]], i32 0 +// CHECK: store i32 %[[VAL_52]], ptr %[[VAL_54]], align 4 +// CHECK: br label %[[VAL_30]] HloModule SliceToDynamic, is_scheduled=true diff --git a/xla/service/gpu/tests/sorting.hlo b/xla/service/gpu/tests/sorting.hlo index c6b405ce17a4c..dd1152b7e176f 100644 --- a/xla/service/gpu/tests/sorting.hlo +++ b/xla/service/gpu/tests/sorting.hlo @@ -1,5 +1,5 @@ // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py -// RUN: hlo_to_llvm_ir %s | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %{IR_SUBST} %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s HloModule TestModule, is_scheduled=true @@ -16,269 +16,285 @@ compare { // CHECK: %[[VAL_3:.*]] = alloca i8, align 1 // CHECK: %[[VAL_4:.*]] = alloca i8, align 1 // CHECK: %[[VAL_5:.*]] = alloca i8, align 1 -// CHECK: %[[VAL_6:.*]] = call i32 [[CTAIDX]] +// CHECK-PTX: %[[VAL_6:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x +// CHECK-GCN: %[[VAL_6:.*]] = call i32 @llvm.amdgcn.workgroup.id.x // CHECK: %[[VAL_7:.*]] = zext i32 %[[VAL_6]] to i64 -// CHECK: %[[VAL_8:.*]] = call i32 [[TIDX]] +// CHECK-PTX: %[[VAL_8:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x +// CHECK-GCN: %[[VAL_8:.*]] = call i32 @llvm.amdgcn.workitem.id.x // CHECK: %[[VAL_9:.*]] = zext i32 %[[VAL_8]] to i64 // CHECK: %[[VAL_10:.*]] = mul nuw nsw i64 %[[VAL_7]], 2 // CHECK: %[[VAL_11:.*]] = add nuw nsw i64 %[[VAL_10]], %[[VAL_9]] // CHECK: %[[VAL_12:.*]] = icmp ult i64 %[[VAL_11]], 4 // CHECK: call void @llvm.assume(i1 %[[VAL_12]]) -// CHECK: %[[VAL_13:.*]] = udiv i64 %[[VAL_11]], 1 -// CHECK: %[[VAL_14:.*]] = urem i64 %[[VAL_13]], 2 -// CHECK: %[[VAL_15:.*]] = udiv i64 %[[VAL_11]], 2 -// CHECK: %[[VAL_16:.*]] = icmp ult i64 %[[VAL_11]], 4 -// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]] -// CHECK: sort.in_bounds-after: ; preds = %[[VAL_19:.*]], %[[VAL_20:.*]] +// CHECK: %[[VAL_13:.*]] = add nuw nsw i64 %[[VAL_11]], 0 +// CHECK: %[[VAL_14:.*]] = udiv i64 %[[VAL_13]], 1 +// CHECK: %[[VAL_15:.*]] = urem i64 %[[VAL_14]], 2 +// CHECK: %[[VAL_16:.*]] = udiv i64 %[[VAL_13]], 2 +// CHECK: %[[VAL_17:.*]] = icmp ult i64 %[[VAL_11]], 4 +// CHECK: br i1 %[[VAL_17]], label %[[VAL_18:.*]], label %[[VAL_19:.*]] +// CHECK: sort.in_bounds-after: ; preds = %[[VAL_20:.*]], %[[VAL_21:.*]] // CHECK: ret void -// CHECK: sort.in_bounds-true: ; preds = %[[VAL_20]] -// CHECK: %[[VAL_21:.*]] = call i32 [[TIDX]] -// CHECK: %[[VAL_22:.*]] = sext i32 %[[VAL_21]] to i64 -// CHECK: %[[VAL_23:.*]] = shl i64 %[[VAL_14]], 1 -// CHECK: %[[VAL_24:.*]] = icmp slt i64 %[[VAL_23]], 3 -// CHECK: br i1 %[[VAL_24]], label %[[VAL_25:.*]], label %[[VAL_26:.*]] -// CHECK: smaller_keys_index-after: ; preds = %[[VAL_27:.*]], %[[VAL_17]] -// CHECK: call void [[BARRIER]] -// CHECK: %[[VAL_28:.*]] = mul i64 %[[VAL_14]], 2 -// CHECK: %[[VAL_29:.*]] = icmp uge i64 %[[VAL_28]], 0 -// CHECK: br i1 %[[VAL_29]], label %[[VAL_30:.*]], label %[[VAL_31:.*]] -// CHECK: is_last_tile-after: ; preds = %[[VAL_32:.*]], %[[VAL_33:.*]] -// CHECK: call void [[BARRIER]] -// CHECK: %[[VAL_34:.*]] = mul i64 %[[VAL_14]], 2 -// CHECK: %[[VAL_35:.*]] = icmp uge i64 %[[VAL_34]], 0 -// CHECK: br i1 %[[VAL_35]], label %[[VAL_36:.*]], label %[[VAL_37:.*]] -// CHECK: is_last_tile-after9: ; preds = %[[VAL_38:.*]], %[[VAL_39:.*]] -// CHECK: call void [[BARRIER]] -// CHECK: %[[VAL_40:.*]] = mul i64 %[[VAL_14]], 2 -// CHECK: %[[VAL_41:.*]] = icmp uge i64 %[[VAL_40]], 0 -// CHECK: br i1 %[[VAL_41]], label %[[VAL_42:.*]], label %[[VAL_43:.*]] -// CHECK: is_last_tile-after24: ; preds = %[[VAL_44:.*]], %[[VAL_45:.*]] -// CHECK: call void [[BARRIER]] -// CHECK: %[[VAL_46:.*]] = shl i64 %[[VAL_14]], 1 -// CHECK: %[[VAL_47:.*]] = icmp slt i64 %[[VAL_46]], 3 -// CHECK: br i1 %[[VAL_47]], label %[[VAL_48:.*]], label %[[VAL_19]] -// CHECK: smaller_keys_index-after38: ; preds = %[[VAL_49:.*]], %[[VAL_50:.*]] -// CHECK: br label %[[VAL_18]] -// CHECK: smaller_keys_index-true: ; preds = %[[VAL_17]] -// CHECK: %[[VAL_51:.*]] = shl i64 %[[VAL_22]], 1 -// CHECK: %[[VAL_52:.*]] = getelementptr inbounds [2 x [3 x float]], ptr %[[VAL_53:.*]], i64 0, i64 %[[VAL_15]], i64 %[[VAL_23]] -// CHECK: %[[VAL_54:.*]] = load float, ptr %[[VAL_52]], align 4 -// CHECK: %[[VAL_55:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_51]] -// CHECK: store float %[[VAL_54]], ptr addrspace(3) %[[VAL_55]], align 4 -// CHECK: %[[VAL_56:.*]] = add i64 %[[VAL_23]], 1 -// CHECK: %[[VAL_57:.*]] = icmp slt i64 %[[VAL_56]], 3 -// CHECK: br i1 %[[VAL_57]], label %[[VAL_58:.*]], label %[[VAL_27]] -// CHECK: inner_smaller_keys_index-after: ; preds = %[[VAL_58]], %[[VAL_25]] -// CHECK: br label %[[VAL_26]] -// CHECK: inner_smaller_keys_index-true: ; preds = %[[VAL_25]] -// CHECK: %[[VAL_59:.*]] = add i64 %[[VAL_51]], 1 -// CHECK: %[[VAL_60:.*]] = getelementptr inbounds [2 x [3 x float]], ptr %[[VAL_53]], i64 0, i64 %[[VAL_15]], i64 %[[VAL_56]] -// CHECK: %[[VAL_61:.*]] = load float, ptr %[[VAL_60]], align 4 -// CHECK: %[[VAL_62:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_59]] -// CHECK: store float %[[VAL_61]], ptr addrspace(3) %[[VAL_62]], align 4 +// CHECK: sort.in_bounds-true: ; preds = %[[VAL_21]] +// CHECK-PTX: %[[VAL_22:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x +// CHECK-GCN: %[[VAL_22:.*]] = call i32 @llvm.amdgcn.workitem.id.x +// CHECK: %[[VAL_23:.*]] = sext i32 %[[VAL_22]] to i64 +// CHECK: %[[VAL_24:.*]] = shl i64 %[[VAL_15]], 1 +// CHECK: %[[VAL_25:.*]] = icmp slt i64 %[[VAL_24]], 3 +// CHECK: br i1 %[[VAL_25]], label %[[VAL_26:.*]], label %[[VAL_27:.*]] +// CHECK: smaller_keys_index-after: ; preds = %[[VAL_28:.*]], %[[VAL_18]] +// CHECK-PTX: call void @llvm.nvvm.barrier0 +// CHECK-GCN: call void @llvm.amdgcn.s.barrier +// CHECK: %[[VAL_29:.*]] = mul i64 %[[VAL_15]], 2 +// CHECK: %[[VAL_30:.*]] = icmp uge i64 %[[VAL_29]], 0 +// CHECK: br i1 %[[VAL_30]], label %[[VAL_31:.*]], label %[[VAL_32:.*]] +// CHECK: is_last_tile-after: ; preds = %[[VAL_33:.*]], %[[VAL_34:.*]] +// CHECK-PTX: call void @llvm.nvvm.barrier0 +// CHECK-GCN: call void @llvm.amdgcn.s.barrier +// CHECK: %[[VAL_35:.*]] = mul i64 %[[VAL_15]], 2 +// CHECK: %[[VAL_36:.*]] = icmp uge i64 %[[VAL_35]], 0 +// CHECK: br i1 %[[VAL_36]], label %[[VAL_37:.*]], label %[[VAL_38:.*]] +// CHECK: is_last_tile-after9: ; preds = %[[VAL_39:.*]], %[[VAL_40:.*]] +// CHECK-PTX: call void @llvm.nvvm.barrier0 +// CHECK-GCN: call void @llvm.amdgcn.s.barrier +// CHECK: %[[VAL_41:.*]] = mul i64 %[[VAL_15]], 2 +// CHECK: %[[VAL_42:.*]] = icmp uge i64 %[[VAL_41]], 0 +// CHECK: br i1 %[[VAL_42]], label %[[VAL_43:.*]], label %[[VAL_44:.*]] +// CHECK: is_last_tile-after24: ; preds = %[[VAL_45:.*]], %[[VAL_46:.*]] +// CHECK-PTX: call void @llvm.nvvm.barrier0 +// CHECK-GCN: call void @llvm.amdgcn.s.barrier +// CHECK: %[[VAL_47:.*]] = shl i64 %[[VAL_15]], 1 +// CHECK: %[[VAL_48:.*]] = icmp slt i64 %[[VAL_47]], 3 +// CHECK: br i1 %[[VAL_48]], label %[[VAL_49:.*]], label %[[VAL_20]] +// CHECK: smaller_keys_index-after38: ; preds = %[[VAL_50:.*]], %[[VAL_51:.*]] +// CHECK: br label %[[VAL_19]] +// CHECK: smaller_keys_index-true: ; preds = %[[VAL_18]] +// CHECK: %[[VAL_52:.*]] = shl i64 %[[VAL_23]], 1 +// CHECK: %[[VAL_53:.*]] = getelementptr inbounds [2 x [3 x float]], ptr %[[VAL_54:.*]], i64 0, i64 %[[VAL_16]], i64 %[[VAL_24]] +// CHECK: %[[VAL_55:.*]] = load float, ptr %[[VAL_53]], align 4 +// CHECK: %[[VAL_56:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_52]] +// CHECK: store float %[[VAL_55]], ptr addrspace(3) %[[VAL_56]], align 4 +// CHECK: %[[VAL_57:.*]] = add i64 %[[VAL_24]], 1 +// CHECK: %[[VAL_58:.*]] = icmp slt i64 %[[VAL_57]], 3 +// CHECK: br i1 %[[VAL_58]], label %[[VAL_59:.*]], label %[[VAL_28]] +// CHECK: inner_smaller_keys_index-after: ; preds = %[[VAL_59]], %[[VAL_26]] // CHECK: br label %[[VAL_27]] -// CHECK: is_last_tile-true: ; preds = %[[VAL_26]] -// CHECK: %[[VAL_63:.*]] = mul i64 %[[VAL_22]], 2 -// CHECK: %[[VAL_64:.*]] = xor i64 %[[VAL_63]], 1 -// CHECK: %[[VAL_65:.*]] = icmp slt i64 %[[VAL_63]], %[[VAL_64]] -// CHECK: %[[VAL_66:.*]] = icmp slt i64 %[[VAL_64]], 3 -// CHECK: %[[VAL_67:.*]] = and i1 %[[VAL_65]], %[[VAL_66]] -// CHECK: br i1 %[[VAL_67]], label %[[VAL_68:.*]], label %[[VAL_33]] -// CHECK: smaller_comparison_index-after: ; preds = %[[VAL_69:.*]], %[[VAL_30]] -// CHECK: br label %[[VAL_70:.*]] -// CHECK: is_last_tile-false: ; preds = %[[VAL_26]] -// CHECK: %[[VAL_71:.*]] = mul i64 %[[VAL_22]], 2 -// CHECK: %[[VAL_72:.*]] = xor i64 %[[VAL_71]], 1 -// CHECK: %[[VAL_73:.*]] = icmp slt i64 %[[VAL_71]], %[[VAL_72]] -// CHECK: %[[VAL_74:.*]] = icmp slt i64 %[[VAL_72]], 4 -// CHECK: br i1 true, label %[[VAL_75:.*]], label %[[VAL_32]] -// CHECK: smaller_comparison_index-after2: ; preds = %[[VAL_76:.*]], %[[VAL_31]] -// CHECK: br label %[[VAL_70]] -// CHECK: smaller_comparison_index-true: ; preds = %[[VAL_30]] -// CHECK: %[[VAL_77:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_64]] -// CHECK: %[[VAL_78:.*]] = addrspacecast ptr addrspace(3) %[[VAL_77]] to ptr -// CHECK: %[[VAL_79:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_63]] -// CHECK: %[[VAL_80:.*]] = addrspacecast ptr addrspace(3) %[[VAL_79]] to ptr +// CHECK: inner_smaller_keys_index-true: ; preds = %[[VAL_26]] +// CHECK: %[[VAL_60:.*]] = add i64 %[[VAL_52]], 1 +// CHECK: %[[VAL_61:.*]] = getelementptr inbounds [2 x [3 x float]], ptr %[[VAL_54]], i64 0, i64 %[[VAL_16]], i64 %[[VAL_57]] +// CHECK: %[[VAL_62:.*]] = load float, ptr %[[VAL_61]], align 4 +// CHECK: %[[VAL_63:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_60]] +// CHECK: store float %[[VAL_62]], ptr addrspace(3) %[[VAL_63]], align 4 +// CHECK: br label %[[VAL_28]] +// CHECK: is_last_tile-true: ; preds = %[[VAL_27]] +// CHECK: %[[VAL_64:.*]] = mul i64 %[[VAL_23]], 2 +// CHECK: %[[VAL_65:.*]] = xor i64 %[[VAL_64]], 1 +// CHECK: %[[VAL_66:.*]] = icmp slt i64 %[[VAL_64]], %[[VAL_65]] +// CHECK: %[[VAL_67:.*]] = icmp slt i64 %[[VAL_65]], 3 +// CHECK: %[[VAL_68:.*]] = and i1 %[[VAL_66]], %[[VAL_67]] +// CHECK: br i1 %[[VAL_68]], label %[[VAL_69:.*]], label %[[VAL_34]] +// CHECK: smaller_comparison_index-after: ; preds = %[[VAL_70:.*]], %[[VAL_31]] +// CHECK: br label %[[VAL_71:.*]] +// CHECK: is_last_tile-false: ; preds = %[[VAL_27]] +// CHECK: %[[VAL_72:.*]] = mul i64 %[[VAL_23]], 2 +// CHECK: %[[VAL_73:.*]] = xor i64 %[[VAL_72]], 1 +// CHECK: %[[VAL_74:.*]] = icmp slt i64 %[[VAL_72]], %[[VAL_73]] +// CHECK: %[[VAL_75:.*]] = icmp slt i64 %[[VAL_73]], 4 +// CHECK: br i1 true, label %[[VAL_76:.*]], label %[[VAL_33]] +// CHECK: smaller_comparison_index-after2: ; preds = %[[VAL_77:.*]], %[[VAL_32]] +// CHECK: br label %[[VAL_71]] +// CHECK: smaller_comparison_index-true: ; preds = %[[VAL_31]] +// CHECK: %[[VAL_78:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_65]] +// CHECK: %[[VAL_79:.*]] = addrspacecast ptr addrspace(3) %[[VAL_78]] to ptr +// CHECK: %[[VAL_80:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_64]] +// CHECK: %[[VAL_81:.*]] = addrspacecast ptr addrspace(3) %[[VAL_80]] to ptr // CHECK-GCN: %[[VAL_5_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_5]] to ptr -// CHECK-GCN: call void @[[REGION:compare_.*]](ptr %[[VAL_78]], ptr %[[VAL_80]], ptr %[[VAL_5_2]]) -// CHECK-PTX: call void @[[REGION:compare_.*]](ptr %[[VAL_78]], ptr %[[VAL_80]], ptr %[[VAL_5]]) -// CHECK: %[[VAL_81:.*]] = load i8, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_5]], align 1 -// CHECK: %[[VAL_82:.*]] = icmp ne i8 %[[VAL_81]], 0 -// CHECK: br i1 %[[VAL_82]], label %[[VAL_83:.*]], label %[[VAL_69]] -// CHECK: is_smaller_than-after: ; preds = %[[VAL_83]], %[[VAL_68]] -// CHECK: br label %[[VAL_33]] -// CHECK: is_smaller_than-true: ; preds = %[[VAL_68]] -// CHECK: %[[VAL_84:.*]] = load float, ptr %[[VAL_78]], align 4 -// CHECK: %[[VAL_85:.*]] = load float, ptr %[[VAL_80]], align 4 -// CHECK: %[[VAL_86:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_63]] -// CHECK: store float %[[VAL_84]], ptr addrspace(3) %[[VAL_86]], align 4 +// CHECK-GCN: call void @[[REGION:compare_.*]](ptr %[[VAL_79]], ptr %[[VAL_81]], ptr %[[VAL_5_2]]) +// CHECK-PTX: call void @[[REGION:compare_.*]](ptr %[[VAL_79]], ptr %[[VAL_81]], ptr %[[VAL_5]]) +// CHECK-PTX: %[[VAL_82:.*]] = load i8, ptr %[[VAL_5]], align 1 +// CHECK-GCN: %[[VAL_82:.*]] = load i8, ptr addrspace(5) %[[VAL_5]], align 1 +// CHECK: %[[VAL_83:.*]] = icmp ne i8 %[[VAL_82]], 0 +// CHECK: br i1 %[[VAL_83]], label %[[VAL_84:.*]], label %[[VAL_70]] +// CHECK: is_smaller_than-after: ; preds = %[[VAL_84]], %[[VAL_69]] +// CHECK: br label %[[VAL_34]] +// CHECK: is_smaller_than-true: ; preds = %[[VAL_69]] +// CHECK: %[[VAL_85:.*]] = load float, ptr %[[VAL_79]], align 4 +// CHECK: %[[VAL_86:.*]] = load float, ptr %[[VAL_81]], align 4 // CHECK: %[[VAL_87:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_64]] // CHECK: store float %[[VAL_85]], ptr addrspace(3) %[[VAL_87]], align 4 -// CHECK: br label %[[VAL_69]] -// CHECK: smaller_comparison_index-true1: ; preds = %[[VAL_31]] -// CHECK: %[[VAL_88:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_72]] -// CHECK: %[[VAL_89:.*]] = addrspacecast ptr addrspace(3) %[[VAL_88]] to ptr -// CHECK: %[[VAL_90:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_71]] -// CHECK: %[[VAL_91:.*]] = addrspacecast ptr addrspace(3) %[[VAL_90]] to ptr +// CHECK: %[[VAL_88:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_65]] +// CHECK: store float %[[VAL_86]], ptr addrspace(3) %[[VAL_88]], align 4 +// CHECK: br label %[[VAL_70]] +// CHECK: smaller_comparison_index-true1: ; preds = %[[VAL_32]] +// CHECK: %[[VAL_89:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_73]] +// CHECK: %[[VAL_90:.*]] = addrspacecast ptr addrspace(3) %[[VAL_89]] to ptr +// CHECK: %[[VAL_91:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_72]] +// CHECK: %[[VAL_92:.*]] = addrspacecast ptr addrspace(3) %[[VAL_91]] to ptr // CHECK-GCN: %[[VAL_4_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_4]] to ptr -// CHECK-GCN: call void @[[REGION]](ptr %[[VAL_89]], ptr %[[VAL_91]], ptr %[[VAL_4_2]]) -// CHECK-PTX: call void @[[REGION]](ptr %[[VAL_89]], ptr %[[VAL_91]], ptr %[[VAL_4]]) -// CHECK: %[[VAL_92:.*]] = load i8, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_4]], align 1 -// CHECK: %[[VAL_93:.*]] = icmp ne i8 %[[VAL_92]], 0 -// CHECK: br i1 %[[VAL_93]], label %[[VAL_94:.*]], label %[[VAL_76]] -// CHECK: is_smaller_than-after6: ; preds = %[[VAL_94]], %[[VAL_75]] -// CHECK: br label %[[VAL_32]] -// CHECK: is_smaller_than-true5: ; preds = %[[VAL_75]] -// CHECK: %[[VAL_95:.*]] = load float, ptr %[[VAL_89]], align 4 -// CHECK: %[[VAL_96:.*]] = load float, ptr %[[VAL_91]], align 4 -// CHECK: %[[VAL_97:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_71]] -// CHECK: store float %[[VAL_95]], ptr addrspace(3) %[[VAL_97]], align 4 +// CHECK-GCN: call void @[[REGION]](ptr %[[VAL_90]], ptr %[[VAL_92]], ptr %[[VAL_4_2]]) +// CHECK-PTX: call void @[[REGION]](ptr %[[VAL_90]], ptr %[[VAL_92]], ptr %[[VAL_4]]) +// CHECK-PTX: %[[VAL_93:.*]] = load i8, ptr %[[VAL_4]], align 1 +// CHECK-GCN: %[[VAL_93:.*]] = load i8, ptr addrspace(5) %[[VAL_4]], align 1 +// CHECK: %[[VAL_94:.*]] = icmp ne i8 %[[VAL_93]], 0 +// CHECK: br i1 %[[VAL_94]], label %[[VAL_95:.*]], label %[[VAL_77]] +// CHECK: is_smaller_than-after6: ; preds = %[[VAL_95]], %[[VAL_76]] +// CHECK: br label %[[VAL_33]] +// CHECK: is_smaller_than-true5: ; preds = %[[VAL_76]] +// CHECK: %[[VAL_96:.*]] = load float, ptr %[[VAL_90]], align 4 +// CHECK: %[[VAL_97:.*]] = load float, ptr %[[VAL_92]], align 4 // CHECK: %[[VAL_98:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_72]] // CHECK: store float %[[VAL_96]], ptr addrspace(3) %[[VAL_98]], align 4 -// CHECK: br label %[[VAL_76]] -// CHECK: is_last_tile-true7: ; preds = %[[VAL_70]] -// CHECK: %[[VAL_99:.*]] = xor i64 %[[VAL_22]], 3 -// CHECK: %[[VAL_100:.*]] = icmp slt i64 %[[VAL_22]], %[[VAL_99]] -// CHECK: %[[VAL_101:.*]] = icmp slt i64 %[[VAL_99]], 3 -// CHECK: %[[VAL_102:.*]] = and i1 %[[VAL_100]], %[[VAL_101]] -// CHECK: br i1 %[[VAL_102]], label %[[VAL_103:.*]], label %[[VAL_39]] -// CHECK: smaller_comparison_index-after11: ; preds = %[[VAL_104:.*]], %[[VAL_36]] -// CHECK: br label %[[VAL_105:.*]] -// CHECK: is_last_tile-false8: ; preds = %[[VAL_70]] -// CHECK: %[[VAL_106:.*]] = xor i64 %[[VAL_22]], 3 -// CHECK: %[[VAL_107:.*]] = icmp slt i64 %[[VAL_22]], %[[VAL_106]] -// CHECK: %[[VAL_108:.*]] = icmp slt i64 %[[VAL_106]], 4 -// CHECK: br i1 true, label %[[VAL_109:.*]], label %[[VAL_38]] -// CHECK: smaller_comparison_index-after17: ; preds = %[[VAL_110:.*]], %[[VAL_37]] -// CHECK: br label %[[VAL_105]] -// CHECK: smaller_comparison_index-true10: ; preds = %[[VAL_36]] -// CHECK: %[[VAL_111:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_99]] -// CHECK: %[[VAL_112:.*]] = addrspacecast ptr addrspace(3) %[[VAL_111]] to ptr -// CHECK: %[[VAL_113:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_22]] -// CHECK: %[[VAL_114:.*]] = addrspacecast ptr addrspace(3) %[[VAL_113]] to ptr +// CHECK: %[[VAL_99:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_73]] +// CHECK: store float %[[VAL_97]], ptr addrspace(3) %[[VAL_99]], align 4 +// CHECK: br label %[[VAL_77]] +// CHECK: is_last_tile-true7: ; preds = %[[VAL_71]] +// CHECK: %[[VAL_100:.*]] = xor i64 %[[VAL_23]], 3 +// CHECK: %[[VAL_101:.*]] = icmp slt i64 %[[VAL_23]], %[[VAL_100]] +// CHECK: %[[VAL_102:.*]] = icmp slt i64 %[[VAL_100]], 3 +// CHECK: %[[VAL_103:.*]] = and i1 %[[VAL_101]], %[[VAL_102]] +// CHECK: br i1 %[[VAL_103]], label %[[VAL_104:.*]], label %[[VAL_40]] +// CHECK: smaller_comparison_index-after11: ; preds = %[[VAL_105:.*]], %[[VAL_37]] +// CHECK: br label %[[VAL_106:.*]] +// CHECK: is_last_tile-false8: ; preds = %[[VAL_71]] +// CHECK: %[[VAL_107:.*]] = xor i64 %[[VAL_23]], 3 +// CHECK: %[[VAL_108:.*]] = icmp slt i64 %[[VAL_23]], %[[VAL_107]] +// CHECK: %[[VAL_109:.*]] = icmp slt i64 %[[VAL_107]], 4 +// CHECK: br i1 true, label %[[VAL_110:.*]], label %[[VAL_39]] +// CHECK: smaller_comparison_index-after17: ; preds = %[[VAL_111:.*]], %[[VAL_38]] +// CHECK: br label %[[VAL_106]] +// CHECK: smaller_comparison_index-true10: ; preds = %[[VAL_37]] +// CHECK: %[[VAL_112:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_100]] +// CHECK: %[[VAL_113:.*]] = addrspacecast ptr addrspace(3) %[[VAL_112]] to ptr +// CHECK: %[[VAL_114:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_23]] +// CHECK: %[[VAL_115:.*]] = addrspacecast ptr addrspace(3) %[[VAL_114]] to ptr // CHECK-GCN: %[[VAL_3_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_3]] to ptr -// CHECK-GCN: call void @[[REGION]](ptr %[[VAL_112]], ptr %[[VAL_114]], ptr %[[VAL_3_2]]) -// CHECK-PTX: call void @[[REGION]](ptr %[[VAL_112]], ptr %[[VAL_114]], ptr %[[VAL_3]]) -// CHECK: %[[VAL_115:.*]] = load i8, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_3]], align 1 -// CHECK: %[[VAL_116:.*]] = icmp ne i8 %[[VAL_115]], 0 -// CHECK: br i1 %[[VAL_116]], label %[[VAL_117:.*]], label %[[VAL_104]] -// CHECK: is_smaller_than-after15: ; preds = %[[VAL_117]], %[[VAL_103]] -// CHECK: br label %[[VAL_39]] -// CHECK: is_smaller_than-true14: ; preds = %[[VAL_103]] -// CHECK: %[[VAL_118:.*]] = load float, ptr %[[VAL_112]], align 4 -// CHECK: %[[VAL_119:.*]] = load float, ptr %[[VAL_114]], align 4 -// CHECK: %[[VAL_120:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_22]] -// CHECK: store float %[[VAL_118]], ptr addrspace(3) %[[VAL_120]], align 4 -// CHECK: %[[VAL_121:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_99]] +// CHECK-GCN: call void @[[REGION]](ptr %[[VAL_113]], ptr %[[VAL_115]], ptr %[[VAL_3_2]]) +// CHECK-PTX: call void @[[REGION]](ptr %[[VAL_113]], ptr %[[VAL_115]], ptr %[[VAL_3]]) +// CHECK-PTX: %[[VAL_116:.*]] = load i8, ptr %[[VAL_3]], align 1 +// CHECK-GCN: %[[VAL_116:.*]] = load i8, ptr addrspace(5) %[[VAL_3]], align 1 +// CHECK: %[[VAL_117:.*]] = icmp ne i8 %[[VAL_116]], 0 +// CHECK: br i1 %[[VAL_117]], label %[[VAL_118:.*]], label %[[VAL_105]] +// CHECK: is_smaller_than-after15: ; preds = %[[VAL_118]], %[[VAL_104]] +// CHECK: br label %[[VAL_40]] +// CHECK: is_smaller_than-true14: ; preds = %[[VAL_104]] +// CHECK: %[[VAL_119:.*]] = load float, ptr %[[VAL_113]], align 4 +// CHECK: %[[VAL_120:.*]] = load float, ptr %[[VAL_115]], align 4 +// CHECK: %[[VAL_121:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_23]] // CHECK: store float %[[VAL_119]], ptr addrspace(3) %[[VAL_121]], align 4 -// CHECK: br label %[[VAL_104]] -// CHECK: smaller_comparison_index-true16: ; preds = %[[VAL_37]] -// CHECK: %[[VAL_122:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_106]] -// CHECK: %[[VAL_123:.*]] = addrspacecast ptr addrspace(3) %[[VAL_122]] to ptr -// CHECK: %[[VAL_124:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_22]] -// CHECK: %[[VAL_125:.*]] = addrspacecast ptr addrspace(3) %[[VAL_124]] to ptr +// CHECK: %[[VAL_122:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_100]] +// CHECK: store float %[[VAL_120]], ptr addrspace(3) %[[VAL_122]], align 4 +// CHECK: br label %[[VAL_105]] +// CHECK: smaller_comparison_index-true16: ; preds = %[[VAL_38]] +// CHECK: %[[VAL_123:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_107]] +// CHECK: %[[VAL_124:.*]] = addrspacecast ptr addrspace(3) %[[VAL_123]] to ptr +// CHECK: %[[VAL_125:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_23]] +// CHECK: %[[VAL_126:.*]] = addrspacecast ptr addrspace(3) %[[VAL_125]] to ptr // CHECK-GCN: %[[VAL_2_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_2]] to ptr -// CHECK-GCN: call void @[[REGION]](ptr %[[VAL_123]], ptr %[[VAL_125]], ptr %[[VAL_2_2]]) -// CHECK-PTX: call void @[[REGION]](ptr %[[VAL_123]], ptr %[[VAL_125]], ptr %[[VAL_2]]) -// CHECK: %[[VAL_126:.*]] = load i8, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_2]], align 1 -// CHECK: %[[VAL_127:.*]] = icmp ne i8 %[[VAL_126]], 0 -// CHECK: br i1 %[[VAL_127]], label %[[VAL_128:.*]], label %[[VAL_110]] -// CHECK: is_smaller_than-after21: ; preds = %[[VAL_128]], %[[VAL_109]] -// CHECK: br label %[[VAL_38]] -// CHECK: is_smaller_than-true20: ; preds = %[[VAL_109]] -// CHECK: %[[VAL_129:.*]] = load float, ptr %[[VAL_123]], align 4 -// CHECK: %[[VAL_130:.*]] = load float, ptr %[[VAL_125]], align 4 -// CHECK: %[[VAL_131:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_22]] -// CHECK: store float %[[VAL_129]], ptr addrspace(3) %[[VAL_131]], align 4 -// CHECK: %[[VAL_132:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_106]] +// CHECK-GCN: call void @[[REGION]](ptr %[[VAL_124]], ptr %[[VAL_126]], ptr %[[VAL_2_2]]) +// CHECK-PTX: call void @[[REGION]](ptr %[[VAL_124]], ptr %[[VAL_126]], ptr %[[VAL_2]]) +// CHECK-PTX: %[[VAL_127:.*]] = load i8, ptr %[[VAL_2]], align 1 +// CHECK-GCN: %[[VAL_127:.*]] = load i8, ptr addrspace(5) %[[VAL_2]], align 1 +// CHECK: %[[VAL_128:.*]] = icmp ne i8 %[[VAL_127]], 0 +// CHECK: br i1 %[[VAL_128]], label %[[VAL_129:.*]], label %[[VAL_111]] +// CHECK: is_smaller_than-after21: ; preds = %[[VAL_129]], %[[VAL_110]] +// CHECK: br label %[[VAL_39]] +// CHECK: is_smaller_than-true20: ; preds = %[[VAL_110]] +// CHECK: %[[VAL_130:.*]] = load float, ptr %[[VAL_124]], align 4 +// CHECK: %[[VAL_131:.*]] = load float, ptr %[[VAL_126]], align 4 +// CHECK: %[[VAL_132:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_23]] // CHECK: store float %[[VAL_130]], ptr addrspace(3) %[[VAL_132]], align 4 -// CHECK: br label %[[VAL_110]] -// CHECK: is_last_tile-true22: ; preds = %[[VAL_105]] -// CHECK: %[[VAL_133:.*]] = mul i64 %[[VAL_22]], 2 -// CHECK: %[[VAL_134:.*]] = xor i64 %[[VAL_133]], 1 -// CHECK: %[[VAL_135:.*]] = icmp slt i64 %[[VAL_133]], %[[VAL_134]] -// CHECK: %[[VAL_136:.*]] = icmp slt i64 %[[VAL_134]], 3 -// CHECK: %[[VAL_137:.*]] = and i1 %[[VAL_135]], %[[VAL_136]] -// CHECK: br i1 %[[VAL_137]], label %[[VAL_138:.*]], label %[[VAL_45]] -// CHECK: smaller_comparison_index-after26: ; preds = %[[VAL_139:.*]], %[[VAL_42]] -// CHECK: br label %[[VAL_50]] -// CHECK: is_last_tile-false23: ; preds = %[[VAL_105]] -// CHECK: %[[VAL_140:.*]] = mul i64 %[[VAL_22]], 2 -// CHECK: %[[VAL_141:.*]] = xor i64 %[[VAL_140]], 1 -// CHECK: %[[VAL_142:.*]] = icmp slt i64 %[[VAL_140]], %[[VAL_141]] -// CHECK: %[[VAL_143:.*]] = icmp slt i64 %[[VAL_141]], 4 -// CHECK: br i1 true, label %[[VAL_144:.*]], label %[[VAL_44]] -// CHECK: smaller_comparison_index-after32: ; preds = %[[VAL_145:.*]], %[[VAL_43]] -// CHECK: br label %[[VAL_50]] -// CHECK: smaller_comparison_index-true25: ; preds = %[[VAL_42]] -// CHECK: %[[VAL_146:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_134]] -// CHECK: %[[VAL_147:.*]] = addrspacecast ptr addrspace(3) %[[VAL_146]] to ptr -// CHECK: %[[VAL_148:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_133]] -// CHECK: %[[VAL_149:.*]] = addrspacecast ptr addrspace(3) %[[VAL_148]] to ptr +// CHECK: %[[VAL_133:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_107]] +// CHECK: store float %[[VAL_131]], ptr addrspace(3) %[[VAL_133]], align 4 +// CHECK: br label %[[VAL_111]] +// CHECK: is_last_tile-true22: ; preds = %[[VAL_106]] +// CHECK: %[[VAL_134:.*]] = mul i64 %[[VAL_23]], 2 +// CHECK: %[[VAL_135:.*]] = xor i64 %[[VAL_134]], 1 +// CHECK: %[[VAL_136:.*]] = icmp slt i64 %[[VAL_134]], %[[VAL_135]] +// CHECK: %[[VAL_137:.*]] = icmp slt i64 %[[VAL_135]], 3 +// CHECK: %[[VAL_138:.*]] = and i1 %[[VAL_136]], %[[VAL_137]] +// CHECK: br i1 %[[VAL_138]], label %[[VAL_139:.*]], label %[[VAL_46]] +// CHECK: smaller_comparison_index-after26: ; preds = %[[VAL_140:.*]], %[[VAL_43]] +// CHECK: br label %[[VAL_51]] +// CHECK: is_last_tile-false23: ; preds = %[[VAL_106]] +// CHECK: %[[VAL_141:.*]] = mul i64 %[[VAL_23]], 2 +// CHECK: %[[VAL_142:.*]] = xor i64 %[[VAL_141]], 1 +// CHECK: %[[VAL_143:.*]] = icmp slt i64 %[[VAL_141]], %[[VAL_142]] +// CHECK: %[[VAL_144:.*]] = icmp slt i64 %[[VAL_142]], 4 +// CHECK: br i1 true, label %[[VAL_145:.*]], label %[[VAL_45]] +// CHECK: smaller_comparison_index-after32: ; preds = %[[VAL_146:.*]], %[[VAL_44]] +// CHECK: br label %[[VAL_51]] +// CHECK: smaller_comparison_index-true25: ; preds = %[[VAL_43]] +// CHECK: %[[VAL_147:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_135]] +// CHECK: %[[VAL_148:.*]] = addrspacecast ptr addrspace(3) %[[VAL_147]] to ptr +// CHECK: %[[VAL_149:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_134]] +// CHECK: %[[VAL_150:.*]] = addrspacecast ptr addrspace(3) %[[VAL_149]] to ptr // CHECK-GCN: %[[VAL_1_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_1]] to ptr -// CHECK-GCN: call void @[[REGION]](ptr %[[VAL_147]], ptr %[[VAL_149]], ptr %[[VAL_1_2]]) -// CHECK-PTX: call void @[[REGION]](ptr %[[VAL_147]], ptr %[[VAL_149]], ptr %[[VAL_1]]) -// CHECK: %[[VAL_150:.*]] = load i8, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_1]], align 1 -// CHECK: %[[VAL_151:.*]] = icmp ne i8 %[[VAL_150]], 0 -// CHECK: br i1 %[[VAL_151]], label %[[VAL_152:.*]], label %[[VAL_139]] -// CHECK: is_smaller_than-after30: ; preds = %[[VAL_152]], %[[VAL_138]] -// CHECK: br label %[[VAL_45]] -// CHECK: is_smaller_than-true29: ; preds = %[[VAL_138]] -// CHECK: %[[VAL_153:.*]] = load float, ptr %[[VAL_147]], align 4 -// CHECK: %[[VAL_154:.*]] = load float, ptr %[[VAL_149]], align 4 -// CHECK: %[[VAL_155:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_133]] -// CHECK: store float %[[VAL_153]], ptr addrspace(3) %[[VAL_155]], align 4 +// CHECK-GCN: call void @[[REGION]](ptr %[[VAL_148]], ptr %[[VAL_150]], ptr %[[VAL_1_2]]) +// CHECK-PTX: call void @[[REGION]](ptr %[[VAL_148]], ptr %[[VAL_150]], ptr %[[VAL_1]]) +// CHECK-PTX: %[[VAL_151:.*]] = load i8, ptr %[[VAL_1]], align 1 +// CHECK-GCN: %[[VAL_151:.*]] = load i8, ptr addrspace(5) %[[VAL_1]], align 1 +// CHECK: %[[VAL_152:.*]] = icmp ne i8 %[[VAL_151]], 0 +// CHECK: br i1 %[[VAL_152]], label %[[VAL_153:.*]], label %[[VAL_140]] +// CHECK: is_smaller_than-after30: ; preds = %[[VAL_153]], %[[VAL_139]] +// CHECK: br label %[[VAL_46]] +// CHECK: is_smaller_than-true29: ; preds = %[[VAL_139]] +// CHECK: %[[VAL_154:.*]] = load float, ptr %[[VAL_148]], align 4 +// CHECK: %[[VAL_155:.*]] = load float, ptr %[[VAL_150]], align 4 // CHECK: %[[VAL_156:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_134]] // CHECK: store float %[[VAL_154]], ptr addrspace(3) %[[VAL_156]], align 4 -// CHECK: br label %[[VAL_139]] -// CHECK: smaller_comparison_index-true31: ; preds = %[[VAL_43]] -// CHECK: %[[VAL_157:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_141]] -// CHECK: %[[VAL_158:.*]] = addrspacecast ptr addrspace(3) %[[VAL_157]] to ptr -// CHECK: %[[VAL_159:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_140]] -// CHECK: %[[VAL_160:.*]] = addrspacecast ptr addrspace(3) %[[VAL_159]] to ptr +// CHECK: %[[VAL_157:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_135]] +// CHECK: store float %[[VAL_155]], ptr addrspace(3) %[[VAL_157]], align 4 +// CHECK: br label %[[VAL_140]] +// CHECK: smaller_comparison_index-true31: ; preds = %[[VAL_44]] +// CHECK: %[[VAL_158:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_142]] +// CHECK: %[[VAL_159:.*]] = addrspacecast ptr addrspace(3) %[[VAL_158]] to ptr +// CHECK: %[[VAL_160:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_141]] +// CHECK: %[[VAL_161:.*]] = addrspacecast ptr addrspace(3) %[[VAL_160]] to ptr // CHECK-GCN: %[[VAL_0_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_0]] to ptr -// CHECK-GCN: call void @[[REGION]](ptr %[[VAL_158]], ptr %[[VAL_160]], ptr %[[VAL_0_1]]) -// CHECK-PTX: call void @[[REGION]](ptr %[[VAL_158]], ptr %[[VAL_160]], ptr %[[VAL_0]]) -// CHECK: %[[VAL_161:.*]] = load i8, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_0]], align 1 -// CHECK: %[[VAL_162:.*]] = icmp ne i8 %[[VAL_161]], 0 -// CHECK: br i1 %[[VAL_162]], label %[[VAL_163:.*]], label %[[VAL_145]] -// CHECK: is_smaller_than-after36: ; preds = %[[VAL_163]], %[[VAL_144]] -// CHECK: br label %[[VAL_44]] -// CHECK: is_smaller_than-true35: ; preds = %[[VAL_144]] -// CHECK: %[[VAL_164:.*]] = load float, ptr %[[VAL_158]], align 4 -// CHECK: %[[VAL_165:.*]] = load float, ptr %[[VAL_160]], align 4 -// CHECK: %[[VAL_166:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_140]] -// CHECK: store float %[[VAL_164]], ptr addrspace(3) %[[VAL_166]], align 4 +// CHECK-GCN: call void @[[REGION]](ptr %[[VAL_159]], ptr %[[VAL_161]], ptr %[[VAL_0_1]]) +// CHECK-PTX: call void @[[REGION]](ptr %[[VAL_159]], ptr %[[VAL_161]], ptr %[[VAL_0]]) +// CHECK-PTX: %[[VAL_162:.*]] = load i8, ptr %[[VAL_0]], align 1 +// CHECK-GCN: %[[VAL_162:.*]] = load i8, ptr addrspace(5) %[[VAL_0]], align 1 +// CHECK: %[[VAL_163:.*]] = icmp ne i8 %[[VAL_162]], 0 +// CHECK: br i1 %[[VAL_163]], label %[[VAL_164:.*]], label %[[VAL_146]] +// CHECK: is_smaller_than-after36: ; preds = %[[VAL_164]], %[[VAL_145]] +// CHECK: br label %[[VAL_45]] +// CHECK: is_smaller_than-true35: ; preds = %[[VAL_145]] +// CHECK: %[[VAL_165:.*]] = load float, ptr %[[VAL_159]], align 4 +// CHECK: %[[VAL_166:.*]] = load float, ptr %[[VAL_161]], align 4 // CHECK: %[[VAL_167:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_141]] // CHECK: store float %[[VAL_165]], ptr addrspace(3) %[[VAL_167]], align 4 -// CHECK: br label %[[VAL_145]] -// CHECK: smaller_keys_index-true37: ; preds = %[[VAL_50]] -// CHECK: %[[VAL_168:.*]] = shl i64 %[[VAL_22]], 1 -// CHECK: %[[VAL_169:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_168]] -// CHECK: %[[VAL_170:.*]] = load float, ptr addrspace(3) %[[VAL_169]], align 4 -// CHECK: %[[VAL_171:.*]] = getelementptr inbounds [2 x [3 x float]], ptr %[[VAL_53]], i64 0, i64 %[[VAL_15]], i64 %[[VAL_46]] -// CHECK: store float %[[VAL_170]], ptr %[[VAL_171]], align 4 -// CHECK: %[[VAL_172:.*]] = add i64 %[[VAL_46]], 1 -// CHECK: %[[VAL_173:.*]] = icmp slt i64 %[[VAL_172]], 3 -// CHECK: br i1 %[[VAL_173]], label %[[VAL_174:.*]], label %[[VAL_49]] -// CHECK: inner_smaller_keys_index-after40: ; preds = %[[VAL_174]], %[[VAL_48]] -// CHECK: br label %[[VAL_19]] -// CHECK: inner_smaller_keys_index-true39: ; preds = %[[VAL_48]] -// CHECK: %[[VAL_175:.*]] = add i64 %[[VAL_168]], 1 -// CHECK: %[[VAL_176:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_175]] -// CHECK: %[[VAL_177:.*]] = load float, ptr addrspace(3) %[[VAL_176]], align 4 -// CHECK: %[[VAL_178:.*]] = getelementptr inbounds [2 x [3 x float]], ptr %[[VAL_53]], i64 0, i64 %[[VAL_15]], i64 %[[VAL_172]] -// CHECK: store float %[[VAL_177]], ptr %[[VAL_178]], align 4 -// CHECK: br label %[[VAL_49]] +// CHECK: %[[VAL_168:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_142]] +// CHECK: store float %[[VAL_166]], ptr addrspace(3) %[[VAL_168]], align 4 +// CHECK: br label %[[VAL_146]] +// CHECK: smaller_keys_index-true37: ; preds = %[[VAL_51]] +// CHECK: %[[VAL_169:.*]] = shl i64 %[[VAL_23]], 1 +// CHECK: %[[VAL_170:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_169]] +// CHECK: %[[VAL_171:.*]] = load float, ptr addrspace(3) %[[VAL_170]], align 4 +// CHECK: %[[VAL_172:.*]] = getelementptr inbounds [2 x [3 x float]], ptr %[[VAL_54]], i64 0, i64 %[[VAL_16]], i64 %[[VAL_47]] +// CHECK: store float %[[VAL_171]], ptr %[[VAL_172]], align 4 +// CHECK: %[[VAL_173:.*]] = add i64 %[[VAL_47]], 1 +// CHECK: %[[VAL_174:.*]] = icmp slt i64 %[[VAL_173]], 3 +// CHECK: br i1 %[[VAL_174]], label %[[VAL_175:.*]], label %[[VAL_50]] +// CHECK: inner_smaller_keys_index-after40: ; preds = %[[VAL_175]], %[[VAL_49]] +// CHECK: br label %[[VAL_20]] +// CHECK: inner_smaller_keys_index-true39: ; preds = %[[VAL_49]] +// CHECK: %[[VAL_176:.*]] = add i64 %[[VAL_169]], 1 +// CHECK: %[[VAL_177:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_176]] +// CHECK: %[[VAL_178:.*]] = load float, ptr addrspace(3) %[[VAL_177]], align 4 +// CHECK: %[[VAL_179:.*]] = getelementptr inbounds [2 x [3 x float]], ptr %[[VAL_54]], i64 0, i64 %[[VAL_16]], i64 %[[VAL_173]] +// CHECK: store float %[[VAL_178]], ptr %[[VAL_179]], align 4 +// CHECK: br label %[[VAL_50]] // CHECK: entry: -// CHECK: %[[VAL_179:.*]] = alloca i8, align 1 -// CHECK: %[[VAL_180:.*]] = load float, ptr %[[VAL_181:.*]], align 4 -// CHECK: %[[VAL_182:.*]] = load float, ptr %[[VAL_183:.*]], align 4 -// CHECK: %[[VAL_184:.*]] = fcmp olt float %[[VAL_180]], %[[VAL_182]] -// CHECK: %[[VAL_185:.*]] = zext i1 %[[VAL_184]] to i8 -// CHECK: store i8 %[[VAL_185]], ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_179]], align 1 -// CHECK: %[[VAL_186:.*]] = load i8, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_179]], align 1 -// CHECK: store i8 %[[VAL_186]], ptr %[[VAL_187:.*]], align 1 +// CHECK: %[[VAL_180:.*]] = alloca i8, align 1 +// CHECK: %[[VAL_181:.*]] = load float, ptr %[[VAL_182:.*]], align 4 +// CHECK: %[[VAL_183:.*]] = load float, ptr %[[VAL_184:.*]], align 4 +// CHECK: %[[VAL_185:.*]] = fcmp olt float %[[VAL_181]], %[[VAL_183]] +// CHECK: %[[VAL_186:.*]] = zext i1 %[[VAL_185]] to i8 +// CHECK-PTX: store i8 %[[VAL_186]], ptr %[[VAL_180]], align 1 +// CHECK-GCN: store i8 %[[VAL_186]], ptr addrspace(5) %[[VAL_180]], align 1 +// CHECK-PTX: %[[VAL_187:.*]] = load i8, ptr %[[VAL_180]], align 1 +// CHECK-GCN: %[[VAL_187:.*]] = load i8, ptr addrspace(5) %[[VAL_180]], align 1 +// CHECK: store i8 %[[VAL_187]], ptr %[[VAL_188:.*]], align 1 // CHECK: ret void ENTRY main { @@ -299,379 +315,395 @@ compare { } // CHECK: entry: -// CHECK: %[[VAL_188:.*]] = alloca i8, align 1 // CHECK: %[[VAL_189:.*]] = alloca i8, align 1 // CHECK: %[[VAL_190:.*]] = alloca i8, align 1 // CHECK: %[[VAL_191:.*]] = alloca i8, align 1 // CHECK: %[[VAL_192:.*]] = alloca i8, align 1 // CHECK: %[[VAL_193:.*]] = alloca i8, align 1 -// CHECK: %[[VAL_194:.*]] = call i32 [[CTAIDX]] -// CHECK: %[[VAL_195:.*]] = zext i32 %[[VAL_194]] to i64 -// CHECK: %[[VAL_196:.*]] = call i32 [[TIDX]] -// CHECK: %[[VAL_197:.*]] = zext i32 %[[VAL_196]] to i64 -// CHECK: %[[VAL_198:.*]] = mul nuw nsw i64 %[[VAL_195]], 2 -// CHECK: %[[VAL_199:.*]] = add nuw nsw i64 %[[VAL_198]], %[[VAL_197]] -// CHECK: %[[VAL_200:.*]] = icmp ult i64 %[[VAL_199]], 4 -// CHECK: call void @llvm.assume(i1 %[[VAL_200]]) -// CHECK: %[[VAL_201:.*]] = udiv i64 %[[VAL_199]], 1 -// CHECK: %[[VAL_202:.*]] = urem i64 %[[VAL_201]], 2 -// CHECK: %[[VAL_203:.*]] = udiv i64 %[[VAL_199]], 2 -// CHECK: %[[VAL_204:.*]] = icmp ult i64 %[[VAL_199]], 4 -// CHECK: br i1 %[[VAL_204]], label %[[VAL_205:.*]], label %[[VAL_206:.*]] -// CHECK: sort.in_bounds-after: ; preds = %[[VAL_207:.*]], %[[VAL_208:.*]] +// CHECK: %[[VAL_194:.*]] = alloca i8, align 1 +// CHECK-PTX: %[[VAL_195:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x +// CHECK-GCN: %[[VAL_195:.*]] = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %[[VAL_196:.*]] = zext i32 %[[VAL_195]] to i64 +// CHECK-PTX: %[[VAL_197:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x +// CHECK-GCN: %[[VAL_197:.*]] = call i32 @llvm.amdgcn.workitem.id.x +// CHECK: %[[VAL_198:.*]] = zext i32 %[[VAL_197]] to i64 +// CHECK: %[[VAL_199:.*]] = mul nuw nsw i64 %[[VAL_196]], 2 +// CHECK: %[[VAL_200:.*]] = add nuw nsw i64 %[[VAL_199]], %[[VAL_198]] +// CHECK: %[[VAL_201:.*]] = icmp ult i64 %[[VAL_200]], 4 +// CHECK: call void @llvm.assume(i1 %[[VAL_201]]) +// CHECK: %[[VAL_202:.*]] = add nuw nsw i64 %[[VAL_200]], 0 +// CHECK: %[[VAL_203:.*]] = udiv i64 %[[VAL_202]], 1 +// CHECK: %[[VAL_204:.*]] = urem i64 %[[VAL_203]], 2 +// CHECK: %[[VAL_205:.*]] = udiv i64 %[[VAL_202]], 2 +// CHECK: %[[VAL_206:.*]] = icmp ult i64 %[[VAL_200]], 4 +// CHECK: br i1 %[[VAL_206]], label %[[VAL_207:.*]], label %[[VAL_208:.*]] +// CHECK: sort.in_bounds-after: ; preds = %[[VAL_209:.*]], %[[VAL_210:.*]] // CHECK: ret void -// CHECK: sort.in_bounds-true: ; preds = %[[VAL_208]] -// CHECK: %[[VAL_209:.*]] = call i32 [[TIDX]] -// CHECK: %[[VAL_210:.*]] = sext i32 %[[VAL_209]] to i64 -// CHECK: %[[VAL_211:.*]] = shl i64 %[[VAL_202]], 1 -// CHECK: %[[VAL_212:.*]] = icmp slt i64 %[[VAL_211]], 3 -// CHECK: br i1 %[[VAL_212]], label %[[VAL_213:.*]], label %[[VAL_214:.*]] -// CHECK: smaller_keys_index-after: ; preds = %[[VAL_215:.*]], %[[VAL_205]] -// CHECK: %[[VAL_216:.*]] = shl i64 %[[VAL_202]], 1 -// CHECK: %[[VAL_217:.*]] = icmp slt i64 %[[VAL_216]], 3 -// CHECK: br i1 %[[VAL_217]], label %[[VAL_218:.*]], label %[[VAL_219:.*]] -// CHECK: smaller_keys_index-after2: ; preds = %[[VAL_220:.*]], %[[VAL_214]] -// CHECK: call void [[BARRIER]] -// CHECK: %[[VAL_221:.*]] = mul i64 %[[VAL_202]], 2 -// CHECK: %[[VAL_222:.*]] = icmp uge i64 %[[VAL_221]], 0 -// CHECK: br i1 %[[VAL_222]], label %[[VAL_223:.*]], label %[[VAL_224:.*]] -// CHECK: is_last_tile-after: ; preds = %[[VAL_225:.*]], %[[VAL_226:.*]] -// CHECK: call void [[BARRIER]] -// CHECK: %[[VAL_227:.*]] = mul i64 %[[VAL_202]], 2 -// CHECK: %[[VAL_228:.*]] = icmp uge i64 %[[VAL_227]], 0 -// CHECK: br i1 %[[VAL_228]], label %[[VAL_229:.*]], label %[[VAL_230:.*]] -// CHECK: is_last_tile-after13: ; preds = %[[VAL_231:.*]], %[[VAL_232:.*]] -// CHECK: call void [[BARRIER]] -// CHECK: %[[VAL_233:.*]] = mul i64 %[[VAL_202]], 2 -// CHECK: %[[VAL_234:.*]] = icmp uge i64 %[[VAL_233]], 0 -// CHECK: br i1 %[[VAL_234]], label %[[VAL_235:.*]], label %[[VAL_236:.*]] -// CHECK: is_last_tile-after28: ; preds = %[[VAL_237:.*]], %[[VAL_238:.*]] -// CHECK: call void [[BARRIER]] -// CHECK: %[[VAL_239:.*]] = shl i64 %[[VAL_202]], 1 -// CHECK: %[[VAL_240:.*]] = icmp slt i64 %[[VAL_239]], 3 -// CHECK: br i1 %[[VAL_240]], label %[[VAL_241:.*]], label %[[VAL_242:.*]] -// CHECK: smaller_keys_index-after42: ; preds = %[[VAL_243:.*]], %[[VAL_244:.*]] -// CHECK: %[[VAL_245:.*]] = shl i64 %[[VAL_202]], 1 -// CHECK: %[[VAL_246:.*]] = icmp slt i64 %[[VAL_245]], 3 -// CHECK: br i1 %[[VAL_246]], label %[[VAL_247:.*]], label %[[VAL_207]] -// CHECK: smaller_keys_index-after46: ; preds = %[[VAL_248:.*]], %[[VAL_242]] -// CHECK: br label %[[VAL_206]] -// CHECK: smaller_keys_index-true: ; preds = %[[VAL_205]] -// CHECK: %[[VAL_249:.*]] = shl i64 %[[VAL_210]], 1 -// CHECK: %[[VAL_250:.*]] = getelementptr inbounds [2 x [3 x i32]], ptr %[[VAL_251:.*]], i64 0, i64 %[[VAL_203]], i64 %[[VAL_211]] -// CHECK: %[[VAL_252:.*]] = load i32, ptr %[[VAL_250]], align 4 -// CHECK: %[[VAL_253:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_249]] -// CHECK: store i32 %[[VAL_252]], ptr addrspace(3) %[[VAL_253]], align 4 -// CHECK: %[[VAL_254:.*]] = add i64 %[[VAL_211]], 1 -// CHECK: %[[VAL_255:.*]] = icmp slt i64 %[[VAL_254]], 3 -// CHECK: br i1 %[[VAL_255]], label %[[VAL_256:.*]], label %[[VAL_215]] -// CHECK: inner_smaller_keys_index-after: ; preds = %[[VAL_256]], %[[VAL_213]] -// CHECK: br label %[[VAL_214]] -// CHECK: inner_smaller_keys_index-true: ; preds = %[[VAL_213]] -// CHECK: %[[VAL_257:.*]] = add i64 %[[VAL_249]], 1 -// CHECK: %[[VAL_258:.*]] = getelementptr inbounds [2 x [3 x i32]], ptr %[[VAL_251]], i64 0, i64 %[[VAL_203]], i64 %[[VAL_254]] -// CHECK: %[[VAL_259:.*]] = load i32, ptr %[[VAL_258]], align 4 -// CHECK: %[[VAL_260:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_257]] -// CHECK: store i32 %[[VAL_259]], ptr addrspace(3) %[[VAL_260]], align 4 -// CHECK: br label %[[VAL_215]] -// CHECK: smaller_keys_index-true1: ; preds = %[[VAL_214]] -// CHECK: %[[VAL_261:.*]] = shl i64 %[[VAL_210]], 1 -// CHECK: %[[VAL_262:.*]] = getelementptr inbounds [2 x [3 x float]], ptr %[[VAL_263:.*]], i64 0, i64 %[[VAL_203]], i64 %[[VAL_216]] -// CHECK: %[[VAL_264:.*]] = load float, ptr %[[VAL_262]], align 4 -// CHECK: %[[VAL_265:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_261]] -// CHECK: store float %[[VAL_264]], ptr addrspace(3) %[[VAL_265]], align 4 -// CHECK: %[[VAL_266:.*]] = add i64 %[[VAL_216]], 1 -// CHECK: %[[VAL_267:.*]] = icmp slt i64 %[[VAL_266]], 3 -// CHECK: br i1 %[[VAL_267]], label %[[VAL_268:.*]], label %[[VAL_220]] -// CHECK: inner_smaller_keys_index-after4: ; preds = %[[VAL_268]], %[[VAL_218]] -// CHECK: br label %[[VAL_219]] -// CHECK: inner_smaller_keys_index-true3: ; preds = %[[VAL_218]] -// CHECK: %[[VAL_269:.*]] = add i64 %[[VAL_261]], 1 -// CHECK: %[[VAL_270:.*]] = getelementptr inbounds [2 x [3 x float]], ptr %[[VAL_263]], i64 0, i64 %[[VAL_203]], i64 %[[VAL_266]] -// CHECK: %[[VAL_271:.*]] = load float, ptr %[[VAL_270]], align 4 -// CHECK: %[[VAL_272:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_269]] -// CHECK: store float %[[VAL_271]], ptr addrspace(3) %[[VAL_272]], align 4 -// CHECK: br label %[[VAL_220]] -// CHECK: is_last_tile-true: ; preds = %[[VAL_219]] -// CHECK: %[[VAL_273:.*]] = mul i64 %[[VAL_210]], 2 -// CHECK: %[[VAL_274:.*]] = xor i64 %[[VAL_273]], 1 -// CHECK: %[[VAL_275:.*]] = icmp slt i64 %[[VAL_273]], %[[VAL_274]] -// CHECK: %[[VAL_276:.*]] = icmp slt i64 %[[VAL_274]], 3 -// CHECK: %[[VAL_277:.*]] = and i1 %[[VAL_275]], %[[VAL_276]] -// CHECK: br i1 %[[VAL_277]], label %[[VAL_278:.*]], label %[[VAL_226]] -// CHECK: smaller_comparison_index-after: ; preds = %[[VAL_279:.*]], %[[VAL_223]] -// CHECK: br label %[[VAL_280:.*]] -// CHECK: is_last_tile-false: ; preds = %[[VAL_219]] -// CHECK: %[[VAL_281:.*]] = mul i64 %[[VAL_210]], 2 -// CHECK: %[[VAL_282:.*]] = xor i64 %[[VAL_281]], 1 -// CHECK: %[[VAL_283:.*]] = icmp slt i64 %[[VAL_281]], %[[VAL_282]] -// CHECK: %[[VAL_284:.*]] = icmp slt i64 %[[VAL_282]], 4 -// CHECK: br i1 true, label %[[VAL_285:.*]], label %[[VAL_225]] -// CHECK: smaller_comparison_index-after6: ; preds = %[[VAL_286:.*]], %[[VAL_224]] -// CHECK: br label %[[VAL_280]] -// CHECK: smaller_comparison_index-true: ; preds = %[[VAL_223]] -// CHECK: %[[VAL_287:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_274]] -// CHECK: %[[VAL_288:.*]] = addrspacecast ptr addrspace(3) %[[VAL_287]] to ptr -// CHECK: %[[VAL_289:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_273]] +// CHECK: sort.in_bounds-true: ; preds = %[[VAL_210]] +// CHECK-PTX: %[[VAL_211:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x +// CHECK-GCN: %[[VAL_211:.*]] = call i32 @llvm.amdgcn.workitem.id.x +// CHECK: %[[VAL_212:.*]] = sext i32 %[[VAL_211]] to i64 +// CHECK: %[[VAL_213:.*]] = shl i64 %[[VAL_204]], 1 +// CHECK: %[[VAL_214:.*]] = icmp slt i64 %[[VAL_213]], 3 +// CHECK: br i1 %[[VAL_214]], label %[[VAL_215:.*]], label %[[VAL_216:.*]] +// CHECK: smaller_keys_index-after: ; preds = %[[VAL_217:.*]], %[[VAL_207]] +// CHECK: %[[VAL_218:.*]] = shl i64 %[[VAL_204]], 1 +// CHECK: %[[VAL_219:.*]] = icmp slt i64 %[[VAL_218]], 3 +// CHECK: br i1 %[[VAL_219]], label %[[VAL_220:.*]], label %[[VAL_221:.*]] +// CHECK: smaller_keys_index-after2: ; preds = %[[VAL_222:.*]], %[[VAL_216]] +// CHECK-PTX: call void @llvm.nvvm.barrier0 +// CHECK-GCN: call void @llvm.amdgcn.s.barrier +// CHECK: %[[VAL_223:.*]] = mul i64 %[[VAL_204]], 2 +// CHECK: %[[VAL_224:.*]] = icmp uge i64 %[[VAL_223]], 0 +// CHECK: br i1 %[[VAL_224]], label %[[VAL_225:.*]], label %[[VAL_226:.*]] +// CHECK: is_last_tile-after: ; preds = %[[VAL_227:.*]], %[[VAL_228:.*]] +// CHECK-PTX: call void @llvm.nvvm.barrier0 +// CHECK-GCN: call void @llvm.amdgcn.s.barrier +// CHECK: %[[VAL_229:.*]] = mul i64 %[[VAL_204]], 2 +// CHECK: %[[VAL_230:.*]] = icmp uge i64 %[[VAL_229]], 0 +// CHECK: br i1 %[[VAL_230]], label %[[VAL_231:.*]], label %[[VAL_232:.*]] +// CHECK: is_last_tile-after13: ; preds = %[[VAL_233:.*]], %[[VAL_234:.*]] +// CHECK-PTX: call void @llvm.nvvm.barrier0 +// CHECK-GCN: call void @llvm.amdgcn.s.barrier +// CHECK: %[[VAL_235:.*]] = mul i64 %[[VAL_204]], 2 +// CHECK: %[[VAL_236:.*]] = icmp uge i64 %[[VAL_235]], 0 +// CHECK: br i1 %[[VAL_236]], label %[[VAL_237:.*]], label %[[VAL_238:.*]] +// CHECK: is_last_tile-after28: ; preds = %[[VAL_239:.*]], %[[VAL_240:.*]] +// CHECK-PTX: call void @llvm.nvvm.barrier0 +// CHECK-GCN: call void @llvm.amdgcn.s.barrier +// CHECK: %[[VAL_241:.*]] = shl i64 %[[VAL_204]], 1 +// CHECK: %[[VAL_242:.*]] = icmp slt i64 %[[VAL_241]], 3 +// CHECK: br i1 %[[VAL_242]], label %[[VAL_243:.*]], label %[[VAL_244:.*]] +// CHECK: smaller_keys_index-after42: ; preds = %[[VAL_245:.*]], %[[VAL_246:.*]] +// CHECK: %[[VAL_247:.*]] = shl i64 %[[VAL_204]], 1 +// CHECK: %[[VAL_248:.*]] = icmp slt i64 %[[VAL_247]], 3 +// CHECK: br i1 %[[VAL_248]], label %[[VAL_249:.*]], label %[[VAL_209]] +// CHECK: smaller_keys_index-after46: ; preds = %[[VAL_250:.*]], %[[VAL_244]] +// CHECK: br label %[[VAL_208]] +// CHECK: smaller_keys_index-true: ; preds = %[[VAL_207]] +// CHECK: %[[VAL_251:.*]] = shl i64 %[[VAL_212]], 1 +// CHECK: %[[VAL_252:.*]] = getelementptr inbounds [2 x [3 x i32]], ptr %[[VAL_253:.*]], i64 0, i64 %[[VAL_205]], i64 %[[VAL_213]] +// CHECK: %[[VAL_254:.*]] = load i32, ptr %[[VAL_252]], align 4 +// CHECK: %[[VAL_255:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_251]] +// CHECK: store i32 %[[VAL_254]], ptr addrspace(3) %[[VAL_255]], align 4 +// CHECK: %[[VAL_256:.*]] = add i64 %[[VAL_213]], 1 +// CHECK: %[[VAL_257:.*]] = icmp slt i64 %[[VAL_256]], 3 +// CHECK: br i1 %[[VAL_257]], label %[[VAL_258:.*]], label %[[VAL_217]] +// CHECK: inner_smaller_keys_index-after: ; preds = %[[VAL_258]], %[[VAL_215]] +// CHECK: br label %[[VAL_216]] +// CHECK: inner_smaller_keys_index-true: ; preds = %[[VAL_215]] +// CHECK: %[[VAL_259:.*]] = add i64 %[[VAL_251]], 1 +// CHECK: %[[VAL_260:.*]] = getelementptr inbounds [2 x [3 x i32]], ptr %[[VAL_253]], i64 0, i64 %[[VAL_205]], i64 %[[VAL_256]] +// CHECK: %[[VAL_261:.*]] = load i32, ptr %[[VAL_260]], align 4 +// CHECK: %[[VAL_262:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_259]] +// CHECK: store i32 %[[VAL_261]], ptr addrspace(3) %[[VAL_262]], align 4 +// CHECK: br label %[[VAL_217]] +// CHECK: smaller_keys_index-true1: ; preds = %[[VAL_216]] +// CHECK: %[[VAL_263:.*]] = shl i64 %[[VAL_212]], 1 +// CHECK: %[[VAL_264:.*]] = getelementptr inbounds [2 x [3 x float]], ptr %[[VAL_265:.*]], i64 0, i64 %[[VAL_205]], i64 %[[VAL_218]] +// CHECK: %[[VAL_266:.*]] = load float, ptr %[[VAL_264]], align 4 +// CHECK: %[[VAL_267:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_263]] +// CHECK: store float %[[VAL_266]], ptr addrspace(3) %[[VAL_267]], align 4 +// CHECK: %[[VAL_268:.*]] = add i64 %[[VAL_218]], 1 +// CHECK: %[[VAL_269:.*]] = icmp slt i64 %[[VAL_268]], 3 +// CHECK: br i1 %[[VAL_269]], label %[[VAL_270:.*]], label %[[VAL_222]] +// CHECK: inner_smaller_keys_index-after4: ; preds = %[[VAL_270]], %[[VAL_220]] +// CHECK: br label %[[VAL_221]] +// CHECK: inner_smaller_keys_index-true3: ; preds = %[[VAL_220]] +// CHECK: %[[VAL_271:.*]] = add i64 %[[VAL_263]], 1 +// CHECK: %[[VAL_272:.*]] = getelementptr inbounds [2 x [3 x float]], ptr %[[VAL_265]], i64 0, i64 %[[VAL_205]], i64 %[[VAL_268]] +// CHECK: %[[VAL_273:.*]] = load float, ptr %[[VAL_272]], align 4 +// CHECK: %[[VAL_274:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_271]] +// CHECK: store float %[[VAL_273]], ptr addrspace(3) %[[VAL_274]], align 4 +// CHECK: br label %[[VAL_222]] +// CHECK: is_last_tile-true: ; preds = %[[VAL_221]] +// CHECK: %[[VAL_275:.*]] = mul i64 %[[VAL_212]], 2 +// CHECK: %[[VAL_276:.*]] = xor i64 %[[VAL_275]], 1 +// CHECK: %[[VAL_277:.*]] = icmp slt i64 %[[VAL_275]], %[[VAL_276]] +// CHECK: %[[VAL_278:.*]] = icmp slt i64 %[[VAL_276]], 3 +// CHECK: %[[VAL_279:.*]] = and i1 %[[VAL_277]], %[[VAL_278]] +// CHECK: br i1 %[[VAL_279]], label %[[VAL_280:.*]], label %[[VAL_228]] +// CHECK: smaller_comparison_index-after: ; preds = %[[VAL_281:.*]], %[[VAL_225]] +// CHECK: br label %[[VAL_282:.*]] +// CHECK: is_last_tile-false: ; preds = %[[VAL_221]] +// CHECK: %[[VAL_283:.*]] = mul i64 %[[VAL_212]], 2 +// CHECK: %[[VAL_284:.*]] = xor i64 %[[VAL_283]], 1 +// CHECK: %[[VAL_285:.*]] = icmp slt i64 %[[VAL_283]], %[[VAL_284]] +// CHECK: %[[VAL_286:.*]] = icmp slt i64 %[[VAL_284]], 4 +// CHECK: br i1 true, label %[[VAL_287:.*]], label %[[VAL_227]] +// CHECK: smaller_comparison_index-after6: ; preds = %[[VAL_288:.*]], %[[VAL_226]] +// CHECK: br label %[[VAL_282]] +// CHECK: smaller_comparison_index-true: ; preds = %[[VAL_225]] +// CHECK: %[[VAL_289:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_276]] // CHECK: %[[VAL_290:.*]] = addrspacecast ptr addrspace(3) %[[VAL_289]] to ptr -// CHECK: %[[VAL_291:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_274]] +// CHECK: %[[VAL_291:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_275]] // CHECK: %[[VAL_292:.*]] = addrspacecast ptr addrspace(3) %[[VAL_291]] to ptr -// CHECK: %[[VAL_293:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_273]] +// CHECK: %[[VAL_293:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_276]] // CHECK: %[[VAL_294:.*]] = addrspacecast ptr addrspace(3) %[[VAL_293]] to ptr -// CHECK-GCN: %[[VAL_193_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_193]] to ptr -// CHECK-PTX: call void @[[REGION2:compare_.*]](ptr %[[VAL_288]], ptr %[[VAL_290]], ptr %[[VAL_292]], ptr %[[VAL_294]], ptr %[[VAL_193]]) -// CHECK-GCN: call void @[[REGION2:compare_.*]](ptr %[[VAL_288]], ptr %[[VAL_290]], ptr %[[VAL_292]], ptr %[[VAL_294]], ptr %[[VAL_193_2]]) -// CHECK: %[[VAL_295:.*]] = load i8, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_193]], align 1 -// CHECK: %[[VAL_296:.*]] = icmp ne i8 %[[VAL_295]], 0 -// CHECK: br i1 %[[VAL_296]], label %[[VAL_297:.*]], label %[[VAL_279]] -// CHECK: is_smaller_than-after: ; preds = %[[VAL_297]], %[[VAL_278]] -// CHECK: br label %[[VAL_226]] -// CHECK: is_smaller_than-true: ; preds = %[[VAL_278]] -// CHECK: %[[VAL_298:.*]] = load i32, ptr %[[VAL_288]], align 4 -// CHECK: %[[VAL_299:.*]] = load i32, ptr %[[VAL_290]], align 4 -// CHECK: %[[VAL_300:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_273]] -// CHECK: store i32 %[[VAL_298]], ptr addrspace(3) %[[VAL_300]], align 4 -// CHECK: %[[VAL_301:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_274]] -// CHECK: store i32 %[[VAL_299]], ptr addrspace(3) %[[VAL_301]], align 4 -// CHECK: %[[VAL_302:.*]] = load float, ptr %[[VAL_292]], align 4 -// CHECK: %[[VAL_303:.*]] = load float, ptr %[[VAL_294]], align 4 -// CHECK: %[[VAL_304:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_273]] -// CHECK: store float %[[VAL_302]], ptr addrspace(3) %[[VAL_304]], align 4 -// CHECK: %[[VAL_305:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_274]] -// CHECK: store float %[[VAL_303]], ptr addrspace(3) %[[VAL_305]], align 4 -// CHECK: br label %[[VAL_279]] -// CHECK: smaller_comparison_index-true5: ; preds = %[[VAL_224]] -// CHECK: %[[VAL_306:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_282]] -// CHECK: %[[VAL_307:.*]] = addrspacecast ptr addrspace(3) %[[VAL_306]] to ptr -// CHECK: %[[VAL_308:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_281]] +// CHECK: %[[VAL_295:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_275]] +// CHECK: %[[VAL_296:.*]] = addrspacecast ptr addrspace(3) %[[VAL_295]] to ptr +// CHECK-GCN: %[[VAL_194_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_194]] to ptr +// CHECK-PTX: call void @[[REGION2:compare_.*]](ptr %[[VAL_290]], ptr %[[VAL_292]], ptr %[[VAL_294]], ptr %[[VAL_296]], ptr %[[VAL_194]]) +// CHECK-GCN: call void @[[REGION2:compare_.*]](ptr %[[VAL_290]], ptr %[[VAL_292]], ptr %[[VAL_294]], ptr %[[VAL_296]], ptr %[[VAL_194_2]]) +// CHECK-PTX: %[[VAL_297:.*]] = load i8, ptr %[[VAL_194]], align 1 +// CHECK-GCN: %[[VAL_297:.*]] = load i8, ptr addrspace(5) %[[VAL_194]], align 1 +// CHECK: %[[VAL_298:.*]] = icmp ne i8 %[[VAL_297]], 0 +// CHECK: br i1 %[[VAL_298]], label %[[VAL_299:.*]], label %[[VAL_281]] +// CHECK: is_smaller_than-after: ; preds = %[[VAL_299]], %[[VAL_280]] +// CHECK: br label %[[VAL_228]] +// CHECK: is_smaller_than-true: ; preds = %[[VAL_280]] +// CHECK: %[[VAL_300:.*]] = load i32, ptr %[[VAL_290]], align 4 +// CHECK: %[[VAL_301:.*]] = load i32, ptr %[[VAL_292]], align 4 +// CHECK: %[[VAL_302:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_275]] +// CHECK: store i32 %[[VAL_300]], ptr addrspace(3) %[[VAL_302]], align 4 +// CHECK: %[[VAL_303:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_276]] +// CHECK: store i32 %[[VAL_301]], ptr addrspace(3) %[[VAL_303]], align 4 +// CHECK: %[[VAL_304:.*]] = load float, ptr %[[VAL_294]], align 4 +// CHECK: %[[VAL_305:.*]] = load float, ptr %[[VAL_296]], align 4 +// CHECK: %[[VAL_306:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_275]] +// CHECK: store float %[[VAL_304]], ptr addrspace(3) %[[VAL_306]], align 4 +// CHECK: %[[VAL_307:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_276]] +// CHECK: store float %[[VAL_305]], ptr addrspace(3) %[[VAL_307]], align 4 +// CHECK: br label %[[VAL_281]] +// CHECK: smaller_comparison_index-true5: ; preds = %[[VAL_226]] +// CHECK: %[[VAL_308:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_284]] // CHECK: %[[VAL_309:.*]] = addrspacecast ptr addrspace(3) %[[VAL_308]] to ptr -// CHECK: %[[VAL_310:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_282]] +// CHECK: %[[VAL_310:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_283]] // CHECK: %[[VAL_311:.*]] = addrspacecast ptr addrspace(3) %[[VAL_310]] to ptr -// CHECK: %[[VAL_312:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_281]] +// CHECK: %[[VAL_312:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_284]] // CHECK: %[[VAL_313:.*]] = addrspacecast ptr addrspace(3) %[[VAL_312]] to ptr -// CHECK-GCN: %[[VAL_192_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_192]] to ptr -// CHECK-PTX: call void @[[REGION2]](ptr %[[VAL_307]], ptr %[[VAL_309]], ptr %[[VAL_311]], ptr %[[VAL_313]], ptr %[[VAL_192]]) -// CHECK-GCN: call void @[[REGION2]](ptr %[[VAL_307]], ptr %[[VAL_309]], ptr %[[VAL_311]], ptr %[[VAL_313]], ptr %[[VAL_192_2]]) -// CHECK: %[[VAL_314:.*]] = load i8, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_192]], align 1 -// CHECK: %[[VAL_315:.*]] = icmp ne i8 %[[VAL_314]], 0 -// CHECK: br i1 %[[VAL_315]], label %[[VAL_316:.*]], label %[[VAL_286]] -// CHECK: is_smaller_than-after10: ; preds = %[[VAL_316]], %[[VAL_285]] -// CHECK: br label %[[VAL_225]] -// CHECK: is_smaller_than-true9: ; preds = %[[VAL_285]] -// CHECK: %[[VAL_317:.*]] = load i32, ptr %[[VAL_307]], align 4 -// CHECK: %[[VAL_318:.*]] = load i32, ptr %[[VAL_309]], align 4 -// CHECK: %[[VAL_319:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_281]] -// CHECK: store i32 %[[VAL_317]], ptr addrspace(3) %[[VAL_319]], align 4 -// CHECK: %[[VAL_320:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_282]] -// CHECK: store i32 %[[VAL_318]], ptr addrspace(3) %[[VAL_320]], align 4 -// CHECK: %[[VAL_321:.*]] = load float, ptr %[[VAL_311]], align 4 -// CHECK: %[[VAL_322:.*]] = load float, ptr %[[VAL_313]], align 4 -// CHECK: %[[VAL_323:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_281]] -// CHECK: store float %[[VAL_321]], ptr addrspace(3) %[[VAL_323]], align 4 -// CHECK: %[[VAL_324:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_282]] -// CHECK: store float %[[VAL_322]], ptr addrspace(3) %[[VAL_324]], align 4 -// CHECK: br label %[[VAL_286]] -// CHECK: is_last_tile-true11: ; preds = %[[VAL_280]] -// CHECK: %[[VAL_325:.*]] = xor i64 %[[VAL_210]], 3 -// CHECK: %[[VAL_326:.*]] = icmp slt i64 %[[VAL_210]], %[[VAL_325]] -// CHECK: %[[VAL_327:.*]] = icmp slt i64 %[[VAL_325]], 3 -// CHECK: %[[VAL_328:.*]] = and i1 %[[VAL_326]], %[[VAL_327]] -// CHECK: br i1 %[[VAL_328]], label %[[VAL_329:.*]], label %[[VAL_232]] -// CHECK: smaller_comparison_index-after15: ; preds = %[[VAL_330:.*]], %[[VAL_229]] -// CHECK: br label %[[VAL_331:.*]] -// CHECK: is_last_tile-false12: ; preds = %[[VAL_280]] -// CHECK: %[[VAL_332:.*]] = xor i64 %[[VAL_210]], 3 -// CHECK: %[[VAL_333:.*]] = icmp slt i64 %[[VAL_210]], %[[VAL_332]] -// CHECK: %[[VAL_334:.*]] = icmp slt i64 %[[VAL_332]], 4 -// CHECK: br i1 true, label %[[VAL_335:.*]], label %[[VAL_231]] -// CHECK: smaller_comparison_index-after21: ; preds = %[[VAL_336:.*]], %[[VAL_230]] -// CHECK: br label %[[VAL_331]] -// CHECK: smaller_comparison_index-true14: ; preds = %[[VAL_229]] -// CHECK: %[[VAL_337:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_325]] -// CHECK: %[[VAL_338:.*]] = addrspacecast ptr addrspace(3) %[[VAL_337]] to ptr -// CHECK: %[[VAL_339:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_210]] +// CHECK: %[[VAL_314:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_283]] +// CHECK: %[[VAL_315:.*]] = addrspacecast ptr addrspace(3) %[[VAL_314]] to ptr +// CHECK-GCN: %[[VAL_193_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_193]] to ptr +// CHECK-PTX: call void @[[REGION2]](ptr %[[VAL_309]], ptr %[[VAL_311]], ptr %[[VAL_313]], ptr %[[VAL_315]], ptr %[[VAL_193]]) +// CHECK-GCN: call void @[[REGION2]](ptr %[[VAL_309]], ptr %[[VAL_311]], ptr %[[VAL_313]], ptr %[[VAL_315]], ptr %[[VAL_193_2]]) +// CHECK-PTX: %[[VAL_316:.*]] = load i8, ptr %[[VAL_193]], align 1 +// CHECK-GCN: %[[VAL_316:.*]] = load i8, ptr addrspace(5) %[[VAL_193]], align 1 +// CHECK: %[[VAL_317:.*]] = icmp ne i8 %[[VAL_316]], 0 +// CHECK: br i1 %[[VAL_317]], label %[[VAL_318:.*]], label %[[VAL_288]] +// CHECK: is_smaller_than-after10: ; preds = %[[VAL_318]], %[[VAL_287]] +// CHECK: br label %[[VAL_227]] +// CHECK: is_smaller_than-true9: ; preds = %[[VAL_287]] +// CHECK: %[[VAL_319:.*]] = load i32, ptr %[[VAL_309]], align 4 +// CHECK: %[[VAL_320:.*]] = load i32, ptr %[[VAL_311]], align 4 +// CHECK: %[[VAL_321:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_283]] +// CHECK: store i32 %[[VAL_319]], ptr addrspace(3) %[[VAL_321]], align 4 +// CHECK: %[[VAL_322:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_284]] +// CHECK: store i32 %[[VAL_320]], ptr addrspace(3) %[[VAL_322]], align 4 +// CHECK: %[[VAL_323:.*]] = load float, ptr %[[VAL_313]], align 4 +// CHECK: %[[VAL_324:.*]] = load float, ptr %[[VAL_315]], align 4 +// CHECK: %[[VAL_325:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_283]] +// CHECK: store float %[[VAL_323]], ptr addrspace(3) %[[VAL_325]], align 4 +// CHECK: %[[VAL_326:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_284]] +// CHECK: store float %[[VAL_324]], ptr addrspace(3) %[[VAL_326]], align 4 +// CHECK: br label %[[VAL_288]] +// CHECK: is_last_tile-true11: ; preds = %[[VAL_282]] +// CHECK: %[[VAL_327:.*]] = xor i64 %[[VAL_212]], 3 +// CHECK: %[[VAL_328:.*]] = icmp slt i64 %[[VAL_212]], %[[VAL_327]] +// CHECK: %[[VAL_329:.*]] = icmp slt i64 %[[VAL_327]], 3 +// CHECK: %[[VAL_330:.*]] = and i1 %[[VAL_328]], %[[VAL_329]] +// CHECK: br i1 %[[VAL_330]], label %[[VAL_331:.*]], label %[[VAL_234]] +// CHECK: smaller_comparison_index-after15: ; preds = %[[VAL_332:.*]], %[[VAL_231]] +// CHECK: br label %[[VAL_333:.*]] +// CHECK: is_last_tile-false12: ; preds = %[[VAL_282]] +// CHECK: %[[VAL_334:.*]] = xor i64 %[[VAL_212]], 3 +// CHECK: %[[VAL_335:.*]] = icmp slt i64 %[[VAL_212]], %[[VAL_334]] +// CHECK: %[[VAL_336:.*]] = icmp slt i64 %[[VAL_334]], 4 +// CHECK: br i1 true, label %[[VAL_337:.*]], label %[[VAL_233]] +// CHECK: smaller_comparison_index-after21: ; preds = %[[VAL_338:.*]], %[[VAL_232]] +// CHECK: br label %[[VAL_333]] +// CHECK: smaller_comparison_index-true14: ; preds = %[[VAL_231]] +// CHECK: %[[VAL_339:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_327]] // CHECK: %[[VAL_340:.*]] = addrspacecast ptr addrspace(3) %[[VAL_339]] to ptr -// CHECK: %[[VAL_341:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_325]] +// CHECK: %[[VAL_341:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_212]] // CHECK: %[[VAL_342:.*]] = addrspacecast ptr addrspace(3) %[[VAL_341]] to ptr -// CHECK: %[[VAL_343:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_210]] +// CHECK: %[[VAL_343:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_327]] // CHECK: %[[VAL_344:.*]] = addrspacecast ptr addrspace(3) %[[VAL_343]] to ptr -// CHECK-GCN: %[[VAL_191_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_191]] to ptr -// CHECK-PTX: call void @[[REGION2]](ptr %[[VAL_338]], ptr %[[VAL_340]], ptr %[[VAL_342]], ptr %[[VAL_344]], ptr %[[VAL_191]]) -// CHECK-GCN: call void @[[REGION2]](ptr %[[VAL_338]], ptr %[[VAL_340]], ptr %[[VAL_342]], ptr %[[VAL_344]], ptr %[[VAL_191_2]]) -// CHECK: %[[VAL_345:.*]] = load i8, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_191]], align 1 -// CHECK: %[[VAL_346:.*]] = icmp ne i8 %[[VAL_345]], 0 -// CHECK: br i1 %[[VAL_346]], label %[[VAL_347:.*]], label %[[VAL_330]] -// CHECK: is_smaller_than-after19: ; preds = %[[VAL_347]], %[[VAL_329]] -// CHECK: br label %[[VAL_232]] -// CHECK: is_smaller_than-true18: ; preds = %[[VAL_329]] -// CHECK: %[[VAL_348:.*]] = load i32, ptr %[[VAL_338]], align 4 -// CHECK: %[[VAL_349:.*]] = load i32, ptr %[[VAL_340]], align 4 -// CHECK: %[[VAL_350:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_210]] -// CHECK: store i32 %[[VAL_348]], ptr addrspace(3) %[[VAL_350]], align 4 -// CHECK: %[[VAL_351:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_325]] -// CHECK: store i32 %[[VAL_349]], ptr addrspace(3) %[[VAL_351]], align 4 -// CHECK: %[[VAL_352:.*]] = load float, ptr %[[VAL_342]], align 4 -// CHECK: %[[VAL_353:.*]] = load float, ptr %[[VAL_344]], align 4 -// CHECK: %[[VAL_354:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_210]] -// CHECK: store float %[[VAL_352]], ptr addrspace(3) %[[VAL_354]], align 4 -// CHECK: %[[VAL_355:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_325]] -// CHECK: store float %[[VAL_353]], ptr addrspace(3) %[[VAL_355]], align 4 -// CHECK: br label %[[VAL_330]] -// CHECK: smaller_comparison_index-true20: ; preds = %[[VAL_230]] -// CHECK: %[[VAL_356:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_332]] -// CHECK: %[[VAL_357:.*]] = addrspacecast ptr addrspace(3) %[[VAL_356]] to ptr -// CHECK: %[[VAL_358:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_210]] +// CHECK: %[[VAL_345:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_212]] +// CHECK: %[[VAL_346:.*]] = addrspacecast ptr addrspace(3) %[[VAL_345]] to ptr +// CHECK-GCN: %[[VAL_192_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_192]] to ptr +// CHECK-PTX: call void @[[REGION2]](ptr %[[VAL_340]], ptr %[[VAL_342]], ptr %[[VAL_344]], ptr %[[VAL_346]], ptr %[[VAL_192]]) +// CHECK-GCN: call void @[[REGION2]](ptr %[[VAL_340]], ptr %[[VAL_342]], ptr %[[VAL_344]], ptr %[[VAL_346]], ptr %[[VAL_192_2]]) +// CHECK-PTX: %[[VAL_347:.*]] = load i8, ptr %[[VAL_192]], align 1 +// CHECK-GCN: %[[VAL_347:.*]] = load i8, ptr addrspace(5) %[[VAL_192]], align 1 +// CHECK: %[[VAL_348:.*]] = icmp ne i8 %[[VAL_347]], 0 +// CHECK: br i1 %[[VAL_348]], label %[[VAL_349:.*]], label %[[VAL_332]] +// CHECK: is_smaller_than-after19: ; preds = %[[VAL_349]], %[[VAL_331]] +// CHECK: br label %[[VAL_234]] +// CHECK: is_smaller_than-true18: ; preds = %[[VAL_331]] +// CHECK: %[[VAL_350:.*]] = load i32, ptr %[[VAL_340]], align 4 +// CHECK: %[[VAL_351:.*]] = load i32, ptr %[[VAL_342]], align 4 +// CHECK: %[[VAL_352:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_212]] +// CHECK: store i32 %[[VAL_350]], ptr addrspace(3) %[[VAL_352]], align 4 +// CHECK: %[[VAL_353:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_327]] +// CHECK: store i32 %[[VAL_351]], ptr addrspace(3) %[[VAL_353]], align 4 +// CHECK: %[[VAL_354:.*]] = load float, ptr %[[VAL_344]], align 4 +// CHECK: %[[VAL_355:.*]] = load float, ptr %[[VAL_346]], align 4 +// CHECK: %[[VAL_356:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_212]] +// CHECK: store float %[[VAL_354]], ptr addrspace(3) %[[VAL_356]], align 4 +// CHECK: %[[VAL_357:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_327]] +// CHECK: store float %[[VAL_355]], ptr addrspace(3) %[[VAL_357]], align 4 +// CHECK: br label %[[VAL_332]] +// CHECK: smaller_comparison_index-true20: ; preds = %[[VAL_232]] +// CHECK: %[[VAL_358:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_334]] // CHECK: %[[VAL_359:.*]] = addrspacecast ptr addrspace(3) %[[VAL_358]] to ptr -// CHECK: %[[VAL_360:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_332]] +// CHECK: %[[VAL_360:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_212]] // CHECK: %[[VAL_361:.*]] = addrspacecast ptr addrspace(3) %[[VAL_360]] to ptr -// CHECK: %[[VAL_362:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_210]] +// CHECK: %[[VAL_362:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_334]] // CHECK: %[[VAL_363:.*]] = addrspacecast ptr addrspace(3) %[[VAL_362]] to ptr -// CHECK-GCN: %[[VAL_190_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_190]] to ptr -// CHECK-PTX: call void @[[REGION2]](ptr %[[VAL_357]], ptr %[[VAL_359]], ptr %[[VAL_361]], ptr %[[VAL_363]], ptr %[[VAL_190]]) -// CHECK-GCN: call void @[[REGION2]](ptr %[[VAL_357]], ptr %[[VAL_359]], ptr %[[VAL_361]], ptr %[[VAL_363]], ptr %[[VAL_190_2]]) -// CHECK: %[[VAL_364:.*]] = load i8, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_190]], align 1 -// CHECK: %[[VAL_365:.*]] = icmp ne i8 %[[VAL_364]], 0 -// CHECK: br i1 %[[VAL_365]], label %[[VAL_366:.*]], label %[[VAL_336]] -// CHECK: is_smaller_than-after25: ; preds = %[[VAL_366]], %[[VAL_335]] -// CHECK: br label %[[VAL_231]] -// CHECK: is_smaller_than-true24: ; preds = %[[VAL_335]] -// CHECK: %[[VAL_367:.*]] = load i32, ptr %[[VAL_357]], align 4 -// CHECK: %[[VAL_368:.*]] = load i32, ptr %[[VAL_359]], align 4 -// CHECK: %[[VAL_369:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_210]] -// CHECK: store i32 %[[VAL_367]], ptr addrspace(3) %[[VAL_369]], align 4 -// CHECK: %[[VAL_370:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_332]] -// CHECK: store i32 %[[VAL_368]], ptr addrspace(3) %[[VAL_370]], align 4 -// CHECK: %[[VAL_371:.*]] = load float, ptr %[[VAL_361]], align 4 -// CHECK: %[[VAL_372:.*]] = load float, ptr %[[VAL_363]], align 4 -// CHECK: %[[VAL_373:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_210]] -// CHECK: store float %[[VAL_371]], ptr addrspace(3) %[[VAL_373]], align 4 -// CHECK: %[[VAL_374:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_332]] -// CHECK: store float %[[VAL_372]], ptr addrspace(3) %[[VAL_374]], align 4 -// CHECK: br label %[[VAL_336]] -// CHECK: is_last_tile-true26: ; preds = %[[VAL_331]] -// CHECK: %[[VAL_375:.*]] = mul i64 %[[VAL_210]], 2 -// CHECK: %[[VAL_376:.*]] = xor i64 %[[VAL_375]], 1 -// CHECK: %[[VAL_377:.*]] = icmp slt i64 %[[VAL_375]], %[[VAL_376]] -// CHECK: %[[VAL_378:.*]] = icmp slt i64 %[[VAL_376]], 3 -// CHECK: %[[VAL_379:.*]] = and i1 %[[VAL_377]], %[[VAL_378]] -// CHECK: br i1 %[[VAL_379]], label %[[VAL_380:.*]], label %[[VAL_238]] -// CHECK: smaller_comparison_index-after30: ; preds = %[[VAL_381:.*]], %[[VAL_235]] -// CHECK: br label %[[VAL_244]] -// CHECK: is_last_tile-false27: ; preds = %[[VAL_331]] -// CHECK: %[[VAL_382:.*]] = mul i64 %[[VAL_210]], 2 -// CHECK: %[[VAL_383:.*]] = xor i64 %[[VAL_382]], 1 -// CHECK: %[[VAL_384:.*]] = icmp slt i64 %[[VAL_382]], %[[VAL_383]] -// CHECK: %[[VAL_385:.*]] = icmp slt i64 %[[VAL_383]], 4 -// CHECK: br i1 true, label %[[VAL_386:.*]], label %[[VAL_237]] -// CHECK: smaller_comparison_index-after36: ; preds = %[[VAL_387:.*]], %[[VAL_236]] -// CHECK: br label %[[VAL_244]] -// CHECK: smaller_comparison_index-true29: ; preds = %[[VAL_235]] -// CHECK: %[[VAL_388:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_376]] -// CHECK: %[[VAL_389:.*]] = addrspacecast ptr addrspace(3) %[[VAL_388]] to ptr -// CHECK: %[[VAL_390:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_375]] +// CHECK: %[[VAL_364:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_212]] +// CHECK: %[[VAL_365:.*]] = addrspacecast ptr addrspace(3) %[[VAL_364]] to ptr +// CHECK-GCN: %[[VAL_191_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_191]] to ptr +// CHECK-PTX: call void @[[REGION2]](ptr %[[VAL_359]], ptr %[[VAL_361]], ptr %[[VAL_363]], ptr %[[VAL_365]], ptr %[[VAL_191]]) +// CHECK-GCN: call void @[[REGION2]](ptr %[[VAL_359]], ptr %[[VAL_361]], ptr %[[VAL_363]], ptr %[[VAL_365]], ptr %[[VAL_191_2]]) +// CHECK-PTX: %[[VAL_366:.*]] = load i8, ptr %[[VAL_191]], align 1 +// CHECK-GCN: %[[VAL_366:.*]] = load i8, ptr addrspace(5) %[[VAL_191]], align 1 +// CHECK: %[[VAL_367:.*]] = icmp ne i8 %[[VAL_366]], 0 +// CHECK: br i1 %[[VAL_367]], label %[[VAL_368:.*]], label %[[VAL_338]] +// CHECK: is_smaller_than-after25: ; preds = %[[VAL_368]], %[[VAL_337]] +// CHECK: br label %[[VAL_233]] +// CHECK: is_smaller_than-true24: ; preds = %[[VAL_337]] +// CHECK: %[[VAL_369:.*]] = load i32, ptr %[[VAL_359]], align 4 +// CHECK: %[[VAL_370:.*]] = load i32, ptr %[[VAL_361]], align 4 +// CHECK: %[[VAL_371:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_212]] +// CHECK: store i32 %[[VAL_369]], ptr addrspace(3) %[[VAL_371]], align 4 +// CHECK: %[[VAL_372:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_334]] +// CHECK: store i32 %[[VAL_370]], ptr addrspace(3) %[[VAL_372]], align 4 +// CHECK: %[[VAL_373:.*]] = load float, ptr %[[VAL_363]], align 4 +// CHECK: %[[VAL_374:.*]] = load float, ptr %[[VAL_365]], align 4 +// CHECK: %[[VAL_375:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_212]] +// CHECK: store float %[[VAL_373]], ptr addrspace(3) %[[VAL_375]], align 4 +// CHECK: %[[VAL_376:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_334]] +// CHECK: store float %[[VAL_374]], ptr addrspace(3) %[[VAL_376]], align 4 +// CHECK: br label %[[VAL_338]] +// CHECK: is_last_tile-true26: ; preds = %[[VAL_333]] +// CHECK: %[[VAL_377:.*]] = mul i64 %[[VAL_212]], 2 +// CHECK: %[[VAL_378:.*]] = xor i64 %[[VAL_377]], 1 +// CHECK: %[[VAL_379:.*]] = icmp slt i64 %[[VAL_377]], %[[VAL_378]] +// CHECK: %[[VAL_380:.*]] = icmp slt i64 %[[VAL_378]], 3 +// CHECK: %[[VAL_381:.*]] = and i1 %[[VAL_379]], %[[VAL_380]] +// CHECK: br i1 %[[VAL_381]], label %[[VAL_382:.*]], label %[[VAL_240]] +// CHECK: smaller_comparison_index-after30: ; preds = %[[VAL_383:.*]], %[[VAL_237]] +// CHECK: br label %[[VAL_246]] +// CHECK: is_last_tile-false27: ; preds = %[[VAL_333]] +// CHECK: %[[VAL_384:.*]] = mul i64 %[[VAL_212]], 2 +// CHECK: %[[VAL_385:.*]] = xor i64 %[[VAL_384]], 1 +// CHECK: %[[VAL_386:.*]] = icmp slt i64 %[[VAL_384]], %[[VAL_385]] +// CHECK: %[[VAL_387:.*]] = icmp slt i64 %[[VAL_385]], 4 +// CHECK: br i1 true, label %[[VAL_388:.*]], label %[[VAL_239]] +// CHECK: smaller_comparison_index-after36: ; preds = %[[VAL_389:.*]], %[[VAL_238]] +// CHECK: br label %[[VAL_246]] +// CHECK: smaller_comparison_index-true29: ; preds = %[[VAL_237]] +// CHECK: %[[VAL_390:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_378]] // CHECK: %[[VAL_391:.*]] = addrspacecast ptr addrspace(3) %[[VAL_390]] to ptr -// CHECK: %[[VAL_392:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_376]] +// CHECK: %[[VAL_392:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_377]] // CHECK: %[[VAL_393:.*]] = addrspacecast ptr addrspace(3) %[[VAL_392]] to ptr -// CHECK: %[[VAL_394:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_375]] +// CHECK: %[[VAL_394:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_378]] // CHECK: %[[VAL_395:.*]] = addrspacecast ptr addrspace(3) %[[VAL_394]] to ptr -// CHECK-GCN: %[[VAL_189_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_189]] to ptr -// CHECK-PTX: call void @[[REGION2]](ptr %[[VAL_389]], ptr %[[VAL_391]], ptr %[[VAL_393]], ptr %[[VAL_395]], ptr %[[VAL_189]]) -// CHECK-GCN: call void @[[REGION2]](ptr %[[VAL_389]], ptr %[[VAL_391]], ptr %[[VAL_393]], ptr %[[VAL_395]], ptr %[[VAL_189_2]]) -// CHECK: %[[VAL_396:.*]] = load i8, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_189]], align 1 -// CHECK: %[[VAL_397:.*]] = icmp ne i8 %[[VAL_396]], 0 -// CHECK: br i1 %[[VAL_397]], label %[[VAL_398:.*]], label %[[VAL_381]] -// CHECK: is_smaller_than-after34: ; preds = %[[VAL_398]], %[[VAL_380]] -// CHECK: br label %[[VAL_238]] -// CHECK: is_smaller_than-true33: ; preds = %[[VAL_380]] -// CHECK: %[[VAL_399:.*]] = load i32, ptr %[[VAL_389]], align 4 -// CHECK: %[[VAL_400:.*]] = load i32, ptr %[[VAL_391]], align 4 -// CHECK: %[[VAL_401:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_375]] -// CHECK: store i32 %[[VAL_399]], ptr addrspace(3) %[[VAL_401]], align 4 -// CHECK: %[[VAL_402:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_376]] -// CHECK: store i32 %[[VAL_400]], ptr addrspace(3) %[[VAL_402]], align 4 -// CHECK: %[[VAL_403:.*]] = load float, ptr %[[VAL_393]], align 4 -// CHECK: %[[VAL_404:.*]] = load float, ptr %[[VAL_395]], align 4 -// CHECK: %[[VAL_405:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_375]] -// CHECK: store float %[[VAL_403]], ptr addrspace(3) %[[VAL_405]], align 4 -// CHECK: %[[VAL_406:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_376]] -// CHECK: store float %[[VAL_404]], ptr addrspace(3) %[[VAL_406]], align 4 -// CHECK: br label %[[VAL_381]] -// CHECK: smaller_comparison_index-true35: ; preds = %[[VAL_236]] -// CHECK: %[[VAL_407:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_383]] -// CHECK: %[[VAL_408:.*]] = addrspacecast ptr addrspace(3) %[[VAL_407]] to ptr -// CHECK: %[[VAL_409:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_382]] +// CHECK: %[[VAL_396:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_377]] +// CHECK: %[[VAL_397:.*]] = addrspacecast ptr addrspace(3) %[[VAL_396]] to ptr +// CHECK-GCN: %[[VAL_190_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_190]] to ptr +// CHECK-PTX: call void @[[REGION2]](ptr %[[VAL_391]], ptr %[[VAL_393]], ptr %[[VAL_395]], ptr %[[VAL_397]], ptr %[[VAL_190]]) +// CHECK-GCN: call void @[[REGION2]](ptr %[[VAL_391]], ptr %[[VAL_393]], ptr %[[VAL_395]], ptr %[[VAL_397]], ptr %[[VAL_190_2]]) +// CHECK-PTX: %[[VAL_398:.*]] = load i8, ptr %[[VAL_190]], align 1 +// CHECK-GCN: %[[VAL_398:.*]] = load i8, ptr addrspace(5) %[[VAL_190]], align 1 +// CHECK: %[[VAL_399:.*]] = icmp ne i8 %[[VAL_398]], 0 +// CHECK: br i1 %[[VAL_399]], label %[[VAL_400:.*]], label %[[VAL_383]] +// CHECK: is_smaller_than-after34: ; preds = %[[VAL_400]], %[[VAL_382]] +// CHECK: br label %[[VAL_240]] +// CHECK: is_smaller_than-true33: ; preds = %[[VAL_382]] +// CHECK: %[[VAL_401:.*]] = load i32, ptr %[[VAL_391]], align 4 +// CHECK: %[[VAL_402:.*]] = load i32, ptr %[[VAL_393]], align 4 +// CHECK: %[[VAL_403:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_377]] +// CHECK: store i32 %[[VAL_401]], ptr addrspace(3) %[[VAL_403]], align 4 +// CHECK: %[[VAL_404:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_378]] +// CHECK: store i32 %[[VAL_402]], ptr addrspace(3) %[[VAL_404]], align 4 +// CHECK: %[[VAL_405:.*]] = load float, ptr %[[VAL_395]], align 4 +// CHECK: %[[VAL_406:.*]] = load float, ptr %[[VAL_397]], align 4 +// CHECK: %[[VAL_407:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_377]] +// CHECK: store float %[[VAL_405]], ptr addrspace(3) %[[VAL_407]], align 4 +// CHECK: %[[VAL_408:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_378]] +// CHECK: store float %[[VAL_406]], ptr addrspace(3) %[[VAL_408]], align 4 +// CHECK: br label %[[VAL_383]] +// CHECK: smaller_comparison_index-true35: ; preds = %[[VAL_238]] +// CHECK: %[[VAL_409:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_385]] // CHECK: %[[VAL_410:.*]] = addrspacecast ptr addrspace(3) %[[VAL_409]] to ptr -// CHECK: %[[VAL_411:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_383]] +// CHECK: %[[VAL_411:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_384]] // CHECK: %[[VAL_412:.*]] = addrspacecast ptr addrspace(3) %[[VAL_411]] to ptr -// CHECK: %[[VAL_413:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_382]] +// CHECK: %[[VAL_413:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_385]] // CHECK: %[[VAL_414:.*]] = addrspacecast ptr addrspace(3) %[[VAL_413]] to ptr -// CHECK-GCN: %[[VAL_188_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_188]] to ptr -// CHECK-PTX: call void @[[REGION2]](ptr %[[VAL_408]], ptr %[[VAL_410]], ptr %[[VAL_412]], ptr %[[VAL_414]], ptr %[[VAL_188]]) -// CHECK-GCN: call void @[[REGION2]](ptr %[[VAL_408]], ptr %[[VAL_410]], ptr %[[VAL_412]], ptr %[[VAL_414]], ptr %[[VAL_188_2]]) -// CHECK: %[[VAL_415:.*]] = load i8, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_188]], align 1 -// CHECK: %[[VAL_416:.*]] = icmp ne i8 %[[VAL_415]], 0 -// CHECK: br i1 %[[VAL_416]], label %[[VAL_417:.*]], label %[[VAL_387]] -// CHECK: is_smaller_than-after40: ; preds = %[[VAL_417]], %[[VAL_386]] -// CHECK: br label %[[VAL_237]] -// CHECK: is_smaller_than-true39: ; preds = %[[VAL_386]] -// CHECK: %[[VAL_418:.*]] = load i32, ptr %[[VAL_408]], align 4 -// CHECK: %[[VAL_419:.*]] = load i32, ptr %[[VAL_410]], align 4 -// CHECK: %[[VAL_420:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_382]] -// CHECK: store i32 %[[VAL_418]], ptr addrspace(3) %[[VAL_420]], align 4 -// CHECK: %[[VAL_421:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_383]] -// CHECK: store i32 %[[VAL_419]], ptr addrspace(3) %[[VAL_421]], align 4 -// CHECK: %[[VAL_422:.*]] = load float, ptr %[[VAL_412]], align 4 -// CHECK: %[[VAL_423:.*]] = load float, ptr %[[VAL_414]], align 4 -// CHECK: %[[VAL_424:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_382]] -// CHECK: store float %[[VAL_422]], ptr addrspace(3) %[[VAL_424]], align 4 -// CHECK: %[[VAL_425:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_383]] -// CHECK: store float %[[VAL_423]], ptr addrspace(3) %[[VAL_425]], align 4 -// CHECK: br label %[[VAL_387]] -// CHECK: smaller_keys_index-true41: ; preds = %[[VAL_244]] -// CHECK: %[[VAL_426:.*]] = shl i64 %[[VAL_210]], 1 -// CHECK: %[[VAL_427:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_426]] -// CHECK: %[[VAL_428:.*]] = load i32, ptr addrspace(3) %[[VAL_427]], align 4 -// CHECK: %[[VAL_429:.*]] = getelementptr inbounds [2 x [3 x i32]], ptr %[[VAL_251]], i64 0, i64 %[[VAL_203]], i64 %[[VAL_239]] -// CHECK: store i32 %[[VAL_428]], ptr %[[VAL_429]], align 4 -// CHECK: %[[VAL_430:.*]] = add i64 %[[VAL_239]], 1 -// CHECK: %[[VAL_431:.*]] = icmp slt i64 %[[VAL_430]], 3 -// CHECK: br i1 %[[VAL_431]], label %[[VAL_432:.*]], label %[[VAL_243]] -// CHECK: inner_smaller_keys_index-after44: ; preds = %[[VAL_432]], %[[VAL_241]] -// CHECK: br label %[[VAL_242]] -// CHECK: inner_smaller_keys_index-true43: ; preds = %[[VAL_241]] -// CHECK: %[[VAL_433:.*]] = add i64 %[[VAL_426]], 1 -// CHECK: %[[VAL_434:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_433]] -// CHECK: %[[VAL_435:.*]] = load i32, ptr addrspace(3) %[[VAL_434]], align 4 -// CHECK: %[[VAL_436:.*]] = getelementptr inbounds [2 x [3 x i32]], ptr %[[VAL_251]], i64 0, i64 %[[VAL_203]], i64 %[[VAL_430]] -// CHECK: store i32 %[[VAL_435]], ptr %[[VAL_436]], align 4 -// CHECK: br label %[[VAL_243]] -// CHECK: smaller_keys_index-true45: ; preds = %[[VAL_242]] -// CHECK: %[[VAL_437:.*]] = shl i64 %[[VAL_210]], 1 -// CHECK: %[[VAL_438:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_437]] -// CHECK: %[[VAL_439:.*]] = load float, ptr addrspace(3) %[[VAL_438]], align 4 -// CHECK: %[[VAL_440:.*]] = getelementptr inbounds [2 x [3 x float]], ptr %[[VAL_263]], i64 0, i64 %[[VAL_203]], i64 %[[VAL_245]] -// CHECK: store float %[[VAL_439]], ptr %[[VAL_440]], align 4 -// CHECK: %[[VAL_441:.*]] = add i64 %[[VAL_245]], 1 -// CHECK: %[[VAL_442:.*]] = icmp slt i64 %[[VAL_441]], 3 -// CHECK: br i1 %[[VAL_442]], label %[[VAL_443:.*]], label %[[VAL_248]] -// CHECK: inner_smaller_keys_index-after48: ; preds = %[[VAL_443]], %[[VAL_247]] -// CHECK: br label %[[VAL_207]] -// CHECK: inner_smaller_keys_index-true47: ; preds = %[[VAL_247]] -// CHECK: %[[VAL_444:.*]] = add i64 %[[VAL_437]], 1 -// CHECK: %[[VAL_445:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_444]] -// CHECK: %[[VAL_446:.*]] = load float, ptr addrspace(3) %[[VAL_445]], align 4 -// CHECK: %[[VAL_447:.*]] = getelementptr inbounds [2 x [3 x float]], ptr %[[VAL_263]], i64 0, i64 %[[VAL_203]], i64 %[[VAL_441]] -// CHECK: store float %[[VAL_446]], ptr %[[VAL_447]], align 4 -// CHECK: br label %[[VAL_248]] +// CHECK: %[[VAL_415:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_384]] +// CHECK: %[[VAL_416:.*]] = addrspacecast ptr addrspace(3) %[[VAL_415]] to ptr +// CHECK-GCN: %[[VAL_189_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_189]] to ptr +// CHECK-PTX: call void @[[REGION2]](ptr %[[VAL_410]], ptr %[[VAL_412]], ptr %[[VAL_414]], ptr %[[VAL_416]], ptr %[[VAL_189]]) +// CHECK-GCN: call void @[[REGION2]](ptr %[[VAL_410]], ptr %[[VAL_412]], ptr %[[VAL_414]], ptr %[[VAL_416]], ptr %[[VAL_189_2]]) +// CHECK-PTX: %[[VAL_417:.*]] = load i8, ptr %[[VAL_189]], align 1 +// CHECK-GCN: %[[VAL_417:.*]] = load i8, ptr addrspace(5) %[[VAL_189]], align 1 +// CHECK: %[[VAL_418:.*]] = icmp ne i8 %[[VAL_417]], 0 +// CHECK: br i1 %[[VAL_418]], label %[[VAL_419:.*]], label %[[VAL_389]] +// CHECK: is_smaller_than-after40: ; preds = %[[VAL_419]], %[[VAL_388]] +// CHECK: br label %[[VAL_239]] +// CHECK: is_smaller_than-true39: ; preds = %[[VAL_388]] +// CHECK: %[[VAL_420:.*]] = load i32, ptr %[[VAL_410]], align 4 +// CHECK: %[[VAL_421:.*]] = load i32, ptr %[[VAL_412]], align 4 +// CHECK: %[[VAL_422:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_384]] +// CHECK: store i32 %[[VAL_420]], ptr addrspace(3) %[[VAL_422]], align 4 +// CHECK: %[[VAL_423:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_385]] +// CHECK: store i32 %[[VAL_421]], ptr addrspace(3) %[[VAL_423]], align 4 +// CHECK: %[[VAL_424:.*]] = load float, ptr %[[VAL_414]], align 4 +// CHECK: %[[VAL_425:.*]] = load float, ptr %[[VAL_416]], align 4 +// CHECK: %[[VAL_426:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_384]] +// CHECK: store float %[[VAL_424]], ptr addrspace(3) %[[VAL_426]], align 4 +// CHECK: %[[VAL_427:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_385]] +// CHECK: store float %[[VAL_425]], ptr addrspace(3) %[[VAL_427]], align 4 +// CHECK: br label %[[VAL_389]] +// CHECK: smaller_keys_index-true41: ; preds = %[[VAL_246]] +// CHECK: %[[VAL_428:.*]] = shl i64 %[[VAL_212]], 1 +// CHECK: %[[VAL_429:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_428]] +// CHECK: %[[VAL_430:.*]] = load i32, ptr addrspace(3) %[[VAL_429]], align 4 +// CHECK: %[[VAL_431:.*]] = getelementptr inbounds [2 x [3 x i32]], ptr %[[VAL_253]], i64 0, i64 %[[VAL_205]], i64 %[[VAL_241]] +// CHECK: store i32 %[[VAL_430]], ptr %[[VAL_431]], align 4 +// CHECK: %[[VAL_432:.*]] = add i64 %[[VAL_241]], 1 +// CHECK: %[[VAL_433:.*]] = icmp slt i64 %[[VAL_432]], 3 +// CHECK: br i1 %[[VAL_433]], label %[[VAL_434:.*]], label %[[VAL_245]] +// CHECK: inner_smaller_keys_index-after44: ; preds = %[[VAL_434]], %[[VAL_243]] +// CHECK: br label %[[VAL_244]] +// CHECK: inner_smaller_keys_index-true43: ; preds = %[[VAL_243]] +// CHECK: %[[VAL_435:.*]] = add i64 %[[VAL_428]], 1 +// CHECK: %[[VAL_436:.*]] = getelementptr [64 x i32], ptr addrspace(3) @sort_tile_param_0, i64 0, i64 %[[VAL_435]] +// CHECK: %[[VAL_437:.*]] = load i32, ptr addrspace(3) %[[VAL_436]], align 4 +// CHECK: %[[VAL_438:.*]] = getelementptr inbounds [2 x [3 x i32]], ptr %[[VAL_253]], i64 0, i64 %[[VAL_205]], i64 %[[VAL_432]] +// CHECK: store i32 %[[VAL_437]], ptr %[[VAL_438]], align 4 +// CHECK: br label %[[VAL_245]] +// CHECK: smaller_keys_index-true45: ; preds = %[[VAL_244]] +// CHECK: %[[VAL_439:.*]] = shl i64 %[[VAL_212]], 1 +// CHECK: %[[VAL_440:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_439]] +// CHECK: %[[VAL_441:.*]] = load float, ptr addrspace(3) %[[VAL_440]], align 4 +// CHECK: %[[VAL_442:.*]] = getelementptr inbounds [2 x [3 x float]], ptr %[[VAL_265]], i64 0, i64 %[[VAL_205]], i64 %[[VAL_247]] +// CHECK: store float %[[VAL_441]], ptr %[[VAL_442]], align 4 +// CHECK: %[[VAL_443:.*]] = add i64 %[[VAL_247]], 1 +// CHECK: %[[VAL_444:.*]] = icmp slt i64 %[[VAL_443]], 3 +// CHECK: br i1 %[[VAL_444]], label %[[VAL_445:.*]], label %[[VAL_250]] +// CHECK: inner_smaller_keys_index-after48: ; preds = %[[VAL_445]], %[[VAL_249]] +// CHECK: br label %[[VAL_209]] +// CHECK: inner_smaller_keys_index-true47: ; preds = %[[VAL_249]] +// CHECK: %[[VAL_446:.*]] = add i64 %[[VAL_439]], 1 +// CHECK: %[[VAL_447:.*]] = getelementptr [64 x float], ptr addrspace(3) @sort_tile_param_1, i64 0, i64 %[[VAL_446]] +// CHECK: %[[VAL_448:.*]] = load float, ptr addrspace(3) %[[VAL_447]], align 4 +// CHECK: %[[VAL_449:.*]] = getelementptr inbounds [2 x [3 x float]], ptr %[[VAL_265]], i64 0, i64 %[[VAL_205]], i64 %[[VAL_443]] +// CHECK: store float %[[VAL_448]], ptr %[[VAL_449]], align 4 +// CHECK: br label %[[VAL_250]] // CHECK: entry: -// CHECK: %[[VAL_448:.*]] = alloca i8, align 1 -// CHECK: %[[VAL_449:.*]] = load float, ptr %[[VAL_450:.*]], align 4 +// CHECK: %[[VAL_450:.*]] = alloca i8, align 1 // CHECK: %[[VAL_451:.*]] = load float, ptr %[[VAL_452:.*]], align 4 -// CHECK: %[[VAL_453:.*]] = fcmp olt float %[[VAL_449]], %[[VAL_451]] -// CHECK: %[[VAL_454:.*]] = zext i1 %[[VAL_453]] to i8 -// CHECK: store i8 %[[VAL_454]], ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_448]], align 1 -// CHECK: %[[VAL_455:.*]] = load i8, ptr [[ADDRSPACE_ANNOTATION]]%[[VAL_448]], align 1 -// CHECK: store i8 %[[VAL_455]], ptr %[[VAL_456:.*]], align 1 +// CHECK: %[[VAL_453:.*]] = load float, ptr %[[VAL_454:.*]], align 4 +// CHECK: %[[VAL_455:.*]] = fcmp olt float %[[VAL_451]], %[[VAL_453]] +// CHECK: %[[VAL_456:.*]] = zext i1 %[[VAL_455]] to i8 +// CHECK-PTX: store i8 %[[VAL_456]], ptr %[[VAL_450]], align 1 +// CHECK-GCN: store i8 %[[VAL_456]], ptr addrspace(5) %[[VAL_450]], align 1 +// CHECK-PTX: %[[VAL_457:.*]] = load i8, ptr %[[VAL_450]], align 1 +// CHECK-GCN: %[[VAL_457:.*]] = load i8, ptr addrspace(5) %[[VAL_450]], align 1 +// CHECK: store i8 %[[VAL_457]], ptr %[[VAL_458:.*]], align 1 // CHECK: ret void ENTRY main { diff --git a/xla/service/gpu/tests/sorting_test.cc b/xla/service/gpu/tests/sorting_test.cc index 2d6cca3c51b24..686cdbafd4f7a 100644 --- a/xla/service/gpu/tests/sorting_test.cc +++ b/xla/service/gpu/tests/sorting_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include #include "absl/log/check.h" #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" diff --git a/xla/service/gpu/tests/swap_conv_operands_test.cc b/xla/service/gpu/tests/swap_conv_operands_test.cc index 1e2566dae69cb..2885c8af11ff3 100644 --- a/xla/service/gpu/tests/swap_conv_operands_test.cc +++ b/xla/service/gpu/tests/swap_conv_operands_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/error_spec.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "tsl/platform/test.h" diff --git a/xla/service/gpu/tests/tensor_float_32_global_var_test.cc b/xla/service/gpu/tests/tensor_float_32_global_var_test.cc index cec3db1c9a1d7..a8d338b066387 100644 --- a/xla/service/gpu/tests/tensor_float_32_global_var_test.cc +++ b/xla/service/gpu/tests/tensor_float_32_global_var_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,7 +15,6 @@ limitations under the License. #include -#include #include "xla/error_spec.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/tensor_float_32_utils.h" diff --git a/xla/service/gpu/tests/test_autotune_cache.textproto b/xla/service/gpu/tests/test_autotune_cache.textproto index 2dee80241f5da..b20a9d20ece50 100644 --- a/xla/service/gpu/tests/test_autotune_cache.textproto +++ b/xla/service/gpu/tests/test_autotune_cache.textproto @@ -1,4 +1,18 @@ -version: 2 +# Copyright 2023 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +version: 3 results { device: "sm_8.0 with 42331013120B RAM, 108 cores, 1410000KHz clock, 1215000KHz mem clock, 41943040B L2$" hlo: "{\n tmp_0 = f16[1,16,17,3]{3,2,1,0} parameter(0)\n tmp_1 = f16[16,51]{1,0} bitcast(f16[1,16,17,3]{3,2,1,0} tmp_0)\n tmp_2 = s8[16,17,3]{2,1,0} parameter(1)\n tmp_3 = s8[51,16]{0,1} bitcast(s8[16,17,3]{2,1,0} tmp_2)\n tmp_4 = f16[51,16]{0,1} convert(s8[51,16]{0,1} tmp_3)\n tmp_5 = f16[16,16]{1,0} dot(f16[16,51]{1,0} tmp_1, f16[51,16]{0,1} tmp_4), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n ROOT tmp_6 = f16[1,16,16]{2,1,0} bitcast(f16[16,16]{1,0} tmp_5)\n}" @@ -13,6 +27,7 @@ results { split_k: 1 num_stages: 1 num_warps: 4 + num_ctas: 1 } } } diff --git a/xla/service/gpu/tests/transpose_021.hlo b/xla/service/gpu/tests/transpose_021.hlo new file mode 100644 index 0000000000000..370779aed2142 --- /dev/null +++ b/xla/service/gpu/tests/transpose_021.hlo @@ -0,0 +1,107 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s + +HloModule Transpose, is_scheduled=true + +%fused_computation { + %p0 = f32[2,16,17]{2,1,0} parameter(0) + ROOT %transpose = f32[2,17,16]{2,1,0} transpose(%p0), dimensions={0,2,1} +} + +ENTRY main { + %param = f32[2,16,17]{2,1,0} parameter(0) + ROOT %fusion = f32[2,17,16] fusion(%param), kind=kInput, calls=%fused_computation +} + +// CHECK-LABEL: entry: +// CHECK: %[[VAL_0:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_1:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_2:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_3:.*]] = alloca i32, align 4 +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %[[VAL_4:.*]] = udiv i32 %thread.id.x, 32 +// CHECK: %thread.id.1 = urem i32 %[[VAL_4]], 4 +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_5:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_6:.*]] = urem i32 %[[VAL_5]], 1 +// CHECK: %[[VAL_7:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_8:.*]] = urem i32 %[[VAL_7]], 1 +// CHECK: %[[VAL_9:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_10:.*]] = icmp eq i32 %[[VAL_8]], 0 +// CHECK: %tile_bound.1 = select i1 %[[VAL_10]], i32 16, i32 32 +// CHECK: %[[VAL_11:.*]] = icmp eq i32 %[[VAL_6]], 0 +// CHECK: %tile_bound.2 = select i1 %[[VAL_11]], i32 17, i32 32 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_9]], 1 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_8]], 32 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_6]], 32 +// CHECK: store i32 %thread.id.1, ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: br label %[[VAL_12:.*]] +// CHECK: loop1.loop_header: ; preds = %[[VAL_13:.*]], %[[VAL_14:.*]] +// CHECK: %[[VAL_15:.*]] = load i32, ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: %[[VAL_16:.*]] = icmp uge i32 %[[VAL_15]], %tile_bound.1 +// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]] +// CHECK: loop1.loop_body: ; preds = %[[VAL_12]] +// CHECK: %[[VAL_19:.*]] = add nuw nsw i32 %[[VAL_15]], 4 +// CHECK: store i32 %[[VAL_19]], ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: %[[VAL_20:.*]] = icmp eq i32 %[[VAL_15]], %thread.id.1 +// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: br label %[[VAL_21:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_22:.*]], %[[VAL_18]] +// CHECK: %[[VAL_23:.*]] = load i32, ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: %[[VAL_24:.*]] = icmp uge i32 %[[VAL_23]], %tile_bound.2 +// CHECK: br i1 %[[VAL_24]], label %[[VAL_13]], label %[[VAL_22]] +// CHECK: loop2.loop_body: ; preds = %[[VAL_21]] +// CHECK: %[[VAL_25:.*]] = add nuw nsw i32 %[[VAL_23]], 32 +// CHECK: store i32 %[[VAL_25]], ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: %[[VAL_26:.*]] = icmp eq i32 %[[VAL_23]], %thread.id.2 +// CHECK: %[[VAL_27:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_28:.*]] = add i32 %tile_origin.1, %[[VAL_15]] +// CHECK: %[[VAL_29:.*]] = add i32 %tile_origin.2, %[[VAL_23]] +// CHECK: %[[VAL_30:.*]] = getelementptr{{.*}} inbounds [2 x [16 x [17 x float]]], ptr{{.*}} %[[VAL_31:.*]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] +// CHECK: %[[VAL_32:.*]] = load float, ptr{{.*}} %[[VAL_30]], align 4, !invariant.load !{{[0-9]}} +// CHECK: %[[VAL_33:.*]] = getelementptr{{.*}} inbounds [1 x [32 x [33 x float]]], ptr{{.*}} addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_15]], i32 %[[VAL_23]] +// CHECK: %[[VAL_34:.*]] = addrspacecast ptr{{.*}} addrspace(3) %[[VAL_33]] to ptr +// CHECK: store float %[[VAL_32]], ptr{{.*}} %[[VAL_34]], align 4 +// CHECK: br label %[[VAL_21]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit: ; preds = %[[VAL_21]] +// CHECK: br label %[[VAL_12]], !llvm.loop !{{[0-9]}} +// CHECK: loop1.loop_exit: ; preds = %[[VAL_12]] +// CHECK-PTX: call void @llvm.nvvm.barrier0() +// CHECK-GCN: call void @llvm.amdgcn.s.barrier() +// CHECK: store i32 %thread.id.1, ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: br label %[[VAL_35:.*]] +// CHECK: loop1.loop_header4: ; preds = %[[VAL_36:.*]], %[[VAL_17]] +// CHECK: %[[VAL_37:.*]] = load i32, ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: %[[VAL_38:.*]] = icmp uge i32 %[[VAL_37]], %tile_bound.2 +// CHECK: br i1 %[[VAL_38]], label %[[VAL_39:.*]], label %[[VAL_40:.*]] +// CHECK: loop1.loop_body5: ; preds = %[[VAL_35]] +// CHECK: %[[VAL_41:.*]] = add nuw nsw i32 %[[VAL_37]], 4 +// CHECK: store i32 %[[VAL_41]], ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: %[[VAL_42:.*]] = icmp eq i32 %[[VAL_37]], %thread.id.1 +// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: br label %[[VAL_43:.*]] +// CHECK: loop2.loop_header10: ; preds = %[[VAL_44:.*]], %[[VAL_40]] +// CHECK: %[[VAL_45:.*]] = load i32, ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: %[[VAL_46:.*]] = icmp uge i32 %[[VAL_45]], %tile_bound.1 +// CHECK: br i1 %[[VAL_46]], label %[[VAL_36]], label %[[VAL_44]] +// CHECK: loop2.loop_body11: ; preds = %[[VAL_43]] +// CHECK: %[[VAL_47:.*]] = add nuw nsw i32 %[[VAL_45]], 32 +// CHECK: store i32 %[[VAL_47]], ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: %[[VAL_48:.*]] = icmp eq i32 %[[VAL_45]], %thread.id.2 +// CHECK: %[[VAL_49:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_50:.*]] = add i32 %tile_origin.2, %[[VAL_37]] +// CHECK: %[[VAL_51:.*]] = add i32 %tile_origin.1, %[[VAL_45]] +// CHECK: %[[VAL_52:.*]] = getelementptr{{.*}} inbounds [1 x [32 x [33 x float]]], ptr{{.*}} addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_45]], i32 %[[VAL_37]] +// CHECK: %[[VAL_53:.*]] = addrspacecast ptr{{.*}} addrspace(3) %[[VAL_52]] to ptr +// CHECK: %[[VAL_54:.*]] = load float, ptr{{.*}} %[[VAL_53]], align 4 +// CHECK: %[[VAL_55:.*]] = getelementptr{{.*}} inbounds [2 x [17 x [16 x float]]], ptr{{.*}} %[[VAL_56:.*]], i32 0, i32 %[[VAL_49]], i32 %[[VAL_50]], i32 %[[VAL_51]] +// CHECK: store float %[[VAL_54]], ptr{{.*}} %[[VAL_55]], align 4 +// CHECK: br label %[[VAL_43]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit9: ; preds = %[[VAL_43]] +// CHECK: br label %[[VAL_35]], !llvm.loop !{{[0-9]}} +// CHECK: loop1.loop_exit3: ; preds = %[[VAL_35]] +// CHECK: ret void diff --git a/xla/service/gpu/tests/transpose_021_extra_output.hlo b/xla/service/gpu/tests/transpose_021_extra_output.hlo new file mode 100644 index 0000000000000..b285b335e9161 --- /dev/null +++ b/xla/service/gpu/tests/transpose_021_extra_output.hlo @@ -0,0 +1,115 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s + +HloModule Transpose, is_scheduled=true + +%fused_computation { + %p0 = f32[2,16,17] parameter(0) + %neg = f32[2,16,17] negate(%p0) + %transpose = f32[2,17,16] transpose(%p0), dimensions={0,2,1} + ROOT %tuple = (f32[2,16,17], f32[2,17,16]) tuple(%neg, %transpose) +} + +ENTRY main { + %param = f32[2,16,17] parameter(0) + ROOT %fusion = (f32[2,16,17], f32[2,17,16]) fusion(%param), kind=kInput, calls=%fused_computation +} + + +// CHECK-LABEL: entry: +// CHECK: %[[VAL_0:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_1:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_2:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_3:.*]] = alloca i32, align 4 +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %[[VAL_4:.*]] = udiv i32 %thread.id.x, 32 +// CHECK: %thread.id.1 = urem i32 %[[VAL_4]], 4 +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_5:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_6:.*]] = urem i32 %[[VAL_5]], 1 +// CHECK: %[[VAL_7:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_8:.*]] = urem i32 %[[VAL_7]], 1 +// CHECK: %[[VAL_9:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_10:.*]] = icmp eq i32 %[[VAL_8]], 0 +// CHECK: %tile_bound.1 = select i1 %[[VAL_10]], i32 16, i32 32 +// CHECK: %[[VAL_11:.*]] = icmp eq i32 %[[VAL_6]], 0 +// CHECK: %tile_bound.2 = select i1 %[[VAL_11]], i32 17, i32 32 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_9]], 1 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_8]], 32 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_6]], 32 +// CHECK: store i32 %thread.id.1, ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: br label %[[VAL_12:.*]] +// CHECK: loop1.loop_header: ; preds = %[[VAL_13:.*]], %[[VAL_14:.*]] +// CHECK: %[[VAL_15:.*]] = load i32, ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: %[[VAL_16:.*]] = icmp uge i32 %[[VAL_15]], %tile_bound.1 +// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]] +// CHECK: loop1.loop_body: ; preds = %[[VAL_12]] +// CHECK: %[[VAL_19:.*]] = add nuw nsw i32 %[[VAL_15]], 4 +// CHECK: store i32 %[[VAL_19]], ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: %[[VAL_20:.*]] = icmp eq i32 %[[VAL_15]], %thread.id.1 +// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: br label %[[VAL_21:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_22:.*]], %[[VAL_18]] +// CHECK: %[[VAL_23:.*]] = load i32, ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: %[[VAL_24:.*]] = icmp uge i32 %[[VAL_23]], %tile_bound.2 +// CHECK: br i1 %[[VAL_24]], label %[[VAL_13]], label %[[VAL_22]] +// CHECK: loop2.loop_body: ; preds = %[[VAL_21]] +// CHECK: %[[VAL_25:.*]] = add nuw nsw i32 %[[VAL_23]], 32 +// CHECK: store i32 %[[VAL_25]], ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: %[[VAL_26:.*]] = icmp eq i32 %[[VAL_23]], %thread.id.2 +// CHECK: %[[VAL_27:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_28:.*]] = add i32 %tile_origin.1, %[[VAL_15]] +// CHECK: %[[VAL_29:.*]] = add i32 %tile_origin.2, %[[VAL_23]] +// CHECK: %[[VAL_30:.*]] = getelementptr inbounds [2 x [16 x [17 x float]]], ptr{{.*}} %[[VAL_31:.*]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] +// CHECK: %[[VAL_32:.*]] = load float, ptr{{.*}} %[[VAL_30]], align 4, !invariant.load !{{[0-9]}} +// CHECK: %[[VAL_33:.*]] = getelementptr inbounds [1 x [32 x [33 x float]]], ptr{{.*}} addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_15]], i32 %[[VAL_23]] +// CHECK: %[[VAL_34:.*]] = addrspacecast ptr{{.*}} addrspace(3) %[[VAL_33]] to ptr +// CHECK: store float %[[VAL_32]], ptr{{.*}} %[[VAL_34]], align 4 +// CHECK: %[[VAL_35:.*]] = getelementptr inbounds [2 x [16 x [17 x float]]], ptr{{.*}} %[[VAL_31]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] +// CHECK: %[[VAL_36:.*]] = load float, ptr{{.*}} %[[VAL_35]], align 4, !invariant.load !{{[0-9]}} +// CHECK: %[[VAL_37:.*]] = fneg float %[[VAL_36]] +// CHECK: %[[VAL_38:.*]] = getelementptr inbounds [2 x [16 x [17 x float]]], ptr{{.*}} %[[VAL_39:.*]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] +// CHECK: store float %[[VAL_37]], ptr{{.*}} %[[VAL_38]], align 4 +// CHECK: br label %[[VAL_21]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit: ; preds = %[[VAL_21]] +// CHECK: br label %[[VAL_12]], !llvm.loop !{{[0-9]}} +// CHECK: loop1.loop_exit: ; preds = %[[VAL_12]] +// CHECK-PTX: call void @llvm.nvvm.barrier0() +// CHECK-GCN: call void @llvm.amdgcn.s.barrier() +// CHECK: store i32 %thread.id.1, ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: br label %[[VAL_40:.*]] +// CHECK: loop1.loop_header6: ; preds = %[[VAL_41:.*]], %[[VAL_17]] +// CHECK: %[[VAL_42:.*]] = load i32, ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: %[[VAL_43:.*]] = icmp uge i32 %[[VAL_42]], %tile_bound.2 +// CHECK: br i1 %[[VAL_43]], label %[[VAL_44:.*]], label %[[VAL_45:.*]] +// CHECK: loop1.loop_body7: ; preds = %[[VAL_40]] +// CHECK: %[[VAL_46:.*]] = add nuw nsw i32 %[[VAL_42]], 4 +// CHECK: store i32 %[[VAL_46]], ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: %[[VAL_47:.*]] = icmp eq i32 %[[VAL_42]], %thread.id.1 +// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: br label %[[VAL_48:.*]] +// CHECK: loop2.loop_header12: ; preds = %[[VAL_49:.*]], %[[VAL_45]] +// CHECK: %[[VAL_50:.*]] = load i32, ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: %[[VAL_51:.*]] = icmp uge i32 %[[VAL_50]], %tile_bound.1 +// CHECK: br i1 %[[VAL_51]], label %[[VAL_41]], label %[[VAL_49]] +// CHECK: loop2.loop_body13: ; preds = %[[VAL_48]] +// CHECK: %[[VAL_52:.*]] = add nuw nsw i32 %[[VAL_50]], 32 +// CHECK: store i32 %[[VAL_52]], ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: %[[VAL_53:.*]] = icmp eq i32 %[[VAL_50]], %thread.id.2 +// CHECK: %[[VAL_54:.*]] = add i32 %tile_origin.0, 0 +// CHECK: %[[VAL_55:.*]] = add i32 %tile_origin.2, %[[VAL_42]] +// CHECK: %[[VAL_56:.*]] = add i32 %tile_origin.1, %[[VAL_50]] +// CHECK: %[[VAL_57:.*]] = getelementptr inbounds [1 x [32 x [33 x float]]], ptr{{.*}} addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_50]], i32 %[[VAL_42]] +// CHECK: %[[VAL_58:.*]] = addrspacecast ptr{{.*}} addrspace(3) %[[VAL_57]] to ptr +// CHECK: %[[VAL_59:.*]] = load float, ptr{{.*}} %[[VAL_58]], align 4 +// CHECK: %[[VAL_60:.*]] = getelementptr inbounds [2 x [17 x [16 x float]]], ptr{{.*}} %[[VAL_61:.*]], i32 0, i32 %[[VAL_54]], i32 %[[VAL_55]], i32 %[[VAL_56]] +// CHECK: store float %[[VAL_59]], ptr{{.*}} %[[VAL_60]], align 4 +// CHECK: br label %[[VAL_48]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit11: ; preds = %[[VAL_48]] +// CHECK: br label %[[VAL_40]], !llvm.loop !{{[0-9]}} +// CHECK: loop1.loop_exit5: ; preds = %[[VAL_40]] +// CHECK: ret void diff --git a/xla/service/gpu/tests/transpose_210.hlo b/xla/service/gpu/tests/transpose_210.hlo new file mode 100644 index 0000000000000..cf83fa7a8c029 --- /dev/null +++ b/xla/service/gpu/tests/transpose_210.hlo @@ -0,0 +1,106 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s + +HloModule Transpose, is_scheduled=true + +%fused_computation { + %p0 = f32[33,49,65]{2,1,0} parameter(0) + ROOT %transpose = f32[65,49,33]{2,1,0} transpose(%p0), dimensions={2,1,0} +} + +ENTRY main { + %param = f32[33,49,65]{2,1,0} parameter(0) + ROOT %fusion = f32[65,49,33] fusion(%param), kind=kInput, calls=%fused_computation +} + +// CHECK-LABEL: entry: +// CHECK: %[[VAL_0:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_1:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_2:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_3:.*]] = alloca i32, align 4 +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %thread.id.0 = udiv i32 %thread.id.x, 32 +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_5:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_6:.*]] = urem i32 %[[VAL_5]], 3 +// CHECK: %[[VAL_7:.*]] = udiv i32 %block.id.x, 3 +// CHECK: %[[VAL_8:.*]] = urem i32 %[[VAL_7]], 49 +// CHECK: %[[VAL_9:.*]] = udiv i32 %block.id.x, 147 +// CHECK: %[[VAL_10:.*]] = icmp eq i32 %[[VAL_9]], 1 +// CHECK: %tile_bound.0 = select i1 %[[VAL_10]], i32 1, i32 32 +// CHECK: %[[VAL_11:.*]] = icmp eq i32 %[[VAL_6]], 2 +// CHECK: %tile_bound.2 = select i1 %[[VAL_11]], i32 1, i32 32 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_9]], 32 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_8]], 1 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_6]], 32 +// CHECK: store i32 %thread.id.0, ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: br label %[[VAL_12:.*]] +// CHECK: loop0.loop_header: ; preds = %[[VAL_13:.*]], %[[VAL_14:.*]] +// CHECK: %[[VAL_15:.*]] = load i32, ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: %[[VAL_16:.*]] = icmp uge i32 %[[VAL_15]], %tile_bound.0 +// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]] +// CHECK: loop0.loop_body: ; preds = %[[VAL_12]] +// CHECK: %[[VAL_19:.*]] = add nuw nsw i32 %[[VAL_15]], 4 +// CHECK: store i32 %[[VAL_19]], ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: %[[VAL_20:.*]] = icmp eq i32 %[[VAL_15]], %thread.id.0 +// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: br label %[[VAL_21:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_22:.*]], %[[VAL_18]] +// CHECK: %[[VAL_23:.*]] = load i32, ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: %[[VAL_24:.*]] = icmp uge i32 %[[VAL_23]], %tile_bound.2 +// CHECK: br i1 %[[VAL_24]], label %[[VAL_13]], label %[[VAL_22]] +// CHECK: loop2.loop_body: ; preds = %[[VAL_21]] +// CHECK: %[[VAL_25:.*]] = add nuw nsw i32 %[[VAL_23]], 32 +// CHECK: store i32 %[[VAL_25]], ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: %[[VAL_26:.*]] = icmp eq i32 %[[VAL_23]], %thread.id.2 +// CHECK: %[[VAL_27:.*]] = add i32 %tile_origin.0, %[[VAL_15]] +// CHECK: %[[VAL_28:.*]] = add i32 %tile_origin.1, 0 +// CHECK: %[[VAL_29:.*]] = add i32 %tile_origin.2, %[[VAL_23]] +// CHECK: %[[VAL_30:.*]] = getelementptr inbounds [33 x [49 x [65 x float]]], ptr %[[VAL_31:.*]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] +// CHECK: %[[VAL_32:.*]] = load float, ptr %[[VAL_30]], align 4, !invariant.load !{{[0-9]}} +// CHECK: %[[VAL_33:.*]] = getelementptr inbounds [32 x [1 x [33 x float]]], ptr addrspace(3) @tr_tile_0, i32 0, i32 %[[VAL_15]], i32 0, i32 %[[VAL_23]] +// CHECK: %[[VAL_34:.*]] = addrspacecast ptr addrspace(3) %[[VAL_33]] to ptr +// CHECK: store float %[[VAL_32]], ptr %[[VAL_34]], align 4 +// CHECK: br label %[[VAL_21]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit: ; preds = %[[VAL_21]] +// CHECK: br label %[[VAL_12]], !llvm.loop !{{[0-9]}} +// CHECK: loop0.loop_exit: ; preds = %[[VAL_12]] +// CHECK-PTX: call void @llvm.nvvm.barrier0() +// CHECK-GCN: call void @llvm.amdgcn.s.barrier() +// CHECK: store i32 %thread.id.0, ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: br label %[[VAL_35:.*]] +// CHECK: loop0.loop_header4: ; preds = %[[VAL_36:.*]], %[[VAL_17]] +// CHECK: %[[VAL_37:.*]] = load i32, ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: %[[VAL_38:.*]] = icmp uge i32 %[[VAL_37]], %tile_bound.2 +// CHECK: br i1 %[[VAL_38]], label %[[VAL_39:.*]], label %[[VAL_40:.*]] +// CHECK: loop0.loop_body5: ; preds = %[[VAL_35]] +// CHECK: %[[VAL_41:.*]] = add nuw nsw i32 %[[VAL_37]], 4 +// CHECK: store i32 %[[VAL_41]], ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: %[[VAL_42:.*]] = icmp eq i32 %[[VAL_37]], %thread.id.0 +// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: br label %[[VAL_43:.*]] +// CHECK: loop2.loop_header10: ; preds = %[[VAL_44:.*]], %[[VAL_40]] +// CHECK: %[[VAL_45:.*]] = load i32, ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: %[[VAL_46:.*]] = icmp uge i32 %[[VAL_45]], %tile_bound.0 +// CHECK: br i1 %[[VAL_46]], label %[[VAL_36]], label %[[VAL_44]] +// CHECK: loop2.loop_body11: ; preds = %[[VAL_43]] +// CHECK: %[[VAL_47:.*]] = add nuw nsw i32 %[[VAL_45]], 32 +// CHECK: store i32 %[[VAL_47]], ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: %[[VAL_48:.*]] = icmp eq i32 %[[VAL_45]], %thread.id.2 +// CHECK: %[[VAL_49:.*]] = add i32 %tile_origin.2, %[[VAL_37]] +// CHECK: %[[VAL_50:.*]] = add i32 %tile_origin.1, 0 +// CHECK: %[[VAL_51:.*]] = add i32 %tile_origin.0, %[[VAL_45]] +// CHECK: %[[VAL_52:.*]] = getelementptr inbounds [32 x [1 x [33 x float]]], ptr addrspace(3) @tr_tile_0, i32 0, i32 %[[VAL_45]], i32 0, i32 %[[VAL_37]] +// CHECK: %[[VAL_53:.*]] = addrspacecast ptr addrspace(3) %[[VAL_52]] to ptr +// CHECK: %[[VAL_54:.*]] = load float, ptr{{.*}} %[[VAL_53]], align 4 +// CHECK: %[[VAL_55:.*]] = getelementptr inbounds [65 x [49 x [33 x float]]], ptr %[[VAL_56:.*]], i32 0, i32 %[[VAL_49]], i32 %[[VAL_50]], i32 %[[VAL_51]] +// CHECK: store float %[[VAL_54]], ptr %[[VAL_55]], align 4 +// CHECK: br label %[[VAL_43]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit9: ; preds = %[[VAL_43]] +// CHECK: br label %[[VAL_35]], !llvm.loop !{{[0-9]}} +// CHECK: loop0.loop_exit3: ; preds = %[[VAL_35]] +// CHECK: ret void diff --git a/xla/service/gpu/tests/transpose_210_extra_output.hlo b/xla/service/gpu/tests/transpose_210_extra_output.hlo new file mode 100644 index 0000000000000..9581099deed5c --- /dev/null +++ b/xla/service/gpu/tests/transpose_210_extra_output.hlo @@ -0,0 +1,113 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s + +HloModule Transpose, is_scheduled=true + +%fused_computation { + %p0 = f32[33,49,65] parameter(0) + %neg = f32[33,49,65] negate(%p0) + %transpose = f32[65,49,33] transpose(%p0), dimensions={2,1,0} + ROOT %tuple = (f32[33,49,65], f32[65,49,33]) tuple(%neg, %transpose) +} + +ENTRY main { + %param = f32[33,49,65]{2,1,0} parameter(0) + ROOT %fusion = (f32[33,49,65], f32[65,49,33]) fusion(%param), kind=kInput, calls=%fused_computation +} + +// CHECK-LABEL: entry: +// CHECK: %[[VAL_0:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_1:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_2:.*]] = alloca i32, align 4 +// CHECK: %[[VAL_3:.*]] = alloca i32, align 4 +// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 +// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x +// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3 +// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x +// CHECK: %thread.id.0 = udiv i32 %thread.id.x, 32 +// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 +// CHECK: %lane_id = urem i32 %thread.id.x, 32 +// CHECK: %[[VAL_5:.*]] = udiv i32 %block.id.x, 1 +// CHECK: %[[VAL_6:.*]] = urem i32 %[[VAL_5]], 3 +// CHECK: %[[VAL_7:.*]] = udiv i32 %block.id.x, 3 +// CHECK: %[[VAL_8:.*]] = urem i32 %[[VAL_7]], 49 +// CHECK: %[[VAL_9:.*]] = udiv i32 %block.id.x, 147 +// CHECK: %[[VAL_10:.*]] = icmp eq i32 %[[VAL_9]], 1 +// CHECK: %tile_bound.0 = select i1 %[[VAL_10]], i32 1, i32 32 +// CHECK: %[[VAL_11:.*]] = icmp eq i32 %[[VAL_6]], 2 +// CHECK: %tile_bound.2 = select i1 %[[VAL_11]], i32 1, i32 32 +// CHECK: %tile_origin.0 = mul i32 %[[VAL_9]], 32 +// CHECK: %tile_origin.1 = mul i32 %[[VAL_8]], 1 +// CHECK: %tile_origin.2 = mul i32 %[[VAL_6]], 32 +// CHECK: store i32 %thread.id.0, ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: br label %[[VAL_12:.*]] +// CHECK: loop0.loop_header: ; preds = %[[VAL_13:.*]], %[[VAL_14:.*]] +// CHECK: %[[VAL_15:.*]] = load i32, ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: %[[VAL_16:.*]] = icmp uge i32 %[[VAL_15]], %tile_bound.0 +// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]] +// CHECK: loop0.loop_body: ; preds = %[[VAL_12]] +// CHECK: %[[VAL_19:.*]] = add nuw nsw i32 %[[VAL_15]], 4 +// CHECK: store i32 %[[VAL_19]], ptr{{.*}} %[[VAL_3]], align 4 +// CHECK: %[[VAL_20:.*]] = icmp eq i32 %[[VAL_15]], %thread.id.0 +// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: br label %[[VAL_21:.*]] +// CHECK: loop2.loop_header: ; preds = %[[VAL_22:.*]], %[[VAL_18]] +// CHECK: %[[VAL_23:.*]] = load i32, ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: %[[VAL_24:.*]] = icmp uge i32 %[[VAL_23]], %tile_bound.2 +// CHECK: br i1 %[[VAL_24]], label %[[VAL_13]], label %[[VAL_22]] +// CHECK: loop2.loop_body: ; preds = %[[VAL_21]] +// CHECK: %[[VAL_25:.*]] = add nuw nsw i32 %[[VAL_23]], 32 +// CHECK: store i32 %[[VAL_25]], ptr{{.*}} %[[VAL_2]], align 4 +// CHECK: %[[VAL_26:.*]] = icmp eq i32 %[[VAL_23]], %thread.id.2 +// CHECK: %[[VAL_27:.*]] = add i32 %tile_origin.0, %[[VAL_15]] +// CHECK: %[[VAL_28:.*]] = add i32 %tile_origin.1, 0 +// CHECK: %[[VAL_29:.*]] = add i32 %tile_origin.2, %[[VAL_23]] +// CHECK: %[[VAL_30:.*]] = getelementptr inbounds [33 x [49 x [65 x float]]], ptr{{.*}} %[[VAL_31:.*]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] +// CHECK: %[[VAL_32:.*]] = load float, ptr{{.*}} %[[VAL_30]], align 4, !invariant.load !{{[0-9]}} +// CHECK: %[[VAL_33:.*]] = getelementptr inbounds [32 x [1 x [33 x float]]], ptr{{.*}} addrspace(3) @tr_tile_0, i32 0, i32 %[[VAL_15]], i32 0, i32 %[[VAL_23]] +// CHECK: %[[VAL_34:.*]] = addrspacecast ptr{{.*}} addrspace(3) %[[VAL_33]] to ptr +// CHECK: store float %[[VAL_32]], ptr{{.*}} %[[VAL_34]], align 4 +// CHECK: %[[VAL_35:.*]] = getelementptr inbounds [33 x [49 x [65 x float]]], ptr{{.*}} %[[VAL_31]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] +// CHECK: %[[VAL_36:.*]] = load float, ptr{{.*}} %[[VAL_35]], align 4, !invariant.load !{{[0-9]}} +// CHECK: %[[VAL_37:.*]] = fneg float %[[VAL_36]] +// CHECK: %[[VAL_38:.*]] = getelementptr inbounds [33 x [49 x [65 x float]]], ptr{{.*}} %[[VAL_39:.*]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] +// CHECK: store float %[[VAL_37]], ptr{{.*}} %[[VAL_38]], align 4 +// CHECK: br label %[[VAL_21]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit: ; preds = %[[VAL_21]] +// CHECK: br label %[[VAL_12]], !llvm.loop !{{[0-9]}} +// CHECK: loop0.loop_exit: ; preds = %[[VAL_12]] +// CHECK-PTX: call void @llvm.nvvm.barrier0() +// CHECK-GCN: call void @llvm.amdgcn.s.barrier() +// CHECK: store i32 %thread.id.0, ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: br label %[[VAL_40:.*]] +// CHECK: loop0.loop_header6: ; preds = %[[VAL_41:.*]], %[[VAL_17]] +// CHECK: %[[VAL_42:.*]] = load i32, ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: %[[VAL_43:.*]] = icmp uge i32 %[[VAL_42]], %tile_bound.2 +// CHECK: br i1 %[[VAL_43]], label %[[VAL_44:.*]], label %[[VAL_45:.*]] +// CHECK: loop0.loop_body7: ; preds = %[[VAL_40]] +// CHECK: %[[VAL_46:.*]] = add nuw nsw i32 %[[VAL_42]], 4 +// CHECK: store i32 %[[VAL_46]], ptr{{.*}} %[[VAL_1]], align 4 +// CHECK: %[[VAL_47:.*]] = icmp eq i32 %[[VAL_42]], %thread.id.0 +// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: br label %[[VAL_48:.*]] +// CHECK: loop2.loop_header12: ; preds = %[[VAL_49:.*]], %[[VAL_45]] +// CHECK: %[[VAL_50:.*]] = load i32, ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: %[[VAL_51:.*]] = icmp uge i32 %[[VAL_50]], %tile_bound.0 +// CHECK: br i1 %[[VAL_51]], label %[[VAL_41]], label %[[VAL_49]] +// CHECK: loop2.loop_body13: ; preds = %[[VAL_48]] +// CHECK: %[[VAL_52:.*]] = add nuw nsw i32 %[[VAL_50]], 32 +// CHECK: store i32 %[[VAL_52]], ptr{{.*}} %[[VAL_0]], align 4 +// CHECK: %[[VAL_53:.*]] = icmp eq i32 %[[VAL_50]], %thread.id.2 +// CHECK: %[[VAL_54:.*]] = add i32 %tile_origin.2, %[[VAL_42]] +// CHECK: %[[VAL_55:.*]] = add i32 %tile_origin.1, 0 +// CHECK: %[[VAL_56:.*]] = add i32 %tile_origin.0, %[[VAL_50]] +// CHECK: %[[VAL_57:.*]] = getelementptr inbounds [32 x [1 x [33 x float]]], ptr{{.*}} addrspace(3) @tr_tile_0, i32 0, i32 %[[VAL_50]], i32 0, i32 %[[VAL_42]] +// CHECK: %[[VAL_58:.*]] = addrspacecast ptr{{.*}} addrspace(3) %[[VAL_57]] to ptr +// CHECK: %[[VAL_59:.*]] = load float, ptr{{.*}} %[[VAL_58]], align 4 +// CHECK: %[[VAL_60:.*]] = getelementptr inbounds [65 x [49 x [33 x float]]], ptr{{.*}} %[[VAL_61:.*]], i32 0, i32 %[[VAL_54]], i32 %[[VAL_55]], i32 %[[VAL_56]] +// CHECK: store float %[[VAL_59]], ptr{{.*}} %[[VAL_60]], align 4 +// CHECK: br label %[[VAL_48]], !llvm.loop !{{[0-9]}} +// CHECK: loop2.loop_exit11: ; preds = %[[VAL_48]] +// CHECK: br label %[[VAL_40]], !llvm.loop !{{[0-9]}} +// CHECK: loop0.loop_exit5: ; preds = %[[VAL_40]] +// CHECK: ret void diff --git a/xla/service/gpu/tests/transpose_emitter_test.cc b/xla/service/gpu/tests/transpose_emitter_test.cc index aa3d3b257c7ea..b8a263df629a6 100644 --- a/xla/service/gpu/tests/transpose_emitter_test.cc +++ b/xla/service/gpu/tests/transpose_emitter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,12 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include -#include +#include "xla/error_spec.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/tests/hlo_test_base.h" #include "tsl/platform/test.h" namespace xla { @@ -92,14 +90,14 @@ HloModule m %fused_computation { %param_0.1 = f32[16,32]{1,0} parameter(0) %s.1 = f32[16,32]{1,0} sqrt(%param_0.1) - %c.1 = f32[16,32]{0,1} copy(%s.1) - b = f32[16,32,1]{0,1,2} bitcast(%c.1) - ROOT o = f32[16,32,1]{0,1,2} sqrt(b) + %t.1 = f32[32,16]{1,0} transpose(%s.1), dimensions={1,0} + b = f32[32,16,1]{2,1,0} bitcast(%t.1) + ROOT o = f32[32,16,1]{2,1,0} sqrt(b) } ENTRY main { %p = f32[16,32]{1,0} parameter(0) - ROOT %fusion = f32[16,32,1]{0,1,2} fusion(%p), kind=kInput, calls=%fused_computation + ROOT %fusion = f32[32,16,1]{2,1,0} fusion(%p), kind=kInput, calls=%fused_computation } )"; @@ -111,23 +109,23 @@ ENTRY main { EXPECT_TRUE(RunAndCompareNoHloPasses(hlo, ErrorSpec{1e-3})); } -TEST_F(TransposeEmitterTest, MultipleCopiesWithPostFusion) { +TEST_F(TransposeEmitterTest, MultipleTransposesWithPostFusion) { const char* hlo = R"( HloModule m %fused_computation { %param_0.1 = f32[16,32]{1,0} parameter(0) %s.1 = f32[16,32]{1,0} sqrt(%param_0.1) - %c.1 = f32[16,32]{0,1} copy(%s.1) - %c1.1 = f32[16,32]{0,1} copy(%param_0.1) - %r.1 = f32[16,32,1]{0,1,2} reshape(%c.1) - %r1.1 = f32[16,32,1]{0,1,2} reshape(%c1.1) - ROOT %tuple = (f32[16,32,1]{0,1,2}, f32[16,32,1]{0,1,2}) tuple(%r.1, %r1.1) + %t.1 = f32[32,16]{1,0} transpose(%s.1), dimensions={1,0} + %t1.1 = f32[32,16]{1,0} transpose(%param_0.1), dimensions={1,0} + %r.1 = f32[32,16,1]{2,1,0} reshape(%t.1) + %r1.1 = f32[32,16,1]{2,1,0} reshape(%t1.1) + ROOT %tuple = (f32[32,16,1]{2,1,0}, f32[32,16,1]{2,1,0}) tuple(%r.1, %r1.1) } ENTRY main { %p = f32[16,32]{1,0} parameter(0) - ROOT %fusion = (f32[16,32,1]{0,1,2}, f32[16,32,1]{0,1,2}) fusion(%p), kind=kInput, calls=%fused_computation + ROOT %fusion = (f32[32,16,1]{2,1,0}, f32[32,16,1]{2,1,0}) fusion(%p), kind=kInput, calls=%fused_computation } )"; diff --git a/xla/service/gpu/tests/tree_reduction_rewriter_test.cc b/xla/service/gpu/tests/tree_reduction_rewriter_test.cc index 4cfd03a38a66e..4827717e3c3cd 100644 --- a/xla/service/gpu/tests/tree_reduction_rewriter_test.cc +++ b/xla/service/gpu/tests/tree_reduction_rewriter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,16 +16,10 @@ limitations under the License. #include "xla/service/gpu/tree_reduction_rewriter.h" #include -#include #include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" -#include "xla/statusor.h" -#include "xla/tests/filecheck.h" +#include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { @@ -60,7 +54,7 @@ add { } ENTRY main { - input = f32[50000] parameter(0) + input = f32[50021] parameter(0) zero = f32[] constant(0) ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add } @@ -68,9 +62,9 @@ ENTRY main { CheckTreeRewriter(hlo, R"( -// CHECK: [[pad_0:%[^ ]+]] = f32[50048]{0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_48 -// CHECK: [[bitcast_3:%[^ ]+]] = f32[128,391]{1,0} bitcast([[pad_0]]) -// CHECK: [[reduce_4:%[^ ]+]] = f32[128]{0} reduce([[bitcast_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_5:%[^ ]+]] +// CHECK: [[pad_0:%[^ ]+]] = f32[50022]{0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_1 +// CHECK: [[bitcast_3:%[^ ]+]] = f32[397,126]{1,0} bitcast([[pad_0]]) +// CHECK: [[reduce_4:%[^ ]+]] = f32[397]{0} reduce([[bitcast_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_5:%[^ ]+]] // CHECK: ROOT [[out_1_6:%[^ ]+]] = f32[] reduce([[reduce_4]], [[zero_2]]), dimensions={0}, to_apply=[[add_5]] )"); } @@ -120,9 +114,9 @@ ENTRY main { CheckTreeRewriter(hlo, R"( // CHECK: [[input_0:%[^ ]+]] = f32[50048]{0} parameter(0) -// CHECK: [[bitcast_1:%[^ ]+]] = f32[128,391]{1,0} bitcast([[input_0]]) +// CHECK: [[bitcast_1:%[^ ]+]] = f32[391,128]{1,0} bitcast([[input_0]]) // CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0) -// CHECK: [[reduce_3:%[^ ]+]] = f32[128]{0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]] +// CHECK: [[reduce_3:%[^ ]+]] = f32[391]{0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]] // CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[] reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]] )"); } @@ -269,7 +263,7 @@ add { } ENTRY main { - input = f32[10302,100] parameter(0) + input = f32[10303,100] parameter(0) zero = f32[] constant(0) ROOT out = f32[100] reduce(input, zero), dimensions={0}, to_apply=add } @@ -277,11 +271,11 @@ ENTRY main { CheckTreeRewriter(hlo, R"( -// CHECK: [[input_0:%[^ ]+]] = f32[10302,100]{1,0} parameter(0) +// CHECK: [[input_0:%[^ ]+]] = f32[10303,100]{1,0} parameter(0) // CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0) -// CHECK: [[pad_0:%[^ ]+]] = f32[10304,100]{1,0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_2x0_0 -// CHECK: [[bitcast_1:%[^ ]+]] = f32[64,161,100]{2,1,0} bitcast([[pad_0]]) -// CHECK: [[reduce_3:%[^ ]+]] = f32[64,100]{1,0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]] +// CHECK: [[pad_0:%[^ ]+]] = f32[10304,100]{1,0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_1x0_0 +// CHECK: [[bitcast_1:%[^ ]+]] = f32[161,64,100]{2,1,0} bitcast([[pad_0]]) +// CHECK: [[reduce_3:%[^ ]+]] = f32[161,100]{1,0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]] // CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]] )"); } @@ -362,8 +356,8 @@ argmax { } ENTRY main { - input = f32[2,100000] parameter(0) - idxs = u32[2,100000] iota(), iota_dimension=0 + input = f32[2,100003] parameter(0) + idxs = u32[2,100003] iota(), iota_dimension=0 zero = f32[] constant(0) zero_idx = u32[] constant(0) @@ -376,14 +370,14 @@ ENTRY main { CheckTreeRewriter(hlo, R"( -// CHECK: [[pad_0:%[^ ]+]] = f32[2,100096]{1,0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_0x0_96 -// CHECK: [[bitcast_3:%[^ ]+]] = f32[2,256,391]{2,1,0} bitcast([[pad_0]]) +// CHECK: [[pad_0:%[^ ]+]] = f32[2,100005]{1,0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_0x0_2 +// CHECK: [[bitcast_3:%[^ ]+]] = f32[2,339,295]{2,1,0} bitcast([[pad_0]]) // CHECK: [[zero_idx_4:%[^ ]+]] = u32[] constant(0) -// CHECK: [[pad_1_5:%[^ ]+]] = u32[2,100096]{1,0} pad([[idxs_6:%[^ ]+]], [[zero_idx_4]]), padding=0_0x0_96 -// CHECK: [[bitcast_1_7:%[^ ]+]] = u32[2,256,391]{2,1,0} bitcast([[pad_1_5]]) -// CHECK: [[reduce_8:%[^ ]+]] = (f32[2,256]{1,0}, u32[2,256]{1,0}) reduce([[bitcast_3]], [[bitcast_1_7]], [[zero_2]], [[zero_idx_4]]), dimensions={2}, to_apply=[[argmax_9:%[^ ]+]] -// CHECK: [[get_tuple_element_10:%[^ ]+]] = f32[2,256]{1,0} get-tuple-element([[reduce_8]]), index=0 -// CHECK: [[get_tuple_element_1_11:%[^ ]+]] = u32[2,256]{1,0} get-tuple-element([[reduce_8]]), index=1 +// CHECK: [[pad_1_5:%[^ ]+]] = u32[2,100005]{1,0} pad([[idxs_6:%[^ ]+]], [[zero_idx_4]]), padding=0_0x0_2 +// CHECK: [[bitcast_1_7:%[^ ]+]] = u32[2,339,295]{2,1,0} bitcast([[pad_1_5]]) +// CHECK: [[reduce_8:%[^ ]+]] = (f32[2,339]{1,0}, u32[2,339]{1,0}) reduce([[bitcast_3]], [[bitcast_1_7]], [[zero_2]], [[zero_idx_4]]), dimensions={2}, to_apply=[[argmax_9:%[^ ]+]] +// CHECK: [[get_tuple_element_10:%[^ ]+]] = f32[2,339]{1,0} get-tuple-element([[reduce_8]]), index=0 +// CHECK: [[get_tuple_element_1_11:%[^ ]+]] = u32[2,339]{1,0} get-tuple-element([[reduce_8]]), index=1 // CHECK: ROOT [[out_1_12:%[^ ]+]] = (f32[2]{0}, u32[2]{0}) reduce([[get_tuple_element_10]], [[get_tuple_element_1_11]], [[zero_2]], [[zero_idx_4]]), dimensions={1}, to_apply=[[argmax_9]] )"); } diff --git a/xla/service/gpu/tests/triton_naming.hlo b/xla/service/gpu/tests/triton_naming.hlo index 53be10f72e3c8..2739e34918178 100644 --- a/xla/service/gpu/tests/triton_naming.hlo +++ b/xla/service/gpu/tests/triton_naming.hlo @@ -1,6 +1,7 @@ -// RUN: hlo_to_llvm_ir %s | FileCheck %{IR_SUBST} %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck --check-prefixes=CHECK-%{PTX} %s -// CHECK: define [[KERNEL_ANNOTATION]]void @triton_gemm_r( +// CHECK-PTX: define void @triton_gemm_r( +// CHECK-GCN: define amdgpu_kernel void @triton_gemm_r( HloModule t, is_scheduled=true, entry_computation_layout={(f16[15,19]{1,0},s8[19,17]{1,0})->f16[15,17]{1,0}} @@ -14,5 +15,5 @@ HloModule t, is_scheduled=true, entry_computation_layout={(f16[15,19]{1,0},s8[19 ENTRY %e (p0: f16[15,19], p1: s8[19,17]) -> f16[15,17] { %p1 = s8[19,17]{1,0} parameter(1) %p0 = f16[15,19]{1,0} parameter(0) - ROOT %triton_gemm_r = f16[15,17]{1,0} fusion(%p1, %p0), kind=kCustom, calls=%triton_gemm_r, backend_config="{kind: \"__triton_gemm\", triton_gemm_config: {\"block_m\":\"64\",\"block_n\":\"32\",\"block_k\":\"64\",\"split_k\":\"1\",\"num_stages\":\"2\",\"num_warps\":\"8\"}}" + ROOT %triton_gemm_r = f16[15,17]{1,0} fusion(%p1, %p0), kind=kCustom, calls=%triton_gemm_r, backend_config="{ \"fusion_backend_config\": {kind: \"__triton_gemm\", triton_gemm_config: {\"block_m\":\"64\",\"block_n\":\"32\",\"block_k\":\"64\",\"split_k\":\"1\",\"num_stages\":\"2\",\"num_warps\":\"8\",\"num_ctas\":\"1\"}}}" } diff --git a/xla/service/gpu/thunk.cc b/xla/service/gpu/thunk.cc deleted file mode 100644 index 6c544a8568513..0000000000000 --- a/xla/service/gpu/thunk.cc +++ /dev/null @@ -1,144 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/thunk.h" - -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/translate/mhlo_to_hlo/location_exporter.h" - -namespace xla { -namespace gpu { - -Thunk::ExecuteParams::ExecuteParams( - const ServiceExecutableRunOptions& run_options, - const BufferAllocations& buffer_allocations, se::Stream* stream, - absl::Span async_streams) - : buffer_allocations(&buffer_allocations), - stream(stream), - async_comms_streams(async_streams.begin(), async_streams.end()), - nccl_params(run_options, stream->parent()) {} - -/*static*/ absl::string_view Thunk::KindToString(Thunk::Kind kind) { -#define CASE(x) \ - case Thunk::x: \ - return #x - switch (kind) { - CASE(kCholesky); - CASE(kCommandBuffer); - CASE(kConditional); - CASE(kConvolution); - CASE(kConvolutionReorder); - CASE(kCopy); - CASE(kCubSort); - CASE(kCublasLtMatmul); - CASE(kCustomCall); - CASE(kNcclAllGather); - CASE(kNcclAllGatherStart); - CASE(kNcclAllGatherDone); - CASE(kNcclAllReduce); - CASE(kNcclAllReduceStart); - CASE(kNcclAllReduceDone); - CASE(kNcclCollectivePermute); - CASE(kNcclCollectivePermuteStart); - CASE(kNcclCollectivePermuteDone); - CASE(kNcclReduceScatter); - CASE(kNcclReduceScatterStart); - CASE(kNcclReduceScatterDone); - CASE(kNcclAllToAll); - CASE(kNcclAllToAllStart); - CASE(kNcclAllToAllDone); - CASE(kNcclSend); - CASE(kNcclRecv); - CASE(kFft); - CASE(kFor); - CASE(kGemm); - CASE(kInfeed); - CASE(kKernel); - CASE(kMemset32BitValue); - CASE(kMemzero); - CASE(kNorm); - CASE(kOutfeed); - CASE(kReplicaId); - CASE(kPartitionId); - CASE(kSequential); - CASE(kTriangularSolve); - CASE(kWhile); - CASE(kFusedMHA); - } -} - -std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) { - return os << Thunk::KindToString(kind); -} - -std::string ThunkSequence::ToString( - int indent, - std::function get_thunk_annotation) const { - const std::string indent_str(indent * 2, ' '); - if (empty()) return indent_str + "No thunks."; - - auto thunk_with_longest_kind = absl::c_max_element( - *this, - [](const std::unique_ptr& a, const std::unique_ptr& b) { - return Thunk::KindToString(a->kind()).length() < - Thunk::KindToString(b->kind()).length(); - }); - int64_t max_thunk_kind_len = - Thunk::KindToString(thunk_with_longest_kind->get()->kind()).length(); - std::string result; - for (const std::unique_ptr& thunk : *this) { - // Write out the thunk kind, padded out to max_thunk_kind_len. - absl::string_view kind_str = Thunk::KindToString(thunk->kind()); - absl::StrAppend(&result, indent_str, kind_str, - std::string(max_thunk_kind_len - kind_str.length(), ' '), - "\t"); - if (get_thunk_annotation) { - absl::StrAppend(&result, get_thunk_annotation(thunk.get())); - } - absl::StrAppend(&result, thunk->ToStringExtra(indent)); - absl::StrAppend(&result, "\n"); - } - return result; -} - -bool IsReductionCollective(Thunk::Kind kind) { - return kind == Thunk::kNcclAllReduce || kind == Thunk::kNcclAllReduceStart || - kind == Thunk::kNcclReduceScatter || - kind == Thunk::kNcclReduceScatterStart; -} - -Thunk::ThunkInfo Thunk::ThunkInfo::WithProfileAnnotation(mlir::Operation* op) { - ThunkInfo thunk_info(op); - thunk_info.profile_annotation = absl::StrFormat( - "Thunk:#hlo_op=%s#", mlir::mhlo::GetDebugNameFromLocation(op->getLoc())); - return thunk_info; -} - -Thunk::ThunkInfo Thunk::ThunkInfo::WithProfileAnnotation( - const HloInstruction* instr) { - ThunkInfo thunk_info(nullptr); - thunk_info.profile_annotation = - absl::StrFormat("Thunk:#hlo_op=%s#", instr->name()); - return thunk_info; -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/thunk.h b/xla/service/gpu/thunk.h deleted file mode 100644 index a560899dcebf7..0000000000000 --- a/xla/service/gpu/thunk.h +++ /dev/null @@ -1,212 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_THUNK_H_ -#define XLA_SERVICE_GPU_THUNK_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/types/span.h" -#include "mlir/IR/Operation.h" // from @llvm-project -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/gpu_executable_run_options.h" -#include "xla/service/service_executable_run_options.h" -#include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/status.h" - -namespace xla { -namespace gpu { - -class GpuExecutable; - -enum AsyncStreamKind { - kAsyncStreamCollective = 0, // Stream for asynchronous collective ops. - kAsyncStreamP2P = 1, // Stream for P2P Send and Recv ops. -}; -constexpr static int64_t kAsyncStreamTotal = kAsyncStreamP2P + 1; -// Assigns a unique ID to a stream for asynchronous or synchronous execution. -// These IDs can be used, for example, to look up the NCCL communicator. -inline uint64_t GetStreamId( - bool is_async, AsyncStreamKind stream_kind = kAsyncStreamCollective) { - return is_async ? stream_kind + 1 : 0; -} - -// Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the -// metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction. -// -// Thunk provides the Initialize and ExecuteOnStream interface for GpuExecutable -// to initialize and execute the invocation respectively. Its subclasses are -// supposed to override these interfaces to launch a generated kernel or call an -// external library function (such as operations in cuBLAS). -// -// This is thread-compatible. -class Thunk { - public: - enum Kind { - kCholesky, - kConditional, - kConvolution, - kConvolutionReorder, - kCopy, - kCommandBuffer, - kCubSort, - kCublasLtMatmul, - kCustomCall, - kFft, - kFor, - kGemm, - kInfeed, - kKernel, - kMemset32BitValue, - kMemzero, - kNcclAllGather, - kNcclAllGatherStart, - kNcclAllGatherDone, - kNcclAllReduce, - kNcclAllReduceStart, - kNcclAllReduceDone, - kNcclCollectivePermute, - kNcclCollectivePermuteStart, - kNcclCollectivePermuteDone, - kNcclReduceScatter, - kNcclReduceScatterStart, - kNcclReduceScatterDone, - kNcclAllToAll, - kNcclAllToAllStart, - kNcclAllToAllDone, - kNcclSend, - kNcclRecv, - kNorm, - kOutfeed, - kReplicaId, - kPartitionId, - kSequential, - kTriangularSolve, - kWhile, - kFusedMHA - }; - - // TODO(ezhulenev): This should become a part of StreamExecutor library, but - // for now we keep it here as a Thunk implementation detail. It's not yet - // clear what else should become a part of "executable source", we likely - // need to keep some information about available symbols and signatures. - struct ExecutableSource { - std::string_view text; // PTX for NVIDIA backend - absl::Span binary; // CUBIN for NVIDIA backends - }; - - struct ThunkInfo { - explicit ThunkInfo(mlir::Operation* op) : op(op) {} - static ThunkInfo WithProfileAnnotation(mlir::Operation* op); - static ThunkInfo WithProfileAnnotation(const HloInstruction* instr); - - std::string profile_annotation; - // TODO(b/304613751): This is only needed by the LMHLO. Remove this when - // LMHLO is removed from the runtime pipeline. - mlir::Operation* op; - }; - - // The hlo_instruction argument is meant to be the instruction this thunk was - // generated from, but Thunk never uses this argument other than to save it - // to Thunk::hlo_instruction, so it can be null. - Thunk(Kind kind, ThunkInfo thunk_info) - : kind_(kind), - profile_annotation_(thunk_info.profile_annotation), - op_(thunk_info.op) {} - virtual ~Thunk() = default; - Thunk(const Thunk&) = delete; - Thunk& operator=(const Thunk&) = delete; - - virtual std::string ToStringExtra(int indent) const { return ""; } - Kind kind() const { return kind_; } - std::string profile_annotation() const { return profile_annotation_; } - // Only valid during compilation, i.e., lowering thunks to kernel-launch - // related XLA runtime custom calls). nullptr at runtime. MLIR codegen will - // cease the practice of lowering thunks to XLA runtime custom calls. - mlir::Operation* op() { return op_; } - - // Prepares the thunk for execution on the given StreamExecutor. - // - // This may be called multiple times. Its main purpose is to give us a chance - // to do initialization outside of ExecuteOnStream() so that the - // time spent initializing doesn't count towards our execution profile. - virtual Status Initialize(se::StreamExecutor*, ExecutableSource) { - return OkStatus(); - } - - // Parameters passed to ExecuteOnStream. Encapsulated in a struct so that - // when we add something we don't have to change every subclass of Thunk. - struct ExecuteParams { - ExecuteParams(const ServiceExecutableRunOptions& run_options, - const BufferAllocations& buffer_allocations, - se::Stream* stream, - absl::Span async_streams); - - const BufferAllocations* buffer_allocations; // never null - se::Stream* stream; - absl::InlinedVector async_comms_streams; - NcclExecuteParams nccl_params; - }; - - // Execute the kernel for the thunk on the given stream. This method must be - // called after Initialize and can be called multiple times over Thunk's - // lifetime. - // - // Precondition: Initialize(stream->parent()) has been called. - virtual Status ExecuteOnStream(const ExecuteParams& params) = 0; - - // Clears metadata that is only valid during compile time. - virtual void ClearCompileTimeInfo() { op_ = nullptr; } - - static absl::string_view KindToString(Thunk::Kind kind); - - private: - Kind kind_; - std::string profile_annotation_; - mlir::Operation* op_; -}; - -// A sequence of thunks. -class ThunkSequence : public std::vector> { - public: - std::string ToString(int indent = 0, - std::function - get_thunk_annotation = nullptr) const; -}; - -std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); - -// A struct that defines a shaped slice, i.e., a BufferAllocation::Slice and its -// shape. -struct ShapedSlice { - BufferAllocation::Slice slice; - Shape shape; -}; - -// Returns if the thunk implements a reduction collective (all-reduce or -// reduce-scatter). -bool IsReductionCollective(Thunk::Kind kind); -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_THUNK_H_ diff --git a/xla/service/gpu/topk_specializer.cc b/xla/service/gpu/topk_specializer.cc index e69597bb42e85..7cf655eff0059 100644 --- a/xla/service/gpu/topk_specializer.cc +++ b/xla/service/gpu/topk_specializer.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -36,10 +37,9 @@ limitations under the License. #include "xla/service/tuple_util.h" #include "xla/shape.h" #include "xla/status.h" -#include "xla/statusor.h" +#include "xla/status_macros.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -47,7 +47,7 @@ namespace gpu { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace { -StatusOr SmallBufferOptimization( +absl::StatusOr SmallBufferOptimization( HloCustomCallInstruction* topk) { Shape data_shape = topk->operand(0)->shape(); auto supported_dtypes = {F32, BF16}; @@ -85,10 +85,10 @@ StatusOr SmallBufferOptimization( class SpecializeTopkVisitor : public DfsHloRewriteVisitor { public: - Status HandleCustomCall(HloInstruction* inst) override { + absl::Status HandleCustomCall(HloInstruction* inst) override { HloCustomCallInstruction* topk = DynCast(inst); if (topk == nullptr || topk->custom_call_target() != "TopK") { - return OkStatus(); + return absl::OkStatus(); } TF_RET_CHECK(topk->operand_count() == 1); @@ -99,13 +99,13 @@ class SpecializeTopkVisitor : public DfsHloRewriteVisitor { << small_topk.status(); } - return OkStatus(); + return absl::OkStatus(); } }; } // namespace -StatusOr TopkSpecializer::Run( +absl::StatusOr TopkSpecializer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { return SpecializeTopkVisitor().RunOnModule(module, execution_threads); @@ -113,7 +113,7 @@ StatusOr TopkSpecializer::Run( #else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -StatusOr TopkSpecializer::Run( +absl::StatusOr TopkSpecializer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { return false; diff --git a/xla/service/gpu/topk_specializer.h b/xla/service/gpu/topk_specializer.h index c2c02cfe1864f..5b57f57b77bba 100644 --- a/xla/service/gpu/topk_specializer.h +++ b/xla/service/gpu/topk_specializer.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,12 +17,10 @@ limitations under the License. #define XLA_SERVICE_GPU_TOPK_SPECIALIZER_H_ #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" -#include "tsl/platform/statusor.h" namespace xla::gpu { @@ -33,7 +31,7 @@ class TopkSpecializer : public HloModulePass { absl::string_view name() const override { return "topk-specializer"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/topk_splitter.cc b/xla/service/gpu/topk_splitter.cc index 2aea2d48be38b..33e271de2c32b 100644 --- a/xla/service/gpu/topk_splitter.cc +++ b/xla/service/gpu/topk_splitter.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/gpu/topk_splitter.h" #include +#include #include #include #include @@ -23,6 +24,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" #include "absl/numeric/bits.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" @@ -32,13 +34,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/literal_util.h" #include "xla/service/hlo_creation_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" #include "xla/statusor.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -52,29 +54,29 @@ class TopkSplitterVisitor : public DfsHloRewriteVisitor { explicit TopkSplitterVisitor(size_t split_threshold) : split_threshold_(split_threshold) {} - Status HandleCustomCall(HloInstruction* inst) override { + absl::Status HandleCustomCall(HloInstruction* inst) override { HloCustomCallInstruction* topk = DynCast(inst); if (topk == nullptr || topk->custom_call_target() != "TopK") { - return OkStatus(); + return absl::OkStatus(); } HloComputation* comp = inst->parent(); Shape data_shape = topk->operand(0)->shape(); bool has_batch = data_shape.dimensions_size() == 2; // TODO(doak): Support multiple batches. if (has_batch && data_shape.dimensions(0) != 1) { - return OkStatus(); + return absl::OkStatus(); } size_t n = data_shape.dimensions(has_batch ? 1 : 0); int64_t k = topk->shape().tuple_shapes(0).dimensions(has_batch ? 1 : 0); // If K approaches N, splitting the input will not be beneficial anymore. if (k > sqrt(n)) { - return OkStatus(); + return absl::OkStatus(); } // TODO(doak): Relax this alignment requirement. if (n % kRequiredAlignment != 0) { - return OkStatus(); + return absl::OkStatus(); } - if (n < split_threshold_) return OkStatus(); + if (n < split_threshold_) return absl::OkStatus(); int new_batch = std::min(absl::bit_floor(n / split_threshold_), kMaximumBatchSize); int new_n = n / new_batch; @@ -142,7 +144,7 @@ class TopkSplitterVisitor : public DfsHloRewriteVisitor { } // namespace -StatusOr TopKSplitter::Run( +absl::StatusOr TopKSplitter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { return TopkSplitterVisitor(split_threshold_) diff --git a/xla/service/gpu/topk_splitter.h b/xla/service/gpu/topk_splitter.h index ac177d10151ec..8fee2dc4975db 100644 --- a/xla/service/gpu/topk_splitter.h +++ b/xla/service/gpu/topk_splitter.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,16 +17,13 @@ limitations under the License. #define XLA_SERVICE_GPU_TOPK_SPLITTER_H_ #include -#include #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" -#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -41,7 +38,7 @@ class TopKSplitter : public HloModulePass { absl::string_view name() const override { return "topk-splitter"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/gpu/topk_splitter_test.cc b/xla/service/gpu/topk_splitter_test.cc index 7fb76f73ca089..834185f990956 100644 --- a/xla/service/gpu/topk_splitter_test.cc +++ b/xla/service/gpu/topk_splitter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,7 +25,6 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" -#include "xla/error_spec.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" @@ -34,7 +33,6 @@ limitations under the License. #include "xla/service/topk_rewriter.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" -#include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/service/gpu/runtime/topk_test.cc b/xla/service/gpu/topk_test.cc similarity index 89% rename from xla/service/gpu/runtime/topk_test.cc rename to xla/service/gpu/topk_test.cc index 6a21ec1087668..52018e55315f8 100644 --- a/xla/service/gpu/runtime/topk_test.cc +++ b/xla/service/gpu/topk_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime/topk.h" - #include #include @@ -23,12 +21,12 @@ limitations under the License. #include #include #include -#include +#include #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" -#include "xla/error_spec.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -41,10 +39,7 @@ limitations under the License. #include "xla/service/topk_rewriter.h" #include "xla/shape_util.h" #include "xla/status.h" -#include "xla/statusor.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" -#include "xla/types.h" #include "tsl/platform/statusor.h" namespace xla { @@ -68,8 +63,9 @@ class TopkTest : public HloTestBase, public ParameterizedInterface { *PlatformUtil::GetPlatform("gpu"), true, true, {}) {} protected: - StatusOr> TopkHlo(int n, int k, int batch_size, - std::string_view dtype) { + absl::StatusOr> TopkHlo(int n, int k, + int batch_size, + std::string_view dtype) { return ParseAndReturnVerifiedModule(absl::Substitute( R"( %compare { @@ -94,10 +90,10 @@ class TopkTest : public HloTestBase, public ParameterizedInterface { class GeneralizeTopkVisitor : public DfsHloRewriteVisitor { public: - Status HandleCustomCall(HloInstruction* inst) override { + absl::Status HandleCustomCall(HloInstruction* inst) override { HloCustomCallInstruction* topk = DynCast(inst); if (topk == nullptr || topk->custom_call_target() != "__gpu$TopK") { - return OkStatus(); + return absl::OkStatus(); } HloComputation* comp = topk->parent(); auto original_shape = ShapeUtil::SliceTuple(topk->shape(), 0, 2); @@ -122,9 +118,9 @@ class GeneralizeTopk : public HloModulePass { absl::string_view name() const override { return "generalized-topk"; } using HloPassInterface::Run; - StatusOr Run(HloModule* module, - const absl::flat_hash_set& - execution_threads) override { + absl::StatusOr Run(HloModule* module, + const absl::flat_hash_set& + execution_threads) override { return GeneralizeTopkVisitor().RunOnModule(module, execution_threads); } }; diff --git a/xla/service/gpu/tree_reduction_rewriter.cc b/xla/service/gpu/tree_reduction_rewriter.cc index bd37ce540540f..3ed9e365102cd 100644 --- a/xla/service/gpu/tree_reduction_rewriter.cc +++ b/xla/service/gpu/tree_reduction_rewriter.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,27 +14,37 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/tree_reduction_rewriter.h" -#include #include +#include #include #include #include #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/numeric/bits.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/reduction_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" +#include "xla/stream_executor/device_description.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -44,16 +54,16 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { explicit ReductionRewriterVisitor(se::GpuComputeCapability gpu_version) : gpu_version_(gpu_version) {} - Status HandleReduce(HloInstruction *hlo) override { + absl::Status HandleReduce(HloInstruction *hlo) override { if (IsMinMaxReduction(hlo)) { // TODO(cheshire): Also enable for integers. VLOG(1) << "Not performing tree expansion on min/max-reduction: " << hlo->ToString() << " since min/max operations are associative"; - return OkStatus(); + return absl::OkStatus(); } if (!IsReductionFromOrToContiguousDimensions(*hlo)) { - return OkStatus(); + return absl::OkStatus(); } return RewriteReduction(hlo); } @@ -69,7 +79,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { return false; } - Status RewriteReduction(HloInstruction *hlo) { + absl::Status RewriteReduction(HloInstruction *hlo) { ReductionDimensions reduction_dimensions = GetReductionKindAndContiguousComponents(*hlo); VLOG(5) << "Input: " << hlo->ToString(); @@ -101,7 +111,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { // Base case: everything fits. if (ReductionIsRaceFree(hlo->GetModule()->config(), reduction_dimensions)) { VLOG(3) << "Base case: dimensions fit"; - return OkStatus(); + return absl::OkStatus(); } VLOG(1) << "Input: " << hlo->ToString(); @@ -115,32 +125,36 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { // We do this by splitting the input shape [a, n, b] into [a, k, n / k, b]. // // We want to choose k to be roughly equal to sqrt(n) so that we process - // "most of" the reduction in the first step. We also want k to be a power - // of 2, so that the GPU kernel doesn't spend all its time doing slow - // integer divmods to compute indices into the shape [a,k,n/k,b]. This - // means we may need to pad n so that n is divisible by k. - // - // Thus we consider two options for k: - // - // k1 = round_up_pow2(sqrt(n)) - // k2 = round_down_pow2(sqrt(n)) - // - // and we choose the value of k that results in the least amount of padding. - int64_t k1 = absl::bit_ceil(static_cast(std::ceil(std::sqrt(n)))); - int64_t k2 = - absl::bit_floor(static_cast(std::floor(std::sqrt(n)))); - int64_t padded_n_k1 = RoundUpTo(n, k1); - int64_t padded_n_k2 = RoundUpTo(n, k2); - - int64_t k; - int64_t padded_n; - if (padded_n_k1 < padded_n_k2) { - k = k1; - padded_n = padded_n_k1; - } else { - k = k2; - padded_n = padded_n_k2; + // "most of" the reduction in the first step. But it is also important that + // we choose a value of k with the least amount of padding we need to add to + // n to make it divisible by k. We search for the best value of n / k + // between sqrt(n)/2 and sqrt(n). If there are several possible values for + // n / k that result in the minimum amount of padding, we also want n / k to + // be a power of 2, so that the GPU kernel doesn't spend all its time doing + // slow integer divmods to compute indices into the shape [a,k,n/k,b]. + // Note that by searching in the range between sqrt(n)/2 and sqrt(n), we + // will have a power of 2 in that range. + uint64_t n_div_k = static_cast(std::floor(std::sqrt(n))); + int64_t race_free_bound = ReductionDimensionRaceFreeBound( + hlo->GetModule()->config(), reduction_dimensions); + if (n_div_k > race_free_bound) { + // This means we need more than one split. It is best to limit the n/k + // dimension to the maximum size that doesn't require further splitting. + // Otherwise we might choose a rather small reduce dimension size for the + // first step (in the worst case, sqrt(race_free_bound + 1)). + n_div_k = race_free_bound; + } + uint64_t minimum_padding = (n_div_k - n % n_div_k) % n_div_k; + uint64_t best_k = (n + minimum_padding) / n_div_k; + for (uint64_t i = n_div_k - 1; i > n_div_k / 2; --i) { + uint64_t padding = (i - n % i) % i; + if (padding < minimum_padding || + (padding == minimum_padding && absl::has_single_bit(i))) { + minimum_padding = padding; + best_k = (n + padding) / i; + } } + uint64_t padded_n = n + minimum_padding; // Pad reduced dimension to the required number of elements. bool no_padding_necessary = n == padded_n; @@ -179,8 +193,8 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { for (int64_t dim_idx = 0; dim_idx < padded[0]->shape().dimensions_size(); dim_idx++) { if (dim_idx == reduced_input_dimension) { - reshaped_dimensions.push_back(k); - reshaped_dimensions.push_back(padded_n / k); + reshaped_dimensions.push_back(best_k); + reshaped_dimensions.push_back(padded_n / best_k); } else { reshaped_dimensions.push_back(padded[0]->shape().dimensions(dim_idx)); } @@ -248,7 +262,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { } // Rewrites batch dimension reduction into a separate reduce operation. - Status RewriteBatchDimensionLargerThanTile( + absl::Status RewriteBatchDimensionLargerThanTile( HloReduceInstruction *hlo, const ReductionDimensions &reduction_dimensions, int64_t reduced_input_dimension) { @@ -278,7 +292,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { se::GpuComputeCapability gpu_version_; }; -StatusOr GpuTreeReductionRewriter::Run( +absl::StatusOr GpuTreeReductionRewriter::Run( HloModule *module, const absl::flat_hash_set &execution_threads) { VLOG(5) << "Rewriter input: " << module->ToString(); diff --git a/xla/service/gpu/tree_reduction_rewriter.h b/xla/service/gpu/tree_reduction_rewriter.h index 5c3e8d9bf7311..5f6edf8ac33e4 100644 --- a/xla/service/gpu/tree_reduction_rewriter.h +++ b/xla/service/gpu/tree_reduction_rewriter.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,12 +15,12 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_TREE_REDUCTION_REWRITER_H_ #define XLA_SERVICE_GPU_TREE_REDUCTION_REWRITER_H_ -#include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" namespace xla { @@ -85,7 +85,7 @@ class GpuTreeReductionRewriter : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/gpu/triangular_solve_rewriter.cc b/xla/service/gpu/triangular_solve_rewriter.cc index 0b46e941e52bf..2dcd36569b707 100644 --- a/xla/service/gpu/triangular_solve_rewriter.cc +++ b/xla/service/gpu/triangular_solve_rewriter.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,15 +15,29 @@ limitations under the License. #include "xla/service/gpu/triangular_solve_rewriter.h" +#include #include +#include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/hlo_creation_utils.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { -StatusOr TriangularSolveRewriter::Run( +absl::StatusOr TriangularSolveRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/gpu/triangular_solve_rewriter.h b/xla/service/gpu/triangular_solve_rewriter.h index 7e1e2e0c69538..6d4b1c14188a0 100644 --- a/xla/service/gpu/triangular_solve_rewriter.h +++ b/xla/service/gpu/triangular_solve_rewriter.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,11 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_TRIANGULAR_SOLVE_REWRITER_H_ #define XLA_SERVICE_GPU_TRIANGULAR_SOLVE_REWRITER_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" namespace xla { namespace gpu { @@ -48,7 +49,7 @@ class TriangularSolveRewriter : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/triton_autotuner.cc b/xla/service/gpu/triton_autotuner.cc deleted file mode 100644 index 8c18f253933c3..0000000000000 --- a/xla/service/gpu/triton_autotuner.cc +++ /dev/null @@ -1,957 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/triton_autotuner.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "absl/time/time.h" -#include "absl/types/span.h" -#include "third_party/gpus/cuda/include/cublas_v2.h" -#include "xla/autotuning.pb.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_clone_context.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/utils/hlo_query.h" -#include "xla/service/dump.h" -#include "xla/service/executable.h" -#include "xla/service/float_normalization.h" -#include "xla/service/gpu/autotuner_compile_util.h" -#include "xla/service/gpu/autotuner_util.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/buffer_comparator.h" -#include "xla/service/gpu/gemm_rewriter.h" -#include "xla/service/gpu/gpu_float_support.h" -#include "xla/service/gpu/gpu_fusible.h" -#include "xla/service/gpu/instruction_fusion.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/split_k_gemm_rewriter.h" -#include "xla/service/gpu/stream_executor_util.h" -#include "xla/service/hlo_module_config.h" -#include "xla/service/shaped_buffer.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/status.h" -#include "xla/status_macros.h" -#include "xla/statusor.h" -#include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/gpu/redzone_allocator.h" -#include "xla/stream_executor/stream.h" -#include "xla/util.h" -#include "xla/xla.pb.h" -#include "tsl/lib/core/bits.h" -#include "tsl/platform/blocking_counter.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/threadpool.h" -#include "tsl/util/proto/proto_utils.h" - -// Log levels used in this file: -// VLOG(1): Overview -// VLOG(2): Autotuning progress -// VLOG(3): Autotuning progress - more frequent -// VLOG(4): Print all fusions -// VLOG(5): Profiling information for every tiling - -namespace xla { -namespace gpu { - -using ProfilingOutput = AutotunerCompileUtil::ProfilingOutput; - -namespace { - -// Currently supported minimum tile size. -constexpr int kMinTileSize = 16; -// Not a hard limit, just an assumption that should stay valid. -constexpr int kMaxTileSize = 512; - -// Default tiling when autotuning is disabled. -constexpr TritonGemmConfig kDefaultGemmTiling = {32, 32, 32, 1, 1, 4}; - -class TritonAutotunerVisitor : public DfsHloRewriteVisitor { - public: - explicit TritonAutotunerVisitor(const AutotuneConfig& config) - : config_(config) {} - - Status HandleFusion(HloInstruction* hlo) override { - TF_ASSIGN_OR_RETURN(auto backend_config, - hlo->backend_config()); - if (backend_config.kind() != kTritonGemmFusionKind) { - return OkStatus(); - } - - VLOG(4) << "Processing " << hlo->ToString(); - if (!backend_config.has_triton_gemm_config()) { - TF_ASSIGN_OR_RETURN( - AutotuneResult autotune_result, - AutotunerUtil::Autotune( - hlo, config_, [&]() -> StatusOr { - if (config_.IsDeviceless()) { - return absl::InternalError(absl::StrCat( - "Expect autotune result cache hit for deviceless " - "compilation (HLO: ", - hlo->ToString())); - } - return absl::InternalError("Expect autotune result cache hit."); - })); - VLOG(4) << "Result: " << autotune_result.ShortDebugString(); - - if (autotune_result.has_triton()) { - *backend_config.mutable_triton_gemm_config() = autotune_result.triton(); - TF_RETURN_IF_ERROR(hlo->set_backend_config(backend_config)); - } else { - // Falling back to cuBLAS: Converting the fusion to a Call, so that it - // can be inlined back again. - HloComputation* const computation = hlo->parent(); - HloInstruction* const call = computation->AddInstruction( - HloInstruction::CreateCall(hlo->shape(), hlo->operands(), - hlo->fused_instructions_computation())); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, call)); - hlo = call; - } - } - - // This cannot be the "else" branch of the previous "if". - if (backend_config.has_triton_gemm_config()) { - const TritonGemmConfig config = - TritonGemmConfig::FromProto(backend_config.triton_gemm_config()); - if (config.split_k > 1) { - TF_RETURN_IF_ERROR(MakeDotSplitKBatch(hlo, config)); - } - } - - MarkAsChanged(); - return OkStatus(); - } - - private: - AutotuneConfig config_; -}; - -// This contains all alternative Triton GEMM configs related to one fusion. -struct GemmConfigSet { - std::vector configs; -}; - -struct ExecutableCandidate { - TritonGemmConfig config; - // Not nullptr. - std::unique_ptr executable; -}; - -// This contains all alternative executables related to one fusion. -struct ExecutableSet { - std::vector candidates; - // Not nullptr. - std::unique_ptr reference; -}; - -class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault { - public: - explicit GemmConfigSetCollector(const AutotuneConfig& config) - : config_(config) {} - - StatusOr> - CollectGemmConfigSets( - const HloModule* module, - const absl::flat_hash_set& execution_threads = {}) { - gemm_config_sets_.clear(); - for (HloComputation* computation : - module->MakeNonfusionComputations(execution_threads)) { - TF_RETURN_IF_ERROR(computation->Accept(this)); - } - return std::move(gemm_config_sets_); - } - - Status HandleFusion(const HloInstruction* hlo) override { - const HloFusionInstruction* fusion = Cast(hlo); - - TF_ASSIGN_OR_RETURN(auto backend_config, - hlo->backend_config()); - if (backend_config.kind() != kTritonGemmFusionKind || - backend_config.has_triton_gemm_config()) { - return OkStatus(); - } - - AutotuneCacheKey key = AutotunerUtil::GetKey(hlo, config_); - if (AutotunerUtil::IsInCache(key) || handled_fusions_.contains(key)) { - return OkStatus(); - } - - CHECK(gemm_config_sets_.insert({fusion, GetGemmConfigSet(fusion)}).second); - - handled_fusions_.insert(key); - return OkStatus(); - } - - Status DefaultAction(const HloInstruction* hlo) override { - return OkStatus(); - } - - private: - GemmConfigSet GetGemmConfigSet(const HloFusionInstruction* fusion) { - const DebugOptions& debug_options = - fusion->GetModule()->config().debug_options(); - return {GetPossibleMatmulAutotuneConfigs( - *Cast(hlo_query::GetFirstInstructionWithOpcode( - *fusion->called_computations().at(0), HloOpcode::kDot)), - config_.GetCudaComputeCapability(), debug_options, - config_.ExhaustiveTilingSearch())}; - } - - AutotuneConfig config_; - absl::flat_hash_map - gemm_config_sets_; - absl::flat_hash_set handled_fusions_; -}; - -struct TileSizeLimit { - int64_t block_m = 0; - int64_t block_n = 0; - int64_t block_k = 0; -}; - -TileSizeLimit GetUpperLimit(const HloDotInstruction& dot) { - // This is not a sharp upper limit, the actual m value can be much smaller - // based on how much of the m dimension is physically contiguous. - // TODO(tdanyluk): Get the exact m value by running a TritonFusionAnalysis. - const int64_t m = dot.operand(0)->shape().dimensions( - NonContractingDimensionIndex(dot, /*operand_number=*/0)); - // Theoretically the same is true as for m, but that is not possible in - // practice with the current implementation. - const int64_t n = dot.operand(1)->shape().dimensions( - NonContractingDimensionIndex(dot, /*operand_number=*/1)); - // This is before doing the split-k transform. - const int64_t k = dot.operand(0)->shape().dimensions( - ContractingDimensionIndex(dot, /*operand_number=*/0)); - const int64_t block_m_limit = - std::max(tsl::NextPowerOfTwoS64(m), kMinTileSize); - const int64_t block_n_limit = - std::max(tsl::NextPowerOfTwoS64(n), kMinTileSize); - const int64_t block_k_limit = - std::max(tsl::NextPowerOfTwoS64(k), kMinTileSize); - return {block_m_limit, block_n_limit, block_k_limit}; -} - -int64_t GetSplitKLimit(int64_t block_k, int64_t block_k_limit) { - return std::max(block_k_limit / block_k, 1); -} - -// Search space for exhaustive matmul autotuning. -constexpr std::array BLOCK_SIZES = {16, 32, 64, 128, 256, 512}; -constexpr std::array NUM_STAGES = {1, 2, 3, 4}; -constexpr std::array NUM_WARPS = {2, 4, 8, 16}; -constexpr std::array SPLIT_K = {1, 2, 4, 8, 16}; - -std::vector GetExhaustiveMatmulAutotuneConfigs( - const HloDotInstruction& dot, - const se::CudaComputeCapability compute_capability, const int max_split_k) { - const TileSizeLimit limit = GetUpperLimit(dot); - std::vector configs; - bool mma_layout_v2 = - compute_capability.IsAtLeast(se::CudaComputeCapability::AMPERE); - for (int num_warps : NUM_WARPS) { - for (int num_stages : NUM_STAGES) { - // Volta doesn't support num_stages > 2. - if (!mma_layout_v2 && num_stages > 2) { - continue; - } - for (int block_m : BLOCK_SIZES) { - if (block_m > limit.block_m) { - continue; - } - for (int block_n : BLOCK_SIZES) { - // Exclude configs not supported by MMA layout v2. - if (block_n > limit.block_n || - (mma_layout_v2 && (block_m * block_n / 256) % num_warps != 0)) { - continue; - } - for (int block_k : BLOCK_SIZES) { - if (block_k > limit.block_k) { - continue; - } - for (int split_k : SPLIT_K) { - if (split_k > - std::min(max_split_k, - GetSplitKLimit(block_k, limit.block_k))) { - continue; - } - auto config = TritonGemmConfig(block_m, block_n, block_k, split_k, - num_stages, num_warps); - configs.push_back(std::move(config)); - } - } - } - } - } - } - return configs; -} - -std::vector GetFixedMatmulAutotuneConfigs( - const se::CudaComputeCapability compute_capability, const int max_split_k) { - // Shorter name for better formatting. - using Config = TritonGemmConfig; - std::vector configs = { - Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4), - Config(32, 64, 64, 4, 1, 4), Config(128, 128, 64, 4, 1, 4), - Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4), - Config(16, 64, 128, 1, 1, 4), Config(16, 128, 32, 8, 1, 4), - Config(16, 16, 512, 1, 1, 4), Config(32, 16, 512, 1, 1, 4), - Config(64, 32, 64, 1, 2, 8)}; - if (compute_capability.IsAtLeast(se::CudaComputeCapability::AMPERE)) { - absl::c_copy( - std::vector{ - Config(128, 256, 32, 1, 3, 8), Config(256, 128, 32, 1, 3, 8), - Config(256, 64, 32, 1, 4, 4), Config(64, 256, 32, 1, 4, 4), - Config(128, 64, 32, 1, 4, 4), Config(64, 128, 32, 1, 4, 4), - Config(256, 128, 128, 1, 3, 8), Config(256, 64, 128, 1, 4, 4), - Config(64, 256, 128, 1, 4, 4), Config(128, 128, 128, 1, 4, 4), - Config(128, 64, 64, 1, 4, 4), Config(64, 128, 64, 1, 4, 4), - Config(128, 32, 64, 1, 4, 4), Config(64, 32, 64, 1, 4, 4), - Config(32, 128, 32, 1, 4, 4), Config(128, 128, 32, 1, 4, 4), - Config(16, 16, 256, 1, 3, 4), Config(128, 128, 64, 2, 1, 8), - Config(64, 64, 64, 1, 2, 4), Config(16, 64, 256, 8, 1, 4), - Config(256, 256, 128, 1, 3, 8)}, - std::back_inserter(configs)); - } - if (compute_capability.IsAtLeast(se::CudaComputeCapability::HOPPER)) { - configs.erase( - std::remove_if(configs.begin(), configs.end(), - [](const Config& config) { - return (config.block_m * config.block_n / 256) % - config.num_warps != - 0; - }), - configs.end()); - } - configs.erase(std::remove_if(configs.begin(), configs.end(), - [&](const Config& config) { - return config.split_k > max_split_k; - }), - configs.end()); - return configs; -} - -// This prefers to take the parameter by moving it. -std::vector ReduceTileSizes( - const HloDotInstruction& dot, std::vector configs) { - const TileSizeLimit limit = GetUpperLimit(dot); - // Decrease the block sizes and split_k if they are unnecessarily big. - for (TritonGemmConfig& config : configs) { - config.block_m = std::min(config.block_m, limit.block_m); - config.block_n = std::min(config.block_n, limit.block_n); - config.block_k = std::min(config.block_k, limit.block_k); - config.split_k = std::min( - config.split_k, GetSplitKLimit(config.block_k, limit.block_k)); - } - - // Remove duplicates. - absl::flat_hash_set configs_so_far; - configs.erase(std::remove_if(configs.begin(), configs.end(), - [&](const TritonGemmConfig& config) { - return !configs_so_far.insert(config).second; - }), - configs.end()); - CHECK(!configs.empty()); - return configs; -} - -int GetLogEveryN() { return VLOG_IS_ON(3) ? 100 : 1000; } - -StatusOr> TritonGemmAutotuneExtractor( - const TritonGemmConfig& config, - const se::DeviceDescription& gpu_device_info, - const HloFusionInstruction* fusion, DebugOptions debug_opts, - bool allow_filtering_kernels_spilling_registers) { - std::unique_ptr new_module = - AutotunerUtil::ExtractInstructionIntoNewModule(*fusion); - // Reduce memory usage during compilation by disabling GPU runtime. - debug_opts.set_xla_gpu_enable_xla_runtime_executable(false); - // TODO(anlunx): Disable command buffers for now because it breaks triton - // autotuner test. Enable this when the function of command buffers is stable. - debug_opts.clear_xla_gpu_enable_command_buffer(); - if (!allow_filtering_kernels_spilling_registers) { - debug_opts.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning( - false); - } - new_module->mutable_config().set_debug_options(debug_opts); - - HloComputation* entry_computation = new_module->entry_computation(); - HloInstruction* cloned_dot_fusion = entry_computation->root_instruction(); - - TF_ASSIGN_OR_RETURN(auto backend_config, - cloned_dot_fusion->backend_config()); - *backend_config.mutable_triton_gemm_config() = config.ToProto(); - TF_RETURN_IF_ERROR(cloned_dot_fusion->set_backend_config(backend_config)); - - if (config.split_k > 1) { - TF_RETURN_IF_ERROR(MakeDotSplitKBatch(cloned_dot_fusion, config)); - GpuFloatSupport bf16_support(BF16); - FloatNormalization float_normalization(&bf16_support); - TF_RETURN_IF_ERROR(float_normalization.Run(new_module.get()).status()); - GpuInstructionFusion instruction_fusion(/*may_duplicate=*/false, - gpu_device_info); - TF_RETURN_IF_ERROR(instruction_fusion.Run(new_module.get()).status()); - HloInstruction* root = entry_computation->root_instruction(); - // If the instruction fusion pass above skipped the reduction, turn it - // into a fusion for a universal set of arguments for execution. - if (root->opcode() == HloOpcode::kReduce) { - HloInstruction* fusion_instruction = - entry_computation->AddInstruction(HloInstruction::CreateFusion( - root->shape(), ChooseFusionKind(*root->operand(0), *root), root)); - HloInstruction* init_value = root->mutable_operand(1); - TF_CHECK_OK( - entry_computation->ReplaceInstruction(root, fusion_instruction)); - fusion_instruction->FuseInstruction(init_value); - TF_CHECK_OK(entry_computation->RemoveInstruction(init_value)); - } - } - return new_module; -} - -StatusOr> CublasGemmAutotuneExtractor( - const AutotuneConfig& config, const HloFusionInstruction* fusion, - const DebugOptions& debug_opts) { - const HloComputation* fusion_computation = - fusion->called_computations().at(0); - std::unique_ptr new_module = - AutotunerUtil::ExtractComputationIntoNewModule(*fusion_computation); - new_module->mutable_config().set_debug_options(debug_opts); - - GemmRewriter rewriter(config.GetCudaComputeCapability()); - GpuInstructionFusion fusion_pass( - /*may_duplicate=*/false, config.GetExecutor()->GetDeviceDescription()); - TF_RETURN_IF_ERROR(rewriter.Run(new_module.get()).status()); - TF_RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status()); - // TODO(tdanyluk): Consider running GemmAlgorithmPicker here for better cuBLAS - // performance. It is probably not needed on Ampere and later because cuBLAS - // ignores the algorithm parameter for those targets. If we run - // GemmAlgorithmPicker, we probably should not run this in parallel with other - // compilations. - return new_module; -} - -bool ShouldAllowFilteringKernelsSpillingRegisters( - const GemmConfigSet& gemm_config_set) { - return gemm_config_set.configs.size() > 1; -} - -StatusOr> -CompileMany(const AutotuneConfig& config, AutotunerCompileUtil& util, - tsl::thread::ThreadPool* thread_pool, - const DebugOptions& debug_opts, - const absl::flat_hash_map& gemm_config_sets) { - absl::Mutex executable_sets_mu; - absl::flat_hash_map - executable_sets; - - if (gemm_config_sets.empty()) { - return executable_sets; - } - - const se::DeviceDescription& gpu_device_info = - config.GetExecutor()->GetDeviceDescription(); - - const int log_every_n = GetLogEveryN(); - int64_t config_count = 0; - for (const auto& key_value : gemm_config_sets) { - const GemmConfigSet& gemm_config_set = key_value.second; - config_count += gemm_config_set.configs.size(); - } - // The cuBLAS configs: - config_count += gemm_config_sets.size(); - - std::atomic done_count = 0; - std::atomic good_count = 0; - auto log = [&](bool success) { - const int done_so_far = done_count.fetch_add(1) + 1; - const int good_so_far = - success ? good_count.fetch_add(1) + 1 : good_count.load(); - if (done_so_far % log_every_n == 0) { - VLOG(2) << "Compiled " << done_so_far << " of " << config_count - << " configs (successful: " << good_so_far << ")"; - } - }; - - // Returns true on success. - auto compile = - [&](const HloFusionInstruction* fusion, const TritonGemmConfig& conf, - bool allow_filtering_kernels_spilling_registers) -> StatusOr { - CHECK_LE(conf.block_m, kMaxTileSize); - CHECK_LE(conf.block_n, kMaxTileSize); - CHECK_LE(conf.block_k, kMaxTileSize); - // TODO(b/296884861): Reenable GPU runtime, when it will have much smaller - // memory overhead (regarding the size of the executables). - // We can also remove the force_disable_gpu_runtime argument at that - // point. - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - util.Compile([&](const DebugOptions& opts) { - return TritonGemmAutotuneExtractor( - conf, gpu_device_info, fusion, opts, - allow_filtering_kernels_spilling_registers); - })); - - if (executable != nullptr) { - absl::MutexLock lock(&executable_sets_mu); - ExecutableSet& executable_set = executable_sets[fusion]; - executable_set.candidates.push_back( - ExecutableCandidate{conf, std::move(executable)}); - return true; - } - - return false; - }; - - // Returns true on success. - auto compile_reference_executable = - [&](const HloFusionInstruction* fusion) -> StatusOr { - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - util.Compile([&](const DebugOptions& opts) { - return CublasGemmAutotuneExtractor(config, fusion, - opts); - })); - - if (executable != nullptr) { - absl::MutexLock lock(&executable_sets_mu); - ExecutableSet& executable_set = executable_sets[fusion]; - TF_RET_CHECK(executable_set.reference == nullptr); - executable_set.reference = std::move(executable); - return true; - } - - return false; - }; - - // If the thread pool has only one thread, then it is actually slower to - // offload the tasks there. - if (thread_pool && thread_pool->NumThreads() > 1 && - debug_opts.xla_gpu_force_compilation_parallelism() != 1) { - if (gemm_config_sets.size() == 1) { - absl::string_view fusion_name = gemm_config_sets.begin()->first->name(); - VLOG(1) << "Compiling " << config_count << " configs for " << fusion_name - << " on " << thread_pool->NumThreads() << " threads."; - } else { - VLOG(1) << "Compiling " << config_count << " configs for " - << gemm_config_sets.size() << " fusions on " - << thread_pool->NumThreads() << " threads."; - } - - tsl::BlockingCounter counter(config_count); - for (const auto& key_value : gemm_config_sets) { - const HloFusionInstruction* fusion = key_value.first; - const GemmConfigSet& gemm_config_set = key_value.second; - - for (const TritonGemmConfig& conf : gemm_config_set.configs) { - thread_pool->Schedule([&, fusion] { - StatusOr has_executable = compile( - fusion, conf, - ShouldAllowFilteringKernelsSpillingRegisters(gemm_config_set)); - TF_CHECK_OK(has_executable.status()) - << "Failure occured when compiling fusion " << fusion->name() - << " with config '" << conf.ToString() - << "'\nFused HLO computation:\n" - << fusion->fused_instructions_computation()->ToString(); - log(has_executable.value()); - counter.DecrementCount(); - }); - } - - thread_pool->Schedule([&, fusion] { - StatusOr has_executable = compile_reference_executable(fusion); - TF_CHECK_OK(has_executable.status()); - log(has_executable.value()); - counter.DecrementCount(); - }); - } - counter.Wait(); - } else { - if (gemm_config_sets.size() == 1) { - absl::string_view fusion_name = gemm_config_sets.begin()->first->name(); - LOG(WARNING) << "Compiling " << config_count << " configs for " - << fusion_name << " on a single thread."; - - } else { - LOG(WARNING) << "Compiling " << config_count << " configs for " - << gemm_config_sets.size() << " fusions on a single thread."; - } - - for (const auto& key_value : gemm_config_sets) { - const HloFusionInstruction* fusion = key_value.first; - const GemmConfigSet& gemm_config_set = key_value.second; - - for (const TritonGemmConfig& gemm_config : gemm_config_set.configs) { - TF_ASSIGN_OR_RETURN( - bool has_executable, - compile( - fusion, gemm_config, - ShouldAllowFilteringKernelsSpillingRegisters(gemm_config_set))); - log(has_executable); - } - - TF_ASSIGN_OR_RETURN(bool has_executable, - compile_reference_executable(fusion)); - log(has_executable); - } - } - - VLOG(1) << "Done compiling (successful: " << good_count.load() << ")."; - - return executable_sets; -} - -// Runs matmul fusion contents without Triton - with cuBLAS, to measure time and -// generate a reference output. -StatusOr RunMatmulWithCublas( - AutotunerCompileUtil& util, se::Stream* stream, Executable& executable, - absl::Span input_buffers, - absl::Span input_shapes) { - TF_ASSIGN_OR_RETURN( - std::optional output, - util.ProfileExecutable(&executable, stream, input_buffers, input_shapes)); - TF_RET_CHECK(output.has_value()); - return std::move(output.value()); -} - -StatusOr Execute(const AutotuneConfig& config, - AutotunerCompileUtil& util, - const DebugOptions& debug_opts, - const HloFusionInstruction* fusion, - const ExecutableSet& executable_set) { - const HloComputation* fusion_computation = - fusion->called_computations().at(0); - - se::StreamExecutor* stream_exec = config.GetExecutor(); - if (!stream_exec->SynchronizeAllActivity()) { - return InternalError("Failed to synchronize GPU for autotuning."); - } - se::DeviceMemoryAllocator* allocator = config.GetAllocator(); - if (allocator == nullptr) { - allocator = stream_exec->GetAllocator(); - } - TF_ASSIGN_OR_RETURN(se::Stream* const stream, - allocator->GetStream(stream_exec->device_ordinal())); - TF_ASSIGN_OR_RETURN( - se::RedzoneAllocator rz_allocator, - AutotunerUtil::CreateRedzoneAllocator(config, debug_opts)); - - const HloInstruction& root = *fusion_computation->root_instruction(); - BufferComparator comparator(root.shape(), - fusion_computation->parent()->config()); - - std::vector inputs; - inputs.reserve(fusion_computation->parameter_instructions().size()); - std::vector input_shapes; - input_shapes.reserve(fusion_computation->parameter_instructions().size()); - int64_t rng_state = 0; - for (const HloInstruction* param : - fusion_computation->parameter_instructions()) { - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase param_buffer, - AutotunerUtil::CreateBuffer( - rz_allocator, param->shape(), config, rng_state)); - inputs.push_back(param_buffer); - input_shapes.push_back(param->shape()); - } - - // Run with cuBLAS. - std::optional reference_buffer; - absl::Duration cublas_duration; - { - TF_RET_CHECK(executable_set.reference != nullptr); - TF_ASSIGN_OR_RETURN( - ProfilingOutput output, - RunMatmulWithCublas(util, stream, *executable_set.reference, inputs, - input_shapes)); - if (config.should_check_correctness()) { - reference_buffer = std::move(output.output); - } - cublas_duration = output.duration; - } - - const int log_every_n = GetLogEveryN(); - int64_t executable_count = - static_cast(executable_set.candidates.size()); - int ran_so_far = 0; - std::vector results; - VLOG(2) << "Running " << executable_count << " configs for " << fusion->name() - << "."; - for (const ExecutableCandidate& candidate : executable_set.candidates) { - VLOG(5) << "Trying triton tiling: " << candidate.config.ToString(); - - AutotuneResult res; - *res.mutable_triton() = candidate.config.ToProto(); - - TF_ASSIGN_OR_RETURN(std::optional profiling_output, - util.ProfileExecutable(candidate.executable.get(), - stream, inputs, input_shapes)); - ran_so_far += 1; - if (ran_so_far % log_every_n == 0) { - VLOG(2) << "Ran " << ran_so_far << " configs of " << executable_count - << "."; - } - - if (!profiling_output) { - VLOG(5) << "Skipping this tiling."; - continue; - } - - VLOG(5) << "Running the kernel took: " << profiling_output->duration; - if (profiling_output->duration >= absl::Seconds(1)) { - LOG(WARNING) << "Slow kernel for " << fusion->name() - << " took: " << profiling_output->duration - << ". config: " << candidate.config.ToString(); - } - *res.mutable_run_time() = - tsl::proto_utils::ToDurationProto(profiling_output->duration); - - if (config.should_check_correctness()) { - TF_ASSIGN_OR_RETURN( - se::RedzoneAllocator::RedzoneCheckStatus rz_check_status, - rz_allocator.CheckRedzones()); - if (!rz_check_status.ok()) { - LOG(ERROR) << "Red zone modified"; - res.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED); - res.mutable_failure()->set_msg(rz_check_status.RedzoneFailureMsg()); - CHECK(!config.should_crash_on_check_failure()); - continue; - } - - TF_ASSIGN_OR_RETURN( - bool outputs_match, - comparator.CompareEqual( - stream, /*current=*/profiling_output->output.root_buffer(), - /*expected=*/reference_buffer->root_buffer())); - if (!outputs_match) { - const char kMessage[] = - "Results do not match the reference. This is likely a " - "bug/unexpected loss of precision."; - LOG(ERROR) << kMessage; - CHECK(!config.should_crash_on_check_failure()); - // WRONG_RESULT is not taken seriously by PickBestResult(), so - // use DISQUALIFIED. - res.mutable_failure()->set_kind(AutotuneResult::DISQUALIFIED); - res.mutable_failure()->set_msg(kMessage); - } - } - results.push_back(res); - } - VLOG(2) << "Done running."; - - TF_ASSIGN_OR_RETURN( - AutotuneResult best_triton, - PickBestResult(results, root.ToString(), root.GetModule()->config())); - - if (debug_opts.xla_gpu_cublas_fallback() && - !debug_opts.xla_gpu_deterministic_ops()) { - const absl::Duration best_triton_duration = - tsl::proto_utils::FromDurationProto(best_triton.run_time()); - VLOG(2) << fusion->name() << ": time with cuBLAS: " << cublas_duration - << ", best time with Triton: " << best_triton_duration; - if (cublas_duration < best_triton_duration) { - VLOG(2) << "Falling back to cuBLAS for " << fusion->name(); - - AutotuneResult cublas; - *cublas.mutable_run_time() = - tsl::proto_utils::ToDurationProto(cublas_duration); - // We will ignore this value anyway. - cublas.mutable_gemm()->set_algorithm(CUBLAS_GEMM_DEFAULT); - - return cublas; - } - } - - return best_triton; -} - -Status DumpAutotunedFusion(const AutotuneConfig& config, - AutotunerCompileUtil& util, - const AutotuneResult result, - const HloFusionInstruction* fusion, int fusion_id) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - util.ExtractModule([&](const DebugOptions& debug_opts) { - return TritonGemmAutotuneExtractor( - TritonGemmConfig::FromProto(result.triton()), - config.GetExecutor()->GetDeviceDescription(), fusion, debug_opts, - /*allow_filtering_kernels_spilling_registers=*/true); - })); - module->set_name(std::string(fusion->name())); - // Using the original module for its debug info and name in the first - // parameter. It's better to include the name of both the original module - // and the extracted module, to avoid name clashes. - DumpToFileInDirOrStdout( - /*module=*/*fusion->GetModule(), - /*file_prefix=*/"", - /*file_suffix=*/ - absl::StrCat("triton_fusion_", fusion_id, ".", module->name(), - ".optimized.txt"), - /*contents=*/module->ToString()); - return OkStatus(); -} - -Status Autotune(const AutotuneConfig& config, AutotunerCompileUtil& util, - tsl::thread::ThreadPool* thread_pool, - const DebugOptions& debug_opts, - const absl::flat_hash_map& gemm_config_sets) { - absl::flat_hash_map - executable_sets; - TF_ASSIGN_OR_RETURN( - executable_sets, - CompileMany(config, util, thread_pool, debug_opts, gemm_config_sets)); - - // Sort the candidates to make their execution order well-defined for each - // fusion. - for (auto& key_value : executable_sets) { - ExecutableSet& executable_set = key_value.second; - std::vector& candidates = executable_set.candidates; - absl::c_sort(candidates, [](const ExecutableCandidate& a, - const ExecutableCandidate& b) { - return a.config < b.config; - }); - } - - int fusion_id = 0; - for (const auto& key_value : executable_sets) { - const HloFusionInstruction* fusion = key_value.first; - const ExecutableSet& executable_set = key_value.second; - - TF_ASSIGN_OR_RETURN(AutotuneResult result, Execute(config, util, debug_opts, - fusion, executable_set)); - - if (debug_opts.xla_gpu_dump_autotuned_triton_fusions()) { - TF_RETURN_IF_ERROR( - DumpAutotunedFusion(config, util, result, fusion, fusion_id++)); - } - - const AutotuneCacheKey key = AutotunerUtil::GetKey(fusion, config); - if (!AutotunerUtil::AddResult(key, std::move(result))) { - // In the context of model server, concurrent autotuning is expected and - // insertion of identical autotuning keys is accepted. - LOG(WARNING) << "AutotunerUtil::AddResult already existed: " - << key.ToString(); - } - } - - return OkStatus(); -} - -} // anonymous namespace - -std::vector GetPossibleMatmulAutotuneConfigs( - const HloDotInstruction& dot, - const se::CudaComputeCapability compute_capability, - const DebugOptions& debug_options, bool exhaustive_tiling_search) { - // Avoid autotuning tiny fusions. - constexpr int kMinGemmElements = 32 * 32; - if (ShapeUtil::ElementsIn(dot.operand(0)->shape()) <= kMinGemmElements && - ShapeUtil::ElementsIn(dot.operand(1)->shape()) <= kMinGemmElements) { - return ReduceTileSizes(dot, {kDefaultGemmTiling}); - } - // Split-K optimization enables more even utilization of a GPU in cases - // where tiling just the non-contracting dimensions of a GEMM does not create - // a sufficient number of thread block programs to occupy all available cores. - // Given the typical ~100 cores per GPU 500 tiles make around 5 full - // waves that completely avoid the need for split-K. The formula below is - // n_tiles = split_k * (M * N) / (block_m * block_n) - // with pessimistically assumed maximum block_m and block_n. - // Most likely there is no need for split-K already at much smaller output - // tensor sizes. - constexpr int kSufficientNumberOfTiles = 500; - const int max_split_k = - debug_options.xla_gpu_enable_split_k_autotuning() - ? std::max(1L, kSufficientNumberOfTiles * kMaxTileSize * - kMaxTileSize / - ShapeUtil::ElementsIn(dot.shape())) - : 1; - return exhaustive_tiling_search - ? GetExhaustiveMatmulAutotuneConfigs(dot, compute_capability, - max_split_k) - : ReduceTileSizes(dot, GetFixedMatmulAutotuneConfigs( - compute_capability, max_split_k)); -} - -StatusOr TritonAutotuner::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - XLA_SCOPED_LOGGING_TIMER("Triton autotuner"); - const DebugOptions& debug_options = module->config().debug_options(); - TF_ASSIGN_OR_RETURN(std::optional opt_compile_util, - AutotunerCompileUtil::Create(config_, debug_options)); - - GemmConfigSetCollector gemm_config_set_collector(config_); - absl::flat_hash_map - gemm_config_sets; - TF_ASSIGN_OR_RETURN(gemm_config_sets, - gemm_config_set_collector.CollectGemmConfigSets( - module, execution_threads)); - - if (debug_options.xla_gpu_autotune_level() == 0 || - debug_options.xla_gpu_deterministic_ops()) { - // Pick the first option for each gemm instead of autotuning.. - for (const auto& [fusion, tilings] : gemm_config_sets) { - const AutotuneCacheKey key = AutotunerUtil::GetKey(fusion, config_); - AutotuneResult res; - *res.mutable_triton() = kDefaultGemmTiling.ToProto(); - *res.mutable_run_time() = - tsl::proto_utils::ToDurationProto(absl::ZeroDuration()); - AutotunerUtil::AddResult(key, res); - } - } else if (!config_.IsDeviceless()) { - TF_RET_CHECK(opt_compile_util.has_value()); - if (!gemm_config_sets.empty()) { - std::string correctness_check_str = config_.should_check_correctness() - ? "(with correctness check)" - : "(without correctness check)"; - - VLOG(1) << "Autotuning " << gemm_config_sets.size() << " fusions " - << correctness_check_str << "."; - TF_RETURN_IF_ERROR(Autotune(config_, *opt_compile_util, thread_pool_, - debug_options, gemm_config_sets)); - VLOG(1) << "Done autotuning."; - } - } - - return TritonAutotunerVisitor(config_).RunOnModule(module, execution_threads); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/triton_autotuner.h b/xla/service/gpu/triton_autotuner.h deleted file mode 100644 index 9e1e07f9d787e..0000000000000 --- a/xla/service/gpu/triton_autotuner.h +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef XLA_SERVICE_GPU_TRITON_AUTOTUNER_H_ -#define XLA_SERVICE_GPU_TRITON_AUTOTUNER_H_ - -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/strings/string_view.h" -#include "xla/autotuning.pb.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/autotuner_util.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" -#include "xla/stream_executor/device_description.h" -#include "xla/xla.pb.h" -#include "tsl/platform/threadpool.h" - -namespace xla { -namespace gpu { - -// Find best tiling configuration for each triton fusion outlined. -class TritonAutotuner : public HloModulePass { - public: - explicit TritonAutotuner(const AutotuneConfig& config, - tsl::thread::ThreadPool* thread_pool) - : config_(config), thread_pool_(thread_pool) {} - - absl::string_view name() const override { return "triton-autotuner"; } - - using HloPassInterface::Run; - StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - AutotuneConfig config_; - tsl::thread::ThreadPool* thread_pool_; -}; - -// TODO(b/266210099): have a way to generate/load these dynamically. -// Returns a list of possible tilings for a GEMM performed in Triton. -std::vector GetPossibleMatmulAutotuneConfigs( - const HloDotInstruction& dot, se::CudaComputeCapability compute_capability, - const DebugOptions& debug_options, bool exhaustive_tiling_search = false); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_TRITON_AUTOTUNER_H_ diff --git a/xla/service/gpu/triton_autotuner_test.cc b/xla/service/gpu/triton_autotuner_test.cc deleted file mode 100644 index 0edf9c86a987d..0000000000000 --- a/xla/service/gpu/triton_autotuner_test.cc +++ /dev/null @@ -1,710 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "xla/service/gpu/triton_autotuner.h" - -#include -#include -#include -#include -#include - -#include -#include -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "xla/autotuning.pb.h" -#include "xla/error_spec.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/executable.h" -#include "xla/service/gpu/autotuner_util.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/gemm_rewriter_triton.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_pass_pipeline.h" -#include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" -#include "xla/shape_util.h" -#include "xla/stream_executor/device_description.h" -#include "xla/tests/filecheck.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_utils.h" -#include "xla/tests/verified_hlo_module.h" -#include "xla/xla.pb.h" -#include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" -#include "tsl/platform/cpu_info.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status_matchers.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/threadpool.h" - -namespace xla { -namespace gpu { -namespace { - -namespace m = ::xla::match; - -using HloExtractionTest = HloTestBase; - -TEST_F(HloExtractionTest, InstructionExtractionIsCorrect) { - std::unique_ptr module = ParseAndReturnVerifiedModule(R"( -HloModule module - -triton_gemm_dot { - p0 = s8[10,10] parameter(0) - p1 = f32[10,10] parameter(1) - c0 = f32[10,10] convert(p0) - ROOT dot.0 = f32[10,10] dot(c0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY entry { - p0 = s8[10,10] parameter(0) - p1 = f32[10,10] parameter(1) - s = f32[10,10] sqrt(p1) - d = f32[10,10] fusion(p0, p1), - kind=kCustom, calls=triton_gemm_dot - ROOT r = f32[10,10] add(d, s) -})") - .value(); - - std::unique_ptr extracted_module = - AutotunerUtil::ExtractInstructionIntoNewModule( - *module->entry_computation()->root_instruction()->operand(0)); - - // Destroy the original module to be sure that the extracted one has no - // dependency on it. - module.release(); - - EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); - EXPECT_EQ(extracted_module->entry_computation()->instruction_count(), 3); - TF_EXPECT_OK(VerifyHloModule(extracted_module.get(), - /*layout_sensitive=*/true, - /*allow_mixed_precision=*/false)); -} - -TEST_F(HloExtractionTest, ComputationExtractionIsCorrect) { - std::unique_ptr module = ParseAndReturnVerifiedModule(R"( -HloModule module - -triton_gemm_dot { - p0 = s8[10,10] parameter(0) - p1 = f32[10,10] parameter(1) - c0 = f32[10,10] convert(p0) - ROOT dot.0 = f32[10,10] dot(c0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY entry { - p0 = s8[10,10] parameter(0) - p1 = f32[10,10] parameter(1) - s = f32[10,10] sqrt(p1) - d = f32[10,10] fusion(p0, p1), - kind=kCustom, calls=triton_gemm_dot - ROOT r = f32[10,10] add(d, s) -})") - .value(); - - std::unique_ptr extracted_module = - AutotunerUtil::ExtractComputationIntoNewModule( - *module->entry_computation() - ->root_instruction() - ->operand(0) - ->fused_instructions_computation()); - - // Destroy the original module to be sure that the extracted one has no - // dependency on it. - module.release(); - - EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), - GmockMatch(m::Dot(m::Convert(m::Parameter()), m::Parameter()))); - EXPECT_EQ(extracted_module->entry_computation()->instruction_count(), 4); - TF_EXPECT_OK(VerifyHloModule(extracted_module.get(), - /*layout_sensitive=*/true, - /*allow_mixed_precision=*/false)); -} - -class StatelessAutotunerTest : public HloTestBase { - public: - StatelessAutotunerTest() - : HloTestBase(/*verifier_layout_sensitive=*/true, - /*allow_mixed_precision_in_hlo_verifier=*/false) {} - - void SetUp() override { - AutotunerUtil::ClearAutotuneResults(); - HloTestBase::SetUp(); - } - - void TearDown() override { - AutotunerUtil::ClearAutotuneResults(); - HloTestBase::TearDown(); - } -}; - -class TritonAutotunerTest : public StatelessAutotunerTest { - public: - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = - StatelessAutotunerTest::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_triton_gemm(true); - debug_options.set_xla_gpu_cublas_fallback(false); - return debug_options; - } - - se::CudaComputeCapability GetCudaComputeCapability() { - return backend() - .default_stream_executor() - ->GetDeviceDescription() - .cuda_compute_capability(); - } - - void CheckTritonAutotuning(absl::string_view hlo, - absl::string_view expected) { - HloPassPipeline pipeline("gemm_rewrite"); - pipeline.AddPass(backend() - .default_stream_executor() - ->GetDeviceDescription() - .cuda_compute_capability()); - tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "", - tsl::port::MaxParallelism()); - DebugOptions opts; - pipeline.AddPass( - AutotuneConfig{DeviceConfig{backend().default_stream_executor(), - backend().memory_allocator()}, - opts}, - &thread_pool); - - RunAndFilecheckHloRewrite( - hlo, std::move(pipeline), expected, [](const HloModule* m) { - VLOG(5) << m->ToString(); - const HloInstruction* dot_fusion = - m->entry_computation()->root_instruction(); - if (dot_fusion->opcode() == HloOpcode::kReduce) { - dot_fusion = dot_fusion->operand(0); - } - CHECK_EQ(dot_fusion->opcode(), HloOpcode::kFusion); - CHECK_GT(dot_fusion->backend_config() - .value() - .triton_gemm_config() - .block_m(), - 0); - }); - } -}; - -class TritonAutotunerTestWithMorePreciseReduction : public TritonAutotunerTest { - public: - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = TritonAutotunerTest::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_triton_gemm_disable_reduced_precision_reduction( - true); - return debug_options; - } -}; - -TEST_F(TritonAutotunerTest, VoltaUsesNoMoreThanTwoStages) { - std::unique_ptr module = ParseAndReturnVerifiedModule(R"( -ENTRY e { - p0 = f32[1024,1024] parameter(0) - p1 = f32[1024,1024] parameter(1) - ROOT r = f32[1024,1024] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})") - .value(); - const se::CudaComputeCapability compute_capability{ - se::CudaComputeCapability::VOLTA, /*minor=*/0}; - const std::vector configs = - GetPossibleMatmulAutotuneConfigs( - *Cast( - module->entry_computation()->root_instruction()), - compute_capability, GetDebugOptionsForTest()); - EXPECT_FALSE(std::any_of( - configs.begin(), configs.end(), - [](const TritonGemmConfig& config) { return config.num_stages > 2; })); -} - -TEST_F(TritonAutotunerTest, AmpereUsesMoreThanTwoStages) { - std::unique_ptr module = ParseAndReturnVerifiedModule(R"( -ENTRY e { - p0 = f32[1024,1024] parameter(0) - p1 = f32[1024,1024] parameter(1) - ROOT r = f32[1024,1024] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})") - .value(); - const se::CudaComputeCapability compute_capability{ - se::CudaComputeCapability::AMPERE, /*minor=*/0}; - const std::vector configs = - GetPossibleMatmulAutotuneConfigs( - *Cast( - module->entry_computation()->root_instruction()), - compute_capability, GetDebugOptionsForTest()); - EXPECT_TRUE(std::any_of( - configs.begin(), configs.end(), - [](const TritonGemmConfig& config) { return config.num_stages > 2; })); -} - -TEST_F(TritonAutotunerTest, SmallOutputCanUseLargeSplitK) { - std::unique_ptr module = ParseAndReturnVerifiedModule(R"( -ENTRY e { - p0 = f32[1024,1024] parameter(0) - p1 = f32[1024,1024] parameter(1) - ROOT r = f32[1024,1024] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})") - .value(); - const se::CudaComputeCapability compute_capability{ - se::CudaComputeCapability::AMPERE, /*minor=*/0}; - const std::vector configs = - GetPossibleMatmulAutotuneConfigs( - *Cast( - module->entry_computation()->root_instruction()), - compute_capability, GetDebugOptionsForTest()); - EXPECT_TRUE(std::any_of( - configs.begin(), configs.end(), - [](const TritonGemmConfig& config) { return config.split_k >= 16; })); -} - -TEST_F(TritonAutotunerTest, LargeOutputDoesNotUseLargeSplitK) { - std::unique_ptr module = ParseAndReturnVerifiedModule(R"( -ENTRY e { - p0 = f32[20480,20480] parameter(0) - p1 = f32[20480,20480] parameter(1) - ROOT r = f32[20480,20480] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})") - .value(); - const se::CudaComputeCapability compute_capability{ - se::CudaComputeCapability::AMPERE, /*minor=*/0}; - const std::vector configs = - GetPossibleMatmulAutotuneConfigs( - *Cast( - module->entry_computation()->root_instruction()), - compute_capability, GetDebugOptionsForTest()); - EXPECT_FALSE(std::any_of( - configs.begin(), configs.end(), - [](const TritonGemmConfig& config) { return config.split_k > 1; })); -} - -TEST_F(TritonAutotunerTest, Int8FusedGemm) { - const std::string hlo = R"( -HloModule module - -ENTRY e { - x = s8[128,64] parameter(0) - c = f16[128,64] convert(x) - - y = f16[64,6144] parameter(1) - - ROOT out = f16[128,6144] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} -)"; - CheckTritonAutotuning(hlo, R"( -// CHECK: ENTRY -// CHECK: ROOT -// CHECK-SAME: kCustom -// CHECK-SAME: block_m -)"); - - EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{/*aabs=*/5e-3, /*arel=*/5e-3})); -} - -TEST_F(TritonAutotunerTest, Int8FusedGemm256) { - const std::string hlo = R"( -HloModule module - -ENTRY e { - x = s8[128,256] parameter(0) - c = f16[128,256] convert(x) - - y = f16[256,6144] parameter(1) - - ROOT out = f16[128,6144] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} -)"; - - CheckTritonAutotuning(hlo, R"( -// CHECK: ENTRY -// CHECK: ROOT -// CHECK-SAME: kCustom -// CHECK-SAME: block_m -)"); - - EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); -} - -TEST_F(TritonAutotunerTest, SelectsSplitK) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } - // Shapes with K >> M, N have to force split-K configurations. - const std::string kHloText = R"( -HloModule t - -ENTRY e { - p0 = s8[7,8192] parameter(0) - p0c = bf16[7,8192] convert(p0) - p1 = bf16[8192,18] parameter(1) - ROOT dot.0 = bf16[7,18] dot(p0c, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: reduce -; CHECK: ENTRY -; CHECK-NEXT: parameter -; CHECK-NEXT: parameter -; CHECK-NEXT: kCustom -; CHECK-NEXT: kLoop -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/4, /*arel=*/1e-1})); -} - -TEST_F(TritonAutotunerTestWithMorePreciseReduction, SelectsSplitK) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } - // Shapes with K >> M, N have to force split-K configurations. - constexpr absl::string_view kHloText = R"( -HloModule t - -ENTRY e { - p0 = s8[7,8192] parameter(0) - p0c = bf16[7,8192] convert(p0) - p1 = bf16[8192,18] parameter(1) - ROOT dot.0 = bf16[7,18] dot(p0c, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: reduce -; CHECK: ENTRY -; CHECK-NEXT: parameter -; CHECK-NEXT: parameter -; CHECK-NEXT: kCustom -; CHECK-NEXT: kLoop -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); -} - -TEST_F(TritonAutotunerTest, ApplySplitKWithoutAlteringTiling) { - const std::string kHloText = R"( -triton_dot { - p0 = f16[55,120] parameter(0) - p1 = f16[120,20] parameter(1) - ROOT dot = f16[55,20] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = f16[55,120]{1,0} parameter(0) - p1 = f16[120,20]{1,0} parameter(1) - ROOT _ = f16[55,20] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32,"split_k":3,"num_stages":1,"num_warps":2}} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: f16[3,55,20] -; CHECK: {"block_m":16,"block_n":64,"block_k":32,"split_k":3,"num_stages":1,"num_warps":2} -; CHECK: reduce -)"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - -TEST_F(TritonAutotunerTest, DoNotRunAutotuningKernelSpillingRegisters) { - const std::string kHloText = R"( -HloModule m - -%triton_gemm_dot { - %p1 = s8[4,12288]{1,0} parameter(1) - %p0 = s8[12288,1536]{1,0} parameter(0) - %convert.p0 = f16[12288,1536]{1,0} convert(s8[12288,1536]{1,0} %p0) - %convert.p1 = f16[4,12288]{1,0} convert(s8[4,12288]{1,0} %p1) - %dot = f16[4,1536]{1,0} dot(f16[4,12288]{1,0} %convert.p1, f16[12288,1536]{1,0} %convert.p0), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT %convert = s8[4,1536]{1,0} convert(f16[4,1536]{1,0} %dot) -} - -ENTRY %e { - %get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0) - %convert = s8[4,12288]{1,0} parameter(1) - ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot, - backend_config={"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"16","split_k":"1","num_stages":"1","num_warps":"16"}} -})"; - - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "Not enough shared memory to run big tiles before Ampere."; - } - auto module = ParseAndReturnVerifiedModule(kHloText).value(); - EXPECT_THAT( - backend().compiler()->RunBackend(std::move(module), - backend().default_stream_executor(), - {/*device_allocator=*/nullptr, - /*thread_pool=*/nullptr, - /*layout_canonicalization_callback=*/{}, - /*is_autotuning_compilation=*/true}), - tsl::testing::StatusIs( - tsl::error::CANCELLED, - absl::StrFormat( - "Compilation result discarded due to register spilling"))); -} - -TEST_F(TritonAutotunerTest, DoNotFilterOutAutotuningKernelSpillingRegisters) { - const std::string kHloText = R"( -HloModule m - -%triton_gemm_dot { - %p1 = s8[4,12288]{1,0} parameter(1) - %p0 = s8[12288,1536]{1,0} parameter(0) - %convert.p0 = f16[12288,1536]{1,0} convert(s8[12288,1536]{1,0} %p0) - %convert.p1 = f16[4,12288]{1,0} convert(s8[4,12288]{1,0} %p1) - %dot = f16[4,1536]{1,0} dot(f16[4,12288]{1,0} %convert.p1, f16[12288,1536]{1,0} %convert.p0), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT %convert = s8[4,1536]{1,0} convert(f16[4,1536]{1,0} %dot) -} - -ENTRY %e { - %get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0) - %convert = s8[4,12288]{1,0} parameter(1) - ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot, - backend_config={"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"16","split_k":"1","num_stages":"1","num_warps":"16"}} -})"; - - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "Not enough shared memory to run big tiles before Ampere."; - } - auto module = ParseAndReturnVerifiedModule(kHloText).value(); - HloModuleConfig config = module->config(); - DebugOptions debug_options = config.debug_options(); - debug_options.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning( - false); - config.set_debug_options(debug_options); - module->set_config(config); - - std::unique_ptr executable = - backend() - .compiler() - ->RunBackend(std::move(module), backend().default_stream_executor(), - {/*device_allocator=*/nullptr, - /*thread_pool=*/nullptr, - /*layout_canonicalization_callback=*/{}, - /*is_autotuning_compilation=*/true}) - .value(); - EXPECT_NE(executable, nullptr); -} - -TEST_F(TritonAutotunerTest, RunAutotuningKernelNotSpillingRegisters) { - const std::string kHloText = R"( -HloModule m - -%triton_gemm_dot { - %p1 = f16[4,12288]{1,0} parameter(1) - %p0 = s8[12288,1536]{1,0} parameter(0) - %convert.10406 = f16[12288,1536]{1,0} convert(s8[12288,1536]{1,0} %p0) - ROOT %dot = f16[4,1536]{1,0} dot(f16[4,12288]{1,0} %p1, f16[12288,1536]{1,0} %convert.10406), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY %e { - %p0 = s8[12288,1536]{1,0} parameter(0) - %p1 = f16[4,12288]{1,0} parameter(1) - ROOT %triton_dot = f16[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %p0, f16[4,12288]{1,0} %p1), kind=kCustom, calls=%triton_gemm_dot, - backend_config={"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"32","block_k":"16","split_k":"1","num_stages":"1","num_warps":"2"}} -})"; - - auto module = ParseAndReturnVerifiedModule(kHloText).value(); - std::unique_ptr executable = - backend() - .compiler() - ->RunBackend(std::move(module), backend().default_stream_executor(), - {/*device_allocator=*/nullptr, - /*thread_pool=*/nullptr, - /*layout_canonicalization_callback=*/{}, - /*is_autotuning_compilation=*/true}) - .value(); - EXPECT_NE(executable, nullptr); -} - -// TODO(b/281489442): Write a testcase called -// `SkipConfigsProducingDeviantResults` or similar. - -class TritonAutotunerLevelTest : public StatelessAutotunerTest, - public ::testing::WithParamInterface { - public: - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = - StatelessAutotunerTest::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_autotune_level(GetParam()); - debug_options.set_xla_gpu_cublas_fallback(false); - return debug_options; - } -}; - -TEST_P(TritonAutotunerLevelTest, AllAutotuningLevelsWorkCorrectly) { - const std::string kHloText = R"( -HloModule m - -ENTRY e { - p0 = pred[64,10] parameter(0) - p0c = f32[64,10] convert(p0) - p1 = f32[10,128] parameter(1) - ROOT r = f32[64,128] dot(p0c, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - - MatchOptimizedHlo(kHloText, R"( -; CHECK: kind=kCustom -; CHECK-SAME: block_m - )"); - - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - -TEST_P(TritonAutotunerLevelTest, Deviceless) { - const std::string hlo = R"( -HloModule module - -ENTRY e { - x = s8[16,16] parameter(0) - c = f16[16,16] convert(x) - y = f16[16,16] parameter(1) - ROOT out = f16[16,16] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} -)"; - - HloPassPipeline pipeline("gemm_rewrite_deviceless"); - pipeline.AddPass(backend() - .default_stream_executor() - ->GetDeviceDescription() - .cuda_compute_capability()); - tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "", - tsl::port::MaxParallelism()); - DebugOptions opts; - pipeline.AddPass( - AutotuneConfig{DevicelessConfig{backend() - .default_stream_executor() - ->GetDeviceDescription() - .model_str(), - backend() - .default_stream_executor() - ->GetDeviceDescription() - .cuda_compute_capability()}, - opts}, - &thread_pool); - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo)); - if (GetDebugOptionsForTest().xla_gpu_autotune_level() == 0) { - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloTestBase::RunHloPass(&pipeline, module.get())); - EXPECT_TRUE(changed); - - // Check default configuration. - TF_ASSERT_OK_AND_ASSIGN( - bool filecheck_matches, - RunFileCheck( - module->ToString(HloPrintOptions{}.set_print_operand_shape(false)), - R"( -// CHECK: backend_config={"kind":"__triton_gemm","triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"32","split_k":"1","num_stages":"1","num_warps":"4"}} - )")); - EXPECT_TRUE(filecheck_matches); - } else { - EXPECT_THAT(HloTestBase::RunHloPass(&pipeline, module.get()), - tsl::testing::StatusIs( - tsl::error::INTERNAL, - ::testing::HasSubstr( - "Expect autotune result cache hit for deviceless"))); - } -} - -INSTANTIATE_TEST_SUITE_P(TritonAutotunerLevelSweep, TritonAutotunerLevelTest, - ::testing::Range(0, 5)); - -class TritonAutotunerExhaustiveTest : public TritonAutotunerTest { - public: - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = TritonAutotunerTest::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_exhaustive_tiling_search(true); - return debug_options; - } -}; - -TEST_F(TritonAutotunerExhaustiveTest, DISABLED_CompileOnly) { - const std::string hlo = R"( -HloModule module - -ENTRY e { - x = s8[16,16] parameter(0) - c = f16[16,16] convert(x) - y = f16[16,16] parameter(1) - ROOT out = f16[16,16] dot(c, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} -)"; - - CheckTritonAutotuning(hlo, R"( -// CHECK: %triton_gemm_out_computation ( -// CHECK: ROOT %out.1 = f16[16,16]{1,0} dot(%c.1, %parameter_1), lhs_contracting_dims={1}, rhs_contracting_dims={0} -// CHECK: ROOT %triton_gemm_out = f16[16,16]{1,0} fusion(%x, %y), kind=kCustom, calls=%triton_gemm_out_computation -// CHECK-SAME: "block_m": -)"); -} - - -class TritonAutotunerDisableSplitK : public TritonAutotunerTest { - public: - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = TritonAutotunerTest::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_split_k_autotuning(false); - return debug_options; - } -}; - -TEST_F(TritonAutotunerDisableSplitK, SplitKIsDisabled) { - std::unique_ptr module = ParseAndReturnVerifiedModule(R"( -ENTRY e { - p0 = f32[1024,1024] parameter(0) - p1 = f32[1024,1024] parameter(1) - ROOT r = f32[1024,1024] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})") - .value(); - const se::CudaComputeCapability compute_capability{ - se::CudaComputeCapability::AMPERE, /*minor=*/0}; - const std::vector configs = - GetPossibleMatmulAutotuneConfigs( - *Cast( - module->entry_computation()->root_instruction()), - compute_capability, GetDebugOptionsForTest()); - EXPECT_TRUE(std::all_of( - configs.begin(), configs.end(), - [](const TritonGemmConfig& config) { return config.split_k == 1; })); -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/triton_call.cc b/xla/service/gpu/triton_call.cc new file mode 100644 index 0000000000000..dfc88e578b0eb --- /dev/null +++ b/xla/service/gpu/triton_call.cc @@ -0,0 +1,51 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/triton_call.h" + +#include +#include +#include + +#include "mlir/AsmParser/AsmParser.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace xla::gpu { + +TritonCall TritonCall::Parse(std::string_view backend_config, + mlir::MLIRContext* mlir_context) { + // TODO(slebedev): Plumb through num_ctas and enable_wrap_specialization. + auto attrs = mlir::cast( + mlir::parseAttribute(backend_config, mlir_context)); + auto name = attrs.getAs("name").getValue().str(); + auto ir = attrs.getAs("ir").str(); + auto grid_x = static_cast( + attrs.getAs("grid_x").getValue().getSExtValue()); + auto grid_y = static_cast( + attrs.getAs("grid_y").getValue().getSExtValue()); + auto grid_z = static_cast( + attrs.getAs("grid_z").getValue().getSExtValue()); + auto num_stages = + attrs.getAs("num_stages").getValue().getSExtValue(); + auto num_warps = + attrs.getAs("num_warps").getValue().getSExtValue(); + return TritonCall{std::move(name), std::move(ir), num_stages, num_warps, + grid_x, grid_y, grid_z}; +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/triton_call.h b/xla/service/gpu/triton_call.h new file mode 100644 index 0000000000000..169e4e703e7dc --- /dev/null +++ b/xla/service/gpu/triton_call.h @@ -0,0 +1,43 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRITON_CALL_H_ +#define XLA_SERVICE_GPU_TRITON_CALL_H_ + +#include +#include +#include + +#include "mlir/IR/MLIRContext.h" // from @llvm-project + +namespace xla::gpu { + +struct TritonCall { + std::string name; + std::string ir; + int64_t num_stages; + int64_t num_warps; + int32_t grid_x; + int32_t grid_y; + int32_t grid_z; + + // Parse the metadata of a __gpu$xla.gpu.triton call. + static TritonCall Parse(std::string_view backend_config, + mlir::MLIRContext* mlir_context); +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_TRITON_CALL_H_ diff --git a/xla/service/gpu/triton_fusion_analysis.cc b/xla/service/gpu/triton_fusion_analysis.cc index 1d9188667a9bb..7af7b2a484188 100644 --- a/xla/service/gpu/triton_fusion_analysis.cc +++ b/xla/service/gpu/triton_fusion_analysis.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/gpu/triton_fusion_analysis.h" #include +#include #include #include #include @@ -24,10 +25,12 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/gpu/matmul_utils.h" @@ -36,8 +39,10 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" -#include "xla/statusor.h" +#include "xla/tools/hlo_decomposer.h" +#include "xla/util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -46,6 +51,7 @@ namespace { using triton_fusion::DimOrdersAndReqs; using triton_fusion::DimOrdersAndReqsOrError; +using triton_fusion::DotRequirements; using triton_fusion::FusionContext; using triton_fusion::GetPropagatedDimOrdersAndRequirements; using triton_fusion::kNoSplitRequirement; @@ -55,14 +61,17 @@ using triton_fusion::TransformDirection; namespace triton_fusion { -/*static*/ FusionContext FusionContext::FromDotOperand( +/*static*/ absl::StatusOr FusionContext::FromDotOperand( const HloInstruction& dot, const int operand_number, const int split_k) { // There can be either none or one split-K batch dimension. const int num_split_k_batch_dims = split_k > 1; int split_k_dimension_index = kNoDimensionIndex; + TF_ASSIGN_OR_RETURN(int contracting_dimension_index, + ContractingDimensionIndex(dot, operand_number)); + TF_ASSIGN_OR_RETURN(int non_contracting_dimension_index, + NonContractingDimensionIndex(dot, operand_number)); if (split_k > 1) { - split_k_dimension_index = - ContractingDimensionIndex(dot, operand_number) - 1; + split_k_dimension_index = contracting_dimension_index - 1; } int splittable_dimension_index = kNoDimensionIndex; // LHS non-contracting dimension can be split if non-splitK batch is absent. @@ -70,14 +79,11 @@ namespace triton_fusion { dot.dot_dimension_numbers().lhs_batch_dimensions_size() - num_split_k_batch_dims == 0) { - splittable_dimension_index = - NonContractingDimensionIndex(dot, operand_number); + splittable_dimension_index = non_contracting_dimension_index; } - FusionContext context( - DotProperties{ - static_cast(NonContractingDimensionIndex(dot, operand_number)), - splittable_dimension_index}, - DotRequirements(kNoSplitRequirement)); + FusionContext context(DotProperties{non_contracting_dimension_index, + splittable_dimension_index}, + DotRequirements(kNoSplitRequirement)); context.dim_orders_[dot.operand(operand_number)] = DimensionOrder::FromDotOperandOrOutput(*dot.operand(operand_number), split_k_dimension_index); @@ -86,19 +92,19 @@ namespace triton_fusion { /*static*/ FusionContext FusionContext::FromDotOutput( const HloInstruction& dot, const int split_k, - const int64_t splittable_dimension_major_part_size) { + DotRequirements requirements) { // Allow non-contracting dimension originating from LHS to split if // this dimension is split at the output at the same ratio as // at the input. int splittable_dimension_index = kNoDimensionIndex; - if (splittable_dimension_major_part_size > 1) { + if (requirements.splittable_dimension_major_part_size > 1) { // Split-K dimension is the first one in the output if present; // LHS non-contracting follows (batch is absent in this case). splittable_dimension_index = (split_k > 1) ? 1 : 0; } FusionContext context(DotProperties{/*noncontracting_dimension=*/-1, splittable_dimension_index}, - DotRequirements(splittable_dimension_major_part_size)); + std::move(requirements)); context.dim_orders_[&dot] = DimensionOrder::FromDotOperandOrOutput(dot); return context; } @@ -150,7 +156,7 @@ bool FusionContext::CombineDimOrdersAndReqs(const DimOrdersAndReqs& update) { return true; } -Status FusionContext::PropagateDimensionOrdersToParameters( +absl::Status FusionContext::PropagateDimensionOrdersToParameters( const HloInstruction& origin, ConstHloInstructionSet& parameters, ConstHloInstructionMap& iter_specs) { absl::flat_hash_set visited; @@ -167,14 +173,26 @@ Status FusionContext::PropagateDimensionOrdersToParameters( // more elementwise users - they share the same tiling. Situations when // one instruction is read differently by different users in the same // scope of the dot are currently prevented during the fusion. - TF_RET_CHECK(parameters.insert(hlo).second); + if (!parameters.insert(hlo).second) { + return FailedPrecondition( + "A parameter is read differently by different users. hlo: %s", + hlo->ToString()); + } VLOG(5) << hlo->ToString(); } DimOrdersAndReqsOrError result = GetPropagatedDimOrdersAndRequirements( *hlo, dim_orders_.at(hlo), TransformDirection::kOutputToInput, properties_); - TF_RET_CHECK(std::holds_alternative(result)); - TF_RET_CHECK(CombineDimOrdersAndReqs(std::get(result))); + + if (!std::holds_alternative(result)) { + return FailedPrecondition( + "Can not propagate dim orders and requirements."); + } + + if (!CombineDimOrdersAndReqs(std::get(result))) { + return FailedPrecondition("Can not combine dim orders and requirements."); + } + iter_specs[hlo] = dim_orders_.at(hlo).ToTensorIterationSpec(); for (const HloInstruction* operand : hlo->operands()) { if (!visited.insert(operand).second) { @@ -188,12 +206,12 @@ Status FusionContext::PropagateDimensionOrdersToParameters( to_process.push(operand); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace triton_fusion -StatusOr TritonFusionAnalysis::Execute( +absl::StatusOr TritonFusionAnalysis::Execute( const HloComputation& computation, const int split_k) { VLOG(5) << computation.ToString(HloPrintOptions::ShortParsable()); TritonFusionAnalysis analysis; @@ -208,7 +226,7 @@ StatusOr TritonFusionAnalysis::Execute( return analysis; } -Status TritonFusionAnalysis::ExecuteForSoftmaxFusion( +absl::Status TritonFusionAnalysis::ExecuteForSoftmaxFusion( const HloInstruction& root) { auto context = FusionContext::FromSoftmaxRoot(root); // Softmax fusion uses one tiled scope. @@ -216,25 +234,64 @@ Status TritonFusionAnalysis::ExecuteForSoftmaxFusion( root, parameters_[Scope::OUTPUT], iter_specs_[Scope::OUTPUT])); iter_specs_[Scope::LHS] = {}; iter_specs_[Scope::RHS] = {}; - return OkStatus(); + return absl::OkStatus(); } -Status TritonFusionAnalysis::ExecuteForDotFusion(const HloInstruction& dot, - const int split_k) { - int64_t lhs_nc_split_major_part_size = kNoSplitRequirement; - for (const Scope scope : {Scope::LHS, Scope::RHS}) { +absl::Status TritonFusionAnalysis::ExecuteForProducerConsumer( + const HloInstruction& producer, const HloInstruction& consumer, + int split_k) { + // TODO(shyshkov): Use HloFusionAdaptor to avoid the need to materialize the + // hlo fusion. + std::unique_ptr new_module = + ExtractProducerConsumerIntoNewModule(producer, consumer); + + auto* new_producer = + new_module->entry_computation()->GetInstructionWithName(producer.name()); + auto* new_consumer = + new_module->entry_computation()->GetInstructionWithName(consumer.name()); + + std::unique_ptr fusion_instruction_holder; + HloInstruction* fusion_instruction; + if (new_consumer->opcode() == HloOpcode::kFusion) { + fusion_instruction = new_consumer; + } else { + fusion_instruction_holder = HloInstruction::CreateFusion( + new_consumer->shape(), new_producer->fusion_kind(), new_consumer); + fusion_instruction = fusion_instruction_holder.get(); + } + + // Try to merge the producer into candidate fusion. + if (new_producer->opcode() == HloOpcode::kFusion) { + fusion_instruction->MergeFusionInstruction(new_producer); + } else { + fusion_instruction->FuseInstruction(new_producer); + } + + auto* fused_computation = + fusion_instruction->fused_instructions_computation(); + return Execute(*fused_computation, split_k).status(); +} + +absl::Status TritonFusionAnalysis::ExecuteForDotFusion( + const HloInstruction& dot, const int split_k) { + DotRequirements lhs_requirements(kNoSplitRequirement); + for (const Scope scope : {Scope::LHS, Scope::RHS, Scope::META}) { const int operand_number = static_cast(scope); - auto context = FusionContext::FromDotOperand(dot, operand_number, split_k); + if (dot.operand_count() < operand_number + 1) { + continue; // Meta scope is optional. + } + TF_ASSIGN_OR_RETURN(auto context, FusionContext::FromDotOperand( + dot, operand_number, split_k)); TF_RETURN_IF_ERROR(context.PropagateDimensionOrdersToParameters( *dot.operand(operand_number), parameters_[scope], iter_specs_[scope])); if (scope == Scope::LHS) { - lhs_nc_split_major_part_size = - context.splittable_dimension_major_part_size(); + lhs_requirements = std::get(context.requirements()); } } - auto context = - FusionContext::FromDotOutput(dot, split_k, lhs_nc_split_major_part_size); + // For now the RHS doesn't support splits, so it also doesn't impose any + // requirements. + auto context = FusionContext::FromDotOutput(dot, split_k, lhs_requirements); const HloInstruction* output = ˙ // Currently supported is one fusion output and one path from dot to it. // Propagate dimension order from dot to root. @@ -254,12 +311,13 @@ Status TritonFusionAnalysis::ExecuteForDotFusion(const HloInstruction& dot, .insert( {output, context.dim_orders().at(output).ToTensorIterationSpec()}) .second); + parameters_[Scope::OUTPUT] = {}; if (output != &dot) { // Propagate back to parameters of the output fusion. TF_RETURN_IF_ERROR(context.PropagateDimensionOrdersToParameters( *output, parameters_[Scope::OUTPUT], iter_specs_[Scope::OUTPUT])); } - return OkStatus(); + return absl::OkStatus(); } const TensorIterationSpec::DimIterationSpec* TritonFusionAnalysis::IterSpec( @@ -267,10 +325,8 @@ const TensorIterationSpec::DimIterationSpec* TritonFusionAnalysis::IterSpec( const int dimension) const { auto hlo_spec = iter_specs_.at(scope).find(hlo); if (hlo_spec != iter_specs_.at(scope).cend()) { - auto dim_spec = hlo_spec->second.Storage().find(dimension); - if (dim_spec != hlo_spec->second.Storage().cend()) { - return &dim_spec->second; - } + // The pointer returned here may also be nullptr. + return hlo_spec->second.Find(dimension); } return nullptr; } @@ -294,6 +350,8 @@ std::string ScopeToString(TritonFusionAnalysis::Scope s) { return "LHS"; case TritonFusionAnalysis::Scope::RHS: return "RHS"; + case TritonFusionAnalysis::Scope::META: + return "META"; case TritonFusionAnalysis::Scope::OUTPUT: return "OUTPUT"; } diff --git a/xla/service/gpu/triton_fusion_analysis.h b/xla/service/gpu/triton_fusion_analysis.h index 4f8c225ab378d..8459f86c8ffce 100644 --- a/xla/service/gpu/triton_fusion_analysis.h +++ b/xla/service/gpu/triton_fusion_analysis.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,18 +17,14 @@ limitations under the License. // This file contains TritonFusionAnalysis and FusionContext. -#include #include #include -#include -#include "absl/log/check.h" #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/triton_tiling_propagation.h" #include "xla/status.h" -#include "xla/statusor.h" #include "xla/xla_data.pb.h" namespace xla { @@ -36,20 +32,29 @@ namespace gpu { // Analysis of tensor iteration orders within tiled fusions. class TritonFusionAnalysis { - Status ExecuteForDotFusion(const HloInstruction& dot, int split_k); - Status ExecuteForSoftmaxFusion(const HloInstruction& root); + absl::Status ExecuteForDotFusion(const HloInstruction& dot, int split_k); + absl::Status ExecuteForSoftmaxFusion(const HloInstruction& root); public: // Execute the analysis of a fusion computation. // `split_k` indicates whether this operation was converted to the split-K // form and tells the analysis how to interpret the batch dimensions. - static StatusOr Execute( + static absl::StatusOr Execute( const HloComputation& computation, int split_k = 1); + // Execute the analysis of a produce-consumer fusion. Returns OkStatus, if the + // analysis can find a valid tiling for the producer-consumer fusion. + // `split_k` indicates whether this operation was converted to the split-K + // form and tells the analysis how to interpret the batch dimensions. + static absl::Status ExecuteForProducerConsumer(const HloInstruction& producer, + const HloInstruction& consumer, + int split_k = 1); + // A scope is an HLO graph that can be tiled efficiently using same or - // compatible tile shapes on all operations. GEMM fusion has 3 scopes - // defined by left operand, right operand and output. - enum class Scope { LHS = 0, RHS = 1, OUTPUT = 2 }; + // compatible tile shapes on all operations. GEMM fusion has 3 or 4 scopes + // defined by left operand, right operand, optional meta (third operand) and + // output. + enum class Scope { LHS = 0, RHS = 1, META = 2, OUTPUT = 3 }; using IterationSpecByInstructionMap = ConstHloInstructionMap; @@ -59,10 +64,11 @@ class TritonFusionAnalysis { // Every parameter requires a separate piece of shared memory for asynchronous // loads. Multiple parameters are approximately equivalent to multiple // pipeline stages. - // Note: this has been tuned specifically for GEMMs, where pipelining with + // Note: This has been tuned specifically for GEMMs, where pipelining with // more than 4 stages has been shown to rarely be practical. This limitation // is not necessarily applicable to other operations. - static constexpr int kMaxParameterPerDotScope = 4; + // Note: The limit doesn't apply to the epilogue of the fusion. + static constexpr int kMaxParameterPerDotOperand = 4; // Scope -> HLO -> dot dimension number -> iteration spec at the HLO's output. const TensorIterationSpec::DimIterationSpec* IterSpec(Scope scope, @@ -91,13 +97,13 @@ class FusionContext { public: // Create fusion context from a dot operand according to // the currently supported configurations. - static FusionContext FromDotOperand(const HloInstruction& dot, - int operand_number, int split_k = 1); + static absl::StatusOr FromDotOperand(const HloInstruction& dot, + int operand_number, + int split_k = 1); // Create fusion context from dot's output. - static FusionContext FromDotOutput( - const HloInstruction& dot, int split_k, - int64_t splittable_dimension_major_part_size); + static FusionContext FromDotOutput(const HloInstruction& dot, int split_k, + DotRequirements requirements); static FusionContext FromSoftmaxRoot(const HloInstruction&); @@ -108,17 +114,13 @@ class FusionContext { // Propagate dimension orders in consumer->producer direction starting at // `origin` with output `origin_dim_order` till parameters of the // computation. Store the found parameters and their iteration specs. - Status PropagateDimensionOrdersToParameters( + absl::Status PropagateDimensionOrdersToParameters( const HloInstruction& origin, ConstHloInstructionSet& parameters, ConstHloInstructionMap& iter_specs); - int64_t splittable_dimension_major_part_size() const { - CHECK(std::holds_alternative(requirements_)); - return std::get(requirements_) - .splittable_dimension_major_part_size; - } const HeroProperties& hero_properties() const { return properties_; } const DimOrderMap& dim_orders() const { return dim_orders_; } + const Requirements& requirements() const { return requirements_; } private: const HeroProperties properties_; diff --git a/xla/service/gpu/triton_fusion_analysis_test.cc b/xla/service/gpu/triton_fusion_analysis_test.cc index d2f67580234f6..4b0e2a481cd0d 100644 --- a/xla/service/gpu/triton_fusion_analysis_test.cc +++ b/xla/service/gpu/triton_fusion_analysis_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,10 +20,11 @@ limitations under the License. #include #include -#include "absl/strings/string_view.h" +#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/gpu/gemm_rewriter_triton.h" +#include "xla/service/gpu/gemm_fusion.h" +#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" @@ -38,6 +39,28 @@ using ::testing::FieldsAre; using TritonDotAnalysisTest = HloTestBase; +TEST_F(TritonDotAnalysisTest, QueryingOutputScopeParametersAlwaysWorks) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +triton_dot { + p0 = f32[8,8] parameter(0) + ROOT dot = f32[8,8] dot(p0, p0), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[8,8] parameter(0) + ROOT r = f32[8,8] fusion(p0), kind=kCustom, calls=triton_dot +})")); + TF_ASSERT_OK_AND_ASSIGN( + const auto analysis, + TritonFusionAnalysis::Execute(*module->entry_computation() + ->root_instruction() + ->called_computations()[0])); + EXPECT_TRUE( + analysis.ScopeParameters(TritonFusionAnalysis::Scope::OUTPUT).empty()); +} + TEST_F(TritonDotAnalysisTest, NopBitcasts) { const std::string hlo_text = R"( HloModule t @@ -92,6 +115,57 @@ ENTRY e { /*slice_limit=*/3, ElementsAre(3)))); } +TEST_F(TritonDotAnalysisTest, DoNotRemoveTrivialDimensionForDot) { + const std::string hlo_text = R"( +HloModule t, is_scheduled=true + +triton_dot { + param_0.1 = f32[137,115]{1,0} parameter(0) + param_1.1 = f32[1,115]{1,0} parameter(1) + ROOT dot = f32[137,1]{1,0} dot(param_0.1, param_1.1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = f32[137,115]{1,0} parameter(0) + p1 = f32[1,115]{1,0} parameter(1) + ROOT custom-call = f32[137,1]{1,0} fusion(p0, p1), kind=kCustom, + calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":2, + "num_ctas":1}}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + const HloComputation* dot_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const HloInstruction* p0 = dot_computation->parameter_instruction(0); + const HloInstruction* p1 = dot_computation->parameter_instruction(1); + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*dot_computation)); + EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).begin(), + p0); + EXPECT_EQ(*analysis.ScopeParameters(TritonFusionAnalysis::Scope::RHS).begin(), + p1); + EXPECT_THAT( + *analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 0), + ElementsAre(FieldsAre(/*stride=*/115, /*count=*/137, /*slice_start=*/0, + /*slice_limit=*/137, ElementsAre(137)))); + EXPECT_THAT( + *analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, p0, 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/115, /*slice_start=*/0, + /*slice_limit=*/115, ElementsAre(115)))); + EXPECT_THAT( + *analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 0), + ElementsAre(FieldsAre(/*stride=*/115, /*count=*/1, /*slice_start=*/0, + /*slice_limit=*/1, ElementsAre(1)))); + EXPECT_THAT( + *analysis.IterSpec(TritonFusionAnalysis::Scope::RHS, p1, 1), + ElementsAre(FieldsAre(/*stride=*/1, /*count=*/115, /*slice_start=*/0, + /*slice_limit=*/115, ElementsAre(115)))); +} + TEST_F(TritonDotAnalysisTest, Merge) { const std::string hlo_text = R"( HloModule t @@ -522,8 +596,8 @@ ENTRY e { lhs_contracting_dims={1}, rhs_contracting_dims={0} ROOT bc = bf16[2,2,100] broadcast(dot), dimensions={0,1} })")); - EXPECT_TRUE(GemmRewriterTriton(se::CudaComputeCapability{ - se::CudaComputeCapability::AMPERE, 0}) + EXPECT_TRUE(GemmFusion(se::CudaComputeCapability{ + se::CudaComputeCapability::AMPERE, 0}) .Run(module.get()) .value()); EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), @@ -568,6 +642,67 @@ ENTRY e { /*subfragments=*/ElementsAre(30)))); } +TEST_F(TritonDotAnalysisTest, + HandlesFurtherPropagationFromTrivialSizedTensorGracefully) { + // We could probably support this better, just checking to avoid a crash for + // now. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +triton_gemm_r { + a = f32[3,3]{1,0} parameter(0) + constant = f32[1,1]{1,0} constant({ {0} }) + broadcast = f32[1,1]{1,0} broadcast(constant), dimensions={0,1} + reshape = f32[] reshape(broadcast) + broadcast2 = f32[3,3]{1,0} broadcast(reshape), dimensions={} + ROOT dot = f32[3,3]{1,0} dot(a, broadcast2), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +} + +ENTRY e { + a = f32[3,3]{1,0} parameter(0) + ROOT dot = f32[3,3]{1,0} fusion(a), kind=kCustom, calls=triton_gemm_r, + backend_config={kind: "__triton_gemm"} +} +)")); + + const HloComputation* dot_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + + absl::StatusOr analysis = + TritonFusionAnalysis::Execute(*dot_computation); + // It can fail but shouldn't crash. + (void)analysis; +} + +TEST_F(TritonDotAnalysisTest, SparseDot) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +triton_gemm { + lhs = bf16[5,16] parameter(0) + rhs = bf16[32,10] parameter(1) + meta = u16[5,2] parameter(2) + ROOT dot = f32[5,10] dot(lhs, rhs, meta), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4 +} + +ENTRY main { + lhs = bf16[5,16] parameter(0) + rhs = bf16[32,10] parameter(1) + meta = u16[5,2] parameter(2) + ROOT out = f32[5,10] fusion(lhs, rhs, meta), + kind=kCustom, calls=triton_gemm, backend_config={kind:"__triton_gemm"} +} +)")); + + const HloComputation* dot_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*dot_computation)); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::META, + dot_computation->parameter_instruction(2), 0), + ::testing::SizeIs(1)); +} + using TritonSoftmaxAnalysisTest = HloTestBase; TEST_F(TritonSoftmaxAnalysisTest, DegenerateBatchDimensionIsSupported) { @@ -603,9 +738,11 @@ ENTRY e { ElementsAre(FieldsAre(/*stride=*/1, /*count=*/97, /*slice_start=*/0, /*slice_limit=*/97, /*subfragments=*/ElementsAre(97)))); - EXPECT_EQ(analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, - computation->root_instruction(), 1), - nullptr); + EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, + computation->root_instruction(), 1), + ElementsAre(FieldsAre(/*stride=*/97, /*count=*/1, + /*slice_start=*/0, /*slice_limit=*/1, + /*subfragments=*/ElementsAre(1)))); } TEST_F(TritonSoftmaxAnalysisTest, BroadcastIntoBatchDimensionIsSupported) { @@ -673,6 +810,208 @@ ENTRY main { EXPECT_FALSE(analysis.ok()); } +TEST_F(TritonSoftmaxAnalysisTest, PadWithinTritonSoftmaxIsNotSupported) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +triton_softmax_computation { + param_1 = f32[4,127]{1,0} parameter(0) + constant_0 = f32[] constant(0) + reduce = f32[4]{0} reduce(param_1, constant_0), dimensions={1}, to_apply=add + broadcast = f32[4,127]{1,0} broadcast(reduce), dimensions={0} + ROOT pad = f32[8,127]{1,0} pad(broadcast, constant_0), padding=0_4x0_0 +} + +ENTRY main { + param_0 = f32[4,127]{1,0} parameter(0) + ROOT fusion = f32[8,127]{1,0} fusion(param_0), kind=kCustom, + calls=triton_softmax_computation, + backend_config={"kind":"__triton_softmax"} +})")); + + const HloComputation* computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const auto analysis = TritonFusionAnalysis::Execute(*computation); + EXPECT_FALSE(analysis.ok()); +} + +TEST_F(TritonSoftmaxAnalysisTest, + BitcastWhichSplitsBatchAndReduceDimensionsIsNotSupported) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +triton_softmax_computation { + param_0 = f32[8,16129]{1,0} parameter(0) + bitcast = f32[8,127,127]{2,1,0} bitcast(param_0) + constant = f32[] constant(0) + reduce = f32[8,127]{1,0} reduce(bitcast, f32[] constant), dimensions={2}, to_apply=add + ROOT broadcast = f32[8,127,127]{2,1,0} broadcast(reduce), dimensions={0,1} +} + +ENTRY main { + param_1 = f32[8,16129]{1,0} parameter(0) + ROOT fusion = f32[8,127,127]{2,1,0} fusion(param_1), kind=kCustom, + calls=triton_softmax_computation, + backend_config={"kind":"__triton_softmax"} +})")); + + const HloComputation* computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const auto analysis = TritonFusionAnalysis::Execute(*computation); + EXPECT_FALSE(analysis.ok()); +} + +TEST_F(TritonSoftmaxAnalysisTest, + BitcastWhichSplitsReduceDimensionIsSupported) { + // Clone of BitcastWhichSplitsBatchAndReduceDimensionsIsNotSupported, + // But in this case the split dimension can be fully tiled as a reduce dim. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +triton_softmax_computation { + param_0 = f32[1,8,127,128]{3,2,1,0} parameter(0) + intermediate_bitcast = f32[8,127,2,64]{3,2,1,0} bitcast(param_0) + bitcast = f32[8,127,128]{2,1,0} bitcast(intermediate_bitcast) + constant = f32[] constant(0) + reduce = f32[8,127]{1,0} reduce(bitcast, constant), dimensions={2}, to_apply=add + ROOT broadcast = f32[8,127,128]{2,1 ,0} broadcast(reduce), dimensions={0,1} +} + +ENTRY main { + param_1 = f32[1,8,127,128]{3,2,1,0} parameter(0) + ROOT fusion = f32[8,127,128]{2,1,0} fusion(param_1), kind=kCustom, + calls=triton_softmax_computation, + backend_config={"kind":"__triton_softmax"} +})")); + + const HloComputation* computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*computation)); +} + +TEST_F(TritonSoftmaxAnalysisTest, + BitcastWhichDoesNotAffectReduceDimIsSupported) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +triton_softmax_computation { + param_0 = f32[1,2,4,127,128]{4,3,2,1,0} parameter(0) + bitcast = f32[8,127,128]{2,1,0} bitcast(param_0) + constant = f32[] constant(0) + reduce = f32[8,127]{1,0} reduce(bitcast, constant), dimensions={2}, to_apply=add + ROOT broadcast = f32[8,127,128]{2,1,0} broadcast(reduce), dimensions={0,1} +} + +ENTRY main { + param_1 = f32[1,2,4,127,128]{4,3,2,1,0} parameter(0) + ROOT fusion = f32[8,127,128]{2,1,0} fusion(param_1), kind=kCustom, + calls=triton_softmax_computation, + backend_config={"kind":"__triton_softmax"} +})")); + + const HloComputation* computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + TF_ASSERT_OK_AND_ASSIGN(const auto analysis, + TritonFusionAnalysis::Execute(*computation)); +} + +TEST_F(TritonSoftmaxAnalysisTest, SliceWithinTritonSoftmaxIsNotSupported) { + // Slices cannot yet be tiled into triton softmax (b/316637896) because they + // cannot be emitted. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule t + +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +triton_softmax_computation { + param_0 = f32[27,260]{1,0} parameter(0) + slice = f32[4,127]{1,0} slice(param_0), slice={[7:27:5], [6:260:2]} + constant_0 = f32[] constant(0) + reduce = f32[4]{0} reduce(slice, constant_0), dimensions={1}, to_apply=add + ROOT broadcast = f32[4,127]{1,0} broadcast(reduce), dimensions={0} +} + +ENTRY main { + param_0 = f32[27,260]{1,0} parameter(0) + ROOT fusion = f32[4,127]{1,0} fusion(param_0), kind=kCustom, + calls=triton_softmax_computation, + backend_config={"kind":"__triton_softmax"} +})")); + + const HloComputation* computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const auto analysis = TritonFusionAnalysis::Execute(*computation); + EXPECT_FALSE(analysis.ok()); +} + +TEST_F(TritonSoftmaxAnalysisTest, ProducerConsumerFusion) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +producer_computation { + parameter_0 = f32[125] parameter(0) + ROOT broadcast = f32[125,127] broadcast(parameter_0), dimensions={0} +} + +triton_softmax_computation { + parameter_0 = f32[125,127] parameter(0) + multiply_0 = f32[125,127] multiply(parameter_0, parameter_0) + constant_0 = f32[] constant(0) + reduce_0 = f32[125] reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[125,127] broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[125,127] multiply(multiply_0, broadcast_4) +} + +ENTRY main { + param_0 = f32[125] parameter(0) + param_1 = f32[125,127] parameter(1) + producer_fusion = f32[125,127] fusion(param_0), kind=kLoop, calls=producer_computation + ROOT triton_softmax = f32[125,127] fusion(producer_fusion), kind=kCustom, + calls=triton_softmax_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_softmax"}} +})")); + + auto consumer = module->entry_computation()->root_instruction(); + auto producer = consumer->operand(0); + + EXPECT_TRUE( + TritonFusionAnalysis::ExecuteForProducerConsumer(*producer, *consumer) + .ok()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/triton_support.cc b/xla/service/gpu/triton_support.cc index 1f7845ab79bb4..66631dbd19ad3 100644 --- a/xla/service/gpu/triton_support.cc +++ b/xla/service/gpu/triton_support.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,12 +16,19 @@ limitations under the License. #include "xla/service/gpu/triton_support.h" #include +#include #include #include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/variant_visitor.h" #include "xla/stream_executor/device_description.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/tensor_float_32_utils.h" namespace xla { namespace gpu { @@ -47,10 +54,9 @@ bool IsDistributiveOverAddition(const HloInstruction& hlo) { // BF16 is supported in a sense that all operations on it are implemented // through F32 and converts have to be inserted into the HLO graph, but // they can be missing during fusion. +// TODO(b/266862493): Support more data types (F8, F64, etc.). bool IsTritonSupportedDataType(PrimitiveType type, - se::GpuComputeCapability gpu_version) { - auto cuda_compute_capability = - std::get(gpu_version); + const se::GpuComputeCapability& gpu_version) { switch (type) { case PRED: case S8: @@ -60,8 +66,13 @@ bool IsTritonSupportedDataType(PrimitiveType type, case F32: return true; case BF16: - return cuda_compute_capability.IsAtLeast( - stream_executor::CudaComputeCapability::AMPERE); + return std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { + return true; + }, + [](const se::RocmComputeCapability& cc) { + return cc.has_bf16_dtype_support(); + }}, + gpu_version); default: return false; } @@ -84,7 +95,7 @@ std::vector TritonSupportedUnaryElementwise( HloOpcode::kLog1p, HloOpcode::kRsqrt, HloOpcode::kSin, HloOpcode::kSqrt, HloOpcode::kCbrt, HloOpcode::kTan, - HloOpcode::kTanh}, + HloOpcode::kTanh, HloOpcode::kErf}, std::back_inserter(ret)); } return ret; @@ -124,5 +135,218 @@ bool IsTritonSupportedElementwise(HloOpcode opcode, opcode); } +CodegenDecision CanTritonHandleElementwise( + const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { + if (!IsTritonSupportedDataType(instr.shape().element_type(), gpu_version)) { + return "Unsupported output data type."; + } + + for (const HloInstruction* operand : instr.operands()) { + if (!IsTritonSupportedDataType(operand->shape().element_type(), + gpu_version)) { + return "Unsupported input data type."; + } + } + + if (instr.opcode() == HloOpcode::kConstant) { + return CodegenDecision{}; + } else if (!IsTritonSupportedElementwise( + instr.opcode(), instr.operand(0)->shape().element_type())) { + return "Unsupported elementwise operation."; + } + return CodegenDecision{}; +} + +bool IsDotAlgorithmSupportedByTriton( + PrecisionConfig::Algorithm algorithm, + const se::GpuComputeCapability& gpu_version) { + auto cuda_compute_capability = + std::get_if(&gpu_version); + auto rocm_compute_capability = + std::get_if(&gpu_version); + switch (algorithm) { + case PrecisionConfig::ALG_DOT_TF32_TF32_F32: + if (cuda_compute_capability) { + return true; + } + return false; + case PrecisionConfig::ALG_DOT_BF16_BF16_F32: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: + if (cuda_compute_capability) { + return true; + } + if (rocm_compute_capability) { + return rocm_compute_capability->has_bf16_dtype_support(); + } + return false; + + // TODO(b/326579472): Fix the support of this algorithm and maybe allow it + // here. + case PrecisionConfig::ALG_DOT_F16_F16_F32: + // TODO(b/311331155): Triton F32 is about 3x slower than Triton TF32 and is + // slow to compile. Disable it for now. + case PrecisionConfig::ALG_DOT_F32_F32_F32: + default: + return false; + } +} + +// Filters GEMMs which can be handled using Triton. +CodegenDecision CanTritonHandleGEMM( + const HloDotInstruction& dot, const se::GpuComputeCapability& gpu_version) { + auto cuda_compute_capability = + std::get_if(&gpu_version); + auto rocm_compute_capability = + std::get_if(&gpu_version); + + CHECK(cuda_compute_capability || rocm_compute_capability); + + if (dot.precision_config().algorithm() == PrecisionConfig::ALG_UNSET) { + if (!tsl::tensor_float_32_execution_enabled() || + absl::c_any_of(dot.precision_config().operand_precision(), + [](int x) { return x != PrecisionConfig::DEFAULT; })) { + return "Having non-default operand precisions or TensorFloat-32 disabled " + "for Dot op with unset algorithm."; + } + } else { + if (!IsDotAlgorithmSupportedByTriton(dot.precision_config().algorithm(), + gpu_version)) { + return "Unsupported algorithm on the current device(s)."; + } + } + + auto supported_output_type = [&](const PrimitiveType t) { + switch (t) { + case F16: + case F32: + return true; + case BF16: + if (cuda_compute_capability) { + return true; + } + if (rocm_compute_capability) { + return rocm_compute_capability->has_bf16_dtype_support(); + } + return false; + default: + return false; + } + }; + + // TODO(b/266862493): Support more output types. + if (!supported_output_type(dot.shape().element_type())) { + return "Unsupported output data type for Dot op."; + } + + if (!IsTritonSupportedDataType(dot.operand(0)->shape().element_type(), + gpu_version) || + !IsTritonSupportedDataType(dot.operand(1)->shape().element_type(), + gpu_version)) { + return "Unsupported input data type for Dot op."; + } + + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); + + // TODO(b/269580541): support multiple batch dimensions. + if (dim_numbers.lhs_batch_dimensions().size() > 1) { + return "Multiple batch dimensions."; + } + + // Cases where lhs or rhs have no non-contracting dims are not handled. + if (dim_numbers.lhs_batch_dimensions().size() + + dim_numbers.lhs_contracting_dimensions().size() == + dot.operand(0)->shape().rank() || + dim_numbers.rhs_batch_dimensions().size() + + dim_numbers.rhs_contracting_dimensions().size() == + dot.operand(1)->shape().rank()) { + return "No non-contracting dimensions."; + } + + return CodegenDecision{}; +} + +// Filters Reduces which can be handled using Triton. +CodegenDecision CanTritonHandleReduce( + const HloReduceInstruction& reduce, + const se::GpuComputeCapability& gpu_version) { + if (!IsTritonSupportedDataType(reduce.shape().element_type(), gpu_version)) { + return "Unsupported output data type for Reduce op."; + } + + for (const HloInstruction* operand : reduce.operands()) { + if (!IsTritonSupportedDataType(operand->shape().element_type(), + gpu_version)) { + return "Unsupported input data type for Reduce op."; + } + } + + bool is_triton_supported_reduction_computation = [&]() { + return absl::c_all_of( + reduce.to_apply()->instructions(), [&](const HloInstruction* instr) { + return IsTritonSupportedInstruction(*instr, gpu_version); + }); + }(); + if (!is_triton_supported_reduction_computation) { + return "Unsupported reduction computation by Triton."; + } + + if (reduce.dimensions().size() == 1 && + reduce.dimensions().front() == reduce.operand(0)->shape().rank() - 1 && + reduce.operand_count() == 2) { + const HloInstruction* operand = reduce.operand(1); + // We assume that the reduction init value was input as a constant, or in + // the case of a data type affected by float normalization, a convert of a + // constant. + if (operand->opcode() == HloOpcode::kConvert) { + if (operand->operand(0)->opcode() == HloOpcode::kConstant && + operand->operand(0)->shape().element_type() == BF16 && + operand->shape().element_type() == F32) { + return CodegenDecision{}; + } + } else if (operand->opcode() == HloOpcode::kConstant) { + return CodegenDecision{}; + } + return "Reduction init value should be a constant or a convert of a " + "constant."; + } + return "Reduction is not a row-reduction of a single operand."; +} + +CodegenDecision IsTritonSupportedInstruction( + const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { + if (instr.IsElementwise()) { + return CanTritonHandleElementwise(instr, gpu_version); + } + + switch (instr.opcode()) { + case HloOpcode::kDot: { + return CanTritonHandleGEMM(*Cast(&instr), gpu_version); + } + case HloOpcode::kReduce: { + return CanTritonHandleReduce(*Cast(&instr), + gpu_version); + } + case HloOpcode::kTuple: { + if (instr.IsRoot()) { + return CodegenDecision{}; + } + return "Only supports root tuples."; + } + case HloOpcode::kBitcast: + case HloOpcode::kTranspose: + case HloOpcode::kSlice: + case HloOpcode::kReshape: + case HloOpcode::kPad: + case HloOpcode::kConcatenate: + case HloOpcode::kParameter: + case HloOpcode::kBroadcast: + return CodegenDecision{}; + default: + break; + } + return "Unsupported opcode."; +} + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/triton_support.h b/xla/service/gpu/triton_support.h index f3dcd23d954d3..072c9ab948ec0 100644 --- a/xla/service/gpu/triton_support.h +++ b/xla/service/gpu/triton_support.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,11 +22,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/instruction_fusion.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" namespace xla { namespace gpu { +using CodegenDecision = FusionDecision; // Tells if f(a+b) == f(a) + f(b). bool IsDistributiveOverAddition(const HloInstruction& hlo); @@ -41,11 +43,15 @@ std::vector TritonSupportedBinaryElementwise(PrimitiveType); std::vector TritonSupportedTernaryElementwise(PrimitiveType); // Data types that are supported by the Triton emitters. -bool IsTritonSupportedDataType(PrimitiveType, se::GpuComputeCapability); +bool IsTritonSupportedDataType(PrimitiveType, const se::GpuComputeCapability&); // Checks elementwise operation against all supported by Triton GEMM codegen. bool IsTritonSupportedElementwise(HloOpcode, PrimitiveType); +// Checks instruction against requirements of triton emitter. +CodegenDecision IsTritonSupportedInstruction( + const HloInstruction& instr, const se::GpuComputeCapability& gpu_version); + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/triton_support_test.cc b/xla/service/gpu/triton_support_test.cc new file mode 100644 index 0000000000000..e3ad43b2f0f78 --- /dev/null +++ b/xla/service/gpu/triton_support_test.cc @@ -0,0 +1,940 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/triton_support.h" + +#include +#include +#include +#include + +#include +#include +#include "absl/base/optimization.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "llvm/IR/Module.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/primitive_util.h" +#include "xla/service/float_normalization.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/gpu_float_support.h" +#include "xla/service/gpu/ir_emitter_triton.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/service/gpu/triton_fusion_analysis.h" +#include "xla/service/hlo_pass_pipeline.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class TritonSupportTest : public GpuCodegenTest { + public: + se::CudaComputeCapability GetCudaComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + } + absl::StatusOr ApplyFloatNormalization(HloModule* module) { + const GpuFloatSupport bf16_support(GetCudaComputeCapability(), BF16); + HloPassPipeline pipeline("hlo float normalization"); + pipeline.AddPass(&bf16_support); + return pipeline.Run(module); + } + + float getTolerance(PrimitiveType data_type) { + float tolerance; + switch (data_type) { + case F64: + case F32: + tolerance = 1e-6; + break; + case F16: + tolerance = 2e-4; + break; + case BF16: + tolerance = 2e-2; + break; + case PRED: + case S8: + tolerance = 3e-2; + break; + case S16: + tolerance = 3e-3; + break; + case S32: + tolerance = 3e-3; + break; + default: + ABSL_UNREACHABLE(); + } + return tolerance; + } + + protected: + llvm::LLVMContext llvm_ctx_; + llvm::Module llvm_module_{"module", llvm_ctx_}; + mlir::MLIRContext mlir_context_; + TritonGemmConfig config_{16, 32, 512, 1, 4, 8}; +}; + +class TritonSupportTestWithParam : public TritonSupportTest, + public ::testing::WithParamInterface< + std::tuple> {}; + +std::string TestParamsToString( + const ::testing::TestParamInfo>& + data) { + PrimitiveType data_type; + HloOpcode opcode; + std::tie(data_type, opcode) = data.param; + return absl::StrCat( + primitive_util::LowercasePrimitiveTypeName(data_type), "_", + absl::StrReplaceAll(HloOpcodeString(opcode), {{"-", "_"}})); +} + +using UnaryElementwiseTest = TritonSupportTestWithParam; + +// TODO(b/331636835): updates elementwise op tests to directly emit single op +// instead of relying on triton gemm kernel. +TEST_P(UnaryElementwiseTest, IsTritonSupportedExecutesCorrectlyForUnary) { + PrimitiveType data_type; + HloOpcode opcode; + std::tie(data_type, opcode) = GetParam(); + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE) && + data_type == BF16) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloTestTemplate = R"( +triton_gemm___computation { + parameter_0 = f32[15,33]{1,0} parameter(0) + parameter_1 = $0[33,68]{1,0} parameter(1) + unary = $0[33,68]{1,0} $1(parameter_1) + convert = f32[33,68]{1,0} convert(unary) + ROOT dot = f32[15,68]{1,0} dot(parameter_0, convert), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + operand_precision={HIGH, HIGH} +} + +ENTRY e { + parameter_0 = f32[15,33]{1,0} parameter(0) + parameter_1 = $0[33,68]{1,0} parameter(1) + ROOT triton_gemm = f32[15,68]{1,0} fusion(parameter_0, parameter_1), + kind=kCustom, calls=triton_gemm___computation, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_test)); + const HloComputation* computation = + module->GetComputationWithName("triton_gemm___computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, opcode); + if (IsTritonSupportedInstruction(*instr, GetCudaComputeCapability())) { + float tolerance = getTolerance(data_type); + EXPECT_OK(ApplyFloatNormalization(module.get())); + EXPECT_TRUE(RunAndCompareNoHloPasses( + std::move(module), ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance})); + } else { + // TODO(b/331632717): update the check to use SymbolicTileAnalysis to avoid + // tiling failures and check triton emitter fails gracefully. + EXPECT_THAT(TritonFusionAnalysis::Execute(*computation), + tsl::testing::StatusIs( + absl::StatusCode::kFailedPrecondition, + ::testing::HasSubstr( + "Can not propagate dim orders and requirements"))); + } +} + +INSTANTIATE_TEST_SUITE_P( + UnaryElementwiseTestSuite, UnaryElementwiseTest, + ::testing::Combine(::testing::Values(S8, S16, S32, F16, F32, BF16), + ::testing::Values(HloOpcode::kConvert, HloOpcode::kAbs, + HloOpcode::kNegate)), + TestParamsToString); +INSTANTIATE_TEST_SUITE_P( + UnaryPREDTestSuite, UnaryElementwiseTest, + ::testing::Combine(::testing::Values(PRED), + ::testing::Values(HloOpcode::kConvert, HloOpcode::kNot)), + TestParamsToString); +INSTANTIATE_TEST_SUITE_P( + UnaryMathTestSuite, UnaryElementwiseTest, + ::testing::Combine(::testing::Values(F16, F32, BF16), + ::testing::Values(HloOpcode::kCos, HloOpcode::kExp, + HloOpcode::kExpm1, HloOpcode::kLog, + HloOpcode::kLog1p, HloOpcode::kRsqrt, + HloOpcode::kSin, HloOpcode::kSqrt, + HloOpcode::kCbrt, HloOpcode::kTan, + HloOpcode::kTanh, HloOpcode::kErf)), + TestParamsToString); + +using BinaryElementwiseTest = TritonSupportTestWithParam; + +TEST_P(BinaryElementwiseTest, IsTritonSupportedExecutesCorrectlyForBinaryE) { + PrimitiveType data_type; + HloOpcode opcode; + std::tie(data_type, opcode) = GetParam(); + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE) && + data_type == BF16) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloTestTemplate = R"( +triton_gemm___computation { + parameter_0 = f32[92,11]{1,0} parameter(0) + parameter_1 = $0[11,63]{1,0} parameter(1) + parameter_2 = $0[11,63]{1,0} parameter(2) + binary = $0[11,63]{1,0} $1(parameter_1, parameter_2) + convert = f32[11,63]{1,0} convert(binary) + ROOT dot = f32[92,63]{1,0} dot(parameter_0, convert), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + operand_precision={HIGH, HIGH} +} + +ENTRY e { + parameter_0 = f32[92,11]{1,0} parameter(0) + parameter_1 = $0[11,63]{1,0} parameter(1) + parameter_2 = $0[11,63]{1,0} parameter(2) + ROOT triton_gemm = f32[92,63]{1,0} fusion(parameter_0, parameter_1, parameter_2), + kind=kCustom, calls=triton_gemm___computation, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_test)); + const HloComputation* computation = + module->GetComputationWithName("triton_gemm___computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, opcode); + if (IsTritonSupportedInstruction(*instr, GetCudaComputeCapability())) { + float tolerance = getTolerance(data_type); + EXPECT_OK(ApplyFloatNormalization(module.get())); + EXPECT_TRUE(RunAndCompareNoHloPasses( + std::move(module), ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance})); + } else { + EXPECT_THAT(TritonFusionAnalysis::Execute(*computation), + ::testing::AnyOf( + tsl::testing::StatusIs( + absl::StatusCode::kInternal, + ::testing::HasSubstr( + "std::holds_alternative")), + tsl::testing::StatusIs( + absl::StatusCode::kFailedPrecondition, + ::testing::HasSubstr( + "Can not propagate dim orders and requirements")))); + } +} + +INSTANTIATE_TEST_SUITE_P( + BinaryElementwiseTestSuite, BinaryElementwiseTest, + ::testing::Combine(::testing::Values(S8, S16, S32, F16, F32, BF16), + ::testing::Values(HloOpcode::kAdd, HloOpcode::kMultiply, + HloOpcode::kMaximum, + HloOpcode::kMinimum, + HloOpcode::kSubtract)), + TestParamsToString); + +INSTANTIATE_TEST_SUITE_P(BinaryPREDTestSuite, BinaryElementwiseTest, + ::testing::Combine(::testing::Values(PRED), + ::testing::Values(HloOpcode::kAnd, + HloOpcode::kOr, + HloOpcode::kXor)), + TestParamsToString); +INSTANTIATE_TEST_SUITE_P( + BinaryMathTestSuite, BinaryElementwiseTest, + ::testing::Combine(::testing::Values(F16, F32, BF16), + ::testing::Values(HloOpcode::kAtan2, HloOpcode::kDivide, + HloOpcode::kPower)), + TestParamsToString); + +using CompareTest = TritonSupportTestWithParam; + +TEST_P(CompareTest, IsTritonSupportedExecutesCorrectlyForCompare) { + PrimitiveType data_type; + HloOpcode opcode; + std::tie(data_type, opcode) = GetParam(); + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE) && + data_type == BF16) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloTestTemplate = R"( +triton_gemm___computation { + parameter_0 = f32[92,11]{1,0} parameter(0) + parameter_1 = $0[11,63]{1,0} parameter(1) + parameter_2 = $0[11,63]{1,0} parameter(2) + compare = pred[11,63]{1,0} $1(parameter_1, parameter_2), direction=GE + convert = f32[11,63]{1,0} convert(compare) + ROOT dot = f32[92,63]{1,0} dot(parameter_0, convert), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + operand_precision={HIGH, HIGH} +} + +ENTRY e { + parameter_0 = f32[92,11]{1,0} parameter(0) + parameter_1 = $0[11,63]{1,0} parameter(1) + parameter_2 = $0[11,63]{1,0} parameter(2) + ROOT triton_gemm = f32[92,63]{1,0} fusion(parameter_0, parameter_1, parameter_2), + kind=kCustom, calls=triton_gemm___computation, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_test)); + const HloComputation* computation = + module->GetComputationWithName("triton_gemm___computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, opcode); + if (IsTritonSupportedInstruction(*instr, GetCudaComputeCapability())) { + float tolerance = getTolerance(data_type); + EXPECT_OK(ApplyFloatNormalization(module.get())); + EXPECT_TRUE(RunAndCompareNoHloPasses( + std::move(module), ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance})); + } else { + EXPECT_THAT( + TritonFusionAnalysis::Execute(*computation), + tsl::testing::StatusIs( + absl::StatusCode::kInternal, + ::testing::HasSubstr("std::holds_alternative"))); + } +} + +INSTANTIATE_TEST_SUITE_P( + CompareTestSuite, CompareTest, + ::testing::Combine(::testing::Values(PRED, S8, S16, S32, F16, F32, BF16), + ::testing::Values(HloOpcode::kCompare)), + TestParamsToString); + +using TernaryElementwiseTest = TritonSupportTestWithParam; + +TEST_P(TernaryElementwiseTest, IsTritonSupportedExecutesCorrectlyForTernary) { + PrimitiveType data_type; + HloOpcode opcode; + std::tie(data_type, opcode) = GetParam(); + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE) && + data_type == BF16) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloTestTemplate = R"( +triton_gemm___computation { + parameter_0 = f32[92,13]{1,0} parameter(0) + parameter_1 = $0[13,63]{1,0} parameter(1) + parameter_2 = $0[13,63]{1,0} parameter(2) + parameter_3 = pred[13,63]{1,0} parameter(3) + ternary = $0[13,63]{1,0} $1(parameter_3, parameter_1, parameter_2) + convert = f32[13,63]{1,0} convert(ternary) + ROOT dot = f32[92,63]{1,0} dot(parameter_0, convert), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + operand_precision={HIGH, HIGH} +} + +ENTRY e { + parameter_0 = f32[92,13]{1,0} parameter(0) + parameter_1 = $0[13,63]{1,0} parameter(1) + parameter_2 = $0[13,63]{1,0} parameter(2) + parameter_3 = pred[13,63]{1,0} parameter(3) + ROOT triton_gemm = f32[92,63]{1,0} fusion(parameter_0, parameter_1, parameter_2, parameter_3), + kind=kCustom, calls=triton_gemm___computation, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_test)); + const HloComputation* computation = + module->GetComputationWithName("triton_gemm___computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, opcode); + if (IsTritonSupportedInstruction(*instr, GetCudaComputeCapability())) { + float tolerance = getTolerance(data_type); + EXPECT_OK(ApplyFloatNormalization(module.get())); + EXPECT_TRUE(RunAndCompareNoHloPasses( + std::move(module), ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance})); + } else { + EXPECT_THAT( + TritonFusionAnalysis::Execute(*computation), + tsl::testing::StatusIs( + absl::StatusCode::kInternal, + ::testing::HasSubstr("std::holds_alternative"))); + } +} + +INSTANTIATE_TEST_SUITE_P( + TernaryElementwiseTestSuite, TernaryElementwiseTest, + ::testing::Combine(::testing::Values(PRED, S8, S16, S32, F16, F32, BF16), + ::testing::Values(HloOpcode::kSelect)), + TestParamsToString); + +using DotTest = TritonSupportTestWithParam; + +TEST_P(DotTest, IsTritonSupportedExecutesCorrectlyForDot) { + PrimitiveType data_type; + HloOpcode opcode; + std::tie(data_type, opcode) = GetParam(); + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE) && + data_type == BF16) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloTestTemplate = R"( +triton_gemm___computation { + parameter_0 = $0[92,11]{1,0} parameter(0) + parameter_1 = $0[11,63]{1,0} parameter(1) + ROOT dot = $0[92,63]{1,0} $1(parameter_0, parameter_1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + parameter_0 = $0[92,11]{1,0} parameter(0) + parameter_1 = $0[11,63]{1,0} parameter(1) + ROOT triton_gemm = $0[92,63]{1,0} fusion(parameter_0, parameter_1), kind=kCustom, + calls=triton_gemm___computation, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_test)); + const HloComputation* computation = + module->GetComputationWithName("triton_gemm___computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, opcode); + if (IsTritonSupportedInstruction(*instr, GetCudaComputeCapability())) { + EXPECT_OK(ApplyFloatNormalization(module.get())); + EXPECT_TRUE(RunAndCompareNoHloPasses( + std::move(module), ErrorSpec{/*aabs=*/2e-4, /*arel=*/2e-4})); + } else { + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + EXPECT_THAT( + TritonWrapper(*TritonFusionAnalysis::Execute(*computation), "test_fn", + computation, GetCudaComputeCapability(), dev_info, + config_, &llvm_module_, &EmitMatMul, mlir_context_), + tsl::testing::StatusIs( + absl::StatusCode::kInternal, + ::testing::HasSubstr("Failed to compile Triton kernel"))); + } +} + +INSTANTIATE_TEST_SUITE_P(DotTestTestSuite, DotTest, + ::testing::Combine(::testing::Values(F16, F32, BF16), + ::testing::Values(HloOpcode::kDot)), + TestParamsToString); + +TEST_F(TritonSupportTest, UnsupportedDotOutputTypeFailsGracefullyWithTriton) { + const std::string kHloTest = R"( +triton_gemm___computation { + parameter_0 = f32[92,11]{1,0} parameter(0) + parameter_1 = f32[11,63]{1,0} parameter(1) + ROOT dot = pred[92,63]{1,0} dot(parameter_0, parameter_1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + parameter_0 = f32[92,11]{1,0} parameter(0) + parameter_1 = f32[11,63]{1,0} parameter(1) + ROOT triton_gemm = pred[92,63]{1,0} fusion(parameter_0, parameter_1), kind=kCustom, + calls=triton_gemm___computation, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloTest)); + + const HloComputation* computation = + hlo_module->GetComputationWithName("triton_gemm___computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + EXPECT_THAT(IsTritonSupportedInstruction(*instr, GetCudaComputeCapability()) + .Explain(), + ::testing::HasSubstr("Unsupported output data type for Dot op.")); + EXPECT_THAT( + TritonWrapper(*TritonFusionAnalysis::Execute(*computation), "test_fn", + computation, GetCudaComputeCapability(), dev_info, config_, + &llvm_module_, &EmitMatMul, mlir_context_), + tsl::testing::StatusIs( + absl::StatusCode::kInternal, + ::testing::HasSubstr("pm.run(triton_module.get()).succeeded()"))); +} + +TEST_F(TritonSupportTest, + UnsupportedDotWithMultipleBatchDimensionsFailsGracefullyWithTriton) { + const std::string kHloTest = R"( +triton_gemm___computation { + parameter_0 = f32[2,2,2,2]{3,2,1,0} parameter(0) + parameter_1 = f32[2,2,2,2]{3,2,1,0} parameter(1) + ROOT dot = f32[2,2,2,2]{3,2,1,0} dot(parameter_0, parameter_1), + lhs_contracting_dims={3}, lhs_batch_dims={1,0}, rhs_contracting_dims={2}, + rhs_batch_dims={1,0} +} + +ENTRY e { + parameter_0 = f32[2,2,2,2]{3,2,1,0} parameter(0) + parameter_1 = f32[2,2,2,2]{3,2,1,0} parameter(1) + ROOT triton_gemm = f32[2,2,2,2]{3,2,1,0} fusion(parameter_0, parameter_1), + kind=kCustom, calls=triton_gemm___computation, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloTest)); + + const HloComputation* computation = + hlo_module->GetComputationWithName("triton_gemm___computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + EXPECT_THAT(IsTritonSupportedInstruction(*instr, GetCudaComputeCapability()) + .Explain(), + ::testing::HasSubstr("Multiple batch dimensions")); + EXPECT_THAT( + TritonWrapper(*TritonFusionAnalysis::Execute(*computation), "test_fn", + computation, GetCudaComputeCapability(), dev_info, config_, + &llvm_module_, &EmitMatMul, mlir_context_), + tsl::testing::StatusIs(absl::StatusCode::kInternal, + ::testing::HasSubstr("num_batch_dims <= 1"))); +} + +TEST_F(TritonSupportTest, + UnsupportedDotWithNoNonContractingDimensionsFailsGracefullyWithTriton) { + const std::string kHloTest = R"( +triton_gemm___computation { + parameter_0 = f32[2]{0} parameter(0) + parameter_1 = f32[2]{0} parameter(1) + ROOT dot = f32[] dot(parameter_0, parameter_1), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +} + +ENTRY e { + parameter_0 = f32[2]{0} parameter(0) + parameter_1 = f32[2]{0} parameter(1) + ROOT triton_gemm = f32[] fusion(parameter_0, parameter_1), kind=kCustom, + calls=triton_gemm___computation, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloTest)); + + const HloComputation* computation = + hlo_module->GetComputationWithName("triton_gemm___computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); + EXPECT_THAT(IsTritonSupportedInstruction(*instr, GetCudaComputeCapability()) + .Explain(), + ::testing::HasSubstr("No non-contracting dimensions.")); + EXPECT_THAT(TritonFusionAnalysis::Execute(*computation), + tsl::testing::StatusIs( + absl::StatusCode::kInternal, + ::testing::HasSubstr("non_contracting_dims.size() == 1"))); +} + +using ReduceConstTest = TritonSupportTestWithParam; +TEST_P(ReduceConstTest, + IsTritonSupportedExecutesCorrectlyForReduceWithConstInit) { + PrimitiveType data_type; + HloOpcode opcode; + std::tie(data_type, opcode) = GetParam(); + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE) && + data_type == BF16) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloTestTemplate = R"( +HloModule t +add { + Arg_0 = $0[] parameter(0) + Arg_1 = $0[] parameter(1) + ROOT add = $0[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + parameter_0 = $0[125,127]{1,0} parameter(0) + multiply_0 = $0[125,127]{1,0} multiply(parameter_0, parameter_0) + constant_0 = $0[] constant(0) + reduce = $0[125]{0} $1(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast = $0[125,127]{1,0} broadcast(reduce), dimensions={0} + ROOT multiply = $0[125,127]{1,0} multiply(multiply_0, broadcast) +} + +ENTRY main { + parameter_0 = $0[125,127]{1,0} parameter(0) + ROOT triton_softmax = $0[125,127]{1,0} fusion(parameter_0), + kind=kCustom, calls=triton_softmax_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton_softmax"}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_test)); + + const HloComputation* computation = + module->GetComputationWithName("triton_softmax_computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, opcode); + if (IsTritonSupportedInstruction(*instr, GetCudaComputeCapability())) { + float tolerance = getTolerance(data_type); + EXPECT_OK(ApplyFloatNormalization(module.get())); + EXPECT_TRUE(RunAndCompareNoHloPasses( + std::move(module), ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance})); + } else { + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + EXPECT_THAT( + TritonWrapper(*TritonFusionAnalysis::Execute(*computation), "test_fn", + computation, GetCudaComputeCapability(), dev_info, + config_, &llvm_module_, &EmitSoftMax, mlir_context_), + tsl::testing::StatusIs( + absl::StatusCode::kInternal, + ::testing::HasSubstr("Failed to compile Triton kernel"))); + } +} + +INSTANTIATE_TEST_SUITE_P( + ReduceConstTestSuite, ReduceConstTest, + ::testing::Combine(::testing::Values(F16, F32, BF16), + ::testing::Values(HloOpcode::kReduce)), + TestParamsToString); + +TEST_F(TritonSupportTest, + SupportedReduceWithConvertConstantIsCodegenedSuccessfullyWithTriton) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + const std::string kHloTest = R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + parameter_0 = f32[125,127]{1,0} parameter(0) + multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) + constant_0 = bf16[] constant(0) + convert_0 = f32[] convert(constant_0) + reduce = f32[125]{0} reduce(multiply_0, convert_0), dimensions={1}, to_apply=add + broadcast = f32[125,127]{1,0} broadcast(reduce), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast) +} + +ENTRY main { + parameter_0 = f32[125,127]{1,0} parameter(0) + ROOT triton_softmax = f32[125,127]{1,0} fusion(parameter_0), kind=kCustom, + calls=triton_softmax_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton_softmax"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloTest)); + + const HloComputation* computation = + hlo_module->GetComputationWithName("triton_softmax_computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = hlo_query::GetFirstInstructionWithOpcode( + *computation, HloOpcode::kReduce); + EXPECT_TRUE(IsTritonSupportedInstruction(*instr, GetCudaComputeCapability()) + .CanFuse()); + EXPECT_OK(ApplyFloatNormalization(hlo_module.get())); + EXPECT_TRUE(RunAndCompareNoHloPasses( + std::move(hlo_module), ErrorSpec{/*aabs=*/2e-4, /*arel=*/2e-4})); +} + +TEST_F( + TritonSupportTest, + UnsupportedReduceWithMoreThanOneReduceDimensionsFailsGracefullyWithTriton) { + const std::string kHloTest = R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + parameter_0 = f32[2,125,127]{2,1,0} parameter(0) + multiply_0 = f32[2,125,127]{2,1,0} multiply(parameter_0, parameter_0) + constant_0 = f32[] constant(0) + reduce = f32[2]{0} reduce(multiply_0, constant_0), dimensions={1,2}, to_apply=add + broadcast = f32[2,125,127]{2,1,0} broadcast(reduce), dimensions={0} + ROOT multiply = f32[2,125,127]{2,1,0} multiply(multiply_0, broadcast) +} + +ENTRY main { + parameter_0 = f32[2,125,127]{2,1,0} parameter(0) + ROOT triton_softmax = f32[2,125,127]{2,1,0} fusion(parameter_0), + kind=kCustom, calls=triton_softmax_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton_softmax"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloTest)); + + const HloComputation* computation = + hlo_module->GetComputationWithName("triton_softmax_computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = hlo_query::GetFirstInstructionWithOpcode( + *computation, HloOpcode::kReduce); + EXPECT_THAT(IsTritonSupportedInstruction(*instr, GetCudaComputeCapability()) + .Explain(), + ::testing::HasSubstr( + "Reduction is not a row-reduction of a single operand.")); + EXPECT_THAT(TritonFusionAnalysis::Execute(*computation), + tsl::testing::StatusIs( + absl::StatusCode::kFailedPrecondition, + ::testing::HasSubstr( + "Can not propagate dim orders and requirements"))); +} + +TEST_F(TritonSupportTest, + UnsupportedReduceWithNoneLastReduceDimensionFailsGracefullyWithTriton) { + const std::string kHloTest = R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + parameter_0 = f32[2,125,127]{2,1,0} parameter(0) + multiply_0 = f32[2,125,127]{2,1,0} multiply(parameter_0, parameter_0) + constant_0 = f32[] constant(0) + reduce = f32[2,127]{1,0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast = f32[2,125,127]{2,1,0} broadcast(reduce), dimensions={0,2} + ROOT multiply = f32[2,125,127]{2,1,0} multiply(multiply_0, broadcast) +} + +ENTRY main { + parameter_0 = f32[2,125,127]{2,1,0} parameter(0) + ROOT triton_softmax = f32[2,125,127]{2,1,0} fusion(parameter_0), + kind=kCustom, calls=triton_softmax_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton_softmax"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloTest)); + + const HloComputation* computation = + hlo_module->GetComputationWithName("triton_softmax_computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = hlo_query::GetFirstInstructionWithOpcode( + *computation, HloOpcode::kReduce); + EXPECT_THAT(IsTritonSupportedInstruction(*instr, GetCudaComputeCapability()) + .Explain(), + ::testing::HasSubstr( + "Reduction is not a row-reduction of a single operand.")); + EXPECT_THAT(TritonFusionAnalysis::Execute(*computation), + tsl::testing::StatusIs( + absl::StatusCode::kFailedPrecondition, + ::testing::HasSubstr( + "Can not propagate dim orders and requirements"))); +} + +TEST_F(TritonSupportTest, + UnsupportedReduceWithMoreThanOneOperandsFailsGracefullyWithTriton) { + const std::string kHloTest = R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_2 = f32[] parameter(1) + Arg_1 = f32[] parameter(2) + Arg_3 = f32[] parameter(3) + add_0 = f32[] add(Arg_0, Arg_2) + add_1 = f32[] add(Arg_1, Arg_3) + ROOT pair = (f32[], f32[]) tuple(add_0, add_1) +} + +triton_softmax_computation { + parameter_0 = f32[125,127] parameter(0) + multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) + constant_0 = f32[] constant(0) + tuple_0 = (f32[125]{0}, f32[125]{0}) reduce(multiply_0, multiply_0, constant_0, constant_0), dimensions={1}, to_apply=add + reduce = f32[125]{0} get-tuple-element(tuple_0), index=0 + broadcast = f32[125,127]{1,0} broadcast(reduce), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast) +} + +ENTRY main { + parameter_0 = f32[125,127]{1,0} parameter(0) + ROOT triton_softmax = f32[125,127]{1,0} fusion(parameter_0), + kind=kCustom, calls=triton_softmax_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton_softmax"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloTest)); + + const HloComputation* computation = + hlo_module->GetComputationWithName("triton_softmax_computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = hlo_query::GetFirstInstructionWithOpcode( + *computation, HloOpcode::kReduce); + EXPECT_THAT( + IsTritonSupportedInstruction(*instr, GetCudaComputeCapability()) + .Explain(), + ::testing::HasSubstr("Unsupported output data type for Reduce op.")); + EXPECT_THAT(TritonFusionAnalysis::Execute(*computation), + tsl::testing::StatusIs( + absl::StatusCode::kFailedPrecondition, + ::testing::HasSubstr( + "Can not propagate dim orders and requirements"))); +} + +TEST_F(TritonSupportTest, + UnsupportedReduceWithNonConstReduceValueFailsGracefullyWithTriton) { + const std::string kHloTest = R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + parameter_0 = f32[125,127]{1,0} parameter(0) + multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) + init = f32[] parameter(1) + reduce = f32[125]{0} reduce(multiply_0, init), dimensions={1}, to_apply=add + broadcast = f32[125,127]{1,0} broadcast(reduce), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast) +} + +ENTRY main { + parameter_0 = f32[125,127]{1,0} parameter(0) + parameter_1 = f32[] parameter(1) + ROOT triton_softmax = f32[125,127]{1,0} fusion(parameter_0, parameter_1), + kind=kCustom, calls=triton_softmax_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton_softmax"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloTest)); + + const HloComputation* computation = + hlo_module->GetComputationWithName("triton_softmax_computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = hlo_query::GetFirstInstructionWithOpcode( + *computation, HloOpcode::kReduce); + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + EXPECT_THAT(IsTritonSupportedInstruction(*instr, GetCudaComputeCapability()) + .Explain(), + ::testing::HasSubstr("Reduction init value should be a constant " + "or a convert of a constant.")); + EXPECT_THAT( + TritonWrapper(*TritonFusionAnalysis::Execute(*computation), "test_fn", + computation, GetCudaComputeCapability(), dev_info, config_, + &llvm_module_, &EmitSoftMax, mlir_context_), + tsl::testing::StatusIs( + absl::StatusCode::kInternal, + ::testing::HasSubstr("operand->opcode() == HloOpcode::kConstant"))); +} + +TEST_F(TritonSupportTest, + UnsupportedReductionComputationFailsGracefullyWithTriton) { + const std::string kHloTest = R"( +HloModule t +custom_call { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT custom_call = f32[] custom-call(Arg_0, Arg_1), custom_call_target="foo" +} + +triton_softmax_computation { + parameter_0 = f32[125,127]{1,0} parameter(0) + multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) + constant_0 = f32[] constant(0) + reduce = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=custom_call + broadcast = f32[125,127]{1,0} broadcast(reduce), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast) +} + +ENTRY main { + parameter_0 = f32[125,127]{1,0} parameter(0) + ROOT triton_softmax = f32[125,127]{1,0} fusion(parameter_0), + kind=kCustom, calls=triton_softmax_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton_softmax"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloTest)); + + const HloComputation* computation = + hlo_module->GetComputationWithName("triton_softmax_computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = hlo_query::GetFirstInstructionWithOpcode( + *computation, HloOpcode::kReduce); + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + EXPECT_THAT( + IsTritonSupportedInstruction(*instr, GetCudaComputeCapability()) + .Explain(), + ::testing::HasSubstr("Unsupported reduction computation by Triton.")); + EXPECT_THAT( + TritonWrapper(*TritonFusionAnalysis::Execute(*computation), "test_fn", + computation, GetCudaComputeCapability(), dev_info, config_, + &llvm_module_, &EmitSoftMax, mlir_context_), + tsl::testing::StatusIs(absl::StatusCode::kInvalidArgument, + ::testing::HasSubstr("Unsupported operation"))); +} +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/triton_tiling_propagation.cc b/xla/service/gpu/triton_tiling_propagation.cc index ec45285cb517f..9abec4560784a 100644 --- a/xla/service/gpu/triton_tiling_propagation.cc +++ b/xla/service/gpu/triton_tiling_propagation.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/gpu/triton_tiling_propagation.h" #include +#include #include #include #include @@ -48,40 +49,86 @@ limitations under the License. namespace xla { namespace gpu { -bool TensorIterationSpec::operator==(const TensorIterationSpec& other) const { - VLOG(9) << this->ToString(); - VLOG(9) << other.ToString(); - auto it_this = dim_iteration_specs_.cbegin(); - while (it_this != dim_iteration_specs_.cend()) { - auto it_other = other.dim_iteration_specs_.find(it_this->first); - if (it_other == other.dim_iteration_specs_.cend()) { +namespace { + +// The input is a map from dimension index to DimIterationSpec. The function +// removes dimensions that have a trivial DimIterationSpec. +absl::flat_hash_map +FilterTrivialDims( + const absl::flat_hash_map& + dim_iter_specs) { + absl::flat_hash_map + non_trivial_dim_iteration_specs; + for (const auto& [dim, dim_spec] : dim_iter_specs) { + if (dim_spec.size() == 1 && dim_spec[0].count == 1) { + continue; + } + non_trivial_dim_iteration_specs[dim] = dim_spec; + } + return non_trivial_dim_iteration_specs; +} + +} // namespace + +const TensorIterationSpec::DimIterationSpec* TensorIterationSpec::Find( + const int dimension) const { + if (auto it = dim_iteration_specs_.find(dimension); + it != dim_iteration_specs_.end()) { + return &it->second; + } + return nullptr; +} + +std::vector TensorIterationSpec::GetDimensions() const { + std::vector result; + result.reserve(dim_iteration_specs_.size()); + for (const auto& [dim, _] : dim_iteration_specs_) { + result.push_back(dim); + } + return result; +} + +bool TensorIterationSpec::IsPhysicallyEquivalent( + const TensorIterationSpec& other) const { + // Filter out trivial dims since they don't affect physical representation. + const absl::flat_hash_map + non_trivial_dim_iteration_specs = FilterTrivialDims(dim_iteration_specs_); + const absl::flat_hash_map + other_non_trivial_dim_iteration_specs = + FilterTrivialDims(other.dim_iteration_specs_); + + if (non_trivial_dim_iteration_specs.size() != + other_non_trivial_dim_iteration_specs.size()) { + return false; + } + + for (const auto& pair : non_trivial_dim_iteration_specs) { + int dimension = pair.first; + const DimIterationSpec& dim_iter_spec = pair.second; + auto other_it = other_non_trivial_dim_iteration_specs.find(dimension); + if (other_it == other_non_trivial_dim_iteration_specs.end()) { return false; } - if (it_this->second.size() != it_other->second.size()) { + const DimIterationSpec& other_dim_iter_spec = other_it->second; + if (dim_iter_spec.size() != other_dim_iter_spec.size()) { return false; } - for (int fragment = 0; fragment < it_this->second.size(); ++fragment) { - if (it_this->second[fragment] != it_other->second[fragment]) { + for (size_t i = 0; i < dim_iter_spec.size(); i++) { + if (!dim_iter_spec[i].IsPhysicallyEquivalent(other_dim_iter_spec[i])) { return false; } } - ++it_this; } return true; } std::string TensorIterationSpec::IterationSpecFragment::ToString() const { return absl::StrCat("{stride=", stride, ", count=", count, - ", slice_start=", slice_start, ", subfragments=[", + ", slice_start=", slice_start, + ", sliced_count=", sliced_count, ", subfragments=[", absl::StrJoin(subfragments, ", "), "]}"); } -bool TensorIterationSpec::IterationSpecFragment::operator!=( - const IterationSpecFragment& other) const { - return stride != other.stride || count != other.count || - slice_start != other.slice_start || sliced_count != other.sliced_count; -} - std::string TensorIterationSpec::ToString() const { return absl::StrCat( "{", @@ -164,12 +211,6 @@ TensorIterationSpec DimensionOrder::ToTensorIterationSpec() const { TensorIterationSpec tensor_spec; int64_t accumulated_stride = 1; int last_dim = -1; - auto remove_last_fragment_if_degenerate = [&tensor_spec](const int dim_idx) { - if (dim_idx >= 0 && !tensor_spec[dim_idx].empty() && - tensor_spec[dim_idx].back().count == 1) { - tensor_spec[dim_idx].pop_back(); - } - }; for (int dim_order_index = 0; dim_order_index < dim_fragments.size(); ++dim_order_index) { const DimensionOrder::Fragment& fragment = dim_fragments[dim_order_index]; @@ -196,7 +237,6 @@ TensorIterationSpec DimensionOrder::ToTensorIterationSpec() const { dim_spec.back().subfragments.push_back(fragment.sliced_count()); } } else { - remove_last_fragment_if_degenerate(last_dim); // Add part of the dimension. dim_spec.push_back(TensorIterationSpec::IterationSpecFragment{ accumulated_stride, @@ -209,7 +249,23 @@ TensorIterationSpec DimensionOrder::ToTensorIterationSpec() const { accumulated_stride *= fragment.full_count(); last_dim = fragment.dst_dim_number(); } - remove_last_fragment_if_degenerate(last_dim); + + // Remove degenerate fragments. + for (int dim_idx : tensor_spec.GetDimensions()) { + TensorIterationSpec::DimIterationSpec& dim_spec = tensor_spec[dim_idx]; + + // We should not remove the only fragment in a dimension, because if it is + // removed, the dimension will be removed from the TensorIterationSpec. + if (dim_spec.size() <= 1) continue; + + TensorIterationSpec::DimIterationSpec filtered_dim_spec; + absl::c_copy_if(dim_spec, std::back_inserter(filtered_dim_spec), + [](const TensorIterationSpec::IterationSpecFragment& f) { + return f.count != 1; + }); + tensor_spec[dim_idx] = filtered_dim_spec; + } + tensor_spec.RemoveEmptyDimensions(); return tensor_spec; } @@ -398,17 +454,7 @@ DimOrderMap GetPropagatedDimOrdersForElementwise( return map; } - DimOrderMap map; - map.insert({&hlo, src_dim_order}); - // TODO(tdanyluk): For now, the "input to output" direction of this function - // also returns the dim orders for the operands, not just the output. This is - // needed to propagate the dim order of one input to the other(s) when fusing - // elementwise ops to the output. Perhaps we can separate the "input to - // output" and "output to input" directions of that in a later CL. - for (const HloInstruction* operand : hlo.operands()) { - map.insert({operand, src_dim_order}); - } - return map; + return {{&hlo, src_dim_order}}; } const HloInstruction& GetSourceHlo(const HloInstruction& hlo, @@ -457,6 +503,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForBitcast( DimensionOrder& dst_dim_order = dst_dim_orders.insert({&dst, DimensionOrder()}).first->second; Fragments& dst_fragments_order = dst_dim_order.TensorFragmentsOrder(); + bool dst_remainder_comes_from_reduce_dim = false; // Size of not yet assigned part of current target dimension. int64_t dst_remaining_size = 1; // Track destination fragments created from a source one. @@ -481,6 +528,14 @@ DimOrderMapOrError GetPropagatedDimOrdersForBitcast( // Find a continuous group of fragments corresponding to this dimension in // the source and assign the corresponding size in fragments of the // destination ignoring the source ones. + + // If there is dst_remaining_size leftover from our previous src_dim, + // and it came from a reduce dim, we cannot tile it in a batch dim. + if (dst_remainder_comes_from_reduce_dim) { + return R"(Unsupported bitcast splits dimension between batch and + reduction dimensions in softmax)"; + } + dst_remaining_size = src_dim->full_count(); while (src_dim + 1 != src_fragments_order.cend() && (src_dim + 1)->dst_dim_number() == src_dim->dst_dim_number()) { @@ -548,6 +603,16 @@ DimOrderMapOrError GetPropagatedDimOrdersForBitcast( ++dst_dim_it; } } + + // We cannot tile a single dim with fragments across both reduce and batch + // dimensions. As such, if we have a dst remainder leftover from tiling a + // src fragment on the reduce dimension in softmax, we must only tile it + // with other src_dim fragments on the reduce dimension. + dst_remainder_comes_from_reduce_dim = + (dst_remaining_size > 1 && + std::holds_alternative(properties) && + src_dim->dst_dim_number() == std::get(properties) + .softmax_reduction_dimension); } CHECK_EQ(dst_remaining_size, 1); @@ -610,6 +675,12 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( // full dimensions and matching by total size. std::vector> src_physical; src_physical.reserve(src.shape().rank()); + if (src_fragments_order.size() < src.shape().rank()) { + // It's not supported currently to further propagate dimensions after + // reaching a trivial sized tensor. We could probably support it, but now we + // just prevent crashing here. + return FusionDecision("Cannot propagate further from trivial sized tensor"); + } auto src_fragment_it = src_fragments_order.begin(); for (int64_t dim_index : src.shape().layout().minor_to_major()) { const int64_t dim_size = src.shape().dimensions(dim_index); @@ -633,6 +704,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( } DimOrderMap dst_dim_orders; + int64_t concat_accumulated_size = 0; for (const HloInstruction* dst : GetDestHlos(hlo, direction)) { DimensionOrder& dst_dim_order = dst_dim_orders.insert({dst, DimensionOrder()}).first->second; @@ -685,13 +757,19 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( } else if (hlo.opcode() == HloOpcode::kConcatenate) { dst_logical.resize(src_logical.size()); for (int i = 0; i < src_logical.size(); ++i) { - dst_logical[i] = src_logical[i]; if (i == hlo.concatenate_dimension()) { if (src_logical[i].size() != 1 || src_logical[i][0]->is_sliced()) { return FusionDecision("Unsupported concatenation."); } - dst_logical[i][0]->set_count(dst->shape().dimensions(i)); - dst_logical[i][0]->set_slice(0, dst->shape().dimensions(i)); + const Fragment& src_fragment = *src_logical[i][0]; + Fragment& dst_fragment = new_fragments.emplace_back( + src_fragment.dst_dim_number(), dst->shape().dimensions(i)); + dst_fragment.set_slice(-concat_accumulated_size, + dst->shape().dimensions(i)); + concat_accumulated_size += dst->shape().dimensions(i); + dst_logical[i].push_back(&dst_fragment); + } else { + dst_logical[i] = src_logical[i]; } } } else if (hlo.opcode() == HloOpcode::kCopy) { @@ -828,6 +906,10 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, properties); } else if (hlo.opcode() == HloOpcode::kPad) { + if (std::holds_alternative(properties)) { + return "Pad ops are only supported when they are generated as part of " + "the split-k transform of dot fusions."; + } if (direction != TransformDirection::kOutputToInput) { return "Unsupported pad direction."; } @@ -841,6 +923,10 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, return GetPropagatedDimOrdersForBitcast(hlo, direction, src_dim_order, properties); } else if (hlo.opcode() == HloOpcode::kSlice) { + // TODO(b/316637896) Add support for slices in softmax. + if (std::holds_alternative(properties)) { + return "Slices are not supported in Softmax fusions yet."; + } if (direction != TransformDirection::kOutputToInput) { return "Unsupported slice direction."; } @@ -857,18 +943,27 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, if (!std::holds_alternative(properties)) { return "Concatenations for now are only supported in GEMM fusions."; } - auto dim = LogicalIndexOfLabeledDimension( - hlo.shape(), src_dim_order, - std::get(properties).noncontracting_dimension); + + int64_t noncontracting_dim_label = + std::get(properties).noncontracting_dimension; + const FragmentOrders& src_dim_fragments_orders = + src_dim_order.DimFragmentsOrders(); + + auto noncontracting_dim_fragment_order_it = + src_dim_fragments_orders.find(noncontracting_dim_label); + if (noncontracting_dim_fragment_order_it != + src_dim_fragments_orders.end()) { + if (noncontracting_dim_fragment_order_it->second.size() > 1) { + return "Concatenations on split non-contracting dimensions are " + "unsupported."; + } + } + + auto dim = LogicalIndexOfLabeledDimension(hlo.shape(), src_dim_order, + noncontracting_dim_label); if (!dim.has_value() || dim.value() != hlo.concatenate_dimension()) { return "Unsupported concatenation."; } - if (absl::c_any_of(hlo.operands(), [](const HloInstruction* operand) { - return operand->user_count() > 1; - })) { - return FusionDecision( - "Concatenation has to be the only user of its inputs."); - } if (absl::c_any_of(hlo.operands(), [&hlo](const HloInstruction* operand) { // In the current simple implementation of concatenation the size of // each of its inputs along the concatenated dimension has to be @@ -1005,10 +1100,7 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( std::move(std::get(result_or_error)); int fusion_level = hlo.GetModule()->config().debug_options().xla_gpu_triton_fusion_level(); - if (!std::get(gpu_version) - .IsAtLeast(se::CudaComputeCapability::AMPERE)) { - fusion_level = std::min(fusion_level, 1); - } + // TODO(ROCm): Check fusion level for ROCm. if (transform_direction == TransformDirection::kOutputToInput) { if (fusion_level < 2) { if (hlo.opcode() == HloOpcode::kConvert) { diff --git a/xla/service/gpu/triton_tiling_propagation.h b/xla/service/gpu/triton_tiling_propagation.h index 08445f73b7608..87ff11ae7c741 100644 --- a/xla/service/gpu/triton_tiling_propagation.h +++ b/xla/service/gpu/triton_tiling_propagation.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -34,6 +35,33 @@ limitations under the License. namespace xla { namespace gpu { +// Illustration explaining why slice_start for concatenations is negative: + +// Slice +// ===== +// input +// [--------------------------] +// . . . +// . offset . +// |------> output . +// [--------] +// +// output[x] = input[x + offset] + +// Concatenation +// ============= +// +// input_n +// [......][--------][........] +// . . +// offset . . +// <-------| . +// . . . +// . . output . +// [--------------------------] +// +// output[x] = input_n[x - offset] + class TensorIterationSpec { public: // Description of basic iteration: `count` elements separated by `stride` @@ -43,12 +71,33 @@ class TensorIterationSpec { int64_t count; int64_t slice_start; int64_t sliced_count; - // Logical subfragments when this iteration is composed - // of several HLO dimensions. + // Logical subfragments: + // These are the sizes of the HLO dimensions which make up this basic + // iteration. std::vector subfragments; bool is_sliced() const { return count != sliced_count; } - bool operator!=(const IterationSpecFragment& other) const; + + auto ToTuple() const { + return std::make_tuple(stride, count, slice_start, sliced_count, + subfragments); + } + + bool operator==(const IterationSpecFragment& other) const { + return ToTuple() == other.ToTuple(); + } + template + friend H AbslHashValue(H h, const IterationSpecFragment& fragment) { + return H::combine(std::move(h), fragment.ToTuple()); + } + + bool IsPhysicallyEquivalent(const IterationSpecFragment& other) const { + // Subfragments don't change the physical layout. + return stride == other.stride && count == other.count && + slice_start == other.slice_start && + sliced_count == other.sliced_count; + } + std::string ToString() const; }; // Description of complex iteration over a sequence of several strides. @@ -56,26 +105,41 @@ class TensorIterationSpec { // separated into multiple fragments by other dimensions. using DimIterationSpec = std::vector; - using StorageType = absl::flat_hash_map; const DimIterationSpec& operator[](const int dimension) const { return dim_iteration_specs_.at(dimension); } DimIterationSpec& operator[](const int dimension) { return dim_iteration_specs_[dimension]; } - const StorageType& Storage() const { return dim_iteration_specs_; } + // Returns nullptr if not found. + const DimIterationSpec* Find(int dimension) const; + + std::vector GetDimensions() const; + void RemoveEmptyDimensions() { absl::erase_if(dim_iteration_specs_, [](const auto& it) { return it.second.empty(); }); } + bool operator==(const TensorIterationSpec& other) const { + return dim_iteration_specs_ == other.dim_iteration_specs_; + } + + template + friend H AbslHashValue(H h, const TensorIterationSpec& spec) { + return H::combine(std::move(h), spec.dim_iteration_specs_); + } + // Compares physical layouts of tensors ignoring subfragments of dimensions. - bool operator==(const TensorIterationSpec& other) const; + // Checking with this, instead of "==" allows a few more edge cases to be + // fused. + bool IsPhysicallyEquivalent(const TensorIterationSpec& other) const; std::string ToString() const; private: - StorageType dim_iteration_specs_; + // Maps dimensions to DimIterationSpecs. + absl::flat_hash_map dim_iteration_specs_; }; // The details of the Triton fusion / tiling propagation are in a separate @@ -152,7 +216,8 @@ class DimensionOrder { // Tells that two dimension orders describe the same tensor physical layout. bool IsPhysicallyEquivalent(const DimensionOrder& other) const { - return ToTensorIterationSpec() == other.ToTensorIterationSpec(); + return ToTensorIterationSpec().IsPhysicallyEquivalent( + other.ToTensorIterationSpec()); } private: diff --git a/xla/service/gpu/triton_tiling_propagation_test.cc b/xla/service/gpu/triton_tiling_propagation_test.cc new file mode 100644 index 0000000000000..515bffbe0eb64 --- /dev/null +++ b/xla/service/gpu/triton_tiling_propagation_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/triton_tiling_propagation.h" + +#include + +#include +#include "xla/tests/hlo_test_base.h" + +namespace xla::gpu { +namespace { + +using TritonTilingPropagationTest = HloTestBase; +using triton_fusion::DimensionOrder; + +DimensionOrder FromFragments(DimensionOrder::Fragments fragments) { + DimensionOrder dim_order; + DimensionOrder::Fragments& tensor_fragments_order = + dim_order.TensorFragmentsOrder(); + DimensionOrder::FragmentOrders& dim_fragments_orders = + dim_order.DimFragmentsOrders(); + for (const DimensionOrder::Fragment& fragment : fragments) { + tensor_fragments_order.push_back(fragment); + dim_fragments_orders[fragment.dst_dim_number()].push_back( + tensor_fragments_order.size()); + } + return dim_order; +} + +TEST_F( + TritonTilingPropagationTest, + DimensionOrdersRemainPhysicallyEquivalentAfterInsertingTrivialDimensions) { + DimensionOrder::Fragment fragment_1(/*dst_dim_number=*/0, /*count=*/97); + DimensionOrder::Fragment fragment_2(/*dst_dim_number=*/0, /*count=*/1); + DimensionOrder dimension_order_1 = FromFragments({fragment_1, fragment_2}); + + DimensionOrder::Fragment fragment_3(/*dst_dim_number=*/0, /*count=*/97); + DimensionOrder::Fragment fragment_4(/*dst_dim_number=*/1, /*count=*/1); + DimensionOrder dimension_order_2 = FromFragments({fragment_3, fragment_4}); + + // They should be equivalent because fragment_2 and fragment_4 both have count + // 1, so they don't affect the physical representation. + EXPECT_TRUE(dimension_order_1.IsPhysicallyEquivalent(dimension_order_2)); +} + +TEST_F( + TritonTilingPropagationTest, + IterationSpecsRemainPhysicallyEquivalentAfterInsertingTrivialDimensions) { + TensorIterationSpec::IterationSpecFragment fragment_1 = { + /*stride=*/1, /*count=*/97, /*slice_start=*/0, /*sliced_count=*/97, + /*subfragments=*/{97}}; + TensorIterationSpec spec_1; + spec_1[0].push_back(fragment_1); + + TensorIterationSpec::IterationSpecFragment fragment_2 = { + /*stride=*/1, /*count=*/97, /*slice_start=*/0, /*sliced_count=*/97, + /*subfragments=*/{97}}; + TensorIterationSpec::IterationSpecFragment fragment_3 = { + /*stride=*/97, /*count=*/1, /*slice_start=*/0, /*sliced_count=*/1, + /*subfragments=*/{1}}; + TensorIterationSpec spec_2; + spec_2[0].push_back(fragment_2); + spec_2[1].push_back(fragment_3); + + // spec_2's extra dimension is degenerate, so it should have the same physical + // representation as spec_1. + EXPECT_TRUE(spec_1.IsPhysicallyEquivalent(spec_2)); +} + +TEST_F(TritonTilingPropagationTest, + DimensionsShouldNotBeRemovedByToTensorIterationSpec) { + DimensionOrder::Fragment fragment_0(/*dst_dim_number=*/0, /*count=*/97); + DimensionOrder::Fragment fragment_1(/*dst_dim_number=*/1, /*count=*/1); + DimensionOrder dimension_order = FromFragments({fragment_0, fragment_1}); + TensorIterationSpec spec = dimension_order.ToTensorIterationSpec(); + const TensorIterationSpec::DimIterationSpec* dim_spec_0 = spec.Find(0); + EXPECT_NE(dim_spec_0, nullptr); + EXPECT_EQ(dim_spec_0->size(), 1); + EXPECT_EQ(dim_spec_0->at(0).count, 97); + + const TensorIterationSpec::DimIterationSpec* dim_spec_1 = spec.Find(1); + EXPECT_NE(dim_spec_1, nullptr); + EXPECT_EQ(dim_spec_1->size(), 1); + EXPECT_EQ(dim_spec_1->at(0).count, 1); +} + +} // namespace +} // namespace xla::gpu diff --git a/xla/service/gpu/variadic_op_splitter.cc b/xla/service/gpu/variadic_op_splitter.cc index 4d91f3833afd4..f1371575b7d62 100644 --- a/xla/service/gpu/variadic_op_splitter.cc +++ b/xla/service/gpu/variadic_op_splitter.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,15 +15,21 @@ limitations under the License. #include "xla/service/gpu/variadic_op_splitter.h" +#include #include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/statusor.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape.h" #include "xla/util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -35,7 +41,8 @@ namespace { // how big these buffers are. constexpr int32_t kMaxParameters = 128; -StatusOr SplitConcatenate(HloInstruction* concat, HloComputation* comp) { +absl::StatusOr SplitConcatenate(HloInstruction* concat, + HloComputation* comp) { auto operands = concat->operands(); std::vector operands_to_split(operands.begin(), operands.end()); @@ -89,7 +96,7 @@ std::vector GetRelevantVariadicOps(HloComputation* comp) { } // namespace -StatusOr VariadicOpSplitter::Run( +absl::StatusOr VariadicOpSplitter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/gpu/variadic_op_splitter.h b/xla/service/gpu/variadic_op_splitter.h index bba95c0ae179d..4449ce2a0bdcd 100644 --- a/xla/service/gpu/variadic_op_splitter.h +++ b/xla/service/gpu/variadic_op_splitter.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,9 +16,11 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_VARIADIC_OP_SPLITTER_H_ #define XLA_SERVICE_GPU_VARIADIC_OP_SPLITTER_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/statusor.h" namespace xla { namespace gpu { @@ -30,7 +32,7 @@ class VariadicOpSplitter : public HloModulePass { absl::string_view name() const override { return "variadic-op-splitter"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/gpu/variadic_op_splitter_test.cc b/xla/service/gpu/variadic_op_splitter_test.cc index af32fbde0573b..6d7b72eebe0ba 100644 --- a/xla/service/gpu/variadic_op_splitter_test.cc +++ b/xla/service/gpu/variadic_op_splitter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,11 +15,13 @@ limitations under the License. #include "xla/service/gpu/variadic_op_splitter.h" +#include +#include + +#include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" #include "xla/literal_util.h" -#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/shape_util.h" diff --git a/xla/service/gpu/variant_visitor.h b/xla/service/gpu/variant_visitor.h new file mode 100644 index 0000000000000..c4ff4aa89b3fd --- /dev/null +++ b/xla/service/gpu/variant_visitor.h @@ -0,0 +1,34 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_VARIANT_VISITOR_H_ +#define XLA_SERVICE_GPU_VARIANT_VISITOR_H_ + +namespace xla::gpu { +// This structure is used to support C++17 overload pattern as described in +// https://en.cppreference.com/w/cpp/utility/variant/visit +// +// TODO(b/319202112): Replace with absl::Overload once abs lts_2024_XXX is +// tagged. +template +struct VariantVisitor : Ts... { + using Ts::operator()...; +}; +template +VariantVisitor(Ts...) -> VariantVisitor; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_VARIANT_VISITOR_H_ diff --git a/xla/service/gpu/while_thunk.cc b/xla/service/gpu/while_thunk.cc deleted file mode 100644 index 7c096227d8dbc..0000000000000 --- a/xla/service/gpu/while_thunk.cc +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/while_thunk.h" - -#include -#include - -#include "xla/util.h" -#include "tsl/platform/errors.h" - -namespace xla { -namespace gpu { - -WhileThunk::WhileThunk( - ThunkInfo thunk_info, - const BufferAllocation::Slice& condition_result_buffer_index, - std::unique_ptr condition_thunk_sequence, - std::unique_ptr body_thunk_sequence) - : Thunk(Kind::kWhile, thunk_info), - condition_result_buffer_index_(condition_result_buffer_index), - condition_thunk_sequence_(std::make_unique( - ThunkInfo(thunk_info.op), std::move(*condition_thunk_sequence))), - body_thunk_sequence_(std::make_unique( - ThunkInfo(thunk_info.op), std::move(*body_thunk_sequence))) {} - -Status WhileThunk::Initialize(se::StreamExecutor* executor, - ExecutableSource src) { - TF_RETURN_IF_ERROR(condition_thunk_sequence_->Initialize(executor, src)); - TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executor, src)); - return OkStatus(); -} - -Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) { - auto& stream = *params.stream; - - se::DeviceMemoryBase condition_result_data = - params.buffer_allocations->GetDeviceAddress( - condition_result_buffer_index_); - - while (true) { - // Invoke thunk sequence for while 'condition' computation. - VLOG(3) << "Executing condition computation"; - TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream(params)); - - // Copy the result of condition computation and break the loop if 'false'. - bool condition_result; - stream.ThenMemcpy(&condition_result, condition_result_data, sizeof(bool)); - VLOG(3) << "condition_result = " << condition_result; - Status block_status = stream.BlockHostUntilDone(); - if (!block_status.ok()) { - return InternalError( - "Failed to complete all kernels launched on stream %p: %s", &stream, - block_status.message()); - } - - if (!condition_result) { - break; - } - - VLOG(3) << "Executing body computation"; - // Invoke thunk sequence for while 'body' computation. - TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params)); - } - return OkStatus(); -} - -} // namespace gpu -} // namespace xla diff --git a/xla/service/gpu/while_thunk.h b/xla/service/gpu/while_thunk.h deleted file mode 100644 index 9484446d75aa3..0000000000000 --- a/xla/service/gpu/while_thunk.h +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_WHILE_THUNK_H_ -#define XLA_SERVICE_GPU_WHILE_THUNK_H_ - -#include - -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/sequential_thunk.h" -#include "xla/service/gpu/thunk.h" -#include "xla/stream_executor/stream_executor.h" - -namespace xla { -namespace gpu { - -// WhileThunk implements the while instruction on GPU by invoking a thunk -// sequence for the while 'condition' computation, and (conditionally) another -// thunk sequence for the while 'body' computation. WhileThunk assumes that -// buffers for the following set of while-related instructions share the same -// allocation: -// init, condition.parameter, body.parameter, body.root, while.result -// WhileThunk synchronizes the stream to test the result of the 'condition' -// computation. -class WhileThunk : public Thunk { - public: - // Constructs a WhileThunk to compute while instruction 'hlo'. - WhileThunk(ThunkInfo thunk_info, - const BufferAllocation::Slice& condition_result_buffer_index, - std::unique_ptr condition_thunk_sequence, - std::unique_ptr body_thunk_sequence); - WhileThunk(const WhileThunk&) = delete; - WhileThunk& operator=(const WhileThunk&) = delete; - - Status Initialize(se::StreamExecutor* executor, - ExecutableSource src) override; - Status ExecuteOnStream(const ExecuteParams& params) override; - - SequentialThunk* condition_thunk_sequence() { - return condition_thunk_sequence_.get(); - } - SequentialThunk* body_thunk_sequence() { return body_thunk_sequence_.get(); } - - private: - const BufferAllocation::Slice condition_result_buffer_index_; - std::unique_ptr condition_thunk_sequence_; - std::unique_ptr body_thunk_sequence_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_WHILE_THUNK_H_ diff --git a/xla/service/gpu/while_transformer_test.cc b/xla/service/gpu/while_transformer_test.cc index 079581d3a1004..a5bf72cb4d8b6 100644 --- a/xla/service/gpu/while_transformer_test.cc +++ b/xla/service/gpu/while_transformer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,15 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/copy_insertion.h" -#include "xla/service/gpu/instruction_fusion.h" -#include "xla/service/hlo_verifier.h" +#include +#include + +#include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" #include "xla/service/while_loop_analysis.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/xla_data.pb.h" namespace xla { namespace { diff --git a/xla/service/gpu/xfeed_queue.h b/xla/service/gpu/xfeed_queue.h index 1cb8dec6d5f91..18f63a934a17c 100644 --- a/xla/service/gpu/xfeed_queue.h +++ b/xla/service/gpu/xfeed_queue.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/gpu/xla_executor_state.h b/xla/service/gpu/xla_executor_state.h index 8714d0b24bb81..60aabcd65ff12 100644 --- a/xla/service/gpu/xla_executor_state.h +++ b/xla/service/gpu/xla_executor_state.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,8 +18,11 @@ limitations under the License. #include +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" #include "xla/service/gpu/infeed_manager.h" #include "xla/service/gpu/outfeed_manager.h" +#include "xla/stream_executor/stream_executor_pimpl.h" // Defines XLA:GPU specific state that will be attached to the GpuExecutor. diff --git a/xla/service/gpu_compilation_environment.cc b/xla/service/gpu_compilation_environment.cc index 8f5b0b3c31d12..d598c02df3d5f 100644 --- a/xla/service/gpu_compilation_environment.cc +++ b/xla/service/gpu_compilation_environment.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,12 +23,13 @@ limitations under the License. #include "absl/strings/str_join.h" #include "xla/parse_flags_from_env.h" #include "xla/service/compilation_environments.h" +#include "xla/status.h" #include "xla/statusor.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" -#include "tsl/util/command_line_flags.h" namespace xla { @@ -48,7 +49,7 @@ void InitializeFlagsForGpuCompEnv(std::vector* flag_list, gpu_comp_env->dummy_flag(), "Dummy flag to demonstrate the flow")); } -StatusOr CreateGpuCompEnvFromFlagStrings( +absl::StatusOr CreateGpuCompEnvFromFlagStrings( std::vector& flags, bool strict) { GpuCompilationEnvironment gpu_comp_env; std::vector flag_objects; @@ -61,7 +62,7 @@ StatusOr CreateGpuCompEnvFromFlagStrings( return gpu_comp_env; } -StatusOr CreateGpuCompEnvFromEnvVar() { +absl::StatusOr CreateGpuCompEnvFromEnvVar() { GpuCompilationEnvironment env; std::vector flag_objects; InitializeFlagsForGpuCompEnv(&flag_objects, &env); @@ -78,47 +79,51 @@ GpuCompilationEnvironment CreateGpuCompEnvWithDefaultValues() { return env; } -namespace { - -// Implement a CompilationEnvironment::ProcessNewEnvFn for -// GpuCompilationEnvironment, so that we can add GpuCompilationEnvironments -// to CompilationEnvironments. -// -// The implementation returns Default env if one doesn't exist already. -// NOLINTNEXTLINE -StatusOr> -ProcessNewGpuCompilationEnvironment( - std::unique_ptr env) { // NOLINT - if (!env) { - env = std::make_unique(); - } +Status InitializeMissingFieldsFromXLAFlags(GpuCompilationEnvironment& env) { TF_ASSIGN_OR_RETURN(GpuCompilationEnvironment from_env, CreateGpuCompEnvFromEnvVar()); auto default_env = CreateGpuCompEnvWithDefaultValues(); - auto reflection = env->GetReflection(); + auto reflection = env.GetReflection(); auto reflection_from_env = from_env.GetReflection(); auto descriptor = GpuCompilationEnvironment::descriptor(); std::vector missing_fields; for (int j = 0; j < descriptor->field_count(); ++j) { const tsl::protobuf::FieldDescriptor* field = descriptor->field(j); - if (reflection->HasField(*env, field) && + if (reflection->HasField(env, field) && reflection_from_env->HasField(from_env, field)) { return InvalidArgument( "Flag %s is set in both XLA_FLAGS env var and " "GpuCompilationEnvironment.", field->name()); - } else if (!reflection->HasField(*env, field) && + } else if (!reflection->HasField(env, field) && !reflection_from_env->HasField(from_env, field)) { missing_fields.push_back(field); } } - env->MergeFrom(from_env); + env.MergeFrom(from_env); if (!missing_fields.empty()) { - reflection->SwapFields(env.get(), &default_env, missing_fields); + reflection->SwapFields(&env, &default_env, missing_fields); + } + return OkStatus(); +} + +namespace { + +// Implement a CompilationEnvironment::ProcessNewEnvFn for +// GpuCompilationEnvironment, so that we can add GpuCompilationEnvironments +// to CompilationEnvironments. +// +// The implementation returns Empty env if one doesn't exist already. +// NOLINTNEXTLINE +absl::StatusOr> +ProcessNewGpuCompilationEnvironment( + std::unique_ptr env) { // NOLINT + if (!env) { + env = std::make_unique(); } return env; } diff --git a/xla/service/gpu_compilation_environment.h b/xla/service/gpu_compilation_environment.h index 99d93c185042c..23f2a30273c8a 100644 --- a/xla/service/gpu_compilation_environment.h +++ b/xla/service/gpu_compilation_environment.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,12 +23,16 @@ limitations under the License. namespace xla { -StatusOr CreateGpuCompEnvFromFlagStrings( +absl::StatusOr CreateGpuCompEnvFromFlagStrings( std::vector& flags, bool strict); -StatusOr CreateGpuCompEnvFromEnvVar(); +absl::StatusOr CreateGpuCompEnvFromEnvVar(); GpuCompilationEnvironment CreateGpuCompEnvWithDefaultValues(); +// Returns non-OK status if XLA_FLAGS env var has malformed values or +// if it has conflict with the GpuCompilationEnvironment proto +Status InitializeMissingFieldsFromXLAFlags(GpuCompilationEnvironment& env); + } // namespace xla #endif // XLA_SERVICE_GPU_COMPILATION_ENVIRONMENT_H_ diff --git a/xla/service/gpu_compilation_environment_test.cc b/xla/service/gpu_compilation_environment_test.cc index 22efaa4a317d6..85dde778f8aee 100644 --- a/xla/service/gpu_compilation_environment_test.cc +++ b/xla/service/gpu_compilation_environment_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,9 +15,7 @@ limitations under the License. #include "xla/service/gpu_compilation_environment.h" -#include #include -#include #include #include @@ -96,54 +94,50 @@ TEST(CreateGpuCompEnvFromEnvVarTest, InvalidFlagValue) { StatusIs(tsl::error::INVALID_ARGUMENT)); } -TEST(ProcessNewEnvTest, BothProtoAndEnvVarUnset) { +TEST(InitializeMissingFieldsFromXLAFlagsTest, BothProtoAndEnvVarUnset) { set_xla_flags_env_var(""); - CompilationEnvironments envs; - - const auto& env = envs.GetEnv(); + GpuCompilationEnvironment env; + TF_ASSERT_OK(InitializeMissingFieldsFromXLAFlags(env)); EXPECT_EQ(env.dummy_flag(), 1); } -TEST(ProcessNewEnvTest, ProtoSetButEnvVarUnset) { +TEST(InitializeMissingFieldsFromXLAFlagsTest, ProtoSetButEnvVarUnset) { set_xla_flags_env_var(""); - CompilationEnvironments envs; - { - auto env = std::make_unique(); - env->set_dummy_flag(2); - TF_ASSERT_OK(envs.AddEnv(std::move(env))); - } - const auto& env = envs.GetEnv(); + GpuCompilationEnvironment env; + env.set_dummy_flag(2); + + TF_ASSERT_OK(InitializeMissingFieldsFromXLAFlags(env)); EXPECT_EQ(env.dummy_flag(), 2); } -TEST(ProcessNewEnvTest, ProtoUnsetButEnvVarSet) { +TEST(InitializeMissingFieldsFromXLAFlagsTest, ProtoUnsetButEnvVarSet) { set_xla_flags_env_var("--dummy_flag=4"); - CompilationEnvironments envs; - const auto& env = envs.GetEnv(); + + GpuCompilationEnvironment env; + TF_ASSERT_OK(InitializeMissingFieldsFromXLAFlags(env)); EXPECT_EQ(env.dummy_flag(), 4); } -TEST(ProcessNewEnvTest, BothProtoAndEnvVarSetButNoConflict) { +TEST(InitializeMissingFieldsFromXLAFlagsTest, + BothProtoAndEnvVarSetButNoConflict) { set_xla_flags_env_var("--dummy_flag=4"); CompilationEnvironments envs; - { - auto env = std::make_unique(); - TF_ASSERT_OK(envs.AddEnv(std::move(env))); - } - const auto& env = envs.GetEnv(); + GpuCompilationEnvironment env; + TF_ASSERT_OK(InitializeMissingFieldsFromXLAFlags(env)); EXPECT_EQ(env.dummy_flag(), 4); } -TEST(ProcessNewEnvTest, BothProtoAndEnvVarSetWithConflict) { +TEST(InitializeMissingFieldsFromXLAFlagsTest, + BothProtoAndEnvVarSetWithConflict) { set_xla_flags_env_var("--dummy_flag=4"); CompilationEnvironments envs; - auto env = std::make_unique(); - env->set_dummy_flag(2); - EXPECT_THAT(envs.AddEnv(std::move(env)), + GpuCompilationEnvironment env; + env.set_dummy_flag(2); + EXPECT_THAT(InitializeMissingFieldsFromXLAFlags(env), StatusIs(tsl::error::INVALID_ARGUMENT)); } diff --git a/xla/service/graphcycles/BUILD b/xla/service/graphcycles/BUILD index 215bfe438778e..6779d9b0c0496 100644 --- a/xla/service/graphcycles/BUILD +++ b/xla/service/graphcycles/BUILD @@ -1,10 +1,10 @@ +load("@tsl//tsl:tsl.bzl", "internal_visibility") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//xla:xla.bzl", "xla_cc_test") -load("@tsl//tsl:tsl.bzl", "set_external_visibility") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([ + default_visibility = internal_visibility([ "//tensorflow/compiler:__subpackages__", ]), licenses = ["notice"], @@ -41,6 +41,7 @@ xla_cc_test( deps = [ ":graphcycles", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/random", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_benchmark", @@ -53,9 +54,7 @@ xla_cc_test( srcs = ["ordered_set_test.cc"], deps = [ ":ordered_set", - "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", - "@tsl//tsl/platform:test_benchmark", "@tsl//tsl/platform:test_main", ], ) diff --git a/xla/service/graphcycles/graphcycles.cc b/xla/service/graphcycles/graphcycles.cc index 1cdf406204458..c8308d4d2e13f 100644 --- a/xla/service/graphcycles/graphcycles.cc +++ b/xla/service/graphcycles/graphcycles.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,11 +32,15 @@ limitations under the License. #include "xla/service/graphcycles/graphcycles.h" #include +#include +#include +#include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "xla/service/graphcycles/ordered_set.h" #include "tsl/platform/logging.h" @@ -47,17 +51,19 @@ namespace { using NodeSet = absl::flat_hash_set; using OrderedNodeSet = OrderedSet; -template -struct VecStruct { - typedef absl::InlinedVector type; -}; -template -using Vec = typename VecStruct::type; - struct Node { int32_t rank; // rank number assigned by Pearce-Kelly algorithm + // Note (ecg@): the padding between these two fields bothered me, so I tried + // the following alternatives: + // - Separate bitmap to track visited[]. + // - Separate std::vector visited. + // - Tagged top or bottom bit of "rank" to keep track of "visited". + // However, keeping the bool here (despite the padding) achieves the best + // performance for the IsReachableNonConst microbenchmark. bool visited; // Temporary marker used by depth-first-search - void* data; // User-supplied data +}; + +struct NodeIO { OrderedNodeSet in; // List of immediate predecessor nodes in graph OrderedNodeSet out; // List of immediate successor nodes in graph }; @@ -65,40 +71,43 @@ struct Node { } // namespace struct GraphCycles::Rep { - Vec nodes_; - Vec free_nodes_; // Indices for unused entries in nodes_ + std::vector nodes_; + std::vector node_io_; + std::vector free_nodes_; // Indices for unused entries in nodes_ // Temporary state. - Vec deltaf_; // Results of forward DFS - Vec deltab_; // Results of backward DFS - Vec list_; // All nodes to reprocess - Vec merged_; // Rank values to assign to list_ entries - Vec + std::vector deltaf_; // Results of forward DFS + std::vector deltab_; // Results of backward DFS + std::vector list_; // All nodes to reprocess + std::vector merged_; // Rank values to assign to list_ entries + std::vector stack_; // Emulates recursion stack when doing depth first search + + // User-supplied data. Stored outside of Node since it is rarely accessed. + std::vector node_data_; }; GraphCycles::GraphCycles() : rep_(new Rep) {} +// Define the destructor here because Rep is also defined in this file. GraphCycles::~GraphCycles() { - for (Vec::size_type i = 0; i < rep_->nodes_.size(); i++) { - delete rep_->nodes_[i]; - } delete rep_; } bool GraphCycles::CheckInvariants() const { Rep* r = rep_; NodeSet ranks; // Set of ranks seen so far. - for (Vec::size_type x = 0; x < r->nodes_.size(); x++) { - Node* nx = r->nodes_[x]; + for (size_t x = 0; x < r->nodes_.size(); x++) { + Node* nx = &r->nodes_[x]; if (nx->visited) { LOG(FATAL) << "Did not clear visited marker on node " << x; } if (!ranks.insert(nx->rank).second) { LOG(FATAL) << "Duplicate occurrence of rank " << nx->rank; } - for (int32_t y : nx->out.GetSequence()) { - Node* ny = r->nodes_[y]; + NodeIO* nx_io = &r->node_io_[x]; + for (int32_t y : nx_io->out.GetSequence()) { + Node* ny = &r->nodes_[y]; if (nx->rank >= ny->rank) { LOG(FATAL) << "Edge " << x << "->" << y << " has bad rank assignment " << nx->rank << "->" << ny->rank; @@ -110,29 +119,30 @@ bool GraphCycles::CheckInvariants() const { int32_t GraphCycles::NewNode() { if (rep_->free_nodes_.empty()) { - Node* n = new Node; - n->visited = false; - n->data = nullptr; - n->rank = rep_->nodes_.size(); - rep_->nodes_.push_back(n); - return n->rank; + Node n; + n.visited = false; + n.rank = rep_->nodes_.size(); + rep_->nodes_.emplace_back(n); + rep_->node_io_.emplace_back(); + rep_->node_data_.push_back(nullptr); + return n.rank; } else { // Preserve preceding rank since the set of ranks in use must be // a permutation of [0,rep_->nodes_.size()-1]. int32_t r = rep_->free_nodes_.back(); - rep_->nodes_[r]->data = nullptr; rep_->free_nodes_.pop_back(); + rep_->node_data_[r] = nullptr; return r; } } void GraphCycles::RemoveNode(int32_t node) { - Node* x = rep_->nodes_[node]; + NodeIO* x = &rep_->node_io_[node]; for (int32_t y : x->out.GetSequence()) { - rep_->nodes_[y]->in.Erase(node); + rep_->node_io_[y].in.Erase(node); } for (int32_t y : x->in.GetSequence()) { - rep_->nodes_[y]->out.Erase(node); + rep_->node_io_[y].out.Erase(node); } x->in.Clear(); x->out.Clear(); @@ -140,20 +150,20 @@ void GraphCycles::RemoveNode(int32_t node) { } void* GraphCycles::GetNodeData(int32_t node) const { - return rep_->nodes_[node]->data; + return rep_->node_data_[node]; } void GraphCycles::SetNodeData(int32_t node, void* data) { - rep_->nodes_[node]->data = data; + rep_->node_data_[node] = data; } bool GraphCycles::HasEdge(int32_t x, int32_t y) const { - return rep_->nodes_[x]->out.Contains(y); + return rep_->node_io_[x].out.Contains(y); } void GraphCycles::RemoveEdge(int32_t x, int32_t y) { - rep_->nodes_[x]->out.Erase(y); - rep_->nodes_[y]->in.Erase(x); + rep_->node_io_[x].out.Erase(y); + rep_->node_io_[y].in.Erase(x); // No need to update the rank assignment since a previous valid // rank assignment remains valid after an edge deletion. } @@ -161,23 +171,26 @@ void GraphCycles::RemoveEdge(int32_t x, int32_t y) { static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound); static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound); static void Reorder(GraphCycles::Rep* r); -static void Sort(const Vec&, Vec* delta); -static void MoveToList(GraphCycles::Rep* r, Vec* src, - Vec* dst); -static void ClearVisitedBits(GraphCycles::Rep* r, const Vec& nodes); +static void Sort(absl::Span, std::vector* delta); +static void MoveToList(GraphCycles::Rep* r, std::vector* src, + std::vector* dst); +static void ClearVisitedBits(GraphCycles::Rep* r, + absl::Span visited_indices); bool GraphCycles::InsertEdge(int32_t x, int32_t y) { if (x == y) return false; Rep* r = rep_; - Node* nx = r->nodes_[x]; - if (!nx->out.Insert(y)) { + NodeIO* nx_io = &r->node_io_[x]; + if (!nx_io->out.Insert(y)) { // Edge already exists. return true; } - Node* ny = r->nodes_[y]; - ny->in.Insert(x); + NodeIO* ny_io = &r->node_io_[y]; + ny_io->in.Insert(x); + Node* nx = &r->nodes_[x]; + Node* ny = &r->nodes_[y]; if (nx->rank <= ny->rank) { // New edge is consistent with existing rank assignment. return true; @@ -187,8 +200,8 @@ bool GraphCycles::InsertEdge(int32_t x, int32_t y) { // We only need to consider nodes that fall in the range [ny->rank,nx->rank]. if (!ForwardDFS(r, y, nx->rank)) { // Found a cycle. Undo the insertion and tell caller. - nx->out.Erase(y); - ny->in.Erase(x); + nx_io->out.Erase(y); + ny_io->in.Erase(x); // Since we do not call Reorder() on this path, clear any visited // markers left by ForwardDFS. ClearVisitedBits(r, r->deltaf_); @@ -208,14 +221,15 @@ static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound) { while (!r->stack_.empty()) { n = r->stack_.back(); r->stack_.pop_back(); - Node* nn = r->nodes_[n]; + Node* nn = &r->nodes_[n]; if (nn->visited) continue; nn->visited = true; r->deltaf_.push_back(n); - for (auto w : nn->out.GetSequence()) { - Node* nw = r->nodes_[w]; + NodeIO* nn_io = &r->node_io_[n]; + for (auto w : nn_io->out.GetSequence()) { + Node* nw = &r->nodes_[w]; if (nw->rank == upper_bound) { return false; // Cycle } @@ -234,14 +248,15 @@ static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound) { while (!r->stack_.empty()) { n = r->stack_.back(); r->stack_.pop_back(); - Node* nn = r->nodes_[n]; + Node* nn = &r->nodes_[n]; if (nn->visited) continue; nn->visited = true; r->deltab_.push_back(n); - for (auto w : nn->in.GetSequence()) { - Node* nw = r->nodes_[w]; + NodeIO* nn_io = &r->node_io_[n]; + for (auto w : nn_io->in.GetSequence()) { + Node* nw = &r->nodes_[w]; if (!nw->visited && lower_bound < nw->rank) { r->stack_.push_back(w); } @@ -264,36 +279,31 @@ static void Reorder(GraphCycles::Rep* r) { r->deltaf_.end(), r->merged_.begin()); // Assign the ranks in order to the collected list. - for (Vec::size_type i = 0; i < r->list_.size(); i++) { - r->nodes_[r->list_[i]]->rank = r->merged_[i]; + for (size_t i = 0; i < r->list_.size(); i++) { + r->nodes_[r->list_[i]].rank = r->merged_[i]; } } -static void Sort(const Vec& nodes, Vec* delta) { - struct ByRank { - const Vec* nodes; - bool operator()(int32_t a, int32_t b) const { - return (*nodes)[a]->rank < (*nodes)[b]->rank; - } - }; - ByRank cmp; - cmp.nodes = &nodes; - std::sort(delta->begin(), delta->end(), cmp); +static void Sort(absl::Span nodes, std::vector* delta) { + std::sort(delta->begin(), delta->end(), [&](int32_t a, int32_t b) { + return nodes[a].rank < nodes[b].rank; + }); } -static void MoveToList(GraphCycles::Rep* r, Vec* src, - Vec* dst) { - for (Vec::size_type i = 0; i < src->size(); i++) { +static void MoveToList(GraphCycles::Rep* r, std::vector* src, + std::vector* dst) { + for (size_t i = 0; i < src->size(); i++) { int32_t w = (*src)[i]; - (*src)[i] = r->nodes_[w]->rank; // Replace src entry with its rank - r->nodes_[w]->visited = false; // Prepare for future DFS calls + (*src)[i] = r->nodes_[w].rank; // Replace src entry with its rank + r->nodes_[w].visited = false; // Prepare for future DFS calls dst->push_back(w); } } -static void ClearVisitedBits(GraphCycles::Rep* r, const Vec& nodes) { - for (Vec::size_type i = 0; i < nodes.size(); i++) { - r->nodes_[nodes[i]]->visited = false; +static void ClearVisitedBits(GraphCycles::Rep* r, + absl::Span visited_indices) { + for (auto index : visited_indices) { + r->nodes_[index].visited = false; } } @@ -327,7 +337,7 @@ int GraphCycles::FindPath(int32_t x, int32_t y, int max_path_len, return path_len; } - for (auto w : r->nodes_[n]->out.GetSequence()) { + for (auto w : r->node_io_[n].out.GetSequence()) { if (seen.insert(w).second) { r->stack_.push_back(w); } @@ -344,8 +354,8 @@ bool GraphCycles::IsReachable(int32_t x, int32_t y) const { bool GraphCycles::IsReachableNonConst(int32_t x, int32_t y) { if (x == y) return true; Rep* r = rep_; - Node* nx = r->nodes_[x]; - Node* ny = r->nodes_[y]; + Node* nx = &r->nodes_[x]; + Node* ny = &r->nodes_[y]; if (nx->rank >= ny->rank) { // x cannot reach y since it is after it in the topological ordering @@ -380,29 +390,29 @@ std::optional GraphCycles::ContractEdge(int32_t a, int32_t b) { return std::nullopt; } - if (rep_->nodes_[b]->in.Size() + rep_->nodes_[b]->out.Size() > - rep_->nodes_[a]->in.Size() + rep_->nodes_[a]->out.Size()) { + if (rep_->node_io_[b].in.Size() + rep_->node_io_[b].out.Size() > + rep_->node_io_[a].in.Size() + rep_->node_io_[a].out.Size()) { // Swap "a" and "b" to minimize copying. std::swap(a, b); } - Node* nb = rep_->nodes_[b]; - OrderedNodeSet out = std::move(nb->out); - OrderedNodeSet in = std::move(nb->in); + NodeIO* nb_io = &rep_->node_io_[b]; + OrderedNodeSet out = std::move(nb_io->out); + OrderedNodeSet in = std::move(nb_io->in); for (int32_t y : out.GetSequence()) { - rep_->nodes_[y]->in.Erase(b); + rep_->node_io_[y].in.Erase(b); } for (int32_t y : in.GetSequence()) { - rep_->nodes_[y]->out.Erase(b); + rep_->node_io_[y].out.Erase(b); } rep_->free_nodes_.push_back(b); - rep_->nodes_[a]->out.Reserve(rep_->nodes_[a]->out.Size() + out.Size()); + rep_->node_io_[a].out.Reserve(rep_->node_io_[a].out.Size() + out.Size()); for (int32_t y : out.GetSequence()) { InsertEdge(a, y); } - rep_->nodes_[a]->in.Reserve(rep_->nodes_[a]->in.Size() + in.Size()); + rep_->node_io_[a].in.Reserve(rep_->node_io_[a].in.Size() + in.Size()); for (int32_t y : in.GetSequence()) { InsertEdge(y, a); } @@ -412,11 +422,11 @@ std::optional GraphCycles::ContractEdge(int32_t a, int32_t b) { } absl::Span GraphCycles::Successors(int32_t node) const { - return rep_->nodes_[node]->out.GetSequence(); + return rep_->node_io_[node].out.GetSequence(); } absl::Span GraphCycles::Predecessors(int32_t node) const { - return rep_->nodes_[node]->in.GetSequence(); + return rep_->node_io_[node].in.GetSequence(); } std::vector GraphCycles::SuccessorsCopy(int32_t node) const { @@ -430,11 +440,11 @@ std::vector GraphCycles::PredecessorsCopy(int32_t node) const { } namespace { -void SortInPostOrder(absl::Span nodes, +void SortInPostOrder(absl::Span nodes, std::vector* to_sort) { absl::c_sort(*to_sort, [&](int32_t a, int32_t b) { - DCHECK(a == b || nodes[a]->rank != nodes[b]->rank); - return nodes[a]->rank > nodes[b]->rank; + DCHECK(a == b || nodes[a].rank != nodes[b].rank); + return nodes[a].rank > nodes[b].rank; }); } } // namespace @@ -457,10 +467,8 @@ std::vector GraphCycles::AllNodesInPostOrder() const { } std::string GraphCycles::DebugString() const { - absl::flat_hash_set free_nodes_set; - for (int32_t free_node : rep_->free_nodes_) { - free_nodes_set.insert(free_node); - } + absl::flat_hash_set free_nodes_set(rep_->free_nodes_.begin(), + rep_->free_nodes_.end()); std::string result = "digraph {\n"; for (int i = 0, end = rep_->nodes_.size(); i < end; i++) { @@ -468,7 +476,7 @@ std::string GraphCycles::DebugString() const { continue; } - for (int32_t succ : rep_->nodes_[i]->out.GetSequence()) { + for (int32_t succ : rep_->node_io_[i].out.GetSequence()) { absl::StrAppend(&result, " \"", i, "\" -> \"", succ, "\"\n"); } } diff --git a/xla/service/graphcycles/graphcycles.h b/xla/service/graphcycles/graphcycles.h index 778ed446fba70..41970b0beaff8 100644 --- a/xla/service/graphcycles/graphcycles.h +++ b/xla/service/graphcycles/graphcycles.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/graphcycles/graphcycles_test.cc b/xla/service/graphcycles/graphcycles_test.cc index 3b3399432ba0f..0c6fa481b16c3 100644 --- a/xla/service/graphcycles/graphcycles_test.cc +++ b/xla/service/graphcycles/graphcycles_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,11 +17,13 @@ limitations under the License. #include "xla/service/graphcycles/graphcycles.h" +#include #include #include #include #include "absl/container/flat_hash_set.h" +#include "absl/random/random.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" @@ -510,7 +512,7 @@ TEST_F(GraphCyclesTest, CanContractEdge) { static void BM_StressTest(::testing::benchmark::State &state) { const int num_nodes = state.range(0); - for (auto s : state) { + while (state.KeepRunningBatch(num_nodes)) { tensorflow::GraphCycles g; int32_t *nodes = new int32_t[num_nodes]; for (int i = 0; i < num_nodes; i++) { @@ -532,7 +534,7 @@ BENCHMARK(BM_StressTest)->Range(2048, 1048576); static void BM_ContractEdge(::testing::benchmark::State &state) { const int num_nodes = state.range(0); - for (auto s : state) { + while (state.KeepRunningBatch(num_nodes)) { state.PauseTiming(); tensorflow::GraphCycles g; std::vector nodes; @@ -553,3 +555,50 @@ static void BM_ContractEdge(::testing::benchmark::State &state) { } } BENCHMARK(BM_ContractEdge)->Arg(1000)->Arg(10000); + +static void BM_IsReachableNonConst(testing::benchmark::State &state) { + const int num_nodes = state.range(0); + + tensorflow::GraphCycles g; + std::vector nodes; + nodes.reserve(num_nodes); + for (int i = 0; i < num_nodes; i++) { + nodes.push_back(g.NewNode()); + } + + // Add forward edges. + absl::BitGen bitgen; + for (int i = 0; i < num_nodes; i++) { + int max = num_nodes - 1 - i; + if (max == 0) break; + constexpr int branch_factor = 2; + for (int b = 0; b < branch_factor; b++) { + int j = i + 1 + absl::Uniform(bitgen, 0, max); + CHECK_LT(j, num_nodes); + CHECK(g.InsertEdge(nodes[i], nodes[j])); + } + } + + auto get_random_node = [&]() { + return nodes[absl::Uniform(bitgen, 0, num_nodes)]; + }; + + uint32_t src, dst; + int i = 0; + for (auto s : state) { + if (i % 256 == 0) { + src = get_random_node(); + dst = get_random_node(); + } + bool reachable = g.IsReachableNonConst(src, dst); + benchmark::DoNotOptimize(reachable); + i++; + } +} +BENCHMARK(BM_IsReachableNonConst) + ->Arg(10) + ->Arg(50) + ->Arg(100) + ->Arg(200) + ->Arg(1000) + ->Arg(30000); diff --git a/xla/service/graphcycles/ordered_set.h b/xla/service/graphcycles/ordered_set.h index 16258a395ae43..b5e0e1fbd0efc 100644 --- a/xla/service/graphcycles/ordered_set.h +++ b/xla/service/graphcycles/ordered_set.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/graphcycles/ordered_set_test.cc b/xla/service/graphcycles/ordered_set_test.cc index d4e59086b2e4f..4845cdd7be9fa 100644 --- a/xla/service/graphcycles/ordered_set_test.cc +++ b/xla/service/graphcycles/ordered_set_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,9 +15,7 @@ limitations under the License. #include "xla/service/graphcycles/ordered_set.h" -#include "tsl/platform/logging.h" #include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" namespace tensorflow { namespace { diff --git a/xla/service/heap_simulator/BUILD b/xla/service/heap_simulator/BUILD new file mode 100644 index 0000000000000..f51576e24bc05 --- /dev/null +++ b/xla/service/heap_simulator/BUILD @@ -0,0 +1,91 @@ +# Description: +# XLA Heap simulator. + +load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load( + "//xla:xla.bzl", + "xla_cc_test", +) + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +cc_library( + name = "allocation_block", + srcs = ["allocation_block.cc"], + hdrs = ["allocation_block.h"], + deps = [ + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@tsl//tsl/platform:logging", + ], +) + +cc_library( + name = "heap_simulator", + srcs = ["heap_simulator.cc"], + hdrs = ["heap_simulator.h"], + deps = [ + ":allocation_block", + "//xla:comparison_util", + "//xla:status", + "//xla:statusor", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_live_range", + "//xla/service:buffer_value", + "//xla/service:buffer_value_containers", + "//xla/service:hlo_alias_analysis", + "//xla/service:hlo_buffer", + "//xla/service:hlo_dataflow_analysis", + "//xla/service:hlo_ordering", + "//xla/service:hlo_proto_cc", + "//xla/service:hlo_value", + "//xla/service:time_utils", + "//xla/service:tuple_points_to_analysis", + "//xla/service/memory_space_assignment:repacking", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +xla_cc_test( + name = "heap_simulator_test", + srcs = ["heap_simulator_test.cc"], + deps = [ + ":allocation_block", + ":heap_simulator", + "//xla:literal", + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_value", + "//xla/service:hlo_ordering", + "//xla/service:hlo_parser", + "//xla/service:hlo_value", + "//xla/service:tuple_points_to_analysis", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:test", + ], +) diff --git a/xla/service/heap_simulator/allocation_block.cc b/xla/service/heap_simulator/allocation_block.cc new file mode 100644 index 0000000000000..8b776d7e87d0e --- /dev/null +++ b/xla/service/heap_simulator/allocation_block.cc @@ -0,0 +1,119 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/heap_simulator/allocation_block.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tsl/platform/logging.h" + +namespace xla { + +std::string AllocatedSlice::ToString() const { + return absl::StrCat("{ size: ", size, ", offset: ", offset, + ", inclusive_start_time: ", inclusive_start_time, " }"); +} + +std::tuple AllocatedSlice::ToTuple() const { + return std::make_tuple(size, offset, inclusive_start_time); +} + +bool AllocatedSlice::operator==(const AllocatedSlice& rhs) const { + return ToTuple() == rhs.ToTuple(); +} + +std::vector SlicedAllocationData::SizesSortedByOffset() const { + std::vector sizes_sorted_by_offset; + sizes_sorted_by_offset.reserve(slices_sorted_by_offset.size()); + absl::c_for_each(slices_sorted_by_offset, + [&sizes_sorted_by_offset](const AllocatedSlice& slice) { + sizes_sorted_by_offset.push_back(slice.size); + }); + return sizes_sorted_by_offset; +} + +std::vector SlicedAllocationData::SortedInclusiveStartTimes() const { + std::vector sorted_inclusive_start_times; + sorted_inclusive_start_times.reserve(slices_sorted_by_offset.size()); + absl::c_for_each(slices_sorted_by_offset, [&sorted_inclusive_start_times]( + const AllocatedSlice& slice) { + sorted_inclusive_start_times.push_back(slice.inclusive_start_time); + }); + absl::c_sort(sorted_inclusive_start_times); + return sorted_inclusive_start_times; +} + +std::string SlicedAllocationData::ToString() const { + return absl::StrCat( + "{ slices_sorted_by_offset: [ ", + absl::StrJoin(slices_sorted_by_offset, ", ", + [](std::string* out, const AllocatedSlice& slice) { + absl::StrAppend(out, slice.ToString()); + }), + " ] }"); +} + +bool SlicedAllocationData::operator==(const SlicedAllocationData& rhs) const { + return slices_sorted_by_offset == rhs.slices_sorted_by_offset; +} + +std::string AllocationBlock::ToString() const { + std::string original_slicing_str; + if (original_slice_data.has_value()) { + original_slicing_str = absl::StrCat("; original_slice_data: ", + original_slice_data->ToString()); + } + std::string repacked_slicing_str; + if (repacked_slice_data.has_value()) { + repacked_slicing_str = absl::StrCat("; repacked_slice_data: ", + repacked_slice_data->ToString()); + } + return absl::StrCat("[", inclusive_start_time, ", ", end_time, + "]; size: ", size, "; offset: ", offset, + "; initial offset: ", initial_offset, + "; # colocations: ", GetColocationsCount(), + original_slicing_str, repacked_slicing_str); +} + +int AllocationBlock::GetColocationsCount() const { + int count = 1; + for (const AllocationBlock* colocated = next_colocated; colocated != this; + colocated = colocated->next_colocated, ++count) { + CHECK_NE(colocated, nullptr); + } + return count; +} + +std::vector AllocationBlock::GetColocations() { + std::vector colocations{this}; + for (AllocationBlock* colocated = next_colocated; colocated != this; + colocated = colocated->next_colocated) { + CHECK_NE(colocated, nullptr); + colocations.push_back(colocated); + } + return colocations; +} + +bool AllocationBlock::operator<(const AllocationBlock& other) const { + return id < other.id; +} + +} // namespace xla diff --git a/xla/service/heap_simulator/allocation_block.h b/xla/service/heap_simulator/allocation_block.h new file mode 100644 index 0000000000000..f87491e4cb964 --- /dev/null +++ b/xla/service/heap_simulator/allocation_block.h @@ -0,0 +1,109 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// TODO(b/319135034): create a heap_simulator sub directory and move +// allocation_block.h/cc to it. + +// This file contains a number of data structures to describe how blocks of +// data are allocated. It is used by Memory Space Assignment repacking to +// understand how data was allocated before the repacking. + +#ifndef XLA_SERVICE_HEAP_SIMULATOR_ALLOCATION_BLOCK_H_ +#define XLA_SERVICE_HEAP_SIMULATOR_ALLOCATION_BLOCK_H_ + +#include +#include +#include +#include +#include + +namespace xla { + +// Data about a slice in a sliced allocation. +struct AllocatedSlice { + int64_t size; + int64_t offset; + int64_t inclusive_start_time; + + std::string ToString() const; + + std::tuple ToTuple() const; + + bool operator==(const AllocatedSlice& rhs) const; +}; + +// Slice data about a sliced allocation. +struct SlicedAllocationData { + std::vector slices_sorted_by_offset; + + std::vector SizesSortedByOffset() const; + + std::vector SortedInclusiveStartTimes() const; + + int64_t num_slices() const { return slices_sorted_by_offset.size(); } + + std::string ToString() const; + + bool operator==(const SlicedAllocationData& rhs) const; +}; + +// A contiguous block of allocation consisting of start and end (logical) +// times, size, and the initial offset. After repacking, if the repacking was +// successful and the allocations were modified, the offset field holds the +// new offset. To support aliased allocations, AllocationBlock also includes a +// pointer to the next colocated AllocationBlock called next_colocated. The +// colocations form a circular singly-linked list. Therefore, next_colocated +// should never be a nullptr (it should point to itself for AllocationBlocks +// without any other colocations). All AllocationBlock objects within the +// colocations must get the same offset. The id should be unique and is used +// to ensure determinism for comparison tie-breaker. +// +// Each AllocationBlock can be treated as an allocation that requires size +// space from start_time to end_time. However, some allocations are really +// composed of slices. In such cases, the repacker can utilize +// the information in the original_slice_data field to achieve an even more +// efficient repacking. +struct AllocationBlock { + int64_t inclusive_start_time; + int64_t end_time; + int64_t size; + int64_t offset; + int64_t initial_offset; + int64_t id; + AllocationBlock* next_colocated; + + // Optional data structures that are used to improve repacking, when an + // allocation is sliced, e.g., from a sliced prefetch. + std::optional original_slice_data; + std::optional repacked_slice_data; + + std::string ToString() const; + + // Returns the number of AllocationBlocks colocated with this (including + // this AllocationBlock). + int GetColocationsCount() const; + + // Returns the AllocationBlocks colocated with this (including this + // AllocationBlock). + std::vector GetColocations(); + + // This is required by BufferIntervalCompare as a tie breaker. Use a unique + // and deterministic id. + bool operator<(const AllocationBlock& other) const; +}; + +} // namespace xla + +#endif // XLA_SERVICE_HEAP_SIMULATOR_ALLOCATION_BLOCK_H_ diff --git a/xla/service/heap_simulator.cc b/xla/service/heap_simulator/heap_simulator.cc similarity index 75% rename from xla/service/heap_simulator.cc rename to xla/service/heap_simulator/heap_simulator.cc index 1496df490595d..0b6223e3344f3 100644 --- a/xla/service/heap_simulator.cc +++ b/xla/service/heap_simulator/heap_simulator.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/heap_simulator.h" +#include "xla/service/heap_simulator/heap_simulator.h" #include #include @@ -43,6 +43,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/map_util.h" +#include "xla/service/heap_simulator/allocation_block.h" +#include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/repacking.h" #include "xla/service/time_utils.h" #include "xla/status.h" @@ -53,6 +55,10 @@ namespace xla { using absl::flat_hash_map; using absl::flat_hash_set; +bool IsOdd(int x) { return (x % 2) == 1; } + +bool IsEven(int x) { return (x % 2) == 0; } + HeapSimulator::Chunk HeapSimulator::Chunk::FromOffsetEnd(int64_t offset, int64_t end) { return FromOffsetSize(offset, end - offset); @@ -80,7 +86,7 @@ std::ostream& operator<<(std::ostream& stream, } /*static*/ -StatusOr HeapSimulator::MinimumMemoryForModule( +absl::StatusOr HeapSimulator::MinimumMemoryForModule( const HloSchedule& schedule, const LogicalBuffer::SizeFunction& size_function) { if (schedule.empty()) { @@ -104,7 +110,7 @@ StatusOr HeapSimulator::MinimumMemoryForModule( } /*static*/ -StatusOr HeapSimulator::MinimumMemoryForComputation( +absl::StatusOr HeapSimulator::MinimumMemoryForComputation( const HloComputation& computation, const HloInstructionSequence& sequence, const HloAliasAnalysis& alias_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -118,7 +124,7 @@ StatusOr HeapSimulator::MinimumMemoryForComputation( return result.heap_size; } -StatusOr HeapSimulator::MinimumMemoryForComputation( +absl::StatusOr HeapSimulator::MinimumMemoryForComputation( const HloComputation& computation, const HloInstructionSequence& sequence, const HloAliasAnalysis& alias_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -132,7 +138,7 @@ StatusOr HeapSimulator::MinimumMemoryForComputation( } /*static*/ -StatusOr> HeapSimulator::Run( +absl::StatusOr> HeapSimulator::Run( std::unique_ptr> algorithm, const HloModule& module, const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_fn, const Options& options) { @@ -150,7 +156,7 @@ StatusOr> HeapSimulator::Run( } /*static*/ -StatusOr> HeapSimulator::Run( +absl::StatusOr> HeapSimulator::Run( std::unique_ptr> algorithm, const HloComputation& computation, const HloInstructionSequence& instruction_sequence, @@ -171,7 +177,7 @@ StatusOr> HeapSimulator::Run( } /*static*/ -StatusOr> HeapSimulator::Run( +absl::StatusOr> HeapSimulator::Run( std::unique_ptr> algorithm, const HloComputation& computation, const HloInstructionSequence& instruction_sequence, @@ -458,8 +464,8 @@ int64_t HeapSimulator::GetBufferSize(const HloValue* buffer) const { return it->second; } -HeapSimulator::Result HeapSimulator::Finish() { - Result result = algorithm_->Finish(); +absl::StatusOr> HeapSimulator::Finish() { + TF_ASSIGN_OR_RETURN(Result result, algorithm_->Finish()); // Post-process the result to add chunks for shared buffers. An empty chunk // map means that either no buffers were allocated, or the heap was only @@ -478,7 +484,8 @@ HeapSimulator::Result HeapSimulator::Finish() { } // Fragmentation is the difference between the actual and ideal sizes. - const Result no_frag_result = no_fragmentation_stats_->Finish(); + TF_ASSIGN_OR_RETURN(const Result no_frag_result, + no_fragmentation_stats_->Finish()); result.fragmentation_size = result.heap_size - no_frag_result.heap_size; // Copy the debug trace we collected to the final result. @@ -550,7 +557,7 @@ void NoFragmentationStatsHeap::Free(const BufferType* buffer, } template -HeapSimulator::Result +StatusOr> NoFragmentationStatsHeap::Finish() { // The result.chunk_map is empty, since we only collect stats, and don't // actually compute chunk assignments. @@ -561,8 +568,11 @@ NoFragmentationStatsHeap::Finish() { template GlobalDecreasingSizeBestFitHeap::GlobalDecreasingSizeBestFitHeap( - int64_t alignment, Type type, BufferIntervalCompare buffer_interval_compare) - : alignment_(alignment) { + int64_t alignment, Type type, BufferIntervalCompare buffer_interval_compare, + SliceTimePermutationIterator::Ty slice_time_permutation_iterator_type) + : alignment_(alignment), + slice_time_permutation_iteration_type_( + slice_time_permutation_iterator_type) { if (type == kTemporal) { buffer_interval_compare_ = GetTemporalBufferIntervalCompare(); CHECK(buffer_interval_compare == nullptr); @@ -590,6 +600,12 @@ GlobalDecreasingSizeBestFitHeap::GetTemporalBufferIntervalCompare() }); } +template +SliceTimePermutationIterator::Ty GlobalDecreasingSizeBestFitHeap< + BufferType>::slice_time_permutation_iterator_type() const { + return slice_time_permutation_iteration_type_; +} + template /*static*/ typename GlobalDecreasingSizeBestFitHeap< BufferType>::BufferIntervalCompare @@ -976,6 +992,18 @@ const std::vector& GlobalDecreasingSizeBestFitHeap< return slice_sizes_sorted_by_offset_; } +template +std::vector GlobalDecreasingSizeBestFitHeap< + BufferType>::SlicedBufferInterval::inclusive_start_times() const { + std::vector inclusive_start_times; + inclusive_start_times.reserve(num_slices()); + for (const BufferInterval& buffer_interval : make_free_chunks_intervals_) { + inclusive_start_times.push_back(buffer_interval.start); + } + + return inclusive_start_times; +} + template const typename GlobalDecreasingSizeBestFitHeap::BufferInterval& GlobalDecreasingSizeBestFitHeap::SlicedBufferInterval:: @@ -1008,6 +1036,510 @@ std::string GlobalDecreasingSizeBestFitHeap< absl::StrJoin(slice_sizes_sorted_by_offset_, ", "), " } }"); } +namespace { + +// A class that indicates if a permutation of starting slice times is valid. See +// SliceTimePermutationIterator for the meaning of slice time permutations. +// +// In non-repacking scenarios, all slices are valid. In repacking scenarios, +// a permutation is invalid if it does not maintain the mapping between slice +// times and slice sizes of the original placement. +class SliceTimePermutationValidator { + public: + explicit SliceTimePermutationValidator( + const SlicedAllocationData* original_slices) + : original_num_slices_(original_slices ? original_slices->num_slices() + : 0) { + if (original_num_slices_ <= 0) { + return; + } + slice_time_to_inclusive_schedule_time_ = + original_slices->SortedInclusiveStartTimes(); + absl::c_sort(slice_time_to_inclusive_schedule_time_); + + original_slice_sizes_and_start_times_pairwise_sorted_.reserve( + original_num_slices_); + for (const AllocatedSlice& slice : + original_slices->slices_sorted_by_offset) { + original_slice_sizes_and_start_times_pairwise_sorted_.push_back( + std::make_pair(slice.size, slice.inclusive_start_time)); + } + absl::c_sort(original_slice_sizes_and_start_times_pairwise_sorted_); + + sizes_sorted_by_offset_ = original_slices->SizesSortedByOffset(); + } + + bool IsValid(absl::Span permutation) { + if (original_num_slices_ <= 0) { + return true; + } + + // Compute the slice size to slice start time mapping proposed by the + // permutation. + std::vector> + proposed_slice_sizes_and_start_times_pairwise_sorted; + proposed_slice_sizes_and_start_times_pairwise_sorted.reserve( + original_num_slices_); + CHECK_EQ(sizes_sorted_by_offset_.size(), original_num_slices_); + CHECK_EQ(permutation.size(), original_num_slices_); + for (int i = 0; i < original_num_slices_; ++i) { + proposed_slice_sizes_and_start_times_pairwise_sorted.push_back( + std::make_pair( + sizes_sorted_by_offset_[i], + slice_time_to_inclusive_schedule_time_[permutation[i]])); + } + absl::c_sort(proposed_slice_sizes_and_start_times_pairwise_sorted); + + bool allowed = (original_slice_sizes_and_start_times_pairwise_sorted_ == + proposed_slice_sizes_and_start_times_pairwise_sorted); + VLOG(3) << [&]() { + auto export_pair = [](std::string* out, + const std::pair& p) { + absl::StrAppend(out, "<", p.first, ", ", p.second, ">"); + }; + return absl::StrCat( + "Slice permutation ", (allowed ? "allowed" : "disallowed"), + ". Original slice mapping: ", + absl::StrJoin(original_slice_sizes_and_start_times_pairwise_sorted_, + ", ", export_pair), + ". Proposed mapping: ", + absl::StrJoin(proposed_slice_sizes_and_start_times_pairwise_sorted, + ", ", export_pair), + "."); + }(); + + return allowed; + } + + private: + int64_t original_num_slices_; + + // The original allocation mapping from slice times to schedule times. + std::vector slice_time_to_inclusive_schedule_time_; + + std::vector> + original_slice_sizes_and_start_times_pairwise_sorted_; + + std::vector sizes_sorted_by_offset_; +}; + +// A manager class that tracks if we've already observed an equivalent +// permutation. See the description of SliceTimePermutationIterator for a +// definition of permutation equivalence. +class ObservedPermutationManager { + public: + explicit ObservedPermutationManager( + absl::Span inclusive_start_times) { + slice_time_to_inclusive_start_time_ = std::vector( + inclusive_start_times.begin(), inclusive_start_times.end()); + absl::c_sort(slice_time_to_inclusive_start_time_); + } + + // Returns true if an equivalent permutation was already seen. If false is + // returned, we track that we've now observed permutation. + bool Insert(absl::Span permutation) { + std::vector permutation_inclusive_start_times; + permutation_inclusive_start_times.reserve(permutation.size()); + for (int64_t slice_time : permutation) { + permutation_inclusive_start_times.push_back( + slice_time_to_inclusive_start_time_[slice_time]); + } + + return observed_inclusive_start_time_permutation_ + .insert(permutation_inclusive_start_times) + .second; + } + + void Clear() { observed_inclusive_start_time_permutation_.clear(); } + + protected: + std::vector slice_time_to_inclusive_start_time_; + absl::flat_hash_set> + observed_inclusive_start_time_permutation_; +}; + +// A SliceTimePermutationIterator that iterates over all valid (see +// SliceTimePermutationValidator for more details) permutations of slice times. +class SliceTimeAllPermutationIterator : public SliceTimePermutationIterator { + public: + explicit SliceTimeAllPermutationIterator(int64_t num_slices) + : num_slices_(num_slices), permutation_(num_slices, 0) {} + + ~SliceTimeAllPermutationIterator() override = default; + + void Begin() override { + done_ = (num_slices_ <= 0); + + for (int64_t i = 0; i < num_slices_; ++i) { + permutation_[i] = i; + } + } + + bool Done() const override { return done_; } + + void Next() override { + if (Done()) { + return; + } + done_ = !absl::c_next_permutation(permutation_); + } + + absl::Span Get() const override { return permutation_; } + + private: + SliceTimeAllPermutationIterator() = default; + + int64_t num_slices_; + bool done_ = true; + std::vector permutation_; +}; + +// A SliceTimePermutationIterator that iterates over "preferred" shapes, as +// described in SliceTimePermutationIterator::Ty::kPreferred. When we have +// original sliced allocation data available (from a repack), before +// generating preferred permutation, we fix the slice time of any slice whose +// size is different from the first slice. We fix the slice time for such slices +// to their slice times in the original sliced data. Doing so avoids generating +// invalid permutations (as defined in SliceTimePermutationIterator). +// +// Note, in repacking situations, we don't know the exact slice time that each +// slice was assigned. We only know the inclusive start time of each slice. +// This gives us the slice time, except in cases where 2 slices have the same +// inclusive slice time. We choose to break such ties using offset, which is +// fine because it doesn't hurt performance. +class SliceTimePreferredPermutationIterator + : public SliceTimePermutationIterator { + public: + SliceTimePreferredPermutationIterator( + int64_t num_slices, + const SlicedAllocationData* original_sliced_allocation) + : num_slices_(num_slices), + fixed_permutation_values_(num_slices, false), + permutation_(num_slices, 0) { + // In the body of the constructor we need to: + // - If original_sliced_allocation is specified, we update + // fixed_permutation_values_ and permutation_ accordingly + // - Initialize slice_times_available_for_permutation_. + + if (!original_sliced_allocation) { + // If there are no original slice times, then any slice time can appear + // at any permutation index. + slice_times_available_for_permutation_.reserve(num_slices_); + for (int64_t slice_time = 0; slice_time < num_slices_; ++slice_time) { + slice_times_available_for_permutation_.push_back(slice_time); + } + return; + } + + absl::flat_hash_map + slice_to_slice_time_map = + BuildSliceToSliceTimeMap(original_sliced_allocation); + const AllocatedSlice* first_slice = nullptr; + if (!original_sliced_allocation->slices_sorted_by_offset.empty()) { + first_slice = + &original_sliced_allocation->slices_sorted_by_offset.front(); + } + for (int offset_index = 0; offset_index < num_slices_; ++offset_index) { + CHECK(first_slice); + const AllocatedSlice& slice = + original_sliced_allocation->slices_sorted_by_offset[offset_index]; + if (slice.size != first_slice->size) { + fixed_permutation_values_[offset_index] = true; + permutation_[offset_index] = slice_to_slice_time_map[&slice]; + continue; + } + slice_times_available_for_permutation_.push_back( + slice_to_slice_time_map[&slice]); + } + absl::c_sort(slice_times_available_for_permutation_); + } + + ~SliceTimePreferredPermutationIterator() override = default; + + void Begin() override { + permutation_type_ = NextPermutationType(PermutationType::kUninitialized); + SetUpPermutationForCurrentType(); + } + + bool Done() const override { + return permutation_type_ == PermutationType::kDone; + } + + void Next() override { + permutation_type_ = NextPermutationType(permutation_type_); + SetUpPermutationForCurrentType(); + } + + absl::Span Get() const override { return permutation_; } + + private: + enum class PermutationType { + kUninitialized, + // space + // ^ + // | +--+ + // | +--+ | + // | +--+ | + // | +--+ | + // | +--+ | + // | +--------------+ + // +------------------> time + kSmallerOffsetSmallerSliceTime, + // space + // ^ + // | +--------------+ + // | +--+ | + // | +--+ | + // | +--+ | + // | +--+ | + // | +--+ + // +------------------> time + kSmallerOffsetLargerSliceTime, + // space + // ^ + // | +--+ + // | +-----+ | + // | +-----+ | + // | +--+ | + // | +-----+ | + // | +-----+ + // +------------------> time + kDistributeSmallSliceTimesAroundMiddleOffset, + kDone, + }; + + SliceTimePreferredPermutationIterator() = default; + + // Increments from one PermutationType to the next. Note, we skip some + // PermutationTypes if the number of slices is small enough to make some + // PermutationTypes generate the same permutation. + PermutationType NextPermutationType(PermutationType ty) { + switch (ty) { + case PermutationType::kUninitialized: + if (num_slices_ <= 0) { + return PermutationType::kDone; + } + return PermutationType::kSmallerOffsetSmallerSliceTime; + case PermutationType::kSmallerOffsetSmallerSliceTime: + if (num_slices_ <= 1) { + return PermutationType::kDone; + } + return PermutationType::kSmallerOffsetLargerSliceTime; + case PermutationType::kSmallerOffsetLargerSliceTime: + if (num_slices_ <= 2) { + return PermutationType::kDone; + } + return PermutationType::kDistributeSmallSliceTimesAroundMiddleOffset; + case PermutationType::kDistributeSmallSliceTimesAroundMiddleOffset: + case PermutationType::kDone: + return PermutationType::kDone; + } + } + + // Maps slices in original_sliced_allocation to their slice time. + // + // REQUIRES: + // - original_sliced_allocation may not be null + absl::flat_hash_map BuildSliceToSliceTimeMap( + const SlicedAllocationData* original_sliced_allocation) { + CHECK(original_sliced_allocation); + std::vector slice_time_to_slice; + slice_time_to_slice.reserve(num_slices_); + for (const AllocatedSlice& slice : + original_sliced_allocation->slices_sorted_by_offset) { + slice_time_to_slice.push_back(&slice); + } + absl::c_sort(slice_time_to_slice, [](const AllocatedSlice* lhs, + const AllocatedSlice* rhs) { + return std::make_tuple(lhs->inclusive_start_time, lhs->offset) < + std::make_tuple(rhs->inclusive_start_time, rhs->offset); + }); + + absl::flat_hash_map map; + for (int slice_time = 0; slice_time < slice_time_to_slice.size(); + ++slice_time) { + map[slice_time_to_slice[slice_time]] = slice_time; + } + + return map; + } + + // Builds permutation_ according to permutation_type_. + // + // REQUIRES: + // - permutation_type_ != kUninitialized + void SetUpPermutationForCurrentType() { + CHECK(permutation_type_ != PermutationType::kUninitialized); + if (Done()) { + return; + } + + int permutation_index = NextAvailablePermutationIndex(-1); + + for (int i = slice_times_available_for_permutation_.size() - 1; i >= 0; + --i) { + if (permutation_type_ == PermutationType::kSmallerOffsetLargerSliceTime || + (permutation_type_ == + PermutationType::kDistributeSmallSliceTimesAroundMiddleOffset && + IsOdd(i))) { + CHECK_LT(permutation_index, permutation_.size()); + permutation_[permutation_index] = + slice_times_available_for_permutation_[i]; + permutation_index = NextAvailablePermutationIndex(permutation_index); + } + } + for (int i = 0; i < slice_times_available_for_permutation_.size(); ++i) { + if (permutation_type_ == + PermutationType::kSmallerOffsetSmallerSliceTime || + (permutation_type_ == + PermutationType::kDistributeSmallSliceTimesAroundMiddleOffset && + IsEven(i))) { + CHECK_LT(permutation_index, permutation_.size()); + permutation_[permutation_index] = + slice_times_available_for_permutation_[i]; + permutation_index = NextAvailablePermutationIndex(permutation_index); + } + } + CHECK_EQ(permutation_index, permutation_.size()); + } + + // Increments permutation_index. We skip over indices with fixed slice times. + int NextAvailablePermutationIndex(int permutation_index) { + do { + ++permutation_index; + } while (permutation_index < permutation_.size() && + fixed_permutation_values_[permutation_index]); + return permutation_index; + } + + int64_t num_slices_; + // For each value in permutation, indicates if it has a fixed value tied to + // a sliced allocation before repacking. If fixed_permutation_values[i] is + // true, permutation_[i] holds the fixed slice time for the slice with the + // ith smallest offset. + std::vector fixed_permutation_values_; + // Slice times that are available for permutation. A slice time is not + // available for permutation if we have to fix it to an offset to generate + // valid permutations, due to repacking. + std::vector slice_times_available_for_permutation_; + // The current type of permutation we are generating. + PermutationType permutation_type_ = PermutationType::kUninitialized; + // The permutation pertaining to permutation_type_. + std::vector permutation_; +}; + +// A ComposedSliceTimePermutationIterator uses a base_iterator to generate +// permutations. However, it only returns valid permutations, for which we +// have not already emitted an equivalent permutation. +class ComposedSliceTimePermutationIterator + : public SliceTimePermutationIterator { + public: + ComposedSliceTimePermutationIterator( + SliceTimePermutationValidator validator, + ObservedPermutationManager seen_manager, + std::unique_ptr base_iterator) + : validator_(std::move(validator)), + seen_(std::move(seen_manager)), + base_iterator_(std::move(base_iterator)) {} + + ~ComposedSliceTimePermutationIterator() override = default; + + void Begin() override { NextImpl(/*initialize=*/true); } + + bool Done() const override { return base_iterator_->Done(); } + + void Next() override { NextImpl(/*initialize=*/false); } + + absl::Span Get() const override { + return base_iterator_->Get(); + } + + private: + void NextImpl(bool initialize) { + if (initialize) { + seen_.Clear(); + base_iterator_->Begin(); + } + + if (Done()) { + return; + } + + if (!initialize) { + base_iterator_->Next(); + } + + // Keep advancing if we're not done, and the permutation is invalid or an + // equivalent permutation has already been observed. + while (!Done() && (!validator_.IsValid(Get()) || !seen_.Insert(Get()))) { + base_iterator_->Next(); + } + } + + SliceTimePermutationValidator validator_; + ObservedPermutationManager seen_; + std::unique_ptr base_iterator_; +}; + +} // namespace + +std::unique_ptr +SliceTimePermutationIterator::CreateForNewAllocation( + Ty ty, absl::Span inclusive_slice_start_times) { + switch (ty) { + case Ty::kAll: + return std::make_unique( + SliceTimePermutationValidator(/*original_slices=*/nullptr), + ObservedPermutationManager(inclusive_slice_start_times), + std::make_unique( + inclusive_slice_start_times.size())); + case Ty::kPreferred: + return std::make_unique( + SliceTimePermutationValidator(/*original_slices=*/nullptr), + ObservedPermutationManager(inclusive_slice_start_times), + std::make_unique( + inclusive_slice_start_times.size(), + /*original_sliced_allocation=*/nullptr)); + } +} + +std::unique_ptr +SliceTimePermutationIterator::CreateForRepack( + Ty ty, const SlicedAllocationData* original_sliced_allocation) { + // Repacking defaults to 1 slice in the absence of slicing data. + int64_t num_slices = 1; + if (original_sliced_allocation) { + num_slices = original_sliced_allocation->num_slices(); + } + + std::vector inclusive_start_times; + if (original_sliced_allocation) { + inclusive_start_times = + original_sliced_allocation->SortedInclusiveStartTimes(); + } else { + // We don't actually know the first inclusive start time, but the actual + // values don't matter, just their uniqueness within + // inclusive_start_times. So, for a single slice, which is how we + // treat any repacked allocation without slice data, any start time will + // work. + inclusive_start_times.push_back(0); + } + + switch (ty) { + case Ty::kAll: + return std::make_unique( + SliceTimePermutationValidator(original_sliced_allocation), + ObservedPermutationManager(inclusive_start_times), + std::make_unique(num_slices)); + case Ty::kPreferred: + return std::make_unique( + SliceTimePermutationValidator(original_sliced_allocation), + ObservedPermutationManager(inclusive_start_times), + std::make_unique( + num_slices, original_sliced_allocation)); + } +} + template std::string GlobalDecreasingSizeBestFitHeap< BufferType>::SlicedAllocationFinder::FreeChunkPiece::ToString() const { @@ -1154,6 +1686,8 @@ GlobalDecreasingSizeBestFitHeap::SlicedAllocationFinder:: absl::Span free_chunks_per_slice_time, std::vector sorted_slice_sizes, int64_t max_colocation_size, int64_t preferred_offset, int64_t alignment, + std::unique_ptr + slice_time_permutation_iterator, absl::AnyInvocable is_offset_allowed) : sorted_slice_sizes_(std::move(sorted_slice_sizes)), slice_size_sum_(std::accumulate(sorted_slice_sizes_.begin(), @@ -1162,6 +1696,8 @@ GlobalDecreasingSizeBestFitHeap::SlicedAllocationFinder:: max_colocation_size_(max_colocation_size), preferred_offset_(preferred_offset), alignment_(alignment), + slice_time_permutation_iterator_( + std::move(slice_time_permutation_iterator)), is_offset_allowed_(std::move(is_offset_allowed)) { CHECK_EQ(sorted_slice_sizes_.size(), free_chunks_per_slice_time.size()) << "We expect a data structure explaining the free chunks at each slice " @@ -1403,7 +1939,7 @@ GlobalDecreasingSizeBestFitHeap< template Status GlobalDecreasingSizeBestFitHeap::SlicedAllocationFinder:: - DoesPermutationFit(const std::vector& permutation_of_slice_times, + DoesPermutationFit(absl::Span permutation_of_slice_times, const FreeChunkRoot& root, int64_t offset) const { Status result = DoesPermutationFitImpl(permutation_of_slice_times, root, offset); @@ -1418,9 +1954,8 @@ Status GlobalDecreasingSizeBestFitHeap::SlicedAllocationFinder:: template Status GlobalDecreasingSizeBestFitHeap::SlicedAllocationFinder:: - DoesPermutationFitImpl( - const std::vector& permutation_of_slice_times, - const FreeChunkRoot& root, int64_t offset) const { + DoesPermutationFitImpl(absl::Span permutation_of_slice_times, + const FreeChunkRoot& root, int64_t offset) const { if (permutation_of_slice_times.size() != sorted_slice_sizes_.size()) { return InvalidArgumentStrCat( sorted_slice_sizes_.size(), " slices times expected in permutation. ", @@ -1496,46 +2031,14 @@ Status GlobalDecreasingSizeBestFitHeap::SlicedAllocationFinder:: } if (!out_of_slices(slice_index)) { - return InternalErrorStrCat("Ran out of space in root ", - root.chunk.ToString(), - " to fit slice permutation; however, we should " - "have caught such a condition earlier."); + return InternalStrCat("Ran out of space in root ", root.chunk.ToString(), + " to fit slice permutation; however, we should " + "have caught such a condition earlier."); } return OkStatus(); } -namespace { - -// An iterator for iterating through permutations of slice times. -class SliceTimePermutationIterator { - public: - explicit SliceTimePermutationIterator(int64_t latest_slice_time) - : done_(latest_slice_time < 0) { - permutation_.reserve(latest_slice_time + 1); - for (int64_t i = 0; i <= latest_slice_time; ++i) { - permutation_.push_back(i); - } - } - - bool Done() const { return done_; } - - void Next() { - if (Done()) { - return; - } - done_ = !absl::c_next_permutation(permutation_); - } - - const std::vector& Get() const { return permutation_; } - - private: - bool done_ = false; - std::vector permutation_; -}; - -} // namespace - // Future opportunities: // 1) Potential optimization: We don't have to try every offset in // [root.chunk.offset, root.chunk.chunk_end()). If a permutation doesn't fit @@ -1567,10 +2070,14 @@ GlobalDecreasingSizeBestFitHeap::SlicedAllocationFinder::FindInRoot( CHECK_EQ(first_offset % alignment_, 0); for (int64_t offset = first_offset; offset + max_colocation_size_ <= last_end; offset += alignment_) { - for (SliceTimePermutationIterator permutation_it(LatestSliceTime()); - !permutation_it.Done(); permutation_it.Next()) { - if (DoesPermutationFit(permutation_it.Get(), root, offset).ok()) { - return PermutationToChunks(permutation_it.Get(), offset); + for (slice_time_permutation_iterator_->Begin(); + !slice_time_permutation_iterator_->Done(); + slice_time_permutation_iterator_->Next()) { + if (DoesPermutationFit(slice_time_permutation_iterator_->Get(), root, + offset) + .ok()) { + return PermutationToChunks(slice_time_permutation_iterator_->Get(), + offset); } } @@ -1589,7 +2096,7 @@ template typename GlobalDecreasingSizeBestFitHeap< BufferType>::SlicedAllocationFinder::ChunksSortedBySliceTime GlobalDecreasingSizeBestFitHeap::SlicedAllocationFinder:: - PermutationToChunks(const std::vector& permutation_of_slice_times, + PermutationToChunks(absl::Span permutation_of_slice_times, int64_t offset) const { ChunksSortedBySliceTime chunks(permutation_of_slice_times.size() + 1, Chunk::FromOffsetSize(-1, 1)); @@ -1612,7 +2119,7 @@ GlobalDecreasingSizeBestFitHeap::SlicedAllocationFinder:: } template -HeapSimulator::Result +StatusOr> GlobalDecreasingSizeBestFitHeap::Finish() { std::vector sorted_buffer_intervals = GetSortedBufferIntervals(); @@ -1732,8 +2239,11 @@ GlobalDecreasingSizeBestFitHeap::FindChunkCandidates( int64_t max_colocation_size = GetMaxColocationSize(sliced_buffer_interval.full_buffer_interval()); auto chunks = - CreateSlicedAllocationFinder(sliced_buffer_interval, max_colocation_size, - preferred_offset) + CreateSlicedAllocationFinder( + sliced_buffer_interval, max_colocation_size, preferred_offset, + SliceTimePermutationIterator::CreateForNewAllocation( + slice_time_permutation_iteration_type_, + sliced_buffer_interval.inclusive_start_times())) .Find(); return PostProcessFindChunkCandidatesResult(sliced_buffer_interval, std::move(chunks)); @@ -1757,6 +2267,8 @@ typename GlobalDecreasingSizeBestFitHeap::SlicedAllocationFinder GlobalDecreasingSizeBestFitHeap::CreateSlicedAllocationFinder( const SlicedBufferInterval& sliced_interval, int64_t max_colocation_size, int64_t preferred_offset, + std::unique_ptr + slice_time_permutation_iterator, absl::AnyInvocable is_offset_allowed) const { // Build up a list of free chunks for each slice time. std::vector free_chunks_per_slice_time; @@ -1777,10 +2289,10 @@ GlobalDecreasingSizeBestFitHeap::CreateSlicedAllocationFinder( 1), max_colocation_size)); - return SlicedAllocationFinder(free_chunks_per_slice_time, - sliced_interval.SliceSizesSortedByOffset(), - max_colocation_size, preferred_offset, - alignment_, std::move(is_offset_allowed)); + return SlicedAllocationFinder( + free_chunks_per_slice_time, sliced_interval.SliceSizesSortedByOffset(), + max_colocation_size, preferred_offset, alignment_, + std::move(slice_time_permutation_iterator), std::move(is_offset_allowed)); } template @@ -1831,7 +2343,7 @@ void GlobalDecreasingSizeBestFitHeap::AddToChunkMap( DCHECK(emplace_result.second); } -HeapSimulator::Result +absl::StatusOr> ConstrainedGlobalDecreasingSizeBestFitHeap::Finish() { std::vector sorted_buffer_vec = GetSortedBufferIntervals(); // Convert into std::list so that erase() is O(1). @@ -1884,14 +2396,14 @@ ConstrainedGlobalDecreasingSizeBestFitHeap::Finish() { } template -HeapSimulator::Result +StatusOr> ChooseBestHeapAlgorithm::Finish() { DCHECK(!algorithms_.empty()); std::vector results(algorithms_.size()); int64_t min_size = INT64_MAX; int min_size_index = -1; for (int i = 0; i < algorithms_.size(); ++i) { - results[i] = algorithms_[i]->Finish(); + TF_ASSIGN_OR_RETURN(results[i], algorithms_[i]->Finish()); if (results[i].heap_size < min_size) { min_size = results[i].heap_size; min_size_index = i; @@ -1903,8 +2415,7 @@ ChooseBestHeapAlgorithm::Finish() { } template class GlobalDecreasingSizeBestFitHeap; -template class GlobalDecreasingSizeBestFitHeap< - memory_space_assignment::MemorySpaceAssignmentRepacker::AllocationBlock>; +template class GlobalDecreasingSizeBestFitHeap; template class ChooseBestHeapAlgorithm; } // namespace xla diff --git a/xla/service/heap_simulator.h b/xla/service/heap_simulator/heap_simulator.h similarity index 86% rename from xla/service/heap_simulator.h rename to xla/service/heap_simulator/heap_simulator.h index b6b78d7a02803..849be24cb4b08 100644 --- a/xla/service/heap_simulator.h +++ b/xla/service/heap_simulator/heap_simulator.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_HEAP_SIMULATOR_H_ -#define XLA_SERVICE_HEAP_SIMULATOR_H_ +#ifndef XLA_SERVICE_HEAP_SIMULATOR_HEAP_SIMULATOR_H_ +#define XLA_SERVICE_HEAP_SIMULATOR_HEAP_SIMULATOR_H_ #include +#include #include #include #include +#include #include #include #include @@ -41,11 +43,13 @@ limitations under the License. #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/buffer_value.h" #include "xla/service/buffer_value_containers.h" +#include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_ordering.h" +#include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/repacking.h" #include "xla/service/tuple_points_to_analysis.h" #include "xla/statusor.h" @@ -141,20 +145,20 @@ class HeapSimulator { // Returns the minimum memory required to compute an HLO module where all // computations have been scheduled (represented by the given // schedule), assuming no fragmentation. - static StatusOr MinimumMemoryForModule( + static absl::StatusOr MinimumMemoryForModule( const HloSchedule& schedule, const LogicalBuffer::SizeFunction& size_function); // Returns the minimum memory required to compute the given computation, // assuming no fragmentation. - static StatusOr MinimumMemoryForComputation( + static absl::StatusOr MinimumMemoryForComputation( const HloComputation& computation, const HloInstructionSequence& sequence, const HloAliasAnalysis& alias_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map* memory_by_computation = nullptr); - static StatusOr MinimumMemoryForComputation( + static absl::StatusOr MinimumMemoryForComputation( const HloComputation& computation, const HloInstructionSequence& sequence, const HloAliasAnalysis& alias_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -169,7 +173,7 @@ class HeapSimulator { // to running on a per-computation basis, since we can re-use buffer space for // called sub-computations. // - static StatusOr> Run( + static absl::StatusOr> Run( std::unique_ptr> algorithm, const HloModule& module, const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis, @@ -180,7 +184,7 @@ class HeapSimulator { // must contain a topologically-consistent total ordering of all instructions // in the computation. The result is invalid if instructions are not run in // exactly this sequence. - static StatusOr> Run( + static absl::StatusOr> Run( std::unique_ptr> algorithm, const HloComputation& computation, const HloInstructionSequence& instruction_sequence, @@ -192,7 +196,7 @@ class HeapSimulator { // Same as above, but runs on with a schedule that covers all nested // computations. - static StatusOr> Run( + static absl::StatusOr> Run( std::unique_ptr> algorithm, const HloComputation& computation, const HloInstructionSequence& instruction_sequence, @@ -232,7 +236,7 @@ class HeapSimulator { // Two buffers belong to the same shared group. // Eight of the buffer has no shared group assigned. bool InSameSharedGroup(const HloValue* left, const HloValue* right); - Result Finish(); + absl::StatusOr> Finish(); void FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, const HloValue* buffer, const HloInstruction* instruction, @@ -312,7 +316,7 @@ class HeapAlgorithm { // Finish collects the buffer offset assignment results. Finish may only be // called once, after all Alloc and Free calls. - virtual Result Finish() = 0; + virtual absl::StatusOr Finish() = 0; }; // NoFragmentationStatsHeap computes the heap size assuming no fragmentation; @@ -336,7 +340,7 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { void Free(const BufferType* buffer, int64_t size) override; - Result Finish() override; + absl::StatusOr Finish() override; private: int64_t current_heap_size_ = 0; @@ -384,6 +388,79 @@ class BufferIntervalTree { std::list node_storage_; }; +// An iterator that is passed to +// GlobalDecreasingSizeBestFitHeap::CreateSlicedAllocationFinder() when trying +// to place a buffer, telling the finder which permutations of starting slice +// times to try (and in which order to try them). +// * The set of slice times is the set {x : x ∈ [0, num_slices - 1]}. If a +// buffer is not sliced, it will only have 1 permutation, containing slice +// time 0. +// * The ith value in a permutation is the slice time for the slice at the +// ith smallest offset. +// * Iterators skip permutations that are equivalent to previously emitted +// permutations. The ith smallest slice time corresponds to the ith smallest +// inclusive start time. Let the start_time_permutation be the mapping of a +// permutation to its corresponding start times. Two permutations are +// equivalent if their start_time_permutations are equivalent. For example, +// let's say slice time 0 and slice time 1 both map to inclusive start time +// 1000. There is no difference in permutation [0, 1, x] and [1, 0, x] +// because the first two slices map to the same inclusive start time. +// * When repacking slice data is provided, iterators skip invalid +// permutations. A permutation is invalid if the mapping from inclusive +// start times to slice sizes is not maintained from before the repack. +// * Begin() must be called to initialize the iterator before it can be used. +class SliceTimePermutationIterator { + public: + enum class Ty : std::int8_t { + // Include all valid permutations + kAll, + // Only include perferred valid permutations. Heap simulator is trying to + // optimize fitting allocations into a grid of (heap) space by time. The + // preferred permutation iterator only allows the following triagular + // shapes: + // + // Smaller offsets Smaller offsets Slice times are + // get smaller slice get larger slice distributed around + // times times the middle offset + // + // space space space + // ^ ^ ^ + // | +--+ | +--------------+ | +--+ + // | +--+ | | +--+ | | +-----+ | + // | +--+ | | +--+ | | +-----+ | + // | +--+ | | +--+ | | +--+ | + // | +--+ | | +--+ | | +-----+ | + // | +--------------+ | +--+ | +-----+ + // +------------------> +------------------> +------------------> time + // + // We deviate from those shapes as needed to make valid permutations. + kPreferred, + }; + + // A new iterator is typically created for each buffer to be placed. + // - num_slices: number of slices in the buffer. 1 if not sliced. + // - original_sliced_allocation: For a repacking scenario, the original + // details of each slice in a sliced buffer. nullptr is used if the buffer + // was not sliced. (Note, if the repacker has no slicing data, it is + // treated as unsliced in the repacker and by this iterator.) + static std::unique_ptr CreateForNewAllocation( + Ty ty, absl::Span inclusive_slice_start_times); + static std::unique_ptr CreateForRepack( + Ty ty, const SlicedAllocationData* original_sliced_allocation); + + virtual ~SliceTimePermutationIterator() = default; + + virtual void Begin() = 0; + virtual bool Done() const = 0; + virtual void Next() = 0; + + // A permutation of starting slice times. + virtual absl::Span Get() const = 0; + + protected: + SliceTimePermutationIterator() = default; +}; + // GlobalDecreasingSizeBestFitHeap collects the live intervals of all buffers, // then allocates them in decreasing spatial or temporal size regardless of the // alloc/free time. It internally tracks the allocated buffers and their live @@ -501,6 +578,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { const BufferInterval& full_buffer_interval() const; size_t num_slices() const { return slice_sizes_sorted_by_offset_.size(); } const std::vector& SliceSizesSortedByOffset() const; + std::vector inclusive_start_times() const; // Returns a BufferInterval with the requirements to call // GlobalDecreasingSizeBestFitHeap::MakeFreeChunks at the specified slice @@ -618,6 +696,11 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { // with the fully allocated sliced allocation. // - preferred_offset: The preferred starting offset for the fully allocated // sliced allocation. + // - slice_time_permutation_iterator: An iterator for iterating over the + // different slice time permutations for slices. Users may specify the + // order in which different permutations are tried by the HeapSimulator. + // Users are also responsbile for ensuring that returned permutations are + // legal. // - is_offset_allowed: Indicates if a the entire sliced allocation is // allowed to be allocated at a given offset. // @@ -634,6 +717,8 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { absl::Span free_chunks_per_slice_time, std::vector sorted_slice_sizes, int64_t max_colocation_size, int64_t preferred_offset, int64_t alignment, + std::unique_ptr + slice_time_permutation_iterator, absl::AnyInvocable is_offset_allowed = &AllOffsetsAllowed); @@ -672,14 +757,14 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { // sorted_slice_sizes_[i] and would be allocated at offset + // sum(sorted_slice_sizes[j], for j in [0, i-1]). Status DoesPermutationFit( - const std::vector& permutation_of_slice_times, + absl::Span permutation_of_slice_times, const FreeChunkRoot& root, int64_t offset) const; // Only DoesSlicedPermutationFit() should call this method directly. Other // callers should call DoesSlicedPermutationFit(), which contains some // wrapper VLOGGING. Status DoesPermutationFitImpl( - const std::vector& permutation_of_slice_times, + absl::Span permutation_of_slice_times, const FreeChunkRoot& root, int64_t offset) const; // Same as Find() except only checks root, to see if it can hold the sliced @@ -699,7 +784,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { // end of the result to account for an additional colocation space that // need to be allocated. This Chunk is added, even if it is of size 0. ChunksSortedBySliceTime PermutationToChunks( - const std::vector& permutation_of_slice_times, + absl::Span permutation_of_slice_times, int64_t offset) const; std::vector sorted_slice_sizes_; @@ -708,12 +793,16 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { int64_t preferred_offset_; int64_t alignment_; FreeChunkRoots free_chunks_; + std::unique_ptr + slice_time_permutation_iterator_; absl::AnyInvocable is_offset_allowed_; }; explicit GlobalDecreasingSizeBestFitHeap( int64_t alignment, Type type = kSpatial, - BufferIntervalCompare buffer_interval_compare = nullptr); + BufferIntervalCompare buffer_interval_compare = nullptr, + SliceTimePermutationIterator::Ty slice_time_permutation_iterator_type = + SliceTimePermutationIterator::Ty::kAll); ~GlobalDecreasingSizeBestFitHeap() override {} void Alloc(const BufferType* buffer, int64_t size) override; @@ -722,7 +811,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { void ShareWith(const BufferType* buffer, const BufferType* share_with, int64_t size) override; - Result Finish() override; + StatusOr Finish() override; // Return a BufferIntervalCompare function that sort by spatial size. We don't // look at co-locates as they should have the same size. @@ -789,6 +878,8 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { SlicedAllocationFinder CreateSlicedAllocationFinder( const SlicedBufferInterval& sliced_interval, int64_t max_colocation_size, int64_t preferred_offset, + std::unique_ptr + slice_time_permutation_iterator, absl::AnyInvocable is_offset_allowed = &SlicedAllocationFinder::AllOffsetsAllowed) const; std::vector PostProcessFindChunkCandidatesResult( @@ -807,6 +898,8 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { // contiguous. BufferIntervalCompare GetTemporalBufferIntervalCompare() const; + SliceTimePermutationIterator::Ty slice_time_permutation_iterator_type() const; + absl::flat_hash_map buffer_intervals_; HeapResult result_; BufferIntervalCompare buffer_interval_compare_; @@ -819,6 +912,9 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { // Alloc or Free call. int64_t current_time_ = 0; + SliceTimePermutationIterator::Ty slice_time_permutation_iteration_type_ = + SliceTimePermutationIterator::Ty::kAll; + protected: // Returns all transitive colocated buffers of this buffer interval. I.e., If // a buffer A is colocated with B and B is colocated with C, this function @@ -858,7 +954,7 @@ class ConstrainedGlobalDecreasingSizeBestFitHeap size_limit_per_heap_(size_limit_per_heap) {} ~ConstrainedGlobalDecreasingSizeBestFitHeap() override {} - Result Finish() override; + absl::StatusOr Finish() override; private: uint64_t size_limit_per_heap_; @@ -896,17 +992,16 @@ class ChooseBestHeapAlgorithm : public HeapAlgorithm { } } - Result Finish() override; + StatusOr Finish() override; private: std::vector>> algorithms_; }; extern template class GlobalDecreasingSizeBestFitHeap; -extern template class GlobalDecreasingSizeBestFitHeap< - memory_space_assignment::MemorySpaceAssignmentRepacker::AllocationBlock>; +extern template class GlobalDecreasingSizeBestFitHeap; extern template class ChooseBestHeapAlgorithm; } // namespace xla -#endif // XLA_SERVICE_HEAP_SIMULATOR_H_ +#endif // XLA_SERVICE_HEAP_SIMULATOR_HEAP_SIMULATOR_H_ diff --git a/xla/service/heap_simulator_test.cc b/xla/service/heap_simulator/heap_simulator_test.cc similarity index 85% rename from xla/service/heap_simulator_test.cc rename to xla/service/heap_simulator/heap_simulator_test.cc index 8ee842c3f1603..480213f78e8f7 100644 --- a/xla/service/heap_simulator_test.cc +++ b/xla/service/heap_simulator/heap_simulator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,25 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/heap_simulator.h" +#include "xla/service/heap_simulator/heap_simulator.h" #include #include #include #include +#include +#include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" -#include "xla/service/async_op_canonicalizer.h" #include "xla/service/buffer_value.h" -#include "xla/service/hlo_dce.h" +#include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/hlo_ordering.h" #include "xla/service/hlo_parser.h" #include "xla/service/hlo_value.h" @@ -39,6 +41,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/tests/hlo_test_base.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/logging.h" #include "tsl/platform/test.h" namespace xla { @@ -260,7 +263,7 @@ class HeapCallRecorder : public HeapAlgorithm { void Free(const HloValue* buffer, int64_t size) override { calls_->emplace_back(kFree, buffer); } - Result Finish() override { + absl::StatusOr Finish() override { calls_->emplace_back(kFinish, nullptr); HeapSimulator::Result result; result.heap_size = result_.heap_size; @@ -957,16 +960,12 @@ TEST_F(HeapSimulatorTest, AsyncCallImplicitSharding) { ENTRY entry { p0 = f32[8] parameter(0) call-start = ((f32[8]), f32[8], s32[]) call-start(p0), async_execution_thread="foo", to_apply=called_computation - ROOT call-done = f32[8] call-done(call-start), async_execution_thread="foo", to_apply=called_computation + ROOT call-done = f32[8] call-done(call-start) } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo_string)); - AsyncOpCanonicalizer canonicalizer; - TF_ASSERT_OK(canonicalizer.Run(module.get()).status()); - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); TF_ASSERT_OK_AND_ASSIGN(auto alias_analysis, HloAliasAnalysis::Run(module.get())); auto size_fn = [](const BufferValue& buffer) -> int64_t { @@ -1034,7 +1033,9 @@ class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {}; TEST_F(NoFragmentationStatsHeapTest, Empty) { NoFragmentationStatsHeap heap; - EXPECT_EQ(0, heap.Finish().heap_size); + TF_ASSERT_OK_AND_ASSIGN(const HeapSimulator::Result result, + heap.Finish()); + EXPECT_EQ(0, result.heap_size); } TEST_F(NoFragmentationStatsHeapTest, Simple) { @@ -1047,7 +1048,9 @@ TEST_F(NoFragmentationStatsHeapTest, Simple) { heap.Free(buffer_b_, 20); heap.Free(buffer_c_, 30); heap.Free(buffer_d_, 30); - EXPECT_EQ(90, heap.Finish().heap_size); + TF_ASSERT_OK_AND_ASSIGN(const HeapSimulator::Result result, + heap.Finish()); + EXPECT_EQ(90, result.heap_size); } TEST_F(NoFragmentationStatsHeapTest, Mixed) { @@ -1064,14 +1067,17 @@ TEST_F(NoFragmentationStatsHeapTest, Mixed) { heap.Free(buffer_d_, 5); heap.Free(buffer_a_, 10); - EXPECT_EQ(40, heap.Finish().heap_size); + TF_ASSERT_OK_AND_ASSIGN(const HeapSimulator::Result result, + heap.Finish()); + EXPECT_EQ(40, result.heap_size); } class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {}; TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) { GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); - const HeapSimulator::Result result = heap.Finish(); + TF_ASSERT_OK_AND_ASSIGN(const HeapSimulator::Result result, + heap.Finish()); EXPECT_EQ(0, result.heap_size); EXPECT_EQ(1, result.heap_results.size()); EXPECT_EQ(0, result.heap_results.at(0).chunk_map.size()); @@ -1101,7 +1107,8 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) { heap.Free(buffer_c_, 20); heap.Free(buffer_d_, 40); - const HeapSimulator::Result results = heap.Finish(); + TF_ASSERT_OK_AND_ASSIGN(const HeapSimulator::Result results, + heap.Finish()); EXPECT_EQ(1, results.heap_results.size()); const HeapSimulator::HeapResult& result = results.heap_results.at(0); @@ -1143,7 +1150,8 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) { heap.Free(buffer_c_, 50); heap.Free(buffer_d_, 40); - const HeapSimulator::Result results = heap.Finish(); + TF_ASSERT_OK_AND_ASSIGN(const HeapSimulator::Result results, + heap.Finish()); EXPECT_EQ(1, results.heap_results.size()); const HeapSimulator::HeapResult& result = results.heap_results.at(0); @@ -1189,7 +1197,8 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) { heap.Free(buffer_d_, 30); heap.Free(buffer_e_, 50); - const HeapSimulator::Result results = heap.Finish(); + TF_ASSERT_OK_AND_ASSIGN(const HeapSimulator::Result results, + heap.Finish()); EXPECT_EQ(1, results.heap_results.size()); const HeapSimulator::HeapResult& result = results.heap_results.at(0); @@ -1224,7 +1233,8 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, Colocated) { heap.ShareWith(buffer_c_, buffer_a_, 40); heap.Free(buffer_c_, 40); - const HeapSimulator::Result results = heap.Finish(); + TF_ASSERT_OK_AND_ASSIGN(const HeapSimulator::Result results, + heap.Finish()); EXPECT_EQ(1, results.heap_results.size()); const HeapSimulator::HeapResult& result = results.heap_results.at(0); @@ -1256,7 +1266,8 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedII) { heap.Free(buffer_c_, 40); heap.Free(buffer_b_, 20); - const HeapSimulator::Result results = heap.Finish(); + TF_ASSERT_OK_AND_ASSIGN(const HeapSimulator::Result results, + heap.Finish()); EXPECT_EQ(1, results.heap_results.size()); const HeapSimulator::HeapResult& result = results.heap_results.at(0); @@ -1289,7 +1300,8 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedIII) { heap.Free(buffer_c_, 10); heap.Free(buffer_b_, 30); - const HeapSimulator::Result results = heap.Finish(); + TF_ASSERT_OK_AND_ASSIGN(const HeapSimulator::Result results, + heap.Finish()); EXPECT_EQ(1, results.heap_results.size()); const HeapSimulator::HeapResult& result = results.heap_results.at(0); @@ -1321,7 +1333,8 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedDifferentSize1) { heap.Free(buffer_c_, 30); heap.Free(buffer_b_, 20); - const HeapSimulator::Result results = heap.Finish(); + TF_ASSERT_OK_AND_ASSIGN(const HeapSimulator::Result results, + heap.Finish()); EXPECT_EQ(1, results.heap_results.size()); const HeapSimulator::HeapResult& result = results.heap_results.at(0); @@ -1354,7 +1367,8 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedDifferentSize2) { heap.Free(buffer_c_, 50); heap.Free(buffer_b_, 20); - const HeapSimulator::Result results = heap.Finish(); + TF_ASSERT_OK_AND_ASSIGN(const HeapSimulator::Result results, + heap.Finish()); EXPECT_EQ(1, results.heap_results.size()); const HeapSimulator::HeapResult& result = results.heap_results.at(0); @@ -1724,7 +1738,8 @@ TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest, DecreasingSize) { heap.Free(buffer_c_, 20); heap.Free(buffer_d_, 40); - const HeapSimulator::Result result = heap.Finish(); + TF_ASSERT_OK_AND_ASSIGN(const HeapSimulator::Result result, + heap.Finish()); EXPECT_EQ(100, result.heap_size); EXPECT_EQ(2, result.heap_results.size()); @@ -1766,7 +1781,8 @@ TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest, heap.Free(buffer_c_, 50); heap.Free(buffer_d_, 40); - const HeapSimulator::Result result = heap.Finish(); + TF_ASSERT_OK_AND_ASSIGN(const HeapSimulator::Result result, + heap.Finish()); EXPECT_EQ(130, result.heap_size); // 70 + 60 EXPECT_EQ(2, result.heap_results.size()); @@ -1799,7 +1815,8 @@ TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest, ColocatedII) { heap.Free(buffer_c_, 40); heap.Free(buffer_b_, 20); - const HeapSimulator::Result result = heap.Finish(); + TF_ASSERT_OK_AND_ASSIGN(const HeapSimulator::Result result, + heap.Finish()); EXPECT_EQ(60, result.heap_size); // 40 + 20 EXPECT_EQ(2, result.heap_results.size()); @@ -2154,6 +2171,20 @@ class SlicedAllocationFinderTest : public ::testing::Test { using FreeChunks = typename HeapTy::FreeChunks; using Chunk = HeapSimulator::Chunk; using Finder = typename HeapTy::SlicedAllocationFinder; + + protected: + std::unique_ptr NewPermutationIterator( + int64_t num_slices) { + // For these tests, map each slice time to a unique incrementing start time. + std::vector inclusive_start_times; + inclusive_start_times.reserve(num_slices); + for (int64_t start_time = 0; start_time < num_slices; ++start_time) { + inclusive_start_times.push_back(start_time); + } + + return SliceTimePermutationIterator::CreateForNewAllocation( + SliceTimePermutationIterator::Ty::kAll, inclusive_start_times); + } }; TEST_F(SlicedAllocationFinderTest, NoSlices) { @@ -2183,7 +2214,8 @@ The full buffer goes in the smallest chunk that fits. int64_t alignment = 1; Finder finder(free_chunks_per_slice_time, sorted_slice_sizes, - max_colocation_size, preferred_offset, alignment); + max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size())); EXPECT_THAT(finder.Find(), ::testing::ElementsAre(Chunk::FromOffsetSize(45, 3), @@ -2217,7 +2249,8 @@ The max colocation size does not fit in the smallest free chunk. int64_t alignment = 1; Finder finder(free_chunks_per_slice_time, sorted_slice_sizes, - max_colocation_size, preferred_offset, alignment); + max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size())); EXPECT_THAT(finder.Find(), ::testing::ElementsAre(Chunk::FromOffsetSize(60, 3), @@ -2252,7 +2285,8 @@ Multiple free chunks have size 3. We pick the one with the smallest offset. int64_t alignment = 1; Finder finder(free_chunks_per_slice_time, sorted_slice_sizes, - max_colocation_size, preferred_offset, alignment); + max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size())); EXPECT_THAT(finder.Find(), ::testing::ElementsAre(Chunk::FromOffsetSize(10, 3), @@ -2300,7 +2334,8 @@ t0 |xxxxx xxx xxxxx000xxxxxxxxxxxx x int64_t alignment = 1; Finder finder(free_chunks_per_slice_time, sorted_slice_sizes, - max_colocation_size, preferred_offset, alignment); + max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size())); EXPECT_THAT(finder.Find(), ::testing::ElementsAre( @@ -2349,7 +2384,8 @@ t0 |xxxxx xxx xxxxxxxxxxx222xxxxxx x int64_t alignment = 1; Finder finder(free_chunks_per_slice_time, sorted_slice_sizes, - max_colocation_size, preferred_offset, alignment); + max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size())); EXPECT_THAT(finder.Find(), ::testing::ElementsAre( @@ -2398,7 +2434,8 @@ t0 |xxxxx xxx xxxxxxxx111xxxxxxxxx x int64_t alignment = 1; Finder finder(free_chunks_per_slice_time, sorted_slice_sizes, - max_colocation_size, preferred_offset, alignment); + max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size())); EXPECT_THAT(finder.Find(), ::testing::ElementsAre( @@ -2452,7 +2489,8 @@ subsliced by MSA.) int64_t alignment = 1; Finder finder(free_chunks_per_slice_time, sorted_slice_sizes, - max_colocation_size, preferred_offset, alignment); + max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size())); EXPECT_THAT(finder.Find(), ::testing::ElementsAre( @@ -2498,7 +2536,8 @@ t0 |xxxxxx 111 xxx int64_t alignment = 1; Finder finder(free_chunks_per_slice_time, sorted_slice_sizes, - max_colocation_size, preferred_offset, alignment); + max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size())); EXPECT_THAT(finder.Find(), ::testing::ElementsAre( @@ -2547,7 +2586,8 @@ t0 |xxxxx xxx xxxxxx 111 xxxxxxxxx x int64_t alignment = 1; Finder finder(free_chunks_per_slice_time, sorted_slice_sizes, - max_colocation_size, preferred_offset, alignment); + max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size())); EXPECT_THAT(finder.Find(), ::testing::ElementsAre( @@ -2596,7 +2636,8 @@ t0 |xxxxx xxx00000 xxxxxx xxxxxxxxxxx x int64_t alignment = 1; Finder finder(free_chunks_per_slice_time, sorted_slice_sizes, - max_colocation_size, preferred_offset, alignment); + max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size())); EXPECT_THAT(finder.Find(), ::testing::ElementsAre( @@ -2645,7 +2686,8 @@ t0 |xxxxxxxxxx xxxxxxxxxxxxxxxxxxxx000 x int64_t alignment = 1; Finder finder(free_chunks_per_slice_time, sorted_slice_sizes, - max_colocation_size, preferred_offset, alignment); + max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size())); EXPECT_THAT(finder.Find(), ::testing::ElementsAre( @@ -2694,7 +2736,8 @@ t0 |xxxxx xxx xxxxxxxx xxxxxxxxx000 x int64_t alignment = 1; Finder finder(free_chunks_per_slice_time, sorted_slice_sizes, - max_colocation_size, preferred_offset, alignment); + max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size())); EXPECT_THAT(finder.Find(), ::testing::ElementsAre( @@ -2743,7 +2786,8 @@ t0 |xxxxx xxx 000 xxxxxxxx xxxxxxxxx x int64_t alignment = 1; Finder finder(free_chunks_per_slice_time, sorted_slice_sizes, - max_colocation_size, preferred_offset, alignment); + max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size())); EXPECT_THAT(finder.Find(), ::testing::ElementsAre( @@ -2794,7 +2838,8 @@ The sliced allocation does not fit at the preferred offset. int64_t alignment = 1; Finder finder(free_chunks_per_slice_time, sorted_slice_sizes, - max_colocation_size, preferred_offset, alignment); + max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size())); EXPECT_THAT(finder.Find(), ::testing::ElementsAre( @@ -2846,7 +2891,8 @@ on spatial boundaries of 2. int64_t alignment = 2; Finder finder(free_chunks_per_slice_time, sorted_slice_sizes, - max_colocation_size, preferred_offset, alignment); + max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size())); EXPECT_THAT(finder.Find(), ::testing::ElementsAre( @@ -2898,7 +2944,8 @@ on spatial boundaries of 2. int64_t alignment = 2; Finder finder(free_chunks_per_slice_time, sorted_slice_sizes, - max_colocation_size, preferred_offset, alignment); + max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size())); EXPECT_THAT(finder.Find(), ::testing::ElementsAre( @@ -2935,7 +2982,8 @@ t0 |xxxxx000 xxxx xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx int64_t alignment = 1; Finder finder(free_chunks_per_slice_time, sorted_slice_sizes, - max_colocation_size, preferred_offset, alignment); + max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size())); EXPECT_THAT(finder.Find(), ::testing::ElementsAre(Chunk::FromOffsetSize(5, 3), @@ -2974,7 +3022,8 @@ t0 |xxxxx000 xxxx xxxxxxxxxxxxxx xxxxxxxxxxxxxxxxxxxxxxxxxxxx int64_t alignment = 1; Finder finder(free_chunks_per_slice_time, sorted_slice_sizes, - max_colocation_size, preferred_offset, alignment); + max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size())); EXPECT_THAT(finder.Find(), ::testing::ElementsAre(Chunk::FromOffsetSize(5, 3), @@ -3025,6 +3074,7 @@ t0 |xxxxx xxx xxxxx 000xxxxxxxxxxx x Finder finder( free_chunks_per_slice_time, sorted_slice_sizes, max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size()), /*is_offset_allowed=*/[](int64_t offset) { return offset != 45; }); EXPECT_THAT(finder.Find(), @@ -3076,6 +3126,7 @@ t0 |xxxxx xxx xxxxx000xxxxxxxxxxxx x Finder finder( free_chunks_per_slice_time, sorted_slice_sizes, max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size()), // We're not allowed to start at offset 46, but we can include it. /*is_offset_allowed=*/[](int64_t offset) { return offset != 46; }); @@ -3128,6 +3179,7 @@ t0 |xxxxx xxx xxxxxxxxxxx 222xxxxx x Finder finder( free_chunks_per_slice_time, sorted_slice_sizes, max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size()), /*is_offset_allowed=*/[](int64_t offset) { return offset != 45; }); EXPECT_THAT(finder.Find(), @@ -3179,6 +3231,7 @@ t0 |xxxxx xxx xxxxxxxxxxx xxxxxx000 x Finder finder( free_chunks_per_slice_time, sorted_slice_sizes, max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size()), /*is_offset_allowed=*/[](int64_t offset) { return offset != 45; }); EXPECT_THAT(finder.Find(), @@ -3230,6 +3283,7 @@ t0 |xxxxx xxx xxxxx xxxxxxxxxxx x Finder finder( free_chunks_per_slice_time, sorted_slice_sizes, max_colocation_size, preferred_offset, alignment, + NewPermutationIterator(sorted_slice_sizes.size()), /*is_offset_allowed=*/[](int64_t offset) { return offset != 45; }); EXPECT_THAT(finder.FindForOffset(10), @@ -3256,5 +3310,388 @@ t0 |xxxxx xxx xxxxx xxxxxxxxxxx x Chunk::FromOffsetSize(67, 3), Chunk::FromOffsetSize(70, 0))); } +class SliceTimePermutationIteratorTest : public ::testing::Test { + protected: + struct NewAllocationTestCase { + void Test() const { + auto iterator = SliceTimePermutationIterator::CreateForNewAllocation( + ty, inclusive_start_times); + + // Run the iterator multiple times to make sure it can be reused. + for (int i = 0; i < 5; ++i) { + VLOG(2) << "Test case try #" << i << ": NewAllocation, " << name; + EXPECT_THAT(GetPermutations(iterator.get()), + ::testing::ElementsAreArray(expected_permutations)) + << "Failed NewAllocation, " << name; + } + } + + std::string name; + SliceTimePermutationIterator::Ty ty; + std::vector inclusive_start_times; + std::vector> expected_permutations; + }; + + struct RepackTestCase { + void Test() const { + auto iterator = SliceTimePermutationIterator::CreateForRepack( + ty, (original_slice_data.has_value() ? &(*original_slice_data) + : nullptr)); + + // Run the iterator multiple times to make sure it can be reused. + for (int i = 0; i < 5; ++i) { + VLOG(2) << "Test case try #" << i << ": Repack, " << name; + EXPECT_THAT(GetPermutations(iterator.get()), + ::testing::ElementsAreArray(expected_permutations)) + << "Failed Repack, " << name; + } + } + + std::string name; + SliceTimePermutationIterator::Ty ty; + std::optional original_slice_data; + std::vector> expected_permutations; + }; + + static std::vector> GetPermutations( + SliceTimePermutationIterator* it) { + std::vector> results; + for (it->Begin(); !it->Done(); it->Next()) { + absl::Span permutation = it->Get(); + results.push_back( + std::vector(permutation.begin(), permutation.end())); + } + + return results; + } +}; + +TEST_F(SliceTimePermutationIteratorTest, NewAllocations) { + std::vector test_cases = { + { + "0 slices, all permutations", + SliceTimePermutationIterator::Ty::kAll, + /*inclusive_start_times=*/{}, + /*expected_permutations=*/{}, + }, + { + "1 slice, all permutations", + SliceTimePermutationIterator::Ty::kAll, + /*inclusive_start_times=*/{0}, + /*expected_permutations=*/{{0}}, + }, + { + "2 slices, all permutations", + SliceTimePermutationIterator::Ty::kAll, + /*inclusive_start_times=*/{10, 20}, + /*expected_permutations=*/{{0, 1}, {1, 0}}, + }, + { + "many slices, all permutations, unique start times", + SliceTimePermutationIterator::Ty::kAll, + /*inclusive_start_times=*/{40, 10, 450}, + /*expected_permutations=*/ + {{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}}, + }, + { + "many slices, all permutations, non-unique start times", + SliceTimePermutationIterator::Ty::kAll, + /*inclusive_start_times=*/{40, 10, 450, 10}, + /*expected_permutations=*/ + { + // The two smallest start times are the same. Thus, when we + // compare permutations for equivalence, if index i is assigned + // slice time 0, and index j is assigned slice time 1, its + // equivalent to i being assigned 1, and j being assigned 0. + // + // Note, the order of inclusive start times is irrelevant. The ith + // earliest slice time is associated with the ith earliest + // inclusive start time. + {0, 1, 2, 3}, + {0, 1, 3, 2}, + {0, 2, 1, 3}, + {0, 2, 3, 1}, + {0, 3, 1, 2}, + {0, 3, 2, 1}, + // {1, 0, 2, 3}, equivalent emitted + // {1, 0, 3, 2}, equivalent emitted + // {1, 2, 0, 3}, equivalent emitted + // {1, 2, 3, 0}, equivalent emitted + // {1, 3, 0, 2}, equivalent emitted + // {1, 3, 2, 0}, equivalent emitted + {2, 0, 1, 3}, + {2, 0, 3, 1}, + // {2, 1, 0, 3}, equivalent emitted + // {2, 1, 3, 0}, equivalent emitted + {2, 3, 0, 1}, + // {2, 3, 1, 0}, equivalent emitted + {3, 0, 1, 2}, + {3, 0, 2, 1}, + // {3, 1, 0, 2}, equivalent emitted + // {3, 1, 2, 0}, equivalent emitted + {3, 2, 0, 1}, + // {3, 2, 1, 0}, equivalent emitted + }, + }, + { + "0 slices, preferred permutations", + SliceTimePermutationIterator::Ty::kPreferred, + /*inclusive_start_times=*/{}, + /*expected_permutations=*/{}, + }, + { + "1 slice, preferred permutations", + SliceTimePermutationIterator::Ty::kPreferred, + /*inclusive_start_times=*/{0}, + /*expected_permutations=*/{{0}}, + }, + { + "2 slices, preferred permutations", + SliceTimePermutationIterator::Ty::kPreferred, + /*inclusive_start_times=*/{10, 20}, + /*expected_permutations=*/{{0, 1}, {1, 0}}, + }, + { + "many slices, preferred permutations, unique start times", + SliceTimePermutationIterator::Ty::kPreferred, + /*inclusive_start_times=*/{40, 10, 450, 12, 14}, + /*expected_permutations=*/ + {{0, 1, 2, 3, 4}, {4, 3, 2, 1, 0}, {3, 1, 0, 2, 4}}, + }, + { + "many slices, preferred permutations, non-unique start times 1", + SliceTimePermutationIterator::Ty::kPreferred, + /*inclusive_start_times=*/{40, 10, 450, 10}, + /*expected_permutations=*/ + {// This case is not impacted by non-unique start times. + {0, 1, 2, 3}, + {3, 2, 1, 0}, + {3, 1, 0, 2}}, + }, + { + "many slices, preferred permutations, non-unique start times 2", + SliceTimePermutationIterator::Ty::kPreferred, + /*inclusive_start_times=*/{40, 40}, + /*expected_permutations=*/ + { + // The two smallest start times are the same. Thus, we must ignore + // duplicate permutations, when we ignore the order of slice times + // 0 and 1. + {0, 1}, + // This is a duplicate of {0, 1}, when ignoring the order of 0 and + // 1. + // {1, 0}, + }, + }, + }; + + for (const NewAllocationTestCase& test_case : test_cases) { + test_case.Test(); + } +} + +TEST_F(SliceTimePermutationIteratorTest, Repacks) { + std::vector test_cases = { + { + "no slice data, all permutations", + SliceTimePermutationIterator::Ty::kAll, + /*original_slice_data=*/std::nullopt, + /*expected_permutations=*/{{0}}, + }, + { + "0 slices, all permutations", + SliceTimePermutationIterator::Ty::kAll, + /*original_slice_data=*/SlicedAllocationData{}, + /*expected_permutations=*/{}, + }, + { + "1 slice, all permutations", + SliceTimePermutationIterator::Ty::kAll, + /*original_slice_data=*/ + SlicedAllocationData{/*slices_sorted_by_offset=*/{ + {/*size=*/1, /*offset=*/1, /*inclusive_start_time=*/1}, + }}, + /*expected_permutations=*/{{0}}, + }, + { + "2 slices, uniform slice size, all permutations", + SliceTimePermutationIterator::Ty::kAll, + /*original_slice_data=*/ + SlicedAllocationData{/*slices_sorted_by_offset=*/{ + {/*size=*/1, /*offset=*/1, /*inclusive_start_time=*/1}, + {/*size=*/1, /*offset=*/2, /*inclusive_start_time=*/2}, + }}, + /*expected_permutations=*/{{0, 1}, {1, 0}}, + }, + { + "many slices, uniform slice size, unique start times, all " + "permutations", + SliceTimePermutationIterator::Ty::kAll, + /*original_slice_data=*/ + SlicedAllocationData{/*slices_sorted_by_offset=*/{ + {/*size=*/1, /*offset=*/1, /*inclusive_start_time=*/1}, + {/*size=*/1, /*offset=*/2, /*inclusive_start_time=*/2}, + {/*size=*/1, /*offset=*/3, /*inclusive_start_time=*/3}, + }}, + /*expected_permutations=*/ + {{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}}, + }, + { + "many slices, non-uniform slice size, unique start times, all " + "permutations", + SliceTimePermutationIterator::Ty::kAll, + /*original_slice_data=*/ + SlicedAllocationData{/*slices_sorted_by_offset=*/{ + {/*size=*/1, /*offset=*/1, /*inclusive_start_time=*/1}, + {/*size=*/2, /*offset=*/2, /*inclusive_start_time=*/3}, + {/*size=*/1, /*offset=*/3, /*inclusive_start_time=*/2}, + }}, + /*expected_permutations=*/ + { + // The slice at index 0 has a different size than any other slice, + // so it's invalid to give it any slice time other than its + // original slice time of 2. + {0, 2, 1}, + {1, 2, 0}, + }, + }, + { + "many slices, non-uniform slice size, non-unique start times, all " + "permutations", + SliceTimePermutationIterator::Ty::kAll, + /*original_slice_data=*/ + SlicedAllocationData{/*slices_sorted_by_offset=*/{ + {/*size=*/1, /*offset=*/1, /*inclusive_start_time=*/1}, + {/*size=*/1, /*offset=*/2, /*inclusive_start_time=*/2}, + {/*size=*/2, /*offset=*/3, /*inclusive_start_time=*/1}, + {/*size=*/1, /*offset=*/5, /*inclusive_start_time=*/1}, + {/*size=*/2, /*offset=*/6, /*inclusive_start_time=*/3}, + {/*size=*/3, /*offset=*/8, /*inclusive_start_time=*/4}, + }}, + /*expected_permutations=*/ + { + // All permutations such that: + // * The first 3 slice times hold 2 slices with size 1, and 1 + // slice with size 2. + // * Slice time 3 holds a slice with size 1. + // * Slice time 4 holds a slice with size 2. + // * Slice time 5 holds a slice with size 3, which can only be the + // slice at index 5. + // * We throw away permutations where the first 3 slice times are + // given to the same slice offsets. + {0, 1, 2, 3, 4, 5}, + {0, 1, 4, 3, 2, 5}, + {0, 3, 1, 2, 4, 5}, + {0, 3, 4, 1, 2, 5}, + {3, 0, 1, 2, 4, 5}, + {3, 0, 4, 1, 2, 5}, + }, + }, + { + "no slice data, preferred permutations", + SliceTimePermutationIterator::Ty::kPreferred, + /*original_slice_data=*/std::nullopt, + /*expected_permutations=*/{{0}}, + }, + { + "0 slices, preferred permutations", + SliceTimePermutationIterator::Ty::kPreferred, + /*original_slice_data=*/SlicedAllocationData{}, + /*expected_permutations=*/{}, + }, + { + "1 slice, preferred permutations", + SliceTimePermutationIterator::Ty::kPreferred, + /*original_slice_data=*/ + SlicedAllocationData{/*slices_sorted_by_offset=*/{ + {/*size=*/1, /*offset=*/1, /*inclusive_start_time=*/1}, + }}, + /*expected_permutations=*/{{0}}, + }, + { + "2 slices, uniform slice size, preferred permutations", + SliceTimePermutationIterator::Ty::kPreferred, + /*original_slice_data=*/ + SlicedAllocationData{/*slices_sorted_by_offset=*/{ + {/*size=*/1, /*offset=*/1, /*inclusive_start_time=*/1}, + {/*size=*/1, /*offset=*/2, /*inclusive_start_time=*/2}, + }}, + /*expected_permutations=*/{{0, 1}, {1, 0}}, + }, + { + "many slices, uniform slice size, unique start times, preferred " + "permutations", + SliceTimePermutationIterator::Ty::kPreferred, + /*original_slice_data=*/ + SlicedAllocationData{/*slices_sorted_by_offset=*/{ + {/*size=*/1, /*offset=*/1, /*inclusive_start_time=*/1}, + {/*size=*/1, /*offset=*/2, /*inclusive_start_time=*/2}, + {/*size=*/1, /*offset=*/3, /*inclusive_start_time=*/3}, + }}, + /*expected_permutations=*/ + {{0, 1, 2}, {2, 1, 0}, {1, 0, 2}}, + }, + { + "many slices, non-uniform slice size, unique start times, preferred " + "permutations", + SliceTimePermutationIterator::Ty::kPreferred, + /*original_slice_data=*/ + SlicedAllocationData{/*slices_sorted_by_offset=*/{ + {/*size=*/1, /*offset=*/1, /*inclusive_start_time=*/1}, + {/*size=*/2, /*offset=*/2, /*inclusive_start_time=*/3}, + {/*size=*/1, /*offset=*/3, /*inclusive_start_time=*/2}, + }}, + /*expected_permutations=*/ + { + // The 2nd slice has a different size than the first, so we must + // fix it to its original slice time, i.e., slice time 2. + {0, 2, 1}, + {1, 2, 0}, + }, + }, + { + "many slices, non-uniform slice size, non-unique start times, " + "preferred permutations", + SliceTimePermutationIterator::Ty::kPreferred, + /*original_slice_data=*/ + SlicedAllocationData{/*slices_sorted_by_offset=*/{ + {/*size=*/1, /*offset=*/1, /*inclusive_start_time=*/1}, + {/*size=*/1, /*offset=*/2, /*inclusive_start_time=*/2}, + {/*size=*/2, /*offset=*/3, /*inclusive_start_time=*/1}, + {/*size=*/1, /*offset=*/5, /*inclusive_start_time=*/1}, + {/*size=*/2, /*offset=*/6, /*inclusive_start_time=*/3}, + {/*size=*/3, /*offset=*/8, /*inclusive_start_time=*/4}, + }}, + /*expected_permutations=*/ + { + // First we fix the slice times of slices that have different + // sizes than the first slice, i.e., slices 3, 5, and 6. If we + // sort the slices by , the fixed + // slice time for those slices will be their index in the sorted + // order. Thus, slice 3 is fixed to slice time 1. Slice 5 is fixed + // to slice time 4. And, slice 6 is fixed to slice time 5. + // + // The remaining slices are given preferred slice times, throwing + // out any equivalent permutations. Two permutations are + // equivalent if they are equal, after ignoring permutations of + // the slice times that map to the same inclusive start time. In + // our case, slice times 0, 1, and 2 map to inclusive start + // time 1. Thus, if indices i, j, and k are given slice times 0, + // 1, and 2, it doesn't matter which of i, j, and k maps to 0, 1, + // and 2 (for the purposes of equivalence). + {0, 2, 1, 3, 4, 5}, + {3, 2, 1, 0, 4, 5}, + // The next permutation is the same as the previous, except slice + // times 0 and 2 are permuted, so we throw it out. + // {3, 0, 1, 2, 4, 5}, + }, + }, + }; + + for (const RepackTestCase& test_case : test_cases) { + test_case.Test(); + } +} + } // namespace } // namespace xla diff --git a/xla/service/hlo.proto b/xla/service/hlo.proto index ea63665e83e85..b79805ec37c94 100644 --- a/xla/service/hlo.proto +++ b/xla/service/hlo.proto @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ syntax = "proto3"; package xla; +import "google/protobuf/any.proto"; import "xla/xla_data.proto"; option cc_enable_arenas = true; @@ -111,7 +112,7 @@ enum CustomCallApiVersion { } // Serialization of HloInstruction. -// Next ID: 86 +// Next ID: 87 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -351,10 +352,8 @@ message HloInstructionProto { // status-returning API. CustomCallApiVersion custom_call_api_version = 77; - // Represents a unique identifier for an async group which consists of an - // async start, async done, and zero or more async update operations. - // Negative async_group_id is equivalent to no async group id. - int64 async_group_id = 78; + // Used to be async_group_id. + reserved 78; // Represents a unique execution thread name for one or more async groups. // Each HLO module may contain a main thread and one or more parallel threads. @@ -371,17 +370,13 @@ message HloInstructionProto { // graph. xla.StatisticsViz statistics_viz = 82; - // Specifies which operation queue the current instruction will run on. - // A backend may have multiple operation queues to run instructions - // concurrently, use this to signal the backend which queue to dispatch to. - // The backend should keep a mapping of - // operation_queue_id->actual_hardware_queue_id if runtime will create - // different IDs. - int64 operation_queue_id = 83; - - // Specifies which operation queues to await for data when running with - // multiple operation queues. - repeated int64 wait_on_operation_queues = 84; + // Used to be operation_queue_id. + reserved 83; + // Used to be wait_on_operation_queues. + reserved 84; + + // Sparsity descriptor for dot operation. + repeated xla.SparsityDescriptor dot_sparsity = 86; } // Serialization of HloComputation. @@ -799,6 +794,9 @@ message HloPassMetadata { // Timestamp before and after the pass is run. Note they may be equal. int64 start_timestamp_usec = 8; int64 end_timestamp_usec = 9; + + // Custom metadata for the pass. + google.protobuf.Any custom_metadata = 10; } // Encodes the underlying Xla runtime executable compiled from the XLA module. diff --git a/xla/service/hlo_alias_analysis.cc b/xla/service/hlo_alias_analysis.cc index 186352bd84144..444255d9c344b 100644 --- a/xla/service/hlo_alias_analysis.cc +++ b/xla/service/hlo_alias_analysis.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -381,7 +381,7 @@ std::string HloAliasAnalysis::ToString() const { } /* static */ -StatusOr> HloAliasAnalysis::Run( +absl::StatusOr> HloAliasAnalysis::Run( const HloModule* module, const HloDataflowAnalysis::CanShareBuffer& can_share_buffer) { VLOG(2) << "HloAliasAnalysis::Run on module " << module->name(); diff --git a/xla/service/hlo_alias_analysis.h b/xla/service/hlo_alias_analysis.h index e3ff80c4f8199..b0896a93c6883 100644 --- a/xla/service/hlo_alias_analysis.h +++ b/xla/service/hlo_alias_analysis.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -39,7 +39,7 @@ class HloAliasAnalysis { public: // The callgraph of the given HloModule must be flattened // (xla::FlattenCallGraph) prior to running the analysis. - static StatusOr> Run( + static absl::StatusOr> Run( const HloModule* module, const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr); diff --git a/xla/service/hlo_alias_analysis_test.cc b/xla/service/hlo_alias_analysis_test.cc index b87a4964c2a96..36709bc0a8e79 100644 --- a/xla/service/hlo_alias_analysis_test.cc +++ b/xla/service/hlo_alias_analysis_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_buffer.cc b/xla/service/hlo_buffer.cc index 3641b60b760ed..4c7569d10d9b2 100644 --- a/xla/service/hlo_buffer.cc +++ b/xla/service/hlo_buffer.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_buffer.h b/xla/service/hlo_buffer.h index 90d064d9b7cc4..398b08d8cfdee 100644 --- a/xla/service/hlo_buffer.h +++ b/xla/service/hlo_buffer.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_casting_utils_test.cc b/xla/service/hlo_casting_utils_test.cc index 4c0acf40e00c0..e4cb40244f22e 100644 --- a/xla/service/hlo_casting_utils_test.cc +++ b/xla/service/hlo_casting_utils_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_computation_deduplicator.cc b/xla/service/hlo_computation_deduplicator.cc index fc1f4971093ac..52c5ef8326380 100644 --- a/xla/service/hlo_computation_deduplicator.cc +++ b/xla/service/hlo_computation_deduplicator.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ bool HloComputationDeduplicator::ContainsLargeConstants(HloComputation* comp) { } return false; } -StatusOr HloComputationDeduplicator::Run( +absl::StatusOr HloComputationDeduplicator::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { absl::flat_hash_map unique_comps; diff --git a/xla/service/hlo_computation_deduplicator.h b/xla/service/hlo_computation_deduplicator.h index 42fcb191ce943..96d67c6eb73bf 100644 --- a/xla/service/hlo_computation_deduplicator.h +++ b/xla/service/hlo_computation_deduplicator.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ class HloComputationDeduplicator : public HloModulePass { absl::string_view name() const override { return "computation-deduplicator"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/hlo_computation_deduplicator_test.cc b/xla/service/hlo_computation_deduplicator_test.cc index 3dab6805d21a9..df7f725c6668a 100644 --- a/xla/service/hlo_computation_deduplicator_test.cc +++ b/xla/service/hlo_computation_deduplicator_test.cc @@ -1,5 +1,5 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_computation_test.cc b/xla/service/hlo_computation_test.cc index 2e576161664e9..a0a2f5ecb01a3 100644 --- a/xla/service/hlo_computation_test.cc +++ b/xla/service/hlo_computation_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,8 +29,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" +#include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" @@ -487,6 +489,44 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { EXPECT_EQ(negate, computation->root_instruction()); } +TEST_F(HloComputationTest, ReplaceParameter) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], s32[]) parameter(0) + val = f32[2] get-tuple-element(p_body), index=0 + const = s32[] constant(-1) + ROOT root = (f32[2], s32[]) tuple(val, const) + } + + condition { + p_cond = (f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=1 + const = s32[] constant(42) + ROOT result = pred[] compare(gte, const), direction=EQ + } + + ENTRY entry { + param.1 = s32[] parameter(0) + const = f32[2] constant({0,1}) + while_init = (f32[2], s32[]) tuple(const, param.1) + while = (f32[2], s32[]) while(while_init), condition=condition, body=body + ROOT out = s32[] get-tuple-element(while), index=1 + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule(kHloModule)); + HloComputation* body = module->GetComputationWithName("body"); + + Shape new_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {2}), ShapeUtil::MakeShape(S32, {})}); + body->ReplaceParameter( + 0, HloInstruction::CreateParameter(0, new_param_shape, "new_p_body")); + + EXPECT_TRUE(ShapeUtil::Equal(body->parameter_instruction(0)->shape(), + new_param_shape)); +} + TEST_F(HloComputationTest, CloneWithControlDependency) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( @@ -855,52 +895,13 @@ TEST_F(HloComputationTest, CloneWrappedAsyncInstructionSameWrappedFunc) { ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* start = FindInstruction(module.get(), "reduce-scatter-start"); HloInstruction* done = FindInstruction(module.get(), "reduce-scatter-done"); - EXPECT_EQ(start->called_computations()[0], done->called_computations()[0]); - std::unique_ptr cloned_start = start->Clone(); - std::unique_ptr cloned_done = - done->CloneWithNewOperands(done->shape(), {cloned_start.get()}); - EXPECT_EQ(cloned_start.get()->called_computations()[0], - cloned_done.get()->called_computations()[0]); -} - -TEST_F(HloComputationTest, CloneWrappedAsyncInstructionDiffWrappedFunc) { - const char* const hlo_string = R"( - HloModule Module - add (lhs: u32[], rhs: u32[]) -> u32[] { - lhs = u32[] parameter(0) - rhs = u32[] parameter(1) - ROOT add = u32[] add(u32[] lhs, u32[] rhs) - } - - async_wrapped_1 (async_param: u32[8]) -> u32[4] { - async_param = u32[8]{0} parameter(0) - ROOT %reduce-scatter = u32[4]{0} reduce-scatter(u32[8]{0} async_param), - replica_groups={}, dimensions={0}, to_apply=add - } - - async_wrapped_2 (async_param.1: u32[8]) -> u32[4] { - async_param.1 = u32[8]{0} parameter(0) - ROOT reduce-scatter.1 = u32[4]{0} reduce-scatter(u32[8]{0} async_param.1), - replica_groups={}, dimensions={0}, to_apply=add - } - - ENTRY main (data: u32[8]) -> u32[4] { - data = u32[8]{0} parameter(0) - reduce-scatter-start = ((u32[8]{0}), u32[4]{0}) async-start(u32[8]{0} data), - calls=async_wrapped_1, backend_config={"is_sync":false} - ROOT reduce-scatter-done = u32[4]{0} async-done(((u32[8]{0}), u32[4]{0}) reduce-scatter-start), - calls=async_wrapped_2 -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* start = FindInstruction(module.get(), "reduce-scatter-start"); - HloInstruction* done = FindInstruction(module.get(), "reduce-scatter-done"); - EXPECT_NE(start->called_computations()[0], done->called_computations()[0]); + EXPECT_EQ(start->async_wrapped_computation(), + done->async_wrapped_computation()); std::unique_ptr cloned_start = start->Clone(); std::unique_ptr cloned_done = done->CloneWithNewOperands(done->shape(), {cloned_start.get()}); - EXPECT_NE(cloned_start.get()->called_computations()[0], - cloned_done.get()->called_computations()[0]); + EXPECT_EQ(cloned_start.get()->async_wrapped_computation(), + cloned_done.get()->async_wrapped_computation()); } } // namespace diff --git a/xla/service/hlo_constant_folding.cc b/xla/service/hlo_constant_folding.cc index 88d016ae5fc6b..22950e91bcdbf 100644 --- a/xla/service/hlo_constant_folding.cc +++ b/xla/service/hlo_constant_folding.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -66,7 +66,7 @@ static bool IsOrContainsIllegalInstr(const HloInstruction* instr) { /*static*/ std::atomic HloConstantFolding::slow_op_counter_{0}; -StatusOr HloConstantFolding::Run( +absl::StatusOr HloConstantFolding::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { // Limit the constant folding to 0 iterations to skip folding loops. This @@ -233,6 +233,16 @@ StatusOr HloConstantFolding::Run( dead_instructions.push_back(instruction); HloInstruction* new_constant = computation->AddInstruction( HloInstruction::CreateConstant(std::move(result))); + if (new_constant->shape().has_layout()) { + // Update element_size_in_bits on the new instruction's layout. Literals + // always have element_size_in_bits set to 0, and CreateConstant copies + // the shape/layout from the Literal, so we need to set + // element_size_in_bits here. + new_constant->mutable_shape() + ->mutable_layout() + ->set_element_size_in_bits( + instruction->shape().layout().element_size_in_bits()); + } TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(new_constant)); } } diff --git a/xla/service/hlo_constant_folding.h b/xla/service/hlo_constant_folding.h index 217333f23ed98..e20be0d2db9ad 100644 --- a/xla/service/hlo_constant_folding.h +++ b/xla/service/hlo_constant_folding.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ class HloConstantFolding : public HloModulePass { // Run constant folding operations on the given module. Returns whether the // module was changed (constant expressions folded). using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/hlo_constant_folding_test.cc b/xla/service/hlo_constant_folding_test.cc index 5e51acfc874eb..4958bee65f54d 100644 --- a/xla/service/hlo_constant_folding_test.cc +++ b/xla/service/hlo_constant_folding_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -346,6 +346,32 @@ TEST_F(HloConstantFoldingTest, FoldOpsWhereOneOperandIsBroadcast) { ))); } +TEST_F(HloConstantFoldingTest, FoldInt4Ops) { + const char* const kModuleStr = R"( + HloModule test + + ENTRY entry { + c0 = s4[2]{0:E(4)} constant({1, 2}) + c1 = s4[2]{0:E(4)} constant({3, 4}) + add1 = s4[2]{0:E(4)} add(c0, c1) + c2 = s4[]{:E(4)} constant(5) + add2 = s4[2]{0:E(4)} add(c0, s4[2]{0:E(4)} broadcast(c2)) + ROOT root = tuple(add1, add2) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + HloConstantFolding constant_folding; + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&constant_folding, module.get())); + EXPECT_TRUE(result); + auto is_4_bit = [](const HloInstruction* instr) { + return instr->shape().layout().element_size_in_bits() == 4; + }; + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple(m::Constant().WithPredicate(is_4_bit), + m::Constant().WithPredicate(is_4_bit)))); +} + TEST_F(HloConstantFoldingTest, BigReduceWindow) { constexpr absl::string_view kModuleStr = R"( HloModule test diff --git a/xla/service/hlo_cost_analysis.cc b/xla/service/hlo_cost_analysis.cc index 9fe36cd2eb16e..c60a4193aaee2 100644 --- a/xla/service/hlo_cost_analysis.cc +++ b/xla/service/hlo_cost_analysis.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -25,11 +25,13 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" @@ -78,7 +80,7 @@ Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) { if (key == kOptimalSecondsKey) { return; } - float per_second_rate = options_.per_second_rates[key]; + float per_second_rate = options_.per_second_rate(key); if (per_second_rate != 0) { optimal_seconds = std::max(optimal_seconds, val / per_second_rate); } @@ -133,13 +135,14 @@ Status HloCostAnalysis::HandleElementwiseOp( // operation can correspond to several floating point ops. // kLogistic is included in "trascendental" as it is implemented using // trascendental ops (tanh or exp). - if (opcode == HloOpcode::kExp || opcode == HloOpcode::kLog || - opcode == HloOpcode::kLogistic || opcode == HloOpcode::kPower || - opcode == HloOpcode::kSqrt || opcode == HloOpcode::kCbrt || - opcode == HloOpcode::kRsqrt || opcode == HloOpcode::kTanh || - opcode == HloOpcode::kSin || opcode == HloOpcode::kCos || - opcode == HloOpcode::kExpm1 || opcode == HloOpcode::kLog1p || - opcode == HloOpcode::kAtan2 || opcode == HloOpcode::kTan) { + if (opcode == HloOpcode::kErf || opcode == HloOpcode::kExp || + opcode == HloOpcode::kLog || opcode == HloOpcode::kLogistic || + opcode == HloOpcode::kPower || opcode == HloOpcode::kSqrt || + opcode == HloOpcode::kCbrt || opcode == HloOpcode::kRsqrt || + opcode == HloOpcode::kTanh || opcode == HloOpcode::kSin || + opcode == HloOpcode::kCos || opcode == HloOpcode::kExpm1 || + opcode == HloOpcode::kLog1p || opcode == HloOpcode::kAtan2 || + opcode == HloOpcode::kTan) { current_properties_[kTranscendentalsKey] = computation_count; } else { // Note: transcendental operations are considered a separate category from @@ -179,7 +182,28 @@ int64_t HloCostAnalysis::FusionParameterReadBytes( switch (user->opcode()) { case HloOpcode::kFusion: { for (int64_t idx : user->OperandIndices(hlo)) { - size += FusionParameterReadBytes(user->fused_parameter(idx)); + auto nested_size = + FusionParameterReadBytes(user->fused_parameter(idx)); + const HloInstruction* root_instruction = + user->fused_instructions_computation()->root_instruction(); + // We define the nested fusion as simple if the parameter directly + // feeds the root. + const bool fusion_is_simple = + user->fused_parameter(idx) == root_instruction->operand(0); + const auto& fusion_users = user->users(); + auto is_slice = [](const HloInstruction* hlo) { + return hlo->opcode() == HloOpcode::kSlice || + hlo->opcode() == HloOpcode::kDynamicSlice; + }; + // If the nested fusion is simple and the user is a slice, + // we only load that portion of the parameter. + // TODO(b/332998529): deal with nested fusions more generally. + if (fusion_is_simple && fusion_users.size() == 1 && + is_slice(fusion_users[0])) { + size += GetShapeSize(fusion_users[0]->shape()); + } else { + size += nested_size; + } } break; } @@ -410,11 +434,12 @@ Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { Status HloCostAnalysis::HandleInfeed(const HloInstruction* infeed) { // Count nested infeed output tuples. int64_t size = 0; - for (const auto& indexed_shape : ShapeUtil::GetLeafShapes(infeed->shape())) { - size += GetShapeSize(indexed_shape.shape); - current_properties_.set_output_bytes_accessed( - indexed_shape.index, GetShapeSize(indexed_shape.shape)); - } + ShapeUtil::ForEachLeafShape( + infeed->shape(), [&](const Shape& sub_shape, const ShapeIndex& index) { + size += GetShapeSize(sub_shape); + current_properties_.set_output_bytes_accessed(index, + GetShapeSize(sub_shape)); + }); current_properties_.set_output_bytes_accessed(size); current_properties_[kBytesAccessedKey] = size; return OkStatus(); @@ -426,12 +451,14 @@ Status HloCostAnalysis::HandleOutfeed(const HloInstruction* outfeed) { for (int64_t i = 0; i < outfeed->operand_count(); ++i) { const HloInstruction* operand = outfeed->operand(i); int64_t size = 0; - for (const auto& indexed_shape : - ShapeUtil::GetLeafShapes(operand->shape())) { - size += GetShapeSize(indexed_shape.shape); - current_properties_.set_operand_bytes_accessed( - i, indexed_shape.index, GetShapeSize(indexed_shape.shape)); - } + + ShapeUtil::ForEachLeafShape( + operand->shape(), [&](const Shape& sub_shape, const ShapeIndex& index) { + size += GetShapeSize(sub_shape); + current_properties_.set_operand_bytes_accessed( + i, index, GetShapeSize(sub_shape)); + }); + current_properties_.set_operand_bytes_accessed(i, size); current_properties_[kBytesAccessedKey] += size; } @@ -605,7 +632,7 @@ Status HloCostAnalysis::HandleBitcast(const HloInstruction*) { Status HloCostAnalysis::HandleBroadcast(const HloInstruction* broadcast) { if (options_.count_multiple_input_accesses) { current_properties_.set_operand_bytes_accessed( - 0, ShapeUtil::ElementsIn(broadcast->shape())); + 0, GetShapeSize(broadcast->shape())); current_properties_.set_operand_utilization( 0, 1.0 * ShapeUtil::ElementsIn(broadcast->shape()) / ShapeUtil::ElementsIn(broadcast->operand(0)->shape())); @@ -969,6 +996,11 @@ Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) { return OkStatus(); } +Status HloCostAnalysis::HandleCollectiveBroadcast( + const HloInstruction* /*hlo*/) { + return OkStatus(); +} + Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) { return OkStatus(); } @@ -1102,23 +1134,23 @@ Status HloCostAnalysis::FusionProcessOperandBytesRead( } else { // If the fusion parameter is a tuple type, find the gte for the leaf // shape and calculate the bytes accessed for those array types. - for (const auto& indexed_shape : - ShapeUtil::GetLeafShapes(operand->shape())) { - const HloInstruction* gte = operand; - for (int64_t index : indexed_shape.index) { - for (const HloInstruction* user : gte->users()) { - if (user->opcode() == HloOpcode::kGetTupleElement && - user->tuple_index() == index) { - gte = user; - break; + ShapeUtil::ForEachLeafShape( + operand->shape(), + [&](const Shape& /*sub_shape*/, const ShapeIndex& index) { + const HloInstruction* gte = operand; + for (int64_t sub_index : index) { + for (const HloInstruction* user : gte->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->tuple_index() == sub_index) { + gte = user; + break; + } + } } - } - } - int64_t size = FusionParameterReadBytes(gte); - operand_size += size; - current_properties_.set_operand_bytes_accessed(i, indexed_shape.index, - size); - } + int64_t size = FusionParameterReadBytes(gte); + operand_size += size; + current_properties_.set_operand_bytes_accessed(i, index, size); + }); } current_properties_[kBytesAccessedKey] += operand_size; current_properties_.set_operand_bytes_accessed(i, operand_size); @@ -1390,21 +1422,23 @@ int64_t HloCostAnalysis::GetBytesRead( int64_t HloCostAnalysis::GetBytesWritten( const HloInstruction& hlo, std::optional memory_space) const { int64_t bytes_written = 0; - for (const ShapeUtil::IndexedShape& indexed_shape : - ShapeUtil::GetLeafShapes(hlo.shape())) { - std::optional index_memory_space; - if (indexed_shape.shape.has_layout()) { - index_memory_space = indexed_shape.shape.layout().memory_space(); - } - if (!memory_space || memory_space == index_memory_space) { - bytes_written += output_bytes_accessed(hlo, indexed_shape.index); - } - } + + ShapeUtil::ForEachLeafShape( + hlo.shape(), [&](const Shape& sub_shape, const ShapeIndex& index) { + std::optional index_memory_space; + if (sub_shape.has_layout()) { + index_memory_space = sub_shape.layout().memory_space(); + } + if (!memory_space || memory_space == index_memory_space) { + bytes_written += output_bytes_accessed(hlo, index); + } + }); + return bytes_written; } -StatusOr HloCostAnalysis::ProcessSubcomputation( - HloComputation* computation) { +absl::StatusOr +HloCostAnalysis::ProcessSubcomputation(HloComputation* computation) { auto visitor = CreateNestedCostAnalysis(); visitor->ReserveVisitStates(computation->instruction_count()); TF_RETURN_IF_ERROR(computation->Accept(visitor.get())); @@ -1420,19 +1454,17 @@ std::unique_ptr HloCostAnalysis::CreateNestedCostAnalysis() { /*static*/ std::string HloCostAnalysis::GetOperandBytesAccessedKey( int64_t operand_num, const ShapeIndex& index) { - return absl::StrCat(kBytesAccessedKey, " operand ", operand_num, " ", - index.ToString()); + return absl::StrCat(kBytesAccessedKey, operand_num, index.ToString()); } /*static*/ std::string HloCostAnalysis::GetOperandUtilizationKey( int64_t operand_num, const ShapeIndex& index) { - return absl::StrCat(kUtilizationKey, " operand ", operand_num, " ", - index.ToString()); + return absl::StrCat(kUtilizationKey, operand_num, index.ToString()); } /*static*/ std::string HloCostAnalysis::GetOutputBytesAccessedKey( const ShapeIndex& index) { - return absl::StrCat(kBytesAccessedKey, " output ", index.ToString()); + return absl::StrCat(kBytesAccessedKey, "out", index.ToString()); } bool HloCostAnalysis::KeyToCopyFromSubcomputation(absl::string_view key) const { diff --git a/xla/service/hlo_cost_analysis.h b/xla/service/hlo_cost_analysis.h index 705eb2437860f..0a6ecebec265c 100644 --- a/xla/service/hlo_cost_analysis.h +++ b/xla/service/hlo_cost_analysis.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -334,15 +334,15 @@ class HloCostAnalysis : public ConstDfsHloVisitor { private: // These must match GetOperandUtilizationKey(0, {}) etc. static inline constexpr absl::string_view kOperand0UtilizationKey = - "utilization operand 0 {}"; + "utilization0{}"; static inline constexpr absl::string_view kOperand1UtilizationKey = - "utilization operand 1 {}"; + "utilization1{}"; static inline constexpr absl::string_view kOperand0BytesAccessedKey = - "bytes accessed operand 0 {}"; + "bytes accessed0{}"; static inline constexpr absl::string_view kOperand1BytesAccessedKey = - "bytes accessed operand 1 {}"; + "bytes accessed1{}"; static inline constexpr absl::string_view kOutputRootBytesAccessedKey = - "bytes accessed output {}"; + "bytes accessedout{}"; float flops_; float transcendentals_; @@ -399,7 +399,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { } // Returns the specified per-second rate used by cost analysis. - float per_second_rate(const std::string& key) const { + float per_second_rate(absl::string_view key) const { return per_second_rates[key]; } }; @@ -445,6 +445,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleAllReduceStart(const HloInstruction* hlo) override; Status HandleAllReduceDone(const HloInstruction* hlo) override; Status HandleAllToAll(const HloInstruction* hlo) override; + Status HandleCollectiveBroadcast(const HloInstruction* hlo) override; Status HandleCollectivePermute(const HloInstruction* hlo) override; Status HandleCollectivePermuteStart(const HloInstruction* hlo) override; Status HandleCollectivePermuteDone(const HloInstruction* hlo) override; @@ -547,7 +548,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { // Returns the specified per-second rate used by cost analysis. float per_second_rate(absl::string_view key) const { - return options_.per_second_rates[key]; + return options_.per_second_rate(key); } // Return the key that is used to index into Properties for the specified @@ -620,7 +621,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { // given hlo. The cost of visited sub HLO instructions is saved to // hlo_properties_, which will be used by functions such as // flop_count(hlo_instruction) to return cost of a particular HLO instruction. - virtual StatusOr ProcessSubcomputation( + virtual absl::StatusOr ProcessSubcomputation( HloComputation* computation); // Utility function to handle all element-wise operations. diff --git a/xla/service/hlo_cost_analysis_test.cc b/xla/service/hlo_cost_analysis_test.cc index a656bf47a337f..341003499a623 100644 --- a/xla/service/hlo_cost_analysis_test.cc +++ b/xla/service/hlo_cost_analysis_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -631,6 +631,23 @@ TEST_F(HloCostAnalysisTest, Broadcast) { EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10 * 7); } +TEST_F(HloCostAnalysisTest, BroadcastCountMultipleInputAccesses) { + XlaBuilder b("broadcast"); + Broadcast(ConstantR0(&b, 42), {10, 7}); + auto hlo_module = BuildHloGraph(&b); + HloCostAnalysis analysis(HloCostAnalysis::Options{ + .shape_size = ShapeSize, .count_multiple_input_accesses = true}); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + EXPECT_EQ(analysis.flop_count(), 0); + + EXPECT_EQ(analysis.bytes_accessed(), sizeof(float) * (1 + 10 * 7)); + + HloInstruction* root = hlo_module->entry_computation()->root_instruction(); + EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 7); + EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10 * 7); +} + // Calculates the computation cost of a graph with more than one HLO node. TEST_F(HloCostAnalysisTest, FullyConnectedForward) { XlaBuilder builder("fully_connected_forward"); @@ -815,6 +832,62 @@ TEST_F(FusionCostAnalysis, LoopFusion) { } } +TEST_F(FusionCostAnalysis, NestedCopyFusion) { + absl::string_view nested_fusion_text = R"( +HloModule temp, is_scheduled=true + +copy_fusion.1291.clone { + input.1291 = s8[2,6144,2,256]{3,1,0,2:T(32,128)(4,1)S(1)} parameter(0) + ROOT copy.74276 = s8[2,6144,2,256]{3,1,0,2:T(8,128)(4,1)} copy(input.1291) +} + +fused_computation.4150.clone { + param_0.185389 = s8[2,6144,2,256]{3,1,0,2:T(32,128)(4,1)} parameter(0) + fusion.103344 = s8[2,6144,2,256]{3,1,0,2:T(8,128)(4,1)} fusion(param_0.185389), kind=kLoop, calls=copy_fusion.1291.clone + constant.230138 = s32[]{:T(128)} constant(0) + param_1.219146 = s32[]{:T(128)S(6)} parameter(1) + ROOT dynamic-slice.40526 = s8[2,384,2,256]{3,1,0,2:T(8,128)(4,1)} dynamic-slice(fusion.103344, constant.230138, param_1.219146, constant.230138, constant.230138), dynamic_slice_sizes={2,384,2,256} +} + +ENTRY temp { + param_2.123719 = s8[2,6144,2,256]{3,1,0,2:T(32,128)(4,1)} parameter(0) + param_3.66279 = s32[]{:T(128)S(6)} parameter(1) + ROOT fusion.85943 = s8[2,384,2,256]{3,1,0,2:T(8,128)(4,1)} fusion(param_2.123719, param_3.66279), kind=kLoop, calls=fused_computation.4150.clone +} +)"; + absl::string_view fusion_text = R"( +HloModule temp, is_scheduled=true + +fused_computation.4150.clone { + param_0.185389 = s8[2,6144,2,256]{3,1,0,2:T(32,128)(4,1)} parameter(0) + constant.230138 = s32[]{:T(128)} constant(0) + param_1.219146 = s32[]{:T(128)S(6)} parameter(1) + ROOT dynamic-slice.40526 = s8[2,384,2,256]{3,1,0,2:T(8,128)(4,1)} dynamic-slice(param_0.185389, constant.230138, param_1.219146, constant.230138, constant.230138), dynamic_slice_sizes={2,384,2,256} +} + +ENTRY temp { + param_2.123719 = s8[2,6144,2,256]{3,1,0,2:T(32,128)(4,1)} parameter(0) + param_3.66279 = s32[]{:T(128)S(6)} parameter(1) + ROOT fusion.85943 = s8[2,384,2,256]{3,1,0,2:T(8,128)(4,1)} fusion(param_2.123719, param_3.66279), kind=kLoop, calls=fused_computation.4150.clone +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto nested_fusion_module, + ParseAndReturnVerifiedModule(nested_fusion_text)); + HloCostAnalysis nested_analysis(ShapeSize); + auto* nested_root = + nested_fusion_module->entry_computation()->root_instruction(); + ASSERT_IS_OK(nested_root->Accept(&nested_analysis)); + TF_ASSERT_OK_AND_ASSIGN(auto fusion_module, + ParseAndReturnVerifiedModule(fusion_text)); + HloCostAnalysis fusion_analysis(ShapeSize); + auto* fusion_root = fusion_module->entry_computation()->root_instruction(); + ASSERT_IS_OK(fusion_root->Accept(&fusion_analysis)); + // The nested fusion should only access the bytes size amount of the parameter + // based on the size of the consuming dynamic slice. + EXPECT_EQ(nested_analysis.bytes_accessed(*nested_root), + fusion_analysis.bytes_accessed(*fusion_root)); +} + TEST_F(FusionCostAnalysis, LoopFusionTupleOutput) { Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); diff --git a/xla/service/hlo_creation_utils.cc b/xla/service/hlo_creation_utils.cc index 34e188e7be003..d5bb5a2307937 100644 --- a/xla/service/hlo_creation_utils.cc +++ b/xla/service/hlo_creation_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -52,9 +52,9 @@ limitations under the License. namespace xla { using absl::StrCat; -StatusOr MakeUnaryHlo(HloOpcode opcode, - HloInstruction* operand, - const OpMetadata* metadata) { +absl::StatusOr MakeUnaryHlo(HloOpcode opcode, + HloInstruction* operand, + const OpMetadata* metadata) { HloComputation* computation = operand->parent(); TF_ASSIGN_OR_RETURN(Shape unary_op_shape, ShapeInference::InferUnaryOpShape(opcode, operand)); @@ -67,9 +67,10 @@ HloInstruction* MakeCopyHlo(HloInstruction* from, const Shape& to) { HloInstruction::CreateUnary(to, HloOpcode::kCopy, from)); } -StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, - HloInstruction* rhs, - const OpMetadata* metadata) { +absl::StatusOr MakeBinaryHlo(HloOpcode opcode, + HloInstruction* lhs, + HloInstruction* rhs, + const OpMetadata* metadata) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); TF_ASSIGN_OR_RETURN(Shape binary_op_shape, @@ -79,10 +80,10 @@ StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, metadata); } -StatusOr MakeCompareHlo(ComparisonDirection direction, - HloInstruction* lhs, - HloInstruction* rhs, - const OpMetadata* metadata) { +absl::StatusOr MakeCompareHlo(ComparisonDirection direction, + HloInstruction* lhs, + HloInstruction* rhs, + const OpMetadata* metadata) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); TF_ASSIGN_OR_RETURN( @@ -93,10 +94,10 @@ StatusOr MakeCompareHlo(ComparisonDirection direction, metadata); } -StatusOr MakePadHlo(HloInstruction* operand, - HloInstruction* padding_value, - const PaddingConfig& padding_config, - const OpMetadata* metadata) { +absl::StatusOr MakePadHlo(HloInstruction* operand, + HloInstruction* padding_value, + const PaddingConfig& padding_config, + const OpMetadata* metadata) { HloComputation* computation = operand->parent(); CHECK_EQ(computation, padding_value->parent()); TF_ASSIGN_OR_RETURN( @@ -109,11 +110,10 @@ StatusOr MakePadHlo(HloInstruction* operand, metadata); } -StatusOr MakeSliceHlo(HloInstruction* operand, - absl::Span start_indices, - absl::Span limit_indices, - absl::Span strides, - const OpMetadata* metadata) { +absl::StatusOr MakeSliceHlo( + HloInstruction* operand, absl::Span start_indices, + absl::Span limit_indices, absl::Span strides, + const OpMetadata* metadata) { HloComputation* computation = operand->parent(); TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape( operand->shape(), start_indices, @@ -124,7 +124,7 @@ StatusOr MakeSliceHlo(HloInstruction* operand, metadata); } -StatusOr MakeConvolveHlo( +absl::StatusOr MakeConvolveHlo( HloInstruction* lhs, HloInstruction* rhs, int64_t feature_group_count, int64_t batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, @@ -145,7 +145,7 @@ StatusOr MakeConvolveHlo( metadata); } -StatusOr MakeTransposeHlo( +absl::StatusOr MakeTransposeHlo( HloInstruction* operand, absl::Span dimensions) { TF_ASSIGN_OR_RETURN( Shape transpose_shape, @@ -154,13 +154,13 @@ StatusOr MakeTransposeHlo( HloInstruction::CreateTranspose(transpose_shape, operand, dimensions)); } -StatusOr MakeReshapeHlo(const Shape& result_shape, - HloInstruction* operand) { +absl::StatusOr MakeReshapeHlo(const Shape& result_shape, + HloInstruction* operand) { return operand->AddInstruction( HloInstruction::CreateReshape(result_shape, operand)); } -StatusOr MakeReshapeHlo( +absl::StatusOr MakeReshapeHlo( absl::Span result_shape_dim_bounds, HloInstruction* operand) { Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(), @@ -168,7 +168,7 @@ StatusOr MakeReshapeHlo( return MakeReshapeHlo(new_shape, operand); } -StatusOr MakeDynamicSliceHlo( +absl::StatusOr MakeDynamicSliceHlo( HloInstruction* operand, absl::Span start_indices, absl::Span slice_sizes, const OpMetadata* metadata) { // slice of a scalar is no-op @@ -189,7 +189,7 @@ StatusOr MakeDynamicSliceHlo( metadata); } -StatusOr MakeDynamicSliceHlo( +absl::StatusOr MakeDynamicSliceHlo( HloInstruction* operand, HloInstruction* start_indices, absl::Span slice_sizes, const OpMetadata* metadata) { HloComputation* computation = operand->parent(); @@ -218,7 +218,7 @@ StatusOr MakeDynamicSliceHlo( metadata); } -StatusOr MakeDynamicUpdateSliceHlo( +absl::StatusOr MakeDynamicUpdateSliceHlo( HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices, const OpMetadata* metadata) { HloComputation* computation = operand->parent(); @@ -248,7 +248,7 @@ StatusOr MakeDynamicUpdateSliceHlo( metadata); } -StatusOr MakeDynamicUpdateSliceHlo( +absl::StatusOr MakeDynamicUpdateSliceHlo( HloInstruction* operand, HloInstruction* update, absl::Span start_indices, const OpMetadata* metadata) { @@ -289,9 +289,8 @@ HloInstruction* MakeBroadcastHlo(HloInstruction* operand, metadata); } -StatusOr MakeGetTupleElementHlo(HloInstruction* operand, - int64_t index, - const OpMetadata* metadata) { +absl::StatusOr MakeGetTupleElementHlo( + HloInstruction* operand, int64_t index, const OpMetadata* metadata) { HloComputation* computation = operand->parent(); TF_ASSIGN_OR_RETURN( @@ -302,7 +301,7 @@ StatusOr MakeGetTupleElementHlo(HloInstruction* operand, metadata); } -StatusOr MakeConcatHlo( +absl::StatusOr MakeConcatHlo( absl::Span operands, int64_t dimension, const OpMetadata* metadata) { CHECK_GT(operands.size(), 0); @@ -370,27 +369,28 @@ HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape, HloInstruction::CreateIota(shape, iota_dimension)); } -StatusOr MakeDotHlo( +absl::StatusOr MakeDotHlo( HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dim_numbers, const PrecisionConfig& precision_config, std::optional preferred_element_type, - const OpMetadata* metadata) { + std::vector sparsity, + absl::Span sparse_meta, const OpMetadata* metadata) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); - TF_ASSIGN_OR_RETURN( - Shape dot_shape, - ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers, - preferred_element_type)); + TF_ASSIGN_OR_RETURN(Shape dot_shape, + ShapeInference::InferDotOpShape( + lhs->shape(), rhs->shape(), dim_numbers, + preferred_element_type, absl::MakeSpan(sparsity))); return computation->AddInstruction( HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers, - precision_config), + precision_config, sparsity, sparse_meta), metadata); } -StatusOr MakeMapHlo(absl::Span operands, - HloComputation* map_computation, - const OpMetadata* metadata) { +absl::StatusOr MakeMapHlo( + absl::Span operands, HloComputation* map_computation, + const OpMetadata* metadata) { CHECK(!operands.empty()) << "Map Hlo requires at least one operand."; HloComputation* computation = operands.front()->parent(); std::vector operand_shapes; @@ -420,11 +420,10 @@ HloInstruction* MakeReducePrecisionHlo(HloInstruction* operand, metadata); } -StatusOr MakeReduceHlo(HloInstruction* operand, - HloInstruction* init_value, - absl::Span dimensions, - HloComputation* reduce_computation, - const OpMetadata* metadata) { +absl::StatusOr MakeReduceHlo( + HloInstruction* operand, HloInstruction* init_value, + absl::Span dimensions, HloComputation* reduce_computation, + const OpMetadata* metadata) { auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {}); auto result_shape = ShapeUtil::DeleteDimensions(dimensions, operand->shape()); @@ -434,7 +433,7 @@ StatusOr MakeReduceHlo(HloInstruction* operand, metadata); } -StatusOr MakeReduceWindowHlo( +absl::StatusOr MakeReduceWindowHlo( HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation, const OpMetadata* metadata) { TF_ASSIGN_OR_RETURN(Shape inferred_shape, @@ -447,11 +446,10 @@ StatusOr MakeReduceWindowHlo( metadata); } -StatusOr MakeReduceHlo(HloInstruction* operand, - HloInstruction* init_value, - absl::Span dimensions, - HloOpcode binary_opcode, - const OpMetadata* metadata) { +absl::StatusOr MakeReduceHlo( + HloInstruction* operand, HloInstruction* init_value, + absl::Span dimensions, HloOpcode binary_opcode, + const OpMetadata* metadata) { auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {}); HloComputation* reduce_computation; { @@ -470,11 +468,11 @@ StatusOr MakeReduceHlo(HloInstruction* operand, metadata); } -StatusOr MakeReduceHlo(HloInstruction* operand, - HloInstruction* init_value, - HloOpcode binary_opcode, - HloModule* module, - const OpMetadata* metadata) { +absl::StatusOr MakeReduceHlo(HloInstruction* operand, + HloInstruction* init_value, + HloOpcode binary_opcode, + HloModule* module, + const OpMetadata* metadata) { DCHECK_NE(nullptr, module); std::vector all_dims(operand->shape().rank()); std::iota(all_dims.begin(), all_dims.end(), 0); @@ -496,7 +494,7 @@ StatusOr MakeReduceHlo(HloInstruction* operand, metadata); } -StatusOr MakeReduceHlo( +absl::StatusOr MakeReduceHlo( absl::Span operands, absl::Span init_values, absl::Span dimensions, HloComputation* reduce_computation, @@ -527,9 +525,9 @@ StatusOr MakeReduceHlo( metadata); } -StatusOr MakeReverseHlo(HloInstruction* operand, - absl::Span dimensions, - const OpMetadata* metadata) { +absl::StatusOr MakeReverseHlo( + HloInstruction* operand, absl::Span dimensions, + const OpMetadata* metadata) { HloComputation* computation = operand->parent(); TF_ASSIGN_OR_RETURN(Shape reverse_shape, ShapeInference::InferReverseShape( operand->shape(), dimensions)); @@ -538,10 +536,10 @@ StatusOr MakeReverseHlo(HloInstruction* operand, metadata); } -StatusOr MakeSelectHlo(HloInstruction* pred, - HloInstruction* on_true, - HloInstruction* on_false, - HloInstruction* derived_from) { +absl::StatusOr MakeSelectHlo(HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false, + HloInstruction* derived_from) { HloComputation* computation = pred->parent(); DCHECK_EQ(computation, on_true->parent()); DCHECK_EQ(computation, on_false->parent()); @@ -581,7 +579,7 @@ HloInstruction* MaybeMakeTuple(absl::Span operands) { HloInstruction::CreateTuple(operands)); } -StatusOr MakeSortHlo( +absl::StatusOr MakeSortHlo( const Shape& sort_shape, absl::Span operands, int64_t dimension_to_sort, bool is_stable, HloComputation::Builder* builder, HloModule* module, const OpMetadata* metadata) { @@ -607,8 +605,8 @@ StatusOr MakeSortHlo( sort_shape, dimension_to_sort, operands, compare_computation, is_stable)); } -StatusOr CollapseFirstNDims(HloInstruction* operand, - int64_t n) { +absl::StatusOr CollapseFirstNDims(HloInstruction* operand, + int64_t n) { CHECK_GT(n, 0); const Shape& operand_shape = operand->shape(); @@ -641,8 +639,8 @@ StatusOr CollapseFirstNDims(HloInstruction* operand, return MakeReshapeHlo(output_shape, operand); } -StatusOr PrependDegenerateDims(HloInstruction* operand, - int64_t n) { +absl::StatusOr PrependDegenerateDims(HloInstruction* operand, + int64_t n) { CHECK_GT(n, 0); std::vector new_shape_dims; const Shape& operand_shape = operand->shape(); @@ -652,7 +650,7 @@ StatusOr PrependDegenerateDims(HloInstruction* operand, return MakeReshapeHlo(new_shape_dims, operand); } -StatusOr ExpandFirstDimIntoNDims( +absl::StatusOr ExpandFirstDimIntoNDims( HloInstruction* operand, absl::Span expanded_dims) { CHECK_GT(operand->shape().dimensions_size(), 0); CHECK_EQ(operand->shape().dimensions(0), Product(expanded_dims)); @@ -669,7 +667,7 @@ StatusOr ExpandFirstDimIntoNDims( return MakeReshapeHlo(new_shape, operand); } -StatusOr ElideDegenerateDims( +absl::StatusOr ElideDegenerateDims( HloInstruction* operand, absl::Span dims_to_elide) { return MakeReshapeHlo(ShapeUtil::FilterDimensions( [&](int64_t dim) { @@ -679,7 +677,7 @@ StatusOr ElideDegenerateDims( operand); } -StatusOr InsertDegenerateDims( +absl::StatusOr InsertDegenerateDims( HloInstruction* operand, absl::Span dims_to_insert) { CHECK(absl::c_is_sorted(dims_to_insert)); @@ -711,9 +709,9 @@ StatusOr InsertDegenerateDims( return MakeReshapeHlo(output_shape, operand); } -StatusOr PadVectorWithZeros(HloInstruction* operand, - int64_t zeros_to_prepend, - int64_t zeros_to_append) { +absl::StatusOr PadVectorWithZeros(HloInstruction* operand, + int64_t zeros_to_prepend, + int64_t zeros_to_append) { HloComputation* computation = operand->parent(); CHECK_EQ(operand->shape().dimensions_size(), 1); PaddingConfig padding_config; @@ -752,7 +750,7 @@ HloInstruction* BroadcastOnes(HloComputation* computation, /*result_shape_bounds=*/broadcast_dimensions); } -StatusOr MakeFusionInstruction( +absl::StatusOr MakeFusionInstruction( HloInstruction* fused, HloInstruction::FusionKind kind) { HloComputation* comp = fused->parent(); HloInstruction* fusion_instruction = comp->AddInstruction( @@ -777,7 +775,7 @@ HloInstruction* CreateDummyOp(HloComputation::Builder* b, const Shape& shape) { return b->AddInstruction(HloInstruction::CreateTuple(sub_instructions)); } -StatusOr> CreateComputationWithSignature( +absl::StatusOr> CreateComputationWithSignature( absl::Span domain, const Shape& range, absl::string_view name) { HloComputation::Builder b{std::string(name)}; diff --git a/xla/service/hlo_creation_utils.h b/xla/service/hlo_creation_utils.h index 5f5640ad3ec46..443bb93a0bc5a 100644 --- a/xla/service/hlo_creation_utils.h +++ b/xla/service/hlo_creation_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,11 +18,14 @@ limitations under the License. #include #include +#include +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/literal_util.h" #include "xla/statusor.h" +#include "xla/xla_data.pb.h" namespace xla { @@ -32,47 +35,44 @@ namespace xla { // Creates a unary HLO instruction and adds it to the computation containing // `operand`. -StatusOr MakeUnaryHlo(HloOpcode opcode, - HloInstruction* operand, - const OpMetadata* metadata = nullptr); +absl::StatusOr MakeUnaryHlo( + HloOpcode opcode, HloInstruction* operand, + const OpMetadata* metadata = nullptr); // Creates a binary HLO instruction and adds it to the computation containing // `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). -StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, - HloInstruction* rhs, - const OpMetadata* metadata = nullptr); +absl::StatusOr MakeBinaryHlo( + HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs, + const OpMetadata* metadata = nullptr); // Creates a kCopy HLO. HloInstruction* MakeCopyHlo(HloInstruction* from, const Shape& to); // Creates a compare HLO instruction and adds it to the computation containing // `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). -StatusOr MakeCompareHlo(Comparison::Direction direction, - HloInstruction* lhs, - HloInstruction* rhs, - const OpMetadata* metadata = nullptr); +absl::StatusOr MakeCompareHlo( + Comparison::Direction direction, HloInstruction* lhs, HloInstruction* rhs, + const OpMetadata* metadata = nullptr); // Creates a pad HLO instruction and adds it to the computation containing // `operand` and `padding_value` (`operand` and `padding_value` must be in the // same computation). -StatusOr MakePadHlo(HloInstruction* operand, - HloInstruction* padding_value, - const PaddingConfig& padding_config, - const OpMetadata* metadata = nullptr); +absl::StatusOr MakePadHlo( + HloInstruction* operand, HloInstruction* padding_value, + const PaddingConfig& padding_config, const OpMetadata* metadata = nullptr); // Creates a slice HLO instruction and adds it to the computation containing // `operand`. -StatusOr MakeSliceHlo(HloInstruction* operand, - absl::Span start_indices, - absl::Span limit_indices, - absl::Span strides, - const OpMetadata* metadata = nullptr); +absl::StatusOr MakeSliceHlo( + HloInstruction* operand, absl::Span start_indices, + absl::Span limit_indices, absl::Span strides, + const OpMetadata* metadata = nullptr); // Creates a convolution HLO instruction and adds it to the computation // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). // If the result shape has integral element type, an optional // preferred_element_type can be specified to override the element type. -StatusOr MakeConvolveHlo( +absl::StatusOr MakeConvolveHlo( HloInstruction* lhs, HloInstruction* rhs, int64_t feature_group_count, int64_t batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, @@ -82,25 +82,25 @@ StatusOr MakeConvolveHlo( // Creates a transpose HLO instruction and adds it to the computation containing // `operand`. -StatusOr MakeTransposeHlo( +absl::StatusOr MakeTransposeHlo( HloInstruction* operand, absl::Span dimensions); // Creates a reshape HLO instruction and adds it to the computation containing // `operand`. -StatusOr MakeReshapeHlo(const Shape& result_shape, - HloInstruction* operand); +absl::StatusOr MakeReshapeHlo(const Shape& result_shape, + HloInstruction* operand); -StatusOr MakeReshapeHlo( +absl::StatusOr MakeReshapeHlo( absl::Span result_shape_dim_bounds, HloInstruction* operand); // Creates a dynamic-slice HLO instruction and adds it to the computation // containing `operand` and `start_indices` (`operand` and `start_indices` must // be in the same computation). -StatusOr MakeDynamicSliceHlo( +absl::StatusOr MakeDynamicSliceHlo( HloInstruction* operand, absl::Span start_indices, absl::Span slice_sizes, const OpMetadata* metadata = nullptr); -StatusOr MakeDynamicSliceHlo( +absl::StatusOr MakeDynamicSliceHlo( HloInstruction* operand, HloInstruction* start_indices, absl::Span slice_sizes, const OpMetadata* metadata = nullptr); @@ -108,13 +108,13 @@ StatusOr MakeDynamicSliceHlo( // Creates a dynamic-update-slice HLO instruction and adds it to the computation // containing `operand`, `update` and `start_indices` (`operand`, `update` and // `start_indices` must be in the same computation). -StatusOr MakeDynamicUpdateSliceHlo( +absl::StatusOr MakeDynamicUpdateSliceHlo( HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices, const OpMetadata* metadata = nullptr); // a variant of dynamic-update-slice where `start_indices` is a vector of HLO // instructions -StatusOr MakeDynamicUpdateSliceHlo( +absl::StatusOr MakeDynamicUpdateSliceHlo( HloInstruction* operand, HloInstruction* update, absl::Span start_indices, const OpMetadata* metadata = nullptr); @@ -132,14 +132,14 @@ HloInstruction* MakeBroadcastHlo(HloInstruction* operand, // Creates a GetTupleElement HLO instruction and adds it to the computation // containing `operand`. -StatusOr MakeGetTupleElementHlo( +absl::StatusOr MakeGetTupleElementHlo( HloInstruction* operand, int64_t index, const OpMetadata* metadata = nullptr); // Creates a Concatenate HLO instruction and adds it to the computation // containing `operands` (`operands` must be non-empty and every element must be // contained in the same computation). -StatusOr MakeConcatHlo( +absl::StatusOr MakeConcatHlo( absl::Span operands, int64_t dimension, const OpMetadata* metadata = nullptr); @@ -163,19 +163,22 @@ HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape, // Creates a Dot HLO instruction and adds it to the computation containing `lhs` // and `rhs` (both must be in the same computation). If the result shape has // integral element type, an optional preferred_element_type can be specified to -// override the element type. -StatusOr MakeDotHlo( +// override the element type. If 'sparsity' is set, then 'sparse_meta' must also +// be present (and have the same size). +absl::StatusOr MakeDotHlo( HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dim_numbers, const PrecisionConfig& precision_config, std::optional preferred_element_type, + std::vector sparsity = {}, + absl::Span sparse_meta = {}, const OpMetadata* metadata = nullptr); // Creates a Map HLO instruction and adds it to the computation containing the // operands. All operands must be in the same computation. -StatusOr MakeMapHlo(absl::Span operands, - HloComputation* map_computation, - const OpMetadata* metadata = nullptr); +absl::StatusOr MakeMapHlo( + absl::Span operands, HloComputation* map_computation, + const OpMetadata* metadata = nullptr); // Creates a reduce-precision op, where operand is the data to reduce in // precision, and exponent_bits and mantissa_bits describe the precision to @@ -184,30 +187,27 @@ HloInstruction* MakeReducePrecisionHlo(HloInstruction* operand, int exponent_bits, int mantissa_bits, const OpMetadata* metadata = nullptr); -StatusOr MakeReduceWindowHlo( +absl::StatusOr MakeReduceWindowHlo( HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation, const OpMetadata* metadata = nullptr); // Creates a Reduce HLO instruction and adds it to the computation containing // the operand. This will create the sub-computation needed for the reduction in // the given module. binary_opcode should represent a binary operation. -StatusOr MakeReduceHlo(HloInstruction* operand, - HloInstruction* init_value, - absl::Span dimensions, - HloOpcode binary_opcode, - const OpMetadata* metadata = nullptr); +absl::StatusOr MakeReduceHlo( + HloInstruction* operand, HloInstruction* init_value, + absl::Span dimensions, HloOpcode binary_opcode, + const OpMetadata* metadata = nullptr); -StatusOr MakeReduceHlo(HloInstruction* operand, - HloInstruction* init_value, - absl::Span dimensions, - HloComputation* reduce_computation, - const OpMetadata* metadata = nullptr); +absl::StatusOr MakeReduceHlo( + HloInstruction* operand, HloInstruction* init_value, + absl::Span dimensions, HloComputation* reduce_computation, + const OpMetadata* metadata = nullptr); -StatusOr MakeReduceHlo(HloInstruction* operand, - HloInstruction* init_value, - HloOpcode binary_opcode, - HloModule* module, - const OpMetadata* metadata = nullptr); +absl::StatusOr MakeReduceHlo( + HloInstruction* operand, HloInstruction* init_value, + HloOpcode binary_opcode, HloModule* module, + const OpMetadata* metadata = nullptr); // Generic helper function to create a reduction. // @@ -216,7 +216,7 @@ StatusOr MakeReduceHlo(HloInstruction* operand, // // Creates a non-variadic reduction if the size is singular, and a variadic one // otherwise. -StatusOr MakeReduceHlo( +absl::StatusOr MakeReduceHlo( absl::Span operands, absl::Span init_values, absl::Span dimensions, HloComputation* reduce_computation, @@ -224,18 +224,17 @@ StatusOr MakeReduceHlo( // Creates a Reverse HLO instruction and adds it to the computation containing // `operand`. -StatusOr MakeReverseHlo(HloInstruction* operand, - absl::Span dimensions, - const OpMetadata* metadata = nullptr); +absl::StatusOr MakeReverseHlo( + HloInstruction* operand, absl::Span dimensions, + const OpMetadata* metadata = nullptr); // Creates a Select HLO instruction and adds it to the computation containing // the predicate. The on_true and on_false instructions must also be contained // in the same computation. If on_true and on_false are tuples, create a tuple // select instead. `pred` is broadcasted up from a scalar if necessary. -StatusOr MakeSelectHlo(HloInstruction* pred, - HloInstruction* on_true, - HloInstruction* on_false, - HloInstruction* derived_from = nullptr); +absl::StatusOr MakeSelectHlo( + HloInstruction* pred, HloInstruction* on_true, HloInstruction* on_false, + HloInstruction* derived_from = nullptr); // Forwards the first operand if operands.size() == 1, or creates a tuple // instruction with all the operands. Crashes if `operands` is empty. @@ -245,7 +244,7 @@ HloInstruction* MaybeMakeTuple(absl::Span operands); // operands. All operands must be in the same computation. Also creates a // default compare sub-computation which sorts the first operand into ascending // order. 'is_stable' specifies whether the sorting should be stable. -StatusOr MakeSortHlo( +absl::StatusOr MakeSortHlo( const Shape& sort_shape, absl::Span operands, int64_t dimension_to_sort, bool is_stable, HloComputation::Builder* builder, HloModule* module, const OpMetadata* metadata = nullptr); @@ -253,9 +252,9 @@ StatusOr MakeSortHlo( // Creates an R1 Constant HLO instruction of the given PrimitiveType with the // given values and adds it to the given computation. template -StatusOr MakeR1ConstantHlo(HloComputation* computation, - PrimitiveType type, - absl::Span values) { +absl::StatusOr MakeR1ConstantHlo( + HloComputation* computation, PrimitiveType type, + absl::Span values) { Literal literal = LiteralUtil::CreateR1(values); if (literal.shape().element_type() != type) { TF_ASSIGN_OR_RETURN(literal, literal.Convert(type)); @@ -284,13 +283,13 @@ HloInstruction* MakeScalarLike(HloInstruction* base, NativeT value) { *scalar->mutable_shape() = base->shape(); return scalar; } - return base->AddInstruction( - HloInstruction::CreateBroadcast(base->shape(), scalar, {})); + return base->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeStaticShape(base->shape()), scalar, {})); } // Creates a fusion instruction and fuses `fused` into the created fusion // instruction. -StatusOr MakeFusionInstruction( +absl::StatusOr MakeFusionInstruction( HloInstruction* fused, HloInstruction::FusionKind kind); // ----------------------------------------------------------------------------- @@ -304,8 +303,8 @@ StatusOr MakeFusionInstruction( // // For instance if `operand` has shape f32[7,8,9] and n is 2 then the output is // the `operand` reshaped to [56,9]. -StatusOr CollapseFirstNDims(HloInstruction* operand, - int64_t n); +absl::StatusOr CollapseFirstNDims(HloInstruction* operand, + int64_t n); // Prepends `n` degenerate dimensions (dimensions with bound = 1) to `operand` // using a reshape. @@ -313,8 +312,8 @@ StatusOr CollapseFirstNDims(HloInstruction* operand, // For instance if operand has shape f32[3,4,5] then this returns the operand // reshaped to f32[1,3,4,5]. If the operand is a f32 scalar (i.e. has shape // f32[]) then this returns the operand reshaped to f32[1]. -StatusOr PrependDegenerateDims(HloInstruction* operand, - int64_t n); +absl::StatusOr PrependDegenerateDims(HloInstruction* operand, + int64_t n); // Expands (via reshape) the first (logical) dimension of `operand` into a // sequence of `expanded_dims` dimensions. `operand` must at least be of rank 1 @@ -323,7 +322,7 @@ StatusOr PrependDegenerateDims(HloInstruction* operand, // // For instance if `operand` has shape f32[200,9,7] and expanded_dims is // {2,5,20} the result is `operand` reshaped to [2,5,20,9,7]. -StatusOr ExpandFirstDimIntoNDims( +absl::StatusOr ExpandFirstDimIntoNDims( HloInstruction* operand, absl::Span expanded_dims); // Elides (via reshape) a set of degenerate dimensions (dimensions containing @@ -333,7 +332,7 @@ StatusOr ExpandFirstDimIntoNDims( // // For example if `operand` is of shape f32[19,1,20,1,7,1,9] and dims_to_elide // is {1,5} then the result is `operand` reshaped to [19,20,1,7,9]. -StatusOr ElideDegenerateDims( +absl::StatusOr ElideDegenerateDims( HloInstruction* operand, absl::Span dims_to_elide); // Inserts (via reshape) a set of degenerate dimensions (dimensions containing @@ -343,14 +342,14 @@ StatusOr ElideDegenerateDims( // // For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is // {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34]. -StatusOr InsertDegenerateDims( +absl::StatusOr InsertDegenerateDims( HloInstruction* operand, absl::Span dims_to_insert); // Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the // front and `zeros_to_append` zeros in the back. -StatusOr PadVectorWithZeros(HloInstruction* operand, - int64_t zeros_to_prepend, - int64_t zeros_to_append); +absl::StatusOr PadVectorWithZeros(HloInstruction* operand, + int64_t zeros_to_prepend, + int64_t zeros_to_append); // Broadcasts a zero value of type `element_type` into a tensor with element // type `element_type` and dimension bounds `broadcast_dimensions`. The @@ -370,7 +369,7 @@ HloInstruction* BroadcastOnes(HloComputation* computation, // Creates a HLO computation that takes arguments of type `domain` and produces // a value of type `range`. -StatusOr> CreateComputationWithSignature( +absl::StatusOr> CreateComputationWithSignature( absl::Span domain, const Shape& range, absl::string_view name); diff --git a/xla/service/hlo_creation_utils_test.cc b/xla/service/hlo_creation_utils_test.cc index 48708f2aa218b..4df62a6463e48 100644 --- a/xla/service/hlo_creation_utils_test.cc +++ b/xla/service/hlo_creation_utils_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" @@ -372,6 +373,7 @@ TEST_F(HloCreationUtilsTest, MaybeMakeTupleTuplizesMultipleOperands) { Literal expected_result = LiteralUtil::MakeTuple({&input1, &input0}); EXPECT_EQ(result_literal, expected_result); } + TEST_F(HloCreationUtilsTest, DynamicUpdateSliceVectorStartIndices) { auto module = CreateNewVerifiedModule("dus-creation-test"); // arg: @@ -485,5 +487,19 @@ TEST_F(HloCreationUtilsTest, ReduceWindow) { expected_output_shape); } +TEST_F(HloCreationUtilsTest, DynamicBroadcastShape) { + HloInstruction* param; + HloComputation* entry_computation; + + auto module = CreateModuleWithProgramShape(F32, /*input_shape_dims=*/{10}, + /*output_shape_dims=*/{10}, ¶m, + &entry_computation); + param->mutable_shape()->set_dynamic_dimension(0, true); + + HloInstruction* one_constant = MakeScalarLike(param, 1.0f); + // Broadcasts should always have a static shape that is inferred. + EXPECT_TRUE(one_constant->shape().is_static()); +} + } // namespace } // namespace xla diff --git a/xla/service/hlo_cse.cc b/xla/service/hlo_cse.cc index 74932943d4f54..9ade2a7d03d9c 100644 --- a/xla/service/hlo_cse.cc +++ b/xla/service/hlo_cse.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -60,7 +61,8 @@ struct ConstantKey { // While we're here, also combine identical iota instructions, since they need // similar treatment. template -StatusOr CombineConstants(HloComputation* computation) { +absl::StatusOr CombineConstants(HloComputation* computation, + bool only_scalars) { // Populating the domain map is somewhat expensive -- only do it if there are // kDomain ops in the computation. If there are no kDomain ops, the domain // map is trivial, every op gets mapped to the same domain. @@ -85,6 +87,10 @@ StatusOr CombineConstants(HloComputation* computation) { // invalidated due to deletion. ++inst_it; + if (only_scalars && !ShapeUtil::IsScalar(instruction->shape())) { + continue; + } + HloInstruction* match = nullptr; if (auto* constant_inst = DynCast(instruction)) { auto insert_result = constants.insert(ConstantKey{ @@ -215,7 +221,7 @@ struct CseKey { } // namespace -StatusOr HloCSE::Run( +absl::StatusOr HloCSE::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; @@ -249,10 +255,11 @@ StatusOr HloCSE::Run( continue; } - TF_ASSIGN_OR_RETURN(bool combined, - is_layout_sensitive_ - ? CombineConstants(computation) - : CombineConstants(computation)); + TF_ASSIGN_OR_RETURN( + bool combined, + is_layout_sensitive_ + ? CombineConstants(computation, only_scalars_) + : CombineConstants(computation, only_scalars_)); changed |= combined; // HLO instructions are grouped into equivalency classes by using the @@ -274,6 +281,10 @@ StatusOr HloCSE::Run( continue; } + if (only_scalars_ && !ShapeUtil::IsScalar(instruction->shape())) { + continue; + } + auto pair = representatives.insert(CseKey{instruction}); if (!pair.second) { HloInstruction* equivalent_instruction = pair.first->hlo; @@ -282,6 +293,8 @@ StatusOr HloCSE::Run( TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands( instruction, /*cleanup=*/std::nullopt, ignore_control_dependencies_)); + VLOG(4) << "Replaced " << instruction->name() << " with " + << equivalent_instruction->name(); changed = true; continue; } diff --git a/xla/service/hlo_cse.h b/xla/service/hlo_cse.h index a7f703d08f3d2..1ccab0d5872eb 100644 --- a/xla/service/hlo_cse.h +++ b/xla/service/hlo_cse.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,17 +33,19 @@ class HloCSE : public HloModulePass { // when replacing instructions with their equivalents. explicit HloCSE(bool is_layout_sensitive, bool only_fusion_computations = false, - bool ignore_control_dependencies = false) + bool ignore_control_dependencies = false, + bool only_scalars = false) : is_layout_sensitive_(is_layout_sensitive), only_fusion_computations_(only_fusion_computations), - ignore_control_dependencies_(ignore_control_dependencies) {} + ignore_control_dependencies_(ignore_control_dependencies), + only_scalars_(only_scalars) {} ~HloCSE() override = default; absl::string_view name() const override { return "cse"; } // Run CSE on the given module. Returns whether the module was changed (common // subexpressions were found and eliminated). using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -51,6 +53,7 @@ class HloCSE : public HloModulePass { const bool is_layout_sensitive_; const bool only_fusion_computations_; const bool ignore_control_dependencies_; + const bool only_scalars_; }; } // namespace xla diff --git a/xla/service/hlo_cse_test.cc b/xla/service/hlo_cse_test.cc index 4de4c80aa272c..bd93f7cd88b8c 100644 --- a/xla/service/hlo_cse_test.cc +++ b/xla/service/hlo_cse_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/strings/substitute.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -709,6 +710,32 @@ TEST_F(HloCseTest, OptimizationBarrier) { EXPECT_FALSE(changed); } +TEST_F(HloCseTest, OnlyScalar) { + const char* const hlo_string = R"( + HloModule m + + ENTRY entry { + %const1 = f32[] constant(1) + %const2 = f32[] constant(1) + %const3 = f32[2] constant({1,2}) + %const4 = f32[2] constant({1,2}) + %add.0 = f32[] add(%const1, %const2) + %add.1 = f32[2] add(%const3, %const4) + ROOT out = (f32[], f32[2]) tuple(%add.0, %add.1) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + HloCSE cse(/*is_layout_sensitive=*/false, /*only_fusion_computations=*/false, + /*ignore_control_dependencies=*/false, /*only_scalars=*/true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get())); + EXPECT_TRUE(changed); + EXPECT_EQ(absl::c_count_if(m->entry_computation()->instructions(), + [](const HloInstruction* instruction) { + return instruction->IsConstant(); + }), + 3); +} + class HloCseCustomCallTest : public HloCseTest, public ::testing::WithParamInterface #include #include +#include #include #include @@ -29,6 +30,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" +#include "absl/log/check.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -331,7 +333,7 @@ HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction, } void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) { - const HloValue& value = *values_.at(value_id); + const HloValue& value = GetValue(value_id); VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")"; value_ids_to_delete_.push_back(value_id); @@ -525,11 +527,13 @@ bool HloDataflowAnalysis::Phi( } const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const { - return *values_.at(value_id); + DCHECK(values_.contains(value_id)) << "Value not found: " << value_id; + return *values_.find(value_id)->second; } HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) { - return *values_.at(value_id); + DCHECK(values_.contains(value_id)) << "Value not found: " << value_id; + return *values_.find(value_id)->second; } HloValueSet HloDataflowAnalysis::GetFlattenedValueSet( @@ -1402,12 +1406,16 @@ void HloDataflowAnalysis::Propagate() { const InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet( const HloInstruction* instruction) const { - return *value_sets_.at(instruction); + DCHECK(value_sets_.contains(instruction)) + << "Instruction " << instruction->ToString() << " not found."; + return *value_sets_.find(instruction)->second; } InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet( const HloInstruction* instruction) { - return *value_sets_.at(instruction); + DCHECK(value_sets_.contains(instruction)) + << "Instruction " << instruction->ToString() << " not found."; + return *value_sets_.find(instruction)->second; } Status HloDataflowAnalysis::InitializeInstructionValueSets() { @@ -1647,8 +1655,8 @@ void HloDataflowAnalysis::OptimizePhiValues() { HloValue::Id phi_id = values[0]->id(); HloValue::Id new_id = phi_graph_.FindOptimizedValue(phi_id); if (new_id != phi_id) { - VLOG(1) << "Replacing " << values[0]->ToString() << " with " - << GetValue(new_id).ToString(); + VLOG(1) << "Replacing " << values[0]->ToShortString() << " with " + << GetValue(new_id).ToShortString(); value_set->Clear(); const HloValue& new_value = GetValue(new_id); value_set->AddValue(&new_value); @@ -1660,7 +1668,7 @@ void HloDataflowAnalysis::OptimizePhiValues() { } /* static */ -StatusOr> HloDataflowAnalysis::Run( +absl::StatusOr> HloDataflowAnalysis::Run( const HloModule& module, bool ssa_form, bool bitcast_defines_value, const CanShareBuffer& can_share_buffer, const ForwardsValue& forwards_value, absl::flat_hash_set execution_threads) { @@ -1846,59 +1854,63 @@ std::vector> GetFusionInstructionInPlaceInputOutputPairs(const HloInstruction* instruction) { std::vector> in_place_input_output_pairs; + // Each of these leaves represents one array output of the fusion that might // be aliased with one of the fusion computation's array inputs (both could be // nested arbitrarily deep inside tuples). - for (const auto& fusion_output_array_shape : - ShapeUtil::GetLeafShapes(instruction->shape())) { - // Start from the root instruction of the fusion computation and follow - // tuple indirection backwards to find the "output source", i.e. the - // instruction that is the original source of the array output in question. - // If there is no such indirection the "output source" will just be the - // fusion root instruction itself. - const HloInstruction* output_source_instruction = - instruction->fused_expression_root(); - ShapeIndex output_source_index = fusion_output_array_shape.index; - std::tie(output_source_instruction, output_source_index) = - FollowTupleIndirection(output_source_instruction, output_source_index); - - // The aliasing rules of the "output source" instruction determine the - // aliasing rules for the entire fusion. If we can connect (following tuple - // indirection) the input of an "in-place" pair to one of the fusion's - // inputs, and the output of this "in-place" pair to the fusion output - // in question, then this fusion input and output must alias. - auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs( - output_source_instruction); - ShapeIndex in_place_input_index; - const HloInstruction* in_place_input_source = nullptr; - - for (const auto& output_source_in_place_pair : in_place_pairs) { - const HloOperandIndex& input = output_source_in_place_pair.first; - const ShapeIndex& output_index = output_source_in_place_pair.second; - if (output_index == output_source_index) { - // It is not possible for the same output to alias multiple inputs. - CHECK(in_place_input_source == nullptr); - in_place_input_source = - output_source_instruction->operand(input.operand_number); - in_place_input_index = input.operand_index; - } - } - - if (in_place_input_source) { - // Follow tuple indirection backwards from the instruction input to try to - // find a fusion parameter. If found, that parameter aliases the current - // output. If not, the current output aliases no input. - std::tie(in_place_input_source, in_place_input_index) = - FollowTupleIndirection(in_place_input_source, in_place_input_index); + ShapeUtil::ForEachLeafShape( + instruction->shape(), + [&](const Shape& sub_shape, const ShapeIndex& index) { + // Start from the root instruction of the fusion computation and follow + // tuple indirection backwards to find the "output source", i.e. the + // instruction that is the original source of the array output in + // question. If there is no such indirection the "output source" will + // just be the fusion root instruction itself. + const HloInstruction* output_source_instruction = + instruction->fused_expression_root(); + ShapeIndex output_source_index = index; + std::tie(output_source_instruction, output_source_index) = + FollowTupleIndirection(output_source_instruction, + output_source_index); + + // The aliasing rules of the "output source" instruction determine the + // aliasing rules for the entire fusion. If we can connect (following + // tuple indirection) the input of an "in-place" pair to one of the + // fusion's inputs, and the output of this "in-place" pair to the fusion + // output in question, then this fusion input and output must alias. + auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs( + output_source_instruction); + ShapeIndex in_place_input_index; + const HloInstruction* in_place_input_source = nullptr; + + for (const auto& output_source_in_place_pair : in_place_pairs) { + const HloOperandIndex& input = output_source_in_place_pair.first; + const ShapeIndex& output_index = output_source_in_place_pair.second; + if (output_index == output_source_index) { + // It is not possible for the same output to alias multiple inputs. + CHECK(in_place_input_source == nullptr); + in_place_input_source = + output_source_instruction->operand(input.operand_number); + in_place_input_index = input.operand_index; + } + } - if (in_place_input_source->opcode() == HloOpcode::kParameter) { - in_place_input_output_pairs.emplace_back( - HloOperandIndex{in_place_input_source->parameter_number(), - in_place_input_index}, - fusion_output_array_shape.index); - } - } - } + if (in_place_input_source) { + // Follow tuple indirection backwards from the instruction input to + // try to find a fusion parameter. If found, that parameter aliases + // the current output. If not, the current output aliases no input. + std::tie(in_place_input_source, in_place_input_index) = + FollowTupleIndirection(in_place_input_source, + in_place_input_index); + + if (in_place_input_source->opcode() == HloOpcode::kParameter) { + in_place_input_output_pairs.emplace_back( + HloOperandIndex{in_place_input_source->parameter_number(), + in_place_input_index}, + index); + } + } + }); return in_place_input_output_pairs; } diff --git a/xla/service/hlo_dataflow_analysis.h b/xla/service/hlo_dataflow_analysis.h index 99453fd2dd31c..a5cf84409aae4 100644 --- a/xla/service/hlo_dataflow_analysis.h +++ b/xla/service/hlo_dataflow_analysis.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -110,7 +110,7 @@ class HloDataflowAnalysis { // bitcast_defines_value : If true then the Bitcast HLO instruction defines // a new HLO value in the analysis. If false then Bitcast forwards the // value of its operand. - static StatusOr> Run( + static absl::StatusOr> Run( const HloModule& module, bool ssa_form = false, bool bitcast_defines_value = false, const CanShareBuffer& can_share_buffer = nullptr, diff --git a/xla/service/hlo_dataflow_analysis_test.cc b/xla/service/hlo_dataflow_analysis_test.cc index c5f67ab658dc9..2244f5e37fb92 100644 --- a/xla/service/hlo_dataflow_analysis_test.cc +++ b/xla/service/hlo_dataflow_analysis_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,7 +31,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/literal_util.h" -#include "xla/service/async_op_canonicalizer.h" #include "xla/service/flatten_call_graph.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_dce.h" @@ -66,8 +65,6 @@ class HloDataflowAnalysisTest : public HloTestBase, const HloDataflowAnalysis& RunAnalysis(bool ssa_form, bool bitcast_defines_value = false, bool run_dce = true) { - AsyncOpCanonicalizer async_op_canonicalizer; - EXPECT_TRUE(async_op_canonicalizer.Run(module_.get()).ok()); if (run_dce) { HloDCE dce; EXPECT_TRUE(dce.Run(module_.get()).ok()); @@ -1085,8 +1082,8 @@ TEST_P(HloDataflowAnalysisTest, AsyncOps) { ENTRY entry { p0 = f32[2,3] parameter(0) async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo" - async-update = ((f32[2,3]), f32[2,3], u32[]) custom-call-update(async-start), custom_call_target="foo" - ROOT async-done = f32[2,3] custom-call-done(async-update), custom_call_target="foo" + async-update = ((f32[2,3]), f32[2,3], u32[]) custom-call-update(async-start) + ROOT async-done = f32[2,3] custom-call-done(async-update) } )"; TF_ASSERT_OK_AND_ASSIGN( @@ -1151,10 +1148,10 @@ ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] { %b = f32[4096]{0} parameter(1) %async-start = ((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) call-start(f32[4096]{0} %a, f32[4096]{0} %b), to_apply=%called_computation %negate_2 = f32[4096]{0} negate(f32[4096]{0} %a) - %async-update = ((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) call-update(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start), to_apply=%called_computation + %async-update = ((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) call-update(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start) %negate_3 = f32[4096]{0} negate(f32[4096]{0} %b) %add_0 = f32[4096]{0} add(f32[4096]{0} %negate_2, f32[4096]{0} %negate_3) - %async-done = f32[4096]{0} call-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-update), to_apply=%called_computation + %async-done = f32[4096]{0} call-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-update) ROOT %add_1 = f32[4096]{0} add(f32[4096]{0} %add_0, f32[4096]{0} %async-done) } )"; @@ -1199,8 +1196,8 @@ TEST_P(HloDataflowAnalysisTest, TupleShapedAsyncOp) { ENTRY entry { p0 = f32[2,3] parameter(0) async-start = ((f32[2,3]), (f32[2,3], f32[2,3]), u32[]) custom-call-start(p0), custom_call_target="foo" - async-update = ((f32[2,3]), (f32[2,3], f32[2,3]), u32[]) custom-call-update(async-start), custom_call_target="foo" - ROOT async-done = (f32[2,3], f32[2,3]) custom-call-done(async-update), custom_call_target="foo" + async-update = ((f32[2,3]), (f32[2,3], f32[2,3]), u32[]) custom-call-update(async-start) + ROOT async-done = (f32[2,3], f32[2,3]) custom-call-done(async-update) } )"; TF_ASSERT_OK_AND_ASSIGN( diff --git a/xla/service/hlo_dce.cc b/xla/service/hlo_dce.cc index ed24416f5c7a9..acb205d5143f9 100644 --- a/xla/service/hlo_dce.cc +++ b/xla/service/hlo_dce.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -61,7 +61,7 @@ bool IsRemovableWhile(HloInstruction* instruction, } // namespace -/*static*/ StatusOr HloDCE::RunOnComputation( +/*static*/ absl::StatusOr HloDCE::RunOnComputation( HloComputation* computation, bool remove_cross_partition_collective_ops) { bool changed = false; VLOG(3) << "Before dce:"; @@ -75,6 +75,7 @@ bool IsRemovableWhile(HloInstruction* instruction, if (instruction->IsDead() && computation->IsSafelyRemovable(instruction) && (!instruction->IsCustomCall("Sharding") || (!instruction->operand(0)->IsRoot() && + instruction->operand(0)->opcode() != HloOpcode::kParameter && instruction->operand(0)->user_count() == 1)) && (!instruction->HasSideEffect() || (remove_cross_partition_collective_ops && maybe_collective_op && @@ -127,7 +128,8 @@ Status HloDCE::RecursivelyRemoveDeadComputation( return module->RemoveEmbeddedComputation(computation); } -StatusOr HloDCE::RecursivelyRemoveDeadComputations(HloModule* module) { +absl::StatusOr HloDCE::RecursivelyRemoveDeadComputations( + HloModule* module) { // Tracks whether any dead code is eliminated by this pass. bool module_contains_dead_code = false; @@ -150,7 +152,6 @@ StatusOr HloDCE::RecursivelyRemoveDeadComputations(HloModule* module) { } // Find dead computations. - absl::flat_hash_set dead_computations; for (auto* computation : module->MakeComputationPostOrder()) { // Finds all "top-level" dead computations not called by any instructions. // contains(comp) = true and live_computation_call_count[comp] = 0 also @@ -166,7 +167,7 @@ StatusOr HloDCE::RecursivelyRemoveDeadComputations(HloModule* module) { return module_contains_dead_code; } -StatusOr HloDCE::Run( +absl::StatusOr HloDCE::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/hlo_dce.h b/xla/service/hlo_dce.h index 5d9d6ea54ab33..8e8e3fab4d58e 100644 --- a/xla/service/hlo_dce.h +++ b/xla/service/hlo_dce.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -47,20 +47,20 @@ class HloDCE : public HloModulePass { absl::string_view name() const override { return "dce"; } // Run DCE on a computation. - static StatusOr RunOnComputation( + static absl::StatusOr RunOnComputation( HloComputation* computation, bool remove_cross_partition_collective_ops); // Run the pass on the given module. Returns whether the module was changed // (instructions were removed). using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; private: // Finds all computations that are not called by any instruction and removes // them from the module. Returns whether any dead code was removed. - StatusOr RecursivelyRemoveDeadComputations(HloModule* module); + absl::StatusOr RecursivelyRemoveDeadComputations(HloModule* module); // Given a dead computation, decrements the ref count of all its called // computations and checks if any of the subcomputations become dead after the diff --git a/xla/service/hlo_dce_test.cc b/xla/service/hlo_dce_test.cc index b59c2d5fa8f70..e80cb84c67a5e 100644 --- a/xla/service/hlo_dce_test.cc +++ b/xla/service/hlo_dce_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -111,6 +111,30 @@ TEST_F(HloDceTest, CustomCallInstructionsWithSideEffect) { EXPECT_FALSE(result); } +TEST_F(HloDceTest, AsyncCustomCallInstructionsWithSideEffect) { + // Verify that custom call instruction with side-effect is not removed. + auto builder = HloComputation::Builder(TestName()); + auto instr = Cast(builder.AddInstruction( + HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"))); + instr->set_custom_call_has_side_effect(true); + builder.AddInstruction(HloInstruction::CreateTuple({})); + + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN([[maybe_unused]] HloInstruction * async_done, + module->entry_computation()->CreateAsyncInstructions( + instr, {{ShapeUtil::MakeScalarShape(U32)}}, + HloInstruction::kMainExecutionThread, + /*replace=*/true, /*override_names=*/true)); + + HloDCE dce; + TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&dce, module.get())); + EXPECT_FALSE(result); +} + TEST_F(HloDceTest, CustomCallInstructionsWithoutSideEffect) { // Verify that custom call instruction without side-effect is removed. auto builder = HloComputation::Builder(TestName()); @@ -128,6 +152,30 @@ TEST_F(HloDceTest, CustomCallInstructionsWithoutSideEffect) { EXPECT_TRUE(result); } +TEST_F(HloDceTest, AsyncCustomCallInstructionsWithoutSideEffect) { + // Verify that custom call instruction without side-effect is removed. + auto builder = HloComputation::Builder(TestName()); + auto instr = Cast(builder.AddInstruction( + HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"))); + instr->set_custom_call_has_side_effect(false); + builder.AddInstruction(HloInstruction::CreateTuple({})); + + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN([[maybe_unused]] HloInstruction * async_done, + module->entry_computation()->CreateAsyncInstructions( + instr, {{ShapeUtil::MakeScalarShape(U32)}}, + HloInstruction::kMainExecutionThread, + /*replace=*/true, /*override_names=*/true)); + + HloDCE dce; + TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&dce, module.get())); + EXPECT_TRUE(result); +} + TEST_F(HloDceTest, ShardingCustomCallInstruction) { // Verify that sharding custom call instruction is not removed. auto builder = HloComputation::Builder(TestName()); diff --git a/xla/service/hlo_dfs_reachability_test.cc b/xla/service/hlo_dfs_reachability_test.cc new file mode 100644 index 0000000000000..9bc77c75f42ea --- /dev/null +++ b/xla/service/hlo_dfs_reachability_test.cc @@ -0,0 +1,199 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/ir/hlo_dfs_reachability.h" + +#include +#include +#include + +#include "absl/random/random.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/test.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/test_benchmark.h" + +namespace xla { + +namespace { + +class HloDfsReachabilityTest : public HloTestBase {}; + +TEST_F(HloDfsReachabilityTest, NonTrivialReachability) { + // Test reachability of a non-trivial computation: + // + // const1 const2 + // | | + // | +-------+ + // | | | + // add .. negate + // | . | + // | .... exp + // | | + // +---+ +-+---+ + // | | | + // multiply copy + // + // There is a control dependency from 'add' to 'exp'. + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32, HloOpcode::kAdd, constant1, constant2)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kNegate, constant2)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, negate)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, add, exp)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kCopy, exp)); + + auto module = CreateNewVerifiedModule(); + auto computation = + module->AddEntryComputation(builder.Build(/*root_instruction=*/mul)); + + TF_CHECK_OK(add->AddControlDependencyTo(exp)); + auto reachability = HloDfsReachability::Build(computation); + + EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); + EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant1, add)); + EXPECT_FALSE(reachability->IsReachable(constant1, negate)); + EXPECT_TRUE(reachability->IsReachable(constant1, exp)); + EXPECT_TRUE(reachability->IsReachable(constant1, mul)); + EXPECT_TRUE(reachability->IsReachable(constant1, copy)); + + EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); + EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant2, add)); + EXPECT_TRUE(reachability->IsReachable(constant2, negate)); + EXPECT_TRUE(reachability->IsReachable(constant2, exp)); + EXPECT_TRUE(reachability->IsReachable(constant2, mul)); + EXPECT_TRUE(reachability->IsReachable(constant2, copy)); + + EXPECT_FALSE(reachability->IsReachable(exp, constant1)); + EXPECT_FALSE(reachability->IsReachable(exp, constant2)); + EXPECT_FALSE(reachability->IsReachable(exp, add)); + EXPECT_FALSE(reachability->IsReachable(exp, negate)); + EXPECT_TRUE(reachability->IsReachable(exp, exp)); + EXPECT_TRUE(reachability->IsReachable(exp, mul)); + EXPECT_TRUE(reachability->IsReachable(exp, copy)); + + EXPECT_FALSE(reachability->IsReachable(mul, constant1)); + EXPECT_FALSE(reachability->IsReachable(mul, constant2)); + EXPECT_FALSE(reachability->IsReachable(mul, add)); + EXPECT_FALSE(reachability->IsReachable(mul, negate)); + EXPECT_FALSE(reachability->IsReachable(mul, exp)); + EXPECT_TRUE(reachability->IsReachable(mul, mul)); + EXPECT_FALSE(reachability->IsReachable(mul, copy)); + + EXPECT_TRUE(reachability->IsConnected(constant1, copy)); + EXPECT_TRUE(reachability->IsConnected(copy, constant1)); + EXPECT_FALSE(reachability->IsConnected(negate, add)); + EXPECT_FALSE(reachability->IsConnected(add, negate)); +} + +TEST_F(HloDfsReachabilityTest, ChannelReachability) { + const Shape shape = ShapeUtil::MakeShape(F32, {5, 7}); + HloComputation::Builder builder("ChannelReachability"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); + auto send = + builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); + auto recv = + builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + + auto module = CreateNewVerifiedModule(); + module->mutable_config().set_use_spmd_partitioning(false); + module->mutable_config().set_static_device_assignment(DeviceAssignment(1, 2)); + auto computation = module->AddEntryComputation(builder.Build(recv_done)); + auto reachability = HloDfsReachability::Build(computation); + EXPECT_FALSE(reachability->IsReachable(param, recv_done)); + EXPECT_FALSE(reachability->IsReachable(send, recv)); + EXPECT_FALSE(reachability->IsReachable(send_done, recv)); +} + +class HloDfsReachabilityBenchmark { + public: + HloDfsReachabilityBenchmark(int size, std::string_view name) : name_(name) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + auto builder = HloComputation::Builder(name); + + // Build a graph of chained Exponentials, i.e. Exp(...(Exp(Input))...). + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); + HloInstruction* prev = constant; + for (int i = 1; i < size; ++i) { + prev = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, prev)); + } + + HloModuleConfig hlo_config; + module_ = std::make_unique(name_, hlo_config); + computation_ = + module_->AddEntryComputation(builder.Build(/*root_instruction=*/prev)); + } + + std::unique_ptr Build() { + return HloDfsReachability::Build(computation_); + } + + const HloComputation* computation() { return computation_; } + + private: + std::unique_ptr module_; + HloComputation* computation_; + const std::string name_; +}; + +void BM_HloDfsReachabilityBuild(benchmark::State& state) { + int num_nodes = state.range(0); + HloDfsReachabilityBenchmark bm(num_nodes, state.name()); + while (state.KeepRunningBatch(num_nodes)) { + benchmark::DoNotOptimize(bm.Build()); + } +} + +void BM_HloDfsReachabilityCheck(benchmark::State& state) { + size_t size = state.range(0); + + HloDfsReachabilityBenchmark bm(size, state.name()); + auto reachability = bm.Build(); + auto instrs = bm.computation()->MakeInstructionPostOrder(); + + size_t i = 0; + for (auto s : state) { + size_t from = i % size; + size_t to = (++i + size / 2) % size; + reachability->IsReachable(instrs[from], instrs[to]); + } +} + +#define BM_ARGS Arg(1)->Arg(64)->Arg(128)->Arg(256)->Range(512, 256 * 1024) +BENCHMARK(BM_HloDfsReachabilityBuild)->BM_ARGS; +BENCHMARK(BM_HloDfsReachabilityCheck)->BM_ARGS; + +} // namespace + +} // namespace xla diff --git a/xla/service/hlo_domain_isolator.cc b/xla/service/hlo_domain_isolator.cc index 1b3305bcf3ef4..c5355ef98d677 100644 --- a/xla/service/hlo_domain_isolator.cc +++ b/xla/service/hlo_domain_isolator.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,8 +29,8 @@ namespace xla { namespace { // Add domains which are used as users of a specific instruction. -StatusOr AddExitDomains(HloInstruction* instruction, - HloDomainIsolator::DomainCreator* creator) { +absl::StatusOr AddExitDomains( + HloInstruction* instruction, HloDomainIsolator::DomainCreator* creator) { int64_t added_domains = 0; if (instruction->opcode() == HloOpcode::kDomain) { return added_domains; @@ -55,7 +55,7 @@ StatusOr AddExitDomains(HloInstruction* instruction, return added_domains; } -StatusOr RunInternal( +absl::StatusOr RunInternal( HloModule* module, const absl::flat_hash_set& execution_threads, HloDomainIsolator::DomainCreator* creator) { @@ -98,7 +98,8 @@ StatusOr RunInternal( HloDomainIsolator::HloDomainIsolator(DomainCreatorFactory creator_factory) : creator_factory_(std::move(creator_factory)) {} -StatusOr HloDomainIsolator::UpdateDomains(HloInstruction* instruction) { +absl::StatusOr HloDomainIsolator::UpdateDomains( + HloInstruction* instruction) { DomainCreator creator = creator_factory_(); bool changed = false; // Update exit domains. @@ -122,7 +123,7 @@ StatusOr HloDomainIsolator::UpdateDomains(HloInstruction* instruction) { return changed; } -StatusOr HloDomainIsolator::Run( +absl::StatusOr HloDomainIsolator::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { DomainCreator creator = creator_factory_(); diff --git a/xla/service/hlo_domain_isolator.h b/xla/service/hlo_domain_isolator.h index 59c3646342759..db9385847f034 100644 --- a/xla/service/hlo_domain_isolator.h +++ b/xla/service/hlo_domain_isolator.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -43,10 +43,10 @@ class HloDomainIsolator : public HloModulePass { absl::string_view name() const override { return "domain_isolator"; } // Update domains for an instruction. - StatusOr UpdateDomains(HloInstruction* instruction); + absl::StatusOr UpdateDomains(HloInstruction* instruction); using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/hlo_domain_map.cc b/xla/service/hlo_domain_map.cc index d5397f05696f5..6aa3933244de8 100644 --- a/xla/service/hlo_domain_map.cc +++ b/xla/service/hlo_domain_map.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,14 +28,14 @@ limitations under the License. namespace xla { -/* static */ StatusOr> HloDomainMap::Create( +/* static */ absl::StatusOr> HloDomainMap::Create( HloComputation* computation, std::string domain_kind) { auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind))); TF_RETURN_IF_ERROR(domain_map->Populate(computation)); return std::move(domain_map); } -/* static */ StatusOr> HloDomainMap::Create( +/* static */ absl::StatusOr> HloDomainMap::Create( HloModule* module, std::string domain_kind) { auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind))); for (HloComputation* computation : module->computations()) { @@ -193,7 +193,8 @@ Status HloDomainMap::ExpandDomain(HloInstruction* instruction, return OkStatus(); } -StatusOr> HloDomainMap::CreateDomain( +absl::StatusOr> +HloDomainMap::CreateDomain( HloInstruction* instruction, const InstructionOrderMap& instructions_order) const { auto domain = std::make_unique(); diff --git a/xla/service/hlo_domain_map.h b/xla/service/hlo_domain_map.h index 2ab375fd3f260..9781457127433 100644 --- a/xla/service/hlo_domain_map.h +++ b/xla/service/hlo_domain_map.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -41,14 +41,14 @@ class HloDomainMap { // computation, of the given kind. If domain_kind is not empty, only the // kDomain instructions of domain_kind will be considered as separators. // Otherwise every kDomain instruction will be splitting domains. - static StatusOr> Create( + static absl::StatusOr> Create( HloComputation* computation, std::string domain_kind); // Creates a new HloDomainMap, creating all the domains within the input // module, of the given kind. If domain_kind is not empty, only the // kDomain instructions of domain_kind will be considered as separators. // Otherwise every kDomain instruction will be splitting domains. - static StatusOr> Create( + static absl::StatusOr> Create( HloModule* module, std::string domain_kind); // Retrieves all the domains the input module or computation are composed by. @@ -105,7 +105,7 @@ class HloDomainMap { DomainMetadata::Domain* domain) const; // Creates a domain data structure using the ExpandDomain() API. - StatusOr> CreateDomain( + absl::StatusOr> CreateDomain( HloInstruction* instruction, const InstructionOrderMap& instructions_order) const; diff --git a/xla/service/hlo_domain_remover.cc b/xla/service/hlo_domain_remover.cc index 5d37b0ff3519d..3d35a25595e1b 100644 --- a/xla/service/hlo_domain_remover.cc +++ b/xla/service/hlo_domain_remover.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ class HloDomainRemover::RunContext { RunContext(HloModule* module, HloDomainRemover* remover) : module_(module), remover_(remover) {} - StatusOr Run( + absl::StatusOr Run( const absl::flat_hash_set& execution_threads); private: @@ -58,7 +58,7 @@ Status HloDomainRemover::RunContext::VerifyAndNormalizeDomain( return OkStatus(); } -StatusOr HloDomainRemover::RunContext::Run( +absl::StatusOr HloDomainRemover::RunContext::Run( const absl::flat_hash_set& execution_threads) { VLOG(4) << "Processing metadata domain: '" << remover_->kind_ << "'"; int64_t removed_domains = 0; @@ -100,7 +100,7 @@ StatusOr HloDomainRemover::RunContext::Run( return removed_domains > 0; } -StatusOr HloDomainRemover::RemoveExitDomains( +absl::StatusOr HloDomainRemover::RemoveExitDomains( HloInstruction* instruction, absl::string_view domain_kind) { int64_t removed_domains = 0; HloComputation* computation = instruction->parent(); @@ -120,7 +120,7 @@ StatusOr HloDomainRemover::RemoveExitDomains( return removed_domains; } -StatusOr HloDomainRemover::Run( +absl::StatusOr HloDomainRemover::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { RunContext run_context(module, this); diff --git a/xla/service/hlo_domain_remover.h b/xla/service/hlo_domain_remover.h index 25039fa0e6664..9cda66d36e845 100644 --- a/xla/service/hlo_domain_remover.h +++ b/xla/service/hlo_domain_remover.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -46,11 +46,11 @@ class HloDomainRemover : public HloModulePass { // Remove domains of a given kind which are used as users of a specific // instruction. - static StatusOr RemoveExitDomains(HloInstruction* instruction, - absl::string_view domain_kind); + static absl::StatusOr RemoveExitDomains( + HloInstruction* instruction, absl::string_view domain_kind); using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/hlo_domain_test.cc b/xla/service/hlo_domain_test.cc index b938e8a32b121..8e82878602f1a 100644 --- a/xla/service/hlo_domain_test.cc +++ b/xla/service/hlo_domain_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_domain_verifier.cc b/xla/service/hlo_domain_verifier.cc index afa07acce1db2..0d83a318dbf9e 100644 --- a/xla/service/hlo_domain_verifier.cc +++ b/xla/service/hlo_domain_verifier.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -84,7 +84,7 @@ Status HloDomainVerifier::RunContext::Run( return OkStatus(); } -StatusOr HloDomainVerifier::Run( +absl::StatusOr HloDomainVerifier::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { RunContext run_context(module, this); @@ -92,7 +92,7 @@ StatusOr HloDomainVerifier::Run( return false; } -StatusOr HloDomainVerifier::VerifyDomain( +absl::StatusOr HloDomainVerifier::VerifyDomain( const DomainMetadata::Domain& domain) { const DomainMetadata* ref_metadata = nullptr; VLOG(4) << "Reach set:"; diff --git a/xla/service/hlo_domain_verifier.h b/xla/service/hlo_domain_verifier.h index 4ceb131f38626..15ab7e3894bef 100644 --- a/xla/service/hlo_domain_verifier.h +++ b/xla/service/hlo_domain_verifier.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ class HloDomainVerifier : public HloModulePass { absl::string_view name() const override { return "domain_verifier"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -55,7 +55,7 @@ class HloDomainVerifier : public HloModulePass { // represents the common metadata within such domain. If the returned // DomainMetadata pointer is nullptr, the input domain had no kDomain // boundary. - static StatusOr VerifyDomain( + static absl::StatusOr VerifyDomain( const DomainMetadata::Domain& domain); private: diff --git a/xla/service/hlo_element_type_converter.cc b/xla/service/hlo_element_type_converter.cc index f33b73c7f84c9..05721ae9b2a2c 100644 --- a/xla/service/hlo_element_type_converter.cc +++ b/xla/service/hlo_element_type_converter.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -113,7 +113,7 @@ HloElementTypeConverter::HloElementTypeConverter( // This routine converts the arithmetic operations in the given module that use // eliminate_type_ to operations that use replace_with_type_. -StatusOr HloElementTypeConverter::Run( +absl::StatusOr HloElementTypeConverter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES( diff --git a/xla/service/hlo_element_type_converter.h b/xla/service/hlo_element_type_converter.h index 9cfd95fe75a95..6f1eced4bf3d9 100644 --- a/xla/service/hlo_element_type_converter.h +++ b/xla/service/hlo_element_type_converter.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ class HloElementTypeConverter : public HloModulePass { // Returns the pass on the module and returns whether the module was modified. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/hlo_element_type_converter_test.cc b/xla/service/hlo_element_type_converter_test.cc index 257ee81683f8f..d7bfc8d0a0961 100644 --- a/xla/service/hlo_element_type_converter_test.cc +++ b/xla/service/hlo_element_type_converter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_execution_profile.cc b/xla/service/hlo_execution_profile.cc index b782998641dc6..78d823726bf33 100644 --- a/xla/service/hlo_execution_profile.cc +++ b/xla/service/hlo_execution_profile.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_execution_profile.h b/xla/service/hlo_execution_profile.h index 306c81b43e0fa..9a5f6348965ad 100644 --- a/xla/service/hlo_execution_profile.h +++ b/xla/service/hlo_execution_profile.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_execution_profile_data.proto b/xla/service/hlo_execution_profile_data.proto index 1577880764030..b0897bd066ce7 100644 --- a/xla/service/hlo_execution_profile_data.proto +++ b/xla/service/hlo_execution_profile_data.proto @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_execution_profile_test.cc b/xla/service/hlo_execution_profile_test.cc index d7dfe7e61ed74..5a27be1d80028 100644 --- a/xla/service/hlo_execution_profile_test.cc +++ b/xla/service/hlo_execution_profile_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_graph_dumper.cc b/xla/service/hlo_graph_dumper.cc index bcb761531d2a5..8a2dfe9b4ef63 100644 --- a/xla/service/hlo_graph_dumper.cc +++ b/xla/service/hlo_graph_dumper.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,28 @@ limitations under the License. #include "xla/service/hlo_graph_dumper.h" +#include +#include + +#include "absl/base/const_init.h" +#include "absl/base/thread_annotations.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/shape.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/file_system.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/thread_annotations.h" + #ifndef _WIN32 #include #endif @@ -22,14 +44,17 @@ limitations under the License. #include #include #include +#include #include #include #include #include #include #include +#include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/match.h" @@ -98,8 +123,9 @@ class NodeFilter { NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {} explicit NodeFilter( - std::function filter) - : filter_(std::move(filter)) {} + std::function filter, + std::optional num_rendered = std::nullopt) + : filter_(std::move(filter)), num_rendered_(num_rendered) {} bool Show(const HloInstruction* instr) const { return filter_(instr) != kHideNode; @@ -120,8 +146,12 @@ class NodeFilter { result == kSomeUsersOmitted; } + // Returns an optionally recorded number of nodes which will be rendered. + std::optional GetNumRendered() const { return num_rendered_; } + private: std::function filter_; + std::optional num_rendered_; }; // We arbitrarily set this as the boundary between "large" and "small" @@ -157,10 +187,10 @@ enum ColorScheme { // Graphviz attributes/colors that make up a color scheme. struct NodeColors { - const char* style; - const char* fill_color; - const char* stroke_color; - const char* font_color; + std::string style; + std::string fill_color; + std::string stroke_color; + std::string font_color; }; NodeColors NodeColorsForScheme(ColorScheme color) { @@ -204,7 +234,7 @@ NodeColors NodeColorsForScheme(ColorScheme color) { // Given a Statistic object, returns a hex string for the fill color of the node // with that statistic. -const char* NodeFillColorForStatistic(const Statistic& statistic) { +std::string NodeFillColorForStatistic(const Statistic& statistic) { auto stat_val = statistic.stat_val(); if (stat_val == 0) { return "#f5f5f5"; @@ -233,7 +263,7 @@ const char* NodeFillColorForStatistic(const Statistic& statistic) { // Given a Statistic object, returns a hex string for the font color of the node // with that statistic. -const char* NodeFontColorForStatistic(const Statistic& statistic) { +std::string NodeFontColorForStatistic(const Statistic& statistic) { if (statistic.stat_val() < 60) { return "black"; } else { @@ -368,14 +398,18 @@ optional MatchTrivialComputation( // Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax). class HloDotDumper { public: - HloDotDumper(const HloComputation* computation, absl::string_view label, - const DebugOptions& debug_options, - HloRenderOptions hlo_render_options, NodeFilter filter) + HloDotDumper( + const HloComputation* computation, absl::string_view label, + const DebugOptions& debug_options, HloRenderOptions hlo_render_options, + NodeFilter filter, + std::optional> + color_map = std::nullopt) : computation_(computation), label_(label), debug_options_(debug_options), hlo_render_options_(hlo_render_options), - filter_(std::move(filter)) {} + filter_(std::move(filter)), + color_map_(color_map) {} std::string Dump(); @@ -459,7 +493,8 @@ class HloDotDumper { const DebugOptions& debug_options_; const HloRenderOptions hlo_render_options_; const NodeFilter filter_; - + const std::optional> + color_map_; // Each HloInstruction dumped gets a monotonically-increasing node ID. This // must start at 1, because that's where graphviz's accounting starts. int64_t next_node_id_ = 1; @@ -551,15 +586,15 @@ stylesheet=< // because the "X ~ Y" CSS selector finds a sibling of X that *comes // after X in the DOM* and matches Y. std::vector edge_css_rules; - const char* kBlue = "#1976d2"; - const char* kRed = "#d32f2f"; + std::string kBlue = "#1976d2"; + std::string kRed = "#d32f2f"; for (const auto& kv : edge_ids_) { const HloInstruction* from_node = kv.first.first; const HloInstruction* to_node = kv.first.second; int64_t edge_id = kv.second; auto add_hover_css_rule = [&](std::string elem_type, int64_t elem_id, - const char* color) { + std::string color) { // One could imagine other ways of writing this CSS rule that involve // less duplication, but this way seems to be relatively performant. edge_css_rules.push_back( @@ -644,6 +679,11 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { return false; } + if (subcomp->WhileCallInstruction() != nullptr && + !hlo_render_options_.show_while_subcomputations) { + return false; + } + // Show the subcomputation if we're showing any of its members. return absl::c_any_of( subcomp->instructions(), @@ -695,8 +735,8 @@ std::string HloDotDumper::DumpSubcomputation( } bool highlight = filter_.Highlight(parent_instr); - const char* fillcolor; - const char* strokecolor; + std::string fillcolor; + std::string strokecolor; if (!highlight && (parent_instr->module_has_statistics() || parent_instr->has_statistics())) { @@ -864,8 +904,6 @@ std::string HloDotDumper::DumpInstruction(const HloInstruction* instr) { VLOG(2) << "Adding node " << instr->name() << " as " << next_node_id_; node_ids_[instr] = next_node_id_++; - - ColorScheme color = GetInstructionColor(instr); std::string node_shape = GetInstructionNodeShape(instr); std::string node_label = GetInstructionNodeLabel(instr); std::string node_metadata = GetInstructionNodeMetadata(instr); @@ -875,42 +913,66 @@ std::string HloDotDumper::DumpInstruction(const HloInstruction* instr) { std::string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); AddInstructionIncomingEdges(instr); - - if (!debug_options_.xla_hlo_graph_sharding_color()) { - // Override the node's styling if it should be (de-)emphasized. - if (filter_.Deemphasized(instr)) { - color = kDashedBorder; + NodeColors node_colors; + std::string node_style; + std::string node_attributes; + if (hlo_render_options_.override_node_colors && color_map_.has_value()) { + if (color_map_->contains(instr)) { + // look up color stats in the color_map_ + node_colors.fill_color = color_map_->at(instr).color; + node_attributes = color_map_->at(instr).stats; + } else { + VLOG(2) << "color_map_ for instruction:" << instr->name() << "is empty" + << "\n"; + node_colors.fill_color = "#808080"; } - if (filter_.Highlight(instr)) { - node_shape = "diamond"; - color = kDarkRed; + node_colors.style = "filled"; + node_colors.font_color = "black"; + node_colors.stroke_color = "#c2c2c2"; + node_style = + StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", + node_colors.style, node_colors.font_color, + node_colors.stroke_color, node_colors.fill_color); + } else { + ColorScheme color = GetInstructionColor(instr); + if (!debug_options_.xla_hlo_graph_sharding_color()) { + // Override the node's styling if it should be (de-)emphasized. + if (filter_.Deemphasized(instr)) { + color = kDashedBorder; + } + if (filter_.Highlight(instr)) { + node_shape = "diamond"; + color = kDarkRed; + } } - } - NodeColors node_colors = NodeColorsForScheme(color); - if (instr->has_statistics()) { - // override node's color to show statistics - const auto& statistic_to_visualize = instr->statistic_to_visualize(); - node_colors.fill_color = NodeFillColorForStatistic(statistic_to_visualize); - node_colors.stroke_color = "#c2c2c2"; - node_colors.font_color = NodeFontColorForStatistic(statistic_to_visualize); - } else if (instr->module_has_statistics()) { - // all other nodes without statistics must be gray - node_colors.fill_color = "#f5f5f5"; - node_colors.stroke_color = "#c2c2c2"; - node_colors.font_color = "black"; + node_colors = NodeColorsForScheme(color); + if (instr->has_statistics()) { + // override node's color to show statistics + const auto& statistic_to_visualize = instr->statistic_to_visualize(); + node_colors.fill_color = + NodeFillColorForStatistic(statistic_to_visualize); + node_colors.stroke_color = "#c2c2c2"; + node_colors.font_color = + NodeFontColorForStatistic(statistic_to_visualize); + } else if (instr->module_has_statistics()) { + // all other nodes without statistics must be gray + node_colors.fill_color = "#f5f5f5"; + node_colors.stroke_color = "#c2c2c2"; + node_colors.font_color = "black"; + } + + // Build the node style + node_style = + StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", + node_colors.style, node_colors.font_color, + node_colors.stroke_color, node_colors.fill_color); } - - // Build the node style - std::string node_style = - StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", - node_colors.style, node_colors.font_color, - node_colors.stroke_color, node_colors.fill_color); - // Build the text that will be displayed inside the node. std::string node_body = node_label; - for (const std::string& s : {trivial_subcomputation, extra_info, - inlined_constants, node_backend_config}) { + for (const std::string& s : + {trivial_subcomputation, extra_info, inlined_constants, + node_backend_config, node_attributes}) { if (!s.empty()) { StrAppend(&node_body, "
", s); } @@ -1081,6 +1143,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kConvert: case HloOpcode::kCos: case HloOpcode::kDivide: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -1196,6 +1259,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kAllReduceStart: case HloOpcode::kAllReduceDone: case HloOpcode::kAllToAll: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermuteStart: case HloOpcode::kCollectivePermuteDone: @@ -1346,18 +1410,20 @@ std::string HloDotDumper::GetInstructionNodeBackendConfig( // !show_backend_config, but this is simpler, and it's not too noisy.) std::vector> props; if (gpu::IsCustomCallToDnnConvolution(*instr)) { - StatusOr config = - instr->backend_config(); + absl::StatusOr config = + instr->backend_config(); if (config.ok()) { - props = ExtractCudnnConvBackendConfigProps(*config); + props = ExtractCudnnConvBackendConfigProps( + config->cudnn_conv_backend_config()); } } else if (gpu::IsCublasGemm(*instr)) { - StatusOr config = - instr->backend_config(); + absl::StatusOr config = + instr->backend_config(); if (config.ok()) { // gemm strides are generally uninteresting (derived from the instruction // shape), so we hide them by default. - props = ExtractGemmBackendConfigProps(*config, instr); + props = + ExtractGemmBackendConfigProps(config->gemm_backend_config(), instr); } } @@ -1392,11 +1458,12 @@ std::string HloDotDumper::GetInstructionNodeExtraInfo( for (const auto& line : instr->ExtraAttributesToString( HloPrintOptions().set_print_subcomputation_mode( HloPrintOptions::PrintSubcomputationMode::kOff))) { - // Some instructions have giant device identifier fields, so truncate their - // length to 128. + // Some instructions have giant device identifier or control-predecessor + // fields, so truncate their length to 128. constexpr int kMaxDeviceIdFieldLen = 128; if ((absl::StartsWith(line, "replica_groups=") || - absl::StartsWith(line, "source_target_pairs=")) && + absl::StartsWith(line, "source_target_pairs=") || + absl::StartsWith(line, "control-predecessors=")) && line.length() > kMaxDeviceIdFieldLen) { lines.push_back(HtmlLikeStringSanitize( StrCat(line.substr(0, kMaxDeviceIdFieldLen - 3), "..."))); @@ -1575,8 +1642,20 @@ NodeFilter MakeNodeRadiusAroundFilter( // are not interesting to the graph at hand. if (instr == root || instr->opcode() != HloOpcode::kTuple) { for (const HloInstruction* operand : instr->operands()) { + // Special logic for handling bitcasts: since sometimes bitcasts are not + // fused, they create a lot of extra nodes in the graph, with exactly + // one input and output. Adding such nodes does not "really" increase + // the size of the graph (since they don't add extra information), and + // stopping the rendering early cuts off important information (you + // almost never want the rendering to be cutoff at the bitcast: you'd + // like to see its parent). if (!nodes.contains(operand)) { - worklist.push_back({operand, depth + 1}); + int new_depth = (operand->opcode() == HloOpcode::kBitcast || + instr->opcode() == HloOpcode::kBitcast) + ? depth + : depth + 1; + + worklist.push_back({operand, new_depth}); } } } @@ -1643,17 +1722,19 @@ NodeFilter MakeNodeRadiusAroundFilter( // Highlight the root node. nodes[root] = kHighlightNode; - return NodeFilter([=](const HloInstruction* instr) { - auto it = nodes.find(instr); - if (it != nodes.end()) { - return it->second; - } - // Show all nodes in subcomputations. - if (instr->parent() != root->parent()) { - return kNormalNode; - } - return kHideNode; - }); + return NodeFilter( + [=](const HloInstruction* instr) { + auto it = nodes.find(instr); + if (it != nodes.end()) { + return it->second; + } + // Show all nodes in subcomputations. + if (instr->parent() != root->parent()) { + return kNormalNode; + } + return kHideNode; + }, + nodes.size()); } // Gets a node filter that includes nodes on all paths from `from` to `to`. If @@ -1707,7 +1788,7 @@ NodeFilter MakeNodeFromToFilter(const HloInstruction* from, } absl::Mutex url_renderer_mu(absl::kConstInit); -std::function(absl::string_view)>* url_renderer +std::function(absl::string_view)>* url_renderer ABSL_GUARDED_BY(url_renderer_mu) = nullptr; // Storage for fusion visualization: (module_id, computation_id) -> sequence of @@ -1752,11 +1833,10 @@ static std::pair FusionVisualizerStateKey( computation.unique_id()); } - } // namespace // Compress with zlib + b64 encode. -static StatusOr CompressAndEncode(absl::string_view input) { +static absl::StatusOr CompressAndEncode(absl::string_view input) { class WritableStringFile : public tsl::WritableFile { public: explicit WritableStringFile(std::string* data) : data_(data){}; @@ -1797,11 +1877,11 @@ static std::string EscapeJSONString(absl::string_view raw) { "\""); } -StatusOr WrapFusionExplorer( +absl::StatusOr WrapFusionExplorer( const FusionVisualizerProgress& visualizer_progress, absl::string_view graph_title) { if (visualizer_progress.frames.empty()) { - return InternalError("Empty"); + return Internal("Empty"); } std::string dot_graphs = @@ -1901,7 +1981,7 @@ StatusOr WrapFusionExplorer( var area = document.getElementById('rendered'); area.innerHTML = `${svg}`; var panzoom = svgPanZoom(area.children[0], { - zoomEnabled: true, controlIconsEnabled: true, }); + zoomEnabled: true, controlIconsEnabled: true, maxZoom: 200, }); var to_highlight = frame[2].length ? document.querySelector(`${frame[2]}`) : null; if (to_highlight) { @@ -1909,6 +1989,14 @@ StatusOr WrapFusionExplorer( } document.getElementById('performance_note').innerText = `Rendering took ${(performance.now() - render_start).toFixed(2)}ms`; + + // Change cursor. + let text_nodes = document.getElementsByTagName("text"); + for (var el of text_nodes) { + if (title_to_id.has(el.innerHTML)) { + el.style.cursor = "pointer"; + } + } }; if (renderCache[dot_ptr]) { render_callback(renderCache[dot_ptr]); @@ -1979,6 +2067,20 @@ StatusOr WrapFusionExplorer( renderFrameList(); renderCurrentFrame(); }); + + window.title_to_id = new Map(); + for (let i=0; i < frames.length; i++) { + title_to_id.set(frames[i][1], i); + } + + // Navigate to next elements on click. + document.addEventListener("click", (event) => { + let txt = event.target.innerHTML; + if (title_to_id.has(txt)) { + let id = title_to_id.get(txt); + window.location.hash = `#frame${id}`; + } + }); }); //--> @@ -1995,15 +2097,16 @@ static std::string GraphTitle(const HloComputation& computation) { return absl::StrCat(computation.parent()->name(), "_", computation.name()); } -StatusOr WrapFusionExplorer(const HloComputation& computation) { +absl::StatusOr WrapFusionExplorer( + const HloComputation& computation) { absl::MutexLock lock(&fusion_visualizer_state_mu); const FusionVisualizerProgress& visualizer_progress = fusion_visualizer_states[FusionVisualizerStateKey(computation)]; return WrapFusionExplorer(visualizer_progress, GraphTitle(computation)); } -static StatusOr WrapDotInHtml(absl::string_view dot, - absl::string_view title) { +static absl::StatusOr WrapDotInHtml(absl::string_view dot, + absl::string_view title) { FusionVisualizerProgress progress; progress.AddState(dot, title, std::nullopt); return WrapFusionExplorer(progress, title); @@ -2015,10 +2118,9 @@ static StatusOr WrapDotInHtml(absl::string_view dot, // returning an error because we want to fail quickly when there's no URL // renderer available, and this function runs only after we've done all the work // of producing dot for the graph.) -static StatusOr WrapDotInFormat(const HloComputation& computation, - absl::string_view dot, - RenderedGraphFormat format) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) { +static absl::StatusOr WrapDotInFormat( + const HloComputation& computation, absl::string_view dot, + RenderedGraphFormat format) ABSL_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) { switch (format) { case RenderedGraphFormat::kUrl: CHECK(url_renderer != nullptr) @@ -2032,7 +2134,7 @@ static StatusOr WrapDotInFormat(const HloComputation& computation, } void RegisterGraphToURLRenderer( - std::function(absl::string_view)> renderer) { + std::function(absl::string_view)> renderer) { absl::MutexLock lock(&url_renderer_mu); if (url_renderer != nullptr) { LOG(WARNING) << "Multiple calls to RegisterGraphToURLRenderer. Last call " @@ -2040,8 +2142,9 @@ void RegisterGraphToURLRenderer( "nondeterministic, this may not be what you want."; } delete url_renderer; - url_renderer = new std::function(absl::string_view)>( - std::move(renderer)); + url_renderer = + new std::function(absl::string_view)>( + std::move(renderer)); } void RegisterFusionState(const HloComputation& computation, @@ -2075,26 +2178,69 @@ void RegisterFusionState(const HloComputation& computation, fusion_progress.AddState(dot_txt, label, producer_to_highlight); } -StatusOr RenderGraph(const HloComputation& computation, - absl::string_view label, - const DebugOptions& debug_options, - RenderedGraphFormat format, - HloRenderOptions hlo_render_options) { +absl::StatusOr RenderGraph( + const HloComputation& computation, absl::string_view label, + const DebugOptions& debug_options, RenderedGraphFormat format, + HloRenderOptions hlo_render_options, + std::optional> + color_map) { absl::MutexLock lock(&url_renderer_mu); if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) { return Unavailable("Can't render as URL; no URL renderer was registered."); } - std::string rendered_dot = HloDotDumper(&computation, label, debug_options, - hlo_render_options, NodeFilter()) - .Dump(); + std::string rendered_dot = + HloDotDumper(&computation, label, debug_options, hlo_render_options, + NodeFilter(), color_map) + .Dump(); return WrapDotInFormat(computation, rendered_dot, format); } -StatusOr RenderNeighborhoodAround( +absl::StatusOr RenderAllComputationsToHtml( + const HloModule& module) { + FusionVisualizerProgress progress; + + std::vector instrs = + module.entry_computation()->MakeInstructionPostOrder(); + absl::c_reverse(instrs); + for (const HloInstruction* instr : instrs) { + if (absl::c_linear_search( + std::vector{HloOpcode::kConstant, + HloOpcode::kGetTupleElement}, + instr->opcode())) { + continue; + } + + HloRenderOptions opts; + opts.show_fusion_subcomputations = true; + opts.show_backend_config = true; + opts.show_while_subcomputations = instr->opcode() == HloOpcode::kWhile; + + // Dynamically adjusts the radius with a magical cutoff of 100. + static constexpr int64_t max_nodes_to_render = 100; + absl::flat_hash_set render_boundary; + + NodeFilter filter = MakeNodeRadiusAroundFilter(instr, 2, render_boundary); + if (filter.GetNumRendered().value_or(1) > max_nodes_to_render) { + filter = MakeNodeRadiusAroundFilter(instr, 1, render_boundary); + } + + std::string dot = + HloDotDumper(module.entry_computation(), instr->name(), + module.config().debug_options(), opts, filter) + .Dump(); + progress.AddState(dot, instr->name(), std::nullopt); + } + + return WrapFusionExplorer(progress, module.name()); +} + +absl::StatusOr RenderNeighborhoodAround( const HloInstruction& node, int radius, RenderedGraphFormat format, HloRenderOptions hlo_render_options, - const absl::flat_hash_set& boundary) { + const absl::flat_hash_set& boundary, + std::optional> + color_map) { absl::MutexLock lock(&url_renderer_mu); if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) { return FailedPrecondition( @@ -2104,15 +2250,15 @@ StatusOr RenderNeighborhoodAround( std::string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); std::string rendered_dot = - HloDotDumper(node.parent(), label, - node.GetModule()->config().debug_options(), - hlo_render_options, - MakeNodeRadiusAroundFilter(&node, radius, boundary)) + HloDotDumper( + node.parent(), label, node.GetModule()->config().debug_options(), + hlo_render_options, + MakeNodeRadiusAroundFilter(&node, radius, boundary), color_map) .Dump(); return WrapDotInFormat(*node.parent(), rendered_dot, format); } -StatusOr RenderAllPathsFromTo( +absl::StatusOr RenderAllPathsFromTo( const HloInstruction& from, const HloInstruction& to, int64_t max_nodes, RenderedGraphFormat format, HloRenderOptions hlo_render_options) { absl::MutexLock lock(&url_renderer_mu); diff --git a/xla/service/hlo_graph_dumper.h b/xla/service/hlo_graph_dumper.h index 12b589abb93ce..1b5c7353d0f5c 100644 --- a/xla/service/hlo_graph_dumper.h +++ b/xla/service/hlo_graph_dumper.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,15 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_GRAPH_DUMPER_H_ #define XLA_SERVICE_HLO_GRAPH_DUMPER_H_ +#include #include +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" -#include "xla/types.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/statusor.h" #include "xla/xla.pb.h" // This file contains routines for rendering HLO computations into a @@ -30,35 +35,19 @@ limitations under the License. // // - as a raw DOT file, which can be rendered using `graphviz`. // -// - as an HTML file with an embedded DOT file, which can be viewed in a -// browser using a version of graphviz compiled to JavaScript +// - as an HTML file with an embedded DOT file, rendered in JavaScript. // -// - as a URL hosted somewhere which somehow embeds the DOT file. +// - as an HTML page showing the fusion progress, rendered in JavaScript. // -// - as an HTML page showing the fusion progress. +// - as a URL hosted somewhere which somehow embeds the DOT file. // -// Two last options are not implemented by default, but you can add a plugin to +// The last option is not implemented by default, but you can add a plugin to // implement it via RegisterGraphToURLRenderer. // // TODO(jlebar): Rename this file to hlo_graph_renderer. namespace xla { -inline constexpr char kRenderDotJS[] = R"( - - - - - -)"; - // Different formats that a graph can be packaged as. enum class RenderedGraphFormat { kDot, @@ -72,6 +61,17 @@ struct HloRenderOptions { // Include the fusion subcomputations in the rendered graph. bool show_fusion_subcomputations = true; + + // Include the while subcomputations in the rendered graph. + bool show_while_subcomputations = true; + + bool override_node_colors = false; +}; + +// Contains color computed according to the numerical diff of an HloInstruction +struct ColorStats { + std::string color; + std::string stats; }; // Renders an HLO module as a human-readable visual graph. @@ -81,11 +81,15 @@ struct HloRenderOptions { // unreadable, or both. To view such graphs, use a tool such as // interactive_graphviz, which calls RenderNeighborhoodAround to render subsets // of a graph. -StatusOr RenderGraph(const HloComputation& computation, - absl::string_view label, - const DebugOptions& debug_options, - RenderedGraphFormat format, - HloRenderOptions hlo_render_options = {}); +absl::StatusOr RenderGraph( + const HloComputation& computation, absl::string_view label, + const DebugOptions& debug_options, RenderedGraphFormat format, + HloRenderOptions hlo_render_options = {}, + std::optional> + color_map = std::nullopt); + +absl::StatusOr RenderAllComputationsToHtml( + const HloModule& module); // Like RenderGraph, but renders only nodes "near" the given node in the graph. // @@ -95,15 +99,17 @@ StatusOr RenderGraph(const HloComputation& computation, // // The optional boundary specifies a set of boundary nodes, beyond which nodes // will be omitted even if they are within the radius. -StatusOr RenderNeighborhoodAround( +absl::StatusOr RenderNeighborhoodAround( const HloInstruction& node, int radius, RenderedGraphFormat format, HloRenderOptions hlo_render_options = {}, - const absl::flat_hash_set& boundary = {}); + const absl::flat_hash_set& boundary = {}, + std::optional> + color_map = std::nullopt); // Renders nodes on any of the paths from `from` to `to`. If there are more // than max_nodes on all paths, restricts to the max_nodes nodes on the shortest // paths. -StatusOr RenderAllPathsFromTo( +absl::StatusOr RenderAllPathsFromTo( const HloInstruction& from, const HloInstruction& to, int64_t max_nodes, RenderedGraphFormat format, HloRenderOptions hlo_render_options = {}); @@ -127,11 +133,12 @@ void RegisterFusionState(const HloComputation& computation, // There can only be one active renderer, and the last call to this function // wins. void RegisterGraphToURLRenderer( - std::function(absl::string_view dot)> renderer); + std::function(absl::string_view dot)> renderer); // Generates a fusion explorer for the given computation using the data in // fusion_visualizer_state. -StatusOr WrapFusionExplorer(const HloComputation& computation); +absl::StatusOr WrapFusionExplorer( + const HloComputation& computation); } // namespace xla diff --git a/xla/service/hlo_graph_dumper_test.cc b/xla/service/hlo_graph_dumper_test.cc index 4d2d8e8589131..d76e734f33ec7 100644 --- a/xla/service/hlo_graph_dumper_test.cc +++ b/xla/service/hlo_graph_dumper_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,7 +15,9 @@ limitations under the License. #include "xla/service/hlo_graph_dumper.h" +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -212,5 +214,42 @@ ENTRY %conditional_select (constant: pred[]) -> (f32[]) { DebugOptions(), RenderedGraphFormat::kDot)); } +TEST_F(HloGraphDumperTest, OverrideColors) { + const char* hlo_string = R"( + HloModule comp + + ENTRY comp { + param.0 = f32[10] parameter(0) + param.1 = f32[10] parameter(1) + ROOT lt = pred[10] compare(param.0, param.1), direction=LT + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + // Create a color map with color and stats + absl::flat_hash_map color_map; + ColorStats color_stats_1; + color_stats_1.color = "#A9C343"; + color_stats_1.stats = absl::StrFormat("%.3f", 1.11); + ColorStats color_stats_2; + color_stats_2.color = "#BC8A3F"; + color_stats_2.stats = absl::StrFormat("%.3f", 2.22); + color_map[module->entry_computation()->GetInstructionWithName("param.0")] = + color_stats_1; + color_map[module->entry_computation()->GetInstructionWithName("param.1")] = + color_stats_2; + + HloRenderOptions hlo_render_options; + hlo_render_options.override_node_colors = true; + TF_ASSERT_OK_AND_ASSIGN( + std::string graph, + RenderGraph(*module->entry_computation(), /*label=*/"tuple_constant", + DebugOptions(), RenderedGraphFormat::kDot, hlo_render_options, + color_map)); + EXPECT_THAT(graph, HasSubstr("#A9C343")); + EXPECT_THAT(graph, HasSubstr("1.110")); + EXPECT_THAT(graph, HasSubstr("#BC8A3F")); + EXPECT_THAT(graph, HasSubstr("2.220")); +} + } // anonymous namespace } // namespace xla diff --git a/xla/service/hlo_input_output_alias_config_test.cc b/xla/service/hlo_input_output_alias_config_test.cc index e5c3d01092db4..5f583c1f38f0b 100644 --- a/xla/service/hlo_input_output_alias_config_test.cc +++ b/xla/service/hlo_input_output_alias_config_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_instruction_test.cc b/xla/service/hlo_instruction_test.cc index 62d86cfa0c70c..408ae55b2aaec 100644 --- a/xla/service/hlo_instruction_test.cc +++ b/xla/service/hlo_instruction_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" #include "xla/protobuf_util.h" #include "xla/service/gpu/backend_configs.pb.h" @@ -35,6 +36,7 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/util.h" #include "xla/window_util.h" +#include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" namespace xla { @@ -855,6 +857,60 @@ TEST_F(HloInstructionTest, AsyncOp) { EXPECT_EQ(computation->root_instruction(), async_done); } +TEST_F(HloInstructionTest, AsyncOpWithDeps) { + HloComputation::Builder builder(TestName()); + // Create a call instruction containing a single binary operation. + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.1f))); + + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); + auto constant4 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.1f))); + + auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32_, HloOpcode::kAdd, constant3, constant4)); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32_, HloOpcode::kAdd, constant1, constant2)); + + auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32_, HloOpcode::kAdd, constant1, constant2)); + + // control chain is add1 <- add <- add2 + TF_ASSERT_OK(add1->AddControlDependencyTo(add)); + + TF_ASSERT_OK(add->AddControlDependencyTo(add2)); + + auto module = CreateNewVerifiedModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN( + auto* async_done, + computation->CreateAsyncInstructions( + add, {ShapeUtil::MakeScalarShape(U32)}, "parallel_thread")); + auto* async_start = async_done->operand(0); + // Verify that control chain is not broken. + // New chain should be add1 <- asyncStart <- asyncDone <- add2 + EXPECT_EQ(async_start->control_predecessors().size(), 1); + EXPECT_EQ(async_start->control_predecessors()[0], add1); + + EXPECT_EQ(async_done->control_successors().size(), 1); + EXPECT_EQ(async_done->control_successors()[0], add2); + + EXPECT_EQ(async_start->shape().tuple_shapes_size(), 3); + EXPECT_EQ(async_start->async_execution_thread(), "parallel_thread"); + EXPECT_EQ(async_done->async_execution_thread(), "parallel_thread"); + EXPECT_TRUE(ShapeUtil::Equal(async_start->shape().tuple_shapes(2), + ShapeUtil::MakeScalarShape(U32))); + EXPECT_EQ(async_start->async_wrapped_computation()->execution_thread(), + "parallel_thread"); + EXPECT_EQ(async_done->async_wrapped_computation()->execution_thread(), + "parallel_thread"); + EXPECT_THAT(async_start->operands(), ElementsAre(constant1, constant2)); +} + TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { HloComputation::Builder builder(TestName()); auto constant = builder.AddInstruction( @@ -1572,8 +1628,7 @@ TEST_F(HloInstructionTest, CloneSuffixNames) { EXPECT_EQ(foo_clone_clone3->Clone()->name(), "foo.clone.clone4"); } -TEST_F(HloInstructionTest, Stringification) { - // Tests stringification of a simple op, fusion, while, and conditional. +TEST_F(HloInstructionTest, StringifyDot) { const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); @@ -1607,16 +1662,60 @@ TEST_F(HloInstructionTest, Stringification) { EXPECT_EQ(dot->ToString(options2), "dot = f32[5,20] dot(x, transpose), " "lhs_contracting_dims={1}, rhs_contracting_dims={0}"); +} + +TEST_F(HloInstructionTest, StringifySparseDot) { + HloComputation::Builder builder("SparseDot"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {5, 16}), "x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {32, 20}), "y")); + HloInstruction* meta = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(U16, {5, 2}), "meta")); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + SparsityDescriptor sparsity_descriptor; + sparsity_descriptor.set_type(SparsityType::SPARSITY_STRUCTURED_N_M); + sparsity_descriptor.set_n(2); + sparsity_descriptor.set_m(4); + sparsity_descriptor.set_index(0); + sparsity_descriptor.set_dimension(1); + std::vector meta_operands = {meta}; + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(F32, {5, 20}), x, y, dot_dnums, + DefaultPrecisionConfig(2), {sparsity_descriptor}, meta_operands)); + + EXPECT_EQ(dot->ToString(), + "%dot = f32[5,20]{1,0} dot(f32[5,16]{1,0} %x, f32[32,20]{1,0} %y, " + "u16[5,2]{1,0} %meta), lhs_contracting_dims={1}, " + "rhs_contracting_dims={0}, sparsity=L.1@2:4"); +} + +TEST_F(HloInstructionTest, StringifyConditional) { + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction(HloInstruction::CreateDot(sout, x, reshape, dot_dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); - HloInstruction* loop = builder.AddInstruction( - HloInstruction::CreateWhile(sout, computation, computation, x)); - EXPECT_EQ(loop->ToString(options), - "%while = f32[5,20]{1,0} while(f32[5,10]{1,0} %x), " - "condition=%TransposeDot, body=%TransposeDot"); - + auto options = HloPrintOptions().set_print_metadata(false); auto pred = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); HloInstruction* conditional = @@ -1628,6 +1727,36 @@ TEST_F(HloInstructionTest, Stringification) { "true_computation=%TransposeDot, false_computation=%TransposeDot"); } +TEST_F(HloInstructionTest, StringifyWhile) { + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction(HloInstruction::CreateDot(sout, x, reshape, dot_dnums, + DefaultPrecisionConfig(2))); + + auto module = CreateNewVerifiedModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto options = HloPrintOptions().set_print_metadata(false); + HloInstruction* loop = builder.AddInstruction( + HloInstruction::CreateWhile(sout, computation, computation, x)); + EXPECT_EQ(loop->ToString(options), + "%while = f32[5,20]{1,0} while(f32[5,10]{1,0} %x), " + "condition=%TransposeDot, body=%TransposeDot"); +} + TEST_F(HloInstructionTest, GetSetStatisticsViz) { const Shape shape = ShapeUtil::MakeShape(F32, {5, 10}); @@ -1853,17 +1982,11 @@ TEST_F(HloInstructionTest, StringifyAsyncOps) { HloInstruction* async_start = entry_builder.AddInstruction(HloInstruction::CreateAsyncStart( s_tuple, {entry_param}, async_computation.get(), - /*async_group_id=*/std::nullopt, - /*async_execution_thread=*/"parallel_thread")); - HloInstruction* async_update = - entry_builder.AddInstruction(HloInstruction::CreateAsyncUpdate( - s_tuple, async_start, async_computation.get(), - /*async_group_id=*/std::nullopt, /*async_execution_thread=*/"parallel_thread")); - entry_builder.AddInstruction(HloInstruction::CreateAsyncDone( - s2, async_update, async_computation.get(), - /*async_group_id=*/std::nullopt, - /*async_execution_thread=*/"parallel_thread")); + HloInstruction* async_update = entry_builder.AddInstruction( + HloInstruction::CreateAsyncUpdate(s_tuple, async_start)); + entry_builder.AddInstruction( + HloInstruction::CreateAsyncDone(s2, async_update)); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(entry_builder.Build()); @@ -1875,8 +1998,8 @@ TEST_F(HloInstructionTest, StringifyAsyncOps) { ENTRY %Entry (p0: f32[10]) -> f32[20] { %p0 = f32[10]{0} parameter(0) %custom-call-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), async_execution_thread="parallel_thread", custom_call_target="foo" - %custom-call-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %custom-call-start), async_execution_thread="parallel_thread", custom_call_target="foo" - ROOT %custom-call-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %custom-call-update), async_execution_thread="parallel_thread", custom_call_target="foo" + %custom-call-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %custom-call-start) + ROOT %custom-call-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %custom-call-update) } )"; @@ -1892,8 +2015,8 @@ ENTRY %Entry (p0: f32[10]) -> f32[20] { ENTRY %Entry (p0: f32[10]) -> f32[20] { %p0 = f32[10]{0} parameter(0) %custom-call-start = ((f32[10]{0}), f32[20]{0}, s32[]) async-start(f32[10]{0} %p0), async_execution_thread="parallel_thread", calls=%AsyncOp - %custom-call-update = ((f32[10]{0}), f32[20]{0}, s32[]) async-update(((f32[10]{0}), f32[20]{0}, s32[]) %custom-call-start), async_execution_thread="parallel_thread", calls=%AsyncOp - ROOT %custom-call-done = f32[20]{0} async-done(((f32[10]{0}), f32[20]{0}, s32[]) %custom-call-update), async_execution_thread="parallel_thread", calls=%AsyncOp + %custom-call-update = ((f32[10]{0}), f32[20]{0}, s32[]) async-update(((f32[10]{0}), f32[20]{0}, s32[]) %custom-call-start) + ROOT %custom-call-done = f32[20]{0} async-done(((f32[10]{0}), f32[20]{0}, s32[]) %custom-call-update) } )"; @@ -1938,17 +2061,11 @@ TEST_F(HloInstructionTest, StringifyAsyncOpsWithReduceScatter) { HloInstruction* async_start = entry_builder.AddInstruction(HloInstruction::CreateAsyncStart( async_start_shape, {entry_param}, async_computation.get(), - /*async_group_id=*/std::nullopt, /*async_execution_thread=*/"parallel_thread")); - HloInstruction* async_update = - entry_builder.AddInstruction(HloInstruction::CreateAsyncUpdate( - async_start_shape, async_start, async_computation.get(), - /*async_group_id=*/std::nullopt, - /*async_execution_thread=*/"parallel_thread")); - entry_builder.AddInstruction(HloInstruction::CreateAsyncDone( - rs_output_shape, async_update, async_computation.get(), - /*async_group_id=*/std::nullopt, - /*async_execution_thread=*/"parallel_thread")); + HloInstruction* async_update = entry_builder.AddInstruction( + HloInstruction::CreateAsyncUpdate(async_start_shape, async_start)); + entry_builder.AddInstruction( + HloInstruction::CreateAsyncDone(rs_output_shape, async_update)); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(entry_builder.Build()); @@ -1967,8 +2084,8 @@ TEST_F(HloInstructionTest, StringifyAsyncOpsWithReduceScatter) { ENTRY %Entry (pentry: f32[20]) -> f32[10] { %pentry = f32[20]{0} parameter(0) %reduce-scatter-start = ((f32[20]{0}), f32[10]{0}) reduce-scatter-start(f32[20]{0} %pentry), async_execution_thread="parallel_thread", replica_groups={}, dimensions={0}, to_apply=%add - %reduce-scatter-update = ((f32[20]{0}), f32[10]{0}) reduce-scatter-update(((f32[20]{0}), f32[10]{0}) %reduce-scatter-start), async_execution_thread="parallel_thread", replica_groups={}, dimensions={0}, to_apply=%add - ROOT %reduce-scatter-done = f32[10]{0} reduce-scatter-done(((f32[20]{0}), f32[10]{0}) %reduce-scatter-update), async_execution_thread="parallel_thread", replica_groups={}, dimensions={0}, to_apply=%add + %reduce-scatter-update = ((f32[20]{0}), f32[10]{0}) reduce-scatter-update(((f32[20]{0}), f32[10]{0}) %reduce-scatter-start) + ROOT %reduce-scatter-done = f32[10]{0} reduce-scatter-done(((f32[20]{0}), f32[10]{0}) %reduce-scatter-update) } )"; @@ -1991,8 +2108,8 @@ ENTRY %Entry (pentry: f32[20]) -> f32[10] { ENTRY %Entry (pentry: f32[20]) -> f32[10] { %pentry = f32[20]{0} parameter(0) %reduce-scatter-start = ((f32[20]{0}), f32[10]{0}) async-start(f32[20]{0} %pentry), async_execution_thread="parallel_thread", calls=%AsyncOp - %reduce-scatter-update = ((f32[20]{0}), f32[10]{0}) async-update(((f32[20]{0}), f32[10]{0}) %reduce-scatter-start), async_execution_thread="parallel_thread", calls=%AsyncOp - ROOT %reduce-scatter-done = f32[10]{0} async-done(((f32[20]{0}), f32[10]{0}) %reduce-scatter-update), async_execution_thread="parallel_thread", calls=%AsyncOp + %reduce-scatter-update = ((f32[20]{0}), f32[10]{0}) async-update(((f32[20]{0}), f32[10]{0}) %reduce-scatter-start) + ROOT %reduce-scatter-done = f32[10]{0} async-done(((f32[20]{0}), f32[10]{0}) %reduce-scatter-update) } )"; @@ -2128,9 +2245,6 @@ TEST_F(HloInstructionTest, CanonicalStringificationConditional) { computation->CreateFusionInstruction({dot, reshape}, HloInstruction::FusionKind::kLoop); - builder.AddInstruction( - HloInstruction::CreateWhile(sout, computation, computation, x)); - auto pred = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); HloInstruction* conditional = @@ -2417,16 +2531,19 @@ TEST_F(HloInstructionTest, BackendConfigCanContainNonFiniteFloats) { dot_dnums.add_rhs_contracting_dimensions(0); auto dot = b.AddInstruction(HloInstruction::CreateDot( shape, p0, p0, dot_dnums, DefaultPrecisionConfig(2))); - - gpu::GemmBackendConfig orig_config; + gpu::GpuBackendConfig gpu_config; + gpu::GemmBackendConfig& orig_config = + *gpu_config.mutable_gemm_backend_config(); orig_config.set_alpha_real(std::numeric_limits::infinity()); orig_config.set_alpha_imag(std::numeric_limits::quiet_NaN()); - TF_ASSERT_OK(dot->set_backend_config(orig_config)); + TF_ASSERT_OK(dot->set_backend_config(gpu_config)); - TF_ASSERT_OK_AND_ASSIGN(auto new_config, - dot->backend_config()); - EXPECT_GT(new_config.alpha_real(), std::numeric_limits::max()); - EXPECT_NE(new_config.alpha_imag(), new_config.alpha_imag()); + TF_ASSERT_OK_AND_ASSIGN(auto new_gpu_config, + dot->backend_config()); + EXPECT_GT(new_gpu_config.gemm_backend_config().alpha_real(), + std::numeric_limits::max()); + EXPECT_NE(new_gpu_config.gemm_backend_config().alpha_imag(), + new_gpu_config.gemm_backend_config().alpha_imag()); } TEST_F(HloInstructionTest, VerifyToApplyRegionPointsToReduceScatter) { @@ -2541,88 +2658,12 @@ TEST_F(HloInstructionTest, PrintCycle) { ASSERT_IS_OK(send_done->DropAllControlDeps()); } -TEST_F(HloInstructionTest, SetOperationQueueId) { - std::unique_ptr main_computation; - HloComputation::Builder main_builder("Entry"); - const Shape scalar_shape = ShapeUtil::MakeScalarShape(F32); - HloInstruction* param0 = main_builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "p0")); - HloInstruction* param1 = main_builder.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "p1")); - - HloInstruction* add = - main_builder.AddInstruction(HloInstruction::CreateBinary( - scalar_shape, HloOpcode::kAdd, param0, param1)); - add->set_operation_queue_id(3); - auto module = CreateNewVerifiedModule(); - module->AddEntryComputation(main_builder.Build()); - - auto options = HloPrintOptions().set_print_metadata(false); - EXPECT_EQ(module->entry_computation()->root_instruction()->ToString(options), - "%add = f32[] add(f32[] %p0, f32[] %p1), operation_queue_id=3"); -} - -TEST_F(HloInstructionTest, SetWaitOnOperationQueues) { - std::unique_ptr main_computation; - HloComputation::Builder main_builder("Entry"); - const Shape scalar_shape = ShapeUtil::MakeScalarShape(F32); - HloInstruction* param0 = main_builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "p0")); - HloInstruction* param1 = main_builder.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "p1")); - - HloInstruction* add = - main_builder.AddInstruction(HloInstruction::CreateBinary( - scalar_shape, HloOpcode::kAdd, param0, param1)); - std::vector wait_on_queues = {0, 2}; - add->set_wait_on_operation_queues(wait_on_queues); - add->add_wait_on_operation_queues(5); - - auto module = CreateNewVerifiedModule(); - module->AddEntryComputation(main_builder.Build()); - - auto options = HloPrintOptions().set_print_metadata(false); - EXPECT_EQ(module->entry_computation()->root_instruction()->ToString(options), - "%add = f32[] add(f32[] %p0, f32[] %p1), " - "wait_on_operation_queues={0, 2, 5}"); -} - -TEST_F(HloInstructionTest, ParseOperationQueueId) { - constexpr char kHloString[] = R"( - ENTRY main { - c0 = f32[] constant(0) - ROOT add0 = f32[] add(c0, c0), operation_queue_id=2 - })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kHloString)); - EXPECT_EQ( - module->entry_computation()->root_instruction()->operation_queue_id(), 2); -} - -TEST_F(HloInstructionTest, ParseWaitOnOperationQueues) { - constexpr char kHloString[] = R"( - ENTRY main { - c0 = f32[] constant(0) - ROOT add0 = f32[] add(c0, c0), wait_on_operation_queues={0,2} - })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kHloString)); - std::vector expected_wait_on_queue_ids = {0, 2}; - for (int64_t i = 0; i < expected_wait_on_queue_ids.size(); i++) { - EXPECT_EQ(expected_wait_on_queue_ids[i], - module->entry_computation() - ->root_instruction() - ->wait_on_operation_queues()[i]); - } -} - TEST_F(HloInstructionTest, VerifyBodyComputationPointsToWhile) { auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeScalarShape(F32); HloComputation::Builder cond_builder("cond"); { - const Shape scalar_shape = ShapeUtil::MakeScalarShape(F32); HloInstruction* param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape, "p0")); HloInstruction* constant = cond_builder.AddInstruction( @@ -2635,7 +2676,6 @@ TEST_F(HloInstructionTest, VerifyBodyComputationPointsToWhile) { HloComputation::Builder body_builder("body"); { - const Shape scalar_shape = ShapeUtil::MakeScalarShape(F32); HloInstruction* param = body_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape, "p0")); body_builder.AddInstruction(HloInstruction::CreateBinary( @@ -2664,5 +2704,172 @@ TEST_F(HloInstructionTest, VerifyBodyComputationPointsToWhile) { EXPECT_EQ(num_while_body_comp, 1); } +TEST_F(HloInstructionTest, + VerifyBranchComputationPointsToConditonal_TrueFalseConstructor) { + auto module = CreateNewVerifiedModule(); + const Shape scalar_shape = ShapeUtil::MakeScalarShape(F32); + + HloComputation::Builder branch_0_builder("branch_0"); + { + HloInstruction* param = branch_0_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* constant = branch_0_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1024.0))); + branch_0_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param, constant)); + } + auto branch_0_computation = + module->AddEmbeddedComputation(branch_0_builder.Build()); + + HloComputation::Builder branch_1_builder("branch_1"); + { + HloInstruction* param = branch_1_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + branch_1_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kMultiply, param, param)); + } + auto branch_1_computation = + module->AddEmbeddedComputation(branch_1_builder.Build()); + + std::unique_ptr main_computation; + HloComputation::Builder main_builder("Entry"); + + HloInstruction* pred_param = + main_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(PRED, {}), "pred_param")); + HloInstruction* param = main_builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "input")); + + main_builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape, pred_param, /*true_computation_arg=*/param, + /*true_computation=*/branch_0_computation, + /*false_computation_arg=*/param, + /*false_computation=*/branch_1_computation)); + + module->AddEntryComputation(main_builder.Build()); + // Should find conditional branch computations in the graph and it should + // point to the conditonal instruction. + int num_conditional_branch_comp = 0; + for (HloComputation* comp : module->MakeComputationPostOrder()) { + if (comp->IsConditionalBranchComputation()) { + num_conditional_branch_comp += 1; + EXPECT_EQ(comp->ConditionalCallInstruction(), + module->entry_computation()->root_instruction()); + } + } + EXPECT_EQ(num_conditional_branch_comp, 2); +} + +TEST_F(HloInstructionTest, + VerifyBranchComputationPointsToConditonal_BranchIndexConstructor) { + auto module = CreateNewVerifiedModule(); + const Shape scalar_shape = ShapeUtil::MakeScalarShape(F32); + + std::vector branch_computations; + + { + HloComputation::Builder builder("branch_0"); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1024.0))); + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param, constant)); + + branch_computations.push_back( + module->AddEmbeddedComputation(builder.Build())); + } + + { + HloComputation::Builder builder("branch_1"); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kMultiply, param, param)); + + branch_computations.push_back( + module->AddEmbeddedComputation(builder.Build())); + } + + { + HloComputation::Builder builder("branch_2"); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kLog, param)); + + branch_computations.push_back( + module->AddEmbeddedComputation(builder.Build())); + } + + std::unique_ptr main_computation; + HloComputation::Builder main_builder("Entry"); + + HloInstruction* branch_index = + main_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeScalarShape(S32), "branch_index_param")); + HloInstruction* param = main_builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "input")); + + std::vector branch_computation_args( + branch_computations.size(), param); + + main_builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape, branch_index, branch_computations, + branch_computation_args)); + + module->AddEntryComputation(main_builder.Build()); + // Should find conditional branch computations in the graph and it should + // point to the conditonal instruction. + int num_conditional_branch_comp = 0; + for (HloComputation* comp : module->MakeComputationPostOrder()) { + if (comp->IsConditionalBranchComputation()) { + num_conditional_branch_comp += 1; + EXPECT_EQ(comp->ConditionalCallInstruction(), + module->entry_computation()->root_instruction()); + } + } + EXPECT_EQ(num_conditional_branch_comp, branch_computations.size()); +} + +TEST_F(HloInstructionTest, BackendConfigCopiedToDerived) { + HloComputation::Builder b(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + auto p0 = b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + auto p1 = b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p1")); + auto add = b.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p1)); + + gpu::GpuBackendConfig gpu_config; + gpu_config.set_operation_queue_id(2); + TF_ASSERT_OK(add->set_backend_config(gpu_config)); + auto add2 = b.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0)); + add->SetupDerivedInstruction(add2); + auto backend_config = add2->backend_config(); + EXPECT_TRUE(backend_config.ok()); + EXPECT_EQ(backend_config->operation_queue_id(), 2); +} + +TEST_F(HloInstructionTest, BackendConfigNotCopiedToDerivedWithDiffOpcode) { + HloComputation::Builder b(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + auto p0 = b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + auto p1 = b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p1")); + auto or1 = b.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kOr, p0, p1)); + + gpu::GpuBackendConfig gpu_config; + gpu_config.set_operation_queue_id(2); + TF_ASSERT_OK(or1->set_backend_config(gpu_config)); + auto add2 = b.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p1)); + or1->SetupDerivedInstruction(add2); + EXPECT_FALSE(add2->has_backend_config()); +} + } // namespace } // namespace xla diff --git a/xla/service/hlo_lexer.cc b/xla/service/hlo_lexer.cc index bd516129caa7a..92053e86f7617 100644 --- a/xla/service/hlo_lexer.cc +++ b/xla/service/hlo_lexer.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -347,6 +347,13 @@ TokKind HloLexer::LexIdentifier() { token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kDimLabels; } + static LazyRE2 sparsity_desc_pattern = { + R"(([LR]\.[0-9]+@[0-9]+:[0-9]+_?)+)"}; + if (RE2::Consume(&consumable, *sparsity_desc_pattern)) { + current_ptr_ = consumable.data(); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); + return TokKind::kSparsityDesc; + } } token_state_.str_val = std::string(identifier); @@ -638,6 +645,8 @@ std::string TokKindToString(TokKind kind) { return "kDxD"; case TokKind::kPad: return "kPad"; + case TokKind::kSparsityDesc: + return "kSparsityDesc"; case TokKind::kIdent: return "kIdent"; case TokKind::kString: diff --git a/xla/service/hlo_lexer.h b/xla/service/hlo_lexer.h index 5681818c07162..8a7547ff67983 100644 --- a/xla/service/hlo_lexer.h +++ b/xla/service/hlo_lexer.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -82,6 +82,7 @@ enum class TokKind { kDimLabels, // [0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,} kDxD, // [0-9]+(x[0-9]+)+ kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* + kSparsityDesc, // ([LR]\.[0-9]+@[0-9]+:[0-9]+_?)+ kIdent, // other identifiers kString, // "abcd\"\n" kInt, // 42 @@ -110,6 +111,7 @@ class HloLexer { case TokKind::kDimLabels: case TokKind::kDxD: case TokKind::kPad: + case TokKind::kSparsityDesc: case TokKind::kString: case TokKind::kIdent: return token_state_.str_val; diff --git a/xla/service/hlo_liveness_analysis.cc b/xla/service/hlo_liveness_analysis.cc index 91ec0a6f8d747..83ae6250e3d8c 100644 --- a/xla/service/hlo_liveness_analysis.cc +++ b/xla/service/hlo_liveness_analysis.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,18 +15,22 @@ limitations under the License. #include "xla/service/hlo_liveness_analysis.h" +#include +#include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/functional/function_ref.h" +#include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/call_graph.h" +#include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/status.h" #include "xla/types.h" @@ -116,25 +120,28 @@ void PropagateLivenessThroughTuple( HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, Workset* workset) { CHECK_EQ(instruction->opcode(), HloOpcode::kTuple); - for (int64_t operand_index = 0; operand_index < instruction->operand_count(); - ++operand_index) { - const ShapeTree& index_tree = *live_index_map->at(instruction); - ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { - if (shape_index.empty() || shape_index[0] != operand_index) { - return; - } - // Mark top-level index of operand at 'operand_index'. - MarkLiveAtIndex(instruction->operand(operand_index), {}, live_index_map, - worklist, workset); - // Mark sub-shape index of operand at 'operand_index'. - ShapeIndex operand_shape_index; - for (int i = 1; i < shape_index.size(); ++i) { - operand_shape_index.push_back(shape_index[i]); - } - MarkLiveAtIndex(instruction->operand(operand_index), operand_shape_index, - live_index_map, worklist, workset); - }); - } + const ShapeTree& index_tree = *live_index_map->at(instruction); + + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + const size_t size = shape_index.size(); + if (size == 0) { + return; + } + const int64_t operand_index = shape_index[0]; + if (operand_index >= instruction->operand_count()) { + return; + } + // Mark top-level index of operand at 'operand_index'. + MarkLiveAtIndex(instruction->operand(operand_index), {}, live_index_map, + worklist, workset); + // Mark sub-shape index of operand at 'operand_index'. + ShapeIndex operand_shape_index(size - 1); + for (int i = 1; i < size; ++i) { + operand_shape_index[i - 1] = shape_index[i]; + } + MarkLiveAtIndex(instruction->operand(operand_index), operand_shape_index, + live_index_map, worklist, workset); + }); } // Propagates liveness through GetTupleElement instructions. @@ -334,7 +341,7 @@ bool HloLivenessAnalysis::IsLive(const HloInstruction* instruction, } /* static */ -StatusOr> HloLivenessAnalysis::Run( +absl::StatusOr> HloLivenessAnalysis::Run( const HloModule& module) { VLOG(1) << "HloLivenessAnalysis::Run on module " << module.name(); XLA_VLOG_LINES(2, module.ToString()); diff --git a/xla/service/hlo_liveness_analysis.h b/xla/service/hlo_liveness_analysis.h index a021573a89259..0c49dda4aaaf8 100644 --- a/xla/service/hlo_liveness_analysis.h +++ b/xla/service/hlo_liveness_analysis.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -44,7 +44,7 @@ class HloLivenessAnalysis { // Runs liveness analysis on 'module'. Returns HloLivenessAnalysis object // which exports liveness for each {HloInstruction, ShapeIndex} in 'module'. - static StatusOr> Run( + static absl::StatusOr> Run( const HloModule& module); // Returns true if output of 'instruction' at 'shape_index' is live. diff --git a/xla/service/hlo_liveness_analysis_test.cc b/xla/service/hlo_liveness_analysis_test.cc index f87586df26e1e..4a132059878c3 100644 --- a/xla/service/hlo_liveness_analysis_test.cc +++ b/xla/service/hlo_liveness_analysis_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_memory_scheduler.cc b/xla/service/hlo_memory_scheduler.cc index 65320dd96d112..f073b34556ef5 100644 --- a/xla/service/hlo_memory_scheduler.cc +++ b/xla/service/hlo_memory_scheduler.cc @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" -#include "xla/service/heap_simulator.h" +#include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/tuple_points_to_analysis.h" #include "xla/shape_util.h" #include "xla/status_macros.h" @@ -77,7 +77,7 @@ class ListScheduler { public: // Construct and return a memory-minimizing sequence of HLO instructions // containing the given HLO computation. - static StatusOr Run( + static absl::StatusOr Run( HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_function, @@ -410,7 +410,7 @@ int64_t SumLogicalBufferSizes( return size; } -StatusOr ScheduleComputationHelper( +absl::StatusOr ScheduleComputationHelper( HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, @@ -433,7 +433,7 @@ StatusOr ScheduleComputationHelper( } // namespace -StatusOr DFSMemoryScheduler( +absl::StatusOr DFSMemoryScheduler( HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, @@ -532,7 +532,7 @@ ModuleSchedulerAlgorithm ComputationSchedulerToModuleScheduler( const HloAliasAnalysis& alias_analysis, const LogicalBuffer::SizeFunction& size_func, const absl::flat_hash_set& execution_threads, - int64_t* peak_memory) -> StatusOr { + int64_t* peak_memory) -> absl::StatusOr { HloSchedule schedule(module); absl::flat_hash_map memory_by_computation; for (auto* computation : @@ -555,7 +555,7 @@ ModuleSchedulerAlgorithm ComputationSchedulerToModuleScheduler( }; } -StatusOr ListMemoryScheduler( +absl::StatusOr ListMemoryScheduler( HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, @@ -578,7 +578,7 @@ StatusOr ListMemoryScheduler( return sequence; } -StatusOr PostOrderMemoryScheduler( +absl::StatusOr PostOrderMemoryScheduler( HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, @@ -599,7 +599,7 @@ StatusOr PostOrderMemoryScheduler( return sequence; } -StatusOr DefaultMemoryScheduler( +absl::StatusOr DefaultMemoryScheduler( HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, @@ -660,7 +660,7 @@ StatusOr DefaultMemoryScheduler( } } -StatusOr DefaultModuleScheduler( +absl::StatusOr DefaultModuleScheduler( const HloModule* module, const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_function, @@ -720,7 +720,7 @@ StatusOr DefaultModuleScheduler( } } -StatusOr ScheduleModule( +absl::StatusOr ScheduleModule( const HloModule* module, const BufferValue::SizeFunction& size_function, const ModuleSchedulerAlgorithm& algorithm, const absl::flat_hash_set& execution_threads, @@ -740,7 +740,7 @@ StatusOr ScheduleModule( return std::move(schedule); } -StatusOr ScheduleComputation( +absl::StatusOr ScheduleComputation( HloComputation* computation, const BufferValue::SizeFunction& size_function, const MemorySchedulerPostprocessor& postprocessor) { CHECK(!computation->IsFusionComputation()); @@ -760,7 +760,7 @@ HloMemoryScheduler::HloMemoryScheduler( const ModuleSchedulerAlgorithm& algorithm) : size_function_(size_function), algorithm_(algorithm) {} -StatusOr HloMemoryScheduler::Run( +absl::StatusOr HloMemoryScheduler::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { TF_ASSIGN_OR_RETURN( @@ -770,7 +770,7 @@ StatusOr HloMemoryScheduler::Run( return true; } -StatusOr HloTrivialScheduler::Run( +absl::StatusOr HloTrivialScheduler::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { HloSchedule schedule(module); @@ -792,7 +792,7 @@ StatusOr HloTrivialScheduler::Run( return true; } -StatusOr HloDescheduler::Run( +absl::StatusOr HloDescheduler::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = module->has_schedule(); diff --git a/xla/service/hlo_memory_scheduler.h b/xla/service/hlo_memory_scheduler.h index 2d61c077a5055..24ca085a9a40b 100644 --- a/xla/service/hlo_memory_scheduler.h +++ b/xla/service/hlo_memory_scheduler.h @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ using MemorySchedulerPostprocessor = // HeapSimulator. // // TODO(yunxing): Cleanup usage of TuplePointsToAnalysis. -typedef std::function( +typedef std::function( HloComputation*, const TuplePointsToAnalysis&, const HloAliasAnalysis&, const LogicalBuffer::SizeFunction&, const absl::flat_hash_map&, @@ -51,7 +51,7 @@ typedef std::function( MemorySchedulerAlgorithm; // Scheduler for the entire module. -typedef std::function( +typedef std::function( const HloModule*, const TuplePointsToAnalysis&, const HloAliasAnalysis&, const LogicalBuffer::SizeFunction&, const absl::flat_hash_set& execution_threads, @@ -64,7 +64,7 @@ ModuleSchedulerAlgorithm ComputationSchedulerToModuleScheduler( const MemorySchedulerAlgorithm&, const MemorySchedulerPostprocessor& = {}); // List scheduler -StatusOr ListMemoryScheduler( +absl::StatusOr ListMemoryScheduler( HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, @@ -74,7 +74,7 @@ StatusOr ListMemoryScheduler( const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); // DFS-order scheduler -StatusOr DFSMemoryScheduler( +absl::StatusOr DFSMemoryScheduler( HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, @@ -84,7 +84,7 @@ StatusOr DFSMemoryScheduler( const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); // Naive Post Order scheduler -StatusOr PostOrderMemoryScheduler( +absl::StatusOr PostOrderMemoryScheduler( HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, @@ -97,7 +97,7 @@ StatusOr PostOrderMemoryScheduler( // and the post-order scheduler and chooses whichever returns a lower min- // memory, not accounting for fragmentation. peak_memory (may be nullptr) is set // to the peak memory of the resulting schedule according to the HeapSimulator. -StatusOr DefaultMemoryScheduler( +absl::StatusOr DefaultMemoryScheduler( HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, @@ -106,7 +106,7 @@ StatusOr DefaultMemoryScheduler( memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); -StatusOr DefaultModuleScheduler( +absl::StatusOr DefaultModuleScheduler( const HloModule* module, const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -117,7 +117,7 @@ StatusOr DefaultModuleScheduler( // module. size_function is the function returning the number of bytes required // for a LogicalBuffer. peak_memory (if not nullptr) is set to the largest peak // memory (according to the HeapSimulator) of all computations in the module. -StatusOr ScheduleModule( +absl::StatusOr ScheduleModule( const HloModule* module, const LogicalBuffer::SizeFunction& size_function, const ModuleSchedulerAlgorithm& algorithm = {}, const absl::flat_hash_set& execution_threads = {}, @@ -125,7 +125,7 @@ StatusOr ScheduleModule( // Computes the schedule for a single computation. // Currently only used by the GPU backend. -StatusOr ScheduleComputation( +absl::StatusOr ScheduleComputation( HloComputation* computation, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerPostprocessor& postprocessor); @@ -146,7 +146,7 @@ class HloMemoryScheduler : public HloModulePass { absl::string_view name() const override { return "hlo-memory-scheduler"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -163,7 +163,7 @@ class HloTrivialScheduler : public HloModulePass { absl::string_view name() const override { return "hlo-trivial-scheduler"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; @@ -177,7 +177,7 @@ class HloDescheduler : public HloModulePass { absl::string_view name() const override { return "hlo-descheduler"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/hlo_memory_scheduler_test.cc b/xla/service/hlo_memory_scheduler_test.cc index c3db985d55adf..17a10ae1ea4d4 100644 --- a/xla/service/hlo_memory_scheduler_test.cc +++ b/xla/service/hlo_memory_scheduler_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/heap_simulator.h" +#include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo_dce.h" #include "xla/service/hlo_ordering.h" #include "xla/shape_util.h" diff --git a/xla/service/hlo_module_config.cc b/xla/service/hlo_module_config.cc index a293887d0ed8e..0cf04f10387f4 100644 --- a/xla/service/hlo_module_config.cc +++ b/xla/service/hlo_module_config.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -85,6 +85,9 @@ std::string HloModuleConfig::compilation_cache_key() const { StrAppend(&key, device_type()); } StrAppend(&key, "::alias_passthrough_params=", alias_passthrough_params_); + StrAppend(&key, "::allow_spmd_sharding_propagation_to_parameters={", + absl::StrJoin(allow_spmd_sharding_propagation_to_parameters_, ","), + "}"); StrAppend(&key, "::allow_spmd_sharding_propagation_to_output={", absl::StrJoin(allow_spmd_sharding_propagation_to_output_, ","), "}"); @@ -147,7 +150,7 @@ static void AssignProtoDotConfig( for (int64_t val : list_vector) { list.add_vals(val); } - proto.mutable_dot_config()->insert({key, std::move(list)}); + proto.mutable_dot_config()->try_emplace(key, std::move(list)); } } @@ -195,7 +198,7 @@ static void AssignProtoPhaseOrderingConfig( pair.output_shape_index.assign(output_idx.begin(), output_idx.end()); cfg_pairs.push_back(pair); } - config.set_shardable_value_update_pairs(cfg_pairs); + config.set_shardable_value_update_pairs(std::move(cfg_pairs)); } static void AssignStructFusionConfig(HloModuleConfig& config, @@ -255,7 +258,7 @@ static void AssignStructPhaseOrderingConfig(HloModuleConfig& config, *config.mutable_phase_ordering_config() = std::move(module_config); } -StatusOr HloModuleConfig::ToProto() const { +absl::StatusOr HloModuleConfig::ToProto() const { HloModuleConfigProto proto; if (has_entry_computation_layout()) { *proto.mutable_entry_computation_layout() = @@ -303,6 +306,9 @@ StatusOr HloModuleConfig::ToProto() const { AssignProtoPhaseOrderingConfig(proto, phase_ordering_config_); proto.set_phase_index(phase_index_); + for (bool value : allow_spmd_sharding_propagation_to_parameters_) { + proto.add_allow_spmd_sharding_propagation_to_parameters(value); + } for (bool value : allow_spmd_sharding_propagation_to_output_) { proto.add_allow_spmd_sharding_propagation_to_output(value); } @@ -318,8 +324,8 @@ StatusOr HloModuleConfig::ToProto() const { return proto; } -StatusOr> HloModuleConfig::CreateFromProto( - const HloModuleConfigProto& proto) { +absl::StatusOr> +HloModuleConfig::CreateFromProto(const HloModuleConfigProto& proto) { auto config = std::make_unique(); if (proto.has_entry_computation_layout()) { @@ -370,6 +376,9 @@ StatusOr> HloModuleConfig::CreateFromProto( proto.memory_space_assignment_config().end()); AssignStructPhaseOrderingConfig(*config, proto); config->phase_index_ = proto.phase_index(); + config->allow_spmd_sharding_propagation_to_parameters_.assign( + proto.allow_spmd_sharding_propagation_to_parameters().begin(), + proto.allow_spmd_sharding_propagation_to_parameters().end()); config->allow_spmd_sharding_propagation_to_output_.assign( proto.allow_spmd_sharding_propagation_to_output().begin(), proto.allow_spmd_sharding_propagation_to_output().end()); diff --git a/xla/service/hlo_module_config.h b/xla/service/hlo_module_config.h index bac9a26e28493..754750018f6bd 100644 --- a/xla/service/hlo_module_config.h +++ b/xla/service/hlo_module_config.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -85,8 +85,8 @@ class HloModuleConfig { explicit HloModuleConfig(ComputationLayout entry_computation_layout); // Convert an HloModuleConfig to or from a proto. - StatusOr ToProto() const; - static StatusOr> CreateFromProto( + absl::StatusOr ToProto() const; + static absl::StatusOr> CreateFromProto( const HloModuleConfigProto& proto); // Assigns the repeated ShardableValueUpdatePairProto field to the given @@ -169,7 +169,7 @@ class HloModuleConfig { return param_requires_broadcast_via_collectives_; } void set_param_requires_broadcast_via_collectives( - const std::vector require_broadcast) { + std::vector require_broadcast) { param_requires_broadcast_via_collectives_ = std::move(require_broadcast); } @@ -195,16 +195,16 @@ class HloModuleConfig { } void set_auto_spmd_partitioning_mesh_shape(std::vector mesh_shape) { - auto_spmd_partitioning_mesh_shape_ = mesh_shape; + auto_spmd_partitioning_mesh_shape_ = std::move(mesh_shape); } - std::vector auto_spmd_partitioning_mesh_shape() const { + const std::vector& auto_spmd_partitioning_mesh_shape() const { return auto_spmd_partitioning_mesh_shape_; } void set_auto_spmd_partitioning_mesh_ids(std::vector mesh_ids) { - auto_spmd_partitioning_mesh_ids_ = mesh_ids; + auto_spmd_partitioning_mesh_ids_ = std::move(mesh_ids); } - std::vector auto_spmd_partitioning_mesh_ids() const { + const std::vector& auto_spmd_partitioning_mesh_ids() const { return auto_spmd_partitioning_mesh_ids_; } @@ -324,9 +324,17 @@ class HloModuleConfig { int phase_index() const { return phase_index_; } void set_phase_index(const int phase_index) { phase_index_ = phase_index; } + absl::Span allow_spmd_sharding_propagation_to_parameters() const { + return allow_spmd_sharding_propagation_to_parameters_; + } absl::Span allow_spmd_sharding_propagation_to_output() const { return allow_spmd_sharding_propagation_to_output_; } + void set_allow_spmd_sharding_propagation_to_parameters( + absl::Span data) { + return allow_spmd_sharding_propagation_to_parameters_.assign(data.begin(), + data.end()); + } void set_allow_spmd_sharding_propagation_to_output( absl::Span data) { return allow_spmd_sharding_propagation_to_output_.assign(data.begin(), @@ -453,6 +461,18 @@ class HloModuleConfig { // config across functions during compilation. int phase_index_ = 0; + // Allows sharding propagation to propagate to the parameters. This changes + // the input shape of the computation (which is undesirable), but it can be + // used to allow to run partial compilation to determine what would be the + // input sharding of a computation if XLA would be allowed to propagate the + // sharding which can be used by higher level framework as a way to query + // intermediate sharding of operations when multiple computation would be + // chained and merged together. + // This is a vector of bool, because the user can control which parameters can + // have the sharding substituted. If only one boolean value is passed in the + // vector that is interpreted as the value to be applied for every parameter. + absl::InlinedVector allow_spmd_sharding_propagation_to_parameters_ = + {false}; // Allows sharding propagation to propagate to the outputs. This changes the // output shape of the computation (which is undesirable), but it can be used // to allow to run partial compilation to determine what would be the output diff --git a/xla/service/hlo_module_config_test.cc b/xla/service/hlo_module_config_test.cc index ac9b0c9ad3f0f..952aec1244fec 100644 --- a/xla/service/hlo_module_config_test.cc +++ b/xla/service/hlo_module_config_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_module_dce.cc b/xla/service/hlo_module_dce.cc index f00f6d028509d..de04957795aa2 100644 --- a/xla/service/hlo_module_dce.cc +++ b/xla/service/hlo_module_dce.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ namespace xla { namespace { -StatusOr RunWhileDCE( +absl::StatusOr RunWhileDCE( HloModule* module, HloLivenessAnalysis* liveness, const absl::flat_hash_set& execution_threads) { bool changed = false; @@ -105,7 +105,7 @@ StatusOr RunWhileDCE( } // namespace -StatusOr HloModuleDCE::Run( +absl::StatusOr HloModuleDCE::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(2) << "Before HloModuleDCE:"; diff --git a/xla/service/hlo_module_dce.h b/xla/service/hlo_module_dce.h index 56bda9c5684a3..2bd5df9deb87c 100644 --- a/xla/service/hlo_module_dce.h +++ b/xla/service/hlo_module_dce.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ class HloModuleDCE : public HloModulePass { // Run the pass on the given module. Returns whether the module was changed // (instructions were removed). using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/hlo_module_dce_test.cc b/xla/service/hlo_module_dce_test.cc index 19dd36aef4439..c192429c2f30e 100644 --- a/xla/service/hlo_module_dce_test.cc +++ b/xla/service/hlo_module_dce_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_module_group_metadata.cc b/xla/service/hlo_module_group_metadata.cc index 31234894b5211..2ab8c03c7ec03 100644 --- a/xla/service/hlo_module_group_metadata.cc +++ b/xla/service/hlo_module_group_metadata.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -57,7 +57,7 @@ std::string HloModuleGroupMetadata::TrackedInstruction::ToString() const { return repr; } -/* static */ StatusOr> +/* static */ absl::StatusOr> HloModuleGroupMetadata::Build(absl::Span modules) { auto metadata = std::make_unique(modules); TF_RETURN_IF_ERROR(metadata->Build()); @@ -78,11 +78,11 @@ Status HloModuleGroupMetadata::Build() { return OkStatus(); } - if (IsChannelInstruction(hlo) || hlo->IsCrossModuleAllReduce()) { + if (IsChannelInstruction(hlo) || IsNonSpmdCrossModuleAllReduce(hlo)) { std::vector peers; if (IsChannelInstruction(hlo)) { peers.push_back(PeerComputation(hlo)); - } else if (hlo->IsCrossModuleAllReduce()) { + } else if (IsNonSpmdCrossModuleAllReduce(hlo)) { for (HloInstruction* instr : GetAllReduceGroup(*hlo->channel_id())) { if (instr == hlo) { continue; @@ -217,10 +217,16 @@ bool HloModuleGroupMetadata::IsCompanionInstruction(HloInstruction* hlo) const { return companion_set_index_.contains(hlo); } +bool HloModuleGroupMetadata::IsNonSpmdCrossModuleAllReduce( + HloInstruction* hlo) const { + return hlo->IsCrossModuleAllReduce() && + !hlo->GetModule()->config().use_spmd_partitioning(); +} + bool HloModuleGroupMetadata::InstructionCommunicates( HloInstruction* hlo) const { return IsChannelInstruction(hlo) || IsCompanionInstruction(hlo) || - hlo->IsCrossModuleAllReduce(); + IsNonSpmdCrossModuleAllReduce(hlo); } const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel( @@ -332,7 +338,7 @@ Status HloModuleGroupMetadata::RecordInstructions() { } // Group cross module all-reduce instructions by the channel id. - if (hlo->IsCrossModuleAllReduce()) { + if (IsNonSpmdCrossModuleAllReduce(hlo)) { TF_RET_CHECK(channel_id_map_.find(*hlo->channel_id()) == channel_id_map_.end()) << "channel_id " << *hlo->channel_id() diff --git a/xla/service/hlo_module_group_metadata.h b/xla/service/hlo_module_group_metadata.h index c5bedb109176d..f977b19ab2ed7 100644 --- a/xla/service/hlo_module_group_metadata.h +++ b/xla/service/hlo_module_group_metadata.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -114,7 +114,7 @@ class HloModuleGroupMetadata { ~HloModuleGroupMetadata() = default; // Build and return the metadata for the given modules. - static StatusOr> Build( + static absl::StatusOr> Build( absl::Span modules); // Returns true if the instruction is one of the 4 channel instructions (Send, @@ -125,8 +125,12 @@ class HloModuleGroupMetadata { // comment above on companion instructions. bool IsCompanionInstruction(HloInstruction* hlo) const; + // Returns true if the instruction is either a cross-module all-reduce + // instruction in a non-spmd module. + bool IsNonSpmdCrossModuleAllReduce(HloInstruction* hlo) const; + // Returns true if the instruction is either a channel instruction, a - // cross-module all-reduce instruction, or a companion instruction. + // cross-module non-spmd all-reduce instruction, or a companion instruction. bool InstructionCommunicates(HloInstruction* hlo) const; // Returns the Channel instance for the given channel id. diff --git a/xla/service/hlo_module_group_test.cc b/xla/service/hlo_module_group_test.cc index f061a80dce243..b56b53b4952e0 100644 --- a/xla/service/hlo_module_group_test.cc +++ b/xla/service/hlo_module_group_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_module_group_util.cc b/xla/service/hlo_module_group_util.cc index 3b5f052c2a67d..362c3ce46c3e3 100644 --- a/xla/service/hlo_module_group_util.cc +++ b/xla/service/hlo_module_group_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -362,7 +362,7 @@ Status HloModuleGroupUtil::VerifyComputations( return OkStatus(); } -StatusOr> +absl::StatusOr> HloModuleGroupUtil::ComputeReachability( absl::Span computations) { std::vector post_order; diff --git a/xla/service/hlo_module_group_util.h b/xla/service/hlo_module_group_util.h index 595712af811fd..9d5e9af41e478 100644 --- a/xla/service/hlo_module_group_util.h +++ b/xla/service/hlo_module_group_util.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -103,7 +103,7 @@ class HloModuleGroupUtil { // they can handle instructions across multiple computations. // // Creates the reachability map for the instructions in the computations. - StatusOr> ComputeReachability( + absl::StatusOr> ComputeReachability( absl::Span computations); // Updates the reachability of the given instruction, taking the global diff --git a/xla/service/hlo_module_metadata_test.cc b/xla/service/hlo_module_metadata_test.cc index 834c0214cd17e..fc64b31d019ab 100644 --- a/xla/service/hlo_module_metadata_test.cc +++ b/xla/service/hlo_module_metadata_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_module_test.cc b/xla/service/hlo_module_test.cc index 34c4d20a0f4e7..49d00abadb03a 100644 --- a/xla/service/hlo_module_test.cc +++ b/xla/service/hlo_module_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -791,7 +791,7 @@ static HloModuleConfigProto::BoolList MakeOneHotBoolList(unsigned num_vals, return list; } -static StatusOr MakeTestModuleConfigProto() { +static absl::StatusOr MakeTestModuleConfigProto() { HloModuleConfigProto proto; // entry_computation_layout_ is optional proto.set_seed(0xdeadbeef); diff --git a/xla/service/hlo_module_util.cc b/xla/service/hlo_module_util.cc index 26cc35cbf0d21..eca1d682061e6 100644 --- a/xla/service/hlo_module_util.cc +++ b/xla/service/hlo_module_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -46,7 +46,7 @@ Status ValidateResultShape(const Shape& client_shape, } } // namespace -StatusOr> CreateModuleConfig( +absl::StatusOr> CreateModuleConfig( const ProgramShape& program_shape, absl::Span argument_shapes, const ExecutionOptions* execution_options, int default_num_replicas, @@ -100,6 +100,11 @@ StatusOr> CreateModuleConfig( } config->set_use_spmd_partitioning( execution_options->use_spmd_partitioning()); + if (!execution_options->allow_spmd_sharding_propagation_to_parameters() + .empty()) { + config->set_allow_spmd_sharding_propagation_to_parameters( + execution_options->allow_spmd_sharding_propagation_to_parameters()); + } if (!execution_options->allow_spmd_sharding_propagation_to_output() .empty()) { config->set_allow_spmd_sharding_propagation_to_output( @@ -107,16 +112,12 @@ StatusOr> CreateModuleConfig( } config->set_use_auto_spmd_partitioning( execution_options->use_auto_spmd_partitioning()); - std::vector mesh_shape; - for (auto t : execution_options->auto_spmd_partitioning_mesh_shape()) { - mesh_shape.push_back(t); - } - config->set_auto_spmd_partitioning_mesh_shape(mesh_shape); - std::vector mesh_ids; - for (auto t : execution_options->auto_spmd_partitioning_mesh_ids()) { - mesh_ids.push_back(t); - } - config->set_auto_spmd_partitioning_mesh_ids(mesh_ids); + config->set_auto_spmd_partitioning_mesh_shape(std::vector( + execution_options->auto_spmd_partitioning_mesh_shape().begin(), + execution_options->auto_spmd_partitioning_mesh_shape().end())); + config->set_auto_spmd_partitioning_mesh_ids(std::vector( + execution_options->auto_spmd_partitioning_mesh_ids().begin(), + execution_options->auto_spmd_partitioning_mesh_ids().end())); config->set_deduplicate_hlo(execution_options->deduplicate_hlo()); config->set_seed(execution_options->seed()); config->set_launch_id(execution_options->launch_id()); diff --git a/xla/service/hlo_module_util.h b/xla/service/hlo_module_util.h index f74faeb287a53..8692bf927e444 100644 --- a/xla/service/hlo_module_util.h +++ b/xla/service/hlo_module_util.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ namespace xla { // If execution_options does not set num_replicas, default_num_replicas is used. // num_threads is optional; if not given, intra_op_parallelism_threads not set. // aot_options is optional; if not given a default is used. -StatusOr> CreateModuleConfig( +absl::StatusOr> CreateModuleConfig( const ProgramShape& program_shape, absl::Span argument_shapes, const ExecutionOptions* execution_options, int default_num_replicas, diff --git a/xla/service/hlo_opcode_test.cc b/xla/service/hlo_opcode_test.cc index 48a79d9ee8827..3f1e5c44fc63a 100644 --- a/xla/service/hlo_opcode_test.cc +++ b/xla/service/hlo_opcode_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -58,11 +58,13 @@ TEST(HloOpcodeTest, OpcodeProperties) { case HloOpcode::kAllReduceStart: case HloOpcode::kAllToAll: case HloOpcode::kCall: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermuteStart: case HloOpcode::kConcatenate: case HloOpcode::kConditional: case HloOpcode::kCustomCall: + case HloOpcode::kDot: // Sparse dot has an extra meta argument. case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kDynamicReshape: diff --git a/xla/service/hlo_ordering.cc b/xla/service/hlo_ordering.cc index cb23dad2517d9..b60e9ce44afe1 100644 --- a/xla/service/hlo_ordering.cc +++ b/xla/service/hlo_ordering.cc @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -54,7 +54,11 @@ HloOrdering::ExecutionConstraint HloOrdering::GetExecutionConstraint( // callgraph ancestor instructions which call (potentially transitively) the // computations containing 'a' and 'b' and use these ancestor instructions to // compare order. - if (a == b) { + auto is_async_wrapped = [](const HloInstruction* a, const HloInstruction* b) { + // Treats the async wrapped instruction as same as the wrapper. + return a->IsAsynchronous() && a->async_wrapped_instruction() == b; + }; + if (a == b || is_async_wrapped(a, b) || is_async_wrapped(b, a)) { return ExecutionConstraint::kIsSame; } const HloInstruction* a_ancestor; diff --git a/xla/service/hlo_ordering.h b/xla/service/hlo_ordering.h index c1d255bc0f8cd..6b070798f9ebb 100644 --- a/xla/service/hlo_ordering.h +++ b/xla/service/hlo_ordering.h @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_ordering_test.cc b/xla/service/hlo_ordering_test.cc index c3329f768e13d..1f2bad04127b4 100644 --- a/xla/service/hlo_ordering_test.cc +++ b/xla/service/hlo_ordering_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -283,7 +283,7 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { ASSERT_EQ(dataflow->GetValueDefinedAt(xla_while).GetUses().size(), 1); const HloUse* while_use = - &dataflow->GetValueDefinedAt(xla_while).GetUses()[0]; + dataflow->GetValueDefinedAt(xla_while).GetUses().data(); EXPECT_EQ(while_use->instruction, add); EXPECT_TRUE(ordering.UsesBeforeValueDefinition( {&while_use, 1}, dataflow->GetValueDefinedAt(add), *dataflow)); @@ -613,8 +613,8 @@ HloModule single_sc_async_call ENTRY %main { %input.1 = s32[1024]{0} parameter(0) %buf = s32[1024]{0} custom-call(), custom_call_target="AllocateBuffer" - %async-start = ((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) async-start(s32[1024]{0} %input.1, s32[1024]{0} %buf), async_group_id=0, async_execution_thread="foobar", calls=%async_wrapped - ROOT %async-done = s32[1024]{0} async-done(((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) %async-start), async_group_id=0, async_execution_thread="foobar", calls=%async_wrapped + %async-start = ((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) async-start(s32[1024]{0} %input.1, s32[1024]{0} %buf), async_execution_thread="foobar", calls=%async_wrapped + ROOT %async-done = s32[1024]{0} async-done(((s32[1024]{0}, s32[1024]{0}), s32[1024]{0}, u32[]) %async-start), async_execution_thread="foobar", calls=%async_wrapped } )"; HloModuleConfig hlo_config; @@ -636,5 +636,60 @@ ENTRY %main { {&async_start_use, &call_use, &async_done_use}, value, *dataflow)); } +TEST_F(HloOrderingTest, OrderingBetweenAsyncOpAndItsWrapped) { + constexpr absl::string_view hlo = R"( +HloModule test + +%async_computation { + %param_0 = f32[10,32,512]{2,1,0:T(8,128)S(5)} parameter(0) + %param_1 = f32[1,32,512]{2,1,0:T(8,128)} parameter(1) + %param_2 = s32[]{:T(128)} parameter(2) + %param_3 = s32[]{:T(128)} parameter(3) + %param_4 = s32[]{:T(128)} parameter(4) + ROOT %dynamic-update-slice.1 = f32[10,32,512]{2,1,0:T(8,128)S(5)} + dynamic-update-slice(%param_0, %param_1, %param_2, %param_3, %param_4) +} + +ENTRY %main { + %param.1 = (s32[]{:T(128)}, f32[32,512]{1,0:T(8,128)}, + f32[10,32,512]{2,1,0:T(8,128)S(5)}) parameter(0) + %get-tuple-element.132 = f32[10,32,512]{2,1,0:T(8,128)S(5)} get-tuple-element( + %param.1), index=2 + %get-tuple-element.131 = f32[32,512]{1,0:T(8,128)} get-tuple-element( + %param.1), index=1 + %cosine.0 = f32[32,512]{1,0:T(8,128)} cosine(%get-tuple-element.131) + %reshape.6 = f32[1,32,512]{2,1,0:T(8,128)} reshape(%cosine.0) + %get-tuple-element.130 = s32[]{:T(128)} get-tuple-element(%param.1), index=0 + %constant.49 = s32[]{:T(128)} constant(0) + %compare.13 = pred[]{:T(512)} compare( + %get-tuple-element.130, %constant.49), direction=LT + %constant.50 = s32[]{:T(128)} constant(10) + %add.22 = s32[]{:T(128)} add(%get-tuple-element.130, %constant.50) + %select.6 = s32[]{:T(128)} select( + %compare.13, %add.22, %get-tuple-element.130) + %dynamic-update-slice-start = ( + (f32[10,32,512]{2,1,0:T(8,128)S(5)}, f32[1,32,512]{2,1,0:T(8,128)}, + s32[]{:T(128)}, s32[]{:T(128)}, s32[]{:T(128)}), + f32[10,32,512]{2,1,0:T(8,128)S(5)}, u32[]) async-start( + %get-tuple-element.132, %reshape.6, %select.6, + %constant.49, %constant.49), calls=%async_computation + ROOT %dynamic-update-slice-done = f32[10,32,512]{2,1,0:T(8,128)S(5)} + async-done(%dynamic-update-slice-start), calls=%async_computation +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + DependencyHloOrdering ordering(module.get()); + auto* async_start = + FindInstruction(module.get(), "dynamic-update-slice-start"); + auto* async_done = FindInstruction(module.get(), "dynamic-update-slice-done"); + auto* dus = FindInstruction(module.get(), "dynamic-update-slice.1"); + EXPECT_EQ(ordering.GetExecutionConstraint(async_start, dus), + HloOrdering::ExecutionConstraint::kIsSame); + EXPECT_EQ(ordering.GetExecutionConstraint(async_done, dus), + HloOrdering::ExecutionConstraint::kIsSame); +} } // namespace } // namespace xla diff --git a/xla/service/hlo_parser.cc b/xla/service/hlo_parser.cc index 3a31abe766ba8..2937c1ff91861 100644 --- a/xla/service/hlo_parser.cc +++ b/xla/service/hlo_parser.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -129,6 +129,7 @@ bool CanInferShape(HloOpcode code) { case HloOpcode::kDivide: case HloOpcode::kDomain: case HloOpcode::kDot: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFft: @@ -196,6 +197,7 @@ bool CanInferShape(HloOpcode code) { case HloOpcode::kAllReduceStart: case HloOpcode::kAllReduceDone: case HloOpcode::kAllToAll: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermuteStart: case HloOpcode::kCollectivePermuteDone: @@ -247,17 +249,18 @@ class HloParserImpl : public HloParser { std::string GetError() const { return StrJoin(error_, "\n"); } // Stand alone parsing utils for various aggregate data types. - StatusOr ParseShapeOnly(); - StatusOr ParseLayoutOnly(); - StatusOr ParseShardingOnly(); - StatusOr ParseFrontendAttributesOnly(); - StatusOr ParseStatisticsVizOnly(); - StatusOr> ParseParameterReplicationOnly(); - StatusOr ParseBooleanListOrSingleBooleanOnly(); - StatusOr ParseWindowOnly(); - StatusOr ParseConvolutionDimensionNumbersOnly(); - StatusOr ParsePaddingConfigOnly(); - StatusOr> ParseReplicaGroupsOnly(); + absl::StatusOr ParseShapeOnly(); + absl::StatusOr ParseLayoutOnly(); + absl::StatusOr ParseShardingOnly(); + absl::StatusOr ParseFrontendAttributesOnly(); + absl::StatusOr ParseStatisticsVizOnly(); + absl::StatusOr> ParseParameterReplicationOnly(); + absl::StatusOr ParseBooleanListOrSingleBooleanOnly(); + absl::StatusOr ParseWindowOnly(); + absl::StatusOr + ParseConvolutionDimensionNumbersOnly(); + absl::StatusOr ParsePaddingConfigOnly(); + absl::StatusOr> ParseReplicaGroupsOnly(); private: // Types of attributes. @@ -295,12 +298,14 @@ class HloParserImpl : public HloParser { kShapeList, kEnum, kRandomAlgorithm, + kPrecisionAlgorithm, kAliasing, kBufferDonor, kComputationLayout, kInstructionAliasing, kCustomCallSchedule, kCustomCallApiVersion, + kSparsityDescriptor, // A double-quoted string, or a string that looks like a JSON dictionary // enclosed in matching curly braces (returned value includes the curlies). kStringOrJsonDict, @@ -532,6 +537,7 @@ class HloParserImpl : public HloParser { absl::InlinedVector* dim_unique, absl::InlinedVector* dim_ordered); bool ParseTiles(std::vector* tiles); + bool ParseSplitConfigs(std::vector& split_configs); bool ParsePhysicalShape(Shape* physical_shape); bool ParseOpcode(HloOpcode* opcode, std::optional* async_wrapped_opcode); @@ -544,6 +550,7 @@ class HloParserImpl : public HloParser { bool ParseRandomDistribution(RandomDistribution* result); bool ParseRandomAlgorithm(RandomAlgorithm* result); bool ParsePrecision(PrecisionConfig::Precision* result); + bool ParseAlgorithm(PrecisionConfig::Algorithm* result); bool ParseInt64(int64_t* result); bool ParseDouble(double* result); bool ParseComplex(std::complex* result); @@ -571,6 +578,7 @@ class HloParserImpl : public HloParser { bool ParseCustomCallSchedule(CustomCallSchedule* result); bool ParseCustomCallApiVersion(CustomCallApiVersion* result); + bool ParseSparsityDescriptor(std::vector* result); bool ParseShapeIndex(ShapeIndex* out); // Returns true if the current token is the beginning of a shape. @@ -1001,6 +1009,38 @@ bool HloParserImpl::ParseCustomCallApiVersion(CustomCallApiVersion* result) { return true; } +bool HloParserImpl::ParseSparsityDescriptor( + std::vector* result) { + VLOG(3) << "ParseSparsityDescriptor"; + if (lexer_.GetKind() != TokKind::kSparsityDesc) { + return TokenError("expects sparsity descriptor, e.g. L.0@2:4"); + } + std::string val = lexer_.GetStrVal(); + std::vector split = absl::StrSplit(val, '_'); + for (absl::string_view item : split) { + std::vector splitA = absl::StrSplit(item, '@'); + std::vector splitB = absl::StrSplit(splitA[0], '.'); + std::vector splitC = absl::StrSplit(splitA[1], ':'); + SparsityDescriptor descriptor; + int dim, n, m; + if (!absl::SimpleAtoi(splitB[1], &dim) || dim < 0) { + return TokenError("Invalid dimension number"); + } + if (!absl::SimpleAtoi(splitC[0], &n) || !absl::SimpleAtoi(splitC[1], &m) || + n < 1 || m <= n) { + return TokenError("Invalid structured sparsity type"); + } + descriptor.set_type(SparsityType::SPARSITY_STRUCTURED_N_M); + descriptor.set_index(splitB[0] == "L" ? 0 : 1); + descriptor.set_dimension(dim); + descriptor.set_n(n); + descriptor.set_m(m); + result->push_back(descriptor); + } + lexer_.Lex(); + return true; +} + // ::= 'HloModule' name computations bool HloParserImpl::ParseHloModule(HloModule* module, bool parse_module_without_header) { @@ -1014,6 +1054,7 @@ bool HloParserImpl::ParseHloModule(HloModule* module, absl::flat_hash_map attrs; std::optional entry_computation_layout; std::optional frontend_attributes; + BoolList allow_spmd_sharding_propagation_to_parameters; BoolList allow_spmd_sharding_propagation_to_output; attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled}; @@ -1031,6 +1072,9 @@ bool HloParserImpl::ParseHloModule(HloModule* module, &entry_computation_layout}; attrs["frontend_attributes"] = { /*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes}; + attrs["allow_spmd_sharding_propagation_to_parameters"] = { + /*required=*/false, AttrTy::kBracedBoolListOrBool, + &allow_spmd_sharding_propagation_to_parameters}; attrs["allow_spmd_sharding_propagation_to_output"] = { /*required=*/false, AttrTy::kBracedBoolListOrBool, &allow_spmd_sharding_propagation_to_output}; @@ -1089,6 +1133,11 @@ bool HloParserImpl::ParseHloModule(HloModule* module, if (frontend_attributes) { module->set_frontend_attributes(frontend_attributes.value()); } + if (!allow_spmd_sharding_propagation_to_parameters.empty()) { + config.set_allow_spmd_sharding_propagation_to_parameters( + allow_spmd_sharding_propagation_to_parameters); + default_config = false; + } if (!allow_spmd_sharding_propagation_to_output.empty()) { config.set_allow_spmd_sharding_propagation_to_output( allow_spmd_sharding_propagation_to_output); @@ -1322,14 +1371,6 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, attrs["backend_config"] = {/*required=*/false, AttrTy::kStringOrJsonDict, &backend_config}; - optional operation_queue_id; - attrs["operation_queue_id"] = {/*required=*/false, AttrTy::kInt64, - &operation_queue_id}; - - optional> wait_on_operation_queues; - attrs["wait_on_operation_queues"] = { - /*required=*/false, AttrTy::kBracedInt64List, &wait_on_operation_queues}; - std::optional maybe_shape; if (parse_shape) { maybe_shape = shape; @@ -1401,12 +1442,6 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, if (statistics_viz) { instruction->set_statistics_viz(*statistics_viz); } - if (operation_queue_id) { - instruction->set_operation_queue_id(*operation_queue_id); - } - if (wait_on_operation_queues) { - instruction->set_wait_on_operation_queues(*wait_on_operation_queues); - } return AddInstruction(name, instruction, name_loc); } @@ -1422,7 +1457,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT operands = *preset_operands; } const auto maybe_infer_shape = - [&](absl::FunctionRef()> infer) { + [&](absl::FunctionRef()> infer) { if (shape.has_value()) { return true; } @@ -1518,6 +1553,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT case HloOpcode::kCopyDone: case HloOpcode::kCos: case HloOpcode::kOptimizationBarrier: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kImag: @@ -1732,6 +1768,23 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT constrain_layout ? *constrain_layout : false, channel_id, split_dimension)); } + case HloOpcode::kCollectiveBroadcast: { + optional>> tmp_groups; + attrs["replica_groups"] = {/*required=*/true, + AttrTy::kBracedInt64ListList, &tmp_groups}; + optional channel_id; + attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id}; + if ((!preset_operands && !ParseOperands(&operands, builder)) || + !ParseAttributes(attrs, allow_attributes)) { + return nullptr; + } + std::vector replica_groups; + if (tmp_groups) { + replica_groups = CreateReplicaGroups(*tmp_groups); + } + return builder->AddInstruction(HloInstruction::CreateCollectiveBroadcast( + *shape, operands, replica_groups, false, channel_id)); + } case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermuteStart: { optional>> source_targets; @@ -1810,82 +1863,146 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT return shape.IsTuple() && shape.tuple_shapes_size() >= 2 && shape.tuple_shapes(0).IsTuple(); }; - optional async_group_id; - attrs["async_group_id"] = {/*required=*/false, AttrTy::kInt64, - &async_group_id}; - optional async_execution_thread = - HloInstruction::kMainExecutionThread; + // Verify operand/resulting shapes + if (opcode == HloOpcode::kAsyncUpdate || + opcode == HloOpcode::kAsyncDone) { + if (operands.size() != 1 || + !is_async_shape_correct(operands[0]->shape())) { + TokenError( + "AsyncUpdate and AsyncDone expect a single operand in the form " + "of ((async-operands), async-outputs, state)."); + return nullptr; + } + } + if (opcode == HloOpcode::kAsyncStart || + opcode == HloOpcode::kAsyncUpdate) { + if (!is_async_shape_correct(*shape)) { + TokenError( + "AsyncStart and AsyncUpdate expect the op shape to be in the " + "form of " + "((async-operands), async-outputs, state)."); + return nullptr; + } + } + // async-{update,done} expect their one singular operand to be the + // previous async op. + if (opcode == HloOpcode::kAsyncUpdate || + opcode == HloOpcode::kAsyncDone) { + if (operands.size() != 1 || + !is_async_shape_correct(operands[0]->shape())) { + TokenError( + "AsyncUpdate and AsyncDone expect a single operand in the form " + "of ((async-operands), async-outputs, state)."); + return nullptr; + } + if (!operands[0]->IsAsynchronous()) { + TokenError( + "AsyncUpdate and AsyncDone expect their operand to be the " + "previous async op."); + return nullptr; + } + } + optional async_execution_thread; attrs["async_execution_thread"] = {/*required=*/false, AttrTy::kString, &async_execution_thread}; if (async_wrapped_opcode) { - std::vector async_wrapped_operands; - std::vector async_wrapped_operand_shapes; - Shape async_wrapped_root_shape; + // Only generate async-wrapper for async-start. if (opcode == HloOpcode::kAsyncStart) { + std::vector async_wrapped_operands; + std::vector async_wrapped_operand_shapes; + Shape async_wrapped_root_shape; for (const HloInstruction* operand : operands) { async_wrapped_operand_shapes.push_back(operand->shape()); } - } else { - if (operands.size() != 1 || - !is_async_shape_correct(operands[0]->shape())) { - TokenError( - "AsyncUpdate and AsyncDone expect a single operand in the form " - "of ((async-operands), async-outputs, state)."); + async_wrapped_root_shape = shape->tuple_shapes(1); + HloComputation::Builder async_wrapped_builder("async_wrapped"); + async_wrapped_operands.reserve(async_wrapped_operand_shapes.size()); + for (int i = 0; i < async_wrapped_operand_shapes.size(); ++i) { + async_wrapped_operands.push_back( + async_wrapped_builder.AddInstruction( + HloInstruction::CreateParameter( + i, async_wrapped_operand_shapes.at(i), "async_param"))); + } + HloInstruction* root = + CreateInstruction(&async_wrapped_builder, "async_op", + async_wrapped_root_shape, *async_wrapped_opcode, + /*async_wrapped_opcode=*/std::nullopt, attrs, + allow_attributes, &async_wrapped_operands); + if (!root) { return nullptr; } - async_wrapped_operand_shapes = - operands[0]->shape().tuple_shapes(0).tuple_shapes(); - } - - if (opcode == HloOpcode::kAsyncDone) { - async_wrapped_root_shape = *shape; + computations_.emplace_back(async_wrapped_builder.Build(root)); + async_computation = computations_.back().get(); } else { - if (!is_async_shape_correct(*shape)) { + // Since async-{update,done} will inherit the computation from + // async-start, we'll only need to make sure it matches what was + // specified explicitily. + if (operands[0]->async_wrapped_opcode() != *async_wrapped_opcode) { TokenError( - "AsyncStart and AsyncUpdate expect the op shape to be in the " - "form of " - "((async-operands), async-outputs, state)."); + StrFormat("Expect async wrapped opcode to be %s, but got %s", + HloOpcodeString(operands[0]->async_wrapped_opcode()), + HloOpcodeString(*async_wrapped_opcode))); return nullptr; } - async_wrapped_root_shape = shape->tuple_shapes(1); } - HloComputation::Builder async_wrapped_builder("async_wrapped"); - async_wrapped_operands.reserve(async_wrapped_operand_shapes.size()); - for (int i = 0; i < async_wrapped_operand_shapes.size(); ++i) { - async_wrapped_operands.push_back(async_wrapped_builder.AddInstruction( - HloInstruction::CreateParameter( - i, async_wrapped_operand_shapes.at(i), "async_param"))); + } else { + attrs["calls"] = {/*required=*/opcode == HloOpcode::kAsyncStart, + AttrTy::kHloComputation, &async_computation}; + } + // Attributes would have already been consumed when constructing the + // async wrapped computation for async-start. + if (!(async_wrapped_opcode && opcode == HloOpcode::kAsyncStart)) { + if (!ParseAttributes(attrs, allow_attributes)) { + return nullptr; } - HloInstruction* root = - CreateInstruction(&async_wrapped_builder, "async_op", - async_wrapped_root_shape, *async_wrapped_opcode, - /*async_wrapped_opcode=*/std::nullopt, attrs, - allow_attributes, &async_wrapped_operands); - if (!root) { + } + // Async attributes on async-{update,done} are allowed for backward + // compatibility reasons, but are ignored, since they are inherited + // from the async-start op. Simply check that whatever is explicitly + // specified matches what is inherited. + if (opcode == HloOpcode::kAsyncUpdate || + opcode == HloOpcode::kAsyncDone) { + if (async_execution_thread && + operands[0]->async_execution_thread() != *async_execution_thread) { + TokenError(StrFormat( + "Expect async_execution_thread to be %s, but got %s", + operands[0]->async_execution_thread(), *async_execution_thread)); return nullptr; } - computations_.emplace_back(async_wrapped_builder.Build(root)); - async_computation = computations_.back().get(); - } else { - attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation, - &async_computation}; - if (!ParseAttributes(attrs, allow_attributes)) { + if (async_computation && + operands[0]->async_wrapped_computation() != *async_computation) { + TokenError( + StrFormat("Expect async_wrapped_computation to be %s, but got %s", + operands[0]->async_wrapped_computation()->name(), + (*async_computation)->name())); return nullptr; } } + // There should be a 1:1 correspondence between async-start ops and + // async wrapped computations. At this stage, the computation should + // not be referenced by any other async op. + if (opcode == HloOpcode::kAsyncStart && + (*async_computation)->IsAsyncComputation()) { + TokenError(StrFormat( + "Computation %s is already referenced by another async op", + (*async_computation)->name())); + return nullptr; + } if (opcode == HloOpcode::kAsyncStart) { + // async_execution_thread only needs to be populated for async-start, + // as the rest of the async chain will reference the root op. + if (!async_execution_thread) { + async_execution_thread = HloInstruction::kMainExecutionThread; + } return builder->AddInstruction(HloInstruction::CreateAsyncStart( - *shape, operands, *async_computation, async_group_id, - *async_execution_thread)); + *shape, operands, *async_computation, *async_execution_thread)); } if (opcode == HloOpcode::kAsyncUpdate) { - return builder->AddInstruction(HloInstruction::CreateAsyncUpdate( - *shape, operands[0], *async_computation, async_group_id, - *async_execution_thread)); + return builder->AddInstruction( + HloInstruction::CreateAsyncUpdate(*shape, operands[0])); } - return builder->AddInstruction(HloInstruction::CreateAsyncDone( - *shape, operands[0], *async_computation, async_group_id, - *async_execution_thread)); + return builder->AddInstruction( + HloInstruction::CreateAsyncDone(*shape, operands[0])); } case HloOpcode::kCopyStart: { optional cross_program_prefetch_index = std::nullopt; @@ -2055,14 +2172,15 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT !ParseAttributes(attrs, allow_attributes)) { return nullptr; } - if (dynamic_cast(operands[0]) == nullptr) { - return nullptr; - } - if (channel_id != operands[0]->channel_id()) { - return nullptr; + + if (dynamic_cast(operands[0]) != nullptr) { + if (channel_id != operands[0]->channel_id()) { + return nullptr; + } } - return builder->AddInstruction( - HloInstruction::CreateRecvDone(operands[0], *is_host_transfer)); + + return builder->AddInstruction(HloInstruction::CreateRecvDone( + operands[0], channel_id.value(), *is_host_transfer)); } case HloOpcode::kSend: { optional channel_id; @@ -2091,14 +2209,15 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT !ParseAttributes(attrs, allow_attributes)) { return nullptr; } - if (dynamic_cast(operands[0]) == nullptr) { - return nullptr; - } - if (channel_id != operands[0]->channel_id()) { - return nullptr; + + if (dynamic_cast(operands[0]) != nullptr) { + if (channel_id != operands[0]->channel_id()) { + return nullptr; + } } - return builder->AddInstruction( - HloInstruction::CreateSendDone(operands[0], *is_host_transfer)); + + return builder->AddInstruction(HloInstruction::CreateSendDone( + operands[0], channel_id.value(), *is_host_transfer)); } case HloOpcode::kGetTupleElement: { optional index; @@ -2963,13 +3082,32 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT optional> operand_precision; attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, &operand_precision}; + std::vector sparsity; + attrs["sparsity"] = {/*required=*/false, AttrTy::kSparsityDescriptor, + &sparsity}; - if ((!preset_operands && - !ParseOperands(&operands, builder, /*expected_size=*/2)) || + optional algorithm; + attrs["algorithm"] = {/*required=*/false, AttrTy::kPrecisionAlgorithm, + &algorithm}; + + LocTy loc = lexer_.GetLoc(); + if ((!preset_operands && !ParseOperands(&operands, builder)) || !ParseAttributes(attrs, allow_attributes)) { return nullptr; } + int expected_size = HloDotInstruction::kOperands + sparsity.size(); + if (sparsity.size() > HloDotInstruction::kOperands) { + Error(loc, + StrCat("too many sparse dot descriptors: ", sparsity.size())); + return nullptr; + } + if (operands.size() != expected_size) { + Error(loc, StrCat("expects ", expected_size, " operands, but has ", + operands.size(), " operands")); + return nullptr; + } + DotDimensionNumbers dnum; if (lhs_contracting_dims) { *dnum.mutable_lhs_contracting_dimensions() = { @@ -2994,17 +3132,21 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT operand_precision->begin(), operand_precision->end()}; } else { precision_config.mutable_operand_precision()->Resize( - operands.size(), PrecisionConfig::DEFAULT); + HloDotInstruction::kOperands, PrecisionConfig::DEFAULT); + } + if (algorithm) { + precision_config.set_algorithm(*algorithm); } if (!maybe_infer_shape([&] { return ShapeInference::InferDotOpShape( operands[0]->shape(), operands[1]->shape(), dnum, - /*preferred_element_type=*/std::nullopt); + /*preferred_element_type=*/std::nullopt, sparsity); })) { return nullptr; } return builder->AddInstruction(HloInstruction::CreateDot( - *shape, operands[0], operands[1], dnum, precision_config)); + *shape, operands[0], operands[1], dnum, precision_config, sparsity, + absl::MakeSpan(operands).subspan(HloDotInstruction::kOperands))); } case HloOpcode::kGather: { optional> offset_dims; @@ -3164,8 +3306,9 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT return builder->AddInstruction(HloInstruction::CreateSetDimensionSize( *shape, operands[0], operands[1], (*dimensions)[0])); } + default: + return nullptr; } - return nullptr; } // NOLINT(readability/fn_size) // ::= '{' (single_sharding | tuple_sharding) '}' @@ -4187,14 +4330,26 @@ bool HloParserImpl::ParseDenseLiteral(Literal* literal, const Shape& shape) { // away. This is a best-effort approach to make sure replaying a HLO // gives us same optimized HLO graph. static uint32_t data = 0; + + // According to the System V ABI not all 8 bit values are valid booleans + // - only the values 0 and 1 are allowed. So to avoid undefined + // behaviour we mask elements of type PRED accordingly. The mask assumes + // that the C++ data type `bool` is represented as a single byte. + static_assert(sizeof(bool) == 1); + constexpr uint32_t kBooleanMask = 0x01010101; + + constexpr uint32_t kNoMask = 0xFFFFFFFF; + const uint32_t mask = + (shape.element_type() == PRED) ? kBooleanMask : kNoMask; + uint32_t* raw_data = static_cast(literal->untyped_data()); for (int64_t i = 0; i < literal->size_bytes() / 4; ++i) { - raw_data[i] = data++; + raw_data[i] = data++ & mask; } uint8_t* raw_data_int8 = static_cast(literal->untyped_data()); static uint8_t data_int8 = 0; for (int64_t i = 0; i < literal->size_bytes() % 4; ++i) { - raw_data_int8[literal->size_bytes() / 4 + i] = data_int8++; + raw_data_int8[literal->size_bytes() / 4 + i] = data_int8++ & mask; } break; } @@ -4763,6 +4918,15 @@ bool HloParserImpl::ParseAttributeHelper( static_cast*>(attr_out_ptr)->emplace(result); return true; } + case AttrTy::kPrecisionAlgorithm: { + PrecisionConfig::Algorithm result; + if (!ParseAlgorithm(&result)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(result); + return true; + } case AttrTy::kAliasing: { AliasingData aliasing_data; if (!ParseAliasing(&aliasing_data)) { @@ -4830,6 +4994,15 @@ bool HloParserImpl::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kSparsityDescriptor: { + std::vector result; + if (!ParseSparsityDescriptor(&result)) { + return false; + } + *static_cast*>(attr_out_ptr) = + std::move(result); + return true; + } } }(); if (!success) { @@ -5506,7 +5679,7 @@ bool HloParserImpl::ParseDimLevelTypes( // tiles // ::= /*empty*/ -// ::= 'T' '(' dim_list ')' +// ::= 'T' ('(' dim_list ')')+ // dim_list // ::= /*empty*/ // ::= (int64_t | '*') (',' (int64_t | '*'))* @@ -5598,12 +5771,46 @@ bool HloParserImpl::ParseLayoutIntAttribute( return true; } +// split_configs +// ::= /*empty*/ +// ::= 'SC' ('(' int64_t ':' int64_list ')')+ +bool HloParserImpl::ParseSplitConfigs(std::vector& split_configs) { + auto parse_and_add_split_index = [&]() { + int64_t i; + if (ParseInt64(&i)) { + split_configs.back().add_split_indices(i); + return true; + } + return false; + }; + + do { + if (!ParseToken(TokKind::kLparen, + StrCat("expects split configs to start with ", + TokKindToString(TokKind::kLparen)))) { + return false; + } + int64_t dimension; + if (!ParseInt64(&dimension)) { + return false; + } + split_configs.push_back(SplitConfig(dimension, {})); + if (!ParseList(TokKind::kColon, TokKind::kRparen, TokKind::kComma, + parse_and_add_split_index)) { + return false; + } + } while (lexer_.GetKind() == TokKind::kLparen); + return true; +} + // layout // ::= '{' int64_list // (':' dim_level_types // tiles +// tail_padding_alignment_in_elements // element_size_in_bits // memory_space +// split_configs // physical_shape // dynamic_shape_metadata_prefix_bytes)? // '}' @@ -5623,8 +5830,10 @@ bool HloParserImpl::ParseLayout(Layout* layout) { PrimitiveType pointer_primitive_type = PRIMITIVE_TYPE_INVALID; int64_t element_size_in_bits = 0; int64_t memory_space = 0; + std::vector split_configs; std::optional physical_shape; int64_t dynamic_shape_metadata_prefix_bytes = 0; + int64_t tail_padding_alignment_in_elements = 1; auto parse_and_add_item = [&]() { int64_t i; @@ -5663,6 +5872,12 @@ bool HloParserImpl::ParseLayout(Layout* layout) { ParseTiles(&tiles); } + if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "L") { + lexer_.Lex(); + ParseLayoutIntAttribute(&tail_padding_alignment_in_elements, + "multiple padded to in elements"); + } + if (lexer_.GetKind() == TokKind::kOctothorp) { lexer_.Lex(); ParseToken( @@ -5697,6 +5912,11 @@ bool HloParserImpl::ParseLayout(Layout* layout) { ParseLayoutIntAttribute(&memory_space, "memory space"); } + if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "SC") { + lexer_.Lex(); + ParseSplitConfigs(split_configs); + } + if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "P") { lexer_.Lex(); physical_shape.emplace(); @@ -5720,11 +5940,11 @@ bool HloParserImpl::ParseLayout(Layout* layout) { for (int i = 0; i < tiles.size(); i++) { vec_tiles[i] = Tile(tiles[i]); } - *layout = LayoutUtil::MakeLayout(minor_to_major, dim_level_types, dim_unique, - dim_ordered, vec_tiles, index_primitive_type, - pointer_primitive_type, element_size_in_bits, - memory_space, std::move(physical_shape), - dynamic_shape_metadata_prefix_bytes); + *layout = LayoutUtil::MakeLayout( + minor_to_major, dim_level_types, dim_unique, dim_ordered, vec_tiles, + tail_padding_alignment_in_elements, index_primitive_type, + pointer_primitive_type, element_size_in_bits, memory_space, split_configs, + std::move(physical_shape), dynamic_shape_metadata_prefix_bytes); return true; } @@ -6219,6 +6439,22 @@ bool HloParserImpl::ParsePrecision(PrecisionConfig::Precision* result) { return true; } +bool HloParserImpl::ParseAlgorithm(PrecisionConfig::Algorithm* result) { + VLOG(3) << "ParseAlgorithm"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects algorithm"); + } + std::string val = lexer_.GetStrVal(); + auto status_or_result = StringToAlgorithm(val); + if (!status_or_result.ok()) { + return TokenError(StrFormat("expects algorithm but sees: %s, error: %s", + val, status_or_result.status().message())); + } + *result = status_or_result.value(); + lexer_.Lex(); + return true; +} + bool HloParserImpl::ParseInt64(int64_t* result) { VLOG(3) << "ParseInt64"; if (lexer_.GetKind() != TokKind::kInt) { @@ -6346,7 +6582,7 @@ bool HloParserImpl::AddComputation(const std::string& name, return true; } -StatusOr HloParserImpl::ParseShapeOnly() { +absl::StatusOr HloParserImpl::ParseShapeOnly() { lexer_.Lex(); Shape shape; if (!ParseShape(&shape)) { @@ -6358,7 +6594,7 @@ StatusOr HloParserImpl::ParseShapeOnly() { return shape; } -StatusOr HloParserImpl::ParseLayoutOnly() { +absl::StatusOr HloParserImpl::ParseLayoutOnly() { lexer_.Lex(); Layout layout; if (!ParseLayout(&layout)) { @@ -6370,7 +6606,7 @@ StatusOr HloParserImpl::ParseLayoutOnly() { return layout; } -StatusOr HloParserImpl::ParseShardingOnly() { +absl::StatusOr HloParserImpl::ParseShardingOnly() { lexer_.Lex(); OpSharding op_sharding; if (!ParseSharding(&op_sharding)) { @@ -6382,7 +6618,8 @@ StatusOr HloParserImpl::ParseShardingOnly() { return HloSharding::FromProto(op_sharding); } -StatusOr HloParserImpl::ParseFrontendAttributesOnly() { +absl::StatusOr +HloParserImpl::ParseFrontendAttributesOnly() { lexer_.Lex(); FrontendAttributes attributes; if (!ParseFrontendAttributes(&attributes)) { @@ -6395,7 +6632,7 @@ StatusOr HloParserImpl::ParseFrontendAttributesOnly() { return attributes; } -StatusOr HloParserImpl::ParseStatisticsVizOnly() { +absl::StatusOr HloParserImpl::ParseStatisticsVizOnly() { lexer_.Lex(); StatisticsViz statistics_viz; if (!ParseStatisticsViz(&statistics_viz)) { @@ -6407,7 +6644,8 @@ StatusOr HloParserImpl::ParseStatisticsVizOnly() { return statistics_viz; } -StatusOr> HloParserImpl::ParseParameterReplicationOnly() { +absl::StatusOr> +HloParserImpl::ParseParameterReplicationOnly() { lexer_.Lex(); ParameterReplication parameter_replication; if (!ParseParameterReplication(¶meter_replication)) { @@ -6422,7 +6660,7 @@ StatusOr> HloParserImpl::ParseParameterReplicationOnly() { parameter_replication.replicated_at_leaf_buffers().end()); } -StatusOr +absl::StatusOr HloParserImpl::ParseBooleanListOrSingleBooleanOnly() { lexer_.Lex(); BoolList booleans; @@ -6435,7 +6673,8 @@ HloParserImpl::ParseBooleanListOrSingleBooleanOnly() { return booleans; } -StatusOr> HloParserImpl::ParseReplicaGroupsOnly() { +absl::StatusOr> +HloParserImpl::ParseReplicaGroupsOnly() { lexer_.Lex(); std::vector replica_groups; if (!ParseReplicaGroupsOnly(&replica_groups)) { @@ -6447,7 +6686,7 @@ StatusOr> HloParserImpl::ParseReplicaGroupsOnly() { return replica_groups; } -StatusOr HloParserImpl::ParseWindowOnly() { +absl::StatusOr HloParserImpl::ParseWindowOnly() { lexer_.Lex(); Window window; if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) { @@ -6459,7 +6698,7 @@ StatusOr HloParserImpl::ParseWindowOnly() { return window; } -StatusOr +absl::StatusOr HloParserImpl::ParseConvolutionDimensionNumbersOnly() { lexer_.Lex(); ConvolutionDimensionNumbers dnums; @@ -6473,7 +6712,7 @@ HloParserImpl::ParseConvolutionDimensionNumbersOnly() { return dnums; } -StatusOr HloParserImpl::ParsePaddingConfigOnly() { +absl::StatusOr HloParserImpl::ParsePaddingConfigOnly() { lexer_.Lex(); PaddingConfig padding_config; if (!ParsePaddingConfig(&padding_config)) { @@ -6544,7 +6783,7 @@ bool HloParserImpl::ParseSingleInstruction(HloModule* module) { } // namespace -StatusOr> ParseAndReturnUnverifiedModule( +absl::StatusOr> ParseAndReturnUnverifiedModule( absl::string_view str, const HloModuleConfig& config) { auto module = std::make_unique(/*name=*/"_", config); HloParserImpl parser(str); @@ -6552,65 +6791,67 @@ StatusOr> ParseAndReturnUnverifiedModule( return std::move(module); } -StatusOr> ParseAndReturnUnverifiedModule( +absl::StatusOr> ParseAndReturnUnverifiedModule( absl::string_view str) { return ParseAndReturnUnverifiedModule(str, HloModuleConfig()); } -StatusOr ParseSharding(absl::string_view str) { +absl::StatusOr ParseSharding(absl::string_view str) { HloParserImpl parser(str); return parser.ParseShardingOnly(); } -StatusOr ParseFrontendAttributes(absl::string_view str) { +absl::StatusOr ParseFrontendAttributes( + absl::string_view str) { HloParserImpl parser(str); return parser.ParseFrontendAttributesOnly(); } -StatusOr ParseStatisticsViz(absl::string_view str) { +absl::StatusOr ParseStatisticsViz(absl::string_view str) { HloParserImpl parser(str); return parser.ParseStatisticsVizOnly(); } -StatusOr> ParseParameterReplication(absl::string_view str) { +absl::StatusOr> ParseParameterReplication( + absl::string_view str) { HloParserImpl parser(str); return parser.ParseParameterReplicationOnly(); } -StatusOr ParseBooleanListOrSingleBoolean( +absl::StatusOr ParseBooleanListOrSingleBoolean( absl::string_view str) { HloParserImpl parser(str); return parser.ParseBooleanListOrSingleBooleanOnly(); } -StatusOr> ParseReplicaGroupsOnly( +absl::StatusOr> ParseReplicaGroupsOnly( absl::string_view str) { HloParserImpl parser(str); return parser.ParseReplicaGroupsOnly(); } -StatusOr ParseWindow(absl::string_view str) { +absl::StatusOr ParseWindow(absl::string_view str) { HloParserImpl parser(str); return parser.ParseWindowOnly(); } -StatusOr ParseConvolutionDimensionNumbers( +absl::StatusOr ParseConvolutionDimensionNumbers( absl::string_view str) { HloParserImpl parser(str); return parser.ParseConvolutionDimensionNumbersOnly(); } -StatusOr ParsePaddingConfig(absl::string_view str) { +absl::StatusOr ParsePaddingConfig(absl::string_view str) { HloParserImpl parser(str); return parser.ParsePaddingConfigOnly(); } -StatusOr ParseShape(absl::string_view str) { +absl::StatusOr ParseShape(absl::string_view str) { HloParserImpl parser(str); return parser.ParseShapeOnly(); } -StatusOr ParseLayout(absl::string_view str) { +absl::StatusOr ParseLayout(absl::string_view str) { HloParserImpl parser(str); return parser.ParseLayoutOnly(); } diff --git a/xla/service/hlo_parser.h b/xla/service/hlo_parser.h index e130b94949401..1a29800d6fedc 100644 --- a/xla/service/hlo_parser.h +++ b/xla/service/hlo_parser.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,58 +33,60 @@ namespace xla { // creates a HloModule with the given config. // Note: Tests derived from HloTestBase should use // ParseAndReturnVerifiedModule() instead! -StatusOr> ParseAndReturnUnverifiedModule( +absl::StatusOr> ParseAndReturnUnverifiedModule( absl::string_view str, const HloModuleConfig& config); // Given a string in the HloModule::ToString() format, parses the string and // creates a HloModule with default config. // Note: Tests derived from HloTestBase should use // ParseAndReturnVerifiedModule() instead! -StatusOr> ParseAndReturnUnverifiedModule( +absl::StatusOr> ParseAndReturnUnverifiedModule( absl::string_view str); // Parses sharding from str. str is supposed to contain the body of the // sharding, i.e. just the rhs of the "sharding={...}" attribute string, e.g., // "{replicated}". -StatusOr ParseSharding(absl::string_view str); +absl::StatusOr ParseSharding(absl::string_view str); // Parses frontend attributes from str. str is supposed to contain the body of // the frontend attributes , i.e. just the rhs of the // "frontend_attributes={...}" attribute string, e.g., // "{attr_a=a,attr_b=b}". -StatusOr ParseFrontendAttributes(absl::string_view str); +absl::StatusOr ParseFrontendAttributes( + absl::string_view str); // Parses statistics viz from str. str is supposed to contain the body of the // statistics visualization, i.e. just the rhs of the "statistics={...}" // attribute string, e.g., "{visualizing_index=1,nan_percent=50}". -StatusOr ParseStatisticsViz(absl::string_view str); +absl::StatusOr ParseStatisticsViz(absl::string_view str); // Parses parameter replication from str. str is supposed to contain the body of // the parameter replication, i.e. just the rhs of the // "parameter_replication={...}" attribute string, e.g., "{true, false}". -StatusOr> ParseParameterReplication(absl::string_view str); +absl::StatusOr> ParseParameterReplication( + absl::string_view str); // Parses the result of window_util::ToString(const Window&). -StatusOr ParseWindow(absl::string_view str); +absl::StatusOr ParseWindow(absl::string_view str); // Parses the result of ConvolutionDimensionNumbersToString(), e.g. // "b0f_0io->b0f". -StatusOr ParseConvolutionDimensionNumbers( +absl::StatusOr ParseConvolutionDimensionNumbers( absl::string_view str); // Parses the result of PaddingConfigToString(), e.g. "0_0x1_1". -StatusOr ParsePaddingConfig(absl::string_view str); +absl::StatusOr ParsePaddingConfig(absl::string_view str); // Parses and returns a Shape::ToString-format string. -StatusOr ParseShape(absl::string_view str); +absl::StatusOr ParseShape(absl::string_view str); // Parses and returns a Layout::ToString-format string. -StatusOr ParseLayout(absl::string_view str); +absl::StatusOr ParseLayout(absl::string_view str); // Parses and returns a std::vector from str. str is supposed to // contain a list of the replica groups, i.e. just the rhs of the // "replica_groups={...}" attribute string, e.g., "{{0,1}, {2,3}}". -StatusOr> ParseReplicaGroupsOnly( +absl::StatusOr> ParseReplicaGroupsOnly( absl::string_view str); class HloParser { diff --git a/xla/service/hlo_parser_test.cc b/xla/service/hlo_parser_test.cc index d3a0b01054f88..7ff7b661208ac 100644 --- a/xla/service/hlo_parser_test.cc +++ b/xla/service/hlo_parser_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -1333,21 +1333,8 @@ R"(HloModule AsyncOpsWithSyntaxSugar, entry_computation_layout={(f32[10]{0})->f3 ENTRY %Entry (p0: f32[10]) -> f32[20] { %p0 = f32[10]{0} parameter(0) %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), custom_call_target="foo" - %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), custom_call_target="foo" - ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), custom_call_target="foo" -} - -)" -}, -{ -"AsyncOpsWithSyntaxSugarAndGroupId", -R"(HloModule AsyncOpsWithSyntaxSugarAndGroupId, entry_computation_layout={(f32[10]{0})->f32[20]{0}} - -ENTRY %Entry (p0: f32[10]) -> f32[20] { - %p0 = f32[10]{0} parameter(0) - %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), async_group_id=3, custom_call_target="foo" - %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), async_group_id=3, custom_call_target="foo" - ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), async_group_id=3, custom_call_target="foo" + %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start) + ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update) } )" @@ -1360,8 +1347,8 @@ R"(HloModule AsyncOpsWithSyntaxSugarAndThreadName, entry_computation_layout={(f3 ENTRY %Entry (p0: f32[10]) -> f32[20] { %p0 = f32[10]{0} parameter(0) %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), async_execution_thread="parallel_thread", custom_call_target="foo" - %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), async_execution_thread="parallel_thread", custom_call_target="foo" - ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), async_execution_thread="parallel_thread", custom_call_target="foo" + %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start) + ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update) } )" @@ -1374,8 +1361,8 @@ R"(HloModule HloComputationWithParallelThreadName, entry_computation_layout={(f3 ENTRY %Entry (p0: f32[10]) -> f32[20] { %p0 = f32[10]{0} parameter(0) %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), async_execution_thread="parallel_thread", custom_call_target="foo" - %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), async_execution_thread="parallel_thread", custom_call_target="foo" - ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), async_execution_thread="parallel_thread", custom_call_target="foo" + %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start) + ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update) }, execution_thread="main_thread" )" @@ -1776,6 +1763,45 @@ ENTRY dot { ROOT dot = f32[2]{0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={1}, rhs_contracting_dims={0} } +)" +}, +{ +"DotSparseOperand", +R"(HloModule dot, entry_computation_layout={(f16[32,32]{1,0}, f16[64,32]{1,0}, u16[32,4]{1,0})->f16[32,32]{1,0}} + +ENTRY dot { + a = f16[32,32]{1,0} parameter(0) + b = f16[64,32]{1,0} parameter(1) + meta = u16[32,4]{1,0} parameter(2) + ROOT dot = f16[32,32]{1,0} dot(a, b, meta), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4 +} + +)" +}, +{ +"DotSparseOperands", +R"(HloModule dot, entry_computation_layout={(f16[32,32]{1,0}, f16[32,32]{1,0}, u16[32,4]{1,0}, u16[4,32]{1,0})->f16[32,32]{1,0}} + +ENTRY dot { + a = f16[32,32]{1,0} parameter(0) + b = f16[32,32]{1,0} parameter(1) + a_meta = u16[32,4]{1,0} parameter(2) + b_meta = u16[4,32]{1,0} parameter(3) + ROOT dot = f16[32,32]{1,0} dot(a, b, a_meta, b_meta), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4_R.0@2:4 +} + +)" +}, +{ +"DotWithAlgorithm", +R"(HloModule dot, entry_computation_layout={(f32[2,10]{1,0}, f32[10,2]{1,0})->f32[2]{0}} + +ENTRY dot { + a = f32[2,10]{1,0} parameter(0) + b = f32[10,2]{1,0} parameter(1) + ROOT dot = f32[2]{0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={1}, rhs_contracting_dims={0}, algorithm=dot_tf32_tf32_f32 +} + )" }, { @@ -1961,6 +1987,19 @@ ENTRY AllToAllWithSubgroups { ROOT a2a = (f32[128,32]{0,1}, f32[128,32]{0,1}) all-to-all(p0, p1), replica_groups={{1,2},{3,0}} } +)", +/*replica_count=*/4, +}, +// collective-broadcast +{ +"CollectiveBroadcast", +R"(HloModule CollectiveBroadcast, entry_computation_layout={(f32[128,32]{0,1})->f32[128,32]{0,1}}, replica_count=4 + +ENTRY CollectiveBroadcast { + input = f32[128,32]{0,1} parameter(0) + ROOT cb = f32[128,32]{0,1} collective-broadcast(input), replica_groups={{1,0},{2,3}} +} + )", /*replica_count=*/4, }, @@ -2558,8 +2597,8 @@ class HloParserTest : public ::testing::Test { EXPECT_TRUE(absl::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } - StatusOr> ParseAndReturnVerifiedModule( - absl::string_view hlo_text) { + absl::StatusOr> + ParseAndReturnVerifiedModule(absl::string_view hlo_text) { auto module = std::make_unique( ::testing::UnitTest::GetInstance()->current_test_info()->name(), HloModuleConfig(), @@ -3708,7 +3747,7 @@ TEST(HloParserSingleOpTest, SingleOp) { TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) { const std::string text = "multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)"; - StatusOr> module = + absl::StatusOr> module = ParseAndReturnUnverifiedModule(text); ASSERT_TRUE(!module.status().ok()); LOG(INFO) << "Status: " << module.status(); @@ -3718,7 +3757,7 @@ TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) { TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) { const std::string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)"; - StatusOr> module = + absl::StatusOr> module = ParseAndReturnUnverifiedModule(text); ASSERT_TRUE(!module.status().ok()); LOG(INFO) << "Status: " << module.status(); @@ -4151,7 +4190,7 @@ TEST_F(HloParserTest, ParseShapeStringWithElementSizeInBits) { std::string shape_string = "s4[123,456]{1,0:T(2,128)E(4)}"; TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); Shape expected = ShapeUtil::MakeShapeWithDenseLayout(S4, {123, 456}, {1, 0}, - {Tile({2, 128})}, 4); + {Tile({2, 128})}, 1, 4); EXPECT_EQ(expected, actual) << "expected: " << ShapeUtil::HumanStringWithLayout(expected) << "actual: " << ShapeUtil::HumanStringWithLayout(actual); @@ -4161,8 +4200,8 @@ TEST_F(HloParserTest, ParseShapeStringWithMemorySpaceLayout) { // Tile, element size, and memory space. std::string shape_string = "pred[123,456]{1,0:T(2,128)S(3)}"; TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); - Shape expected = ShapeUtil::MakeShapeWithDenseLayout(PRED, {123, 456}, {1, 0}, - {Tile({2, 128})}, 0, 3); + Shape expected = ShapeUtil::MakeShapeWithDenseLayout( + PRED, {123, 456}, {1, 0}, {Tile({2, 128})}, 1, 0, 3); EXPECT_EQ(expected, actual) << "expected: " << ShapeUtil::HumanStringWithLayout(expected) << "actual: " << ShapeUtil::HumanStringWithLayout(actual); @@ -4170,8 +4209,8 @@ TEST_F(HloParserTest, ParseShapeStringWithMemorySpaceLayout) { // Element size and memory space. shape_string = "pred[123,456]{1,0:S(3)}"; TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string)); - expected = - ShapeUtil::MakeShapeWithDenseLayout(PRED, {123, 456}, {1, 0}, {}, 0, 3); + expected = ShapeUtil::MakeShapeWithDenseLayout(PRED, {123, 456}, {1, 0}, {}, + 1, 0, 3); EXPECT_EQ(expected, actual) << "expected: " << ShapeUtil::HumanStringWithLayout(expected) << "actual: " << ShapeUtil::HumanStringWithLayout(actual); @@ -4179,8 +4218,8 @@ TEST_F(HloParserTest, ParseShapeStringWithMemorySpaceLayout) { // Memory space only. shape_string = "pred[123,456]{1,0:S(3)}"; TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string)); - expected = - ShapeUtil::MakeShapeWithDenseLayout(PRED, {123, 456}, {1, 0}, {}, 0, 3); + expected = ShapeUtil::MakeShapeWithDenseLayout(PRED, {123, 456}, {1, 0}, {}, + 1, 0, 3); EXPECT_EQ(expected, actual) << "expected: " << ShapeUtil::HumanStringWithLayout(expected) << "actual: " << ShapeUtil::HumanStringWithLayout(actual); @@ -4198,6 +4237,37 @@ TEST_F(HloParserTest, ParseShapeStringWithDynamicShapeMetadataPrefix) { << "actual: " << ShapeUtil::HumanStringWithLayout(actual); } +TEST_F(HloParserTest, ParseShapeStringWithSplitConfigLayout) { + // Tile, memory space, and split config. + std::string shape_string = "pred[123,456]{1,0:T(2,128)S(3)SC(1:200)}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithDenseLayout( + PRED, {123, 456}, {1, 0}, {Tile({2, 128})}, 1, 0, 3, + {SplitConfig(1, {200})}); + EXPECT_EQ(expected, actual) + << "expected: " << ShapeUtil::HumanStringWithLayout(expected) + << "actual: " << ShapeUtil::HumanStringWithLayout(actual); + + // Memory space and split config. + shape_string = "pred[123,456]{1,0:S(3)SC(0:10)(1:4,5)}"; + TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string)); + expected = ShapeUtil::MakeShapeWithDenseLayout( + PRED, {123, 456}, {1, 0}, {}, 1, 0, 3, + {SplitConfig(0, {10}), SplitConfig(1, {4, 5})}); + EXPECT_EQ(expected, actual) + << "expected: " << ShapeUtil::HumanStringWithLayout(expected) + << "actual: " << ShapeUtil::HumanStringWithLayout(actual); + + // Split config only. + shape_string = "pred[123,456]{1,0:SC(1:50,200)}"; + TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string)); + expected = ShapeUtil::MakeShapeWithDenseLayout( + PRED, {123, 456}, {1, 0}, {}, 1, 0, 0, {SplitConfig(1, {50, 200})}); + EXPECT_EQ(expected, actual) + << "expected: " << ShapeUtil::HumanStringWithLayout(expected) + << "actual: " << ShapeUtil::HumanStringWithLayout(actual); +} + TEST_F(HloParserTest, ParseOpaqueType) { TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("opaque[]")); Shape expected = ShapeUtil::MakeOpaqueShape(); @@ -4218,7 +4288,7 @@ TEST_F(HloParserTest, ParseInvalidShapeString) { std::string shape_strings[] = {"f32[123,456]foobar{0,1}", "f32[123,456]{foo}", "f32[123,456]dense{foo}"}; for (const std::string& shape_string : shape_strings) { - StatusOr result = ParseShape(shape_string); + absl::StatusOr result = ParseShape(shape_string); ASSERT_FALSE(result.ok()) << "shape: " << shape_string; } } @@ -4245,7 +4315,7 @@ TEST_F(HloParserTest, ParseDynamicTuple) { TEST_F(HloParserTest, ParseInvalidDimLevel) { constexpr std::string_view shape_string = "f32[123]{0:D(D+~)}"; - StatusOr result = ParseShape(shape_string); + absl::StatusOr result = ParseShape(shape_string); ASSERT_THAT( result.status(), tsl::testing::StatusIs( @@ -4436,6 +4506,21 @@ ENTRY InferDotShape { ShapeUtil::MakeShape(F32, {2}, {0}))); } +TEST_F(HloParserTest, InferSparseDotShape) { + constexpr char text[] = R"(HloModule InferSparseDotShapeTest +ENTRY InferSparseDotShape { + a = f32[2,16]{1,0} parameter(0) + b = f32[32,2]{1,0} parameter(1) + meta = u16[2,2]{1,0} parameter(2) + ROOT dot = dot(a, b, meta), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text)); + EXPECT_TRUE(ShapeUtil::Equal( + module->entry_computation()->ComputeProgramShape().result(), + ShapeUtil::MakeShape(F32, {2}, {0}))); +} + TEST_F(HloParserTest, InferTupleShape) { constexpr char text[] = R"(HloModule InferTupleShapeTest ENTRY InferTupleShape () -> s32[2,3] { @@ -4542,6 +4627,50 @@ ENTRY TestComputation { "attr_value"); } +TEST_F(HloParserTest, CheckAllowSpmdShardingPropagationToParameters) { + const char* const hlo_string = R"( +HloModule TestModule, allow_spmd_sharding_propagation_to_parameters=true + +ENTRY TestComputation { + p0 = f16[2048,1024] parameter(0) + p1 = f16[2048,1024] parameter(1) + ROOT root = (f16[2048,1024], f16[2048,1024]) tuple(p0, p1) +} +)"; + auto result = ParseAndReturnVerifiedModule(hlo_string); + TF_EXPECT_OK(result.status()); + EXPECT_EQ((*result) + ->config() + .allow_spmd_sharding_propagation_to_parameters() + .size(), + 1); + EXPECT_TRUE( + (*result)->config().allow_spmd_sharding_propagation_to_parameters()[0]); +} + +TEST_F(HloParserTest, CheckAllowSpmdShardingPropagationToParametersVec) { + const char* const hlo_string = R"( +HloModule TestModule, allow_spmd_sharding_propagation_to_parameters={true,false} + +ENTRY TestComputation { + p0 = f16[2048,1024] parameter(0) + p1 = f16[2048,1024] parameter(1) + ROOT root = (f16[2048,1024], f16[2048,1024]) tuple(p0, p1) +} +)"; + auto result = ParseAndReturnVerifiedModule(hlo_string); + TF_EXPECT_OK(result.status()); + EXPECT_EQ((*result) + ->config() + .allow_spmd_sharding_propagation_to_parameters() + .size(), + 2); + EXPECT_TRUE( + (*result)->config().allow_spmd_sharding_propagation_to_parameters()[0]); + EXPECT_FALSE( + (*result)->config().allow_spmd_sharding_propagation_to_parameters()[1]); +} + TEST_F(HloParserTest, CheckAllowSpmdShardingPropagationToOutput) { const char* const hlo_string = R"( HloModule TestModule, allow_spmd_sharding_propagation_to_output=true @@ -4771,8 +4900,8 @@ ENTRY %main { %input.5 = s32[] parameter(1) %broadcast = s32[1024]{0} broadcast(s32[] %input.5), dimensions={} %input.0 = s32[256]{0} parameter(0) - %async-start = ((s32[1024]{0}, s32[256]{0}, s32[]), s32[1024]{0}, u32[]) async-start(%broadcast, %input.0, %input.5), async_group_id=0, calls=%async_wrapped - ROOT %async-done = s32[1024]{0} async-done(((s32[1024]{0}, s32[256]{0}, s32[]), s32[1024]{0}, u32[]) %async-start), async_group_id=0, calls=%async_wrapped + %async-start = ((s32[1024]{0}, s32[256]{0}, s32[]), s32[1024]{0}, u32[]) async-start(%broadcast, %input.0, %input.5), calls=%async_wrapped + ROOT %async-done = s32[1024]{0} async-done(((s32[1024]{0}, s32[256]{0}, s32[]), s32[1024]{0}, u32[]) %async-start), calls=%async_wrapped } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, @@ -4803,5 +4932,359 @@ TEST_F(HloParserTest, LexesAsJsonDict) { EXPECT_FALSE(LexesAsJsonDict("{{{{}}}")); } +TEST_F(HloParserTest, AsyncStartMissingOperandWrapper) { + const char* const hlo_string = R"( +HloModule Module + +async_computation { + p = f32[2,3] parameter(0) + ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo" +} + +ENTRY AsyncStartMissingOperandWrapper { + p0 = f32[2,3] parameter(0) + async-start = (f32[2,3], f32[3,2], s32[]) async-start(p0), calls=async_computation + async-update = ((f32[2,3]), f32[3,2], s32[]) async-update(async-start), calls=async_computation + ROOT async-done = f32[3,2] async-done(async-update), calls=async_computation +} + )"; + EXPECT_THAT( + ParseAndReturnUnverifiedModule(hlo_string).status(), + tsl::testing::StatusIs( + tsl::error::INVALID_ARGUMENT, + HasSubstr("AsyncStart and AsyncUpdate expect the op shape to be " + "in the form of " + "((async-operands), async-outputs, state)."))); +} + +TEST_F(HloParserTest, AsyncUpdateMissingOperandWrapper) { + const char* const hlo_string = R"( +HloModule Module + +async_computation { + p = f32[2,3] parameter(0) + ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo" +} + +ENTRY AsyncUpdateMissingOperandWrapper { + p0 = f32[2,3] parameter(0) + async-start = ((f32[2,3]), f32[3,2], s32[]) async-start(p0), calls=async_computation + async-update = (f32[2,3], f32[3,2], s32[]) async-update(async-start), calls=async_computation + ROOT async-done = f32[3,2] async-done(async-update), calls=async_computation +} + )"; + EXPECT_THAT( + ParseAndReturnUnverifiedModule(hlo_string).status(), + tsl::testing::StatusIs( + tsl::error::INVALID_ARGUMENT, + HasSubstr("AsyncStart and AsyncUpdate expect the op shape to be " + "in the form of " + "((async-operands), async-outputs, state)."))); +} + +TEST_F(HloParserTest, AsyncOpTupleWrongType) { + const char* const hlo_string = R"( +HloModule Module + +async_computation { + p = f32[2,3] parameter(0) + ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo" +} + +ENTRY AsyncStartAndAsyncDone { + p0 = f32[2,3] parameter(0) + async-start = ((f32[2,3])) async-start(p0), calls=async_computation + ROOT async-done = f32[3,2] async-done(async-start), calls=async_computation +} + )"; + EXPECT_THAT( + ParseAndReturnUnverifiedModule(hlo_string).status(), + tsl::testing::StatusIs( + tsl::error::INVALID_ARGUMENT, + HasSubstr("AsyncStart and AsyncUpdate expect the op shape to be " + "in the form of " + "((async-operands), async-outputs, state)."))); +} + +TEST_F(HloParserTest, AsyncDoneNoAsyncStart) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY AsyncStartAndAsyncDone { + p0 = f32[2,3] parameter(0) + p1 = u32[] parameter(1) + tuple = ((f32[2,3]), f32[2,3], u32[]) tuple(p0, p0, p1) + ROOT async-done = f32[2,3] custom-call-done(tuple) +} + )"; + EXPECT_THAT( + ParseAndReturnUnverifiedModule(hlo_string).status(), + tsl::testing::StatusIs( + tsl::error::INVALID_ARGUMENT, + HasSubstr("AsyncUpdate and AsyncDone expect their operand to be " + "the previous async op."))); +} + +TEST_F(HloParserTest, AsyncUpdateAndAsyncDoneNoAsyncStart) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY AsyncStartAndAsyncDone { + p0 = f32[2,3] parameter(0) + p1 = u32[] parameter(1) + tuple = ((f32[2,3]), f32[2,3], u32[]) tuple(p0, p0, p1) + async-update = ((f32[2,3]), f32[2,3], u32[]) custom-call-update(tuple) + ROOT async-done = f32[2,3] custom-call-done(tuple) +} + )"; + EXPECT_THAT( + ParseAndReturnUnverifiedModule(hlo_string).status(), + tsl::testing::StatusIs( + tsl::error::INVALID_ARGUMENT, + HasSubstr("AsyncUpdate and AsyncDone expect their operand to be " + "the previous async op."))); +} + +TEST_F(HloParserTest, AsyncUpdateWithSyntaxSugarWrongOp) { + const char* const hlo_string = R"( +HloModule AsyncUpdateWithSyntaxSugarWrongOp + +ENTRY %Entry (p0: f32[10]) -> f32[20] { + %p0 = f32[10]{0} parameter(0) + %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), custom_call_target="foo" + %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) add-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start) + ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update) +} + )"; + EXPECT_THAT(ParseAndReturnUnverifiedModule(hlo_string).status(), + tsl::testing::StatusIs( + tsl::error::INVALID_ARGUMENT, + HasSubstr("Expect async wrapped opcode to be custom-call, " + "but got add"))); +} + +TEST_F(HloParserTest, AsyncDoneWithSyntaxSugarWrongOp) { + const char* const hlo_string = R"( +HloModule AsyncUpdateWithSyntaxSugarWrongOp + +ENTRY %Entry (p0: f32[10]) -> f32[20] { + %p0 = f32[10]{0} parameter(0) + %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), custom_call_target="foo" + %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start) + ROOT %async-done = f32[20]{0} add-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update) +} + )"; + EXPECT_THAT(ParseAndReturnUnverifiedModule(hlo_string).status(), + tsl::testing::StatusIs( + tsl::error::INVALID_ARGUMENT, + HasSubstr("Expect async wrapped opcode to be custom-call, " + "but got add"))); +} + +TEST_F(HloParserTest, AsyncOpSharedComputation) { + const char* const hlo_string = R"( +HloModule AsyncOpSharedComputation + +%async_wrapped (async_param: f32[10]) -> f32[20] { + %async_param = f32[10]{0} parameter(0) + ROOT %call = f32[20]{0} custom-call(f32[10]{0} %async_param), custom_call_target="foo" +} + +ENTRY %Entry (p0: f32[10]) -> f32[20] { + %p0 = f32[10]{0} parameter(0) + %async-start.0 = ((f32[10]{0}), f32[20]{0}, s32[]) async-start(f32[10]{0} %p0), calls=%async_wrapped + %async-done.0 = f32[20]{0} async-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-start.0) + %async-start.1 = ((f32[10]{0}), f32[20]{0}, s32[]) async-start(f32[10]{0} %p0), calls=%async_wrapped + ROOT %async-done.1 = f32[20]{0} async-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-start.1) +} + )"; + EXPECT_THAT(ParseAndReturnUnverifiedModule(hlo_string).status(), + tsl::testing::StatusIs( + tsl::error::INVALID_ARGUMENT, + HasSubstr("Computation async_wrapped is already referenced " + "by another async op"))); +} + +TEST_F(HloParserTest, AsyncUpdateWrongComputation) { + const char* const hlo_string = R"( +HloModule AsyncUpdateWrongComputation + +%async_wrapped.0 (async_param: f32[10]) -> f32[20] { + %async_param = f32[10]{0} parameter(0) + ROOT %custom-call = f32[20]{0} custom-call(f32[10]{0} %async_param), custom_call_target="foo" +} + +%async_wrapped.1 (async_param: f32[10]) -> f32[20] { + %async_param = f32[10]{0} parameter(0) + ROOT %custom-call = f32[20]{0} custom-call(f32[10]{0} %async_param), custom_call_target="foo" +} + +ENTRY %Entry (p0: f32[10]) -> f32[20] { + %p0 = f32[10]{0} parameter(0) + %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) async-start(f32[10]{0} %p0), calls=%async_wrapped.0 + %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) async-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), calls=%async_wrapped.1 + ROOT %async-done = f32[20]{0} async-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update) +} + )"; + EXPECT_THAT( + ParseAndReturnUnverifiedModule(hlo_string).status(), + tsl::testing::StatusIs( + tsl::error::INVALID_ARGUMENT, + HasSubstr("Expect async_wrapped_computation to be async_wrapped.0, " + "but got async_wrapped.1"))); +} + +TEST_F(HloParserTest, AsyncDoneWrongComputation) { + const char* const hlo_string = R"( +HloModule AsyncDoneWrongComputation + +%async_wrapped.0 (async_param: f32[10]) -> f32[20] { + %async_param = f32[10]{0} parameter(0) + ROOT %custom-call = f32[20]{0} custom-call(f32[10]{0} %async_param), custom_call_target="foo" +} + +%async_wrapped.1 (async_param: f32[10]) -> f32[20] { + %async_param = f32[10]{0} parameter(0) + ROOT %custom-call = f32[20]{0} custom-call(f32[10]{0} %async_param), custom_call_target="foo" +} + +ENTRY %Entry (p0: f32[10]) -> f32[20] { + %p0 = f32[10]{0} parameter(0) + %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) async-start(f32[10]{0} %p0), calls=%async_wrapped.0 + %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) async-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start) + ROOT %async-done = f32[20]{0} async-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), calls=%async_wrapped.1 +} + )"; + EXPECT_THAT( + ParseAndReturnUnverifiedModule(hlo_string).status(), + tsl::testing::StatusIs( + tsl::error::INVALID_ARGUMENT, + HasSubstr("Expect async_wrapped_computation to be async_wrapped.0, " + "but got async_wrapped.1"))); +} + +TEST_F(HloParserTest, AsyncUpdateWrongDefaultThread) { + const char* const hlo_string = R"( +HloModule AsyncUpdateWrongDefaultThread + +ENTRY %Entry (p0: f32[10]) -> f32[20] { + %p0 = f32[10]{0} parameter(0) + %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), custom_call_target="foo" + %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), async_execution_thread="foo_thread" + ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update) +} + )"; + EXPECT_THAT(ParseAndReturnUnverifiedModule(hlo_string).status(), + tsl::testing::StatusIs( + tsl::error::INVALID_ARGUMENT, + HasSubstr("Expect async_execution_thread to be main, " + "but got foo_thread"))); +} + +TEST_F(HloParserTest, AsyncDoneWrongDefaultThread) { + const char* const hlo_string = R"( +HloModule AsyncDoneWrongDefaultThread + +ENTRY %Entry (p0: f32[10]) -> f32[20] { + %p0 = f32[10]{0} parameter(0) + %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), custom_call_target="foo" + %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start) + ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), async_execution_thread="foo_thread" +} + )"; + EXPECT_THAT(ParseAndReturnUnverifiedModule(hlo_string).status(), + tsl::testing::StatusIs( + tsl::error::INVALID_ARGUMENT, + HasSubstr("Expect async_execution_thread to be main, " + "but got foo_thread"))); +} + +TEST_F(HloParserTest, PipelinedSendRecv) { + const std::string hlo_string = R"( + HloModule test + cond { + param = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) parameter(0) + count = get-tuple-element(%param), index=0 + ub = u32[] constant(1) + ROOT result = pred[] compare(count, ub), direction=LT + } + + body { + param = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) parameter(0) + count = get-tuple-element(%param), index=0 + + recv.0 = (u32[2], u32[], token[]) get-tuple-element(param), index=1 + recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + recv-data.0 = u32[2] get-tuple-element(recv-done.0), index=0 + + c1 = u32[] constant(1) + new_count = u32[] add(count, c1) + + send.0 = (u32[2], u32[], token[]) get-tuple-element(param), index=2 + send-done.0 = (u32[2], token[]) recv-done(send.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + + after-all.0.n = token[] after-all() + recv.0.n = (u32[2], u32[], token[]) recv(after-all.0.n), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}", + _xla_send_recv_pipeline="0" + } + + + after-all.1.n = token[] after-all() + send.0.n = (u32[2], u32[], token[]) send(recv-data.0, after-all.1.n), + channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}", + _xla_send_recv_pipeline="0" + } + + ROOT result = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) tuple(new_count, recv.0.n, send.0.n) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + init = u32[2] broadcast(c0), dimensions={} + after-all.0.p = token[] after-all() + recv.0.p = (u32[2], u32[], token[]) recv(after-all.0.p), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}", + _xla_send_recv_pipeline="0" + } + + after-all.1.p = token[] after-all() + send.0.p = (u32[2], u32[], token[]) send(init, after-all.1.p), + channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}", + _xla_send_recv_pipeline="0" + } + + while_init = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) tuple(c0, recv.0.p, send.0.p) + while_result = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) while(while_init), body=body, condition=cond + + recv.0.q = (u32[2], u32[], token[]) get-tuple-element(while_result), index=1 + recv-done.0.q = (u32[2], token[]) recv-done(recv.0.q), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send.0.q = (u32[2], u32[], token[]) get-tuple-element(while_result), index=2 + send-done.0.q = token[] send-done(send.0.q), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + + ROOT recv-data.0.q = u32[2] get-tuple-element(recv-done.0.q), index=0 + })"; + auto result = ParseAndReturnUnverifiedModule(hlo_string); + EXPECT_EQ(OkStatus(), result.status()); +} + } // namespace } // namespace xla diff --git a/xla/service/hlo_pass_fix.h b/xla/service/hlo_pass_fix.h index 2b62053b3b877..3835a013b9031 100644 --- a/xla/service/hlo_pass_fix.h +++ b/xla/service/hlo_pass_fix.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -50,18 +50,19 @@ class HloPassFix : public Pass { } using HloPassInterface::Run; - StatusOr Run(HloModule* module, - const absl::flat_hash_set& - execution_threads) override { + absl::StatusOr Run(HloModule* module, + const absl::flat_hash_set& + execution_threads) override { RunState run_state(module); TF_RETURN_IF_ERROR(RunToFixPoint(module, &run_state, execution_threads)); return !run_state.changed.empty(); } using HloPassInterface::RunOnModuleGroup; - StatusOr RunOnModuleGroup(HloModuleGroup* module_group, - const absl::flat_hash_set& - execution_threads) override { + absl::StatusOr RunOnModuleGroup( + HloModuleGroup* module_group, + const absl::flat_hash_set& execution_threads) + override { bool changed = false; bool changed_this_iteration = true; int64_t iteration_count = 0; diff --git a/xla/service/hlo_pass_interface.h b/xla/service/hlo_pass_interface.h index d3b5c34078b6b..64fee1155c7f1 100644 --- a/xla/service/hlo_pass_interface.h +++ b/xla/service/hlo_pass_interface.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -83,10 +83,10 @@ class HloPassInterface { // override; // }; // - StatusOr Run(HloModule* module) { + absl::StatusOr Run(HloModule* module) { return Run(module, /*execution_threads=*/{}); } - virtual StatusOr Run( + virtual absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) = 0; @@ -132,10 +132,10 @@ class HloPassInterface { // override; // }; // - StatusOr RunOnModuleGroup(HloModuleGroup* module_group) { + absl::StatusOr RunOnModuleGroup(HloModuleGroup* module_group) { return RunOnModuleGroup(module_group, /*execution_threads=*/{}); } - virtual StatusOr RunOnModuleGroup( + virtual absl::StatusOr RunOnModuleGroup( HloModuleGroup* module_group, const absl::flat_hash_set& execution_threads) = 0; @@ -147,9 +147,10 @@ class HloModulePass : public HloPassInterface { public: // Runs the pass on a module group by iterating through each module in the // group. - StatusOr RunOnModuleGroup(HloModuleGroup* module_group, - const absl::flat_hash_set& - execution_threads) override { + absl::StatusOr RunOnModuleGroup( + HloModuleGroup* module_group, + const absl::flat_hash_set& execution_threads) + override { bool changed = false; for (HloModule* module : module_group->modules()) { TF_ASSIGN_OR_RETURN(bool module_changed, Run(module, execution_threads)); @@ -171,10 +172,10 @@ class HloModulePass : public HloPassInterface { // on an HLO module. class HloModuleGroupPass : public HloPassInterface { public: - StatusOr Run(HloModule* module, - const absl::flat_hash_set& - execution_threads) override { - return InternalError("Module group pass cannot be run on a module"); + absl::StatusOr Run(HloModule* module, + const absl::flat_hash_set& + execution_threads) override { + return Internal("Module group pass cannot be run on a module"); } }; diff --git a/xla/service/hlo_pass_pipeline.cc b/xla/service/hlo_pass_pipeline.cc index b795e65faf4b4..ece25312e9dcf 100644 --- a/xla/service/hlo_pass_pipeline.cc +++ b/xla/service/hlo_pass_pipeline.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" +#include "tsl/profiler/lib/scoped_annotation.h" namespace xla { @@ -98,29 +99,6 @@ void RecordPassEndMetadata(HloModuleGroup& module_group, } } -void SetInstructionMetadata(HloModule& module) { - StatusOr pass_id = module.metadata()->current_pass_id(); - if (!pass_id.ok()) { - LOG(FATAL) << pass_id.status(); - } - for (xla::HloComputation* computation : module.computations()) { - for (xla::HloInstruction* instruction : computation->instructions()) { - if (instruction->metadata().creation_pass_id() == 0) { - instruction->set_creation_pass_id(*pass_id); - } - if (instruction->metadata().logical_creation_pass_id() == 0) { - instruction->set_logical_creation_pass_id(*pass_id); - } - } - } -} - -void SetInstructionMetadata(HloModuleGroup& module_group) { - for (HloModule* module : module_group.modules()) { - SetInstructionMetadata(*module); - } -} - } // namespace template @@ -129,7 +107,7 @@ Status HloPassPipeline::RunInvariantCheckers( const absl::flat_hash_set& execution_threads) { for (auto& invariant_checker : invariant_checkers_) { VLOG(1) << " Invariant checker " << invariant_checker->name(); - StatusOr changed_status = + absl::StatusOr changed_status = RunHelper(invariant_checker.get(), hlo, execution_threads); VLOG(1) << " Invariant checker done " << invariant_checker->name(); if (!changed_status.ok()) { @@ -146,8 +124,20 @@ Status HloPassPipeline::RunInvariantCheckers( return OkStatus(); } +namespace { +std::string UniqueId(const HloModule& mod) { + return std::to_string(mod.unique_id()); +} +std::string UniqueId(const HloModuleGroup& group) { + return absl::StrJoin(group.modules(), "-", + [](std::string* out, const HloModule* mod) { + out->append(std::to_string(mod->unique_id())); + }); +} +} // namespace + template -StatusOr HloPassPipeline::RunPassesInternal( +absl::StatusOr HloPassPipeline::RunPassesInternal( HloT* hlo, const DebugOptions& debug_options, const absl::flat_hash_set& execution_threads) { auto passes = GetEnabledPasses(debug_options); @@ -157,12 +147,15 @@ StatusOr HloPassPipeline::RunPassesInternal( static constexpr absl::string_view kPipelineStart = "pipeline-start"; static constexpr absl::string_view kPipelineEnd = "pipeline-end"; std::string pipeline_name = std::string(name()); + tsl::profiler::ScopedAnnotation annotation{[&] { + return absl::StrFormat("XlaPassPipeline:#name=%s,module=%s,program_id=%s#", + pipeline_name, hlo->name(), UniqueId(*hlo)); + }}; TF_RETURN_IF_ERROR( RunInvariantCheckers(hlo, kPipelineStart, execution_threads)); RecordPassStartMetadata(*hlo, std::string(kPipelineStart), pipeline_name); - SetInstructionMetadata(*hlo); MaybeDumpHloAndSaveFilenames(*hlo, /*after_pass_name=*/kPipelineStart, /*before_pass_name=*/passes.empty() @@ -174,29 +167,22 @@ StatusOr HloPassPipeline::RunPassesInternal( bool changed = false; for (int i = 0; i < passes.size(); i++) { HloPassInterface* pass = passes[i]; - XLA_SCOPED_LOGGING_TIMER(absl::StrCat("HLO pass: ", pass->name())); std::string pass_name = std::string(pass->name()); + XLA_SCOPED_LOGGING_TIMER(absl::StrCat("HLO pass: ", pass_name)); + tsl::profiler::ScopedAnnotation annotation{ + [&] { return "XlaPass:" + pass_name; }}; VLOG(1) << " HLO pass " << pass_name; VLOG(2) << " Module hash " << absl::HashOf(*hlo); if (!pass->IsPassPipeline()) { compilation_stats_->StartPass(pass_name); } RecordPassStartMetadata(*hlo, pass_name, pipeline_name); - // Embed RunHelper into lambda to enable recording of error statuses - auto run_helper_lambda = - [this, pass_name]( - HloPassInterface* pass, HloT* hlo, - const absl::flat_hash_set& execution_threads) { - auto status_or = RunHelper(pass, hlo, execution_threads); - if (!status_or.ok()) { - compilation_stats_->RecordPassError( - pass_name, absl::StatusCodeToString(status_or.status().code())); - } - return status_or; - }; - TF_ASSIGN_OR_RETURN(bool pass_changed, - run_helper_lambda(pass, hlo, execution_threads)); - SetInstructionMetadata(*hlo); + auto status_or_changed = RunHelper(pass, hlo, execution_threads); + if (auto status = status_or_changed.status(); !status.ok()) { + compilation_stats_->RecordPassError( + pass_name, absl::StatusCodeToString(status.code())); + } + TF_ASSIGN_OR_RETURN(bool pass_changed, status_or_changed); if (!dump_regex.empty() && (pass_changed || dump_regex != ".*")) { MaybeDumpHloAndSaveFilenames(*hlo, /*after_pass_name=*/pass_name, @@ -207,22 +193,13 @@ StatusOr HloPassPipeline::RunPassesInternal( RecordPassEndMetadata(*hlo, pass_name, pass_changed); changed |= pass_changed; if (pass_changed) { - VLOG(3) << " Pass caused changes " << pass->name(); - // Embed RunInvariantCheckers into lambda to enable recording of errors - auto run_invariant_checkers_lambda = - [this]( - HloT* hlo, absl::string_view pass_name, - const absl::flat_hash_set& execution_threads) { - auto status = - RunInvariantCheckers(hlo, pass_name, execution_threads); - if (!status.ok()) { - compilation_stats_->RecordPassError( - pass_name, absl::StatusCodeToString(status.code())); - } - return status; - }; - TF_RETURN_IF_ERROR( - run_invariant_checkers_lambda(hlo, pass_name, execution_threads)); + VLOG(3) << " Pass caused changes " << pass_name; + auto status = RunInvariantCheckers(hlo, pass_name, execution_threads); + if (!status.ok()) { + compilation_stats_->RecordPassError( + pass_name, absl::StatusCodeToString(status.code())); + } + TF_RETURN_IF_ERROR(status); } if (!pass->IsPassPipeline()) { compilation_stats_->EndPass(pass_name); @@ -307,7 +284,7 @@ void HloPassPipeline::MaybeDumpHloAndSaveFilenames( } } -StatusOr HloPassPipeline::Run( +absl::StatusOr HloPassPipeline::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { run_called_ = true; @@ -319,7 +296,7 @@ StatusOr HloPassPipeline::Run( execution_threads); } -StatusOr HloPassPipeline::RunOnModuleGroup( +absl::StatusOr HloPassPipeline::RunOnModuleGroup( HloModuleGroup* module_group, const absl::flat_hash_set& execution_threads) { run_called_ = true; diff --git a/xla/service/hlo_pass_pipeline.h b/xla/service/hlo_pass_pipeline.h index 2fc9ded2afc2f..6ee1fc5e20056 100644 --- a/xla/service/hlo_pass_pipeline.h +++ b/xla/service/hlo_pass_pipeline.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -79,11 +79,11 @@ class HloPassPipeline : public HloPassInterface { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; using HloPassInterface::RunOnModuleGroup; - StatusOr RunOnModuleGroup( + absl::StatusOr RunOnModuleGroup( HloModuleGroup* module_group, const absl::flat_hash_set& execution_threads) override; @@ -125,7 +125,7 @@ class HloPassPipeline : public HloPassInterface { // Helper which runs the given pass on the given HLO. HloT can be either // HloModule or HloModuleGroup. template - StatusOr RunPassesInternal( + absl::StatusOr RunPassesInternal( HloT* hlo, const DebugOptions& debug_options, const absl::flat_hash_set& execution_threads); @@ -134,14 +134,14 @@ class HloPassPipeline : public HloPassInterface { // empty thread list means all `execution_threads` are considered. These // helpers enable templating of the core of the pipeline logic by providing // HloModule and HloModuleGroup specific methods with the same name. - static StatusOr RunHelper( + static absl::StatusOr RunHelper( HloPassInterface* pass, HloModule* module, const absl::flat_hash_set& execution_threads) { TF_ASSIGN_OR_RETURN(bool changed, pass->Run(module, execution_threads)); module->Cleanup(); return changed; } - static StatusOr RunHelper( + static absl::StatusOr RunHelper( HloPassInterface* pass, HloModuleGroup* module_group, const absl::flat_hash_set& execution_threads) { TF_ASSIGN_OR_RETURN( diff --git a/xla/service/hlo_pass_pipeline_test.cc b/xla/service/hlo_pass_pipeline_test.cc index 4dc43b0ff610e..252f9ccb7124e 100644 --- a/xla/service/hlo_pass_pipeline_test.cc +++ b/xla/service/hlo_pass_pipeline_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,7 +32,7 @@ using ::testing::StrEq; class HloPassPipelineTest : public HloTestBase { protected: - StatusOr ParseModuleGroup( + absl::StatusOr ParseModuleGroup( absl::Span hlo_strings) { HloModuleGroup group(TestName()); for (const std::string& hlo_string : hlo_strings) { @@ -49,9 +49,9 @@ class FooToBarModulePass : public HloModulePass { absl::string_view name() const override { return "foo2bar"; } using HloPassInterface::Run; - StatusOr Run(HloModule* module, - const absl::flat_hash_set& - execution_threads) override { + absl::StatusOr Run(HloModule* module, + const absl::flat_hash_set& + execution_threads) override { bool changed = false; for (HloComputation* computation : module->computations(execution_threads)) { @@ -72,9 +72,9 @@ class ReverseStringModulePass : public HloModulePass { absl::string_view name() const override { return "reverse"; } using HloPassInterface::Run; - StatusOr Run(HloModule* module, - const absl::flat_hash_set& - execution_threads) override { + absl::StatusOr Run(HloModule* module, + const absl::flat_hash_set& + execution_threads) override { bool changed = false; for (HloComputation* computation : module->computations(execution_threads)) { @@ -93,9 +93,10 @@ class BazToQuxModuleGroupPass : public HloModuleGroupPass { absl::string_view name() const override { return "baz2qux"; } using HloPassInterface::RunOnModuleGroup; - StatusOr RunOnModuleGroup(HloModuleGroup* module_group, - const absl::flat_hash_set& - execution_threads) override { + absl::StatusOr RunOnModuleGroup( + HloModuleGroup* module_group, + const absl::flat_hash_set& execution_threads) + override { bool changed = false; for (HloModule* module : module_group->modules()) { for (HloComputation* computation : @@ -118,14 +119,14 @@ class BarBlowerUpper : public HloModulePass { absl::string_view name() const override { return "bar-blower-upper"; } using HloPassInterface::Run; - StatusOr Run(HloModule* module, - const absl::flat_hash_set& - execution_threads) override { + absl::StatusOr Run(HloModule* module, + const absl::flat_hash_set& + execution_threads) override { for (HloComputation* computation : module->computations(execution_threads)) { for (HloInstruction* instruction : computation->instructions()) { if (instruction->name() == "bar") { - return InternalError("Module has instruction named bar"); + return Internal("Module has instruction named bar"); } } } diff --git a/xla/service/hlo_phi_graph.cc b/xla/service/hlo_phi_graph.cc index a984d436fd076..d272907ee9b0b 100644 --- a/xla/service/hlo_phi_graph.cc +++ b/xla/service/hlo_phi_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_phi_graph.h b/xla/service/hlo_phi_graph.h index f4c00a8964e17..b8c56454f07ca 100644 --- a/xla/service/hlo_phi_graph.h +++ b/xla/service/hlo_phi_graph.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_phi_graph_test.cc b/xla/service/hlo_phi_graph_test.cc index f4f60512bd8f9..b5184aa4d41ed 100644 --- a/xla/service/hlo_phi_graph_test.cc +++ b/xla/service/hlo_phi_graph_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_profile_printer.cc b/xla/service/hlo_profile_printer.cc index fbe8ec930ced8..f26e430323fb5 100644 --- a/xla/service/hlo_profile_printer.cc +++ b/xla/service/hlo_profile_printer.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_profile_printer.h b/xla/service/hlo_profile_printer.h index 10c2ace2c9b85..e15c1ab8c8c32 100644 --- a/xla/service/hlo_profile_printer.h +++ b/xla/service/hlo_profile_printer.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_profile_printer_data.proto b/xla/service/hlo_profile_printer_data.proto index f752bc08154db..5231d13d65853 100644 --- a/xla/service/hlo_profile_printer_data.proto +++ b/xla/service/hlo_profile_printer_data.proto @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_proto_util.cc b/xla/service/hlo_proto_util.cc index d03d28145079d..b04fddf8c2fa6 100644 --- a/xla/service/hlo_proto_util.cc +++ b/xla/service/hlo_proto_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -39,7 +39,7 @@ HloProto MakeHloProto(const HloModule& module) { return proto; } -StatusOr> CreateModuleFromProto( +absl::StatusOr> CreateModuleFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config, bool is_module_post_optimizations) { VLOG(4) << proto.ShortDebugString(); @@ -53,7 +53,7 @@ StatusOr> CreateModuleFromProto( return module; } -StatusOr> EntryComputationParameterShapes( +absl::StatusOr> EntryComputationParameterShapes( const HloProto& hlo_proto) { if (!hlo_proto.has_hlo_module()) { return NotFound("HloProto missing HloModuleProto."); @@ -70,7 +70,7 @@ StatusOr> EntryComputationParameterShapes( return parameter_shapes; } -StatusOr EntryComputationOutputShape( +absl::StatusOr EntryComputationOutputShape( const HloProto& hlo_proto) { if (!hlo_proto.has_hlo_module()) { return NotFound("HloProto missing HloModuleProto."); diff --git a/xla/service/hlo_proto_util.h b/xla/service/hlo_proto_util.h index f52fa8d3f0ef5..8bdaf38b0c04d 100644 --- a/xla/service/hlo_proto_util.h +++ b/xla/service/hlo_proto_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -41,18 +41,18 @@ HloProto MakeHloProto(const HloModule& module); // The HLO module could be a pre-optimizations (default) or post-optimizations // module, which affects how the HLO module is verified, e.g., mixed-precision // is allowed in post-optimizations HLOs. -StatusOr> CreateModuleFromProto( +absl::StatusOr> CreateModuleFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config, bool is_module_post_optimizations = false); // Returns the shapes of the parameters of the entry computation. Shape pointers // refer to shapes inside of the given HloProto. -StatusOr> EntryComputationParameterShapes( +absl::StatusOr> EntryComputationParameterShapes( const HloProto& hlo_proto); // Returns the shape of the output of the entry computation. The shape pointer // refers to the output shape inside of the given HloProto. -StatusOr EntryComputationOutputShape( +absl::StatusOr EntryComputationOutputShape( const HloProto& hlo_proto); } // namespace xla diff --git a/xla/service/hlo_proto_util_test.cc b/xla/service/hlo_proto_util_test.cc index d1b4922796b8c..d5ef461e58878 100644 --- a/xla/service/hlo_proto_util_test.cc +++ b/xla/service/hlo_proto_util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_reachability_test.cc b/xla/service/hlo_reachability_test.cc index 981a635b86c7f..bc0d2b7293b47 100644 --- a/xla/service/hlo_reachability_test.cc +++ b/xla/service/hlo_reachability_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,13 +15,17 @@ limitations under the License. #include "xla/hlo/ir/hlo_reachability.h" +#include #include +#include +#include "absl/random/random.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/computation_placer.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/test_benchmark.h" namespace xla { @@ -242,4 +246,74 @@ TEST_F(HloReachabilityTest, ReplaceInstructions) { } // namespace +class HloReachabilityMapBitSetBenchmark { + public: + explicit HloReachabilityMapBitSetBenchmark(int size) : a_(size), b_(size) { + // Initialize the bit sets to random inputs. Done out of caution -- note + // that a sufficiently smart optimizer might realize that the bit sets + // are otherwise initialized to 0. + absl::BitGen gen; + for (int i = 0; i < size; ++i) { + if (absl::Bernoulli(gen, 0.5)) a_.Set(i); + if (absl::Bernoulli(gen, 0.5)) b_.Set(i); + } + } + void Union() { a_ |= b_; } + + private: + HloReachabilityMap::BitSet a_; + HloReachabilityMap::BitSet b_; +}; + +namespace { + +void BM_HloReachabilityBitSetUnion(benchmark::State& state) { + HloReachabilityMapBitSetBenchmark bm(state.range(0)); + for (auto s : state) { + bm.Union(); + } +} +#define BM_ARGS Arg(1)->Arg(64)->Arg(128)->Arg(256)->Range(512, 256 * 1024) +BENCHMARK(BM_HloReachabilityBitSetUnion)->BM_ARGS; + +class HloReachabilityBenchmark { + public: + HloReachabilityBenchmark(int size, std::string_view name) : name_(name) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + auto builder = HloComputation::Builder(name); + + // Build a graph of chained Exponentials, i.e. Exp(...(Exp(Input))...). + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); + HloInstruction* prev = constant; + for (int i = 1; i < size; ++i) { + prev = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, prev)); + } + + HloModuleConfig hlo_config; + module_ = std::make_unique(name_, hlo_config); + computation_ = + module_->AddEntryComputation(builder.Build(/*root_instruction=*/prev)); + } + std::unique_ptr Build() { + return HloReachabilityMap::Build(computation_); + } + + private: + std::unique_ptr module_; + HloComputation* computation_; + const std::string name_; +}; + +void BM_HloReachabilityBuild(benchmark::State& state) { + HloReachabilityBenchmark bm(state.range(0), state.name()); + for (auto s : state) { + benchmark::DoNotOptimize(bm.Build()); + } +} +BENCHMARK(BM_HloReachabilityBuild)->BM_ARGS; + +} // namespace + } // namespace xla diff --git a/xla/service/hlo_rematerialization.cc b/xla/service/hlo_rematerialization.cc index 70054780ee382..a524e3cca559d 100644 --- a/xla/service/hlo_rematerialization.cc +++ b/xla/service/hlo_rematerialization.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -723,7 +723,7 @@ class MemoryUsageTracker { // Get the compact shape of given hlo instruction. An internal cache is used // to avoid computing the shape multiple times. - StatusOr GetCompactShape(const HloInstruction* hlo); + absl::StatusOr GetCompactShape(const HloInstruction* hlo); // Creates a Buffer representing the given logical buffer. The buffer is added // to buffers_ and a reference is returned. @@ -1506,7 +1506,8 @@ std::string MemoryUsageTracker::ToString() const { return output; } -StatusOr MemoryUsageTracker::GetCompactShape(const HloInstruction* hlo) { +absl::StatusOr MemoryUsageTracker::GetCompactShape( + const HloInstruction* hlo) { auto it = compact_shape_.find(hlo); if (it != compact_shape_.end()) { return it->second; @@ -1988,7 +1989,7 @@ UsesList MemoryUsageTracker::GetItemUses(Item* item) const { return combined_users; } -StatusOr RematerializeInstructions( +absl::StatusOr RematerializeInstructions( MemoryUsageTracker* memory_tracker, std::vector* best_items, absl::flat_hash_set* remat_move_instructions, InstructionList* instruction_list, HloSchedule* schedule, @@ -2032,12 +2033,7 @@ StatusOr RematerializeInstructions( } // Add control dependencies to the new operation. - for (auto successor : best->control_successors()) { - TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor)); - } - for (auto predecessor : best->control_predecessors()) { - TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat)); - } + TF_RETURN_IF_ERROR(remat->CopyAllControlDepsFrom(best)); Item* remat_item = instruction_list->CreateItem(remat); @@ -2166,16 +2162,20 @@ StatusOr RematerializeInstructions( VLOG(2) << "The old instruction " << best->name() << " is an async op. Removing to maintain one start to one done " "invariant to keep the HLO valid."; + // We need to remove all control dependencies from best before removing it + // from the computation. Its control dependencies were previously copied + // to the remat instruction. + TF_RETURN_IF_ERROR(best->DropAllControlDeps()); TF_RETURN_IF_ERROR(computation->RemoveInstruction(best)); } } return net_instructions_added; } -StatusOr CompressInstruction(MemoryUsageTracker* memory_tracker, - Item* best_item, - const Shape& compact_shape, - InstructionList* instruction_list) { +absl::StatusOr CompressInstruction(MemoryUsageTracker* memory_tracker, + Item* best_item, + const Shape& compact_shape, + InstructionList* instruction_list) { HloInstruction* best = best_item->instruction; VLOG(5) << "Transposing instruction " << best->name() << " (saving " << HumanReadableNumBytes(memory_tracker->MemoryReducedIfCompressed( @@ -2225,9 +2225,9 @@ StatusOr CompressInstruction(MemoryUsageTracker* memory_tracker, return 2; } -StatusOr OffloadInstruction(MemoryUsageTracker* memory_tracker, - Item* best_item, - InstructionList* instruction_list) { +absl::StatusOr OffloadInstruction(MemoryUsageTracker* memory_tracker, + Item* best_item, + InstructionList* instruction_list) { HloInstruction* best_instruction = best_item->instruction; HloComputation* computation = best_instruction->parent(); VLOG(2) << "Best_instruction's users: " @@ -2502,7 +2502,7 @@ struct InstructionsAdded { // Rematerializes the best block of instructions of size between min_block_size // and max_block_size (both inclusive) if at least one candidate block of // instructions can be found. Returns number of instructions rematerialized. -StatusOr RematerializeBestBlock( +absl::StatusOr RematerializeBestBlock( int min_block_size, int max_block_size, MemoryUsageTracker* memory_tracker, InstructionList* instruction_list, HloSchedule* schedule, int64_t memory_limit_bytes, @@ -2571,7 +2571,7 @@ StatusOr RematerializeBestBlock( } } // namespace -StatusOr HloRematerialization::ComputePeakMemory( +absl::StatusOr HloRematerialization::ComputePeakMemory( const HloComputation* computation, const HloInstructionSequence& order, const absl::flat_hash_set& execution_threads) const { InstructionList instruction_list(order); @@ -2594,7 +2594,7 @@ StatusOr HloRematerialization::ComputePeakMemory( return peak_memory; } -StatusOr HloRematerialization::CalledComputationsMemoryUsage( +absl::StatusOr HloRematerialization::CalledComputationsMemoryUsage( const HloInstruction* instruction, const absl::flat_hash_set& execution_threads) const { const CallSite* callsite = @@ -2614,7 +2614,7 @@ StatusOr HloRematerialization::CalledComputationsMemoryUsage( return callee_usage; } -StatusOr HloRematerialization::RematerializeComputation( +absl::StatusOr HloRematerialization::RematerializeComputation( HloComputation* computation, HloSchedule* schedule, int64_t memory_limit_bytes, int64_t min_remat_size, const absl::flat_hash_set& execution_threads) { @@ -2816,7 +2816,7 @@ StatusOr HloRematerialization::RematerializeComputation( return changed; } -StatusOr HloRematerialization::Run( +absl::StatusOr HloRematerialization::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { if (options_.remat_mode_config.host_offload) { diff --git a/xla/service/hlo_rematerialization.h b/xla/service/hlo_rematerialization.h index e73de37112aa1..4fae1cc2ddbbd 100644 --- a/xla/service/hlo_rematerialization.h +++ b/xla/service/hlo_rematerialization.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -45,7 +45,8 @@ class HloRematerialization : public HloModulePass { public: using ShapeSizeFunction = std::function; - using CompactShapeFunction = std::function(const Shape&)>; + using CompactShapeFunction = + std::function(const Shape&)>; // Helper struct that communicates the before / after sizes for the // rematerialization process. @@ -166,7 +167,7 @@ class HloRematerialization : public HloModulePass { // specified in the constructor then no instructions are rematerialized and // false is returned. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -175,15 +176,15 @@ class HloRematerialization : public HloModulePass { // order in which the computation's instructions will be emitted in the // backend. Rematerialized instructions will be added to the HLO computation // and inserted into 'order'. - StatusOr RematerializeComputation(HloComputation* computation, - HloSchedule* schedule, - int64_t memory_limit_bytes, - int64_t min_remat_size) { + absl::StatusOr RematerializeComputation(HloComputation* computation, + HloSchedule* schedule, + int64_t memory_limit_bytes, + int64_t min_remat_size) { return RematerializeComputation(computation, schedule, memory_limit_bytes, min_remat_size, /*execution_threads=*/{}); } - virtual StatusOr RematerializeComputation( + virtual absl::StatusOr RematerializeComputation( HloComputation* computation, HloSchedule* schedule, int64_t memory_limit_bytes, int64_t min_remat_size, const absl::flat_hash_set& execution_threads); @@ -192,13 +193,13 @@ class HloRematerialization : public HloModulePass { // peak memory is the maximum total size of all live HLO instruction values at // any program point. 'order' is the order in which the HLO instructions will // be emitted which is used to determine lifespans of HLO values. - StatusOr ComputePeakMemory( + absl::StatusOr ComputePeakMemory( const HloComputation* computation, const HloInstructionSequence& order, const absl::flat_hash_set& execution_threads) const; // Returns the peak memory usage of the called computations for the given // instruction. Zero is returned if the instruction calls no computations. - StatusOr CalledComputationsMemoryUsage( + absl::StatusOr CalledComputationsMemoryUsage( const HloInstruction* instruction, const absl::flat_hash_set& execution_threads) const; diff --git a/xla/service/hlo_rematerialization_test.cc b/xla/service/hlo_rematerialization_test.cc index e8c9355fe866d..bf3ef3cb0e2ef 100644 --- a/xla/service/hlo_rematerialization_test.cc +++ b/xla/service/hlo_rematerialization_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -48,9 +48,9 @@ using ::testing::_; class RecomputeAndCompressHloRematerializationTest : public RematerializationTestBase { protected: - StatusOr RunHloRematerialization(int64_t memory_limit_bytes, - HloModule* module, - int64_t min_remat_size = 0) { + absl::StatusOr RunHloRematerialization(int64_t memory_limit_bytes, + HloModule* module, + int64_t min_remat_size = 0) { TF_EXPECT_OK(verifier().Run(module).status()); if (!module->has_schedule()) { HloMemoryScheduler scheduler( @@ -995,7 +995,7 @@ class CompressingRematerializationTest : public RematerializationTestBase { // Swap the layout of the two most-minor dimensions if the second-minor // dimension is bigger than the most-minor dimension. - static StatusOr ChooseCompactLayoutForShape(const Shape& shape) { + static absl::StatusOr ChooseCompactLayoutForShape(const Shape& shape) { if (shape.rank() != 2) { return shape; } @@ -1014,9 +1014,9 @@ class CompressingRematerializationTest : public RematerializationTestBase { return result; } - StatusOr RunHloRematerialization(int64_t memory_limit_bytes, - HloModule* module, - int64_t min_remat_size = 0) { + absl::StatusOr RunHloRematerialization(int64_t memory_limit_bytes, + HloModule* module, + int64_t min_remat_size = 0) { TF_EXPECT_OK(verifier().Run(module).status()); HloRematerialization::RematerializationModeConfig config( /*recompute=*/false, /*compress=*/true, /*host_offload=*/false); @@ -1209,9 +1209,9 @@ ENTRY %entry { class OffloadingRematerializationTest : public RematerializationTestBase { protected: - StatusOr RunHloRematerialization(int64_t memory_limit_bytes, - HloModule* module, - int64_t min_remat_size = 0) { + absl::StatusOr RunHloRematerialization(int64_t memory_limit_bytes, + HloModule* module, + int64_t min_remat_size = 0) { TF_EXPECT_OK(verifier().Run(module).status()); if (!module->has_schedule()) { HloMemoryScheduler scheduler( diff --git a/xla/service/hlo_rematerialization_test_utils.h b/xla/service/hlo_rematerialization_test_utils.h index b58fb49c5dbd7..069494536f263 100644 --- a/xla/service/hlo_rematerialization_test_utils.h +++ b/xla/service/hlo_rematerialization_test_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_rematerialization_test_utils_test.cc b/xla/service/hlo_rematerialization_test_utils_test.cc index 5d0d661c81e47..803a0704fde83 100644 --- a/xla/service/hlo_rematerialization_test_utils_test.cc +++ b/xla/service/hlo_rematerialization_test_utils_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_replication_analysis.cc b/xla/service/hlo_replication_analysis.cc index 70b73ed079892..b92b6fe816e3d 100644 --- a/xla/service/hlo_replication_analysis.cc +++ b/xla/service/hlo_replication_analysis.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -98,9 +98,13 @@ HloReplicationAnalysis::DetermineHloInstructionIsReplicated( bool replicated_across_replicas = true; const int64_t num_partitions = hlo->GetModule()->config().num_partitions(); + absl::flat_hash_set visited_partitions; + absl::flat_hash_set visited_replicas; for (const auto& group : hlo->replica_groups()) { - absl::flat_hash_set visited_partitions; - absl::flat_hash_set visited_replicas; + visited_partitions.clear(); + visited_replicas.clear(); + visited_replicas.reserve(group.replica_ids().size()); + visited_partitions.reserve(group.replica_ids().size()); for (int64_t id : group.replica_ids()) { int64_t rid = id / num_partitions; int64_t pid = id % num_partitions; @@ -421,13 +425,13 @@ Status HloReplicationAnalysis::ComputeHloReplication() { if (replication) { // If parameter replication status has been set explicitly, use that // instead. - if (!cross_partition_spmd_ && replication->at(leaf_index)) { + if (!cross_partition_spmd_ && (*replication)[leaf_index]) { // Setting parameter replication status for replicas in // non cross-partition spmd mode. *shape_tree.mutable_element(index) = HloReplication::ReplicatedOnAllDevices(); } - if (cross_partition_spmd_ && !replication->at(leaf_index)) { + if (cross_partition_spmd_ && !(*replication)[leaf_index]) { // Setting paramemter replication status for partitions in // cross-partition spmd mode. *shape_tree.mutable_element(index) = @@ -482,14 +486,14 @@ bool HloReplicationAnalysis::HloInstructionIsReplicatedAt( return true; } -/* static */ StatusOr> +/* static */ absl::StatusOr> HloReplicationAnalysis::Run(const HloModule* module, bool cross_partition_spmd) { const absl::flat_hash_set empty; return Run(module, cross_partition_spmd, &empty); } -/* static */ StatusOr> +/* static */ absl::StatusOr> HloReplicationAnalysis::Run(const HloModule* module, bool cross_partition_spmd, const absl::flat_hash_set* loops_known_with_same_iterations) { @@ -500,7 +504,7 @@ HloReplicationAnalysis::Run(const HloModule* module, bool cross_partition_spmd, return analysis; } -/* static */ StatusOr> +/* static */ absl::StatusOr> HloReplicationAnalysis::RunWithPartialReplication(const HloModule* module, bool cross_partition_spmd) { const absl::flat_hash_set empty; diff --git a/xla/service/hlo_replication_analysis.h b/xla/service/hlo_replication_analysis.h index 2267a80664979..e6f680214d799 100644 --- a/xla/service/hlo_replication_analysis.h +++ b/xla/service/hlo_replication_analysis.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,19 +35,19 @@ namespace xla { class HloReplicationAnalysis { public: // Runs the analysis on module and returns the result or an error. - static StatusOr> Run( + static absl::StatusOr> Run( const HloModule* module, bool cross_partition_spmd); // Same as above, but the caller can provide additional annotations: a set of // while loops that are known to have the same iteration counts across // replicas or partitions. - static StatusOr> Run( + static absl::StatusOr> Run( const HloModule* module, bool cross_partition_spmd, const absl::flat_hash_set* loops_known_with_same_iterations); // Same as above but supports finding partially replicated HLOs. - static StatusOr> + static absl::StatusOr> RunWithPartialReplication(const HloModule* module, bool cross_partition_spmd); // Returns if the HLO instruction outputs the same value (i.e., replicated) at diff --git a/xla/service/hlo_replication_analysis_test.cc b/xla/service/hlo_replication_analysis_test.cc index e5f07d0b586d3..388797c35437e 100644 --- a/xla/service/hlo_replication_analysis_test.cc +++ b/xla/service/hlo_replication_analysis_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_runner.cc b/xla/service/hlo_runner.cc index 19aee14693b42..6e4331a9fb61e 100644 --- a/xla/service/hlo_runner.cc +++ b/xla/service/hlo_runner.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ limitations under the License. #include "xla/shape_util.h" #include "tsl/platform/blocking_counter.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -49,7 +50,7 @@ HloRunner::HloRunner(se::Platform* platform, int intra_op_parallelism_threads) { HloRunner::~HloRunner() {} -StatusOr HloRunner::TransferLiteralToDevice( +absl::StatusOr HloRunner::TransferLiteralToDevice( const Literal& literal, int64_t param_no) { auto shape_representation_fn = [this, param_no](const Shape& shape) { Shape new_shape = device_shape_representation_fn_(shape); @@ -89,8 +90,8 @@ StatusOr HloRunner::TransferLiteralToDevice( return std::move(buffer); } -StatusOr> HloRunner::TransferLiteralsToDevice( - absl::Span literals) { +absl::StatusOr> +HloRunner::TransferLiteralsToDevice(absl::Span literals) { std::vector buffers; buffers.reserve(literals.size()); for (auto i = 0; i < literals.size(); i++) { @@ -103,8 +104,8 @@ StatusOr> HloRunner::TransferLiteralsToDevice( return std::move(buffers); } -StatusOr> HloRunner::TransferLiteralsToDevice( - absl::Span literals) { +absl::StatusOr> +HloRunner::TransferLiteralsToDevice(absl::Span literals) { std::vector literal_pointers; literal_pointers.reserve(literals.size()); for (const auto& literal : literals) { @@ -113,7 +114,7 @@ StatusOr> HloRunner::TransferLiteralsToDevice( return TransferLiteralsToDevice(literal_pointers); } -StatusOr HloRunner::TransferLiteralFromDevice( +absl::StatusOr HloRunner::TransferLiteralFromDevice( const ShapedBuffer& buffer) { TF_ASSIGN_OR_RETURN( auto stream, backend().BorrowStream(backend().default_stream_executor())); @@ -138,10 +139,10 @@ StatusOr HloRunner::TransferLiteralFromDevice( shaped_buffer); } -StatusOr HloRunner::Execute(std::unique_ptr module, - absl::Span arguments, - bool run_hlo_passes, - ExecutionProfile* profile) { +absl::StatusOr HloRunner::Execute( + std::unique_ptr module, + absl::Span arguments, bool run_hlo_passes, + ExecutionProfile* profile) { xla::UpdateEntryComputationLayout(module.get(), device_shape_representation_fn_); entry_computation_layout_ = &(module->entry_computation_layout()); @@ -157,7 +158,7 @@ StatusOr HloRunner::Execute(std::unique_ptr module, return TransferLiteralFromDevice(result.Result()); } -StatusOr HloRunner::ExecuteWithBufferAssignment( +absl::StatusOr HloRunner::ExecuteWithBufferAssignment( std::unique_ptr module, const BufferAssignmentProto* buffer_assignment_proto, absl::Span arguments, bool run_hlo_passes, @@ -176,7 +177,7 @@ StatusOr HloRunner::ExecuteWithBufferAssignment( return TransferLiteralFromDevice(result.Result()); } -StatusOr HloRunner::ExecuteWithExecutable( +absl::StatusOr HloRunner::ExecuteWithExecutable( Executable* executable, absl::Span arguments, ExecutionProfile* profile) { entry_computation_layout_ = @@ -273,7 +274,7 @@ static void ExecutionInputsFromMovedScopedShapedBuffers( } } -StatusOr HloRunner::ExecuteWithDeviceBuffers( +absl::StatusOr HloRunner::ExecuteWithDeviceBuffers( std::unique_ptr module, absl::Span arguments, bool run_hlo_passes, ExecutionProfile* profile) { @@ -282,7 +283,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( return ExecuteWithDeviceBuffers(executable.get(), arguments, profile); } -StatusOr HloRunner::ExecuteWithDeviceBuffers( +absl::StatusOr HloRunner::ExecuteWithDeviceBuffers( Executable* executable, absl::Span arguments, ExecutionProfile* profile) { std::vector execution_arguments = @@ -294,7 +295,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( profile); } -StatusOr HloRunner::ExecuteWithMovedDeviceBuffers( +absl::StatusOr HloRunner::ExecuteWithMovedDeviceBuffers( std::unique_ptr module, std::vector arguments, bool run_hlo_passes, ExecutionProfile* profile) { @@ -303,7 +304,7 @@ StatusOr HloRunner::ExecuteWithMovedDeviceBuffers( std::move(arguments), run_hlo_passes, profile); } -StatusOr +absl::StatusOr HloRunner::ExecuteWithMovedDeviceBuffersAndBufferAssignment( std::unique_ptr module, const BufferAssignmentProto* buffer_assignment_proto, @@ -317,7 +318,7 @@ HloRunner::ExecuteWithMovedDeviceBuffersAndBufferAssignment( profile); } -StatusOr HloRunner::ExecuteWithMovedDeviceBuffers( +absl::StatusOr HloRunner::ExecuteWithMovedDeviceBuffers( Executable* executable, std::vector arguments, ExecutionProfile* profile) { std::vector execution_arguments; @@ -341,28 +342,28 @@ StatusOr HloRunner::ExecuteWithMovedDeviceBuffers( return retval; } -StatusOr HloRunner::ExecuteWithExecutionInputs( +absl::StatusOr HloRunner::ExecuteWithExecutionInputs( Executable* executable, std::vector arguments, ExecutionProfile* profile) { xla::UpdateEntryComputationLayout(&executable->module(), device_shape_representation_fn_); // Get service run options. - se::Stream stream(backend().default_stream_executor()); - stream.Init(); + TF_ASSIGN_OR_RETURN(auto stream, + backend().default_stream_executor()->CreateStream()); ServiceExecutableRunOptions service_run_options = - GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream, - nullptr, RunId()); + GetServiceRunOptionsForDevice(backend().default_device_ordinal(), + stream.get(), nullptr, RunId()); service_run_options.mutable_run_options()->set_execution_profile(profile); TF_ASSIGN_OR_RETURN(ExecutionOutput retval, executable->ExecuteOnStreamWrapper(&service_run_options, std::move(arguments))); - TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); return std::move(retval); } -StatusOr> HloRunner::ExecuteReplicated( +absl::StatusOr> HloRunner::ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment) { TF_ASSIGN_OR_RETURN( @@ -371,8 +372,8 @@ StatusOr> HloRunner::ExecuteReplicated( return ExecuteReplicated(executable.get(), options, device_assignment); } -StatusOr> HloRunner::ExecuteReplicatedImpl( - std::function>( +absl::StatusOr> HloRunner::ExecuteReplicatedImpl( + std::function>( const std::vector&, const std::vector>&)> execution_helper, @@ -409,8 +410,8 @@ StatusOr> HloRunner::ExecuteReplicatedImpl( (*device_assignment)(i / num_partitions, i % num_partitions); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device)); - streams.push_back(std::make_unique(executor)); - streams.back()->Init(); + TF_ASSIGN_OR_RETURN(auto stream, executor->CreateStream()); + streams.emplace_back(std::move(stream)); service_run_options.emplace_back(GetServiceRunOptionsForDevice( device, streams.back().get(), device_assignment, run_id)); @@ -509,14 +510,14 @@ StatusOr> HloRunner::ExecuteReplicatedImpl( return std::move(exec_results); } -StatusOr> HloRunner::ExecuteReplicated( +absl::StatusOr> HloRunner::ExecuteReplicated( Executable* executable, const ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment, ExecutionProfile* profile) { return ExecuteReplicatedImpl( [&](const std::vector& service_run_options, const std::vector>& argument_buffer_slices) - -> StatusOr> { + -> absl::StatusOr> { std::vector results; if (!options.use_threads) { TF_ASSIGN_OR_RETURN( @@ -558,7 +559,7 @@ StatusOr> HloRunner::ExecuteReplicated( options, device_assignment); } -StatusOr> HloRunner::ExecuteReplicated( +absl::StatusOr> HloRunner::ExecuteReplicated( std::function executable_provider, std::function argument_count_provider, std::function argument_provider, @@ -576,7 +577,7 @@ StatusOr> HloRunner::ExecuteReplicated( [&](const std::vector& service_run_options, const std::vector>& argument_buffer_slices) - -> StatusOr> { + -> absl::StatusOr> { TF_RET_CHECK(options.use_threads); std::vector results; absl::Mutex mutex; @@ -613,7 +614,7 @@ StatusOr> HloRunner::ExecuteReplicated( argument_count_provider, argument_provider, options, device_assignment); } -StatusOr> HloRunner::ExecuteReplicated( +absl::StatusOr> HloRunner::ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options) { TF_ASSIGN_OR_RETURN( @@ -622,14 +623,14 @@ StatusOr> HloRunner::ExecuteReplicated( return ExecuteReplicated(std::move(module), options, &device_assignment); } -StatusOr> HloRunner::CreateExecutable( +absl::StatusOr> HloRunner::CreateExecutable( std::unique_ptr module, bool run_hlo_passes) { return CreateExecutableWithBufferAssignment( std::move(module), /*buffer_assignment_proto=*/nullptr, run_hlo_passes); } -StatusOr> +absl::StatusOr> HloRunner::CreateExecutableWithBufferAssignment( std::unique_ptr module, const BufferAssignmentProto* buffer_assignment_proto, bool run_hlo_passes) { @@ -640,6 +641,11 @@ HloRunner::CreateExecutableWithBufferAssignment( LOG(WARNING) << "Ignoring buffer assignment provided because hlo passes " "are enabled."; } + // Setup intra-op threads in module config + if (backend().eigen_intra_op_thread_pool() != nullptr) { + module->mutable_config().set_intra_op_parallelism_threads( + backend().eigen_intra_op_thread_pool()->NumThreads()); + } auto module_group = std::make_unique(std::move(module)); TF_ASSIGN_OR_RETURN( auto executables, diff --git a/xla/service/hlo_runner.h b/xla/service/hlo_runner.h index eeb3d4bdcc427..d62786fbd329f 100644 --- a/xla/service/hlo_runner.h +++ b/xla/service/hlo_runner.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -58,13 +58,13 @@ class HloRunner : public HloRunnerInterface { ~HloRunner() override; // Transfers data between the host and device. - StatusOr TransferLiteralToDevice(const Literal& literal, - int64_t param_no); - StatusOr> TransferLiteralsToDevice( + absl::StatusOr TransferLiteralToDevice( + const Literal& literal, int64_t param_no); + absl::StatusOr> TransferLiteralsToDevice( absl::Span literals); - StatusOr> TransferLiteralsToDevice( + absl::StatusOr> TransferLiteralsToDevice( absl::Span literals); - StatusOr TransferLiteralFromDevice(const ShapedBuffer& buffer); + absl::StatusOr TransferLiteralFromDevice(const ShapedBuffer& buffer); // Executes the given module with given literals as input and returns the // result as a Literal. @@ -74,14 +74,14 @@ class HloRunner : public HloRunnerInterface { using HloRunnerInterface::Execute; - StatusOr Execute(std::unique_ptr module, - absl::Span arguments, - bool run_hlo_passes, - ExecutionProfile* profile) override; + absl::StatusOr Execute(std::unique_ptr module, + absl::Span arguments, + bool run_hlo_passes, + ExecutionProfile* profile) override; using HloRunnerInterface::ExecuteWithBufferAssignment; - StatusOr ExecuteWithBufferAssignment( + absl::StatusOr ExecuteWithBufferAssignment( std::unique_ptr module, const BufferAssignmentProto* buffer_assignment_proto, absl::Span arguments, bool run_hlo_passes, @@ -89,7 +89,7 @@ class HloRunner : public HloRunnerInterface { using HloRunnerInterface::ExecuteWithExecutable; - StatusOr ExecuteWithExecutable( + absl::StatusOr ExecuteWithExecutable( Executable* executable, absl::Span arguments, ExecutionProfile* profile) override; @@ -101,12 +101,12 @@ class HloRunner : public HloRunnerInterface { // // This may overwrite the values of the arguments if the the module has // aliasing. - StatusOr ExecuteWithDeviceBuffers( + absl::StatusOr ExecuteWithDeviceBuffers( std::unique_ptr module, absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); - StatusOr ExecuteWithDeviceBuffers( + absl::StatusOr ExecuteWithDeviceBuffers( Executable* executable, absl::Span arguments, ExecutionProfile* profile = nullptr); @@ -115,27 +115,29 @@ class HloRunner : public HloRunnerInterface { // // This is a memory-safer version of ExecuteWithDeviceBuffers, but it consumes // the arguments. - StatusOr ExecuteWithMovedDeviceBuffers( + absl::StatusOr ExecuteWithMovedDeviceBuffers( std::unique_ptr module, std::vector arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); - StatusOr ExecuteWithMovedDeviceBuffersAndBufferAssignment( + absl::StatusOr + ExecuteWithMovedDeviceBuffersAndBufferAssignment( std::unique_ptr module, const BufferAssignmentProto* buffer_assignment_proto, std::vector arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); - StatusOr ExecuteWithMovedDeviceBuffers( + absl::StatusOr ExecuteWithMovedDeviceBuffers( Executable* executable, std::vector arguments, ExecutionProfile* profile = nullptr); // Creates an executable object given an HLO module. If run_hlo_passes is // true, the HLO passes will be run as part of compilation. - StatusOr> CreateExecutable( + absl::StatusOr> CreateExecutable( std::unique_ptr module, bool run_hlo_passes) override; - StatusOr> CreateExecutableWithBufferAssignment( + absl::StatusOr> + CreateExecutableWithBufferAssignment( std::unique_ptr module, const BufferAssignmentProto* /*buffer_assignment_proto*/, bool run_hlo_passes) override; @@ -143,12 +145,12 @@ class HloRunner : public HloRunnerInterface { // Executes a given HLO module into a set of replicas, and returns a map // with the replica number as key, and the corresponding returned literal as // value. - StatusOr> ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options) override; // Same as above, but with specified device assignment. - StatusOr> ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment) override; @@ -158,7 +160,7 @@ class HloRunner : public HloRunnerInterface { // // Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes, // since we've already compiled the Executable. - StatusOr> ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( Executable* executable, const ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment, ExecutionProfile* profile = nullptr); @@ -167,7 +169,7 @@ class HloRunner : public HloRunnerInterface { // // Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes, // since we've already compiled the Executable. - StatusOr> ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( std::function executable_provider, std::function argument_count_provider, std::function argument_provider, @@ -189,7 +191,7 @@ class HloRunner : public HloRunnerInterface { } private: - StatusOr ExecuteWithExecutionInputs( + absl::StatusOr ExecuteWithExecutionInputs( Executable* executable, std::vector arguments, ExecutionProfile* profile); @@ -202,8 +204,8 @@ class HloRunner : public HloRunnerInterface { RunId run_id); // Common implementation code for ExecuteReplicated() above. - StatusOr> ExecuteReplicatedImpl( - std::function>( + absl::StatusOr> ExecuteReplicatedImpl( + std::function>( const std::vector&, const std::vector>&)> execution_helper, diff --git a/xla/service/hlo_runner_interface.cc b/xla/service/hlo_runner_interface.cc index 171a4f7105483..d5507bfdef898 100644 --- a/xla/service/hlo_runner_interface.cc +++ b/xla/service/hlo_runner_interface.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ limitations under the License. namespace xla { -/*static*/ StatusOr> +/*static*/ absl::StatusOr> HloRunnerInterface::CreateModuleFromString(const absl::string_view hlo_string, const DebugOptions& debug_options) { HloModuleConfig config; @@ -30,7 +30,7 @@ HloRunnerInterface::CreateModuleFromString(const absl::string_view hlo_string, namespace { // Creates an HloModule from the given proto. -StatusOr> HloProtoToModule( +absl::StatusOr> HloProtoToModule( const HloProto& proto, const DebugOptions& debug_options) { TF_ASSIGN_OR_RETURN(HloModuleConfig config, HloModule::CreateModuleConfigFromProto(proto.hlo_module(), @@ -51,7 +51,7 @@ std::vector MakePointerVector(absl::Span input_vec) { } // namespace -/*static*/ StatusOr> +/*static*/ absl::StatusOr> HloRunnerInterface::ReadModuleFromBinaryProtoFile( const std::string& filename, const DebugOptions& debug_options) { HloProto proto; @@ -60,7 +60,7 @@ HloRunnerInterface::ReadModuleFromBinaryProtoFile( return HloProtoToModule(proto, debug_options); } -/*static*/ StatusOr> +/*static*/ absl::StatusOr> HloRunnerInterface::ReadModuleFromTextProtoFile( const std::string& filename, const DebugOptions& debug_options) { HloProto proto; @@ -68,7 +68,7 @@ HloRunnerInterface::ReadModuleFromTextProtoFile( return HloProtoToModule(proto, debug_options); } -/*static*/ StatusOr> +/*static*/ absl::StatusOr> HloRunnerInterface::ReadModuleFromHloTextFile( const std::string& filename, const DebugOptions& debug_options) { std::string hlo_string; @@ -79,7 +79,7 @@ HloRunnerInterface::ReadModuleFromHloTextFile( return ParseAndReturnUnverifiedModule(hlo_string, config); } -/*static*/ StatusOr> +/*static*/ absl::StatusOr> HloRunnerInterface::ReadModuleFromModuleBinaryProtofile( const std::string& filename, const DebugOptions& debug_options) { HloModuleProto module_proto; @@ -93,7 +93,7 @@ HloRunnerInterface::ReadModuleFromModuleBinaryProtofile( return HloModule::CreateFromProto(module_proto, module_config); } -StatusOr HloRunnerInterface::Execute( +absl::StatusOr HloRunnerInterface::Execute( std::unique_ptr module, absl::Span arguments, bool run_hlo_passes, ExecutionProfile* profile) { // Construct a vector of plain pointers for the arguments. @@ -105,7 +105,7 @@ StatusOr HloRunnerInterface::Execute( /*profile=*/profile); } -StatusOr HloRunnerInterface::ExecuteWithBufferAssignment( +absl::StatusOr HloRunnerInterface::ExecuteWithBufferAssignment( std::unique_ptr module, const BufferAssignmentProto* buffer_assignment_proto, absl::Span arguments, bool run_hlo_passes, @@ -120,7 +120,7 @@ StatusOr HloRunnerInterface::ExecuteWithBufferAssignment( /*profile=*/profile); } -StatusOr HloRunnerInterface::ExecuteWithExecutable( +absl::StatusOr HloRunnerInterface::ExecuteWithExecutable( Executable* executable, absl::Span arguments, ExecutionProfile* profile) { // Construct a vector of plain pointers for the arguments. diff --git a/xla/service/hlo_runner_interface.h b/xla/service/hlo_runner_interface.h index 0ba562a373df1..9727b9fc1d56a 100644 --- a/xla/service/hlo_runner_interface.h +++ b/xla/service/hlo_runner_interface.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -90,37 +90,38 @@ class HloRunnerInterface { // Converts an HloModule from the given hlo textual IR string (in // HloModule::ToString format). - static StatusOr> CreateModuleFromString( + static absl::StatusOr> CreateModuleFromString( const absl::string_view hlo_string, const DebugOptions& debug_options); // Reads the proto file in xla.HloProto format, creates and returns the // HloModule. - static StatusOr> ReadModuleFromBinaryProtoFile( - const std::string& filename, const DebugOptions& debug_options); - static StatusOr> ReadModuleFromTextProtoFile( + static absl::StatusOr> + ReadModuleFromBinaryProtoFile(const std::string& filename, + const DebugOptions& debug_options); + static absl::StatusOr> ReadModuleFromTextProtoFile( const std::string& filename, const DebugOptions& debug_options); // Reads the proto file in xla.HloModule format, creates and returns the // HloModule. - static StatusOr> + static absl::StatusOr> ReadModuleFromModuleBinaryProtofile(const std::string& filename, const DebugOptions& debug_options); // Reads the hlo text dump file in HloModule::ToString format, creates and // returns the HloModule. - static StatusOr> ReadModuleFromHloTextFile( + static absl::StatusOr> ReadModuleFromHloTextFile( const std::string& filename, const DebugOptions& debug_options); // Creates an executable object given an HLO module. If run_hlo_passes is // true, the HLO passes will be run as part of compilation. - virtual StatusOr> CreateExecutable( + virtual absl::StatusOr> CreateExecutable( std::unique_ptr module, bool run_hlo_passes) = 0; // Same as above, except it takes buffer assignment as input. // Note: The default implementation of the API here does not utilize the given // buffer assignment. A derived runner interface is expected to override the // following method to achieve this functionality. - virtual StatusOr> + virtual absl::StatusOr> CreateExecutableWithBufferAssignment( std::unique_ptr module, const BufferAssignmentProto* /*buffer_assignment_proto*/, @@ -134,24 +135,24 @@ class HloRunnerInterface { // // If run_hlo_passes is false, the module will be executed without Hlo // optimization - StatusOr Execute(std::unique_ptr module, - absl::Span arguments, - bool run_hlo_passes = true) { + absl::StatusOr Execute(std::unique_ptr module, + absl::Span arguments, + bool run_hlo_passes = true) { return Execute(std::move(module), arguments, run_hlo_passes, nullptr); } - StatusOr Execute(std::unique_ptr module, - absl::Span arguments, - bool run_hlo_passes = true, - ExecutionProfile* profile = nullptr); + absl::StatusOr Execute(std::unique_ptr module, + absl::Span arguments, + bool run_hlo_passes = true, + ExecutionProfile* profile = nullptr); - virtual StatusOr Execute(std::unique_ptr module, - absl::Span arguments, - bool run_hlo_passes, - ExecutionProfile* profile) = 0; + virtual absl::StatusOr Execute( + std::unique_ptr module, + absl::Span arguments, bool run_hlo_passes, + ExecutionProfile* profile) = 0; // Same as above 3 methods, but with buffer assignment specified. - StatusOr ExecuteWithBufferAssignment( + absl::StatusOr ExecuteWithBufferAssignment( std::unique_ptr module, const BufferAssignmentProto* buffer_assignment_proto, absl::Span arguments, bool run_hlo_passes = true) { @@ -160,7 +161,7 @@ class HloRunnerInterface { run_hlo_passes, nullptr); } - StatusOr ExecuteWithBufferAssignment( + absl::StatusOr ExecuteWithBufferAssignment( std::unique_ptr module, const BufferAssignmentProto* buffer_assignment_proto, absl::Span arguments, bool run_hlo_passes = true, @@ -169,7 +170,7 @@ class HloRunnerInterface { // Note: The default implementation of the API here does not utilize the given // buffer assignment. A derived runner interface is expected to override the // following method to achieve this functionality. - virtual StatusOr ExecuteWithBufferAssignment( + virtual absl::StatusOr ExecuteWithBufferAssignment( std::unique_ptr module, const BufferAssignmentProto* /*buffer_assignment_proto*/, absl::Span arguments, bool run_hlo_passes, @@ -179,16 +180,16 @@ class HloRunnerInterface { } // Same as 3 Execute methods above, but with Executable as input. - StatusOr ExecuteWithExecutable(Executable* executable, - absl::Span arguments, - ExecutionProfile* profile = nullptr); + absl::StatusOr ExecuteWithExecutable( + Executable* executable, absl::Span arguments, + ExecutionProfile* profile = nullptr); - StatusOr ExecuteWithExecutable( + absl::StatusOr ExecuteWithExecutable( Executable* executable, absl::Span arguments) { return ExecuteWithExecutable(executable, arguments, nullptr); } - virtual StatusOr ExecuteWithExecutable( + virtual absl::StatusOr ExecuteWithExecutable( Executable* executable, absl::Span arguments, ExecutionProfile* profile) = 0; @@ -196,17 +197,17 @@ class HloRunnerInterface { // with the replica number as key, and the corresponding returned literal as // value. // TODO(b/172931928): change to non-virtual function. - virtual StatusOr> ExecuteReplicated( + virtual absl::StatusOr> ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options) = 0; // Same as above, but with specified device assignment. - virtual StatusOr> ExecuteReplicated( + virtual absl::StatusOr> ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment) = 0; - virtual StatusOr> ExecuteReplicated( + virtual absl::StatusOr> ExecuteReplicated( std::function executable_provider, std::function argument_count_provider, std::function argument_provider, diff --git a/xla/service/hlo_runner_pjrt.cc b/xla/service/hlo_runner_pjrt.cc index 2f47c10cb503b..fbf1dcd539153 100644 --- a/xla/service/hlo_runner_pjrt.cc +++ b/xla/service/hlo_runner_pjrt.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ class PjRtWrappedExecutable : public Executable { : Executable(hlo_module), pjrt_loaded_executable_(pjrt_loaded_executable) {} - StatusOr ExecuteAsyncOnStream( + absl::StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, std::vector arguments, HloExecutionProfile* hlo_execution_profile) override; @@ -55,7 +55,7 @@ class PjRtWrappedExecutable : public Executable { PjRtLoadedExecutable* pjrt_loaded_executable_; }; -StatusOr PjRtWrappedExecutable::ExecuteAsyncOnStream( +absl::StatusOr PjRtWrappedExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, std::vector arguments, HloExecutionProfile* hlo_execution_profile) { @@ -73,7 +73,7 @@ HloRunnerPjRt::HloRunnerPjRt( HloRunnerPjRt::~HloRunnerPjRt() = default; -StatusOr HloRunnerPjRt::GenerateDefaultCompileOptions( +absl::StatusOr HloRunnerPjRt::GenerateDefaultCompileOptions( HloModule* module, bool run_hlo_passes) { TF_ASSIGN_OR_RETURN( auto device_assignment, @@ -105,15 +105,16 @@ StatusOr HloRunnerPjRt::GenerateDefaultCompileOptions( return compile_options; } -StatusOr HloRunnerPjRt::TransferLiteralFromDevice(PjRtBuffer& buffer) { +absl::StatusOr HloRunnerPjRt::TransferLiteralFromDevice( + PjRtBuffer& buffer) { TF_RETURN_IF_ERROR(buffer.GetReadyFuture().Await()); TF_ASSIGN_OR_RETURN(auto literal, buffer.ToLiteralSync()); return std::move(*literal); } -StatusOr> HloRunnerPjRt::TransferLiteralToDevice( - const Literal& literal) { +absl::StatusOr> +HloRunnerPjRt::TransferLiteralToDevice(const Literal& literal) { auto devices = pjrt_client_->addressable_devices(); TF_ASSIGN_OR_RETURN(auto assignment, pjrt_client_->BufferFromHostLiteral( @@ -122,7 +123,7 @@ StatusOr> HloRunnerPjRt::TransferLiteralToDevice( return std::move(assignment); } -StatusOr>> +absl::StatusOr>> HloRunnerPjRt::TransferLiteralsToDevice( absl::Span literals) { std::vector> buffers; @@ -137,7 +138,7 @@ HloRunnerPjRt::TransferLiteralsToDevice( return std::move(buffers); } -StatusOr HloRunnerPjRt::Execute( +absl::StatusOr HloRunnerPjRt::Execute( std::unique_ptr module, absl::Span arguments, bool run_hlo_passes, ExecutionProfile* profile) { @@ -174,14 +175,15 @@ std::vector> HloRunnerPjRt::BufferMatToPointerMat( return argument_ptrs; } -StatusOr> HloRunnerPjRt::CreateExecutable( - HloModule* module, CompileOptions compile_options) { +absl::StatusOr> +HloRunnerPjRt::CreateExecutable(HloModule* module, + CompileOptions compile_options) { XlaComputation computation(module->ToProto()); return pjrt_client_->Compile(computation, compile_options); } -StatusOr>> +absl::StatusOr>> HloRunnerPjRt::ExecuteWithDeviceBuffers( PjRtLoadedExecutable* executable, const std::vector>& arguments) { @@ -201,7 +203,7 @@ HloRunnerPjRt::ExecuteWithDeviceBuffers( return output_buffers; } -StatusOr HloRunnerPjRt::ExecuteWithExecutable( +absl::StatusOr HloRunnerPjRt::ExecuteWithExecutable( Executable* executable, absl::Span arguments, ExecutionProfile* profile) { PjRtWrappedExecutable* wrapped_executable = @@ -220,7 +222,7 @@ StatusOr HloRunnerPjRt::ExecuteWithExecutable( return TransferLiteralFromDevice(*output_buffer[0]); } -StatusOr> HloRunnerPjRt::CreateExecutable( +absl::StatusOr> HloRunnerPjRt::CreateExecutable( std::unique_ptr module, bool run_hlo_passes) { TF_ASSIGN_OR_RETURN(auto compile_options, GenerateDefaultCompileOptions( module.get(), run_hlo_passes)); @@ -236,7 +238,7 @@ StatusOr> HloRunnerPjRt::CreateExecutable( return exec; } -StatusOr> HloRunnerPjRt::ExecuteReplicated( +absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( std::unique_ptr module, const HloRunnerInterface::ReplicatedExecuteOptions& options) { xla::UpdateEntryComputationLayout(module.get(), @@ -249,7 +251,7 @@ StatusOr> HloRunnerPjRt::ExecuteReplicated( return ExecuteReplicated(std::move(module), options, &device_assignment); } -StatusOr> HloRunnerPjRt::ExecuteReplicated( +absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( std::unique_ptr module, const HloRunnerInterface::ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment) { @@ -262,13 +264,13 @@ StatusOr> HloRunnerPjRt::ExecuteReplicated( return ExecuteReplicated(executable.get(), options, device_assignment); } -StatusOr> HloRunnerPjRt::ExecuteReplicated( +absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( Executable* executable, const HloRunnerInterface::ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment, ExecutionProfile* profile) { return ExecuteReplicatedImpl( [&](absl::Span>& argument_buffer_slices) - -> StatusOr>> { + -> absl::StatusOr>> { PjRtWrappedExecutable* wrapped_executable = static_cast(executable); @@ -292,7 +294,7 @@ StatusOr> HloRunnerPjRt::ExecuteReplicated( options, device_assignment); } -StatusOr> HloRunnerPjRt::ExecuteReplicated( +absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( std::function executable_provider, std::function argument_count_provider, std::function argument_provider, @@ -301,8 +303,8 @@ StatusOr> HloRunnerPjRt::ExecuteReplicated( return Unimplemented("Unimplemeneted ExecuteReplicated"); } -StatusOr> HloRunnerPjRt::ExecuteReplicatedImpl( - std::function>>( +absl::StatusOr> HloRunnerPjRt::ExecuteReplicatedImpl( + std::function>>( absl::Span>&)> execution_helper, std::function argument_count_provider, diff --git a/xla/service/hlo_runner_pjrt.h b/xla/service/hlo_runner_pjrt.h index 0f6c47b0ebda0..baa71fb33b44a 100644 --- a/xla/service/hlo_runner_pjrt.h +++ b/xla/service/hlo_runner_pjrt.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -39,56 +39,57 @@ class HloRunnerPjRt : public HloRunnerInterface { ~HloRunnerPjRt() override; // Transfers data between the host and device. - StatusOr> TransferLiteralToDevice( + absl::StatusOr> TransferLiteralToDevice( const Literal& literal); - StatusOr>> TransferLiteralsToDevice( - absl::Span literals); - StatusOr TransferLiteralFromDevice(PjRtBuffer& buffer); + absl::StatusOr>> + TransferLiteralsToDevice(absl::Span literals); + absl::StatusOr TransferLiteralFromDevice(PjRtBuffer& buffer); // Executes the given module with given literals as input and returns the // result as a Literal. - StatusOr Execute(std::unique_ptr module, - absl::Span arguments, - bool run_hlo_passes, - ExecutionProfile* profile) override; + absl::StatusOr Execute(std::unique_ptr module, + absl::Span arguments, + bool run_hlo_passes, + ExecutionProfile* profile) override; // As Execute(), but accepts and returns device buffers instead of host // buffers. - StatusOr>> ExecuteWithDeviceBuffers( + absl::StatusOr>> + ExecuteWithDeviceBuffers( PjRtLoadedExecutable* executable, const std::vector>& arguments); // Creates an executable object for an HloModule. - StatusOr> CreateExecutable( + absl::StatusOr> CreateExecutable( HloModule* module, CompileOptions compile_options); // Creates an executable object given an HLO module. If run_hlo_passes is // true, the HLO passes will be run as part of compilation. - StatusOr> CreateExecutable( + absl::StatusOr> CreateExecutable( std::unique_ptr module, bool run_hlo_passes) override; - StatusOr ExecuteWithExecutable( + absl::StatusOr ExecuteWithExecutable( Executable* executable, absl::Span arguments, ExecutionProfile* profile) override; - StatusOr> ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options) override; // Same as above, but with specified device assignment. - StatusOr> ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment) override; - StatusOr> ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( std::function executable_provider, std::function argument_count_provider, std::function argument_provider, const ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment) override; - StatusOr> ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( Executable* executable, const HloRunnerInterface::ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment, ExecutionProfile* profile = nullptr); @@ -105,11 +106,11 @@ class HloRunnerPjRt : public HloRunnerInterface { std::vector> BufferMatToPointerMat( std::vector>>& buffer); - StatusOr GenerateDefaultCompileOptions(HloModule* module, - bool run_hlo_passes); + absl::StatusOr GenerateDefaultCompileOptions( + HloModule* module, bool run_hlo_passes); - StatusOr> ExecuteReplicatedImpl( - std::function>>( + absl::StatusOr> ExecuteReplicatedImpl( + std::function>>( absl::Span>&)> execution_helper, std::function argument_count_provider, diff --git a/xla/service/hlo_schedule_test.cc b/xla/service/hlo_schedule_test.cc index 99112c105620d..4ba1a982def9e 100644 --- a/xla/service/hlo_schedule_test.cc +++ b/xla/service/hlo_schedule_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_sharding_test.cc b/xla/service/hlo_sharding_test.cc index b2c2535297709..7a9a5d6ece4ab 100644 --- a/xla/service/hlo_sharding_test.cc +++ b/xla/service/hlo_sharding_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/hlo_value.cc b/xla/service/hlo_value.cc index 9969d94ab16e9..a74e43b5996ec 100644 --- a/xla/service/hlo_value.cc +++ b/xla/service/hlo_value.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,27 +16,27 @@ limitations under the License. #include "xla/service/hlo_value.h" #include -#include +#include +#include +#include #include #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/map_util.h" +#include "xla/service/buffer_value.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status.h" -#include "xla/types.h" #include "xla/util.h" -#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" namespace xla { @@ -95,12 +95,16 @@ std::string HloValue::ToString(int indent) const { for (const HloPosition& position : positions()) { StrAppend(&out, indentation, " ", position.ToString(), "\n"); } - StrAppend(&out, indentation, " uses:\n"); - for (const HloUse& use : GetUses()) { - StrAppend(&out, indentation, " ", use.ToString(), "\n"); + if (uses_.has_value()) { + StrAppend(&out, indentation, " uses:\n"); + for (const HloUse& use : GetUses()) { + StrAppend(&out, indentation, " ", use.ToString(), "\n"); + } + } else { + StrAppend(&out, indentation, " uses are not initialized yet.\n"); } - StrAppend(&out, indentation, " from instruction:", instruction()->ToString(), - "\n"); + StrAppend(&out, indentation, + " from instruction: ", instruction()->ToString()); return out; } @@ -110,28 +114,23 @@ namespace { // ShapeIndex in the given operand. Generally, instruction which pass through // values transparently without reading the value are not considered to use the // value. -bool MayUseOperandValue(int64_t operand_number, const ShapeIndex& index, - const HloInstruction* user) { +bool MayUseOperandValue(const ShapeIndex& index, const HloInstruction* user) { switch (user->opcode()) { case HloOpcode::kGetTupleElement: case HloOpcode::kCopy: // These instructions only access the top-level values of their // operand. Non-top-level (nested) values are passed through // transparently. - CHECK_EQ(operand_number, 0); return index.empty(); case HloOpcode::kDomain: case HloOpcode::kTuple: // These instructions always pass through their operands transparently. return false; - case HloOpcode::kCall: - case HloOpcode::kWhile: - // Although call and while instructions pass through their operands, they - // are considered uses. - return true; - default: + // Although call (HloOpcode::kCall) and while (HloOpcode::kWhile) + // instructions pass through their operands as are all other opcode types, + // they are considered uses. return true; } } @@ -171,6 +170,19 @@ HloValue::Uses HloValue::ComputeUses() const { // Build vector of HloUses for the value. for (const HloPosition& position : positions_) { for (HloInstruction* const user : position.instruction->users()) { +#ifndef NDEBUG + // If user is in the root positions of this value, it must be a root. + if (root_positions.contains(user)) { + CHECK(user->IsRoot()); + } +#endif // NDEBUG + // Root instructions of computations are considered to be uses whether + // or not the root instruction itself actually uses the value. + if (!MayUseOperandValue(position.index, user) && + !(user->IsRoot() && root_positions.contains(user))) { + continue; + } + int i = -1; for (const auto& operand : user->operands()) { ++i; @@ -179,21 +191,19 @@ HloValue::Uses HloValue::ComputeUses() const { continue; } - // Root instructions of computations are considered to be uses whether - // or not the root instruction itself actually uses the value. - if (MayUseOperandValue(i, position.index, user) || - root_positions.contains(user)) { - HloUse new_use{user, i, position.index}; - + uses.emplace_back(user, i, position.index); #ifndef NDEBUG - // The new use must not already exist in uses. - for (const HloUse& use : uses) { - DCHECK_NE(use, new_use); - } -#endif // NDEBUG - - uses.push_back(std::move(new_use)); + // The new use must not already exist in uses. + for (int index = 0; index + 1 < uses.size(); ++index) { + DCHECK_NE(uses[index], uses.back()); } +#endif // NDEBUG + } + // In case of HloOpcode::kGetTupleElement or HloOpcode::kCopy instruction, + // ensure that user has at most one operand. + if (user->opcode() == HloOpcode::kGetTupleElement || + user->opcode() == HloOpcode::kCopy) { + CHECK_LE(i, 0); } } } diff --git a/xla/service/hlo_value.h b/xla/service/hlo_value.h index a3b0638a3e141..3f4f4699cde54 100644 --- a/xla/service/hlo_value.h +++ b/xla/service/hlo_value.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,20 +18,22 @@ limitations under the License. #include +#include +#include #include #include #include #include -#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/lazy.h" #include "xla/service/buffer_value.h" +#include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" @@ -79,6 +81,15 @@ struct HloUse { // The shape index within the operand in which the value appears. ShapeIndex operand_index; + HloUse() = default; + HloUse(HloInstruction* instruction, int64_t operand_number) + : instruction(instruction), operand_number(operand_number) {} + HloUse(HloInstruction* instruction, int64_t operand_number, + ShapeIndex operand_index) + : instruction(instruction), + operand_number(operand_number), + operand_index(std::move(operand_index)) {} + std::string ToString() const; bool operator==(const HloUse& other) const { @@ -164,6 +175,8 @@ class HloValue : public BufferValue { // Return a single-line string representation of the value. std::string ToShortString() const; + // The returned string doesn't include `uses` if the ToString is called before + // `GetUses` is called. std::string ToString(int indent) const; std::string ToString() const override { return ToString(0); } @@ -250,7 +263,7 @@ class HloValueSet { std::vector values_; }; -std::ostream& operator<<(std::ostream& out, const HloValueSet& hlo_value); +std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set); // A class collecting the HloValues which might be contained in the output of // an HLO instruction. For array-shaped instructions, an InstructionValueSet diff --git a/xla/service/hlo_value_semantics_analysis.cc b/xla/service/hlo_value_semantics_analysis.cc index 2e7f98cf5f357..53bb988b83360 100644 --- a/xla/service/hlo_value_semantics_analysis.cc +++ b/xla/service/hlo_value_semantics_analysis.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,6 +17,7 @@ limitations under the License. #include "xla/service/hlo_value_semantics_analysis.h" #include +#include #include #include #include @@ -27,10 +28,12 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" @@ -43,6 +46,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" +#include "xla/side_effect_util.h" #include "xla/status.h" #include "xla/statusor.h" #include "xla/util.h" @@ -51,6 +55,48 @@ limitations under the License. namespace xla { +SendRecvGroupMap::SendRecvGroupMap(const HloModule& hlo_module) { + for (HloComputation* computation : hlo_module.computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kSend && + instruction->opcode() != HloOpcode::kRecv) { + continue; + } + std::string rendezvous = instruction->frontend_attributes().map().at( + kXlaHostTransferRendezvousNameAttr); + auto send_recv_iter = host_transfer_rendezvous_map_.find(rendezvous); + if (send_recv_iter == host_transfer_rendezvous_map_.end()) { + auto insert_success = host_transfer_rendezvous_map_.insert( + {rendezvous, SendRecvGroup{nullptr, nullptr}}); + send_recv_iter = insert_success.first; + } + if (instruction->opcode() == HloOpcode::kSend) { + send_recv_iter->second.send = instruction; + } else { + send_recv_iter->second.recv = instruction; + } + } + } +} + +absl::StatusOr SendRecvGroupMap::GetMatchingSendOrRecv( + HloInstruction* send_or_recv) const { + if (send_or_recv->opcode() != HloOpcode::kSend && + send_or_recv->opcode() != HloOpcode::kRecv) { + return InvalidArgument("Expecting only send or recv"); + } + std::string rendezvous = send_or_recv->frontend_attributes().map().at( + kXlaHostTransferRendezvousNameAttr); + auto send_recv_iter = host_transfer_rendezvous_map_.find(rendezvous); + if (send_recv_iter == host_transfer_rendezvous_map_.end()) { + return Internal("Missing send or recv from send recv group."); + } + if (send_or_recv->opcode() == HloOpcode::kSend) { + return send_recv_iter->second.recv; + } + return send_recv_iter->second.send; +} + bool HloPreOrderDFS::IsReady(const HloInstruction* instruction) const { for (HloInstruction* user : instruction->users()) { if (!visited_.contains(user)) { @@ -75,8 +121,8 @@ std::vector GetAllInstructionsWithZeroUsers( } // namespace -Status HloPreOrderDFS::Run(const HloComputation& computation, - DfsHloVisitorBase* visitor) { +absl::Status HloPreOrderDFS::Run(const HloComputation& computation, + DfsHloVisitorBase* visitor) { stack_.clear(); visited_.clear(); std::vector roots = @@ -103,7 +149,32 @@ Status HloPreOrderDFS::Run(const HloComputation& computation, return OkStatus(); } -Status EinsumDepthAnalysis::RunInternal( +namespace { + +template +std::string ToString(T element) { + return absl::StrCat(element); +} + +template <> +std::string ToString(const HloValueSemantics* element) { + return element->ToString(); +} + +template +std::string ToString(const ShapeTree& tree) { + std::string str; + tree.ForEachElement([&str, &tree](const ShapeIndex& shape_index, T element) { + auto subshape = ShapeUtil::GetSubshape(tree.shape(), (shape_index)); + absl::StrAppend(&str, shape_index.ToString(), ", ", subshape.ToString(), + ": ", ToString(element), "\n"); + }); + return str; +} + +} // namespace + +absl::Status EinsumDepthAnalysis::RunInternal( const HloComputation& computation, const std::optional>& root_depth) { std::vector roots = @@ -123,9 +194,11 @@ Status EinsumDepthAnalysis::RunInternal( return dfs.Run(computation, this); } -StatusOr> EinsumDepthAnalysis::Run( - const HloComputation& computation) { - EinsumDepthAnalysis* analysis_ptr = new EinsumDepthAnalysis(); +absl::StatusOr> EinsumDepthAnalysis::Run( + const HloComputation& computation, + const SendRecvGroupMap& send_recv_group_map) { + EinsumDepthAnalysis* analysis_ptr = + new EinsumDepthAnalysis(send_recv_group_map); std::unique_ptr analysis(analysis_ptr); TF_RETURN_IF_ERROR(analysis->RunInternal(computation, std::nullopt)); return analysis; @@ -187,21 +260,22 @@ int GetMaxDepth(const ShapeTree& depth_tree) { void SetDepthFromTupleDepth(ShapeTree& depth_tree, const ShapeTree& tuple_depth_tree, int tuple_index) { - depth_tree.ForEachMutableElement([&depth_tree, &tuple_depth_tree, - tuple_index](const ShapeIndex& shape_index, - int* depth_ptr) { - if (depth_tree.IsLeaf(shape_index)) { - ShapeIndex output_index = shape_index; - output_index.push_front(tuple_index); - *depth_ptr = std::max(*depth_ptr, tuple_depth_tree.element(output_index)); - } - }); + depth_tree.ForEachMutableElement( + [&depth_tree, &tuple_depth_tree, tuple_index]( + const ShapeIndex& shape_index, int* depth_ptr) { + if (depth_tree.IsLeaf(shape_index)) { + ShapeIndex output_index = shape_index; + output_index.push_front(tuple_index); + *depth_ptr = + MergeDepth(*depth_ptr, tuple_depth_tree.element(output_index)); + } + }); } } // namespace EinsumDepthMap::iterator EinsumDepthAnalysis::GetOrCreateDepthTree( - HloInstruction* instruction) { + const HloInstruction* instruction) { auto depth_iter = einsum_depth_map_.find(instruction); if (depth_iter == einsum_depth_map_.end()) { ShapeTree depth_tree(instruction->shape(), -1); @@ -212,82 +286,68 @@ EinsumDepthMap::iterator EinsumDepthAnalysis::GetOrCreateDepthTree( return depth_iter; } -Status EinsumDepthAnalysis::SetInstructionDepth(HloInstruction* instruction, - int depth) { +EinsumDepthMap::iterator EinsumDepthAnalysis::GetDepthTreeOrDie( + const HloInstruction* instruction) { + auto depth_iter = einsum_depth_map_.find(instruction); + CHECK(depth_iter != einsum_depth_map_.end()) + << "No depth tree found for instruction: " << instruction->ToString(); + return depth_iter; +} + +absl::Status EinsumDepthAnalysis::SetInstructionDepth( + const HloInstruction* instruction, int depth) { auto depth_iter = GetOrCreateDepthTree(instruction); ShapeTree& depth_tree = depth_iter->second; SetDepth(depth_tree, depth); return OkStatus(); } -Status EinsumDepthAnalysis::SetInstructionDepth(HloInstruction* instruction, - const ShapeTree& depth) { +absl::Status EinsumDepthAnalysis::SetInstructionDepth( + const HloInstruction* instruction, const ShapeTree& depth) { auto depth_iter = GetOrCreateDepthTree(instruction); ShapeTree& depth_tree = depth_iter->second; SetDepth(depth_tree, depth); return OkStatus(); } -Status EinsumDepthAnalysis::DefaultAction(HloInstruction* instruction) { - if (!instruction->shape().IsToken() && !instruction->shape().IsArray() && - !instruction->shape().IsTuple()) { - return InvalidArgument("Unexpected shape for default action."); - } - auto depth_iter = einsum_depth_map_.find(instruction); - CHECK(depth_iter != einsum_depth_map_.end()); - const ShapeTree depth_tree = depth_iter->second; +absl::Status EinsumDepthAnalysis::SetInstructionDepthFromTupleDepth( + const HloInstruction* instruction, const ShapeTree& tuple_depth_tree, + int tuple_index) { + auto depth_iter = GetOrCreateDepthTree(instruction); + ShapeTree& depth_tree = depth_iter->second; + SetDepthFromTupleDepth(depth_tree, tuple_depth_tree, tuple_index); + return OkStatus(); +} + +absl::Status EinsumDepthAnalysis::DefaultAction(HloInstruction* instruction) { + auto depth_iter = GetDepthTreeOrDie(instruction); + const ShapeTree& depth_tree = depth_iter->second; int max_depth = GetMaxDepth(depth_tree); - if (instruction->shape().IsToken()) { - for (HloInstruction* operand : instruction->mutable_operands()) { - TF_RETURN_IF_ERROR(SetInstructionDepth(operand, max_depth)); - } - } - if (instruction->operand_count() == 1) { - HloInstruction* operand = instruction->mutable_operand(0); - if (Shape::Equal().IgnoreLayout()(instruction->shape(), operand->shape())) { - TF_RETURN_IF_ERROR(SetInstructionDepth(operand, depth_tree)); - return OkStatus(); - } - } - // If the instruction is an array, the output depends on all operands. - if (instruction->shape().IsArray()) { - int instruction_depth = depth_tree.element({}); - for (HloInstruction* operand : instruction->mutable_operands()) { - TF_RETURN_IF_ERROR(SetInstructionDepth(operand, instruction_depth)); - } - return OkStatus(); - } - // If the instruction is a tuple and the output size is larger than the - // operand count, each tuple element depends on all operands. - int tuple_shape_size = instruction->shape().tuple_shapes_size(); - if (instruction->operand_count() < tuple_shape_size) { - for (HloInstruction* operand : instruction->mutable_operands()) { - TF_RETURN_IF_ERROR(SetInstructionDepth(operand, max_depth)); - } - return OkStatus(); - } - // Each tuple element depends on a specific operand. for (int operand_index = 0; operand_index < instruction->operand_count(); ++operand_index) { - HloInstruction* operand = instruction->mutable_operand(operand_index); - if (operand_index < tuple_shape_size) { - auto operand_depth_iter = GetOrCreateDepthTree(operand); - ShapeTree& operand_depth = operand_depth_iter->second; - SetDepthFromTupleDepth(operand_depth, depth_tree, operand_index); - } else { - TF_RETURN_IF_ERROR(SetInstructionDepth(operand, max_depth)); - } + const HloInstruction* operand = instruction->operand(operand_index); + TF_RETURN_IF_ERROR(SetInstructionDepth(operand, max_depth)); } return OkStatus(); } -Status EinsumDepthAnalysis::HandleTuple(HloInstruction* tuple) { - auto depth_iter = einsum_depth_map_.find(tuple); - CHECK(depth_iter != einsum_depth_map_.end()); +absl::Status EinsumDepthAnalysis::HandleTuple(HloInstruction* tuple) { + return HandleTupleLike(tuple); +} + +absl::Status EinsumDepthAnalysis::HandleAllReduce(HloInstruction* all_reduce) { + if (all_reduce->shape().IsArray()) { + return DefaultAction(all_reduce); + } + return HandleTupleLike(all_reduce); +} + +absl::Status EinsumDepthAnalysis::HandleTupleLike(HloInstruction* tuple_like) { + auto depth_iter = GetDepthTreeOrDie(tuple_like); const ShapeTree depth_tree = depth_iter->second; - for (int operand_index = 0; operand_index < tuple->operand_count(); + for (int operand_index = 0; operand_index < tuple_like->operand_count(); ++operand_index) { - HloInstruction* operand = tuple->mutable_operand(operand_index); + HloInstruction* operand = tuple_like->mutable_operand(operand_index); auto operand_depth_iter = GetOrCreateDepthTree(operand); ShapeTree& operand_depth = operand_depth_iter->second; SetDepthFromTupleDepth(operand_depth, depth_tree, operand_index); @@ -295,10 +355,9 @@ Status EinsumDepthAnalysis::HandleTuple(HloInstruction* tuple) { return OkStatus(); } -Status EinsumDepthAnalysis::HandleGetTupleElement( +absl::Status EinsumDepthAnalysis::HandleGetTupleElement( HloInstruction* get_tuple_element) { - auto depth_iter = einsum_depth_map_.find(get_tuple_element); - CHECK(depth_iter != einsum_depth_map_.end()); + auto depth_iter = GetDepthTreeOrDie(get_tuple_element); const ShapeTree depth_tree = depth_iter->second; HloInstruction* operand = get_tuple_element->mutable_operand(0); @@ -314,16 +373,15 @@ Status EinsumDepthAnalysis::HandleGetTupleElement( if (operand_depth.IsLeaf(shape_index)) { ShapeIndex output_index = shape_index; output_index.pop_front(); - *depth_ptr = std::max(*depth_ptr, depth_tree.element(output_index)); + *depth_ptr = MergeDepth(*depth_ptr, depth_tree.element(output_index)); } }); return OkStatus(); } -Status EinsumDepthAnalysis::HandleDepthIncrementInstruction( +absl::Status EinsumDepthAnalysis::HandleDepthIncrementInstruction( HloInstruction* instruction) { - auto depth_iter = einsum_depth_map_.find(instruction); - CHECK(depth_iter != einsum_depth_map_.end()); + auto depth_iter = GetDepthTreeOrDie(instruction); int instruction_depth = depth_iter->second.element({}); for (HloInstruction* operand : instruction->mutable_operands()) { TF_RETURN_IF_ERROR(SetInstructionDepth( @@ -333,82 +391,74 @@ Status EinsumDepthAnalysis::HandleDepthIncrementInstruction( return OkStatus(); } -Status EinsumDepthAnalysis::HandleDot(HloInstruction* dot) { - auto depth_iter = einsum_depth_map_.find(dot); - CHECK(depth_iter != einsum_depth_map_.end()); +absl::Status EinsumDepthAnalysis::HandleDot(HloInstruction* dot) { return HandleDepthIncrementInstruction(dot); } -Status EinsumDepthAnalysis::HandleConvolution(HloInstruction* convolution) { +absl::Status EinsumDepthAnalysis::HandleConvolution( + HloInstruction* convolution) { return HandleDepthIncrementInstruction(convolution); } -Status EinsumDepthAnalysis::HandleCall(HloInstruction* call) { - auto depth_iter = einsum_depth_map_.find(call); - CHECK(depth_iter != einsum_depth_map_.end()); +absl::Status EinsumDepthAnalysis::HandleCall(HloInstruction* call) { + auto depth_iter = GetDepthTreeOrDie(call); const ShapeTree depth_tree = depth_iter->second; return HandleCalledComputation(*call->called_computations()[0], depth_tree, call->operands()); } -Status EinsumDepthAnalysis::HandleFusion(HloInstruction* fusion) { - auto depth_iter = einsum_depth_map_.find(fusion); - CHECK(depth_iter != einsum_depth_map_.end()); +absl::Status EinsumDepthAnalysis::HandleFusion(HloInstruction* fusion) { + auto depth_iter = GetDepthTreeOrDie(fusion); const ShapeTree depth_tree = depth_iter->second; return HandleCalledComputation(*fusion->called_computations()[0], depth_tree, fusion->operands()); } -Status EinsumDepthAnalysis::HandleCustomCall(HloInstruction* custom_call) { - if (custom_call->shape().IsToken() || custom_call->shape().IsArray() || - custom_call->shape().IsTuple()) { - return DefaultAction(custom_call); - } - return Unimplemented("Unimplemented custom-call: %s", - custom_call->custom_call_target()); -} - -Status EinsumDepthAnalysis::HandleWhile(HloInstruction* xla_while) { - auto depth_iter = einsum_depth_map_.find(xla_while); - CHECK(depth_iter != einsum_depth_map_.end()); - const ShapeTree depth_tree = depth_iter->second; +absl::Status EinsumDepthAnalysis::HandleWhile(HloInstruction* xla_while) { + auto depth_iter = GetDepthTreeOrDie(xla_while); + const ShapeTree& depth_tree = depth_iter->second; int max_depth = GetMaxDepth(depth_tree); HloComputation* condition_computation = xla_while->while_condition(); HloInstruction* condition_root = condition_computation->root_instruction(); ShapeTree condition_depth(condition_root->shape(), max_depth); TF_RETURN_IF_ERROR(HandleCalledComputation( *condition_computation, condition_depth, xla_while->operands())); + const ShapeTree* root_depth_ptr = &depth_tree; HloComputation* body_computation = xla_while->while_body(); - TF_RETURN_IF_ERROR(HandleCalledComputation(*body_computation, depth_tree, - xla_while->operands())); - // Elements of while loop outputs may only be used within the while loop. - // Set the depth of the while body outputs to have the max of their original - // depth and their corresponding operand depth if their original depth was - // negative. Then recompute while loop instruction depths. - auto body_depth_iter = + bool run_depth_propagation_on_body = true; + auto root_depth_iter = GetOrCreateDepthTree(body_computation->root_instruction()); - ShapeTree& body_depth = body_depth_iter->second; - // Note: while body computations have a single parameter. See - // ShapeVerifier::HandleWhile. - HloInstruction* operand = body_computation->parameter_instruction(0); - auto operand_depth = GetOrCreateDepthTree(operand)->second; - body_depth.ForEachMutableElement( - [&body_depth, &operand_depth](const ShapeIndex& shape_index, - int* depth_ptr) { - if (body_depth.IsLeaf(shape_index)) { - if (body_depth.element(shape_index) < 0 && + ShapeTree& root_depth = root_depth_iter->second; + while (run_depth_propagation_on_body) { + run_depth_propagation_on_body = false; + TF_RETURN_IF_ERROR(HandleCalledComputation( + *body_computation, *root_depth_ptr, xla_while->operands())); + // Elements of while loop outputs may only be used within the while loop. + // If such elements exist, we set its root depth to it operand depth. Then + // recompute while loop instruction depths. + HloInstruction* operand = body_computation->parameter_instruction(0); + const ShapeTree& operand_depth = GetOrCreateDepthTree(operand)->second; + + root_depth.ForEachMutableElement( + [&run_depth_propagation_on_body, &root_depth, &operand_depth]( + const ShapeIndex& shape_index, int* depth_ptr) { + if (!root_depth.IsLeaf(shape_index)) { + return; + } + if (root_depth.element(shape_index) < 0 && operand_depth.element(shape_index) >= 0) { - *depth_ptr = 0; + *depth_ptr = operand_depth.element(shape_index); + run_depth_propagation_on_body = true; } - } - }); - return HandleCalledComputation(*body_computation, body_depth, - xla_while->operands()); + }); + root_depth_ptr = &root_depth; + } + return OkStatus(); } -Status EinsumDepthAnalysis::HandleConditional(HloInstruction* conditional) { - auto depth_iter = einsum_depth_map_.find(conditional); - CHECK(depth_iter != einsum_depth_map_.end()); +absl::Status EinsumDepthAnalysis::HandleConditional( + HloInstruction* conditional) { + auto depth_iter = GetDepthTreeOrDie(conditional); const ShapeTree depth_tree = depth_iter->second; // Conditionals have one more operand than the number of branches. The first // operand is the pred. @@ -422,7 +472,7 @@ Status EinsumDepthAnalysis::HandleConditional(HloInstruction* conditional) { return OkStatus(); } -Status EinsumDepthAnalysis::HandleCalledComputation( +absl::Status EinsumDepthAnalysis::HandleCalledComputation( const HloComputation& called_computation, const ShapeTree& root_depth, absl::Span operands) { TF_RETURN_IF_ERROR(RunInternal(called_computation, @@ -437,21 +487,404 @@ Status EinsumDepthAnalysis::HandleCalledComputation( return OkStatus(); } -Status EinsumDepthAnalysis::HandleAfterAll(HloInstruction* after_all) { +absl::Status EinsumDepthAnalysis::HandleAfterAll(HloInstruction* after_all) { + auto depth_iter = GetDepthTreeOrDie(after_all); + const ShapeTree& depth_tree = depth_iter->second; + int max_depth = GetMaxDepth(depth_tree); + for (HloInstruction* operand_token : after_all->mutable_operands()) { + CHECK(operand_token->shape().IsToken()); + TF_RETURN_IF_ERROR(SetInstructionDepth(operand_token, max_depth)); + } return OkStatus(); } -Status EinsumDepthAnalysis::HandleOutfeed(HloInstruction* outfeed) { - auto depth_iter = einsum_depth_map_.find(outfeed); - CHECK(depth_iter != einsum_depth_map_.end()); - const ShapeTree depth_tree = depth_iter->second; +absl::Status EinsumDepthAnalysis::HandleSend(HloInstruction* send) { + auto depth_iter = GetDepthTreeOrDie(send); + const ShapeTree& depth_tree = depth_iter->second; + HloInstruction* send_buffer = send->mutable_operand(0); + auto send_buffer_depth_iter = GetOrCreateDepthTree(send_buffer); + ShapeTree& send_buffer_depth = send_buffer_depth_iter->second; + SetDepthFromTupleDepth(send_buffer_depth, depth_tree, 0); int max_depth = GetMaxDepth(depth_tree); - for (HloInstruction* operand : outfeed->mutable_operands()) { - TF_RETURN_IF_ERROR(SetInstructionDepth(operand, max_depth)); + HloInstruction* token = send->mutable_operand(1); + return SetInstructionDepth(token, max_depth); +} + +absl::Status EinsumDepthAnalysis::HandleRecv(HloInstruction* recv) { + auto depth_iter = GetDepthTreeOrDie(recv); + const ShapeTree& depth_tree = depth_iter->second; + TF_ASSIGN_OR_RETURN(HloInstruction * send, + send_recv_group_map_->GetMatchingSendOrRecv(recv)); + CHECK(send) << "recv: " << recv->name() + << " not found in send_recv_group_map: " << recv->ToString(); + auto send_depth_iter = GetOrCreateDepthTree(send); + ShapeTree& send_depth = send_depth_iter->second; + int max_depth = GetMaxDepth(depth_tree); + send_depth.ForEachMutableElement([&depth_tree, &send_depth, max_depth]( + const ShapeIndex& index, int* depth) { + if (!send_depth.IsLeaf(index)) { + return; + } + if (index.front() == 0) { + *depth = MergeDepth(*depth, depth_tree.element(index)); + return; + } + *depth = MergeDepth(*depth, max_depth); + }); + HloInstruction* after_all = recv->mutable_operand(0); + return SetInstructionDepth(after_all, max_depth); +} + +absl::Status EinsumDepthAnalysis::HandleSendDone(HloInstruction* send_done) { + HloInstruction* send = send_done->mutable_operand(0); + auto depth_iter = GetDepthTreeOrDie(send_done); + const ShapeTree& depth_tree = depth_iter->second; + int max_depth = GetMaxDepth(depth_tree); + return SetInstructionDepth(send, max_depth); +} + +absl::Status EinsumDepthAnalysis::HandleRecvDone(HloInstruction* recv_done) { + auto depth_iter = GetDepthTreeOrDie(recv_done); + const ShapeTree& depth_tree = depth_iter->second; + int max_depth = GetMaxDepth(depth_tree); + HloInstruction* recv = recv_done->mutable_operand(0); + auto recv_depth_iter = GetOrCreateDepthTree(recv); + ShapeTree& recv_depth = recv_depth_iter->second; + recv_depth.ForEachMutableElement([&depth_tree, &recv_depth, max_depth]( + const ShapeIndex& index, int* depth) { + if (!recv_depth.IsLeaf(index)) { + return; + } + if (index.front() == 0) { + *depth = MergeDepth(*depth, depth_tree.element(index)); + return; + } + *depth = MergeDepth(*depth, max_depth); + }); + return OkStatus(); +} + +namespace { + +int MergeHeight(int original_height, int new_height) { + return std::max(original_height, new_height); +} + +void SetHeight(ShapeTree& height_tree, int height) { + height_tree.ForEachMutableElement( + [height, &height_tree](const ShapeIndex& shape_index, int* height_ptr) { + if (height_tree.IsLeaf(shape_index)) { + *height_ptr = MergeHeight(*height_ptr, height); + } + }); +} + +void SetHeight(ShapeTree& height_tree, const ShapeTree& source, + const ShapeIndex& source_index = {}, + const ShapeIndex& target_index = {}) { + height_tree.ForEachMutableElement( + [&source, &source_index, &target_index](const ShapeIndex& shape_index, + int* height_ptr) { + if (shape_index.size() < target_index.size()) { + return; + } + for (int i = 0; i < target_index.size(); ++i) { + if (shape_index[i] != target_index[i]) { + return; + } + } + ShapeIndex complete_source_index = source_index; + for (int i = target_index.size(); i < shape_index.size(); ++i) { + complete_source_index.push_back(shape_index[i]); + } + *height_ptr = + MergeHeight(*height_ptr, source.element(complete_source_index)); + }); +} + +int GetMaxHeight(const ShapeTree& height_tree) { + int max_height = 0; + height_tree.ForEachElement( + [&max_height](const ShapeIndex& shape_index, int height) { + max_height = std::max(max_height, height); + return OkStatus(); + }); + return max_height; +} + +int GetMaxOperandHeight(HloInstruction* instruction, + const EinsumHeightMap& einsum_height_map) { + int max_height = 0; + for (HloInstruction* operand : instruction->mutable_operands()) { + auto operand_height_iter = einsum_height_map.find(operand); + CHECK(operand_height_iter != einsum_height_map.end()) + << "operand: " << operand->name(); + const ShapeTree& operand_height_tree = operand_height_iter->second; + int max_operand_height = GetMaxHeight(operand_height_tree); + max_height = std::max(max_height, max_operand_height); + } + return max_height; +} + +} // namespace + +absl::StatusOr> EinsumHeightAnalysis::Run( + const HloComputation& computation, + const SendRecvGroupMap& send_recv_group_map) { + EinsumHeightAnalysis* analysis_ptr = + new EinsumHeightAnalysis(send_recv_group_map); + std::unique_ptr analysis(analysis_ptr); + TF_RETURN_IF_ERROR(analysis->RunInternal(computation, {})); + TF_RETURN_IF_ERROR(analysis->RunInternal(computation, {})); + return analysis; +} + +absl::Status EinsumHeightAnalysis::RunInternal( + const HloComputation& computation, + absl::Span operands) { + return HandleCalledComputation(computation, operands); +} + +EinsumHeightMap::iterator EinsumHeightAnalysis::GetOrCreateHeightTree( + const HloInstruction* instruction) { + auto height_iter = einsum_height_map_.find(instruction); + if (height_iter == einsum_height_map_.end()) { + ShapeTree height_tree(instruction->shape(), 0); + auto inserted = einsum_height_map_.insert( + std::make_pair(instruction, std::move(height_tree))); + height_iter = inserted.first; } + return height_iter; +} + +EinsumHeightMap::iterator EinsumHeightAnalysis::GetHeightTreeOrDie( + const HloInstruction* instruction) { + auto height_iter = einsum_height_map_.find(instruction); + CHECK(height_iter != einsum_height_map_.end()); + return height_iter; +} + +bool EinsumHeightAnalysis::HasHeightFor( + const HloInstruction* instruction) const { + return einsum_height_map_.contains(instruction); +} + +absl::Status EinsumHeightAnalysis::SetInstructionHeight( + const HloInstruction* instruction, int height) { + auto height_iter = GetOrCreateHeightTree(instruction); + ShapeTree& height_tree = height_iter->second; + SetHeight(height_tree, height); + return OkStatus(); +} + +absl::Status EinsumHeightAnalysis::SetInstructionHeight( + const HloInstruction* instruction, const ShapeTree& height) { + auto height_iter = GetOrCreateHeightTree(instruction); + ShapeTree& height_tree = height_iter->second; + SetHeight(height_tree, height); + return OkStatus(); +} + +#define RETURN_IF_HEIGHT_EXISTS(instruction) \ + if (HasHeightFor(instruction)) { \ + return OkStatus(); \ + } + +absl::Status EinsumHeightAnalysis::HandleHeightIncrementInstruction( + HloInstruction* instruction) { + auto height_iter = GetOrCreateHeightTree(instruction); + for (HloInstruction* operand : instruction->mutable_operands()) { + auto operand_height_iter = GetHeightTreeOrDie(operand); + int operand_height = operand_height_iter->second.element({}); + SetHeight(height_iter->second, operand_height + 1); + } + return OkStatus(); +} + +absl::Status EinsumHeightAnalysis::HandleCalledComputation( + const HloComputation& computation, + absl::Span operands) { + if (!operands.empty()) { + if (computation.num_parameters() != operands.size()) { + return absl::InvalidArgumentError(absl::StrCat( + operands.size(), " operands were passed for the computation ", + computation.name(), " with ", computation.num_parameters(), + " parameters.")); + } + for (int parameter_index = 0; + parameter_index < computation.num_parameters(); ++parameter_index) { + HloInstruction* parameter = + computation.parameter_instruction(parameter_index); + HloInstruction* operand = operands[parameter_index]; + auto operand_height_iter = GetHeightTreeOrDie(operand); + TF_RETURN_IF_ERROR( + SetInstructionHeight(parameter, operand_height_iter->second)); + } + } + for (HloInstruction* instruction : computation.instructions()) { + if (instruction->user_count() == 0) { + TF_RETURN_IF_ERROR(instruction->Accept(this)); + } + } + return OkStatus(); +} + +absl::Status EinsumHeightAnalysis::DefaultAction(HloInstruction* instruction) { + RETURN_IF_HEIGHT_EXISTS(instruction); + int instruction_height = GetMaxOperandHeight(instruction, einsum_height_map_); + return SetInstructionHeight(instruction, instruction_height); +} + +absl::Status EinsumHeightAnalysis::HandleTupleLike(HloInstruction* tuple_like) { + auto height_iter = GetOrCreateHeightTree(tuple_like); + ShapeTree& height_tree = height_iter->second; + height_tree.ForEachMutableElement([&height_tree, tuple_like, this]( + const ShapeIndex& index, int* height) { + if (!height_tree.IsLeaf(index)) { + return; + } + int operand_index = index.front(); + const HloInstruction* operand = tuple_like->operand(operand_index); + auto operand_height_iter = GetHeightTreeOrDie(operand); + CHECK(operand_height_iter != einsum_height_map_.end()) + << "operand: " << operand->name(); + ShapeIndex source_index = index; + source_index.pop_front(); + *height = + MergeHeight(*height, operand_height_iter->second.element(source_index)); + }); + return OkStatus(); +} + +absl::Status EinsumHeightAnalysis::HandleTuple(HloInstruction* tuple) { + RETURN_IF_HEIGHT_EXISTS(tuple); + return HandleTupleLike(tuple); +} + +absl::Status EinsumHeightAnalysis::HandleGetTupleElement( + HloInstruction* get_tuple_element) { + RETURN_IF_HEIGHT_EXISTS(get_tuple_element); + auto height_iter = GetOrCreateHeightTree(get_tuple_element); + ShapeTree& height_tree = height_iter->second; + auto tuple_height_iter = GetHeightTreeOrDie(get_tuple_element->operand(0)); + const ShapeTree& tuple_height_tree = tuple_height_iter->second; + int tuple_index = get_tuple_element->tuple_index(); + SetHeight(height_tree, tuple_height_tree, {tuple_index}, {}); + return OkStatus(); +} + +absl::Status EinsumHeightAnalysis::HandleDot(HloInstruction* dot) { + RETURN_IF_HEIGHT_EXISTS(dot); + return HandleHeightIncrementInstruction(dot); +} + +absl::Status EinsumHeightAnalysis::HandleConvolution( + HloInstruction* convolution) { + RETURN_IF_HEIGHT_EXISTS(convolution); + return HandleHeightIncrementInstruction(convolution); +} + +absl::Status EinsumHeightAnalysis::HandleCall(HloInstruction* call) { + RETURN_IF_HEIGHT_EXISTS(call); + TF_RETURN_IF_ERROR(HandleCalledComputation(*(call->called_computations()[0]), + call->mutable_operands())); + auto root_height_iter = + GetHeightTreeOrDie(call->called_computations()[0]->root_instruction()); + TF_RETURN_IF_ERROR(SetInstructionHeight(call, root_height_iter->second)); + return OkStatus(); +} + +absl::Status EinsumHeightAnalysis::HandleFusion(HloInstruction* fusion) { + RETURN_IF_HEIGHT_EXISTS(fusion); + return HandleCall(fusion); +} + +absl::Status EinsumHeightAnalysis::HandleWhile(HloInstruction* xla_while) { + RETURN_IF_HEIGHT_EXISTS(xla_while); + TF_RETURN_IF_ERROR(HandleCalledComputation(*(xla_while->while_condition()), + xla_while->mutable_operands())); + TF_RETURN_IF_ERROR(HandleCalledComputation(*(xla_while->while_body()), + xla_while->mutable_operands())); + auto root_height_iter = + GetHeightTreeOrDie(xla_while->while_body()->root_instruction()); + return SetInstructionHeight(xla_while, root_height_iter->second); +} + +absl::Status EinsumHeightAnalysis::HandleConditional( + HloInstruction* conditional) { + RETURN_IF_HEIGHT_EXISTS(conditional); + auto conditional_height_iter = GetOrCreateHeightTree(conditional); + ShapeTree& height_tree = conditional_height_iter->second; + for (size_t i = 0; i < conditional->branch_count(); ++i) { + HloComputation* computation = conditional->branch_computation(i); + // An N-way conditional op has N + 1 operands where the first one is the + // branch index determining what branch to take, and the remaining N + // operands correspond to arguments to be passed to each of the N branch + // computations, if they are executed. So the (i + 1)th operand corresponds + // to the ith branch computation. + TF_RETURN_IF_ERROR(HandleCalledComputation( + *computation, {conditional->mutable_operands()[i + 1]})); + auto branch_root_height_iter = + GetHeightTreeOrDie(computation->root_instruction()); + SetHeight(height_tree, branch_root_height_iter->second); + } + return OkStatus(); +} + +absl::Status EinsumHeightAnalysis::HandleSend(HloInstruction* send) { + RETURN_IF_HEIGHT_EXISTS(send); + HloInstruction* send_buffer = send->mutable_operand(0); + auto send_buffer_height_iter = GetHeightTreeOrDie(send_buffer); + const ShapeTree& send_buffer_height_tree = + send_buffer_height_iter->second; + + auto height_iter = GetOrCreateHeightTree(send); + ShapeTree& height_tree = height_iter->second; + SetHeight(height_tree, send_buffer_height_tree, {}, {0}); + return OkStatus(); +} + +absl::Status EinsumHeightAnalysis::HandleRecv(HloInstruction* recv) { + RETURN_IF_HEIGHT_EXISTS(recv); + TF_ASSIGN_OR_RETURN(HloInstruction * send, + send_recv_group_map_->GetMatchingSendOrRecv(recv)); + TF_RETURN_IF_ERROR(send->Accept(this)); + HloInstruction* send_buffer = send->mutable_operand(0); + auto send_buffer_height_iter = GetHeightTreeOrDie(send_buffer); + const ShapeTree& send_buffer_height_tree = + send_buffer_height_iter->second; + + auto height_iter = GetOrCreateHeightTree(recv); + ShapeTree& height_tree = height_iter->second; + SetHeight(height_tree, send_buffer_height_tree, {}, {0}); + return OkStatus(); +} + +absl::Status EinsumHeightAnalysis::HandleSendDone(HloInstruction* send_done) { + RETURN_IF_HEIGHT_EXISTS(send_done); + GetOrCreateHeightTree(send_done); return OkStatus(); } +absl::Status EinsumHeightAnalysis::HandleRecvDone(HloInstruction* recv_done) { + RETURN_IF_HEIGHT_EXISTS(recv_done); + HloInstruction* recv = recv_done->mutable_operand(0); + auto recv_height_iter = GetHeightTreeOrDie(recv); + const ShapeTree& recv_height_tree = recv_height_iter->second; + auto height_iter = GetOrCreateHeightTree(recv_done); + ShapeTree& height_tree = height_iter->second; + SetHeight(height_tree, recv_height_tree, {0}, {0}); + return OkStatus(); +} + +absl::Status EinsumHeightAnalysis::HandleAllReduce(HloInstruction* all_reduce) { + RETURN_IF_HEIGHT_EXISTS(all_reduce); + if (all_reduce->shape().IsArray()) { + return DefaultAction(all_reduce); + } + return HandleTupleLike(all_reduce); +} + std::string HloValueSemanticLabelToString(HloValueSemanticLabel label) { switch (label) { case HloValueSemanticLabel::kStatic: @@ -487,6 +920,11 @@ HloValueSemantics::HloValueSemantics(Id id, HloValueSemanticLabel label, const HloPosition& origin) : id_(id), label_(label), origin_(origin) {} +std::string HloValueSemanticsTreeToString( + const ShapeTree& tree) { + return ToString(tree); +} + HloValueSemanticsAnalysis::HloValueSemanticsAnalysis(const HloModule& module) : module_(module), next_id_(0) {} @@ -495,25 +933,66 @@ const HloValueSemantics* HloValueSemanticsAnalysis::GetSemantics( return GetInstructionSemantics(instruction).element(index); } -StatusOr> +int HloValueSemanticsAnalysis::GetDepth(const HloInstruction* instruction, + const ShapeIndex& index) const { + auto depth_iter = einsum_depth_map_.find(instruction); + CHECK(depth_iter != einsum_depth_map_.end()); + return depth_iter->second.element(index); +} + +int HloValueSemanticsAnalysis::GetHeight(const HloInstruction* instruction, + const ShapeIndex& index) const { + auto height_iter = einsum_height_map_.find(instruction); + CHECK(height_iter != einsum_height_map_.end()); + return height_iter->second.element(index); +} + +absl::StatusOr> HloValueSemanticsAnalysis::Run(const HloModule& module) { std::unique_ptr value_semantics_analysis = absl::WrapUnique(new HloValueSemanticsAnalysis(module)); + value_semantics_analysis->InitializeSendRecvGroups(); TF_RETURN_IF_ERROR(value_semantics_analysis->InitializeEinsumDepth()); + TF_RETURN_IF_ERROR(value_semantics_analysis->InitializeEinsumHeight()); value_semantics_analysis->AnnotateWeights(); TF_RETURN_IF_ERROR( value_semantics_analysis->RunOnComputation(*module.entry_computation())); return value_semantics_analysis; } -Status HloValueSemanticsAnalysis::InitializeEinsumDepth() { +absl::Status HloValueSemanticsAnalysis::InitializeEinsumDepth() { TF_ASSIGN_OR_RETURN( std::unique_ptr einsum_depth_analysis, - EinsumDepthAnalysis::Run(*module_.entry_computation())); + EinsumDepthAnalysis::Run(*module_.entry_computation(), + *send_recv_group_map_)); einsum_depth_map_ = einsum_depth_analysis->GetEinsumDepthMap(); return OkStatus(); } +absl::Status HloValueSemanticsAnalysis::InitializeEinsumHeight() { + TF_ASSIGN_OR_RETURN( + std::unique_ptr einsum_height_analysis, + EinsumHeightAnalysis::Run(*module_.entry_computation(), + *send_recv_group_map_)); + einsum_height_map_ = einsum_height_analysis->GetEinsumHeightMap(); + return OkStatus(); +} + +void HloValueSemanticsAnalysis::InitializeSendRecvGroups() { + send_recv_group_map_ = std::make_unique(module_); +} + +bool HloValueSemanticsAnalysis::HasSemanticsFor( + const HloInstruction* instruction) const { + return value_semantics_.contains(instruction); +} + +absl::StatusOr +HloValueSemanticsAnalysis::GetMatchingSendOrRecv( + HloInstruction* send_or_recv) const { + return send_recv_group_map_->GetMatchingSendOrRecv(send_or_recv); +} + HloValueSemantics::Id HloValueSemanticsAnalysis::NextId() { return next_id_++; } const HloValueSemantics* HloValueSemanticsAnalysis::NewHloValueSemantics( @@ -624,7 +1103,7 @@ void HloValueSemanticsAnalysis::AnnotateWeights() { } } -Status HloValueSemanticsAnalysis::RunOnComputation( +absl::Status HloValueSemanticsAnalysis::RunOnComputation( const HloComputation& computation, absl::Span operands) { CHECK_EQ(computation.num_parameters(), operands.size()); @@ -637,7 +1116,7 @@ Status HloValueSemanticsAnalysis::RunOnComputation( return RunOnComputation(computation); } -Status HloValueSemanticsAnalysis::RunOnComputation( +absl::Status HloValueSemanticsAnalysis::RunOnComputation( const HloComputation& computation) { HloValueSemanticsPropagation propagation(this); return propagation.Run(computation); @@ -647,8 +1126,15 @@ HloValueSemanticsPropagation::HloValueSemanticsPropagation( HloValueSemanticsAnalysis* analysis) : analysis_(analysis) {} -Status HloValueSemanticsPropagation::Run(const HloComputation& computation) { - return computation.root_instruction()->Accept(this); +absl::Status HloValueSemanticsPropagation::Run( + const HloComputation& computation) { + TF_RETURN_IF_ERROR(computation.root_instruction()->Accept(this)); + for (HloInstruction* instruction : computation.instructions()) { + if (instruction->user_count() == 0) { + TF_RETURN_IF_ERROR(instruction->Accept(this)); + } + } + return OkStatus(); } HloValueSemantics HloValueSemanticsPropagation::CopySemantics( @@ -728,7 +1214,7 @@ bool HloValueSemanticsPropagation::OriginDependsOn( return !dependent_einsums.empty(); } -StatusOr +absl::StatusOr HloValueSemanticsPropagation::ComputeSemanticsFromStaticAndOther( const HloValueSemantics& static_semantics, const HloValueSemantics& other_semantics, @@ -743,12 +1229,13 @@ HloValueSemanticsPropagation::ComputeSemanticsFromStaticAndOther( instruction->opcode() == HloOpcode::kConvolution; if (is_dot_or_convolution && other_semantics.label() == HloValueSemanticLabel::kActivationGradient) { - return CreateGradientSemantics(instruction); + return MaybeCreateGradientSemantics( + instruction, HloValueSemanticLabel::kActivationGradient); } return CopySemantics(other_semantics); } -StatusOr +absl::StatusOr HloValueSemanticsPropagation::ComputeSemanticsFromRandomAndOther( const HloValueSemantics& random_semantics, const HloValueSemantics& other_semantics, @@ -761,13 +1248,11 @@ HloValueSemanticsPropagation::ComputeSemanticsFromRandomAndOther( return CopySemantics(other_semantics); } -StatusOr -HloValueSemanticsPropagation::CreateGradientSemantics( - HloInstruction* gradient_candidate) const { - const EinsumDepthMap& einsum_depth_map = analysis_->GetEinsumDepthMap(); - auto depth_iter = einsum_depth_map.find(gradient_candidate); - CHECK(depth_iter != einsum_depth_map.end()); - int gradient_depth = depth_iter->second.element({}); +absl::StatusOr +HloValueSemanticsPropagation::MaybeCreateGradientSemantics( + HloInstruction* gradient_candidate, + HloValueSemanticLabel fallback_label) const { + int gradient_depth = analysis_->GetDepth(gradient_candidate, {}); if (gradient_depth < 0) { // There is dependency between the two operands of the dot, but the dot // is not used by root. This is likely eval computation in a TF program. @@ -779,11 +1264,10 @@ HloValueSemanticsPropagation::CreateGradientSemantics( return HloValueSemantics(HloValueSemanticLabel::kWeightGradient, {gradient_candidate, {}}); } - return HloValueSemantics(HloValueSemanticLabel::kActivationGradient, - {gradient_candidate, {}}); + return HloValueSemantics(fallback_label, {gradient_candidate, {}}); } -StatusOr +absl::StatusOr HloValueSemanticsPropagation::ComputeSemanticsFromWeightAndOther( const HloValueSemantics& weight_semantics, const HloValueSemantics& other_semantics, @@ -795,6 +1279,9 @@ HloValueSemanticsPropagation::ComputeSemanticsFromWeightAndOther( instruction->opcode() == HloOpcode::kConvolution; if (other_semantics.label() == HloValueSemanticLabel::kWeight) { if (!is_dot_or_convolution) { + if (weight_semantics.origin() == other_semantics.origin()) { + return CopySemantics(other_semantics); + } return CopySemanticsWithNewOrigin(other_semantics, instruction); } return HloValueSemantics(HloValueSemanticLabel::kActivation, @@ -809,9 +1296,22 @@ HloValueSemanticsPropagation::ComputeSemanticsFromWeightAndOther( // is the loss. We distinguish this case from regular Activations by // checking whether X is computed from some einsum that takes W as an // operand. - if (OriginDependsOn(other_semantics, weight_semantics.origin(), - /*recursive=*/true)) { - return CreateGradientSemantics(instruction); + int instruction_depth = analysis_->GetDepth(instruction, {}); + auto dependent_einsums = FindEinsumsWhereOriginDependsOnOther( + other_semantics, weight_semantics.origin(), /*recursive=*/true); + bool all_dependent_einsums_immediately_proceeds_instruction = + absl::c_all_of(dependent_einsums, + [instruction_depth, + this](const EinsumAndOperandIndex& dependent_einsum) { + int dependent_einsum_depth = + analysis_->GetDepth(dependent_einsum.einsum, {}); + return dependent_einsum_depth > 0 && + dependent_einsum_depth == instruction_depth + 1; + }); + if (!dependent_einsums.empty() && + all_dependent_einsums_immediately_proceeds_instruction) { + return MaybeCreateGradientSemantics( + instruction, HloValueSemanticLabel::kActivationGradient); } return CopySemanticsWithNewOrigin(other_semantics, instruction); } @@ -820,13 +1320,14 @@ HloValueSemanticsPropagation::ComputeSemanticsFromWeightAndOther( // which produce an Activation. The ActivationGradient to this Activation // could be used in an einsum with one of the Weights to compute // the WeightGradient for the other Weight. - return CreateGradientSemantics(instruction); + return MaybeCreateGradientSemantics( + instruction, HloValueSemanticLabel::kActivationGradient); } CHECK(other_semantics.label() == HloValueSemanticLabel::kWeightGradient); return CopySemantics(other_semantics); } -StatusOr +absl::StatusOr HloValueSemanticsPropagation::ComputeSemanticsFromActivationAndOther( const HloValueSemantics& activation_semantics, const HloValueSemantics& other_semantics, @@ -838,14 +1339,16 @@ HloValueSemanticsPropagation::ComputeSemanticsFromActivationAndOther( bool is_dot_or_convolution = instruction->opcode() == HloOpcode::kDot || instruction->opcode() == HloOpcode::kConvolution; if (!is_dot_or_convolution) { + if (activation_semantics.origin() == other_semantics.origin()) { + return CopySemantics(other_semantics); + } return CopySemanticsWithNewOrigin(other_semantics, instruction); } if (other_semantics.label() == HloValueSemanticLabel::kActivation) { // Like said above, since loss is classified as Activation, an einsum - // between an Activation X and an Activation Y could be WeightGradient or - // even ActivationGradient when either X or Y is the loss. This case is - // different from other Activation einsums because there must a dependency - // between X and Y. + // between an Activation X and an Activation Y could be WeightGradient if + // either X or Y is the loss. This case is different from other Activation + // einsums because there must a dependency between X and Y. bool other_depends_on_activation = OriginDependsOn( other_semantics, activation_semantics.origin(), /*recursive=*/true); bool activation_depends_on_other = @@ -855,14 +1358,19 @@ HloValueSemanticsPropagation::ComputeSemanticsFromActivationAndOther( // If there is no dependency between the two Activations, the output must // be an Activation. if (other_depends_on_activation || activation_depends_on_other) { - return CreateGradientSemantics(instruction); + // We check if the einsum is actually weight gradient. If it is not, fall + // back to activation, since we expect the loss to be computed from an + // activation-weight einsum. + return MaybeCreateGradientSemantics(instruction, + HloValueSemanticLabel::kActivation); } return CopySemanticsWithNewOrigin(other_semantics, instruction); } if (other_semantics.label() == HloValueSemanticLabel::kActivationGradient) { // An Activation-ActivationGradient einsum could be computing // WeightGradient or ActivationGradient. - return CreateGradientSemantics(instruction); + return MaybeCreateGradientSemantics( + instruction, HloValueSemanticLabel::kActivationGradient); } CHECK(other_semantics.label() == HloValueSemanticLabel::kWeightGradient) << "instruction: " << instruction->ToString() @@ -872,7 +1380,7 @@ HloValueSemanticsPropagation::ComputeSemanticsFromActivationAndOther( return CopySemantics(other_semantics); } -StatusOr +absl::StatusOr HloValueSemanticsPropagation::ComputeSemanticsFromActivationGradientAndOther( const HloValueSemantics& activation_gradient_semantics, const HloValueSemantics& other_semantics, @@ -884,6 +1392,9 @@ HloValueSemanticsPropagation::ComputeSemanticsFromActivationGradientAndOther( other_semantics.label() != HloValueSemanticLabel::kWeight && other_semantics.label() != HloValueSemanticLabel::kActivation); if (other_semantics.label() == HloValueSemanticLabel::kActivationGradient) { + if (other_semantics.origin() == activation_gradient_semantics.origin()) { + return CopySemantics(activation_gradient_semantics); + } return CopySemanticsWithNewOrigin(other_semantics, instruction); } @@ -891,7 +1402,7 @@ HloValueSemanticsPropagation::ComputeSemanticsFromActivationGradientAndOther( return CopySemantics(other_semantics); } -StatusOr +absl::StatusOr HloValueSemanticsPropagation::ComputeSemanticsFromWeightGradientAndOther( const HloValueSemantics& weight_gradient_semantics, const HloValueSemantics& other_semantics, @@ -906,32 +1417,15 @@ HloValueSemanticsPropagation::ComputeSemanticsFromWeightGradientAndOther( return CopySemantics(weight_gradient_semantics); } -StatusOr -HloValueSemanticsPropagation::ComputeSemanticsFromOperands( - HloInstruction* instruction, absl::Span operand_indices, - absl::Span operand_shape_indices) const { - CHECK(!operand_indices.empty()); - CHECK(operand_shape_indices.empty() || - operand_indices.size() == operand_shape_indices.size()); - VLOG(3) << __func__ << ", instruction: " << instruction->ToString(); - std::vector semantics_vec; - for (int64_t operand_index : operand_indices) { - const HloInstruction* operand = instruction->operand(operand_index); - const HloValueSemantics* operand_semantics = analysis_->GetSemantics( - operand, operand_shape_indices.empty() - ? ShapeIndex() - : operand_shape_indices[operand_index]); - VLOG(3) << __func__ << ", operand_index: " << operand_index - << ", operand: " << operand->name() - << ", operand_semantics: " << operand_semantics->ToString(); - semantics_vec.push_back(*operand_semantics); - } +absl::StatusOr +HloValueSemanticsPropagation::MergeSemanticsForAnInstruction( + HloInstruction* instruction, + std::vector& semantics_vec) const { while (semantics_vec.size() >= 2) { absl::Span operand_list = absl::MakeConstSpan(semantics_vec).subspan(semantics_vec.size() - 2, 2); auto find_operand_index_with_label = - [&operand_list]( - HloValueSemanticLabel label) -> std::optional { + [&operand_list](HloValueSemanticLabel label) -> std::optional { auto iter = absl::c_find_if(operand_list, [label](const HloValueSemantics& operand) { return operand.label() == label; @@ -1001,6 +1495,13 @@ HloValueSemanticsPropagation::ComputeSemanticsFromOperands( replace_operands_semantics_with(semantics); continue; } + if (operand_list[0].label() == HloValueSemanticLabel::kTupleOrToken && + operand_list[1].label() == HloValueSemanticLabel::kTupleOrToken) { + HloValueSemantics semantics = + CopySemanticsWithNewOrigin(operand_list[0], instruction); + replace_operands_semantics_with(semantics); + continue; + } LOG(FATAL) << "We don't expect to handle operands of label " << HloValueSemanticLabelToString(operand_list[0].label()) << " and " @@ -1015,26 +1516,80 @@ HloValueSemanticsPropagation::ComputeSemanticsFromOperands( return semantics_vec.back(); } -Status HloValueSemanticsPropagation::DefaultAction( +absl::StatusOr +HloValueSemanticsPropagation::ComputeSemanticsFromOperands( + HloInstruction* instruction, absl::Span operand_indices, + absl::Span operand_shape_indices) const { + CHECK(!operand_indices.empty()); + CHECK(operand_shape_indices.empty() || + operand_indices.size() == operand_shape_indices.size()); + VLOG(3) << __func__ << ", instruction: " << instruction->ToString(); + std::vector semantics_vec; + for (int64_t operand_index : operand_indices) { + const HloInstruction* operand = instruction->operand(operand_index); + const HloValueSemantics* operand_semantics = analysis_->GetSemantics( + operand, operand_shape_indices.empty() + ? ShapeIndex() + : operand_shape_indices[operand_index]); + auto operand_height_iter = analysis_->GetEinsumHeightMap().find(operand); + CHECK(operand_height_iter != analysis_->GetEinsumHeightMap().end()); + VLOG(3) << __func__ << ", operand_index: " << operand_index + << ", operand: " << operand->name() + << ", operand_semantics: " << operand_semantics->ToString() + << ", height: " << ToString(operand_height_iter->second); + semantics_vec.push_back(*operand_semantics); + } + return MergeSemanticsForAnInstruction(instruction, semantics_vec); +} + +#define RETURN_IF_ALREADY_PROPAGATED(instruction) \ + if (analysis_->HasSemanticsFor(instruction)) { \ + return OkStatus(); \ + } + +absl::Status HloValueSemanticsPropagation::DefaultAction( HloInstruction* instruction) { + RETURN_IF_ALREADY_PROPAGATED(instruction); std::vector operand_indices(instruction->operand_count()); std::iota(operand_indices.begin(), operand_indices.end(), 0); TF_ASSIGN_OR_RETURN( HloValueSemantics semantics, ComputeSemanticsFromOperands(instruction, operand_indices)); - const HloValueSemantics* semantics_ptr = AddSemantics(semantics); - ShapeTree semantics_shape_tree(instruction->shape(), - semantics_ptr); - analysis_->SetHloValueSemantics(instruction, semantics_shape_tree); + + if (instruction->shape().IsTuple()) { + ShapeTree semantics_shape_tree( + instruction->shape(), nullptr); + semantics_shape_tree.ForEachMutableElement( + [this, &semantics, &semantics_shape_tree, instruction]( + const ShapeIndex& index, const HloValueSemantics** semantics_ptr) { + if (semantics_shape_tree.IsLeaf(index)) { + HloValueSemantics sub_semantics = + CopySemanticsWithNewOrigin(semantics, instruction, index); + *semantics_ptr = AddSemantics(sub_semantics); + } else { + HloValueSemantics sub_semantics( + HloValueSemanticLabel::kTupleOrToken, {instruction, index}); + *semantics_ptr = AddSemantics(sub_semantics); + } + }); + analysis_->SetHloValueSemantics(instruction, semantics_shape_tree); + } else { + const HloValueSemantics* semantics_ptr = AddSemantics(semantics); + ShapeTree semantics_shape_tree( + instruction->shape(), semantics_ptr); + analysis_->SetHloValueSemantics(instruction, semantics_shape_tree); + } return OkStatus(); } -Status HloValueSemanticsPropagation::HandleParameter( +absl::Status HloValueSemanticsPropagation::HandleParameter( HloInstruction* parameter) { return OkStatus(); } -Status HloValueSemanticsPropagation::HandleConstant(HloInstruction* constant) { +absl::Status HloValueSemanticsPropagation::HandleConstant( + HloInstruction* constant) { + RETURN_IF_ALREADY_PROPAGATED(constant); const HloValueSemantics* constant_semantics = analysis_->NewHloValueSemantics( HloValueSemanticLabel::kStatic, {constant, {}}); ShapeTree semantics_shape_tree(constant->shape(), @@ -1043,7 +1598,8 @@ Status HloValueSemanticsPropagation::HandleConstant(HloInstruction* constant) { return OkStatus(); } -Status HloValueSemanticsPropagation::HandleIota(HloInstruction* iota) { +absl::Status HloValueSemanticsPropagation::HandleIota(HloInstruction* iota) { + RETURN_IF_ALREADY_PROPAGATED(iota); const HloValueSemantics* semantics = analysis_->NewHloValueSemantics( HloValueSemanticLabel::kStatic, {iota, {}}); ShapeTree semantics_shape_tree(iota->shape(), @@ -1052,8 +1608,9 @@ Status HloValueSemanticsPropagation::HandleIota(HloInstruction* iota) { return OkStatus(); } -Status HloValueSemanticsPropagation::HandlePartitionId( +absl::Status HloValueSemanticsPropagation::HandlePartitionId( HloInstruction* partition_id) { + RETURN_IF_ALREADY_PROPAGATED(partition_id); const HloValueSemantics* semantics = analysis_->NewHloValueSemantics( HloValueSemanticLabel::kStatic, {partition_id, {}}); ShapeTree semantics_shape_tree( @@ -1061,8 +1618,9 @@ Status HloValueSemanticsPropagation::HandlePartitionId( analysis_->SetHloValueSemantics(partition_id, semantics_shape_tree); return OkStatus(); } -Status HloValueSemanticsPropagation::HandleReplicaId( +absl::Status HloValueSemanticsPropagation::HandleReplicaId( HloInstruction* replica_id) { + RETURN_IF_ALREADY_PROPAGATED(replica_id); const HloValueSemantics* semantics = analysis_->NewHloValueSemantics( HloValueSemanticLabel::kStatic, {replica_id, {}}); ShapeTree semantics_shape_tree(replica_id->shape(), @@ -1071,7 +1629,7 @@ Status HloValueSemanticsPropagation::HandleReplicaId( return OkStatus(); } -Status HloValueSemanticsPropagation::HandleRngBitGenerator( +absl::Status HloValueSemanticsPropagation::HandleRngBitGenerator( HloInstruction* rng_bit_generator) { const HloValueSemantics* semantics = analysis_->NewHloValueSemantics( HloValueSemanticLabel::kRandom, {rng_bit_generator, {}}); @@ -1081,39 +1639,22 @@ Status HloValueSemanticsPropagation::HandleRngBitGenerator( return OkStatus(); } -Status HloValueSemanticsPropagation::HandleClamp(HloInstruction* clamp) { +absl::Status HloValueSemanticsPropagation::HandleClamp(HloInstruction* clamp) { + RETURN_IF_ALREADY_PROPAGATED(clamp); const ShapeTree& operand_semantics = analysis_->GetInstructionSemantics(clamp->operand(1)); analysis_->DeepCopyHloValueSemantics(clamp, operand_semantics); return OkStatus(); } -Status HloValueSemanticsPropagation::HandleTuple(HloInstruction* tuple) { - ShapeTree semantics_shape_tree(tuple->shape(), - nullptr); - for (int operand_index = 0; operand_index < tuple->operand_count(); - ++operand_index) { - const HloInstruction* operand = tuple->operand(operand_index); - const ShapeTree& operand_semantics = - analysis_->GetInstructionSemantics(operand); - analysis_->DeepCopyHloValueSemantics( - semantics_shape_tree, operand_semantics, {}, {operand_index}); - } - semantics_shape_tree.ForEachMutableElement( - [tuple, this](const ShapeIndex& index, - const HloValueSemantics** semantics) { - if (index.empty()) { - *semantics = analysis_->NewHloValueSemantics( - HloValueSemanticLabel::kTupleOrToken, {tuple, {}}); - return; - } - }); - analysis_->SetHloValueSemantics(tuple, semantics_shape_tree); - return OkStatus(); +absl::Status HloValueSemanticsPropagation::HandleTuple(HloInstruction* tuple) { + RETURN_IF_ALREADY_PROPAGATED(tuple); + return HandleTupleLike(tuple); } -Status HloValueSemanticsPropagation::HandleGetTupleElement( +absl::Status HloValueSemanticsPropagation::HandleGetTupleElement( HloInstruction* get_tuple_element) { + RETURN_IF_ALREADY_PROPAGATED(get_tuple_element); const HloInstruction* tuple = get_tuple_element->operand(0); int64_t tuple_index = get_tuple_element->tuple_index(); const ShapeTree& tuple_semantics = @@ -1126,7 +1667,8 @@ Status HloValueSemanticsPropagation::HandleGetTupleElement( return OkStatus(); } -Status HloValueSemanticsPropagation::HandleCall(HloInstruction* call) { +absl::Status HloValueSemanticsPropagation::HandleCall(HloInstruction* call) { + RETURN_IF_ALREADY_PROPAGATED(call); HloComputation* computation = call->called_computations()[0]; TF_RETURN_IF_ERROR( analysis_->RunOnComputation(*computation, call->operands())); @@ -1136,17 +1678,14 @@ Status HloValueSemanticsPropagation::HandleCall(HloInstruction* call) { return OkStatus(); } -Status HloValueSemanticsPropagation::HandleFusion(HloInstruction* fusion) { - HloComputation* computation = fusion->called_computations()[0]; - TF_RETURN_IF_ERROR( - analysis_->RunOnComputation(*computation, fusion->operands())); - const ShapeTree& root_semantics = - analysis_->GetInstructionSemantics(computation->root_instruction()); - analysis_->DeepCopyHloValueSemantics(fusion, root_semantics); - return OkStatus(); +absl::Status HloValueSemanticsPropagation::HandleFusion( + HloInstruction* fusion) { + return HandleCall(fusion); } -Status HloValueSemanticsPropagation::HandleWhile(HloInstruction* xla_while) { +absl::Status HloValueSemanticsPropagation::HandleWhile( + HloInstruction* xla_while) { + RETURN_IF_ALREADY_PROPAGATED(xla_while); TF_RETURN_IF_ERROR(analysis_->RunOnComputation(*xla_while->while_condition(), xla_while->operands())); HloComputation* computation = xla_while->while_body(); @@ -1158,9 +1697,12 @@ Status HloValueSemanticsPropagation::HandleWhile(HloInstruction* xla_while) { return OkStatus(); } -Status HloValueSemanticsPropagation::HandleCustomCall( +absl::Status HloValueSemanticsPropagation::HandleCustomCall( HloInstruction* custom_call) { - if (custom_call->custom_call_target() == "Sharding") { + RETURN_IF_ALREADY_PROPAGATED(custom_call); + if (custom_call->custom_call_target() == "Sharding" || + custom_call->custom_call_target() == "SPMDFullToShardShape" || + custom_call->custom_call_target() == "SPMDShardToFullShape") { const ShapeTree& operand_semantics = analysis_->GetInstructionSemantics(custom_call->operand(0)); analysis_->DeepCopyHloValueSemantics(custom_call, operand_semantics); @@ -1170,21 +1712,49 @@ Status HloValueSemanticsPropagation::HandleCustomCall( custom_call->custom_call_target()); } -Status HloValueSemanticsPropagation::HandleConditional( +absl::Status HloValueSemanticsPropagation::HandleConditional( HloInstruction* conditional) { + RETURN_IF_ALREADY_PROPAGATED(conditional); + std::vector> semantics_tree_vec; for (int i = 0; i < conditional->called_computations().size(); ++i) { - TF_RETURN_IF_ERROR( - analysis_->RunOnComputation(*conditional->called_computations()[i], - {conditional->operands()[i + 1]})); + HloComputation* computation = conditional->called_computations()[i]; + TF_RETURN_IF_ERROR(analysis_->RunOnComputation( + *computation, {conditional->operands()[i + 1]})); + const ShapeTree& root_semantics = + analysis_->GetInstructionSemantics(computation->root_instruction()); + semantics_tree_vec.push_back(root_semantics); } - HloComputation* computation = conditional->called_computations()[0]; - const ShapeTree& root_semantics = - analysis_->GetInstructionSemantics(computation->root_instruction()); - analysis_->DeepCopyHloValueSemantics(conditional, root_semantics); + + std::vector merged_semantics_leaves; + TF_RETURN_IF_ERROR(semantics_tree_vec[0].ForEachElementWithStatus( + [&](const ShapeIndex& index, + const HloValueSemantics* semantics) -> Status { + std::vector semantics_vector; + for (size_t i = 0; i < semantics_tree_vec.size(); ++i) { + semantics_vector.push_back( + *(semantics_tree_vec[i].find(index)->second)); + } + TF_ASSIGN_OR_RETURN( + HloValueSemantics merged, + MergeSemanticsForAnInstruction(conditional, semantics_vector)); + merged_semantics_leaves.push_back(merged); + return OkStatus(); + })); + + ShapeTree merged_semantics(conditional->shape()); + int idx = 0; + merged_semantics.ForEachMutableElement( + [&](const ShapeIndex& index, + const HloValueSemantics** semantics) -> void { + *semantics = &merged_semantics_leaves[idx++]; + }); + analysis_->DeepCopyHloValueSemantics(conditional, merged_semantics); return OkStatus(); } -Status HloValueSemanticsPropagation::HandleSelect(HloInstruction* select) { +absl::Status HloValueSemanticsPropagation::HandleSelect( + HloInstruction* select) { + RETURN_IF_ALREADY_PROPAGATED(select); TF_ASSIGN_OR_RETURN(HloValueSemantics semantics, ComputeSemanticsFromOperands(select, {1, 2})); const HloValueSemantics* semantics_ptr = AddSemantics(semantics); @@ -1194,37 +1764,31 @@ Status HloValueSemanticsPropagation::HandleSelect(HloInstruction* select) { return OkStatus(); } -Status HloValueSemanticsPropagation::HandleConcatenate( +absl::Status HloValueSemanticsPropagation::HandleConcatenate( HloInstruction* concatenate) { + RETURN_IF_ALREADY_PROPAGATED(concatenate); const ShapeTree& operand_semantics = analysis_->GetInstructionSemantics(concatenate->operand(0)); analysis_->DeepCopyHloValueSemantics(concatenate, operand_semantics); return OkStatus(); } -Status HloValueSemanticsPropagation::HandleDynamicSlice( +absl::Status HloValueSemanticsPropagation::HandleDynamicSlice( HloInstruction* dynamic_slice) { + RETURN_IF_ALREADY_PROPAGATED(dynamic_slice); const HloInstruction* dynamic_slice_operand = dynamic_slice->operand(0); const HloValueSemantics* operand_semantics = analysis_->GetSemantics(dynamic_slice_operand); - const HloValueSemantics* semantics = nullptr; - if (operand_semantics->label() == HloValueSemanticLabel::kStatic || - operand_semantics->label() == HloValueSemanticLabel::kRandom || - operand_semantics->label() == HloValueSemanticLabel::kWeight) { - semantics = analysis_->NewHloValueSemantics(operand_semantics->label(), - {dynamic_slice, {}}); - } else { - HloValueSemantics semantics_value = CopySemantics(*operand_semantics); - semantics = AddSemantics(semantics_value); - } + const HloValueSemantics* semantics = AddSemantics(*operand_semantics); ShapeTree semantics_shape_tree( dynamic_slice->shape(), semantics); analysis_->SetHloValueSemantics(dynamic_slice, semantics_shape_tree); return OkStatus(); } -Status HloValueSemanticsPropagation::HandleDynamicUpdateSlice( +absl::Status HloValueSemanticsPropagation::HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) { + RETURN_IF_ALREADY_PROPAGATED(dynamic_update_slice); TF_ASSIGN_OR_RETURN( HloValueSemantics semantics, ComputeSemanticsFromOperands(dynamic_update_slice, {0, 1})); @@ -1235,91 +1799,86 @@ Status HloValueSemanticsPropagation::HandleDynamicUpdateSlice( return OkStatus(); } -Status HloValueSemanticsPropagation::HandleCopyStart( +absl::Status HloValueSemanticsPropagation::HandleCopyStart( HloInstruction* copy_start) { - ShapeTree semantics_shape_tree(copy_start->shape()); + return HandleCollectiveOrCopyStart(copy_start); +} + +absl::Status HloValueSemanticsPropagation::HandleCopyDone( + HloInstruction* copy_done) { + return HandleCollectiveOrCopyDone(copy_done); +} + +absl::Status HloValueSemanticsPropagation::HandleCollectiveOrCopyStart( + HloInstruction* op_start) { + RETURN_IF_ALREADY_PROPAGATED(op_start); + ShapeTree semantics_shape_tree(op_start->shape()); const ShapeTree& operand_semantics_shape_tree = - analysis_->GetInstructionSemantics(copy_start->operand(0)); + analysis_->GetInstructionSemantics(op_start->operand(0)); analysis_->DeepCopyHloValueSemantics(semantics_shape_tree, operand_semantics_shape_tree, {}, {0}); analysis_->DeepCopyHloValueSemantics(semantics_shape_tree, operand_semantics_shape_tree, {}, {1}); semantics_shape_tree.ForEachMutableElement( - [this, copy_start](const ShapeIndex& shape_index, - const HloValueSemantics** semantics) { + [this, op_start](const ShapeIndex& shape_index, + const HloValueSemantics** semantics) { if (shape_index.empty()) { *semantics = analysis_->NewHloValueSemantics( - HloValueSemanticLabel::kTupleOrToken, {copy_start, shape_index}); + HloValueSemanticLabel::kTupleOrToken, {op_start, {}}); } if (shape_index == ShapeIndex{2}) { *semantics = analysis_->NewHloValueSemantics( - HloValueSemanticLabel::kRandom, {copy_start, shape_index}); + HloValueSemanticLabel::kRandom, {op_start, shape_index}); } if (shape_index == ShapeIndex{3}) { *semantics = analysis_->NewHloValueSemantics( - HloValueSemanticLabel::kRandom, {copy_start, shape_index}); + HloValueSemanticLabel::kRandom, {op_start, shape_index}); } }); - analysis_->SetHloValueSemantics(copy_start, semantics_shape_tree); + analysis_->SetHloValueSemantics(op_start, semantics_shape_tree); return OkStatus(); } -Status HloValueSemanticsPropagation::HandleCopyDone(HloInstruction* copy_done) { +absl::Status HloValueSemanticsPropagation::HandleCollectiveOrCopyDone( + HloInstruction* op_done) { + RETURN_IF_ALREADY_PROPAGATED(op_done); const ShapeTree& operand_semantics_shape_tree = - analysis_->GetInstructionSemantics(copy_done->operand(0)); - analysis_->DeepCopyHloValueSemantics(copy_done, operand_semantics_shape_tree, - {0}); + analysis_->GetInstructionSemantics(op_done->operand(0)); + analysis_->DeepCopyHloValueSemantics(op_done, operand_semantics_shape_tree, + {1}); return OkStatus(); } -Status HloValueSemanticsPropagation::HandleCollectivePermuteStart( + +absl::Status HloValueSemanticsPropagation::HandleAllGatherStart( + HloInstruction* all_gather_start) { + return HandleCollectiveOrCopyStart(all_gather_start); +} + +absl::Status HloValueSemanticsPropagation::HandleAllGatherDone( + HloInstruction* all_gather_done) { + return HandleCollectiveOrCopyDone(all_gather_done); +} + +absl::Status HloValueSemanticsPropagation::HandleCollectivePermuteStart( HloInstruction* collective_permute_start) { - ShapeTree semantics_shape_tree( - collective_permute_start->shape()); - const ShapeTree& operand_semantics_shape_tree = - analysis_->GetInstructionSemantics(collective_permute_start->operand(0)); - analysis_->DeepCopyHloValueSemantics(semantics_shape_tree, - operand_semantics_shape_tree, {}, {0}); - analysis_->DeepCopyHloValueSemantics(semantics_shape_tree, - operand_semantics_shape_tree, {}, {1}); - semantics_shape_tree.ForEachMutableElement( - [this, collective_permute_start](const ShapeIndex& shape_index, - const HloValueSemantics** semantics) { - if (shape_index.empty()) { - *semantics = analysis_->NewHloValueSemantics( - HloValueSemanticLabel::kTupleOrToken, - {collective_permute_start, {}}); - } - if (shape_index == ShapeIndex{2}) { - *semantics = analysis_->NewHloValueSemantics( - HloValueSemanticLabel::kRandom, - {collective_permute_start, shape_index}); - } - if (shape_index == ShapeIndex{3}) { - *semantics = analysis_->NewHloValueSemantics( - HloValueSemanticLabel::kRandom, - {collective_permute_start, shape_index}); - } - }); - analysis_->SetHloValueSemantics(collective_permute_start, - semantics_shape_tree); - return OkStatus(); + return HandleCollectiveOrCopyStart(collective_permute_start); } -Status HloValueSemanticsPropagation::HandleCollectivePermuteDone( +absl::Status HloValueSemanticsPropagation::HandleCollectivePermuteDone( HloInstruction* collective_permute_done) { - const ShapeTree& operand_semantics_shape_tree = - analysis_->GetInstructionSemantics(collective_permute_done->operand(0)); - analysis_->DeepCopyHloValueSemantics(collective_permute_done, - operand_semantics_shape_tree, {1}); - return OkStatus(); + return HandleCollectiveOrCopyDone(collective_permute_done); } -Status HloValueSemanticsPropagation::HandleGather(HloInstruction* gather) { +absl::Status HloValueSemanticsPropagation::HandleGather( + HloInstruction* gather) { + RETURN_IF_ALREADY_PROPAGATED(gather); const ShapeTree& operand_semantics_shape_tree = analysis_->GetInstructionSemantics(gather->operand(0)); analysis_->DeepCopyHloValueSemantics(gather, operand_semantics_shape_tree); return OkStatus(); } -Status HloValueSemanticsPropagation::HandleScatter(HloInstruction* scatter) { +absl::Status HloValueSemanticsPropagation::HandleScatter( + HloInstruction* scatter) { + RETURN_IF_ALREADY_PROPAGATED(scatter); TF_ASSIGN_OR_RETURN(HloValueSemantics semantics, ComputeSemanticsFromOperands(scatter, {0, 2})); const HloValueSemantics* semantics_ptr = AddSemantics(semantics); @@ -1329,7 +1888,9 @@ Status HloValueSemanticsPropagation::HandleScatter(HloInstruction* scatter) { return OkStatus(); } -Status HloValueSemanticsPropagation::HandleAfterAll(HloInstruction* after_all) { +absl::Status HloValueSemanticsPropagation::HandleAfterAll( + HloInstruction* after_all) { + RETURN_IF_ALREADY_PROPAGATED(after_all); const HloValueSemantics* semantics = analysis_->NewHloValueSemantics( HloValueSemanticLabel::kTupleOrToken, {after_all, {}}); ShapeTree semantics_shape_tree(after_all->shape(), @@ -1338,16 +1899,71 @@ Status HloValueSemanticsPropagation::HandleAfterAll(HloInstruction* after_all) { return OkStatus(); } -Status HloValueSemanticsPropagation::HandleAsyncStart( +absl::Status HloValueSemanticsPropagation::HandleAllReduce( + HloInstruction* all_reduce) { + RETURN_IF_ALREADY_PROPAGATED(all_reduce); + if (all_reduce->shape().IsArray()) { + return DefaultAction(all_reduce); + } + CHECK(all_reduce->shape().IsTuple()); + return HandleTupleLike(all_reduce); +} + +absl::Status HloValueSemanticsPropagation::HandleAsyncStart( HloInstruction* async_start) { - return Unimplemented("AsyncStart is not supported yet."); + RETURN_IF_ALREADY_PROPAGATED(async_start); + const HloValueSemantics* semantics = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kTupleOrToken, {async_start, {}}); + ShapeTree semantics_shape_tree(async_start->shape(), + semantics); + for (int operand_index = 0; operand_index < async_start->operand_count(); + ++operand_index) { + HloInstruction* operand = async_start->mutable_operand(operand_index); + const ShapeTree& operand_semantics_tree = + analysis_->GetInstructionSemantics(operand); + analysis_->DeepCopyHloValueSemantics( + semantics_shape_tree, operand_semantics_tree, {}, {0, operand_index}); + } + std::vector operand_indices(async_start->operand_count()); + std::iota(operand_indices.begin(), operand_indices.end(), 0); + TF_ASSIGN_OR_RETURN( + HloValueSemantics output_semantics, + ComputeSemanticsFromOperands(async_start, operand_indices)); + semantics_shape_tree.ForEachMutableElement( + [&output_semantics, &semantics_shape_tree, this, async_start]( + const ShapeIndex& index, const HloValueSemantics** semantics_ptr) { + if (index.empty() || index.front() == 0) { + return; + } + if (!semantics_shape_tree.IsLeaf(index)) { + *semantics_ptr = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kTupleOrToken, {async_start, {}}); + return; + } + if (index.front() == 1) { + *semantics_ptr = AddSemantics(output_semantics); + return; + } + if (index.front() == 2) { + *semantics_ptr = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kRandom, {async_start, {}}); + } + }); + analysis_->SetHloValueSemantics(async_start, semantics_shape_tree); + return OkStatus(); } -Status HloValueSemanticsPropagation::HandleAsyncDone( +absl::Status HloValueSemanticsPropagation::HandleAsyncDone( HloInstruction* async_done) { - return Unimplemented("AsyncDone is not supported yet."); + RETURN_IF_ALREADY_PROPAGATED(async_done); + const ShapeTree& operand_semantics_tree = + analysis_->GetInstructionSemantics(async_done->operand(0)); + analysis_->DeepCopyHloValueSemantics(async_done, operand_semantics_tree, {1}); + return OkStatus(); } -Status HloValueSemanticsPropagation::HandleInfeed(HloInstruction* infeed) { +absl::Status HloValueSemanticsPropagation::HandleInfeed( + HloInstruction* infeed) { + RETURN_IF_ALREADY_PROPAGATED(infeed); ShapeTree semantics_shape_tree(infeed->shape(), nullptr); semantics_shape_tree.ForEachMutableElement( @@ -1365,7 +1981,20 @@ Status HloValueSemanticsPropagation::HandleInfeed(HloInstruction* infeed) { return OkStatus(); } -Status HloValueSemanticsPropagation::HandleDomain(HloInstruction* domain) { +absl::Status HloValueSemanticsPropagation::HandleOutfeed( + HloInstruction* outfeed) { + RETURN_IF_ALREADY_PROPAGATED(outfeed); + const HloValueSemantics* semantics = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kTupleOrToken, {outfeed, {}}); + ShapeTree outfeed_semantics_tree(outfeed->shape(), + semantics); + analysis_->SetHloValueSemantics(outfeed, outfeed_semantics_tree); + return OkStatus(); +} + +absl::Status HloValueSemanticsPropagation::HandleDomain( + HloInstruction* domain) { + RETURN_IF_ALREADY_PROPAGATED(domain); HloInstruction* domain_operand = domain->mutable_operand(0); const ShapeTree& operand_semantics = analysis_->GetInstructionSemantics(domain_operand); @@ -1373,4 +2002,136 @@ Status HloValueSemanticsPropagation::HandleDomain(HloInstruction* domain) { return OkStatus(); } +absl::Status HloValueSemanticsPropagation::HandleOptimizationBarrier( + HloInstruction* opt_barrier) { + RETURN_IF_ALREADY_PROPAGATED(opt_barrier); + HloInstruction* opt_barrier_operand = opt_barrier->mutable_operand(0); + const ShapeTree& operand_semantics = + analysis_->GetInstructionSemantics(opt_barrier_operand); + analysis_->DeepCopyHloValueSemantics(opt_barrier, operand_semantics); + return OkStatus(); +} + +absl::Status HloValueSemanticsPropagation::HandleSend(HloInstruction* send) { + RETURN_IF_ALREADY_PROPAGATED(send); + ShapeTree semantics_tree(send->shape(), nullptr); + HloInstruction* source_buffer = send->mutable_operand(0); + const ShapeTree& source_buffer_semantics = + analysis_->GetInstructionSemantics(source_buffer); + analysis_->DeepCopyHloValueSemantics(semantics_tree, source_buffer_semantics, + {}, {0}); + + semantics_tree.ForEachMutableElement( + [this, send, &semantics_tree](const ShapeIndex& index, + const HloValueSemantics** semantics) { + if (!index.empty()) { + if (index.front() == 1 && semantics_tree.IsLeaf(index)) { + *semantics = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kRandom, {send, index}); + return; + } + if (index.front() == 0) { + return; + } + } + *semantics = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kTupleOrToken, {send, index}); + }); + analysis_->SetHloValueSemantics(send, semantics_tree); + return OkStatus(); +} + +absl::Status HloValueSemanticsPropagation::HandleRecv(HloInstruction* recv) { + // Since recv is not a prerequisite of send, we might have not propagated + // semantics to the corresponding send when we reach this recv. So we visit + // the send first before visiting this recv. + // We use RETURN_IF_ALREADY_PROPAGATED to avoid processing an HLO more than + // once. + RETURN_IF_ALREADY_PROPAGATED(recv); + TF_ASSIGN_OR_RETURN(HloInstruction * send, + analysis_->GetMatchingSendOrRecv(recv)); + TF_RETURN_IF_ERROR(send->Accept(this)); + ShapeTree semantics_tree(recv->shape(), nullptr); + const ShapeTree& send_buffer_semantics = + analysis_->GetInstructionSemantics(send); + analysis_->DeepCopyHloValueSemantics(semantics_tree, send_buffer_semantics, + {0}, {0}); + semantics_tree.ForEachMutableElement( + [this, recv, &semantics_tree](const ShapeIndex& index, + const HloValueSemantics** semantics) { + if (!index.empty()) { + if (index.front() == 1 && semantics_tree.IsLeaf(index)) { + *semantics = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kRandom, {recv, index}); + return; + } + if (index.front() == 0) { + return; + } + } + *semantics = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kTupleOrToken, {recv, index}); + }); + analysis_->SetHloValueSemantics(recv, semantics_tree); + return OkStatus(); +} + +absl::Status HloValueSemanticsPropagation::HandleSendDone( + HloInstruction* send_done) { + RETURN_IF_ALREADY_PROPAGATED(send_done); + const HloValueSemantics* semantics = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kTupleOrToken, {send_done, {}}); + ShapeTree send_done_semantics_tree( + send_done->shape(), semantics); + analysis_->SetHloValueSemantics(send_done, send_done_semantics_tree); + return OkStatus(); +} +absl::Status HloValueSemanticsPropagation::HandleRecvDone( + HloInstruction* recv_done) { + RETURN_IF_ALREADY_PROPAGATED(recv_done); + ShapeTree semantics_tree(recv_done->shape(), + nullptr); + HloInstruction* recv = recv_done->mutable_operand(0); + const ShapeTree& recv_semantics = + analysis_->GetInstructionSemantics(recv); + analysis_->DeepCopyHloValueSemantics(semantics_tree, recv_semantics, {0}, + {0}); + semantics_tree.ForEachMutableElement( + [this, recv_done](const ShapeIndex& index, + const HloValueSemantics** semantics) { + if (!index.empty() && index.front() == 0) { + return; + } + *semantics = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kTupleOrToken, {recv_done, index}); + }); + analysis_->SetHloValueSemantics(recv_done, semantics_tree); + return OkStatus(); +} + +absl::Status HloValueSemanticsPropagation::HandleTupleLike( + HloInstruction* tuple_like) { + ShapeTree semantics_shape_tree(tuple_like->shape(), + nullptr); + for (int operand_index = 0; operand_index < tuple_like->operand_count(); + ++operand_index) { + const HloInstruction* operand = tuple_like->operand(operand_index); + const ShapeTree& operand_semantics = + analysis_->GetInstructionSemantics(operand); + analysis_->DeepCopyHloValueSemantics( + semantics_shape_tree, operand_semantics, {}, {operand_index}); + } + semantics_shape_tree.ForEachMutableElement( + [tuple_like, this](const ShapeIndex& index, + const HloValueSemantics** semantics) { + if (index.empty()) { + *semantics = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kTupleOrToken, {tuple_like, {}}); + return; + } + }); + analysis_->SetHloValueSemantics(tuple_like, semantics_shape_tree); + return OkStatus(); +} + } // namespace xla diff --git a/xla/service/hlo_value_semantics_analysis.h b/xla/service/hlo_value_semantics_analysis.h index 7037951eddaaf..4cff3c4984f14 100644 --- a/xla/service/hlo_value_semantics_analysis.h +++ b/xla/service/hlo_value_semantics_analysis.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,12 +24,14 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/node_hash_map.h" #include "absl/types/span.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_value.h" +#include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/status.h" @@ -37,6 +39,24 @@ limitations under the License. namespace xla { +struct SendRecvGroup { + HloInstruction* send; + HloInstruction* recv; +}; + +class SendRecvGroupMap { + public: + explicit SendRecvGroupMap(const HloModule& hlo_module); + SendRecvGroupMap(SendRecvGroupMap&& other) = default; + SendRecvGroupMap(const SendRecvGroupMap& other) = default; + virtual ~SendRecvGroupMap() = default; + virtual absl::StatusOr GetMatchingSendOrRecv( + HloInstruction* send_or_recv) const; + + private: + absl::flat_hash_map host_transfer_rendezvous_map_; +}; + class HloPreOrderDFS { public: HloPreOrderDFS() = default; @@ -51,7 +71,7 @@ class HloPreOrderDFS { }; using EinsumDepthMap = - absl::flat_hash_map>; + absl::node_hash_map>; // The einsum depth is the length of the einsum dependency chain. And we // distinguish instructions that are used by root and that are not used by @@ -69,8 +89,9 @@ using EinsumDepthMap = class EinsumDepthAnalysis : public DfsHloVisitorWithDefault { public: - static StatusOr> Run( - const HloComputation& computation); + static absl::StatusOr> Run( + const HloComputation& computation, + const SendRecvGroupMap& send_recv_group_map); ~EinsumDepthAnalysis() override = default; Status DefaultAction(HloInstruction* instruction) override; Status HandleTuple(HloInstruction* tuple) override; @@ -79,26 +100,89 @@ class EinsumDepthAnalysis : public DfsHloVisitorWithDefault { Status HandleConvolution(HloInstruction* convolution) override; Status HandleCall(HloInstruction* call) override; Status HandleFusion(HloInstruction* fusion) override; - Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleWhile(HloInstruction* xla_while) override; Status HandleConditional(HloInstruction* conditional) override; Status HandleAfterAll(HloInstruction* after_all) override; - Status HandleOutfeed(HloInstruction* outfeed) override; + Status HandleSend(HloInstruction* send) override; + Status HandleRecv(HloInstruction* recv) override; + Status HandleSendDone(HloInstruction* send_done) override; + Status HandleRecvDone(HloInstruction* recv_done) override; + Status HandleAllReduce(HloInstruction* all_reduce) override; const EinsumDepthMap& GetEinsumDepthMap() const { return einsum_depth_map_; } private: - EinsumDepthAnalysis() = default; + explicit EinsumDepthAnalysis(const SendRecvGroupMap& send_recv_group_map) + : send_recv_group_map_(&send_recv_group_map) {} Status RunInternal(const HloComputation& computation, const std::optional>& root_depth); - EinsumDepthMap::iterator GetOrCreateDepthTree(HloInstruction* instruction); - Status SetInstructionDepth(HloInstruction* instruction, int depth); - Status SetInstructionDepth(HloInstruction* instruction, + EinsumDepthMap::iterator GetOrCreateDepthTree( + const HloInstruction* instruction); + EinsumDepthMap::iterator GetDepthTreeOrDie(const HloInstruction* instruction); + Status SetInstructionDepth(const HloInstruction* instruction, int depth); + Status SetInstructionDepth(const HloInstruction* instruction, const ShapeTree& depth); + Status SetInstructionDepthFromTupleDepth( + const HloInstruction* instruction, const ShapeTree& tuple_depth_tree, + int tuple_index); Status HandleDepthIncrementInstruction(HloInstruction* instruction); Status HandleCalledComputation(const HloComputation& called_computation, const ShapeTree& root_depth, absl::Span operands); + Status HandleTupleLike(HloInstruction* tuple_like); EinsumDepthMap einsum_depth_map_; + const SendRecvGroupMap* const send_recv_group_map_; +}; + +using EinsumHeightMap = + absl::node_hash_map>; + +// Einsum height is the maximum number of einsums between this instruction and +// any leaf. + +class EinsumHeightAnalysis : public DfsHloVisitorWithDefault { + public: + static absl::StatusOr> Run( + const HloComputation& computation, + const SendRecvGroupMap& send_recv_group_map); + ~EinsumHeightAnalysis() override = default; + Status DefaultAction(HloInstruction* instruction) override; + Status HandleTuple(HloInstruction* tuple) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; + Status HandleDot(HloInstruction* dot) override; + Status HandleConvolution(HloInstruction* convolution) override; + Status HandleCall(HloInstruction* call) override; + Status HandleFusion(HloInstruction* fusion) override; + Status HandleWhile(HloInstruction* xla_while) override; + Status HandleConditional(HloInstruction* conditional) override; + Status HandleSend(HloInstruction* send) override; + Status HandleRecv(HloInstruction* recv) override; + Status HandleSendDone(HloInstruction* send_done) override; + Status HandleRecvDone(HloInstruction* recv_done) override; + Status HandleAllReduce(HloInstruction* all_reduce) override; + const EinsumHeightMap& GetEinsumHeightMap() const { + return einsum_height_map_; + } + + private: + explicit EinsumHeightAnalysis(const SendRecvGroupMap& send_recv_group_map) + : send_recv_group_map_(&send_recv_group_map) {} + Status RunInternal(const HloComputation& computation, + absl::Span operands); + EinsumHeightMap::iterator GetOrCreateHeightTree( + const HloInstruction* instruction); + EinsumHeightMap::iterator GetHeightTreeOrDie( + const HloInstruction* instruction); + bool HasHeightFor(const HloInstruction* instruction) const; + Status SetInstructionHeight(const HloInstruction* instruction, int height); + Status SetInstructionHeight(const HloInstruction* instruction, + const ShapeTree& height); + Status HandleHeightIncrementInstruction(HloInstruction* instruction); + Status HandleCalledComputation(const HloComputation& computation, + absl::Span operands); + Status HandleTupleLike(HloInstruction* tuple_like); + + EinsumHeightMap einsum_height_map_; + const SendRecvGroupMap* const send_recv_group_map_; }; // The comment below explains where the labels could originate from. Once @@ -144,16 +228,20 @@ class HloValueSemantics { const HloPosition origin_; }; +std::string HloValueSemanticsTreeToString( + const ShapeTree& tree); + using HloValueSemanticsMap = - absl::flat_hash_map>; class HloValueSemanticsPropagation; class HloValueSemanticsAnalysis { public: - static StatusOr> Run( + static absl::StatusOr> Run( const HloModule& module); virtual ~HloValueSemanticsAnalysis() = default; + bool HasSemanticsFor(const HloInstruction* instruction) const; const HloValueSemantics* GetSemantics(const HloInstruction* instruction, const ShapeIndex& index = {}) const; @@ -162,11 +250,28 @@ class HloValueSemanticsAnalysis { } const EinsumDepthMap& GetEinsumDepthMap() const { return einsum_depth_map_; } + const EinsumHeightMap& GetEinsumHeightMap() const { + return einsum_height_map_; + } + int GetDepth(const HloInstruction* instruction, + const ShapeIndex& index = {}) const; + int GetHeight(const HloInstruction* instruction, + const ShapeIndex& index = {}) const; + + const SendRecvGroupMap& GetSendRecvGroupMap() const { + return *send_recv_group_map_; + } + + absl::StatusOr GetMatchingSendOrRecv( + HloInstruction* send_or_recv) const; protected: friend class HloValueSemanticsPropagation; explicit HloValueSemanticsAnalysis(const HloModule& module); - Status InitializeEinsumDepth(); + virtual Status InitializeEinsumDepth(); + virtual Status InitializeEinsumHeight(); + // We match send and recv HLOs to propagate semantics from send to recv. + virtual void InitializeSendRecvGroups(); void AnnotateWeights(); // Infer semantics for all instructions in the computation. Computation @@ -201,6 +306,8 @@ class HloValueSemanticsAnalysis { value_semantics_map_; HloValueSemantics::Id next_id_; EinsumDepthMap einsum_depth_map_; + EinsumHeightMap einsum_height_map_; + std::unique_ptr send_recv_group_map_; }; class HloValueSemanticsPropagation : public DfsHloVisitorWithDefault { @@ -229,6 +336,8 @@ class HloValueSemanticsPropagation : public DfsHloVisitorWithDefault { HloInstruction* dynamic_update_slice) override; Status HandleCopyStart(HloInstruction* copy_start) override; Status HandleCopyDone(HloInstruction* copy_done) override; + Status HandleAllGatherStart(HloInstruction* all_gather_start) override; + Status HandleAllGatherDone(HloInstruction* all_gather_done) override; Status HandleCollectivePermuteStart( HloInstruction* collective_permute_start) override; Status HandleCollectivePermuteDone( @@ -236,11 +345,18 @@ class HloValueSemanticsPropagation : public DfsHloVisitorWithDefault { Status HandleGather(HloInstruction* gather) override; Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* after_all) override; + Status HandleAllReduce(HloInstruction* all_reduce) override; Status HandleAsyncStart(HloInstruction* async_start) override; Status HandleAsyncDone(HloInstruction* async_done) override; Status HandleInfeed(HloInstruction* infeed) override; + Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleDomain(HloInstruction* domain) override; + Status HandleOptimizationBarrier(HloInstruction* opt_barrier) override; Status HandleRngBitGenerator(HloInstruction* rng_bit_generator) override; + Status HandleSend(HloInstruction* send) override; + Status HandleRecv(HloInstruction* recv) override; + Status HandleSendDone(HloInstruction* send_done) override; + Status HandleRecvDone(HloInstruction* recv_done) override; protected: HloValueSemantics CopySemantics(const HloValueSemantics& semantics) const; @@ -268,35 +384,43 @@ class HloValueSemanticsPropagation : public DfsHloVisitorWithDefault { bool OriginDependsOn(const HloValueSemantics& semantics, const HloPosition& origin_dependence, bool recursive = false) const; - StatusOr CreateGradientSemantics( - HloInstruction* gradient_candidate) const; - StatusOr ComputeSemanticsFromStaticAndOther( + absl::StatusOr MaybeCreateGradientSemantics( + HloInstruction* gradient_candidate, + HloValueSemanticLabel fallback_label) const; + absl::StatusOr ComputeSemanticsFromStaticAndOther( const HloValueSemantics& static_semantics, const HloValueSemantics& other_semantics, HloInstruction* instruction) const; - StatusOr ComputeSemanticsFromRandomAndOther( + absl::StatusOr ComputeSemanticsFromRandomAndOther( const HloValueSemantics& random_semantics, const HloValueSemantics& other_semantics, HloInstruction* instruction) const; - StatusOr ComputeSemanticsFromWeightAndOther( + absl::StatusOr ComputeSemanticsFromWeightAndOther( const HloValueSemantics& weight_semantics, const HloValueSemantics& other_semantics, HloInstruction* instruction) const; - StatusOr ComputeSemanticsFromActivationAndOther( + absl::StatusOr ComputeSemanticsFromActivationAndOther( const HloValueSemantics& activation_semantics, const HloValueSemantics& other_semantics, HloInstruction* instruction) const; - StatusOr ComputeSemanticsFromActivationGradientAndOther( + absl::StatusOr + ComputeSemanticsFromActivationGradientAndOther( const HloValueSemantics& activation_gradient_semantics, const HloValueSemantics& other_semantics, HloInstruction* instruction) const; - StatusOr ComputeSemanticsFromWeightGradientAndOther( + absl::StatusOr ComputeSemanticsFromWeightGradientAndOther( const HloValueSemantics& weight_gradient_semantics, const HloValueSemantics& other_semantics, HloInstruction* instruction) const; - StatusOr ComputeSemanticsFromOperands( + absl::StatusOr MergeSemanticsForAnInstruction( + HloInstruction* instruction, + std::vector& semantics_vec) const; + absl::StatusOr ComputeSemanticsFromOperands( HloInstruction* instruction, absl::Span operand_indices, absl::Span operand_shape_indices = {}) const; + Status HandleTupleLike(HloInstruction* tuple_like); + Status HandleCollectiveOrCopyStart(HloInstruction* op_start); + Status HandleCollectiveOrCopyDone(HloInstruction* op_done); HloValueSemanticsAnalysis* analysis_; }; diff --git a/xla/service/hlo_value_semantics_analysis_test.cc b/xla/service/hlo_value_semantics_analysis_test.cc index fd1704f9192e8..aabff9a72183d 100644 --- a/xla/service/hlo_value_semantics_analysis_test.cc +++ b/xla/service/hlo_value_semantics_analysis_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -195,6 +195,12 @@ class HloValueSemanticsAnalysisTest : public HloTestBase { return HasLabel(hlo_value_semantics_analysis, module, instruction_name, HloValueSemanticLabel::kWeightGradient); } + bool IsTupleOrToken( + const HloValueSemanticsAnalysis& hlo_value_semantics_analysis, + HloModule* module, absl::string_view instruction_name) { + return HasLabel(hlo_value_semantics_analysis, module, instruction_name, + HloValueSemanticLabel::kTupleOrToken); + } }; TEST_F(HloValueSemanticsAnalysisTest, OneMatmul) { @@ -244,6 +250,41 @@ ENTRY entry { EXPECT_TRUE(IsWeight(*hlo_value_semantics_analysis, module.get(), "dot.2")); } +TEST_F(HloValueSemanticsAnalysisTest, HandleConditional) { + const std::string module_str = R"( + HloModule Module + + branch0 { + tparam = f32[4] parameter(0) + tgte1 = f32[4] ceil(tparam) + ROOT tuple = (f32[4], f32[4]) tuple(tparam, tgte1) + } + + branch1 { + fparam = f32[4] parameter(0) + %async-start = ((f32[4]), f32[4], s32[]) custom-call-start(f32[4] fparam), async_execution_thread="parallel_thread", custom_call_target="foo" + %async-done = f32[4] custom-call-done(((f32[4]), f32[4], s32[]) %async-start) + ROOT tuple = (f32[4], f32[4]) tuple(fparam, %async-done) + } + + ENTRY entry { + p0 = f32[4] parameter(0) + b0 = s32[] parameter(1) + ROOT conditional = (f32[4], f32[4]) conditional(b0, p0, p0), + branch_computations={branch0, branch1} + } +)"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(module_str, /*replica_count=*/1, + /*num_partitions=*/2)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr hlo_value_semantics_analysis, + HloValueSemanticsAnalysis::Run(*module)); + EXPECT_TRUE(IsTupleOrToken(*hlo_value_semantics_analysis, module.get(), + "conditional")); +} + TEST_F(HloValueSemanticsAnalysisTest, TwoMatmuls) { const std::string module_str = R"( HloModule TwoMatmuls @@ -567,7 +608,8 @@ TEST_F(EinsumDepthAnalysisTest, MnistTrainingLoop) { /*num_partitions=*/1)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr einsum_depth_analysis, - EinsumDepthAnalysis::Run(*module->entry_computation())); + EinsumDepthAnalysis::Run(*module->entry_computation(), + SendRecvGroupMap(*module))); const EinsumDepthMap& einsum_depth_map = einsum_depth_analysis->GetEinsumDepthMap(); HloComputation* computation = module->GetComputationWithName("body.49"); @@ -593,7 +635,7 @@ TEST_F(EinsumDepthAnalysisTest, HandleConditional) { branch1 { fparam = f32[4] parameter(0) %async-start = ((f32[4]), f32[4], s32[]) custom-call-start(f32[4] fparam), async_execution_thread="parallel_thread", custom_call_target="foo" - ROOT %async-done = f32[4] custom-call-done(((f32[4]), f32[4], s32[]) %async-start), async_execution_thread="parallel_thread", custom_call_target="foo" + ROOT %async-done = f32[4] custom-call-done(((f32[4]), f32[4], s32[]) %async-start) } branch2 { @@ -612,7 +654,8 @@ TEST_F(EinsumDepthAnalysisTest, HandleConditional) { ParseAndReturnVerifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr einsum_depth_analysis, - EinsumDepthAnalysis::Run(*module->entry_computation())); + EinsumDepthAnalysis::Run(*module->entry_computation(), + SendRecvGroupMap(*module))); const EinsumDepthMap& einsum_depth_map = einsum_depth_analysis->GetEinsumDepthMap(); HloComputation* computation = module->GetComputationWithName("entry"); @@ -620,5 +663,62 @@ TEST_F(EinsumDepthAnalysisTest, HandleConditional) { 0); } +TEST_F(EinsumDepthAnalysisTest, HandleAfterAll) { + const char* const hlo_string = R"( + ENTRY entry { + after-all.1 = token[] after-all() + parameter.1 = f32[] parameter(0) + send.1 = (f32[], u32[], token[]) send(parameter.1, after-all.1), channel_id=1, is_host_transfer=true, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="rendezvous1"} + send-done.1 = token[] send-done(send.1), channel_id=1, is_host_transfer=true, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="rendezvous1"} + ROOT after-all.2 = token[] after-all(send-done.1), frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="rendezvous1"} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr einsum_depth_analysis, + EinsumDepthAnalysis::Run(*module->entry_computation(), + SendRecvGroupMap(*module))); + const EinsumDepthMap& einsum_depth_map = + einsum_depth_analysis->GetEinsumDepthMap(); + HloComputation* computation = module->GetComputationWithName("entry"); + EXPECT_EQ(GetInstructionDepth(einsum_depth_map, computation, "after-all.2"), + 0); +} + +class EinsumHeightAnalysisTest : public HloTestBase { + public: + int GetInstructionHeight(const EinsumHeightMap& height_map, + HloComputation* computation, + absl::string_view name) { + HloInstruction* instruction = computation->GetInstructionWithName(name); + auto height_iter = height_map.find(instruction); + EXPECT_NE(height_iter, height_map.end()); + return height_iter->second.element({}); + } +}; + +TEST_F(EinsumHeightAnalysisTest, MnistTrainingLoop) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kMnistHlo, + /*replica_count=*/1, + /*num_partitions=*/1)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr einsum_height_analysis, + EinsumHeightAnalysis::Run(*module->entry_computation(), + SendRecvGroupMap(*module))); + const EinsumHeightMap& einsum_height_map = + einsum_height_analysis->GetEinsumHeightMap(); + HloComputation* computation = module->GetComputationWithName("body.49"); + EXPECT_EQ(GetInstructionHeight(einsum_height_map, computation, "dot.63"), 1); + EXPECT_EQ(GetInstructionHeight(einsum_height_map, computation, "dot.67"), 2); + EXPECT_EQ(GetInstructionHeight(einsum_height_map, computation, "dot.71"), 3); + EXPECT_EQ(GetInstructionHeight(einsum_height_map, computation, "dot.89"), 4); + EXPECT_EQ(GetInstructionHeight(einsum_height_map, computation, "dot.96"), 5); + EXPECT_EQ(GetInstructionHeight(einsum_height_map, computation, "dot.92"), 5); + EXPECT_EQ(GetInstructionHeight(einsum_height_map, computation, "dot.99"), 6); + EXPECT_EQ(GetInstructionHeight(einsum_height_map, computation, "dot.85"), 4); +} + } // namespace } // namespace xla diff --git a/xla/service/hlo_verifier.cc b/xla/service/hlo_verifier.cc index b78b796f8b985..6de670ace1086 100644 --- a/xla/service/hlo_verifier.cc +++ b/xla/service/hlo_verifier.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ limitations under the License. #include "xla/service/hlo_verifier.h" #include +#include +#include #include #include #include @@ -27,10 +29,16 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -39,54 +47,34 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/layout.h" +#include "xla/layout_util.h" #include "xla/permutation_util.h" #include "xla/primitive_util.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/pattern_matcher.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/shape_inference.h" +#include "xla/shape.h" +#include "xla/shape_layout.h" #include "xla/shape_util.h" +#include "xla/status.h" #include "xla/status_macros.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { - -namespace m = match; - namespace { bool IsCallerInstruction(HloInstruction* hlo) { - switch (hlo->opcode()) { - case HloOpcode::kAsyncStart: - case HloOpcode::kAsyncUpdate: - case HloOpcode::kAsyncDone: - case HloOpcode::kCall: - case HloOpcode::kConditional: - case HloOpcode::kWhile: - case HloOpcode::kAllReduce: - case HloOpcode::kReduceScatter: - case HloOpcode::kAllReduceStart: - case HloOpcode::kMap: - case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: - case HloOpcode::kScatter: - case HloOpcode::kSelectAndScatter: - case HloOpcode::kSort: - case HloOpcode::kTopK: - case HloOpcode::kFusion: - case HloOpcode::kCustomCall: - return true; - default: - return false; - } + return HloInstruction::MightHaveCalledComputations(hlo->opcode()); } Status CheckOperandCount(const HloInstruction* hlo, int expected) { if (hlo->operand_count() != expected) { - return InternalError("Expected %d operands for %s instruction: %s", - expected, HloOpcodeString(hlo->opcode()), - hlo->ToString()); + return Internal("Expected %d operands for %s instruction: %s", expected, + HloOpcodeString(hlo->opcode()), hlo->ToString()); } return OkStatus(); } @@ -125,7 +113,7 @@ Status CheckNestedComputationThreadNameEqual(const HloComputation* comp, } for (const HloComputation* called_cmp : instr->called_computations()) { if (called_cmp->execution_thread() != comp->execution_thread()) { - return InternalError( + return Internal( "Nested computations expects same computation's thread name (%s vs " "%s).", called_cmp->execution_thread(), comp->execution_thread()); @@ -142,7 +130,7 @@ Status CheckNestedComputationThreadNameEqual(const HloComputation* comp, const HloInstruction* calling_instruction, const HloComputation* computation, int expected) { if (computation->num_parameters() != expected) { - return InternalError( + return Internal( "Expected computation %s called from %s to have %d parameters, has %d", computation->name(), calling_instruction->name(), expected, computation->num_parameters()); @@ -152,7 +140,7 @@ Status CheckNestedComputationThreadNameEqual(const HloComputation* comp, Status ShapeVerifier::Preprocess(HloInstruction* hlo) { if (!hlo->called_computations().empty() && !IsCallerInstruction(hlo)) { - return InternalError( + return Internal( "Called computations specified for non-caller instruction %s", hlo->ToString()); } @@ -217,12 +205,15 @@ Status ShapeVerifier::HandleCopy(HloInstruction* copy) { } Status ShapeVerifier::HandleDot(HloInstruction* dot) { + auto sparsity = Cast(dot)->sparsity(); + TF_RETURN_IF_ERROR( + CheckOperandCount(dot, HloDotInstruction::kOperands + sparsity.size())); TF_ASSIGN_OR_RETURN( const Shape expected, ShapeInference::InferDotOpShape( dot->operand(0)->shape(), dot->operand(1)->shape(), dot->dot_dimension_numbers(), - /*preferred_element_type=*/dot->shape().element_type())); + /*preferred_element_type=*/dot->shape().element_type(), sparsity)); if (auto nibble_count = absl::c_count(dot->precision_config().operand_precision(), PrecisionConfig::PACKED_NIBBLE)) { @@ -244,6 +235,24 @@ Status ShapeVerifier::HandleDot(HloInstruction* dot) { } } } + for (int i = 0; i < sparsity.size(); ++i) { + const SparsityDescriptor& descriptor = sparsity[i]; + TF_RET_CHECK(descriptor.index() == 0 || descriptor.index() == 1); + TF_ASSIGN_OR_RETURN(const Shape expected_metadata_shape, + ShapeInference::InferSparseDotMetadataShape( + dot->operand(descriptor.index())->shape(), + dot->dot_dimension_numbers(), descriptor)); + const Shape actual_metadata_shape = + dot->operand(HloDotInstruction::kOperands + i)->shape(); + if (!ShapeUtil::Compatible(actual_metadata_shape, + expected_metadata_shape)) { + return Internal( + "Expected sparse dot metadata to have shape equal to %s, actual " + "shape is %s:\n%s", + StringifyShape(expected_metadata_shape), + StringifyShape(actual_metadata_shape), dot->ToString()); + } + } return CheckShape(dot, expected); } @@ -323,21 +332,10 @@ Status ShapeVerifier::HandleOptimizationBarrier(HloInstruction* hlo) { } bool ShapeVerifier::ShapesSame(const Shape& a, const Shape& b, - bool minor_to_major_only, - bool ignore_memory_space, bool ignore_tiles) { + Shape::Equal equal) { if (!opts_.layout_sensitive) { return ShapeUtil::Compatible(a, b); } - Shape::Equal equal; - if (ignore_memory_space) { - equal.IgnoreMemorySpaceInLayout(); - } - if (minor_to_major_only) { - equal.MinorToMajorOnlyInLayout(); - } - if (ignore_tiles) { - equal.IgnoreTilesInLayout(); - } return equal(a, b); } @@ -361,13 +359,12 @@ static Status CheckReplicaGroups(HloInstruction* hlo, absl::flat_hash_set replicas_seen; for (const ReplicaGroup& g : hlo->replica_groups()) { if (g.replica_ids().empty()) { - return InternalError( - "Instruction cannot have an empty replica group: %s", - hlo->ToString()); + return Internal("Instruction cannot have an empty replica group: %s", + hlo->ToString()); } for (int64_t i : g.replica_ids()) { if (!replicas_seen.insert(i).second) { - return InternalError( + return Internal( "Replica %d is repeated in instruction's replica-groups: %s", i, hlo->ToString()); } @@ -376,7 +373,7 @@ static Status CheckReplicaGroups(HloInstruction* hlo, size_t n = replicas_seen.size(); for (int64_t i = 0; i < n; ++i) { if (!replicas_seen.count(i)) { - return InternalError( + return Internal( "Replica %d is not named in instruction's replica-groups: %s", i, hlo->ToString()); } @@ -620,7 +617,7 @@ namespace { Status CheckBufferOffset(const Shape& buffer_shape, const Shape& buffer_offset_shape) { if (!buffer_offset_shape.IsTuple()) { - return InternalError("Buffer offset is not tuple."); + return Internal("Buffer offset is not tuple."); } bool all_is_array = absl::c_all_of(buffer_offset_shape.tuple_shapes(), @@ -629,7 +626,7 @@ Status CheckBufferOffset(const Shape& buffer_shape, absl::c_all_of(buffer_offset_shape.tuple_shapes(), [](const Shape& shape) { return shape.IsTuple(); }); if (!all_is_array && !all_is_tuple) { - return InternalError( + return Internal( "Buffer offset should either be a tuple of arrays or " " a tuple of tuples."); } @@ -640,13 +637,13 @@ Status CheckBufferOffset(const Shape& buffer_shape, return ShapeUtil::TupleElementCount(shape) != buffer_shape.rank(); })) { - return InternalError( + return Internal( "Buffer offset index should have the same number of " "elements as the buffer's rank."); } } else { if (buffer_offset_shape.tuple_shapes_size() != buffer_shape.rank()) { - return InternalError( + return Internal( "Buffer offset index should have the same number of " "elements as the buffer's rank."); } @@ -659,8 +656,8 @@ Status CheckInplaceCollectivePermute(HloInstruction* collective_permute) { return OkStatus(); } if (collective_permute->operand_count() != 4) { - return InternalError("Unexpected number of operands: %d.", - collective_permute->operand_count()); + return Internal("Unexpected number of operands: %d.", + collective_permute->operand_count()); } const Shape& input_buffer_shape = collective_permute->operand(0)->shape(); @@ -682,12 +679,12 @@ Status CheckInplaceCollectivePermute(HloInstruction* collective_permute) { } else if (input_buffer_shape.IsTuple() && output_buffer_shape.IsTuple()) { if (ShapeUtil::TupleElementCount(input_buffer_shape) != ShapeUtil::TupleElementCount(output_buffer_shape)) { - return InternalError("Unmatching input buffers and output buffers."); + return Internal("Unmatching input buffers and output buffers."); } if (!input_offset_shape.IsTuple() || ShapeUtil::TupleElementCount(input_offset_shape) != ShapeUtil::TupleElementCount(input_buffer_shape)) { - return InternalError("Unmatching input buffers and input offset."); + return Internal("Unmatching input buffers and input offset."); } for (int i = 0; i < input_buffer_shape.tuple_shapes_size(); ++i) { Status check_input_buffer_offset = @@ -700,7 +697,7 @@ Status CheckInplaceCollectivePermute(HloInstruction* collective_permute) { if (!output_offset_shape.IsTuple() || ShapeUtil::TupleElementCount(output_offset_shape) != ShapeUtil::TupleElementCount(output_buffer_shape)) { - return InternalError("Unmatching output buffers and output offset."); + return Internal("Unmatching output buffers and output offset."); } for (int i = 0; i < output_buffer_shape.tuple_shapes_size(); ++i) { Status check_output_buffer_offset = @@ -711,7 +708,7 @@ Status CheckInplaceCollectivePermute(HloInstruction* collective_permute) { } } } else { - return InternalError("Unmatching input buffers and output buffers."); + return Internal("Unmatching input buffers and output buffers."); } return OkStatus(); } @@ -752,12 +749,12 @@ Status CheckDuplicatedSourceOrTarget(HloInstruction* hlo, if (seen_source_to_targets.contains(p.first) && seen_source_to_targets[p.first].size() == allowed_seen_count) { if (allowed_seen_count == 1) { - return InternalError( + return Internal( "Source %d appears more than once in instruction's source-target " "pairs: %s", p.first, hlo->ToString()); } else { - return InternalError( + return Internal( "Source %d appears more than %d times in instruction's " "source-target " "pairs: %s", @@ -777,12 +774,12 @@ Status CheckDuplicatedSourceOrTarget(HloInstruction* hlo, if (seen_target_to_sources.contains(p.second) && seen_target_to_sources[p.second].size() == allowed_seen_count) { if (allowed_seen_count == 1) { - return InternalError( + return Internal( "Target %d appears more than once in instruction's source-target " "pairs: %s", p.second, hlo->ToString()); } else { - return InternalError( + return Internal( "Target %d appears more than %d times in instruction's " "source-target " "pairs: %s", @@ -797,6 +794,15 @@ Status CheckDuplicatedSourceOrTarget(HloInstruction* hlo, } // namespace +Status ShapeVerifier::HandleCollectiveBroadcast(HloInstruction* hlo) { + std::vector operand_shapes; + for (const HloInstruction* operand : hlo->operands()) { + operand_shapes.push_back(&operand->shape()); + } + return CheckShape( + hlo, ShapeInference::InferCollectiveBroadcastShape(operand_shapes)); +} + Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { TF_ASSIGN_OR_RETURN( CollectiveOpGroupMode group_mode, @@ -848,7 +854,7 @@ Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction, int64_t operand_no) { const HloInstruction* token = instruction->operand(operand_no); if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) { - return InternalError( + return Internal( "Expected operand %d to be token-shaped, actual shape is " "%s:\n%s", operand_no, StringifyShape(token->shape()), instruction->ToString()); @@ -863,9 +869,9 @@ Status ShapeVerifier::CheckOperandAndParameter( const HloInstruction* parameter = computation->parameter_instruction(parameter_number); if (!ShapesSame(operand->shape(), parameter->shape())) { - return InternalError("Operand %s shape does not match parameter's %s in %s", - operand->ToString(), parameter->ToString(), - instruction->ToString()); + return Internal("Operand %s shape does not match parameter's %s in %s", + operand->ToString(), parameter->ToString(), + instruction->ToString()); } return OkStatus(); } @@ -888,7 +894,7 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { // Outfeed has a separate shape field for the value which is outfed to the // host. The shape of the instruction itself is always a token. if (!ShapesSame(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) { - return InternalError( + return Internal( "Expected outfeed shape to be equal to operand's shape %s, " "actual shape is %s:\n%s", StringifyShape(outfeed->operand(0)->shape()), @@ -913,13 +919,13 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { const Shape& shape_0 = instruction->operand(0)->shape(); const Shape& shape_1 = instruction->operand(1)->shape(); if (!ShapeUtil::IsScalar(shape_0) || !ShapeUtil::IsScalar(shape_1)) { - return InternalError( + return Internal( "Expected scalar types for the two operands of Rng instruction: %s", instruction->ToString()); } if (!HasCompatibleElementTypes(shape_0, shape_1, instruction->shape())) { - return InternalError( + return Internal( "Expected compatible element types for the result and the two operands" " of Rng instruction: %s", instruction->ToString()); @@ -931,7 +937,7 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { if (!primitive_util::IsFloatingPointType(element_type) && !primitive_util::IsIntegralType(element_type) && element_type != PRED) { - return InternalError( + return Internal( "Element type not supported." " Expected element to be of floating point type, integral type or" " predicate type for RngUniform: %s", @@ -941,14 +947,14 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { case RNG_NORMAL: if (!primitive_util::IsFloatingPointType(element_type)) { - return InternalError( + return Internal( "Element type not supported." " Expected element to be FloatingPointType for RngNormal: %s", instruction->ToString()); } break; default: - return InternalError( + return Internal( "Invalid Rng distribution %s", RandomDistribution_Name(instruction->random_distribution())); } @@ -961,16 +967,16 @@ Status ShapeVerifier::HandleRngBitGenerator(HloInstruction* hlo) { return OkStatus(); } if (hlo->shape().IsTuple() && hlo->shape().tuple_shapes_size() != 2) { - return InternalError( + return Internal( "Expected tuple shape with 2 elements for RngBitGenerator. Got: %s", - hlo->shape().ToString()); + hlo->shape().ToString(true)); } if (!ShapeUtil::Compatible(hlo->operand(0)->shape(), hlo->shape().tuple_shapes(0))) { - return InternalError( + return Internal( "Expected state shape to match between input and output for " "RngBitGenerator. Got %s vs. %s", - hlo->operand(0)->shape().ToString(), + hlo->operand(0)->shape().ToString(true), hlo->shape().tuple_shapes(0).ToString()); } return OkStatus(); @@ -981,7 +987,7 @@ Status ShapeVerifier::HandleRngGetAndUpdateState(HloInstruction* instruction) { const Shape& result_shape = instruction->shape(); const Shape expected_shape = ShapeUtil::MakeShape(U64, {2}); if (!ShapeUtil::Compatible(result_shape, expected_shape)) { - return InternalError( + return Internal( "Invalid RngGetAndUpdateState, expect result to have shape %s, got %s ", StringifyShape(expected_shape), StringifyShape(result_shape)); } @@ -1004,15 +1010,15 @@ Status ShapeVerifier::HandleTopK(HloInstruction* hlo) { Status ShapeVerifier::HandleSort(HloInstruction* hlo) { HloSortInstruction* sort = Cast(hlo); if (sort->operand_count() < 1) { - return InternalError("Expected at least 1 operand for %s instruction: %s", - HloOpcodeString(sort->opcode()), sort->ToString()); + return Internal("Expected at least 1 operand for %s instruction: %s", + HloOpcodeString(sort->opcode()), sort->ToString()); } HloComputation* compare = sort->to_apply(); // Check that the 'compare' computation returns a PRED. Shape compare_shape = compare->root_instruction()->shape(); if (!ShapeUtil::Compatible(compare_shape, ShapeUtil::MakeShape(PRED, {}))) { - return InternalError( + return Internal( "The Sort compare computation shape does not lead to a scalar " "predicate shape: %s", StringifyShape(compare_shape)); @@ -1034,7 +1040,7 @@ Status ShapeVerifier::HandleSort(HloInstruction* hlo) { compare->parameter_instruction(parameter_idx)->shape(); if (!ShapeUtil::CompatibleIgnoringFpPrecision(expected_scalar_shape, actual_parameter_shape)) { - return InternalError( + return Internal( "Expected the %lld-th parameter of the compare computation of sort " "to have shape %s, but got %s", parameter_idx, StringifyShape(expected_scalar_shape), @@ -1046,7 +1052,7 @@ Status ShapeVerifier::HandleSort(HloInstruction* hlo) { for (int64_t operand = 1; operand < sort->operand_count(); ++operand) { if (!ShapeUtil::SameDimensions(sort->operand(0)->shape(), sort->operand(operand)->shape())) { - return InternalError( + return Internal( "Expected sort to have to have the same dimensions for all operands. " "First operand shape is: %s\n, shape (operand index %lld) is: %s", StringifyShape(sort->operand(0)->shape()), operand, @@ -1056,7 +1062,7 @@ Status ShapeVerifier::HandleSort(HloInstruction* hlo) { // Verify the sort_dimension. if (sort->sort_dimension() >= sort->operand(0)->shape().rank()) { - return InternalError( + return Internal( "Expected the sort_dimension %d of sort to be smaller than the rank %d " "of the operand(s).", sort->sort_dimension(), sort->shape().rank()); @@ -1067,8 +1073,8 @@ Status ShapeVerifier::HandleSort(HloInstruction* hlo) { Status ShapeVerifier::HandleConstant(HloInstruction* constant) { if (!Cast(constant)->HasLiteral()) { - return InternalError("Constant is required to have a valid literal: %s", - constant->ToString()); + return Internal("Constant is required to have a valid literal: %s", + constant->ToString()); } return CheckShape(constant, constant->literal().shape(), /*only_compare_minor_to_major_in_layout=*/true); @@ -1077,15 +1083,15 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) { Status ShapeVerifier::HandleIota(HloInstruction* hlo) { auto* iota = Cast(hlo); if (!iota->shape().IsArray()) { - return InternalError("Iota does not support non-array result."); + return Internal("Iota does not support non-array result."); } const int64_t rank = iota->shape().rank(); if (rank == 0) { - return InternalError("Iota does not support scalars."); + return Internal("Iota does not support scalars."); } int64_t iota_dimension = iota->iota_dimension(); if (iota_dimension >= rank || iota_dimension < 0) { - return InternalError( + return Internal( "The iota dimension cannot go beyond the operation rank or be " "negative."); } @@ -1130,7 +1136,7 @@ Status SameElementTypesForOperandsAndToApplyParameters( Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { if (reduce->operand_count() % 2 != 0) { - return InternalError( + return Internal( "Expected an even number of operands for %s instruction: %s", HloOpcodeString(reduce->opcode()), reduce->ToString()); } @@ -1161,12 +1167,13 @@ Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { !(output_shape.is_static() && operand_shape.is_static() && (ShapeUtil::ArrayDataSize(output_shape) == ShapeUtil::ArrayDataSize(operand_shape)))) { - return InternalError( - "Bitcast cannot have different shape sizes of output (%d) and " + return Internal( + "%s: Bitcast cannot have different shape sizes of output (%d) and " "operand " "(%d) (%s) (%s)", - opts_.shape_size(output_shape), opts_.shape_size(operand_shape), - output_shape.ToString(true), operand_shape.ToString(true)); + bitcast->ToString(), opts_.shape_size(output_shape), + opts_.shape_size(operand_shape), output_shape.ToString(true), + operand_shape.ToString(true)); } } return OkStatus(); @@ -1177,8 +1184,10 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { // ShapeInference method. Check the output shape explicitly. const Shape& operand_shape = broadcast->operand(0)->shape(); // Check for mixed precision. - TF_RET_CHECK(SameElementType(broadcast->shape(), operand_shape)); - TF_RET_CHECK(operand_shape.rank() == broadcast->dimensions().size()); + TF_RET_CHECK(SameElementType(broadcast->shape(), operand_shape)) + << broadcast->ToString(); + TF_RET_CHECK(operand_shape.rank() == broadcast->dimensions().size()) + << broadcast->ToString(); for (int64_t operand_dimension = 0; operand_dimension < operand_shape.rank(); ++operand_dimension) { int64_t output_dimension = broadcast->dimensions()[operand_dimension]; @@ -1226,21 +1235,20 @@ Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { Status ShapeVerifier::HandleFusion(HloInstruction* fusion) { if (fusion->called_computations().size() != 1) { - return InternalError( - "Fusion has a non-unary number of called computations (%s)", - fusion->ToString().c_str()); + return Internal("Fusion has a non-unary number of called computations (%s)", + fusion->ToString().c_str()); } const Shape& root_computation_shape = fusion->called_computations()[0]->root_instruction()->shape(); if (!ShapesSame(fusion->shape(), root_computation_shape)) { - return InternalError( + return Internal( "Fused computation shape (%s) is not equal to the fusion shape (%s)", root_computation_shape.ToString(true), fusion->shape().ToString(true)); } auto& fused_parameters = fusion->fused_parameters(); if (fused_parameters.size() != fusion->operand_count()) { - return InternalError( + return Internal( "Fused parameter count (%d) does not match the number of operands (%d)" " passed to the fusion instruction in: %s.", fused_parameters.size(), fusion->operand_count(), @@ -1249,7 +1257,7 @@ Status ShapeVerifier::HandleFusion(HloInstruction* fusion) { for (HloInstruction* fused_param : fused_parameters) { int64_t param_no = fused_param->parameter_number(); if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) { - return InternalError( + return Internal( "Shape mismatch between parameter number %d and its operand in " "%s.", param_no, fusion->ToString().c_str()); @@ -1308,7 +1316,7 @@ Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) { custom_call->operand_shapes_with_layout()[i]; TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(), operand_shape_with_layout)) - << custom_call->operand(i)->shape().ToString() << " operand " + << custom_call->operand(i)->shape().ToString(true) << " operand " << operand_shape_with_layout.ToString(); TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout)); } @@ -1430,7 +1438,7 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { xla_while->while_condition()->root_instruction()->shape(); if (!ShapeUtil::Compatible(conditional_shape, ShapeUtil::MakeShape(PRED, {}))) { - return InternalError( + return Internal( "Conditional computation shape does not lead to a scalar predicate " "shape: %s", StringifyShape(conditional_shape)); @@ -1484,7 +1492,7 @@ Status CheckAsyncOpOperand(const HloInstruction* async_op) { const HloInstruction* operand = async_op->operand(0); if (operand->opcode() != HloOpcode::kAsyncStart && operand->opcode() != HloOpcode::kAsyncUpdate) { - return InternalError( + return Internal( "%s expects operand to be async-update or async-done, found " "%s.", HloOpcodeString(async_op->opcode()), @@ -1492,7 +1500,7 @@ Status CheckAsyncOpOperand(const HloInstruction* async_op) { } if (*async_op->async_wrapped_computation() != *operand->async_wrapped_computation()) { - return InternalError( + return Internal( "The %s expects its wrapped async computation to be identical to its " "operand's wrapped async computation (%s vs %s), thread name (%s vs " "%s).", @@ -1502,15 +1510,6 @@ Status CheckAsyncOpOperand(const HloInstruction* async_op) { async_op->async_wrapped_computation()->execution_thread(), operand->async_wrapped_computation()->execution_thread()); } - if (async_op->async_group_id() != operand->async_group_id()) { - return InternalError( - "%s expects its operand to have the same group id (%s vs %s).", - HloOpcodeString(async_op->opcode()), - async_op->async_group_id() ? absl::StrCat(*async_op->async_group_id()) - : "none", - operand->async_group_id() ? absl::StrCat(*operand->async_group_id()) - : "none"); - } return OkStatus(); } @@ -1518,7 +1517,7 @@ Status CheckAsyncOpComputationThreadName(const HloInstruction* async_op) { absl::string_view async_execution_thread = async_op->async_execution_thread(); if (async_execution_thread != async_op->async_wrapped_computation()->execution_thread()) { - return InternalError( + return Internal( "%s expects same async thread name as wrapped computation's " "thread name (%s vs %s).", HloOpcodeString(async_op->opcode()), async_execution_thread, @@ -1535,7 +1534,7 @@ Status CheckCallableInstructionThreadName(const HloInstruction* instruction, if (instruction->parent() != nullptr) { if (instruction->parent()->execution_thread() != computation->execution_thread()) { - return InternalError( + return Internal( "callable instruction %s expects parent computation thread name " "same as called computation's thread name (%s vs %s).", instruction->ToString(), instruction->parent()->execution_thread(), @@ -1552,7 +1551,7 @@ Status CheckCallableInstructionThreadName(const HloInstruction* instruction, Status ShapeVerifier::CheckAsyncOpComputationShapes( const HloInstruction* async_op, const Shape& async_shape) { if (!async_shape.IsTuple() || async_shape.tuple_shapes_size() < 2) { - return InternalError( + return Internal( "The %s expects the async shape to be a tuple of at least two " "elements, found %s.", HloOpcodeString(async_op->opcode()), async_shape.ToString()); @@ -1561,7 +1560,7 @@ Status ShapeVerifier::CheckAsyncOpComputationShapes( async_op->async_wrapped_computation()->ComputeProgramShape(); Shape param_shape = ShapeUtil::MakeTupleShape(computation_shape.parameters()); if (!ShapesSame(async_shape.tuple_shapes(0), param_shape)) { - return InternalError( + return Internal( "The %s expects the async shape at index {0} to match async " "computation parameter shape (%s vs %s).", HloOpcodeString(async_op->opcode()), @@ -1569,7 +1568,7 @@ Status ShapeVerifier::CheckAsyncOpComputationShapes( param_shape.ToString(/*print_layout=*/true)); } if (!ShapesSame(async_shape.tuple_shapes(1), computation_shape.result())) { - return InternalError( + return Internal( "The %s expects the async shape at index {1} to match the async " "computation root shape (%s vs %s).", HloOpcodeString(async_op->opcode()), @@ -1587,7 +1586,7 @@ Status ShapeVerifier::HandleAsyncStart(HloInstruction* async_start) { for (int i = 0; i < async_start->operand_count(); ++i) { if (!ShapesSame(param_shape.tuple_shapes(i), async_start->operand(i)->shape())) { - return InternalError( + return Internal( "The %s expects the shape of operand %d to match the async shape at " "index {0} (%s vs %s).", HloOpcodeString(async_start->opcode()), i, @@ -1601,11 +1600,11 @@ Status ShapeVerifier::HandleAsyncStart(HloInstruction* async_start) { Status ShapeVerifier::HandleAsyncUpdate(HloInstruction* async_update) { TF_RETURN_IF_ERROR(CheckAsyncOpComputationThreadName(async_update)); if (!ShapesSame(async_update->operand(0)->shape(), async_update->shape())) { - return InternalError( + return Internal( "The %s expects the shape of operand and output to match (%s vs %s).", HloOpcodeString(async_update->opcode()), - async_update->operand(0)->shape().ToString(), - async_update->shape().ToString()); + async_update->operand(0)->shape().ToString(true), + async_update->shape().ToString(true)); } TF_RETURN_IF_ERROR( CheckAsyncOpComputationShapes(async_update, async_update->shape())); @@ -1618,11 +1617,11 @@ Status ShapeVerifier::HandleAsyncDone(HloInstruction* async_done) { async_done, async_done->operand(0)->shape())); const Shape& root_shape = async_done->operand(0)->shape().tuple_shapes(1); if (!ShapesSame(root_shape, async_done->shape())) { - return InternalError( + return Internal( "The %s expects the shape of output to match the async shape at index " "{1} (%s vs %s).", - HloOpcodeString(async_done->opcode()), async_done->shape().ToString(), - root_shape.ToString()); + HloOpcodeString(async_done->opcode()), + async_done->shape().ToString(true), root_shape.ToString(true)); } return CheckAsyncOpOperand(async_done); } @@ -1640,9 +1639,8 @@ Status ShapeVerifier::HandleCopyDone(HloInstruction* copy_done) { const Shape& dest_shape = ShapeUtil::GetTupleElementShape(operand_shape, 0); const Shape& src_shape = ShapeUtil::GetTupleElementShape(operand_shape, 1); if (!ShapesSame(dest_shape, src_shape, - /*minor_to_major_only=*/false, - /*ignore_memory_space=*/true)) { - return InternalError( + Shape::Equal().IgnoreMemorySpaceInLayout())) { + return Internal( "Source and destination buffers in CopyDone arguments need to be the " "same shape found %s and %s\n%s", StringifyShape(dest_shape), StringifyShape(src_shape), @@ -1761,14 +1759,15 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { for (auto operand : instruction->operands()) { TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( operand->shape(), - [&](const Shape& subshape, const ShapeIndex& index) { + [&](const Shape& subshape, + const ShapeIndex& index) -> absl::Status { if (!ShapeUtil::ElementIsFloating(subshape)) { return OkStatus(); } if (fp_type == PRIMITIVE_TYPE_INVALID) { fp_type = subshape.element_type(); } else if (fp_type != subshape.element_type()) { - return InternalError( + return Internal( "Seen floating point types of different precisions in " "%s, but mixed precision is disallowed.", instruction->ToString()); @@ -1866,18 +1865,27 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kTuple: - case HloOpcode::kWhile: - return ShapesSame(instruction->shape(), inferred_shape, - only_compare_minor_to_major_in_layout); - case HloOpcode::kDynamicUpdateSlice: - // For DynamicUpdateSlice it has an "in-place" update semantics, but - // inside of fusions memory space propagation doesn't propagate the - // memory spaces all the way, causing possible mismatches. Relax the - // constraint in that condition. - return ShapesSame(instruction->shape(), inferred_shape, - only_compare_minor_to_major_in_layout, - /*ignore_memory_space=*/ - instruction->parent()->IsFusionComputation()); + case HloOpcode::kWhile: { + Shape::Equal equal; + if (only_compare_minor_to_major_in_layout) { + equal.MinorToMajorOnlyInLayout(); + } + return ShapesSame(instruction->shape(), inferred_shape, equal); + } + case HloOpcode::kDynamicUpdateSlice: { + Shape::Equal equal; + if (only_compare_minor_to_major_in_layout) { + equal.MinorToMajorOnlyInLayout(); + } + if (instruction->parent()->IsFusionComputation()) { + // For DynamicUpdateSlice it has an "in-place" update semantics, but + // inside of fusions memory space propagation doesn't propagate the + // memory spaces all the way, causing possible mismatches. Relax the + // constraint in that condition. + equal.IgnoreMemorySpaceInLayout(); + } + return ShapesSame(instruction->shape(), inferred_shape, equal); + } // We allow arbitrary layout and f32->bf16 transformations on all other // instructions, although this may be made more strict pending discussion @@ -1892,7 +1900,7 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } }(); if (!equal) { - return InternalError( + return Internal( "Expected instruction to have shape equal to %s, actual " "shape is %s:\n%s", StringifyShape(inferred_shape), StringifyShape(instruction->shape()), @@ -1901,8 +1909,9 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, return OkStatus(); } -Status ShapeVerifier::CheckShape(const HloInstruction* instruction, - const StatusOr& inferred_shape_status) { +Status ShapeVerifier::CheckShape( + const HloInstruction* instruction, + const absl::StatusOr& inferred_shape_status) { if (!inferred_shape_status.ok()) { Status s = inferred_shape_status.status(); tsl::errors::AppendToMessage(&s, ", for instruction ", @@ -1950,9 +1959,11 @@ Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) { // let's not check that. if (!ShapesSame(computation->root_instruction()->shape(), result_layout.shape(), - /*minor_to_major_only=*/false, /*ignore_memory_space=*/false, - /*ignore_tiles=*/true)) { - return InternalError( + Shape::Equal() + .IgnoreTilesInLayout() + .IgnoreTailPaddingAlignmentInElements() + .IgnoreMemorySpaceInLayout())) { + return Internal( "Shape of the root instruction of entry computation (%s) should be " "compatible to one specified in module's entry computation layout (%s)", StringifyShape(computation->root_instruction()->shape()), @@ -1960,7 +1971,7 @@ Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) { } if (computation->num_parameters() != layout.parameter_count()) { - return InternalError( + return Internal( "Number of parameters in entry computation layout (%d) must be same " "as number of parameters of entry computation (%d)", layout.parameter_count(), computation->num_parameters()); @@ -1973,10 +1984,11 @@ Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) { // TPU layout assignment doesn't set the tiles on entry_computation_layout, // so let's not check that. if (!ShapesSame(parameter->shape(), layout.parameter_shape(i), - /*minor_to_major_only=*/false, - /*ignore_memory_space=*/false, - /*ignore_tiles=*/true)) { - return InternalError( + Shape::Equal() + .IgnoreTilesInLayout() + .IgnoreTailPaddingAlignmentInElements() + .IgnoreMemorySpaceInLayout())) { + return Internal( "Shape of the entry computation parameter %d is %s should be " "compatible to the one specified in module's entry computation " "layout %s", @@ -2010,22 +2022,21 @@ std::string ComputationsToString( Status VerifyHloStructure(HloModule* module) { for (const HloComputation* computation : module->computations()) { if (computation->parent() == nullptr) { - return InternalError("Computation %s has a null parent pointer", - computation->name()); + return Internal("Computation %s has a null parent pointer", + computation->name()); } if (computation->parent() != module) { - return InternalError( - "Computation %s parent() does not point to parent module", - computation->name()); + return Internal("Computation %s parent() does not point to parent module", + computation->name()); } for (const HloInstruction* instruction : computation->instructions()) { if (instruction->parent() == nullptr) { - return InternalError("Instruction %s has a null parent pointer", - instruction->name()); + return Internal("Instruction %s has a null parent pointer", + instruction->name()); } if (instruction->parent() != computation) { - return InternalError( + return Internal( "Instruction %s parent() does not point to parent computation", instruction->name()); } @@ -2041,7 +2052,7 @@ Status VerifyHloStructure(HloModule* module) { for (int i = 0; i < instruction->operand_count(); ++i) { const HloInstruction* operand = instruction->operand(i); if (operand->parent() != instruction->parent()) { - return InternalError( + return Internal( "Operand %d (%s) of instruction %s is in a different " "computation: %s vs %s", i, operand->name(), instruction->name(), @@ -2068,28 +2079,11 @@ bool ShapeContainsToken(const Shape& shape) { return contains_token; } -// Verifies that all types entering and exiting the entry computation are -// legal. -Status VerifyEntryAndExitShapes(const HloModule& module) { - // Tokens cannot be passed as entry parameters. - // TODO(b/80000000): Remove this constraint. - for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { - HloInstruction* param = - module.entry_computation()->parameter_instruction(i); - if (ShapeContainsToken(param->shape())) { - return InternalError( - "Entry parameter %d is or contains a token shape: %s", i, - ShapeUtil::HumanString(param->shape())); - } - } - return OkStatus(); -} - // Checks if the given two instructions share the same channel id. Status CheckSameChannel(const HloInstruction* instr1, const HloInstruction* instr2) { if (instr1->channel_id() != instr2->channel_id()) { - return InternalError( + return Internal( "Expected to have the same channel id, actual channel ids are: %s " "(%d), %s (%d)", instr1->ToString(), *instr1->channel_id(), instr2->ToString(), @@ -2099,7 +2093,7 @@ Status CheckSameChannel(const HloInstruction* instr1, } // Checks if the given two instructions have the same is_host_transfer -// attribute value. Intsructions must be send/recv instructions or their +// attribute value. Instructions must be send/recv instructions or their // 'done' variant. Status CheckSameIsHostTransfer(const HloInstruction* instr1, const HloInstruction* instr2) { @@ -2110,7 +2104,7 @@ Status CheckSameIsHostTransfer(const HloInstruction* instr1, TF_RET_CHECK(send_recv1 != nullptr); TF_RET_CHECK(send_recv2 != nullptr); if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) { - return InternalError( + return Internal( "Expected instructions to have the same is-host-transfer property: " "%s, " "%s ", @@ -2261,8 +2255,40 @@ Status VerifyChannels(const HloModule& module) { absl::flat_hash_map> channel_instructions; - // Send/Recv instruction must have a single user: the corresponding - // SendDone/RecvDone. with matching channel. + // For Async operations, we need to make sure: + // (1) AsyncStart and AsyncDone are used in pairs + // (2) AsynStart and Asyndone are connected, that is, an AsynDone has an + // AsyncStart as its only operand, and an AsynStart has an AsyncDone as + // its only user + // (3) the channel ID used by a pair of Async operations is unique + // + // Send and SendDone, Recv and RecvDone are such pairs of Async operations. + // Different from other Async operations, a channel ID can be used by one + // Send-SendDone pair and one Recv-RecvDone pair. As such, we verify the + // above three invariants for Send/Recv related instructions with adjustment + // to (3): + // (3*) the channel ID used by a pair of Send-SendDone can be shared by at + // most one pair of Recv-RecvDone. + // + // Currently, the GPU compiler can decomposed collective-permute into a group + // of instructions with a pair of Send-SendDone and a pair of Recv-RecvDone + // that use the same channel ID. When a while-body contains such instructions, + // the GPU compiler can also peel off Send and Recv, and statically order + // SendDone/RecvDone inside the while-body before Send/Recv. This breaks + // invariants (2) and (3*) for the pipelined Send/Recv case. We verify the + // following for a group of instructions using the same channel ID but don't + // satisfy invariants (1)(2)(3*): + // (4) All instructions in the group are annotated with frontend attributes. + // We avoid verifying the content of such a frontend attribute to avoid + // making the general HLO instruction verifier depend on the compiler pass + // that performs the transformation. + // (5) the group should contain equal number uses of each Send/Recv related + // instructions. + // + // Comparing the verification of unpipelined Send/Recv with the verification + // of pipelined, what we missing verifying is that the direct connection + // between Send/Recv and SendDone/RecvDone through operands. + // for (const HloComputation* computation : module.computations()) { for (const HloInstruction* instruction : computation->instructions()) { auto channel_instr = DynCast(instruction); @@ -2273,29 +2299,55 @@ Status VerifyChannels(const HloModule& module) { switch (instruction->opcode()) { case HloOpcode::kSend: { - TF_RET_CHECK(instruction->users().size() == 1); - const HloInstruction* send_done = instruction->users().front(); - TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); - TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done)); - TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done)); + bool pipelined = true; + if (instruction->users().size() == 1) { + const HloInstruction* send_user = instruction->users().front(); + if (send_user->opcode() == HloOpcode::kSendDone) { + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_user)); + TF_RETURN_IF_ERROR( + CheckSameIsHostTransfer(instruction, send_user)); + pipelined = false; + } + } + // Pipelined Send should be annotated with frontend attributes. + TF_RET_CHECK(pipelined == false || + !instruction->frontend_attributes().map().empty()); break; } case HloOpcode::kRecv: { - TF_RET_CHECK(instruction->users().size() == 1); - const HloInstruction* recv_done = instruction->users().front(); - TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); - TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done)); - TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done)); + bool pipelined = true; + if (instruction->users().size() == 1) { + const HloInstruction* recv_user = instruction->users().front(); + if (recv_user->opcode() == HloOpcode::kRecvDone) { + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_user)); + TF_RETURN_IF_ERROR( + CheckSameIsHostTransfer(instruction, recv_user)); + pipelined = false; + } + } + // Pipelined Recv should be annotated with frontend attributes. + TF_RET_CHECK(pipelined == false || + !instruction->frontend_attributes().map().empty()); break; } - case HloOpcode::kSendDone: + case HloOpcode::kSendDone: { TF_RET_CHECK(instruction->operands().size() == 1); - TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend); + const HloInstruction* send_done_operand = instruction->operand(0); + // If the operand is not a Send, the Send-done is pipelined and should + // have frontend attributes. + TF_RET_CHECK(send_done_operand->opcode() == HloOpcode::kSend || + !instruction->frontend_attributes().map().empty()); break; - case HloOpcode::kRecvDone: + } + case HloOpcode::kRecvDone: { TF_RET_CHECK(instruction->operands().size() == 1); - TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv); + const HloInstruction* recv_done_operand = instruction->operand(0); + // If the operand is not a Recv, the Recv-done is pipelined and should + // have frontend attributes. + TF_RET_CHECK(recv_done_operand->opcode() == HloOpcode::kRecv || + !instruction->frontend_attributes().map().empty()); break; + } default: break; } @@ -2308,33 +2360,50 @@ Status VerifyChannels(const HloModule& module) { const HloInstruction* first = instructions[0]; auto sendrecv = DynCast(first); if (sendrecv) { - absl::flat_hash_set opcodes; - bool maybe_send_recv_pipeline = false; + // Check that all instructions are Send/Recv related and count the + // appearance of each opcode in the group. + absl::flat_hash_map opcode_to_count; for (const HloInstruction* instr : instructions) { - if (opcodes.insert(instr->opcode()).second == false) { - // A channel is used by multiple instructions with the same opcode. - // This is only allows for pipelining Send and Recv, assuming such - // instructions have non-empty frontend attributes. - if (DynCast(instr) || - DynCast(instr)) { - maybe_send_recv_pipeline = - (!instr->frontend_attributes().map().empty()); - } + auto it = opcode_to_count.find(instr->opcode()); + if (it != opcode_to_count.end()) { + it->second++; + } else { + opcode_to_count[instr->opcode()] = 1; } - auto cast = DynCast(instr); - TF_RET_CHECK(cast != nullptr) + TF_RET_CHECK(DynCast(instr) != nullptr) << "channel " << pair.first << " is used for different types of channel instructions"; } + + int count = opcode_to_count.begin()->second; + bool consistent_count = + absl::c_all_of(opcode_to_count, [count](const auto& opcode_count) { + return opcode_count.second == count; + }); + // A pipelined group of Send/Recv should all have frontend attributes. + bool maybe_pipelined = + absl::c_all_of(instructions, [](const HloInstruction* inst) { + return !inst->frontend_attributes().map().empty(); + }); + if (sendrecv->is_host_transfer()) { - TF_RET_CHECK(instructions.size() == 2) + TF_RET_CHECK(consistent_count && count == 1 && instructions.size() == 2) << "channel " << pair.first << " is used for multiple host send/recv instructions"; } else { - if (!maybe_send_recv_pipeline) { - TF_RET_CHECK(instructions.size() == opcodes.size()) + if (consistent_count && count == 1) { + TF_RET_CHECK(instructions.size() == opcode_to_count.size()) << "channel " << pair.first << " is used for multiple send/recv instructions"; + } else { + TF_RET_CHECK(maybe_pipelined) << "channel " << pair.first + << " is used for multiple send/recv " + "instructions but not pipelined"; + TF_RET_CHECK(consistent_count && opcode_to_count.size() % 2 == 0) + << "channel " << pair.first + << " is pipelined. Not all Send/Recv related instructions are" + " used the same number of times or channel is used for other " + "instructions"; } } } else { @@ -2354,7 +2423,7 @@ Status CheckFusionInstruction(HloInstruction* fusion) { // The parent fusion instruction of the fusion computation must be 'fusion'. HloComputation* fused_computation = fusion->fused_instructions_computation(); if (fusion != fused_computation->FusionInstruction()) { - return InternalError( + return Internal( "Instruction of fused computation does not match expected " "instruction " "%s.", @@ -2370,36 +2439,35 @@ Status CheckFusionInstruction(HloInstruction* fusion) { for (auto* instruction : fused_computation->instructions()) { if (fused_root == instruction) { if (root_owned) { - return InternalError("Root appears more than once in %s.", - fusion->ToString()); + return Internal("Root appears more than once in %s.", + fusion->ToString()); } root_owned = true; } for (int i = 0; i < fused_parameters.size(); ++i) { if (fused_parameters[i] == instruction) { if (parameter_owned[i]) { - return InternalError("Parameter appears more than once in %s.", - fusion->ToString()); + return Internal("Parameter appears more than once in %s.", + fusion->ToString()); } parameter_owned[i] = true; } } } if (!root_owned) { - return InternalError("Root not found in computation of %s.", - fusion->ToString()); + return Internal("Root not found in computation of %s.", fusion->ToString()); } // Make sure all the parameter_owned entries are set for (int i = 0; i < parameter_owned.size(); i++) { if (!parameter_owned[i]) { - return InternalError("Parameter %d not found in computation of %s.", i, - fusion->ToString()); + return Internal("Parameter %d not found in computation of %s.", i, + fusion->ToString()); } } // Fused root must have no users. if (fused_root->user_count() != 0) { - return InternalError("Root of %s may not have users.", fusion->ToString()); + return Internal("Root of %s may not have users.", fusion->ToString()); } // All uses of fused instructions must be in the fusion computation, and @@ -2408,12 +2476,12 @@ Status CheckFusionInstruction(HloInstruction* fusion) { fusion->fused_instructions_computation()->instructions()) { if (instruction != fused_root) { if (instruction->user_count() == 0) { - return InternalError("Non-root instruction %s in %s must have users.", - instruction->ToString(), fusion->ToString()); + return Internal("Non-root instruction %s in %s must have users.", + instruction->ToString(), fusion->ToString()); } for (auto& user : instruction->users()) { if (fused_computation != user->parent()) { - return InternalError( + return Internal( "Non-root instruction %s in %s may not have external users.", instruction->ToString(), fusion->ToString()); } @@ -2428,17 +2496,17 @@ Status CheckFusionInstruction(HloInstruction* fusion) { for (auto fused_param : fused_parameters) { int64_t param_no = fused_param->parameter_number(); if (param_no < 0) { - return InternalError("Unexpected negative parameter number %d in %s.", - param_no, fusion->ToString()); + return Internal("Unexpected negative parameter number %d in %s.", + param_no, fusion->ToString()); } if (param_no >= fused_parameters.size()) { - return InternalError( + return Internal( "Unexpected parameter number %d in %s: higher then number of " "parameters %lu.", param_no, fusion->ToString(), fused_parameters.size()); } if (parameter_numbers[param_no]) { - return InternalError( + return Internal( "Did not expect parameter number %d more than once in %s.", param_no, fusion->ToString()); } @@ -2447,8 +2515,8 @@ Status CheckFusionInstruction(HloInstruction* fusion) { // Make sure all the parameter_numbers entries were seen. for (int i = 0; i < parameter_numbers.size(); i++) { if (!parameter_numbers[i]) { - return InternalError("Did not see parameter number %d in %s.", i, - fusion->ToString()); + return Internal("Did not see parameter number %d in %s.", i, + fusion->ToString()); } } @@ -2738,10 +2806,18 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { operand_shape.rank() == result_shape.rank() && operand_shape.has_layout()) { const Layout& operand_layout = operand_shape.layout(); - Layout::Equal equal_predicate = Layout::Equal(); + Layout::Equal equal_predicate = + Layout::Equal().IgnoreTiles().IgnoreMemorySpace(); if (instruction->opcode() == HloOpcode::kConvert) { // Convert instructions can change element_size_in_bits equal_predicate.IgnoreElementSize(); + } else if (instruction->opcode() == HloOpcode::kDynamicSlice || + instruction->opcode() == HloOpcode::kDynamicUpdateSlice || + instruction->opcode() == HloOpcode::kCopy) { + TF_RETURN_IF_ERROR(HostOffloadInstructionCanChangeMemorySpace( + instruction, operand_layout.memory_space(), + result_layout.memory_space())); + equal_predicate.IgnoreMemorySpace(); } TF_RET_CHECK(equal_predicate(result_layout, operand_layout)) << "Instruction shouldn't change layouts " @@ -2775,6 +2851,39 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { return OkStatus(); } + // Verifies whether a given `instruction` is permitted to change the layout + // memory space from `operand_memory_space` to `result_memory_space`. + // Returns OkStatus() if the instruction's layout changes are valid; + // otherwise, returns an appropriate error status. + static Status HostOffloadInstructionCanChangeMemorySpace( + const HloInstruction* instruction, const int64_t operand_memory_space, + const int64_t result_memory_space) { + TF_RET_CHECK(!(operand_memory_space == Layout::kGenericFastMemorySpace && + result_memory_space != Layout::kGenericFastMemorySpace) || + (operand_memory_space != Layout::kGenericFastMemorySpace && + result_memory_space == Layout::kGenericFastMemorySpace)) + << "Instruction shouldn't change layout memory space between generic " + "fast memory space and others for instruction: " + << instruction->ToString(); + + if (instruction->opcode() == HloOpcode::kDynamicSlice) { + TF_RET_CHECK(!(operand_memory_space == Layout::kDefaultMemorySpace && + result_memory_space == Layout::kHostMemorySpace)) + << "DynamicSlice instruction shouldn't change layout memory " + << "space from device to host: " << instruction->ToString(); + } else if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice) { + TF_RET_CHECK(!(operand_memory_space == Layout::kHostMemorySpace && + result_memory_space == Layout::kDefaultMemorySpace)) + << "DynamicUpdateSlice instruction shouldn't change layout " + << "memory space from host to device: " << instruction->ToString(); + } else if (instruction->opcode() != HloOpcode::kCopy) { + return absl::InvalidArgumentError( + absl::StrCat("Instruction shouldn't change layout memory space: ", + instruction->ToString())); + } + return OkStatus(); + } + absl::flat_hash_map instructions_by_name_; const HloVerifierOpts& opts_; std::optional num_devices_; @@ -2782,14 +2891,14 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { } // namespace -StatusOr HloVerifier::Run( +absl::StatusOr HloVerifier::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { auto disabled = module->config().debug_options().xla_disable_hlo_passes(); if (std::find(disabled.begin(), disabled.end(), name()) != disabled.end()) { return false; } - auto status_or_changed = [&]() -> StatusOr { + auto status_or_changed = [&]() -> absl::StatusOr { TF_RET_CHECK(!module->name().empty()); if (module->entry_computation()->IsFusionComputation()) { @@ -2814,7 +2923,6 @@ StatusOr HloVerifier::Run( } TF_RETURN_IF_ERROR(shape_verifier->VerifyEntryComputationLayout(*module)); - TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); // If the module has a schedule, it must be valid. if (module->has_schedule()) { @@ -2894,12 +3002,6 @@ void MetadataTracker::HandleMetadata(const OpMetadata& metadata) { if (metadata.source_line() != 0) { ++has_source_line_count_; } - if (metadata.creation_pass_id() != 0) { - ++has_creation_pass_id_count_; - } - if (metadata.logical_creation_pass_id() != 0) { - ++has_logical_creation_pass_id_count_; - } if (metadata.size_of_generated_code_in_bytes() != 0) { ++has_size_of_generated_code_in_bytes_count_; } diff --git a/xla/service/hlo_verifier.h b/xla/service/hlo_verifier.h index 29744af10982b..91806e35976a0 100644 --- a/xla/service/hlo_verifier.h +++ b/xla/service/hlo_verifier.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -181,6 +181,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleAllReduceStart(HloInstruction* hlo) override; Status HandleAllReduceDone(HloInstruction* hlo) override; Status HandleAllToAll(HloInstruction* hlo) override; + Status HandleCollectiveBroadcast(HloInstruction* hlo) override; Status HandleCollectivePermute(HloInstruction* hlo) override; Status HandleCollectivePermuteStart(HloInstruction* hlo) override; Status HandleCollectivePermuteDone(HloInstruction* hlo) override; @@ -244,9 +245,7 @@ class ShapeVerifier : public DfsHloVisitor { protected: // Helpers that switch on layout_sensitive_. - bool ShapesSame(const Shape& a, const Shape& b, - bool minor_to_major_only = false, - bool ignore_memory_space = false, bool ignore_tiles = false); + bool ShapesSame(const Shape& a, const Shape& b, Shape::Equal equal = {}); // Check the instruction's shape against the shape given by ShapeInference // and return an appropriate error if there is a mismatch. @@ -256,7 +255,7 @@ class ShapeVerifier : public DfsHloVisitor { // Overload which takes a StatusOr to reduce boilerplate in the caller. Status CheckShape(const HloInstruction* instruction, - const StatusOr& inferred_shape_status); + const absl::StatusOr& inferred_shape_status); static Status CheckParameterCount(const HloInstruction* calling_instruction, const HloComputation* computation, @@ -269,19 +268,6 @@ class ShapeVerifier : public DfsHloVisitor { Status CheckVariadicShape(const HloInstruction* instruction); private: - bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b, - bool minor_to_major_only = false) { - if (!opts_.layout_sensitive) { - return ShapeUtil::CompatibleIgnoringFpPrecision(a, b); - } - Shape::Equal equal; - if (minor_to_major_only) { - equal.MinorToMajorOnlyInLayout(); - } - equal.IgnoreFpPrecision(); - return equal(a, b); - } - std::string StringifyShape(const Shape& s) { return opts_.layout_sensitive ? ShapeUtil::HumanStringWithLayout(s) : ShapeUtil::HumanString(s); @@ -390,7 +376,7 @@ class HloVerifier : public HloModulePass { // Never returns true; no instructions are ever modified by this pass. using HloPassInterface::Run; using HloPassInterface::RunOnModuleGroup; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/hlo_verifier_test.cc b/xla/service/hlo_verifier_test.cc index 7d9a5a93470b6..15dd3208efb25 100644 --- a/xla/service/hlo_verifier_test.cc +++ b/xla/service/hlo_verifier_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -330,7 +330,7 @@ TEST_F(HloVerifierTest, CheckConditionalBranchContainsAsyncThread) { branch1 { fparam = f32[4] parameter(0) %async-start = ((f32[4]), f32[4], s32[]) custom-call-start(f32[4] fparam), async_execution_thread="parallel_thread", custom_call_target="foo" - ROOT %async-done = f32[4] custom-call-done(((f32[4]), f32[4], s32[]) %async-start), async_execution_thread="parallel_thread", custom_call_target="foo" + ROOT %async-done = f32[4] custom-call-done(((f32[4]), f32[4], s32[]) %async-start) } branch2 { @@ -856,7 +856,7 @@ TEST_F(HloVerifierTestLayoutSensitive, AsyncStartAndAsyncDone) { ENTRY AsyncStartAndAsyncDone { p0 = f32[2,3]{1,0:S(1)} parameter(0) async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), custom_call_target="foo" - ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-start), custom_call_target="foo" + ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-start) } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, @@ -873,9 +873,9 @@ TEST_F(HloVerifierTestLayoutSensitive, AsyncStartAndAsyncUpdateAndAsyncDone) { ENTRY AsyncStartAndAsyncUpdateAndAsyncDone { p0 = f32[2,3]{1,0:S(1)} parameter(0) async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), custom_call_target="foo" - async-update.1 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-start), custom_call_target="foo" - async-update.2 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-update.1), custom_call_target="foo" - ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-update.2), custom_call_target="foo" + async-update.1 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-start) + async-update.2 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-update.1) + ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-update.2) } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, @@ -893,9 +893,9 @@ TEST_F(HloVerifierTestLayoutSensitive, ENTRY AsyncStartAndAsyncUpdateAndAsyncDone { p0 = f32[2,3]{1,0:S(1)} parameter(0) async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), async_execution_thread="parallel_thread", custom_call_target="foo" - async-update.1 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-start), async_execution_thread="parallel_thread", custom_call_target="foo" - async-update.2 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-update.1), async_execution_thread="parallel_thread", custom_call_target="foo" - ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-update.2), async_execution_thread="parallel_thread", custom_call_target="foo" + async-update.1 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-start) + async-update.2 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-update.1) + ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-update.2) } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, @@ -912,7 +912,7 @@ TEST_F(HloVerifierTest, AsyncStartAndAsyncDoneWrongType) { ENTRY AsyncStartAndAsyncDone { p0 = f32[2,3] parameter(0) async-start = ((f32[2,3]), f32[3,2], u32[]) custom-call-start(p0), custom_call_target="foo" - ROOT async-done = f32[2,3] custom-call-done(async-start), custom_call_target="foo" + ROOT async-done = f32[2,3] custom-call-done(async-start) } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, @@ -921,47 +921,8 @@ TEST_F(HloVerifierTest, AsyncStartAndAsyncDoneWrongType) { auto status = verifier().Run(module.get()).status(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.message(), - HasSubstr("async-done expects the async shape at index {1} to " - "match the async computation root shape")); -} - -TEST_F(HloVerifierTest, AsyncStartAndAsyncDoneWrongThreadName) { - const char* const hlo_string = R"( - HloModule Module - - ENTRY AsyncStartAndAsyncDone { - p0 = f32[2,3] parameter(0) - async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), async_execution_thread="parallel_thread", custom_call_target="foo" - ROOT async-done = f32[2,3] custom-call-done(async-start), async_execution_thread="main_thread", custom_call_target="bar" - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); - - auto status = verifier().Run(module.get()).status(); - ASSERT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - HasSubstr("thread name (main_thread vs parallel_thread).")); -} - -TEST_F(HloVerifierTest, AsyncStartAndAsyncDoneWrongAttr) { - const char* const hlo_string = R"( - HloModule Module - - ENTRY AsyncStartAndAsyncDone { - p0 = f32[2,3] parameter(0) - async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo" - ROOT async-done = f32[2,3] custom-call-done(async-start), custom_call_target="bar" - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); - - auto status = verifier().Run(module.get()).status(); - ASSERT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - HasSubstr("async-done expects its wrapped async computation to " - "be identical to its operand's")); + HasSubstr("async-done expects the shape of output to match the " + "async shape at index {1}")); } TEST_F(HloVerifierTest, AsyncStartMultipleAsyncDone) { @@ -971,8 +932,8 @@ TEST_F(HloVerifierTest, AsyncStartMultipleAsyncDone) { ENTRY AsyncStartAndAsyncDone { p0 = f32[2,3] parameter(0) async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo" - async-done.1 = f32[2,3] custom-call-done(async-start), custom_call_target="foo" - async-done.2 = f32[2,3] custom-call-done(async-start), custom_call_target="foo" + async-done.1 = f32[2,3] custom-call-done(async-start) + async-done.2 = f32[2,3] custom-call-done(async-start) ROOT tuple = (f32[2,3], f32[2,3]) tuple(async-done.1, async-done.2) } )"; @@ -1012,7 +973,7 @@ TEST_F(HloVerifierTest, AsyncStartAndAsyncUpdateNoAsyncDone) { ENTRY AsyncStartAndAsyncDone { p0 = f32[2,3] parameter(0) async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo" - ROOT async-update = ((f32[2,3]), f32[2,3], u32[]) custom-call-update(async-start), custom_call_target="foo" + ROOT async-update = ((f32[2,3]), f32[2,3], u32[]) custom-call-update(async-start) } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, @@ -1029,16 +990,27 @@ TEST_F(HloVerifierTest, AsyncDoneNoAsyncStart) { const char* const hlo_string = R"( HloModule Module - ENTRY AsyncStartAndAsyncDone { + ENTRY AsyncDoneNoAsyncStart { p0 = f32[2,3] parameter(0) p1 = u32[] parameter(1) tuple = ((f32[2,3]), f32[2,3], u32[]) tuple(p0, p0, p1) - ROOT async-done = f32[2,3] custom-call-done(tuple), custom_call_target="foo" + async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo" + ROOT async-done = f32[2,3] custom-call-done(async-start) } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo_string)); + // The parser checks that the async-{update,done} operand is an async op, + // so we need to invalidate it in the C++ representation. + HloInstruction* tuple = FindInstruction(module.get(), "tuple"); + HloInstruction* async_done = FindInstruction(module.get(), "async-done"); + TF_ASSERT_OK(async_done->ReplaceOperandWith(0, tuple)); + HloInstruction* async_start = FindInstruction(module.get(), "async-start"); + HloComputation* computation = + FindComputation(module.get(), "AsyncDoneNoAsyncStart"); + TF_ASSERT_OK(computation->RemoveInstruction(async_start)); + auto status = verifier().Run(module.get()).status(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.message(), @@ -1050,17 +1022,30 @@ TEST_F(HloVerifierTest, AsyncUpdateAndAsyncDoneNoAsyncStart) { const char* const hlo_string = R"( HloModule Module - ENTRY AsyncStartAndAsyncDone { + ENTRY AsyncUpdateAndAsyncDoneNoAsyncStart { p0 = f32[2,3] parameter(0) p1 = u32[] parameter(1) tuple = ((f32[2,3]), f32[2,3], u32[]) tuple(p0, p0, p1) - async-update = ((f32[2,3]), f32[2,3], u32[]) custom-call-update(tuple), custom_call_target="foo" - ROOT async-done = f32[2,3] custom-call-done(tuple), custom_call_target="foo" + async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo" + async-update = ((f32[2,3]), f32[2,3], u32[]) custom-call-update(async-start) + ROOT async-done = f32[2,3] custom-call-done(async-update) } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo_string)); + // The parser checks that the async-{update,done} operand is an async op, + // so we need to invalidate it in the C++ representation. + HloInstruction* tuple = FindInstruction(module.get(), "tuple"); + HloInstruction* async_update = FindInstruction(module.get(), "async-update"); + TF_ASSERT_OK(async_update->ReplaceOperandWith(0, tuple)); + HloInstruction* async_done = FindInstruction(module.get(), "async-done"); + TF_ASSERT_OK(async_done->ReplaceOperandWith(0, tuple)); + HloInstruction* async_start = FindInstruction(module.get(), "async-start"); + HloComputation* computation = + FindComputation(module.get(), "AsyncUpdateAndAsyncDoneNoAsyncStart"); + TF_ASSERT_OK(computation->RemoveInstruction(async_start)); + auto status = verifier().Run(module.get()).status(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.message(), @@ -1129,13 +1114,18 @@ TEST_F(HloVerifierTest, AsyncOpTupleWrongType) { ENTRY AsyncStartAndAsyncDone { p0 = f32[2,3] parameter(0) - async-start = ((f32[2,3])) async-start(p0), calls=async_computation + async-start = ((f32[2,3]), f32[3,2], s32[]) async-start(p0), calls=async_computation ROOT async-done = f32[3,2] async-done(async-start), calls=async_computation } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo_string)); + // The parser checks that the async op's shape type is valid, so we need to + // invalidate it in the C++ representation. + HloInstruction* async_start = FindInstruction(module.get(), "async-start"); + async_start->mutable_shape()->clear_tuple_shapes(); + auto status = verifier().Run(module.get()).status(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.message(), @@ -1247,50 +1237,6 @@ TEST_F(HloVerifierTest, AsyncOpComputationNotTrivial) { "expected to contain only the root and parameter instructions")); } -TEST_F(HloVerifierTestLayoutSensitive, AsyncDoneWrongGroupId) { - const char* const hlo_string = R"( - HloModule Module - - ENTRY AsyncStartAndAsyncUpdateAndAsyncDone { - p0 = f32[2,3]{1,0:S(1)} parameter(0) - async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), async_group_id=0, custom_call_target="foo" - async-update.1 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-start), async_group_id=0, custom_call_target="foo" - async-update.2 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-update.1), async_group_id=0, custom_call_target="foo" - ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-update.2), async_group_id=1, custom_call_target="foo" - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); - - auto status = verifier().Run(module.get()).status(); - ASSERT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - HasSubstr("async-done expects its operand to have the same group " - "id (1 vs 0).")); -} - -TEST_F(HloVerifierTestLayoutSensitive, AsyncUpdateWrongGroupId) { - const char* const hlo_string = R"( - HloModule Module - - ENTRY AsyncStartAndAsyncUpdateAndAsyncDone { - p0 = f32[2,3]{1,0:S(1)} parameter(0) - async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), async_group_id=0, custom_call_target="foo" - async-update.1 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-start), custom_call_target="foo" - async-update.2 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-update.1), async_group_id=0, custom_call_target="foo" - ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-update.2), async_group_id=0, custom_call_target="foo" - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); - - auto status = verifier().Run(module.get()).status(); - ASSERT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - HasSubstr("async-update expects its operand to have the same " - "group id (none vs 0).")); -} - TEST_F(HloVerifierTest, IotaNonArrayResult) { const char* const hlo_string = R"( HloModule IotaTupleResult @@ -1424,7 +1370,7 @@ int64_t ReplicaCount(const std::vector>& replica_groups) { return replica_count; } -StatusOr> MakeCollectiveCommOpComputation( +absl::StatusOr> MakeCollectiveCommOpComputation( std::vector> replica_groups, std::optional replica_count, std::optional num_partitions, absl::string_view other_attributes, absl::string_view template_str) { @@ -1442,7 +1388,7 @@ StatusOr> MakeCollectiveCommOpComputation( config); } -StatusOr> MakeAllReduceComputation( +absl::StatusOr> MakeAllReduceComputation( std::vector> replica_groups, std::optional replica_count = std::nullopt, std::optional num_partitions = std::nullopt, @@ -1635,7 +1581,7 @@ TEST_F(HloVerifierTest, AllReduceDoneWithoutStart) { "needs to be all-reduce-start, found tuple")); } -StatusOr> MakeAllToAllComputation( +absl::StatusOr> MakeAllToAllComputation( std::vector> replica_groups, std::optional replica_count = std::nullopt, std::optional num_partitions = std::nullopt, @@ -2097,6 +2043,95 @@ TEST_F(HloVerifierTest, ChannelVerifier) { HasSubstr("used for different types of channel instructions")); } +TEST_F(HloVerifierTest, ChannelVerifierPipelinedMissingDones) { + const char* const kModuleStr = R"( + HloModule test + cond { + param = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) parameter(0) + count = get-tuple-element(%param), index=0 + ub = u32[] constant(1) + ROOT result = pred[] compare(count, ub), direction=LT + } + + body { + param = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) parameter(0) + count = get-tuple-element(%param), index=0 + + recv.0 = (u32[2], u32[], token[]) get-tuple-element(param), index=1 + recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + recv-data.0 = u32[2] get-tuple-element(recv-done.0), index=0 + + c1 = u32[] constant(1) + new_count = u32[] add(count, c1) + + send.0 = (u32[2], u32[], token[]) get-tuple-element(param), index=2 + send-done.0 = (u32[2], token[]) recv-done(send.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + + after-all.0.n = token[] after-all() + recv.0.n = (u32[2], u32[], token[]) recv(after-all.0.n), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}", + _xla_send_recv_pipeline="0" + } + + + after-all.1.n = token[] after-all() + send.0.n = (u32[2], u32[], token[]) send(recv-data.0, after-all.1.n), + channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}", + _xla_send_recv_pipeline="0" + } + + ROOT result = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) + tuple(new_count, recv.0.n, send.0.n) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + init = u32[2] broadcast(c0), dimensions={} + after-all.0.p = token[] after-all() + recv.0.p = (u32[2], u32[], token[]) recv(after-all.0.p), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}", + _xla_send_recv_pipeline="0" + } + + after-all.1.p = token[] after-all() + send.0.p = (u32[2], u32[], token[]) send(init, after-all.1.p), + channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}", + _xla_send_recv_pipeline="0" + } + + while_init = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) + tuple(c0, recv.0.p, send.0.p) + while_result = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) + while(while_init), body=body, condition=cond + + recv.0.q = (u32[2], u32[], token[]) get-tuple-element(while_result), index=1 + recv-done.0.q = (u32[2], token[]) recv-done(recv.0.q), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + + ROOT recv-data.0.q = u32[2] get-tuple-element(recv-done.0.q), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + EXPECT_THAT( + verifier().Run(module.get()).status().message(), + HasSubstr("is pipelined. Not all Send/Recv related instructions are used" + " the same number of times")); +} + TEST_F(HloVerifierTest, CollectiveChannelVerifier) { const char* const kModuleStr = R"( HloModule test @@ -2569,7 +2604,7 @@ TEST_F(HloVerifierTest, CheckWhileContainsAsyncThread) { %constant.1 = s32[] constant(5) %prev.2 = s32[] parameter(0) %async-start = ((s32[]), s32[], s32[]) custom-call-start(s32[] %prev.2), async_execution_thread="parallel_thread", custom_call_target="async_add" - %async-done = s32[] custom-call-done(((s32[]), s32[], s32[]) %async-start), async_execution_thread="parallel_thread", custom_call_target="async_add" + %async-done = s32[] custom-call-done(((s32[]), s32[], s32[]) %async-start) ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %async-done), direction=GT } @@ -2910,5 +2945,84 @@ TEST_F(HloVerifierTest, EnableUnboundedDynamism) { ASSERT_TRUE(status.ok()); } +TEST_F(HloVerifierTest, SparseDotMetadataShape) { + const char* const kHlo = R"( + HloModule test + ENTRY entry { + %lhs = f32[10,16] parameter(0) + %rhs = f32[32,20] parameter(1) + %meta = u16[10,4] parameter(2) + ROOT %dot = f32[10,20] dot(%lhs, %rhs, %meta), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo)); + HloVerifier verifier{HloVerifierOpts{}.WithAllowUnboundedDynamism(true)}; + auto status = verifier.Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.message(), HasSubstr("Expected sparse dot metadata")); +} + +TEST_F(HloVerifierTestLayoutSensitive, + HostOffloadingDUSAndDSAreVerifiedWhenChangingLayout) { + const char* const hlo_string = R"( + HloModule m + + ENTRY main { + constant_f32_0 = f32[] constant(0) + custom-call = f32[2,2048,2048]{2,1,0:S(5)} custom-call(), custom_call_target="AllocateBuffer" + data_param = f32[1,2048,2048]{2,1,0} parameter(0) + index_param = s32[] parameter(1) + constant_s32_0 = s32[] constant(0) + dynamic_update_slice = f32[2,2048,2048]{2,1,0:S(5)} dynamic-update-slice(custom-call, data_param, index_param, constant_s32_0, constant_s32_0) + ROOT dynamic_slice = f32[1,2048,2048]{2,1,0} dynamic-slice(f32[2,2048,2048]{2,1,0:S(5)} dynamic_update_slice, index_param, constant_s32_0, constant_s32_0), dynamic_slice_sizes={1,2048,2048} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTestLayoutSensitive, + HostOffloadingCopyIsVerifiedWhenChangingLayout) { + const char* const hlo_string = R"( + HloModule m + + ENTRY main { + data_param = f32[2048]{0} parameter(0) + copy_0 = f32[2048]{0:S(5)} copy(f32[2048]{0} data_param) + ROOT copy_1 = f32[2048]{0} copy(f32[2048]{0:S(5)} copy_0) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTestLayoutSensitive, + HostOffloadingDSCannotChangeLayoutFromDeviceToHost) { + const char* const hlo_string = R"( + HloModule m + + ENTRY main { + constant_f32_0 = f32[] constant(0) + custom-call = f32[2,2048,2048]{2,1,0} custom-call(), custom_call_target="AllocateBuffer" + data_param = f32[1,2048,2048]{2,1,0} parameter(0) + index_param = s32[] parameter(1) + constant_s32_0 = s32[] constant(0) + dynamic_update_slice = f32[2,2048,2048]{2,1,0} dynamic-update-slice(custom-call, data_param, index_param, constant_s32_0, constant_s32_0) + ROOT dynamic_slice = f32[1,2048,2048]{2,1,0:S(5)} dynamic-slice(f32[2,2048,2048]{2,1,0} dynamic_update_slice, index_param, constant_s32_0, constant_s32_0), dynamic_slice_sizes={1,2048,2048} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.message(), + HasSubstr("DynamicSlice instruction shouldn't change layout " + "memory space from device to host")); +} } // namespace } // namespace xla diff --git a/xla/service/host_memory_offload_annotations.h b/xla/service/host_memory_offload_annotations.h new file mode 100644 index 0000000000000..a0b7e3decaea3 --- /dev/null +++ b/xla/service/host_memory_offload_annotations.h @@ -0,0 +1,37 @@ +/* Copyright 2024 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#ifndef XLA_SERVICE_HOST_MEMORY_OFFLOAD_ANNOTATIONS_H_ +#define XLA_SERVICE_HOST_MEMORY_OFFLOAD_ANNOTATIONS_H_ + +#include "absl/strings/string_view.h" + +namespace xla { +namespace host_memory_offload_annotations { + +// External annotations: +inline const absl::string_view kDevicePlacement = "annotate_device_placement"; +inline const absl::string_view kMemoryTargetPinnedHost = "pinned_host"; +inline const absl::string_view kMemoryTargetUnpinnedHost = "unpinned_host"; +inline const absl::string_view kMemoryTargetDevice = "device"; + +// Internal annotations: +inline const absl::string_view kMoveToHostCustomCallTarget = "MoveToHost"; +inline const absl::string_view kMoveToDeviceCustomCallTarget = "MoveToDevice"; + +} // namespace host_memory_offload_annotations +} // namespace xla + +#endif // XLA_SERVICE_HOST_MEMORY_OFFLOAD_ANNOTATIONS_H_ diff --git a/xla/service/host_memory_transfer_asyncifier.cc b/xla/service/host_memory_transfer_asyncifier.cc new file mode 100644 index 0000000000000..c704183e325d6 --- /dev/null +++ b/xla/service/host_memory_transfer_asyncifier.cc @@ -0,0 +1,210 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/host_memory_transfer_asyncifier.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +namespace { + +class HostMemoryTransferAsyncifierVisitor : public DfsHloVisitorWithDefault { + public: + explicit HostMemoryTransferAsyncifierVisitor(int64_t host_memory_space_color) + : kHostMemorySpaceColor(host_memory_space_color) {} + bool Changed() const { return changed_; } + + Status DefaultAction(HloInstruction* hlo_instruction) override { + return OkStatus(); + } + + // Replace all dynamic-slice ops which slice from host memory to device memory + // with an asynchronous dynamic-slice. + Status HandleDynamicSlice(HloInstruction* dynamic_slice) override { + // Check that the dynamic_slice and its first operand have layouts. This + // pass must only be run after LayoutAssignment. + HloInstruction* dynamic_slice_operand = dynamic_slice->mutable_operand(0); + if (!dynamic_slice->shape().has_layout()) { + return InternalStrCat(dynamic_slice->name(), " does not have a layout."); + } + if (!dynamic_slice_operand->shape().has_layout()) { + return InternalStrCat(dynamic_slice->name(), "'s operand, ", + dynamic_slice_operand->name(), + ", does not have a layout."); + } + + VLOG(3) << absl::StreamFormat( + "\"%s\" from S(%d) to S(%d)", dynamic_slice->name(), + dynamic_slice_operand->shape().layout().memory_space(), + dynamic_slice->shape().layout().memory_space()); + // Check that this is a dynamic-slice slicing from host memory to device + // memory. + if (dynamic_slice_operand->shape().layout().memory_space() != + kHostMemorySpaceColor) { + // Only care about dynamic-slice from host memory. + return OkStatus(); + } + if (dynamic_slice->shape().layout().memory_space() != + xla::Layout::kDefaultMemorySpace) { + // Only care about dynamic-slice to device memory. + return OkStatus(); + } + + // Everything is as expected. Replace this dynamic-slice with the async + // equivalent. + VLOG(1) << "DynamicSlice \"" << dynamic_slice->name() + << "\" is slicing from host memory. Converting to async."; + const Shape context_shape = ShapeUtil::MakeScalarShape(U32); + const Shape transfer_bytes_shape = ShapeUtil::MakeScalarShape(S32); + TF_ASSIGN_OR_RETURN( + HloInstruction * async_done, + dynamic_slice->parent()->CreateAsyncInstructions( + dynamic_slice, {context_shape, transfer_bytes_shape})); + (void)async_done; + MarkAsChanged(); + return OkStatus(); + } + + // Replace all dynamic-update-slice ops which update into host memory from + // device memory with an asynchronous dynamic-update-slice. + Status HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) override { + // Check that the dynamic-update-slice and its first two operands have + // layouts. This pass must only be run after LayoutAssignment. + HloInstruction* dynamic_update_slice_operand = + dynamic_update_slice->mutable_operand(0); + HloInstruction* dynamic_update_slice_update = + dynamic_update_slice->mutable_operand(1); + if (!dynamic_update_slice->shape().has_layout()) { + return InternalStrCat(dynamic_update_slice->name(), + " does not have a layout."); + } + if (!dynamic_update_slice_operand->shape().has_layout()) { + return InternalStrCat(dynamic_update_slice->name(), "'s operand, ", + dynamic_update_slice_operand->name(), + ", does not have a layout."); + } + if (!dynamic_update_slice_update->shape().has_layout()) { + return InternalStrCat(dynamic_update_slice->name(), "'s update, ", + dynamic_update_slice_update->name(), + ", does not have a layout."); + } + + // Check that this is a dynamic-update-slice updating from device memory + // into host memory. + if (dynamic_update_slice_update->shape().layout().memory_space() != + xla::Layout::kDefaultMemorySpace) { + // Only care about dynamic-update-slice from device memory. + return OkStatus(); + } + if (dynamic_update_slice->shape().layout().memory_space() != + kHostMemorySpaceColor) { + // Only care about dynamic-update-slice to host memory. + return OkStatus(); + } + if (dynamic_update_slice_operand->shape().layout().memory_space() != + dynamic_update_slice->shape().layout().memory_space()) { + return InternalStrCat( + "Unexpected that ", dynamic_update_slice_operand->name(), + "'s memory space is not the same as the dynamic-update-slice."); + } + + // Everything is as expected. Replace this dynamic-update-slice with the + // async equivalent. + VLOG(1) << "DynamicUpdateSlice \"" << dynamic_update_slice->name() + << "\" is slicing into host memory space. Converting to async."; + const Shape context_shape = ShapeUtil::MakeScalarShape(U32); + TF_ASSIGN_OR_RETURN(HloInstruction * async_done, + dynamic_update_slice->parent()->CreateAsyncInstructions( + dynamic_update_slice, {context_shape})); + (void)async_done; + MarkAsChanged(); + return OkStatus(); + } + + // Replace all copy ops which copy from host memory to device memory or from + // device memory to host memory with an asynchronous copy. + Status HandleCopy(HloInstruction* copy) override { + HloInstruction* operand = copy->mutable_operand(0); + if (!operand->shape().has_layout()) { + return InternalStrCat(operand->name(), " does not have a layout."); + } + if (!copy->shape().has_layout()) { + return InternalStrCat(copy->name(), " does not have a layout."); + } + + const auto copy_src_memory_space = operand->shape().layout().memory_space(); + const auto copy_dst_memory_space = copy->shape().layout().memory_space(); + if (!((copy_src_memory_space == kHostMemorySpaceColor && + copy_dst_memory_space == xla::Layout::kDefaultMemorySpace) || + (copy_src_memory_space == xla::Layout::kDefaultMemorySpace && + copy_dst_memory_space == kHostMemorySpaceColor))) { + VLOG(2) + << "Skipping copy because it is not a copy between device memory and " + "host memory: " + << copy->ToString(); + // Only care about copies between device memory and host memory. + return OkStatus(); + } + + // Everything is as expected. Replace this copy with the async equivalent. + VLOG(1) + << "Copy \"" << copy->name() + << "\" is between device and host memory space. Converting to async."; + const Shape context_shape = ShapeUtil::MakeScalarShape(U32); + TF_ASSIGN_OR_RETURN( + HloInstruction * async_done, + copy->parent()->CreateAsyncInstructions(copy, {context_shape})); + (void)async_done; + MarkAsChanged(); + return OkStatus(); + } + + private: + const int64_t kHostMemorySpaceColor; + bool changed_ = false; + + void MarkAsChanged() { changed_ = true; } +}; + +} // namespace + +absl::StatusOr HostMemoryTransferAsyncifier::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + HostMemoryTransferAsyncifierVisitor visitor(kHostMemorySpaceColor); + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + } + return visitor.Changed(); +} + +} // namespace xla diff --git a/xla/service/host_memory_transfer_asyncifier.h b/xla/service/host_memory_transfer_asyncifier.h new file mode 100644 index 0000000000000..5368162e75817 --- /dev/null +++ b/xla/service/host_memory_transfer_asyncifier.h @@ -0,0 +1,58 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_HOST_MEMORY_TRANSFER_ASYNCIFIER_H_ +#define XLA_SERVICE_HOST_MEMORY_TRANSFER_ASYNCIFIER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "xla/service/hlo_pass_interface.h" +#include "xla/statusor.h" + +namespace xla { + +/* +This pass finds copies between the host memory and device memory and converts +them into the async ops. This includes, but is not limited to: + - device to host DynamicUpdateSlice + - host to device DynamicSlice +* The examples below are not yet supported * + - host to device DynamicUpdateSlice + - device to host DynamicSlice + - host to device Copy + - device to host Copy +*/ +class HostMemoryTransferAsyncifier : public HloModulePass { + public: + explicit HostMemoryTransferAsyncifier(int64_t host_memory_space_color) + : kHostMemorySpaceColor(host_memory_space_color) {} + ~HostMemoryTransferAsyncifier() override = default; + + absl::string_view name() const override { + return "host-memory-transfer-asyncifier"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + const int64_t kHostMemorySpaceColor; +}; + +} // namespace xla + +#endif // XLA_SERVICE_HOST_MEMORY_TRANSFER_ASYNCIFIER_H_ diff --git a/xla/service/host_memory_transfer_asyncifier_test.cc b/xla/service/host_memory_transfer_asyncifier_test.cc new file mode 100644 index 0000000000000..092f4fe7ef256 --- /dev/null +++ b/xla/service/host_memory_transfer_asyncifier_test.cc @@ -0,0 +1,424 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/host_memory_transfer_asyncifier.h" + +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" +#include "xla/statusor.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/util.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +namespace m = ::xla::match; + +class HostMemoryTransferAsyncifierTest : public HloTestBase { + protected: + absl::StatusOr RunAsyncifier(absl::string_view hlo_string) { + TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSIGN_OR_RETURN(bool changed, RunAsyncifier(module.get())); + return changed; + } + + absl::StatusOr RunAsyncifier(HloModule* module) { + TF_EXPECT_OK(verifier().Run(module).status()); + if (module->has_schedule()) { + return absl::InternalError("Expected a non-scheduled module"); + } + + HostMemoryTransferAsyncifier asyncifier(kHostMemorySpaceColor); + return asyncifier.Run(module); + } + + private: + static constexpr int64_t kHostMemorySpaceColor{5}; +}; + +// =============================DynamicUpdateSlice============================== + +TEST_F(HostMemoryTransferAsyncifierTest, DynamicUpdateSliceFromHostToHost) { + const std::string& hlo_string = R"( +HloModule MyModule + +ENTRY main { + host_operand = f32[32,1,1]{2,1,0:T(2,128)S(5)} parameter(0) + host_update = f32[1,1,1]{2,1,0:T(2,128)S(5)} parameter(1) + constant_0 = s32[] constant(0) + ROOT dynamic-update-slice = f32[32,1,1]{2,1,0:T(2,128)S(5)} dynamic-update-slice(host_operand, host_update, constant_0, constant_0, constant_0) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunAsyncifier(module.get())); + + EXPECT_FALSE(changed); + // The root instruction should still be a regular dynamic-update-slice + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::DynamicUpdateSlice())); +} + +TEST_F(HostMemoryTransferAsyncifierTest, DynamicUpdateSliceFromDeviceToDevice) { + const std::string& hlo_string = R"( +HloModule MyModule + +ENTRY main { + operand = f32[32,1,1]{2,1,0:T(2,128)} parameter(0) + update = f32[1,1,1]{2,1,0:T(2,128)} parameter(1) + constant_0 = s32[] constant(0) + ROOT dynamic-update-slice = f32[32,1,1]{2,1,0:T(2,128)} dynamic-update-slice(operand, update, constant_0, constant_0, constant_0) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunAsyncifier(module.get())); + + EXPECT_FALSE(changed); + // The root instruction should still be a regular dynamic-update-slice + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::DynamicUpdateSlice())); +} + +TEST_F(HostMemoryTransferAsyncifierTest, DynamicUpdateSliceFromHostToDevice) { + const std::string& hlo_string = R"( +HloModule MyModule + +ENTRY main { + operand = f32[32,1,1]{2,1,0:T(2,128)} parameter(0) + host_update = f32[1,1,1]{2,1,0:T(2,128)S(5)} parameter(1) + constant_0 = s32[] constant(0) + ROOT dynamic-update-slice = f32[32,1,1]{2,1,0:T(2,128)} dynamic-update-slice(operand, host_update, constant_0, constant_0, constant_0) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunAsyncifier(module.get())); + + EXPECT_FALSE(changed); + // The root instruction should still be a regular dynamic-update-slice + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::DynamicUpdateSlice())); +} + +TEST_F(HostMemoryTransferAsyncifierTest, DynamicUpdateSliceFromDeviceToHost) { + const std::string& hlo_string = R"( +HloModule MyModule + +ENTRY main { + host_operand = f32[32,1,1]{2,1,0:T(2,128)S(5)} parameter(0) + update = f32[1,1,1]{2,1,0:T(2,128)} parameter(1) + constant_0 = s32[] constant(0) + ROOT dynamic-update-slice = f32[32,1,1]{2,1,0:T(2,128)S(5)} dynamic-update-slice(host_operand, update, constant_0, constant_0, constant_0) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunAsyncifier(module.get())); + + EXPECT_TRUE(changed); + // dynamic-update-slice should have been converted into an + // async-dynamic-update-slice. + HloInstruction* dynamic_update_slice_start; + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Op() + .WithOpcode(HloOpcode::kAsyncDone) + .WithOperand(0, m::Op(&dynamic_update_slice_start) + .WithOpcode(HloOpcode::kAsyncStart)))); + ASSERT_EQ(dynamic_update_slice_start->called_computations().size(), 1); + HloComputation* async_dynamic_slice_computation = + dynamic_update_slice_start->called_computations().at(0); + EXPECT_THAT(async_dynamic_slice_computation->root_instruction(), + GmockMatch(m::DynamicUpdateSlice())); +} + +// ================================DynamicSlice================================= + +TEST_F(HostMemoryTransferAsyncifierTest, DynamicSliceFromHostToHost) { + const std::string& hlo_string = R"( +HloModule MyModule + +ENTRY main { + host_memory = f32[32,1,1]{2,1,0:T(2,128)S(5)} parameter(0) + constant_0 = s32[] constant(0) + ROOT dynamic-slice = f32[1,1,1]{2,1,0:T(2,128)S(5)} dynamic-slice(host_memory, constant_0, constant_0, constant_0), dynamic_slice_sizes={1,1,1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunAsyncifier(module.get())); + + EXPECT_FALSE(changed); + // The root instruction should still be a regular dynamic-slice + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::DynamicSlice())); +} + +TEST_F(HostMemoryTransferAsyncifierTest, DynamicSliceFromDeviceToDevice) { + const std::string& hlo_string = R"( +HloModule MyModule + +ENTRY main { + device = f32[32,1,1]{2,1,0:T(2,128)} parameter(0) + constant_0 = s32[] constant(0) + ROOT dynamic-slice = f32[1,1,1]{2,1,0:T(2,128)} dynamic-slice(device, constant_0, constant_0, constant_0), dynamic_slice_sizes={1,1,1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunAsyncifier(module.get())); + + EXPECT_FALSE(changed); + // The root instruction should still be a regular dynamic-slice + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::DynamicSlice())); +} + +TEST_F(HostMemoryTransferAsyncifierTest, DynamicSliceFromDeviceToHost) { + const std::string& hlo_string = R"( +HloModule MyModule + +ENTRY main { + device = f32[32,1,1]{2,1,0:T(2,128)} parameter(0) + constant_0 = s32[] constant(0) + ROOT dynamic-slice = f32[1,1,1]{2,1,0:T(2,128)S(5)} dynamic-slice(device, constant_0, constant_0, constant_0), dynamic_slice_sizes={1,1,1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunAsyncifier(module.get())); + + EXPECT_FALSE(changed); + // The root instruction should still be a regular dynamic-slice + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::DynamicSlice())); +} + +TEST_F(HostMemoryTransferAsyncifierTest, DynamicSliceFromHostToDevice) { + const std::string& hlo_string = R"( +HloModule MyModule + +ENTRY main { + host_memory = f32[32,1,1]{2,1,0:T(2,128)S(5)} parameter(0) + constant_0 = s32[] constant(0) + ROOT dynamic-slice = f32[1,1,1]{2,1,0:T(2,128)} dynamic-slice(host_memory, constant_0, constant_0, constant_0), dynamic_slice_sizes={1,1,1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunAsyncifier(module.get())); + + EXPECT_TRUE(changed); + // dynamic-slice should have been converted into an async-dynamic-slice. + HloInstruction* dynamic_slice_start; + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Op() + .WithOpcode(HloOpcode::kAsyncDone) + .WithOperand(0, m::Op(&dynamic_slice_start) + .WithOpcode(HloOpcode::kAsyncStart)))); + ASSERT_EQ(dynamic_slice_start->called_computations().size(), 1); + HloComputation* async_dynamic_slice_computation = + dynamic_slice_start->called_computations().at(0); + EXPECT_THAT(async_dynamic_slice_computation->root_instruction(), + GmockMatch(m::DynamicSlice())); +} + +// ====================================Copy===================================== + +TEST_F(HostMemoryTransferAsyncifierTest, CopyFromHostToHost) { + const std::string& hlo_string = R"( +HloModule MyModule + +ENTRY main { + host_memory = f32[32,1,1]{2,1,0:T(2,128)S(5)} parameter(0) + ROOT copy = f32[32,1,1]{2,1,0:T(2,128)S(5)} copy(host_memory) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunAsyncifier(module.get())); + + EXPECT_FALSE(changed); + // The root instruction should still be a regular copy + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Copy())); +} + +TEST_F(HostMemoryTransferAsyncifierTest, CopyFromDeviceToDevice) { + const std::string& hlo_string = R"( +HloModule MyModule + +ENTRY main { + device = f32[32,1,1]{2,1,0:T(2,128)} parameter(0) + ROOT copy = f32[32,1,1]{2,1,0:T(2,128)} copy(device) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunAsyncifier(module.get())); + + EXPECT_FALSE(changed); + // The root instruction should still be a regular copy + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Copy())); +} + +// TODO(b/319466176): Once this bug is fixed, enable this test and delete the +// OldCopyFromDeviceToHost test. +TEST_F(HostMemoryTransferAsyncifierTest, DISABLED_CopyFromDeviceToHost) { + const std::string& hlo_string = R"( +HloModule MyModule + +ENTRY main { + device = f32[32,1,1]{2,1,0:T(2,128)} parameter(0) + ROOT copy = f32[32,1,1]{2,1,0:T(2,128)S(5)} copy(device) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunAsyncifier(module.get())); + + EXPECT_TRUE(changed); + // copy should have been converted into an async-copy. + HloInstruction* copy_start; + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Op() + .WithOpcode(HloOpcode::kAsyncDone) + .WithOperand( + 0, m::Op(©_start).WithOpcode(HloOpcode::kAsyncStart)))); + ASSERT_EQ(copy_start->called_computations().size(), 1); + HloComputation* async_copy_computation = + copy_start->called_computations().at(0); + EXPECT_THAT(async_copy_computation->root_instruction(), + GmockMatch(m::Copy())); +} + +TEST_F(HostMemoryTransferAsyncifierTest, OldCopyFromDeviceToHost) { + const std::string& hlo_string = R"( +HloModule MyModule + +ENTRY main { + device = f32[32,1,1]{2,1,0:T(2,128)} parameter(0) + ROOT copy = f32[32,1,1]{2,1,0:T(2,128)S(5)} copy(device) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunAsyncifier(module.get())); + + EXPECT_TRUE(changed); + // copy should have been converted into an async-copy. + HloInstruction* copy_start; + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Op() + .WithOpcode(HloOpcode::kCopyDone) + .WithOperand( + 0, m::Op(©_start).WithOpcode(HloOpcode::kCopyStart)))); +} + +// TODO(b/319466176): Once this bug is fixed, enable this test and delete the +// OldCopyFromHostToDevice test. +TEST_F(HostMemoryTransferAsyncifierTest, DISABLED_CopyFromHostToDevice) { + const std::string& hlo_string = R"( +HloModule MyModule + +ENTRY main { + host_memory = f32[32,1,1]{2,1,0:T(2,128)S(5)} parameter(0) + ROOT copy = f32[32,1,1]{2,1,0:T(2,128)} copy(host_memory) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunAsyncifier(module.get())); + + EXPECT_TRUE(changed); + // copy should have been converted into an async-copy. + HloInstruction* copy_start; + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Op() + .WithOpcode(HloOpcode::kAsyncDone) + .WithOperand( + 0, m::Op(©_start).WithOpcode(HloOpcode::kAsyncStart)))); + ASSERT_EQ(copy_start->called_computations().size(), 1); + HloComputation* async_copy_computation = + copy_start->called_computations().at(0); + EXPECT_THAT(async_copy_computation->root_instruction(), + GmockMatch(m::Copy())); +} + +TEST_F(HostMemoryTransferAsyncifierTest, OldCopyFromHostToDevice) { + const std::string& hlo_string = R"( +HloModule MyModule + +ENTRY main { + host_memory = f32[32,1,1]{2,1,0:T(2,128)S(5)} parameter(0) + ROOT copy = f32[32,1,1]{2,1,0:T(2,128)} copy(host_memory) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunAsyncifier(module.get())); + + EXPECT_TRUE(changed); + // copy should have been converted into an async-copy. + HloInstruction* copy_start; + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Op() + .WithOpcode(HloOpcode::kCopyDone) + .WithOperand( + 0, m::Op(©_start).WithOpcode(HloOpcode::kCopyStart)))); +} + +// ============================================================================= + +} // namespace + +} // namespace xla diff --git a/xla/service/host_offload_legalize.cc b/xla/service/host_offload_legalize.cc new file mode 100644 index 0000000000000..c6349f7680a54 --- /dev/null +++ b/xla/service/host_offload_legalize.cc @@ -0,0 +1,662 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/host_offload_legalize.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/call_graph.h" +#include "xla/service/hlo_value.h" +#include "xla/service/host_memory_offload_annotations.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +namespace { + +constexpr std::array kUsersOpcodes = {HloOpcode::kSlice, + HloOpcode::kDynamicSlice}; + +// Find an annotation moving up. Meant to find an annotation from a DUS operand. +HloInstruction* FindToHostAnnotationToUpdate(HloInstruction* instr) { + while (!instr->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget)) { + if ((instr->opcode() != HloOpcode::kBitcast && + instr->opcode() != HloOpcode::kCopy && + instr->opcode() != HloOpcode::kReshape) || + instr->mutable_operand(0)->user_count() != 1) { + return nullptr; + } + instr = instr->mutable_operand(0); + } + return instr; +} + +// Find an annotation moving up. Meant to find an annotation from a DUS +// instruction. +HloInstruction* FindToDeviceAnnotationToUpdate(HloInstruction* instr) { + while (!instr->IsCustomCall( + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) { + if (instr->user_count() != 1 || + (instr->opcode() != HloOpcode::kBitcast && + instr->opcode() != HloOpcode::kReshape && + instr->opcode() != HloOpcode::kCopy && + !absl::c_linear_search(kUsersOpcodes, instr->opcode()))) { + return nullptr; + } + instr = instr->users()[0]; + } + return instr; +} + +// Find a DUS starting from an annotation. +HloInstruction* FindDUSFromAnnotation(HloInstruction* instr) { + while (instr->opcode() != HloOpcode::kDynamicUpdateSlice) { + if (instr->user_count() != 1 || (instr->opcode() != HloOpcode::kBitcast && + instr->opcode() != HloOpcode::kReshape)) { + break; + } + instr = instr->users()[0]; + } + return instr; +} + +// Make sure that broadcasts are duplicated for each use. +absl::StatusOr DuplicateBroadcastForEachUse(HloModule* module) { + bool split_at_least_one = false; + for (HloComputation* computation : module->computations()) { + std::vector broadcasts; + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kBroadcast || + !instruction->HasConstantOperand()) { + continue; + } + broadcasts.push_back(instruction); + } + for (HloInstruction* instruction : broadcasts) { + if (instruction->opcode() != HloOpcode::kBroadcast || + !instruction->HasConstantOperand()) { + continue; + } + absl::InlinedVector uses; + for (HloInstruction* user : instruction->users()) { + for (int64_t i = 0; i < user->operand_count(); ++i) { + if (user->operand(i) != instruction) { + continue; + } + uses.push_back(HloUse{user, i, /*operand_index=*/{}}); + } + } + + if (uses.size() <= 1) { + VLOG(5) << "Skipping broadcast " << instruction->ToString() + << " which has " << uses.size() << " uses"; + continue; + } + + VLOG(5) << "Splitting broadcast " << instruction->ToString() + << " which has " << uses.size() << " uses"; + split_at_least_one = true; + // Don't create a new broadcast for the first use; we can still use the + // original. + for (int i = 1; i < uses.size(); ++i) { + const HloUse& use = uses[i]; + HloInstruction* new_broadcast = + instruction->parent()->AddInstruction(instruction->Clone()); + VLOG(5) << "New broadcast " << new_broadcast->ToString(); + TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith( + use.operand_number, new_broadcast)); + } + } + } + return split_at_least_one; +} + +// Walk up in the chain of memory offloaded instructions. Status not-ok when +// an instructions not supported or end of chain reached. +// Walks one instruction at a time. +absl::StatusOr> WalkUpMemoryOffload( + std::pair current_value, + const CallGraph& call_graph) { + // TODO(maggioni): Verify that set of instructions supported in chain by + // legalization is in sync with host_offloader. + auto& [instruction, index] = current_value; + // Walk up to find definition + switch (instruction->opcode()) { + case HloOpcode::kGetTupleElement: { + CHECK_EQ(index, -1); + return std::make_pair(instruction->mutable_operand(0), + instruction->tuple_index()); + } + case HloOpcode::kBitcast: + case HloOpcode::kReshape: { + return std::make_pair(instruction->mutable_operand(0), index); + } + case HloOpcode::kTuple: { + return std::make_pair(instruction->mutable_operand(index), -1); + } + case HloOpcode::kOptimizationBarrier: { + return std::make_pair(instruction->mutable_operand(0), index); + } + case HloOpcode::kWhile: { + HloComputation* while_body = instruction->while_body(); + HloInstruction* root = while_body->root_instruction(); + CHECK_EQ(root->opcode(), HloOpcode::kTuple); + return std::make_pair(root, index); + } + case HloOpcode::kParameter: { + CHECK_NE(instruction->parent(), + instruction->GetModule()->entry_computation()); + auto callers = call_graph.GetComputationCallers(instruction->parent()); + if (callers.size() != 1) { + return absl::InvalidArgumentError( + "Expected to be called only by one caller"); + } + auto* caller = callers[0]; + if (caller->opcode() != HloOpcode::kWhile) { + return absl::InvalidArgumentError( + "Expected to be called by a while loop"); + } + return std::make_pair(caller->mutable_operand(0), index); + } + case HloOpcode::kDynamicUpdateSlice: { + return std::make_pair(instruction->mutable_operand(0), index); + } + case HloOpcode::kCustomCall: { + if (!instruction->IsCustomCall("AllocateBuffer") && + !instruction->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget)) { + return absl::InvalidArgumentError( + "Expected AllocateBuffer or MoveToHost custom-call"); + } + return std::make_pair(instruction, index); + } + case HloOpcode::kBroadcast: { + auto* broadcast_operand = instruction->mutable_operand(0); + if (broadcast_operand->opcode() != HloOpcode::kConstant) { + return absl::InvalidArgumentError("Expected a constant as operand"); + } + if (!ShapeUtil::IsEffectiveScalar(broadcast_operand->shape())) { + return absl::InvalidArgumentError("Expected a scalar broadcast"); + } + return std::make_pair(instruction, index); + } + default: { + return absl::InvalidArgumentError( + absl::StrFormat("Invalid opcode %s", instruction->ToString())); + } + } +} + +// Walk down in the chain of memory offloaded instructions. Status not-ok when +// an instructions not supported or end of chain reached. +// Walks one instruction at a time, but returns multiple instructions for each +// conforming user. +absl::StatusOr>> +WalkDownMemoryOffload(const std::pair& current_value, + const CallGraph& call_graph) { + // TODO(maggioni): Verify that set of instructions supported in chain by + // legalization is in sync with host_offloader. + VLOG(5) << "Current value in progress: " << current_value.first->ToString() + << " idx: " << current_value.second; + std::vector> results; + auto add_gte_for_idx = [&results](HloInstruction* instr, int idx) -> Status { + HloInstruction* gte = nullptr; + for (HloInstruction* user : instr->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + return absl::InvalidArgumentError( + "Expected users to be only get-tuple-elements"); + } + if (user->tuple_index() != idx) { + continue; + } + if (gte != nullptr) { + return absl::InvalidArgumentError( + "Expected to find only one gte per index."); + } + results.push_back(std::make_pair(user, -1)); + } + return OkStatus(); + }; + if (current_value.first->user_count() == 0) { + if (current_value.first->parent()->root_instruction() == + current_value.first) { + auto callers = + call_graph.GetComputationCallers(current_value.first->parent()); + if (callers.size() != 1 || callers[0]->opcode() != HloOpcode::kWhile) { + return absl::InvalidArgumentError( + "Expected to be called only by one caller and caller be a While"); + } + TF_RETURN_IF_ERROR(add_gte_for_idx(callers[0], current_value.second)); + return results; + } + } + if (current_value.first->opcode() == HloOpcode::kParameter && + current_value.first->shape().IsTuple()) { + TF_RETURN_IF_ERROR( + add_gte_for_idx(current_value.first, current_value.second)); + return results; + } + for (HloInstruction* user : current_value.first->users()) { + switch (user->opcode()) { + case HloOpcode::kGetTupleElement: { + CHECK_NE(user->tuple_index(), -1); + if (user->tuple_index() != current_value.second) { + continue; + } + results.push_back(std::make_pair(user, -1)); + break; + } + case HloOpcode::kTuple: { + auto output_indices = user->OperandIndices(current_value.first); + if (output_indices.size() != 1) { + return absl::InvalidArgumentError( + "Expected operand to be used only once in the tuple."); + } + results.push_back(std::make_pair(user, output_indices[0])); + break; + } + case HloOpcode::kOptimizationBarrier: { + results.push_back(std::make_pair(user, current_value.second)); + break; + } + case HloOpcode::kWhile: { + HloComputation* while_body = user->while_body(); + HloInstruction* parameter = while_body->parameter_instruction(0); + results.push_back(std::make_pair(parameter, current_value.second)); + break; + } + case HloOpcode::kDynamicUpdateSlice: { + if (user->OperandIndices(current_value.first)[0] != 0) { + return absl::InvalidArgumentError( + "Expected to be used by first operand of dynamic-update-slice"); + } + results.push_back(std::make_pair(user, current_value.second)); + break; + } + case HloOpcode::kCustomCall: { + if (user->IsCustomCall(host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget)) { + results.push_back(std::make_pair(user, current_value.second)); + break; + } + return absl::InvalidArgumentError("Invalid custom-call found."); + } + case HloOpcode::kBitcast: + case HloOpcode::kCopy: + case HloOpcode::kDynamicSlice: + case HloOpcode::kReshape: + case HloOpcode::kSlice: { + results.push_back(std::make_pair(user, current_value.second)); + break; + } + default: { + return absl::InvalidArgumentError("Unrecognized user opcode"); + } + } + } + return results; +} + +absl::StatusOr ProcessAnnotationForCopyMovement( + HloInstruction* instruction, const CallGraph* call_graph, + absl::flat_hash_set& processed_annotations, + std::vector& to_remove) { + auto is_entry_computation_parameter = [](HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kParameter && + instruction->parent()->IsEntryComputation(); + }; + + if (instruction->IsRoot()) { + return false; + } + HloInstruction* starting_instr = + FindDUSFromAnnotation(instruction->users().at(0)); + // If it's the pure copy case reset instruction. + if (starting_instr->opcode() != HloOpcode::kDynamicUpdateSlice) { + starting_instr = instruction; + } + VLOG(3) << "Dus or Annotation: " << starting_instr->ToString(); + std::pair current_value = + std::make_pair(starting_instr, -1); + // Found a copy that would block offloading. Walk up to find all annotations + // to update (required in case there are multiple insertions in the buffer). + processed_annotations.insert(current_value.first); + if (!current_value.first->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget) && + !is_entry_computation_parameter(current_value.first)) { + CHECK_EQ(current_value.first->opcode(), HloOpcode::kDynamicUpdateSlice); + while (true) { + VLOG(10) << "Current value before: " << current_value.first->ToString(); + auto current_value_up = WalkUpMemoryOffload(current_value, *call_graph); + // Invalid upward walking means the chain is unrecognized. + if (!current_value_up.ok()) { + return false; + } + // This means we encountered a broadcast with constant 0 expansion. + if (current_value_up.value() == current_value) { + break; + } + current_value = current_value_up.value(); + VLOG(10) << "Current value after: " << current_value.first->ToString(); + HloInstruction* annotation = current_value.first; + if (annotation->opcode() == HloOpcode::kDynamicUpdateSlice) { + HloInstruction* real_annotation = + FindToHostAnnotationToUpdate(annotation->mutable_operand(1)); + // Check if this dynamic-update-slice doesn't have an annotation + // attached. + if (!real_annotation->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget)) { + return false; + } + } + } + } + std::vector> copies_to_move; + // Do a final walkdown from the top to collect all the instructions that need + // their shape updated. + std::vector> stack(1, current_value); + while (!stack.empty()) { + VLOG(5) << "Current value before down: " << stack.back().first->ToString(); + if (absl::c_linear_search(kUsersOpcodes, stack.back().first->opcode()) || + stack.back().first->IsCustomCall( + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) { + HloInstruction* annotation = + FindToDeviceAnnotationToUpdate(stack.back().first); + if (!annotation || + !annotation->IsCustomCall( + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) { + VLOG(5) << "Couldn't find annotation for consumer instruction in chain"; + return false; + } + stack.pop_back(); + continue; + } + auto current_value_down = WalkDownMemoryOffload(stack.back(), *call_graph); + if (!current_value_down.ok()) { + VLOG(5) << "Current value down failed: " << current_value_down.status(); + break; + } + stack.pop_back(); + stack.insert(stack.end(), current_value_down.value().begin(), + current_value_down.value().end()); + for (auto& instruction : current_value_down.value()) { + VLOG(5) << "Current value last down: " << stack.back().first->ToString(); + if (instruction.first->opcode() == HloOpcode::kCopy) { + copies_to_move.push_back(instruction); + } + } + } + + auto update_shape_layout = + [&](const std::pair& instruction, + HloInstruction* copy_to_move) { + VLOG(5) << "Update shape layout: " << instruction.first->ToString() + << " " << instruction.second; + // Update shape. Tuple shape vs array shape. + if (instruction.second != -1) { + *instruction.first->mutable_shape() + ->mutable_tuple_shapes(instruction.second) + ->mutable_layout() = copy_to_move->operand(0)->shape().layout(); + } else { + *instruction.first->mutable_shape()->mutable_layout() = + copy_to_move->operand(0)->shape().layout(); + } + + if (instruction.first->opcode() == HloOpcode::kWhile) { + // Fix up while body's root instruction shape and condition's + // parameter shape for while loops. + Shape new_shape = copy_to_move->operand(0)->shape(); + *instruction.first->while_body() + ->root_instruction() + ->mutable_shape() + ->mutable_tuple_shapes(instruction.second) + ->mutable_layout() = new_shape.layout(); + *instruction.first->while_condition() + ->parameter_instruction(0) + ->mutable_shape() + ->mutable_tuple_shapes(instruction.second) + ->mutable_layout() = new_shape.layout(); + } + }; + + // Process all copies one at a time from the last to the first and push it to + // its specific user. + while (!copies_to_move.empty()) { + auto& copy_to_move = copies_to_move.back(); + VLOG(5) << "Copy to move: " << copy_to_move.first->ToString(); + stack.clear(); + stack.push_back(copy_to_move); + while (!stack.empty()) { + VLOG(5) << "Current value before down: " << stack.back().first->ToString() + << " " << stack.back().second; + auto current_value_down = + WalkDownMemoryOffload(stack.back(), *call_graph); + if (!current_value_down.ok()) { + VLOG(5) << "Current value down failed: " << current_value_down.status(); + break; + } + for (auto& instruction : current_value_down.value()) { + update_shape_layout(instruction, copy_to_move.first); + if (instruction.first->opcode() == HloOpcode::kParameter) { + auto callers = + call_graph->GetComputationCallers(instruction.first->parent()); + if (callers.size() != 1) { + return absl::InvalidArgumentError( + "Expected to be called only by one caller"); + } + auto* caller = callers[0]; + if (caller->opcode() == HloOpcode::kWhile) { + update_shape_layout(std::make_pair(caller, instruction.second), + copy_to_move.first); + + HloInstruction* root_instruction = + caller->while_body()->root_instruction(); + // Fix while loop's result tuple to not use move-to-device since + // at loop entry it's still on host. + if (root_instruction->operand(instruction.second) + ->IsCustomCall(host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget)) { + root_instruction + ->ReplaceOperandWith( + instruction.second, + root_instruction->mutable_operand(instruction.second) + ->mutable_operand(0)) + .IgnoreError(); + } + } + } + } + stack.pop_back(); + for (auto& instruction : current_value_down.value()) { + VLOG(5) << "Current value last down: " << instruction.first->ToString(); + CHECK_NE(instruction.first->opcode(), HloOpcode::kCopy) + << "Copies should be processed in order"; + if (absl::c_linear_search(kUsersOpcodes, instruction.first->opcode()) || + instruction.first->IsCustomCall( + host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget)) { + HloInstruction* annotation = + FindToDeviceAnnotationToUpdate(instruction.first); + CHECK_NE(annotation, nullptr) + << "We already verified we could find an annotation here. " + "Something went wrong."; + HloInstruction* new_annotation = nullptr; + if (instruction.first->opcode() == HloOpcode::kCustomCall) { + new_annotation = annotation; + } else { + new_annotation = instruction.first->AddInstruction( + annotation->CloneWithNewOperands(instruction.first->shape(), + {instruction.first})); + } + update_shape_layout(std::make_pair(new_annotation, -1), + copy_to_move.first); + Shape new_copy_shape = new_annotation->shape(); + *new_copy_shape.mutable_layout() = + copy_to_move.first->shape().layout(); + HloInstruction* new_copy = instruction.first->AddInstruction( + copy_to_move.first->CloneWithNewOperands(new_copy_shape, + {new_annotation})); + std::vector users = instruction.first->users(); + for (auto* use : users) { + if (use == new_copy || use == new_annotation) { + continue; + } + TF_RETURN_IF_ERROR( + instruction.first->ReplaceUseWithDifferentShape(use, new_copy)); + } + // Move the copy here. + if (new_annotation != annotation) { + TF_RETURN_IF_ERROR(annotation->ReplaceAllUsesWithDifferentShape( + annotation->mutable_operand(0))); + to_remove.push_back(annotation); + } + continue; + } + // Move the annotation first just before dynamic-update-slice to avoid + // shape changes. + if (instruction.first->opcode() == HloOpcode::kDynamicUpdateSlice) { + HloInstruction* annotation = FindToHostAnnotationToUpdate( + instruction.first->mutable_operand(1)); + if (annotation == nullptr) { + CHECK(false); + return false; + } + CHECK(annotation->opcode() == HloOpcode::kCustomCall); + HloInstruction* new_annotation = instruction.first->AddInstruction( + annotation->CloneWithNewOperands( + instruction.first->operand(1)->shape(), + {instruction.first->mutable_operand(1)})); + TF_RETURN_IF_ERROR( + instruction.first->ReplaceOperandWith(1, new_annotation)); + TF_RETURN_IF_ERROR( + annotation->ReplaceAllUsesWith(annotation->mutable_operand(0))); + processed_annotations.insert(annotation); + processed_annotations.insert(new_annotation); + to_remove.push_back(annotation); + } + stack.push_back(instruction); + } + } + VLOG(5) << "MOVED: " << copy_to_move.first->ToString(); + TF_RETURN_IF_ERROR(copy_to_move.first->ReplaceAllUsesWithDifferentShape( + copy_to_move.first->mutable_operand(0))); + TF_RETURN_IF_ERROR( + copy_to_move.first->parent()->RemoveInstruction(copy_to_move.first)); + copies_to_move.pop_back(); + } + return true; +} + +// Fixes layout changing copies in between on the path to users. +absl::StatusOr FixupInterveningCopies( + const std::vector& copy_to_host_annotations, + const CallGraph* call_graph) { + absl::flat_hash_set processed_annotations; + std::vector annotations_to_remove; + bool changed = false; + for (HloInstruction* instruction : copy_to_host_annotations) { + if (processed_annotations.contains(instruction)) { + continue; + } + TF_ASSIGN_OR_RETURN(bool changed_annotation_for_copy_movement, + ProcessAnnotationForCopyMovement( + instruction, call_graph, processed_annotations, + annotations_to_remove)); + changed |= changed_annotation_for_copy_movement; + } + for (HloInstruction* instruction : annotations_to_remove) { + TF_RETURN_IF_ERROR(instruction->parent()->RemoveInstruction(instruction)); + } + return changed; +} + +} // namespace + +absl::StatusOr HostOffloadLegalize::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + + // Split broadcasts so that each HloUse of a broadcast instruction will get + // its own copy. + // TODO(b/319293925): Do not blindly duplicate all broadcasts, instead do it + // only when necessary. + TF_ASSIGN_OR_RETURN(bool duplicated_at_least_one_broadcast, + DuplicateBroadcastForEachUse(module)); + if (duplicated_at_least_one_broadcast) { + changed = true; + } + if (!after_layout_) { + return changed; + } + std::unique_ptr call_graph = CallGraph::Build(module); + std::vector copy_to_host_annotations; + + // Iterate over all instructions and look for XLA host offload annotations. + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kParameter && + instruction->parent()->IsEntryComputation()) { + Shape param_shape = + module->entry_computation_layout() + .parameter_layout(instruction->parameter_number()) + .shape(); + // TODO(mingyao): Add support for tuple parameter. + if (param_shape.has_layout() && + param_shape.layout().memory_space() == kHostMemorySpaceColor) { + copy_to_host_annotations.push_back(instruction); + continue; + } + } + + if (instruction->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget)) { + copy_to_host_annotations.push_back(instruction); + } + } + } + // Fixup layout changing copies that are in between memory offloaded sections. + // Move them before the data is moved to the host. + TF_ASSIGN_OR_RETURN( + bool changed_intervening_copies, + FixupInterveningCopies(copy_to_host_annotations, call_graph.get())); + changed |= changed_intervening_copies; + + return changed; +} + +} // namespace xla diff --git a/xla/service/host_offload_legalize.h b/xla/service/host_offload_legalize.h new file mode 100644 index 0000000000000..73f050fbdfed3 --- /dev/null +++ b/xla/service/host_offload_legalize.h @@ -0,0 +1,55 @@ +/* Copyright 2024 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#ifndef XLA_SERVICE_HOST_OFFLOAD_LEGALIZE_H_ +#define XLA_SERVICE_HOST_OFFLOAD_LEGALIZE_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { + +class HloCostAnalysis; + +// This pass legalizes the graph for the "host memory offloading" pass to +// correctly identified buffers that are meant to be move on the host. Any +// legalization that could block that is welcome into this pass. +class HostOffloadLegalize : public HloModulePass { + public: + explicit HostOffloadLegalize(int64_t host_memory_space_color, + bool after_layout) + : kHostMemorySpaceColor(host_memory_space_color), + after_layout_(after_layout) {} + ~HostOffloadLegalize() override = default; + + absl::string_view name() const override { return "host-offload-legalize"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + const int64_t kHostMemorySpaceColor; + const bool after_layout_; +}; + +} // namespace xla + +#endif // XLA_SERVICE_HOST_OFFLOAD_LEGALIZE_H_ diff --git a/xla/service/host_offload_legalize_test.cc b/xla/service/host_offload_legalize_test.cc new file mode 100644 index 0000000000000..f9929648a3f12 --- /dev/null +++ b/xla/service/host_offload_legalize_test.cc @@ -0,0 +1,405 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/host_offload_legalize.h" + +#include +#include +#include + +#include +#include +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/host_memory_offload_annotations.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/statusor.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/util.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +namespace m = ::xla::match; + +namespace xla { +namespace { + +class HostOffloadLegalizeTest : public HloTestBase { + protected: + static constexpr int64_t kHostMemorySpaceColor{5}; + + absl::StatusOr RunHostOffloadLegalize(HloModule* module) { + TF_EXPECT_OK(verifier().Run(module).status()); + if (module->has_schedule()) { + return absl::InternalError("Expected a non-scheduled module"); + } + HostOffloadLegalize host_offload_legalize(kHostMemorySpaceColor, + /*after_layout=*/true); + return host_offload_legalize.Run(module); + } + + void TestShapeHasMemorySpace(const Shape& shape, int64_t memory_space) { + ASSERT_TRUE(shape.has_layout()); + EXPECT_EQ(shape.layout().memory_space(), memory_space); + } + + bool HaveRemainingOffloadAnnotations(const HloModule* module) { + for (const HloComputation* computation : module->computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + if (instruction->IsCustomCall( + {host_memory_offload_annotations::kMoveToHostCustomCallTarget, + host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget})) { + return true; + } + } + } + return false; + } +}; + +TEST_F(HostOffloadLegalizeTest, NoCopyWithOptBarrierMoreElaborate) { + const std::string& hlo_string = R"( +HloModule jit_f, entry_computation_layout={(f32[16,256]{0,1})->f32[16,256]{1,0}} + +ENTRY main.24 { + Arg_0.1 = f32[16,256]{0,1} parameter(0) + cosine.4 = f32[16,256]{0,1} cosine(Arg_0.1) + custom-call.5 = f32[16,256]{0,1} custom-call(cosine.4), custom_call_target="MoveToHost" + sine.3 = f32[16,256]{0,1} sine(Arg_0.1) + cosine.7 = f32[16,256]{0,1} cosine(sine.3) + custom-call.8 = f32[16,256]{0,1} custom-call(cosine.7), custom_call_target="MoveToHost" + sine.6 = f32[16,256]{0,1} sine(sine.3) + cosine.9 = f32[16,256]{0,1} cosine(sine.6) + custom-call.10 = f32[16,256]{0,1} custom-call(cosine.9), custom_call_target="MoveToHost" + constant.2 = f32[] constant(1) + cp = f32[16,256]{1,0} copy(custom-call.8) + tuple.11 = (f32[16,256]{0,1}, f32[16,256]{1,0}, f32[16,256]{0,1}, f32[]) tuple(custom-call.5, cp, custom-call.10, constant.2) + opt-barrier.12 = (f32[16,256]{0,1}, f32[16,256]{1,0}, f32[16,256]{0,1}, f32[]) opt-barrier(tuple.11) + get-tuple-element.16 = f32[] get-tuple-element(opt-barrier.12), index=3 + broadcast.20 = f32[16,256]{0,1} broadcast(get-tuple-element.16), dimensions={} + get-tuple-element.15 = f32[16,256]{0,1} get-tuple-element(opt-barrier.12), index=2 + custom-call.19 = f32[16,256]{0,1} custom-call(get-tuple-element.15), custom_call_target="MoveToDevice" + multiply.21 = f32[16,256]{0,1} multiply(broadcast.20, custom-call.19) + cp2 = f32[16,256]{1,0} copy(multiply.21) + get-tuple-element.14 = f32[16,256]{1,0} get-tuple-element(opt-barrier.12), index=1 + custom-call.18 = f32[16,256]{1,0} custom-call(get-tuple-element.14), custom_call_target="MoveToDevice" + multiply.22 = f32[16,256]{1,0} multiply(cp2, custom-call.18) + get-tuple-element.13 = f32[16,256]{0,1} get-tuple-element(opt-barrier.12), index=0 + custom-call.17 = f32[16,256]{0,1} custom-call(get-tuple-element.13), custom_call_target="MoveToDevice" + cp3 = f32[16,256]{1,0} copy(custom-call.17) + ROOT multiply.23 = f32[16,256]{1,0} multiply(multiply.22, cp3) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloadLegalize(module.get())); + + EXPECT_TRUE(changed); + XLA_VLOG_LINES(1, module->ToString()); + HloInstruction* custom_call = FindInstruction(module.get(), "custom-call.18"); + EXPECT_EQ(custom_call->users()[0]->opcode(), HloOpcode::kCopy); + EXPECT_EQ(custom_call->shape().layout(), LayoutUtil::MakeLayout({0, 1})); + EXPECT_EQ(custom_call->users()[0]->shape().layout(), + LayoutUtil::MakeLayout({1, 0})); +} + +TEST_F(HostOffloadLegalizeTest, XposeCopyOnParameterStreaming) { + const std::string& hlo_string = R"( +HloModule jit_f, entry_computation_layout={(f32[16,256]{0,1},f32[16,256]{0,1:T(8,128)S(5)})->f32[16,256]{1,0}} + +ENTRY main.24 { + Arg_0.1 = f32[16,256]{0,1} parameter(0) + Arg_0.2 = f32[16,256]{0,1:T(8,128)} parameter(1) + cp0 = f32[16,256]{1,0} copy(Arg_0.2) + cosine.4 = f32[16,256]{0,1} cosine(Arg_0.1) + custom-call.5 = f32[16,256]{0,1} custom-call(cosine.4), custom_call_target="MoveToHost" + sine.3 = f32[16,256]{0,1} sine(Arg_0.1) + cosine.7 = f32[16,256]{0,1} cosine(sine.3) + custom-call.8 = f32[16,256]{0,1} custom-call(cosine.7), custom_call_target="MoveToHost" + constant.2 = f32[] constant(1) + cp1 = f32[16,256]{1,0} copy(custom-call.8) + tuple.11 = (f32[16,256]{0,1}, f32[16,256]{1,0}, f32[16,256]{1,0}, f32[]) tuple(custom-call.5, cp1, cp0, constant.2) + opt-barrier.12 = (f32[16,256]{0,1}, f32[16,256]{1,0}, f32[16,256]{1,0}, f32[]) opt-barrier(tuple.11) + get-tuple-element.16 = f32[] get-tuple-element(opt-barrier.12), index=3 + broadcast.20 = f32[16,256]{0,1} broadcast(get-tuple-element.16), dimensions={} + get-tuple-element.15 = f32[16,256]{1,0} get-tuple-element(opt-barrier.12), index=2 + custom-call.19 = f32[16,256]{1,0} custom-call(get-tuple-element.15), custom_call_target="MoveToDevice" + multiply.21 = f32[16,256]{0,1} multiply(broadcast.20, custom-call.19) + cp2 = f32[16,256]{1,0} copy(multiply.21) + get-tuple-element.14 = f32[16,256]{1,0} get-tuple-element(opt-barrier.12), index=1 + custom-call.18 = f32[16,256]{1,0} custom-call(get-tuple-element.14), custom_call_target="MoveToDevice" + multiply.22 = f32[16,256]{1,0} multiply(cp2, custom-call.18) + get-tuple-element.13 = f32[16,256]{0,1} get-tuple-element(opt-barrier.12), index=0 + custom-call.17 = f32[16,256]{0,1} custom-call(get-tuple-element.13), custom_call_target="MoveToDevice" + cp3 = f32[16,256]{1,0} copy(custom-call.17) + ROOT multiply.23 = f32[16,256]{1,0} multiply(multiply.22, cp3) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloadLegalize(module.get())); + + EXPECT_TRUE(changed); + XLA_VLOG_LINES(1, module->ToString()); + HloInstruction* custom_call = FindInstruction(module.get(), "custom-call.18"); + EXPECT_EQ(custom_call->users()[0]->opcode(), HloOpcode::kCopy); + EXPECT_EQ(custom_call->shape().layout(), LayoutUtil::MakeLayout({0, 1})); + EXPECT_EQ(custom_call->users()[0]->shape().layout(), + LayoutUtil::MakeLayout({1, 0})); + + custom_call = FindInstruction(module.get(), "custom-call.19"); + EXPECT_EQ(custom_call->users()[0]->opcode(), HloOpcode::kCopy); + EXPECT_EQ(custom_call->shape().layout(), + LayoutUtil::MakeLayout({0, 1}, {}, {}, {}, {Tile{{8, 128}}})); + EXPECT_EQ(custom_call->users()[0]->shape().layout(), + LayoutUtil::MakeLayout({1, 0})); +} + +TEST_F(HostOffloadLegalizeTest, LlmActivationHostMemoryMultipleConsumers) { + const std::string& hlo_string = R"( +HloModule llm_while + +producing_while_condition { + producing_condition_param = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) parameter(0) + producing_condition_current_iteration_index = s32[] get-tuple-element(producing_condition_param), index=0 + producing_condition_iteration_count = s32[] constant(96) + ROOT producing_condition_result = pred[] compare(producing_condition_current_iteration_index, producing_condition_iteration_count), direction=LT +} + +consuming_while_condition { + consuming_condition_param = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) parameter(0) + consuming_condition_current_iteration_index = s32[] get-tuple-element(consuming_condition_param), index=0 + consuming_condition_iteration_count = s32[] constant(96) + ROOT consuming_condition_result = pred[] compare(consuming_condition_current_iteration_index, consuming_condition_iteration_count), direction=LT +} + +producing_while_body { + input_tuple.0 = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) parameter(0) + current_iteration_index.0 = s32[] get-tuple-element(input_tuple.0), index=0 + data_0.0 = f32[96,8,6,2048,2048]{0,1,2,3,4} get-tuple-element(input_tuple.0), index=1 + constant_0.0 = s32[] constant(0) + constant_1.0 = s32[] constant(1) + constant_96 = s32[] constant(96) + + /* Create dummy data used in DUS */ + slice_data_0 = f32[1,8,6,2048,2048] constant({...}) + + /* Build DUS index */ + compare_result.0 = pred[] compare(current_iteration_index.0, constant_0.0), direction=LT + add_result = s32[] add(current_iteration_index.0, constant_96) + select_result.0 = s32[] select(compare_result.0, add_result, current_iteration_index.0) + + /* Annotate DUS for offload */ + custom_call_0.0 = f32[1,8,6,2048,2048] custom-call(slice_data_0), custom_call_target="MoveToHost" + + dynamic_update_slice_0 = f32[96,8,6,2048,2048]{0,1,2,3,4} dynamic-update-slice(data_0.0, custom_call_0.0, select_result.0, constant_0.0, constant_0.0, constant_0.0, constant_0.0) + + /* Increment iteration index */ + incremented_index.0 = s32[] add(current_iteration_index.0, constant_1.0) + ROOT tuple_result.0 = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) tuple(incremented_index.0, dynamic_update_slice_0) +} + +consuming_while_body { + input_tuple.1 = (s32[], f32[96,8,6,2048,2048]{0,1,3,2,4}) parameter(0) + current_iteration_index.1 = s32[] get-tuple-element(input_tuple.1), index=0 + data_0.1 = f32[96,8,6,2048,2048]{0,1,3,2,4} get-tuple-element(input_tuple.1), index=1 + constant_0.1 = s32[] constant(0) + constant_1.1 = s32[] constant(1) + constant_95 = s32[] constant(95) + constant_191 = s32[] constant(191) + + /* Build DS index */ + subtract_0 = s32[] subtract(constant_95, current_iteration_index.1) + compare_result.1 = pred[] compare(subtract_0, constant_0.1), direction=LT + subtract_1 = s32[] subtract(constant_191, current_iteration_index.1) + select_result.1 = s32[] select(compare_result.1, subtract_1, subtract_0) + + dynamic_slice_0 = f32[1,8,6,2048,2048] dynamic-slice(data_0.1, select_result.1, constant_0.1, constant_0.1, constant_0.1, constant_0.1), dynamic_slice_sizes={1,8,6,2048,2048} + + /* Annotate DS for offload */ + custom_call_0.1 = f32[1,8,6,2048,2048] custom-call(dynamic_slice_0), custom_call_target="MoveToDevice" + + /* Do some work with the dynamic slice outputs. */ + tanh_0 = f32[1,8,6,2048,2048] tanh(custom_call_0.1) + + /* Increment iteration index */ + incremented_index.1 = s32[] add(current_iteration_index.1, constant_1.1) + ROOT tuple_result.1 = (s32[], f32[96,8,6,2048,2048]{0,1,3,2,4}) tuple(incremented_index.1, data_0.1) +} + +ENTRY main { + entry_param_0 = f32[] parameter(0) + entry_param_1 = s32[] parameter(1) + entry_param_2 = s32[] parameter(2) + cs0 = f32[] constant(0) + broadcast_0 = f32[96,8,6,2048,2048]{0,1,2,3,4} broadcast(cs0), dimensions={} + constant_s32_0 = s32[] constant(0) + tuple_for_producing_while = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) tuple(constant_s32_0, broadcast_0) + producing_while = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) while(tuple_for_producing_while), condition=producing_while_condition, body=producing_while_body + while_output_1 = f32[96,8,6,2048,2048]{0,1,2,3,4} get-tuple-element(producing_while), index=1 + cp = f32[96,8,6,2048,2048]{0,1,3,2,4} copy(while_output_1) + tuple_for_consuming_while = (s32[], f32[96,8,6,2048,2048]{0,1,3,2,4}) tuple(constant_s32_0, cp) + consuming_while = (s32[], f32[96,8,6,2048,2048]{0,1,3,2,4}) while(tuple_for_consuming_while), condition=consuming_while_condition, body=consuming_while_body + second_while_output = f32[96,8,6,2048,2048]{0,1,3,2,4} get-tuple-element(consuming_while), index=1 + final_dynamic_slice_0 = f32[1,8,6,2048,2048] dynamic-slice(second_while_output, entry_param_1, constant_s32_0, constant_s32_0, constant_s32_0, constant_s32_0), dynamic_slice_sizes={1,8,6,2048,2048} + final_host_to_device_custom_call_0 = f32[1,8,6,2048,2048] custom-call(final_dynamic_slice_0), custom_call_target="MoveToDevice" + final_slice_0 = f32[1,8,6,2048,2048] slice(second_while_output), slice={[41:42], [0:8], [0:6], [0:2048], [0:2048]} + final_host_to_device_custom_call_1 = f32[1,8,6,2048,2048] custom-call(final_slice_0), custom_call_target="MoveToDevice" + ROOT add = f32[1,8,6,2048,2048] add(final_host_to_device_custom_call_0, final_host_to_device_custom_call_1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloadLegalize(module.get())); + + EXPECT_TRUE(changed); + HloInstruction* copy = FindInstruction(module.get(), HloOpcode::kCopy); + HloInstruction* consuming_while = + FindInstruction(module.get(), "consuming_while"); + EXPECT_NE(copy, nullptr); + EXPECT_NE(consuming_while, nullptr); + EXPECT_EQ(copy->parent(), consuming_while->while_body()); + XLA_VLOG_LINES(1, module->ToString()); +} + +TEST_F(HostOffloadLegalizeTest, LlmActivationHostMemoryMultipleCopies) { + const std::string& hlo_string = R"( +HloModule llm_while + +producing_while_condition { + producing_condition_param = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) parameter(0) + producing_condition_current_iteration_index = s32[] get-tuple-element(producing_condition_param), index=0 + producing_condition_iteration_count = s32[] constant(96) + ROOT producing_condition_result = pred[] compare(producing_condition_current_iteration_index, producing_condition_iteration_count), direction=LT +} + +consuming_while_condition { + consuming_condition_param = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) parameter(0) + consuming_condition_current_iteration_index = s32[] get-tuple-element(consuming_condition_param), index=0 + consuming_condition_iteration_count = s32[] constant(96) + ROOT consuming_condition_result = pred[] compare(consuming_condition_current_iteration_index, consuming_condition_iteration_count), direction=LT +} + +producing_while_body { + input_tuple.0 = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) parameter(0) + current_iteration_index.0 = s32[] get-tuple-element(input_tuple.0), index=0 + data_0.0 = f32[96,8,6,2048,2048]{0,1,2,3,4} get-tuple-element(input_tuple.0), index=1 + constant_0.0 = s32[] constant(0) + constant_1.0 = s32[] constant(1) + constant_96 = s32[] constant(96) + + /* Create dummy data used in DUS */ + slice_data_0 = f32[1,8,6,2048,2048] constant({...}) + + /* Build DUS index */ + compare_result.0 = pred[] compare(current_iteration_index.0, constant_0.0), direction=LT + add_result = s32[] add(current_iteration_index.0, constant_96) + select_result.0 = s32[] select(compare_result.0, add_result, current_iteration_index.0) + + /* Annotate DUS for offload */ + custom_call_0.0 = f32[1,8,6,2048,2048] custom-call(slice_data_0), custom_call_target="MoveToHost" + + dynamic_update_slice_0 = f32[96,8,6,2048,2048]{0,1,2,3,4} dynamic-update-slice(data_0.0, custom_call_0.0, select_result.0, constant_0.0, constant_0.0, constant_0.0, constant_0.0) + + /* Increment iteration index */ + incremented_index.0 = s32[] add(current_iteration_index.0, constant_1.0) + ROOT tuple_result.0 = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) tuple(incremented_index.0, dynamic_update_slice_0) +} + +consuming_while_body { + input_tuple.1 = (s32[], f32[96,8,6,2048,2048]{0,1,3,2,4}) parameter(0) + current_iteration_index.1 = s32[] get-tuple-element(input_tuple.1), index=0 + data_0.1 = f32[96,8,6,2048,2048]{0,1,3,2,4} get-tuple-element(input_tuple.1), index=1 + constant_0.1 = s32[] constant(0) + constant_1.1 = s32[] constant(1) + constant_95 = s32[] constant(95) + constant_191 = s32[] constant(191) + + /* Build DS index */ + subtract_0 = s32[] subtract(constant_95, current_iteration_index.1) + compare_result.1 = pred[] compare(subtract_0, constant_0.1), direction=LT + subtract_1 = s32[] subtract(constant_191, current_iteration_index.1) + select_result.1 = s32[] select(compare_result.1, subtract_1, subtract_0) + + dynamic_slice_0 = f32[1,8,6,2048,2048] dynamic-slice(data_0.1, select_result.1, constant_0.1, constant_0.1, constant_0.1, constant_0.1), dynamic_slice_sizes={1,8,6,2048,2048} + + /* Annotate DS for offload */ + custom_call_0.1 = f32[1,8,6,2048,2048] custom-call(dynamic_slice_0), custom_call_target="MoveToDevice" + + /* Do some work with the dynamic slice outputs. */ + tanh_0 = f32[1,8,6,2048,2048] tanh(custom_call_0.1) + + /* Increment iteration index */ + incremented_index.1 = s32[] add(current_iteration_index.1, constant_1.1) + ROOT tuple_result.1 = (s32[], f32[96,8,6,2048,2048]{0,1,3,2,4}) tuple(incremented_index.1, data_0.1) +} + +ENTRY main { + entry_param_0 = f32[] parameter(0) + entry_param_1 = s32[] parameter(1) + entry_param_2 = s32[] parameter(2) + cs0 = f32[] constant(0) + broadcast_0 = f32[96,8,6,2048,2048]{0,1,2,3,4} broadcast(cs0), dimensions={} + constant_s32_0 = s32[] constant(0) + tuple_for_producing_while = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) tuple(constant_s32_0, broadcast_0) + producing_while = (s32[], f32[96,8,6,2048,2048]{0,1,2,3,4}) while(tuple_for_producing_while), condition=producing_while_condition, body=producing_while_body + while_output_1 = f32[96,8,6,2048,2048]{0,1,2,3,4} get-tuple-element(producing_while), index=1 + cp = f32[96,8,6,2048,2048]{0,1,3,2,4} copy(while_output_1) + cp1 = f32[96,8,6,2048,2048]{0,1,3,2,4} copy(cp) + tuple_for_consuming_while = (s32[], f32[96,8,6,2048,2048]{0,1,3,2,4}) tuple(constant_s32_0, cp1) + consuming_while = (s32[], f32[96,8,6,2048,2048]{0,1,3,2,4}) while(tuple_for_consuming_while), condition=consuming_while_condition, body=consuming_while_body + second_while_output = f32[96,8,6,2048,2048]{0,1,3,2,4} get-tuple-element(consuming_while), index=1 + final_dynamic_slice_0 = f32[1,8,6,2048,2048] dynamic-slice(second_while_output, entry_param_1, constant_s32_0, constant_s32_0, constant_s32_0, constant_s32_0), dynamic_slice_sizes={1,8,6,2048,2048} + final_host_to_device_custom_call_0 = f32[1,8,6,2048,2048] custom-call(final_dynamic_slice_0), custom_call_target="MoveToDevice" + final_slice_0 = f32[1,8,6,2048,2048] slice(second_while_output), slice={[41:42], [0:8], [0:6], [0:2048], [0:2048]} + final_host_to_device_custom_call_1 = f32[1,8,6,2048,2048] custom-call(final_slice_0), custom_call_target="MoveToDevice" + ROOT add = f32[1,8,6,2048,2048] add(final_host_to_device_custom_call_0, final_host_to_device_custom_call_1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloadLegalize(module.get())); + + EXPECT_TRUE(changed); + HloInstruction* copy_0 = FindInstruction(module.get(), "cp.2"); + HloInstruction* copy_1 = FindInstruction(module.get(), "cp1.2"); + HloInstruction* consuming_while = + FindInstruction(module.get(), "consuming_while"); + EXPECT_NE(copy_0, nullptr); + EXPECT_NE(copy_1, nullptr); + EXPECT_NE(consuming_while, nullptr); + EXPECT_EQ(copy_0->parent(), module->entry_computation()); + EXPECT_EQ(copy_1->operand(0), copy_0); + XLA_VLOG_LINES(1, module->ToString()); +} + +} // namespace + +} // namespace xla diff --git a/xla/service/host_offloader.cc b/xla/service/host_offloader.cc new file mode 100644 index 0000000000000..747b877617b16 --- /dev/null +++ b/xla/service/host_offloader.cc @@ -0,0 +1,809 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/host_offloader.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_buffer.h" +#include "xla/service/hlo_value.h" +#include "xla/service/host_memory_offload_annotations.h" +#include "xla/service/pattern_matcher.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +namespace { + +void SetMemorySpace(Shape* shape, int64_t memory_space_color) { + CHECK(shape->has_layout()); + shape->mutable_layout()->set_memory_space(memory_space_color); +} + +// Checks if all of the HloPositions of this HloValue, apart from the defining +// position, are allowed when doing memory-only offload. +bool AllPositionsAreAllowed(const HloValue* value) { + // Given an HloValue, validate that none of its positions are doing any + // compute. + for (const HloPosition& position : value->positions()) { + if (position == value->defining_position()) { + // Skip defining positions. + continue; + } + // Check if this position is of an allowed type. + if (!absl::c_linear_search(HostOffloader::GetAllowedPositionOpcodes(), + position.instruction->opcode())) { + VLOG(1) << "Position " << position.instruction->ToString() + << " is not supported."; + return false; + } + } + + // Did not find any invalid ops. + return true; +} + +bool DefiningPositionIsAllowed(const HloInstruction* instruction) { + static constexpr std::array kAllowedOpcodes = {HloOpcode::kWhile, + HloOpcode::kParameter}; + return absl::c_linear_search(kAllowedOpcodes, instruction->opcode()); +} + +template +absl::StatusOr BufferHasPositionWithUser( + const HloBuffer& buffer, MatcherType matcher) { + HloInstruction* result = nullptr; + for (const HloValue* value : buffer.values()) { + for (const HloPosition& position : value->positions()) { + for (HloInstruction* user : position.instruction->users()) { + if (Match(user, matcher)) { + if (result != nullptr && result != user) { + return Internal("Found multiple matching users! At least %s and %s", + result->name(), user->name()); + } + result = user; + } + } + } + } + return result; +} + +template +absl::StatusOr> GetBufferPositionsWithUser( + const HloBuffer& buffer, MatcherType matcher) { + std::vector result; + for (const HloValue* value : buffer.values()) { + for (const HloPosition& position : value->positions()) { + for (HloInstruction* user : position.instruction->users()) { + if (Match(user, matcher)) { + result.emplace_back(user); + } + } + } + } + return result; +} + +template +absl::StatusOr> GetBufferUsersOfType( + const HloBuffer& buffer, MatcherType matcher) { + std::vector result; + for (const HloValue* value : buffer.values()) { + VLOG(3) << "Buffer defined at " << value->defining_instruction()->name() + << " has positions [" + << absl::StrJoin(value->positions(), ", ", + [](std::string* out, const HloPosition& position) { + out->append(position.instruction->name()); + }) + << "]"; + for (const HloPosition& position : value->positions()) { + VLOG(4) << " Position " << position.instruction->name() << " has users [" + << absl::StrJoin( + position.instruction->users(), ", ", + [](std::string* out, const HloInstruction* user) { + out->append(user->name()); + }) + << "]"; + for (HloInstruction* user : position.instruction->users()) { + if (Match(user, matcher)) { + result.emplace_back(user); + } + } + } + } + return result; +} + +HloInstruction* FindDSAnnotation(HloInstruction* hlo) { + while (!hlo->IsCustomCall( + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) { + if (hlo->opcode() != HloOpcode::kReshape && + hlo->opcode() != HloOpcode::kBitcast) { + break; + } + if (hlo->user_count() != 1) { + break; + } + hlo = hlo->users()[0]; + } + return hlo; +} + +} // namespace + +absl::StatusOr HostOffloader::TryOutputStreaming( + HloInstruction* custom_call) { + const HloBuffer& unique_buffer = + alias_analysis_->GetUniqueBufferAt(custom_call); + bool is_used_as_output_with_host_memory_space = false; + const HloComputation* const entry_computation = + custom_call->GetModule()->entry_computation(); + for (const HloValue* value : unique_buffer.values()) { + // Check if this is memory-only. + if (!AllPositionsAreAllowed(value)) { + // Found a position which is not allowed. + return false; + } + + // Look for a value used as a output. + for (const auto& position : value->positions()) { + const HloInstruction* instruction = position.instruction; + const ShapeIndex& index = position.index; + if (instruction->parent() == entry_computation && instruction->IsRoot()) { + const Shape& output_shape = + ShapeUtil::GetSubshape(entry_computation->parent() + ->entry_computation_layout() + .result_shape(), + index); + CHECK(output_shape.has_layout()); + + if (output_shape.layout().memory_space() != kHostMemorySpaceColor) { + return FailedPrecondition( + "Output buffer is annotated with %s but is not marked with host " + "memory space in the entry computation.", + custom_call->name()); + } + is_used_as_output_with_host_memory_space = true; + } + } + } + if (!is_used_as_output_with_host_memory_space) { + VLOG(1) << "Buffer annotated by " << custom_call->name() + << " is not used as an output with host memory space."; + return false; + } + + VLOG(3) << "Found an output buffer annotated with " << custom_call->name() + << ". Expecting that we'll need to insert copies."; + + annotations_for_copy_to_host_to_insert_.emplace(custom_call); + AddAllPositionsToBeMovedToHostMemory(unique_buffer); + return true; +} + +Status HostOffloader::HandleMoveToHostCustomCall(HloInstruction* custom_call) { + VLOG(2) << "Found a custom call annotating start-of-host-offload: " + << custom_call->ToString(); + // Save a pointer to this custom call for when we want to remove it later. + custom_calls_to_remove_.emplace(custom_call); + + // We expect that either the custom call is the root or the DUS is the only + // user of this custom call. + if (!custom_call->IsRoot() && custom_call->user_count() != 1) { + return FailedPrecondition( + "Expecting custom call %s to either be the root or only have 1 user; " + "it is not the root and has %d users: [%s]", + custom_call->name(), custom_call->user_count(), + absl::StrJoin(custom_call->users(), ", ", + [](std::string* out, const HloInstruction* user) { + out->append(user->name()); + })); + } + + HloInstruction* consumer = nullptr; + if (!custom_call->IsRoot()) { + consumer = custom_call->users().at(0); + // Skip past any bitcasts. + while (consumer != nullptr && consumer->opcode() == HloOpcode::kBitcast) { + VLOG(1) << "Skipping bitcast " << consumer->ToString(); + consumer = consumer->users().at(0); + } + } + + if (consumer != nullptr && + consumer->opcode() == HloOpcode::kDynamicUpdateSlice) { + TF_RETURN_IF_ERROR(MemoryOnlyOffloadStartingWithDus(consumer)); + } else if (consumer != nullptr && consumer->opcode() == HloOpcode::kCopy) { + TF_RETURN_IF_ERROR(MemoryOnlyOffloadStartingWithCopy(consumer)); + } else { + TF_ASSIGN_OR_RETURN(bool did_output_streaming, + TryOutputStreaming(custom_call)); + if (!did_output_streaming) { + TF_RETURN_IF_ERROR(MemoryOnlyOffloadInsertCopies(custom_call)); + } + } + return OkStatus(); +} + +Status HostOffloader::MemoryOnlyOffloadStartingWithDus( + const HloInstruction* dynamic_update_slice) { + // The user wants to offload the data defined by this dynamic-update-slice. + VLOG(2) << "Host memory offload starts with a dynamic-update-slice: " + << dynamic_update_slice->name(); + // Get the buffer for this DUS. + const HloBuffer& unique_buffer = + alias_analysis_->GetUniqueBufferAt(dynamic_update_slice); + + // We must find at least two HloValues: + // 1. Defined by a broadcast. + // a. For now, we only offload if the original destination of DUS is + // created by a broadcast. + // 2. Defined by a dynamic-update-slice. + const HloValue* dus_value = nullptr; + const HloValue* broadcast_value = nullptr; + for (const HloValue* value : unique_buffer.values()) { + HloInstruction* defining_instruction = + value->defining_position().instruction; + if (defining_instruction->opcode() == HloOpcode::kBroadcast) { + if (broadcast_value != nullptr) { + LOG(WARNING) << "Already found one broadcast (" + << broadcast_value->defining_position().instruction->name() + << ") value for this buffer. This one is " + << defining_instruction->name(); + } + broadcast_value = value; + } else if (defining_instruction->opcode() == + HloOpcode::kDynamicUpdateSlice) { + if (dus_value != nullptr) { + LOG(WARNING) << "Already found one dynamic-update-slice (" + << dus_value->defining_position().instruction->name() + << ") value for this buffer. This one is " + << defining_instruction->name(); + } + dus_value = value; + } else { + // For all values other than the two we were looking for, ensure that the + // defining position is non-compute as well as all other positions. + if (!DefiningPositionIsAllowed(value->defining_position().instruction)) { + return Internal( + "HloValue is defined by an unsupported op: %s. HloValue: %s", + defining_instruction->name(), value->ToString()); + } + if (!AllPositionsAreAllowed(value)) { + return Internal( + "HloValue defined by %s has an invalid position. HloValue: %s", + defining_instruction->name(), value->ToString()); + } + } + } + + // For the two found HloValues, ensure that all other positions are + // non-compute. + if (dus_value == nullptr) { + return Internal( + "DynamicUpdateSlice's buffer does not have a value which is defined by " + "a dynamic update slice. HloBuffer: %s", + unique_buffer.ToString()); + } + if (!AllPositionsAreAllowed(dus_value)) { + return Internal( + "HloValue defined by %s has an invalid position. HloValue: %s", + dus_value->defining_position().instruction->name(), + dus_value->ToString()); + } + if (broadcast_value == nullptr) { + return Internal( + "DynamicUpdateSlice's buffer does not have a value which is defined by " + "a broadcast. HloBuffer: %s", + unique_buffer.ToString()); + } + if (!AllPositionsAreAllowed(broadcast_value)) { + return Internal( + "HloValue defined by %s has an invalid position. HloValue: %s", + broadcast_value->defining_position().instruction->name(), + broadcast_value->ToString()); + } + + // TODO(b/319681297): Further analyze the HloValue defined by the broadcast. + // Make sure that nothing is expecting the result of the broadcast, as we'll + // be replacing it. + + // Check that this buffer is finally an input to at least one slice or + // dynamic-slice. + TF_ASSIGN_OR_RETURN( + std::vector consuming_slices, + GetBufferUsersOfType( + unique_buffer, + match::AnyOf(match::Slice(), match::DynamicSlice()))); + VLOG(2) << dynamic_update_slice->name() + << " is consumed by [dynamic-]slices: [" + << absl::StrJoin(consuming_slices, ", ", + [](std::string* out, const HloInstruction* inst) { + out->append(inst->name()); + }) + << ']'; + if (consuming_slices.empty()) { + return Internal( + "The dynamic-update-slice (%s) never feeds into a slice nor " + "dynamic-slice.", + dynamic_update_slice->name()); + } + + // Each dynamic_slice and slice should feed into another annotation. + for (HloInstruction* consuming_slice : consuming_slices) { + VLOG(1) << "Host data produced by " << dynamic_update_slice->name() + << " is consumed by " << consuming_slice->name(); + if (consuming_slice->user_count() != 1) { + return Internal( + "Slice/Dynamic-slice %s should only have one user. It should be an " + "annotation " + "to load the data back on the device. Instead, it has users [%s]", + consuming_slice->name(), + absl::StrJoin(consuming_slice->users(), ", ", + [](std::string* out, const HloInstruction* inst) { + out->append(inst->name()); + })); + } + HloInstruction* consuming_slice_user = + FindDSAnnotation(consuming_slice->users()[0]); + if (consuming_slice_user->opcode() != HloOpcode::kCustomCall) { + return Internal( + "Slice/Dynamic-slice %s does not have a matching annotation.", + consuming_slice->name()); + } + if (consuming_slice_user->custom_call_target() != + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget) { + return Internal( + "Found custom-call (%s) is not the expected matching host offload " + "annotation", + consuming_slice_user->name()); + } + expected_host_to_device_annotations_.emplace(consuming_slice_user); + } + + // Save the broadcast to later be replaced with a + // custom-call("AllocateBuffer") + broadcasts_to_replace_.emplace( + broadcast_value->defining_position().instruction); + AddAllPositionsToBeMovedToHostMemory(unique_buffer); + return OkStatus(); +} + +void HostOffloader::AddAllPositionsToBeMovedToHostMemory( + const HloBuffer& unique_buffer) { + for (const HloValue* value : unique_buffer.values()) { + for (const HloPosition& position : value->positions()) { + positions_to_move_to_host_memory_.emplace(position); + } + } +} + +Status HostOffloader::MemoryOnlyOffloadStartingWithCopy( + const HloInstruction* copy) { + // The user wants to offload the data defined by this copy. + VLOG(2) << "Host memory offload starts with a copy: " << copy->name(); + + // Get the buffer for this copy. + const HloBuffer& unique_buffer = alias_analysis_->GetUniqueBufferAt(copy); + + // Look for a value defined by a copy. + const HloValue* copy_value = nullptr; + for (const HloValue* value : unique_buffer.values()) { + HloInstruction* defining_instruction = + value->defining_position().instruction; + if (defining_instruction->opcode() == HloOpcode::kCopy) { + if (copy_value != nullptr) { + LOG(WARNING) + << "Already found one dynamic-update-slice value for this buffer"; + } + copy_value = value; + } else { + // For all other values (that aren't defined by a copy), ensure that the + // defining position is non-compute as well as all other positions. + if (!DefiningPositionIsAllowed(value->defining_position().instruction)) { + return Internal( + "HloValue is defined by an unsupported op: %s. HloValue: %s", + defining_instruction->name(), value->ToString()); + } + if (!AllPositionsAreAllowed(value)) { + return Internal( + "HloValue defined by %s has an invalid position. HloValue: %s", + defining_instruction->name(), value->ToString()); + } + } + } + + if (copy_value == nullptr) { + return Internal( + "Copy's buffer does not have a value which is defined by a copy. " + "HloBuffer: %s", + unique_buffer.ToString()); + } + // For the copy, ensure that all other positions are non-compute. + if (!AllPositionsAreAllowed(copy_value)) { + return Internal( + "HloValue defined by %s has an invalid position. HloValue: %s", + copy_value->defining_position().instruction->name(), + copy_value->ToString()); + } + + // Check that this buffer is finally an input to another copy. + TF_ASSIGN_OR_RETURN(HloInstruction * consuming_copy, + BufferHasPositionWithUser(unique_buffer, match::Copy())); + if (consuming_copy == nullptr) { + return Internal("The copy (%s) never feeds into another copy.", + copy->name()); + } + + // The copy should feed into another annotation. + if (consuming_copy->user_count() != 1) { + return Internal( + "Copy should only have one user. It should be an annotation to load " + "the data back on the device. Instead, it has users [%s]", + absl::StrJoin(consuming_copy->users(), ", ", + [](std::string* out, const HloInstruction* inst) { + out->append(inst->name()); + })); + } + HloInstruction* consuming_copy_user = consuming_copy->users()[0]; + if (consuming_copy_user->opcode() != HloOpcode::kCustomCall) { + return Internal("Copy does not have a matching annotation."); + } + if (consuming_copy_user->custom_call_target() != + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget) { + return Internal( + "Found custom-call is not the expected matching host offload " + "annotation"); + } + expected_host_to_device_annotations_.emplace(consuming_copy_user); + + AddAllPositionsToBeMovedToHostMemory(unique_buffer); + return OkStatus(); +} + +Status HostOffloader::MemoryOnlyOffloadInsertCopies( + HloInstruction* custom_call) { + VLOG(3) << "Found an offload annotation (" << custom_call->name() + << "). Expecting that we'll need to insert copies"; + const HloBuffer& unique_buffer = + alias_analysis_->GetUniqueBufferAt(custom_call); + for (const HloValue* value : unique_buffer.values()) { + HloInstruction* defining_instruction = + value->defining_position().instruction; + if (!AllPositionsAreAllowed(value)) { + return Internal( + "HloValue defined by %s has an invalid position. HloValue: %s", + defining_instruction->name(), value->ToString()); + } + } + + // Check that this buffer is finally an input to a load-from-host custom-call. + TF_ASSIGN_OR_RETURN( + std::vector matching_annotations, + GetBufferPositionsWithUser( + unique_buffer, + match::CustomCall({host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget}))); + if (matching_annotations.empty()) { + return Internal( + "The offloaded data (from %s) never feeds into a matching \"load\" " + "annotation.", + custom_call->name()); + } + + // This fits the pattern that we're looking for. Save these annotations to + // later insert copies around. + annotations_for_copy_to_host_to_insert_.emplace(custom_call); + for (HloInstruction* matching_annotation : matching_annotations) { + annotations_for_copy_to_device_to_insert_.emplace(matching_annotation); + + // Save the matching annotation to later be removed. + expected_host_to_device_annotations_.emplace(matching_annotation); + } + + AddAllPositionsToBeMovedToHostMemory(unique_buffer); + return OkStatus(); +} + +absl::StatusOr HostOffloader::TryParameterStreaming( + HloInstruction* custom_call) { + HloInstruction* operand_of_load_annotation = custom_call->mutable_operand(0); + const HloBuffer& unique_buffer = + alias_analysis_->GetUniqueBufferAt(operand_of_load_annotation); + bool is_defined_by_entry_param_with_host_memory_space = false; + const HloComputation* const entry_computation = + custom_call->GetModule()->entry_computation(); + for (const HloValue* value : unique_buffer.values()) { + // Check if this is memory-only. + if (!AllPositionsAreAllowed(value)) { + // Found a position which is not allowed. + return false; + } + // Look for a value defined by a entry computation parameter. + HloInstruction* defining_instruction = + value->defining_position().instruction; + if (defining_instruction->opcode() == HloOpcode::kParameter) { + if (defining_instruction->parent() == entry_computation) { + const Shape& param_shape = + entry_computation->parent() + ->entry_computation_layout() + .parameter_shape(defining_instruction->parameter_number()); + CHECK(param_shape.has_layout()); + if (param_shape.layout().memory_space() == kHostMemorySpaceColor) { + is_defined_by_entry_param_with_host_memory_space = true; + } + } + } + } + if (!is_defined_by_entry_param_with_host_memory_space) { + VLOG(1) << absl::StreamFormat( + "Buffer annotated by %s is not defined by an entry computation " + "parameter with host memory space.", + custom_call->name()); + return false; + } + + // Create a copy to the device. + Shape copy_shape = operand_of_load_annotation->shape(); + SetMemorySpace(©_shape, Layout::kDefaultMemorySpace); + HloInstruction* copy_to_device = + custom_call->parent()->AddInstruction(HloInstruction::CreateUnary( + copy_shape, HloOpcode::kCopy, operand_of_load_annotation)); + + auto users = operand_of_load_annotation->users(); + for (HloInstruction* use : users) { + if (use == copy_to_device) { + continue; + } + auto callers = call_graph_->GetComputationCallers(copy_to_device->parent()); + if (callers.size() > 1) { + return absl::InvalidArgumentError( + "Expected to be called only by one caller"); + } else if (callers.size() == 1) { + auto* caller = callers[0]; + if (caller->opcode() == HloOpcode::kWhile && + use->opcode() == HloOpcode::kTuple && use->IsRoot()) { + // Do not replace the while loop parameter with the moved data. Because + // of the nature of while loops, since the data started on the host, it + // must end on the host. Only the while loop body's root should not use + // copy_to_device since it's on host at the loop entry. + continue; + } + } + + TF_RETURN_IF_ERROR( + operand_of_load_annotation->ReplaceUseWith(use, copy_to_device)); + } + + AddAllPositionsToBeMovedToHostMemory(unique_buffer); + return true; +} + +Status HostOffloader::HandleMoveToDeviceCustomCall( + HloInstruction* custom_call) { + VLOG(2) << "Found a custom call annotating end-of-host-offload: " + << custom_call->ToString(); + TF_ASSIGN_OR_RETURN(bool did_parameter_streaming, + TryParameterStreaming(custom_call)); + if (did_parameter_streaming) { + expected_host_to_device_annotations_.emplace(custom_call); + } + // Save a pointer to this custom call for later removal. + found_host_to_device_annotations_.emplace(custom_call); + return OkStatus(); +} + +Status HostOffloader::DynamifySlice(HloInstruction* slice) { + VLOG(3) << "Dynamifying slice " << slice->ToString(); + std::vector start_constants; + for (int64_t start : slice->slice_starts()) { + HloInstruction* constant = slice->parent()->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(start))); + start_constants.push_back(constant); + } + std::vector slice_sizes; + slice_sizes.reserve(slice->slice_limits().size()); + for (int i = 0; i < slice->slice_limits().size(); ++i) { + slice_sizes.push_back(slice->slice_limits()[i] - slice->slice_starts()[i]); + } + HloInstruction* new_ds = + slice->parent()->AddInstruction(HloInstruction::CreateDynamicSlice( + slice->shape(), slice->mutable_operand(0), start_constants, + slice_sizes)); + VLOG(3) << "Newly created dynamic slice: " << new_ds->name(); + TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(new_ds)); + TF_RETURN_IF_ERROR(slice->parent()->RemoveInstruction(slice)); + return OkStatus(); +} + +absl::StatusOr HostOffloader::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + + call_graph_ = CallGraph::Build(module); + + // Run HloAliasAnalysis on module. + TF_ASSIGN_OR_RETURN(alias_analysis_, HloAliasAnalysis::Run(module)); + + // Iterate over all instructions and look for XLA host offload annotations. + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + if (instruction->opcode() != HloOpcode::kCustomCall) { + continue; + } + if (instruction->custom_call_target() == + host_memory_offload_annotations::kMoveToHostCustomCallTarget) { + TF_RETURN_IF_ERROR(HandleMoveToHostCustomCall(instruction)); + } else if (instruction->custom_call_target() == + host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget) { + TF_RETURN_IF_ERROR(HandleMoveToDeviceCustomCall(instruction)); + } + } + } + + // Insert copies to the host for the saved annotations. + for (HloInstruction* to_host_annotation : + annotations_for_copy_to_host_to_insert_) { + HloInstruction* data_to_host = to_host_annotation->mutable_operand(0); + // Create a copy (to host) of the first and only operand to the given custom + // call. + HloInstruction* copy_to_host = + data_to_host->parent()->AddInstruction(HloInstruction::CreateUnary( + data_to_host->shape(), HloOpcode::kCopy, data_to_host)); + // Replace all uses of the to-host annotation with the first copy. + TF_RETURN_IF_ERROR(to_host_annotation->ReplaceAllUsesWith(copy_to_host)); + // Also save the position of the newly created copy-to-host to later have + // its memory space updated. + positions_to_move_to_host_memory_.emplace(HloPosition{copy_to_host}); + } + + // Insert copies to the device for the saved annotations. + for (HloInstruction* to_device_annotation : + annotations_for_copy_to_device_to_insert_) { + HloInstruction* data_to_device = to_device_annotation->mutable_operand(0); + // Create another copy (back to device) of that copy. + HloInstruction* copy_to_device = + data_to_device->parent()->AddInstruction(HloInstruction::CreateUnary( + data_to_device->shape(), HloOpcode::kCopy, data_to_device)); + // Replace all uses of the to-device annotation with the second copy. + TF_RETURN_IF_ERROR( + to_device_annotation->ReplaceAllUsesWith(copy_to_device)); + } + + // Check that we found all the annotations that we expected. + if (found_host_to_device_annotations_ != + expected_host_to_device_annotations_) { + return Internal( + "There is a mismatch between the expected host-to-device annotations " + "(%s) and the found host-to-device annotations (%s)", + absl::StrJoin(expected_host_to_device_annotations_, ", ", + [](std::string* str, HloInstruction* instr) { + str->append(instr->name()); + }), + absl::StrJoin(found_host_to_device_annotations_, ", ", + [](std::string* str, HloInstruction* instr) { + str->append(instr->name()); + })); + } + + // Remove these host-to-device annotations. + for (HloInstruction* instr : found_host_to_device_annotations_) { + custom_calls_to_remove_.emplace(instr); + } + + absl::flat_hash_set slices_to_dynamify; + // Change the memory space of these positions to the host memory space. + for (const HloPosition& position : positions_to_move_to_host_memory_) { + // If a user of this position is a slice, change it to be a dynamic-slice. + for (HloInstruction* user : position.instruction->users()) { + if (user->opcode() == HloOpcode::kSlice) { + slices_to_dynamify.emplace(user); + } + } + Shape* shape_to_change = ShapeUtil::GetMutableSubshape( + position.instruction->mutable_shape(), position.index); + VLOG(2) << "Setting instruction to have host memory space: " + << position.instruction->name(); + SetMemorySpace(shape_to_change, kHostMemorySpaceColor); + changed = true; + } + + for (HloInstruction* user : slices_to_dynamify) { + TF_RETURN_IF_ERROR(DynamifySlice(user)); + } + + // Replace these broadcasts with AllocateBuffer instructions for host memory. + for (HloInstruction* broadcast : broadcasts_to_replace_) { + HloInstruction* allocate_buffer = + broadcast->parent()->AddInstruction(HloInstruction::CreateCustomCall( + broadcast->shape(), {}, "AllocateBuffer")); + VLOG(2) << "Replacing broadcast " << broadcast->name() + << " with AllocateBuffer " << allocate_buffer->ToString(); + SetMemorySpace(allocate_buffer->mutable_shape(), kHostMemorySpaceColor); + CHECK_OK(broadcast->ReplaceAllUsesWith(allocate_buffer)); + TF_RETURN_IF_ERROR(broadcast->parent()->RemoveInstruction(broadcast)); + changed = true; + } + + // Recompute alias analysis after changes. + TF_ASSIGN_OR_RETURN(alias_analysis_, HloAliasAnalysis::Run(module)); + auto uses_parameter_buffer = [this](HloInstruction* hlo) { + for (const HloBuffer* buffer : alias_analysis_->ComputeBuffersAt(hlo)) { + for (const HloValue* value : buffer->values()) { + for (const HloPosition& pos : value->positions()) { + if (absl::c_linear_search(hlo->parent()->parameter_instructions(), + pos.instruction)) { + return true; + } + } + } + } + return false; + }; + // Remove these custom-calls that were previously used for annotation. + for (HloInstruction* custom_call : custom_calls_to_remove_) { + CHECK_EQ(custom_call->operand_count(), 1); + HloInstruction* operand = custom_call->operands()[0]; + if (custom_call->parent() != + custom_call->GetModule()->entry_computation() && + custom_call->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget)) { + // Replace custom call with a copy for dynamic-update-slice in case it + // used parameter buffer directly because in case of aliasing with loop + // parameters control dependencies can mess with scheduling. + if (uses_parameter_buffer(operand)) { + VLOG(10) << "Adding copy for custom call " << custom_call->name(); + operand = + custom_call->parent()->AddInstruction(HloInstruction::CreateUnary( + operand->shape(), HloOpcode::kCopy, operand)); + } else { + VLOG(10) << "NOT Adding copy for custom call " << custom_call->name(); + } + } + CHECK_OK(custom_call->ReplaceAllUsesWith(operand)); + TF_RETURN_IF_ERROR(custom_call->parent()->RemoveInstruction(custom_call)); + changed = true; + } + + return changed; +} + +} // namespace xla diff --git a/xla/service/host_offloader.h b/xla/service/host_offloader.h new file mode 100644 index 0000000000000..cd6e319a6fd50 --- /dev/null +++ b/xla/service/host_offloader.h @@ -0,0 +1,102 @@ +/* Copyright 2024 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#ifndef XLA_SERVICE_HOST_OFFLOADER_H_ +#define XLA_SERVICE_HOST_OFFLOADER_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { + +class HloCostAnalysis; + +// This pass does "host memory offloading". If a tensor is annotated to be moved +// to or from the host, this pass will remove the annotations and update each +// tensor's layout with host memory spaces and insert copies if necessary. This +// pass checks to make sure that no compute is done on the tensors annotated for +// host memory offload; if there is compute, it is considered a user error and +// an error will be returned. +class HostOffloader : public HloModulePass { + public: + explicit HostOffloader(int64_t host_memory_space_color) + : kHostMemorySpaceColor(host_memory_space_color) {} + ~HostOffloader() override = default; + + absl::string_view name() const override { return "host-offloader"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + static absl::Span GetAllowedPositionOpcodes() { + return kAllowedPositionOpcodes; + } + + private: + const int64_t kHostMemorySpaceColor; + std::unique_ptr alias_analysis_; + absl::flat_hash_set found_host_to_device_annotations_; + absl::flat_hash_set expected_host_to_device_annotations_; + absl::flat_hash_set custom_calls_to_remove_; + absl::flat_hash_set broadcasts_to_replace_; + absl::flat_hash_set positions_to_move_to_host_memory_; + absl::flat_hash_set annotations_for_copy_to_host_to_insert_; + absl::flat_hash_set + annotations_for_copy_to_device_to_insert_; + std::unique_ptr call_graph_; + + // Positions of all HloValues of the given HloBuffer will be added to + // positions_to_move_to_host_memory_. + void AddAllPositionsToBeMovedToHostMemory(const HloBuffer& unique_buffer); + + absl::StatusOr TryParameterStreaming(HloInstruction* custom_call); + absl::StatusOr TryOutputStreaming(HloInstruction* custom_call); + Status HandleMoveToHostCustomCall(HloInstruction* custom_call); + Status HandleMoveToDeviceCustomCall(HloInstruction* custom_call); + + // Handle memory-only offloading where the data is written to the host via a + // dynamic-update-slice and is read back via a dynamic-slice. + Status MemoryOnlyOffloadStartingWithDus( + const HloInstruction* dynamic_update_slice); + + // Handle memory-only offloading where the data is written to the host via a + // copy and is read back via a copy. + Status MemoryOnlyOffloadStartingWithCopy(const HloInstruction* copy); + + // Handle memory-only offloading where there are no ops yet for data movement. + // We will insert copies at the points where the annotations are. + Status MemoryOnlyOffloadInsertCopies(HloInstruction* custom_call); + + Status DynamifySlice(HloInstruction* slice); + + static constexpr std::array kAllowedPositionOpcodes = { + HloOpcode::kBitcast, + HloOpcode::kGetTupleElement, + HloOpcode::kOptimizationBarrier, + HloOpcode::kParameter, + HloOpcode::kTuple, + HloOpcode::kWhile}; +}; + +} // namespace xla + +#endif // XLA_SERVICE_HOST_OFFLOADER_H_ diff --git a/xla/service/host_offloader_test.cc b/xla/service/host_offloader_test.cc new file mode 100644 index 0000000000000..73579a11d1c3d --- /dev/null +++ b/xla/service/host_offloader_test.cc @@ -0,0 +1,2062 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/host_offloader.h" + +#include +#include +#include + +#include +#include +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/host_memory_offload_annotations.h" +#include "xla/service/host_offload_legalize.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/statusor.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/util.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +namespace m = ::xla::match; + +namespace xla { +namespace { + +class HostOffloaderTest : public HloTestBase { + protected: + static constexpr int64_t kHostMemorySpaceColor{5}; + + absl::StatusOr RunHostOffloader(HloModule* module, + bool after_layout = false) { + TF_EXPECT_OK(verifier().Run(module).status()); + if (module->has_schedule()) { + return absl::InternalError("Expected a non-scheduled module"); + } + bool changed = false; + HostOffloadLegalize host_offload_legalize(kHostMemorySpaceColor, + after_layout); + TF_ASSIGN_OR_RETURN(bool legal_changed, host_offload_legalize.Run(module)); + changed |= legal_changed; + HostOffloader host_offloader(kHostMemorySpaceColor); + TF_ASSIGN_OR_RETURN(bool offload_changed, host_offloader.Run(module)); + changed |= offload_changed; + return changed; + } + + void TestShapeHasMemorySpace(const Shape& shape, int64_t memory_space) { + ASSERT_TRUE(shape.has_layout()); + EXPECT_EQ(shape.layout().memory_space(), memory_space); + } + + bool HaveRemainingOffloadAnnotations(const HloModule* module) { + for (const HloComputation* computation : module->computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + if (instruction->IsCustomCall( + {host_memory_offload_annotations::kMoveToHostCustomCallTarget, + host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget})) { + return true; + } + } + } + return false; + } +}; + +TEST_F(HostOffloaderTest, BasicDusDs) { + const std::string& hlo_string = R"( +HloModule my_module +ENTRY main { + data_param = f32[1,2048,2048] parameter(0) + index_param = s32[] parameter(1) + constant_f32_0 = f32[] constant(0) + constant_s32_0 = s32[] constant(0) + broadcast = f32[2,2048,2048] broadcast(constant_f32_0), dimensions={} + offload_custom_call = f32[1,2048,2048] custom-call(data_param), custom_call_target="MoveToHost" + dynamic_update_slice = f32[2,2048,2048] dynamic-update-slice(broadcast, offload_custom_call, index_param, constant_s32_0, constant_s32_0) + dynamic_slice = f32[1,2048,2048] dynamic-slice(dynamic_update_slice, index_param, constant_s32_0, constant_s32_0), dynamic_slice_sizes={1,2048,2048} + ROOT load_custom_call = f32[1,2048,2048] custom-call(dynamic_slice), custom_call_target="MoveToDevice" +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // Look for the following pattern: + // "AllocateBuffer" param_0 _... + // | / / + // dynamic-update-slice _... + // | / + // dynamic-slice + HloInstruction* param; + HloInstruction* allocate_buffer; + HloInstruction* dynamic_update_slice; + HloInstruction* dynamic_slice; + ASSERT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::DynamicSlice( + &dynamic_slice, + m::DynamicUpdateSlice( + &dynamic_update_slice, + m::CustomCall(&allocate_buffer, {"AllocateBuffer"}), + m::Parameter(¶m, 0), m::Op(), m::Op(), m::Op()), + m::Op(), m::Op(), m::Op()))); + TestShapeHasMemorySpace(param->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(allocate_buffer->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_update_slice->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_slice->shape(), Layout::kDefaultMemorySpace); + + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + +TEST_F(HostOffloaderTest, BasicCopy) { + const std::string& hlo_string = R"( +HloModule my_module +ENTRY main { + data_param = f32[2048] parameter(0) + offload_custom_call = f32[2048] custom-call(data_param), custom_call_target="MoveToHost" + copy_0 = f32[2048] copy(offload_custom_call) + copy_1 = f32[2048] copy(copy_0) + ROOT load_custom_call = f32[2048] custom-call(copy_1), custom_call_target="MoveToDevice" +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // Look for the following pattern: + // param + // | + // copy (to host) + // | + // copy (to device) + + HloInstruction* param; + HloInstruction* copy_to_host; + HloInstruction* copy_to_device; + ASSERT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Copy(©_to_device, + m::Copy(©_to_host, m::Parameter(¶m, 0))))); + TestShapeHasMemorySpace(param->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy_to_host->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(copy_to_device->shape(), Layout::kDefaultMemorySpace); + + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + +TEST_F(HostOffloaderTest, ParameterStreamingWithXposeCopyFeedingIntoWhile) { + const std::string& hlo_string = R"( +HloModule jit__prefill_impl, entry_computation_layout={(bf16[2,16,16]{2,1,0:T(8,128)(2,1)S(5)})->bf16[2,16,16]{1,2,0:T(8,128)(2,1)}} + +while_condition { + condition_param = (s32[], bf16[2,16,16]{1,2,0:T(8,128)(2,1)}, bf16[2,16,16]{1,2,0:T(8,128)(2,1)}) parameter(0) + condition_current_iteration_index = s32[] get-tuple-element(condition_param), index=0 + condition_iteration_count = s32[] constant(16) + ROOT condition_result = pred[] compare(condition_current_iteration_index, condition_iteration_count), direction=LT +} + +while_body { + input_tuple.0 = (s32[], bf16[2,16,16]{1,2,0:T(8,128)(2,1)}, bf16[2,16,16]{1,2,0:T(8,128)(2,1)}) parameter(0) + current_iteration_index.0 = s32[] get-tuple-element(input_tuple.0), index=0 + orig_data = bf16[2,16,16]{1,2,0:T(8,128)(2,1)} get-tuple-element(input_tuple.0), index=1 + custom-call.0 = bf16[2,16,16]{1,2,0:T(8,128)(2,1)} custom-call(orig_data), custom_call_target="MoveToDevice" + sum = bf16[2,16,16]{1,2,0:T(8,128)(2,1)} get-tuple-element(input_tuple.0), index=2 + sum.1 = bf16[2,16,16]{1,2,0:T(8,128)(2,1)} add(custom-call.0, sum) + + constant_1 = s32[] constant(1) + /* Increment iteration index */ + incremented_index.0 = s32[] add(current_iteration_index.0, constant_1) + ROOT tuple_result.0 = (s32[], bf16[2,16,16]{1,2,0:T(8,128)(2,1)}, bf16[2,16,16]{1,2,0:T(8,128)(2,1)}) tuple(incremented_index.0, custom-call.0, sum.1) +} + +ENTRY main { + param.0 = bf16[2,16,16]{2,1,0:T(8,128)(2,1)} parameter(0) + copy = bf16[2,16,16]{1,2,0:T(8,128)(2,1)} copy(param.0) + constant_0 = s32[] constant(0) + constant_0.0 = bf16[] constant(0.0) + broadcast = bf16[2,16,16]{1,2,0:T(8,128)(2,1)} broadcast(constant_0.0), dimensions={} + tuple_for_while = (s32[], bf16[2,16,16]{1,2,0:T(8,128)(2,1)}, bf16[2,16,16]{1,2,0:T(8,128)(2,1)}) tuple(constant_0, copy, broadcast) + while = (s32[], bf16[2,16,16]{1,2,0:T(8,128)(2,1)}, bf16[2,16,16]{1,2,0:T(8,128)(2,1)}) while(tuple_for_while), condition=while_condition, body=while_body + ROOT gte = bf16[2,16,16]{1,2,0:T(8,128)(2,1)} get-tuple-element(while), index=2 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, RunHostOffloader(module.get(), /*after_layout=*/true)); + EXPECT_TRUE(changed); + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); + HloVerifier verifier(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/true); + TF_EXPECT_OK(verifier.Run(module.get()).status()); + VLOG(1) << "module after: " << module->ToString(); +} + +TEST_F(HostOffloaderTest, BasicNoCopy) { + const std::string& hlo_string = R"( +HloModule my_module +ENTRY main { + data_param = f32[2048] parameter(0) + offload_custom_call = f32[2048] custom-call(data_param), custom_call_target="MoveToHost" + ROOT load_custom_call = f32[2048] custom-call(offload_custom_call), custom_call_target="MoveToDevice" +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // Look for the following pattern: + // param + // | + // copy (to host) + // | + // copy (to device) + + HloInstruction* param; + HloInstruction* copy_to_host; + HloInstruction* copy_to_device; + ASSERT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Copy(©_to_device, + m::Copy(©_to_host, m::Parameter(¶m, 0))))); + TestShapeHasMemorySpace(param->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy_to_host->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(copy_to_device->shape(), Layout::kDefaultMemorySpace); + + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + +TEST_F(HostOffloaderTest, NoCopyWithOptBarrier) { + const std::string& hlo_string = R"( +HloModule my_module +ENTRY main { + data_param = f32[2048] parameter(0) + offload_custom_call = f32[2048] custom-call(data_param), custom_call_target="MoveToHost" + tuple = (f32[2048]) tuple(offload_custom_call) + opt_barrier = (f32[2048]) opt-barrier(tuple) + get_tuple_element = f32[2048] get-tuple-element(opt_barrier), index=0 + ROOT load_custom_call = f32[2048] custom-call(get_tuple_element), custom_call_target="MoveToDevice" +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // Look for the following pattern: + // param + // | + // copy (to host) + // | + // tuple + // | + // opt-barrier + // | + // get-tuple-element + // | + // copy (to device) + + HloInstruction* param; + HloInstruction* copy_to_host; + HloInstruction* tuple; + HloInstruction* opt_barrier; + HloInstruction* gte; + HloInstruction* copy_to_device; + ASSERT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Copy( + ©_to_device, + m::GetTupleElement( + >e, m::OptimizationBarrier( + &opt_barrier, + m::Tuple(&tuple, m::Copy(©_to_host, + m::Parameter(¶m, 0)))))))); + TestShapeHasMemorySpace(param->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy_to_host->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {0}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(opt_barrier->shape(), {0}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(copy_to_device->shape(), Layout::kDefaultMemorySpace); + + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + +TEST_F(HostOffloaderTest, NoCopyMultipleToDevice) { + const std::string& hlo_string = R"( +HloModule my_module +ENTRY main { + constant = f32[] constant(0) + custom_call_0 = f32[] custom-call(constant), custom_call_target="MoveToHost" + tuple_0 = (f32[], f32[]) tuple(custom_call_0, custom_call_0) + opt_barrier = (f32[], f32[]) opt-barrier(tuple_0) + gte_0 = f32[] get-tuple-element(opt_barrier), index=0 + custom_call_1 = f32[] custom-call(gte_0), custom_call_target="MoveToDevice" + gte_1 = f32[] get-tuple-element(opt_barrier), index=1 + custom_call_2 = f32[] custom-call(gte_1), custom_call_target="MoveToDevice" + ROOT tuple_1 = (f32[], f32[]) tuple(custom_call_1, custom_call_2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // Look for the following pattern: + // constant + // | + // copy + // | | + // tuple + // | + // opt-barrier + // / \ + // gte gte + // | | + // copy copy + // \ / + // tuple + HloInstruction* constant; + HloInstruction* copy_to_host_1; + HloInstruction* copy_to_host_2; + HloInstruction* tuple_1; + HloInstruction* opt_barrier; + HloInstruction* gte_1; + HloInstruction* copy_to_device_1; + HloInstruction* gte_2; + HloInstruction* copy_to_device_2; + HloInstruction* tuple_2; + const auto constant_pattern = m::ConstantScalar(&constant, 0); + const auto opt_barrier_pattern = m::OptimizationBarrier( + &opt_barrier, + m::Tuple(&tuple_1, m::Copy(©_to_host_1, constant_pattern), + m::Copy(©_to_host_2, constant_pattern))); + ASSERT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + &tuple_2, + m::Copy(©_to_device_1, + m::GetTupleElement(>e_1, opt_barrier_pattern)), + m::Copy(©_to_device_2, + m::GetTupleElement(>e_2, opt_barrier_pattern))))); + TestShapeHasMemorySpace(constant->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy_to_host_1->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(copy_to_host_2->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple_1->shape(), {0}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple_1->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(opt_barrier->shape(), {0}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(opt_barrier->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte_1->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(copy_to_device_1->shape(), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(gte_2->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(copy_to_device_2->shape(), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple_2->shape(), {0}), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple_2->shape(), {1}), + Layout::kDefaultMemorySpace); + + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + +TEST_F(HostOffloaderTest, NoCopyWithOptBarrierMoreElaborate) { + const std::string& hlo_string = R"( +HloModule jit_f, entry_computation_layout={(f32[16]{0})->f32[16]{0}} + +ENTRY main.24 { + Arg_0.1 = f32[16]{0} parameter(0), sharding={devices=[2]<=[2]} + cosine.4 = f32[16]{0} cosine(Arg_0.1) + custom-call.5 = f32[16]{0} custom-call(cosine.4), custom_call_target="MoveToHost" + sine.3 = f32[16]{0} sine(Arg_0.1) + cosine.7 = f32[16]{0} cosine(sine.3) + custom-call.8 = f32[16]{0} custom-call(cosine.7), custom_call_target="MoveToHost" + sine.6 = f32[16]{0} sine(sine.3) + cosine.9 = f32[16]{0} cosine(sine.6) + custom-call.10 = f32[16]{0} custom-call(cosine.9), custom_call_target="MoveToHost" + constant.2 = f32[] constant(1) + tuple.11 = (f32[16]{0}, f32[16]{0}, f32[16]{0}, f32[]) tuple(custom-call.5, custom-call.8, custom-call.10, constant.2) + opt-barrier.12 = (f32[16]{0}, f32[16]{0}, f32[16]{0}, f32[]) opt-barrier(tuple.11) + get-tuple-element.16 = f32[] get-tuple-element(opt-barrier.12), index=3 + broadcast.20 = f32[16]{0} broadcast(get-tuple-element.16), dimensions={} + get-tuple-element.15 = f32[16]{0} get-tuple-element(opt-barrier.12), index=2 + custom-call.19 = f32[16]{0} custom-call(get-tuple-element.15), custom_call_target="MoveToDevice" + multiply.21 = f32[16]{0} multiply(broadcast.20, custom-call.19) + get-tuple-element.14 = f32[16]{0} get-tuple-element(opt-barrier.12), index=1 + custom-call.18 = f32[16]{0} custom-call(get-tuple-element.14), custom_call_target="MoveToDevice" + multiply.22 = f32[16]{0} multiply(multiply.21, custom-call.18) + get-tuple-element.13 = f32[16]{0} get-tuple-element(opt-barrier.12), index=0 + custom-call.17 = f32[16]{0} custom-call(get-tuple-element.13), custom_call_target="MoveToDevice" + ROOT multiply.23 = f32[16]{0} multiply(multiply.22, custom-call.17) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // Look for the following pattern: + // param constant + // __________/ | | + // / | | + // cosine sine | + // | | \____________ | + // | | \ | + // | | sine | + // | | | | + // | cosine cosine | + // | | | | + // copy(to host) copy(to host) copy(to host) | + // \ \ / | + // \______________ | | _________________/ + // \ | | / + // tuple + // | + // opt-barrier + // _____________/ / \ \_____________ + // / / \ \ + // get-tuple-element get-tuple-element get-tuple-element get-tuple-element + // | | | | + // copy(to device) copy(to device) copy(to device) broadcast + // \ \ \ / + // \ \__________ multiply + // \ \ / + // \ multiply + // \_________________________ / + // \ / + // multiply + + HloInstruction* param; + HloInstruction* constant; + HloInstruction* sine_0; + HloInstruction* sine_1; + HloInstruction* cosine_0; + HloInstruction* cosine_1; + HloInstruction* cosine_2; + HloInstruction* copy_to_host_0; + HloInstruction* copy_to_host_1; + HloInstruction* copy_to_host_2; + HloInstruction* tuple; + HloInstruction* opt_barrier; + HloInstruction* gte_0; + HloInstruction* gte_1; + HloInstruction* gte_2; + HloInstruction* gte_3; + HloInstruction* broadcast; + HloInstruction* copy_to_device_0; + HloInstruction* copy_to_device_1; + HloInstruction* copy_to_device_2; + HloInstruction* multiply_0; + HloInstruction* multiply_1; + HloInstruction* multiply_2; + + auto parameter_matcher = m::Parameter(¶m, 0); + auto first_sine_matcher = m::Op(&sine_0) + .WithOpcode(xla::HloOpcode::kSin) + .WithOperand(0, parameter_matcher); + auto opt_barrier_matcher = m::OptimizationBarrier( + &opt_barrier, + m::Tuple( + &tuple, + m::Copy(©_to_host_0, m::Op(&cosine_0) + .WithOpcode(xla::HloOpcode::kCos) + .WithOperand(0, parameter_matcher)), + m::Copy(©_to_host_1, m::Op(&cosine_1) + .WithOpcode(xla::HloOpcode::kCos) + .WithOperand(0, first_sine_matcher)), + m::Copy(©_to_host_2, + m::Op(&cosine_2) + .WithOpcode(xla::HloOpcode::kCos) + .WithOperand(0, m::Op(&sine_1) + .WithOpcode(xla::HloOpcode::kSin) + .WithOperand(0, first_sine_matcher))), + m::Constant(&constant))); + ASSERT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Multiply( + &multiply_0, + m::Multiply( + &multiply_1, + m::Multiply( + &multiply_2, + m::Broadcast(&broadcast, m::GetTupleElement( + >e_3, opt_barrier_matcher, 3)), + m::Copy(©_to_device_2, + m::GetTupleElement(>e_2, opt_barrier_matcher, 2))), + m::Copy(©_to_device_1, + m::GetTupleElement(>e_1, opt_barrier_matcher, 1))), + m::Copy(©_to_device_0, + m::GetTupleElement(>e_0, opt_barrier_matcher, 0))))); + + TestShapeHasMemorySpace(param->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(constant->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(sine_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(sine_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(cosine_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(cosine_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(cosine_2->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy_to_host_0->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(copy_to_host_1->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(copy_to_host_2->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {0}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {2}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {3}), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(opt_barrier->shape(), {0}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(opt_barrier->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(opt_barrier->shape(), {2}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(opt_barrier->shape(), {3}), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(gte_0->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte_1->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte_2->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte_3->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(broadcast->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy_to_device_0->shape(), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy_to_device_1->shape(), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy_to_device_2->shape(), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_2->shape(), Layout::kDefaultMemorySpace); + + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + +TEST_F(HostOffloaderTest, NoCopyMultipleUsers) { + const std::string& hlo_string = R"( +HloModule my_module +ENTRY main { + data_param = f32[2048] parameter(0) + offload_custom_call = f32[2048] custom-call(data_param), custom_call_target="MoveToHost" + sine = f32[2048] sine(data_param) + load_custom_call = f32[2048] custom-call(offload_custom_call), custom_call_target="MoveToDevice" + ROOT add = f32[2048] add(sine, load_custom_call) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // Look for the following pattern: + // parameter + // / \ + // sine copy + // | | + // | copy + // | / + // add + HloInstruction* param; + HloInstruction* sine; + HloInstruction* copy_to_host; + HloInstruction* copy_to_device; + HloInstruction* add; + const auto param_pattern = m::Parameter(¶m, 0); + ASSERT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Add( + &add, m::Sin(&sine, param_pattern), + m::Copy(©_to_device, m::Copy(©_to_host, param_pattern))))); + TestShapeHasMemorySpace(param->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(sine->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy_to_host->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(copy_to_device->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(add->shape(), Layout::kDefaultMemorySpace); + + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + +TEST_F(HostOffloaderTest, BasicDusDsWithMultipleBroadcastUsers) { + const std::string& hlo_string = R"( +HloModule my_module +ENTRY main { + data_param = f32[1,2048,2048] parameter(0) + index_param = s32[] parameter(1) + constant_f32_0 = f32[] constant(0) + constant_s32_0 = s32[] constant(0) + broadcast = f32[2,2048,2048] broadcast(constant_f32_0), dimensions={} + tanh = f32[2,2048,2048] tanh(broadcast) + offload_custom_call = f32[1,2048,2048] custom-call(data_param), custom_call_target="MoveToHost" + dynamic_update_slice = f32[2,2048,2048] dynamic-update-slice(broadcast, offload_custom_call, index_param, constant_s32_0, constant_s32_0) + dynamic_slice = f32[1,2048,2048] dynamic-slice(dynamic_update_slice, index_param, constant_s32_0, constant_s32_0), dynamic_slice_sizes={1,2048,2048} + ROOT load_custom_call = f32[1,2048,2048] custom-call(dynamic_slice), custom_call_target="MoveToDevice" +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // Look for the following pattern: + // "AllocateBuffer" param_0 _... + // | / / + // dynamic-update-slice _... + // | / + // dynamic-slice + HloInstruction* param; + HloInstruction* allocate_buffer; + HloInstruction* dynamic_update_slice; + HloInstruction* dynamic_slice; + ASSERT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::DynamicSlice( + &dynamic_slice, + m::DynamicUpdateSlice( + &dynamic_update_slice, + m::CustomCall(&allocate_buffer, {"AllocateBuffer"}), + m::Parameter(¶m, 0), m::Op(), m::Op(), m::Op()), + m::Op(), m::Op(), m::Op()))); + TestShapeHasMemorySpace(param->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(allocate_buffer->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_update_slice->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_slice->shape(), Layout::kDefaultMemorySpace); + + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); + + // Look for the tanh and make sure that it still uses the original broadcast. + HloInstruction* tanh = nullptr; + for (HloInstruction* instruction : + module->entry_computation()->instructions()) { + if (instruction->opcode() == HloOpcode::kTanh) { + tanh = instruction; + break; + } + } + ASSERT_NE(tanh, nullptr); + HloInstruction* broadcast; + EXPECT_THAT(tanh, GmockMatch(m::Tanh(m::Broadcast(&broadcast)))); + TestShapeHasMemorySpace(broadcast->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(tanh->shape(), Layout::kDefaultMemorySpace); +} + +TEST_F(HostOffloaderTest, BasicDusDsBitcastBeforeDus) { + const std::string& hlo_string = R"( +HloModule my_module +ENTRY main { + data_param = f32[2048,2048] parameter(0) + index_param = s32[] parameter(1) + constant_f32_0 = f32[] constant(0) + constant_s32_0 = s32[] constant(0) + broadcast = f32[2,2048,2048] broadcast(constant_f32_0), dimensions={} + offload_custom_call = f32[2048,2048] custom-call(data_param), custom_call_target="MoveToHost" + bitcast = f32[1,2048,2048] bitcast(offload_custom_call) + dynamic_update_slice = f32[2,2048,2048] dynamic-update-slice(broadcast, bitcast, index_param, constant_s32_0, constant_s32_0) + dynamic_slice = f32[1,2048,2048] dynamic-slice(dynamic_update_slice, index_param, constant_s32_0, constant_s32_0), dynamic_slice_sizes={1,2048,2048} + ROOT load_custom_call = f32[1,2048,2048] custom-call(dynamic_slice), custom_call_target="MoveToDevice" +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // Look for the following pattern: + // param_0 + // | + // "AllocateBuffer" bitcast _... + // | / / + // dynamic-update-slice _... + // | / + // dynamic-slice + HloInstruction* param; + HloInstruction* bitcast; + HloInstruction* allocate_buffer; + HloInstruction* dynamic_update_slice; + HloInstruction* dynamic_slice; + ASSERT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::DynamicSlice( + &dynamic_slice, + m::DynamicUpdateSlice( + &dynamic_update_slice, + m::CustomCall(&allocate_buffer, {"AllocateBuffer"}), + m::Bitcast(&bitcast, m::Parameter(¶m, 0)), m::Op(), + m::Op(), m::Op()), + m::Op(), m::Op(), m::Op()))); + TestShapeHasMemorySpace(param->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(bitcast->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(allocate_buffer->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_update_slice->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_slice->shape(), Layout::kDefaultMemorySpace); + + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + +// The annotation is mistakenly after the dynamic-update-slice; it should be +// before. +TEST_F(HostOffloaderTest, BasicDusDsDusAnnotationOnWrongSide) { + const std::string& hlo_string = R"( +HloModule my_module +ENTRY main { + data_param = f32[1,2048,2048] parameter(0) + index_param = s32[] parameter(1) + constant_f32_0 = f32[] constant(0) + constant_s32_0 = s32[] constant(0) + broadcast = f32[2,2048,2048] broadcast(constant_f32_0), dimensions={} + dynamic_update_slice = f32[2,2048,2048] dynamic-update-slice(broadcast, data_param, index_param, constant_s32_0, constant_s32_0) + offload_custom_call = f32[1,2048,2048] custom-call(dynamic_update_slice), custom_call_target="MoveToHost" + dynamic_slice = f32[1,2048,2048] dynamic-slice(offload_custom_call, index_param, constant_s32_0, constant_s32_0), dynamic_slice_sizes={1,2048,2048} + ROOT load_custom_call = f32[1,2048,2048] custom-call(dynamic_slice), custom_call_target="MoveToDevice" +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + absl::StatusOr statusOrChanged = RunHostOffloader(module.get()); + // The pass should return an error. + ASSERT_FALSE(statusOrChanged.ok()); +} + +// The annotation is mistakenly before the dynamic-slice; it should be after. +TEST_F(HostOffloaderTest, BasicDusDsDsAnnotationOnWrongSide) { + const std::string& hlo_string = R"( +HloModule my_module +ENTRY main { + data_param = f32[1,2048,2048] parameter(0) + index_param = s32[] parameter(1) + constant_f32_0 = f32[] constant(0) + constant_s32_0 = s32[] constant(0) + broadcast = f32[2,2048,2048] broadcast(constant_f32_0), dimensions={} + offload_custom_call = f32[1,2048,2048] custom-call(data_param), custom_call_target="MoveToHost" + dynamic_update_slice = f32[2,2048,2048] dynamic-update-slice(broadcast, offload_custom_call, index_param, constant_s32_0, constant_s32_0) + load_custom_call = f32[2,2048,2048] custom-call(dynamic_update_slice), custom_call_target="MoveToDevice" + ROOT dynamic_slice = f32[1,2048,2048] dynamic-slice(load_custom_call, index_param, constant_s32_0, constant_s32_0), dynamic_slice_sizes={1,2048,2048} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + absl::StatusOr statusOrChanged = RunHostOffloader(module.get()); + // The pass should return an error. + ASSERT_FALSE(statusOrChanged.ok()); +} + +TEST_F(HostOffloaderTest, LlmActivation) { + const std::string& hlo_string = R"( +HloModule llm_while + +producing_while_condition { + producing_condition_param = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) parameter(0) + producing_condition_current_iteration_index = s32[] get-tuple-element(producing_condition_param), index=0 + producing_condition_iteration_count = s32[] constant(96) + ROOT producing_condition_result = pred[] compare(producing_condition_current_iteration_index, producing_condition_iteration_count), direction=LT +} + +consuming_while_condition { + consuming_condition_param = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) parameter(0) + consuming_condition_current_iteration_index = s32[] get-tuple-element(consuming_condition_param), index=0 + consuming_condition_iteration_count = s32[] constant(96) + ROOT consuming_condition_result = pred[] compare(consuming_condition_current_iteration_index, consuming_condition_iteration_count), direction=LT +} + +producing_while_body { + input_tuple.0 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) parameter(0) + current_iteration_index.0 = s32[] get-tuple-element(input_tuple.0), index=0 + data_0.0 = f32[96,8,6,2048,2048] get-tuple-element(input_tuple.0), index=1 + data_1.0 = f32[96,8,6,2048,1] get-tuple-element(input_tuple.0), index=2 + constant_0.0 = s32[] constant(0) + constant_1.0 = s32[] constant(1) + constant_96 = s32[] constant(96) + + /* Create dummy data used in DUS */ + slice_data_0 = f32[1,8,6,2048,2048] constant({...}) + slice_data_1 = f32[1,8,6,2048,1] constant({...}) + + /* Build DUS index */ + compare_result.0 = pred[] compare(current_iteration_index.0, constant_0.0), direction=LT + add_result = s32[] add(current_iteration_index.0, constant_96) + select_result.0 = s32[] select(compare_result.0, add_result, current_iteration_index.0) + + /* Annotate DUS for offload */ + custom_call_0.0 = f32[1,8,6,2048,2048] custom-call(slice_data_0), custom_call_target="MoveToHost" + custom_call_1.0 = f32[1,8,6,2048,1] custom-call(slice_data_1), custom_call_target="MoveToHost" + + dynamic_update_slice_0 = f32[96,8,6,2048,2048] dynamic-update-slice(data_0.0, custom_call_0.0, select_result.0, constant_0.0, constant_0.0, constant_0.0, constant_0.0) + dynamic_update_slice_1 = f32[96,8,6,2048,1] dynamic-update-slice(data_1.0, custom_call_1.0, select_result.0, constant_0.0, constant_0.0, constant_0.0, constant_0.0) + + /* Increment iteration index */ + incremented_index.0 = s32[] add(current_iteration_index.0, constant_1.0) + ROOT tuple_result.0 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) tuple(incremented_index.0, dynamic_update_slice_0, dynamic_update_slice_1) +} + +consuming_while_body { + input_tuple.1 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) parameter(0) + current_iteration_index.1 = s32[] get-tuple-element(input_tuple.1), index=0 + data_0.1 = f32[96,8,6,2048,2048] get-tuple-element(input_tuple.1), index=1 + data_1.1 = f32[96,8,6,2048,1] get-tuple-element(input_tuple.1), index=2 + constant_0.1 = s32[] constant(0) + constant_1.1 = s32[] constant(1) + constant_95 = s32[] constant(95) + constant_191 = s32[] constant(191) + + /* Build DS index */ + subtract_0 = s32[] subtract(constant_95, current_iteration_index.1) + compare_result.1 = pred[] compare(subtract_0, constant_0.1), direction=LT + subtract_1 = s32[] subtract(constant_191, current_iteration_index.1) + select_result.1 = s32[] select(compare_result.1, subtract_1, subtract_0) + + dynamic_slice_0 = f32[1,8,6,2048,2048] dynamic-slice(data_0.1, select_result.1, constant_0.1, constant_0.1, constant_0.1, constant_0.1), dynamic_slice_sizes={1,8,6,2048,2048} + dynamic_slice_1 = f32[1,8,6,2048,1] dynamic-slice(data_1.1, select_result.1, constant_0.1, constant_0.1, constant_0.1, constant_0.1), dynamic_slice_sizes={1,8,6,2048,1} + + /* Annotate DS for offload */ + custom_call_0.1 = f32[1,8,6,2048,2048] custom-call(dynamic_slice_0), custom_call_target="MoveToDevice" + custom_call_1.1 = f32[1,8,6,2048,1] custom-call(dynamic_slice_1), custom_call_target="MoveToDevice" + + /* Do some work with the dynamic slice outputs. */ + tanh_0 = f32[1,8,6,2048,2048] tanh(custom_call_0.1) + tanh_1 = f32[1,8,6,2048,1] tanh(custom_call_1.1) + + /* Increment iteration index */ + incremented_index.1 = s32[] add(current_iteration_index.1, constant_1.1) + ROOT tuple_result.1 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) tuple(incremented_index.1, data_0.1, data_1.1) +} + +ENTRY main { + entry_param_0 = f32[] parameter(0) + broadcast_0 = f32[96,8,6,2048,2048] broadcast(entry_param_0), dimensions={} + broadcast_1 = f32[96,8,6,2048,1] broadcast(entry_param_0), dimensions={} + constant_s32_0 = s32[] constant(0) + tuple_for_producing_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) tuple(constant_s32_0, broadcast_0, broadcast_1) + producing_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) while(tuple_for_producing_while), condition=producing_while_condition, body=producing_while_body + while_output_1 = f32[96,8,6,2048,2048] get-tuple-element(producing_while), index=1 + while_output_2 = f32[96,8,6,2048,1] get-tuple-element(producing_while), index=2 + tuple_for_consuming_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) tuple(constant_s32_0, while_output_1, while_output_2) + ROOT consuming_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) while(tuple_for_consuming_while), condition=consuming_while_condition, body=consuming_while_body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // First, look for the pattern: + // producing_while + // / \ + // gte gte constant + // \ / / + // \/ / + // tuple + // | + // consuming_while + HloInstruction* consuming_while; + HloInstruction* producing_while_0; + HloInstruction* producing_while_1; + { + HloInstruction* tuple; + HloInstruction* gte_0; + HloInstruction* gte_1; + ASSERT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::While( + &consuming_while, + m::Tuple( + &tuple, m::Constant(), + m::GetTupleElement(>e_0, m::While(&producing_while_0)), + m::GetTupleElement(>e_1, m::While(&producing_while_1)))))); + ASSERT_EQ(producing_while_0, producing_while_1); + + // Check that the memory spaces were properly set. + TestShapeHasMemorySpace(gte_0->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte_1->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape(consuming_while->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape(consuming_while->shape(), {2}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape(producing_while_0->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape(producing_while_0->shape(), {2}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {2}), + kHostMemorySpaceColor); + } + + // Now, look for the AllocateBuffers leading into the producing while. + { + HloInstruction* allocate_buffer_0; + HloInstruction* allocate_buffer_1; + ASSERT_THAT(producing_while_0, + GmockMatch(m::While(m::Tuple( + m::Constant(), + m::CustomCall(&allocate_buffer_0, {"AllocateBuffer"}), + m::CustomCall(&allocate_buffer_1, {"AllocateBuffer"}))))); + // Check that the memory spaces were properly set. + ASSERT_TRUE(allocate_buffer_0->shape().has_layout()); + EXPECT_EQ(allocate_buffer_0->shape().layout().memory_space(), + kHostMemorySpaceColor); + ASSERT_TRUE(allocate_buffer_1->shape().has_layout()); + EXPECT_EQ(allocate_buffer_1->shape().layout().memory_space(), + kHostMemorySpaceColor); + } + + // There are 4 computations to look at: + // - Consuming while's body + // - Consuming while's condition + // - Producing while's body + // - Producing while's condition + + // For the condition computations, just check that the parameters have the + // right memory space. + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape( + consuming_while->while_condition()->parameter_instruction(0)->shape(), + {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape( + consuming_while->while_condition()->parameter_instruction(0)->shape(), + {2}), + kHostMemorySpaceColor); + + // Now, check the producing while for the following pattern: + // param param + // | | + // gte _... gte _... + // | / | / + // | / | / + // | / | / + // dus dus + // | / + // | / + // _ | / + // \ | / + // \ | / + // \| / + // tuple + { + HloInstruction* tuple; + HloInstruction* dynamic_update_slice_0; + HloInstruction* dynamic_update_slice_1; + HloInstruction* dynamic_update_slice_second_param_0; + HloInstruction* dynamic_update_slice_second_param_1; + HloInstruction* gte_0; + HloInstruction* gte_1; + HloInstruction* param_0; + HloInstruction* param_1; + ASSERT_THAT(producing_while_0->while_body()->root_instruction(), + GmockMatch(m::Tuple( + &tuple, m::Op(), + m::DynamicUpdateSlice( + &dynamic_update_slice_0, + m::GetTupleElement(>e_0, m::Parameter(¶m_0)), + m::Op(&dynamic_update_slice_second_param_0), m::Op(), + m::Op(), m::Op(), m::Op(), m::Op()), + m::DynamicUpdateSlice( + &dynamic_update_slice_1, + m::GetTupleElement(>e_1, m::Parameter(¶m_1)), + m::Op(&dynamic_update_slice_second_param_1), m::Op(), + m::Op(), m::Op(), m::Op(), m::Op())))); + EXPECT_EQ(param_0, param_1); + + // Check that the memory spaces were properly set. + // HOST: + // tuple subshape 1 + // tuple subshape 2 + // dynamic_update_slice_0 shape + // dynamic_update_slice_1 shape + // gte_0 shape + // gte_1 shape + // param_0 subshape 1 + // param_0 subshape 2 + // DEVICE: + // dynamic_update_slice_second_param_0 + // dynamic_update_slice_second_param_1 + + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {2}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_update_slice_0->shape(), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_update_slice_1->shape(), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte_0->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte_1->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(param_0->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(param_0->shape(), {2}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_update_slice_second_param_0->shape(), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(dynamic_update_slice_second_param_1->shape(), + Layout::kDefaultMemorySpace); + } + + // Now, check the consuming while for the following pattern: + // param + // | | + // gte gte + // | | + // ds ds + { + // Since we do not do anything meaningful with the result of the + // dynamic-slices, there is no easy way to access them from the root. + // Instead, search from the parameter and find all dynamic-slices. + EXPECT_EQ(consuming_while->while_body()->parameter_instructions().size(), + 1); + const HloInstruction* param = + consuming_while->while_body()->parameter_instruction(0); + absl::flat_hash_set dynamic_slices; + std::stack stack; + stack.emplace(param); + while (!stack.empty()) { + const HloInstruction* current = stack.top(); + stack.pop(); + if (current->opcode() == HloOpcode::kDynamicSlice) { + dynamic_slices.emplace(current); + continue; + } + // Add all users. + for (const HloInstruction* user : current->users()) { + stack.emplace(user); + } + } + // There should only be two dynamic-slices. + ASSERT_EQ(dynamic_slices.size(), 2); + for (const HloInstruction* dynamic_slice : dynamic_slices) { + const HloInstruction* get_tuple_element; + const HloInstruction* parameter; + ASSERT_THAT( + dynamic_slice, + GmockMatch(m::DynamicSlice( + m::GetTupleElement(&get_tuple_element, m::Parameter(¶meter)), + m::Op(), m::Op(), m::Op(), m::Op(), m::Op()))); + + // Check that the memory spaces were properly set. + // HOST: + // parameter subshape 1 + // parameter subshape 2 + // get_tuple_element + // DEVICE: + // dynamic_slice + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(parameter->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(parameter->shape(), {2}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(get_tuple_element->shape(), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_slice->shape(), + Layout::kDefaultMemorySpace); + } + } + + // Finally, ensure that all annotations have been removed. + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + +TEST_F(HostOffloaderTest, LlmActivationDsWithReshape) { + const std::string& hlo_string = R"( +HloModule llm_while + +producing_while_condition { + producing_condition_param = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) parameter(0) + producing_condition_current_iteration_index = s32[] get-tuple-element(producing_condition_param), index=0 + producing_condition_iteration_count = s32[] constant(96) + ROOT producing_condition_result = pred[] compare(producing_condition_current_iteration_index, producing_condition_iteration_count), direction=LT +} + +consuming_while_condition { + consuming_condition_param = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) parameter(0) + consuming_condition_current_iteration_index = s32[] get-tuple-element(consuming_condition_param), index=0 + consuming_condition_iteration_count = s32[] constant(96) + ROOT consuming_condition_result = pred[] compare(consuming_condition_current_iteration_index, consuming_condition_iteration_count), direction=LT +} + +producing_while_body { + input_tuple.0 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) parameter(0) + current_iteration_index.0 = s32[] get-tuple-element(input_tuple.0), index=0 + data_0.0 = f32[96,8,6,2048,2048] get-tuple-element(input_tuple.0), index=1 + data_1.0 = f32[96,8,6,2048,1] get-tuple-element(input_tuple.0), index=2 + constant_0.0 = s32[] constant(0) + constant_1.0 = s32[] constant(1) + constant_96 = s32[] constant(96) + + /* Create dummy data used in DUS */ + slice_data_0 = f32[1,8,6,2048,2048] constant({...}) + slice_data_1 = f32[1,8,6,2048,1] constant({...}) + + /* Build DUS index */ + compare_result.0 = pred[] compare(current_iteration_index.0, constant_0.0), direction=LT + add_result = s32[] add(current_iteration_index.0, constant_96) + select_result.0 = s32[] select(compare_result.0, add_result, current_iteration_index.0) + + /* Annotate DUS for offload */ + custom_call_0.0 = f32[1,8,6,2048,2048] custom-call(slice_data_0), custom_call_target="MoveToHost" + custom_call_1.0 = f32[1,8,6,2048,1] custom-call(slice_data_1), custom_call_target="MoveToHost" + + dynamic_update_slice_0 = f32[96,8,6,2048,2048] dynamic-update-slice(data_0.0, custom_call_0.0, select_result.0, constant_0.0, constant_0.0, constant_0.0, constant_0.0) + dynamic_update_slice_1 = f32[96,8,6,2048,1] dynamic-update-slice(data_1.0, custom_call_1.0, select_result.0, constant_0.0, constant_0.0, constant_0.0, constant_0.0) + + /* Increment iteration index */ + incremented_index.0 = s32[] add(current_iteration_index.0, constant_1.0) + ROOT tuple_result.0 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) tuple(incremented_index.0, dynamic_update_slice_0, dynamic_update_slice_1) +} + +consuming_while_body { + input_tuple.1 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) parameter(0) + current_iteration_index.1 = s32[] get-tuple-element(input_tuple.1), index=0 + data_0.1 = f32[96,8,6,2048,2048] get-tuple-element(input_tuple.1), index=1 + data_1.1 = f32[96,8,6,2048,1] get-tuple-element(input_tuple.1), index=2 + constant_0.1 = s32[] constant(0) + constant_1.1 = s32[] constant(1) + constant_95 = s32[] constant(95) + constant_191 = s32[] constant(191) + + /* Build DS index */ + subtract_0 = s32[] subtract(constant_95, current_iteration_index.1) + compare_result.1 = pred[] compare(subtract_0, constant_0.1), direction=LT + subtract_1 = s32[] subtract(constant_191, current_iteration_index.1) + select_result.1 = s32[] select(compare_result.1, subtract_1, subtract_0) + + dynamic_slice_0 = f32[1,8,6,2048,2048] dynamic-slice(data_0.1, select_result.1, constant_0.1, constant_0.1, constant_0.1, constant_0.1), dynamic_slice_sizes={1,8,6,2048,2048} + dynamic_slice_1 = f32[1,8,6,2048,1] dynamic-slice(data_1.1, select_result.1, constant_0.1, constant_0.1, constant_0.1, constant_0.1), dynamic_slice_sizes={1,8,6,2048,1} + rs = f32[1,8,6,2048,2048] reshape(dynamic_slice_0) + rs2 = f32[1,8,6,2048,1] reshape(dynamic_slice_1) + /* Annotate DS for offload */ + custom_call_0.1 = f32[1,8,6,2048,2048] custom-call(rs), custom_call_target="MoveToDevice" + custom_call_1.1 = f32[1,8,6,2048,1] custom-call(rs2), custom_call_target="MoveToDevice" + + /* Do some work with the dynamic slice outputs. */ + tanh_0 = f32[1,8,6,2048,2048] tanh(custom_call_0.1) + tanh_1 = f32[1,8,6,2048,1] tanh(custom_call_1.1) + + /* Increment iteration index */ + incremented_index.1 = s32[] add(current_iteration_index.1, constant_1.1) + ROOT tuple_result.1 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) tuple(incremented_index.1, data_0.1, data_1.1) +} + +ENTRY main { + entry_param_0 = f32[] parameter(0) + broadcast_0 = f32[96,8,6,2048,2048] broadcast(entry_param_0), dimensions={} + broadcast_1 = f32[96,8,6,2048,1] broadcast(entry_param_0), dimensions={} + constant_s32_0 = s32[] constant(0) + tuple_for_producing_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) tuple(constant_s32_0, broadcast_0, broadcast_1) + producing_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) while(tuple_for_producing_while), condition=producing_while_condition, body=producing_while_body + while_output_1 = f32[96,8,6,2048,2048] get-tuple-element(producing_while), index=1 + while_output_2 = f32[96,8,6,2048,1] get-tuple-element(producing_while), index=2 + tuple_for_consuming_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) tuple(constant_s32_0, while_output_1, while_output_2) + ROOT consuming_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) while(tuple_for_consuming_while), condition=consuming_while_condition, body=consuming_while_body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // First, look for the pattern: + // producing_while + // / \ + // gte gte constant + // \ / / + // \/ / + // tuple + // | + // consuming_while + HloInstruction* consuming_while; + HloInstruction* producing_while_0; + HloInstruction* producing_while_1; + { + HloInstruction* tuple; + HloInstruction* gte_0; + HloInstruction* gte_1; + ASSERT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::While( + &consuming_while, + m::Tuple( + &tuple, m::Constant(), + m::GetTupleElement(>e_0, m::While(&producing_while_0)), + m::GetTupleElement(>e_1, m::While(&producing_while_1)))))); + ASSERT_EQ(producing_while_0, producing_while_1); + + // Check that the memory spaces were properly set. + TestShapeHasMemorySpace(gte_0->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte_1->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape(consuming_while->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape(consuming_while->shape(), {2}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape(producing_while_0->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape(producing_while_0->shape(), {2}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {2}), + kHostMemorySpaceColor); + } + + // Now, look for the AllocateBuffers leading into the producing while. + { + HloInstruction* allocate_buffer_0; + HloInstruction* allocate_buffer_1; + ASSERT_THAT(producing_while_0, + GmockMatch(m::While(m::Tuple( + m::Constant(), + m::CustomCall(&allocate_buffer_0, {"AllocateBuffer"}), + m::CustomCall(&allocate_buffer_1, {"AllocateBuffer"}))))); + // Check that the memory spaces were properly set. + ASSERT_TRUE(allocate_buffer_0->shape().has_layout()); + EXPECT_EQ(allocate_buffer_0->shape().layout().memory_space(), + kHostMemorySpaceColor); + ASSERT_TRUE(allocate_buffer_1->shape().has_layout()); + EXPECT_EQ(allocate_buffer_1->shape().layout().memory_space(), + kHostMemorySpaceColor); + } + + // There are 4 computations to look at: + // - Consuming while's body + // - Consuming while's condition + // - Producing while's body + // - Producing while's condition + + // For the condition computations, just check that the parameters have the + // right memory space. + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape( + consuming_while->while_condition()->parameter_instruction(0)->shape(), + {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape( + consuming_while->while_condition()->parameter_instruction(0)->shape(), + {2}), + kHostMemorySpaceColor); + + // Now, check the producing while for the following pattern: + // param param + // | | + // gte _... gte _... + // | / | / + // | / | / + // | / | / + // dus dus + // | / + // | / + // _ | / + // \ | / + // \ | / + // \| / + // tuple + { + HloInstruction* tuple; + HloInstruction* dynamic_update_slice_0; + HloInstruction* dynamic_update_slice_1; + HloInstruction* dynamic_update_slice_second_param_0; + HloInstruction* dynamic_update_slice_second_param_1; + HloInstruction* gte_0; + HloInstruction* gte_1; + HloInstruction* param_0; + HloInstruction* param_1; + ASSERT_THAT(producing_while_0->while_body()->root_instruction(), + GmockMatch(m::Tuple( + &tuple, m::Op(), + m::DynamicUpdateSlice( + &dynamic_update_slice_0, + m::GetTupleElement(>e_0, m::Parameter(¶m_0)), + m::Op(&dynamic_update_slice_second_param_0), m::Op(), + m::Op(), m::Op(), m::Op(), m::Op()), + m::DynamicUpdateSlice( + &dynamic_update_slice_1, + m::GetTupleElement(>e_1, m::Parameter(¶m_1)), + m::Op(&dynamic_update_slice_second_param_1), m::Op(), + m::Op(), m::Op(), m::Op(), m::Op())))); + EXPECT_EQ(param_0, param_1); + + // Check that the memory spaces were properly set. + // HOST: + // tuple subshape 1 + // tuple subshape 2 + // dynamic_update_slice_0 shape + // dynamic_update_slice_1 shape + // gte_0 shape + // gte_1 shape + // param_0 subshape 1 + // param_0 subshape 2 + // DEVICE: + // dynamic_update_slice_second_param_0 + // dynamic_update_slice_second_param_1 + + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {2}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_update_slice_0->shape(), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_update_slice_1->shape(), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte_0->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte_1->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(param_0->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(param_0->shape(), {2}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_update_slice_second_param_0->shape(), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(dynamic_update_slice_second_param_1->shape(), + Layout::kDefaultMemorySpace); + } + + // Now, check the consuming while for the following pattern: + // param + // | | + // gte gte + // | | + // ds ds + { + // Since we do not do anything meaningful with the result of the + // dynamic-slices, there is no easy way to access them from the root. + // Instead, search from the parameter and find all dynamic-slices. + EXPECT_EQ(consuming_while->while_body()->parameter_instructions().size(), + 1); + const HloInstruction* param = + consuming_while->while_body()->parameter_instruction(0); + absl::flat_hash_set dynamic_slices; + std::stack stack; + stack.emplace(param); + while (!stack.empty()) { + const HloInstruction* current = stack.top(); + stack.pop(); + if (current->opcode() == HloOpcode::kDynamicSlice) { + dynamic_slices.emplace(current); + continue; + } + // Add all users. + for (const HloInstruction* user : current->users()) { + stack.emplace(user); + } + } + // There should only be two dynamic-slices. + ASSERT_EQ(dynamic_slices.size(), 2); + for (const HloInstruction* dynamic_slice : dynamic_slices) { + const HloInstruction* get_tuple_element; + const HloInstruction* parameter; + ASSERT_THAT( + dynamic_slice, + GmockMatch(m::DynamicSlice( + m::GetTupleElement(&get_tuple_element, m::Parameter(¶meter)), + m::Op(), m::Op(), m::Op(), m::Op(), m::Op()))); + + // Check that the memory spaces were properly set. + // HOST: + // parameter subshape 1 + // parameter subshape 2 + // get_tuple_element + // DEVICE: + // dynamic_slice + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(parameter->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(parameter->shape(), {2}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(get_tuple_element->shape(), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_slice->shape(), + Layout::kDefaultMemorySpace); + } + } + + // Finally, ensure that all annotations have been removed. + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + +TEST_F(HostOffloaderTest, LlmActivationHostMemoryMultipleConsumers) { + const std::string& hlo_string = R"( +HloModule llm_while + +producing_while_condition { + producing_condition_param = (s32[], f32[96,8,6,2048,2048]) parameter(0) + producing_condition_current_iteration_index = s32[] get-tuple-element(producing_condition_param), index=0 + producing_condition_iteration_count = s32[] constant(96) + ROOT producing_condition_result = pred[] compare(producing_condition_current_iteration_index, producing_condition_iteration_count), direction=LT +} + +consuming_while_condition { + consuming_condition_param = (s32[], f32[96,8,6,2048,2048]) parameter(0) + consuming_condition_current_iteration_index = s32[] get-tuple-element(consuming_condition_param), index=0 + consuming_condition_iteration_count = s32[] constant(96) + ROOT consuming_condition_result = pred[] compare(consuming_condition_current_iteration_index, consuming_condition_iteration_count), direction=LT +} + +producing_while_body { + input_tuple.0 = (s32[], f32[96,8,6,2048,2048]) parameter(0) + current_iteration_index.0 = s32[] get-tuple-element(input_tuple.0), index=0 + data_0.0 = f32[96,8,6,2048,2048] get-tuple-element(input_tuple.0), index=1 + constant_0.0 = s32[] constant(0) + constant_1.0 = s32[] constant(1) + constant_96 = s32[] constant(96) + + /* Create dummy data used in DUS */ + slice_data_0 = f32[1,8,6,2048,2048] constant({...}) + + /* Build DUS index */ + compare_result.0 = pred[] compare(current_iteration_index.0, constant_0.0), direction=LT + add_result = s32[] add(current_iteration_index.0, constant_96) + select_result.0 = s32[] select(compare_result.0, add_result, current_iteration_index.0) + + /* Annotate DUS for offload */ + custom_call_0.0 = f32[1,8,6,2048,2048] custom-call(slice_data_0), custom_call_target="MoveToHost" + + dynamic_update_slice_0 = f32[96,8,6,2048,2048] dynamic-update-slice(data_0.0, custom_call_0.0, select_result.0, constant_0.0, constant_0.0, constant_0.0, constant_0.0) + + /* Increment iteration index */ + incremented_index.0 = s32[] add(current_iteration_index.0, constant_1.0) + ROOT tuple_result.0 = (s32[], f32[96,8,6,2048,2048]) tuple(incremented_index.0, dynamic_update_slice_0) +} + +consuming_while_body { + input_tuple.1 = (s32[], f32[96,8,6,2048,2048]) parameter(0) + current_iteration_index.1 = s32[] get-tuple-element(input_tuple.1), index=0 + data_0.1 = f32[96,8,6,2048,2048] get-tuple-element(input_tuple.1), index=1 + constant_0.1 = s32[] constant(0) + constant_1.1 = s32[] constant(1) + constant_95 = s32[] constant(95) + constant_191 = s32[] constant(191) + + /* Build DS index */ + subtract_0 = s32[] subtract(constant_95, current_iteration_index.1) + compare_result.1 = pred[] compare(subtract_0, constant_0.1), direction=LT + subtract_1 = s32[] subtract(constant_191, current_iteration_index.1) + select_result.1 = s32[] select(compare_result.1, subtract_1, subtract_0) + + dynamic_slice_0 = f32[1,8,6,2048,2048] dynamic-slice(data_0.1, select_result.1, constant_0.1, constant_0.1, constant_0.1, constant_0.1), dynamic_slice_sizes={1,8,6,2048,2048} + + /* Annotate DS for offload */ + custom_call_0.1 = f32[1,8,6,2048,2048] custom-call(dynamic_slice_0), custom_call_target="MoveToDevice" + + /* Do some work with the dynamic slice outputs. */ + tanh_0 = f32[1,8,6,2048,2048] tanh(custom_call_0.1) + + /* Increment iteration index */ + incremented_index.1 = s32[] add(current_iteration_index.1, constant_1.1) + ROOT tuple_result.1 = (s32[], f32[96,8,6,2048,2048]) tuple(incremented_index.1, data_0.1) +} + +ENTRY main { + entry_param_0 = f32[] parameter(0) + entry_param_1 = s32[] parameter(1) + entry_param_2 = s32[] parameter(2) + broadcast_0 = f32[96,8,6,2048,2048] broadcast(entry_param_0), dimensions={} + constant_s32_0 = s32[] constant(0) + tuple_for_producing_while = (s32[], f32[96,8,6,2048,2048]) tuple(constant_s32_0, broadcast_0) + producing_while = (s32[], f32[96,8,6,2048,2048]) while(tuple_for_producing_while), condition=producing_while_condition, body=producing_while_body + while_output_1 = f32[96,8,6,2048,2048] get-tuple-element(producing_while), index=1 + tuple_for_consuming_while = (s32[], f32[96,8,6,2048,2048]) tuple(constant_s32_0, while_output_1) + consuming_while = (s32[], f32[96,8,6,2048,2048]) while(tuple_for_consuming_while), condition=consuming_while_condition, body=consuming_while_body + second_while_output = f32[96,8,6,2048,2048] get-tuple-element(consuming_while), index=1 + final_dynamic_slice_0 = f32[1,8,6,2048,2048] dynamic-slice(second_while_output, entry_param_1, constant_s32_0, constant_s32_0, constant_s32_0, constant_s32_0), dynamic_slice_sizes={1,8,6,2048,2048} + final_host_to_device_custom_call_0 = f32[1,8,6,2048,2048] custom-call(final_dynamic_slice_0), custom_call_target="MoveToDevice" + final_slice_0 = f32[1,8,6,2048,2048] slice(second_while_output), slice={[41:42], [0:8], [0:6], [0:2048], [0:2048]} + final_host_to_device_custom_call_1 = f32[1,8,6,2048,2048] custom-call(final_slice_0), custom_call_target="MoveToDevice" + ROOT add = f32[1,8,6,2048,2048] add(final_host_to_device_custom_call_0, final_host_to_device_custom_call_1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // First, look for the pattern: + // producing_while + // | + // constant gte + // \ | + // \ | + // tuple + // | + // consuming_while + // | + // gte + // / \ + // dynamic-slice dynamic-slice + // \ / + // add + // Note: The second dynamic-slice was originally a slice. + HloInstruction* consuming_while; + HloInstruction* producing_while; + { + HloInstruction* tuple; + HloInstruction* gte_between_whiles; + HloInstruction* final_gte; + HloInstruction* dynamic_slice_0; + HloInstruction* dynalic_slice_1; + HloInstruction* add; + auto pattern_ending_in_gte = m::GetTupleElement( + &final_gte, + m::While(&consuming_while, + m::Tuple(&tuple, m::Constant(), + m::GetTupleElement(>e_between_whiles, + m::While(&producing_while))))); + ASSERT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Add(&add, + m::DynamicSlice(&dynamic_slice_0, pattern_ending_in_gte, + m::Op(), m::Op(), m::Op(), m::Op(), m::Op()), + m::DynamicSlice(&dynalic_slice_1, pattern_ending_in_gte, + m::ConstantScalar(41), m::Op(), m::Op(), + m::Op(), m::Op())))); + + // Check that the memory spaces were properly set. + TestShapeHasMemorySpace(gte_between_whiles->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape(consuming_while->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape(producing_while->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(final_gte->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_slice_0->shape(), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(dynalic_slice_1->shape(), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(add->shape(), Layout::kDefaultMemorySpace); + } + + // Now, look for the AllocateBuffers leading into the producing while. + { + HloInstruction* allocate_buffer; + ASSERT_THAT(producing_while, + GmockMatch(m::While(m::Tuple( + m::Constant(), + m::CustomCall(&allocate_buffer, {"AllocateBuffer"}))))); + // Check that the memory spaces were properly set. + ASSERT_TRUE(allocate_buffer->shape().has_layout()); + EXPECT_EQ(allocate_buffer->shape().layout().memory_space(), + kHostMemorySpaceColor); + } + + // There are 4 computations to look at: + // - Consuming while's body + // - Consuming while's condition + // - Producing while's body + // - Producing while's condition + + // For the condition computations, just check that the parameters have the + // right memory space. + TestShapeHasMemorySpace( + ShapeUtil::GetSubshape( + consuming_while->while_condition()->parameter_instruction(0)->shape(), + {1}), + kHostMemorySpaceColor); + + // Now, check the producing while for the following pattern: + // param + // | + // gte _ + // | / + // | / + // _ dus + // \ | + // tuple + { + HloInstruction* tuple; + HloInstruction* dynamic_update_slice; + HloInstruction* dynamic_update_slice_second_param; + HloInstruction* gte; + HloInstruction* param; + ASSERT_THAT( + producing_while->while_body()->root_instruction(), + GmockMatch(m::Tuple(&tuple, m::Op(), + m::DynamicUpdateSlice( + &dynamic_update_slice, + m::GetTupleElement(>e, m::Parameter(¶m)), + m::Op(&dynamic_update_slice_second_param), + m::Op(), m::Op(), m::Op(), m::Op(), m::Op())))); + + // Check that the memory spaces were properly set. + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_update_slice->shape(), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(gte->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(param->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_update_slice_second_param->shape(), + Layout::kDefaultMemorySpace); + } + + // Now, check the consuming while for the following pattern: + // param + // | + // gte + // | + // ds + { + // Since we do not do anything meaningful with the result of the + // dynamic-slices, there is no easy way to access them from the root. + // Instead, search from the parameter and find all dynamic-slices. + EXPECT_EQ(consuming_while->while_body()->parameter_instructions().size(), + 1); + const HloInstruction* param = + consuming_while->while_body()->parameter_instruction(0); + absl::flat_hash_set dynamic_slices; + std::stack stack; + stack.emplace(param); + while (!stack.empty()) { + const HloInstruction* current = stack.top(); + stack.pop(); + if (current->opcode() == HloOpcode::kDynamicSlice) { + dynamic_slices.emplace(current); + continue; + } + // Add all users. + for (const HloInstruction* user : current->users()) { + stack.emplace(user); + } + } + // There should only be one dynamic-slice. + ASSERT_EQ(dynamic_slices.size(), 1); + const HloInstruction* dynamic_slice = *dynamic_slices.begin(); + const HloInstruction* get_tuple_element; + const HloInstruction* parameter; + ASSERT_THAT( + dynamic_slice, + GmockMatch(m::DynamicSlice( + m::GetTupleElement(&get_tuple_element, m::Parameter(¶meter)), + m::Op(), m::Op(), m::Op(), m::Op(), m::Op()))); + + // Check that the memory spaces were properly set. + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(parameter->shape(), {1}), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(get_tuple_element->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(dynamic_slice->shape(), + Layout::kDefaultMemorySpace); + } + + // Finally, ensure that all annotations have been removed. + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + +TEST_F(HostOffloaderTest, InsertExtraCopyForScheduling) { + const std::string& hlo_string = R"( +HloModule llm_while + +producing_while_condition { + producing_condition_param = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1], f32[1,8,6,2048,1]) parameter(0) + producing_condition_current_iteration_index = s32[] get-tuple-element(producing_condition_param), index=0 + producing_condition_iteration_count = s32[] constant(96) + ROOT producing_condition_result = pred[] compare(producing_condition_current_iteration_index, producing_condition_iteration_count), direction=LT +} + +consuming_while_condition { + consuming_condition_param = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) parameter(0) + consuming_condition_current_iteration_index = s32[] get-tuple-element(consuming_condition_param), index=0 + consuming_condition_iteration_count = s32[] constant(96) + ROOT consuming_condition_result = pred[] compare(consuming_condition_current_iteration_index, consuming_condition_iteration_count), direction=LT +} + +producing_while_body { + input_tuple.0 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1], f32[1,8,6,2048,1]) parameter(0) + current_iteration_index.0 = s32[] get-tuple-element(input_tuple.0), index=0 + data_0.0 = f32[96,8,6,2048,2048] get-tuple-element(input_tuple.0), index=1 + data_1.0 = f32[96,8,6,2048,1] get-tuple-element(input_tuple.0), index=2 + data_2.1 = f32[1,8,6,2048,1] get-tuple-element(input_tuple.0), index=3 + constant_0.0 = s32[] constant(0) + constant_1.0 = s32[] constant(1) + constant_96 = s32[] constant(96) + + /* Create dummy data used in DUS */ + slice_data_0 = f32[1,8,6,2048,2048] constant({...}) + slice_data_1 = f32[1,8,6,2048,1] constant({...}) + + /* Build DUS index */ + compare_result.0 = pred[] compare(current_iteration_index.0, constant_0.0), direction=LT + add_result = s32[] add(current_iteration_index.0, constant_96) + select_result.0 = s32[] select(compare_result.0, add_result, current_iteration_index.0) + + /* Annotate DUS for offload */ + custom_call_0.0 = f32[1,8,6,2048,2048] custom-call(slice_data_0), custom_call_target="MoveToHost" + custom_call_1.0 = f32[1,8,6,2048,1] custom-call(data_2.1), custom_call_target="MoveToHost" + + dynamic_update_slice_0 = f32[96,8,6,2048,2048] dynamic-update-slice(data_0.0, custom_call_0.0, select_result.0, constant_0.0, constant_0.0, constant_0.0, constant_0.0) + dynamic_update_slice_1 = f32[96,8,6,2048,1] dynamic-update-slice(data_1.0, custom_call_1.0, select_result.0, constant_0.0, constant_0.0, constant_0.0, constant_0.0) + + /* Increment iteration index */ + incremented_index.0 = s32[] add(current_iteration_index.0, constant_1.0) + ROOT tuple_result.0 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1], f32[1,8,6,2048,1]) tuple(incremented_index.0, dynamic_update_slice_0, dynamic_update_slice_1, data_2.1) +} + +consuming_while_body { + input_tuple.1 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) parameter(0) + current_iteration_index.1 = s32[] get-tuple-element(input_tuple.1), index=0 + data_0.1 = f32[96,8,6,2048,2048] get-tuple-element(input_tuple.1), index=1 + data_1.1 = f32[96,8,6,2048,1] get-tuple-element(input_tuple.1), index=2 + constant_0.1 = s32[] constant(0) + constant_1.1 = s32[] constant(1) + constant_95 = s32[] constant(95) + constant_191 = s32[] constant(191) + + /* Build DS index */ + subtract_0 = s32[] subtract(constant_95, current_iteration_index.1) + compare_result.1 = pred[] compare(subtract_0, constant_0.1), direction=LT + subtract_1 = s32[] subtract(constant_191, current_iteration_index.1) + select_result.1 = s32[] select(compare_result.1, subtract_1, subtract_0) + + dynamic_slice_0 = f32[1,8,6,2048,2048] dynamic-slice(data_0.1, select_result.1, constant_0.1, constant_0.1, constant_0.1, constant_0.1), dynamic_slice_sizes={1,8,6,2048,2048} + dynamic_slice_1 = f32[1,8,6,2048,1] dynamic-slice(data_1.1, select_result.1, constant_0.1, constant_0.1, constant_0.1, constant_0.1), dynamic_slice_sizes={1,8,6,2048,1} + + /* Annotate DS for offload */ + custom_call_0.1 = f32[1,8,6,2048,2048] custom-call(dynamic_slice_0), custom_call_target="MoveToDevice" + custom_call_1.1 = f32[1,8,6,2048,1] custom-call(dynamic_slice_1), custom_call_target="MoveToDevice" + + /* Do some work with the dynamic slice outputs. */ + tanh_0 = f32[1,8,6,2048,2048] tanh(custom_call_0.1) + tanh_1 = f32[1,8,6,2048,1] tanh(custom_call_1.1) + + /* Increment iteration index */ + incremented_index.1 = s32[] add(current_iteration_index.1, constant_1.1) + ROOT tuple_result.1 = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) tuple(incremented_index.1, data_0.1, data_1.1) +} + +ENTRY main { + entry_param_0 = f32[] parameter(0) + broadcast_0 = f32[96,8,6,2048,2048] broadcast(entry_param_0), dimensions={} + broadcast_1 = f32[96,8,6,2048,1] broadcast(entry_param_0), dimensions={} + broadcast_2 = f32[1,8,6,2048,1] broadcast(entry_param_0), dimensions={} + constant_s32_0 = s32[] constant(0) + tuple_for_producing_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1], f32[1,8,6,2048,1]) tuple(constant_s32_0, broadcast_0, broadcast_1, broadcast_2) + producing_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1], f32[1,8,6,2048,1]) while(tuple_for_producing_while), condition=producing_while_condition, body=producing_while_body + while_output_1 = f32[96,8,6,2048,2048] get-tuple-element(producing_while), index=1 + while_output_2 = f32[96,8,6,2048,1] get-tuple-element(producing_while), index=2 + tuple_for_consuming_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) tuple(constant_s32_0, while_output_1, while_output_2) + ROOT consuming_while = (s32[], f32[96,8,6,2048,2048], f32[96,8,6,2048,1]) while(tuple_for_consuming_while), condition=consuming_while_condition, body=consuming_while_body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + EXPECT_TRUE(changed); + // Finally, ensure that all annotations have been removed. + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); + const HloInstruction* dus0 = + FindInstruction(module.get(), "dynamic_update_slice_0"); + const HloInstruction* dus1 = + FindInstruction(module.get(), "dynamic_update_slice_1"); + EXPECT_THAT(dus0, GmockMatch(m::DynamicUpdateSlice(m::Op(), m::Constant(), + m::Op(), m::Op(), m::Op(), + m::Op(), m::Op()))); + EXPECT_THAT(dus1, GmockMatch(m::DynamicUpdateSlice(m::Op(), m::Copy(), + m::Op(), m::Op(), m::Op(), + m::Op(), m::Op()))); +} + +TEST_F(HostOffloaderTest, ParameterStreaming) { + const std::string& hlo_string = R"( +HloModule ParameterStreaming, entry_computation_layout={(s32[2,1]{1,0:T(2,128)S(5)}, s32[2,1]{1,0:T(2,128)})->(s32[2,1]{1,0:T(2,128)}, s32[2,1]{1,0:T(2,128)})} + +ENTRY main { + param_0 = s32[2,1]{1,0} parameter(0) + param_1 = s32[2,1]{1,0} parameter(1) + constant_2 = s32[] constant(2) + constant_4 = s32[] constant(4) + broadcast_0 = s32[2,1]{1,0} broadcast(constant_2), dimensions={} + multiply_0 = s32[2,1]{1,0} multiply(param_1, broadcast_0) + custom_call = s32[2,1]{1,0} custom-call(param_0), custom_call_target="MoveToDevice" + multiply_1 = s32[2,1]{1,0} multiply(multiply_0, custom_call) + broadcast_1 = s32[2,1]{1,0} broadcast(constant_4), dimensions={} + multiply_2 = s32[2,1]{1,0} multiply(multiply_1, broadcast_1) + ROOT tuple = (s32[2,1]{1,0}, s32[2,1]{1,0}) tuple(multiply_2, multiply_1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // Look for the following pattern: + // constant + // | + // param1 broadcast param0 + // \ / / + // multiply copy + // \ / + // \ / + // multiply constant + // | | | + // | ---+---broadcast + // | / | + // multiply | + // \ | + // tuple + HloInstruction* param_1; + HloInstruction* broadcast_0; + HloInstruction* multiply_0; + HloInstruction* param_0; + HloInstruction* copy; + HloInstruction* multiply_1; + HloInstruction* broadcast_1; + HloInstruction* multiply_2; + HloInstruction* tuple; + auto multiplyPattern = + m::Multiply(&multiply_1, + m::Multiply(&multiply_0, m::Parameter(¶m_1), + m::Broadcast(&broadcast_0, m::ConstantScalar(2))), + m::Copy(©, m::Parameter(¶m_0))); + ASSERT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + &tuple, + m::Multiply(&multiply_2, multiplyPattern, + m::Broadcast(&broadcast_1, m::ConstantScalar(4))), + multiplyPattern))); + TestShapeHasMemorySpace(param_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(broadcast_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(param_0->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(copy->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(broadcast_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_2->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {0}), + Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {1}), + Layout::kDefaultMemorySpace); + + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + +TEST_F(HostOffloaderTest, OutputStreaming) { + const std::string& hlo_string = R"( + HloModule ParameterStreaming, entry_computation_layout={(s32[2,1]{1,0:T(2,128)}, s32[2,1]{1,0:T(2,128)})->(s32[2,1]{1,0:T(2,128)S(5)}, s32[2,1]{1,0:T(2,128)})} + + ENTRY main { + param_0 = s32[2,1]{1,0} parameter(0) + param_1 = s32[2,1]{1,0} parameter(1) + constant_2 = s32[] constant(2) + constant_4 = s32[] constant(4) + broadcast_0 = s32[2,1]{1,0} broadcast(constant_2), dimensions={} + multiply_0 = s32[2,1]{1,0} multiply(param_1, broadcast_0) + multiply_1 = s32[2,1]{1,0} multiply(multiply_0, param_0) + broadcast_1 = s32[2,1]{1,0} broadcast(constant_4), dimensions={} + multiply_2 = s32[2,1]{1,0} multiply(multiply_1, broadcast_1) + custom_call = s32[2,1]{1,0} custom-call(multiply_2), custom_call_target="MoveToHost" + ROOT tuple = (s32[2,1]{1,0}, s32[2,1]{1,0}) tuple(custom_call, multiply_1) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // Look for the following pattern: + // constant + // | + // param1 broadcast param0 + // \ / / + // multiply / + // \ / + // \ / + // multiply constant + // | | | + // | ---+---broadcast + // | / | + // multiply | + // | | + // copy | + // \ | + // tuple + HloInstruction* param_1; + HloInstruction* broadcast_0; + HloInstruction* multiply_0; + HloInstruction* param_0; + HloInstruction* multiply_1; + HloInstruction* broadcast_1; + HloInstruction* multiply_2; + HloInstruction* copy; + HloInstruction* tuple; + auto multiplyPattern = + m::Multiply(&multiply_1, + m::Multiply(&multiply_0, m::Parameter(¶m_1), + m::Broadcast(&broadcast_0, m::ConstantScalar(2))), + m::Parameter(¶m_0)); + ASSERT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + &tuple, + m::Copy(©, m::Multiply( + &multiply_2, multiplyPattern, + m::Broadcast(&broadcast_1, m::ConstantScalar(4)))), + multiplyPattern))); + TestShapeHasMemorySpace(param_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(broadcast_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(param_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(broadcast_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_2->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {0}), + Layout::kHostMemorySpace); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {1}), + Layout::kDefaultMemorySpace); + + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + +TEST_F(HostOffloaderTest, OutputStreamingCustomCallRoot) { + const std::string& hlo_string = R"( + HloModule ParameterStreaming, entry_computation_layout={(s32[2,1]{1,0:T(2,128)}, s32[2,1]{1,0:T(2,128)})->s32[2,1]{1,0:T(2,128)S(5)}} + + ENTRY main { + param_0 = s32[2,1]{1,0} parameter(0) + param_1 = s32[2,1]{1,0} parameter(1) + constant_2 = s32[] constant(2) + constant_4 = s32[] constant(4) + broadcast_0 = s32[2,1]{1,0} broadcast(constant_2), dimensions={} + multiply_0 = s32[2,1]{1,0} multiply(param_1, broadcast_0) + multiply_1 = s32[2,1]{1,0} multiply(multiply_0, param_0) + broadcast_1 = s32[2,1]{1,0} broadcast(constant_4), dimensions={} + multiply_2 = s32[2,1]{1,0} multiply(multiply_1, broadcast_1) + ROOT custom_call = s32[2,1]{1,0} custom-call(multiply_2), custom_call_target="MoveToHost" + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // Look for the following pattern: + // constant + // | + // param1 broadcast param0 + // \ / / + // multiply / + // \ / + // \ / + // multiply constant + // | | + // | ---+---broadcast + // | / + // multiply + // | + // copy + HloInstruction* param_1; + HloInstruction* broadcast_0; + HloInstruction* multiply_0; + HloInstruction* param_0; + HloInstruction* multiply_1; + HloInstruction* broadcast_1; + HloInstruction* multiply_2; + HloInstruction* copy; + auto multiplyPattern = + m::Multiply(&multiply_1, + m::Multiply(&multiply_0, m::Parameter(¶m_1), + m::Broadcast(&broadcast_0, m::ConstantScalar(2))), + m::Parameter(¶m_0)); + ASSERT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Copy( + ©, m::Multiply(&multiply_2, multiplyPattern, + m::Broadcast(&broadcast_1, + m::ConstantScalar(4)))))); + TestShapeHasMemorySpace(param_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(broadcast_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(param_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(broadcast_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_2->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy->shape(), kHostMemorySpaceColor); + + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + +} // namespace + +} // namespace xla diff --git a/xla/service/human_readable_profile_builder.cc b/xla/service/human_readable_profile_builder.cc index b78de778978e0..3e46a3e775415 100644 --- a/xla/service/human_readable_profile_builder.cc +++ b/xla/service/human_readable_profile_builder.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/human_readable_profile_builder.h b/xla/service/human_readable_profile_builder.h index cbd7ebdcd0731..8f192eebdc46b 100644 --- a/xla/service/human_readable_profile_builder.h +++ b/xla/service/human_readable_profile_builder.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/indexed_array_analysis.cc b/xla/service/indexed_array_analysis.cc index 90b15f5032523..fe4b20cdc3781 100644 --- a/xla/service/indexed_array_analysis.cc +++ b/xla/service/indexed_array_analysis.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -81,7 +81,7 @@ std::string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { } } -StatusOr IndexedArrayAnalysis::GetArrayFor( +absl::StatusOr IndexedArrayAnalysis::GetArrayFor( const HloInstruction* instr) { auto it = cache_.find(instr); if (it != cache_.end()) { @@ -137,7 +137,7 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache( return OkStatus(); } -StatusOr IndexedArrayAnalysis::ComputeArrayFor( +absl::StatusOr IndexedArrayAnalysis::ComputeArrayFor( const HloInstruction* instr) { Array* computed_array; if (instr->IsElementwise() && instr->operand_count() == 1) { @@ -184,12 +184,12 @@ StatusOr IndexedArrayAnalysis::ComputeArrayFor( return computed_array; } -StatusOr IndexedArrayAnalysis::ComputeArrayForConstant( +absl::StatusOr IndexedArrayAnalysis::ComputeArrayForConstant( const Literal& literal) { return Construct(&literal); } -StatusOr IndexedArrayAnalysis::FoldGatherOfGather( +absl::StatusOr IndexedArrayAnalysis::FoldGatherOfGather( ScalarIndexedArray* source, Array* indices, int64_t source_dim, absl::Span output_dims, Shape shape) { // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)). @@ -255,7 +255,7 @@ StatusOr IndexedArrayAnalysis::FoldGatherOfGather( std::move(shape)); } -StatusOr IndexedArrayAnalysis::ComputeArrayForGather( +absl::StatusOr IndexedArrayAnalysis::ComputeArrayForGather( const Shape& shape, const GatherDimensionNumbers& dim_numbers, absl::Span slice_sizes, Array* source, Array* indices) { if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { @@ -468,7 +468,7 @@ Shape StripDegenerateDimensions(const Shape& shape) { } }; // namespace -StatusOr +absl::StatusOr IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims( ScalarIndexedArray* operand) { const Shape& shape = operand->shape(); @@ -525,7 +525,8 @@ IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims( StripDegenerateDimensions(operand->shape())); } -StatusOr IndexedArrayAnalysis::ReshapeToAddDegenerateDims( +absl::StatusOr +IndexedArrayAnalysis::ReshapeToAddDegenerateDims( ScalarIndexedArray* operand, absl::Span degenerate_dims) { if (degenerate_dims.empty()) { return operand; @@ -602,7 +603,7 @@ StatusOr IndexedArrayAnalysis::ReshapeToAddDegenerateDims( InlinedVectorToVector(new_output_dims), new_result_shape); } -StatusOr IndexedArrayAnalysis::FoldReshapeOfGather( +absl::StatusOr IndexedArrayAnalysis::FoldReshapeOfGather( const Shape& shape, ScalarIndexedConstantArray* operand) { VLOG(3) << "FoldReshapeOfGather(" << ToString(operand) << ")"; @@ -636,7 +637,7 @@ StatusOr IndexedArrayAnalysis::FoldReshapeOfGather( degenerate_result_dims); } -StatusOr +absl::StatusOr IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( const Shape& shape, ScalarIndexedConstantArray* scalar_indexed) { VLOG(3) << "FoldReshapeOfGatherNoDegenerateDims(" << ToString(scalar_indexed) @@ -803,7 +804,7 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( output_dims_for_new_scalar_indexed_node, shape); } -StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( +absl::StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( const Shape& shape, Array* operand) { if (ShapeUtil::Compatible(operand->shape(), shape)) { return operand; @@ -828,7 +829,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( return Construct(operand, shape); } -StatusOr +absl::StatusOr IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, Array* lhs, Array* rhs) { @@ -949,7 +950,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, scalar_indexed_const->shape()); } -StatusOr +absl::StatusOr IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, Array* operand) { auto* scalar_indexed_const = @@ -1032,7 +1033,7 @@ bool CanFoldDotIntoIndexedArray( } // namespace -StatusOr +absl::StatusOr IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs, @@ -1067,7 +1068,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( SpanToVector(lhs->output_dims()), shape); } -StatusOr +absl::StatusOr IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, const PrecisionConfig& precision_config, ConstantArray* lhs, @@ -1103,7 +1104,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( SpanToVector(rhs->output_dims()), shape); } -StatusOr IndexedArrayAnalysis::ComputeArrayForDot( +absl::StatusOr IndexedArrayAnalysis::ComputeArrayForDot( const Shape& shape, const DotDimensionNumbers& dim_numbers, const PrecisionConfig& precision_config, Array* lhs, Array* rhs) { // Intuitively, if @@ -1149,7 +1150,7 @@ absl::string_view IndexedArrayAnalysisPrinterPass::name() const { return "indexed-array-analysis-printer-pass"; } -StatusOr IndexedArrayAnalysisPrinterPass::Run( +absl::StatusOr IndexedArrayAnalysisPrinterPass::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { if (!VLOG_IS_ON(2)) { diff --git a/xla/service/indexed_array_analysis.h b/xla/service/indexed_array_analysis.h index 4ea4d30e87d8d..c0746b1f68dc6 100644 --- a/xla/service/indexed_array_analysis.h +++ b/xla/service/indexed_array_analysis.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -245,7 +245,7 @@ class IndexedArrayAnalysis { // NB! By inspecting the implementation, you may be able to infer a stronger // caching guarantee than what is mentioned above. Nevertheless, what is // stated above is the contract. - StatusOr GetArrayFor(const HloInstruction* instr); + absl::StatusOr GetArrayFor(const HloInstruction* instr); // Pretty-prints the expression rooted at `root`. std::string ToString(Array* root, bool print_constants = false); @@ -257,28 +257,27 @@ class IndexedArrayAnalysis { // Creates an Array instance for `instr` under the assumption that all // operations of `instr` are present in `cache_`. - StatusOr ComputeArrayFor(const HloInstruction* instr); + absl::StatusOr ComputeArrayFor(const HloInstruction* instr); - StatusOr ComputeArrayForConstant(const Literal& literal); + absl::StatusOr ComputeArrayForConstant(const Literal& literal); - StatusOr ComputeArrayForGather( + absl::StatusOr ComputeArrayForGather( const Shape& shape, const GatherDimensionNumbers& dim_numbers, absl::Span slice_sizes, Array* source, Array* indices); - StatusOr ComputeArrayForDotWithIndexedLhs( + absl::StatusOr ComputeArrayForDotWithIndexedLhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs, ConstantArray* rhs); - StatusOr ComputeArrayForDotWithIndexedRhs( + absl::StatusOr ComputeArrayForDotWithIndexedRhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, const PrecisionConfig& precision_config, ConstantArray* lhs, ScalarIndexedConstantArray* rhs); - StatusOr ComputeArrayForDot(const Shape& shape, - const DotDimensionNumbers& dim_numbers, - const PrecisionConfig& precision_config, - Array* lhs, Array* rhs); + absl::StatusOr ComputeArrayForDot( + const Shape& shape, const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, Array* lhs, Array* rhs); // This tries to fold a ScalarIndexedArray which has another // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a @@ -301,30 +300,32 @@ class IndexedArrayAnalysis { // // I2 = [I0[i] for i in I1] // G1 = [Arr[i] for i in I2] - StatusOr FoldGatherOfGather( + absl::StatusOr FoldGatherOfGather( ScalarIndexedArray* source, Array* indices, int64_t source_dim, absl::Span output_dims, Shape shape); // Reshapes a scalar-indexed node to remove the degenerate dimensions in its // output. The result is always a scalar-indexed node. - StatusOr ReshapeToRemoveDegenerateDims( + absl::StatusOr ReshapeToRemoveDegenerateDims( ScalarIndexedArray* operand); // Reshapes a scalar-indexed node such that the result has the degenerate // dimensions `degenerate_dims`. The result is always a scalar-indexed node. - StatusOr ReshapeToAddDegenerateDims( + absl::StatusOr ReshapeToAddDegenerateDims( ScalarIndexedArray* operand, absl::Span degenerate_dims); - StatusOr FoldReshapeOfGather( + absl::StatusOr FoldReshapeOfGather( const Shape& shape, ScalarIndexedConstantArray* operand); - StatusOr FoldReshapeOfGatherNoDegenerateDims( + absl::StatusOr FoldReshapeOfGatherNoDegenerateDims( const Shape& shape, ScalarIndexedConstantArray* scalar_indexed); - StatusOr ComputeArrayForReshape(const Shape& shape, Array* operand); + absl::StatusOr ComputeArrayForReshape(const Shape& shape, + Array* operand); - StatusOr ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, - Array* lhs, Array* rhs); - StatusOr ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, - Array* operand); + absl::StatusOr ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, + Array* lhs, + Array* rhs); + absl::StatusOr ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, + Array* operand); template T* Construct(Args&&... args) { @@ -352,7 +353,8 @@ class IndexedArrayAnalysis { return &owned_literals_.back(); } - StatusOr TakeOwnership(StatusOr literal_or_error) { + absl::StatusOr TakeOwnership( + absl::StatusOr literal_or_error) { TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error)); owned_literals_.push_back(std::move(literal)); return &owned_literals_.back(); @@ -370,7 +372,7 @@ class IndexedArrayAnalysisPrinterPass : public HloModulePass { public: absl::string_view name() const override; using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/indexed_array_analysis_test.cc b/xla/service/indexed_array_analysis_test.cc index c11d483b3f1c3..d711c8f1fdc6b 100644 --- a/xla/service/indexed_array_analysis_test.cc +++ b/xla/service/indexed_array_analysis_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/instruction_fusion.cc b/xla/service/instruction_fusion.cc index 140144f1086e3..bcc9d2e4549f4 100644 --- a/xla/service/instruction_fusion.cc +++ b/xla/service/instruction_fusion.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -177,12 +177,14 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kAllReduceStart: case HloOpcode::kAllReduceDone: case HloOpcode::kAllToAll: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermuteDone: case HloOpcode::kCollectivePermuteStart: case HloOpcode::kCustomCall: case HloOpcode::kDomain: case HloOpcode::kDot: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFft: @@ -498,7 +500,7 @@ class ReversePostOrderFusionQueue : public FusionQueue { } // namespace -std::vector InstructionFusion::GetFusionComputations( +std::vector InstructionFusion::GetNonFusionComputations( HloModule* module, const absl::flat_hash_set& execution_threads) { // Use sorted computations because fusion configuration is order-sensitive. @@ -510,7 +512,7 @@ std::unique_ptr InstructionFusion::GetFusionQueue( return std::make_unique(computation); } -StatusOr InstructionFusion::Run( +absl::StatusOr InstructionFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; @@ -520,7 +522,8 @@ StatusOr InstructionFusion::Run( bool dump_fusion = module->config().debug_options().xla_dump_fusion_visualization(); - for (auto* computation : GetFusionComputations(module, execution_threads)) { + for (auto* computation : + GetNonFusionComputations(module, execution_threads)) { CHECK(!computation->IsFusionComputation()); std::unique_ptr reachability = HloReachabilityMap::Build(computation); diff --git a/xla/service/instruction_fusion.h b/xla/service/instruction_fusion.h index 81a22adf0747f..8ace349f141db 100644 --- a/xla/service/instruction_fusion.h +++ b/xla/service/instruction_fusion.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -101,24 +101,6 @@ class FusionDecision { return *this; } - // Executes the given fusibility checks in order, until one fails. Returns the - // result of the first failure (or a `FusionDecision` with `CanFuse() == true` - // if none did). - // Usage: - // FusionDecision result = FusionDecision::All(std::tuple{FnOne, FnTwo}, - // arg1, arg2) - template - static FusionDecision All(const std::tuple& checks, - const Args&... args) { - FusionDecision result = {}; - std::apply( - [&](auto&&... fns) { - ((result = result ? fns(args...) : result), ...); - }, - checks); - return result; - } - // Appends to explanation, or turns the decision negative. FusionDecision operator<<(absl::string_view explanation) const { return {absl::StrCat(explanation_.value_or(""), explanation)}; @@ -137,6 +119,14 @@ class FusionDecision { std::optional explanation_; }; +#define RETURN_IF_NOT_FUSIBLE(...) \ + do { \ + ::xla::FusionDecision _decision = (__VA_ARGS__); \ + if (TF_PREDICT_FALSE(!_decision.CanFuse())) { \ + return _decision; \ + } \ + } while (0) + // HLO pass which performs instruction fusion. Instructions are fused // "vertically", meaning producing instructions are fused into their consumers // with the intent that the loops which compute their values will be fused in @@ -158,7 +148,7 @@ class InstructionFusion : public HloModulePass { // Run instruction fusion on the given computation. Returns whether the // computation was changed (instructions were fused). using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -176,8 +166,9 @@ class InstructionFusion : public HloModulePass { const HloInstruction* consumer); protected: - // Returns a list of computations on which Fusion is performed. - virtual std::vector GetFusionComputations( + // Returns a list of computations that are not fusion computations. These + // computations contain instructions which are candidates for fusions. + virtual std::vector GetNonFusionComputations( HloModule* module, const absl::flat_hash_set& execution_threads); diff --git a/xla/service/instruction_fusion_test.cc b/xla/service/instruction_fusion_test.cc index fc7029ceaeb34..9d225b9107b97 100644 --- a/xla/service/instruction_fusion_test.cc +++ b/xla/service/instruction_fusion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -770,56 +770,4 @@ TEST_F(FusionDecisionTest, NotFusionPossibleDisjunction) { EXPECT_FALSE(!a || !b); } -TEST_F(FusionDecisionTest, AllExecutesAllChecks) { - bool first_called = false; - bool second_called = false; - auto result = FusionDecision::All(std::tuple{ - [&]() -> FusionDecision { - first_called = true; - return {}; - }, - [&]() -> FusionDecision { - second_called = true; - return {}; - }, - }); - - EXPECT_TRUE(result.CanFuse()); - EXPECT_TRUE(first_called); - EXPECT_TRUE(second_called); -} - -TEST_F(FusionDecisionTest, AllShortCircuits) { - bool second_called = false; - auto result = FusionDecision::All(std::tuple{ - [&]() -> FusionDecision { return "failure"; }, - [&]() -> FusionDecision { - second_called = true; - return {}; - }, - }); - - EXPECT_EQ(result.Explain(), "failure"); - EXPECT_FALSE(second_called); -} - -TEST_F(FusionDecisionTest, AllForwardsArgs) { - int64_t sum = 0; - auto result = FusionDecision::All( - std::tuple{ - [&](int64_t value1, int64_t value2) -> FusionDecision { - sum += value1; - return {}; - }, - [&](int64_t value1, int64_t value2) -> FusionDecision { - sum += value2; - return {}; - }, - }, - 42, 9000); - - EXPECT_TRUE(result.CanFuse()); - EXPECT_EQ(sum, 9042); -} - } // namespace xla diff --git a/xla/service/instruction_hoister.cc b/xla/service/instruction_hoister.cc index 551eaa0572037..58e27f3dbf118 100644 --- a/xla/service/instruction_hoister.cc +++ b/xla/service/instruction_hoister.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -129,7 +129,7 @@ bool HoistConstantOperations( } } // namespace -StatusOr InstructionHoister::Run( +absl::StatusOr InstructionHoister::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool modified = false; diff --git a/xla/service/instruction_hoister.h b/xla/service/instruction_hoister.h index c066c4c8e5eec..0f0f1683e314d 100644 --- a/xla/service/instruction_hoister.h +++ b/xla/service/instruction_hoister.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ class InstructionHoister : public HloModulePass { absl::string_view name() const override { return "instruction-hoister"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/latency_hiding_scheduler.cc b/xla/service/latency_hiding_scheduler.cc index 84f8888cf763e..aa0a2fd45c43b 100644 --- a/xla/service/latency_hiding_scheduler.cc +++ b/xla/service/latency_hiding_scheduler.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -52,11 +52,17 @@ limitations under the License. namespace xla { namespace { + +const int64_t kDefaultMemorySpace = 0; + bool IsNopInstruction(const HloInstruction& hlo) { HloOpcode op = hlo.opcode(); return op == HloOpcode::kGetTupleElement || op == HloOpcode::kBitcast || op == HloOpcode::kConstant || op == HloOpcode::kParameter || - hlo.IsEffectiveBitcast(); + op == HloOpcode::kBroadcast || op == HloOpcode::kIota || + hlo.IsEffectiveBitcast() || + (op == HloOpcode::kTuple && hlo.user_count() == 1 && + hlo.users().front()->opcode() == HloOpcode::kWhile); } } // namespace @@ -71,6 +77,10 @@ CanonicalAsyncOp DefaultGetCanonicalAsyncOp(const HloInstruction& hlo) { return {HloOpcode::kAsyncStart, HloOpcode::kAllGather}; case HloOpcode::kCollectivePermuteStart: return {HloOpcode::kAsyncStart, HloOpcode::kCollectivePermute}; + case HloOpcode::kCopyStart: + return {HloOpcode::kAsyncStart, HloOpcode::kCopy}; + case HloOpcode::kCopyDone: + return {HloOpcode::kAsyncDone, HloOpcode::kCopy}; case HloOpcode::kAllReduceDone: return {HloOpcode::kAsyncDone, HloOpcode::kAllReduce}; case HloOpcode::kAllGatherDone: @@ -130,7 +140,9 @@ bool AsyncTracker::IsSupportedAsyncDone(const HloInstruction& hlo) const { case HloOpcode::kAllToAll: case HloOpcode::kAllGather: case HloOpcode::kAllReduce: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: + case HloOpcode::kCopy: case HloOpcode::kReduceScatter: return true; default: @@ -157,7 +169,9 @@ bool AsyncTracker::IsSupportedAsyncStart(const HloInstruction& hlo) const { case HloOpcode::kAllToAll: case HloOpcode::kAllGather: case HloOpcode::kAllReduce: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: + case HloOpcode::kCopy: case HloOpcode::kReduceScatter: return true; default: @@ -167,7 +181,7 @@ bool AsyncTracker::IsSupportedAsyncStart(const HloInstruction& hlo) const { return false; } -ResourcesVector AsyncTracker::GetResourcesFromInstruction( +ResourcesVector AsyncTracker::GetResourcesFromInstructionImpl( const HloInstruction& hlo) const { CanonicalAsyncOp op = GetCanonicalAsyncOp(hlo); auto get_resource_for_op = [](HloOpcode op) -> ResourceType { @@ -178,8 +192,12 @@ ResourcesVector AsyncTracker::GetResourcesFromInstruction( return ResourceType::kAllGather; case HloOpcode::kAllToAll: return ResourceType::kAllToAll; + case HloOpcode::kCollectiveBroadcast: + return ResourceType::kCollectiveBroadcast; case HloOpcode::kCollectivePermute: return ResourceType::kCollectivePermute; + case HloOpcode::kCopy: + return ResourceType::kCopy; case HloOpcode::kReduceScatter: return ResourceType::kReduceScatter; default: @@ -244,6 +262,14 @@ ResourcesVector AsyncTracker::GetResourcesFromInstruction( } } +ResourcesVector AsyncTracker::GetResourcesFromInstruction( + const HloInstruction& hlo) const { + if (!resources_cache_.contains(&hlo)) { + resources_cache_.insert({&hlo, GetResourcesFromInstructionImpl(hlo)}); + } + return resources_cache_.at(&hlo); +} + int64_t AsyncTracker::GetNumResourcesPerInstruction( ResourceType resource_type, const HloInstruction& instr) const { return GetNumResourcesPerInstruction(ResourceTypeToIndex(resource_type), @@ -312,9 +338,14 @@ int64_t AsyncTracker::GetNumResourcesPerInstruction( void AsyncTracker::SetConcurrentResourceLimits( absl::flat_hash_map& max_concurrent_resource) const { // Set the limits for default resources + max_concurrent_resource[ResourceTypeToIndex( + ResourceType::kCollectiveBroadcast)] = + config_.collective_broadcast_overlap_limit; max_concurrent_resource[ResourceTypeToIndex( ResourceType::kCollectivePermute)] = config_.collective_permute_overlap_limit; + max_concurrent_resource[ResourceTypeToIndex(ResourceType::kCopy)] = + config_.copy_overlap_limit; max_concurrent_resource[ResourceTypeToIndex(ResourceType::kAllToAll)] = config_.all_to_all_overlap_limit; max_concurrent_resource[ResourceTypeToIndex(ResourceType::kAllGather)] = @@ -348,14 +379,20 @@ absl::string_view AsyncTracker::GetResourceName(int64_t resource_type) const { return "kAllGather"; case ResourceTypeToIndex(ResourceType::kAllReduce): return "kAllReduce"; + case ResourceTypeToIndex(ResourceType::kCollectiveBroadcast): + return "kCollectiveBroadcast"; case ResourceTypeToIndex(ResourceType::kCollectivePermute): return "kCollectivePermute"; + case ResourceTypeToIndex(ResourceType::kCopy): + return "kCopy"; case ResourceTypeToIndex(ResourceType::kSendRecv): return "kSendRecv"; case ResourceTypeToIndex(ResourceType::kSendHost): return "kSendHost"; case ResourceTypeToIndex(ResourceType::kRecvHost): return "kRecvHost"; + case ResourceTypeToIndex(ResourceType::kReduceScatter): + return "kReduceScatter"; default: return "Not a valid default resource"; } @@ -527,6 +564,10 @@ void MemoryPressureTracker::Initialize( if (!initial_live_buffers.empty()) { for (HloBuffer::Id id : initial_live_buffers) { auto& buffer = buffer_tracker_.GetBufferInfo(id); + if (buffer.value->values()[0]->shape().has_layout() && + buffer.value->values()[0]->shape().layout().memory_space() != 0) { + continue; + } live_buffers_[buffer.value->id()] = 1; initial_memory_pressure_ += buffer.buffer_size; } @@ -553,7 +594,10 @@ void MemoryPressureTracker::UpdateBuffers(const HloInstruction* instruction) { for (auto* op : instruction->operands()) { auto& output_values = output_buffers_[op]; for (auto& info : output_values) { - if (ShouldSkipBufferAllocations(instruction, info.second)) { + if (ShouldSkipBufferAllocations(instruction, info.second) || + (info.first.value->values()[0]->shape().has_layout() && + info.first.value->values()[0]->shape().layout().memory_space() != + kDefaultMemorySpace)) { continue; } if (live_buffers_[info.first.value->id()] == 0) { @@ -569,6 +613,11 @@ void MemoryPressureTracker::UpdateBuffers(const HloInstruction* instruction) { CHECK(it != defined_buffers_.end()); if (!ShouldSkipBufferReleases(instruction)) { for (auto& b : it->second) { + if (b.value->values()[0]->shape().has_layout() && + b.value->values()[0]->shape().layout().memory_space() != + kDefaultMemorySpace) { + continue; + } if (live_buffers_[b.value->id()] != 0) { if (b.first_definition == instruction) { live_memory_usage_ -= b.buffer_size; @@ -604,7 +653,10 @@ std::pair MemoryPressureTracker::MemoryPressureDifference( auto it = output_buffers_.find(op); CHECK(it != output_buffers_.end()); for (auto& b : it->second) { - if (ShouldSkipBufferAllocations(instruction, b.second)) { + if (ShouldSkipBufferAllocations(instruction, b.second) || + (b.first.value->values()[0]->shape().has_layout() && + b.first.value->values()[0]->shape().layout().memory_space() != + kDefaultMemorySpace)) { continue; } if (!live_buffers_[b.first.value->id()]) { @@ -618,6 +670,11 @@ std::pair MemoryPressureTracker::MemoryPressureDifference( // Decrease memory pressure if some buffers are released. if (!ShouldSkipBufferReleases(instruction)) { for (auto& b : it->second) { + if (b.value->values()[0]->shape().has_layout() && + b.value->values()[0]->shape().layout().memory_space() != + kDefaultMemorySpace) { + continue; + } if (live_buffers_[b.value->id()]) { if (b.first_definition == instruction) { increase -= b.buffer_size; @@ -658,17 +715,18 @@ class ReadySetLt { DefaultSchedulerCore::CandidateResult operator()( DefaultSchedulerCore::ScheduleCandidate& a, DefaultSchedulerCore::ScheduleCandidate& b) const { + // Schedule according to ForceEarly. + if (auto value = DefaultSchedulerCore::ChooseBestCandidate( + a.node->GetForceEarly(), a, b.node->GetForceEarly(), b, + "kForceEarly")) { + return *value; + } // Schedule according to ForceDelay first. if (auto value = DefaultSchedulerCore::ChooseBestCandidate( !a.node->GetForceDelay(), a, !b.node->GetForceDelay(), b, "kForceDelay")) { return *value; } - if (early_target_scheduling_rule_) { - if (auto value = early_target_scheduling_rule_(a, b)) { - return *value; - } - } // Prioritize instructions that are NOPs as they have no memory pressure // issue and unlock different operations for being scheduled. if (auto value = DefaultSchedulerCore::ChooseBestCandidate( @@ -730,6 +788,11 @@ class ReadySetLt { return *value; } } + if (early_target_scheduling_rule_) { + if (auto value = early_target_scheduling_rule_(a, b)) { + return *value; + } + } // Some heuristic that try to prioritize unlocking "done" instructions // so that we can perform overlap. More fancy heuristics can be used by // discovering the closest "done" to every instruction and prioritize @@ -837,8 +900,10 @@ class ReadySetLt { HloGraphNode::TimeCost b_cost_diff = std::abs( latest_ready - sched_state_.current_time - b.node->GetCost()); if (auto value = DefaultSchedulerCore::ChooseBestCandidate( - a_cost_diff < b_cost_diff, a, b_cost_diff < a_cost_diff, b, - "kAvoidWaste")) { + !a.node->DoesReleaseAnyResource() && a_cost_diff < b_cost_diff, + a, + !b.node->DoesReleaseAnyResource() && b_cost_diff < a_cost_diff, + b, "kAvoidWaste")) { return *value; } } @@ -1248,7 +1313,7 @@ bool DefaultSchedulerCore::AddOccupierToResource( return true; } -StatusOr DefaultSchedulerCore::ScheduleNode( +absl::StatusOr DefaultSchedulerCore::ScheduleNode( HloGraphNode* n, DefaultSchedulerCore::SchedulingState* sched_state) const { // Insert the node into the sequence and mark it as scheduled. sched_state->new_sequence_reversed.push_back( @@ -1503,6 +1568,8 @@ HloScheduleGraph::HloScheduleGraph( HloGraphNode* ctrl_succ_node = ctrl_succ_node_it->second.get(); add_dependency_helper(instr_node, ctrl_succ_node); } + //TODO: HloBuffer outdegree have no decrese, remove it + continue; // To make sure an instruction that aliases with the buffer produced // by the async-done operation is not scheduled in between the start and the // done instruction as that buffer is in flux when the start happens. @@ -1676,7 +1743,7 @@ Status DefaultSchedulerCore::SchedulingStep(SchedulingState* sched_state) { return OkStatus(); } -StatusOr> +absl::StatusOr> DefaultSchedulerCore::ScheduleComputation(const HloComputation* computation) { const HloSchedule& module_schedule = computation->parent()->schedule(); MemoryPressureTracker memory_pressure_tracker( @@ -1812,6 +1879,7 @@ LatencyHidingScheduler::LatencyHidingStatistics( kReduceScatter, kSend, kRecv, + kCollectiveBroadcast, }; auto opcode_to_async_kind = [](HloOpcode opcode) { switch (opcode) { @@ -1819,6 +1887,8 @@ LatencyHidingScheduler::LatencyHidingStatistics( return AsyncKind::kAllGather; case HloOpcode::kAllReduce: return AsyncKind::kAllReduce; + case HloOpcode::kCollectiveBroadcast: + return AsyncKind::kCollectiveBroadcast; case HloOpcode::kCollectivePermute: return AsyncKind::kCollectivePermute; case HloOpcode::kAllToAll: @@ -1877,15 +1947,21 @@ LatencyHidingScheduler::LatencyHidingStatistics( .push_back({instr, current_time, curr_pos}); } else if (async_tracker->IsSupportedAsyncDone(*instr)) { const HloInstruction* start_instr = instr->operand(0); - auto it = find_outstanding_async(start_instr); - const HloGraphNode& start_node = schedule_graph.GetNode(std::get<0>(*it)); - auto edge_it = find_node_successor_edge(start_node, instr_node); - const double async_wasted_cycles = - std::max(0.0, edge_it->Latency() - (current_time - std::get<1>(*it))); - AsyncKind kind = opcode_to_async_kind( - async_tracker->GetCanonicalAsyncOp(*start_instr).inner); - wasted_time_per_collective[kind] += async_wasted_cycles; - current_time += async_wasted_cycles; + // TODO(b/329731042): Handle pipelined Send/Recv in while-body, which + // is the only situation where an async done operand is not an async + // start. + if (async_tracker->IsSupportedAsyncStart(*start_instr)) { + auto it = find_outstanding_async(start_instr); + const HloGraphNode& start_node = + schedule_graph.GetNode(std::get<0>(*it)); + auto edge_it = find_node_successor_edge(start_node, instr_node); + const double async_wasted_cycles = std::max( + 0.0, edge_it->Latency() - (current_time - std::get<1>(*it))); + AsyncKind kind = opcode_to_async_kind( + async_tracker->GetCanonicalAsyncOp(*start_instr).inner); + wasted_time_per_collective[kind] += async_wasted_cycles; + current_time += async_wasted_cycles; + } } curr_pos++; } @@ -1909,6 +1985,8 @@ LatencyHidingScheduler::LatencyHidingStatistics( wasted_time_per_collective[AsyncKind::kAllGather], /*all_reduce_wasted_cycles=*/ wasted_time_per_collective[AsyncKind::kAllReduce], + /*collective_broadcast_wasted_cycles=*/ + wasted_time_per_collective[AsyncKind::kCollectiveBroadcast], /*collective_permute_wasted_cycles=*/ wasted_time_per_collective[AsyncKind::kCollectivePermute], /*all_to_all_wasted_cycles=*/ @@ -1936,6 +2014,7 @@ std::string LatencyHidingScheduler::SchedulerStatisticsString( absl::StrAppend(&result, "Total wasted cycles: ", sched_stats.all_gather_wasted_cycles + sched_stats.all_reduce_wasted_cycles + + sched_stats.collective_broadcast_wasted_cycles + sched_stats.collective_permute_wasted_cycles + sched_stats.all_to_all_wasted_cycles + sched_stats.reduce_scatter_wasted_cycles + @@ -1946,6 +2025,8 @@ std::string LatencyHidingScheduler::SchedulerStatisticsString( sched_stats.all_reduce_wasted_cycles, "\n"); absl::StrAppend(&result, "Wasted cycles for all-gather: ", sched_stats.all_gather_wasted_cycles, "\n"); + absl::StrAppend(&result, "Wasted cycles for collective-broadcast: ", + sched_stats.collective_broadcast_wasted_cycles, "\n"); absl::StrAppend(&result, "Wasted cycles for collective-permute: ", sched_stats.collective_permute_wasted_cycles, "\n"); absl::StrAppend(&result, "Wasted cycles for all-to-all: ", @@ -1971,7 +2052,7 @@ void LatencyHidingScheduler::LogScheduleStatistics( async_tracker_.get(), shape_size_bytes_))); } -StatusOr LatencyHidingScheduler::Run( +absl::StatusOr LatencyHidingScheduler::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(5) << "Original module:"; diff --git a/xla/service/latency_hiding_scheduler.h b/xla/service/latency_hiding_scheduler.h index ca6ccf54f4c49..436eeaad40df0 100644 --- a/xla/service/latency_hiding_scheduler.h +++ b/xla/service/latency_hiding_scheduler.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,8 +38,8 @@ namespace xla { struct CanonicalAsyncOp { HloOpcode outer; // kAsyncStart or kAsyncDone - HloOpcode inner; // kAllReduce, kAllGather, kAllToAll, kCollectivePermute, - // or kReduceScatter + HloOpcode inner; // kAllReduce, kAllGather, kAllToAll, kCollectiveBroadcast, + // kCollectivePermute, or kReduceScatter }; CanonicalAsyncOp DefaultGetCanonicalAsyncOp(const HloInstruction& hlo); @@ -56,11 +56,13 @@ enum class ResourceType { kAllGather = 2, kAllReduce = 3, kCollectivePermute = 4, - kReduceScatter = 5, - kSendRecv = 6, - kSendHost = 7, - kRecvHost = 8, - kNumResources = 9, + kCopy = 5, + kReduceScatter = 6, + kSendRecv = 7, + kSendHost = 8, + kRecvHost = 9, + kCollectiveBroadcast = 10, + kNumResources = 11, kTargetDefinedResourcesBound = 10000, }; @@ -92,6 +94,7 @@ class HloGraphNode; class HloScheduleGraph; struct SchedulerConfig { + int64_t collective_broadcast_overlap_limit = 1; int64_t collective_permute_overlap_limit = 1; int64_t all_to_all_overlap_limit = 1; int64_t all_gather_overlap_limit = 1; @@ -99,6 +102,7 @@ struct SchedulerConfig { int64_t reduce_scatter_overlap_limit = 1; int64_t send_recv_overlap_limit = 1; int64_t send_recv_host_overlap_limit = 1; + int64_t copy_overlap_limit = 1; uint64_t memory_limit = UINT64_MAX; bool schedule_send_recvs = false; // Consider send recv as the same resource. Some platforms do not take well @@ -181,6 +185,10 @@ class AsyncTracker { // Returns if this is an Async op start that the scheduler supports. virtual bool IsSupportedAsyncStart(const HloInstruction& hlo) const; + // Returns resources used (i.e., occupied or released) by this instruction + virtual ResourcesVector GetResourcesFromInstructionImpl( + const HloInstruction& hlo) const; + // Returns resources used (i.e., occupied or released) by this instruction virtual ResourcesVector GetResourcesFromInstruction( const HloInstruction& hlo) const; @@ -258,13 +266,17 @@ class AsyncTracker { absl::flat_hash_map> async_in_computation_cache_; GetCanonicalAsyncOpFunc get_canonical_async_op_; + + protected: + mutable absl::flat_hash_map + resources_cache_; }; // Base class for the core scheduling algorithm. class SchedulerCore { public: virtual Status InitializeScheduler(const HloModule* module) = 0; - virtual StatusOr> ScheduleComputation( + virtual absl::StatusOr> ScheduleComputation( const HloComputation* computation) = 0; virtual ~SchedulerCore() = default; virtual int64_t GetMemoryPeak() = 0; @@ -330,6 +342,8 @@ class HloGraphNode { void SetGraphDepth(TimeCost graph_depth) { graph_depth_ = graph_depth; } bool GetForceDelay() const { return force_delay_; } void SetForceDelay(bool force_delay) { force_delay_ = force_delay; } + bool GetForceEarly() const { return force_early_; } + void SetForceEarly(bool force_early) { force_early_ = force_early; } ResourcesVector GetResources() const { return resources_; } bool DoesOccupyAnyResource() const { return absl::c_any_of(resources_, [](const ResourcePair& resource) { @@ -402,6 +416,7 @@ class HloGraphNode { absl::StrAppend(&result, "Depth: ", depth_, "\n"); absl::StrAppend(&result, "Graph Depth: ", graph_depth_, "\n"); absl::StrAppend(&result, "Force Delay: ", force_delay_, "\n"); + absl::StrAppend(&result, "Force Early: ", force_early_, "\n"); absl::StrAppend(&result, "Predecessors:\n"); for (const HloEdge& e : predecessors_) { absl::StrAppend(&result, e.ToString()); @@ -454,6 +469,8 @@ class HloGraphNode { ResourcesVector resources_; // Force the scheduling of the nodes with attribute set as late as possible. bool force_delay_ = false; + // Force the scheduling of the nodes with attribute set as early as possible. + bool force_early_ = false; // Whether this node has been scheduled or not yet. bool scheduled_ = false; // Shareable resources released by this node. @@ -817,7 +834,7 @@ class DefaultSchedulerCore : public SchedulerCore { early_target_scheduling_rule_(early_target_scheduling_rule), post_processing_fn_(post_processing_fn) {} Status InitializeScheduler(const HloModule* module) override; - StatusOr> ScheduleComputation( + absl::StatusOr> ScheduleComputation( const HloComputation* computation) override; static bool AddOccupierToResource( HloGraphNode::TimeCost current_time, HloEdge& new_edge, @@ -837,7 +854,7 @@ class DefaultSchedulerCore : public SchedulerCore { protected: virtual void LogInstruction(const HloInstruction* instr) const; // Update node that has been scheduled. - virtual StatusOr ScheduleNode( + virtual absl::StatusOr ScheduleNode( HloGraphNode* n, SchedulingState* sched_state) const; // Perform the scheduling of one or more instructions. Called every time the // ready set is not empty. @@ -870,6 +887,7 @@ class LatencyHidingScheduler : public HloModulePass { const HloComputation* computation = nullptr; double all_gather_wasted_cycles = 0; double all_reduce_wasted_cycles = 0; + double collective_broadcast_wasted_cycles = 0; double collective_permute_wasted_cycles = 0; double all_to_all_wasted_cycles = 0; double reduce_scatter_wasted_cycles = 0; @@ -902,7 +920,7 @@ class LatencyHidingScheduler : public HloModulePass { static std::string SchedulerStatisticsString( const SchedulerStatistics& sched_stats); using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/latency_hiding_scheduler_test.cc b/xla/service/latency_hiding_scheduler_test.cc index 4edcc1a682c8b..2ed00500230ec 100644 --- a/xla/service/latency_hiding_scheduler_test.cc +++ b/xla/service/latency_hiding_scheduler_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -127,13 +127,14 @@ class TestLatencyEstimator : public LatencyEstimator { static constexpr TimeCost kHighCost = 5000.0; }; -StatusOr RunScheduler( +absl::StatusOr RunScheduler( HloModule* module, SchedulerConfig sched_config = GetDefaultSchedConfig(), std::unique_ptr latency_estimator = std::make_unique()) { AsyncCollectiveCreator::CollectiveCreatorConfig config{ /*convert_all_reduce=*/HloPredicateTrue, /*convert_all_gather=*/HloPredicateTrue, + /*convert_collective_broadcast=*/HloPredicateTrue, /*convert_collective_permute=*/HloPredicateTrue}; TF_ASSIGN_OR_RETURN(bool value, AsyncCollectiveCreator(std::move(config)).Run(module)); @@ -165,12 +166,12 @@ StatusOr RunScheduler( class LatencyHidingSchedulerTest : public HloTestBase { public: - StatusOr> ParseHloText( + absl::StatusOr> ParseHloText( absl::string_view hlo_string) { TF_ASSIGN_OR_RETURN( auto hlo_module, ParseAndReturnVerifiedModule(hlo_string, GetModuleConfigForTest())); - return StatusOr>(std::move(hlo_module)); + return absl::StatusOr>(std::move(hlo_module)); } }; @@ -996,9 +997,9 @@ TEST_F(LatencyHidingSchedulerTest, SerialCollectivePermutesTest) { ENTRY after_optimizations_test { %parameter.1 = bf16[8]{0} parameter(0) %collective-permute.2 = bf16[8]{0} collective-permute(bf16[8]{0} parameter.1), source_target_pairs={{0,1},{1,2},{2,3}} - %constant.3 = bf16[] constant(1) - %broadcast.4 = bf16[8]{0} broadcast(bf16[] %constant.3), dimensions={} - %add.5 = bf16[8]{0} add(%collective-permute.2, %broadcast.4) + %add.3 = bf16[8]{0} add(%parameter.1, %parameter.1) + %add.4 = bf16[8]{0} add(%add.3, parameter.1) + %add.5 = bf16[8]{0} add(%collective-permute.2, %add.4) %collective-permute.6 = bf16[8]{0} collective-permute(bf16[8]{0} add.5), source_target_pairs={{1,0},{0,3},{3,2}} } )"; @@ -1035,7 +1036,7 @@ TEST_F(LatencyHidingSchedulerTest, SerialCollectivePermutesTest) { original_instruction_sequence[3]), PositionInVector(new_instruction_sequence, original_instruction_sequence[4])); - EXPECT_EQ(original_instruction_sequence[0]->user_count(), 1); + EXPECT_EQ(original_instruction_sequence[0]->user_count(), 3); EXPECT_EQ(original_instruction_sequence[0]->users()[0]->opcode(), HloOpcode::kCollectivePermuteStart); HloInstruction* collective_permute_start_1 = @@ -2938,27 +2939,24 @@ ENTRY main { %call-start.1 = ((s32[<=4096]{0:T(8)M(1024)}), s32[<=4096]{0:T(8)M(1024)}, u32[]{:T(8)S(8)}) call-start(s32[<=4096]{0:T(8)M(1024)} %get-tuple-element.1), - async_group_id=17, async_execution_thread="sparsecore", to_apply=%called_computation + async_execution_thread="sparsecore", to_apply=%called_computation %call-done.1 = s32[<=4096]{0:T(8)M(1024)} - call-done(((s32[<=4096]{0:T(8)M(1024)}), s32[<=4096]{0:T(8)M(1024)}, u32[]{:T(8)S(8)}) %call-start.1), - async_group_id=17, async_execution_thread="sparsecore", to_apply=%called_computation + call-done(((s32[<=4096]{0:T(8)M(1024)}), s32[<=4096]{0:T(8)M(1024)}, u32[]{:T(8)S(8)}) %call-start.1) %call-start.2 = ((s32[<=4096]{0:T(8)M(1024)}), s32[<=4096]{0:T(8)M(1024)}, u32[]{:T(8)S(8)}) call-start(s32[<=4096]{0:T(8)M(1024)} %call-done.1), - async_group_id=27, async_execution_thread="sparsecore", to_apply=%called_computation + async_execution_thread="sparsecore", to_apply=%called_computation %call-done.2 = s32[<=4096]{0:T(8)M(1024)} - call-done(((s32[<=4096]{0:T(8)M(1024)}), s32[<=4096]{0:T(8)M(1024)}, u32[]{:T(8)S(8)}) %call-start.2), - async_group_id=27, async_execution_thread="sparsecore", to_apply=%called_computation + call-done(((s32[<=4096]{0:T(8)M(1024)}), s32[<=4096]{0:T(8)M(1024)}, u32[]{:T(8)S(8)}) %call-start.2) %call-start.3 = ((s32[<=4096]{0:T(8)M(1024)}), s32[<=4096]{0:T(8)M(1024)}, u32[]{:T(8)S(8)}) call-start(s32[<=4096]{0:T(8)M(1024)} %get-tuple-element.0), - async_group_id=14, async_execution_thread="sparsecore", to_apply=%called_computation + async_execution_thread="sparsecore", to_apply=%called_computation %call-done.3 = s32[<=4096]{0:T(8)M(1024)} - call-done(((s32[<=4096]{0:T(8)M(1024)}), s32[<=4096]{0:T(8)M(1024)}, u32[]{:T(8)S(8)}) %call-start.3), - async_group_id=14, async_execution_thread="sparsecore", to_apply=%called_computation + call-done(((s32[<=4096]{0:T(8)M(1024)}), s32[<=4096]{0:T(8)M(1024)}, u32[]{:T(8)S(8)}) %call-start.3) ROOT %tuple.6 = (s32[<=4096]{0:T(8)M(1024)}, s32[<=4096]{0:T(8)M(1024)}) tuple(s32[<=4096]{0:T(8)M(1024)} %call-done.2, s32[<=4096]{0:T(8)M(1024)} %call-done.3), @@ -2977,4 +2975,79 @@ ENTRY main { // not create a failure of scheduling by the async done checks. EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config).ok()); } + +TEST_F(LatencyHidingSchedulerTest, CopyScheduling) { + absl::string_view hlo_string = R"( +HloModule EinsumTest, is_scheduled=true +ENTRY AddR2 { + y_host = bf16[12800,12800]{1,0:T(8,128)(2,1)} parameter(1) + z = bf16[12800,12800]{1,0:T(8,128)(2,1)} parameter(2) + x = bf16[12800,12800]{1,0:T(8,128)(2,1)} parameter(0) + convolution = bf16[12800,12800]{1,0:T(8,128)(2,1)} convolution(x, z), dim_labels=bf_io->bf + copy-start = (bf16[12800,12800]{1,0:T(8,128)(2,1)}, bf16[12800,12800]{1,0:T(8,128)(2,1)}, u32[]{:S(2)}) copy-start(y_host) + copy-done = bf16[12800,12800]{1,0:T(8,128)(2,1)} copy-done(copy-start) + ROOT convolution.1 = bf16[12800,12800]{1,0:T(8,128)(2,1)} convolution(convolution, copy-done), dim_labels=bf_io->bf +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string)); + HloSchedule& module_schedule = hlo_module->schedule(); + EXPECT_TRUE(hlo_module->has_entry_computation()); + HloComputation* entry_computation = hlo_module->entry_computation(); + std::vector original_instruction_sequence = + module_schedule.sequence(entry_computation).instructions(); + auto sched_config = GetDefaultSchedConfig(); + EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config).ok()); + const HloInstruction* conv = FindInstruction(hlo_module.get(), "convolution"); + const HloInstruction* cps = FindInstruction(hlo_module.get(), "copy-start"); + const HloInstruction* cpd = FindInstruction(hlo_module.get(), "copy-done"); + std::vector new_instruction_sequence = + module_schedule.sequence(entry_computation).instructions(); + EXPECT_LT(PositionInVector(new_instruction_sequence, cps), + PositionInVector(new_instruction_sequence, conv)); + EXPECT_LT(PositionInVector(new_instruction_sequence, conv), + PositionInVector(new_instruction_sequence, cpd)); + XLA_VLOG_LINES(1, hlo_module->ToString()); +} + +TEST_F(LatencyHidingSchedulerTest, MaxCopyScheduling) { + absl::string_view hlo_string = R"( +HloModule EinsumTest, is_scheduled=true +ENTRY AddR2 { + y_host = bf16[12800,12800]{1,0:T(8,128)(2,1)} parameter(1) + q_host = bf16[12800,12800]{1,0:T(8,128)(2,1)} parameter(3) + z = bf16[12800,12800]{1,0:T(8,128)(2,1)} parameter(2) + x = bf16[12800,12800]{1,0:T(8,128)(2,1)} parameter(0) + convolution = bf16[12800,12800]{1,0:T(8,128)(2,1)} convolution(x, z), dim_labels=bf_io->bf + copy-start = (bf16[12800,12800]{1,0:T(8,128)(2,1)}, bf16[12800,12800]{1,0:T(8,128)(2,1)}, u32[]{:S(2)}) copy-start(y_host) + copy-done = bf16[12800,12800]{1,0:T(8,128)(2,1)} copy-done(copy-start) + copy-start2 = (bf16[12800,12800]{1,0:T(8,128)(2,1)}, bf16[12800,12800]{1,0:T(8,128)(2,1)}, u32[]{:S(2)}) copy-start(q_host) + copy-done2 = bf16[12800,12800]{1,0:T(8,128)(2,1)} copy-done(copy-start2) + ROOT t = (bf16[12800,12800]{1,0:T(8,128)(2,1)}, bf16[12800,12800]{1,0:T(8,128)(2,1)}) tuple(copy-done2, copy-done) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string)); + HloSchedule& module_schedule = hlo_module->schedule(); + EXPECT_TRUE(hlo_module->has_entry_computation()); + HloComputation* entry_computation = hlo_module->entry_computation(); + std::vector original_instruction_sequence = + module_schedule.sequence(entry_computation).instructions(); + auto sched_config = GetDefaultSchedConfig(); + EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config).ok()); + const HloInstruction* conv = FindInstruction(hlo_module.get(), "convolution"); + const HloInstruction* cps = FindInstruction(hlo_module.get(), "copy-start"); + const HloInstruction* cps2 = FindInstruction(hlo_module.get(), "copy-start2"); + const HloInstruction* cpd2 = FindInstruction(hlo_module.get(), "copy-done2"); + std::vector new_instruction_sequence = + module_schedule.sequence(entry_computation).instructions(); + EXPECT_LT(PositionInVector(new_instruction_sequence, cps2), + PositionInVector(new_instruction_sequence, conv)); + EXPECT_LT(PositionInVector(new_instruction_sequence, conv), + PositionInVector(new_instruction_sequence, cpd2)); + EXPECT_LT(PositionInVector(new_instruction_sequence, cps), + PositionInVector(new_instruction_sequence, cpd2)); + XLA_VLOG_LINES(1, hlo_module->ToString()); +} + } // namespace xla diff --git a/xla/service/layout_assignment.cc b/xla/service/layout_assignment.cc index 3e5e64a51823e..2ac95a1be55d9 100644 --- a/xla/service/layout_assignment.cc +++ b/xla/service/layout_assignment.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,36 +15,37 @@ limitations under the License. #include "xla/service/layout_assignment.h" -#include +#include #include -#include #include #include -#include #include #include #include -#include #include #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/layout.h" #include "xla/layout_util.h" #include "xla/map_util.h" #include "xla/permutation_util.h" #include "xla/service/call_graph.h" #include "xla/service/computation_layout.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_dce.h" #include "xla/service/logical_buffer.h" #include "xla/service/tuple_points_to_analysis.h" @@ -52,15 +53,15 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_layout.h" #include "xla/shape_util.h" +#include "xla/status.h" #include "xla/status_macros.h" #include "xla/statusor.h" -#include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" -#include "tsl/platform/protobuf.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -720,11 +721,23 @@ Status LayoutAssignment::AddMandatoryConstraints( if (parameter_layout.LayoutIsSet()) { // Parameter layouts must match the respective layout in // ComputationLayout, if there is one. - TF_RETURN_IF_ERROR( - SetInstructionLayout(parameter_layout.shape(), instruction)); + Shape param_shape = parameter_layout.shape(); + // Clear out memory space in layout. Host offloader will do the + // analysis later. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( + ¶m_shape, [](Shape* subshape, const ShapeIndex& index) { + if (!subshape->has_layout() || !subshape->IsArray()) { + return OkStatus(); + } + subshape->mutable_layout()->set_memory_space( + Layout::kDefaultMemorySpace); + return OkStatus(); + })); + + TF_RETURN_IF_ERROR(SetInstructionLayout(param_shape, instruction)); if (reverse_computation_order_) { TF_RETURN_IF_ERROR(PropagateParameterLayoutToUsers( - instruction, parameter_layout.shape(), this)); + instruction, param_shape, this)); } } } @@ -793,14 +806,17 @@ Status LayoutAssignment::AddMandatoryConstraints( const ComputationLayout& called_computation_layout = FindOrDie(computation_layouts_, instruction->to_apply()) ->computation_layout(); - TF_RETURN_IF_ERROR(SetInstructionLayout( - called_computation_layout.result_layout().shape(), instruction)); + auto result_shape = UnShardedShape( + instruction, called_computation_layout.result_layout().shape(), -1); + TF_RETURN_IF_ERROR(SetInstructionLayout(result_shape, instruction)); TF_RET_CHECK(instruction->operand_count() == called_computation_layout.parameter_count()); for (int64_t i = 0; i < instruction->operand_count(); ++i) { - TF_RETURN_IF_ERROR(SetOperandLayout( - called_computation_layout.parameter_layout(i).shape(), instruction, - i, /*mandatory=*/true, /*dfs=*/true)); + auto operand_shape = UnShardedShape( + instruction, called_computation_layout.parameter_layout(i).shape(), + i); + TF_RETURN_IF_ERROR(SetOperandLayout(operand_shape, instruction, i, + /*mandatory=*/true, /*dfs=*/true)); } } else if (instruction->opcode() == HloOpcode::kWhile && computation_layouts_.find(instruction->while_body()) != @@ -950,22 +966,6 @@ bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) { return Layout::Equal().MinorToMajorOnly()(lhs.layout(), rhs.layout()); } -// The operands of a call must match the layouts of parameters in the -// ComputationLayout, and the call instruction itself must match the result -// layout in the ComputationLayout. -Status CheckCallLayout(HloInstruction* call, - const ComputationLayout& computation_layout) { - HloComputation* computation = call->to_apply(); - TF_RET_CHECK(computation->num_parameters() == call->operand_count()); - for (int64_t i = 0; i < computation->num_parameters(); ++i) { - TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape( - call->operand(i)->shape(), /*minor_to_major_only=*/true)); - } - TF_RET_CHECK(computation_layout.result_layout().MatchesLayoutInShape( - call->shape(), /*minor_to_major_only=*/true)); - return OkStatus(); -} - // Operands of layout-constrained custom calls must match the expected // constrained layouts. Status CheckCustomCallLayout(HloInstruction* instruction) { @@ -1055,7 +1055,8 @@ Status CheckParameterLayout(HloInstruction* parameter, computation_layout.parameter_layout(parameter->parameter_number()); return ShapeUtil::ForEachSubshapeWithStatus( parameter_layout.shape(), - [&](const Shape& subshape, const ShapeIndex& shape_index) { + [&](const Shape& subshape, + const ShapeIndex& shape_index) -> absl::Status { if (!ShapeUtil::IsLeafIndex(parameter_layout.shape(), shape_index) || !subshape.has_layout()) { return OkStatus(); @@ -1063,7 +1064,7 @@ Status CheckParameterLayout(HloInstruction* parameter, if (!Shape::Equal().MinorToMajorOnlyInLayout().IgnoreDynamicDimension()( subshape, ShapeUtil::GetSubshape(parameter->shape(), shape_index))) { - return InternalError( + return Internal( "parameter instruction %s does not match layout of computation " "shape: %s", parameter->ToString(), parameter_layout.ToString()); @@ -1075,7 +1076,7 @@ Status CheckParameterLayout(HloInstruction* parameter, // The layout of a constant instruction must match the layout of its literal. Status CheckConstantLayout(HloInstruction* constant) { if (!LayoutsInShapesEqual(constant->literal().shape(), constant->shape())) { - return InternalError( + return Internal( "constant instruction %s does not match the layout of its literal %s", constant->ToString(), ShapeUtil::HumanStringWithLayout(constant->literal().shape())); @@ -1104,7 +1105,7 @@ Status CheckBroadcastLayout(HloInstruction* broadcast) { }, broadcast->shape()); if (!LayoutsInShapesEqual(shape, broadcast->operand(0)->shape())) { - return InternalError( + return Internal( "broadcast instruction %s does not match the layout of its operand %s", broadcast->ToString(), broadcast->operand(0)->ToString()); } @@ -1113,7 +1114,21 @@ Status CheckBroadcastLayout(HloInstruction* broadcast) { } // namespace -StatusOr LayoutAssignment::CreateCopyWithNewLayout( +Status LayoutAssignment::CheckCallLayout( + HloInstruction* call, const ComputationLayout& computation_layout) { + HloComputation* computation = call->to_apply(); + TF_RET_CHECK(computation->num_parameters() == call->operand_count()); + for (int64_t i = 0; i < computation->num_parameters(); ++i) { + TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape( + ShardedShape(call, call->operand(i)->shape(), i), + /*minor_to_major_only=*/true)); + } + TF_RET_CHECK(computation_layout.result_layout().MatchesLayoutInShape( + ShardedShape(call, call->shape(), -1), /*minor_to_major_only=*/true)); + return OkStatus(); +} + +absl::StatusOr LayoutAssignment::CreateCopyWithNewLayout( const Shape& shape_with_layout, HloInstruction* instruction) { TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout)); DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape())) @@ -1282,7 +1297,7 @@ Status LayoutAssignment::CheckLayouts( .IgnoreDynamicDimension() .MinorToMajorOnlyInLayout()(instruction_subshape, buffer->shape())) { - return InternalError( + return Internal( "Layout of instruction %s at index {%s} does not match " "source LogicalBuffer %s: %s vs %s", instruction->name(), absl::StrJoin(index, ","), @@ -2003,9 +2018,14 @@ Status LayoutAssignment::PropagateBufferConstraintToUses( VLOG(3) << "Propagating layout through backedge" << buffer_constraint.layout().ToString(); int64_t index = user->operand_index(buffer.instruction()); - TF_ASSIGN_OR_RETURN( - auto buffer, points_to_analysis_->GetBufferDefinedAt( - user->parent()->parameter_instruction(0), {index})); + + const HloInstruction* inputs = user->parent()->parameter_instruction(0); + + ShapeIndex used_index = buffer.index(); + used_index.push_front(index); + + TF_ASSIGN_OR_RETURN(auto buffer, points_to_analysis_->GetBufferDefinedAt( + inputs, used_index)); TF_RETURN_IF_ERROR(SetBufferLayout(buffer_constraint.layout(), *buffer, /*mandatory=*/false)); @@ -2018,19 +2038,35 @@ Status LayoutAssignment::PropagateBufferConstraintToUses( Status LayoutAssignment::PropagateResultConstraint( const ComputationLayoutConstraint& layout_constraint, LayoutConstraints* constraints) { + ShapeLayout result_layout = + layout_constraint.computation_layout().result_layout(); + // Clear out memory space in layout for entry computation root. Host offloader + // will do the analysis later and add back the memory space for host outputs. + if (constraints->computation()->IsEntryComputation()) { + Shape result_shape = result_layout.shape(); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( + &result_shape, [](Shape* subshape, const ShapeIndex& shape_index) { + if (subshape->has_layout() && subshape->IsArray()) { + subshape->mutable_layout()->set_memory_space( + Layout::kDefaultMemorySpace); + } + return OkStatus(); + })); + TF_RETURN_IF_ERROR(result_layout.CopyLayoutFromShape(result_shape)); + } + // Propagate the use constraint of the root instruction up to the logical // buffers which make up the result. return PropagateUseConstraintToDefs( - layout_constraint.computation_layout().result_layout(), - constraints->computation()->root_instruction(), constraints, - current_priority_); + result_layout, constraints->computation()->root_instruction(), + constraints, current_priority_); } // Infers the layout of the array at the given index in the given instruction's // output using points-to analysis. Precondition: The given instruction must // not produce this array value (that is, the array is forwarded from the // instruction's operands). -StatusOr LayoutAssignment::InferArrayLayout( +absl::StatusOr LayoutAssignment::InferArrayLayout( const HloInstruction* instruction, const ShapeIndex& index) { const auto& source_buffers = points_to_analysis_->GetPointsToSet(instruction).element(index); @@ -2045,8 +2081,8 @@ StatusOr LayoutAssignment::InferArrayLayout( if (source_buffer_constraint == nullptr) { // This should not happen because we've assigned layouts to all // instructions preceding this one. - return InternalError("LogicalBuffer %s does not have a layout", - source_buffer->ToString()); + return Internal("LogicalBuffer %s does not have a layout", + source_buffer->ToString()); } if (first_buffer_layout == nullptr) { @@ -2129,7 +2165,7 @@ Status LayoutAssignment::AssignLayouts(LayoutConstraints& constraints) { if (instruction->opcode() == HloOpcode::kBitcast) { // bitcasts are inherently layout sensitive and so a bitcast instruction // present in the IR before layout assignment is a bug. - return InternalError( + return Internal( "Unexpected bitcast operation seen during layout assignment: %s.", instruction->ToString()); } @@ -2218,7 +2254,8 @@ Status LayoutAssignment::AssignLayouts(LayoutConstraints& constraints) { computation->root_instruction())); computation->set_root_instruction(new_root); } else { - // Copy the tiling info specified in result layout. + // Copy the tiling info/tail_padding_alignment_in_elements specified in + // result layout. auto copy_tiling = [&constraints](xla::Shape* subshape, const xla::ShapeIndex& index) { if (subshape->IsArray()) { @@ -2229,6 +2266,10 @@ Status LayoutAssignment::AssignLayouts(LayoutConstraints& constraints) { result_shape.layout().tiles().begin(), result_shape.layout().tiles().end()); } + subshape->mutable_layout()->set_element_size_in_bits( + result_shape.layout().element_size_in_bits()); + subshape->mutable_layout()->set_tail_padding_alignment_in_elements( + result_shape.layout().tail_padding_alignment_in_elements()); } }; xla::ShapeUtil::ForEachMutableSubshape( @@ -2314,7 +2355,8 @@ Status LayoutAssignment::CalculateComputationLayout( SetCalleeLayout( instruction, instruction->operands(), mutable_computation_constraints(instruction->to_apply()), - current_priority_ + 1) == OkStatus()) { + current_priority_ + 1) + .ok()) { VLOG(2) << "Successfully propagated to callee layout\n"; } break; @@ -2377,7 +2419,7 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { if (instruction->opcode() == HloOpcode::kBitcast) { // bitcasts are inherently layout sensitive and so a bitcast instruction // present in the IR before layout assignment is a bug. - return InternalError( + return Internal( "Unexpected bitcast operation seen during layout assignment: %s.", instruction->ToString()); } @@ -2499,7 +2541,8 @@ Status LayoutAssignment::PropagateComputationLayouts( bool needs_assign = false; TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( param_layout->shape(), - [&](const Shape& subshape, const ShapeIndex& shape_index) { + [&](const Shape& subshape, + const ShapeIndex& shape_index) -> absl::Status { if (!ShapeUtil::IsLeafIndex(param_layout->shape(), shape_index)) { return OkStatus(); } @@ -2509,8 +2552,9 @@ Status LayoutAssignment::PropagateComputationLayouts( } const auto& computed_subshape = ShapeUtil::GetSubshape( computed_computation_layout.parameter_shape(i), shape_index); - if (subshape.layout() != computed_subshape.layout()) { - return InternalError( + if (!Layout::Equal().IgnoreMemorySpace()( + subshape.layout(), computed_subshape.layout())) { + return Internal( "Assigned parameter shape %s does not match layout of " "computation shape: %s", computed_computation_layout.ToString(), @@ -2539,7 +2583,7 @@ Status LayoutAssignment::PropagateComputationLayouts( return OkStatus(); } -StatusOr LayoutAssignment::Run( +absl::StatusOr LayoutAssignment::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(2) << "Running layout assignment on module " << module->name(); @@ -2671,6 +2715,62 @@ StatusOr LayoutAssignment::Run( ? LayoutConstraint::kGivenPriority : LayoutConstraint::kDefaultPriority)); for (int64_t i = 0; i < kNumberOfPropagationRounds; ++i) { + if (i > 0) { + LayoutConstraints* constraints = + mutable_computation_constraints(module->entry_computation()); + + bool changed = false; + module->input_output_alias_config().ForEachAlias( + [&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) { + const auto param = alias.parameter_number; + const auto& index = alias.parameter_index; + bool param_is_forced = + ShapeUtil::GetSubshape( + saved_entry_computation_layout_.parameter_shape(param), + index) + .has_layout(); + bool result_is_forced = + ShapeUtil::GetSubshape( + saved_entry_computation_layout_.result_shape(), + output_index) + .has_layout(); + Shape* param_shape = + ShapeUtil::GetMutableSubshape(module->entry_computation() + ->parameter_instruction(param) + ->mutable_shape(), + index); + Shape* result_shape = + ShapeUtil::GetMutableSubshape(module->entry_computation() + ->root_instruction() + ->mutable_shape(), + output_index); + if (param_is_forced && result_is_forced) { + return; + } + + if (param_shape->layout().minor_to_major() == + result_shape->layout().minor_to_major()) { + return; + } + changed = true; + if (!param_is_forced) { + *param_shape = *result_shape; + return; + } + *result_shape = *param_shape; + }); + if (changed) { + auto computed_program_shape = + module->entry_computation()->ComputeProgramShape(); + constraints->mutable_computation_constraint()->ResetComputationLayout( + ComputationLayout{ + module->entry_computation()->ComputeProgramShape(), false}, + LayoutConstraint::kGivenPriority, true, true); + *entry_computation_layout_ = + constraints->computation_constraint().computation_layout(); + } + } VLOG(1) << "Running " << (i == 0 ? "un" : "") << "constrained pass"; TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module, execution_threads)); for (auto* computation : computations_to_work) { @@ -2721,10 +2821,12 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kAllGatherStart: case HloOpcode::kAllGatherDone: case HloOpcode::kAllToAll: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: case HloOpcode::kDivide: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFft: @@ -2847,7 +2949,6 @@ bool LayoutAssignment::IsAtMostRank1(const Shape& shape) { Status LayoutAssignment::Init(HloModule* module) { computation_layouts_.clear(); conditional_mismatch_.clear(); - *entry_computation_layout_ = saved_entry_computation_layout_; current_priority_ = LayoutConstraint::kBeginningPriority; // Clear all the copies which have been added, and all the related // instructions (like GTE and tuples). diff --git a/xla/service/layout_assignment.h b/xla/service/layout_assignment.h index 18fd5df38f41a..2258649391710 100644 --- a/xla/service/layout_assignment.h +++ b/xla/service/layout_assignment.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_LAYOUT_ASSIGNMENT_H_ #define XLA_SERVICE_LAYOUT_ASSIGNMENT_H_ +#include #include #include #include @@ -28,19 +29,26 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/container/node_hash_map.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" #include "xla/layout_util.h" +#include "xla/map_util.h" #include "xla/service/call_graph.h" #include "xla/service/computation_layout.h" #include "xla/service/hlo_pass_interface.h" #include "xla/service/logical_buffer.h" #include "xla/service/tuple_points_to_analysis.h" +#include "xla/shape.h" #include "xla/shape_layout.h" #include "xla/shape_util.h" +#include "xla/status.h" #include "xla/statusor.h" #include "xla/types.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/status.h" @@ -117,7 +125,7 @@ class OperandLayoutConstraint : public LayoutConstraint { const ShapeLayout& shape_layout() const { return shape_layout_[0]; } const HloInstruction* instruction() const { return instruction_; } - const int64_t operand_no() const { return operand_no_; } + int64_t operand_no() const { return operand_no_; } const HloInstruction* operand() const { return instruction_->operand(operand_no_); } @@ -196,7 +204,7 @@ class ComputationLayoutConstraint : public LayoutConstraint { class ChannelLayoutConstraints { public: // Construct an empty constraint set. - ChannelLayoutConstraints() {} + ChannelLayoutConstraints() = default; // Returns true if channel_id has a layout constraint. bool IsChannelConstrained(int64_t channel_id) const { @@ -263,7 +271,7 @@ class LayoutAssignment : public HloModulePass { // Assign layouts to the given module. Returns whether the module was changed // (any layouts were changed). using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -464,8 +472,8 @@ class LayoutAssignment : public HloModulePass { Status PropagateUnconstraintedBuffers(LayoutConstraints* constraints); const BufferLayoutConstraint* GetBufferLayoutConstraint( const LogicalBuffer& buffer) const; - StatusOr GetInstructionBufferLayoutConstraint( - const HloInstruction* instruction) const; + absl::StatusOr + GetInstructionBufferLayoutConstraint(const HloInstruction* instruction) const; // Find a bufferset in the bufferset cache. This is useful since we can // currently create the flattened buffer set for the same instruction many // times, which is often slow. @@ -479,8 +487,8 @@ class LayoutAssignment : public HloModulePass { // buffers of its operands and would return true for each of its operands. bool AnyOperandBufferForwarded(const HloInstruction* instruction, int64_t operand_no) const; - StatusOr InferArrayLayout(const HloInstruction* instruction, - const ShapeIndex& index); + absl::StatusOr InferArrayLayout(const HloInstruction* instruction, + const ShapeIndex& index); // Propagates a buffer layout constraint into the operands that use it. Status PropagateBufferConstraintToUses( @@ -516,6 +524,31 @@ class LayoutAssignment : public HloModulePass { virtual bool InstructionCanChangeLayoutInstance( const HloInstruction* instruction); + // The shapes in caller can be different from the shapes in callee. For + // example, a shape (1024, 128) of an array can be distributed to four threads + // so the shape for each thread is (256, 128). When verifying the callee's + // shapes based on the caller, we should use this function to compute the + // expected shape. The param_id should be the parameter id of the shape or -1 + // for the result output or unknown. + virtual Shape ShardedShape(const HloInstruction* call, const Shape& shape, + int param_id) { + return shape; + } + // When verifying the caller's shapes based on the callee, we should use this + // function to compute the expected shape. + // The param_id should be the parameter id of the shape or -1 for the result + // output or unknown. + virtual Shape UnShardedShape(const HloInstruction* call, const Shape& shape, + int param_id) { + return shape; + } + + // The operands of a call must match the layouts of parameters in the + // ComputationLayout, and the call instruction itself must match the result + // layout in the ComputationLayout. + Status CheckCallLayout(HloInstruction* call, + const ComputationLayout& computation_layout); + private: // Initializes the layout assignment object for a new Run() call. Status Init(HloModule* module); @@ -619,7 +652,7 @@ class LayoutAssignment : public HloModulePass { // Creates and returns a copy of the given instruction with a different // layout. Tuple-shaped instructions will be deep-copied, and the last Tuple // instruction producing the copy is returned. - StatusOr CreateCopyWithNewLayout( + absl::StatusOr CreateCopyWithNewLayout( const Shape& shape_with_layout, HloInstruction* instruction); // Creates a copy of the given operand if the operand's layout does not match diff --git a/xla/service/layout_assignment_test.cc b/xla/service/layout_assignment_test.cc index e0db0346e4c82..9a0766e8bc92e 100644 --- a/xla/service/layout_assignment_test.cc +++ b/xla/service/layout_assignment_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -1612,7 +1612,7 @@ ENTRY main { TEST_F(LayoutAssignmentTest, PropagateOperandLayout2) { const char* module_str = R"( HloModule TensorFlowGather, entry_computation_layout={(f32[32,650]{1,0},s32[16,1,18]{0,1,2})->f32[16,1,18,32]{3,1,2,0}} - + ENTRY %main (operand: f32[32,650], indices: s32[16,1,18]) -> f32[16,1,18,32] { %operand = f32[32,650]{1,0} parameter(0) %transpose = f32[650,32]{0,1} transpose(f32[32,650]{1,0} %operand), dimensions={1,0} @@ -1638,7 +1638,7 @@ TEST_F(LayoutAssignmentTest, PropagateOperandLayout2) { TEST_F(LayoutAssignmentTest, PreserveInstructionLayout) { const char* module_str = R"( HloModule TensorFlowGather, entry_computation_layout={(f32[32,650]{1,0},s32[16,1,18]{0,1,2})->(f32[16,1,18,32]{3,1,2,0})} - + ENTRY %main { %operand = f32[32,650]{1,0} parameter(0) %transpose = f32[650,32]{0,1} transpose(f32[32,650]{1,0} %operand), dimensions={1,0} @@ -1697,7 +1697,7 @@ ENTRY main { TEST_F(LayoutAssignmentTest, PartialEntryParameterLayout) { const char* module_str = R"( HloModule EntryLayout, entry_computation_layout={(f32[32,650]{1,0},s32[16,1,18]{0,1,2})->(f32[650,32]{1,0},s32[18,16,1]{0,1,2})} - + ENTRY %main { operand = f32[32,650] parameter(0) transpose = transpose(operand), dimensions={1,0} @@ -1722,5 +1722,132 @@ TEST_F(LayoutAssignmentTest, PartialEntryParameterLayout) { {0, 1, 2}); } +// Test the ability to enforce aliasing . +TEST_F(LayoutAssignmentTest, AliasParameterAndOutput) { + const char* module_str = R"( + HloModule EntryAlias, input_output_alias={ {}: (0, {}, may-alias) } + + ENTRY %main { + p0 = f32[65,65] parameter(0) + p1 = f32[4225] parameter(1) + r = f32[65,65] reshape(p1) + a = add(p0,r) + ROOT t = transpose(a), dimensions={1,0} + } )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); + m->mutable_entry_computation_layout()->SetToDefaultLayout(); + m->mutable_entry_computation_layout()->mutable_result_layout()->Clear(); + m->mutable_entry_computation_layout()->mutable_parameter_layout(0)->Clear(); + + LayoutAssignment layout_assignment(m->mutable_entry_computation_layout(), + nullptr); + EXPECT_IS_OK(layout_assignment.Run(m.get()).status()); + EXPECT_EQ(m->entry_computation_layout().result_layout().shape(), + m->entry_computation_layout().parameter_layout(0).shape()); +} + +// Test the ability to enforce aliasing . +TEST_F(LayoutAssignmentTest, AliasUnconstrainedParamterWithConstrainedOutput) { + const char* module_str = R"( + HloModule EntryAlias, input_output_alias={ {}: (0, {}, may-alias) } + + ENTRY %main { + p0 = f32[65,65] parameter(0) + p1 = f32[4225] parameter(1) + r = f32[65,65] reshape(p1) + a = add(p0,r) + ROOT t = transpose(a), dimensions={1,0} + } )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); + m->mutable_entry_computation_layout()->SetToDefaultLayout(); + m->mutable_entry_computation_layout()->mutable_parameter_layout(0)->Clear(); + + LayoutAssignment layout_assignment(m->mutable_entry_computation_layout(), + nullptr); + EXPECT_IS_OK(layout_assignment.Run(m.get()).status()); + EXPECT_EQ(m->entry_computation_layout().result_layout().shape(), + m->entry_computation_layout().parameter_layout(0).shape()); +} + +TEST_F(LayoutAssignmentTest, AliasConstrainedParamterWithUnconstrainedOutput) { + const char* module_str = R"( + HloModule EntryAlias, input_output_alias={ {}: (0, {}, may-alias) } + + ENTRY %main { + p0 = f32[65,65] parameter(0) + p1 = f32[4225] parameter(1) + r = f32[65,65] reshape(p1) + a = add(p0,r) + ROOT t = transpose(a), dimensions={1,0} + } )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); + m->mutable_entry_computation_layout()->SetToDefaultLayout(); + m->mutable_entry_computation_layout()->mutable_result_layout()->Clear(); + + LayoutAssignment layout_assignment(m->mutable_entry_computation_layout(), + nullptr); + EXPECT_IS_OK(layout_assignment.Run(m.get()).status()); + EXPECT_EQ(m->entry_computation_layout().result_layout().shape(), + m->entry_computation_layout().parameter_layout(0).shape()); +} + +TEST_F(LayoutAssignmentTest, NestedTupleInLoop) { + const char* module_str = R"( +HloModule Module + +condition { + p = (f32[100,100], (f32[100,100], u32[], token[])) parameter(0) + ROOT lt = pred[] constant(1) +} + +body { + p = (f32[100,100], (f32[100,100], u32[], token[])) parameter(0) + + t1 = f32[100,100] get-tuple-element(p), index=0 + t = (f32[100,100], u32[], token[]) get-tuple-element(p), index=1 + sdone = token[] send-done(t), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + tk = token[] after-all() + snd = (f32[100,100], u32[], token[]) send(t1, tk), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + a = add(t1, t1) + ROOT tup = tuple(a, snd) +} + +ENTRY %main { + p0 = f32[100,100] parameter(0) + tk = token[] after-all() + snd = (f32[100,100], u32[], token[]) send(p0, tk), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + t = tuple(p0, snd) + loop = while(t), condition=condition, body=body + ssend = (f32[100,100], u32[], token[]) get-tuple-element(loop), index=1 + sdone = token[] send-done(ssend), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + ROOT result = f32[100,100] get-tuple-element(loop), index=0 +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); + + LayoutAssignment layout_assignment(m->mutable_entry_computation_layout(), + nullptr); + EXPECT_IS_OK(layout_assignment.Run(m.get()).status()); +} + } // namespace } // namespace xla diff --git a/xla/service/layout_normalization.cc b/xla/service/layout_normalization.cc index a4c002ac04e2e..6837437f8f93e 100644 --- a/xla/service/layout_normalization.cc +++ b/xla/service/layout_normalization.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -238,6 +238,24 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { return OkStatus(); } + Status HandleIota(HloInstruction* hlo) override { + VLOG(3) << "Input iota: " << hlo->ToString(); + auto s = hlo->shape(); + auto normalized_shape = Normalize(s); + std::vector orig_output_layout_as_permutation = + ToTransposeDimensions(s.layout()); + int64_t iota_dimension = hlo->dimensions()[0]; + int64_t new_iota_dimension = + FindIndex(orig_output_layout_as_permutation, iota_dimension); + auto normalized_iota = hlo->AddInstruction( + HloInstruction::CreateIota(normalized_shape, new_iota_dimension)); + SetVisited(*normalized_iota); + VLOG(3) << "Generated iota: " << normalized_iota->ToString(); + auto bc_to_orig = MakeBitcastHlo(normalized_iota, s); + TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); + return OkStatus(); + } + // BitcastConvert is only layout-preserving if it doesn't change the rank. Status HandleBitcastConvert(HloInstruction* hlo) override { // If the rank isn't changing this is just an unary op. @@ -634,7 +652,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { // Due to Local Precondition we have, the input to all processed ops should // be HLO in descending layout piped through bitcast. - StatusOr GetNormalizedInput(HloInstruction* hlo) { + absl::StatusOr GetNormalizedInput(HloInstruction* hlo) { TF_RET_CHECK(hlo->opcode() == HloOpcode::kBitcast) << "Unexpected HLO input: " << hlo->ToString(); auto input = hlo->mutable_operand(0); @@ -655,7 +673,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { } // end namespace -StatusOr LayoutNormalization::Run( +absl::StatusOr LayoutNormalization::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { return LayoutNormalizationVisitor{custom_call_transformer_}.RunOnModule( diff --git a/xla/service/layout_normalization.h b/xla/service/layout_normalization.h index ec03f48a0c962..9fd3c00d3108e 100644 --- a/xla/service/layout_normalization.h +++ b/xla/service/layout_normalization.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ limitations under the License. namespace xla { using CustomCallTransformer = - std::function>( + std::function>( HloCustomCallInstruction*)>; // Normalize shapes for some subsets of HLOs. @@ -49,7 +49,7 @@ class LayoutNormalization : public HloModulePass { absl::string_view name() const override { return "layout_normalization"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/layout_normalization_test.cc b/xla/service/layout_normalization_test.cc index 7c972ebef4290..6cebe0f2a858c 100644 --- a/xla/service/layout_normalization_test.cc +++ b/xla/service/layout_normalization_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -311,6 +311,23 @@ ENTRY main { )"); } +TEST_F(LayoutNormalizationTest, IotaCustomOutputLayout) { + const char* hlo = R"( +HloModule module + +ENTRY main { + a = f32[2,4,3]{1,2,0} iota(), iota_dimension=2 + ROOT out = abs(a) +} +)"; + + CheckLayoutNormalization(hlo, R"( +// CHECK: [[iota_2:%[^ ]+]] = f32[2,3,4]{2,1,0} iota(), iota_dimension=1 +// CHECK: [[abs_3:%[^ ]+]] = f32[2,3,4]{2,1,0} abs([[iota_2]]) +// CHECK: ROOT [[bitcast_3_4:%[^ ]+]] = f32[2,4,3]{1,2,0} bitcast([[abs_3]]) +)"); +} + TEST_F(LayoutNormalizationTest, Concatenate) { const char* hlo = R"( HloModule module diff --git a/xla/service/llvm_compiler.cc b/xla/service/llvm_compiler.cc index ea20c00592b95..c34b286b7c028 100644 --- a/xla/service/llvm_compiler.cc +++ b/xla/service/llvm_compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ limitations under the License. #endif namespace xla { -StatusOr>> LLVMCompiler::Compile( +absl::StatusOr>> LLVMCompiler::Compile( std::unique_ptr module_group, std::vector> stream_execs, const CompileOptions& options) { diff --git a/xla/service/llvm_compiler.h b/xla/service/llvm_compiler.h index 14a44b18eeee7..cd2699c2074fd 100644 --- a/xla/service/llvm_compiler.h +++ b/xla/service/llvm_compiler.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -70,7 +70,7 @@ class LLVMCompiler : public Compiler { using Compiler::RunBackend; using Compiler::RunHloPasses; - StatusOr>> Compile( + absl::StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_execs, const CompileOptions& options) override; diff --git a/xla/service/llvm_ir/BUILD b/xla/service/llvm_ir/BUILD index 86334f5676360..b38d45a5a1a54 100644 --- a/xla/service/llvm_ir/BUILD +++ b/xla/service/llvm_ir/BUILD @@ -1,13 +1,14 @@ # Description: # Libraries for helping construct LLVM IR for XLA backends. +load("@tsl//tsl:tsl.bzl", "internal_visibility") load("@tsl//tsl:tsl.default.bzl", "filegroup") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//xla:xla.bzl", "xla_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [":friends"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -42,7 +43,6 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@llvm-project//llvm:Core", - "@tsl//tsl/platform:logging", ], ) @@ -50,12 +50,9 @@ xla_cc_test( name = "alias_analysis_test", srcs = ["alias_analysis_test.cc"], deps = [ - ":alias_analysis", "//xla/service:custom_call_status_public_headers", "//xla/service:custom_call_target_registry", - "//xla/service:hlo_parser", "//xla/service/cpu/tests:cpu_codegen_test", - "//xla/tests:filecheck", "@tsl//tsl/platform:test", ], ) @@ -81,6 +78,7 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", "@llvm-project//llvm:TransformUtils", "@llvm-project//mlir:IR", "@tsl//tsl/platform:byte_order", @@ -160,7 +158,6 @@ cc_library( "//xla:status_macros", "//xla:statusor", "//xla:types", - "//xla:xla_data_proto_cc", "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:Core", "@tsl//tsl/platform:errors", @@ -186,6 +183,7 @@ cc_library( "//xla/service:fusion_node_indexing_evaluation", "@com_google_absl//absl/container:flat_hash_map", "@llvm-project//llvm:Core", + "@llvm-project//llvm:TargetParser", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:statusor", diff --git a/xla/service/llvm_ir/alias_analysis.cc b/xla/service/llvm_ir/alias_analysis.cc index 2d4faa6a590c0..d239c7520bd22 100644 --- a/xla/service/llvm_ir/alias_analysis.cc +++ b/xla/service/llvm_ir/alias_analysis.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/llvm_ir/alias_analysis.h b/xla/service/llvm_ir/alias_analysis.h index 9b55939679a62..91d850e61db0e 100644 --- a/xla/service/llvm_ir/alias_analysis.h +++ b/xla/service/llvm_ir/alias_analysis.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/llvm_ir/alias_analysis_test.cc b/xla/service/llvm_ir/alias_analysis_test.cc index e736d1940c65c..8016fd24aaa9d 100644 --- a/xla/service/llvm_ir/alias_analysis_test.cc +++ b/xla/service/llvm_ir/alias_analysis_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,15 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/llvm_ir/alias_analysis.h" - -#include -#include - #include "xla/service/cpu/tests/cpu_codegen_test.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" -#include "xla/tests/filecheck.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/service/llvm_ir/buffer_assignment_util.cc b/xla/service/llvm_ir/buffer_assignment_util.cc index f92b3732c9ca4..5fa4e61bd69b1 100644 --- a/xla/service/llvm_ir/buffer_assignment_util.cc +++ b/xla/service/llvm_ir/buffer_assignment_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/llvm_ir/buffer_assignment_util.h b/xla/service/llvm_ir/buffer_assignment_util.h index 90805feaa9cdf..bc3a0581de74e 100644 --- a/xla/service/llvm_ir/buffer_assignment_util.h +++ b/xla/service/llvm_ir/buffer_assignment_util.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/llvm_ir/dynamic_update_slice_util.cc b/xla/service/llvm_ir/dynamic_update_slice_util.cc index 1b8a00d5bdd64..6b17b94236479 100644 --- a/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -98,7 +98,7 @@ bool CanEmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion, // EmitFusedDynamicUpdateSliceInPlace. // // Emits a sequential loop if launch_dimensions is null. -using IndexGenerator = std::function(int64_t)>; +using IndexGenerator = std::function(int64_t)>; static Status EmitDynamicUpdateSliceInPlaceImpl( const Shape& update_shape, const IndexGenerator& start_indices_generator, @@ -235,7 +235,7 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( fused_emitter->GetGenerator(*update)); IndexGenerator start_indices_generator = - [&](int64_t index) -> StatusOr { + [&](int64_t index) -> absl::StatusOr { TF_ASSIGN_OR_RETURN(ElementGenerator element_generator, fused_emitter->GetGenerator( *dynamic_update_slice->operand(2 + index))); diff --git a/xla/service/llvm_ir/dynamic_update_slice_util.h b/xla/service/llvm_ir/dynamic_update_slice_util.h index 88c43e55416b5..cf015c8692226 100644 --- a/xla/service/llvm_ir/dynamic_update_slice_util.h +++ b/xla/service/llvm_ir/dynamic_update_slice_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/llvm_ir/fused_ir_emitter.cc b/xla/service/llvm_ir/fused_ir_emitter.cc index 4b66bd023b9cc..f69cc88a14f71 100644 --- a/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/xla/service/llvm_ir/fused_ir_emitter.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" +#include "llvm/TargetParser/Triple.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -44,14 +45,14 @@ namespace xla { using llvm_ir::IrArray; -StatusOr FusedIrEmitter::DefaultAction( +absl::StatusOr FusedIrEmitter::DefaultAction( const HloInstruction& instruction) { IndexedGenerator generator = elemental_emitter_.MakeElementGenerator( &instruction, indexed_generators_); - return StatusOr([&, generator = std::move(generator)]( - const IrArray::Index& index) - -> StatusOr { + return absl::StatusOr([&, generator = std::move(generator)]( + const IrArray::Index& index) + -> absl::StatusOr { ValueCacheKey key{&instruction, index.multidim()}; llvm::Value* value = value_cache_.insert({key, nullptr}).first->second; @@ -95,6 +96,8 @@ FusedIrEmitter::IndexedGenerator FusedIrEmitter::HandleConstant( llvm::Module* module = elemental_emitter_.module(); llvm::IRBuilder<>* b = elemental_emitter_.b(); + // Explicitly set global addrspace for SPIR backend. + int addrspace = llvm::Triple(module->getTargetTriple()).isSPIR() ? 1 : 0; llvm::Constant* initializer = llvm_ir::ConvertLiteralToIrConstant(constant.literal(), module); llvm::GlobalVariable* global = new llvm::GlobalVariable( @@ -104,7 +107,7 @@ FusedIrEmitter::IndexedGenerator FusedIrEmitter::HandleConstant( /*Initializer=*/initializer, /*Name=*/"", /*InsertBefore=*/nullptr, /*TLMode=*/llvm::GlobalValue::NotThreadLocal, - /*AddressSpace=*/0, + /*AddressSpace=*/addrspace, /*isExternallyInitialized=*/false); global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global); @@ -116,7 +119,7 @@ FusedIrEmitter::IndexedGenerator FusedIrEmitter::HandleConstant( }; } -StatusOr FusedIrEmitter::HandleTuple( +absl::StatusOr FusedIrEmitter::HandleTuple( const HloInstruction& tuple) { std::vector element_ir_types; element_ir_types.reserve(tuple.operand_count()); @@ -128,33 +131,32 @@ StatusOr FusedIrEmitter::HandleTuple( llvm::IRBuilder<>* b = elemental_emitter_.b(); llvm::Type* type = llvm::StructType::get(b->getContext(), element_ir_types); - return StatusOr( - [&, b, type](const IrArray::Index& index) -> StatusOr { - llvm::Value* ret = llvm::UndefValue::get(type); - for (size_t i = 0; i < tuple.operand_count(); ++i) { - IrArray::Index used_index = index; - if (i > 0 && - !ShapeUtil::EqualIgnoringElementType(tuple.operand(i)->shape(), - tuple.operand(0)->shape())) { - used_index = used_index.SourceIndexOfBitcast( - tuple.operand(0)->shape(), tuple.operand(i)->shape(), b); - } - TF_ASSIGN_OR_RETURN( - llvm::Value * value, - indexed_generators_.at(tuple.operand(i))(used_index)); - ret = b->CreateInsertValue(ret, value, i); - } - return ret; - }); + return absl::StatusOr([&, b, + type](const IrArray::Index& index) + -> absl::StatusOr { + llvm::Value* ret = llvm::UndefValue::get(type); + for (size_t i = 0; i < tuple.operand_count(); ++i) { + IrArray::Index used_index = index; + if (i > 0 && !ShapeUtil::EqualIgnoringElementType( + tuple.operand(i)->shape(), tuple.operand(0)->shape())) { + used_index = used_index.SourceIndexOfBitcast( + tuple.operand(0)->shape(), tuple.operand(i)->shape(), b); + } + TF_ASSIGN_OR_RETURN(llvm::Value * value, + indexed_generators_.at(tuple.operand(i))(used_index)); + ret = b->CreateInsertValue(ret, value, i); + } + return ret; + }); } -StatusOr FusedIrEmitter::CreateGenerator( - const HloInstruction& instruction) { +absl::StatusOr +FusedIrEmitter::CreateGenerator(const HloInstruction& instruction) { switch (instruction.opcode()) { case HloOpcode::kConstant: return HandleConstant(instruction); case HloOpcode::kGetTupleElement: - return InternalError("Tuple parameters are not supported for fusion"); + return Internal("Tuple parameters are not supported for fusion"); case HloOpcode::kParameter: return InvalidArgument("Unbound parameter: %s", instruction.ToString()); case HloOpcode::kTuple: @@ -164,7 +166,7 @@ StatusOr FusedIrEmitter::CreateGenerator( } } -StatusOr FusedIrEmitter::GetGenerator( +absl::StatusOr FusedIrEmitter::GetGenerator( const HloInstruction& instruction) { std::vector stack = {&instruction}; while (!stack.empty()) { diff --git a/xla/service/llvm_ir/fused_ir_emitter.h b/xla/service/llvm_ir/fused_ir_emitter.h index ec958615ab67e..098666d890eb9 100644 --- a/xla/service/llvm_ir/fused_ir_emitter.h +++ b/xla/service/llvm_ir/fused_ir_emitter.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -56,13 +56,16 @@ class FusedIrEmitter { } // Returns the generator function for the given instruction. - StatusOr GetGenerator(const HloInstruction& instruction); + absl::StatusOr GetGenerator( + const HloInstruction& instruction); private: - StatusOr CreateGenerator(const HloInstruction& instruction); - StatusOr DefaultAction(const HloInstruction& instruction); + absl::StatusOr CreateGenerator( + const HloInstruction& instruction); + absl::StatusOr DefaultAction( + const HloInstruction& instruction); IndexedGenerator HandleConstant(const HloInstruction& constant); - StatusOr HandleTuple(const HloInstruction& tuple); + absl::StatusOr HandleTuple(const HloInstruction& tuple); ElementalIrEmitter& elemental_emitter_; diff --git a/xla/service/llvm_ir/ir_array.cc b/xla/service/llvm_ir/ir_array.cc index acffe65df84ef..25785d175b108 100644 --- a/xla/service/llvm_ir/ir_array.cc +++ b/xla/service/llvm_ir/ir_array.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -191,8 +191,6 @@ IrArray::IrArray(llvm::Value* base_ptr, llvm::Type* pointee_type, Shape shape) shape_(std::move(shape)) { TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); CHECK(base_ptr_->getType()->isPointerTy()); - CHECK(llvm::cast(base_ptr_->getType()) - ->isOpaqueOrPointeeTypeMatches(pointee_type)); int depth = 0; element_type_ = pointee_type; while (llvm::ArrayType* array_type = @@ -375,6 +373,13 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast( return index; } +IrArray::Index IrArray::Index::SourceIndexOfBitcast( + const Shape& operand_shape, llvm::IRBuilder<>* builder) const { + auto shape = ShapeUtil::MakeShape(F32, dims_); + *shape.mutable_layout() = layout_; + return SourceIndexOfBitcast(shape, operand_shape, builder); +} + IrArray::Index IrArray::Index::SourceIndexOfBroadcast( const Shape& shape, const Shape& operand_shape, absl::Span dimension_mapping, @@ -553,8 +558,19 @@ llvm::Value* IrArray::EmitLinearArrayElementAddress( llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); llvm::Type* type = PrimitiveTypeToIrType(shape_.element_type(), module); if (!primitive_util::Is4BitType(shape_.element_type())) { - return b->CreateInBoundsGEP(type, base_ptr_, index.linear(), - llvm_ir::AsStringRef(name)); + auto linear_index = llvm::dyn_cast(index.linear()); + if (linear_index && (linear_index->getOpcode() == llvm::Instruction::Add)) { + llvm::Value* index_operand_0 = linear_index->getOperand(0); + llvm::Value* index_operand_1 = linear_index->getOperand(1); + llvm::Value* ptr_address = + b->CreateGEP(type, base_ptr_, index_operand_0, ""); + + return b->CreateInBoundsGEP(type, ptr_address, index_operand_1, + llvm_ir::AsStringRef(name)); + } else { + return b->CreateInBoundsGEP(type, base_ptr_, index.linear(), + llvm_ir::AsStringRef(name)); + } } // Handle int4 case by dividing index by 2. Int4 arrays are represented in diff --git a/xla/service/llvm_ir/ir_array.h b/xla/service/llvm_ir/ir_array.h index 1ee0995b523a6..0212e90e8da0d 100644 --- a/xla/service/llvm_ir/ir_array.h +++ b/xla/service/llvm_ir/ir_array.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -101,6 +101,17 @@ class IrArray { return with_offset; } + Index AddOffset(absl::Span offsets, + llvm::IRBuilder<>* b) const { + CHECK_EQ(multidim_.size(), offsets.size()); + Index with_offset = *this; + with_offset.linear_ = nullptr; + for (auto&& [dim, offset] : llvm::zip(with_offset.multidim_, offsets)) { + dim = b->CreateAdd(dim, offset); + } + return with_offset; + } + const std::vector& multidim() const { return multidim_; } const std::vector& dims() const { return dims_; } llvm::Value* linear() const { return linear_; } @@ -154,6 +165,9 @@ class IrArray { // to `shape`, returns the source index. Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape, llvm::IRBuilder<>* builder) const; + // Same as above, but for bitcasts from `operand_shape` to `this->dims`. + Index SourceIndexOfBitcast(const Shape& operand_shape, + llvm::IRBuilder<>* builder) const; // Given that "this" is the target index of a broadcast from `operand_shape` // to `shape` with the given dimension mapping, returns the source index. diff --git a/xla/service/llvm_ir/ir_array_test.cc b/xla/service/llvm_ir/ir_array_test.cc index ea1717ce4371f..138f4efe8004e 100644 --- a/xla/service/llvm_ir/ir_array_test.cc +++ b/xla/service/llvm_ir/ir_array_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/llvm_ir/ir_builder_mixin.h b/xla/service/llvm_ir/ir_builder_mixin.h index 6f5f6a3f615c3..23a00c242f9ad 100644 --- a/xla/service/llvm_ir/ir_builder_mixin.h +++ b/xla/service/llvm_ir/ir_builder_mixin.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/llvm_ir/kernel_support_library.cc b/xla/service/llvm_ir/kernel_support_library.cc index 7e246bcdeca82..7d72be76a0f7f 100644 --- a/xla/service/llvm_ir/kernel_support_library.cc +++ b/xla/service/llvm_ir/kernel_support_library.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/llvm_ir/kernel_support_library.h b/xla/service/llvm_ir/kernel_support_library.h index 6f8600f42b105..b395877c8b0a6 100644 --- a/xla/service/llvm_ir/kernel_support_library.h +++ b/xla/service/llvm_ir/kernel_support_library.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/llvm_ir/llvm_command_line_options.h b/xla/service/llvm_ir/llvm_command_line_options.h index aa55982ff0394..9a90c3b7cbc9d 100644 --- a/xla/service/llvm_ir/llvm_command_line_options.h +++ b/xla/service/llvm_ir/llvm_command_line_options.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -51,7 +51,7 @@ void InitializeLLVMCommandLineOptions(const T& options) { VLOG(2) << s; } llvm::cl::ParseCommandLineOptions(static_cast(fake_argv.size()), - &fake_argv[0]); + fake_argv.data()); } } diff --git a/xla/service/llvm_ir/llvm_loop.cc b/xla/service/llvm_ir/llvm_loop.cc index 4d3eb6554a3b0..bf5d49b78a828 100644 --- a/xla/service/llvm_ir/llvm_loop.cc +++ b/xla/service/llvm_ir/llvm_loop.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/llvm_ir/llvm_loop.h b/xla/service/llvm_ir/llvm_loop.h index a30ec8c8bea92..8e5bb5b2a38a4 100644 --- a/xla/service/llvm_ir/llvm_loop.h +++ b/xla/service/llvm_ir/llvm_loop.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/llvm_ir/llvm_type_conversion_util.h b/xla/service/llvm_ir/llvm_type_conversion_util.h index 033c933deb157..2c2af057aceb7 100644 --- a/xla/service/llvm_ir/llvm_type_conversion_util.h +++ b/xla/service/llvm_ir/llvm_type_conversion_util.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/llvm_ir/llvm_util.cc b/xla/service/llvm_ir/llvm_util.cc index 59124f4e6a787..332c36ddd1534 100644 --- a/xla/service/llvm_ir/llvm_util.cc +++ b/xla/service/llvm_ir/llvm_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -50,6 +50,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/Cloning.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project @@ -193,9 +194,6 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Type* element_type, llvm::Value* index, llvm::IRBuilder<>* b) { llvm::Type* array_type = array->getType(); CHECK(array_type->isPointerTy()); - llvm::PointerType* array_type_as_pointer = - llvm::cast(array_type); - CHECK(array_type_as_pointer->isOpaqueOrPointeeTypeMatches(element_type)); VLOG(2) << "EmitBufferIndexingGEP with type=" << llvm_ir::DumpToString(array_type) << " array=" << llvm_ir::DumpToString(array) @@ -225,22 +223,16 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, return llvm::Type::getInt8Ty(module->getContext()); case S16: case U16: - case BF16: - // For BF16 we just need some type that is 16 bits wide so that it will - // take up the right amount of space in memory. LLVM does not have a BF16 - // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so - // we can't map it directly to an LLVM type. We will not map a BF16 - // addition to an addition on this type (int16_t) - this is just the type - // used for storage. return llvm::Type::getInt16Ty(module->getContext()); case F8E5M2: case F8E5M2FNUZ: case F8E4M3FN: case F8E4M3B11FNUZ: case F8E4M3FNUZ: - // Similarly as with BF16, we represent F8 as an int since there is no - // LLVM F8 dtype. + // We represent F8 as an int since there is no LLVM F8 dtype. return llvm::Type::getInt8Ty(module->getContext()); + case BF16: + return llvm::Type::getBFloatTy(module->getContext()); case F16: return llvm::Type::getHalfTy(module->getContext()); case S32: @@ -324,12 +316,11 @@ llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) { return result_type; } -StatusOr EncodeSelfDescribingShapeConstant(const Shape& shape, - int32_t* shape_size, - llvm::IRBuilder<>* b) { +absl::StatusOr EncodeSelfDescribingShapeConstant( + const Shape& shape, int32_t* shape_size, llvm::IRBuilder<>* b) { std::string encoded_shape = shape.SerializeAsString(); if (encoded_shape.size() > std::numeric_limits::max()) { - return InternalError("Encoded shape size exceeded int32_t size limit."); + return Internal("Encoded shape size exceeded int32_t size limit."); } *shape_size = static_cast(encoded_shape.size()); return b->CreateGlobalStringPtr(encoded_shape); @@ -364,6 +355,49 @@ llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module, llvm::GlobalValue::NotThreadLocal, kGPUSharedMemoryAddrSpace); } +SharedMemoryTile AllocateSharedMemoryTile( + llvm::Module* module, llvm::Type* element_type, + absl::Span dimensions_major_to_minor, + absl::string_view buffer_name) { + llvm::Type* ty = element_type; + for (auto dim : llvm::reverse(dimensions_major_to_minor)) { + ty = llvm::ArrayType::get(ty, dim); + } + return SharedMemoryTile{ + llvm_ir::AllocateSharedMemoryTile(module, ty, buffer_name), element_type}; +} + +static std::vector IndexWith0( + absl::Span index, llvm::IRBuilder<>* b) { + std::vector index_with_0{ + llvm::ConstantInt::get(index.front()->getType(), 0)}; + absl::c_copy(index, std::back_inserter(index_with_0)); + return index_with_0; +} + +llvm::Value* SharedMemoryTile::Address(absl::Span index, + llvm::IRBuilder<>* b) const { + llvm::Value* gep = b->CreateInBoundsGEP(base_ptr_->getValueType(), base_ptr_, + IndexWith0(index, b)); + // __shared__ memory uses a different address space, so we cast it + // to global address space before writing or reading. + return b->CreateAddrSpaceCast(gep, + llvm::PointerType::get(b->getContext(), 0)); +}; + +llvm::Value* SharedMemoryTile::Load(absl::Span index, + llvm::IRBuilder<>* b) const { + auto* load_type = llvm::GetElementPtrInst::getIndexedType( + base_ptr_->getValueType(), IndexWith0(index, b)); + return b->CreateLoad(load_type, Address(index, b)); +} + +llvm::StoreInst* SharedMemoryTile::Store(llvm::Value* value, + absl::Span index, + llvm::IRBuilder<>* b) const { + return b->CreateStore(value, Address(index, b)); +} + llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, absl::string_view name, llvm::IRBuilder<>* b, @@ -380,8 +414,12 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, llvm::Function* function = b->GetInsertBlock()->getParent(); b->SetInsertPoint(&function->getEntryBlock(), function->getEntryBlock().getFirstInsertionPt()); + llvm::Module* module = b->GetInsertBlock()->getModule(); + // Explicitly set local addrspace for SPIR backend. + llvm::Triple target(module->getTargetTriple()); + int addrspace = target.isSPIR() || target.isAMDGPU() ? 5 : 0; llvm::AllocaInst* alloca = - b->CreateAlloca(type, element_count, AsStringRef(name)); + b->CreateAlloca(type, addrspace, element_count, AsStringRef(name)); if (alignment != 0) { alloca->setAlignment(llvm::Align(alignment)); } @@ -499,7 +537,11 @@ void SetDereferenceableMetadataForLoad(llvm::LoadInst* load, } llvm::Instruction* AddRangeMetadata(int32_t lower, int32_t upper, - llvm::Instruction* inst) { + llvm::Instruction* inst, + llvm::Module* module) { + if (llvm::Triple(module->getTargetTriple()).isSPIR()) { + return inst; + } llvm::LLVMContext& context = inst->getParent()->getContext(); llvm::IntegerType* i32 = llvm::Type::getInt32Ty(context); inst->setMetadata( diff --git a/xla/service/llvm_ir/llvm_util.h b/xla/service/llvm_ir/llvm_util.h index 10f710309781e..41945f6cbbeb8 100644 --- a/xla/service/llvm_ir/llvm_util.h +++ b/xla/service/llvm_ir/llvm_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -143,9 +143,8 @@ llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module); // Returns a value that represents a pointer to a global string constant that // encodes the shape as a serialized protobuf. -StatusOr EncodeSelfDescribingShapeConstant(const Shape& shape, - int32_t* shape_size, - llvm::IRBuilder<>* b); +absl::StatusOr EncodeSelfDescribingShapeConstant( + const Shape& shape, int32_t* shape_size, llvm::IRBuilder<>* b); // Converts a given literal to an IR Constant. Literals have known constant // values at IR emission time. @@ -157,6 +156,33 @@ llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module, llvm::Type* tile_type, absl::string_view name); +// Utility class for working with shared memory. +class SharedMemoryTile { + public: + SharedMemoryTile() = default; + explicit SharedMemoryTile(llvm::GlobalVariable* base_ptr, + llvm::Type* element_type) + : base_ptr_(base_ptr), element_type_(element_type) {} + + llvm::Value* Address(absl::Span index, + llvm::IRBuilder<>* b) const; + llvm::Value* Load(absl::Span index, + llvm::IRBuilder<>* b) const; + llvm::StoreInst* Store(llvm::Value* value, + absl::Span index, + llvm::IRBuilder<>* b) const; + llvm::Type* GetElementType() const { return element_type_; } + + private: + llvm::GlobalVariable* base_ptr_; + llvm::Type* element_type_; +}; + +SharedMemoryTile AllocateSharedMemoryTile( + llvm::Module* module, llvm::Type* element_type, + absl::Span dimensions_major_to_minor, + absl::string_view buffer_name); + // Inserts an allocate of the requested type at the entry point of the // function that the builder is currently building. The insert point // of the builder is set to the same place after calling this function @@ -246,7 +272,8 @@ void SetDereferenceableMetadataForLoad(llvm::LoadInst* load, // Tells LLVM `inst >= lower && inst < upper`. Returns `inst` for convenience. llvm::Instruction* AddRangeMetadata(int32_t lower, int32_t upper, - llvm::Instruction* inst); + llvm::Instruction* inst, + llvm::Module* module); void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder); diff --git a/xla/service/llvm_ir/loop_emitter.cc b/xla/service/llvm_ir/loop_emitter.cc index be8ade56d5f5b..8e07365a88c03 100644 --- a/xla/service/llvm_ir/loop_emitter.cc +++ b/xla/service/llvm_ir/loop_emitter.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/llvm_ir/loop_emitter.h b/xla/service/llvm_ir/loop_emitter.h index c5f3c2322db05..da23eeb5ce37c 100644 --- a/xla/service/llvm_ir/loop_emitter.h +++ b/xla/service/llvm_ir/loop_emitter.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ namespace llvm_ir { // The function has to emit code to compute this value and return the resulting // llvm::Value*. using ElementGenerator = - std::function(const IrArray::Index& index)>; + std::function(const IrArray::Index& index)>; using BodyEmitter = std::function; // Creates the body emitter from target arrays. diff --git a/xla/service/llvm_ir/math_ops.cc b/xla/service/llvm_ir/math_ops.cc index d5089ddddd1bc..f33e8ec40bb3b 100644 --- a/xla/service/llvm_ir/math_ops.cc +++ b/xla/service/llvm_ir/math_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,13 +14,25 @@ limitations under the License. ==============================================================================*/ #include "xla/service/llvm_ir/math_ops.h" + #include "xla/service/llvm_ir/llvm_util.h" namespace xla { namespace llvm_ir { -llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input) { +llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input, + bool with_fma) { llvm::Type* type = input->getType(); + const float plus_clamp = + with_fma ? 7.99881172180175781f : 7.90531110763549805f; + const float minus_clamp = -plus_clamp; + // Inputs in the range [plus_clamp, 9.0] may cause the output + // of EmitFastTanh to be greater than 1, so we set the input to be at most + // 'plus_clamp'. We choose 'plus_clamp' in a way that the + // tanh approximation on that input is exactly 1.0. Similarly for + // 'minus_clamp', where the tanh approximation will return exactly + // -1.0. + // Taken from tanh(Eigen/src/Core/MathFunctionsImpl.h). // For small values of x, we can approximate tanh(x)=x. For extremely small // values of x (|x| < 1e-37), the other approximation evaluates tanh(x) = 0. @@ -30,14 +42,12 @@ llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input) { auto use_aprox = b->CreateFCmpOLT(abs_x, llvm::ConstantFP::get(type, kCanUseApprox)); - // Clamp the input to [-9, 9]. - // // To simplify the code base until it's an issue, don't have a slow min/max in // this approximation. llvm::Value* input_clamped = llvm_ir::EmitFloatMin( - llvm_ir::EmitFloatMax(input, llvm::ConstantFP::get(type, -9.0), b, + llvm_ir::EmitFloatMax(input, llvm::ConstantFP::get(type, minus_clamp), b, /*enable_fast_min_max=*/true), - llvm::ConstantFP::get(type, 9.0), b, /*enable_fast_min_max=*/true); + llvm::ConstantFP::get(type, plus_clamp), b, /*enable_fast_min_max=*/true); static constexpr std::array numerator_coeffs{ -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, @@ -68,5 +78,76 @@ llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input) { b->CreateFDiv(numerator, denominator)); } +llvm::Value* EmitErfF32(llvm::IRBuilder<>* b, llvm::Value* x) { + auto type = x->getType(); + constexpr float kErfInvOneMinusHalfULP = 3.832506856900711f; + auto call_fabs = [b](llvm::Value* operand_value) { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {operand_value}, + {operand_value->getType()}, b); + }; + auto fcmp_le = [b](llvm::Value* lhs_value, llvm::Value* rhs_value) { + return b->CreateFCmpOLE(lhs_value, rhs_value); + }; + llvm::Value* const clamp = fcmp_le( + llvm::ConstantFP::get(type, kErfInvOneMinusHalfULP), call_fabs(x)); + // The monomial coefficients of the numerator polynomial (odd). + llvm::Value* const alpha_1 = llvm::ConstantFP::get(type, 1.128379143519084f); + llvm::Value* const alpha_3 = + llvm::ConstantFP::get(type, 0.18520832239976145f); + llvm::Value* const alpha_5 = + llvm::ConstantFP::get(type, 0.050955695062380861f); + llvm::Value* const alpha_7 = + llvm::ConstantFP::get(type, 0.0034082910107109506f); + llvm::Value* const alpha_9 = + llvm::ConstantFP::get(type, 0.00022905065861350646f); + + // The monomial coefficients of the denominator polynomial (even). + llvm::Value* const beta_0 = llvm::ConstantFP::get(type, 1.0f); + llvm::Value* const beta_2 = llvm::ConstantFP::get(type, 0.49746925110067538f); + llvm::Value* const beta_4 = llvm::ConstantFP::get(type, 0.11098505178285362f); + llvm::Value* const beta_6 = + llvm::ConstantFP::get(type, 0.014070470171167667f); + llvm::Value* const beta_8 = + llvm::ConstantFP::get(type, 0.0010179625278914885f); + llvm::Value* const beta_10 = + llvm::ConstantFP::get(type, 0.000023547966471313185f); + llvm::Value* const beta_12 = + llvm::ConstantFP::get(type, -1.1791602954361697e-7f); + + // Since the polynomials are odd/even, we need x^2. + llvm::Value* const x2 = b->CreateFMul(x, x); + + // Evaluate the numerator polynomial p. + auto call_fma = [b](llvm::Value* multiplier, llvm::Value* multiplicand, + llvm::Value* addend) { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fma, + {multiplier, multiplicand, addend}, + {multiplier->getType()}, b); + }; + llvm::Value* p = call_fma(x2, alpha_9, alpha_7); + p = call_fma(x2, p, alpha_5); + p = call_fma(x2, p, alpha_3); + p = call_fma(x2, p, alpha_1); + p = b->CreateFMul(x, p); + + // Evaluate the denominator polynomial p. + llvm::Value* q = call_fma(x2, beta_12, beta_10); + q = call_fma(x2, q, beta_8); + q = call_fma(x2, q, beta_6); + q = call_fma(x2, q, beta_4); + q = call_fma(x2, q, beta_2); + q = call_fma(x2, q, beta_0); + + // Divide the numerator by the denominator. + auto call_copysign = [b](llvm::Value* mag, llvm::Value* sign) { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign, {mag, sign}, + {mag->getType()}, b); + }; + auto* result = + b->CreateSelect(clamp, call_copysign(llvm::ConstantFP::get(type, 1.0), x), + b->CreateFDiv(p, q)); + return result; +} + } // namespace llvm_ir } // namespace xla diff --git a/xla/service/llvm_ir/math_ops.h b/xla/service/llvm_ir/math_ops.h index bc16d4cc593d6..7c5bf27c55de0 100644 --- a/xla/service/llvm_ir/math_ops.h +++ b/xla/service/llvm_ir/math_ops.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,8 +23,14 @@ namespace xla { namespace llvm_ir { // Emits an approximation of tanh. The implementation uses the same rational +// interpolant as implemented in Eigen3. 'with_fma' should be set to true if FMA +// instructions are available. +llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input, + bool with_fma = false); + +// Emits an approximation of erf. The implementation uses the same rational // interpolant as implemented in Eigen3. -llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input); +llvm::Value* EmitErfF32(llvm::IRBuilder<>* b, llvm::Value* x); } // namespace llvm_ir } // namespace xla diff --git a/xla/service/llvm_ir/sort_util.cc b/xla/service/llvm_ir/sort_util.cc index f8f717846f71f..22ed179ae4095 100644 --- a/xla/service/llvm_ir/sort_util.cc +++ b/xla/service/llvm_ir/sort_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -165,7 +165,8 @@ Status EmitTiledCompareLoop( llvm::Value* thread_id = gpu::EmitCallToTargetIntrinsic( gpu::TargetIntrinsicID::kThreadIdx, {}, {}, b); llvm_ir::AddRangeMetadata(0, tile_size / 2, - llvm::cast(thread_id)); + llvm::cast(thread_id), + b->GetInsertBlock()->getModule()); thread_id = b->CreateIntCast(thread_id, tiled_keys_index.GetType(), /*isSigned=*/true, "thread.id.x"); @@ -227,10 +228,11 @@ Status EmitTiledCompareLoop( // We need a generic pointer with address space 0 instead of a pointer to // shared memory (address space 3) so that we can pass it to the comparison // computation. - return b->CreateAddrSpaceCast(shared_memory_address, - llvm::PointerType::getWithSamePointeeType( - llvm::cast(ptr_type), - /*AddressSpace=*/0)); + return b->CreateAddrSpaceCast( + shared_memory_address, + llvm::PointerType::get( + llvm::cast(ptr_type)->getContext(), + /*AddressSpace=*/0)); }; auto element_address_pointee_type = [&](int64_t operand, llvm::Value* index) { return llvm::GetElementPtrInst::getIndexedType( diff --git a/xla/service/llvm_ir/sort_util.h b/xla/service/llvm_ir/sort_util.h index 92cfa679ae284..e84c2c9194144 100644 --- a/xla/service/llvm_ir/sort_util.h +++ b/xla/service/llvm_ir/sort_util.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/llvm_ir/tuple_ops.cc b/xla/service/llvm_ir/tuple_ops.cc index e5f5d98a5c4bd..ce064fefb3188 100644 --- a/xla/service/llvm_ir/tuple_ops.cc +++ b/xla/service/llvm_ir/tuple_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -89,8 +89,6 @@ llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64_t index, int alignment, llvm::Value* operand, llvm::Type* operand_pointee_type, llvm::IRBuilder<>* b) { - CHECK(llvm::cast(operand->getType()) - ->isOpaqueOrPointeeTypeMatches(operand_pointee_type)); const std::vector gep_index = {b->getInt64(0), b->getInt64(index)}; llvm::Value* element_ptr = diff --git a/xla/service/llvm_ir/tuple_ops.h b/xla/service/llvm_ir/tuple_ops.h index 0ce8249caa9f1..5506e72fd8759 100644 --- a/xla/service/llvm_ir/tuple_ops.h +++ b/xla/service/llvm_ir/tuple_ops.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/local_service.cc b/xla/service/local_service.cc index cbd81a3915b7c..573b372be61d4 100644 --- a/xla/service/local_service.cc +++ b/xla/service/local_service.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -46,8 +46,8 @@ limitations under the License. namespace xla { -/* static */ StatusOr> LocalService::NewService( - const ServiceOptions& options) { +/* static */ absl::StatusOr> +LocalService::NewService(const ServiceOptions& options) { se::Platform* platform = options.platform(); if (platform == nullptr) { TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); @@ -70,7 +70,7 @@ LocalService::LocalService(const ServiceOptions& options, std::unique_ptr execute_backend) : Service(options, std::move(execute_backend)) {} -StatusOr>> +absl::StatusOr>> LocalService::CompileExecutables( const XlaComputation& computation, const absl::Span argument_layouts, @@ -121,7 +121,7 @@ LocalService::CompileExecutables( } } -StatusOr>> +absl::StatusOr>> LocalService::CompileAotResults( const XlaComputation& computation, const absl::Span argument_layouts, @@ -150,13 +150,14 @@ LocalService::CompileAotResults( build_options.run_backend_only()); } -StatusOr LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) { +absl::StatusOr LocalService::ReplicaNumberToDeviceOrdinal( + int replica_number) { return backend().computation_placer()->DeviceId( replica_number, /*computation=*/0, options_.number_of_replicas(), /*computation_count=*/1); } -StatusOr LocalService::GlobalDataToShapedBuffer( +absl::StatusOr LocalService::GlobalDataToShapedBuffer( const GlobalDataHandle& data, int replica_number) { TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data)); if (replica_number >= buffers.size()) { @@ -167,7 +168,7 @@ StatusOr LocalService::GlobalDataToShapedBuffer( return buffers[replica_number]; } -StatusOr LocalService::RegisterReplicatedBuffers( +absl::StatusOr LocalService::RegisterReplicatedBuffers( std::vector replicated_buffers, const std::string& tag) { return allocation_tracker_.RegisterReplicatedBuffers( diff --git a/xla/service/local_service.h b/xla/service/local_service.h index 4bea89ef1f587..27d27e622f51d 100644 --- a/xla/service/local_service.h +++ b/xla/service/local_service.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -39,7 +39,7 @@ namespace xla { class LocalService : public Service { public: // Factory for creating a LocalService. - static StatusOr> NewService( + static absl::StatusOr> NewService( const ServiceOptions& options); // Builds Executables with the given XlaComputation, argument layouts and @@ -47,7 +47,7 @@ class LocalService : public Service { // produce a result of the given layout. If device_allocator is non-null, // then the compiler may use it to allocate temp space on the device. The // compiler is responsible for freeing any memory it allocates this way. - StatusOr>> CompileExecutables( + absl::StatusOr>> CompileExecutables( const XlaComputation& computation, const absl::Span argument_layouts, const ExecutableBuildOptions& build_options); @@ -55,7 +55,7 @@ class LocalService : public Service { // Same as CompileExecutables() above, but return AotCompilationResult objects // (instead of Executable objects), which can be persisted to later load // Executable objects. - StatusOr>> + absl::StatusOr>> CompileAotResults(const XlaComputation& computation, const absl::Span argument_layouts, const ExecutableBuildOptions& build_options); @@ -65,16 +65,16 @@ class LocalService : public Service { // This returns an error if there is not a one-to-one correspondence of // replicas to device ordinals, but is useful as a short term mechanism for // the "easy" case where a single replica is a single device. - StatusOr ReplicaNumberToDeviceOrdinal(int replica_number); + absl::StatusOr ReplicaNumberToDeviceOrdinal(int replica_number); // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid // as long as the handle is valid. - StatusOr GlobalDataToShapedBuffer( + absl::StatusOr GlobalDataToShapedBuffer( const GlobalDataHandle& data, int replica_number); // Registers a vector of shaped buffers of device memory, one per replica, and // returns a corresponding handle that can be used for talking to XLA clients. - StatusOr RegisterReplicatedBuffers( + absl::StatusOr RegisterReplicatedBuffers( std::vector replicated_buffers, const std::string& tag); diff --git a/xla/service/local_service_utils.cc b/xla/service/local_service_utils.cc index e0b2d81da3ace..b07dc887cef41 100644 --- a/xla/service/local_service_utils.cc +++ b/xla/service/local_service_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -63,7 +63,7 @@ std::optional ParameterMetadata( } } // namespace -StatusOr> GetHloModuleConfig( +absl::StatusOr> GetHloModuleConfig( const XlaComputation& computation, absl::Span argument_layouts, const ExecutableBuildOptions& build_options, ServiceOptions* options, diff --git a/xla/service/local_service_utils.h b/xla/service/local_service_utils.h index e71f4605aec2e..5e3bb9a5d0f99 100644 --- a/xla/service/local_service_utils.h +++ b/xla/service/local_service_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ limitations under the License. namespace xla { // Validates the computation argument layouts, and returns the corresponding // HloModuleConfig. -StatusOr> GetHloModuleConfig( +absl::StatusOr> GetHloModuleConfig( const XlaComputation& computation, absl::Span argument_layouts, const ExecutableBuildOptions& build_options, diff --git a/xla/service/lockable.h b/xla/service/lockable.h new file mode 100644 index 0000000000000..3a71685c65343 --- /dev/null +++ b/xla/service/lockable.h @@ -0,0 +1,140 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_LOCKABLE_H_ +#define XLA_SERVICE_LOCKABLE_H_ + +#include + +#include "absl/base/thread_annotations.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "tsl/platform/logging.h" + +namespace xla { + +// A template that can be specialized to give a human readable name to lockable +// of type `T`. +template +struct LockableName { + static std::string ToString(const T& value) { + return absl::StrFormat("lockable %p", &value); + } +}; + +// An RAII helper for a value of type `T` that requires exclusive access. +template > +class Lockable { + public: + // RAII type that will release the exclusive lock when it is destroyed. + class Lock { + public: + Lock() = default; + + Lock(Lock&& other) { + lockable_ = other.lockable_; + other.lockable_ = nullptr; + } + + Lock& operator=(Lock&& other) { + lockable_ = other.lockable_; + other.lockable_ = nullptr; + return *this; + } + + ~Lock() { + if (lockable_) lockable_->Release(); + } + + T& operator*() const { return lockable_->value_; } + T* operator->() const { return &lockable_->value_; } + operator bool() const { return lockable_ != nullptr; } // NOLINT + + std::string ToString() const { + return lockable_ ? lockable_->ToString() : ""; + } + + private: + friend class Lockable; + explicit Lock(Lockable* lockable) : lockable_(lockable) {} + Lockable* lockable_ = nullptr; + }; + + Lockable() = default; + + explicit Lockable(T value) : value_(std::move(value)) { + VLOG(2) << "Constructed " << LockableName::ToString(value_); + } + + template + explicit Lockable(Args&&... args) : value_(std::forward(args)...) { + VLOG(2) << "Constructed " << LockableName::ToString(value_); + } + + Lockable(const Lockable&) = delete; + Lockable& operator=(const Lockable&) = delete; + + ~Lockable() { + VLOG(2) << "Destroy " << LockableName::ToString(value_); + absl::MutexLock lock(&mutex_); + CHECK_EQ(is_unlocked_, true); // NOLINT + } + + Lock Acquire() { + absl::MutexLock lock(&mutex_); + mutex_.Await(absl::Condition(&is_unlocked_)); + VLOG(2) << "Acquired " << LockableName::ToString(value_); + is_unlocked_ = false; + + return Lock(this); + } + + Lock TryAcquire() { + absl::MutexLock lock(&mutex_); + + // Someone already locked this object, return an empty lock. + if (is_unlocked_ == false) { + VLOG(2) << "Failed to acquire " << LockableName::ToString(value_); + return Lock(); + } + + VLOG(2) << "Acquired " << LockableName::ToString(value_); + is_unlocked_ = false; + return Lock(this); + } + + std::string ToString() const { return LockableName::ToString(value_); } + + protected: + const T& value() const { return value_; } + + private: + friend class Lock; + + void Release() { + absl::MutexLock lock(&mutex_); + VLOG(2) << "Released " << LockableName::ToString(value_); + CHECK(!is_unlocked_); // NOLINT + is_unlocked_ = true; + } + + T value_; + absl::Mutex mutex_; + bool is_unlocked_ ABSL_GUARDED_BY(mutex_) = true; +}; + +} // namespace xla + +#endif // XLA_SERVICE_LOCKABLE_H_ diff --git a/xla/service/lockable_test.cc b/xla/service/lockable_test.cc new file mode 100644 index 0000000000000..9118fb9e7276b --- /dev/null +++ b/xla/service/lockable_test.cc @@ -0,0 +1,98 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/lockable.h" + +#include +#include +#include +#include + +#include "absl/synchronization/blocking_counter.h" +#include "tsl/platform/env.h" +#include "tsl/platform/test.h" +#include "tsl/platform/threadpool.h" + +namespace xla { + +tsl::thread::ThreadPool CreateThreadPool(int32_t size) { + return tsl::thread::ThreadPool(tsl::Env::Default(), "lockable_test", size); +} + +template <> +struct LockableName { + static std::string ToString(const std::string& str) { + return "lockable string " + str; + } +}; + +class LockableString : public Lockable { + using Lockable::Lockable; +}; + +TEST(LockableTest, LockProperties) { + // Lock can be default constructed and implicitly casted to bool. + LockableString::Lock lock0; + EXPECT_FALSE(lock0); + + // Lock can be locked from a lockable object. + LockableString str("foo"); + LockableString::Lock lock1 = str.Acquire(); + EXPECT_TRUE(lock1); + + // Lock can be moved. + LockableString::Lock lock2 = std::move(lock1); + EXPECT_FALSE(lock1); + EXPECT_TRUE(lock2); + + // TryAcquire will return empty lock for locked object. + LockableString::Lock lock3 = str.TryAcquire(); + EXPECT_FALSE(lock3); + + // Locks have human readable names. + EXPECT_EQ(lock1.ToString(), ""); + EXPECT_EQ(lock2.ToString(), "lockable string foo"); + + // Lockable has human readable name. + EXPECT_EQ(str.ToString(), "lockable string foo"); + + // After lock is destructed we can acquire lockable with TryLock. + auto sink = [](LockableString::Lock) {}; + sink(std::move(lock2)); + + LockableString::Lock lock4 = str.TryAcquire(); + EXPECT_TRUE(lock4); +} + +TEST(LockableTest, ExclusiveAccess) { + absl::BlockingCounter counter(100); + auto thread_pool = CreateThreadPool(10); + + LockableString str("foo"); + + for (size_t i = 0; i < 100; ++i) { + thread_pool.Schedule([&] { + { // Decrement counter only after lock is released. + auto exclusive_str = str.Acquire(); + ASSERT_EQ(*exclusive_str, "foo"); + } + counter.DecrementCount(); + }); + } + + counter.Wait(); +} + +} // namespace xla diff --git a/xla/service/logical_buffer.cc b/xla/service/logical_buffer.cc index 56fa248b00e89..0a69feaf0641e 100644 --- a/xla/service/logical_buffer.cc +++ b/xla/service/logical_buffer.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/logical_buffer.h b/xla/service/logical_buffer.h index 8666d6d62614f..1bc6ab4811442 100644 --- a/xla/service/logical_buffer.h +++ b/xla/service/logical_buffer.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/logical_buffer_analysis.cc b/xla/service/logical_buffer_analysis.cc index 15afe61bea624..38bda9003bbc1 100644 --- a/xla/service/logical_buffer_analysis.cc +++ b/xla/service/logical_buffer_analysis.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ void GatherFusionInstructions( } // namespace -/* static */ StatusOr> +/* static */ absl::StatusOr> LogicalBufferAnalysis::Run(const HloModule* module) { std::unique_ptr analysis( new LogicalBufferAnalysis(module)); @@ -78,7 +78,7 @@ Status LogicalBufferAnalysis::Analyze() { } LogicalBuffer& LogicalBufferAnalysis::GetBuffer(LogicalBuffer::Id id) const { - return *logical_buffers_.at(id); + return *logical_buffers_[id]; } LogicalBuffer& LogicalBufferAnalysis::GetBuffer(HloInstruction* instruction, diff --git a/xla/service/logical_buffer_analysis.h b/xla/service/logical_buffer_analysis.h index ce86d5a2adc43..b7ca006ea6b02 100644 --- a/xla/service/logical_buffer_analysis.h +++ b/xla/service/logical_buffer_analysis.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ namespace xla { class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { public: // Runs points-to analysis on 'module'. - static StatusOr> Run( + static absl::StatusOr> Run( const HloModule* module); // Returns the logical buffer with the given ID. diff --git a/xla/service/logistic_expander.cc b/xla/service/logistic_expander.cc index a6606f775e9f4..85972849c9eeb 100644 --- a/xla/service/logistic_expander.cc +++ b/xla/service/logistic_expander.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -40,7 +40,7 @@ bool LogisticExpander::InstructionMatchesPattern(HloInstruction* instruction) { return instruction->opcode() == HloOpcode::kLogistic; } -StatusOr LogisticExpander::ExpandInstruction( +absl::StatusOr LogisticExpander::ExpandInstruction( HloInstruction* instruction) { HloInstruction* operand = instruction->mutable_operand(0); const Shape operand_shape = operand->shape(); diff --git a/xla/service/logistic_expander.h b/xla/service/logistic_expander.h index b754b71266d1f..c0dd55a42bf69 100644 --- a/xla/service/logistic_expander.h +++ b/xla/service/logistic_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ class LogisticExpander : public OpExpanderPass { // Returns a replacement for `instruction`, or nullptr if no replacement is // needed (e.g. only the to_apply subcomputation of the instruction was // modified). - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; }; diff --git a/xla/service/logistic_expander_test.cc b/xla/service/logistic_expander_test.cc index a3a64c804b950..dae9292b715e4 100644 --- a/xla/service/logistic_expander_test.cc +++ b/xla/service/logistic_expander_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,30 +16,19 @@ limitations under the License. #include "xla/service/logistic_expander.h" #include +#include -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/layout_util.h" -#include "xla/literal.h" -#include "xla/service/hlo_creation_utils.h" +#include "xla/service/dynamic_padder.h" #include "xla/service/hlo_parser.h" -#include "xla/service/hlo_pass_fix.h" -#include "xla/service/hlo_pass_pipeline.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/service/shape_inference.h" -#include "xla/shape_util.h" +#include "xla/statusor.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "xla/types.h" -#include "xla/window_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { @@ -73,5 +62,21 @@ TEST_F(LogisticExpanderTest, ExpandWith) { m::Exp(m::Negate(m::Parameter(0))))))); } +TEST_F(LogisticExpanderTest, DynamicDimensions) { + constexpr std::string_view hlo = R"( +HloModule DynamicDimensions + +ENTRY main { + p = f32[<=10] parameter(0) + ROOT root = f32[<=10] logistic(p) +} + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + + LogisticExpander logistic_expander; + ASSERT_TRUE(logistic_expander.Run(module.get()).value()); + DynamicPadder dynamic_padder; + EXPECT_TRUE(dynamic_padder.Run(module.get()).value()); +} } // namespace } // namespace xla diff --git a/xla/service/loop_schedule_linearizer.cc b/xla/service/loop_schedule_linearizer.cc index 5cc07be9becb5..9acd4597ffb6b 100644 --- a/xla/service/loop_schedule_linearizer.cc +++ b/xla/service/loop_schedule_linearizer.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -70,7 +70,7 @@ class ComputationInstructionOrdering { } // namespace -static StatusOr AddControlEdgesForLoopWrites( +static absl::StatusOr AddControlEdgesForLoopWrites( HloInstruction* xla_while, HloAliasAnalysis& alias_analysis) { HloDataflowAnalysis& dataflow = alias_analysis.dataflow_analysis(); HloComputation* body = xla_while->while_body(); @@ -145,7 +145,7 @@ static StatusOr AddControlEdgesForLoopWrites( return changed; } -StatusOr LoopScheduleLinearizer::Run( +absl::StatusOr LoopScheduleLinearizer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { // Constructing HloAliasAnalysis is expensive, so don't do it until we find at @@ -166,11 +166,10 @@ StatusOr LoopScheduleLinearizer::Run( const HloComputation* body = instruction->while_body(); bool has_async_collectives = absl::c_any_of(body->instructions(), [](const HloInstruction* instr) { - HloOpcode op = instr->opcode(); return hlo_query::IsAsyncCollectiveStartOp( - op, /*include_send_recv=*/true) || + instr, /*include_send_recv=*/true) || hlo_query::IsAsyncCollectiveDoneOp( - op, /*include_send_recv=*/true); + instr, /*include_send_recv=*/true); }); if (has_async_collectives) { diff --git a/xla/service/loop_schedule_linearizer.h b/xla/service/loop_schedule_linearizer.h index a3a1b9e4beb15..ac57ef4b8ce80 100644 --- a/xla/service/loop_schedule_linearizer.h +++ b/xla/service/loop_schedule_linearizer.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ class LoopScheduleLinearizer : public HloModulePass { : can_share_buffer_(can_share_buffer) {} using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/loop_schedule_linearizer_test.cc b/xla/service/loop_schedule_linearizer_test.cc index ac101a1e7cc24..bcc023f672246 100644 --- a/xla/service/loop_schedule_linearizer_test.cc +++ b/xla/service/loop_schedule_linearizer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/map_inliner.cc b/xla/service/map_inliner.cc index 1f6317f24d2eb..eea543fe2d6e2 100644 --- a/xla/service/map_inliner.cc +++ b/xla/service/map_inliner.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -46,7 +46,7 @@ class MapInlinerVisitor : public DfsHloVisitorWithDefault { Status HandleMap(HloInstruction* map) override; // Runs the visitor on a computation. - StatusOr Run(HloComputation* computation); + absl::StatusOr Run(HloComputation* computation); private: // Current HloComputation instance the MapInlinerVisitor is traversing. @@ -56,7 +56,7 @@ class MapInlinerVisitor : public DfsHloVisitorWithDefault { bool changed_ = false; }; -StatusOr MapInlinerVisitor::Run(HloComputation* computation) { +absl::StatusOr MapInlinerVisitor::Run(HloComputation* computation) { changed_ = false; computation_ = computation; TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this)); @@ -110,7 +110,7 @@ Status MapInlinerVisitor::HandleMap(HloInstruction* map) { return OkStatus(); } -StatusOr MapInliner::Run( +absl::StatusOr MapInliner::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { MapInlinerVisitor visitor(/*computation=*/nullptr); diff --git a/xla/service/map_inliner.h b/xla/service/map_inliner.h index ebb10cdf3d87d..4735f84240f92 100644 --- a/xla/service/map_inliner.h +++ b/xla/service/map_inliner.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,7 +32,7 @@ class MapInliner : public HloModulePass { // Run map inlining on the given computation. Returns whether the computation // was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/map_inliner_test.cc b/xla/service/map_inliner_test.cc index d4219447ad8ea..3d58214907296 100644 --- a/xla/service/map_inliner_test.cc +++ b/xla/service/map_inliner_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/mapped_ptr_container_sorter.h b/xla/service/mapped_ptr_container_sorter.h index 262f5d455f9c0..1b1dc5df19d5c 100644 --- a/xla/service/mapped_ptr_container_sorter.h +++ b/xla/service/mapped_ptr_container_sorter.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -127,7 +127,7 @@ class MappedPtrContainerSorter { // The result maps each element in the unordered_container to the target // index that it will occupy in the sorted result. - StatusOr> Flatten() const; + absl::StatusOr> Flatten() const; private: SortedIndices() = delete; @@ -152,7 +152,7 @@ class MappedPtrContainerSorter { // Returns a mapping in which the element at index i indicates the target // index that unordered_container[i] should occupy in the sorted result. template - static StatusOr> ComputeNewIndices( + static absl::StatusOr> ComputeNewIndices( MapPtrFn map_ptr, UnmappedPtrIndexFn unmapped_index, const OrderedTy& ordered_container, const UnorderedTy& unordered_container); @@ -230,9 +230,8 @@ template Status MappedPtrContainerSorter::SortedIndices::AddMappedElement( size_t unordered_container_index, size_t partial_order) { if (partial_order >= mapped_element_indices_by_partial_order_.size()) { - return InternalErrorStrCat( - "invalid partial order: ", partial_order, " v max(", - mapped_element_indices_by_partial_order_.size(), ")"); + return InternalStrCat("invalid partial order: ", partial_order, " v max(", + mapped_element_indices_by_partial_order_.size(), ")"); } mapped_element_indices_by_partial_order_[partial_order].push_back( @@ -284,13 +283,13 @@ std::string MappedPtrContainerSorter::SortedIndices::ToString() } template -StatusOr> +absl::StatusOr> MappedPtrContainerSorter::SortedIndices::Flatten() const { std::vector result(unordered_container_size_, InvalidIndex()); size_t next_available_index = 0; - auto next_index_fn = [&]() -> StatusOr { + auto next_index_fn = [&]() -> absl::StatusOr { if (next_available_index >= unordered_container_size_) { - return InternalErrorStrCat( + return InternalStrCat( "invalid unordered_container index: ", next_available_index, " v size(", unordered_container_size_, ")"); } @@ -335,7 +334,7 @@ MappedPtrContainerSorter::SortedIndices::Flatten() const { absl::flat_hash_set used_indices; for (size_t index : result) { if (used_indices.contains(index)) { - return InternalErrorStrCat( + return InternalStrCat( "2 elements in unordered_container are destined for the same " "index: ", index); @@ -351,7 +350,7 @@ MappedPtrContainerSorter::SortedIndices::Flatten() const { template template -StatusOr> +absl::StatusOr> MappedPtrContainerSorter::ComputeNewIndices( MapPtrFn map_ptr, UnmappedPtrIndexFn unmapped_index, const OrderedTy& ordered_container, diff --git a/xla/service/mapped_ptr_container_sorter_test.cc b/xla/service/mapped_ptr_container_sorter_test.cc index 26f77e7c8ccf5..f0501ace594ec 100644 --- a/xla/service/mapped_ptr_container_sorter_test.cc +++ b/xla/service/mapped_ptr_container_sorter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/maybe_owning_device_memory.cc b/xla/service/maybe_owning_device_memory.cc index 0081f965dc79d..febb92d1b1387 100644 --- a/xla/service/maybe_owning_device_memory.cc +++ b/xla/service/maybe_owning_device_memory.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/maybe_owning_device_memory.h b/xla/service/maybe_owning_device_memory.h index 7ca8666611bf9..1992b4055e390 100644 --- a/xla/service/maybe_owning_device_memory.h +++ b/xla/service/maybe_owning_device_memory.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/memory_space_assignment/BUILD b/xla/service/memory_space_assignment/BUILD index 2f9e592e22ff8..a90a6c5d2f2cc 100644 --- a/xla/service/memory_space_assignment/BUILD +++ b/xla/service/memory_space_assignment/BUILD @@ -1,19 +1,20 @@ # Description: # Memory Space Assignment service implementation. -load( - "//xla:xla.bzl", - "xla_cc_test", -) +load("@tsl//tsl:tsl.bzl", "internal_visibility") load( "@tsl//tsl/platform:build_config.bzl", "tf_proto_library", ) load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load( + "//xla:xla.bzl", + "xla_cc_test", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [":friends"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -36,8 +37,13 @@ cc_library( srcs = ["memory_space_assignment.cc"], hdrs = ["memory_space_assignment.h"], deps = [ + ":allocation", + ":cost_analysis", + ":memory_bound_loop_optimizer", ":memory_space_assignment_proto_cc", + ":options", ":repacking", + ":slice", ":tuning_utils", ":utils", "//xla:debug_options_flags", @@ -46,11 +52,11 @@ cc_library( "//xla:status_macros", "//xla:statusor", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", "//xla/service:buffer_value", "//xla/service:call_graph", - "//xla/service:heap_simulator", "//xla/service:hlo_alias_analysis", "//xla/service:hlo_buffer", "//xla/service:hlo_cost_analysis", @@ -58,21 +64,22 @@ cc_library( "//xla/service:hlo_proto_cc", "//xla/service:hlo_value", "//xla/service:time_utils", - "//xla/service:tuple_util", + "//xla/service/heap_simulator", + "//xla/service/heap_simulator:allocation_block", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", + "@com_googlesource_code_re2//:re2", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", ], ) @@ -81,20 +88,29 @@ xla_cc_test( name = "memory_space_assignment_test", srcs = ["memory_space_assignment_test.cc"], deps = [ + ":allocation", + ":cost_analysis", ":memory_space_assignment", ":memory_space_assignment_proto_cc", + ":options", + ":prefetch_interval_picker", ":repacking", + ":slice", + ":testing_utils", "//xla:shape_util", "//xla:status", + "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_live_range", "//xla/hlo/utils:hlo_matchers", - "//xla/service:heap_simulator", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_value", "//xla/service:instruction_hoister", "//xla/service:time_utils", + "//xla/service/heap_simulator", + "//xla/service/heap_simulator:allocation_block", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", @@ -103,13 +119,15 @@ xla_cc_test( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:protobuf", "@tsl//tsl/platform:status", + "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], @@ -120,9 +138,7 @@ cc_library( hdrs = ["repacking.h"], deps = [ "//xla:statusor", - "//xla:types", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings", + "//xla/service/heap_simulator:allocation_block", "@com_google_absl//absl/types:span", ], ) @@ -135,13 +151,16 @@ cc_library( ":repacking", "//xla:comparison_util", "//xla:statusor", - "//xla/service:heap_simulator", + "//xla/service/heap_simulator", + "//xla/service/heap_simulator:allocation_block", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:status", ], ) @@ -151,7 +170,53 @@ cc_library( hdrs = ["utils.h"], deps = [ "//xla/hlo/ir:hlo", - "//xla/service:heap_simulator", + "//xla/service:hlo_value", + "//xla/service/heap_simulator", + ], +) + +cc_library( + name = "slice", + srcs = ["slice.cc"], + hdrs = [ + "slice.h", + ], + deps = [ + ":memory_space_assignment_proto_cc", + "//xla:shape_util", + "//xla/service/heap_simulator", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "allocation", + srcs = ["allocation.cc"], + hdrs = [ + "allocation.h", + ], + deps = [ + ":cost_analysis", + ":memory_space_assignment_proto_cc", + ":slice", + "//xla:shape_util", + "//xla:status", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_value", + "//xla/service:time_utils", + "//xla/service:tuple_util", + "//xla/service/heap_simulator", + "//xla/service/heap_simulator:allocation_block", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", ], ) @@ -161,7 +226,206 @@ cc_library( hdrs = ["tuning_utils.h"], deps = [ "//xla/hlo/ir:hlo", - "//xla/service:heap_simulator", + "//xla/service/heap_simulator", + ], +) + +cc_library( + name = "options", + srcs = [], + hdrs = ["options.h"], + deps = [ + ":cost_analysis", + ":memory_space_assignment_proto_cc", + ":prefetch_interval_picker", + ":repacking", + ":slice", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_value", + "//xla/service:hlo_value", + "//xla/service/heap_simulator", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "cost_analysis", + srcs = ["cost_analysis.cc"], + hdrs = ["cost_analysis.h"], + deps = [ + "//xla:shape_util", + "//xla:statusor", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_live_range", + "//xla/service:call_graph", + "//xla/service:hlo_alias_analysis", + "//xla/service:hlo_buffer", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_value", + "//xla/service/heap_simulator", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "cost_analysis_test", + srcs = ["cost_analysis_test.cc"], + deps = [ + ":cost_analysis", + "//xla:shape_util", + "//xla:status", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "prefetch_interval_picker", + srcs = ["prefetch_interval_picker.cc"], + hdrs = ["prefetch_interval_picker.h"], + deps = [ + ":cost_analysis", + ":memory_space_assignment_proto_cc", + "//xla:shape_util", + "//xla:status", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_live_range", + "//xla/service:hlo_proto_cc", + "//xla/service:hlo_value", + "//xla/service/heap_simulator", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:logging", + ], +) + +cc_library( + name = "testing_utils", + testonly = True, + hdrs = ["testing_utils.h"], + deps = [ + ":cost_analysis", + "//xla:shape_util", + "//xla:statusor", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_live_range", + "//xla/service:call_graph", + "//xla/service:hlo_alias_analysis", + "//xla/service:hlo_cost_analysis", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "memory_bound_loop_optimizer", + srcs = ["memory_bound_loop_optimizer.cc"], + hdrs = ["memory_bound_loop_optimizer.h"], + deps = [ + ":allocation", + ":cost_analysis", + ":memory_space_assignment_proto_cc", + ":options", + "//xla:shape_util", + "//xla:status", + "//xla:status_macros", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_live_range", + "//xla/service:buffer_value", + "//xla/service:hlo_alias_analysis", + "//xla/service:hlo_buffer", + "//xla/service:hlo_proto_cc", + "//xla/service:hlo_value", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "memory_bound_loop_optimizer_test", + srcs = ["memory_bound_loop_optimizer_test.cc"], + deps = [ + ":allocation", + ":cost_analysis", + ":memory_bound_loop_optimizer", + ":memory_space_assignment", + ":memory_space_assignment_proto_cc", + ":options", + ":prefetch_interval_picker", + "//xla:shape_util", + "//xla:status", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_live_range", + "//xla/service:buffer_value", + "//xla/service:hlo_alias_analysis", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_value", + "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@com_googlesource_code_re2//:re2", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + ], +) + +xla_cc_test( + name = "prefetch_interval_picker_test", + srcs = ["prefetch_interval_picker_test.cc"], + deps = [ + ":cost_analysis", + ":prefetch_interval_picker", + ":testing_utils", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_value", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", ], ) @@ -172,6 +436,8 @@ xla_cc_test( ":best_fit_repacker", ":repacking", "//xla:comparison_util", + "//xla/service/heap_simulator", + "//xla/service/heap_simulator:allocation_block", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:span", diff --git a/xla/service/memory_space_assignment/allocation.cc b/xla/service/memory_space_assignment/allocation.cc new file mode 100644 index 0000000000000..49808709832f5 --- /dev/null +++ b/xla/service/memory_space_assignment/allocation.cc @@ -0,0 +1,854 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/memory_space_assignment/allocation.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/functional/function_ref.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/heap_simulator/allocation_block.h" +#include "xla/service/heap_simulator/heap_simulator.h" +#include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" +#include "xla/service/memory_space_assignment/slice.h" +#include "xla/service/time_utils.h" +#include "xla/service/tuple_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" + +namespace xla::memory_space_assignment { +namespace { + +std::string UsesToString(const std::vector& uses) { + if (uses.empty()) { + return "none"; + } + std::vector uses_str; + uses_str.reserve(uses.size()); + for (const auto& use : uses) { + uses_str.push_back(use.ToString()); + } + return absl::StrJoin(uses_str, ","); +} + +// Helper function to compute the start time for a SlicedCopyAllocation. +int64_t GetSlicedCopyAllocationExclusiveStartTime( + const std::vector& + slice_decisions_sorted_by_exclusive_start_time) { + if (slice_decisions_sorted_by_exclusive_start_time.empty()) { + return -1; + } + + return slice_decisions_sorted_by_exclusive_start_time.front() + .exclusive_start_time; +} + +// Helper function to compute the underlying Allocation chunk for a +// SlicedCopyAllocation. +std::optional GetSlicedCopyAllocationChunk( + const std::vector& slice_decisions_sorted_by_start_time) { + if (slice_decisions_sorted_by_start_time.empty()) { + return std::nullopt; + } + auto offset_cmp = [](const SliceDecision& lhs, const SliceDecision& rhs) { + return lhs.chunk.offset < rhs.chunk.offset; + }; + auto end_cmp = [](const SliceDecision& lhs, const SliceDecision& rhs) { + return lhs.chunk.chunk_end() < rhs.chunk.chunk_end(); + }; + return HeapSimulator::Chunk::FromOffsetEnd( + std::min_element(slice_decisions_sorted_by_start_time.begin(), + slice_decisions_sorted_by_start_time.end(), offset_cmp) + ->chunk.offset, + std::max_element(slice_decisions_sorted_by_start_time.begin(), + slice_decisions_sorted_by_start_time.end(), end_cmp) + ->chunk.chunk_end()); +} + +} // namespace + +std::optional Allocation::cross_program_prefetch_index() const { + return cross_program_prefetch_index_; +} + +HeapSimulator::Chunk Allocation::chunk() const { + CHECK(chunk_.has_value()); + return *chunk_; +} + +void Allocation::set_offset(int64_t offset) { + CHECK(chunk_.has_value()); + *chunk_ = HeapSimulator::Chunk::FromOffsetSize(offset, chunk_->size); +} + +bool Allocation::is_in_alternate_mem() const { + return memory_space_ == MemorySpace::kAlternate; +} + +bool Allocation::is_in_default_mem() const { + return memory_space_ == MemorySpace::kDefault; +} + +void Allocation::AddUse(HloUse use) { + HloInstruction* operand = + use.instruction->mutable_operand(use.operand_number); + // If the use is a tuple, look inside the tuple to find the actual use. + for (int64_t index : use.operand_index) { + if (operand->opcode() != HloOpcode::kTuple) { + break; + } + operand = operand->mutable_operand(index); + } + + // Look beyond GetTupleElement(Tuple()) pattern for any bitcasts. + std::function get_simplified_operand; + get_simplified_operand = [&](HloInstruction* instruction) { + while (instruction->opcode() == HloOpcode::kGetTupleElement) { + HloInstruction* operand = + get_simplified_operand(instruction->mutable_operand(0)); + if (operand->opcode() == HloOpcode::kTuple) { + instruction = operand->mutable_operand(instruction->tuple_index()); + } else { + return instruction; + } + } + return instruction; + }; + operand = get_simplified_operand(operand); + + uses_.push_back(use); +} + +Status Allocation::UpdateUses(HloComputation* computation, + HloInstruction* producing_instruction) { + for (const HloUse& use : uses()) { + HloInstruction* replacement_instruction = producing_instruction; + Shape operand_shape = use.instruction->operand(use.operand_number)->shape(); + if (operand_shape.IsTuple()) { + TF_ASSIGN_OR_RETURN( + replacement_instruction, + TupleUtil::ReplaceTupleWith( + producing_instruction, + use.instruction->mutable_operand(use.operand_number), + use.operand_index)); + } else if (operand_shape != producing_instruction->shape()) { + // When processing allocations, we treat bitcasts as trivial positions and + // do not create allocations for them. We insert bitcasts after copies, to + // account for the fact that we don't have an allocation for the bitcast. + VLOG(4) << "Old shape = " << operand_shape.ToString() + << ", new shape = " << producing_instruction->shape().ToString() + << "; inserting a bitcast."; + replacement_instruction = computation->AddInstruction( + HloInstruction::CreateBitcast(operand_shape, producing_instruction)); + } + TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith( + use.operand_number, replacement_instruction)); + } + return OkStatus(); +} + +bool Allocation::is_copy_like_allocation() const { + return is_copy_allocation() || is_sliced_copy_allocation(); +} + +HloInstruction* Allocation::AddGetTupleElements() const { + CHECK_NE(defining_position().instruction, nullptr); + + Shape shape = defining_position().shape(); + CHECK(shape.IsArray()) << "Allocation shape is not an array. Shape = " + << shape.ToString() + << " position = " << defining_position().shape(); + return TupleUtil::AddGetTupleElements(defining_position()); +} + +Allocation::Allocation(HloPosition defining_position, MemorySpace memory_space, + std::optional chunk, + int64_t start_time, int64_t end_time, + bool is_scoped_allocation, + std::optional cross_program_prefetch_index) + : original_defining_position_(std::move(defining_position)), + memory_space_(memory_space), + chunk_(chunk), + start_time_(start_time), + end_time_(end_time), + is_scoped_allocation_(is_scoped_allocation), + cross_program_prefetch_index_(cross_program_prefetch_index) { + CHECK(!is_scoped_allocation || + original_defining_position_.index == ShapeIndex({})); +} + +HloPosition Allocation::original_defining_position() const { + return original_defining_position_; +} + +void Allocation::set_original_defining_position(HloPosition defining_position) { + original_defining_position_ = std::move(defining_position); +} + +bool Allocation::base_is_equal(const Allocation& other) const { + return defining_position() == other.defining_position() && + uses() == other.uses() && memory_space() == other.memory_space() && + chunk() == other.chunk() && start_time() == other.start_time() && + end_time() == other.end_time() && + earliest_available_time() == other.earliest_available_time() && + is_copy_allocation() == other.is_copy_allocation() && + is_scoped_allocation() == other.is_scoped_allocation(); +} + +PinnedAllocation::PinnedAllocation(HloPosition defining_position, + MemorySpace memory_space, + std::optional chunk, + int64_t start_time, int64_t end_time, + bool is_scoped_allocation) + : Allocation(std::move(defining_position), memory_space, chunk, start_time, + end_time, is_scoped_allocation, + /*cross_program_prefetch_index=*/std::nullopt) {} + +HloPosition PinnedAllocation::defining_position() const { + return original_defining_position(); +} + +bool PinnedAllocation::operator==(const PinnedAllocation& other) const { + return this->base_is_equal(static_cast(other)); +} + +bool MirroredAllocation::operator==(const MirroredAllocation& other) const { + return this->base_is_equal(static_cast(other)); +} + +bool ParentAllocation::operator==(const ParentAllocation& other) const { + return this->base_is_equal(static_cast(other)); +} + +bool PinnedAllocation::operator==(const Allocation& other) const { + const PinnedAllocation* casted_other = + dynamic_cast(&other); + return casted_other != nullptr && (*this) == (*casted_other); +} + +Status PinnedAllocation::Process() { + if (is_scoped_allocation()) { + // Nothing to do here for scoped allocations. + return OkStatus(); + } + HloInstruction* producing_instruction = AddGetTupleElements(); + HloComputation* computation = producing_instruction->parent(); + return UpdateUses(computation, producing_instruction); +} + +std::string PinnedAllocation::ToString() const { + std::string memory_space_str = + memory_space() == MemorySpace::kDefault ? "def" : "alt"; + std::optional chunk = maybe_chunk(); + if (chunk) { + absl::StrAppend(&memory_space_str, " (off: ", chunk->offset, ")"); + } + return absl::StrCat((is_scoped_allocation() ? "Scoped " : ""), + "PinnedAllocation in ", memory_space_str, " defined at ", + original_defining_position().ToString(), + ", start_time:", start_time(), ", end_time:", end_time(), + ", uses: ", UsesToString(uses())); +} + +void PinnedAllocation::MarkIfNeeded( + absl::flat_hash_set& needed_allocations) const { + MarkNeeded(needed_allocations); +} + +void PinnedAllocation::MarkNeeded( + absl::flat_hash_set& needed_allocations) const { + needed_allocations.insert(this); +} + +CopyAllocation::CopyAllocation( + Allocation& prev_allocation, MemorySpace memory_space, + std::optional chunk, + int64_t copy_start_schedule_after_time, + int64_t copy_done_schedule_before_time, int64_t end_time, + std::optional cross_program_prefetch_index) + : Allocation( + /*defining_position=*/{nullptr, {}}, memory_space, chunk, + // Allocation uses an inclusive start time + ExclusiveToInclusiveStartTime(copy_start_schedule_after_time), + end_time, + /*is_scoped_allocation=*/false, cross_program_prefetch_index), + prev_allocation_(prev_allocation), + copy_start_schedule_after_(copy_start_schedule_after_time), + copy_done_schedule_before_(copy_done_schedule_before_time) {} + +int64_t CopyAllocation::earliest_available_time() const { + return copy_done_schedule_before_; +} + +Status CopyAllocation::Process() { + // Copy allocations need to insert asynchronous copy nodes. + Shape shape = defining_position().shape(); + HloInstruction* producing_instruction = AddGetTupleElements(); + HloComputation* computation = producing_instruction->parent(); + copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart( + ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}), + producing_instruction, cross_program_prefetch_index())); + copy_done_ = computation->AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_)); + VLOG(4) << "Created " << copy_start_->name() + << " for copy allocation: " << ToString(); + + // Update the allocation position with the copy complete instruction, so that + // if there are further copies from it, they can find the correct position. + set_original_defining_position(HloPosition{copy_done_, {}}); + return UpdateUses(computation, copy_done_); +} + +void CopyAllocation::MarkIfNeeded( + absl::flat_hash_set& needed_allocations) const { + MarkNeeded(needed_allocations); +} + +void CopyAllocation::MarkNeeded( + absl::flat_hash_set& needed_allocations) const { + needed_allocations.insert(this); + prev_allocation_.MarkNeeded(needed_allocations); +} + +std::string CopyAllocation::ToString() const { + std::string memory_space_str = + memory_space() == MemorySpace::kDefault ? "def" : "alt"; + std::optional chunk = maybe_chunk(); + if (chunk) { + absl::StrAppend(&memory_space_str, " (off: ", chunk->offset, ")"); + } + return absl::StrCat("Copy Allocation in ", memory_space_str, + ", start_time:", start_time(), ", end_time:", end_time(), + ", copy_start_after_time: ", copy_start_schedule_after(), + ", copy_done_before_time: ", copy_done_schedule_before(), + ", uses: ", UsesToString(uses()), ", from ", + prev_allocation_.ToString()); +} + +HloPosition CopyAllocation::defining_position() const { + // Unless explicitly set, the defining position of a copy allocation is + // retrieved from the previous allocation. This is because we don't create + // new CopyStart/CopyDone instructions until later and the position should + // point to the previous (copy or otherwise) allocation's position for the + // original defining position. + HloPosition defining_position = original_defining_position(); + if (defining_position.instruction == nullptr) { + return prev_allocation_.defining_position(); + } + return defining_position; +} + +bool CopyAllocation::operator==(const CopyAllocation& other) const { + return this->base_is_equal(static_cast(other)) && + copy_done_schedule_before() == other.copy_done_schedule_before() && + copy_start_schedule_after() == other.copy_start_schedule_after() && + copy_start() == other.copy_start() && copy_done() == other.copy_done(); +} + +bool CopyAllocation::operator==(const Allocation& other) const { + const CopyAllocation* casted_other = + dynamic_cast(&other); + return casted_other != nullptr && (*this) == (*casted_other); +} + +void CopyAllocation::set_copy_start_schedule_after( + int64_t copy_start_schedule_after) { + copy_start_schedule_after_ = copy_start_schedule_after; +} + +void CopyAllocation::set_copy_done_schedule_before( + int64_t copy_done_schedule_before) { + copy_done_schedule_before_ = copy_done_schedule_before; +} + +int64_t CopyAllocation::copy_start_schedule_after() const { + return copy_start_schedule_after_; +} + +int64_t CopyAllocation::copy_done_schedule_before() const { + return copy_done_schedule_before_; +} + +SlicedCopyAllocation::SlicedCopyAllocation( + const Allocation& prev_allocation, MemorySpace memory_space, + std::vector slice_decisions_sorted_by_exclusive_start_time, + int64_t copy_done_schedule_before_time, int64_t end_time, + const SlicedPrefetchOptions& sliced_prefetch_options, + absl::FunctionRef get_equivalent_s8_shape_fn) + : Allocation( + /*defining_position=*/{nullptr, {}}, memory_space, + GetSlicedCopyAllocationChunk( + slice_decisions_sorted_by_exclusive_start_time), + // Allocation uses an inclusive start time + ExclusiveToInclusiveStartTime( + GetSlicedCopyAllocationExclusiveStartTime( + slice_decisions_sorted_by_exclusive_start_time)), + end_time, + /*is_scoped_allocation=*/false, + /*cross_program_prefetch_index=*/std::nullopt), + original_shape_to_slice_(prev_allocation.defining_position().shape()), + prev_allocation_(prev_allocation), + sliced_prefetch_options_(sliced_prefetch_options), + get_equivalent_s8_shape_fn_(get_equivalent_s8_shape_fn) { + CHECK_GE(slice_decisions_sorted_by_exclusive_start_time.size(), 2); + slice_details_sorted_by_exclusive_start_time_.reserve( + slice_decisions_sorted_by_exclusive_start_time.size()); + for (SliceDecision& decision : + slice_decisions_sorted_by_exclusive_start_time) { + int64_t copy_done_schedule_after_time = decision.exclusive_start_time; + slice_details_sorted_by_exclusive_start_time_.push_back(SliceDetail{ + std::move(decision), + copy_done_schedule_after_time, + copy_done_schedule_before_time, + /*copy_start=*/nullptr, + /*copy_done=*/nullptr, + }); + } +} + +Status SlicedCopyAllocation::Process() { + Shape shape = defining_position().shape(); + HloInstruction* producing_instruction = AddGetTupleElements(); + + // Calling Process() over the previous allocation might have modified the + // defining position, and hence the shape that was used when we computed + // the slices. In cases where the shape has changed, we insert a bitcast, so + // slice instructions operate on the originally sliced shape. + // + // Note, these bitcasts are being inserted in the same cases that + // UpdateUses() is inserting bitcasts, except we are + // inserting the bitcasts before the copy, instead of after the copy. + if (!Shape::Equal().IgnoreMemorySpaceInLayout()(shape, + original_shape_to_slice_)) { + int64_t new_memory_space = shape.layout().memory_space(); + shape = original_shape_to_slice_; + shape.mutable_layout()->set_memory_space(new_memory_space); + producing_instruction = producing_instruction->parent()->AddInstruction( + HloInstruction::CreateBitcast(shape, producing_instruction)); + } + + HloComputation* computation = producing_instruction->parent(); + std::vector slice_dones; + slice_dones.reserve(slice_details_sorted_by_exclusive_start_time_.size()); + + // If we are trying to make all slices a uniform size, we bitcast the + // producing instruction to an array of bytes, so it is easy to slice into any + // size. + Shape slice_shape = shape; + if (IsUniformSliceSizingEnabled(sliced_prefetch_options_)) { + slice_shape = get_equivalent_s8_shape_fn_(shape); + producing_instruction = producing_instruction->parent()->AddInstruction( + HloInstruction::CreateBitcast(slice_shape, producing_instruction)); + } + + // Sliced copy allocations need to insert asynchronous copy nodes. + for (SliceDetail& slice_detail : + slice_details_sorted_by_exclusive_start_time_) { + TF_RETURN_IF_ERROR(slice_detail.CreateAsyncSlice( + slice_shape, *producing_instruction, *computation)); + VLOG(4) << "Created " << slice_detail.copy_start->name() + << " for sliced copy allocation: " << ToString(); + slice_dones.push_back(slice_detail.copy_done); + } + + TF_RETURN_IF_ERROR(CreateBitcastConcat(shape, slice_dones)); + + // If we bitcast to an array of bytes above, the result of the concatenated + // slices will also be an array of bytes. Thus, we need to cast the + // concatentation back to the original shape. + if (IsUniformSliceSizingEnabled(sliced_prefetch_options_)) { + concat_ = concat_->parent()->AddInstruction( + HloInstruction::CreateBitcast(shape, concat_)); + } + + // Update the allocation position with the copy complete instruction, so that + // if there are further copies from it, they can find the correct position. + set_original_defining_position(HloPosition{concat_, {}}); + return UpdateUses(computation, concat_); +} + +void SlicedCopyAllocation::MarkIfNeeded( + absl::flat_hash_set& needed_allocations) const { + MarkNeeded(needed_allocations); +} + +void SlicedCopyAllocation::MarkNeeded( + absl::flat_hash_set& needed_allocations) const { + needed_allocations.insert(this); + prev_allocation_.MarkNeeded(needed_allocations); +} + +HloPosition SlicedCopyAllocation::defining_position() const { + // Unless explicitly set, the defining position of a sliced copy allocation is + // retrieved from the previous allocation. This is because we don't create + // new CopyStart/CopyDone instructions until later and the position should + // point to the previous (copy or otherwise) allocation's position for the + // original defining position. + HloPosition defining_position = original_defining_position(); + if (defining_position.instruction == nullptr) { + return prev_allocation_.defining_position(); + } + return defining_position; +} + +int64_t SlicedCopyAllocation::earliest_available_time() const { + return slice_details_sorted_by_start_time().back().copy_done_before_time; +} + +std::vector SlicedCopyAllocation::SliceOffsetsSortedByStartTime() + const { + std::vector offsets; + offsets.reserve(slice_details_sorted_by_exclusive_start_time_.size()); + + for (const SliceDetail& slice_detail : + slice_details_sorted_by_exclusive_start_time_) { + offsets.push_back(slice_detail.slice_decision.chunk.offset); + } + + return offsets; +} + +void SlicedCopyAllocation::AddDiffToAllSliceOffsets(int64_t diff) { + for (SliceDetail& slice_detail : + slice_details_sorted_by_exclusive_start_time_) { + HeapSimulator::Chunk& chunk = slice_detail.slice_decision.chunk; + chunk = + HeapSimulator::Chunk::FromOffsetSize(chunk.offset + diff, chunk.size); + } +} + +void SlicedCopyAllocation::ImportRepackedSliceData( + const SlicedAllocationData& data) { + int num_slices = slice_details_sorted_by_exclusive_start_time_.size(); + CHECK_EQ(data.slices_sorted_by_offset.size(), num_slices); + + std::vector slice_details_sorted_by_offset; + slice_details_sorted_by_offset.reserve(num_slices); + for (SliceDetail& slice_detail : + slice_details_sorted_by_exclusive_start_time_) { + slice_details_sorted_by_offset.push_back(&slice_detail); + } + absl::c_sort(slice_details_sorted_by_offset, [](const SliceDetail* lhs, + const SliceDetail* rhs) { + return lhs->slice_decision.chunk.offset < rhs->slice_decision.chunk.offset; + }); + + for (int i = 0; i < num_slices; ++i) { + SliceDetail* slice_detail = slice_details_sorted_by_offset[i]; + HeapSimulator::Chunk& chunk = slice_detail->slice_decision.chunk; + const AllocatedSlice& repacked_slice_data = data.slices_sorted_by_offset[i]; + chunk = HeapSimulator::Chunk::FromOffsetSize(repacked_slice_data.offset, + chunk.size); + slice_detail->copy_start_after_time = + repacked_slice_data.inclusive_start_time - 1; + slice_detail->slice_decision.exclusive_start_time = + InclusiveToExclusiveStartTime(repacked_slice_data.inclusive_start_time); + } + + absl::c_sort(slice_details_sorted_by_exclusive_start_time_, + [](const SliceDetail& lhs, const SliceDetail& rhs) { + return std::make_tuple(lhs.copy_start_after_time, + lhs.slice_decision.chunk.offset) < + std::make_tuple(rhs.copy_start_after_time, + rhs.slice_decision.chunk.offset); + }); +} + +const std::vector& +SlicedCopyAllocation::slice_details_sorted_by_start_time() const { + return slice_details_sorted_by_exclusive_start_time_; +} + +std::vector& +SlicedCopyAllocation::mutable_slice_details_sorted_by_start_time() { + return slice_details_sorted_by_exclusive_start_time_; +} + +bool SlicedCopyAllocation::operator==(const SlicedCopyAllocation& other) const { + return this->base_is_equal(static_cast(other)) && + slice_details_sorted_by_exclusive_start_time_ == + other.slice_details_sorted_by_exclusive_start_time_ && + concat_ == other.concat_; +} + +std::string SlicedCopyAllocation::ToString() const { + std::string memory_space_str = "def"; + if (memory_space() == MemorySpace::kAlternate) { + memory_space_str = absl::StrCat("alt (off: ", maybe_chunk()->offset, ")"); + } + return absl::StrCat( + "Sliced Copy Allocation in ", memory_space_str, + ", start_time:", start_time(), ", end_time:", end_time(), + ", first_slice_copy_start_after_time: ", + slice_details_sorted_by_start_time().front().copy_start_after_time, + ", last_slice_copy_done_before_time: ", + slice_details_sorted_by_start_time().back().copy_done_before_time, + ", uses: ", UsesToString(uses()), ", from ", prev_allocation_.ToString()); +} + +Status SlicedCopyAllocation::CreateBitcastConcat( + const Shape& shape, absl::Span slices) { + CHECK(!slices.empty()); + concat_ = + slices.front()->parent()->AddInstruction(HloInstruction::CreateCustomCall( + shape, slices, + xla::memory_space_assignment::kConcatBitcastCustomCall)); + return OkStatus(); +} + +std::string SlicedCopyAllocation::SliceDetail::ToString() const { + return absl::StrCat("{ slice_decision: ", slice_decision.ToString(), + ", copy_start_after_time: ", copy_start_after_time, + ", copy_done_before_time: ", copy_done_before_time, " }"); +} + +std::tuple +SliceDetailToTuple(const SlicedCopyAllocation::SliceDetail& slice_detail) { + return std::make_tuple(std::ref(slice_detail.slice_decision), + slice_detail.copy_start_after_time, + slice_detail.copy_done_before_time, + slice_detail.copy_start, slice_detail.copy_done); +} + +bool SlicedCopyAllocation::SliceDetail::operator==( + const SliceDetail& other) const { + return SliceDetailToTuple(*this) == SliceDetailToTuple(other); +} + +Status SlicedCopyAllocation::SliceDetail::CreateAsyncSlice( + const Shape& original_shape, HloInstruction& producer, + HloComputation& parent) { + if (original_shape.rank() != slice_decision.sizing.slice_params.size()) { + return FailedPrecondition( + "%s", absl::StrCat("The number of SlicedCopyAllocation parameters ", + slice_decision.sizing.slice_params.size(), + " does not match the rank ", original_shape.rank(), + " of the tensor we are slicing.")); + } + + std::vector start_indices; + start_indices.reserve(slice_decision.sizing.slice_params.size()); + std::vector limit_indices; + limit_indices.reserve(slice_decision.sizing.slice_params.size()); + std::vector strides; + strides.reserve(slice_decision.sizing.slice_params.size()); + + for (int i = 0; i < slice_decision.sizing.slice_params.size(); ++i) { + const SliceParam& slice_param = slice_decision.sizing.slice_params[i]; + start_indices.push_back(slice_param.start_inclusive); + limit_indices.push_back(slice_param.end_exclusive); + strides.push_back(1); + const int64_t new_dim = + slice_param.end_exclusive - slice_param.start_inclusive; + if (new_dim <= 0) { + return FailedPrecondition( + "%s", absl::StrCat("SlicedCopyAllocation new dimension size is ", + new_dim, ", expected something > 0.")); + } + if (original_shape.dimensions(i) < new_dim) { + return FailedPrecondition( + "%s", + absl::StrCat("SlicedCopyAllocation sliced dimension size ", new_dim, + " is bigger than its original dimension size of ", + original_shape.dimensions(i), ".")); + } + } + + HloInstruction* slice = parent.AddInstruction( + HloInstruction::CreateSlice(slice_decision.sizing.slice_shape, &producer, + start_indices, limit_indices, strides)); + TF_ASSIGN_OR_RETURN(copy_done, parent.CreateAsyncInstructions( + slice, {ShapeUtil::MakeShape(S32, {})})); + copy_start = copy_done->mutable_operand(0); + + return OkStatus(); +} + +bool SlicedCopyAllocation::operator==(const Allocation& other) const { + const SlicedCopyAllocation* casted_other = + dynamic_cast(&other); + return casted_other != nullptr && (*this) == (*casted_other); +} + +HloPosition MirroredAllocation::defining_position() const { + return original_defining_position(); +} + +std::string MirroredAllocation::ToString() const { + return absl::StrCat("Mirrored Allocation for ", + original_allocation_.ToString()); +} + +std::string ParentAllocation::ToString() const { + return absl::StrCat("Parent Allocation mirrored at ", + original_defining_position().ToString(), ", originally ", + original_allocation_.ToString()); +} + +MirroredAllocation::MirroredAllocation(const Allocation& original_allocation, + int64_t time) + : Allocation(original_allocation.defining_position(), MemorySpace::kDefault, + original_allocation.maybe_chunk(), + /*start_time=*/time, + /*end_time=*/time, /*is_scoped_allocation=*/false, + /*cross_program_prefetch_index=*/std::nullopt), + original_allocation_(original_allocation) {} + +Status MirroredAllocation::Process() { + set_original_defining_position(original_allocation_.defining_position()); + if (is_scoped_allocation()) { + // Nothing to do here for scoped allocations. + return OkStatus(); + } + HloInstruction* producing_instruction = AddGetTupleElements(); + HloComputation* computation = producing_instruction->parent(); + return UpdateUses(computation, producing_instruction); +} + +ParentAllocation::ParentAllocation(const Allocation& original_allocation, + HloInstruction* calling_instruction, + HloPosition position, int64_t time) + : Allocation(std::move(position), MemorySpace::kDefault, + original_allocation.maybe_chunk(), + /*start_time=*/time, + /*end_time=*/time, /*is_scoped_allocation=*/false, + /*cross_program_prefetch_index=*/std::nullopt), + original_allocation_(original_allocation), + calling_instruction_(calling_instruction) {} + +HloPosition ParentAllocation::defining_position() const { + return original_defining_position(); +} + +Status ParentAllocation::Process() { + // Add an additional parameter to the while HLO with a reference to the buffer + // in the default memory space. + HloInstruction* producing_instruction = + original_allocation_.AddGetTupleElements(); + int new_tuple_index = calling_instruction_->shape().tuple_shapes_size(); + + TF_ASSIGN_OR_RETURN( + HloInstruction * new_while_operand, + TupleUtil::ReplaceTupleWith(producing_instruction, + calling_instruction_->mutable_operand(0), + {new_tuple_index})); + TF_RETURN_IF_ERROR(calling_instruction_->ReplaceOperandWithDifferentShape( + 0, new_while_operand)); + *calling_instruction_->mutable_shape() = new_while_operand->shape(); + *calling_instruction_->while_condition() + ->parameter_instruction(0) + ->mutable_shape() = new_while_operand->shape(); + *calling_instruction_->while_body() + ->parameter_instruction(0) + ->mutable_shape() = new_while_operand->shape(); + HloPosition defining_position = original_defining_position(); + defining_position.index = {new_tuple_index}; + set_original_defining_position(defining_position); + // Also replace the while op with a tuple that has the old shape. Note that we + // need to first take a snapshot of the users before calling ExtractPrefix + // since ExtractPrefix introduces additional gte users. + std::vector while_users = calling_instruction_->users(); + HloInstruction* tuple_with_old_shape = + TupleUtil::ExtractPrefix(calling_instruction_, new_tuple_index); + TF_RETURN_IF_ERROR(calling_instruction_->ReplaceAllUsesWithDifferentShape( + while_users, tuple_with_old_shape)); + + if (is_scoped_allocation()) { + // Nothing to do here for scoped allocations. + return OkStatus(); + } + HloInstruction* final_instruction = AddGetTupleElements(); + HloComputation* computation = final_instruction->parent(); + return UpdateUses(computation, final_instruction); +} + +Status ParentAllocation::PostProcess() { + // Update the root of the while body with the new parameter. The reason why we + // need a separate post-process for this is because other allocations may have + // while body root as a use, so they would update the old root instead of the + // new root. Doing the post-process step later ensures the root has been + // updated with other changes, and we can safely add the additional parameter. + HloComputation* while_body = calling_instruction_->while_body(); + TF_ASSIGN_OR_RETURN(HloInstruction * new_while_body_root, + TupleUtil::ReplaceTupleWith( + AddGetTupleElements(), while_body->root_instruction(), + original_defining_position().index)); + while_body->set_root_instruction(new_while_body_root, + /*accept_different_shape=*/true); + return OkStatus(); +} + +void ParentAllocation::MarkIfNeeded( + absl::flat_hash_set& needed_allocations) const { + // Parent allocations are only needed if they have any uses or if there is a + // copy allocation that copies this value (in that case, the copy allocation + // will call this allocation's MarkNeeded function). + if (!has_no_uses()) { + MarkNeeded(needed_allocations); + } +} + +void ParentAllocation::MarkNeeded( + absl::flat_hash_set& needed_allocations) const { + needed_allocations.insert(this); + original_allocation_.MarkNeeded(needed_allocations); +} + +bool ParentAllocation::operator==(const Allocation& other) const { + const ParentAllocation* casted_other = + dynamic_cast(&other); + return casted_other != nullptr && (*this) == (*casted_other); +} + +void MirroredAllocation::MarkIfNeeded( + absl::flat_hash_set& needed_allocations) const { + MarkNeeded(needed_allocations); +} + +void MirroredAllocation::MarkNeeded( + absl::flat_hash_set& needed_allocations) const { + needed_allocations.insert(this); + original_allocation_.MarkNeeded(needed_allocations); +} + +bool MirroredAllocation::operator==(const Allocation& other) const { + const MirroredAllocation* casted_other = + dynamic_cast(&other); + return casted_other != nullptr && (*this) == (*casted_other); +} + +} // namespace xla::memory_space_assignment diff --git a/xla/service/memory_space_assignment/allocation.h b/xla/service/memory_space_assignment/allocation.h new file mode 100644 index 0000000000000..e9977dda40ab8 --- /dev/null +++ b/xla/service/memory_space_assignment/allocation.h @@ -0,0 +1,449 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_ALLOCATION_H_ +#define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_ALLOCATION_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/functional/function_ref.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/heap_simulator/allocation_block.h" +#include "xla/service/heap_simulator/heap_simulator.h" +#include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" +#include "xla/service/memory_space_assignment/slice.h" +#include "xla/shape.h" +#include "xla/status.h" + +namespace xla::memory_space_assignment { + +// MemorySpaceAssignment uses a notion of a slow and large default memory +// space and a fast and small alternate memory space. +enum class MemorySpace : std::uint8_t { kDefault, kAlternate }; + +// An interface describing what to do with a value in memory over its lifetime. +// An allocation might either be placed in the default or alternate memory. An +// HloValue might live in multiple different allocations over its lifetime. The +// lifetimes of the allocations are defined using start_time and end_time, which +// corresponds to the instruction indexes in the flattened schedule. Each of +// these allocations might partially overlap with each other. +// +// Consider an instruction Foo, and its users Bar and Baz, and the times given +// in terms of the flattened schedule of the entire module: +// +// Foo:10 +// / \ +// Bar:14 \ +// Baz:25 +// +// A valid memory space assignment could be like the following: +// +// Time: 10 ... 14 ... 25 +// Foo Bar Baz +// Alternate +-------+ +-----+ +// Default +---------------------+ +// ^ ^ ^ ^ +// | | | | +// evict evict prefetch prefetch +// start end start end +// +// This would be represented with: +// - PinnedAllocation(memory_space=kAlternate, start_time=10, end_time=14) +// - CopyAllocation(memory_space=kDefault, start_time=12, end_time=25) +// - CopyAllocation(memory_space=kAlternate, start_time=22, end_time=25) +class Allocation { + public: + virtual ~Allocation() = default; + + // Allocation source methods + // -------------------------------------------------------------------------- + // Returns the defining position for this allocation. + virtual HloPosition defining_position() const = 0; + // Returns the cross-program prefetch index for this allocation. + std::optional cross_program_prefetch_index() const; + + // Allocation timing methods + // -------------------------------------------------------------------------- + // TODO(cl/604356742): update all timing methods to explicitly state that + // they're representing inclusive intervals. + int64_t start_time() const { return start_time_; } + int64_t end_time() const { return end_time_; } + // Returns the time the buffer is first available to be used + virtual int64_t earliest_available_time() const = 0; + void set_start_time(int64_t start_time) { start_time_ = start_time; } + void set_end_time(int64_t end_time) { end_time_ = end_time; } + // Extends the end time of this allocation. + void Extend(int64_t end_time) { end_time_ = std::max(end_time_, end_time); } + + // Allocation space methods + // -------------------------------------------------------------------------- + MemorySpace memory_space() const { return memory_space_; } + // Returns the associated chunk that may be a nullopt if the allocation is + // in the default memory space. + std::optional maybe_chunk() const { return chunk_; } + // Returns the associated chunk. The caller should ensure that the chunk is + // defined (the allocation should be in the alternate memory space). + HeapSimulator::Chunk chunk() const; + HeapSimulator::Chunk* mutable_chunk() { return &*chunk_; } + void set_offset(int64_t offset); + bool is_scoped_allocation() const { return is_scoped_allocation_; } + // Returns true if the allocation is in the alternate memory space. + bool is_in_alternate_mem() const; + // Returns true if the allocation is in the default memory space. + bool is_in_default_mem() const; + + // Use methods + // -------------------------------------------------------------------------- + const std::vector& uses() const { return uses_; } + void clear_uses() { uses_.clear(); } + bool has_no_uses() const { return uses_.empty(); } + // Adds a use to this allocation. + void AddUse(HloUse use); + // Replaces all uses of the allocation with the copy_complete instruction. + Status UpdateUses(HloComputation* computation, + HloInstruction* producing_instruction); + + // Allocation type methods + // -------------------------------------------------------------------------- + virtual bool is_copy_allocation() const = 0; + virtual bool is_sliced_copy_allocation() const = 0; + // True if the allocation is for a copy or a sliced-copy. + bool is_copy_like_allocation() const; + + // Processing methods + // -------------------------------------------------------------------------- + // Recursively create kGetTupleElement instructions if the defining position + // shape is not an array. Returns the new instruction that has array shape. + HloInstruction* AddGetTupleElements() const; + // After all of the time ranges for the allocations have been assigned, + // Process morphs the instructions affected to assign the memory spaces and + // insert asynchronous copy instructions if necessary. + virtual Status Process() = 0; + // An optional post-process step that will be called after all allocations + // have been processed. + virtual Status PostProcess() = 0; + // Marks (adds this allocation to needed_allocations) if this allocation is + // needed. PinnedAllocation and CopyAllocations are always needed and + // ParentAllocations are needed if they have any uses or if other + // CopyAllocation or ParentAllocations depend on them. + virtual void MarkIfNeeded( + absl::flat_hash_set& needed_allocations) const = 0; + // Marks this allocation as needed. + virtual void MarkNeeded( + absl::flat_hash_set& needed_allocations) const = 0; + + // Utility methods + // -------------------------------------------------------------------------- + virtual std::string ToString() const = 0; + virtual bool operator==(const Allocation& other) const = 0; + + protected: + // Protected constructor to encourage use of the final subclasses (e.g., + // PinnedAllocation, CopyAllocation, etc.). + Allocation(HloPosition defining_position, MemorySpace memory_space, + std::optional chunk, int64_t start_time, + int64_t end_time, bool is_scoped_allocation, + std::optional cross_program_prefetch_index); + + // Returns the original defining position of this allocation. + HloPosition original_defining_position() const; + // Sets the original defining position of this allocation. + void set_original_defining_position(HloPosition defining_position); + bool base_is_equal(const Allocation& other) const; + + private: + HloPosition original_defining_position_; + MemorySpace memory_space_; + std::optional chunk_; + int64_t start_time_; + int64_t end_time_; + const bool is_scoped_allocation_; + std::vector uses_; + std::optional cross_program_prefetch_index_; +}; + +using AllocationSequence = std::vector>; + +// This class represents an allocation that pins a tensor to +// a specific memory space. +class PinnedAllocation final : public Allocation { + public: + PinnedAllocation(HloPosition defining_position, MemorySpace memory_space, + std::optional chunk, + int64_t start_time, int64_t end_time, + bool is_scoped_allocation); + + // Overridden methods + // + // Returns the original defining position. + HloPosition defining_position() const override; + int64_t earliest_available_time() const override { return start_time(); } + bool is_copy_allocation() const override { return false; } + bool is_sliced_copy_allocation() const override { return false; } + Status Process() override; + Status PostProcess() override { return OkStatus(); } + void MarkIfNeeded(absl::flat_hash_set& needed_allocations) + const override; + void MarkNeeded(absl::flat_hash_set& needed_allocations) + const override; + std::string ToString() const override; + bool operator==(const Allocation& other) const override; + + // New non-virtual methods + bool operator==(const PinnedAllocation& other) const; +}; + +// This class represents an allocation as a result of an asynchronous copy. +// Note: CopyStart instructions are inserted after +// `copy_start_schedule_after`, while CopyDone instructions are inserted +// before `copy_done_schedule_before_time`. +class CopyAllocation final : public Allocation { + public: + // TODO(b/307342076): Reorder scheduling times to be + // copy_start_schedule_after_time, copy_done_schedule_before_time, end_time + CopyAllocation( + Allocation& prev_allocation, MemorySpace memory_space, + std::optional chunk, + int64_t copy_start_schedule_after_time, + int64_t copy_done_schedule_before_time, int64_t end_time, + std::optional cross_program_prefetch_index = std::nullopt); + + // Overridden methods + // + HloPosition defining_position() const override; + // Returns the time the buffer is first available to be used. For + // CopyAllocation, this is when the copy ends, which is + // copy_done_schedule_before. + int64_t earliest_available_time() const override; + bool is_copy_allocation() const override { return true; } + bool is_sliced_copy_allocation() const override { return false; } + Status Process() override; + Status PostProcess() override { return OkStatus(); } + void MarkIfNeeded(absl::flat_hash_set& needed_allocations) + const override; + void MarkNeeded(absl::flat_hash_set& needed_allocations) + const override; + std::string ToString() const override; + bool operator==(const Allocation& other) const override; + + // New non-virtual methods + bool operator==(const CopyAllocation& other) const; + + const Allocation& prev_allocation() { return prev_allocation_; } + Allocation& mutable_prev_allocation() { return prev_allocation_; } + + HloInstruction* copy_start() const { return copy_start_; } + HloInstruction* copy_done() const { return copy_done_; } + + void set_copy_start_schedule_after(int64_t copy_start_schedule_after); + void set_copy_done_schedule_before(int64_t copy_done_schedule_before); + int64_t copy_start_schedule_after() const; + int64_t copy_done_schedule_before() const; + + private: + Allocation& prev_allocation_; + // These variables define the scheduling boundaries where CopyStart and + // CopyDone can be scheduled. The earliest CopyStart can be scheduled is + // after copy_start_schedule_after_ and the latest CopyDone can be scheduled + // is before copy_done_schedule_before_. + int64_t copy_start_schedule_after_; + int64_t copy_done_schedule_before_; + HloInstruction* copy_start_ = nullptr; + HloInstruction* copy_done_ = nullptr; +}; + +// This class represents an allocation resulting from asynchronous sliced +// copies. +// +// Let the sliced allocation be represented as follows, and imagine that t3 +// is the time when the entire buffer [p0, p3) is available for use +// +// space +// ^ +// p3 | +-----------+ +// | | | +// p2 | +---+ | +// | | | +// p1 | +-------+ | +// | | | +// p0 | +-------+ +// +---|---|---|---|---|----> time +// t0 t1 t2 t3 t4 +// +// The PinnedAllocation underlying the SlicedCopyAllocation will use the +// following dimensions: +// - chunk = [p0, p3) +// - start time = t2 +// - earliest_available_time = t3 +// - end_time = t4 +class SlicedCopyAllocation final : public Allocation { + public: + // Full details about a slice in the sliced allocation. + struct SliceDetail { + std::string ToString() const; + std::tuple + ToTuple() const; + bool operator==(const SliceDetail& other) const; + + // Create the instructions to copy the slice. This method updates + // copy_start and copy_done. + Status CreateAsyncSlice(const Shape& original_shape, + HloInstruction& producer, HloComputation& parent); + + SliceDecision slice_decision; + int64_t copy_start_after_time = -1; + int64_t copy_done_before_time = -1; + HloInstruction* copy_start = nullptr; + HloInstruction* copy_done = nullptr; + }; + + // REQUIRES: + // - slice_decisions_sorted_by_exclusive_start_time.size() >= 2, otherwise, + // CopyAllocation should be used. + SlicedCopyAllocation( + const Allocation& prev_allocation, MemorySpace memory_space, + std::vector slice_decisions_sorted_by_exclusive_start_time, + int64_t copy_done_schedule_before_time, int64_t end_time, + const SlicedPrefetchOptions& sliced_prefetch_options, + absl::FunctionRef get_equivalent_s8_shape_fn); + + // Overridden methods + // + HloPosition defining_position() const override; + // Returns the time the buffer is first available to be used. For + // SlicedCopyAllocation, this is when all copies have ended. + int64_t earliest_available_time() const override; + bool is_copy_allocation() const override { return false; } + bool is_sliced_copy_allocation() const override { return true; } + // MemorySpaceAssignment::Process() calls Process() to create asynchronous + // slice copies, and a bitcast-concat call to glue the slices back together. + Status Process() override; + Status PostProcess() override { return OkStatus(); } + // Marks the allocation as needed. + void MarkIfNeeded(absl::flat_hash_set& needed_allocations) + const override; + void MarkNeeded(absl::flat_hash_set& needed_allocations) + const override; + std::string ToString() const override; + bool operator==(const Allocation& other) const override; + + // New non-virtual methods + bool operator==(const SlicedCopyAllocation& other) const; + + std::vector SliceOffsetsSortedByStartTime() const; + void AddDiffToAllSliceOffsets(int64_t diff); + // Used to update offsets and start times after repacking. + void ImportRepackedSliceData(const SlicedAllocationData& data); + const std::vector& slice_details_sorted_by_start_time() const; + std::vector& mutable_slice_details_sorted_by_start_time(); + HloInstruction* concat() const { return concat_; } + + private: + SlicedCopyAllocation() = delete; + + // Create an instruction to concatenate the slices. Populates concat_. + Status CreateBitcastConcat(const Shape& shape, + absl::Span slices); + + Shape original_shape_to_slice_; + const Allocation& prev_allocation_; + // REQUIRES: + // - sorted_segments_[i].copy_start_after_time <= + // sorted_segments_[i+j].copy.start_after_time + // - sorted_segments_[i].copy_done_before_time <= + // sorted_segments_[i+j].copy.start_before_time + std::vector slice_details_sorted_by_exclusive_start_time_; + HloInstruction* concat_ = nullptr; + const SlicedPrefetchOptions& sliced_prefetch_options_; + absl::FunctionRef get_equivalent_s8_shape_fn_; +}; + +// An allocation in the default memory space that mirrors another Allocation +// object. This is useful to model an eviction that happens before a while op +// so that we don't need to redundantly evict the buffer after the while op as +// well. +class MirroredAllocation final : public Allocation { + public: + MirroredAllocation(const Allocation& original_allocation, int64_t time); + + // Overridden methods + // + // Returns the original defining position. + HloPosition defining_position() const override; + int64_t earliest_available_time() const override { return start_time(); } + bool is_copy_allocation() const override { return false; } + bool is_sliced_copy_allocation() const override { return false; } + Status Process() override; + Status PostProcess() override { return OkStatus(); } + void MarkIfNeeded(absl::flat_hash_set& needed_allocations) + const override; + void MarkNeeded(absl::flat_hash_set& needed_allocations) + const override; + std::string ToString() const override; + bool operator==(const Allocation& other) const override; + + // New non-virtual methods + bool operator==(const MirroredAllocation& other) const; + + private: + const Allocation& original_allocation_; +}; + +// An allocation in default memory space that is defined in the parent +// computation. If a value has a copy in the default memory space in the +// parent computation, we don't need to evict this buffer in a while loop. +class ParentAllocation final : public Allocation { + public: + ParentAllocation(const Allocation& original_allocation, + HloInstruction* calling_instruction, HloPosition position, + int64_t time); + + // Overridden methods + // + // Returns the original defining position. + HloPosition defining_position() const override; + int64_t earliest_available_time() const override { return start_time(); } + bool is_copy_allocation() const override { return false; } + bool is_sliced_copy_allocation() const override { return false; } + Status Process() override; + Status PostProcess() override; + void MarkIfNeeded(absl::flat_hash_set& needed_allocations) + const override; + void MarkNeeded(absl::flat_hash_set& needed_allocations) + const override; + std::string ToString() const override; + bool operator==(const Allocation& other) const override; + + // New non-virtual methods + bool operator==(const ParentAllocation& other) const; + + private: + const Allocation& original_allocation_; + HloInstruction* calling_instruction_; +}; + +} // namespace xla::memory_space_assignment + +#endif // XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_ALLOCATION_H_ diff --git a/xla/service/memory_space_assignment/best_fit_repacker.cc b/xla/service/memory_space_assignment/best_fit_repacker.cc index 7d2aa7af63196..e99e4ed085c28 100644 --- a/xla/service/memory_space_assignment/best_fit_repacker.cc +++ b/xla/service/memory_space_assignment/best_fit_repacker.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -127,24 +127,21 @@ Step 5: Update AllocationBlocks with the repacking placements #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xla/comparison_util.h" -#include "xla/service/heap_simulator.h" +#include "xla/service/heap_simulator/allocation_block.h" +#include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/memory_space_assignment/repacking.h" #include "xla/statusor.h" #include "tsl/platform/logging.h" +#include "tsl/platform/status.h" namespace xla { namespace { -using AllocationBlock = - memory_space_assignment::MemorySpaceAssignmentRepacker::AllocationBlock; -using Type = GlobalDecreasingSizeBestFitHeap::Type; -using SlicedAllocationData = memory_space_assignment:: - MemorySpaceAssignmentRepacker::SlicedAllocationData; -using Slice = memory_space_assignment::MemorySpaceAssignmentRepacker::Slice; - bool IsSliced(const AllocationBlock* block) { return block->original_slice_data.has_value(); } @@ -164,6 +161,14 @@ std::vector SortAllocationBlocks(const T& container) { return result; } +const SlicedAllocationData* GetSlicedAllocationDataPointer( + const std::optional& sliced_allocation_data) { + if (!sliced_allocation_data.has_value()) { + return nullptr; + } + return &(*sliced_allocation_data); +} + // A slice-aware best-fit repacker. class BestFitRepacker : public GlobalDecreasingSizeBestFitHeap { @@ -171,11 +176,13 @@ class BestFitRepacker BestFitRepacker( const memory_space_assignment::MemorySpaceAssignmentBestFitRepacker:: BestFitRepackOptions& options, + SliceTimePermutationIterator::Ty slice_time_permutation_iterator_type, int64_t max_size, int64_t alignment) : GlobalDecreasingSizeBestFitHeap( alignment, kCustom, (options.buffer_interval_compare ? options.buffer_interval_compare - : DefaultBufferIntervalCompare())), + : DefaultBufferIntervalCompare()), + slice_time_permutation_iterator_type), validate_(options.validate), max_size_(max_size) {} @@ -189,14 +196,17 @@ class BestFitRepacker for (AllocationBlock* allocation_block : allocation_blocks_) { // Check if any of the colocations are already added to buffer_intervals_. bool need_allocation = true; - auto aliased_it = absl::c_find_if( - allocation_block->colocations, [&](AllocationBlock* search) { - return full_buffer_interval_map_.contains(search); - }); - if (aliased_it != allocation_block->colocations.end()) { - full_buffer_interval_map_[*aliased_it].colocations.push_back( - allocation_block); - need_allocation = false; + CHECK_NE(allocation_block->next_colocated, nullptr); + for (AllocationBlock* colocated = allocation_block->next_colocated; + colocated != allocation_block; + colocated = colocated->next_colocated) { + auto aliased_it = full_buffer_interval_map_.find(colocated); + if (aliased_it != full_buffer_interval_map_.end() && + aliased_it->second.need_allocation) { + aliased_it->second.colocations.push_back(allocation_block); + need_allocation = false; + break; + } } full_buffer_interval_map_.insert( std::make_pair(allocation_block, @@ -351,10 +361,10 @@ class BestFitRepacker new_offset = (new_offset == -1 ? chunk.offset : std::min(new_offset, chunk.offset)); repacked_slice_data->slices_sorted_by_offset.push_back( - Slice({chunk.size, chunk.offset, start_time})); + AllocatedSlice({chunk.size, chunk.offset, start_time})); } absl::c_sort(repacked_slice_data->slices_sorted_by_offset, - [](const Slice& lhs, const Slice& rhs) { + [](const AllocatedSlice& lhs, const AllocatedSlice& rhs) { return lhs.offset < rhs.offset; }); } else { @@ -408,9 +418,14 @@ class BestFitRepacker SlicedBufferInterval& colocation_sliced_buffer_interval = sliced_buffer_interval_map_.at(colocation); SlicedAllocationFinder sliced_colocation_finder = - CreateSlicedAllocationFinder(colocation_sliced_buffer_interval, - max_colocation_size, - /*preferred_offset=*/-1); + CreateSlicedAllocationFinder( + colocation_sliced_buffer_interval, max_colocation_size, + /*preferred_offset=*/-1, + SliceTimePermutationIterator::CreateForRepack( + slice_time_permutation_iterator_type(), + GetSlicedAllocationDataPointer( + colocation->original_slice_data)), + &SlicedAllocationFinder::AllOffsetsAllowed); sliced_buffer_map.insert(std::make_pair( colocation, SlicedColocationData{&colocation_sliced_buffer_interval, @@ -444,6 +459,10 @@ class BestFitRepacker // Find chunks for allocation_block and its colocations. SlicedAllocationFinder finder = CreateSlicedAllocationFinder( sliced_buffer_interval, max_colocation_size, /*preferred_offset=*/-1, + SliceTimePermutationIterator::CreateForRepack( + slice_time_permutation_iterator_type(), + GetSlicedAllocationDataPointer( + allocation_block->original_slice_data)), is_offset_allowed); std::vector chunks = PostProcessFindChunkCandidatesResult( sliced_buffer_interval, finder.Find()); @@ -476,7 +495,7 @@ class BestFitRepacker LOG(FATAL) << "We should never get here."; } - Result Finish() override { + absl::StatusOr Finish() override { std::vector sorted_buffer_intervals = GetSortedBufferIntervals(); @@ -519,7 +538,7 @@ class BestFitRepacker for (int i = 0; i < block->repacked_slice_data->slices_sorted_by_offset.size(); ++i) { - const Slice& slice = + const AllocatedSlice& slice = block->repacked_slice_data->slices_sorted_by_offset[i]; timed_chunks.push_back( TimedChunk{absl::StrCat(((int64_t)block), "_slice_", i), block, @@ -552,7 +571,7 @@ class BestFitRepacker } bool Repack() { - Finish(); + TF_CHECK_OK(Finish().status()); bool success = result_.heap_size <= max_size_; if (!success) { VLOG(1) << "Repacking unsuccessful with heap size " << result_.heap_size; @@ -613,10 +632,10 @@ class BestFitRepacker namespace memory_space_assignment { -StatusOr MemorySpaceAssignmentBestFitRepacker::Repack( +absl::StatusOr MemorySpaceAssignmentBestFitRepacker::Repack( absl::Span allocations) { - BestFitRepacker best_fit_repacker = - BestFitRepacker(options_, max_size_, alignment_); + BestFitRepacker best_fit_repacker = BestFitRepacker( + options_, slice_time_permutation_iterator_type_, max_size_, alignment_); best_fit_repacker.ImportAllocationBlocks(allocations); return best_fit_repacker.Repack(); } diff --git a/xla/service/memory_space_assignment/best_fit_repacker.h b/xla/service/memory_space_assignment/best_fit_repacker.h index eb69edb1ce93c..816031d383359 100644 --- a/xla/service/memory_space_assignment/best_fit_repacker.h +++ b/xla/service/memory_space_assignment/best_fit_repacker.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,8 @@ limitations under the License. #include #include "absl/types/span.h" -#include "xla/service/heap_simulator.h" +#include "xla/service/heap_simulator/allocation_block.h" +#include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/memory_space_assignment/repacking.h" #include "xla/statusor.h" @@ -45,17 +46,28 @@ class MemorySpaceAssignmentBestFitRepacker BufferIntervalCompare buffer_interval_compare = nullptr; }; - MemorySpaceAssignmentBestFitRepacker(int64_t max_size, int64_t alignment) + MemorySpaceAssignmentBestFitRepacker( + int64_t max_size, int64_t alignment, + SliceTimePermutationIterator::Ty slice_time_permutation_iterator_type) : MemorySpaceAssignmentRepacker(max_size, alignment), - options_(BestFitRepackOptions()) {} - MemorySpaceAssignmentBestFitRepacker(int64_t max_size, int64_t alignment, - BestFitRepackOptions options) - : MemorySpaceAssignmentRepacker(max_size, alignment), options_(options) {} + options_(BestFitRepackOptions()), + slice_time_permutation_iterator_type_( + slice_time_permutation_iterator_type) {} + MemorySpaceAssignmentBestFitRepacker( + int64_t max_size, int64_t alignment, + SliceTimePermutationIterator::Ty slice_time_permutation_iterator_type, + BestFitRepackOptions options) + : MemorySpaceAssignmentRepacker(max_size, alignment), + options_(std::move(options)), + slice_time_permutation_iterator_type_( + slice_time_permutation_iterator_type) {} - StatusOr Repack(absl::Span allocations) override; + absl::StatusOr Repack( + absl::Span allocations) override; private: BestFitRepackOptions options_; + SliceTimePermutationIterator::Ty slice_time_permutation_iterator_type_; }; } // namespace memory_space_assignment diff --git a/xla/service/memory_space_assignment/best_fit_repacker_test.cc b/xla/service/memory_space_assignment/best_fit_repacker_test.cc index 4ff2e55e6fc1f..ac631841d3bda 100644 --- a/xla/service/memory_space_assignment/best_fit_repacker_test.cc +++ b/xla/service/memory_space_assignment/best_fit_repacker_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,6 +20,8 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "xla/comparison_util.h" +#include "xla/service/heap_simulator/allocation_block.h" +#include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/memory_space_assignment/repacking.h" #include "tsl/platform/test.h" @@ -27,27 +29,17 @@ namespace xla { class MemorySpaceAssignmentBestFitRepackerTest : public ::testing::Test { protected: - using AllocationBlock = - memory_space_assignment::MemorySpaceAssignmentRepacker::AllocationBlock; - using SlicedAllocationData = memory_space_assignment:: - MemorySpaceAssignmentRepacker::SlicedAllocationData; - using Slice = memory_space_assignment::MemorySpaceAssignmentRepacker::Slice; - - MemorySpaceAssignmentBestFitRepackerTest() : repacker_(100, 1, options_) {} + MemorySpaceAssignmentBestFitRepackerTest() + : repacker_(100, 1, SliceTimePermutationIterator::Ty::kAll, options_) {} AllocationBlock* MakeAllocationBlock(int64_t start_time, int64_t end_time, int64_t size, int64_t initial_offset = -1) { allocation_blocks_.push_back( - {start_time, - end_time, - size, - -1, - initial_offset, - static_cast(allocation_blocks_.size()), - {}}); + {start_time, end_time, size, -1, initial_offset, + static_cast(allocation_blocks_.size())}); AllocationBlock* block = &allocation_blocks_.back(); - block->colocations.push_back(block); + block->next_colocated = block; return block; } @@ -73,8 +65,8 @@ TEST_F(MemorySpaceAssignmentBestFitRepackerTest, Colocation) { allocation_blocks.push_back(MakeAllocationBlock(0, 2, 10)); allocation_blocks.push_back(MakeAllocationBlock(10, 20, 10)); // Allocation blocks 0 and 1 are colocated. - allocation_blocks[0]->colocations.push_back(allocation_blocks[1]); - allocation_blocks[1]->colocations.push_back(allocation_blocks[0]); + allocation_blocks[0]->next_colocated = allocation_blocks[1]; + allocation_blocks[1]->next_colocated = allocation_blocks[0]; allocation_blocks.push_back(MakeAllocationBlock(5, 25, 15)); EXPECT_TRUE(*repacker_.Repack(absl::MakeSpan(allocation_blocks))); @@ -106,8 +98,8 @@ TEST_F(MemorySpaceAssignmentBestFitRepackerTest, ColocationDifferentSizes) { allocation_blocks.push_back(MakeAllocationBlock(0, 2, 5)); allocation_blocks.push_back(MakeAllocationBlock(10, 20, 10)); // Allocation blocks 0 and 1 are colocated. - allocation_blocks[0]->colocations.push_back(allocation_blocks[1]); - allocation_blocks[1]->colocations.push_back(allocation_blocks[0]); + allocation_blocks[0]->next_colocated = allocation_blocks[1]; + allocation_blocks[1]->next_colocated = allocation_blocks[0]; allocation_blocks.push_back(MakeAllocationBlock(9, 11, 2)); allocation_blocks.push_back(MakeAllocationBlock(1, 2, 2)); EXPECT_TRUE(*repacker_.Repack(absl::MakeSpan(allocation_blocks))); @@ -122,11 +114,10 @@ TEST_F(MemorySpaceAssignmentBestFitRepackerTest, RepackedSlicesFit) { // Expected repacking: // // space - // ^ - // 8 | - // 7 | +-----+ - // 6 | | E | - // 5 | +-------+-++ | + // ^ + // 7 | + // 6 | +-----+ + // 5 | +-------+-++ E| // 4 | | B |+--++--++--+ // 3 | | || || F|| | // 2 +----------+---++----++ |+--++ | @@ -142,16 +133,16 @@ TEST_F(MemorySpaceAssignmentBestFitRepackerTest, RepackedSlicesFit) { allocation_blocks.push_back(MakeAllocationBlock(11, 21, 3)); // Block C allocation_blocks.push_back(MakeAllocationBlock(16, 25, 4)); - allocation_blocks.back()->original_slice_data = - SlicedAllocationData({{Slice{2, -1, 16}, Slice{2, -1, 22}}}); + allocation_blocks.back()->original_slice_data = SlicedAllocationData( + {{AllocatedSlice{2, -1, 16}, AllocatedSlice{2, -1, 22}}}); // Block D allocation_blocks.push_back(MakeAllocationBlock(26, 33, 4)); - allocation_blocks.back()->original_slice_data = - SlicedAllocationData({{Slice{2, -1, 26}, Slice{2, -1, 30}}}); + allocation_blocks.back()->original_slice_data = SlicedAllocationData( + {{AllocatedSlice{2, -1, 26}, AllocatedSlice{2, -1, 30}}}); // Block E - allocation_blocks.push_back(MakeAllocationBlock(19, 25, 3)); - allocation_blocks.back()->original_slice_data = - SlicedAllocationData({{Slice{1, -1, 19}, Slice{2, -1, 22}}}); + allocation_blocks.push_back(MakeAllocationBlock(19, 25, 2)); + allocation_blocks.back()->original_slice_data = SlicedAllocationData( + {{AllocatedSlice{1, -1, 19}, AllocatedSlice{1, -1, 22}}}); // Block F allocation_blocks.push_back(MakeAllocationBlock(26, 29, 2)); @@ -171,7 +162,7 @@ TEST_F(MemorySpaceAssignmentBestFitRepackerTest, RepackedSlicesFit) { return sort_keys.at(x.buffer); }); repacker_ = memory_space_assignment::MemorySpaceAssignmentBestFitRepacker( - 100, 1, options_); + 100, 1, SliceTimePermutationIterator::Ty::kAll, options_); EXPECT_TRUE(*repacker_.Repack(absl::MakeSpan(allocation_blocks))); @@ -186,27 +177,170 @@ TEST_F(MemorySpaceAssignmentBestFitRepackerTest, RepackedSlicesFit) { EXPECT_EQ(allocation_blocks[2]->offset, 0); ASSERT_TRUE(allocation_blocks[2]->repacked_slice_data.has_value()); EXPECT_EQ(*allocation_blocks[2]->repacked_slice_data, - (SlicedAllocationData({{Slice{2, 0, 16}, Slice{2, 2, 22}}}))); + (SlicedAllocationData( + {{AllocatedSlice{2, 0, 16}, AllocatedSlice{2, 2, 22}}}))); // Block D EXPECT_EQ(allocation_blocks[3]->offset, 0); ASSERT_TRUE(allocation_blocks[3]->repacked_slice_data.has_value()); EXPECT_EQ(*allocation_blocks[3]->repacked_slice_data, - (SlicedAllocationData({{Slice{2, 0, 26}, Slice{2, 2, 30}}}))); + (SlicedAllocationData( + {{AllocatedSlice{2, 0, 26}, AllocatedSlice{2, 2, 30}}}))); // Block E EXPECT_EQ(allocation_blocks[4]->offset, 4); ASSERT_TRUE(allocation_blocks[4]->repacked_slice_data.has_value()); EXPECT_EQ(*allocation_blocks[4]->repacked_slice_data, - (SlicedAllocationData({{Slice{1, 4, 22}, Slice{2, 5, 19}}}))); + (SlicedAllocationData( + {{AllocatedSlice{1, 4, 22}, AllocatedSlice{1, 5, 19}}}))); // Block F EXPECT_EQ(allocation_blocks[5]->offset, 2); EXPECT_FALSE(allocation_blocks[5]->repacked_slice_data.has_value()); } +// Test that we do not permute slice start times in a way that changes the +// original slice size-start time mappings. Doing so breaks assumptions that +// MSA uses to construct its internal state prior to repacking. +TEST_F(MemorySpaceAssignmentBestFitRepackerTest, + SliceTimePermutationsMatchOriginalSizeTimeMapping) { + // Original placement: Ideal repacking, but unsupported: + // + // space space + // ^ ^ + // 7 | +---------+ 7 | + // 6 | | C | 6 | +---------+ + // 5 | +-----+---+ 5 | | C | + // 4 | +-----+ | 4 | +---------+ + // 3 | | B | 3 | | B | + // 2 +----+----+----+ 2 +----+----++ | + // 1 | A | 1 | A |+---+ + // 0 +---------+ 0 +---------+ + // +----|----|----|----> time +----|----|----|----> time + // 0 5 10 15 0 5 10 15 + + std::vector allocation_blocks; + // Block A + allocation_blocks.push_back(MakeAllocationBlock(0, 10, 2, 0)); + // Block B + allocation_blocks.push_back(MakeAllocationBlock(5, 15, 3, 2)); + allocation_blocks.back()->original_slice_data = SlicedAllocationData( + {{AllocatedSlice{2, 2, 5}, AllocatedSlice{1, 4, 11}}}); + // Block C + allocation_blocks.push_back(MakeAllocationBlock(5, 15, 2, 6)); + + // Specify the repacking sort order as the order in which blocks were added to + // allocation_blocks. We need to do this so that B is placed before C. If C + // is placed before B, C will sit directly on top of A, and the repacker would + // never try to permute B's slice size-start time mapping. + absl::flat_hash_map sort_keys; + for (int i = 0; i < allocation_blocks.size(); ++i) { + sort_keys[allocation_blocks[i]] = i; + } + options_.buffer_interval_compare = LessThanByKey( + [sort_keys](const memory_space_assignment:: + MemorySpaceAssignmentBestFitRepacker::BufferInterval& x) { + return sort_keys.at(x.buffer); + }); + repacker_ = memory_space_assignment::MemorySpaceAssignmentBestFitRepacker( + 100, 1, SliceTimePermutationIterator::Ty::kAll, options_); + + // The repacker returns true as long as the result fits in the max size, + // regardless of whether it has actually changed anything. + EXPECT_TRUE(*repacker_.Repack(absl::MakeSpan(allocation_blocks))); + + // Typically the heap_simulator would prefer to start Block B at a smaller + // offset, i.e., offset 1 rather than offset 2. However, in order to do so, + // the repacker would have to permute the original slice size-start time + // mapping, which is not permitted. Thus, we ensure that the repacked B's + // larger slice is assigned the smaller offset and earlier start time. + ASSERT_TRUE(allocation_blocks[1]->repacked_slice_data.has_value()); + ASSERT_EQ( + allocation_blocks[1]->repacked_slice_data->slices_sorted_by_offset.size(), + 2); + const AllocatedSlice& slice_with_smaller_offset = + allocation_blocks[1]->repacked_slice_data->slices_sorted_by_offset[0]; + const AllocatedSlice& slice_with_larger_offset = + allocation_blocks[1]->repacked_slice_data->slices_sorted_by_offset[1]; + // The larger slice is assigned to the smaller offset. + ASSERT_GT(slice_with_smaller_offset.size, slice_with_larger_offset.size); + const AllocatedSlice& larger_slice = slice_with_smaller_offset; + const AllocatedSlice& smaller_slice = slice_with_larger_offset; + // The larger slice is assigned to the earlier start time. + ASSERT_LT(larger_slice.inclusive_start_time, + smaller_slice.inclusive_start_time); +} + +// Test that we do not permute slice start times in a way that changes the +// original slice size-start time mappings. Doing so breaks assumptions that +// MSA uses to construct its internal state prior to repacking. +TEST_F(MemorySpaceAssignmentBestFitRepackerTest, + SliceTimePermutationsMatchOriginalSizeTimeMapping2) { + // Original placement: New placement: + // + // space space + // ^ ^ + // 7 | 7 | + // 6 | +--------+ 6 | + // 5 | | B | 5 | +---------+ + // 4 | +-----+---+----+ 4 | | C | + // 3 | | C | 3 | +-----+ | + // 2 +----+----++ | 2 +---------++---+----+ + // 1 | A |+---+ 1 | A || B | + // 0 +---------+ 0 +---------++--------+ + // +----|----|----|----|--> time +----|----|----|----|--> time + // 0 5 10 15 20 0 5 10 15 20 + + std::vector allocation_blocks; + // Block A + allocation_blocks.push_back(MakeAllocationBlock(0, 10, 2, 0)); + // Block B + allocation_blocks.push_back(MakeAllocationBlock(11, 20, 2, 4)); + // Block C + allocation_blocks.push_back(MakeAllocationBlock(5, 15, 3, 1)); + allocation_blocks.back()->original_slice_data = SlicedAllocationData( + {{AllocatedSlice{1, 1, 5}, AllocatedSlice{2, 2, 11}}}); + + // Specify the repacking sort order as the order in which blocks were added to + // allocation_blocks. We need to do this so that B is placed before C. + absl::flat_hash_map sort_keys; + for (int i = 0; i < allocation_blocks.size(); ++i) { + sort_keys[allocation_blocks[i]] = i; + } + options_.buffer_interval_compare = LessThanByKey( + [sort_keys](const memory_space_assignment:: + MemorySpaceAssignmentBestFitRepacker::BufferInterval& x) { + return sort_keys.at(x.buffer); + }); + repacker_ = memory_space_assignment::MemorySpaceAssignmentBestFitRepacker( + 100, 1, SliceTimePermutationIterator::Ty::kAll, options_); + + // The repacker returns true as long as the result fits in the max size, + // regardless of whether it has actually changed anything. + EXPECT_TRUE(*repacker_.Repack(absl::MakeSpan(allocation_blocks))); + + // Check results + // + // Typically the heap_simulator would prefer to start the first slice of + // Block C at time 5 and the second block at time 11, but that is not allowed + // because it permutes the original slice size-time mapping. + // + // Block A + EXPECT_EQ(allocation_blocks[0]->offset, 0); + EXPECT_FALSE(allocation_blocks[0]->repacked_slice_data.has_value()); + // Block B + EXPECT_EQ(allocation_blocks[1]->offset, 0); + EXPECT_FALSE(allocation_blocks[1]->repacked_slice_data.has_value()); + // Block C + EXPECT_EQ(allocation_blocks[2]->offset, 2); + ASSERT_TRUE(allocation_blocks[2]->repacked_slice_data.has_value()); + EXPECT_EQ(*allocation_blocks[2]->repacked_slice_data, + (SlicedAllocationData( + {{AllocatedSlice{1, 2, 5}, AllocatedSlice{2, 3, 11}}}))); +} + TEST_F(MemorySpaceAssignmentBestFitRepackerTest, SlicedColocationsFit) { // Expected repacking: // // space - // ^ + // ^ // 9 | +-+ // 8 | | | // 7 | |F|+-+ @@ -230,17 +364,15 @@ TEST_F(MemorySpaceAssignmentBestFitRepackerTest, SlicedColocationsFit) { allocation_blocks.push_back(MakeAllocationBlock(5, 11, 2)); // Block D allocation_blocks.push_back(MakeAllocationBlock(15, 20, 5)); - // Note, below we put the later start time first in the original slice data. - // This shouldn't make a difference to the repacker because it will map - // sizes to times as it sees fit. - allocation_blocks.back()->original_slice_data = - SlicedAllocationData({{Slice{2, -1, 18}, Slice{3, -1, 15}}}); + allocation_blocks.back()->original_slice_data = SlicedAllocationData( + {{AllocatedSlice{2, -1, 15}, AllocatedSlice{3, -1, 18}}}); // Block E allocation_blocks.push_back(MakeAllocationBlock(9, 14, 4)); - allocation_blocks.back()->original_slice_data = - SlicedAllocationData({{Slice{2, -1, 9}, Slice{2, -1, 12}}}); + allocation_blocks.back()->original_slice_data = SlicedAllocationData( + {{AllocatedSlice{2, -1, 9}, AllocatedSlice{2, -1, 12}}}); // Colocate E with D. - allocation_blocks.back()->colocations.push_back(allocation_blocks[3]); + allocation_blocks.back()->next_colocated = allocation_blocks[3]; + allocation_blocks[3]->next_colocated = allocation_blocks.back(); // Block F allocation_blocks.push_back(MakeAllocationBlock(15, 17, 5)); @@ -260,7 +392,7 @@ TEST_F(MemorySpaceAssignmentBestFitRepackerTest, SlicedColocationsFit) { return sort_keys.at(x.buffer); }); repacker_ = memory_space_assignment::MemorySpaceAssignmentBestFitRepacker( - 100, 1, options_); + 100, 1, SliceTimePermutationIterator::Ty::kAll, options_); EXPECT_TRUE(*repacker_.Repack(absl::MakeSpan(allocation_blocks))); @@ -278,15 +410,99 @@ TEST_F(MemorySpaceAssignmentBestFitRepackerTest, SlicedColocationsFit) { EXPECT_EQ(allocation_blocks[3]->offset, 2); ASSERT_TRUE(allocation_blocks[3]->repacked_slice_data.has_value()); EXPECT_EQ(*allocation_blocks[3]->repacked_slice_data, - (SlicedAllocationData({{Slice{2, 2, 15}, Slice{3, 4, 18}}}))); + (SlicedAllocationData( + {{AllocatedSlice{2, 2, 15}, AllocatedSlice{3, 4, 18}}}))); // Block E EXPECT_EQ(allocation_blocks[4]->offset, 2); ASSERT_TRUE(allocation_blocks[4]->repacked_slice_data.has_value()); EXPECT_EQ(*allocation_blocks[4]->repacked_slice_data, - (SlicedAllocationData({{Slice{2, 2, 9}, Slice{2, 4, 12}}}))); + (SlicedAllocationData( + {{AllocatedSlice{2, 2, 9}, AllocatedSlice{2, 4, 12}}}))); // Block F EXPECT_EQ(allocation_blocks[5]->offset, 4); EXPECT_FALSE(allocation_blocks[5]->repacked_slice_data.has_value()); } +// Test that we do not permute slice start times in a way that changes the +// original slice size-start time mappings. Doing so breaks assumptions that +// MSA uses to construct its internal state prior to repacking. +TEST_F(MemorySpaceAssignmentBestFitRepackerTest, + SlicedColocationsPermutationsMatchOriginalSizeTimeMapping) { + // Original placement: Ideal repacking, but unsupported: + // + // space space + // ^ ^ + // 8 | 8 | + // 7 |+--------+ +---+ 7 | + // 6 || | | | 6 | + // 5 || C | | D | 5 |+--------++--------+ + // 4 |+----+ |+----+ | 4 || || | + // 3 | | || | 3 || C || D | + // 2 |+---++---++---+----+ 2 |+---++ |+---++ | + // 1 || A | | B | 1 || A || || B || | + // 0 |+---+ +---+ 0 |+---++---++---++---+ + // +----|----|----|----|----> time +----|----|----|----|----> time + // 0 5 10 15 20 0 5 10 15 20 + + std::vector allocation_blocks; + // Block A + allocation_blocks.push_back(MakeAllocationBlock(1, 5, 2)); + // Block B + allocation_blocks.push_back(MakeAllocationBlock(11, 15, 2)); + // Block C + allocation_blocks.push_back(MakeAllocationBlock(1, 10, 5)); + allocation_blocks.back()->original_slice_data = SlicedAllocationData( + {{AllocatedSlice{2, 2, 6}, AllocatedSlice{3, 4, 1}}}); + // Block D + allocation_blocks.push_back(MakeAllocationBlock(15, 20, 5)); + allocation_blocks.back()->original_slice_data = SlicedAllocationData( + {{AllocatedSlice{2, 2, 11}, AllocatedSlice{3, 4, 16}}}); + // Colocate D with C. + allocation_blocks.back()->next_colocated = allocation_blocks[2]; + allocation_blocks[2]->next_colocated = allocation_blocks.back(); + + // Specify the repacking sort order as the order in which blocks were added to + // allocation_blocks. By placing A and B before C/D, the repacker will try + // permutations of C/D's slices that fit around A and B. + absl::flat_hash_map sort_keys; + for (int i = 0; i < allocation_blocks.size(); ++i) { + sort_keys[allocation_blocks[i]] = i; + } + options_.buffer_interval_compare = LessThanByKey( + [sort_keys](const memory_space_assignment:: + MemorySpaceAssignmentBestFitRepacker::BufferInterval& x) { + return sort_keys.at(x.buffer); + }); + repacker_ = memory_space_assignment::MemorySpaceAssignmentBestFitRepacker( + 100, 1, SliceTimePermutationIterator::Ty::kAll, options_); + + EXPECT_TRUE(*repacker_.Repack(absl::MakeSpan(allocation_blocks))); + + // Check results + // + // Typically the heap simulator would like start C/D at offset 0, which is + // lower than C/D's actual placement at offset 2. However, in order to place + // C/D at offset 0, we would need to permute the slice-time mappings of the + // colocation D, which is not permitted. + // + // Block A + EXPECT_EQ(allocation_blocks[0]->offset, 0); + EXPECT_FALSE(allocation_blocks[0]->repacked_slice_data.has_value()); + // Block B + EXPECT_EQ(allocation_blocks[1]->offset, 0); + EXPECT_FALSE(allocation_blocks[1]->repacked_slice_data.has_value()); + // Block C + EXPECT_EQ(allocation_blocks[2]->offset, 2); + ASSERT_TRUE(allocation_blocks[2]->repacked_slice_data.has_value()); + EXPECT_EQ(*allocation_blocks[3]->repacked_slice_data, + (SlicedAllocationData( + {{AllocatedSlice{2, 2, 11}, AllocatedSlice{3, 4, 16}}}))); + // Block D + EXPECT_EQ(allocation_blocks[3]->offset, 2); + ASSERT_TRUE(allocation_blocks[3]->repacked_slice_data.has_value()); + EXPECT_EQ(*allocation_blocks[3]->repacked_slice_data, + (SlicedAllocationData( + {{AllocatedSlice{2, 2, 11}, AllocatedSlice{3, 4, 16}}}))); +} + } // namespace xla diff --git a/xla/service/memory_space_assignment/cost_analysis.cc b/xla/service/memory_space_assignment/cost_analysis.cc new file mode 100644 index 0000000000000..72f1c226e9652 --- /dev/null +++ b/xla/service/memory_space_assignment/cost_analysis.cc @@ -0,0 +1,403 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/memory_space_assignment/cost_analysis.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/memory/memory.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_live_range.h" +#include "xla/service/call_graph.h" +#include "xla/service/heap_simulator/heap_simulator.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_buffer.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/hlo_value.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/statusor.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace memory_space_assignment { +/*static*/ absl::StatusOr> CostAnalysis::Create( + const HloCostAnalysis& cost_analysis, const CostAnalysisOptions& options, + const HloModule& module) { + TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module)); + TF_ASSIGN_OR_RETURN(auto hlo_live_range, + HloLiveRange::Run(module.schedule(), *alias_analysis, + module.entry_computation())); + auto call_graph = CallGraph::Build(&module); + // Using `new` to access a non-public constructor. + return absl::WrapUnique( + new CostAnalysis(cost_analysis, options, std::move(alias_analysis), + std::move(hlo_live_range), std::move(call_graph))); +} + +float CostAnalysis::GetAlternateMemoryBenefit( + const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem, + CostAnalysis::Cache* cache) const { + float elapsed_time_due_to_compute = + GetInstructionElapsedDueToCompute(instruction); + float elapsed_time_due_to_memory = + GetInstructionElapsedDueToMemory(instruction); + if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) { + // Memory bound, return how much alternate memory is better. + float while_nest_multiplier; + if (cache) { + // If there is a cache provided, memoize the while nest multiplier. + auto it = cache->while_nest_multiplier.find(&instruction); + if (it != cache->while_nest_multiplier.end()) { + while_nest_multiplier = it->second; + } else { + while_nest_multiplier = GetWhileNestMultiplier( + CalculateComputationNestLevel(&instruction, + /*while_only=*/true)); + cache->while_nest_multiplier[&instruction] = while_nest_multiplier; + } + } else { + while_nest_multiplier = GetWhileNestMultiplier( + CalculateComputationNestLevel(&instruction, + /*while_only=*/true)); + } + return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) * + while_nest_multiplier; + } else { + // Compute bound, return how far off are we to memory boundedness. + return elapsed_time_due_to_memory - elapsed_time_due_to_compute; + } +} + +float CostAnalysis::GetMemoryBoundedness( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, + CostAnalysis::Cache* cache) const { + if (cache) { + auto it = + cache->memory_boundedness.find(interval.buffer->defining_position()); + if (it != cache->memory_boundedness.end()) { + return it->second; + } + } + float alternate_mem_benefit = + GetAlternateMemoryBenefit(interval.buffer->defining_position(), cache); + + for (const HloBuffer* buffer : alias_analysis_->ComputeBuffersAt( + interval.buffer->defining_position().instruction, + interval.buffer->defining_position().index)) { + for (const HloValue* value : buffer->values()) { + for (const HloUse& use : value->GetUses()) { + // We look inside the called computations of while and conditional, so + // don't use the benefit of while and conditional directly. + if (use.instruction->opcode() == HloOpcode::kWhile || + use.instruction->opcode() == HloOpcode::kConditional) { + continue; + } + float use_alternate_mem_benefit = GetAlternateMemoryBenefit(use, cache); + // If the benefit is positive (memory bound), add it to this buffer's + // benefit. If the benefit is negative (compute bound), calculate the + // maximum. + if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) { + alternate_mem_benefit += use_alternate_mem_benefit; + } else { + alternate_mem_benefit = + std::max(alternate_mem_benefit, use_alternate_mem_benefit); + } + } + } + } + + // Penalize larger buffers by dividing the benefit by the square root of + // the size. Empirically, we observed this resulted in better performance + // compared to dividing by the size. + float memory_boundedness = 1; + if (options_ + .xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers == + "NO_SCALE") { + memory_boundedness = alternate_mem_benefit; + } else { + memory_boundedness = alternate_mem_benefit / std::sqrt(interval.size); + } + + if (cache) { + cache->memory_boundedness[interval.buffer->defining_position()] = + memory_boundedness; + } + return memory_boundedness; +} + +float CostAnalysis::GetAlternateMemoryBenefit( + const HloPosition& position, CostAnalysis::Cache* cache) const { + return GetAlternateMemoryBenefit( + *position.instruction, + GetInstructionElapsedDueToMemory( + *position.instruction, + /*operands_in_alternate_mem=*/{}, + /*outputs_in_alternate_mem=*/{position.index}), + cache); +} + +float CostAnalysis::GetAlternateMemoryBenefit( + const HloUse& use, CostAnalysis::Cache* cache) const { + return GetAlternateMemoryBenefit( + *use.instruction, + GetInstructionElapsedDueToMemory( + *use.instruction, + /*operands_in_alternate_mem=*/{std::make_pair(use.operand_number, + use.operand_index)}), + cache); +} + +int CostAnalysis::CalculateComputationNestLevel( + const HloInstruction* instruction, bool while_only) const { + int nest_level = 0; + const HloComputation* computation = instruction->parent(); + while (!computation->IsEntryComputation()) { + auto& node = call_graph_->GetNode(computation); + auto callsites = node.caller_callsites(); + CHECK(node.computation()->IsAsyncComputation() || callsites.size() == 1) + << "The module is not flattened!"; + auto& callsite = callsites[0]; + if (!while_only || callsite.instruction()->opcode() == HloOpcode::kWhile) { + ++nest_level; + } + computation = callsite.instruction()->parent(); + } + return nest_level; +} + +float CostAnalysis::GetWhileNestMultiplier(int while_nest_level) const { + return IPow( + options_.xla_tpu_memory_space_assignment_while_execution_count, + while_nest_level); +} + +float CostAnalysis::GetDefaultMemoryAccessOverhead( + const HloInstruction& instruction, + absl::Span> operands_in_alternate_mem, + absl::Span outputs_in_alternate_mem) const { + // Calculate the pipeline overhead of accessing the default memory. We use the + // maximum of the window size heuristic and the actual default memory bytes + // accessed multiplied with the compute as the overhead. So, the math is: + // + // overhead = compute_per_iteration + // = compute_elapsed / num_iterations + // = compute_elapsed / (bytes_accessed / window_size) + // = (window_size / bytes_accessed) * compute_elapsed + const float window_size_bytes = + options_.pipeline_overhead_window_size_mib * 1024 * 1024; + const float bytes_accessed = hlo_cost_analysis_.bytes_accessed(instruction); + const float default_memory_bytes_accessed = + bytes_accessed - + GetBytesAccessedFromAlternateMemory( + instruction, operands_in_alternate_mem, outputs_in_alternate_mem); + const float compute_elapsed = GetInstructionElapsedDueToCompute(instruction); + const float effective_window_size_bytes = + std::min(window_size_bytes, default_memory_bytes_accessed); + float overhead = 0; + if (bytes_accessed > 0) { + overhead = (effective_window_size_bytes / bytes_accessed) * compute_elapsed; + } + return overhead; +} + +float CostAnalysis::GetDefaultMemoryBandwidthIdleTime( + const HloInstruction& instruction, + absl::Span> operands_in_alternate_mem, + absl::Span outputs_in_alternate_mem) const { + const float default_memory_bytes_accessed = + hlo_cost_analysis_.bytes_accessed(instruction) - + GetBytesAccessedFromAlternateMemory( + instruction, operands_in_alternate_mem, outputs_in_alternate_mem); + const float elapsed_due_to_default_mem = + default_memory_bytes_accessed / + hlo_cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); + const float elapsed = GetInstructionElapsedInAlternateMemory( + instruction, operands_in_alternate_mem, outputs_in_alternate_mem); + return elapsed - elapsed_due_to_default_mem; +} + +float CostAnalysis::GetBytesAccessedFromAlternateMemory( + const HloInstruction& instruction, + absl::Span> operands_in_alternate_mem, + absl::Span outputs_in_alternate_mem) const { + float bytes_accessed_from_alternate_mem = 0.0; + for (auto& operand : operands_in_alternate_mem) { + const float operand_bytes_accessed = + hlo_cost_analysis_.operand_bytes_accessed(instruction, operand.first, + operand.second); + bytes_accessed_from_alternate_mem += operand_bytes_accessed; + } + + for (auto& shape_idx : outputs_in_alternate_mem) { + const float output_bytes_accessed = + hlo_cost_analysis_.output_bytes_accessed(instruction, shape_idx); + bytes_accessed_from_alternate_mem += output_bytes_accessed; + } + return bytes_accessed_from_alternate_mem; +} + +namespace { +// Returns true on async instructions since we assume they are already +// efficiently scheduled such that they are not in the critical path and appear +// to take no time. +bool ExcludeInstructionFromElapsed(const HloInstruction& instruction) { + return instruction.opcode() == HloOpcode::kAllGatherStart || + instruction.opcode() == HloOpcode::kAllGatherDone || + instruction.opcode() == HloOpcode::kAllReduceStart || + instruction.opcode() == HloOpcode::kAllReduceDone || + instruction.opcode() == HloOpcode::kAsyncStart || + instruction.opcode() == HloOpcode::kAsyncDone || + instruction.opcode() == HloOpcode::kCollectivePermuteStart || + instruction.opcode() == HloOpcode::kCollectivePermuteDone || + instruction.opcode() == HloOpcode::kCopyStart || + instruction.opcode() == HloOpcode::kCopyDone; +} +} // namespace + +float CostAnalysis::GetInstructionElapsedDueToCompute( + const HloInstruction& instruction) const { + if (ExcludeInstructionFromElapsed(instruction)) { + return 0.0f; + } + return std::max( + hlo_cost_analysis_.flop_count(instruction) / + hlo_cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey), + hlo_cost_analysis_.transcendental_count(instruction) / + hlo_cost_analysis_.per_second_rate( + HloCostAnalysis::kTranscendentalsKey)); +} + +float CostAnalysis::GetInstructionElapsedDueToMemory( + const HloInstruction& instruction, + absl::Span> operands_in_alternate_mem, + absl::Span outputs_in_alternate_mem) const { + if (ExcludeInstructionFromElapsed(instruction)) { + return 0.0f; + } + float total_bytes_accessed = hlo_cost_analysis_.bytes_accessed(instruction); + float bytes_accessed_from_alternate_mem = GetBytesAccessedFromAlternateMemory( + instruction, operands_in_alternate_mem, outputs_in_alternate_mem); + float elapsed_due_to_alternate_mem = + bytes_accessed_from_alternate_mem / + options_.alternate_mem_bandwidth_bytes_per_second; + float elapsed_due_to_default_mem = + (total_bytes_accessed - bytes_accessed_from_alternate_mem) / + hlo_cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); + return elapsed_due_to_alternate_mem + elapsed_due_to_default_mem; +} + +float CostAnalysis::GetInstructionElapsedDueToMemory( + const HloInstruction& instruction, + IsInAlternateMemoryFun is_in_alternate_mem) const { + if (ExcludeInstructionFromElapsed(instruction)) { + return 0.0f; + } + float total_bytes_accessed = hlo_cost_analysis_.bytes_accessed(instruction); + float bytes_accessed_from_alternate_mem = 0.0; + for (int operand_num = 0; operand_num < instruction.operand_count(); + ++operand_num) { + ShapeUtil::ForEachSubshape( + instruction.operand(operand_num)->shape(), + [&](const Shape& subshape, const ShapeIndex& index) { + if (!subshape.IsArray()) { + return; + } + if (is_in_alternate_mem(operand_num, index, subshape)) { + bytes_accessed_from_alternate_mem += + hlo_cost_analysis_.operand_bytes_accessed(instruction, + operand_num, index); + } + }); + } + ShapeUtil::ForEachSubshape(instruction.shape(), [&](const Shape& subshape, + const ShapeIndex& index) { + if (!subshape.IsArray()) { + return; + } + if (is_in_alternate_mem(/*operand_num=*/std::nullopt, index, subshape)) { + bytes_accessed_from_alternate_mem += + hlo_cost_analysis_.output_bytes_accessed(instruction, index); + } + }); + float elapsed_due_to_alternate_mem = + bytes_accessed_from_alternate_mem / + options_.alternate_mem_bandwidth_bytes_per_second; + float elapsed_due_to_default_mem = + (total_bytes_accessed - bytes_accessed_from_alternate_mem) / + hlo_cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); + return elapsed_due_to_alternate_mem + elapsed_due_to_default_mem; +} + +float CostAnalysis::GetInstructionElapsed( + const HloInstruction& instruction) const { + if (ExcludeInstructionFromElapsed(instruction)) { + return 0.0f; + } + float overhead = GetDefaultMemoryAccessOverhead(instruction); + return std::max(GetInstructionElapsedDueToCompute(instruction), + GetInstructionElapsedDueToMemory(instruction) + overhead); +} + +float CostAnalysis::GetInstructionElapsedInAlternateMemory( + const HloInstruction& instruction, + absl::Span> operands_in_alternate_mem, + absl::Span outputs_in_alternate_mem) const { + if (ExcludeInstructionFromElapsed(instruction)) { + return 0.0f; + } + float overhead = GetDefaultMemoryAccessOverhead( + instruction, operands_in_alternate_mem, outputs_in_alternate_mem); + return std::max( + GetInstructionElapsedDueToCompute(instruction), + GetInstructionElapsedDueToMemory(instruction, operands_in_alternate_mem, + outputs_in_alternate_mem) + + overhead); +} + +float CostAnalysis::GetInstructionElapsedInAlternateMemory( + const HloInstruction& instruction, + IsInAlternateMemoryFun is_in_alternate_mem) const { + if (ExcludeInstructionFromElapsed(instruction)) { + return 0.0f; + } + return std::max( + GetInstructionElapsedDueToCompute(instruction), + GetInstructionElapsedDueToMemory(instruction, is_in_alternate_mem)); +} + +float CostAnalysis::GetAsyncCopyElapsed(const Shape& shape) const { + int64_t size_in_bytes = hlo_cost_analysis_.GetShapeSize(shape); + return static_cast(size_in_bytes) / + (options_.async_copy_bandwidth_bytes_per_second * + options_.async_copy_bandwidth_scaling_factor); +} + +int64_t CostAnalysis::GetScheduleEndTime() const { + return hlo_live_range_->schedule_end_time(); +} + +} // namespace memory_space_assignment +} // namespace xla diff --git a/xla/service/memory_space_assignment/cost_analysis.h b/xla/service/memory_space_assignment/cost_analysis.h new file mode 100644 index 0000000000000..b8152dbc23fda --- /dev/null +++ b/xla/service/memory_space_assignment/cost_analysis.h @@ -0,0 +1,237 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_COST_ANALYSIS_H_ +#define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_COST_ANALYSIS_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/functional/function_ref.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/utils/hlo_live_range.h" +#include "xla/service/call_graph.h" +#include "xla/service/heap_simulator/heap_simulator.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/hlo_value.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/statusor.h" +#include "xla/util.h" + +namespace xla { +namespace memory_space_assignment { + +// Options to be passed to the CostAnalysis. +struct CostAnalysisOptions { + // This variable is used by the cost analysis in estimating how many times + // each while loop will execute. Nested loops will be assumed to have + // executed pow(while_execution_count, nesting_level) times. + uint64_t xla_tpu_memory_space_assignment_while_execution_count = 5ULL; + + // This variable is used to scale the alternate memory benefit factor for + // large buffers. The default scaling function is sqrt. + std::string + xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers = + "SQRT"; + + // The window size used to calculate the pipeline overhead when HLO accesses + // the default memory, in MiB. + float pipeline_overhead_window_size_mib = 0; + + float alternate_mem_bandwidth_bytes_per_second = 0.0f; + + float async_copy_bandwidth_bytes_per_second = 0.0f; + + // Scales effective bandwidth for async copies. Valid range is (0, 1]. + float async_copy_bandwidth_scaling_factor = 1.0; +}; + +// A wrapper class around HloCostAnalysis with additional knowledge about the +// bandwidths of different memory spaces. +class CostAnalysis { + public: + // An optional Cache object may be provided to some of the methods below to + // speed up the lookup. + struct Cache { + absl::flat_hash_map while_nest_multiplier; + absl::flat_hash_map memory_boundedness; + }; + + // Function type that can be used to indicate which input/output values are in + // the alternate memory. + using IsInAlternateMemoryFun = absl::FunctionRef /*operand_num*/, const ShapeIndex& /*index*/, + const Shape& /*shape*/)>; + + virtual ~CostAnalysis() = default; + + static absl::StatusOr> Create( + const HloCostAnalysis& cost_analysis, const CostAnalysisOptions& options, + const HloModule& module); + + const HloCostAnalysis& hlo_cost_analysis() const { + return hlo_cost_analysis_; + } + + // Returns a heuristic value that captures how much putting this tensor to the + // alternate memory would help if the op is memory bound, or otherwise how far + // off is the op to memory boundedness. The larger this number, the higher + // priority it will be placed in the alternate memory. + float GetAlternateMemoryBenefit(const HloInstruction& instruction, + float elapsed_time_due_to_alternate_mem, + Cache* cache = nullptr) const; + // Like above, return the benefit of putting the output tensor in the + // alternate memory. + float GetAlternateMemoryBenefit(const HloPosition& position, + Cache* cache = nullptr) const; + // Like above, return the benefit of putting the input tensor in the alternate + // memory. + float GetAlternateMemoryBenefit(const HloUse& use, + Cache* cache = nullptr) const; + + // Returns a heuristic value of memory boundedness for the given + // BufferInterval. The larger this number, the higher priority it will be + // placed in the alternate memory. + float GetMemoryBoundedness( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, + Cache* cache = nullptr) const; + + // If enabled in CostAnalysisOptions::pipeline_overhead_window_size_mib, + // returns the overhead of accessing the default memory, in seconds. The + // source of the overhead is the software pipelining ovehead. The lowering of + // the operations typically use tiling to copy one window at a time from + // default memory, and perform compute: + // + // Pipeline overhead: <-> + // +----+----+----+----+ + // Copy from default mem: | | | | | + // +----+----+----+----+ + // \ \ \ \ + // \ \ \ \ + // V V V V + // +--+ +--+ +--+ +--+ + // Compute: | | | | | | | | + // +--+ +--+ +--+ +--+ + float GetDefaultMemoryAccessOverhead( + const HloInstruction& instruction, + absl::Span> + operands_in_alternate_mem = {}, + absl::Span outputs_in_alternate_mem = {}) const; + + // Returns the amount of time the default memory bandwidth is idle, while + // executing this instruction, in seconds. This value can be multiplied with + // the default memory bandwidth to get the amount of bytes that are available + // to be copied to/from default memory during the execution of this + // instruction. + float GetDefaultMemoryBandwidthIdleTime( + const HloInstruction& instruction, + absl::Span> + operands_in_alternate_mem = {}, + absl::Span outputs_in_alternate_mem = {}) const; + + // Returns the bytes accessed from alternate memory. + float GetBytesAccessedFromAlternateMemory( + const HloInstruction& instruction, + absl::Span> + operands_in_alternate_mem = {}, + absl::Span outputs_in_alternate_mem = {}) const; + + // Returns the elapsed time in seconds due to compute only. + float GetInstructionElapsedDueToCompute( + const HloInstruction& instruction) const; + + // Returns the elapsed time in seconds due to memory only. If + // operands_in_alternate_mem or outputs_in_alternate_mem is provided, it will + // assume that the corresponding operands or output will be in the alternate + // memory space. This is useful for calculating the benefit of placing the + // buffer in alternate memory. + float GetInstructionElapsedDueToMemory( + const HloInstruction& instruction, + absl::Span> + operands_in_alternate_mem = {}, + absl::Span outputs_in_alternate_mem = {}) const; + + // Like above, only the inputs/outputs indicated by is_in_alternate_mem are in + // the alternate memory. + float GetInstructionElapsedDueToMemory( + const HloInstruction& instruction, + IsInAlternateMemoryFun is_in_alternate_mem) const; + + // Returns the estimated elapsed duration of the instruction in seconds. It + // assumes all operands and outputs of the instruction are in the default + // memory. + virtual float GetInstructionElapsed(const HloInstruction& instruction) const; + + // Returns the estimated elapsed duration of the instruction in seconds. It + // assumes all operands and outputs of the instruction are in the default + // memory, except for the operands and outputs specified to be in the + // alternate memory. + virtual float GetInstructionElapsedInAlternateMemory( + const HloInstruction& instruction, + absl::Span> + operands_in_alternate_mem, + absl::Span outputs_in_alternate_mem) const; + + // Like above, only the inputs/outputs indicated by is_in_alternate_mem are in + // the alternate memory. + float GetInstructionElapsedInAlternateMemory( + const HloInstruction& instruction, + IsInAlternateMemoryFun is_in_alternate_mem) const; + + // Returns the elapsed time it would take to asynchronously copy the shape + // from default to alternate memory space (or vice versa). + virtual float GetAsyncCopyElapsed(const Shape& shape) const; + + int64_t GetScheduleEndTime() const; + + // Returns the number of nested computation levels this instruction resides + // in. If while_only is true, it returns the while loop nest level and 0 + // means the instruction is not in a while loop. + int CalculateComputationNestLevel(const HloInstruction* instruction, + bool while_only) const; + float GetWhileNestMultiplier(int while_nest_level) const; + + const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; } + + protected: + CostAnalysis(const HloCostAnalysis& hlo_cost_analysis, + const CostAnalysisOptions& options, + std::unique_ptr alias_analysis, + std::unique_ptr hlo_live_range, + std::unique_ptr call_graph) + : hlo_cost_analysis_(hlo_cost_analysis), + options_(options), + alias_analysis_(std::move(alias_analysis)), + hlo_live_range_(std::move(hlo_live_range)), + call_graph_(std::move(call_graph)) {} + + private: + const HloCostAnalysis& hlo_cost_analysis_; + const CostAnalysisOptions options_; + std::unique_ptr alias_analysis_; + std::unique_ptr hlo_live_range_; + std::unique_ptr call_graph_; +}; + +} // namespace memory_space_assignment +} // namespace xla +#endif // XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_COST_ANALYSIS_H_ diff --git a/xla/service/memory_space_assignment/cost_analysis_test.cc b/xla/service/memory_space_assignment/cost_analysis_test.cc new file mode 100644 index 0000000000000..54567bb3855d0 --- /dev/null +++ b/xla/service/memory_space_assignment/cost_analysis_test.cc @@ -0,0 +1,229 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/memory_space_assignment/cost_analysis.h" + +#include +#include + +#include +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +using memory_space_assignment::CostAnalysis; +using memory_space_assignment::CostAnalysisOptions; + +constexpr int64_t kPointerSize = 8; + +int64_t ShapeSize(const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, kPointerSize); +} + +class MemorySpaceAssignmentCostAnalysisTest : public HloTestBase { + protected: + Status Initialize(const HloModule* module, + float pipeline_overhead_window_size_mib = 0.0) { + HloCostAnalysis::Options options; + options_.alternate_mem_bandwidth_bytes_per_second = 128; + options_.async_copy_bandwidth_bytes_per_second = 32; + options_.pipeline_overhead_window_size_mib = + pipeline_overhead_window_size_mib; + options.shape_size = ShapeSize; + options.set_flops_per_second(8); + options.set_bytes_per_second(32); + options.set_transcendentals_per_second(16); + hlo_cost_analysis_ = std::make_unique(options); + TF_RETURN_IF_ERROR( + module->entry_computation()->Accept(hlo_cost_analysis_.get())); + TF_ASSIGN_OR_RETURN( + cost_analysis_, + CostAnalysis::Create(*hlo_cost_analysis_, options_, *module)); + return OkStatus(); + } + + CostAnalysisOptions options_; + std::unique_ptr hlo_cost_analysis_; + std::unique_ptr cost_analysis_; +}; + +TEST_F(MemorySpaceAssignmentCostAnalysisTest, NoPipelineOverhead) { + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY Entry { + param0 = f32[2,4] parameter(0) + param1 = f32[2,4] parameter(1) + ROOT add = f32[2,4] add(param0, param1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK(Initialize(module.get())); + + const HloInstruction* add = module->entry_computation()->root_instruction(); + const float expected_compute_elapsed = + /*num_flops=*/8 / /*flops_per_second=*/8.0; + LOG(INFO) << "Expected compute elapsed = " << expected_compute_elapsed; + EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToCompute(*add), + expected_compute_elapsed); + float expected_memory_elapsed = + /*bytes_accessed=*/(3 * 4 * 8) / /*bytes_per_second=*/32.0; + LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; + EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory(*add), + expected_memory_elapsed); + + // This HLO is memory-bound. + EXPECT_EQ(cost_analysis_->GetInstructionElapsed(*add), + expected_memory_elapsed); + EXPECT_EQ( + cost_analysis_->GetInstructionElapsedInAlternateMemory(*add, {}, {}), + expected_memory_elapsed); + + // Put operand 0 in alternate memory. Still memory bound. + expected_memory_elapsed = + (/*bytes_accessed=*/(2 * 4 * 8) / /*bytes_per_second=*/32.0) + + (/*bytes_accessed=*/(4 * 8) / /*bytes_per_second=*/128.0); + LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; + EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory(*add, {{0, {}}}), + expected_memory_elapsed); + EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( + *add, {{0, {}}}, {}), + expected_memory_elapsed); + + // Put operand 0 and output in alternate memory. Still memory bound. + expected_memory_elapsed = + (/*bytes_accessed=*/(4 * 8) / /*bytes_per_second=*/32.0) + + (/*bytes_accessed=*/(2 * 4 * 8) / /*bytes_per_second=*/128.0); + LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; + EXPECT_EQ( + cost_analysis_->GetInstructionElapsedDueToMemory(*add, {{0, {}}}, {{}}), + expected_memory_elapsed); + EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( + *add, {{0, {}}}, {{}}), + expected_memory_elapsed); + + // Put everything in alternate memory. We're now compute bound. + expected_memory_elapsed = + /*bytes_accessed=*/(3 * 4 * 8) / /*bytes_per_second=*/128.0; + LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; + EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory( + *add, {{0, {}}, {1, {}}}, {{}}), + expected_memory_elapsed); + EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( + *add, {{0, {}}, {1, {}}}, {{}}), + expected_compute_elapsed); +} + +TEST_F(MemorySpaceAssignmentCostAnalysisTest, PipelineOverhead) { + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY Entry { + param0 = f32[2,4] parameter(0) + param1 = f32[2,4] parameter(1) + ROOT add = f32[2,4] add(param0, param1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + // Set the window size 64B. + TF_ASSERT_OK( + Initialize(module.get(), + /*pipeline_overhead_window_size_mib=*/(64.0 / 1024 / 1024))); + + const HloInstruction* add = module->entry_computation()->root_instruction(); + const float expected_compute_elapsed = + /*num_flops=*/8 / /*flops_per_second=*/8.0; + LOG(INFO) << "Expected compute elapsed = " << expected_compute_elapsed; + EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToCompute(*add), + expected_compute_elapsed); + float expected_memory_elapsed = + /*bytes_accessed=*/(3 * 4 * 8) / /*bytes_per_second=*/32.0; + LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; + EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory(*add), + expected_memory_elapsed); + + float expected_overhead = expected_compute_elapsed * 2 / 3; + LOG(INFO) << "Expected overhead = " << expected_overhead; + EXPECT_EQ(cost_analysis_->GetDefaultMemoryAccessOverhead(*add), + expected_overhead); + // This HLO is memory-bound. + EXPECT_EQ(cost_analysis_->GetInstructionElapsed(*add), + expected_memory_elapsed + expected_overhead); + EXPECT_EQ( + cost_analysis_->GetInstructionElapsedInAlternateMemory(*add, {}, {}), + expected_memory_elapsed + expected_overhead); + + // Put operand 0 in alternate memory. Still memory bound. + expected_memory_elapsed = + (/*bytes_accessed=*/(2 * 4 * 8) / /*bytes_per_second=*/32.0) + + (/*bytes_accessed=*/(4 * 8) / /*bytes_per_second=*/128.0); + LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; + EXPECT_EQ(cost_analysis_->GetDefaultMemoryAccessOverhead(*add, {{0, {}}}), + expected_overhead); + EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory(*add, {{0, {}}}), + expected_memory_elapsed); + EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( + *add, {{0, {}}}, {}), + expected_memory_elapsed + expected_overhead); + + // Put operand 0 and output in alternate memory. Still memory bound. + expected_memory_elapsed = + (/*bytes_accessed=*/(4 * 8) / /*bytes_per_second=*/32.0) + + (/*bytes_accessed=*/(2 * 4 * 8) / /*bytes_per_second=*/128.0); + LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; + expected_overhead = expected_compute_elapsed / 3; + LOG(INFO) << "Expected overhead = " << expected_overhead; + EXPECT_EQ( + cost_analysis_->GetDefaultMemoryAccessOverhead(*add, {{0, {}}}, {{}}), + expected_overhead); + EXPECT_EQ( + cost_analysis_->GetInstructionElapsedDueToMemory(*add, {{0, {}}}, {{}}), + expected_memory_elapsed); + EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( + *add, {{0, {}}}, {{}}), + expected_memory_elapsed + expected_overhead); + + // Put everything in alternate memory. We're now compute bound. + expected_memory_elapsed = + /*bytes_accessed=*/(3 * 4 * 8) / /*bytes_per_second=*/128.0; + LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; + expected_overhead = 0; + LOG(INFO) << "Expected overhead = " << expected_overhead; + EXPECT_EQ(cost_analysis_->GetDefaultMemoryAccessOverhead( + *add, {{0, {}}, {1, {}}}, {{}}), + expected_overhead); + EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory( + *add, {{0, {}}, {1, {}}}, {{}}), + expected_memory_elapsed); + EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( + *add, {{0, {}}, {1, {}}}, {{}}), + expected_compute_elapsed); +} + +} // namespace +} // namespace xla diff --git a/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc b/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc new file mode 100644 index 0000000000000..c2f481c458be8 --- /dev/null +++ b/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc @@ -0,0 +1,1238 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/memory_space_assignment/memory_bound_loop_optimizer.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_live_range.h" +#include "xla/service/buffer_value.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_buffer.h" +#include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/allocation.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" +#include "xla/service/memory_space_assignment/options.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/status_macros.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" + +namespace xla { +namespace memory_space_assignment { + +/*static*/ absl::StatusOr> +MemoryBoundLoopOptimizer::Create( + int loop_start, int loop_end, uint64_t alternate_memory_size, + const MemoryBoundLoopOptimizerOptions& options, + const HloLiveRange& hlo_live_range, const HloAliasAnalysis& alias_analysis, + const CostAnalysis& cost_analysis, + const BufferValue::SizeFunction& size_function, + const ReservedScopedMemoryFunction& reserved_scoped_memory_fn) { + std::unique_ptr optimizer = + absl::WrapUnique(new MemoryBoundLoopOptimizer( + loop_start, loop_end, alternate_memory_size, options, hlo_live_range, + alias_analysis, cost_analysis, size_function, + reserved_scoped_memory_fn)); + TF_RETURN_IF_ERROR(optimizer->Initialize()); + return std::move(optimizer); +} + +MemoryBoundLoopOptimizer::MemoryBoundLoopOptimizer( + int loop_start, int loop_end, uint64_t alternate_memory_size, + const MemoryBoundLoopOptimizerOptions& options, + const HloLiveRange& hlo_live_range, const HloAliasAnalysis& alias_analysis, + const CostAnalysis& cost_analysis, + const BufferValue::SizeFunction& size_function, + const ReservedScopedMemoryFunction& reserved_scoped_memory_fn) + : loop_start_(loop_start), + loop_end_(loop_end), + loop_size_(loop_end - loop_start), + alternate_memory_size_(alternate_memory_size), + options_(options), + hlo_live_range_(hlo_live_range), + alias_analysis_(alias_analysis), + cost_analysis_(cost_analysis), + size_function_(size_function), + reserved_scoped_memory_fn_(reserved_scoped_memory_fn) {} + +Status MemoryBoundLoopOptimizer::Initialize() { + const auto& instruction_sequence = + hlo_live_range_.flattened_instruction_sequence().instructions(); + VLOG(3) << "MemoryBoundLoopOptimizer::Initialize, loop start: " << loop_start_ + << ", loop end: " << loop_end_ << ", loop size: " << loop_size_; + const HloComputation* loop_computation = nullptr; + // Initialize the remaining memory array with the size of the alternate + // memory. Also populate instructions_in_loop_ and + // instructions_in_{prev,next}_iterations_ data structures to help find the + // loop values. + for (int i = loop_start_; i < loop_end_; ++i) { + const HloInstruction* inst = instruction_sequence[i]; + instructions_in_loop_[inst] = i - loop_start_; + VLOG(3) << " inst in loop [" << (i - loop_start_) << "]: " << inst->name(); + if (!loop_computation) { + loop_computation = inst->parent(); + } else { + TF_RET_CHECK(loop_computation == inst->parent()); + } + remaining_memory_.push_back( + alternate_memory_size_ - + reserved_scoped_memory_fn_(inst, /*operands_in_alternate_memory=*/{}, + /*outputs_in_alternate_memory=*/{})); + } + + for (int i = loop_start_ - loop_size_; i < loop_start_; ++i) { + const HloInstruction* inst = instruction_sequence[i]; + instructions_in_prev_iteration_[inst] = i - loop_start_ + loop_size_; + } + for (int i = loop_end_; i < loop_end_ + loop_size_; ++i) { + const HloInstruction* inst = instruction_sequence[i]; + instructions_in_next_iteration_[inst] = i - loop_end_; + } + + // Create a tree set to keep track of all the values that the loop + // instructions produce and consume. We use a tree set instead of a hash set + // to ensure the iteration order is the same as insertion order. Since we + // traverse the program in instruction order, the buffers would be inserted in + // a deterministic order, so we'll be able to iterate over these buffers in a + // deterministic order. + std::set buffers_to_process; + for (const auto& [instruction, idx] : instructions_in_loop_) { + auto maybe_add_buffer = [&](const HloInstruction* instruction) { + return [this, &buffers_to_process, instruction](const Shape& subshape, + const ShapeIndex& index) { + if (!subshape.IsArray()) { + return; + } + const HloBuffer& buffer = + alias_analysis_.GetUniqueBufferAt(instruction, index); + if (buffers_to_process.find(&buffer) == buffers_to_process.end()) { + buffers_to_process.insert(&buffer); + } + }; + }; + ShapeUtil::ForEachSubshape(instruction->shape(), + maybe_add_buffer(instruction)); + for (const HloInstruction* operand : instruction->operands()) { + ShapeUtil::ForEachSubshape(operand->shape(), maybe_add_buffer(operand)); + } + } + + // Process the buffers and decide if they should be added as LoopValues. + for (const HloBuffer* buffer : buffers_to_process) { + MaybeCreateLoopValue(*buffer, loop_computation); + } + return OkStatus(); +} + +void MemoryBoundLoopOptimizer::MaybeCreateLoopValue( + const HloBuffer& buffer, const HloComputation* loop_computation) { + // Define helper lambdas to get the loop-relative index of the given + // instruction. + auto get_index_in_loop = + [&](const HloInstruction* instruction, + const absl::flat_hash_map& + instructions_in_loop, + int64_t relative_index = 0) { + std::optional loop_index; + if (instructions_in_loop.contains(instruction)) { + loop_index = hlo_live_range_.instruction_schedule().at(instruction) - + loop_start_ + relative_index; + CHECK_GE(*loop_index, 0); + CHECK_LT(*loop_index, loop_size_); + } + return loop_index; + }; + auto get_index_in_current_iteration = [&](const HloInstruction* instruction) { + return get_index_in_loop(instruction, instructions_in_loop_); + }; + auto get_index_in_prev_iteration = [&](const HloInstruction* instruction) { + return get_index_in_loop(instruction, instructions_in_prev_iteration_, + loop_size_); + }; + auto get_index_in_next_iteration = [&](const HloInstruction* instruction) { + return get_index_in_loop(instruction, instructions_in_next_iteration_, + -loop_size_); + }; + + loop_values_.push_back({}); + LoopValue& loop_value = loop_values_.back(); + float pos_bytes = 0; + float use_bytes = 0; + bool has_footer_consumer = false; + for (const HloValue* value : buffer.values()) { + // For each position and use of the value, populate the respecive position + // and use fields for the current, previous, and next iterations along with + // the loop indices. + for (const HloPosition& position : value->positions()) { + if (position.instruction->opcode() == HloOpcode::kGetTupleElement) { + continue; + } + std::optional loop_index = + get_index_in_current_iteration(position.instruction); + std::optional prev_iteration_index; + if (loop_index) { + loop_value.loop_positions.push_back({*loop_index, position}); + VLOG(3) << "Pos match: " << position.instruction->name() << " at " + << *loop_index; + } else if ((prev_iteration_index = + get_index_in_prev_iteration(position.instruction))) { + loop_value.prev_iteration_positions.push_back( + {*prev_iteration_index, position}); + VLOG(3) << "Pos match (prev iteration): " + << position.instruction->name() << " at " + << *prev_iteration_index; + } else if (loop_value.prev_iteration_positions.empty() && + loop_value.loop_positions.empty() && + position.instruction->parent() == loop_computation && + !loop_value.header_position) { + loop_value.header_position = position; + } + + // Keep track of bytes accessed by this value. + if (loop_index || prev_iteration_index) { + float bytes_accessed = + cost_analysis_.hlo_cost_analysis().output_bytes_accessed( + *position.instruction, position.index); + pos_bytes += bytes_accessed; + VLOG(3) << " accessed: " << bytes_accessed; + } + } + + for (const HloUse& use : value->GetUses()) { + if (use.instruction->opcode() == HloOpcode::kGetTupleElement) { + continue; + } + std::optional loop_index = + get_index_in_current_iteration(use.instruction); + std::optional next_iteration_index; + if (loop_index) { + loop_value.loop_uses.push_back({*loop_index, use}); + VLOG(3) << "Use match: " << use.instruction->name() << " at " + << *loop_index; + } else if ((next_iteration_index = + get_index_in_next_iteration(use.instruction))) { + loop_value.next_iteration_uses.push_back({*next_iteration_index, use}); + VLOG(3) << "Use match (next iteration): " << use.instruction->name() + << " at " << *next_iteration_index; + } else if (!loop_value.loop_positions.empty() || + !loop_value.loop_uses.empty()) { + has_footer_consumer = true; + } + + // Keep track of bytes accessed by this value. + if (loop_index || next_iteration_index) { + float bytes_accessed = + cost_analysis_.hlo_cost_analysis().operand_bytes_accessed( + *use.instruction, use.operand_number, use.operand_index); + use_bytes += bytes_accessed; + VLOG(3) << " accessed: " << bytes_accessed; + } + } + } + + // We only add the loop position if it has a position or use in the current + // iteration and its previous iteration positions are empty. The reason why we + // disallow values with previous iteration positions is because there will be + // a different value that corresponds to the same value but one iteration + // later, so we will add that one instead. + if ((!loop_value.loop_positions.empty() || !loop_value.loop_uses.empty()) && + loop_value.prev_iteration_positions.empty()) { + loop_value.size = size_function_(**buffer.values().begin()); + VLOG(3) << "Size: " << loop_value.size; + // Classify the type of allocation. See the comment in LoopValue definition. + loop_value.allocation_type = LoopValue::AllocationType::kUnsupported; + auto position_compare = [](const std::pair& a, + const std::pair& b) { + return a.first < b.first; + }; + auto use_compare = [](const std::pair& a, + const std::pair& b) { + return a.first < b.first; + }; + absl::c_sort(loop_value.loop_positions, position_compare); + absl::c_sort(loop_value.prev_iteration_positions, position_compare); + absl::c_sort(loop_value.loop_uses, use_compare); + absl::c_sort(loop_value.next_iteration_uses, use_compare); + if (!loop_value.loop_positions.empty()) { + if (loop_value.next_iteration_uses.empty() && + !loop_value.loop_uses.empty()) { + loop_value.allocation_type = LoopValue::AllocationType::kTemporary; + } else if (!loop_value.next_iteration_uses.empty()) { + if (loop_value.next_iteration_uses.back().first >= + loop_value.loop_positions.front().first) { + loop_value.allocation_type = + LoopValue::AllocationType::kLoopCarriedDependence; + } else { + loop_value.allocation_type = LoopValue::AllocationType::kTemporary; + } + } + } else if (loop_value.header_position && !loop_value.loop_uses.empty()) { + if (loop_value.loop_uses.size() == + loop_value.next_iteration_uses.size() && + loop_value.loop_uses.front().first == + loop_value.next_iteration_uses.front().first) { + loop_value.allocation_type = LoopValue::AllocationType::kPinned; + } else if (loop_value.next_iteration_uses.empty() || + loop_value.next_iteration_uses.back().first < + loop_value.loop_uses.front().first) { + loop_value.allocation_type = LoopValue::AllocationType::kPrefetch; + } + } + + VLOG(3) << "Allocation type " + << LoopValue::AllocationTypeToString(loop_value.allocation_type); + VLOG(3) << "Pos bytes: " << pos_bytes << " use bytes: " << use_bytes; + + // We calculate the savings of allocating this buffer in the alternate + // memory. + float savings = pos_bytes + use_bytes; + if (loop_value.header_position) { + savings -= loop_value.size; + } + if (!loop_value.loop_positions.empty() && has_footer_consumer) { + savings -= loop_value.size; + } + loop_value.savings = savings; + loop_value.savings_per_byte = savings / loop_value.size; + VLOG(3) << "Savings: " << loop_value.savings; + VLOG(3) << "Savings per byte: " << loop_value.savings_per_byte; + for (const HloValue* value : buffer.values()) { + VLOG(3) << value->ToString(); + } + auto sort_positions = [](const std::pair& a, + const std::pair& b) { + return a.first < b.first; + }; + auto sort_uses = [](const std::pair& a, + const std::pair& b) { + return a.first < b.first; + }; + absl::c_sort(loop_value.loop_positions, sort_positions); + absl::c_sort(loop_value.prev_iteration_positions, sort_positions); + absl::c_sort(loop_value.loop_uses, sort_uses); + absl::c_sort(loop_value.next_iteration_uses, sort_uses); + loop_value.hlo_values = buffer.values(); + } else { + loop_values_.pop_back(); + } +} + +void MemoryBoundLoopOptimizer::Optimize() { + SortLoopValues(); + AllocateLoopValues(); + PostProcess(); +} + +float MemoryBoundLoopOptimizer::CalculateExecutionTime() const { + // First populate the list of prefetches. + std::vector> prefetches; + for (const LoopValue& value : loop_values_) { + if (!value.allocations.empty() && + value.allocations.back()->is_copy_allocation()) { + prefetches.push_back( + {static_cast(value.allocations.back().get()), + cost_analysis_.GetAsyncCopyElapsed( + value.hlo_values.front()->shape())}); + } + } + + // Returns the effective prefetch completion time. The effective time is a + // value that will be larger than loop size for prefetches that start in this + // iteration but complete in the next iteration. + auto get_effective_done_time = + [&](int64_t copy_start_schedule_after, + int64_t copy_done_schedule_before) -> int64_t { + if (copy_start_schedule_after == loop_size_ - 1 && + copy_done_schedule_before == 0) { + return 2 * loop_size_; + } + if (copy_start_schedule_after + 1 >= copy_done_schedule_before) { + return copy_done_schedule_before + loop_size_; + } + return copy_done_schedule_before; + }; + + // Sort the prefetches by first the start time, then the effective done time. + absl::c_sort( + prefetches, [&](const std::pair& a, + const std::pair& b) { + return std::forward_as_tuple( + a.first->copy_start_schedule_after(), + get_effective_done_time( + a.first->copy_start_schedule_after(), + a.first->copy_done_schedule_before())) < + std::forward_as_tuple(b.first->copy_start_schedule_after(), + get_effective_done_time( + b.first->copy_start_schedule_after(), + b.first->copy_done_schedule_before())); + }); + // Populate the required prefetch completions array. For each instruction in + // the loop, this vector holds the index of the latest-issued prefetch that + // needs to be completed before the instruction executes, or nullopt if there + // is no prefetch that needs to finish by this instruction. To represent + // prefetches that started in the previous iteration, we use negative numbers. + std::vector> required_prefetch_completions(loop_size_); + for (int i = 0; i < prefetches.size(); ++i) { + const auto& [prefetch, elapsed] = prefetches[i]; + int required_prefetch_completion = i; + if (prefetch->copy_start_schedule_after() == loop_size_ - 1 && + prefetch->copy_done_schedule_before() == 0) { + required_prefetch_completion -= 2 * prefetches.size(); + } else if (prefetch->copy_start_schedule_after() + 1 >= + prefetch->copy_done_schedule_before()) { + required_prefetch_completion -= prefetches.size(); + } + VLOG(3) << "Prefetch #" << i << " (elapsed " << elapsed + << "): " << prefetch->ToString(); + if (required_prefetch_completions[prefetch->copy_done_schedule_before()]) { + required_prefetch_completions[prefetch->copy_done_schedule_before()] = + std::max( + *required_prefetch_completions[prefetch + ->copy_done_schedule_before()], + required_prefetch_completion); + } else { + required_prefetch_completions[prefetch->copy_done_schedule_before()] = + required_prefetch_completion; + } + VLOG(4) + << "Required completion at " << prefetch->copy_done_schedule_before() + << " = " + << *required_prefetch_completions[prefetch + ->copy_done_schedule_before()]; + } + + // Populate the elapsed times of instructions and bandwidth idle times at each + // point. + float result; + std::vector bandwidth_idle_times; + std::vector instructions_elapsed; + bandwidth_idle_times.reserve(loop_size_); + instructions_elapsed.reserve(loop_size_); + for (int i = 0; i < loop_size_; ++i) { + bandwidth_idle_times.push_back(GetBandwidthIdleTime(i)); + instructions_elapsed.push_back(GetInstructionElapsed(i)); + } + // We simulate the loop for three iterations to measure the steady state. + const int kNumIterations = 3; + // This data structure keeps track of the elapsed time remaining of each + // prefetch. Note that there is a separate entry for each prefetch in each + // iteration simulated. + std::vector prefetch_remaining_elapsed_times(prefetches.size() * + kNumIterations); + int prefetch_start_index = 0; + int prefetch_done_index = 0; + int prefetch_completed_index = 0; + + for (int iteration = 0; iteration < kNumIterations; ++iteration) { + float total_elapsed = 0; + float total_bandwidth_idle_time = 0; + float total_critical_prefetch = 0; + for (int i = 0; i < loop_size_; ++i) { + // If any prefetches are expected to be completed, check if they have any + // remaining elapsed time associated with them, and if so add this to + // critical prefetch time. + std::optional required_prefetch_completion = + required_prefetch_completions[i]; + if (required_prefetch_completion) { + int required_prefetch_done_index = + iteration * static_cast(prefetches.size()) + + *required_prefetch_completion; + VLOG(4) << "Prefetch #" + << ((*required_prefetch_completion + prefetches.size()) % + prefetches.size()) + << " (" << required_prefetch_done_index + << ") is required to be completed at " << i; + for (; prefetch_done_index <= required_prefetch_done_index; + ++prefetch_done_index) { + CHECK_LE(prefetch_done_index, prefetch_start_index); + if (prefetch_done_index == prefetch_completed_index) { + float& prefetch_remaining = + prefetch_remaining_elapsed_times[prefetch_done_index]; + VLOG(4) << "Prefetch #" << (prefetch_done_index % prefetches.size()) + << " (" << prefetch_done_index + << ") did not complete, remaining elapsed = " + << prefetch_remaining; + total_critical_prefetch += prefetch_remaining; + prefetch_remaining = 0; + ++prefetch_completed_index; + } + } + } + + float elapsed = instructions_elapsed[i]; + total_elapsed += elapsed; + float bandwidth_idle_time = bandwidth_idle_times[i]; + // Find the outstanding prefetches during this instruction, and if any of + // them have remaining time, spend some or all of the bandwidth idle time + // to satisfy them. + for (; prefetch_completed_index < prefetch_start_index; + ++prefetch_completed_index) { + float& prefetch_remaining = + prefetch_remaining_elapsed_times[prefetch_completed_index]; + if (bandwidth_idle_time < prefetch_remaining) { + prefetch_remaining -= bandwidth_idle_time; + bandwidth_idle_time = 0; + VLOG(4) << "Prefetch #" + << (prefetch_completed_index % prefetches.size()) << " (" + << prefetch_completed_index << ") still ongoing at " << i + << ", remaining elapsed = " << prefetch_remaining; + break; + } + bandwidth_idle_time -= prefetch_remaining; + prefetch_remaining = 0; + VLOG(4) << "Prefetch #" + << (prefetch_completed_index % prefetches.size()) << " (" + << prefetch_completed_index << ") completed at " << i + << ", bandwidth idle time = " << bandwidth_idle_time; + } + if (bandwidth_idle_time > 0) { + VLOG(4) << "Bandwidth idle time at " << i << " = " + << bandwidth_idle_time; + total_bandwidth_idle_time += bandwidth_idle_time; + } + + // Start new prefetches that are scheduled to start after this + // instruction. + for (; prefetch_start_index < (iteration + 1) * prefetches.size() && + prefetches[prefetch_start_index % prefetches.size()] + .first->copy_start_schedule_after() == i; + ++prefetch_start_index) { + float& prefetch_remaining = + prefetch_remaining_elapsed_times[prefetch_start_index]; + prefetch_remaining = + prefetches[prefetch_start_index % prefetches.size()].second; + VLOG(4) << "Prefetch #" << (prefetch_start_index % prefetches.size()) + << " (" << prefetch_start_index << ") started at " << i + << ", remaining elapsed = " << prefetch_remaining; + } + } + VLOG(3) << "Iteration " << iteration; + VLOG(3) << "Total elapsed: " << total_elapsed + << ", total critical prefetch: " << total_critical_prefetch + << ", total bandwidth idle time: " << total_bandwidth_idle_time; + result = total_elapsed + total_critical_prefetch; + } + return result; +} + +/*static*/ std::string +MemoryBoundLoopOptimizer::LoopValue::AllocationTypeToString( + LoopValue::AllocationType allocation_type) { + switch (allocation_type) { + case AllocationType::kTemporary: + return "temporary"; + case AllocationType::kLoopCarriedDependence: + return "loop-carried dependence"; + case AllocationType::kPinned: + return "pinned"; + case AllocationType::kPrefetch: + return "prefetch"; + default: + CHECK(allocation_type == AllocationType::kUnsupported); + return "unsupported"; + } +} + +std::string MemoryBoundLoopOptimizer::LoopValue::ToString() const { + std::string values_str; + absl::StrAppend(&values_str, "Values:"); + for (const HloValue* hlo_value : hlo_values) { + absl::StrAppend(&values_str, "\n - ", hlo_value->ToShortString()); + } + std::string allocations_str; + if (!allocations.empty()) { + absl::StrAppend(&allocations_str, "Allocations:"); + } + for (const auto& allocation : allocations) { + absl::StrAppend(&allocations_str, "\n - ", allocation->ToString()); + } + return absl::StrCat( + "Size: ", size, " savings: ", savings, + " savings per byte: ", savings_per_byte, + " allocation type: ", AllocationTypeToString(allocation_type), "\n", + values_str, "\n", allocations_str); +} + +bool MemoryBoundLoopOptimizer::LoopValue::IsAllocationTypeSupported() const { + return allocation_type == AllocationType::kTemporary || + allocation_type == AllocationType::kPinned || + allocation_type == AllocationType::kPrefetch; +} + +void MemoryBoundLoopOptimizer::SortLoopValues() { + absl::c_stable_sort(loop_values_, [](const LoopValue& a, const LoopValue& b) { + return a.savings_per_byte > b.savings_per_byte; + }); +} + +void MemoryBoundLoopOptimizer::AllocateLoopValues() { + // This function allocates loop values. + std::vector prefetch_values; + VLOG(3) << "Pre optimization execution time: " << CalculateExecutionTime(); + for (LoopValue& value : loop_values_) { + switch (value.allocation_type) { + case LoopValue::AllocationType::kTemporary: + AllocateTemporary(value); + break; + case LoopValue::AllocationType::kPinned: + if (value.savings > 0) { + AllocatePinned(value); + } + break; + case LoopValue::AllocationType::kPrefetch: + prefetch_values.push_back(&value); + break; + case LoopValue::AllocationType::kLoopCarriedDependence: + case LoopValue::AllocationType::kUnsupported: + VLOG(1) << "Unsupported allocation: " << value.ToString(); + } + } + VLOG(3) << "Execution time after allocating temporaries: " + << CalculateExecutionTime(); + AllocatePrefetches(absl::MakeSpan(prefetch_values)); + VLOG(3) << "Execution time after allocating prefetches: " + << CalculateExecutionTime(); +} + +void MemoryBoundLoopOptimizer::PostProcess() { + // At the end, ensure that all loop uses have a corresponding Allocation and + // create one in the default memory space if they don't. + for (LoopValue& value : loop_values_) { + absl::flat_hash_set allocated_uses; + for (const auto& allocation : value.allocations) { + for (const HloUse& use : allocation->uses()) { + allocated_uses.insert(use); + } + } + std::vector unallocated_uses; + absl::flat_hash_set use_indices; + for (const auto& [idx, use] : value.loop_uses) { + use_indices.insert(idx); + if (!allocated_uses.contains(use)) { + unallocated_uses.push_back(use); + } + } + for (const auto& [next_iteration_idx, use] : value.next_iteration_uses) { + if (use_indices.contains(next_iteration_idx)) { + continue; + } + HloInstruction* loop_instruction = + hlo_live_range_.flattened_instruction_sequence().instructions().at( + loop_start_ + next_iteration_idx); + HloUse loop_use{loop_instruction, use.operand_number, use.operand_index}; + if (!allocated_uses.contains(loop_use)) { + unallocated_uses.push_back(loop_use); + } + } + if (!unallocated_uses.empty()) { + // TODO(b/281582241): We should find the correct position. For now, we're + // using the defining position on the first HLO value. + value.allocations.push_back(std::make_unique( + value.hlo_values.front()->defining_position(), MemorySpace::kDefault, + std::nullopt, 0, loop_size_, /*is_scoped_allocation=*/false)); + for (const HloUse& use : unallocated_uses) { + value.allocations.back()->AddUse(use); + } + } + } +} + +bool MemoryBoundLoopOptimizer::AllocateBetween(int64_t begin_idx, + int64_t end_idx, int64_t size) { + int64_t end_idx_sentinel = end_idx; + if (end_idx < begin_idx) { + end_idx_sentinel += loop_size_; + } + for (int64_t i = begin_idx; i <= end_idx_sentinel; ++i) { + if (remaining_memory_[i % loop_size_] < size) { + return false; + } + } + for (int64_t i = begin_idx; i <= end_idx_sentinel; ++i) { + remaining_memory_[i % loop_size_] -= size; + } + return true; +} + +bool MemoryBoundLoopOptimizer::AllocateTemporary(LoopValue& value) { + VLOG(3) << "AllocateTemporary: " << value.ToString(); + if (value.hlo_values.size() > 1) { + VLOG(3) << "LoopValue has more than one hlo value associated."; + return false; + } + int64_t definition_idx = value.loop_positions.front().first; + int64_t max_use_idx; + if (!value.next_iteration_uses.empty()) { + max_use_idx = value.next_iteration_uses.back().first; + // If max_use_idx >= definition_idx, then this is a loop carried dependence + // and we should not have called this function. + CHECK_LT(max_use_idx, definition_idx); + } else { + max_use_idx = value.loop_uses.back().first; + } + bool success = AllocateBetween(definition_idx, max_use_idx, value.size); + if (success) { + VLOG(3) << "Pos: " << value.loop_positions[0].second; + value.allocations.push_back(std::make_unique( + value.loop_positions[0].second, MemorySpace::kAlternate, std::nullopt, + definition_idx, max_use_idx, + /*is_scoped_allocation=*/false)); + AddAllLoopPositionsAndUses(value, /*allocate_next_iteration_uses=*/true); + } + return success; +} + +bool MemoryBoundLoopOptimizer::AllocatePinned(LoopValue& value) { + bool success = AllocateBetween(0, loop_size_ - 1, value.size); + if (success) { + CHECK(value.header_position); + value.allocations.push_back(std::make_unique( + *value.header_position, MemorySpace::kAlternate, std::nullopt, 0, + loop_size_, + /*is_scoped_allocation=*/false)); + AddAllLoopPositionsAndUses(value, /*allocate_next_iteration_uses=*/false); + } + return success; +} + +bool MemoryBoundLoopOptimizer::AllocatePrefetches( + absl::Span values) { + VLOG(3) << "Allocating prefetches num values: " << values.size(); + AllocatePrefetchesContext context; + context.values = values; + // Populate value_indices, which is a list of indices into values array sorted + // by the start time of the first use. + context.value_indices.resize(values.size()); + absl::c_iota(context.value_indices, 0); + absl::c_stable_sort(context.value_indices, [&](int a, int b) { + return std::forward_as_tuple( + values[a]->loop_uses.begin()->first, + values[a]->loop_uses.begin()->second.operand_number) > + std::forward_as_tuple( + values[b]->loop_uses.begin()->first, + values[b]->loop_uses.begin()->second.operand_number); + }); + + // Populate the data structures that contain additional positions and uses + // that would get alternate memory allocations if all of the prefetches were + // successful. + absl::flat_hash_map>> + additional_uses_in_alternate_mem; + absl::flat_hash_map> + additional_positions_in_alternate_mem; + for (const LoopValue* value : values) { + VLOG(3) << " prefetch value: " << value->ToString(); + for (const auto& [idx, use] : value->loop_uses) { + additional_uses_in_alternate_mem[use.instruction].push_back( + {use.operand_number, use.operand_index}); + } + for (const auto& [idx, position] : value->loop_positions) { + additional_positions_in_alternate_mem[position.instruction].push_back( + position.index); + } + } + // Calculate the default-memory remaining bandwidths assuming all prefetches + // succeed. + for (int i = 0; i < loop_size_; ++i) { + context.bandwidth_idle_times.push_back( + GetBandwidthIdleTime(i, additional_uses_in_alternate_mem, + additional_positions_in_alternate_mem)); + VLOG(3) << "Remaining bandwidth at " << i << " = " + << *context.bandwidth_idle_times.rbegin(); + } + + context.additional_memory_used.resize(loop_size_, 0); + + // Allocate prefetches by traversing the loop values in reverse order of + // the first uses. + for (int value_index : context.value_indices) { + AllocatePrefetch(value_index, context); + } + + for (int i = 0; i < loop_size_; ++i) { + remaining_memory_[i] -= context.additional_memory_used[i]; + VLOG(3) << "Additional memory [" << i + << "]: " << context.additional_memory_used[i]; + VLOG(3) << "Remaining memory [" << i << "]: " << remaining_memory_[i]; + VLOG(3) << "Remaining bandwidth [" << i + << "] : " << context.bandwidth_idle_times[i]; + } + return true; +} + +bool MemoryBoundLoopOptimizer::AllocatePrefetch( + int value_index, AllocatePrefetchesContext& context) { + LoopValue* value = context.values.at(value_index); + VLOG(3) << "Allocating value: " << value->ToString(); + int first_use_idx = value->loop_uses.front().first; + int last_use_idx = value->loop_uses.back().first; + int last_use_idx_sentinel = last_use_idx; + if (!value->next_iteration_uses.empty()) { + last_use_idx = value->next_iteration_uses.back().first; + last_use_idx_sentinel = last_use_idx + loop_size_; + CHECK_LT(last_use_idx, first_use_idx); + } + bool out_of_memory = false; + for (int i = first_use_idx; i <= last_use_idx_sentinel; ++i) { + int loop_idx = i % loop_size_; + if (context.additional_memory_used[loop_idx] + value->size > + remaining_memory_[loop_idx]) { + VLOG(3) << "Ran out of memory allocating for uses."; + out_of_memory = true; + } + } + if (out_of_memory) { + return false; + } + float copy_resource = + cost_analysis_.GetAsyncCopyElapsed(value->hlo_values.front()->shape()); + VLOG(3) << "First use: " << value->loop_uses.begin()->second + << " use idx: " << first_use_idx + << " copy resource: " << copy_resource; + std::optional copy_start_time; + // The general allocation algorithm for prefetches is to first calculate the + // default-memory bandwidth idle times at each point (assuming all prefetches + // succeeded). We show this pictorially below. We also show the previous + // iteration for clarity. The algorithm solves allocation for one iteration + // and this will be used for all iterations. + // + // idx: 0 1 2 3 4 5| 0 1 2 3 4 5| + // bw idle time: 2 2 1 2 3 1| 2 2 1 2 3 1| + // additional memory: 0 0 0 0 0 0| 0 0 0 0 0 0| + // iteration: prev | current | + // + // Now, let's assume there are two prefetches that need to be scheduled. For + // the sake of the example, assume 1 MiB of prefetch uses 1 memory bandwidth + // resource: + // - Prefetch 1 is 4 MiB and is first used at index 5. + // - Prefetch 2 is 5 MiB and is first used at index 1. + // + // We first order these prefetches by their first use from latest to earliest. + // Then starting from the prefetch completion time (i.e. the first use time), + // move the prefetch start time earlier until the copy resource is satisfied + // (or reaching another resource satisfaction criteria explained below) by + // consuming the bandwidth idle time of the overlapped instructions. We also + // keep track of the additional memory required. Note that index 5 also + // accounts for the additional 4 MiB consumed since the data needs to reside + // during the execution of the instruction at index 5. Below is the updated + // state after scheduling prefetch 1: + // + // prefetch 1: +====+ +====+ + // idx: 0 1 2 3 4 5| 0 1 2 3 4 5| + // bw idle time: 2 2 1 1 0 1| 2 2 1 1 0 1| + // additional memory: 0 0 0 4 4 4| 0 0 0 4 4 4| + // iteration: prev | current | + // + // To schedule prefetch 2, we similarly start the same way, from its first use + // and bring the prefetch start earlier. We first reach index 0 with still an + // unsatisfied copy resource of 3: + // + // prefetch 2: +=+ +=+ unsat res: 3 + // prefetch 1: +====+ +====+ + // idx: 0 1 2 3 4 5| 0 1 2 3 4 5| + // bw idle time: 0 2 1 1 0 1| 0 2 1 1 0 1| + // additional memory: 5 5 0 4 4 4| 5 5 0 4 4 4| + // iteration: prev | current | + // + // We continue onto the previous iteration: + // + // prefetch 2:===+ +====+ +== unsat res: 2 + // prefetch 1: +====+ +====+ + // idx: 0 1 2 3 4 5| 0 1 2 3 4 5| + // bw idle time: 0 2 1 1 0 0| 0 2 1 1 0 0| + // additional memory: 5 5 0 4 4 9| 5 5 0 4 4 9| + // iteration: prev | current | + // + // As we bring the start time of prefetch 2 earlier, it starts overlapping + // with prefetch 1: + // + // prefetch 2:===+ +==========+ +======== unsat res: 1 + // prefetch 1: +====+ +====+ + // idx: 0 1 2 3 4 5| 0 1 2 3 4 5| + // bw idle time: 0 2 1 0 0 0| 0 2 1 0 0 0| + // additional memory: 5 5 0 9 9 9| 5 5 0 9 9 9| + // iteration: prev | current | + // + // The prefetch resource is still unsatisfied at this point. We can bring the + // prefetch earlier. However, the first prefetch's end time is earlier than + // the second and we need to maintain FIFO order with regard to prefetches. In + // order to maintain this FIFO order, we "early force" prefetches that are + // already scheduled by moving the start time earlier along with prefetch 2: + // + // prefetch 2:===+ +=============+ +=========== + // prefetch 1: +=======+ +=======+ + // idx: 0 1 2 3 4 5| 0 1 2 3 4 5| + // bw idle time: 0 2 0 0 0 0| 0 2 0 0 0 0| + // additional memory: 5 5 9 9 9 9| 5 5 9 9 9 9| + // iteration: prev | current | + // + // Depending on the options provided, we can use alternative resource + // satisfaction criteria. One option is to specify a percentage of the copy + // resource that needs to be satisfied instead of the complete amount (100%). + // This is called the "desired copy ratio". The reason why desired copy ratio + // can be less than 100% is that in a memory-bound loop, we probably do not + // have enough aggregate bandwidth resources to satisfy all of the prefetches, + // but using up all of the default-memory bandwidth is more important than + // having some prefetches with unsatisfied resources. In a similar vein, + // another option is to accept prefetches that are fully pipelined, i.e. + // their copy start time is scheduled the same time as the copy done time in + // the previous iteration, regardless of how much of its copy resources are + // actually satisfied. To illustrate a fully pipelined prefetch, consider + // prefetch 3 (assume no prefetch 1 or 2 in this example) which is 15 MiB and + // its first use is at index 4: + // + // prefetch 3:=============+=================+===== unsat res: 4 + // idx: 0 1 2 3 4 5| 0 1 2 3 4 5| + // bw idle time: 0 0 0 0 0 0| 0 0 0 0 0 0| + // additional memory: 15 15 15 15 30 15|15 15 15 15 30 15| + // iteration: prev | current | + // + // Note that the additional memory consumption at index 4 is actually twice + // the size of the prefetch as we are effectively double buffering. Also note + // that the prefetch has an unsatisfied copy resource of 4 meaning the copy + // will be in the critical path, but this actually will be faster than not + // scheduling this particular prefetch in the first place since the bandwidth + // idle time resource would go unused. + float accumulated_copy_resource = 0; + std::vector early_forced_prefetch_value_indices; + int early_forced_prefetch_value_search_index = 0; + float early_forced_prefetch_additional_memory = 0; + for (int i = first_use_idx - 1; i >= last_use_idx_sentinel - loop_size_; + --i) { + int loop_idx = (i + loop_size_) % loop_size_; + // Check if this prefetch rolls over to the previous iteration, check if any + // already-scheduled prefetches would violate the FIFO order, and if so, + // "early-force" them to be co-scheduled with this prefetch to maintain the + // FIFO order. This of course increases the required memory, so also keep + // track of additional memory that would be consumed. + if (i < 0) { + for (; context.value_indices[early_forced_prefetch_value_search_index] != + value_index; + ++early_forced_prefetch_value_search_index) { + VLOG(3) << "Searching for early forced: " + << early_forced_prefetch_value_search_index; + LoopValue* early_forced_value = context.values.at( + context.value_indices[early_forced_prefetch_value_search_index]); + if (early_forced_value->allocations.empty()) { + continue; + } + const CopyAllocation* early_forced_prefetch = + static_cast( + early_forced_value->allocations.back().get()); + VLOG(3) << "Prefetch: " << early_forced_prefetch->ToString(); + + // If the prefetch is already a roll-around prefetch, no need to further + // early force it. + if (early_forced_prefetch->copy_done_schedule_before() <= + early_forced_prefetch->copy_start_schedule_after() + 1 || + (early_forced_prefetch->copy_start_schedule_after() == + loop_size_ - 1 && + early_forced_prefetch->copy_done_schedule_before() == 0)) { + break; + } + if (early_forced_prefetch->copy_start_schedule_after() != loop_idx) { + break; + } + early_forced_prefetch_value_indices.push_back( + early_forced_prefetch_value_search_index); + early_forced_prefetch_additional_memory += early_forced_value->size; + VLOG(3) << "Found early-forced prefetch value: " + << early_forced_value->ToString(); + VLOG(3) << "Early forced prefetch additional memory: " + << early_forced_prefetch_additional_memory; + } + } + + // Overlap memory overhead only happens if the copy start overlaps with the + // first use (i.e. fully pipelined), so we'd need to account for 2X the + // buffer at this time. + int64_t overlap_memory_overhead = 0; + if (loop_idx == last_use_idx) { + overlap_memory_overhead = value->size; + VLOG(3) << "Loop idx == last use idx (" << loop_idx + << "), overlap memory overhead = " << overlap_memory_overhead; + } + + // OOM; give up prefetch. + if (context.additional_memory_used[loop_idx] + value->size + + overlap_memory_overhead + early_forced_prefetch_additional_memory > + remaining_memory_[loop_idx]) { + VLOG(3) << "Ran out of memory. Accumulated copy resource " + << accumulated_copy_resource << " out of " << copy_resource + << " at " << loop_idx; + break; + } + + // We ideally find a time to overlap the prefetch fully where the previous + // iteration's memory use is disjoint from this iteration. If that is not + // possible, there are two compromises we could pick: + // - Find a prefetch time that satisfies a desired ratio < 1 of the + // prefetch elapsed time. This means the prefetch will be critical. + // - Overlap the prefetch with the previous iteration's buffer use, i.e. + // full pipelining. This would increase the peak memory consumption. + float bandwidth_idle_time = context.bandwidth_idle_times[loop_idx]; + VLOG(3) << "Idx " << loop_idx + << " bandwidth_idle_time: " << bandwidth_idle_time + << " copy resource remaining: " + << (copy_resource - accumulated_copy_resource) << " diff: " + << (bandwidth_idle_time - + (copy_resource - accumulated_copy_resource)); + if (bandwidth_idle_time >= copy_resource - accumulated_copy_resource) { + accumulated_copy_resource = copy_resource; + copy_start_time = loop_idx; + VLOG(3) << "Found the complete copy ratio and updated accumulated copy " + "resource: " + << accumulated_copy_resource; + break; + } else if (!copy_start_time && + accumulated_copy_resource + bandwidth_idle_time >= + copy_resource * options_.desired_copy_ratio()) { + accumulated_copy_resource += bandwidth_idle_time; + copy_start_time = loop_idx; + VLOG(3) << "Found the desired copy ratio and updated accumulated copy " + "resource: " + << accumulated_copy_resource; + } else if (options_.allow_unsatisfied_fully_pipelined_prefetch() && + loop_idx == last_use_idx) { + // Even if desired resource isn't reached, and if the options allow it, + // allow a fully pipelined prefetch. + accumulated_copy_resource += bandwidth_idle_time; + copy_start_time = loop_idx; + VLOG(3) << "Could not reach the desired copy ratio but scheduling " + "fully pipelined prefetch anyway: " + << accumulated_copy_resource; + break; + } else { + accumulated_copy_resource += bandwidth_idle_time; + VLOG(3) << "Updated accumulated copy resource: " + << accumulated_copy_resource; + } + } + + // Could not find a suitable copy start time. + if (!copy_start_time) { + return false; + } + + VLOG(3) << "Success: copy_start_time: " << *copy_start_time + << " leftover copy resource: " + << (copy_resource - accumulated_copy_resource); + auto update_additional_memory_used = [&](int loop_idx, int64_t addition) { + VLOG(4) << "Updating additional memory used at " << loop_idx << ". " + << context.additional_memory_used[loop_idx] << " + " << addition + << " => " << (context.additional_memory_used[loop_idx] + addition) + << " (remaining: " << remaining_memory_[loop_idx] << ")"; + context.additional_memory_used[loop_idx] += addition; + CHECK_LE(context.additional_memory_used[loop_idx], + remaining_memory_[loop_idx]); + }; + for (int i = first_use_idx; i <= last_use_idx_sentinel; ++i) { + int loop_idx = i % loop_size_; + update_additional_memory_used(loop_idx, value->size); + } + // We reset accumulated copy resource and then reuse it to accumulate copy + // resource time in order to replay the previous for loop. It is important + // that we use the same arithmetic operations (as opposed to subtracting from + // copy_resource) because floating point operations aren't commutative. + accumulated_copy_resource = 0.0; + for (int i = first_use_idx - 1; i >= last_use_idx_sentinel - loop_size_; + --i) { + int loop_idx = (i + loop_size_) % loop_size_; + float& bandwidth_idle_time = context.bandwidth_idle_times[loop_idx]; + // Overlap memory overhead only happens if the copy start overlaps with the + // first use (i.e. fully pipelined), so we'd need to account for 2X the + // buffer at this time. + int64_t overlap_memory_overhead = 0; + update_additional_memory_used(loop_idx, + value->size + overlap_memory_overhead); + if (bandwidth_idle_time < copy_resource - accumulated_copy_resource) { + accumulated_copy_resource += bandwidth_idle_time; + bandwidth_idle_time = 0; + if (loop_idx == *copy_start_time) { + VLOG(3) << "Remaining copy resource: " + << (copy_resource - accumulated_copy_resource); + break; + } + } else { + bandwidth_idle_time -= copy_resource - accumulated_copy_resource; + CHECK_EQ(loop_idx, *copy_start_time); + break; + } + } + + // Create the Allocation objects that correspond to the scheduled prefetch. + CHECK(value->header_position); + value->allocations.push_back(std::make_unique( + *value->header_position, MemorySpace::kDefault, std::nullopt, 0, + loop_size_, /*is_scoped_allocation=*/false)); + value->allocations.push_back(std::make_unique( + *value->allocations.back(), MemorySpace::kAlternate, std::nullopt, + ((*copy_start_time - 1) + loop_size_) % loop_size_, first_use_idx, + last_use_idx_sentinel)); + AddAllLoopPositionsAndUses(*value, /*allocate_next_iteration_uses=*/true); + + // Account for the additional memory used by early forcing the already + // scheduled prefetches. Also modify the start times of these to this + // prefetch's copy start time. + for (int early_forced_prefetch_value_index : + early_forced_prefetch_value_indices) { + LoopValue* early_forced_value = context.values.at( + context.value_indices[early_forced_prefetch_value_index]); + CHECK(!early_forced_value->allocations.empty()); + CopyAllocation* early_forced_prefetch = static_cast( + early_forced_value->allocations.back().get()); + for (int index = early_forced_prefetch->copy_start_schedule_after(); + index >= *copy_start_time; --index) { + update_additional_memory_used(index, early_forced_value->size); + VLOG(3) << "Additional memory used: " << index << " " + << context.additional_memory_used[index]; + } + early_forced_prefetch->set_copy_start_schedule_after( + ((*copy_start_time - 1) + loop_size_) % loop_size_); + VLOG(3) << "Updated prefetch: " << early_forced_prefetch->ToString(); + } + return true; +} + +void MemoryBoundLoopOptimizer::AddAllLoopPositionsAndUses( + LoopValue& value, bool allocate_next_iteration_uses) { + CHECK_GE(value.allocations.size(), 1); + Allocation& allocation = *value.allocations.back(); + for (const auto& [idx, position] : value.loop_positions) { + positions_in_alternate_mem_[position.instruction].push_back(position.index); + } + for (const auto& [idx, use] : value.loop_uses) { + uses_in_alternate_mem_[use.instruction].push_back( + {use.operand_number, use.operand_index}); + allocation.AddUse(use); + } + if (allocate_next_iteration_uses) { + for (const auto& [next_iteration_idx, use] : value.next_iteration_uses) { + HloInstruction* loop_instruction = + hlo_live_range_.flattened_instruction_sequence().instructions().at( + loop_start_ + next_iteration_idx); + uses_in_alternate_mem_[loop_instruction].push_back( + {use.operand_number, use.operand_index}); + allocation.AddUse( + {loop_instruction, use.operand_number, use.operand_index}); + } + } +} + +float MemoryBoundLoopOptimizer::GetBandwidthIdleTime(int idx) const { + const HloInstruction* inst = + hlo_live_range_.flattened_instruction_sequence().instructions().at( + loop_start_ + idx); + std::vector> empty_operands; + std::vector empty_outputs; + const std::vector>* operands_in_alternate_mem = + &empty_operands; + const std::vector* outputs_in_alternate_mem = &empty_outputs; + auto uses_it = uses_in_alternate_mem_.find(inst); + if (uses_it != uses_in_alternate_mem_.end()) { + operands_in_alternate_mem = &uses_it->second; + } + auto positions_it = positions_in_alternate_mem_.find(inst); + if (positions_it != positions_in_alternate_mem_.end()) { + outputs_in_alternate_mem = &positions_it->second; + } + return cost_analysis_.GetDefaultMemoryBandwidthIdleTime( + *inst, *operands_in_alternate_mem, *outputs_in_alternate_mem); +} + +float MemoryBoundLoopOptimizer::GetBandwidthIdleTime( + int idx, + const absl::flat_hash_map>>& + additional_uses_in_alternate_mem, + const absl::flat_hash_map>& + additional_positions_in_alternate_mem) const { + const HloInstruction* inst = + hlo_live_range_.flattened_instruction_sequence().instructions().at( + loop_start_ + idx); + std::vector> operands_in_alternate_mem; + std::vector outputs_in_alternate_mem; + auto uses_it = uses_in_alternate_mem_.find(inst); + if (uses_it != uses_in_alternate_mem_.end()) { + operands_in_alternate_mem = uses_it->second; + } + auto additional_uses_it = additional_uses_in_alternate_mem.find(inst); + if (additional_uses_it != additional_uses_in_alternate_mem.end()) { + absl::c_copy(additional_uses_it->second, + std::back_inserter(operands_in_alternate_mem)); + } + auto positions_it = positions_in_alternate_mem_.find(inst); + if (positions_it != positions_in_alternate_mem_.end()) { + outputs_in_alternate_mem = positions_it->second; + } + auto additional_positions_it = + additional_positions_in_alternate_mem.find(inst); + if (additional_positions_it != additional_positions_in_alternate_mem.end()) { + absl::c_copy(additional_positions_it->second, + std::back_inserter(outputs_in_alternate_mem)); + } + return cost_analysis_.GetDefaultMemoryBandwidthIdleTime( + *inst, operands_in_alternate_mem, outputs_in_alternate_mem); +} + +float MemoryBoundLoopOptimizer::GetInstructionElapsed(int idx) const { + const HloInstruction* inst = + hlo_live_range_.flattened_instruction_sequence().instructions().at( + loop_start_ + idx); + std::vector> empty_operands; + std::vector empty_outputs; + const std::vector>* operands_in_alternate_mem = + &empty_operands; + const std::vector* outputs_in_alternate_mem = &empty_outputs; + auto uses_it = uses_in_alternate_mem_.find(inst); + if (uses_it != uses_in_alternate_mem_.end()) { + operands_in_alternate_mem = &uses_it->second; + } + auto positions_it = positions_in_alternate_mem_.find(inst); + if (positions_it != positions_in_alternate_mem_.end()) { + outputs_in_alternate_mem = &positions_it->second; + } + return cost_analysis_.GetInstructionElapsedInAlternateMemory( + *inst, *operands_in_alternate_mem, *outputs_in_alternate_mem); +} + +} // namespace memory_space_assignment +} // namespace xla diff --git a/xla/service/memory_space_assignment/memory_bound_loop_optimizer.h b/xla/service/memory_space_assignment/memory_bound_loop_optimizer.h new file mode 100644 index 0000000000000..1f75ae693d34b --- /dev/null +++ b/xla/service/memory_space_assignment/memory_bound_loop_optimizer.h @@ -0,0 +1,297 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_BOUND_LOOP_OPTIMIZER_H_ +#define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_BOUND_LOOP_OPTIMIZER_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/utils/hlo_live_range.h" +#include "xla/service/buffer_value.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_buffer.h" +#include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/allocation.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" +#include "xla/service/memory_space_assignment/options.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/util.h" + +namespace xla { +namespace memory_space_assignment { + +// An optimizer for unrolled memory-bound loops. It keeps track of alternate +// memory capacity and default memory bandwidth to decide the allocations of +// each tensor within a loop iteration. The assumption is that all of the +// unrolled loop iterations will use the same allocation decisions, so we can +// spend more time to optimize this one iteration as optimally as possible. +// +// To represent instructions, we keep track of three iterations (previous, +// current, and next), as well as the header and footer regions that are before +// and after the loop, respectively. +// +// We classify each tensor used in the current iteration as one of the following +// allocations based on its positions and uses: +// +// Temporary Allocations: These are produced by a producer in the current +// iteration and consumed either in this or the next iteration. For these, we +// try to give them alternate memory allocations for their entire live range. +// +// Case 1: producer and consumer all in the current iteration. +// p-----c--c +// Case 2: producer is in the current iter, consumer is in the next iter. +// p-----c +// idx: |...| 0 1 2 3 4| 0 1 2 3 4| 0 1 2 3 4|...| +// iter: head |...| prev | current | next |...| foot +// +// Loop Carried Dependences: This is where the last use is at a larger index +// than the producer. This would require 2X peak buffer consumption because both +// this and next iteration's buffer is alive at the same time. This case is +// currently not supported. +// +// Case 3: producer is in the current iter, consumer is in the next iter +// (consumer idx >= producer idx). +// p-----------------c +// idx: |...| 0 1 2 3 4| 0 1 2 3 4| 0 1 2 3 4|...| +// iter: head |...| prev | current | next |...| foot +// +// Pinned Allocations: These are values produced at the header and are used in +// every iteration at the same indices. For these, we just allocate the buffer +// for the duration of the loop: +// +// Case 4: producer: kHead, consumer: kCurrent +// p---------------c--------------c--------------c-------- +// idx: |...| 0 1 2 3 4| 0 1 2 3 4| 0 1 2 3 4|...| +// iter: head |...| prev | current | next |...| foot +// +// Prefetch Allocations: These are values produced at the header and are used in +// the current (and possibly next) iteration. We will try to prefetch these +// values into the alternate memory: +// +// Case 5: producer: kHead, consumer: kCurrent +// p---------------------------------c--------c +// idx: |...| 0 1 2 3 4| 0 1 2 3 4| 0 1 2 3 4|...| +// iter: head |...| prev | current | next |...| foot +class MemoryBoundLoopOptimizer { + public: + // We represent each tensor used in the current iteration as a LoopValue, + // wrapping the relevant information such as its HLO value, indices and + // pointers to its use and position sites in different iterations. + struct LoopValue { + // An enum that encodes the allocation type that is suitable for this + // LoopValue. See the comment above on what each of these mean. + enum class AllocationType { + kTemporary, + kLoopCarriedDependence, + kPinned, + kPrefetch, + kUnsupported + }; + + // ToString methods for logging/debugging. + static std::string AllocationTypeToString(AllocationType allocation_type); + std::string ToString() const; + + // Returns true if memory-bound loop optimizer supports allocating this type + // of a loop value. + bool IsAllocationTypeSupported() const; + + // The HloValues that correspond to this LoopValue. + std::vector hlo_values; + // The position in the header, if any. + std::optional header_position; + // The loop index and position in the previous and current iterations. + std::vector> prev_iteration_positions; + std::vector> loop_positions; + // The loop index and use in the current and next iterations. + std::vector> loop_uses; + std::vector> next_iteration_uses; + // The allocation type. + AllocationType allocation_type; + // Size of this tensor. + int64_t size; + // The default memory bandwidth savings were we to successfully put this in + // the alternate memory using the allocation type, in bytes. + float savings; + // The savings divided by the size. This is typically 2 for temporary + // allocations (skip a write and a read to the default memory). More complex + // production/consumption patterns may result in higher or lower values. We + // use this value to sort LoopValues so that the algorithm can prioritize + // allocating the buffers with the highest savings per byte to the alternate + // memory. + float savings_per_byte; + // The optimized AllocationSequence. + AllocationSequence allocations; + }; + + // Factory method to create and initialize a MemoryBoundLoopOptimizer. + static absl::StatusOr> Create( + int loop_start, int loop_end, uint64_t alternate_memory_size, + const MemoryBoundLoopOptimizerOptions& options, + const HloLiveRange& hlo_live_range, + const HloAliasAnalysis& alias_analysis_, + const CostAnalysis& cost_analysis, + const BufferValue::SizeFunction& size_function, + const ReservedScopedMemoryFunction& reserved_scoped_memory_fn); + + // Optimize the loop. Initialize must be called first. + void Optimize(); + + // Calculate the steady-state execution time of one loop iteration using the + // allocation decisions so far. + float CalculateExecutionTime() const; + + // Return the LoopValues. + const std::vector& loop_values() const { return loop_values_; } + std::vector& loop_values() { return loop_values_; } + + // Return the remaining memory vector for each point in time in the loop using + // the allocation decisions so far. + const std::vector& remaining_memory() const { + return remaining_memory_; + } + + // The loop start, end, and size accessors. + int loop_start() const { return loop_start_; } + int loop_end() const { return loop_end_; } + int loop_size() const { return loop_size_; } + + private: + // Temporary data structures used by the AllocatePrefetch function. + struct AllocatePrefetchesContext { + // The values that are requested to be prefetched. + absl::Span values; + + // A list of indices into values array, sorted by the start time of the + // first use. + std::vector value_indices; + + // Default memory remaining bandwidths assuming all prefetches succeeded. + std::vector bandwidth_idle_times; + + // Additional memory used while performing prefetching. + std::vector additional_memory_used; + }; + + MemoryBoundLoopOptimizer( + int loop_start, int loop_end, uint64_t alternate_memory_size, + const MemoryBoundLoopOptimizerOptions& options, + const HloLiveRange& hlo_live_range, + const HloAliasAnalysis& alias_analysis_, + const CostAnalysis& cost_analysis, + const BufferValue::SizeFunction& size_function, + const ReservedScopedMemoryFunction& reserved_scoped_memory_fn); + + // Initializes the data structures used by the optimizer. + Status Initialize(); + + // Given an HloBuffer object, determines if this buffer represents a LoopValue + // that can be optimized by the optimizer, and if so it adds a LoopValue to + // the back of loop_values_ that represents the HloBuffer. Otherwise, no new + // LoopValue is added to loop_values_. + void MaybeCreateLoopValue(const HloBuffer& buffer, + const HloComputation* loop_computation); + + // Sort LoopValues by savings_per_byte. + void SortLoopValues(); + + // After allocation finishes, we fix up by creating Allocation objects to any + // LoopValues that didn't get alternate memory allocations. + void PostProcess(); + + // Allocate LoopValues by dispatching to the correct Allocate method. + void AllocateLoopValues(); + + // Allocate and reserve memory between the given indices. + bool AllocateBetween(int64_t begin_idx, int64_t end_idx, int64_t size); + + // Perform allocation type kTemporary. Return true if successful. + bool AllocateTemporary(LoopValue& value); + + // Perform allocation type kPinned. Return true if successful. + bool AllocatePinned(LoopValue& value); + + // Perform allocation type kPrefetch. Unlike the other Allocate methods, this + // performs allocation of multiple LoopValues in order to consider the effect + // of remaining bandwidth assuming the other prefetches were successful. + // Return true if successful. + bool AllocatePrefetches(absl::Span values); + + // Allocate one prefetch for the loop value index that corresponds to + // context.context.values. Returns true if successful. + bool AllocatePrefetch(int value_index, AllocatePrefetchesContext& context); + + // Keeps track of successful allocation of all uses and positions of this + // LoopValue. + void AddAllLoopPositionsAndUses(LoopValue& value, + bool allocate_next_iteration_uses); + + // Returns the default memory bandwidth idle time at the index. + float GetBandwidthIdleTime(int idx) const; + + // Returns the default memory bandwidth idle time at the index assuming the + // given uses and positions got alternate memory allocations. + float GetBandwidthIdleTime( + int idx, + const absl::flat_hash_map>>& + additional_uses_in_alternate_mem, + const absl::flat_hash_map>& + additional_positions_in_alternate_mem) const; + + // Returns the instruction elapsed at the index. + float GetInstructionElapsed(int idx) const; + + int loop_start_; + int loop_end_; + int loop_size_; + uint64_t alternate_memory_size_; + MemoryBoundLoopOptimizerOptions options_; + const HloLiveRange& hlo_live_range_; + const HloAliasAnalysis& alias_analysis_; + const CostAnalysis& cost_analysis_; + BufferValue::SizeFunction size_function_; + + absl::flat_hash_map instructions_in_loop_; + absl::flat_hash_map + instructions_in_prev_iteration_; + absl::flat_hash_map + instructions_in_next_iteration_; + std::vector loop_values_; + std::vector remaining_memory_; + absl::flat_hash_map>> + uses_in_alternate_mem_; + absl::flat_hash_map> + positions_in_alternate_mem_; + const ReservedScopedMemoryFunction& reserved_scoped_memory_fn_; +}; + +} // namespace memory_space_assignment +} // namespace xla + +#endif // XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_BOUND_LOOP_OPTIMIZER_H_ diff --git a/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc b/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc new file mode 100644 index 0000000000000..8403316ffdb2d --- /dev/null +++ b/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc @@ -0,0 +1,1381 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/memory_space_assignment/memory_bound_loop_optimizer.h" + +#include +#include +#include +#include +#include +#include + +#include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "re2/re2.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_live_range.h" +#include "xla/service/buffer_value.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/allocation.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" +#include "xla/service/memory_space_assignment/options.h" +#include "xla/service/memory_space_assignment/prefetch_interval_picker.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/status_macros.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/verified_hlo_module.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace memory_space_assignment { +namespace { + +constexpr int64_t kPointerSize = 8; + +int64_t ShapeSize(const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, kPointerSize); +} + +int64_t SizeFunction(const BufferValue& value) { + return ShapeSize(value.shape()); +} + +int64_t ReservedScopedMemoryFn( + const HloInstruction* instruction, + const absl::flat_hash_set>& + operands_in_alternate_memory, + const absl::flat_hash_set& outputs_in_alternate_memory) { + return 0; +} + +class MemoryBoundLoopOptimizerTest : public HloTestBase { + public: + MemoryBoundLoopOptimizerTest() = default; + + protected: + const int64_t kAlternateMemorySpace = 1; + const int64_t kDefaultMemorySpace = 0; + + Status Initialize(const HloModule* module, + uint64_t alternate_memory_size = 256) { + HloCostAnalysis::Options options; + MemoryBoundLoopOptimizerOptions optimizer_options; + optimizer_options.set_enabled(true); + optimizer_options.set_desired_copy_ratio(0.7); + optimizer_options.set_allow_unsatisfied_fully_pipelined_prefetch(false); + optimizer_options.set_min_num_iterations(3.0); + options_.memory_bound_loop_optimizer_options = optimizer_options; + cost_analysis_options_.alternate_mem_bandwidth_bytes_per_second = 128; + cost_analysis_options_.async_copy_bandwidth_bytes_per_second = 32; + cost_analysis_options_.pipeline_overhead_window_size_mib = 1; + options.shape_size = ShapeSize; + options.set_flops_per_second(16); + options.set_bytes_per_second(32); + options.set_transcendentals_per_second(16); + hlo_cost_analysis_ = std::make_unique(options); + TF_RETURN_IF_ERROR( + module->entry_computation()->Accept(hlo_cost_analysis_.get())); + TF_ASSIGN_OR_RETURN(cost_analysis_, + CostAnalysis::Create(*hlo_cost_analysis_, + cost_analysis_options_, *module)); + TF_ASSIGN_OR_RETURN(alias_analysis_, HloAliasAnalysis::Run(module)); + TF_ASSIGN_OR_RETURN(live_range_, + HloLiveRange::Run(module->schedule(), *alias_analysis_, + module->entry_computation())); + return OkStatus(); + } + + absl::StatusOr CreateOptimizer( + int loop_start, int loop_end, const HloModule* module, + uint64_t alternate_memory_size = 256, + const ReservedScopedMemoryFunction& reserved_scoped_memory_fn = + ReservedScopedMemoryFn) { + TF_RETURN_IF_ERROR(Initialize(module, alternate_memory_size)); + MemoryBoundLoopOptimizerOptions optimizer_options; + optimizer_options.set_enabled(true); + optimizer_options.set_desired_copy_ratio(0.7); + optimizer_options.set_allow_unsatisfied_fully_pipelined_prefetch(false); + TF_ASSIGN_OR_RETURN( + optimizer_, + MemoryBoundLoopOptimizer::Create( + loop_start, loop_end, alternate_memory_size, optimizer_options, + *live_range_, *alias_analysis_, *cost_analysis_, SizeFunction, + reserved_scoped_memory_fn)); + return optimizer_.get(); + } + + absl::StatusOr> ParseAndCreateOptimizer( + absl::string_view hlo_loop_str, uint64_t alternate_memory_size, + int& loop_start_idx, MemoryBoundLoopOptimizer** optimizer, + const ReservedScopedMemoryFunction& reserved_scoped_memory_fn = + ReservedScopedMemoryFn) { + int loop_end_idx; + TF_ASSIGN_OR_RETURN( + std::string module_str, + ParseAndCreateModuleString(hlo_loop_str, loop_start_idx, loop_end_idx)); + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + TF_ASSIGN_OR_RETURN( + *optimizer, + CreateOptimizer(loop_start_idx, loop_end_idx, module.get(), + alternate_memory_size, reserved_scoped_memory_fn)); + return std::move(module); + } + + // Parse a loop string description like the following: + // $op0 = f32[1,4] add(f32[1,4] $param0, f32[1,4] $prev_op4) + // $op1 = f32[8,4] add(f32[8,4] $param1, f32[8,4] $prev_op3) + // $op2 = f32[1,4] add(f32[1,4] $param2, f32[1,4] $op0) + // $op3 = f32[8,4] add(f32[8,4] $param3, f32[8,4] $op1) + // $op4 = f32[1,4] add(f32[1,4] $param4, f32[1,4] $op2) + absl::StatusOr ParseAndCreateModuleString( + absl::string_view hlo_loop_str, int& loop_start_idx, int& loop_end_idx) { + // Parse op name and types first. + RE2 op_re("\\$op([0-9]+) += +(\\S+).*"); + std::vector ops; + std::vector op_types; + int begin_pos = 0; + absl::string_view submatch[3]; + while (op_re.Match(hlo_loop_str, begin_pos, hlo_loop_str.size(), + RE2::UNANCHORED, submatch, /*nsubmatch=*/3)) { + for (int i = 0; i < 3; ++i) { + if (submatch[i].data() == nullptr) { + VLOG(4) << "Submatch[" << i << "] = nullptr"; + } else { + VLOG(4) << "Submatch[" << i << "] = " << submatch[i] + << " (idx: " << (submatch[i].data() - hlo_loop_str.data()) + << ")"; + } + } + int op_num; + if (!absl::SimpleAtoi(submatch[1], &op_num)) { + return InvalidArgument("Op name expects to contain a number, found %s.", + submatch[1]); + } + if (op_num != ops.size()) { + return InvalidArgument("Op number expected to be %d found %d.", + op_types.size(), op_num); + } + ops.push_back(submatch[0]); + op_types.push_back(submatch[2]); + begin_pos = submatch[0].data() - hlo_loop_str.data() + submatch[0].size(); + } + + RE2 param_re("([[:alnum:]]+\\[\\S*\\]) +\\$param([0-9]+)"); + std::vector param_types; + begin_pos = 0; + while (param_re.Match(hlo_loop_str, begin_pos, hlo_loop_str.size(), + RE2::UNANCHORED, submatch, /*nsubmatch=*/3)) { + for (int i = 0; i < 3; ++i) { + if (submatch[i].data() == nullptr) { + VLOG(4) << "Submatch[" << i << "] = nullptr"; + } else { + VLOG(4) << "Submatch[" << i << "] = " << submatch[i] + << " (idx: " << (submatch[i].data() - hlo_loop_str.data()) + << ")"; + } + } + int param_num; + if (!absl::SimpleAtoi(submatch[2], ¶m_num)) { + return InvalidArgument( + "Param name expects to contain a number, found %s.", submatch[2]); + } + while (param_num >= param_types.size()) { + param_types.push_back({}); + } + param_types[param_num] = submatch[1]; + + begin_pos = submatch[0].data() - hlo_loop_str.data() + submatch[0].size(); + } + + RE2 root_re("ROOT \\$root += +tuple\\((.*)\\)"); + absl::string_view root_values; + if (root_re.Match(hlo_loop_str, 0, hlo_loop_str.size(), RE2::UNANCHORED, + submatch, /*nsubmatch=*/2)) { + for (int i = 0; i < 2; ++i) { + if (submatch[i].data() == nullptr) { + VLOG(4) << "Submatch[" << i << "] = nullptr"; + } else { + VLOG(4) << "Submatch[" << i << "] = " << submatch[i] + << " (idx: " << (submatch[i].data() - hlo_loop_str.data()) + << ")"; + } + } + root_values = submatch[1]; + } + + for (absl::string_view op_type : op_types) { + VLOG(4) << "op_type: " << op_type; + } + for (absl::string_view param_type : param_types) { + VLOG(4) << "param_type: " << param_type; + } + + std::string hlo_string = R"( +HloModule module, is_scheduled=true + +ENTRY Entry { +)"; + int total_instructions = 0; + for (absl::string_view param_prefix : {"prev_", "", "next_"}) { + for (int i = 0; i < param_types.size(); ++i) { + int parameter_number = total_instructions; + absl::StrAppend(&hlo_string, " ", param_prefix, "param", i, " = ", + param_types[i], " parameter(", parameter_number, + ") // ", total_instructions++, "\n"); + } + } + + for (int i = 0; i < op_types.size(); ++i) { + int parameter_number = total_instructions; + absl::StrAppend(&hlo_string, " ", "prev_prev_op", i, " = ", op_types[i], + " parameter(", parameter_number, ") // ", + total_instructions++, "\n"); + } + + std::string new_root_values; + auto print_ops = + [&](const std::vector>& + replacements) { + for (int i = 0; i < ops.size(); ++i) { + absl::StrAppend(&hlo_string, " ", + absl::StrReplaceAll(ops[i], replacements), " // ", + total_instructions++, "\n"); + } + if (!root_values.empty()) { + absl::StrAppend(&new_root_values, + new_root_values.empty() ? "" : ", ", + absl::StrReplaceAll(root_values, replacements)); + } + }; + + std::vector> + prev_replacements; + prev_replacements.push_back({"$prev_op", "prev_prev_op"}); + prev_replacements.push_back({"$op", "prev_op"}); + prev_replacements.push_back({"$param", "prev_param"}); + absl::StrAppend(&hlo_string, " // Prev iteration body:\n"); + print_ops(prev_replacements); + + loop_start_idx = total_instructions; + std::vector> replacements; + replacements.push_back({"$", ""}); + absl::StrAppend(&hlo_string, " // Loop body:\n"); + print_ops(replacements); + loop_end_idx = total_instructions; + + std::vector> + next_replacements; + next_replacements.push_back({"$prev_op", "op"}); + next_replacements.push_back({"$op", "next_op"}); + next_replacements.push_back({"$param", "next_param"}); + absl::StrAppend(&hlo_string, " // Next iteration body:\n"); + print_ops(next_replacements); + + absl::StrAppend(&hlo_string, " ROOT root = tuple(", new_root_values, + ")\n"); + absl::StrAppend(&hlo_string, "}"); + + VLOG(1) << hlo_string; + return hlo_string; + } + + absl::StatusOr> RunMsa( + HloModule* module, uint64_t alternate_memory_size = 256) { + options_.max_size_in_bytes = alternate_memory_size; + options_.alignment_in_bytes = 8; + options_.verify = true; + + options_.alternate_memory_space = kAlternateMemorySpace; + + if (!cost_analysis_) { + TF_RETURN_IF_ERROR(Initialize(module, alternate_memory_size)); + } + CostAnalysis::Cache cache; + MemoryBoundednessBufferIntervalComparator comparator(*cost_analysis_, + &cache); + options_.buffer_interval_comparator = &comparator; + CostAnalysisPrefetchIntervalPicker prefetch_interval_picker( + CostAnalysisPrefetchIntervalPicker( + *cost_analysis_, /*min_overlap_to_async_copy_ratio=*/0.8, + /*preferred_overlap_to_async_copy_ratio=*/1.5, + /*max_overlap_to_mem_size_async_copy_ratio=*/10.0, + /*mem_size_bytes=*/alternate_memory_size)); + options_.prefetch_interval_picker = &prefetch_interval_picker; + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + options_.size_fn = size_fn; + + auto is_allowed_in_alternate_mem = [](const HloValue& value) { + // Check if the value belongs to the entry computation. + HloInstruction* instruction = value.instruction(); + HloComputation* computation = instruction->parent(); + bool in_entry_computation = + (computation == computation->parent()->entry_computation()); + if (in_entry_computation && + instruction->opcode() == HloOpcode::kParameter) { + return false; + } + return true; + }; + options_.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem; + options_.max_outstanding_prefetches = -1; + options_.max_outstanding_evictions = -1; + options_.allocate_across_sequential_calls = true; + options_.cost_analysis = cost_analysis_.get(); + + std::unique_ptr preset_assignments = + MemorySpaceAssignment::Run(module, *live_range_, *alias_analysis_, + options_) + .value(); + return preset_assignments; + } + + Status VerifyMsaEquivalence(HloModule* module, + bool expect_unsupported_allocations = false) { + // Create a map indexed by instruction number and operand number. + absl::flat_hash_map, const Allocation*> allocation_map; + for (const MemoryBoundLoopOptimizer::LoopValue& value : + optimizer_->loop_values()) { + // Skip verification for unsupported allocations as they will go through + // the usual MSA algorithm and may actually get an alternate memory + // allocation. + if (!value.IsAllocationTypeSupported()) { + continue; + } + for (const auto& allocation : value.allocations) { + for (const HloUse& use : allocation->uses()) { + absl::string_view inst_name = use.instruction->name(); + TF_RET_CHECK(absl::StartsWith(inst_name, "op")); + int inst_number; + TF_RET_CHECK(absl::SimpleAtoi(inst_name.substr(2), &inst_number)); + allocation_map[{inst_number, use.operand_number}] = allocation.get(); + } + } + } + + auto get_inst_prefix_in_iter = [](int iteration) { + switch (iteration) { + case 0: + return "prev_"; + case 1: + return ""; + case 2: + return "next_"; + default: + LOG(FATAL) << "Invalid iteration " << iteration; + return "INVALID"; + } + }; + + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module)); + TF_ASSIGN_OR_RETURN(std::unique_ptr live_range, + HloLiveRange::Run(module->schedule(), *alias_analysis, + module->entry_computation())); + const auto& flattened_instructions = + live_range->flattened_instruction_sequence().instructions(); + for (int iteration = 1; iteration < 3; ++iteration) { + for (int inst_number = 0; inst_number < optimizer_->loop_size(); + ++inst_number) { + HloInstruction* inst = FindInstruction( + module, absl::StrCat(get_inst_prefix_in_iter(iteration), "op", + inst_number)); + for (int operand_number = 0; operand_number < 2; ++operand_number) { + const HloInstruction* operand = inst->operand(operand_number); + LOG(INFO) << inst->name() << ", operand " << operand_number; + if (!allocation_map.contains({inst_number, operand_number})) { + TF_RET_CHECK(expect_unsupported_allocations); + continue; + } + const Allocation* allocation = + allocation_map.at({inst_number, operand_number}); + if (!allocation->is_copy_allocation()) { + // We don't expect a prefetch here. + EXPECT_NE(operand->opcode(), HloOpcode::kCopyDone); + int expected_memory_space = + allocation->memory_space() == MemorySpace::kDefault + ? kDefaultMemorySpace + : kAlternateMemorySpace; + EXPECT_EQ(operand->shape().layout().memory_space(), + expected_memory_space); + } else { + EXPECT_EQ(allocation->memory_space(), MemorySpace::kAlternate); + TF_RET_CHECK(operand->opcode() == HloOpcode::kCopyDone); + const CopyAllocation* copy_allocation = + static_cast(allocation); + if (copy_allocation->copy_done_schedule_before() != inst_number) { + // The only case where the copy done schedule before is not the + // same as this use would be that this use is not the first use of + // the copy allocation. + EXPECT_NE(allocation->uses().front(), + (HloUse{inst, operand_number})); + continue; + } + int expected_copy_start_iteration = iteration; + if (copy_allocation->copy_start_schedule_after() == + optimizer_->loop_size() && + copy_allocation->copy_done_schedule_before() == 0) { + expected_copy_start_iteration -= 2; + } else if (copy_allocation->copy_start_schedule_after() + 1 >= + copy_allocation->copy_done_schedule_before()) { + expected_copy_start_iteration -= 1; + } + + if (expected_copy_start_iteration >= 0) { + const HloInstruction* expected_copy_start_schedule_after = + FindInstruction( + module, + absl::StrCat( + get_inst_prefix_in_iter( + expected_copy_start_iteration), + "op", copy_allocation->copy_start_schedule_after())); + LOG(INFO) << "Expected copy start schedule after: " + << expected_copy_start_schedule_after->name(); + const HloInstruction* copy_start = operand->operand(0); + TF_RET_CHECK(copy_start->opcode() == HloOpcode::kCopyStart); + // Find the instruction before this copy start that is not an + // async copy or gte or parameter. + int copy_start_idx = + live_range->instruction_schedule().at(copy_start); + const HloInstruction* copy_start_schedule_after = nullptr; + for (int i = copy_start_idx - 1; i >= 0; --i) { + HloOpcode opcode = flattened_instructions.at(i)->opcode(); + if (opcode != HloOpcode::kCopyStart && + opcode != HloOpcode::kCopyDone && + opcode != HloOpcode::kGetTupleElement && + opcode != HloOpcode::kParameter) { + copy_start_schedule_after = flattened_instructions.at(i); + break; + } + } + TF_RET_CHECK(copy_start_schedule_after != nullptr); + EXPECT_EQ(copy_start_schedule_after, + expected_copy_start_schedule_after); + } + } + } + } + } + return OkStatus(); + } + + private: + Options options_; + CostAnalysisOptions cost_analysis_options_; + std::unique_ptr hlo_cost_analysis_; + std::unique_ptr cost_analysis_; + std::unique_ptr alias_analysis_; + std::unique_ptr live_range_; + std::unique_ptr optimizer_; +}; + +TEST_F(MemoryBoundLoopOptimizerTest, SimplePrefetch) { + absl::string_view hlo_loop_str = R"( + $op0 = f32[1,4] add(f32[1,4] $prev_op3, f32[1,4] $prev_op4) + $op1 = f32[1,4] add(f32[1,4] $prev_op4, f32[1,4] $op0) + $op2 = f32[1,4] add(f32[1,4] $op0, f32[1,4] $op1) + $op3 = f32[1,4] add(f32[1,4] $op1, f32[1,4] $op2) + $op4 = f32[1,4] add(f32[1,4] $param0, f32[1,4] $op3) + ROOT $root = tuple($op4, $param0) + )"; + int loop_start_idx; + MemoryBoundLoopOptimizer* optimizer; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndCreateOptimizer(hlo_loop_str, + /*alternate_memory_size=*/128, + loop_start_idx, &optimizer)); + + optimizer->Optimize(); + absl::flat_hash_set seen_uses; + for (const MemoryBoundLoopOptimizer::LoopValue& loop_value : + optimizer->loop_values()) { + LOG(INFO) << loop_value.ToString(); + if (loop_value.hlo_values.front() + ->defining_position() + .instruction->name() == "param0") { + EXPECT_TRUE(loop_value.allocations.back()->is_copy_allocation()); + } + for (const auto& allocation : loop_value.allocations) { + for (const HloUse& use : allocation->uses()) { + EXPECT_FALSE(seen_uses.contains(use)) << use.ToString(); + seen_uses.insert(use); + } + } + } + + // Ensure all of the uses in the loop have an associated use. + for (absl::string_view inst_name : {"op0", "op1", "op2", "op3", "op4"}) { + HloInstruction* inst = + module->entry_computation()->GetInstructionWithName(inst_name); + EXPECT_TRUE(seen_uses.contains(HloUse{inst, 0})) << inst_name; + EXPECT_TRUE(seen_uses.contains(HloUse{inst, 1})) << inst_name; + } +} + +// Specify a ReservedScopedMemoryFunction to the loop optimizer that causes each +// HLO to reserve the entire alternate memory. If the loop optimizer is +// correctly accounting for reserved scoped memory, it should not put any +// allocations in alternate memory, which we test. +TEST_F(MemoryBoundLoopOptimizerTest, ReservedScopedMemory) { + absl::string_view hlo_loop_str = R"( + $op0 = f32[1,4] add(f32[1,4] $prev_op3, f32[1,4] $prev_op4) + $op1 = f32[1,4] add(f32[1,4] $prev_op4, f32[1,4] $op0) + $op2 = f32[1,4] add(f32[1,4] $op0, f32[1,4] $op1) + $op3 = f32[1,4] add(f32[1,4] $op1, f32[1,4] $op2) + $op4 = f32[1,4] add(f32[1,4] $param0, f32[1,4] $op3) + ROOT $root = tuple($op4, $param0) + )"; + int loop_start_idx; + MemoryBoundLoopOptimizer* optimizer; + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndCreateOptimizer( + hlo_loop_str, + /*alternate_memory_size=*/128, loop_start_idx, &optimizer, + [](const HloInstruction*, + const absl::flat_hash_set>&, + const absl::flat_hash_set&) { return 128; })); + + optimizer->Optimize(); + for (const MemoryBoundLoopOptimizer::LoopValue& loop_value : + optimizer->loop_values()) { + LOG(INFO) << "Loop value: " << loop_value.ToString(); + for (const auto& allocation : loop_value.allocations) { + ASSERT_NE(static_cast(allocation->memory_space()), + kAlternateMemorySpace); + } + } +} + +// Check that a spurious GetTupleElement instruction in a later iteration of a +// loop does not cause MSA to CHECK fail, when identifying loops. Prior to the +// change instroduced with this test, IdentifyAndOptimizeMemoryBoundLoops() +// would recognize 4 iterations to the loop thinking that gte is a repeat of +// op2. Doing so triggers the CHECKs introduced by the change that added this +// test to fail. So, the point of this test is to verfiy that we do not check +// fail. +TEST_F(MemoryBoundLoopOptimizerTest, GetTupleElement) { + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + p0 = f32[1,4] parameter(0) + p1 = f32[1,4] parameter(1) + p2 = f32[1,4] parameter(2) + p3 = f32[1,4] parameter(3) + p4 = f32[1,4] parameter(4) + p5 = f32[1,4] parameter(5) + p6 = f32[1,4] parameter(6) + tupleparam = (f32[1,4], f32[1,4]) parameter(7) + + // Iteration 0 + op1 = tanh(p0) + op2 = tanh(p1) + op3 = tanh(op2) + op4 = add(op1, op3) + + // Iteration 1 + op5 = tanh(p2) + op6 = tanh(p3) + op7 = tanh(op6) + op8 = add(op5, op7) + + // Iteration 2 + op9 = tanh(p4) + op10 = tanh(p5) + op11 = tanh(op10) + op12 = add(op9, op11) + + // Not an iteration + op13 = tanh(p6) + gte = get-tuple-element(tupleparam), index=1 + op14 = tanh(gte) + op15 = tanh(op14) + op16 = add(op13, op15) + + ROOT root = tuple(tupleparam, op4, op8, op12, op16) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + VLOG(1) << "Original module:\n" + << module->ToString(HloPrintOptions::ShortParsable()); + + TF_ASSERT_OK_AND_ASSIGN(auto preset_assignments, RunMsa(module.get())); +} + +TEST_F(MemoryBoundLoopOptimizerTest, NoAlternateMem) { + absl::string_view hlo_loop_str = R"( + $op0 = f32[1,4] add(f32[1,4] $prev_op3, f32[1,4] $prev_op4) + $op1 = f32[1,4] add(f32[1,4] $prev_op4, f32[1,4] $op0) + $op2 = f32[1,4] add(f32[1,4] $op0, f32[1,4] $op1) + $op3 = f32[1,4] add(f32[1,4] $op1, f32[1,4] $op2) + $op4 = f32[1,4] add(f32[1,4] $param0, f32[1,4] $op3) + ROOT $root = tuple($op4, $param0) + )"; + int loop_start_idx; + MemoryBoundLoopOptimizer* optimizer; + // Set alternate memory size to zero so nothing should be in the alternate + // memory. We still expect to find an allocation for all uses. + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndCreateOptimizer(hlo_loop_str, + /*alternate_memory_size=*/0, + loop_start_idx, &optimizer)); + + optimizer->Optimize(); + absl::flat_hash_set seen_uses; + for (const MemoryBoundLoopOptimizer::LoopValue& loop_value : + optimizer->loop_values()) { + LOG(INFO) << loop_value.ToString(); + for (const auto& allocation : loop_value.allocations) { + EXPECT_EQ(allocation->memory_space(), MemorySpace::kDefault); + for (const HloUse& use : allocation->uses()) { + EXPECT_FALSE(seen_uses.contains(use)) << use.ToString(); + seen_uses.insert(use); + } + } + } + + // Ensure all of the uses in the loop have an associated use. + for (absl::string_view inst_name : {"op0", "op1", "op2", "op3", "op4"}) { + HloInstruction* inst = + module->entry_computation()->GetInstructionWithName(inst_name); + EXPECT_TRUE(seen_uses.contains(HloUse{inst, 0})) << inst_name; + EXPECT_TRUE(seen_uses.contains(HloUse{inst, 1})) << inst_name; + } +} + +TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithOverlap) { + // Test for enforcing FIFO order of prefetches. There are three parameters + // that will be prefetched (param0, param1, and param2). param2 is one eighth + // the size of the other parameters and is scheduled later in the loop. So, we + // expect the allocation algorithm to initially allocate param2's prefetch + // with a short live range (since copying it doesn't take very long), but then + // as we try to prefetch param0 and param1, we will wrap around into the + // previous iterations and would need to "early force" param2's prefetch to be + // scheduled earlier to enforce the FIFO order. + // + // alternate_mem_bytes_per_second = 128 + // default_mem_bytes_per_second = 32 + // flops_per_second = 16 + // f32[1,4] add: flops: 4, bytes: 48, compute elapsed: 0.25 + // - All default memory elapsed: 1.5 + // - All alternate memory elapsed: 0.375 + // f32[8,4] add: flops: 32, bytes: 384, compute elapsed: 2 + // - All default memory elapsed: 12 + // - All alternate memory elapsed: 3 + // f32[1,4] copy: bytes: 16, memory elapsed: 0.5 + // f32[8,4] copy: bytes: 128, memory elapsed: 4 + absl::string_view hlo_loop_str = R"( + $op0 = f32[1,4] add(f32[1,4] $prev_op13, f32[1,4] $prev_op14) + $op1 = f32[8,4] add(f32[8,4] $param0, f32[8,4] $param1) + $op2 = f32[1,4] add(f32[1,4] $prev_op14, f32[1,4] $op0) + $op3 = f32[1,4] add(f32[1,4] $op0, f32[1,4] $op2) + $op4 = f32[1,4] add(f32[1,4] $op2, f32[1,4] $op3) + $op5 = f32[1,4] add(f32[1,4] $op3, f32[1,4] $op4) + $op6 = f32[1,4] add(f32[1,4] $op4, f32[1,4] $op5) + $op7 = f32[1,4] add(f32[1,4] $op5, f32[1,4] $op6) + $op8 = f32[1,4] add(f32[1,4] $op6, f32[1,4] $op7) + $op9 = f32[1,4] add(f32[1,4] $op7, f32[1,4] $op8) + $op10 = f32[1,4] add(f32[1,4] $op8, f32[1,4] $op9) + $op11 = f32[1,4] add(f32[1,4] $op9, f32[1,4] $op10) + $op12 = f32[1,4] add(f32[1,4] $op10, f32[1,4] $op11) + $op13 = f32[1,4] add(f32[1,4] $op11, f32[1,4] $op12) + $op14 = f32[1,4] add(f32[1,4] $param2, f32[1,4] $op13) + )"; + + int loop_start_idx; + MemoryBoundLoopOptimizer* optimizer; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndCreateOptimizer(hlo_loop_str, + /*alternate_memory_size=*/512, + loop_start_idx, &optimizer)); + + optimizer->Optimize(); + // We expect the prefetches to be scheduled this way: + // + // + // param0 or param1: + // ===========> =====================================> + // param1 or param0: + // ===========> === + // ==============================================> + // param2: + // =====> ========================================> === + // 13 14| 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14| 0 1 + // prev | loop | next + // + // Temporaries: + // +======+ + // +=========+ + // +=========+ + // +======+ + // +======+ + // +======+ + // +======+ + // +======+ + // +======+ + // +======+ + // +======+ + // +======+ + // +======+ + // +===+ + // +======+ + // +=========+ + // 13 14| 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14| 0 1 + // prev | loop | next + std::vector prefetches; + for (const MemoryBoundLoopOptimizer::LoopValue& loop_value : + optimizer->loop_values()) { + if (!loop_value.allocations.empty() && + loop_value.allocations.back()->is_copy_allocation()) { + prefetches.push_back(static_cast( + loop_value.allocations.back().get())); + } + } + EXPECT_EQ(prefetches.size(), 3); + bool seen_overlap = false; + bool seen_nonoverlap = false; + for (const CopyAllocation* prefetch : prefetches) { + const HloUse& use = *prefetch->uses().begin(); + if (use.instruction->name() == "op14") { + EXPECT_EQ(prefetch->copy_done_schedule_before(), 14); + EXPECT_EQ(prefetch->copy_start_schedule_after(), 0); + } else { + ASSERT_EQ(use.instruction->name(), "op1"); + EXPECT_EQ(prefetch->copy_done_schedule_before(), 1); + if (prefetch->copy_start_schedule_after() == 0) { + EXPECT_FALSE(seen_overlap); + seen_overlap = true; + } else { + EXPECT_GT(prefetch->copy_start_schedule_after(), 1); + EXPECT_FALSE(seen_nonoverlap); + seen_nonoverlap = true; + } + } + } + // We expect to fully saturate the default memory bandwidth. Total default + // memory accesses: + // param0 (128 B) + param1 (128 B) + op1 (128 B) + param2 (16 B) = 400 B + // execution time: + // 400 B / 32 B/s = 12.5 s. + EXPECT_EQ(optimizer->CalculateExecutionTime(), 12.5); + + // Check the memory used at each point of the loop. + const std::vector& remaining_memory = optimizer->remaining_memory(); + // Time 0: 3 temporaries (16 B) + param0 (128 B) + param1 (128 B) + EXPECT_EQ(remaining_memory.at(0), 512 - (3 * 16 + 128 + 128)); + // Time 1: 2 temporaries (16 B) + 2*param0 (128 B) + param1 (128 B) + // + param2 (16 B) + EXPECT_EQ(remaining_memory.at(1), 512 - (2 * 16 + 2 * 128 + 128 + 16)); + // Times 2 and 3: 3 temporaries (16 B) + param0 (128 B) + param2 (16 B) + EXPECT_EQ(remaining_memory.at(2), 512 - (3 * 16 + 128 + 16)); + EXPECT_EQ(remaining_memory.at(3), 512 - (3 * 16 + 128 + 16)); + // Times 4 to 13: 3 temporaries (16 B) + param0 (128 B) + param1 (128 B) + // + param2 (16 B) + for (int i = 4; i <= 13; ++i) { + EXPECT_EQ(remaining_memory.at(i), 512 - (3 * 16 + 128 + 128 + 16)); + } + // Time 14: 2 temporaries (16 B) + param0 (128 B) + param1 (128 B) + // + param2 (16 B) + EXPECT_EQ(remaining_memory.at(14), 512 - (2 * 16 + 128 + 128 + 16)); +} + +TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithoutOverlap) { + // Same as the test above, except the size of alternate memory is less than + // 384, which is the minimum amount needed to keep the three 128-byte sized + // parameters alive (one of the parameters would need to be overlapped with + // the previous iteration, so counts 2X). In that case, we won't be able to + // fully saturate the bandwidth. + // + // alternate_mem_bytes_per_second = 128 + // default_mem_bytes_per_second = 32 + // flops_per_second = 16 + // f32[1,4] add: flops: 4, bytes: 48, compute elapsed: 0.25 + // - All default memory elapsed: 1.5 + // - All alternate memory elapsed: 0.375 + // f32[8,4] add: flops: 32, bytes: 384, compute elapsed: 2 + // - All default memory elapsed: 12 + // - All alternate memory elapsed: 3 + // f32[1,4] copy: bytes: 16, memory elapsed: 0.5 + // f32[8,4] copy: bytes: 128, memory elapsed: 4 + absl::string_view hlo_loop_str = R"( + $op0 = f32[1,4] add(f32[1,4] $prev_op13, f32[1,4] $prev_op14) + $op1 = f32[8,4] add(f32[8,4] $param0, f32[8,4] $param1) + $op2 = f32[1,4] add(f32[1,4] $prev_op14, f32[1,4] $op0) + $op3 = f32[1,4] add(f32[1,4] $op0, f32[1,4] $op2) + $op4 = f32[1,4] add(f32[1,4] $op2, f32[1,4] $op3) + $op5 = f32[1,4] add(f32[1,4] $op3, f32[1,4] $op4) + $op6 = f32[1,4] add(f32[1,4] $op4, f32[1,4] $op5) + $op7 = f32[1,4] add(f32[1,4] $op5, f32[1,4] $op6) + $op8 = f32[1,4] add(f32[1,4] $op6, f32[1,4] $op7) + $op9 = f32[1,4] add(f32[1,4] $op7, f32[1,4] $op8) + $op10 = f32[1,4] add(f32[1,4] $op8, f32[1,4] $op9) + $op11 = f32[1,4] add(f32[1,4] $op9, f32[1,4] $op10) + $op12 = f32[1,4] add(f32[1,4] $op10, f32[1,4] $op11) + $op13 = f32[1,4] add(f32[1,4] $op11, f32[1,4] $op12) + $op14 = f32[1,4] add(f32[1,4] $param2, f32[1,4] $op13) + )"; + + int loop_start_idx; + MemoryBoundLoopOptimizer* optimizer; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndCreateOptimizer(hlo_loop_str, + /*alternate_memory_size=*/350, + loop_start_idx, &optimizer)); + + optimizer->Optimize(); + // We expect the prefetches to be scheduled this way: + // + // + // param0 or param1: + // ===========> =====================================> + // param2: + // =====> ===============================> + // 13 14| 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14| 0 1 + // prev | loop | next + std::vector prefetches; + for (const MemoryBoundLoopOptimizer::LoopValue& loop_value : + optimizer->loop_values()) { + if (!loop_value.allocations.empty() && + loop_value.allocations.back()->is_copy_allocation()) { + prefetches.push_back(static_cast( + loop_value.allocations.back().get())); + } + } + EXPECT_EQ(prefetches.size(), 2); + std::optional expected_op14_copy_start_time; + for (const CopyAllocation* prefetch : prefetches) { + const HloUse& use = *prefetch->uses().begin(); + if (use.instruction->name() == "op1") { + EXPECT_EQ(prefetch->copy_done_schedule_before(), 1); + EXPECT_GT(prefetch->copy_start_schedule_after(), 1); + expected_op14_copy_start_time = prefetch->copy_start_schedule_after(); + } + } + EXPECT_TRUE(expected_op14_copy_start_time.has_value()); + for (const CopyAllocation* prefetch : prefetches) { + const HloUse& use = *prefetch->uses().begin(); + if (use.instruction->name() == "op14") { + EXPECT_EQ(prefetch->copy_done_schedule_before(), 14); + EXPECT_EQ(prefetch->copy_start_schedule_after(), + *expected_op14_copy_start_time); + } + } + // We expect not to fully saturate the default memory bandwidth. + EXPECT_GT(optimizer->CalculateExecutionTime(), 12.5); +} + +TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithOverlap2) { + // Same as PrefetchFifoOrderWithOverlap, except the instructions are shifted + // earlier by one such that param0 and param1 are used by op0. This tests that + // we are accounting for overlaps for prefetches that span three iterations. + // + // alternate_mem_bytes_per_second = 128 + // default_mem_bytes_per_second = 32 + // flops_per_second = 16 + // f32[1,4] add: flops: 4, bytes: 48, compute elapsed: 0.25 + // - All default memory elapsed: 1.5 + // - All alternate memory elapsed: 0.375 + // f32[8,4] add: flops: 32, bytes: 384, compute elapsed: 2 + // - All default memory elapsed: 12 + // - All alternate memory elapsed: 3 + // f32[1,4] copy: bytes: 16, memory elapsed: 0.5 + // f32[8,4] copy: bytes: 128, memory elapsed: 4 + absl::string_view hlo_loop_str = R"( + $op0 = f32[8,4] add(f32[8,4] $param0, f32[8,4] $param1) + $op1 = f32[1,4] add(f32[1,4] $prev_op13, f32[1,4] $prev_op14) + $op2 = f32[1,4] add(f32[1,4] $prev_op14, f32[1,4] $op1) + $op3 = f32[1,4] add(f32[1,4] $op1, f32[1,4] $op2) + $op4 = f32[1,4] add(f32[1,4] $op2, f32[1,4] $op3) + $op5 = f32[1,4] add(f32[1,4] $op3, f32[1,4] $op4) + $op6 = f32[1,4] add(f32[1,4] $op4, f32[1,4] $op5) + $op7 = f32[1,4] add(f32[1,4] $op5, f32[1,4] $op6) + $op8 = f32[1,4] add(f32[1,4] $op6, f32[1,4] $op7) + $op9 = f32[1,4] add(f32[1,4] $op7, f32[1,4] $op8) + $op10 = f32[1,4] add(f32[1,4] $op8, f32[1,4] $op9) + $op11 = f32[1,4] add(f32[1,4] $op9, f32[1,4] $op10) + $op12 = f32[1,4] add(f32[1,4] $op10, f32[1,4] $op11) + $op13 = f32[1,4] add(f32[1,4] $param2, f32[1,4] $op12) + $op14 = f32[1,4] add(f32[1,4] $op12, f32[1,4] $op13) + )"; + + int loop_start_idx; + MemoryBoundLoopOptimizer* optimizer; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndCreateOptimizer(hlo_loop_str, + /*alternate_memory_size=*/512, + loop_start_idx, &optimizer)); + + optimizer->Optimize(); + // We expect the prefetches to be scheduled this way: + // + // + // param0 or param1: + // ========> =====================================> === + // param1 or param0: + // ========> ====== + // ==============================================> + // param2: + // ==> ========================================> ====== + // 13 14| 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14| 0 1 + // prev | loop | next + std::vector prefetches; + for (const MemoryBoundLoopOptimizer::LoopValue& loop_value : + optimizer->loop_values()) { + if (!loop_value.allocations.empty() && + loop_value.allocations.back()->is_copy_allocation()) { + prefetches.push_back(static_cast( + loop_value.allocations.back().get())); + } + } + EXPECT_EQ(prefetches.size(), 3); + bool seen_overlap = false; + bool seen_nonoverlap = false; + for (const CopyAllocation* prefetch : prefetches) { + const HloUse& use = *prefetch->uses().begin(); + if (use.instruction->name() == "op13") { + EXPECT_EQ(prefetch->copy_done_schedule_before(), 13); + EXPECT_EQ(prefetch->copy_start_schedule_after(), 14); + } else { + ASSERT_EQ(use.instruction->name(), "op0"); + EXPECT_EQ(prefetch->copy_done_schedule_before(), 0); + if (prefetch->copy_start_schedule_after() == 14) { + EXPECT_FALSE(seen_overlap); + seen_overlap = true; + } else { + EXPECT_LT(prefetch->copy_start_schedule_after(), 14); + EXPECT_FALSE(seen_nonoverlap); + seen_nonoverlap = true; + } + } + } + // We expect to fully saturate the default memory bandwidth. Total default + // memory accesses: + // param0 (128 B) + param1 (128 B) + op1 (128 B) + param2 (16 B) = 400 B + // execution time: + // 400 B / 32 B/s = 12.5 s. + EXPECT_EQ(optimizer->CalculateExecutionTime(), 12.5); +} + +TEST_F(MemoryBoundLoopOptimizerTest, OptimizerEndToEnd) { + absl::string_view hlo_loop_str = R"( + $op0 = f32[1,4] add(f32[1,4] $prev_op13, f32[1,4] $prev_op14) + $op1 = f32[8,4] add(f32[8,4] $param0, f32[8,4] $param1) + $op2 = f32[1,4] add(f32[1,4] $prev_op14, f32[1,4] $op0) + $op3 = f32[1,4] add(f32[1,4] $op0, f32[1,4] $op2) + $op4 = f32[1,4] add(f32[1,4] $op2, f32[1,4] $op3) + $op5 = f32[1,4] add(f32[1,4] $op3, f32[1,4] $op4) + $op6 = f32[1,4] add(f32[1,4] $op4, f32[1,4] $op5) + $op7 = f32[1,4] add(f32[1,4] $op5, f32[1,4] $op6) + $op8 = f32[1,4] add(f32[1,4] $op6, f32[1,4] $op7) + $op9 = f32[1,4] add(f32[1,4] $op7, f32[1,4] $op8) + $op10 = f32[1,4] add(f32[1,4] $op8, f32[1,4] $op9) + $op11 = f32[1,4] add(f32[1,4] $op9, f32[1,4] $op10) + $op12 = f32[1,4] add(f32[1,4] $op10, f32[1,4] $op11) + $op13 = f32[1,4] add(f32[1,4] $op11, f32[1,4] $op12) + $op14 = f32[1,4] add(f32[1,4] $param2, f32[1,4] $op13) + ROOT $root = tuple($op1, $op14) + )"; + + int loop_start_idx; + MemoryBoundLoopOptimizer* optimizer; + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndCreateOptimizer(hlo_loop_str, + /*alternate_memory_size=*/1024, + loop_start_idx, &optimizer)); + + optimizer->Optimize(); + TF_ASSERT_OK_AND_ASSIGN(auto preset_assignments, + RunMsa(module.get(), /*alternate_memory_size=*/1024)); + + TF_ASSERT_OK(VerifyMsaEquivalence(module.get())); +} + +TEST_F(MemoryBoundLoopOptimizerTest, OptimizerEndToEndUnsupportedAllocation) { + // op2 is a loop-carried dependency, which is currently not supported. But the + // usual MSA algorithm should still be able to give it an alternate memory + // allocation. + absl::string_view hlo_loop_str = R"( + $op0 = f32[1,4] add(f32[1,4] $prev_op3, f32[1,4] $prev_op4) + $op1 = f32[8,4] add(f32[8,4] $param0, f32[8,4] $param1) + $op2 = f32[1,4] add(f32[1,4] $prev_op2, f32[1,4] $op0) + $op3 = f32[1,4] add(f32[1,4] $op0, f32[1,4] $op2) + $op4 = f32[1,4] add(f32[1,4] $op2, f32[1,4] $op3) + ROOT $root = tuple($op1, $op4) + )"; + + int loop_start_idx; + MemoryBoundLoopOptimizer* optimizer; + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndCreateOptimizer(hlo_loop_str, + /*alternate_memory_size=*/1024, + loop_start_idx, &optimizer)); + + optimizer->Optimize(); + TF_ASSERT_OK_AND_ASSIGN(auto preset_assignments, + RunMsa(module.get(), /*alternate_memory_size=*/1024)); + + TF_ASSERT_OK(VerifyMsaEquivalence(module.get(), + /*expect_unsupported_allocations=*/true)); + + const HloInstruction* op2 = FindInstruction(module.get(), "op2"); + EXPECT_EQ(op2->shape().layout().memory_space(), kAlternateMemorySpace); +} + +TEST_F(MemoryBoundLoopOptimizerTest, TempAndPinnedAllocations) { + absl::string_view hlo_str = R"( + HloModule module, is_scheduled=true + + while_cond { + while_cond_param = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) parameter(0) + ROOT p = pred[] get-tuple-element(while_cond_param), index=5 + } + + while_body { + while_body_param = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) parameter(0) + pinned_prev_param0 = f32[1,4] get-tuple-element(while_body_param), index=0 + next_param0 = f32[1,4] get-tuple-element(while_body_param), index=1 + prev_prev_op3 = f32[1,4] get-tuple-element(while_body_param), index=2 + prev_prev_op4 = f32[1,4] get-tuple-element(while_body_param), index=3 + prev_op0 = f32[1,4] add(f32[1,4] prev_prev_op3, f32[1,4] prev_prev_op4) + prev_op1 = f32[1,4] add(f32[1,4] prev_prev_op4, f32[1,4] prev_op0) + prev_op2 = f32[1,4] add(f32[1,4] prev_op0, f32[1,4] prev_op1) + prev_op3 = f32[1,4] add(f32[1,4] prev_op1, f32[1,4] prev_op2) + prev_op4 = f32[1,4] multiply(f32[1,4] pinned_prev_param0, f32[1,4] prev_op3) + op0 = f32[1,4] add(f32[1,4] prev_op3, f32[1,4] prev_op4) + op1 = f32[1,4] add(f32[1,4] prev_op4, f32[1,4] op0) + op2 = f32[1,4] add(f32[1,4] op0, f32[1,4] op1) + op3 = f32[1,4] add(f32[1,4] op1, f32[1,4] op2) + op4 = f32[1,4] multiply(f32[1,4] pinned_prev_param0, f32[1,4] op3) + next_op0 = f32[1,4] add(f32[1,4] op3, f32[1,4] op4) + next_op1 = f32[1,4] add(f32[1,4] op4, f32[1,4] next_op0) + next_op2 = f32[1,4] add(f32[1,4] next_op0, f32[1,4] next_op1) + next_op3 = f32[1,4] add(f32[1,4] next_op1, f32[1,4] next_op2) + next_op4 = f32[1,4] multiply(f32[1,4] pinned_prev_param0, f32[1,4] next_op3) + p = pred[] get-tuple-element(while_body_param), index=5 + ROOT root = tuple(pinned_prev_param0, next_param0, prev_prev_op3, prev_prev_op4, next_op4, p) + } + + ENTRY entry { + p0 = f32[1,4] parameter(0) + p1 = f32[1,4] parameter(1) + p2 = f32[1,4] parameter(2) + p3 = f32[1,4] parameter(3) + p4 = pred[] parameter(4) + copy = f32[1,4] copy(p3) + tuple = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) tuple(p0, p1, p2, p3, copy, p4) + while = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) while(tuple), condition=while_cond, body=while_body + ROOT root = f32[1,4] get-tuple-element(while), index=4 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_str)); + + TF_ASSERT_OK_AND_ASSIGN(auto optimizer, + CreateOptimizer(19, 24, module.get(), + /*alternate_memory_size=*/512)); + optimizer->Optimize(); + + const std::vector& remaining_memory = optimizer->remaining_memory(); + // Time 0: 3 temporaries (16 B) + 1 pinned (16 B) + EXPECT_EQ(remaining_memory.at(0), 512 - (3 * 16 + 16)); + // Time 1: 3 temporaries (16 B) + 1 pinned (16 B) + EXPECT_EQ(remaining_memory.at(1), 512 - (3 * 16 + 16)); + // Time 2: 3 temporaries (16 B) + 1 pinned (16 B) + EXPECT_EQ(remaining_memory.at(2), 512 - (3 * 16 + 16)); + // Time 3: 3 temporaries (16 B) + 1 pinned (16 B) + EXPECT_EQ(remaining_memory.at(3), 512 - (3 * 16 + 16)); + // Time 4: 2 temporaries (16 B) + 1 pinned (16 B) + EXPECT_EQ(remaining_memory.at(4), 512 - (2 * 16 + 16)); +} + +TEST_F(MemoryBoundLoopOptimizerTest, NegativeSavingNotPinned) { + absl::string_view hlo_str = R"( + HloModule module, is_scheduled=true + + while_cond { + while_cond_param = (f32[28,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) parameter(0) + ROOT p = pred[] get-tuple-element(while_cond_param), index=5 + } + + while_body { + while_body_param = (f32[28,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) parameter(0) + pinned_prev_param0 = f32[28,4] get-tuple-element(while_body_param), index=0 + zero = s32[] constant(0) + next_param0 = f32[1,4] get-tuple-element(while_body_param), index=1 + prev_prev_op3 = f32[1,4] get-tuple-element(while_body_param), index=2 + prev_prev_op4 = f32[1,4] get-tuple-element(while_body_param), index=3 + prev_op0 = f32[1,4] add(f32[1,4] prev_prev_op3, f32[1,4] prev_prev_op4) + prev_op1 = f32[1,4] add(f32[1,4] prev_prev_op4, f32[1,4] prev_op0) + prev_op2 = f32[1,4] add(f32[1,4] prev_op0, f32[1,4] prev_op1) + prev_op3 = f32[1,4] add(f32[1,4] prev_op1, f32[1,4] prev_op2) + pinned_slice = f32[1,4] dynamic-slice(pinned_prev_param0, zero, zero), dynamic_slice_sizes={1,4} + prev_op4 = f32[1,4] multiply(f32[1,4] pinned_slice, f32[1,4] prev_op3) + op0 = f32[1,4] add(f32[1,4] prev_op3, f32[1,4] prev_op4) + op1 = f32[1,4] add(f32[1,4] prev_op4, f32[1,4] op0) + op2 = f32[1,4] add(f32[1,4] op0, f32[1,4] op1) + op3 = f32[1,4] add(f32[1,4] op1, f32[1,4] op2) + pinned_slice2 = f32[1,4] dynamic-slice(pinned_prev_param0, zero, zero), dynamic_slice_sizes={1,4} + op4 = f32[1,4] multiply(f32[1,4] pinned_slice2, f32[1,4] op3) + next_op0 = f32[1,4] add(f32[1,4] op3, f32[1,4] op4) + next_op1 = f32[1,4] add(f32[1,4] op4, f32[1,4] next_op0) + next_op2 = f32[1,4] add(f32[1,4] next_op0, f32[1,4] next_op1) + next_op3 = f32[1,4] add(f32[1,4] next_op1, f32[1,4] next_op2) + pinned_slice3 = f32[1,4] dynamic-slice(pinned_prev_param0, zero, zero), dynamic_slice_sizes={1,4} + next_op4 = f32[1,4] multiply(f32[1,4] pinned_slice3, f32[1,4] next_op3) + p = pred[] get-tuple-element(while_body_param), index=5 + ROOT root = tuple(pinned_prev_param0, next_param0, prev_prev_op3, prev_prev_op4, next_op4, p) + } + + ENTRY entry { + p0 = f32[28,4] parameter(0) + p1 = f32[1,4] parameter(1) + p2 = f32[1,4] parameter(2) + p3 = f32[1,4] parameter(3) + p4 = pred[] parameter(4) + copy = f32[1,4] copy(p3) + tuple = (f32[28,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) tuple(p0, p1, p2, p3, copy, p4) + while = (f32[28,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) while(tuple), condition=while_cond, body=while_body + ROOT root = f32[1,4] get-tuple-element(while), index=4 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_str)); + + TF_ASSERT_OK_AND_ASSIGN(auto optimizer, + CreateOptimizer(21, 27, module.get(), + /*alternate_memory_size=*/512)); + optimizer->Optimize(); + + const std::vector& remaining_memory = optimizer->remaining_memory(); + // We expect that pinned_prev_param0 would not get pinned due to negative + // savings: 32(uses) - 28 * 16(size) = -416 Time 0: 3 temporaries (16 B) + 1 + // pinned (4 B) + EXPECT_EQ(remaining_memory.at(0), 512 - (3 * 16 + 4)); +} + +TEST_F(MemoryBoundLoopOptimizerTest, OptimizerEndToEndWhileLoop) { + absl::string_view hlo_str = R"( +HloModule module, is_scheduled=true + +while_cond { + while_cond_param = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) parameter(0) + ROOT p = pred[] get-tuple-element(while_cond_param), index=6 +} + +while_body { + while_body_param = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) parameter(0) + prev_param0 = f32[1,4] get-tuple-element(while_body_param), index=0 + param0 = f32[1,4] get-tuple-element(while_body_param), index=1 + next_param0 = f32[1,4] get-tuple-element(while_body_param), index=2 + prev_prev_op3 = f32[1,4] get-tuple-element(while_body_param), index=3 + prev_prev_op4 = f32[1,4] get-tuple-element(while_body_param), index=4 + prev_op0 = f32[1,4] add(f32[1,4] prev_prev_op3, f32[1,4] prev_prev_op4) + prev_op1 = f32[1,4] add(f32[1,4] prev_prev_op4, f32[1,4] prev_op0) + prev_op2 = f32[1,4] add(f32[1,4] prev_op0, f32[1,4] prev_op1) + prev_op3 = f32[1,4] add(f32[1,4] prev_op1, f32[1,4] prev_op2) + prev_op4 = f32[1,4] multiply(f32[1,4] prev_param0, f32[1,4] prev_op3) + op0 = f32[1,4] add(f32[1,4] prev_op3, f32[1,4] prev_op4) + op1 = f32[1,4] add(f32[1,4] prev_op4, f32[1,4] op0) + op2 = f32[1,4] add(f32[1,4] op0, f32[1,4] op1) + op3 = f32[1,4] add(f32[1,4] op1, f32[1,4] op2) + op4 = f32[1,4] multiply(f32[1,4] param0, f32[1,4] op3) + next_op0 = f32[1,4] add(f32[1,4] op3, f32[1,4] op4) + next_op1 = f32[1,4] add(f32[1,4] op4, f32[1,4] next_op0) + next_op2 = f32[1,4] add(f32[1,4] next_op0, f32[1,4] next_op1) + next_op3 = f32[1,4] add(f32[1,4] next_op1, f32[1,4] next_op2) + next_op4 = f32[1,4] multiply(f32[1,4] next_param0, f32[1,4] next_op3) + p = pred[] get-tuple-element(while_body_param), index=6 + ROOT root = tuple(prev_param0, param0, next_param0, prev_prev_op3, prev_prev_op4, next_op4, p) +} + +ENTRY entry { + p0 = f32[1,4] parameter(0) + p1 = f32[1,4] parameter(1) + p2 = f32[1,4] parameter(2) + p3 = f32[1,4] parameter(3) + p4 = f32[1,4] parameter(4) + p5 = pred[] parameter(5) + copy = f32[1,4] copy(p4) + tuple = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) tuple(p0, p1, p2, p3, p4, copy, p5) + while = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) while(tuple), condition=while_cond, body=while_body + ROOT root = f32[1,4] get-tuple-element(while), index=5 +} + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_str)); + + TF_ASSERT_OK_AND_ASSIGN(auto preset_assignments, + RunMsa(module.get(), /*alternate_memory_size=*/512)); + + // We expect operand 0 of prev_op4, op4, and next_op4 to all be prefetches of + // same distance from the user. + TF_ASSERT_OK_AND_ASSIGN(auto alias_analysis, + HloAliasAnalysis::Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_live_range, + HloLiveRange::Run(module->schedule(), *alias_analysis, + module->entry_computation())); + const HloInstruction* prev_copy_done = + FindInstruction(module.get(), "prev_op4")->operand(0); + const HloInstruction* copy_done = + FindInstruction(module.get(), "op4")->operand(0); + const HloInstruction* next_copy_done = + FindInstruction(module.get(), "next_op4")->operand(0); + ASSERT_EQ(prev_copy_done->opcode(), HloOpcode::kCopyDone); + ASSERT_EQ(copy_done->opcode(), HloOpcode::kCopyDone); + ASSERT_EQ(next_copy_done->opcode(), HloOpcode::kCopyDone); + EXPECT_EQ(prev_copy_done->shape().layout().memory_space(), + kAlternateMemorySpace); + EXPECT_EQ(copy_done->shape().layout().memory_space(), kAlternateMemorySpace); + EXPECT_EQ(next_copy_done->shape().layout().memory_space(), + kAlternateMemorySpace); + auto prefetch_distance = [&](const HloInstruction* copy_done) { + return hlo_live_range->instruction_schedule().at(copy_done) - + hlo_live_range->instruction_schedule().at(copy_done->operand(0)); + }; + EXPECT_EQ(prefetch_distance(prev_copy_done), prefetch_distance(copy_done)); + EXPECT_EQ(prefetch_distance(next_copy_done), prefetch_distance(copy_done)); +} + +TEST_F(MemoryBoundLoopOptimizerTest, OptimizerEndToEndNestedWhileLoopBug) { + absl::string_view hlo_str = R"( +HloModule module, is_scheduled=true + +prev_while_cond { + prev_while_cond_param = (f32[1,4], pred[]) parameter(0) + ROOT p = pred[] get-tuple-element(prev_while_cond_param), index=1 +} + +prev_while_body { + prev_while_body_param = (f32[1,4], pred[]) parameter(0) + prev_while_body_gte = f32[1,4] get-tuple-element(prev_while_body_param), index=0 + prev_while_body_pred = pred[] get-tuple-element(prev_while_body_param), index=1 + prev_while_body_op = f32[1,4] negate(prev_while_body_gte) + ROOT prev_while_body_root = (f32[1,4], pred[]) tuple(prev_while_body_op, prev_while_body_pred) +} + +current_while_cond { + current_while_cond_param = (f32[1,4], pred[]) parameter(0) + ROOT p = pred[] get-tuple-element(current_while_cond_param), index=1 +} + +current_while_body { + current_while_body_param = (f32[1,4], pred[]) parameter(0) + current_while_body_gte = f32[1,4] get-tuple-element(current_while_body_param), index=0 + current_while_body_pred = pred[] get-tuple-element(current_while_body_param), index=1 + current_while_body_op = f32[1,4] negate(current_while_body_gte) + ROOT current_while_body_root = (f32[1,4], pred[]) tuple(current_while_body_op, current_while_body_pred) +} + +next_while_cond { + next_while_cond_param = (f32[1,4], pred[]) parameter(0) + ROOT p = pred[] get-tuple-element(next_while_cond_param), index=1 +} + +next_while_body { + next_while_body_param = (f32[1,4], pred[]) parameter(0) + next_while_body_gte = f32[1,4] get-tuple-element(next_while_body_param), index=0 + next_while_body_pred = pred[] get-tuple-element(next_while_body_param), index=1 + next_while_body_op = f32[1,4] negate(next_while_body_gte) + ROOT next_while_body_root = (f32[1,4], pred[]) tuple(next_while_body_op, next_while_body_pred) +} + +while_cond { + while_cond_param = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) parameter(0) + ROOT p = pred[] get-tuple-element(while_cond_param), index=6 +} + +while_body { + while_body_param = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) parameter(0) + prev_param0 = f32[1,4] get-tuple-element(while_body_param), index=0 + param0 = f32[1,4] get-tuple-element(while_body_param), index=1 + next_param0 = f32[1,4] get-tuple-element(while_body_param), index=2 + prev_prev_op3 = f32[1,4] get-tuple-element(while_body_param), index=3 + prev_prev_op4 = f32[1,4] get-tuple-element(while_body_param), index=4 + while_pred = pred[] get-tuple-element(while_body_param), index=6 + prev_op0 = f32[1,4] add(f32[1,4] prev_prev_op3, f32[1,4] prev_prev_op4) + prev_op1 = f32[1,4] add(f32[1,4] prev_prev_op4, f32[1,4] prev_op0) + prev_op2 = f32[1,4] add(f32[1,4] prev_op0, f32[1,4] prev_op1) + prev_op3 = f32[1,4] add(f32[1,4] prev_op1, f32[1,4] prev_op2) + prev_tuple = (f32[1,4], pred[]) tuple(prev_op3, while_pred) + prev_while = (f32[1,4], pred[]) while(prev_tuple), condition=prev_while_cond, body=prev_while_body + prev_gte = f32[1,4] get-tuple-element(prev_while), index=0 + prev_op4 = f32[1,4] multiply(f32[1,4] prev_param0, f32[1,4] prev_gte) + op0 = f32[1,4] add(f32[1,4] prev_op3, f32[1,4] prev_op4) + op1 = f32[1,4] add(f32[1,4] prev_op4, f32[1,4] op0) + op2 = f32[1,4] add(f32[1,4] op0, f32[1,4] op1) + op3 = f32[1,4] add(f32[1,4] op1, f32[1,4] op2) + current_tuple = (f32[1,4], pred[]) tuple(op3, while_pred) + current_while = (f32[1,4], pred[]) while(current_tuple), condition=current_while_cond, body=current_while_body + current_gte = f32[1,4] get-tuple-element(current_while), index=0 + op4 = f32[1,4] multiply(f32[1,4] param0, f32[1,4] current_gte) + next_op0 = f32[1,4] add(f32[1,4] op3, f32[1,4] op4) + next_op1 = f32[1,4] add(f32[1,4] op4, f32[1,4] next_op0) + next_op2 = f32[1,4] add(f32[1,4] next_op0, f32[1,4] next_op1) + next_op3 = f32[1,4] add(f32[1,4] next_op1, f32[1,4] next_op2) + next_tuple = (f32[1,4], pred[]) tuple(next_op3, while_pred) + next_while = (f32[1,4], pred[]) while(next_tuple), condition=next_while_cond, body=next_while_body + next_gte = f32[1,4] get-tuple-element(next_while), index=0 + next_op4 = f32[1,4] multiply(f32[1,4] next_param0, f32[1,4] next_gte) + ROOT root = tuple(prev_param0, param0, next_param0, prev_prev_op3, prev_prev_op4, next_op4, while_pred) +} + +ENTRY entry { + p0 = f32[1,4] parameter(0) + p1 = f32[1,4] parameter(1) + p2 = f32[1,4] parameter(2) + p3 = f32[1,4] parameter(3) + p4 = f32[1,4] parameter(4) + p5 = pred[] parameter(5) + copy = f32[1,4] copy(p4) + tuple = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) tuple(p0, p1, p2, p3, p4, copy, p5) + while = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) while(tuple), condition=while_cond, body=while_body + ROOT root = f32[1,4] get-tuple-element(while), index=5 +} + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_str)); + + TF_ASSERT_OK_AND_ASSIGN(auto preset_assignments, + RunMsa(module.get(), /*alternate_memory_size=*/512)); +} + +} // namespace +} // namespace memory_space_assignment +} // namespace xla diff --git a/xla/service/memory_space_assignment/memory_space_assignment.cc b/xla/service/memory_space_assignment/memory_space_assignment.cc index 85196bb522582..7d1274f69f261 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment.cc +++ b/xla/service/memory_space_assignment/memory_space_assignment.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ limitations under the License. #include "xla/service/memory_space_assignment/memory_space_assignment.h" #include -#include #include #include #include @@ -26,7 +25,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -39,16 +37,13 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/functional/any_invocable.h" -#include "absl/functional/function_ref.h" #include "absl/log/check.h" -#include "absl/memory/memory.h" -#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "re2/re2.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -57,25 +52,34 @@ limitations under the License. #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/buffer_value.h" #include "xla/service/call_graph.h" -#include "xla/service/heap_simulator.h" +#include "xla/service/heap_simulator/allocation_block.h" +#include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/allocation.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" +#include "xla/service/memory_space_assignment/memory_bound_loop_optimizer.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" +#include "xla/service/memory_space_assignment/options.h" #include "xla/service/memory_space_assignment/repacking.h" +#include "xla/service/memory_space_assignment/slice.h" #include "xla/service/memory_space_assignment/tuning_utils.h" #include "xla/service/memory_space_assignment/utils.h" #include "xla/service/time_utils.h" -#include "xla/service/tuple_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" +#include "xla/statusor.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/casts.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" namespace xla { @@ -89,13 +93,6 @@ const HeapSimulator::Chunk kDummyChunk = // For cross-program prefetched buffer, we only perform the freeing optimization // if the buffer occupies less of the execution time ratio than this value. const float kCrossProgramPrefetchOccupyFreeingLimit = 0.6; -// Each time we retry compilation, increase the preferred eviction end time by -// this amount multiplied by preferred overlap to async copy ratio. -const float kEvictionRetryMultiplier = 2.0; -// The number of decreasing intervals for CostAnalysisPrefetchIntervalPicker to -// return when it runs out of increasing intervals. Increasing this number may -// hurt compilation time. -const int kNumExploredDecreasingIntervals = 100; template std::string VectorToString(const std::vector& v, @@ -229,6 +226,17 @@ std::vector FindCrossProgramPrefetchUses( bool IsCrossProgramPrefetchCandidate(const HloValue& value, const HloAliasAnalysis& alias_analysis, const Options& options) { + // Filter out values that alias with the entry computation root. + const HloBuffer& buffer = alias_analysis.GetBufferContainingValue(value); + const HloInstruction* root = alias_analysis.dataflow_analysis() + .module() + .entry_computation() + ->root_instruction(); + for (const HloPosition& position : buffer.ComputePositions()) { + if (position.instruction == root) { + return false; + } + } std::vector uses = FindCrossProgramPrefetchUses(value.GetUses(), alias_analysis); return value.defining_instruction()->parent() == @@ -264,16 +272,15 @@ struct CrossProgramPrefetchBufferSortValues { int64_t use_size = 0; }; -std::vector -FindCrossProgramPrefetchCandidates(const HloAliasAnalysis& alias_analysis, - const HloLiveRange& hlo_live_range, - const Options& options) { - std::vector candidates; +std::vector FindCrossProgramPrefetchCandidates( + const HloAliasAnalysis& alias_analysis, const HloLiveRange& hlo_live_range, + const Options& options) { + std::vector candidates; for (const HloBuffer& buffer : alias_analysis.buffers()) { CHECK_GE(buffer.values().size(), 1); const HloValue* value = buffer.values().at(0); if (IsCrossProgramPrefetchCandidate(*value, alias_analysis, options)) { - MemorySpaceAssignment::BufferInterval interval; + MsaBufferInterval interval; interval.buffer = value; interval.size = options.size_fn(*value); interval.start = 0; @@ -286,7 +293,7 @@ FindCrossProgramPrefetchCandidates(const HloAliasAnalysis& alias_analysis, DefaultCrossProgramPrefetchBufferIntervalComparator default_comparator( hlo_live_range); - MemorySpaceAssignment::BufferIntervalComparator* comparator = + BufferIntervalComparator* comparator = (options.default_cross_program_prefetch_heuristic && options.buffer_interval_comparator ? options.buffer_interval_comparator @@ -336,19 +343,8 @@ Status InsertInstructionAndEnsureOperandsInserted( return OkStatus(); } -std::string UsesToString(const std::vector& uses) { - if (uses.empty()) { - return "none"; - } - std::vector uses_str; - uses_str.reserve(uses.size()); - for (const auto& use : uses) { - uses_str.push_back(use.ToString()); - } - return absl::StrJoin(uses_str, ","); -} - -StatusOr GetScheduleTimeFromInstructionName( +absl::StatusOr +GetScheduleTimeFromInstructionName( absl::string_view name, const absl::flat_hash_map& schedule) { @@ -361,81 +357,170 @@ StatusOr GetScheduleTimeFromInstructionName( name); } -StatusOr GetFilterResult( - const std::pair& - filter, - int64_t operand_size, const HloUse& hlo_use) { - switch (filter.first) { - case FilterUpdatePreferredPrefetch::FilterType::OP_SIZE_GTE: - return FilterUpdatePreferredPrefetch::IsOpSizeGte(operand_size, - filter.second); - case FilterUpdatePreferredPrefetch::FilterType::OP_SIZE_LTE: - return FilterUpdatePreferredPrefetch::IsOpSizeLte(operand_size, - filter.second); - case FilterUpdatePreferredPrefetch::FilterType::INSTRUCTION_NAME_EXACT: - return FilterUpdatePreferredPrefetch::IsInstructionNameExact( - hlo_use.instruction->name(), filter.second); - case FilterUpdatePreferredPrefetch::FilterType::OP_NUMBER_EXACT: - return FilterUpdatePreferredPrefetch::IsOpNumberExact( - hlo_use.operand_number, filter.second); - case FilterUpdatePreferredPrefetch::FilterType::OP_INDEX_EXACT: - return FilterUpdatePreferredPrefetch::IsOpIndexExact( - hlo_use.operand_index, filter.second); - default: - return InvalidArgument("Unknown filter type."); - } -} - -StatusOr> GetOverriddenPreferredPrefetchTime( - const std::vector& - filter_update_preferred_prefetches, +bool DoesOperandMatchFilter(const HloOperandFilter& filter, + int64_t operand_size, const HloUse& hlo_use) { + if (filter.has_size_gte() && operand_size < filter.size_gte()) { + return false; + } + if (filter.has_size_lte() && operand_size > filter.size_lte()) { + return false; + } + if (filter.has_operand_number() && + hlo_use.operand_number != filter.operand_number()) { + return false; + } + if (filter.has_instruction_name_regex() && + !RE2::FullMatch(hlo_use.instruction->name(), + filter.instruction_name_regex())) { + return false; + } + if (filter.has_tuple_index() && + hlo_use.operand_index != ShapeIndex(filter.tuple_index().index().begin(), + filter.tuple_index().index().end())) { + return false; + } + return true; +} + +absl::StatusOr> GetPrefetchTimeByEagerness( + float prefetch_eagerness, int64_t earliest_prefetch_time, + int64_t latest_prefetch_time) { + CHECK_GE(prefetch_eagerness, 0.0); + CHECK_LE(prefetch_eagerness, 1.0); + if (earliest_prefetch_time > latest_prefetch_time) { + return static_cast>(std::nullopt); + } + return static_cast>( + earliest_prefetch_time + + (latest_prefetch_time - earliest_prefetch_time) * prefetch_eagerness); +} + +absl::StatusOr> GetPrefetchTimeAfterInstruction( + const std::string& after_instruction_name, + const absl::flat_hash_map& schedule) { + TF_ASSIGN_OR_RETURN( + auto reference_instruction_time, + GetScheduleTimeFromInstructionName(after_instruction_name, schedule)); + return static_cast>(reference_instruction_time); +} + +absl::StatusOr> GetPrefetchTimeBeforeInstruction( + const std::string& before_instruction_name, + const absl::flat_hash_map& schedule) { + TF_ASSIGN_OR_RETURN( + auto reference_instruction_time, + GetScheduleTimeFromInstructionName(before_instruction_name, schedule)); + return static_cast>(reference_instruction_time - 1); +} + +absl::StatusOr> GetPrefetchTime( + const PreferredPrefetchOverrideOptions& override_options, + int64_t earliest_prefetch_time, int64_t latest_prefetch_time, + const absl::flat_hash_map& + instruction_schedule) { + switch (override_options.options_case()) { + case PreferredPrefetchOverrideOptions::kPrefetchEagerness: + return GetPrefetchTimeByEagerness(override_options.prefetch_eagerness(), + earliest_prefetch_time, + latest_prefetch_time); + case PreferredPrefetchOverrideOptions::kAfterInstructionName: + return GetPrefetchTimeAfterInstruction( + override_options.after_instruction_name(), instruction_schedule); + case PreferredPrefetchOverrideOptions::kBeforeInstructionName: + return GetPrefetchTimeBeforeInstruction( + override_options.before_instruction_name(), instruction_schedule); + case PreferredPrefetchOverrideOptions::OPTIONS_NOT_SET: + break; + } + return static_cast>>(std::nullopt); +} + +absl::StatusOr> GetOverriddenPreferredPrefetchTime( + const PreferredPrefetchOverrides& preferred_prefetch_overrides, int64_t operand_size, const HloUse& hlo_use, const absl::flat_hash_map& instruction_schedule, int64_t earliest_prefetch_time, int64_t latest_prefetch_time) { - for (const auto& filter_update_preferred_prefetch : - filter_update_preferred_prefetches) { - bool match = true; - for (const auto& filter : filter_update_preferred_prefetch.filter_list_) { - TF_ASSIGN_OR_RETURN(auto filter_result, - GetFilterResult(filter, operand_size, hlo_use)); - match &= filter_result; - } - if (match) { - LOG(INFO) << "Config " << filter_update_preferred_prefetch.ToString() - << " match for instruction " << hlo_use.instruction->name() - << " operand number " << hlo_use.operand_number - << " operand index " << hlo_use.operand_index.ToString() - << " size " << operand_size << " live range (" - << earliest_prefetch_time << ", " << latest_prefetch_time - << ")"; - switch (filter_update_preferred_prefetch.override_type_) { - case FilterUpdatePreferredPrefetch::OverrideType::PREFETCH_EAGERNESS: - return filter_update_preferred_prefetch.GetPrefetchByEagerness( - earliest_prefetch_time, latest_prefetch_time); - case FilterUpdatePreferredPrefetch::OverrideType::PUT_AFTER_INSTRUCTION: - return filter_update_preferred_prefetch - .GetPrefetchTimeAfterInstruction(instruction_schedule); - case FilterUpdatePreferredPrefetch::OverrideType:: - PUT_BEFORE_INSTRUCTION: - return filter_update_preferred_prefetch - .GetPrefetchTimeBeforeInstruction(instruction_schedule); - default: - return InvalidArgument("Unknown override type."); - } + for (const auto& override : preferred_prefetch_overrides.overrides()) { + if (!DoesOperandMatchFilter(override.hlo_operand_filter(), operand_size, + hlo_use)) { + continue; + } + LOG(INFO) << "Config match for instruction " << hlo_use.instruction->name() + << " operand number " << hlo_use.operand_number + << " operand index " << hlo_use.operand_index.ToString() + << " size " << operand_size << " live range (" + << earliest_prefetch_time << ", " << latest_prefetch_time << ")"; + TF_ASSIGN_OR_RETURN( + auto prefetch_time, + GetPrefetchTime(override.override_options(), earliest_prefetch_time, + latest_prefetch_time, instruction_schedule)); + if (prefetch_time.has_value() && + prefetch_time.value() >= earliest_prefetch_time && + prefetch_time.value() <= latest_prefetch_time) { + return prefetch_time; } } return static_cast>>(std::nullopt); } +bool DoesResultMatchFilter(const HloPositionMatcher& filter, + const ShapeIndex& index, + HloInstruction* instruction) { + if (filter.has_instruction_regex() && + !RE2::FullMatch(instruction->ToString(), filter.instruction_regex())) { + return false; + } + if (filter.has_instruction_name_regex() && + !RE2::FullMatch(instruction->name(), filter.instruction_name_regex())) { + return false; + } + if (filter.has_tuple_index() && + index != ShapeIndex(filter.tuple_index().index().begin(), + filter.tuple_index().index().end())) { + return false; + } + return true; +} + +// Returns an integer representing the priority of a BufferInterval during +// assignment, a smaller number indicates a higher priority. +int64_t GetBufferIntervalOverridePriority( + const MsaSortOrderOverrides& msa_sort_order_overrides, + const BufferInterval& buffer_interval) { + if (msa_sort_order_overrides.overrides_size() == 0) { + return 0; + } + for (int64_t i = 0; i < msa_sort_order_overrides.overrides_size(); ++i) { + const auto& override = msa_sort_order_overrides.overrides(i); + if (!DoesResultMatchFilter(override.hlo_position_matcher(), + buffer_interval.buffer->index(), + buffer_interval.buffer->instruction())) { + continue; + } + LOG(INFO) << "Override Sort Order Config " << i << " matches " + << buffer_interval.buffer->instruction()->ToString(); + switch (override.override_options().options_case()) { + case MsaSortOrderOverrideOptions::kAssignFirst: + return std::numeric_limits::lowest() + i; + case MsaSortOrderOverrideOptions::kAssignLast: + return std::numeric_limits::max() - i; + case MsaSortOrderOverrideOptions::OPTIONS_NOT_SET: + continue; + } + } + return 0; +} + std::tuple GetAllocationSortTuple( - const std::unique_ptr& allocation) { + const std::unique_ptr& allocation) { int64_t scheduled_on_or_before = allocation->start_time(); int64_t scheduled_on_or_after = allocation->start_time(); if (allocation->is_copy_allocation()) { auto copy_allocation = - tensorflow::down_cast( - allocation.get()); + tensorflow::down_cast(allocation.get()); scheduled_on_or_before = copy_allocation->copy_done_schedule_before(); scheduled_on_or_after = copy_allocation->copy_start_schedule_after(); } @@ -444,25 +529,20 @@ std::tuple GetAllocationSortTuple( scheduled_on_or_after); } -void SortAllocationSequence( - MemorySpaceAssignment::AllocationSequence& allocations) { - absl::c_sort( - allocations, - [](const std::unique_ptr& lhs, - const std::unique_ptr& rhs) { - return GetAllocationSortTuple(lhs) < GetAllocationSortTuple(rhs); - }); +void SortAllocationSequence(AllocationSequence& allocations) { + absl::c_sort(allocations, [](const std::unique_ptr& lhs, + const std::unique_ptr& rhs) { + return GetAllocationSortTuple(lhs) < GetAllocationSortTuple(rhs); + }); } -std::string AllocationSequenceToString( - MemorySpaceAssignment::AllocationSequence& allocations, - bool sort_allocations = false) { +std::string AllocationSequenceToString(AllocationSequence& allocations, + bool sort_allocations = false) { if (sort_allocations) { SortAllocationSequence(allocations); } std::string allocations_str = "\n"; - for (const std::unique_ptr& allocation : - allocations) { + for (const std::unique_ptr& allocation : allocations) { absl::StrAppend(&allocations_str, allocation->ToString(), "\n"); } return allocations_str; @@ -486,15 +566,12 @@ std::string InstructionScheduleToString(const HloLiveRange& hlo_live_range) { return instruction_schedule_str; } -void EnsureParentAllocationIsAvailableForCopy( - MemorySpaceAssignment::CopyAllocation* copy_allocation) { - MemorySpaceAssignment::Allocation& parent_allocation = - copy_allocation->mutable_prev_allocation(); +void EnsureParentAllocationIsAvailableForCopy(CopyAllocation* copy_allocation) { + Allocation& parent_allocation = copy_allocation->mutable_prev_allocation(); parent_allocation.Extend(copy_allocation->copy_done_schedule_before()); if (parent_allocation.is_copy_allocation()) { auto parent_copy_allocation = - tensorflow::down_cast( - &parent_allocation); + tensorflow::down_cast(&parent_allocation); parent_copy_allocation->set_copy_done_schedule_before( std::min(parent_copy_allocation->copy_done_schedule_before(), copy_allocation->start_time())); @@ -504,8 +581,8 @@ void EnsureParentAllocationIsAvailableForCopy( } } -void MakeCopyAllocationJitForSingleUse( - MemorySpaceAssignment::CopyAllocation* copy_allocation, int64_t use_time) { +void MakeCopyAllocationJitForSingleUse(CopyAllocation* copy_allocation, + int64_t use_time) { copy_allocation->set_start_time(use_time - 1); copy_allocation->set_copy_start_schedule_after(use_time - 1); copy_allocation->set_end_time(use_time); @@ -517,28 +594,24 @@ int64_t GetUseTime(const HloUse& use, const HloLiveRange& hlo_live_range) { return hlo_live_range.instruction_schedule().at(use.instruction); } -std::vector -GetAllocationSequenceInRawPointers( - MemorySpaceAssignment::AllocationSequence& allocations) { - std::vector allocations_in_raw_pointers; - for (const std::unique_ptr& allocation : - allocations) { +std::vector GetAllocationSequenceInRawPointers( + AllocationSequence& allocations) { + std::vector allocations_in_raw_pointers; + for (const std::unique_ptr& allocation : allocations) { allocations_in_raw_pointers.push_back(allocation.get()); } return allocations_in_raw_pointers; } -void ProcessPrefetchesToAlternateMemory( - MemorySpaceAssignment::AllocationSequence& allocations, - const HloLiveRange& hlo_live_range) { - std::vector allocations_in_raw_pointers = +void ProcessPrefetchesToAlternateMemory(AllocationSequence& allocations, + const HloLiveRange& hlo_live_range) { + std::vector allocations_in_raw_pointers = GetAllocationSequenceInRawPointers(allocations); for (auto allocation : allocations_in_raw_pointers) { if (allocation->is_copy_allocation() && allocation->is_in_alternate_mem() && !allocation->uses().empty()) { - MemorySpaceAssignment::CopyAllocation* prefetch = - tensorflow::down_cast( - allocation); + CopyAllocation* prefetch = + tensorflow::down_cast(allocation); std::vector uses = prefetch->uses(); // Create a copy of uses. prefetch->clear_uses(); // Clear old uses. // For every prefetch, update prefetch to serve earliest use just in time. @@ -550,11 +623,9 @@ void ProcessPrefetchesToAlternateMemory( for (size_t use_index = 1; use_index < uses.size(); ++use_index) { const HloUse& use = uses[use_index]; int64_t use_time = GetUseTime(use, hlo_live_range); - auto jit_single_use_prefetch = - std::make_unique( - prefetch->mutable_prev_allocation(), - MemorySpaceAssignment::MemorySpace::kAlternate, - prefetch->chunk(), use_time - 1, use_time, use_time); + auto jit_single_use_prefetch = std::make_unique( + prefetch->mutable_prev_allocation(), MemorySpace::kAlternate, + prefetch->chunk(), use_time - 1, use_time, use_time); jit_single_use_prefetch->set_copy_start_schedule_after(use_time - 1); jit_single_use_prefetch->AddUse(use); EnsureParentAllocationIsAvailableForCopy(jit_single_use_prefetch.get()); @@ -564,28 +635,21 @@ void ProcessPrefetchesToAlternateMemory( } } -void MakeEvictionImmediate(MemorySpaceAssignment::CopyAllocation* eviction) { - const MemorySpaceAssignment::Allocation& parent_allocation = - eviction->prev_allocation(); +void MakeEvictionImmediate(CopyAllocation* eviction) { + const Allocation& parent_allocation = eviction->prev_allocation(); eviction->set_start_time(parent_allocation.start_time()); eviction->set_copy_start_schedule_after(parent_allocation.start_time()); eviction->set_copy_done_schedule_before(parent_allocation.start_time() + 1); eviction->Extend(parent_allocation.start_time() + 1); } -absl::flat_hash_map -GetEvictionsMap(std::vector& allocations) { - absl::flat_hash_map - evictions_map; +absl::flat_hash_map GetEvictionsMap( + std::vector& allocations) { + absl::flat_hash_map evictions_map; for (auto& allocation : allocations) { if (allocation->is_copy_allocation() && allocation->is_in_default_mem()) { - auto eviction = - tensorflow::down_cast( - allocation); - MemorySpaceAssignment::Allocation& parent_allocation = - eviction->mutable_prev_allocation(); + auto eviction = tensorflow::down_cast(allocation); + Allocation& parent_allocation = eviction->mutable_prev_allocation(); if (!parent_allocation.is_copy_allocation()) { evictions_map[&parent_allocation] = eviction; } @@ -595,15 +659,13 @@ GetEvictionsMap(std::vector& allocations) { } void ProcessBuffersProducedInAlternateMemory( - MemorySpaceAssignment::AllocationSequence& allocations, - const HloLiveRange& hlo_live_range) { - std::vector allocations_in_raw_pointers = + AllocationSequence& allocations, const HloLiveRange& hlo_live_range) { + std::vector allocations_in_raw_pointers = GetAllocationSequenceInRawPointers(allocations); // For all parent allocations produced in alternate memory, create a map from // parent allocation -> eviction. - absl::flat_hash_map - evictions_map = GetEvictionsMap(allocations_in_raw_pointers); + absl::flat_hash_map evictions_map = + GetEvictionsMap(allocations_in_raw_pointers); // Make all such evictions immediate. for (auto& [_, eviction] : evictions_map) { MakeEvictionImmediate(eviction); @@ -629,22 +691,19 @@ void ProcessBuffersProducedInAlternateMemory( continue; } if (!evictions_map.contains(allocation)) { - auto eviction_unique_ptr = - std::make_unique( - *allocation, MemorySpaceAssignment::MemorySpace::kDefault, - std::nullopt, allocation->start_time(), - allocation->start_time() + 1, allocation->start_time() + 1); + auto eviction_unique_ptr = std::make_unique( + *allocation, MemorySpace::kDefault, std::nullopt, + allocation->start_time(), allocation->start_time() + 1, + allocation->start_time() + 1); eviction_unique_ptr->set_copy_start_schedule_after( allocation->start_time()); evictions_map[allocation] = eviction_unique_ptr.get(); allocations.push_back(std::move(eviction_unique_ptr)); } - MemorySpaceAssignment::CopyAllocation* eviction = - evictions_map[allocation]; - auto jit_single_use_prefetch = - std::make_unique( - *eviction, MemorySpaceAssignment::MemorySpace::kAlternate, - allocation->chunk(), use_time - 1, use_time, use_time); + CopyAllocation* eviction = evictions_map[allocation]; + auto jit_single_use_prefetch = std::make_unique( + *eviction, MemorySpace::kAlternate, allocation->chunk(), + use_time - 1, use_time, use_time); jit_single_use_prefetch->set_copy_start_schedule_after(use_time - 1); jit_single_use_prefetch->AddUse(use); EnsureParentAllocationIsAvailableForCopy(jit_single_use_prefetch.get()); @@ -654,9 +713,8 @@ void ProcessBuffersProducedInAlternateMemory( } } -void TransformAllocationSequenceToSpill( - MemorySpaceAssignment::AllocationSequence& allocations, - const HloLiveRange& hlo_live_range) { +void TransformAllocationSequenceToSpill(AllocationSequence& allocations, + const HloLiveRange& hlo_live_range) { VLOG(2) << "InstructionSchedule before transform\n"; XLA_LOG_LINES(2, InstructionScheduleToString(hlo_live_range)); VLOG(2) << "AllocationSequence before transform\n"; @@ -669,1060 +727,9 @@ void TransformAllocationSequenceToSpill( XLA_LOG_LINES(2, AllocationSequenceToString(allocations, true)); SortAllocationSequence(allocations); } -} // namespace - -/*static*/ StatusOr> -MemorySpaceAssignmentCostAnalysis::Create(const HloCostAnalysis& cost_analysis, - const Options& options, - const HloModule& module) { - TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module)); - TF_ASSIGN_OR_RETURN(auto hlo_live_range, - HloLiveRange::Run(module.schedule(), *alias_analysis, - module.entry_computation())); - auto call_graph = CallGraph::Build(&module); - return absl::WrapUnique(new MemorySpaceAssignmentCostAnalysis( - cost_analysis, options, std::move(alias_analysis), - std::move(hlo_live_range), std::move(call_graph))); -} - -float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( - const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem, - MemorySpaceAssignmentCostAnalysis::Cache* cache) const { - float elapsed_time_due_to_compute = - GetInstructionElapsedDueToCompute(instruction); - float elapsed_time_due_to_memory = - GetInstructionElapsedDueToMemory(instruction); - if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) { - // Memory bound, return how much alternate memory is better. - float while_nest_multiplier; - if (cache) { - // If there is a cache provided, memoize the while nest multiplier. - auto it = cache->while_nest_multiplier.find(&instruction); - if (it != cache->while_nest_multiplier.end()) { - while_nest_multiplier = it->second; - } else { - while_nest_multiplier = IPow( - options_.xla_tpu_memory_space_assignment_while_execution_count, - CalculateComputationNestLevel(&instruction, - /*while_only=*/true)); - cache->while_nest_multiplier[&instruction] = while_nest_multiplier; - } - } else { - while_nest_multiplier = IPow( - options_.xla_tpu_memory_space_assignment_while_execution_count, - CalculateComputationNestLevel(&instruction, - /*while_only=*/true)); - } - return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) * - while_nest_multiplier; - } else { - // Compute bound, return how far off are we to memory boundedness. - return elapsed_time_due_to_memory - elapsed_time_due_to_compute; - } -} -float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, - MemorySpaceAssignmentCostAnalysis::Cache* cache) const { - if (cache) { - auto it = - cache->memory_boundedness.find(interval.buffer->defining_position()); - if (it != cache->memory_boundedness.end()) { - return it->second; - } - } - float alternate_mem_benefit = - GetAlternateMemoryBenefit(interval.buffer->defining_position(), cache); - - for (const HloBuffer* buffer : alias_analysis_->ComputeBuffersAt( - interval.buffer->defining_position().instruction, - interval.buffer->defining_position().index)) { - for (const HloValue* value : buffer->values()) { - for (const HloUse& use : value->GetUses()) { - // We look inside the called computations of while and conditional, so - // don't use the benefit of while and conditional directly. - if (use.instruction->opcode() == HloOpcode::kWhile || - use.instruction->opcode() == HloOpcode::kConditional) { - continue; - } - float use_alternate_mem_benefit = GetAlternateMemoryBenefit(use, cache); - // If the benefit is positive (memory bound), add it to this buffer's - // benefit. If the benefit is negative (compute bound), calculate the - // maximum. - if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) { - alternate_mem_benefit += use_alternate_mem_benefit; - } else { - alternate_mem_benefit = - std::max(alternate_mem_benefit, use_alternate_mem_benefit); - } - } - } - } - - // Penalize larger buffers by dividing the benefit by the square root of - // the size. Empirically, we observed this resulted in better performance - // compared to dividing by the size. - float memory_boundedness = 1; - if (options_ - .xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers == - "NO_SCALE") { - memory_boundedness = alternate_mem_benefit; - } else { - memory_boundedness = alternate_mem_benefit / std::sqrt(interval.size); - } - - if (cache) { - cache->memory_boundedness[interval.buffer->defining_position()] = - memory_boundedness; - } - return memory_boundedness; -} - -float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( - const HloPosition& position, - MemorySpaceAssignmentCostAnalysis::Cache* cache) const { - return GetAlternateMemoryBenefit( - *position.instruction, - GetInstructionElapsedDueToMemory( - *position.instruction, - /*operands_in_alternate_mem=*/{}, - /*outputs_in_alternate_mem=*/{position.index}), - cache); -} - -float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( - const HloUse& use, MemorySpaceAssignmentCostAnalysis::Cache* cache) const { - return GetAlternateMemoryBenefit( - *use.instruction, - GetInstructionElapsedDueToMemory( - *use.instruction, - /*operands_in_alternate_mem=*/{std::make_pair(use.operand_number, - use.operand_index)}), - cache); -} - -int MemorySpaceAssignmentCostAnalysis::CalculateComputationNestLevel( - const HloInstruction* instruction, bool while_only) const { - int nest_level = 0; - const HloComputation* computation = instruction->parent(); - while (!computation->IsEntryComputation()) { - auto& node = call_graph_->GetNode(computation); - auto callsites = node.caller_callsites(); - CHECK(node.computation()->IsAsyncComputation() || callsites.size() == 1) - << "The module is not flattened!"; - auto& callsite = callsites[0]; - if (!while_only || callsite.instruction()->opcode() == HloOpcode::kWhile) { - ++nest_level; - } - computation = callsite.instruction()->parent(); - } - return nest_level; -} - -float MemorySpaceAssignmentCostAnalysis::GetDefaultMemoryAccessOverhead( - const HloInstruction& instruction, - absl::Span> operands_in_alternate_mem, - absl::Span outputs_in_alternate_mem) const { - // Calculate the pipeline overhead of accessing the default memory. We use the - // maximum of the window size heuristic and the actual default memory bytes - // accessed multiplied with the compute as the overhead. So, the math is: - // - // overhead = compute_per_iteration - // = compute_elapsed / num_iterations - // = compute_elapsed / (bytes_accessed / window_size) - // = (window_size / bytes_accessed) * compute_elapsed - const float window_size_bytes = - options_.pipeline_overhead_window_size_mib * 1024 * 1024; - const float bytes_accessed = cost_analysis_.bytes_accessed(instruction); - const float default_memory_bytes_accessed = - bytes_accessed - - GetBytesAccessedFromAlternateMemory( - instruction, operands_in_alternate_mem, outputs_in_alternate_mem); - const float compute_elapsed = GetInstructionElapsedDueToCompute(instruction); - const float effective_window_size_bytes = - std::min(window_size_bytes, default_memory_bytes_accessed); - float overhead = 0; - if (bytes_accessed > 0) { - overhead = (effective_window_size_bytes / bytes_accessed) * compute_elapsed; - } - return overhead; -} - -float MemorySpaceAssignmentCostAnalysis::GetDefaultMemoryBandwidthIdleTime( - const HloInstruction& instruction, - absl::Span> operands_in_alternate_mem, - absl::Span outputs_in_alternate_mem) const { - const float default_memory_bytes_accessed = - cost_analysis_.bytes_accessed(instruction) - - GetBytesAccessedFromAlternateMemory( - instruction, operands_in_alternate_mem, outputs_in_alternate_mem); - const float elapsed_due_to_default_mem = - default_memory_bytes_accessed / - cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); - const float elapsed = GetInstructionElapsedInAlternateMemory( - instruction, operands_in_alternate_mem, outputs_in_alternate_mem); - return elapsed - elapsed_due_to_default_mem; -} - -float MemorySpaceAssignmentCostAnalysis::GetBytesAccessedFromAlternateMemory( - const HloInstruction& instruction, - absl::Span> operands_in_alternate_mem, - absl::Span outputs_in_alternate_mem) const { - float bytes_accessed_from_alternate_mem = 0.0; - for (auto& operand : operands_in_alternate_mem) { - const float operand_bytes_accessed = cost_analysis_.operand_bytes_accessed( - instruction, operand.first, operand.second); - bytes_accessed_from_alternate_mem += operand_bytes_accessed; - } - - for (auto& shape_idx : outputs_in_alternate_mem) { - const float output_bytes_accessed = - cost_analysis_.output_bytes_accessed(instruction, shape_idx); - bytes_accessed_from_alternate_mem += output_bytes_accessed; - } - return bytes_accessed_from_alternate_mem; -} - -namespace { -// Returns true on async instructions since we assume they are already -// efficiently scheduled such that they are not in the critical path and appear -// to take no time. -bool ExcludeInstructionFromElapsed(const HloInstruction& instruction) { - return instruction.opcode() == HloOpcode::kAllGatherStart || - instruction.opcode() == HloOpcode::kAllGatherDone || - instruction.opcode() == HloOpcode::kAllReduceStart || - instruction.opcode() == HloOpcode::kAllReduceDone || - instruction.opcode() == HloOpcode::kAsyncStart || - instruction.opcode() == HloOpcode::kAsyncDone || - instruction.opcode() == HloOpcode::kCollectivePermuteStart || - instruction.opcode() == HloOpcode::kCollectivePermuteDone || - instruction.opcode() == HloOpcode::kCopyStart || - instruction.opcode() == HloOpcode::kCopyDone; -} } // namespace -float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute( - const HloInstruction& instruction) const { - if (ExcludeInstructionFromElapsed(instruction)) { - return 0.0f; - } - return std::max( - cost_analysis_.flop_count(instruction) / - cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey), - cost_analysis_.transcendental_count(instruction) / - cost_analysis_.per_second_rate(HloCostAnalysis::kTranscendentalsKey)); -} - -float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory( - const HloInstruction& instruction, - absl::Span> operands_in_alternate_mem, - absl::Span outputs_in_alternate_mem) const { - if (ExcludeInstructionFromElapsed(instruction)) { - return 0.0f; - } - float total_bytes_accessed = cost_analysis_.bytes_accessed(instruction); - float bytes_accessed_from_alternate_mem = GetBytesAccessedFromAlternateMemory( - instruction, operands_in_alternate_mem, outputs_in_alternate_mem); - float elapsed_due_to_alternate_mem = - bytes_accessed_from_alternate_mem / - options().alternate_mem_bandwidth_bytes_per_second; - float elapsed_due_to_default_mem = - (total_bytes_accessed - bytes_accessed_from_alternate_mem) / - cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); - return elapsed_due_to_alternate_mem + elapsed_due_to_default_mem; -} - -float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory( - const HloInstruction& instruction, - IsInAlternateMemoryFun is_in_alternate_mem) const { - if (ExcludeInstructionFromElapsed(instruction)) { - return 0.0f; - } - float total_bytes_accessed = cost_analysis_.bytes_accessed(instruction); - float bytes_accessed_from_alternate_mem = 0.0; - for (int operand_num = 0; operand_num < instruction.operand_count(); - ++operand_num) { - ShapeUtil::ForEachSubshape( - instruction.operand(operand_num)->shape(), - [&](const Shape& subshape, const ShapeIndex& index) { - if (!subshape.IsArray()) { - return; - } - if (is_in_alternate_mem(operand_num, index, subshape)) { - bytes_accessed_from_alternate_mem += - cost_analysis_.operand_bytes_accessed(instruction, operand_num, - index); - } - }); - } - ShapeUtil::ForEachSubshape(instruction.shape(), [&](const Shape& subshape, - const ShapeIndex& index) { - if (!subshape.IsArray()) { - return; - } - if (is_in_alternate_mem(/*operand_num=*/std::nullopt, index, subshape)) { - bytes_accessed_from_alternate_mem += - cost_analysis_.output_bytes_accessed(instruction, index); - } - }); - float elapsed_due_to_alternate_mem = - bytes_accessed_from_alternate_mem / - options().alternate_mem_bandwidth_bytes_per_second; - float elapsed_due_to_default_mem = - (total_bytes_accessed - bytes_accessed_from_alternate_mem) / - cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); - return elapsed_due_to_alternate_mem + elapsed_due_to_default_mem; -} - -float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsed( - const HloInstruction& instruction) const { - if (ExcludeInstructionFromElapsed(instruction)) { - return 0.0f; - } - float overhead = GetDefaultMemoryAccessOverhead(instruction); - return std::max(GetInstructionElapsedDueToCompute(instruction), - GetInstructionElapsedDueToMemory(instruction) + overhead); -} - -float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedInAlternateMemory( - const HloInstruction& instruction, - absl::Span> operands_in_alternate_mem, - absl::Span outputs_in_alternate_mem) const { - if (ExcludeInstructionFromElapsed(instruction)) { - return 0.0f; - } - float overhead = GetDefaultMemoryAccessOverhead( - instruction, operands_in_alternate_mem, outputs_in_alternate_mem); - return std::max( - GetInstructionElapsedDueToCompute(instruction), - GetInstructionElapsedDueToMemory(instruction, operands_in_alternate_mem, - outputs_in_alternate_mem) + - overhead); -} - -float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedInAlternateMemory( - const HloInstruction& instruction, - IsInAlternateMemoryFun is_in_alternate_mem) const { - if (ExcludeInstructionFromElapsed(instruction)) { - return 0.0f; - } - return std::max( - GetInstructionElapsedDueToCompute(instruction), - GetInstructionElapsedDueToMemory(instruction, is_in_alternate_mem)); -} - -float MemorySpaceAssignmentCostAnalysis::GetAsyncCopyElapsed( - const Shape& shape) const { - int64_t size_in_bytes = cost_analysis_.GetShapeSize(shape); - return static_cast(size_in_bytes) / - (options().async_copy_bandwidth_bytes_per_second * - options().async_copy_bandwidth_scaling_factor); -} - -int64_t MemorySpaceAssignmentCostAnalysis::GetScheduleEndTime() const { - return hlo_live_range_->schedule_end_time(); -} - -bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy( - const Shape& shape, int64_t start_time, int64_t end_time) const { - return end_time - start_time <= max_overlap_count_; -} - -int64_t InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime( - const Shape& shape, int64_t start_time, int64_t latest_end_time) const { - return std::min(start_time + min_overlap_count_, latest_end_time); -} - -int64_t InstructionCountPrefetchIntervalPicker::LatestPrefetchStartTime( - const Shape& shape, int64_t start_time, int64_t end_time, - const HloUse* use) const { - return end_time - min_overlap_count_; -} - -int64_t InstructionCountPrefetchIntervalPicker::PreferredPrefetchStartTime( - const Shape& shape, int64_t earliest_prefetch_start_time, - int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const { - return std::max(earliest_prefetch_start_time, - prefetch_end_time - max_overlap_count_); -} - -int64_t InstructionCountPrefetchIntervalPicker::EstimatedPrefetchEndTime( - const Shape& shape, int64_t start_time, int64_t end_time) const { - // For testing, assume the end time is the estimated prefetch end time. - return end_time; -} - -float InstructionCountPrefetchIntervalPicker::GetLogicalIntervalElapsed( - int64_t start_time, int64_t end_time) const { - // For testing, just assume every HLO takes 1 second. - return static_cast(end_time - start_time - 1); -} - -void InstructionCountPrefetchIntervalPicker::Begin( - const HloUse& use, int64_t start_time, int64_t end_time, - std::optional preferred_time) { - end_time_ = end_time; - const Shape& shape = ShapeUtil::GetSubshape( - use.instruction->operand(use.operand_number)->shape(), use.operand_index); - if (preferred_time) { - current_prefetch_time_ = *preferred_time; - } else { - current_prefetch_time_ = - PreferredPrefetchStartTime(shape, start_time, end_time, end_time); - } -} - -int64_t InstructionCountPrefetchIntervalPicker::Next() { - CHECK(!Done()) << "Prefetch interval picker's Next() is called even though " - "Done() is false"; - return current_prefetch_time_++; -} - -bool InstructionCountPrefetchIntervalPicker::Done() const { - return end_time_ - current_prefetch_time_ <= min_overlap_count_; -} - -int64_t InstructionCountPrefetchIntervalPicker::latest_time() const { - return end_time_ - min_overlap_count_ - 1; -} - -std::string InstructionCountPrefetchIntervalPicker::ToDebugString() const { - return absl::StrCat("Overlapped HLOs = ", end_time_ - current_prefetch_time_); -} - -std::string InstructionCountPrefetchIntervalPicker::ToNoCopyDebugString( - const Shape& shape, int64_t start_time, int64_t end_time) const { - return absl::StrCat("Overlapped HLOs = ", end_time - start_time); -} - -CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( - const MemorySpaceAssignmentCostAnalysis& cost_analysis, - float min_overlap_to_async_copy_ratio, - float preferred_overlap_to_async_copy_ratio, - float max_overlap_to_mem_size_async_copy_ratio, int64_t mem_size_bytes, - const Shape* shape_override) - : while_nest_level_( - cost_analysis.hlo_live_range().instruction_schedule().size() + 1, 0), - computation_nest_level_( - cost_analysis.hlo_live_range().instruction_schedule().size() + 1, 0), - cost_analysis_(cost_analysis), - min_overlap_to_async_copy_ratio_(min_overlap_to_async_copy_ratio), - preferred_overlap_to_async_copy_ratio_( - preferred_overlap_to_async_copy_ratio), - max_async_copy_elapsed_( - cost_analysis_.GetAsyncCopyElapsed( - ShapeUtil::MakeShape(S32, {mem_size_bytes / 4})) * - max_overlap_to_mem_size_async_copy_ratio), - shape_override_(shape_override ? std::optional(*shape_override) - : std::nullopt) { - instruction_schedule_ = - &cost_analysis_.hlo_live_range().instruction_schedule(); - - // Create a vector of elapsed times and while nesting levels of HLO - // instructions. The elapsed times are multiplied by - // pow(while_execution_count, nest_level) to account for executing the HLOs - // multiple times in while loops. - std::vector instructions_elapsed_time( - instruction_schedule_->size() + 1, 0.0); - int max_while_nest_level = 0; - for (const auto& instruction_and_logical_time : *instruction_schedule_) { - // To avoid double counting, don't include the elapsed time of while and - // conditional HLOs. - const HloInstruction* instruction = instruction_and_logical_time.first; - int64_t logical_time = instruction_and_logical_time.second; - if (logical_time >= instructions_elapsed_time.size()) { - instructions_elapsed_time.resize(logical_time + 1, 0.0); - while_nest_level_.resize(logical_time + 1, 0); - } - int while_nest_level = cost_analysis_.CalculateComputationNestLevel( - instruction_and_logical_time.first, /*while_only=*/true); - while_nest_level_[logical_time] = while_nest_level; - max_while_nest_level = std::max(max_while_nest_level, while_nest_level); - int computation_nest_level = cost_analysis_.CalculateComputationNestLevel( - instruction_and_logical_time.first, /*while_only=*/false); - computation_nest_level_[logical_time] = computation_nest_level; - if (instruction->opcode() == HloOpcode::kWhile || - instruction->opcode() == HloOpcode::kConditional) { - continue; - } - float elapsed_time = cost_analysis_.GetInstructionElapsed( - *instruction_and_logical_time.first); - instructions_elapsed_time[logical_time] = - elapsed_time * - IPow(cost_analysis_.options() - .xla_tpu_memory_space_assignment_while_execution_count, - while_nest_level); - } - // As an optimization, create a cumulative sum vector of elapsed time. - float cumsum = 0.0; - elapsed_time_cumsum_.reserve(instructions_elapsed_time.size()); - for (float elapsed_time : instructions_elapsed_time) { - cumsum += elapsed_time; - elapsed_time_cumsum_.push_back(cumsum); - } - // To be able to accurately determine the minimum nest level between a start - // time and an end time efficiently, populate a data structure that stores the - // closest 'smaller' nest level change index. - const int64_t size = instructions_elapsed_time.size(); - CHECK_EQ(size, while_nest_level_.size()); - std::vector most_recent_by_level(while_nest_level_.size(), -1); - int prev_nest_level = 0; - int change_idx = -1; - while_nest_level_change_.reserve(size); - for (int i = 0; i < size; ++i) { - int nest_level = while_nest_level_[i]; - if (nest_level != prev_nest_level) { - prev_nest_level = nest_level; - // Compute last change index by choosing the most recent instruction index - // with smaller nesting level. Note that it may happen that even though - // there were few different regions with other nest levels before, all of - // then are same or bigger than this one, in which case we'll end up with - // -1, e.g. if you got nest level 0 no need checking anything else. - change_idx = -1; - for (int smaller_level = 0; smaller_level < nest_level; smaller_level++) { - change_idx = std::max(change_idx, most_recent_by_level[smaller_level]); - } - } - most_recent_by_level[nest_level] = i; - while_nest_level_change_.push_back(change_idx); - } - for (int i = 0; i <= max_while_nest_level; ++i) { - while_execution_counts_.push_back( - IPow(cost_analysis_.options() - .xla_tpu_memory_space_assignment_while_execution_count, - i)); - } -} - -float CostAnalysisPrefetchIntervalPicker::GetMaxElapsedInAlternateMemory( - float async_copy_elapsed) const { - return max_async_copy_elapsed_; -} - -bool CostAnalysisPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy( - const Shape& shape, int64_t start_time, int64_t end_time) const { - // Even though this method returns if we allow the buffer in alternate memory - // _without_ asynchronous copies, calculate how long it would have taken to - // copy it and compare it to the elapsed time in the logical interval. - float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( - shape_override_ ? *shape_override_ : shape); - float logical_interval_elapsed = - GetLogicalIntervalElapsed(start_time, end_time); - return GetMaxElapsedInAlternateMemory(async_copy_elapsed) > - logical_interval_elapsed; -} - -int64_t CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime( - const Shape& shape, int64_t start_time, int64_t latest_end_time) const { - float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( - shape_override_ ? *shape_override_ : shape); - int64_t end_time; - for (end_time = start_time + 1; end_time <= latest_end_time; ++end_time) { - float logical_interval_elapsed = - GetLogicalIntervalElapsed(start_time, end_time); - if (logical_interval_elapsed >= - (1 + kEvictionRetryMultiplier * retry_number_) * - preferred_overlap_to_async_copy_ratio_ * async_copy_elapsed) { - break; - } - } - return end_time; -} - -int64_t CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime( - const Shape& shape, int64_t start_time, int64_t end_time, - const HloUse* use) const { - // Find the earliest time that satisfies max_overlap_to_async_copy_ratio_. - float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( - shape_override_ ? *shape_override_ : shape); - // If there is a use, estimate the time we would save by having this op in - // alternate memory. - float inst_elapsed_reduction = 0.0f; - if (use) { - float elapsed_time = - cost_analysis_.GetInstructionElapsed(*use->instruction); - float elapsed_time_in_alternate_mem = - cost_analysis_.GetInstructionElapsedInAlternateMemory( - *use->instruction, - /*operands_in_alternate_mem=*/ - {std::make_pair(use->operand_number, use->operand_index)}, - /*outputs_in_alternate_mem=*/{}); - inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem; - } - int end_nest_level = computation_nest_level_[end_time]; - - // Find the latest time we're allowed to start prefetching. - float min_interval = min_overlap_to_async_copy_ratio_ * async_copy_elapsed; - int latest_prefetch_time; - for (latest_prefetch_time = end_time - 1; - latest_prefetch_time >= start_time && - (computation_nest_level_[latest_prefetch_time] != end_nest_level || - min_interval > - GetLogicalIntervalElapsed(latest_prefetch_time, end_time) + - inst_elapsed_reduction); - --latest_prefetch_time) { - } - - return latest_prefetch_time; -} - -int64_t CostAnalysisPrefetchIntervalPicker::PreferredPrefetchStartTime( - const Shape& shape, int64_t earliest_prefetch_start_time, - int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const { - // Between the earliest and latest prefetch interval, find the interval - // closest to the preferred interval and start iterating from there. - float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( - shape_override_ ? *shape_override_ : shape); - int64_t preferred_prefetch_start_time = earliest_prefetch_start_time; - float preferred_interval = - preferred_overlap_to_async_copy_ratio_ * async_copy_elapsed; - float best_interval = GetLogicalIntervalElapsed(earliest_prefetch_start_time, - prefetch_end_time); - int end_nest_level = computation_nest_level_[prefetch_end_time]; - for (int64_t prefetch_start_time = earliest_prefetch_start_time + 1; - prefetch_start_time <= latest_prefetch_start_time; - ++prefetch_start_time) { - float interval = - GetLogicalIntervalElapsed(prefetch_start_time, prefetch_end_time); - if (computation_nest_level_[prefetch_start_time] == end_nest_level && - std::abs(preferred_interval - interval) < - std::abs(preferred_interval - best_interval)) { - best_interval = interval; - preferred_prefetch_start_time = prefetch_start_time; - } - } - return preferred_prefetch_start_time; -} - -int64_t CostAnalysisPrefetchIntervalPicker::LatestPrefetchEndTime( - int64_t original_prefetch_end_time, - int64_t proposed_prefetch_end_time) const { - // Iterate towards the beginning until we find a suitable end time that is the - // same while nest level as the original prefetch end time. - int64_t original_nest_level = - computation_nest_level_[original_prefetch_end_time]; - int64_t new_prefetch_end_time; - for (new_prefetch_end_time = proposed_prefetch_end_time; - computation_nest_level_[new_prefetch_end_time] != original_nest_level; - --new_prefetch_end_time) { - } - return new_prefetch_end_time; -} - -int64_t CostAnalysisPrefetchIntervalPicker::EstimatedPrefetchEndTime( - const Shape& shape, int64_t start_time, int64_t end_time) const { - float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( - shape_override_ ? *shape_override_ : shape); - int64_t estimated_end_time; - for (estimated_end_time = start_time + 1; estimated_end_time < end_time; - ++estimated_end_time) { - float interval = GetLogicalIntervalElapsed(start_time, estimated_end_time); - if (interval >= async_copy_elapsed) { - break; - } - } - return estimated_end_time; -} - -void CostAnalysisPrefetchIntervalPicker::Begin( - const HloUse& use, int64_t start_time, int64_t end_time, - std::optional preferred_time) { - const Shape& shape = ShapeUtil::GetSubshape( - use.instruction->operand(use.operand_number)->shape(), use.operand_index); - // Find the earliest time that satisfies max_overlap_to_async_copy_ratio_. - async_copy_elapsed_ = cost_analysis_.GetAsyncCopyElapsed( - shape_override_ ? *shape_override_ : shape); - // Estimate the time we would save by having this op in alternate memory. - float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction); - float elapsed_time_in_alternate_mem = - cost_analysis_.GetInstructionElapsedInAlternateMemory( - *use.instruction, /*operands_in_alternate_mem=*/ - {std::make_pair(use.operand_number, use.operand_index)}, - /*outputs_in_alternate_mem=*/{}); - inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem; - end_logical_time_ = end_time; - int end_nest_level = computation_nest_level_[end_logical_time_]; - - // Find the latest time we're allowed to start prefetching. - float min_interval = min_overlap_to_async_copy_ratio_ * async_copy_elapsed_; - latest_prefetch_time_ = - LatestPrefetchStartTime(shape, start_time, end_time, &use); - - // Find the earliest time we're allowed to start prefetching. - float max_interval = GetMaxElapsedInAlternateMemory(async_copy_elapsed_); - for (earliest_prefetch_time_ = start_time; - earliest_prefetch_time_ < latest_prefetch_time_ && - (computation_nest_level_[earliest_prefetch_time_] != end_nest_level || - max_interval < GetLogicalIntervalElapsed(earliest_prefetch_time_, - end_logical_time_)); - ++earliest_prefetch_time_) { - } - if (earliest_prefetch_time_ > latest_prefetch_time_) { - // There is no available prefetch interval for the given start and end - // times. Set the iterators accordingly to ensure Done() returns true. - increasing_prefetch_time_iterator_ = earliest_prefetch_time_; - decreasing_prefetch_time_iterator_ = latest_prefetch_time_; - CHECK(Done()); - return; - } - - int64_t starting_prefetch_time; - if (preferred_time && *preferred_time <= latest_prefetch_time_) { - starting_prefetch_time = *preferred_time; - } else { - starting_prefetch_time = - PreferredPrefetchStartTime(shape, earliest_prefetch_time_, - latest_prefetch_time_, end_logical_time_); - } - float preferred_interval = - preferred_overlap_to_async_copy_ratio_ * async_copy_elapsed_; - VLOG(4) << "Interval min/max/preferred = " << min_interval << " " - << max_interval << " " << preferred_interval - << " prefetch time earliest/latest/starting = " - << earliest_prefetch_time_ << " " << latest_prefetch_time_ << " " - << starting_prefetch_time; - - increasing_prefetch_time_iterator_ = starting_prefetch_time; - decreasing_prefetch_time_iterator_ = starting_prefetch_time; - using_increasing_prefetch_time_iterator_ = true; - // Since both iterators start at the same position, call Next() once to - // advance one of the iterators. - Next(); -} - -int64_t CostAnalysisPrefetchIntervalPicker::Next() { - CHECK(!Done()) << "Prefetch interval picker's Next() is called even though " - "Done() is false"; - if (using_increasing_prefetch_time_iterator_) { - int64_t prefetch_time = increasing_prefetch_time_iterator_++; - while (increasing_prefetch_time_iterator_ <= latest_prefetch_time_ && - computation_nest_level_[increasing_prefetch_time_iterator_] != - computation_nest_level_[end_logical_time_]) { - ++increasing_prefetch_time_iterator_; - } - if (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_) { - using_increasing_prefetch_time_iterator_ = false; - } - return prefetch_time; - } else { - int64_t prefetch_time = decreasing_prefetch_time_iterator_--; - // As a compilation time optimization, reduce the number of intervals that - // this prefetch interval picker returns. When we run out of the increasing - // prefetch time iterator, only explore up to - // kNumExploredDecreasingIntervals intervals. To do that, calculate the - // 1/kNumExploredDecreasingIntervals of the elapsed time between the - // earliest prefetch time and the use, and decrement the iterator until the - // prefetch elapsed time is at least as large as this target value. This - // allows us to reduce the number of expensive heap fit and resource checks - // when the graph consists of a large number of fast-executing HLOs. - // - // Shown pictorially, assuming kNumExploredDecreasingIntervals = 3 and the - // numbers indicating the elapsed time of the HLOs, only the indicated - // options for prefetch start time would be explored: - // - // ---1---1---3---1---1---1---1---0---0---0---0---1---5---X - // ^ ^ ^ ^ - // Option3 Option2 Option1 Use - // (Earliest) - float next_target_interval_elapsed = 0; - if (increasing_prefetch_time_iterator_ > latest_prefetch_time_) { - next_target_interval_elapsed = - GetLogicalIntervalElapsed(prefetch_time, end_logical_time_) + - (GetLogicalIntervalElapsed(earliest_prefetch_time_, - end_logical_time_) / - kNumExploredDecreasingIntervals); - VLOG(3) << "Next target interval elapsed: " - << next_target_interval_elapsed; - } - while (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_ && - (computation_nest_level_[decreasing_prefetch_time_iterator_] != - computation_nest_level_[end_logical_time_] || - GetLogicalIntervalElapsed(decreasing_prefetch_time_iterator_, - end_logical_time_) < - next_target_interval_elapsed)) { - --decreasing_prefetch_time_iterator_; - } - if (increasing_prefetch_time_iterator_ <= latest_prefetch_time_) { - using_increasing_prefetch_time_iterator_ = true; - } - return prefetch_time; - } -} - -bool CostAnalysisPrefetchIntervalPicker::Done() const { - return increasing_prefetch_time_iterator_ > latest_prefetch_time_ && - decreasing_prefetch_time_iterator_ < earliest_prefetch_time_; -} - -int64_t CostAnalysisPrefetchIntervalPicker::latest_time() const { - return latest_prefetch_time_; -} - -void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) { - retry_number_ = retry_number; -} - -int CostAnalysisPrefetchIntervalPicker::GetMinWhileNestLevel( - int64_t start_time, int64_t end_time) const { - int min_nest_level = - std::min(while_nest_level_[start_time], while_nest_level_[end_time]); - int change_idx = while_nest_level_change_[end_time]; - while (change_idx >= start_time) { - min_nest_level = std::min(min_nest_level, while_nest_level_[change_idx]); - change_idx = while_nest_level_change_[change_idx]; - } - return min_nest_level; -} - -float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed( - int64_t start_time, int64_t end_time) const { - CHECK_LE(start_time, end_time); - if (start_time == end_time) { - return 0.0; - } - if (start_time < 0) { - start_time = 0; - } - // Since elapsed_time_cumsum_ is already weighed by the while loop nesting - // level, normalize the elapsed time by dividing with the nesting factor of - // the interval (start and end times). - int interval_while_nest_level = GetMinWhileNestLevel(start_time, end_time); - return (elapsed_time_cumsum_[end_time - 1] - - elapsed_time_cumsum_[start_time]) / - while_execution_counts_[interval_while_nest_level]; -} - -std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const { - int current_logical_prefetch_time = using_increasing_prefetch_time_iterator_ - ? increasing_prefetch_time_iterator_ - : decreasing_prefetch_time_iterator_; - float logical_interval_elapsed = GetLogicalIntervalElapsed( - current_logical_prefetch_time, end_logical_time_); - return absl::StrCat( - "Async copy elapsed (s) = ", async_copy_elapsed_, - ", inst elapsed reduction (s) = ", inst_elapsed_reduction_, - ", logical interval elapsed (s) = ", logical_interval_elapsed, - ", interval = (", current_logical_prefetch_time, ", ", end_logical_time_, - ")"); -} - -std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString( - const Shape& shape, int64_t start_time, int64_t end_time) const { - float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( - shape_override_ ? *shape_override_ : shape); - float logical_interval_elapsed = - GetLogicalIntervalElapsed(start_time, end_time); - return absl::StrCat( - "Async copy elapsed (s) = ", async_copy_elapsed, - ", logical interval elapsed (s) = ", logical_interval_elapsed); -} - -std::optional -CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) - const { - return cost_analysis_.GetMemoryBoundedness(interval); -} - -/*static*/ StatusOr> -FilterUpdatePreferredPrefetch::ParseFilterUpdatePreferredPrefetches( - std::string config) { - if (config.empty()) { - return std::vector(); - } - std::vector filter_update_prefetches; - std::vector filter_update_configs = absl::StrSplit(config, ';'); - for (const auto& config : filter_update_configs) { - TF_ASSIGN_OR_RETURN(auto filter_update_prefetch, - ParseFilterUpdatePreferredPrefetch(config)); - filter_update_prefetches.push_back(filter_update_prefetch); - } - return filter_update_prefetches; -} - -/*static*/ StatusOr FilterUpdatePreferredPrefetch::IsOpSizeGte( - int64_t operand_size, std::string config) { - int64_t config_value; - if (!absl::SimpleAtoi(config, &config_value)) { - return InvalidArgument("Expected integer, got %s for operand size filter", - config); - } - return operand_size >= config_value; -} - -/*static*/ StatusOr FilterUpdatePreferredPrefetch::IsOpSizeLte( - int64_t operand_size, std::string config) { - int64_t config_value; - if (!absl::SimpleAtoi(config, &config_value)) { - return InvalidArgument("Expected integer, got %s for operand size filter", - config); - } - return operand_size <= config_value; -} - -/*static*/ StatusOr FilterUpdatePreferredPrefetch::IsInstructionNameExact( - const absl::string_view instruction_name, std::string config) { - return instruction_name == config; -} - -/*static*/ StatusOr FilterUpdatePreferredPrefetch::IsOpNumberExact( - int64_t operand_number, std::string config) { - int64_t config_value; - if (!absl::SimpleAtoi(config, &config_value)) { - return InvalidArgument("Expected integer, got %s for operand number filter", - config); - } - return operand_number == config_value; -} - -/*static*/ StatusOr FilterUpdatePreferredPrefetch::IsOpIndexExact( - const ShapeIndex& operand_index, std::string config) { - TF_ASSIGN_OR_RETURN(auto config_value, ParseOperandIndex(config)); - return operand_index == config_value; -} - -StatusOr> -FilterUpdatePreferredPrefetch::GetPrefetchByEagerness( - int64_t earliest_prefetch_time, int64_t latest_prefetch_time) const { - if (earliest_prefetch_time > latest_prefetch_time) { - return static_cast>(std::nullopt); - } - float override_value; - if (!absl::SimpleAtof(override_value_, &override_value)) { - return InvalidArgument("Expected float, got %s for prefetch eagerness", - override_value_); - } - return static_cast>( - earliest_prefetch_time * override_value + - latest_prefetch_time * (1.0 - override_value)); -} - -StatusOr> -FilterUpdatePreferredPrefetch::GetPrefetchTimeAfterInstruction( - const absl::flat_hash_map& schedule) const { - TF_ASSIGN_OR_RETURN(auto reference_instruction_time, - GetScheduleTimeFromInstructionName(schedule)); - return static_cast>(reference_instruction_time); -} - -StatusOr> -FilterUpdatePreferredPrefetch::GetPrefetchTimeBeforeInstruction( - const absl::flat_hash_map& schedule) const { - TF_ASSIGN_OR_RETURN(auto reference_instruction_time, - GetScheduleTimeFromInstructionName(schedule)); - return static_cast>(reference_instruction_time - 1); -} - -StatusOr -FilterUpdatePreferredPrefetch::GetScheduleTimeFromInstructionName( - const absl::flat_hash_map& schedule) const { - for (auto schedule_entry : schedule) { - if (schedule_entry.first->name() == override_value_) { - return schedule_entry.second; - } - } - return NotFound("Reference instruction %s was not found in the schedule.", - override_value_); -} - -/*static*/ StatusOr -FilterUpdatePreferredPrefetch::ParseFilterType(std::string config) { - if (config == "op_size_lte") { - return FilterType::OP_SIZE_LTE; - } - if (config == "op_size_gte") { - return FilterType::OP_SIZE_GTE; - } - if (config == "instruction_name_exact") { - return FilterType::INSTRUCTION_NAME_EXACT; - } - if (config == "op_number_exact") { - return FilterType::OP_NUMBER_EXACT; - } - if (config == "op_index_exact") { - return FilterType::OP_INDEX_EXACT; - } - return InvalidArgument("Failed to parse filter type %s", config); -} - -/*static*/ StatusOr -FilterUpdatePreferredPrefetch::ParseOverrideType(std::string config) { - if (config == "prefetch_eagerness") { - return OverrideType::PREFETCH_EAGERNESS; - } - if (config == "put_after_instruction") { - return OverrideType::PUT_AFTER_INSTRUCTION; - } - if (config == "put_before_instruction") { - return OverrideType::PUT_BEFORE_INSTRUCTION; - } - return InvalidArgument("Failed to parse override type %s", config); -} - -/*static*/ StatusOr -FilterUpdatePreferredPrefetch::ParseOperandIndex(std::string config) { - ShapeIndex operand_index{}; - if (config.empty()) { - return operand_index; - } - for (const absl::string_view& index_string : absl::StrSplit(config, '#')) { - int64_t index; - if (!absl::SimpleAtoi(index_string, &index)) { - return InvalidArgument("Failed to parse operand_index %s", config); - } - operand_index.push_back(index); - } - return operand_index; -} - -/*static*/ StatusOr -FilterUpdatePreferredPrefetch::ParseFilterUpdatePreferredPrefetch( - std::string config) { - std::vector filter_update_config = absl::StrSplit(config, ':'); - if (filter_update_config.size() < 4 || filter_update_config.size() % 2 != 0) { - return InvalidArgument( - "Failed to parse filter update config %s, incorrect number of " - "arguments", - config); - } - FilterUpdatePreferredPrefetch result; - result.config_string_ = config; - for (int i = 0; i < filter_update_config.size() - 2; i += 2) { - TF_ASSIGN_OR_RETURN(auto filter_type, - ParseFilterType(filter_update_config[i])); - result.filter_list_.push_back( - std::make_pair(filter_type, filter_update_config[i + 1])); - } - TF_ASSIGN_OR_RETURN( - result.override_type_, - ParseOverrideType(filter_update_config[filter_update_config.size() - 2])); - result.override_value_ = filter_update_config.back(); - return result; -} - -bool MemorySpaceAssignment::Allocation::operator==( - const MemorySpaceAssignment::Allocation& other) const { - return defining_position() == other.defining_position() && - uses() == other.uses() && memory_space() == other.memory_space() && - chunk() == other.chunk() && start_time() == other.start_time() && - end_time() == other.end_time() && - earliest_available_time() == other.earliest_available_time() && - is_copy_allocation() == other.is_copy_allocation() && - is_scoped_allocation() == other.is_scoped_allocation(); -} - -bool MemorySpaceAssignment::CopyAllocation::operator==( - const MemorySpaceAssignment::CopyAllocation& other) const { - return static_cast(*this) == - static_cast(other) && - copy_done_schedule_before() == other.copy_done_schedule_before() && - copy_start_schedule_after() == other.copy_start_schedule_after() && - copy_start() == other.copy_start() && copy_done() == other.copy_done(); -} - std::string MemorySpaceAssignment::AllocationValue::ToString() const { std::string out = absl::StrCat("computation = ", computation()->name()); absl::StrAppend(&out, @@ -1743,22 +750,51 @@ std::string MemorySpaceAssignment::AllocationValue::ToShortString() const { (requires_contiguous_allocation_ ? " (cont alloc)" : "")); } +bool AlternateMemoryBestFitHeap::IsIntervalPinnedToAlternateMemory( + const AlternateMemoryBestFitHeap::BufferInterval& interval) const { + const Shape& shape = interval.buffer->shape(); + return shape.has_layout() && + shape.layout().memory_space() == options_.alternate_memory_space; +} + AlternateMemoryBestFitHeap::AlternateMemoryBestFitHeap( - MemorySpaceAssignment::AllocationSequence* allocations, - const Options& options, const HloAliasAnalysis& alias_analysis, - const HloLiveRange& hlo_live_range) - : GlobalDecreasingSizeBestFitHeap(options.alignment_in_bytes), + AllocationSequence* allocations, const Options& options, + const HloAliasAnalysis& alias_analysis, const HloLiveRange& hlo_live_range) + : GlobalDecreasingSizeBestFitHeap( + options.alignment_in_bytes, + /*type=*/kSpatial, /*buffer_interval_compare=*/nullptr, + (options.sliced_prefetch_options.max_slices() > + options.sliced_prefetch_options + .all_slice_time_permutations_threshold() + ? SliceTimePermutationIterator::Ty::kPreferred + : SliceTimePermutationIterator::Ty::kAll)), allocations_(allocations), options_(options), alias_analysis_(alias_analysis), hlo_live_range_(hlo_live_range), peak_memory_usage_(hlo_live_range.schedule_end_time() + 1) { // Override buffer interval compare if provided. + auto comparison_function = GetSpatialBufferIntervalCompare(); if (options.buffer_interval_comparator) { - buffer_interval_compare_ = + comparison_function = options.buffer_interval_comparator->GetComparisonFunctor(); } + // Prioritize pinned buffers in the buffer interval order. + buffer_interval_compare_ = + [this, comparison_function = std::move(comparison_function)]( + const BufferInterval& a, const BufferInterval& b) { + bool is_a_pinned = IsIntervalPinnedToAlternateMemory(a); + bool is_b_pinned = IsIntervalPinnedToAlternateMemory(b); + if (is_a_pinned && !is_b_pinned) { + return true; + } + if (!is_a_pinned && is_b_pinned) { + return false; + } + return comparison_function(a, b); + }; + call_graph_ = CallGraph::Build(&alias_analysis_.dataflow_analysis().module()); std::vector initial_resources(hlo_live_range.schedule_end_time(), 1.0); @@ -1829,7 +865,6 @@ void AlternateMemoryBestFitHeap::CreateAllocationValues( }); // Create an AllocationValue for each non-trivial position. - absl::flat_hash_set computations; int beginning_idx = allocation_values.size(); for (int i = 0; i < positions.size(); ++i) { const HloPosition& position = positions.at(i); @@ -1884,7 +919,13 @@ void AlternateMemoryBestFitHeap::CreateAllocationValues( CHECK(HloDataflowAnalysis::IsAsynchronousOperationDone( allocation_value.uses().at(0).hlo_use.instruction->opcode())); VLOG(3) << "Mark " << allocation_value.ToShortString() - << " to require contiguous allocation."; + << " to require contiguous allocation because it is an async " + "start operation."; + allocation_value.set_requires_contiguous_allocation(true); + } else if (options_.position_requires_contiguous_allocation_fn( + allocation_value.defining_position())) { + VLOG(3) << "Mark " << allocation_value.ToShortString() + << " to require contiguous allocation because of options."; allocation_value.set_requires_contiguous_allocation(true); } VLOG(3) << "Created allocation value: " @@ -2162,8 +1203,7 @@ void AlternateMemoryBestFitHeap::AppendScopedAllocationBufferInfoDebugString( } void AlternateMemoryBestFitHeap::AppendAllocationInfoDebugString( - const MemorySpaceAssignment::Allocation& allocation, - std::string& debug_str) const { + const Allocation& allocation, std::string& debug_str) const { // Columns in allocation information: // buffer_id: int. This value can be used the match with buffer info. // size: int. In bytes. @@ -2193,1192 +1233,9 @@ void AlternateMemoryBestFitHeap::DumpDebugStringsIfEnabled() const { if (!options_.dump_fn) { return; } - options_.dump_fn("bufferinfo", buffer_info_str_); - options_.dump_fn("allocinfo", allocation_info_str_); - options_.dump_fn("scheduleinfo", instruction_schedule_str_); -} - -/*static*/ StatusOr> -MemoryBoundLoopOptimizer::Create( - int loop_start, int loop_end, uint64_t alternate_memory_size, - const MemoryBoundLoopOptimizerOptions& options, - const HloLiveRange& hlo_live_range, const HloAliasAnalysis& alias_analysis, - const MemorySpaceAssignmentCostAnalysis& cost_analysis, - const BufferValue::SizeFunction& size_function) { - std::unique_ptr optimizer = - absl::WrapUnique(new MemoryBoundLoopOptimizer( - loop_start, loop_end, alternate_memory_size, options, hlo_live_range, - alias_analysis, cost_analysis, size_function)); - TF_RETURN_IF_ERROR(optimizer->Initialize()); - return std::move(optimizer); -} - -MemoryBoundLoopOptimizer::MemoryBoundLoopOptimizer( - int loop_start, int loop_end, uint64_t alternate_memory_size, - const MemoryBoundLoopOptimizerOptions& options, - const HloLiveRange& hlo_live_range, const HloAliasAnalysis& alias_analysis, - const MemorySpaceAssignmentCostAnalysis& cost_analysis, - const BufferValue::SizeFunction& size_function) - : loop_start_(loop_start), - loop_end_(loop_end), - loop_size_(loop_end - loop_start), - alternate_memory_size_(alternate_memory_size), - options_(options), - hlo_live_range_(hlo_live_range), - alias_analysis_(alias_analysis), - cost_analysis_(cost_analysis), - size_function_(size_function) {} - -Status MemoryBoundLoopOptimizer::Initialize() { - const auto& instruction_sequence = - hlo_live_range_.flattened_instruction_sequence().instructions(); - VLOG(3) << "MemoryBoundLoopOptimizer::Initialize, loop start: " << loop_start_ - << ", loop end: " << loop_end_ << ", loop size: " << loop_size_; - const HloComputation* loop_computation = nullptr; - // Initialize the remaining memory array with the size of the alternate - // memory. Also populate instructions_in_loop_ and - // instructions_in_{prev,next}_iterations_ data structures to help find the - // loop values. - for (int i = loop_start_; i < loop_end_; ++i) { - const HloInstruction* inst = instruction_sequence[i]; - instructions_in_loop_[inst] = i - loop_start_; - VLOG(3) << " inst in loop [" << (i - loop_start_) << "]: " << inst->name(); - if (!loop_computation) { - loop_computation = inst->parent(); - } else { - TF_RET_CHECK(loop_computation == inst->parent()); - } - remaining_memory_.push_back(alternate_memory_size_); - } - - for (int i = loop_start_ - loop_size_; i < loop_start_; ++i) { - const HloInstruction* inst = instruction_sequence[i]; - instructions_in_prev_iteration_[inst] = i - loop_start_ + loop_size_; - } - for (int i = loop_end_; i < loop_end_ + loop_size_; ++i) { - const HloInstruction* inst = instruction_sequence[i]; - instructions_in_next_iteration_[inst] = i - loop_end_; - } - - // Create a tree set to keep track of all the values that the loop - // instructions produce and consume. We use a tree set instead of a hash set - // to ensure the iteration order is the same as insertion order. Since we - // traverse the program in instruction order, the buffers would be inserted in - // a deterministic order, so we'll be able to iterate over these buffers in a - // deterministic order. - std::set buffers_to_process; - for (const auto& [instruction, idx] : instructions_in_loop_) { - auto maybe_add_buffer = [&](const HloInstruction* instruction) { - return [this, &buffers_to_process, instruction](const Shape& subshape, - const ShapeIndex& index) { - if (!subshape.IsArray()) { - return; - } - const HloBuffer& buffer = - alias_analysis_.GetUniqueBufferAt(instruction, index); - if (buffers_to_process.find(&buffer) == buffers_to_process.end()) { - buffers_to_process.insert(&buffer); - } - }; - }; - ShapeUtil::ForEachSubshape(instruction->shape(), - maybe_add_buffer(instruction)); - for (const HloInstruction* operand : instruction->operands()) { - ShapeUtil::ForEachSubshape(operand->shape(), maybe_add_buffer(operand)); - } - } - - // Process the buffers and decide if they should be added as LoopValues. - for (const HloBuffer* buffer : buffers_to_process) { - MaybeCreateLoopValue(*buffer, loop_computation); - } - return OkStatus(); -} - -void MemoryBoundLoopOptimizer::MaybeCreateLoopValue( - const HloBuffer& buffer, const HloComputation* loop_computation) { - // Define helper lambdas to get the loop-relative index of the given - // instruction. - auto get_index_in_loop = - [&](const HloInstruction* instruction, - const absl::flat_hash_map& - instructions_in_loop, - int64_t relative_index = 0) { - std::optional loop_index; - if (instructions_in_loop.contains(instruction)) { - loop_index = hlo_live_range_.instruction_schedule().at(instruction) - - loop_start_ + relative_index; - CHECK_GE(*loop_index, 0); - CHECK_LT(*loop_index, loop_size_); - } - return loop_index; - }; - auto get_index_in_current_iteration = [&](const HloInstruction* instruction) { - return get_index_in_loop(instruction, instructions_in_loop_); - }; - auto get_index_in_prev_iteration = [&](const HloInstruction* instruction) { - return get_index_in_loop(instruction, instructions_in_prev_iteration_, - loop_size_); - }; - auto get_index_in_next_iteration = [&](const HloInstruction* instruction) { - return get_index_in_loop(instruction, instructions_in_next_iteration_, - -loop_size_); - }; - - loop_values_.push_back({}); - LoopValue& loop_value = loop_values_.back(); - float pos_bytes = 0; - float use_bytes = 0; - bool has_footer_consumer = false; - for (const HloValue* value : buffer.values()) { - // For each position and use of the value, populate the respecive position - // and use fields for the current, previous, and next iterations along with - // the loop indices. - for (const HloPosition& position : value->positions()) { - if (position.instruction->opcode() == HloOpcode::kGetTupleElement) { - continue; - } - std::optional loop_index = - get_index_in_current_iteration(position.instruction); - std::optional prev_iteration_index; - if (loop_index) { - loop_value.loop_positions.push_back({*loop_index, position}); - VLOG(3) << "Pos match: " << position.instruction->name() << " at " - << *loop_index; - } else if ((prev_iteration_index = - get_index_in_prev_iteration(position.instruction))) { - loop_value.prev_iteration_positions.push_back( - {*prev_iteration_index, position}); - VLOG(3) << "Pos match (prev iteration): " - << position.instruction->name() << " at " - << *prev_iteration_index; - } else if (loop_value.prev_iteration_positions.empty() && - loop_value.loop_positions.empty() && - position.instruction->parent() == loop_computation && - !loop_value.header_position) { - loop_value.header_position = position; - } - - // Keep track of bytes accessed by this value. - if (loop_index || prev_iteration_index) { - float bytes_accessed = - cost_analysis_.cost_analysis().output_bytes_accessed( - *position.instruction, position.index); - pos_bytes += bytes_accessed; - VLOG(3) << " accessed: " << bytes_accessed; - } - } - - for (const HloUse& use : value->GetUses()) { - if (use.instruction->opcode() == HloOpcode::kGetTupleElement) { - continue; - } - std::optional loop_index = - get_index_in_current_iteration(use.instruction); - std::optional next_iteration_index; - if (loop_index) { - loop_value.loop_uses.push_back({*loop_index, use}); - VLOG(3) << "Use match: " << use.instruction->name() << " at " - << *loop_index; - } else if ((next_iteration_index = - get_index_in_next_iteration(use.instruction))) { - loop_value.next_iteration_uses.push_back({*next_iteration_index, use}); - VLOG(3) << "Use match (next iteration): " << use.instruction->name() - << " at " << *next_iteration_index; - } else if (!loop_value.loop_positions.empty() || - !loop_value.loop_uses.empty()) { - has_footer_consumer = true; - } - - // Keep track of bytes accessed by this value. - if (loop_index || next_iteration_index) { - float bytes_accessed = - cost_analysis_.cost_analysis().operand_bytes_accessed( - *use.instruction, use.operand_number, use.operand_index); - use_bytes += bytes_accessed; - VLOG(3) << " accessed: " << bytes_accessed; - } - } - } - - // We only add the loop position if it has a position or use in the current - // iteration and its previous iteration positions are empty. The reason why we - // disallow values with previous iteration positions is because there will be - // a different value that corresponds to the same value but one iteration - // later, so we will add that one instead. - if ((!loop_value.loop_positions.empty() || !loop_value.loop_uses.empty()) && - loop_value.prev_iteration_positions.empty()) { - loop_value.size = size_function_(**buffer.values().begin()); - VLOG(3) << "Size: " << loop_value.size; - // Classify the type of allocation. See the comment in LoopValue definition. - loop_value.allocation_type = LoopValue::AllocationType::kUnsupported; - auto position_compare = [](const std::pair& a, - const std::pair& b) { - return a.first < b.first; - }; - auto use_compare = [](const std::pair& a, - const std::pair& b) { - return a.first < b.first; - }; - absl::c_sort(loop_value.loop_positions, position_compare); - absl::c_sort(loop_value.prev_iteration_positions, position_compare); - absl::c_sort(loop_value.loop_uses, use_compare); - absl::c_sort(loop_value.next_iteration_uses, use_compare); - if (!loop_value.loop_positions.empty()) { - if (loop_value.next_iteration_uses.empty() && - !loop_value.loop_uses.empty()) { - loop_value.allocation_type = LoopValue::AllocationType::kTemporary; - } else if (!loop_value.next_iteration_uses.empty()) { - if (loop_value.next_iteration_uses.back().first >= - loop_value.loop_positions.front().first) { - loop_value.allocation_type = - LoopValue::AllocationType::kLoopCarriedDependence; - } else { - loop_value.allocation_type = LoopValue::AllocationType::kTemporary; - } - } - } else if (loop_value.header_position && !loop_value.loop_uses.empty()) { - if (loop_value.loop_uses.size() == - loop_value.next_iteration_uses.size() && - loop_value.loop_uses.front().first == - loop_value.next_iteration_uses.front().first) { - loop_value.allocation_type = LoopValue::AllocationType::kPinned; - } else if (loop_value.next_iteration_uses.empty() || - loop_value.next_iteration_uses.back().first < - loop_value.loop_uses.front().first) { - loop_value.allocation_type = LoopValue::AllocationType::kPrefetch; - } - } - - VLOG(3) << "Allocation type " - << LoopValue::AllocationTypeToString(loop_value.allocation_type); - VLOG(3) << "Pos bytes: " << pos_bytes << " use bytes: " << use_bytes; - - // We calculate the savings of allocating this buffer in the alternate - // memory. - float savings = pos_bytes + use_bytes; - if (loop_value.header_position) { - savings -= loop_value.size; - } - if (!loop_value.loop_positions.empty() && has_footer_consumer) { - savings -= loop_value.size; - } - loop_value.savings = savings; - loop_value.savings_per_byte = savings / loop_value.size; - VLOG(3) << "Savings: " << loop_value.savings; - VLOG(3) << "Savings per byte: " << loop_value.savings_per_byte; - for (const HloValue* value : buffer.values()) { - VLOG(3) << value->ToString(); - } - auto sort_positions = [](const std::pair& a, - const std::pair& b) { - return a.first < b.first; - }; - auto sort_uses = [](const std::pair& a, - const std::pair& b) { - return a.first < b.first; - }; - absl::c_sort(loop_value.loop_positions, sort_positions); - absl::c_sort(loop_value.prev_iteration_positions, sort_positions); - absl::c_sort(loop_value.loop_uses, sort_uses); - absl::c_sort(loop_value.next_iteration_uses, sort_uses); - loop_value.hlo_values = buffer.values(); - } else { - loop_values_.pop_back(); - } -} - -void MemoryBoundLoopOptimizer::Optimize() { - SortLoopValues(); - AllocateLoopValues(); - PostProcess(); -} - -float MemoryBoundLoopOptimizer::CalculateExecutionTime() const { - // First populate the list of prefetches. - std::vector> - prefetches; - for (const LoopValue& value : loop_values_) { - if (!value.allocations.empty() && - value.allocations.back()->is_copy_allocation()) { - prefetches.push_back( - {static_cast( - value.allocations.back().get()), - cost_analysis_.GetAsyncCopyElapsed( - value.hlo_values.front()->shape())}); - } - } - - // Returns the effective prefetch completion time. The effective time is a - // value that will be larger than loop size for prefetches that start in this - // iteration but complete in the next iteration. - auto get_effective_done_time = - [&](int64_t copy_start_schedule_after, - int64_t copy_done_schedule_before) -> int64_t { - if (copy_start_schedule_after == loop_size_ - 1 && - copy_done_schedule_before == 0) { - return 2 * loop_size_; - } - if (copy_start_schedule_after + 1 >= copy_done_schedule_before) { - return copy_done_schedule_before + loop_size_; - } - return copy_done_schedule_before; - }; - - // Sort the prefetches by first the start time, then the effective done time. - absl::c_sort( - prefetches, - [&](const std::pair& - a, - const std::pair& - b) { - return std::forward_as_tuple( - a.first->copy_start_schedule_after(), - get_effective_done_time( - a.first->copy_start_schedule_after(), - a.first->copy_done_schedule_before())) < - std::forward_as_tuple(b.first->copy_start_schedule_after(), - get_effective_done_time( - b.first->copy_start_schedule_after(), - b.first->copy_done_schedule_before())); - }); - // Populate the required prefetch completions array. For each instruction in - // the loop, this vector holds the index of the latest-issued prefetch that - // needs to be completed before the instruction executes, or nullopt if there - // is no prefetch that needs to finish by this instruction. To represent - // prefetches that started in the previous iteration, we use negative numbers. - std::vector> required_prefetch_completions(loop_size_); - for (int i = 0; i < prefetches.size(); ++i) { - const auto& [prefetch, elapsed] = prefetches[i]; - int required_prefetch_completion = i; - if (prefetch->copy_start_schedule_after() == loop_size_ - 1 && - prefetch->copy_done_schedule_before() == 0) { - required_prefetch_completion -= 2 * prefetches.size(); - } else if (prefetch->copy_start_schedule_after() + 1 >= - prefetch->copy_done_schedule_before()) { - required_prefetch_completion -= prefetches.size(); - } - VLOG(3) << "Prefetch #" << i << " (elapsed " << elapsed - << "): " << prefetch->ToString(); - if (required_prefetch_completions[prefetch->copy_done_schedule_before()]) { - required_prefetch_completions[prefetch->copy_done_schedule_before()] = - std::max( - *required_prefetch_completions[prefetch - ->copy_done_schedule_before()], - required_prefetch_completion); - } else { - required_prefetch_completions[prefetch->copy_done_schedule_before()] = - required_prefetch_completion; - } - VLOG(4) - << "Required completion at " << prefetch->copy_done_schedule_before() - << " = " - << *required_prefetch_completions[prefetch - ->copy_done_schedule_before()]; - } - - // Populate the elapsed times of instructions and bandwidth idle times at each - // point. - float result; - std::vector bandwidth_idle_times; - std::vector instructions_elapsed; - bandwidth_idle_times.reserve(loop_size_); - instructions_elapsed.reserve(loop_size_); - for (int i = 0; i < loop_size_; ++i) { - bandwidth_idle_times.push_back(GetBandwidthIdleTime(i)); - instructions_elapsed.push_back(GetInstructionElapsed(i)); - } - // We simulate the loop for three iterations to measure the steady state. - const int kNumIterations = 3; - // This data structure keeps track of the elapsed time remaining of each - // prefetch. Note that there is a separate entry for each prefetch in each - // iteration simulated. - std::vector prefetch_remaining_elapsed_times(prefetches.size() * - kNumIterations); - int prefetch_start_index = 0; - int prefetch_done_index = 0; - int prefetch_completed_index = 0; - - for (int iteration = 0; iteration < kNumIterations; ++iteration) { - float total_elapsed = 0; - float total_bandwidth_idle_time = 0; - float total_critical_prefetch = 0; - for (int i = 0; i < loop_size_; ++i) { - // If any prefetches are expected to be completed, check if they have any - // remaining elapsed time associated with them, and if so add this to - // critical prefetch time. - std::optional required_prefetch_completion = - required_prefetch_completions[i]; - if (required_prefetch_completion) { - int required_prefetch_done_index = - iteration * static_cast(prefetches.size()) + - *required_prefetch_completion; - VLOG(4) << "Prefetch #" - << ((*required_prefetch_completion + prefetches.size()) % - prefetches.size()) - << " (" << required_prefetch_done_index - << ") is required to be completed at " << i; - for (; prefetch_done_index <= required_prefetch_done_index; - ++prefetch_done_index) { - CHECK_LE(prefetch_done_index, prefetch_start_index); - if (prefetch_done_index == prefetch_completed_index) { - float& prefetch_remaining = - prefetch_remaining_elapsed_times[prefetch_done_index]; - VLOG(4) << "Prefetch #" << (prefetch_done_index % prefetches.size()) - << " (" << prefetch_done_index - << ") did not complete, remaining elapsed = " - << prefetch_remaining; - total_critical_prefetch += prefetch_remaining; - prefetch_remaining = 0; - ++prefetch_completed_index; - } - } - } - - float elapsed = instructions_elapsed[i]; - total_elapsed += elapsed; - float bandwidth_idle_time = bandwidth_idle_times[i]; - // Find the outstanding prefetches during this instruction, and if any of - // them have remaining time, spend some or all of the bandwidth idle time - // to satisfy them. - for (; prefetch_completed_index < prefetch_start_index; - ++prefetch_completed_index) { - float& prefetch_remaining = - prefetch_remaining_elapsed_times[prefetch_completed_index]; - if (bandwidth_idle_time < prefetch_remaining) { - prefetch_remaining -= bandwidth_idle_time; - bandwidth_idle_time = 0; - VLOG(4) << "Prefetch #" - << (prefetch_completed_index % prefetches.size()) << " (" - << prefetch_completed_index << ") still ongoing at " << i - << ", remaining elapsed = " << prefetch_remaining; - break; - } - bandwidth_idle_time -= prefetch_remaining; - prefetch_remaining = 0; - VLOG(4) << "Prefetch #" - << (prefetch_completed_index % prefetches.size()) << " (" - << prefetch_completed_index << ") completed at " << i - << ", bandwidth idle time = " << bandwidth_idle_time; - } - if (bandwidth_idle_time > 0) { - VLOG(4) << "Bandwidth idle time at " << i << " = " - << bandwidth_idle_time; - total_bandwidth_idle_time += bandwidth_idle_time; - } - - // Start new prefetches that are scheduled to start after this - // instruction. - for (; prefetch_start_index < (iteration + 1) * prefetches.size() && - prefetches[prefetch_start_index % prefetches.size()] - .first->copy_start_schedule_after() == i; - ++prefetch_start_index) { - float& prefetch_remaining = - prefetch_remaining_elapsed_times[prefetch_start_index]; - prefetch_remaining = - prefetches[prefetch_start_index % prefetches.size()].second; - VLOG(4) << "Prefetch #" << (prefetch_start_index % prefetches.size()) - << " (" << prefetch_start_index << ") started at " << i - << ", remaining elapsed = " << prefetch_remaining; - } - } - VLOG(3) << "Iteration " << iteration; - VLOG(3) << "Total elapsed: " << total_elapsed - << ", total critical prefetch: " << total_critical_prefetch - << ", total bandwidth idle time: " << total_bandwidth_idle_time; - result = total_elapsed + total_critical_prefetch; - } - return result; -} - -/*static*/ std::string -MemoryBoundLoopOptimizer::LoopValue::AllocationTypeToString( - LoopValue::AllocationType allocation_type) { - switch (allocation_type) { - case AllocationType::kTemporary: - return "temporary"; - case AllocationType::kLoopCarriedDependence: - return "loop-carried dependence"; - case AllocationType::kPinned: - return "pinned"; - case AllocationType::kPrefetch: - return "prefetch"; - default: - CHECK(allocation_type == AllocationType::kUnsupported); - return "unsupported"; - } -} - -std::string MemoryBoundLoopOptimizer::LoopValue::ToString() const { - std::string values_str; - absl::StrAppend(&values_str, "Values:"); - for (const HloValue* hlo_value : hlo_values) { - absl::StrAppend(&values_str, "\n - ", hlo_value->ToShortString()); - } - std::string allocations_str; - if (!allocations.empty()) { - absl::StrAppend(&allocations_str, "Allocations:"); - } - for (const auto& allocation : allocations) { - absl::StrAppend(&allocations_str, "\n - ", allocation->ToString()); - } - return absl::StrCat( - "Size: ", size, " savings: ", savings, - " savings per byte: ", savings_per_byte, - " allocation type: ", AllocationTypeToString(allocation_type), "\n", - values_str, "\n", allocations_str); -} - -bool MemoryBoundLoopOptimizer::LoopValue::IsAllocationTypeSupported() const { - return allocation_type == AllocationType::kTemporary || - allocation_type == AllocationType::kPinned || - allocation_type == AllocationType::kPrefetch; -} - -void MemoryBoundLoopOptimizer::SortLoopValues() { - absl::c_stable_sort(loop_values_, [](const LoopValue& a, const LoopValue& b) { - return a.savings_per_byte > b.savings_per_byte; - }); -} - -void MemoryBoundLoopOptimizer::AllocateLoopValues() { - // This function allocates loop values. - std::vector prefetch_values; - VLOG(3) << "Pre optimization execution time: " << CalculateExecutionTime(); - for (LoopValue& value : loop_values_) { - switch (value.allocation_type) { - case LoopValue::AllocationType::kTemporary: - AllocateTemporary(value); - break; - case LoopValue::AllocationType::kPinned: - AllocatePinned(value); - break; - case LoopValue::AllocationType::kPrefetch: - prefetch_values.push_back(&value); - break; - case LoopValue::AllocationType::kLoopCarriedDependence: - case LoopValue::AllocationType::kUnsupported: - VLOG(1) << "Unsupported allocation: " << value.ToString(); - } - } - VLOG(3) << "Execution time after allocating temporaries: " - << CalculateExecutionTime(); - AllocatePrefetches(absl::MakeSpan(prefetch_values)); - VLOG(3) << "Execution time after allocating prefetches: " - << CalculateExecutionTime(); -} - -void MemoryBoundLoopOptimizer::PostProcess() { - // At the end, ensure that all loop uses have a corresponding Allocation and - // create one in the default memory space if they don't. - for (LoopValue& value : loop_values_) { - absl::flat_hash_set allocated_uses; - for (const auto& allocation : value.allocations) { - for (const HloUse& use : allocation->uses()) { - allocated_uses.insert(use); - } - } - std::vector unallocated_uses; - absl::flat_hash_set use_indices; - for (const auto& [idx, use] : value.loop_uses) { - use_indices.insert(idx); - if (!allocated_uses.contains(use)) { - unallocated_uses.push_back(use); - } - } - for (const auto& [next_iteration_idx, use] : value.next_iteration_uses) { - if (use_indices.contains(next_iteration_idx)) { - continue; - } - HloInstruction* loop_instruction = - hlo_live_range_.flattened_instruction_sequence().instructions().at( - loop_start_ + next_iteration_idx); - HloUse loop_use{loop_instruction, use.operand_number, use.operand_index}; - if (!allocated_uses.contains(loop_use)) { - unallocated_uses.push_back(loop_use); - } - } - if (!unallocated_uses.empty()) { - // TODO(b/281582241): We should find the correct position. For now, we're - // using the defining position on the first HLO value. - value.allocations.push_back( - std::make_unique( - value.hlo_values.front()->defining_position(), - MemorySpaceAssignment::MemorySpace::kDefault, std::nullopt, 0, - loop_size_, /*is_scoped_allocation=*/false)); - for (const HloUse& use : unallocated_uses) { - value.allocations.back()->AddUse(use); - } - } - } -} - -bool MemoryBoundLoopOptimizer::AllocateBetween(int64_t begin_idx, - int64_t end_idx, int64_t size) { - int64_t end_idx_sentinel = end_idx; - if (end_idx < begin_idx) { - end_idx_sentinel += loop_size_; - } - for (int64_t i = begin_idx; i <= end_idx_sentinel; ++i) { - if (remaining_memory_[i % loop_size_] < size) { - return false; - } - } - for (int64_t i = begin_idx; i <= end_idx_sentinel; ++i) { - remaining_memory_[i % loop_size_] -= size; - } - return true; -} - -bool MemoryBoundLoopOptimizer::AllocateTemporary(LoopValue& value) { - VLOG(3) << "AllocateTemporary: " << value.ToString(); - if (value.hlo_values.size() > 1) { - VLOG(3) << "LoopValue has more than one hlo value associated."; - return false; - } - int64_t definition_idx = value.loop_positions.front().first; - int64_t max_use_idx; - if (!value.next_iteration_uses.empty()) { - max_use_idx = value.next_iteration_uses.back().first; - // If max_use_idx >= definition_idx, then this is a loop carried dependence - // and we should not have called this function. - CHECK_LT(max_use_idx, definition_idx); - } else { - max_use_idx = value.loop_uses.back().first; - } - bool success = AllocateBetween(definition_idx, max_use_idx, value.size); - if (success) { - VLOG(3) << "Pos: " << value.loop_positions[0].second; - value.allocations.push_back( - std::make_unique( - value.loop_positions[0].second, - MemorySpaceAssignment::MemorySpace::kAlternate, std::nullopt, - definition_idx, max_use_idx, - /*is_scoped_allocation=*/false)); - AddAllLoopPositionsAndUses(value, /*allocate_next_iteration_uses=*/true); - } - return success; -} - -bool MemoryBoundLoopOptimizer::AllocatePinned(LoopValue& value) { - bool success = AllocateBetween(0, loop_size_, value.size); - if (success) { - CHECK(value.header_position); - value.allocations.push_back( - std::make_unique( - *value.header_position, - MemorySpaceAssignment::MemorySpace::kAlternate, std::nullopt, 0, - loop_size_, - /*is_scoped_allocation=*/false)); - AddAllLoopPositionsAndUses(value, /*allocate_next_iteration_uses=*/false); - } - return success; -} - -bool MemoryBoundLoopOptimizer::AllocatePrefetches( - absl::Span values) { - VLOG(3) << "Allocating prefetches num values: " << values.size(); - AllocatePrefetchesContext context; - context.values = values; - // Populate value_indices, which is a list of indices into values array sorted - // by the start time of the first use. - context.value_indices.resize(values.size()); - absl::c_iota(context.value_indices, 0); - absl::c_stable_sort(context.value_indices, [&](int a, int b) { - return std::forward_as_tuple( - values[a]->loop_uses.begin()->first, - values[a]->loop_uses.begin()->second.operand_number) > - std::forward_as_tuple( - values[b]->loop_uses.begin()->first, - values[b]->loop_uses.begin()->second.operand_number); - }); - - // Populate the data structures that contain additional positions and uses - // that would get alternate memory allocations if all of the prefetches were - // successful. - absl::flat_hash_map>> - additional_uses_in_alternate_mem; - absl::flat_hash_map> - additional_positions_in_alternate_mem; - for (const LoopValue* value : values) { - VLOG(3) << " prefetch value: " << value->ToString(); - for (const auto& [idx, use] : value->loop_uses) { - additional_uses_in_alternate_mem[use.instruction].push_back( - {use.operand_number, use.operand_index}); - } - for (const auto& [idx, position] : value->loop_positions) { - additional_positions_in_alternate_mem[position.instruction].push_back( - position.index); - } - } - // Calculate the default-memory remaining bandwidths assuming all prefetches - // succeed. - for (int i = 0; i < loop_size_; ++i) { - context.bandwidth_idle_times.push_back( - GetBandwidthIdleTime(i, additional_uses_in_alternate_mem, - additional_positions_in_alternate_mem)); - VLOG(3) << "Remaining bandwidth at " << i << " = " - << *context.bandwidth_idle_times.rbegin(); - } - - context.additional_memory_used.resize(loop_size_, 0); - - // Allocate prefetches by traversing the loop values in reverse order of - // the first uses. - for (int value_index : context.value_indices) { - AllocatePrefetch(value_index, context); - } - - for (int i = 0; i < loop_size_; ++i) { - remaining_memory_[i] -= context.additional_memory_used[i]; - VLOG(3) << "Additional memory [" << i - << "]: " << context.additional_memory_used[i]; - VLOG(3) << "Remaining memory [" << i << "]: " << remaining_memory_[i]; - VLOG(3) << "Remaining bandwidth [" << i - << "] : " << context.bandwidth_idle_times[i]; - } - return true; -} - -bool MemoryBoundLoopOptimizer::AllocatePrefetch( - int value_index, AllocatePrefetchesContext& context) { - LoopValue* value = context.values.at(value_index); - VLOG(3) << "Allocating value: " << value->ToString(); - int first_use_idx = value->loop_uses.front().first; - int last_use_idx = value->loop_uses.back().first; - int last_use_idx_sentinel = last_use_idx; - if (!value->next_iteration_uses.empty()) { - last_use_idx = value->next_iteration_uses.back().first; - last_use_idx_sentinel = last_use_idx + loop_size_; - CHECK_LT(last_use_idx, first_use_idx); - } - bool out_of_memory = false; - for (int i = first_use_idx; i <= last_use_idx_sentinel; ++i) { - int loop_idx = i % loop_size_; - if (context.additional_memory_used[loop_idx] + value->size > - remaining_memory_[loop_idx]) { - VLOG(3) << "Ran out of memory allocating for uses."; - out_of_memory = true; - } - } - if (out_of_memory) { - return false; - } - float copy_resource = - cost_analysis_.GetAsyncCopyElapsed(value->hlo_values.front()->shape()); - VLOG(3) << "First use: " << value->loop_uses.begin()->second - << " use idx: " << first_use_idx - << " copy resource: " << copy_resource; - std::optional copy_start_time; - // The general allocation algorithm for prefetches is to first calculate the - // default-memory bandwidth idle times at each point (assuming all prefetches - // succeeded). We show this pictorially below. We also show the previous - // iteration for clarity. The algorithm solves allocation for one iteration - // and this will be used for all iterations. - // - // idx: 0 1 2 3 4 5| 0 1 2 3 4 5| - // bw idle time: 2 2 1 2 3 1| 2 2 1 2 3 1| - // additional memory: 0 0 0 0 0 0| 0 0 0 0 0 0| - // iteration: prev | current | - // - // Now, let's assume there are two prefetches that need to be scheduled. For - // the sake of the example, assume 1 MiB of prefetch uses 1 memory bandwidth - // resource: - // - Prefetch 1 is 4 MiB and is first used at index 5. - // - Prefetch 2 is 5 MiB and is first used at index 1. - // - // We first order these prefetches by their first use from latest to earliest. - // Then starting from the prefetch completion time (i.e. the first use time), - // move the prefetch start time earlier until the copy resource is satisfied - // (or reaching another resource satisfaction criteria explained below) by - // consuming the bandwidth idle time of the overlapped instructions. We also - // keep track of the additional memory required. Note that index 5 also - // accounts for the additional 4 MiB consumed since the data needs to reside - // during the execution of the instruction at index 5. Below is the updated - // state after scheduling prefetch 1: - // - // prefetch 1: +====+ +====+ - // idx: 0 1 2 3 4 5| 0 1 2 3 4 5| - // bw idle time: 2 2 1 1 0 1| 2 2 1 1 0 1| - // additional memory: 0 0 0 4 4 4| 0 0 0 4 4 4| - // iteration: prev | current | - // - // To schedule prefetch 2, we similarly start the same way, from its first use - // and bring the prefetch start earlier. We first reach index 0 with still an - // unsatisfied copy resource of 3: - // - // prefetch 2: +=+ +=+ unsat res: 3 - // prefetch 1: +====+ +====+ - // idx: 0 1 2 3 4 5| 0 1 2 3 4 5| - // bw idle time: 0 2 1 1 0 1| 0 2 1 1 0 1| - // additional memory: 5 5 0 4 4 4| 5 5 0 4 4 4| - // iteration: prev | current | - // - // We continue onto the previous iteration: - // - // prefetch 2:===+ +====+ +== unsat res: 2 - // prefetch 1: +====+ +====+ - // idx: 0 1 2 3 4 5| 0 1 2 3 4 5| - // bw idle time: 0 2 1 1 0 0| 0 2 1 1 0 0| - // additional memory: 5 5 0 4 4 9| 5 5 0 4 4 9| - // iteration: prev | current | - // - // As we bring the start time of prefetch 2 earlier, it starts overlapping - // with prefetch 1: - // - // prefetch 2:===+ +==========+ +======== unsat res: 1 - // prefetch 1: +====+ +====+ - // idx: 0 1 2 3 4 5| 0 1 2 3 4 5| - // bw idle time: 0 2 1 0 0 0| 0 2 1 0 0 0| - // additional memory: 5 5 0 9 9 9| 5 5 0 9 9 9| - // iteration: prev | current | - // - // The prefetch resource is still unsatisfied at this point. We can bring the - // prefetch earlier. However, the first prefetch's end time is earlier than - // the second and we need to maintain FIFO order with regard to prefetches. In - // order to maintain this FIFO order, we "early force" prefetches that are - // already scheduled by moving the start time earlier along with prefetch 2: - // - // prefetch 2:===+ +=============+ +=========== - // prefetch 1: +=======+ +=======+ - // idx: 0 1 2 3 4 5| 0 1 2 3 4 5| - // bw idle time: 0 2 0 0 0 0| 0 2 0 0 0 0| - // additional memory: 5 5 9 9 9 9| 5 5 9 9 9 9| - // iteration: prev | current | - // - // Depending on the options provided, we can use alternative resource - // satisfaction criteria. One option is to specify a percentage of the copy - // resource that needs to be satisfied instead of the complete amount (100%). - // This is called the "desired copy ratio". The reason why desired copy ratio - // can be less than 100% is that in a memory-bound loop, we probably do not - // have enough aggregate bandwidth resources to satisfy all of the prefetches, - // but using up all of the default-memory bandwidth is more important than - // having some prefetches with unsatisfied resources. In a similar vein, - // another option is to accept prefetches that are fully pipelined, i.e. - // their copy start time is scheduled the same time as the copy done time in - // the previous iteration, regardless of how much of its copy resources are - // actually satisfied. To illustrate a fully pipelined prefetch, consider - // prefetch 3 (assume no prefetch 1 or 2 in this example) which is 15 MiB and - // its first use is at index 4: - // - // prefetch 3:=============+=================+===== unsat res: 4 - // idx: 0 1 2 3 4 5| 0 1 2 3 4 5| - // bw idle time: 0 0 0 0 0 0| 0 0 0 0 0 0| - // additional memory: 15 15 15 15 30 15|15 15 15 15 30 15| - // iteration: prev | current | - // - // Note that the additional memory consumption at index 4 is actually twice - // the size of the prefetch as we are effectively double buffering. Also note - // that the prefetch has an unsatisfied copy resource of 4 meaning the copy - // will be in the critical path, but this actually will be faster than not - // scheduling this particular prefetch in the first place since the bandwidth - // idle time resource would go unused. - float accumulated_copy_resource = 0; - std::vector early_forced_prefetch_value_indices; - int early_forced_prefetch_value_search_index = 0; - float early_forced_prefetch_additional_memory = 0; - for (int i = first_use_idx - 1; i >= last_use_idx_sentinel - loop_size_; - --i) { - int loop_idx = (i + loop_size_) % loop_size_; - // Check if this prefetch rolls over to the previous iteration, check if any - // already-scheduled prefetches would violate the FIFO order, and if so, - // "early-force" them to be co-scheduled with this prefetch to maintain the - // FIFO order. This of course increases the required memory, so also keep - // track of additional memory that would be consumed. - if (i < 0) { - for (; context.value_indices[early_forced_prefetch_value_search_index] != - value_index; - ++early_forced_prefetch_value_search_index) { - VLOG(3) << "Searching for early forced: " - << early_forced_prefetch_value_search_index; - LoopValue* early_forced_value = context.values.at( - context.value_indices[early_forced_prefetch_value_search_index]); - if (early_forced_value->allocations.empty()) { - continue; - } - const MemorySpaceAssignment::CopyAllocation* early_forced_prefetch = - static_cast( - early_forced_value->allocations.back().get()); - VLOG(3) << "Prefetch: " << early_forced_prefetch->ToString(); - - // If the prefetch is already a roll-around prefetch, no need to further - // early force it. - if (early_forced_prefetch->copy_done_schedule_before() <= - early_forced_prefetch->copy_start_schedule_after() + 1 || - (early_forced_prefetch->copy_start_schedule_after() == - loop_size_ - 1 && - early_forced_prefetch->copy_done_schedule_before() == 0)) { - break; - } - if (early_forced_prefetch->copy_start_schedule_after() != loop_idx) { - break; - } - early_forced_prefetch_value_indices.push_back( - early_forced_prefetch_value_search_index); - early_forced_prefetch_additional_memory += early_forced_value->size; - VLOG(3) << "Found early-forced prefetch value: " - << early_forced_value->ToString(); - VLOG(3) << "Early forced prefetch additional memory: " - << early_forced_prefetch_additional_memory; - } - } - - // Overlap memory overhead only happens if the copy start overlaps with the - // first use (i.e. fully pipelined), so we'd need to account for 2X the - // buffer at this time. - int64_t overlap_memory_overhead = 0; - if (loop_idx == last_use_idx) { - overlap_memory_overhead = value->size; - VLOG(3) << "Loop idx == last use idx (" << loop_idx - << "), overlap memory overhead = " << overlap_memory_overhead; - } - - // OOM; give up prefetch. - if (context.additional_memory_used[loop_idx] + value->size + - overlap_memory_overhead + early_forced_prefetch_additional_memory > - remaining_memory_[loop_idx]) { - VLOG(3) << "Ran out of memory. Accumulated copy resource " - << accumulated_copy_resource << " out of " << copy_resource - << " at " << loop_idx; - break; - } - - // We ideally find a time to overlap the prefetch fully where the previous - // iteration's memory use is disjoint from this iteration. If that is not - // possible, there are two compromises we could pick: - // - Find a prefetch time that satisfies a desired ratio < 1 of the - // prefetch elapsed time. This means the prefetch will be critical. - // - Overlap the prefetch with the previous iteration's buffer use, i.e. - // full pipelining. This would increase the peak memory consumption. - float bandwidth_idle_time = context.bandwidth_idle_times[loop_idx]; - VLOG(3) << "Idx " << loop_idx - << " bandwidth_idle_time: " << bandwidth_idle_time - << " copy resource remaining: " - << (copy_resource - accumulated_copy_resource) << " diff: " - << (bandwidth_idle_time - - (copy_resource - accumulated_copy_resource)); - if (bandwidth_idle_time >= copy_resource - accumulated_copy_resource) { - accumulated_copy_resource = copy_resource; - copy_start_time = loop_idx; - VLOG(3) << "Found the complete copy ratio and updated accumulated copy " - "resource: " - << accumulated_copy_resource; - break; - } else if (!copy_start_time && - accumulated_copy_resource + bandwidth_idle_time >= - copy_resource * options_.desired_copy_ratio()) { - accumulated_copy_resource += bandwidth_idle_time; - copy_start_time = loop_idx; - VLOG(3) << "Found the desired copy ratio and updated accumulated copy " - "resource: " - << accumulated_copy_resource; - } else if (options_.allow_unsatisfied_fully_pipelined_prefetch() && - loop_idx == last_use_idx) { - // Even if desired resource isn't reached, and if the options allow it, - // allow a fully pipelined prefetch. - accumulated_copy_resource += bandwidth_idle_time; - copy_start_time = loop_idx; - VLOG(3) << "Could not reach the desired copy ratio but scheduling " - "fully pipelined prefetch anyway: " - << accumulated_copy_resource; - break; - } else { - accumulated_copy_resource += bandwidth_idle_time; - VLOG(3) << "Updated accumulated copy resource: " - << accumulated_copy_resource; - } - } - - // Could not find a suitable copy start time. - if (!copy_start_time) { - return false; - } - - VLOG(3) << "Success: copy_start_time: " << *copy_start_time - << " leftover copy resource: " - << (copy_resource - accumulated_copy_resource); - auto update_additional_memory_used = [&](int loop_idx, int64_t addition) { - VLOG(4) << "Updating additional memory used at " << loop_idx << ". " - << context.additional_memory_used[loop_idx] << " + " << addition - << " => " << (context.additional_memory_used[loop_idx] + addition) - << " (remaining: " << remaining_memory_[loop_idx] << ")"; - context.additional_memory_used[loop_idx] += addition; - CHECK_LE(context.additional_memory_used[loop_idx], - remaining_memory_[loop_idx]); - }; - for (int i = first_use_idx; i <= last_use_idx_sentinel; ++i) { - int loop_idx = i % loop_size_; - update_additional_memory_used(loop_idx, value->size); - } - // We reset accumulated copy resource and then reuse it to accumulate copy - // resource time in order to replay the previous for loop. It is important - // that we use the same arithmetic operations (as opposed to subtracting from - // copy_resource) because floating point operations aren't commutative. - accumulated_copy_resource = 0.0; - for (int i = first_use_idx - 1; i >= last_use_idx_sentinel - loop_size_; - --i) { - int loop_idx = (i + loop_size_) % loop_size_; - float& bandwidth_idle_time = context.bandwidth_idle_times[loop_idx]; - // Overlap memory overhead only happens if the copy start overlaps with the - // first use (i.e. fully pipelined), so we'd need to account for 2X the - // buffer at this time. - int64_t overlap_memory_overhead = 0; - update_additional_memory_used(loop_idx, - value->size + overlap_memory_overhead); - if (bandwidth_idle_time < copy_resource - accumulated_copy_resource) { - accumulated_copy_resource += bandwidth_idle_time; - bandwidth_idle_time = 0; - if (loop_idx == *copy_start_time) { - VLOG(3) << "Remaining copy resource: " - << (copy_resource - accumulated_copy_resource); - break; - } - } else { - bandwidth_idle_time -= copy_resource - accumulated_copy_resource; - CHECK_EQ(loop_idx, *copy_start_time); - break; - } - } - - // Create the Allocation objects that correspond to the scheduled prefetch. - CHECK(value->header_position); - value->allocations.push_back( - std::make_unique( - *value->header_position, MemorySpaceAssignment::MemorySpace::kDefault, - std::nullopt, 0, loop_size_, /*is_scoped_allocation=*/false)); - value->allocations.push_back( - std::make_unique( - *value->allocations.back(), - MemorySpaceAssignment::MemorySpace::kAlternate, std::nullopt, - ((*copy_start_time - 1) + loop_size_) % loop_size_, first_use_idx, - last_use_idx_sentinel)); - AddAllLoopPositionsAndUses(*value, /*allocate_next_iteration_uses=*/true); - - // Account for the additional memory used by early forcing the already - // scheduled prefetches. Also modify the start times of these to this - // prefetch's copy start time. - for (int early_forced_prefetch_value_index : - early_forced_prefetch_value_indices) { - LoopValue* early_forced_value = context.values.at( - context.value_indices[early_forced_prefetch_value_index]); - CHECK(!early_forced_value->allocations.empty()); - MemorySpaceAssignment::CopyAllocation* early_forced_prefetch = - static_cast( - early_forced_value->allocations.back().get()); - for (int index = early_forced_prefetch->copy_start_schedule_after(); - index >= *copy_start_time; --index) { - update_additional_memory_used(index, early_forced_value->size); - VLOG(3) << "Additional memory used: " << index << " " - << context.additional_memory_used[index]; - } - early_forced_prefetch->set_copy_start_schedule_after( - ((*copy_start_time - 1) + loop_size_) % loop_size_); - VLOG(3) << "Updated prefetch: " << early_forced_prefetch->ToString(); - } - return true; -} - -void MemoryBoundLoopOptimizer::AddAllLoopPositionsAndUses( - LoopValue& value, bool allocate_next_iteration_uses) { - CHECK_GE(value.allocations.size(), 1); - MemorySpaceAssignment::Allocation& allocation = *value.allocations.back(); - for (const auto& [idx, position] : value.loop_positions) { - positions_in_alternate_mem_[position.instruction].push_back(position.index); - } - for (const auto& [idx, use] : value.loop_uses) { - uses_in_alternate_mem_[use.instruction].push_back( - {use.operand_number, use.operand_index}); - allocation.AddUse(use); - } - if (allocate_next_iteration_uses) { - for (const auto& [next_iteration_idx, use] : value.next_iteration_uses) { - HloInstruction* loop_instruction = - hlo_live_range_.flattened_instruction_sequence().instructions().at( - loop_start_ + next_iteration_idx); - uses_in_alternate_mem_[loop_instruction].push_back( - {use.operand_number, use.operand_index}); - allocation.AddUse( - {loop_instruction, use.operand_number, use.operand_index}); - } - } -} - -float MemoryBoundLoopOptimizer::GetBandwidthIdleTime(int idx) const { - const HloInstruction* inst = - hlo_live_range_.flattened_instruction_sequence().instructions().at( - loop_start_ + idx); - std::vector> empty_operands; - std::vector empty_outputs; - const std::vector>* operands_in_alternate_mem = - &empty_operands; - const std::vector* outputs_in_alternate_mem = &empty_outputs; - auto uses_it = uses_in_alternate_mem_.find(inst); - if (uses_it != uses_in_alternate_mem_.end()) { - operands_in_alternate_mem = &uses_it->second; - } - auto positions_it = positions_in_alternate_mem_.find(inst); - if (positions_it != positions_in_alternate_mem_.end()) { - outputs_in_alternate_mem = &positions_it->second; - } - return cost_analysis_.GetDefaultMemoryBandwidthIdleTime( - *inst, *operands_in_alternate_mem, *outputs_in_alternate_mem); -} - -float MemoryBoundLoopOptimizer::GetBandwidthIdleTime( - int idx, - const absl::flat_hash_map>>& - additional_uses_in_alternate_mem, - const absl::flat_hash_map>& - additional_positions_in_alternate_mem) const { - const HloInstruction* inst = - hlo_live_range_.flattened_instruction_sequence().instructions().at( - loop_start_ + idx); - std::vector> operands_in_alternate_mem; - std::vector outputs_in_alternate_mem; - auto uses_it = uses_in_alternate_mem_.find(inst); - if (uses_it != uses_in_alternate_mem_.end()) { - operands_in_alternate_mem = uses_it->second; - } - auto additional_uses_it = additional_uses_in_alternate_mem.find(inst); - if (additional_uses_it != additional_uses_in_alternate_mem.end()) { - absl::c_copy(additional_uses_it->second, - std::back_inserter(operands_in_alternate_mem)); - } - auto positions_it = positions_in_alternate_mem_.find(inst); - if (positions_it != positions_in_alternate_mem_.end()) { - outputs_in_alternate_mem = positions_it->second; - } - auto additional_positions_it = - additional_positions_in_alternate_mem.find(inst); - if (additional_positions_it != additional_positions_in_alternate_mem.end()) { - absl::c_copy(additional_positions_it->second, - std::back_inserter(outputs_in_alternate_mem)); - } - return cost_analysis_.GetDefaultMemoryBandwidthIdleTime( - *inst, operands_in_alternate_mem, outputs_in_alternate_mem); -} - -float MemoryBoundLoopOptimizer::GetInstructionElapsed(int idx) const { - const HloInstruction* inst = - hlo_live_range_.flattened_instruction_sequence().instructions().at( - loop_start_ + idx); - std::vector> empty_operands; - std::vector empty_outputs; - const std::vector>* operands_in_alternate_mem = - &empty_operands; - const std::vector* outputs_in_alternate_mem = &empty_outputs; - auto uses_it = uses_in_alternate_mem_.find(inst); - if (uses_it != uses_in_alternate_mem_.end()) { - operands_in_alternate_mem = &uses_it->second; - } - auto positions_it = positions_in_alternate_mem_.find(inst); - if (positions_it != positions_in_alternate_mem_.end()) { - outputs_in_alternate_mem = &positions_it->second; - } - return cost_analysis_.GetInstructionElapsedInAlternateMemory( - *inst, *operands_in_alternate_mem, *outputs_in_alternate_mem); + options_.dump_fn("bufferinfo", buffer_info_str_); + options_.dump_fn("allocinfo", allocation_info_str_); + options_.dump_fn("scheduleinfo", instruction_schedule_str_); } Status AlternateMemoryBestFitHeap::OptimizeMemoryBoundLoop(int loop_start_idx, @@ -3395,7 +1252,8 @@ Status AlternateMemoryBestFitHeap::OptimizeMemoryBoundLoop(int loop_start_idx, MemoryBoundLoopOptimizer::Create( iteration_start_idx, iteration_end_idx, options_.max_size_in_bytes, options_.memory_bound_loop_optimizer_options, hlo_live_range_, - alias_analysis_, *options_.cost_analysis, options_.size_fn)); + alias_analysis_, *options_.cost_analysis, options_.size_fn, + options_.reserved_scoped_memory_fn)); optimizer->Optimize(); const int loop_optimized_allocations_original_size = @@ -3418,8 +1276,7 @@ Status AlternateMemoryBestFitHeap::OptimizeMemoryBoundLoop(int loop_start_idx, // optimizer. for (int i = loop_optimized_allocations_original_size; i < loop_optimized_allocations_.size(); ++i) { - const MemorySpaceAssignment::AllocationSequence& sequence = - loop_optimized_allocations_.at(i); + const AllocationSequence& sequence = loop_optimized_allocations_.at(i); CHECK(!sequence.empty()); VLOG(3) << " alloc: " << sequence.back()->ToString(); for (const auto& allocation : sequence) { @@ -3438,6 +1295,10 @@ Status AlternateMemoryBestFitHeap::OptimizeMemoryBoundLoop(int loop_start_idx, for (int64_t i = loop_start_idx + use_idx; i <= loop_end_idx; i += loop_size) { HloInstruction* repeated_inst = instruction_sequence[i]; + CHECK_EQ(use.instruction->opcode(), repeated_inst->opcode()); + CHECK_EQ(use.instruction->operand_count(), + repeated_inst->operand_count()); + CHECK_LT(use.operand_number, repeated_inst->operand_count()); HloUse repeated_use{repeated_inst, use.operand_number, use.operand_index}; loop_optimized_allocations_map_[repeated_use] = {use_idx, loop_size, @@ -3478,11 +1339,13 @@ std::function GetOperandDistanceFunction( const HloLiveRange& hlo_live_range, const HloInstruction* use_inst) { const int use_idx = hlo_live_range.instruction_schedule().at(use_inst); return [&, use_idx](const HloInstruction* operand) -> int { - // We just use -1 for parameter, tuple, and gte instructions. We could make - // this "see through" the gtes if we get too many false positives. + // We just use -1 for parameter, tuple, gte and constant instructions. We + // could make this "see through" the gtes if we get too many false + // positives. if (operand->opcode() == HloOpcode::kParameter || operand->opcode() == HloOpcode::kTuple || - operand->opcode() == HloOpcode::kGetTupleElement) { + operand->opcode() == HloOpcode::kGetTupleElement || + operand->opcode() == HloOpcode::kConstant) { return -1; } return use_idx - hlo_live_range.instruction_schedule().at(operand); @@ -3587,6 +1450,7 @@ void AlternateMemoryBestFitHeap::IdentifyAndOptimizeMemoryBoundLoops() { instruction->opcode() == HloOpcode::kTuple || instruction->opcode() == HloOpcode::kGetTupleElement; }; + // We trigger this if statement until we find the start of the loop. if (loop_start_idx == -1) { if (i > optimized_loop_idx - loop_size_candidate) { break; @@ -3634,7 +1498,7 @@ void AlternateMemoryBestFitHeap::IdentifyAndOptimizeMemoryBoundLoops() { break; } operand_distances.push_back({}); - if (ignore_op(inst) || fingerprint_it == fingerprint_map_.end()) { + if (fingerprint_it == fingerprint_map_.end()) { continue; } absl::c_transform(inst->operands(), @@ -3649,6 +1513,21 @@ void AlternateMemoryBestFitHeap::IdentifyAndOptimizeMemoryBoundLoops() { if (prev_fingerprint_it == fingerprint_map_.end()) { break; } + if (ignore_op(inst) || ignore_op(prev_inst)) { + if (inst->opcode() != prev_inst->opcode()) { + VLOG(3) << "Mismatch (opcode) at " << i << ", " + << (i - loop_size_candidate) << ": " << inst->opcode() + << " vs " << prev_inst->opcode(); + break; + } + if (inst->operand_count() != prev_inst->operand_count()) { + VLOG(3) << "Mismatch (# operands) at " << i << ", " + << (i - loop_size_candidate) << ": " + << inst->operand_count() << " vs " + << prev_inst->operand_count(); + break; + } + } if (fingerprint_it->second != prev_fingerprint_it->second) { VLOG(3) << "Mismatch (fp) at " << i << ", " << (i - loop_size_candidate) << ": " << fingerprint_it->second @@ -3692,7 +1571,8 @@ void AlternateMemoryBestFitHeap::IdentifyAndOptimizeMemoryBoundLoops() { } } -HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { +absl::StatusOr> +AlternateMemoryBestFitHeap::Finish() { if (options_.autotuning_config.has_value()) { CHECK_EQ((*options_.autotuning_config).size(), buffer_intervals_.size()); } @@ -3721,7 +1601,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { for (auto& interval : sorted_buffer_intervals) { if (!interval.need_allocation || !MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( - interval) || + interval, options_.alternate_memory_space) || interval.size > available_heap_size()) { continue; } @@ -3788,11 +1668,15 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { for (auto& interval : sorted_buffer_intervals) { if (!interval.need_allocation) { + VLOG(3) << "Skip " << interval.buffer->ToShortString() + << " because it doesn't need an allocation."; continue; } if (!MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( - interval)) { + interval, options_.alternate_memory_space)) { + VLOG(3) << "Skip " << interval.buffer->ToShortString() + << " because it is not allowed in the alternate memory."; continue; } @@ -3870,8 +1754,9 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { retry_number++) { AddRequiredAssignmentsForColocatedIntervals(colocated_intervals); options_.prefetch_interval_picker->SetRetryNumber(retry_number); - Result result = - AllocateAllocationValues(absl::MakeSpan(allocation_values)); + TF_ASSIGN_OR_RETURN( + Result result, + AllocateAllocationValues(absl::MakeSpan(allocation_values))); VLOG(2) << "Allocation result = " << absl::StrFormat("%x", static_cast(result)); if (result_requires_uncommit(result)) { @@ -3884,8 +1769,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { ++num_repacks_; repacked = true; CHECK_NE(options_.repacker, nullptr); - std::vector - repack_allocation_blocks; + std::vector repack_allocation_blocks; ExportAllocationsForRepacking(repack_allocation_blocks); VLOG(2) << "Repacking."; auto repack_status = @@ -3933,8 +1817,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { } if (options_.repack_after_every_allocation) { CHECK_NE(options_.repacker, nullptr); - std::vector - repack_allocation_blocks; + std::vector repack_allocation_blocks; ExportAllocationsForRepacking(repack_allocation_blocks); VLOG(2) << "Final Repacking."; auto repack_status = @@ -3976,26 +1859,23 @@ HloPosition TupleUseToPosition(const HloUse& use) { } // Returns the memory space of the defining position of an Allocation object. -MemorySpaceAssignment::MemorySpace GetDefiningPositionMemorySpace( - const MemorySpaceAssignment::Allocation& allocation) { +MemorySpace GetDefiningPositionMemorySpace(const Allocation& allocation) { if (!allocation.is_copy_like_allocation()) { return allocation.memory_space(); } - if (allocation.memory_space() == - MemorySpaceAssignment::MemorySpace::kDefault) { - return MemorySpaceAssignment::MemorySpace::kAlternate; + if (allocation.memory_space() == MemorySpace::kDefault) { + return MemorySpace::kAlternate; } - return MemorySpaceAssignment::MemorySpace::kDefault; + return MemorySpace::kDefault; } } // namespace -std::vector> +std::vector> AlternateMemoryBestFitHeap::GetLinkedAllocationsInAlternateMemory( absl::Span allocation_values) const { - std::vector> - linked_allocations; + std::vector> linked_allocations; // A map from position to index into linked_allocations. absl::flat_hash_map link_id_map; // Iterate over the allocation values. Find Allocation objects across the @@ -4088,8 +1968,7 @@ AlternateMemoryBestFitHeap::GetLinkedAllocationsInAlternateMemory( if (VLOG_IS_ON(3)) { for (int i = 0; i < linked_allocations.size(); ++i) { VLOG(3) << "Link id = " << i; - for (const MemorySpaceAssignment::Allocation* allocation : - linked_allocations[i]) { + for (const Allocation* allocation : linked_allocations[i]) { VLOG(3) << " " << allocation->ToString(); } } @@ -4128,15 +2007,16 @@ AlternateMemoryBestFitHeap::GetInefficientAllocationSites( const HloPosition& defining_position = allocation->defining_position(); int64_t accessed = - options_.cost_analysis->cost_analysis().output_bytes_accessed( + options_.cost_analysis->hlo_cost_analysis().output_bytes_accessed( *defining_position.instruction, defining_position.index); VLOG(3) << " pos: " << defining_position.ToString() << ", accessed: " << accessed << " / " << size; } for (const HloUse& use : allocation->uses()) { int64_t accessed = - options_.cost_analysis->cost_analysis().operand_bytes_accessed( - *use.instruction, use.operand_number, use.operand_index); + options_.cost_analysis->hlo_cost_analysis() + .operand_bytes_accessed(*use.instruction, use.operand_number, + use.operand_index); VLOG(3) << " use: " << use.ToString() << ", accessed: " << accessed << " / " << size; } @@ -4144,12 +2024,11 @@ AlternateMemoryBestFitHeap::GetInefficientAllocationSites( } } - std::vector> - linked_allocations = - GetLinkedAllocationsInAlternateMemory(allocation_values); + std::vector> linked_allocations = + GetLinkedAllocationsInAlternateMemory(allocation_values); std::vector inefficient_sites; - for (const std::vector& - allocation_group : linked_allocations) { + for (const std::vector& allocation_group : + linked_allocations) { // For all of allocation in the linked allocation group, calculate the total // use bytes in alternate memory and async copy bytes. If the ratio between // the two is below inefficient_use_to_copy_ratio, add all of the @@ -4157,8 +2036,7 @@ AlternateMemoryBestFitHeap::GetInefficientAllocationSites( VLOG(3) << "AllocationGroup:"; int64_t copy_bytes = 0; int64_t use_bytes = 0; - for (const MemorySpaceAssignment::Allocation* allocation : - allocation_group) { + for (const Allocation* allocation : allocation_group) { VLOG(3) << " Allocation: " << allocation->ToString(); MemorySpace position_memory_space = GetDefiningPositionMemorySpace(*allocation); @@ -4167,22 +2045,22 @@ AlternateMemoryBestFitHeap::GetInefficientAllocationSites( } if (position_memory_space == MemorySpace::kAlternate) { use_bytes += - options_.cost_analysis->cost_analysis().output_bytes_accessed( + options_.cost_analysis->hlo_cost_analysis().output_bytes_accessed( *allocation->defining_position().instruction, allocation->defining_position().index); } if (allocation->memory_space() == MemorySpace::kAlternate) { for (const HloUse& use : allocation->uses()) { use_bytes += - options_.cost_analysis->cost_analysis().operand_bytes_accessed( - *use.instruction, use.operand_number, use.operand_index); + options_.cost_analysis->hlo_cost_analysis() + .operand_bytes_accessed(*use.instruction, use.operand_number, + use.operand_index); } } } VLOG(3) << " use bytes: " << use_bytes << ", copy bytes: " << copy_bytes; if (options_.inefficient_use_to_copy_ratio * copy_bytes > use_bytes) { - for (const MemorySpaceAssignment::Allocation* allocation : - allocation_group) { + for (const Allocation* allocation : allocation_group) { MemorySpace position_memory_space = GetDefiningPositionMemorySpace(*allocation); if (position_memory_space == MemorySpace::kAlternate) { @@ -4266,7 +2144,7 @@ void AlternateMemoryBestFitHeap::CreateAllocationValuesFromColocatedIntervals( FindAliases(&allocation_values); } -AlternateMemoryBestFitHeap::Result +absl::StatusOr AlternateMemoryBestFitHeap::AllocateAllocationValues( absl::Span allocation_values) { const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); @@ -4293,11 +2171,26 @@ AlternateMemoryBestFitHeap::AllocateAllocationValues( int64_t definition_time = instruction_schedule.at(allocation_value.defining_instruction()); + bool require_no_copy_alternate_mem_allocation = + allocation_value.value()->shape().has_layout() && + allocation_value.value()->shape().layout().memory_space() == + options_.alternate_memory_space; + VLOG(3) << "require_no_copy_alternate_mem_allocation = " + << require_no_copy_alternate_mem_allocation; if (!options_.is_position_allowed_in_alternate_mem_fn( allocation_value.defining_position())) { - AddRequiredAssignment(allocation_value.value(), - allocation_value.defining_instruction(), - MemorySpace::kDefault, definition_time); + if (require_no_copy_alternate_mem_allocation) { + LOG(WARNING) + << "The value " << allocation_value.value()->ToShortString() + << " is pre-colored for alternate memory but the position " + << allocation_value.defining_position().ToString() + << " is not allowed in the alternate memory. Respecting the color " + "but this may break things later in compilation."; + } else { + AddRequiredAssignment(allocation_value.value(), + allocation_value.defining_instruction(), + MemorySpace::kDefault, definition_time); + } } AliasedOffset* preferred_offset = nullptr; @@ -4312,12 +2205,30 @@ AlternateMemoryBestFitHeap::AllocateAllocationValues( const AllocationValue::Use& use = allocation_value.uses().at(use_idx); const HloUse hlo_use = use.hlo_use; int64_t use_time = instruction_schedule.at(hlo_use.instruction); - int64_t latest_prefetch_time = use_time; bool allow_no_copy_alternate_mem_allocation = true; bool allow_prefetch = true; bool prefer_no_copy_alternate_mem_allocation = false; + // TODO(b/318886791): Rename boundary variables (here and other places) + // like `latest_prefetch_time` and `earliest_prefetch_time` indicate + // whether they are exclusive or inclusive boundaries. + int64_t latest_prefetch_time = use_time; std::optional earliest_prefetch_time = std::nullopt; + // Assign the required assignment offset as a preferred offset. + std::optional required_assignment = + AliasedRequiredAssignmentForUse(use); + if (required_assignment && + required_assignment->memory_space == MemorySpace::kAlternate) { + if (preferred_offset) { + CHECK_EQ(preferred_offset, required_assignment->offset); + } else { + preferred_offset = required_assignment->offset; + VLOG(3) + << "Setting preferred offset due to required assignment for use: " + << preferred_offset->offset; + } + } + // Control flow calls include kWhile, kCall, and kConditional opcodes. bool is_sequential_call = (GetInstructionCallContext(hlo_use.instruction->opcode()) == @@ -4377,8 +2288,17 @@ AlternateMemoryBestFitHeap::AllocateAllocationValues( // Add a required assignment in default memory if the use not allowed in // alternate memory. if (!IsUseAllowedInAlternateMemory(allocation_value, hlo_use)) { - AddRequiredAssignment(allocation_value.value(), hlo_use.instruction, - MemorySpace::kDefault, use_time); + if (require_no_copy_alternate_mem_allocation) { + LOG(WARNING) + << "The value " << allocation_value.value()->ToShortString() + << " is pre-colored for alternate memory but the use " + << hlo_use.ToString() + << " is not allowed in the alternate memory. Respecting the " + "color but this may break things later in compilation."; + } else { + AddRequiredAssignment(allocation_value.value(), hlo_use.instruction, + MemorySpace::kDefault, use_time); + } } else if (use_idx > 0) { // We allow buffers in alternate memory that are passed into // conditionals to give up their alternate memory allocation inside the @@ -4415,16 +2335,22 @@ AlternateMemoryBestFitHeap::AllocateAllocationValues( loop_optimized_allocations_map_.end()) { const LoopOptimizedAllocationInfo& loop_optimized_allocation_info = loop_optimized_allocation_it->second; - const MemorySpaceAssignment::Allocation* allocation = + const Allocation* allocation = loop_optimized_allocation_info.loop_optimized_allocation; VLOG(3) << "Found optimized allocation for " << use.hlo_use.ToString() << " (loop idx: " << loop_optimized_allocation_info.use_index << "): " << allocation->ToString(); - if (allocation->is_copy_allocation()) { + if (require_no_copy_alternate_mem_allocation) { + if (allocation->is_copy_allocation() || + allocation->memory_space() == MemorySpace::kDefault) { + LOG(WARNING) << "Optimized allocation could not be applied " + "because the tensor is pre-colored, allocation: " + << allocation->ToString(); + } + } else if (allocation->is_copy_allocation()) { allow_no_copy_alternate_mem_allocation = true; - const MemorySpaceAssignment::CopyAllocation* copy_allocation = - static_cast( - allocation); + const CopyAllocation* copy_allocation = + static_cast(allocation); int64_t effective_copy_start_time = copy_allocation->copy_start_schedule_after(); if (copy_allocation->copy_start_schedule_after() == @@ -4487,9 +2413,9 @@ AlternateMemoryBestFitHeap::AllocateAllocationValues( : std::min(definition_time, use_time)); auto overridden_preferred_prefetch_time = GetOverriddenPreferredPrefetchTime( - options_.filter_update_preferred_prefetches, - allocation_value.size(), hlo_use, instruction_schedule, - live_range_start_time, latest_prefetch_time); + options_.preferred_prefetch_overrides, allocation_value.size(), + hlo_use, instruction_schedule, live_range_start_time, + latest_prefetch_time); TF_CHECK_OK(overridden_preferred_prefetch_time.status()); if (overridden_preferred_prefetch_time.value().has_value()) { LOG(INFO) << "Overriding preferred prefetch for " @@ -4519,6 +2445,8 @@ AlternateMemoryBestFitHeap::AllocateAllocationValues( request.allow_no_copy_alternate_mem_allocation = allow_no_copy_alternate_mem_allocation; request.allow_prefetch = allow_prefetch; + request.require_no_copy_alternate_mem_allocation = + require_no_copy_alternate_mem_allocation; request.earliest_prefetch_time = earliest_prefetch_time; request.preferred_prefetch_time = preferred_prefetch_time; request.preferred_offset = preferred_offset; @@ -4526,6 +2454,17 @@ AlternateMemoryBestFitHeap::AllocateAllocationValues( request.allocation_value = &allocation_value; request.all_use_times = all_use_times; result_mark(AllocateSegment(request), result); + if (request.require_no_copy_alternate_mem_allocation && + result != Result::kSuccess) { + Status failed_precondition = FailedPrecondition( + "The value defined at %s requires allocation in the alternate " + "memory, which could not be satisfied. This typically happens " + "because more pinned buffers are live than the alternate memory " + "capacity.", + allocation_value.defining_instruction()->ToString()); + LOG(ERROR) << failed_precondition; + return failed_precondition; + } if (result_requires_uncommit(result)) { // If the allocation finding failed (e.g., due to running out of // asynchronous copies), then fall back to allocating the buffer @@ -4539,9 +2478,8 @@ AlternateMemoryBestFitHeap::AllocateAllocationValues( } // Propagate the allocation to any aliases this use might have had. - MemorySpaceAssignment::Allocation* aliased_allocation = - GetLiveAllocationAt(*allocation_value.allocation_sequence(), - use_time); + Allocation* aliased_allocation = GetLiveAllocationAt( + *allocation_value.allocation_sequence(), use_time); for (const HloPosition& aliased_position : use.aliases) { AddAliasedRequiredAssignment(aliased_position.instruction, aliased_position.index, @@ -4592,7 +2530,7 @@ AlternateMemoryBestFitHeap::AllocateAllocationValues( int64_t body_parameter_time = instruction_schedule.at( body_allocation_value_it->defining_instruction()); body_allocation_value_it->mutable_allocation_sequence()->push_back( - std::make_unique( + std::make_unique( **prev_allocation_in_default_mem_it, hlo_use.instruction, body_allocation_value_it->defining_position(), body_parameter_time)); @@ -4610,9 +2548,8 @@ AlternateMemoryBestFitHeap::AllocateAllocationValues( << after_while_allocation_value_it->ToShortString(); int64_t while_time = instruction_schedule.at(hlo_use.instruction); after_while_allocation_value_it->mutable_allocation_sequence() - ->push_back( - std::make_unique( - **prev_allocation_in_default_mem_it, while_time)); + ->push_back(std::make_unique( + **prev_allocation_in_default_mem_it, while_time)); VLOG(3) << "Created: " << after_while_allocation_value_it->allocation_sequence() ->back() @@ -4914,7 +2851,7 @@ struct CopyResourceDumpData { std::string AsynchronousCopyResource::Dump( int64_t start_time, int64_t end_time, - MemorySpaceAssignment::MemorySpace memory_space_filter) const { + MemorySpace memory_space_filter) const { std::vector available = GetCurrentResources(); std::vector time_dump_data; for (int i = start_time; i < end_time; ++i) { @@ -4976,15 +2913,14 @@ std::string AsynchronousCopyResource::Dump( } AlternateMemoryBestFitHeap::AliasedOffset* -AlternateMemoryBestFitHeap::GetAliasedOffset( - const MemorySpaceAssignment::Allocation& allocation) { +AlternateMemoryBestFitHeap::GetAliasedOffset(const Allocation& allocation) { auto aliased_offset_it = aliased_offset_map_.find(&allocation); CHECK(aliased_offset_it != aliased_offset_map_.end()); return aliased_offset_it->second; } void AlternateMemoryBestFitHeap::CreateOrAddToAliasedOffset( - const MemorySpaceAssignment::Allocation& allocation, + const Allocation& allocation, AlternateMemoryBestFitHeap::AliasedOffset* aliased_offset) { CHECK(allocation.memory_space() == MemorySpace::kAlternate); CHECK(!aliased_offset_map_.contains(&allocation)); @@ -4997,10 +2933,8 @@ void AlternateMemoryBestFitHeap::CreateOrAddToAliasedOffset( aliased_offset_map_[&allocation] = aliased_offset; } -/*static*/ MemorySpaceAssignment::Allocation* -AlternateMemoryBestFitHeap::GetLiveAllocationAt( - const MemorySpaceAssignment::AllocationSequence& allocations, - int64_t time) { +/*static*/ Allocation* AlternateMemoryBestFitHeap::GetLiveAllocationAt( + const AllocationSequence& allocations, int64_t time) { for (auto allocation_it = allocations.rbegin(); allocation_it != allocations.rend(); ++allocation_it) { if ((*allocation_it)->start_time() <= time && @@ -5024,8 +2958,8 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer( int cross_program_prefetch_index = module->CrossProgramPrefetches().size(); module->AddCrossProgramPrefetch(parameter, buffer->index()); - MemorySpaceAssignment::AllocationSequence allocations; - allocations.push_back(std::make_unique( + AllocationSequence allocations; + allocations.push_back(std::make_unique( buffer->defining_position(), MemorySpace::kDefault, kDummyChunk, prefetch_candidate.start, prefetch_candidate.end, /*is_scoped_allocation=*/false)); @@ -5142,23 +3076,24 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer( // Add a repack allocation block for the Allocation objects in alternate // memory. - std::vector colocations; + std::vector colocations; for (int i = allocations_initial_size; i < allocations_->size(); ++i) { const auto& allocation = allocations_->at(i); if (allocation->memory_space() == MemorySpace::kAlternate) { repack_allocation_blocks_.push_back(MakeRepackAllocationBlock( allocation->start_time(), allocation->end_time(), allocation->chunk().size, allocation->chunk().offset, - static_cast(colocations.size()), allocation.get())); - RepackAllocationBlock* inserted = &repack_allocation_blocks_.back(); - for (RepackAllocationBlock* colocation : colocations) { - inserted->colocations.push_back(colocation); - colocation->colocations.push_back(inserted); - } - inserted->colocations.emplace_back(inserted); - colocations.emplace_back(inserted); + static_cast(repack_allocation_blocks_.size()), + allocation.get())); + colocations.push_back(&repack_allocation_blocks_.back()); } } + for (int i = 0; i < colocations.size() - 1; ++i) { + colocations[i]->next_colocated = colocations[i + 1]; + } + if (!colocations.empty()) { + colocations.back()->next_colocated = colocations.front(); + } ClearPendingChunks(); } @@ -5166,22 +3101,22 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer( void AlternateMemoryBestFitHeap::AllocateReservedScopedAllocations() { const auto& instruction_sequence = hlo_live_range_.flattened_instruction_sequence().instructions(); - std::vector colocations; for (int i = 0; i < instruction_sequence.size(); ++i) { const HloInstruction* instruction = instruction_sequence[i]; - int64_t reserved_scoped_memory = options_.reserved_scoped_memory_fn( - instruction, /*operands_in_alternate_memory=*/{}, - /*outputs_in_alternate_memory=*/{}); + int64_t reserved_scoped_memory = + std::min(options_.reserved_scoped_memory_fn( + instruction, /*operands_in_alternate_memory=*/{}, + /*outputs_in_alternate_memory=*/{}), + options_.max_size_in_bytes); if (reserved_scoped_memory != 0) { VLOG(1) << "Allocate reserved scoped memory at " << i << " (" << instruction->name() << "): " << reserved_scoped_memory; - MemorySpaceAssignment::BufferInterval interval; + MsaBufferInterval interval; interval.buffer = nullptr; interval.size = reserved_scoped_memory; interval.start = i; interval.end = i; interval.need_allocation = true; - interval.colocations = {}; Chunk chunk_candidate = FindChunkCandidate(interval, /*preferred_offset=*/0); CHECK_EQ(chunk_candidate.offset, 0); @@ -5192,17 +3127,15 @@ void AlternateMemoryBestFitHeap::AllocateReservedScopedAllocations() { instruction, i, reserved_scoped_memory, buffer_info_str_); } - allocations_->push_back( - std::make_unique( - HloPosition{instruction_sequence[i], {}}, MemorySpace::kAlternate, - chunk_candidate, i, i, /*is_scoped_allocation=*/true)); + allocations_->push_back(std::make_unique( + HloPosition{instruction_sequence[i], {}}, MemorySpace::kAlternate, + chunk_candidate, i, i, /*is_scoped_allocation=*/true)); repack_allocation_blocks_.push_back(MakeRepackAllocationBlock( i, i, reserved_scoped_memory, /*initial_offset=*/0, static_cast(repack_allocation_blocks_.size()), allocations_->back().get())); - colocations.push_back(&repack_allocation_blocks_.back()); } } // If requested, make all scoped allocations to colocate with each other so @@ -5211,13 +3144,19 @@ void AlternateMemoryBestFitHeap::AllocateReservedScopedAllocations() { // opportunity to deduplicate different ops. However, this may hurt the // memory packing efficiency. if (options_.allocate_reserved_scoped_memory_at_same_offset) { - for (MemorySpaceAssignmentRepacker::AllocationBlock* repack_block : - colocations) { - repack_block->colocations = colocations; + for (auto allocation_block_it = repack_allocation_blocks_.begin(); + allocation_block_it != repack_allocation_blocks_.end() && + std::next(allocation_block_it) != repack_allocation_blocks_.end(); + ++allocation_block_it) { + allocation_block_it->next_colocated = &*std::next(allocation_block_it); + } + if (!repack_allocation_blocks_.empty()) { + repack_allocation_blocks_.back().next_colocated = + &repack_allocation_blocks_.front(); } } else { for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) { - allocation_block.colocations.push_back(&allocation_block); + allocation_block.next_colocated = &allocation_block; } } ClearPendingChunks(); @@ -5267,7 +3206,7 @@ AlternateMemoryBestFitHeap::AliasedRequiredAssignmentForUse( void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment( const HloInstruction* instruction, ShapeIndex index, - const MemorySpaceAssignment::Allocation* aliased_allocation) { + const Allocation* aliased_allocation) { AliasedOffset* offset = nullptr; if (aliased_allocation->memory_space() == MemorySpace::kAlternate) { offset = GetAliasedOffset(*aliased_allocation); @@ -5278,8 +3217,8 @@ void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment( void AlternateMemoryBestFitHeap::AddRequiredAssignment( const HloValue* value, const HloInstruction* instruction, - MemorySpaceAssignment::MemorySpace memory_space, int64_t time, - AliasedOffset* offset, bool add_to_pending) { + MemorySpace memory_space, int64_t time, AliasedOffset* offset, + bool add_to_pending) { // Check for existing required assignment at this time and make sure it is the // same as this if there is one. auto existing_required_assignment = RequiredMemoryAssignmentAt(value, time); @@ -5399,23 +3338,24 @@ void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() { continue; } int64_t constant_instruction_time = constant_instruction_it->second; - for (const auto& indexed_shape : - ShapeUtil::GetLeafShapes(instruction->shape())) { - const ShapeIndex& index = indexed_shape.index; - for (const HloBuffer* buffer : - alias_analysis_.ComputeBuffersAt(instruction, index)) { - for (const HloValue* value : buffer->values()) { - VLOG(3) << "Adding required assignment for constant value = " - << value->ToShortString() - << " time = " << constant_instruction_time - << " space = def"; - AddRequiredAssignment(value, instruction, MemorySpace::kDefault, - constant_instruction_time, - /*offset=*/nullptr, - /*add_to_pending=*/false); - } - } - } + ShapeUtil::ForEachLeafShape( + instruction->shape(), + [&](const Shape& /*sub_shape*/, const ShapeIndex& index) { + for (const HloBuffer* buffer : + alias_analysis_.ComputeBuffersAt(instruction, index)) { + for (const HloValue* value : buffer->values()) { + VLOG(3) << "Adding required assignment for constant value = " + << value->ToShortString() + << " time = " << constant_instruction_time + << " space = def"; + AddRequiredAssignment(value, instruction, + MemorySpace::kDefault, + constant_instruction_time, + /*offset=*/nullptr, + /*add_to_pending=*/false); + } + } + }); } } } @@ -5446,6 +3386,8 @@ void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() { << "Mismatch in required assignments at time " << instruction_time << " value: " << value->ToString(); } else { + VLOG(3) << "Adding required assignment: " << value->ToShortString() + << " at " << instruction_time << " at def"; required_assignments.push_back( {MemorySpace::kDefault, instruction_time}); } @@ -5511,7 +3453,7 @@ void AlternateMemoryBestFitHeap::UpdateReservedScopedAllocationSize() { } // Update scoped allocation sizes. for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) { - MemorySpaceAssignment::Allocation* allocation = allocation_block.allocation; + Allocation* allocation = allocation_block.allocation; if (allocation->is_scoped_allocation()) { allocation_block.size = reserved_scoped_memory_map[allocation->start_time()]; @@ -5522,8 +3464,7 @@ void AlternateMemoryBestFitHeap::UpdateReservedScopedAllocationSize() { } void AlternateMemoryBestFitHeap::ExportAllocationsForRepacking( - std::vector& allocations) { - using SlicedCopyAllocation = MemorySpaceAssignment::SlicedCopyAllocation; + std::vector& allocations) { using SliceDetail = SlicedCopyAllocation::SliceDetail; if (options_.reduce_scoped_memory_limit) { @@ -5556,17 +3497,16 @@ void AlternateMemoryBestFitHeap::ExportAllocationsForRepacking( // Since this is a sliced allocation, construct SlicedAllocationData to // attach to the AllocationBlock. - MemorySpaceAssignmentRepacker::SlicedAllocationData original_slice_data; + SlicedAllocationData original_slice_data; for (const SliceDetail* slice_detail : slice_details_sorted_by_offset) { CHECK_EQ(slice_detail->copy_start_after_time, slice_detail->slice_decision.exclusive_start_time); - original_slice_data.slices_sorted_by_offset.push_back( - MemorySpaceAssignmentRepacker::Slice{ - slice_detail->slice_decision.chunk.size, - slice_detail->slice_decision.chunk.offset, - /*inclusive_start_time=*/ - ExclusiveToInclusiveStartTime( - slice_detail->slice_decision.exclusive_start_time)}); + original_slice_data.slices_sorted_by_offset.push_back(AllocatedSlice{ + slice_detail->slice_decision.chunk.size, + slice_detail->slice_decision.chunk.offset, + /*inclusive_start_time=*/ + ExclusiveToInclusiveStartTime( + slice_detail->slice_decision.exclusive_start_time)}); } allocation_block.original_slice_data = std::move(original_slice_data); @@ -5587,7 +3527,7 @@ void AlternateMemoryBestFitHeap::ImportRepackedAllocations() { void AlternateMemoryBestFitHeap::ImportRepackedNonSlicedAllocation( RepackAllocationBlock& block) { - MemorySpaceAssignment::Allocation* allocation = block.allocation; + Allocation* allocation = block.allocation; int64_t original_offset = block.initial_offset; int64_t repacked_offset = block.offset; @@ -5606,9 +3546,11 @@ void AlternateMemoryBestFitHeap::ImportRepackedNonSlicedAllocation( void AlternateMemoryBestFitHeap::ImportRepackedSlicedAllocation( RepackAllocationBlock& block) { - using SlicedCopyAllocation = MemorySpaceAssignment::SlicedCopyAllocation; + using SlicedCopyAllocation = memory_space_assignment::SlicedCopyAllocation; using SliceDetail = SlicedCopyAllocation::SliceDetail; + CHECK_OK(AreRepackedSlicesValid(block)); + SlicedCopyAllocation* allocation = dynamic_cast(block.allocation); CHECK(block.allocation->is_sliced_copy_allocation()); @@ -5620,9 +3562,6 @@ void AlternateMemoryBestFitHeap::ImportRepackedSlicedAllocation( // Update the Allocation, AllocationBlock, and interval_tree_. allocation->set_offset(repacked_offset); if (block.repacked_slice_data.has_value()) { - CHECK(block.original_slice_data.has_value()); - CHECK_EQ(allocation->slice_details_sorted_by_start_time().size(), - block.repacked_slice_data->slices_sorted_by_offset.size()); allocation->ImportRepackedSliceData(*block.repacked_slice_data); } else { allocation->AddDiffToAllSliceOffsets(repacked_offset - original_offset); @@ -5661,6 +3600,53 @@ void AlternateMemoryBestFitHeap::ImportRepackedSlicedAllocation( << "; Allocation: " << allocation->ToString(); } +Status AlternateMemoryBestFitHeap::AreRepackedSlicesValid( + const RepackAllocationBlock& block) { + if (!block.repacked_slice_data.has_value()) { + return OkStatus(); + } + if (!block.original_slice_data.has_value()) { + return InvalidArgumentStrCat( + "Repacked sliced allocation has repacked slice data but not original " + "slice data."); + } + int64_t num_slices = + block.original_slice_data->slices_sorted_by_offset.size(); + if (num_slices != block.repacked_slice_data->slices_sorted_by_offset.size()) { + return InvalidArgumentStrCat( + "Repacked sliced allocation has ", num_slices, + " slices but repacking has data for ", + block.repacked_slice_data->slices_sorted_by_offset.size(), " slices."); + } + + // Ensure that the slice size to start time mapping has not changed. If it + // changes, its invalidates MSA's internal state, e.g., the peak_memory_usage_ + // data structure. + std::vector> original_size_to_time_mapping; + original_size_to_time_mapping.reserve(num_slices); + for (const AllocatedSlice& slice : + block.original_slice_data->slices_sorted_by_offset) { + original_size_to_time_mapping.push_back( + std::make_pair(slice.size, slice.inclusive_start_time)); + }; + absl::c_sort(original_size_to_time_mapping); + std::vector> repacked_size_to_time_mapping; + repacked_size_to_time_mapping.reserve(num_slices); + for (const AllocatedSlice& slice : + block.repacked_slice_data->slices_sorted_by_offset) { + repacked_size_to_time_mapping.push_back( + std::make_pair(slice.size, slice.inclusive_start_time)); + }; + absl::c_sort(repacked_size_to_time_mapping); + if (original_size_to_time_mapping != repacked_size_to_time_mapping) { + return InvalidArgumentStrCat( + "Repacked slices do not preserve the initial slice size-start time " + "mappings."); + } + + return OkStatus(); +} + void AlternateMemoryBestFitHeap::UncommitPendingChunks( absl::Span allocation_values) { // Clear the allocation sequence of the allocation values so that in case we @@ -5727,8 +3713,7 @@ void AlternateMemoryBestFitHeap::UncommitPendingChunks( void AlternateMemoryBestFitHeap::FinalizeAllocations( absl::Span allocation_values) { - absl::flat_hash_map> + absl::flat_hash_map> colocation_map; for (AllocationValue& allocation_value : allocation_values) { for (auto& allocation : *allocation_value.mutable_allocation_sequence()) { @@ -5745,8 +3730,7 @@ void AlternateMemoryBestFitHeap::FinalizeAllocations( } } allocations_->push_back(std::move(allocation)); - MemorySpaceAssignment::Allocation* inserted_allocation = - allocations_->back().get(); + Allocation* inserted_allocation = allocations_->back().get(); if (inserted_allocation->memory_space() == MemorySpace::kAlternate) { colocation_map[GetAliasedOffset(*inserted_allocation)].push_back( inserted_allocation); @@ -5757,9 +3741,8 @@ void AlternateMemoryBestFitHeap::FinalizeAllocations( // Export these to repack_allocation_blocks_ so that we can repack them to // reduce fragmentation. for (auto& colocation : colocation_map) { - std::vector colocations; - for (MemorySpaceAssignment::Allocation* colocated_allocation : - colocation.second) { + std::vector colocations; + for (Allocation* colocated_allocation : colocation.second) { repack_allocation_blocks_.push_back(MakeRepackAllocationBlock( colocated_allocation->start_time(), colocated_allocation->end_time(), colocated_allocation->chunk().size, @@ -5768,9 +3751,11 @@ void AlternateMemoryBestFitHeap::FinalizeAllocations( colocated_allocation)); colocations.push_back(&repack_allocation_blocks_.back()); } - for (MemorySpaceAssignmentRepacker::AllocationBlock* repack_block : - colocations) { - repack_block->colocations = colocations; + for (int i = 0; i < colocations.size() - 1; ++i) { + colocations[i]->next_colocated = colocations[i + 1]; + } + if (!colocations.empty()) { + colocations.back()->next_colocated = colocations.front(); } } ClearPendingChunks(); @@ -5825,7 +3810,7 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( // consumed multiple times by the same instruction. We can just find the // previous allocation and use that allocation. if (request.inclusive_start_time == request.end_time) { - MemorySpaceAssignment::Allocation* allocation = + Allocation* allocation = GetLiveAllocationAt(*allocation_sequence, request.end_time); CHECK_NE(allocation, nullptr); allocation->AddUse(request.use->hlo_use); @@ -5842,6 +3827,9 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( << " use = " << request.use->hlo_use.ToString() << ". Size = " << request.size << ", def pos = " << defining_position.ToString(); + if (request.require_no_copy_alternate_mem_allocation) { + VLOG(2) << "Requiring alternate memory allocation."; + } CHECK_LE(request.inclusive_start_time, request.end_time); if (VLOG_IS_ON(3) && options_.cost_analysis) { const HloPosition& defining_position = @@ -5853,12 +3841,13 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( << " use benefit = " << options_.cost_analysis->GetAlternateMemoryBenefit( request.use->hlo_use); - VLOG(3) << "Definition bytes accessed = " - << options_.cost_analysis->cost_analysis().output_bytes_accessed( - *defining_position.instruction, defining_position.index) - << ", use bytes accessed = " - << options_.cost_analysis->cost_analysis().operand_bytes_accessed( - *use.instruction, use.operand_number, use.operand_index); + VLOG(3) + << "Definition bytes accessed = " + << options_.cost_analysis->hlo_cost_analysis().output_bytes_accessed( + *defining_position.instruction, defining_position.index) + << ", use bytes accessed = " + << options_.cost_analysis->hlo_cost_analysis().operand_bytes_accessed( + *use.instruction, use.operand_number, use.operand_index); } // There could be a requirement to pin this buffer to default memory either @@ -5912,12 +3901,11 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( aliased_chunk = Chunk::FromOffsetSize( required_assignment_at_start->offset->offset, request.size); } - allocation_sequence->push_back( - std::make_unique( - defining_position, required_assignment_at_start->memory_space, - aliased_chunk, request.inclusive_start_time, - request.inclusive_start_time, - /*is_scoped_allocation=*/false)); + allocation_sequence->push_back(std::make_unique( + defining_position, required_assignment_at_start->memory_space, + aliased_chunk, request.inclusive_start_time, + request.inclusive_start_time, + /*is_scoped_allocation=*/false)); if (required_assignment_at_start->memory_space == MemorySpace::kAlternate) { CreateOrAddToAliasedOffset(*allocation_sequence->back(), @@ -5935,8 +3923,14 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( if (allocation_result == Result::kSuccess) { return Result::kSuccess; } + // If we required alternate memory allocation, return on failure. + if (request.require_no_copy_alternate_mem_allocation) { + return allocation_result; + } } + CHECK(!request.require_no_copy_alternate_mem_allocation); + auto prev_allocation_it = allocation_sequence->rbegin(); // Find a previous allocation that is in the default memory space (not // necessarily the very last allocation). @@ -5960,12 +3954,10 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( } prev_allocation_in_default_mem_it = allocation_sequence->rbegin(); } else if (prev_allocation_in_default_mem_it == allocation_sequence->rend()) { - allocation_sequence->push_back( - std::make_unique( - defining_position, MemorySpace::kDefault, - /*chunk=*/std::nullopt, request.inclusive_start_time, - request.end_time, - /*is_scoped_allocation=*/false)); + allocation_sequence->push_back(std::make_unique( + defining_position, MemorySpace::kDefault, + /*chunk=*/std::nullopt, request.inclusive_start_time, request.end_time, + /*is_scoped_allocation=*/false)); prev_allocation_in_default_mem_it = allocation_sequence->rbegin(); } @@ -6004,21 +3996,17 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( // Warn if the prefetch time picked doesn't match the preferred prefetch // time. CHECK(!request.allocation_value->allocation_sequence()->empty()); - const MemorySpaceAssignment::Allocation* allocation = + const Allocation* allocation = request.allocation_value->allocation_sequence()->back().get(); int64_t prefetch_time = 0; if (allocation->is_copy_allocation()) { - prefetch_time = - static_cast( - allocation) - ->copy_start_schedule_after(); + prefetch_time = static_cast(allocation) + ->copy_start_schedule_after(); } else if (allocation->is_sliced_copy_allocation()) { - prefetch_time = - static_cast( - allocation) - ->slice_details_sorted_by_start_time() - .front() - .copy_start_after_time; + prefetch_time = static_cast(allocation) + ->slice_details_sorted_by_start_time() + .front() + .copy_start_after_time; } else { LOG(FATAL) << "Prefetch allocation are expected to be " "CopyAllocations or SlicedCopyAllocations."; @@ -6065,34 +4053,28 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( } void AlternateMemoryBestFitHeap::AddAsyncCopy( - MemorySpaceAssignment::Allocation& prev_allocation, - MemorySpace memory_space, std::optional chunk, - int64_t exclusive_start_time, int64_t end_time, - int64_t copy_done_schedule_before_time, - MemorySpaceAssignment::AllocationSequence* allocations, + Allocation& prev_allocation, MemorySpace memory_space, + std::optional chunk, int64_t exclusive_start_time, int64_t end_time, + int64_t copy_done_schedule_before_time, AllocationSequence* allocations, AliasedOffset* aliased_offset, float resource, std::optional cross_program_prefetch_index) { VLOG(3) << "Copy to " - << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault - ? "default" - : "alternate") + << (memory_space == MemorySpace::kDefault ? "default" : "alternate") << " memory in (" << exclusive_start_time << ", " << copy_done_schedule_before_time << "), keeping until " << end_time << ", estimated copy resource is " << resource; CHECK_LT(exclusive_start_time, copy_done_schedule_before_time); - allocations->push_back( - std::make_unique( - prev_allocation, memory_space, chunk, exclusive_start_time, - copy_done_schedule_before_time, end_time, - cross_program_prefetch_index)); + allocations->push_back(std::make_unique( + prev_allocation, memory_space, chunk, exclusive_start_time, + copy_done_schedule_before_time, end_time, cross_program_prefetch_index)); // Register the additional async copy with the interval tree to keep track of // the limit at any given time. pending_async_copies_.push_back({exclusive_start_time, copy_done_schedule_before_time, resource, memory_space, next_async_copy_id_++}); - if (memory_space == MemorySpaceAssignment::MemorySpace::kAlternate) { + if (memory_space == MemorySpace::kAlternate) { prefetch_interval_tree_.Add( /*start=*/ ExclusiveToInclusiveStartTime(exclusive_start_time), @@ -6120,8 +4102,8 @@ namespace { // - When the allocation for the slice ends // - An estimation of how much copy resource the slice consumes std::string SliceTimesAndCopyResourcesToString( - const std::vector& slice_decisions, - int64_t prefetch_end, int64_t allocation_end) { + const std::vector& slice_decisions, int64_t prefetch_end, + int64_t allocation_end) { std::vector slice_strings; slice_strings.reserve(slice_decisions.size()); @@ -6145,11 +4127,9 @@ std::string SliceTimesAndCopyResourcesToString( } // namespace void AlternateMemoryBestFitHeap::AddAsyncSlicesForPrefetch( - const MemorySpaceAssignment::Allocation& prev_allocation, - MemorySpaceAssignment::AllocationSequence* allocations, + const Allocation& prev_allocation, AllocationSequence* allocations, AliasedOffset* aliased_offset, - const std::vector& - slice_decisions_sorted_by_start_time, + const std::vector& slice_decisions_sorted_by_start_time, int64_t prefetch_end_time, int64_t allocation_end_time) { VLOG(3) << "Sliced copy to alternate memory. " << SliceTimesAndCopyResourcesToString( @@ -6160,19 +4140,18 @@ void AlternateMemoryBestFitHeap::AddAsyncSlicesForPrefetch( return slice_decision.exclusive_start_time < prefetch_end_time; })); - allocations->push_back( - std::make_unique( - prev_allocation, MemorySpaceAssignment::MemorySpace::kAlternate, - slice_decisions_sorted_by_start_time, prefetch_end_time, - allocation_end_time, options_.update_layout_fn)); + allocations->push_back(std::make_unique( + prev_allocation, MemorySpace::kAlternate, + slice_decisions_sorted_by_start_time, prefetch_end_time, + allocation_end_time, options_.sliced_prefetch_options, + options_.get_equivalent_s8_shape_fn)); // Register the additional async copy with the interval tree to keep track of // the limit at any given time. for (const auto& slice_decision : slice_decisions_sorted_by_start_time) { pending_async_copies_.push_back( {slice_decision.exclusive_start_time, prefetch_end_time, - slice_decision.copy_resource_consumed, - MemorySpaceAssignment::MemorySpace::kAlternate, + slice_decision.copy_resource_consumed, MemorySpace::kAlternate, next_async_copy_id_++}); prefetch_interval_tree_.Add(slice_decision.exclusive_start_time, prefetch_end_time, kDummyChunk); @@ -6217,7 +4196,7 @@ bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies( AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( const AllocationRequest& request) { - MemorySpaceAssignment::Allocation* prev_allocation = nullptr; + Allocation* prev_allocation = nullptr; bool can_eliminate_copy = false; if (request.allocation_value->allocation_sequence()->empty()) { // There hasn't been any allocations for this interval so far. We can @@ -6242,7 +4221,8 @@ AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( request.allocation_value->defining_position(); // If prefer_no_copy_alternate_mem_allocation is true, bypass the live range // duration checks. - if (!request.prefer_no_copy_alternate_mem_allocation && + if (!request.require_no_copy_alternate_mem_allocation && + !request.prefer_no_copy_alternate_mem_allocation && !options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( defining_position.shape(), request.inclusive_start_time, request.end_time)) { @@ -6266,11 +4246,15 @@ AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( } if (request.preferred_offset) { - // Sanity check that if there is a preferred offset provided in the request, - // it matches with the previous allocation. - CHECK(!preferred_offset || request.preferred_offset == preferred_offset) - << "preferred_offset = " << preferred_offset->offset - << ", request.preferred_offset = " << request.preferred_offset->offset; + // If there is a preferred offset provided in the request and if it doesn't + // match the previous allocation, this request cannot be satisified. + if (preferred_offset && request.preferred_offset != preferred_offset) { + VLOG(3) << "Cannot perform no-copy allocation due to mismatch: " + "preferred_offset = " + << preferred_offset->offset << ", request.preferred_offset = " + << request.preferred_offset->offset; + return Result::kFailConflictingPreferredOffsets; + } preferred_offset = request.preferred_offset; } @@ -6321,7 +4305,7 @@ AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( prev_allocation->set_end_time(request.end_time); } else { request.allocation_value->mutable_allocation_sequence()->push_back( - std::make_unique( + std::make_unique( defining_position, MemorySpace::kAlternate, chunk_candidate, request.inclusive_start_time, request.end_time, /*is_scoped_allocation=*/false)); @@ -6343,7 +4327,7 @@ AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Evict( const AllocationRequest& request) { CHECK_GT(request.allocation_value->allocation_sequence()->size(), 0); - MemorySpaceAssignment::Allocation* prev_allocation = + Allocation* prev_allocation = request.allocation_value->allocation_sequence()->back().get(); // We do not ever expect an Evict() to be immediately proceeded by a prefetch. // If that case ever occurs, the eviction_exclusive_start_time below will be @@ -6476,7 +4460,7 @@ namespace { // A debugging/logging method for describing a sliced solution. std::string DescribeSlicedBufferMove( - const std::vector& slice_decisions, + const std::vector& slice_decisions, const AlternateMemoryBestFitHeap::HeapResult& heap_result, const AlternateMemoryBestFitHeap::Chunk& full_chunk, absl::string_view prefetch_picker_debug_string) { @@ -6501,7 +4485,7 @@ std::string DescribeSlicedBufferMove( AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Prefetch( const AllocationRequest& request, - MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem) { + Allocation& prev_allocation_in_default_mem) { // Try partially placing the buffer in the alternate space. The time that is // overlapped will be used to asynchronously copy the buffer from the // default memory to the alternate memory. @@ -6686,12 +4670,10 @@ void AlternateMemoryBestFitHeap::GenerateSliceProposal( } VLOG(6) << log_prefix() << ". Slice proposal = [" - << absl::StrJoin( - status_or_proposal.value(), ", ", - [](std::string* out, - const MemorySpaceAssignment::SliceProposal& proposal) { - absl::StrAppend(out, proposal.ToString()); - }) + << absl::StrJoin(status_or_proposal.value(), ", ", + [](std::string* out, const SliceProposal& proposal) { + absl::StrAppend(out, proposal.ToString()); + }) << "]"; context.slice_proposal_collection = std::move(status_or_proposal.value()); @@ -6725,7 +4707,7 @@ void AlternateMemoryBestFitHeap::SetupPrefetchWorkingIntervalsAndSliceProposal( context.sliced_solution_intervals.full)); std::vector sizes; sizes.reserve(context.slice_proposal_collection->size()); - for (const MemorySpaceAssignment::SliceProposal& single_slice_proposal : + for (const SliceProposal& single_slice_proposal : *context.slice_proposal_collection) { sizes.push_back(single_slice_proposal.slice_size); } @@ -6828,12 +4810,10 @@ float CopyResourceForShape(const Options& options, const Shape& shape) { // collection, in descending order. std::vector GetCopyResourcesSortedDescending( const Options& options, - const MemorySpaceAssignment::SliceProposalCollection& - slice_proposal_collection) { + const SliceProposalCollection& slice_proposal_collection) { std::vector copy_resources; copy_resources.reserve(slice_proposal_collection.size()); - for (const MemorySpaceAssignment::SliceProposal& proposal : - slice_proposal_collection) { + for (const SliceProposal& proposal : slice_proposal_collection) { copy_resources.push_back( CopyResourceForShape(options, proposal.slice_shape)); } @@ -7051,20 +5031,16 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::CheckPrefetchFit( GetCandidateToProposalIndexMap(chunk_candidates); // Create slice decisions, sorted by time. - std::vector - slice_decisions_sorted_by_start_time; + std::vector slice_decisions_sorted_by_start_time; for (int64_t slice_time = 0; slice_time < sliced_buffer_interval->num_slices(); ++slice_time) { - const MemorySpaceAssignment::SliceProposal& proposal = - context.slice_proposal_collection->at( - candidate_to_proposal_index_map[slice_time]); + const SliceProposal& proposal = context.slice_proposal_collection->at( + candidate_to_proposal_index_map[slice_time]); copy_resource_per_slice_sorted_by_start_time[slice_time] = CopyResourceForShape(options_, proposal.slice_shape); - slice_decisions_sorted_by_start_time.push_back( - MemorySpaceAssignment::SliceDecision{ - chunk_candidates[slice_time], - exclusive_slice_start_times[slice_time], proposal, - copy_resource_per_slice_sorted_by_start_time[slice_time]}); + slice_decisions_sorted_by_start_time.push_back(SliceDecision{ + chunk_candidates[slice_time], exclusive_slice_start_times[slice_time], + proposal, copy_resource_per_slice_sorted_by_start_time[slice_time]}); } // Check that we have enough copy resources for all the slice decisions. @@ -7357,7 +5333,7 @@ AlternateMemoryBestFitHeap::FindBestChunkCandidates( return {}; } -StatusOr +absl::StatusOr MemorySpaceAssignment::CalculateAsyncCopyStats() const { AsyncCopyStats stats; int64_t current_copies = 0; @@ -7401,7 +5377,7 @@ MemorySpaceAssignment::CalculateAsyncCopyStats() const { return stats; } -/*static*/ StatusOr> +/*static*/ absl::StatusOr> MemorySpaceAssignment::Run(HloModule* module, const HloLiveRange& hlo_live_range, const HloAliasAnalysis& alias_analysis, @@ -7417,7 +5393,7 @@ MemorySpaceAssignment::Run(HloModule* module, alias_analysis); } -StatusOr> +absl::StatusOr> MemorySpaceAssignment::RunMemorySpaceAssignment( const HloLiveRange& hlo_live_range, const HloAliasAnalysis& alias_analysis) { @@ -7471,45 +5447,6 @@ Status MemorySpaceAssignment::FindAllocationSequence( return OkStatus(); } -bool MemorySpaceAssignment::Allocation::is_copy_like_allocation() const { - return is_copy_allocation() || is_sliced_copy_allocation(); -} - -void MemorySpaceAssignment::Allocation::AddUse(HloUse use) { - HloInstruction* operand = - use.instruction->mutable_operand(use.operand_number); - // If the use is a tuple, look inside the tuple to find the actual use. - for (int64_t index : use.operand_index) { - if (operand->opcode() != HloOpcode::kTuple) { - break; - } - operand = operand->mutable_operand(index); - } - - // Look beyond GetTupleElement(Tuple()) pattern for any bitcasts. - std::function get_simplified_operand; - get_simplified_operand = [&](HloInstruction* instruction) { - while (instruction->opcode() == HloOpcode::kGetTupleElement) { - HloInstruction* operand = - get_simplified_operand(instruction->mutable_operand(0)); - if (operand->opcode() == HloOpcode::kTuple) { - instruction = operand->mutable_operand(instruction->tuple_index()); - } else { - return instruction; - } - } - return instruction; - }; - operand = get_simplified_operand(operand); - - uses_.push_back(use); -} - -void MemorySpaceAssignment::Allocation::set_offset(int64_t offset) { - CHECK(chunk_.has_value()); - *chunk_ = Chunk::FromOffsetSize(offset, chunk_->size); -} - float MemorySpaceAssignment::ComputeEstimatedElapsedTime( const HloLiveRange& hlo_live_range, const AllocationSequence& allocations) { absl::flat_hash_map> @@ -7552,660 +5489,16 @@ float MemorySpaceAssignment::ComputeEstimatedElapsedTime( options_.cost_analysis->GetInstructionElapsedInAlternateMemory( *instruction, operands_in_alternate_memory, outputs_in_alternate_memory); - float while_nest_multiplier = IPow( - options_.xla_tpu_memory_space_assignment_while_execution_count, - options_.cost_analysis->CalculateComputationNestLevel( - instruction, - /*while_only=*/true)); + float while_nest_multiplier = + options_.cost_analysis->GetWhileNestMultiplier( + options_.cost_analysis->CalculateComputationNestLevel( + instruction, + /*while_only=*/true)); total_elapsed += while_nest_multiplier * instruction_elapsed; } return total_elapsed; } -Status MemorySpaceAssignment::Allocation::Process() { - if (is_scoped_allocation()) { - // Nothing to do here for scoped allocations. - return OkStatus(); - } - HloInstruction* producing_instruction = AddGetTupleElements(); - HloComputation* computation = producing_instruction->parent(); - for (const HloUse& use : uses_) { - Shape operand_shape = use.instruction->operand(use.operand_number)->shape(); - HloInstruction* replacement_instruction = producing_instruction; - if (operand_shape.IsTuple()) { - TF_ASSIGN_OR_RETURN( - replacement_instruction, - TupleUtil::ReplaceTupleWith( - producing_instruction, - use.instruction->mutable_operand(use.operand_number), - use.operand_index)); - } else if (operand_shape != producing_instruction->shape()) { - VLOG(4) << "Old shape = " << operand_shape.ToString() - << ", new shape = " << producing_instruction->shape().ToString() - << "; inserting a bitcast."; - replacement_instruction = computation->AddInstruction( - HloInstruction::CreateBitcast(operand_shape, producing_instruction)); - } - TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith( - use.operand_number, replacement_instruction)); - } - return OkStatus(); -} - -HloInstruction* MemorySpaceAssignment::Allocation::AddGetTupleElements() const { - CHECK_NE(defining_position().instruction, nullptr); - - Shape shape = defining_position().shape(); - CHECK(shape.IsArray()) << "Allocation shape is not an array. Shape = " - << shape.ToString() - << " position = " << defining_position().shape(); - return TupleUtil::AddGetTupleElements(defining_position()); -} - -std::string MemorySpaceAssignment::Allocation::ToString() const { - std::string memory_space_str = - memory_space_ == MemorySpace::kDefault ? "def" : "alt"; - if (chunk_) { - absl::StrAppend(&memory_space_str, " (off: ", chunk_->offset, ")"); - } - return absl::StrCat((is_scoped_allocation() ? "Scoped " : ""), - "Allocation in ", memory_space_str, " defined at ", - defining_position_.ToString(), - ", start_time:", start_time(), ", end_time:", end_time(), - ", uses: ", UsesToString(uses())); -} - -std::string MemorySpaceAssignment::CopyAllocation::ToString() const { - std::string memory_space_str = - memory_space_ == MemorySpace::kDefault ? "def" : "alt"; - if (chunk_) { - absl::StrAppend(&memory_space_str, " (off: ", chunk_->offset, ")"); - } - return absl::StrCat("Copy Allocation in ", memory_space_str, - ", start_time:", start_time(), ", end_time:", end_time(), - ", copy_start_after_time: ", copy_start_schedule_after(), - ", copy_done_before_time: ", copy_done_schedule_before(), - ", uses: ", UsesToString(uses()), ", from ", - prev_allocation_.ToString()); -} - -std::string MemorySpaceAssignment::SliceParam::ToString() const { - return absl::StrCat("[", start_inclusive, ",", end_exclusive, ")"); -} - -bool MemorySpaceAssignment::SliceParam::operator==( - const SliceParam& other) const { - return start_inclusive == other.start_inclusive && - end_exclusive == other.end_exclusive; -} - -std::string MemorySpaceAssignment::SliceProposal::ToString() const { - return absl::StrCat( - "{ slice_shape: ", slice_shape.ToString(true), ", slice_params: { ", - absl::StrJoin(slice_params, ", ", - [](std::string* out, const SliceParam& param) { - absl::StrAppend(out, param.ToString()); - }), - " }, slice_size: ", slice_size, " }"); -} - -std::ostream& operator<<(std::ostream& os, - const MemorySpaceAssignment::SliceProposal& proposal) { - os << proposal.ToString(); - return os; -} - -std::tuple&, - int64_t> -MemorySpaceAssignment::SliceProposal::ToTuple() const { - return std::make_tuple(std::ref(slice_shape), std::ref(slice_params), - slice_size); -} - -bool MemorySpaceAssignment::SliceProposal::operator==( - const SliceProposal& other) const { - return ToTuple() == other.ToTuple(); -} - -std::string MemorySpaceAssignment::SliceDecision::ToString() const { - return absl::StrCat("{ chunk: ", chunk.ToString(), - ", (exclusive) start_time: ", exclusive_start_time, - ", sizing: ", sizing.ToString(), - ", copy_resource_consumed: ", copy_resource_consumed, - " }"); -} - -namespace { - -std::tuple -SliceDecisionToTuple(const MemorySpaceAssignment::SliceDecision& decision) { - return std::make_tuple( - std::ref(decision.chunk), decision.exclusive_start_time, - std::ref(decision.sizing), decision.copy_resource_consumed); -} - -} // namespace - -bool MemorySpaceAssignment::SliceDecision::operator==( - const SliceDecision& other) const { - return SliceDecisionToTuple(*this) == SliceDecisionToTuple(other); -} - -std::string MemorySpaceAssignment::SlicedCopyAllocation::SliceDetail::ToString() - const { - return absl::StrCat("{ slice_decision: ", slice_decision.ToString(), - ", copy_start_after_time: ", copy_start_after_time, - ", copy_done_before_time: ", copy_done_before_time, " }"); -} - -namespace { - -std::tuple -SliceDetailToTuple( - const MemorySpaceAssignment::SlicedCopyAllocation::SliceDetail& - slice_detail) { - return std::make_tuple(std::ref(slice_detail.slice_decision), - slice_detail.copy_start_after_time, - slice_detail.copy_done_before_time, - slice_detail.copy_start, slice_detail.copy_done); -} - -} // namespace - -bool MemorySpaceAssignment::SlicedCopyAllocation::SliceDetail::operator==( - const SliceDetail& other) const { - return SliceDetailToTuple(*this) == SliceDetailToTuple(other); -} - -Status -MemorySpaceAssignment::SlicedCopyAllocation::SliceDetail::CreateAsyncSlice( - const Shape& original_shape, HloInstruction& producer, - HloComputation& parent, absl::FunctionRef update_layout_fn) { - if (original_shape.rank() != slice_decision.sizing.slice_params.size()) { - return FailedPrecondition( - "%s", absl::StrCat("The number of SlicedCopyAllocation parameters ", - slice_decision.sizing.slice_params.size(), - " does not match the rank ", original_shape.rank(), - " of the tensor we are slicing.")); - } - - std::vector start_indices; - start_indices.reserve(slice_decision.sizing.slice_params.size()); - std::vector limit_indices; - limit_indices.reserve(slice_decision.sizing.slice_params.size()); - std::vector strides; - strides.reserve(slice_decision.sizing.slice_params.size()); - Shape new_shape(original_shape); - - for (int i = 0; i < slice_decision.sizing.slice_params.size(); ++i) { - const SliceParam& slice_param = slice_decision.sizing.slice_params[i]; - start_indices.push_back(slice_param.start_inclusive); - limit_indices.push_back(slice_param.end_exclusive); - strides.push_back(1); - int64_t new_value = slice_param.end_exclusive - slice_param.start_inclusive; - if (new_value <= 0) { - return FailedPrecondition( - "%s", absl::StrCat("SlicedCopyAllocation new dimension size is ", - new_value, ", expected something > 0.")); - } - if (new_shape.dimensions(i) < new_value) { - return FailedPrecondition( - "%s", - absl::StrCat("SlicedCopyAllocation sliced dimension size ", new_value, - " is bigger than its original dimension size of ", - new_shape.dimensions(i), ".")); - } - new_shape.set_dimensions(i, new_value); - } - update_layout_fn(&new_shape); - if (!Shape::Equal().IgnoreMemorySpaceInLayout()( - slice_decision.sizing.slice_shape, new_shape)) { - return FailedPrecondition( - "%s", - absl::StrCat( - "Slice was calculated to have shape ", - slice_decision.sizing.slice_shape.ToString(true), - ", but we are trying to create the slice instruction with shape ", - new_shape.ToString(true), ".")); - } - - HloInstruction* slice = parent.AddInstruction(HloInstruction::CreateSlice( - new_shape, &producer, start_indices, limit_indices, strides)); - TF_ASSIGN_OR_RETURN(copy_done, parent.CreateAsyncInstructions( - slice, {ShapeUtil::MakeShape(S32, {})})); - copy_start = copy_done->mutable_operand(0); - - return OkStatus(); -} - -namespace { - -// Helper function to compute the underlying Allocation chunk for a -// SlicedCopyAllocation. -std::optional GetSlicedCopyAllocationChunk( - const std::vector& - slice_decisions_sorted_by_start_time) { - if (slice_decisions_sorted_by_start_time.empty()) { - return std::nullopt; - } - auto offset_cmp = [](const MemorySpaceAssignment::SliceDecision& lhs, - const MemorySpaceAssignment::SliceDecision& rhs) { - return lhs.chunk.offset < rhs.chunk.offset; - }; - auto end_cmp = [](const MemorySpaceAssignment::SliceDecision& lhs, - const MemorySpaceAssignment::SliceDecision& rhs) { - return lhs.chunk.chunk_end() < rhs.chunk.chunk_end(); - }; - return MemorySpaceAssignment::Chunk::FromOffsetEnd( - std::min_element(slice_decisions_sorted_by_start_time.begin(), - slice_decisions_sorted_by_start_time.end(), offset_cmp) - ->chunk.offset, - std::max_element(slice_decisions_sorted_by_start_time.begin(), - slice_decisions_sorted_by_start_time.end(), end_cmp) - ->chunk.chunk_end()); -} - -// Helper function to compute the start time for a SlicedCopyAllocation. -int64_t GetSlicedCopyAllocationExclusiveStartTime( - const std::vector& - slice_decisions_sorted_by_exclusive_start_time) { - if (slice_decisions_sorted_by_exclusive_start_time.empty()) { - return -1; - } - - return slice_decisions_sorted_by_exclusive_start_time.front() - .exclusive_start_time; -} - -} // namespace - -MemorySpaceAssignment::SlicedCopyAllocation::SlicedCopyAllocation( - const Allocation& prev_allocation, MemorySpace memory_space, - std::vector slice_decisions_sorted_by_exclusive_start_time, - int64_t copy_done_schedule_before_time, int64_t end_time, - absl::FunctionRef update_layout_fn) - : Allocation( - /*defining_position=*/{nullptr, {}}, memory_space, - GetSlicedCopyAllocationChunk( - slice_decisions_sorted_by_exclusive_start_time), - // Allocation uses an inclusive start time - ExclusiveToInclusiveStartTime( - GetSlicedCopyAllocationExclusiveStartTime( - slice_decisions_sorted_by_exclusive_start_time)), - end_time, - /*is_scoped_allocation=*/false), - original_shape_to_slice_(prev_allocation.defining_position().shape()), - prev_allocation_(prev_allocation), - update_layout_fn_(update_layout_fn) { - CHECK_GE(slice_decisions_sorted_by_exclusive_start_time.size(), 2); - slice_details_sorted_by_start_time_.reserve( - slice_decisions_sorted_by_exclusive_start_time.size()); - for (SliceDecision& decision : - slice_decisions_sorted_by_exclusive_start_time) { - int64_t copy_done_schedule_after_time = decision.exclusive_start_time; - slice_details_sorted_by_start_time_.push_back(SliceDetail{ - std::move(decision), - copy_done_schedule_after_time, - copy_done_schedule_before_time, - /*copy_start=*/nullptr, - /*copy_done=*/nullptr, - }); - } -} - -namespace { - -// Sets defining_position with the copy_complete instruction and replaces all -// uses of the allocation with the copy_complete instruction. -Status ProcessCopyLikeAllocationUses(HloPosition& defining_position, - std::vector& uses, - HloComputation* computation, - HloInstruction* copy_complete) { - // Update the allocation position with the copy complete instruction, so that - // if there are further copies from it, they can find the correct position. - defining_position = HloPosition{copy_complete, {}}; - - // Replace all the uses of the copy-like allocation with the copy complete - // instruction. - for (HloUse use : uses) { - // If the operand is a tuple, we need to descend to the actual instruction - // we want to replace. - HloInstruction* replacement_instruction = copy_complete; - Shape operand_shape = use.instruction->operand(use.operand_number)->shape(); - if (operand_shape.IsTuple()) { - TF_ASSIGN_OR_RETURN( - replacement_instruction, - TupleUtil::ReplaceTupleWith( - copy_complete, - use.instruction->mutable_operand(use.operand_number), - use.operand_index)); - } else if (operand_shape != copy_complete->shape()) { - // When processing allocations, we treat bitcasts as trivial positions and - // do not create allocations for them. We insert bitcasts after copies, to - // account for the fact that we don't have an allocation for the bitcast. - VLOG(4) << "Old shape = " << operand_shape.ToString() - << ", new shape = " << copy_complete->shape().ToString() - << "; inserting a bitcast."; - replacement_instruction = computation->AddInstruction( - HloInstruction::CreateBitcast(operand_shape, copy_complete)); - } - TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith( - use.operand_number, replacement_instruction)); - } - - return OkStatus(); -} - -} // namespace - -Status MemorySpaceAssignment::SlicedCopyAllocation::Process() { - Shape shape = defining_position().shape(); - HloInstruction* producing_instruction = AddGetTupleElements(); - - // Calling Process() over the previous allocation might have modified the - // defining position, and hence the shape that was used when we computed - // the slices. In cases where the shape has changed, we insert a bitcast, so - // slice instructions operate on the originally sliced shape. - // - // Note, these bitcasts are being inserted in the same cases that - // ProcessCopyLikeAllocationUses() is inserting bitcasts, except we are - // inserting the bitcasts before the copy, instead of after the copy. - if (!Shape::Equal().IgnoreMemorySpaceInLayout()(shape, - original_shape_to_slice_)) { - int64_t new_memory_space = shape.layout().memory_space(); - shape = original_shape_to_slice_; - shape.mutable_layout()->set_memory_space(new_memory_space); - producing_instruction = producing_instruction->parent()->AddInstruction( - HloInstruction::CreateBitcast(shape, producing_instruction)); - } - - HloComputation* computation = producing_instruction->parent(); - std::vector slice_dones; - slice_dones.reserve(slice_details_sorted_by_start_time_.size()); - - // Sliced copy allocations need to insert asynchronous copy nodes. - for (SliceDetail& slice_detail : slice_details_sorted_by_start_time_) { - TF_RETURN_IF_ERROR(slice_detail.CreateAsyncSlice( - shape, *producing_instruction, *computation, update_layout_fn_)); - VLOG(4) << "Created " << slice_detail.copy_start->name() - << " for sliced copy allocation: " << ToString(); - slice_dones.push_back(slice_detail.copy_done); - } - - TF_RETURN_IF_ERROR(CreateBitcastConcat(shape, slice_dones)); - - return ProcessCopyLikeAllocationUses(defining_position_, uses_, computation, - concat_); -} - -void MemorySpaceAssignment::SlicedCopyAllocation::MarkNeeded( - absl::flat_hash_set& needed_allocations) const { - needed_allocations.insert(this); - prev_allocation_.MarkNeeded(needed_allocations); -} - -HloPosition MemorySpaceAssignment::SlicedCopyAllocation::defining_position() - const { - // Unless explicitly set, the defining position of a sliced copy allocation is - // retrieved from the previous allocation. This is because we don't create - // new CopyStart/CopyDone instructions until later and the position should - // point to the previous (copy or otherwise) allocation's position for the - // original defining position. - if (defining_position_.instruction == nullptr) { - return prev_allocation_.defining_position(); - } - return defining_position_; -} - -int64_t MemorySpaceAssignment::SlicedCopyAllocation::earliest_available_time() - const { - return slice_details_sorted_by_start_time().back().copy_done_before_time; -} - -std::vector -MemorySpaceAssignment::SlicedCopyAllocation::SliceOffsetsSortedByStartTime() - const { - std::vector offsets; - offsets.reserve(slice_details_sorted_by_start_time_.size()); - - for (const SliceDetail& slice_detail : slice_details_sorted_by_start_time_) { - offsets.push_back(slice_detail.slice_decision.chunk.offset); - } - - return offsets; -} - -void MemorySpaceAssignment::SlicedCopyAllocation::AddDiffToAllSliceOffsets( - int64_t diff) { - for (SliceDetail& slice_detail : slice_details_sorted_by_start_time_) { - Chunk& chunk = slice_detail.slice_decision.chunk; - chunk = Chunk::FromOffsetSize(chunk.offset + diff, chunk.size); - } -} - -void MemorySpaceAssignment::SlicedCopyAllocation::ImportRepackedSliceData( - const MemorySpaceAssignmentRepacker::SlicedAllocationData& data) { - int num_slices = slice_details_sorted_by_start_time_.size(); - CHECK_EQ(data.slices_sorted_by_offset.size(), num_slices); - - std::vector slice_details_sorted_by_offset; - slice_details_sorted_by_offset.reserve(num_slices); - for (SliceDetail& slice_detail : slice_details_sorted_by_start_time_) { - slice_details_sorted_by_offset.push_back(&slice_detail); - } - absl::c_sort(slice_details_sorted_by_offset, [](const SliceDetail* lhs, - const SliceDetail* rhs) { - return lhs->slice_decision.chunk.offset < rhs->slice_decision.chunk.offset; - }); - - for (int i = 0; i < num_slices; ++i) { - SliceDetail* slice_detail = slice_details_sorted_by_offset[i]; - Chunk& chunk = slice_detail->slice_decision.chunk; - const MemorySpaceAssignmentRepacker::Slice& repacked_slice_data = - data.slices_sorted_by_offset[i]; - chunk = Chunk::FromOffsetSize(repacked_slice_data.offset, chunk.size); - slice_detail->copy_start_after_time = - repacked_slice_data.inclusive_start_time - 1; - slice_detail->slice_decision.exclusive_start_time = - InclusiveToExclusiveStartTime(repacked_slice_data.inclusive_start_time); - } - - absl::c_sort(slice_details_sorted_by_start_time_, - [](const SliceDetail& lhs, const SliceDetail& rhs) { - return std::make_tuple(lhs.copy_start_after_time, - lhs.slice_decision.chunk.offset) < - std::make_tuple(rhs.copy_start_after_time, - rhs.slice_decision.chunk.offset); - }); -} - -const std::vector& -MemorySpaceAssignment::SlicedCopyAllocation:: - slice_details_sorted_by_start_time() const { - return slice_details_sorted_by_start_time_; -} - -std::vector& -MemorySpaceAssignment::SlicedCopyAllocation:: - mutable_slice_details_sorted_by_start_time() { - return slice_details_sorted_by_start_time_; -} - -std::tuple&, - const HloInstruction*> -MemorySpaceAssignment::SlicedCopyAllocation::ToTuple() const { - return std::make_tuple( - std::ref(*this), std::ref(slice_details_sorted_by_start_time_), concat_); -} - -bool MemorySpaceAssignment::SlicedCopyAllocation::operator==( - const SlicedCopyAllocation& other) const { - return ToTuple() == other.ToTuple(); -} - -std::string MemorySpaceAssignment::SlicedCopyAllocation::ToString() const { - std::string memory_space_str = "def"; - if (memory_space_ == MemorySpace::kAlternate) { - memory_space_str = absl::StrCat("alt (off: ", chunk_->offset, ")"); - } - return absl::StrCat( - "Sliced Copy Allocation in ", memory_space_str, - ", start_time:", start_time(), ", end_time:", end_time(), - ", first_slice_copy_start_after_time: ", - slice_details_sorted_by_start_time().front().copy_start_after_time, - ", last_slice_copy_done_before_time: ", - slice_details_sorted_by_start_time().back().copy_done_before_time, - ", uses: ", UsesToString(uses()), ", from ", prev_allocation_.ToString()); -} - -Status MemorySpaceAssignment::SlicedCopyAllocation::CreateBitcastConcat( - const Shape& shape, absl::Span slices) { - CHECK(!slices.empty()); - concat_ = - slices.front()->parent()->AddInstruction(HloInstruction::CreateCustomCall( - shape, slices, kConcatBitcastCustomCall)); - return OkStatus(); -} - -std::string MemorySpaceAssignment::MirroredAllocation::ToString() const { - return absl::StrCat("Mirrored Allocation for ", - original_allocation_.ToString()); -} - -std::string MemorySpaceAssignment::ParentAllocation::ToString() const { - return absl::StrCat("Parent Allocation mirrored at ", - defining_position_.ToString(), ", originally ", - original_allocation_.ToString()); -} - -MemorySpaceAssignment::CopyAllocation::CopyAllocation( - Allocation& prev_allocation, MemorySpace memory_space, - std::optional chunk, int64_t copy_start_schedule_after_time, - int64_t copy_done_schedule_before_time, int64_t end_time, - std::optional cross_program_prefetch_index) - : Allocation(/*defining_position=*/{nullptr, {}}, memory_space, chunk, - // Allocation uses an inclusive start time - ExclusiveToInclusiveStartTime(copy_start_schedule_after_time), - end_time, - /*is_scoped_allocation=*/false), - prev_allocation_(prev_allocation), - copy_start_schedule_after_(copy_start_schedule_after_time), - copy_done_schedule_before_(copy_done_schedule_before_time), - cross_program_prefetch_index_(cross_program_prefetch_index) {} - -Status MemorySpaceAssignment::CopyAllocation::Process() { - // Copy allocations need to insert asynchronous copy nodes. - Shape shape = defining_position().shape(); - HloInstruction* producing_instruction = AddGetTupleElements(); - HloComputation* computation = producing_instruction->parent(); - copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart( - ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}), - producing_instruction, cross_program_prefetch_index_)); - copy_done_ = computation->AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_)); - VLOG(4) << "Created " << copy_start_->name() - << " for copy allocation: " << ToString(); - - return ProcessCopyLikeAllocationUses(defining_position_, uses_, computation, - copy_done_); -} - -Status MemorySpaceAssignment::MirroredAllocation::Process() { - defining_position_ = original_allocation_.defining_position(); - return Allocation::Process(); -} - -Status MemorySpaceAssignment::ParentAllocation::Process() { - // Add an additional parameter to the while HLO with a reference to the buffer - // in the default memory space. - HloInstruction* producing_instruction = - original_allocation_.AddGetTupleElements(); - int new_tuple_index = calling_instruction_->shape().tuple_shapes_size(); - - TF_ASSIGN_OR_RETURN( - HloInstruction * new_while_operand, - TupleUtil::ReplaceTupleWith(producing_instruction, - calling_instruction_->mutable_operand(0), - {new_tuple_index})); - TF_RETURN_IF_ERROR(calling_instruction_->ReplaceOperandWithDifferentShape( - 0, new_while_operand)); - *calling_instruction_->mutable_shape() = new_while_operand->shape(); - *calling_instruction_->while_condition() - ->parameter_instruction(0) - ->mutable_shape() = new_while_operand->shape(); - *calling_instruction_->while_body() - ->parameter_instruction(0) - ->mutable_shape() = new_while_operand->shape(); - defining_position_.index = {new_tuple_index}; - // Also replace the while op with a tuple that has the old shape. Note that we - // need to first take a snapshot of the users before calling ExtractPrefix - // since ExtractPrefix introduces additional gte users. - std::vector while_users = calling_instruction_->users(); - HloInstruction* tuple_with_old_shape = - TupleUtil::ExtractPrefix(calling_instruction_, new_tuple_index); - TF_RETURN_IF_ERROR(calling_instruction_->ReplaceAllUsesWithDifferentShape( - while_users, tuple_with_old_shape)); - return Allocation::Process(); -} - -Status MemorySpaceAssignment::ParentAllocation::PostProcess() { - // Update the root of the while body with the new parameter. The reason why we - // need a separate post-process for this is because other allocations may have - // while body root as a use, so they would update the old root instead of the - // new root. Doing the post-process step later ensures the root has been - // updated with other changes, and we can safely add the additional parameter. - HloComputation* while_body = calling_instruction_->while_body(); - TF_ASSIGN_OR_RETURN(HloInstruction * new_while_body_root, - TupleUtil::ReplaceTupleWith( - AddGetTupleElements(), while_body->root_instruction(), - defining_position_.index)); - while_body->set_root_instruction(new_while_body_root, - /*accept_different_shape=*/true); - return OkStatus(); -} - -void MemorySpaceAssignment::Allocation::MarkIfNeeded( - absl::flat_hash_set& needed_allocations) const { - MarkNeeded(needed_allocations); -} - -void MemorySpaceAssignment::Allocation::MarkNeeded( - absl::flat_hash_set& needed_allocations) const { - needed_allocations.insert(this); -} - -void MemorySpaceAssignment::CopyAllocation::MarkNeeded( - absl::flat_hash_set& needed_allocations) const { - needed_allocations.insert(this); - prev_allocation_.MarkNeeded(needed_allocations); -} - -void MemorySpaceAssignment::ParentAllocation::MarkIfNeeded( - absl::flat_hash_set& needed_allocations) const { - // Parent allocations are only needed if they have any uses or if there is a - // copy allocation that copies this value (in that case, the copy allocation - // will call this allocation's MarkNeeded function). - if (!uses_.empty()) { - MarkNeeded(needed_allocations); - } -} - -void MemorySpaceAssignment::ParentAllocation::MarkNeeded( - absl::flat_hash_set& needed_allocations) const { - needed_allocations.insert(this); - original_allocation_.MarkNeeded(needed_allocations); -} - -void MemorySpaceAssignment::MirroredAllocation::MarkNeeded( - absl::flat_hash_set& needed_allocations) const { - needed_allocations.insert(this); - original_allocation_.MarkNeeded(needed_allocations); -} - Status MemorySpaceAssignment::Process(const HloLiveRange& hlo_live_range) { VLOG(1) << "Processing assigned buffers..."; // Since some parent allocations may not be needed (e.g. when they don't have @@ -8283,7 +5576,7 @@ Status MemorySpaceAssignment::ExportAndColorBuffers() { VLOG(3) << "Exported alternate memory allocations:"; for (const auto& position_and_chunk : alternate_memory_assignments_) { const HloPosition& defining_position = position_and_chunk.first; - const Chunk& chunk = position_and_chunk.second; + const HeapSimulator::Chunk& chunk = position_and_chunk.second; const HloBuffer& buffer = alias_analysis->GetUniqueBufferAt( defining_position.instruction, defining_position.index); auto seen_buffer_offset_it = seen_buffer_offsets.find(buffer.id()); @@ -8303,7 +5596,7 @@ Status MemorySpaceAssignment::ExportAndColorBuffers() { VLOG(3) << "Exported scoped allocations in alternate memory:"; for (const auto& instruction_and_chunk : scoped_memory_assignments_) { HloInstruction* instruction = instruction_and_chunk.first; - const Chunk& chunk = instruction_and_chunk.second; + const HeapSimulator::Chunk& chunk = instruction_and_chunk.second; VLOG(3) << " [" << chunk.offset << ", " << chunk.size << "] : " << instruction->name(); preset_assignments_->add_scoped_allocation_chunk(instruction, chunk); @@ -8516,8 +5809,7 @@ class AsyncCopyStep { class AsyncCopyStepForCopyAllocation : public AsyncCopyStep { public: - explicit AsyncCopyStepForCopyAllocation( - MemorySpaceAssignment::CopyAllocation* copy_allocation) + explicit AsyncCopyStepForCopyAllocation(CopyAllocation* copy_allocation) : AsyncCopyStep(), copy_allocation_(copy_allocation) {} ~AsyncCopyStepForCopyAllocation() override = default; @@ -8543,14 +5835,13 @@ class AsyncCopyStepForCopyAllocation : public AsyncCopyStep { } private: - MemorySpaceAssignment::CopyAllocation* copy_allocation_ = nullptr; + CopyAllocation* copy_allocation_ = nullptr; }; class AsyncCopyStepForSlice : public AsyncCopyStep { public: - AsyncCopyStepForSlice( - MemorySpaceAssignment::SlicedCopyAllocation* sliced_copy_allocation, - size_t slice_index) + AsyncCopyStepForSlice(SlicedCopyAllocation* sliced_copy_allocation, + size_t slice_index) : AsyncCopyStep(), sliced_copy_allocation_(sliced_copy_allocation), slice_index_(slice_index) {} @@ -8562,10 +5853,9 @@ class AsyncCopyStepForSlice : public AsyncCopyStep { } std::optional start_phase() const override { - const MemorySpaceAssignment::SlicedCopyAllocation::SliceDetail& - slice_details = - sliced_copy_allocation_ - ->slice_details_sorted_by_start_time()[slice_index_]; + const SlicedCopyAllocation::SliceDetail& slice_details = + sliced_copy_allocation_ + ->slice_details_sorted_by_start_time()[slice_index_]; StartPhase phase{slice_details.copy_start_after_time, slice_details.copy_start}; @@ -8579,10 +5869,9 @@ class AsyncCopyStepForSlice : public AsyncCopyStep { } DonePhase done_phase() const override { - const MemorySpaceAssignment::SlicedCopyAllocation::SliceDetail& - slice_details = - sliced_copy_allocation_ - ->slice_details_sorted_by_start_time()[slice_index_]; + const SlicedCopyAllocation::SliceDetail& slice_details = + sliced_copy_allocation_ + ->slice_details_sorted_by_start_time()[slice_index_]; DonePhase phase{slice_details.copy_done_before_time, slice_details.copy_done}; @@ -8590,15 +5879,14 @@ class AsyncCopyStepForSlice : public AsyncCopyStep { } private: - MemorySpaceAssignment::SlicedCopyAllocation* sliced_copy_allocation_ = - nullptr; + SlicedCopyAllocation* sliced_copy_allocation_ = nullptr; size_t slice_index_; }; class AsyncCopyStepForSliceConcat : public AsyncCopyStep { public: explicit AsyncCopyStepForSliceConcat( - MemorySpaceAssignment::SlicedCopyAllocation* sliced_copy_allocation) + SlicedCopyAllocation* sliced_copy_allocation) : AsyncCopyStep(), sliced_copy_allocation_(sliced_copy_allocation) {} ~AsyncCopyStepForSliceConcat() override = default; @@ -8619,8 +5907,7 @@ class AsyncCopyStepForSliceConcat : public AsyncCopyStep { } private: - MemorySpaceAssignment::SlicedCopyAllocation* sliced_copy_allocation_ = - nullptr; + SlicedCopyAllocation* sliced_copy_allocation_ = nullptr; }; } // namespace @@ -8723,7 +6010,7 @@ Status MemorySpaceAssignment::FixSchedule() { VLOG(4) << "Scheduling: " << computation->ToString(); - for (int64_t instruction_index = 0;; ++instruction_index) { + for (int64_t instruction_index = -1;; ++instruction_index) { auto insts_before_iter = schedule_before_.find(instruction_index); if (insts_before_iter != schedule_before_.end()) { for (HloInstruction* new_instruction : insts_before_iter->second) { @@ -8735,25 +6022,32 @@ Status MemorySpaceAssignment::FixSchedule() { } } } - // We allow scheduling copy dones past the root instruction (for - // end-of-program cross-program prefetch). So the loop exit condition is - // actually here. - if (instruction_index >= flattened_instructions_.size()) { - break; - } - HloInstruction* instruction = flattened_instructions_[instruction_index]; - // Insert only if it is not deleted (SimplifyGraph sets it to nullptr if - // it was deleted) and not previously inserted. Also bitcasts and tuples - // are treated specially and only inserted as a result of operand - // dependencies. - if (instruction != nullptr && instruction->parent() == computation && - instruction->opcode() != HloOpcode::kBitcast && - instruction->opcode() != HloOpcode::kTuple && - !inserted_instructions.contains(instruction)) { - VLOG(4) << "inst " << instruction_index << ": " << instruction->name(); - TF_RETURN_IF_ERROR(InsertInstructionAndEnsureOperandsInserted( - instruction, &new_sequence, &inserted_instructions)); + + if (instruction_index != -1) { + // We allow scheduling copy dones past the root instruction (for + // end-of-program cross-program prefetch). So the loop exit condition is + // actually here. + if (instruction_index >= flattened_instructions_.size()) { + break; + } + + HloInstruction* instruction = + flattened_instructions_[instruction_index]; + // Insert only if it is not deleted (SimplifyGraph sets it to nullptr if + // it was deleted) and not previously inserted. Also bitcasts and tuples + // are treated specially and only inserted as a result of operand + // dependencies. + if (instruction != nullptr && instruction->parent() == computation && + instruction->opcode() != HloOpcode::kBitcast && + instruction->opcode() != HloOpcode::kTuple && + !inserted_instructions.contains(instruction)) { + VLOG(4) << "inst " << instruction_index << ": " + << instruction->name(); + TF_RETURN_IF_ERROR(InsertInstructionAndEnsureOperandsInserted( + instruction, &new_sequence, &inserted_instructions)); + } } + auto insts_after_iter = schedule_after_.find(instruction_index); if (insts_after_iter != schedule_after_.end()) { for (HloInstruction* new_instruction : insts_after_iter->second) { @@ -8766,6 +6060,7 @@ Status MemorySpaceAssignment::FixSchedule() { } } } + // For rare cases where the original sequence is empty, ensure the root // instruction and its dependencies are scheduled. TF_RETURN_IF_ERROR(EnsureInstructionAndOperandsInserted( @@ -8797,12 +6092,13 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { // are sorted first by time, then within the same time, allocations are sorted // earlier than frees, and finally the value id as a tie breaker. std::map, - std::tuple> + std::tuple> events; auto add_allocation_and_verify = [&](int64_t start_time, int64_t end_time, - const Chunk& chunk, - const HloValue* value) { + const HeapSimulator::Chunk& chunk, + const HloValue* value) -> absl::Status { events[std::make_tuple(start_time, /*is_free=*/false, value->id())] = std::make_tuple(value, chunk, HeapSimulatorTrace::Event::ALLOC); events[std::make_tuple(end_time, /*is_free=*/true, value->id())] = @@ -8814,10 +6110,10 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { // really should check against end_time (inclusive) for cases where the // operand can't share buffer with user (see // HloDataflowAnalysis::CanShareOperandBufferWithUser). - for (const Chunk& overlapping_chunk : + for (const HeapSimulator::Chunk& overlapping_chunk : interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) { if (chunk.OverlapsWith(overlapping_chunk)) { - return InternalError( + return Internal( ("Value %s (%d, %d) off: %d size: %d overlaps with another chunk" " off: %d size: %d"), value->ToShortString(), start_time, end_time, chunk.offset, @@ -8851,7 +6147,7 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { for (const auto& position_and_chunk : preset_assignments_->chunks()) { const HloPosition& position = position_and_chunk.first; - const Chunk& chunk = position_and_chunk.second; + const HeapSimulator::Chunk& chunk = position_and_chunk.second; const HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(position.instruction, position.index); CHECK(!seen_buffers.contains(buffer.id())) @@ -8964,7 +6260,7 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { int64_t buffer_id; std::tie(time, is_free, buffer_id) = event.first; const HloValue* value; - Chunk chunk; + HeapSimulator::Chunk chunk; HeapSimulatorTrace::Event::Kind kind; std::tie(value, chunk, kind) = event.second; HeapSimulatorTrace::Event* heap_trace_event = heap_trace->add_events(); @@ -9004,8 +6300,7 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { DefaultCrossProgramPrefetchBufferIntervalComparator:: DefaultCrossProgramPrefetchBufferIntervalComparator( const HloLiveRange& hlo_live_range) - : MemorySpaceAssignment::BufferIntervalComparator(), - hlo_live_range_(hlo_live_range) {} + : BufferIntervalComparator(), hlo_live_range_(hlo_live_range) {} std::string DefaultCrossProgramPrefetchBufferIntervalComparator:: DescribeComparisonCriteria() const { @@ -9014,19 +6309,19 @@ std::string DefaultCrossProgramPrefetchBufferIntervalComparator:: std::string DefaultCrossProgramPrefetchBufferIntervalComparator::CriteriaToString( - const BufferInterval& buffer_interval) { + const MsaBufferInterval& buffer_interval) { return absl::StrCat("[ ", absl::StrJoin(GetTuple(buffer_interval), ", "), " ]"); } bool DefaultCrossProgramPrefetchBufferIntervalComparator::LessThan( - const BufferInterval& lhs, const BufferInterval& rhs) { + const MsaBufferInterval& lhs, const MsaBufferInterval& rhs) { return GetTuple(lhs) < GetTuple(rhs); } DefaultCrossProgramPrefetchBufferIntervalComparator::ComparisonTuple DefaultCrossProgramPrefetchBufferIntervalComparator::GetTuple( - const BufferInterval& buffer_interval) { + const MsaBufferInterval& buffer_interval) { auto sort_data_it = additional_sort_data_.find(buffer_interval.buffer); if (sort_data_it == additional_sort_data_.end()) { AdditionalSortData sort_data; @@ -9039,10 +6334,9 @@ DefaultCrossProgramPrefetchBufferIntervalComparator::GetTuple( sort_data.cumulative_use_size += ShapeUtil::ElementsInRecursive(use.instruction->shape()); }); - sort_data_it = additional_sort_data_ - .insert(std::make_pair(buffer_interval.buffer, - std::move(sort_data))) - .first; + sort_data_it = + additional_sort_data_.try_emplace(buffer_interval.buffer, sort_data) + .first; } return std::make_tuple( @@ -9052,32 +6346,41 @@ DefaultCrossProgramPrefetchBufferIntervalComparator::GetTuple( MemoryBoundednessBufferIntervalComparator:: MemoryBoundednessBufferIntervalComparator( - const MemorySpaceAssignmentCostAnalysis& cost_analysis, - MemorySpaceAssignmentCostAnalysis::Cache* cost_analysis_cache) - : MemorySpaceAssignment::BufferIntervalComparator(), + const CostAnalysis& cost_analysis, + CostAnalysis::Cache* cost_analysis_cache) + : BufferIntervalComparator(), cost_analysis_(cost_analysis), cost_analysis_cache_(cost_analysis_cache) {} +MemoryBoundednessBufferIntervalComparator:: + MemoryBoundednessBufferIntervalComparator( + const CostAnalysis& cost_analysis, + CostAnalysis::Cache* cost_analysis_cache, + MsaSortOrderOverrides msa_sort_order_overrides) + : BufferIntervalComparator(), + cost_analysis_(cost_analysis), + cost_analysis_cache_(cost_analysis_cache), + msa_sort_order_overrides_(msa_sort_order_overrides) {} + std::string MemoryBoundednessBufferIntervalComparator::DescribeComparisonCriteria() const { - return "[ -memory boundedness, -size, -buffer duration, latest use time, " - "(inclusive) start time, instruction id ]"; + return "[override priority, -memory boundedness, -size, -buffer duration, " + "latest use time, (inclusive) start time, instruction id ]"; } std::string MemoryBoundednessBufferIntervalComparator::CriteriaToString( - const BufferInterval& buffer_interval) { + const MsaBufferInterval& buffer_interval) { return absl::StrCat("[ ", absl::StrJoin(GetTuple(buffer_interval), ", "), " ]"); } bool MemoryBoundednessBufferIntervalComparator::LessThan( - const BufferInterval& lhs, const BufferInterval& rhs) { + const MsaBufferInterval& lhs, const MsaBufferInterval& rhs) { return GetTuple(lhs) < GetTuple(rhs); } -MemoryBoundednessBufferIntervalComparator::ComparisonTuple -MemoryBoundednessBufferIntervalComparator::GetTuple( - const BufferInterval& buffer_interval) { +int64_t MemoryBoundednessBufferIntervalComparator::GetLatestUseTime( + const MsaBufferInterval& buffer_interval) { auto latest_use_it = buffer_to_latest_use_.find(buffer_interval.buffer); if (latest_use_it == buffer_to_latest_use_.end()) { int64_t latest_use_time = 0; @@ -9093,13 +6396,25 @@ MemoryBoundednessBufferIntervalComparator::GetTuple( .insert(std::make_pair(buffer_interval.buffer, latest_use_time)) .first; } + return latest_use_it->second; +} - return std::make_tuple(-1.0 * cost_analysis_.GetMemoryBoundedness( - buffer_interval, cost_analysis_cache_), - -1 * buffer_interval.size, - buffer_interval.start - buffer_interval.end, - latest_use_it->second, buffer_interval.start, - buffer_interval.buffer->id()); +MemoryBoundednessBufferIntervalComparator::ComparisonTuple +MemoryBoundednessBufferIntervalComparator::GetTuple( + const MsaBufferInterval& buffer_interval) { + int64_t priority = GetBufferIntervalOverridePriority( + msa_sort_order_overrides_, buffer_interval); + float inverse_memory_boundedness = + -1.0 * cost_analysis_.GetMemoryBoundedness(buffer_interval, + cost_analysis_cache_); + int64_t inverse_buffer_size = -1 * buffer_interval.size; + int64_t inverse_buffer_duration = buffer_interval.start - buffer_interval.end; + int64_t latest_use_time = GetLatestUseTime(buffer_interval); + int64_t buffer_start_time = buffer_interval.start; + auto buffer_id = buffer_interval.buffer->id(); + return std::make_tuple(priority, inverse_memory_boundedness, + inverse_buffer_size, inverse_buffer_duration, + latest_use_time, buffer_start_time, buffer_id); } } // namespace memory_space_assignment diff --git a/xla/service/memory_space_assignment/memory_space_assignment.h b/xla/service/memory_space_assignment/memory_space_assignment.h index cd26fdb21ee45..e262349fbb5cb 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment.h +++ b/xla/service/memory_space_assignment/memory_space_assignment.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,17 +13,165 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +/* +Quick reference + +This section is meant as to be a quick reference for getting the gist of +commonly used terminology in the code and logging. Please see the code for more +details. + +General concepts + + - Time: In MSA, time typically refers to an index into the flattened + instruction schedule. + + - Cross-program prefetch: Cross-program prefetched tensors are copied from + memory to alternate the first time a program executes, like usual + prefetches. MSA keeps these buffers alive in alternate memory at the end of + the program, such that if the same program is executed again, these tensors + would not need to be prefetched again. + +Classes + + - HloPosition (Hlo dataflow analysis concept): Identifies a tensor referenced + in an instruction's output. Defined by . + + - HloValue (Hlo dataflow analysis concept): The value of a tensor. Each + HloValue is represented by a collection of HloPositions. Exactly 1 of those + positions is the HloValue's defining position, i.e., the point in code where + the value is created/written. The rest of the positions pertain to read-only + uses of the value. + * Example: A tensor that is inserted in a Tuple has 2 HloPositions, one for + the instruction that creates the tensor, and one indexing into the Tuple + instruction result. + * The read-only positions of an HloValue should not be confused with + HloUses. Read-only positions are references to the HloValue in the output + of an instruction. Uses are references to an HloValue in the input of an + instruction. + * Dataflow analysis assigns HloValues for the instructions in computations + pertaining to while loops, conditionals, and call ops. However, it does + not assign HloValues to the computations pertaining to instructions with + "call" semantics (e.g., fusions, reduce, and custom-call) because those + computations are treated as black boxes. + * If a while loop does not modify an input tensor, that tensor will be + assigned 1 HloValue that lasts from its creation point through the while + loop. + * If a while loop modifies one of its input tensors, that tensor will + receive at least the following HloValues: + - An HloValue for the tensor's creation, with a use at the operand of the + while instruction. + - An HloValue with its defining position at the while body's parameter. + - An HloValue whose defining position is an instruction in the while body + that feeds the new tensor value to the body's ROOT instruction. + - An HloValue with its defining position at the while instruction's + result. + + - HloBuffer (Hlo alias analysis concept): A memory container that holds one + or more HloValues that must alias. Typically, each HloValue corresponds to + 1 HloBuffer; however, many exceptions exist. For example, tensors that are + modified by a while loop have their HloValues share an HloBuffer, for the + HloValues that come immediately before, during, and immediately after the + loop. HloBuffers are shared between HloValues wherever their is aliasing, + whether implicit by the nature of the instruction (e.g., + dynamic-update-slice) or explicit (e.g., fusion input-output aliasing). + + - BufferInterval (HeapSimulator concept): A BufferInterval is defined by a + buffer of a given size, with a defined lifetime. In MSA, the buffer + corresponds to an HloValue. + + - AllocationValue: An AllocationValue is defined by an HloValue, and *one* of + its HloPositions. + * We do not create AllocationValues for non-trivial HloPositions, e.g., ones + defined by Tuple, GetTupleElement, and Bitcast instructions. + * The HloPosition used to define the AllocationValue is referred to as the + AllocationValue's defining position. + * Typically, this is also the defining position of the HloValue. However, + it may not be. For example, we would create an AllocationValue with an + HloPosition of a read-only while loop parameter, but the HloValue + corresponding to that HloPosition would have a different defining + position. + * The uses of an AllocationValue are limited to the direct uses of the + AllocationValue's defining position. + * An AllocationValue is associated with an AllocationSequence, describing + what to do with the underlying tensor, in memory, over the lifetime of the + AllocationValue. + + - (Use) Segment: Each AllocationValue and its uses are separated into periods + of time called use segments. The first use segment is from the (inclusive) + time of the AllocationValue's defining position to its first use + (inclusive). The second use segment is from the first use (inclusive) to + the second use (inclusive), etc. + + - AllocationRequest: A request to determine what to do with an + AllocationValue, in memory, during a use segment. It also contains + restrictions and preferences on what to do. + * A request results in updates to the AllocationValue's AllocationSequence. + It may add Allocations, or modify existing Allocations in the sequence. + + - Allocation: A description of what to do with an AllocationValue in memory, + over a period of time. + * Pure virtual base class of all Allocations. + + - AllocationSequence: A sequential list of Allocations, explaining what to do + with an AllocationValue over its lifetime. Allocations in the sequence may + overlap. + + - Pinned Allocation: Represents producing a tensor in a particular memory + space, or keeping a tensor in a memory space in which it already exists. + + - Copy Allocation: Instructions to copy an AllocationValue from one memory + space to another. Used for prefetching (default mem -> alt mem), and + eviction (alt mem -> default mem). + * A copy Allocation contains a copy_done_schedule_before_time. The buffer is + available for use at that schedule time, through the Allocation's + end_time. + + - Sliced Copy Allocation: Similar to a Copy Allocation, except the memory is + copied in slices, in an effort to delay allocating memory in the destination + memory space, for as long as possible. + + - Mirrored Allocation and Parent Allocation: R/W tensors passed to while loops + typically have at least 3 AllocationValues, 1 for the producer of the tensor + before the while loop, 1 for the while loop's body parameter, and 1 for the + result of the while loop. There are situations heading into a while loop, in + which the while loop input is both in alternate memory and default memory. + (For example, this could happen beause we want the buffer in alternate + memory for the while loop and default memory after the while loop, but we + don't have resources to evict the buffer after the while loop.) In those + cases, we use a mirrored allocation for the AllocationValue inside the + while loop, to mirror the allocation in default memory. We use a parent + allocation for the AllocationValue resulting from the while loop result. + +Useful logging and error messages + + - Live range too long: The live range of a use segement is too long to for an + alternate memory no copy, i.e., its longer than we want to keep a buffer in + alternate memory wihtout being used. + * If the CostAnalysisPrefetchIntervalPicker is used, which is the default, + live range too long is governed by the picker's + max_overlap_to_mem_size_async_copy_ratio argument. + + - Live range too short: The live range of a use segement is too short to + prefetch a buffer to alternate memory, according to some heuristic and not + based on limited copy resource. + * If the CostAnalysisPrefetchIntervalPicker is used, which is the default, + live range too long is governed by the picker's + min_overlap_to_async_copy_ratio argument. + + - "Finding allocation for": Magical logging phrase indicating the point in + time where we are are trying to determine how to update an AllocationValue's + AllocationSequenece, for a particular use segment. +*/ + #ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_SPACE_ASSIGNMENT_H_ #define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_SPACE_ASSIGNMENT_H_ #include #include -#include #include #include #include #include -#include #include #include #include @@ -39,34 +187,30 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/functional/any_invocable.h" -#include "absl/functional/function_ref.h" -#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/buffer_value.h" #include "xla/service/call_graph.h" -#include "xla/service/heap_simulator.h" +#include "xla/service/heap_simulator/allocation_block.h" +#include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_alias_analysis.h" -#include "xla/service/hlo_buffer.h" -#include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/allocation.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" #include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" -#include "xla/service/memory_space_assignment/repacking.h" +#include "xla/service/memory_space_assignment/options.h" +#include "xla/service/memory_space_assignment/slice.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/status.h" #include "xla/statusor.h" #include "xla/util.h" namespace xla { namespace memory_space_assignment { -// Forward Declaration of Options. -struct Options; - -inline constexpr char kConcatBitcastCustomCall[] = "ConcatBitcast"; - // This class contains pre-set assignments determined by memory space // assignment. It contains two data structures: (1) a chunks vector that maps a // defining HloPosition to a Chunk (offset and size), and (2) an assignment_info @@ -137,428 +281,6 @@ class PresetAssignments { std::string instruction_schedule_str_; }; -// A wrapper class around HloCostAnalysis with additional knowledge about the -// bandwidths of different memory spaces. -class MemorySpaceAssignmentCostAnalysis { - public: - // An optional Cache object may be provided to some of the methods below to - // speed up the lookup. - struct Cache { - absl::flat_hash_map while_nest_multiplier; - absl::flat_hash_map memory_boundedness; - }; - - // Function type that can be used to indicate which input/output values are in - // the alternate memory. - using IsInAlternateMemoryFun = absl::FunctionRef /*operand_num*/, const ShapeIndex& /*index*/, - const Shape& /*shape*/)>; - - virtual ~MemorySpaceAssignmentCostAnalysis() = default; - - static StatusOr> Create( - const HloCostAnalysis& cost_analysis, const Options& options, - const HloModule& module); - - const HloCostAnalysis& cost_analysis() const { return cost_analysis_; } - - // Returns a heuristic value that captures how much putting this tensor to the - // alternate memory would help if the op is memory bound, or otherwise how far - // off is the op to memory boundedness. The larger this number, the higher - // priority it will be placed in the alternate memory. - float GetAlternateMemoryBenefit(const HloInstruction& instruction, - float elapsed_time_due_to_alternate_mem, - Cache* cache = nullptr) const; - // Like above, return the benefit of putting the output tensor in the - // alternate memory. - float GetAlternateMemoryBenefit(const HloPosition& position, - Cache* cache = nullptr) const; - // Like above, return the benefit of putting the input tensor in the alternate - // memory. - float GetAlternateMemoryBenefit(const HloUse& use, - Cache* cache = nullptr) const; - - // Returns a heuristic value of memory boundedness for the given - // BufferInterval. The larger this number, the higher priority it will be - // placed in the alternate memory. - float GetMemoryBoundedness( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, - Cache* cache = nullptr) const; - - // If enabled in Options::pipeline_overhead_window_size_mib, returns the - // overhead of accessing the default memory, in seconds. The source of the - // overhead is the software pipelining ovehead. The lowering of the operations - // typically use tiling to copy one window at a time from default memory, and - // perform compute: - // - // Pipeline overhead: <-> - // +----+----+----+----+ - // Copy from default mem: | | | | | - // +----+----+----+----+ - // \ \ \ \ - // \ \ \ \ - // V V V V - // +--+ +--+ +--+ +--+ - // Compute: | | | | | | | | - // +--+ +--+ +--+ +--+ - float GetDefaultMemoryAccessOverhead( - const HloInstruction& instruction, - absl::Span> - operands_in_alternate_mem = {}, - absl::Span outputs_in_alternate_mem = {}) const; - - // Returns the amount of time the default memory bandwidth is idle, while - // executing this instruction, in seconds. This value can be multiplied with - // the default memory bandwidth to get the amount of bytes that are available - // to be copied to/from default memory during the execution of this - // instruction. - float GetDefaultMemoryBandwidthIdleTime( - const HloInstruction& instruction, - absl::Span> - operands_in_alternate_mem = {}, - absl::Span outputs_in_alternate_mem = {}) const; - - // Returns the bytes accessed from alternate memory. - float GetBytesAccessedFromAlternateMemory( - const HloInstruction& instruction, - absl::Span> - operands_in_alternate_mem = {}, - absl::Span outputs_in_alternate_mem = {}) const; - - // Returns the elapsed time in seconds due to compute only. - float GetInstructionElapsedDueToCompute( - const HloInstruction& instruction) const; - - // Returns the elapsed time in seconds due to memory only. If - // operands_in_alternate_mem or outputs_in_alternate_mem is provided, it will - // assume that the corresponding operands or output will be in the alternate - // memory space. This is useful for calculating the benefit of placing the - // buffer in alternate memory. - float GetInstructionElapsedDueToMemory( - const HloInstruction& instruction, - absl::Span> - operands_in_alternate_mem = {}, - absl::Span outputs_in_alternate_mem = {}) const; - - // Like above, only the inputs/outputs indicated by is_in_alternate_mem are in - // the alternate memory. - float GetInstructionElapsedDueToMemory( - const HloInstruction& instruction, - IsInAlternateMemoryFun is_in_alternate_mem) const; - - // Returns the estimated elapsed duration of the instruction in seconds. It - // assumes all operands and outputs of the instruction are in the default - // memory. - virtual float GetInstructionElapsed(const HloInstruction& instruction) const; - - // Returns the estimated elapsed duration of the instruction in seconds. It - // assumes all operands and outputs of the instruction are in the default - // memory, except for the operands and outputs specified to be in the - // alternate memory. - virtual float GetInstructionElapsedInAlternateMemory( - const HloInstruction& instruction, - absl::Span> - operands_in_alternate_mem, - absl::Span outputs_in_alternate_mem) const; - - // Like above, only the inputs/outputs indicated by is_in_alternate_mem are in - // the alternate memory. - float GetInstructionElapsedInAlternateMemory( - const HloInstruction& instruction, - IsInAlternateMemoryFun is_in_alternate_mem) const; - - // Returns the elapsed time it would take to asynchronously copy the shape - // from default to alternate memory space (or vice versa). - virtual float GetAsyncCopyElapsed(const Shape& shape) const; - - int64_t GetScheduleEndTime() const; - - // Returns the number of nested computation levels this instruction resides - // in. If while_only is true, it returns the while loop nest level and 0 - // means the instruction is not in a while loop. - int CalculateComputationNestLevel(const HloInstruction* instruction, - bool while_only) const; - - const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; } - const Options& options() const { return options_; } - - protected: - MemorySpaceAssignmentCostAnalysis( - const HloCostAnalysis& cost_analysis, const Options& options, - std::unique_ptr alias_analysis, - std::unique_ptr hlo_live_range, - std::unique_ptr call_graph) - : cost_analysis_(cost_analysis), - options_(options), - alias_analysis_(std::move(alias_analysis)), - hlo_live_range_(std::move(hlo_live_range)), - call_graph_(std::move(call_graph)) {} - - private: - const HloCostAnalysis& cost_analysis_; - const Options& options_; - std::unique_ptr alias_analysis_; - std::unique_ptr hlo_live_range_; - std::unique_ptr call_graph_; -}; - -// Abstract base class that memory space assignment uses to pick prefetch -// intervals. -class PrefetchIntervalPicker { - public: - PrefetchIntervalPicker() = default; - virtual ~PrefetchIntervalPicker() = default; - - // Returns true if the buffer can be allocated in alternate memory space - // without any copies (prefetches). - virtual bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, - int64_t start_time, - int64_t end_time) const = 0; - - // Returns the preferred end time for an eviction that starts at a given time - // and must end by the given end time. - virtual int64_t PreferredEvictionEndTime(const Shape& shape, - int64_t start_time, - int64_t latest_end_time) const = 0; - - // Returns the latest time that a prefetch can start. - virtual int64_t LatestPrefetchStartTime(const Shape& shape, - int64_t start_time, int64_t end_time, - const HloUse* use) const = 0; - - // Returns the preferred time that a prefetch can start. - virtual int64_t PreferredPrefetchStartTime( - const Shape& shape, int64_t earliest_prefetch_start_time, - int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const = 0; - - // Returns the latest time that a prefetch can end that is less than or equal - // to proposed_prefetch_end_time. - virtual int64_t LatestPrefetchEndTime( - int64_t original_prefetch_end_time, - int64_t proposed_prefetch_end_time) const { - return proposed_prefetch_end_time; - } - - // Returns the estimated end time of a prefetch that starts at the given time. - virtual int64_t EstimatedPrefetchEndTime(const Shape& shape, - int64_t start_time, - int64_t end_time) const = 0; - - // Returns the elapsed time in seconds between the logical interval that - // corresponds to the instruction schedule. - virtual float GetLogicalIntervalElapsed(int64_t start_time, - int64_t end_time) const = 0; - - // Begins the iterator for the first start time of the prefetch. - virtual void Begin(const HloUse& use, int64_t start_time, int64_t end_time, - std::optional preferred_time) = 0; - - // Advances the start time of the prefetch and returns that value. - virtual int64_t Next() = 0; - - // Returns true if the available prefetch intervals have been exhausted. - virtual bool Done() const = 0; - - // Returns the latest time the prefetch interval picker will have pick. - virtual int64_t latest_time() const = 0; - - // The retry number can be used to modify the interval picking policies. The - // first attempt will have a retry_number of 0, then 1, etc. - virtual void SetRetryNumber(int retry_number) { - retry_number_ = retry_number; - } - int retry_number() const { return retry_number_; } - - // Returns a debug string for the current state of the prefetch interval - // picker. - virtual std::string ToDebugString() const = 0; - - // Returns a debug string for no-copy allocation. - virtual std::string ToNoCopyDebugString(const Shape& shape, - int64_t start_time, - int64_t end_time) const = 0; - - // Prefetch interval pickers may return a value corresponding to the benefit - // of placing the BufferInterval in the alternate memory. The larger value, - // the more beneficial. - virtual std::optional BufferIntervalAlternateMemoryBenefit( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) - const { - return std::nullopt; - } - - protected: - const absl::flat_hash_map* - instruction_schedule_ = nullptr; - int retry_number_ = 0; -}; - -// Prefetch interval picker that uses instruction count to overlap asynchronous -// copies with independent computation. The min and max overlap counts describe -// the number of independent HLOs overlapped while a value is being prefetched -// into the alternate memory (between CopyStart and CopyDone HLO instructions). -// max_overlap_count attempts to prevent bringing tensors into the alternate -// memory too eagerly and hence occupying the space for other tensors which -// might use it. min_overlap_count attempts to prevent cases where tensors are -// prefetched into the alternate memory without sufficient time for the copy to -// take place. In those cases, it's just better to keep the tensor in the -// default memory instead of hurting the critical path with this copy that -// likely won't finish in time. -class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker { - public: - InstructionCountPrefetchIntervalPicker(int64_t min_overlap_count, - int64_t max_overlap_count) - : min_overlap_count_(min_overlap_count), - max_overlap_count_(max_overlap_count) {} - - bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, - int64_t start_time, - int64_t end_time) const override; - - int64_t PreferredEvictionEndTime(const Shape& shape, int64_t start_time, - int64_t latest_end_time) const override; - - int64_t LatestPrefetchStartTime(const Shape& shape, int64_t start_time, - int64_t end_time, - const HloUse* use) const override; - - int64_t PreferredPrefetchStartTime(const Shape& shape, - int64_t earliest_prefetch_start_time, - int64_t latest_prefetch_start_time, - int64_t prefetch_end_time) const override; - - int64_t EstimatedPrefetchEndTime(const Shape& shape, int64_t start_time, - int64_t end_time) const override; - float GetLogicalIntervalElapsed(int64_t start_time, - int64_t end_time) const override; - - void Begin(const HloUse& use, int64_t start_time, int64_t end_time, - std::optional preferred_time) override; - - int64_t Next() override; - bool Done() const override; - - int64_t latest_time() const override; - - std::string ToDebugString() const override; - std::string ToNoCopyDebugString(const Shape& shape, int64_t start_time, - int64_t end_time) const override; - - private: - int64_t min_overlap_count_; - int64_t max_overlap_count_; - int64_t end_time_; - int64_t current_prefetch_time_; -}; - -// Forward Declaration of MemorySpaceAssignmentCostAnalysis -class MemorySpaceAssignmentCostAnalysis; -// Prefetch interval picker that uses cost analysis to overlap asynchronous -// copies with independent computation. It uses min (independent computation -// duration) / (asynchronous copy duration) ratio to guide whether the prefetch -// is within the lower bound. For the upper bound, it restricts the maximum -// duration that a buffer may occupy the alternate memory space as a multiple of -// the time it would take to copy a buffer that is the size of the alternate -// memory. It starts with the preferred ratio in Begin() and works its way for -// alternately earlier and later prefetches until hitting min and max ratios. -// The value for buffer size for max async copy is a mechanism to prevent -// copying small buffers between the two memories unnecessarily. For calculating -// the max time that the buffer can reside in alternate memory, we use the -// larger of this value and the actual size of the buffer. A shape override can -// also be provided which causes the interval picker to use that shape for async -// copy durations instead of the actual shape of the copy. -class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { - public: - CostAnalysisPrefetchIntervalPicker( - const MemorySpaceAssignmentCostAnalysis& cost_analysis, - float min_overlap_to_async_copy_ratio, - float preferred_overlap_to_async_copy_ratio, - float max_overlap_to_mem_size_async_copy_ratio, int64_t mem_size_bytes, - const Shape* shape_override = nullptr); - - bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, - int64_t start_time, - int64_t end_time) const override; - - int64_t PreferredEvictionEndTime(const Shape& shape, int64_t start_time, - int64_t latest_end_time) const override; - - int64_t LatestPrefetchEndTime( - int64_t original_prefetch_end_time, - int64_t proposed_prefetch_end_time) const override; - - int64_t LatestPrefetchStartTime(const Shape& shape, int64_t start_time, - int64_t end_time, - const HloUse* use) const override; - - int64_t PreferredPrefetchStartTime(const Shape& shape, - int64_t earliest_prefetch_start_time, - int64_t latest_prefetch_start_time, - int64_t prefetch_end_time) const override; - - int64_t EstimatedPrefetchEndTime(const Shape& shape, int64_t start_time, - int64_t end_time) const override; - float GetLogicalIntervalElapsed(int64_t start_time, - int64_t end_time) const override; - - void Begin(const HloUse& use, int64_t start_time, int64_t end_time, - std::optional preferred_time) override; - - int64_t Next() override; - bool Done() const override; - - int64_t latest_time() const override; - - void SetRetryNumber(int retry_number) override; - - std::string ToDebugString() const override; - std::string ToNoCopyDebugString(const Shape& shape, int64_t start_time, - int64_t end_time) const override; - - std::optional BufferIntervalAlternateMemoryBenefit( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) - const override; - - private: - // Finds the minimum nest level in the given interval. - int GetMinWhileNestLevel(int64_t start_time, int64_t end_time) const; - - // Given the elapsed time to copy this buffer to the alternate memory, returns - // the longest time that this buffer may reside in the alternate memory space. - float GetMaxElapsedInAlternateMemory(float async_copy_elapsed) const; - - // For each instruction in the flattened schedule, maintain their elapsed time - // (in cumulative sum) and while nesting level. - std::vector elapsed_time_cumsum_; - std::vector while_nest_level_; - std::vector computation_nest_level_; - // Maintain the index of the most recent (before this instruction) nest level - // change in order to efficiently determine the minimum nest level in an - // interval. - std::vector while_nest_level_change_; - - const MemorySpaceAssignmentCostAnalysis& cost_analysis_; - float min_overlap_to_async_copy_ratio_; - float preferred_overlap_to_async_copy_ratio_; - float max_async_copy_elapsed_; - float max_overlap_multiplier_ = 1.0; - - float async_copy_elapsed_; - float inst_elapsed_reduction_; - int64_t end_logical_time_; - int64_t earliest_prefetch_time_; - int64_t latest_prefetch_time_; - bool using_increasing_prefetch_time_iterator_ = true; - int64_t increasing_prefetch_time_iterator_; - int64_t decreasing_prefetch_time_iterator_; - - std::vector while_execution_counts_; - // Shape override is used to override the shape of the shape of the async copy - // to treat all async copies the same duration. Having an override forces - // prefetches to be scheduled roughly in FIFO order. - std::optional shape_override_; -}; - // A class for turning a copy start time and end time into slice start times. class SlicedPrefetchStartTimePicker { public: @@ -595,503 +317,6 @@ class SlicedPrefetchStartTimePicker { // memory space. class MemorySpaceAssignment { public: - using Chunk = HeapSimulator::Chunk; - using BufferInterval = - GlobalDecreasingSizeBestFitHeap::BufferInterval; - using BufferIntervalCompare = - GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare; - using IsAllowedInAlternateMemoryFunction = - std::function; - using IsUseAllowedInAlternateMemoryFunction = - std::function; - using IsPositionAllowedInAlternateMemoryFunction = - std::function; - using ReservedScopedMemoryFunction = std::function>& /*operands_in_alternate_memory*/, - const absl::flat_hash_set& /*outputs_in_alternate_memory*/)>; - using UpdateLayoutFunction = std::function; - - // The BufferInterval sorting interface that MemorySpaceAssignment expects. - class BufferIntervalComparator { - public: - using BufferInterval = MemorySpaceAssignment::BufferInterval; - - virtual ~BufferIntervalComparator() = default; - - // A logging string explaining the sorting criteria. E.g., [ -size, offset ] - // indicates we sort (desc) size, then (asc) offset. - virtual std::string DescribeComparisonCriteria() const = 0; - - // A logging string containing the values used to sort buffer_interval. - // E.g., we might return [ -1024, 100 ], if the criteria is [ -size, - // offset ]. - virtual std::string CriteriaToString( - const BufferInterval& buffer_interval) = 0; - - // comparator.LessThan(lhs, rhs) will be used for BufferIntervalCompare. - virtual bool LessThan(const BufferInterval& lhs, - const BufferInterval& rhs) = 0; - - // Used to create a functor that can be passed to a method like std::sort. - // E.g., absl::c_sort(v, comparator.GetComparisonFunctor()); - BufferIntervalCompare GetComparisonFunctor() { - return [this](const BufferInterval& lhs, const BufferInterval& rhs) { - return LessThan(lhs, rhs); - }; - } - - protected: - BufferIntervalComparator() = default; - }; - - // MemorySpaceAssignment uses a notion of a slow and large default memory - // space and a fast and small alternate memory space. - enum class MemorySpace { kDefault, kAlternate }; - - // Forward declaration for Allocation. - class Allocation; - class ParentAllocation; - - // This class represents an allocation that might either be in the default or - // alternate memory. An HloValue might live in multiple different allocations - // over its lifetime. The lifetimes of the allocations are defined using - // start_time and end_time, which corresponds to the instruction indexes in - // the flattened schedule. Each of these allocations might partially overlap - // with each other. CopyAllocation defined below represents asynchronous - // copies between Allocations. - // - // Consider an instruction Foo, and its users Bar and Baz, and the times given - // in terms of the flattened schedule of the entire module: - // - // Foo:10 - // / \ - // Bar:14 \ - // Baz:25 - // - // A valid memory space assignment could be like the following: - // - // Time: 10 ... 14 ... 25 - // Foo Bar Baz - // Alternate +-------+ +-----+ - // Default +---------------------+ - // ^ ^ ^ ^ - // | | | | - // evict evict prefetch prefetch - // start end start end - // - // This would be represented with: - // - Allocation(memory_space=kAlternate, start_time=10, end_time=14) - // - CopyAllocation(memory_space=kDefault, start_time=12, end_time=25) - // - CopyAllocation(memory_space=kAlternate, start_time=22, end_time=25) - class Allocation { - friend class ParentAllocation; - - public: - Allocation(HloPosition defining_position, MemorySpace memory_space, - std::optional chunk, int64_t start_time, int64_t end_time, - bool is_scoped_allocation) - : defining_position_(defining_position), - memory_space_(memory_space), - chunk_(chunk), - start_time_(start_time), - end_time_(end_time), - is_scoped_allocation_(is_scoped_allocation) { - CHECK(!is_scoped_allocation || defining_position.index == ShapeIndex({})); - } - virtual ~Allocation() = default; - - // True if the allocation is for a copy or a sliced-copy. - bool is_copy_like_allocation() const; - - virtual bool is_copy_allocation() const { return false; } - virtual bool is_sliced_copy_allocation() const { return false; } - - // Adds a use to this allocation. - void AddUse(HloUse use); - - // Extends the end time of this allocation. - void Extend(int64_t end_time) { end_time_ = std::max(end_time_, end_time); } - - // After all of the time ranges for the allocations have been assigned, - // Process morphs the instructions affected to assign the memory spaces and - // insert asynchronous copy instructions if necessary. - virtual Status Process(); - - // An optional post-process step that will be called after all allocations - // have been processed. - virtual Status PostProcess() { return OkStatus(); } - - // Marks (adds this allocation to needed_allocations) if this allocation is - // needed. Allocation and CopyAllocations are always needed and - // ParentAllocations are needed if they have any uses or if other - // CopyAllocation or ParentAllocations depend on them. - virtual void MarkIfNeeded( - absl::flat_hash_set& needed_allocations) const; - - // Marks this allocation as needed. - virtual void MarkNeeded( - absl::flat_hash_set& needed_allocations) const; - - // Returns the defining position for this allocation. - virtual HloPosition defining_position() const { return defining_position_; } - - // Returns the time the buffer is first available to be used. For - // Allocation, this is start_time. - virtual int64_t earliest_available_time() const { return start_time_; } - - const std::vector& uses() const { return uses_; } - void clear_uses() { uses_.clear(); } - MemorySpace memory_space() const { return memory_space_; } - // Returns the associated chunk that may be a nullopt if the allocation is - // in the default memory space. - std::optional maybe_chunk() const { return chunk_; } - // Returns the associated chunk. The caller should ensure that the chunk is - // defined (the allocation should be in the alternate memory space). - Chunk chunk() const { - CHECK(chunk_.has_value()); - return *chunk_; - } - Chunk* mutable_chunk() { return &*chunk_; } - void set_offset(int64_t offset); - void set_start_time(int64_t start_time) { start_time_ = start_time; } - void set_end_time(int64_t end_time) { end_time_ = end_time; } - int64_t start_time() const { return start_time_; } - int64_t end_time() const { return end_time_; } - bool is_scoped_allocation() const { return is_scoped_allocation_; } - virtual std::optional cross_program_prefetch_index() const { - return std::nullopt; - } - - bool operator==(const Allocation& other) const; - virtual std::string ToString() const; - - bool is_in_alternate_mem() const { - return memory_space_ == MemorySpace::kAlternate; - } - bool is_in_default_mem() const { - return memory_space_ == MemorySpace::kDefault; - } - - protected: - // Recursively create kGetTupleElement instructions if the defining position - // shape is not an array. Returns the new instruction that has array shape. - HloInstruction* AddGetTupleElements() const; - - HloPosition defining_position_; - std::vector uses_; - MemorySpace memory_space_; - std::optional chunk_; - int64_t start_time_; - int64_t end_time_; - const bool is_scoped_allocation_; - }; - - // This class represents an allocation as a result of an asynchronous copy. - // Note: CopyStart instructions are inserted after - // `copy_start_schedule_after`, while CopyDone instructions are inserted - // before `copy_done_schedule_before_time`. - class CopyAllocation : public Allocation { - public: - // TODO(b/307342076): Reorder scheduling times to be - // copy_start_schedule_after_time, copy_done_schedule_before_time, end_time - CopyAllocation( - Allocation& prev_allocation, MemorySpace memory_space, - std::optional chunk, int64_t copy_start_schedule_after_time, - int64_t copy_done_schedule_before_time, int64_t end_time, - std::optional cross_program_prefetch_index = std::nullopt); - - bool is_copy_allocation() const override { return true; } - - Status Process() override; - - void MarkNeeded(absl::flat_hash_set& needed_allocations) - const override; - - HloPosition defining_position() const override { - // Unless explicitly set, the defining position of a copy allocation in - // retrieved from the previous allocation. This is because we don't create - // new CopyStart/CopyDone instructions until later and the position should - // point to the previous (copy or otherwise) allocation's position for the - // original defining position. - if (defining_position_.instruction == nullptr) { - return prev_allocation_.defining_position(); - } - return defining_position_; - } - - HloInstruction* copy_start() const { return copy_start_; } - HloInstruction* copy_done() const { return copy_done_; } - - // Returns the time the buffer is first available to be used. For - // CopyAllocation, this is when the copy ends, which is - // copy_done_schedule_before. - int64_t earliest_available_time() const override { - return copy_done_schedule_before_; - } - - int64_t copy_start_schedule_after() const { - return copy_start_schedule_after_; - } - int64_t copy_done_schedule_before() const { - return copy_done_schedule_before_; - } - - void set_copy_start_schedule_after(int64_t copy_start_schedule_after) { - copy_start_schedule_after_ = copy_start_schedule_after; - } - - void set_copy_done_schedule_before(int64_t copy_done_schedule_before) { - copy_done_schedule_before_ = copy_done_schedule_before; - } - - std::optional cross_program_prefetch_index() const override { - return cross_program_prefetch_index_; - } - - bool operator==(const CopyAllocation& other) const; - std::string ToString() const override; - - const Allocation& prev_allocation() { return prev_allocation_; } - Allocation& mutable_prev_allocation() { return prev_allocation_; } - - private: - Allocation& prev_allocation_; - // These variables define the scheduling boundaries where CopyStart and - // CopyDone can be scheduled. The earliest CopyStart can be scheduled is - // after copy_start_schedule_after_ and the latest CopyDone can be scheduled - // is before copy_done_schedule_before_. - int64_t copy_start_schedule_after_; - int64_t copy_done_schedule_before_; - HloInstruction* copy_start_; - HloInstruction* copy_done_; - std::optional cross_program_prefetch_index_; - }; - - // The parameters for slicing a single dimension of a tensor. - struct SliceParam { - std::string ToString() const; - bool operator==(const SliceParam& other) const; - - int64_t start_inclusive; - int64_t end_exclusive; - }; - - // A proposed way to slice a buffer. - struct SliceProposal { - std::string ToString() const; - friend std::ostream& operator<<(std::ostream& os, - const SliceProposal& proposal); - std::tuple&, int64_t> - ToTuple() const; - bool operator==(const SliceProposal& other) const; - - // Shape resulting from the slice. - Shape slice_shape; - - // slice_params map to the parameters that would be passed to a slice - // instruction. Thus: - // * There should be a slice parameter for every dimension in the shape of - // the tensor being sliced. - // * The ith slice_param applies to the ith logical dimension in the shape - // being sliced. - // * If a dimension is not being sliced, it should have a SliceParam of - // {0, dim size}. - std::vector slice_params; - - // The size to be allocated for the slice. Note, this may be > the size of - // the slice shape, due to additional padding that may occur when the slices - // are concatenated back together. - int64_t slice_size; - }; - - // A SliceProposalCollection proposes a way to to slice an AllocationRequest. - // A SliceProposalCollection is generated from a SliceProposalFunction and is - // used when we want to slice a prefetch. - using SliceProposalCollection = std::vector; - using SliceProposalFunction = std::function( - const Shape& shape, const SlicedPrefetchOptions& options)>; - - // A SliceDecision is a SliceProposal that we've determined where and when to - // allocate. - struct SliceDecision { - std::string ToString() const; - bool operator==(const SliceDecision& other) const; - - Chunk chunk; - int64_t exclusive_start_time; - SliceProposal sizing; - float copy_resource_consumed; - }; - - // This class represents an allocation resulting from asynchronous sliced - // copies. - // - // Let the sliced allocation be represented as follows, and imagine that t3 - // is the time when the entire buffer [p0, p3) is available for use - // - // space - // ^ - // p3 | +-----------+ - // | | | - // p2 | +---+ | - // | | | - // p1 | +-------+ | - // | | | - // p0 | +-------+ - // +---|---|---|---|---|----> time - // t0 t1 t2 t3 t4 - // - // The Allocation underlying the SlicedCopyAllocation will use the following - // dimensions: - // - chunk = [p0, p3) - // - start time = t2 - // - earliest_available_time = t3 - // - end_time = t4 - class SlicedCopyAllocation : public Allocation { - public: - // Full details about a slice in the sliced allocation. - struct SliceDetail { - std::string ToString() const; - std::tuple - ToTuple() const; - bool operator==(const SliceDetail& other) const; - - // Create the instructions to copy the slice. This method updates - // copy_start and copy_done. Given a Shape, the hardware may have - // constraints on how the shape is physically laid out in memory. - // update_layout_fn updates a Shape's layout in accordance with those - // constraints. - Status CreateAsyncSlice(const Shape& original_shape, - HloInstruction& producer, HloComputation& parent, - absl::FunctionRef update_layout_fn); - - SliceDecision slice_decision; - int64_t copy_start_after_time = -1; - int64_t copy_done_before_time = -1; - HloInstruction* copy_start = nullptr; - HloInstruction* copy_done = nullptr; - }; - - // REQUIRES: - // - slice_decisions_sorted_by_start_time.size() >= 2, otherwise, - // CopyAllocation should be used. - SlicedCopyAllocation( - const Allocation& prev_allocation, MemorySpace memory_space, - std::vector slice_decisions_sorted_by_start_time, - int64_t copy_done_schedule_before_time, int64_t end_time, - absl::FunctionRef update_layout_fn); - - bool is_sliced_copy_allocation() const override { return true; } - - // MemorySpaceAssignment::Process() calls Process() to create asynchronous - // slice copies, and a bitcast-concat call to glue the slices back together. - Status Process() override; - - // Marks the allocation as needed. - void MarkNeeded(absl::flat_hash_set& needed_allocations) - const override; - - // Returns the defining position for this allocation. - HloPosition defining_position() const override; - - // Returns the time the buffer is first available to be used. For - // SlicedCopyAllocation, this is when all copies have ended. - int64_t earliest_available_time() const override; - - std::vector SliceOffsetsSortedByStartTime() const; - void AddDiffToAllSliceOffsets(int64_t diff); - - // Used to update offsets and start times after repacking. - void ImportRepackedSliceData( - const MemorySpaceAssignmentRepacker::SlicedAllocationData& data); - - const std::vector& slice_details_sorted_by_start_time() const; - std::vector& mutable_slice_details_sorted_by_start_time(); - HloInstruction* concat() const { return concat_; } - - std::tuple&, - const HloInstruction*> - ToTuple() const; - bool operator==(const SlicedCopyAllocation& other) const; - std::string ToString() const override; - - private: - SlicedCopyAllocation() = delete; - - // Create an instruction to concatenate the slices. Populates concat_. - Status CreateBitcastConcat(const Shape& shape, - absl::Span slices); - - Shape original_shape_to_slice_; - const Allocation& prev_allocation_; - // REQUIRES: - // - sorted_segments_[i].copy_start_after_time <= - // sorted_segments_[i+j].copy.start_after_time - // - sorted_segments_[i].copy_done_before_time <= - // sorted_segments_[i+j].copy.start_before_time - std::vector slice_details_sorted_by_start_time_; - HloInstruction* concat_ = nullptr; - absl::FunctionRef update_layout_fn_; - }; - - // An allocation in the default memory space that mirrors another Allocation - // object. This is useful to model an eviction that happens before a while op - // so that we don't need to redundantly evict the buffer after the while op as - // well. - class MirroredAllocation : public Allocation { - public: - MirroredAllocation(const Allocation& original_allocation, int64_t time) - : Allocation(original_allocation.defining_position(), - MemorySpace::kDefault, original_allocation.maybe_chunk(), - /*start_time=*/time, - /*end_time=*/time, /*is_scoped_allocation=*/false), - original_allocation_(original_allocation) {} - - Status Process() override; - - void MarkNeeded(absl::flat_hash_set& needed_allocations) - const override; - - std::string ToString() const override; - - private: - const Allocation& original_allocation_; - }; - - // An allocation in default memory space that is defined in the parent - // computation. If a value has a copy in the default memory space in the - // parent computation, we don't need to evict this buffer in a while loop. - class ParentAllocation : public Allocation { - public: - ParentAllocation(const Allocation& original_allocation, - HloInstruction* calling_instruction, HloPosition position, - int64_t time) - : Allocation(position, MemorySpace::kDefault, - original_allocation.maybe_chunk(), /*start_time=*/time, - /*end_time=*/time, /*is_scoped_allocation=*/false), - original_allocation_(original_allocation), - calling_instruction_(calling_instruction) {} - - Status Process() override; - Status PostProcess() override; - - void MarkIfNeeded(absl::flat_hash_set& - needed_allocations) const override; - void MarkNeeded(absl::flat_hash_set& needed_allocations) - const override; - - std::string ToString() const override; - - private: - const Allocation& original_allocation_; - HloInstruction* calling_instruction_; - }; - - using AllocationSequence = std::vector>; // AllocationValue is used to break up HloValues for each non-trivial position // (trivial positions are considered Tuple, GetTupleElement, and Bitcast). An // HloValue may include positions and uses that alias with each other across @@ -1252,12 +477,12 @@ class MemorySpaceAssignment { virtual ~MemorySpaceAssignment() = default; // Runs the MemorySpaceAssignment pass. - static StatusOr> Run( + static absl::StatusOr> Run( HloModule* module, const HloLiveRange& hlo_live_range, const HloAliasAnalysis& alias_analysis, const Options& options); // Calculates asynchronous copy statistics. - StatusOr CalculateAsyncCopyStats() const; + absl::StatusOr CalculateAsyncCopyStats() const; // Verify that the memory space assignment is free of overlapping buffers and // export heap simulator trace to be used by buffer_assignment. @@ -1265,9 +490,9 @@ class MemorySpaceAssignment { protected: // Main driver of the memory space assignment pass. - virtual StatusOr> RunMemorySpaceAssignment( - const HloLiveRange& hlo_live_range, - const HloAliasAnalysis& alias_analysis); + virtual absl::StatusOr> + RunMemorySpaceAssignment(const HloLiveRange& hlo_live_range, + const HloAliasAnalysis& alias_analysis); // Finds an AllocationSequence for placing buffers in alternate memory using // the AlternateMemoryBestFitHeap algorithm. Must be set before Process() is @@ -1336,8 +561,10 @@ class MemorySpaceAssignment { std::vector flattened_instructions_; absl::flat_hash_set computations_in_schedule_; std::unique_ptr preset_assignments_; - std::vector> alternate_memory_assignments_; - std::vector> scoped_memory_assignments_; + std::vector> + alternate_memory_assignments_; + std::vector> + scoped_memory_assignments_; int64_t alternate_memory_size_ = 0; // These maps hold vectors of new instructions that need to be scheduled after @@ -1352,36 +579,47 @@ class MemorySpaceAssignment { // // This comparator caches HloValues -> latest use time. class MemoryBoundednessBufferIntervalComparator - : public MemorySpaceAssignment::BufferIntervalComparator { + : public BufferIntervalComparator { public: MemoryBoundednessBufferIntervalComparator( - const MemorySpaceAssignmentCostAnalysis& cost_analysis, - MemorySpaceAssignmentCostAnalysis::Cache* cost_analysis_cache); + const CostAnalysis& cost_analysis, + CostAnalysis::Cache* cost_analysis_cache); + + MemoryBoundednessBufferIntervalComparator( + const CostAnalysis& cost_analysis, + CostAnalysis::Cache* cost_analysis_cache, + MsaSortOrderOverrides msa_sort_order_overrides); ~MemoryBoundednessBufferIntervalComparator() override = default; std::string DescribeComparisonCriteria() const override; - std::string CriteriaToString(const BufferInterval& buffer_interval) override; - bool LessThan(const BufferInterval& lhs, const BufferInterval& rhs) override; + std::string CriteriaToString( + const MsaBufferInterval& buffer_interval) override; + bool LessThan(const MsaBufferInterval& lhs, + const MsaBufferInterval& rhs) override; private: // See the value returned by DescribeComparisonCriteria() for the meaning of // each tuple element. - using ComparisonTuple = - std::tuple; - - ComparisonTuple GetTuple(const BufferInterval& buffer_interval); + using ComparisonTuple = std::tuple; + ComparisonTuple GetTuple(const MsaBufferInterval& buffer_interval); + int64_t GetLatestUseTime(const MsaBufferInterval& buffer_interval); absl::flat_hash_map buffer_to_latest_use_; - const MemorySpaceAssignmentCostAnalysis& cost_analysis_; - MemorySpaceAssignmentCostAnalysis::Cache* cost_analysis_cache_; + const CostAnalysis& cost_analysis_; + CostAnalysis::Cache* cost_analysis_cache_; + + // Config to override alternate memory assignment sorting order for filtered + // buffers. + MsaSortOrderOverrides msa_sort_order_overrides_; }; // The default BufferIntervalComparator used for cross-program prefetching. // // This class caches HloValue -> {latest use, cumulative use size }. class DefaultCrossProgramPrefetchBufferIntervalComparator - : public MemorySpaceAssignment::BufferIntervalComparator { + : public BufferIntervalComparator { public: explicit DefaultCrossProgramPrefetchBufferIntervalComparator( const HloLiveRange& hlo_live_range); @@ -1389,8 +627,10 @@ class DefaultCrossProgramPrefetchBufferIntervalComparator ~DefaultCrossProgramPrefetchBufferIntervalComparator() override = default; std::string DescribeComparisonCriteria() const override; - std::string CriteriaToString(const BufferInterval& buffer_interval) override; - bool LessThan(const BufferInterval& lhs, const BufferInterval& rhs) override; + std::string CriteriaToString( + const MsaBufferInterval& buffer_interval) override; + bool LessThan(const MsaBufferInterval& lhs, + const MsaBufferInterval& rhs) override; private: // See the value returned by DescribeComparisonCriteria() for the meaning of @@ -1403,290 +643,13 @@ class DefaultCrossProgramPrefetchBufferIntervalComparator int64_t cumulative_use_size = 0; }; - ComparisonTuple GetTuple(const BufferInterval& buffer_interval); + ComparisonTuple GetTuple(const MsaBufferInterval& buffer_interval); absl::flat_hash_map additional_sort_data_; const HloLiveRange& hlo_live_range_; }; -// Filters prefetches by matching against multiple filters and overrides the -// preferred prefetch time for matching prefetches by the provided override -// strategy. -class FilterUpdatePreferredPrefetch { - public: - // Supported filters for prefetch filtering by operand size, instruction name, - // operand number and operand index matching. - enum class FilterType { - OP_SIZE_LTE, // sting value: op_size_lte, filter value type: integer - OP_SIZE_GTE, // sting value: op_size_gte, filter value type: integer - INSTRUCTION_NAME_EXACT, // sting value: instruction_name_exact, - // filter value type: string - OP_NUMBER_EXACT, // sting value: op_number_exact, - // filter value type: integer - OP_INDEX_EXACT // sting value: op_index_exact, filter value type: string - // (empty string for {}, 1 for {1} and 1#2 for {1,2}) - }; - // Strategies to compute new perferred prefetch time. Prefetch eagerness - // sets prefetch time to a time within the live-range depending on a value, - // e.g. 0.5 sets it exactly in the middle of the live-range. Put after - // instruction or put before instruction finds an instruction in the schedule - // and puts the preferred prefetch time before or after the found instruction. - enum class OverrideType { - PREFETCH_EAGERNESS, // sting value: prefetch_eagerness, - // override value type : float - PUT_AFTER_INSTRUCTION, // sting value: put_after_instruction, - // override value type: string - PUT_BEFORE_INSTRUCTION // sting value: put_before_instruction, - // override value type: string - }; - std::vector> filter_list_; - OverrideType override_type_; - std::string override_value_; - - std::string ToString() const { return config_string_; } - - static StatusOr> - ParseFilterUpdatePreferredPrefetches(std::string config); - - static StatusOr IsOpSizeGte(int64_t operand_size, std::string config); - - static StatusOr IsOpSizeLte(int64_t operand_size, std::string config); - - static StatusOr IsInstructionNameExact( - absl::string_view instruction_name, std::string config); - - static StatusOr IsOpNumberExact(int64_t operand_number, - std::string config); - - static StatusOr IsOpIndexExact(const ShapeIndex& operand_index, - std::string config); - - StatusOr> GetPrefetchByEagerness( - int64_t earliest_prefetch_time, int64_t latest_prefetch_time) const; - - StatusOr> GetPrefetchTimeAfterInstruction( - const absl::flat_hash_map& schedule) - const; - - StatusOr> GetPrefetchTimeBeforeInstruction( - const absl::flat_hash_map& schedule) - const; - - private: - std::string config_string_; - StatusOr GetScheduleTimeFromInstructionName( - const absl::flat_hash_map& schedule) - const; - - static StatusOr ParseFilterType(std::string config); - - static StatusOr ParseOverrideType(std::string config); - - static StatusOr ParseOperandIndex(std::string config); - - static StatusOr - ParseFilterUpdatePreferredPrefetch(std::string config); -}; - -// The different options to be passed to the Run() API. -struct Options { - // Backend-specific integer value that describes the alternate memory. - int64_t alternate_memory_space = 0; - - // Maximum size of the alternate memory space. - int64_t max_size_in_bytes = 0; - - // Memory alignment of the alternate memory space. - int64_t alignment_in_bytes = 1; - - // If provided, we sort the buffers using this comparator. Otherwise, we use - // GlobalDecreasingSizeBestFitHeap::kSpatial. - MemorySpaceAssignment::BufferIntervalComparator* buffer_interval_comparator = - nullptr; - - // This object determines how early and how late prefetches can occur. - PrefetchIntervalPicker* prefetch_interval_picker = nullptr; - - // This object is used to determine the benefit of a particular allocation. - MemorySpaceAssignmentCostAnalysis* cost_analysis = nullptr; - - // Size function for buffer values. - BufferValue::SizeFunction size_fn; - - // This function can be used to prevent certain HloValues (e.g., based on - // the opcode) to be placed on the alternate memory. - MemorySpaceAssignment::IsAllowedInAlternateMemoryFunction - is_allowed_in_alternate_mem_fn; - - // This function can be used to prevent certain HloUses (e.g., based on - // the opcode) to be placed on the alternate memory. - MemorySpaceAssignment::IsUseAllowedInAlternateMemoryFunction - is_use_allowed_in_alternate_mem_fn = [](const HloUse&) { return true; }; - - // Specifies if the given position is allowed in the alternate memory. - MemorySpaceAssignment::IsPositionAllowedInAlternateMemoryFunction - is_position_allowed_in_alternate_mem_fn = - [](const HloPosition&) { return true; }; - - // This function returns the amount of scoped memory in bytes that should be - // reserved during the execution of this instruction. - MemorySpaceAssignment::ReservedScopedMemoryFunction - reserved_scoped_memory_fn = - [](const HloInstruction*, - const absl::flat_hash_set< - std::pair>& /*operands_in_alternate_memory*/, - const absl::flat_hash_set< - ShapeIndex>& /*outputs_in_alternate_memory*/) { return 0; }; - - // If true, we will try to reduce scoped allocation buffer size for all - // instructions if their operand/output has been allocated in alternate - // memory. - bool reduce_scoped_memory_limit = false; - - // If true, we allocate the reserved scoped memory at the same offset. This - // is useful to enable more deduplication between HLOs that have reserved - // scoped memories, but may result in less efficient memory packing. - bool allocate_reserved_scoped_memory_at_same_offset = true; - - // Specifies the upper bound for number of outstanding prefetches and - // evictions, -1 for unlimited. - int64_t max_outstanding_prefetches = -1; - int64_t max_outstanding_evictions = -1; - - // Extra outstanding prefetch limit for while uses (in addition to - // max_outstanding_prefetches). - int64_t while_use_extra_outstanding_prefetch_limit = 0; - - // Specifies the maximum number of retries that will be performed for each - // value in case prefetching failed due to running out of asynchronous - // copies or asynchronous copy resource. - int64_t max_retries = 1; - - // The maximum number of repacks that we are willing to perform in case we - // can't allocate a buffer due to running out of memory. If this value is - // greater than 0, repacker must be non-nullptr. - int64_t max_repacks = 0; - - // This variable is used by the cost analysis in estimating how many times - // each while loop will execute. Nested loops will be assumed to have - // executed pow(while_execution_count, nesting_level) times. - uint64_t xla_tpu_memory_space_assignment_while_execution_count = 5ULL; - - // This variable is used to scale the alternate memory benefit factor for - // large buffers. The default scaling function is sqrt. - std::string - xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers = - "SQRT"; - - float async_copy_bandwidth_bytes_per_second = 0.0f; - - float alternate_mem_bandwidth_bytes_per_second = 0.0f; - - // The repacking algorithm to reduce fragmentation. Must be non-null if - // max_repacks is greater than 0. - MemorySpaceAssignmentRepacker* repacker = nullptr; - - // This is only useful for testing, repack after every allocation. - bool repack_after_every_allocation = false; - - // If true, tries allocating buffers across (e.g., before and inside a while - // loop body) sequential calls (kWhile, kCall, and kConditional). - bool allocate_across_sequential_calls = false; - - // If true, verifies the memory space assignment against overlapping - // buffers. - bool verify = false; - - // If not nullptr, this function is called to dump debugging information. - // The first argument is appended to the file name and the second argument - // is the contents of the file. - std::function dump_fn = nullptr; - - // Enable prefetching buffers into preferred memory across program - // boundaries - bool enable_cross_program_prefetch = true; - - // If true, use buffer_interval_compare to determine which buffers to - // prefetch across program boundaries. - bool default_cross_program_prefetch_heuristic = false; - - // Enable cross-program prefetch freeing optimization where the - // cross-program-prefetched buffer can be reused. - bool enable_cross_program_prefetch_freeing = true; - - // The maximum number of cross program prefetches. - // TODO(tjablin): Use a heuristic to determine this automatically. - int max_cross_program_prefetches = 1; - - // Enable redundant eviction optimization in/around while loops. If enabled, - // this optimization would keep a copy of the buffer in the default memory in - // addition to alternate memory to eliminate redundant evictions. - bool enable_while_redundant_eviction_elimination = true; - - // An optional memory space assignment autotuning config, which is used - // to sort allocated buffers. - std::optional> autotuning_config = std::nullopt; - - // Scales effective bandwidth for async copies. Valid range is (0, 1]. - float async_copy_bandwidth_scaling_factor = 1.0; - - // If true, uses the earlier instance of the same instruction to use as - // preferred prefetch start time. - bool use_repeated_instance_for_preferred_prefetch_time = false; - - // If true, enforces the FIFO order for prefetches. - bool enforce_prefetch_fifo_order = false; - - // The ratio of use bytes to copy bytes for a given allocation site below - // which we consider the site to be inefficient. A value of 0 would treat all - // sites as efficient and a value of 1 would require the amount of bytes used - // at the site to be at least as much as the async copy bytes. There are two - // factors that determine the copy and use bytes: - // - Some uses don't actually access the entire tensor, e.g. in - // dynamic-update-slice. - // - copy_bytes may be larger than the size of the tensor as well. An - // example is a tensor may be prefetched, used, and then evicted. In that - // case copy_bytes would be twice the size of the tensor. - float inefficient_use_to_copy_ratio = 0.0; - - // This is mostly used for testing, it allows a test case to inject its own - // logic for AlternateMemoryBestFitHeap::GetInefficientAllocationSites. - std::function>( - absl::Span)> - get_inefficient_allocation_sites_fn = nullptr; - - // The window size used to calculate the pipeline overhead when HLO accesses - // the default memory, in MiB. - float pipeline_overhead_window_size_mib = 0; - - // Config to filter prefetches and update preferred prefetch times for the - // filtered prefetches according to an update config. - std::vector filter_update_preferred_prefetches; - - // Options for slicing prefetches into smaller asynchronously copied pieces. - SlicedPrefetchOptions sliced_prefetch_options; - - // Options for the memory-bound loop optimizer feature. - MemoryBoundLoopOptimizerOptions memory_bound_loop_optimizer_options; - - // A function for updating shape layouts. - MemorySpaceAssignment::UpdateLayoutFunction update_layout_fn = [](Shape*) {}; - - MemorySpaceAssignment::SliceProposalFunction propose_slice_fn = - [](const Shape&, const SlicedPrefetchOptions&) - -> xla::StatusOr { - return UnimplementedStrCat("Generation of SliceProposals unimplemented"); - }; - - // Option to always spill buffers from alternate memory to default memory - // and prefetching back to alternate memory(if needed) just in time for use. - bool always_spill_to_default_memory = false; -}; - // A struct representing an asynchronous copy with its logical start and end // time (time that copy done is scheduled), the resource this copy would use, // its destination memory space, and a unique ID. @@ -1694,12 +657,10 @@ struct AsynchronousCopy { int64_t exclusive_start_time; int64_t end_time; float resource; - MemorySpaceAssignment::MemorySpace destination; + MemorySpace destination; int64_t id; - std::tuple - AsTuple() const { + std::tuple AsTuple() const { return std::make_tuple(exclusive_start_time, end_time, resource, destination, id); } @@ -1814,9 +775,8 @@ class AsynchronousCopyResource { // A useful debugging tool for printing several pieces of information about // AsynchronousCopyResource. - std::string Dump( - int64_t start_time, int64_t end_time, - MemorySpaceAssignment::MemorySpace memory_space_filter) const; + std::string Dump(int64_t start_time, int64_t end_time, + MemorySpace memory_space_filter) const; private: // Internal helper method to implement adding/removing/checking resources. @@ -1853,265 +813,18 @@ class AsynchronousCopyResource { std::vector delay_; }; -// TODO(b/280618622): Refactor this class out of this file. -// -// An optimizer for unrolled memory-bound loops. It keeps track of alternate -// memory capacity and default memory bandwidth to decide the allocations of -// each tensor within a loop iteration. The assumption is that all of the -// unrolled loop iterations will use the same allocation decisions, so we can -// spend more time to optimize this one iteration as optimally as possible. -// -// To represent instructions, we keep track of three iterations (previous, -// current, and next), as well as the header and footer regions that are before -// and after the loop, respectively. -// -// We classify each tensor used in the current iteration as one of the following -// allocations based on its positions and uses: -// -// Temporary Allocations: These are produced by a producer in the current -// iteration and consumed either in this or the next iteration. For these, we -// try to give them alternate memory allocations for their entire live range. -// -// Case 1: producer and consumer all in the current iteration. -// p-----c--c -// Case 2: producer is in the current iter, consumer is in the next iter. -// p-----c -// idx: |...| 0 1 2 3 4| 0 1 2 3 4| 0 1 2 3 4|...| -// iter: head |...| prev | current | next |...| foot -// -// Loop Carried Dependences: This is where the last use is at a larger index -// than the producer. This would require 2X peak buffer consumption because both -// this and next iteration's buffer is alive at the same time. This case is -// currently not supported. -// -// Case 3: producer is in the current iter, consumer is in the next iter -// (consumer idx >= producer idx). -// p-----------------c -// idx: |...| 0 1 2 3 4| 0 1 2 3 4| 0 1 2 3 4|...| -// iter: head |...| prev | current | next |...| foot -// -// Pinned Allocations: These are values produced at the header and are used in -// every iteration at the same indices. For these, we just allocate the buffer -// for the duration of the loop: -// -// Case 4: producer: kHead, consumer: kCurrent -// p---------------c--------------c--------------c-------- -// idx: |...| 0 1 2 3 4| 0 1 2 3 4| 0 1 2 3 4|...| -// iter: head |...| prev | current | next |...| foot -// -// Prefetch Allocations: These are values produced at the header and are used in -// the current (and possibly next) iteration. We will try to prefetch these -// values into the alternate memory: -// -// Case 5: producer: kHead, consumer: kCurrent -// p---------------------------------c--------c -// idx: |...| 0 1 2 3 4| 0 1 2 3 4| 0 1 2 3 4|...| -// iter: head |...| prev | current | next |...| foot -class MemoryBoundLoopOptimizer { - public: - // We represent each tensor used in the current iteration as a LoopValue, - // wrapping the relevant information such as its HLO value, indices and - // pointers to its use and position sites in different iterations. - struct LoopValue { - // An enum that encodes the allocation type that is suitable for this - // LoopValue. See the comment above on what each of these mean. - enum class AllocationType { - kTemporary, - kLoopCarriedDependence, - kPinned, - kPrefetch, - kUnsupported - }; - - // ToString methods for logging/debugging. - static std::string AllocationTypeToString(AllocationType allocation_type); - std::string ToString() const; - - // Returns true if memory-bound loop optimizer supports allocating this type - // of a loop value. - bool IsAllocationTypeSupported() const; - - // The HloValues that correspond to this LoopValue. - std::vector hlo_values; - // The position in the header, if any. - std::optional header_position; - // The loop index and position in the previous and current iterations. - std::vector> prev_iteration_positions; - std::vector> loop_positions; - // The loop index and use in the current and next iterations. - std::vector> loop_uses; - std::vector> next_iteration_uses; - // The allocation type. - AllocationType allocation_type; - // Size of this tensor. - int64_t size; - // The default memory bandwidth savings were we to successfully put this in - // the alternate memory using the allocation type, in bytes. - float savings; - // The savings divided by the size. This is typically 2 for temporary - // allocations (skip a write and a read to the default memory). More complex - // production/consumption patterns may result in higher or lower values. We - // use this value to sort LoopValues so that the algorithm can prioritize - // allocating the buffers with the highest savings per byte to the alternate - // memory. - float savings_per_byte; - // The optimized AllocationSequence. - MemorySpaceAssignment::AllocationSequence allocations; - }; - - // Factory method to create and initialize a MemoryBoundLoopOptimizer. - static StatusOr> Create( - int loop_start, int loop_end, uint64_t alternate_memory_size, - const MemoryBoundLoopOptimizerOptions& options, - const HloLiveRange& hlo_live_range, - const HloAliasAnalysis& alias_analysis_, - const MemorySpaceAssignmentCostAnalysis& cost_analysis, - const BufferValue::SizeFunction& size_function); - - // Optimize the loop. Initialize must be called first. - void Optimize(); - - // Calculate the steady-state execution time of one loop iteration using the - // allocation decisions so far. - float CalculateExecutionTime() const; - - // Return the LoopValues. - const std::vector& loop_values() const { return loop_values_; } - std::vector& loop_values() { return loop_values_; } - - // Return the remaining memory vector for each point in time in the loop using - // the allocation decisions so far. - const std::vector& remaining_memory() const { - return remaining_memory_; - } - - // The loop start, end, and size accessors. - int loop_start() const { return loop_start_; } - int loop_end() const { return loop_end_; } - int loop_size() const { return loop_size_; } - - private: - // Temporary data structures used by the AllocatePrefetch function. - struct AllocatePrefetchesContext { - // The values that are requested to be prefetched. - absl::Span values; - - // A list of indices into values array, sorted by the start time of the - // first use. - std::vector value_indices; - - // Default memory remaining bandwidths assuming all prefetches succeeded. - std::vector bandwidth_idle_times; - - // Additional memory used while performing prefetching. - std::vector additional_memory_used; - }; - - MemoryBoundLoopOptimizer( - int loop_start, int loop_end, uint64_t alternate_memory_size, - const MemoryBoundLoopOptimizerOptions& options, - const HloLiveRange& hlo_live_range, - const HloAliasAnalysis& alias_analysis_, - const MemorySpaceAssignmentCostAnalysis& cost_analysis, - const BufferValue::SizeFunction& size_function); - - // Initializes the data structures used by the optimizer. - Status Initialize(); - - // Given an HloBuffer object, determines if this buffer represents a LoopValue - // that can be optimized by the optimizer, and if so it adds a LoopValue to - // the back of loop_values_ that represents the HloBuffer. Otherwise, no new - // LoopValue is added to loop_values_. - void MaybeCreateLoopValue(const HloBuffer& buffer, - const HloComputation* loop_computation); - - // Sort LoopValues by savings_per_byte. - void SortLoopValues(); - - // After allocation finishes, we fix up by creating Allocation objects to any - // LoopValues that didn't get alternate memory allocations. - void PostProcess(); - - // Allocate LoopValues by dispatching to the correct Allocate method. - void AllocateLoopValues(); - - // Allocate and reserve memory between the given indices. - bool AllocateBetween(int64_t begin_idx, int64_t end_idx, int64_t size); - - // Perform allocation type kTemporary. Return true if successful. - bool AllocateTemporary(LoopValue& value); - - // Perform allocation type kPinned. Return true if successful. - bool AllocatePinned(LoopValue& value); - - // Perform allocation type kPrefetch. Unlike the other Allocate methods, this - // performs allocation of multiple LoopValues in order to consider the effect - // of remaining bandwidth assuming the other prefetches were successful. - // Return true if successful. - bool AllocatePrefetches(absl::Span values); - - // Allocate one prefetch for the loop value index that corresponds to - // context.context.values. Returns true if successful. - bool AllocatePrefetch(int value_index, AllocatePrefetchesContext& context); - - // Keeps track of successful allocation of all uses and positions of this - // LoopValue. - void AddAllLoopPositionsAndUses(LoopValue& value, - bool allocate_next_iteration_uses); - - // Returns the default memory bandwidth idle time at the index. - float GetBandwidthIdleTime(int idx) const; - - // Returns the default memory bandwidth idle time at the index assuming the - // given uses and positions got alternate memory allocations. - float GetBandwidthIdleTime( - int idx, - const absl::flat_hash_map>>& - additional_uses_in_alternate_mem, - const absl::flat_hash_map>& - additional_positions_in_alternate_mem) const; - - // Returns the instruction elapsed at the index. - float GetInstructionElapsed(int idx) const; - - int loop_start_; - int loop_end_; - int loop_size_; - uint64_t alternate_memory_size_; - MemoryBoundLoopOptimizerOptions options_; - const HloLiveRange& hlo_live_range_; - const HloAliasAnalysis& alias_analysis_; - const MemorySpaceAssignmentCostAnalysis& cost_analysis_; - BufferValue::SizeFunction size_function_; - - absl::flat_hash_map instructions_in_loop_; - absl::flat_hash_map - instructions_in_prev_iteration_; - absl::flat_hash_map - instructions_in_next_iteration_; - std::vector loop_values_; - std::vector remaining_memory_; - absl::flat_hash_map>> - uses_in_alternate_mem_; - absl::flat_hash_map> - positions_in_alternate_mem_; -}; - // This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of // maximum size. class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { public: - using MemorySpace = MemorySpaceAssignment::MemorySpace; using AllocationValue = MemorySpaceAssignment::AllocationValue; using HloPositionOrUse = std::variant; - AlternateMemoryBestFitHeap( - MemorySpaceAssignment::AllocationSequence* allocations, - const Options& options, const HloAliasAnalysis& alias_analysis, - const HloLiveRange& hlo_live_range); + AlternateMemoryBestFitHeap(AllocationSequence* allocations, + const Options& options, + const HloAliasAnalysis& alias_analysis, + const HloLiveRange& hlo_live_range); // Allocates a buffer in preferred memory with whole program lifetime and // enables prefetching prefetch_candidate from default memory across program @@ -2119,7 +832,7 @@ class AlternateMemoryBestFitHeap void AllocateCrossProgramPrefetchBuffer( HloModule* module, const BufferInterval& prefetch_candidate); - HeapSimulator::Result Finish() override; + absl::StatusOr> Finish() override; protected: // Given a buffer interval, returns the colocated intervals. Unlike the @@ -2145,9 +858,7 @@ class AlternateMemoryBestFitHeap // positions. void FindAliases(std::vector* allocation_values) const; - MemorySpaceAssignment::AllocationSequence* allocations() { - return allocations_; - } + AllocationSequence* allocations() { return allocations_; } const Options& options() const { return options_; } const HloAliasAnalysis& alias_analysis() { return alias_analysis_; } const HloLiveRange& hlo_live_range() { return hlo_live_range_; } @@ -2155,16 +866,15 @@ class AlternateMemoryBestFitHeap private: // We inherit AllocationBlock struct to attach the Allocation information to // make importing repacked offsets easier. - struct RepackAllocationBlock - : MemorySpaceAssignmentRepacker::AllocationBlock { - MemorySpaceAssignment::Allocation* allocation; + struct RepackAllocationBlock : AllocationBlock { + Allocation* allocation; }; // A data structure we use to associate Allocation objects that are aliased // and must get the same offset. struct AliasedOffset { int64_t offset; - absl::flat_hash_set allocations; + absl::flat_hash_set allocations; }; // An allocation request for a use segment. A use segment is the time segment @@ -2194,6 +904,7 @@ class AlternateMemoryBestFitHeap int64_t size; bool prefer_no_copy_alternate_mem_allocation; bool allow_no_copy_alternate_mem_allocation; + bool require_no_copy_alternate_mem_allocation; bool allow_prefetch; std::optional earliest_prefetch_time; std::optional preferred_prefetch_time; @@ -2208,7 +919,7 @@ class AlternateMemoryBestFitHeap // time of the parameter instruction, and an output's time would correspond to // the time of last use. struct RequiredMemoryAssignment { - MemorySpaceAssignment::MemorySpace memory_space; + MemorySpace memory_space; int64_t time; AliasedOffset* offset; @@ -2236,7 +947,7 @@ class AlternateMemoryBestFitHeap // instruction. int64_t loop_size; // A pointer into an Allocation in loop_optimized_allocations_. - const MemorySpaceAssignment::Allocation* loop_optimized_allocation; + const Allocation* loop_optimized_allocation; }; // A context object that is used to share state amongst the methods that @@ -2271,8 +982,7 @@ class AlternateMemoryBestFitHeap // p0 | +-------+ // +---|---|---|---|---|----> time // t0 t1 t2 t3 t4 - std::vector - slice_decisions_sorted_by_start_time; + std::vector slice_decisions_sorted_by_start_time; // In order to support colocated buffer calculations, we need to add a // BufferInterval-Chunk pair to pending_chunks_, such that: @@ -2334,7 +1044,7 @@ class AlternateMemoryBestFitHeap // Parameters to Prefetch(). const AllocationRequest* request; - MemorySpaceAssignment::Allocation* prev_allocation_in_default_mem; + Allocation* prev_allocation_in_default_mem; // Intermediate calculations common to both the sliced and unsliced // solutions. @@ -2348,8 +1058,8 @@ class AlternateMemoryBestFitHeap std::optional exclusive_out_of_mem_start = std::nullopt; // Data structures used to compute and store the sliced solution. - std::optional - slice_proposal_collection = std::nullopt; + std::optional slice_proposal_collection = + std::nullopt; WorkingIntervals sliced_solution_intervals; std::optional sliced_solution; @@ -2388,7 +1098,9 @@ class AlternateMemoryBestFitHeap kFailRequiresUncommit = 64, // For prefetching, indicates that all slices have the same start time, in // which case, we fallback to an unsliced solution. - kAllSlicesHaveTheSameStartTime = 128 + kAllSlicesHaveTheSameStartTime = 128, + // There were conflicting preferred offsets. + kFailConflictingPreferredOffsets = 256 }; // Return true if the result belongs to a failure. @@ -2432,22 +1144,19 @@ class AlternateMemoryBestFitHeap void AllocateReservedScopedAllocations(); // Returns the AliasedOffset object associated with the allocation. - AliasedOffset* GetAliasedOffset( - const MemorySpaceAssignment::Allocation& allocation); + AliasedOffset* GetAliasedOffset(const Allocation& allocation); // If aliased_offset is non-null, this method adds the allocation to // aliased_offset. Otherwise, it creates a new AliasedOffset object and adds // the allocation to this new AliasedOffset. - void CreateOrAddToAliasedOffset( - const MemorySpaceAssignment::Allocation& allocation, - AliasedOffset* aliased_offset); + void CreateOrAddToAliasedOffset(const Allocation& allocation, + AliasedOffset* aliased_offset); // Given an allocation sequence, returns the live allocation at time with a // preference towards allocations in alternate memory. Returns nullptr if no // allocation is alive at that time. - static MemorySpaceAssignment::Allocation* GetLiveAllocationAt( - const MemorySpaceAssignment::AllocationSequence& allocations, - int64_t time); + static Allocation* GetLiveAllocationAt(const AllocationSequence& allocations, + int64_t time); // Returns true if the use is allowed in the alternate memory. bool IsUseAllowedInAlternateMemory(const AllocationValue& value, @@ -2457,7 +1166,7 @@ class AlternateMemoryBestFitHeap // All of the allocation values have a must-alias relationship with each // other. Returns either kSuccess if all of the sites could be placed in the // alternate memory or a bitwise OR of failure reasons why they couldn't - Result AllocateAllocationValues( + absl::StatusOr AllocateAllocationValues( absl::Span allocation_values); // Finds an allocation for an allocation request for a segment (see the @@ -2491,9 +1200,8 @@ class AlternateMemoryBestFitHeap int64_t earliest_prefetch_time) const; // Try prefetching to alternate memory space. - Result Prefetch( - const AllocationRequest& request, - MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem); + Result Prefetch(const AllocationRequest& request, + Allocation& prev_allocation_in_default_mem); // Helper methods used to implement Prefetch(). // @@ -2548,9 +1256,9 @@ class AlternateMemoryBestFitHeap colocated_intervals); // Propagates aliased required assignment for a given position. - void AddAliasedRequiredAssignment( - const HloInstruction* instruction, ShapeIndex index, - const MemorySpaceAssignment::Allocation* aliased_allocation); + void AddAliasedRequiredAssignment(const HloInstruction* instruction, + ShapeIndex index, + const Allocation* aliased_allocation); // This sets a required assignment. CHECK fails if there is a conflicting // required assignment at the same time. @@ -2578,7 +1286,7 @@ class AlternateMemoryBestFitHeap // allocations all share a common allocation site (a use or position) with // each other. This can be used to determine if a group of linked allocations // are considered efficient or not. - std::vector> + std::vector> GetLinkedAllocationsInAlternateMemory( absl::Span allocation_values) const; @@ -2624,8 +1332,7 @@ class AlternateMemoryBestFitHeap // Exports the allocations for repacking and puts them into the vector in the // parameter. void ExportAllocationsForRepacking( - std::vector& - allocations); + std::vector& allocations); // Update reserved scoped allocation size for instructions when their // operand/output has been allocated in alternate memory by invoking @@ -2638,26 +1345,24 @@ class AlternateMemoryBestFitHeap // Helper functions to implement ImportRepackedAllocations. void ImportRepackedNonSlicedAllocation(RepackAllocationBlock& block); void ImportRepackedSlicedAllocation(RepackAllocationBlock& block); + Status AreRepackedSlicesValid(const RepackAllocationBlock& block); // Adds an asynchronous copy to allocations. void AddAsyncCopy( - MemorySpaceAssignment::Allocation& prev_allocation, - MemorySpace memory_space, std::optional chunk, - int64_t exclusive_start_time, int64_t end_time, - int64_t copy_done_schedule_before_time, - MemorySpaceAssignment::AllocationSequence* allocations, - AliasedOffset* aliased_offset, float resource, + Allocation& prev_allocation, MemorySpace memory_space, + std::optional chunk, int64_t exclusive_start_time, + int64_t end_time, int64_t copy_done_schedule_before_time, + AllocationSequence* allocations, AliasedOffset* aliased_offset, + float resource, std::optional cross_program_prefetch_index = std::nullopt); // For prefetching, adds a SlicedCopyAllocation to allocations. Also updates // asynchronous copy data structures, prefetch_interval_tree_, and aliasing // data structures void AddAsyncSlicesForPrefetch( - const MemorySpaceAssignment::Allocation& prev_allocation, - MemorySpaceAssignment::AllocationSequence* allocations, + const Allocation& prev_allocation, AllocationSequence* allocations, AliasedOffset* aliased_offset, - const std::vector& - slice_decisions_sorted_by_start_time, + const std::vector& slice_decisions_sorted_by_start_time, int64_t prefetch_end_time, int64_t allocation_end_time); // This method is used for committing the chunk candidate but adding it to @@ -2684,9 +1389,8 @@ class AlternateMemoryBestFitHeap void AppendScopedAllocationBufferInfoDebugString( const HloInstruction* instruction, int64_t time, int64_t size, std::string& debug_str) const; - void AppendAllocationInfoDebugString( - const MemorySpaceAssignment::Allocation& allocation, - std::string& debug_str) const; + void AppendAllocationInfoDebugString(const Allocation& allocation, + std::string& debug_str) const; void DumpDebugStringsIfEnabled() const; // Returns the available heap size in the alternate memory. @@ -2703,8 +1407,7 @@ class AlternateMemoryBestFitHeap // Creates and returns a RepackAllocationBlock. static RepackAllocationBlock MakeRepackAllocationBlock( int64_t start_time, int64_t end_time, int64_t size, - int64_t initial_offset, int64_t id, - MemorySpaceAssignment::Allocation* allocation) { + int64_t initial_offset, int64_t id, Allocation* allocation) { RepackAllocationBlock allocation_block; allocation_block.inclusive_start_time = start_time; allocation_block.end_time = end_time; @@ -2712,7 +1415,7 @@ class AlternateMemoryBestFitHeap allocation_block.offset = -1; allocation_block.initial_offset = initial_offset; allocation_block.id = id; - allocation_block.colocations = {}; + allocation_block.next_colocated = nullptr; allocation_block.allocation = allocation; return allocation_block; } @@ -2722,7 +1425,11 @@ class AlternateMemoryBestFitHeap const std::vector* GetRepeatedInstructionList( const HloInstruction* instruction) const; - MemorySpaceAssignment::AllocationSequence* allocations_; + // Returns true if the interval is pinned in the alternate memory. Buffers are + // pinned when their layout has the alternate memory space before MSA runs. + bool IsIntervalPinnedToAlternateMemory(const BufferInterval& interval) const; + + AllocationSequence* allocations_; const Options& options_; const HloAliasAnalysis& alias_analysis_; const HloLiveRange& hlo_live_range_; @@ -2752,8 +1459,7 @@ class AlternateMemoryBestFitHeap // The data structure that contains AliasedOffset objects and Allocation to // AliasedOffset map for efficient lookup. std::list aliased_offsets_; - absl::flat_hash_map - aliased_offset_map_; + absl::flat_hash_map aliased_offset_map_; // This map contains required memory assignments for HloValues (e.g., input // and outputs). absl::flat_hash_map> @@ -2776,8 +1482,7 @@ class AlternateMemoryBestFitHeap // allocation objects describe the allocations for one iteration of the loop, // so we translate them into the program-level Allocation objects in // allocations_. - std::vector - loop_optimized_allocations_; + std::vector loop_optimized_allocations_; // A map to look up the loop-optimized allocation info by use. absl::flat_hash_map loop_optimized_allocations_map_; @@ -2795,6 +1500,7 @@ class AlternateMemoryBestFitHeap std::string allocation_info_str_; std::string instruction_schedule_str_; }; + } // namespace memory_space_assignment } // namespace xla diff --git a/xla/service/memory_space_assignment/memory_space_assignment.proto b/xla/service/memory_space_assignment/memory_space_assignment.proto index 426e4a154ff38..47a89e74bebb0 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment.proto +++ b/xla/service/memory_space_assignment/memory_space_assignment.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,6 +36,14 @@ message SlicedPrefetchOptions { // size that is not a multiple of the required hardware alignment. Otherwise, // we will choose not to slice such situations, which is always safe. bool fail_on_non_alignment_boundary_slice_proposal = 3; + + // The threshold for max_slices after which we limit the permutations of slice + // times that we try when placing a sliced allocation. + uint32 all_slice_time_permutations_threshold = 4; + + // The preferred slize size for MSA sliced prefetches. 0 means there is no + // preferred slice size, in which case, we'll try to slice into max_slices. + uint64 preferred_slice_size = 5; } // Options for memory-bound loop optimizations in memory space assignment. If @@ -63,3 +71,96 @@ message MemoryBoundLoopOptimizerOptions { // memory-bound loop optimizer to kick in. optional float min_num_iterations = 4; } + +message TupleShapeIndex { + repeated int64 index = 1; +} + +// A message to filter operands in an HLO schedule, that can be used to override +// compiler behaviour like altering schedule etc. +message HloOperandFilter { + // Regex to match instruction name. + optional string instruction_name_regex = 1; + // Set if filtering operands of an instruction. + optional int64 operand_number = 2; + // If filtering operands based on size in bytes. + optional int64 size_gte = 3; + // If filtering operands based on size in bytes. + optional int64 size_lte = 4; + // If operand of an instruction is a tuple and indexing into the tuple is + // required. + optional TupleShapeIndex tuple_index = 5; +} + +// Options to override preferred prefetch time for an operand. +message PreferredPrefetchOverrideOptions { + oneof options { + // A value X in [0, 1] that tells us the preferred prefetch time is the + // fraction X through the live range. For example, .5 will set the + // preferred prefetch time to the middle of live range. + float prefetch_eagerness = 1; + // Preferred prefetch time is set to after the instruction with instruction + // name. + string after_instruction_name = 2; + // Preferred prefetch time is set to before the instruction with instruction + // name. + string before_instruction_name = 3; + } +} + +// Filters operands in an HLO schedule and overrides preferred prefetch times +// for those operands according to an override strategy specified in +// override_options. +message PreferredPrefetchOverride { + optional HloOperandFilter hlo_operand_filter = 1; + optional xla.memory_space_assignment.PreferredPrefetchOverrideOptions + override_options = 2; +} + +// Encloses chained override configs. The first config has highest precedence +// and so on. +message PreferredPrefetchOverrides { + repeated PreferredPrefetchOverride overrides = 1; +} + +// A message that identifies one or more HloPositions. +message HloPositionMatcher { + // Regex to match the entire instruction HLO. The HLO string is constructed + // using default HloPrintOptions. Refer to the HloPrintOptions class in + // hlo_instruction.h to know more about the format of the HLO string used for + // matching. + optional string instruction_regex = 1; + // Regex to match instruction name. + optional string instruction_name_regex = 2; + // If output of an instruction is a tuple and indexing into the + // tuple is required. + optional TupleShapeIndex tuple_index = 3; +} + +// Options to override preferred prefetch time for an operand. +message MsaSortOrderOverrideOptions { + oneof options { + // Assign alternate memory to the filtered buffer before other buffers. If + // multiple buffers are to be assigned first (within the same override + // config) other tie breakers and stable sort order will take effect. + bool assign_first = 1; + // Assign alternate memory to the filtered buffer after other buffers. If + // multiple buffers are to be assigned last (within the same override + // config) other tie breakers and stable sort order will take effect. + bool assign_last = 2; + } +} + +// Specifies details on how to override the sort order for matching +// HloPositions. +message MsaSortOrderOverride { + optional HloPositionMatcher hlo_position_matcher = 1; + optional xla.memory_space_assignment.MsaSortOrderOverrideOptions + override_options = 2; +} + +// Encloses chained override configs. The first config has highest precedence +// and so on. +message MsaSortOrderOverrides { + repeated MsaSortOrderOverride overrides = 1; +} diff --git a/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/xla/service/memory_space_assignment/memory_space_assignment_test.cc index c54a1d03aebf7..5c9cc1e1321ba 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,7 +23,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -38,65 +37,57 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/functional/any_invocable.h" #include "absl/log/log.h" -#include "absl/status/statusor.h" -#include "absl/strings/ascii.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" -#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_live_range.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/service/heap_simulator.h" +#include "xla/service/heap_simulator/allocation_block.h" +#include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_value.h" #include "xla/service/instruction_hoister.h" +#include "xla/service/memory_space_assignment/allocation.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" #include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" +#include "xla/service/memory_space_assignment/options.h" +#include "xla/service/memory_space_assignment/prefetch_interval_picker.h" #include "xla/service/memory_space_assignment/repacking.h" +#include "xla/service/memory_space_assignment/slice.h" +#include "xla/service/memory_space_assignment/testing_utils.h" #include "xla/service/time_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" +#include "xla/statusor.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep #include "tsl/platform/status.h" +#include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { +namespace memory_space_assignment { namespace { namespace op = xla::testing::opcode_matchers; using Chunk = HeapSimulator::Chunk; -using memory_space_assignment::AsynchronousCopy; -using memory_space_assignment::AsynchronousCopyOrdering; -using memory_space_assignment::AsynchronousCopyResource; -using memory_space_assignment::CostAnalysisPrefetchIntervalPicker; -using memory_space_assignment::InstructionCountPrefetchIntervalPicker; -using memory_space_assignment::MemoryBoundLoopOptimizer; -using memory_space_assignment::MemoryBoundLoopOptimizerOptions; -using memory_space_assignment::MemorySpaceAssignment; -using memory_space_assignment::MemorySpaceAssignmentCostAnalysis; -using memory_space_assignment::MemorySpaceAssignmentRepacker; -using memory_space_assignment::Options; -using memory_space_assignment::PrefetchIntervalPicker; -using memory_space_assignment::PresetAssignments; -using memory_space_assignment::SlicedPrefetchOptions; -using SliceParam = memory_space_assignment::MemorySpaceAssignment::SliceParam; -using SliceProposal = - memory_space_assignment::MemorySpaceAssignment::SliceProposal; -using SliceProposalCollection = - memory_space_assignment::MemorySpaceAssignment::SliceProposalCollection; -using MSA = memory_space_assignment::MemorySpaceAssignment; using ::testing::_; using ::testing::Return; +using ::testing::UnorderedElementsAre; constexpr int64_t kPointerSize = 8; constexpr float kAsyncCopyBandwidth = 100; @@ -113,24 +104,44 @@ int64_t SizeFunction(const BufferValue& value) { return ShapeSize(value.shape()); } -class TestBufferIntervalComparator - : public MemorySpaceAssignment::BufferIntervalComparator { +int64_t ReservedScopedMemoryFn( + const HloInstruction* instruction, + const absl::flat_hash_set>& + operands_in_alternate_memory, + const absl::flat_hash_set& outputs_in_alternate_memory) { + return 0; +} + +template +StatusOr ParseTextProto(const std::string& text_proto) { + tsl::protobuf::TextFormat::Parser parser; + MessageType parsed_proto; + tsl::protobuf::io::ArrayInputStream input_stream(text_proto.data(), + text_proto.size()); + if (!parser.Parse(&input_stream, &parsed_proto)) { + return absl::InvalidArgumentError("Could not parse text proto"); + } + return parsed_proto; +} + +class TestBufferIntervalComparator : public BufferIntervalComparator { public: explicit TestBufferIntervalComparator( GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare compare_method) - : MemorySpaceAssignment::BufferIntervalComparator(), - compare_method_(compare_method) {} + : BufferIntervalComparator(), compare_method_(compare_method) {} ~TestBufferIntervalComparator() override = default; std::string DescribeComparisonCriteria() const override { return "internal to test"; } - std::string CriteriaToString(const BufferInterval& buffer_interval) override { + std::string CriteriaToString( + const MsaBufferInterval& buffer_interval) override { return "internal to test"; } - bool LessThan(const BufferInterval& lhs, const BufferInterval& rhs) override { + bool LessThan(const MsaBufferInterval& lhs, + const MsaBufferInterval& rhs) override { return compare_method_(lhs, rhs); } @@ -160,8 +171,6 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { Options DefaultMemorySpaceOptions() { Options options; - options.async_copy_bandwidth_bytes_per_second = kAsyncCopyBandwidth; - options.alternate_mem_bandwidth_bytes_per_second = kAlternateMemBandwidth; options.max_size_in_bytes = 128; options.alignment_in_bytes = 8; options.verify = true; @@ -174,6 +183,13 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { return options; } + CostAnalysisOptions DefaultCostAnalysisOptions() { + CostAnalysisOptions options; + options.async_copy_bandwidth_bytes_per_second = kAsyncCopyBandwidth; + options.alternate_mem_bandwidth_bytes_per_second = kAlternateMemBandwidth; + return options; + } + Options UpdateMaxAsyncCopies(Options options, int64_t max_async_copies) { options.max_outstanding_prefetches = max_async_copies; options.max_outstanding_evictions = max_async_copies; @@ -184,14 +200,18 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { std::unique_ptr AssignMemorySpaceUsingCostAnalysis( HloModule* module, std::optional memory_space_options_override = std::nullopt, - std::optional cost_options_override = + std::optional cost_analysis_options_override = + std::nullopt, + std::optional hlo_cost_options_override = + std::nullopt, + std::optional optional_msa_sort_order_overrides = std::nullopt) { - HloCostAnalysis::Options cost_options = DefaultHloCostAnalysisOptions(); - if (cost_options_override) { - cost_options = *cost_options_override; + HloCostAnalysis::Options hlo_cost_options = DefaultHloCostAnalysisOptions(); + if (hlo_cost_options_override) { + hlo_cost_options = *hlo_cost_options_override; } - HloCostAnalysis hlo_cost_analysis(cost_options); + HloCostAnalysis hlo_cost_analysis(hlo_cost_options); for (HloComputation* computation : module->MakeNonfusionComputations()) { TF_CHECK_OK(computation->Accept(&hlo_cost_analysis)); } @@ -201,10 +221,14 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { if (memory_space_options_override) { memory_space_options = *memory_space_options_override; } + CostAnalysisOptions cost_analysis_options = DefaultCostAnalysisOptions(); + if (cost_analysis_options_override) { + cost_analysis_options = *cost_analysis_options_override; + } - auto cost_analysis = MemorySpaceAssignmentCostAnalysis::Create( - hlo_cost_analysis, memory_space_options, *module) - .value(); + auto cost_analysis = + CostAnalysis::Create(hlo_cost_analysis, cost_analysis_options, *module) + .value(); memory_space_options.cost_analysis = cost_analysis.get(); CostAnalysisPrefetchIntervalPicker prefetch_interval_picker( CostAnalysisPrefetchIntervalPicker( @@ -212,12 +236,16 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { /*preferred_overlap_to_async_copy_ratio=*/1.5, /*max_overlap_to_mem_size_async_copy_ratio=*/10.0, /*mem_size_bytes=*/memory_space_options.max_size_in_bytes)); - memory_space_assignment::MemoryBoundednessBufferIntervalComparator - comparator(*cost_analysis, &cache_); + MsaSortOrderOverrides msa_sort_order_overrides; + if (optional_msa_sort_order_overrides.has_value()) { + msa_sort_order_overrides = optional_msa_sort_order_overrides.value(); + } + MemoryBoundednessBufferIntervalComparator comparator( + *cost_analysis, &cache_, msa_sort_order_overrides); return AssignMemorySpace( module, memory_space_options, - [&comparator](const MemorySpaceAssignment::BufferInterval& lhs, - const MemorySpaceAssignment::BufferInterval& rhs) { + [&comparator](const MsaBufferInterval& lhs, + const MsaBufferInterval& rhs) { return comparator.LessThan(lhs, rhs); }, &prefetch_interval_picker); @@ -237,8 +265,19 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { std::unique_ptr AssignMemorySpace( HloModule* module, std::optional options_override, - std::optional - buffer_interval_compare, + std::optional buffer_interval_compare, + PrefetchIntervalPicker* prefetch_interval_picker) { + auto status_or = AssignMemorySpaceAndReturnStatus(module, options_override, + buffer_interval_compare, + prefetch_interval_picker); + TF_EXPECT_OK(status_or.status()); + return std::move(status_or.value()); + } + + absl::StatusOr> + AssignMemorySpaceAndReturnStatus( + HloModule* module, std::optional options_override, + std::optional buffer_interval_compare, PrefetchIntervalPicker* prefetch_interval_picker) { auto size_fn = [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); @@ -288,16 +327,14 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem; } - auto alias_analysis = HloAliasAnalysis::Run(module).value(); - std::unique_ptr hlo_live_range = - HloLiveRange::Run(module->schedule(), *alias_analysis, - module->entry_computation()) - .value(); + TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module)); + TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_live_range, + HloLiveRange::Run(module->schedule(), *alias_analysis, + module->entry_computation())); - std::unique_ptr preset_assignments = - MemorySpaceAssignment::Run(module, *hlo_live_range, *alias_analysis, - options) - .value(); + TF_ASSIGN_OR_RETURN(std::unique_ptr preset_assignments, + MemorySpaceAssignment::Run(module, *hlo_live_range, + *alias_analysis, options)); if (check_parameters_in_default_memory) { CheckParametersInDefaultMemory(module); } @@ -469,7 +506,7 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { return module; } - MemorySpaceAssignmentCostAnalysis::Cache cache_; + CostAnalysis::Cache cache_; }; class MemorySpaceAssignmentTest : public MemorySpaceAssignmentTestBase, @@ -478,93 +515,6 @@ class MemorySpaceAssignmentTest : public MemorySpaceAssignmentTestBase, bool allocate_across_sequential_calls() const override { return GetParam(); } }; -// For testing purposes, we define a cost analysis where we can control the -// elapsed times of each HLO and asynchronous copy. -class FakeMemorySpaceAssignmentCostAnalysis - : public MemorySpaceAssignmentCostAnalysis { - public: - static StatusOr> - Create(const HloCostAnalysis& cost_analysis, const HloModule& module, - const Options& options) { - TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module)); - TF_ASSIGN_OR_RETURN(auto hlo_live_range, - HloLiveRange::Run(module.schedule(), *alias_analysis, - module.entry_computation())); - auto call_graph = CallGraph::Build(&module); - return absl::WrapUnique(new FakeMemorySpaceAssignmentCostAnalysis( - cost_analysis, options, std::move(alias_analysis), - std::move(hlo_live_range), std::move(call_graph))); - } - - float GetInstructionElapsed( - const HloInstruction& instruction) const override { - if (get_instruction_elapsed_override_) { - return get_instruction_elapsed_override_(instruction); - } - return 1.0; - } - - float GetInstructionElapsedInAlternateMemory( - const HloInstruction& instruction, - absl::Span> - operands_in_alternate_mem, - absl::Span outputs_in_alternate_mem) const override { - if (get_instruction_elapsed_in_alternate_memory_override_) { - return get_instruction_elapsed_in_alternate_memory_override_( - instruction, operands_in_alternate_mem, outputs_in_alternate_mem); - } - if (!operands_in_alternate_mem.empty()) { - return 0.5; - } else { - return 1.0; - } - } - - float GetAsyncCopyElapsed(const Shape& shape) const override { - if (get_async_copy_elapsed_override_) { - return get_async_copy_elapsed_override_(shape); - } - return 3.0; - } - - // The following methods can be used to override what the above API calls - // return. - void SetOverrideForGetInstructionElapsed( - std::function function) { - get_instruction_elapsed_override_ = function; - } - void SetOverrideForGetInstructionElapsedInAlternateMemory( - std::function>, - absl::Span)> - function) { - get_instruction_elapsed_in_alternate_memory_override_ = function; - } - void SetOverrideForGetAsyncCopyElapsed( - std::function function) { - get_async_copy_elapsed_override_ = function; - } - - protected: - FakeMemorySpaceAssignmentCostAnalysis( - const HloCostAnalysis& cost_analysis, const Options& options, - std::unique_ptr alias_analysis, - std::unique_ptr hlo_live_range, - std::unique_ptr call_graph) - : MemorySpaceAssignmentCostAnalysis( - cost_analysis, options, std::move(alias_analysis), - std::move(hlo_live_range), std::move(call_graph)) {} - - private: - std::function - get_instruction_elapsed_override_ = nullptr; - std::function>, - absl::Span)> - get_instruction_elapsed_in_alternate_memory_override_ = nullptr; - std::function get_async_copy_elapsed_override_ = nullptr; -}; - TEST_P(MemorySpaceAssignmentTest, ParameterOnly) { // A module consisting of a single parameter. Inputs/outputs are currently // excluded from memory space assignment. @@ -617,6 +567,7 @@ TEST_P(MemorySpaceAssignmentTest, Simple) { F32, {2, 3}, /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kAlternateMemorySpace); EXPECT_THAT(p0, op::ShapeWithLayout(shape)); EXPECT_THAT(p1, op::ShapeWithLayout(shape)); @@ -679,6 +630,7 @@ TEST_P(MemorySpaceAssignmentTest, NegateChain) { F32, {2, 3}, /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kAlternateMemorySpace); EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem)); EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem)); @@ -940,11 +892,16 @@ TEST_P(MemorySpaceAssignmentTest, FilterUpdatePreferredPrefetchTest) { TF_CHECK_OK(module->set_schedule(schedule)); Options options = DefaultMemorySpaceOptions(); - auto config = "op_size_gte:24:op_size_lte:24:prefetch_eagerness:0.5"; + + const std::string text_proto = R"pb( + overrides { + hlo_operand_filter { size_lte: 24 size_gte: 24 } + override_options { prefetch_eagerness: 0.5 } + })pb"; TF_ASSERT_OK_AND_ASSIGN( - options.filter_update_preferred_prefetches, - memory_space_assignment::FilterUpdatePreferredPrefetch:: - ParseFilterUpdatePreferredPrefetches(config)); + options.preferred_prefetch_overrides, + ParseTextProto(text_proto)); + AssignMemorySpace(module.get(), options); EXPECT_THAT(add, op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace, @@ -958,6 +915,7 @@ TEST_P(MemorySpaceAssignmentTest, FilterUpdatePreferredPrefetchTest) { F32, {2, 3}, /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kAlternateMemorySpace); EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem)); EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem)); @@ -1010,13 +968,16 @@ TEST_P(MemorySpaceAssignmentTest, FilterUpdateConfigExactMatchBeforeTest) { TF_CHECK_OK(module->set_schedule(schedule)); Options options = DefaultMemorySpaceOptions(); - auto config = - "instruction_name_exact:add:op_number_exact:1:put_before_instruction:" - "negate.3"; + + const std::string text_proto = R"pb( + overrides { + hlo_operand_filter { instruction_name_regex: "add" operand_number: 1 } + override_options { before_instruction_name: "negate.3" } + })pb"; TF_ASSERT_OK_AND_ASSIGN( - options.filter_update_preferred_prefetches, - memory_space_assignment::FilterUpdatePreferredPrefetch:: - ParseFilterUpdatePreferredPrefetches(config)); + options.preferred_prefetch_overrides, + ParseTextProto(text_proto)); + AssignMemorySpace(module.get(), options); EXPECT_THAT(add, op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace, @@ -1030,6 +991,7 @@ TEST_P(MemorySpaceAssignmentTest, FilterUpdateConfigExactMatchBeforeTest) { F32, {2, 3}, /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kAlternateMemorySpace); EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem)); EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem)); @@ -1082,13 +1044,16 @@ TEST_P(MemorySpaceAssignmentTest, FilterUpdateConfigExactMatchAfterTest) { TF_CHECK_OK(module->set_schedule(schedule)); Options options = DefaultMemorySpaceOptions(); - auto config = - "instruction_name_exact:add:op_number_exact:1:put_after_instruction:" - "negate.1"; + + const std::string text_proto = R"pb( + overrides { + hlo_operand_filter { instruction_name_regex: "add" operand_number: 1 } + override_options { after_instruction_name: "negate.1" } + })pb"; TF_ASSERT_OK_AND_ASSIGN( - options.filter_update_preferred_prefetches, - memory_space_assignment::FilterUpdatePreferredPrefetch:: - ParseFilterUpdatePreferredPrefetches(config)); + options.preferred_prefetch_overrides, + ParseTextProto(text_proto)); + AssignMemorySpace(module.get(), options); EXPECT_THAT(add, op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace, @@ -1102,6 +1067,7 @@ TEST_P(MemorySpaceAssignmentTest, FilterUpdateConfigExactMatchAfterTest) { F32, {2, 3}, /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kAlternateMemorySpace); EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem)); EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem)); @@ -1154,13 +1120,16 @@ TEST_P(MemorySpaceAssignmentTest, FilterUpdateConfigExactMatchTooLateTest) { TF_CHECK_OK(module->set_schedule(schedule)); Options options = DefaultMemorySpaceOptions(); - auto config = - "instruction_name_exact:add:op_number_exact:1:put_after_instruction:" - "negate.5"; + + const std::string text_proto = R"pb( + overrides { + hlo_operand_filter { instruction_name_regex: "add" operand_number: 1 } + override_options { after_instruction_name: "negate.5" } + })pb"; TF_ASSERT_OK_AND_ASSIGN( - options.filter_update_preferred_prefetches, - memory_space_assignment::FilterUpdatePreferredPrefetch:: - ParseFilterUpdatePreferredPrefetches(config)); + options.preferred_prefetch_overrides, + ParseTextProto(text_proto)); + AssignMemorySpace(module.get(), options); // Ensure the Async copy is not scheduled. @@ -1173,6 +1142,7 @@ TEST_P(MemorySpaceAssignmentTest, FilterUpdateConfigExactMatchTooLateTest) { F32, {2, 3}, /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kAlternateMemorySpace); EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem)); EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem)); @@ -1218,13 +1188,20 @@ TEST_P(MemorySpaceAssignmentTest, FilterUpdateConfigPrecedenceTest) { TF_CHECK_OK(module->set_schedule(schedule)); Options options = DefaultMemorySpaceOptions(); - auto config = - "op_size_gte:24:op_size_lte:24:prefetch_eagerness:0.5;instruction_" - "name_exact:add:op_number_exact:1:put_after_instruction:negate.1"; + + const std::string text_proto = R"pb( + overrides { + hlo_operand_filter { size_lte: 24 size_gte: 24 } + override_options { prefetch_eagerness: 0.5 } + } + overrides { + hlo_operand_filter { instruction_name_regex: "add" operand_number: 1 } + override_options { after_instruction_name: "negate.1" } + })pb"; TF_ASSERT_OK_AND_ASSIGN( - options.filter_update_preferred_prefetches, - memory_space_assignment::FilterUpdatePreferredPrefetch:: - ParseFilterUpdatePreferredPrefetches(config)); + options.preferred_prefetch_overrides, + ParseTextProto(text_proto)); + AssignMemorySpace(module.get(), options); EXPECT_THAT(add, op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace, @@ -1238,6 +1215,7 @@ TEST_P(MemorySpaceAssignmentTest, FilterUpdateConfigPrecedenceTest) { F32, {2, 3}, /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kAlternateMemorySpace); EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem)); EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem)); @@ -1290,13 +1268,21 @@ TEST_P(MemorySpaceAssignmentTest, FilterUpdateConfigExactMatchPrecedenceTest) { TF_CHECK_OK(module->set_schedule(schedule)); Options options = DefaultMemorySpaceOptions(); - auto config = - "instruction_name_exact:add:op_number_exact:1:put_after_instruction:" - "negate.1;op_size_gte:24:op_size_lte:24:prefetch_eagerness:0.5"; + + const std::string text_proto = R"pb( + overrides { + hlo_operand_filter { instruction_name_regex: "add" operand_number: 1 } + override_options { after_instruction_name: "negate.1" } + } + overrides { + hlo_operand_filter { size_lte: 24 size_gte: 24 } + override_options { prefetch_eagerness: 0.5 } + } + )pb"; TF_ASSERT_OK_AND_ASSIGN( - options.filter_update_preferred_prefetches, - memory_space_assignment::FilterUpdatePreferredPrefetch:: - ParseFilterUpdatePreferredPrefetches(config)); + options.preferred_prefetch_overrides, + ParseTextProto(text_proto)); + AssignMemorySpace(module.get(), options); EXPECT_THAT(add, op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace, @@ -1310,6 +1296,7 @@ TEST_P(MemorySpaceAssignmentTest, FilterUpdateConfigExactMatchPrecedenceTest) { F32, {2, 3}, /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kAlternateMemorySpace); EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem)); EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem)); @@ -1362,11 +1349,17 @@ TEST_P(MemorySpaceAssignmentTest, FilterUpdatePreferredPrefetchNoMatchTest) { TF_CHECK_OK(module->set_schedule(schedule)); Options options = DefaultMemorySpaceOptions(); - auto config = "op_size_gte:25:op_size_lte:24:prefetch_eagerness:0.5"; + + const std::string text_proto = R"pb( + overrides { + hlo_operand_filter { size_lte: 24 size_gte: 25 } + override_options { prefetch_eagerness: 0.5 } + } + )pb"; TF_ASSERT_OK_AND_ASSIGN( - options.filter_update_preferred_prefetches, - memory_space_assignment::FilterUpdatePreferredPrefetch:: - ParseFilterUpdatePreferredPrefetches(config)); + options.preferred_prefetch_overrides, + ParseTextProto(text_proto)); + AssignMemorySpace(module.get(), options); EXPECT_THAT(add, op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace, @@ -1380,6 +1373,7 @@ TEST_P(MemorySpaceAssignmentTest, FilterUpdatePreferredPrefetchNoMatchTest) { F32, {2, 3}, /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kAlternateMemorySpace); EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem)); EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem)); @@ -1700,7 +1694,8 @@ TEST_P(MemorySpaceAssignmentTest, While) { } Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithDenseLayout( F32, {2, 3}, - /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kAlternateMemorySpace); EXPECT_THAT(body_data_mul, op::ShapeWithLayout(shape_in_alternate_mem)); } @@ -2039,7 +2034,8 @@ TEST_P(MemorySpaceAssignmentTest, BitcastMultiUse) { AssignMemorySpace(module.get()); Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithDenseLayout( F32, {2, 3}, - /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kAlternateMemorySpace); EXPECT_THAT(negate0->operand(0), op::ShapeWithLayout(shape)); EXPECT_THAT(add->operand(0), op::ShapeWithLayout(shape_in_alternate_mem)); @@ -2094,7 +2090,8 @@ TEST_P(MemorySpaceAssignmentTest, BitcastMultiUseTuple) { AssignMemorySpace(module.get()); Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithDenseLayout( F32, {2, 3}, - /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kAlternateMemorySpace); EXPECT_THAT(negate0->operand(0), op::ShapeWithLayout(shape)); EXPECT_THAT(fusion->operand(0)->operand(0), @@ -2253,9 +2250,8 @@ TEST_P(MemorySpaceAssignmentTest, WhileAllocationBug) { } )"; - MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare = - [](const MemorySpaceAssignment::BufferInterval& a, - const MemorySpaceAssignment::BufferInterval& b) { + MsaBufferIntervalCompare buffer_interval_compare = + [](const MsaBufferInterval& a, const MsaBufferInterval& b) { bool a_is_mul = a.buffer->defining_instruction()->opcode() == HloOpcode::kMultiply; bool b_is_mul = @@ -4003,6 +3999,7 @@ TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule6) { LayoutUtil::MakeLayout( /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*dim_unique=*/{}, /*dim_ordered=*/{}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*index_primitive_type=*/PRIMITIVE_TYPE_INVALID, /*pointer_primitive_type=*/PRIMITIVE_TYPE_INVALID, /*element_size_in_bits=*/0, kAlternateMemorySpace); @@ -4011,6 +4008,7 @@ TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule6) { LayoutUtil::MakeLayout( /*minor_to_major=*/{}, /*dim_level_types=*/{}, /*dim_unique=*/{}, /*dim_ordered=*/{}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*index_primitive_type=*/PRIMITIVE_TYPE_INVALID, /*pointer_primitive_type=*/PRIMITIVE_TYPE_INVALID, /*element_size_in_bits=*/0, kDefaultMemorySpace); @@ -4019,6 +4017,7 @@ TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule6) { LayoutUtil::MakeLayout( /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*dim_unique=*/{}, /*dim_ordered=*/{}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*index_primitive_type=*/PRIMITIVE_TYPE_INVALID, /*pointer_primitive_type=*/PRIMITIVE_TYPE_INVALID, /*element_size_in_bits=*/0, kDefaultMemorySpace); @@ -4454,7 +4453,8 @@ TEST_P(MemorySpaceAssignmentTest, CostAnalysis) { // Negate instructions are in the alternate memory space (1). Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithDenseLayout( F32, {2, 3}, - /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kAlternateMemorySpace); EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem)); EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem)); @@ -4524,7 +4524,8 @@ TEST_P(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) { EXPECT_THAT(p1, op::ShapeWithLayout(shape)); Shape shape_in_default_mem = ShapeUtil::MakeShapeWithDenseLayout( F32, {4, 3}, - /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kDefaultMemorySpace); // Expect only negates to be in alternate memory space. Not all might fit but // make sure at least one does. @@ -4543,6 +4544,143 @@ TEST_P(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) { EXPECT_THAT(tanh4, op::ShapeWithLayout(shape_in_default_mem)); } +TEST_P(MemorySpaceAssignmentTest, + MemoryBoundednessOverrideSortOrderAssignFirst) { + // Override MSA sort order and try to assign all negates to alternate memory + // first. + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + p0 = f32[3,4]{1,0} parameter(0) + p1 = f32[3,4]{1,0} parameter(1) + tanh0 = f32[3,4]{1,0} tanh(p0) + negate0 = f32[3,4]{1,0} negate(p1) + tanh1 = f32[3,4]{1,0} tanh(tanh0) + negate1 = f32[3,4]{1,0} negate(negate0) + tanh2 = f32[3,4]{1,0} tanh(tanh1) + negate2 = f32[3,4]{1,0} negate(negate1) + tanh3 = f32[3,4]{1,0} tanh(tanh2) + negate3 = f32[3,4]{1,0} negate(negate2) + tanh4 = f32[3,4]{1,0} tanh(tanh3) + negate4 = f32[3,4]{1,0} negate(negate3) + ROOT tuple = (f32[3,4]{1,0}, f32[3,4]{1,0}) tuple(tanh4, negate4) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + const std::string text_proto = R"pb( + overrides { + hlo_position_matcher { instruction_name_regex: "negate(.*)" } + override_options { assign_first: true } + })pb"; + TF_ASSERT_OK_AND_ASSIGN(auto msa_sort_order_overrides, + ParseTextProto(text_proto)); + + AssignMemorySpaceUsingCostAnalysis( + module.get(), /*memory_space_options_override=*/std::nullopt, + /*cost_analysis_options_override=*/std::nullopt, + /*hlo_cost_options_override=*/std::nullopt, + /*optional_msa_sort_order_overrides=*/msa_sort_order_overrides); + // Parameters are in the default memory space. + const HloInstruction* p0 = FindInstruction(module.get(), "p0"); + EXPECT_EQ(p0->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* p1 = FindInstruction(module.get(), "p1"); + EXPECT_EQ(p1->shape().layout().memory_space(), kDefaultMemorySpace); + // All negates are in alternate memory space except negate4. + HloInstruction* negate0 = FindInstruction(module.get(), "negate0"); + EXPECT_EQ(negate0->shape().layout().memory_space(), kAlternateMemorySpace); + HloInstruction* negate1 = FindInstruction(module.get(), "negate1"); + EXPECT_EQ(negate1->shape().layout().memory_space(), kAlternateMemorySpace); + HloInstruction* negate2 = FindInstruction(module.get(), "negate2"); + EXPECT_EQ(negate2->shape().layout().memory_space(), kAlternateMemorySpace); + HloInstruction* negate3 = FindInstruction(module.get(), "negate3"); + EXPECT_EQ(negate3->shape().layout().memory_space(), kAlternateMemorySpace); + HloInstruction* negate4 = FindInstruction(module.get(), "negate4"); + EXPECT_EQ(negate4->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* tanh0 = FindInstruction(module.get(), "tanh0"); + EXPECT_EQ(tanh0->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* tanh1 = FindInstruction(module.get(), "tanh1"); + EXPECT_EQ(tanh1->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* tanh2 = FindInstruction(module.get(), "tanh2"); + EXPECT_EQ(tanh2->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* tanh3 = FindInstruction(module.get(), "tanh3"); + EXPECT_EQ(tanh3->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* tanh4 = FindInstruction(module.get(), "tanh4"); + EXPECT_EQ(tanh4->shape().layout().memory_space(), kDefaultMemorySpace); +} + +TEST_P(MemorySpaceAssignmentTest, + MemoryBoundednessOverrideSortOrderAssignLast) { + // Override MSA sort order and try to assign all negates to alternate memory + // last. + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + p0 = f32[3,4]{1,0} parameter(0) + p1 = f32[3,4]{1,0} parameter(1) + tanh0 = f32[3,4]{1,0} tanh(p0) + negate0 = f32[3,4]{1,0} negate(p1) + tanh1 = f32[3,4]{1,0} tanh(tanh0) + negate1 = f32[3,4]{1,0} negate(negate0) + tanh2 = f32[3,4]{1,0} tanh(tanh1) + negate2 = f32[3,4]{1,0} negate(negate1) + tanh3 = f32[3,4]{1,0} tanh(tanh2) + negate3 = f32[3,4]{1,0} negate(negate2) + tanh4 = f32[3,4]{1,0} tanh(tanh3) + negate4 = f32[3,4]{1,0} negate(negate3) + ROOT tuple = (f32[3,4]{1,0}, f32[3,4]{1,0}) tuple(tanh4, negate4) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + const std::string text_proto = R"pb( + overrides { + hlo_position_matcher { instruction_name_regex: "negate(.*)" } + override_options { assign_last: true } + } + )pb"; + TF_ASSERT_OK_AND_ASSIGN(auto msa_sort_order_overrides, + ParseTextProto(text_proto)); + + AssignMemorySpaceUsingCostAnalysis( + module.get(), /*memory_space_options_override=*/std::nullopt, + /*cost_analysis_options_override=*/std::nullopt, + /*hlo_cost_options_override=*/std::nullopt, + /*optional_msa_sort_order_overrides=*/msa_sort_order_overrides); + // Parameters are in the default memory space. + const HloInstruction* p0 = FindInstruction(module.get(), "p0"); + EXPECT_EQ(p0->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* p1 = FindInstruction(module.get(), "p1"); + EXPECT_EQ(p1->shape().layout().memory_space(), kDefaultMemorySpace); + // All negates are in default memory space except negate3. + HloInstruction* negate0 = FindInstruction(module.get(), "negate0"); + EXPECT_EQ(negate0->shape().layout().memory_space(), kDefaultMemorySpace); + HloInstruction* negate1 = FindInstruction(module.get(), "negate1"); + EXPECT_EQ(negate1->shape().layout().memory_space(), kDefaultMemorySpace); + HloInstruction* negate2 = FindInstruction(module.get(), "negate2"); + EXPECT_EQ(negate2->shape().layout().memory_space(), kDefaultMemorySpace); + HloInstruction* negate3 = FindInstruction(module.get(), "negate3"); + EXPECT_EQ(negate3->shape().layout().memory_space(), kAlternateMemorySpace); + HloInstruction* negate4 = FindInstruction(module.get(), "negate4"); + EXPECT_EQ(negate4->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* tanh0 = FindInstruction(module.get(), "tanh0"); + EXPECT_EQ(tanh0->shape().layout().memory_space(), kAlternateMemorySpace); + const HloInstruction* tanh1 = FindInstruction(module.get(), "tanh1"); + EXPECT_EQ(tanh1->shape().layout().memory_space(), kAlternateMemorySpace); + const HloInstruction* tanh2 = FindInstruction(module.get(), "tanh2"); + EXPECT_EQ(tanh2->shape().layout().memory_space(), kAlternateMemorySpace); + const HloInstruction* tanh3 = FindInstruction(module.get(), "tanh3"); + EXPECT_EQ(tanh3->shape().layout().memory_space(), kAlternateMemorySpace); + const HloInstruction* tanh4 = FindInstruction(module.get(), "tanh4"); + EXPECT_EQ(tanh4->shape().layout().memory_space(), kDefaultMemorySpace); +} + TEST_P(MemorySpaceAssignmentTest, SimpleWhileTupleTest) { Shape s32 = ShapeUtil::MakeShape(xla::S32, {}); Shape f32v1 = ShapeUtil::MakeShape(F32, {1}); @@ -4632,15 +4770,18 @@ TEST_P(MemorySpaceAssignmentTest, SimpleWhileTupleTest) { // Ensure all parameters and while are placed in default memory. Shape shape_in_default_mem = ShapeUtil::MakeShapeWithDenseLayout( F32, {4, 6}, - /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kDefaultMemorySpace); Shape s32_in_default_mem = ShapeUtil::MakeShapeWithDenseLayout( xla::S32, {}, - /*minor_to_major=*/{}, /*tiles=*/{}, /*element_size_in_bits=*/0, + /*minor_to_major=*/{}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kDefaultMemorySpace); Shape f32v1_in_default_mem = ShapeUtil::MakeShapeWithDenseLayout( F32, {1}, - /*minor_to_major=*/{0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + /*minor_to_major=*/{0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kDefaultMemorySpace); Shape t_s32_f32v1_in_default_mem = ShapeUtil::MakeTupleShape({s32_in_default_mem, f32v1_in_default_mem}); @@ -4751,7 +4892,8 @@ TEST_P(MemorySpaceAssignmentTest, Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithDenseLayout( F32, {2, 3}, - /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kAlternateMemorySpace); // p0 is in the default memory space. HloInstruction* p0 = @@ -4880,9 +5022,8 @@ TEST_P(MemorySpaceAssignmentTest, PendingChunkMemoryCorruptionBug) { } )"; - MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare = - [](const MemorySpaceAssignment::BufferInterval& a, - const MemorySpaceAssignment::BufferInterval& b) { + MsaBufferIntervalCompare buffer_interval_compare = + [](const MsaBufferInterval& a, const MsaBufferInterval& b) { auto get_opcode_priority = [](const HloOpcode& opcode) { switch (opcode) { case HloOpcode::kSin: @@ -4979,9 +5120,8 @@ TEST_P(MemorySpaceAssignmentTest, DisallowedUseBug) { } )"; - MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare = - [](const MemorySpaceAssignment::BufferInterval& a, - const MemorySpaceAssignment::BufferInterval& b) { + MsaBufferIntervalCompare buffer_interval_compare = + [](const MsaBufferInterval& a, const MsaBufferInterval& b) { auto get_opcode_priority = [](const HloOpcode& opcode) { switch (opcode) { case HloOpcode::kSin: @@ -5831,6 +5971,145 @@ ENTRY %primitive_computation_gather.4 (parameter.1: f32[3,10,5], parameter.2: s3 root->shape().layout().memory_space() == kDefaultMemorySpace); } +TEST_P(MemorySpaceAssignmentTest, PrecoloredBuffer) { + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + ENTRY Entry { + param0 = f32[8,3] parameter(0) + param1 = f32[2,4] parameter(1) + a = f32[8,3]{1,0:S(1)} cosine(param0) + b = f32[2,4] negate(param1) + d = f32[8,3] negate(a) + c = f32[2,4] negate(b) + e = f32[2,4] negate(c) + f = f32[8,3] negate(d) + g = f32[2,4] negate(e) + h = f32[2,4] negate(g) + i = f32[2,4] negate(h) + j = f32[2,4] negate(i) + k = f32[2,4] negate(j) + l = f32[2,4] negate(k) + m = f32[2,4] negate(l) + n = f32[2,4] negate(m) + o = f32[8,3] negate(f) + p = f32[2,4] negate(n) + q = f32[8,3] add(f, o) + r = f32[8,3] add(q, a) + ROOT tuple = (f32[2,4], f32[8,3]) tuple(p, r) + } + )"; + + MsaBufferIntervalCompare buffer_interval_compare = + [](const MsaBufferInterval& a, const MsaBufferInterval& b) { + auto get_opcode_priority = [](const HloOpcode& opcode) { + switch (opcode) { + case HloOpcode::kNegate: + return 0; + case HloOpcode::kAdd: + return 1; + case HloOpcode::kCos: + return 2; + default: + return 3; + } + }; + + return get_opcode_priority(a.buffer->defining_instruction()->opcode()) < + get_opcode_priority(b.buffer->defining_instruction()->opcode()); + }; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10); + Options options = DefaultMemorySpaceOptions(); + std::unique_ptr preset_assignments = + AssignMemorySpace(module.get(), options, buffer_interval_compare, + &prefetch_interval_picker); + + const HloInstruction* r = FindInstruction(module.get(), "r"); + const HloInstruction* d = FindInstruction(module.get(), "d"); + const HloInstruction* a = FindInstruction(module.get(), "a"); + // Make sure the r and d operands aren't prefetched. + EXPECT_EQ(r->operand(1), a); + EXPECT_EQ(d->operand(0), a); + // Make sure they are allocated in the alternate memory. + EXPECT_EQ(a->shape().layout().memory_space(), kAlternateMemorySpace); + // Make sure the a buffer has an entry in the preset assignments. + auto a_entry = std::find_if( + preset_assignments->chunks().begin(), preset_assignments->chunks().end(), + [&](std::pair position_and_chunk) { + return position_and_chunk.first.instruction == a; + }); + EXPECT_NE(a_entry, preset_assignments->chunks().end()); +} + +TEST_P(MemorySpaceAssignmentTest, PrecoloredBufferOOM) { + // Same as above but there are two 96-byte values that are pinned to the + // alternate memory (the size of the alternate memory is 128 bytes), which is + // unsatisfiable. + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + ENTRY Entry { + param0 = f32[8,3] parameter(0) + param1 = f32[2,4] parameter(1) + a = f32[8,3]{1,0:S(1)} cosine(param0) + b = f32[2,4] negate(param1) + d = f32[8,3] negate(a) + c = f32[2,4] negate(b) + e = f32[2,4] negate(c) + f = f32[8,3] negate(d) + g = f32[2,4] negate(e) + h = f32[2,4] negate(g) + i = f32[2,4] negate(h) + j = f32[2,4] negate(i) + k = f32[2,4] negate(j) + l = f32[2,4] negate(k) + m = f32[2,4] negate(l) + n = f32[2,4] negate(m) + o = f32[8,3]{1,0:S(1)} negate(f) + p = f32[2,4] negate(n) + q = f32[8,3] add(f, o) + r = f32[8,3] add(q, a) + ROOT tuple = (f32[2,4], f32[8,3]) tuple(p, r) + } + )"; + + MsaBufferIntervalCompare buffer_interval_compare = + [](const MsaBufferInterval& a, const MsaBufferInterval& b) { + auto get_opcode_priority = [](const HloOpcode& opcode) { + switch (opcode) { + case HloOpcode::kNegate: + return 0; + case HloOpcode::kAdd: + return 1; + case HloOpcode::kCos: + return 2; + default: + return 3; + } + }; + + return get_opcode_priority(a.buffer->defining_instruction()->opcode()) < + get_opcode_priority(b.buffer->defining_instruction()->opcode()); + }; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10); + Options options = DefaultMemorySpaceOptions(); + auto status_or = AssignMemorySpaceAndReturnStatus(module.get(), options, + buffer_interval_compare, + &prefetch_interval_picker); + EXPECT_THAT( + status_or.status(), + tsl::testing::StatusIs( + tsl::error::FAILED_PRECONDITION, + ::testing::HasSubstr("requires allocation in the alternate memory, " + "which could not be satisfied"))); +} + TEST_P(MemorySpaceAssignmentTest, AsyncOpShortLiveRange) { absl::string_view hlo_string = R"( HloModule module, is_scheduled=true @@ -6418,12 +6697,13 @@ class FakeMemorySpaceAssignmentRepacker : public MemorySpaceAssignmentRepacker { check_fun_(check_fun), always_return_modified_(always_return_modified) {} - StatusOr Repack(absl::Span allocations) override { + absl::StatusOr Repack( + absl::Span allocations) override { bool modified = false; for (AllocationBlock* block : allocations) { absl::flat_hash_set colocations; std::string colocations_str; - for (const AllocationBlock* colocation : block->colocations) { + for (const AllocationBlock* colocation : block->GetColocations()) { absl::StrAppend(&colocations_str, colocation->id, ", "); colocations.insert(colocation->id); } @@ -6440,7 +6720,7 @@ class FakeMemorySpaceAssignmentRepacker : public MemorySpaceAssignmentRepacker { } else { block->offset = block->initial_offset; } - for (AllocationBlock* colocation : block->colocations) { + for (AllocationBlock* colocation : block->GetColocations()) { if (it != repack_map_.end()) { colocation->offset = it->second; } else { @@ -6537,9 +6817,8 @@ TEST_P(MemorySpaceAssignmentTest, Repack) { } )"; - MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare = - [](const MemorySpaceAssignment::BufferInterval& a, - const MemorySpaceAssignment::BufferInterval& b) { + MsaBufferIntervalCompare buffer_interval_compare = + [](const MsaBufferInterval& a, const MsaBufferInterval& b) { auto get_opcode_priority = [](const HloOpcode& opcode) { switch (opcode) { case HloOpcode::kSin: @@ -6644,9 +6923,8 @@ TEST_P(MemorySpaceAssignmentTest, RepackExportsAliasedOffsets) { } )"; - MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare = - [](const MemorySpaceAssignment::BufferInterval& a, - const MemorySpaceAssignment::BufferInterval& b) { + MsaBufferIntervalCompare buffer_interval_compare = + [](const MsaBufferInterval& a, const MsaBufferInterval& b) { auto get_opcode_priority = [](const HloOpcode& opcode) { switch (opcode) { case HloOpcode::kSin: @@ -6671,16 +6949,14 @@ TEST_P(MemorySpaceAssignmentTest, RepackExportsAliasedOffsets) { // Expect that of the four separate allocations for the "a" buffer, the first // and the next three are in separate colocations. - auto check_fun = - [](absl::Span - allocations) { - EXPECT_TRUE(allocations.at(0)->colocations.size() == 1 || - allocations.at(0)->colocations.size() == 3); - EXPECT_EQ(allocations.at(1)->colocations.size(), 3); - EXPECT_EQ(allocations.at(2)->colocations.size(), 3); - EXPECT_TRUE(allocations.at(3)->colocations.size() == 1 || - allocations.at(3)->colocations.size() == 3); - }; + auto check_fun = [](absl::Span allocations) { + EXPECT_TRUE(allocations.at(0)->GetColocationsCount() == 1 || + allocations.at(0)->GetColocationsCount() == 3); + EXPECT_EQ(allocations.at(1)->GetColocationsCount(), 3); + EXPECT_EQ(allocations.at(2)->GetColocationsCount(), 3); + EXPECT_TRUE(allocations.at(3)->GetColocationsCount() == 1 || + allocations.at(3)->GetColocationsCount() == 3); + }; FakeMemorySpaceAssignmentRepacker repacker = FakeMemorySpaceAssignmentRepacker(repack_map, check_fun); Options options = DefaultMemorySpaceOptions(); @@ -6726,13 +7002,11 @@ ENTRY entry { // Expect that the first two value to repack has a colocations size of 2, // corresponding to the scoped allocations. - auto check_fun = - [&](absl::Span - allocations) { - EXPECT_EQ(allocations.at(0)->colocations.size(), 2); - EXPECT_EQ(allocations.at(1)->colocations.size(), 2); - repacker_ran = true; - }; + auto check_fun = [&](absl::Span allocations) { + EXPECT_EQ(allocations.at(0)->GetColocationsCount(), 2); + EXPECT_EQ(allocations.at(1)->GetColocationsCount(), 2); + repacker_ran = true; + }; FakeMemorySpaceAssignmentRepacker repacker = FakeMemorySpaceAssignmentRepacker(repack_map, check_fun); options.repacker = &repacker; @@ -6912,16 +7186,13 @@ TEST_P(MemorySpaceAssignmentTest, ScopedAllocationWithDifferentOffset) { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - auto check_fun = - [](absl::Span - allocations) { - for (MemorySpaceAssignmentRepacker::AllocationBlock* block : - allocations) { - if (block->inclusive_start_time == block->end_time) { - EXPECT_GT(block->colocations.size(), 0); - } - } - }; + auto check_fun = [](absl::Span allocations) { + for (AllocationBlock* block : allocations) { + if (block->inclusive_start_time == block->end_time) { + EXPECT_GT(block->GetColocationsCount(), 0); + } + } + }; absl::flat_hash_map, int64_t> repack_map; FakeMemorySpaceAssignmentRepacker repacker = FakeMemorySpaceAssignmentRepacker(repack_map, check_fun); @@ -7108,8 +7379,8 @@ ENTRY entry { negate2 = f32[2,3] negate(negate1) negate3 = f32[2,3] negate(negate2) negate4 = f32[2,3] negate(negate3) - async-start = ((f32[2,3]), f32[2,3], f32[2]) async-start(negate1), async_group_id=0, async_execution_thread="foobar", calls=async_comp - async-done = f32[2,3] async-done(async-start), async_group_id=0, async_execution_thread="foobar", calls=async_comp + async-start = ((f32[2,3]), f32[2,3], f32[2]) async-start(negate1), async_execution_thread="foobar", calls=async_comp + async-done = f32[2,3] async-done(async-start), async_execution_thread="foobar", calls=async_comp add0 = f32[2,3] add(negate0, async-done) negate5 = f32[2,3] negate(add0) negate6 = f32[2,3] negate(negate5) @@ -7228,7 +7499,8 @@ ENTRY entry { // Disable inefficiency check. Expect that the fusion output and operand are // in the alternate memory. options.inefficient_use_to_copy_ratio = 0.0; - AssignMemorySpaceUsingCostAnalysis(module.get(), options); + AssignMemorySpaceUsingCostAnalysis(module.get(), + /*memory_space_options_override=*/options); if (allocate_across_sequential_calls()) { EXPECT_THAT( module->entry_computation()->root_instruction(), @@ -7244,7 +7516,8 @@ ENTRY entry { // f32[2,3]), so this should be considered inefficient (8/48 < 0.5). TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnVerifiedModule(hlo_string)); options.inefficient_use_to_copy_ratio = 0.5; - AssignMemorySpaceUsingCostAnalysis(module.get(), options); + AssignMemorySpaceUsingCostAnalysis(module.get(), + /*memory_space_options_override=*/options); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Tuple(op::Fusion(op::Parameter()), op::Negate())); } @@ -7314,10 +7587,13 @@ ENTRY entry { Options options = DefaultMemorySpaceOptions(); options.enable_cross_program_prefetch = false; options.inefficient_use_to_copy_ratio = 0.5; - HloCostAnalysis::Options cost_options = DefaultHloCostAnalysisOptions(); - cost_options.set_transcendentals_per_second(0.4); + HloCostAnalysis::Options hlo_cost_options = DefaultHloCostAnalysisOptions(); + hlo_cost_options.set_transcendentals_per_second(0.4); - AssignMemorySpaceUsingCostAnalysis(module.get(), options, cost_options); + AssignMemorySpaceUsingCostAnalysis( + module.get(), /*memory_space_options_override=*/options, + /*cost_analysis_options_override=*/std::nullopt, + /*hlo_cost_options_override=*/hlo_cost_options); } TEST_P(MemorySpaceAssignmentTest, AsyncOpElapsedTime) { @@ -7346,66 +7622,390 @@ ENTRY entry { op::Parameter(1)); } -INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation, - MemorySpaceAssignmentTest, - ::testing::Values(false, true)); - -using AsynchronousCopyOrderingTest = ::testing::Test; +TEST_P(MemorySpaceAssignmentTest, AliasedOperandBug) { + // Test for a case where two aliased operands into the same instruction + // (param0 and custom_call2) cause a violation of the required assignment. + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true -TEST_F(AsynchronousCopyOrderingTest, Simple) { - // Given asynchronous copies like the following, ensure the pipelining order - // is maintained (earlier start time must have earlier end time). - // 3,11 +-------+ OK - // 1,8 +------+ OK - // 5,14 +--------+ OK - // 7,14 +------+ OK - // 2,16 +-------------+ Violate - // 9,12 +--+ Violate - // 6,17 +----------+ Violate - // 5,13 +-------+ OK (same start as 5,14) - // 5,14 +--------+ OK (same as 5,14) - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; - AsynchronousCopyOrdering ordering; - EXPECT_FALSE(ordering.ViolatesOrdering(3, 11)); - ordering.AddCopy({3, 11, 1, alternate_mem_space, 0}); - EXPECT_FALSE(ordering.ViolatesOrdering(1, 8)); - ordering.AddCopy({1, 8, 1, alternate_mem_space, 1}); - EXPECT_FALSE(ordering.ViolatesOrdering(5, 14)); - ordering.AddCopy({5, 14, 1, alternate_mem_space, 2}); - EXPECT_FALSE(ordering.ViolatesOrdering(7, 14)); - ordering.AddCopy({7, 14, 1, alternate_mem_space, 3}); - EXPECT_TRUE(ordering.ViolatesOrdering(2, 16)); - EXPECT_TRUE(ordering.ViolatesOrdering(9, 12)); - EXPECT_TRUE(ordering.ViolatesOrdering(6, 17)); - EXPECT_FALSE(ordering.ViolatesOrdering(5, 13)); - ordering.AddCopy({5, 13, 1, alternate_mem_space, 4}); - EXPECT_FALSE(ordering.ViolatesOrdering(5, 14)); - ordering.AddCopy({5, 14, 1, alternate_mem_space, 5}); +ENTRY entry { + param0 = f32[4,4]{0,1} parameter(0) + param1 = f32[4]{0} parameter(1) + param2 = f32[4,4]{0,1} parameter(2) + negate0 = f32[4]{0} negate(param1) + negate1 = f32[4]{0} negate(negate0) + negate2 = f32[4]{0} negate(negate1) + negate3 = f32[4]{0} negate(negate2) + negate4 = f32[4]{0} negate(negate3) + negate5 = f32[4]{0} negate(negate4) + custom_call1 = f32[4,4]{0,1} custom-call(param0), custom_call_target="FooBar", output_to_operand_aliasing={{}: (0, {})} + tanh = f32[4,4]{0,1} tanh(param2) + negate6 = f32[4]{0} negate(negate5) + negate7 = f32[4]{0} negate(negate6) + negate8 = f32[4]{0} negate(negate7) + negate9 = f32[4]{0} negate(negate8) + negate10 = f32[4]{0} negate(negate9) + negate11 = f32[4]{0} negate(negate10) + negate12 = f32[4]{0} negate(negate11) + negate13 = f32[4]{0} negate(negate12) + negate14 = f32[4]{0} negate(negate13) + negate15 = f32[4]{0} negate(negate14) + negate16 = f32[4]{0} negate(negate15) + custom_call2 = f32[4,4]{0,1} custom-call(custom_call1), custom_call_target="FooBar", output_to_operand_aliasing={{}: (0, {})} + custom_call3 = f32[4,4]{0,1} custom-call(param0, custom_call2), custom_call_target="FooBar", output_to_operand_aliasing={{}: (0, {})} + ROOT root = f32[4,4]{0,1} add(tanh, custom_call2) } + )"; -TEST_F(AsynchronousCopyOrderingTest, SameInterval) { - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; - AsynchronousCopyOrdering ordering; - EXPECT_FALSE(ordering.ViolatesOrdering(1, 5)); - EXPECT_FALSE(ordering.ViolatesOrdering(2, 4)); - ordering.AddCopy({1, 5, 1, alternate_mem_space, 0}); - EXPECT_TRUE(ordering.ViolatesOrdering(2, 4)); - ordering.AddCopy({1, 5, 1, alternate_mem_space, 1}); - EXPECT_TRUE(ordering.ViolatesOrdering(2, 4)); - ordering.AddCopy({1, 5, 1, alternate_mem_space, 2}); - EXPECT_TRUE(ordering.ViolatesOrdering(2, 4)); - ordering.RemoveCopy({1, 5, 1, alternate_mem_space, 1}); - EXPECT_TRUE(ordering.ViolatesOrdering(2, 4)); - ordering.RemoveCopy({1, 5, 1, alternate_mem_space, 2}); - EXPECT_TRUE(ordering.ViolatesOrdering(2, 4)); - ordering.RemoveCopy({1, 5, 1, alternate_mem_space, 0}); - EXPECT_FALSE(ordering.ViolatesOrdering(2, 4)); -} + MsaBufferIntervalCompare buffer_interval_compare = + [](const MsaBufferInterval& a, const MsaBufferInterval& b) { + auto get_inst_priority = [](const HloInstruction* instruction) { + if (instruction->name() == "param2") { + return 0; + } + if (instruction->name() == "param0") { + return 1; + } + return 2; + }; -using AsynchronousCopyResourceTest = ::testing::Test; + return get_inst_priority(a.buffer->defining_instruction()) < + get_inst_priority(b.buffer->defining_instruction()); + }; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); -TEST_F(AsynchronousCopyResourceTest, Simple) { - // time: 0 1 2 3 4 5 6 7 8 9 + InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10); + Options options = DefaultMemorySpaceOptions(); + AssignMemorySpace(module.get(), options, buffer_interval_compare, + &prefetch_interval_picker); +} + +TEST_P(MemorySpaceAssignmentTest, AsyncOpCustomFusionShortLiveRange) { + absl::string_view hlo_string = R"( +HloModule Module, is_scheduled=true + +fused_computation_start { + param0 = f32[2,1] parameter(0) + negate = f32[2,1] negate(param0) + ROOT custom-call = (f32[2,1], f32[2,1], u32[], u32[]) custom-call(negate), custom_call_target="AsyncOpStart" +} + +fused_computation_update { + param0 = f32[2,1] parameter(0) + param1 = f32[2,1] parameter(1) + param2 = f32[2,1] parameter(2) + param3 = f32[2,1] parameter(3) + param4 = u32[] parameter(4) + param5 = u32[] parameter(5) + add = f32[2,1] add(param0, param1) + negate = f32[2,1] negate(param2) + ROOT tuple = (f32[2,1], f32[2,1], f32[2,1], f32[2,1], u32[], u32[]) tuple(add, param2, param3, negate, param4, param5) +} + +fused_computation_done { + param0 = f32[2,1] parameter(0) + param1 = f32[2,1] parameter(1) + param2 = u32[] parameter(2) + param3 = u32[] parameter(3) + negate = f32[2,1] negate(param0) + ROOT custom-call = f32[2,1] custom-call(param0, param1, negate, param2, param3), custom_call_target="AsyncOpDone" +} + +ENTRY main { + param = f32[2,1] parameter(0) + negate1 = f32[2,1] negate(param) + negate2 = f32[2,1] negate(negate1) + fusion1 = (f32[2,1], f32[2,1], u32[], u32[]) fusion(negate1), kind=kCustom, output_to_operand_aliasing={{0}: (0, {})}, calls=fused_computation_start + negate3 = f32[2,1] negate(negate2) + negate4 = f32[2,1] negate(negate3) + gte0 = f32[2,1] get-tuple-element(fusion1), index=0 + gte1 = f32[2,1] get-tuple-element(fusion1), index=1 + gte2 = u32[] get-tuple-element(fusion1), index=2 + gte3 = u32[] get-tuple-element(fusion1), index=3 + fusion2 = (f32[2,1], f32[2,1], f32[2,1], f32[2,1], u32[], u32[]) fusion(negate4, negate2, gte0, gte1, gte2, gte3), kind=kLoop, output_to_operand_aliasing={{1}: (2, {}), {2}: (3, {}), {3}: (3, {}), {4}: (4, {}), {5}: (5, {})}, calls=fused_computation_update + gte4 = f32[2,1] get-tuple-element(fusion2), index=0 + negate5 = f32[2,1] negate(gte4) + gte5 = f32[2,1] get-tuple-element(fusion2), index=1 + gte6 = f32[2,1] get-tuple-element(fusion2), index=2 + gte7 = u32[] get-tuple-element(fusion2), index=4 + gte8 = u32[] get-tuple-element(fusion2), index=5 + fusion3 = f32[2,1] fusion(gte5, gte6, gte7, gte8), kind=kCustom, output_to_operand_aliasing={{}: (1, {})}, calls=fused_computation_done + ROOT add = f32[2,1] add(negate5, fusion3) +} + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.position_requires_contiguous_allocation_fn = + [](const HloPosition& position) { + std::string_view inst_name = position.instruction->name(); + if (inst_name == "fusion1" || + (inst_name == "fusion2" && position.index != ShapeIndex({0}))) { + return true; + } + return false; + }; + AssignMemorySpace(module.get(), options); + + HloInstruction* fusion1 = + module->entry_computation()->GetInstructionWithName("fusion1"); + HloInstruction* fusion2 = + module->entry_computation()->GetInstructionWithName("fusion2"); + HloInstruction* fusion3 = + module->entry_computation()->GetInstructionWithName("fusion3"); + + EXPECT_THAT(fusion2->operand(2), op::GetTupleElement(fusion1, 0)); + EXPECT_THAT(fusion2->operand(3), op::GetTupleElement(fusion1, 1)); + EXPECT_THAT(fusion3->operand(0), op::GetTupleElement(fusion2, 1)); + EXPECT_THAT(fusion3->operand(1), op::GetTupleElement(fusion2, 2)); + if (allocate_across_sequential_calls()) { + EXPECT_THAT(fusion2->operand(2)->shape().layout().memory_space(), + kAlternateMemorySpace); + EXPECT_THAT(fusion2->operand(3)->shape().layout().memory_space(), + kAlternateMemorySpace); + EXPECT_THAT(fusion3->operand(0)->shape().layout().memory_space(), + kAlternateMemorySpace); + EXPECT_THAT(fusion3->operand(1)->shape().layout().memory_space(), + kAlternateMemorySpace); + } + // Operand 0 and 1 should get alternate memory allocations and so is the + // output {0}. + EXPECT_THAT(fusion2->operand(0)->shape().layout().memory_space(), + kAlternateMemorySpace); + EXPECT_THAT(fusion2->operand(1)->shape().layout().memory_space(), + kAlternateMemorySpace); + EXPECT_THAT( + ShapeUtil::GetSubshape(fusion2->shape(), {0}).layout().memory_space(), + kAlternateMemorySpace); +} + +TEST_P(MemorySpaceAssignmentTest, AsyncOpCustomFusionLongLiveRange) { + absl::string_view hlo_string = R"( +HloModule Module, is_scheduled=true + +fused_computation_start { + param0 = f32[2,1] parameter(0) + negate = f32[2,1] negate(param0) + ROOT custom-call = (f32[2,1], f32[2,1], u32[], u32[]) custom-call(negate), custom_call_target="AsyncOpStart" +} + +fused_computation_update { + param0 = f32[2,1] parameter(0) + param1 = f32[2,1] parameter(1) + param2 = f32[2,1] parameter(2) + param3 = f32[2,1] parameter(3) + param4 = u32[] parameter(4) + param5 = u32[] parameter(5) + add = f32[2,1] add(param0, param1) + negate = f32[2,1] negate(param2) + ROOT tuple = (f32[2,1], f32[2,1], f32[2,1], f32[2,1], u32[], u32[]) tuple(add, param2, param3, negate, param4, param5) +} + +fused_computation_done { + param0 = f32[2,1] parameter(0) + param1 = f32[2,1] parameter(1) + param2 = u32[] parameter(2) + param3 = u32[] parameter(3) + negate = f32[2,1] negate(param0) + ROOT custom-call = f32[2,1] custom-call(param0, param1, negate, param2, param3), custom_call_target="AsyncOpDone" +} + +ENTRY main { + param = f32[2,1] parameter(0) + negate1 = f32[2,1] negate(param) + negate2 = f32[2,1] negate(negate1) + fusion1 = (f32[2,1], f32[2,1], u32[], u32[]) fusion(negate1), kind=kCustom, output_to_operand_aliasing={{0}: (0, {})}, calls=fused_computation_start + negate3 = f32[2,1] negate(negate2) + negate4 = f32[2,1] negate(negate3) + negate5 = f32[2,1] negate(negate4) + negate6 = f32[2,1] negate(negate5) + negate7 = f32[2,1] negate(negate6) + negate8 = f32[2,1] negate(negate7) + negate9 = f32[2,1] negate(negate8) + negate10 = f32[2,1] negate(negate9) + negate11 = f32[2,1] negate(negate10) + negate12 = f32[2,1] negate(negate11) + gte0 = f32[2,1] get-tuple-element(fusion1), index=0 + gte1 = f32[2,1] get-tuple-element(fusion1), index=1 + gte2 = u32[] get-tuple-element(fusion1), index=2 + gte3 = u32[] get-tuple-element(fusion1), index=3 + fusion2 = (f32[2,1], f32[2,1], f32[2,1], f32[2,1], u32[], u32[]) fusion(negate12, negate2, gte0, gte1, gte2, gte3), kind=kLoop, output_to_operand_aliasing={{1}: (2, {}), {2}: (3, {}), {3}: (3, {}), {4}: (4, {}), {5}: (5, {})}, calls=fused_computation_update + gte4 = f32[2,1] get-tuple-element(fusion2), index=0 + negate13 = f32[2,1] negate(gte4) + gte5 = f32[2,1] get-tuple-element(fusion2), index=1 + gte6 = f32[2,1] get-tuple-element(fusion2), index=2 + gte7 = u32[] get-tuple-element(fusion2), index=4 + gte8 = u32[] get-tuple-element(fusion2), index=5 + fusion3 = f32[2,1] fusion(gte5, gte6, gte7, gte8), kind=kCustom, output_to_operand_aliasing={{}: (1, {})}, calls=fused_computation_done + ROOT add = f32[2,1] add(negate13, fusion3) +} + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.position_requires_contiguous_allocation_fn = + [](const HloPosition& position) { + std::string_view inst_name = position.instruction->name(); + if (inst_name == "fusion1" || + (inst_name == "fusion2" && position.index != ShapeIndex({0}))) { + return true; + } + return false; + }; + AssignMemorySpace(module.get(), options); + + HloInstruction* fusion1 = + module->entry_computation()->GetInstructionWithName("fusion1"); + HloInstruction* fusion2 = + module->entry_computation()->GetInstructionWithName("fusion2"); + HloInstruction* fusion3 = + module->entry_computation()->GetInstructionWithName("fusion3"); + EXPECT_THAT(fusion2->operand(2), op::GetTupleElement(fusion1, 0)); + EXPECT_THAT(fusion2->operand(2)->shape().layout().memory_space(), + kDefaultMemorySpace); + EXPECT_THAT(fusion2->operand(3), op::GetTupleElement(fusion1, 1)); + EXPECT_THAT(fusion2->operand(3)->shape().layout().memory_space(), + kDefaultMemorySpace); + EXPECT_THAT(fusion3->operand(0), op::GetTupleElement(fusion2, 1)); + EXPECT_THAT(fusion3->operand(0)->shape().layout().memory_space(), + kDefaultMemorySpace); + EXPECT_THAT(fusion3->operand(1), op::GetTupleElement(fusion2, 2)); + EXPECT_THAT(fusion3->operand(1)->shape().layout().memory_space(), + kDefaultMemorySpace); + // Operand 0 and 1 should get alternate memory allocations and so is the + // output {0}. + EXPECT_THAT(fusion2->operand(0)->shape().layout().memory_space(), + kAlternateMemorySpace); + EXPECT_THAT(fusion2->operand(1)->shape().layout().memory_space(), + kAlternateMemorySpace); + EXPECT_THAT( + ShapeUtil::GetSubshape(fusion2->shape(), {0}).layout().memory_space(), + kAlternateMemorySpace); +} + +// This test seeks to test that MSA will schedule async copy operations with +// schedule_after=-1 at the very beginning of the program. +// +// The machinery for this is a little opaque from the public API, so we attempt +// to get MSA to self-assign an async copies with schedule_after=-1 by +// exploiting how the hidden algorithm works. This is brittle and subject to +// inadvertent breakage in the future. +TEST_P(MemorySpaceAssignmentTest, HoistCopyStart) { + absl::string_view hlo_string = R"( + HloModule cross_program_prefetch, is_scheduled=true + + ENTRY cross_program_prefetch { + p0 = (f32[8,8]{1,0}, f32[8,2]{1,0}) parameter(0) + get-tuple-element.0 = f32[8,8]{1,0} get-tuple-element(p0), index=0 + add.0 = f32[8,8]{1,0} add(get-tuple-element.0, get-tuple-element.0) + get-tuple-element.1 = f32[8,2]{1,0} get-tuple-element(p0), index=1 + dot.0 = f32[8,2]{1,0} dot(add.0, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + negate.1 = f32[8,2]{1,0} negate(dot.0) + negate.2 = f32[8,2]{1,0} negate(negate.1) + negate.3 = f32[8,2]{1,0} negate(negate.2) + negate.4 = f32[8,2]{1,0} negate(negate.3) + negate.5 = f32[8,2]{1,0} negate(negate.4) + negate.6 = f32[8,2]{1,0} negate(negate.5) + negate.7 = f32[8,2]{1,0} negate(negate.6) + negate.8 = f32[8,2]{1,0} negate(negate.7) + ROOT dot.1 = f32[2,2]{1,0} dot(negate.8, get-tuple-element.1), lhs_contracting_dims={0}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.enable_cross_program_prefetch = true; + AssignMemorySpace(module.get(), options); + + // Ensure that get-tuple-element.1 is chosen for cross-program prefetch. + auto cross_program_prefetches = module->CrossProgramPrefetches(); + ASSERT_EQ(cross_program_prefetches.size(), 1); + ASSERT_EQ(cross_program_prefetches[0].parameter, 0); + ASSERT_EQ(cross_program_prefetches[0].index, ShapeIndex({1})); + + // Check that the async copy-start for get-tuple-element.1 is hoisted + // after MSA (get-tuple-element.1 was initially the third operation of the + // original schedule). + // + // We expect the only instructions before it are declaring parameter(0) and + // get-tuple-element.1. + for (auto* instruction : module->schedule() + .sequence(module->entry_computation()) + .instructions()) { + auto p0 = op::Parameter(0); + auto get_tuple_element_1 = op::GetTupleElement(p0, 1); + auto copy_start = op::CopyStart(get_tuple_element_1); + EXPECT_THAT(instruction, AnyOf(p0, get_tuple_element_1, copy_start)); + if (::testing::Matches(copy_start)(instruction)) { + EXPECT_TRUE(instruction->cross_program_prefetch_index().has_value()); + break; + } + } +} + +INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation, + MemorySpaceAssignmentTest, + ::testing::Values(false, true)); + +using AsynchronousCopyOrderingTest = ::testing::Test; + +TEST_F(AsynchronousCopyOrderingTest, Simple) { + // Given asynchronous copies like the following, ensure the pipelining order + // is maintained (earlier start time must have earlier end time). + // 3,11 +-------+ OK + // 1,8 +------+ OK + // 5,14 +--------+ OK + // 7,14 +------+ OK + // 2,16 +-------------+ Violate + // 9,12 +--+ Violate + // 6,17 +----------+ Violate + // 5,13 +-------+ OK (same start as 5,14) + // 5,14 +--------+ OK (same as 5,14) + auto alternate_mem_space = MemorySpace::kAlternate; + AsynchronousCopyOrdering ordering; + EXPECT_FALSE(ordering.ViolatesOrdering(3, 11)); + ordering.AddCopy({3, 11, 1, alternate_mem_space, 0}); + EXPECT_FALSE(ordering.ViolatesOrdering(1, 8)); + ordering.AddCopy({1, 8, 1, alternate_mem_space, 1}); + EXPECT_FALSE(ordering.ViolatesOrdering(5, 14)); + ordering.AddCopy({5, 14, 1, alternate_mem_space, 2}); + EXPECT_FALSE(ordering.ViolatesOrdering(7, 14)); + ordering.AddCopy({7, 14, 1, alternate_mem_space, 3}); + EXPECT_TRUE(ordering.ViolatesOrdering(2, 16)); + EXPECT_TRUE(ordering.ViolatesOrdering(9, 12)); + EXPECT_TRUE(ordering.ViolatesOrdering(6, 17)); + EXPECT_FALSE(ordering.ViolatesOrdering(5, 13)); + ordering.AddCopy({5, 13, 1, alternate_mem_space, 4}); + EXPECT_FALSE(ordering.ViolatesOrdering(5, 14)); + ordering.AddCopy({5, 14, 1, alternate_mem_space, 5}); +} + +TEST_F(AsynchronousCopyOrderingTest, SameInterval) { + auto alternate_mem_space = MemorySpace::kAlternate; + AsynchronousCopyOrdering ordering; + EXPECT_FALSE(ordering.ViolatesOrdering(1, 5)); + EXPECT_FALSE(ordering.ViolatesOrdering(2, 4)); + ordering.AddCopy({1, 5, 1, alternate_mem_space, 0}); + EXPECT_TRUE(ordering.ViolatesOrdering(2, 4)); + ordering.AddCopy({1, 5, 1, alternate_mem_space, 1}); + EXPECT_TRUE(ordering.ViolatesOrdering(2, 4)); + ordering.AddCopy({1, 5, 1, alternate_mem_space, 2}); + EXPECT_TRUE(ordering.ViolatesOrdering(2, 4)); + ordering.RemoveCopy({1, 5, 1, alternate_mem_space, 1}); + EXPECT_TRUE(ordering.ViolatesOrdering(2, 4)); + ordering.RemoveCopy({1, 5, 1, alternate_mem_space, 2}); + EXPECT_TRUE(ordering.ViolatesOrdering(2, 4)); + ordering.RemoveCopy({1, 5, 1, alternate_mem_space, 0}); + EXPECT_FALSE(ordering.ViolatesOrdering(2, 4)); +} + +using AsynchronousCopyResourceTest = ::testing::Test; + +TEST_F(AsynchronousCopyResourceTest, Simple) { + // time: 0 1 2 3 4 5 6 7 8 9 // resource: 2 3 1 6 7 1 7 2 2 4 // -1,3,5 +-----+ OK // resource: 0 0 1 6 7 1 7 2 2 4 @@ -7416,7 +8016,7 @@ TEST_F(AsynchronousCopyResourceTest, Simple) { // 4,9,3 +-------+ Violate // 4,8,2 +-----+ OK; The 5,9 copy shifts resource to right. // resource: 0 0 0 3 7 0 0 0 0 4 - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource( {2.0, 3.0, 1.0, 6.0, 7.0, 1.0, 7.0, 2.0, 2.0, 4.0}); EXPECT_TRUE(resource.HasEnoughResource(-1, 3, 5.0)); @@ -7450,7 +8050,7 @@ TEST_F(AsynchronousCopyResourceTest, Propagate) { // 0,4,3 +-----+ OK // resource: 2 0 0 0 0 0 0 0 0 0 // 0,4,1 +-----+ Violate - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource( {2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0}); EXPECT_TRUE(resource.HasEnoughResource(6, 10, 2.0)); @@ -7494,7 +8094,7 @@ TEST_F(AsynchronousCopyResourceTest, CantPropagate) { // 4,8,4 +-----+ OK // resource: 2 2 2 2 2 0 0 0 0 2 // 3,6,4 +---+ Violate - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource( {2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0}); EXPECT_TRUE(resource.HasEnoughResource(5, 10, 2.0)); @@ -7521,7 +8121,7 @@ TEST_F(AsynchronousCopyResourceTest, Nested) { // 1,3,2 +-+ OK // resource: 2 2 0 2 2 // 0,4,4 +-----+ Violate - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource({2.0, 2.0, 2.0, 2.0, 2.0}); EXPECT_TRUE(resource.HasEnoughResource(1, 3, 2.0)); resource.AddCopy({1, 3, 2.0, alternate_mem_space, 0}); @@ -7545,7 +8145,7 @@ TEST_F(AsynchronousCopyResourceTest, Remove) { // resource: 0 1 2 2 2 // rem:-1,2,3+---+ // resource: 2 2 2 2 2 - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource({2.0, 2.0, 2.0, 2.0, 2.0}); AsynchronousCopy copy1{2, 5, 2.0, alternate_mem_space, 0}; AsynchronousCopy copy2{-1, 2, 3.0, alternate_mem_space, 1}; @@ -7588,7 +8188,7 @@ TEST_F(AsynchronousCopyResourceTest, NestedRemove) { // resource: 2 2 2 2 2 // add:1,3,2 +-+ OK // resource: 2 2 0 2 2 - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource({2.0, 2.0, 2.0, 2.0, 2.0}); AsynchronousCopy copy1{1, 3, 2.0, alternate_mem_space, 0}; AsynchronousCopy copy2{0, 4, 4.0, alternate_mem_space, 1}; @@ -7635,7 +8235,7 @@ TEST_F(AsynchronousCopyResourceTest, PropagateRemove) { // resource: 2 0 0 0 0 0 0 0 1 2 // rem:0,4,3 +-----+ // resource: 2 2 0 0 0 0 0 0 2 2 - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource( {2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0}); EXPECT_TRUE(resource.HasEnoughResource(6, 10, 2.0)); @@ -7687,7 +8287,7 @@ TEST_F(AsynchronousCopyResourceTest, StartAtZeroAndRemove) { // resource: 0 0 1 1 2 // add:0,4,2 +-----+ OK // resource: 0 0 0 0 2 - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource({0.0, 0.0, 1.0, 1.0, 2.0}); AsynchronousCopy copy1{0, 4, 2.0, alternate_mem_space, 0}; EXPECT_TRUE(resource.HasEnoughResource(0, 4, 2.0)); @@ -7730,7 +8330,7 @@ TEST_F(AsynchronousCopyResourceTest, OutOfOrderRemovalSameStartTime) { // resource: 2 2 1 2 2 // rem:1,5,1 +-----+ // resource: 2 2 2 2 2 - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource({2.0, 2.0, 2.0, 2.0, 2.0}); AsynchronousCopy copy1{1, 3, 1.0, alternate_mem_space, 0}; AsynchronousCopy copy2{1, 4, 2.0, alternate_mem_space, 1}; @@ -7795,7 +8395,7 @@ TEST_F(AsynchronousCopyResourceTest, HasEnoughResourceMultiCheckSuccess) { // 0,6,4 +-----------+ // 4,6,3 +-+ 2 copies OK; The 1,10 copy shifts. // resource: 0 0 0 0 6 0 7 2 2 4 - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource( {2.0, 1.0, 3.0, 6.0, 7.0, 3.0, 7.0, 2.0, 2.0, 4.0}); EXPECT_TRUE(resource.HasEnoughResource(-1, 3, 5.0)); @@ -7823,7 +8423,7 @@ TEST_F(AsynchronousCopyResourceTest, HasEnoughResourceMultiCheckFailure) { // resource: 0 0 0 3 7 3 7 2 2 4 // 0,6,4 +-----------+ // 4,6,4 +-+ Not-OK - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource( {2.0, 1.0, 3.0, 6.0, 7.0, 3.0, 7.0, 2.0, 2.0, 4.0}); EXPECT_TRUE(resource.HasEnoughResource(-1, 3, 5.0)); @@ -7840,7 +8440,7 @@ TEST_F(AsynchronousCopyResourceTest, HasEnoughResourceMultiCheckFailure) { TEST_F(AsynchronousCopyResourceTest, HasEnoughResourceMultiCheckRegressionTest) { - auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; + auto alternate_mem_space = MemorySpace::kAlternate; AsynchronousCopyResource resource({/*0:*/ 24.0f, /*1:*/ 0.0f, /*2:*/ 6.0f, @@ -8354,7 +8954,8 @@ TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchPinnedTest) { auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature}); auto rhs_shape = ShapeUtil::MakeShapeWithDenseLayout( F32, {kFeature, kOutput}, - /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kAlternateMemorySpace); auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput}); HloInstruction* lhs = builder.AddInstruction( @@ -8396,7 +8997,8 @@ TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchPinnedTupleTest) { auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature}); auto rhs_shape = ShapeUtil::MakeShapeWithDenseLayout( F32, {kFeature, kOutput}, - /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kAlternateMemorySpace); auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput}); auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape}); @@ -8454,6 +9056,35 @@ TEST_P(MemorySpaceAssignmentTest, CrossProgramRootDupMayAlias) { op::Parameter(0)); } +TEST_P(MemorySpaceAssignmentTest, CrossProgramRootDusFusionMayAlias) { + absl::string_view hlo_string = R"( + HloModule cross_program_prefetch, is_scheduled=true, input_output_alias={ {}: (0, {}, may-alias) } + fused_computation { + fused_p0 = s32[2,2] parameter(0) + fused_p1 = s32[1,2] parameter(1) + fused_p2 = s32[] parameter(2) + fused_p3 = s32[] parameter(3) + ROOT dus = s32[2,2] dynamic-update-slice(fused_p0, fused_p1, fused_p2, fused_p3) + } + + ENTRY CrossProgramPrefetch { + p0 = s32[2,2] parameter(0) + c0 = s32[1,2] constant({{77, 77}}) + c1 = s32[] constant(0) + bitcast1 = s32[2,2] bitcast(p0) + ROOT fusion = s32[2,2] fusion(bitcast1, c0, c1, c1), kind=kLoop, calls=fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto preset_assignments = AssignMemorySpace( + module.get(), DefaultMemorySpaceOptions(), + /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 0); +} + TEST_P(MemorySpaceAssignmentTest, CrossProgramRootDup) { absl::string_view hlo_string = R"( HloModule cross_program_prefetch, is_scheduled=true @@ -8518,19 +9149,15 @@ TEST_P(MemorySpaceAssignmentTest, CrossProgramRootDotMayAlias) { /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2); auto cross_program_prefetches = module->CrossProgramPrefetches(); - EXPECT_EQ(cross_program_prefetches.size(), 1); + EXPECT_EQ(cross_program_prefetches.size(), 0); EXPECT_THAT(FindInstruction(module.get(), "dot")->operand(1), - op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, - op::Parameter(0))); + op::Parameter(0)); } TEST_P(MemorySpaceAssignmentTest, CrossProgramRootLiveOutBug) { - // An in-place fusion that lives out should not be included as a use to the - // cross-program prefetch allocation. Due to a bug, we considered in-place - // update that feeds the ROOT of the entry computation as a valid use of the - // cross-program prefetch. This then would cause this live-out buffer to be - // placed in the alternate memory. We expect p0 to be cross-program prefetched - // but only for the dot operand and not the fusion operand. + // Input-output aliased buffers should not be cross-program prefetched since + // the update on the buffer will not be reflected on the next program + // execution (the data in the alternate memory would be stale). absl::string_view hlo_string = R"( HloModule cross_program_prefetch, is_scheduled=true, input_output_alias={ {0}: (0, {}, may-alias) } fused_computation { @@ -8556,12 +9183,7 @@ TEST_P(MemorySpaceAssignmentTest, CrossProgramRootLiveOutBug) { /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2); auto cross_program_prefetches = module->CrossProgramPrefetches(); - EXPECT_EQ(cross_program_prefetches.size(), 1); - EXPECT_THAT(FindInstruction(module.get(), "dot")->operand(1), - op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, - op::Parameter(0))); - EXPECT_THAT(FindInstruction(module.get(), "fusion")->operand(0), - op::Parameter(0)); + EXPECT_EQ(cross_program_prefetches.size(), 0); } TEST_P(MemorySpaceAssignmentTest, CrossProgramRootParameter) { @@ -8822,1832 +9444,215 @@ TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTupleReuse) { ROOT dot.2 = f32[2,2]{1,0} dot(negate.8, get-tuple-element.1), lhs_contracting_dims={0}, rhs_contracting_dims={0} } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - AssignMemorySpace(module.get(), DefaultMemorySpaceOptions(), - /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2); - - auto cross_program_prefetches = module->CrossProgramPrefetches(); - EXPECT_EQ(cross_program_prefetches.size(), 1); - if (!cross_program_prefetches.empty()) { - EXPECT_EQ(cross_program_prefetches[0].parameter, 0); - EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({1})); - } - - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr dataflow_analysis, - HloDataflowAnalysis::Run(*module)); - const HloValue& cross_program_prefetched_value = - dataflow_analysis->GetValueDefinedAt( - module->entry_computation()->parameter_instruction(0), {1}); - // Expect that there is one prefetch that use this value, the cross-program - // prefetch. There shouldn't be an end-of-program prefetch. - auto is_cross_program_prefetch = [](const HloUse& use) { - return use.instruction->opcode() == HloOpcode::kCopyStart && - use.instruction->cross_program_prefetch_index().has_value(); - }; - EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(), - is_cross_program_prefetch), - 1); - auto is_end_of_program_prefetch = [](const HloUse& use) { - return use.instruction->opcode() == HloOpcode::kCopyStart && - !use.instruction->cross_program_prefetch_index().has_value(); - }; - EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(), - is_end_of_program_prefetch), - 0); -} - -TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchBufferUnused) { - absl::string_view hlo_string = R"( -HloModule module, is_scheduled=true - -%fused_computation { - %param_0.2 = f32[32]{0} parameter(0) - %param_1.4 = s32[100]{0} parameter(1) - %custom-call.1 = s32[100]{0} custom-call(s32[100]{0} %param_1.4), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[100]{0}} - %slice.1 = s32[32]{0} slice(s32[100]{0} %custom-call.1), slice={[0:32]} - %reshape.7 = s32[32]{0} reshape(s32[32]{0} %slice.1) - %transpose.5 = s32[32]{0} transpose(s32[32]{0} %reshape.7), dimensions={0} - %gather.1 = f32[32]{0} gather(f32[32]{0} %param_0.2, s32[32]{0} %transpose.5), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1} - %transpose.4 = f32[32]{0} transpose(f32[32]{0} %gather.1), dimensions={0} - ROOT %reshape.6 = f32[32]{0} reshape(f32[32]{0} %transpose.4) -} - -%i.reduce_sub_computation { - %rhs = s32[] parameter(1) - %lhs = s32[] parameter(0) - ROOT %add = s32[] add(s32[] %lhs, s32[] %rhs) -} - -%fused_computation.1 { - %constant.4 = s32[] constant(0) - %broadcast.4 = s32[100]{0} broadcast(s32[] %constant.4), dimensions={} - %param_0.4 = s32[32]{0} parameter(0) - %pad.1 = s32[100]{0} pad(s32[32]{0} %param_0.4, s32[] %constant.4), padding=0_68 - %constant.3 = s32[] constant(76031) - %broadcast.3 = s32[100]{0} broadcast(s32[] %constant.3), dimensions={} - ROOT %clamp.1 = s32[100]{0} clamp(s32[100]{0} %broadcast.4, s32[100]{0} %pad.1, s32[100]{0} %broadcast.3) -} - -ENTRY %main { - %constant = s32[] constant(0) - %i = s32[32,1]{0,1} parameter(1) - %o = f32[32]{0} parameter(0) - %reduce = s32[32]{0} reduce(s32[32,1]{0,1} %i, s32[] %constant), dimensions={1}, to_apply=%i.reduce_sub_computation - %fusion.1 = s32[100]{0} fusion(s32[32]{0} %reduce), kind=kLoop, calls=%fused_computation.1 - ROOT %fusion = f32[32]{0} fusion(f32[32]{0} %o, s32[100]{0} %fusion.1), kind=kCustom, calls=%fused_computation -} - )"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - AssignMemorySpace(module.get()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Fusion(op::AsyncCopy(kAlternateMemorySpace, - kDefaultMemorySpace, op::Parameter(0)), - op::Fusion())); -} - -// Test description: -// - Setup: Make sure p1 can not be prefetched to alternate memory until after -// instruction c. We do this by causing p0 to be prefetched to alternate -// memory for use in c. Since p0 is larger than 1/2 of alternate memory, we -// will not be able to prefetch p1 until after p0 is unallocated. -// - Test: prefetch p1, after p0 is unallocated from alternate memory (after -// instruction c). -TEST_P(MemorySpaceAssignmentTest, CopyResourceIntegration) { - std::string_view hlo_string = R"( -HloModule module, is_scheduled=true - -ENTRY main { - p0 = s32[8,8] parameter(0) - p1 = s32[8,8] parameter(1) - p2 = s32[] parameter(2) - a = negate(p2) - b = negate(a) - c = add(p0, p0) - d = negate(b) - e = negate(d) - f = add(p1, p1) - - ROOT result = tuple(e,c,f) -} - )"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - Options options = DefaultMemorySpaceOptions(); - options.max_size_in_bytes = 300; - - // Setup cost analysis so it takes 2 instructions to prefetch anything. - HloCostAnalysis hlo_cost_analysis(ShapeSize); - TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, - FakeMemorySpaceAssignmentCostAnalysis::Create( - hlo_cost_analysis, *module, options)); - cost_analysis->SetOverrideForGetInstructionElapsed( - [](const HloInstruction& instruction) -> float { return 10.0; }); - cost_analysis->SetOverrideForGetAsyncCopyElapsed( - [](const Shape& shape) -> float { return 20.0; }); - options.cost_analysis = cost_analysis.get(); - CostAnalysisPrefetchIntervalPicker prefetch_interval_picker( - CostAnalysisPrefetchIntervalPicker( - *cost_analysis, /*min_overlap_to_async_copy_ratio=*/0.8, - /*preferred_overlap_to_async_copy_ratio=*/1.5, - /*max_overlap_to_mem_size_async_copy_ratio=*/10.0, - /*mem_size_bytes=*/options.max_size_in_bytes)); - - // p0 has the highest priority, followed by p1, followed by everything else. - MemorySpaceAssignment::BufferIntervalCompare compare = - [](const MemorySpaceAssignment::BufferInterval& lhs, - const MemorySpaceAssignment::BufferInterval& rhs) -> bool { - auto lookup = [](const MemorySpaceAssignment::BufferInterval& x) { - // An arbitrary value that is greater than that for p0 and p1. - int priority = 100; - if (x.buffer->instruction()->name() == "p0") { - priority = 0; - } else if (x.buffer->instruction()->name() == "p1") { - priority = 1; - } - return std::make_tuple(priority, x.buffer->instruction()->name()); - }; - - return lookup(lhs) < lookup(rhs); - }; - - // Run test. - AssignMemorySpace(module.get(), options, compare, &prefetch_interval_picker); - - // - Make sure the setup occurred, i.e., that p0 is prefetched to alternate - // memory for use by c. - // - Make sure p1 is prefetched. - ASSERT_THAT( - module->entry_computation()->root_instruction(), - op::Tuple(_, - // p0 is prefetched to alternate memory for use by c. - op::Add(op::AsyncCopy(kAlternateMemorySpace, - kDefaultMemorySpace, op::Parameter(0)), - op::AsyncCopy(kAlternateMemorySpace, - kDefaultMemorySpace, op::Parameter(0))), - // p1 is prefetched to alternate memory for use by f. - op::Add(op::AsyncCopy(kAlternateMemorySpace, - kDefaultMemorySpace, op::Parameter(1)), - op::AsyncCopy(kAlternateMemorySpace, - kDefaultMemorySpace, op::Parameter(1))))); - - // Check the schedule - const std::vector& schedule = - module->schedule().sequence(module->entry_computation()).instructions(); - auto find_schedule_index = [&schedule](std::string_view name) -> int { - for (int i = 0; i < schedule.size(); ++i) { - if (schedule[i]->name() == name) { - return i; - } - } - LOG(FATAL) << "Unable to find index of instruction with name " << name; - }; - int c_index = find_schedule_index("c"); - int p1_copy_start = find_schedule_index(module->entry_computation() - ->root_instruction() // result - ->operand(2) // f - ->operand(0) // copy done - ->operand(0) // copy start - ->name()); - int d_index = find_schedule_index("d"); - int e_index = find_schedule_index("e"); - int p1_copy_end = find_schedule_index(module->entry_computation() - ->root_instruction() // result - ->operand(2) // f - ->operand(0) // copy done - ->name()); - int f_index = find_schedule_index("f"); - // We expect to start copying p1 after c. - EXPECT_EQ(p1_copy_start, c_index + 1); - // d and e should follow come between p1's copy start and end. - EXPECT_EQ(d_index, p1_copy_start + 1); - EXPECT_EQ(e_index, d_index + 1); - EXPECT_EQ(p1_copy_end, e_index + 1); - // f should immediately follow the end of p1's copy. - EXPECT_EQ(f_index, p1_copy_end + 1); -} - -using CostAnalysisPrefetchIntervalPickerTest = HloTestBase; - -TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) { - absl::string_view hlo_string = R"( - HloModule bug, is_scheduled=true - - ENTRY Entry { - param0 = f32[2,4] parameter(0) - a = f32[2,4] negate(param0) - b = f32[2,4] negate(a) - c = f32[2,4] negate(b) - d = f32[2,4] negate(c) - e = f32[2,4] negate(d) - f = f32[2,4] negate(e) - g = f32[2,4] negate(f) - h = f32[2,4] negate(g) - i = f32[2,4] negate(h) - j = f32[2,4] negate(i) - k = f32[2,4] negate(j) - l = f32[2,4] negate(k) - m = f32[2,4] negate(l) - n = f32[2,4] negate(m) - o = f32[2,4] negate(n) - p = f32[2,4] negate(o) - q = f32[2,4] negate(p) - r = f32[2,4] negate(q) - s = f32[2,4] negate(r) - t = f32[2,4] negate(s) - u = f32[2,4] negate(t) - ROOT v = f32[2,4] add(u, param0) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - HloCostAnalysis hlo_cost_analysis(ShapeSize); - Options options; - TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, - FakeMemorySpaceAssignmentCostAnalysis::Create( - hlo_cost_analysis, *module, options)); - CostAnalysisPrefetchIntervalPicker interval_picker( - *cost_analysis, - /*min_overlap_to_async_copy_ratio=*/1.0, - /*preferred_overlap_to_async_copy_ratio=*/2.0, - /*max_overlap_to_mem_size_async_copy_ratio=*/4.0, - /*mem_size_bytes=*/32); - - HloInstruction* root = module->entry_computation()->root_instruction(); - const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; - interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/22, std::nullopt); - - // Expect that the first interval is (15, 22), which has elapsed time of 6.0, - // twice of the async copy elased (3.0). Then we expect that intervals will be - // visited in alternating increasing and decreasing orders until hitting the - // min and max async copy overlap ratios, which are the intervals (18, 22) - // and (9, 22) respectively. - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 15); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 16); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 14); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 17); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 13); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 18); // Min async overlap ratio reached. - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 12); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 11); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 10); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 9); // Max async overlap ratio reached. - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_TRUE(interval_picker.Done()); - - // Expect that if the time between start_time and end_time is too short, there - // won't be any available intervals. - interval_picker.Begin(use, /*start_time=*/19, /*end_time=*/22, std::nullopt); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_TRUE(interval_picker.Done()); -} - -TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrderWhile) { - absl::string_view hlo_string = R"( - HloModule bug, is_scheduled=true - - while_condition { - param1 = (f32[2,4]) parameter(0) // 19 - ROOT cond = pred[] constant(true) // 20 - } - - while_body { - param2 = (f32[2,4]) parameter(0) // 21 - gte2 = f32[2,4] get-tuple-element(param2), index=0 // 22 - add = f32[2,4] add(gte2, gte2) // 23 - ROOT tuple2 = (f32[2,4]) tuple(add) // 24 - } - - ENTRY Entry { - param0 = f32[2,4] parameter(0) // 0 - a = f32[2,4] negate(param0) // 1 - b = f32[2,4] negate(a) // 2 - c = f32[2,4] negate(b) // 3 - d = f32[2,4] negate(c) // 4 - e = f32[2,4] negate(d) // 5 - f = f32[2,4] negate(e) // 6 - g = f32[2,4] negate(f) // 7 - h = f32[2,4] negate(g) // 8 - i = f32[2,4] negate(h) // 9 - j = f32[2,4] negate(i) // 10 - k = f32[2,4] negate(j) // 11 - l = f32[2,4] negate(k) // 12 - m = f32[2,4] negate(l) // 13 - n = f32[2,4] negate(m) // 14 - o = f32[2,4] negate(n) // 15 - p = f32[2,4] negate(o) // 16 - q = f32[2,4] negate(p) // 17 - tuple = (f32[2,4]) tuple(q) // 18 - while = (f32[2,4]) while(tuple), condition=while_condition, body=while_body // 25 - gte1 = f32[2,4] get-tuple-element(while), index=0 // 26 - r = f32[2,4] negate(gte1) // 27 - s = f32[2,4] negate(r) // 28 - t = f32[2,4] negate(s) // 29 - u = f32[2,4] negate(t) // 30 - ROOT v = f32[2,4] add(u, param0) // 31 - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - HloCostAnalysis hlo_cost_analysis(ShapeSize); - Options options; - TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, - FakeMemorySpaceAssignmentCostAnalysis::Create( - hlo_cost_analysis, *module, options)); - CostAnalysisPrefetchIntervalPicker interval_picker( - *cost_analysis, - /*min_overlap_to_async_copy_ratio=*/1.0, - /*preferred_overlap_to_async_copy_ratio=*/2.0, - /*max_overlap_to_mem_size_async_copy_ratio=*/12.0, - /*mem_size_bytes=*/32); - - EXPECT_EQ(cost_analysis->options() - .xla_tpu_memory_space_assignment_while_execution_count, - 5); - HloInstruction* root = module->entry_computation()->root_instruction(); - const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; - interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/31, std::nullopt); - - // Because there are while loop computations between [19, 24], we ensure that - // the interval picker avoids this interval. - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 25); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 26); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 18); - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 27); // Min async overlap ratio reached. - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_EQ(interval_picker.Next(), 17); // Max async overlap ratio reached. - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_TRUE(interval_picker.Done()); -} - -TEST_F(CostAnalysisPrefetchIntervalPickerTest, NestedWhile) { - // This test is to check against a bug where we didn't assign - // while_nest_level_ for while instructions, and defaulting to 0. This could - // cause the prefetch interval logic to think a nested while instruction is - // the same level as the outermost computation. - absl::string_view hlo_string = R"( - HloModule bug, is_scheduled=true - - while_condition.2 { - param1 = (f32[2,4]) parameter(0) // 11 - ROOT cond = pred[] constant(true) // 12 - } - - while_body.2 { - param2 = (f32[2,4]) parameter(0) // 13 - gte2 = f32[2,4] get-tuple-element(param2), index=0 // 14 - add = f32[2,4] add(gte2, gte2) // 15 - ROOT tuple2 = (f32[2,4]) tuple(add) // 16 - } - - while_condition.1 { - param3 = (f32[2,4]) parameter(0) // 5 - ROOT cond = pred[] constant(true) // 6 - } - - while_body.1 { - param4 = (f32[2,4]) parameter(0) // 7 - gte1 = f32[2,4] get-tuple-element(param4), index=0 // 8 - add1 = f32[2,4] add(gte1, gte1) // 9 - tuple1 = (f32[2,4]) tuple(add1) // 10 - while = (f32[2,4]) while(tuple1), condition=while_condition.2, body=while_body.2 // 17 - gte2 = f32[2,4] get-tuple-element(while), index=0 // 18 - add2 = f32[2,4] add(gte2, gte2) // 19 - ROOT tuple2 = (f32[2,4]) tuple(add2) // 20 - } - - ENTRY Entry { - param0 = f32[2,4] parameter(0) // 0 - a = f32[2,4] negate(param0) // 1 - b = f32[2,4] negate(a) // 2 - c = f32[2,4] negate(b) // 3 - tuple = (f32[2,4]) tuple(c) // 4 - while = (f32[2,4]) while(tuple), condition=while_condition.1, body=while_body.1 // 21 - gte1 = f32[2,4] get-tuple-element(while), index=0 // 22 - ROOT root = f32[2,4] add(gte1, param0) // 23 - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - HloCostAnalysis hlo_cost_analysis(ShapeSize); - Options options; - TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, - FakeMemorySpaceAssignmentCostAnalysis::Create( - hlo_cost_analysis, *module, options)); - CostAnalysisPrefetchIntervalPicker interval_picker( - *cost_analysis, - /*min_overlap_to_async_copy_ratio=*/1.0, - /*preferred_overlap_to_async_copy_ratio=*/2.0, - /*max_overlap_to_mem_size_async_copy_ratio=*/12.0, - /*mem_size_bytes=*/32); - - HloInstruction* root = module->entry_computation()->root_instruction(); - const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; - const Shape& shape = root->operand(1)->shape(); - - // We expect the root's latest prefetch start time to be before the while loop - // (logical time 4). - EXPECT_EQ(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0, - /*end_time=*/23, &use), - 4); -} - -TEST_F(CostAnalysisPrefetchIntervalPickerTest, ConsecutiveConditionals) { - // This is a test for b/170668492, where prefetching for consecutive - // conditionals can cause the prefetch to start in the conditional's - // computation. - absl::string_view hlo_string = R"( - HloModule bug, is_scheduled=true - - true_computation.0 { - p0 = (f32[3]{0}) parameter(0) // 5 - gte = f32[3]{0} get-tuple-element(p0), index=0 // 6 - ROOT neg1 = f32[3]{0} negate(gte) // 7 - } - - false_computation.0 { - p0 = (f32[3]{0}) parameter(0) // 8 - gte = f32[3]{0} get-tuple-element(p0), index=0 // 9 - ROOT neg2 = f32[3]{0} negate(gte) // 10 - } - - true_computation.1 { - p0 = (f32[3]{0}) parameter(0) // 12 - gte = f32[3]{0} get-tuple-element(p0), index=0 // 13 - ROOT neg1 = f32[3]{0} negate(gte) // 14 - } - - false_computation.1 { - p0 = (f32[3]{0}) parameter(0) // 15 - gte = f32[3]{0} get-tuple-element(p0), index=0 // 16 - ROOT neg2 = f32[3]{0} negate(gte) // 17 - } - - ENTRY entry { - p0 = f32[3]{0} parameter(0) // 0 - p1 = f32[3]{0} parameter(1) // 1 - p2 = pred[] parameter(2) // 2 - tuple0 = (f32[3]{0}) tuple(p0) // 3 - tuple1 = (f32[3]{0}) tuple(p1) // 4 - conditional0 = f32[3]{0} conditional(p2, tuple0, tuple0), true_computation=true_computation.0, false_computation=false_computation.0 // 11 - conditional1 = f32[3]{0} conditional(p2, tuple1, tuple1), true_computation=true_computation.1, false_computation=false_computation.1 // 18 - ROOT tuple2 = (f32[3]{0}, f32[3]{0}) tuple(conditional0, conditional1) // 19 - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - HloCostAnalysis hlo_cost_analysis(ShapeSize); - Options options; - TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, - FakeMemorySpaceAssignmentCostAnalysis::Create( - hlo_cost_analysis, *module, options)); - CostAnalysisPrefetchIntervalPicker interval_picker( - *cost_analysis, - /*min_overlap_to_async_copy_ratio=*/1.0, - /*preferred_overlap_to_async_copy_ratio=*/2.0, - /*max_overlap_to_mem_size_async_copy_ratio=*/12.0, - /*mem_size_bytes=*/32); - - LOG(INFO) << module->ToString(); - - HloInstruction* conditional1 = - module->entry_computation()->GetInstructionWithName("conditional1"); - const HloUse use{conditional1, /*operand_number=*/1, /*operand_index=*/{0}}; - const Shape& shape = - module->entry_computation()->parameter_instruction(0)->shape(); - - // Expect that the prefetch to start before conditional0's called - // computations. - EXPECT_LT(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0, - /*end_time=*/11, &use), - 5); -} - -TEST_F(CostAnalysisPrefetchIntervalPickerTest, EarliestLatestWindowTooSmall) { - // This tests the scenario where there is an op that takes a long time (tanh - // in this example) and as a result the earliest and latest times both fall - // inside this long-running op. In this case, we should still return a valid - // prefetch interval just before the long-running op. - absl::string_view hlo_string = R"( - HloModule bug, is_scheduled=true - - ENTRY Entry { - param0 = f32[2,4] parameter(0) - negate = f32[2,4] negate(param0) - tanh = f32[2,4] tanh(param0) - ROOT add = f32[2,4] add(tanh, negate) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - HloCostAnalysis hlo_cost_analysis(ShapeSize); - Options options; - TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, - FakeMemorySpaceAssignmentCostAnalysis::Create( - hlo_cost_analysis, *module, options)); - cost_analysis->SetOverrideForGetInstructionElapsed( - [](const HloInstruction& hlo) { - if (hlo.opcode() == HloOpcode::kTanh) { - return 20.0; - } - return 1.0; - }); - CostAnalysisPrefetchIntervalPicker interval_picker( - *cost_analysis, - /*min_overlap_to_async_copy_ratio=*/1.0, - /*preferred_overlap_to_async_copy_ratio=*/2.0, - /*max_overlap_to_mem_size_async_copy_ratio=*/12.0, - /*mem_size_bytes=*/32); - - HloInstruction* root = module->entry_computation()->root_instruction(); - const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; - interval_picker.Begin(use, /*start_time=*/1, /*end_time=*/3, std::nullopt); - - LOG(INFO) << interval_picker.ToDebugString(); - EXPECT_FALSE(interval_picker.Done()); - EXPECT_EQ(interval_picker.Next(), 1); - EXPECT_TRUE(interval_picker.Done()); -} - -class MemorySpaceAssignmentCostAnalysisTest : public HloTestBase { - protected: - Status Initialize(const HloModule* module, - float pipeline_overhead_window_size_mib = 0.0) { - HloCostAnalysis::Options options; - options_.alternate_mem_bandwidth_bytes_per_second = 128; - options_.async_copy_bandwidth_bytes_per_second = 32; - options_.pipeline_overhead_window_size_mib = - pipeline_overhead_window_size_mib; - options.shape_size = ShapeSize; - options.set_flops_per_second(8); - options.set_bytes_per_second(32); - options.set_transcendentals_per_second(16); - hlo_cost_analysis_ = std::make_unique(options); - TF_RETURN_IF_ERROR( - module->entry_computation()->Accept(hlo_cost_analysis_.get())); - TF_ASSIGN_OR_RETURN(cost_analysis_, - MemorySpaceAssignmentCostAnalysis::Create( - *hlo_cost_analysis_, options_, *module)); - return OkStatus(); - } - - Options options_; - std::unique_ptr hlo_cost_analysis_; - std::unique_ptr cost_analysis_; -}; - -TEST_F(MemorySpaceAssignmentCostAnalysisTest, NoPipelineOverhead) { - absl::string_view hlo_string = R"( - HloModule module, is_scheduled=true - - ENTRY Entry { - param0 = f32[2,4] parameter(0) - param1 = f32[2,4] parameter(1) - ROOT add = f32[2,4] add(param0, param1) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK(Initialize(module.get())); - - const HloInstruction* add = module->entry_computation()->root_instruction(); - const float expected_compute_elapsed = - /*num_flops=*/8 / /*flops_per_second=*/8.0; - LOG(INFO) << "Expected compute elapsed = " << expected_compute_elapsed; - EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToCompute(*add), - expected_compute_elapsed); - float expected_memory_elapsed = - /*bytes_accessed=*/(3 * 4 * 8) / /*bytes_per_second=*/32.0; - LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; - EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory(*add), - expected_memory_elapsed); - - // This HLO is memory-bound. - EXPECT_EQ(cost_analysis_->GetInstructionElapsed(*add), - expected_memory_elapsed); - EXPECT_EQ( - cost_analysis_->GetInstructionElapsedInAlternateMemory(*add, {}, {}), - expected_memory_elapsed); - - // Put operand 0 in alternate memory. Still memory bound. - expected_memory_elapsed = - (/*bytes_accessed=*/(2 * 4 * 8) / /*bytes_per_second=*/32.0) + - (/*bytes_accessed=*/(4 * 8) / /*bytes_per_second=*/128.0); - LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; - EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory(*add, {{0, {}}}), - expected_memory_elapsed); - EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( - *add, {{0, {}}}, {}), - expected_memory_elapsed); - - // Put operand 0 and output in alternate memory. Still memory bound. - expected_memory_elapsed = - (/*bytes_accessed=*/(4 * 8) / /*bytes_per_second=*/32.0) + - (/*bytes_accessed=*/(2 * 4 * 8) / /*bytes_per_second=*/128.0); - LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; - EXPECT_EQ( - cost_analysis_->GetInstructionElapsedDueToMemory(*add, {{0, {}}}, {{}}), - expected_memory_elapsed); - EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( - *add, {{0, {}}}, {{}}), - expected_memory_elapsed); - - // Put everything in alternate memory. We're now compute bound. - expected_memory_elapsed = - /*bytes_accessed=*/(3 * 4 * 8) / /*bytes_per_second=*/128.0; - LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; - EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory( - *add, {{0, {}}, {1, {}}}, {{}}), - expected_memory_elapsed); - EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( - *add, {{0, {}}, {1, {}}}, {{}}), - expected_compute_elapsed); -} - -TEST_F(MemorySpaceAssignmentCostAnalysisTest, PipelineOverhead) { - absl::string_view hlo_string = R"( - HloModule module, is_scheduled=true - - ENTRY Entry { - param0 = f32[2,4] parameter(0) - param1 = f32[2,4] parameter(1) - ROOT add = f32[2,4] add(param0, param1) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - // Set the window size 64B. - TF_ASSERT_OK( - Initialize(module.get(), - /*pipeline_overhead_window_size_mib=*/(64.0 / 1024 / 1024))); - - const HloInstruction* add = module->entry_computation()->root_instruction(); - const float expected_compute_elapsed = - /*num_flops=*/8 / /*flops_per_second=*/8.0; - LOG(INFO) << "Expected compute elapsed = " << expected_compute_elapsed; - EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToCompute(*add), - expected_compute_elapsed); - float expected_memory_elapsed = - /*bytes_accessed=*/(3 * 4 * 8) / /*bytes_per_second=*/32.0; - LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; - EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory(*add), - expected_memory_elapsed); - - float expected_overhead = expected_compute_elapsed * 2 / 3; - LOG(INFO) << "Expected overhead = " << expected_overhead; - EXPECT_EQ(cost_analysis_->GetDefaultMemoryAccessOverhead(*add), - expected_overhead); - // This HLO is memory-bound. - EXPECT_EQ(cost_analysis_->GetInstructionElapsed(*add), - expected_memory_elapsed + expected_overhead); - EXPECT_EQ( - cost_analysis_->GetInstructionElapsedInAlternateMemory(*add, {}, {}), - expected_memory_elapsed + expected_overhead); - - // Put operand 0 in alternate memory. Still memory bound. - expected_memory_elapsed = - (/*bytes_accessed=*/(2 * 4 * 8) / /*bytes_per_second=*/32.0) + - (/*bytes_accessed=*/(4 * 8) / /*bytes_per_second=*/128.0); - LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; - EXPECT_EQ(cost_analysis_->GetDefaultMemoryAccessOverhead(*add, {{0, {}}}), - expected_overhead); - EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory(*add, {{0, {}}}), - expected_memory_elapsed); - EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( - *add, {{0, {}}}, {}), - expected_memory_elapsed + expected_overhead); - - // Put operand 0 and output in alternate memory. Still memory bound. - expected_memory_elapsed = - (/*bytes_accessed=*/(4 * 8) / /*bytes_per_second=*/32.0) + - (/*bytes_accessed=*/(2 * 4 * 8) / /*bytes_per_second=*/128.0); - LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; - expected_overhead = expected_compute_elapsed / 3; - LOG(INFO) << "Expected overhead = " << expected_overhead; - EXPECT_EQ( - cost_analysis_->GetDefaultMemoryAccessOverhead(*add, {{0, {}}}, {{}}), - expected_overhead); - EXPECT_EQ( - cost_analysis_->GetInstructionElapsedDueToMemory(*add, {{0, {}}}, {{}}), - expected_memory_elapsed); - EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( - *add, {{0, {}}}, {{}}), - expected_memory_elapsed + expected_overhead); - - // Put everything in alternate memory. We're now compute bound. - expected_memory_elapsed = - /*bytes_accessed=*/(3 * 4 * 8) / /*bytes_per_second=*/128.0; - LOG(INFO) << "Expected memory elapsed = " << expected_memory_elapsed; - expected_overhead = 0; - LOG(INFO) << "Expected overhead = " << expected_overhead; - EXPECT_EQ(cost_analysis_->GetDefaultMemoryAccessOverhead( - *add, {{0, {}}, {1, {}}}, {{}}), - expected_overhead); - EXPECT_EQ(cost_analysis_->GetInstructionElapsedDueToMemory( - *add, {{0, {}}, {1, {}}}, {{}}), - expected_memory_elapsed); - EXPECT_EQ(cost_analysis_->GetInstructionElapsedInAlternateMemory( - *add, {{0, {}}, {1, {}}}, {{}}), - expected_compute_elapsed); -} - -class MemoryBoundLoopOptimizerTest : public HloTestBase { - public: - MemoryBoundLoopOptimizerTest() = default; - - protected: - const int64_t kAlternateMemorySpace = 1; - const int64_t kDefaultMemorySpace = 0; - - Status Initialize(const HloModule* module, - uint64_t alternate_memory_size = 256) { - HloCostAnalysis::Options options; - MemoryBoundLoopOptimizerOptions optimizer_options; - optimizer_options.set_enabled(true); - optimizer_options.set_desired_copy_ratio(0.7); - optimizer_options.set_allow_unsatisfied_fully_pipelined_prefetch(false); - optimizer_options.set_min_num_iterations(3.0); - options_.memory_bound_loop_optimizer_options = optimizer_options; - options_.alternate_mem_bandwidth_bytes_per_second = 128; - options_.async_copy_bandwidth_bytes_per_second = 32; - options_.pipeline_overhead_window_size_mib = 1; - options.shape_size = ShapeSize; - options.set_flops_per_second(16); - options.set_bytes_per_second(32); - options.set_transcendentals_per_second(16); - hlo_cost_analysis_ = std::make_unique(options); - TF_RETURN_IF_ERROR( - module->entry_computation()->Accept(hlo_cost_analysis_.get())); - TF_ASSIGN_OR_RETURN(cost_analysis_, - MemorySpaceAssignmentCostAnalysis::Create( - *hlo_cost_analysis_, options_, *module)); - TF_ASSIGN_OR_RETURN(alias_analysis_, HloAliasAnalysis::Run(module)); - TF_ASSIGN_OR_RETURN(live_range_, - HloLiveRange::Run(module->schedule(), *alias_analysis_, - module->entry_computation())); - return OkStatus(); - } - - StatusOr CreateOptimizer( - int loop_start, int loop_end, const HloModule* module, - uint64_t alternate_memory_size = 256) { - TF_RETURN_IF_ERROR(Initialize(module, alternate_memory_size)); - MemoryBoundLoopOptimizerOptions optimizer_options; - optimizer_options.set_enabled(true); - optimizer_options.set_desired_copy_ratio(0.7); - optimizer_options.set_allow_unsatisfied_fully_pipelined_prefetch(false); - TF_ASSIGN_OR_RETURN( - optimizer_, - MemoryBoundLoopOptimizer::Create( - loop_start, loop_end, alternate_memory_size, optimizer_options, - *live_range_, *alias_analysis_, *cost_analysis_, SizeFunction)); - return optimizer_.get(); - } - - StatusOr> ParseAndCreateOptimizer( - absl::string_view hlo_loop_str, uint64_t alternate_memory_size, - int& loop_start_idx, MemoryBoundLoopOptimizer** optimizer) { - int loop_end_idx; - TF_ASSIGN_OR_RETURN( - std::string module_str, - ParseAndCreateModuleString(hlo_loop_str, loop_start_idx, loop_end_idx)); - TF_ASSIGN_OR_RETURN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); - TF_ASSIGN_OR_RETURN( - *optimizer, CreateOptimizer(loop_start_idx, loop_end_idx, module.get(), - alternate_memory_size)); - return std::move(module); - } - - // Parse a loop string description like the following: - // $op0 = f32[1,4] add(f32[1,4] $param0, f32[1,4] $prev_op4) - // $op1 = f32[8,4] add(f32[8,4] $param1, f32[8,4] $prev_op3) - // $op2 = f32[1,4] add(f32[1,4] $param2, f32[1,4] $op0) - // $op3 = f32[8,4] add(f32[8,4] $param3, f32[8,4] $op1) - // $op4 = f32[1,4] add(f32[1,4] $param4, f32[1,4] $op2) - StatusOr ParseAndCreateModuleString( - absl::string_view hlo_loop_str, int& loop_start_idx, int& loop_end_idx) { - // Parse op name and types first. - RE2 op_re("\\$op([0-9]+) += +(\\S+).*"); - std::vector ops; - std::vector op_types; - int begin_pos = 0; - absl::string_view submatch[3]; - while (op_re.Match(hlo_loop_str, begin_pos, hlo_loop_str.size(), - RE2::UNANCHORED, submatch, /*nsubmatch=*/3)) { - for (int i = 0; i < 3; ++i) { - if (submatch[i].data() == nullptr) { - VLOG(4) << "Submatch[" << i << "] = nullptr"; - } else { - VLOG(4) << "Submatch[" << i << "] = " << submatch[i] - << " (idx: " << (submatch[i].data() - hlo_loop_str.data()) - << ")"; - } - } - int op_num; - if (!absl::SimpleAtoi(submatch[1], &op_num)) { - return InvalidArgument("Op name expects to contain a number, found %s.", - submatch[1]); - } - if (op_num != ops.size()) { - return InvalidArgument("Op number expected to be %d found %d.", - op_types.size(), op_num); - } - ops.push_back(submatch[0]); - op_types.push_back(submatch[2]); - begin_pos = submatch[0].data() - hlo_loop_str.data() + submatch[0].size(); - } - - RE2 param_re("([[:alnum:]]+\\[\\S*\\]) +\\$param([0-9]+)"); - std::vector param_types; - begin_pos = 0; - while (param_re.Match(hlo_loop_str, begin_pos, hlo_loop_str.size(), - RE2::UNANCHORED, submatch, /*nsubmatch=*/3)) { - for (int i = 0; i < 3; ++i) { - if (submatch[i].data() == nullptr) { - VLOG(4) << "Submatch[" << i << "] = nullptr"; - } else { - VLOG(4) << "Submatch[" << i << "] = " << submatch[i] - << " (idx: " << (submatch[i].data() - hlo_loop_str.data()) - << ")"; - } - } - int param_num; - if (!absl::SimpleAtoi(submatch[2], ¶m_num)) { - return InvalidArgument( - "Param name expects to contain a number, found %s.", submatch[2]); - } - while (param_num >= param_types.size()) { - param_types.push_back({}); - } - param_types[param_num] = submatch[1]; - - begin_pos = submatch[0].data() - hlo_loop_str.data() + submatch[0].size(); - } - - RE2 root_re("ROOT \\$root += +tuple\\((.*)\\)"); - absl::string_view root_values; - if (root_re.Match(hlo_loop_str, 0, hlo_loop_str.size(), RE2::UNANCHORED, - submatch, /*nsubmatch=*/2)) { - for (int i = 0; i < 2; ++i) { - if (submatch[i].data() == nullptr) { - VLOG(4) << "Submatch[" << i << "] = nullptr"; - } else { - VLOG(4) << "Submatch[" << i << "] = " << submatch[i] - << " (idx: " << (submatch[i].data() - hlo_loop_str.data()) - << ")"; - } - } - root_values = submatch[1]; - } - - for (absl::string_view op_type : op_types) { - VLOG(4) << "op_type: " << op_type; - } - for (absl::string_view param_type : param_types) { - VLOG(4) << "param_type: " << param_type; - } - - std::string hlo_string = R"( -HloModule module, is_scheduled=true - -ENTRY Entry { -)"; - int total_instructions = 0; - for (absl::string_view param_prefix : {"prev_", "", "next_"}) { - for (int i = 0; i < param_types.size(); ++i) { - int parameter_number = total_instructions; - absl::StrAppend(&hlo_string, " ", param_prefix, "param", i, " = ", - param_types[i], " parameter(", parameter_number, - ") // ", total_instructions++, "\n"); - } - } - - for (int i = 0; i < op_types.size(); ++i) { - int parameter_number = total_instructions; - absl::StrAppend(&hlo_string, " ", "prev_prev_op", i, " = ", op_types[i], - " parameter(", parameter_number, ") // ", - total_instructions++, "\n"); - } - - std::string new_root_values; - auto print_ops = - [&](const std::vector>& - replacements) { - for (int i = 0; i < ops.size(); ++i) { - absl::StrAppend(&hlo_string, " ", - absl::StrReplaceAll(ops[i], replacements), " // ", - total_instructions++, "\n"); - } - if (!root_values.empty()) { - absl::StrAppend(&new_root_values, - new_root_values.empty() ? "" : ", ", - absl::StrReplaceAll(root_values, replacements)); - } - }; - - std::vector> - prev_replacements; - prev_replacements.push_back({"$prev_op", "prev_prev_op"}); - prev_replacements.push_back({"$op", "prev_op"}); - prev_replacements.push_back({"$param", "prev_param"}); - absl::StrAppend(&hlo_string, " // Prev iteration body:\n"); - print_ops(prev_replacements); - - loop_start_idx = total_instructions; - std::vector> replacements; - replacements.push_back({"$", ""}); - absl::StrAppend(&hlo_string, " // Loop body:\n"); - print_ops(replacements); - loop_end_idx = total_instructions; - - std::vector> - next_replacements; - next_replacements.push_back({"$prev_op", "op"}); - next_replacements.push_back({"$op", "next_op"}); - next_replacements.push_back({"$param", "next_param"}); - absl::StrAppend(&hlo_string, " // Next iteration body:\n"); - print_ops(next_replacements); - - absl::StrAppend(&hlo_string, " ROOT root = tuple(", new_root_values, - ")\n"); - absl::StrAppend(&hlo_string, "}"); - - VLOG(1) << hlo_string; - return hlo_string; - } - - StatusOr> RunMsa( - HloModule* module, uint64_t alternate_memory_size = 256) { - options_.max_size_in_bytes = alternate_memory_size; - options_.alignment_in_bytes = 8; - options_.verify = true; - - options_.alternate_memory_space = kAlternateMemorySpace; - - if (!cost_analysis_) { - TF_RETURN_IF_ERROR(Initialize(module, alternate_memory_size)); - } - MemorySpaceAssignmentCostAnalysis::Cache cache; - memory_space_assignment::MemoryBoundednessBufferIntervalComparator - comparator(*cost_analysis_, &cache); - options_.buffer_interval_comparator = &comparator; - CostAnalysisPrefetchIntervalPicker prefetch_interval_picker( - CostAnalysisPrefetchIntervalPicker( - *cost_analysis_, /*min_overlap_to_async_copy_ratio=*/0.8, - /*preferred_overlap_to_async_copy_ratio=*/1.5, - /*max_overlap_to_mem_size_async_copy_ratio=*/10.0, - /*mem_size_bytes=*/alternate_memory_size)); - options_.prefetch_interval_picker = &prefetch_interval_picker; - - auto size_fn = [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); - }; - options_.size_fn = size_fn; - - auto is_allowed_in_alternate_mem = [](const HloValue& value) { - // Check if the value belongs to the entry computation. - HloInstruction* instruction = value.instruction(); - HloComputation* computation = instruction->parent(); - bool in_entry_computation = - (computation == computation->parent()->entry_computation()); - if (in_entry_computation && - instruction->opcode() == HloOpcode::kParameter) { - return false; - } - return true; - }; - options_.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem; - options_.max_outstanding_prefetches = -1; - options_.max_outstanding_evictions = -1; - options_.allocate_across_sequential_calls = true; - options_.cost_analysis = cost_analysis_.get(); - - std::unique_ptr preset_assignments = - MemorySpaceAssignment::Run(module, *live_range_, *alias_analysis_, - options_) - .value(); - return preset_assignments; - } - - Status VerifyMsaEquivalence(HloModule* module, - bool expect_unsupported_allocations = false) { - // Create a map indexed by instruction number and operand number. - absl::flat_hash_map, - const MemorySpaceAssignment::Allocation*> - allocation_map; - for (const MemoryBoundLoopOptimizer::LoopValue& value : - optimizer_->loop_values()) { - // Skip verification for unsupported allocations as they will go through - // the usual MSA algorithm and may actually get an alternate memory - // allocation. - if (!value.IsAllocationTypeSupported()) { - continue; - } - for (const auto& allocation : value.allocations) { - for (const HloUse& use : allocation->uses()) { - absl::string_view inst_name = use.instruction->name(); - TF_RET_CHECK(absl::StartsWith(inst_name, "op")); - int inst_number; - TF_RET_CHECK(absl::SimpleAtoi(inst_name.substr(2), &inst_number)); - allocation_map[{inst_number, use.operand_number}] = allocation.get(); - } - } - } - - auto get_inst_prefix_in_iter = [](int iteration) { - switch (iteration) { - case 0: - return "prev_"; - case 1: - return ""; - case 2: - return "next_"; - default: - LOG(FATAL) << "Invalid iteration " << iteration; - return "INVALID"; - } - }; - - TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module)); - TF_ASSIGN_OR_RETURN(std::unique_ptr live_range, - HloLiveRange::Run(module->schedule(), *alias_analysis, - module->entry_computation())); - const auto& flattened_instructions = - live_range->flattened_instruction_sequence().instructions(); - for (int iteration = 1; iteration < 3; ++iteration) { - for (int inst_number = 0; inst_number < optimizer_->loop_size(); - ++inst_number) { - HloInstruction* inst = FindInstruction( - module, absl::StrCat(get_inst_prefix_in_iter(iteration), "op", - inst_number)); - for (int operand_number = 0; operand_number < 2; ++operand_number) { - const HloInstruction* operand = inst->operand(operand_number); - LOG(INFO) << inst->name() << ", operand " << operand_number; - if (!allocation_map.contains({inst_number, operand_number})) { - TF_RET_CHECK(expect_unsupported_allocations); - continue; - } - const MemorySpaceAssignment::Allocation* allocation = - allocation_map.at({inst_number, operand_number}); - if (!allocation->is_copy_allocation()) { - // We don't expect a prefetch here. - EXPECT_NE(operand->opcode(), HloOpcode::kCopyDone); - int expected_memory_space = - allocation->memory_space() == - MemorySpaceAssignment::MemorySpace::kDefault - ? kDefaultMemorySpace - : kAlternateMemorySpace; - EXPECT_EQ(operand->shape().layout().memory_space(), - expected_memory_space); - } else { - EXPECT_EQ(allocation->memory_space(), - MemorySpaceAssignment::MemorySpace::kAlternate); - TF_RET_CHECK(operand->opcode() == HloOpcode::kCopyDone); - const MemorySpaceAssignment::CopyAllocation* copy_allocation = - static_cast( - allocation); - if (copy_allocation->copy_done_schedule_before() != inst_number) { - // The only case where the copy done schedule before is not the - // same as this use would be that this use is not the first use of - // the copy allocation. - EXPECT_NE(allocation->uses().front(), - (HloUse{inst, operand_number})); - continue; - } - int expected_copy_start_iteration = iteration; - if (copy_allocation->copy_start_schedule_after() == - optimizer_->loop_size() && - copy_allocation->copy_done_schedule_before() == 0) { - expected_copy_start_iteration -= 2; - } else if (copy_allocation->copy_start_schedule_after() + 1 >= - copy_allocation->copy_done_schedule_before()) { - expected_copy_start_iteration -= 1; - } - - if (expected_copy_start_iteration >= 0) { - const HloInstruction* expected_copy_start_schedule_after = - FindInstruction( - module, - absl::StrCat( - get_inst_prefix_in_iter( - expected_copy_start_iteration), - "op", copy_allocation->copy_start_schedule_after())); - LOG(INFO) << "Expected copy start schedule after: " - << expected_copy_start_schedule_after->name(); - const HloInstruction* copy_start = operand->operand(0); - TF_RET_CHECK(copy_start->opcode() == HloOpcode::kCopyStart); - // Find the instruction before this copy start that is not an - // async copy or gte or parameter. - int copy_start_idx = - live_range->instruction_schedule().at(copy_start); - const HloInstruction* copy_start_schedule_after = nullptr; - for (int i = copy_start_idx - 1; i >= 0; --i) { - HloOpcode opcode = flattened_instructions.at(i)->opcode(); - if (opcode != HloOpcode::kCopyStart && - opcode != HloOpcode::kCopyDone && - opcode != HloOpcode::kGetTupleElement && - opcode != HloOpcode::kParameter) { - copy_start_schedule_after = flattened_instructions.at(i); - break; - } - } - TF_RET_CHECK(copy_start_schedule_after != nullptr); - EXPECT_EQ(copy_start_schedule_after, - expected_copy_start_schedule_after); - } - } - } - } - } - return OkStatus(); - } - - private: - Options options_; - std::unique_ptr hlo_cost_analysis_; - std::unique_ptr cost_analysis_; - std::unique_ptr alias_analysis_; - std::unique_ptr live_range_; - std::unique_ptr optimizer_; -}; - -TEST_F(MemoryBoundLoopOptimizerTest, SimplePrefetch) { - absl::string_view hlo_loop_str = R"( - $op0 = f32[1,4] add(f32[1,4] $prev_op3, f32[1,4] $prev_op4) - $op1 = f32[1,4] add(f32[1,4] $prev_op4, f32[1,4] $op0) - $op2 = f32[1,4] add(f32[1,4] $op0, f32[1,4] $op1) - $op3 = f32[1,4] add(f32[1,4] $op1, f32[1,4] $op2) - $op4 = f32[1,4] add(f32[1,4] $param0, f32[1,4] $op3) - ROOT $root = tuple($op4, $param0) - )"; - int loop_start_idx; - MemoryBoundLoopOptimizer* optimizer; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndCreateOptimizer(hlo_loop_str, - /*alternate_memory_size=*/128, - loop_start_idx, &optimizer)); - - optimizer->Optimize(); - absl::flat_hash_set seen_uses; - for (const MemoryBoundLoopOptimizer::LoopValue& loop_value : - optimizer->loop_values()) { - LOG(INFO) << loop_value.ToString(); - if (loop_value.hlo_values.front() - ->defining_position() - .instruction->name() == "param0") { - EXPECT_TRUE(loop_value.allocations.back()->is_copy_allocation()); - } - for (const auto& allocation : loop_value.allocations) { - for (const HloUse& use : allocation->uses()) { - EXPECT_FALSE(seen_uses.contains(use)) << use.ToString(); - seen_uses.insert(use); - } - } - } - - // Ensure all of the uses in the loop have an associated use. - for (absl::string_view inst_name : {"op0", "op1", "op2", "op3", "op4"}) { - HloInstruction* inst = - module->entry_computation()->GetInstructionWithName(inst_name); - EXPECT_TRUE(seen_uses.contains(HloUse{inst, 0})) << inst_name; - EXPECT_TRUE(seen_uses.contains(HloUse{inst, 1})) << inst_name; - } -} - -TEST_F(MemoryBoundLoopOptimizerTest, NoAlternateMem) { - absl::string_view hlo_loop_str = R"( - $op0 = f32[1,4] add(f32[1,4] $prev_op3, f32[1,4] $prev_op4) - $op1 = f32[1,4] add(f32[1,4] $prev_op4, f32[1,4] $op0) - $op2 = f32[1,4] add(f32[1,4] $op0, f32[1,4] $op1) - $op3 = f32[1,4] add(f32[1,4] $op1, f32[1,4] $op2) - $op4 = f32[1,4] add(f32[1,4] $param0, f32[1,4] $op3) - ROOT $root = tuple($op4, $param0) - )"; - int loop_start_idx; - MemoryBoundLoopOptimizer* optimizer; - // Set alternate memory size to zero so nothing should be in the alternate - // memory. We still expect to find an allocation for all uses. - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndCreateOptimizer(hlo_loop_str, - /*alternate_memory_size=*/0, - loop_start_idx, &optimizer)); - - optimizer->Optimize(); - absl::flat_hash_set seen_uses; - for (const MemoryBoundLoopOptimizer::LoopValue& loop_value : - optimizer->loop_values()) { - LOG(INFO) << loop_value.ToString(); - for (const auto& allocation : loop_value.allocations) { - EXPECT_EQ(allocation->memory_space(), - MemorySpaceAssignment::MemorySpace::kDefault); - for (const HloUse& use : allocation->uses()) { - EXPECT_FALSE(seen_uses.contains(use)) << use.ToString(); - seen_uses.insert(use); - } - } - } - - // Ensure all of the uses in the loop have an associated use. - for (absl::string_view inst_name : {"op0", "op1", "op2", "op3", "op4"}) { - HloInstruction* inst = - module->entry_computation()->GetInstructionWithName(inst_name); - EXPECT_TRUE(seen_uses.contains(HloUse{inst, 0})) << inst_name; - EXPECT_TRUE(seen_uses.contains(HloUse{inst, 1})) << inst_name; - } -} - -TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithOverlap) { - // Test for enforcing FIFO order of prefetches. There are three parameters - // that will be prefetched (param0, param1, and param2). param2 is one eighth - // the size of the other parameters and is scheduled later in the loop. So, we - // expect the allocation algorithm to initially allocate param2's prefetch - // with a short live range (since copying it doesn't take very long), but then - // as we try to prefetch param0 and param1, we will wrap around into the - // previous iterations and would need to "early force" param2's prefetch to be - // scheduled earlier to enforce the FIFO order. - // - // alternate_mem_bytes_per_second = 128 - // default_mem_bytes_per_second = 32 - // flops_per_second = 16 - // f32[1,4] add: flops: 4, bytes: 48, compute elapsed: 0.25 - // - All default memory elapsed: 1.5 - // - All alternate memory elapsed: 0.375 - // f32[8,4] add: flops: 32, bytes: 384, compute elapsed: 2 - // - All default memory elapsed: 12 - // - All alternate memory elapsed: 3 - // f32[1,4] copy: bytes: 16, memory elapsed: 0.5 - // f32[8,4] copy: bytes: 128, memory elapsed: 4 - absl::string_view hlo_loop_str = R"( - $op0 = f32[1,4] add(f32[1,4] $prev_op13, f32[1,4] $prev_op14) - $op1 = f32[8,4] add(f32[8,4] $param0, f32[8,4] $param1) - $op2 = f32[1,4] add(f32[1,4] $prev_op14, f32[1,4] $op0) - $op3 = f32[1,4] add(f32[1,4] $op0, f32[1,4] $op2) - $op4 = f32[1,4] add(f32[1,4] $op2, f32[1,4] $op3) - $op5 = f32[1,4] add(f32[1,4] $op3, f32[1,4] $op4) - $op6 = f32[1,4] add(f32[1,4] $op4, f32[1,4] $op5) - $op7 = f32[1,4] add(f32[1,4] $op5, f32[1,4] $op6) - $op8 = f32[1,4] add(f32[1,4] $op6, f32[1,4] $op7) - $op9 = f32[1,4] add(f32[1,4] $op7, f32[1,4] $op8) - $op10 = f32[1,4] add(f32[1,4] $op8, f32[1,4] $op9) - $op11 = f32[1,4] add(f32[1,4] $op9, f32[1,4] $op10) - $op12 = f32[1,4] add(f32[1,4] $op10, f32[1,4] $op11) - $op13 = f32[1,4] add(f32[1,4] $op11, f32[1,4] $op12) - $op14 = f32[1,4] add(f32[1,4] $param2, f32[1,4] $op13) - )"; - - int loop_start_idx; - MemoryBoundLoopOptimizer* optimizer; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndCreateOptimizer(hlo_loop_str, - /*alternate_memory_size=*/512, - loop_start_idx, &optimizer)); - - optimizer->Optimize(); - // We expect the prefetches to be scheduled this way: - // - // - // param0 or param1: - // ===========> =====================================> - // param1 or param0: - // ===========> === - // ==============================================> - // param2: - // =====> ========================================> === - // 13 14| 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14| 0 1 - // prev | loop | next - // - // Temporaries: - // +======+ - // +=========+ - // +=========+ - // +======+ - // +======+ - // +======+ - // +======+ - // +======+ - // +======+ - // +======+ - // +======+ - // +======+ - // +======+ - // +===+ - // +======+ - // +=========+ - // 13 14| 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14| 0 1 - // prev | loop | next - std::vector prefetches; - for (const MemoryBoundLoopOptimizer::LoopValue& loop_value : - optimizer->loop_values()) { - if (!loop_value.allocations.empty() && - loop_value.allocations.back()->is_copy_allocation()) { - prefetches.push_back( - static_cast( - loop_value.allocations.back().get())); - } - } - EXPECT_EQ(prefetches.size(), 3); - bool seen_overlap = false; - bool seen_nonoverlap = false; - for (const MemorySpaceAssignment::CopyAllocation* prefetch : prefetches) { - const HloUse& use = *prefetch->uses().begin(); - if (use.instruction->name() == "op14") { - EXPECT_EQ(prefetch->copy_done_schedule_before(), 14); - EXPECT_EQ(prefetch->copy_start_schedule_after(), 0); - } else { - ASSERT_EQ(use.instruction->name(), "op1"); - EXPECT_EQ(prefetch->copy_done_schedule_before(), 1); - if (prefetch->copy_start_schedule_after() == 0) { - EXPECT_FALSE(seen_overlap); - seen_overlap = true; - } else { - EXPECT_GT(prefetch->copy_start_schedule_after(), 1); - EXPECT_FALSE(seen_nonoverlap); - seen_nonoverlap = true; - } - } - } - // We expect to fully saturate the default memory bandwidth. Total default - // memory accesses: - // param0 (128 B) + param1 (128 B) + op1 (128 B) + param2 (16 B) = 400 B - // execution time: - // 400 B / 32 B/s = 12.5 s. - EXPECT_EQ(optimizer->CalculateExecutionTime(), 12.5); - - // Check the memory used at each point of the loop. - const std::vector& remaining_memory = optimizer->remaining_memory(); - // Time 0: 3 temporaries (16 B) + param0 (128 B) + param1 (128 B) - EXPECT_EQ(remaining_memory.at(0), 512 - (3 * 16 + 128 + 128)); - // Time 1: 2 temporaries (16 B) + 2*param0 (128 B) + param1 (128 B) - // + param2 (16 B) - EXPECT_EQ(remaining_memory.at(1), 512 - (2 * 16 + 2 * 128 + 128 + 16)); - // Times 2 and 3: 3 temporaries (16 B) + param0 (128 B) + param2 (16 B) - EXPECT_EQ(remaining_memory.at(2), 512 - (3 * 16 + 128 + 16)); - EXPECT_EQ(remaining_memory.at(3), 512 - (3 * 16 + 128 + 16)); - // Times 4 to 13: 3 temporaries (16 B) + param0 (128 B) + param1 (128 B) - // + param2 (16 B) - for (int i = 4; i <= 13; ++i) { - EXPECT_EQ(remaining_memory.at(i), 512 - (3 * 16 + 128 + 128 + 16)); - } - // Time 14: 2 temporaries (16 B) + param0 (128 B) + param1 (128 B) - // + param2 (16 B) - EXPECT_EQ(remaining_memory.at(14), 512 - (2 * 16 + 128 + 128 + 16)); -} - -TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithoutOverlap) { - // Same as the test above, except the size of alternate memory is less than - // 384, which is the minimum amount needed to keep the three 128-byte sized - // parameters alive (one of the parameters would need to be overlapped with - // the previous iteration, so counts 2X). In that case, we won't be able to - // fully saturate the bandwidth. - // - // alternate_mem_bytes_per_second = 128 - // default_mem_bytes_per_second = 32 - // flops_per_second = 16 - // f32[1,4] add: flops: 4, bytes: 48, compute elapsed: 0.25 - // - All default memory elapsed: 1.5 - // - All alternate memory elapsed: 0.375 - // f32[8,4] add: flops: 32, bytes: 384, compute elapsed: 2 - // - All default memory elapsed: 12 - // - All alternate memory elapsed: 3 - // f32[1,4] copy: bytes: 16, memory elapsed: 0.5 - // f32[8,4] copy: bytes: 128, memory elapsed: 4 - absl::string_view hlo_loop_str = R"( - $op0 = f32[1,4] add(f32[1,4] $prev_op13, f32[1,4] $prev_op14) - $op1 = f32[8,4] add(f32[8,4] $param0, f32[8,4] $param1) - $op2 = f32[1,4] add(f32[1,4] $prev_op14, f32[1,4] $op0) - $op3 = f32[1,4] add(f32[1,4] $op0, f32[1,4] $op2) - $op4 = f32[1,4] add(f32[1,4] $op2, f32[1,4] $op3) - $op5 = f32[1,4] add(f32[1,4] $op3, f32[1,4] $op4) - $op6 = f32[1,4] add(f32[1,4] $op4, f32[1,4] $op5) - $op7 = f32[1,4] add(f32[1,4] $op5, f32[1,4] $op6) - $op8 = f32[1,4] add(f32[1,4] $op6, f32[1,4] $op7) - $op9 = f32[1,4] add(f32[1,4] $op7, f32[1,4] $op8) - $op10 = f32[1,4] add(f32[1,4] $op8, f32[1,4] $op9) - $op11 = f32[1,4] add(f32[1,4] $op9, f32[1,4] $op10) - $op12 = f32[1,4] add(f32[1,4] $op10, f32[1,4] $op11) - $op13 = f32[1,4] add(f32[1,4] $op11, f32[1,4] $op12) - $op14 = f32[1,4] add(f32[1,4] $param2, f32[1,4] $op13) - )"; - - int loop_start_idx; - MemoryBoundLoopOptimizer* optimizer; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndCreateOptimizer(hlo_loop_str, - /*alternate_memory_size=*/350, - loop_start_idx, &optimizer)); - - optimizer->Optimize(); - // We expect the prefetches to be scheduled this way: - // - // - // param0 or param1: - // ===========> =====================================> - // param2: - // =====> ===============================> - // 13 14| 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14| 0 1 - // prev | loop | next - std::vector prefetches; - for (const MemoryBoundLoopOptimizer::LoopValue& loop_value : - optimizer->loop_values()) { - if (!loop_value.allocations.empty() && - loop_value.allocations.back()->is_copy_allocation()) { - prefetches.push_back( - static_cast( - loop_value.allocations.back().get())); - } - } - EXPECT_EQ(prefetches.size(), 2); - std::optional expected_op14_copy_start_time; - for (const MemorySpaceAssignment::CopyAllocation* prefetch : prefetches) { - const HloUse& use = *prefetch->uses().begin(); - if (use.instruction->name() == "op1") { - EXPECT_EQ(prefetch->copy_done_schedule_before(), 1); - EXPECT_GT(prefetch->copy_start_schedule_after(), 1); - expected_op14_copy_start_time = prefetch->copy_start_schedule_after(); - } - } - EXPECT_TRUE(expected_op14_copy_start_time.has_value()); - for (const MemorySpaceAssignment::CopyAllocation* prefetch : prefetches) { - const HloUse& use = *prefetch->uses().begin(); - if (use.instruction->name() == "op14") { - EXPECT_EQ(prefetch->copy_done_schedule_before(), 14); - EXPECT_EQ(prefetch->copy_start_schedule_after(), - *expected_op14_copy_start_time); - } - } - // We expect not to fully saturate the default memory bandwidth. - EXPECT_GT(optimizer->CalculateExecutionTime(), 12.5); -} - -TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithOverlap2) { - // Same as PrefetchFifoOrderWithOverlap, except the instructions are shifted - // earlier by one such that param0 and param1 are used by op0. This tests that - // we are accounting for overlaps for prefetches that span three iterations. - // - // alternate_mem_bytes_per_second = 128 - // default_mem_bytes_per_second = 32 - // flops_per_second = 16 - // f32[1,4] add: flops: 4, bytes: 48, compute elapsed: 0.25 - // - All default memory elapsed: 1.5 - // - All alternate memory elapsed: 0.375 - // f32[8,4] add: flops: 32, bytes: 384, compute elapsed: 2 - // - All default memory elapsed: 12 - // - All alternate memory elapsed: 3 - // f32[1,4] copy: bytes: 16, memory elapsed: 0.5 - // f32[8,4] copy: bytes: 128, memory elapsed: 4 - absl::string_view hlo_loop_str = R"( - $op0 = f32[8,4] add(f32[8,4] $param0, f32[8,4] $param1) - $op1 = f32[1,4] add(f32[1,4] $prev_op13, f32[1,4] $prev_op14) - $op2 = f32[1,4] add(f32[1,4] $prev_op14, f32[1,4] $op1) - $op3 = f32[1,4] add(f32[1,4] $op1, f32[1,4] $op2) - $op4 = f32[1,4] add(f32[1,4] $op2, f32[1,4] $op3) - $op5 = f32[1,4] add(f32[1,4] $op3, f32[1,4] $op4) - $op6 = f32[1,4] add(f32[1,4] $op4, f32[1,4] $op5) - $op7 = f32[1,4] add(f32[1,4] $op5, f32[1,4] $op6) - $op8 = f32[1,4] add(f32[1,4] $op6, f32[1,4] $op7) - $op9 = f32[1,4] add(f32[1,4] $op7, f32[1,4] $op8) - $op10 = f32[1,4] add(f32[1,4] $op8, f32[1,4] $op9) - $op11 = f32[1,4] add(f32[1,4] $op9, f32[1,4] $op10) - $op12 = f32[1,4] add(f32[1,4] $op10, f32[1,4] $op11) - $op13 = f32[1,4] add(f32[1,4] $param2, f32[1,4] $op12) - $op14 = f32[1,4] add(f32[1,4] $op12, f32[1,4] $op13) - )"; - - int loop_start_idx; - MemoryBoundLoopOptimizer* optimizer; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndCreateOptimizer(hlo_loop_str, - /*alternate_memory_size=*/512, - loop_start_idx, &optimizer)); - - optimizer->Optimize(); - // We expect the prefetches to be scheduled this way: - // - // - // param0 or param1: - // ========> =====================================> === - // param1 or param0: - // ========> ====== - // ==============================================> - // param2: - // ==> ========================================> ====== - // 13 14| 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14| 0 1 - // prev | loop | next - std::vector prefetches; - for (const MemoryBoundLoopOptimizer::LoopValue& loop_value : - optimizer->loop_values()) { - if (!loop_value.allocations.empty() && - loop_value.allocations.back()->is_copy_allocation()) { - prefetches.push_back( - static_cast( - loop_value.allocations.back().get())); - } - } - EXPECT_EQ(prefetches.size(), 3); - bool seen_overlap = false; - bool seen_nonoverlap = false; - for (const MemorySpaceAssignment::CopyAllocation* prefetch : prefetches) { - const HloUse& use = *prefetch->uses().begin(); - if (use.instruction->name() == "op13") { - EXPECT_EQ(prefetch->copy_done_schedule_before(), 13); - EXPECT_EQ(prefetch->copy_start_schedule_after(), 14); - } else { - ASSERT_EQ(use.instruction->name(), "op0"); - EXPECT_EQ(prefetch->copy_done_schedule_before(), 0); - if (prefetch->copy_start_schedule_after() == 14) { - EXPECT_FALSE(seen_overlap); - seen_overlap = true; - } else { - EXPECT_LT(prefetch->copy_start_schedule_after(), 14); - EXPECT_FALSE(seen_nonoverlap); - seen_nonoverlap = true; - } - } - } - // We expect to fully saturate the default memory bandwidth. Total default - // memory accesses: - // param0 (128 B) + param1 (128 B) + op1 (128 B) + param2 (16 B) = 400 B - // execution time: - // 400 B / 32 B/s = 12.5 s. - EXPECT_EQ(optimizer->CalculateExecutionTime(), 12.5); -} - -TEST_F(MemoryBoundLoopOptimizerTest, OptimizerEndToEnd) { - absl::string_view hlo_loop_str = R"( - $op0 = f32[1,4] add(f32[1,4] $prev_op13, f32[1,4] $prev_op14) - $op1 = f32[8,4] add(f32[8,4] $param0, f32[8,4] $param1) - $op2 = f32[1,4] add(f32[1,4] $prev_op14, f32[1,4] $op0) - $op3 = f32[1,4] add(f32[1,4] $op0, f32[1,4] $op2) - $op4 = f32[1,4] add(f32[1,4] $op2, f32[1,4] $op3) - $op5 = f32[1,4] add(f32[1,4] $op3, f32[1,4] $op4) - $op6 = f32[1,4] add(f32[1,4] $op4, f32[1,4] $op5) - $op7 = f32[1,4] add(f32[1,4] $op5, f32[1,4] $op6) - $op8 = f32[1,4] add(f32[1,4] $op6, f32[1,4] $op7) - $op9 = f32[1,4] add(f32[1,4] $op7, f32[1,4] $op8) - $op10 = f32[1,4] add(f32[1,4] $op8, f32[1,4] $op9) - $op11 = f32[1,4] add(f32[1,4] $op9, f32[1,4] $op10) - $op12 = f32[1,4] add(f32[1,4] $op10, f32[1,4] $op11) - $op13 = f32[1,4] add(f32[1,4] $op11, f32[1,4] $op12) - $op14 = f32[1,4] add(f32[1,4] $param2, f32[1,4] $op13) - ROOT $root = tuple($op1, $op14) - )"; - - int loop_start_idx; - MemoryBoundLoopOptimizer* optimizer; - TF_ASSERT_OK_AND_ASSIGN( - auto module, ParseAndCreateOptimizer(hlo_loop_str, - /*alternate_memory_size=*/1024, - loop_start_idx, &optimizer)); - - optimizer->Optimize(); - TF_ASSERT_OK_AND_ASSIGN(auto preset_assignments, - RunMsa(module.get(), /*alternate_memory_size=*/1024)); - - TF_ASSERT_OK(VerifyMsaEquivalence(module.get())); -} - -TEST_F(MemoryBoundLoopOptimizerTest, OptimizerEndToEndUnsupportedAllocation) { - // op2 is a loop-carried dependency, which is currently not supported. But the - // usual MSA algorithm should still be able to give it an alternate memory - // allocation. - absl::string_view hlo_loop_str = R"( - $op0 = f32[1,4] add(f32[1,4] $prev_op3, f32[1,4] $prev_op4) - $op1 = f32[8,4] add(f32[8,4] $param0, f32[8,4] $param1) - $op2 = f32[1,4] add(f32[1,4] $prev_op2, f32[1,4] $op0) - $op3 = f32[1,4] add(f32[1,4] $op0, f32[1,4] $op2) - $op4 = f32[1,4] add(f32[1,4] $op2, f32[1,4] $op3) - ROOT $root = tuple($op1, $op4) - )"; - - int loop_start_idx; - MemoryBoundLoopOptimizer* optimizer; - TF_ASSERT_OK_AND_ASSIGN( - auto module, ParseAndCreateOptimizer(hlo_loop_str, - /*alternate_memory_size=*/1024, - loop_start_idx, &optimizer)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); - optimizer->Optimize(); - TF_ASSERT_OK_AND_ASSIGN(auto preset_assignments, - RunMsa(module.get(), /*alternate_memory_size=*/1024)); + AssignMemorySpace(module.get(), DefaultMemorySpaceOptions(), + /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2); - TF_ASSERT_OK(VerifyMsaEquivalence(module.get(), - /*expect_unsupported_allocations=*/true)); + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 1); + if (!cross_program_prefetches.empty()) { + EXPECT_EQ(cross_program_prefetches[0].parameter, 0); + EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({1})); + } - const HloInstruction* op2 = FindInstruction(module.get(), "op2"); - EXPECT_EQ(op2->shape().layout().memory_space(), kAlternateMemorySpace); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr dataflow_analysis, + HloDataflowAnalysis::Run(*module)); + const HloValue& cross_program_prefetched_value = + dataflow_analysis->GetValueDefinedAt( + module->entry_computation()->parameter_instruction(0), {1}); + // Expect that there is one prefetch that use this value, the cross-program + // prefetch. There shouldn't be an end-of-program prefetch. + auto is_cross_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + use.instruction->cross_program_prefetch_index().has_value(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(), + is_cross_program_prefetch), + 1); + auto is_end_of_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + !use.instruction->cross_program_prefetch_index().has_value(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(), + is_end_of_program_prefetch), + 0); } -TEST_F(MemoryBoundLoopOptimizerTest, OptimizerEndToEndWhileLoop) { - absl::string_view hlo_str = R"( +TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchBufferUnused) { + absl::string_view hlo_string = R"( HloModule module, is_scheduled=true -while_cond { - while_cond_param = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) parameter(0) - ROOT p = pred[] get-tuple-element(while_cond_param), index=6 -} - -while_body { - while_body_param = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) parameter(0) - prev_param0 = f32[1,4] get-tuple-element(while_body_param), index=0 - param0 = f32[1,4] get-tuple-element(while_body_param), index=1 - next_param0 = f32[1,4] get-tuple-element(while_body_param), index=2 - prev_prev_op3 = f32[1,4] get-tuple-element(while_body_param), index=3 - prev_prev_op4 = f32[1,4] get-tuple-element(while_body_param), index=4 - prev_op0 = f32[1,4] add(f32[1,4] prev_prev_op3, f32[1,4] prev_prev_op4) - prev_op1 = f32[1,4] add(f32[1,4] prev_prev_op4, f32[1,4] prev_op0) - prev_op2 = f32[1,4] add(f32[1,4] prev_op0, f32[1,4] prev_op1) - prev_op3 = f32[1,4] add(f32[1,4] prev_op1, f32[1,4] prev_op2) - prev_op4 = f32[1,4] multiply(f32[1,4] prev_param0, f32[1,4] prev_op3) - op0 = f32[1,4] add(f32[1,4] prev_op3, f32[1,4] prev_op4) - op1 = f32[1,4] add(f32[1,4] prev_op4, f32[1,4] op0) - op2 = f32[1,4] add(f32[1,4] op0, f32[1,4] op1) - op3 = f32[1,4] add(f32[1,4] op1, f32[1,4] op2) - op4 = f32[1,4] multiply(f32[1,4] param0, f32[1,4] op3) - next_op0 = f32[1,4] add(f32[1,4] op3, f32[1,4] op4) - next_op1 = f32[1,4] add(f32[1,4] op4, f32[1,4] next_op0) - next_op2 = f32[1,4] add(f32[1,4] next_op0, f32[1,4] next_op1) - next_op3 = f32[1,4] add(f32[1,4] next_op1, f32[1,4] next_op2) - next_op4 = f32[1,4] multiply(f32[1,4] next_param0, f32[1,4] next_op3) - p = pred[] get-tuple-element(while_body_param), index=6 - ROOT root = tuple(prev_param0, param0, next_param0, prev_prev_op3, prev_prev_op4, next_op4, p) +%fused_computation { + %param_0.2 = f32[32]{0} parameter(0) + %param_1.4 = s32[100]{0} parameter(1) + %custom-call.1 = s32[100]{0} custom-call(s32[100]{0} %param_1.4), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[100]{0}} + %slice.1 = s32[32]{0} slice(s32[100]{0} %custom-call.1), slice={[0:32]} + %reshape.7 = s32[32]{0} reshape(s32[32]{0} %slice.1) + %transpose.5 = s32[32]{0} transpose(s32[32]{0} %reshape.7), dimensions={0} + %gather.1 = f32[32]{0} gather(f32[32]{0} %param_0.2, s32[32]{0} %transpose.5), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1} + %transpose.4 = f32[32]{0} transpose(f32[32]{0} %gather.1), dimensions={0} + ROOT %reshape.6 = f32[32]{0} reshape(f32[32]{0} %transpose.4) } -ENTRY entry { - p0 = f32[1,4] parameter(0) - p1 = f32[1,4] parameter(1) - p2 = f32[1,4] parameter(2) - p3 = f32[1,4] parameter(3) - p4 = f32[1,4] parameter(4) - p5 = pred[] parameter(5) - copy = f32[1,4] copy(p4) - tuple = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) tuple(p0, p1, p2, p3, p4, copy, p5) - while = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) while(tuple), condition=while_cond, body=while_body - ROOT root = f32[1,4] get-tuple-element(while), index=5 +%i.reduce_sub_computation { + %rhs = s32[] parameter(1) + %lhs = s32[] parameter(0) + ROOT %add = s32[] add(s32[] %lhs, s32[] %rhs) } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_str)); +%fused_computation.1 { + %constant.4 = s32[] constant(0) + %broadcast.4 = s32[100]{0} broadcast(s32[] %constant.4), dimensions={} + %param_0.4 = s32[32]{0} parameter(0) + %pad.1 = s32[100]{0} pad(s32[32]{0} %param_0.4, s32[] %constant.4), padding=0_68 + %constant.3 = s32[] constant(76031) + %broadcast.3 = s32[100]{0} broadcast(s32[] %constant.3), dimensions={} + ROOT %clamp.1 = s32[100]{0} clamp(s32[100]{0} %broadcast.4, s32[100]{0} %pad.1, s32[100]{0} %broadcast.3) +} - TF_ASSERT_OK_AND_ASSIGN(auto preset_assignments, - RunMsa(module.get(), /*alternate_memory_size=*/512)); +ENTRY %main { + %constant = s32[] constant(0) + %i = s32[32,1]{0,1} parameter(1) + %o = f32[32]{0} parameter(0) + %reduce = s32[32]{0} reduce(s32[32,1]{0,1} %i, s32[] %constant), dimensions={1}, to_apply=%i.reduce_sub_computation + %fusion.1 = s32[100]{0} fusion(s32[32]{0} %reduce), kind=kLoop, calls=%fused_computation.1 + ROOT %fusion = f32[32]{0} fusion(f32[32]{0} %o, s32[100]{0} %fusion.1), kind=kCustom, calls=%fused_computation +} + )"; - // We expect operand 0 of prev_op4, op4, and next_op4 to all be prefetches of - // same distance from the user. - TF_ASSERT_OK_AND_ASSIGN(auto alias_analysis, - HloAliasAnalysis::Run(module.get())); - TF_ASSERT_OK_AND_ASSIGN(auto hlo_live_range, - HloLiveRange::Run(module->schedule(), *alias_analysis, - module->entry_computation())); - const HloInstruction* prev_copy_done = - FindInstruction(module.get(), "prev_op4")->operand(0); - const HloInstruction* copy_done = - FindInstruction(module.get(), "op4")->operand(0); - const HloInstruction* next_copy_done = - FindInstruction(module.get(), "next_op4")->operand(0); - ASSERT_EQ(prev_copy_done->opcode(), HloOpcode::kCopyDone); - ASSERT_EQ(copy_done->opcode(), HloOpcode::kCopyDone); - ASSERT_EQ(next_copy_done->opcode(), HloOpcode::kCopyDone); - EXPECT_EQ(prev_copy_done->shape().layout().memory_space(), - kAlternateMemorySpace); - EXPECT_EQ(copy_done->shape().layout().memory_space(), kAlternateMemorySpace); - EXPECT_EQ(next_copy_done->shape().layout().memory_space(), - kAlternateMemorySpace); - auto prefetch_distance = [&](const HloInstruction* copy_done) { - return hlo_live_range->instruction_schedule().at(copy_done) - - hlo_live_range->instruction_schedule().at(copy_done->operand(0)); - }; - EXPECT_EQ(prefetch_distance(prev_copy_done), prefetch_distance(copy_done)); - EXPECT_EQ(prefetch_distance(next_copy_done), prefetch_distance(copy_done)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Fusion(op::AsyncCopy(kAlternateMemorySpace, + kDefaultMemorySpace, op::Parameter(0)), + op::Fusion())); } -TEST_F(MemoryBoundLoopOptimizerTest, OptimizerEndToEndNestedWhileLoopBug) { - absl::string_view hlo_str = R"( +// Test description: +// - Setup: Make sure p1 can not be prefetched to alternate memory until after +// instruction c. We do this by causing p0 to be prefetched to alternate +// memory for use in c. Since p0 is larger than 1/2 of alternate memory, we +// will not be able to prefetch p1 until after p0 is unallocated. +// - Test: prefetch p1, after p0 is unallocated from alternate memory (after +// instruction c). +TEST_P(MemorySpaceAssignmentTest, CopyResourceIntegration) { + std::string_view hlo_string = R"( HloModule module, is_scheduled=true -prev_while_cond { - prev_while_cond_param = (f32[1,4], pred[]) parameter(0) - ROOT p = pred[] get-tuple-element(prev_while_cond_param), index=1 -} - -prev_while_body { - prev_while_body_param = (f32[1,4], pred[]) parameter(0) - prev_while_body_gte = f32[1,4] get-tuple-element(prev_while_body_param), index=0 - prev_while_body_pred = pred[] get-tuple-element(prev_while_body_param), index=1 - prev_while_body_op = f32[1,4] negate(prev_while_body_gte) - ROOT prev_while_body_root = (f32[1,4], pred[]) tuple(prev_while_body_op, prev_while_body_pred) -} - -current_while_cond { - current_while_cond_param = (f32[1,4], pred[]) parameter(0) - ROOT p = pred[] get-tuple-element(current_while_cond_param), index=1 -} - -current_while_body { - current_while_body_param = (f32[1,4], pred[]) parameter(0) - current_while_body_gte = f32[1,4] get-tuple-element(current_while_body_param), index=0 - current_while_body_pred = pred[] get-tuple-element(current_while_body_param), index=1 - current_while_body_op = f32[1,4] negate(current_while_body_gte) - ROOT current_while_body_root = (f32[1,4], pred[]) tuple(current_while_body_op, current_while_body_pred) -} - -next_while_cond { - next_while_cond_param = (f32[1,4], pred[]) parameter(0) - ROOT p = pred[] get-tuple-element(next_while_cond_param), index=1 -} - -next_while_body { - next_while_body_param = (f32[1,4], pred[]) parameter(0) - next_while_body_gte = f32[1,4] get-tuple-element(next_while_body_param), index=0 - next_while_body_pred = pred[] get-tuple-element(next_while_body_param), index=1 - next_while_body_op = f32[1,4] negate(next_while_body_gte) - ROOT next_while_body_root = (f32[1,4], pred[]) tuple(next_while_body_op, next_while_body_pred) -} - -while_cond { - while_cond_param = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) parameter(0) - ROOT p = pred[] get-tuple-element(while_cond_param), index=6 -} - -while_body { - while_body_param = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) parameter(0) - prev_param0 = f32[1,4] get-tuple-element(while_body_param), index=0 - param0 = f32[1,4] get-tuple-element(while_body_param), index=1 - next_param0 = f32[1,4] get-tuple-element(while_body_param), index=2 - prev_prev_op3 = f32[1,4] get-tuple-element(while_body_param), index=3 - prev_prev_op4 = f32[1,4] get-tuple-element(while_body_param), index=4 - while_pred = pred[] get-tuple-element(while_body_param), index=6 - prev_op0 = f32[1,4] add(f32[1,4] prev_prev_op3, f32[1,4] prev_prev_op4) - prev_op1 = f32[1,4] add(f32[1,4] prev_prev_op4, f32[1,4] prev_op0) - prev_op2 = f32[1,4] add(f32[1,4] prev_op0, f32[1,4] prev_op1) - prev_op3 = f32[1,4] add(f32[1,4] prev_op1, f32[1,4] prev_op2) - prev_tuple = (f32[1,4], pred[]) tuple(prev_op3, while_pred) - prev_while = (f32[1,4], pred[]) while(prev_tuple), condition=prev_while_cond, body=prev_while_body - prev_gte = f32[1,4] get-tuple-element(prev_while), index=0 - prev_op4 = f32[1,4] multiply(f32[1,4] prev_param0, f32[1,4] prev_gte) - op0 = f32[1,4] add(f32[1,4] prev_op3, f32[1,4] prev_op4) - op1 = f32[1,4] add(f32[1,4] prev_op4, f32[1,4] op0) - op2 = f32[1,4] add(f32[1,4] op0, f32[1,4] op1) - op3 = f32[1,4] add(f32[1,4] op1, f32[1,4] op2) - current_tuple = (f32[1,4], pred[]) tuple(op3, while_pred) - current_while = (f32[1,4], pred[]) while(current_tuple), condition=current_while_cond, body=current_while_body - current_gte = f32[1,4] get-tuple-element(current_while), index=0 - op4 = f32[1,4] multiply(f32[1,4] param0, f32[1,4] current_gte) - next_op0 = f32[1,4] add(f32[1,4] op3, f32[1,4] op4) - next_op1 = f32[1,4] add(f32[1,4] op4, f32[1,4] next_op0) - next_op2 = f32[1,4] add(f32[1,4] next_op0, f32[1,4] next_op1) - next_op3 = f32[1,4] add(f32[1,4] next_op1, f32[1,4] next_op2) - next_tuple = (f32[1,4], pred[]) tuple(next_op3, while_pred) - next_while = (f32[1,4], pred[]) while(next_tuple), condition=next_while_cond, body=next_while_body - next_gte = f32[1,4] get-tuple-element(next_while), index=0 - next_op4 = f32[1,4] multiply(f32[1,4] next_param0, f32[1,4] next_gte) - ROOT root = tuple(prev_param0, param0, next_param0, prev_prev_op3, prev_prev_op4, next_op4, while_pred) -} +ENTRY main { + p0 = s32[8,8] parameter(0) + p1 = s32[8,8] parameter(1) + p2 = s32[] parameter(2) + a = negate(p2) + b = negate(a) + c = add(p0, p0) + d = negate(b) + e = negate(d) + f = add(p1, p1) -ENTRY entry { - p0 = f32[1,4] parameter(0) - p1 = f32[1,4] parameter(1) - p2 = f32[1,4] parameter(2) - p3 = f32[1,4] parameter(3) - p4 = f32[1,4] parameter(4) - p5 = pred[] parameter(5) - copy = f32[1,4] copy(p4) - tuple = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) tuple(p0, p1, p2, p3, p4, copy, p5) - while = (f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], f32[1,4], pred[]) while(tuple), condition=while_cond, body=while_body - ROOT root = f32[1,4] get-tuple-element(while), index=5 + ROOT result = tuple(e,c,f) } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_str)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + Options options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 300; + + // Setup cost analysis so it takes 2 instructions to prefetch anything. + HloCostAnalysis hlo_cost_analysis(ShapeSize); + CostAnalysisOptions cost_analysis_options; + TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, + FakeCostAnalysis::Create(hlo_cost_analysis, *module, + cost_analysis_options)); + cost_analysis->SetOverrideForGetInstructionElapsed( + [](const HloInstruction& instruction) -> float { return 10.0; }); + cost_analysis->SetOverrideForGetAsyncCopyElapsed( + [](const Shape& shape) -> float { return 20.0; }); + options.cost_analysis = cost_analysis.get(); + CostAnalysisPrefetchIntervalPicker prefetch_interval_picker( + CostAnalysisPrefetchIntervalPicker( + *cost_analysis, /*min_overlap_to_async_copy_ratio=*/0.8, + /*preferred_overlap_to_async_copy_ratio=*/1.5, + /*max_overlap_to_mem_size_async_copy_ratio=*/10.0, + /*mem_size_bytes=*/options.max_size_in_bytes)); + + // p0 has the highest priority, followed by p1, followed by everything else. + MsaBufferIntervalCompare compare = [](const MsaBufferInterval& lhs, + const MsaBufferInterval& rhs) -> bool { + auto lookup = [](const MsaBufferInterval& x) { + // An arbitrary value that is greater than that for p0 and p1. + int priority = 100; + if (x.buffer->instruction()->name() == "p0") { + priority = 0; + } else if (x.buffer->instruction()->name() == "p1") { + priority = 1; + } + return std::make_tuple(priority, x.buffer->instruction()->name()); + }; + + return lookup(lhs) < lookup(rhs); + }; + + // Run test. + AssignMemorySpace(module.get(), options, compare, &prefetch_interval_picker); + + // - Make sure the setup occurred, i.e., that p0 is prefetched to alternate + // memory for use by c. + // - Make sure p1 is prefetched. + ASSERT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(_, + // p0 is prefetched to alternate memory for use by c. + op::Add(op::AsyncCopy(kAlternateMemorySpace, + kDefaultMemorySpace, op::Parameter(0)), + op::AsyncCopy(kAlternateMemorySpace, + kDefaultMemorySpace, op::Parameter(0))), + // p1 is prefetched to alternate memory for use by f. + op::Add(op::AsyncCopy(kAlternateMemorySpace, + kDefaultMemorySpace, op::Parameter(1)), + op::AsyncCopy(kAlternateMemorySpace, + kDefaultMemorySpace, op::Parameter(1))))); - TF_ASSERT_OK_AND_ASSIGN(auto preset_assignments, - RunMsa(module.get(), /*alternate_memory_size=*/512)); + // Check the schedule + const std::vector& schedule = + module->schedule().sequence(module->entry_computation()).instructions(); + auto find_schedule_index = [&schedule](std::string_view name) -> int { + for (int i = 0; i < schedule.size(); ++i) { + if (schedule[i]->name() == name) { + return i; + } + } + LOG(FATAL) << "Unable to find index of instruction with name " << name; + }; + int c_index = find_schedule_index("c"); + int p1_copy_start = find_schedule_index(module->entry_computation() + ->root_instruction() // result + ->operand(2) // f + ->operand(0) // copy done + ->operand(0) // copy start + ->name()); + int d_index = find_schedule_index("d"); + int e_index = find_schedule_index("e"); + int p1_copy_end = find_schedule_index(module->entry_computation() + ->root_instruction() // result + ->operand(2) // f + ->operand(0) // copy done + ->name()); + int f_index = find_schedule_index("f"); + // We expect to start copying p1 after c. + EXPECT_EQ(p1_copy_start, c_index + 1); + // d and e should follow come between p1's copy start and end. + EXPECT_EQ(d_index, p1_copy_start + 1); + EXPECT_EQ(e_index, d_index + 1); + EXPECT_EQ(p1_copy_end, e_index + 1); + // f should immediately follow the end of p1's copy. + EXPECT_EQ(f_index, p1_copy_end + 1); } class SlicedPrefetchStartTimePickerTest : public ::testing::Test { @@ -10660,7 +9665,7 @@ class SlicedPrefetchStartTimePickerTest : public ::testing::Test { std::vector Pick( const std::vector& schedule_data, int64_t num_slices, int64_t prefetch_start_time, int64_t prefetch_end_time) { - return memory_space_assignment::SlicedPrefetchStartTimePicker::Pick( + return SlicedPrefetchStartTimePicker::Pick( num_slices, prefetch_start_time, prefetch_end_time, [&schedule_data](int64_t exclusive_start_time, int64_t exclusive_end_time) { @@ -10883,20 +9888,20 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { } // A class that can be mocked to set expectations on slice proposals. To do - // that, we set memory_space_assignment::Options::propose_slice_fn to a lambda - // that calls our mocks ProposeSlices() method. + // that, we set Options::propose_slice_fn to a lambda that calls our mocks + // ProposeSlices() method. class SliceProposer { public: SliceProposer() = default; virtual ~SliceProposer() = default; - virtual StatusOr ProposeSlices( + virtual absl::StatusOr ProposeSlices( const Shape& shape, const SlicedPrefetchOptions& options) = 0; }; class MockSliceProposer : public SliceProposer { public: - MOCK_METHOD(StatusOr, ProposeSlices, + MOCK_METHOD(absl::StatusOr, ProposeSlices, (const Shape& shape, const SlicedPrefetchOptions& options), (override)); }; @@ -10904,9 +9909,11 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { // An HloInstruction* matcher for matching the asynchronous sliced copies // produced by MSA. In particular, the matcher performs the following // checks: - // - The copy is concluded with a concat-bitcast custom call + // - The copy is concluded with a concat-bitcast custom call, or a + // bitcast of a concat-bitcast custom call if expect_bitcasted_io is true // - The operands to the concat-bitcast are asynchronous slices of the - // expected operand + // expected operand, or asynchronous slices of a bitcast of the expected + // operand if expect_bitcasted_io is true // - The number of slices is as expected (i.e., // expected_slice_params_per_slice_in_spatial_order_.size()) // - The copy is from and to the correct memory spaces @@ -10924,47 +9931,57 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { AsyncSlicedCopy(int64_t to_space, int64_t from_space, std::vector> expected_slice_params_per_slice_in_spatial_order, - ::testing::Matcher operand) + ::testing::Matcher operand, + bool expect_bitcasted_io) : to_space_(to_space), from_space_(from_space), expected_slice_params_per_slice_in_spatial_order_( std::move(expected_slice_params_per_slice_in_spatial_order)), - custom_call_matcher_( - memory_space_assignment::kConcatBitcastCustomCall, - std::vector<::testing::Matcher>( - expected_slice_params_per_slice_in_spatial_order_.size(), - op::AsyncDone(op::AsyncStart(operand)))) {} + base_hlo_matcher_(CreateBaseHloMatcher( + operand, expected_slice_params_per_slice_in_spatial_order_.size(), + expect_bitcasted_io)), + expect_bitcasted_io_(expect_bitcasted_io) {} bool MatchAndExplain( const HloInstruction* instruction, ::testing::MatchResultListener* listener) const override { - // Match the custom call. - if (!custom_call_matcher_.MatchAndExplain(instruction, listener)) { + // Match opcodes and number of operands. + if (!base_hlo_matcher_.MatchAndExplain(instruction, listener)) { return false; } - // Check if the custom call has the proper memory space. - const HloInstruction* concat_bitcast = instruction; - if (!MatchMemorySpace(concat_bitcast, to_space_, "concat-bitcast", - listener)) { + // Check if the copied result has the proper memory space. + if (!MatchMemorySpace(instruction, to_space_, "copy result", listener)) { return false; } - // Check if the copied tensor has the proper memory space. + // Find some instructions in the async copy. + const HloInstruction* concat_bitcast = + (expect_bitcasted_io_ ? instruction->operand(0) : instruction); + VLOG(2) << "AsyncSlicedCopy identified the concat-bitcast as " + << concat_bitcast->name(); const HloInstruction* copy_operand = concat_bitcast->operand(0)->operand(0)->operand(0); - if (!MatchMemorySpace(copy_operand, from_space_, "copy operand", + const HloInstruction* original_copy_operand = + (expect_bitcasted_io_ ? copy_operand->operand(0) : copy_operand); + VLOG(2) << "AsyncSlicedCopy identified the copy operand as " + << copy_operand->name() << ", and the original copy operand as " + << original_copy_operand->name(); + + // Check if the copied tensor has the proper memory space. + if (!MatchMemorySpace(original_copy_operand, from_space_, "copy operand", listener)) { return false; } // Check if the copied tensor retains its shape. - if (!Shape::Equal().IgnoreMemorySpaceInLayout()(concat_bitcast->shape(), - copy_operand->shape())) { + if (!Shape::Equal().IgnoreMemorySpaceInLayout()( + instruction->shape(), original_copy_operand->shape())) { *listener << " has a shape of " - << copy_operand->shape().ToString(/*print_layout=*/true) + << original_copy_operand->shape().ToString( + /*print_layout=*/true) << " before copying but a shape of " - << concat_bitcast->shape().ToString(/*print_layout=*/true) + << instruction->shape().ToString(/*print_layout=*/true) << " after copying (ignoring memory space)"; return false; @@ -11040,7 +10057,7 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { } void DescribeTo(std::ostream* os) const override { - custom_call_matcher_.DescribeTo(os); + base_hlo_matcher_.DescribeTo(os); std::vector slice_parameters_per_operand; for (int op_idx = 0; op_idx < expected_slice_params_per_slice_in_spatial_order_.size(); @@ -11068,6 +10085,22 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { } private: + static ::testing::Matcher CreateBaseHloMatcher( + ::testing::Matcher operand, int64_t num_slices, + bool expect_bitcasted_io) { + if (expect_bitcasted_io) { + return op::Bitcast(op::CustomCall( + kConcatBitcastCustomCall, + std::vector<::testing::Matcher>( + num_slices, + op::AsyncDone(op::AsyncStart(op::Bitcast(operand)))))); + } + return op::CustomCall( + kConcatBitcastCustomCall, + std::vector<::testing::Matcher>( + num_slices, op::AsyncDone(op::AsyncStart(operand)))); + } + static bool MatchMemorySpace(const HloInstruction* instruction, int64_t expected_memory_space, std::string_view error_message_identifier, @@ -11095,7 +10128,8 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { int64_t from_space_; std::vector> expected_slice_params_per_slice_in_spatial_order_; - ::xla::testing::HloCustomCallMatcher custom_call_matcher_; + ::testing::Matcher base_hlo_matcher_; + bool expect_bitcasted_io_; }; // Returns an AsyncSlicedCopy matcher. @@ -11103,10 +10137,11 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { int64_t to_space, int64_t from_space, std::vector> expected_slice_params_per_slice_in_spatial_order, - ::testing::Matcher operand_matcher) { + ::testing::Matcher operand_matcher, + bool expect_bitcasted_io = false) { return ::testing::MakeMatcher(new AsyncSlicedCopy( to_space, from_space, expected_slice_params_per_slice_in_spatial_order, - operand_matcher)); + operand_matcher, expect_bitcasted_io)); } // We make our own matcher for SlicedPrefetchOptions to work around the fact @@ -11220,12 +10255,11 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { // Returns true if instruction is a concat-bitcast. static bool IsConcatBitcast(const HloInstruction* instruction) { - return instruction->IsCustomCall( - memory_space_assignment::kConcatBitcastCustomCall); + return instruction->IsCustomCall(kConcatBitcastCustomCall); } // Returns the index of the first instruction with the given name. - static StatusOr FindScheduleIndexOfInstruction( + static absl::StatusOr FindScheduleIndexOfInstruction( const std::vector& schedule, std::string_view name, InstructionClass c) { for (int i = 0; i < schedule.size(); ++i) { @@ -11251,7 +10285,7 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { return nullptr; } - static StatusOr> GetSliceStartIndicies( + static absl::StatusOr> GetSliceStartIndicies( const std::vector& schedule, const HloInstruction* concat_bitcast) { std::vector indicies; @@ -11500,6 +10534,8 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { std::string_view slices_start_after_instruction_name, std::string_view slices_done_before_instruction_name, bool expect_slices_started_at_different_times) { + CHECK(concat_bitcast->IsCustomCall(kConcatBitcastCustomCall)); + // Get the schedule. auto entry_schedule = module.schedule().sequence(module.entry_computation()).instructions(); @@ -11573,29 +10609,38 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { } // Returns OkStatus iff: - // - When the slices of concat_bitcast are sorted in expected spatial order, - // they are assigned chunks that spatially fall in the same order AND - // - The slices of concat_bitcast are assigned contiguous memory chunks AND - // - The concat_bitcast is assigned a chunk that is the concatenation of the - // slice chunks AND - // - The size of the chunk assigned to the concat_bitcast has the same size - // as the instruction's shape + // - Each slice is assigned a chunk that is the same size as the slice + // instruction's shape. + // - When the slices of sliced_copy_result are sorted in expected spatial + // order, they are assigned chunks that spatially fall in the same order AND + // - The slices of sliced_copy_result are assigned contiguous memory chunks + // AND + // - The sliced_copy_result is assigned a chunk that is the concatenation of + // the slice chunks AND + // - The size of the chunk assigned to the sliced_copy_result has the same + // size as the instruction's shape static Status CheckSliceChunks(const PresetAssignments& assignments, - const HloInstruction* concat_bitcast) { + const HloInstruction* sliced_copy_result, + bool expect_bitcasted_io = false) { + const HloInstruction* concat_bitcast = + (expect_bitcasted_io ? sliced_copy_result->operand(0) + : sliced_copy_result); + CHECK(concat_bitcast->IsCustomCall(kConcatBitcastCustomCall)); + absl::flat_hash_map slices_to_chunks; - std::optional concat_bitcast_chunk = std::nullopt; + std::optional result_chunk = std::nullopt; for (const std::pair& position_chunk_pair : assignments.chunks()) { - if (position_chunk_pair.first.instruction == concat_bitcast) { - if (concat_bitcast_chunk.has_value()) { + if (position_chunk_pair.first.instruction == sliced_copy_result) { + if (result_chunk.has_value()) { return FailedPrecondition( - "%s", absl::StrCat("Concat-bitcast ", concat_bitcast->name(), + "%s", absl::StrCat("Sliced copy ", sliced_copy_result->name(), " is assigned more than one chunk: ", - concat_bitcast_chunk->ToString(), " and ", + result_chunk->ToString(), " and ", position_chunk_pair.second.ToString())); } - concat_bitcast_chunk = position_chunk_pair.second; + result_chunk = position_chunk_pair.second; } for (const HloInstruction* slice : concat_bitcast->operands()) { if (position_chunk_pair.first.instruction == slice) { @@ -11614,7 +10659,7 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { std::vector sorted_slices = SortSlicesInExpectedSpatialOrder(concat_bitcast); - VLOG(1) << "Chunk assignments for " << concat_bitcast->name() << ":\n" + VLOG(1) << "Chunk assignments for " << sliced_copy_result->name() << ":\n" << absl::StrJoin( sorted_slices, "\n", [&](std::string* out, const HloInstruction* slice) { @@ -11626,16 +10671,16 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { absl::StrAppend(out, " slice ", slice->name(), ": ", chunk); }) - << "\n concat-bitcast " << concat_bitcast->name() << ": " - << (concat_bitcast_chunk.has_value() - ? concat_bitcast_chunk->ToString() - : "no chunk assigned"); + << "\n sliced copy result " << sliced_copy_result->name() << ": " + << (result_chunk.has_value() ? result_chunk->ToString() + : "no chunk assigned"); if (sorted_slices.empty()) { return OkStatus(); } // Check that slices are assigned contiguous chunks that are spatially - // ordered according to sorted_slices. + // ordered according to sorted_slices. Also make sure that slices are + // assigned chunks with sizes that match their shape. int64_t previous_end = -1; int64_t min_offset = std::numeric_limits::max(); int64_t max_limit = std::numeric_limits::min(); @@ -11647,6 +10692,16 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { absl::StrCat("Slice ", slice->name(), " is not assigned a chunk")); } const Chunk& chunk = it->second; + + if (chunk.size != ShapeSize(slice->shape())) { + return FailedPrecondition( + "%s", + absl::StrCat("Slice ", slice->name(), " is assigned chunk ", + chunk.ToString(), " with size ", chunk.size, + ". Expected a size of ", ShapeSize(slice->shape()), + ", to match its shape.")); + } + if (previous_end != -1 && chunk.offset != previous_end) { return FailedPrecondition( "%s", absl::StrCat( @@ -11659,31 +10714,29 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { max_limit = std::max(max_limit, chunk.chunk_end()); } - // Check that the concat_bitcast is assigned a chunk that is the + // Check that the sliced copy result is assigned a chunk that is the // concatenation of the slice chunks. - if (!concat_bitcast_chunk.has_value()) { + if (!result_chunk.has_value()) { return FailedPrecondition( - "%s", absl::StrCat("Concat-bitcast ", concat_bitcast->name(), + "%s", absl::StrCat("Sliced copy result ", sliced_copy_result->name(), " is not assigned a chunk.")); } - Chunk expected_concat_bitcast_chunk = - Chunk::FromOffsetEnd(min_offset, max_limit); - if (!(*concat_bitcast_chunk == expected_concat_bitcast_chunk)) { + Chunk expected_result_chunk = Chunk::FromOffsetEnd(min_offset, max_limit); + if (!(*result_chunk == expected_result_chunk)) { return FailedPrecondition( - "%s", - absl::StrCat("Concat-bitcast ", concat_bitcast->name(), - " is assigned chunk ", concat_bitcast_chunk->ToString(), - " but its expected to be assigned chunk ", - expected_concat_bitcast_chunk.ToString())); + "%s", absl::StrCat("Sliced copy result ", sliced_copy_result->name(), + " is assigned chunk ", result_chunk->ToString(), + ", but it's expected to be assigned chunk ", + expected_result_chunk.ToString())); } - if (concat_bitcast_chunk->size != ShapeSize(concat_bitcast->shape())) { + if (result_chunk->size != ShapeSize(sliced_copy_result->shape())) { return FailedPrecondition( - "%s", - absl::StrCat( - "Concat-bitcast ", concat_bitcast->name(), " is assigned chunk ", - concat_bitcast_chunk->ToString(), " with size ", - concat_bitcast_chunk->size, ". Expected a size of ", - ShapeSize(concat_bitcast->shape()), ", to match its shape.")); + "%s", absl::StrCat("Sliced copy result ", sliced_copy_result->name(), + " is assigned chunk ", result_chunk->ToString(), + " with size ", result_chunk->size, + ". Expected a size of ", + ShapeSize(sliced_copy_result->shape()), + ", to match its shape.")); } return OkStatus(); @@ -11696,12 +10749,13 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { options_.max_size_in_bytes = 1024; options_.sliced_prefetch_options.set_max_slices(2); options_.sliced_prefetch_options.set_min_bytes(8); - options_.propose_slice_fn = - [&](const Shape& shape, - const memory_space_assignment::SlicedPrefetchOptions& options) { - return slice_proposer_.ProposeSlices(shape, options); - }; - options_.update_layout_fn = [](Shape* shape) {}; + options_.propose_slice_fn = [&](const Shape& shape, + const SlicedPrefetchOptions& options) { + return slice_proposer_.ProposeSlices(shape, options); + }; + options_.get_equivalent_s8_shape_fn = [](const Shape& original_shape) { + return ShapeUtil::MakeShape(S8, {ShapeSize(original_shape)}); + }; } bool allocate_across_sequential_calls() const override { return true; } @@ -11931,7 +10985,7 @@ ENTRY main { EXPECT_CALL(slice_proposer_, ProposeSlices(f32_8_8_, EqualsSlicedPrefetchOptions( options_.sliced_prefetch_options))) - .WillRepeatedly(Return(StatusOr( + .WillRepeatedly(Return(absl::StatusOr( FailedPrecondition("%s", "Cannot slice.")))); TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text)); @@ -11978,7 +11032,8 @@ ENTRY main { << module->ToString(HloPrintOptions::ShortParsable()); std::unique_ptr assignments = - AssignMemorySpaceUsingCostAnalysis(module.get(), options_); + AssignMemorySpaceUsingCostAnalysis( + module.get(), /*memory_space_options_override=*/options_); VLOG(1) << "Post-MSA module:\n" << module->ToString(HloPrintOptions::ShortParsable()); @@ -12054,7 +11109,8 @@ ENTRY main { << module->ToString(HloPrintOptions::ShortParsable()); std::unique_ptr assignments = - AssignMemorySpaceUsingCostAnalysis(module.get(), options_); + AssignMemorySpaceUsingCostAnalysis( + module.get(), /*memory_space_options_override=*/options_); VLOG(1) << "Post-MSA module:\n" << module->ToString(HloPrintOptions::ShortParsable()); @@ -12122,7 +11178,7 @@ class MockRepacker : public MemorySpaceAssignmentRepacker { MockRepacker() : MemorySpaceAssignmentRepacker(std::numeric_limits::max(), 1) {} - MOCK_METHOD(StatusOr, Repack, (absl::Span), + MOCK_METHOD(absl::StatusOr, Repack, (absl::Span), (override)); }; @@ -12172,10 +11228,6 @@ ENTRY main { ROOT z5 = f32[32,16] add(z4, d) })"; - using Slice = MemorySpaceAssignmentRepacker::Slice; - using SlicedAllocationData = - MemorySpaceAssignmentRepacker::SlicedAllocationData; - // Create 2 copies of the module, one to run without repacking and one to run // with repacking. TF_ASSERT_OK_AND_ASSIGN(auto module_no_repacking, @@ -12206,10 +11258,9 @@ ENTRY main { // Force MSA to prefer prefetching (in order) p1, p2, p3, p4, and then // anything else. - MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare = - [](const MemorySpaceAssignment::BufferInterval& lhs, - const MemorySpaceAssignment::BufferInterval& rhs) { - auto lookup = [](const MemorySpaceAssignment::BufferInterval& x) { + MsaBufferIntervalCompare buffer_interval_compare = + [](const MsaBufferInterval& lhs, const MsaBufferInterval& rhs) { + auto lookup = [](const MsaBufferInterval& x) { // An arbitrary value that is greater than that for p1, p2, p3, and // p4. int priority = 100; @@ -12250,11 +11301,11 @@ ENTRY main { MockRepacker repacker; absl::flat_hash_map, int64_t> repack_map; EXPECT_CALL(repacker, Repack(_)) - .WillRepeatedly([](absl::Span allocations) - -> StatusOr { + .WillRepeatedly([](absl::Span allocations) + -> absl::StatusOr { bool found_p2 = false; bool found_p3 = false; - for (MockRepacker::AllocationBlock* block : allocations) { + for (AllocationBlock* block : allocations) { VLOG(1) << "Allocation block: " << block->ToString(); if (block->inclusive_start_time == 3 && @@ -12267,15 +11318,15 @@ ENTRY main { EXPECT_TRUE(block->original_slice_data.has_value()); if (block->original_slice_data.has_value()) { SlicedAllocationData expected( - {{Slice{1024, 1024, /*inclusive_start_time=*/3}, - Slice{1024, 2048, /*inclusive_start_time=*/7}}}); + {{AllocatedSlice{1024, 1024, /*inclusive_start_time=*/3}, + AllocatedSlice{1024, 2048, /*inclusive_start_time=*/7}}}); EXPECT_EQ(*block->original_slice_data, expected) << "\nExpected: " << expected.ToString() << "\nGot: " << block->original_slice_data->ToString(); // Set the first slice for p2 to be place at the larger offset. block->repacked_slice_data = SlicedAllocationData( - {{Slice{1024, 2048, /*inclusive_start_time=*/7}, - Slice{1024, 3072, /*inclusive_start_time=*/3}}}); + {{AllocatedSlice{1024, 2048, /*inclusive_start_time=*/7}, + AllocatedSlice{1024, 3072, /*inclusive_start_time=*/3}}}); } } else if (block->inclusive_start_time == 4 && block->initial_offset == 3072 && block->size == 1024) { @@ -12451,10 +11502,9 @@ ENTRY main { // Configure MSA. SetupProposeSlicesToExpect2SlicesOfF32x8x8(); // Force MSA to prefer prefetching 'prefetch'. - MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare = - [](const MemorySpaceAssignment::BufferInterval& lhs, - const MemorySpaceAssignment::BufferInterval& rhs) { - auto lookup = [](const MemorySpaceAssignment::BufferInterval& x) { + MsaBufferIntervalCompare buffer_interval_compare = + [](const MsaBufferInterval& lhs, const MsaBufferInterval& rhs) { + auto lookup = [](const MsaBufferInterval& x) { // An arbitrary value that is greater than that used for 'prefetch'. int priority = 100; if (x.buffer->instruction()->name() == "prefetch") { @@ -12473,7 +11523,7 @@ ENTRY main { // Define a lambda for running MSA on the specified HLO, with the // configuration above. auto run_msa = - [&](std::string_view hlo_text) -> StatusOr { + [&](std::string_view hlo_text) -> absl::StatusOr { ModuleAndAssignments module_and_assignments; TF_ASSIGN_OR_RETURN(module_and_assignments.module, ParseAndReturnVerifiedModule(hlo_text)); @@ -12570,5 +11620,130 @@ ENTRY main { {copy_start, first_while, second_while, copy_done})); } +using RepackingTest = ::testing::Test; + +TEST_F(RepackingTest, Colocations) { + AllocationBlock a{10, 20, 100, 0, 1000, 0}; + AllocationBlock b{15, 25, 150, 0, 2000, 1}; + AllocationBlock c{18, 22, 50, 0, 500, 2}; + AllocationBlock d{5, 9, 20, 0, 3000, 3}; + AllocationBlock e{17, 22, 100, 0, 1500, 4}; + AllocationBlock f{25, 27, 150, 0, 2500, 5}; + + // a doesn't have other colocations. + a.next_colocated = &a; + // b and c are colocated. + b.next_colocated = &c; + c.next_colocated = &b; + // d, e, and f are colocated. + d.next_colocated = &f; + e.next_colocated = &d; + f.next_colocated = &e; + + EXPECT_EQ(a.GetColocationsCount(), 1); + EXPECT_THAT(a.GetColocations(), UnorderedElementsAre(&a)); + EXPECT_EQ(b.GetColocationsCount(), 2); + EXPECT_THAT(b.GetColocations(), UnorderedElementsAre(&b, &c)); + EXPECT_EQ(c.GetColocationsCount(), 2); + EXPECT_THAT(c.GetColocations(), UnorderedElementsAre(&b, &c)); + EXPECT_EQ(d.GetColocationsCount(), 3); + EXPECT_THAT(d.GetColocations(), UnorderedElementsAre(&d, &e, &f)); + EXPECT_EQ(e.GetColocationsCount(), 3); + EXPECT_THAT(e.GetColocations(), UnorderedElementsAre(&d, &e, &f)); + EXPECT_EQ(f.GetColocationsCount(), 3); + EXPECT_THAT(f.GetColocations(), UnorderedElementsAre(&d, &e, &f)); +} + +TEST_F(SlicedPrefetchTest, UniformSizedSlicing) { + std::string hlo_text = R"zz( +HloModule Slice, is_scheduled=true + +ENTRY main { + p0 = f32[8,8] parameter(0) + p1 = f32[8,8] parameter(1) + p2 = f32[8,16] parameter(2) + constant1 = f32[] constant(1.1) + + a = f32[8,8] tanh(p0) + b = f32[8,8] tanh(a) + c = f32[8,8] tanh(b) + d = f32[8,8] tanh(c) + e = f32[8,8] tanh(d) + f = f32[8,8] tanh(e) + g = f32[8,8] tanh(f) + h = f32[8,8] tanh(g) + + x = f32[8,8] add(p1, h) + padded_x = f32[8,16] pad(x, constant1), padding=0_0x0_8 + ROOT r = f32[8,16] add(padded_x, p2) +})zz"; + const Shape f32_8_16 = ShapeUtil::MakeShape(F32, {8, 16}); + const Shape s8_128 = ShapeUtil::MakeShape(S8, {128}); + + options_.sliced_prefetch_options.set_max_slices(100000); + options_.sliced_prefetch_options.set_preferred_slice_size(4 * 8 * 4); + + EXPECT_CALL(slice_proposer_, + ProposeSlices(f32_8_8_, EqualsSlicedPrefetchOptions( + options_.sliced_prefetch_options))) + .WillRepeatedly(Return(SliceProposalCollection({ + SliceProposal( + {s8_128, std::vector({{0, 128}}), ShapeSize(s8_128)}), + SliceProposal({s8_128, std::vector({{128, 256}}), + ShapeSize(s8_128)}), + }))); + + EXPECT_CALL(slice_proposer_, + ProposeSlices(f32_8_16, EqualsSlicedPrefetchOptions( + options_.sliced_prefetch_options))) + .WillRepeatedly(Return(SliceProposalCollection({ + SliceProposal( + {s8_128, std::vector({{0, 128}}), ShapeSize(s8_128)}), + SliceProposal({s8_128, std::vector({{128, 256}}), + ShapeSize(s8_128)}), + SliceProposal({s8_128, std::vector({{256, 384}}), + ShapeSize(s8_128)}), + SliceProposal({s8_128, std::vector({{384, 512}}), + ShapeSize(s8_128)}), + }))); + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text)); + VLOG(1) << "Original module:\n" + << module->ToString(HloPrintOptions::ShortParsable()); + + std::unique_ptr assignments = AssignMemorySpace( + module.get(), options_, + /*max_prefetch_interval=*/100, /*min_prefetch_interval=*/1); + + VLOG(1) << "Post-MSA module:\n" + << module->ToString(HloPrintOptions::ShortParsable()); + + auto root = module->entry_computation()->root_instruction(); + + // Expect p1 to be asynchronously copied via 2 slices, and p2 to be + // asynchronously copied via 4 slices. We expect p1 and p2 to be bitcast + // before slicing and after slicing. + EXPECT_THAT( + root, + op::Add(op::Pad(op::Add(IsAsyncSlicedCopy( + kAlternateMemorySpace, kDefaultMemorySpace, + {{{0, 128}}, {{128, 256}}}, op::Parameter(1), + /*expect_bitcasted_io=*/true), + /*don't care*/ _), + /*padding constant*/ _), + IsAsyncSlicedCopy( + kAlternateMemorySpace, kDefaultMemorySpace, + {{{0, 128}}, {{128, 256}}, {{256, 384}}, {{384, 512}}}, + op::Parameter(2), /*expect_bitcasted_io=*/true))); + + // Check expectations on the chunks assigned to the asynchronous sliced copy. + TF_EXPECT_OK(CheckSliceChunks(*assignments, root->operand(1), + /*expect_bitcasted_io=*/true)); + TF_EXPECT_OK(CheckSliceChunks(*assignments, + root->operand(0)->operand(0)->operand(0), + /*expect_bitcasted_io=*/true)); +} + } // namespace +} // namespace memory_space_assignment } // namespace xla diff --git a/xla/service/memory_space_assignment/options.h b/xla/service/memory_space_assignment/options.h new file mode 100644 index 0000000000000..b49f30e08cedd --- /dev/null +++ b/xla/service/memory_space_assignment/options.h @@ -0,0 +1,270 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_OPTIONS_H_ +#define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_OPTIONS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/buffer_value.h" +#include "xla/service/heap_simulator/heap_simulator.h" +#include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" +#include "xla/service/memory_space_assignment/prefetch_interval_picker.h" +#include "xla/service/memory_space_assignment/repacking.h" +#include "xla/service/memory_space_assignment/slice.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" + +namespace xla { +namespace memory_space_assignment { + +using IsAllowedInAlternateMemoryFunction = std::function; +using IsUseAllowedInAlternateMemoryFunction = + std::function; +using IsPositionAllowedInAlternateMemoryFunction = + std::function; +using ReservedScopedMemoryFunction = std::function>& /*operands_in_alternate_memory*/, + const absl::flat_hash_set& /*outputs_in_alternate_memory*/)>; +using MsaBufferInterval = + GlobalDecreasingSizeBestFitHeap::BufferInterval; +using MsaBufferIntervalCompare = + GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare; +using PositionRequiresContiguousAllocationFunction = + std::function; + +// The BufferInterval sorting interface that MemorySpaceAssignment expects. +class BufferIntervalComparator { + public: + virtual ~BufferIntervalComparator() = default; + + // A logging string explaining the sorting criteria. E.g., [ -size, offset ] + // indicates we sort (desc) size, then (asc) offset. + virtual std::string DescribeComparisonCriteria() const = 0; + + // A logging string containing the values used to sort buffer_interval. + // E.g., we might return [ -1024, 100 ], if the criteria is [ -size, + // offset ]. + virtual std::string CriteriaToString( + const MsaBufferInterval& buffer_interval) = 0; + + // comparator.LessThan(lhs, rhs) will be used for BufferIntervalCompare. + virtual bool LessThan(const MsaBufferInterval& lhs, + const MsaBufferInterval& rhs) = 0; + + // Used to create a functor that can be passed to a method like std::sort. + // E.g., absl::c_sort(v, comparator.GetComparisonFunctor()); + MsaBufferIntervalCompare GetComparisonFunctor() { + return [this](const MsaBufferInterval& lhs, const MsaBufferInterval& rhs) { + return LessThan(lhs, rhs); + }; + } + + protected: + BufferIntervalComparator() = default; +}; + +// The different options to be passed to the Run() API. +struct Options { + // Backend-specific integer value that describes the alternate memory. + int64_t alternate_memory_space = 0; + + // Maximum size of the alternate memory space. + int64_t max_size_in_bytes = 0; + + // Memory alignment of the alternate memory space. + int64_t alignment_in_bytes = 1; + + // If provided, we sort the buffers using this comparator. Otherwise, we use + // GlobalDecreasingSizeBestFitHeap::kSpatial. + BufferIntervalComparator* buffer_interval_comparator = nullptr; + + // This object determines how early and how late prefetches can occur. + PrefetchIntervalPicker* prefetch_interval_picker = nullptr; + + // This object is used to determine the benefit of a particular allocation. + CostAnalysis* cost_analysis = nullptr; + + // Size function for buffer values. + BufferValue::SizeFunction size_fn; + + std::function get_equivalent_s8_shape_fn; + + // This function can be used to prevent certain HloValues (e.g., based on + // the opcode) to be placed on the alternate memory. + IsAllowedInAlternateMemoryFunction is_allowed_in_alternate_mem_fn; + + // This function can be used to prevent certain HloUses (e.g., based on + // the opcode) to be placed on the alternate memory. + IsUseAllowedInAlternateMemoryFunction is_use_allowed_in_alternate_mem_fn = + [](const HloUse&) { return true; }; + + // Specifies if the given position is allowed in the alternate memory. + IsPositionAllowedInAlternateMemoryFunction + is_position_allowed_in_alternate_mem_fn = + [](const HloPosition&) { return true; }; + + // This function returns the amount of scoped memory in bytes that should be + // reserved during the execution of this instruction. + ReservedScopedMemoryFunction reserved_scoped_memory_fn = + [](const HloInstruction*, + const absl::flat_hash_set< + std::pair>& /*operands_in_alternate_memory*/, + const absl::flat_hash_set< + ShapeIndex>& /*outputs_in_alternate_memory*/) { return 0; }; + + PositionRequiresContiguousAllocationFunction + position_requires_contiguous_allocation_fn = + [](const HloPosition&) { return false; }; + + // If true, we will try to reduce scoped allocation buffer size for all + // instructions if their operand/output has been allocated in alternate + // memory. + bool reduce_scoped_memory_limit = false; + + // If true, we allocate the reserved scoped memory at the same offset. This + // is useful to enable more deduplication between HLOs that have reserved + // scoped memories, but may result in less efficient memory packing. + bool allocate_reserved_scoped_memory_at_same_offset = true; + + // Specifies the upper bound for number of outstanding prefetches and + // evictions, -1 for unlimited. + int64_t max_outstanding_prefetches = -1; + int64_t max_outstanding_evictions = -1; + + // Extra outstanding prefetch limit for while uses (in addition to + // max_outstanding_prefetches). + int64_t while_use_extra_outstanding_prefetch_limit = 0; + + // Specifies the maximum number of retries that will be performed for each + // value in case prefetching failed due to running out of asynchronous + // copies or asynchronous copy resource. + int64_t max_retries = 1; + + // The maximum number of repacks that we are willing to perform in case we + // can't allocate a buffer due to running out of memory. If this value is + // greater than 0, repacker must be non-nullptr. + int64_t max_repacks = 0; + + // The repacking algorithm to reduce fragmentation. Must be non-null if + // max_repacks is greater than 0. + MemorySpaceAssignmentRepacker* repacker = nullptr; + + // This is only useful for testing, repack after every allocation. + bool repack_after_every_allocation = false; + + // If true, tries allocating buffers across (e.g., before and inside a while + // loop body) sequential calls (kWhile, kCall, and kConditional). + bool allocate_across_sequential_calls = false; + + // If true, verifies the memory space assignment against overlapping + // buffers. + bool verify = false; + + // If not nullptr, this function is called to dump debugging information. + // The first argument is appended to the file name and the second argument + // is the contents of the file. + std::function dump_fn = nullptr; + + // Enable prefetching buffers into preferred memory across program + // boundaries + bool enable_cross_program_prefetch = true; + + // If true, use buffer_interval_compare to determine which buffers to + // prefetch across program boundaries. + bool default_cross_program_prefetch_heuristic = false; + + // Enable cross-program prefetch freeing optimization where the + // cross-program-prefetched buffer can be reused. + bool enable_cross_program_prefetch_freeing = true; + + // The maximum number of cross program prefetches. + // TODO(tjablin): Use a heuristic to determine this automatically. + int max_cross_program_prefetches = 1; + + // Enable redundant eviction optimization in/around while loops. If enabled, + // this optimization would keep a copy of the buffer in the default memory in + // addition to alternate memory to eliminate redundant evictions. + bool enable_while_redundant_eviction_elimination = true; + + // An optional memory space assignment autotuning config, which is used + // to sort allocated buffers. + std::optional> autotuning_config = std::nullopt; + + // If true, uses the earlier instance of the same instruction to use as + // preferred prefetch start time. + bool use_repeated_instance_for_preferred_prefetch_time = false; + + // If true, enforces the FIFO order for prefetches. + bool enforce_prefetch_fifo_order = false; + + // The ratio of use bytes to copy bytes for a given allocation site below + // which we consider the site to be inefficient. A value of 0 would treat all + // sites as efficient and a value of 1 would require the amount of bytes used + // at the site to be at least as much as the async copy bytes. There are two + // factors that determine the copy and use bytes: + // - Some uses don't actually access the entire tensor, e.g. in + // dynamic-update-slice. + // - copy_bytes may be larger than the size of the tensor as well. An + // example is a tensor may be prefetched, used, and then evicted. In that + // case copy_bytes would be twice the size of the tensor. + float inefficient_use_to_copy_ratio = 0.0; + + // This is mostly used for testing, it allows a test case to inject its own + // logic for AlternateMemoryBestFitHeap::GetInefficientAllocationSites. + std::function>( + absl::Span)> + get_inefficient_allocation_sites_fn = nullptr; + + // Config to filter prefetches and update preferred prefetch times for the + // filtered prefetches. + PreferredPrefetchOverrides preferred_prefetch_overrides; + + // Options for slicing prefetches into smaller asynchronously copied pieces. + SlicedPrefetchOptions sliced_prefetch_options; + + // Options for the memory-bound loop optimizer feature. + MemoryBoundLoopOptimizerOptions memory_bound_loop_optimizer_options; + + SliceProposalFunction propose_slice_fn = [](const Shape&, + const SlicedPrefetchOptions&) + -> absl::StatusOr { + return UnimplementedStrCat("Generation of SliceProposals unimplemented"); + }; + + // Option to always spill buffers from alternate memory to default memory + // and prefetching back to alternate memory(if needed) just in time for use. + bool always_spill_to_default_memory = false; +}; +} // namespace memory_space_assignment +} // namespace xla + +#endif // XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_OPTIONS_H_ diff --git a/xla/service/memory_space_assignment/prefetch_interval_picker.cc b/xla/service/memory_space_assignment/prefetch_interval_picker.cc new file mode 100644 index 0000000000000..ad63a509dc4fd --- /dev/null +++ b/xla/service/memory_space_assignment/prefetch_interval_picker.cc @@ -0,0 +1,553 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/memory_space_assignment/prefetch_interval_picker.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_live_range.h" +#include "xla/service/heap_simulator/heap_simulator.h" +#include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/util.h" +#include "tsl/platform/logging.h" + +namespace xla { +namespace memory_space_assignment { +namespace { + +// Each time we retry compilation, increase the preferred eviction end time by +// this amount multiplied by preferred overlap to async copy ratio. +const float kEvictionRetryMultiplier = 2.0; + +// The number of decreasing intervals for CostAnalysisPrefetchIntervalPicker to +// return when it runs out of increasing intervals. Increasing this number may +// hurt compilation time. +const int kNumExploredDecreasingIntervals = 100; + +} // namespace + +bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy( + const Shape& shape, int64_t start_time, int64_t end_time) const { + return end_time - start_time <= max_overlap_count_; +} + +int64_t InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime( + const Shape& shape, int64_t start_time, int64_t latest_end_time) const { + return std::min(start_time + min_overlap_count_, latest_end_time); +} + +int64_t InstructionCountPrefetchIntervalPicker::LatestPrefetchStartTime( + const Shape& shape, int64_t start_time, int64_t end_time, + const HloUse* use) const { + return end_time - min_overlap_count_; +} + +int64_t InstructionCountPrefetchIntervalPicker::PreferredPrefetchStartTime( + const Shape& shape, int64_t earliest_prefetch_start_time, + int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const { + return std::max(earliest_prefetch_start_time, + prefetch_end_time - max_overlap_count_); +} + +int64_t InstructionCountPrefetchIntervalPicker::EstimatedPrefetchEndTime( + const Shape& shape, int64_t start_time, int64_t end_time) const { + // For testing, assume the end time is the estimated prefetch end time. + return end_time; +} + +float InstructionCountPrefetchIntervalPicker::GetLogicalIntervalElapsed( + int64_t start_time, int64_t end_time) const { + // For testing, just assume every HLO takes 1 second. + return static_cast(end_time - start_time - 1); +} + +void InstructionCountPrefetchIntervalPicker::Begin( + const HloUse& use, int64_t start_time, int64_t end_time, + std::optional preferred_time) { + end_time_ = end_time; + const Shape& shape = ShapeUtil::GetSubshape( + use.instruction->operand(use.operand_number)->shape(), use.operand_index); + if (preferred_time) { + current_prefetch_time_ = *preferred_time; + } else { + current_prefetch_time_ = + PreferredPrefetchStartTime(shape, start_time, end_time, end_time); + } +} + +int64_t InstructionCountPrefetchIntervalPicker::Next() { + CHECK(!Done()) << "Prefetch interval picker's Next() is called even though " + "Done() is false"; + return current_prefetch_time_++; +} + +bool InstructionCountPrefetchIntervalPicker::Done() const { + return end_time_ - current_prefetch_time_ <= min_overlap_count_; +} + +int64_t InstructionCountPrefetchIntervalPicker::latest_time() const { + return end_time_ - min_overlap_count_ - 1; +} + +std::string InstructionCountPrefetchIntervalPicker::ToDebugString() const { + return absl::StrCat("Overlapped HLOs = ", end_time_ - current_prefetch_time_); +} + +std::string InstructionCountPrefetchIntervalPicker::ToNoCopyDebugString( + const Shape& shape, int64_t start_time, int64_t end_time) const { + return absl::StrCat("Overlapped HLOs = ", end_time - start_time); +} + +CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( + const CostAnalysis& cost_analysis, float min_overlap_to_async_copy_ratio, + float preferred_overlap_to_async_copy_ratio, + float max_overlap_to_mem_size_async_copy_ratio, int64_t mem_size_bytes, + const Shape* shape_override) + : while_nest_level_( + cost_analysis.hlo_live_range().instruction_schedule().size() + 1, 0), + computation_nest_level_( + cost_analysis.hlo_live_range().instruction_schedule().size() + 1, 0), + cost_analysis_(cost_analysis), + min_overlap_to_async_copy_ratio_(min_overlap_to_async_copy_ratio), + preferred_overlap_to_async_copy_ratio_( + preferred_overlap_to_async_copy_ratio), + max_async_copy_elapsed_( + cost_analysis_.GetAsyncCopyElapsed( + ShapeUtil::MakeShape(S32, {mem_size_bytes / 4})) * + max_overlap_to_mem_size_async_copy_ratio), + shape_override_(shape_override ? std::optional(*shape_override) + : std::nullopt) { + instruction_schedule_ = + &cost_analysis_.hlo_live_range().instruction_schedule(); + + // Create a vector of elapsed times and while nesting levels of HLO + // instructions. The elapsed times are multiplied by + // pow(while_execution_count, nest_level) to account for executing the HLOs + // multiple times in while loops. + std::vector instructions_elapsed_time( + instruction_schedule_->size() + 1, 0.0); + int max_while_nest_level = 0; + for (const auto& instruction_and_logical_time : *instruction_schedule_) { + // To avoid double counting, don't include the elapsed time of while and + // conditional HLOs. + const HloInstruction* instruction = instruction_and_logical_time.first; + int64_t logical_time = instruction_and_logical_time.second; + if (logical_time >= instructions_elapsed_time.size()) { + instructions_elapsed_time.resize(logical_time + 1, 0.0); + while_nest_level_.resize(logical_time + 1, 0); + } + int while_nest_level = cost_analysis_.CalculateComputationNestLevel( + instruction_and_logical_time.first, /*while_only=*/true); + while_nest_level_[logical_time] = while_nest_level; + max_while_nest_level = std::max(max_while_nest_level, while_nest_level); + int computation_nest_level = cost_analysis_.CalculateComputationNestLevel( + instruction_and_logical_time.first, /*while_only=*/false); + computation_nest_level_[logical_time] = computation_nest_level; + if (instruction->opcode() == HloOpcode::kWhile || + instruction->opcode() == HloOpcode::kConditional) { + continue; + } + float elapsed_time = cost_analysis_.GetInstructionElapsed( + *instruction_and_logical_time.first); + instructions_elapsed_time[logical_time] = + elapsed_time * cost_analysis_.GetWhileNestMultiplier(while_nest_level); + } + // As an optimization, create a cumulative sum vector of elapsed time. + float cumsum = 0.0; + elapsed_time_cumsum_.reserve(instructions_elapsed_time.size()); + for (float elapsed_time : instructions_elapsed_time) { + cumsum += elapsed_time; + elapsed_time_cumsum_.push_back(cumsum); + } + // To be able to accurately determine the minimum nest level between a start + // time and an end time efficiently, populate a data structure that stores the + // closest 'smaller' nest level change index. + const int64_t size = instructions_elapsed_time.size(); + CHECK_EQ(size, while_nest_level_.size()); + std::vector most_recent_by_level(while_nest_level_.size(), -1); + int prev_nest_level = 0; + int change_idx = -1; + while_nest_level_change_.reserve(size); + for (int i = 0; i < size; ++i) { + int nest_level = while_nest_level_[i]; + if (nest_level != prev_nest_level) { + prev_nest_level = nest_level; + // Compute last change index by choosing the most recent instruction index + // with smaller nesting level. Note that it may happen that even though + // there were few different regions with other nest levels before, all of + // then are same or bigger than this one, in which case we'll end up with + // -1, e.g. if you got nest level 0 no need checking anything else. + change_idx = -1; + for (int smaller_level = 0; smaller_level < nest_level; smaller_level++) { + change_idx = std::max(change_idx, most_recent_by_level[smaller_level]); + } + } + most_recent_by_level[nest_level] = i; + while_nest_level_change_.push_back(change_idx); + } + for (int i = 0; i <= max_while_nest_level; ++i) { + while_execution_counts_.push_back(cost_analysis_.GetWhileNestMultiplier(i)); + } +} + +float CostAnalysisPrefetchIntervalPicker::GetMaxElapsedInAlternateMemory( + float async_copy_elapsed) const { + return max_async_copy_elapsed_; +} + +bool CostAnalysisPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy( + const Shape& shape, int64_t start_time, int64_t end_time) const { + // Even though this method returns if we allow the buffer in alternate memory + // _without_ asynchronous copies, calculate how long it would have taken to + // copy it and compare it to the elapsed time in the logical interval. + float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( + shape_override_ ? *shape_override_ : shape); + float logical_interval_elapsed = + GetLogicalIntervalElapsed(start_time, end_time); + return GetMaxElapsedInAlternateMemory(async_copy_elapsed) > + logical_interval_elapsed; +} + +int64_t CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime( + const Shape& shape, int64_t start_time, int64_t latest_end_time) const { + float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( + shape_override_ ? *shape_override_ : shape); + int64_t end_time; + for (end_time = start_time + 1; end_time <= latest_end_time; ++end_time) { + float logical_interval_elapsed = + GetLogicalIntervalElapsed(start_time, end_time); + if (logical_interval_elapsed >= + (1 + kEvictionRetryMultiplier * retry_number_) * + preferred_overlap_to_async_copy_ratio_ * async_copy_elapsed) { + break; + } + } + return end_time; +} + +int64_t CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime( + const Shape& shape, int64_t start_time, int64_t end_time, + const HloUse* use) const { + // Find the earliest time that satisfies max_overlap_to_async_copy_ratio_. + float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( + shape_override_ ? *shape_override_ : shape); + // If there is a use, estimate the time we would save by having this op in + // alternate memory. + float inst_elapsed_reduction = 0.0f; + if (use) { + float elapsed_time = + cost_analysis_.GetInstructionElapsed(*use->instruction); + float elapsed_time_in_alternate_mem = + cost_analysis_.GetInstructionElapsedInAlternateMemory( + *use->instruction, + /*operands_in_alternate_mem=*/ + {std::make_pair(use->operand_number, use->operand_index)}, + /*outputs_in_alternate_mem=*/{}); + inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem; + } + int end_nest_level = computation_nest_level_[end_time]; + + // Find the latest time we're allowed to start prefetching. + float min_interval = min_overlap_to_async_copy_ratio_ * async_copy_elapsed; + int latest_prefetch_time; + for (latest_prefetch_time = end_time - 1; + latest_prefetch_time >= start_time && + (computation_nest_level_[latest_prefetch_time] != end_nest_level || + min_interval > + GetLogicalIntervalElapsed(latest_prefetch_time, end_time) + + inst_elapsed_reduction); + --latest_prefetch_time) { + } + + return latest_prefetch_time; +} + +int64_t CostAnalysisPrefetchIntervalPicker::PreferredPrefetchStartTime( + const Shape& shape, int64_t earliest_prefetch_start_time, + int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const { + // Between the earliest and latest prefetch interval, find the interval + // closest to the preferred interval and start iterating from there. + float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( + shape_override_ ? *shape_override_ : shape); + int64_t preferred_prefetch_start_time = earliest_prefetch_start_time; + float preferred_interval = + preferred_overlap_to_async_copy_ratio_ * async_copy_elapsed; + float best_interval = GetLogicalIntervalElapsed(earliest_prefetch_start_time, + prefetch_end_time); + int end_nest_level = computation_nest_level_[prefetch_end_time]; + for (int64_t prefetch_start_time = earliest_prefetch_start_time + 1; + prefetch_start_time <= latest_prefetch_start_time; + ++prefetch_start_time) { + float interval = + GetLogicalIntervalElapsed(prefetch_start_time, prefetch_end_time); + if (computation_nest_level_[prefetch_start_time] == end_nest_level && + std::abs(preferred_interval - interval) < + std::abs(preferred_interval - best_interval)) { + best_interval = interval; + preferred_prefetch_start_time = prefetch_start_time; + } + } + return preferred_prefetch_start_time; +} + +int64_t CostAnalysisPrefetchIntervalPicker::LatestPrefetchEndTime( + int64_t original_prefetch_end_time, + int64_t proposed_prefetch_end_time) const { + // Iterate towards the beginning until we find a suitable end time that is the + // same while nest level as the original prefetch end time. + int64_t original_nest_level = + computation_nest_level_[original_prefetch_end_time]; + int64_t new_prefetch_end_time; + for (new_prefetch_end_time = proposed_prefetch_end_time; + computation_nest_level_[new_prefetch_end_time] != original_nest_level; + --new_prefetch_end_time) { + } + return new_prefetch_end_time; +} + +int64_t CostAnalysisPrefetchIntervalPicker::EstimatedPrefetchEndTime( + const Shape& shape, int64_t start_time, int64_t end_time) const { + float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( + shape_override_ ? *shape_override_ : shape); + int64_t estimated_end_time; + for (estimated_end_time = start_time + 1; estimated_end_time < end_time; + ++estimated_end_time) { + float interval = GetLogicalIntervalElapsed(start_time, estimated_end_time); + if (interval >= async_copy_elapsed) { + break; + } + } + return estimated_end_time; +} + +void CostAnalysisPrefetchIntervalPicker::Begin( + const HloUse& use, int64_t start_time, int64_t end_time, + std::optional preferred_time) { + const Shape& shape = ShapeUtil::GetSubshape( + use.instruction->operand(use.operand_number)->shape(), use.operand_index); + // Find the earliest time that satisfies max_overlap_to_async_copy_ratio_. + async_copy_elapsed_ = cost_analysis_.GetAsyncCopyElapsed( + shape_override_ ? *shape_override_ : shape); + // Estimate the time we would save by having this op in alternate memory. + float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction); + float elapsed_time_in_alternate_mem = + cost_analysis_.GetInstructionElapsedInAlternateMemory( + *use.instruction, /*operands_in_alternate_mem=*/ + {std::make_pair(use.operand_number, use.operand_index)}, + /*outputs_in_alternate_mem=*/{}); + inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem; + end_logical_time_ = end_time; + int end_nest_level = computation_nest_level_[end_logical_time_]; + + // Find the latest time we're allowed to start prefetching. + float min_interval = min_overlap_to_async_copy_ratio_ * async_copy_elapsed_; + latest_prefetch_time_ = + LatestPrefetchStartTime(shape, start_time, end_time, &use); + + // Find the earliest time we're allowed to start prefetching. + float max_interval = GetMaxElapsedInAlternateMemory(async_copy_elapsed_); + for (earliest_prefetch_time_ = start_time; + earliest_prefetch_time_ < latest_prefetch_time_ && + (computation_nest_level_[earliest_prefetch_time_] != end_nest_level || + max_interval < GetLogicalIntervalElapsed(earliest_prefetch_time_, + end_logical_time_)); + ++earliest_prefetch_time_) { + } + if (earliest_prefetch_time_ > latest_prefetch_time_) { + // There is no available prefetch interval for the given start and end + // times. Set the iterators accordingly to ensure Done() returns true. + increasing_prefetch_time_iterator_ = earliest_prefetch_time_; + decreasing_prefetch_time_iterator_ = latest_prefetch_time_; + CHECK(Done()); + return; + } + + int64_t starting_prefetch_time; + if (preferred_time && *preferred_time <= latest_prefetch_time_) { + starting_prefetch_time = *preferred_time; + } else { + starting_prefetch_time = + PreferredPrefetchStartTime(shape, earliest_prefetch_time_, + latest_prefetch_time_, end_logical_time_); + } + float preferred_interval = + preferred_overlap_to_async_copy_ratio_ * async_copy_elapsed_; + VLOG(4) << "Interval min/max/preferred = " << min_interval << " " + << max_interval << " " << preferred_interval + << " prefetch time earliest/latest/starting = " + << earliest_prefetch_time_ << " " << latest_prefetch_time_ << " " + << starting_prefetch_time; + + increasing_prefetch_time_iterator_ = starting_prefetch_time; + decreasing_prefetch_time_iterator_ = starting_prefetch_time; + using_increasing_prefetch_time_iterator_ = true; + // Since both iterators start at the same position, call Next() once to + // advance one of the iterators. + Next(); +} + +int64_t CostAnalysisPrefetchIntervalPicker::Next() { + CHECK(!Done()) << "Prefetch interval picker's Next() is called even though " + "Done() is false"; + if (using_increasing_prefetch_time_iterator_) { + int64_t prefetch_time = increasing_prefetch_time_iterator_++; + while (increasing_prefetch_time_iterator_ <= latest_prefetch_time_ && + computation_nest_level_[increasing_prefetch_time_iterator_] != + computation_nest_level_[end_logical_time_]) { + ++increasing_prefetch_time_iterator_; + } + if (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_) { + using_increasing_prefetch_time_iterator_ = false; + } + return prefetch_time; + } else { + int64_t prefetch_time = decreasing_prefetch_time_iterator_--; + // As a compilation time optimization, reduce the number of intervals that + // this prefetch interval picker returns. When we run out of the increasing + // prefetch time iterator, only explore up to + // kNumExploredDecreasingIntervals intervals. To do that, calculate the + // 1/kNumExploredDecreasingIntervals of the elapsed time between the + // earliest prefetch time and the use, and decrement the iterator until the + // prefetch elapsed time is at least as large as this target value. This + // allows us to reduce the number of expensive heap fit and resource checks + // when the graph consists of a large number of fast-executing HLOs. + // + // Shown pictorially, assuming kNumExploredDecreasingIntervals = 3 and the + // numbers indicating the elapsed time of the HLOs, only the indicated + // options for prefetch start time would be explored: + // + // ---1---1---3---1---1---1---1---0---0---0---0---1---5---X + // ^ ^ ^ ^ + // Option3 Option2 Option1 Use + // (Earliest) + float next_target_interval_elapsed = 0; + if (increasing_prefetch_time_iterator_ > latest_prefetch_time_) { + next_target_interval_elapsed = + GetLogicalIntervalElapsed(prefetch_time, end_logical_time_) + + (GetLogicalIntervalElapsed(earliest_prefetch_time_, + end_logical_time_) / + kNumExploredDecreasingIntervals); + VLOG(3) << "Next target interval elapsed: " + << next_target_interval_elapsed; + } + while (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_ && + (computation_nest_level_[decreasing_prefetch_time_iterator_] != + computation_nest_level_[end_logical_time_] || + GetLogicalIntervalElapsed(decreasing_prefetch_time_iterator_, + end_logical_time_) < + next_target_interval_elapsed)) { + --decreasing_prefetch_time_iterator_; + } + if (increasing_prefetch_time_iterator_ <= latest_prefetch_time_) { + using_increasing_prefetch_time_iterator_ = true; + } + return prefetch_time; + } +} + +bool CostAnalysisPrefetchIntervalPicker::Done() const { + return increasing_prefetch_time_iterator_ > latest_prefetch_time_ && + decreasing_prefetch_time_iterator_ < earliest_prefetch_time_; +} + +int64_t CostAnalysisPrefetchIntervalPicker::latest_time() const { + return latest_prefetch_time_; +} + +void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) { + retry_number_ = retry_number; +} + +int CostAnalysisPrefetchIntervalPicker::GetMinWhileNestLevel( + int64_t start_time, int64_t end_time) const { + int min_nest_level = + std::min(while_nest_level_[start_time], while_nest_level_[end_time]); + int change_idx = while_nest_level_change_[end_time]; + while (change_idx >= start_time) { + min_nest_level = std::min(min_nest_level, while_nest_level_[change_idx]); + change_idx = while_nest_level_change_[change_idx]; + } + return min_nest_level; +} + +float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed( + int64_t start_time, int64_t end_time) const { + CHECK_LE(start_time, end_time); + if (start_time == end_time) { + return 0.0; + } + if (start_time < 0) { + start_time = 0; + } + // Since elapsed_time_cumsum_ is already weighed by the while loop nesting + // level, normalize the elapsed time by dividing with the nesting factor of + // the interval (start and end times). + int interval_while_nest_level = GetMinWhileNestLevel(start_time, end_time); + return (elapsed_time_cumsum_[end_time - 1] - + elapsed_time_cumsum_[start_time]) / + while_execution_counts_[interval_while_nest_level]; +} + +std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const { + int current_logical_prefetch_time = using_increasing_prefetch_time_iterator_ + ? increasing_prefetch_time_iterator_ + : decreasing_prefetch_time_iterator_; + float logical_interval_elapsed = GetLogicalIntervalElapsed( + current_logical_prefetch_time, end_logical_time_); + return absl::StrCat( + "Async copy elapsed (s) = ", async_copy_elapsed_, + ", inst elapsed reduction (s) = ", inst_elapsed_reduction_, + ", logical interval elapsed (s) = ", logical_interval_elapsed, + ", interval = (", current_logical_prefetch_time, ", ", end_logical_time_, + ")"); +} + +std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString( + const Shape& shape, int64_t start_time, int64_t end_time) const { + float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed( + shape_override_ ? *shape_override_ : shape); + float logical_interval_elapsed = + GetLogicalIntervalElapsed(start_time, end_time); + return absl::StrCat( + "Async copy elapsed (s) = ", async_copy_elapsed, + ", logical interval elapsed (s) = ", logical_interval_elapsed); +} + +std::optional +CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) + const { + return cost_analysis_.GetMemoryBoundedness(interval); +} + +} // namespace memory_space_assignment +} // namespace xla diff --git a/xla/service/memory_space_assignment/prefetch_interval_picker.h b/xla/service/memory_space_assignment/prefetch_interval_picker.h new file mode 100644 index 0000000000000..0ae8af5307128 --- /dev/null +++ b/xla/service/memory_space_assignment/prefetch_interval_picker.h @@ -0,0 +1,292 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_PREFETCH_INTERVAL_PICKER_H_ +#define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_PREFETCH_INTERVAL_PICKER_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/heap_simulator/heap_simulator.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" +#include "xla/shape.h" +#include "xla/util.h" + +namespace xla { +namespace memory_space_assignment { + +// Abstract base class that memory space assignment uses to pick prefetch +// intervals. +class PrefetchIntervalPicker { + public: + PrefetchIntervalPicker() = default; + virtual ~PrefetchIntervalPicker() = default; + + // Returns true if the buffer can be allocated in alternate memory space + // without any copies (prefetches). + virtual bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, + int64_t start_time, + int64_t end_time) const = 0; + + // Returns the preferred end time for an eviction that starts at a given time + // and must end by the given end time. + virtual int64_t PreferredEvictionEndTime(const Shape& shape, + int64_t start_time, + int64_t latest_end_time) const = 0; + + // Returns the latest time that a prefetch can start. + virtual int64_t LatestPrefetchStartTime(const Shape& shape, + int64_t start_time, int64_t end_time, + const HloUse* use) const = 0; + + // Returns the preferred time that a prefetch can start. + virtual int64_t PreferredPrefetchStartTime( + const Shape& shape, int64_t earliest_prefetch_start_time, + int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const = 0; + + // Returns the latest time that a prefetch can end that is less than or equal + // to proposed_prefetch_end_time. + virtual int64_t LatestPrefetchEndTime( + int64_t original_prefetch_end_time, + int64_t proposed_prefetch_end_time) const { + return proposed_prefetch_end_time; + } + + // Returns the estimated end time of a prefetch that starts at the given time. + virtual int64_t EstimatedPrefetchEndTime(const Shape& shape, + int64_t start_time, + int64_t end_time) const = 0; + + // Returns the elapsed time in seconds between the logical interval that + // corresponds to the instruction schedule. + virtual float GetLogicalIntervalElapsed(int64_t start_time, + int64_t end_time) const = 0; + + // Begins the iterator for the first start time of the prefetch. + virtual void Begin(const HloUse& use, int64_t start_time, int64_t end_time, + std::optional preferred_time) = 0; + + // Advances the start time of the prefetch and returns that value. + virtual int64_t Next() = 0; + + // Returns true if the available prefetch intervals have been exhausted. + virtual bool Done() const = 0; + + // Returns the latest time the prefetch interval picker will have pick. + virtual int64_t latest_time() const = 0; + + // The retry number can be used to modify the interval picking policies. The + // first attempt will have a retry_number of 0, then 1, etc. + virtual void SetRetryNumber(int retry_number) { + retry_number_ = retry_number; + } + int retry_number() const { return retry_number_; } + + // Returns a debug string for the current state of the prefetch interval + // picker. + virtual std::string ToDebugString() const = 0; + + // Returns a debug string for no-copy allocation. + virtual std::string ToNoCopyDebugString(const Shape& shape, + int64_t start_time, + int64_t end_time) const = 0; + + // Prefetch interval pickers may return a value corresponding to the benefit + // of placing the BufferInterval in the alternate memory. The larger value, + // the more beneficial. + virtual std::optional BufferIntervalAlternateMemoryBenefit( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) + const { + return std::nullopt; + } + + protected: + const absl::flat_hash_map* + instruction_schedule_ = nullptr; + int retry_number_ = 0; +}; + +// Prefetch interval picker that uses instruction count to overlap asynchronous +// copies with independent computation. The min and max overlap counts describe +// the number of independent HLOs overlapped while a value is being prefetched +// into the alternate memory (between CopyStart and CopyDone HLO instructions). +// max_overlap_count attempts to prevent bringing tensors into the alternate +// memory too eagerly and hence occupying the space for other tensors which +// might use it. min_overlap_count attempts to prevent cases where tensors are +// prefetched into the alternate memory without sufficient time for the copy to +// take place. In those cases, it's just better to keep the tensor in the +// default memory instead of hurting the critical path with this copy that +// likely won't finish in time. +class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker { + public: + InstructionCountPrefetchIntervalPicker(int64_t min_overlap_count, + int64_t max_overlap_count) + : min_overlap_count_(min_overlap_count), + max_overlap_count_(max_overlap_count) {} + + bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, + int64_t start_time, + int64_t end_time) const override; + + int64_t PreferredEvictionEndTime(const Shape& shape, int64_t start_time, + int64_t latest_end_time) const override; + + int64_t LatestPrefetchStartTime(const Shape& shape, int64_t start_time, + int64_t end_time, + const HloUse* use) const override; + + int64_t PreferredPrefetchStartTime(const Shape& shape, + int64_t earliest_prefetch_start_time, + int64_t latest_prefetch_start_time, + int64_t prefetch_end_time) const override; + + int64_t EstimatedPrefetchEndTime(const Shape& shape, int64_t start_time, + int64_t end_time) const override; + float GetLogicalIntervalElapsed(int64_t start_time, + int64_t end_time) const override; + + void Begin(const HloUse& use, int64_t start_time, int64_t end_time, + std::optional preferred_time) override; + + int64_t Next() override; + bool Done() const override; + + int64_t latest_time() const override; + + std::string ToDebugString() const override; + std::string ToNoCopyDebugString(const Shape& shape, int64_t start_time, + int64_t end_time) const override; + + private: + int64_t min_overlap_count_; + int64_t max_overlap_count_; + int64_t end_time_; + int64_t current_prefetch_time_; +}; + +// Prefetch interval picker that uses cost analysis to overlap asynchronous +// copies with independent computation. It uses min (independent computation +// duration) / (asynchronous copy duration) ratio to guide whether the prefetch +// is within the lower bound. For the upper bound, it restricts the maximum +// duration that a buffer may occupy the alternate memory space as a multiple of +// the time it would take to copy a buffer that is the size of the alternate +// memory. It starts with the preferred ratio in Begin() and works its way for +// alternately earlier and later prefetches until hitting min and max ratios. +// The value for buffer size for max async copy is a mechanism to prevent +// copying small buffers between the two memories unnecessarily. For calculating +// the max time that the buffer can reside in alternate memory, we use the +// larger of this value and the actual size of the buffer. A shape override can +// also be provided which causes the interval picker to use that shape for async +// copy durations instead of the actual shape of the copy. +class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { + public: + CostAnalysisPrefetchIntervalPicker( + const CostAnalysis& cost_analysis, float min_overlap_to_async_copy_ratio, + float preferred_overlap_to_async_copy_ratio, + float max_overlap_to_mem_size_async_copy_ratio, int64_t mem_size_bytes, + const Shape* shape_override = nullptr); + + bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, + int64_t start_time, + int64_t end_time) const override; + + int64_t PreferredEvictionEndTime(const Shape& shape, int64_t start_time, + int64_t latest_end_time) const override; + + int64_t LatestPrefetchEndTime( + int64_t original_prefetch_end_time, + int64_t proposed_prefetch_end_time) const override; + + int64_t LatestPrefetchStartTime(const Shape& shape, int64_t start_time, + int64_t end_time, + const HloUse* use) const override; + + int64_t PreferredPrefetchStartTime(const Shape& shape, + int64_t earliest_prefetch_start_time, + int64_t latest_prefetch_start_time, + int64_t prefetch_end_time) const override; + + int64_t EstimatedPrefetchEndTime(const Shape& shape, int64_t start_time, + int64_t end_time) const override; + float GetLogicalIntervalElapsed(int64_t start_time, + int64_t end_time) const override; + + void Begin(const HloUse& use, int64_t start_time, int64_t end_time, + std::optional preferred_time) override; + + int64_t Next() override; + bool Done() const override; + + int64_t latest_time() const override; + + void SetRetryNumber(int retry_number) override; + + std::string ToDebugString() const override; + std::string ToNoCopyDebugString(const Shape& shape, int64_t start_time, + int64_t end_time) const override; + + std::optional BufferIntervalAlternateMemoryBenefit( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) + const override; + + private: + // Finds the minimum nest level in the given interval. + int GetMinWhileNestLevel(int64_t start_time, int64_t end_time) const; + + // Given the elapsed time to copy this buffer to the alternate memory, returns + // the longest time that this buffer may reside in the alternate memory space. + float GetMaxElapsedInAlternateMemory(float async_copy_elapsed) const; + + // For each instruction in the flattened schedule, maintain their elapsed time + // (in cumulative sum) and while nesting level. + std::vector elapsed_time_cumsum_; + std::vector while_nest_level_; + std::vector computation_nest_level_; + // Maintain the index of the most recent (before this instruction) nest level + // change in order to efficiently determine the minimum nest level in an + // interval. + std::vector while_nest_level_change_; + + const CostAnalysis& cost_analysis_; + float min_overlap_to_async_copy_ratio_; + float preferred_overlap_to_async_copy_ratio_; + float max_async_copy_elapsed_; + float async_copy_elapsed_; + float inst_elapsed_reduction_; + int64_t end_logical_time_; + int64_t earliest_prefetch_time_; + int64_t latest_prefetch_time_; + bool using_increasing_prefetch_time_iterator_ = true; + int64_t increasing_prefetch_time_iterator_; + int64_t decreasing_prefetch_time_iterator_; + + std::vector while_execution_counts_; + // Shape override is used to override the shape of the shape of the async copy + // to treat all async copies the same duration. Having an override forces + // prefetches to be scheduled roughly in FIFO order. + std::optional shape_override_; +}; + +} // namespace memory_space_assignment +} // namespace xla + +#endif // XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_PREFETCH_INTERVAL_PICKER_H_ diff --git a/xla/service/memory_space_assignment/prefetch_interval_picker_test.cc b/xla/service/memory_space_assignment/prefetch_interval_picker_test.cc new file mode 100644 index 0000000000000..7b8cac3fcab70 --- /dev/null +++ b/xla/service/memory_space_assignment/prefetch_interval_picker_test.cc @@ -0,0 +1,406 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/memory_space_assignment/prefetch_interval_picker.h" + +#include +#include + +#include +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" +#include "xla/service/memory_space_assignment/testing_utils.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace memory_space_assignment { +namespace { + +constexpr int64_t kPointerSize = 8; + +int64_t ShapeSize(const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, kPointerSize); +} + +using CostAnalysisPrefetchIntervalPickerTest = HloTestBase; + +TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) { + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + ENTRY Entry { + param0 = f32[2,4] parameter(0) + a = f32[2,4] negate(param0) + b = f32[2,4] negate(a) + c = f32[2,4] negate(b) + d = f32[2,4] negate(c) + e = f32[2,4] negate(d) + f = f32[2,4] negate(e) + g = f32[2,4] negate(f) + h = f32[2,4] negate(g) + i = f32[2,4] negate(h) + j = f32[2,4] negate(i) + k = f32[2,4] negate(j) + l = f32[2,4] negate(k) + m = f32[2,4] negate(l) + n = f32[2,4] negate(m) + o = f32[2,4] negate(n) + p = f32[2,4] negate(o) + q = f32[2,4] negate(p) + r = f32[2,4] negate(q) + s = f32[2,4] negate(r) + t = f32[2,4] negate(s) + u = f32[2,4] negate(t) + ROOT v = f32[2,4] add(u, param0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloCostAnalysis hlo_cost_analysis(ShapeSize); + CostAnalysisOptions options; + TF_ASSERT_OK_AND_ASSIGN( + auto cost_analysis, + FakeCostAnalysis::Create(hlo_cost_analysis, *module, options)); + CostAnalysisPrefetchIntervalPicker interval_picker( + *cost_analysis, + /*min_overlap_to_async_copy_ratio=*/1.0, + /*preferred_overlap_to_async_copy_ratio=*/2.0, + /*max_overlap_to_mem_size_async_copy_ratio=*/4.0, + /*mem_size_bytes=*/32); + + HloInstruction* root = module->entry_computation()->root_instruction(); + const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; + interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/22, std::nullopt); + + // Expect that the first interval is (15, 22), which has elapsed time of 6.0, + // twice of the async copy elased (3.0). Then we expect that intervals will be + // visited in alternating increasing and decreasing orders until hitting the + // min and max async copy overlap ratios, which are the intervals (18, 22) + // and (9, 22) respectively. + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 15); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 16); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 14); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 17); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 13); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 18); // Min async overlap ratio reached. + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 12); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 11); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 10); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 9); // Max async overlap ratio reached. + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_TRUE(interval_picker.Done()); + + // Expect that if the time between start_time and end_time is too short, there + // won't be any available intervals. + interval_picker.Begin(use, /*start_time=*/19, /*end_time=*/22, std::nullopt); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_TRUE(interval_picker.Done()); +} + +TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrderWhile) { + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + while_condition { + param1 = (f32[2,4]) parameter(0) // 19 + ROOT cond = pred[] constant(true) // 20 + } + + while_body { + param2 = (f32[2,4]) parameter(0) // 21 + gte2 = f32[2,4] get-tuple-element(param2), index=0 // 22 + add = f32[2,4] add(gte2, gte2) // 23 + ROOT tuple2 = (f32[2,4]) tuple(add) // 24 + } + + ENTRY Entry { + param0 = f32[2,4] parameter(0) // 0 + a = f32[2,4] negate(param0) // 1 + b = f32[2,4] negate(a) // 2 + c = f32[2,4] negate(b) // 3 + d = f32[2,4] negate(c) // 4 + e = f32[2,4] negate(d) // 5 + f = f32[2,4] negate(e) // 6 + g = f32[2,4] negate(f) // 7 + h = f32[2,4] negate(g) // 8 + i = f32[2,4] negate(h) // 9 + j = f32[2,4] negate(i) // 10 + k = f32[2,4] negate(j) // 11 + l = f32[2,4] negate(k) // 12 + m = f32[2,4] negate(l) // 13 + n = f32[2,4] negate(m) // 14 + o = f32[2,4] negate(n) // 15 + p = f32[2,4] negate(o) // 16 + q = f32[2,4] negate(p) // 17 + tuple = (f32[2,4]) tuple(q) // 18 + while = (f32[2,4]) while(tuple), condition=while_condition, body=while_body // 25 + gte1 = f32[2,4] get-tuple-element(while), index=0 // 26 + r = f32[2,4] negate(gte1) // 27 + s = f32[2,4] negate(r) // 28 + t = f32[2,4] negate(s) // 29 + u = f32[2,4] negate(t) // 30 + ROOT v = f32[2,4] add(u, param0) // 31 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloCostAnalysis hlo_cost_analysis(ShapeSize); + CostAnalysisOptions options; + TF_ASSERT_OK_AND_ASSIGN( + auto cost_analysis, + FakeCostAnalysis::Create(hlo_cost_analysis, *module, options)); + CostAnalysisPrefetchIntervalPicker interval_picker( + *cost_analysis, + /*min_overlap_to_async_copy_ratio=*/1.0, + /*preferred_overlap_to_async_copy_ratio=*/2.0, + /*max_overlap_to_mem_size_async_copy_ratio=*/12.0, + /*mem_size_bytes=*/32); + + EXPECT_EQ(cost_analysis->GetWhileNestMultiplier(1), 5.0); + HloInstruction* root = module->entry_computation()->root_instruction(); + const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; + interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/31, std::nullopt); + + // Because there are while loop computations between [19, 24], we ensure that + // the interval picker avoids this interval. + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 25); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 26); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 18); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 27); // Min async overlap ratio reached. + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 17); // Max async overlap ratio reached. + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_TRUE(interval_picker.Done()); +} + +TEST_F(CostAnalysisPrefetchIntervalPickerTest, NestedWhile) { + // This test is to check against a bug where we didn't assign + // while_nest_level_ for while instructions, and defaulting to 0. This could + // cause the prefetch interval logic to think a nested while instruction is + // the same level as the outermost computation. + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + while_condition.2 { + param1 = (f32[2,4]) parameter(0) // 11 + ROOT cond = pred[] constant(true) // 12 + } + + while_body.2 { + param2 = (f32[2,4]) parameter(0) // 13 + gte2 = f32[2,4] get-tuple-element(param2), index=0 // 14 + add = f32[2,4] add(gte2, gte2) // 15 + ROOT tuple2 = (f32[2,4]) tuple(add) // 16 + } + + while_condition.1 { + param3 = (f32[2,4]) parameter(0) // 5 + ROOT cond = pred[] constant(true) // 6 + } + + while_body.1 { + param4 = (f32[2,4]) parameter(0) // 7 + gte1 = f32[2,4] get-tuple-element(param4), index=0 // 8 + add1 = f32[2,4] add(gte1, gte1) // 9 + tuple1 = (f32[2,4]) tuple(add1) // 10 + while = (f32[2,4]) while(tuple1), condition=while_condition.2, body=while_body.2 // 17 + gte2 = f32[2,4] get-tuple-element(while), index=0 // 18 + add2 = f32[2,4] add(gte2, gte2) // 19 + ROOT tuple2 = (f32[2,4]) tuple(add2) // 20 + } + + ENTRY Entry { + param0 = f32[2,4] parameter(0) // 0 + a = f32[2,4] negate(param0) // 1 + b = f32[2,4] negate(a) // 2 + c = f32[2,4] negate(b) // 3 + tuple = (f32[2,4]) tuple(c) // 4 + while = (f32[2,4]) while(tuple), condition=while_condition.1, body=while_body.1 // 21 + gte1 = f32[2,4] get-tuple-element(while), index=0 // 22 + ROOT root = f32[2,4] add(gte1, param0) // 23 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloCostAnalysis hlo_cost_analysis(ShapeSize); + CostAnalysisOptions options; + TF_ASSERT_OK_AND_ASSIGN( + auto cost_analysis, + FakeCostAnalysis::Create(hlo_cost_analysis, *module, options)); + CostAnalysisPrefetchIntervalPicker interval_picker( + *cost_analysis, + /*min_overlap_to_async_copy_ratio=*/1.0, + /*preferred_overlap_to_async_copy_ratio=*/2.0, + /*max_overlap_to_mem_size_async_copy_ratio=*/12.0, + /*mem_size_bytes=*/32); + + HloInstruction* root = module->entry_computation()->root_instruction(); + const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; + const Shape& shape = root->operand(1)->shape(); + + // We expect the root's latest prefetch start time to be before the while loop + // (logical time 4). + EXPECT_EQ(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0, + /*end_time=*/23, &use), + 4); +} + +TEST_F(CostAnalysisPrefetchIntervalPickerTest, ConsecutiveConditionals) { + // This is a test for b/170668492, where prefetching for consecutive + // conditionals can cause the prefetch to start in the conditional's + // computation. + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + true_computation.0 { + p0 = (f32[3]{0}) parameter(0) // 5 + gte = f32[3]{0} get-tuple-element(p0), index=0 // 6 + ROOT neg1 = f32[3]{0} negate(gte) // 7 + } + + false_computation.0 { + p0 = (f32[3]{0}) parameter(0) // 8 + gte = f32[3]{0} get-tuple-element(p0), index=0 // 9 + ROOT neg2 = f32[3]{0} negate(gte) // 10 + } + + true_computation.1 { + p0 = (f32[3]{0}) parameter(0) // 12 + gte = f32[3]{0} get-tuple-element(p0), index=0 // 13 + ROOT neg1 = f32[3]{0} negate(gte) // 14 + } + + false_computation.1 { + p0 = (f32[3]{0}) parameter(0) // 15 + gte = f32[3]{0} get-tuple-element(p0), index=0 // 16 + ROOT neg2 = f32[3]{0} negate(gte) // 17 + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) // 0 + p1 = f32[3]{0} parameter(1) // 1 + p2 = pred[] parameter(2) // 2 + tuple0 = (f32[3]{0}) tuple(p0) // 3 + tuple1 = (f32[3]{0}) tuple(p1) // 4 + conditional0 = f32[3]{0} conditional(p2, tuple0, tuple0), true_computation=true_computation.0, false_computation=false_computation.0 // 11 + conditional1 = f32[3]{0} conditional(p2, tuple1, tuple1), true_computation=true_computation.1, false_computation=false_computation.1 // 18 + ROOT tuple2 = (f32[3]{0}, f32[3]{0}) tuple(conditional0, conditional1) // 19 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloCostAnalysis hlo_cost_analysis(ShapeSize); + CostAnalysisOptions options; + TF_ASSERT_OK_AND_ASSIGN( + auto cost_analysis, + FakeCostAnalysis::Create(hlo_cost_analysis, *module, options)); + CostAnalysisPrefetchIntervalPicker interval_picker( + *cost_analysis, + /*min_overlap_to_async_copy_ratio=*/1.0, + /*preferred_overlap_to_async_copy_ratio=*/2.0, + /*max_overlap_to_mem_size_async_copy_ratio=*/12.0, + /*mem_size_bytes=*/32); + + LOG(INFO) << module->ToString(); + + HloInstruction* conditional1 = + module->entry_computation()->GetInstructionWithName("conditional1"); + const HloUse use{conditional1, /*operand_number=*/1, /*operand_index=*/{0}}; + const Shape& shape = + module->entry_computation()->parameter_instruction(0)->shape(); + + // Expect that the prefetch to start before conditional0's called + // computations. + EXPECT_LT(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0, + /*end_time=*/11, &use), + 5); +} + +TEST_F(CostAnalysisPrefetchIntervalPickerTest, EarliestLatestWindowTooSmall) { + // This tests the scenario where there is an op that takes a long time (tanh + // in this example) and as a result the earliest and latest times both fall + // inside this long-running op. In this case, we should still return a valid + // prefetch interval just before the long-running op. + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + ENTRY Entry { + param0 = f32[2,4] parameter(0) + negate = f32[2,4] negate(param0) + tanh = f32[2,4] tanh(param0) + ROOT add = f32[2,4] add(tanh, negate) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloCostAnalysis hlo_cost_analysis(ShapeSize); + CostAnalysisOptions options; + TF_ASSERT_OK_AND_ASSIGN( + auto cost_analysis, + FakeCostAnalysis::Create(hlo_cost_analysis, *module, options)); + cost_analysis->SetOverrideForGetInstructionElapsed( + [](const HloInstruction& hlo) { + if (hlo.opcode() == HloOpcode::kTanh) { + return 20.0; + } + return 1.0; + }); + CostAnalysisPrefetchIntervalPicker interval_picker( + *cost_analysis, + /*min_overlap_to_async_copy_ratio=*/1.0, + /*preferred_overlap_to_async_copy_ratio=*/2.0, + /*max_overlap_to_mem_size_async_copy_ratio=*/12.0, + /*mem_size_bytes=*/32); + + HloInstruction* root = module->entry_computation()->root_instruction(); + const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; + interval_picker.Begin(use, /*start_time=*/1, /*end_time=*/3, std::nullopt); + + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_FALSE(interval_picker.Done()); + EXPECT_EQ(interval_picker.Next(), 1); + EXPECT_TRUE(interval_picker.Done()); +} + +} // namespace +} // namespace memory_space_assignment +} // namespace xla diff --git a/xla/service/memory_space_assignment/repacking.h b/xla/service/memory_space_assignment/repacking.h index 1556bd390c170..095fd8ded056e 100644 --- a/xla/service/memory_space_assignment/repacking.h +++ b/xla/service/memory_space_assignment/repacking.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,17 +17,10 @@ limitations under the License. #define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_REPACKING_H_ #include -#include -#include -#include -#include -#include "absl/algorithm/container.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/service/heap_simulator/allocation_block.h" #include "xla/statusor.h" -#include "xla/types.h" namespace xla { namespace memory_space_assignment { @@ -39,121 +32,11 @@ class MemorySpaceAssignmentRepacker { : max_size_(max_size), alignment_(alignment) {} virtual ~MemorySpaceAssignmentRepacker() = default; - // Data about a slice in a sliced allocation. - struct Slice { - int64_t size; - int64_t offset; - int64_t inclusive_start_time; - - std::string ToString() const { - return absl::StrCat("{ size: ", size, ", offset: ", offset, - ", inclusive_start_time: ", inclusive_start_time, - " }"); - } - - std::tuple ToTuple() const { - return std::make_tuple(size, offset, inclusive_start_time); - } - - bool operator==(const Slice& rhs) const { - return ToTuple() == rhs.ToTuple(); - } - }; - - // Slice data about a sliced allocation. - struct SlicedAllocationData { - std::vector slices_sorted_by_offset; - - std::vector SizesSortedByOffset() const { - std::vector sizes_sorted_by_offset; - sizes_sorted_by_offset.reserve(slices_sorted_by_offset.size()); - absl::c_for_each(slices_sorted_by_offset, - [&sizes_sorted_by_offset](const Slice& slice) { - sizes_sorted_by_offset.push_back(slice.size); - }); - return sizes_sorted_by_offset; - } - - std::vector SortedInclusiveStartTimes() const { - std::vector sorted_inclusive_start_times; - sorted_inclusive_start_times.reserve(slices_sorted_by_offset.size()); - absl::c_for_each(slices_sorted_by_offset, [&sorted_inclusive_start_times]( - const Slice& slice) { - sorted_inclusive_start_times.push_back(slice.inclusive_start_time); - }); - absl::c_sort(sorted_inclusive_start_times); - return sorted_inclusive_start_times; - } - - std::string ToString() const { - return absl::StrCat( - "{ slices_sorted_by_offset: [ ", - absl::StrJoin(slices_sorted_by_offset, ", ", - [](std::string* out, const Slice& slice) { - absl::StrAppend(out, slice.ToString()); - }), - " ] }"); - } - - bool operator==(const SlicedAllocationData& rhs) const { - return slices_sorted_by_offset == rhs.slices_sorted_by_offset; - } - }; - - // A contiguous block of allocation consisting of start and end (logical) - // times, size, and the initial offset. After repacking, if the repacking was - // successful and the allocations were modified, the offset field holds the - // new offset. To support aliased allocations, AllocationBlock also includes a - // vector of AllocationBlock pointers, called colocations. All AllocationBlock - // objects within the colocations must get the same offset. The id should be - // unique and is used to ensure determinism for comparison tie-breaker. - // - // Each AllocationBlock can be treated as an allocation that requires size - // space from start_time to end_time. However, some allocations are really - // composed of slices. In such cases, the repacker can utilize - // the information in the original_slice_data field to achieve an even more - // efficient repacking. - struct AllocationBlock { - int64_t inclusive_start_time; - int64_t end_time; - int64_t size; - int64_t offset; - int64_t initial_offset; - int64_t id; - std::vector colocations; - - // Optional data structures that are used to improve repacking, when an - // allocation is sliced, e.g., from a sliced prefetch. - std::optional original_slice_data; - std::optional repacked_slice_data; - - std::string ToString() const { - std::string original_slicing_str; - if (original_slice_data.has_value()) { - original_slicing_str = absl::StrCat("; original_slice_data: ", - original_slice_data->ToString()); - } - std::string repacked_slicing_str; - if (repacked_slice_data.has_value()) { - repacked_slicing_str = absl::StrCat("; repacked_slice_data: ", - repacked_slice_data->ToString()); - } - return absl::StrCat("[", inclusive_start_time, ", ", end_time, - "]; size: ", size, "; offset: ", offset, - "; initial offset: ", initial_offset, - "; # colocations: ", colocations.size(), - original_slicing_str, repacked_slicing_str); - } - - // This is required by BufferIntervalCompare as a tie breaker. Use a unique - // and deterministic id. - bool operator<(const AllocationBlock& other) const { return id < other.id; } - }; - // Repack the AllocationBlocks provided in the parameter. Returns true if // allocations have been modified and false if not. Returns a non-ok status if // there was an error. - virtual StatusOr Repack(absl::Span allocations) = 0; + virtual absl::StatusOr Repack( + absl::Span allocations) = 0; protected: int64_t max_size_; diff --git a/xla/service/memory_space_assignment/slice.cc b/xla/service/memory_space_assignment/slice.cc new file mode 100644 index 0000000000000..e550e965f804f --- /dev/null +++ b/xla/service/memory_space_assignment/slice.cc @@ -0,0 +1,89 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/memory_space_assignment/slice.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "xla/service/heap_simulator/heap_simulator.h" +#include "xla/shape.h" + +namespace xla::memory_space_assignment { + +std::tuple +SliceDecisionToTuple(const SliceDecision& decision) { + return std::make_tuple( + std::ref(decision.chunk), decision.exclusive_start_time, + std::ref(decision.sizing), decision.copy_resource_consumed); +} + +std::string SliceDecision::ToString() const { + return absl::StrCat("{ chunk: ", chunk.ToString(), + ", (exclusive) start_time: ", exclusive_start_time, + ", sizing: ", sizing.ToString(), + ", copy_resource_consumed: ", copy_resource_consumed, + " }"); +} + +bool SliceDecision::operator==(const SliceDecision& other) const { + return SliceDecisionToTuple(*this) == SliceDecisionToTuple(other); +} + +std::string SliceProposal::ToString() const { + return absl::StrCat( + "{ slice_shape: ", slice_shape.ToString(true), ", slice_params: { ", + absl::StrJoin(slice_params, ", ", + [](std::string* out, const SliceParam& param) { + absl::StrAppend(out, param.ToString()); + }), + " }, slice_size: ", slice_size, " }"); +} + +std::ostream& operator<<(std::ostream& os, const SliceProposal& proposal) { + os << proposal.ToString(); + return os; +} + +std::tuple&, int64_t> +SliceProposal::ToTuple() const { + return std::make_tuple(std::ref(slice_shape), std::ref(slice_params), + slice_size); +} + +bool SliceProposal::operator==(const SliceProposal& other) const { + return ToTuple() == other.ToTuple(); +} + +std::string SliceParam::ToString() const { + return absl::StrCat("[", start_inclusive, ",", end_exclusive, ")"); +} + +bool SliceParam::operator==(const SliceParam& other) const { + return start_inclusive == other.start_inclusive && + end_exclusive == other.end_exclusive; +} + +bool IsUniformSliceSizingEnabled(const SlicedPrefetchOptions& options) { + return options.max_slices() > 0 && options.preferred_slice_size() > 0; +} + +} // namespace xla::memory_space_assignment diff --git a/xla/service/memory_space_assignment/slice.h b/xla/service/memory_space_assignment/slice.h new file mode 100644 index 0000000000000..3d1fe279e36fe --- /dev/null +++ b/xla/service/memory_space_assignment/slice.h @@ -0,0 +1,120 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file contains definitions for MSA slicing. Slicing is an allocation +// technique in which we allocate a buffer in slices that can start at different +// times, but once allocated, form a contiguous buffer. When copying buffers, we +// may want to allocate a buffer in slices, so that we delay allocating memory +// that would otherwise not be in use, due to copy bandwidth constraints. +// +// The following illustrates a buffer that is fully allocated at time t2, via +// slices. +// +// space +// ^ +// p3 | +-----------+ +// | | s2 | +// p2 | +---+-----------+ +// | | s1 | +// p1 | +-------+-------+ +// | | s0 | +// p0 | +-------+ +// +---|---|---|---|---|----> time +// t0 t1 t2 t3 t4 + +#ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_SLICE_H_ +#define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_SLICE_H_ + +#include +#include +#include +#include +#include + +#include "xla/service/heap_simulator/heap_simulator.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" +#include "xla/shape.h" +#include "xla/shape_util.h" + +namespace xla::memory_space_assignment { + +// The target of a custom call that slicing uses to concatenate slices +// that are already contiguous in memory, into a larger buffer. +inline constexpr char kConcatBitcastCustomCall[] = "ConcatBitcast"; + +// The parameters for slicing a single dimension of a tensor. +struct SliceParam { + std::string ToString() const; + bool operator==(const SliceParam& other) const; + + int64_t start_inclusive; + int64_t end_exclusive; +}; + +// A proposed way to slice a buffer. +struct SliceProposal { + std::string ToString() const; + friend std::ostream& operator<<(std::ostream& os, + const SliceProposal& proposal); + std::tuple&, int64_t> ToTuple() + const; + bool operator==(const SliceProposal& other) const; + + // Shape resulting from the slice. + Shape slice_shape; + + // slice_params map to the parameters that would be passed to a slice + // instruction. Thus: + // * There should be a slice parameter for every dimension in the shape of + // the tensor being sliced. + // * The ith slice_param applies to the ith logical dimension in the shape + // being sliced. + // * If a dimension is not being sliced, it should have a SliceParam of + // {0, dim size}. + std::vector slice_params; + + // The size to be allocated for the slice. Note, this may be > the size of + // the slice shape, due to additional padding that may occur when the slices + // are concatenated back together. + int64_t slice_size; +}; + +// A SliceProposalCollection proposes a way to to slice an AllocationRequest. +// A SliceProposalCollection is generated from a SliceProposalFunction and is +// used when we want to slice a prefetch. +using SliceProposalCollection = std::vector; +using SliceProposalFunction = + std::function( + const Shape& shape, const SlicedPrefetchOptions& options)>; + +// A SliceDecision is a SliceProposal that we've determined where and when to +// allocate. +struct SliceDecision { + std::string ToString() const; + bool operator==(const SliceDecision& other) const; + + HeapSimulator::Chunk chunk; + int64_t exclusive_start_time; + SliceProposal sizing; + float copy_resource_consumed; +}; + +// Returns true if the options indicates that there is a preferred slice +// size. +bool IsUniformSliceSizingEnabled(const SlicedPrefetchOptions& options); + +} // namespace xla::memory_space_assignment + +#endif // XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_SLICE_H_ diff --git a/xla/service/memory_space_assignment/testing_utils.h b/xla/service/memory_space_assignment/testing_utils.h new file mode 100644 index 0000000000000..ccea37b88b470 --- /dev/null +++ b/xla/service/memory_space_assignment/testing_utils.h @@ -0,0 +1,128 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_TESTING_UTILS_H_ +#define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_TESTING_UTILS_H_ + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/utils/hlo_live_range.h" +#include "xla/service/call_graph.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/statusor.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace memory_space_assignment { + +// For testing purposes, we define a cost analysis where we can control the +// elapsed times of each HLO and asynchronous copy. +class FakeCostAnalysis : public CostAnalysis { + public: + static absl::StatusOr> Create( + const HloCostAnalysis& cost_analysis, const HloModule& module, + const CostAnalysisOptions& options) { + TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module)); + TF_ASSIGN_OR_RETURN(auto hlo_live_range, + HloLiveRange::Run(module.schedule(), *alias_analysis, + module.entry_computation())); + auto call_graph = CallGraph::Build(&module); + return absl::WrapUnique( + new FakeCostAnalysis(cost_analysis, options, std::move(alias_analysis), + std::move(hlo_live_range), std::move(call_graph))); + } + + float GetInstructionElapsed( + const HloInstruction& instruction) const override { + if (get_instruction_elapsed_override_) { + return get_instruction_elapsed_override_(instruction); + } + return 1.0; + } + + float GetInstructionElapsedInAlternateMemory( + const HloInstruction& instruction, + absl::Span> + operands_in_alternate_mem, + absl::Span outputs_in_alternate_mem) const override { + if (get_instruction_elapsed_in_alternate_memory_override_) { + return get_instruction_elapsed_in_alternate_memory_override_( + instruction, operands_in_alternate_mem, outputs_in_alternate_mem); + } + if (!operands_in_alternate_mem.empty()) { + return 0.5; + } else { + return 1.0; + } + } + + float GetAsyncCopyElapsed(const Shape& shape) const override { + if (get_async_copy_elapsed_override_) { + return get_async_copy_elapsed_override_(shape); + } + return 3.0; + } + + // The following methods can be used to override what the above API calls + // return. + void SetOverrideForGetInstructionElapsed( + std::function function) { + get_instruction_elapsed_override_ = function; + } + void SetOverrideForGetInstructionElapsedInAlternateMemory( + std::function>, + absl::Span)> + function) { + get_instruction_elapsed_in_alternate_memory_override_ = function; + } + void SetOverrideForGetAsyncCopyElapsed( + std::function function) { + get_async_copy_elapsed_override_ = function; + } + + protected: + FakeCostAnalysis(const HloCostAnalysis& cost_analysis, + const CostAnalysisOptions& options, + std::unique_ptr alias_analysis, + std::unique_ptr hlo_live_range, + std::unique_ptr call_graph) + : CostAnalysis(cost_analysis, options, std::move(alias_analysis), + std::move(hlo_live_range), std::move(call_graph)) {} + + private: + std::function + get_instruction_elapsed_override_ = nullptr; + std::function>, + absl::Span)> + get_instruction_elapsed_in_alternate_memory_override_ = nullptr; + std::function get_async_copy_elapsed_override_ = nullptr; +}; + +} // namespace memory_space_assignment +} // namespace xla + +#endif // XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_TESTING_UTILS_H_ diff --git a/xla/service/memory_space_assignment/tuning_utils.cc b/xla/service/memory_space_assignment/tuning_utils.cc index b24409dcd4789..f039b7199a283 100644 --- a/xla/service/memory_space_assignment/tuning_utils.cc +++ b/xla/service/memory_space_assignment/tuning_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/memory_space_assignment/tuning_utils.h b/xla/service/memory_space_assignment/tuning_utils.h index 749b4445e4be9..86354591562af 100644 --- a/xla/service/memory_space_assignment/tuning_utils.h +++ b/xla/service/memory_space_assignment/tuning_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ limitations under the License. #define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_TUNING_UTILS_H_ #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/heap_simulator.h" +#include "xla/service/heap_simulator/heap_simulator.h" namespace xla { diff --git a/xla/service/memory_space_assignment/utils.cc b/xla/service/memory_space_assignment/utils.cc index 2b1eff9a07486..1f7321d0f1888 100644 --- a/xla/service/memory_space_assignment/utils.cc +++ b/xla/service/memory_space_assignment/utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,12 +17,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/hlo_value.h" namespace xla { namespace memory_space_assignment { bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory( - const HloValue* value) { + const HloValue* value, int64_t alternate_memory_space) { // If the buffer is a tuple, don't use this algorithm for now. The buffers // that are pointed to by the tuple will still use this algorithm. Because // tuples are cheap to place in the alternate memory (they are just pointers) @@ -66,19 +67,19 @@ bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory( return false; } - // WARNING (b/259460539): output_to_operand_aliasing was moved from - // HloCustomCallInstruction to HloCallableInstruction so that fusions can - // also be annotated with this aliasing. This feature might not be complete. - if (auto* callable = - DynCast(position.instruction)) { - for (const auto& pair : callable->output_to_operand_aliasing()) { - if (position.index == pair.first) { - VLOG(4) << "Keeping value " << value->ToShortString() - << " in default mem because it is a custom-call/fusion output" - " that aliases an operand buffer."; - return false; - } - } + // If the tensor is pre-colored to a memory space that is neither the + // default (0) nor the alternate, disallow it from the alternate memory + // space. + int64_t memory_space = 0; + if (position.shape().has_layout()) { + memory_space = position.shape().layout().memory_space(); + } + if (memory_space != 0 && memory_space != alternate_memory_space) { + VLOG(4) << "Value " << value->ToShortString() + << " not allowed in the alternate memory space due to existing " + "memory space: " + << memory_space; + return false; } } @@ -86,9 +87,15 @@ bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory( } bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) { - return IsValueAllowedInAlternateMemory(interval.buffer) && - absl::c_all_of(interval.colocations, IsValueAllowedInAlternateMemory); + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, + int64_t alternate_memory_space) { + return IsValueAllowedInAlternateMemory(interval.buffer, + alternate_memory_space) && + absl::c_all_of(interval.colocations, + [alternate_memory_space](const HloValue* value) { + return IsValueAllowedInAlternateMemory( + value, alternate_memory_space); + }); } } // namespace memory_space_assignment diff --git a/xla/service/memory_space_assignment/utils.h b/xla/service/memory_space_assignment/utils.h index c2d76208a28e1..eced50e6dcab3 100644 --- a/xla/service/memory_space_assignment/utils.h +++ b/xla/service/memory_space_assignment/utils.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_UTILS_H_ #define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_UTILS_H_ -#include "xla/service/heap_simulator.h" +#include "xla/service/heap_simulator/heap_simulator.h" namespace xla { namespace memory_space_assignment { @@ -27,11 +27,12 @@ class MemorySpaceAssignmentUtils { // Returns true if this buffer is allowed to be placed in the alternate // memory. static bool IsIntervalAllowedInAlternateMemory( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& - interval); + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, + int64_t alternate_memory_space); // Returns true if the HloValue is allowed to be placed in alternate memory. - static bool IsValueAllowedInAlternateMemory(const HloValue* value); + static bool IsValueAllowedInAlternateMemory(const HloValue* value, + int64_t alternate_memory_space); }; } // namespace memory_space_assignment diff --git a/xla/service/memory_space_propagation.cc b/xla/service/memory_space_propagation.cc index e203a1ddb3765..fa8bf710118c2 100644 --- a/xla/service/memory_space_propagation.cc +++ b/xla/service/memory_space_propagation.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,9 +15,14 @@ limitations under the License. #include "xla/service/memory_space_propagation.h" +#include + +#include "xla/shape.h" +#include "xla/shape_util.h" + namespace xla { -StatusOr MemorySpacePropagation::Run( +absl::StatusOr MemorySpacePropagation::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool modified = false; @@ -37,24 +42,24 @@ StatusOr MemorySpacePropagation::Run( // Propagate the operand subshapes. for (int operand_idx = 0; operand_idx < instruction->operand_count(); ++operand_idx) { - for (const ShapeUtil::IndexedShape& indexed_shape : - ShapeUtil::GetLeafShapes( - instruction->operand(operand_idx)->shape())) { - int64_t memory_space = indexed_shape.shape.layout().memory_space(); - modified |= Propagate(indexed_shape.index, - instruction->fused_parameter(operand_idx), - memory_space); - } + ShapeUtil::ForEachLeafShape( + instruction->operand(operand_idx)->shape(), + [&](const Shape& sub_shape, const ShapeIndex& index) { + int64_t memory_space = sub_shape.layout().memory_space(); + modified |= + Propagate(index, instruction->fused_parameter(operand_idx), + memory_space); + }); } // Propagate output subshapes. - for (const ShapeUtil::IndexedShape& indexed_shape : - ShapeUtil::GetLeafShapes(instruction->shape())) { - int64_t memory_space = indexed_shape.shape.layout().memory_space(); - modified |= - Propagate(indexed_shape.index, - instruction->fused_expression_root(), memory_space); - } + ShapeUtil::ForEachLeafShape( + instruction->shape(), + [&](const Shape& sub_shape, const ShapeIndex& index) { + int64_t memory_space = sub_shape.layout().memory_space(); + modified |= Propagate(index, instruction->fused_expression_root(), + memory_space); + }); } } } diff --git a/xla/service/memory_space_propagation.h b/xla/service/memory_space_propagation.h index 1038ac146ec3e..4de4b2e47ca1a 100644 --- a/xla/service/memory_space_propagation.h +++ b/xla/service/memory_space_propagation.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ class MemorySpacePropagation : public HloModulePass { ~MemorySpacePropagation() override = default; absl::string_view name() const override { return "memory-space-propagation"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/memory_space_propagation_test.cc b/xla/service/memory_space_propagation_test.cc index dbb87c67b51cb..da2059817cd4d 100644 --- a/xla/service/memory_space_propagation_test.cc +++ b/xla/service/memory_space_propagation_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/metrics.proto b/xla/service/metrics.proto index 910616b35dcff..90325b70fcc6f 100644 --- a/xla/service/metrics.proto +++ b/xla/service/metrics.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package xla; +import "google/protobuf/any.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/timestamp.proto"; @@ -13,6 +14,10 @@ message PassMetrics { string pass_name = 2; // Duration of the pass. google.protobuf.Duration pass_duration = 3; + // Custom pass metrics. This is kept opaque, via `google.protobuf.Any`, in + // order to decouple pass agnostic compilation logs from possibly proprietary + // compiler passes. + google.protobuf.Any custom_metrics = 4; } // Defines XLA compilation metrics. diff --git a/xla/service/metrics_hook_interface.h b/xla/service/metrics_hook_interface.h index bf36461137376..d8c0a45d0a402 100644 --- a/xla/service/metrics_hook_interface.h +++ b/xla/service/metrics_hook_interface.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/multi_output_fusion.cc b/xla/service/multi_output_fusion.cc index 3af2e97fc9660..779a292ac4334 100644 --- a/xla/service/multi_output_fusion.cc +++ b/xla/service/multi_output_fusion.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,13 +29,18 @@ limitations under the License. namespace xla { -StatusOr MultiOutputFusion::Run( +absl::StatusOr MultiOutputFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; for (auto* computation : module->MakeNonfusionComputations(execution_threads)) { + // Do not operate over async computations (computations of async + // instructions). + if (computation->IsAsyncComputation()) { + continue; + } computation_ = computation; candidates_.clear(); candidates_index_.clear(); diff --git a/xla/service/multi_output_fusion.h b/xla/service/multi_output_fusion.h index 8e120c2953f92..8add1bb17cf9c 100644 --- a/xla/service/multi_output_fusion.h +++ b/xla/service/multi_output_fusion.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -56,7 +56,7 @@ class MultiOutputFusion : public HloModulePass { // Run multi-output fusion on the given module. Returns whether the module // was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/name_uniquer.cc b/xla/service/name_uniquer.cc index 04eb2a9cd597a..6fb7351251b57 100644 --- a/xla/service/name_uniquer.cc +++ b/xla/service/name_uniquer.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/name_uniquer.h b/xla/service/name_uniquer.h index 88adc5ce0b6a3..5f0a83f148920 100644 --- a/xla/service/name_uniquer.h +++ b/xla/service/name_uniquer.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/name_uniquer_test.cc b/xla/service/name_uniquer_test.cc index 6fb93523155ab..6ebdfffedb73d 100644 --- a/xla/service/name_uniquer_test.cc +++ b/xla/service/name_uniquer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/op_expander_pass.cc b/xla/service/op_expander_pass.cc index 4e557dfc15eef..318211dce1f08 100644 --- a/xla/service/op_expander_pass.cc +++ b/xla/service/op_expander_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,17 +15,21 @@ limitations under the License. #include "xla/service/op_expander_pass.h" -#include +#include +#include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/hlo_creation_utils.h" -#include "xla/statusor.h" #include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { -StatusOr OpExpanderPass::Run( +absl::StatusOr OpExpanderPass::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { std::vector matching_instructions; @@ -45,7 +49,11 @@ StatusOr OpExpanderPass::Run( if (expanded_root == nullptr) { continue; } - TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root)); + TF_ASSIGN_OR_RETURN(bool changed, + inst->parent()->ReplaceInstruction( + inst, expanded_root, preserve_sharding_, + relay_control_dependency_)); + DCHECK(changed); } return !matching_instructions.empty(); diff --git a/xla/service/op_expander_pass.h b/xla/service/op_expander_pass.h index ff99b1dd51b91..c86c3a44f5563 100644 --- a/xla/service/op_expander_pass.h +++ b/xla/service/op_expander_pass.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,14 +28,20 @@ namespace xla { class OpExpanderPass : public HloModulePass { public: using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; // extra_filter: Optional extra filtering criteria for matching instructions, // used in conjunction with InstructionMatchesPattern. - explicit OpExpanderPass(HloPredicate extra_filter = nullptr) - : extra_filter_(std::move(extra_filter)) {} + // preserve_sharding and relay_control_dependency: If we preserve sharding and + // relay control dependency when replacing the matched instructions. + explicit OpExpanderPass(HloPredicate extra_filter = nullptr, + bool preserve_sharding = false, + bool relay_control_dependency = false) + : extra_filter_(std::move(extra_filter)), + preserve_sharding_(preserve_sharding), + relay_control_dependency_(relay_control_dependency) {} protected: // Returns `true` if `instruction` should be expanded by this pass. @@ -44,10 +50,12 @@ class OpExpanderPass : public HloModulePass { // Returns a replacement for `instruction`, or nullptr if no replacement is // needed (e.g. only the to_apply subcomputation of the instruction was // modified). - virtual StatusOr ExpandInstruction( + virtual absl::StatusOr ExpandInstruction( HloInstruction* instruction) = 0; HloPredicate extra_filter_; + const bool preserve_sharding_; + const bool relay_control_dependency_; }; } // namespace xla diff --git a/xla/service/operand_upcaster.cc b/xla/service/operand_upcaster.cc index e954629b4efec..81c30c88fc2fc 100644 --- a/xla/service/operand_upcaster.cc +++ b/xla/service/operand_upcaster.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,22 +17,32 @@ limitations under the License. #include +#include "absl/algorithm/container.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/shape_inference.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { -StatusOr> MaybeInferShape( +absl::StatusOr> MaybeInferShape( const HloInstruction* instruction) { switch (instruction->opcode()) { case HloOpcode::kDot: return ShapeInference::InferDotOpShape( instruction->operand(0)->shape(), instruction->operand(1)->shape(), instruction->dot_dimension_numbers(), - /*preferred_element_type=*/std::nullopt); + /*preferred_element_type=*/std::nullopt, + Cast(instruction)->sparsity()); case HloOpcode::kConvolution: return ShapeInference::InferConvolveShape( instruction->operand(0)->shape(), instruction->operand(1)->shape(), @@ -59,19 +69,17 @@ bool OperandUpcaster::InstructionMatchesPattern(HloInstruction* instruction) { return true; } - const Shape& inferred_shape = status_or_inferred_shape.value().value(); - if (inferred_shape.element_type() == instruction->shape().element_type() && - absl::c_all_of(instruction->operands(), - [&](const HloInstruction* operand) { - return operand->shape().element_type() == - inferred_shape.element_type(); - })) { + PrimitiveType inferred_type = (*status_or_inferred_shape)->element_type(); + if (instruction->shape().element_type() == inferred_type && + instruction->operand(0)->shape().element_type() == inferred_type && + instruction->operand(1)->shape().element_type() == inferred_type) { return false; } - return ShapeUtil::ElementCanUpcast(inferred_shape, instruction->shape()); + return ShapeUtil::ElementCanUpcast(**status_or_inferred_shape, + instruction->shape()); } -StatusOr OperandUpcaster::ExpandInstruction( +absl::StatusOr OperandUpcaster::ExpandInstruction( HloInstruction* instruction) { const bool packed_nibble = absl::c_count(instruction->precision_config().operand_precision(), @@ -123,7 +131,7 @@ StatusOr OperandUpcaster::ExpandInstruction( return MakeBinaryHlo(HloOpcode::kAdd, linear_n0, linear_n1); } - for (int i = 0; i < instruction->operand_count(); ++i) { + for (int i = 0; i < HloDotInstruction::kOperands; ++i) { auto* operand = instruction->mutable_operand(i); if (operand->shape().element_type() == type) { continue; diff --git a/xla/service/operand_upcaster.h b/xla/service/operand_upcaster.h index 8261ae2987f39..d89daf4e415d0 100644 --- a/xla/service/operand_upcaster.h +++ b/xla/service/operand_upcaster.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,8 +16,13 @@ limitations under the License. #ifndef XLA_SERVICE_OPERAND_UPCASTER_H_ #define XLA_SERVICE_OPERAND_UPCASTER_H_ -#include "xla/hlo/ir/hlo_module.h" +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/op_expander_pass.h" +#include "xla/util.h" namespace xla { @@ -33,7 +38,7 @@ class OperandUpcaster : public OpExpanderPass { protected: bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; }; diff --git a/xla/service/operand_upcaster_test.cc b/xla/service/operand_upcaster_test.cc index 0364f135533ae..37a8b0657c894 100644 --- a/xla/service/operand_upcaster_test.cc +++ b/xla/service/operand_upcaster_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,10 +15,16 @@ limitations under the License. #include "xla/service/operand_upcaster.h" +#include +#include + +#include "absl/strings/string_view.h" #include "absl/strings/substitute.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/primitive_util.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -109,6 +115,32 @@ INSTANTIATE_TEST_SUITE_P(NoUpcast, OperandUpcasterTest, ::testing::Values(std::make_tuple(F32, F32, BF16), std::make_tuple(S32, S32, U32))); +TEST_F(OperandUpcasterTest, SparseDot) { + absl::string_view kHlo = R"( + HloModule module + + ENTRY main { + p0 = bf16[2,16]{1,0} parameter(0) + p1 = bf16[32,2]{1,0} parameter(1) + meta = u16[2,2]{1,0} parameter(2) + ROOT dot = f32[2,2]{1,0} dot(p0, p1, meta), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4 + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHlo)); + TF_ASSERT_OK_AND_ASSIGN(bool upcasted, OperandUpcaster().Run(module.get())); + EXPECT_TRUE(upcasted); + auto upcasted_lhs = + AllOf(op::Convert(op::Parameter(0)), op::Shape("f32[2,16]{1,0}")); + auto upcasted_rhs = + AllOf(op::Convert(op::Parameter(1)), op::Shape("f32[32,2]{1,0}")); + EXPECT_THAT(module->entry_computation()->root_instruction(), + AllOf(::testing::MakeMatcher(new ::xla::testing::HloMatcher( + HloOpcode::kDot, + {upcasted_lhs, upcasted_rhs, op::Parameter(2)})), + op::Shape("f32[2,2]{1,0}"))); +} + } // namespace } // namespace xla diff --git a/xla/service/optimization_barrier_expander.cc b/xla/service/optimization_barrier_expander.cc index 87f122c79bfa0..877fcb1670236 100644 --- a/xla/service/optimization_barrier_expander.cc +++ b/xla/service/optimization_barrier_expander.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ limitations under the License. namespace xla { -StatusOr OptimizationBarrierExpander::Run( +absl::StatusOr OptimizationBarrierExpander::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { std::vector barriers; diff --git a/xla/service/optimization_barrier_expander.h b/xla/service/optimization_barrier_expander.h index 50c148dfc0da7..b614b80d8f3e4 100644 --- a/xla/service/optimization_barrier_expander.h +++ b/xla/service/optimization_barrier_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,9 +27,8 @@ class OptimizationBarrierExpander : public HloModulePass { absl::string_view name() const override { return "cse_barrier_expander"; } - protected: using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/optimize_input_output_buffer_alias.cc b/xla/service/optimize_input_output_buffer_alias.cc index 697b288d3ed59..c2a16f372df5a 100644 --- a/xla/service/optimize_input_output_buffer_alias.cc +++ b/xla/service/optimize_input_output_buffer_alias.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ limitations under the License. namespace xla { -StatusOr OptimizeInputOutputBufferAlias::Build( +absl::StatusOr OptimizeInputOutputBufferAlias::Build( absl::Span input_shapes, const Shape& output_shape, HloInputOutputAliasConfig* alias_config, HloBufferDonorConfig* buffer_donor_config) { @@ -130,7 +130,7 @@ StatusOr OptimizeInputOutputBufferAlias::Build( return changed; } -StatusOr OptimizeInputOutputBufferAlias::Run( +absl::StatusOr OptimizeInputOutputBufferAlias::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { // We exactly follow HloInputOutputAliasConfig::Verify to create input_shapes diff --git a/xla/service/optimize_input_output_buffer_alias.h b/xla/service/optimize_input_output_buffer_alias.h index 989e173a712c3..d8d618aff6930 100644 --- a/xla/service/optimize_input_output_buffer_alias.h +++ b/xla/service/optimize_input_output_buffer_alias.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -55,11 +55,11 @@ class OptimizeInputOutputBufferAlias : public HloModulePass { ~OptimizeInputOutputBufferAlias() override = default; absl::string_view name() const override { - return "optimize_input_output_buffer_alias.h"; + return "optimize_input_output_buffer_alias"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -74,10 +74,10 @@ class OptimizeInputOutputBufferAlias : public HloModulePass { // Match buffer donors and donees and save the matched paired in the // alias_config. The availability of buffer donors is controlled by the flag // registered_buffer_donor_only_. - StatusOr Build(absl::Span input_shapes, - const Shape& output_shape, - HloInputOutputAliasConfig* alias_config, - HloBufferDonorConfig* buffer_donor_config); + absl::StatusOr Build(absl::Span input_shapes, + const Shape& output_shape, + HloInputOutputAliasConfig* alias_config, + HloBufferDonorConfig* buffer_donor_config); std::function shape_size_fn_ = [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); diff --git a/xla/service/optimize_input_output_buffer_alias_test.cc b/xla/service/optimize_input_output_buffer_alias_test.cc index c8df5b9b08d68..fd63ae727e66e 100644 --- a/xla/service/optimize_input_output_buffer_alias_test.cc +++ b/xla/service/optimize_input_output_buffer_alias_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/p2p_schedule_preparation.cc b/xla/service/p2p_schedule_preparation.cc index 41319120b60c0..782b807763f82 100644 --- a/xla/service/p2p_schedule_preparation.cc +++ b/xla/service/p2p_schedule_preparation.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/container/flat_hash_map.h" @@ -31,8 +32,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_reachability.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/service/collective_ops_utils.h" #include "xla/status.h" -#include "xla/statusor.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -47,15 +48,10 @@ namespace { // (2) we need to exclude host P2P operations when looking for a nested chain // of non-host P2P operations. bool IsP2POp(const HloInstruction* op) { - auto p2p = DynCastOrNull(op); + auto p2p = DynCast(op); return p2p != nullptr && !p2p->is_host_transfer(); } -bool IsP2PDoneOp(const HloInstruction* op) { - return IsP2POp(op) && (op->opcode() == HloOpcode::kRecvDone || - op->opcode() == HloOpcode::kSendDone); -} - // Returns whether the instruction is a collective operation, for the purpose // of detecting whether the computation directly invokes collective // operations. As such, we only need to detect one of the instructions for a @@ -71,11 +67,9 @@ bool IsCollectiveOp(const HloInstruction* op) { return true; } - return hlo_query::IsAsyncCollectiveDoneOp(opcode, - /*include_send_recv=*/true) || + return hlo_query::IsAsyncCollectiveDoneOp(op, /*include_send_recv=*/true) || (hlo_query::IsCollectiveCommunicationOp(opcode) && - !hlo_query::IsAsyncCollectiveStartOp(opcode, - /*include_send_recv=*/true)); + !hlo_query::IsAsyncCollectiveStartOp(op, /*include_send_recv=*/true)); } // Returns the corresponding Done op if the input is a Start op. Otherwise, @@ -95,6 +89,8 @@ HloInstruction* GetStartOpForDoneOp(HloInstruction* op) { enum P2PGroupKind { kUnpipelined = 0, kPipelined = 1, kUnrecognized = 2 }; +enum P2PRuntimeStream { kUnknown = 0, kStream0 = 1, kStream1 = 2 }; + // A P2P group node represents the P2P instructions that are in the same // computation and have the same channel ID. This includes one Send/SendDone // and one Recv/RecvDone. If the P2P instructions for the given channel ID are @@ -110,21 +106,38 @@ struct P2PGroupNode { return computation == parent; } - bool RecordDoneOp(HloSendRecvInstruction* p2p) { + bool RecordP2POp(HloSendRecvInstruction* p2p) { if (!RecordParentComputation(p2p->parent())) { return false; } - if (p2p->opcode() == HloOpcode::kRecvDone) { - if (recv_done == nullptr) { - recv_done = Cast(p2p); - return true; - } - } else if (p2p->opcode() == HloOpcode::kSendDone) { - if (send_done == nullptr) { - send_done = Cast(p2p); - return true; - } + switch (p2p->opcode()) { + case HloOpcode::kRecvDone: + if (recv_done == nullptr) { + recv_done = Cast(p2p); + return true; + } + break; + case HloOpcode::kSendDone: + if (send_done == nullptr) { + send_done = Cast(p2p); + return true; + } + break; + case HloOpcode::kRecv: + if (recv == nullptr) { + recv = Cast(p2p); + return true; + } + break; + case HloOpcode::kSend: + if (send == nullptr) { + send = Cast(p2p); + return true; + } + break; + default: + break; } return false; } @@ -141,15 +154,45 @@ struct P2PGroupNode { } bool Incomplete() const { - return recv_done == nullptr || send_done == nullptr; + return recv_done == nullptr || send_done == nullptr || recv == nullptr || + send == nullptr; } bool IncompletePipelinedParent() const { return Incomplete() || while_loop == nullptr; } + // Returns the pipeline stream used to execute the P2P instructions in the + // group. + P2PRuntimeStream GetRuntimeStream(const HloInstruction* start) const { + auto it = start->frontend_attributes().map().find(kSendRecvPipelineAttr); + if (it != start->frontend_attributes().map().end()) { + if (it->second == "0") { + return kStream0; + } + if (it->second == "1") { + return kStream1; + } + } + return kUnknown; + } + + // Finds the pipeline stream from the frontend attribute of the Send/Recv in + // the pipeline group node, verifies they both have the same value and returns + // the stream. + P2PRuntimeStream GetRuntimeStream() const { + P2PRuntimeStream send_stream = GetRuntimeStream(send); + P2PRuntimeStream recv_stream = GetRuntimeStream(recv); + if (send_stream != recv_stream) { + return kUnknown; + } + return send_stream; + } + HloRecvDoneInstruction* recv_done = nullptr; HloSendDoneInstruction* send_done = nullptr; + HloRecvInstruction* recv = nullptr; + HloSendInstruction* send = nullptr; // The computation that contains the Send and Recv instructions. HloComputation* computation = nullptr; // The while-loop instruction that calls the while-body with the pipelined @@ -157,6 +200,26 @@ struct P2PGroupNode { HloInstruction* while_loop = nullptr; }; +// Maps a channel ID to the corresponding P2P operation group. +struct P2PGroup; +using P2PGroupMap = absl::flat_hash_map; + +// Maps a computation to the channel IDs used by the computation for P2P +// operations. We use std::set instead of hash set for deterministic +// iterators. +using P2PInComputation = + absl::flat_hash_map>; + +// Maps a computation to a boolean that indicates whether the computation +// invokes collective operations directly or indirectly. +using CollectiveInComputation = + absl::flat_hash_map; + +// Represents the start and end of a region marked by an ordered P2P instruction +// chain. +using ChainStartEnd = + std::pair; + static constexpr int kUnpipelinedNodeIdx = 0; static constexpr int kPipelinedChildNodeIdx = 0; static constexpr int kPipelinedParentNodeIdx = 1; @@ -166,35 +229,37 @@ static constexpr int kPipelinedParentNodeIdx = 1; // A kUnpipelined P2P group contains only one P2PGroupNode while a kPipelined // P2P group contains a P2PGroupNode for the while-body and a P2PGroupNode // for the computation with the while-loop instruction calling the while-body. +// If a group forms a cycle with another group, records the other group as a +// complement group. struct P2PGroup { - Status RecordDoneOpForUnpipelinedGroup(HloSendRecvInstruction* p2p) { + Status RecordP2POpForUnpipelinedGroup(HloSendRecvInstruction* p2p) { if (kind == kUnrecognized) { // Leave unrecognized P2P groups alone. return OkStatus(); } if (kind != kUnpipelined) { - return InternalError("Expected unpipelined group"); + return Internal("Expected unpipelined group"); } P2PGroupNode& node = nodes[kUnpipelinedNodeIdx]; - if (!node.RecordDoneOp(p2p)) { + if (!node.RecordP2POp(p2p)) { kind = kUnrecognized; } return OkStatus(); } - Status RecordDoneOpForPipelinedGroup(HloSendRecvInstruction* p2p) { + Status RecordP2POpForPipelinedGroup(HloSendRecvInstruction* p2p) { if (kind == kUnrecognized) { // Leave unrecognized P2P groups alone. return OkStatus(); } if (kind == kUnpipelined) { if (nodes[kPipelinedParentNodeIdx].computation != nullptr) { - return InternalError("Expected unpipelined group"); + return Internal("Expected unpipelined group"); } kind = kPipelined; } P2PGroupNode& node = nodes[kPipelinedParentNodeIdx]; - if (!node.RecordDoneOp(p2p)) { + if (!node.RecordP2POp(p2p)) { kind = kUnrecognized; } return OkStatus(); @@ -206,7 +271,7 @@ struct P2PGroup { return OkStatus(); } if (kind == kUnpipelined) { - return InternalError("Expected pipelined group"); + return Internal("Expected pipelined group"); } P2PGroupNode& node = nodes[kPipelinedParentNodeIdx]; if (!node.RecordWhileOp(while_op)) { @@ -215,23 +280,108 @@ struct P2PGroup { return OkStatus(); } - P2PGroupKind kind = kUnpipelined; - P2PGroupNode nodes[2]; -}; + // Finds the pipeline stream from the frontend attribute of the Send/Recv in + // the pipeline group, verifies they all have the same value and records + // the stream. + bool RecordRuntimeStream() { + P2PRuntimeStream child_stream = + nodes[kPipelinedChildNodeIdx].GetRuntimeStream(); + if (kind == kPipelined) { + P2PRuntimeStream parent_stream = + nodes[kPipelinedParentNodeIdx].GetRuntimeStream(); + if (child_stream != parent_stream || child_stream == kUnknown) { + return false; + } + } + // Record the stream. + runtime_stream = child_stream; + return true; + } -// Maps a channel ID to the corresponding P2P operation group. -using P2PGroupMap = absl::flat_hash_map; + // Records the other group that forms a cycle with this group, assuming that + // we handle only two groups that form a cycle. + Status RecordComplementGroup(P2PGroupMap& p2p_group_map) { + CHECK(complement_group == nullptr && runtime_stream == kStream1); + for (auto& [channel, p2p_group] : p2p_group_map) { + if (&p2p_group == this || + p2p_group.ChildComputation() != ChildComputation()) { + continue; + } + if (p2p_group.kind == kPipelined && + p2p_group.ParentComputation() == ParentComputation()) { + // Found two pipelined group for the same while loop, verify that they + // have different valid pipeline stream. + if (p2p_group.runtime_stream != kStream0) { + return Internal( + "Expected different pipeline stream for complement group"); + } + complement_group = &p2p_group; + p2p_group.complement_group = this; + } else if (p2p_group.kind == kUnpipelined && + p2p_group.runtime_stream != kStream1) { + complement_group = &p2p_group; + p2p_group.complement_group = this; + } + } + return OkStatus(); + } -// Maps a computation to the channel IDs used by the computation for P2P -// operations. We use std::set instead of hash set for deterministic -// iterators. -using P2PInComputation = - absl::flat_hash_map>; + // Returns the parent computation assuming this is a kPipelined group. + HloComputation* ParentComputation() const { return GetParent().computation; } -// Maps a computation to a boolean that indicates whether the computation -// invokes collective operations directly or indirectly. -using CollectiveInComputation = - absl::flat_hash_map; + // Returns the child computation for the group. + HloComputation* ChildComputation() const { return GetChild().computation; } + + P2PGroupNode& GetChild() { return nodes[kPipelinedChildNodeIdx]; } + P2PGroupNode& GetParent() { return nodes[kPipelinedParentNodeIdx]; } + const P2PGroupNode& GetChild() const { return nodes[kPipelinedChildNodeIdx]; } + const P2PGroupNode& GetParent() const { + return nodes[kPipelinedParentNodeIdx]; + } + + // Returns the start and end of a region marked by a pipelined chain in the + // given computation, which is the region with the pipelined P2P instructions. + ChainStartEnd GetChainStartEnd(HloComputation* computation) const { + if (kind == kUnpipelined) { + if (!InCycle()) { + return std::make_pair(GetChild().recv, GetChild().send_done); + } + CHECK(runtime_stream == kStream1); + return std::make_pair(complement_group->GetChild().recv, + GetChild().send_done); + } + + CHECK(kind == kPipelined); + if (computation == ChildComputation()) { + if (!InCycle()) { + return std::make_pair(GetChild().recv, GetChild().send_done); + } + CHECK(runtime_stream == kStream1); + return std::make_pair(complement_group->GetChild().recv, + GetChild().send_done); + } + + CHECK(computation == ParentComputation()); + if (!InCycle()) { + return std::make_pair(GetParent().recv, GetParent().send_done); + } + CHECK(runtime_stream == kStream1); + return std::make_pair(complement_group->GetParent().recv, + GetParent().send_done); + } + + HloInstruction* GetWhileOp() const { + return nodes[kPipelinedParentNodeIdx].while_loop; + } + + bool InCycle() const { return complement_group != nullptr; } + + P2PGroupKind kind = kUnpipelined; + P2PGroupNode nodes[2]; + P2PRuntimeStream runtime_stream = kUnknown; + // Another P2PGroup that forms a cycle with this group. + P2PGroup* complement_group = nullptr; +}; bool MayInvokeCollectiveOp( const HloInstruction* hlo, @@ -249,136 +399,169 @@ bool MayInvokeCollectiveOp( return false; } -// If the while-body contains a P2P chain that use the same channel as another -// P2P chain in the caller computation, assume these two P2P chain belong to -// the same pipelined P2P sequence. Adds the WhileOp to the pipelined group +// If the while-body contains a P2P group that uses the same channel as any +// Send operand of the while-op, we assume these two P2P groups belong to the +// same pipelined P2P sequence. Adds the WhileOp to the pipelined group // representation in this case. Status MayAddWhileOpToPipelinedGroup(HloInstruction* while_op, P2PInComputation& p2p_in_computation, P2PGroupMap& p2p_group_map) { + if (while_op->while_init()->opcode() != HloOpcode::kTuple) { + // A while-init should contain the loop index variable. So if a while-init + // is not a tuple, it only contains the loop index variable and shouldn't + // contain any pipelined Send operand. + return OkStatus(); + } HloComputation* body = while_op->called_computations()[0]; - auto p2p = p2p_in_computation.find(body); - if (p2p == p2p_in_computation.end()) { + auto p2p_in_while = p2p_in_computation.find(body); + if (p2p_in_while == p2p_in_computation.end()) { return OkStatus(); } int pipelined_group = 0; - for (auto channel : p2p->second) { - auto p2p_group = p2p_group_map.find(channel); - if (p2p_group == p2p_group_map.end() || - p2p_group->second.kind != kPipelined) { + // Check whether the while-op init contains a token from a Send result. + for (auto hlo : while_op->while_init()->operands()) { + if (hlo->opcode() != HloOpcode::kSendDone) { + continue; + } + int64_t channel_id = hlo->channel_id().value(); + if (p2p_in_while->second.find(channel_id) == p2p_in_while->second.end()) { + continue; + } + auto group = p2p_group_map.find(channel_id); + if (group == p2p_group_map.end() || group->second.kind != kPipelined) { continue; } pipelined_group++; - if (pipelined_group > 1) { - return InternalError( - "Expecting only one pipelined P2P group for each while-loop"); + if (pipelined_group > 2) { + return Internal( + "Expecting up to two pipelined P2P groups for each while-loop"); } - TF_RETURN_IF_ERROR( - p2p_group->second.RecordWhileOpToPipelinedGroup(while_op)); + TF_RETURN_IF_ERROR(group->second.RecordWhileOpToPipelinedGroup(while_op)); } return OkStatus(); } -// For an unpipelined Send-Recv chain, add control dependence to enforce this -// ordering: +Status OrderBefore(HloInstruction* i1, HloInstruction* i2) { + TF_RETURN_IF_ERROR(i1->AddControlDependencyTo(i2)); + VLOG(10) << "Add control predecessor " << i2->ToString(); + return OkStatus(); +} + +// Adds control dependence to enforce this ordering: // recv => send => recv-done => send-done. -Status ConnectUnpipelinedP2P(const P2PGroupNode& node) { - HloSendRecvInstruction* recv_done = node.recv_done; - HloRecvInstruction* recv = - DynCast(recv_done->mutable_operand(0)); - HloSendRecvInstruction* send_done = node.send_done; - HloSendInstruction* send = - DynCast(send_done->mutable_operand(0)); - // We want the Recv to be scheduled before the Send. - TF_RETURN_IF_ERROR(recv->AddControlDependencyTo(send)); - VLOG(10) << "Add control predecessor " << send->ToString(); - // We want the Send to be scheduled before RecvDone to prevent the scheduler - // from interleaving two Send-Recv sequences. - TF_RETURN_IF_ERROR(send->AddControlDependencyTo(recv_done)); - VLOG(10) << "Add control predecessor " << recv_done->ToString(); - - // We want the RecvDone to be scheduled before the SendDone. - TF_RETURN_IF_ERROR(recv_done->AddControlDependencyTo(send_done)); - VLOG(10) << "Add control predecessor " << send_done->ToString(); +Status ConnectP2P1NodeChain(const P2PGroupNode& node) { + HloRecvDoneInstruction* recv_done = node.recv_done; + HloRecvInstruction* recv = node.recv; + HloSendDoneInstruction* send_done = node.send_done; + HloSendInstruction* send = node.send; + TF_RETURN_IF_ERROR(OrderBefore(recv, send)); + TF_RETURN_IF_ERROR(OrderBefore(send, recv_done)); + TF_RETURN_IF_ERROR(OrderBefore(recv_done, send_done)); return OkStatus(); } -// For the pipelined Send-Recv chain in a while-body, we need to make sure -// that the Send is scheduled before Recv as the Send release the unrolled -// Recv before entering the transformed loop. We let the scheduler to decide -// where to schedule send-done and recv-done. For example, if the while-body -// has other unpiplined Send-Recv chains, it may produce this ordering: -// send => send-done => other Send-Recv chains => recv => recv-done -// If the while-body doesn't have other Send-Recv chains, it may produce this +// For an unpipelined Send-Recv chain, adds control dependence to enforce this // ordering: -// send => recv => recv-done => send-done -Status ConnectPipelinedP2PChild(const P2PGroupNode& node) { - HloSendRecvInstruction* recv_done = node.recv_done; - HloRecvInstruction* recv = - DynCast(recv_done->mutable_operand(0)); - HloSendRecvInstruction* send_done = node.send_done; - HloSendInstruction* send = - DynCast(send_done->mutable_operand(0)); - // We want the Send to be scheduled before the Recv. - TF_RETURN_IF_ERROR(send->AddControlDependencyTo(recv)); - VLOG(10) << "Add control predecessor " << recv->ToString(); +// recv => send => recv-done => send-done. +Status ConnectUnpipelinedP2P(const P2PGroup& p2p_group) { + return ConnectP2P1NodeChain(p2p_group.GetChild()); +} + +// For a single pipelined Send-Recv chain in a while-body, adds control +// dependence toenforce this ordering: +// recv => send => recv-done => send-done +Status ConnectPipelined1P2PChild(const P2PGroup& p2p_group) { + return ConnectP2P1NodeChain(p2p_group.GetChild()); +} + +// For aSend-Recv chain involving two channels, adds control dependence to +// enforce this ordering: +// recv.0 => send.0 => recv.1 => send.1 => +// recv-done.0 => recv-done.1 => send-done.0 => send-done.1 +Status ConnectP2P2NodeChain(const P2PGroupNode& node0, + const P2PGroupNode& node1) { + HloSendRecvInstruction* recv_done0 = node0.recv_done; + HloRecvInstruction* recv0 = node0.recv; + HloSendRecvInstruction* send_done0 = node0.send_done; + HloSendInstruction* send0 = node0.send; + HloSendRecvInstruction* recv_done1 = node1.recv_done; + HloRecvInstruction* recv1 = node1.recv; + HloSendRecvInstruction* send_done1 = node1.send_done; + HloSendInstruction* send1 = node1.send; + + TF_RETURN_IF_ERROR(OrderBefore(recv_done0, recv_done1)); + TF_RETURN_IF_ERROR(OrderBefore(recv_done1, send_done0)); + TF_RETURN_IF_ERROR(OrderBefore(send_done0, send_done1)); + + TF_RETURN_IF_ERROR(OrderBefore(recv0, send0)); + TF_RETURN_IF_ERROR(OrderBefore(send0, recv1)); + TF_RETURN_IF_ERROR(OrderBefore(recv1, send1)); + + TF_RETURN_IF_ERROR(OrderBefore(send1, recv_done0)); + return OkStatus(); } -// Returns a boolean to indicate whether there are any operation in the range -// [start, end] that contains non-host P2P transfer that are reachable from -// the given instruction. -bool OperationChainHasP2P( - P2PInComputation& p2p_in_computation, - const std::vector::const_iterator& start, - const std::vector::const_iterator& end, - const HloReachabilityMap* reachability, const HloInstruction* instr) { - for (auto it_op = start; it_op != end; ++it_op) { - const HloInstruction* op = *it_op; - if (!reachability->IsReachable(instr, op)) continue; - - if (IsP2POp(op)) { - return true; - } +// For a pipelined Send-Recv chain with two channel groups forming a cycle in a +// while-body computation, we enforce this ordering: +// recv.0 => send.0 => recv.1 => send.1 => +// recv-done.0 => recv-done.1 => send-done.0 => send-done.1 +Status ConnectPipelined2P2PChild(const P2PGroup& p2p_group) { + return ConnectP2P2NodeChain(p2p_group.complement_group->GetChild(), + p2p_group.GetChild()); +} - for (const HloComputation* called_comp : op->called_computations()) { - auto p2p_in_comp = p2p_in_computation.find(called_comp); - if (p2p_in_comp != p2p_in_computation.end()) { - return true; - } - } - } - return false; +// For a pipelined Send-Recv chain with one group in the while-body calling +// computation, we enforce this ordering: +// recv => send => recv-done => send-done +Status ConnectPipelined1P2PParent(const P2PGroup& p2p_group) { + return ConnectP2P1NodeChain(p2p_group.GetParent()); +} + +// For a pipelined Send-Recv chain with two channel groups forming a cycle +// in the while-body calling computation, we enforce this ordering: +// recv.0 => send.0 => recv.1 => send.1 => => +// recv-done.0 => recv-done.1 => send-done.0 => send-done.1 +Status ConnectPipelined2P2PParent(const P2PGroup& p2p_group) { + return ConnectP2P2NodeChain(p2p_group.complement_group->GetParent(), + p2p_group.GetParent()); } -// Collects P2P send-done and recv-done instructions from the computation and -// group them by channel IDs. Also records whether the computation invokes -// collective operation directly or indirectly. +// For a Send-Recv chain with two channel groups forming a cycle in a while-body +// annotated for pipelining but not pipelined (due to skip pipelining pass), we +// enforece this ordering: +// recv.0 => send.0 => recv.1 => send.1 => +// recv-done.0 => recv-done.1 => send-done.0 => send-done.1 +Status ConnectUnpipelined2P2P(const P2PGroup& p2p_group) { + CHECK(p2p_group.runtime_stream == kStream1); + return ConnectP2P2NodeChain(p2p_group.complement_group->GetChild(), + p2p_group.GetChild()); +} + +// Collects P2P send-done and recv-done instructions from the computation, +// groups them by channel IDs, records pipeline decision for groups and connects +// groups that form a cycle for pipelining. Also records whether the computation +// invokes collective operation directly or indirectly. Status GatherP2PGroupsAndCollectiveInfo( const HloComputation* computation, P2PInComputation& p2p_in_computation, P2PGroupMap& p2p_group_map, CollectiveInComputation& collective_in_computation) { collective_in_computation[computation] = false; + std::vector while_ops; for (auto hlo : computation->MakeInstructionPostOrder()) { // Record the use of collective operations. - if (IsCollectiveOp(hlo)) { + if (MayInvokeCollectiveOp(hlo, collective_in_computation)) { collective_in_computation[computation] = true; - } else { - // Propagate CollectiveInComputation from callees to callers. - for (auto callee : hlo->called_computations()) { - auto collective_in_comp = collective_in_computation.find(callee); - if (collective_in_comp != collective_in_computation.end()) { - collective_in_computation[computation] |= collective_in_comp->second; - } - } } if (hlo->opcode() == HloOpcode::kWhile) { - TF_RETURN_IF_ERROR(MayAddWhileOpToPipelinedGroup(hlo, p2p_in_computation, - p2p_group_map)); + // The pipelined Recv-done/Send-done appears after the while-op. As + // such, the pipelined group hasn't been constructed at this point. + // Keep the while-op and add to the pipelined group later. + while_ops.push_back(hlo); continue; } - if (!IsP2PDoneOp(hlo)) { + if (!IsP2POp(hlo)) { continue; } HloSendRecvInstruction* p2p = Cast(hlo); @@ -389,15 +572,15 @@ Status GatherP2PGroupsAndCollectiveInfo( // P2P group and may turn it into a kPipelined group or kUnrecognized // group. P2PGroup group; - TF_RETURN_IF_ERROR(group.RecordDoneOpForUnpipelinedGroup(p2p)); + TF_RETURN_IF_ERROR(group.RecordP2POpForUnpipelinedGroup(p2p)); p2p_group_map[channel] = group; } else { P2PGroup& group = p2p_group->second; - if (group.nodes[kUnpipelinedNodeIdx].computation == computation) { - TF_RETURN_IF_ERROR(group.RecordDoneOpForUnpipelinedGroup(p2p)); + if (group.ChildComputation() == computation) { + TF_RETURN_IF_ERROR(group.RecordP2POpForUnpipelinedGroup(p2p)); } else { // We are at the parent computation for a pipelined P2P group. - TF_RETURN_IF_ERROR(group.RecordDoneOpForPipelinedGroup(p2p)); + TF_RETURN_IF_ERROR(group.RecordP2POpForPipelinedGroup(p2p)); } } // We can't rely on the operation on p2p_group_map above to find out @@ -414,36 +597,61 @@ Status GatherP2PGroupsAndCollectiveInfo( } } + for (auto hlo : while_ops) { + TF_RETURN_IF_ERROR( + MayAddWhileOpToPipelinedGroup(hlo, p2p_in_computation, p2p_group_map)); + } + // Now finalize each group, in particular, if a kPipelined or kUnpipelined - // group is missing some instructions, change the group to kUnrecognized. + // group is missing some instructions, a kPipelined group missing a pipeline + // stream or have inconsistent pipeline streams, change the group to + // kUnrecognized. for (auto& [channel, p2p_group] : p2p_group_map) { if (p2p_group.kind == kUnpipelined) { - if (p2p_group.nodes[kUnpipelinedNodeIdx].Incomplete()) { + if (p2p_group.nodes[kUnpipelinedNodeIdx].Incomplete() || + !p2p_group.RecordRuntimeStream()) { p2p_group.kind = kUnrecognized; } } else if (p2p_group.kind == kPipelined) { if (p2p_group.nodes[kPipelinedChildNodeIdx].Incomplete() || p2p_group.nodes[kPipelinedParentNodeIdx] - .IncompletePipelinedParent()) { + .IncompletePipelinedParent() || + !p2p_group.RecordRuntimeStream()) { p2p_group.kind = kUnrecognized; } } } + // Erase kUnrecognized groups. absl::erase_if(p2p_group_map, [](const auto& p2p_group) { return p2p_group.second.kind == kUnrecognized; }); + // Connect two groups that form a cycle, both for pipelined and unpipelined + // cases for the current computation. We only build such a connection when we + // are processing the group for kStream1 stream, and for parent computation + // for a pipelined group. + for (auto& [channel, p2p_group] : p2p_group_map) { + if ((p2p_group.kind == kPipelined && + p2p_group.ParentComputation() != computation) || + p2p_group.complement_group != nullptr || + p2p_group.runtime_stream != kStream1) { + continue; + } + + TF_RETURN_IF_ERROR(p2p_group.RecordComplementGroup(p2p_group_map)); + } + return OkStatus(); } -// For a given computation, adds control dependence to chain the recognized -// pipelined or unpipelined P2P group in the computation. Returns the total -// number of P2P chains and if the computation is a while-body with a pipelined -// P2P group, returns such a group or a nullptr. -StatusOr ConnectP2PChain(HloComputation* computation, - const P2PGroupMap& p2p_group_map, - const std::set& p2p_channels) { +// For a given computation, adds control dependence to chain a pipelined or +// unpipelined P2P group in the computation. Returns the total number of such +// chains. If the computation is a while-body, verifies that at most one group +// or two groups forming a cycle are pipelined and returns the pipelined group. +absl::StatusOr> ConnectP2PChain( + HloComputation* computation, const P2PGroupMap& p2p_group_map, + const std::set& p2p_channels) { // If the current computation is a while-body and has a pipelined P2P chain, // record such a P2P group. const P2PGroup* pipelined_group = nullptr; @@ -451,7 +659,7 @@ StatusOr ConnectP2PChain(HloComputation* computation, for (int64_t channel : p2p_channels) { auto it = p2p_group_map.find(channel); if (it == p2p_group_map.end()) { - // The instructions that use the channel don't form an interested P2P + // The instructions that use the channel don't form an interesting P2P // group, do nothing. continue; } @@ -459,28 +667,52 @@ StatusOr ConnectP2PChain(HloComputation* computation, const P2PGroup& p2p_group = it->second; P2PGroupKind kind = p2p_group.kind; if (kind == P2PGroupKind::kUnpipelined) { - TF_RETURN_IF_ERROR( - ConnectUnpipelinedP2P(p2p_group.nodes[kUnpipelinedNodeIdx])); + if (!p2p_group.InCycle()) { + TF_RETURN_IF_ERROR(ConnectUnpipelinedP2P(p2p_group)); + } else if (p2p_group.runtime_stream == kStream1) { + TF_RETURN_IF_ERROR(ConnectUnpipelined2P2P(p2p_group)); + } continue; } - // For Pipelined group. - if (computation != p2p_group.nodes[kPipelinedParentNodeIdx].computation) { - // We are at the computation for the while-body of the pipelined group. - TF_RETURN_IF_ERROR( - ConnectPipelinedP2PChild(p2p_group.nodes[kPipelinedChildNodeIdx])); + if (p2p_group.complement_group == nullptr) { + if (computation == p2p_group.ParentComputation()) { + TF_RETURN_IF_ERROR(ConnectPipelined1P2PParent(p2p_group)); + } else { + // A pipeline of one group. + if (pipelined_group != nullptr) { + return Internal("Expected <=1 pipelined group in a while-body"); + } + pipelined_group = &p2p_group; + TF_RETURN_IF_ERROR(ConnectPipelined1P2PChild(p2p_group)); + } + continue; + } + + // A pipeline of two groups that form a cycle. We process the pipeline when + // we see the group with kStream1. + if (p2p_group.runtime_stream != kStream1) { + continue; + } + + if (computation == p2p_group.ParentComputation()) { + TF_RETURN_IF_ERROR(ConnectPipelined2P2PParent(p2p_group)); + } else { if (pipelined_group != nullptr) { - return InternalError("Expected <=1 pipelined group in a while-body"); + return Internal( + "Expected only two pipelined groups forming a cycle in a " + "while-body"); } pipelined_group = &p2p_group; - } else { + TF_RETURN_IF_ERROR(ConnectPipelined2P2PChild(p2p_group)); } } - return num_p2p_chains; + return std::make_pair(num_p2p_chains, pipelined_group); } Status OrderBefore(HloReachabilityMap* reachability, HloInstruction* a, HloInstruction* b) { + VLOG(10) << "OrderBefore " << a->ToString() << " " << b->ToString(); if (!reachability->IsReachable(a, b)) { TF_RETURN_IF_ERROR(a->AddControlDependencyTo(b)); VLOG(10) << "add control predecessor " << b->ToString(); @@ -490,153 +722,93 @@ Status OrderBefore(HloReachabilityMap* reachability, HloInstruction* a, } // Adds control dependence to linearize other collective ops with respect to -// the given unpipelined P2P chain which is ordered as follows: -// Recv => Send => Recv-Done => Send-Done -// We intend to schedule collective ops ordered before Recv-Done before Recv -// and collective ops ordered after Recv-Done after Send-Done. -Status ChainCollectivesWithUnpipelinedP2P( - const P2PGroupMap& p2p_group_map, const P2PGroupNode& node, - const std::vector::iterator& begin, - const std::vector::iterator& recv_done_iter, - const std::vector::iterator& end, +// the given P2P chain, which is either an unpipelined P2P chain, or a pipelined +// P2P chain in the while-loop calling computation. The P2P chain can be one of +// the following: +// Recv => Send => Recv-Done => Send-Done (unpipelined, or pipelined 1) +// Recv.0 => Send.0 => Recv.1 => Send.1 => Recv-Done.0 => Send-Done.0 +// Recv-Done.1 => Send-Done.1 (pipelined 2) +// We intend to schedule collective ops ordered before the beginning of such a +// chain or after the ending of such a chain. +Status LinearizeCollectivesWithOtherP2P( + const P2PGroupMap& p2p_group_map, const P2PGroup& group, const CollectiveInComputation& collective_in_computation, + const std::vector::iterator& chain_start_iter, + const std::vector::iterator& begin_iter, + const std::vector::iterator& end_iter, HloReachabilityMap* reachability) { - HloSendRecvInstruction* send_done = - DynCast(node.send_done); - HloInstruction* recv = (*recv_done_iter)->mutable_operand(0); - auto in_current_p2p_chain = [&](const HloInstruction* hlo) { - const HloSendRecvInstruction* p2p = - DynCastOrNull(hlo); - return p2p != nullptr && p2p->channel_id() == send_done->channel_id(); - }; - - for (auto it = begin; it != end; ++it) { - HloInstruction* hlo = *it; - if (!MayInvokeCollectiveOp(hlo, collective_in_computation) || - in_current_p2p_chain(hlo)) { - continue; - } + HloComputation* computation = (*chain_start_iter)->parent(); + ChainStartEnd start_end = group.GetChainStartEnd(computation); - HloOpcode opcode = hlo->opcode(); - // Handle a P2P chain when we see its Send-Done. - if (opcode == HloOpcode::kRecvDone) { - continue; - } - if (opcode == HloOpcode::kSendDone) { + // We refer to the P2P chain represented by `group` chain A. + for (auto it = begin_iter; it != end_iter; ++it) { + HloInstruction* hlo = *it; + if (IsP2POp(hlo)) { auto group_it = p2p_group_map.find(hlo->channel_id().value()); if (group_it == p2p_group_map.end()) { - LOG(INFO) << "Warn unhandled P2P " << hlo->ToString(); continue; } - const P2PGroup& p2p_group = group_it->second; - P2PGroupKind kind = p2p_group.kind; + const P2PGroup& cur_group = group_it->second; + P2PGroupKind kind = cur_group.kind; + // May linearize chain A with chain B represented by `cur_group`. if (kind == P2PGroupKind::kPipelined && - recv->parent() != - p2p_group.nodes[kPipelinedParentNodeIdx].computation) { - // The pipelined P2P in the "child" is already ordered with respected to - // other P2P chains. + computation == cur_group.ChildComputation()) { + // Chain B a pipelined P2P chain with `computation` as a while-body. We + // already linearize the two chains in + // LinearizeCollectivesWithPipelinedP2PChild. continue; } - if (reachability->IsReachable(recv, hlo)) { - HloInstruction* recv2 = p2p_group - .nodes[kind == P2PGroupKind::kUnpipelined - ? kUnpipelinedNodeIdx - : kPipelinedParentNodeIdx] - .recv_done->mutable_operand(0); - TF_RETURN_IF_ERROR(OrderBefore(reachability, send_done, recv2)); + + ChainStartEnd cur_start_end = cur_group.GetChainStartEnd(computation); + if (cur_start_end.first != hlo) { + // We will linearize the two chains when we see the first instruction in + // chain B. + continue; + } + if (it <= chain_start_iter) { + // We already linearize the two chains when we call this routine for + // `cur_group`. + continue; + } + + if (reachability->IsReachable(start_end.first, cur_start_end.second)) { + // Order chain A before chain B. + TF_RETURN_IF_ERROR( + OrderBefore(reachability, start_end.second, cur_start_end.first)); } else { - TF_RETURN_IF_ERROR(OrderBefore(reachability, hlo, recv)); + // Order chain B before chain A. + TF_RETURN_IF_ERROR( + OrderBefore(reachability, cur_start_end.second, start_end.first)); } continue; } - // The hlo is not a Send/Recv instruction. - if (reachability->IsReachable(hlo, send_done)) { - TF_RETURN_IF_ERROR(OrderBefore(reachability, hlo, recv)); - } else { - // TODO(b/309639264): Remove kCustomCall when the NVIDIA bug is fixed. - TF_RETURN_IF_ERROR(OrderBefore( - reachability, send_done, - opcode == HloOpcode::kCustomCall ? hlo : GetStartOpForDoneOp(hlo))); - } - } - - return OkStatus(); -} - -// Adds control dependence to linearize other collective ops with respect to -// the given pipelined P2P chain in the computation containing the pipelined -// while-loop, which is ordered as follows: -// Recv => Recv-Done => While-loop => Send => SendDone -// We intend to schedule collective ops ordered before the while-loop before -// Recv and collective ops ordered after the while-loop after Send-Done. -Status ChainCollectivesWithPipelinedP2PParent( - const P2PGroupMap& p2p_group_map, const P2PGroupNode& node, - const std::vector::iterator& begin, - const std::vector::iterator& while_loop_iter, - const std::vector::iterator& end, - const CollectiveInComputation& collective_in_computation, - HloReachabilityMap* reachability) { - HloInstruction* recv = node.recv_done->mutable_operand(0); - HloSendRecvInstruction* send_done = - DynCast(node.send_done); - auto in_current_p2p_chain = [&](const HloInstruction* hlo) { - if (hlo->opcode() == HloOpcode::kWhile) { - return node.while_loop == hlo; - } - const HloSendRecvInstruction* p2p = - DynCastOrNull(hlo); - return p2p != nullptr && p2p->channel_id() == send_done->channel_id(); - }; - - for (auto it = begin; it != end; ++it) { - HloInstruction* hlo = *it; - if (!MayInvokeCollectiveOp(hlo, collective_in_computation) || - in_current_p2p_chain(hlo)) { + if (!MayInvokeCollectiveOp(hlo, collective_in_computation)) { continue; } - - HloOpcode opcode = hlo->opcode(); - // Handle a P2P chain when we see its Send-done. - if (opcode == HloOpcode::kRecvDone) { + if (hlo->opcode() == HloOpcode::kWhile && + group.kind == P2PGroupKind::kPipelined && group.GetWhileOp() == hlo) { + // This is the while-op for chain A. No need to add control dependence. continue; } - if (opcode == HloOpcode::kSendDone) { - auto group_it = p2p_group_map.find(hlo->channel_id().value()); - if (group_it == p2p_group_map.end()) { - LOG(INFO) << "Warn unhandled P2P " << hlo->ToString(); - continue; - } - const P2PGroup& p2p_group = group_it->second; - P2PGroupKind kind = p2p_group.kind; - if (kind == P2PGroupKind::kPipelined && - recv->parent() != - p2p_group.nodes[kPipelinedParentNodeIdx].computation) { - // The pipelined P2P in the "child" is already ordered with respected to - // other P2P chains. - continue; - } - if (reachability->IsReachable(recv, hlo)) { - HloInstruction* recv2 = p2p_group - .nodes[kind == P2PGroupKind::kUnpipelined - ? kUnpipelinedNodeIdx - : kPipelinedParentNodeIdx] - .recv_done->mutable_operand(0); - TF_RETURN_IF_ERROR(OrderBefore(reachability, send_done, recv2)); + + if (hlo_query::IsAsyncCollectiveDoneOp(hlo, /*include_send_recv=*/false)) { + if (reachability->IsReachable(start_end.first, hlo)) { + // Order chain A before the async op. + TF_RETURN_IF_ERROR(OrderBefore(reachability, start_end.second, + GetStartOpForDoneOp(hlo))); } else { - TF_RETURN_IF_ERROR(OrderBefore(reachability, hlo, recv)); + // Order the async op before chain A. + TF_RETURN_IF_ERROR(OrderBefore(reachability, hlo, start_end.first)); } - continue; } - - // The hlo is not a Send/Recv instruction. - if (reachability->IsReachable(hlo, send_done)) { - TF_RETURN_IF_ERROR(OrderBefore(reachability, hlo, recv)); + // CustomCall or other op that indirectly invoke collectives. + if (reachability->IsReachable(start_end.first, hlo)) { + // Order chain A before the op. + TF_RETURN_IF_ERROR(OrderBefore(reachability, start_end.second, hlo)); } else { - // TODO(b/309639264): Remove kCustomCall when the NVIDIA bug is fixed. - TF_RETURN_IF_ERROR(OrderBefore( - reachability, send_done, - opcode == HloOpcode::kCustomCall ? hlo : GetStartOpForDoneOp(hlo))); + // Order the op before chain A. + TF_RETURN_IF_ERROR(OrderBefore(reachability, hlo, start_end.first)); } } @@ -645,77 +817,50 @@ Status ChainCollectivesWithPipelinedP2PParent( // Adds control dependence to linearize other collective ops with respect to // the given pipelined P2P chain in the computation for the pipelined -// while-loop, which is ordered as follows: -// Send => Send-Done -// Recv => Recv-Done -// Send => Recv -// All collective ops should be scheduled after Send-Done and Before -Status ChainCollectivesWithPipelinedP2PChild( - const P2PGroupMap& p2p_group_map, const P2PGroupNode& node, - const std::vector::iterator& begin, - const std::vector::iterator& end, +// while-loop. All Collective ops should be scheduled before the chain. +Status LinearizeCollectivesWithPipelinedP2PChild( + const P2PGroupMap& p2p_group_map, const P2PGroup& group, const CollectiveInComputation& collective_in_computation, - HloReachabilityMap* reachability) { - HloInstruction* send_done = node.send_done; - HloSendRecvInstruction* recv = - DynCast(node.recv_done->mutable_operand(0)); - auto in_current_p2p_chain = [&](const HloInstruction* hlo) { - const HloSendRecvInstruction* p2p = - DynCastOrNull(hlo); - return p2p != nullptr && p2p->channel_id() == recv->channel_id(); - }; - - // If an hlo may invoke collective operation and is ordered before the - // Send, checks that it is not reachable to Send-Done and adds control - // dependence to make sure it is scheduled after Send-Done and before Recv. - for (auto it = begin; it != end; ++it) { - HloInstruction* hlo = *it; - if (!MayInvokeCollectiveOp(hlo, collective_in_computation) || - in_current_p2p_chain(hlo)) { + HloComputation* computation, HloReachabilityMap* reachability) { + ChainStartEnd start_end = group.GetChainStartEnd(computation); + + // If an hlo may invoke collective operation, we add control dependence to + // make sure that the hlo is scheduled before the pipelined chain starts. + for (HloInstruction* hlo : computation->MakeInstructionPostOrder()) { + // For async collective ops, only the done version of the op passes this + // check, to avoid handling async ops twice. + if (!MayInvokeCollectiveOp(hlo, collective_in_computation)) { continue; } - if (reachability->IsReachable(hlo, send_done) || - reachability->IsReachable(recv, hlo)) { - return InternalError("Detect deadlock in input HLO"); - } HloOpcode opcode = hlo->opcode(); - // Handle a P2P chain when we see its Send-done. - if (opcode == HloOpcode::kRecvDone) { + // Handle a P2P group when we see its Send-done. + if (IsP2POp(hlo) && opcode != HloOpcode::kSendDone) { continue; } - if (opcode == HloOpcode::kSendDone) { + if (hlo->opcode() == HloOpcode::kSendDone) { auto group_it = p2p_group_map.find(hlo->channel_id().value()); if (group_it == p2p_group_map.end()) { continue; } - const P2PGroup& p2p_group = group_it->second; - P2PGroupKind kind = p2p_group.kind; + const P2PGroup& cur_group = group_it->second; + P2PGroupKind kind = cur_group.kind; if (kind == P2PGroupKind::kPipelined && - recv->parent() != - p2p_group.nodes[kPipelinedParentNodeIdx].computation) { - // The pipelined P2P in the "child" is already ordered with respected to - // other P2P chains. + computation == cur_group.ChildComputation()) { + // This is a P2P group for the pipelined in the current while-body. + // We are looking for other collective ops outside this group. continue; } - HloInstruction* recv2 = p2p_group - .nodes[kind == P2PGroupKind::kUnpipelined - ? kUnpipelinedNodeIdx - : kPipelinedParentNodeIdx] - .recv_done->mutable_operand(0); - TF_RETURN_IF_ERROR(OrderBefore(reachability, send_done, recv2)); - TF_RETURN_IF_ERROR(OrderBefore(reachability, hlo, recv)); + ChainStartEnd cur_start_end = cur_group.GetChainStartEnd(computation); + TF_RETURN_IF_ERROR( + OrderBefore(reachability, cur_start_end.second, start_end.first)); continue; } - // The hlo is not a Send/Recv instruction. - // TODO(b/309639264): Remove kCustomCall when the NVIDIA bug is fixed. - TF_RETURN_IF_ERROR(OrderBefore( - reachability, send_done, - opcode == HloOpcode::kCustomCall ? hlo : GetStartOpForDoneOp(hlo))); - TF_RETURN_IF_ERROR(OrderBefore(reachability, hlo, recv)); + // Async done, CustomCall, or other ops that indirectly invoke collectives. + TF_RETURN_IF_ERROR(OrderBefore(reachability, hlo, start_end.first)); } return OkStatus(); @@ -723,7 +868,7 @@ Status ChainCollectivesWithPipelinedP2PChild( } // namespace -StatusOr P2PSchedulePreparation::Run( +absl::StatusOr P2PSchedulePreparation::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { P2PGroupMap p2p_group_map; @@ -739,6 +884,8 @@ StatusOr P2PSchedulePreparation::Run( // already have information for the while-body. for (auto iter = all_computations.begin(); iter != all_computations.end(); ++iter) { + VLOG(10) << "Gathering P2P groups and collective info for computation " + << (*iter)->name(); TF_RETURN_IF_ERROR(GatherP2PGroupsAndCollectiveInfo( *iter, p2p_in_computation, p2p_group_map, collective_in_computation)); } @@ -757,23 +904,36 @@ StatusOr P2PSchedulePreparation::Run( HloComputation* computation = *iter; auto p2p_in_comp = p2p_in_computation.find(computation); if (p2p_in_comp == p2p_in_computation.end()) { - // No recognized P2P chains in the computation, do nothing. + // No recognized P2P groups in the computation, do nothing. continue; } std::set& p2p_channels = p2p_in_comp->second; + // Connect P2P chains and return the number of chains and the P2P group + // representation for pipelined P2P in the current computation as a + // while-body. TF_ASSIGN_OR_RETURN( - int num_p2p_chains, - ConnectP2PChain(computation, p2p_group_map, p2p_channels)); - if (num_p2p_chains == 0) { + auto result, ConnectP2PChain(computation, p2p_group_map, p2p_channels)); + if (result.first == 0) { continue; } - VLOG(10) << "processing computation " << computation->name() - << " num_p2p_chains " << num_p2p_chains; + VLOG(10) << "Processing computation " << computation->name() + << " num_p2p_chains " << result.first; + + std::unique_ptr reachability = + HloReachabilityMap::Build(computation); + if (result.second != nullptr) { + // The current computation is a while-body with pipelined P2P chain. + // Order all other collectives in a pipelined while-body before the + // pipelined P2P chain. + TF_RETURN_IF_ERROR(LinearizeCollectivesWithPipelinedP2PChild( + p2p_group_map, *result.second, collective_in_computation, computation, + reachability.get())); + } + // Add control dependence to linearize collective operations with respect to - // each P2P chain. - std::unique_ptr reachability; + // other P2P chains. std::vector all_instructions = computation->MakeInstructionPostOrder(); std::vector::iterator begin = all_instructions.begin(); @@ -789,58 +949,35 @@ StatusOr P2PSchedulePreparation::Run( if (group_it == p2p_group_map.end()) { continue; } - - if (reachability == nullptr) { - reachability = HloReachabilityMap::Build(computation); + P2PGroup& group = group_it->second; + P2PGroupKind kind = group.kind; + if (kind == P2PGroupKind::kPipelined && + computation == group.ChildComputation()) { + // We already linearize pipelined P2P chains in while-body with respect + // to other collectives. + continue; } - P2PGroup& p2p_group = group_it->second; - P2PGroupKind kind = p2p_group.kind; - VLOG(10) << "connect other collectives with channel " << channel - << " kind " << (int)kind; - - if (kind == P2PGroupKind::kUnpipelined && - hlo->opcode() == HloOpcode::kRecvDone) { - // Case 1: Unpipelined P2P chain - // Send => Recv => Recv-Done => Send-Done - // We intend to schedule collective ops ordered before Recv-Done before - // Send and collective ops ordered after Recv-Done after Send-Done. - TF_RETURN_IF_ERROR(ChainCollectivesWithUnpipelinedP2P( - p2p_group_map, p2p_group.nodes[kUnpipelinedNodeIdx], begin, - instr_it, end, collective_in_computation, reachability.get())); - } else if (kind == P2PGroupKind::kPipelined && - computation == - p2p_group.nodes[kPipelinedParentNodeIdx].computation && - hlo->opcode() == HloOpcode::kRecvDone) { - // Case 2: Pipelined P2P chain in the "parent", that is, the computation - // containing the while-loop: - // Recv => Recv-Done => While-loop => Send => SendDone - // We intend to schedule collective ops ordered before the while-loop - // before Recv and collective ops ordered after the while-loop after - // Send-Done. - const HloInstruction* while_loop = - p2p_group.nodes[kPipelinedParentNodeIdx].while_loop; - std::vector::iterator while_loop_it = instr_it + 1; - while ((*while_loop_it) != while_loop) while_loop_it++; - TF_RETURN_IF_ERROR(ChainCollectivesWithPipelinedP2PParent( - p2p_group_map, p2p_group.nodes[kPipelinedParentNodeIdx], begin, - while_loop_it, end, collective_in_computation, reachability.get())); - } else if (kind == P2PGroupKind::kPipelined && - computation != - p2p_group.nodes[kPipelinedParentNodeIdx].computation && - hlo->opcode() == HloOpcode::kSend) { - // Case 3: Pipelined P2P chain in the "child", that is, the computation - // for the while-body: - // Send => Send-Done. ... Recv => Recv-Done - // All collective ops should be scheduled after Send-Done and Before - // Recv. - TF_RETURN_IF_ERROR(ChainCollectivesWithPipelinedP2PChild( - p2p_group_map, p2p_group.nodes[kPipelinedChildNodeIdx], begin, end, - collective_in_computation, reachability.get())); + if (group.InCycle() && group.runtime_stream != kStream1) { + // We process a chain with two groups when we see the group for + // kStream1. + continue; } - VLOG(10) << "finish connect other collectives with channel " << channel - << " kind " << (int)kind; + ChainStartEnd start_end = group.GetChainStartEnd(computation); + + // Handle the group when we see the beginning of the chain. + if (start_end.first != hlo) { + continue; + } + VLOG(10) << "linearize other collectives with respect to channel " + << hlo->ToString(); + + TF_RETURN_IF_ERROR(LinearizeCollectivesWithOtherP2P( + p2p_group_map, group, collective_in_computation, instr_it, begin, end, + reachability.get())); + VLOG(10) << "finish connect other collectives with channel "; } } + return true; } diff --git a/xla/service/p2p_schedule_preparation.h b/xla/service/p2p_schedule_preparation.h index 84c5b87dcc6b3..d28a73f2ab3b4 100644 --- a/xla/service/p2p_schedule_preparation.h +++ b/xla/service/p2p_schedule_preparation.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,81 +21,94 @@ limitations under the License. namespace xla { -// P2PSchedulePreparation is a pass to linearize point-to-point operation chain +// P2PSchedulePreparation is a pass to linearize point-to-point operation chains // to prepare for any HLO scheduler. In particular, this pass currently does the // following: // (1) For an unpipelined P2P Send-Recv chain, add control dependence to // express this ordering: // recv => send => recv-done => send-done // -// (2) For a pipelined P2P Send-Recv chain, add control dependence to the -// while-body to express this ordering: -// send => recv -// In the computation with such a while-loop, the data dependence already -// expresses this ordering: -// recv => recv-done => while-loop => send => send-done +// (2.1) For a single pipelined P2P Send-Recv chain, add control dependence to +// the while-body to express this ordering: +// recv-done => send-done => recv => send +// In the computation with such a while-loop, add control dependence to +// express this ordering: +// recv => send +// recv-done => send-done +// The data dependence already express this dependence: +// recv, send => while-loop => recv-done, send-done +// +// (2.2) For two pipelined P2P Send-Recv chain together forms a cycle, add +// control dependence to the while-body to express this ordering: +// recv-done.0 => send-done.0 => recv-done.1 => send-done.1 => recv.0 => +// send.0 => recv.1 => send.1 +// In the computation with such a while-loop, add control dependence to +// express this ordering: +// recv.0 => send.0 => recv.1 => send.1 +// recv-done.0 => send-done.0 => recv-done.1 => send-done.1 +// The data dependence already express this dependence: +// recv.0/1, send.0/1 => while-loop => recv-done.0/1, send-done.0/1 // // (3) For a pipelined P2P Send-Recv chain, if the while-body has other // collective ops, we add control dependence to ensure that the pipelined -// Send-done is ordered before other P2P chains while the pipelined -// Recv is ordered after other P2P chains. For example, if the other collective -// op is another Send-Recv chain, we make the pipelined Send-done the control -// predecessor of the other Recv and the pipelined Recv the control successor of -// the other other Send. Here is an example to illustrate the problem we -// address: +// Send-done (or Send-done.1 in the cyclic case) is ordered before other P2P +// chains while the pipelined Recv ( or Recv.1 in the cyclic case) is ordered +// after other P2P chains. For example, if the other collective op is another +// Send-Recv chain, we make the pipelined Send-done the control predecessor of +// the other Recv and the pipelined Recv the control successor of the other +// other Send. Here is an example to illustrate the problem we address: // // Assume a while-body with the following HLO collective-permute operations: -// collective-permute-start-1 = (u32[2], u32[2]) +// collective-permute-start.1 = (u32[2], u32[2]) // collective-permute-start(data), channel_id=1... -// collective-permute-done-1 = u32[2], channel_id=1 -// use of collective-permute-done-1 result -// collective-permute-start-2 = (u32[2], u32[2]) +// collective-permute-done.1 = u32[2], channel_id=1 +// use of collective-permute-done.1 result +// collective-permute-start.2 = (u32[2], u32[2]) // collective-permute-start(data), channel_id=2... -// collective-permute-done-2 = u32[2], channel_id=2 -// use of collective-permute-done-2 result +// collective-permute-done.2 = u32[2], channel_id=2 +// use of collective-permute-don.2 result // // Now assume we transform the collective-permute operations into two P2P // Send-Recv chains, the block of code will become something like this: -// after-all-1 = token[] after-all() -// recv-1 = (u32[2], token[]) recv(after-all-1), channel_id=1 ... -// send-1 = (u32[2], token[]) send(data, after-all-1), channel_id=1 ... -// recv-done-1 = (u32[2], token[]) recv-done(recv-1), channel_id=1 ... -// send-done-1 = token[] send-done(send-1), channel_id=1 ... -// use of recv-done-1 result -// after-all-2 = token[] after-all() -// recv-2 = (u32[2], token[]) recv(after-all-2), channel_id=2 ... -// send-2 = (u32[2], token[]) send(data, after-all-2), channel_id=2 ... -// recv-done-2 = (u32[2], token[]) recv-done(recv-2), channel_id=2 ... -// send-done-2 = token[] send-done(send-2), channel_id=2 ... -// use of recv-done-2 result +// after-all.1 = token[] after-all() +// recv.1 = (u32[2], token[]) recv(after-all.1), channel_id=1 ... +// send.1 = (u32[2], token[]) send(data, after-all.1), channel_id=1 ... +// recv-done.1 = (u32[2], token[]) recv-done(recv.1), channel_id=1 ... +// send-done.1 = token[] send-done(send.1), channel_id=1 ... +// use of recv-done.1 result +// after-all.2 = token[] after-all() +// recv.2 = (u32[2], token[]) recv(after-all.2), channel_id=2 ... +// send.2 = (u32[2], token[]) send(data, after-all.2), channel_id=2 ... +// recv-done.2 = (u32[2], token[]) recv-done(recv.2), channel_id=2 ... +// send-done.2 = token[] send-done(send.2), channel_id=2 ... +// use of recv-done.2 result // // If the while-loop is not pipelined, this pass adds control dependence to // make sure the first Send-Recv chain finish before the second Send-Recv // starts. // // If the while-loop is pipelined for the first Send-Recv chain, then the -// first Recv and the last Send of the chain are moved to the computation -// that calls the while-loop, and the block of code in the while-body will -// become something like this: - -// after-all-1 = token[] after-all() -// send-1 = (u32[2], token[]) send(data, after-all-1), channel_id=1 ... -// send-done-1 = token[] send-done(send-1), channel_id=1 ... -// use of recv-done-1 result from the previous iteration or the computation -// that calls the while-loop (for the first iteration) -// -// after-all-2 = token[] after-all() -// recv-2 = (u32[2], token[]) recv(after-all-2), channel_id=2 ... -// send-2 = (u32[2], token[]) send(data, after-all-2), channel_id=2 ... -// recv-done-2 = (u32[2], token[]) recv-done(recv-2), channel_id=2 ... -// send-done-2 = token[] send-done(send-2), channel_id=2 ... -// use of recv-done-2 result -// -// recv-1 = (u32[2], token[]) recv(after-all-1), channel_id=1 ... -// recv-done-1 = (u32[2], token[]) recv-done(recv-1), channel_id=1 ... +// first Recv/Send and the last Recv-done/Send-done of the chain are moved to +// the computation that calls the while-loop, and the block of code in the +// while-body will become something like this: +// recv.1 = (u32[2], u32[], token[]) get-tuple-element(param), index=1 +// recv-done.1 = (u32[2], token[]) recv-done(recv.1), channel_id=1 +// send.1 = (u32[2], u32[], token[]) get-tuple-element(param), index=4 +// send-done.1 = token[] send-done(send.1), channel_id=1 +// use of recv-done.1 result +// after-all.2 = token[] after-all() +// recv.2 = (u32[2], token[]) recv(after-all.2), channel_id=2 ... +// send.2 = (u32[2], token[]) send(data, after-all.2), channel_id=2 ... +// recv-done.2 = (u32[2], token[]) recv-done(recv.2), channel_id=2 ... +// send-done.2 = token[] send-done(send.2), channel_id=2 ... +// use of recv-done.2 result +// after-all.1.n = token[] after-all() +// recv.1.n = (u32[2], u32[], token[]) recv(after-all.1.n), channel_id=1 +// send.1.n = (u32[2], u32[], token[]) send(new-data, after-all.1.n), +// channel_id=1 // // In this case, we make send-done-1 the control predecessor of recv-2 and -// send-done-2 the control predecessor of recv-1 to ensure that the second +// send-done-2 the control predecessor of recv-1.n to ensure that the second // Send-Recv chain is executed after the Send for the first chain finishes and // before the Recv for the first chain starts. // @@ -183,7 +196,7 @@ class P2PSchedulePreparation : public HloModulePass { using HloPassInterface::Run; // Runs P2PSchedulePreparation pass on computations in 'module'. // Returns whether the 'module' was changed. - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/p2p_schedule_preparation_test.cc b/xla/service/p2p_schedule_preparation_test.cc index 83a9fe56085f7..bcd2bedef7fd0 100644 --- a/xla/service/p2p_schedule_preparation_test.cc +++ b/xla/service/p2p_schedule_preparation_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include "absl/algorithm/container.h" #include "absl/log/log.h" @@ -36,8 +35,9 @@ namespace { class P2PSchedulePreparationTest : public HloTestBase { public: - // Verifies that no control dependence enforces are added to the P2P chain. - void VerifyP2PNotTransformed(HloModule* module, std::string suffix = "") { + // Verifies that no control dependence is added to the P2P group. + void VerifyP2PNotTransformed(HloModule* module, + const std::string& suffix = "") { HloInstruction* recv = FindInstruction(module, "recv" + suffix); HloInstruction* recv_done = FindInstruction(module, "recv-done" + suffix); HloInstruction* send_done = FindInstruction(module, "send-done" + suffix); @@ -46,10 +46,9 @@ class P2PSchedulePreparationTest : public HloTestBase { EXPECT_EQ(send_done->control_predecessors().size(), 0); } - // Verifies that the control dependence enforces this ordering for an - // unpipelined Send-Recv chain: - // recv => send => recv-done => send-done. - void VerifyUnpipelinedP2P(HloModule* module, std::string suffix = "") { + // Verifies that the control dependence enforces this ordering: + // recv => send => recv-done => send-done + void VerifyP2P1GroupChain(HloModule* module, const std::string& suffix) { HloInstruction* send = FindInstruction(module, "send" + suffix); HloInstruction* recv = FindInstruction(module, "recv" + suffix); HloInstruction* recv_done = FindInstruction(module, "recv-done" + suffix); @@ -59,27 +58,70 @@ class P2PSchedulePreparationTest : public HloTestBase { EXPECT_EQ(send_done->control_predecessors()[0], recv_done); } + // Verifies that the control dependence enforces this ordering for an + // unpipelined Send-Recv chain: + // recv => send => recv-done => send-done + void VerifyUnpipelinedP2P(HloModule* module, const std::string& suffix = "") { + VerifyP2P1GroupChain(module, suffix); + } + // Verifies that the control dependence enforces this ordering for a pipelined // Send-Recv chain in the while-body: - // send => recv. - void VerifyPipelinedP2PChild(HloModule* module, std::string suffix = "") { - HloInstruction* send = FindInstruction(module, "send" + suffix); - HloInstruction* recv = FindInstruction(module, "recv" + suffix); - HloInstruction* recv_done = FindInstruction(module, "recv-done" + suffix); - HloInstruction* send_done = FindInstruction(module, "send-done" + suffix); - // If the while-body has other P2P, the pipelined Recv should also the - // Send-done of the other P2P as control predecessors. - EXPECT_EQ(1, absl::c_count(recv->control_predecessors(), send)); - EXPECT_EQ(recv_done->control_predecessors().size(), 0); - EXPECT_EQ(send_done->control_predecessors().size(), 0); + // recv => send => recv-done => send-done + void VerifyPipelinedP2PChild(HloModule* module, + const std::string& suffix = "") { + VerifyP2P1GroupChain(module, suffix); } - // Verifies that no control dependence are added to a pipelined Send-Recv - // in the computation with the while-loop as the data dependence already - // expresses this ordering: - // recv => recv-done => while-loop => send => send-done. - void VerifyPipelinedP2PParent(HloModule* module, std::string suffix = "") { - VerifyP2PNotTransformed(module, suffix); + // Verifies that the control dependence enforces this ordering for a pipelined + // Send-Recv chain in the while-loop calling computation: + // recv => send => while-loop => recv-done => send-done. + void VerifyPipelinedP2PParent(HloModule* module, + const std::string& suffix = "") { + VerifyP2P1GroupChain(module, suffix); + } + + // Verifies that the control dependence enforces this ordering: + // recv.0 => send.0 => recv.1 => send.1 => + // recv-done.0 => recv-done.1 => send-done.0 => send-done.1 + void VerifyP2P2GroupChain(HloModule* module, const std::string& suffix0, + const std::string& suffix1) { + HloInstruction* send0 = FindInstruction(module, "send" + suffix0); + HloInstruction* recv0 = FindInstruction(module, "recv" + suffix0); + HloInstruction* recv_done0 = FindInstruction(module, "recv-done" + suffix0); + HloInstruction* send_done0 = FindInstruction(module, "send-done" + suffix0); + HloInstruction* send1 = FindInstruction(module, "send" + suffix1); + HloInstruction* recv1 = FindInstruction(module, "recv" + suffix1); + HloInstruction* recv_done1 = FindInstruction(module, "recv-done" + suffix1); + HloInstruction* send_done1 = FindInstruction(module, "send-done" + suffix1); + + EXPECT_EQ(recv_done1->control_predecessors()[0], recv_done0); + EXPECT_EQ(send_done0->control_predecessors()[0], recv_done1); + EXPECT_EQ(send_done1->control_predecessors()[0], send_done0); + + EXPECT_EQ(send0->control_predecessors()[0], recv0); + EXPECT_EQ(recv1->control_predecessors()[0], send0); + EXPECT_EQ(send1->control_predecessors()[0], recv1); + + EXPECT_EQ(recv_done0->control_predecessors()[0], send1); + } + + // Verifies that the control dependence enforces this ordering for a pipelined + // chain with two Send-Recv groups in a while-body: + // recv.0 => send.0 => recv.1 => send.1 => + // recv-done.0 => send-done.0 => recv-done.1 => send-done.1 + void VerifyPipelined2P2PChild(HloModule* module, const std::string& suffix0, + const std::string& suffix1) { + VerifyP2P2GroupChain(module, suffix0, suffix1); + } + + // Verifies that the control dependence enforces this ordering for a pipelined + // chain with two Send-Recv groups in the while-loop calling computation: + // recv.0 => send.0 => recv.1 => send.1 => + // => recv-done.0 => send-done.0 => recv-done.1 => send-done.1 + void VerifyPipelined2P2PParent(HloModule* module, const std::string& suffix0, + const std::string& suffix1) { + VerifyP2P2GroupChain(module, suffix0, suffix1); } }; @@ -288,7 +330,6 @@ TEST_F(P2PSchedulePreparationTest, NestedP2PChainTransformed) { // the purpose of testing its ordering with respect to P2P chain. std::string GetPipelinedP2PModuleString(bool nested_p2p_in_main = false, bool other_p2p_in_while = false, - bool deadlock_in_while = false, bool test_custom_call = false) { // This is to support the while-loop with nested P2P chains called from the // main computation. @@ -327,44 +368,38 @@ std::string GetPipelinedP2PModuleString(bool nested_p2p_in_main = false, // while-loop with nested P2P chains. constexpr char kUnnestedResult[] = R"( while-result-1 = f32[1, 1024, 1024] get-tuple-element(while-result), index=1 - ROOT collective-permute.2 = f32[1, 1024, 1024] collective-permute(while-result-1), + collective-permute.2 = f32[1, 1024, 1024] collective-permute(init), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}} + ROOT entry-result = f32[1, 1024, 1024] add(while-result-1, collective-permute.2) )"; // Similar to the above, but for test_custom_call = true. constexpr char kUnnestedResultWithCustomCall[] = R"( while-result-1 = f32[1, 1024, 1024] get-tuple-element(while-result), index=1 - ROOT custom-call = f32[1, 1024, 1024] custom-call(while-result-1), custom_call_target="my_custom_call" + custom-call = f32[1, 1024, 1024] custom-call(init), + custom_call_target="my_custom_call" + ROOT entry-result = f32[1, 1024, 1024] add(while-result-1, custom-call) )"; // This is the result for the main computation, if it has another while-loop // with nested P2P chains. constexpr char kNestedResult[] = R"( while-result-1 = f32[1, 1024, 1024] get-tuple-element(while-result), index=1 - while-init-2 = (u32[], f32[1, 1024, 1024]) tuple(c0, while-result-1) - while-result-2 = (u32[], f32[1, 1024, 1024]) while(while-init-2), + while-init-2 = (u32[], f32[1, 1024, 1024]) tuple(c0, init) + while-2 = (u32[], f32[1, 1024, 1024]) while(while-init-2), body=while-body-2, condition=while-cond-2, backend_config={"known_trip_count":{"n":"25"}} - ROOT entry-result = f32[1, 1024, 1024] get-tuple-element(while-result-2), index=1 + while-result-2 = f32[1, 1024, 1024] get-tuple-element(while-2), index=1 + ROOT entry-result = f32[1, 1024, 1024] add(while-result-1, while-result-2) )"; constexpr char kPipelinedWhileBodyWithoutOtherP2P[] = R"( while-body { - param = (u32[], f32[1, 1024, 1024], f32[1, 1024, 1024]) parameter(0) + param = (u32[], (f32[1, 1024, 1024], token[]), token[]) parameter(0) count = get-tuple-element(param), index=0 - send-data = get-tuple-element(param), index=1 - recv-data = get-tuple-element(param), index=2 - after-all.1 = token[] after-all() - send.1 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.1), - channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" - } - send-done.1 = token[] send-done(send.1), channel_id=1 - recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=1, - frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" - } + recv-done.1.q = (f32[1, 1024, 1024], token[]) get-tuple-element(param), index=1 + recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0 c1 = u32[] constant(1) new-count = u32[] add(count, c1) @@ -383,30 +418,37 @@ std::string GetPipelinedP2PModuleString(bool nested_p2p_in_main = false, source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}} new-data = f32[1, 1024, 1024] add(c, collective-permute.1) - recv-done.1 = (f32[1, 1024, 1024], token[]) recv-done(recv.1), channel_id=1 - new-recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.1), index=0 + after-all.1 = token[] after-all() + send.1 = (f32[1, 1024, 1024], token[]) send(new-data, after-all.1), + channel_id=1, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" + } + send-done.1 = token[] send-done(send.1), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + recv.1 = (f32[1, 1024, 1024], token[]) recv(after-all.1), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" + } + recv-done.1 = (f32[1, 1024, 1024], token[]) recv-done(recv.1), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } - ROOT body-result = (u32[], f32[1, 1024, 1024], f32[1, 1024, 1024]) tuple(new-count, new-data, new-recv-data) + ROOT body-result = (u32[], (f32[1, 1024, 1024], token[]), token[]) + tuple(new-count, recv-done.1, send-done.1) } )"; constexpr char kPipelinedWhileBodyWithOtherP2P[] = R"( while-body { - param = (u32[], f32[1, 1024, 1024], f32[1, 1024, 1024]) parameter(0) + param = (u32[], (f32[1, 1024, 1024], token[]), token[]) parameter(0) count = get-tuple-element(param), index=0 - send-data = get-tuple-element(param), index=1 - recv-data = get-tuple-element(param), index=2 - - after-all.1 = token[] after-all() - send.1 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.1), - channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" - } - send-done.1 = token[] send-done(send.1), channel_id=1 - recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=1, - frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" - } + recv-done.1.q = (f32[1, 1024, 1024], token[])get-tuple-element(param), index=1 + recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0 c1 = u32[] constant(1) new-count = u32[] add(count, c1) @@ -421,68 +463,44 @@ std::string GetPipelinedP2PModuleString(bool nested_p2p_in_main = false, d = f32[1, 1024, 1024] tan(c) s = f32[1, 1024, 1024] dot(c, d), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} - new-data-0 = f32[1, 1024, 1024] add(c, s) - - recv-done.1 = (f32[1, 1024, 1024], token[]) recv-done(recv.1), channel_id=1 - new-recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.1), index=0 + collective-permute.1 = f32[1, 1024, 1024] collective-permute(s), + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}} + send-data = f32[1, 1024, 1024] add(c, collective-permute.1) after-all.4 = token[] after-all() send.4 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.4), channel_id=4, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" - } + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" + } send-done.4 = token[] send-done(send.4), channel_id=4 recv.4 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.4), channel_id=4, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" - } + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" + } recv-done.4 = (f32[1, 1024, 1024], token[]) recv-done(recv.4), channel_id=4 - recv-data-4 = f32[1, 1024, 1024] get-tuple-element(recv-done.4), index=0 - new-data = f32[1, 1024, 1024] add(new-data-0, recv-data-4) - - ROOT body-result = (u32[], f32[1, 1024, 1024], f32[1, 1024, 1024]) tuple(new-count, new-data, new-recv-data) - } -)"; + new-data = f32[1, 1024, 1024] get-tuple-element(recv-done.4), index=0 - constexpr char kPipelinedWhileBodyDeadlock[] = R"( - while-body { - param = (u32[], f32[1, 1024, 1024], f32[1, 1024, 1024]) parameter(0) - count = get-tuple-element(param), index=0 - send-data = get-tuple-element(param), index=1 - recv-data = get-tuple-element(param), index=2 - - collective-permute.1 = f32[1, 1024, 1024] collective-permute(send-data), - source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}} after-all.1 = token[] after-all() - send.1 = (f32[1, 1024, 1024], u32[], token[]) send(collective-permute.1, after-all.1), + send.1 = (f32[1, 1024, 1024], token[]) send(new-data, after-all.1), channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" - } - send-done.1 = token[] send-done(send.1), channel_id=1 - recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=1, + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" + } + send-done.1 = token[] send-done(send.1), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + recv.1 = (f32[1, 1024, 1024], token[]) recv(after-all.1), channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" } - - c1 = u32[] constant(1) - new-count = u32[] add(count, c1) - replica = u32[] replica-id() - c10 = u32[] constant(10) - sum = u32[] add(replica, c10) - sum2 = u32[] add(sum, count) - conv = f32[] convert(sum2) - p = f32[1, 1024, 1024] broadcast(conv), dimensions={} - b = f32[1, 1024, 1024] add(p, recv-data) - c = f32[1, 1024, 1024] multiply(b, b) - d = f32[1, 1024, 1024] tan(c) - s = f32[1, 1024, 1024] dot(c, d), lhs_batch_dims={0}, - lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} - new-data = f32[1, 1024, 1024] add(c, s) - - recv-done.1 = (f32[1, 1024, 1024], token[]) recv-done(recv.1), channel_id=1 - new-recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.1), index=0 - - ROOT body-result = (u32[], f32[1, 1024, 1024], f32[1, 1024, 1024]) tuple(new-count, new-data, new-recv-data) + recv-done.1 = (f32[1, 1024, 1024], token[]) recv-done(recv.1), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + ROOT body-result = (u32[], (f32[1, 1024, 1024], token[]), token[]) + tuple(new-count, recv-done.1, send-done.1) } )"; @@ -490,16 +508,16 @@ std::string GetPipelinedP2PModuleString(bool nested_p2p_in_main = false, HloModule test while-cond { - param = (u32[], f32[1, 1024, 1024], f32[1, 1024, 1024]) parameter(0) + param = (u32[], (f32[1, 1024, 1024], u32[], token[]), token[]) parameter(0) count = get-tuple-element(param), index=0 ub = u32[] constant(25) ROOT cond-result = pred[] compare(count, ub), direction=LT } - // The pipelined while-body goes here. + // The code that support the while-loop with nested P2P chains goes here. %s - // The code that support the while-loop with nested P2P chains goes here. + // The pipelined while-body goes here. %s ENTRY test-computation { @@ -508,24 +526,34 @@ std::string GetPipelinedP2PModuleString(bool nested_p2p_in_main = false, init = f32[1, 1024, 1024] broadcast(f0), dimensions={} after-all.2 = token[] after-all() - recv.2 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.2), channel_id=1, + recv.2 = (f32[1, 1024, 1024], token[]) recv(after-all.2), channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" - } - recv-done.2 = (f32[1, 1024, 1024], token[]) recv-done(recv.2), channel_id=1 - recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.2), index=0 + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" + } + recv-done.2 = (f32[1, 1024, 1024], token[]) recv-done(recv.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send.2 = (f32[1, 1024, 1024], token[]) send(init, after-all.2), + channel_id=1, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" + } + send-done.2 = token[] send-done(send.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } - while-init = (u32[], f32[1, 1024, 1024], f32[1, 1024, 1024]) tuple(c0, init, recv-data) - while-result = (u32[], f32[1, 1024, 1024], f32[1, 1024, 1024]) while(while-init), + while-init = (u32[], (f32[1, 1024, 1024], token[]), token[]) + tuple(c0, recv-done.2, send-done.2) + while-result = (u32[], (f32[1, 1024, 1024], token[]), token[]) + while(while-init), body=while-body, condition=while-cond, backend_config={"known_trip_count":{"n":"25"}} - send-data = f32[1, 1024, 1024] get-tuple-element(while-result), index=2 - send.2 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.2), - channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" - } - send-done.2 = token[] send-done(send.2), channel_id=1 + recv-done.2.q = (f32[1, 1024, 1024], token[]) get-tuple-element(while-result), index=1 + recv-data.2.q = f32[1, 1024, 1024] get-tuple-element(recv-done.2.q), index=0 // The code for the computation result goes here. %s @@ -534,10 +562,8 @@ std::string GetPipelinedP2PModuleString(bool nested_p2p_in_main = false, const char* while_str = nested_p2p_in_main ? kWhileForMain : kEmpty; const char* pipelined_while_body_str = - deadlock_in_while - ? kPipelinedWhileBodyDeadlock - : (other_p2p_in_while ? kPipelinedWhileBodyWithOtherP2P - : kPipelinedWhileBodyWithoutOtherP2P); + other_p2p_in_while ? kPipelinedWhileBodyWithOtherP2P + : kPipelinedWhileBodyWithoutOtherP2P; const char* result_str = nested_p2p_in_main ? kNestedResult : (test_custom_call ? kUnnestedResultWithCustomCall @@ -546,19 +572,6 @@ std::string GetPipelinedP2PModuleString(bool nested_p2p_in_main = false, result_str); } -TEST_F(P2PSchedulePreparationTest, PipelinedP2PChainDeadlocked) { - std::string kModuleStr = GetPipelinedP2PModuleString( - /*nested_p2p_in_main=*/false, /*other_p2p_in_while=*/false, - /*deadlock_in_while=*/true); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - P2PSchedulePreparation preparation; - auto status = preparation.Run(module.get()); - EXPECT_EQ(status.ok(), false); - EXPECT_THAT(status.status().message(), - ::testing::HasSubstr("deadlock in input HLO")); -} - TEST_F(P2PSchedulePreparationTest, UnnestedPipelinedP2PChainTransformed) { std::string kModuleStr = GetPipelinedP2PModuleString(); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -568,26 +581,27 @@ TEST_F(P2PSchedulePreparationTest, UnnestedPipelinedP2PChainTransformed) { EXPECT_TRUE(changed); VLOG(10) << module->ToString(); - // Verify the pipelined P2P chain in the whild-body. + // Verify the pipelined P2P chain in the while-body. VerifyPipelinedP2PChild(module.get(), ".1"); // Verify the pipelined P2P chain in the main computation. VerifyPipelinedP2PParent(module.get(), ".2"); - // Verify in the while-body collective-permute is scheduled after Send-done - // and before Recv. - HloInstruction* send_done_1 = FindInstruction(module.get(), "send-done.1"); - HloInstruction* recv = FindInstruction(module.get(), "recv.1"); + // Verify in the while-body collective-permute is scheduled before recv. + HloInstruction* recv_1 = FindInstruction(module.get(), "recv.1"); HloInstruction* collective_1 = FindInstruction(module.get(), "collective-permute.1"); - EXPECT_EQ(collective_1->control_predecessors()[0], send_done_1); - EXPECT_EQ(1, absl::c_count(recv->control_predecessors(), collective_1)); + EXPECT_EQ(recv_1->control_predecessors()[0], collective_1); - // Verify in the main computation collective-permute is scheduled after the - // Send-done for the pipelined while-loop. + // Verify in the main computation collective-permute is either scheduled + // after send-done or before recv of the pipelined P2P chain. HloInstruction* send_done_2 = FindInstruction(module.get(), "send-done.2"); + HloInstruction* recv_2 = FindInstruction(module.get(), "recv.2"); HloInstruction* collective_2 = FindInstruction(module.get(), "collective-permute.2"); - EXPECT_EQ(collective_2->control_predecessors()[0], send_done_2); + EXPECT_TRUE((!collective_2->control_predecessors().empty() && + collective_2->control_predecessors()[0] == send_done_2) || + (!recv_2->control_predecessors().empty() && + recv_2->control_predecessors()[0] == collective_2)); } TEST_F(P2PSchedulePreparationTest, NestedPipelinedP2PChainTransformed) { @@ -600,18 +614,22 @@ TEST_F(P2PSchedulePreparationTest, NestedPipelinedP2PChainTransformed) { EXPECT_TRUE(changed); VLOG(10) << module->ToString(); - // Verify the pipelined P2P chain in the whild-body. + // Verify the pipelined P2P chain in the while-body. VerifyPipelinedP2PChild(module.get(), ".1"); // Verify the pipelined P2P chain in the main computation. VerifyPipelinedP2PParent(module.get(), ".2"); // Verify the unpipelined P2P chain in the other while-body. VerifyUnpipelinedP2P(module.get(), ".3"); - // Verify that the while-loop with nested P2P is schedule after the last - // Send-done of the pipeline P2P chain. - HloInstruction* send_done = FindInstruction(module.get(), "send-done.2"); - HloInstruction* while_user = FindInstruction(module.get(), "while-result-2"); - EXPECT_EQ(while_user->control_predecessors()[0], send_done); + // Verify in the while-loop with nested P2P is either scheduled after + // end-done or before recv of the pipelined P2P chain. + HloInstruction* send_done_2 = FindInstruction(module.get(), "send-done.2"); + HloInstruction* recv_2 = FindInstruction(module.get(), "recv.2"); + HloInstruction* while_2 = FindInstruction(module.get(), "while-2"); + EXPECT_TRUE((!while_2->control_predecessors().empty() && + while_2->control_predecessors()[0] == send_done_2) || + (!recv_2->control_predecessors().empty() && + recv_2->control_predecessors()[0] == while_2)); } TEST_F(P2PSchedulePreparationTest, @@ -625,23 +643,18 @@ TEST_F(P2PSchedulePreparationTest, EXPECT_TRUE(changed); VLOG(10) << module->ToString(); - // Verify the pipelined P2P chain in the whild-body. + // Verify the pipelined P2P chain in the while-body. VerifyPipelinedP2PChild(module.get(), ".1"); // Verify the pipelined P2P chain in the main computation. VerifyPipelinedP2PParent(module.get(), ".2"); // Verify the other unpipelined P2P chain in the while-body. VerifyUnpipelinedP2P(module.get(), ".4"); - // Verify that in the pipelined while-body, the pipelined Send is ordered - // before other P2P while the pipelined Recv is ordered after other P2P. - HloInstruction* pipelined_send_done = - FindInstruction(module.get(), "send-done.1"); + // Verify that in the pipelined while-body, the pipelined recv is ordered + // after other P2P. HloInstruction* pipelined_recv = FindInstruction(module.get(), "recv.1"); - HloInstruction* other_recv = FindInstruction(module.get(), "recv.4"); HloInstruction* other_send_done = FindInstruction(module.get(), "send-done.4"); - EXPECT_EQ(1, absl::c_count(other_recv->control_predecessors(), - pipelined_send_done)); EXPECT_EQ(1, absl::c_count(pipelined_recv->control_predecessors(), other_send_done)); } @@ -650,18 +663,289 @@ TEST_F(P2PSchedulePreparationTest, UnnestedPipelinedP2PChainWithCustomCallTransformed) { std::string kModuleStr = GetPipelinedP2PModuleString( /*nested_p2p_in_main=*/false, /*other_p2p_in_while=*/false, - /*deadlock_in_while=*/false, /*test_custom_call=*/true); + /*test_custom_call=*/true); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnUnverifiedModule((kModuleStr))); P2PSchedulePreparation preparation; TF_ASSERT_OK_AND_ASSIGN(bool changed, preparation.Run(module.get())); EXPECT_TRUE(changed); - // Verify in the main computation custom-call is scheduled after the - // Send-done for the pipelined while-loop. + // Verify in the main computation, custom-call is either scheduled after + // end-done or before recv of the pipelined P2P chain. HloInstruction* send_done_2 = FindInstruction(module.get(), "send-done.2"); + HloInstruction* recv_2 = FindInstruction(module.get(), "recv.2"); HloInstruction* custom_call = FindInstruction(module.get(), "custom-call"); - EXPECT_EQ(custom_call->control_predecessors()[0], send_done_2); + EXPECT_TRUE((!custom_call->control_predecessors().empty() && + custom_call->control_predecessors()[0] == send_done_2) || + (!recv_2->control_predecessors().empty() && + recv_2->control_predecessors()[0] == custom_call)); +} + +TEST_F(P2PSchedulePreparationTest, PipelinedP2PChain2Transformed) { + const char* const kModuleStr = R"( + HloModule test + +cond { + param = (u32[], (u32[2], token[]), (u32[2], token[]), + token[], token[]) parameter(0) + count = get-tuple-element(%param), index=0 + ub = u32[] constant(10) + ROOT result = pred[] compare(count, ub), direction=LT + } + +body { + param = (u32[], (u32[2], token[]), (u32[2], token[]), + token[], token[]) parameter(0) + count = get-tuple-element(param), index=0 + + recv-done.0.f = (u32[2], token[]) get-tuple-element(param), index=1 + recv-data.0 = u32[2] get-tuple-element(recv-done.0.f), index=0 + recv-done.1.f = (u32[2], token[]) get-tuple-element(param), index=2 + recv-data.1 = u32[2] get-tuple-element(recv-done.1.f), index=0 + + replica = u32[] replica-id() + constant0 = u32[] constant(0) + compare0 = pred[] compare(replica, constant0), direction=EQ + compare = pred[2] broadcast(compare0), dimensions={} + recv-data = u32[2] select(compare, recv-data.0, recv-data.1) + + c1 = u32[] constant(1) + new_count = u32[] add(count, c1) + + r = u32[2] broadcast(c1), dimensions={} + s = u32[2] add(r, recv-data) + + // The Recv "rotated" from the beginning of the loop to the end of the loop. + after-all.0.n = token[] after-all() + recv.0 = (u32[2], u32[], token[]) recv(after-all.0.n), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } + send.0 = (u32[2], u32[], token[]) send(s, after-all.0.n), + channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } + recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send-done.0 = token[] send-done(send.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + + after-all.1.n = token[] after-all() + recv.1 = (u32[2], u32[], token[]) recv(after-all.1.n), channel_id=2, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3}}", + _xla_send_recv_pipeline="1" + } + send.1 = (u32[2], u32[], token[]) send(s, after-all.1.n), + channel_id=2, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3}}", + _xla_send_recv_pipeline="1" + } + recv-done.1 = (u32[2], token[]) recv-done(recv.1), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } + send-done.1 = token[] send-done(send.1), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } + ROOT result = (u32[], (u32[2], token[]), (u32[2], token[]), token[], token[]) + tuple(new_count, recv-done.0, recv-done.1, send-done.0, send-done.1) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + c1 = u32[] constant(1) + r = u32[] replica-id() + a = u32[] add(c1, r) + init = u32[2] broadcast(a), dimensions={} + + // Peel off both Recv. + after-all.0.p = token[] after-all() + recv.2 = (u32[2], u32[], token[]) recv(after-all.0.p), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } + send.2 = (u32[2], u32[], token[]) send(init, after-all.0.p), + channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } + recv-done.2 = (u32[2], token[]) recv-done(recv.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send-done.2 = token[] send-done(send.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + after-all.1.p = token[] after-all() + recv.3 = (u32[2], u32[], token[]) recv(after-all.1.p), channel_id=2, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3}}", + _xla_send_recv_pipeline="1" + } + send.3 = (u32[2], u32[], token[]) send(init, after-all.1.p), + channel_id=2, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3}}", + _xla_send_recv_pipeline="1" + } + recv-done.3 = (u32[2], token[]) recv-done(recv.3), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } + send-done.3 = token[] send-done(send.3), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } + // This is the pipelined loop. + while_init = (u32[], (u32[2], token[]), (u32[2], token[]), + token[], token[]) tuple(c0, recv-done.2, recv-done.3, send-done.2, send-done.3) + while_result = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[]), + token[], token[]) while(while_init), body=body, condition=cond, + backend_config={"known_trip_count":{"n":"10"}} + + // This is the remaining Send/Send-done/Recv-done for the pipeline. + // Use .q as suffix for HLO name. + recv-done.0.q = (u32[2], u32[], token[]) get-tuple-element(while_result), index=1 + recv-data.0.q = u32[2] get-tuple-element(recv-done.0.q), index=0 + + recv-done.1.q = (u32[2], u32[], token[]) get-tuple-element(while_result), index=2 + recv-data.1.q = u32[2] get-tuple-element(recv-done.1.q), index=0 + + replica = u32[] replica-id() + constant0 = u32[] constant(0) + compare0 = pred[] compare(replica, constant0), direction=EQ + compare = pred[2] broadcast(compare0), dimensions={} + recv-data = u32[2] select(compare, recv-data.0.q, recv-data.1.q) + + s = u32[2] add(c1, recv-data) + + ROOT result = u32[2] add(s, recv-data) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kModuleStr))); + P2PSchedulePreparation preparation; + TF_ASSERT_OK_AND_ASSIGN(bool changed, preparation.Run(module.get())); + VLOG(10) << module->ToString(); + EXPECT_TRUE(changed); + + // Verify the pipelined P2P chain in the while-body. + VerifyPipelined2P2PChild(module.get(), ".0", ".1"); + // Verify the pipelined P2P chain in the main computation. + VerifyPipelined2P2PParent(module.get(), ".2", ".3"); +} + +TEST_F(P2PSchedulePreparationTest, UnpipelinedP2PChain2Transformed) { + const char* const kModuleStr = R"( + HloModule test + +cond { + param = (u32[], u32[2]) parameter(0) + count = get-tuple-element(%param), index=0 + ub = u32[] constant(11) + ROOT result = pred[] compare(count, ub), direction=LT + } + +body { + param = (u32[], u32[2]) parameter(0) + count = get-tuple-element(param), index=0 + send-data = u32[2] get-tuple-element(param), index=1 + + after-all.0.n = token[] after-all() + recv.0 = (u32[2], u32[], token[]) recv(after-all.0.n), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } + send.0 = (u32[2], u32[], token[]) send(send-data, after-all.0.n), + channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } + recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send-done.0 = token[] send-done(send.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + + after-all.1 = token[] after-all() + recv.1 = (u32[2], u32[], token[]) recv(after-all.1), channel_id=2, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3}}", + _xla_send_recv_pipeline="1" + } + send.1 = (u32[2], u32[], token[]) send(send-data, after-all.1), + channel_id=2, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3}}", + _xla_send_recv_pipeline="1" + } + recv-done.1 = (u32[2], token[]) recv-done(recv.1), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } + send-done.1 = token[] send-done(send.1), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } + + recv-data.0 = u32[2] get-tuple-element(recv-done.0), index=0 + recv-data.1 = u32[2] get-tuple-element(recv-done.1), index=0 + + replica = u32[] replica-id() + constant0 = u32[] constant(0) + compare0 = pred[] compare(replica, constant0), direction=EQ + compare = pred[2] broadcast(compare0), dimensions={} + recv-data = u32[2] select(compare, recv-data.0, recv-data.1) + + c1 = u32[] constant(1) + new_count = u32[] add(count, c1) + + r = u32[2] broadcast(c1), dimensions={} + s = u32[2] add(r, recv-data) + + ROOT result = (u32[], u32[2]) tuple(new_count, s) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + c1 = u32[] constant(1) + r = u32[] replica-id() + a = u32[] add(c1, r) + init = u32[2] broadcast(a), dimensions={} + while_init = (u32[], u32[2]) tuple(c0, init) + while_result = (u32[], u32[2]) while(while_init), body=body, condition=cond, + backend_config={"known_trip_count":{"n":"11"}} + ROOT recv-data = u32[2] get-tuple-element(while_result), index=1 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kModuleStr))); + P2PSchedulePreparation preparation; + TF_ASSERT_OK_AND_ASSIGN(bool changed, preparation.Run(module.get())); + EXPECT_TRUE(changed); + + // Verify the unpipelined P2P chain with two channels in the while-body. + VerifyP2P2GroupChain(module.get(), ".0", ".1"); } } // namespace diff --git a/xla/service/pattern_matcher.h b/xla/service/pattern_matcher.h index 041d500006ad5..242f2eff0d634 100644 --- a/xla/service/pattern_matcher.h +++ b/xla/service/pattern_matcher.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -39,8 +40,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/ir/ptrvec.h" #include "xla/layout_util.h" -#include "xla/literal_util.h" #include "xla/service/hlo_parser.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" @@ -96,6 +97,9 @@ namespace xla { // contracting dimensions. // - WithReplicaGroups: Collective instruction's replica groups matches the // given pattern. +// - WithSharding: Instruction's sharding is equal to the given sharding. +// - WithControlDeps: Instruction's control predecessors/successors match +// the given list of instructions. // // Shape(): // - EqualTo @@ -1867,6 +1871,53 @@ class HloInstructionPatternOneUserImpl } }; +class HloInstructionPatternNumUserImpl { + public: + explicit constexpr HloInstructionPatternNumUserImpl(int64_t user_num) + : user_num_(user_num) {} + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + if (inst->user_count() != user_num_) { + EXPLAIN << "HloInstruction has " << inst->user_count() + << " users, but expected exactly " << user_num_ << " users."; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64_t indent = 0) const { + *os << "which has exactly " << user_num_ + << " users (but possibly is used multiple times by " + "same instruction)"; + } + + private: + int64_t user_num_; +}; + +class HloInstructionPatternAtMostNumUserImpl { + public: + explicit constexpr HloInstructionPatternAtMostNumUserImpl(int64_t user_num) + : user_num_(user_num) {} + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + if (inst->user_count() > user_num_) { + EXPLAIN << "HloInstruction has " << inst->user_count() + << " users, but expected less than or equal " << user_num_ + << " users."; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64_t indent = 0) const { + *os << "which has less than or equal " << user_num_ + << " users (but possibly is used multiple times by " + "same instruction)"; + } + + private: + int64_t user_num_; +}; + class HloInstructionPatternComparisonDirectionImpl { public: explicit constexpr HloInstructionPatternComparisonDirectionImpl( @@ -2136,6 +2187,63 @@ class HloInstructionShardingImpl { std::optional sharding_; }; +class HloInstructionControlDepsImpl { + public: + explicit HloInstructionControlDepsImpl( + absl::Span preds, + absl::Span succs) + : preds_(preds), succs_(succs) {} + + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64_t indent = 0) const { + auto print_deps = [os](absl::Span deps, + absl::string_view type) { + if (deps.empty()) { + *os << "no control " << type; + } else { + *os << "control " << type << " {" << absl::StrJoin(deps, ",", fmt) + << "}"; + } + }; + + *os << "with "; + print_deps(preds_, "predecessors"); + *os << " and "; + print_deps(succs_, "successors"); + } + + private: + template + bool MatchImpl(HloInstructionType* inst, MatchOption option) const { + auto match_deps = [&](absl::Span expected_deps, + const PtrVec& actual_deps, + absl::string_view type) { + if (!absl::c_equal(expected_deps, actual_deps)) { + EXPLAIN << "HloInstruction expected to have control " << type << " {" + << absl::StrJoin(expected_deps, ",", fmt) << "} but has {" + << absl::StrJoin(actual_deps, ",", fmt) << "}"; + return false; + } + return true; + }; + return match_deps(preds_, inst->control_predecessors(), "predecessors") && + match_deps(succs_, inst->control_successors(), "successors"); + } + + static void fmt(std::string* out, const HloInstruction* inst) { + absl::StrAppend(out, inst->name()); + }; + + absl::Span preds_, succs_; +}; + // Matches a constant scalar or effective scalar, optionally with a given value. template class HloConstantScalarImpl { @@ -2418,6 +2526,18 @@ class HloInstructionPattern { return AppendImpl(HloInstructionPatternOneUserImpl()); } + // Modifies the pattern to match if the instruction is used by exactly + // user_num times by other instruction. + constexpr auto WithNumUser(int64_t user_num) const { + return AppendImpl(HloInstructionPatternNumUserImpl(user_num)); + } + + // Modifies the pattern to match if the instruction is used by less than + // user_num times by other instruction. + constexpr auto WithAtMostNumUser(int64_t user_num) const { + return AppendImpl(HloInstructionPatternAtMostNumUserImpl(user_num)); + } + // Modifies the pattern to match only if the instruction has the given // comparison direction. auto WithComparisonDirection(ComparisonDirection direction) const { @@ -2453,6 +2573,11 @@ class HloInstructionPattern { HloInstructionShardingImpl(ParseSharding(sharding).value())); } + auto WithControlDeps(absl::Span preds, + absl::Span succs) { + return AppendImpl(HloInstructionControlDepsImpl(preds, succs)); + } + void DescribeTo(std::ostream* os, int64_t indent = 0) const { impl_.DescribeTo(os, indent); } @@ -2525,6 +2650,11 @@ XLA_NULLOP_PATTERN(ReplicaId) #define XLA_UNOP_PATTERN(NAME) \ inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \ \ + template \ + inline auto NAME(HloInstructionType** matched_inst) { \ + return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \ + } \ + \ template \ inline auto NAME(Arg&& arg) { \ return Op() \ @@ -2550,6 +2680,7 @@ XLA_UNOP_PATTERN(Cos) XLA_UNOP_PATTERN(AllReduceStart) XLA_UNOP_PATTERN(AllReduceDone) XLA_UNOP_PATTERN(AllToAll) +XLA_UNOP_PATTERN(CollectiveBroadcast) XLA_UNOP_PATTERN(CollectivePermute) XLA_UNOP_PATTERN(CollectivePermuteStart) XLA_UNOP_PATTERN(CollectivePermuteDone) @@ -2564,6 +2695,7 @@ XLA_UNOP_PATTERN(IsFinite) XLA_UNOP_PATTERN(Log) XLA_UNOP_PATTERN(Not) XLA_UNOP_PATTERN(Negate) +XLA_UNOP_PATTERN(OptimizationBarrier) XLA_UNOP_PATTERN(Real) XLA_UNOP_PATTERN(Recv) XLA_UNOP_PATTERN(RecvDone) @@ -2579,6 +2711,7 @@ XLA_UNOP_PATTERN(Sqrt) XLA_UNOP_PATTERN(Tan) XLA_UNOP_PATTERN(Tanh) XLA_UNOP_PATTERN(Transpose) +XLA_UNOP_PATTERN(While) #undef XLA_UNOP_PATTERN // Helpers for binary instructions. diff --git a/xla/service/pattern_matcher_gmock.h b/xla/service/pattern_matcher_gmock.h index 51ac1cac0a40e..e183211d645d5 100644 --- a/xla/service/pattern_matcher_gmock.h +++ b/xla/service/pattern_matcher_gmock.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/pattern_matcher_gmock_test.cc b/xla/service/pattern_matcher_gmock_test.cc index 93c1575c78f4d..81cff291024fe 100644 --- a/xla/service/pattern_matcher_gmock_test.cc +++ b/xla/service/pattern_matcher_gmock_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/pattern_matcher_test.cc b/xla/service/pattern_matcher_test.cc index 4511aa99cd08e..cd020c821b0c0 100644 --- a/xla/service/pattern_matcher_test.cc +++ b/xla/service/pattern_matcher_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -1453,5 +1453,35 @@ TEST_F(PatternMatcherTest, TestWithSharding) { "sharding={devices=[1,2,2,1]0,1,2,3}"); } +TEST_F(PatternMatcherTest, TestWithControlDeps) { + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + p0 = f32[4] parameter(0) + p1 = f32[4] parameter(1) + add = f32[4] add(p0, p1) + mul = f32[4] multiply(p0, p1), control-predecessors={add} + div = f32[4] divide(p0, p1), control-predecessors={mul} + ROOT t = (f32[4], f32[4], f32[4]) tuple(add, mul, div) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, + ParseAndReturnVerifiedModule(kModuleStr)); + auto* add = FindInstruction(hlo_module.get(), "add"); + auto* mul = FindInstruction(hlo_module.get(), "mul"); + auto* div = FindInstruction(hlo_module.get(), "div"); + + EXPECT_TRUE(Match(add, m::Op().WithControlDeps({}, {mul}))); + EXPECT_TRUE(Match(mul, m::Op().WithControlDeps({add}, {div}))); + EXPECT_TRUE(Match(div, m::Op().WithControlDeps({mul}, {}))); + EXPECT_FALSE(Match(div, m::Op().WithControlDeps({mul}, {div}))); + EXPECT_DESC_AND_EXPLANATION( + div, m::Op().WithControlDeps({mul}, {div}), + "an HloInstruction with control predecessors {mul} and control " + "successors {div}", + "HloInstruction expected to have control successors {div} but has {}\n" + "in div = f32[4]{0} divide(f32[4]{0} p0, f32[4]{0} p1), " + "control-predecessors={mul}"); +} + } // namespace } // namespace xla diff --git a/xla/service/platform_util.cc b/xla/service/platform_util.cc index 2e588545a4c35..1d7cb0c2088d2 100644 --- a/xla/service/platform_util.cc +++ b/xla/service/platform_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ limitations under the License. #include "xla/statusor.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/host/host_platform_id.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/stream_executor.h" #include "xla/types.h" @@ -63,8 +64,8 @@ std::string CanonicalPlatformName(const std::string& platform_name) { return lowercase_platform_name; } -StatusOr> GetSupportedPlatforms() { - return se::MultiPlatformManager::PlatformsWithFilter( +absl::StatusOr> GetSupportedPlatforms() { + return se::PlatformManager::PlatformsWithFilter( [](const se::Platform* platform) { auto compiler_status = Compiler::GetForPlatform(platform); bool supported = compiler_status.ok(); @@ -79,18 +80,18 @@ StatusOr> GetSupportedPlatforms() { } // namespace -/*static */ StatusOr PlatformUtil::CanonicalPlatformName( +/*static */ absl::StatusOr PlatformUtil::CanonicalPlatformName( const std::string& platform_name) { return xla::CanonicalPlatformName(platform_name); } -/* static */ StatusOr> +/* static */ absl::StatusOr> PlatformUtil::GetSupportedPlatforms() { // Gather all platforms which have an XLA compiler. return xla::GetSupportedPlatforms(); } -/* static */ StatusOr PlatformUtil::GetDefaultPlatform() { +/* static */ absl::StatusOr PlatformUtil::GetDefaultPlatform() { TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); se::Platform* platform = nullptr; @@ -121,10 +122,10 @@ PlatformUtil::GetSupportedPlatforms() { platforms_string); } -/*static*/ StatusOr PlatformUtil::GetPlatform( +/*static*/ absl::StatusOr PlatformUtil::GetPlatform( const std::string& platform_name) { TF_ASSIGN_OR_RETURN(se::Platform * platform, - se::MultiPlatformManager::PlatformWithName( + se::PlatformManager::PlatformWithName( xla::CanonicalPlatformName(platform_name))); TF_RETURN_IF_ERROR(Compiler::GetForPlatform(platform).status()); return platform; @@ -161,7 +162,7 @@ static bool IsDeviceSupported(se::StreamExecutor* executor) { return true; } -/* static */ StatusOr> +/* static */ absl::StatusOr> PlatformUtil::GetStreamExecutors( se::Platform* platform, const std::optional>& allowed_devices) { @@ -237,7 +238,7 @@ PlatformUtil::GetStreamExecutors( } } if (out.empty()) { - return InternalError("no supported devices found for platform %s", + return Internal("no supported devices found for platform %s", platform->Name()); } return out; diff --git a/xla/service/platform_util.h b/xla/service/platform_util.h index 57695f6449948..0e012ef4b7b8c 100644 --- a/xla/service/platform_util.h +++ b/xla/service/platform_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,25 +34,26 @@ class PlatformUtil { // This is needed to differentiate if for given platform like GPU or CPU // there are multiple implementations. For example, GPU platform may be // cuda(Nvidia) or rocm(AMD) - static StatusOr CanonicalPlatformName( + static absl::StatusOr CanonicalPlatformName( const std::string& platform_name); // Returns the platforms present on the system and supported by XLA. // // Note that, even if a platform is present with zero devices, if we *do* have // compilation support for it, it will be returned in this sequence. - static StatusOr> GetSupportedPlatforms(); + static absl::StatusOr> GetSupportedPlatforms(); // Convenience function which returns the default supported platform for // tests. If exactly one supported platform is present, then this platform is // the default platform. If exactly two platforms are present and one of them // is the interpreter platform, then the other platform is the default // platform. Otherwise returns an error. - static StatusOr GetDefaultPlatform(); + static absl::StatusOr GetDefaultPlatform(); // Returns the platform according to the given name. Returns error if there is // no such platform. - static StatusOr GetPlatform(const std::string& platform_name); + static absl::StatusOr GetPlatform( + const std::string& platform_name); // Returns a vector of StreamExecutors for the given platform. // If populated, only the devices in allowed_devices will have @@ -60,7 +61,7 @@ class PlatformUtil { // initialized and returned. // // If the platform has no visible devices, a not-found error is returned. - static StatusOr> GetStreamExecutors( + static absl::StatusOr> GetStreamExecutors( se::Platform* platform, const std::optional>& allowed_devices = std::nullopt); diff --git a/xla/service/profile_guided_latency_estimator.cc b/xla/service/profile_guided_latency_estimator.cc index 106aedfce1cd0..50d250a102a20 100644 --- a/xla/service/profile_guided_latency_estimator.cc +++ b/xla/service/profile_guided_latency_estimator.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -37,10 +37,27 @@ LatencyEstimator::TimeCost ProfileGuidedLatencyEstimator::GetLatencyBetween( } auto it = instr_map_.find(from.GetInstr().name()); + if (it == instr_map_.end() && + (from.GetInstr().opcode() == HloOpcode::kAsyncStart || + from.GetInstr().opcode() == HloOpcode::kAsyncDone)) { + absl::string_view wrapped_inst_name = + from.GetInstr().async_wrapped_instruction()->name(); + VLOG(10) << "PGLE found async wrapped instruction: " << wrapped_inst_name + << " in " << from.GetInstr().name(); + it = instr_map_.find(wrapped_inst_name); + } + if (it == instr_map_.end()) { return latency_estimator_->GetLatencyBetween(from, target); } + auto it2 = it->second.latencies.find(target.GetInstr().name()); + if (it2 == it->second.latencies.end() && + (target.GetInstr().opcode() == HloOpcode::kAsyncStart || + target.GetInstr().opcode() == HloOpcode::kAsyncDone)) { + it2 = it->second.latencies.find( + target.GetInstr().async_wrapped_instruction()->name()); + } if (it2 != it->second.latencies.end()) { VLOG(10) << "PGLE found latency between " << from.GetInstr().name() << " and " << target.GetInstr().name() << " in latency info"; @@ -61,9 +78,8 @@ LatencyEstimator::TimeCost ProfileGuidedLatencyEstimator::GetLatencyBetween( LatencyEstimator::TimeCost ProfileGuidedLatencyEstimator::NodeCost( const HloInstruction* instr) const { - const HloOpcode opcode = instr->opcode(); - if (hlo_query::IsAsyncCollectiveStartOp(opcode, /*include_send_recv=*/true) || - hlo_query::IsAsyncCollectiveDoneOp(opcode, /*include_send_recv=*/true)) { + if (hlo_query::IsAsyncCollectiveStartOp(instr, /*include_send_recv=*/true) || + hlo_query::IsAsyncCollectiveDoneOp(instr, /*include_send_recv=*/true)) { static constexpr TimeCost kLowCost = 1.0; return kLowCost; } diff --git a/xla/service/profile_guided_latency_estimator.h b/xla/service/profile_guided_latency_estimator.h index 5d71e85751317..a3b939ce77ec8 100644 --- a/xla/service/profile_guided_latency_estimator.h +++ b/xla/service/profile_guided_latency_estimator.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/profile_guided_latency_estimator_test.cc b/xla/service/profile_guided_latency_estimator_test.cc index 183ce5ce5cd99..74100cbd1c19c 100644 --- a/xla/service/profile_guided_latency_estimator_test.cc +++ b/xla/service/profile_guided_latency_estimator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -49,7 +49,7 @@ SchedulerConfig GetDefaultSchedConfig() { return sched_cfg; } -StatusOr RunScheduler( +absl::StatusOr RunScheduler( HloModule* module, const SchedulerConfig& sched_config, std::unique_ptr latency_estimator = std::make_unique()) { @@ -82,7 +82,7 @@ StatusOr RunScheduler( class LatencyHidingSchedulerTest : public HloTestBase, public ::testing::WithParamInterface { public: - StatusOr> ParseHloText( + absl::StatusOr> ParseHloText( absl::string_view hlo_string) { return ParseAndReturnVerifiedModule(hlo_string, GetModuleConfigForTest()); } @@ -158,4 +158,54 @@ ENTRY entry { INSTANTIATE_TEST_SUITE_P(LatencyHidingSchedulerTest, LatencyHidingSchedulerTest, ::testing::Bool()); +using ProfileGuidedLatencyEstimatorTest = HloTestBase; + +TEST_F(ProfileGuidedLatencyEstimatorTest, + TestProfileGuidedLatencyEstimatorWithAsyncInstruction) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +add.1 { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) +} + +ENTRY entry { + p0 = f32[16,64,256]{2,1,0} parameter(0) + p1 = f32[16,64,256]{2,1,0} parameter(1) + reduce-scatter-start = ((f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}), (f32[4,64,256]{2,1,0}, f32[4,64,256]{2,1,0})) reduce-scatter-start(p0, p1), channel_id=1, replica_groups={}, dimensions={0}, to_apply=add.1 + reduce-scatter-done = (f32[4,64,256]{2,1,0}, f32[4,64,256]{2,1,0}) reduce-scatter-done(reduce-scatter-start) + ROOT gte = f32[4,64,256]{2,1,0} get-tuple-element(reduce-scatter-done), index=0 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, + ParseAndReturnVerifiedModule(hlo_string)); + EXPECT_TRUE(hlo_module->has_entry_computation()); + + std::string profiled_instructions_text_proto = R"pb( + costs { name: "reduce-scatter" cost_us: 120.0 } + )pb"; + ; + tensorflow::profiler::ProfiledInstructionsProto profiled_instructions_proto; + ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString( + profiled_instructions_text_proto, &profiled_instructions_proto)); + + auto sched_config = GetDefaultSchedConfig(); + auto latency_estimator = std::make_unique( + sched_config, std::make_unique(), + profiled_instructions_proto); + HloInstruction* rs_start = + FindInstruction(hlo_module.get(), "reduce-scatter-start"); + HloInstruction* rs_done = + FindInstruction(hlo_module.get(), "reduce-scatter-done"); + HloGraphNode rs_start_node = HloGraphNode(rs_start, 0); + HloGraphNode rs_done_node = HloGraphNode(rs_done, 1); + + double latency = + latency_estimator->GetLatencyBetween(rs_start_node, rs_done_node); + EXPECT_EQ(latency, 120.0); +} + } // namespace xla diff --git a/xla/service/qr_expander.cc b/xla/service/qr_expander.cc index 4ed4d9ba0122d..e817b66b61d2c 100644 --- a/xla/service/qr_expander.cc +++ b/xla/service/qr_expander.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -196,7 +196,7 @@ Status House(XlaOp x, XlaOp k, absl::Span batch_dims, // a[j+1:, j] = v[j+1:] // taus[j] = tau // return (a, taus) -StatusOr QrExpander::QrBlock( +absl::StatusOr QrExpander::QrBlock( XlaOp a, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); @@ -219,8 +219,9 @@ StatusOr QrExpander::QrBlock( std::vector batch_dim_indices(num_batch_dims); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); - auto qr_body_fn = [&](XlaOp j, absl::Span values, - XlaBuilder* builder) -> StatusOr> { + auto qr_body_fn = + [&](XlaOp j, absl::Span values, + XlaBuilder* builder) -> absl::StatusOr> { auto a = values[0]; auto taus = values[1]; @@ -315,7 +316,7 @@ StatusOr QrExpander::QrBlock( // for i in range(1, n): // t[:, i] = scipy.linalg.blas.strmm(t, vtv[:, i]) // return t -StatusOr QrExpander::CompactWYRepresentation( +absl::StatusOr QrExpander::CompactWYRepresentation( PrimitiveType type, absl::Span batch_dims, XlaOp vs, XlaOp taus, int64_t m, int64_t n, PrecisionConfig::Precision precision) { XlaBuilder* builder = vs.builder(); @@ -324,8 +325,9 @@ StatusOr QrExpander::CompactWYRepresentation( std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); int64_t n_index = batch_dims.size() + 1; - auto body_fn = [&](XlaOp j, absl::Span values, - XlaBuilder* builder) -> StatusOr> { + auto body_fn = + [&](XlaOp j, absl::Span values, + XlaBuilder* builder) -> absl::StatusOr> { // w has shape [..., m, n] auto t = values[0]; const auto vtv = values[1]; @@ -370,7 +372,7 @@ StatusOr QrExpander::CompactWYRepresentation( // a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:]) // q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T) // return (q, a) -StatusOr QrExpander::BuildQrDecomposition( +absl::StatusOr QrExpander::BuildQrDecomposition( XlaOp a, int64_t block_size, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); @@ -433,7 +435,7 @@ StatusOr QrExpander::BuildQrDecomposition( return Tuple(builder, {a, taus}); } -StatusOr QrExpander::ProductOfElementaryHouseholderReflectors( +absl::StatusOr QrExpander::ProductOfElementaryHouseholderReflectors( XlaOp a, XlaOp taus, int64_t block_size, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); @@ -505,11 +507,14 @@ bool QrExpander::InstructionMatchesPattern(HloInstruction* instruction) { kHouseholderProductCustomCallName); } -StatusOr QrExpander::ExpandInstruction( +absl::StatusOr QrExpander::ExpandInstruction( HloInstruction* instruction) { - const std::string name = + std::string name = absl::StrFormat("xla.%s_%s", instruction->custom_call_target(), instruction->operand(0)->shape().ToString()); + if (instruction->custom_call_target() == kHouseholderProductCustomCallName) { + name += "_" + instruction->operand(1)->shape().ToString(); + } HloModule* module = instruction->GetModule(); diff --git a/xla/service/qr_expander.h b/xla/service/qr_expander.h index ab381bc57433d..d4818f644d137 100644 --- a/xla/service/qr_expander.h +++ b/xla/service/qr_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,21 +30,21 @@ class QrExpander : public OpExpanderPass { protected: bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; - virtual StatusOr QrBlock( + virtual absl::StatusOr QrBlock( XlaOp a, PrecisionConfig::Precision precision); - virtual StatusOr CompactWYRepresentation( + virtual absl::StatusOr CompactWYRepresentation( PrimitiveType type, absl::Span batch_dims, XlaOp vs, XlaOp taus, int64_t m, int64_t n, PrecisionConfig::Precision precision); private: - StatusOr BuildQrDecomposition(XlaOp a, int64_t block_size, - PrecisionConfig::Precision precision); + absl::StatusOr BuildQrDecomposition( + XlaOp a, int64_t block_size, PrecisionConfig::Precision precision); - StatusOr ProductOfElementaryHouseholderReflectors( + absl::StatusOr ProductOfElementaryHouseholderReflectors( XlaOp a, XlaOp taus, int64_t block_size, PrecisionConfig::Precision precision); diff --git a/xla/service/real_imag_expander.cc b/xla/service/real_imag_expander.cc index d00c5583f1eea..6f50e250ec114 100644 --- a/xla/service/real_imag_expander.cc +++ b/xla/service/real_imag_expander.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,7 +25,7 @@ bool RealImagExpander::InstructionMatchesPattern(HloInstruction* inst) { !ShapeUtil::ElementIsComplex(inst->operand(0)->shape()); } -StatusOr RealImagExpander::ExpandInstruction( +absl::StatusOr RealImagExpander::ExpandInstruction( HloInstruction* inst) { if (inst->opcode() == HloOpcode::kReal) { // real with a non-complex input is just a copy. diff --git a/xla/service/real_imag_expander.h b/xla/service/real_imag_expander.h index e5337197e44de..2c2bd9e08eb08 100644 --- a/xla/service/real_imag_expander.h +++ b/xla/service/real_imag_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,7 +28,8 @@ class RealImagExpander : public OpExpanderPass { protected: bool InstructionMatchesPattern(HloInstruction* inst) override; - StatusOr ExpandInstruction(HloInstruction* inst) override; + absl::StatusOr ExpandInstruction( + HloInstruction* inst) override; }; } // namespace xla diff --git a/xla/service/real_imag_expander_test.cc b/xla/service/real_imag_expander_test.cc index 8e3e25ddc2dde..a7349a64011d6 100644 --- a/xla/service/real_imag_expander_test.cc +++ b/xla/service/real_imag_expander_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/reduce_decomposer.cc b/xla/service/reduce_decomposer.cc index e67f6ff730b8d..0e43381063f1b 100644 --- a/xla/service/reduce_decomposer.cc +++ b/xla/service/reduce_decomposer.cc @@ -1,5 +1,5 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -123,7 +123,7 @@ class ReduceDecomposerVisitor : public DfsHloRewriteVisitor { } private: - StatusOr GetOutput(HloInstruction* instr, int idx) { + absl::StatusOr GetOutput(HloInstruction* instr, int idx) { if (instr->shape().IsTuple()) { return MakeGetTupleElementHlo(instr, idx); } else { @@ -147,7 +147,7 @@ class ReduceDecomposerVisitor : public DfsHloRewriteVisitor { } // namespace -StatusOr ReduceDecomposer::Run( +absl::StatusOr ReduceDecomposer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { TF_ASSIGN_OR_RETURN(bool changed1, diff --git a/xla/service/reduce_decomposer.h b/xla/service/reduce_decomposer.h index 1486a3249deee..0907527cc7244 100644 --- a/xla/service/reduce_decomposer.h +++ b/xla/service/reduce_decomposer.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -67,7 +67,7 @@ class ReduceDecomposer : public HloModulePass { absl::string_view name() const override { return "reduce-decomposer"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/reduce_decomposer_test.cc b/xla/service/reduce_decomposer_test.cc index 1cdf7341363d1..54d290ec9e441 100644 --- a/xla/service/reduce_decomposer_test.cc +++ b/xla/service/reduce_decomposer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/reduce_scatter_combiner.cc b/xla/service/reduce_scatter_combiner.cc index 19e454486f104..7bef0b9f9fdf7 100644 --- a/xla/service/reduce_scatter_combiner.cc +++ b/xla/service/reduce_scatter_combiner.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -48,22 +49,26 @@ limitations under the License. namespace xla { namespace { -int64_t FindMostFrequentGatherDim( +// Returns the most frequent scatter dim if it can be a valid scatter dim +// for all shapes involved, else returns 0. +int64_t FindMostFrequentScatterDim( absl::Span to_combine) { assert(!to_combine.empty()); // Count frequencies. + int64_t min_rank = std::numeric_limits::max(); std::vector frequency; for (const HloInstruction* it : to_combine) { int64_t dim = Cast(it)->scatter_dimension(); frequency.resize(std::max(dim + 1, static_cast(frequency.size())), 0); frequency[dim]++; + min_rank = std::min(min_rank, it->shape().rank()); } int64_t most_frequent_dim = std::distance( frequency.begin(), std::max_element(frequency.begin(), frequency.end())); - return most_frequent_dim; + return most_frequent_dim < min_rank ? most_frequent_dim : 0; } using ReduceScatterKey = @@ -90,8 +95,8 @@ Status CombineReduceScatters(absl::Span to_combine) { std::vector>> operand_permutations; std::vector output_shapes; - // Find the most frequent all-gather dimension. - int64_t most_frequent_dim = FindMostFrequentGatherDim(to_combine); + // Find the most frequent reduce-scatter dimension. + int64_t most_frequent_dim = FindMostFrequentScatterDim(to_combine); VLOG(1) << "Combining set"; for (HloInstruction* hlo : to_combine) { @@ -173,7 +178,7 @@ ReduceScatterCombiner::ReduceScatterCombiner(int64_t combine_threshold_in_bytes, combine_threshold_count_(combine_threshold_count), combine_by_dim_(combine_by_dim) {} -StatusOr ReduceScatterCombiner::Run( +absl::StatusOr ReduceScatterCombiner::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(1) << "Running ReduceScatterCombiner with threshold of " diff --git a/xla/service/reduce_scatter_combiner.h b/xla/service/reduce_scatter_combiner.h index 97bb562602416..6f1647591c07e 100644 --- a/xla/service/reduce_scatter_combiner.h +++ b/xla/service/reduce_scatter_combiner.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ class ReduceScatterCombiner : public HloModulePass { absl::string_view name() const override { return "reduce-scatter-combiner"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/reduce_scatter_combiner_test.cc b/xla/service/reduce_scatter_combiner_test.cc index 9427ca4fce0ca..9610e28e4a53a 100644 --- a/xla/service/reduce_scatter_combiner_test.cc +++ b/xla/service/reduce_scatter_combiner_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ constexpr int64_t kMaxByteCount = 10 * 1024 * 1024; class ReduceScatterCombinerTest : public HloTestBase { public: - StatusOr> RunPass( + absl::StatusOr> RunPass( absl::string_view hlo_module, bool expect_change, int64_t byte_threshold = kMaxByteCount, int64_t count_threshold = kMaxCombineCount, bool combine_by_dim = true) { @@ -54,7 +54,7 @@ class ReduceScatterCombinerTest : public HloTestBase { << ReduceScatterCount(module.get()) << " reduce-scatter ops"; EXPECT_EQ(changed.value(), expect_change); - return StatusOr>(std::move(module)); + return absl::StatusOr>(std::move(module)); } size_t ReduceScatterCount(HloModule *module) { @@ -153,6 +153,35 @@ ENTRY main { EXPECT_EQ(ReduceScatterCount(module.get()), 1); } +TEST_F(ReduceScatterCombinerTest, DifferentDimensionsAndRanks) { + absl::string_view hlo_string = R"( +HloModule m + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + +ENTRY main { + p0 = f32[8, 8] parameter(0) + p1 = f32[8] parameter(1) + rs0 = f32[8, 4] reduce-scatter(p0), replica_groups={{0,1}}, dimensions={1}, + to_apply=sum + rs1 = f32[8, 4] reduce-scatter(p0), replica_groups={{0,1}}, dimensions={1}, + to_apply=sum + rs2 = f32[4] reduce-scatter(p1), replica_groups={{0,1}}, dimensions={0}, + to_apply=sum + ROOT t = (f32[8, 4], f32[8, 4], f32[4]) + tuple(rs0, rs1, rs2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + auto module, RunPass(hlo_string, /*expect_change=*/true, kMaxByteCount, + kMaxCombineCount, /*combine_by_dim=*/false)); + EXPECT_EQ(ReduceScatterCount(module.get()), 1); +} + // Test that dependent reduce-scatter do not get combined. TEST_F(ReduceScatterCombinerTest, DependentReduceScatter) { absl::string_view hlo_string = R"( diff --git a/xla/service/reduce_scatter_decomposer.cc b/xla/service/reduce_scatter_decomposer.cc index 59366639b84e2..da2fed224a53f 100644 --- a/xla/service/reduce_scatter_decomposer.cc +++ b/xla/service/reduce_scatter_decomposer.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ limitations under the License. namespace xla { -StatusOr ReduceScatterDecomposer::Run( +absl::StatusOr ReduceScatterDecomposer::Run( HloModule *module, const absl::flat_hash_set &execution_threads) { bool changed = false; @@ -53,7 +53,11 @@ StatusOr ReduceScatterDecomposer::Run( if (rs->channel_id()) { channel_id = next_channel_id++; } + if (should_decompose_ && !should_decompose_(rs)) { + continue; + } + VLOG(2) << "Decompose: " << rs->ToString(); // Create an all-reduce HloComputation *apply_clone = module->AddComputationAndUnifyNamesAndIds( rs->to_apply()->Clone(), /*is_entry=*/false); diff --git a/xla/service/reduce_scatter_decomposer.h b/xla/service/reduce_scatter_decomposer.h index 2717baa31a18a..1ee1f603c09f2 100644 --- a/xla/service/reduce_scatter_decomposer.h +++ b/xla/service/reduce_scatter_decomposer.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,17 +29,19 @@ namespace xla { class ReduceScatterDecomposer : public HloModulePass { public: explicit ReduceScatterDecomposer( - std::function update_layout = nullptr) - : update_layout_(update_layout) {} + std::function update_layout = nullptr, + std::function should_decompose = nullptr) + : update_layout_(update_layout), should_decompose_(should_decompose) {} absl::string_view name() const override { return "reduce-scatter-decomposer"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; std::function update_layout_; + std::function should_decompose_; }; } // namespace xla diff --git a/xla/service/reduce_scatter_decomposer_test.cc b/xla/service/reduce_scatter_decomposer_test.cc index a2ffb5071c4d7..d7f8360fbdc91 100644 --- a/xla/service/reduce_scatter_decomposer_test.cc +++ b/xla/service/reduce_scatter_decomposer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -41,13 +41,18 @@ class ReduceScatterDecomposerTest : public HloTestBase { absl::string_view hlo_module, PassAction action, CollectiveOpGroupMode mode = CollectiveOpGroupMode::kCrossReplica, int64_t shard_size = 0, int64_t shard_dimension = 0, - int64_t replica_count = 2) { + int64_t replica_count = 2, + std::function should_decompose = + [](const HloInstruction *) { return true; }) { const int64_t partition_count = 2; TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndReturnVerifiedModule(hlo_module, replica_count, partition_count)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, - ReduceScatterDecomposer().Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ReduceScatterDecomposer(/*update_layout=*/nullptr, + /*should_decompose=*/should_decompose) + .Run(module.get())); if (action == PassAction::kNoChange) { ASSERT_FALSE(changed); return; @@ -222,5 +227,26 @@ ENTRY main { RunPass(hlo_string, PassAction::kNoChange); } +TEST_F(ReduceScatterDecomposerTest, NoChangeWithShouldDecompose) { + absl::string_view hlo_string = R"( +HloModule m + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + +ENTRY main { + p0 = f32[4, 8] parameter(0) + ROOT rs = f32[4, 4] reduce-scatter(p0), replica_groups={{0,1}, {2,3}}, channel_id=1, dimensions={1}, to_apply=sum, use_global_device_ids=true +} +)"; + RunPass(hlo_string, PassAction::kNoChange, + CollectiveOpGroupMode::kCrossReplica, + /*shard_size=*/0, /*shard_dimension=*/0, + /*replica_count=*/2, [](const HloInstruction *) { return false; }); +} + } // namespace } // namespace xla diff --git a/xla/service/reduce_scatter_reassociate.cc b/xla/service/reduce_scatter_reassociate.cc index b63e6b5de9c85..88b03f0f3e196 100644 --- a/xla/service/reduce_scatter_reassociate.cc +++ b/xla/service/reduce_scatter_reassociate.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -51,7 +51,7 @@ bool AreCompatible(const HloReduceScatterInstruction *rs0, } // namespace -StatusOr ReduceScatterReassociate::Run( +absl::StatusOr ReduceScatterReassociate::Run( HloModule *module, const absl::flat_hash_set &execution_threads) { if (hlo_query::ContainsLayoutConstrainedCollective( diff --git a/xla/service/reduce_scatter_reassociate.h b/xla/service/reduce_scatter_reassociate.h index bc7c92532a417..ebdaa3c43250b 100644 --- a/xla/service/reduce_scatter_reassociate.h +++ b/xla/service/reduce_scatter_reassociate.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ class ReduceScatterReassociate : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/reduce_scatter_reassociate_test.cc b/xla/service/reduce_scatter_reassociate_test.cc index bac680be01cde..ed68774a8496f 100644 --- a/xla/service/reduce_scatter_reassociate_test.cc +++ b/xla/service/reduce_scatter_reassociate_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,15 +27,15 @@ namespace m = xla::testing::opcode_matchers; class ReduceScatterReassociateTest : public HloTestBase { public: - StatusOr> RunPass(absl::string_view hlo_module, - bool expect_change) { + absl::StatusOr> RunPass( + absl::string_view hlo_module, bool expect_change) { TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module)); auto changed = ReduceScatterReassociate().Run(module.get()); if (!changed.ok()) { return changed.status(); } EXPECT_EQ(changed.value(), expect_change); - return StatusOr>(std::move(module)); + return absl::StatusOr>(std::move(module)); } size_t ReduceScatterCount(std::unique_ptr& module) { diff --git a/xla/service/rendezvous.cc b/xla/service/rendezvous.cc index 663ec1445bed0..ff37bebca30ac 100644 --- a/xla/service/rendezvous.cc +++ b/xla/service/rendezvous.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,33 +15,103 @@ limitations under the License. #include "xla/service/rendezvous.h" -#include "absl/synchronization/mutex.h" +#include +#include +#include +#include +#include + +#include "absl/synchronization/notification.h" #include "absl/time/time.h" #include "tsl/platform/logging.h" namespace xla { +namespace internal { -void AwaitAndLogIfStuck(absl::Mutex& mutex, const absl::Condition& condition, - absl::Duration warn_stuck_timeout, +void AwaitAndLogIfStuck(absl::Notification& ready, std::string_view name, + size_t num_threads, absl::Duration warn_stuck_timeout, absl::Duration terminate_timeout) { - if (mutex.AwaitWithTimeout(condition, warn_stuck_timeout)) { + if (ready.WaitForNotificationWithTimeout(warn_stuck_timeout)) { return; } - LOG(ERROR) << "This thread has been waiting for " + LOG(ERROR) << "This thread has been waiting for `" << name << "` for " << absl::ToInt64Seconds(warn_stuck_timeout) - << " seconds and may be stuck:"; + << " seconds and may be stuck. Expected " << num_threads + << " threads to join the rendezvous, but not all of them arrived" + << " on time."; - if (mutex.AwaitWithTimeout(condition, terminate_timeout)) { + if (ready.WaitForNotificationWithTimeout(terminate_timeout)) { LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. " "Perhaps the timeout is too short."; return; } - LOG(ERROR) - << "Termination timeout of " << absl::ToInt64Seconds(terminate_timeout) - << " seconds exceeded. Exiting to ensure a consistent program state."; + LOG(ERROR) << "Termination timeout for `" << name << "` of " + << absl::ToInt64Seconds(terminate_timeout) + << " seconds exceeded. Exiting to ensure a consistent program" + << " state. Expected " << num_threads + << " threads to join the rendezvous, but not all of them arrived" + << " on time."; std::exit(42); } +} // namespace internal + +namespace { +inline constexpr int32_t kPending = 0; +inline constexpr int32_t kCompleted = std::numeric_limits::max(); +} // namespace + +RendezvousSingleFlag::RendezvousSingleFlag() : state_(kPending) {} + +RendezvousSingleFlag::InFlightRendezvous::InFlightRendezvous( + RendezvousSingleFlag* flag) + : flag_(flag) {} + +RendezvousSingleFlag::InFlightRendezvous::~InFlightRendezvous() { + if (flag_ == nullptr) return; + + // Reload state and use CAS to decide if we are the one who + // should mark rendezvous flag completed. + int32_t state = flag_->state_.load(); + + CHECK(state != kPending && state != kCompleted) // NOLINT + << "rendezvous can't be in pending or completed state"; + + // Exit the critical section and maybe mark rendezvous as completed. + while (!flag_->state_.compare_exchange_weak( + state, state == 1 ? kCompleted : state - 1)) { + // Check state after CAS failure: while we are in this function no one + // should complete rendezvous without us or switch it back to pending. + CHECK(state != kPending && state != kCompleted); // NOLINT + } +} + +RendezvousSingleFlag::InFlightRendezvous::operator bool() const { + return flag_ != nullptr; +} + +RendezvousSingleFlag::InFlightRendezvous RendezvousSingleFlag::TryJoin() { + // If `state_` is `kCompleted` it means that we have at least one completed + // rendezvous for this flag and can skip it. + if (state_.load() == kCompleted) return InFlightRendezvous(nullptr); + + // Try to increment a state in a CAS loop to signal all other participants + // that we joined an in-flight rendezvous. + int32_t state = state_.load(); + while (state != kCompleted && + !state_.compare_exchange_weak(state, state + 1)) { + } + + // Someone else completed the rendezvous and we don't need to join. + if (state == kCompleted) return InFlightRendezvous(nullptr); + + return InFlightRendezvous(this); +} + +bool RendezvousSingleFlag::IsCompleted() const { + return state_.load() == kCompleted; +} + } // namespace xla diff --git a/xla/service/rendezvous.h b/xla/service/rendezvous.h index a22b491ead581..bee54f44fc9bd 100644 --- a/xla/service/rendezvous.h +++ b/xla/service/rendezvous.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,124 +16,345 @@ limitations under the License. #ifndef XLA_SERVICE_RENDEZVOUS_H_ #define XLA_SERVICE_RENDEZVOUS_H_ -#include +#include +#include +#include #include +#include +#include +#include #include #include +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" -#include "absl/functional/function_ref.h" +#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "tsl/platform/logging.h" namespace xla { -template -class ThreadSafeMap { +//===----------------------------------------------------------------------===// +// A rendezvous for a group of threads. +//===----------------------------------------------------------------------===// + +// A little bit of compile time metaprogramming to simplify the rendezvous +// return type for functions returning `absl::StatusOr`. If we detect that +// rendezvous callback returns `absl::StatusOr` we swap the order of a shared +// pointer and status container. + +template +struct RendezvousResult { + using Type = std::shared_ptr; + + static Type Wrap(R result) { return std::make_shared(std::move(result)); } + + static Type Empty() { return std::shared_ptr(); } +}; + +template +struct RendezvousResult> { + using Type = absl::StatusOr>; + + static Type Wrap(absl::StatusOr result) { + if (!result.ok()) return result.status(); + return std::make_shared(std::move(*result)); + } + + static Type Empty() { return {std::shared_ptr()}; } +}; + +template +using RendezvousResultType = typename RendezvousResult::Type; + +// The group of threads identifies itself with a key that must be unique to +// the the group. When all threads have arrived at the rendezvous, one thread +// executes the given function with the values supplied by each thread, and +// all threads receive the result. Rendezvous must have a human readable name to +// make easy to debug stuck and timed out attempts. +template +RendezvousResultType RendezvousSingle( + std::string_view name, const K& key, const V& value, size_t num_threads, + Fn fn, absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), + absl::Duration terminate_timeout = absl::InfiniteDuration()); + +// A rendezvous for a group of threads that do not have any value arguments. +template +RendezvousResultType RendezvousSingle( + std::string_view name, const K& key, size_t num_threads, Fn fn, + absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), + absl::Duration terminate_timeout = absl::InfiniteDuration()); + +// A rendezvous for a group of threads that do not have any computation to run +// and simply acts as a barrier for a group of thread. +template +void RendezvousSingle( + std::string_view name, const K& key, size_t num_threads, + absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), + absl::Duration terminate_timeout = absl::InfiniteDuration()); + +// An `std::once_flag`-like primitive for executing RendezvousSingle operations. +// +// RendezvousSingleFlag guarantees that all or none participants in a rendezvous +// join the rendezvous process and once rendezvous is completed flag marked as +// `completed` and all further rendezvous using this flag will be skipped. It +// has a weaker than exactly-once guarantee and multiple racing rendezvous can +// execute in parallel, and the last completed rendezvous will switch flag to +// `completed` state. +// +// In XLA rendezvous are rare and used to guard costly shared state +// initialization, so in practice we do not expect to see many racing rendezvous +// and prefer simpler implementation with weaker guarantees. +// +// See: https://en.cppreference.com/w/cpp/thread/once_flag +class RendezvousSingleFlag { + public: + RendezvousSingleFlag(); + + RendezvousSingleFlag(const RendezvousSingleFlag&) = delete; + RendezvousSingleFlag& operator=(const RendezvousSingleFlag&) = delete; + + // RAII wrapper to exit from in-flight rendezvous when destructed. + class InFlightRendezvous { + public: + explicit InFlightRendezvous(RendezvousSingleFlag* flag); + ~InFlightRendezvous(); + + InFlightRendezvous(const InFlightRendezvous&) = delete; + InFlightRendezvous& operator=(const InFlightRendezvous&) = delete; + + operator bool() const; // NOLINT + + private: + RendezvousSingleFlag* flag_; + }; + + // Returns InFlightRendezvous convertible to `true` if the caller should join + // the rendezvous process. If result conversion to bool is `false` it means + // that the rendezvous is already completed. + InFlightRendezvous TryJoin(); + + bool IsCompleted() const; + + private: + friend class InFlightRendezvous; + + std::atomic state_; +}; + +// A rendezvous for a group of threads that will be executed only if the flag is +// not in `completed` state and will switch it to `completed` after finishing a +// rendezvous. If rendezvous will not be executed it will return empty shared +// pointer result. +template +RendezvousResultType RendezvousSingle( + RendezvousSingleFlag& flag, std::string_view name, const K& key, + size_t num_threads, Fn fn, + absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), + absl::Duration terminate_timeout = absl::InfiniteDuration()); + +// A rendezvous for a group of threads that will be executed only if the flag is +// not in `completed` state and will switch it to `completed` after finishing a +// rendezvous. +template +void RendezvousSingle( + RendezvousSingleFlag& flag, std::string_view name, const K& key, + size_t num_threads, + absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), + absl::Duration terminate_timeout = absl::InfiniteDuration()); + +//===----------------------------------------------------------------------===// +// Internal implementation details. +//===----------------------------------------------------------------------===// + +namespace internal { + +// A state for a single round of rendezvous. We expect exactly `num_treads` to +// arrive to a rendezvous and update corresponding slots in `values`. We +// pre-allocate storage for values so at run time each participant doesn't have +// to grab a lock and can simple write to the destination storage. +template +struct RendezvousState { + explicit RendezvousState(size_t num_threads) + : ack(0), rel(0), values(num_threads, nullptr), result(nullptr) {} + + std::atomic ack; + std::atomic rel; + std::vector values; + + absl::Notification ready; // signals availability of `result` + RendezvousResultType result; +}; + +// A container for in-progress rendezvous. +// +// Rendezvous state ownership: +// +// (1) When rendezvous participant initiates a rendezvous with a particular key +// we create a new state for it, keep it in a map for tracking and return a +// shared pointer to the caller. +// +// (2) When rendezvous participant joins in-progress rendezvous it gets back +// a shared pointer that is copied from a tracking map. +// +// (3) When the last rendezvous participant computes the result it completes the +// rendezvous and removes a shared pointer to a state. Remaining shared +// pointers destructed when all participants are notified. +// +// This process guarantees that all completed rendezvous are removed from a map +// and a map has records only for rendezvous in progress. +template +class RendezvousMap { public: - V& operator[](const K& key) { + using State = RendezvousState; + + std::shared_ptr Join(const K& key, size_t num_threads) { absl::MutexLock lock(&mutex_); - std::unique_ptr& value = map_[key]; - if (value == nullptr) value = std::make_unique(); - return *value; + std::shared_ptr& state = state_[key]; + + // Join an in-progress rendezvous. + if (state) return state; + + // Join a newly created rendezvous. + return state = std::make_shared(num_threads); } - void ForEachValue(absl::FunctionRef fn) { - absl::MutexLock lock(&mutex_); - for (const auto& [_, value] : map_) fn(*value); + void Complete(const K& key, RendezvousResultType result) { + std::shared_ptr state = [&] { + absl::MutexLock lock(&mutex_); + + // Extract state from the map so we can immediately start a new round of + // rendezvous with the same key. A state for previous rendezvous will be + // destructed with the last copy of a shared pointer. + std::shared_ptr state = state_.extract(key).mapped(); + + // Check that we have have exactly the number of participants we expected: + // +1 reference for all participants and a +1 reference we extracted. + CHECK_EQ(state.use_count(), 1 + state->values.size()); // NOLINT + + return state; + }(); + + // Notify awaiting participants without holding a lock. + state->result = std::move(result); + state->ready.Notify(); } private: absl::Mutex mutex_; - absl::flat_hash_map> map_ ABSL_GUARDED_BY(mutex_); + absl::flat_hash_map> state_ ABSL_GUARDED_BY(mutex_); }; -void AwaitAndLogIfStuck(absl::Mutex& mutex, const absl::Condition& condition, - absl::Duration warn_stuck_timeout, +void AwaitAndLogIfStuck(absl::Notification& ready, std::string_view name, + size_t num_threads, absl::Duration warn_stuck_timeout, absl::Duration terminate_timeout); +} // namespace internal + +//===----------------------------------------------------------------------===// +// Rendezvous implemenetation. +//===----------------------------------------------------------------------===// + +template +RendezvousResultType RendezvousSingle(std::string_view name, const K& key, + const V& value, size_t num_threads, + Fn fn, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { + // Check that `fn` is callable with a span of values and returns `R`. + static_assert(std::is_invocable_r_v>, + "invalid rendezvous function signature"); -// A rendezvous for a group of threads. -// -// The group of threads identifies itself with a key that must be unique to the -// the group. When all threads have arrived at the rendezvous, one thread -// executes the given function with the values supplied by each thread, and all -// threads receive the result. -// TODO(cjfj): Replace XLA rendezvous code with this simpler implementation. -template -std::shared_ptr RendezvousSingle( - const K& key, const V& value, size_t num_threads, - absl::FunctionRef)> fn, - absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), - absl::Duration terminate_timeout = absl::InfiniteDuration()) { // Fast-path (DO NOT REMOVE: the logic below doesn't work for single thread). - if (num_threads == 1) return std::make_shared(fn({&value})); + if (num_threads == 1) { + const V* ptr = &value; + return RendezvousResult::Wrap(fn(absl::MakeSpan(&ptr, 1))); + } - struct State { - absl::Mutex mutex; - std::vector values ABSL_GUARDED_BY(mutex); - std::shared_ptr result ABSL_GUARDED_BY(mutex); - }; + using State = internal::RendezvousState; + static auto& rendezvous = *new internal::RendezvousMap; + std::shared_ptr state = rendezvous.Join(key, num_threads); - static auto& states = *new ThreadSafeMap; - State& state = states[key]; + // If we got an id larger than `num_threads` it means that we have multiple + // rendezvous sharing the same key running concurrently. + int64_t id = state->ack.fetch_add(1); + CHECK_LT(id, num_threads) // NOLINT + << "Id can't be larger than the number of participating threads" + << "; id=" << id << "; num_threads=" << num_threads; - absl::MutexLock lock(&state.mutex); - state.values.push_back(&value); + // std::vector::operator[] creates data races, so we rely on data pointer + // here and when we create an absl::Span below. + *(state->values.data() + id) = &value; - std::shared_ptr result; - if (state.values.size() == num_threads) { - // Last thread to arrive executes the function. - CHECK(state.result == nullptr); - result = std::make_shared(fn(state.values)); - state.result = result; - state.values.clear(); + // Use a second atomic to safely publish values without data races. + if constexpr (!std::is_same_v) { + id = state->rel.fetch_add(1); + } + + if (id < num_threads - 1) { + // Threads arriving before the last one wait for a result to be computed by + // the last joining thread. + internal::AwaitAndLogIfStuck(state->ready, name, num_threads, + warn_stuck_timeout, terminate_timeout); } else { - absl::Condition result_ready( - +[](std::shared_ptr* ptr) { return ptr->get() != nullptr; }, - &state.result); - AwaitAndLogIfStuck(state.mutex, result_ready, warn_stuck_timeout, - terminate_timeout); - - // There is one use of the result in the shared state, plus one use for each - // thread that has already retrieved the result. - if (state.result.use_count() < num_threads) { - result = state.result; - } else { - // Last thread to retrieve the result takes the result from the state, - // allowing the other threads to exit the function. - return std::move(state.result); - } + // Last thread to arrive executes the function and completes rendezvous by + // making result available to all participants. All other participants will + // be notified via `state->ready` notification when result is ready, and we + // rely on the notification to create a memory barrier that makes access to + // `state->result` safe without any extra synchronization. + absl::Span values(state->values.data(), num_threads); + rendezvous.Complete(key, RendezvousResult::Wrap(fn(values))); } - // Wait for all threads to have retrieved the result. Without this, a thread - // could duplicate or delete its copy of the result, invalidating the use - // count logic above. - absl::Condition result_taken( - +[](std::shared_ptr* ptr) { return ptr->get() == nullptr; }, - &state.result); - AwaitAndLogIfStuck(state.mutex, result_taken, warn_stuck_timeout, - terminate_timeout); - return result; + return state->result; } -// A rendezvous for a group of threads. -// -// The group of threads identifies itself with a key that must be unique to the -// the group. When all threads have arrived at the rendezvous, one thread -// executes the given function and all threads receive the result. -// TODO(cjfj): Replace XLA rendezvous code with this simpler implementation. -template -std::shared_ptr RendezvousSingle( - const K& key, size_t num_threads, absl::FunctionRef fn, - absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), - absl::Duration terminate_timeout = absl::InfiniteDuration()) { - // Pass an arbitrary value that is ignored. - return RendezvousSingle( - key, 0, num_threads, [fn](absl::Span) { return fn(); }, +template +RendezvousResultType RendezvousSingle(std::string_view name, const K& key, + size_t num_threads, Fn fn, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { + return RendezvousSingle( + name, key, std::nullopt, num_threads, [fn](auto) { return fn(); }, warn_stuck_timeout, terminate_timeout); } +template +void RendezvousSingle(std::string_view name, const K& key, size_t num_threads, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { + RendezvousSingle( + name, key, std::nullopt, num_threads, [](auto) { return std::nullopt; }, + warn_stuck_timeout, terminate_timeout); +} + +template +RendezvousResultType RendezvousSingle(RendezvousSingleFlag& flag, + std::string_view name, const K& key, + size_t num_threads, Fn fn, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { + if (auto in_flight_rendezvous = flag.TryJoin()) { + return RendezvousSingle(name, key, num_threads, std::move(fn), + warn_stuck_timeout, terminate_timeout); + } else { + return RendezvousResult::Empty(); + } +} + +template +void RendezvousSingle(RendezvousSingleFlag& flag, std::string_view name, + const K& key, size_t num_threads, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { + if (auto in_flight_rendezvous = flag.TryJoin()) { + RendezvousSingle(name, key, num_threads, warn_stuck_timeout, + terminate_timeout); + } +} + } // namespace xla #endif // XLA_SERVICE_RENDEZVOUS_H_ diff --git a/xla/service/rendezvous_test.cc b/xla/service/rendezvous_test.cc new file mode 100644 index 0000000000000..f7fda89bdccdd --- /dev/null +++ b/xla/service/rendezvous_test.cc @@ -0,0 +1,282 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/rendezvous.h" + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/synchronization/blocking_counter.h" +#include "absl/synchronization/notification.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "tsl/platform/env.h" +#include "tsl/platform/test.h" +#include "tsl/platform/test_benchmark.h" +#include "tsl/platform/threadpool.h" + +namespace xla { +namespace { + +absl::Duration Timeout() { return absl::Seconds(10); } +absl::Duration Terminate() { return absl::Seconds(10); } + +tsl::thread::ThreadPool CreateThreadPool(int32_t size) { + return tsl::thread::ThreadPool(tsl::Env::Default(), "rendezvous_test", size); +} + +TEST(RendezvousTest, OneParticipant) { + auto result = + RendezvousSingle("rendezvous_test", 0, 1, [] { return 42; }); + ASSERT_EQ(*result, 42); +} + +TEST(RendezvousTest, TwoParticipants) { + absl::BlockingCounter counter(2); + std::vector> results(2); + + auto task = [&](int32_t id) { + return [&, id] { + results[id] = + RendezvousSingle("rendezvous_test", 0, 2, [] { return 42; }); + counter.DecrementCount(); + }; + }; + + auto thread_pool = CreateThreadPool(2); + thread_pool.Schedule(task(0)); + thread_pool.Schedule(task(1)); + counter.Wait(); + + ASSERT_EQ(results.size(), 2); + ASSERT_EQ(*results[0], 42); + ASSERT_EQ(*results[1], 42); +} + +TEST(RendezvousTest, TwoParticipantsWithValues) { + absl::BlockingCounter counter(2); + std::vector> results(2); + + auto accumulate = [](absl::Span values) { + int32_t result = 0; + for (const int32_t* value : values) result += *value; + return result; + }; + + auto task = [&](int32_t id) { + return [&, id] { + results[id] = + RendezvousSingle("rendezvous_test", 0, id, 2, accumulate); + counter.DecrementCount(); + }; + }; + + auto thread_pool = CreateThreadPool(2); + thread_pool.Schedule(task(0)); + thread_pool.Schedule(task(1)); + counter.Wait(); + + ASSERT_EQ(results.size(), 2); + ASSERT_EQ(*results[0], 1); + ASSERT_EQ(*results[1], 1); +} + +TEST(RendezvousTest, RepeatRendezvous) { + auto thread_pool = CreateThreadPool(2); + + for (int32_t i = 0; i < 10; ++i) { + absl::BlockingCounter counter(2); + + auto task = [&] { + RendezvousSingle("rendezvous_test", i, 2, [] { return 42; }); + counter.DecrementCount(); + }; + + thread_pool.Schedule(task); + thread_pool.Schedule(task); + counter.Wait(); + } +} + +TEST(RendezvousTest, ReturningStatusOr) { + absl::BlockingCounter counter(2); + std::vector>> results(2); + + auto task = [&](int32_t id) { + return [&, id] { + results[id] = RendezvousSingle>( + "rendezvous_test", 0, 2, [] { return 42; }); + counter.DecrementCount(); + }; + }; + + auto thread_pool = CreateThreadPool(2); + thread_pool.Schedule(task(0)); + thread_pool.Schedule(task(1)); + counter.Wait(); + + ASSERT_EQ(results.size(), 2); + ASSERT_EQ(**results[0], 42); + ASSERT_EQ(**results[1], 42); +} + +TEST(RendezvousTest, RendezvousSingleFlag) { + RendezvousSingleFlag flag; + + auto thread_pool = CreateThreadPool(2); + int32_t num_executed = 0; + + absl::BlockingCounter round_0(2); + absl::BlockingCounter round_1(2); + + auto task = [&](absl::BlockingCounter& counter) { + return [&] { + RendezvousSingle( + flag, "rendezvous_test", 0, 2, [&] { return ++num_executed; }, + Timeout(), Terminate()); + counter.DecrementCount(); + }; + }; + + // Execute rendezvous a first time. + thread_pool.Schedule(task(round_0)); + thread_pool.Schedule(task(round_0)); + round_0.Wait(); + + ASSERT_EQ(num_executed, 1); + + // Execute rendezvous a second time. + thread_pool.Schedule(task(round_1)); + thread_pool.Schedule(task(round_1)); + round_1.Wait(); + + // Check that we did not execute it second time. + ASSERT_EQ(num_executed, 1); +} + +TEST(RendezvousTest, RendezvousSingleFlagRace) { + RendezvousSingleFlag flag; + + static constexpr int32_t kNumRendezvous = 16; + static constexpr int32_t kNumThreads = 8; + + auto thread_pool = CreateThreadPool(kNumRendezvous * kNumThreads); + + auto task = [&](int32_t key) { + return [&, key] { + RendezvousSingle(flag, "key: " + std::to_string(key), key, kNumThreads, + Timeout(), Terminate()); + }; + }; + + for (int32_t key = 0; key < kNumRendezvous; ++key) { + for (int32_t thread = 0; thread < kNumThreads; ++thread) { + thread_pool.Schedule(task(key)); + } + } +} + +TEST(RendezvousTest, RendezvousSingleFlagRaceWithBarriers) { + RendezvousSingleFlag flag; + + static constexpr int32_t kNumRendezvous = 16; + static constexpr int32_t kNumThreads = 8; + + auto thread_pool = CreateThreadPool(kNumRendezvous * kNumThreads); + + // We use barriers and notifications to make sure all 128 threads start + // rendezvous at the same time to detect potential deadlocks and data races. + absl::BlockingCounter participants_ready(kNumRendezvous * kNumThreads); + absl::Notification participants_notification; + absl::BlockingCounter participants_done(kNumRendezvous * kNumThreads); + + auto task = [&](int32_t key) { + return [&, key] { + participants_ready.DecrementCount(); + participants_notification.WaitForNotification(); + RendezvousSingle(flag, "key: " + std::to_string(key), key, kNumThreads, + Timeout(), Terminate()); + participants_done.DecrementCount(); + }; + }; + + for (int32_t key = 0; key < kNumRendezvous; ++key) { + for (int32_t thread = 0; thread < kNumThreads; ++thread) { + thread_pool.Schedule(task(key)); + } + } + + participants_notification.Notify(); + participants_ready.Wait(); + participants_done.Wait(); +} + +//===----------------------------------------------------------------------===// +// Performance benchmarks below +//===----------------------------------------------------------------------===// + +static void BM_Rendezvous(benchmark::State& state) { + int64_t num_threads = state.range(0); + auto thread_pool = CreateThreadPool(num_threads); + + for (auto _ : state) { + absl::BlockingCounter counter(num_threads); + for (int64_t i = 0; i < num_threads; ++i) { + thread_pool.Schedule([&] { + RendezvousSingle("rendezvous_test", 0, num_threads, + [] { return 42; }); + counter.DecrementCount(); + }); + } + counter.Wait(); + } +} + +static void BM_RendezvousWithValues(benchmark::State& state) { + int64_t num_threads = state.range(0); + auto thread_pool = CreateThreadPool(num_threads); + + for (auto _ : state) { + absl::BlockingCounter counter(num_threads); + for (int64_t i = 0; i < num_threads; ++i) { + thread_pool.Schedule([&] { + int32_t value = i; + RendezvousSingle("rendezvous_test", 0, value, num_threads, + [](auto) { return 42; }); + counter.DecrementCount(); + }); + } + counter.Wait(); + } +} + +BENCHMARK(BM_Rendezvous) + ->MeasureProcessCPUTime() + ->Arg(2) + ->Arg(4) + ->Arg(8) + ->Arg(16); + +BENCHMARK(BM_RendezvousWithValues) + ->MeasureProcessCPUTime() + ->Arg(2) + ->Arg(4) + ->Arg(8) + ->Arg(16); + +} // namespace +} // namespace xla diff --git a/xla/service/reshape_decomposer.cc b/xla/service/reshape_decomposer.cc index ebd4bab886db9..2e3ab87569abc 100644 --- a/xla/service/reshape_decomposer.cc +++ b/xla/service/reshape_decomposer.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -72,7 +72,7 @@ class ReshapeDecomposerVisitor : public DfsHloRewriteVisitor { } // namespace -StatusOr ReshapeDecomposer::Run( +absl::StatusOr ReshapeDecomposer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { return ReshapeDecomposerVisitor{}.RunOnModule(module, execution_threads); diff --git a/xla/service/reshape_decomposer.h b/xla/service/reshape_decomposer.h index 4b3a8bf9d6aad..2d9cb5fee68b5 100644 --- a/xla/service/reshape_decomposer.h +++ b/xla/service/reshape_decomposer.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ class ReshapeDecomposer : public HloModulePass { absl::string_view name() const override { return "reshape-decomposer"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/reshape_decomposer_test.cc b/xla/service/reshape_decomposer_test.cc index 2fa5c7fe16018..94d135e8b15b0 100644 --- a/xla/service/reshape_decomposer_test.cc +++ b/xla/service/reshape_decomposer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/reshape_mover.cc b/xla/service/reshape_mover.cc index 5463b77e4bf50..8040a6eb544f2 100644 --- a/xla/service/reshape_mover.cc +++ b/xla/service/reshape_mover.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -226,7 +226,7 @@ bool ReshapeMover::IsReshapeMoveCandidate(HloInstruction* instruction) { // This will often create redundant operations that we expect to be eliminated // by algsimp. For example, if we have an operand rearrange(x), this will // produce rearrange'(rearrange(x)), which can be simplified to x. -StatusOr ReshapeMover::ApplyInverseRearrange( +absl::StatusOr ReshapeMover::ApplyInverseRearrange( const HloInstruction* rearrange, HloInstruction* operand) { switch (rearrange->opcode()) { case HloOpcode::kReshape: { @@ -255,7 +255,7 @@ StatusOr ReshapeMover::ApplyInverseRearrange( // Actually performs the reshape-move transformation -- that is, sinks the // reshape or transpose operands of `instruction` across it. -StatusOr ReshapeMover::SinkRearrangeOperands( +absl::StatusOr ReshapeMover::SinkRearrangeOperands( HloInstruction* instruction) { auto print_no_metadata = HloPrintOptions().set_print_metadata(false); @@ -331,7 +331,7 @@ StatusOr ReshapeMover::SinkRearrangeOperands( // remaining rearrange operands have users outside `candidates`. In the later // case, all the remaining instructions in `candidates` are reshape-moved and // the routine returns true. -StatusOr ReshapeMover::TryReshapeMoveOnCandidates( +absl::StatusOr ReshapeMover::TryReshapeMoveOnCandidates( HloInstructionSet* candidates) { bool removed = true; while (!candidates->empty() && removed) { @@ -378,7 +378,7 @@ StatusOr ReshapeMover::TryReshapeMoveOnCandidates( return true; } -StatusOr ReshapeMover::Run( +absl::StatusOr ReshapeMover::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/reshape_mover.h b/xla/service/reshape_mover.h index 444457bc49f96..da0e3a55134bc 100644 --- a/xla/service/reshape_mover.h +++ b/xla/service/reshape_mover.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -51,14 +51,15 @@ class ReshapeMover : public HloModulePass { absl::string_view name() const override { return "reshape-mover"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; private: - StatusOr TryReshapeMoveOnCandidates(HloInstructionSet* candidates); - StatusOr SinkRearrangeOperands(HloInstruction* instruction); - StatusOr ApplyInverseRearrange( + absl::StatusOr TryReshapeMoveOnCandidates( + HloInstructionSet* candidates); + absl::StatusOr SinkRearrangeOperands(HloInstruction* instruction); + absl::StatusOr ApplyInverseRearrange( const HloInstruction* rearrange, HloInstruction* operand); bool IsReshapeMoveCandidate(HloInstruction* instruction); const HloInstruction* FirstNontrivialRearrange( diff --git a/xla/service/reshape_mover_test.cc b/xla/service/reshape_mover_test.cc index 061b4ecef81d0..02867a54a3b90 100644 --- a/xla/service/reshape_mover_test.cc +++ b/xla/service/reshape_mover_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/result_caster.cc b/xla/service/result_caster.cc index 432f7faff1484..fed07ff2cba87 100644 --- a/xla/service/result_caster.cc +++ b/xla/service/result_caster.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,19 +15,28 @@ limitations under the License. #include "xla/service/result_caster.h" +#include + +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/shape_inference.h" +#include "xla/shape.h" namespace xla { namespace { -StatusOr> MaybeInferShape( +absl::StatusOr> MaybeInferShape( const HloInstruction* instruction) { switch (instruction->opcode()) { case HloOpcode::kDot: return ShapeInference::InferDotOpShape( instruction->operand(0)->shape(), instruction->operand(1)->shape(), instruction->dot_dimension_numbers(), - /*preferred_element_type=*/std::nullopt); + /*preferred_element_type=*/std::nullopt, + Cast(instruction)->sparsity()); case HloOpcode::kConvolution: return ShapeInference::InferConvolveShape( instruction->operand(0)->shape(), instruction->operand(1)->shape(), @@ -51,7 +60,7 @@ bool ResultCaster::InstructionMatchesPattern(HloInstruction* instruction) { return inferred_shape.element_type() != instruction->shape().element_type(); } -StatusOr ResultCaster::ExpandInstruction( +absl::StatusOr ResultCaster::ExpandInstruction( HloInstruction* instruction) { auto* computation = instruction->parent(); Shape inferred_shape = MaybeInferShape(instruction).value().value(); diff --git a/xla/service/result_caster.h b/xla/service/result_caster.h index c5fd47d2c8fa1..9482ec4152e80 100644 --- a/xla/service/result_caster.h +++ b/xla/service/result_caster.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,8 +16,13 @@ limitations under the License. #ifndef XLA_SERVICE_RESULT_CASTER_H_ #define XLA_SERVICE_RESULT_CASTER_H_ -#include "xla/hlo/ir/hlo_module.h" +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/op_expander_pass.h" +#include "xla/util.h" namespace xla { @@ -34,7 +39,7 @@ class ResultCaster : public OpExpanderPass { protected: bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; }; diff --git a/xla/service/result_caster_test.cc b/xla/service/result_caster_test.cc index 54dcdf9adfab4..9fdd4049422f4 100644 --- a/xla/service/result_caster_test.cc +++ b/xla/service/result_caster_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,10 +15,16 @@ limitations under the License. #include "xla/service/result_caster.h" +#include +#include + +#include "absl/strings/string_view.h" #include "absl/strings/substitute.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/primitive_util.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -76,5 +82,26 @@ INSTANTIATE_TEST_SUITE_P(All, ResultCasterTest, std::make_tuple(F32, F32, S32), std::make_tuple(F32, BF16, F32))); +TEST_F(ResultCasterTest, SparseDot) { + absl::string_view kHlo = R"( + HloModule module + + ENTRY main { + p0 = bf16[2,16]{1,0} parameter(0) + p1 = bf16[32,2]{1,0} parameter(1) + meta = u16[2,2]{1,0} parameter(2) + ROOT dot = f32[2,2]{1,0} dot(p0, p1, meta), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4 + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHlo)); + TF_ASSERT_OK_AND_ASSIGN(bool casted, ResultCaster().Run(module.get())); + EXPECT_TRUE(casted); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Convert(::testing::MakeMatcher(new ::xla::testing::HloMatcher( + HloOpcode::kDot, + {op::Parameter(0), op::Parameter(1), op::Parameter(2)})))); +} + } // namespace } // namespace xla diff --git a/xla/service/rng_bit_generator_expander.cc b/xla/service/rng_bit_generator_expander.cc index 2d4fb779fe745..cbabc2ebef6d5 100644 --- a/xla/service/rng_bit_generator_expander.cc +++ b/xla/service/rng_bit_generator_expander.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -52,9 +52,11 @@ bool RngBitGeneratorExpander::InstructionMatchesPattern( return instruction->opcode() == HloOpcode::kRngBitGenerator; } -StatusOr RngBitGeneratorExpander::GetGeneratorComputation( - const Shape& data_shape, const Shape& state_shape, - RandomAlgorithm algorithm, HloModule* module) { +absl::StatusOr +RngBitGeneratorExpander::GetGeneratorComputation(const Shape& data_shape, + const Shape& state_shape, + RandomAlgorithm algorithm, + HloModule* module) { RngGeneratorKey cache_key{data_shape, state_shape, algorithm, module}; auto it = computation_cache_.find(cache_key); if (it != computation_cache_.end()) { @@ -97,7 +99,7 @@ StatusOr RngBitGeneratorExpander::GetGeneratorComputation( return new_computation; } -StatusOr RngBitGeneratorExpander::ExpandInstruction( +absl::StatusOr RngBitGeneratorExpander::ExpandInstruction( HloInstruction* hlo) { HloRngBitGeneratorInstruction* rng = Cast(hlo); RandomAlgorithm algorithm = rng->algorithm(); diff --git a/xla/service/rng_bit_generator_expander.h b/xla/service/rng_bit_generator_expander.h index b297e8c9b1dc4..46b8264ba5fd0 100644 --- a/xla/service/rng_bit_generator_expander.h +++ b/xla/service/rng_bit_generator_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -57,11 +57,11 @@ class RngBitGeneratorExpander : public OpExpanderPass { }; bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction(HloInstruction* hlo) override; - StatusOr GetGeneratorComputation(const Shape& data_shape, - const Shape& state_shape, - RandomAlgorithm algorithm, - HloModule* module); + absl::StatusOr ExpandInstruction( + HloInstruction* hlo) override; + absl::StatusOr GetGeneratorComputation( + const Shape& data_shape, const Shape& state_shape, + RandomAlgorithm algorithm, HloModule* module); const RandomAlgorithm default_algorithm_; absl::flat_hash_map computation_cache_; diff --git a/xla/service/rng_expander.cc b/xla/service/rng_expander.cc index 7b3232c519e65..cbc5a1d4549db 100644 --- a/xla/service/rng_expander.cc +++ b/xla/service/rng_expander.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ int64_t GetNumberOf32bitUnits(const Shape& shape) { return num_elems * (bit_width / 32); } -StatusOr ConvertSmallFpRngToF32Rng(HloInstruction* rng) { +absl::StatusOr ConvertSmallFpRngToF32Rng(HloInstruction* rng) { CHECK_EQ(rng->opcode(), HloOpcode::kRng); PrimitiveType primitive_type = rng->shape().element_type(); CHECK(primitive_type == F16 || primitive_type == BF16); @@ -71,7 +71,7 @@ StatusOr ConvertSmallFpRngToF32Rng(HloInstruction* rng) { return new_rng; } -StatusOr GetComputationForRng(HloInstruction* rng) { +absl::StatusOr GetComputationForRng(HloInstruction* rng) { XlaBuilder builder("rng"); const Shape u64_shape = ShapeUtil::MakeShape(xla::U64, {}); const Shape u128_shape = ShapeUtil::MakeShape(xla::U64, {2}); @@ -129,7 +129,8 @@ bool RngExpander::InstructionMatchesPattern(HloInstruction* instruction) { return instruction->opcode() == HloOpcode::kRng; } -StatusOr RngExpander::ExpandInstruction(HloInstruction* rng) { +absl::StatusOr RngExpander::ExpandInstruction( + HloInstruction* rng) { VLOG(2) << "Expand rng instruction " << rng->ToString(); PrimitiveType old_primitive_type = rng->shape().element_type(); if (primitive_util::BitWidth(old_primitive_type) < 32) { diff --git a/xla/service/rng_expander.h b/xla/service/rng_expander.h index 8ba3ca11cfac8..dd41a2a94838e 100644 --- a/xla/service/rng_expander.h +++ b/xla/service/rng_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,7 +27,8 @@ class RngExpander : public OpExpanderPass { protected: bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction(HloInstruction* rng) override; + absl::StatusOr ExpandInstruction( + HloInstruction* rng) override; private: // Cache RNG computations based on the distribution, output shape and shapes diff --git a/xla/service/root_instruction_sinker.cc b/xla/service/root_instruction_sinker.cc index 07d27d0967ab6..007e9914499b9 100644 --- a/xla/service/root_instruction_sinker.cc +++ b/xla/service/root_instruction_sinker.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -49,7 +49,7 @@ void SinkNontupleRoot(HloComputation* computation) { } // namespace -StatusOr RootInstructionSinker::Run( +absl::StatusOr RootInstructionSinker::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { TF_RET_CHECK(module->has_schedule()); diff --git a/xla/service/root_instruction_sinker.h b/xla/service/root_instruction_sinker.h index 9dcaef16891b0..81672c59539a2 100644 --- a/xla/service/root_instruction_sinker.h +++ b/xla/service/root_instruction_sinker.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ class RootInstructionSinker : public HloModulePass { ~RootInstructionSinker() override = default; absl::string_view name() const override { return "root-instruction-sinker"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/root_instruction_sinker_test.cc b/xla/service/root_instruction_sinker_test.cc index 23761d8eee943..1be67c96c61ed 100644 --- a/xla/service/root_instruction_sinker_test.cc +++ b/xla/service/root_instruction_sinker_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/scatter_expander.cc b/xla/service/scatter_expander.cc index 054a72b343c1b..f76c88cabda01 100644 --- a/xla/service/scatter_expander.cc +++ b/xla/service/scatter_expander.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,7 +32,7 @@ namespace xla { // Transposes the given scatter_indices such that the index_vector_dim becomes // the most-minor dimension. -static StatusOr TransposeIndexVectorDimToLast( +static absl::StatusOr TransposeIndexVectorDimToLast( HloInstruction* scatter_indices, int64_t index_vector_dim) { const Shape& scatter_indices_shape = scatter_indices->shape(); @@ -57,7 +57,7 @@ static StatusOr TransposeIndexVectorDimToLast( // Canonicalizes the scatter_indices tensor in order to keep them uniform while // performing the scatter operation. -static StatusOr CanonicalizeScatterIndices( +static absl::StatusOr CanonicalizeScatterIndices( HloInstruction* scatter_indices, int64_t index_vector_dim) { // Transpose the non-index-vector dimensions to the front. TF_ASSIGN_OR_RETURN( @@ -95,7 +95,7 @@ static StatusOr CanonicalizeScatterIndices( // Permutes the `updates` tensor such that all the scatter dims appear in the // major dimensions and all the window dimensions appear in the minor // dimensions. -static StatusOr PermuteScatterAndWindowDims( +static absl::StatusOr PermuteScatterAndWindowDims( HloInstruction* updates, absl::Span update_window_dims) { std::vector permutation; const int64_t updates_rank = updates->shape().rank(); @@ -115,7 +115,7 @@ static StatusOr PermuteScatterAndWindowDims( } // Expands or contracts the scatter indices in the updates tensor. -static StatusOr AdjustScatterDims( +static absl::StatusOr AdjustScatterDims( const Shape& scatter_indices_shape, HloInstruction* updates, int64_t index_vector_dim) { int64_t num_scatter_dims = scatter_indices_shape.dimensions_size(); @@ -133,7 +133,7 @@ static StatusOr AdjustScatterDims( // Expands an index vector from the scatter_indices tensor into a vector that // can be used to dynamic-update-slice to perform the scatter update. -static StatusOr ExpandIndexVectorIntoOperandSpace( +static absl::StatusOr ExpandIndexVectorIntoOperandSpace( HloInstruction* index_vector, const ScatterDimensionNumbers& dim_numbers, int64_t operand_rank) { HloComputation* computation = index_vector->parent(); @@ -172,7 +172,7 @@ static StatusOr ExpandIndexVectorIntoOperandSpace( return MakeConcatHlo(expanded_index_components, /*dimension=*/0); } -static StatusOr CheckIndexValidity( +static absl::StatusOr CheckIndexValidity( HloComputation* computation, HloInstruction* index, absl::Span operand_dims, absl::Span window_sizes, HloModule* module) { @@ -218,8 +218,8 @@ static StatusOr CheckIndexValidity( return MakeBroadcastHlo(valid_index_reduced, {}, window_sizes); } -static StatusOr CallAndGetOutput(HloComputation* original, - int output_index) { +static absl::StatusOr CallAndGetOutput( + HloComputation* original, int output_index) { HloInstruction* original_root = original->root_instruction(); if (!original_root->shape().IsTuple()) { return original; @@ -246,7 +246,7 @@ static StatusOr CallAndGetOutput(HloComputation* original, } // Body of the while loop that performs the scatter operation using other HLOs. -static StatusOr> ScatterLoopBody( +static absl::StatusOr> ScatterLoopBody( HloScatterInstruction* scatter, HloInstruction* induction_var, absl::Span loop_state) { const ScatterDimensionNumbers& dim_numbers = @@ -410,7 +410,7 @@ static int64_t ScatterTripCount(const HloScatterInstruction* scatter) { // from c. and d. using the update_computation of scatter. // f. Write the updated value of the slice into the operand tensor. -StatusOr ScatterExpander::ExpandInstruction( +absl::StatusOr ScatterExpander::ExpandInstruction( HloInstruction* inst) { auto* scatter = Cast(inst); auto scatter_operands = scatter->scatter_operands(); @@ -470,7 +470,7 @@ StatusOr ScatterExpander::ExpandInstruction( absl::c_copy(scatter_operands, std::back_inserter(loop_state)); loop_state.push_back(canonical_scatter_indices); absl::c_copy(adjusted_canonical_updates, std::back_inserter(loop_state)); - StatusOr> scatter_loop_result_status = + absl::StatusOr> scatter_loop_result_status = WhileUtil::MakeCountedLoop( scatter->parent(), scatter_loop_trip_count, loop_state, [scatter](HloInstruction* induction_var, @@ -497,7 +497,7 @@ bool IsCombinerAssociative(const HloComputation* combiner) { case HloOpcode::kMinimum: case HloOpcode::kMaximum: return true; - // Other common combiners are associative at least for interger arithmetic. + // Other common combiners are associative at least for integer arithmetic. case HloOpcode::kAdd: case HloOpcode::kMultiply: case HloOpcode::kOr: diff --git a/xla/service/scatter_expander.h b/xla/service/scatter_expander.h index 7e6ae841b0082..d87bb52fc0738 100644 --- a/xla/service/scatter_expander.h +++ b/xla/service/scatter_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -56,7 +56,8 @@ class ScatterExpander : public OpExpanderPass { protected: bool InstructionMatchesPattern(HloInstruction* inst) override; - StatusOr ExpandInstruction(HloInstruction* inst) override; + absl::StatusOr ExpandInstruction( + HloInstruction* inst) override; private: Mode mode_; diff --git a/xla/service/scatter_expander_test.cc b/xla/service/scatter_expander_test.cc index b3ea60b7d0031..71c76374f0d08 100644 --- a/xla/service/scatter_expander_test.cc +++ b/xla/service/scatter_expander_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/scatter_simplifier.cc b/xla/service/scatter_simplifier.cc index 200d1c2f161bd..b6b12a3dc3b9a 100644 --- a/xla/service/scatter_simplifier.cc +++ b/xla/service/scatter_simplifier.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ limitations under the License. namespace xla { namespace { -StatusOr FlattenAndTransposeUpdates( +absl::StatusOr FlattenAndTransposeUpdates( HloInstruction* updates, absl::Span update_window_dims, absl::Span inserted_window_dims, int64_t scatter_indices_size) { @@ -92,7 +92,7 @@ std::vector MakeUpdatePermutation( // Transforms the scatter_updates field of scatter. scatter_indices_size is the // size of the scatter dimension in scatter_indices. -StatusOr> TransformScatterUpdates( +absl::StatusOr> TransformScatterUpdates( HloScatterInstruction* scatter, const std::vector& update_permutation, int64_t scatter_indices_size) { @@ -128,7 +128,7 @@ ScatterDimensionNumbers MakeScatterDimensionNumbers( } // namespace -StatusOr ScatterSimplifier::ExpandInstruction( +absl::StatusOr ScatterSimplifier::ExpandInstruction( HloInstruction* inst) { auto* scatter = Cast(inst); @@ -202,26 +202,28 @@ StatusOr ScatterSimplifier::ExpandInstruction( return MaybeMakeTuple(result_items); } -bool ScatterSimplifier::InstructionMatchesPattern(HloInstruction* inst) { - if (auto* scatter = DynCast(inst)) { - const auto& dims = scatter->scatter_dimension_numbers(); - - bool nonstandard_index_vector_dim = - dims.index_vector_dim() != - scatter->scatter_indices()->shape().rank() - 1; - int64_t num_scatter_dims = - scatter->scatter_updates().front()->shape().rank() - - dims.update_window_dims().size(); - bool scatter_indices_reordered = - !IsIdentityPermutation(dims.scatter_dims_to_operand_dims()); - bool scatter_dim_not_first = - absl::c_linear_search(dims.update_window_dims(), 0); - - return nonstandard_index_vector_dim || num_scatter_dims > 1 || +bool ScatterSimplifier::IsSimplifiedScatter( + const HloScatterInstruction* scatter) { + const auto& dims = scatter->scatter_dimension_numbers(); + + bool nonstandard_index_vector_dim = + dims.index_vector_dim() != scatter->scatter_indices()->shape().rank() - 1; + int64_t num_scatter_dims = + scatter->scatter_updates().front()->shape().rank() - + dims.update_window_dims().size(); + bool scatter_indices_reordered = + !IsIdentityPermutation(dims.scatter_dims_to_operand_dims()); + bool scatter_dim_not_first = + absl::c_linear_search(dims.update_window_dims(), 0); + + return !(nonstandard_index_vector_dim || num_scatter_dims > 1 || scatter_indices_reordered || scatter_dim_not_first || - !dims.inserted_window_dims().empty(); - } - return false; + !dims.inserted_window_dims().empty()); +} + +bool ScatterSimplifier::InstructionMatchesPattern(HloInstruction* inst) { + auto* scatter = DynCast(inst); + return scatter && !IsSimplifiedScatter(scatter); } } // namespace xla diff --git a/xla/service/scatter_simplifier.h b/xla/service/scatter_simplifier.h index c993782cd8f42..8b14e16abc9ff 100644 --- a/xla/service/scatter_simplifier.h +++ b/xla/service/scatter_simplifier.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_SCATTER_SIMPLIFIER_H_ #define XLA_SERVICE_SCATTER_SIMPLIFIER_H_ +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/op_expander_pass.h" namespace xla { @@ -42,10 +43,13 @@ class ScatterSimplifier : public OpExpanderPass { public: absl::string_view name() const override { return "scatter_simplifier"; } + static bool IsSimplifiedScatter(const HloScatterInstruction* scatter); + protected: bool InstructionMatchesPattern(HloInstruction* inst) override; - StatusOr ExpandInstruction(HloInstruction* inst) override; + absl::StatusOr ExpandInstruction( + HloInstruction* inst) override; }; } // namespace xla diff --git a/xla/service/scatter_simplifier_test.cc b/xla/service/scatter_simplifier_test.cc index 39a21d4edaeba..36d82b4b487ee 100644 --- a/xla/service/scatter_simplifier_test.cc +++ b/xla/service/scatter_simplifier_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/select_and_scatter_expander.cc b/xla/service/select_and_scatter_expander.cc index d7cdb2783d84b..10437d0d5e966 100644 --- a/xla/service/select_and_scatter_expander.cc +++ b/xla/service/select_and_scatter_expander.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ limitations under the License. namespace xla { -StatusOr SelectAndScatterExpander::ExpandInstruction( +absl::StatusOr SelectAndScatterExpander::ExpandInstruction( HloInstruction* instruction) { // Prepare the original values auto* computation = instruction->parent(); diff --git a/xla/service/select_and_scatter_expander.h b/xla/service/select_and_scatter_expander.h index 96161a3f6a8d2..9e544972b3fec 100644 --- a/xla/service/select_and_scatter_expander.h +++ b/xla/service/select_and_scatter_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,7 +32,8 @@ class SelectAndScatterExpander : public OpExpanderPass { protected: bool InstructionMatchesPattern(HloInstruction* inst) override; - StatusOr ExpandInstruction(HloInstruction* inst) override; + absl::StatusOr ExpandInstruction( + HloInstruction* inst) override; }; } // namespace xla diff --git a/xla/service/select_and_scatter_expander_test.cc b/xla/service/select_and_scatter_expander_test.cc index 2162c1b546dc3..0daf6a7fa586a 100644 --- a/xla/service/select_and_scatter_expander_test.cc +++ b/xla/service/select_and_scatter_expander_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/service.cc b/xla/service/service.cc index 48b98e197190d..9017774ce2e51 100644 --- a/xla/service/service.cc +++ b/xla/service/service.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -62,6 +62,7 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/protobuf.h" +#include "tsl/profiler/lib/scoped_annotation.h" namespace xla { namespace { @@ -203,7 +204,7 @@ Status Service::ValidateResultShape(const Shape& client_shape, return OkStatus(); } -StatusOr>> +absl::StatusOr>> Service::ResolveAndValidateArguments( absl::Span arguments, absl::Span stream_executors) const { @@ -228,7 +229,7 @@ Service::ResolveAndValidateArguments( return replicated_arguments; } -StatusOr> Service::CreateModuleConfig( +absl::StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, absl::Span argument_shapes, const ExecutionOptions* execution_options, @@ -245,7 +246,7 @@ StatusOr> Service::CreateModuleConfig( num_threads, aot_options); } -StatusOr> Service::CreateModuleConfig( +absl::StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, absl::Span arguments, const ExecutionOptions& execution_options, @@ -258,7 +259,8 @@ StatusOr> Service::CreateModuleConfig( aot_options); } -StatusOr>> Service::BuildExecutables( +absl::StatusOr>> +Service::BuildExecutables( const std::vector& module_protos, std::vector> module_configs, Backend* backend, std::vector> executors, @@ -305,7 +307,7 @@ StatusOr>> Service::BuildExecutables( return std::move(executables); } -StatusOr>> +absl::StatusOr>> Service::BuildAotResults( const std::vector& module_protos, std::vector> module_configs, @@ -327,14 +329,7 @@ Service::BuildAotResults( TF_ASSIGN_OR_RETURN( auto module, CreateModuleFromProto(*proto, config, run_backend_only)); DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName); - if (run_backend_only) { - module_group->push_back(std::move(module)); - } else { - TF_ASSIGN_OR_RETURN(auto module_after_opt, - backend->compiler()->RunHloPasses( - std::move(module), executors[0][0], options)); - module_group->push_back(std::move(module_after_opt)); - } + module_group->push_back(std::move(module)); } AotCompilationOptions aot_options(backend->compiler()->PlatformId()); @@ -349,7 +344,7 @@ Service::BuildAotResults( return std::move(aot_results); } -StatusOr> +absl::StatusOr> Service::ExecuteParallelAndRegisterResult( absl::Span executables, absl::Span>> arguments, @@ -428,15 +423,15 @@ Service::ExecuteParallelAndRegisterResult( for (int64_t i = 0, end = streams.size(); i < end; ++i) { Status block_status = streams[i]->BlockHostUntilDone(); if (!block_status.ok()) { - return InternalError("failed to complete execution for stream %d: %s", i, - block_status.message()); + return Internal("failed to complete execution for stream %d: %s", i, + block_status.message()); } } return result_handles; } -StatusOr Service::ExecuteAndRegisterResult( +absl::StatusOr Service::ExecuteAndRegisterResult( Executable* executable, absl::Span> arguments, Backend* backend, const DeviceHandle& device_handle, @@ -475,7 +470,7 @@ StatusOr Service::ExecuteAndRegisterResult( if (options_.number_of_replicas() == 1) { TF_ASSIGN_OR_RETURN(auto result, executable->ExecuteOnStreamWrapper( - &run_options[0], arguments[0])); + run_options.data(), arguments[0])); return allocation_tracker_.Register(std::move(result), result_tag); } @@ -493,7 +488,7 @@ StatusOr Service::ExecuteAndRegisterResult( result_tag); } -StatusOr> Service::GetExecutors( +absl::StatusOr> Service::GetExecutors( const ExecutionOptions& execution_options, int64_t requests_size, int64_t request_index) const { if (execution_options.device_handles().empty()) { @@ -518,7 +513,8 @@ StatusOr> Service::GetExecutors( return executors; } -StatusOr>> Service::GetArguments( +absl::StatusOr>> +Service::GetArguments( const ExecutionOptions& execution_options, absl::Span arguments) const { // Resolve the allocations for the arguments of the computation, and create @@ -662,16 +658,17 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, Status execution_status = OkStatus(); if (executable_ptrs.size() == 1) { - StatusOr output_or_status = ExecuteAndRegisterResult( - executable_ptrs[0], all_arguments[0], execute_backend_.get(), - device_handles[0], computation_names[0], &profile); + absl::StatusOr output_or_status = + ExecuteAndRegisterResult(executable_ptrs[0], all_arguments[0], + execute_backend_.get(), device_handles[0], + computation_names[0], &profile); if (output_or_status.ok()) { outputs.push_back(std::move(output_or_status).value()); } else { execution_status = output_or_status.status(); } } else { - StatusOr> outputs_or_status = + absl::StatusOr> outputs_or_status = ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, execute_backend_.get(), device_handles, computation_names, &profile); @@ -741,7 +738,7 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, return OkStatus(); } -StatusOr> Service::BuildExecutable( +absl::StatusOr> Service::BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, se::StreamExecutor* executor, const Compiler::CompileOptions& options, @@ -750,6 +747,10 @@ StatusOr> Service::BuildExecutable( "BuildExecutable on service %p with serialized module proto: %s", this, module_proto.name()); + tsl::profiler::ScopedAnnotation annotation{[&] { + return absl::StrCat("XlaCompile:#module=", module_proto.name(), "#"); + }}; + TF_ASSIGN_OR_RETURN( std::unique_ptr module, CreateModuleFromProto(module_proto, *module_config, run_backend_only)); @@ -770,6 +771,9 @@ StatusOr> Service::BuildExecutable( std::move(module), executor, options)); } + tsl::profiler::ScopedAnnotation backend_annotation{[&] { + return absl::StrCat("XlaCompileBackend:#module=", module_proto.name(), "#"); + }}; TF_ASSIGN_OR_RETURN( std::unique_ptr executable, backend->compiler()->RunBackend(std::move(module), executor, options)); @@ -786,10 +790,10 @@ StatusOr> Service::BuildExecutable( buffer_assignment_proto_after_opt != nullptr) { CHECK(DumpingEnabledForHloModule(executable->module())); *hlo_proto_before_opt->mutable_buffer_assignment() = - std::move(*buffer_assignment_proto_after_opt); + *buffer_assignment_proto_after_opt; executable->set_hlo_proto(std::move(hlo_proto_before_opt)); } - return std::move(executable); + return executable; } Status Service::Compile(const CompileRequest* arg, CompileResponse* result) { @@ -1089,7 +1093,7 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, evaluator.set_dynamic_dimension_inference(&dynamic_dimension_inference); evaluator.set_custom_call_handler( [](const HloInstruction* custom_call, - absl::Span operands) -> StatusOr { + absl::Span operands) -> absl::StatusOr { if (custom_call->custom_call_target() == "SliceToDynamic") { auto result = operands[0]->Clone(); for (int64_t i = 0; i < result.shape().rank(); ++i) { @@ -1161,7 +1165,7 @@ DeviceHandle Service::SingleComputationDeviceHandle() const { return device_handle; } -StatusOr> Service::Replicas( +absl::StatusOr> Service::Replicas( const Backend& backend, const DeviceHandle& device_handle) const { std::vector replicas; for (int replica = 0; replica < options_.number_of_replicas(); ++replica) { diff --git a/xla/service/service.h b/xla/service/service.h index dedf1445bdd90..f606698b1a858 100644 --- a/xla/service/service.h +++ b/xla/service/service.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -171,7 +171,7 @@ class Service : public ServiceInterface { // Create a Hlo module config for the given program shape and arguments. // aot_options is optional; if not given a default is used. - StatusOr> CreateModuleConfig( + absl::StatusOr> CreateModuleConfig( const ProgramShape& program_shape, absl::Span argument_shapes, const ExecutionOptions* execution_options, @@ -186,19 +186,19 @@ class Service : public ServiceInterface { private: // A private overload for Service itself, used by other methods within this // class. - StatusOr> CreateModuleConfig( + absl::StatusOr> CreateModuleConfig( const ProgramShape& program_shape, absl::Span arguments, const ExecutionOptions& execution_options, const AotCompilationOptions* aot_options = nullptr); // Prepare the executors for executing parallel. - StatusOr> GetExecutors( + absl::StatusOr> GetExecutors( const ExecutionOptions& execution_options, int64_t requests_size, int64_t request_index) const; // Prepare the arguments for executing parallel. - StatusOr>> GetArguments( + absl::StatusOr>> GetArguments( const ExecutionOptions& execution_options, absl::Span arguments) const; @@ -214,17 +214,18 @@ class Service : public ServiceInterface { // the corresponding allocations for every replica. The function also verifies // that each allocation matches the execution platform and device ordinal of // the corresponding replica. - StatusOr>> + absl::StatusOr>> ResolveAndValidateArguments( absl::Span arguments, absl::Span stream_executors) const; + public: // Builds an Executable for the given parameters. // // If device_allocator is not null, the compiler may use it to allocate temp // buffers, which the compiler is responsible for freeing. The allocator // given here need not match the allocator used when running the executable. - StatusOr> BuildExecutable( + absl::StatusOr> BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, se::StreamExecutor* executor, const Compiler::CompileOptions& options, @@ -232,26 +233,29 @@ class Service : public ServiceInterface { // Same as BuildExecutable() above, but builds a list of Executables for the // given computations that may interact with each other. - StatusOr>> BuildExecutables( + absl::StatusOr>> BuildExecutables( const std::vector& module_protos, std::vector> module_configs, Backend* backend, std::vector> executors, const Compiler::CompileOptions& options, bool run_backend_only = false); + protected: // Same as BuildExecutable() above, but builds a list of // AotCompilationResult(s), which can be persisted to later load Executable // objects. - StatusOr>> BuildAotResults( - const std::vector& module_protos, - std::vector> module_configs, - Backend* backend, std::vector> executors, - const Compiler::CompileOptions& options, bool run_backend_only = false); + absl::StatusOr>> + BuildAotResults(const std::vector& module_protos, + std::vector> module_configs, + Backend* backend, + std::vector> executors, + const Compiler::CompileOptions& options, + bool run_backend_only = false); // Runs the given executable with the given arguments and register the result // in the allocation tracker. The handle of the result from the tracker is // returned. If the parameter "profile" is not null, it points to an // ExecutionProfile object which will be filled in with profile data. - StatusOr ExecuteAndRegisterResult( + absl::StatusOr ExecuteAndRegisterResult( Executable* executable, absl::Span> arguments, Backend* backend, const DeviceHandle& device_handle, @@ -260,7 +264,8 @@ class Service : public ServiceInterface { // Runs the given executables with the given arguments and register the result // from each executable in the allocation tracker. The handles of the result // from the tracker are returned. - StatusOr> ExecuteParallelAndRegisterResult( + absl::StatusOr> + ExecuteParallelAndRegisterResult( absl::Span executables, absl::Span>> arguments, Backend* backend, absl::Span device_handles, @@ -269,7 +274,7 @@ class Service : public ServiceInterface { // Returns the stream executors assigned to the replicas represented by the // given device handle. Each device_handle is a virtual replicated device that // represents a set of physical devices for the replicas. - StatusOr> Replicas( + absl::StatusOr> Replicas( const Backend& backend, const DeviceHandle& device_handle) const; // Returns the device handle that represents the replicated device for a diff --git a/xla/service/service_executable_run_options.h b/xla/service/service_executable_run_options.h index a8c5dc0745f9a..a2f57e6bbae63 100644 --- a/xla/service/service_executable_run_options.h +++ b/xla/service/service_executable_run_options.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,15 +35,16 @@ class ServiceExecutableRunOptions { // with the first argument being the device ordinal, the second // argument being the number of streams to borrow, and the third // argument being the priority of the streams. - using StreamBorrower = std::function>( - int, int, se::StreamPriority)>; + using StreamBorrower = + std::function>( + int, int, se::StreamPriority)>; ServiceExecutableRunOptions() : ServiceExecutableRunOptions(ExecutableRunOptions()) {} - explicit ServiceExecutableRunOptions(ExecutableRunOptions run_options, + explicit ServiceExecutableRunOptions(const ExecutableRunOptions& run_options, StreamBorrower stream_borrower = nullptr) - : run_options_(std::move(run_options)), + : run_options_(run_options), stream_borrower_(std::move(stream_borrower)) {} // Returns reference or pointer to `ExecutableRunOptions` member. @@ -59,7 +60,7 @@ class ServiceExecutableRunOptions { // Borrows a stream and returns a smart pointer which returns the stream on // destruction. - StatusOr BorrowStream( + absl::StatusOr BorrowStream( int device_ordinal, se::StreamPriority priority = se::StreamPriority::Default) const { if (!stream_borrower_) { @@ -73,7 +74,7 @@ class ServiceExecutableRunOptions { return stream; } - StatusOr> BorrowStreams( + absl::StatusOr> BorrowStreams( int device_ordinal, int num_streams, se::StreamPriority priority = se::StreamPriority::Default) const { return stream_borrower_ diff --git a/xla/service/shape_inference.cc b/xla/service/shape_inference.cc index bf324e880c396..5e439128c1ade 100644 --- a/xla/service/shape_inference.cc +++ b/xla/service/shape_inference.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/shape_inference.h" #include +#include #include #include #include @@ -37,6 +38,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/permutation_util.h" #include "xla/primitive_util.h" @@ -44,7 +46,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status.h" #include "xla/status_macros.h" -#include "xla/statusor.h" #include "xla/util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" @@ -65,6 +66,18 @@ bool AllUnique(absl::Span slice) { return std::set(slice.begin(), slice.end()).size() == slice.size(); } +// Checks whether the given dimension size `size` is unbounded dynamic size. +bool IsUnboundedDynamicSize(int64_t size) { + return size == Shape::kUnboundedSize; +} + +// Returns success if the given two dimension sizes 'size_a' and 'size_b' are +// compatible: at least one is dynamic or both are equal. +bool CompatibleDimensionSizes(int64_t size_a, int64_t size_b) { + return IsUnboundedDynamicSize(size_a) || IsUnboundedDynamicSize(size_b) || + size_a == size_b; +} + Status ExpectArray(const Shape& shape, absl::string_view op_type) { if (!shape.IsArray()) { return InvalidArgument("Expected array argument for %s, but got %s.", @@ -172,9 +185,9 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape, return OkStatus(); } -StatusOr InferWindowOutputShape(const Shape& base_shape, - const Window& window, - PrimitiveType element_type) { +absl::StatusOr InferWindowOutputShape(const Shape& base_shape, + const Window& window, + PrimitiveType element_type) { if (window.dimensions_size() != base_shape.rank()) { return InvalidArgument( "Window has dimension %d but base shape has dimension %d.", @@ -204,15 +217,19 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, window.DebugString()); } - const int64_t dilated_base = window_util::DilatedBound( - ShapeUtil::GetDimension(base_shape, i), dim.base_dilation()); - const int64_t padded_dilated_base = - dim.padding_low() + dilated_base + dim.padding_high(); - const int64_t dilated_window = - window_util::DilatedBound(dim.size(), dim.window_dilation()); - - output_dimensions[i] = window_util::StridedBound( - padded_dilated_base, dilated_window, dim.stride()); + if (IsUnboundedDynamicSize(ShapeUtil::GetDimension(base_shape, i))) { + output_dimensions[i] = Shape::kUnboundedSize; + } else { + const int64_t dilated_base = window_util::DilatedBound( + ShapeUtil::GetDimension(base_shape, i), dim.base_dilation()); + const int64_t padded_dilated_base = + dim.padding_low() + dilated_base + dim.padding_high(); + const int64_t dilated_window = + window_util::DilatedBound(dim.size(), dim.window_dilation()); + + output_dimensions[i] = window_util::StridedBound( + padded_dilated_base, dilated_window, dim.stride()); + } output_is_dynamic[i] = base_shape.is_dynamic_dimension(i); } @@ -220,14 +237,106 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, output_is_dynamic); } +// Encapsulates inferred dimension size and bound size. +struct DimAndBound { + int64_t dimension, bound; +}; + +// Inference rules to concat dimensions with bounds (lhs/rhs are commutative): +// Dim of lhs Dim of rhs Infer +// c0: X Y X+Y +// c1: X ? ? +// c2: X <=B <=X+B +// c3: ? ? ? +//. c4: ? <=B ? +// c5: <=B <=C <=B+C +// Note: +// A HLO static dimension size `X` is expressed as size=X, and bound=? +// A bounded dynamic dimension size `<=X` is be expressed as size=X, and bound=? +// A unbounded dynamic dimension size, `?`, is expressed as size=?, and bound=? +DimAndBound InferConcatenatedDimAndBound(int64_t left_size, int64_t right_size, + int64_t left_bound, + int64_t right_bound) { + bool is_left_static_dim = !IsUnboundedDynamicSize(left_size); + bool is_right_static_dim = !IsUnboundedDynamicSize(right_size); + bool is_left_static_bound = !IsUnboundedDynamicSize(left_bound); + bool is_right_static_bound = !IsUnboundedDynamicSize(right_bound); + int64_t inferred_size = Shape::kUnboundedSize; + int64_t inferred_bound = Shape::kUnboundedSize; + + if (is_left_static_dim && is_right_static_dim) { + inferred_size = left_size + right_size; + } + if (is_left_static_bound || is_right_static_bound) { + int64_t leftBoundOrSize = is_left_static_bound ? left_bound : left_size; + int64_t rightBoundOrSize = is_right_static_bound ? right_bound : right_size; + if (!IsUnboundedDynamicSize(leftBoundOrSize) && + !IsUnboundedDynamicSize(rightBoundOrSize)) { + inferred_bound = leftBoundOrSize + rightBoundOrSize; + } + } + return {inferred_size, inferred_bound}; +} + +// Inference rules to merge dimensions with bounds (lhs/rhs are commutative): +// Dim of lhs Dim of rhs Infer +// c0: X X X +// c1: X ? X +// c2: X <=X <=X +// c3: ? ? ? +// c4: ? <=B <=B +// c5: <=B <=C Error, mismatched bound sizes +// c6: X Y Error, mismatched dimension sizes +// Note: +// A HLO static dimension size `X` is expressed as size=X, and bound=? +// A bounded dynamic dimension size `<=X` is be expressed as size=X, and bound=? +// A unbounded dynamic dimension size, `?`, is expressed as size=?, and bound=? +absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, + int64_t left_size, + int64_t right_size, + int64_t left_bound, + int64_t right_bound) { + bool is_left_static_dim = !IsUnboundedDynamicSize(left_size); + bool is_right_static_dim = !IsUnboundedDynamicSize(right_size); + bool is_left_static_bound = !IsUnboundedDynamicSize(left_bound); + bool is_right_static_bound = !IsUnboundedDynamicSize(right_bound); + int64_t inferred_size = Shape::kUnboundedSize; + int64_t inferred_bound = Shape::kUnboundedSize; + + if (is_left_static_bound || is_right_static_bound) { + if (is_left_static_bound && is_right_static_bound && + left_bound != right_bound) { + return InvalidArgument("Mismatched bound sizes %d and %d in dimension %d", + left_bound, right_bound, dim); + } + inferred_bound = is_left_static_bound ? left_bound : right_bound; + } + if (is_left_static_dim || is_right_static_dim) { + if (is_left_static_dim && is_right_static_dim && left_size != right_size) { + return InvalidArgument( + "Mismatched dimension sizes %d and %d in dimension %d", left_size, + right_size, dim); + } + inferred_size = is_left_static_dim ? left_size : right_size; + if (!IsUnboundedDynamicSize(inferred_bound) && + inferred_size != inferred_bound) { + return InvalidArgument( + "Mismatched dimension size %d and bound %d in dimension %d", + inferred_size, inferred_bound, dim); + } + } + DimAndBound dim_and_bound = {inferred_size, inferred_bound}; + return dim_and_bound; +} + } // namespace -/* static */ StatusOr ShapeInference::InferUnaryOpShape( +/* static */ absl::StatusOr ShapeInference::InferUnaryOpShape( HloOpcode opcode, const HloInstruction* operand) { return InferUnaryOpShape(opcode, operand->shape()); } -/* static */ StatusOr ShapeInference::InferUnaryOpShape( +/* static */ absl::StatusOr ShapeInference::InferUnaryOpShape( HloOpcode opcode, const Shape& shape) { // There is no copy operation at the proto level, so handle copy explicitly. // A domain shape is the same as the input one. @@ -241,6 +350,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, switch (opcode) { case HloOpcode::kFloor: case HloOpcode::kCeil: + case HloOpcode::kErf: case HloOpcode::kRoundNearestAfz: case HloOpcode::kRoundNearestEven: if (!ShapeUtil::ElementIsFloating(shape)) { @@ -356,7 +466,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } } -/* static */ StatusOr ShapeInference::InferTopKShape( +/* static */ absl::StatusOr ShapeInference::InferTopKShape( const Shape& operand_shape, int64_t k) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of top-k operation")); int64_t last_dim = operand_shape.rank() - 1; @@ -378,7 +488,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeTupleShape({out, idxs_shape}); } -/* static */ StatusOr ShapeInference::InferConcatOpShape( +/* static */ absl::StatusOr ShapeInference::InferConcatOpShape( absl::Span arg_shapes, const int64_t dimension) { if (arg_shapes.empty()) { return InvalidArgument("Concatenate expects at least one argument."); @@ -411,8 +521,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } for (int64_t dimension_number = 0; dimension_number < arg_shape->rank(); ++dimension_number) { - if (arg_shape->dimensions(dimension_number) != - shape->dimensions(dimension_number)) { + if (!CompatibleDimensionSizes(arg_shape->dimensions(dimension_number), + shape->dimensions(dimension_number))) { if (dimension_number == dimension) { continue; // It's okay to differ in the dimension we're // concatenating. @@ -420,7 +530,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Cannot concatenate arrays that differ in dimensions other than " "the one being concatenated. Dimension %d in both shapes must be " - "equal: %s vs %s.", + "equal (or compatible): %s vs %s.", dimension_number, ShapeUtil::HumanString(*arg_shape), ShapeUtil::HumanString(*shape)); } @@ -428,26 +538,50 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape); } - std::vector new_dimensions(arg_shape->dimensions().begin(), - arg_shape->dimensions().end()); - for (size_t i = 1; i < arg_shapes.size(); ++i) { - new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension); - } - - Shape result = ShapeUtil::MakeShape(element_type, new_dimensions); + // Infer the most specific (size, bound) of all dimensions of the return type + int64_t rank = arg_shape->rank(); + std::vector inferred_sizes(rank, Shape::kUnboundedSize); + std::vector inferred_bounds(rank, Shape::kUnboundedSize); + // Note: for the concatenate dimension, 0 should be the identity element: + // Any dim size can keep unchanged when concatenated with 0 + inferred_sizes[dimension] = 0; - // Set dynamic dimensions if any input has dynamic dimension. for (const Shape* shape : arg_shapes) { - for (int64_t i = 0; i < shape->dimensions_size(); ++i) { - if (shape->is_dynamic_dimension(i)) { - result.set_dynamic_dimension(i, true); + for (int dim = 0; dim < rank; ++dim) { + DimAndBound inferred_dim_and_bound; + + int64_t dimension_size = shape->dimensions(dim); + int64_t leftSize = inferred_sizes[dim]; + int64_t rightSize = dimension_size; + int64_t leftBound = inferred_bounds[dim]; + int64_t rightBound = shape->is_dynamic_dimension(dim) + ? dimension_size + : Shape::kUnboundedSize; + if (dim == dimension) { + inferred_dim_and_bound = InferConcatenatedDimAndBound( + leftSize, rightSize, leftBound, rightBound); + } else { + TF_ASSIGN_OR_RETURN( + inferred_dim_and_bound, + InferMostSpecificDimAndBound(dim, leftSize, rightSize, leftBound, + rightBound)); } + inferred_sizes[dim] = inferred_dim_and_bound.dimension; + inferred_bounds[dim] = inferred_dim_and_bound.bound; + } + } + + Shape result = ShapeUtil::MakeShape(element_type, inferred_sizes); + for (int64_t i = 0; i < inferred_bounds.size(); ++i) { + if (!IsUnboundedDynamicSize(inferred_bounds[i]) || + IsUnboundedDynamicSize(inferred_sizes[i])) { + result.set_dynamic_dimension(i, true); } } return result; } -/* static */ StatusOr ShapeInference::InferConvertShape( +/* static */ absl::StatusOr ShapeInference::InferConvertShape( const Shape& operand_shape, PrimitiveType new_element_type) { if (!operand_shape.IsArray() || !primitive_util::IsArrayType(new_element_type)) { @@ -463,7 +597,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::ChangeElementType(operand_shape, new_element_type); } -/* static */ StatusOr ShapeInference::InferBitcastConvertShape( +/* static */ absl::StatusOr ShapeInference::InferBitcastConvertShape( const Shape& operand_shape, PrimitiveType new_element_type) { auto old_element_type = operand_shape.element_type(); if (primitive_util::IsComplexType(old_element_type) != @@ -517,7 +651,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return new_shape; } -/* static */ StatusOr ShapeInference::InferStochasticConvertShape( +/* static */ absl::StatusOr ShapeInference::InferStochasticConvertShape( const Shape& operand_shape, const Shape& random_shape, PrimitiveType new_element_type) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape)); @@ -561,7 +695,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::ChangeElementType(operand_shape, new_element_type); } -/* static */ StatusOr ShapeInference::InferReducePrecisionShape( +/* static */ absl::StatusOr ShapeInference::InferReducePrecisionShape( const Shape& operand_shape, const int exponent_bits, const int mantissa_bits) { if (!ShapeUtil::ElementIsFloating(operand_shape)) { @@ -585,7 +719,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return operand_shape; } -/* static */ StatusOr ShapeInference::InferPadShape( +/* static */ absl::StatusOr ShapeInference::InferPadShape( const Shape& operand_shape, const Shape& padding_value_shape, const PaddingConfig& padding_config) { if (!operand_shape.IsArray()) { @@ -624,13 +758,17 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, std::vector is_dynamic(operand_shape.rank()); for (int64_t i = 0; i < operand_shape.dimensions_size(); ++i) { const auto& p = padding_config.dimensions(i); - dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() + - p.edge_padding_high() + - std::max(operand_shape.dimensions(i) - 1, 0LL) * - p.interior_padding(); - if (dimensions[i] < 0) { - return InvalidArgument("Padding result in negative size for dimension %d", - i); + if (operand_shape.is_unbounded_dynamic_dimension(i)) { + dimensions[i] = Shape::kUnboundedSize; + } else { + dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() + + p.edge_padding_high() + + std::max(operand_shape.dimensions(i) - 1, 0LL) * + p.interior_padding(); + if (dimensions[i] < 0) { + return InvalidArgument( + "Padding result in negative size for dimension %d", i); + } } is_dynamic[i] = operand_shape.is_dynamic_dimension(i); } @@ -704,10 +842,11 @@ Status ValidateDotDimensionNumbers( } // namespace -/* static */ StatusOr ShapeInference::InferDotOpShape( +/* static */ absl::StatusOr ShapeInference::InferDotOpShape( const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers, - std::optional preferred_element_type) { + std::optional preferred_element_type, + absl::Span sparsity) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot")); @@ -724,6 +863,23 @@ Status ValidateDotDimensionNumbers( // Validate basic properties of dot dimension numbers. TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers)); + // Sparsity is only supported for contracting dimensions. + // With N:M sparsity, the contracting dimension sizes have N/M ratio. + const int kSize = HloDotInstruction::kOperands; + std::array, kSize> sparsity_nm = {{{1, 1}, {1, 1}}}; + std::array sparsity_dim = {-1, -1}; + for (const auto& descriptor : sparsity) { + TF_RET_CHECK(descriptor.index() == 0 || descriptor.index() == 1); + sparsity_dim[descriptor.index()] = descriptor.dimension(); + switch (descriptor.type()) { + case SPARSITY_STRUCTURED_N_M: + sparsity_nm[descriptor.index()] = {descriptor.n(), descriptor.m()}; + break; + default: + LOG(FATAL) << "Unsupported sparsity type: " << descriptor.type(); + } + } + // Check that number of contracting dimensions match. if (dimension_numbers.lhs_contracting_dimensions_size() != dimension_numbers.rhs_contracting_dimensions_size()) { @@ -738,9 +894,24 @@ Status ValidateDotDimensionNumbers( dimension_numbers.lhs_contracting_dimensions(i); const int64_t rhs_contracting_dimension = dimension_numbers.rhs_contracting_dimensions(i); - if (lhs.dimensions(lhs_contracting_dimension) != - rhs.dimensions(rhs_contracting_dimension)) { - return fail("Contracting dimension sizes do not match."); + int64_t lhs_size = lhs.dimensions(lhs_contracting_dimension); + int64_t rhs_size = rhs.dimensions(rhs_contracting_dimension); + bool is_sparse = false; + if (lhs_contracting_dimension == sparsity_dim[0]) { + lhs_size *= sparsity_nm[0].second; + rhs_size *= sparsity_nm[0].first; + is_sparse = true; + } + if (rhs_contracting_dimension == sparsity_dim[1]) { + lhs_size *= sparsity_nm[1].first; + rhs_size *= sparsity_nm[1].second; + is_sparse = true; + } + if (!CompatibleDimensionSizes(lhs_size, rhs_size)) { + return fail( + !is_sparse + ? "Contracting dimension sizes are not compatible." + : "Sparse dimension size ratio doesn't match the descriptor."); } } @@ -752,9 +923,10 @@ Status ValidateDotDimensionNumbers( // Check that batch dimension numbers and sizes match. for (int64_t i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) { - if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != - rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) { - return fail("Batch dimension sizes must match for lhs/rhs."); + if (!CompatibleDimensionSizes( + lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)), + rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i)))) { + return fail("Batch dimension sizes are not compatible."); } } @@ -802,9 +974,56 @@ Status ValidateDotDimensionNumbers( return result; } -/* static */ StatusOr -ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, - const Shape& lhs, +/* static */ absl::StatusOr ShapeInference::InferSparseDotMetadataShape( + const Shape& operand_shape, const DotDimensionNumbers& dimension_numbers, + const SparsityDescriptor& sparsity, PrimitiveType element_type) { + CHECK(primitive_util::IsUnsignedIntegralType(element_type)); + + // Metadata includes contracting and non-contracting dimensions + // (i.e. excludes batch) of the sparse operand shape. The sparse dimension + // must be contracting. + bool sparse_lhs = sparsity.index() == 0; + auto& contracting_dimensions = + sparse_lhs ? dimension_numbers.lhs_contracting_dimensions() + : dimension_numbers.rhs_contracting_dimensions(); + TF_RET_CHECK( + absl::c_linear_search(contracting_dimensions, sparsity.dimension())); + + // Calculate the number of elements needed to encode the sparsity metadata + // in the sparse dimension. + int64_t metadata_dimension_size = 0; + switch (sparsity.type()) { + case SPARSITY_STRUCTURED_N_M: { + // For 2:4 sparsity, each group of 4 elements has 2 values defined. + // Each 16-bit metadata element contains the data for 4 groups. + int bits_per_value = Log2Ceiling(static_cast(sparsity.m())); + int bits_per_group = sparsity.n() * bits_per_value; + int groups_per_element = + CeilOfRatio(primitive_util::BitWidth(element_type), bits_per_group); + int64_t group_count = + CeilOfRatio(operand_shape.dimensions(sparsity.dimension()), + static_cast(sparsity.n())); + metadata_dimension_size = + CeilOfRatio(group_count, static_cast(groups_per_element)); + break; + } + default: + LOG(FATAL) << "Unsupported sparsity type: " << sparsity.type(); + } + + // Build the resulting shape dimensions. + std::vector dimensions; + std::vector is_dynamic; + for (int64_t i = 0; i < operand_shape.rank(); ++i) { + dimensions.push_back(i != sparsity.dimension() ? operand_shape.dimensions(i) + : metadata_dimension_size); + is_dynamic.push_back(operand_shape.is_dynamic_dimension(i)); + } + return ShapeUtil::MakeShape(element_type, dimensions, is_dynamic); +} + +/* static */ absl::StatusOr +ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, const Shape& rhs) { TF_RET_CHECK(lhs.rank() == rhs.rank()); @@ -857,10 +1076,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ? rhs.is_dynamic_dimension(i) : lhs.is_dynamic_dimension(i); } else { - return InvalidArgument( - "Binary op %s with incompatible shapes: %s and %s.", - HloOpcodeString(operation), ShapeUtil::HumanString(lhs), - ShapeUtil::HumanString(rhs)); + return InvalidArgument("Binary op with incompatible shapes: %s and %s.", + ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs)); } } @@ -868,17 +1086,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, output_dimensions, output_dimensions_is_dynamic); } -/* static */ StatusOr ShapeInference::InferInDimBroadcastShape( +/* static */ absl::StatusOr ShapeInference::InferInDimBroadcastShape( const Shape& smaller_shape, const Shape& larger_shape, absl::Span broadcast_dimensions) { - if (smaller_shape.is_unbounded_dynamic() || - larger_shape.is_unbounded_dynamic()) { - return InvalidArgumentError(StrFormat( - "Unbounded dynamic shapes not supported, but we have %s and %s", - ShapeUtil::HumanString(smaller_shape), - ShapeUtil::HumanString(larger_shape))); - } - if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) { // Reject "magic" inference for binops on different shapes, requiring // the user to provide an explicit broadcast dimension in this case. @@ -887,7 +1097,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, StrFormat("Shapes must be equal rank, but are %s and %s", ShapeUtil::HumanString(smaller_shape), ShapeUtil::HumanString(larger_shape))); - } else if (broadcast_dimensions.size() != smaller_shape.rank()) { + } + + if (broadcast_dimensions.size() != smaller_shape.rank()) { return InvalidArgumentError(StrFormat( "Size of broadcast_dimensions has to match lower-rank operand's " "rank; " @@ -971,7 +1183,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (small_dimension_size == large_dimension_size || (small_dimension_size == 1 && !small_is_dynamic) || (large_dimension_size == 1 && !large_is_dynamic)) { - // Do nothing. It's OK when the size-1 dimension is not static. + // Do nothing. It's OK when the size-1 dimension is not static or when + // it is unbounded dynamic. } else { return InvalidArgumentError( StrFormat("Broadcast dimension %d dynamism mismatch: %s and %s.", i, @@ -994,7 +1207,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return output_shape; } -/* static */ StatusOr ShapeInference::InferElementwiseBinaryOpShape( +/* static */ absl::StatusOr +ShapeInference::InferElementwiseBinaryOpShape( HloOpcode operation, const Shape& lhs, const Shape& rhs, absl::Span broadcast_dimensions) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation")); @@ -1034,7 +1248,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return result; } else if (lhs.rank() == rhs.rank()) { - return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs); + return InferDegenerateDimensionBroadcastShape(lhs, rhs); } else { // Ranks do not match, so perform InDim broadcasting using // broadcast_dimensions. Scalar broadcasting is a special case of this. @@ -1046,18 +1260,38 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, InferInDimBroadcastShape(smaller_shape, larger_shape, broadcast_dimensions)); - return InferDegenerateDimensionBroadcastShape( - operation, indim_broadcast_shape, larger_shape); + return InferDegenerateDimensionBroadcastShape(indim_broadcast_shape, + larger_shape); + } +} + +/* static */ absl::StatusOr> +ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { + // The shape is not scalar, it may have unbounded/bounded dynamic + // dimensions. Inferring the proper shape per op is out of scope of this + // function. + std::optional broadcasted_shape; + for (const Shape& shape : shapes) { + if (!shape.IsArray() || shape.rank() == 0) continue; + if (!broadcasted_shape.has_value()) { + broadcasted_shape = shape; + } + // TODO(jpienaar): The case where we need to compute the broadcasted + // shape by considering multiple of the shapes is not implemented. + // Consider reusing "getBroadcastedType" from mlir/Dialect/Traits.h. + TF_RET_CHECK(ShapeUtil::SameDimensions(broadcasted_shape.value(), shape)) + << "Unimplemented implicit broadcast."; } + return broadcasted_shape; } -/* static */ StatusOr ShapeInference::InferBinaryOpShape( +/* static */ absl::StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) { return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(), /*broadcast_dimensions=*/{}); } -/* static */ StatusOr ShapeInference::InferBinaryOpShape( +/* static */ absl::StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, absl::Span broadcast_dimensions) { VLOG(2) << StrFormat( @@ -1142,13 +1376,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } -/* static */ StatusOr ShapeInference::InferTernaryOpShape( +/* static */ absl::StatusOr ShapeInference::InferTernaryOpShape( HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs, const HloInstruction* ehs) { return InferTernaryOpShape(opcode, lhs->shape(), rhs->shape(), ehs->shape()); } -/* static */ StatusOr ShapeInference::InferTernaryOpShape( +/* static */ absl::StatusOr ShapeInference::InferTernaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); @@ -1163,7 +1397,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } -/* static */ StatusOr ShapeInference::InferVariadicOpShape( +/* static */ absl::StatusOr ShapeInference::InferVariadicOpShape( HloOpcode opcode, absl::Span operands) { std::vector operand_shapes; operand_shapes.reserve(operands.size()); @@ -1173,7 +1407,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InferVariadicOpShape(opcode, operand_shapes); } -/* static */ StatusOr ShapeInference::InferVariadicOpShape( +/* static */ absl::StatusOr ShapeInference::InferVariadicOpShape( HloOpcode opcode, absl::Span operand_shapes) { for (const Shape* shape : operand_shapes) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape)); @@ -1210,7 +1444,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } -/* static */ StatusOr ShapeInference::InferMapShape( +/* static */ absl::StatusOr ShapeInference::InferMapShape( absl::Span arg_shapes, const ProgramShape& to_apply, absl::Span dimensions) { if (arg_shapes.empty()) { @@ -1298,11 +1532,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } - return ShapeUtil::MakeShape(output_shape.element_type(), - arg_shape->dimensions()); + return ShapeUtil::MakeShape( + output_shape.element_type(), arg_shape->dimensions(), + /*dynamic_dimensions=*/ + std::vector(arg_shape->dynamic_dimensions().begin(), + arg_shape->dynamic_dimensions().end())); } -/* static */ StatusOr ShapeInference::InferBatchNormTrainingShape( +/* static */ absl::StatusOr ShapeInference::InferBatchNormTrainingShape( const Shape& operand_shape, const Shape& scale_shape, const Shape& offset_shape, int64_t feature_index) { TF_RETURN_IF_ERROR( @@ -1387,17 +1624,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, Shape output_shape_for_mean_and_var = ShapeUtil::MakeShape( operand_shape.element_type(), {feature_count}, {dynamic_feature}); - if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) { + if (!CompatibleDimensionSizes(ShapeUtil::GetDimension(offset_shape, 0), + feature_count)) { return InvalidArgument( - "The size of offset factor should be the same as feature count," + "The size of offset factor should be compatible with feature count, " "but the size of offset factor is %d " "and the feature count is %d.", ShapeUtil::GetDimension(offset_shape, 0), feature_count); } - if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { + if (!CompatibleDimensionSizes(ShapeUtil::GetDimension(scale_shape, 0), + feature_count)) { return InvalidArgument( - "The size of scale factor should be the same as feature count," + "The size of scale factor should be compatible with feature count, " "but the size of scale factor is %d " "and the feature count is %d.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); @@ -1408,7 +1647,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, &output_shape_for_mean_and_var}); } -/* static */ StatusOr ShapeInference::InferBatchNormInferenceShape( +/* static */ absl::StatusOr ShapeInference::InferBatchNormInferenceShape( const Shape& operand_shape, const Shape& scale_shape, const Shape& offset_shape, const Shape& mean_shape, const Shape& variance_shape, int64_t feature_index) { @@ -1514,36 +1753,38 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } const int64_t feature_count = operand_shape.dimensions(feature_index); - Shape output_shape_for_mean_and_var = - ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count}); - if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) { + if (!CompatibleDimensionSizes(ShapeUtil::GetDimension(offset_shape, 0), + feature_count)) { return InvalidArgument( - "The size of offset factor should be the same as feature count," + "The size of offset factor should be compatible with feature count, " "but the size of offset factor is %d " "and the feature count is %d.", ShapeUtil::GetDimension(offset_shape, 0), feature_count); } - if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { + if (!CompatibleDimensionSizes(ShapeUtil::GetDimension(scale_shape, 0), + feature_count)) { return InvalidArgument( - "The size of scale factor should be the same as feature count," + "The size of scale factor should be compatible with feature count, " "but the size of scale factor is %d " "and the feature count is %d.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } - if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) { + if (!CompatibleDimensionSizes(ShapeUtil::GetDimension(mean_shape, 0), + feature_count)) { return InvalidArgument( - "The size of mean should be the same as feature count," + "The size of mean should be compatible with feature count, " "but the size of mean is %d " "and the feature count is %d.", ShapeUtil::GetDimension(mean_shape, 0), feature_count); } - if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) { + if (!CompatibleDimensionSizes(ShapeUtil::GetDimension(variance_shape, 0), + feature_count)) { return InvalidArgument( - "The size of variance should be the same as feature count," + "The size of variance should be compatible with feature count, " "but the size of variance is %d " "and the feature count is %d.", ShapeUtil::GetDimension(variance_shape, 0), feature_count); @@ -1552,7 +1793,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return operand_shape; } -/* static */ StatusOr ShapeInference::InferBatchNormGradShape( +/* static */ absl::StatusOr ShapeInference::InferBatchNormGradShape( const Shape& operand_shape, const Shape& scale_shape, const Shape& mean_shape, const Shape& var_shape, const Shape& output_grad_shape, int64_t feature_index) { @@ -1663,29 +1904,32 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } const int64_t feature_count = operand_shape.dimensions(feature_index); + bool dynamic_feature = operand_shape.is_dynamic_dimension(feature_index); + Shape feature_shape = ShapeUtil::MakeShape( + operand_shape.element_type(), {feature_count}, {dynamic_feature}); - Shape feature_shape = - ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count}); - - if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) { + if (!CompatibleDimensionSizes(ShapeUtil::GetDimension(mean_shape, 0), + feature_count)) { return InvalidArgument( - "The size of mean should be the same as feature count," + "The size of mean should be compatible with feature count, " "but the size of offset factor is %d " "and the feature count is %d.", ShapeUtil::GetDimension(mean_shape, 0), feature_count); } - if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { + if (!CompatibleDimensionSizes(ShapeUtil::GetDimension(scale_shape, 0), + feature_count)) { return InvalidArgument( - "The size of scale factor should be the same as feature count," + "The size of scale factor should be compatible with feature count, " "but the size of scale factor is %d " "and the feature count is %d.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } - if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) { + if (!CompatibleDimensionSizes(ShapeUtil::GetDimension(var_shape, 0), + feature_count)) { return InvalidArgument( - "The size of variance should be the same as feature count," + "The size of variance should be compatible with feature count, " "but the size of variance is %d " "and the feature count is %d.", ShapeUtil::GetDimension(var_shape, 0), feature_count); @@ -1693,11 +1937,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, // Verify operand_shape and output_grad_shape have same bounds. for (int64_t i = 0; i < operand_shape.rank(); ++i) { - if (ShapeUtil::GetDimension(operand_shape, i) != - ShapeUtil::GetDimension(output_grad_shape, i)) { + if (!CompatibleDimensionSizes( + ShapeUtil::GetDimension(operand_shape, i), + ShapeUtil::GetDimension(output_grad_shape, i))) { return InvalidArgument( - "The bounds of operand shape should be the same as output_grad's," - "but the bound of operand_shape at dimension %d is %d " + "The bounds of operand shape should be compatible with " + "output_grad's, but the bound of operand_shape at dimension %d is %d " "and the bound of output_grad_shape is %d.", i, ShapeUtil::GetDimension(operand_shape, i), ShapeUtil::GetDimension(output_grad_shape, i)); @@ -1708,7 +1953,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, {&operand_shape, &feature_shape, &feature_shape}); } -/* static */ StatusOr ShapeInference::InferConvolveShape( +/* static */ absl::StatusOr ShapeInference::InferConvolveShape( const Shape& lhs, const Shape& rhs, int64_t feature_group_count, int64_t batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dnums, @@ -1909,8 +2154,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, dnums.ShortDebugString()); } - Shape base_shape = - ShapeUtil::MakeShape(lhs.element_type(), input_spatial_dims); + std::vector dynamic_dimensions(input_spatial_dims.size()); + for (auto it = input_spatial_dims.begin(); it != input_spatial_dims.end(); + ++it) { + dynamic_dimensions[it - input_spatial_dims.begin()] = + IsUnboundedDynamicSize(*it); + } + Shape base_shape = ShapeUtil::MakeShape( + lhs.element_type(), input_spatial_dims, dynamic_dimensions); TF_ASSIGN_OR_RETURN( Shape window_output_shape, InferWindowOutputShape(base_shape, window, lhs.element_type())); @@ -1966,7 +2217,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeShape(type, dimensions, is_dynamic); } -/* static */ StatusOr ShapeInference::InferFftShape( +/* static */ absl::StatusOr ShapeInference::InferFftShape( const Shape& in, const FftType fft_type, const absl::Span fft_length) { const int64_t fft_rank = fft_length.size(); @@ -2055,7 +2306,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, #undef RET_CHECK_RANK } -/* static */ StatusOr ShapeInference::InferTriangularSolveShape( +/* static */ absl::StatusOr ShapeInference::InferTriangularSolveShape( const Shape& a, const Shape& b, const TriangularSolveOptions& options) { if ((!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) || a.element_type() != b.element_type()) { @@ -2075,13 +2326,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "Arguments to triangular solve must have equal rank; got %s and %s.", b.ToString(), a.ToString()); } - if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) { + if (!CompatibleDimensionSizes(a.dimensions(a.rank() - 2), + a.dimensions(a.rank() - 1))) { return InvalidArgument( "The two minor dimensions of 'a' must have equal size, got %s.", a.ToString()); } - if (a.dimensions(a.rank() - 1) != - b.dimensions(b.rank() - (options.left_side() ? 2 : 1))) { + if (!CompatibleDimensionSizes( + a.dimensions(a.rank() - 1), + b.dimensions(b.rank() - (options.left_side() ? 2 : 1)))) { return InvalidArgument( "The shared dimension of 'a' and 'b' does not match, got shapes %s and " "%s", @@ -2106,7 +2359,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return b; } -/* static */ StatusOr ShapeInference::InferCholeskyShape( +/* static */ absl::StatusOr ShapeInference::InferCholeskyShape( const Shape& a) { if (!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) { return InvalidArgument( @@ -2119,15 +2372,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The 'a' argument to Cholesky must have rank >= 2, got shape %s", a.ToString()); } - if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) { + if (!CompatibleDimensionSizes(a.dimensions(a.rank() - 2), + a.dimensions(a.rank() - 1))) { return InvalidArgument( - "The two minor dimensions of 'a' must have equal size, got %s.", + "The two minor dimensions of 'a' must have compatible size, got %s.", a.ToString()); } return a; } -/* static */ StatusOr ShapeInference::InferAllGatherShape( +/* static */ absl::StatusOr ShapeInference::InferAllGatherShape( absl::Span operand_shapes, int64_t all_gather_dimension, int64_t shard_count) { TF_RET_CHECK(all_gather_dimension >= 0); @@ -2151,7 +2405,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeTupleShape(output_shapes); } -/* static */ StatusOr ShapeInference::InferAllGatherStartShape( +/* static */ absl::StatusOr ShapeInference::InferAllGatherStartShape( absl::Span operand_shapes, int64_t all_gather_dimension, int64_t shard_count) { TF_ASSIGN_OR_RETURN( @@ -2166,12 +2420,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeTupleShapeWithPtrs({&input_shape, &ag_shape}); } -/* static */ StatusOr ShapeInference::InferAllGatherDoneShape( +/* static */ absl::StatusOr ShapeInference::InferAllGatherDoneShape( const Shape& all_gather_start_shape) { return ShapeUtil::GetTupleElementShape(all_gather_start_shape, 1); } -/* static */ StatusOr ShapeInference::InferAllReduceShape( +/* static */ absl::StatusOr ShapeInference::InferAllReduceShape( absl::Span operand_shapes) { for (const Shape* operand_shape : operand_shapes) { TF_RETURN_IF_ERROR( @@ -2183,7 +2437,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes); } -/* static */ StatusOr ShapeInference::InferReduceScatterShape( +/* static */ absl::StatusOr ShapeInference::InferReduceScatterShape( absl::Span operand_shapes, int64_t scatter_dimension, int64_t shard_count) { TF_RET_CHECK(scatter_dimension >= 0); @@ -2218,18 +2472,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeTupleShape(output_shapes); } -/* static */ StatusOr ShapeInference::InferAllReduceStartShape( +/* static */ absl::StatusOr ShapeInference::InferAllReduceStartShape( absl::Span operand_shapes) { return InferAllReduceShape(operand_shapes); } -/* static */ StatusOr ShapeInference::InferAllReduceDoneShape( +/* static */ absl::StatusOr ShapeInference::InferAllReduceDoneShape( const Shape& operand_shape) { // The returned value from AllReduceDone is the operand forwarded. return operand_shape; } -/* static */ StatusOr ShapeInference::InferAllToAllShape( +/* static */ absl::StatusOr ShapeInference::InferAllToAllShape( const Shape& shape, int64_t split_dimension, int64_t concat_dimension, int64_t split_count) { TF_RET_CHECK(split_count > 0); @@ -2256,7 +2510,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeShape(shape.element_type(), new_dimensions); } -/* static */ StatusOr ShapeInference::InferAllToAllTupleShape( +/* static */ absl::StatusOr ShapeInference::InferAllToAllTupleShape( absl::Span operand_shapes) { // An Alltoall HLO instruction receives N operands (with the same shape) and // returns a tuple that contains N array shapes. @@ -2274,8 +2528,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes); } +/* static */ absl::StatusOr +ShapeInference::InferCollectiveBroadcastShape( + absl::Span operand_shapes) { + TF_RETURN_IF_ERROR( + ExpectArray(*(operand_shapes[0]), "operand of collective-broadcast")); + return *(operand_shapes[0]); +} -/* static */ StatusOr ShapeInference::InferCollectivePermuteShape( +/* static */ absl::StatusOr ShapeInference::InferCollectivePermuteShape( absl::Span operand_shapes) { if (operand_shapes.size() == 1) { TF_RETURN_IF_ERROR( @@ -2287,7 +2548,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } -/* static */ StatusOr ShapeInference::InferCollectivePermuteStartShape( +/* static */ absl::StatusOr +ShapeInference::InferCollectivePermuteStartShape( absl::Span operand_shapes, absl::Span context_shapes) { absl::InlinedVector shapes; @@ -2304,13 +2566,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeTupleShapeWithPtrs(shapes); } -/* static */ StatusOr ShapeInference::InferCollectivePermuteDoneShape( - const Shape& operand_shape) { +/* static */ absl::StatusOr +ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { TF_RET_CHECK(operand_shape.IsTuple()); return ShapeUtil::GetTupleElementShape(operand_shape, 1); } -/* static */ StatusOr ShapeInference::InferReduceShape( +/* static */ absl::StatusOr ShapeInference::InferReduceShape( absl::Span arg_shapes, absl::Span dimensions_to_reduce, const ProgramShape& to_apply) { @@ -2329,8 +2591,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, for (int64_t i = 1; i < num_reduced_args; ++i) { if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) { return InvalidArgument( - "All reduced tensors must have the same dimension. Tensor 0 has " - "shape %s, Tensor %d has shape %s", + "All reduced tensors must have compatible dimension. Tensor at index " + "0 has shape %s, and tensor at index %d has shape %s.", ShapeUtil::HumanString(*reduced_args[0]), i, ShapeUtil::HumanString(*reduced_args[i])); } @@ -2387,7 +2649,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } -/* static */ StatusOr ShapeInference::InferReduceWindowShape( +/* static */ absl::StatusOr ShapeInference::InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value_shape, const Window& window, const ProgramShape& to_apply_shape) { TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape}, @@ -2396,7 +2658,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InferReduceWindowShape(operand_shape, init_value_shape, window); } -/* static */ StatusOr ShapeInference::InferReduceWindowShape( +/* static */ absl::StatusOr ShapeInference::InferReduceWindowShape( absl::Span operands, absl::Span init_values, const Window& window, const ProgramShape& to_apply_shape) { @@ -2437,7 +2699,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } -/* static */ StatusOr ShapeInference::InferReduceWindowShape( +/* static */ absl::StatusOr ShapeInference::InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value_shape, const Window& window) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window")); @@ -2445,7 +2707,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, init_value_shape.element_type()); } -/* static */ StatusOr ShapeInference::InferSelectAndScatterShape( +/* static */ absl::StatusOr ShapeInference::InferSelectAndScatterShape( const Shape& operand_shape, const ProgramShape& select_shape, const Window& window, const Shape& source_shape, const Shape& init_value_shape, const ProgramShape& scatter_shape) { @@ -2504,7 +2766,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return operand_shape; } -/* static */ StatusOr ShapeInference::InferGetDimensionSizeShape( +/* static */ absl::StatusOr ShapeInference::InferGetDimensionSizeShape( const Shape& shape, int64_t dimension) { if (dimension < 0 || dimension >= shape.rank()) { return InvalidArgument("GetDimensionSize dimension out of bounds: %d.", @@ -2523,7 +2785,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeShape(S32, {}); } -/* static */ StatusOr ShapeInference::InferSetDimensionSizeShape( +/* static */ absl::StatusOr ShapeInference::InferSetDimensionSizeShape( const Shape& shape, const Shape& val_shape, int64_t dimension) { if (dimension < 0 || dimension >= shape.rank()) { return InvalidArgument("SetDimensionSize dimension out of bounds: %d.", @@ -2549,14 +2811,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return result; } -/* static */ StatusOr ShapeInference::InferWindowFromDimensions( +/* static */ absl::StatusOr ShapeInference::InferWindowFromDimensions( absl::Span window_dimensions, absl::Span window_strides, absl::Span> padding, absl::Span lhs_dilation, absl::Span rhs_dilation, std::optional> window_reversal) { - const auto verify_size = [&](const size_t x, const char* x_name) { + const auto verify_size = [&](const size_t x, + const char* x_name) -> absl::Status { if (x == 0 || x == window_dimensions.size()) { return OkStatus(); } else { @@ -2611,7 +2874,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return window; } -/* static */ StatusOr ShapeInference::InferSliceShape( +/* static */ absl::StatusOr ShapeInference::InferSliceShape( const Shape& arg, absl::Span starts, absl::Span limits, absl::Span strides) { auto error = [&](const std::string& message) { @@ -2652,11 +2915,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (start_index < 0) { return InvalidArgument("Negative start index to slice: %d.", start_index); } - if (limit_index > arg.dimensions(dimension)) { + int64_t dimension_size = arg.dimensions(dimension); + if (!arg.is_unbounded_dynamic_dimension(dimension) && + limit_index > dimension_size) { return error( StrFormat("limit index (%d) must be less than or equal to dimension " "size (%d)", - limit_index, arg.dimensions(dimension))); + limit_index, dimension_size)); } VLOG(2) << StrFormat("starts[%d] = %d", dimension, start_index); VLOG(2) << StrFormat("limits[%d] = %d", dimension, limit_index); @@ -2678,13 +2943,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (sizes[i] == 1) { continue; } - is_dynamic[i] = arg.is_dynamic_dimension(i); + is_dynamic[i] = arg.is_bounded_dynamic_dimension(i); } return ShapeUtil::MakeShape(arg.element_type(), sizes, is_dynamic); } -/* static */ StatusOr ShapeInference::InferDynamicSliceShape( +/* static */ absl::StatusOr ShapeInference::InferDynamicSliceShape( const Shape& operand_shape, absl::Span start_index_shapes, absl::Span slice_sizes, bool allow_scalar_indices) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); @@ -2775,7 +3040,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument("Negative size index to dynamic slice: %d.", slice_dim_size); } - if (slice_dim_size > input_dim_size) { + if (!IsUnboundedDynamicSize(input_dim_size) && + slice_dim_size > input_dim_size) { return InvalidArgument( "Slice dim size %d greater than dynamic slice dimension: %d.", slice_dim_size, input_dim_size); @@ -2797,7 +3063,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return result; } -/* static */ StatusOr ShapeInference::InferDynamicUpdateSliceShape( +/* static */ absl::StatusOr ShapeInference::InferDynamicUpdateSliceShape( const Shape& operand_shape, const Shape& update_shape, absl::Span start_index_shapes, bool allow_scalar_indices) { TF_RETURN_IF_ERROR( @@ -2934,7 +3200,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return result_shape; } -/*static */ StatusOr ShapeInference::InferReverseShape( +/*static */ absl::StatusOr ShapeInference::InferReverseShape( const Shape& operand_shape, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse")); if (!AllUnique(dimensions)) { @@ -2950,7 +3216,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return operand_shape; } -/* static */ StatusOr ShapeInference::InferGetTupleElementShape( +/* static */ absl::StatusOr ShapeInference::InferGetTupleElementShape( const Shape& arg, int64_t index) { if (!arg.IsTuple()) { return InvalidArgument( @@ -2968,7 +3234,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return arg.tuple_shapes(index); } -/* static */ StatusOr ShapeInference::InferWhileShape( +/* static */ absl::StatusOr ShapeInference::InferWhileShape( const ProgramShape& condition, const ProgramShape& body, const Shape& init) { // Check the number of parameters for given computations. @@ -3005,7 +3271,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return init; } -/* static */ StatusOr ShapeInference::InferConditionalShape( +/* static */ absl::StatusOr ShapeInference::InferConditionalShape( const Shape& branch_index, absl::Span branch_computations, absl::Span branch_operands) { @@ -3080,10 +3346,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return result; } -/* static */ StatusOr ShapeInference::InferBroadcastShape( +/* static */ absl::StatusOr ShapeInference::InferBroadcastShape( const Shape& operand, absl::Span broadcast_sizes) { + // This method is used to infer shape for xla::BroadcastInDim. TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast")); + TF_RET_CHECK(!operand.is_unbounded_dynamic()); for (int64_t size : broadcast_sizes) { + if (size == Shape::kUnboundedSize) { + return InvalidArgument("Non-broadcast dimensions must not be dynamic."); + } if (size < 0) { return InvalidArgument("Broadcast with negative dimension size %d.", size); @@ -3105,11 +3376,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return result; } -/* static */ StatusOr ShapeInference::InferBroadcastShape( +/* static */ absl::StatusOr ShapeInference::InferBroadcastShape( const Shape& operand_shape, const Shape& output_shape, absl::Span broadcast_dimensions) { + // This method is used to infer shape for xla::BroadcastInDim. TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast")); TF_RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast")); + TF_RET_CHECK(!output_shape.is_unbounded_dynamic()); const int64_t operand_rank = operand_shape.rank(); const int64_t output_rank = output_shape.rank(); if (operand_rank > output_rank) { @@ -3129,7 +3402,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument("Broadcast dimension %lld is out of bound", broadcast_dimensions[i]); } - if (operand_shape.dimensions(i) != + if (!operand_shape.is_unbounded_dynamic_dimension(i) && + operand_shape.dimensions(i) != output_shape.dimensions(broadcast_dimensions[i]) && operand_shape.dimensions(i) != 1) { return InvalidArgument( @@ -3139,8 +3413,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, i, operand_shape.dimensions(i), broadcast_dimensions[i], output_shape.dimensions(broadcast_dimensions[i])); } - if (operand_shape.is_dynamic_dimension(i) != - output_shape.is_dynamic_dimension(broadcast_dimensions[i])) { + if (!operand_shape.is_unbounded_dynamic_dimension(i) && + operand_shape.is_bounded_dynamic_dimension(i) != + output_shape.is_bounded_dynamic_dimension( + broadcast_dimensions[i])) { return InvalidArgument( "Broadcast input and output dynamism mismatch: %s and %s", operand_shape.ToString(), output_shape.ToString()); @@ -3157,7 +3433,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return output_shape; } -/* static */ StatusOr ShapeInference::InferDynamicReshapeShape( +/* static */ absl::StatusOr ShapeInference::InferDynamicReshapeShape( const Shape& operand, absl::Span dim_size_shapes, absl::Span new_size_bounds, const std::vector& dims_are_dynamic) { @@ -3189,16 +3465,23 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return inferred_shape; } -/* static */ StatusOr ShapeInference::InferReshapeShape( +/* static */ absl::StatusOr ShapeInference::InferReshapeShape( const Shape& operand, absl::Span dimensions, absl::Span new_sizes, int64_t inferred_dimension) { TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape")); - Shape inferred_shape = ShapeUtil::MakeShape(operand.element_type(), new_sizes); VLOG(3) << "Reshape inferred shape: " << ShapeUtil::HumanString(inferred_shape); + TF_RET_CHECK(!inferred_shape.is_unbounded_dynamic()) + << "Reshaping with unbounded result shape is not supported."; + if (operand.is_unbounded_dynamic()) { + TF_RET_CHECK(!operand.is_bounded_dynamic()) + << "Reshape operand with bounded and unbounded dynamism not supported."; + return inferred_shape; + } + if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) { return InvalidArgument( "Reshape operation has mismatched element counts: from=%d (%s) " @@ -3341,7 +3624,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return inferred_shape; } -/* static */ StatusOr ShapeInference::InferTransposeShape( +/* static */ absl::StatusOr ShapeInference::InferTransposeShape( const Shape& operand, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); @@ -3358,22 +3641,37 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::PermuteDimensions(dimensions, operand); } -/* static */ StatusOr ShapeInference::InferClampShape( +/* static */ absl::StatusOr ShapeInference::InferClampShape( const Shape& min, const Shape& operand, const Shape& max) { TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min")); TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand")); TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max")); - if (!ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) || - !ShapeUtil::CompatibleIgnoringFpPrecision(max, operand)) { + // min, operand, and max must have compatible element types. + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) || + !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand) || + !ShapeUtil::SameElementTypeIgnoringFpPrecision(min, max)) { return InvalidArgument( - "Clamp with different shapes: %s, %s, %s.", ShapeUtil::HumanString(min), - ShapeUtil::HumanString(operand), ShapeUtil::HumanString(max)); + "Clamp with incompatible element types: %s, %s, % s.", + ShapeUtil::HumanString(min), ShapeUtil::HumanString(operand), + ShapeUtil::HumanString(max)); + } + + if ((!ShapeUtil::IsScalar(min) && + !ShapeUtil::CompatibleIgnoringFpPrecision(min, operand)) || + (!ShapeUtil::IsScalar(max) && + !ShapeUtil::CompatibleIgnoringFpPrecision(max, operand)) || + (!ShapeUtil::IsScalar(min) && !ShapeUtil::IsScalar(max) && + !ShapeUtil::CompatibleIgnoringFpPrecision(min, max))) { + return InvalidArgument("Clamp with incompatible shapes: %s, %s, %s.", + ShapeUtil::HumanString(min), + ShapeUtil::HumanString(operand), + ShapeUtil::HumanString(max)); } return operand; } -/* static */ StatusOr ShapeInference::InferSelectShape( +/* static */ absl::StatusOr ShapeInference::InferSelectShape( const Shape& pred, const Shape& on_true, const Shape& on_false) { TF_RETURN_IF_ERROR(ExpectArray(pred, "select pred")); TF_RETURN_IF_ERROR(ExpectArray(on_true, "select on-true")); @@ -3384,33 +3682,52 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "Operands to select must be the same shape; got %s and %s.", ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false)); } + if (pred.element_type() != PRED) { return InvalidArgument( "Select's pred operand must have PRED element type; got %s.", ShapeUtil::HumanString(pred)); } - if (!Shape::Equal() - .IgnoreElementType() - .IgnoreLayout() - .IgnoreDynamicDimension()(pred, on_true)) { + + // If pred is not scalar, it must be compatible with on_true and on_false + if ((!ShapeUtil::IsScalar(pred) && + (!ShapeUtil::CompatibleIgnoringElementType(pred, on_true) || + !ShapeUtil::CompatibleIgnoringElementType(pred, on_false))) || + !ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) { return InvalidArgument( "Operands to select and predicate must be the same shape; got %s and " - "%s.", - ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(pred)); + "%s and %s.", + ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false), + ShapeUtil::HumanString(pred)); } + Shape full_rank_shape = ShapeUtil::IsScalar(pred) ? on_true : pred; Shape result = ShapeUtil::ChangeElementType( - pred, ShapeUtil::HigherPrecisionElementType(on_true, on_false)); - for (int64_t dimension = 0; dimension < pred.rank(); ++dimension) { - result.set_dynamic_dimension(dimension, - pred.is_dynamic_dimension(dimension) || - on_true.is_dynamic_dimension(dimension) || - on_false.is_dynamic_dimension(dimension)); + full_rank_shape, + ShapeUtil::HigherPrecisionElementType(on_true, on_false)); + for (int64_t dimension = 0; dimension < full_rank_shape.rank(); ++dimension) { + if (on_true.is_unbounded_dynamic_dimension(dimension) || + on_false.is_unbounded_dynamic_dimension(dimension)) { + absl::StatusOr inferred = InferMostSpecificDimAndBound( + dimension, on_true.dimensions(dimension), + on_false.dimensions(dimension), on_true.dimensions(dimension), + on_false.dimensions(dimension)); + result.set_dimensions(dimension, (*inferred).dimension); + result.set_dynamic_dimension( + dimension, on_true.is_dynamic_dimension(dimension) && + on_false.is_dynamic_dimension(dimension)); + } else { + result.set_dynamic_dimension( + dimension, (!ShapeUtil::IsScalar(pred) && + pred.is_dynamic_dimension(dimension)) || + on_true.is_dynamic_dimension(dimension) || + on_false.is_dynamic_dimension(dimension)); + } } return std::move(result); } -/* static */ StatusOr ShapeInference::InferCallShape( +/* static */ absl::StatusOr ShapeInference::InferCallShape( absl::Span arg_shapes, const ProgramShape& to_apply) { // The applied function's arity equals the number of arguments. if (arg_shapes.size() != to_apply.parameters_size()) { @@ -3474,8 +3791,10 @@ static Status ValidateGatherDimensionNumbers( } } - if (dim_numbers.start_index_map_size() != - start_indices_shape[dim_numbers.index_vector_dim()]) { + if (!IsUnboundedDynamicSize( + start_indices_shape[dim_numbers.index_vector_dim()]) && + dim_numbers.start_index_map_size() != + start_indices_shape[dim_numbers.index_vector_dim()]) { return InvalidArgument( "Gather op has %d elements in start_index_map and the " "bound of dimension index_vector_dim=%d of start_indices is " @@ -3534,7 +3853,7 @@ static Status ValidateGatherDimensionNumbers( return OkStatus(); } -/*static*/ StatusOr ShapeInference::InferGatherShape( +/*static*/ absl::StatusOr ShapeInference::InferGatherShape( const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, absl::Span slice_sizes) { @@ -3603,6 +3922,7 @@ static Status ValidateGatherDimensionNumbers( } for (int i = 0; i < slice_sizes.size(); i++) { + if (input_shape.is_unbounded_dynamic_dimension(i)) continue; int64_t slice_size = slice_sizes[i]; int64_t corresponding_input_size = input_shape.dimensions(i); if (slice_size < 0 || slice_size > corresponding_input_size) { @@ -3732,8 +4052,9 @@ Status ValidateScatterDimensionNumbers( } // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers. - if (dim_numbers.scatter_dims_to_operand_dims_size() != - scatter_indices_shape[dim_numbers.index_vector_dim()]) { + if (!CompatibleDimensionSizes( + dim_numbers.scatter_dims_to_operand_dims_size(), + scatter_indices_shape[dim_numbers.index_vector_dim()])) { return InvalidArgument( "Scatter op has %d elements in scatter_dims_to_operand_dims and the " "bound of dimension index_vector_dim=%d of scatter_indices is %d. " @@ -3770,7 +4091,7 @@ Status ValidateScatterDimensionNumbers( } // namespace -/*static*/ StatusOr ShapeInference::InferScatterShape( +/*static*/ absl::StatusOr ShapeInference::InferScatterShape( absl::Span arg_shapes, const ProgramShape& to_apply_shape, const ScatterDimensionNumbers& scatter_dim_numbers) { @@ -3865,8 +4186,9 @@ Status ValidateScatterDimensionNumbers( if (scatter_dims_seen == scatter_dim_numbers.index_vector_dim()) { ++scatter_dims_seen; } - if (updates_shape.dimensions(i) != - expanded_scatter_indices_shape[scatter_dims_seen]) { + if (!CompatibleDimensionSizes( + updates_shape.dimensions(i), + expanded_scatter_indices_shape[scatter_dims_seen])) { return InvalidArgument( "Bounds of the scatter dimensions of updates must be same as the " "bounds of the corresponding dimensions of scatter indices. For " diff --git a/xla/service/shape_inference.h b/xla/service/shape_inference.h index e04a59a2bf53c..f233bb045fc4c 100644 --- a/xla/service/shape_inference.h +++ b/xla/service/shape_inference.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,13 +19,16 @@ limitations under the License. #ifndef XLA_SERVICE_SHAPE_INFERENCE_H_ #define XLA_SERVICE_SHAPE_INFERENCE_H_ +#include +#include +#include #include #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape.h" #include "xla/statusor.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" namespace xla { @@ -44,142 +47,152 @@ class ShapeInference { public: // Infers the shape produced by applying the given unary operation to the // given input shape. - static StatusOr InferUnaryOpShape(HloOpcode opcode, - const Shape& shape); - static StatusOr InferUnaryOpShape(HloOpcode opcode, - const HloInstruction* operand); + static absl::StatusOr InferUnaryOpShape(HloOpcode opcode, + const Shape& shape); + static absl::StatusOr InferUnaryOpShape(HloOpcode opcode, + const HloInstruction* operand); + + // For ternary ops, only scalar broadcasting is supported. + // Return the non-scalar shape that all scalars should be broadcasted too + // Returns status if non-scalar operands do not match. + // Returns first shape when all shapes are scalar. + static absl::StatusOr> InferScalarBroadcastShape( + absl::Span shapes); // Infers the shape produced by applying the given binary operation to the // given input shapes. - static StatusOr InferBinaryOpShape( + static absl::StatusOr InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, absl::Span broadcast_dimensions); - static StatusOr InferBinaryOpShape(HloOpcode opcode, - const HloInstruction* lhs, - const HloInstruction* rhs); + static absl::StatusOr InferBinaryOpShape(HloOpcode opcode, + const HloInstruction* lhs, + const HloInstruction* rhs); // Infers the shape produced by applying the given ternary operation to the // given input shapes. - static StatusOr InferTernaryOpShape(HloOpcode opcode, const Shape& lhs, - const Shape& rhs, - const Shape& ehs); - static StatusOr InferTernaryOpShape(HloOpcode opcode, - const HloInstruction* lhs, - const HloInstruction* rhs, - const HloInstruction* ehs); + static absl::StatusOr InferTernaryOpShape(HloOpcode opcode, + const Shape& lhs, + const Shape& rhs, + const Shape& ehs); + static absl::StatusOr InferTernaryOpShape(HloOpcode opcode, + const HloInstruction* lhs, + const HloInstruction* rhs, + const HloInstruction* ehs); // Infers the shape produced by applying the given variadic operation to the // given input operand shapes. - static StatusOr InferVariadicOpShape( + static absl::StatusOr InferVariadicOpShape( HloOpcode opcode, absl::Span operand_shapes); - static StatusOr InferVariadicOpShape( + static absl::StatusOr InferVariadicOpShape( HloOpcode opcode, absl::Span operands); // Infers the shape produced by applying the given mapping computation shape // to the given operand shapes. - static StatusOr InferMapShape( + static absl::StatusOr InferMapShape( absl::Span arg_shapes, const ProgramShape& to_apply, absl::Span dimensions); // Infers the shape produced by InferBatchNormTraining with the given // operands. - static StatusOr InferBatchNormTrainingShape(const Shape& operand_shape, - const Shape& scale_shape, - const Shape& offset_shape, - int64_t feature_index); + static absl::StatusOr InferBatchNormTrainingShape( + const Shape& operand_shape, const Shape& scale_shape, + const Shape& offset_shape, int64_t feature_index); // Infers the shape produced by InferBatchNormInference with the given // operands. - static StatusOr InferBatchNormInferenceShape( + static absl::StatusOr InferBatchNormInferenceShape( const Shape& operand_shape, const Shape& scale_shape, const Shape& offset_shape, const Shape& mean_shape, const Shape& variance_shape, int64_t feature_index); // Infers the shape produced by InferBatchNormGrad with the given operands. - static StatusOr InferBatchNormGradShape(const Shape& operand_shape, - const Shape& scale_shape, - const Shape& mean_shape, - const Shape& var_shape, - const Shape& output_grad_shape, - int64_t feature_index); + static absl::StatusOr InferBatchNormGradShape( + const Shape& operand_shape, const Shape& scale_shape, + const Shape& mean_shape, const Shape& var_shape, + const Shape& output_grad_shape, int64_t feature_index); // Infers the shape produced by applying the given convolutional filter (rhs) // to lhs in the way specified by the fields on window. An optional // preferred_element_type can be specified to upcast the element type. - static StatusOr InferConvolveShape( + static absl::StatusOr InferConvolveShape( const Shape& lhs, const Shape& rhs, int64_t feature_group_count, int64_t batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, std::optional preferred_element_type); // Infers the shape produced by the given FFT type on the given operand. - static StatusOr InferFftShape(const Shape& in, FftType fft_type, - absl::Span fft_length); + static absl::StatusOr InferFftShape( + const Shape& in, FftType fft_type, absl::Span fft_length); // Infers the shape produced by the given triangular solve operation. - static StatusOr InferTriangularSolveShape( + static absl::StatusOr InferTriangularSolveShape( const Shape& a, const Shape& b, const TriangularSolveOptions& options); // Infers the shape produced by the given triangular solve operation. - static StatusOr InferCholeskyShape(const Shape& a); + static absl::StatusOr InferCholeskyShape(const Shape& a); // Infers the shape produced by an all-gather with the given operand shape, // concat dimension, and shard count. - static StatusOr InferAllGatherShape( + static absl::StatusOr InferAllGatherShape( absl::Span operand_shapes, int64_t all_gather_dimension, int64_t shard_count); // Infers the shape produced by an all-gather-start with the given operand // shape, concat dimension, and shard count. - static StatusOr InferAllGatherStartShape( + static absl::StatusOr InferAllGatherStartShape( absl::Span operand_shapes, int64_t all_gather_dimension, int64_t shard_count); // Infers the shape produced by an all-gather-done given a certain // all-gather-start shape. - static StatusOr InferAllGatherDoneShape( + static absl::StatusOr InferAllGatherDoneShape( const Shape& all_gather_start_shape); // Infers the shape produced by a cross replica sum with the given operand // shapes. - static StatusOr InferAllReduceShape( + static absl::StatusOr InferAllReduceShape( absl::Span operand_shapes); // Infers the shape produced by a reduce-scatter with the given operand // shape, scatter dimension, and shard count. - static StatusOr InferReduceScatterShape( + static absl::StatusOr InferReduceScatterShape( absl::Span operand_shapes, int64_t scatter_dimension, int64_t shard_count); // Infers the shape produced by a cross replica sum start. - static StatusOr InferAllReduceStartShape( + static absl::StatusOr InferAllReduceStartShape( absl::Span operand_shapes); // Infers the shape produced by a cross replica sum done. - static StatusOr InferAllReduceDoneShape(const Shape& operand_shape); + static absl::StatusOr InferAllReduceDoneShape( + const Shape& operand_shape); // Infers final shape of an Alltoall operation that is created by the xla // builder. - static StatusOr InferAllToAllShape(const Shape& shape, - int64_t split_dimension, - int64_t concat_dimension, - int64_t split_count); + static absl::StatusOr InferAllToAllShape(const Shape& shape, + int64_t split_dimension, + int64_t concat_dimension, + int64_t split_count); // Infers the shape of an HLO all-to-all instruction. - static StatusOr InferAllToAllTupleShape( + static absl::StatusOr InferAllToAllTupleShape( + absl::Span operand_shapes); + + // Infers the shape of a collective broadcast operation. + static absl::StatusOr InferCollectiveBroadcastShape( absl::Span operand_shapes); // Infers the shape of a collective permute operation. - static StatusOr InferCollectivePermuteShape( + static absl::StatusOr InferCollectivePermuteShape( absl::Span operand_shapes); // Infers the shape of a collective permute start operation. - static StatusOr InferCollectivePermuteStartShape( + static absl::StatusOr InferCollectivePermuteStartShape( absl::Span operand_shapes, absl::Span context_shapes); // Infers the shape of a collective permute operation. - static StatusOr InferCollectivePermuteDoneShape( + static absl::StatusOr InferCollectivePermuteDoneShape( const Shape& operand_shape); // Infers the shape produced by applying the given reduction computation @@ -188,58 +201,57 @@ class ShapeInference { // If pass_index is true, the reduce function is invoked with the element // index as the leading parameter, and the program shape should match // accordingly (or an error will result). - static StatusOr InferReduceShape( + static absl::StatusOr InferReduceShape( absl::Span arg_shapes, absl::Span dimensions_to_reduce, const ProgramShape& to_apply); // Infers the shape produced by applying the given computation to the operand // shape with the given window and stride dimensions. - static StatusOr InferReduceWindowShape( + static absl::StatusOr InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value, const Window& window, const ProgramShape& to_apply_shape); - static StatusOr InferReduceWindowShape(const Shape& operand_shape, - const Shape& init_value, - const Window& window); - static StatusOr InferReduceWindowShape( + static absl::StatusOr InferReduceWindowShape( + const Shape& operand_shape, const Shape& init_value, + const Window& window); + static absl::StatusOr InferReduceWindowShape( absl::Span operands, absl::Span init_values, const Window& window, const ProgramShape& to_apply_shape); - static StatusOr InferReduceWindowShape( + static absl::StatusOr InferReduceWindowShape( absl::Span operands, absl::Span init_values, const Window& window); // Infers the shape produced by scattering the given source shape to the // selected indices of each window on the operand shape. - static StatusOr InferSelectAndScatterShape( + static absl::StatusOr InferSelectAndScatterShape( const Shape& operand_shape, const ProgramShape& select_shape, const Window& window, const Shape& source_shape, const Shape& init_value_shape, const ProgramShape& scatter_shape); // Infers the shape produced by a reverse operation that reverses the order // of the elements in the given dimensions. - static StatusOr InferReverseShape( + static absl::StatusOr InferReverseShape( const Shape& operand_shape, absl::Span dimensions); // Infers the shape produced by a slice operation spanning from the starts to // the limits in the original shape's dimensions. // // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16] - static StatusOr InferSliceShape(const Shape& arg, - absl::Span starts, - absl::Span limits, - absl::Span strides); + static absl::StatusOr InferSliceShape( + const Shape& arg, absl::Span starts, + absl::Span limits, absl::Span strides); // Infers the shape produced by a dynamic slice operation of size specified // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. - static StatusOr InferDynamicSliceShape( + static absl::StatusOr InferDynamicSliceShape( const Shape& operand_shape, absl::Span start_index_shapes, absl::Span slice_sizes, bool allow_scalar_indices = true); // Infers the shape produced by a dynamic update slice operation based // on the shape of operand and update. - static StatusOr InferDynamicUpdateSliceShape( + static absl::StatusOr InferDynamicUpdateSliceShape( const Shape& operand_shape, const Shape& update_shape, absl::Span start_index_shapes, bool allow_scalar_indices = true); @@ -248,108 +260,113 @@ class ShapeInference { // the given input shape. This is essential for operations on tuples, because // it is impossible to infer the type that comes out of the tuple indexing if // it is not a compile time constant. - static StatusOr InferGetTupleElementShape(const Shape& arg, - int64_t index); + static absl::StatusOr InferGetTupleElementShape(const Shape& arg, + int64_t index); // Infers the shape produced from a while node. condition and body are the // shapes of computations for the condition and the body of a while node, and // init is the shape of data initially passed in to the body as an argument. // The shapes must match; condition: T -> PRED, body: T -> T, init: T - static StatusOr InferWhileShape(const ProgramShape& condition, - const ProgramShape& body, - const Shape& init); + static absl::StatusOr InferWhileShape(const ProgramShape& condition, + const ProgramShape& body, + const Shape& init); // Infers the shape produced by a predicated or indexed conditional operation. - static StatusOr InferConditionalShape( + static absl::StatusOr InferConditionalShape( const Shape& branch_index, absl::Span branch_computations, absl::Span branch_operands); // Infers the shape produced by a broadcast operation. - static StatusOr InferBroadcastShape( + static absl::StatusOr InferBroadcastShape( const Shape& operand, absl::Span broadcast_sizes); // Checks whether the given parameters can form a broadcast. Returns the same // output_shape if it's legal. - static StatusOr InferBroadcastShape( + static absl::StatusOr InferBroadcastShape( const Shape& operand_shape, const Shape& output_shape, absl::Span broadcast_dimensions); // Infers the shape produced by a reshape operation from the element type of // its operand and the new dimension sizes specified. - static StatusOr InferReshapeShape(const Shape& operand, - absl::Span dimensions, - absl::Span new_sizes, - int64_t inferred_dimension); + static absl::StatusOr InferReshapeShape( + const Shape& operand, absl::Span dimensions, + absl::Span new_sizes, int64_t inferred_dimension); // Infers the shape produced by a dynamic reshape operation from the element // type of its operand and the new dimension sizes specified. The result shape // will have dynamic dimensions as specific in `dim_is_dynamic` and bound // `new_size_bounds`. - static StatusOr InferDynamicReshapeShape( + static absl::StatusOr InferDynamicReshapeShape( const Shape& operand, absl::Span dim_size_shapes, absl::Span new_size_bounds, const std::vector& dims_are_dynamic); // Infers the shape produced by a transpose operation from the element type of // its operand and its dimensions field. - static StatusOr InferTransposeShape( + static absl::StatusOr InferTransposeShape( const Shape& operand, absl::Span dimensions); // Helper that infers the shape produced by performing a concatenate operation // with the given operand shapes. - static StatusOr InferConcatOpShape( + static absl::StatusOr InferConcatOpShape( absl::Span arg_shapes, int64_t dimension); // Helper that validates the given operand shape can be converted to the // target output_shape via a convert instruction -- the requirement is that // the shape is identical except for the element type. - static StatusOr InferConvertShape(const Shape& operand_shape, - PrimitiveType new_element_type); + static absl::StatusOr InferConvertShape( + const Shape& operand_shape, PrimitiveType new_element_type); // Helper that validates the given operand shape can be bitcast converted to // the target output_shape via a bitcast convert instruction -- the // requirement is that the shape is identical except for the element type and // the element types have identical bit-widths. - static StatusOr InferBitcastConvertShape( + static absl::StatusOr InferBitcastConvertShape( const Shape& operand_shape, PrimitiveType new_element_type); // Helper that validates the given operand shape can be converted to the // target output_shape via a stochastic convert instruction -- the requirement // is that the shape is identical except for the element type. - static StatusOr InferStochasticConvertShape( + static absl::StatusOr InferStochasticConvertShape( const Shape& operand_shape, const Shape& random_shape, PrimitiveType new_element_type); // Helper that validates the input data type for a reduce-precision operation, // and returns the result shape. - static StatusOr InferReducePrecisionShape(const Shape& operand_shape, - const int exponent_bits, - const int mantissa_bits); + static absl::StatusOr InferReducePrecisionShape( + const Shape& operand_shape, const int exponent_bits, + const int mantissa_bits); // Helper that infers the shape produced by a pad operation based on the // padding configuration. - static StatusOr InferPadShape(const Shape& operand_shape, - const Shape& padding_value_shape, - const PaddingConfig& padding_config); + static absl::StatusOr InferPadShape( + const Shape& operand_shape, const Shape& padding_value_shape, + const PaddingConfig& padding_config); // Helper that validates the given arg_shapes are compatible with the shape of // the to_apply parameters, and returns the to_apply result shape. - static StatusOr InferCallShape( + static absl::StatusOr InferCallShape( absl::Span arg_shapes, const ProgramShape& to_apply); // Helper that infers the shape produced by performing a dot operation with // the given LHS and RHS shapes. An optional preferred_element_type can be // specified to upcast the element type. - static StatusOr InferDotOpShape( + static absl::StatusOr InferDotOpShape( const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers, - std::optional preferred_element_type); + std::optional preferred_element_type, + absl::Span sparsity = {}); + + // Helper that infers the shape of the sparse dot metadata. + static absl::StatusOr InferSparseDotMetadataShape( + const Shape& operand_shape, const DotDimensionNumbers& dimension_numbers, + const SparsityDescriptor& sparsity, PrimitiveType element_type = U16); // Helper that infers the shape of the tensor produced by a gather operation // with the given input shape, gather indices shape and gather dimension // numbers. - static StatusOr InferGatherShape( + static absl::StatusOr InferGatherShape( const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, absl::Span slice_sizes); @@ -357,25 +374,25 @@ class ShapeInference { // Helper that validates the given input shape, scatter indices shape, updates // shape, and scatter dimension numbers that constitute a scatter operation, // and returns the result shape of the scatter operation. - static StatusOr InferScatterShape( + static absl::StatusOr InferScatterShape( absl::Span arg_shapes, const ProgramShape& to_apply_shape, const ScatterDimensionNumbers& scatter_dim_numbers); // Helper that validates the given input shape to GetDimensionSize. - static StatusOr InferGetDimensionSizeShape(const Shape& shape, - int64_t dimension); + static absl::StatusOr InferGetDimensionSizeShape(const Shape& shape, + int64_t dimension); // Helper that validates the given input shape to SetDimensionSize. - static StatusOr InferSetDimensionSizeShape(const Shape& operand_shape, - const Shape& val_shape, - int64_t dimension); + static absl::StatusOr InferSetDimensionSizeShape( + const Shape& operand_shape, const Shape& val_shape, int64_t dimension); - static StatusOr InferTopKShape(const Shape& operand_shape, int64_t k); + static absl::StatusOr InferTopKShape(const Shape& operand_shape, + int64_t k); // Helper function for creating a Window proto from user-supplied data. // Returns error if the user-supplied data was invalid. - static StatusOr InferWindowFromDimensions( + static absl::StatusOr InferWindowFromDimensions( absl::Span window_dimensions, absl::Span window_strides, absl::Span> padding, @@ -389,31 +406,35 @@ class ShapeInference { // Note: By "element-wise" we mean operations that look at a single element in // the LHS and a single element in the RHS to produce a single output element, // even in the presence of broadcasting of one of the operands over the other. - static StatusOr InferElementwiseBinaryOpShape( + static absl::StatusOr InferElementwiseBinaryOpShape( HloOpcode operation, const Shape& lhs, const Shape& rhs, absl::Span broadcast_dimensions); // Helper for inferring the shape of Clamp ops. - static StatusOr InferClampShape(const Shape& min, const Shape& operand, - const Shape& max); + static absl::StatusOr InferClampShape(const Shape& min, + const Shape& operand, + const Shape& max); // Helper for inferring the shape of Select ops. - static StatusOr InferSelectShape(const Shape& pred, - const Shape& on_true, - const Shape& on_false); + static absl::StatusOr InferSelectShape(const Shape& pred, + const Shape& on_true, + const Shape& on_false); // Helper for inferring shapes of binary operations which use degenerate // dimension broadcasting (a dimension of size 1 in one operand is broadcast // up to match the size of the dimension in the other operand). - static StatusOr InferDegenerateDimensionBroadcastShape( - HloOpcode operation, const Shape& lhs, const Shape& rhs); + static absl::StatusOr InferDegenerateDimensionBroadcastShape( + const Shape& lhs, const Shape& rhs); // Helper for inferring shapes of binary operations using "InDim" // broadcasting. This is the broadcasting used in the *InDim binary operations // (for example ComputationBuilder::AddInDim). smaller_shape must be a // lower-rank shape than larger_shape. Returns the shape that the // smaller_shape is broadcast to. - static StatusOr InferInDimBroadcastShape( + // + // Since this method is only used by InferBinaryOpShape transitively, this + // method also supports inference of unbounded dynamic dimensions. + static absl::StatusOr InferInDimBroadcastShape( const Shape& smaller_shape, const Shape& larger_shape, absl::Span broadcast_dimensions); diff --git a/xla/service/shape_inference_test.cc b/xla/service/shape_inference_test.cc index 2068bb0ec3661..cec0ea13048f7 100644 --- a/xla/service/shape_inference_test.cc +++ b/xla/service/shape_inference_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,11 +15,18 @@ limitations under the License. #include "xla/service/shape_inference.h" +#include +#include +#include +#include #include #include #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "absl/types/span.h" @@ -29,11 +36,11 @@ limitations under the License. #include "xla/service/hlo_parser.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/test.h" #include "xla/test_helpers.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -41,6 +48,12 @@ namespace { using ::testing::ContainsRegex; using ::testing::HasSubstr; +constexpr absl::string_view kBroadcastDimensionMismatchErrorMessage = + "Broadcast dimension 0 mismatch"; +constexpr absl::string_view kIncompatibleBinaryOpShapeErrorMessage = + "Binary op with incompatible shapes"; +std::array zero_array = {0}; + class ShapeInferenceTest : public ::testing::Test { protected: // Some handy scalar shapes. @@ -72,11 +85,11 @@ class ReduceShapeInferenceTest : public ShapeInferenceTest { const Shape& expected_inferred_shape, const Shape& arg, absl::Span dimensions_to_reduce) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); - auto inferred_status = ShapeInference::InferReduceShape( - {&arg, &f32_}, dimensions_to_reduce, to_apply); - EXPECT_IS_OK(inferred_status.status()); - EXPECT_TRUE( - ShapeUtil::Equal(expected_inferred_shape, inferred_status.value())); + const absl::StatusOr inferred_shape = + ShapeInference::InferReduceShape({&arg, &f32_}, dimensions_to_reduce, + to_apply); + EXPECT_IS_OK(inferred_shape.status()); + EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape, *inferred_shape)); } }; @@ -110,139 +123,186 @@ class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest { ProgramShape scatter_program_shape_; }; +struct BinaryOpTestCase { + std::string lhs; + std::string rhs; + absl::Span broadcast_dimensions; + std::string expected; + std::optional error_message; +}; + +// Subclass for testing unbounded dynamic logical ops +class UnboundedLogicalOpShapeInferenceTest + : public ::testing::TestWithParam {}; + // Subclass for testing unbounded dynamic binary ops class UnboundedBinaryOpShapeInferenceTest + : public ::testing::TestWithParam {}; + +// Subclass for testing unbounded dynamic compare op +class UnboundedCompareOpShapeInferenceTest + : public ::testing::TestWithParam {}; + +// Subclass for testing unbounded dynamic complex op +class UnboundedComplexOpShapeInferenceTest + : public ::testing::TestWithParam {}; + +// Subclass for testing unbounded dynamic concatenate op +class UnboundedConcatenateOpShapeInferenceTest : public ::testing::TestWithParam> {}; +struct UnaryOpTestCase { + std::string operand; + std::string expected; + HloOpcode opcode; +}; + // Subclass for testing unbounded dynamic unary ops class UnboundedUnaryOpShapeInferenceTest + : public ::testing::TestWithParam {}; + +// Subclass for testing unbounded dynamic clamp op +class UnboundedClampOpShapeInferenceTest + : public ::testing::TestWithParam> {}; + +// Subclass for testing unbounded dynamic select op +class UnboundedSelectOpShapeInferenceTest : public ::testing::TestWithParam> {}; TEST_F(ShapeInferenceTest, UnaryNegateMatrix) { - Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); - auto inferred_status = + const Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); + const absl::StatusOr inferred_shape = ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.value())); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, *inferred_shape)); } TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) { - Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_}); - auto inferred_status = ShapeInference::InferTernaryOpShape( - HloOpcode::kSelect, pred_, tuple, tuple); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), + const Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_}); + const absl::StatusOr inferred_shape = + ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, pred_, tuple, + tuple); + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), HasSubstr("Expected array argument for select")); } TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) { - auto inferred_status = ShapeInference::InferTernaryOpShape( - HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT( - inferred_status.status().message(), - HasSubstr("Operands to select and predicate must be the same shape")); + const absl::StatusOr inferred_shape = + ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, pred_, + matrix_64_48_, matrix_64_48_); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, *inferred_shape)); } TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) { - auto predarray = ShapeUtil::MakeShape(PRED, {64, 48}); - auto inferred_status = ShapeInference::InferTernaryOpShape( - HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.value())); + const Shape predarray = ShapeUtil::MakeShape(PRED, {64, 48}); + const absl::StatusOr inferred_shape = + ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, predarray, + matrix_64_48_, matrix_64_48_); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, *inferred_shape)); } TEST_F(ShapeInferenceTest, SelectBadShapes) { - auto inferred_status_error1 = ShapeInference::InferTernaryOpShape( - HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_); - ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_THAT(inferred_status_error1.status().message(), + const absl::StatusOr inferred_shape_error1 = + ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, pred_, + matrix_64_48_, matrix_32_64_); + ASSERT_FALSE(inferred_shape_error1.ok()); + ASSERT_THAT(inferred_shape_error1.status().message(), HasSubstr("Operands to select must be the same shape")); - auto inferred_status_error2 = ShapeInference::InferTernaryOpShape( - HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_); - ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_THAT(inferred_status_error2.status().message(), + const absl::StatusOr inferred_shape_error2 = + ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, s32_, + matrix_64_48_, matrix_64_48_); + ASSERT_FALSE(inferred_shape_error2.ok()); + ASSERT_THAT(inferred_shape_error2.status().message(), HasSubstr("pred operand must have PRED")); - auto inferred_status_error3 = ShapeInference::InferTernaryOpShape( - HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_, - matrix_64_48_); - ASSERT_FALSE(inferred_status_error3.ok()); + const absl::StatusOr inferred_shape_error3 = + ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, + ShapeUtil::MakeShape(PRED, {64}), + matrix_64_48_, matrix_64_48_); + ASSERT_FALSE(inferred_shape_error3.ok()); ASSERT_THAT( - inferred_status_error3.status().message(), + inferred_shape_error3.status().message(), HasSubstr("Operands to select and predicate must be the same shape")); // Tuples have a TUPLE element type and cannot be the pred of a select. - auto inferred_status_error4 = ShapeInference::InferTernaryOpShape( - HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}), - ShapeUtil::MakeTupleShape({f32_, f32_}), - ShapeUtil::MakeTupleShape({f32_, f32_})); - ASSERT_FALSE(inferred_status_error4.ok()); - ASSERT_THAT(inferred_status_error4.status().message(), + const absl::StatusOr inferred_shape_error4 = + ShapeInference::InferTernaryOpShape( + HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}), + ShapeUtil::MakeTupleShape({f32_, f32_}), + ShapeUtil::MakeTupleShape({f32_, f32_})); + ASSERT_FALSE(inferred_shape_error4.ok()); + ASSERT_THAT(inferred_shape_error4.status().message(), HasSubstr("Expected array argument for select pred")); } TEST_F(ShapeInferenceTest, ClampAllMatrix) { - auto inferred_status = ShapeInference::InferTernaryOpShape( - HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.value())); + const absl::StatusOr inferred_shape = + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, matrix_64_48_, + matrix_64_48_, matrix_64_48_); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, *inferred_shape)); } TEST_F(ShapeInferenceTest, ClampAllScalar) { - auto inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.value())); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(f32_, *inferred_shape)); } TEST_F(ShapeInferenceTest, ClampMinScalar) { - auto inferred_status = ShapeInference::InferTernaryOpShape( - HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), - HasSubstr("Clamp with different shapes")); + const absl::StatusOr inferred_shape = + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, + matrix_64_48_, matrix_64_48_); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, *inferred_shape)); } TEST_F(ShapeInferenceTest, ClampMaxScalar) { - auto inferred_status = ShapeInference::InferTernaryOpShape( - HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), - HasSubstr("Clamp with different shapes")); + const absl::StatusOr inferred_shape = + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, matrix_64_48_, + matrix_64_48_, f32_); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, *inferred_shape)); } TEST_F(ShapeInferenceTest, ClampOperandScalar) { - auto inferred_status = ShapeInference::InferTernaryOpShape( - HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), - HasSubstr("Clamp with different shapes")); + const absl::StatusOr inferred_shape = + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, matrix_64_48_, + f32_, matrix_64_48_); + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), + HasSubstr("Clamp with incompatible shapes")); } TEST_F(ShapeInferenceTest, ClampMinMatrix) { - auto inferred_status = ShapeInference::InferTernaryOpShape( - HloOpcode::kClamp, matrix_64_48_, f32_, f32_); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), - HasSubstr("Clamp with different shapes")); + const absl::StatusOr inferred_shape = + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, matrix_64_48_, + f32_, f32_); + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), + HasSubstr("Clamp with incompatible shapes")); } TEST_F(ShapeInferenceTest, ClampMaxMatrix) { - auto inferred_status = ShapeInference::InferTernaryOpShape( - HloOpcode::kClamp, f32_, f32_, matrix_64_48_); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), - HasSubstr("Clamp with different shapes")); + const absl::StatusOr inferred_shape = + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, + matrix_64_48_); + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), + HasSubstr("Clamp with incompatible shapes")); } TEST_F(ShapeInferenceTest, ClampOperandMatrix) { - auto inferred_status = ShapeInference::InferTernaryOpShape( - HloOpcode::kClamp, f32_, matrix_64_48_, f32_); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), - HasSubstr("Clamp with different shapes")); + const absl::StatusOr inferred_shape = + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, + matrix_64_48_, f32_); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, *inferred_shape)); } TEST_F(ShapeInferenceTest, ClampBadShapes) { @@ -279,8 +339,8 @@ TEST_F(ShapeInferenceTest, ClampBadShapes) { } TEST_F(ShapeInferenceTest, Complex) { - auto complex_shape = [&](const Shape& lhs, const Shape& rhs, - absl::Span bcast) { + const auto complex_shape = [&](const Shape& lhs, const Shape& rhs, + absl::Span bcast) { return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs, bcast); }; @@ -292,7 +352,7 @@ TEST_F(ShapeInferenceTest, Complex) { // Only F32->C64 and F64->C128 supported. ASSERT_FALSE(complex_shape(f16_, f16_, {}).ok()); // Validate correct uses. - Shape c64_32 = ShapeUtil::MakeShape(C64, {32}); + const Shape c64_32 = ShapeUtil::MakeShape(C64, {32}); TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {})); ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C64, {}))); TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {})); @@ -302,7 +362,7 @@ TEST_F(ShapeInferenceTest, Complex) { TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {})); ASSERT_TRUE(ShapeUtil::Equal(result, c64_32)); - Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64}); + const Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64}); TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_64_, matrix_32_64_, {1})); ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); @@ -320,15 +380,15 @@ TEST_F(ShapeInferenceTest, Complex) { } TEST_F(ShapeInferenceTest, VariadicOpTuplify) { - StatusOr result = + const absl::StatusOr result = ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_}); ASSERT_IS_OK(result.status()); - ASSERT_TRUE(ShapeUtil::Equal(result.value(), - ShapeUtil::MakeTupleShape({s32_, f32_}))); + ASSERT_TRUE( + ShapeUtil::Equal(*result, ShapeUtil::MakeTupleShape({s32_, f32_}))); } TEST_F(ShapeInferenceTest, ReduceWindowInHalf) { - Shape matrix_shape = ShapeUtil::MakeShape(F32, {8, 8}); + const Shape matrix_shape = ShapeUtil::MakeShape(F32, {8, 8}); Window window; WindowDimension dim; dim.set_size(2); @@ -339,79 +399,85 @@ TEST_F(ShapeInferenceTest, ReduceWindowInHalf) { dim.set_base_dilation(1); *window.add_dimensions() = dim; *window.add_dimensions() = dim; - Shape window_shape = ShapeUtil::MakeShape(F32, {2, 2}); - Shape init_value_shape = ShapeUtil::MakeShape(F32, {}); - Shape float_scalar = ShapeUtil::MakeShape(F32, {}); + const Shape window_shape = ShapeUtil::MakeShape(F32, {2, 2}); + const Shape init_value_shape = ShapeUtil::MakeShape(F32, {}); + const Shape float_scalar = ShapeUtil::MakeShape(F32, {}); ProgramShape to_apply = ShapeUtil::MakeProgramShape( {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_); - auto inferred_status = ShapeInference::InferReduceWindowShape( - matrix_shape, init_value_shape, window, to_apply); + const absl::StatusOr inferred_shape = + ShapeInference::InferReduceWindowShape(matrix_shape, init_value_shape, + window, to_apply); - ASSERT_IS_OK(inferred_status.status()); - Shape inferred = inferred_status.value(); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 4}), inferred)); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 4}), *inferred_shape)); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterProperShapes) { - auto inferred_status_ok = ShapeInference::InferSelectAndScatterShape( - operand_shape_, select_program_shape_, window_, source_shape_, - init_value_shape_, scatter_program_shape_); - ASSERT_IS_OK(inferred_status_ok.status()); - Shape inferred = inferred_status_ok.value(); - ASSERT_TRUE(ShapeUtil::Equal(operand_shape_, inferred)); + const absl::StatusOr inferred_shape_ok = + ShapeInference::InferSelectAndScatterShape( + operand_shape_, select_program_shape_, window_, source_shape_, + init_value_shape_, scatter_program_shape_); + ASSERT_IS_OK(inferred_shape_ok.status()); + ASSERT_TRUE(ShapeUtil::Equal(operand_shape_, *inferred_shape_ok)); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) { - Shape source_shape_fail = ShapeUtil::MakeShape(F32, {4, 6}); - auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( - operand_shape_, select_program_shape_, window_, source_shape_fail, - init_value_shape_, scatter_program_shape_); - ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_THAT(inferred_status_fail.status().message(), + const Shape source_shape_fail = ShapeUtil::MakeShape(F32, {4, 6}); + const absl::StatusOr inferred_shape_fail = + ShapeInference::InferSelectAndScatterShape( + operand_shape_, select_program_shape_, window_, source_shape_fail, + init_value_shape_, scatter_program_shape_); + ASSERT_FALSE(inferred_shape_fail.ok()); + ASSERT_THAT(inferred_shape_fail.status().message(), HasSubstr("Source shape does not match")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape({ShapeUtil::MakeShape(F32, {})}, pred_); - auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( - operand_shape_, select_program_shape_fail, window_, source_shape_, - init_value_shape_, scatter_program_shape_); - ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_THAT(inferred_status_fail.status().message(), + const absl::StatusOr inferred_shape_fail = + ShapeInference::InferSelectAndScatterShape( + operand_shape_, select_program_shape_fail, window_, source_shape_, + init_value_shape_, scatter_program_shape_); + ASSERT_FALSE(inferred_shape_fail.ok()); + ASSERT_THAT(inferred_shape_fail.status().message(), HasSubstr("Select function must take 2 parameters")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape( {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_); - auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( - operand_shape_, select_program_shape_fail, window_, source_shape_, - init_value_shape_, scatter_program_shape_); - ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_THAT(inferred_status_fail.status().message(), + const absl::StatusOr inferred_shape_fail = + ShapeInference::InferSelectAndScatterShape( + operand_shape_, select_program_shape_fail, window_, source_shape_, + init_value_shape_, scatter_program_shape_); + ASSERT_FALSE(inferred_shape_fail.ok()); + ASSERT_THAT(inferred_shape_fail.status().message(), HasSubstr("Select function must have rank-0 PRED")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape( {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {})}, pred_); - auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( - operand_shape_, select_program_shape_fail, window_, source_shape_, - init_value_shape_, scatter_program_shape_); - ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_THAT(inferred_status_fail.status().message(), + const absl::StatusOr inferred_shape_fail = + ShapeInference::InferSelectAndScatterShape( + operand_shape_, select_program_shape_fail, window_, source_shape_, + init_value_shape_, scatter_program_shape_); + ASSERT_FALSE(inferred_shape_fail.ok()); + ASSERT_THAT(inferred_shape_fail.status().message(), HasSubstr("Select function's first parameter")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape( {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(U32, {})}, pred_); - auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( - operand_shape_, select_program_shape_fail, window_, source_shape_, - init_value_shape_, scatter_program_shape_); - ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_THAT(inferred_status_fail.status().message(), + const absl::StatusOr inferred_shape_fail = + ShapeInference::InferSelectAndScatterShape( + operand_shape_, select_program_shape_fail, window_, source_shape_, + init_value_shape_, scatter_program_shape_); + ASSERT_FALSE(inferred_shape_fail.ok()); + ASSERT_THAT(inferred_shape_fail.status().message(), HasSubstr("Select function's second parameter")); } @@ -420,10 +486,11 @@ TEST_F(ShapeInferenceTest, AllGatherStart) { const Shape expected_shape = ShapeUtil::MakeTupleShape( {operand, ShapeUtil::MakeShape(F32, {8, 8, 4})}); - auto inferred_ag_shape = ShapeInference::InferAllGatherStartShape( - {&operand}, /*all_gather_dimension=*/0, /*shard_count=*/8); + const absl::StatusOr inferred_ag_shape = + ShapeInference::InferAllGatherStartShape( + {&operand}, /*all_gather_dimension=*/0, /*shard_count=*/8); EXPECT_TRUE(inferred_ag_shape.ok()); - EXPECT_TRUE(ShapeUtil::Equal(inferred_ag_shape.value(), expected_shape)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_ag_shape, expected_shape)); } TEST_F(ShapeInferenceTest, AllGatherStartMultiOperand) { @@ -438,10 +505,12 @@ TEST_F(ShapeInferenceTest, AllGatherStartMultiOperand) { ShapeUtil::MakeTupleShape( {expected_output0_shape, expected_output1_shape})}); - auto inferred_ag_shape = ShapeInference::InferAllGatherStartShape( - {&operand0, &operand1}, /*all_gather_dimension=*/0, /*shard_count=*/8); + const absl::StatusOr inferred_ag_shape = + ShapeInference::InferAllGatherStartShape({&operand0, &operand1}, + /*all_gather_dimension=*/0, + /*shard_count=*/8); EXPECT_TRUE(inferred_ag_shape.ok()); - EXPECT_TRUE(ShapeUtil::Equal(inferred_ag_shape.value(), expected_shape)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_ag_shape, expected_shape)); } TEST_F(ShapeInferenceTest, AllGatherDone) { @@ -450,10 +519,10 @@ TEST_F(ShapeInferenceTest, AllGatherDone) { ShapeUtil::MakeShape(F32, {8, 8, 4})}); const Shape expected_shape = ShapeUtil::MakeShape(F32, {8, 8, 4}); - auto inferred_ag_done_shape = + const absl::StatusOr inferred_ag_done_shape = ShapeInference::InferAllGatherDoneShape(input_shape); EXPECT_TRUE(inferred_ag_done_shape.ok()); - EXPECT_TRUE(ShapeUtil::Equal(inferred_ag_done_shape.value(), expected_shape)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_ag_done_shape, expected_shape)); } TEST_F(ShapeInferenceTest, AllGatherDoneMultiOperand) { @@ -471,17 +540,17 @@ TEST_F(ShapeInferenceTest, AllGatherDoneMultiOperand) { const Shape expected_shape = ShapeUtil::MakeTupleShape( {expected_output0_shape, expected_output1_shape}); - auto inferred_ag_done_shape = + const absl::StatusOr inferred_ag_done_shape = ShapeInference::InferAllGatherDoneShape(input_shape); EXPECT_TRUE(inferred_ag_done_shape.ok()); - EXPECT_TRUE(ShapeUtil::Equal(inferred_ag_done_shape.value(), expected_shape)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_ag_done_shape, expected_shape)); } TEST_F(ShapeInferenceTest, Convolve) { ConvolutionDimensionNumbers dnums; // Dimension order: batch, feature, x0, x1 - Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); + const Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); dnums.set_input_feature_dimension(1); @@ -492,15 +561,15 @@ TEST_F(ShapeInferenceTest, Convolve) { dnums.add_output_spatial_dimensions(3); // Dimension order: x1, batch, feature, x0 - Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); + const Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); dnums.set_kernel_input_feature_dimension(2); dnums.set_kernel_output_feature_dimension(1); dnums.add_kernel_spatial_dimensions(3); dnums.add_kernel_spatial_dimensions(0); Window window; - auto dim0 = window.add_dimensions(); - auto dim1 = window.add_dimensions(); + const auto dim0 = window.add_dimensions(); + const auto dim1 = window.add_dimensions(); dim0->set_size(3); dim0->set_stride(2); dim0->set_padding_low(1); @@ -513,20 +582,21 @@ TEST_F(ShapeInferenceTest, Convolve) { dim1->set_padding_high(0); dim1->set_window_dilation(1); dim1->set_base_dilation(1); - auto inferred_status = ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, - window, dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_IS_OK(inferred_status.status()); - Shape inferred_shape = inferred_status.value(); + const absl::StatusOr inferred_shape = + ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums, + /*preferred_element_type=*/std::nullopt); + ASSERT_IS_OK(inferred_shape.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}), - inferred_shape)); + *inferred_shape)); } TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { ConvolutionDimensionNumbers dnums; // Dimension order: batch, feature, x0, x1 - Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4}); + const Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4}); dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); dnums.set_input_feature_dimension(1); @@ -537,14 +607,14 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { dnums.add_output_spatial_dimensions(3); // Dimension order: x1, batch, feature, x0 - Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); + const Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); dnums.set_kernel_input_feature_dimension(2); dnums.set_kernel_output_feature_dimension(1); dnums.add_kernel_spatial_dimensions(3); dnums.add_kernel_spatial_dimensions(0); Window window; - auto dim0 = window.add_dimensions(); + const auto dim0 = window.add_dimensions(); dim0->set_size(3); dim0->set_stride(3); dim0->set_padding_low(0); @@ -552,27 +622,28 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { dim0->set_window_dilation(6); dim0->set_base_dilation(1); - auto dim1 = window.add_dimensions(); + const auto dim1 = window.add_dimensions(); dim1->set_size(2); dim1->set_stride(1); dim1->set_padding_low(2); dim1->set_padding_high(1); dim1->set_window_dilation(2); dim1->set_base_dilation(1); - auto inferred_status = ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, - window, dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_IS_OK(inferred_status.status()); - Shape inferred_shape = inferred_status.value(); + const absl::StatusOr inferred_shape = + ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums, + /*preferred_element_type=*/std::nullopt); + ASSERT_IS_OK(inferred_shape.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}), - inferred_shape)); + *inferred_shape)); } TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { ConvolutionDimensionNumbers dnums; // Dimension order: batch, feature, x0, x1 - Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); + const Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); dnums.set_input_feature_dimension(1); @@ -583,14 +654,14 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { dnums.add_output_spatial_dimensions(3); // Dimension order: x1, batch, feature, x0 - Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4}); + const Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4}); dnums.set_kernel_input_feature_dimension(2); dnums.set_kernel_output_feature_dimension(1); dnums.add_kernel_spatial_dimensions(3); dnums.add_kernel_spatial_dimensions(0); Window window; - auto dim0 = window.add_dimensions(); + const auto dim0 = window.add_dimensions(); dim0->set_size(4); dim0->set_stride(3); dim0->set_padding_low(0); @@ -598,26 +669,27 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { dim0->set_window_dilation(1); dim0->set_base_dilation(6); - auto dim1 = window.add_dimensions(); + const auto dim1 = window.add_dimensions(); dim1->set_size(2); dim1->set_stride(1); dim1->set_padding_low(2); dim1->set_padding_high(1); dim1->set_window_dilation(1); dim1->set_base_dilation(2); - auto inferred_status = ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, - window, dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_IS_OK(inferred_status.status()); - Shape inferred_shape = inferred_status.value(); + const absl::StatusOr inferred_shape = + ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums, + /*preferred_element_type=*/std::nullopt); + ASSERT_IS_OK(inferred_shape.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}), - inferred_shape)); + *inferred_shape)); } TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { // Dimension order for this test: batch, feature, x0, x1 - Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); - Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2}); + const Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); + const Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2}); ConvolutionDimensionNumbers dnums; dnums.set_input_batch_dimension(3); @@ -634,8 +706,8 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { dnums.add_kernel_spatial_dimensions(1); Window window; - auto dim0 = window.add_dimensions(); - auto dim1 = window.add_dimensions(); + const auto dim0 = window.add_dimensions(); + const auto dim1 = window.add_dimensions(); dim0->set_size(2); dim0->set_stride(1); dim0->set_padding_low(0); @@ -644,11 +716,13 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { dim1->set_stride(2); dim1->set_padding_low(1); dim1->set_padding_high(1); - auto inferred_status = ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, - window, dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), + const absl::StatusOr inferred_shape = + ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums, + /*preferred_element_type=*/std::nullopt); + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), HasSubstr("each dimension exactly once")); } @@ -666,11 +740,11 @@ TEST_F(ShapeInferenceTest, ConvolveBatchGroupCountUnequalOutputFeature) { dnums.set_output_feature_dimension(1); dnums.add_output_spatial_dimensions(2); dnums.add_output_spatial_dimensions(3); - Shape lhs_shape = ShapeUtil::MakeShape(F32, {60, 38, 17, 13}); - Shape rhs_shape = ShapeUtil::MakeShape(F32, {38, 10, 4, 4}); + const Shape lhs_shape = ShapeUtil::MakeShape(F32, {60, 38, 17, 13}); + const Shape rhs_shape = ShapeUtil::MakeShape(F32, {38, 10, 4, 4}); Window window; - auto dim0 = window.add_dimensions(); - auto dim1 = window.add_dimensions(); + const auto dim0 = window.add_dimensions(); + const auto dim1 = window.add_dimensions(); dim0->set_size(4); dim1->set_size(4); dim0->set_padding_low(0); @@ -681,11 +755,13 @@ TEST_F(ShapeInferenceTest, ConvolveBatchGroupCountUnequalOutputFeature) { dim1->set_stride(1); dim0->set_window_dilation(3); dim1->set_window_dilation(2); - auto inferred_status = ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/6, - window, dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), + const absl::StatusOr inferred_shape = + ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, + /*batch_group_count=*/6, window, dnums, + /*preferred_element_type=*/std::nullopt); + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), HasSubstr("to be a multiple of batch group count")); } @@ -738,7 +814,7 @@ ConvolveArgs MakeConvolveArgs(PrimitiveType lhs_type, PrimitiveType rhs_type) { TEST_F(ShapeInferenceTest, ConvolveWithBF16_F16) { ConvolveArgs args = MakeConvolveArgs(BF16, F16); TF_ASSERT_OK_AND_ASSIGN( - Shape inferred_shape, + const Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, @@ -750,7 +826,7 @@ TEST_F(ShapeInferenceTest, ConvolveWithBF16_F16) { TEST_F(ShapeInferenceTest, ConvolveWithF16_BF16) { ConvolveArgs args = MakeConvolveArgs(F16, BF16); TF_ASSERT_OK_AND_ASSIGN( - Shape inferred_shape, + const Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, @@ -762,7 +838,7 @@ TEST_F(ShapeInferenceTest, ConvolveWithF16_BF16) { TEST_F(ShapeInferenceTest, ConvolveWithS32_U32) { ConvolveArgs args = MakeConvolveArgs(S32, U32); TF_ASSERT_OK_AND_ASSIGN( - Shape inferred_shape, + const Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, @@ -774,7 +850,7 @@ TEST_F(ShapeInferenceTest, ConvolveWithS32_U32) { TEST_F(ShapeInferenceTest, ConvolveWithU32_S32) { ConvolveArgs args = MakeConvolveArgs(U32, S32); TF_ASSERT_OK_AND_ASSIGN( - Shape inferred_shape, + const Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, @@ -786,7 +862,7 @@ TEST_F(ShapeInferenceTest, ConvolveWithU32_S32) { TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementType) { ConvolveArgs args = MakeConvolveArgs(S8, S16); TF_ASSERT_OK_AND_ASSIGN( - Shape inferred_shape, + const Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, @@ -798,7 +874,7 @@ TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementType) { TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementTypeSameAsInferredType) { ConvolveArgs args = MakeConvolveArgs(S8, S16); TF_ASSERT_OK_AND_ASSIGN( - Shape inferred_shape, + const Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, @@ -811,7 +887,7 @@ TEST_F(ShapeInferenceTest, FloatingPointConvolveWithNarrowerPreferredElementType) { ConvolveArgs args = MakeConvolveArgs(F32, F32); TF_ASSERT_OK_AND_ASSIGN( - Shape inferred_shape, + const Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, @@ -824,7 +900,7 @@ TEST_F(ShapeInferenceTest, FloatingPointConvolveWithIntegralPreferredElementType) { ConvolveArgs args = MakeConvolveArgs(BF16, BF16); TF_ASSERT_OK_AND_ASSIGN( - Shape inferred_shape, + const Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, @@ -837,7 +913,7 @@ TEST_F(ShapeInferenceTest, IntegralConvolveWithFloatingPointPreferredElementType) { ConvolveArgs args = MakeConvolveArgs(S8, S16); TF_ASSERT_OK_AND_ASSIGN( - Shape inferred_shape, + const Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, @@ -850,7 +926,7 @@ TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementTypeWithDifferentSignedness) { ConvolveArgs args = MakeConvolveArgs(S8, S16); TF_ASSERT_OK_AND_ASSIGN( - Shape inferred_shape, + const Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, @@ -862,7 +938,7 @@ TEST_F(ShapeInferenceTest, TEST_F(ShapeInferenceTest, ConvolveWithNarrowerPreferredElementType) { ConvolveArgs args = MakeConvolveArgs(S8, S16); TF_ASSERT_OK_AND_ASSIGN( - Shape inferred_shape, + const Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, @@ -884,17 +960,18 @@ static const char* innermost_dimension_matches = static void Pass(const Shape& shape, FftType type, absl::Span length, const Shape& expected_shape) { - auto inferred_status = ShapeInference::InferFftShape(shape, type, length); - ASSERT_IS_OK(inferred_status.status()); - Shape inferred_shape = inferred_status.value(); - ASSERT_TRUE(ShapeUtil::Equal(inferred_shape, expected_shape)); + const absl::StatusOr inferred_shape = + ShapeInference::InferFftShape(shape, type, length); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(expected_shape, *inferred_shape)); } static void Fail(const Shape& shape, FftType type, absl::Span length, absl::string_view message) { - auto inferred_status = ShapeInference::InferFftShape(shape, type, length); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), + const absl::StatusOr inferred_shape = + ShapeInference::InferFftShape(shape, type, length); + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), HasSubstr(std::string(message))); } @@ -902,7 +979,7 @@ static void Fail(const Shape& shape, FftType type, TEST_F(ShapeInferenceTest, InferFftShapeTestFftRanks) { FftType type = FftType::FFT; - Shape shape = ShapeUtil::MakeShape(C64, {16, 8}); + const Shape shape = ShapeUtil::MakeShape(C64, {16, 8}); fft::Fail(shape, type, {}, fft::unsupported_rank); fft::Pass(shape, type, {8}, shape); fft::Pass(shape, type, {16, 8}, shape); @@ -912,15 +989,15 @@ TEST_F(ShapeInferenceTest, InferFftShapeTestFftRanks) { TEST_F(ShapeInferenceTest, InferFftShapeTestFftTypes) { FftType type = FftType::FFT; - Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8}); - Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8}); + const Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8}); + const Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8}); fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input); fft::Pass(shape_c128, type, {16, 8}, shape_c128); } TEST_F(ShapeInferenceTest, InferFftShapeTestIfftRanks) { FftType type = FftType::IFFT; - Shape shape = ShapeUtil::MakeShape(C64, {16, 8}); + const Shape shape = ShapeUtil::MakeShape(C64, {16, 8}); fft::Fail(shape, type, {}, fft::unsupported_rank); fft::Pass(shape, type, {8}, shape); fft::Pass(shape, type, {16, 8}, shape); @@ -930,16 +1007,16 @@ TEST_F(ShapeInferenceTest, InferFftShapeTestIfftRanks) { TEST_F(ShapeInferenceTest, InferFftShapeTestIfftTypes) { FftType type = FftType::IFFT; - Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8}); - Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8}); + const Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8}); + const Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8}); fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input); fft::Pass(shape_c128, type, {16, 8}, shape_c128); } TEST_F(ShapeInferenceTest, InferFftShapeTestRfftRanks) { FftType type = FftType::RFFT; - Shape shape_in = ShapeUtil::MakeShape(F32, {16, 8}); - Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5}); + const Shape shape_in = ShapeUtil::MakeShape(F32, {16, 8}); + const Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5}); fft::Fail(shape_in, type, {}, fft::unsupported_rank); fft::Pass(shape_in, type, {8}, shape_out); fft::Pass(shape_in, type, {16, 8}, shape_out); @@ -949,36 +1026,36 @@ TEST_F(ShapeInferenceTest, InferFftShapeTestRfftRanks) { TEST_F(ShapeInferenceTest, InferFftShapeTestRfftDimensions) { FftType type = FftType::RFFT; - Shape shape = ShapeUtil::MakeShape(F32, {16, 8}); + const Shape shape = ShapeUtil::MakeShape(F32, {16, 8}); fft::Fail(shape, type, {4}, fft::dimensions_match); fft::Fail(shape, type, {16, 4}, fft::dimensions_match); fft::Fail(shape, type, {8, 8}, fft::dimensions_match); fft::Fail(shape, type, {8, 16}, fft::dimensions_match); - Shape zero_shape_in = ShapeUtil::MakeShape(F32, {16, 0}); - Shape zero_shape_out = ShapeUtil::MakeShape(C64, {16, 0}); + const Shape zero_shape_in = ShapeUtil::MakeShape(F32, {16, 0}); + const Shape zero_shape_out = ShapeUtil::MakeShape(C64, {16, 0}); fft::Pass(zero_shape_in, type, {0}, zero_shape_out); fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out); - Shape even_shape_in = ShapeUtil::MakeShape(F32, {16, 8}); - Shape odd_shape_in = ShapeUtil::MakeShape(F32, {16, 9}); - Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5}); + const Shape even_shape_in = ShapeUtil::MakeShape(F32, {16, 8}); + const Shape odd_shape_in = ShapeUtil::MakeShape(F32, {16, 9}); + const Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5}); fft::Pass(even_shape_in, type, {16, 8}, shape_out); fft::Pass(odd_shape_in, type, {16, 9}, shape_out); } TEST_F(ShapeInferenceTest, InferFftShapeTestRfftTypes) { FftType type = FftType::RFFT; - Shape shape_c64 = ShapeUtil::MakeShape(C64, {16, 8}); - Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8}); + const Shape shape_c64 = ShapeUtil::MakeShape(C64, {16, 8}); + const Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8}); fft::Fail(shape_c64, type, {16, 8}, fft::requires_f32_input); fft::Fail(shape_c128, type, {16, 8}, fft::requires_f32_input); } TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftRanks) { FftType type = FftType::IRFFT; - Shape shape_in = ShapeUtil::MakeShape(C64, {16, 5}); - Shape shape_out = ShapeUtil::MakeShape(F32, {16, 8}); + const Shape shape_in = ShapeUtil::MakeShape(C64, {16, 5}); + const Shape shape_out = ShapeUtil::MakeShape(F32, {16, 8}); fft::Fail(shape_in, type, {}, fft::unsupported_rank); fft::Pass(shape_in, type, {8}, shape_out); fft::Pass(shape_in, type, {16, 8}, shape_out); @@ -988,143 +1065,151 @@ TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftRanks) { TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftDimensions) { FftType type = FftType::IRFFT; - Shape shape = ShapeUtil::MakeShape(C64, {16, 5}); + const Shape shape = ShapeUtil::MakeShape(C64, {16, 5}); fft::Fail(shape, type, {5}, fft::innermost_dimension_matches); fft::Fail(shape, type, {16, 5}, fft::innermost_dimension_matches); fft::Fail(shape, type, {8, 8}, fft::dimensions_match); fft::Fail(shape, type, {8, 9}, fft::dimensions_match); - Shape zero_shape_in = ShapeUtil::MakeShape(C64, {16, 0}); - Shape zero_shape_out = ShapeUtil::MakeShape(F32, {16, 0}); + const Shape zero_shape_in = ShapeUtil::MakeShape(C64, {16, 0}); + const Shape zero_shape_out = ShapeUtil::MakeShape(F32, {16, 0}); fft::Pass(zero_shape_in, type, {0}, zero_shape_out); fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out); - Shape even_shape_out = ShapeUtil::MakeShape(F32, {16, 8}); - Shape odd_shape_out = ShapeUtil::MakeShape(F32, {16, 9}); + const Shape even_shape_out = ShapeUtil::MakeShape(F32, {16, 8}); + const Shape odd_shape_out = ShapeUtil::MakeShape(F32, {16, 9}); fft::Pass(shape, type, {16, 8}, even_shape_out); fft::Pass(shape, type, {16, 9}, odd_shape_out); } TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftTypes) { FftType type = FftType::IRFFT; - Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8}); - Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 5}); - Shape shape_f64_out = ShapeUtil::MakeShape(F64, {16, 8}); + const Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8}); + const Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 5}); + const Shape shape_f64_out = ShapeUtil::MakeShape(F64, {16, 8}); fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input); fft::Pass(shape_c128, type, {16, 8}, shape_f64_out); } TEST_F(ShapeInferenceTest, MapThatChangesElementType) { - Shape arg = ShapeUtil::MakeShape(F32, {20}); + const Shape arg = ShapeUtil::MakeShape(F32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_); - auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0}); - EXPECT_IS_OK(inferred_status.status()); - Shape expected = ShapeUtil::MakeShape(S32, {20}); - EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.value())); + const absl::StatusOr inferred_shape = + ShapeInference::InferMapShape({&arg}, to_apply, {0}); + EXPECT_IS_OK(inferred_shape.status()); + const Shape expected = ShapeUtil::MakeShape(S32, {20}); + EXPECT_TRUE(ShapeUtil::Equal(expected, *inferred_shape)); } TEST_F(ShapeInferenceTest, Map) { - auto inferred_status_r1f32 = ShapeInference::InferMapShape( - {&vector_32_, &vector_32_}, - ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); - EXPECT_IS_OK(inferred_status_r1f32.status()); - EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.value())); + const absl::StatusOr inferred_shape_r1f32 = + ShapeInference::InferMapShape( + {&vector_32_, &vector_32_}, + ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); + EXPECT_IS_OK(inferred_shape_r1f32.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_32_, *inferred_shape_r1f32)); // It's OK to provide a single argument, as long as the applied arity matches // (this degenerates to a Map). - auto inferred_status_r1f32_one = ShapeInference::InferMapShape( - {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0}); - EXPECT_IS_OK(inferred_status_r1f32_one.status()); - EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.value())); - - auto inferred_status_r2s32 = ShapeInference::InferMapShape( - {&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_}, - ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1}); - EXPECT_IS_OK(inferred_status_r2s32.status()); - EXPECT_TRUE( - ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.value())); - - auto no_args_error = ShapeInference::InferMapShape( + const absl::StatusOr inferred_shape_r1f32_one = + ShapeInference::InferMapShape( + {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0}); + EXPECT_IS_OK(inferred_shape_r1f32_one.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_32_, *inferred_shape_r1f32_one)); + + const absl::StatusOr inferred_shape_r2s32 = + ShapeInference::InferMapShape( + {&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_}, + ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1}); + EXPECT_IS_OK(inferred_shape_r2s32.status()); + EXPECT_TRUE(ShapeUtil::Equal(s32matrix_64_64_, *inferred_shape_r2s32)); + + const auto no_args_error = ShapeInference::InferMapShape( {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {}); ASSERT_FALSE(no_args_error.ok()); ASSERT_THAT(no_args_error.status().message(), HasSubstr("expects at least one argument")); - auto args_diff_shapes_error = ShapeInference::InferMapShape( + const auto args_diff_shapes_error = ShapeInference::InferMapShape( {&vector_32_, &vector_64_}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); ASSERT_FALSE(args_diff_shapes_error.ok()); ASSERT_THAT(args_diff_shapes_error.status().message(), HasSubstr("requires all operands to have the same shape")); - auto arity_error = ShapeInference::InferMapShape( + const auto arity_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0}); ASSERT_FALSE(arity_error.ok()); ASSERT_THAT(arity_error.status().message(), HasSubstr("function arity must match")); - auto output_shape_error = ShapeInference::InferMapShape( + const auto output_shape_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_), {0}); ASSERT_FALSE(output_shape_error.ok()); ASSERT_THAT(output_shape_error.status().message(), HasSubstr("result has to be a scalar")); - auto param_shape_error = ShapeInference::InferMapShape( + const auto param_shape_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_), {0}); ASSERT_FALSE(param_shape_error.ok()); ASSERT_THAT(param_shape_error.status().message(), HasSubstr("parameter has to be a scalar")); - auto param_element_type_error = ShapeInference::InferMapShape( + const auto param_element_type_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_, s32_}, f32_), {0}); ASSERT_FALSE(param_element_type_error.ok()); ASSERT_THAT(param_element_type_error.status().message(), HasSubstr("parameter type has to match argument")); - Shape arg = ShapeUtil::MakeShape(F32, {20}); + const Shape arg = ShapeUtil::MakeShape(F32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_); - auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0}); - EXPECT_IS_OK(inferred_status.status()); - EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.value())); - - auto inferred_status_error1 = ShapeInference::InferMapShape( - {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); - ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_THAT(inferred_status_error1.status().message(), + const absl::StatusOr inferred_shape = + ShapeInference::InferMapShape({&arg}, to_apply, {0}); + EXPECT_IS_OK(inferred_shape.status()); + EXPECT_TRUE(ShapeUtil::Equal(arg, *inferred_shape)); + + const absl::StatusOr inferred_shape_error1 = + ShapeInference::InferMapShape( + {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); + ASSERT_FALSE(inferred_shape_error1.ok()); + ASSERT_THAT(inferred_shape_error1.status().message(), HasSubstr("arity must match number of arguments")); - auto inferred_status_error2 = ShapeInference::InferMapShape( - {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0}); - ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_THAT(inferred_status_error2.status().message(), + const absl::StatusOr inferred_shape_error2 = + ShapeInference::InferMapShape( + {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0}); + ASSERT_FALSE(inferred_shape_error2.ok()); + ASSERT_THAT(inferred_shape_error2.status().message(), HasSubstr("has to be a scalar")); - auto inferred_status_error3 = ShapeInference::InferMapShape( - {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0}); - ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_THAT(inferred_status_error3.status().message(), + const absl::StatusOr inferred_shape_error3 = + ShapeInference::InferMapShape( + {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0}); + ASSERT_FALSE(inferred_shape_error3.ok()); + ASSERT_THAT(inferred_shape_error3.status().message(), HasSubstr("has to be a scalar")); - auto inferred_status_error5 = ShapeInference::InferMapShape( - {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0}); - ASSERT_FALSE(inferred_status_error5.ok()); - ASSERT_THAT(inferred_status_error5.status().message(), + const absl::StatusOr inferred_shape_error5 = + ShapeInference::InferMapShape( + {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0}); + ASSERT_FALSE(inferred_shape_error5.ok()); + ASSERT_THAT(inferred_shape_error5.status().message(), HasSubstr("parameter type has to match argument")); } TEST_F(ShapeInferenceTest, MapWithDifferentInputTypes) { - Shape arg0 = ShapeUtil::MakeShape(F32, {20}); - Shape arg1 = ShapeUtil::MakeShape(S32, {20}); + const Shape arg0 = ShapeUtil::MakeShape(F32, {20}); + const Shape arg1 = ShapeUtil::MakeShape(S32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, s32_}, s32_); - auto inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferMapShape({&arg0, &arg1}, to_apply, {0}); - EXPECT_IS_OK(inferred_status.status()); - Shape expected = ShapeUtil::MakeShape(S32, {20}); - EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.value())); + EXPECT_IS_OK(inferred_shape.status()); + const Shape expected = ShapeUtil::MakeShape(S32, {20}); + EXPECT_TRUE(ShapeUtil::Equal(expected, *inferred_shape)); } TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) { @@ -1173,20 +1258,20 @@ TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) { } TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) { - Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); - Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + const Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + const Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); ProgramShape to_apply = ShapeUtil::MakeProgramShape( {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); - auto inferred_status = ShapeInference::InferReduceShape( + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape( {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); - EXPECT_IS_OK(inferred_status.status()); + EXPECT_IS_OK(inferred_shape.status()); EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}), - inferred_status.value())); + *inferred_shape)); } TEST_F(ReduceShapeInferenceTest, ReduceWindowMultiOutput) { - Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1}); - Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1}); + const Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1}); + const Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1}); std::vector args = {&f32_arg_shape, &s32_arg_shape}; std::vector inits = {&f32_, &s32_}; ProgramShape to_apply = ShapeUtil::MakeProgramShape( @@ -1197,42 +1282,43 @@ TEST_F(ReduceShapeInferenceTest, ReduceWindowMultiOutput) { MakePadding(f32_arg_shape.dimensions(), window_dimensions, window_strides, Padding::kValid); TF_ASSERT_OK_AND_ASSIGN( - Window window, + const Window window, ShapeInference::InferWindowFromDimensions( window_dimensions, window_strides, padding_values, {}, {})); - auto inferred_status = ShapeInference::InferReduceWindowShape( - absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply); - VLOG(2) << inferred_status.value().ToString() << "\n"; - EXPECT_IS_OK(inferred_status.status()); + const absl::StatusOr inferred_shape = + ShapeInference::InferReduceWindowShape( + absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply); + VLOG(2) << inferred_shape->ToString() << "\n"; + EXPECT_IS_OK(inferred_shape.status()); EXPECT_TRUE(ShapeUtil::Equal( ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {5, 2, 0}), ShapeUtil::MakeShape(S32, {5, 2, 0})}), - inferred_status.value())); + *inferred_shape)); } TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) { - Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); - Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + const Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + const Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); - auto inferred_status = ShapeInference::InferReduceShape( + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape( {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); - EXPECT_FALSE(inferred_status.ok()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_FALSE(inferred_shape.ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("must take 4 parameters, but takes 6 parameter(s)")); } TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) { - Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); - Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + const Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + const Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); ProgramShape to_apply = ShapeUtil::MakeProgramShape( {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); - auto inferred_status = ShapeInference::InferReduceShape( + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape( {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); - EXPECT_FALSE(inferred_status.ok()); + EXPECT_FALSE(inferred_shape.ok()); EXPECT_THAT( - inferred_status.status().message(), + inferred_shape.status().message(), HasSubstr( "parameter shape differs from the result shape: s32[] vs f32[]")); } @@ -1240,15 +1326,16 @@ TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) { TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) { ProgramShape to_apply = ShapeUtil::MakeProgramShape( {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); - auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply); - EXPECT_FALSE(inferred_status.ok()); - EXPECT_THAT(inferred_status.status().message(), + const absl::StatusOr inferred_shape = + ShapeInference::InferReduceShape({}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_shape.ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("must have at least 2 arguments, has 0")); } TEST_F(ReduceShapeInferenceTest, ErrorBadReduceWindowInput) { - Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1}); - Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1}); + const Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1}); + const Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1}); std::vector args = {&f32_arg_shape, &s32_arg_shape}; std::vector inits = {&f32_, &s32_}; ProgramShape to_apply = ShapeUtil::MakeProgramShape( @@ -1259,177 +1346,177 @@ TEST_F(ReduceShapeInferenceTest, ErrorBadReduceWindowInput) { MakePadding(f32_arg_shape.dimensions(), window_dimensions, window_strides, Padding::kValid); TF_ASSERT_OK_AND_ASSIGN( - Window window, + const Window window, ShapeInference::InferWindowFromDimensions( window_dimensions, window_strides, padding_values, {}, {})); - auto inferred_status = ShapeInference::InferReduceWindowShape( - absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply); - EXPECT_FALSE(inferred_status.status().ok()); - EXPECT_THAT(inferred_status.status().message(), HasSubstr("f32[] vs s32[]")); + const absl::StatusOr inferred_shape = + ShapeInference::InferReduceWindowShape( + absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply); + EXPECT_FALSE(inferred_shape.status().ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("f32[] vs s32[]")); } TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) { - Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); - Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + const Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + const Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_); - auto inferred_status = ShapeInference::InferReduceShape( + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape( {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); - EXPECT_FALSE(inferred_status.ok()); + EXPECT_FALSE(inferred_shape.ok()); EXPECT_THAT( - inferred_status.status().message(), + inferred_shape.status().message(), HasSubstr("must produce a tuple with 2 elements, but produces a scalar")); } TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) { - Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); - Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + const Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + const Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); ProgramShape to_apply = ShapeUtil::MakeProgramShape( {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_})); - auto inferred_status = ShapeInference::InferReduceShape( + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape( {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); - EXPECT_FALSE(inferred_status.ok()); + EXPECT_FALSE(inferred_shape.ok()); EXPECT_THAT( - inferred_status.status().message(), + inferred_shape.status().message(), HasSubstr("must produce a tuple with 2 elements, but has 3 elements")); } TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) { - Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); - Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + const Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + const Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); ProgramShape to_apply = ShapeUtil::MakeProgramShape( {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_})); - auto inferred_status = ShapeInference::InferReduceShape( + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape( {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); - EXPECT_FALSE(inferred_status.ok()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_FALSE(inferred_shape.ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("accumulator shape at index 0 differs from the " "init_value shape: s32[] vs f32[]")); } TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); - Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); - auto inferred_status = ShapeInference::InferReduceShape( + const Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape( {&arg_shape, &f32_}, /*dimensions_to_reduce=*/{3, 4}, to_apply); - EXPECT_FALSE(inferred_status.ok()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_FALSE(inferred_shape.ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("out-of-bounds dimension")); } TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_); - Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); - auto inferred_status = + const Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape({&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0}, to_apply); - EXPECT_FALSE(inferred_status.ok()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_FALSE(inferred_shape.ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("take 2 parameters")); } TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_); - Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); - auto inferred_status = + const Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape({&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0}, to_apply); - EXPECT_FALSE(inferred_status.ok()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_FALSE(inferred_shape.ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("0-th parameter shape differs")); } TEST_F(ReduceShapeInferenceTest, ReduceWithRepeatedReduceDimension) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); - Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); - auto inferred_status = ShapeInference::InferReduceShape( + const Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape( {&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0, 0}, to_apply); - EXPECT_FALSE(inferred_status.ok()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_FALSE(inferred_shape.ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("Duplicate reduction dimension: 0")); } TEST_F(ShapeInferenceTest, InferSliceShapeRank2) { - Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); - auto inferred_status = + const Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); + const absl::StatusOr inferred_shape = ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {1, 1}); - ASSERT_IS_OK(inferred_status.status()); - Shape inferred = inferred_status.value(); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred)); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), *inferred_shape)); } TEST_F(ShapeInferenceTest, InferSliceWithDynamicDimensions) { - Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}, {true, true}); - auto inferred_status = + const Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}, {true, true}); + const absl::StatusOr inferred_shape = ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {33, 64}, {1, 1}); - ASSERT_IS_OK(inferred_status.status()); - Shape inferred = inferred_status.value(); + ASSERT_IS_OK(inferred_shape.status()); ASSERT_TRUE(ShapeUtil::Equal( - ShapeUtil::MakeShape(F32, {1, 64}, {false, true}), inferred)); + ShapeUtil::MakeShape(F32, {1, 64}, {false, true}), *inferred_shape)); } TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) { - Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); - auto inferred_status = + const Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); + const absl::StatusOr inferred_shape = ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {2, 4}); - ASSERT_IS_OK(inferred_status.status()); - Shape inferred = inferred_status.value(); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred)); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), *inferred_shape)); } TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) { - Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); - auto inferred_status = + const Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); + const absl::StatusOr inferred_shape = ShapeInference::InferSliceShape(matrix_shape, {15, 0}, {20, 13}, {2, 4}); - ASSERT_IS_OK(inferred_status.status()); - Shape inferred = inferred_status.value(); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), inferred)); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), *inferred_shape)); } TEST_F(ShapeInferenceTest, InferInvalidStride) { - Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); - auto inferred_status = + const Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); + const absl::StatusOr inferred_shape = ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {0, 1}); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_EQ(tsl::error::INVALID_ARGUMENT, inferred_status.status().code()); + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_EQ(tsl::error::INVALID_ARGUMENT, inferred_shape.status().code()); } TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) { - Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); - auto inferred_status = + const Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); + const absl::StatusOr inferred_shape = ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {1, 1}); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_EQ(tsl::error::INVALID_ARGUMENT, inferred_status.status().code()); + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_EQ(tsl::error::INVALID_ARGUMENT, inferred_shape.status().code()); } TEST_F(ShapeInferenceTest, InferSliceShapeRank1) { - Shape vector_shape = ShapeUtil::MakeShape(F32, {17}); - auto inferred_status = + const Shape vector_shape = ShapeUtil::MakeShape(F32, {17}); + const absl::StatusOr inferred_shape = ShapeInference::InferSliceShape(vector_shape, {2}, {4}, {1}); - ASSERT_TRUE(inferred_status.ok()); - Shape inferred = inferred_status.value(); - ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2}))); + ASSERT_TRUE(inferred_shape.ok()); + ASSERT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2}), *inferred_shape)); } TEST_F(ShapeInferenceTest, InferConstIndexShape) { - Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_}); - auto inferred0_status = + const Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_}); + const absl::StatusOr inferred0_status = ShapeInference::InferGetTupleElementShape(tuple_shape, 0); - auto inferred1_status = + const absl::StatusOr inferred1_status = ShapeInference::InferGetTupleElementShape(tuple_shape, 1); ASSERT_IS_OK(inferred0_status.status()); ASSERT_IS_OK(inferred1_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred0_status.value())); - ASSERT_TRUE(ShapeUtil::Equal(s32_, inferred1_status.value())); + ASSERT_TRUE(ShapeUtil::Equal(f32_, *inferred0_status)); + ASSERT_TRUE(ShapeUtil::Equal(s32_, *inferred1_status)); } TEST_F(ShapeInferenceTest, InferTupleElementShapeOutOfBound) { - Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_}); - auto inferredNegative_status = + const Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_}); + const absl::StatusOr inferredNegative_status = ShapeInference::InferGetTupleElementShape(tuple_shape, -1); - auto inferred2_status = + const absl::StatusOr inferred2_status = ShapeInference::InferGetTupleElementShape(tuple_shape, 2); ASSERT_FALSE(inferredNegative_status.ok()); ASSERT_FALSE(inferred2_status.ok()); @@ -1440,20 +1527,22 @@ TEST_F(ShapeInferenceTest, InferTupleElementShapeOutOfBound) { } TEST_F(ShapeInferenceTest, InferPowShape) { - auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = ShapeInference::InferBinaryOpShape( - HloOpcode::kPower, ten_floats, f32_, {}); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.value())); + const Shape ten_floats = ShapeUtil::MakeShape(F32, {10}); + const absl::StatusOr inferred_shape = + ShapeInference::InferBinaryOpShape(HloOpcode::kPower, ten_floats, f32_, + {}); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(ten_floats, *inferred_shape)); } TEST_F(ShapeInferenceTest, InferCompareShape) { - auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = ShapeInference::InferBinaryOpShape( - HloOpcode::kCompare, ten_floats, f32_, {}); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), - inferred_status.value())); + const Shape ten_floats = ShapeUtil::MakeShape(F32, {10}); + const absl::StatusOr inferred_shape = + ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, ten_floats, f32_, + {}); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), *inferred_shape)); } TEST_F(ShapeInferenceTest, InferReshapeDegenerateCombine) { @@ -1462,10 +1551,11 @@ TEST_F(ShapeInferenceTest, InferReshapeDegenerateCombine) { // [<=1] // // Both output dimension can be dynamic, use inferred_dimension to tie-break. - auto operand = ShapeUtil::MakeShape(F32, {1, 1}, {false, true}); - auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {1}, - /*inferred_dimension=*/-1); - ASSERT_EQ(ShapeUtil::MakeShape(F32, {1}, {true}), status.value()); + const Shape operand = ShapeUtil::MakeShape(F32, {1, 1}, {false, true}); + const auto status = + ShapeInference::InferReshapeShape(operand, {1, 0}, {1}, + /*inferred_dimension=*/-1); + ASSERT_EQ(ShapeUtil::MakeShape(F32, {1}, {true}), *status); } TEST_F(ShapeInferenceTest, InferReshapeSplit) { @@ -1474,74 +1564,79 @@ TEST_F(ShapeInferenceTest, InferReshapeSplit) { // [1, 10] // // Both output dimension can be dynamic, use inferred_dimension to tie-break. - auto operand = ShapeUtil::MakeShape(F32, {10}, {true}); - auto status = ShapeInference::InferReshapeShape(operand, {0}, {1, 10}, - /*inferred_dimension=*/0); - ASSERT_EQ(ShapeUtil::MakeShape(F32, {1, 10}, {true, false}), status.value()); + const Shape operand = ShapeUtil::MakeShape(F32, {10}, {true}); + const auto status = + ShapeInference::InferReshapeShape(operand, {0}, {1, 10}, + /*inferred_dimension=*/0); + ASSERT_EQ(ShapeUtil::MakeShape(F32, {1, 10}, {true, false}), *status); } TEST_F(ShapeInferenceTest, InferReshapeCombine) { // [6, <=10] // | reshape // [<=60] - auto operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true}); - auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {60}, - /*inferred_dimension=*/-11); - ASSERT_EQ(ShapeUtil::MakeShape(F32, {60}, {true}), status.value()); + const Shape operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true}); + const auto status = + ShapeInference::InferReshapeShape(operand, {1, 0}, {60}, + /*inferred_dimension=*/-11); + ASSERT_EQ(ShapeUtil::MakeShape(F32, {60}, {true}), *status); } TEST_F(ShapeInferenceTest, UnchangedDimension) { // [6, <=10] // | reshape // [2, 3, <=10] - auto operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true}); - auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {2, 3, 10}, - /*inferred_dimension=*/-11); + const Shape operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true}); + const auto status = + ShapeInference::InferReshapeShape(operand, {1, 0}, {2, 3, 10}, + /*inferred_dimension=*/-11); ASSERT_EQ(ShapeUtil::MakeShape(F32, {2, 3, 10}, {false, false, true}), - status.value()); + *status); } TEST_F(ShapeInferenceTest, InferDynamicBroadcast) { // CHECK: // %broadcast = s32[15,<=15]{1,0} broadcast(s32[<=15]{0}), dimensions={1} - auto operand_shape = ShapeUtil::MakeShape(F32, {15}, {true}); - auto inferred_status = + const Shape operand_shape = ShapeUtil::MakeShape(F32, {15}, {true}); + const absl::StatusOr inferred_shape = ShapeInference::InferBroadcastShape(operand_shape, {15}); - ASSERT_IS_OK(inferred_status.status()); - Shape inferred = inferred_status.value(); - ASSERT_EQ(ShapeUtil::MakeShape(F32, {15, 15}, {false, true}), inferred); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_EQ(ShapeUtil::MakeShape(F32, {15, 15}, {false, true}), + *inferred_shape); } TEST_F(ShapeInferenceTest, BroadcastScalar) { for (auto element_type : {F32, U32, S8}) { const Shape scalar_shape = ShapeUtil::MakeShape(element_type, {}); { // no-op scalar broadcast - auto status = ShapeInference::InferBroadcastShape(scalar_shape, {}); + const auto status = ShapeInference::InferBroadcastShape(scalar_shape, {}); ASSERT_IS_OK(status.status()); - ASSERT_TRUE(ShapeUtil::Equal(scalar_shape, status.value())); + ASSERT_TRUE(ShapeUtil::Equal(scalar_shape, *status)); } const Shape oned_shape = ShapeUtil::MakeShape(element_type, {3}); { // scalar -> 1d broadcast - auto status = ShapeInference::InferBroadcastShape(scalar_shape, {3}); + const auto status = + ShapeInference::InferBroadcastShape(scalar_shape, {3}); ASSERT_IS_OK(status.status()); - ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.value())); + ASSERT_TRUE(ShapeUtil::Equal(oned_shape, *status)); } { // no-op 1d broadcast - auto status = ShapeInference::InferBroadcastShape(oned_shape, {}); + const auto status = ShapeInference::InferBroadcastShape(oned_shape, {}); ASSERT_IS_OK(status.status()); - ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.value())); + ASSERT_TRUE(ShapeUtil::Equal(oned_shape, *status)); } const Shape twod_shape = ShapeUtil::MakeShape(element_type, {2, 3}); { // scalar -> 2d broadcast - auto status = ShapeInference::InferBroadcastShape(scalar_shape, {2, 3}); + const auto status = + ShapeInference::InferBroadcastShape(scalar_shape, {2, 3}); ASSERT_IS_OK(status.status()); - ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.value())); + ASSERT_TRUE(ShapeUtil::Equal(twod_shape, *status)); } { // 1d -> 2d broadcast - auto status = ShapeInference::InferBroadcastShape(oned_shape, {2}); + const auto status = ShapeInference::InferBroadcastShape(oned_shape, {2}); ASSERT_IS_OK(status.status()); - ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.value())); + ASSERT_TRUE(ShapeUtil::Equal(twod_shape, *status)); } } } @@ -1549,10 +1644,10 @@ TEST_F(ShapeInferenceTest, BroadcastScalar) { // scalar vector: ok TEST_F(ShapeInferenceTest, ScalarDotVector) { DotDimensionNumbers dot_dnums; - auto inferred_status = ShapeInference::InferDotOpShape( + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape( f32_, vector_32_, dot_dnums, /*preferred_element_type=*/std::nullopt); - EXPECT_TRUE(inferred_status.ok()); - EXPECT_EQ(inferred_status.value(), vector_32_); + EXPECT_TRUE(inferred_shape.ok()); + EXPECT_EQ(*inferred_shape, vector_32_); } // 3D 2D: error @@ -1560,11 +1655,11 @@ TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto inferred_status = ShapeInference::InferDotOpShape( + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums, /*preferred_element_type=*/std::nullopt); - EXPECT_TRUE(inferred_status.ok()); - EXPECT_TRUE(ShapeUtil::Equal(inferred_status.value(), + EXPECT_TRUE(inferred_shape.ok()); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, ShapeUtil::MakeShape(F32, {32, 32, 64}))); } @@ -1573,15 +1668,15 @@ TEST_F(ShapeInferenceTest, VectorDotVector) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); - auto inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.value())); - auto inferred_status_mismatch = + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(f32_, *inferred_shape)); + const absl::StatusOr inferred_shape_mismatch = ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status_mismatch.ok()); + ASSERT_FALSE(inferred_shape_mismatch.ok()); } // matrix vector -> vector @@ -1589,15 +1684,15 @@ TEST_F(ShapeInferenceTest, MatrixDotVector) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), vector_32_)); - auto inferred_status_mismatch = + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape, vector_32_)); + const absl::StatusOr inferred_shape_mismatch = ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status_mismatch.ok()); + ASSERT_FALSE(inferred_shape_mismatch.ok()); } // vector matrix -> vector @@ -1605,15 +1700,15 @@ TEST_F(ShapeInferenceTest, VectorDotMatrix) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); - auto inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), vector_64_)); - auto inferred_status_mismatch = + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape, vector_64_)); + const absl::StatusOr inferred_shape_mismatch = ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status_mismatch.ok()); + ASSERT_FALSE(inferred_shape_mismatch.ok()); } // matrix matrix -> matrix @@ -1621,24 +1716,24 @@ TEST_F(ShapeInferenceTest, MatrixDotMatrix) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto inferred_status_match = + const absl::StatusOr inferred_shape_match = ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_IS_OK(inferred_status_match.status()); - ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.value(), matrix_32_48_)) - << "inferred: " << ShapeUtil::HumanString(inferred_status_match.value()) + ASSERT_IS_OK(inferred_shape_match.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape_match, matrix_32_48_)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape_match) << " expected: " << ShapeUtil::HumanString(matrix_64_48_); - auto inferred_status_mismatch = + const absl::StatusOr inferred_shape_mismatch = ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status_mismatch.ok()); + ASSERT_FALSE(inferred_shape_mismatch.ok()); } // BatchMatMul with two batch dimensions and one contracting dimension. TEST_F(ShapeInferenceTest, DotGeneral) { - Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3}); - Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14}); - Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14}); + const Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3}); + const Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14}); + const Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(3); @@ -1649,19 +1744,19 @@ TEST_F(ShapeInferenceTest, DotGeneral) { dot_dnums.add_rhs_batch_dimensions(0); dot_dnums.add_rhs_batch_dimensions(1); - auto inferred_status_match = + const absl::StatusOr inferred_shape_match = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_IS_OK(inferred_status_match.status()); - ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.value(), output_shape)) - << "inferred: " << ShapeUtil::HumanString(inferred_status_match.value()) + ASSERT_IS_OK(inferred_shape_match.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape_match, output_shape)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape_match) << " expected: " << ShapeUtil::HumanString(output_shape); } // BatchMatMul with two contracting dimensions fails. TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { - Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); - Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + const Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); + const Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(2); @@ -1671,19 +1766,19 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { dot_dnums.add_rhs_contracting_dimensions(1); dot_dnums.add_rhs_batch_dimensions(0); - auto inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), HasSubstr("Must specify the same number of contracting " "dimensions for lhs and rhs.")); } TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) { - Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); - Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 2, 14}); - Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14}); + const Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); + const Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 2, 14}); + const Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(2); @@ -1694,39 +1789,41 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) { dot_dnums.add_rhs_contracting_dimensions(2); dot_dnums.add_rhs_batch_dimensions(0); - auto inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); - EXPECT_TRUE(inferred_status.ok()); - EXPECT_TRUE(ShapeUtil::Equal(inferred_status.value(), output_shape)); + EXPECT_TRUE(inferred_shape.ok()); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, output_shape)); } TEST_F(ShapeInferenceTest, ErrorSetDimensionSize) { - Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); - Shape val_shape = ShapeUtil::MakeShape(S32, {1}); - auto inferred_status = ShapeInference::InferSetDimensionSizeShape( - arg_shape, val_shape, /*dimension=*/0); - - EXPECT_FALSE(inferred_status.ok()); - EXPECT_THAT(inferred_status.status().message(), + const Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + const Shape val_shape = ShapeUtil::MakeShape(S32, {1}); + const absl::StatusOr inferred_shape = + ShapeInference::InferSetDimensionSizeShape(arg_shape, val_shape, + /*dimension=*/0); + + EXPECT_FALSE(inferred_shape.ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("value has to be S32 scalar")); } TEST_F(ShapeInferenceTest, ErrorSetDimensionSizeWrongType) { - Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); - Shape val_shape = ShapeUtil::MakeShape(U32, {}); - auto inferred_status = ShapeInference::InferSetDimensionSizeShape( - arg_shape, val_shape, /*dimension=*/0); - - EXPECT_FALSE(inferred_status.ok()); - EXPECT_THAT(inferred_status.status().message(), + const Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + const Shape val_shape = ShapeUtil::MakeShape(U32, {}); + const absl::StatusOr inferred_shape = + ShapeInference::InferSetDimensionSizeShape(arg_shape, val_shape, + /*dimension=*/0); + + EXPECT_FALSE(inferred_shape.ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("value has to be S32 scalar")); } // BatchMatMul with different batch dimension sizes fails. TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) { - Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); - Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14}); + const Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + const Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(2); @@ -1735,18 +1832,18 @@ TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) { dot_dnums.add_rhs_contracting_dimensions(1); dot_dnums.add_rhs_batch_dimensions(0); - auto inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), - HasSubstr("Batch dimension sizes must match")); + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), + HasSubstr("Batch dimension sizes are not compatible")); } // BatchMatMul with different batch dimension numbers passes TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimNumbersPasses) { - Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); - Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14}); + const Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + const Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(2); @@ -1755,18 +1852,18 @@ TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimNumbersPasses) { dot_dnums.add_rhs_contracting_dimensions(0); dot_dnums.add_rhs_batch_dimensions(1); - auto inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_TRUE(inferred_status.ok()); - ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), + ASSERT_TRUE(inferred_shape.ok()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape, ShapeUtil::MakeShape(F32, {2, 11, 14}))); } // BatchMatMul with out-of-range dimension numbers fails. TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) { - Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); - Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + const Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + const Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(3); @@ -1775,18 +1872,18 @@ TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) { dot_dnums.add_rhs_contracting_dimensions(0); dot_dnums.add_rhs_batch_dimensions(1); - auto inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), HasSubstr("A dimension number is out of range")); } // BatchMatMul with non-unique dimension numbers fails. TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) { - Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); - Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + const Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + const Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); @@ -1795,11 +1892,11 @@ TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) { dot_dnums.add_rhs_contracting_dimensions(0); dot_dnums.add_rhs_batch_dimensions(1); - auto inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), HasSubstr("A dimension number is not unique")); } @@ -1807,7 +1904,7 @@ TEST_F(ShapeInferenceTest, DotWithIntegralPreferredElementType) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape, + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(S8, {32, 32}), ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums, @@ -1820,7 +1917,7 @@ TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeSameAsInferredType) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape, + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(BF16, {32, 32}), ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums, @@ -1833,7 +1930,7 @@ TEST_F(ShapeInferenceTest, FloatingPointDotWithNarrowerPreferredElementType) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape, + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(BF16, {32, 32}), ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums, @@ -1846,7 +1943,7 @@ TEST_F(ShapeInferenceTest, FloatingPointDotWithIntegralPreferredElementType) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape, + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(BF16, {32, 32}), ShapeUtil::MakeShape(BF16, {32, 32}), dot_dnums, @@ -1859,7 +1956,7 @@ TEST_F(ShapeInferenceTest, IntegralDotWithFloatingPointPreferredElementType) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape, + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(S8, {32, 32}), ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums, @@ -1872,7 +1969,7 @@ TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeWithDifferentSignedness) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape, + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(S8, {32, 32}), ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums, @@ -1885,7 +1982,7 @@ TEST_F(ShapeInferenceTest, DotWithNarrowerPreferredElementType) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape, + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(S8, {32, 32}), ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums, @@ -1894,6 +1991,116 @@ TEST_F(ShapeInferenceTest, DotWithNarrowerPreferredElementType) { ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(S8, {32, 32}))); } +TEST_F(ShapeInferenceTest, DotWithSparseLhs) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + SparsityDescriptor sparsity_descriptor; + sparsity_descriptor.set_type(SparsityType::SPARSITY_STRUCTURED_N_M); + sparsity_descriptor.set_n(2); + sparsity_descriptor.set_m(4); + sparsity_descriptor.set_index(0); + sparsity_descriptor.set_dimension(1); + + std::vector sparsity = {sparsity_descriptor}; + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferDotOpShape( + ShapeUtil::MakeShape(F32, {10, 16}), + ShapeUtil::MakeShape(F32, {32, 20}), dot_dnums, + /*preferred_element_type=*/std::nullopt, absl::MakeSpan(sparsity))); + EXPECT_TRUE( + ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {10, 20}))); +} + +TEST_F(ShapeInferenceTest, DotWithSparseRhs) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + SparsityDescriptor sparsity_descriptor; + sparsity_descriptor.set_type(SparsityType::SPARSITY_STRUCTURED_N_M); + sparsity_descriptor.set_n(2); + sparsity_descriptor.set_m(4); + sparsity_descriptor.set_index(1); + sparsity_descriptor.set_dimension(0); + + std::vector sparsity = {sparsity_descriptor}; + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferDotOpShape( + ShapeUtil::MakeShape(F32, {10, 32}), + ShapeUtil::MakeShape(F32, {16, 20}), dot_dnums, + /*preferred_element_type=*/std::nullopt, absl::MakeSpan(sparsity))); + EXPECT_TRUE( + ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {10, 20}))); +} + +TEST_F(ShapeInferenceTest, DotWithSparseBothOperands) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + SparsityDescriptor sparsity_lhs; + sparsity_lhs.set_type(SparsityType::SPARSITY_STRUCTURED_N_M); + sparsity_lhs.set_n(2); + sparsity_lhs.set_m(4); + sparsity_lhs.set_index(0); + sparsity_lhs.set_dimension(1); + SparsityDescriptor sparsity_rhs = sparsity_lhs; + sparsity_rhs.set_index(1); + sparsity_rhs.set_dimension(0); + + std::vector sparsity = {sparsity_lhs, sparsity_rhs}; + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferDotOpShape( + ShapeUtil::MakeShape(F32, {10, 16}), + ShapeUtil::MakeShape(F32, {16, 20}), dot_dnums, + /*preferred_element_type=*/std::nullopt, absl::MakeSpan(sparsity))); + EXPECT_TRUE( + ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {10, 20}))); +} + +TEST_F(ShapeInferenceTest, DotWithIncorrectSparseDimensionSizeRatio) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + SparsityDescriptor sparsity_descriptor; + sparsity_descriptor.set_type(SparsityType::SPARSITY_STRUCTURED_N_M); + sparsity_descriptor.set_n(2); + sparsity_descriptor.set_m(4); + sparsity_descriptor.set_index(0); + sparsity_descriptor.set_dimension(1); + + std::vector sparsity = {sparsity_descriptor}; + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape( + ShapeUtil::MakeShape(F32, {10, 32}), ShapeUtil::MakeShape(F32, {32, 20}), + dot_dnums, /*preferred_element_type=*/std::nullopt, + absl::MakeSpan(sparsity)); + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT( + inferred_shape.status().message(), + HasSubstr("Sparse dimension size ratio doesn't match the descriptor")); +} + +TEST_F(ShapeInferenceTest, SparseDotMetadata) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_lhs_contracting_dimensions(2); + SparsityDescriptor sparsity_descriptor; + sparsity_descriptor.set_type(SparsityType::SPARSITY_STRUCTURED_N_M); + sparsity_descriptor.set_n(2); + sparsity_descriptor.set_m(4); + sparsity_descriptor.set_index(0); + sparsity_descriptor.set_dimension(2); + + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, + ShapeInference::InferSparseDotMetadataShape( + ShapeUtil::MakeShape(F32, {5, 10, 16}), dot_dnums, + sparsity_descriptor)); + EXPECT_TRUE( + ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(U16, {5, 10, 2}))); +} + TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { // Test variations of broadcasting a vector for a binary add with a // matrix. @@ -1901,23 +2108,23 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { const Shape vec8 = ShapeUtil::MakeShape(F32, {8}); const Shape vec16 = ShapeUtil::MakeShape(F32, {16}); - auto inferred_status_match = + absl::StatusOr inferred_shape_match = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1}); - ASSERT_IS_OK(inferred_status_match.status()); - ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.value(), mat)); + ASSERT_IS_OK(inferred_shape_match.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape_match, mat)); - auto inferred_status_mismatch = + absl::StatusOr inferred_shape_mismatch = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0}); - ASSERT_FALSE(inferred_status_mismatch.ok()); + ASSERT_FALSE(inferred_shape_mismatch.ok()); - inferred_status_match = + inferred_shape_match = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0}); - ASSERT_IS_OK(inferred_status_match.status()); - ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.value(), mat)); + ASSERT_IS_OK(inferred_shape_match.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape_match, mat)); - inferred_status_mismatch = + inferred_shape_mismatch = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1}); - ASSERT_FALSE(inferred_status_mismatch.ok()); + ASSERT_FALSE(inferred_shape_mismatch.ok()); } TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) { @@ -1927,20 +2134,21 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) { const Shape matrix16_4 = ShapeUtil::MakeShape(F32, {16, 4}); const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8}); - auto inferred_status_match = ShapeInference::InferBinaryOpShape( - HloOpcode::kAdd, cube, matrix8_4, {1, 2}); - ASSERT_IS_OK(inferred_status_match.status()); - ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.value(), cube)); + absl::StatusOr inferred_shape_match = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, cube, matrix8_4, + {1, 2}); + ASSERT_IS_OK(inferred_shape_match.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape_match, cube)); - inferred_status_match = ShapeInference::InferBinaryOpShape( + inferred_shape_match = ShapeInference::InferBinaryOpShape( HloOpcode::kAdd, cube, matrix16_4, {0, 2}); - ASSERT_IS_OK(inferred_status_match.status()); - ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.value(), cube)); + ASSERT_IS_OK(inferred_shape_match.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape_match, cube)); - inferred_status_match = ShapeInference::InferBinaryOpShape( + inferred_shape_match = ShapeInference::InferBinaryOpShape( HloOpcode::kAdd, cube, matrix16_8, {0, 1}); - ASSERT_IS_OK(inferred_status_match.status()); - ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.value(), cube)); + ASSERT_IS_OK(inferred_shape_match.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape_match, cube)); } TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { @@ -1952,216 +2160,228 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8}); // "magical" broadcast rejected - auto inferred_status_error1 = + const absl::StatusOr inferred_shape_error1 = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {}); - ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_THAT(inferred_status_error1.status().message(), + ASSERT_FALSE(inferred_shape_error1.ok()); + ASSERT_THAT(inferred_shape_error1.status().message(), HasSubstr("Shapes must be equal rank")); // broadcast_dimension out of bounds for tensor's rank - auto inferred_status_error2 = + const absl::StatusOr inferred_shape_error2 = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3}); - ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_THAT(inferred_status_error2.status().message(), + ASSERT_FALSE(inferred_shape_error2.ok()); + ASSERT_THAT(inferred_shape_error2.status().message(), ContainsRegex("Broadcast dimension number .* too large")); // broadcast_dimension doesn't match corresponding dimension - auto inferred_status_error3 = + const absl::StatusOr inferred_shape_error3 = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0}); - ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_THAT(inferred_status_error3.status().message(), + ASSERT_FALSE(inferred_shape_error3.ok()); + ASSERT_THAT(inferred_shape_error3.status().message(), HasSubstr("Broadcast dimension 0 mismatch")); // broadcast_dimensions list too long - auto inferred_status_error4 = ShapeInference::InferBinaryOpShape( - HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2}); - ASSERT_FALSE(inferred_status_error4.ok()); - ASSERT_THAT(inferred_status_error4.status().message(), + const absl::StatusOr inferred_shape_error4 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, matrix8_4, + {0, 1, 2}); + ASSERT_FALSE(inferred_shape_error4.ok()); + ASSERT_THAT(inferred_shape_error4.status().message(), HasSubstr("broadcast_dimensions has to match")); // there's a dimension above the rank of the tensor - auto inferred_status_error5 = ShapeInference::InferBinaryOpShape( - HloOpcode::kAdd, tensor, matrix8_4, {3, 0}); - ASSERT_FALSE(inferred_status_error5.ok()); - ASSERT_THAT(inferred_status_error5.status().message(), + const absl::StatusOr inferred_shape_error5 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, matrix8_4, + {3, 0}); + ASSERT_FALSE(inferred_shape_error5.ok()); + ASSERT_THAT(inferred_shape_error5.status().message(), ContainsRegex("dimension number .* too large")); // broadcasting dimensions don't match in this order - auto inferred_status_error6 = ShapeInference::InferBinaryOpShape( - HloOpcode::kAdd, tensor, matrix8_4, {2, 1}); - ASSERT_FALSE(inferred_status_error6.ok()); - ASSERT_THAT(inferred_status_error6.status().message(), + const absl::StatusOr inferred_shape_error6 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, matrix8_4, + {2, 1}); + ASSERT_FALSE(inferred_shape_error6.ok()); + ASSERT_THAT(inferred_shape_error6.status().message(), HasSubstr("dimension 0 mismatch")); // The following two tests make sure that broadcasting dimensions are listed // in a proper (strictly increasing) order, even if the lower-rank array // matches the higher-rank array in many different ways. - auto inferred_status_error7 = ShapeInference::InferBinaryOpShape( - HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0}); - ASSERT_FALSE(inferred_status_error7.ok()); - ASSERT_THAT(inferred_status_error7.status().message(), + const absl::StatusOr inferred_shape_error7 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor8_8_8, + matrix8_8, {0, 0}); + ASSERT_FALSE(inferred_shape_error7.ok()); + ASSERT_THAT(inferred_shape_error7.status().message(), HasSubstr("dimensions order is wrong")); - auto inferred_status_error8 = ShapeInference::InferBinaryOpShape( - HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0}); - ASSERT_FALSE(inferred_status_error8.ok()); - ASSERT_THAT(inferred_status_error8.status().message(), + const absl::StatusOr inferred_shape_error8 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor8_8_8, + matrix8_8, {1, 0}); + ASSERT_FALSE(inferred_shape_error8.ok()); + ASSERT_THAT(inferred_shape_error8.status().message(), HasSubstr("dimensions order is wrong")); } // Tests for the while instruction with proper shapes. TEST_F(ShapeInferenceTest, WhileWithCorrectShapes) { - Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_}); + const Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_}); ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_); ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape); - auto inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferWhileShape(cond, body, result_shape); - ASSERT_IS_OK(inferred_status.status()); - Shape inferred = inferred_status.value(); - ASSERT_TRUE(ShapeUtil::Equal(result_shape, inferred)); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(result_shape, *inferred_shape)); } // Tests for the while instruction with wrong shapes. TEST_F(ShapeInferenceTest, WhileWithBadShapes) { - Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_}); - ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_); - ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape); - - auto bad_shape_1 = ShapeUtil::MakeProgramShape({s32_, result_shape}, pred_); - auto inferred_status_error1 = - ShapeInference::InferWhileShape(bad_shape_1, body, result_shape); - ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_THAT(inferred_status_error1.status().message(), + const Shape inferred_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_}); + ProgramShape cond = ShapeUtil::MakeProgramShape({inferred_shape}, pred_); + ProgramShape body = + ShapeUtil::MakeProgramShape({inferred_shape}, inferred_shape); + + const auto bad_shape_1 = + ShapeUtil::MakeProgramShape({s32_, inferred_shape}, pred_); + const absl::StatusOr inferred_shape_error1 = + ShapeInference::InferWhileShape(bad_shape_1, body, inferred_shape); + ASSERT_FALSE(inferred_shape_error1.ok()); + ASSERT_THAT(inferred_shape_error1.status().message(), HasSubstr("Condition must take 1 arguments")); - auto bad_shape_2 = - ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape); - auto inferred_status_error2 = - ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape); - ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_THAT(inferred_status_error2.status().message(), + const auto bad_shape_2 = + ShapeUtil::MakeProgramShape({s32_, inferred_shape}, inferred_shape); + const absl::StatusOr inferred_shape_error2 = + ShapeInference::InferWhileShape(cond, bad_shape_2, inferred_shape); + ASSERT_FALSE(inferred_shape_error2.ok()); + ASSERT_THAT(inferred_shape_error2.status().message(), HasSubstr("Body must take 1 arguments")); - auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_); - auto inferred_status_error3 = - ShapeInference::InferWhileShape(bad_shape_3, body, result_shape); - ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_THAT(inferred_status_error3.status().message(), + const auto bad_shape_3 = ShapeUtil::MakeProgramShape({inferred_shape}, s32_); + const absl::StatusOr inferred_shape_error3 = + ShapeInference::InferWhileShape(bad_shape_3, body, inferred_shape); + ASSERT_FALSE(inferred_shape_error3.ok()); + ASSERT_THAT(inferred_shape_error3.status().message(), HasSubstr("Condition must return a boolean")); - auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_); - auto inferred_status_error4 = - ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape); - ASSERT_FALSE(inferred_status_error4.ok()); - ASSERT_THAT(inferred_status_error4.status().message(), + const auto bad_shape_4 = + ShapeUtil::MakeProgramShape({inferred_shape}, vector_32_); + const absl::StatusOr inferred_shape_error4 = + ShapeInference::InferWhileShape(cond, bad_shape_4, inferred_shape); + ASSERT_FALSE(inferred_shape_error4.ok()); + ASSERT_THAT(inferred_shape_error4.status().message(), HasSubstr("parameter of condition and body")); } // Tests for the concatenate instruction with dynamic shapes. TEST_F(ShapeInferenceTest, ConcatenateWithDynamicShapes) { - auto dynamic_shape_1 = + const auto dynamic_shape_1 = ShapeUtil::MakeShape(F32, {32, 160, 10}, {true, false, false}); - auto dynamic_shape_2 = + const auto dynamic_shape_2 = ShapeUtil::MakeShape(F32, {32, 160, 10}, {false, true, false}); - auto inferred_status = ShapeInference::InferConcatOpShape( - {&dynamic_shape_1, &dynamic_shape_2}, /*dimension=*/0); - ASSERT_IS_OK(inferred_status.status()); - Shape inferred = inferred_status.value(); + const absl::StatusOr inferred_shape = + ShapeInference::InferConcatOpShape({&dynamic_shape_1, &dynamic_shape_2}, + /*dimension=*/0); + ASSERT_IS_OK(inferred_shape.status()); ASSERT_TRUE(ShapeUtil::Equal( - ShapeUtil::MakeShape(F32, {64, 160, 10}, {true, true, false}), inferred)); + ShapeUtil::MakeShape(F32, {64, 160, 10}, {true, true, false}), + *inferred_shape)); } // Tests for the concatenate instruction with proper shapes. TEST_F(ShapeInferenceTest, ConcatenateWithCorrectShapes) { - auto inferred_status_1 = ShapeInference::InferConcatOpShape( - {&vector_32_, &vector_64_}, /*dimension=*/0); - ASSERT_IS_OK(inferred_status_1.status()); - Shape inferred_1 = inferred_status_1.value(); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {96}), inferred_1)); - - auto inferred_status_2 = ShapeInference::InferConcatOpShape( - {&vector_32_, &vector_64_, &vector_32_}, /*dimension=*/0); - ASSERT_IS_OK(inferred_status_2.status()); - Shape inferred_2 = inferred_status_2.value(); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {128}), inferred_2)); - - auto inferred_status_3 = ShapeInference::InferConcatOpShape( - {&matrix_32_48_, &matrix_32_64_, &matrix_32_48_}, /*dimension=*/1); - ASSERT_IS_OK(inferred_status_3.status()); - Shape inferred_3 = inferred_status_3.value(); + const absl::StatusOr inferred_shape_1 = + ShapeInference::InferConcatOpShape({&vector_32_, &vector_64_}, + /*dimension=*/0); + ASSERT_IS_OK(inferred_shape_1.status()); + ASSERT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {96}), *inferred_shape_1)); + + const absl::StatusOr inferred_shape_2 = + ShapeInference::InferConcatOpShape( + {&vector_32_, &vector_64_, &vector_32_}, /*dimension=*/0); + ASSERT_IS_OK(inferred_shape_2.status()); ASSERT_TRUE( - ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 160}), inferred_3)); + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {128}), *inferred_shape_2)); + + const absl::StatusOr inferred_shape_3 = + ShapeInference::InferConcatOpShape( + {&matrix_32_48_, &matrix_32_64_, &matrix_32_48_}, /*dimension=*/1); + ASSERT_IS_OK(inferred_shape_3.status()); + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 160}), + *inferred_shape_3)); } // Tests for the concatenate instruction with wrong shapes. TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { - auto inferred_status_error1 = + const absl::StatusOr inferred_shape_error1 = ShapeInference::InferConcatOpShape({}, /*dimension=*/0); - ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_THAT(inferred_status_error1.status().message(), + ASSERT_FALSE(inferred_shape_error1.ok()); + ASSERT_THAT(inferred_shape_error1.status().message(), HasSubstr("Concatenate expects at least one argument")); - auto inferred_status_error2 = + const absl::StatusOr inferred_shape_error2 = ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1); - ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_THAT(inferred_status_error2.status().message(), + ASSERT_FALSE(inferred_shape_error2.ok()); + ASSERT_THAT(inferred_shape_error2.status().message(), HasSubstr("dimension out of bounds: -1")); - auto inferred_status_error3 = + const absl::StatusOr inferred_shape_error3 = ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1); - ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_THAT(inferred_status_error3.status().message(), + ASSERT_FALSE(inferred_shape_error3.ok()); + ASSERT_THAT(inferred_shape_error3.status().message(), HasSubstr("dimension out of bounds: 1")); - Shape tuple = ShapeUtil::MakeTupleShape({vector_32_}); - auto inferred_status_error4 = ShapeInference::InferConcatOpShape( - {&vector_32_, &tuple}, /*dimension=*/0); - ASSERT_FALSE(inferred_status_error4.ok()); + const Shape tuple = ShapeUtil::MakeTupleShape({vector_32_}); + const absl::StatusOr inferred_shape_error4 = + ShapeInference::InferConcatOpShape({&vector_32_, &tuple}, + /*dimension=*/0); + ASSERT_FALSE(inferred_shape_error4.ok()); ASSERT_THAT( - inferred_status_error4.status().message(), + inferred_shape_error4.status().message(), HasSubstr("Expected array argument for operand of concatenation")); const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32}); - auto inferred_status_error5 = ShapeInference::InferConcatOpShape( - {&vector_32_, &vector_s32}, /*dimension=*/0); - ASSERT_FALSE(inferred_status_error5.ok()); - ASSERT_THAT(inferred_status_error5.status().message(), + const absl::StatusOr inferred_shape_error5 = + ShapeInference::InferConcatOpShape({&vector_32_, &vector_s32}, + /*dimension=*/0); + ASSERT_FALSE(inferred_shape_error5.ok()); + ASSERT_THAT(inferred_shape_error5.status().message(), HasSubstr("concatenate arrays with different element types")); - auto inferred_status_error6 = ShapeInference::InferConcatOpShape( - {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0); - ASSERT_FALSE(inferred_status_error6.ok()); - ASSERT_THAT(inferred_status_error6.status().message(), + const absl::StatusOr inferred_shape_error6 = + ShapeInference::InferConcatOpShape({&matrix_32_48_, &matrix_32_64_}, + /*dimension=*/0); + ASSERT_FALSE(inferred_shape_error6.ok()); + ASSERT_THAT(inferred_shape_error6.status().message(), HasSubstr("concatenate arrays that differ in " "dimensions other than the one being " "concatenated")); } TEST_F(ShapeInferenceTest, Pad) { - Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25}); - Shape padding_value_shape = ShapeUtil::MakeShape(F32, {}); + const Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25}); + const Shape padding_value_shape = ShapeUtil::MakeShape(F32, {}); // Padding for dimension 0: {low: 0, high: 2, interior: 3} // Padding for dimension 1: {low: 1, high: 5, interior: 0} PaddingConfig padding_config; - auto dimension0 = padding_config.add_dimensions(); + const auto dimension0 = padding_config.add_dimensions(); dimension0->set_edge_padding_low(0); dimension0->set_edge_padding_high(2); dimension0->set_interior_padding(3); - auto dimension1 = padding_config.add_dimensions(); + const auto dimension1 = padding_config.add_dimensions(); dimension1->set_edge_padding_low(1); dimension1->set_edge_padding_high(5); dimension1->set_interior_padding(0); - auto inferred_status = ShapeInference::InferPadShape( + const absl::StatusOr inferred_shape = ShapeInference::InferPadShape( input_shape, padding_value_shape, padding_config); - ASSERT_IS_OK(inferred_status.status()); - Shape inferred_shape = inferred_status.value(); + ASSERT_IS_OK(inferred_shape.status()); ASSERT_TRUE( - ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), inferred_shape)); + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), *inferred_shape)); dimension1->set_edge_padding_low(-20); dimension1->set_edge_padding_high(-10); - auto negative_dimension_size = ShapeInference::InferPadShape( + const auto negative_dimension_size = ShapeInference::InferPadShape( input_shape, padding_value_shape, padding_config); ASSERT_FALSE(negative_dimension_size.ok()); ASSERT_THAT(negative_dimension_size.status().message(), @@ -2169,283 +2389,303 @@ TEST_F(ShapeInferenceTest, Pad) { } TEST_F(ShapeInferenceTest, Reverse) { - Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25}); + const Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25}); - auto inferred_status = ShapeInference::InferReverseShape(input_shape, {0, 1}); - ASSERT_IS_OK(inferred_status.status()); - Shape inferred_shape = inferred_status.value(); - ASSERT_TRUE(ShapeUtil::Equal(input_shape, inferred_shape)); + const absl::StatusOr inferred_shape = + ShapeInference::InferReverseShape(input_shape, {0, 1}); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(input_shape, *inferred_shape)); } TEST_F(ShapeInferenceTest, ReverseInvalidDimension) { - Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25}); + const Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25}); - auto inferred_status_error0 = + const absl::StatusOr inferred_shape_error0 = ShapeInference::InferReverseShape(input_shape, {0, 2}); - ASSERT_FALSE(inferred_status_error0.ok()); - ASSERT_THAT(inferred_status_error0.status().message(), + ASSERT_FALSE(inferred_shape_error0.ok()); + ASSERT_THAT(inferred_shape_error0.status().message(), HasSubstr("out-of-bounds")); - auto inferred_status_error1 = + const absl::StatusOr inferred_shape_error1 = ShapeInference::InferReverseShape(input_shape, {0, -1}); - ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_THAT(inferred_status_error1.status().message(), + ASSERT_FALSE(inferred_shape_error1.ok()); + ASSERT_THAT(inferred_shape_error1.status().message(), HasSubstr("out-of-bounds")); - auto inferred_status_error2 = + const absl::StatusOr inferred_shape_error2 = ShapeInference::InferReverseShape(input_shape, {0, 0}); - ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_THAT(inferred_status_error2.status().message(), + ASSERT_FALSE(inferred_shape_error2.ok()); + ASSERT_THAT(inferred_shape_error2.status().message(), HasSubstr("duplicated")); - Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape}); - auto inferred_status_error3 = + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({input_shape, input_shape}); + const absl::StatusOr inferred_shape_error3 = ShapeInference::InferReverseShape(tuple_shape, {0}); - ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_THAT(inferred_status_error3.status().message(), + ASSERT_FALSE(inferred_shape_error3.ok()); + ASSERT_THAT(inferred_shape_error3.status().message(), HasSubstr("Expected array argument")); } TEST_F(ShapeInferenceTest, Call) { - auto inferred_status0 = + const absl::StatusOr inferred_shape0 = ShapeInference::InferCallShape({}, ShapeUtil::MakeProgramShape({}, f32_)); - EXPECT_IS_OK(inferred_status0.status()); - EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.value())); + EXPECT_IS_OK(inferred_shape0.status()); + EXPECT_TRUE(ShapeUtil::Equal(f32_, *inferred_shape0)); - auto inferred_status1 = ShapeInference::InferCallShape( + const absl::StatusOr inferred_shape1 = ShapeInference::InferCallShape( {&f32_, &s32_, &pred_, &vector_32_, &matrix_32_48_}, ShapeUtil::MakeProgramShape( {f32_, s32_, pred_, vector_32_, matrix_32_48_}, s32matrix_64_64_)); - EXPECT_IS_OK(inferred_status1.status()); - EXPECT_TRUE(ShapeUtil::Equal(s32matrix_64_64_, inferred_status1.value())); - - auto inferred_status_error0 = ShapeInference::InferCallShape( - {}, ShapeUtil::MakeProgramShape({f32_}, f32_)); - EXPECT_FALSE(inferred_status_error0.ok()); - EXPECT_THAT(inferred_status_error0.status().message(), + EXPECT_IS_OK(inferred_shape1.status()); + EXPECT_TRUE(ShapeUtil::Equal(s32matrix_64_64_, *inferred_shape1)); + + const absl::StatusOr inferred_shape_error0 = + ShapeInference::InferCallShape({}, + ShapeUtil::MakeProgramShape({f32_}, f32_)); + EXPECT_FALSE(inferred_shape_error0.ok()); + EXPECT_THAT(inferred_shape_error0.status().message(), HasSubstr("arity must match")); - auto inferred_status_error1 = ShapeInference::InferCallShape( - {&f32_}, ShapeUtil::MakeProgramShape({}, f32_)); - EXPECT_FALSE(inferred_status_error1.ok()); - EXPECT_THAT(inferred_status_error1.status().message(), + const absl::StatusOr inferred_shape_error1 = + ShapeInference::InferCallShape({&f32_}, + ShapeUtil::MakeProgramShape({}, f32_)); + EXPECT_FALSE(inferred_shape_error1.ok()); + EXPECT_THAT(inferred_shape_error1.status().message(), HasSubstr("arity must match")); - auto inferred_status_error2 = ShapeInference::InferCallShape( - {&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_)); - EXPECT_FALSE(inferred_status_error2.ok()); - EXPECT_THAT(inferred_status_error2.status().message(), + const absl::StatusOr inferred_shape_error2 = + ShapeInference::InferCallShape({&f32_}, + ShapeUtil::MakeProgramShape({s32_}, f32_)); + EXPECT_FALSE(inferred_shape_error2.ok()); + EXPECT_THAT(inferred_shape_error2.status().message(), HasSubstr("parameter must match argument")); } TEST_F(ShapeInferenceTest, Transpose) { - Shape a_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5}); - auto inferred_shape_and_status = + const Shape a_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5}); + const absl::StatusOr inferred_shape_and_status = ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0}); EXPECT_IS_OK(inferred_shape_and_status); - Shape inferred_shape = inferred_shape_and_status.value(); - EXPECT_TRUE(ShapeUtil::Compatible(inferred_shape, - ShapeUtil::MakeShape(F32, {3, 4, 5, 2}))); + EXPECT_TRUE(ShapeUtil::Compatible(ShapeUtil::MakeShape(F32, {3, 4, 5, 2}), + *inferred_shape_and_status)); } TEST_F(ShapeInferenceTest, Rank1Transpose) { - Shape a_shape = ShapeUtil::MakeShape(F32, {5}); - auto inferred_shape_and_status = + const Shape a_shape = ShapeUtil::MakeShape(F32, {5}); + const absl::StatusOr inferred_shape_and_status = ShapeInference::InferTransposeShape(a_shape, {0}); EXPECT_IS_OK(inferred_shape_and_status); - Shape inferred_shape = inferred_shape_and_status.value(); - EXPECT_TRUE( - ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {5}))); + EXPECT_TRUE(ShapeUtil::Compatible(ShapeUtil::MakeShape(F32, {5}), + *inferred_shape_and_status)); } TEST_F(ShapeInferenceTest, ConditionalPred) { - auto inferred_status0 = ShapeInference::InferConditionalShape( - pred_, - {ShapeUtil::MakeProgramShape({vector_32_}, f32_), - ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, - {vector_32_, vector_64_}); - EXPECT_IS_OK(inferred_status0.status()); - EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.value())); - - auto inferred_status1 = ShapeInference::InferConditionalShape( - pred_, - {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), - ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)}, - {matrix_32_48_, vector_32_}); - EXPECT_IS_OK(inferred_status1.status()); - EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.value())); - - auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_}); - auto inferred_status2 = ShapeInference::InferConditionalShape( - pred_, - {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), - ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)}, - {matrix_32_48_, tuple_f32_v32}); - EXPECT_IS_OK(inferred_status2.status()); - EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.value())); - - auto inferred_status_error0 = ShapeInference::InferConditionalShape( - f32_, - {ShapeUtil::MakeProgramShape({vector_32_}, f32_), - ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, - {vector_32_, vector_64_}); - EXPECT_FALSE(inferred_status_error0.ok()); - EXPECT_THAT(inferred_status_error0.status().message(), + const absl::StatusOr inferred_shape0 = + ShapeInference::InferConditionalShape( + pred_, + {ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, + {vector_32_, vector_64_}); + EXPECT_IS_OK(inferred_shape0.status()); + EXPECT_TRUE(ShapeUtil::Equal(f32_, *inferred_shape0)); + + const absl::StatusOr inferred_shape1 = + ShapeInference::InferConditionalShape( + pred_, + {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), + ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)}, + {matrix_32_48_, vector_32_}); + EXPECT_IS_OK(inferred_shape1.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_64_, *inferred_shape1)); + + const auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_}); + const absl::StatusOr inferred_shape2 = + ShapeInference::InferConditionalShape( + pred_, + {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), + ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)}, + {matrix_32_48_, tuple_f32_v32}); + EXPECT_IS_OK(inferred_shape2.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_32_, *inferred_shape2)); + + const absl::StatusOr inferred_shape_error0 = + ShapeInference::InferConditionalShape( + f32_, + {ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, + {vector_32_, vector_64_}); + EXPECT_FALSE(inferred_shape_error0.ok()); + EXPECT_THAT(inferred_shape_error0.status().message(), HasSubstr("must be bool or int32_t")); - auto inferred_status_error1 = ShapeInference::InferConditionalShape( - pred_, - {ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_), - ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)}, - {ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_}); - EXPECT_FALSE(inferred_status_error1.ok()); - EXPECT_THAT(inferred_status_error1.status().message(), + const absl::StatusOr inferred_shape_error1 = + ShapeInference::InferConditionalShape( + pred_, + {ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_), + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)}, + {ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_}); + EXPECT_FALSE(inferred_shape_error1.ok()); + EXPECT_THAT(inferred_shape_error1.status().message(), HasSubstr("branch computation 0 must take 1 argument")); - auto inferred_status_error2 = ShapeInference::InferConditionalShape( - pred_, - {ShapeUtil::MakeProgramShape({vector_64_}, f32_), - ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, - {vector_32_, vector_64_}); - EXPECT_FALSE(inferred_status_error2.ok()); - EXPECT_THAT(inferred_status_error2.status().message(), + const absl::StatusOr inferred_shape_error2 = + ShapeInference::InferConditionalShape( + pred_, + {ShapeUtil::MakeProgramShape({vector_64_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, + {vector_32_, vector_64_}); + EXPECT_FALSE(inferred_shape_error2.ok()); + EXPECT_THAT(inferred_shape_error2.status().message(), HasSubstr("branch operand 0 must match the shape of the only " "parameter of branch computation 0")); - auto inferred_status_error3 = ShapeInference::InferConditionalShape( - pred_, - {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), - ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)}, - {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_})}); - EXPECT_FALSE(inferred_status_error3.ok()); - EXPECT_THAT(inferred_status_error3.status().message(), + const absl::StatusOr inferred_shape_error3 = + ShapeInference::InferConditionalShape( + pred_, + {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), + ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)}, + {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_})}); + EXPECT_FALSE(inferred_shape_error3.ok()); + EXPECT_THAT(inferred_shape_error3.status().message(), HasSubstr("branch computation 1 must take 1 argument")); - auto inferred_status_error4 = ShapeInference::InferConditionalShape( - pred_, - {ShapeUtil::MakeProgramShape({vector_32_}, f32_), - ShapeUtil::MakeProgramShape({vector_32_}, f32_)}, - {vector_32_, vector_64_}); - EXPECT_FALSE(inferred_status_error4.ok()); - EXPECT_THAT(inferred_status_error4.status().message(), + const absl::StatusOr inferred_shape_error4 = + ShapeInference::InferConditionalShape( + pred_, + {ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_32_}, f32_)}, + {vector_32_, vector_64_}); + EXPECT_FALSE(inferred_shape_error4.ok()); + EXPECT_THAT(inferred_shape_error4.status().message(), HasSubstr("branch operand 1 must match the shape of the only " "parameter of branch computation 1")); - auto inferred_status_error5 = ShapeInference::InferConditionalShape( - pred_, - {ShapeUtil::MakeProgramShape({vector_32_}, f32_), - ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)}, - {vector_32_, vector_64_}); - EXPECT_FALSE(inferred_status_error5.ok()); - EXPECT_THAT(inferred_status_error5.status().message(), + const absl::StatusOr inferred_shape_error5 = + ShapeInference::InferConditionalShape( + pred_, + {ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)}, + {vector_32_, vector_64_}); + EXPECT_FALSE(inferred_shape_error5.ok()); + EXPECT_THAT(inferred_shape_error5.status().message(), HasSubstr("the result of branch 0 computation and branch 1 " "computation must have the same shape")); } TEST_F(ShapeInferenceTest, ConditionalIndexed) { - auto r0s32 = ShapeUtil::MakeShape(S32, {}); - auto inferred_status0 = ShapeInference::InferConditionalShape( - r0s32, - {ShapeUtil::MakeProgramShape({vector_32_}, f32_), - ShapeUtil::MakeProgramShape({vector_64_}, f32_), - ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, - {vector_32_, vector_64_, vector_64_}); - EXPECT_IS_OK(inferred_status0.status()); - EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.value())); - - auto inferred_status1 = ShapeInference::InferConditionalShape( - r0s32, - {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), - ShapeUtil::MakeProgramShape({vector_32_}, vector_64_), - ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_)}, - {matrix_32_48_, vector_32_, matrix_32_48_}); - EXPECT_IS_OK(inferred_status1.status()); - EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.value())); - - auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_}); - auto inferred_status2 = ShapeInference::InferConditionalShape( - r0s32, {ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)}, - {tuple_f32_v32}); - EXPECT_IS_OK(inferred_status2.status()); - EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.value())); - - auto inferred_status_error0 = ShapeInference::InferConditionalShape( - pred_, - {ShapeUtil::MakeProgramShape({vector_32_}, f32_), - ShapeUtil::MakeProgramShape({vector_32_}, f32_), - ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, - {vector_32_, vector_32_, vector_64_}); - EXPECT_FALSE(inferred_status_error0.ok()); - EXPECT_THAT(inferred_status_error0.status().message(), + const Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + const absl::StatusOr inferred_shape0 = + ShapeInference::InferConditionalShape( + r0s32, + {ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, + {vector_32_, vector_64_, vector_64_}); + EXPECT_IS_OK(inferred_shape0.status()); + EXPECT_TRUE(ShapeUtil::Equal(f32_, *inferred_shape0)); + + const absl::StatusOr inferred_shape1 = + ShapeInference::InferConditionalShape( + r0s32, + {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), + ShapeUtil::MakeProgramShape({vector_32_}, vector_64_), + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_)}, + {matrix_32_48_, vector_32_, matrix_32_48_}); + EXPECT_IS_OK(inferred_shape1.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_64_, *inferred_shape1)); + + const auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_}); + const absl::StatusOr inferred_shape2 = + ShapeInference::InferConditionalShape( + r0s32, {ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)}, + {tuple_f32_v32}); + EXPECT_IS_OK(inferred_shape2.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_32_, *inferred_shape2)); + + const absl::StatusOr inferred_shape_error0 = + ShapeInference::InferConditionalShape( + pred_, + {ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, + {vector_32_, vector_32_, vector_64_}); + EXPECT_FALSE(inferred_shape_error0.ok()); + EXPECT_THAT(inferred_shape_error0.status().message(), HasSubstr("2 == branch_computations.size()")); - auto inferred_status_error1 = ShapeInference::InferConditionalShape( - r0s32, - {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), - ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_), - ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)}, - {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), - matrix_32_48_}); - EXPECT_FALSE(inferred_status_error1.ok()); - EXPECT_THAT(inferred_status_error1.status().message(), + const absl::StatusOr inferred_shape_error1 = + ShapeInference::InferConditionalShape( + r0s32, + {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), + ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_), + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)}, + {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), + matrix_32_48_}); + EXPECT_FALSE(inferred_shape_error1.ok()); + EXPECT_THAT(inferred_shape_error1.status().message(), HasSubstr("branch computation 1 must take 1 argument")); - auto inferred_status_error2 = ShapeInference::InferConditionalShape( - r0s32, - {ShapeUtil::MakeProgramShape({r0s32}, f32_), - ShapeUtil::MakeProgramShape({vector_32_}, f32_), - ShapeUtil::MakeProgramShape({vector_32_}, f32_)}, - {r0s32, vector_32_, vector_64_}); - EXPECT_FALSE(inferred_status_error2.ok()); - EXPECT_THAT(inferred_status_error2.status().message(), + const absl::StatusOr inferred_shape_error2 = + ShapeInference::InferConditionalShape( + r0s32, + {ShapeUtil::MakeProgramShape({r0s32}, f32_), + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_32_}, f32_)}, + {r0s32, vector_32_, vector_64_}); + EXPECT_FALSE(inferred_shape_error2.ok()); + EXPECT_THAT(inferred_shape_error2.status().message(), HasSubstr("branch operand 2 must match the shape of the only " "parameter of branch computation 2")); - auto inferred_status_error3 = ShapeInference::InferConditionalShape( - r0s32, - {ShapeUtil::MakeProgramShape({vector_32_}, f32_), - ShapeUtil::MakeProgramShape({vector_32_}, f32_), - ShapeUtil::MakeProgramShape({vector_32_}, f32_), - ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)}, - {vector_32_, vector_32_, vector_32_, vector_64_}); - EXPECT_FALSE(inferred_status_error3.ok()); - EXPECT_THAT(inferred_status_error3.status().message(), + const absl::StatusOr inferred_shape_error3 = + ShapeInference::InferConditionalShape( + r0s32, + {ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)}, + {vector_32_, vector_32_, vector_32_, vector_64_}); + EXPECT_FALSE(inferred_shape_error3.ok()); + EXPECT_THAT(inferred_shape_error3.status().message(), HasSubstr("the result of branch 0 computation and branch 3 " "computation must have the same shape")); - auto inferred_status_error4 = + const absl::StatusOr inferred_shape_error4 = ShapeInference::InferConditionalShape(r0s32, {}, {}); - EXPECT_FALSE(inferred_status_error4.ok()); - EXPECT_THAT(inferred_status_error4.status().message(), + EXPECT_FALSE(inferred_shape_error4.ok()); + EXPECT_THAT(inferred_shape_error4.status().message(), HasSubstr("!branch_computations.empty()")); } TEST_F(ShapeInferenceTest, ConditionalDynamic) { - auto r0s32 = ShapeUtil::MakeShape(S32, {}); - auto static_shape = ShapeUtil::MakeShape(S32, {4}, {false}); - auto dynamic_shape = ShapeUtil::MakeShape(S32, {4}, {true}); - auto inferred_status0 = ShapeInference::InferConditionalShape( - r0s32, - {ShapeUtil::MakeProgramShape({vector_32_}, static_shape), - ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape), - ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape)}, - {vector_32_, vector_64_, vector_64_}); - EXPECT_IS_OK(inferred_status0.status()); - EXPECT_TRUE(ShapeUtil::Equal(dynamic_shape, inferred_status0.value())); - - auto inferred_status1 = ShapeInference::InferConditionalShape( - r0s32, - {ShapeUtil::MakeProgramShape({vector_32_}, dynamic_shape), - ShapeUtil::MakeProgramShape({vector_64_}, static_shape), - ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape)}, - {vector_32_, vector_64_, vector_64_}); - EXPECT_IS_OK(inferred_status1.status()); - EXPECT_TRUE(ShapeUtil::Equal(dynamic_shape, inferred_status1.value())); + const Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + const Shape static_shape = ShapeUtil::MakeShape(S32, {4}, {false}); + const Shape dynamic_shape = ShapeUtil::MakeShape(S32, {4}, {true}); + const absl::StatusOr inferred_shape0 = + ShapeInference::InferConditionalShape( + r0s32, + {ShapeUtil::MakeProgramShape({vector_32_}, static_shape), + ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape), + ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape)}, + {vector_32_, vector_64_, vector_64_}); + EXPECT_IS_OK(inferred_shape0.status()); + EXPECT_TRUE(ShapeUtil::Equal(dynamic_shape, *inferred_shape0)); + + const absl::StatusOr inferred_shape1 = + ShapeInference::InferConditionalShape( + r0s32, + {ShapeUtil::MakeProgramShape({vector_32_}, dynamic_shape), + ShapeUtil::MakeProgramShape({vector_64_}, static_shape), + ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape)}, + {vector_32_, vector_64_, vector_64_}); + EXPECT_IS_OK(inferred_shape1.status()); + EXPECT_TRUE(ShapeUtil::Equal(dynamic_shape, *inferred_shape1)); } TEST_F(ShapeInferenceTest, BadSlice) { - auto arg = ShapeUtil::MakeShape(F32, {4}); - StatusOr statusor = + const Shape arg = ShapeUtil::MakeShape(F32, {4}); + const absl::StatusOr statusor = ShapeInference::InferSliceShape(arg, {0}, {5}, {1}); ASSERT_FALSE(statusor.ok()); @@ -2459,9 +2699,9 @@ TEST_F(ShapeInferenceTest, BadSlice) { } TEST_F(ShapeInferenceTest, BadSort) { - auto keys = ShapeUtil::MakeShape(F32, {4}); - auto values = ShapeUtil::MakeShape(F32, {5}); - StatusOr statusor = + const Shape keys = ShapeUtil::MakeShape(F32, {4}); + const Shape values = ShapeUtil::MakeShape(F32, {5}); + const absl::StatusOr statusor = ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values}); EXPECT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().message(), HasSubstr("dimensions must match")) @@ -2469,10 +2709,10 @@ TEST_F(ShapeInferenceTest, BadSort) { } TEST_F(ShapeInferenceTest, BadSortValuesMismatch) { - auto keys = ShapeUtil::MakeShape(F32, {4}); - auto values_good = ShapeUtil::MakeShape(F32, {4}); - auto values_bad = ShapeUtil::MakeShape(F32, {5}); - StatusOr statusor = ShapeInference::InferVariadicOpShape( + const Shape keys = ShapeUtil::MakeShape(F32, {4}); + const Shape values_good = ShapeUtil::MakeShape(F32, {4}); + const Shape values_bad = ShapeUtil::MakeShape(F32, {5}); + const absl::StatusOr statusor = ShapeInference::InferVariadicOpShape( HloOpcode::kSort, {&keys, &values_good, &values_bad}); EXPECT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().message(), HasSubstr("dimensions must match")) @@ -2480,21 +2720,22 @@ TEST_F(ShapeInferenceTest, BadSortValuesMismatch) { } TEST_F(ShapeInferenceTest, SortManyValues) { - auto keys = ShapeUtil::MakeShape(F32, {4}); - auto values_s32 = ShapeUtil::MakeShape(S32, {4}); - auto values_u32 = ShapeUtil::MakeShape(U32, {4}); - StatusOr statusor = ShapeInference::InferVariadicOpShape( + const Shape keys = ShapeUtil::MakeShape(F32, {4}); + const Shape values_s32 = ShapeUtil::MakeShape(S32, {4}); + const Shape values_u32 = ShapeUtil::MakeShape(U32, {4}); + const absl::StatusOr statusor = ShapeInference::InferVariadicOpShape( HloOpcode::kSort, {&keys, &values_s32, &values_u32}); EXPECT_IS_OK(statusor); - Shape inferred_shape = statusor.value(); + const Shape inferred_shape = *statusor; EXPECT_TRUE(ShapeUtil::Compatible( inferred_shape, ShapeUtil::MakeTupleShape({keys, values_s32, values_u32}))); } TEST_F(ShapeInferenceTest, GoodTopK) { - auto input = ShapeUtil::MakeShape(F32, {3, 4, 5}); - StatusOr s = ShapeInference::InferTopKShape(input, /*k=*/2); + const Shape input = ShapeUtil::MakeShape(F32, {3, 4, 5}); + const absl::StatusOr s = + ShapeInference::InferTopKShape(input, /*k=*/2); ASSERT_IS_OK(s.status()); ASSERT_TRUE(ShapeUtil::Equal( *s, ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 4, 2}), @@ -2502,8 +2743,9 @@ TEST_F(ShapeInferenceTest, GoodTopK) { } TEST_F(ShapeInferenceTest, FailTopKLargeK) { - auto input = ShapeUtil::MakeShape(F32, {3, 4, 5}); - StatusOr statusor = ShapeInference::InferTopKShape(input, /*k=*/10); + const Shape input = ShapeUtil::MakeShape(F32, {3, 4, 5}); + const absl::StatusOr statusor = + ShapeInference::InferTopKShape(input, /*k=*/10); EXPECT_FALSE(statusor.ok()); } @@ -2512,10 +2754,10 @@ TEST_F(ShapeInferenceTest, InferStochasticConvertShape) { const Shape random = ShapeUtil::MakeShape(U32, {4, 3}); const Shape expected_shape = ShapeUtil::MakeShape(S8, {4, 3}); - auto inferred_sr_shape = + const absl::StatusOr inferred_sr_shape = ShapeInference::InferStochasticConvertShape(operand, random, S8); EXPECT_TRUE(inferred_sr_shape.ok()); - EXPECT_TRUE(ShapeUtil::Equal(inferred_sr_shape.value(), expected_shape)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_sr_shape, expected_shape)); } TEST_F(ShapeInferenceTest, InvalidStochasticConvert_MismatchRandomElementType) { @@ -2523,7 +2765,7 @@ TEST_F(ShapeInferenceTest, InvalidStochasticConvert_MismatchRandomElementType) { const Shape random = ShapeUtil::MakeShape(U16, {4, 3}); const Shape expected_shape = ShapeUtil::MakeShape(S8, {4, 3}); - auto status_or = + const auto status_or = ShapeInference::InferStochasticConvertShape(operand, random, S8); ASSERT_FALSE(status_or.ok()); EXPECT_THAT( @@ -2538,7 +2780,7 @@ TEST_F(ShapeInferenceTest, const Shape random = ShapeUtil::MakeShape(S32, {4, 3}); const Shape expected_shape = ShapeUtil::MakeShape(S8, {4, 3}); - auto status_or = + const auto status_or = ShapeInference::InferStochasticConvertShape(operand, random, S8); ASSERT_FALSE(status_or.ok()); EXPECT_THAT( @@ -2567,7 +2809,7 @@ class GatherShapeInferenceTest : public ShapeInferenceTest { }; TEST_F(GatherShapeInferenceTest, TensorFlowGather) { - TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + TF_ASSERT_OK_AND_ASSIGN(const Shape gather_shape, ShapeInference::InferGatherShape( matrix_64_48_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( @@ -2582,7 +2824,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGather) { } TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { - TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + TF_ASSERT_OK_AND_ASSIGN(const Shape gather_shape, ShapeInference::InferGatherShape( matrix_64_48_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( @@ -2597,7 +2839,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { } TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { - TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + TF_ASSERT_OK_AND_ASSIGN(const Shape gather_shape, ShapeInference::InferGatherShape( matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, HloGatherInstruction::MakeGatherDimNumbers( @@ -2613,7 +2855,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, + const Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( @@ -2630,7 +2872,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { TEST_F(GatherShapeInferenceTest, DynamicGatherEntireDimension) { TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, + const Shape gather_shape, ShapeInference::InferGatherShape( ShapeUtil::MakeShape(F32, {3, 2, 1}, {false, true, false}), ShapeUtil::MakeShape(S64, {}), @@ -2647,7 +2889,7 @@ TEST_F(GatherShapeInferenceTest, DynamicGatherEntireDimension) { TEST_F(GatherShapeInferenceTest, DynamicGatherCollapsedDimension) { TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, + const Shape gather_shape, ShapeInference::InferGatherShape( ShapeUtil::MakeShape(F32, {3, 2, 1}, {true, false, false}), ShapeUtil::MakeShape(S64, {}), @@ -2664,7 +2906,7 @@ TEST_F(GatherShapeInferenceTest, DynamicGatherCollapsedDimension) { TEST_F(GatherShapeInferenceTest, DynamicIndices) { TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, + const Shape gather_shape, ShapeInference::InferGatherShape( ShapeUtil::MakeShape(F32, {3, 2, 2}), ShapeUtil::MakeShape(S64, {3, 4, 2}, {false, true, false}), @@ -2682,7 +2924,7 @@ TEST_F(GatherShapeInferenceTest, DynamicIndices) { TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, + const Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, HloGatherInstruction::MakeGatherDimNumbers( @@ -2700,7 +2942,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, + const Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, HloGatherInstruction::MakeGatherDimNumbers( @@ -2718,7 +2960,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { // This is equivalent to a dynamic slice. - TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + TF_ASSERT_OK_AND_ASSIGN(const Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, HloGatherInstruction::MakeGatherDimNumbers( @@ -2736,7 +2978,7 @@ TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { // The gather indices "tensor" is a scalar S here that's used to slice out // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result. - TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + TF_ASSERT_OK_AND_ASSIGN(const Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_scalar_, HloGatherInstruction::MakeGatherDimNumbers( @@ -2752,7 +2994,7 @@ TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { } TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { - StatusOr statusor = ShapeInference::InferGatherShape( + const absl::StatusOr statusor = ShapeInference::InferGatherShape( tuple_shape_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{0}, @@ -2767,7 +3009,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { } TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { - StatusOr statusor = ShapeInference::InferGatherShape( + const absl::StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, tuple_shape_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{0}, @@ -2782,7 +3024,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { } TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { - StatusOr statusor = ShapeInference::InferGatherShape( + const absl::StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, vector_32_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{0}, @@ -2798,7 +3040,7 @@ TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingWindowIndices) { - StatusOr statusor = ShapeInference::InferGatherShape( + const absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 8, 7}, @@ -2815,7 +3057,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowIndices) { - StatusOr statusor = ShapeInference::InferGatherShape( + const absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 7}, @@ -2832,7 +3074,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexOutOfBounds) { - StatusOr statusor = ShapeInference::InferGatherShape( + const absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 99, 100, 101}, @@ -2848,7 +3090,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) { - StatusOr statusor = ShapeInference::InferGatherShape( + const absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 9}, @@ -2864,7 +3106,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingElidedWindowDims) { - StatusOr statusor = ShapeInference::InferGatherShape( + const absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, @@ -2882,7 +3124,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) { - StatusOr statusor = ShapeInference::InferGatherShape( + const absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, @@ -2899,7 +3141,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowToInputMapping) { - StatusOr statusor = ShapeInference::InferGatherShape( + const absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, @@ -2916,7 +3158,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingGatherToInputMapping) { - StatusOr statusor = ShapeInference::InferGatherShape( + const absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, @@ -2934,7 +3176,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) { - StatusOr statusor = ShapeInference::InferGatherShape( + const absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, @@ -2950,7 +3192,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedGatherToInputMapping) { - StatusOr statusor = ShapeInference::InferGatherShape( + const absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, @@ -2967,7 +3209,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingElidedWindowDims) { - StatusOr statusor = ShapeInference::InferGatherShape( + const absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, @@ -2982,7 +3224,7 @@ TEST_F(GatherShapeInferenceTest, } TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { - StatusOr statusor = ShapeInference::InferGatherShape( + const absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7}, @@ -2999,7 +3241,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) { - StatusOr statusor = ShapeInference::InferGatherShape( + const absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, @@ -3016,7 +3258,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) { - StatusOr statusor = ShapeInference::InferGatherShape( + const absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7}, @@ -3033,7 +3275,7 @@ TEST_F(GatherShapeInferenceTest, } TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { - StatusOr statusor = ShapeInference::InferGatherShape( + const absl::StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, @@ -3118,8 +3360,8 @@ class ScatterShapeInferenceTest }; TEST_P(ScatterShapeInferenceTest, TfScatterWithFullUpdates) { - auto shapes = CreateShapes({64, 48}, s64_vector(32), {64, 32}, types()); - TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + const auto shapes = CreateShapes({64, 48}, s64_vector(32), {64, 32}, types()); + TF_ASSERT_OK_AND_ASSIGN(const Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( @@ -3132,8 +3374,8 @@ TEST_P(ScatterShapeInferenceTest, TfScatterWithFullUpdates) { } TEST_P(ScatterShapeInferenceTest, TfScatterWithFullUpdatesV2) { - auto shapes = CreateShapes({64, 48}, s64_vector(32), {32, 48}, types()); - TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + const auto shapes = CreateShapes({64, 48}, s64_vector(32), {32, 48}, types()); + TF_ASSERT_OK_AND_ASSIGN(const Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( @@ -3146,8 +3388,8 @@ TEST_P(ScatterShapeInferenceTest, TfScatterWithFullUpdatesV2) { } TEST_P(ScatterShapeInferenceTest, TfScatterWithPartialUpdates) { - auto shapes = CreateShapes({64, 48}, s64_vector(32), {10, 32}, types()); - TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + const auto shapes = CreateShapes({64, 48}, s64_vector(32), {10, 32}, types()); + TF_ASSERT_OK_AND_ASSIGN(const Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( @@ -3160,8 +3402,8 @@ TEST_P(ScatterShapeInferenceTest, TfScatterWithPartialUpdates) { } TEST_P(ScatterShapeInferenceTest, TfScatterWithPartialUpdatesV2) { - auto shapes = CreateShapes({64, 48}, s64_vector(32), {32, 8}, types()); - TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + const auto shapes = CreateShapes({64, 48}, s64_vector(32), {32, 8}, types()); + TF_ASSERT_OK_AND_ASSIGN(const Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( @@ -3174,8 +3416,8 @@ TEST_P(ScatterShapeInferenceTest, TfScatterWithPartialUpdatesV2) { } TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) { - auto shapes = CreateShapes({64, 48}, s64_vector(32), {65, 32}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + const auto shapes = CreateShapes({64, 48}, s64_vector(32), {65, 32}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, @@ -3191,8 +3433,8 @@ TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) { } TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) { - auto shapes = CreateShapes({64, 48}, s64_vector(32), {32, 49}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + const auto shapes = CreateShapes({64, 48}, s64_vector(32), {32, 49}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{1}, @@ -3208,8 +3450,8 @@ TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) { } TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesNotMatchingIndices) { - auto shapes = CreateShapes({64, 48}, s64_vector(32), {64, 31}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + const auto shapes = CreateShapes({64, 48}, s64_vector(32), {64, 31}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, @@ -3226,8 +3468,8 @@ TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesNotMatchingIndices) { } TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesNotMatchingIndicesV2) { - auto shapes = CreateShapes({64, 48}, s64_vector(32), {31, 48}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + const auto shapes = CreateShapes({64, 48}, s64_vector(32), {31, 48}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{1}, @@ -3244,9 +3486,9 @@ TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesNotMatchingIndicesV2) { } TEST_P(ScatterShapeInferenceTest, TfScatterNdWithFullUpdates) { - auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), - {10, 9, 8, 7, 48}, types()); - TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + const auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), + {10, 9, 8, 7, 48}, types()); + TF_ASSERT_OK_AND_ASSIGN(const Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( @@ -3259,9 +3501,9 @@ TEST_P(ScatterShapeInferenceTest, TfScatterNdWithFullUpdates) { } TEST_P(ScatterShapeInferenceTest, TfScatterNdWithFullUpdatesV2) { - auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), - {10, 9, 8, 7, 64}, types()); - TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + const auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), + {10, 9, 8, 7, 64}, types()); + TF_ASSERT_OK_AND_ASSIGN(const Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( @@ -3274,9 +3516,9 @@ TEST_P(ScatterShapeInferenceTest, TfScatterNdWithFullUpdatesV2) { } TEST_P(ScatterShapeInferenceTest, TfScatterNdWithPartialUpdates) { - auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), - {10, 9, 8, 7, 10}, types()); - TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + const auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), + {10, 9, 8, 7, 10}, types()); + TF_ASSERT_OK_AND_ASSIGN(const Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( @@ -3289,9 +3531,9 @@ TEST_P(ScatterShapeInferenceTest, TfScatterNdWithPartialUpdates) { } TEST_P(ScatterShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) { - auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), - {10, 9, 8, 7, 12}, types()); - TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + const auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), + {10, 9, 8, 7, 12}, types()); + TF_ASSERT_OK_AND_ASSIGN(const Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( @@ -3304,9 +3546,9 @@ TEST_P(ScatterShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) { } TEST_P(ScatterShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) { - auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), - {10, 9, 8, 7, 65}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + const auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), + {10, 9, 8, 7, 65}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4}, @@ -3322,9 +3564,9 @@ TEST_P(ScatterShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) { } TEST_P(ScatterShapeInferenceTest, TfScatterNdWithUpdatesNotMatchingIndices) { - auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), - {9, 9, 8, 7, 64}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + const auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), + {9, 9, 8, 7, 64}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4}, @@ -3341,10 +3583,11 @@ TEST_P(ScatterShapeInferenceTest, TfScatterNdWithUpdatesNotMatchingIndices) { } TEST_P(ScatterShapeInferenceTest, TfBatchDynamicUpdateSlice) { - auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), - {10, 9, 8, 7, 30, 29, 28, 27, 26}, types()); + const auto shapes = + CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), + {10, 9, 8, 7, 30, 29, 28, 27, 26}, types()); TF_ASSERT_OK_AND_ASSIGN( - Shape scatter_shape, + const Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( @@ -3358,10 +3601,11 @@ TEST_P(ScatterShapeInferenceTest, TfBatchDynamicUpdateSlice) { } TEST_P(ScatterShapeInferenceTest, NonDefaultScatterIndicesLeafDim) { - auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 5, 7, 6}), - {10, 9, 7, 6, 30, 29, 28, 27, 26}, types()); + const auto shapes = + CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 5, 7, 6}), + {10, 9, 7, 6, 30, 29, 28, 27, 26}, types()); TF_ASSERT_OK_AND_ASSIGN( - Shape scatter_shape, + const Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( @@ -3376,10 +3620,11 @@ TEST_P(ScatterShapeInferenceTest, NonDefaultScatterIndicesLeafDim) { } TEST_P(ScatterShapeInferenceTest, NonDefaultScatterIndicesLeafDimV2) { - auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({5, 10, 9, 7, 6}), - {10, 9, 7, 6, 30, 29, 28, 27, 26}, types()); + const auto shapes = + CreateShapes({50, 49, 48, 47, 46}, s64_tensor({5, 10, 9, 7, 6}), + {10, 9, 7, 6, 30, 29, 28, 27, 26}, types()); TF_ASSERT_OK_AND_ASSIGN( - Shape scatter_shape, + const Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( @@ -3394,11 +3639,11 @@ TEST_P(ScatterShapeInferenceTest, NonDefaultScatterIndicesLeafDimV2) { } TEST_P(ScatterShapeInferenceTest, NoUpdateScatterDims) { - auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_vector(5), - {30, 29, 28, 27, 26}, types()); + const auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_vector(5), + {30, 29, 28, 27, 26}, types()); // This is equivalent to a dynamic update slice. TF_ASSERT_OK_AND_ASSIGN( - Shape scatter_shape, + const Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( @@ -3413,11 +3658,11 @@ TEST_P(ScatterShapeInferenceTest, NoUpdateScatterDims) { } TEST_P(ScatterShapeInferenceTest, ScalarScatterIndices) { - auto shapes = CreateShapes({50, 49, 48, 47, 46}, scalar(S64), - {30, 29, 28, 27}, types()); + const auto shapes = CreateShapes({50, 49, 48, 47, 46}, scalar(S64), + {30, 29, 28, 27}, types()); // The scalar indices "tensor" is a scalar S here that's used to update a // [30,29,28,27] shaped tensor within the operand at position S. - TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + TF_ASSERT_OK_AND_ASSIGN(const Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( @@ -3432,11 +3677,11 @@ TEST_P(ScatterShapeInferenceTest, ScalarScatterIndices) { } TEST_P(ScatterShapeInferenceTest, ScatterWithTupleShapedTensorInput) { - Shape tuple_shape = + const Shape tuple_shape = ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}), ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1})}); - Shape s64_vector_32 = s64_vector(32); - StatusOr statusor = ShapeInference::InferScatterShape( + const Shape s64_vector_32 = s64_vector(32); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( {&tuple_shape, &s64_vector_32, &s64_vector_32}, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, @@ -3450,11 +3695,11 @@ TEST_P(ScatterShapeInferenceTest, ScatterWithTupleShapedTensorInput) { } TEST_P(ScatterShapeInferenceTest, ScatterWithTupleShapedScatterIndicesInput) { - Shape tuple_shape = + const Shape tuple_shape = ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}), ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1})}); - Shape s64_vector_32 = s64_vector(32); - StatusOr statusor = ShapeInference::InferScatterShape( + const Shape s64_vector_32 = s64_vector(32); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( {&s64_vector_32, &tuple_shape, &s64_vector_32}, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, @@ -3468,11 +3713,11 @@ TEST_P(ScatterShapeInferenceTest, ScatterWithTupleShapedScatterIndicesInput) { } TEST_P(ScatterShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) { - Shape tuple_shape = + const Shape tuple_shape = ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}), ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1})}); - Shape s64_vector_32 = s64_vector(32); - StatusOr statusor = ShapeInference::InferScatterShape( + const Shape s64_vector_32 = s64_vector(32); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( {&s64_vector_32, &s64_vector_32, &tuple_shape}, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, @@ -3486,8 +3731,8 @@ TEST_P(ScatterShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) { } TEST_P(ScatterShapeInferenceTest, FloatingPointScatterIndicesInput) { - Shape s64_vector_32 = s64_vector(32); - StatusOr statusor = ShapeInference::InferScatterShape( + const Shape s64_vector_32 = s64_vector(32); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( {&s64_vector_32, &vector_32_, &s64_vector_32}, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, @@ -3501,9 +3746,10 @@ TEST_P(ScatterShapeInferenceTest, FloatingPointScatterIndicesInput) { } TEST_P(ScatterShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) { - auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), - {10, 9, 8, 7, 30, 29, 28}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + const auto shapes = + CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), + {10, 9, 8, 7, 30, 29, 28}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, @@ -3518,9 +3764,10 @@ TEST_P(ScatterShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) { } TEST_P(ScatterShapeInferenceTest, InvalidUpdates) { - auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), - {10, 9, 8, 7, 30, 29, 28, 50}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + const auto shapes = + CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), + {10, 9, 8, 7, 30, 29, 28, 50}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, @@ -3536,9 +3783,10 @@ TEST_P(ScatterShapeInferenceTest, InvalidUpdates) { TEST_P(ScatterShapeInferenceTest, InvalidUpdateComputation) { const ProgramShape invalid_update_computation = ShapeUtil::MakeProgramShape({f32_}, f32_); - auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), - {10, 9, 8, 7, 30, 29, 28}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + const auto shapes = + CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), + {10, 9, 8, 7, 30, 29, 28}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, invalid_update_computation, HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, @@ -3555,9 +3803,10 @@ TEST_P(ScatterShapeInferenceTest, InvalidUpdateComputation) { TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_NonAscendingUpdateWindowDims) { - auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), - {10, 9, 8, 7, 30, 29, 28, 27, 26}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + const auto shapes = + CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), + {10, 9, 8, 7, 30, 29, 28, 27, 26}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6, 8, 7}, @@ -3572,9 +3821,10 @@ TEST_P(ScatterShapeInferenceTest, TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_RepeatedUpdateWindowDims) { - auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), - {10, 9, 8, 7, 30, 29, 28, 27, 26}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + const auto shapes = + CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), + {10, 9, 8, 7, 30, 29, 28, 27, 26}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6, 7, 7}, @@ -3589,9 +3839,10 @@ TEST_P(ScatterShapeInferenceTest, TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims) { - auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), - {10, 9, 8, 7, 30, 29, 28, 27, 26}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + const auto shapes = + CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), + {10, 9, 8, 7, 30, 29, 28, 27, 26}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6, 7, 9}, @@ -3607,9 +3858,10 @@ TEST_P(ScatterShapeInferenceTest, TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_NonAscendingInsertedWindowDims) { - auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), - {10, 9, 8, 7, 30, 29, 28}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + const auto shapes = + CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), + {10, 9, 8, 7, 30, 29, 28}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, @@ -3624,9 +3876,10 @@ TEST_P(ScatterShapeInferenceTest, TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_RepeatedInsertedWindowDims) { - auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), - {10, 9, 8, 7, 30, 29, 28}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + const auto shapes = + CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), + {10, 9, 8, 7, 30, 29, 28}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, @@ -3641,9 +3894,10 @@ TEST_P(ScatterShapeInferenceTest, TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims) { - auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), - {10, 9, 8, 7, 30, 29, 28}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + const auto shapes = + CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), + {10, 9, 8, 7, 30, 29, 28}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, @@ -3659,9 +3913,10 @@ TEST_P(ScatterShapeInferenceTest, TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims) { - auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), - {10, 9, 8, 7, 30, 29, 28}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + const auto shapes = + CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), + {10, 9, 8, 7, 30, 29, 28}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, @@ -3679,9 +3934,10 @@ TEST_P(ScatterShapeInferenceTest, TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims) { - auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), - {10, 9, 8, 7, 30, 29, 28}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + const auto shapes = + CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), + {10, 9, 8, 7, 30, 29, 28}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, @@ -3697,9 +3953,10 @@ TEST_P(ScatterShapeInferenceTest, TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims) { - auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), - {10, 9, 8, 7, 30, 29, 28}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + const auto shapes = + CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), + {10, 9, 8, 7, 30, 29, 28}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, @@ -3716,9 +3973,9 @@ TEST_P(ScatterShapeInferenceTest, TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_InsufficientWindowDims) { - auto shapes = CreateShapes({50, 49, 48, 47, 46}, scalar(S64), - {30, 29, 28, 27}, types()); - StatusOr statusor = ShapeInference::InferScatterShape( + const auto shapes = CreateShapes({50, 49, 48, 47, 46}, scalar(S64), + {30, 29, 28, 27}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0, 1, 2, 3}, @@ -3746,187 +4003,1285 @@ INSTANTIATE_TEST_SUITE_P(All, ScatterShapeInferenceTest, BF16}), ScatterTestName()); -TEST_P(UnboundedUnaryOpShapeInferenceTest, UnboundedAbs) { - StatusOr operand = ParseShape(GetParam()[0]); - StatusOr expected = ParseShape(GetParam()[1]); - ASSERT_IS_OK(operand.status()); - StatusOr inferred_status = - ShapeInference::InferUnaryOpShape(HloOpcode::kExp, operand.value()); - ASSERT_IS_OK(expected.status()); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), expected.value())) - << "inferred: " << ShapeUtil::HumanString(inferred_status.value()) - << " expected: " << ShapeUtil::HumanString(expected.value()); +TEST_P(UnboundedUnaryOpShapeInferenceTest, UnboundedUnaryOps) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape(GetParam().operand)); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred, + ShapeInference::InferUnaryOpShape(GetParam().opcode, operand)); + EXPECT_TRUE(ShapeUtil::Equal(inferred, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred) + << " expected: " << ShapeUtil::HumanString(expected); } TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedAdd) { - StatusOr lhs = ParseShape(GetParam()[0]); - StatusOr rhs = ParseShape(GetParam()[1]); - StatusOr expected = ParseShape(GetParam()[2]); - ASSERT_IS_OK(lhs.status()); - ASSERT_IS_OK(rhs.status()); - StatusOr inferred_status = ShapeInference::InferBinaryOpShape( - HloOpcode::kAdd, lhs.value(), rhs.value(), - /*broadcast_dimensions=*/{}); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_shape = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, lhs, rhs, + GetParam().broadcast_dimensions); + if (inferred_shape.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + ASSERT_TRUE(GetParam().error_message.has_value()); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr(*GetParam().error_message)); + } +} + +TEST_P(UnboundedLogicalOpShapeInferenceTest, UnboundedAnd) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_shape = + ShapeInference::InferBinaryOpShape(HloOpcode::kAnd, lhs, rhs, + GetParam().broadcast_dimensions); + if (inferred_shape.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + ASSERT_TRUE(GetParam().error_message.has_value()); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr(*GetParam().error_message)); + } +} + +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedAtan2) { + TF_ASSERT_OK_AND_ASSIGN(Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_shape = + ShapeInference::InferBinaryOpShape(HloOpcode::kAtan2, lhs, rhs, + GetParam().broadcast_dimensions); + if (inferred_shape.ok()) { + TF_ASSERT_OK_AND_ASSIGN(Shape expected, ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + ASSERT_TRUE(GetParam().error_message.has_value()); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr(*GetParam().error_message)); + } +} + +TEST_F(ShapeInferenceTest, UnboundedBitcastConvert) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferBitcastConvertShape(operand, PrimitiveType::F16)); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f16[?, 10, 2]")); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedBatchNormGrad) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, ?, 7]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape grad_operand, ParseShape("f32[?, ?, 7]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape scale, ParseShape("f32[5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape mean, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape variance, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape grad_scale, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape grad_offset, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape grad_output, ParseShape("f32[5, ?, 7]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, + ShapeInference::InferBatchNormGradShape( + operand, scale, mean, variance, grad_output, 1)); + const Shape expected_tuple_shape = + ShapeUtil::MakeTupleShape({grad_operand, grad_scale, grad_offset}); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected_tuple_shape)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected_tuple_shape); +} + +TEST_F(ShapeInferenceTest, UnboundedBatchNormInference) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, ?, 7]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape scale, ParseShape("f32[5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape offset, ParseShape("f32[5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape mean, ParseShape("f32[5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape variance, ParseShape("f32[5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, + ShapeInference::InferBatchNormInferenceShape( + operand, scale, offset, mean, variance, 1)); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, ?, 7]")); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedBatchNormTraining) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, ?, 7]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape output, ParseShape("f32[?, ?, 7]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape scale, ParseShape("f32[5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape offset, ParseShape("f32[5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape batch_mean, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape batch_var, ParseShape("f32[?]")); + const Shape expected_tuple_shape = + ShapeUtil::MakeTupleShape({output, batch_mean, batch_var}); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferBatchNormTrainingShape(operand, scale, offset, 1)); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected_tuple_shape)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected_tuple_shape); +} + +TEST_F(ShapeInferenceTest, UnboundedBroadcastUnsupportedOperand) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[<=2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[1, <=2, ?]")); + const absl::StatusOr inferred_shape = + ShapeInference::InferBroadcastShape(operand, /*broadcast_sizes=*/{1}); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr("is_unbounded_dynamic")); +} + +TEST_F(ShapeInferenceTest, UnboundedBroadcastUnsupportedBroadcastSize) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[<=2, 4]")); + const absl::StatusOr inferred_shape = + ShapeInference::InferBroadcastShape( + operand, /*broadcast_sizes=*/{Shape::kUnboundedSize}); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr("Non-broadcast dimensions must not be dynamic.")); +} + +TEST_F(ShapeInferenceTest, UnboundedBroadcastInDim) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[<=2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[<=2, 3, 4]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferBroadcastShape(operand, expected, + /*broadcast_dimensions=*/{0, 2})); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedBroadcastInDimToBounded) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[<=2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[<=2, 3, <=4]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferBroadcastShape(operand, expected, + /*broadcast_dimensions=*/{0, 2})); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedBroadcastInDimUnsupportedOutput) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[<=2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[<=2, 3, ?]")); + const absl::StatusOr inferred_shape = + ShapeInference::InferBroadcastShape(operand, expected, + /*broadcast_dimensions=*/{0, 2}); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr("is_unbounded_dynamic")); +} + +TEST_F(ShapeInferenceTest, UnboundedBroadcastInDimUnsupported) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[<=2, 4]")); + const absl::StatusOr inferred_shape = + ShapeInference::InferBroadcastShape( + operand, /*broadcast_sizes=*/{2, Shape::kUnboundedSize, 4}); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr("Non-broadcast dimensions must not be dynamic.")); +} + +TEST_F(ShapeInferenceTest, UnboundedCholesky) { + TF_ASSERT_OK_AND_ASSIGN(const Shape a, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, + ShapeInference::InferCholeskyShape(a)); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_P(UnboundedClampOpShapeInferenceTest, UnboundedClamp) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam()[0])); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam()[1])); + TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, ParseShape(GetParam()[2])); + const absl::StatusOr inferred_shape = + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, lhs, rhs, ehs); + if (inferred_shape.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape(GetParam()[3])); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + EXPECT_EQ(inferred_shape.status().message(), GetParam()[4]); + } +} + +TEST_F(ShapeInferenceTest, UnboundedClampWithTuple) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("(f32[2], f32[?])")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("(f32[?], f32[2])")); + TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, ParseShape("(f32[2], f32[?])")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("(f32[?], f32[2])")); + const absl::StatusOr inferred_shape = + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, lhs, rhs, ehs); + EXPECT_THAT( + inferred_shape.status().message(), + HasSubstr( + "Expected array argument for clamp min, but got (f32[2], f32[?]).")); +} + +TEST_P(UnboundedCompareOpShapeInferenceTest, UnboundedCompare) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_shape = + ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, lhs, rhs, + GetParam().broadcast_dimensions); + if (inferred_shape.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + ASSERT_TRUE(GetParam().error_message.has_value()); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr(*GetParam().error_message)); + } +} + +TEST_P(UnboundedComplexOpShapeInferenceTest, UnboundedComplex) { + TF_ASSERT_OK_AND_ASSIGN(const Shape real, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape imag, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_shape = + ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, real, imag, + GetParam().broadcast_dimensions); + if (inferred_shape.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + ASSERT_TRUE(GetParam().error_message.has_value()); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr(*GetParam().error_message)); + } +} + +TEST_P(UnboundedConcatenateOpShapeInferenceTest, UnboundedConcatenate) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand1, ParseShape(GetParam()[0])); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand2, ParseShape(GetParam()[1])); + const absl::StatusOr inferred_shape = + ShapeInference::InferConcatOpShape({&operand1, &operand2}, + /*dimension=*/0); + if (inferred_shape.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape(GetParam()[2])); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + EXPECT_EQ(inferred_shape.status().message(), GetParam()[3]); + } +} + +TEST_F(UnboundedConcatenateOpShapeInferenceTest, + UnboundedConcatenateMismatchedDimensions) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand1, ParseShape("f32[2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand2, ParseShape("f32[2, 3]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand3, ParseShape("f32[2, 4]")); + const absl::StatusOr inferred_shape = + ShapeInference::InferConcatOpShape({&operand1, &operand2, &operand3}, + /*dimension=*/0); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr("Mismatched dimension sizes 3 and 4 in dimension 1")); +} + +TEST_F(UnboundedConcatenateOpShapeInferenceTest, + UnboundedConcatenateMismatchedBoundSizes) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand1, ParseShape("f32[2, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand2, ParseShape("f32[2, <=3]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand3, ParseShape("f32[2, <=4]")); + const absl::StatusOr inferred_shape = + ShapeInference::InferConcatOpShape({&operand1, &operand2, &operand3}, + /*dimension=*/0); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr("Mismatched bound sizes 3 and 4 in dimension 1")); +} + +TEST_F(ShapeInferenceTest, UnboundedConvert) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f64[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape result, ShapeInference::InferConvertShape( + operand, PrimitiveType::F64)); + EXPECT_TRUE(ShapeUtil::Equal(result, expected)) + << "inferred: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedConvolution) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("f32[?, 2, ?, 128]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[2, 2, <=128, 8]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 1, ?, 8]")); + + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + TF_ASSERT_OK_AND_ASSIGN( + const Window window, + ShapeInference::InferWindowFromDimensions( + /*window_dimensions=*/{2, 2}, /*window_strides=*/{1, 1}, + MakePadding(/*input_dimensions=*/{2, Shape::kUnboundedSize}, + /*window_dimensions=*/{2, 2}, + /*window_strides=*/{1, 1}, Padding::kValid), + /*lhs_dilation=*/{}, /*rhs_dilation=*/{})); + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, + ShapeInference::InferConvolveShape( + lhs, rhs, /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums, + /*preferred_element_type=*/std::nullopt)); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedDiv) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_shape = + ShapeInference::InferBinaryOpShape(HloOpcode::kDivide, lhs, rhs, + GetParam().broadcast_dimensions); + if (inferred_shape.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + ASSERT_TRUE(GetParam().error_message.has_value()); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr(*GetParam().error_message)); + } +} + +TEST_F(ShapeInferenceTest, UnboundedDot) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(1); + dnums.add_rhs_contracting_dimensions(0); + + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferDotOpShape(lhs, rhs, dnums, + /*preferred_element_type=*/std::nullopt)); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedDotGeneral) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("f32[?, <=3, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("f32[2, 4, 5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, <=3, 5]")); + + DotDimensionNumbers dnums; + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(1); + + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferDotOpShape(lhs, rhs, dnums, + /*preferred_element_type=*/std::nullopt)); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedDynamicSlice) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape start_index, ParseShape("s32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[2, 2]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferDynamicSliceShape( + operand, /*start_index_shapes=*/{start_index, start_index}, + /*slice_sizes=*/{2, 2}, /*allow_scalar_indices=*/true)); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedGather) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[3, 4, 2]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape start_indices, + ParseShape("s32[?, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, ?, 2, 2]")); + + GatherDimensionNumbers dimension_numbers; + dimension_numbers.add_offset_dims(2); + dimension_numbers.add_offset_dims(3); + dimension_numbers.add_collapsed_slice_dims(0); + dimension_numbers.add_start_index_map(1); + dimension_numbers.add_start_index_map(0); + dimension_numbers.set_index_vector_dim(2); + + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, + ShapeInference::InferGatherShape( + operand, start_indices, dimension_numbers, + /*slice_sizes=*/{1, 2, 2})); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST(XlaBuilderTest, UnboundedGetTupleElement) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferGetTupleElementShape( + ShapeUtil::MakeTupleShape({operand}), /*index=*/0)); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedMap) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand0, ParseShape("f32[2, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape operand1, ParseShape("f32[?, 3, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[2, ?, ?]")); + + const ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); + + TF_ASSERT_OK_AND_ASSIGN( + const Shape result_shape, + ShapeInference::InferMapShape(/*arg_shapes=*/{&operand0, &operand1}, + to_apply, /*dimensions=*/{0, 1, 2})); + EXPECT_TRUE(ShapeUtil::Equal(result_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(result_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedMax) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_shape = + ShapeInference::InferBinaryOpShape(HloOpcode::kMaximum, lhs, rhs, + GetParam().broadcast_dimensions); + if (inferred_shape.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + ASSERT_TRUE(GetParam().error_message.has_value()); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr(*GetParam().error_message)); + } +} + +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedMin) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_status = + ShapeInference::InferBinaryOpShape(HloOpcode::kMinimum, lhs, rhs, + GetParam().broadcast_dimensions); if (inferred_status.ok()) { - ASSERT_IS_OK(expected.status()); - ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), expected.value())) - << "inferred: " << ShapeUtil::HumanString(inferred_status.value()) - << " expected: " << ShapeUtil::HumanString(expected.value()); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_status) + << " expected: " << ShapeUtil::HumanString(expected); } else { + ASSERT_TRUE(GetParam().error_message.has_value()); EXPECT_THAT(inferred_status.status().message(), - HasSubstr("Binary op add with incompatible shapes")); + HasSubstr(*GetParam().error_message)); } } -TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedDiv) { - auto lhs = ParseShape(GetParam()[0]); - auto rhs = ParseShape(GetParam()[1]); - auto expected = ParseShape(GetParam()[2]); - ASSERT_IS_OK(lhs.status()); - ASSERT_IS_OK(rhs.status()); - auto inferred_status = ShapeInference::InferBinaryOpShape( - HloOpcode::kDivide, lhs.value(), rhs.value(), - /*broadcast_dimensions=*/{}); +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedMul) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_shape = + ShapeInference::InferBinaryOpShape(HloOpcode::kMultiply, lhs, rhs, + GetParam().broadcast_dimensions); + if (inferred_shape.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + ASSERT_TRUE(GetParam().error_message.has_value()); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr(*GetParam().error_message)); + } +} + +TEST_P(UnboundedLogicalOpShapeInferenceTest, UnboundedOr) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_shape = + ShapeInference::InferBinaryOpShape(HloOpcode::kOr, lhs, rhs, + GetParam().broadcast_dimensions); + if (inferred_shape.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + ASSERT_TRUE(GetParam().error_message.has_value()); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr(*GetParam().error_message)); + } +} + +TEST_F(ShapeInferenceTest, UnboundedPad) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape padding_value, ParseShape("f32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 21]")); + + PaddingConfig padding_config; + for (int i = 0; i < 2; i++) { + const auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_low(1); + dimension->set_edge_padding_high(1); + dimension->set_interior_padding(1); + } + + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferPadShape(operand, padding_value, padding_config)); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedPow) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_shape = + ShapeInference::InferBinaryOpShape(HloOpcode::kPower, lhs, rhs, + GetParam().broadcast_dimensions); + if (inferred_shape.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + ASSERT_TRUE(GetParam().error_message.has_value()); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr(*GetParam().error_message)); + } +} + +TEST_F(ShapeInferenceTest, UnboundedReduce) { + TF_ASSERT_OK_AND_ASSIGN(const Shape input0, ParseShape("f32[7, 5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape input1, ParseShape("f32[?, 5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape input2, ParseShape("f32[7, ?]")); + + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {f32_, f32_, f32_, f32_, f32_, f32_}, + ShapeUtil::MakeTupleShape({f32_, f32_, f32_})); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferReduceShape( + {&input0, &input1, &input2, &f32_, &f32_, &f32_}, {1}, to_apply)); + const Shape shape = ShapeUtil::MakeShape(F32, {7}); + const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape}); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedReduceInvalidReduceDimension) { + TF_ASSERT_OK_AND_ASSIGN(const Shape input0, ParseShape("f32[7, 5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape input1, ParseShape("f32[?, 5]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape input2, ParseShape("f32[5, ?]")); + + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {f32_, f32_, f32_, f32_, f32_, f32_}, + ShapeUtil::MakeTupleShape({f32_, f32_, f32_})); + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape( + {&input0, &input1, &input2, &f32_, &f32_, &f32_}, {1}, to_apply); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr("All reduced tensors must have compatible dimension")); +} + +TEST_F(ShapeInferenceTest, UnboundedReducePrecision) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred, + ShapeInference::InferReducePrecisionShape(operand, /*exponent_bits=*/2, + /*mantissa_bits=*/2)); + ASSERT_TRUE(ShapeUtil::Equal(inferred, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedReduceWindow) { + TF_ASSERT_OK_AND_ASSIGN(const Shape input, ParseShape("f32[?, 4, 8]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 3, 5]")); + + Window window; + WindowDimension dim0, dim1, dim2; + dim0.set_stride(1); + dim0.set_padding_low(0); + dim0.set_padding_high(0); + dim0.set_window_dilation(1); + dim0.set_base_dilation(1); + dim1 = dim2 = dim0; + dim0.set_size(1); + dim1.set_size(2); + dim2.set_size(4); + *window.add_dimensions() = dim0; + *window.add_dimensions() = dim1; + *window.add_dimensions() = dim2; + + ProgramShape body = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, + ShapeInference::InferReduceWindowShape( + input, /*init_value=*/f32_, window, body)); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedRemainder) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_status = + ShapeInference::InferBinaryOpShape(HloOpcode::kRemainder, lhs, rhs, + GetParam().broadcast_dimensions); if (inferred_status.ok()) { - ASSERT_IS_OK(expected.status()); - ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), expected.value())) - << "inferred: " << ShapeUtil::HumanString(inferred_status.value()) - << " expected: " << ShapeUtil::HumanString(expected.value()); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_status) + << " expected: " << ShapeUtil::HumanString(expected); } else { + ASSERT_TRUE(GetParam().error_message.has_value()); EXPECT_THAT(inferred_status.status().message(), - HasSubstr("Binary op divide with incompatible shapes")); + HasSubstr(*GetParam().error_message)); } } -TEST_P(UnboundedUnaryOpShapeInferenceTest, UnboundedExp) { - auto operand = ParseShape(GetParam()[0]); - auto expected = ParseShape(GetParam()[1]); - ASSERT_IS_OK(operand.status()); - auto inferred_status = - ShapeInference::InferUnaryOpShape(HloOpcode::kExp, operand.value()); - ASSERT_IS_OK(expected.status()); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), expected.value())) - << "inferred: " << ShapeUtil::HumanString(inferred_status.value()) - << " expected: " << ShapeUtil::HumanString(expected.value()); +TEST_F(ShapeInferenceTest, UnboundedReshape) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[2,3]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred, + ShapeInference::InferReshapeShape(operand, /*dimensions=*/{0}, + /*new_sizes=*/{2, 3}, -1)); + ASSERT_TRUE(ShapeUtil::Equal(inferred, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedReshapeUnsupportedOutputShape) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[6]")); + const absl::StatusOr inferred_shape = + ShapeInference::InferReshapeShape( + operand, /*dimensions=*/{0}, + /*new_sizes=*/{Shape::kUnboundedSize, Shape::kUnboundedSize}, -1); + EXPECT_THAT( + inferred_shape.status().message(), + HasSubstr("Reshaping with unbounded result shape is not supported.")); +} + +TEST_F(ShapeInferenceTest, UnboundedReshapeUnsupportedMixOfDynamism) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, <=3]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[<=3]")); + const absl::StatusOr inferred_shape = + ShapeInference::InferReshapeShape(operand, /*dimensions=*/{0}, + /*new_sizes=*/{3}, -1); + ASSERT_THAT(inferred_shape.status().message(), + HasSubstr("Reshape operand with bounded and unbounded dynamism " + "not supported.")); } -TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedMax) { - auto lhs = ParseShape(GetParam()[0]); - auto rhs = ParseShape(GetParam()[1]); - auto expected = ParseShape(GetParam()[2]); - ASSERT_IS_OK(lhs.status()); - ASSERT_IS_OK(rhs.status()); - auto inferred_status = ShapeInference::InferBinaryOpShape( - HloOpcode::kMaximum, lhs.value(), rhs.value(), - /*broadcast_dimensions=*/{}); +TEST_F(ShapeInferenceTest, UnboundedReverse) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferReverseShape(operand, /*dimensions=*/{0, 1})); + ASSERT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedScatter) { + TF_ASSERT_OK_AND_ASSIGN(Shape input, ParseShape("f32[?, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_indices, ParseShape("s32[?, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(Shape updates, ParseShape("f32[?, ?, ?, ?]")); + TF_ASSERT_OK_AND_ASSIGN(Shape expected, ParseShape("f32[?, ?, ?]")); + + const ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); + + ScatterDimensionNumbers dimension_numbers; + dimension_numbers.add_update_window_dims(2); + dimension_numbers.add_update_window_dims(3); + dimension_numbers.add_inserted_window_dims(0); + dimension_numbers.add_scatter_dims_to_operand_dims(1); + dimension_numbers.add_scatter_dims_to_operand_dims(0); + dimension_numbers.set_index_vector_dim(2); + + TF_ASSERT_OK_AND_ASSIGN( + Shape result, + ShapeInference::InferScatterShape({&input, &scatter_indices, &updates}, + to_apply, dimension_numbers)); + EXPECT_TRUE(ShapeUtil::Equal(result, expected)) + << "inferred: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_P(UnboundedSelectOpShapeInferenceTest, UnboundedSelect) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam()[0])); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam()[1])); + TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, ParseShape(GetParam()[2])); + const absl::StatusOr inferred_shape = + ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, lhs, rhs, ehs); + if (inferred_shape.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape(GetParam()[3])); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + EXPECT_EQ(inferred_shape.status().message(), GetParam()[4]); + } +} + +TEST_F(ShapeInferenceTest, UnboundedSelectWithTupleUnsupported) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape("(pred[2], pred[?])")); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("(f32[?], f32[2])")); + TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, ParseShape("(f32[2], f32[?])")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("(f32[?], f32[2])")); + const absl::StatusOr inferred_shape = + ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, lhs, rhs, ehs); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr("Expected array argument for select pred, but got " + "(pred[2], pred[?]).")); +} + +TEST_F(ShapeInferenceTest, UnboundedSelectAndScatter) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape source, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape init_value, ParseShape("f32[]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + + Window window; + WindowDimension dim0; + dim0.set_base_dilation(1); + dim0.set_size(3); + dim0.set_stride(2); + dim0.set_padding_low(0); + dim0.set_padding_high(1); + dim0.set_window_dilation(1); + + WindowDimension dim1; + dim1.set_base_dilation(1); + dim1.set_size(1); + dim1.set_stride(1); + dim1.set_padding_low(0); + dim1.set_padding_high(0); + dim1.set_window_dilation(1); + + *window.add_dimensions() = dim0; + *window.add_dimensions() = dim1; + + TF_ASSERT_OK_AND_ASSIGN( + Shape result, + ShapeInference::InferSelectAndScatterShape( + operand, + /*select_shape=*/ShapeUtil::MakeProgramShape({f32_, f32_}, pred_), + window, source, init_value, + /*scatter_shape=*/ + ShapeUtil::MakeProgramShape({f32_, f32_}, f32_))); + + EXPECT_TRUE(ShapeUtil::Equal(result, expected)) + << "inferred: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedShiftLeft) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_status = + ShapeInference::InferBinaryOpShape(HloOpcode::kShiftLeft, lhs, rhs, + GetParam().broadcast_dimensions); if (inferred_status.ok()) { - ASSERT_IS_OK(expected.status()); - ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), expected.value())) - << "inferred: " << ShapeUtil::HumanString(inferred_status.value()) - << " expected: " << ShapeUtil::HumanString(expected.value()); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_status) + << " expected: " << ShapeUtil::HumanString(expected); } else { + ASSERT_TRUE(GetParam().error_message.has_value()); EXPECT_THAT(inferred_status.status().message(), - HasSubstr("Binary op maximum with incompatible shapes")); + HasSubstr(*GetParam().error_message)); } } -TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedMul) { - auto lhs = ParseShape(GetParam()[0]); - auto rhs = ParseShape(GetParam()[1]); - auto expected = ParseShape(GetParam()[2]); - ASSERT_IS_OK(lhs.status()); - ASSERT_IS_OK(rhs.status()); - auto inferred_status = ShapeInference::InferBinaryOpShape( - HloOpcode::kMultiply, lhs.value(), rhs.value(), - /*broadcast_dimensions=*/{}); +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedShiftRightArithmetic) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_status = + ShapeInference::InferBinaryOpShape(HloOpcode::kShiftRightArithmetic, lhs, + rhs, GetParam().broadcast_dimensions); if (inferred_status.ok()) { - ASSERT_IS_OK(expected.status()); - ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), expected.value())) - << "inferred: " << ShapeUtil::HumanString(inferred_status.value()) - << " expected: " << ShapeUtil::HumanString(expected.value()); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_status) + << " expected: " << ShapeUtil::HumanString(expected); } else { + ASSERT_TRUE(GetParam().error_message.has_value()); EXPECT_THAT(inferred_status.status().message(), - HasSubstr("Binary op multiply with incompatible shapes")); + HasSubstr(*GetParam().error_message)); } } -TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedPow) { - auto lhs = ParseShape(GetParam()[0]); - auto rhs = ParseShape(GetParam()[1]); - auto expected = ParseShape(GetParam()[2]); - ASSERT_IS_OK(lhs.status()); - ASSERT_IS_OK(rhs.status()); - auto inferred_status = ShapeInference::InferBinaryOpShape( - HloOpcode::kPower, lhs.value(), rhs.value(), - /*broadcast_dimensions=*/{}); +TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedShiftRightLogical) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_status = + ShapeInference::InferBinaryOpShape(HloOpcode::kShiftRightLogical, lhs, + rhs, GetParam().broadcast_dimensions); if (inferred_status.ok()) { - ASSERT_IS_OK(expected.status()); - ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), expected.value())) - << "inferred: " << ShapeUtil::HumanString(inferred_status.value()) - << " expected: " << ShapeUtil::HumanString(expected.value()); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_status) + << " expected: " << ShapeUtil::HumanString(expected); } else { + ASSERT_TRUE(GetParam().error_message.has_value()); EXPECT_THAT(inferred_status.status().message(), - HasSubstr("Binary op power with incompatible shapes")); + HasSubstr(*GetParam().error_message)); } } +TEST_F(ShapeInferenceTest, UnboundedSlice) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[1, <=3, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[1, <=2, 3]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferSliceShape(operand, /*starts=*/{0, 1, 2}, + /*limits=*/{1, 3, 5}, + /*strides=*/{1, 1, 1})); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedSort) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&operand})); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedSub) { - auto lhs = ParseShape(GetParam()[0]); - auto rhs = ParseShape(GetParam()[1]); - auto expected = ParseShape(GetParam()[2]); - ASSERT_IS_OK(lhs.status()); - ASSERT_IS_OK(rhs.status()); - auto inferred_status = ShapeInference::InferBinaryOpShape( - HloOpcode::kSubtract, lhs.value(), rhs.value(), - /*broadcast_dimensions=*/{}); + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_shape = + ShapeInference::InferBinaryOpShape(HloOpcode::kSubtract, lhs, rhs, + GetParam().broadcast_dimensions); + if (inferred_shape.ok()) { + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); + } else { + ASSERT_TRUE(GetParam().error_message.has_value()); + EXPECT_THAT(inferred_shape.status().message(), + HasSubstr(*GetParam().error_message)); + } +} + +TEST_F(ShapeInferenceTest, UnboundedTranspose) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, + ParseShape("f32[1, ?, 2, ?, <=2]{4,3,2,1,0}")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape("f32[<=2, 1, ?, 2, ?]{0,2,3,4,1}")); + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, + ShapeInference::InferTransposeShape( + operand, /*dimensions=*/{4, 0, 3, 2, 1})); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedTransposeRank1) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferTransposeShape(operand, /*dimensions=*/{0})); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedTriangularSolve) { + TF_ASSERT_OK_AND_ASSIGN(const Shape a, ParseShape("f32[?, 3, ?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape b, ParseShape("f32[?, ?, 4]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, ?, 4]")); + TriangularSolveOptions options; + options.set_left_side(true); + options.set_lower(true); + options.set_unit_diagonal(false); + options.set_transpose_a(TriangularSolveOptions::TRANSPOSE); + TF_ASSERT_OK_AND_ASSIGN( + const Shape result_shape, + ShapeInference::InferTriangularSolveShape(a, b, options)); + EXPECT_TRUE(ShapeUtil::Equal(result_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(result_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedTuple) { + TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); + const Shape expected = ShapeUtil::MakeTupleShape({operand}); + TF_ASSERT_OK_AND_ASSIGN( + const Shape result_shape, + ShapeInference::InferVariadicOpShape( + HloOpcode::kTuple, std::vector({&operand}))); + EXPECT_TRUE(ShapeUtil::Equal(result_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(result_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_F(ShapeInferenceTest, UnboundedWhile) { + TF_ASSERT_OK_AND_ASSIGN(const Shape init, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape result_shape, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?]")); + TF_ASSERT_OK_AND_ASSIGN( + const Shape inferred_shape, + ShapeInference::InferWhileShape( + /*condition=*/ShapeUtil::MakeProgramShape({result_shape}, pred_), + /*body=*/ShapeUtil::MakeProgramShape({result_shape}, result_shape), + /*init=*/init)); + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) + << " expected: " << ShapeUtil::HumanString(expected); +} + +TEST_P(UnboundedLogicalOpShapeInferenceTest, UnboundedXor) { + TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); + TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); + const absl::StatusOr inferred_status = + ShapeInference::InferBinaryOpShape(HloOpcode::kXor, lhs, rhs, + GetParam().broadcast_dimensions); if (inferred_status.ok()) { - ASSERT_IS_OK(expected.status()); - ASSERT_TRUE(ShapeUtil::Equal(inferred_status.value(), expected.value())) - << "inferred: " << ShapeUtil::HumanString(inferred_status.value()) - << " expected: " << ShapeUtil::HumanString(expected.value()); + TF_ASSERT_OK_AND_ASSIGN(const Shape expected, + ParseShape(GetParam().expected)); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_status) + << " expected: " << ShapeUtil::HumanString(expected); } else { + ASSERT_TRUE(GetParam().error_message.has_value()); EXPECT_THAT(inferred_status.status().message(), - HasSubstr("Binary op subtract with incompatible shapes")); + HasSubstr(*GetParam().error_message)); } } +INSTANTIATE_TEST_SUITE_P(UnboundedDynamism, + UnboundedLogicalOpShapeInferenceTest, + ::testing::ValuesIn( + {// LHS | RHS | bdims | Res + // 1 | ? | [] | ? + {"s32[1]", "s32[?]", {}, "s32[?]"}, + // ? | 1 | [] | ? + {"s32[?]", "s32[1]", {}, "s32[?]"}, + // 2 | ? | [] | 2 + {"s32[2]", "s32[?]", {}, "s32[2]"}, + // ? | 2 | [] | 2 + {"s32[?]", "s32[2]", {}, "s32[2]"}, + // <=2 | ? | [] | <=2 + {"s32[<=2]", "s32[?]", {}, "s32[<=2]"}, + // ? | <=2 | [] | <=2 + {"s32[?]", "s32[<=2]", {}, "s32[<=2]"}, + // ? | ? | [] | ? + {"s32[?]", "s32[?]", {}, "s32[?]"}, + // 1 | ?,3 | [0] | ?,3 + {"s32[1]", "s32[?,3]", zero_array, "s32[?,3]"}, + // 2 | ?,3 | [0] | err + {"s32[2]", "s32[?,3]", zero_array, "", + kBroadcastDimensionMismatchErrorMessage}, + // ?,2 | ?,3 | [] | err + {"s32[?,2]", + "s32[?,3]", + {}, + "", + kIncompatibleBinaryOpShapeErrorMessage}})); + +INSTANTIATE_TEST_SUITE_P(UnboundedDynamism, UnboundedBinaryOpShapeInferenceTest, + ::testing::ValuesIn( + {// LHS | RHS | bdims | Res + // 1 | ? | [] | ? + {"f32[1]", "f32[?]", {}, "f32[?]"}, + // ? | 1 | [] | ? + {"f32[?]", "f32[1]", {}, "f32[?]"}, + // 2 | ? | [] | 2 + {"f32[2]", "f32[?]", {}, "f32[2]"}, + // ? | 2 | [] | 2 + {"f32[?]", "f32[2]", {}, "f32[2]"}, + // <=2 | ? | [] | <=2 + {"f32[<=2]", "f32[?]", {}, "f32[<=2]"}, + // ? | <=2 | [] | <=2 + {"f32[?]", "f32[<=2]", {}, "f32[<=2]"}, + // ? | ? | [] | ? + {"f32[?]", "f32[?]", {}, "f32[?]"}, + // 1 | ?,3 | [0] | ?,3 + {"f32[1]", "f32[?,3]", zero_array, "f32[?,3]"}, + // 2 | ?,3 | [0] | err + {"f32[2]", "f32[?,3]", zero_array, "", + kBroadcastDimensionMismatchErrorMessage}, + // ?,2 | ?,3 | [] | err + {"f32[?,2]", + "f32[?,3]", + {}, + "", + kIncompatibleBinaryOpShapeErrorMessage}})); + +INSTANTIATE_TEST_SUITE_P(UnboundedDynamism, + UnboundedCompareOpShapeInferenceTest, + ::testing::ValuesIn( + {// LHS | RHS | bdims | Res + // 1 | ? | [] | ? + {"f32[1]", "f32[?]", {}, "pred[?]"}, + // ? | 1 | [] | ? + {"f32[?]", "f32[1]", {}, "pred[?]"}, + // 2 | ? | [] | 2 + {"f32[2]", "f32[?]", {}, "pred[2]"}, + // ? | 2 | [] | 2 + {"f32[?]", "f32[2]", {}, "pred[2]"}, + // <=2 | ? | [] | <=2 + {"f32[<=2]", "f32[?]", {}, "pred[<=2]"}, + // ? | <=2 | [] | <=2 + {"f32[?]", "f32[<=2]", {}, "pred[<=2]"}, + // ? | ? | [] | ? + {"f32[?]", "f32[?]", {}, "pred[?]"}, + // 1 | ?,3 | [0] | ?,3 + {"f32[1]", "f32[?,3]", zero_array, "pred[?,3]"}, + // 2 | ?,3 | [0] | err + {"f32[2]", "f32[?,3]", zero_array, "", + kBroadcastDimensionMismatchErrorMessage}, + // ?,2 | ?,3 | [] | err + {"f32[?,2]", + "f32[?,3]", + {}, + "", + kIncompatibleBinaryOpShapeErrorMessage}})); + +INSTANTIATE_TEST_SUITE_P(UnboundedDynamism, + UnboundedComplexOpShapeInferenceTest, + ::testing::ValuesIn( + {// LHS | RHS | bdims | Res + // 1 | ? | [] | ? + {"f32[1]", "f32[?]", {}, "c64[?]"}, + // ? | 1 | [] | ? + {"f32[?]", "f32[1]", {}, "c64[?]"}, + // 2 | ? | [] | 2 + {"f32[2]", "f32[?]", {}, "c64[2]"}, + // ? | 2 | [] | 2 + {"f32[?]", "f32[2]", {}, "c64[2]"}, + // <=2 | ? | [] | <=2 + {"f32[<=2]", "f32[?]", {}, "c64[<=2]"}, + // ? | <=2 | [] | <=2 + {"f32[?]", "f32[<=2]", {}, "c64[<=2]"}, + // ? | ? | [] | ? + {"f32[?]", "f32[?]", {}, "c64[?]"}, + // 1 | ?,3 | [0] | ?,3 + {"f32[1]", "f32[?,3]", zero_array, "c64[?,3]"}, + // 2 | ?,3 | [0] | err + {"f32[2]", "f32[?,3]", zero_array, "", + kBroadcastDimensionMismatchErrorMessage}, + // ?,2 | ?,3 | [] | err + {"f32[?,2]", + "f32[?,3]", + {}, + "", + kIncompatibleBinaryOpShapeErrorMessage}})); + +INSTANTIATE_TEST_SUITE_P( + UnboundedDynamism, UnboundedConcatenateOpShapeInferenceTest, + ::testing::Values( + // LHS shape | RHS shape | Result shape (Concat dim is 0) + // [X1, Y] | [X2, Y] | [X1+X2, Y] + std::vector({"f32[2, 3]", "f32[4, 3]", "f32[6, 3]", ""}), + // [X, Y] | [?, ?] | [?, Y] + std::vector({"f32[2, 3]", "f32[?, ?]", "f32[?, 3]", ""}), + // [X1, Y] | [<=X2, <=Y] | [<=X1+X2, <=Y] + std::vector({"f32[4, 3]", "f32[<=2, <=3]", "f32[<=6, <=3]", + ""}), + // [?, ?] | [?, ?] | [?, ?] + std::vector({"f32[?, ?]", "f32[?, ?]", "f32[?, ?]", ""}), + // [?, ?] | [<=B1, <=B2]| [?, <=B2] + std::vector({"f32[?, ?]", "f32[<=2, <=3]", "f32[?, <=3]", + ""}), + // [<=B1, ?] | [<=B2, X] | [<=B1+B2, X] + std::vector({"f32[<=2, ?]", "f32[<=4, 3]", "f32[<=6, 3]", + ""}), + // [X, <=B1] | [X, <=B2] | Error, mismatched + // bound sizes + std::vector( + {"f32[2, <=3]", "f32[2, <=4]", "", + "Cannot concatenate arrays that differ in dimensions other than " + "the one being concatenated. Dimension 1 in both shapes must be " + "equal (or compatible): f32[2,<=3] vs f32[2,<=4]."}), + // [X, Y1] | [X, Y2] | Error, mismatched + // dimension sizes + std::vector( + {"f32[2, 3]", "f32[2, 4]", "", + "Cannot concatenate arrays that differ in dimensions other than " + "the one being concatenated. Dimension 1 in both shapes must be " + "equal (or compatible): f32[2,3] vs f32[2,4]."}))); + +INSTANTIATE_TEST_SUITE_P( + UnboundedDynamism, UnboundedClampOpShapeInferenceTest, + ::testing::Values( + // MIN shape | OPERAND shape | MAX shape | Result + // [] | [?] | [] | [?] + std::vector({"f32[]", "f32[?]", "f32[]", "f32[?]", ""}), + // [] | [?] | [X] | [?] + std::vector({"f32[]", "f32[?]", "f32[2]", "f32[?]", ""}), + // [] | [?] | [<=B] | [?] + std::vector({"f32[]", "f32[?]", "f32[<=2]", "f32[?]", ""}), + // [X] | [?] | [X] | [?] + std::vector({"f32[2]", "f32[?]", "f32[2]", "f32[?]", ""}), + // [?] | [X] | [X] | [X] + std::vector({"f32[?]", "f32[2]", "f32[2]", "f32[2]", ""}), + // [?] | [<=B] | [?] | [<=B] + std::vector({"f32[?]", "f32[<=2]", "f32[?]", "f32[<=2]", + ""}), + // [<=B] | [?] | [<=B] | [?] + std::vector({"f32[<=2]", "f32[?]", "f32[<=2]", "f32[?]", + ""}), + // [?] | [?] | [?] | [?] + std::vector({"f32[?]", "f32[?]", "f32[?]", "f32[?]", ""}), + // [?] | [] | [?] | error + std::vector( + {"f32[?]", "f32[]", "f32[?]", "", + "Clamp with incompatible shapes: f32[?], f32[], f32[?]."}), + // A[] | B[?] | B[?] | error + std::vector( + {"s32[]", "f32[?]", "f32[?]", "", + "Clamp with incompatible element types: s32[], f32[?], f32[?]."}), + // [X] | [<=B] | [X] | error + std::vector( + {"f32[3]", "f32[<=2]", "f32[3]", "", + "Clamp with incompatible shapes: f32[3], f32[<=2], f32[3]."}), + // [X] | [?] | [Y] | error + std::vector( + {"f32[2]", "f32[?]", "f32[3]", "", + "Clamp with incompatible shapes: f32[2], f32[?], f32[3]."}))); + INSTANTIATE_TEST_SUITE_P( - UnboundedDynamism, UnboundedBinaryOpShapeInferenceTest, + UnboundedDynamism, UnboundedSelectOpShapeInferenceTest, ::testing::Values( - // LHS | RHS | Result - // 1 | ? | ? - std::vector({"f32[1]", "f32[?]", "f32[?]"}), - // ? | 1 | ? - std::vector({"f32[?]", "f32[1]", "f32[?]"}), - // 2 | ? | 2 - std::vector({"f32[2]", "f32[?]", "f32[2]"}), - // ? | 2 | 2 - std::vector({"f32[?]", "f32[2]", "f32[2]"}), - // <=2 | ? | <=2 - std::vector({"f32[<=2]", "f32[?]", "f32[<=2]"}), - // ? | <=2 | <=2 - std::vector({"f32[?]", "f32[<=2]", "f32[<=2]"}), - // ? | ? | ? - std::vector({"f32[?]", "f32[?]", "f32[?]"}), - // ?,2 | ?,3 | error - std::vector({"f32[?,2]", "f32[?,3]", ""}))); + // PRED shape | ON_TRUE shape | ON_FALSE shape | Result + // [] | [?] | [X] | [X] + std::vector({"pred[]", "f32[?]", "f32[2]", "f32[2]", ""}), + // [] | [?] | [<=B] | [<=B] + std::vector({"pred[]", "f32[?]", "f32[<=2]", "f32[<=2]", + ""}), + // [X] | [?] | [X] | [X] + std::vector({"pred[2]", "f32[?]", "f32[2]", "f32[2]", ""}), + // [?] | [X] | [X] | [X] + std::vector({"pred[?]", "f32[2]", "f32[?]", "f32[2]", ""}), + // [?] | [<=B] | [?] | [<=B] + std::vector({"pred[?]", "f32[<=2]", "f32[?]", "f32[<=2]", + ""}), + // [<=B] | [?] | [<=B] | [<=B] + std::vector({"pred[<=2]", "f32[?]", "f32[<=2]", "f32[<=2]", + ""}), + // [?] | [?] | [?] | [?] + std::vector({"pred[?]", "f32[?]", "f32[?]", "f32[?]", ""}), + // [X] | A[X] | B[X] | error + std::vector({"pred[3]", "s32[3]", "f32[3]", "", + "Operands to select must be the same shape; " + "got s32[3] and f32[3]."}), + // [X] | [?] | [<=B] | error + std::vector( + {"pred[3]", "f32[?]", "f32[<=2]", "", + "Operands to select and predicate must be the same shape; got " + "f32[?] and f32[<=2] and pred[3]."}), + // [X] | [<=B] | [X] | error + std::vector({"pred[3]", "f32[<=2]", "f32[3]", "", + "Operands to select must be the same shape; " + "got f32[<=2] and f32[3]."}), + // [X] | [?] | [Y] | error + std::vector( + {"pred[2]", "f32[?]", "f32[3]", "f32[3]", + "Operands to select and predicate must be the same shape; got " + "f32[?] and f32[3] and pred[2]."}), + // [?] | [] | [] | error + std::vector( + {"pred[?]", "f32[]", "f32[]", "", + "Operands to select and predicate must be the same shape; got " + "f32[] and f32[] and pred[?]."}), + // [] | [?] | [] | error + std::vector({"pred[]", "f32[?]", "f32[]", "", + "Operands to select must be the same shape; " + "got f32[?] and f32[]."}))); INSTANTIATE_TEST_SUITE_P(UnboundedDynamism, UnboundedUnaryOpShapeInferenceTest, - ::testing::Values( - // OPERAND | Result - // 1 | 1 - std::vector({"f32[1]", "f32[1]"}), - // 2 | 2 - std::vector({"f32[2]", "f32[2]"}), - // <=2 | <=2 - std::vector({"f32[<=2]", "f32[<=2]"}), - // ? | ? - std::vector({"f32[?]", "f32[?]"}), - // ?,3 | ?,3 - std::vector({"f32[?,3]", - "f32[?,3]"}))); + ::testing::ValuesIn( + {{"f32[?]", "f32[?]", HloOpcode::kAbs}, + {"f32[?]", "f32[?]", HloOpcode::kCbrt}, + {"f32[?]", "f32[?]", HloOpcode::kCeil}, + {"u32[?]", "u32[?]", HloOpcode::kClz}, + {"f32[?]", "f32[?]", HloOpcode::kCos}, + {"f32[?]", "f32[?]", HloOpcode::kErf}, + {"f32[?]", "f32[?]", HloOpcode::kExp}, + {"f32[?]", "f32[?]", HloOpcode::kExpm1}, + {"f32[?]", "f32[?]", HloOpcode::kFloor}, + {"f32[?]", "f32[?]", HloOpcode::kImag}, + {"f32[?]", "pred[?]", HloOpcode::kIsFinite}, + {"f32[?]", "f32[?]", HloOpcode::kLog}, + {"f32[?]", "f32[?]", HloOpcode::kLog1p}, + {"f32[?]", "f32[?]", HloOpcode::kLogistic}, + {"f32[?]", "f32[?]", HloOpcode::kNegate}, + {"s32[?]", "s32[?]", HloOpcode::kNot}, + {"u32[?]", "u32[?]", HloOpcode::kPopulationCount}, + {"f32[?]", "f32[?]", HloOpcode::kReal}, + {"f32[?]", "f32[?]", HloOpcode::kRoundNearestAfz}, + {"f32[?]", "f32[?]", + HloOpcode::kRoundNearestEven}, + {"f32[?]", "f32[?]", HloOpcode::kRsqrt}, + {"f32[?]", "f32[?]", HloOpcode::kSign}, + {"f32[?]", "f32[?]", HloOpcode::kSin}, + {"f32[?]", "f32[?]", HloOpcode::kSqrt}, + {"f32[?]", "f32[?]", HloOpcode::kTanh}})); } // namespace } // namespace xla diff --git a/xla/service/shaped_buffer.cc b/xla/service/shaped_buffer.cc index e01132ed8d24c..a429ff8b3819a 100644 --- a/xla/service/shaped_buffer.cc +++ b/xla/service/shaped_buffer.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -67,7 +67,7 @@ ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) { ShapedBuffer::~ShapedBuffer() {} -StatusOr ShapedBuffer::SubShapedBuffer( +absl::StatusOr ShapedBuffer::SubShapedBuffer( const ShapeIndex& index) const { TF_ASSIGN_OR_RETURN(const Shape* device_sub_shape, ShapeUtil::TryGetSubshape(on_device_shape(), index)); diff --git a/xla/service/shaped_buffer.h b/xla/service/shaped_buffer.h index e3e32001d10cd..3882241e74670 100644 --- a/xla/service/shaped_buffer.h +++ b/xla/service/shaped_buffer.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -115,7 +115,7 @@ class ShapedBuffer { const ShapeTree& buffers() const { return buffers_; } ShapeTree& buffers() { return buffers_; } - StatusOr SubShapedBuffer(const ShapeIndex& index) const; + absl::StatusOr SubShapedBuffer(const ShapeIndex& index) const; // Set all device memory pointers in the object to null. void clear(); diff --git a/xla/service/shaped_buffer_test.cc b/xla/service/shaped_buffer_test.cc index b40a5ab3148f0..08ebed2707760 100644 --- a/xla/service/shaped_buffer_test.cc +++ b/xla/service/shaped_buffer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -56,9 +56,9 @@ class TestAllocator : public se::DeviceMemoryAllocator { // Pull in two-arg overload of Allocate. using se::DeviceMemoryAllocator::Allocate; - StatusOr Allocate(int device_ordinal, uint64_t size, - bool /*retry_on_failure*/, - int64_t /*memory_space*/) override { + absl::StatusOr Allocate( + int device_ordinal, uint64_t size, bool /*retry_on_failure*/, + int64_t /*memory_space*/) override { // By contract, we must return null if size == 0. if (size == 0) { return se::OwningDeviceMemory(); @@ -86,7 +86,7 @@ class TestAllocator : public se::DeviceMemoryAllocator { bool AllowsAsynchronousDeallocation() const override { return false; } - StatusOr GetStream(int device_ordinal) override { + absl::StatusOr GetStream(int device_ordinal) override { LOG(FATAL) << "Not implemented"; } diff --git a/xla/service/sharding_format_picker.cc b/xla/service/sharding_format_picker.cc index 5fd1219075881..13e0a244853de 100644 --- a/xla/service/sharding_format_picker.cc +++ b/xla/service/sharding_format_picker.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -164,7 +164,7 @@ std::unique_ptr MaybeConvertToV1(const HloSharding& sharding) { } // namespace -StatusOr ShardingFormatPicker::Run( +absl::StatusOr ShardingFormatPicker::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/sharding_format_picker.h b/xla/service/sharding_format_picker.h index 6cb2ac1938b28..1ebd96520c97b 100644 --- a/xla/service/sharding_format_picker.h +++ b/xla/service/sharding_format_picker.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,7 +32,7 @@ class ShardingFormatPicker : public HloModulePass { : sharding_type_(sharding_type) {} absl::string_view name() const override { return "sharding-format-picker"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/sharding_propagation.cc b/xla/service/sharding_propagation.cc index 052c2eba7615a..f149da7505860 100644 --- a/xla/service/sharding_propagation.cc +++ b/xla/service/sharding_propagation.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -73,31 +73,49 @@ bool IsSpatiallyPartitioned(const HloInstruction* hlo) { return hlo->has_sharding() && IsSpatiallyPartitioned(hlo->sharding()); } -// We think manual shardings are strictly better than tile maximal shardings. -bool IsShardingStrictlyBetter(const HloSharding& lhs, const HloSharding& rhs) { - CHECK_EQ(lhs.IsTuple(), rhs.IsTuple()) << lhs << " <> " << rhs; - if (lhs.IsTuple()) { - // For tuples we consider lhs to have a better sharding if none of the - // elements are worse and at least one element is better then in rhs - // sharding. - const auto& lhs_shardings = lhs.tuple_elements(); - const auto& rhs_shardings = rhs.tuple_elements(); - CHECK_EQ(lhs_shardings.size(), rhs_shardings.size()); - bool is_better = false; - for (int64_t i = 0; i < lhs_shardings.size(); ++i) { - if (IsShardingStrictlyBetter(rhs_shardings[i], lhs_shardings[i])) { - return false; +// Returns +// - 1, iff `lhs` is strictly better than `rhs`. +// - 2, iff `rhs` is strictly better than `lhs`. +// - 0 or 3, otherwise. +// +// Notes: +// - We think manual shardings are strictly better than tile maximal shardings. +// - For tuples we consider lhs to have a better sharding if none of the +// elements are worse and at least one element is better then in rhs +// sharding. +int MaskTupleShardingStrictlyBetter(const HloSharding& lhs, + const HloSharding& rhs) { + DCHECK(lhs.IsTuple()); + DCHECK(rhs.IsTuple()); + const auto& lhs_shardings = lhs.tuple_elements(); + const auto& rhs_shardings = rhs.tuple_elements(); + CHECK_EQ(lhs_shardings.size(), rhs_shardings.size()); + int mask = 0; + for (int64_t i = 0; i < lhs_shardings.size(); ++i) { + const auto& lhs_shard = lhs_shardings[i]; + const auto& rhs_shard = rhs_shardings[i]; + CHECK_EQ(lhs_shard.IsTuple(), rhs_shard.IsTuple()); + if (lhs_shard.IsTuple()) { + mask |= MaskTupleShardingStrictlyBetter(lhs_shard, rhs_shard); + } else { + if (lhs_shard.IsManualLeaf() && rhs_shard.IsTileMaximalLeaf()) { + mask |= 1; } - if (IsShardingStrictlyBetter(lhs_shardings[i], rhs_shardings[i])) { - is_better = true; + if (rhs_shard.IsManualLeaf() && lhs_shard.IsTileMaximalLeaf()) { + mask |= 2; } } - return is_better; + if (mask == 3) break; } - if (lhs.IsManual() && rhs.IsTileMaximal()) { - return true; + return mask; +} + +bool IsShardingStrictlyBetter(const HloSharding& lhs, const HloSharding& rhs) { + CHECK_EQ(lhs.IsTuple(), rhs.IsTuple()) << lhs << " <> " << rhs; + if (lhs.IsTuple()) { + return MaskTupleShardingStrictlyBetter(lhs, rhs) == 1; } - return false; + return lhs.IsManualLeaf() && rhs.IsTileMaximalLeaf(); } // Implementation for returning a improved sharding from another sharding. @@ -107,15 +125,15 @@ std::optional ReturnImprovedShardingImpl( bool allow_aggressive_resharding = false) { // Always allow improve the sharding if it's straightly better. if (to_improved != nullptr && IsShardingStrictlyBetter(from, *to_improved)) { - return from; + return std::move(from); } // We don't want to propagate tile maximal shardings. if (!IsSpatiallyPartitioned(from)) { return std::nullopt; } - // Any sharding is better then no sharding. + // Any sharding is better than no sharding. if (to_improved == nullptr) { - return from; + return std::move(from); } // We don't want to propagate manual shardings. if (from.IsManual()) { @@ -137,7 +155,7 @@ std::optional ReturnImprovedShardingImpl( return std::nullopt; } } - return from; + return std::move(from); } return std::nullopt; } @@ -225,7 +243,7 @@ bool MaybeImproveInstructionSubSharding( } // We consider a convolution kernel to be small iff it is smaller along all -// spatial dimensions then the output of the convolution. The rational is that +// spatial dimensions than the output of the convolution. The rational is that // we can either shard the kernel or the output and we want to shard the larger // one for better efficiency. bool IsConvolutionKernelSmall(const HloInstruction* instruction) { @@ -303,8 +321,10 @@ const HloInstruction* PickRepresentativeOperand( case HloOpcode::kAllReduce: case HloOpcode::kReduceScatter: case HloOpcode::kAllToAll: + case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: case HloOpcode::kDivide: + case HloOpcode::kErf: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -522,7 +542,8 @@ std::optional LookaheadUserSharding(HloInstruction* instr, HloInstruction* current = users_chain[i - 1]; CHECK(user->has_sharding()); sharding = ShardingPropagation::GetShardingFromUser( - *current, *user, INT64_MAX, is_spmd, call_graph); + *current, *user, INT64_MAX, is_spmd, call_graph, + /*sharding_helper=*/nullptr); // We need to set the sharding to the instruction, because // GetShardingFromUser() interface uses sharding from the instruction // itself. It will be cleared out later. @@ -1120,7 +1141,8 @@ bool InferUnspecifiedDimsFromOneUser(HloInstruction* annotate_op, std::optional user_sharding = ShardingPropagation::GetShardingFromUser( man_conversion_op == nullptr ? *annotate_op : *man_conversion_op, - *user, aggressiveness, is_spmd, call_graph); + *user, aggressiveness, is_spmd, call_graph, + /*sharding_helper=*/nullptr); if (!user_sharding.has_value() || user_sharding->IsTileMaximal()) { return false; } @@ -1415,11 +1437,12 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction, instruction, may_combine_partial_sharding); } - // If the kernel is large (e.g backward convolution) then we only support - // replicated output. + // If the kernel is large (e.g., backward convolution) then we only support + // replicated output. We intend to keep the sharding along the batch dimension + // between lhs and output. return MaybeImproveInstructionSharding( - hlo_sharding_util::ReplicateAllDataDims(lhs->sharding(), - instruction->shape().rank()), + hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept( + lhs->sharding(), {dnums.input_batch_dimension()}), instruction, may_combine_partial_sharding); } @@ -1513,7 +1536,7 @@ bool InferReduceShardingFromOperand(HloInstruction* instruction, // copy node for reshard. // `unspecified_dims` will be populated with the converted copies if the custom // call is partially specified. -StatusOr ProcessShardingInstruction( +absl::StatusOr ProcessShardingInstruction( HloModule* module, const absl::flat_hash_set& execution_threads, bool replace_sharding_with_copy, @@ -1526,14 +1549,16 @@ StatusOr ProcessShardingInstruction( absl::flat_hash_map>* shard_group_id_to_shard_as_group, absl::flat_hash_map>* - shard_group_id_to_shard_like_group) { + shard_group_id_to_shard_like_group, + const std::vector* + allow_spmd_sharding_propagation_to_parameters_vector) { bool changed = false; const bool use_shard_group = instruction_to_shard_group_id && shard_group_id_to_shard_as_group && shard_group_id_to_shard_like_group; auto process_shard_group_instruction = [&](HloInstruction* instruction, - const HloSharding& sharding) { + HloSharding sharding) { if (use_shard_group && sharding.IsShardGroup()) { // Store shard group relations. const int64_t shard_group_id = sharding.GetShardGroup().shard_group_id; @@ -1542,7 +1567,8 @@ StatusOr ProcessShardingInstruction( auto& shard_as_group = (*shard_group_id_to_shard_as_group)[shard_group_id]; if (!shard_as_group.empty()) { - CHECK_EQ(instruction->shape(), (*shard_as_group.begin())->shape()) + CHECK(ShapeUtil::SameDimensions(instruction->shape(), + (*shard_as_group.begin())->shape())) << "Instruction: " << instruction->ToString() << " has different shape from the shapes of the other " "instructions within the same shard_as group: " @@ -1553,7 +1579,8 @@ StatusOr ProcessShardingInstruction( auto& shard_like_group = (*shard_group_id_to_shard_like_group)[shard_group_id]; if (!shard_like_group.empty()) { - CHECK_EQ(instruction->shape(), (*shard_like_group.begin())->shape()) + CHECK(ShapeUtil::SameDimensions(instruction->shape(), + (*shard_like_group.begin())->shape())) << "Instruction: " << instruction->ToString() << " has different shape from the shapes of the other " "instructions within the same shard_like group: " @@ -1561,29 +1588,25 @@ StatusOr ProcessShardingInstruction( } shard_like_group.insert(instruction); } + sharding.ClearShardGroup(); } + return sharding; }; for (HloComputation* computation : module->computations(execution_threads)) { auto instructions = computation->MakeInstructionPostOrder(); for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) { HloInstruction* instruction = *it; if (instruction->IsCustomCall("Sharding")) { + HloSharding original_sharding = instruction->sharding(); TF_RET_CHECK(instruction->has_sharding()) << "Sharding instruction must have a sharding attribute"; VLOG(3) << "ProcessShardingInstruction: " << instruction->ToString(); - HloSharding sharding = instruction->sharding(); std::vector unspec_dims; TF_RETURN_IF_ERROR(sharding_op_util::ParseAttributes( Cast(instruction)->opaque(), &unspec_dims)); - // Add operand(i.e. the annotated op) into shard group. - process_shard_group_instruction(instruction->mutable_operand(0), - sharding); - // Strip the sharding of the shard group related annotations. - sharding.ClearShardGroup(); - // Replace it with a copy node so that it does not need special // handling. if (replace_sharding_with_copy) { @@ -1592,21 +1615,35 @@ StatusOr ProcessShardingInstruction( instruction->mutable_operand(0))); TF_RETURN_IF_ERROR( computation->ReplaceInstruction(instruction, copy)); + // Add into shard group. + HloSharding sharding = + process_shard_group_instruction(copy, original_sharding); copy->set_sharding(sharding); instruction = copy; changed = true; } + // Strip the sharding of the shard group related annotations. if (!unspec_dims.empty()) { absl::c_sort(unspec_dims); unspecified_dims->emplace(instruction, std::move(unspec_dims)); } else if (!instruction->operand(0)->has_sharding()) { + HloSharding sharding = original_sharding; + if (instruction->operand(0)->opcode() != HloOpcode::kParameter || + (allow_spmd_sharding_propagation_to_parameters_vector && + allow_spmd_sharding_propagation_to_parameters_vector->size() == + module->entry_computation()->num_parameters() && + allow_spmd_sharding_propagation_to_parameters_vector->at( + instruction->operand(0)->parameter_number()))) { + // Add operand(i.e. the annotated op) into shard group. + sharding = process_shard_group_instruction( + instruction->mutable_operand(0), sharding); + } instruction->mutable_operand(0)->set_sharding(std::move(sharding)); } } else if (instruction->has_sharding()) { // Handle shard group in parameters/outputs. - process_shard_group_instruction(instruction, instruction->sharding()); - HloSharding sharding = instruction->sharding(); - sharding.ClearShardGroup(); + HloSharding sharding = process_shard_group_instruction( + instruction, instruction->sharding()); instruction->set_sharding(std::move(sharding)); } } @@ -1685,7 +1722,8 @@ int64_t ComputeNonRootUsers(const HloInstruction* instr) { // Return the sharding that should be propagated from user to instruction. std::optional ShardingPropagation::GetShardingFromUser( const HloInstruction& instruction, const HloInstruction& user, - int64_t aggressiveness, bool is_spmd, const CallGraph& call_graph) { + int64_t aggressiveness, bool is_spmd, const CallGraph& call_graph, + const CustomCallShardingHelper* sharding_helper) { if (!CanPropagateThroughAtAggressiveLevel(user, aggressiveness)) { return std::nullopt; } @@ -2039,6 +2077,23 @@ std::optional ShardingPropagation::GetShardingFromUser( } return std::nullopt; } + case HloOpcode::kCustomCall: { + bool compatible_shapes = ShapeUtil::CompatibleIgnoringElementType( + instruction.shape(), user.shape()); + if (!compatible_shapes) { + // Incompatible shapes, we will not propagate sharding. + return std::nullopt; + } + if (!sharding_helper) { + // No available sharding helper and shapes are compatible, we will + // propagate sharding. + return user.sharding(); + } + if (sharding_helper->CanPropagateShardingToOperands(&user)) { + return user.sharding(); + } + return std::nullopt; + } default: { // If the user output shape is compatible with the current instruction // shape excluding element type and the current instruction is supported @@ -2146,13 +2201,6 @@ bool ShardingPropagation::InferShardingFromShardGroup( return true; } } - if (!SupportSpatialPartitioning( - instruction, computation_map, is_spmd_, - allow_spmd_sharding_propagation_to_output_, - allow_spmd_sharding_propagation_to_parameters_, - sharding_helper_.get())) { - return false; - } const bool may_combine_partial_sharding = is_spmd_ && aggressiveness > 0; bool changed = false; @@ -2773,7 +2821,8 @@ bool ShardingPropagation::InferShardingFromUsers( } else { std::optional user_sharding = ShardingPropagation::GetShardingFromUser( - *instruction, *user, aggressiveness, is_spmd, call_graph); + *instruction, *user, aggressiveness, is_spmd, call_graph, + sharding_helper); if (user_sharding && user_sharding->IsManual()) { instruction->set_sharding(std::move(*user_sharding)); return true; @@ -2792,8 +2841,9 @@ bool ShardingPropagation::InferShardingFromUsers( const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0; for (const HloInstruction* user : instruction->users()) { std::optional user_sharding = - ShardingPropagation::GetShardingFromUser( - *instruction, *user, aggressiveness, is_spmd, call_graph); + ShardingPropagation::GetShardingFromUser(*instruction, *user, + aggressiveness, is_spmd, + call_graph, sharding_helper); if (user_sharding && instruction->opcode() == HloOpcode::kCustomCall) { if (auto* partitioner = GetCustomCallPartitioner(instruction->custom_call_target())) { @@ -2814,23 +2864,27 @@ bool ShardingPropagation::InferShardingFromUsers( return improved_sharding; } -Status ShardingPropagation::CanonicalizeLayouts(HloModule* module) { - if (!allow_spmd_sharding_propagation_to_output_) { - return OkStatus(); - } - if (!module->layout_canonicalization_callback()) { - LOG(INFO) << "There is no registered layout_canonicalization_callback."; - return OkStatus(); +Status SetParameterShapes( + HloModule* module, const std::vector& parameter_shapes, + const std::vector& + allow_spmd_sharding_propagation_to_parameters_vector) { + for (int64_t i = 0; i < module->entry_computation()->num_parameters(); ++i) { + if (!allow_spmd_sharding_propagation_to_parameters_vector[i]) { + continue; + } + TF_RETURN_IF_ERROR(module->mutable_config() + .mutable_entry_computation_layout() + ->mutable_parameter_layout(i) + ->CopyLayoutFromShape(parameter_shapes[i])); } - // If the result layout is automatically set, allow layout assignment to - // choose the layout. + return OkStatus(); +} + +Status SetResultShape(HloModule* module, const Shape& result_shape) { if (!module->entry_computation_layout().LayoutIsSet() || !module->entry_computation_layout().result_layout().LayoutIsSet()) { return OkStatus(); } - TF_ASSIGN_OR_RETURN(auto layouts, - module->layout_canonicalization_callback()(*module)); - Shape& result_shape = layouts.second; TF_RETURN_IF_ERROR(module->mutable_config() .mutable_entry_computation_layout() ->mutable_result_layout() @@ -2838,7 +2892,29 @@ Status ShardingPropagation::CanonicalizeLayouts(HloModule* module) { return OkStatus(); } -StatusOr ShardingPropagation::Run( +Status ShardingPropagation::CanonicalizeLayouts(HloModule* module) { + if (!allow_spmd_sharding_propagation_to_output_ && + !allow_spmd_sharding_propagation_to_parameters_) { + return OkStatus(); + } + if (!module->layout_canonicalization_callback()) { + LOG(INFO) << "There is no registered layout_canonicalization_callback."; + return OkStatus(); + } + TF_ASSIGN_OR_RETURN(auto shapes_with_layout, + module->layout_canonicalization_callback()(*module)); + if (allow_spmd_sharding_propagation_to_parameters_) { + TF_RETURN_IF_ERROR(SetParameterShapes( + module, shapes_with_layout.first, + allow_spmd_sharding_propagation_to_parameters_vector_)); + } + if (allow_spmd_sharding_propagation_to_output_) { + TF_RETURN_IF_ERROR(SetResultShape(module, shapes_with_layout.second)); + } + return OkStatus(); +} + +absl::StatusOr ShardingPropagation::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { std::optional> @@ -2891,7 +2967,8 @@ StatusOr ShardingPropagation::Run( ? &saved_parameter_shardings : nullptr, &instruction_to_shard_group_id, &shard_group_id_to_shard_as_group, - &shard_group_id_to_shard_like_group)); + &shard_group_id_to_shard_like_group, + &allow_spmd_sharding_propagation_to_parameters_vector_)); any_changed |= changed; // Check sizes of the given allow_spmd_sharding_propagation vectors if (allow_spmd_sharding_propagation_to_output_) { @@ -2908,8 +2985,18 @@ StatusOr ShardingPropagation::Run( "computation."; } if (allow_spmd_sharding_propagation_to_parameters_) { + auto is_same_sized_tuple = [](HloModule* module, int64_t size) { + if (module->entry_computation()->num_parameters() != 1) { + return false; + } + HloInstruction* param = + module->entry_computation()->parameter_instruction(0); + return param->shape().IsTuple() && + size == param->shape().tuple_shapes_size(); + }; auto size = allow_spmd_sharding_propagation_to_parameters_vector_.size(); - CHECK(size == 1 || size == module->entry_computation()->num_parameters()) + CHECK(size == 1 || size == module->entry_computation()->num_parameters() || + is_same_sized_tuple(module, size)) << "allow-spmd-sharding-propagation-to-parameters-vector's size can be " "either 1 or the number of parameters in the entry computation."; } @@ -3047,7 +3134,12 @@ StatusOr ShardingPropagation::Run( } } - if (!allow_spmd_sharding_propagation_to_output_) { + if (!allow_spmd_sharding_propagation_to_output_ && + (!module->entry_computation()->root_instruction()->has_sharding() || + !module->entry_computation() + ->root_instruction() + ->sharding() + .IsUnknown())) { // Consider the root instruction of the entry module as one with provided // sharding as its sharding have to match with the one expected by the host. provided_shardings.insert(module->entry_computation()->root_instruction()); @@ -3055,7 +3147,7 @@ StatusOr ShardingPropagation::Run( if (!allow_spmd_sharding_propagation_to_parameters_) { for (auto param : module->entry_computation()->parameter_instructions()) { - if (param->has_sharding()) { + if (param->has_sharding() && !param->sharding().IsUnknown()) { provided_shardings.insert(param); } } @@ -3127,8 +3219,8 @@ StatusOr ShardingPropagation::Run( } } }; - // Firstly, iterate the shard groups to take shardings from instructions - // of the same group. + // 1. Iterate the shard groups to take shardings from instructions of + // the same group. for (HloInstruction* instruction : instructions) { if (already_inferred_from_shard_group.contains(instruction)) { continue; @@ -3161,7 +3253,7 @@ StatusOr ShardingPropagation::Run( changed_last_iter = true; } } - // Secondly, iterate the HLO graph in post order taking shardings from + // 2. Iterate the HLO graph in post order taking shardings from // operands. for (HloInstruction* instruction : instructions) { if (already_inferred_from_operands.contains(instruction)) { @@ -3202,8 +3294,8 @@ StatusOr ShardingPropagation::Run( changed_last_iter = true; } } - // Then iterate the HLO graph in reverse post order taking shardings - // from users. + // 3. Iterate the HLO graph in reverse post order taking shardings from + // users. for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) { if ((*it)->IsCustomCall("SPMDFullToShardShape") || (*it)->IsCustomCall("SPMDShardToFullShape")) { @@ -3283,10 +3375,9 @@ StatusOr ShardingPropagation::Run( // get the most specific sharding. If some of them are not compatible, then // it will just choose the a random sharding among them(say the first one). std::vector shardings; - absl::c_transform(shard_as_group, std::back_inserter(shardings), - [](const HloInstruction* instruction) { - return instruction->sharding(); - }); + for (HloInstruction* instruction : shard_as_group) { + shardings.push_back(instruction->sharding()); + } HloSharding common_sharding = hlo_sharding_util::FindCommonSharding(shardings); VLOG(2) << "Aligning shard group: " << shard_as_group_id @@ -3337,18 +3428,116 @@ StatusOr ShardingPropagation::Run( root_instruction->set_sharding(std::move(root_sharding)); } auto params = module->entry_computation()->parameter_instructions(); - if (allow_spmd_sharding_propagation_to_parameters_ && - allow_spmd_sharding_propagation_to_parameters_vector_.size() == - params.size()) { - for (int64_t i = 0; i < params.size(); ++i) { - if (!allow_spmd_sharding_propagation_to_parameters_vector_[i]) { - if (saved_parameter_shardings.contains(i) && - !saved_parameter_shardings.at(i).IsUnknown()) { - params[i]->set_sharding(saved_parameter_shardings.at(i)); - } else { - params[i]->clear_sharding(); + if (allow_spmd_sharding_propagation_to_parameters_) { + if (allow_spmd_sharding_propagation_to_parameters_vector_.size() == + params.size()) { + for (int64_t i = 0; i < params.size(); ++i) { + if (!allow_spmd_sharding_propagation_to_parameters_vector_[i]) { + if (saved_parameter_shardings.contains(i) && + !saved_parameter_shardings.at(i).IsUnknown()) { + params[i]->set_sharding(saved_parameter_shardings.at(i)); + } else { + params[i]->clear_sharding(); + } + } + } + } else if (params.size() == 1 && saved_parameter_shardings.size() == 1 && + params[0]->shape().IsTuple() && + params[0]->shape().tuple_shapes_size() == + allow_spmd_sharding_propagation_to_parameters_vector_ + .size()) { + // There is a single parameter which is a tuple with many elements. + HloSharding param_sharding = params[0]->sharding(); + for (int64_t i = 0; i < params[0]->shape().tuple_shapes_size(); ++i) { + HloSharding saved_subsharding = + saved_parameter_shardings.at(0).GetSubSharding(params[0]->shape(), + {i}); + if (!allow_spmd_sharding_propagation_to_parameters_vector_[i] && + !saved_subsharding.IsUnknown()) { + param_sharding.tuple_elements()[i] = saved_subsharding; + } + } + params[0]->set_sharding(std::move(param_sharding)); + } + } + // Replicate the parameter/output sharding if the propagated sharding does not + // evenly partition the parameter/output. + std::function evenly_partitions = + [&evenly_partitions](const Shape& shape, + const HloSharding& sharding) -> bool { + if (!sharding.IsTiled()) { + return true; + } + if (sharding.IsTileMaximal()) { + return sharding.IsReplicated(); + } + if (sharding.IsTuple()) { + for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + if (!evenly_partitions(ShapeUtil::GetTupleElementShape(shape, i), + sharding.GetSubSharding(shape, {i}))) { + return false; + } + } + } + for (int64_t i = 0; i < shape.dimensions_size(); ++i) { + if (shape.dimensions(i) % sharding.tile_assignment().dim(i) != 0) { + return false; + } + } + return true; + }; + if (allow_spmd_sharding_propagation_to_output_ && + root_instruction->has_sharding()) { + if (root_instruction->shape().IsTuple() && + allow_spmd_sharding_propagation_to_output_vector_.size() == + root_instruction->shape().tuple_shapes_size()) { + // The output shape is a tuple and sharding propagation is allowed for at + // least one of its elements. + HloSharding root_sharding = root_instruction->sharding(); + for (int64_t i = 0; i < root_instruction->shape().tuple_shapes_size(); + ++i) { + if (allow_spmd_sharding_propagation_to_output_vector_[i] && + !evenly_partitions(root_instruction->shape().tuple_shapes(i), + root_sharding.tuple_elements()[i])) { + root_sharding.tuple_elements()[i] = HloSharding::Replicate(); + } + } + root_instruction->set_sharding(std::move(root_sharding)); + } else if (!root_instruction->shape().IsTuple()) { + // The output shape is not tuple and sharding propagation is allowed. + if (!evenly_partitions(root_instruction->shape(), + root_instruction->sharding())) { + root_instruction->set_sharding(HloSharding::Replicate()); + } + } + } + if (allow_spmd_sharding_propagation_to_parameters_) { + // Sharding propagation is allowed for at least one parameter. + if (allow_spmd_sharding_propagation_to_parameters_vector_.size() == + params.size()) { + for (int64_t i = 0; i < params.size(); ++i) { + if (params[i]->has_sharding() && + allow_spmd_sharding_propagation_to_parameters_vector_[i] && + !evenly_partitions(params[i]->shape(), params[i]->sharding())) { + params[i]->set_sharding(HloSharding::Replicate()); + } + } + } else if (params.size() == 1 && params[0]->shape().IsTuple() && + params[0]->has_sharding() && + params[0]->shape().tuple_shapes_size() == + allow_spmd_sharding_propagation_to_parameters_vector_ + .size()) { + HloSharding param_sharding = params[0]->sharding(); + for (int64_t i = 0; i < params[0]->shape().tuple_shapes_size(); ++i) { + if (allow_spmd_sharding_propagation_to_parameters_vector_[i] && + !evenly_partitions( + ShapeUtil::GetSubshapeOneIndex(params[0]->shape(), i), + params[0]->sharding().GetSubSharding(params[0]->shape(), + {i}))) { + param_sharding.tuple_elements()[i] = HloSharding::Replicate(); } } + params[0]->set_sharding(std::move(param_sharding)); } } TF_RETURN_IF_ERROR(CanonicalizeLayouts(module)); diff --git a/xla/service/sharding_propagation.h b/xla/service/sharding_propagation.h index 82aa10c6deccc..cdede31734542 100644 --- a/xla/service/sharding_propagation.h +++ b/xla/service/sharding_propagation.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -57,7 +57,7 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction, // operand's existing sharding. // unspecified_dims will be populated with the converted copies if the custom // call is partially specified. -StatusOr ProcessShardingInstruction( +absl::StatusOr ProcessShardingInstruction( HloModule* module, const absl::flat_hash_set& execution_threads, bool replace_sharding_with_copy, @@ -70,7 +70,9 @@ StatusOr ProcessShardingInstruction( absl::flat_hash_map>* shard_group_id_to_shard_as_group = nullptr, absl::flat_hash_map>* - shard_group_id_to_shard_like_group = nullptr); + shard_group_id_to_shard_like_group = nullptr, + const std::vector* + allow_spmd_sharding_propagation_to_parameters_vector = nullptr); int64_t ComputeNonRootUsers(const HloInstruction* instr); @@ -120,7 +122,7 @@ class ShardingPropagation : public HloModulePass { } absl::string_view name() const override { return "sharding-propagation"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -133,7 +135,8 @@ class ShardingPropagation : public HloModulePass { static std::optional GetShardingFromUser( const HloInstruction& instruction, const HloInstruction& user, - int64_t aggressiveness, bool is_spmd, const CallGraph& call_graph); + int64_t aggressiveness, bool is_spmd, const CallGraph& call_graph, + const CustomCallShardingHelper* sharding_helper); // Canonicalizes entry_computation_layouts by calling // module.layout_canonicalization_callback(), which gives canolicalized diff --git a/xla/service/sharding_propagation_test.cc b/xla/service/sharding_propagation_test.cc index e3cdf176be513..f4a97ff8c1dd5 100644 --- a/xla/service/sharding_propagation_test.cc +++ b/xla/service/sharding_propagation_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,17 +19,28 @@ limitations under the License. #include #include +#include #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_op_metadata.h" +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/transforms/hlo_constant_splitter.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/protobuf_util.h" #include "xla/service/hlo_dce.h" #include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace op = xla::testing::opcode_matchers; @@ -3108,6 +3119,41 @@ ENTRY entry { } } +TEST_P(ParameterizedMetadataTest, ConvolutionDataParallelism) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + p0 = f32[256,512,16,32] parameter(0), sharding={devices=[2,2,2,2]<=[16] metadata={op_name="lhs_sharding"}} + p1 = f32[512,1,12,28] parameter(1), sharding={replicated metadata={op_name="rhs_sharding"}} + conv = f32[256,512,5,5] convolution(p0, p1), window={size=12x28}, dim_labels=bf01_oi01->bf01, feature_group_count=512 + ROOT copy = f32[256,512,5,5] copy(conv) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + if (GetParam().clear_metadata) { + ClearMetadata(module.get()); + } + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, GetParam().propagate_metadata) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + auto* instruction = FindInstruction(module.get(), "conv"); + ASSERT_NE(instruction, nullptr); + EXPECT_THAT( + instruction, + op::Sharding("{devices=[2,1,1,1,8]<=[16] last_tile_dim_replicate}")); + if (GetParam().propagate_metadata && !GetParam().clear_metadata) { + EXPECT_THAT(instruction->sharding(), + ShardingMetadata({CreateMetadata("lhs_sharding")})); + } else { + EXPECT_THAT(instruction->sharding(), ShardingMetadata({})); + } +} + TEST_P(ParameterizedMetadataTest, ConcatFromUserUnshardedDim) { const char* const hlo_string = R"( HloModule module @@ -9260,6 +9306,7 @@ ENTRY %entry { HloConstantSplitter(/*split_expressions=*/true).Run(module.get())); EXPECT_TRUE(is_split); TF_ASSERT_OK_AND_ASSIGN(auto _, HloDCE().Run(module.get())); + (void)_; // Suppress unused variable warning in OSS TF_ASSERT_OK_AND_ASSIGN( bool changed, ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true) @@ -10072,6 +10119,81 @@ ENTRY %entry { op::Sharding("{devices=[4]0,1,2,3}")); } +TEST_F(ShardingPropagationTest, PropagateToTupleParameter_WithoutSharding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + %param = (f32[4], f32[4]) parameter(0) + %gte0 = f32[4] get-tuple-element(%param), index=0 + %gte1 = f32[4] get-tuple-element(%param), index=1 + ROOT %add = f32[4] add(%gte0, %gte1), sharding={devices=[4]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{false}, + /*allow_spmd_sharding_propagation_to_parameters=*/{true, true}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->parameter_instruction(0), + op::Sharding("{{devices=[4]0,1,2,3}, {devices=[4]0,1,2,3}}")); +} + +TEST_F(ShardingPropagationTest, PropagateToTupleParameter_WithSharding1) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + %param = (f32[4], f32[4]) parameter(0), sharding={{replicated}, {replicated}} + %gte0 = f32[4] get-tuple-element(%param), index=0 + %gte1 = f32[4] get-tuple-element(%param), index=1 + ROOT %add = f32[4] add(%gte0, %gte1), sharding={devices=[4]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{false}, + /*allow_spmd_sharding_propagation_to_parameters=*/{false, true}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->parameter_instruction(0), + op::Sharding("{{replicated}, {devices=[4]0,1,2,3}}")); +} + +TEST_F(ShardingPropagationTest, PropagateToTupleParameter_WithSharding2) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + %param = (f32[4], f32[4]) parameter(0), sharding={{replicated}, {replicated}} + %gte0 = f32[4] get-tuple-element(%param), index=0 + %gte1 = f32[4] get-tuple-element(%param), index=1 + ROOT %add = f32[4] add(%gte0, %gte1), sharding={devices=[4]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{false}, + /*allow_spmd_sharding_propagation_to_parameters=*/{true, false}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->parameter_instruction(0), + op::Sharding("{{devices=[4]0,1,2,3}, {replicated}}")); +} + TEST_F(ShardingPropagationTest, PropagateManualOutfeed) { const char* const hlo_string = R"( HloModule module @@ -10138,6 +10260,211 @@ ENTRY %entry { op::Sharding("{devices=[4]0,1,2,3}")); } +TEST_F(ShardingPropagationTest, + DoNotPropagateToParameterIfNotDivisible_WithSharding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + %param0 = f32[4] parameter(0), sharding={replicated} + %param1 = f32[3] parameter(1), sharding={replicated} + %pad_value = f32[] constant(0) + %pad = f32[4] pad(%param1, %pad_value), padding=0_1 + ROOT %add = f32[4] add(%param0, %pad), sharding={devices=[4]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{false}, + /*allow_spmd_sharding_propagation_to_parameters=*/{false, true}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->parameter_instruction(0), + op::Sharding("{replicated}")); + // Replicate the input since the propagated sharding does not evenly partition + // it. + EXPECT_THAT(module->entry_computation()->parameter_instruction(1), + op::Sharding("{replicated}")); +} + +TEST_F(ShardingPropagationTest, + DoNotPropagateToParameterIfNotDivisible_WithoutSharding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + %param0 = f32[4] parameter(0), sharding={replicated} + %param1 = f32[3] parameter(1) + %pad_value = f32[] constant(0) + %pad = f32[4] pad(%param1, %pad_value), padding=0_1 + ROOT %add = f32[4] add(%param0, %pad), sharding={devices=[4]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{false}, + /*allow_spmd_sharding_propagation_to_parameters=*/{false, true}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->parameter_instruction(0), + op::Sharding("{replicated}")); + // Replicate the input since the propagated sharding does not evenly partition + // it. + EXPECT_THAT(module->entry_computation()->parameter_instruction(1), + op::Sharding("{replicated}")); +} + +TEST_F(ShardingPropagationTest, DoNotPropagateToTupleParameterIfNotDivisible) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + %param0 = (f32[4], f32[3]) parameter(0), sharding={{replicated}, {replicated}} + %gte0 = f32[4] get-tuple-element(%param0), index=0 + %gte1 = f32[3] get-tuple-element(%param0), index=1 + %pad_value = f32[] constant(0) + %pad = f32[4] pad(%gte1, %pad_value), padding=0_1 + ROOT %add = f32[4] add(%gte0, %pad), sharding={devices=[4]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{false}, + /*allow_spmd_sharding_propagation_to_parameters=*/{false, true}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + // Replicate the second element of parameter since the propagated sharding + // does not evenly partition it. + EXPECT_THAT(module->entry_computation()->parameter_instruction(0), + op::Sharding("{{replicated}, {replicated}}")); +} + +TEST_F(ShardingPropagationTest, + DoNotPropagateToOutputIfNotDivisible_WithSharding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + %param0 = f32[4] parameter(0), sharding={replicated} + %param1 = f32[4] parameter(1), sharding={replicated} + %add = f32[4] add(%param0, %param1), sharding={devices=[4]0,1,2,3} + ROOT %slice = f32[3] slice(%add), slice={[0:3:1]}, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{false, false}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + // Replicate the output since the propagated sharding does not evenly + // partition it. + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Sharding("{replicated}")); +} + +TEST_F(ShardingPropagationTest, + DoNotPropagateToOutputIfNotDivisible_WithoutSharding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + %param0 = f32[4] parameter(0), sharding={replicated} + %param1 = f32[4] parameter(1), sharding={replicated} + %add = f32[4] add(%param0, %param1), sharding={devices=[4]0,1,2,3} + ROOT %slice = f32[3] slice(%add), slice={[0:3:1]} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{false, false}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + // Replicate the output since the propagated sharding does not evenly + // partition it. + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Sharding("{replicated}")); +} + +TEST_F(ShardingPropagationTest, + DoNotPropagateToOutputTupleIfNotDivisible_WithSharding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + %param0 = f32[4] parameter(0), sharding={replicated} + %param1 = f32[4] parameter(1), sharding={replicated} + %add = f32[4] add(%param0, %param1), sharding={devices=[4]0,1,2,3} + %slice = f32[3] slice(%add), slice={[0:3:1]} + ROOT %tuple = (f32[4], f32[3]) tuple(%add, %slice), sharding={{replicated}, {replicated}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{false, true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{false, false}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + // Replicate the output tuple element since the propagated sharding does not + // evenly partition it. + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Sharding("{{replicated}, {replicated}}")); +} + +TEST_F(ShardingPropagationTest, + DoNotPropagateToOutputTupleIfNotDivisible_WithoutSharding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + %param0 = f32[4] parameter(0), sharding={replicated} + %param1 = f32[4] parameter(1), sharding={replicated} + %add = f32[4] add(%param0, %param1), sharding={devices=[4]0,1,2,3} + %slice = f32[3] slice(%add), slice={[0:3:1]} + ROOT %tuple = (f32[4], f32[3]) tuple(%add, %slice) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true, true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{false, false}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + // Replicate the output tuple element since the propagated sharding does not + // evenly partition it. + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Sharding("{{devices=[4]0,1,2,3}, {replicated}}")); +} + TEST_F(ShardingPropagationTest, PropagateShardLikeDifferentSharding) { const char* const hlo_string = R"( HloModule module @@ -10338,6 +10665,63 @@ ENTRY %entry { EXPECT_EQ(add_1->sharding(), output->sharding()); } +TEST_F(ShardingPropagationTest, PropagateShardAsBetweenInputOutput) { + const char* const hlo_string = R"( +HloModule jit_zeros_like + +ENTRY main.6 { + Arg_0.1 = s64[8,2]{1,0} parameter(0), sharding={devices=[4,2]<=[8]} + custom-call.4 = s64[8,2]{1,0} custom-call(Arg_0.1), custom_call_target="Sharding", sharding={unknown shard_as 0} + constant.2 = s64[] constant(0) + broadcast.3 = s64[8,2]{1,0} broadcast(constant.2), dimensions={} + ROOT custom-call.5 = s64[8,2]{1,0} custom-call(broadcast.3), custom_call_target="Sharding", sharding={unknown shard_as 0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{false}, + /*allow_spmd_sharding_propagation_to_parameters=*/{false, false}) + .Run(module.get())); + EXPECT_TRUE(changed); + VLOG(1) << module->ToString(); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Sharding("{devices=[4,2]0,1,2,3,4,5,6,7}")); +} + +TEST_F(ShardingPropagationTest, PropagateShardAsBetweenInputOutput2) { + const char* const hlo_string = R"( +HloModule jit_f, entry_computation_layout={(f32[8]{0:T(256)})->(f32[8]{0:T(256)}, f32[8]{0:T(256)})}, allow_spmd_sharding_propagation_to_output={true,true}, num_partitions=4 + +ENTRY main.9 { + Arg_0.1 = f32[8]{0} parameter(0), sharding={replicated} + custom-call.6 = f32[8]{0} custom-call(Arg_0.1), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 0}, metadata={op_name="jit(f)/jit(main)/shard_alike" source_file="third_party/py/jax/tests/shard_alike_test.py" source_line=206} + custom-call.4 = f32[8]{0} custom-call(Arg_0.1), custom_call_target="Sharding", sharding={devices=[4]<=[4]}, metadata={op_name="jit(f)/jit(main)/sharding_constraint[sharding=GSPMDSharding({devices=[4]<=[4]}) resource_env=ResourceEnv(mesh=Mesh(), ()) unconstrained_dims=set()]" source_file="third_party/py/jax/tests/shard_alike_test.py" source_line=204} + constant.0 = f32[] constant(2) + broadcast.0 = f32[8]{0} broadcast(constant.0), dimensions={} + multiply.5 = f32[8]{0} multiply(custom-call.4, broadcast.0), metadata={op_name="jit(f)/jit(main)/mul" source_file="third_party/py/jax/tests/shard_alike_test.py" source_line=205} + custom-call.7 = f32[8]{0} custom-call(multiply.5), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 0}, metadata={op_name="jit(f)/jit(main)/shard_alike" source_file="third_party/py/jax/tests/shard_alike_test.py" source_line=206} + ROOT tuple.8 = (f32[8]{0}, f32[8]{0}) tuple(custom-call.6, custom-call.7) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true, true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{false}) + .Run(module.get())); + EXPECT_TRUE(changed); + VLOG(1) << module->ToString(); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Sharding("{{devices=[4]<=[4]}, {devices=[4]<=[4]}}")); +} + TEST_F(ShardingPropagationTest, LookaheadUsersOfDot) { const char* const hlo_string = R"( HloModule module @@ -10386,8 +10770,8 @@ called_computation { ENTRY entry_computation { p0 = s32[8] parameter(0), sharding={manual} p1 = s32[8] parameter(1), sharding={manual} - async-start = ((s32[8], s32[8]), s32[8], u32[]) call-start(p0, p1), async_group_id=0, async_execution_thread="thread_1", to_apply=called_computation - ROOT async-done = s32[8] call-done(async-start), async_group_id=0, async_execution_thread="thread_1", to_apply=called_computation + async-start = ((s32[8], s32[8]), s32[8], u32[]) call-start(p0, p1), async_execution_thread="thread_1", to_apply=called_computation + ROOT async-done = s32[8] call-done(async-start) }, execution_thread="thread_0" // entry_computation )"; @@ -10462,8 +10846,8 @@ called_computation { ENTRY entry_computation { p0 = s32[8] parameter(0), sharding={manual} p1 = s32[8] parameter(1), sharding={manual} - async-start = ((s32[8], s32[8]), (s32[8], s32[8]), u32[]) call-start(p0, p1), async_group_id=0, async_execution_thread="thread_1", to_apply=called_computation - ROOT async-done = (s32[8], s32[8]) call-done(async-start), async_group_id=0, async_execution_thread="thread_1", to_apply=called_computation + async-start = ((s32[8], s32[8]), (s32[8], s32[8]), u32[]) call-start(p0, p1), async_execution_thread="thread_1", to_apply=called_computation + ROOT async-done = (s32[8], s32[8]) call-done(async-start) }, execution_thread="thread_0" // entry_computation )"; diff --git a/xla/service/sharding_remover.cc b/xla/service/sharding_remover.cc index 874d503ebd69c..2dc0b6e6b1d40 100644 --- a/xla/service/sharding_remover.cc +++ b/xla/service/sharding_remover.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ namespace xla { // Remove Sharding custom-call instruction by assigning its users to // to its operand. -StatusOr ShardingRemover::Run( +absl::StatusOr ShardingRemover::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/sharding_remover.h b/xla/service/sharding_remover.h index 9a8d0c3f4f661..39acf378cad65 100644 --- a/xla/service/sharding_remover.h +++ b/xla/service/sharding_remover.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ class ShardingRemover : public HloModulePass { public: absl::string_view name() const override { return "sharding-remover"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/sharding_remover_test.cc b/xla/service/sharding_remover_test.cc index 18e950d119a68..86b52d32013ca 100644 --- a/xla/service/sharding_remover_test.cc +++ b/xla/service/sharding_remover_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/simplify_fp_conversions.cc b/xla/service/simplify_fp_conversions.cc index 22fbf3961506c..8cabbb17a4da1 100644 --- a/xla/service/simplify_fp_conversions.cc +++ b/xla/service/simplify_fp_conversions.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,10 @@ limitations under the License. #include "xla/service/simplify_fp_conversions.h" #include -#include #include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -27,29 +27,27 @@ limitations under the License. #include "xla/statusor.h" #include "xla/util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { // Simplifies floating-point conversions `A -> B -> C -> D` as `A -> D`. -StatusOr RunOnComputation(HloComputation& computation, - SimplifyFPConversions::Scope scope) { - const int minimum_logical_creation_pass_id = - (scope == SimplifyFPConversions::Scope::kSimplifyAllConversions) ? -1 : 0; +absl::StatusOr RunOnComputation(HloComputation& computation) { bool changed = false; for (HloInstruction* instruction : computation.MakeInstructionPostOrder()) { HloInstruction* input = instruction; size_t convert_chain_length = 0; - while ((input->opcode() == HloOpcode::kConvert) && - (input->metadata().logical_creation_pass_id() >= - minimum_logical_creation_pass_id) && + while (input->opcode() == HloOpcode::kConvert && primitive_util::IsFloatingPointType(input->shape().element_type())) { input = input->mutable_operand(0); ++convert_chain_length; } - if (convert_chain_length < 2) continue; + if (convert_chain_length < 2) { + continue; + } if (instruction->shape().element_type() == input->shape().element_type()) { TF_RETURN_IF_ERROR( @@ -64,36 +62,23 @@ StatusOr RunOnComputation(HloComputation& computation, return changed; } -std::string ToString(SimplifyFPConversions::Scope scope) { - using Scope = SimplifyFPConversions::Scope; - switch (scope) { - case Scope::kSimplifyAllConversions: - return "SimplifyAllConversions"; - case Scope::kOnlySimplifyCompilerGeneratedConversions: - return "OnlySimplifyCompilerGeneratedConversions"; - } -} - } // namespace -StatusOr SimplifyFPConversions::Run( +absl::StatusOr SimplifyFPConversions::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES( - 2, - absl::StrFormat("SimplifyFPConversions::Run() with scope=%s, before:\n%s", - ToString(scope_), module->ToString())); + 2, absl::StrFormat("SimplifyFPConversions::Run() with before:\n%s", + module->ToString())); bool changed = false; for (HloComputation* computation : module->MakeComputationPostOrder(execution_threads)) { - TF_ASSIGN_OR_RETURN(bool comp_changed, - RunOnComputation(*computation, scope_)); + TF_ASSIGN_OR_RETURN(bool comp_changed, RunOnComputation(*computation)); changed |= comp_changed; } - XLA_VLOG_LINES( - 2, - absl::StrFormat("SimplifyFPConversions::Run() with scope=%s, after:\n%s", - ToString(scope_), module->ToString())); + XLA_VLOG_LINES(2, + absl::StrFormat("SimplifyFPConversions::Run() with after:\n%s", + module->ToString())); return changed; } diff --git a/xla/service/simplify_fp_conversions.h b/xla/service/simplify_fp_conversions.h index 3904fa37505a5..099b06a283cef 100644 --- a/xla/service/simplify_fp_conversions.h +++ b/xla/service/simplify_fp_conversions.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,31 +26,19 @@ namespace xla { // Simplifies chains of floating-point conversions. // // The algebraic simplifier will remove convert pairs of the form `X -> Y -> X`, -// only when they are a no-op (e.g. `bf16 -> f32 -> bf16`). This passes does -// similar, but has two scopes: -// - kSimplifyAllConversions: Simplify any chain of float conversions, possibly -// improving accuracy (e.g. `f32 -> bf16 -> f32` is removed). -// - kOnlySimplifyCompilerGeneratedConversions: Only simplify chains of float -// conversions generated by the compiler in one of the previous optimization -// passes. +// only when they are a no-op, e.g. `bf16 -> f32 -> bf16` or +// `f32 -> bf16 -> f32`. Note that the latter optimization might lead to +// increased precision. class SimplifyFPConversions : public HloModulePass { public: - enum class Scope { - kOnlySimplifyCompilerGeneratedConversions, - kSimplifyAllConversions - }; - - explicit SimplifyFPConversions(Scope scope) : scope_(scope) {} + explicit SimplifyFPConversions() = default; absl::string_view name() const override { return "simplify-fp-conversions"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; - - private: - Scope scope_; }; } // namespace xla diff --git a/xla/service/simplify_fp_conversions_test.cc b/xla/service/simplify_fp_conversions_test.cc index fae975963adb4..ad85bb873eb65 100644 --- a/xla/service/simplify_fp_conversions_test.cc +++ b/xla/service/simplify_fp_conversions_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,32 +36,6 @@ using ::tsl::testing::IsOkAndHolds; using SimplifyFPConversionsTest = HloTestBase; -// This marks all ops in `module` as user-provided, meaning the -// simplifier won't remove any of the converts -static void InitializeCreationPassIds(HloModule* module) { - constexpr int kUserSuppliedOpCreationPassId = -1; - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - instruction->set_creation_pass_id(kUserSuppliedOpCreationPassId); - instruction->set_logical_creation_pass_id(kUserSuppliedOpCreationPassId); - } - } -} - -// This marks all converts ops in `module` as being created by the -// optimization pass `creation_pass_id`. -static void SetCreationPassIdInAllConvertOps(HloModule* module, - int creation_pass_id) { - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kConvert) { - instruction->set_creation_pass_id(creation_pass_id); - instruction->set_logical_creation_pass_id(creation_pass_id); - } - } - } -} - TEST_F(SimplifyFPConversionsTest, DoesNotChangeSingleConvert) { const absl::string_view kModuleStr = R"( HloModule test @@ -74,10 +48,8 @@ TEST_F(SimplifyFPConversionsTest, DoesNotChangeSingleConvert) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleStr)); - InitializeCreationPassIds(module.get()); - SimplifyFPConversions simplifier{ - SimplifyFPConversions::Scope::kSimplifyAllConversions}; + SimplifyFPConversions simplifier; EXPECT_THAT(simplifier.Run(module.get()), IsOkAndHolds(false)); } @@ -94,60 +66,13 @@ TEST_F(SimplifyFPConversionsTest, SimplifiesF32ToBF16ToF32) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleStr)); - InitializeCreationPassIds(module.get()); - SimplifyFPConversions simplifier{ - SimplifyFPConversions::Scope::kSimplifyAllConversions}; + SimplifyFPConversions simplifier; EXPECT_THAT(simplifier.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Tuple(op::Parameter(0))); } -TEST_F(SimplifyFPConversionsTest, SimplifiesCompilerGeneratedF32ToBF16ToF32) { - const absl::string_view kModuleStr = R"( - HloModule test - - ENTRY entry { - p0 = f32[2,3] parameter(0) - c0 = bf16[2,3] convert(p0) - c1 = f32[2,3] convert(c0) - ROOT ret = (f32[2,3]) tuple(c1) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kModuleStr)); - InitializeCreationPassIds(module.get()); - - constexpr int kRandomCreationPassId = 42; - SetCreationPassIdInAllConvertOps(module.get(), kRandomCreationPassId); - - SimplifyFPConversions simplifier{ - SimplifyFPConversions::Scope::kOnlySimplifyCompilerGeneratedConversions}; - EXPECT_THAT(simplifier.Run(module.get()), IsOkAndHolds(true)); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Parameter(0))); -} - -TEST_F(SimplifyFPConversionsTest, DoesNotChangeUserInsertedConverts) { - const absl::string_view kModuleStr = R"( - HloModule test - - ENTRY entry { - p0 = f32[2,3] parameter(0) - c0 = bf16[2,3] convert(p0) - c1 = f32[2,3] convert(c0) - ROOT ret = (f32[2,3]) tuple(c1) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kModuleStr)); - InitializeCreationPassIds(module.get()); - - SimplifyFPConversions simplifier{ - SimplifyFPConversions::Scope::kOnlySimplifyCompilerGeneratedConversions}; - EXPECT_THAT(simplifier.Run(module.get()), IsOkAndHolds(false)); -} - TEST_F(SimplifyFPConversionsTest, SimplifiesF64ToF16ToF32ToBF16) { const absl::string_view kModuleStr = R"( HloModule test @@ -162,10 +87,8 @@ TEST_F(SimplifyFPConversionsTest, SimplifiesF64ToF16ToF32ToBF16) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleStr)); - InitializeCreationPassIds(module.get()); - SimplifyFPConversions simplifier{ - SimplifyFPConversions::Scope::kSimplifyAllConversions}; + SimplifyFPConversions simplifier; EXPECT_THAT(simplifier.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT( module->entry_computation()->root_instruction(), diff --git a/xla/service/slice_sinker.cc b/xla/service/slice_sinker.cc index c2e9bf77d5c18..dc7559444436c 100644 --- a/xla/service/slice_sinker.cc +++ b/xla/service/slice_sinker.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -238,7 +238,7 @@ Status SinkSlices(const std::vector& slice_sources, // This pass currently doesn't transform non-elementwise instructions. We may // extend this pass to transform non-elementwise instructions, such as dot, // broadcast and reduce in the future. -StatusOr SliceSinker::Run( +absl::StatusOr SliceSinker::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/slice_sinker.h b/xla/service/slice_sinker.h index 4a5e6aa6ac5e4..61805ca874211 100644 --- a/xla/service/slice_sinker.h +++ b/xla/service/slice_sinker.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,7 +27,7 @@ class SliceSinker : public HloModulePass { absl::string_view name() const override { return "slice-sinker"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/slice_sinker_test.cc b/xla/service/slice_sinker_test.cc index 4ead532939304..cbbdafc877cda 100644 --- a/xla/service/slice_sinker_test.cc +++ b/xla/service/slice_sinker_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/slow_operation_alarm.cc b/xla/service/slow_operation_alarm.cc index d909b823a3448..01584663892d4 100644 --- a/xla/service/slow_operation_alarm.cc +++ b/xla/service/slow_operation_alarm.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/slow_operation_alarm.h b/xla/service/slow_operation_alarm.h index a9fa35ca9ed36..5c784a04cde81 100644 --- a/xla/service/slow_operation_alarm.h +++ b/xla/service/slow_operation_alarm.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/sort_simplifier.cc b/xla/service/sort_simplifier.cc index bc5b649acd6b4..99df7c6035beb 100644 --- a/xla/service/sort_simplifier.cc +++ b/xla/service/sort_simplifier.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ namespace { // If the sort instruction has a tuple shape then looks for unused output // values and removes them from the sort instruction. Returns true if the // graph has been modified. -StatusOr RemoveUnusedOperandFromSort(HloInstruction* sort) { +absl::StatusOr RemoveUnusedOperandFromSort(HloInstruction* sort) { if (!sort->shape().IsTuple()) { return false; } @@ -135,7 +135,7 @@ StatusOr RemoveUnusedOperandFromSort(HloInstruction* sort) { } } // namespace -StatusOr SortSimplifier::Run( +absl::StatusOr SortSimplifier::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(2) << "HLO module before SortSimplifier:"; diff --git a/xla/service/sort_simplifier.h b/xla/service/sort_simplifier.h index 4150848ff73d0..2f02216168b93 100644 --- a/xla/service/sort_simplifier.h +++ b/xla/service/sort_simplifier.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ class SortSimplifier : public HloModulePass { public: absl::string_view name() const override { return "simplify-sorts"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/sort_simplifier_test.cc b/xla/service/sort_simplifier_test.cc index 4f9ef3b25541d..ea8f208271a57 100644 --- a/xla/service/sort_simplifier_test.cc +++ b/xla/service/sort_simplifier_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/source_map_util.h b/xla/service/source_map_util.h index 349a3025e923e..5fd9db9dca751 100644 --- a/xla/service/source_map_util.h +++ b/xla/service/source_map_util.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/space_to_batch_converter.cc b/xla/service/space_to_batch_converter.cc index e0c91f50355a6..f5015217115b5 100644 --- a/xla/service/space_to_batch_converter.cc +++ b/xla/service/space_to_batch_converter.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -118,11 +118,12 @@ class ConvolutionVisitor { // Propagates space-to-batch on the op, and returns a bool that indicates if // the users of the op need to be propagated through. - StatusOr Propagate(HloInstruction* consumer, HloInstruction* producer); + absl::StatusOr Propagate(HloInstruction* consumer, + HloInstruction* producer); // Splits the given spatial dimension on the activations and returns the // new instructions, and the dimension permutation of the new shape. - StatusOr>> SplitSpace( + absl::StatusOr>> SplitSpace( HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers, int64_t& activations_batch_dim, int64_t high_padding, int64_t low_padding, int64_t spatial_split_size, int64_t num_splits, @@ -130,7 +131,7 @@ class ConvolutionVisitor { bool is_backprop = false, bool is_rhs = false); // Performs the actual dimension splitting. - StatusOr PerformSplitSpace( + absl::StatusOr PerformSplitSpace( HloInstruction* activations, absl::Span spatial_dimensions_to_split, int64_t activations_batch_dim, int64_t spatial_split_size, @@ -140,22 +141,22 @@ class ConvolutionVisitor { // merges the batch(es). // The input activations dimensions are ... B, B0, S0, B1, S1, ... Bn, Sn, ... // The output dimensions will be ..., B, S0, S1,.. Sn, ... - StatusOr TransposeAndMergeBatch( + absl::StatusOr TransposeAndMergeBatch( HloInstruction* activations, absl::Span final_split_spatial_dim_positioning, int64_t activations_batch_dim, int64_t old_batch_size); // Helper function for the SplitSpace function above. Handles padding and // reshaping to generate space-to-batched shape. - StatusOr PadAndSplitSpace( + absl::StatusOr PadAndSplitSpace( HloInstruction* activations, absl::Span spatial_dimensions_to_split, int64_t activations_batch_dim, int64_t high_padding, int64_t low_padding, int64_t spatial_split_size, int64_t num_splits); // Perform space-to-batch propagation on constants. - StatusOr PropagateOnConstant(HloInstruction* consumer, - HloInstruction* producer); + absl::StatusOr PropagateOnConstant(HloInstruction* consumer, + HloInstruction* producer); // Perform space-to-batch propagation on the convolution. Assumes the // activations were already space-to-batched. @@ -189,7 +190,7 @@ class ConvolutionVisitor { // Generates masked output with valid data. This is useful when larger shapes // are generated due to space-to-batch. - StatusOr SelectValidPortion( + absl::StatusOr SelectValidPortion( HloInstruction* new_instr, HloInstruction* old_instr, HloInstruction* select_val, int64_t new_batch_dim, absl::Span new_space_dims, int64_t old_batch_dim, @@ -201,7 +202,7 @@ class ConvolutionVisitor { }; // Performs tranposition so that space dimension follows the batch dimension. - StatusOr BringSpaceNextToBatch( + absl::StatusOr BringSpaceNextToBatch( HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers, int64_t& activations_batch_dim, std::vector* spatial_dimensions_to_split, @@ -209,29 +210,29 @@ class ConvolutionVisitor { // Decreases the spatial dimension size in an already space-to-batched shape // so that the new size is new_spatial_dim_size. - StatusOr ChangeSpatialSizeOnSpaceToBatchedShape( + absl::StatusOr ChangeSpatialSizeOnSpaceToBatchedShape( HloInstruction* activations, int64_t batch_dimension, int64_t old_batch_size, absl::Span spatial_dimensions_to_split, int64_t new_spatial_dim_size, bool increase_spatial_size = false); // Turns B, S0, S1, ..., Sn into B, B0, S0, B1, S1,... Bn, Sn. - StatusOr SplitAndTransposeMergedBatch( + absl::StatusOr SplitAndTransposeMergedBatch( HloInstruction* activations, int64_t batch_dimension, int64_t old_batch_size, absl::Span spatial_dimensions); // Function that converts spaced-to-batch shape back to the original. - StatusOr BatchToSpace(HloInstruction* old_instr); + absl::StatusOr BatchToSpace(HloInstruction* old_instr); // Duplicates elements at boundaries. - StatusOr HaloDuplicateWithSlice( + absl::StatusOr HaloDuplicateWithSlice( HloInstruction* activations, absl::Span spatial_dimensions_to_split, int64_t activations_batch_dim, int64_t low_padding, int64_t halo_size, HloInstruction* pad_val = nullptr); // Runs the visitor on a computation. - StatusOr Run(); + absl::StatusOr Run(); // Returns whether any convolution ops were rewritten. const bool changed() const { return changed_; } @@ -507,7 +508,7 @@ bool ConvolutionVisitor::IsThisBackPropFilterConv(HloInstruction* convolution) { return true; } -StatusOr ConvolutionVisitor::HaloDuplicateWithSlice( +absl::StatusOr ConvolutionVisitor::HaloDuplicateWithSlice( HloInstruction* activations, absl::Span spatial_dimensions_to_split, int64_t activations_batch_dim, int64_t low_padding, int64_t halo_size, @@ -636,7 +637,7 @@ StatusOr ConvolutionVisitor::HaloDuplicateWithSlice( return activations; } -StatusOr +absl::StatusOr ConvolutionVisitor::BringSpaceNextToBatch( HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers, int64_t& activations_batch_dim, @@ -741,7 +742,8 @@ ConvolutionVisitor::BringSpaceNextToBatch( return SpaceNextToBatchDetails{activations, transpose_dims}; } -StatusOr ConvolutionVisitor::SplitAndTransposeMergedBatch( +absl::StatusOr +ConvolutionVisitor::SplitAndTransposeMergedBatch( HloInstruction* activations, int64_t batch_dimension, int64_t old_batch_size, absl::Span spatial_dimensions) { CHECK_EQ(batch_dimension + 1, spatial_dimensions[0]); @@ -792,7 +794,7 @@ StatusOr ConvolutionVisitor::SplitAndTransposeMergedBatch( return batch_split_activations; } -StatusOr +absl::StatusOr ConvolutionVisitor::ChangeSpatialSizeOnSpaceToBatchedShape( HloInstruction* activations, int64_t batch_dimension, int64_t old_batch_size, absl::Span spatial_dimensions, @@ -881,7 +883,7 @@ ConvolutionVisitor::ChangeSpatialSizeOnSpaceToBatchedShape( return activations_new; } -StatusOr ConvolutionVisitor::Run() { +absl::StatusOr ConvolutionVisitor::Run() { for (auto conv : conv_visitor_list_) { // If we expect to see an unpropagatable op, space-to-batch may not be // beneficial. @@ -893,6 +895,7 @@ StatusOr ConvolutionVisitor::Run() { } if (convs_to_visit_.count(conv) > 0) { TF_CHECK_OK(PerformSpaceToBatchOnConvolution(conv)); + changed_ = true; } } conv_visitor_list_.clear(); @@ -1770,8 +1773,8 @@ bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer, return false; } -StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, - HloInstruction* producer) { +absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, + HloInstruction* producer) { auto computation = consumer->parent(); if (IsTrivialElementwise(consumer)) { auto dim_map_val = instr_to_dim_map_[producer]; @@ -2325,7 +2328,7 @@ StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, return true; } -StatusOr ConvolutionVisitor::SelectValidPortion( +absl::StatusOr ConvolutionVisitor::SelectValidPortion( HloInstruction* new_instr, HloInstruction* old_instr, HloInstruction* select_val, int64_t new_batch_dim, absl::Span new_space_dims, int64_t old_batch_dim, @@ -2407,7 +2410,7 @@ StatusOr ConvolutionVisitor::SelectValidPortion( return new_instr; } -StatusOr ConvolutionVisitor::BatchToSpace( +absl::StatusOr ConvolutionVisitor::BatchToSpace( HloInstruction* old_instr) { if (batch_to_space_map_.count(old_instr)) { CHECK_NE(batch_to_space_map_[old_instr], nullptr); @@ -2885,7 +2888,7 @@ Status ConvolutionVisitor::PropagateOnSlice(HloInstruction* slice) { return OkStatus(); } -StatusOr ConvolutionVisitor::TransposeAndMergeBatch( +absl::StatusOr ConvolutionVisitor::TransposeAndMergeBatch( HloInstruction* activations, absl::Span final_split_spatial_dim_positioning, int64_t activations_batch_dim, int64_t old_batch_size) { @@ -2927,7 +2930,7 @@ StatusOr ConvolutionVisitor::TransposeAndMergeBatch( return batch_collapsed_reshape; } -StatusOr ConvolutionVisitor::PerformSplitSpace( +absl::StatusOr ConvolutionVisitor::PerformSplitSpace( HloInstruction* activations, absl::Span spatial_dimensions_to_split, int64_t activations_batch_dim, int64_t spatial_split_size, @@ -2973,7 +2976,7 @@ StatusOr ConvolutionVisitor::PerformSplitSpace( activations_batch_dim, old_batch_size); } -StatusOr ConvolutionVisitor::PadAndSplitSpace( +absl::StatusOr ConvolutionVisitor::PadAndSplitSpace( HloInstruction* activations, absl::Span spatial_dimensions_to_split, int64_t activations_batch_dim, int64_t high_padding, int64_t low_padding, @@ -3007,7 +3010,7 @@ StatusOr ConvolutionVisitor::PadAndSplitSpace( num_splits); } -StatusOr>> +absl::StatusOr>> ConvolutionVisitor::SplitSpace( HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers, int64_t& activations_batch_dim, int64_t high_padding, int64_t low_padding, @@ -3029,7 +3032,7 @@ ConvolutionVisitor::SplitSpace( return std::make_pair(new_activations, transpose_dims); } -StatusOr ConvolutionVisitor::PropagateOnConstant( +absl::StatusOr ConvolutionVisitor::PropagateOnConstant( HloInstruction* consumer, HloInstruction* producer) { CHECK(old_to_new_instrs_.contains(producer)); HloInstruction* new_producer = old_to_new_instrs_[producer]; @@ -3707,8 +3710,6 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( } VLOG(1) << "Handling conv " << convolution->ToString(); - changed_ = false; - ConvolutionDimensionNumbers dim_numbers = convolution->convolution_dimension_numbers(); @@ -3913,14 +3914,13 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( } TF_CHECK_OK(PropagateOnUsers(original_conv)); - changed_ = true; return OkStatus(); } } // namespace -StatusOr SpaceToBatchConverter::Run( +absl::StatusOr SpaceToBatchConverter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES( diff --git a/xla/service/space_to_batch_converter.h b/xla/service/space_to_batch_converter.h index c017071c5aabd..2d9dba06a2b58 100644 --- a/xla/service/space_to_batch_converter.h +++ b/xla/service/space_to_batch_converter.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -57,7 +57,7 @@ class SpaceToBatchConverter : public HloModulePass { // Run convolution rewriting on the given computation. Returns whether the // computation was changed. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/space_to_batch_converter_test.cc b/xla/service/space_to_batch_converter_test.cc index 4b7585aeb86ee..7198ae2114011 100644 --- a/xla/service/space_to_batch_converter_test.cc +++ b/xla/service/space_to_batch_converter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/sparse_util.cc b/xla/service/sparse_util.cc deleted file mode 100644 index 05ff8df5e6dd9..0000000000000 --- a/xla/service/sparse_util.cc +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/sparse_util.h" - -#include - -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/layout_util.h" - -namespace xla { - -/*static*/ bool SparseUtil::HasSparseInOut(HloInstruction* instruction) { - // Tests sparse operands. - if (std::any_of(instruction->operands().begin(), - instruction->operands().end(), [](HloInstruction* operand) { - return LayoutUtil::IsSparse(operand->shape().layout()); - })) { - return true; - } - // Tests sparse result. - return LayoutUtil::IsSparse(instruction->shape().layout()); -} - -} // namespace xla diff --git a/xla/service/sparse_util.h b/xla/service/sparse_util.h deleted file mode 100644 index bcac77f047db1..0000000000000 --- a/xla/service/sparse_util.h +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_SPARSE_UTIL_H_ -#define XLA_SERVICE_SPARSE_UTIL_H_ - -namespace xla { -// Forward declarations. -class HloInstruction; - -// Namespaced collection of (static) Sparse utilities. -class SparseUtil { - public: - // Returns true if the instruction takes sparse operands or return sparse - // result. - static bool HasSparseInOut(HloInstruction* instruction); - - private: - SparseUtil(const SparseUtil&) = delete; - SparseUtil& operator=(const SparseUtil&) = delete; -}; -} // namespace xla - -#endif // XLA_SERVICE_SPARSE_UTIL_H_ diff --git a/xla/service/spmd/BUILD b/xla/service/spmd/BUILD index 0e9459eccf47b..a72a2ed619a66 100644 --- a/xla/service/spmd/BUILD +++ b/xla/service/spmd/BUILD @@ -34,28 +34,36 @@ cc_library( "spmd_partitioner_util.h", ], deps = [ + "//xla:array", "//xla:comparison_util", "//xla:literal", "//xla:literal_util", "//xla:protobuf_util", "//xla:shape_util", "//xla:status", + "//xla:status_macros", + "//xla:statusor", + "//xla:types", "//xla:util", "//xla:window_util", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", + "//xla/client:xla_computation", "//xla/client/lib:comparators", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_reachability", "//xla/hlo/utils:hlo_query", "//xla/hlo/utils:hlo_sharding_util", "//xla/service:call_graph", + "//xla/service:collective_ops_utils", + "//xla/service:computation_layout", "//xla/service:custom_call_sharding_helper", "//xla/service:dot_as_convolution_util", "//xla/service:flatten_call_graph", "//xla/service:hlo_cse", "//xla/service:hlo_dce", "//xla/service:hlo_lexer", + "//xla/service:hlo_module_config", "//xla/service:hlo_pass", "//xla/service:hlo_pass_pipeline", "//xla/service:pattern_matcher", @@ -69,8 +77,13 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@com_google_absl//absl/utility", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:numbers", "@tsl//tsl/platform:statusor", ], @@ -111,7 +124,6 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:test", ], ) @@ -141,7 +153,6 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:test", ], ) @@ -179,7 +190,6 @@ cc_library( srcs = ["spmd_prepare.cc"], hdrs = ["spmd_prepare.h"], deps = [ - ":spmd_partitioner", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_sharding_util", "//xla/service:hlo_pass", @@ -196,7 +206,6 @@ cc_library( ":spmd_partitioner", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", - "@com_google_absl//absl/memory", ], ) @@ -217,7 +226,6 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:test", ], ) @@ -271,8 +279,6 @@ xla_cc_test( "//xla:xla_proto_cc", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:test", ], ) @@ -307,3 +313,14 @@ xla_cc_test( "@tsl//tsl/lib/core:status_test_util", ], ) + +xla_cc_test( + name = "spmd_partitioner_util_test", + srcs = ["spmd_partitioner_util_test.cc"], + deps = [ + ":spmd_partitioner", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:tile_assignment", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/xla/service/spmd/canonicalize_all_gather_for_cse.cc b/xla/service/spmd/canonicalize_all_gather_for_cse.cc index 6f7f441144073..68e13c5978b12 100644 --- a/xla/service/spmd/canonicalize_all_gather_for_cse.cc +++ b/xla/service/spmd/canonicalize_all_gather_for_cse.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ limitations under the License. namespace xla { -StatusOr CanonicalizeAllGatherForCSE::RunOnComputation( +absl::StatusOr CanonicalizeAllGatherForCSE::RunOnComputation( HloComputation* comp) { bool changed = false; // Helper to find the respective shape input dimension of an shape output @@ -92,7 +92,7 @@ StatusOr CanonicalizeAllGatherForCSE::RunOnComputation( return changed; } -StatusOr CanonicalizeAllGatherForCSE::Run( +absl::StatusOr CanonicalizeAllGatherForCSE::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/spmd/canonicalize_all_gather_for_cse.h b/xla/service/spmd/canonicalize_all_gather_for_cse.h index f139ebbe137ae..6e2581a51d9ea 100644 --- a/xla/service/spmd/canonicalize_all_gather_for_cse.h +++ b/xla/service/spmd/canonicalize_all_gather_for_cse.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,12 +30,12 @@ class CanonicalizeAllGatherForCSE : public HloModulePass { absl::string_view name() const override { return "canon-all-gather-for-cse"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; private: - StatusOr RunOnComputation(HloComputation* comp); + absl::StatusOr RunOnComputation(HloComputation* comp); int64_t NextChannelId() { return next_channel_id_++; } int64_t next_channel_id_; diff --git a/xla/service/spmd/canonicalize_all_gather_for_cse_test.cc b/xla/service/spmd/canonicalize_all_gather_for_cse_test.cc index 97c2636ac67e7..b7e3b2a8c7a8b 100644 --- a/xla/service/spmd/canonicalize_all_gather_for_cse_test.cc +++ b/xla/service/spmd/canonicalize_all_gather_for_cse_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,13 +34,14 @@ namespace op = xla::testing::opcode_matchers; class AllGatherCanonicalizeTest : public HloTestBase { public: - StatusOr> RunPass(absl::string_view hlo_module) { + absl::StatusOr> RunPass( + absl::string_view hlo_module) { TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule( hlo_module, GetModuleConfigForTest())); HloPassPipeline pipeline("all-gather-cse"); pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(module.get()).status()); - return StatusOr>(std::move(module)); + return absl::StatusOr>(std::move(module)); } Status RunPassOnModule(HloModule* module, int64_t distance_threshold = 100) { HloPassPipeline pipeline("all-gather-cse"); diff --git a/xla/service/spmd/collective_permute_motion.cc b/xla/service/spmd/collective_permute_motion.cc index b53d30eb26dc0..b6160f91df82e 100644 --- a/xla/service/spmd/collective_permute_motion.cc +++ b/xla/service/spmd/collective_permute_motion.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -129,6 +129,9 @@ std::optional FindMovableClusterAtBodyRoot( } } } + if (cluster.collective_permute == nullptr) { + return std::nullopt; + } return cluster; } @@ -148,8 +151,8 @@ absl::flat_hash_set FindIndicesUnusedAfterLoop(HloInstruction* loop) { return indices; } -StatusOr MoveCollectivePermutes(HloComputation* computation, - HloInstruction* loop) { +absl::StatusOr MoveCollectivePermutes(HloComputation* computation, + HloInstruction* loop) { HloComputation* body = loop->while_body(); HloInstruction* root = body->root_instruction(); if (root->opcode() != HloOpcode::kTuple || @@ -295,7 +298,7 @@ StatusOr MoveCollectivePermutes(HloComputation* computation, return changed; } -StatusOr CollectivePermuteMotion::Run( +absl::StatusOr CollectivePermuteMotion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/spmd/collective_permute_motion.h b/xla/service/spmd/collective_permute_motion.h index 8b35b90b90a25..2b41a99eb2530 100644 --- a/xla/service/spmd/collective_permute_motion.h +++ b/xla/service/spmd/collective_permute_motion.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ class CollectivePermuteMotion : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/spmd/collective_permute_motion_test.cc b/xla/service/spmd/collective_permute_motion_test.cc index 9c149d17d9968..433a9e86bf583 100644 --- a/xla/service/spmd/collective_permute_motion_test.cc +++ b/xla/service/spmd/collective_permute_motion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -68,6 +68,40 @@ TEST_F(CollectivePermuteMotionTest, SimpleMove) { EXPECT_THAT(output, op::Multiply(select, select)); } +TEST_F(CollectivePermuteMotionTest, NoCollectivePermute) { + absl::string_view hlo_string = R"( + HloModule test + body { + loop_var = (s32[], f32[], f32[]) parameter(0) + constant.1 = s32[] constant(1) + gte0 = s32[] get-tuple-element(loop_var), index=0 + add = s32[] add(gte0, constant.1) + gte1 = f32[] get-tuple-element(loop_var), index=1 + constant.4 = f32[] constant(4.0) + ROOT tuple = (s32[], f32[], f32[]) tuple(add, constant.4, gte1) + } + cond { + loop_var = (s32[], f32[], f32[]) parameter(0) + gte.cond = s32[] get-tuple-element(loop_var), index=0 + constant.3 = s32[] constant(5) + ROOT lt = pred[] compare(gte.cond, constant.3), direction=LT + } + ENTRY main { + constant.2 = s32[] constant(0) + param = f32[] parameter(0) + param.1 = f32[] parameter(1) + tuple.1 = (s32[], f32[], f32[]) tuple(constant.2, param, param.1) + while = (s32[], f32[], f32[]) while(tuple.1), condition=cond, body=body + ROOT result = s32[] get-tuple-element(while), index=0 + } +)"; + // Test that the pass does not crash if there is no collective permute + // (but other conditions in FindMovableClusterAtBodyRoot are satisfied). + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + CollectivePermuteMotion pass; + ASSERT_FALSE(pass.Run(&*module).value()); +} + TEST_F(CollectivePermuteMotionTest, MoveWithElementwise) { absl::string_view hlo_string = R"( HloModule test diff --git a/xla/service/spmd/convolution_handler.cc b/xla/service/spmd/convolution_handler.cc index 8df49ca4b8608..2a88430331db9 100644 --- a/xla/service/spmd/convolution_handler.cc +++ b/xla/service/spmd/convolution_handler.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -42,12 +42,12 @@ namespace spmd { namespace { // Partition convolution with batch group count. -StatusOr PartitionConvolutionWithBatchGroupCount( +absl::StatusOr PartitionConvolutionWithBatchGroupCount( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, const HloSharding& output_sharding, - absl::FunctionRef(HloInstruction*, - HloInstruction*, SpmdBuilder*, - const Window& conv_window)> + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> create_sharded_conv, const Window& conv_window, HloInstruction* original_hlo, int64_t num_partitions, SpmdBuilder* b) { @@ -135,12 +135,12 @@ StatusOr PartitionConvolutionWithBatchGroupCount( } // Partition convolution with feature group count. -StatusOr PartitionConvolutionWithFeatureGroupCount( +absl::StatusOr PartitionConvolutionWithFeatureGroupCount( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, const HloSharding& output_sharding, - absl::FunctionRef(HloInstruction*, - HloInstruction*, SpmdBuilder*, - const Window& conv_window)> + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> create_sharded_conv, const Window& conv_window, HloInstruction* original_hlo, int64_t num_partitions, SpmdBuilder* b) { @@ -229,13 +229,13 @@ StatusOr PartitionConvolutionWithFeatureGroupCount( // Partition convolution when both LHS and RHS are partitioned at spatial // dimensions. Halo exchange will happen on RHS only. -StatusOr +absl::StatusOr PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, const HloSharding& output_sharding, - absl::FunctionRef(HloInstruction*, - HloInstruction*, SpmdBuilder*, - const Window& conv_window)> + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> create_sharded_conv, const Window& conv_window, HloInstruction* original_hlo, HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) { @@ -516,13 +516,13 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS( // Partition convolution when both LHS and RHS are partitioned at spatial // dimensions. Halo exchange will happen on LHS only. -StatusOr +absl::StatusOr PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, const HloSharding& output_sharding, - absl::FunctionRef(HloInstruction*, - HloInstruction*, SpmdBuilder*, - const Window& conv_window)> + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> create_sharded_conv, const Window& conv_window, HloInstruction* original_hlo, HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) { @@ -741,12 +741,12 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS( // Partition convolution when output is sharded. Will shard LHS with replicated // RHS. -StatusOr PartitionConvolutionTiledOutput( +absl::StatusOr PartitionConvolutionTiledOutput( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, const HloSharding& output_sharding, - absl::FunctionRef(HloInstruction*, - HloInstruction*, SpmdBuilder*, - const Window& conv_window)> + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> create_sharded_conv, const Window& conv_window, HloInstruction* original_hlo, SpmdBuilder* b) { TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution); @@ -831,12 +831,12 @@ StatusOr PartitionConvolutionTiledOutput( } // Partition convolution with only one kind of dims partitioned. -StatusOr PartitionConvolutionBaseCase( - PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, - absl::FunctionRef(HloInstruction*, - HloInstruction*, SpmdBuilder*, - const Window& conv_window)> +absl::StatusOr PartitionConvolutionBaseCase( + const PartitionedHlo& lhs, const PartitionedHlo& rhs, + const Shape& output_base_shape, const HloSharding& output_sharding, + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> create_sharded_conv, const Window& conv_window, HloInstruction* original_hlo, int64_t num_partitions, const SpmdPartitionerOptions& options, @@ -907,7 +907,7 @@ StatusOr PartitionConvolutionBaseCase( return nullptr; } -StatusOr> CreateShardedConvolution( +absl::StatusOr> CreateShardedConvolution( const HloInstruction& conv, const dot_as_convolution_util::DotConvolutionDimsInfo& dot_dnums, HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo, @@ -992,12 +992,13 @@ StatusOr> CreateShardedConvolution( } // namespace // Partition convolution. -StatusOr PartitionConvolution( - PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, - absl::FunctionRef(HloInstruction*, - HloInstruction*, SpmdBuilder*, - const Window& conv_window)> +absl::StatusOr PartitionConvolution( + const PartitionedHlo& lhs, const PartitionedHlo& rhs, + const Shape& output_base_shape, const HloSharding& output_sharding, + const DotConvDimsMapping& dims_mapping, + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> create_sharded_conv, const Window& conv_window, HloInstruction* original_hlo, int64_t num_partitions, const SpmdPartitionerOptions& options, @@ -1060,7 +1061,7 @@ Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) { auto create_sharded_conv = [&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo, spmd::SpmdBuilder* b, - const Window& conv_window) -> StatusOr { + const Window& conv_window) -> absl::StatusOr { if (dims_info.conv_spatial_dims.empty() && hlo->feature_group_count() == 1 && hlo->batch_group_count() == 1) { TF_ASSIGN_OR_RETURN( diff --git a/xla/service/spmd/convolution_handler.h b/xla/service/spmd/convolution_handler.h index 617dbb8aeac6c..c38b4c3e96f87 100644 --- a/xla/service/spmd/convolution_handler.h +++ b/xla/service/spmd/convolution_handler.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,12 +26,13 @@ namespace xla { namespace spmd { // Partition convolution. -StatusOr PartitionConvolution( - PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, - absl::FunctionRef(HloInstruction*, - HloInstruction*, SpmdBuilder*, - const Window& conv_window)> +absl::StatusOr PartitionConvolution( + const PartitionedHlo& lhs, const PartitionedHlo& rhs, + const Shape& output_base_shape, const HloSharding& output_sharding, + const DotConvDimsMapping& dims_mapping, + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> create_sharded_conv, const Window& conv_window, HloInstruction* original_hlo, int64_t num_partitions, const SpmdPartitionerOptions& options, diff --git a/xla/service/spmd/custom_call_handler.cc b/xla/service/spmd/custom_call_handler.cc index 6b2a7cce3d429..e28aca78b12b4 100644 --- a/xla/service/spmd/custom_call_handler.cc +++ b/xla/service/spmd/custom_call_handler.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,36 +15,51 @@ limitations under the License. #include "xla/service/spmd/custom_call_handler.h" +#include +#include +#include +#include +#include #include -#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/client/lib/comparators.h" #include "xla/client/xla_builder.h" +#include "xla/client/xla_computation.h" +#include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/literal_util.h" #include "xla/service/custom_call_sharding_helper.h" #include "xla/service/hlo_lexer.h" -#include "xla/service/shape_inference.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/spmd/spmd_partitioner.h" #include "xla/service/spmd/spmd_partitioner_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/status_macros.h" #include "xla/util.h" -#include "xla/window_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace spmd { namespace { -StatusOr> ParseOpaqueAsAttributes( - const HloInstruction* hlo) { +absl::StatusOr> +ParseOpaqueAsAttributes(const HloInstruction* hlo) { absl::string_view opaque = Cast(hlo)->opaque(); HloLexer lexer(opaque); absl::flat_hash_map result; @@ -82,8 +97,11 @@ Status SpmdPartitioningVisitor::HandleCustomCallTopK(HloInstruction* hlo) { const int64_t batch_dim = 0; const int64_t sort_dim = 1; + + CHECK(sharding.IsTiled()); const int64_t shard_count = sharding.tile_assignment().dim(sort_dim); const int64_t batch_dim_partition = sharding.tile_assignment().dim(batch_dim); + const int64_t input_size = hlo->operand(0)->shape().dimensions(sort_dim); const int64_t batch_size = hlo->shape().tuple_shapes(0).dimensions(batch_dim); const int64_t k = hlo->shape().tuple_shapes(0).dimensions(sort_dim); @@ -403,10 +421,6 @@ Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) { return OkStatus(); } - if (hlo->custom_call_target() == "TopK") { - return HandleCustomCallTopK(hlo); - } - if (hlo->custom_call_target() == kSPMDOpRotateRight) { return HandleCustomCallSPMDInternal_RotateRight(hlo); } @@ -439,6 +453,10 @@ Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) { return OkStatus(); } + if (hlo->custom_call_target() == "TopK") { + return HandleCustomCallTopK(hlo); + } + return DefaultAction(hlo); } diff --git a/xla/service/spmd/custom_call_handler.h b/xla/service/spmd/custom_call_handler.h index 93c777626301c..ff3737279d43b 100644 --- a/xla/service/spmd/custom_call_handler.h +++ b/xla/service/spmd/custom_call_handler.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/spmd/dot_handler.cc b/xla/service/spmd/dot_handler.cc index 8f0ed21153b4e..8115a78794672 100644 --- a/xla/service/spmd/dot_handler.cc +++ b/xla/service/spmd/dot_handler.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,10 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include #include +#include +#include #include #include @@ -24,7 +27,11 @@ limitations under the License. #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/functional/function_ref.h" +#include "absl/log/log.h" #include "absl/types/span.h" +#include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -34,6 +41,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/literal_util.h" +#include "xla/service/call_graph.h" #include "xla/service/shape_inference.h" #include "xla/service/sharding_propagation.h" #include "xla/service/spmd/convolution_handler.h" @@ -42,9 +50,12 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" +#include "xla/status_macros.h" #include "xla/util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace spmd { @@ -89,17 +100,26 @@ Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) { mapping.rhs_non_contracting_dims.back().rhs = i; mapping.rhs_non_contracting_dims.back().output = next_output_dim++; } + + HloDotInstruction* dot = Cast(hlo); + std::vector sparsity(dot->sparsity().begin(), + dot->sparsity().end()); + std::vector resharded_meta(dot->sparse_operands()); + for (int i = 0; i < dot->sparse_operands(); ++i) { + resharded_meta[i] = + GetPartitionedHlo(dot->operand(HloDotInstruction::kOperands + i)).hlo(); + } auto create_sharded_dot = [&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b, - const Window& conv_window) -> StatusOr { + const Window& conv_window) -> absl::StatusOr { TF_ASSIGN_OR_RETURN( auto sharded_dot_shape, ShapeInference::InferDotOpShape( l->shape(), r->shape(), hlo->dot_dimension_numbers(), - /*preferred_element_type=*/hlo->shape().element_type())); + /*preferred_element_type=*/hlo->shape().element_type(), sparsity)); return b->AddInstruction(HloInstruction::CreateDot( sharded_dot_shape, l, r, hlo->dot_dimension_numbers(), - hlo->precision_config())); + hlo->precision_config(), sparsity, resharded_meta)); }; return HandleDotHelper(hlo, mapping, create_sharded_dot); } @@ -108,11 +128,34 @@ namespace { enum class WindowedEinsumOperand { LHS, RHS }; +enum class DotComponent { LHS, RHS, OUTPUT }; + +int64_t GetPartitionsForDims( + const HloSharding& sharding, + absl::Span dims, + DotComponent component) { + int64_t partitions = 1; + if (sharding.IsTileMaximal()) { + return partitions; + } + for (const auto& dim : dims) { + if (component == DotComponent::LHS) { + partitions *= ShardCountAtDim(sharding, dim.lhs); + } else if (component == DotComponent::RHS) { + partitions *= ShardCountAtDim(sharding, dim.rhs); + } else { + partitions *= ShardCountAtDim(sharding, dim.output); + } + } + return partitions; +} + struct WindowedEinsumConfig { WindowedEinsumOperand windowed_op; bool windowed_at_contracting_dims; bool windowed_at_batch_dims; bool operands_sharded_at_contracting_dims; + bool is_ag_einsum; }; struct DotDimensionIndexMapping { @@ -452,11 +495,11 @@ std::optional GetWindowedEinsumConfiguration( const Window& conv_window, const DotConvDimsMapping& dims_mapping, const CallGraph& call_graph, int64_t max_iterations = INT64_MAX, const HloInstruction* original_hlo = nullptr, - PartitionedHlo* partitioned_lhs = nullptr, - PartitionedHlo* partitioned_rhs = nullptr, - std::optional(HloInstruction*, HloInstruction*, - SpmdBuilder*, const Window& conv_window)>> + const PartitionedHlo* const partitioned_lhs = nullptr, + const PartitionedHlo* const partitioned_rhs = nullptr, + std::optional( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)>> create_sharded_dot = std::nullopt, SpmdBuilder* b = nullptr, HloModule* module = nullptr, SpmdPartitioningVisitor* visitor = nullptr) { @@ -484,9 +527,9 @@ std::optional GetWindowedEinsumConfiguration( } constexpr int kAggressiveness = 3; std::optional original_ideal_sharding = - ShardingPropagation::GetShardingFromUser(*to_loop_over, *original_hlo, - kAggressiveness, - /*is_spmd=*/true, call_graph); + ShardingPropagation::GetShardingFromUser( + *to_loop_over, *original_hlo, kAggressiveness, + /*is_spmd=*/true, call_graph, /*sharding_helper=*/nullptr); // Default to perform collective matmul if GetShardingFromUser() couldn't // determine the sharding. if (!original_ideal_sharding) { @@ -499,7 +542,7 @@ std::optional GetWindowedEinsumConfiguration( std::optional from_user = ShardingPropagation::GetShardingFromUser( *to_loop_over, *user, kAggressiveness, - /*is_spmd=*/true, call_graph); + /*is_spmd=*/true, call_graph, /*sharding_helper=*/nullptr); // Could't determine sharding. Skip to next one and pretend it wouldn't // share the resharding. if (!from_user) { @@ -648,27 +691,31 @@ std::optional GetWindowedEinsumConfiguration( rhs_shape_size >= options.threshold_for_windowed_einsum_mib * 1024 * 1024 && (!rhs || check_users_sharding(rhs)) && - !disable_windowed_einsum(/*lhs_needs_ag=*/false, /*rhs_needs_ag=*/true)) { + !disable_windowed_einsum(/*lhs_needs_ag=*/false, /*rhs_needs_ag=*/true) && + options.enable_windowed_einsum_for_all_gather) { if (rhs_contracting_partitions == num_partitions) { return WindowedEinsumConfig{ /*windowed_op=*/WindowedEinsumOperand::RHS, /*windowed_at_contracting_dims*/ true, /*windowed_at_batch_dims=*/false, - /*operands_sharded_at_contracting_dims=*/false}; + /*operands_sharded_at_contracting_dims=*/false, + /*is_ag_einsum=*/true}; } if (rhs_non_contracting_partitions == num_partitions) { return WindowedEinsumConfig{ /*windowed_op=*/WindowedEinsumOperand::RHS, /*windowed_at_contracting_dims*/ false, /*windowed_at_batch_dims=*/false, - /*operands_sharded_at_contracting_dims=*/false}; + /*operands_sharded_at_contracting_dims=*/false, + /*is_ag_einsum=*/true}; } if (rhs_batch_partitions == num_partitions) { return WindowedEinsumConfig{ /*windowed_op=*/WindowedEinsumOperand::RHS, /*windowed_at_contracting_dims*/ false, /*windowed_at_batch_dims=*/true, - /*operands_sharded_at_contracting_dims=*/false}; + /*operands_sharded_at_contracting_dims=*/false, + /*is_ag_einsum=*/true}; } } if (output_rhs_non_contracting_partitions == num_partitions && @@ -676,27 +723,31 @@ std::optional GetWindowedEinsumConfiguration( lhs_shape_size >= options.threshold_for_windowed_einsum_mib * 1024 * 1024 && (!lhs || check_users_sharding(lhs)) && - !disable_windowed_einsum(/*lhs_needs_ag=*/true, /*rhs_needs_ag=*/false)) { + !disable_windowed_einsum(/*lhs_needs_ag=*/true, /*rhs_needs_ag=*/false) && + options.enable_windowed_einsum_for_all_gather) { if (lhs_contracting_partitions == num_partitions) { return WindowedEinsumConfig{ /*windowed_op=*/WindowedEinsumOperand::LHS, /*windowed_at_contracting_dims*/ true, /*windowed_at_batch_dims=*/false, - /*operands_sharded_at_contracting_dims=*/false}; + /*operands_sharded_at_contracting_dims=*/false, + /*is_ag_einsum=*/true}; } if (lhs_non_contracting_partitions == num_partitions) { return WindowedEinsumConfig{ /*windowed_op=*/WindowedEinsumOperand::LHS, /*windowed_at_contracting_dims*/ false, /*windowed_at_batch_dims=*/false, - /*operands_sharded_at_contracting_dims=*/false}; + /*operands_sharded_at_contracting_dims=*/false, + /*is_ag_einsum=*/true}; } if (lhs_batch_partitions == num_partitions) { return WindowedEinsumConfig{ /*windowed_op=*/WindowedEinsumOperand::LHS, /*windowed_at_contracting_dims*/ false, /*windowed_at_batch_dims=*/true, - /*operands_sharded_at_contracting_dims=*/false}; + /*operands_sharded_at_contracting_dims=*/false, + /*is_ag_einsum=*/true}; } } if (lhs_contracting_partitions == rhs_contracting_partitions && @@ -706,20 +757,21 @@ std::optional GetWindowedEinsumConfiguration( output_shape_size >= options.threshold_for_windowed_einsum_mib * 1024 * 1024 && !disable_windowed_einsum(/*lhs_needs_ag=*/false, - /*rhs_needs_ag=*/false)) { + /*rhs_needs_ag=*/false) && + options.enable_windowed_einsum_for_reduce_scatter) { if (output_lhs_non_contracting_partitions == num_partitions) { - return WindowedEinsumConfig{ - /*windowed_op=*/WindowedEinsumOperand::RHS, - /*windowed_at_contracting_dims*/ false, - /*windowed_at_batch_dims=*/false, - /*operands_sharded_at_contracting_dims=*/true}; + return WindowedEinsumConfig{/*windowed_op=*/WindowedEinsumOperand::RHS, + /*windowed_at_contracting_dims*/ false, + /*windowed_at_batch_dims=*/false, + /*operands_sharded_at_contracting_dims=*/true, + /*is_ag_einsum=*/false}; } if (output_rhs_non_contracting_partitions == num_partitions) { - return WindowedEinsumConfig{ - /*windowed_op=*/WindowedEinsumOperand::LHS, - /*windowed_at_contracting_dims*/ false, - /*windowed_at_batch_dims=*/false, - /*operands_sharded_at_contracting_dims=*/true}; + return WindowedEinsumConfig{/*windowed_op=*/WindowedEinsumOperand::LHS, + /*windowed_at_contracting_dims*/ false, + /*windowed_at_batch_dims=*/false, + /*operands_sharded_at_contracting_dims=*/true, + /*is_ag_einsum=*/false}; } } return std::nullopt; @@ -775,13 +827,13 @@ std::vector GetLoopReplicaGroups(HloInstruction* while_loop) { // is tiled in other dimensions. Or both operands are partitioned in the same // way along contracting dimensions, but the output is partitioned along // non-contracting dimensions. -StatusOr EmitWindowedDotGeneral( +absl::StatusOr EmitWindowedDotGeneral( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, int64_t num_partitions, - absl::FunctionRef(HloInstruction*, - HloInstruction*, SpmdBuilder*, - const Window& conv_window)> + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> create_sharded_dot, const Window& conv_window, HloModule* module, HloInstruction* original_hlo, const SpmdPartitionerOptions& options, SpmdBuilder* b, @@ -940,13 +992,15 @@ StatusOr EmitWindowedDotGeneral( // Create a while loop that computes one window per iteration. During each // iteration, each partition sends its input window to its neighbor using // collective-permute for the next iteration. - SpmdBuilder body_b("windowed_dot_general_body", original_hlo); + std::string body_name = "windowed_dot_general_body"; + body_name += (einsum_config.is_ag_einsum) ? "_ag" : "_rs"; + SpmdBuilder body_b(body_name, original_hlo); // Generate partial results used by bidirectional algorithm. auto get_partial_bid_results = [&](HloInstruction* l, HloInstruction* r, HloInstruction* o, HloInstruction* extra_inout, HloInstruction* cw_cp_output, - HloInstruction* i) -> StatusOr> { + HloInstruction* i) -> absl::StatusOr> { auto partition_id = lhs.state().collective_ops_creator.create_partition_id(&body_b); auto partition_count = body_b.AddInstruction(HloInstruction::CreateConstant( @@ -1318,7 +1372,7 @@ StatusOr EmitWindowedDotGeneral( // Generate partial result used by unidirectional algorithm. auto get_partial_unid_result = [&](HloInstruction* l, HloInstruction* r, HloInstruction* o, - HloInstruction* i) -> StatusOr { + HloInstruction* i) -> absl::StatusOr { auto partition_id = lhs.state().collective_ops_creator.create_partition_id(&body_b); auto data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary( @@ -1646,7 +1700,9 @@ StatusOr EmitWindowedDotGeneral( HloInstruction::CreateTuple({l, r, o, extra_inout, i})); } - SpmdBuilder cond_b("windowed_dot_general_cond", original_hlo); + std::string cond_name = "windowed_dot_general_cond"; + cond_name += (einsum_config.is_ag_einsum) ? "_ag" : "_rs"; + SpmdBuilder cond_b(cond_name, original_hlo); auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeTupleShapeWithPtrs( @@ -1717,21 +1773,15 @@ StatusOr EmitWindowedDotGeneral( // recursion as we group devices together. So refer to the passed in shapes and // shardings for inputs and output, and do not use shape inference. -StatusOr PartitionBaseCase( +absl::StatusOr PartitionBaseCase( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, int64_t num_partitions, - absl::FunctionRef(HloInstruction*, - HloInstruction*, SpmdBuilder*, - const Window& conv_window)> + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> create_sharded_dot, const Window& conv_window, HloModule* module, HloInstruction* original_hlo, - int64_t lhs_batch_partitions, int64_t rhs_batch_partitions, - int64_t output_batch_partitions, int64_t lhs_contracting_partitions, - int64_t rhs_contracting_partitions, int64_t lhs_non_contracting_partitions, - int64_t rhs_non_contracting_partitions, - int64_t output_lhs_non_contracting_partitions, - int64_t output_rhs_non_contracting_partitions, const SpmdPartitionerOptions& options, SpmdBuilder* b, std::vector* windowed_dot_general_loops, @@ -1739,6 +1789,27 @@ StatusOr PartitionBaseCase( SpmdPartitioningVisitor* visitor) { const HloSharding& lhs_sharding = lhs.sharding(); const HloSharding& rhs_sharding = rhs.sharding(); + const int64_t lhs_batch_partitions = GetPartitionsForDims( + lhs_sharding, dims_mapping.batch_dims, DotComponent::LHS); + const int64_t rhs_batch_partitions = GetPartitionsForDims( + rhs_sharding, dims_mapping.batch_dims, DotComponent::RHS); + const int64_t output_batch_partitions = GetPartitionsForDims( + output_sharding, dims_mapping.batch_dims, DotComponent::OUTPUT); + const int64_t lhs_contracting_partitions = GetPartitionsForDims( + lhs_sharding, dims_mapping.contracting_dims, DotComponent::LHS); + const int64_t rhs_contracting_partitions = GetPartitionsForDims( + rhs_sharding, dims_mapping.contracting_dims, DotComponent::RHS); + const int64_t lhs_non_contracting_partitions = GetPartitionsForDims( + lhs_sharding, dims_mapping.lhs_non_contracting_dims, DotComponent::LHS); + const int64_t rhs_non_contracting_partitions = GetPartitionsForDims( + rhs_sharding, dims_mapping.rhs_non_contracting_dims, DotComponent::RHS); + const int64_t output_lhs_non_contracting_partitions = GetPartitionsForDims( + output_sharding, dims_mapping.lhs_non_contracting_dims, + DotComponent::OUTPUT); + const int64_t output_rhs_non_contracting_partitions = GetPartitionsForDims( + output_sharding, dims_mapping.rhs_non_contracting_dims, + DotComponent::OUTPUT); + if (lhs_sharding.ReplicateOnLastTileDim() || rhs_sharding.ReplicateOnLastTileDim() || output_sharding.ReplicateOnLastTileDim()) { @@ -1790,7 +1861,7 @@ StatusOr PartitionBaseCase( // may_reshard_with_allreduce is false, reshard must be done using // all-to-all/collective-permute; otherwise this attempt fails. auto try_emit_output_batch_partitioned_einsum_with_reshard = - [&](bool may_reshard_with_allreduce) -> StatusOr { + [&](bool may_reshard_with_allreduce) -> absl::StatusOr { // LHS and output are batch partitioned in the same way. if (lhs_batch_partitions == num_partitions && output_batch_partitions == num_partitions && @@ -2019,13 +2090,13 @@ StatusOr PartitionBaseCase( return nullptr; } -StatusOr PartitionDot( - PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, - int64_t num_partitions, - absl::FunctionRef(HloInstruction*, - HloInstruction*, SpmdBuilder*, - const Window& conv_window)> +absl::StatusOr PartitionDot( + const PartitionedHlo& lhs, const PartitionedHlo& rhs, + const Shape& output_base_shape, const HloSharding& output_sharding, + const DotConvDimsMapping& dims_mapping, int64_t num_partitions, + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> create_sharded_dot, const Window& conv_window, HloModule* module, HloInstruction* original_hlo, const SpmdPartitionerOptions& options, SpmdBuilder* b, @@ -2033,15 +2104,15 @@ StatusOr PartitionDot( windowed_dot_general_loops, SpmdPartitioningVisitor* visitor); -StatusOr PartitionDotGroupOnBatch( +absl::StatusOr PartitionDotGroupOnBatchImpl( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, int64_t num_partitions, int64_t lhs_contracting_partitions, int64_t rhs_contracting_partitions, int64_t lhs_non_contracting_partitions, int64_t rhs_non_contracting_partitions, - absl::FunctionRef(HloInstruction*, - HloInstruction*, SpmdBuilder*, - const Window& conv_window)> + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> create_sharded_dot, const Window& conv_window, HloModule* module, HloInstruction* original_hlo, bool require_matching_devices_to_group, @@ -2141,7 +2212,7 @@ StatusOr PartitionDotGroupOnBatch( auto per_group_partitioner_state = CreatePerGroupPartitioningState( lhs.state(), output_grouped.device_groups, b); auto reshard_to_output_batch = - [&](PartitionedHlo operand, absl::Span batch_dims, + [&](const PartitionedHlo& operand, absl::Span batch_dims, absl::Span contracting_dims, absl::Span non_contracting_dims, int64_t contracting_dim_partitions, @@ -2417,7 +2488,7 @@ GetNonContractingPartitionGroupedShardingForOtherOperand( return std::nullopt; } -StatusOr PartitionDotGroupOnNonContracting( +absl::StatusOr PartitionDotGroupOnNonContractingImpl( bool lhs_matching, PartitionedHlo matching, PartitionedHlo other, int64_t matching_contracting_partitions, int64_t other_contracting_partitions, @@ -2427,9 +2498,9 @@ StatusOr PartitionDotGroupOnNonContracting( int64_t output_other_non_contracting_partitions, const Shape& output_base_shape, const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, int64_t num_partitions, - absl::FunctionRef(HloInstruction*, - HloInstruction*, SpmdBuilder*, - const Window& conv_window)> + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> create_sharded_dot, const Window& conv_window, HloModule* module, HloInstruction* original_hlo, bool require_matching_devices_to_group, @@ -2691,7 +2762,7 @@ GetDotGroupPartitionContractingLhsRhsShardings( return std::make_pair(lhs_sharding, rhs_sharding); } -StatusOr PartitionDotGroupOnContracting( +absl::StatusOr PartitionDotGroupOnContractingImpl( PartitionedHlo lhs, PartitionedHlo rhs, absl::Span partitioned_contracting_dims, @@ -2700,9 +2771,9 @@ StatusOr PartitionDotGroupOnContracting( int64_t output_rhs_non_contracting_partitions, const Shape& output_base_shape, const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, int64_t num_partitions, - absl::FunctionRef(HloInstruction*, - HloInstruction*, SpmdBuilder*, - const Window& conv_window)> + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> create_sharded_dot, const Window& conv_window, HloModule* module, HloInstruction* original_hlo, bool require_matching_devices_to_group, @@ -2803,7 +2874,7 @@ StatusOr PartitionDotGroupOnContracting( } auto inner_creator = [&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b, - const Window& conv_window) -> StatusOr { + const Window& conv_window) -> absl::StatusOr { TF_ASSIGN_OR_RETURN(auto inner_dot, create_sharded_dot(l, r, b, conv_window)); HloInstruction* result = inner_dot; @@ -3122,9 +3193,9 @@ bool PrioritizeContractingDimensionsPartitioning( int64_t rhs_batch_partitions, int64_t output_batch_partitions, bool require_matching_devices_to_group, SpmdBuilder* b, const Window& conv_window, - absl::FunctionRef(HloInstruction*, - HloInstruction*, SpmdBuilder*, - const Window& conv_window)> + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> create_sharded_dot, SpmdPartitioningVisitor* visitor) { const bool may_group_on_lhs_non_contracting = @@ -3358,9 +3429,9 @@ bool LhsIsBestMatchForNonContractingPartitioning( int64_t output_lhs_non_contracting_partitions, int64_t output_rhs_non_contracting_partitions, int64_t lhs_batch_partitions, int64_t rhs_batch_partitions, SpmdBuilder* b, const Window& conv_window, - absl::FunctionRef(HloInstruction*, - HloInstruction*, SpmdBuilder*, - const Window& conv_window)> + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> create_sharded_dot, SpmdPartitioningVisitor* visitor) { const bool may_group_on_lhs_non_contracting = @@ -3475,78 +3546,221 @@ bool LhsIsBestMatchForNonContractingPartitioning( return lhs_matching; } -// Recursive partitioning function. If there are partial dimensions matching -// in the operands and output, group the devices and recursively partition -// the in-group dot. -StatusOr PartitionDot( - PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, - int64_t num_partitions, - absl::FunctionRef(HloInstruction*, - HloInstruction*, SpmdBuilder*, - const Window& conv_window)> +absl::StatusOr> +PartitionConvOnBatchOrFeatureGroupedDims( + const PartitionedHlo& lhs, const PartitionedHlo& rhs, + const Shape& output_base_shape, const HloSharding& output_sharding, + const DotConvDimsMapping& dims_mapping, int64_t num_partitions, + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> create_sharded_dot, const Window& conv_window, HloModule* module, HloInstruction* original_hlo, - bool require_matching_devices_to_group, const SpmdPartitionerOptions& options, SpmdBuilder* b, std::vector* windowed_dot_general_loops, - SpmdPartitioningVisitor* visitor) { - // If lhs‘ hlo and rhs' hlo are identical, make a copy for rhs. - if (lhs.hlo() == rhs.hlo()) { - auto copy_hlo = b->AddInstruction(HloInstruction::CreateUnary( - rhs.hlo()->shape(), HloOpcode::kCopy, rhs.hlo())); - copy_hlo->copy_sharding(rhs.hlo()); - rhs = PartitionedHlo(copy_hlo, rhs.base_shape(), rhs.state()); - } - - // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output. - auto get_partitions_for_dims = - [&](const HloSharding& sharding, - absl::Span dims, - int lhs_rhs_or_output) { - int64_t partitions = 1; - if (sharding.IsTileMaximal()) { - return partitions; - } - for (const auto& dim : dims) { - if (lhs_rhs_or_output == 0) { - partitions *= ShardCountAtDim(sharding, dim.lhs); - } else if (lhs_rhs_or_output == 1) { - partitions *= ShardCountAtDim(sharding, dim.rhs); - } else { - CHECK_EQ(lhs_rhs_or_output, 2); - partitions *= ShardCountAtDim(sharding, dim.output); - } - } - return partitions; - }; - const int64_t lhs_batch_partitions = - get_partitions_for_dims(lhs.sharding(), dims_mapping.batch_dims, 0); - const int64_t rhs_batch_partitions = - get_partitions_for_dims(rhs.sharding(), dims_mapping.batch_dims, 1); - const int64_t output_batch_partitions = - get_partitions_for_dims(output_sharding, dims_mapping.batch_dims, 2); - const int64_t lhs_contracting_partitions = - get_partitions_for_dims(lhs.sharding(), dims_mapping.contracting_dims, 0); - const int64_t rhs_contracting_partitions = - get_partitions_for_dims(rhs.sharding(), dims_mapping.contracting_dims, 1); - const int64_t lhs_non_contracting_partitions = get_partitions_for_dims( - lhs.sharding(), dims_mapping.lhs_non_contracting_dims, 0); - const int64_t rhs_non_contracting_partitions = get_partitions_for_dims( - rhs.sharding(), dims_mapping.rhs_non_contracting_dims, 1); - const int64_t output_lhs_non_contracting_partitions = get_partitions_for_dims( - output_sharding, dims_mapping.lhs_non_contracting_dims, 2); - const int64_t output_rhs_non_contracting_partitions = get_partitions_for_dims( - output_sharding, dims_mapping.rhs_non_contracting_dims, 2); - const int64_t lhs_conv_spatial_partitions = get_partitions_for_dims( - lhs.sharding(), dims_mapping.conv_spatial_dims, 0); - const int64_t rhs_conv_spatial_partitions = get_partitions_for_dims( - rhs.sharding(), dims_mapping.conv_spatial_dims, 1); - const int64_t output_conv_spatial_partitions = get_partitions_for_dims( - output_sharding, dims_mapping.conv_spatial_dims, 2); - // Before we find partial matches along the dimensions, invoke base case - // again without may_reshard_without_detecting_match. + bool require_matching_devices_to_group, SpmdPartitioningVisitor* visitor) { + if (original_hlo->feature_group_count() > 1 || + original_hlo->batch_group_count() > 1) { + const auto& dnums = original_hlo->convolution_dimension_numbers(); + std::optional new_dims_mapping; + if (original_hlo->feature_group_count() > 1) { + const int64_t input_feature_dim = dnums.input_feature_dimension(); + const int64_t kernel_output_feature_dim = + dnums.kernel_output_feature_dimension(); + // If the input and output feature dims are not equal, we require the + // feature_group_count to be evenly partitioned; otherwise, there will + // be different padding in the input/output. + // TODO(xla): Use halo exchange to solve this problem. Can be a + // preprocessing that uses padding/slicing to make the shape evenly + // shardable. + if (lhs.base_shape().dimensions(input_feature_dim) == + rhs.base_shape().dimensions(kernel_output_feature_dim) || + (lhs.sharding().IsTiled() && + original_hlo->feature_group_count() % + ShardCountAtDim(lhs.sharding(), input_feature_dim) == + 0)) { + new_dims_mapping = + ConvertDimsMappingWithFeatureGroupCount(dims_mapping, original_hlo); + } + } + + if (original_hlo->batch_group_count() > 1) { + const int64_t input_batch_dim = dnums.input_batch_dimension(); + const int64_t kernel_output_feature_dim = + dnums.kernel_output_feature_dimension(); + if (lhs.base_shape().dimensions(input_batch_dim) == + rhs.base_shape().dimensions(kernel_output_feature_dim) || + (lhs.sharding().IsTiled() && + original_hlo->batch_group_count() % + ShardCountAtDim(lhs.sharding(), input_batch_dim) == + 0)) { + new_dims_mapping = + ConvertDimsMappingWithBatchGroupCount(dims_mapping, original_hlo); + } + } + if (!new_dims_mapping.has_value()) { + return nullptr; + } + + const int64_t conv_lhs_contracting_partitions = GetPartitionsForDims( + lhs.sharding(), new_dims_mapping->contracting_dims, DotComponent::LHS); + const int64_t conv_rhs_contracting_partitions = GetPartitionsForDims( + rhs.sharding(), new_dims_mapping->contracting_dims, DotComponent::RHS); + const int64_t conv_lhs_non_contracting_partitions = GetPartitionsForDims( + lhs.sharding(), new_dims_mapping->lhs_non_contracting_dims, + DotComponent::LHS); + const int64_t conv_rhs_non_contracting_partitions = GetPartitionsForDims( + rhs.sharding(), new_dims_mapping->rhs_non_contracting_dims, + DotComponent::RHS); + const int64_t conv_lhs_batch_partitions = GetPartitionsForDims( + lhs.sharding(), new_dims_mapping->batch_dims, DotComponent::LHS); + const int64_t conv_rhs_batch_partitions = GetPartitionsForDims( + rhs.sharding(), new_dims_mapping->batch_dims, DotComponent::RHS); + const int64_t conv_output_batch_partitions = GetPartitionsForDims( + output_sharding, new_dims_mapping->batch_dims, DotComponent::OUTPUT); + if ((conv_lhs_batch_partitions == conv_output_batch_partitions || + conv_rhs_batch_partitions == conv_output_batch_partitions) && + conv_output_batch_partitions > 1) { + TF_ASSIGN_OR_RETURN( + auto try_partitioned_conv, + PartitionDotGroupOnBatchImpl( + lhs, rhs, output_base_shape, output_sharding, *new_dims_mapping, + num_partitions, conv_lhs_contracting_partitions, + conv_rhs_contracting_partitions, + conv_lhs_non_contracting_partitions, + conv_rhs_non_contracting_partitions, create_sharded_dot, + conv_window, module, original_hlo, + require_matching_devices_to_group, options, b, + windowed_dot_general_loops, visitor)); + if (try_partitioned_conv) { + return try_partitioned_conv; + } + } + // For batch/feature grouped convs, we try to at least partiton them on + // the batch dimensions and partially replicate other dimensions, instead + // of replicating everything. + const int64_t max_batch_partitions = + std::max(std::max(conv_lhs_batch_partitions, conv_rhs_batch_partitions), + conv_output_batch_partitions); + if (!require_matching_devices_to_group && max_batch_partitions > 1 && + ((original_hlo->batch_group_count() > 1 && + original_hlo->batch_group_count() % max_batch_partitions == 0) || + (original_hlo->feature_group_count() > 1 && + original_hlo->feature_group_count() % max_batch_partitions == 0))) { + const int64_t conv_lhs_batch_dim = original_hlo->batch_group_count() > 1 + ? dnums.input_batch_dimension() + : dnums.input_feature_dimension(); + const int64_t conv_rhs_batch_dim = + dnums.kernel_output_feature_dimension(); + const int64_t conv_output_batch_dim = dnums.output_feature_dimension(); + PartitionedHlo resharded_lhs = lhs; + PartitionedHlo resharded_rhs = rhs; + HloSharding aligned_output_sharding = HloSharding::Replicate(); + HloInstruction* sharded_conv = nullptr; + DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping( + *new_dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(), + output_base_shape.rank()); + if (max_batch_partitions == conv_lhs_batch_partitions) { + resharded_lhs = resharded_lhs.Reshard( + hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept( + resharded_lhs.sharding(), {conv_lhs_batch_dim})); + auto lhs_sharding_transposed_to_match_rhs = + hlo_sharding_util::TransposeShardingWithCollapsedDims( + resharded_lhs.sharding(), indices_map.lhs_to_rhs_indices, + indices_map.rhs_to_lhs_indices); + resharded_rhs = + resharded_rhs.Reshard(*lhs_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN( + sharded_conv, + create_sharded_dot(resharded_lhs.hlo(), resharded_rhs.hlo(), b, + conv_window)); + auto lhs_sharding_transposed_to_match_output = + hlo_sharding_util::TransposeShardingWithCollapsedDims( + resharded_lhs.sharding(), indices_map.lhs_to_output_indices, + indices_map.output_to_lhs_indices); + sharded_conv->set_sharding(*lhs_sharding_transposed_to_match_output); + } else if (max_batch_partitions == conv_rhs_batch_partitions) { + resharded_rhs = resharded_rhs.Reshard( + hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept( + resharded_rhs.sharding(), {conv_rhs_batch_dim})); + auto rhs_sharding_transposed_to_match_lhs = + hlo_sharding_util::TransposeShardingWithCollapsedDims( + resharded_rhs.sharding(), indices_map.rhs_to_lhs_indices, + indices_map.lhs_to_rhs_indices); + resharded_lhs = + resharded_lhs.Reshard(*rhs_sharding_transposed_to_match_lhs); + TF_ASSIGN_OR_RETURN( + sharded_conv, + create_sharded_dot(resharded_lhs.hlo(), resharded_rhs.hlo(), b, + conv_window)); + auto rhs_sharding_transposed_to_match_output = + hlo_sharding_util::TransposeShardingWithCollapsedDims( + resharded_rhs.sharding(), indices_map.rhs_to_output_indices, + indices_map.output_to_rhs_indices); + sharded_conv->set_sharding(*rhs_sharding_transposed_to_match_output); + } else { + CHECK_EQ(max_batch_partitions, conv_output_batch_partitions); + HloSharding target_output_sharding = + hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept( + output_sharding, {conv_output_batch_dim}); + auto output_sharding_transposed_to_match_lhs = + hlo_sharding_util::TransposeShardingWithCollapsedDims( + target_output_sharding, indices_map.output_to_lhs_indices, + indices_map.lhs_to_output_indices); + resharded_lhs = + resharded_lhs.Reshard(*output_sharding_transposed_to_match_lhs); + auto output_sharding_transposed_to_match_rhs = + hlo_sharding_util::TransposeShardingWithCollapsedDims( + target_output_sharding, indices_map.output_to_rhs_indices, + indices_map.rhs_to_output_indices); + resharded_rhs = + resharded_rhs.Reshard(*output_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN( + sharded_conv, + create_sharded_dot(resharded_lhs.hlo(), resharded_rhs.hlo(), b, + conv_window)); + sharded_conv->set_sharding(target_output_sharding); + } + + return PartitionedHlo(sharded_conv, output_base_shape, lhs.state()) + .Reshard(output_sharding) + .hlo(); + } + return nullptr; + } + return std::nullopt; +} + +absl::StatusOr> PartitionConv( + const PartitionedHlo& lhs, const PartitionedHlo& rhs, + const Shape& output_base_shape, const HloSharding& output_sharding, + const DotConvDimsMapping& dims_mapping, int64_t num_partitions, + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> + create_sharded_dot, + const Window& conv_window, HloModule* module, HloInstruction* original_hlo, + const SpmdPartitionerOptions& options, SpmdBuilder* b, + std::vector* + windowed_dot_general_loops, + bool require_matching_devices_to_group, SpmdPartitioningVisitor* visitor) { + const int64_t lhs_batch_partitions = GetPartitionsForDims( + lhs.sharding(), dims_mapping.batch_dims, DotComponent::LHS); + const int64_t rhs_batch_partitions = GetPartitionsForDims( + rhs.sharding(), dims_mapping.batch_dims, DotComponent::RHS); + const int64_t output_batch_partitions = GetPartitionsForDims( + output_sharding, dims_mapping.batch_dims, DotComponent::OUTPUT); + const int64_t lhs_contracting_partitions = GetPartitionsForDims( + lhs.sharding(), dims_mapping.contracting_dims, DotComponent::LHS); + const int64_t rhs_contracting_partitions = GetPartitionsForDims( + rhs.sharding(), dims_mapping.contracting_dims, DotComponent::RHS); + const int64_t lhs_conv_spatial_partitions = GetPartitionsForDims( + lhs.sharding(), dims_mapping.conv_spatial_dims, DotComponent::LHS); + const int64_t rhs_conv_spatial_partitions = GetPartitionsForDims( + rhs.sharding(), dims_mapping.conv_spatial_dims, DotComponent::RHS); + const int64_t output_conv_spatial_partitions = GetPartitionsForDims( + output_sharding, dims_mapping.conv_spatial_dims, DotComponent::OUTPUT); // Try partition the purely spatially-partitioned convolution with // convolution spatial dimension partitioned or depthwise parallel @@ -3574,7 +3788,7 @@ StatusOr PartitionDot( } TF_ASSIGN_OR_RETURN( - auto partitioned_conv, + HloInstruction * partitioned_conv, PartitionConvolution(lhs, rhs, output_base_shape, output_sharding, dims_mapping, create_sharded_dot, conv_window, original_hlo, num_partitions, options, @@ -3583,206 +3797,53 @@ StatusOr PartitionDot( if (partitioned_conv) { return partitioned_conv; } - - // Recursively partition on different types of dimensions for - // convolution. Case 0.a: Group partitions by feature group count. - if (original_hlo->feature_group_count() > 1 || - original_hlo->batch_group_count() > 1) { - const auto& dnums = original_hlo->convolution_dimension_numbers(); - std::optional new_dims_mapping; - if (original_hlo->feature_group_count() > 1) { - const int64_t input_feature_dim = dnums.input_feature_dimension(); - const int64_t kernel_output_feature_dim = - dnums.kernel_output_feature_dimension(); - // If the input and output feature dims are not equal, we require the - // feature_group_count to be evenly partitioned; otherwise, there will - // be different padding in the input/output. - // TODO(xla): Use halo exchange to solve this problem. Can be a - // preprocessing that uses padding/slicing to make the shape evenly - // shardable. - if (lhs.base_shape().dimensions(input_feature_dim) == - rhs.base_shape().dimensions(kernel_output_feature_dim) || - (lhs.sharding().IsTiled() && - original_hlo->feature_group_count() % - ShardCountAtDim(lhs.sharding(), input_feature_dim) == - 0)) { - new_dims_mapping = ConvertDimsMappingWithFeatureGroupCount( - dims_mapping, original_hlo); - } - } - - if (original_hlo->batch_group_count() > 1) { - const int64_t input_batch_dim = dnums.input_batch_dimension(); - const int64_t kernel_output_feature_dim = - dnums.kernel_output_feature_dimension(); - if (lhs.base_shape().dimensions(input_batch_dim) == - rhs.base_shape().dimensions(kernel_output_feature_dim) || - (lhs.sharding().IsTiled() && - original_hlo->batch_group_count() % - ShardCountAtDim(lhs.sharding(), input_batch_dim) == - 0)) { - new_dims_mapping = - ConvertDimsMappingWithBatchGroupCount(dims_mapping, original_hlo); - } - } - if (!new_dims_mapping.has_value()) { - return nullptr; - } - - const int64_t conv_lhs_contracting_partitions = get_partitions_for_dims( - lhs.sharding(), new_dims_mapping->contracting_dims, 0); - const int64_t conv_rhs_contracting_partitions = get_partitions_for_dims( - rhs.sharding(), new_dims_mapping->contracting_dims, 1); - const int64_t conv_lhs_non_contracting_partitions = - get_partitions_for_dims( - lhs.sharding(), new_dims_mapping->lhs_non_contracting_dims, 0); - const int64_t conv_rhs_non_contracting_partitions = - get_partitions_for_dims( - rhs.sharding(), new_dims_mapping->rhs_non_contracting_dims, 1); - const int64_t conv_lhs_batch_partitions = get_partitions_for_dims( - lhs.sharding(), new_dims_mapping->batch_dims, 0); - const int64_t conv_rhs_batch_partitions = get_partitions_for_dims( - rhs.sharding(), new_dims_mapping->batch_dims, 1); - const int64_t conv_output_batch_partitions = get_partitions_for_dims( - output_sharding, new_dims_mapping->batch_dims, 2); - if ((conv_lhs_batch_partitions == conv_output_batch_partitions || - conv_rhs_batch_partitions == conv_output_batch_partitions) && - conv_output_batch_partitions > 1) { - TF_ASSIGN_OR_RETURN( - auto try_partitioned_conv, - PartitionDotGroupOnBatch( - lhs, rhs, output_base_shape, output_sharding, *new_dims_mapping, - num_partitions, conv_lhs_contracting_partitions, - conv_rhs_contracting_partitions, - conv_lhs_non_contracting_partitions, - conv_rhs_non_contracting_partitions, create_sharded_dot, - conv_window, module, original_hlo, - require_matching_devices_to_group, options, b, - windowed_dot_general_loops, visitor)); - if (try_partitioned_conv) { - return try_partitioned_conv; - } - } - // For batch/feature grouped convs, we try to at least partiton them on - // the batch dimensions and partially replicate other dimensions, instead - // of replicating everything. - const int64_t max_batch_partitions = std::max( - std::max(conv_lhs_batch_partitions, conv_rhs_batch_partitions), - conv_output_batch_partitions); - if (!require_matching_devices_to_group && max_batch_partitions > 1 && - ((original_hlo->batch_group_count() > 1 && - original_hlo->batch_group_count() % max_batch_partitions == 0) || - (original_hlo->feature_group_count() > 1 && - original_hlo->feature_group_count() % max_batch_partitions == 0))) { - const int64_t conv_lhs_batch_dim = - original_hlo->batch_group_count() > 1 - ? dnums.input_batch_dimension() - : dnums.input_feature_dimension(); - const int64_t conv_rhs_batch_dim = - dnums.kernel_output_feature_dimension(); - const int64_t conv_output_batch_dim = dnums.output_feature_dimension(); - PartitionedHlo resharded_lhs = lhs; - PartitionedHlo resharded_rhs = rhs; - HloSharding aligned_output_sharding = HloSharding::Replicate(); - HloInstruction* sharded_conv = nullptr; - DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping( - *new_dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(), - output_base_shape.rank()); - if (max_batch_partitions == conv_lhs_batch_partitions) { - resharded_lhs = resharded_lhs.Reshard( - hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept( - resharded_lhs.sharding(), {conv_lhs_batch_dim})); - auto lhs_sharding_transposed_to_match_rhs = - hlo_sharding_util::TransposeShardingWithCollapsedDims( - resharded_lhs.sharding(), indices_map.lhs_to_rhs_indices, - indices_map.rhs_to_lhs_indices); - resharded_rhs = - resharded_rhs.Reshard(*lhs_sharding_transposed_to_match_rhs); - TF_ASSIGN_OR_RETURN( - sharded_conv, - create_sharded_dot(resharded_lhs.hlo(), resharded_rhs.hlo(), b, - conv_window)); - auto lhs_sharding_transposed_to_match_output = - hlo_sharding_util::TransposeShardingWithCollapsedDims( - resharded_lhs.sharding(), indices_map.lhs_to_output_indices, - indices_map.output_to_lhs_indices); - sharded_conv->set_sharding(*lhs_sharding_transposed_to_match_output); - } else if (max_batch_partitions == conv_rhs_batch_partitions) { - resharded_rhs = resharded_rhs.Reshard( - hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept( - resharded_rhs.sharding(), {conv_rhs_batch_dim})); - auto rhs_sharding_transposed_to_match_lhs = - hlo_sharding_util::TransposeShardingWithCollapsedDims( - resharded_rhs.sharding(), indices_map.rhs_to_lhs_indices, - indices_map.lhs_to_rhs_indices); - resharded_lhs = - resharded_lhs.Reshard(*rhs_sharding_transposed_to_match_lhs); - TF_ASSIGN_OR_RETURN( - sharded_conv, - create_sharded_dot(resharded_lhs.hlo(), resharded_rhs.hlo(), b, - conv_window)); - auto rhs_sharding_transposed_to_match_output = - hlo_sharding_util::TransposeShardingWithCollapsedDims( - resharded_rhs.sharding(), indices_map.rhs_to_output_indices, - indices_map.output_to_rhs_indices); - sharded_conv->set_sharding(*rhs_sharding_transposed_to_match_output); - } else { - // max_batch_partitions == conv_output_batch_partitions - HloSharding target_output_sharding = - hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept( - output_sharding, {conv_output_batch_dim}); - auto output_sharding_transposed_to_match_lhs = - hlo_sharding_util::TransposeShardingWithCollapsedDims( - target_output_sharding, indices_map.output_to_lhs_indices, - indices_map.lhs_to_output_indices); - resharded_lhs = - resharded_lhs.Reshard(*output_sharding_transposed_to_match_lhs); - auto output_sharding_transposed_to_match_rhs = - hlo_sharding_util::TransposeShardingWithCollapsedDims( - target_output_sharding, indices_map.output_to_rhs_indices, - indices_map.rhs_to_output_indices); - resharded_rhs = - resharded_rhs.Reshard(*output_sharding_transposed_to_match_rhs); - TF_ASSIGN_OR_RETURN( - sharded_conv, - create_sharded_dot(resharded_lhs.hlo(), resharded_rhs.hlo(), b, - conv_window)); - sharded_conv->set_sharding(target_output_sharding); - } - - return PartitionedHlo(sharded_conv, output_base_shape, lhs.state()) - .Reshard(output_sharding) - .hlo(); - } - return nullptr; + TF_ASSIGN_OR_RETURN( + std::optional partitioned_conv_depthwise, + PartitionConvOnBatchOrFeatureGroupedDims( + lhs, rhs, output_base_shape, output_sharding, dims_mapping, + num_partitions, create_sharded_dot, conv_window, module, + original_hlo, options, b, windowed_dot_general_loops, + require_matching_devices_to_group, visitor)); + if (partitioned_conv_depthwise.has_value()) { + return partitioned_conv_depthwise.value(); } } + return std::nullopt; +} - TF_ASSIGN_OR_RETURN( - auto try_partitioned_dot, - PartitionBaseCase( - lhs, rhs, output_base_shape, output_sharding, dims_mapping, - num_partitions, create_sharded_dot, conv_window, module, original_hlo, - lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions, - lhs_contracting_partitions, rhs_contracting_partitions, - lhs_non_contracting_partitions, rhs_non_contracting_partitions, - output_lhs_non_contracting_partitions, - output_rhs_non_contracting_partitions, options, b, - windowed_dot_general_loops, - /*may_reshard_without_detecting_match=*/false, visitor)); - if (try_partitioned_dot) { - return try_partitioned_dot; - } - - // Recursively partition on different types of dimensions. - // - // Case 1: Group partitions by batch. +absl::StatusOr PartitionDotGroupOnBatchDims( + const PartitionedHlo& lhs, const PartitionedHlo& rhs, + const Shape& output_base_shape, const HloSharding& output_sharding, + const DotConvDimsMapping& dims_mapping, int64_t num_partitions, + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> + create_sharded_dot, + const Window& conv_window, HloModule* module, HloInstruction* original_hlo, + const SpmdPartitionerOptions& options, SpmdBuilder* b, + std::vector* + windowed_dot_general_loops, + bool require_matching_devices_to_group, SpmdPartitioningVisitor* visitor) { + const int64_t lhs_batch_partitions = GetPartitionsForDims( + lhs.sharding(), dims_mapping.batch_dims, DotComponent::LHS); + const int64_t rhs_batch_partitions = GetPartitionsForDims( + rhs.sharding(), dims_mapping.batch_dims, DotComponent::RHS); + const int64_t output_batch_partitions = GetPartitionsForDims( + output_sharding, dims_mapping.batch_dims, DotComponent::OUTPUT); + const int64_t lhs_contracting_partitions = GetPartitionsForDims( + lhs.sharding(), dims_mapping.contracting_dims, DotComponent::LHS); + const int64_t rhs_contracting_partitions = GetPartitionsForDims( + rhs.sharding(), dims_mapping.contracting_dims, DotComponent::RHS); + const int64_t lhs_non_contracting_partitions = GetPartitionsForDims( + lhs.sharding(), dims_mapping.lhs_non_contracting_dims, DotComponent::LHS); + const int64_t rhs_non_contracting_partitions = GetPartitionsForDims( + rhs.sharding(), dims_mapping.rhs_non_contracting_dims, DotComponent::RHS); if ((lhs_batch_partitions == output_batch_partitions || rhs_batch_partitions == output_batch_partitions) && output_batch_partitions > 1) { TF_ASSIGN_OR_RETURN( auto dot, - PartitionDotGroupOnBatch( + PartitionDotGroupOnBatchImpl( lhs, rhs, output_base_shape, output_sharding, dims_mapping, num_partitions, lhs_contracting_partitions, rhs_contracting_partitions, lhs_non_contracting_partitions, @@ -3793,8 +3854,42 @@ StatusOr PartitionDot( return dot; } } + return nullptr; +} - // Case 2: Group partitions by non-contracting dimensions. +absl::StatusOr PartitionDotGroupOnNonContractingDims( + const PartitionedHlo& lhs, const PartitionedHlo& rhs, + const Shape& output_base_shape, const HloSharding& output_sharding, + const DotConvDimsMapping& dims_mapping, int64_t num_partitions, + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> + create_sharded_dot, + const Window& conv_window, HloModule* module, HloInstruction* original_hlo, + const SpmdPartitionerOptions& options, SpmdBuilder* b, + std::vector* + windowed_dot_general_loops, + bool require_matching_devices_to_group, SpmdPartitioningVisitor* visitor) { + const int64_t lhs_batch_partitions = GetPartitionsForDims( + lhs.sharding(), dims_mapping.batch_dims, DotComponent::LHS); + const int64_t rhs_batch_partitions = GetPartitionsForDims( + rhs.sharding(), dims_mapping.batch_dims, DotComponent::RHS); + const int64_t output_batch_partitions = GetPartitionsForDims( + output_sharding, dims_mapping.batch_dims, DotComponent::OUTPUT); + const int64_t lhs_contracting_partitions = GetPartitionsForDims( + lhs.sharding(), dims_mapping.contracting_dims, DotComponent::LHS); + const int64_t rhs_contracting_partitions = GetPartitionsForDims( + rhs.sharding(), dims_mapping.contracting_dims, DotComponent::RHS); + const int64_t lhs_non_contracting_partitions = GetPartitionsForDims( + lhs.sharding(), dims_mapping.lhs_non_contracting_dims, DotComponent::LHS); + const int64_t rhs_non_contracting_partitions = GetPartitionsForDims( + rhs.sharding(), dims_mapping.rhs_non_contracting_dims, DotComponent::RHS); + const int64_t output_lhs_non_contracting_partitions = GetPartitionsForDims( + output_sharding, dims_mapping.lhs_non_contracting_dims, + DotComponent::OUTPUT); + const int64_t output_rhs_non_contracting_partitions = GetPartitionsForDims( + output_sharding, dims_mapping.rhs_non_contracting_dims, + DotComponent::OUTPUT); const bool may_group_on_lhs_non_contracting = lhs_non_contracting_partitions == output_lhs_non_contracting_partitions && lhs_non_contracting_partitions > 1; @@ -3849,7 +3944,7 @@ StatusOr PartitionDot( prioritize_contracting_for_faster_windowed_einsum)) { TF_ASSIGN_OR_RETURN( auto dot, - PartitionDotGroupOnNonContracting( + PartitionDotGroupOnNonContractingImpl( lhs_matching, lhs_matching ? lhs : rhs, lhs_matching ? rhs : lhs, lhs_matching ? lhs_contracting_partitions : rhs_contracting_partitions, @@ -3868,13 +3963,39 @@ StatusOr PartitionDot( return dot; } } + return nullptr; +} - // Case 3: Group partitions by contracting dimensions. +absl::StatusOr PartitionDotGroupOnContractingDims( + const PartitionedHlo& lhs, const PartitionedHlo& rhs, + const Shape& output_base_shape, const HloSharding& output_sharding, + const DotConvDimsMapping& dims_mapping, int64_t num_partitions, + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> + create_sharded_dot, + const Window& conv_window, HloModule* module, HloInstruction* original_hlo, + const SpmdPartitionerOptions& options, SpmdBuilder* b, + std::vector* + windowed_dot_general_loops, + bool require_matching_devices_to_group, SpmdPartitioningVisitor* visitor) { + const int64_t output_batch_partitions = GetPartitionsForDims( + output_sharding, dims_mapping.batch_dims, DotComponent::OUTPUT); + const int64_t lhs_contracting_partitions = GetPartitionsForDims( + lhs.sharding(), dims_mapping.contracting_dims, DotComponent::LHS); + const int64_t rhs_contracting_partitions = GetPartitionsForDims( + rhs.sharding(), dims_mapping.contracting_dims, DotComponent::RHS); + const int64_t output_lhs_non_contracting_partitions = GetPartitionsForDims( + output_sharding, dims_mapping.lhs_non_contracting_dims, + DotComponent::OUTPUT); + const int64_t output_rhs_non_contracting_partitions = GetPartitionsForDims( + output_sharding, dims_mapping.rhs_non_contracting_dims, + DotComponent::OUTPUT); if (lhs_contracting_partitions == rhs_contracting_partitions && lhs_contracting_partitions > 1) { TF_ASSIGN_OR_RETURN( auto dot, - PartitionDotGroupOnContracting( + PartitionDotGroupOnContractingImpl( lhs, rhs, dims_mapping.contracting_dims, output_batch_partitions, output_lhs_non_contracting_partitions, output_rhs_non_contracting_partitions, output_base_shape, @@ -3898,7 +4019,7 @@ StatusOr PartitionDot( } if (!matching_dims.empty()) { TF_ASSIGN_OR_RETURN( - auto dot, PartitionDotGroupOnContracting( + auto dot, PartitionDotGroupOnContractingImpl( lhs, rhs, matching_dims, output_batch_partitions, output_lhs_non_contracting_partitions, output_rhs_non_contracting_partitions, @@ -3911,9 +4032,22 @@ StatusOr PartitionDot( } } } + return nullptr; +} - // Case 4: If operands are replicated but output is partially replicated, - // recursive call with partial replication removed. +absl::StatusOr PartitionDotRemovingOutputPartialReplication( + const PartitionedHlo& lhs, const PartitionedHlo& rhs, + const Shape& output_base_shape, const HloSharding& output_sharding, + const DotConvDimsMapping& dims_mapping, int64_t num_partitions, + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> + create_sharded_dot, + const Window& conv_window, HloModule* module, HloInstruction* original_hlo, + const SpmdPartitionerOptions& options, SpmdBuilder* b, + std::vector* + windowed_dot_general_loops, + bool require_matching_devices_to_group, SpmdPartitioningVisitor* visitor) { if (lhs.sharding().IsReplicated() && rhs.sharding().IsReplicated() && output_sharding.ReplicateOnLastTileDim()) { auto grouped_output = hlo_sharding_util::GroupShardingOnDims( @@ -3932,34 +4066,137 @@ StatusOr PartitionDot( return dot; } } + return nullptr; +} - // We failed to find partial matches, invoke base case again with - // may_reshard_without_detecting_match. +// Recursive partitioning function. If there are partial dimensions matching +// in the operands and output, group the devices and recursively partition +// the in-group dot. +absl::StatusOr PartitionDot( + const PartitionedHlo& lhs, const PartitionedHlo& raw_rhs, + const Shape& output_base_shape, const HloSharding& output_sharding, + const DotConvDimsMapping& dims_mapping, int64_t num_partitions, + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> + create_sharded_dot, + const Window& conv_window, HloModule* module, HloInstruction* original_hlo, + bool require_matching_devices_to_group, + const SpmdPartitionerOptions& options, SpmdBuilder* b, + std::vector* + windowed_dot_general_loops, + SpmdPartitioningVisitor* visitor) { + // If lhs' hlo and rhs' hlo are identical, make a copy for rhs. + std::unique_ptr new_rhs; + if (lhs.hlo() == raw_rhs.hlo()) { + auto copy_hlo = b->AddInstruction(HloInstruction::CreateUnary( + raw_rhs.hlo()->shape(), HloOpcode::kCopy, raw_rhs.hlo())); + copy_hlo->copy_sharding(raw_rhs.hlo()); + new_rhs = std::make_unique(copy_hlo, raw_rhs.base_shape(), + raw_rhs.state()); + } + const PartitionedHlo& rhs = (lhs.hlo() == raw_rhs.hlo()) ? *new_rhs : raw_rhs; + + // Recursively partition on different types of dimensions. + + // Case 0: Try partition the purely spatially-partitioned convolution with + // convolution spatial dimension partitioned or depthwise parallel + // dimension partitioned. TF_ASSIGN_OR_RETURN( - auto dot, + std::optional partitioned_conv, + PartitionConv(lhs, rhs, output_base_shape, output_sharding, dims_mapping, + num_partitions, create_sharded_dot, conv_window, module, + original_hlo, options, b, windowed_dot_general_loops, + require_matching_devices_to_group, visitor)); + if (partitioned_conv.has_value()) { + return partitioned_conv.value(); + } + + HloInstruction* partitioned_dot; + // Before we find partial matches along the dimensions, invoke base case + // again without may_reshard_without_detecting_match. + TF_ASSIGN_OR_RETURN( + partitioned_dot, PartitionBaseCase( lhs, rhs, output_base_shape, output_sharding, dims_mapping, num_partitions, create_sharded_dot, conv_window, module, original_hlo, - lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions, - lhs_contracting_partitions, rhs_contracting_partitions, - lhs_non_contracting_partitions, rhs_non_contracting_partitions, - output_lhs_non_contracting_partitions, - output_rhs_non_contracting_partitions, options, b, - windowed_dot_general_loops, - /*may_reshard_without_detecting_match=*/true, visitor)); - if (dot) { - return dot; + options, b, windowed_dot_general_loops, + /*may_reshard_without_detecting_match=*/false, visitor)); + if (partitioned_dot) { + return partitioned_dot; + } + + // Case 1: Group partitions by batch. + TF_ASSIGN_OR_RETURN( + partitioned_dot, + PartitionDotGroupOnBatchDims( + lhs, rhs, output_base_shape, output_sharding, dims_mapping, + num_partitions, create_sharded_dot, conv_window, module, original_hlo, + options, b, windowed_dot_general_loops, + require_matching_devices_to_group, visitor)); + if (partitioned_dot) { + return partitioned_dot; + } + + // Case 2: Group partitions by non-contracting dimensions. + TF_ASSIGN_OR_RETURN( + partitioned_dot, + PartitionDotGroupOnNonContractingDims( + lhs, rhs, output_base_shape, output_sharding, dims_mapping, + num_partitions, create_sharded_dot, conv_window, module, original_hlo, + options, b, windowed_dot_general_loops, + require_matching_devices_to_group, visitor)); + if (partitioned_dot) { + return partitioned_dot; + } + + // Case 3: Group partitions by contracting dimensions. + TF_ASSIGN_OR_RETURN( + partitioned_dot, + PartitionDotGroupOnContractingDims( + lhs, rhs, output_base_shape, output_sharding, dims_mapping, + num_partitions, create_sharded_dot, conv_window, module, original_hlo, + options, b, windowed_dot_general_loops, + require_matching_devices_to_group, visitor)); + if (partitioned_dot) { + return partitioned_dot; + } + + // Case 4: If operands are replicated but output is partially replicated, + // recursive call with partial replication removed. + TF_ASSIGN_OR_RETURN( + partitioned_dot, + PartitionDotRemovingOutputPartialReplication( + lhs, rhs, output_base_shape, output_sharding, dims_mapping, + num_partitions, create_sharded_dot, conv_window, module, original_hlo, + options, b, windowed_dot_general_loops, + require_matching_devices_to_group, visitor)); + if (partitioned_dot) { + return partitioned_dot; + } + + // We failed to find partial matches, invoke base case again with + // may_reshard_without_detecting_match. + TF_ASSIGN_OR_RETURN( + partitioned_dot, + PartitionBaseCase(lhs, rhs, output_base_shape, output_sharding, + dims_mapping, num_partitions, create_sharded_dot, + conv_window, module, original_hlo, options, b, + windowed_dot_general_loops, + /*may_reshard_without_detecting_match=*/true, visitor)); + if (partitioned_dot) { + return partitioned_dot; } return nullptr; } -StatusOr PartitionDot( - PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, - const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, - int64_t num_partitions, - absl::FunctionRef(HloInstruction*, - HloInstruction*, SpmdBuilder*, - const Window& conv_window)> +absl::StatusOr PartitionDot( + const PartitionedHlo& lhs, const PartitionedHlo& rhs, + const Shape& output_base_shape, const HloSharding& output_sharding, + const DotConvDimsMapping& dims_mapping, int64_t num_partitions, + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> create_sharded_dot, const Window& conv_window, HloModule* module, HloInstruction* original_hlo, const SpmdPartitionerOptions& options, SpmdBuilder* b, @@ -3994,9 +4231,9 @@ StatusOr PartitionDot( Status SpmdPartitioningVisitor::HandleDotHelper( HloInstruction* hlo, const DotConvDimsMapping& dims_mapping, - absl::FunctionRef(HloInstruction*, - HloInstruction*, SpmdBuilder*, - const Window& conv_window)> + absl::FunctionRef( + HloInstruction*, HloInstruction*, SpmdBuilder*, + const Window& conv_window)> create_sharded_dot) { if (hlo->sharding().HasUniqueDevice()) { return DefaultAction(hlo); @@ -4022,7 +4259,7 @@ namespace { // Finds a cluster of nodes that produce the inputs for `hlo` which only // depend on small operands, which means the cluster should start with // broadcasts, constants and iotas. All other internal nodes must be -// non-side-effecting elemntwise ops. Returns the set of nodes, and the small +// non-side-effecting elementwise ops. Returns the set of nodes, and the small // operands. E.g., for the following graph, // // a -> broadcast -> multiply @@ -4563,7 +4800,8 @@ Status MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions( [&](absl::flat_hash_map& outside_to_inside, absl::Span slice_offsets, - HloInstruction* last_iter_result) -> StatusOr { + HloInstruction* last_iter_result) + -> absl::StatusOr { HloInstruction* operand0 = outside_to_inside[reduce_outside->operand(0)]; HloInstruction* operand1 = outside_to_inside[reduce_outside->operand(1)]; TF_ASSIGN_OR_RETURN( diff --git a/xla/service/spmd/fft_handler.cc b/xla/service/spmd/fft_handler.cc index 5035dcf42b8d3..76cf9abedd8ea 100644 --- a/xla/service/spmd/fft_handler.cc +++ b/xla/service/spmd/fft_handler.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/spmd/gather_scatter_handler.cc b/xla/service/spmd/gather_scatter_handler.cc index 951efe6f61204..0fcea7793414c 100644 --- a/xla/service/spmd/gather_scatter_handler.cc +++ b/xla/service/spmd/gather_scatter_handler.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,12 +17,14 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" +#include "absl/log/log.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -68,14 +70,25 @@ PartitionedHlo PerGroupPartitionedHlo( } // Helper to get multiple per-group partitioned hlos. -absl::InlinedVector PerGroupPartitionedHlos( - absl::Span phlos, const GroupedSharding& grouped_sharding, +std::vector PerGroupPartitionedHlos( + std::vector& phlos, const GroupedSharding& grouped_sharding, SpmdBuilder* b, absl::InlinedVector, 3>& clean_ups) { - absl::InlinedVector per_group_phlos; - absl::c_transform( - phlos, std::back_inserter(per_group_phlos), [&](PartitionedHlo& phlo) { - return PerGroupPartitionedHlo(phlo, grouped_sharding, b, clean_ups); - }); + // Cache per-group partitioned hlos to avoid group-partitioning it more than + // once. + absl::flat_hash_map cached_per_group_hlos; + std::vector hlos; + absl::c_transform(phlos, std::back_inserter(hlos), + [&](PartitionedHlo phlo) { return phlo.hlo(); }); + + std::vector per_group_phlos; + for (int i = 0; i != hlos.size(); ++i) { + if (!cached_per_group_hlos.contains(hlos[i])) { + cached_per_group_hlos.emplace(std::make_pair( + hlos[i], + PerGroupPartitionedHlo(phlos[i], grouped_sharding, b, clean_ups))); + } + per_group_phlos.push_back(cached_per_group_hlos.at(hlos[i])); + } return per_group_phlos; } @@ -131,7 +144,9 @@ std::vector GatherIndexDimsByPriority( const GatherDimensionNumbers& dnums = gather->gather_dimension_numbers(); std::vector priority_dims_for_indices; - priority_dims_for_indices.push_back(dnums.index_vector_dim()); + if (dnums.index_vector_dim() < indices.rank()) { + priority_dims_for_indices.push_back(dnums.index_vector_dim()); + } absl::InlinedVector index_passthrough_dims = hlo_sharding_util::GetGatherScatterIndexPassthroughIndexDims( indices.rank(), dnums.index_vector_dim()); @@ -233,19 +248,21 @@ IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( } // Function that tries to perform recursive partitioning of Gather. -StatusOr PartitionGather( - const HloGatherInstruction* gather, PartitionedHlo& operand, - PartitionedHlo& indices, const Shape& output_shape, +absl::StatusOr PartitionGather( + const HloGatherInstruction* gather, PartitionedHlo operand, + PartitionedHlo indices, const Shape& output_shape, const HloSharding& output_sharding, absl::Span batch_dims, - absl::Span slice_sizes, SpmdPartitioningVisitor* visitor); + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive = true); // Perform partitioning of Gather when the indices are partitioned on the // non-index vector dimension. -StatusOr PartitionGatherIndexPassthroughDimensions( - const HloGatherInstruction* gather, PartitionedHlo& operand, - PartitionedHlo& indices, const Shape& output_shape, +absl::StatusOr PartitionGatherIndexPassthroughDimensions( + const HloGatherInstruction* gather, PartitionedHlo operand, + PartitionedHlo indices, const Shape& output_shape, const HloSharding& output_sharding, absl::Span batch_dims, - absl::Span slice_sizes, SpmdPartitioningVisitor* visitor) { + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive) { // Perform clean up actions upon exiting function scope. absl::InlinedVector, 3> clean_ups; absl::Cleanup cleaner = [&clean_ups] { @@ -300,10 +317,12 @@ StatusOr PartitionGatherIndexPassthroughDimensions( TF_ASSIGN_OR_RETURN( HloInstruction * pgather, PartitionGather(gather, per_group_operand, per_group_indices, pshape, - output_grouped.sharding, batch_dims, slice_sizes, - visitor)); + output_grouped.sharding, batch_dims, slice_sizes, visitor, + allow_recursive)); pgather->set_sharding(passthrough_sharding); - VLOG(5) << "[Gather partitioning]: Partitioned as index only"; + if (allow_recursive) { + VLOG(5) << "[Gather partitioning]: Partitioned as index only"; + } return PartitionedHlo(pgather, gather->shape(), operand.state()) .Reshard(output_sharding) .hlo(); @@ -312,11 +331,12 @@ StatusOr PartitionGatherIndexPassthroughDimensions( // Perform partitioning of Gather when the operand is split in a offset // dimension that is passed through (slice size is the same size of the operand // dimension). -StatusOr PartitionGatherOperandPassthroughDimensions( - const HloGatherInstruction* gather, PartitionedHlo& operand, - PartitionedHlo& indices, const Shape& output_shape, +absl::StatusOr PartitionGatherOperandPassthroughDimensions( + const HloGatherInstruction* gather, PartitionedHlo operand, + PartitionedHlo indices, const Shape& output_shape, const HloSharding& output_sharding, absl::Span batch_dims, - absl::Span slice_sizes, SpmdPartitioningVisitor* visitor) { + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive) { if (operand.sharding().IsTileMaximal()) { return nullptr; } @@ -379,10 +399,12 @@ StatusOr PartitionGatherOperandPassthroughDimensions( HloInstruction * pgather, PartitionGather(gather, per_group_operand, per_group_indices, pshape, output_grouped.sharding, batch_dims, pslice_sizes, - visitor)); + visitor, allow_recursive)); pgather->set_sharding(*maybe_passthrough); - VLOG(5) << "[Gather partitioning]: Partitioned as operand passthrough " - "offset_dim"; + if (allow_recursive) { + VLOG(5) << "[Gather partitioning]: Partitioned as operand passthrough " + "offset_dim"; + } return PartitionedHlo(pgather, output_shape, operand.state()) .Reshard(output_sharding) .hlo(); @@ -392,11 +414,12 @@ StatusOr PartitionGatherOperandPassthroughDimensions( // Partition a Gather when its sliced in a dimension in the operand that is // trivially sliced (sliced with slice size of 1). -StatusOr PartitionGatherTrivialSlicedOperandDimensions( - const HloGatherInstruction* gather, PartitionedHlo& operand, - PartitionedHlo& indices, const Shape& output_shape, +absl::StatusOr PartitionGatherTrivialSlicedOperandDimensions( + const HloGatherInstruction* gather, PartitionedHlo operand, + PartitionedHlo indices, const Shape& output_shape, const HloSharding& output_sharding, absl::Span batch_dims, - absl::Span slice_sizes, SpmdPartitioningVisitor* visitor) { + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive) { // Perform clean up actions upon exiting function scope. absl::InlinedVector, 3> clean_ups; absl::Cleanup cleaner = [&clean_ups] { @@ -492,7 +515,7 @@ StatusOr PartitionGatherTrivialSlicedOperandDimensions( HloInstruction * pgather, PartitionGather(gather, per_group_operand, per_group_new_indices, pshape, output_grouped.sharding, batch_dims, - slice_sizes, visitor)); + slice_sizes, visitor, allow_recursive)); // Mask out invalid results. auto filter = b->AddInstruction(HloInstruction::CreateCompare( ShapeUtil::ChangeElementType(indices.hlo()->shape(), PRED), @@ -543,8 +566,10 @@ StatusOr PartitionGatherTrivialSlicedOperandDimensions( *trivial_slice_dims, operand.state().collective_ops_creator, MakeBinaryAdd(filtered->shape().element_type(), operand.state().module)); - VLOG(5) << "[Gather partitioning]: Partitioned as trivial operand " - "batch_dim slice"; + if (allow_recursive) { + VLOG(5) << "[Gather partitioning]: Partitioned as trivial operand " + "batch_dim slice"; + } ar->set_sharding(hlo_sharding_util::UngroupSharding(output_grouped)); return PartitionedHlo(ar, output_shape, operand.state()) .Reshard(output_sharding) @@ -557,11 +582,12 @@ StatusOr PartitionGatherTrivialSlicedOperandDimensions( // (which means that the indices access the operand in a monotonically // increasing way across the respective operand dimension referenced by the // index). -StatusOr PartitionGatherIndexParallelDimensions( - const HloGatherInstruction* gather, PartitionedHlo& operand, - PartitionedHlo& indices, const Shape& output_shape, +absl::StatusOr PartitionGatherIndexParallelDimensions( + const HloGatherInstruction* gather, PartitionedHlo operand, + PartitionedHlo indices, const Shape& output_shape, const HloSharding& output_sharding, absl::Span batch_dims, - absl::Span slice_sizes, SpmdPartitioningVisitor* visitor) { + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive) { // Perform clean up actions upon exiting function scope. absl::InlinedVector, 3> clean_ups; absl::Cleanup cleaner = [&clean_ups] { @@ -673,8 +699,10 @@ StatusOr PartitionGatherIndexParallelDimensions( HloInstruction * pgather, PartitionGather(gather, per_group_operand, per_group_new_indices, pshape, output_grouped.sharding, batch_dims, - slice_sizes, visitor)); - VLOG(5) << "[Gather partitioning]: Partitioned as parallel batch_dim"; + slice_sizes, visitor, allow_recursive)); + if (allow_recursive) { + VLOG(5) << "[Gather partitioning]: Partitioned as parallel batch_dim"; + } pgather->set_sharding(hlo_sharding_util::UngroupSharding(output_grouped)); return PartitionedHlo(pgather, output_shape, operand.state()) .Reshard(output_sharding) @@ -697,76 +725,24 @@ GatherPartitionMethods() { "PartitionGatherIndexPassthroughDimensions"}}; } -// Estimates the cost for each partitioning methods for gather. -int64_t GatherPartitionMethodCostModel( +// Estimates the memory and communication cost for each partitioning methods for +// gather. +std::pair GatherPartitionMethodCostModel( decltype(PartitionGather)* partition_method, - const HloGatherInstruction* gather, PartitionedHlo& operand, - PartitionedHlo& indices, const Shape& output_shape, + const HloGatherInstruction* gather, const PartitionedHlo& operand, + const PartitionedHlo& indices, const Shape& output_shape, const HloSharding& output_sharding, absl::Span batch_dims, absl::Span slice_sizes, SpmdPartitioningVisitor* visitor) { if (partition_method == PartitionGatherIndexParallelDimensions) { - // Always prioritize index parallel paritioning, and assume it has zero + // Always prioritize index parallel partitioning, and assume it has zero // cost. - return 0; - } - if (partition_method == PartitionGatherOperandPassthroughDimensions) { - auto operand_passthrough_sharding = hlo_sharding_util:: - GatherOutputShardingFromOperandOperandPassthroughDimensions( - operand.base_shape(), operand.sharding(), *gather, slice_sizes); - if (!operand_passthrough_sharding) { - return INT64_MAX; - } - // Consider the potential cost of having to rematerialize the output if the - // sharding is not compatible. - const int64_t max_potential_output_shape_size = - hlo_sharding_util::IsSubTilingOrEqualSharding( - output_shape, output_sharding, *operand_passthrough_sharding) || - hlo_sharding_util::IsSubTilingOrEqualSharding( - output_shape, *operand_passthrough_sharding, - output_sharding) - ? ShapeSizeInBytes(MakePartitionedShape( - output_shape, *operand_passthrough_sharding)) - : ShapeSizeInBytes(output_shape); - - return std::max(ShapeSizeInBytes(operand.hlo()->shape()) + - ShapeSizeInBytes(MakePartitionedShape( - output_shape, *operand_passthrough_sharding)) + - ShapeSizeInBytes(indices.base_shape()), - max_potential_output_shape_size); + return {0, 0}; } - if (partition_method == PartitionGatherTrivialSlicedOperandDimensions) { - auto trivial_slice_dims = GatherScatterOperandPartitionedOnTrivialSliceDims( - operand, gather->gather_dimension_numbers().start_index_map(), - slice_sizes); - return !trivial_slice_dims ? INT64_MAX - : ShapeSizeInBytes(operand.hlo()->shape()) + - ShapeSizeInBytes(output_shape) + - ShapeSizeInBytes(indices.base_shape()); - } - if (partition_method == PartitionGatherIndexPassthroughDimensions) { - const HloSharding index_passthrough_sharding = hlo_sharding_util:: - GatherOutputShardingFromIndexIndexPassthroughDimensions( - indices.sharding(), gather); - if (index_passthrough_sharding.IsTileMaximal()) { - return INT64_MAX; - } - // Consider the potential cost of having to rematerialize the output if the - // sharding is not compatible. - const int64_t max_potential_output_shape_size = - hlo_sharding_util::IsSubTilingOrEqualSharding( - output_shape, output_sharding, index_passthrough_sharding) || - hlo_sharding_util::IsSubTilingOrEqualSharding( - output_shape, index_passthrough_sharding, output_sharding) - ? ShapeSizeInBytes(MakePartitionedShape(output_shape, - index_passthrough_sharding)) - : ShapeSizeInBytes(output_shape); - return std::max(ShapeSizeInBytes(operand.base_shape()) + - ShapeSizeInBytes(MakePartitionedShape( - output_shape, index_passthrough_sharding)) + - ShapeSizeInBytes(indices.hlo()->shape()), - max_potential_output_shape_size); - } - return INT64_MAX; + return EvaluatePartitionCost(gather, partition_method, gather, operand, + indices, output_shape, output_sharding, + batch_dims, slice_sizes, visitor, + /*allow_recursive=*/false) + .value(); } // Returns a full list of partitioning methods for gather ordered by the @@ -774,24 +750,34 @@ int64_t GatherPartitionMethodCostModel( // TODO(b/245443033): Take recursion of gather/scatter partitioning into // consideration of the cost model. std::vector GatherPartitionMethodsOrderedByCost( - const HloGatherInstruction* gather, PartitionedHlo& operand, - PartitionedHlo& indices, const Shape& output_shape, + const HloGatherInstruction* gather, const PartitionedHlo& operand, + const PartitionedHlo& indices, const Shape& output_shape, const HloSharding& output_sharding, absl::Span batch_dims, absl::Span slice_sizes, SpmdPartitioningVisitor* visitor) { std::vector ordered_partition_methods; - std::vector ordered_costs; + absl::flat_hash_map> + partition_method_costs; auto gather_partition_methods = GatherPartitionMethods(); - for (auto [partition_method, _] : gather_partition_methods) { - const int64_t cost = GatherPartitionMethodCostModel( + for (auto [partition_method, method_name] : gather_partition_methods) { + auto [memory_cost, communication_cost] = GatherPartitionMethodCostModel( partition_method, gather, operand, indices, output_shape, output_sharding, batch_dims, slice_sizes, visitor); - auto offset = std::distance(ordered_costs.begin(), - absl::c_upper_bound(ordered_costs, cost)); - ordered_costs.insert(ordered_costs.begin() + offset, cost); - ordered_partition_methods.insert(ordered_partition_methods.begin() + offset, - partition_method); + VLOG(5) << method_name << " has memory cost of " << memory_cost << " bytes" + << " and communication cost of " << communication_cost << " bytes"; + partition_method_costs.emplace( + partition_method, std::make_pair(memory_cost, communication_cost)); + ordered_partition_methods.push_back(partition_method); } - CHECK_EQ(ordered_partition_methods.size(), gather_partition_methods.size()); + absl::c_sort(ordered_partition_methods, [&](decltype(PartitionGather)* lhs, + decltype(PartitionGather)* rhs) { + auto [lhs_memory_cost, lhs_communication_cost] = + partition_method_costs[lhs]; + auto [rhs_memory_cost, rhs_communication_cost] = + partition_method_costs[rhs]; + return lhs_memory_cost != rhs_memory_cost + ? lhs_memory_cost < rhs_memory_cost + : lhs_communication_cost < rhs_communication_cost; + }); VLOG(5) << "Gather partitioning methods(ordered by cost):"; for (auto partition_method : ordered_partition_methods) { VLOG(5) << " " @@ -805,21 +791,25 @@ std::vector GatherPartitionMethodsOrderedByCost( return ordered_partition_methods; } -StatusOr PartitionGather( - const HloGatherInstruction* gather, PartitionedHlo& operand, - PartitionedHlo& indices, const Shape& output_shape, +absl::StatusOr PartitionGather( + const HloGatherInstruction* gather, PartitionedHlo operand, + PartitionedHlo indices, const Shape& output_shape, const HloSharding& output_sharding, absl::Span batch_dims, - absl::Span slice_sizes, SpmdPartitioningVisitor* visitor) { - HloInstruction* partitioned_gather; - for (auto partition_method : GatherPartitionMethodsOrderedByCost( - gather, operand, indices, output_shape, output_sharding, batch_dims, - slice_sizes, visitor)) { - TF_ASSIGN_OR_RETURN( - partitioned_gather, - partition_method(gather, operand, indices, output_shape, - output_sharding, batch_dims, slice_sizes, visitor)); - if (partitioned_gather) { - return partitioned_gather; + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive) { + if (allow_recursive) { + HloInstruction* partitioned_gather; + for (auto partition_method : GatherPartitionMethodsOrderedByCost( + gather, operand, indices, output_shape, output_sharding, + batch_dims, slice_sizes, visitor)) { + TF_ASSIGN_OR_RETURN( + partitioned_gather, + partition_method(gather, operand, indices, output_shape, + output_sharding, batch_dims, slice_sizes, visitor, + allow_recursive)); + if (partitioned_gather) { + return partitioned_gather; + } } } HloInstruction* new_gather = @@ -964,7 +954,9 @@ std::vector ScatterIndexDimsByPriority( const ScatterDimensionNumbers& dnums = scatter->scatter_dimension_numbers(); std::vector priority_dims_for_indices; - priority_dims_for_indices.push_back(dnums.index_vector_dim()); + if (dnums.index_vector_dim() < indices.rank()) { + priority_dims_for_indices.push_back(dnums.index_vector_dim()); + } absl::InlinedVector index_passthrough_dims = hlo_sharding_util::GetGatherScatterIndexPassthroughIndexDims( indices.rank(), dnums.index_vector_dim()); @@ -999,21 +991,23 @@ std::vector ScatterUpdateDimsByPriority( return priority_dims_for_output; } -StatusOr PartitionScatter( - const HloScatterInstruction* scatter, absl::Span operands, - PartitionedHlo& indices, absl::Span updates, +absl::StatusOr PartitionScatter( + const HloScatterInstruction* scatter, std::vector operands, + PartitionedHlo indices, std::vector updates, const Shape& output_shape, const HloSharding& output_sharding, - absl::Span slice_sizes, SpmdPartitioningVisitor* visitor); + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive = true); // Partition a scatter over a indices dimensions that are cosidered parallel // (which means that the indices access the operand in a monotonically // increasing way across the respective operand dimension referenced by the // index). -StatusOr PartitionScatterIndexParallelDimensions( - const HloScatterInstruction* scatter, absl::Span operands, - PartitionedHlo& indices, absl::Span updates, +absl::StatusOr PartitionScatterIndexParallelDimensions( + const HloScatterInstruction* scatter, std::vector operands, + PartitionedHlo indices, std::vector updates, const Shape& output_shape, const HloSharding& output_sharding, - absl::Span slice_sizes, SpmdPartitioningVisitor* visitor) { + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive) { // Perform clean up actions upon exiting function scope. absl::InlinedVector, 3> clean_ups; absl::Cleanup cleaner = [&clean_ups] { @@ -1124,9 +1118,9 @@ StatusOr PartitionScatterIndexParallelDimensions( hlo_sharding_util::GroupShardingOnDims(updates[0].sharding(), update_parallel_dims); const GroupedSharding& output_grouped = operand_grouped; - absl::InlinedVector per_group_operands = + std::vector per_group_operands = PerGroupPartitionedHlos(operands, operand_grouped, b, clean_ups); - absl::InlinedVector per_group_updates = + std::vector per_group_updates = PerGroupPartitionedHlos(updates, update_grouped, b, clean_ups); PartitionedHlo per_group_new_indices = PerGroupPartitionedHlo( new_indices, new_indices_grouped, b, clean_ups); @@ -1135,14 +1129,16 @@ StatusOr PartitionScatterIndexParallelDimensions( TF_ASSIGN_OR_RETURN( HloInstruction * pscatter, PartitionScatter( - scatter, absl::MakeSpan(per_group_operands), - per_group_new_indices, absl::MakeSpan(per_group_updates), pshape, + scatter, per_group_operands, per_group_new_indices, + per_group_updates, pshape, HloSharding::Single(scatter->shape(), output_grouped.sharding), - slice_sizes, visitor)); + slice_sizes, visitor, allow_recursive)); pscatter->set_sharding(HloSharding::Single( pscatter->shape(), hlo_sharding_util::UngroupSharding(output_grouped))); - VLOG(5) << "[Scatter partitioning]: Partitioned as index parallel"; + if (allow_recursive) { + VLOG(5) << "[Scatter partitioning]: Partitioned as index parallel"; + } return PartitionedHlo(pscatter, output_shape, operands[0].state()) .Reshard(output_sharding) .hlo(); @@ -1153,11 +1149,12 @@ StatusOr PartitionScatterIndexParallelDimensions( // Perform partitioning of Scatter when the operand is split in a update window // dimension that is passed through (slice size is the same size of the operand // dimension). -StatusOr PartitionScatterOperandPassthroughDimensions( - const HloScatterInstruction* scatter, absl::Span operands, - PartitionedHlo& indices, absl::Span updates, +absl::StatusOr PartitionScatterOperandPassthroughDimensions( + const HloScatterInstruction* scatter, std::vector operands, + PartitionedHlo indices, std::vector updates, const Shape& output_shape, const HloSharding& output_sharding, - absl::Span slice_sizes, SpmdPartitioningVisitor* visitor) { + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive) { if (operands[0].sharding().IsTileMaximal()) { return nullptr; } @@ -1214,9 +1211,9 @@ StatusOr PartitionScatterOperandPassthroughDimensions( ScatterIndexDimsByPriority(indices, scatter)), update_grouped); const GroupedSharding& output_grouped = operand_grouped; - absl::InlinedVector per_group_operands = + std::vector per_group_operands = PerGroupPartitionedHlos(operands, operand_grouped, b, clean_ups); - absl::InlinedVector per_group_updates = + std::vector per_group_updates = PerGroupPartitionedHlos(updates, update_grouped, b, clean_ups); PartitionedHlo per_group_indices = PerGroupPartitionedHlo(indices, indices_grouped, b, clean_ups); @@ -1224,14 +1221,16 @@ StatusOr PartitionScatterOperandPassthroughDimensions( TF_ASSIGN_OR_RETURN( HloInstruction * pscatter, PartitionScatter( - scatter, absl::MakeSpan(per_group_operands), per_group_indices, - absl::MakeSpan(per_group_updates), pshape, + scatter, per_group_operands, per_group_indices, per_group_updates, + pshape, HloSharding::Single(scatter->shape(), output_grouped.sharding), - pslice_sizes, visitor)); + pslice_sizes, visitor, allow_recursive)); pscatter->set_sharding(HloSharding::Single( pscatter->shape(), hlo_sharding_util::UngroupSharding(output_grouped))); - VLOG(5) << "[Scatter partitioning]: Partitioned as operand passthrough " - "update_window_dims"; + if (allow_recursive) { + VLOG(5) << "[Scatter partitioning]: Partitioned as operand passthrough " + "update_window_dims"; + } return PartitionedHlo(pscatter, output_shape, operands[0].state()) .Reshard(output_sharding) .hlo(); @@ -1241,11 +1240,12 @@ StatusOr PartitionScatterOperandPassthroughDimensions( // Perform partitioning of Scatter when the indices are partitioned on the // non-index vector dimension. -StatusOr PartitionScatterIndexPassthroughDimensions( - const HloScatterInstruction* scatter, absl::Span operands, - PartitionedHlo& indices, absl::Span updates, +absl::StatusOr PartitionScatterIndexPassthroughDimensions( + const HloScatterInstruction* scatter, std::vector operands, + PartitionedHlo indices, std::vector updates, const Shape& output_shape, const HloSharding& output_sharding, - absl::Span slice_sizes, SpmdPartitioningVisitor* visitor) { + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive) { // Perform clean up actions upon exiting function scope. absl::InlinedVector, 3> clean_ups; absl::Cleanup cleaner = [&clean_ups] { @@ -1342,9 +1342,8 @@ StatusOr PartitionScatterIndexPassthroughDimensions( per_group_operand.hlo())); PartitionedHlo new_operand = per_group_operand.CloneWithNewHlo(select_operand); - absl::InlinedVector per_group_new_operands = { - new_operand}; - absl::InlinedVector per_group_updates = { + std::vector per_group_new_operands = {new_operand}; + std::vector per_group_updates = { PerGroupPartitionedHlo(updates[0], update_grouped, b, clean_ups)}; PartitionedHlo per_group_indices = PerGroupPartitionedHlo(indices, indices_grouped, b, clean_ups); @@ -1352,10 +1351,10 @@ StatusOr PartitionScatterIndexPassthroughDimensions( TF_ASSIGN_OR_RETURN( HloInstruction * pscatter, PartitionScatter( - scatter, absl::MakeSpan(per_group_new_operands), per_group_indices, - absl::MakeSpan(per_group_updates), pshape, + scatter, per_group_new_operands, per_group_indices, + per_group_updates, pshape, HloSharding::Single(scatter->shape(), output_grouped.sharding), - slice_sizes, visitor)); + slice_sizes, visitor, allow_recursive)); // All-reduce along all dims in operand sharding -- this is OK because the // operand is not sharded on index_vector_dim. std::vector all_dims(indices.rank()); @@ -1367,7 +1366,9 @@ StatusOr PartitionScatterIndexPassthroughDimensions( operands[0].state().collective_ops_creator, scatter->to_apply()); all_reduce->set_sharding( hlo_sharding_util::UngroupSharding(output_grouped)); - VLOG(5) << "[Scatter partitioning]: Partitioned as index passthrough"; + if (allow_recursive) { + VLOG(5) << "[Scatter partitioning]: Partitioned as index passthrough"; + } return PartitionedHlo(all_reduce, output_shape, operands[0].state()) .Reshard(output_sharding) .hlo(); @@ -1375,116 +1376,116 @@ StatusOr PartitionScatterIndexPassthroughDimensions( // Partition a Scatter when its sliced in a dimension in the operand that is // trivially sliced (sliced with slice size of 1). -StatusOr PartitionScatterTrivialSlicedOperandDimensions( - const HloScatterInstruction* scatter, absl::Span operands, - PartitionedHlo& indices, absl::Span updates, +absl::StatusOr PartitionScatterTrivialSlicedOperandDimensions( + const HloScatterInstruction* scatter, std::vector operands, + PartitionedHlo indices, std::vector updates, const Shape& output_shape, const HloSharding& output_sharding, - absl::Span slice_sizes, SpmdPartitioningVisitor* visitor) { - // Perform clean up actions upon exiting function scope. - absl::InlinedVector, 3> clean_ups; - absl::Cleanup cleaner = [&clean_ups] { - for (auto& clean_up : clean_ups) { - clean_up(); - } - }; - - SpmdBuilder* b = visitor->builder(); - auto dnums = scatter->scatter_dimension_numbers(); - if (std::optional> trivial_slice_dims = - GatherScatterOperandPartitionedOnTrivialSliceDims( - operands[0], dnums.scatter_dims_to_operand_dims(), - slice_sizes)) { - // Operand is sharded on trivial slice dims (update slice size 1). We can - // adjust the indices on each partition by subtracting the offsets. Then - // we execute a scatter on full updated indices, and out-of-bound accesses - // will have no effect on the result as guaranteed by the scatter - // semantics. - const int64_t num_groups = - operands[0].sharding().NumTiles(*trivial_slice_dims); - const int64_t num_tiles = operands[0].sharding().TotalNumTiles(); - const GroupedSharding operand_grouped = - hlo_sharding_util::GroupShardingOnDims(operands[0].sharding(), - *trivial_slice_dims); - // See if we can group partially replicated dimensions from the indices - // otherwise replicate it. - GroupedSharding indices_grouped = AlignGroupsWith( - hlo_sharding_util::GroupShardingOnReplicatedDim( - indices.sharding(), num_groups, num_tiles, indices.rank(), - ScatterIndexDimsByPriority(indices, scatter)), - operand_grouped); - // See if we can group partially replicated dimensions from the updates - // otherwise replicate it. - GroupedSharding update_grouped = AlignGroupsWith( + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive) { + // Perform clean up actions upon exiting function scope. + absl::InlinedVector, 3> clean_ups; + absl::Cleanup cleaner = [&clean_ups] { + for (auto& clean_up : clean_ups) { + clean_up(); + } + }; + + SpmdBuilder* b = visitor->builder(); + auto dnums = scatter->scatter_dimension_numbers(); + if (std::optional> trivial_slice_dims = + GatherScatterOperandPartitionedOnTrivialSliceDims( + operands[0], dnums.scatter_dims_to_operand_dims(), slice_sizes)) { + // Operand is sharded on trivial slice dims (update slice size 1). We can + // adjust the indices on each partition by subtracting the offsets. Then + // we execute a scatter on full updated indices, and out-of-bound accesses + // will have no effect on the result as guaranteed by the scatter + // semantics. + const int64_t num_groups = + operands[0].sharding().NumTiles(*trivial_slice_dims); + const int64_t num_tiles = operands[0].sharding().TotalNumTiles(); + const GroupedSharding operand_grouped = + hlo_sharding_util::GroupShardingOnDims(operands[0].sharding(), + *trivial_slice_dims); + // See if we can group partially replicated dimensions from the indices + // otherwise replicate it. + GroupedSharding indices_grouped = AlignGroupsWith( + hlo_sharding_util::GroupShardingOnReplicatedDim( + indices.sharding(), num_groups, num_tiles, indices.rank(), + ScatterIndexDimsByPriority(indices, scatter)), + operand_grouped); + // See if we can group partially replicated dimensions from the updates + // otherwise replicate it. + GroupedSharding update_grouped = AlignGroupsWith( + hlo_sharding_util::GroupShardingOnReplicatedDim( + updates[0].sharding(), num_groups, num_tiles, updates[0].rank(), + ScatterUpdateDimsByPriority(updates[0].base_shape(), operands[0], + scatter, slice_sizes)), + operand_grouped); + // For index and update sharding, if one is grouped partially but the + // other is replicated, pass through the partially grouped sharding to the + // other one. + if (!indices_grouped.sharding.IsTileMaximal() && + update_grouped.sharding.IsTileMaximal()) { + const HloSharding new_update_sharding = hlo_sharding_util:: + ScatterUpdateShardingFromIndexIndexPassthroughDimensions( + indices.sharding(), scatter); + update_grouped = AlignGroupsWith( hlo_sharding_util::GroupShardingOnReplicatedDim( - updates[0].sharding(), num_groups, num_tiles, updates[0].rank(), + new_update_sharding, num_groups, num_tiles, output_shape.rank(), ScatterUpdateDimsByPriority(updates[0].base_shape(), operands[0], scatter, slice_sizes)), operand_grouped); - // For index and update sharding, if one is grouped partially but the - // other is replicated, pass through the partially grouped sharding to the - // other one. - if (!indices_grouped.sharding.IsTileMaximal() && - update_grouped.sharding.IsTileMaximal()) { - const HloSharding new_update_sharding = hlo_sharding_util:: - ScatterUpdateShardingFromIndexIndexPassthroughDimensions( - indices.sharding(), scatter); - update_grouped = AlignGroupsWith( - hlo_sharding_util::GroupShardingOnReplicatedDim( - new_update_sharding, num_groups, num_tiles, output_shape.rank(), - ScatterUpdateDimsByPriority(updates[0].base_shape(), - operands[0], scatter, slice_sizes)), - operand_grouped); - } - if (indices_grouped.sharding.IsTileMaximal() && - !update_grouped.sharding.IsTileMaximal()) { - const HloSharding new_indices_sharding = hlo_sharding_util:: - ScatterIndexShardingFromUpdateIndexPassthroughDimensions( - updates[0].sharding(), scatter); - indices_grouped = AlignGroupsWith( - hlo_sharding_util::GroupShardingOnReplicatedDim( - new_indices_sharding, num_groups, num_tiles, indices.rank(), - ScatterIndexDimsByPriority(indices, scatter)), - operand_grouped); - } - const GroupedSharding& output_grouped = operand_grouped; - // Reshard indices to its intended sharding before adjusting. - indices = - indices.Reshard(hlo_sharding_util::UngroupSharding(indices_grouped)); - HloInstruction* indices_min; - std::tie(indices_min, std::ignore) = - IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( - operands[0], indices, operands[0].state().partition_id, - dnums.scatter_dims_to_operand_dims(), *trivial_slice_dims, - dnums.index_vector_dim(), b); - auto adjusted_indices = b->AddInstruction(HloInstruction::CreateBinary( - indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(), - indices_min)); - PartitionedHlo new_indices = indices.CloneWithNewHlo(adjusted_indices); - absl::InlinedVector per_group_operands = - PerGroupPartitionedHlos(operands, operand_grouped, b, clean_ups); - absl::InlinedVector per_group_updates = - PerGroupPartitionedHlos(updates, update_grouped, b, clean_ups); - PartitionedHlo per_group_new_indices = - PerGroupPartitionedHlo(new_indices, indices_grouped, b, clean_ups); - auto pshape = - MaybeGetTuplePerGroupBaseShape(output_grouped, output_shape); - TF_ASSIGN_OR_RETURN( - HloInstruction * pscatter, - PartitionScatter( - scatter, absl::MakeSpan(per_group_operands), - per_group_new_indices, absl::MakeSpan(per_group_updates), pshape, - HloSharding::Single(scatter->shape(), output_grouped.sharding), - slice_sizes, visitor)); - pscatter->set_sharding(HloSharding::Single( - pscatter->shape(), - hlo_sharding_util::UngroupSharding(output_grouped))); - VLOG(5) - << "[Scatter partitioning]: Partitioned as trivially sliced operand"; - return PartitionedHlo(pscatter, output_shape, operands[0].state()) - .Reshard(output_sharding) - .hlo(); } - return nullptr; + if (indices_grouped.sharding.IsTileMaximal() && + !update_grouped.sharding.IsTileMaximal()) { + const HloSharding new_indices_sharding = hlo_sharding_util:: + ScatterIndexShardingFromUpdateIndexPassthroughDimensions( + updates[0].sharding(), scatter); + indices_grouped = AlignGroupsWith( + hlo_sharding_util::GroupShardingOnReplicatedDim( + new_indices_sharding, num_groups, num_tiles, indices.rank(), + ScatterIndexDimsByPriority(indices, scatter)), + operand_grouped); + } + const GroupedSharding& output_grouped = operand_grouped; + // Reshard indices to its intended sharding before adjusting. + indices = + indices.Reshard(hlo_sharding_util::UngroupSharding(indices_grouped)); + HloInstruction* indices_min; + std::tie(indices_min, std::ignore) = + IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( + operands[0], indices, operands[0].state().partition_id, + dnums.scatter_dims_to_operand_dims(), *trivial_slice_dims, + dnums.index_vector_dim(), b); + auto adjusted_indices = b->AddInstruction(HloInstruction::CreateBinary( + indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(), + indices_min)); + PartitionedHlo new_indices = indices.CloneWithNewHlo(adjusted_indices); + std::vector per_group_operands = + PerGroupPartitionedHlos(operands, operand_grouped, b, clean_ups); + std::vector per_group_updates = + PerGroupPartitionedHlos(updates, update_grouped, b, clean_ups); + PartitionedHlo per_group_new_indices = + PerGroupPartitionedHlo(new_indices, indices_grouped, b, clean_ups); + auto pshape = MaybeGetTuplePerGroupBaseShape(output_grouped, output_shape); + TF_ASSIGN_OR_RETURN( + HloInstruction * pscatter, + PartitionScatter( + scatter, per_group_operands, per_group_new_indices, + per_group_updates, pshape, + HloSharding::Single(scatter->shape(), output_grouped.sharding), + slice_sizes, visitor, allow_recursive)); + pscatter->set_sharding(HloSharding::Single( + pscatter->shape(), hlo_sharding_util::UngroupSharding(output_grouped))); + if (allow_recursive) { + VLOG(5) << "[Scatter partitioning]: Partitioned as trivially sliced " + "operand"; + } + return PartitionedHlo(pscatter, output_shape, operands[0].state()) + .Reshard(output_sharding) + .hlo(); + } + return nullptr; } // Returns a full list of partitioning methods used for scatter. @@ -1500,89 +1501,25 @@ ScatterPartitionMethods() { "PartitionScatterIndexPassthroughDimensions"}}; } -// Estimates the cost for each partitioning methods for scatter. -int64_t ScatterPartitionMethodCostModel( +// Estimates the memory and communication for each partitioning methods for +// scatter. +std::pair ScatterPartitionMethodCostModel( decltype(PartitionScatter)* partition_method, - const HloScatterInstruction* scatter, absl::Span operands, - PartitionedHlo& indices, absl::Span updates, - const Shape& output_shape, const HloSharding& output_sharding, - absl::Span slice_sizes, SpmdPartitioningVisitor* visitor) { - if (partition_method == PartitionScatterIndexParallelDimensions) { - // Always prioritize index parallel paritioning, and assume it has zero - // cost. - return 0; - } - if (partition_method == PartitionScatterOperandPassthroughDimensions) { - auto operand_passthrough_sharding = hlo_sharding_util:: - ScatterUpdateShardingFromOutputOperandPassthroughDimensions( - operands[0].base_shape(), operands[0].sharding(), *scatter, - slice_sizes); - if (!operand_passthrough_sharding) { - return INT64_MAX; - } - // Consider the possibility of having to fully rematerialize the update - // if the sharding is incompatible. - const int64_t max_potential_updates_shape_size = - absl::c_all_of( - updates, - [&operand_passthrough_sharding](const PartitionedHlo& phlo) { - return hlo_sharding_util::IsSubTilingOrEqualSharding( - phlo.base_shape(), phlo.sharding(), - *operand_passthrough_sharding) || - hlo_sharding_util::IsSubTilingOrEqualSharding( - phlo.base_shape(), *operand_passthrough_sharding, - phlo.sharding()); - }) - ? BaseShapeSizeSum(updates, *operand_passthrough_sharding) - : BaseShapeSizeSum(updates); - - return std::max( - ShapeSizeSum(operands) + - BaseShapeSizeSum(updates, *operand_passthrough_sharding) + - ShapeSizeInBytes(indices.base_shape()), - max_potential_updates_shape_size); - } - if (partition_method == PartitionScatterTrivialSlicedOperandDimensions) { - auto trivial_slice_dims = - GatherScatterOperandPartitionedOnTrivialSliceDims( - operands[0], - scatter->scatter_dimension_numbers() - .scatter_dims_to_operand_dims(), - slice_sizes); - return !trivial_slice_dims - ? INT64_MAX - : ShapeSizeSum(operands) + BaseShapeSizeSum(updates) + - ShapeSizeInBytes(indices.base_shape()); - } - if (partition_method == PartitionScatterIndexPassthroughDimensions) { - const HloSharding index_passthrough_sharding = hlo_sharding_util:: - ScatterUpdateShardingFromIndexIndexPassthroughDimensions( - indices.sharding(), scatter); - if (index_passthrough_sharding.IsTileMaximal()) { - return INT64_MAX; - } - // Consider the possibility of having to fully rematerialize the update - // if the sharding is incompatible. - const int64_t max_potential_updates_shape_size = - absl::c_all_of( - updates, - [&index_passthrough_sharding](const PartitionedHlo& phlo) { - return hlo_sharding_util::IsSubTilingOrEqualSharding( - phlo.base_shape(), phlo.sharding(), - index_passthrough_sharding) || - hlo_sharding_util::IsSubTilingOrEqualSharding( - phlo.base_shape(), index_passthrough_sharding, - phlo.sharding()); - }) - ? BaseShapeSizeSum(updates, index_passthrough_sharding) - : BaseShapeSizeSum(updates); - return std::max( - BaseShapeSizeSum(operands) + - BaseShapeSizeSum(updates, index_passthrough_sharding) + - ShapeSizeInBytes(indices.hlo()->shape()), - max_potential_updates_shape_size); - } - return INT64_MAX; + const HloScatterInstruction* scatter, + const std::vector& operands, const PartitionedHlo& indices, + const std::vector& updates, const Shape& output_shape, + const HloSharding& output_sharding, absl::Span slice_sizes, + SpmdPartitioningVisitor* visitor) { + if (partition_method == PartitionScatterIndexParallelDimensions) { + // Always prioritize index parallel partitioning, and assume it has zero + // cost. + return {0, 0}; + } + return EvaluatePartitionCost(scatter, partition_method, scatter, operands, + indices, updates, output_shape, output_sharding, + slice_sizes, visitor, + /*allow_recursive=*/false) + .value(); } // Returns a full list of partitioning methods for scatter ordered by the @@ -1590,63 +1527,76 @@ int64_t ScatterPartitionMethodCostModel( // TODO(b/245443033): Take recursion of gather/scatter partitioning into // consideration of the cost model. std::vector ScatterPartitionMethodsOrderedByCost( - const HloScatterInstruction* scatter, absl::Span operands, - PartitionedHlo& indices, absl::Span updates, - const Shape& output_shape, const HloSharding& output_sharding, - absl::Span slice_sizes, SpmdPartitioningVisitor* visitor) { - std::vector ordered_partition_methods; - std::vector ordered_costs; - auto scatter_partition_methods = ScatterPartitionMethods(); - for (auto [partition_method, _] : scatter_partition_methods) { - const int64_t cost = ScatterPartitionMethodCostModel( - partition_method, scatter, operands, indices, updates, output_shape, - output_sharding, slice_sizes, visitor); - auto offset = std::distance(ordered_costs.begin(), - absl::c_upper_bound(ordered_costs, cost)); - ordered_costs.insert(ordered_costs.begin() + offset, cost); - ordered_partition_methods.insert( - ordered_partition_methods.begin() + offset, partition_method); - } - CHECK_EQ(ordered_partition_methods.size(), - scatter_partition_methods.size()); - VLOG(5) << "Scatter partitioning methods(ordered by cost):"; - for (auto partition_method : ordered_partition_methods) { - VLOG(5) << " " - << absl::c_find_if( - scatter_partition_methods, - [&](const std::pair& p) { - return p.first == partition_method; - }) - ->second; - } - return ordered_partition_methods; + const HloScatterInstruction* scatter, + const std::vector& operands, const PartitionedHlo& indices, + const std::vector& updates, const Shape& output_shape, + const HloSharding& output_sharding, absl::Span slice_sizes, + SpmdPartitioningVisitor* visitor) { + std::vector ordered_partition_methods; + absl::flat_hash_map> + partition_method_costs; + auto scatter_partition_methods = ScatterPartitionMethods(); + for (auto [partition_method, method_name] : scatter_partition_methods) { + auto [memory_cost, communication_cost] = ScatterPartitionMethodCostModel( + partition_method, scatter, operands, indices, updates, output_shape, + output_sharding, slice_sizes, visitor); + + VLOG(5) << method_name << " has memory cost of " << memory_cost + << " bytes and communication cost of " << communication_cost + << " bytes"; + partition_method_costs.emplace( + partition_method, std::make_pair(memory_cost, communication_cost)); + ordered_partition_methods.push_back(partition_method); + } + absl::c_sort(ordered_partition_methods, [&](decltype(PartitionScatter)* lhs, + decltype(PartitionScatter)* rhs) { + auto [lhs_memory_cost, lhs_communication_cost] = + partition_method_costs[lhs]; + auto [rhs_memory_cost, rhs_communication_cost] = + partition_method_costs[rhs]; + return lhs_memory_cost != rhs_memory_cost + ? lhs_memory_cost < rhs_memory_cost + : lhs_communication_cost < rhs_communication_cost; + }); + VLOG(5) << "Scatter partitioning methods(ordered by cost):"; + for (auto partition_method : ordered_partition_methods) { + VLOG(5) << " " + << absl::c_find_if(scatter_partition_methods, + [&](const std::pair& p) { + return p.first == partition_method; + }) + ->second; + } + return ordered_partition_methods; } -StatusOr PartitionScatter( - const HloScatterInstruction* scatter, absl::Span operands, - PartitionedHlo& indices, absl::Span updates, +absl::StatusOr PartitionScatter( + const HloScatterInstruction* scatter, std::vector operands, + PartitionedHlo indices, std::vector updates, const Shape& output_shape, const HloSharding& output_sharding, - absl::Span slice_sizes, SpmdPartitioningVisitor* visitor) { + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive) { HloInstruction* partitioned_scatter; - for (auto partition_method : ScatterPartitionMethodsOrderedByCost( - scatter, operands, indices, updates, output_shape, output_sharding, - slice_sizes, visitor)) { + if (allow_recursive) { + for (auto partition_method : ScatterPartitionMethodsOrderedByCost( + scatter, operands, indices, updates, output_shape, output_sharding, + slice_sizes, visitor)) { TF_ASSIGN_OR_RETURN( partitioned_scatter, partition_method(scatter, operands, indices, updates, output_shape, - output_sharding, slice_sizes, visitor)); + output_sharding, slice_sizes, visitor, + allow_recursive)); if (partitioned_scatter) { return partitioned_scatter; } + } } - absl::InlinedVector operand_hlos, update_hlos; - absl::c_transform( - operands, std::back_inserter(operand_hlos), - [](PartitionedHlo& phlo) { return phlo.Replicate().hlo(); }); - absl::c_transform( - updates, std::back_inserter(update_hlos), - [](PartitionedHlo& phlo) { return phlo.Replicate().hlo(); }); + std::vector operand_hlos, update_hlos; + absl::c_transform(operands, std::back_inserter(operand_hlos), + [](PartitionedHlo phlo) { return phlo.Replicate().hlo(); }); + absl::c_transform(updates, std::back_inserter(update_hlos), + [](PartitionedHlo phlo) { return phlo.Replicate().hlo(); }); HloInstruction* new_scatter = visitor->builder()->AddInstruction(HloInstruction::CreateScatter( MaybeMakeTupleShape(operand_hlos), operand_hlos, @@ -1673,7 +1623,7 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { // Check all operands have the same shapes and shardings, and all updates have // the same shapes and shardings, and live with this assumption during scatter // partitioning. - absl::InlinedVector operands, updates; + std::vector operands, updates; absl::c_transform( scatter->scatter_operands(), std::back_inserter(operands), [this](HloInstruction* hlo) { return GetPartitionedHlo(hlo); }); @@ -1738,8 +1688,7 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { TF_ASSIGN_OR_RETURN( HloInstruction * pscatter, - PartitionScatter(scatter, absl::MakeSpan(operands), indices, - absl::MakeSpan(updates), scatter->shape(), + PartitionScatter(scatter, operands, indices, updates, scatter->shape(), scatter->sharding(), slice_sizes, this)); if (!pscatter) { return DefaultAction(hlo); diff --git a/xla/service/spmd/partition_assignment.cc b/xla/service/spmd/partition_assignment.cc index ceee9db79d5bb..ce28cc2aae0fa 100644 --- a/xla/service/spmd/partition_assignment.cc +++ b/xla/service/spmd/partition_assignment.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -55,7 +55,7 @@ NoopPartitioning::NoopPartitioning(int64_t num_partitions) << num_partitions; } -StatusOr NoopPartitioning::Run(HloModule* module) const { +absl::StatusOr NoopPartitioning::Run(HloModule* module) const { VLOG(2) << "No-op algorithm was called to partition module: " << module->name(); return false; @@ -86,7 +86,7 @@ PartitionAssignment::ChoosePartitioningAlgorithm( return PartitioningAlgorithm::CreateNoopPartitioning(num_partitions()); } -StatusOr PartitionAssignment::Run( +absl::StatusOr PartitionAssignment::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(2) << "Running partition assignment on module " << module->name(); diff --git a/xla/service/spmd/partition_assignment.h b/xla/service/spmd/partition_assignment.h index 97a71eb1607cf..63a8fb21b7221 100644 --- a/xla/service/spmd/partition_assignment.h +++ b/xla/service/spmd/partition_assignment.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -57,7 +57,7 @@ class PartitioningAlgorithm { int64_t num_partitions() const; // Assigns shardings to the given module. - virtual StatusOr Run(HloModule* module) const = 0; + virtual absl::StatusOr Run(HloModule* module) const = 0; protected: // Internal constructor for a given algorithm kind. Other fields must be @@ -78,7 +78,7 @@ class NoopPartitioning : public PartitioningAlgorithm { explicit NoopPartitioning(int64_t num_partitions); // Assigns shardings to the given module. - StatusOr Run(HloModule* module) const override; + absl::StatusOr Run(HloModule* module) const override; }; // PartitionAssignment assigns sharding annotations to some HLOs in the given @@ -100,7 +100,7 @@ class PartitionAssignment : public HloModulePass { // Runs the pass. using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/spmd/partition_assignment_test.cc b/xla/service/spmd/partition_assignment_test.cc index 94041d62d3b3c..a6770c283a342 100644 --- a/xla/service/spmd/partition_assignment_test.cc +++ b/xla/service/spmd/partition_assignment_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/spmd/schedule_aware_collective_ops_cse.cc b/xla/service/spmd/schedule_aware_collective_ops_cse.cc index 8d2c96c478709..eabdbe2c33d63 100644 --- a/xla/service/spmd/schedule_aware_collective_ops_cse.cc +++ b/xla/service/spmd/schedule_aware_collective_ops_cse.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -85,8 +85,8 @@ HloInstruction* MayConsiderCollective(HloInstruction* hlo, bool for_replicas) { return nullptr; } -StatusOr RunOnComputation(HloComputation* comp, bool for_replicas, - int64_t distance_threshold) { +absl::StatusOr RunOnComputation(HloComputation* comp, bool for_replicas, + int64_t distance_threshold) { // We consider estimate the live ranges of all-gathers by comparing their // users' distance to the root, e.g., height. bool changed = false; @@ -156,7 +156,7 @@ StatusOr RunOnComputation(HloComputation* comp, bool for_replicas, } // namespace -StatusOr ScheduleAwareCollectiveOpsCSE::Run( +absl::StatusOr ScheduleAwareCollectiveOpsCSE::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/spmd/schedule_aware_collective_ops_cse.h b/xla/service/spmd/schedule_aware_collective_ops_cse.h index 86a0779702af3..45cf248c845df 100644 --- a/xla/service/spmd/schedule_aware_collective_ops_cse.h +++ b/xla/service/spmd/schedule_aware_collective_ops_cse.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ class ScheduleAwareCollectiveOpsCSE : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/spmd/schedule_aware_collective_ops_cse_test.cc b/xla/service/spmd/schedule_aware_collective_ops_cse_test.cc index 63bc668aab3e8..2edbaf273ff7e 100644 --- a/xla/service/spmd/schedule_aware_collective_ops_cse_test.cc +++ b/xla/service/spmd/schedule_aware_collective_ops_cse_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ namespace { class CollectiveOpsCseTest : public HloTestBase { public: - StatusOr> RunPass( + absl::StatusOr> RunPass( absl::string_view hlo_module, int64_t distance_threshold = 100) { TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule( hlo_module, GetModuleConfigForTest())); @@ -38,7 +38,7 @@ class CollectiveOpsCseTest : public HloTestBase { pipeline.AddPass(distance_threshold, /*for_replicas=*/false); TF_RETURN_IF_ERROR(pipeline.Run(module.get()).status()); - return StatusOr>(std::move(module)); + return absl::StatusOr>(std::move(module)); } }; diff --git a/xla/service/spmd/spmd_partitioner.cc b/xla/service/spmd/spmd_partitioner.cc index 068e94eac10fa..f812cceee299d 100644 --- a/xla/service/spmd/spmd_partitioner.cc +++ b/xla/service/spmd/spmd_partitioner.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,8 +16,10 @@ limitations under the License. #include "xla/service/spmd/spmd_partitioner.h" #include +#include #include #include +#include #include #include #include @@ -27,13 +29,16 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/array.h" #include "xla/comparison_util.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -42,20 +47,33 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/hlo/utils/hlo_sharding_util.h" +#include "xla/layout_util.h" +#include "xla/literal.h" #include "xla/literal_util.h" #include "xla/protobuf_util.h" +#include "xla/service/call_graph.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/computation_layout.h" #include "xla/service/flatten_call_graph.h" #include "xla/service/hlo_cse.h" #include "xla/service/hlo_dce.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_pipeline.h" #include "xla/service/shape_inference.h" #include "xla/service/spmd/custom_call_handler.h" #include "xla/service/spmd/spmd_partitioner_util.h" #include "xla/service/tuple_simplifier.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/status_macros.h" +#include "xla/statusor.h" +#include "xla/types.h" #include "xla/util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/numbers.h" #include "tsl/platform/statusor.h" namespace xla { @@ -381,7 +399,7 @@ HloInstruction* SpmdBuilder::AddInstruction( } PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target, - std::optional pad_value) { + std::optional pad_value) const { if (sharding() == target) { return *this; } @@ -426,9 +444,9 @@ PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target, return resharded; } -PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target, - std::optional pad_value, - bool allow_full_replication) { +PartitionedHlo PartitionedHlo::ReshardNoCache( + const HloSharding& target, std::optional pad_value, + bool allow_full_replication) const { VLOG(2) << "Resharding " << hlo_->ToString() << " from " << hlo_->sharding().ToString() << " to " << target.ToString(); const Shape& shape = hlo_->shape(); @@ -1197,7 +1215,7 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window, get_dynamic_slice_offset_on_output_if_needed()}); } -PartitionedHlo PartitionedHlo::Replicate() { +PartitionedHlo PartitionedHlo::Replicate() const { auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache; if (state_.partitioner->options().cache_all_gather) { for (auto& entry : cache) { @@ -1245,7 +1263,7 @@ PartitionedHlo PartitionedHlo::Replicate() { } HloInstruction* PartitionedHlo::ReplicatePartial( - absl::Span dims) { + absl::Span dims) const { CHECK(!sharding().IsTileMaximal()); const Shape& shard_shape = hlo()->shape(); Shape target_shape = shard_shape; @@ -1373,7 +1391,7 @@ HloInstruction* PartitionedHlo::ReplicatePartial( std::optional PartitionedHlo::ReshardToPartialReplicateWithAllGather( - const HloSharding& target) { + const HloSharding& target) const { if (!target.ReplicateOnLastTileDim()) { return std::nullopt; } @@ -1439,7 +1457,7 @@ PartitionedHlo::ReshardToPartialReplicateWithAllGather( std::optional PartitionedHlo::ReshardFromPartialReplicateWithDynamicSlice( - const HloSharding& target) { + const HloSharding& target) const { if (!sharding().ReplicateOnLastTileDim()) { return std::nullopt; } @@ -1766,19 +1784,21 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll( namespace { -// Matching a pattern like [..,X,..,Y] -> [..,X*Y,..,1] or [..,X,..,Y] -> -// [..,1,..,X*Y]. +// Matching the following patterns, where X, Y, cannot be 1, Z can be 1. +// 1. [..,X,..,Y,..] -> [..,X*Y,..,1,..] +// 2. [..,Y,..,X,..] -> [..,1,..,X*Y,..] +// 3. [..,X*Y,..,Z,..] -> [..,X,..,Y*Z,..] +// 4. [..,Z,..,X*Y,..] -> [..,Y*Z,..,X,..] // Output tuple: -// - HloSharding: The original sharding with an extra dimension added of size 1. -// - HloSharding: The sharding with the dimension we want to merge moved in -// place of the dimension of size 1 we added. -// - int: Dimension in the input that is going to be merged with another -// dimension (becoming bigger). -// - int: Dimension in the input that is going to be merged into another -// dimension (becoming 1). -std::optional> -PatternMatchMergeSharding(const Shape& shape, const HloSharding& source, - const HloSharding& target) { +// - HloSharding: The original sharding with an extra dimension added of size 1 +// or Y. +// - HloSharding: The sharding with the new dimension added moved in the place +// where we expect the target dimension to be. +// - int64_t: The index of X. +std::optional> +PatternMatchMergeOrSplitSharding(const Shape& shape, const Shape& base_shape, + const HloSharding& source, + const HloSharding& target) { if (!source.IsTiled() || !target.IsTiled()) { return std::nullopt; } @@ -1791,170 +1811,103 @@ PatternMatchMergeSharding(const Shape& shape, const HloSharding& source, target.tile_assignment().dimensions()[target.TiledDataRank()])) { return std::nullopt; } - for (int i = 0; i < target.TiledDataRank(); ++i) { - if (source.tile_assignment().dim(i) < target.tile_assignment().dim(i) && - (target.tile_assignment().dim(i) % source.tile_assignment().dim(i)) == - 0) { - auto get_reshaped_sharding = - [&](int64_t target_idx) -> std::optional { - if (target.tile_assignment().dim(target_idx) != 1) { - return std::nullopt; - } - if (target.tile_assignment().dim(i) != - source.tile_assignment().dim(i) * - source.tile_assignment().dim(target_idx)) { - return std::nullopt; - } - if (shape.dimensions(i) % source.tile_assignment().dim(target_idx) != - 0) { - return std::nullopt; - } - return hlo_sharding_util::SplitShardingDimension( - source, i, source.tile_assignment().dim(i)); - }; - for (int j = i - 1; j >= 0; --j) { - if (auto reshaped_sharding = get_reshaped_sharding(j)) { - VLOG(10) << "Triggered Merge From Left"; - std::vector dimensions( - reshaped_sharding->tile_assignment().dimensions().begin(), - reshaped_sharding->tile_assignment().dimensions().end()); - std::swap(dimensions[i + 1], dimensions[j]); - auto target_tile_assignment = - target.tile_assignment().Reshape(dimensions); - auto new_sharding = - source.HasPartialReplication() - ? HloSharding::PartialTile(target_tile_assignment, - source.metadata()) - : HloSharding::Tile(target_tile_assignment, - source.metadata()); - VLOG(10) << "Reshaped sharding before: " - << reshaped_sharding->ToString(); - VLOG(10) << "Reshaped sharding: " << new_sharding.ToString(); - return std::make_tuple(std::move(*reshaped_sharding), - std::move(new_sharding), i, j); - } - } - for (int j = i + 1; j < target.TiledDataRank(); ++j) { - if (auto reshaped_sharding = get_reshaped_sharding(j)) { - VLOG(10) << "Triggered Merge From Right"; - std::vector dimensions( - reshaped_sharding->tile_assignment().dimensions().begin(), - reshaped_sharding->tile_assignment().dimensions().end()); - std::swap(dimensions[i + 1], dimensions[j + 1]); - auto target_tile_assignment = - target.tile_assignment().Reshape(dimensions); - auto new_sharding = - source.HasPartialReplication() - ? HloSharding::PartialTile(target_tile_assignment, - source.metadata()) - : HloSharding::Tile(target_tile_assignment, - source.metadata()); - VLOG(10) << "Reshaped sharding before: " - << reshaped_sharding->ToString(); - VLOG(10) << "Reshaped sharding: " << new_sharding.ToString(); - return std::make_tuple(std::move(*reshaped_sharding), - std::move(new_sharding), i, j); - } - } - } - } - return std::nullopt; -} -// Matching a pattern like [..,X*Y,..,1] -> [..,X,..,Y] or [..,1,..,X*Y] -> -// [..,X,..,Y]. -// Output tuple: -// - HloSharding: The original sharding with an extra dimension added of size Y. -// - HloSharding: The sharding with the new dimension added moved in the place -// where we expect the target dimension to be. -// - int: Dimension in the input that is going to be unmerged (getting split). -// - int: Dimension in the input that is going to be the destination of the -// unmerged dimension. -std::optional> -PatternMatchUnmergeSharding(const Shape& shape, const Shape& base_shape, - const HloSharding& source, - const HloSharding& target) { - if (!source.IsTiled() || !target.IsTiled()) { - return std::nullopt; - } - if (source.TiledDataRank() != target.TiledDataRank()) { - return std::nullopt; + std::vector diff_index; + for (int64_t i = 0; i < target.TiledDataRank(); ++i) { + if (source.tile_assignment().dim(i) != target.tile_assignment().dim(i)) { + diff_index.push_back(i); + } } - if ((source.HasPartialReplication() ^ target.HasPartialReplication()) || - (source.HasPartialReplication() && - source.tile_assignment().dimensions()[source.TiledDataRank()] != - target.tile_assignment().dimensions()[target.TiledDataRank()])) { + if (diff_index.size() < 2) { return std::nullopt; } - for (int i = 0; i < target.TiledDataRank(); ++i) { - if (source.tile_assignment().dim(i) > target.tile_assignment().dim(i) && - target.tile_assignment().dim(i) != 1 && - base_shape.dimensions(i) % source.tile_assignment().dim(i) == 0 && - source.tile_assignment().dim(i) % target.tile_assignment().dim(i) == - 0) { - auto get_reshaped_sharding = - [&](int64_t target_dim) -> std::optional { - if (source.tile_assignment().dim(target_dim) == - target.tile_assignment().dim(target_dim) || - source.tile_assignment().dim(i) != - target.tile_assignment().dim(i) * - target.tile_assignment().dim(target_dim)) { - VLOG(10) << "Skipped for target dim different from dimension_size " - << target_dim - << " src size: " << source.tile_assignment().dim(i) - << " target size: " - << target.tile_assignment().dim(target_dim); - return std::nullopt; - } - return hlo_sharding_util::SplitShardingDimension( - source, i, target.tile_assignment().dim(i)); - }; - for (int j = i - 1; j >= 0; --j) { - if (auto reshaped_sharding = get_reshaped_sharding(j)) { - VLOG(10) << "Triggered Unmerge to Right i = " << i << ",j = " << j; - std::vector dimensions( - reshaped_sharding->tile_assignment().dimensions().begin(), - reshaped_sharding->tile_assignment().dimensions().end()); - std::swap(dimensions[i + 1], dimensions[j]); - auto target_tile_assignment = - target.tile_assignment().Reshape(dimensions); - auto new_sharding = - source.HasPartialReplication() - ? HloSharding::PartialTile(target_tile_assignment, - source.metadata()) - : HloSharding::Tile(target_tile_assignment, - source.metadata()); - VLOG(10) << "Reshaped sharding before: " - << reshaped_sharding->ToString(); - VLOG(10) << "Reshaped sharding: " << new_sharding.ToString(); - return std::make_tuple(std::move(*reshaped_sharding), - std::move(new_sharding), i, j); + + // Iterate every pair of elements in diff_index. + for (int64_t diff_index_i = 0; diff_index_i < diff_index.size(); + ++diff_index_i) { + for (int64_t diff_index_j = diff_index_i + 1; + diff_index_j < diff_index.size(); ++diff_index_j) { + int64_t i = diff_index[diff_index_i]; + int64_t j = diff_index[diff_index_j]; + const std::vector is_one = {source.tile_assignment().dim(i) == 1, + source.tile_assignment().dim(j) == 1, + target.tile_assignment().dim(i) == 1, + target.tile_assignment().dim(j) == 1}; + int64_t new_dim_size; + switch (std::count(is_one.begin(), is_one.end(), true)) { + case 1: { + if (source.tile_assignment().dim(i) * + source.tile_assignment().dim(j) != + target.tile_assignment().dim(i) * + target.tile_assignment().dim(j)) { + continue; + } + if (source.tile_assignment().dim(i) == 1 || + target.tile_assignment().dim(i) == 1) { + std::swap(i, j); + // After the swap, we always have the following. + // i is the dimension without size 1 in either source or target + // j is the dimension with size 1 in either source or target + } + if (target.tile_assignment().dim(j) == 1) { + // dim of size 1 is in the target + if (shape.dimensions(i) % source.tile_assignment().dim(j) != 0) { + continue; + } + new_dim_size = source.tile_assignment().dim(i); + } else { + // dim of size 1 is in the source + if (base_shape.dimensions(i) % source.tile_assignment().dim(i) != + 0) { + continue; + } + new_dim_size = target.tile_assignment().dim(i); + } + break; } - } - for (int j = i + 1; j < target.TiledDataRank(); ++j) { - if (auto reshaped_sharding = get_reshaped_sharding(j)) { - VLOG(10) << "Triggered Unmerge to Left i = " << i << ",j = " << j; - std::vector dimensions( - reshaped_sharding->tile_assignment().dimensions().begin(), - reshaped_sharding->tile_assignment().dimensions().end()); - std::swap(dimensions[i + 1], dimensions[j + 1]); - auto target_tile_assignment = - target.tile_assignment().Reshape(dimensions); - auto new_sharding = - source.HasPartialReplication() - ? HloSharding::PartialTile(target_tile_assignment, - source.metadata()) - : HloSharding::Tile(target_tile_assignment, - source.metadata()); - VLOG(10) << "Reshaped sharding before: " - << reshaped_sharding->ToString(); - VLOG(10) << "Reshaped sharding: " << new_sharding.ToString(); - return std::make_tuple(std::move(*reshaped_sharding), - std::move(new_sharding), i, j); + case 0: { + if (source.tile_assignment().dim(i) < + target.tile_assignment().dim(i)) { + std::swap(i, j); + // After the swap, we always have the following. + // source.tile_assignment().dim(i) > target.tile_assignment().dim(i) + // source.tile_assignment().dim(j) < target.tile_assignment().dim(j) + } + if (source.tile_assignment().dim(i) != + target.tile_assignment().dim(i) * + target.tile_assignment().dim(j)) { + continue; + } + if (base_shape.dimensions(i) % source.tile_assignment().dim(i) != 0) { + continue; + } + new_dim_size = target.tile_assignment().dim(i); + break; } + default: + continue; } + + auto reshaped_sharding = + hlo_sharding_util::SplitShardingDimension(source, i, new_dim_size); + std::vector dimensions( + reshaped_sharding.tile_assignment().dimensions().begin(), + reshaped_sharding.tile_assignment().dimensions().end()); + std::swap(dimensions[i + 1], dimensions[j + (j > i ? 1 : 0)]); + auto target_tile_assignment = + target.tile_assignment().Reshape(dimensions); + auto new_sharding = + source.HasPartialReplication() + ? HloSharding::PartialTile(target_tile_assignment, + source.metadata()) + : HloSharding::Tile(target_tile_assignment, source.metadata()); + VLOG(10) << "Reshaped sharding before: " << reshaped_sharding.ToString(); + VLOG(10) << "Reshaped sharding: " << new_sharding.ToString(); + return std::make_tuple(std::move(reshaped_sharding), + std::move(new_sharding), i); } } + return std::nullopt; } @@ -2040,16 +1993,15 @@ PartitionedHlo MergeReshapeHelper(const PartitionedHlo& to_reshape, } // namespace std::optional PartitionedHlo::TryComplexReshardHandling( - const HloSharding& target) { + const HloSharding& target) const { VLOG(5) << "Trying to split complicated reshard: " << sharding().ToString() << " to " << target.ToString(); const bool is_source_partially_replicated = sharding().ReplicateOnLastTileDim(); const bool is_target_partially_replicated = target.ReplicateOnLastTileDim(); - if (auto reshape = - PatternMatchMergeSharding(this->hlo()->shape(), sharding(), target)) { - auto& [before_sharding, new_reshaped_sharding, source_dim, target_dim] = - *reshape; + if (auto reshape = PatternMatchMergeOrSplitSharding( + this->hlo()->shape(), this->base_shape(), sharding(), target)) { + auto& [before_sharding, new_reshaped_sharding, source_dim] = *reshape; VLOG(10) << "Matched \"pattern_match_reshape()\": " << std::get<0>(*reshape).ToString(); VLOG(10) << "Original shape: " << hlo()->shape().ToString(); @@ -2077,39 +2029,6 @@ std::optional PartitionedHlo::TryComplexReshardHandling( } return reshaped; } - if (auto reshape = PatternMatchUnmergeSharding( - this->hlo()->shape(), this->base_shape(), sharding(), target)) { - auto& [before_sharding, new_reshaped_sharding, source_dim, target_dim] = - *reshape; - VLOG(10) << "Matched \"unmerge_sharding()\": " - << new_reshaped_sharding.ToString(); - VLOG(10) << "Original shape: " << hlo()->shape().ToString(); - VLOG(10) << "Base shape: " << base_shape().ToString(); - PartitionedHlo reshaped = SplitReshapeHelper( - *this, source_dim, this->hlo()->shape().dimensions(source_dim), - before_sharding); - VLOG(10) << "Reshaped shape: " << reshaped.hlo()->shape().ToString(); - VLOG(10) << "Reshaped base_shape: " << reshaped.base_shape().ToString(); - VLOG(10) << "Before sharding: " << before_sharding.ToString(); - VLOG(10) << "Reshaped: " << reshaped.hlo()->ToString(); - auto reshard = reshaped.ReshardNoCache(new_reshaped_sharding, - /*pad_value=*/std::nullopt, - /*allow_full_replication=*/false); - if (reshard.sharding() != new_reshaped_sharding) { - return std::nullopt; - } - auto reshaped_sharding = hlo_sharding_util::MergeShardingDimension( - reshard.sharding(), source_dim); - reshaped = MergeReshapeHelper(reshard, source_dim, reshaped_sharding); - if (reshaped.sharding() != target) { - reshaped = reshaped.ReshardNoCache(target, /*pad_value=*/std::nullopt, - /*allow_full_replication=*/false); - if (reshaped.sharding() != target) { - return std::nullopt; - } - } - return reshaped; - } if (auto intermediate_target = PatternMatchPartiallyReplicateDim(sharding(), target)) { VLOG(5) << "Matched \"pattern_match_partially_replicate_dim()\": " @@ -2163,7 +2082,8 @@ std::optional PartitionedHlo::TryComplexReshardHandling( } std::optional -PartitionedHlo::ReshardPartialReplicateWithAllToAll(const HloSharding& target) { +PartitionedHlo::ReshardPartialReplicateWithAllToAll( + const HloSharding& target) const { bool source_is_partial_replicate = sharding().ReplicateOnLastTileDim(); const auto& partial_replicate_sharding = source_is_partial_replicate ? sharding() : target; @@ -2308,6 +2228,27 @@ SpmdPartitioningVisitor::SpmdPartitioningVisitor( partitioner_(partitioner), call_graph_(call_graph) {} +SpmdPartitioningVisitor::SpmdPartitioningVisitor( + const SpmdPartitioningVisitor& src) + : changed_(src.changed_), + module_(src.module_), + num_partitions_(src.num_partitions_), + num_replicas_(src.num_replicas_), + collective_ops_creator_(src.collective_ops_creator_), + next_channel_id_(src.next_channel_id_), + b_(absl::StrCat(module_->entry_computation()->name(), "_spmd"), + /*hlo=*/nullptr), + partition_id_(collective_ops_creator_.create_partition_id(&b_)), + logger_(src.logger_), + options_(src.options_), + partitioner_(src.partitioner_), + call_graph_(src.call_graph_) {} + +std::unique_ptr SpmdPartitioningVisitor::Clone() + const { + return std::make_unique(*this); +} + PartitionedHlo::PartitioningState SpmdPartitioningVisitor::MakePartitioningState() { PartitionedHlo::PartitioningState state; @@ -2465,7 +2406,7 @@ Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) { auto get_grouped_sharding = [&](const HloSharding& sharding, const Shape& shape, const GroupedSharding* ref = - nullptr) -> StatusOr { + nullptr) -> absl::StatusOr { if (!sharding.IsTuple()) { GroupedSharding grouped = hlo_sharding_util::GetManualSubgroupSharding(sharding); @@ -2886,7 +2827,8 @@ Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) { int64_t first_nonsort_nonsharded_dim = -1; auto nshards = tile_assignment_dims[sort_dim]; for (int64_t dim = 0; dim < subshape.rank(); ++dim) { - if (dim == sort_dim || tile_assignment_dims[dim] != 1) { + if (dim == sort_dim || tile_assignment_dims[dim] != 1 || + subshape.dimensions(dim) == 1) { continue; } if (first_nonsort_nonsharded_dim == -1) { @@ -2901,40 +2843,59 @@ Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) { if (picked_dim == -1) { picked_dim = first_nonsort_nonsharded_dim; } - VLOG(2) - << "Sort partitioning - picked target dimension to move the sharding: " - << picked_dim; - // The sharding cannot exist in the sort dimension if there are no free - // dimensions to move the sharding into. In other words, we propagated the - // operand sharding which is on the sort dimension only because we knew we - // could pick a free dimension to move it into now. - CHECK_NE(picked_dim, -1) - << "Sort partitioning - sharding cannot exist in the sort dimension if " - "there are no free dimensions to move it into"; - // Move the sharding to the picked dimension - std::vector permutation( - cur_sharding.tile_assignment().dimensions().begin(), - cur_sharding.tile_assignment().dimensions().end()); - absl::c_iota(permutation, 0); - std::swap(permutation[sort_dim], permutation[picked_dim]); - auto new_sharding = - hlo_sharding_util::TransposeSharding(cur_sharding, permutation); - VLOG(2) << "Sort partitioning - new sharding: " << new_sharding.ToString(); std::vector new_operands; std::vector new_shardings; - for (auto& operand : hlo->operands()) { - new_operands.push_back( - GetPartitionedHlo(operand).Reshard(new_sharding).hlo()); - new_shardings.push_back(new_sharding); - } - auto new_output_sharding = new_sharding; - if (sharding.IsTuple()) { - new_output_sharding = HloSharding::Tuple(sort->shape(), new_shardings); + std::optional new_output_sharding; + if (picked_dim != -1) { + VLOG(2) << "Sort partitioning - picked target dimension to move the " + "sharding: " + << picked_dim; + // The sharding cannot exist in the sort dimension if there are no free + // dimensions to move the sharding into. In other words, we propagated the + // operand sharding which is on the sort dimension only because we knew we + // could pick a free dimension to move it into now. + CHECK_NE(picked_dim, -1) + << "Sort partitioning - sharding cannot exist in the sort dimension " + "if " + "there are no free dimensions to move it into"; + // Move the sharding to the picked dimension + std::vector permutation( + cur_sharding.tile_assignment().dimensions().begin(), + cur_sharding.tile_assignment().dimensions().end()); + absl::c_iota(permutation, 0); + std::swap(permutation[sort_dim], permutation[picked_dim]); + auto new_sharding = + hlo_sharding_util::TransposeSharding(cur_sharding, permutation); + VLOG(2) << "Sort partitioning - new sharding: " + << new_sharding.ToString(); + for (auto& operand : hlo->operands()) { + new_operands.push_back( + GetPartitionedHlo(operand).Reshard(new_sharding).hlo()); + new_shardings.push_back(new_sharding); + } + new_output_sharding = new_sharding; + if (sharding.IsTuple()) { + new_output_sharding = HloSharding::Tuple(sort->shape(), new_shardings); + } + } else { + // AllGather the sort dim. + auto new_sharding = + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(cur_sharding, + {sort_dim}); + for (auto& operand : hlo->operands()) { + new_operands.push_back( + GetPartitionedHlo(operand).Reshard(new_sharding).hlo()); + new_shardings.push_back(new_sharding); + } + new_output_sharding = new_sharding; + if (sharding.IsTuple()) { + new_output_sharding = HloSharding::Tuple(sort->shape(), new_shardings); + } } auto final_sort = b_.AddInstruction(hlo->CloneWithNewOperands( - MakePartitionedShape(sort->shape(), new_output_sharding), + MakePartitionedShape(sort->shape(), *new_output_sharding), new_operands)); - final_sort->set_sharding(new_output_sharding); + final_sort->set_sharding(*new_output_sharding); PartitionedHlo psort(final_sort, sort->shape(), MakePartitioningState()); SetPartitionedHlo(sort, psort.Reshard(sort->sharding())); return OkStatus(); @@ -3035,7 +2996,7 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) { auto shard_reshape = [](PartitionedHlo& operand, const HloSharding& sharding, - const Shape& base_shape) -> StatusOr { + const Shape& base_shape) -> absl::StatusOr { auto replicate = [&] { HloInstruction* rep = operand.Replicate().hlo(); HloInstruction* reshape = operand.state().b->AddInstruction( @@ -3196,11 +3157,11 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) { // Try to use PropagateShardingThroughReshape to find compatible dimensions, // then group them and recursively partition other dimensions. - std::function(PartitionedHlo&, const HloSharding&, - const Shape&)> + std::function( + PartitionedHlo&, const HloSharding&, const Shape&)> recursive_shard = [&](PartitionedHlo& operand, const HloSharding& sharding, - const Shape& base_shape) -> StatusOr { + const Shape& base_shape) -> absl::StatusOr { const Shape& operand_base_shape = operand.base_shape(); HloSharding propagated = hlo_sharding_util::PropagateShardingThroughReshape( operand_base_shape, base_shape, operand.sharding()); @@ -3383,7 +3344,8 @@ Status SpmdPartitioningVisitor::HandleAllReduce(HloInstruction* hlo) { TF_RET_CHECK(ar->use_global_device_ids()) << "Cross-partition allreduce in partial manual partitioning mode must " "use global device IDs."; - absl::flat_hash_map partition_to_group_id; + std::vector partition_to_group_id( + hlo->sharding().tile_assignment().num_elements()); hlo->sharding().tile_assignment().Each( [&](absl::Span indices, int64_t partition) { int64_t group_id = 0; @@ -4678,7 +4640,7 @@ Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) { return OkStatus(); } -StatusOr SpmdPartitioningVisitor::DoPartition( +absl::StatusOr SpmdPartitioningVisitor::DoPartition( HloComputation* computation, const HloSharding& root_sharding, const SpmdPartitionerOptions& options) { VLOG(2) << "Partitioning computation " << computation->name() << " for " @@ -4723,34 +4685,28 @@ SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64_t num_partitions, SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction, const std::vector>& partition_subgroups, int64_t channel_id) { - if (partition_subgroups.size() <= 1) { - std::vector groups(num_replicas); - // TODO(yuanzx): Unify subgroup definition with AllToAll. - for (int64_t i = 0; i < num_replicas; ++i) { - groups[i].add_replica_ids(i); - } - HloComputation* reduction_clone = - reduction->parent()->AddComputationAndUnifyNamesAndIds( - reduction->Clone(), false); - HloInstruction* all_reduce = - b->AddInstruction(HloInstruction::CreateAllReduce( - operand->shape(), {operand}, reduction_clone, groups, - /*constrain_layout=*/false, channel_id, - /*use_global_device_ids=*/false)); - reduction_clone->SetCollectiveCallInstruction(all_reduce); - return all_reduce; - } - std::vector device_groups; - device_groups.reserve(partition_subgroups.size() * num_replicas); - for (int64_t i = 0; i < num_replicas; ++i) { - for (const auto& pgroup : partition_subgroups) { + if (partition_subgroups.size() <= 1) { + device_groups.reserve(num_replicas); + for (int64_t rid = 0; rid < num_replicas; ++rid) { device_groups.emplace_back(); - for (int64_t pid : pgroup) { - device_groups.back().add_replica_ids(i * num_partitions + pid); + for (int64_t pid = 0; pid < num_partitions; ++pid) { + device_groups.back().add_replica_ids(rid * num_partitions + pid); + } + } + } else { + device_groups.reserve(partition_subgroups.size() * num_replicas); + for (int64_t rid = 0; rid < num_replicas; ++rid) { + for (const auto& pgroup : partition_subgroups) { + device_groups.emplace_back(); + for (int64_t pid : pgroup) { + device_groups.back().add_replica_ids(rid * num_partitions + + pid); + } } } } + HloComputation* reduction_clone = reduction->parent()->AddComputationAndUnifyNamesAndIds( reduction->Clone(), false); @@ -4970,7 +4926,7 @@ HloInstruction* SpmdPartitioner::AllReduceAlongShardingDimsInternal( return result; } -StatusOr SpmdPartitioner::PartitionComputation( +absl::StatusOr SpmdPartitioner::PartitionComputation( HloComputation* computation, const HloSharding& root_sharding, int64_t* next_channel_id, SpmdLogger* logger, const CallGraph& call_graph) { auto visitor = CreateVisitor(computation, num_partitions_, num_replicas_, @@ -4989,9 +4945,62 @@ std::unique_ptr SpmdPartitioner::CreateVisitor( next_channel_id, logger, std::move(options), this, call_graph); } -StatusOr SpmdPartitioner::Run( +int64_t SpmdPartitioner::MemoryCostInBytes(HloInstruction* hlo) { + auto memory_cost_for_operands = [](HloInstruction* hlo) { + int64_t memory = 0; + for (const HloInstruction* operand : hlo->operands()) { + memory += ShapeSizeInBytes(operand->shape()); + } + return memory; + }; + switch (hlo->opcode()) { + // Calculate memory cost for operands only for ops that re-use input buffers + // for their output buffers. + case HloOpcode::kAllReduce: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kScatter: + case HloOpcode::kWhile: + case HloOpcode::kTuple: + return memory_cost_for_operands(hlo); + default: + // TODO(b/311194120): Consider fusion of element-wise ops and other ops + // which doesn't need the full buffer for all operands. + return memory_cost_for_operands(hlo) + ShapeSizeInBytes(hlo->shape()); + } +} + +int64_t SpmdPartitioner::CommunicationCostInBytes(HloInstruction* hlo) { + CHECK(IsCollective(hlo)); + switch (hlo->opcode()) { + case HloOpcode::kAllReduce: + return ShapeSizeInBytes(hlo->shape()) * 2; + case HloOpcode::kCollectivePermute: + return ShapeSizeInBytes(hlo->shape()); + case HloOpcode::kAllGather: { + HloAllGatherInstruction* ag = Cast(hlo); + int64_t group_size = + ag->shape().dimensions(ag->all_gather_dimension()) / + ag->operand(0)->shape().dimensions(ag->all_gather_dimension()); + return ShapeSizeInBytes(hlo->shape()) * (group_size - 1) / group_size; + } + case HloOpcode::kAllToAll: { + int64_t group_size; + if (!hlo->replica_groups().empty()) { + group_size = hlo->replica_groups()[0].replica_ids_size(); + } else { + group_size = hlo->channel_id() ? num_partitions_ : num_replicas_; + } + return ShapeSizeInBytes(hlo->shape()) * (group_size - 1) / group_size; + } + default: + return 0; + } +} + +absl::StatusOr SpmdPartitioner::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { + set_execution_threads(execution_threads); TF_RETURN_IF_ERROR(PreprocessSharding(module, execution_threads)); TF_RETURN_IF_ERROR(PreprocessHlos(module, execution_threads)); diff --git a/xla/service/spmd/spmd_partitioner.h b/xla/service/spmd/spmd_partitioner.h index de9063deab938..9877f60e83509 100644 --- a/xla/service/spmd/spmd_partitioner.h +++ b/xla/service/spmd/spmd_partitioner.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ #define XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ +#include #include #include #include @@ -81,6 +82,11 @@ struct SpmdPartitionerOptions { // Whether to skip checking the numbers and shardings of windowed einsum's // users. bool skip_checking_windowed_einsum_users = false; + + // Enables windowed einsum for operand all-gather. + bool enable_windowed_einsum_for_all_gather = true; + // Enables windowed einsum for result reduce-scatter. + bool enable_windowed_einsum_for_reduce_scatter = true; }; // Class to wrap the computation builder to capture information during SPMD @@ -222,17 +228,17 @@ class SpmdPartitioner : public HloModulePass { collective_ops_creator_(std::move(collective_ops_creator)) {} absl::string_view name() const override { return "spmd-partitioning"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; // Transforms the given computation with SPMD instructions, replacing it with // a new computation. - StatusOr PartitionComputation(HloComputation* computation, - const HloSharding& root_sharding, - int64_t* next_channel_id, - SpmdLogger* logger, - const CallGraph& call_graph); + absl::StatusOr PartitionComputation(HloComputation* computation, + const HloSharding& root_sharding, + int64_t* next_channel_id, + SpmdLogger* logger, + const CallGraph& call_graph); // Creates all-gather(s) based on HloSharding. Can be overridden to customize. // The default uses a single all-gather even if there are multiple sharded @@ -256,13 +262,25 @@ class SpmdPartitioner : public HloModulePass { const SpmdPartitionerOptions& options() { return options_; } - protected: virtual std::unique_ptr CreateVisitor( HloComputation* computation, int64_t num_partitions, int64_t num_replicas, const SPMDCollectiveOpsCreator& collective_ops_creator, int64_t* next_channel_id, SpmdLogger* logger, SpmdPartitionerOptions options, const CallGraph& call_graph); + // Estimate the memory cost for an op, override this for target-specific + // op buffer implementation. + virtual int64_t MemoryCostInBytes(HloInstruction* hlo); + + // Estimate the communication cost for a collective op, override this for + // target-specific collective implementation. + virtual int64_t CommunicationCostInBytes(HloInstruction* hlo); + + const absl::flat_hash_set& execution_threads() const { + return execution_threads_; + } + + protected: // This is the internal implementation for AllGatherShards(), returns a pair // of hlo instructions whose first element is the result of the all-gather // shard(which might not be the all-gather itself and it could go through @@ -306,12 +324,18 @@ class SpmdPartitioner : public HloModulePass { HloModule* module, const absl::flat_hash_set& execution_threads); + void set_execution_threads( + const absl::flat_hash_set& execution_threads) { + execution_threads_ = execution_threads; + } + const int64_t num_partitions_; const int64_t num_replicas_; SpmdPartitionerOptions options_; SPMDCollectiveOpsCreator collective_ops_creator_; std::vector> device_groups_; + absl::flat_hash_set execution_threads_; }; // Class describes partition state of the data represented by an HLO created @@ -373,7 +397,7 @@ class PartitionedHlo { // specified pad value used during resharding. Could only modify the reshard // cache. PartitionedHlo Reshard(const HloSharding& target, - std::optional pad_value = std::nullopt); + std::optional pad_value = std::nullopt) const; // Pads the garbage area of the output with the provided value. Normally, // unevenly partitioned dimensions are padded on the right, but this function @@ -418,10 +442,10 @@ class PartitionedHlo { // Helper function to replicate the data on all devices. Could only modify // the reshard cache. - PartitionedHlo Replicate(); + PartitionedHlo Replicate() const; // Helper function to replicate the data for partitions along the given dims. - HloInstruction* ReplicatePartial(absl::Span dims); + HloInstruction* ReplicatePartial(absl::Span dims) const; // Set state of the partitoned HLO. void set_state(PartitioningState state) { state_ = std::move(state); } @@ -431,7 +455,7 @@ class PartitionedHlo { // cache, although it would indirectly modify by calling Replicate(). PartitionedHlo ReshardNoCache(const HloSharding& target, std::optional pad_value = std::nullopt, - bool allow_full_replication = true); + bool allow_full_replication = true) const; // Helper function to broadcast data from a single device to all devices. PartitionedHlo Broadcast() const; @@ -439,7 +463,7 @@ class PartitionedHlo { // Try to perform complicated reshard handling by splitting a big reshard into // multiple reshards using that can be handled directly. std::optional TryComplexReshardHandling( - const HloSharding& target); + const HloSharding& target) const; // Helper function to reshard the tensor using AllToAll (instead of the // default of Replicate followed by Slice). @@ -452,15 +476,15 @@ class PartitionedHlo { // Helper function to reshard to partial replicate using AllGather. std::optional ReshardToPartialReplicateWithAllGather( - const HloSharding& target); + const HloSharding& target) const; // Helper function to reshard from partial replicate using DynamicSlice. std::optional ReshardFromPartialReplicateWithDynamicSlice( - const HloSharding& target); + const HloSharding& target) const; // Helper function to reshard from partial replicate using AllToAll. std::optional ReshardPartialReplicateWithAllToAll( - const HloSharding& target); + const HloSharding& target) const; // SPMD instruction. HloInstruction* hlo_; @@ -499,6 +523,8 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { SpmdPartitionerOptions options, SpmdPartitioner* partitioner, const CallGraph& call_graph); + SpmdPartitioningVisitor(const SpmdPartitioningVisitor& src); + Status DefaultAction(HloInstruction* hlo) override; Status HandleAllReduce(HloInstruction* hlo) override; Status HandleBroadcast(HloInstruction* hlo) override; @@ -536,7 +562,7 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { // Implementation of dot partitioning given DotGeneralDimsMapping. Status HandleDotHelper(HloInstruction* hlo, const DotConvDimsMapping& dims_mapping, - absl::FunctionRef( + absl::FunctionRef( HloInstruction*, HloInstruction*, SpmdBuilder*, const Window& conv_window)> create_sharded_dot); @@ -552,6 +578,8 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { // Convenient custom ops defined by the partitioner itself. Status HandleCustomCallSPMDInternal_RotateRight(HloInstruction* hlo); + virtual std::unique_ptr Clone() const; + // Returns the PartitionedHlo that corresponds to the original hlo. PartitionedHlo& GetPartitionedHlo(const HloInstruction* hlo) { CHECK_EQ(partitioned_instructions_.count(hlo), 1); @@ -583,9 +611,9 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { SpmdBuilder* builder() { return &b_; } - virtual StatusOr DoPartition(HloComputation* computation, - const HloSharding& root_sharding, - const SpmdPartitionerOptions& options); + virtual absl::StatusOr DoPartition( + HloComputation* computation, const HloSharding& root_sharding, + const SpmdPartitionerOptions& options); virtual double GetComputationTimeInMilliSec(HloInstruction* hlo) { return 0.0; @@ -606,6 +634,21 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { const CallGraph& call_graph() { return call_graph_; } int64_t num_partitions() const { return num_partitions_; } + int64_t num_replicas() const { return num_replicas_; } + SpmdLogger* logger() { return logger_; } + const SpmdLogger* logger() const { return logger_; } + const SpmdPartitionerOptions& options() const { return options_; } + SpmdPartitioner* partitioner() { return partitioner_; } + const SpmdPartitioner* partitioner() const { return partitioner_; } + SPMDCollectiveOpsCreator& collective_ops_creator() { + return collective_ops_creator_; + } + const SPMDCollectiveOpsCreator& collective_ops_creator() const { + return collective_ops_creator_; + } + HloModule* module() { return module_; } + const HloModule* module() const { return module_; } + void set_module(HloModule* module) { module_ = module; } // Information about a loop created for windowed dot-general. Used when // DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor diff --git a/xla/service/spmd/spmd_partitioner_test.cc b/xla/service/spmd/spmd_partitioner_test.cc index 4a1c45cc6e855..357f7d2841bdd 100644 --- a/xla/service/spmd/spmd_partitioner_test.cc +++ b/xla/service/spmd/spmd_partitioner_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -47,7 +47,7 @@ class SpmdPartitioningTest : public HloTestBase, public ::testing::WithParamInterface { public: - StatusOr> PartitionComputation( + absl::StatusOr> PartitionComputation( absl::string_view hlo_module, int64_t num_devices, bool conv_halo_exchange_always_on_lhs = true, bool choose_faster_windowed_einsum = false, @@ -99,7 +99,7 @@ class SpmdPartitioningTest TF_RETURN_IF_ERROR(pass.Run(module.get()).status()); VerifyNoShardingOnCollectives(module.get()); - return StatusOr>(std::move(module)); + return absl::StatusOr>(std::move(module)); } void VerifyNoShardingOnCollectives(HloModule* module) { @@ -3215,6 +3215,40 @@ ENTRY entry { } } +TEST_P(SpmdPartitioningTest, SortShardedOnSortDim_TwoOperands_FreeDimOfSize1) { + absl::string_view hlo_string = R"( +HloModule module + +compare { + p.0.lhs = f32[] parameter(0), sharding={replicated} + p.0.rhs = f32[] parameter(1), sharding={replicated} + p.1.lhs = s32[] parameter(2), sharding={replicated} + p.1.rhs = s32[] parameter(3), sharding={replicated} + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT, sharding={replicated} +} + +ENTRY entry { + param.0 = f32[1,1024]{1,0} parameter(0) + negate.0 = f32[1,1024]{1,0} negate(param.0), sharding={devices=[1,8]<=[8]} + iota.0 = s32[1,1024]{1,0} iota(), iota_dimension=1, sharding={devices=[1,8]<=[8]} + ROOT sort.0 = (f32[1,1024]{1,0}, s32[1,1024]{1,0}) sort(negate.0, iota.0), dimensions={1}, is_stable=true, to_apply=compare, sharding={{devices=[1,8]<=[8]},{devices=[1,8]<=[8]}} + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + for (HloInstruction* inst : module->entry_computation()->instructions()) { + if (inst->opcode() == HloOpcode::kSort) { + for (HloInstruction* operand : inst->operands()) { + EXPECT_EQ(operand->shape().dimensions(0), 1); + EXPECT_EQ(operand->shape().dimensions(1), 1024); + } + EXPECT_THAT(inst, op::Sort(op::AllReduce(), op::AllReduce())); + } + EXPECT_NE(inst->opcode(), HloOpcode::kAllToAll); + } +} + TEST_P(SpmdPartitioningTest, SortShardedOnSortDim_ThreeOperands) { absl::string_view hlo_string = R"( HloModule module, entry_computation_layout={(f32[1024,1024]{1,0})->(f32[1024,1024]{1,0},s32[1024,1024]{1,0},s32[1024,1024]{1,0})} @@ -7510,6 +7544,43 @@ ENTRY entry { op::Shape("(f32[2,3],f32[2,3])"))); } +TEST_P(SpmdPartitioningTest, VariadicScatterSharedOperands) { + absl::string_view hlo_string = R"( +HloModule module + +add (lhs.0: f32[], lhs.1: f32[], rhs.0: f32[], rhs.1: f32[]) -> (f32[], f32[]) { + lhs.0 = f32[] parameter(0) + lhs.1 = f32[] parameter(1) + rhs.0 = f32[] parameter(2) + rhs.1 = f32[] parameter(3) + sum.0 = f32[] add(lhs.0, rhs.0) + sum.1 = f32[] add(lhs.1, rhs.1) + ROOT tuple = tuple(sum.0, sum.1) +} + +ENTRY entry { + %input.0 = f32[8,16,32] parameter(0), sharding={devices=[4,1,1,2]<=[8] last_tile_dim_replicate} + %indices = s32[16,1] parameter(1), sharding={replicated} + %updates.0 = f32[8,16,16] parameter(2), sharding={devices=[4,1,1,2]<=[8] last_tile_dim_replicate} + %updates.1 = f32[8,16,16] parameter(3), sharding={devices=[4,1,1,2]<=[8] last_tile_dim_replicate} + ROOT %scatter = (f32[8,16,32], f32[8,16,32]) scatter(%input.0, %input.0, %indices, %updates.0, %updates.1), + to_apply=add, + update_window_dims={0,1}, + inserted_window_dims={2}, + scatter_dims_to_operand_dims={2}, + index_vector_dim=1, + indices_are_sorted=true, + unique_indices=true, + sharding={{devices=[4,1,1,2]<=[8] last_tile_dim_replicate}, {devices=[4,1,1,2]<=[8] last_tile_dim_replicate}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Scatter(), op::Shape("(f32[2,16,32],f32[2,16,32])"))); +} + TEST_P(SpmdPartitioningTest, PassthroughScatter) { absl::string_view hlo_string = R"( HloModule module @@ -8418,6 +8489,37 @@ ENTRY entry { EXPECT_THAT(root, dot); } +TEST_P(SpmdPartitioningTest, SimpleSparseDot) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[2,24,128] parameter(0), + sharding={devices=[2,2,1]<=[4]} + %rhs = f32[2,32,256] parameter(1), + sharding={devices=[2,1,1,2]<=[4] last_tile_dim_replicate} + %meta = u16[2,24,16] parameter(2), + sharding={devices=[2,2,1]<=[4]} + ROOT %dot = f32[2,24,32] dot(%lhs, %rhs, %meta), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, sparsity=L.2@2:4, + sharding={devices=[2,2,1]<=[4]} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + const auto lhs = AllOf(op::Shape("f32[1,12,128]"), op::Parameter(0)); + const auto rhs = AllOf(op::Shape("f32[1,32,256]"), op::Parameter(1)); + const auto meta = AllOf(op::Shape("u16[1,12,16]"), op::Parameter(2)); + auto dot = AllOf(op::Shape("f32[1,12,32]"), + ::testing::MakeMatcher(new ::xla::testing::HloMatcher( + HloOpcode::kDot, {lhs, rhs, meta}))); + const auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, dot); +} + TEST_P(SpmdPartitioningTest, DotPartialContracting) { absl::string_view hlo_string = R"( HloModule module @@ -10756,11 +10858,10 @@ ENTRY %module { auto operand = AllOf(op::Shape("s32[8,4,1,2]"), op::AllReduce()); auto indices = AllOf(op::Shape("s32[2,4,2]"), op::Parameter()); auto gather = AllOf(op::Shape("s32[4,2,1,2]"), op::Gather(operand, indices)); - EXPECT_THAT(root, op::AllReduce(op::DynamicUpdateSlice( - _, - op::AllReduce(op::AllReduce( - op::DynamicUpdateSlice(_, gather, _, _, _, _))), - _, _, _, _))); + EXPECT_THAT( + root, op::AllReduce(op::AllReduce(op::DynamicUpdateSlice( + _, op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)), + _, _, _, _)))); } TEST_P(SpmdPartitioningTest, @@ -11677,12 +11778,8 @@ ENTRY %module { auto update = AllOf(op::Shape("s32[4,2,1,2]"), op::DynamicSlice()); auto scatter = AllOf(op::Shape("s32[8,4,1,2]"), op::Scatter(operand, indices, update)); - EXPECT_THAT( - root, - op::AllReduce(op::AllReduce(op::AllReduce(op::DynamicUpdateSlice( - _, - op::DynamicSlice(op::AllReduce(op::AllReduce(scatter)), _, _, _, _), - _, _, _, _))))); + EXPECT_THAT(root, op::AllReduce(op::DynamicUpdateSlice( + _, op::AllReduce(op::AllReduce(scatter)), _, _, _, _))); } TEST_P(SpmdPartitioningTest, @@ -13953,6 +14050,36 @@ ENTRY %entry { EXPECT_THAT(topk_operand, op::Shape("bf16[64,128000]{1,0}")); } +TEST_P(SpmdPartitioningTest, TopKCustomCallManualSharding) { + absl::string_view hlo_string = R"( +HloModule module + +region { + Arg_2.22549 = s32[] parameter(2) + Arg_3.22550 = s32[] parameter(3) + Arg_0.22547 = bf16[] parameter(0) + Arg_1.22548 = bf16[] parameter(1) + ROOT compare.22551 = pred[] compare(Arg_0.22547, Arg_1.22548), direction=GT, type=TOTALORDER +} + +ENTRY %entry { + %p0 = bf16[64,256000]{1,0} parameter(0), sharding={manual} + %custom-call = (bf16[64,40]{1,0}, s32[64,40]{1,0}) custom-call(bf16[64,256000]{1,0} %p0), custom_call_target="TopK", called_computations={%region}, sharding={{manual}, {manual}} + %get-tuple-element.336 = bf16[64,40]{1,0} get-tuple-element((bf16[64,40]{1,0}, s32[64,40]{1,0}) %custom-call), index=0, sharding={manual} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + EXPECT_EQ(FindInstruction(module.get(), HloOpcode::kSort), nullptr); + + auto topk_instruction = FindInstruction(module.get(), HloOpcode::kCustomCall); + EXPECT_EQ(topk_instruction->custom_call_target(), "TopK"); + EXPECT_THAT(topk_instruction->operand(0), op::Shape("bf16[64,256000]{1,0}")); + EXPECT_THAT(topk_instruction, + op::Shape("(bf16[64,40]{1,0}, s32[64,40]{1,0})")); +} + TEST_P(SpmdPartitioningTest, WindowedEinsumShouldMatchLhs_b305313406) { absl::string_view hlo_string = R"( HloModule module @@ -14014,6 +14141,46 @@ ENTRY %extracted_computation (param: f32[13,128,312,16,312]) -> f32[13,39936,499 EXPECT_NE(all_to_all, nullptr); } +TEST_P(SpmdPartitioningTest, SortAllGatherNonMovableDimension) { + const char* const hlo_string = R"( +HloModule module + +top_k_gt_f32_comparator_64.35303 { + Arg_2.35306 = s32[] parameter(2) + Arg_3.35307 = s32[] parameter(3) + Arg_0.35304 = f32[] parameter(0) + Arg_1.35305 = f32[] parameter(1) + ROOT compare.35308 = pred[] compare(Arg_0.35304, Arg_1.35305), direction=GT +} + +ENTRY entry { + param.0 = f32[4,16384,4096]{2,1,0} parameter(0), sharding={devices=[4,4,4]<=[64]} + param.1 = s32[4,16384,4096]{2,1,0} parameter(1), sharding={devices=[4,4,4]<=[64]} + ROOT sort.209 = (f32[4,16384,4096]{2,1,0}, s32[4,16384,4096]{2,1,0}) sort(param.0, param.1), dimensions={2}, to_apply=top_k_gt_f32_comparator_64.35303, sharding={{devices=[4,4,4]<=[64]}, {devices=[4,4,4]<=[64]}} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation( + hlo_string, /*num_devices=*/64, + /*conv_halo_exchange_always_on_lhs=*/true, + /*xla_tpu_enable_log_recorder_partitioned_logging=*/true)); + XLA_VLOG_LINES(1, module->ToString()); + + auto* root = module->entry_computation()->root_instruction(); + auto* sort = FindInstruction(module.get(), HloOpcode::kSort); + EXPECT_THAT( + root, + AllOf(op::Tuple(), + op::Shape("(f32[1,4096,1024]{2,1,0}, s32[1,4096,1024]{2,1,0})"))); + EXPECT_THAT( + sort, + AllOf(op::Sort( + AllOf(op::AllReduce(), op::Shape("f32[1,4096,4096]{2,1,0}")), + AllOf(op::AllReduce(), op::Shape("s32[1,4096,4096]{2,1,0}"))), + op::Shape("(f32[1,4096,4096]{2,1,0}, s32[1,4096,4096]{2,1,0})"))); +} + } // namespace } // namespace spmd } // namespace xla diff --git a/xla/service/spmd/spmd_partitioner_util.cc b/xla/service/spmd/spmd_partitioner_util.cc index e84bfcc28ba75..7bb7927b3f835 100644 --- a/xla/service/spmd/spmd_partitioner_util.cc +++ b/xla/service/spmd/spmd_partitioner_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,14 +17,21 @@ limitations under the License. #include #include +#include +#include #include #include +#include #include +#include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -112,6 +119,13 @@ Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) { } int64_t ShapeSizeInBytes(const Shape& shape) { + if (shape.IsTuple()) { + int64_t total_size = 0; + for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + total_size += ShapeSizeInBytes(ShapeUtil::GetTupleElementShape(shape, i)); + } + return total_size; + } return ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()) * ShapeUtil::ElementsIn(shape); } @@ -162,28 +176,26 @@ std::vector MakePartitionOffsets( absl::Span dims) { CHECK(!shape.IsTuple()); - std::vector> offset_arrays(shape.rank()); - for (int64_t i = 0; i < shape.rank(); ++i) { - offset_arrays[i].resize(sharding.tile_assignment().num_elements()); - } auto shard_shape = MakePartitionedShape(shape, sharding); - sharding.tile_assignment().Each( - [&](absl::Span indices, int64_t device) { - for (int64_t i = 0; i < shape.rank(); ++i) { - offset_arrays[i][device] = indices[i] * shard_shape.dimensions(i); - } - }); std::vector offsets; + for (int64_t i = 0; i < shape.rank(); ++i) { if (sharding.tile_assignment().dim(i) == 1 || (!dims.empty() && !absl::c_linear_search(dims, i))) { offsets.push_back(b->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(S32)))); } else { + std::vector offset_array( + sharding.tile_assignment().num_elements()); + sharding.tile_assignment().Each( + [&](absl::Span indices, int64_t device) { + offset_array[device] = indices[i] * shard_shape.dimensions(i); + }); offsets.push_back( - TableLookup(offset_arrays[i], S32, partition_id, b)); + TableLookup(offset_array, S32, partition_id, b)); } } + return offsets; } @@ -234,13 +246,16 @@ SPMDCollectiveOpsCreator GetPerGroupCollectiveOpsCreator( const SPMDCollectiveOpsCreator& creator, const std::vector>& device_groups) { SPMDCollectiveOpsCreator result; - result.create_partition_id = [creator, device_groups](SpmdBuilder* b) { - return GetInGroupPartitionId(creator.create_partition_id(b), device_groups, - b); + auto device_groups_ptr = + std::make_shared>>(device_groups); + result.create_partition_id = [creator, device_groups_ptr](SpmdBuilder* b) { + return GetInGroupPartitionId(creator.create_partition_id(b), + *device_groups_ptr, b); }; auto expand_partition_groups = - [device_groups]( + [device_groups_ptr]( const std::vector>& partition_subgroups) { + auto& device_groups = *device_groups_ptr; if (partition_subgroups.empty()) { return device_groups; } @@ -268,10 +283,11 @@ SPMDCollectiveOpsCreator GetPerGroupCollectiveOpsCreator( channel_id); }; result.create_cross_partition_collective_permute = - [creator, device_groups]( + [creator, device_groups_ptr]( SpmdBuilder* b, HloInstruction* operand, std::vector>& src_dst_pairs, int64_t next_channel_id) { + auto& device_groups = *device_groups_ptr; std::vector> expanded_pairs( src_dst_pairs.size() * device_groups.size()); for (int64_t g = 0; g < device_groups.size(); ++g) { @@ -316,23 +332,14 @@ std::optional PartialReplicateReshardCompatibleSharding( if (!partial_sharding.ReplicateOnLastTileDim()) { return std::nullopt; } - int64_t rank = partial_sharding.tile_assignment().num_dimensions() - 1; - int64_t target_rank = target_sharding.tile_assignment().num_dimensions() - - (target_sharding.ReplicateOnLastTileDim() ? 1 : 0); - if (target_rank != rank) { + if (partial_sharding.tile_assignment().num_elements() != + target_sharding.tile_assignment().num_elements()) { + return std::nullopt; + } + const int64_t rank = partial_sharding.TiledDataRank(); + if (rank != target_sharding.TiledDataRank()) { return std::nullopt; } - - absl::flat_hash_map device_to_replication_group; - partial_sharding.tile_assignment().Each( - [&](absl::Span indices, int64_t device) { - int64_t gid = 0; - for (int64_t i = 0; i < rank; ++i) { - gid *= partial_sharding.tile_assignment().dim(i); - gid += indices[i]; - } - device_to_replication_group[device] = gid; - }); // A dimension is expanded when target_tile_size > partial_tile_size and // target_tile_size % partial_tile_size == 0. @@ -340,12 +347,11 @@ std::optional PartialReplicateReshardCompatibleSharding( std::vector expand_tile_dims_indices(rank, -1); // expand_tile_size = target_tile_size / partial_tile_size. std::vector expand_tile_sizes; - int num_expand_dims = 0; + int64_t num_expand_dims = 0; for (int64_t dim = 0; dim < rank; dim++) { int64_t partial_tile_size = partial_sharding.tile_assignment().dim(dim); int64_t target_tile_size = target_sharding.tile_assignment().dim(dim); - if (target_tile_size % partial_tile_size != 0 || - target_tile_size < partial_tile_size) { + if (target_tile_size % partial_tile_size != 0) { return std::nullopt; } @@ -355,34 +361,25 @@ std::optional PartialReplicateReshardCompatibleSharding( } } - // Reshape the partial replicate tile_dimensions. - int64_t num_target_replication = 1; - if (target_sharding.ReplicateOnLastTileDim()) { - num_target_replication = - target_sharding.tile_assignment().dimensions().back(); + const std::vector shape_dims( + target_sharding.tile_assignment().dimensions().begin(), + target_sharding.tile_assignment().dimensions().begin() + rank); + if (hlo_sharding_util::IsSubTilingOrEqualSharding( + ShapeUtil::MakeShape(F32, shape_dims), target_sharding, + partial_sharding)) { + return target_sharding; } + + // Now that target_sharding is not a subtiling of partial_sharding, we + // decompose partial_sharding on the last tile dimension (replicated one) and + // move the decomposed tile dimensions to the expanded tile dimensions. std::vector reshape_dimensions( partial_sharding.tile_assignment().dimensions().begin(), - partial_sharding.tile_assignment().dimensions().end()); - int64_t num_replication = reshape_dimensions.back(); - if (num_replication / num_target_replication != Product(expand_tile_sizes) || - num_replication % num_target_replication != 0) { - return std::nullopt; - } - - reshape_dimensions.pop_back(); + partial_sharding.tile_assignment().dimensions().begin() + rank); reshape_dimensions.insert(reshape_dimensions.end(), expand_tile_sizes.begin(), expand_tile_sizes.end()); - if (target_sharding.ReplicateOnLastTileDim()) { - reshape_dimensions.push_back(num_target_replication); - } - - auto reshape_tile_assignment = - partial_sharding.tile_assignment().Reshape(reshape_dimensions); - - // Transpose. - std::vector perm; + std::vector perm; perm.reserve(rank + expand_tile_sizes.size()); for (int64_t dim = 0; dim < rank; dim++) { perm.emplace_back(dim); @@ -390,28 +387,19 @@ std::optional PartialReplicateReshardCompatibleSharding( perm.emplace_back(expand_tile_dims_indices[dim] + rank); } } - auto transpose_sharding = hlo_sharding_util::TransposeSharding( - target_sharding.ReplicateOnLastTileDim() - ? HloSharding::PartialTile(reshape_tile_assignment) - : HloSharding::Tile(reshape_tile_assignment), - perm); - // Reshape to target shape - auto transpose_tile_assignment = transpose_sharding.tile_assignment().Reshape( - target_sharding.tile_assignment().dimensions()); + if (target_sharding.ReplicateOnLastTileDim()) { + reshape_dimensions.push_back( + target_sharding.tile_assignment().dimensions().back()); + perm.push_back(reshape_dimensions.size() - 1); + } - bool groups_matching = true; - target_sharding.tile_assignment().Each( - [&](absl::Span indices, int64_t device) { - if (device_to_replication_group[device] != - device_to_replication_group[transpose_tile_assignment(indices)]) { - groups_matching = false; - } - }); + auto transpose_tile_assignment = + partial_sharding.tile_assignment() + .Reshape(reshape_dimensions) + .Transpose(perm) + .Reshape(target_sharding.tile_assignment().dimensions()); - if (groups_matching) { - return target_sharding; - } return target_sharding.ReplicateOnLastTileDim() ? HloSharding::PartialTile(transpose_tile_assignment) : HloSharding::Tile(transpose_tile_assignment); @@ -1858,7 +1846,7 @@ std::optional AlignGroupsWithInternal( auto get_permutation = [](absl::Span src, absl::Span dst) { CHECK_EQ(src.size(), dst.size()); - absl::flat_hash_map dst_reverse_map; + absl::flat_hash_map dst_reverse_map(dst.size()); for (int64_t i = 0; i < dst.size(); ++i) { dst_reverse_map[dst[i]] = i; } @@ -1872,7 +1860,8 @@ std::optional AlignGroupsWithInternal( }; CHECK_EQ(grouped_sharding.device_groups.size(), reference.device_groups.size()); - absl::flat_hash_map device_to_ref_group; + std::vector device_to_ref_group(reference.device_groups.size() * + reference.device_groups[0].size()); for (int64_t g = 0; g < reference.device_groups.size(); ++g) { for (int64_t device : reference.device_groups[g]) { device_to_ref_group[device] = g; @@ -2043,6 +2032,10 @@ std::optional> FindMatchingPartitionedDimsForGrouping( if (sharding.IsTileMaximal() || device_groups.size() < 2) { return std::nullopt; } + const int64_t num_devices = sharding.tile_assignment().num_elements(); + if (num_devices != device_groups.size() * device_groups[0].size()) { + return std::nullopt; + } std::vector dims; if (device_groups[0].size() < 2) { // Trivial case: single member groups @@ -2053,15 +2046,16 @@ std::optional> FindMatchingPartitionedDimsForGrouping( } return dims; } - int64_t rank = sharding.tile_assignment().num_dimensions(); - absl::flat_hash_map> device_to_index; + + std::vector> device_to_index( + num_devices, + std::vector(sharding.tile_assignment().num_dimensions())); sharding.tile_assignment().Each( [&](absl::Span index, int64_t device) { - device_to_index[device] = - std::vector(index.begin(), index.begin() + rank); + device_to_index[device].assign(index.begin(), index.end()); }); int64_t group_count = 1; - for (int64_t i = 0; i < rank; ++i) { + for (int64_t i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { if (device_to_index[device_groups[0][0]][i] == device_to_index[device_groups[0][1]][i]) { dims.push_back(i); diff --git a/xla/service/spmd/spmd_partitioner_util.h b/xla/service/spmd/spmd_partitioner_util.h index 0bb05c9657739..ae59b9f873722 100644 --- a/xla/service/spmd/spmd_partitioner_util.h +++ b/xla/service/spmd/spmd_partitioner_util.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,23 +16,41 @@ limitations under the License. #ifndef XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_ #define XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_ +#include +#include +#include +#include #include #include #include #include +#include +#include #include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_replace.h" +#include "absl/utility/utility.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/utils/hlo_query.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/literal_util.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/hlo_dce.h" #include "xla/service/spmd/spmd_partitioner.h" #include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace spmd { @@ -472,7 +490,7 @@ std::optional PadFromPartialReplicateShape( // {devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} // Target sharding is {devices=[2,2]0,1,2,3}, the returned compatible sharding // will be sharding={devices=[2,2]0,2,1,3}. -// If patial replicate sharding is not partial replicate or can't reshard to +// If partial_sharding is not partial replicate or can't reshard to // target_tile_dims by dynamic slice, return std::nullopt. // If target_sharding is already compatible, returns it. std::optional PartialReplicateReshardCompatibleSharding( @@ -562,6 +580,355 @@ HloInstruction* PadDataFromWindowReshard( const PartitionedHlo::WindowedInputShardReturnValue& reshard_operand, HloInstruction* pad_value, SpmdBuilder* b); +namespace detail { + +// Check if a type is SpmdPartitioningVisitor* type. +template +struct IsSpmdPartitioningVisitorPointerType : std::false_type {}; + +template +struct IsSpmdPartitioningVisitorPointerType< + T, std::enable_if_t, + SpmdPartitioningVisitor*>>> + : std::true_type {}; + +template +constexpr bool IsSpmdPartitioningVisitorPointerType_v = + IsSpmdPartitioningVisitorPointerType::value; + +template +using IsSpmdPartitioningVisitorPointer = + std::enable_if_t, int>; + +template +using IsNotSpmdPartitioningVisitorPointer = + std::enable_if_t, int>; + +// Check if a type is SpmdBuilder* type. +template +struct IsSpmdBuilderPointerType : std::false_type {}; + +template +struct IsSpmdBuilderPointerType< + T, + std::enable_if_t, SpmdBuilder*>>> + : std::true_type {}; + +template +constexpr bool IsSpmdBuilderPointerType_v = IsSpmdBuilderPointerType::value; + +template +using IsSpmdBuilderPointer = + std::enable_if_t, int>; + +template +using IsNotSpmdBuilderPointer = + std::enable_if_t, int>; + +// Check if a type is HloModule* type. +template +struct IsHloModulePointerType : std::false_type {}; + +template +struct IsHloModulePointerType< + T, std::enable_if_t, HloModule*>>> + : std::true_type {}; + +template +constexpr bool IsHloModulePointerType_v = IsHloModulePointerType::value; + +template +using IsHloModulePointer = std::enable_if_t, int>; + +template +using IsNotHloModulePointer = + std::enable_if_t, int>; + +// Check if a type is PartitionedHlo type. +template +struct IsPartitionedHloType : std::false_type {}; + +template +struct IsPartitionedHloType< + T, std::enable_if_t, PartitionedHlo>>> + : std::true_type {}; + +template +constexpr bool IsPartitionedHloType_v = IsPartitionedHloType::value; + +template +using IsPartitionedHlo = std::enable_if_t, int>; + +template +using IsNotPartitionedHlo = std::enable_if_t, int>; + +// Check if a type is iterable type. +template +struct is_iterable : std::false_type {}; + +template +struct is_iterable().begin()), + decltype(std::declval().end())>> + : std::true_type {}; + +template +constexpr bool is_iterable_v = is_iterable::value; + +template +using iterable_element_type = + std::decay_t().begin())>; + +// Check if a type is iterable container type of PartitionedHlo. +template +struct IsIterablePartitionedHloContainerType : std::false_type {}; + +template +struct IsIterablePartitionedHloContainerType< + T, + std::enable_if_t && + std::is_same_v, PartitionedHlo>>> + : std::true_type {}; + +template +constexpr bool IsIterablePartitionedHloContainerType_v = + IsIterablePartitionedHloContainerType::value; + +template +using IsIterablePartitionedHloContainer = + std::enable_if_t, int>; + +template +using IsNotIterablePartitionedHloContainer = + std::enable_if_t, int>; + +// Create a fake PartitionedHlo object in a fake builder/module as a new +// parameter. +template = 0> +std::decay_t FakePartitionedHlo(Arg&& phlo, HloModule* module, + int* parameter_count, + SpmdPartitioningVisitor* fake_visitor) { + HloInstruction* param = + fake_visitor->builder() + ->AddParameter(HloInstruction::CreateParameter( + *parameter_count, phlo.hlo()->shape(), + "fake_parameter." + std::to_string(*parameter_count))) + .value(); + *parameter_count = *parameter_count + 1; + PartitionedHlo fake_phlo = phlo.CloneWithNewHlo(param); + PartitionedHlo::PartitioningState fake_state = + fake_visitor->MakePartitioningState(); + fake_state.module = module; + fake_phlo.set_state(fake_state); + return fake_phlo; +} + +// Create a fake PartitionedHlo container object in a fake builder/module as a +// number new parameters. +template = 0> +std::decay_t FakeIterablePartitionedHloContainer( + Arg&& phlo_container, HloModule* module, int* parameter_count, + SpmdPartitioningVisitor* fake_visitor) { + std::vector> phlos; + phlos.reserve(phlo_container.size()); + for (const PartitionedHlo& phlo : phlo_container) { + phlos.push_back(std::move( + FakePartitionedHlo(phlo, module, parameter_count, fake_visitor))); + } + bool is_constructible_from_iterators = + std::is_constructible_v, decltype(phlos.begin()), + decltype(phlos.end())>; + CHECK(is_constructible_from_iterators); + return std::decay_t(phlos.begin(), phlos.end()); +} + +// Create a fake SpmdPartitioningVisitor*. +template = 0> +std::decay_t FakeSpmdPartitioningVisitor( + Arg&& visitor, SpmdPartitioningVisitor* fake_visitor) { + return fake_visitor; +} + +// Create a fake SpmdBuilder*. +template = 0> +std::decay_t FakeSpmdBuilder(Arg&& builder, + SpmdPartitioningVisitor* fake_visitor) { + return fake_visitor->builder(); +} +// Create a fake HloModule*. +template = 0> +std::decay_t FakeHloModule(Arg&& module, HloModule* fake_module) { + return fake_module; +} +template +using decay_rvalue_reference_t = + std::conditional_t::value, std::decay_t, T>; + +// Modifies SpmdPartitioningVisitor* type objects. +template = 0> +std::decay_t ArgModifier(Arg&& arg, HloModule* module, + int* parameter_count, + SpmdPartitioningVisitor* fake_visitor) { + VLOG(5) << "Faking argument type: " << typeid(arg).name(); + return FakeSpmdPartitioningVisitor(std::forward(arg), fake_visitor); +} + +// Modifies SpmdBuilder* type objects. +template = 0> +std::decay_t ArgModifier(Arg&& arg, HloModule* module, + int* parameter_count, + SpmdPartitioningVisitor* fake_visitor) { + VLOG(5) << "Faking argument type: " << typeid(arg).name(); + return FakeSpmdBuilder(std::forward(arg), fake_visitor); +} + +// Modifies SpmdPartitioningVisitor* type objects. +template = 0> +std::decay_t ArgModifier(Arg&& arg, HloModule* module, + int* parameter_count, + SpmdPartitioningVisitor* fake_visitor) { + VLOG(5) << "Faking argument type: " << typeid(arg).name(); + return FakeHloModule(std::forward(arg), module); +} + +// Modifies PartitionedHlo type objects. +template = 0> +std::decay_t ArgModifier(Arg&& arg, HloModule* module, + int* parameter_count, + SpmdPartitioningVisitor* fake_visitor) { + VLOG(5) << "Faking argument type: " << typeid(arg).name(); + return FakePartitionedHlo(std::forward(arg), module, parameter_count, + fake_visitor); +} + +// Modifies PartitionedHlo container type objects. +template = 0> +std::decay_t ArgModifier(Arg&& arg, HloModule* module, + int* parameter_count, + SpmdPartitioningVisitor* fake_visitor) { + VLOG(5) << "Faking argument type: " << typeid(arg).name(); + return FakeIterablePartitionedHloContainer(std::forward(arg), module, + parameter_count, fake_visitor); +} + +// Modifies nothing, equivalent to no-op. +template = 0, + IsNotSpmdBuilderPointer = 0, IsNotHloModulePointer = 0, + IsNotIterablePartitionedHloContainer = 0, + IsNotPartitionedHlo = 0> +std::decay_t ArgModifier(Arg&& arg, HloModule* module, + int* parameter_count, + SpmdPartitioningVisitor* fake_visitor) { + VLOG(5) << "Passing through argument type: " << typeid(arg).name(); + return arg; +} + +// Finds SpmdPartitioningVisitor* object in an arg list. +template = 0> +absl::StatusOr FindSpmdPartitioningVisitor( + Arg&& arg) { + return arg; +} + +template = 0> +absl::StatusOr FindSpmdPartitioningVisitor( + Arg&& arg) { + return absl::InvalidArgumentError("No SpmdPartitioningVisitor found."); +} + +template = 0> +absl::StatusOr FindSpmdPartitioningVisitor( + Arg&& arg, Args&&... args) { + return arg; +} + +template = 0> +absl::StatusOr FindSpmdPartitioningVisitor( + Arg&& arg, Args&&... args) { + return FindSpmdPartitioningVisitor(std::forward(args)...); +} + +} // namespace detail + +// Evaluate the memory and communication cost for any arbitrary partitioning +// methods. +template +absl::StatusOr> EvaluatePartitionCost( + const HloInstruction* original_hlo, F partition_method, + Args&&... partition_method_args) { + HloModule* module = original_hlo->GetModule(); + auto comp_env = + std::make_unique(module->comp_envs()); + // Create a fake module and run partitioning with this fake module later. + HloModule fake_module("fake_module", module->config(), std::move(comp_env)); + auto temp_b = HloComputation::Builder("temp_entry"); + auto temp_p = temp_b.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "input")); + HloComputation* temp_entry = fake_module.AddEntryComputation(temp_b.Build()); + + TF_ASSIGN_OR_RETURN(SpmdPartitioningVisitor * visitor, + detail::FindSpmdPartitioningVisitor( + std::forward(partition_method_args)...)); + SpmdPartitioner* partitioner = visitor->partitioner(); + std::unique_ptr fake_visitor = visitor->Clone(); + fake_visitor->set_module(&fake_module); + auto* fake_b = fake_visitor->builder(); + fake_b->set_visiting_hlo(temp_p); + auto parameter_count = std::make_unique(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * new_hlo, + partition_method(detail::ArgModifier( + std::forward(partition_method_args), &fake_module, + parameter_count.get(), fake_visitor.get())...)); + + if (new_hlo == nullptr) { + return std::make_pair(INT64_MAX, INT64_MAX); + } + auto new_entry = fake_module.AddEmbeddedComputation(fake_b->Build(new_hlo)); + // Replace the original computation with the new SPMD computation. + absl::flat_hash_map replacement; + replacement[temp_entry] = new_entry; + for (HloInstruction* hlo : new_entry->instructions()) { + for (HloComputation* comp : hlo->called_computations()) { + if (comp->parent() != &fake_module) { + replacement[comp] = fake_module.AddEmbeddedComputation(comp->Clone()); + } + } + } + fake_module.ReplaceComputations(replacement); + + HloDCE hlo_dce; + TF_ASSIGN_OR_RETURN( + auto _, hlo_dce.Run(&fake_module, partitioner->execution_threads())); + (void)_; // Suppress unused variable warning in OSS + VLOG(5) << "Dry-run partitioning for op: " << original_hlo->ToString() << "\n" + << fake_module.ToString(); + + int64_t max_memory = 0; + int64_t total_communication = 0; + for (HloComputation* computation : fake_module.computations()) { + for (HloInstruction* hlo : computation->instructions()) { + // Check the memory cost for the partitioned hlo op, as well as the + // memory cost for collectives for potential overhead from full remat. + if (hlo->opcode() == original_hlo->opcode() || IsCollective(hlo)) { + int64_t memory_cost = partitioner->MemoryCostInBytes(hlo); + if (memory_cost > max_memory) { + VLOG(5) << hlo->ToString() << " has memory cost of " << memory_cost; + max_memory = memory_cost; + } + } + if (IsCollective(hlo)) { + total_communication += partitioner->CommunicationCostInBytes(hlo); + } + } + } + if (max_memory != 0) { + return std::make_pair(max_memory, total_communication); + } + return std::make_pair(INT64_MAX, INT64_MAX); +} + } // namespace spmd } // namespace xla diff --git a/xla/service/spmd/spmd_partitioner_util_test.cc b/xla/service/spmd/spmd_partitioner_util_test.cc new file mode 100644 index 0000000000000..663bb45e21cb0 --- /dev/null +++ b/xla/service/spmd/spmd_partitioner_util_test.cc @@ -0,0 +1,74 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/spmd/spmd_partitioner_util.h" + +#include + +#include +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/ir/tile_assignment.h" + +namespace xla { +namespace spmd { +namespace { + +TEST(SPMDPartitionerUtilTest, PartialReplicateReshardCompatibleSharding1) { + HloSharding partial_sharding = + HloSharding::PartialTile(TileAssignment({1, 2, 2})); + const std::vector target_shardings = { + HloSharding::IotaTile({2, 2}), + HloSharding::IotaTile({2, 2}, {2, 2}, {1, 0})}; + for (const auto& target_sharding : target_shardings) { + auto result = PartialReplicateReshardCompatibleSharding(partial_sharding, + target_sharding); + EXPECT_EQ(result, target_shardings[1]); + } + + partial_sharding = + HloSharding::PartialTile(TileAssignment({1, 2, 2}, {2, 2}, {1, 0})); + for (const auto& target_sharding : target_shardings) { + auto result = PartialReplicateReshardCompatibleSharding(partial_sharding, + target_sharding); + EXPECT_EQ(result, target_shardings[0]); + } +} + +TEST(SPMDPartitionerUtilTest, PartialReplicateReshardCompatibleSharding2) { + HloSharding partial_sharding = + HloSharding::PartialTile(TileAssignment({2, 2, 8})); + const std::vector target_shardings = { + HloSharding::PartialTile( + TileAssignment({4, 4, 2}, {2, 2, 2, 2, 2}, {0, 2, 1, 3, 4})), + HloSharding::PartialTile( + TileAssignment({4, 4, 2}, {2, 2, 2, 2, 2}, {0, 2, 1, 4, 3})), + HloSharding::PartialTile( + TileAssignment({4, 4, 2}, {2, 2, 2, 2, 2}, {0, 3, 1, 2, 4})), + HloSharding::PartialTile( + TileAssignment({4, 4, 2}, {2, 2, 2, 2, 2}, {0, 3, 1, 4, 2})), + HloSharding::PartialTile( + TileAssignment({4, 4, 2}, {2, 2, 2, 2, 2}, {0, 4, 1, 2, 3})), + HloSharding::PartialTile( + TileAssignment({4, 4, 2}, {2, 2, 2, 2, 2}, {0, 4, 1, 3, 2}))}; + for (const auto& target_sharding : target_shardings) { + auto result = PartialReplicateReshardCompatibleSharding(partial_sharding, + target_sharding); + EXPECT_EQ(result, target_sharding); + } +} + +} // namespace +} // namespace spmd +} // namespace xla diff --git a/xla/service/spmd/spmd_prepare.cc b/xla/service/spmd/spmd_prepare.cc index 4dd0f42628556..f5863c8c9d0ad 100644 --- a/xla/service/spmd/spmd_prepare.cc +++ b/xla/service/spmd/spmd_prepare.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,8 +31,8 @@ namespace xla { namespace spmd { namespace { -StatusOr ProcessScatter(HloInstruction* hlo, - const CallGraph& call_graph) { +absl::StatusOr ProcessScatter(HloInstruction* hlo, + const CallGraph& call_graph) { if (hlo->opcode() != HloOpcode::kScatter) { return false; } @@ -153,8 +153,8 @@ StatusOr ProcessScatter(HloInstruction* hlo, return true; } -StatusOr RunOnComputation(HloComputation* computation, - const CallGraph& call_graph) { +absl::StatusOr RunOnComputation(HloComputation* computation, + const CallGraph& call_graph) { bool changed = false; for (HloInstruction* hlo : computation->MakeInstructionPostOrder()) { if (!hlo->has_sharding()) { @@ -170,7 +170,7 @@ StatusOr RunOnComputation(HloComputation* computation, } } // namespace -StatusOr SpmdPrepare::Run( +absl::StatusOr SpmdPrepare::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/spmd/spmd_prepare.h b/xla/service/spmd/spmd_prepare.h index 5fd69173c2978..c5232f647199e 100644 --- a/xla/service/spmd/spmd_prepare.h +++ b/xla/service/spmd/spmd_prepare.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ class SpmdPrepare : public HloModulePass { absl::string_view name() const override { return "spmd-prepare"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/spmd/spmd_prepare_test.cc b/xla/service/spmd/spmd_prepare_test.cc index a6895948c30a4..7ae4108b7e3ef 100644 --- a/xla/service/spmd/spmd_prepare_test.cc +++ b/xla/service/spmd/spmd_prepare_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,14 +35,14 @@ namespace op = xla::testing::opcode_matchers; class SpmdPrepareTest : public HloTestBase { public: - StatusOr> RunPass( + absl::StatusOr> RunPass( absl::string_view hlo_module, int64_t distance_threshold = 100) { TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule( hlo_module, GetModuleConfigForTest())); HloPassPipeline pipeline("spmd-prepare"); pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(module.get()).status()); - return StatusOr>(std::move(module)); + return absl::StatusOr>(std::move(module)); } }; diff --git a/xla/service/spmd/stateful_rng_spmd_partitioner.cc b/xla/service/spmd/stateful_rng_spmd_partitioner.cc index c4789c60bd345..2d575fa6a627b 100644 --- a/xla/service/spmd/stateful_rng_spmd_partitioner.cc +++ b/xla/service/spmd/stateful_rng_spmd_partitioner.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/spmd/stateful_rng_spmd_partitioner.h b/xla/service/spmd/stateful_rng_spmd_partitioner.h index 8c51f549bc087..5d9170e67e627 100644 --- a/xla/service/spmd/stateful_rng_spmd_partitioner.h +++ b/xla/service/spmd/stateful_rng_spmd_partitioner.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -46,10 +46,12 @@ class StatefulRngSpmdPartitioningVisitor class StatefulRngSpmdPartitioner : public spmd::SpmdPartitioner { public: StatefulRngSpmdPartitioner(int64_t num_partitions, int64_t num_replicas, - int64_t threshold_for_windowed_einsum_mib = 100000) + int64_t threshold_for_windowed_einsum_mib = 100000, + bool windowed_einsum_use_multiple_streams = false) : spmd::SpmdPartitioner( num_partitions, num_replicas, - GetSpmdPartitionerOptions(threshold_for_windowed_einsum_mib)) {} + GetSpmdPartitionerOptions(threshold_for_windowed_einsum_mib, + windowed_einsum_use_multiple_streams)) {} protected: std::unique_ptr CreateVisitor( @@ -67,11 +69,13 @@ class StatefulRngSpmdPartitioner : public spmd::SpmdPartitioner { private: static spmd::SpmdPartitionerOptions GetSpmdPartitionerOptions( - int64_t threshold_for_windowed_einsum_mib) { + int64_t threshold_for_windowed_einsum_mib, + bool windowed_einsum_use_multiple_streams = false) { spmd::SpmdPartitionerOptions options; options.allow_module_signature_change = true; options.threshold_for_windowed_einsum_mib = threshold_for_windowed_einsum_mib; + options.unroll_windowed_einsum = windowed_einsum_use_multiple_streams; return options; } }; diff --git a/xla/service/spmd/stateful_rng_spmd_partitioner_test.cc b/xla/service/spmd/stateful_rng_spmd_partitioner_test.cc index 7907e533b7861..d301f0ccdcdd7 100644 --- a/xla/service/spmd/stateful_rng_spmd_partitioner_test.cc +++ b/xla/service/spmd/stateful_rng_spmd_partitioner_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ namespace { class StatefulRngSpmdPartitionerTest : public HloTestBase { public: - StatusOr> PartitionComputation( + absl::StatusOr> PartitionComputation( absl::string_view hlo_module, int64_t num_partitions, std::function add_passes = nullptr) { TF_ASSIGN_OR_RETURN( @@ -56,7 +56,7 @@ class StatefulRngSpmdPartitionerTest : public HloTestBase { pass.AddPass(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); TF_RETURN_IF_ERROR(pass.Run(module.get()).status()); - return StatusOr>(std::move(module)); + return absl::StatusOr>(std::move(module)); } void VerifyNoAllReduce(HloModule *module) { @@ -120,11 +120,15 @@ TEST_F(StatefulRngSpmdPartitionerTest, VerifyThresholdSetCorrectly) { auto debug_options = HloTestBase::GetDebugOptionsForTest(); int64_t threshold = 400; debug_options.set_xla_gpu_threshold_for_windowed_einsum_mib(threshold); + debug_options.set_xla_gpu_multi_streamed_windowed_einsum(true); + StatefulRngSpmdPartitioner rng_spmd_partitioner( /*num_partitions=*/2, /*num_replicas*/ 1, - debug_options.xla_gpu_threshold_for_windowed_einsum_mib()); + debug_options.xla_gpu_threshold_for_windowed_einsum_mib(), + debug_options.xla_gpu_multi_streamed_windowed_einsum()); EXPECT_EQ(rng_spmd_partitioner.options().threshold_for_windowed_einsum_mib, threshold); + EXPECT_EQ(rng_spmd_partitioner.options().unroll_windowed_einsum, true); } } // namespace } // namespace spmd diff --git a/xla/service/spmd/whole_graph_manual_pass.cc b/xla/service/spmd/whole_graph_manual_pass.cc index d36ba063b50f0..cf26fd21e397a 100644 --- a/xla/service/spmd/whole_graph_manual_pass.cc +++ b/xla/service/spmd/whole_graph_manual_pass.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ bool ShouldClearInstruction(HloInstruction* inst) { !inst->HasSideEffectNoRecurse(); } -StatusOr RunOnComputation(HloComputation* computation) { +absl::StatusOr RunOnComputation(HloComputation* computation) { bool changed = false; for (HloInstruction* inst : computation->instructions()) { if (ShouldClearInstruction(inst)) { @@ -59,7 +59,7 @@ StatusOr RunOnComputation(HloComputation* computation) { } // namespace -StatusOr WholeGraphManualPass::Run( +absl::StatusOr WholeGraphManualPass::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/spmd/whole_graph_manual_pass.h b/xla/service/spmd/whole_graph_manual_pass.h index a453e2bde766e..93707a59db608 100644 --- a/xla/service/spmd/whole_graph_manual_pass.h +++ b/xla/service/spmd/whole_graph_manual_pass.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ class WholeGraphManualPass : public HloModulePass { absl::string_view name() const override { return "whole-graph-manual-pass"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/spmd/whole_graph_manual_pass_test.cc b/xla/service/spmd/whole_graph_manual_pass_test.cc index ac656ade04e78..db4a7838edb9d 100644 --- a/xla/service/spmd/whole_graph_manual_pass_test.cc +++ b/xla/service/spmd/whole_graph_manual_pass_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,7 +33,8 @@ namespace op = xla::testing::opcode_matchers; class WholeGraphManualPassTest : public HloTestBase { public: - StatusOr> RunPass(absl::string_view hlo_module) { + absl::StatusOr> RunPass( + absl::string_view hlo_module) { TF_ASSIGN_OR_RETURN( auto module, ParseAndReturnVerifiedModule( @@ -42,7 +43,7 @@ class WholeGraphManualPassTest : public HloTestBase { HloPassPipeline pipeline("whole-graph-manual-pass"); pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(module.get()).status()); - return StatusOr>(std::move(module)); + return absl::StatusOr>(std::move(module)); } Status RunPassOnModule(HloModule* module, int64_t distance_threshold = 100) { HloPassPipeline pipeline("all-gather-cse"); diff --git a/xla/service/stable_sort_expander.cc b/xla/service/stable_sort_expander.cc index 40bd4c3f47a2a..7a92ab7656c80 100644 --- a/xla/service/stable_sort_expander.cc +++ b/xla/service/stable_sort_expander.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ namespace xla { // If no matching iota operand is found, a iota operand is added to Sort. The // comparison computation is adjusted to break ties using the values from the // iota operand. -StatusOr StableSortExpander::ExpandInstruction( +absl::StatusOr StableSortExpander::ExpandInstruction( HloInstruction* instruction) { auto* sort = Cast(instruction); HloComputation* computation = sort->parent(); diff --git a/xla/service/stable_sort_expander.h b/xla/service/stable_sort_expander.h index e5896a8e2ab49..b5213b22ab466 100644 --- a/xla/service/stable_sort_expander.h +++ b/xla/service/stable_sort_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ class StableSortExpander : public OpExpanderPass { private: bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; }; diff --git a/xla/service/stable_sort_expander_test.cc b/xla/service/stable_sort_expander_test.cc index abb1b35198827..f2b5c41eee4f1 100644 --- a/xla/service/stable_sort_expander_test.cc +++ b/xla/service/stable_sort_expander_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/stochastic_convert_decomposer.cc b/xla/service/stochastic_convert_decomposer.cc index 83c7928986be6..8829b42ed8645 100644 --- a/xla/service/stochastic_convert_decomposer.cc +++ b/xla/service/stochastic_convert_decomposer.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -131,12 +131,12 @@ Status DecomposeStochasticConvert(HloComputation* comp, } // TODO(b/232442915): Add support for converting to floats. - return InternalError("Unsupported stochastic convert: from %s to %s", + return Internal("Unsupported stochastic convert: from %s to %s", PrimitiveType_Name(from_type), PrimitiveType_Name(to_type)); } -StatusOr StochasticConvertDecomposer::Run( +absl::StatusOr StochasticConvertDecomposer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/stochastic_convert_decomposer.h b/xla/service/stochastic_convert_decomposer.h index 1292e625caa5a..a51421b438b28 100644 --- a/xla/service/stochastic_convert_decomposer.h +++ b/xla/service/stochastic_convert_decomposer.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ class StochasticConvertDecomposer : public HloModulePass { return "stochastic_convert_decomposer"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/stochastic_convert_decomposer_test.cc b/xla/service/stochastic_convert_decomposer_test.cc index 246ea404c537f..2dd3564884d76 100644 --- a/xla/service/stochastic_convert_decomposer_test.cc +++ b/xla/service/stochastic_convert_decomposer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/stream_pool.cc b/xla/service/stream_pool.cc index 4e9ea0ec5343e..54f5c773e7613 100644 --- a/xla/service/stream_pool.cc +++ b/xla/service/stream_pool.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,10 +18,11 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" + namespace xla { -StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor, - se::StreamPriority priority) { +StreamPool::Ptr StreamPool::BorrowStream(se::StreamPriority priority) { std::unique_ptr stream; { @@ -34,13 +35,13 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor, stream = std::move(streams_with_pri_[priority].back()); streams_with_pri_[priority].pop_back(); if (stream->ok()) { - VLOG(1) << stream->DebugStreamPointers() - << " StreamPool reusing existing stream with priority: " - << se::StreamPriorityToString(priority); + VLOG(1) << absl::StrFormat( + "StreamPool reusing existing stream (%p) with priority: %s", + stream.get(), se::StreamPriorityToString(priority)); } else { - VLOG(1) << stream->DebugStreamPointers() - << " stream was not ok, StreamPool deleting with priority: " - << se::StreamPriorityToString(priority); + VLOG(1) << absl::StrFormat( + "Stream (%p) was not ok, deleting with : %s", stream.get(), + se::StreamPriorityToString(priority)); stream = nullptr; } } @@ -49,13 +50,10 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor, if (!stream) { // Create a new stream. - stream = std::make_unique(executor); - stream->SetPriority(priority); - VLOG(1) << "Set stream priority to: " - << se::StreamPriorityToString(priority); - stream->Init(); - VLOG(1) << stream->DebugStreamPointers() - << " StreamPool created new stream"; + stream = executor_->CreateStream(priority).value(); + VLOG(1) << absl::StrFormat("Created new stream (%p) with priority = %s", + stream.get(), + se::StreamPriorityToString(priority)); } // Return the stream wrapped in Ptr, which has our special deleter semantics. @@ -65,8 +63,7 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor, void StreamPool::ReturnStream(se::Stream* stream) { if (stream->ok()) { - VLOG(1) << stream->DebugStreamPointers() - << " StreamPool returning ok stream"; + VLOG(1) << absl::StrFormat("StreamPool returning ok stream (%p)", stream); absl::MutexLock lock(&mu_); auto priority = std::get(stream->priority()); streams_with_pri_[priority].emplace_back(stream); @@ -74,8 +71,7 @@ void StreamPool::ReturnStream(se::Stream* stream) { // If the stream has encountered any errors, all subsequent operations on it // will fail. So just delete the stream, and rely on new streams to be // created in the future. - VLOG(1) << stream->DebugStreamPointers() - << " StreamPool deleting !ok stream"; + VLOG(1) << absl::StrFormat("StreamPool deleting !ok stream (%p)", stream); delete stream; } } diff --git a/xla/service/stream_pool.h b/xla/service/stream_pool.h index 32204daa3dcc7..1610071de6dcf 100644 --- a/xla/service/stream_pool.h +++ b/xla/service/stream_pool.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,15 +38,14 @@ class StreamPool { // stream to the pool on destruction. using Ptr = std::unique_ptr; - StreamPool() = default; + explicit StreamPool(se::StreamExecutor* executor) : executor_(executor) {} // Returns a pointer to a stream in the pool, creating a new stream // if none are available in the pool. The returned smart pointer // returns the stream to the pool on destruction. // // This method is thread-safe. - Ptr BorrowStream(se::StreamExecutor* executor, - se::StreamPriority priority = se::StreamPriority::Default); + Ptr BorrowStream(se::StreamPriority priority = se::StreamPriority::Default); private: // Puts a pointer to a stream back into the pool, leaving it free @@ -61,6 +60,7 @@ class StreamPool { std::unordered_map>> streams_with_pri_ ABSL_GUARDED_BY(mu_); + se::StreamExecutor* executor_; }; } // namespace xla diff --git a/xla/service/stream_pool_test.cc b/xla/service/stream_pool_test.cc index 7e57719bf330e..fd0a05e5d2f23 100644 --- a/xla/service/stream_pool_test.cc +++ b/xla/service/stream_pool_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ limitations under the License. #include +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/test_helpers.h" @@ -27,26 +28,29 @@ class StreamPoolTest : public ::testing::Test { protected: std::unique_ptr NewStreamExecutor() { se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("Host").value(); + se::PlatformManager::PlatformWithName("Host").value(); se::StreamExecutorConfig config(/*ordinal=*/0); return platform->GetUncachedExecutor(config).value(); } }; -TEST_F(StreamPoolTest, EmptyPool) { StreamPool pool; } +TEST_F(StreamPoolTest, EmptyPool) { + std::unique_ptr executor = NewStreamExecutor(); + StreamPool pool(executor.get()); +} TEST_F(StreamPoolTest, OneStreamPool) { std::unique_ptr executor = NewStreamExecutor(); - StreamPool pool; + StreamPool pool(executor.get()); // Borrow and return a stream. - StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); + StreamPool::Ptr stream1 = pool.BorrowStream(); se::Stream* stream1_ptr = stream1.get(); EXPECT_TRUE(stream1->ok()); stream1 = nullptr; // Borrow and return another stream. - StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); + StreamPool::Ptr stream2 = pool.BorrowStream(); se::Stream* stream2_ptr = stream2.get(); EXPECT_TRUE(stream2->ok()); stream2 = nullptr; @@ -58,13 +62,13 @@ TEST_F(StreamPoolTest, OneStreamPool) { TEST_F(StreamPoolTest, TwoStreamPool) { std::unique_ptr executor = NewStreamExecutor(); - StreamPool pool; + StreamPool pool(executor.get()); // Borrow two streams. - StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); + StreamPool::Ptr stream1 = pool.BorrowStream(); se::Stream* stream1_ptr = stream1.get(); EXPECT_TRUE(stream1->ok()); - StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); + StreamPool::Ptr stream2 = pool.BorrowStream(); se::Stream* stream2_ptr = stream2.get(); EXPECT_TRUE(stream2->ok()); @@ -74,7 +78,7 @@ TEST_F(StreamPoolTest, TwoStreamPool) { // Return stream1 and borrow stream3. stream1 = nullptr; - StreamPool::Ptr stream3 = pool.BorrowStream(executor.get()); + StreamPool::Ptr stream3 = pool.BorrowStream(); se::Stream* stream3_ptr = stream3.get(); EXPECT_TRUE(stream3->ok()); @@ -84,7 +88,7 @@ TEST_F(StreamPoolTest, TwoStreamPool) { // Return stream2, and borrow stream4. stream2 = nullptr; - StreamPool::Ptr stream4 = pool.BorrowStream(executor.get()); + StreamPool::Ptr stream4 = pool.BorrowStream(); se::Stream* stream4_ptr = stream4.get(); EXPECT_TRUE(stream4->ok()); @@ -93,78 +97,5 @@ TEST_F(StreamPoolTest, TwoStreamPool) { EXPECT_NE(stream3_ptr, stream4_ptr); } -TEST_F(StreamPoolTest, BadStreamDiscarded) { - std::unique_ptr executor = NewStreamExecutor(); - StreamPool pool; - - // Borrow a stream. - StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); - EXPECT_TRUE(stream1->ok()); - - // Force an error on the stream; here we call a method that requires - // DNN support, which we know the Host platform doesn't support. - stream1->ThenDepthConcatenate({}, {}, nullptr); - EXPECT_FALSE(stream1->ok()); - - // Return stream1 and borrow stream2. - stream1 = nullptr; - StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); - se::Stream* stream2_ptr = stream2.get(); - EXPECT_TRUE(stream2->ok()); - - // The underlying streams should be different. They would have been - // the same, but since we forced an error on stream1, it cannot be - // put back into the pool. Sadly we can't just check: - // EXPECT_NE(stream1_ptr, stream2_ptr); - // - // The above should hold logically, but it may fail if the new - // stream instance allocated for stream2 happens to reside in the - // same memory address as stream1, which has been deleted. - // - // The check that stream2->ok() serves as a good-enough check. - - // Return stream2 and borrow stream3. The previous error on stream1 - // has no effect on these streams, and they are the same. - stream2 = nullptr; - StreamPool::Ptr stream3 = pool.BorrowStream(executor.get()); - se::Stream* stream3_ptr = stream3.get(); - EXPECT_TRUE(stream3->ok()); - EXPECT_EQ(stream2_ptr, stream3_ptr); -} - -TEST_F(StreamPoolTest, BadStreamAfterReturnDiscarded) { - std::unique_ptr executor = NewStreamExecutor(); - StreamPool pool; - - // Borrow a stream. - StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); - EXPECT_TRUE(stream1->ok()); - - // Return the stream, but hold a handle to it. - se::Stream* stream1_ptr = stream1.get(); - stream1 = nullptr; - - // Now stream1 is back in the pool, force an error on the stream. Here we call - // a method that requires DNN support, which we know the Host platform doesn't - // support. - stream1_ptr->ThenDepthConcatenate({}, {}, nullptr); - EXPECT_FALSE(stream1_ptr->ok()); - - // Borrow stream2. - StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); - EXPECT_TRUE(stream2->ok()); - - // The underlying streams should be different. They would have been - // the same, but since we forced an error on stream1, it cannot be - // put back into the pool. Sadly we can't just check: - // EXPECT_NE(stream1_ptr, stream2_ptr); - // - // The above should hold logically, but it may fail if the new - // stream instance allocated for stream2 happens to reside in the - // same memory address as stream1, which has been deleted. - // - // The check that stream2->ok() serves as a good-enough check. -} - } // namespace } // namespace xla diff --git a/xla/service/sub_byte_normalization.cc b/xla/service/sub_byte_normalization.cc index 60a8c79d83f7f..b113de71c0780 100644 --- a/xla/service/sub_byte_normalization.cc +++ b/xla/service/sub_byte_normalization.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -84,7 +84,7 @@ bool ProcessInputOrOutputLayout(ShapeLayout* shape_layout, } // namespace -StatusOr SubByteNormalization::Run( +absl::StatusOr SubByteNormalization::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; @@ -93,8 +93,11 @@ StatusOr SubByteNormalization::Run( changed |= UpdateShape(shape, mode_); return OkStatus(); }); - for (HloComputation* computation : - module->MakeNonfusionComputations(execution_threads)) { + for (HloComputation* computation : module->computations()) { + // We rewrite all computations instead of non-fusion computations, despite + // element_size_in_bits within fusions being meaningless, because HloVerfier + // checks for the correct use of element_size_in_bits even in fusion + // computations. TF_RETURN_IF_ERROR(computation->Accept(&visitor)); } auto* computation_layout = module->mutable_entry_computation_layout(); diff --git a/xla/service/sub_byte_normalization.h b/xla/service/sub_byte_normalization.h index 39d4b608a1d70..ea70be74a19b4 100644 --- a/xla/service/sub_byte_normalization.h +++ b/xla/service/sub_byte_normalization.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -53,7 +53,7 @@ class SubByteNormalization : public HloModulePass { } } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/symbol_repository.h b/xla/service/symbol_repository.h index dd88a8b777dd9..6cce3bc497ab1 100644 --- a/xla/service/symbol_repository.h +++ b/xla/service/symbol_repository.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -104,9 +104,10 @@ inline SymbolRepositoryRegistry& GetGlobalSymbolRepositoryRegistry() { // Entry points start here. -inline StatusOr> LookupSymbolInRepository( - absl::string_view repository, absl::string_view symbol_reference, - BackendType backend) { +inline absl::StatusOr> +LookupSymbolInRepository(absl::string_view repository, + absl::string_view symbol_reference, + BackendType backend) { if (SymbolRepository* repo = GetGlobalSymbolRepositoryRegistry().repo(repository); repo != nullptr) { diff --git a/xla/service/test_compilation_environment.proto b/xla/service/test_compilation_environment.proto index 008560debe6cd..8aaaa61c9f27d 100644 --- a/xla/service/test_compilation_environment.proto +++ b/xla/service/test_compilation_environment.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/time_utils.cc b/xla/service/time_utils.cc index 227193f2c6ebf..677cccc6c58f8 100644 --- a/xla/service/time_utils.cc +++ b/xla/service/time_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/time_utils.h b/xla/service/time_utils.h index c3ea7099d634c..e632c0653c819 100644 --- a/xla/service/time_utils.h +++ b/xla/service/time_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/topk_rewriter.cc b/xla/service/topk_rewriter.cc index 784d5f3554100..257163c0faf7a 100644 --- a/xla/service/topk_rewriter.cc +++ b/xla/service/topk_rewriter.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/topk_rewriter.h" #include +#include #include #include #include @@ -29,8 +30,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/primitive_util.h" #include "xla/service/pattern_matcher.h" #include "xla/shape_util.h" +#include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" @@ -39,7 +42,7 @@ namespace xla { namespace m = match; // TODO(cheshire): Avoid duplication w/ cudnn_vectorize_convolutions. -static StatusOr BuilderToHloComputation( +static absl::StatusOr BuilderToHloComputation( XlaComputation& comp, HloComputation* sibling_computation) { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comp.GetProgramShape()); HloModuleConfig config(program_shape); @@ -120,6 +123,36 @@ static bool IsNanSafeGt(HloComputation* comp) { param_s32); }; + auto match_generic_iec559 = [](int64_t parameter_number, + PrimitiveType fp_type, + PrimitiveType int_type) { + auto param = m::Parameter(parameter_number) + .WithShape(m::Shape().WithElementType(fp_type)); + auto signed_value = m::BitcastConvert(param).WithShape( + m::Shape().WithElementType(int_type)); + int64_t bit_width = primitive_util::BitWidth(fp_type); + auto max_value = m::ConstantScalar(LsbMask(bit_width - 1)); + auto flipped_value = m::XorAnyOrder(max_value, signed_value); + auto is_negative = m::Lt(signed_value, m::ConstantScalar(0)); + return m::Select(is_negative, flipped_value, signed_value); + }; + + auto match_generic_iec559_with_convert = + [](int64_t parameter_number, PrimitiveType param_type, + PrimitiveType fp_type, PrimitiveType int_type) { + auto param = m::Parameter(parameter_number) + .WithShape(m::Shape().WithElementType(param_type)); + auto convert = + m::Convert(param).WithShape(m::Shape().WithElementType(fp_type)); + auto signed_value = m::BitcastConvert(convert).WithShape( + m::Shape().WithElementType(int_type)); + int64_t bit_width = primitive_util::BitWidth(fp_type); + auto max_value = m::ConstantScalar(LsbMask(bit_width - 1)); + auto flipped_value = m::XorAnyOrder(max_value, signed_value); + auto is_negative = m::Lt(signed_value, m::ConstantScalar(0)); + return m::Select(is_negative, flipped_value, signed_value); + }; + auto match_s32 = [](int64_t parameter_number) { auto param = m::Parameter(parameter_number) .WithShape(m::Shape().WithElementType(S32)); @@ -155,6 +188,15 @@ static bool IsNanSafeGt(HloComputation* comp) { }; return Match(comp->root_instruction(), + m::Gt(match_generic_iec559(0, F32, S32), + match_generic_iec559(1, F32, S32))) || + Match(comp->root_instruction(), + m::Gt(match_generic_iec559(0, BF16, S16), + match_generic_iec559(1, BF16, S16))) || + Match(comp->root_instruction(), + m::Gt(match_generic_iec559_with_convert(0, BF16, F32, S32), + match_generic_iec559_with_convert(1, BF16, F32, S32))) || + Match(comp->root_instruction(), m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) || Match(comp->root_instruction(), m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1))) || @@ -332,7 +374,7 @@ TopKCustomCall CreateTopKCustomCall(HloInstruction* input, return {topk, value_gte, index_gte}; } -StatusOr TopkRewriter::TransformPatternToCustomCall( +absl::StatusOr TopkRewriter::TransformPatternToCustomCall( HloInstruction* inst) { // Check if sort is in TopK. std::optional k = SortIsInTopK(inst); @@ -386,7 +428,7 @@ StatusOr TopkRewriter::TransformPatternToCustomCall( return topkcc.topk; } -StatusOr TopkRewriter::TransformToCustomCall( +absl::StatusOr TopkRewriter::TransformToCustomCall( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; @@ -403,7 +445,7 @@ StatusOr TopkRewriter::TransformToCustomCall( return changed; } -StatusOr TopkRewriter::Run( +absl::StatusOr TopkRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; @@ -440,11 +482,20 @@ class TopkDecomposerVisitor : public DfsHloRewriteVisitor { } private: - StatusOr CreateVariadicComparator(HloInstruction* inst) { + bool HasSingleUserReadingOnlyTheValueOutput(HloInstruction* inst) { + return inst->user_count() == 1 && inst->users().front()->tuple_index() == 0; + } + + absl::StatusOr CreateVariadicComparator( + HloInstruction* inst) { HloTopKInstruction* topk = DynCast(inst); XlaBuilder b(absl::StrCat("comparator_", topk->name())); std::vector ptypes = { - topk->operand(0)->shape().element_type(), PrimitiveType::S32}; + topk->operand(0)->shape().element_type()}; + + if (!HasSingleUserReadingOnlyTheValueOutput(inst)) { + ptypes.emplace_back(PrimitiveType::S32); + } XlaComputation comparison = topk->largest() ? CreateScalarGtComputation(ptypes, &b) @@ -474,10 +525,10 @@ class TopkDecomposerVisitor : public DfsHloRewriteVisitor { }; CHECK_NE(variadic_comparator, nullptr); // If only the topk values are necessary, skip the iota. - if (call->user_count() == 1 && call->users().front()->tuple_index() == 0 && - call->to_apply()->num_parameters() == 2) { + if (HasSingleUserReadingOnlyTheValueOutput(call) && + variadic_comparator->num_parameters() == 2) { HloInstruction* sort = comp->AddInstruction(HloInstruction::CreateSort( - {input->shape()}, sort_dimension, {input}, call->to_apply(), + {input->shape()}, sort_dimension, {input}, variadic_comparator, /*is_stable=*/true)); TF_RETURN_IF_ERROR(ReplaceInstruction( call->users().front(), @@ -504,7 +555,7 @@ class TopkDecomposerVisitor : public DfsHloRewriteVisitor { HloPredicate should_decompose_; }; -StatusOr TopkDecomposer::Run( +absl::StatusOr TopkDecomposer::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { return TopkDecomposerVisitor(should_decompose_) diff --git a/xla/service/topk_rewriter.h b/xla/service/topk_rewriter.h index f32aa83166fee..8400c505053a6 100644 --- a/xla/service/topk_rewriter.h +++ b/xla/service/topk_rewriter.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ class TopkRewriter : public HloModulePass { absl::string_view name() const override { return "topk-rewriter"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -50,7 +50,7 @@ class TopkRewriter : public HloModulePass { std::optional SortIsInTopK(HloInstruction* inst); // Transform to CustomCall. - StatusOr TransformToCustomCall( + absl::StatusOr TransformToCustomCall( HloModule* module, const absl::flat_hash_set& execution_threads); @@ -62,7 +62,8 @@ class TopkRewriter : public HloModulePass { // Matches the input to the sort+iota+slice pattern and converts to custom // call if profitable. Returns the custom call if one was created. - StatusOr TransformPatternToCustomCall(HloInstruction* inst); + absl::StatusOr TransformPatternToCustomCall( + HloInstruction* inst); }; class TopkDecomposer : public HloModulePass { @@ -73,7 +74,7 @@ class TopkDecomposer : public HloModulePass { : should_decompose_(should_decompose) {} using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/topk_rewriter_test.cc b/xla/service/topk_rewriter_test.cc index 655a728cc8b0e..a1dbb7a5f59b0 100644 --- a/xla/service/topk_rewriter_test.cc +++ b/xla/service/topk_rewriter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/tpu_computation_placer.cc b/xla/service/tpu_computation_placer.cc index d1609ed46febe..23e849ca13536 100644 --- a/xla/service/tpu_computation_placer.cc +++ b/xla/service/tpu_computation_placer.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/tpu_computation_placer.h b/xla/service/tpu_computation_placer.h index a4b50d045319c..60345010683a2 100644 --- a/xla/service/tpu_computation_placer.h +++ b/xla/service/tpu_computation_placer.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,7 +27,7 @@ namespace tpu { class TpuComputationPlacer : public xla::ComputationPlacer { public: template - using StatusOr = xla::StatusOr; + using StatusOr = absl::StatusOr; TpuComputationPlacer(); ~TpuComputationPlacer() override; diff --git a/xla/service/transfer_manager.cc b/xla/service/transfer_manager.cc index 2f9cd018e375f..4335351bc8db2 100644 --- a/xla/service/transfer_manager.cc +++ b/xla/service/transfer_manager.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,23 +15,31 @@ limitations under the License. #include "xla/service/transfer_manager.h" +#include #include #include -#include #include +#include +#include "absl/base/const_init.h" #include "absl/cleanup/cleanup.h" -#include "absl/strings/str_cat.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "xla/literal.h" #include "xla/service/compiler.h" #include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/shaped_buffer.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/status.h" #include "xla/status_macros.h" -#include "xla/types.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" #include "xla/util.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/notification.h" - -using absl::StrCat; +#include "tsl/platform/statusor.h" namespace xla { @@ -45,7 +53,7 @@ TransferManager::GetPlatformTransferManagers() { return r; } -StatusOr TransferManager::TransferLiteralFromDevice( +absl::StatusOr TransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, const TransferMetadata* transfer_metadata) { Literal literal(device_buffer.on_host_shape()); @@ -58,8 +66,8 @@ Status TransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, const MutableBorrowingLiteral& literal, const TransferMetadata* transfer_metadata) { - se::Stream* substream = stream->GetOrCreateSubStream(); - substream->ThenWaitFor(stream); + TF_ASSIGN_OR_RETURN(se::Stream * substream, stream->GetOrCreateSubStream()); + TF_RETURN_IF_ERROR(substream->WaitFor(stream)); absl::Cleanup cleanup = [&]() { stream->ReturnSubStream(substream); }; Status ret; @@ -82,15 +90,15 @@ Status TransferManager::TransferLiteralToDevice( // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. - se::Stream* substream = stream->GetOrCreateSubStream(); - substream->ThenWaitFor(stream); + TF_ASSIGN_OR_RETURN(se::Stream * substream, stream->GetOrCreateSubStream()); + TF_RETURN_IF_ERROR(substream->WaitFor(stream)); absl::Cleanup cleanup = [&]() { stream->ReturnSubStream(substream); }; TF_RETURN_IF_ERROR(TransferLiteralToDeviceAsync( substream, literal, device_buffer, transfer_metadata)); return substream->BlockHostUntilDone(); } -StatusOr TransferManager::TransferArrayFromDevice( +absl::StatusOr TransferManager::TransferArrayFromDevice( se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source, const TransferMetadata* transfer_metadata) { TF_RET_CHECK(shape.IsArray()); @@ -111,8 +119,8 @@ Status TransferManager::TransferArrayToDevice( // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. - se::Stream* substream = stream->GetOrCreateSubStream(); - substream->ThenWaitFor(stream); + TF_ASSIGN_OR_RETURN(se::Stream * substream, stream->GetOrCreateSubStream()); + TF_RETURN_IF_ERROR(substream->WaitFor(stream)); absl::Cleanup cleanup = [&]() { stream->ReturnSubStream(substream); }; TF_RETURN_IF_ERROR( TransferArrayToDeviceAsync(substream, literal, dest, transfer_metadata)); @@ -141,7 +149,8 @@ Status TransferManager::ReadDynamicShapes(se::Stream* stream, TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(stream->parent()->platform())); TF_RETURN_IF_ERROR(device_buffer->buffers().ForEachElementWithStatus( - [&](const ShapeIndex& index, const se::DeviceMemoryBase& buffer) { + [&](const ShapeIndex& index, + const se::DeviceMemoryBase& buffer) -> absl::Status { const Shape& buffer_shape = ShapeUtil::GetSubshape(*device_shape, index); if (buffer_shape.IsTuple()) { @@ -163,8 +172,7 @@ Status TransferManager::ReadDynamicShapes(se::Stream* stream, return InvalidArgument("Dynamic shape metadata size should not be 0"); } auto buffer_8 = se::DeviceMemory(buffer); - auto metadata_buffer = - stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size); + auto metadata_buffer = buffer_8.GetSlice(offset, metadata_size); TF_ASSIGN_OR_RETURN( auto metadata, TransferArrayFromDevice( @@ -194,7 +202,7 @@ Status TransferManager::ReadDynamicShapes(se::Stream* stream, (*managers)[platform_id].creation_function = creation_function; } -/* static */ StatusOr TransferManager::GetForPlatform( +/* static */ absl::StatusOr TransferManager::GetForPlatform( const se::Platform* platform) { absl::MutexLock lock(&TransferManager::platform_transfer_manager_mutex_); auto* managers = GetPlatformTransferManagers(); @@ -289,7 +297,7 @@ Status TransferManager::WriteRootTupleIndexTable( &device_memory); } -StatusOr TransferManager::AllocateScopedShapedBuffer( +absl::StatusOr TransferManager::AllocateScopedShapedBuffer( const Shape& on_host_shape, se::DeviceMemoryAllocator* allocator, int device_ordinal, DeviceShapeRepresentationFn shape_representation_fn) { if (!LayoutUtil::HasLayout(on_host_shape)) { @@ -324,7 +332,7 @@ StatusOr TransferManager::AllocateScopedShapedBuffer( return std::move(shaped_buffer); } -StatusOr TransferManager::ChooseCompactLayoutForShape( +absl::StatusOr TransferManager::ChooseCompactLayoutForShape( const Shape& host_shape) const { return LayoutUtil::GetWithDefaultLayout(host_shape); } diff --git a/xla/service/transfer_manager.h b/xla/service/transfer_manager.h index 8b1b9b8bde1ac..b29a93ee6eb55 100644 --- a/xla/service/transfer_manager.h +++ b/xla/service/transfer_manager.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,19 +16,24 @@ limitations under the License. #ifndef XLA_SERVICE_TRANSFER_MANAGER_H_ #define XLA_SERVICE_TRANSFER_MANAGER_H_ -#include -#include -#include +#include +#include +#include #include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/literal.h" -#include "xla/service/executable.h" +#include "xla/service/maybe_owning_device_memory.h" #include "xla/service/shaped_buffer.h" +#include "xla/shape.h" +#include "xla/shape_tree.h" +#include "xla/shape_util.h" +#include "xla/status.h" #include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" namespace xla { @@ -72,7 +77,7 @@ class TransferManager { // // Optionally caller can specify platform-specific transfer metadata that // tells the actual implementation to do something special. - StatusOr TransferLiteralFromDevice( + absl::StatusOr TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, const TransferMetadata* transfer_metadata = nullptr); @@ -166,7 +171,7 @@ class TransferManager { const se::DeviceMemoryBase& dest, const TransferMetadata* transfer_metadata = nullptr); - StatusOr TransferArrayFromDevice( + absl::StatusOr TransferArrayFromDevice( se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source, const TransferMetadata* transfer_metadata = nullptr); @@ -223,7 +228,7 @@ class TransferManager { // devices that have tiled memory architectures. // The default implementation always picks a default (major-to-minor) layout. // Fails if 'shape' cannot be represented by the device. - virtual StatusOr ChooseCompactLayoutForShape( + virtual absl::StatusOr ChooseCompactLayoutForShape( const Shape& host_shape) const; // For the given shape, chooses a layout for infeed. The returned shape @@ -236,7 +241,7 @@ class TransferManager { // Allocates a ScopedShapedBuffer which can hold data with the given on-host // shape. The on-device shape may be different as indicated by // HostShapeToDeviceShape. - StatusOr AllocateScopedShapedBuffer( + absl::StatusOr AllocateScopedShapedBuffer( const Shape& on_host_shape, se::DeviceMemoryAllocator* allocator, int device_ordinal, DeviceShapeRepresentationFn shape_representation_fn = nullptr); @@ -283,7 +288,7 @@ class TransferManager { // Returns the transfer manager singleton pointer if it is available for the // given platform, or an error status if it is not. - static StatusOr GetForPlatform( + static absl::StatusOr GetForPlatform( const se::Platform* platform); // Writes the given device-memory pointers in 'elements' to the given region @@ -293,6 +298,16 @@ class TransferManager { se::Stream* stream, absl::Span elements, const Shape& shape, se::DeviceMemoryBase* region) = 0; + // Returns whether subbyte types (types less than 1 byte, e.g. U4) should + // have multiple values packed into a single byte on the device. Subbyte + // bytes are never packed on the host. By default, returns false, so a byte + // can only hold one value, but subclasses can override this. + // + // If overridden to return true, subclasses should pack and unpack in their + // overridden implementations of TransferLiteralToDeviceAsync and + // TransferLiteralFromDevice respectively. + virtual bool PackSubbyteTypes() const { return false; } + private: // The mutex that guards the platform-to-transfer manager map. static absl::Mutex platform_transfer_manager_mutex_; diff --git a/xla/service/transpose_folding.cc b/xla/service/transpose_folding.cc index dcfbf6883728f..b954f307f82d4 100644 --- a/xla/service/transpose_folding.cc +++ b/xla/service/transpose_folding.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -188,7 +188,7 @@ TransposeFolding::TransposeFolding( std::move(dot_can_fold_transpose_operand)), transposable_conv_operands_(std::move(transposable_conv_operands)) {} -StatusOr TransposeFolding::Run( +absl::StatusOr TransposeFolding::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { // Modifying the graph while traversing is dangerous, so we find all folding @@ -249,8 +249,9 @@ StatusOr TransposeFolding::Run( return changed; } -/*static*/ StatusOr TransposeFolding::IsRowColumnTransposeDotOperand( - const HloInstruction& dot, int64_t operand_idx) { +/*static*/ absl::StatusOr +TransposeFolding::IsRowColumnTransposeDotOperand(const HloInstruction& dot, + int64_t operand_idx) { TF_RET_CHECK(dot.opcode() == HloOpcode::kDot); TF_RET_CHECK(dot.operand_count() > operand_idx); diff --git a/xla/service/transpose_folding.h b/xla/service/transpose_folding.h index c83c959f9201d..eacf99ae3ced4 100644 --- a/xla/service/transpose_folding.h +++ b/xla/service/transpose_folding.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ class TransposeFolding : public HloModulePass { using TransposableConvOperandsFn = std::function; - using CanFoldTransposeOperand = std::function( + using CanFoldTransposeOperand = std::function( const HloInstruction&, int64_t /*operand_idx*/)>; // Helper function to explicitly not fold transposes. @@ -63,11 +63,11 @@ class TransposeFolding : public HloModulePass { absl::string_view name() const override { return "transpose-folding"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; - static StatusOr IsRowColumnTransposeDotOperand( + static absl::StatusOr IsRowColumnTransposeDotOperand( const HloInstruction& dot, int64_t operand_idx); private: diff --git a/xla/service/transpose_folding_test.cc b/xla/service/transpose_folding_test.cc index d329fdf12d052..1dd4f0c361bad 100644 --- a/xla/service/transpose_folding_test.cc +++ b/xla/service/transpose_folding_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -246,7 +246,7 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { dim->set_size( transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } - StatusOr conv_shape = ShapeInference::InferConvolveShape( + absl::StatusOr conv_shape = ShapeInference::InferConvolveShape( x->shape(), transpose_y->shape(), /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, /*preferred_element_type=*/std::nullopt); @@ -304,7 +304,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { dim->set_size( transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } - StatusOr conv_shape = ShapeInference::InferConvolveShape( + absl::StatusOr conv_shape = ShapeInference::InferConvolveShape( x->shape(), transpose_y->shape(), /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, /*preferred_element_type=*/std::nullopt); @@ -367,7 +367,7 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { dim->set_stride(1); dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } - StatusOr conv_shape = ShapeInference::InferConvolveShape( + absl::StatusOr conv_shape = ShapeInference::InferConvolveShape( transpose_x->shape(), y->shape(), /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, /*preferred_element_type=*/std::nullopt); @@ -436,7 +436,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { dim->set_stride(1); dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } - StatusOr conv_shape = ShapeInference::InferConvolveShape( + absl::StatusOr conv_shape = ShapeInference::InferConvolveShape( transpose_x->shape(), y->shape(), /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, /*preferred_element_type=*/std::nullopt); diff --git a/xla/service/tree_reduction_rewriter.cc b/xla/service/tree_reduction_rewriter.cc index a294285aec396..1c505358648bd 100644 --- a/xla/service/tree_reduction_rewriter.cc +++ b/xla/service/tree_reduction_rewriter.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -109,7 +109,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { int64_t reduce_window_size_; }; -StatusOr TreeReductionRewriter::Run( +absl::StatusOr TreeReductionRewriter::Run( HloModule *module, const absl::flat_hash_set &execution_threads) { ReductionRewriterVisitor visitor(reduce_window_size_); diff --git a/xla/service/tree_reduction_rewriter.h b/xla/service/tree_reduction_rewriter.h index f66d113ca47e1..d1284d3041703 100644 --- a/xla/service/tree_reduction_rewriter.h +++ b/xla/service/tree_reduction_rewriter.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -48,7 +48,7 @@ class TreeReductionRewriter : public HloModulePass { absl::string_view name() const override { return "tree_reduction_rewriter"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/triangular_solve_expander.cc b/xla/service/triangular_solve_expander.cc index da54b04253628..80a3e2c655726 100644 --- a/xla/service/triangular_solve_expander.cc +++ b/xla/service/triangular_solve_expander.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ namespace { // Get the diagonal blocks of the coefficient matrix XlaOp DiagonalBlocks(XlaOp a, int64_t block_size) { XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a)); int ndims = shape.rank(); int64_t n = ShapeUtil::GetDimension(shape, -1); @@ -129,7 +129,7 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, bool transpose_a, bool conjugate_a, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(inv_diag_blocks)); TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); int64_t block_size = ShapeUtil::GetDimension(blocks_shape, -1); @@ -238,7 +238,7 @@ XlaOp TriangularSolveExpander::InvertDiagonalBlocks( XlaOp diag_blocks, bool lower_triangular, PrecisionConfig::Precision precision) { XlaBuilder* builder = diag_blocks.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { // Input is a batch of square lower triangular square matrices. Its shape is // (..., size, size). We resize this to (num_blocks, size, size). TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks)); @@ -366,7 +366,7 @@ XlaOp TriangularSolveExpander::SolveByInvertingDiagonalBlocks( bool conjugate_a, bool unit_diagonal, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); const int64_t ndims = a_shape.rank(); int64_t k = ShapeUtil::GetDimension(a_shape, -1); @@ -410,7 +410,7 @@ XlaOp TriangularSolveExpander::SolveDirectly( bool conjugate_a, bool unit_diagonal, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); int64_t m = ShapeUtil::GetDimension(b_shape, -2); @@ -467,7 +467,7 @@ XlaOp TriangularSolveExpander::BuildTriangularSolve( bool conjugate_a, bool unit_diagonal, int64_t block_size, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); if (a_shape.rank() != b_shape.rank()) { @@ -551,7 +551,7 @@ bool TriangularSolveExpander::InstructionMatchesPattern( return instruction->opcode() == HloOpcode::kTriangularSolve; } -StatusOr TriangularSolveExpander::ExpandInstruction( +absl::StatusOr TriangularSolveExpander::ExpandInstruction( HloInstruction* instruction) { const TriangularSolveOptions& options = instruction->triangular_solve_options(); diff --git a/xla/service/triangular_solve_expander.h b/xla/service/triangular_solve_expander.h index 1767e1b57d3da..0ccbcf1cf7cea 100644 --- a/xla/service/triangular_solve_expander.h +++ b/xla/service/triangular_solve_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ class TriangularSolveExpander : public OpExpanderPass { bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; // Performs a triangular solve using an algorithm from MAGMA, which inverts diff --git a/xla/service/triangular_solve_expander_test.cc b/xla/service/triangular_solve_expander_test.cc index 566f14e426d1d..777f1258eb1ce 100644 --- a/xla/service/triangular_solve_expander_test.cc +++ b/xla/service/triangular_solve_expander_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/tuple_points_to_analysis.cc b/xla/service/tuple_points_to_analysis.cc index 612fb488b2605..c90e8836a12ff 100644 --- a/xla/service/tuple_points_to_analysis.cc +++ b/xla/service/tuple_points_to_analysis.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -137,7 +137,7 @@ void GatherFusionInstructions( } // namespace -/* static */ StatusOr> +/* static */ absl::StatusOr> TuplePointsToAnalysis::Run(const HloModule* module) { auto logical_buffer_analysis = LogicalBufferAnalysis::Run(module); std::unique_ptr analysis(new TuplePointsToAnalysis( @@ -326,7 +326,7 @@ Status TuplePointsToAnalysis::HandleAsyncStart(HloInstruction* async_start) { [&](const ShapeIndex& target_index, PointsToSet::BufferList* buffers) { if (target_index.size() >= 2 && target_index.front() == 0) { const PointsToSet& operand_points_to_set = - GetPointsToSet(async_start->operand(target_index.at(1))); + GetPointsToSet(async_start->operand(target_index[1])); ShapeIndex source_index(target_index.begin() + 2, target_index.end()); *buffers = operand_points_to_set.element(source_index); for (HloInstruction* tuple : @@ -632,7 +632,7 @@ const LogicalBuffer& TuplePointsToAnalysis::GetBuffer( return logical_buffer_analysis_->GetBuffer(id); } -StatusOr TuplePointsToAnalysis::GetBufferDefinedAt( +absl::StatusOr TuplePointsToAnalysis::GetBufferDefinedAt( const HloInstruction* instruction, const ShapeIndex& index) const { const auto& buffers = GetPointsToSet(instruction).element(index); if (buffers.size() != 1 || buffers[0]->instruction() != instruction) { @@ -645,7 +645,7 @@ StatusOr TuplePointsToAnalysis::GetBufferDefinedAt( const TuplePointsToAnalysis::BufferAliasVector& TuplePointsToAnalysis::GetBufferAliases(const LogicalBuffer& buffer) const { - return logical_buffer_aliases_.at(buffer.id()); + return logical_buffer_aliases_[buffer.id()]; } const TuplePointsToAnalysis::BufferDefinitionVector& @@ -719,7 +719,7 @@ std::string TuplePointsToAnalysis::ToString() const { absl::StrAppend(&output, "LogicalBuffers:\n"); for (const auto& b : logical_buffer_analysis_->logical_buffers()) { absl::StrAppend(&output, " buffer ", b->ToString(), ":\n"); - for (const BufferAlias& alias : logical_buffer_aliases_.at(b->id())) { + for (const BufferAlias& alias : logical_buffer_aliases_[b->id()]) { absl::StrAppend(&output, " alias ", alias.ToString(), "\n"); } } diff --git a/xla/service/tuple_points_to_analysis.h b/xla/service/tuple_points_to_analysis.h index 9985b2ae17c72..2b58b1c2e97ac 100644 --- a/xla/service/tuple_points_to_analysis.h +++ b/xla/service/tuple_points_to_analysis.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -183,7 +183,7 @@ std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias); class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { public: // Runs points-to analysis on 'module'. - static StatusOr> Run( + static absl::StatusOr> Run( const HloModule* module); // Return the points-to set of an instruction. This describes the potential @@ -196,7 +196,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { // Returns the buffer defined at the given instruction and index. An error is // returned if no buffer is defined at that point. - StatusOr GetBufferDefinedAt( + absl::StatusOr GetBufferDefinedAt( const HloInstruction* instruction, const ShapeIndex& index) const; // Return a (possibly empty) vector containing all BufferAliases of the given diff --git a/xla/service/tuple_points_to_analysis_test.cc b/xla/service/tuple_points_to_analysis_test.cc index 9a47d4395c5fb..3b0cf4d3b9feb 100644 --- a/xla/service/tuple_points_to_analysis_test.cc +++ b/xla/service/tuple_points_to_analysis_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/tuple_simplifier.cc b/xla/service/tuple_simplifier.cc index d47807fcd6868..3163e82161f4c 100644 --- a/xla/service/tuple_simplifier.cc +++ b/xla/service/tuple_simplifier.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ namespace xla { TupleSimplifier::TupleSimplifier(bool exclude_entry_computation) : exclude_entry_computation_(exclude_entry_computation) {} -StatusOr TupleSimplifier::RemoveWholeTuple(HloInstruction* tuple) { +absl::StatusOr TupleSimplifier::RemoveWholeTuple(HloInstruction* tuple) { HloInstruction* top_tuple = nullptr; for (int64_t operand_number = 0; operand_number < tuple->operand_count(); ++operand_number) { @@ -53,7 +53,7 @@ StatusOr TupleSimplifier::RemoveWholeTuple(HloInstruction* tuple) { return changed; } -StatusOr TupleSimplifier::Run( +absl::StatusOr TupleSimplifier::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { // Initially add all GTE and Tuple instructions to the worklist. diff --git a/xla/service/tuple_simplifier.h b/xla/service/tuple_simplifier.h index b5b029bec95cb..bd9d2d850c7ea 100644 --- a/xla/service/tuple_simplifier.h +++ b/xla/service/tuple_simplifier.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ class TupleSimplifier : public HloModulePass { // computation was changed. using HloPassInterface::Run; using HloPassInterface::RunOnModuleGroup; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -59,7 +59,7 @@ class TupleSimplifier : public HloModulePass { // | // Tuple // - StatusOr RemoveWholeTuple(HloInstruction* tuple); + absl::StatusOr RemoveWholeTuple(HloInstruction* tuple); }; } // namespace xla diff --git a/xla/service/tuple_simplifier_test.cc b/xla/service/tuple_simplifier_test.cc index 167584b82c6e9..5ee4d4929641e 100644 --- a/xla/service/tuple_simplifier_test.cc +++ b/xla/service/tuple_simplifier_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/tuple_util.cc b/xla/service/tuple_util.cc index 01f523cad6a67..2ca0b08bb2d60 100644 --- a/xla/service/tuple_util.cc +++ b/xla/service/tuple_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -84,7 +84,7 @@ namespace xla { HloInstruction::CreateTuple(tuple_elements)); } -/*static*/ StatusOr TupleUtil::ReplaceTupleWith( +/*static*/ absl::StatusOr TupleUtil::ReplaceTupleWith( HloInstruction* new_instruction, HloInstruction* tuple, ShapeIndex shape_index, bool insert_bitcast_if_different_shape) { const Shape& tuple_shape = tuple->shape(); diff --git a/xla/service/tuple_util.h b/xla/service/tuple_util.h index 5fd434fca1596..95c5b44e19d9c 100644 --- a/xla/service/tuple_util.h +++ b/xla/service/tuple_util.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -60,7 +60,7 @@ class TupleUtil { // new_instruction. If the replacement instruction has a different shape than // the old one, we insert a bitcast if insert_bitcast_if_different_shape is // set to true. - static StatusOr ReplaceTupleWith( + static absl::StatusOr ReplaceTupleWith( HloInstruction* new_instruction, HloInstruction* tuple, ShapeIndex shape_index, bool insert_bitcast_if_different_shape = true); diff --git a/xla/service/tuple_util_test.cc b/xla/service/tuple_util_test.cc index fcf66535d098c..2956e6edfbac8 100644 --- a/xla/service/tuple_util_test.cc +++ b/xla/service/tuple_util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/value_range.cc b/xla/service/value_range.cc index bdf81daa7594a..178ebcb33f31d 100644 --- a/xla/service/value_range.cc +++ b/xla/service/value_range.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/value_range.h b/xla/service/value_range.h index abc471e2f8d48..daf55ce2331e7 100644 --- a/xla/service/value_range.h +++ b/xla/service/value_range.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/value_range_test.cc b/xla/service/value_range_test.cc index fabbb35880bfb..5fc8cd48dcfec 100644 --- a/xla/service/value_range_test.cc +++ b/xla/service/value_range_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/while_loop_all_reduce_code_motion.cc b/xla/service/while_loop_all_reduce_code_motion.cc index a24008014915e..19fa63daca07b 100644 --- a/xla/service/while_loop_all_reduce_code_motion.cc +++ b/xla/service/while_loop_all_reduce_code_motion.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -932,7 +932,7 @@ Status AddSinkedAllReducesAndReplaceWhile( } // namespace -StatusOr WhileLoopAllReduceCodeMotion::Run( +absl::StatusOr WhileLoopAllReduceCodeMotion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool is_changed = false; diff --git a/xla/service/while_loop_all_reduce_code_motion.h b/xla/service/while_loop_all_reduce_code_motion.h index 98903c4c5efc9..141d51f407275 100644 --- a/xla/service/while_loop_all_reduce_code_motion.h +++ b/xla/service/while_loop_all_reduce_code_motion.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -52,7 +52,7 @@ class WhileLoopAllReduceCodeMotion : public HloModulePass { return "while-loop-all-reduce-code-motion"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/while_loop_all_reduce_code_motion_test.cc b/xla/service/while_loop_all_reduce_code_motion_test.cc index fecae53c7bc62..22928f05f1cac 100644 --- a/xla/service/while_loop_all_reduce_code_motion_test.cc +++ b/xla/service/while_loop_all_reduce_code_motion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/while_loop_analysis.cc b/xla/service/while_loop_analysis.cc index ffbde2332fd18..2445489d5eca2 100644 --- a/xla/service/while_loop_analysis.cc +++ b/xla/service/while_loop_analysis.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,8 +15,13 @@ limitations under the License. #include "xla/service/while_loop_analysis.h" +#include +#include +#include + #include "absl/base/casts.h" #include "absl/container/flat_hash_map.h" +#include "xla/comparison_util.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -25,6 +30,7 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/pattern_matcher.h" +#include "xla/shape_util.h" namespace xla { @@ -348,7 +354,7 @@ optional CheckedSubtract(int64_t a, int64_t b) { optional MatchTrivialLoopTripCount(const HloInstruction* while_op, int64_t indvar_tuple_idx, const Literal& indvar_init) { - // First, find the scalar constant K that `i` is initialized to. + // First, find the scalar constant init that `i` is initialized to. optional indvar_init_val = LiteralUtil::LiteralAsScalarInt64(indvar_init); if (!indvar_init_val) { @@ -358,21 +364,51 @@ optional MatchTrivialLoopTripCount(const HloInstruction* while_op, return nullopt; } - // Check that `i` goes as `i++` in the while body. - // - // TODO(jlebar): We could also handle i-- and other idioms. + // Check that `i` goes as `i += k` in the while body where k is a natural + // number. auto* while_body = while_op->while_body(); auto* while_body_indvar_update = - while_body->root_instruction()->operand(indvar_tuple_idx); + while_body->root_instruction()->mutable_operand(indvar_tuple_idx); auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); + HloInstruction* trip_count_increase_step_instr = nullptr; + int64_t trip_count_step = 0; if (!Match(while_body_indvar_update, m::AddAnyOrder(m::Op().Is(while_body_indvar), - m::ConstantEffectiveScalar(1)))) { - VLOG(2) << "Pattern-match failed: induction variable does not go as i++: " - << while_body_indvar_update->ToString(); - return nullopt; + m::Op(&trip_count_increase_step_instr)))) { + if (trip_count_increase_step_instr == nullptr) { + VLOG(2) << "Pattern-match failed: induction variable is not getting " + "updated by an add operation: " + << while_body_indvar_update->ToString(); + return nullopt; + } + if (!trip_count_increase_step_instr->IsConstant() || + !ShapeUtil::IsEffectiveScalar( + trip_count_increase_step_instr->shape())) { + VLOG(2) << "Pattern-match failed: induction variable is not getting " + "incremented by constant: " + << while_body_indvar_update->ToString(); + return nullopt; + } + if (!LiteralUtil::LiteralAsScalarInt64( + trip_count_increase_step_instr->literal()) + .has_value()) { + VLOG(2) + << "Pattern-match failed: trip count step is not an integral type: " + << trip_count_increase_step_instr->shape().ToString(); + return nullopt; + } + VLOG(2) << "Pattern-match for trip count step failed: " + << trip_count_increase_step_instr->ToString(); } + trip_count_step = LiteralUtil::LiteralAsScalarInt64( + trip_count_increase_step_instr->literal()) + .value(); + if (trip_count_step <= 0) { + VLOG(2) << "Pattern-match failed: trip count step is not a natural number: " + << trip_count_step; + return nullopt; + } // Check that we do op(i, N) or op(N, i) as the while condition. Capture the // value N. auto* while_cond = while_op->while_condition(); @@ -397,7 +433,7 @@ optional MatchTrivialLoopTripCount(const HloInstruction* while_op, return nullopt; } - // Handle `i = K; i < N; ++i`. + // Handle `i = init; i < N; i+=k`. if (Match(while_cond_root, m::Op() .WithComparisonDirection(ComparisonDirection::kLt) @@ -407,14 +443,26 @@ optional MatchTrivialLoopTripCount(const HloInstruction* while_op, optional trips = CheckedSubtract(*while_cond_bound_val, *indvar_init_val); if (trips) { - return std::max(int64_t{0}, *trips); - } else { - VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX."; - return nullopt; + const int64_t remainder = std::remainder(*trips, trip_count_step); + const int64_t div = std::floor(*trips / trip_count_step); + if (remainder == 0) { + return std::max(int64_t{0}, div); + } + trips = CheckedAdd(div, 1); + if (!trips) { + VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX."; + return nullopt; + } + if (*trips < *while_cond_bound_val) { + return std::max(int64_t{0}, *trips); + } + return std::max(int64_t{0}, div); } + VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX."; + return nullopt; } - // Handle `i = K; i <= N; ++i`. + // Handle `i = init; i <= N; i+=k`. if (Match(while_cond_root, m::Op() .WithComparisonDirection(ComparisonDirection::kLe) @@ -427,7 +475,7 @@ optional MatchTrivialLoopTripCount(const HloInstruction* while_op, VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX"; return nullopt; } - trips = CheckedAdd(*trips, 1); + trips = CheckedAdd(std::floor(*trips / trip_count_step), 1); if (!trips) { VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX"; return nullopt; @@ -460,7 +508,7 @@ optional ComputeWhileLoopTripCount(const HloInstruction* while_op, HloEvaluator evaluator(/*max_loop_iterations=*/0); auto* while_init = while_op->operand(0); auto* indvar_init = while_init->operand(*indvar_tuple_idx); - StatusOr indvar_init_result = evaluator.Evaluate(indvar_init); + absl::StatusOr indvar_init_result = evaluator.Evaluate(indvar_init); if (!indvar_init_result.ok()) { VLOG(2) << "Couldn't evaluate induction variable init, " << indvar_init_result.status() << ", " << indvar_init->ToString(); @@ -486,7 +534,7 @@ optional ComputeWhileLoopTripCount(const HloInstruction* while_op, for (int64_t trip_count = 0; trip_count != max_brute_force_iters + 1; ++trip_count) { - StatusOr result = evaluator.EvaluateWithSubstitutions( + absl::StatusOr result = evaluator.EvaluateWithSubstitutions( while_cond_root, {{while_cond_indvar, &indvar_iter_val}}); if (!result.ok()) { VLOG(2) << "Couldn't evaluate while cond: " << result.status(); @@ -499,8 +547,9 @@ optional ComputeWhileLoopTripCount(const HloInstruction* while_op, // Calculate the value of the induction variable after one iteration of the // loop, and check whether the while condition is true with this new value. - StatusOr indvar_next_result = evaluator.EvaluateWithSubstitutions( - while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}}); + absl::StatusOr indvar_next_result = + evaluator.EvaluateWithSubstitutions( + while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}}); if (!indvar_next_result.ok()) { VLOG(2) << "Couldn't evaluate induction variable update: " << indvar_next_result.status(); @@ -596,7 +645,7 @@ optional ComputeWhileLoopTripCountUpperBound( TF_CHECK_OK(fake_input.CopyFrom(while_body_indvar->literal(), /*dest_shape_index=*/{0}, /*src_shape_index=*/{})); - StatusOr eval_result = + absl::StatusOr eval_result = evaluator.Evaluate(*new_computation, {std::move(fake_input)}); if (!eval_result.ok()) { diff --git a/xla/service/while_loop_analysis.h b/xla/service/while_loop_analysis.h index a9e1d074f850c..5fe4038ab6d0b 100644 --- a/xla/service/while_loop_analysis.h +++ b/xla/service/while_loop_analysis.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/while_loop_analysis_test.cc b/xla/service/while_loop_analysis_test.cc index 1dee0b36c5c1b..924222fbc11a6 100644 --- a/xla/service/while_loop_analysis_test.cc +++ b/xla/service/while_loop_analysis_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,58 +15,98 @@ limitations under the License. #include "xla/service/while_loop_analysis.h" +#include +#include +#include +#include +#include + #include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/statusor.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { -class WhileLoopAnalysisTest : public HloTestBase {}; +class WhileLoopAnalysisTest : public HloTestBase { + protected: + [[nodiscard]] absl::StatusOr MakeWhileLoopAndGetTripCount( + int init, int limit, int step, ComparisonDirection dir); +}; -TEST_F(WhileLoopAnalysisTest, SingleIterationUpperBound) { - const char* const kHloModule = R"( - HloModule ModuleWithWhile +absl::StatusOr WhileLoopAnalysisTest::MakeWhileLoopAndGetTripCount( + int init, int limit, int step, ComparisonDirection dir) { + std::string hlo_string_template = R"( + HloModule ModuleWithWhile body { p_body = (f32[2], s32[]) parameter(0) val = f32[2] get-tuple-element(p_body), index=0 - const = s32[] constant(-1) - ROOT root = (f32[2], s32[]) tuple(val, const) + index = s32[] get-tuple-element(p_body), index=1 + one = s32[] constant({{STEP}}) + inc = s32[] add(index, one) + ROOT root = (f32[2], s32[]) tuple(val, inc) } condition { p_cond = (f32[2], s32[]) parameter(0) gte = s32[] get-tuple-element(p_cond), index=1 - const = s32[] constant(42) - ROOT result = pred[] compare(gte, const), direction=EQ + const = s32[] constant({{LIMIT}}) + ROOT result = pred[] compare(gte, const), direction={{COMP_DIR}} } ENTRY entry { param.0 = f32[2] parameter(0) - param.1 = s32[] parameter(1) + param.1 = s32[] constant({{INIT}}) while_init = (f32[2], s32[]) tuple(param.0, param.1) ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body - })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kHloModule)); + } + )"; + + std::string hlo_string = + absl::StrReplaceAll(hlo_string_template, + {{"{{INIT}}", absl::StrCat(init)}, + {"{{LIMIT}}", absl::StrCat(limit)}, + {"{{STEP}}", absl::StrCat(step)}, + {"{{COMP_DIR}}", ComparisonDirectionToString(dir)}}); + + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* while_op = module->entry_computation()->root_instruction(); - EXPECT_EQ(*ComputeWhileLoopTripCountUpperBound(while_op), 1); + std::optional trip_count = MatchTrivialLoopTripCount( + while_op, 1, + Cast( + module->GetComputationWithName("entry")->GetInstructionWithName( + "param.1")) + ->literal()); + + CHECK(trip_count.has_value()); + + return *trip_count; } -TEST_F(WhileLoopAnalysisTest, NoUpperBound) { +TEST_F(WhileLoopAnalysisTest, SingleIterationUpperBound) { const char* const kHloModule = R"( HloModule ModuleWithWhile body { p_body = (f32[2], s32[]) parameter(0) val = f32[2] get-tuple-element(p_body), index=0 - const = s32[] constant(42) + const = s32[] constant(-1) ROOT root = (f32[2], s32[]) tuple(val, const) } @@ -87,32 +127,30 @@ TEST_F(WhileLoopAnalysisTest, NoUpperBound) { ParseAndReturnVerifiedModule(kHloModule)); HloInstruction* while_op = module->entry_computation()->root_instruction(); - EXPECT_EQ(ComputeWhileLoopTripCountUpperBound(while_op), std::nullopt); + EXPECT_EQ(*ComputeWhileLoopTripCountUpperBound(while_op), 1); } -TEST_F(WhileLoopAnalysisTest, ExactBoundTrivialTripCount) { +TEST_F(WhileLoopAnalysisTest, NoUpperBound) { const char* const kHloModule = R"( HloModule ModuleWithWhile body { p_body = (f32[2], s32[]) parameter(0) val = f32[2] get-tuple-element(p_body), index=0 - index = s32[] get-tuple-element(p_body), index=1 - one = s32[] constant(1) - inc = s32[] add(index, one) - ROOT root = (f32[2], s32[]) tuple(val, inc) + const = s32[] constant(42) + ROOT root = (f32[2], s32[]) tuple(val, const) } condition { p_cond = (f32[2], s32[]) parameter(0) gte = s32[] get-tuple-element(p_cond), index=1 const = s32[] constant(42) - ROOT result = pred[] compare(gte, const), direction=LT + ROOT result = pred[] compare(gte, const), direction=EQ } ENTRY entry { param.0 = f32[2] parameter(0) - param.1 = s32[] constant(0) + param.1 = s32[] parameter(1) while_init = (f32[2], s32[]) tuple(param.0, param.1) ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body })"; @@ -120,47 +158,54 @@ TEST_F(WhileLoopAnalysisTest, ExactBoundTrivialTripCount) { ParseAndReturnVerifiedModule(kHloModule)); HloInstruction* while_op = module->entry_computation()->root_instruction(); - - EXPECT_EQ( - *MatchTrivialLoopTripCount( - while_op, 1, - Cast(module->GetComputationWithName("entry") - ->GetInstructionWithName("param.1")) - ->literal()), - 42); + EXPECT_EQ(ComputeWhileLoopTripCountUpperBound(while_op), std::nullopt); } -TEST_F(WhileLoopAnalysisTest, ExactBound) { - const char* const kHloModule = R"( - HloModule ModuleWithWhile - - body { - p_body = (f32[2], s32[]) parameter(0) - val = f32[2] get-tuple-element(p_body), index=0 - index = s32[] get-tuple-element(p_body), index=1 - one = s32[] constant(1) - inc = s32[] add(index, one) - ROOT root = (f32[2], s32[]) tuple(val, inc) +int CalculateTripCount(int init, int limit, int step, ComparisonDirection dir) { + int trip_count = 0; + if (dir == ComparisonDirection::kLt) { + for (int i = init; i < limit; i += step) { + trip_count++; } - - condition { - p_cond = (f32[2], s32[]) parameter(0) - gte = s32[] get-tuple-element(p_cond), index=1 - const = s32[] constant(42) - ROOT result = pred[] compare(gte, const), direction=LT + } else if (dir == ComparisonDirection::kLe) { + for (int i = init; i <= limit; i += step) { + trip_count++; } + } else { + LOG(FATAL) << "Unknown comparison direction: " + << ComparisonDirectionToString(dir); + } + return trip_count; +} - ENTRY entry { - param.0 = f32[2] parameter(0) - param.1 = s32[] constant(0) - while_init = (f32[2], s32[]) tuple(param.0, param.1) - ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body - })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kHloModule)); +TEST_F(WhileLoopAnalysisTest, ExactBoundTrivialTripCount) { + // LT cases + EXPECT_EQ( + MakeWhileLoopAndGetTripCount(0, 42, 1, ComparisonDirection::kLt).value(), + CalculateTripCount(0, 42, 1, ComparisonDirection::kLt)); + EXPECT_EQ( + MakeWhileLoopAndGetTripCount(0, 42, 2, ComparisonDirection::kLt).value(), + CalculateTripCount(0, 42, 2, ComparisonDirection::kLt)); + EXPECT_EQ( + MakeWhileLoopAndGetTripCount(0, 42, 5, ComparisonDirection::kLt).value(), + CalculateTripCount(0, 42, 5, ComparisonDirection::kLt)); + EXPECT_EQ( + MakeWhileLoopAndGetTripCount(0, 40, 5, ComparisonDirection::kLt).value(), + CalculateTripCount(0, 40, 5, ComparisonDirection::kLt)); - HloInstruction* while_op = module->entry_computation()->root_instruction(); - EXPECT_EQ(*ComputeWhileLoopTripCountUpperBound(while_op), 42); + // LE cases + EXPECT_EQ( + MakeWhileLoopAndGetTripCount(0, 42, 1, ComparisonDirection::kLe).value(), + CalculateTripCount(0, 42, 1, ComparisonDirection::kLe)); + EXPECT_EQ( + MakeWhileLoopAndGetTripCount(0, 42, 2, ComparisonDirection::kLe).value(), + CalculateTripCount(0, 42, 2, ComparisonDirection::kLe)); + EXPECT_EQ( + MakeWhileLoopAndGetTripCount(0, 42, 5, ComparisonDirection::kLe).value(), + CalculateTripCount(0, 42, 5, ComparisonDirection::kLe)); + EXPECT_EQ( + MakeWhileLoopAndGetTripCount(0, 40, 5, ComparisonDirection::kLe).value(), + CalculateTripCount(0, 40, 5, ComparisonDirection::kLe)); } TEST_F(WhileLoopAnalysisTest, NoAIVNoConstChain) { diff --git a/xla/service/while_loop_concat_code_motion.cc b/xla/service/while_loop_concat_code_motion.cc index 300cb906a75a2..d7cf929ae7c54 100644 --- a/xla/service/while_loop_concat_code_motion.cc +++ b/xla/service/while_loop_concat_code_motion.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -944,8 +944,8 @@ Status RewriteLoopWithConcatGroups(HloInstruction* loop, return OkStatus(); } -StatusOr RunOnLoop(HloInstruction* loop, - int64_t min_operand_count_to_optimize) { +absl::StatusOr RunOnLoop(HloInstruction* loop, + int64_t min_operand_count_to_optimize) { auto body = loop->while_body(); auto param = body->parameter_instruction(0); auto root = body->root_instruction(); @@ -1019,7 +1019,7 @@ StatusOr RunOnLoop(HloInstruction* loop, } // namespace -StatusOr WhileLoopConcatCodeMotion::Run( +absl::StatusOr WhileLoopConcatCodeMotion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/while_loop_concat_code_motion.h b/xla/service/while_loop_concat_code_motion.h index c32744b7c85b9..dfb82ae5a009a 100644 --- a/xla/service/while_loop_concat_code_motion.h +++ b/xla/service/while_loop_concat_code_motion.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -68,7 +68,7 @@ class WhileLoopConcatCodeMotion : public HloModulePass { return kName; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/while_loop_concat_code_motion_test.cc b/xla/service/while_loop_concat_code_motion_test.cc index 8738c968608b4..a4baa5bbe4c1e 100644 --- a/xla/service/while_loop_concat_code_motion_test.cc +++ b/xla/service/while_loop_concat_code_motion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/while_loop_constant_sinking.cc b/xla/service/while_loop_constant_sinking.cc index 4821fe5c66960..d714cdb25a021 100644 --- a/xla/service/while_loop_constant_sinking.cc +++ b/xla/service/while_loop_constant_sinking.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "xla/service/while_util.h" +#include "xla/shape_util.h" #include "xla/util.h" namespace xla { @@ -64,7 +65,7 @@ HloInstruction* CloneHelper(const HloInstruction* instruction, } // namespace -StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop( +absl::StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop( HloInstruction* while_instr) { HloComputation* while_cond = while_instr->while_condition(); HloComputation* while_body = while_instr->while_body(); @@ -94,6 +95,12 @@ StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop( continue; } + if (sink_only_scalar_constants_) { + if (!ShapeUtil::IsScalar(init_value.operand(index)->shape())) { + continue; + } + } + // Sink into the while_body. // Should have at least one user that's not while_body_root. if (invariant_body_gte->user_count() > 1) { @@ -126,7 +133,7 @@ StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop( return changed; } -StatusOr WhileLoopConstantSinking::Run( +absl::StatusOr WhileLoopConstantSinking::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(2) << "HLO module before WhileLoopConstantSinking:"; diff --git a/xla/service/while_loop_constant_sinking.h b/xla/service/while_loop_constant_sinking.h index 7bdc186768fd5..fd556c6c65aac 100644 --- a/xla/service/while_loop_constant_sinking.h +++ b/xla/service/while_loop_constant_sinking.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -48,8 +48,10 @@ namespace xla { // class WhileLoopConstantSinking : public HloModulePass { public: - explicit WhileLoopConstantSinking(bool sink_broadcast_of_constants = false) - : sink_broadcast_of_constants_(sink_broadcast_of_constants) {} + explicit WhileLoopConstantSinking(bool sink_broadcast_of_constants = false, + bool sink_only_scalar_constants = false) + : sink_broadcast_of_constants_(sink_broadcast_of_constants), + sink_only_scalar_constants_(sink_only_scalar_constants) {} ~WhileLoopConstantSinking() override = default; @@ -58,14 +60,16 @@ class WhileLoopConstantSinking : public HloModulePass { } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; private: - StatusOr TrySinkingConstantsIntoWhileLoop(HloInstruction* while_instr); + absl::StatusOr TrySinkingConstantsIntoWhileLoop( + HloInstruction* while_instr); const bool sink_broadcast_of_constants_; + const bool sink_only_scalar_constants_; }; } // namespace xla diff --git a/xla/service/while_loop_constant_sinking_test.cc b/xla/service/while_loop_constant_sinking_test.cc index 20e60881e6285..3597686e9b9cc 100644 --- a/xla/service/while_loop_constant_sinking_test.cc +++ b/xla/service/while_loop_constant_sinking_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,7 +18,6 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { @@ -56,8 +55,17 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, - WhileLoopConstantSinking{}.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + WhileLoopConstantSinking(/*sink_broadcast_of_constants=*/false, + /*sink_only_scalar_constants=*/true) + .Run(module.get())); + ASSERT_FALSE(changed); + + TF_ASSERT_OK_AND_ASSIGN( + changed, WhileLoopConstantSinking(/*sink_broadcast_of_constants=*/false, + /*sink_only_scalar_constants=*/false) + .Run(module.get())); ASSERT_TRUE(changed); auto* while_body = module->GetComputationWithName("body"); diff --git a/xla/service/while_loop_expensive_invariant_code_motion.cc b/xla/service/while_loop_expensive_invariant_code_motion.cc index 5a08f2ea47728..399fd7c88c333 100644 --- a/xla/service/while_loop_expensive_invariant_code_motion.cc +++ b/xla/service/while_loop_expensive_invariant_code_motion.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -116,7 +116,7 @@ static void CreateLoopInvariantCopy( } } // namespace -StatusOr WhileLoopExpensiveInvariantCodeMotion:: +absl::StatusOr WhileLoopExpensiveInvariantCodeMotion:: TryHoistingInvariantInstructionsFromWhileBody(HloInstruction* while_instr) { auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); @@ -337,7 +337,7 @@ StatusOr WhileLoopExpensiveInvariantCodeMotion:: return true; } -StatusOr WhileLoopExpensiveInvariantCodeMotion::Run( +absl::StatusOr WhileLoopExpensiveInvariantCodeMotion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(2) << "HLO module before WhileLoopExpensiveInvariantCodeMotion:"; diff --git a/xla/service/while_loop_expensive_invariant_code_motion.h b/xla/service/while_loop_expensive_invariant_code_motion.h index 1505e883223dd..ac26df8f5b398 100644 --- a/xla/service/while_loop_expensive_invariant_code_motion.h +++ b/xla/service/while_loop_expensive_invariant_code_motion.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -44,12 +44,12 @@ class WhileLoopExpensiveInvariantCodeMotion : public HloModulePass { return "while-loop-expensive-invariant-code-motion"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; private: - StatusOr TryHoistingInvariantInstructionsFromWhileBody( + absl::StatusOr TryHoistingInvariantInstructionsFromWhileBody( HloInstruction* while_instr); ShapeSizeFunction shape_size_function_; diff --git a/xla/service/while_loop_expensive_invariant_code_motion_test.cc b/xla/service/while_loop_expensive_invariant_code_motion_test.cc index 28f2cb4c95b25..33d3d0e414933 100644 --- a/xla/service/while_loop_expensive_invariant_code_motion_test.cc +++ b/xla/service/while_loop_expensive_invariant_code_motion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/while_loop_fusible_sinking.cc b/xla/service/while_loop_fusible_sinking.cc new file mode 100644 index 0000000000000..07b0943d9304e --- /dev/null +++ b/xla/service/while_loop_fusible_sinking.cc @@ -0,0 +1,295 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/while_loop_fusible_sinking.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/while_util.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" + +namespace xla { + +namespace { +// Constant and Iota have no operands and an output and broadcasts add +// dimensions to the output so we are looking fusions that have much smaller +// operand sizes compared to output sizes to avoid materialization +bool IsPurelyExpanding(const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kBroadcast || + (instr->opcode() == HloOpcode::kConstant && + instr->shape().rank() == 0) || + instr->opcode() == HloOpcode::kIota; +} + +bool IsFusionCandidate(const HloInstruction* instr) { + return instr->opcode() != HloOpcode::kRng && + (instr->IsElementwise() || instr->opcode() == HloOpcode::kReshape || + instr->opcode() == HloOpcode::kTranspose); +} +} // namespace + +bool WhileLoopFusibleSinking::IsSinkableFusion(HloInstruction* while_operand) { + absl::InlinedVector worklist; + absl::flat_hash_set visited; + worklist.push_back(while_operand); + while (!worklist.empty()) { + HloInstruction* to_process = worklist.back(); + worklist.pop_back(); + if (!to_process->IsFusible()) { + return false; + } + if (!visited.insert(to_process->unique_id()).second) { + // Do not sink extremely large subgraphs as they will be expensive to + // recompute in the loop. + if (visited.size() > 100) { + return false; + } + continue; + } + if (IsPurelyExpanding(to_process)) { + continue; + } + if (IsFusionCandidate(to_process)) { + for (auto* op : to_process->operands()) { + worklist.push_back(op); + } + continue; + } + return false; + } + return true; +} + +HloInstruction* WhileLoopFusibleSinking::CreateSinkableFusion( + HloInstruction* while_operand) { + HloInstruction* fusion = + while_operand->AddInstruction(while_operand->CreateFusion( + while_operand->shape(), HloInstruction::FusionKind::kLoop, + while_operand)); + bool did_fuse = IsFusionCandidate(while_operand); + // Fuse up to broadcasts, this function expects that IsSinkableFusion is true + // and does not verify that + while (did_fuse) { + did_fuse = false; + for (int64_t i = fusion->operand_count() - 1; i >= 0; --i) { + HloInstruction* op = fusion->mutable_operand(i); + if (IsPurelyExpanding(op)) { + continue; + } + fusion->FuseInstruction(op); + did_fuse = true; + break; + } + } + // Fuse the broadcasts, constants and iota at the terminals. + did_fuse = true; + while (did_fuse) { + did_fuse = false; + for (int64_t i = fusion->operand_count() - 1; i >= 0; --i) { + HloInstruction* op = fusion->mutable_operand(i); + if (IsPurelyExpanding(op)) { + fusion->FuseInstruction(op); + did_fuse = true; + break; + } + } + } + return fusion; +} + +absl::StatusOr WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop( + HloInstruction* while_instr) { + HloComputation* while_cond = while_instr->while_condition(); + HloComputation* while_body = while_instr->while_body(); + + // Don't try to mutate unflattened while loop computations. + if (call_counts_[while_body] > 1 || call_counts_[while_cond] > 1) { + return false; + } + HloInstruction* init_value = while_instr->mutable_operand(0); + if (init_value->opcode() != HloOpcode::kTuple) { + return false; + } + + bool changed = false; + + absl::flat_hash_map> + conditional_gte_index_to_insts = + WhileUtil::GetGTEsMapForWhileConditional(*while_cond); + std::vector invariant_body_gtes = + WhileUtil::GetInvariantGTEsForWhileBody(*while_body); + std::vector tuple_indices; + std::vector new_operands; + + for (HloInstruction* invariant_body_gte : invariant_body_gtes) { + int64_t index = invariant_body_gte->tuple_index(); + if (while_instr->operand_count() == 0 || init_value->operand_count() == 0) { + // This is the case when each of tuple elements in the operand tuple of + // the while loop was an invariant value and each of the usages has been + // replaced. + CHECK_EQ(while_instr->user_count(), 0); + VLOG(3) << "Each element in the operand tuple of the while instruction '" + << while_instr->name() + << "' was an invariant value, whose usage has been replaced " + " directly by the value."; + break; + } + + HloInstruction* invariant_value = init_value->mutable_operand(index); + + // If a while operand is used by a slicing instruction, avoid fusing + // invariant value into the loop. + if (absl::c_any_of(invariant_body_gte->users(), + [](const HloInstruction* use) { + switch (use->opcode()) { + case HloOpcode::kDynamicSlice: + case HloOpcode::kGather: + case HloOpcode::kSlice: + return true; + default: + return false; + } + })) { + continue; + } + + if (init_value->IsRoot() || init_value->user_count() > 1) { + init_value = init_value->AddInstruction(init_value->Clone()); + TF_RETURN_IF_ERROR(while_instr->ReplaceOperandWith(0, init_value)); + } + // Original value should be a fusible subgraph. + if (!IsSinkableFusion(invariant_value)) { + continue; + } + HloInstruction* fusion = CreateSinkableFusion(invariant_value); + changed = true; + if (fusion->operand_count() > 0 && + (while_instr->IsRoot() || + absl::c_any_of(while_instr->users(), [&](HloInstruction* use) { + return use->opcode() != HloOpcode::kGetTupleElement; + }))) { + // This really only occurs in unit tests or toy programs. Copy the current + // users for later replacement. + auto uses = while_instr->users(); + std::vector gtes(init_value->operand_count()); + for (int64_t i = 0; i < gtes.size(); ++i) { + gtes[i] = while_instr->AddInstruction( + HloInstruction::CreateGetTupleElement(while_instr, i)); + } + HloInstruction* tuple = + while_instr->AddInstruction(HloInstruction::CreateTuple(gtes)); + if (while_instr->IsRoot()) { + while_instr->parent()->set_root_instruction(tuple); + } + if (!uses.empty()) { + TF_RETURN_IF_ERROR(while_instr->ReplaceUsesWith(uses, tuple)); + } + } + + absl::InlinedVector invariant_output_uses; + for (auto use : while_instr->users()) { + if (use->opcode() == HloOpcode::kGetTupleElement && + use->tuple_index() == index) { + invariant_output_uses.push_back(use); + } + } + for (auto use : invariant_output_uses) { + TF_RETURN_IF_ERROR( + while_instr->parent()->ReplaceInstruction(use, invariant_value)); + } + + HloInstruction* root = while_body->root_instruction(); + HloInstruction* parameter = while_body->parameter_instruction(0); + tuple_indices.resize(fusion->operand_count()); + int64_t next_index = init_value->operand_count(); + new_operands.resize(fusion->operand_count()); + for (int64_t i = 0; i < fusion->operand_count(); ++i) { + init_value->AppendOperand(fusion->mutable_operand(i)); + parameter->mutable_shape()->mutable_tuple_shapes()->push_back( + fusion->mutable_operand(i)->shape()); + new_operands[i] = root->AddInstruction( + HloInstruction::CreateGetTupleElement(parameter, next_index++)); + root->AppendOperand(new_operands[i]); + } + *(init_value->mutable_shape()) = parameter->shape(); + *(while_instr->mutable_shape()) = parameter->shape(); + *(while_cond->parameter_instruction(0)->mutable_shape()) = + parameter->shape(); + *(root->mutable_shape()) = parameter->shape(); + auto cloned_fusion = while_body->AddInstruction( + fusion->CloneWithNewOperands(fusion->shape(), new_operands)); + TF_RETURN_IF_ERROR(fusion->parent()->RemoveInstruction(fusion)); + TF_RETURN_IF_ERROR( + while_body->ReplaceInstruction(invariant_body_gte, cloned_fusion)); + TF_RETURN_IF_ERROR(cloned_fusion->Defuse()); + } + + return changed; +} + +absl::StatusOr WhileLoopFusibleSinking::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + call_counts_.clear(); + bool changed = false; + std::vector while_instrs; + for (auto* comp : module->MakeNonfusionComputations(execution_threads)) { + // Right now we don't particularly care about optimizing while-of-while + // patterns. If/When we do, we'll want to visit the outer while (while_0) + // before we visit the inner while (while_1): + // + // while_1_body(state) { + // val = gte(state, 0) // Loop invariant + // use(val) + // } + // + // while_0_body(state) { + // val = gte(state, 0) // Loop invariant + // while_1 = while(init=tuple(val, ...), body=while_1_body, ...) + // ... + // } + // + // main { + // while_0 = while(init=(fusible, ...), body=while_0_body, ...) + // } + // + // This will let us sink the fusible into the outer while first and then + // into the inner while in a single run of this pass. + absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + HloPredicateIsOp); + } + + for (HloInstruction* while_instr : while_instrs) { + call_counts_[while_instr->while_body()]++; + call_counts_[while_instr->while_condition()]++; + } + + for (HloInstruction* while_instr : while_instrs) { + TF_ASSIGN_OR_RETURN(bool result, + TrySinkingFusiblesIntoWhileLoop(while_instr)); + changed |= result; + } + return changed; +} +} // namespace xla diff --git a/xla/service/while_loop_fusible_sinking.h b/xla/service/while_loop_fusible_sinking.h new file mode 100644 index 0000000000000..a8ac53ec46d7c --- /dev/null +++ b/xla/service/while_loop_fusible_sinking.h @@ -0,0 +1,83 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_WHILE_LOOP_FUSIBLE_SINKING_H_ +#define XLA_SERVICE_WHILE_LOOP_FUSIBLE_SINKING_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" +#include "xla/statusor.h" + +namespace xla { + +// Sinks while loop invariant values that happen to be fusibles into the while +// loop body and conditional. This is probably not a win in isolation but may +// unlock further optimizations like fusible folding. +// +// state = (..., fusible_graph, ...) +// while (pred(state)) { +// (..., v, ...) = state +// use(v) +// state = (..., v, ...) +// } +// +// => +// +// state = (..., fusbile_graph, ..., fusible_graph_operands) +// while (pred(state)) { +// (..., v, ...) = state +// use(fusibile_graph) +// state = (..., v, ...) +// } +// +// Note that it leaves the `v` in place to keep that component of the state +// tuple trivially loop invariant. WhileLoopSimplifier will later get rid of +// `v`. +// +class WhileLoopFusibleSinking : public HloModulePass { + public: + WhileLoopFusibleSinking() = default; + + ~WhileLoopFusibleSinking() override = default; + + absl::string_view name() const override { + return "while-loop-fusible-sinking"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + // Sink a fusible subgraph into a while loop. + absl::StatusOr TrySinkingFusiblesIntoWhileLoop( + HloInstruction* while_instr); + + // Creates a loop fusion instruction containing the computation to move into + // the while loop to avoid conflicts with actual instruction fusion, the loop + // fusion will be defused. + bool IsSinkableFusion(HloInstruction* while_operand); + HloInstruction* CreateSinkableFusion(HloInstruction* while_operand); + + absl::flat_hash_map call_counts_; +}; +} // namespace xla + +#endif // XLA_SERVICE_WHILE_LOOP_FUSIBLE_SINKING_H_ diff --git a/xla/service/while_loop_fusible_sinking_test.cc b/xla/service/while_loop_fusible_sinking_test.cc new file mode 100644 index 0000000000000..fc457f290ff89 --- /dev/null +++ b/xla/service/while_loop_fusible_sinking_test.cc @@ -0,0 +1,158 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/while_loop_fusible_sinking.h" + +#include "xla/hlo/utils/hlo_matchers.h" +#include "xla/test.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; +using ::testing::_; +using WhileLoopFusibleSinkingTest = HloTestBase; + +TEST_F(WhileLoopFusibleSinkingTest, SinkOneFusible) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2],f32[2]) parameter(0) + p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0 + p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1 + + add.0 = f32[2] add(p_body.0, p_body.1) + ROOT root = (f32[2],f32[2]) tuple(add.0, p_body.1) +} + +condition { + p_cond = (f32[2],f32[2]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] parameter(0) + const_1 = f32[2] iota(), iota_dimension=0 + while_init = (f32[2],f32[2]) tuple(const_0, const_1) + ROOT while = (f32[2],f32[2]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopFusibleSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::Add(_, op::Iota()), _)); +} + +TEST_F(WhileLoopFusibleSinkingTest, SinkMask) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[5,7],f32[5,7]) parameter(0) + p_body.0 = get-tuple-element(p_body), index=0 + p_body.1 = get-tuple-element(p_body), index=1 + + add.0 = add(p_body.0, p_body.1) + ROOT root = tuple(add.0, p_body.1) +} + +condition { + p_cond = (f32[5,7],f32[5,7]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[5,7] parameter(0) + p = f32[5] parameter(1) + a = f32[5,7] iota(), iota_dimension=0 + b = f32[5,7] iota(), iota_dimension=1 + c = add(a, b) + d = f32[5,7] broadcast(p), dimensions={0} + mask = multiply(c,d) + while_init = tuple(const_0, mask) + ROOT while = while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopFusibleSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::Add(_, op::Multiply(op::Add(op::Iota(), op::Iota()), + op::Broadcast())), + _, _)); +} + +TEST_F(WhileLoopFusibleSinkingTest, NoSinkSlicedMask) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[5,7],f32[5,7]) parameter(0) + p_body.0 = get-tuple-element(p_body), index=0 + p_body.1 = get-tuple-element(p_body), index=1 + z = s32[] constant(0) + j = s32[] constant(3) + ds = f32[1,7] dynamic-slice(p_body.1, j, z), dynamic_slice_sizes={1,7} + r = f32[7] reshape(ds) + b = f32[5,7] broadcast(r), dimensions={1} + a = add(b, p_body.0) + add.0 = add(a, p_body.1) + ROOT root = tuple(add.0, p_body.1) +} + +condition { + p_cond = (f32[5,7],f32[5,7]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[5,7] parameter(0) + p = f32[5] parameter(1) + a = f32[5,7] iota(), iota_dimension=0 + b = f32[5,7] iota(), iota_dimension=1 + c = add(a, b) + d = f32[5,7] broadcast(p), dimensions={0} + mask = multiply(c,d) + while_init = tuple(const_0, mask) + ROOT while = while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopFusibleSinking{}.Run(module.get())); + EXPECT_FALSE(changed); +} + +} // namespace +} // namespace xla diff --git a/xla/service/while_loop_invariant_code_motion.cc b/xla/service/while_loop_invariant_code_motion.cc index 7d583c4f7bac9..1f670664434c0 100644 --- a/xla/service/while_loop_invariant_code_motion.cc +++ b/xla/service/while_loop_invariant_code_motion.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,13 +19,23 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/map_util.h" #include "xla/service/compile_time_cap.h" #include "xla/service/hlo_dce.h" -#include "xla/service/tuple_util.h" #include "xla/service/while_loop_analysis.h" #include "xla/service/while_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/statusor.h" #include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -127,7 +137,7 @@ bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually( } } -StatusOr +absl::StatusOr WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( HloInstruction* while_instr, BoundNonLinearCompilerAnalysis* allowance) { auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); @@ -318,7 +328,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( return true; } -StatusOr WhileLoopInvariantCodeMotion::Run( +absl::StatusOr WhileLoopInvariantCodeMotion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(2) << "HLO module before WhileLoopInvariantCodeMotion:"; diff --git a/xla/service/while_loop_invariant_code_motion.h b/xla/service/while_loop_invariant_code_motion.h index 7f96acc6f75d5..249605bacf473 100644 --- a/xla/service/while_loop_invariant_code_motion.h +++ b/xla/service/while_loop_invariant_code_motion.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,9 +16,14 @@ limitations under the License. #ifndef XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_ #define XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/compile_time_cap.h" #include "xla/service/hlo_pass_interface.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/statusor.h" namespace xla { @@ -67,13 +72,13 @@ class WhileLoopInvariantCodeMotion : public HloModulePass { return "while-loop-invariant-code-motion"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; private: bool NotWorthHoistingIndividually(const HloInstruction& instruction); - StatusOr TryHoistingInvariantInstructionsFromWhileBody( + absl::StatusOr TryHoistingInvariantInstructionsFromWhileBody( HloInstruction* while_instr, BoundNonLinearCompilerAnalysis* allowance); bool hoist_constants_; diff --git a/xla/service/while_loop_invariant_code_motion_test.cc b/xla/service/while_loop_invariant_code_motion_test.cc index cf16f9c522f13..5a9a35e31fbe9 100644 --- a/xla/service/while_loop_invariant_code_motion_test.cc +++ b/xla/service/while_loop_invariant_code_motion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,11 +15,22 @@ limitations under the License. #include "xla/service/while_loop_invariant_code_motion.h" +#include "absl/log/log.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/literal_util.h" #include "xla/service/hlo_parser.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/xla/service/while_loop_simplifier.cc b/xla/service/while_loop_simplifier.cc index bf0e5118c3d55..190ccaba46184 100644 --- a/xla/service/while_loop_simplifier.cc +++ b/xla/service/while_loop_simplifier.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,8 +22,11 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -38,8 +41,13 @@ limitations under the License. #include "xla/service/hlo_dce.h" #include "xla/service/pattern_matcher.h" #include "xla/service/while_loop_analysis.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" #include "xla/statusor.h" #include "xla/union_find.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -55,7 +63,7 @@ using std::optional; // if: // 1) x is a constant and x >= k + c. // 2) x is a constant x <= c. -static StatusOr TryRemoveTrivialCompare(HloInstruction* while_op) { +static absl::StatusOr TryRemoveTrivialCompare(HloInstruction* while_op) { std::optional indvar_index = GetLoopInductionVarTupleIdx(while_op); if (indvar_index.has_value()) { if (while_op->operand(0)->operand(*indvar_index)->IsConstant()) { @@ -127,7 +135,7 @@ void CopyFrontendAttributes(HloInstruction* old_while_op, // while loop init, body, and condition. The final shape returned is still the // same as before. If set index_for_replaced will replace any use of the removed // indices in the final shape with a copy of the removed index. -static StatusOr RemoveDeadTupleIndices( +static absl::StatusOr RemoveDeadTupleIndices( HloInstruction* while_op, absl::flat_hash_set& used_tuple_indices, int64_t index_for_replaced = -1) { // Build up maps from the old/new to the new/old tuple indices. @@ -304,7 +312,7 @@ static StatusOr RemoveDeadTupleIndices( // that tuple that is not used by the loop condition and is not used by the loop // body except to pass it to the next iteration of the loop, then we can remove // that element from the loop's tuples. -static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { +static absl::StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); // Don't try this transformation if the while loop isn't removable, since if @@ -549,7 +557,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // This is a helper function for TryRemoveRepeatedWhileTupleIndices. It removes // duplicates by replacing them with tuple_index, followed by a call to // RemoveDeadTupleIndices. -static StatusOr TryRemoveRepeatedWhileTupleIndicesHelper( +static absl::StatusOr TryRemoveRepeatedWhileTupleIndicesHelper( HloInstruction* while_op, const int64_t tuple_index, bool replace_with_init, absl::flat_hash_set& duplicates) { HloComputation* while_cond = while_op->while_condition(); @@ -609,7 +617,7 @@ static bool IsDynamicUpdateSliceWhileInsertion( // If the while loop init passes the same values to several tuple indices, and // if the body keeps on passing them through, we can remove the duplicates. -static StatusOr TryRemoveRepeatedWhileTupleIndices( +static absl::StatusOr TryRemoveRepeatedWhileTupleIndices( HloInstruction* while_op) { CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); @@ -757,7 +765,7 @@ static StatusOr TryRemoveRepeatedWhileTupleIndices( // Removes each loop parameter (i.e. member of the while loop tuple) that is a // constant and is the same in the while loop body and the while loop init. -static StatusOr TryRemoveConstantParams(HloInstruction* while_op) { +static absl::StatusOr TryRemoveConstantParams(HloInstruction* while_op) { HloModule* module = while_op->GetModule(); HloComputation* computation = while_op->parent(); auto* while_init = while_op->mutable_operand(0); @@ -905,7 +913,7 @@ static StatusOr TryRemoveConstantParams(HloInstruction* while_op) { // loop itself removed. // // Returns true if it made a change to the graph. -static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { +static absl::StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { // Cowardly refuse to remove loops that are not removable. In practice, this // means that we can't remove loops that have control predecessors/successors. if (!while_op->parent()->IsSafelyRemovable(while_op)) { @@ -990,7 +998,7 @@ static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { return false; } -static StatusOr TryPropagateConstant(HloInstruction* while_op) { +static absl::StatusOr TryPropagateConstant(HloInstruction* while_op) { auto while_init = while_op->operand(0); if (while_init->opcode() != HloOpcode::kTuple) { return false; @@ -1029,7 +1037,8 @@ static StatusOr TryPropagateConstant(HloInstruction* while_op) { // Replace the use of each constant tuple element in the loop_condition and // loop_body with the corresponding constant value. - auto propagate_constant = [&](HloComputation* computation) -> StatusOr { + auto propagate_constant = + [&](HloComputation* computation) -> absl::StatusOr { HloInstruction* param = computation->parameter_instruction(0); bool changed = false; for (auto instr : param->users()) { @@ -1126,7 +1135,7 @@ static std::vector GetFlatTupleElems( return elems; } -static StatusOr TryFlattenNestedTuples(HloInstruction* while_op) { +static absl::StatusOr TryFlattenNestedTuples(HloInstruction* while_op) { HloModule* module = while_op->GetModule(); HloComputation* computation = while_op->parent(); auto* while_init = while_op->mutable_operand(0); @@ -1256,7 +1265,7 @@ static StatusOr TryFlattenNestedTuples(HloInstruction* while_op) { // need to be wrapped in a tuple that changes its shape. We return the loop // itself so that you can call TryMergeInductionVariables in a loop, once for // each integral type elem_ty. -static StatusOr TryMergeInductionVariables( +static absl::StatusOr TryMergeInductionVariables( HloInstruction* while_op, PrimitiveType elem_ty) { CHECK(primitive_util::IsIntegralType(elem_ty)) << PrimitiveType_Name(elem_ty); HloModule* module = while_op->GetModule(); @@ -1471,7 +1480,7 @@ static StatusOr TryMergeInductionVariables( return new_while; } -StatusOr WhileLoopSimplifier::Run( +absl::StatusOr WhileLoopSimplifier::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES(3, diff --git a/xla/service/while_loop_simplifier.h b/xla/service/while_loop_simplifier.h index 9064fda41cadb..3aacd3b0c70ef 100644 --- a/xla/service/while_loop_simplifier.h +++ b/xla/service/while_loop_simplifier.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ #define XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" #include "xla/statusor.h" @@ -55,7 +57,7 @@ class WhileLoopSimplifier : public HloModulePass { ~WhileLoopSimplifier() override = default; absl::string_view name() const override { return "simplify-while-loops"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/xla/service/while_loop_simplifier_test.cc b/xla/service/while_loop_simplifier_test.cc index 94d87e754815b..2733e1cc69d98 100644 --- a/xla/service/while_loop_simplifier_test.cc +++ b/xla/service/while_loop_simplifier_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,9 +18,11 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal_util.h" #include "xla/service/hlo_dce.h" @@ -30,6 +32,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" namespace xla { diff --git a/xla/service/while_loop_trip_count_annotator.cc b/xla/service/while_loop_trip_count_annotator.cc index d7b86663ca58e..ca92eb6fc7897 100644 --- a/xla/service/while_loop_trip_count_annotator.cc +++ b/xla/service/while_loop_trip_count_annotator.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,12 +14,21 @@ limitations under the License. ==============================================================================*/ #include "xla/service/while_loop_trip_count_annotator.h" + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/while_loop_analysis.h" +#include "xla/statusor.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" namespace xla { -StatusOr WhileLoopTripCountAnnotator::Run( +absl::StatusOr WhileLoopTripCountAnnotator::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/while_loop_trip_count_annotator.h b/xla/service/while_loop_trip_count_annotator.h index 1c8374a98712e..440e5e6d6184b 100644 --- a/xla/service/while_loop_trip_count_annotator.h +++ b/xla/service/while_loop_trip_count_annotator.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_SERVICE_WHILE_LOOP_TRIP_COUNT_ANNOTATOR_H_ #define XLA_SERVICE_WHILE_LOOP_TRIP_COUNT_ANNOTATOR_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" #include "xla/statusor.h" @@ -41,7 +43,7 @@ class WhileLoopTripCountAnnotator : public HloModulePass { return "while-loop-trip-count-annotator"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; }; diff --git a/xla/service/while_loop_trip_count_annotator_test.cc b/xla/service/while_loop_trip_count_annotator_test.cc index b0e22e491a9b5..1b12f3178f4b0 100644 --- a/xla/service/while_loop_trip_count_annotator_test.cc +++ b/xla/service/while_loop_trip_count_annotator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,12 +15,10 @@ limitations under the License. #include "xla/service/while_loop_trip_count_annotator.h" -#include "xla/service/pattern_matcher.h" -#include "xla/service/while_loop_simplifier.h" -#include "xla/status_macros.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/xla/service/while_loop_unroller.cc b/xla/service/while_loop_unroller.cc index 886bff152641b..57fe2cdba2de2 100644 --- a/xla/service/while_loop_unroller.cc +++ b/xla/service/while_loop_unroller.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,7 +15,6 @@ limitations under the License. #include "xla/service/while_loop_unroller.h" -#include #include #include #include @@ -28,8 +27,10 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/comparison_util.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -40,8 +41,8 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/overflow_util.h" #include "xla/primitive_util.h" -#include "xla/service/async_op_canonicalizer.h" #include "xla/service/call_inliner.h" +#include "xla/service/collective_ops_utils.h" #include "xla/service/flatten_call_graph.h" #include "xla/service/hlo_cse.h" #include "xla/service/hlo_pass_fix.h" @@ -49,6 +50,7 @@ limitations under the License. #include "xla/service/while_loop_analysis.h" #include "xla/service/while_loop_constant_sinking.h" #include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -57,16 +59,123 @@ limitations under the License. namespace xla { namespace { + using hlo_query::ContainsInstrWithOpcode; +// Parameters for the unroller that can be adjusted. const int kUnrollTripCountThreshold = 64; const int kUnrollInstructionCountThreshold = 800; const int kUnrollExpandFactorThreshold = 10000; -}; // namespace + +// A utility function that decides whether a loop is unrollable or not. +std::optional IsLoopUnrollable(HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + + // TODO(b/300668690): Add support for unrolling loops with control dependency. + // For now, we bail. + // + // Finding all the while loops where other instructions have explicit control + // dependencies on them. + std::vector while_dependees; + for (HloComputation* comp : while_op->GetModule()->computations()) { + for (HloInstruction* instr : comp->instructions()) { + for (HloInstruction* control_dep : instr->control_predecessors()) { + if (control_dep->opcode() == HloOpcode::kWhile) { + while_dependees.push_back(control_dep); + } + } + } + } + if (absl::linear_search(while_dependees.begin(), while_dependees.end(), + while_op)) { + VLOG(2) << "Not attempting to unroll " << while_op->name() + << " due to control dependency: " << while_op->ToShortString(); + return std::nullopt; + } + + // We can't remove while loops that contain send/recv nodes, because we + // rely on the particular loop structure around the node matching on the + // send and recv sides. + if (ContainsInstrWithOpcode(while_op->while_body(), + {HloOpcode::kSend, HloOpcode::kSendDone, + HloOpcode::kRecv, HloOpcode::kRecvDone}) || + ContainsInstrWithOpcode(while_op->while_condition(), + {HloOpcode::kSend, HloOpcode::kSendDone, + HloOpcode::kRecv, HloOpcode::kRecvDone})) { + VLOG(2) << "Not attempting to unroll " << while_op->name() + << " because it contains a send/recv node: " + << while_op->ToShortString(); + return std::nullopt; + } + + if (while_op->operand(0)->opcode() != HloOpcode::kTuple) { + VLOG(2) << "Not attempting to unroll " << while_op->name() + << " because the operand is not a tuple: " + << while_op->ToShortString(); + return std::nullopt; + } + + // We cannot unroll loops that have side effecting condition because the + // condition will be removed after unrolling. This might be relaxed + // later when we add partial unrolling. + if (while_op->while_condition()->HasSideEffect()) { + VLOG(2) << "Not attempting to remove while loop whose condition contains " + "side-effecting instructions: " + << while_op->ToShortString(); + return std::nullopt; + } + + std::optional indvar_tuple_idx = + GetLoopInductionVarTupleIdx(while_op); + if (!indvar_tuple_idx.has_value()) { + return std::nullopt; + } + + HloEvaluator evaluator(/*max_loop_iterations=*/0); + const HloInstruction* while_init = while_op->operand(0); + const HloInstruction* indvar_init = while_init->operand(*indvar_tuple_idx); + absl::StatusOr indvar_init_result = evaluator.Evaluate(indvar_init); + if (!indvar_init_result.ok()) { + VLOG(2) << "Couldn't evaluate induction variable init, " + << indvar_init_result.status() << ", " << indvar_init->ToString(); + return std::nullopt; + } + Literal indvar_iter_val = std::move(indvar_init_result).value(); + + std::optional trip_count = + MatchTrivialLoopTripCount(while_op, *indvar_tuple_idx, indvar_iter_val); + if (!trip_count.has_value()) { + return std::nullopt; + } + + VLOG(3) << "Loop trip count " << trip_count.value(); + + WhileLoopConfig config; + config.init = + LiteralUtil::LiteralAsScalarInt64(std::move(indvar_iter_val)).value(); + config.trip_count = trip_count.value(); + config.induction_var_idx = *indvar_tuple_idx; + + return config; +} + +std::unique_ptr GetConstantWithPrimitiveType(PrimitiveType type, + int64_t value) { + return primitive_util::PrimitiveTypeSwitch>( + [&](auto literal_constant) -> std::unique_ptr { + if constexpr (primitive_util::IsIntegralType(literal_constant)) { + using NativeT = primitive_util::NativeTypeOf; + return HloInstruction::CreateConstant( + LiteralUtil::CreateR0(static_cast(value))); + } + LOG(FATAL) << "literal is of non-integral type"; + }, + type); +} // Helper function that replaces a single iteration of a while loop with // induction variable equal to induction_value. -static StatusOr> +absl::StatusOr> UnrollSingleIterationOfTrivialLoop(HloInstruction* while_op, const int64_t indvar_idx, const int64_t induction_value) { @@ -74,13 +183,28 @@ UnrollSingleIterationOfTrivialLoop(HloInstruction* while_op, std::unique_ptr while_body_clone = while_op->while_body()->Clone(absl::StrCat(induction_value)); - const HloInstruction* induction_var_hlo = - while_op->operand(0)->operand(indvar_idx); + HloInstruction* induction_var_hlo = + while_op->mutable_operand(0)->mutable_operand(indvar_idx); + + // We record the next channel id to utilize when unrolling loops with + // collective communication instructions. During unrolling a single iteration + // of the body, we can reuse the same unique_channel_id. For the later + // iterations, we obtain it again. + int64_t unique_channel_id = hlo_query::NextChannelId(*while_op->GetModule()); // Go through the instructions in while body to get the instruction that // points to the induction var. Then replace it everywhere with the concrete // value. for (HloInstruction* body_inst : while_body_clone->instructions()) { + // We need to assign a unique channel_id for the collective ops that are + // unrolled within the while loop body or fusions containing collectives. + if (IsCollectiveWithChannelId(body_inst)) { + // To obtain the channel_id for the collective ops we only need to + // increment the `unique_channel_id` since it records the next available + // channel_id across the module. + body_inst->set_channel_id(unique_channel_id++); + } + if (body_inst->opcode() != HloOpcode::kGetTupleElement) { continue; } @@ -104,22 +228,8 @@ UnrollSingleIterationOfTrivialLoop(HloInstruction* while_op, // Found the induction var as an operand of body instruction. if (indvar_use_operand == body_inst) { std::unique_ptr constant = - primitive_util::PrimitiveTypeSwitch< - std::unique_ptr>( - [&](auto literal_constant) - -> std::unique_ptr { - if constexpr (primitive_util::IsIntegralType( - literal_constant)) { - using NativeT = - primitive_util::NativeTypeOf; - return HloInstruction::CreateConstant( - LiteralUtil::CreateR0( - static_cast(induction_value))); - } - LOG(FATAL) << "literal is of non-integral type"; - }, - induction_var_hlo->shape().element_type()); - + GetConstantWithPrimitiveType( + induction_var_hlo->shape().element_type(), induction_value); // Assign the same shape of the old instruction to the new // instruction. *constant->mutable_shape() = body_inst->shape(); @@ -133,28 +243,192 @@ UnrollSingleIterationOfTrivialLoop(HloInstruction* while_op, return while_body_clone; } -StatusOr WhileLoopUnroller::Run( +// Helper function to create a condition for a single iteration while loop in +// the form of 'i <= init_value' where i is the induction variable. +std::unique_ptr MakeSingleIterWhileCond( + HloInstruction* while_op, int64_t induction_idx, int64_t init_value) { + auto condition_builder = + HloComputation::Builder(absl::StrCat("unrolled-cond-", while_op->name())); + + auto param_instruction = condition_builder.AddParameter( + while_op->while_condition()->parameter_instruction(0)->Clone()); + + CHECK_OK(param_instruction); + + HloInstruction* indvar_instruction = condition_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(*param_instruction, induction_idx)); + + auto init_value_constant = + condition_builder.AddInstruction(GetConstantWithPrimitiveType( + indvar_instruction->shape().element_type(), init_value)); + + return condition_builder.Build( + condition_builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PrimitiveType::PRED, {}), indvar_instruction, + init_value_constant, ComparisonDirection::kLe))); +} + +absl::Status InitialFeasibilityCheck(HloInstruction* while_op, + WhileLoopConfig config, + int64_t unroll_factor) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + + // While loop must have a single tuple operand. + CHECK_EQ(while_op->operands().size(), 1); + if (while_op->operands().size() != 1) { + return FailedPrecondition( + "%s", + absl::StrCat("Cannot unroll while loop. While loop must have a single " + "tuple operand, instead has more than one operand: ", + while_op->operands().size())); + } + + VLOG(5) << "Trying to unroll " << while_op->ToShortString(); + + // TODO(b/288130138): For now, we only support full unrolling. Will add + // partial unrolling if needed. + if (unroll_factor != -1) { + return UnimplementedStrCat( + "Currently, only full unrolling is supported, unroll factor: ", + unroll_factor); + } + + // TODO(b/291628533): Extract this parameter to the unroller config. We don't + // attempt to unroll loops where the body has more than + // kUnrollInstructionCountThreshold instructions. + if (while_op->while_body()->instruction_count() > + kUnrollInstructionCountThreshold) { + return FailedPrecondition( + "%s", + absl::StrCat( + "Cannot unroll while loop. Too many instructions in the body: ", + while_op->while_body()->instruction_count())); + } + + // TODO(b/291628533): Extract this parameter to the an unroller config. We + // only unroll loops up to a threshold. + if (config.trip_count > kUnrollTripCountThreshold) { + return FailedPrecondition( + "%s", + absl::StrCat("Cannot unroll while loop. The tip count is greater " + "than the threshold: ", + config.trip_count, " vs ", kUnrollTripCountThreshold)); + } + + // TODO(b/291628533): Extract this parameter to the unroller config. We don't + // unroll loops that increase the instruction count by more than + // kUnrollExpandFactorThreshold. + if (config.trip_count * while_op->while_body()->instruction_count() > + kUnrollExpandFactorThreshold) { + return FailedPrecondition( + "%s", absl::StrCat("Not attempting to unroll due to instruction count " + "increase explosion. New instruction count: ", + config.trip_count * + while_op->while_body()->instruction_count(), + " vs ", kUnrollExpandFactorThreshold)); + } + return absl::OkStatus(); +} + +absl::StatusOr UnrollInternal(HloInstruction* while_op, + WhileLoopConfig config, + int64_t unroll_factor) { + TF_RETURN_IF_ERROR(InitialFeasibilityCheck(while_op, config, unroll_factor)); + + VLOG(3) << "Unrolling while instruction " << while_op->ToShortString() + << " with body instruction count " + << while_op->while_body()->instruction_count(); + + HloModule* module = while_op->GetModule(); + HloComputation* computation = while_op->parent(); + HloInstruction* unrolled_body_call_op; + std::vector call_operands = {while_op->operands().at(0)}; + for (int64_t i = config.init; i < config.trip_count + config.init; ++i) { + CHECK(OverflowSafeAdd(i, (int64_t)1).has_value()); + + HloComputation* unrolled_body = module->AddEmbeddedComputation( + UnrollSingleIterationOfTrivialLoop(while_op, config.induction_var_idx, + i) + .value()); + unrolled_body_call_op = + computation->AddInstruction(HloInstruction::CreateCall( + while_op->shape(), call_operands, unrolled_body)); + call_operands.clear(); + call_operands.emplace_back(unrolled_body_call_op); + } + TF_RETURN_IF_ERROR( + computation->ReplaceInstruction(while_op, unrolled_body_call_op)); + + // Needed for the nested while loops in which the outer loop has been + // unrolled which leaves the call graph non-flat. + TF_RETURN_IF_ERROR(FlattenCallGraph().Run(module).status()); + return true; +} + +absl::StatusOr UnrollInternalWrapped(HloInstruction* while_op, + WhileLoopConfig config, + int64_t unroll_factor) { + TF_RETURN_IF_ERROR(InitialFeasibilityCheck(while_op, config, unroll_factor)); + + VLOG(3) << "Unrolling (wrapped) while instruction " + << while_op->ToShortString() << " with body instruction count " + << while_op->while_body()->instruction_count(); + + HloModule* module = while_op->GetModule(); + HloComputation* computation = while_op->parent(); + HloInstruction* unrolled_body_call_op; + + auto body_builder = + HloComputation::Builder(absl::StrCat("unrolled-body-", while_op->name())); + absl::StatusOr p = body_builder.AddParameter( + while_op->while_body()->parameter_instruction(0)->Clone()); + + std::vector call_operands = {p.value()}; + for (int64_t i = config.init; i < config.trip_count + config.init; ++i) { + CHECK(OverflowSafeAdd(i, (int64_t)1).has_value()); + + HloComputation* unrolled_body = module->AddEmbeddedComputation( + UnrollSingleIterationOfTrivialLoop(while_op, config.induction_var_idx, + i) + .value()); + unrolled_body_call_op = + body_builder.AddInstruction(HloInstruction::CreateCall( + while_op->shape(), call_operands, unrolled_body)); + call_operands.clear(); + call_operands.emplace_back(unrolled_body_call_op); + } + HloComputation* new_body = + module->AddEmbeddedComputation(body_builder.Build(unrolled_body_call_op)); + HloComputation* new_cond = module->AddEmbeddedComputation( + MakeSingleIterWhileCond(while_op, config.induction_var_idx, config.init)); + + HloInstruction* new_while_op = + computation->AddInstruction(HloInstruction::CreateWhile( + while_op->shape(), new_cond, new_body, while_op->mutable_operand(0))); + + CHECK_OK(computation->ReplaceInstruction(while_op, new_while_op)); + + // Needed for the nested while loops in which the outer loop has been + // unrolled which leaves the call graph non-flat. + TF_RETURN_IF_ERROR(FlattenCallGraph().Run(module).status()); + return true; +} + +}; // namespace + +absl::StatusOr PrepareModuleForUnrolling( HloModule* module, const absl::flat_hash_set& execution_threads) { - // TODO(b/288130138) For now, we only support full unrolling. Will add partial - // unrolling if needed. - if (unroll_factor_ != -1) { - return false; - } - XLA_VLOG_LINES(3, "WhileLoopUnroller::Run(), before:\n" + module->ToString()); bool changed = false; - - // The following sequence of passes are necessary to prepare loops for - // unrolling. Instead of placing these passes in compiler, they are placed - // here to indicate explicit dependency to these passes. TF_ASSIGN_OR_RETURN( bool applied_cse, - HloCSE{/*is_layout_sensitive=*/true}.Run(module, execution_threads)); + HloCSE(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/false, + /*ignore_control_dependencies=*/false, /*only_scalars=*/true) + .Run(module, execution_threads)); if (applied_cse) { changed = true; VLOG(3) << "Applied hlo cse to module " << module->name(); } - TF_ASSIGN_OR_RETURN(bool applied_tuple_simplifier, TupleSimplifier{}.Run(module, execution_threads)); if (applied_tuple_simplifier) { @@ -164,14 +438,21 @@ StatusOr WhileLoopUnroller::Run( // We apply constant sinking to fix point. HloPassFix constant_sinking( - /*sink_broadcast_of_constants=*/true); + /*sink_broadcast_of_constants=*/true, + /*sink_only_scalar_constants=*/true); TF_ASSIGN_OR_RETURN(bool applied_constant_sinking, constant_sinking.Run(module, execution_threads)); if (applied_constant_sinking) { + changed = true; VLOG(3) << "Applied constant sinking to module " << module->name(); } + return changed; +} - // Processing the while loops in the reverse of topological order. If the body +std::vector> GetUnrollableLoops( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + // Processing the while loops in the reverse topological order. If the body // of while loop A calls while loop B, B comes before A. std::vector all_while_ops; for (auto* comp : module->MakeComputationPostOrder(execution_threads)) { @@ -179,187 +460,97 @@ StatusOr WhileLoopUnroller::Run( HloPredicateIsOp); } - // Finding all the while loops where other instructions have explicit control - // dependencies on them. - std::vector while_with_deps; - for (HloComputation* comp : module->computations(execution_threads)) { - for (HloInstruction* instr : comp->instructions()) { - for (HloInstruction* control_dep : instr->control_predecessors()) { - if (control_dep->opcode() == HloOpcode::kWhile) { - if (std::find(all_while_ops.begin(), all_while_ops.end(), - control_dep) != all_while_ops.end()) { - while_with_deps.push_back(control_dep); - } - } - } - } - } - - // Gather a preliminary vector of all the while ops that we think we can - // unroll. We only consider while loops that take a tuple as an argument. We - // do this ahead of time so we don't have to worry about mutating the lists of - // computations or instructions while we iterate. - std::vector while_ops; + std::vector> while_loop_configs; for (HloInstruction* instr : all_while_ops) { - // TODO(b/300668690): Check control dependencies to the while - // instruction - if (absl::linear_search(while_with_deps.begin(), while_with_deps.end(), - instr)) { - VLOG(2) << "Not attempting to unroll " << instr->name() - << " due to control dependency: " << instr->ToShortString(); - continue; - } - - // We can't remove while loops that contain send/recv nodes, because we - // rely on the particular loop structure around the node matching on the - // send and recv sides. - if (ContainsInstrWithOpcode(instr->while_body(), - {HloOpcode::kSend, HloOpcode::kSendDone, - HloOpcode::kRecv, HloOpcode::kRecvDone}) || - ContainsInstrWithOpcode(instr->while_condition(), - {HloOpcode::kSend, HloOpcode::kSendDone, - HloOpcode::kRecv, HloOpcode::kRecvDone})) { - VLOG(2) << "Not attempting to unroll " << instr->name() - << " because it contains a send/recv node: " - << instr->ToShortString(); - continue; - } - // TODO(b/291146216): Handle this case later - if (ContainsInstrWithOpcode(instr->while_body(), {HloOpcode::kAllReduce, - HloOpcode::kAllGather})) { - VLOG(2) << "Not attempting to unroll " << instr->name() - << " for now because it contains an all-reduce or an all-gather: " - << instr->ToShortString(); - continue; - } - if (instr->operand(0)->opcode() != HloOpcode::kTuple) { - VLOG(2) << "Not attempting to unroll " << instr->name() - << " because the operand is not a tuple: " - << instr->ToShortString(); - continue; + std::optional config = IsLoopUnrollable(instr); + if (config.has_value()) { + while_loop_configs.emplace_back(instr, config.value()); } - // We cannot unroll loops that have side effecting condition because the - // condition will be removed after unrolling. This might be relaxed - // later when we add partial unrolling. - if (instr->while_condition()->HasSideEffect()) { - VLOG(2) << "Not attempting to remove while loop whose condition contains " - "side-effecting instructions: " - << instr->ToShortString(); - return false; - } - // TODO(b/291628533): Extract this to the unroller config - if (instr->while_body()->instruction_count() > - kUnrollInstructionCountThreshold) { - continue; - } - while_ops.push_back(instr); } + return while_loop_configs; +} - VLOG(3) << "Number of while instructions in the module to unroll: " - << while_ops.size(); +absl::StatusOr Unroll(HloInstruction* while_op, int64_t unroll_factor, + bool wrap_in_trivial_loop) { + bool changed = false; + HloModule* module = while_op->GetModule(); - for (HloInstruction* while_op : while_ops) { - VLOG(3) << "Trying to unroll " << while_op->ToShortString(); - bool unrolled_current_loop = false; - int64_t unroll_factor_current_loop = unroll_factor_; + // Make sure all the necessary passes are executed before unrolling in order + // to unroll every possible loop. + TF_ASSIGN_OR_RETURN( + changed, PrepareModuleForUnrolling(module, /*execution_threads=*/{})); - // TODO(b/288130138) For now, we only support full unrolling. Will add - // partial unrolling if needed. - CHECK_EQ(unroll_factor_current_loop, -1); + // Construct the loop config + std::optional config = IsLoopUnrollable(while_op); + if (!config.has_value()) { + return false; + } - std::optional indvar_tuple_idx = - GetLoopInductionVarTupleIdx(while_op); - if (!indvar_tuple_idx.has_value()) { - continue; - } + bool unrolled = false; + if (wrap_in_trivial_loop) { + TF_ASSIGN_OR_RETURN(unrolled, UnrollInternalWrapped( + while_op, config.value(), unroll_factor)); + } else { + TF_ASSIGN_OR_RETURN( + unrolled, UnrollInternal(while_op, config.value(), unroll_factor)); + } - HloEvaluator evaluator(/*max_loop_iterations=*/0); - const HloInstruction* while_init = while_op->operand(0); - const HloInstruction* indvar_init = while_init->operand(*indvar_tuple_idx); - StatusOr indvar_init_result = evaluator.Evaluate(indvar_init); - if (!indvar_init_result.ok()) { - VLOG(2) << "Couldn't evaluate induction variable init, " - << indvar_init_result.status() << ", " << indvar_init->ToString(); - continue; - } - Literal indvar_iter_val = std::move(indvar_init_result).value(); + // We need to inline the calls created for unrolling since later passes rely + // on the calls to be inlined. + if (unrolled) { + TF_RETURN_IF_ERROR(CallInliner().Run(module).status()); + } + return unrolled; +} - // TODO(b/288907795): Try using ComputeWhileLoopTripCount - std::optional trip_count = - MatchTrivialLoopTripCount(while_op, *indvar_tuple_idx, indvar_iter_val); - if (!trip_count.has_value()) { - continue; - } +absl::StatusOr WhileLoopUnroller::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + // TODO(b/288130138) For now, we only support full unrolling. Will add partial + // unrolling if needed. + if (unroll_factor_ != -1) { + return false; + } + XLA_VLOG_LINES(3, "WhileLoopUnroller::Run(), before:\n" + module->ToString()); + bool changed = false; - VLOG(3) << "Loop trip count " << trip_count.value(); + // Make sure all the necessary passes are executed before unrolling in order + // to unroll every possible loop. + TF_ASSIGN_OR_RETURN(changed, + PrepareModuleForUnrolling(module, execution_threads)); - // TODO(b/291628533): Extract this to the unroller config. We only unroll - // loops up to a threshold. - if (trip_count > kUnrollTripCountThreshold) { - continue; - } + // Processing the while loops in the reverse of topological order. If the body + // of while loop A calls while loop B, B comes before A. + std::vector all_while_ops; + for (auto* comp : module->MakeComputationPostOrder(execution_threads)) { + absl::c_copy_if(comp->instructions(), std::back_inserter(all_while_ops), + HloPredicateIsOp); + } - unroll_factor_current_loop = trip_count.value(); - - // TODO(b/291628533): Extract this to the unroller config. We don't unroll - // loops that increase the instruction count by more than - // kUnrollExpandFactorThreshold. - if (trip_count.value() * while_op->while_body()->instruction_count() > - kUnrollExpandFactorThreshold) { - VLOG(3) << "Not attempting to unroll due to instruction count increase " - "explosion."; - VLOG(3) << "New instruction count: " - << trip_count.value() * - while_op->while_body()->instruction_count(); - continue; - } + // Gather a preliminary vector of all the while ops that we think we can + // unroll. We do this ahead of time so we don't have to worry about mutating + // the lists of computations or instructions while we iterate. + std::vector> + unrollable_while_ops = GetUnrollableLoops(module, execution_threads); - std::optional init_value = - LiteralUtil::LiteralAsScalarInt64(indvar_iter_val); - // Init value must be int64_t at this point since we found the trip count. - CHECK(init_value.has_value()); - - unrolled_current_loop = true; - VLOG(3) << "Unrolling while instruction " << while_op->ToShortString() - << " with body instruction count " - << while_op->while_body()->instruction_count(); - HloComputation* computation = while_op->parent(); - HloInstruction* unrolled_body_call_op; - std::vector call_operands; - // We assume while has only one tuple parameter - call_operands.emplace_back(while_op->operands().at(0)); - for (int64_t i = init_value.value(); - i < unroll_factor_current_loop + init_value.value(); ++i) { - CHECK(OverflowSafeAdd(i, (int64_t)1).has_value()); - - HloComputation* unrolled_body = module->AddEmbeddedComputation( - UnrollSingleIterationOfTrivialLoop(while_op, *indvar_tuple_idx, i) - .value()); - unrolled_body_call_op = - computation->AddInstruction(HloInstruction::CreateCall( - while_op->shape(), call_operands, unrolled_body)); - call_operands.clear(); - call_operands.emplace_back(unrolled_body_call_op); - } - CHECK_OK(computation->ReplaceInstruction(while_op, unrolled_body_call_op)); - - // Need to perform following passes only if the current while loop has been - // unrolled. - if (unrolled_current_loop) { - // Since Flattening call graph relies on AsyncOpCanonicalizer - TF_RETURN_IF_ERROR( - AsyncOpCanonicalizer().Run(module, execution_threads).status()); - // Needed for the nested while loops in which the outer loop has been - // unrolled which leaves the call graph non-flat. - TF_RETURN_IF_ERROR( - FlattenCallGraph().Run(module, execution_threads).status()); + VLOG(3) << "Number of while instructions in the module to unroll: " + << unrollable_while_ops.size(); + + bool unrolled = false; + for (auto& [while_op, config] : unrollable_while_ops) { + if (wrap_in_trivial_loop_) { + TF_ASSIGN_OR_RETURN( + unrolled, UnrollInternalWrapped(while_op, config, unroll_factor_)); + } else { + TF_ASSIGN_OR_RETURN(unrolled, + UnrollInternal(while_op, config, unroll_factor_)); } - changed |= unrolled_current_loop; + changed |= unrolled; } + // We need to inline the calls created for unrolling since later passes rely + // on the calls to be inlined. if (changed) { - // We need to inline the calls created for unrolling since later passes rely - // on the calls to be inlined. TF_RETURN_IF_ERROR(CallInliner().Run(module, execution_threads).status()); } diff --git a/xla/service/while_loop_unroller.h b/xla/service/while_loop_unroller.h index 2adb7df6ea9bd..337ec9d348a8d 100644 --- a/xla/service/while_loop_unroller.h +++ b/xla/service/while_loop_unroller.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,16 +17,49 @@ limitations under the License. #define XLA_SERVICE_WHILE_LOOP_UNROLLER_H_ #include +#include +#include #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" #include "xla/statusor.h" namespace xla { +// Config for unrollable while loops. +struct WhileLoopConfig { + // The initial value of the induction variable of the while loop. + int64_t init; + // The number of iterations the loop executes. + int64_t trip_count; + // The index of the induction variable in the input tuple of the while loop. + int64_t induction_var_idx; +}; + +// Runs a sequence of passes that are necessary to prepare loops for unrolling. +// Failure to run these passes will prevent unroller from unrolling loops that +// would have been otherwise unrollable. +absl::StatusOr PrepareModuleForUnrolling( + HloModule* module, + const absl::flat_hash_set& execution_threads); + +// Returns the list of unrollable loops in the given module + +std::vector> GetUnrollableLoops( + HloModule* module, + const absl::flat_hash_set& execution_threads); + +// Unrolls the given while loop with the default behaviour set to full unroll. +// If wrap_in_trivial_loop is set, the unrolled body of the loop will be wrapped +// in a loop with trip count of one. +absl::StatusOr Unroll(HloInstruction* while_op, + int64_t unroll_factor = -1, + bool wrap_in_trivial_loop = false); + // This pass unrolls while loops with the given unrolling factor. The value of // unroll_factor = -1 will fully unroll the loop. // @@ -42,18 +75,22 @@ class WhileLoopUnroller : public HloModulePass { ~WhileLoopUnroller() override = default; // Default unroll_factor of -1 indicates full unrolling - explicit WhileLoopUnroller(int64_t unroll_factor = -1) - : unroll_factor_(unroll_factor) {} + explicit WhileLoopUnroller(int64_t unroll_factor = -1, + bool wrap_in_trivial_loop = false) + : unroll_factor_(unroll_factor), + wrap_in_trivial_loop_(wrap_in_trivial_loop) {} absl::string_view name() const override { return "while_loop_unroller"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; private: int64_t unroll_factor_; + // Whether to wrap the unrolled computation in a loop with trip count of one. + bool wrap_in_trivial_loop_; }; } // namespace xla diff --git a/xla/service/while_loop_unroller_test.cc b/xla/service/while_loop_unroller_test.cc index 4276e19c1c9db..02a6db1be5fe8 100644 --- a/xla/service/while_loop_unroller_test.cc +++ b/xla/service/while_loop_unroller_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,18 +19,21 @@ limitations under the License. #include #include #include -#include +#include #include +#include "absl/algorithm/container.h" #include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -47,15 +50,19 @@ class WhileLoopUnrollerTest : public HloTestBase { MakeModuleWithLoopBodyNestedCopyIndVar(int num_iters); [[nodiscard]] std::unique_ptr MakeModuleWithWhileFeedingAnotherWhile(int num_iters); + [[nodiscard]] std::unique_ptr + MakeModuleWithSimpleLoopAllReduce(int num_iters); public: void UnrollAndCompare(std::unique_ptr module, absl::Span arguments, - int64_t unroll_factor = -1) { + int64_t unroll_factor = -1, bool wrap_in_loop = false) { Literal before_unroll = ExecuteAndTransfer(module->Clone(), arguments); - VLOG(2) << "after unroll value: " << before_unroll.ToString(); + VLOG(2) << "before unroll value: " << before_unroll.ToString(); - EXPECT_TRUE(WhileLoopUnroller(unroll_factor).Run(module.get()).value()); + EXPECT_TRUE(WhileLoopUnroller(unroll_factor, wrap_in_loop) + .Run(module.get()) + .value()); Literal after_unroll = ExecuteAndTransfer(std::move(module), arguments); VLOG(2) << "after unroll value: " << after_unroll.ToString(); @@ -293,8 +300,131 @@ WhileLoopUnrollerTest::MakeModuleWithWhileFeedingAnotherWhile(int num_iters) { return ParseAndReturnVerifiedModule(hlo_string).value(); } +std::unique_ptr +WhileLoopUnrollerTest::MakeModuleWithSimpleLoopAllReduce(int num_iters) { + std::string hlo_string_template = R"( + HloModule SimpleLoop + + %reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) + } + + SimpleLoop.body { + loop_var.1 = (s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + get-tuple-element.2 = f32[1024, 1024] get-tuple-element(loop_var.1), index=1 + get-tuple-element.3 = f32[1024, 1024] get-tuple-element(loop_var.1), index=2 + + %all-reduce = f32[1024, 1024] all-reduce(f32[1024, 1024] get-tuple-element.2), channel_id=1, replica_groups={{0}}, to_apply=%reduction + %accumulation = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] get-tuple-element.3) + + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[], f32[1024, 1024], f32[1024, 1024]) tuple(add, get-tuple-element.2, %accumulation) + } + SimpleLoop.condition { + loop_var.2 = (s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant({{LOOP_BOUND}}) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + ENTRY SimpleLoop { + %param.1 = f32[1024, 1024] parameter(0) + constant.3 = s32[] constant(0) + + %accumulation_buffer_init = f32[] constant(0) + %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + + tuple.1 = (s32[], f32[1024, 1024], f32[1024, 1024]) tuple(constant.3, %param.1, %accumulation_buffer) + ROOT while = (s32[], f32[1024, 1024], f32[1024, 1024]) while(tuple.1), condition=SimpleLoop.condition, body=SimpleLoop.body + } + )"; + std::string hlo_string = absl::StrReplaceAll( + hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(num_iters)}}); + return ParseAndReturnVerifiedModule(hlo_string).value(); +} + TEST_F(WhileLoopUnrollerTest, SimpleLoopUnroll) { - UnrollAndCompare(MakeModuleWithSimpleLoop(/*num_iters=*/5), {}); + UnrollAndCompare(MakeModuleWithSimpleLoop(/*num_iters=*/5), {}, -1, false); + UnrollAndCompare(MakeModuleWithSimpleLoop(/*num_iters=*/5), {}, -1, true); +} + +// This test passes because we run WhileLoopConstantSinking before unrolling. +TEST_F(WhileLoopUnrollerTest, SimpleLoopUnrollNeedPrepare) { + std::string hlo_string = R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s64[], s32[3]{0}, s64[]) parameter(0) + get-tuple-element.1 = s64[] get-tuple-element(loop_var.1), index=0 + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + get-tuple-element.3 = s64[] get-tuple-element(loop_var.1), index=2 + add = s64[] add(get-tuple-element.1, get-tuple-element.3) + multiply = s32[3]{0} add(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s64[], s32[3]{0}, s64[]) tuple(add, multiply, get-tuple-element.3) + } + SimpleLoop.condition { + loop_var.2 = (s64[], s32[3]{0}, s64[]) parameter(0) + get-tuple-element.3 = s64[] get-tuple-element(loop_var.2), index=0 + /* number of iterations is 10 */ + constant.2 = s64[] constant(10) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + ENTRY SimpleLoop { + constant.3 = s64[] constant(0) + one = s64[] constant(1) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s64[], s32[3]{0}, s64[]) tuple(constant.3, constant.4, one) + while = (s64[], s32[3]{0}, s64[]) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT result = s32[3]{0} get-tuple-element(while), index=1 + } + )"; + UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}, -1, + false); + UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}, -1, + true); +} + +// This test passes because we run TupleSimplifier before unrolling. +TEST_F(WhileLoopUnrollerTest, SimpleLoopUnrollNeedPrepare2) { + std::string hlo_string = R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s64[], s32[3]{0}, s64[]) parameter(0) + get-tuple-element.1 = s64[] get-tuple-element(loop_var.1), index=0 + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + get-tuple-element.3 = s64[] get-tuple-element(loop_var.1), index=2 + add = s64[] add(get-tuple-element.1, get-tuple-element.3) + multiply = s32[3]{0} add(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s64[], s32[3]{0}, s64[]) tuple(add, multiply, get-tuple-element.3) + } + SimpleLoop.condition { + loop_var.2 = (s64[], s32[3]{0}, s64[]) parameter(0) + get-tuple-element.3 = s64[] get-tuple-element(loop_var.2), index=0 + /* number of iterations is 10 */ + constant.2 = s64[] constant(10) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + ENTRY SimpleLoop { + constant.3 = s64[] constant(0) + one = s64[] constant(1) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s64[], s32[3]{0}, s64[]) tuple(constant.3, constant.4, one) + gte1 = s64[] get-tuple-element(tuple.1), index=0 + gte2 = s32[3]{0} get-tuple-element(tuple.1), index=1 + gte3 = s64[] get-tuple-element(tuple.1), index=2 + tuple = (s64[], s32[3]{0}, s64[]) tuple(gte1, gte2, gte3) + while = (s64[], s32[3]{0}, s64[]) while(tuple), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT result = s32[3]{0} get-tuple-element(while), index=1 + } + )"; + UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}, -1, + false); + UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}, -1, + true); } TEST_F(WhileLoopUnrollerTest, SimpleLoopNotRoot) { @@ -325,10 +455,13 @@ TEST_F(WhileLoopUnrollerTest, SimpleLoopNotRoot) { ROOT result = s32[3]{0} get-tuple-element(while), index=1 } )"; - UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}); + UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}, -1, + false); + UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}, -1, + true); } -TEST_F(WhileLoopUnrollerTest, SimpleLoopNonZeroInit) { +TEST_F(WhileLoopUnrollerTest, GetUnrollableLoops) { std::string hlo_string = R"( HloModule SimpleLoop SimpleLoop.body { @@ -347,6 +480,160 @@ TEST_F(WhileLoopUnrollerTest, SimpleLoopNonZeroInit) { constant.2 = s64[] constant(10) ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } + SimpleLoop.body.2 { + loop_var.1 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s64[] get-tuple-element(loop_var.1), index=0 + constant.1 = s64[] constant(1) + add = s64[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} add(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s64[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition.2 { + loop_var.2 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s64[] get-tuple-element(loop_var.2), index=0 + /* number of iterations is 10 */ + constant.2 = s64[] constant(10) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + SimpleLoop.body.3 { + loop_var.1 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s64[] get-tuple-element(loop_var.1), index=0 + constant.1 = s64[] constant(1) + add = s64[] multiply(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} add(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s64[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition.3 { + loop_var.2 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s64[] get-tuple-element(loop_var.2), index=0 + /* number of iterations is 10 */ + constant.2 = s64[] constant(10) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + ENTRY SimpleLoop { + constant.3 = s64[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s64[], s32[3]{0}) tuple(constant.3, constant.4) + while1 = (s64[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + while3 = (s64[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition.3, body=SimpleLoop.body.3 + while2 = (s64[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition.2, body=SimpleLoop.body.2 + o1 = s32[3]{0} get-tuple-element(while1), index=1 + o2 = s32[3]{0} get-tuple-element(while2), index=1 + ROOT result = (s32[3]{0}, s32[3]{0}) tuple(o1,o2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + auto unrollable_loops = GetUnrollableLoops(module.get(), {}); + // Only while1 and while2 are unrollable + EXPECT_EQ(unrollable_loops.size(), 2); +} + +TEST_F(WhileLoopUnrollerTest, UnrollMutipleLoops) { + std::string hlo_string = R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s64[] get-tuple-element(loop_var.1), index=0 + constant.1 = s64[] constant(1) + add = s64[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} add(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s64[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s64[] get-tuple-element(loop_var.2), index=0 + /* number of iterations is 10 */ + constant.2 = s64[] constant(10) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + SimpleLoop.body.2 { + loop_var.1 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s64[] get-tuple-element(loop_var.1), index=0 + constant.1 = s64[] constant(1) + add = s64[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} add(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s64[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition.2 { + loop_var.2 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s64[] get-tuple-element(loop_var.2), index=0 + /* number of iterations is 10 */ + constant.2 = s64[] constant(10) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + ENTRY SimpleLoop { + constant.3 = s64[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s64[], s32[3]{0}) tuple(constant.3, constant.4) + while1 = (s64[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + input = s32[3]{0} get-tuple-element(while1), index=1 + tuple.2 = (s64[], s32[3]{0}) tuple(constant.3, input) + while2 = (s64[], s32[3]{0}) while(tuple.2), condition= + SimpleLoop.condition.2, body=SimpleLoop.body.2 + o1 = s32[3]{0} get-tuple-element(while1), index=1 + o2 = s32[3]{0} get-tuple-element(while2), index=1 + ROOT result = (s32[3]{0}, s32[3]{0}) tuple(o1,o2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + // Unroll the first loop + TF_ASSERT_OK_AND_ASSIGN( + bool unrolled1, + Unroll(module->entry_computation()->GetInstructionWithName("while1"))); + EXPECT_TRUE(unrolled1); + + // There should be no call instructions after unrolling either loops since we + // inline all the calls after unrolling. + std::vector call_instrs_1; + for (auto* comp : module->MakeComputationPostOrder()) { + absl::c_copy_if(comp->instructions(), std::back_inserter(call_instrs_1), + HloPredicateIsOp); + } + EXPECT_EQ(call_instrs_1.size(), 0); + + // Unroll the second loop + TF_ASSERT_OK_AND_ASSIGN( + bool unrolled2, + Unroll(module->entry_computation()->GetInstructionWithName("while2"))); + EXPECT_TRUE(unrolled2); + std::vector call_instrs_2; + for (auto* comp : module->MakeComputationPostOrder()) { + absl::c_copy_if(comp->instructions(), std::back_inserter(call_instrs_2), + HloPredicateIsOp); + } + EXPECT_EQ(call_instrs_2.size(), 0); +} + +TEST_F(WhileLoopUnrollerTest, SimpleLoopNonZeroInit) { + std::string hlo_string = R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s64[] get-tuple-element(loop_var.1), index=0 + constant.1 = s64[] constant(1) + add = s64[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} add(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s64[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s64[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s64[] get-tuple-element(loop_var.2), index=0 + constant.2 = s64[] constant(10) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } ENTRY SimpleLoop { constant.3 = s64[] constant(4) constant.4 = s32[3]{0} constant({0, 1, 2}) @@ -356,7 +643,10 @@ TEST_F(WhileLoopUnrollerTest, SimpleLoopNonZeroInit) { ROOT result = s32[3]{0} get-tuple-element(while), index=1 } )"; - UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}); + UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}, -1, + false); + UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}, -1, + true); } TEST_F(WhileLoopUnrollerTest, SimpleLoopS16IndVar) { @@ -386,7 +676,10 @@ TEST_F(WhileLoopUnrollerTest, SimpleLoopS16IndVar) { SimpleLoop.condition, body=SimpleLoop.body } )"; - UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}); + UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}, -1, + false); + UnrollAndCompare(ParseAndReturnVerifiedModule(hlo_string).value(), {}, -1, + true); } TEST_F(WhileLoopUnrollerTest, LoopWithControlDep) { @@ -431,17 +724,244 @@ TEST_F(WhileLoopUnrollerTest, SimpleLoopPartialUnroll) { TEST_F(WhileLoopUnrollerTest, IndirectBodyInc) { std::unique_ptr module = MakeModuleWithLoopBodyIndirectInc(/*num_iters=*/5); - UnrollAndCompare(std::move(module), {}); + UnrollAndCompare(MakeModuleWithLoopBodyIndirectInc(/*num_iters=*/5), {}, -1, + false); + UnrollAndCompare(MakeModuleWithLoopBodyIndirectInc(/*num_iters=*/5), {}, -1, + true); } TEST_F(WhileLoopUnrollerTest, NestedIndirectBodyInc) { std::unique_ptr module = MakeModuleWithNestedLoopBodyIndirectInc(/*num_iters=*/5); - UnrollAndCompare(std::move(module), {}); + UnrollAndCompare(MakeModuleWithNestedLoopBodyIndirectInc(/*num_iters=*/5), {}, + -1, false); + UnrollAndCompare(MakeModuleWithNestedLoopBodyIndirectInc(/*num_iters=*/5), {}, + -1, true); } TEST_F(WhileLoopUnrollerTest, WhileFeedingWhile) { - UnrollAndCompare(MakeModuleWithWhileFeedingAnotherWhile(/*num_iters=*/5), {}); + UnrollAndCompare(MakeModuleWithWhileFeedingAnotherWhile(/*num_iters=*/5), {}, + -1, false); + UnrollAndCompare(MakeModuleWithWhileFeedingAnotherWhile(/*num_iters=*/5), {}, + -1, true); +} + +TEST_F(WhileLoopUnrollerTest, LoopWithCollective) { + int64_t num_iters = 5; + auto module = MakeModuleWithSimpleLoopAllReduce(num_iters); + + EXPECT_TRUE( + WhileLoopUnroller(/*unroll_factor=*/-1).Run(module.get()).value()); + + EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(), + [](const HloInstruction* instruction) { + return instruction->opcode() == + HloOpcode::kAllReduce; + }), + num_iters); +} + +TEST_F(WhileLoopUnrollerTest, LoopWithCollective2) { + std::string hlo_string = R"( + HloModule module, entry_computation_layout={(s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)}, s8[1,2048,4096]{2,1,0:T(8,128)(4,1)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)})->(s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)}, s8[1,2048,4096]{2,1,0:T(8,128)(4,1)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:T(128)}, /*index=5*/u32[]{:T(128)}, u32[256]{0:T(256)}, u32[]{:T(128)}, u32[]{:T(128)}, s32[]{:T(128)}, /*index=10*/u32[]{:T(128)}, u32[]{:T(128)}, u32[]{:T(128)})} + + fused_computation.70.clone.clone.clone { + param_0.10545 = s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)} parameter(0) + ROOT bitcast.7213 = s8[32,2048,1]{1,0,2:T(8,128)(4,1)} bitcast(param_0.10545) + } + + fused_computation.68.clone.clone.clone { + param_1.12561 = s8[1,2048,1,4096]{3,1,2,0:T(8,128)(4,1)S(1)} parameter(1) + constant.26622 = s8[]{:T(512)} constant(0) + pad.3783 = s8[1,2048,2,4096]{3,1,2,0:T(8,128)(4,1)} pad(param_1.12561, constant.26622), padding=0_0x0_0x0_1x0_0 + constant.26621 = s32[]{:T(128)} constant(0) + param_2.10214 = s32[]{:T(128)S(6)} parameter(2) + dynamic-slice.5474 = s8[1,2048,2,256]{3,1,2,0:T(8,128)(4,1)} dynamic-slice(pad.3783, constant.26621, constant.26621, constant.26621, param_2.10214), dynamic_slice_sizes={1,2048,2,256} + pad.3782 = s8[1,2048,2,4096]{3,1,2,0:T(8,128)(4,1)} pad(param_1.12561, constant.26622), padding=0_0x0_0x1_0x0_0 + param_0.10544 = s32[]{:T(128)S(6)} parameter(0) + dynamic-slice.5473 = s8[1,2048,2,256]{3,1,2,0:T(8,128)(4,1)} dynamic-slice(pad.3782, constant.26621, constant.26621, constant.26621, param_0.10544), dynamic_slice_sizes={1,2048,2,256} + add.10207 = s8[1,2048,2,256]{3,1,2,0:T(8,128)(4,1)} add(dynamic-slice.5474, dynamic-slice.5473) + ROOT bitcast.7212 = s8[2048,2,256]{2,0,1:T(8,128)(4,1)} bitcast(add.10207) + } + + fused_computation.71.clone { + param_3.7588 = s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)} parameter(3) + fusion.4288 = s8[32,2048,1]{1,0,2:T(8,128)(4,1)} fusion(param_3.7588), kind=kLoop, calls=fused_computation.70.clone.clone.clone + param_0.10546 = s32[]{:T(128)S(6)} parameter(0) + param_1.12562 = s8[1,2048,1,4096]{3,1,2,0:T(8,128)(4,1)S(1)} parameter(1) + param_2.10215 = s32[]{:T(128)S(6)} parameter(2) + fusion.4287 = s8[2048,2,256]{2,0,1:T(8,128)(4,1)} fusion(param_0.10546, param_1.12562, param_2.10215), kind=kLoop, calls=fused_computation.68.clone.clone.clone + convolution.802 = s32[32,2,256]{2,0,1:T(8,128)} convolution(fusion.4288, fusion.4287), window={size=2 pad=1_1 rhs_reversal=1}, dim_labels=bf0_i0o->b0f + ROOT bitcast.7214 = s32[1,32,2,256]{3,1,2,0:T(8,128)S(1)} bitcast(convolution.802) + } + + fused_computation.76.clone { + param_0.10547 = s32[1,32,256]{2,1,0:T(8,128)S(1)} parameter(0) + param_1.12563 = s32[1,32,2,256]{3,1,2,0:T(8,128)S(1)} parameter(1) + slice.12606 = s32[1,32,1,256]{3,1,2,0:T(8,128)} slice(param_1.12563), slice={[0:1], [0:32], [1:2], [0:256]} + bitcast.7215 = s32[1,32,256]{2,1,0:T(8,128)} bitcast(slice.12606) + add.10208 = s32[1,32,256]{2,1,0:T(8,128)S(1)} add(param_0.10547, bitcast.7215) + param_2.10216 = s32[1,32,256]{2,1,0:T(8,128)S(1)} parameter(2) + slice.12000.clone.2 = s32[1,32,1,256]{3,1,2,0:T(8,128)} slice(param_1.12563), slice={[0:1], [0:32], [0:1], [0:256]} + bitcast.1776.clone.2 = s32[1,32,256]{2,1,0:T(8,128)} bitcast(slice.12000.clone.2) + add.6006.clone.2 = s32[1,32,256]{2,1,0:T(8,128)S(1)} add(param_2.10216, bitcast.1776.clone.2) + ROOT tuple.2892 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}) tuple(add.10208, add.6006.clone.2) + } + + fused_computation.69.clone.clone.clone { + param_0.10549 = s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)} parameter(0) + ROOT bitcast.7217 = s8[32,2048,1]{1,0,2:T(8,128)(4,1)} bitcast(param_0.10549) + } + + fused_computation.66.clone.clone.clone { + param_1.12564 = s8[1,2048,1,4096]{3,1,2,0:T(8,128)(4,1)S(1)} parameter(1) + constant.26625 = s8[]{:T(512)} constant(0) + pad.3785 = s8[1,2048,2,4096]{3,1,2,0:T(8,128)(4,1)} pad(param_1.12564, constant.26625), padding=0_0x0_0x0_1x0_0 + constant.26624 = s32[]{:T(128)} constant(0) + param_2.10217 = s32[]{:T(128)S(6)} parameter(2) + dynamic-slice.5476 = s8[1,2048,2,256]{3,1,2,0:T(8,128)(4,1)} dynamic-slice(pad.3785, constant.26624, constant.26624, constant.26624, param_2.10217), dynamic_slice_sizes={1,2048,2,256} + pad.3784 = s8[1,2048,2,4096]{3,1,2,0:T(8,128)(4,1)} pad(param_1.12564, constant.26625), padding=0_0x0_0x1_0x0_0 + param_0.10548 = s32[]{:T(128)S(6)} parameter(0) + dynamic-slice.5475 = s8[1,2048,2,256]{3,1,2,0:T(8,128)(4,1)} dynamic-slice(pad.3784, constant.26624, constant.26624, constant.26624, param_0.10548), dynamic_slice_sizes={1,2048,2,256} + add.10212 = s8[1,2048,2,256]{3,1,2,0:T(8,128)(4,1)} add(dynamic-slice.5476, dynamic-slice.5475) + ROOT bitcast.7216 = s8[2048,2,256]{2,0,1:T(8,128)(4,1)} bitcast(add.10212) + } + + fused_computation.72.clone { + param_3.7589 = s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)} parameter(3) + fusion.4292 = s8[32,2048,1]{1,0,2:T(8,128)(4,1)} fusion(param_3.7589), kind=kLoop, calls=fused_computation.69.clone.clone.clone + param_0.10550 = s32[]{:T(128)S(6)} parameter(0) + param_1.12565 = s8[1,2048,1,4096]{3,1,2,0:T(8,128)(4,1)S(1)} parameter(1) + param_2.10218 = s32[]{:T(128)S(6)} parameter(2) + fusion.4291 = s8[2048,2,256]{2,0,1:T(8,128)(4,1)} fusion(param_0.10550, param_1.12565, param_2.10218), kind=kLoop, calls=fused_computation.66.clone.clone.clone + convolution.803 = s32[32,2,256]{2,0,1:T(8,128)} convolution(fusion.4292, fusion.4291), window={size=2 pad=1_1 rhs_reversal=1}, dim_labels=bf0_i0o->b0f + ROOT bitcast.7218 = s32[1,32,2,256]{3,1,2,0:T(8,128)S(1)} bitcast(convolution.803) + } + + fused_computation.74.clone { + param_0.10551 = s32[1,32,256]{2,1,0:T(8,128)S(1)} parameter(0) + param_1.12566 = s32[1,32,2,256]{3,1,2,0:T(8,128)S(1)} parameter(1) + slice.12607 = s32[1,32,1,256]{3,1,2,0:T(8,128)} slice(param_1.12566), slice={[0:1], [0:32], [1:2], [0:256]} + bitcast.7219 = s32[1,32,256]{2,1,0:T(8,128)} bitcast(slice.12607) + add.10213 = s32[1,32,256]{2,1,0:T(8,128)S(1)} add(param_0.10551, bitcast.7219) + param_2.10219 = s32[1,32,256]{2,1,0:T(8,128)S(1)} parameter(2) + slice.11997.clone.2 = s32[1,32,1,256]{3,1,2,0:T(8,128)} slice(param_1.12566), slice={[0:1], [0:32], [0:1], [0:256]} + bitcast.1773.clone.2 = s32[1,32,256]{2,1,0:T(8,128)} bitcast(slice.11997.clone.2) + add.6005.clone.2 = s32[1,32,256]{2,1,0:T(8,128)S(1)} add(param_2.10219, bitcast.1773.clone.2) + ROOT tuple.2893 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}) tuple(add.10213, add.6005.clone.2) + } + + wide.windowed_dot_general_body { + wide_param.41 = (s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)}, s8[1,2048,4096]{2,1,0:T(8,128)(4,1)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:T(128)}, /*index=5*/u32[]{:T(128)}, u32[256]{0:T(256)}, u32[]{:T(128)}, u32[]{:T(128)}, s32[]{:T(128)}, /*index=10*/u32[]{:T(128)}, u32[]{:T(128)}, u32[]{:T(128)}) parameter(0) + get-tuple-element.29000 = s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)} get-tuple-element(wide_param.41), index=0 + get-tuple-element.29001 = s8[1,2048,4096]{2,1,0:T(8,128)(4,1)S(1)} get-tuple-element(wide_param.41), index=1 + get-tuple-element.28990 = s32[1,32,256]{2,1,0:T(8,128)S(1)} get-tuple-element(wide_param.41), index=3 + collective-permute-start = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:S(2)}, u32[]{:S(2)}) collective-permute-start(get-tuple-element.28990), channel_id=18, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11},{11,12},{12,13},{13,14},{14,15},{15,0},{16,17},{17,18},{18,19},{19,20},{20,21},{21,22},{22,23},{23,24},{24,25},{25,26},{26,27},{27,28},{28,29},{29,30},{30,31},{31,16},{32,33},{33,34},{34,35},{35,36},{36,37},{37,38},{38,39},{39,40},{40,41},{41,42},{42,43},{43,44},{44,45},{45,46},{46,47},{47,32},{48,49},{49,50},{50,51},{51,52},{52,53},{53,54},{54,55},{55,56},{56,57},{57,58},{58,59},{59,60},{60,61},{61,62},{62,63},{63,48},{64,65},{65,66},{66,67},{67,68},{68,69},{69,70},{70,71},{71,72},{72,73},{73,74},{74,75},{75,76},{76,77},{77,78},{78,79},{79,64},{80,81},{81,82},{82,83},{83,84},{84,85},{85,86},{86,87},{87,88},{88,89},{89,90},{90,91},{91,92},{92,93},{93,94},{94,95},{95,80},{96,97},{97,98},{98,99},{99,100},{100,101},{101,102},{102,103},{103,104},{104,105},{105,106},{106,107},{107,108},{108,109},{109,110},{110,111},{111,96},{112,113},{113,114},{114,115},{115,116},{116,117},{117,118},{118,119},{119,120},{120,121},{121,122},{122,123},{123,124},{124,125},{125,126},{126,127},{127,112},{128,129},{129,130},{130,131},{131,132},{132,133},{133,134},{134,135},{135,136},{136,137},{137,138},{138,139},{139,140},{140,141},{141,142},{142,143},{143,128},{144,145},{145,146},{146,147},{147,148},{148,149},{149,150},{150,151},{151,152},{152,153},{153,154},{154,155},{155,156},{156,157},{157,158},{158,159},{159,144},{160,161},{161,162},{162,163},{163,164},{164,165},{165,166},{166,167},{167,168},{168,169},{169,170},{170,171},{171,172},{172,173},{173,174},{174,175},{175,160},{176,177},{177,178},{178,179},{179,180},{180,181},{181,182},{182,183},{183,184},{184,185},{185,186},{186,187},{187,188},{188,189},{189,190},{190,191},{191,176},{192,193},{193,194},{194,195},{195,196},{196,197},{197,198},{198,199},{199,200},{200,201},{201,202},{202,203},{203,204},{204,205},{205,206},{206,207},{207,192},{208,209},{209,210},{210,211},{211,212},{212,213},{213,214},{214,215},{215,216},{216,217},{217,218},{218,219},{219,220},{220,221},{221,222},{222,223},{223,208},{224,225},{225,226},{226,227},{227,228},{228,229},{229,230},{230,231},{231,232},{232,233},{233,234},{234,235},{235,236},{236,237},{237,238},{238,239},{239,224},{240,241},{241,242},{242,243},{243,244},{244,245},{245,246},{246,247},{247,248},{248,249},{249,250},{250,251},{251,252},{252,253},{253,254},{254,255},{255,240}} + collective-permute-done = s32[1,32,256]{2,1,0:T(8,128)S(1)} collective-permute-done(collective-permute-start) + get-tuple-element.29005 = u32[]{:T(128)} get-tuple-element(wide_param.41), index=5 + get-tuple-element.29006 = u32[256]{0:T(256)} get-tuple-element(wide_param.41), index=6 + partition-id.101 = u32[] partition-id() + dynamic-slice.5472 = u32[1]{0:T(128)} dynamic-slice(get-tuple-element.29006, partition-id.101), dynamic_slice_sizes={1} + bitcast.7210 = u32[]{:T(128)} bitcast(dynamic-slice.5472) + get-tuple-element.29007 = u32[]{:T(128)} get-tuple-element(wide_param.41), index=7 + add.10204 = u32[]{:T(128)S(6)} add(bitcast.7210, get-tuple-element.29007) + get-tuple-element.28991 = u32[]{:T(128)} get-tuple-element(wide_param.41), index=4 + subtract.2863 = u32[]{:T(128)S(6)} subtract(add.10204, get-tuple-element.28991) + get-tuple-element.29008 = u32[]{:T(128)} get-tuple-element(wide_param.41), index=8 + and.400 = u32[]{:T(128)S(6)} and(subtract.2863, get-tuple-element.29008) + clamp.1712 = u32[]{:T(128)S(6)} clamp(get-tuple-element.29005, and.400, get-tuple-element.29008) + convert.8615 = s32[]{:T(128)S(6)} convert(clamp.1712) + get-tuple-element.29009 = s32[]{:T(128)} get-tuple-element(wide_param.41), index=9 + multiply.14830 = s32[]{:T(128)S(6)} multiply(convert.8615, get-tuple-element.29009) + bitcast.8823 = s8[1,2048,1,4096]{3,1,2,0:T(8,128)(4,1)S(1)} bitcast(get-tuple-element.29001) + add.10205 = u32[]{:T(128)S(6)} add(get-tuple-element.28991, bitcast.7210) + get-tuple-element.29010 = u32[]{:T(128)} get-tuple-element(wide_param.41), index=10 + add.10206 = u32[]{:T(128)S(6)} add(add.10205, get-tuple-element.29010) + and.401 = u32[]{:T(128)S(6)} and(add.10206, get-tuple-element.29008) + clamp.1713 = u32[]{:T(128)S(6)} clamp(get-tuple-element.29005, and.401, get-tuple-element.29008) + convert.8616 = s32[]{:T(128)S(6)} convert(clamp.1713) + multiply.14831 = s32[]{:T(128)S(6)} multiply(convert.8616, get-tuple-element.29009) + fusion.4289 = s32[1,32,2,256]{3,1,2,0:T(8,128)S(1)} fusion(multiply.14830, bitcast.8823, multiply.14831, get-tuple-element.29000), kind=kOutput, calls=fused_computation.71.clone + get-tuple-element.28989 = s32[1,32,256]{2,1,0:T(8,128)S(1)} get-tuple-element(wide_param.41), index=2 + collective-permute-start.1 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:S(2)}, u32[]{:S(2)}) collective-permute-start(get-tuple-element.28989), channel_id=17, source_target_pairs={{0,15},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10},{12,11},{13,12},{14,13},{15,14},{16,31},{17,16},{18,17},{19,18},{20,19},{21,20},{22,21},{23,22},{24,23},{25,24},{26,25},{27,26},{28,27},{29,28},{30,29},{31,30},{32,47},{33,32},{34,33},{35,34},{36,35},{37,36},{38,37},{39,38},{40,39},{41,40},{42,41},{43,42},{44,43},{45,44},{46,45},{47,46},{48,63},{49,48},{50,49},{51,50},{52,51},{53,52},{54,53},{55,54},{56,55},{57,56},{58,57},{59,58},{60,59},{61,60},{62,61},{63,62},{64,79},{65,64},{66,65},{67,66},{68,67},{69,68},{70,69},{71,70},{72,71},{73,72},{74,73},{75,74},{76,75},{77,76},{78,77},{79,78},{80,95},{81,80},{82,81},{83,82},{84,83},{85,84},{86,85},{87,86},{88,87},{89,88},{90,89},{91,90},{92,91},{93,92},{94,93},{95,94},{96,111},{97,96},{98,97},{99,98},{100,99},{101,100},{102,101},{103,102},{104,103},{105,104},{106,105},{107,106},{108,107},{109,108},{110,109},{111,110},{112,127},{113,112},{114,113},{115,114},{116,115},{117,116},{118,117},{119,118},{120,119},{121,120},{122,121},{123,122},{124,123},{125,124},{126,125},{127,126},{128,143},{129,128},{130,129},{131,130},{132,131},{133,132},{134,133},{135,134},{136,135},{137,136},{138,137},{139,138},{140,139},{141,140},{142,141},{143,142},{144,159},{145,144},{146,145},{147,146},{148,147},{149,148},{150,149},{151,150},{152,151},{153,152},{154,153},{155,154},{156,155},{157,156},{158,157},{159,158},{160,175},{161,160},{162,161},{163,162},{164,163},{165,164},{166,165},{167,166},{168,167},{169,168},{170,169},{171,170},{172,171},{173,172},{174,173},{175,174},{176,191},{177,176},{178,177},{179,178},{180,179},{181,180},{182,181},{183,182},{184,183},{185,184},{186,185},{187,186},{188,187},{189,188},{190,189},{191,190},{192,207},{193,192},{194,193},{195,194},{196,195},{197,196},{198,197},{199,198},{200,199},{201,200},{202,201},{203,202},{204,203},{205,204},{206,205},{207,206},{208,223},{209,208},{210,209},{211,210},{212,211},{213,212},{214,213},{215,214},{216,215},{217,216},{218,217},{219,218},{220,219},{221,220},{222,221},{223,222},{224,239},{225,224},{226,225},{227,226},{228,227},{229,228},{230,229},{231,230},{232,231},{233,232},{234,233},{235,234},{236,235},{237,236},{238,237},{239,238},{240,255},{241,240},{242,241},{243,242},{244,243},{245,244},{246,245},{247,246},{248,247},{249,248},{250,249},{251,250},{252,251},{253,252},{254,253},{255,254}} + collective-permute-done.1 = s32[1,32,256]{2,1,0:T(8,128)S(1)} collective-permute-done(collective-permute-start.1) + fusion.4290 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}) fusion(collective-permute-done, fusion.4289, collective-permute-done.1), kind=kLoop, calls=fused_computation.76.clone + get-tuple-element.22079 = s32[1,32,256]{2,1,0:T(8,128)S(1)} get-tuple-element(fusion.4290), index=0 + collective-permute-start.2 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:S(2)}, u32[]{:S(2)}) collective-permute-start(get-tuple-element.22079), channel_id=20, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11},{11,12},{12,13},{13,14},{14,15},{15,0},{16,17},{17,18},{18,19},{19,20},{20,21},{21,22},{22,23},{23,24},{24,25},{25,26},{26,27},{27,28},{28,29},{29,30},{30,31},{31,16},{32,33},{33,34},{34,35},{35,36},{36,37},{37,38},{38,39},{39,40},{40,41},{41,42},{42,43},{43,44},{44,45},{45,46},{46,47},{47,32},{48,49},{49,50},{50,51},{51,52},{52,53},{53,54},{54,55},{55,56},{56,57},{57,58},{58,59},{59,60},{60,61},{61,62},{62,63},{63,48},{64,65},{65,66},{66,67},{67,68},{68,69},{69,70},{70,71},{71,72},{72,73},{73,74},{74,75},{75,76},{76,77},{77,78},{78,79},{79,64},{80,81},{81,82},{82,83},{83,84},{84,85},{85,86},{86,87},{87,88},{88,89},{89,90},{90,91},{91,92},{92,93},{93,94},{94,95},{95,80},{96,97},{97,98},{98,99},{99,100},{100,101},{101,102},{102,103},{103,104},{104,105},{105,106},{106,107},{107,108},{108,109},{109,110},{110,111},{111,96},{112,113},{113,114},{114,115},{115,116},{116,117},{117,118},{118,119},{119,120},{120,121},{121,122},{122,123},{123,124},{124,125},{125,126},{126,127},{127,112},{128,129},{129,130},{130,131},{131,132},{132,133},{133,134},{134,135},{135,136},{136,137},{137,138},{138,139},{139,140},{140,141},{141,142},{142,143},{143,128},{144,145},{145,146},{146,147},{147,148},{148,149},{149,150},{150,151},{151,152},{152,153},{153,154},{154,155},{155,156},{156,157},{157,158},{158,159},{159,144},{160,161},{161,162},{162,163},{163,164},{164,165},{165,166},{166,167},{167,168},{168,169},{169,170},{170,171},{171,172},{172,173},{173,174},{174,175},{175,160},{176,177},{177,178},{178,179},{179,180},{180,181},{181,182},{182,183},{183,184},{184,185},{185,186},{186,187},{187,188},{188,189},{189,190},{190,191},{191,176},{192,193},{193,194},{194,195},{195,196},{196,197},{197,198},{198,199},{199,200},{200,201},{201,202},{202,203},{203,204},{204,205},{205,206},{206,207},{207,192},{208,209},{209,210},{210,211},{211,212},{212,213},{213,214},{214,215},{215,216},{216,217},{217,218},{218,219},{219,220},{220,221},{221,222},{222,223},{223,208},{224,225},{225,226},{226,227},{227,228},{228,229},{229,230},{230,231},{231,232},{232,233},{233,234},{234,235},{235,236},{236,237},{237,238},{238,239},{239,224},{240,241},{241,242},{242,243},{243,244},{244,245},{245,246},{246,247},{247,248},{248,249},{249,250},{250,251},{251,252},{252,253},{253,254},{254,255},{255,240}} + collective-permute-done.2 = s32[1,32,256]{2,1,0:T(8,128)S(1)} collective-permute-done(collective-permute-start.2) + get-tuple-element.29011 = u32[]{:T(128)} get-tuple-element(wide_param.41), index=11 + add.10209 = u32[]{:T(128)S(6)} add(get-tuple-element.28991, get-tuple-element.29011) + subtract.2864 = u32[]{:T(128)S(6)} subtract(add.10204, add.10209) + and.402 = u32[]{:T(128)S(6)} and(subtract.2864, get-tuple-element.29008) + clamp.1714 = u32[]{:T(128)S(6)} clamp(get-tuple-element.29005, and.402, get-tuple-element.29008) + convert.8617 = s32[]{:T(128)S(6)} convert(clamp.1714) + multiply.14832 = s32[]{:T(128)S(6)} multiply(convert.8617, get-tuple-element.29009) + bitcast.8824 = s8[1,2048,1,4096]{3,1,2,0:T(8,128)(4,1)S(1)} bitcast(get-tuple-element.29001) + add.10210 = u32[]{:T(128)S(6)} add(add.10209, bitcast.7210) + add.10211 = u32[]{:T(128)S(6)} add(add.10210, get-tuple-element.29010) + and.403 = u32[]{:T(128)S(6)} and(add.10211, get-tuple-element.29008) + clamp.1715 = u32[]{:T(128)S(6)} clamp(get-tuple-element.29005, and.403, get-tuple-element.29008) + convert.8618 = s32[]{:T(128)S(6)} convert(clamp.1715) + multiply.14833 = s32[]{:T(128)S(6)} multiply(convert.8618, get-tuple-element.29009) + fusion.4293 = s32[1,32,2,256]{3,1,2,0:T(8,128)S(1)} fusion(multiply.14832, bitcast.8824, multiply.14833, get-tuple-element.29000), kind=kOutput, calls=fused_computation.72.clone + get-tuple-element.22080 = s32[1,32,256]{2,1,0:T(8,128)S(1)} get-tuple-element(fusion.4290), index=1 + collective-permute-start.3 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:S(2)}, u32[]{:S(2)}) collective-permute-start(get-tuple-element.22080), channel_id=19, source_target_pairs={{0,15},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10},{12,11},{13,12},{14,13},{15,14},{16,31},{17,16},{18,17},{19,18},{20,19},{21,20},{22,21},{23,22},{24,23},{25,24},{26,25},{27,26},{28,27},{29,28},{30,29},{31,30},{32,47},{33,32},{34,33},{35,34},{36,35},{37,36},{38,37},{39,38},{40,39},{41,40},{42,41},{43,42},{44,43},{45,44},{46,45},{47,46},{48,63},{49,48},{50,49},{51,50},{52,51},{53,52},{54,53},{55,54},{56,55},{57,56},{58,57},{59,58},{60,59},{61,60},{62,61},{63,62},{64,79},{65,64},{66,65},{67,66},{68,67},{69,68},{70,69},{71,70},{72,71},{73,72},{74,73},{75,74},{76,75},{77,76},{78,77},{79,78},{80,95},{81,80},{82,81},{83,82},{84,83},{85,84},{86,85},{87,86},{88,87},{89,88},{90,89},{91,90},{92,91},{93,92},{94,93},{95,94},{96,111},{97,96},{98,97},{99,98},{100,99},{101,100},{102,101},{103,102},{104,103},{105,104},{106,105},{107,106},{108,107},{109,108},{110,109},{111,110},{112,127},{113,112},{114,113},{115,114},{116,115},{117,116},{118,117},{119,118},{120,119},{121,120},{122,121},{123,122},{124,123},{125,124},{126,125},{127,126},{128,143},{129,128},{130,129},{131,130},{132,131},{133,132},{134,133},{135,134},{136,135},{137,136},{138,137},{139,138},{140,139},{141,140},{142,141},{143,142},{144,159},{145,144},{146,145},{147,146},{148,147},{149,148},{150,149},{151,150},{152,151},{153,152},{154,153},{155,154},{156,155},{157,156},{158,157},{159,158},{160,175},{161,160},{162,161},{163,162},{164,163},{165,164},{166,165},{167,166},{168,167},{169,168},{170,169},{171,170},{172,171},{173,172},{174,173},{175,174},{176,191},{177,176},{178,177},{179,178},{180,179},{181,180},{182,181},{183,182},{184,183},{185,184},{186,185},{187,186},{188,187},{189,188},{190,189},{191,190},{192,207},{193,192},{194,193},{195,194},{196,195},{197,196},{198,197},{199,198},{200,199},{201,200},{202,201},{203,202},{204,203},{205,204},{206,205},{207,206},{208,223},{209,208},{210,209},{211,210},{212,211},{213,212},{214,213},{215,214},{216,215},{217,216},{218,217},{219,218},{220,219},{221,220},{222,221},{223,222},{224,239},{225,224},{226,225},{227,226},{228,227},{229,228},{230,229},{231,230},{232,231},{233,232},{234,233},{235,234},{236,235},{237,236},{238,237},{239,238},{240,255},{241,240},{242,241},{243,242},{244,243},{245,244},{246,245},{247,246},{248,247},{249,248},{250,249},{251,250},{252,251},{253,252},{254,253},{255,254}} + collective-permute-done.3 = s32[1,32,256]{2,1,0:T(8,128)S(1)} collective-permute-done(collective-permute-start.3) + fusion.4294 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}) fusion(collective-permute-done.2, fusion.4293, collective-permute-done.3), kind=kLoop, calls=fused_computation.74.clone + get-tuple-element.29002 = s32[1,32,256]{2,1,0:T(8,128)S(1)} get-tuple-element(fusion.4294), index=1 + get-tuple-element.29003 = s32[1,32,256]{2,1,0:T(8,128)S(1)} get-tuple-element(fusion.4294), index=0 + get-tuple-element.29012 = u32[]{:T(128)} get-tuple-element(wide_param.41), index=12 + constant.28871 = u32[]{:T(128)} constant(2) + add.10214 = u32[]{:T(128)} add(get-tuple-element.28991, constant.28871) + ROOT tuple.3341 = (s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)}, s8[1,2048,4096]{2,1,0:T(8,128)(4,1)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:T(128)}, /*index=5*/u32[]{:T(128)}, u32[256]{0:T(256)}, u32[]{:T(128)}, u32[]{:T(128)}, s32[]{:T(128)}, /*index=10*/u32[]{:T(128)}, u32[]{:T(128)}, u32[]{:T(128)}) tuple(get-tuple-element.29000, get-tuple-element.29001, get-tuple-element.29002, get-tuple-element.29003, add.10214, get-tuple-element.29005, get-tuple-element.29006, get-tuple-element.29007, get-tuple-element.29008, get-tuple-element.29009, get-tuple-element.29010, get-tuple-element.29011, get-tuple-element.29012) + } + + wide.windowed_dot_general_cond { + wide_param.40 = (s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)}, s8[1,2048,4096]{2,1,0:T(8,128)(4,1)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:T(128)}, /*index=5*/u32[]{:T(128)}, u32[256]{0:T(256)}, u32[]{:T(128)}, u32[]{:T(128)}, s32[]{:T(128)}, /*index=10*/u32[]{:T(128)}, u32[]{:T(128)}, u32[]{:T(128)}) parameter(0) + get-tuple-element.22055 = u32[]{:T(128)} get-tuple-element(wide_param.40), index=4 + constant.26614 = u32[]{:T(128)} constant(8) + ROOT compare.2683 = pred[]{:T(512)} compare(get-tuple-element.22055, constant.26614), direction=LT + } + + ENTRY test { + fusion.4456 = s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)} parameter(0) + fusion.4457 = s8[1,2048,4096]{2,1,0:T(8,128)(4,1)S(1)} parameter(1) + broadcast.26239 = s32[1,32,256]{2,1,0:T(8,128)S(1)} parameter(2) + broadcast.26239.clone = s32[1,32,256]{2,1,0:T(8,128)S(1)} parameter(3) + constant.28863 = u32[]{:T(128)} constant(0) + constant.28864 = u32[]{:T(128)} constant(0) + constant.28865 = u32[256]{0:T(256)} constant({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255}) + constant.28866 = u32[]{:T(128)} constant(8) + constant.28867 = u32[]{:T(128)} constant(15) + constant.28868 = s32[]{:T(128)} constant(256) + constant.28869 = u32[]{:T(128)} constant(9) + constant.28870 = u32[]{:T(128)} constant(1) + constant.28871 = u32[]{:T(128)} constant(2) + tuple.3339 = (s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)}, s8[1,2048,4096]{2,1,0:T(8,128)(4,1)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:T(128)}, /*index=5*/u32[]{:T(128)}, u32[256]{0:T(256)}, u32[]{:T(128)}, u32[]{:T(128)}, s32[]{:T(128)}, /*index=10*/u32[]{:T(128)}, u32[]{:T(128)}, u32[]{:T(128)}) tuple(fusion.4456, fusion.4457, broadcast.26239, broadcast.26239.clone, constant.28863, constant.28864, constant.28865, constant.28866, constant.28867, constant.28868, constant.28869, constant.28870, constant.28871) + ROOT while.636 = (s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)}, s8[1,2048,4096]{2,1,0:T(8,128)(4,1)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:T(128)}, /*index=5*/u32[]{:T(128)}, u32[256]{0:T(256)}, u32[]{:T(128)}, u32[]{:T(128)}, s32[]{:T(128)}, /*index=10*/u32[]{:T(128)}, u32[]{:T(128)}, u32[]{:T(128)}) while(tuple.3339), condition=wide.windowed_dot_general_cond, body=wide.windowed_dot_general_body + })"; + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + + int64_t fusion_instr_count = absl::c_count_if( + module->GetComputationWithName("wide.windowed_dot_general_body") + ->instructions(), + [](const HloInstruction* instr) { + return (instr->IsLoopFusion() || instr->IsOutputFusion()); + }); + + // Fully unroll the specific loop (trip count is 4) + EXPECT_TRUE( + WhileLoopUnroller(/*unroll_factor=*/-1).Run(module.get()).value()); + + int64_t fusion_instr_count_after_unroll = absl::c_count_if( + module->entry_computation()->instructions(), + [](const HloInstruction* instr) { + return (instr->IsLoopFusion() || instr->IsOutputFusion()); + }); + + // The total number of fusions in the unrolled version in the entry must be + // equal to loop_trip_count * fusion_instr_count + EXPECT_EQ(fusion_instr_count * 4, fusion_instr_count_after_unroll); } } // namespace diff --git a/xla/service/while_util.cc b/xla/service/while_util.cc index e3c7a61b9cd40..2f944bf472423 100644 --- a/xla/service/while_util.cc +++ b/xla/service/while_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -49,7 +49,8 @@ namespace xla { using absl::StrCat; -static StatusOr> +static absl::StatusOr< + std::pair> WidenWhileCondition(HloComputation* narrow_condition, const Shape& wide_shape) { const Shape& narrow_shape = narrow_condition->parameter_instruction(0)->shape(); @@ -86,7 +87,8 @@ WidenWhileCondition(HloComputation* narrow_condition, const Shape& wide_shape) { return {{wide_while_cond, std::move(inlined_instructions_map)}}; } -static StatusOr> +static absl::StatusOr< + std::pair> WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) { const Shape& narrow_shape = narrow_body->parameter_instruction(0)->shape(); @@ -125,7 +127,7 @@ WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) { return {{wide_while_body, std::move(inlined_instructions_map)}}; } -/*static*/ StatusOr +/*static*/ absl::StatusOr WhileUtil::MakeInstructionsLiveIn( HloInstruction* while_instr, absl::Span instructions) { @@ -188,7 +190,7 @@ WhileUtil::MakeInstructionsLiveIn( return std::move(result); } -static StatusOr> +static absl::StatusOr> MakeCountedLoopConditionComputation(const Shape& loop_state_shape, int32_t trip_count) { Shape scalar_pred = ShapeUtil::MakeShape(PRED, {}); @@ -212,9 +214,10 @@ MakeCountedLoopConditionComputation(const Shape& loop_state_shape, return std::move(cond_computation); } -static StatusOr> MakeCountedLoopBodyComputation( +static absl::StatusOr> +MakeCountedLoopBodyComputation( const Shape& loop_state_shape, - absl::FunctionRef( + absl::FunctionRef( HloInstruction*, const WhileUtil::LoopStateTy&)> loop_body_generator) { TF_ASSIGN_OR_RETURN(std::unique_ptr body_computation, @@ -277,11 +280,11 @@ static Shape MakeLoopStateShapeWithLayout( return ShapeUtil::MakeTupleShape(loop_state_shape_components); } -/*static*/ StatusOr WhileUtil::MakeCountedLoop( - HloModule* module, int32_t trip_count, - const WhileUtil::LoopStateTy& init_values, - WhileUtil::LoopBodyGeneratorTy loop_body_generator, - const OpMetadata& metadata) { +/*static*/ absl::StatusOr +WhileUtil::MakeCountedLoop(HloModule* module, int32_t trip_count, + const WhileUtil::LoopStateTy& init_values, + WhileUtil::LoopBodyGeneratorTy loop_body_generator, + const OpMetadata& metadata) { CHECK_GE(trip_count, 0); // Both MakeCountedLoopConditionComputation and MakeCountedLoopBodyComputation @@ -319,7 +322,7 @@ static Shape MakeLoopStateShapeWithLayout( return WhileUtil::OwningLoopStateTy{std::move(owned), while_results}; } -/*static*/ StatusOr WhileUtil::MakeCountedLoop( +/*static*/ absl::StatusOr WhileUtil::MakeCountedLoop( HloComputation* computation, int32_t trip_count, const WhileUtil::LoopStateTy& init_values, WhileUtil::LoopBodyGeneratorTy loop_body_generator, diff --git a/xla/service/while_util.h b/xla/service/while_util.h index ac6aac8b26765..03cd85605cf14 100644 --- a/xla/service/while_util.h +++ b/xla/service/while_util.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/call_inliner.h" #include "xla/statusor.h" +#include "xla/xla_data.pb.h" namespace xla { class WhileUtil { @@ -71,12 +72,12 @@ class WhileUtil { // // Every instruction in `instructions` must be contained in the computation // that contains `while_instr`. - static StatusOr MakeInstructionsLiveIn( + static absl::StatusOr MakeInstructionsLiveIn( HloInstruction* while_instr, absl::Span instructions); using LoopStateTy = std::vector; - using LoopBodyGeneratorTy = absl::FunctionRef( + using LoopBodyGeneratorTy = absl::FunctionRef( HloInstruction* /*induction_var*/, const LoopStateTy& /*current_values*/)>; @@ -92,7 +93,7 @@ class WhileUtil { // } // return loop_state; // } - static StatusOr MakeCountedLoop( + static absl::StatusOr MakeCountedLoop( HloComputation* computation, int32_t trip_count, const LoopStateTy& init_values, LoopBodyGeneratorTy loop_body_generator, const OpMetadata& metadata); @@ -104,7 +105,7 @@ class WhileUtil { // As above but does not add the while loop or other instructions created // around it in any particular computation. The caller can instead add it to a // computation of their choosing. - static StatusOr MakeCountedLoop( + static absl::StatusOr MakeCountedLoop( HloModule* module, int32_t trip_count, const WhileUtil::LoopStateTy& init_values, WhileUtil::LoopBodyGeneratorTy loop_body_generator, diff --git a/xla/service/while_util_test.cc b/xla/service/while_util_test.cc index 96e59f56e6052..e18975f7bdd01 100644 --- a/xla/service/while_util_test.cc +++ b/xla/service/while_util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,11 +18,15 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/statusor.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" #include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -31,7 +35,7 @@ namespace op = ::xla::testing::opcode_matchers; class WhileUtilTest : public HloTestBase { protected: - StatusOr> GetParsedModule( + absl::StatusOr> GetParsedModule( HloComputation** entry_computation, HloInstruction** param0, HloInstruction** param1, HloInstruction** param2) { const char* const hlo_string = R"( diff --git a/xla/service/xla_aot_compile_cpu_test.cc b/xla/service/xla_aot_compile_cpu_test.cc index 02c41cf0d580c..16ae69b5bf0e1 100644 --- a/xla/service/xla_aot_compile_cpu_test.cc +++ b/xla/service/xla_aot_compile_cpu_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,11 +16,15 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "xla/client/client_library.h" +#include "xla/client/executable_build_options.h" #include "xla/client/local_client.h" #include "xla/executable_run_options.h" +#include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/platform_util.h" +#include "xla/service/shaped_buffer.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" @@ -69,7 +73,9 @@ TEST(XlaCompileTest, LoadCpuExecutable) { executable_run_options.set_allocator(client->backend().memory_allocator()); TF_ASSERT_OK_AND_ASSIGN( ScopedShapedBuffer result, - local_executable->Run({&array1, &array2}, executable_run_options)); + local_executable->Run( + absl::Span{&array1, &array2}, + executable_run_options)); TF_ASSERT_OK_AND_ASSIGN(Literal output, client->ShapedBufferToLiteral(result)); diff --git a/xla/service/xla_aot_compile_gpu_test.cc b/xla/service/xla_aot_compile_gpu_test.cc index 2e1283f6d13f0..a3720b3b39f15 100644 --- a/xla/service/xla_aot_compile_gpu_test.cc +++ b/xla/service/xla_aot_compile_gpu_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,14 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/client/client_library.h" +#include "xla/client/executable_build_options.h" #include "xla/client/local_client.h" #include "xla/executable_run_options.h" +#include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/platform_util.h" +#include "xla/service/shaped_buffer.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" @@ -58,7 +65,7 @@ TEST_P(XlaAotCompileTest, LoadGpuExecutable) { std::unique_ptr local_executable, client->Load(serialized_aot_result, executable_build_options)); - // Run loaded excutable. + // Run loaded executable. Literal input1 = LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f}); Literal input2 = LiteralUtil::CreateR1({1.0f, 2.0f, 4.0f}); TF_ASSERT_OK_AND_ASSIGN( @@ -109,7 +116,7 @@ TEST(XlaCompileTest, LoadGpuExecutableWithConstant) { std::unique_ptr local_executable, client->Load(serialized_aot_result, executable_build_options)); - // Run loaded excutable. + // Run loaded executable. Literal input = LiteralUtil::CreateR1({3.0f, 3.0f, 3.0f}); TF_ASSERT_OK_AND_ASSIGN( ScopedShapedBuffer array, @@ -126,115 +133,7 @@ TEST(XlaCompileTest, LoadGpuExecutableWithConstant) { EXPECT_EQ(expected, output); } -TEST(XlaCompileTest, LoadGpuExecutableWithGemm) { - std::string path = - tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", - "xla_aot_compile_test_gpu_executable_gemm"); - std::string serialized_aot_result; - TF_ASSERT_OK( - tsl::ReadFileToString(tsl::Env::Default(), path, &serialized_aot_result)); - - // Check that GemmAlgorithmPicker successfully loaded autotune results. - EXPECT_TRUE(absl::StrContains(serialized_aot_result, "algorithm = 13 : i64")) - << serialized_aot_result; - - // Get a LocalClient - TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, - PlatformUtil::GetPlatform("CUDA")); - ASSERT_GT(platform->VisibleDeviceCount(), 0); - - LocalClientOptions local_client_options; - local_client_options.set_platform(platform); - TF_ASSERT_OK_AND_ASSIGN( - LocalClient * client, - ClientLibrary::GetOrCreateLocalClient(local_client_options)); - - // Load from AOT result. - ExecutableBuildOptions executable_build_options; - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr local_executable, - client->Load(serialized_aot_result, executable_build_options)); - - // Run loaded excutable. - Literal input1 = LiteralUtil::CreateR2( - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}); - Literal input2 = LiteralUtil::CreateR2( - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}); - - TF_ASSERT_OK_AND_ASSIGN( - ScopedShapedBuffer array1, - client->LiteralToShapedBuffer(input1, client->default_device_ordinal())); - TF_ASSERT_OK_AND_ASSIGN( - ScopedShapedBuffer array2, - client->LiteralToShapedBuffer(input2, client->default_device_ordinal())); - - ExecutableRunOptions executable_run_options; - executable_run_options.set_allocator(client->backend().memory_allocator()); - TF_ASSERT_OK_AND_ASSIGN( - ScopedShapedBuffer result, - local_executable->Run({&array1, &array2}, executable_run_options)); - - TF_ASSERT_OK_AND_ASSIGN(Literal output, - client->ShapedBufferToLiteral(result)); - Literal expected = LiteralUtil::CreateR2( - {{30.0f, 36.0f, 42.0f}, {66.0, 81.0, 96.0}, {102.0, 126.0, 150.0}}); - EXPECT_EQ(expected, output); -} - -TEST(XlaCompileTest, LoadGpuExecutableWithGemmRuntimeAutotuning) { - std::string path = tsl::io::JoinPath( - tsl::testing::XlaSrcRoot(), "service", - "xla_aot_compile_test_gpu_executable_gemm_runtime_autotuning"); - std::string serialized_aot_result; - TF_ASSERT_OK( - tsl::ReadFileToString(tsl::Env::Default(), path, &serialized_aot_result)); - - // Check that runtime autotuning is enabled. - EXPECT_TRUE(absl::StrContains(serialized_aot_result, "algorithm = -5 : i64")); - - // Get a LocalClient - TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, - PlatformUtil::GetPlatform("CUDA")); - ASSERT_GT(platform->VisibleDeviceCount(), 0); - - LocalClientOptions local_client_options; - local_client_options.set_platform(platform); - TF_ASSERT_OK_AND_ASSIGN( - LocalClient * client, - ClientLibrary::GetOrCreateLocalClient(local_client_options)); - - // Load from AOT result. - ExecutableBuildOptions executable_build_options; - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr local_executable, - client->Load(serialized_aot_result, executable_build_options)); - - // Run loaded excutable. - Literal input1 = LiteralUtil::CreateR2( - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}); - Literal input2 = LiteralUtil::CreateR2( - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}); - - TF_ASSERT_OK_AND_ASSIGN( - ScopedShapedBuffer array1, - client->LiteralToShapedBuffer(input1, client->default_device_ordinal())); - TF_ASSERT_OK_AND_ASSIGN( - ScopedShapedBuffer array2, - client->LiteralToShapedBuffer(input2, client->default_device_ordinal())); - - ExecutableRunOptions executable_run_options; - executable_run_options.set_allocator(client->backend().memory_allocator()); - TF_ASSERT_OK_AND_ASSIGN( - ScopedShapedBuffer result, - local_executable->Run({&array1, &array2}, executable_run_options)); - - TF_ASSERT_OK_AND_ASSIGN(Literal output, - client->ShapedBufferToLiteral(result)); - Literal expected = LiteralUtil::CreateR2( - {{30.0f, 36.0f, 42.0f}, {66.0, 81.0, 96.0}, {102.0, 126.0, 150.0}}); - EXPECT_EQ(expected, output); -} - +// Should also cover the case of loading a GPU executable with a GEMM. TEST(XlaCompileTest, LoadGpuExecutableWithConvolution) { std::string path = tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", @@ -244,7 +143,7 @@ TEST(XlaCompileTest, LoadGpuExecutableWithConvolution) { tsl::ReadFileToString(tsl::Env::Default(), path, &serialized_aot_result)); // Check that GpuConvAlgorithmPicker successfully loaded autotune results. - EXPECT_TRUE(absl::StrContains(serialized_aot_result, "\"algo_id\":\"3\"")) + EXPECT_TRUE(absl::StrContains(serialized_aot_result, "\"algo_id\":\"28\"")) << serialized_aot_result; // Get a LocalClient @@ -252,69 +151,6 @@ TEST(XlaCompileTest, LoadGpuExecutableWithConvolution) { PlatformUtil::GetPlatform("CUDA")); ASSERT_GT(platform->VisibleDeviceCount(), 0); - LocalClientOptions local_client_options; - local_client_options.set_platform(platform); - TF_ASSERT_OK_AND_ASSIGN( - LocalClient * client, - ClientLibrary::GetOrCreateLocalClient(local_client_options)); - - // Load from AOT result. - ExecutableBuildOptions executable_build_options; - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr local_executable, - client->Load(serialized_aot_result, executable_build_options)); - - // Run loaded excutable. - Literal input1 = LiteralUtil::CreateR4( - {{{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}, {7.0, 8.0}}, - {{11.0, 12.0}, {13.0, 14.0}, {15.0, 16.0}, {17.0, 18.0}}, - {{21.0, 22.0}, {23.0, 24.0}, {25.0, 26.0}, {27.0, 28.0}}, - {{31.0, 32.0}, {33.0, 34.0}, {35.0, 36.0}, {37.0, 38.0}}}}); - Literal input2 = - LiteralUtil::CreateR4({{{{1.0}, {2.0}}, {{3.0}, {4.0}}}, - {{{5.0}, {6.0}}, {{7.0}, {8.0}}}, - {{{9.0}, {10.0}}, {{11.0}, {12.0}}}}); - - TF_ASSERT_OK_AND_ASSIGN( - ScopedShapedBuffer array1, - client->LiteralToShapedBuffer(input1, client->default_device_ordinal())); - TF_ASSERT_OK_AND_ASSIGN( - ScopedShapedBuffer array2, - client->LiteralToShapedBuffer(input2, client->default_device_ordinal())); - - ExecutableRunOptions executable_run_options; - executable_run_options.set_allocator(client->backend().memory_allocator()); - TF_ASSERT_OK_AND_ASSIGN( - ScopedShapedBuffer result, - local_executable->Run({&array1, &array2}, executable_run_options)); - - TF_ASSERT_OK_AND_ASSIGN(Literal output, - client->ShapedBufferToLiteral(result)); - Literal expected = LiteralUtil::CreateR4({{ - {{1310.0}, {1466.0}, {1622.0}}, - {{2090.0}, {2246.0}, {2402.0}}, - }}); - EXPECT_EQ(expected, output); -} - -// Run an AOT compiled executable in which the algorithm of convolution is set -// to -1. -TEST(XlaCompileTest, LoadGpuExecutableWithConvolutionRuntimeAutotuning) { - std::string path = tsl::io::JoinPath( - tsl::testing::XlaSrcRoot(), "service", - "xla_aot_compile_test_gpu_executable_convolution_runtime_autotuning"); - std::string serialized_aot_result; - TF_ASSERT_OK( - tsl::ReadFileToString(tsl::Env::Default(), path, &serialized_aot_result)); - - // Check that runtime autotuning is enabled. - EXPECT_TRUE(absl::StrContains(serialized_aot_result, "algorithm = -1")); - - // Get a LocalClient - TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, - PlatformUtil::GetPlatform("CUDA")); - ASSERT_GT(platform->VisibleDeviceCount(), 0); - LocalClientOptions local_client_options; local_client_options.set_platform(platform); TF_ASSERT_OK_AND_ASSIGN( diff --git a/xla/service/xla_aot_compile_stablehlo_cpu_test.cc b/xla/service/xla_aot_compile_stablehlo_cpu_test.cc index 0bb6534c2ebc6..7526cd401c71c 100644 --- a/xla/service/xla_aot_compile_stablehlo_cpu_test.cc +++ b/xla/service/xla_aot_compile_stablehlo_cpu_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,10 +17,14 @@ limitations under the License. #include #include "xla/client/client_library.h" +#include "xla/client/executable_build_options.h" #include "xla/client/local_client.h" +#include "xla/error_spec.h" #include "xla/executable_run_options.h" +#include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/platform_util.h" +#include "xla/service/shaped_buffer.h" #include "xla/tests/literal_test_util.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" diff --git a/xla/service/xla_aot_compile_stablehlo_test.mlir b/xla/service/xla_aot_compile_stablehlo_test.mlir index 9925ef2a232ce..cc409ec78737b 100644 --- a/xla/service/xla_aot_compile_stablehlo_test.mlir +++ b/xla/service/xla_aot_compile_stablehlo_test.mlir @@ -1,7 +1,7 @@ module @axpy { func.func public @main(%alpha: tensor, %x: tensor<4 x f32>, %y: tensor<4 x f32>) -> tensor<4 x f32> { %a = "stablehlo.broadcast_in_dim" (%alpha) { - broadcast_dimensions = dense<[]> : tensor<0 x i64> + broadcast_dimensions = array } : (tensor) -> tensor<4 x f32> %ax = stablehlo.multiply %a, %x : tensor<4 x f32> %result = stablehlo.add %ax, %y : tensor<4 x f32> diff --git a/xla/service/xla_aot_compile_test_autotune_results.prototxt b/xla/service/xla_aot_compile_test_autotune_results.prototxt index f6b15b1d5f799..592ea2a9e185f 100644 --- a/xla/service/xla_aot_compile_test_autotune_results.prototxt +++ b/xla/service/xla_aot_compile_test_autotune_results.prototxt @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -version: 2 +version: 3 results { device: "sm_6.0 with 17071734784B RAM, 56 cores, 1480500KHz clock, 715000KHz mem clock, 4194304B L2$" - hlo: "(f32[3,3]{1,0}, s8[72]{0}) custom-call(f32[3,3]{1,0}, f32[3,3]{1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"9\",\"rhs_stride\":\"9\",\"grad_x\":false,\"grad_y\":false}" + hlo: "(f32[3,3]{1,0}, s8[72]{0}) custom-call(f32[3,3]{1,0}, f32[3,3]{1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"],\"algorithm\":\"ALG_UNSET\"},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"9\",\"rhs_stride\":\"9\",\"grad_x\":false,\"grad_y\":false},\"force_earliest_schedule\":false}" result { gemm { algorithm: 13 @@ -24,13 +24,22 @@ results { } results { device: "sm_6.0 with 17071734784B RAM, 56 cores, 1480500KHz clock, 715000KHz mem clock, 4194304B L2$" - hlo: "(f32[1,1,2,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[2,1,4,4]{3,2,1,0}, f32[2,1,3,2]{3,2,1,0}), window={size=2x3}, dim_labels=bf01_oi01->bf01, custom_call_target=\"__cudnn$convBackwardFilter\", backend_config={\"activation_mode\":\"kNone\",\"conv_result_scale\":1,\"side_input_scale\":0,\"leakyrelu_alpha\":0}" + hlo: "(f32[1,1,2,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,2,4,4]{3,2,1,0}, f32[1,2,3,2]{3,2,1,0}), window={size=3x2}, dim_labels=bf01_oi01->bf01, custom_call_target=\"__cudnn$convForward\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"cudnn_conv_backend_config\":{\"activation_mode\":\"kNone\",\"conv_result_scale\":1,\"side_input_scale\":0,\"leakyrelu_alpha\":0},\"force_earliest_schedule\":false}" result { run_time { - nanos: 45408 + nanos: 8192 } algorithm { - algo_id: 3 + algo_id: 28 + tuning_knobs { + key: 2 + value: 4 + } + tuning_knobs { + key: 3 + value: 0 + } + is_cudnn_frontend: true workspace_size { } } diff --git a/xla/service/xla_aot_compile_test_gemm.mlir b/xla/service/xla_aot_compile_test_gemm.mlir deleted file mode 100644 index 07fb00e7e50e0..0000000000000 --- a/xla/service/xla_aot_compile_test_gemm.mlir +++ /dev/null @@ -1,12 +0,0 @@ -module @foo { - func.func public @main(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> { - %0 = "mhlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #mhlo.dot< - lhs_contracting_dimensions = [1], - rhs_contracting_dimensions = [0] - >, - precision_config = [#mhlo, #mhlo] - } : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> - func.return %0 : tensor<3x3xf32> - } -} \ No newline at end of file diff --git a/xla/service/xla_aot_compile_test_gpu_target_config.prototxt b/xla/service/xla_aot_compile_test_gpu_target_config.prototxt index a8c387bd47586..4ac588f89ff29 100644 --- a/xla/service/xla_aot_compile_test_gpu_target_config.prototxt +++ b/xla/service/xla_aot_compile_test_gpu_target_config.prototxt @@ -1,4 +1,4 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# Copyright 2022 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,4 +36,4 @@ dnn_version_info { minor: 3 patch: 2 } -device_description_str: "sm_6.0 with 17071734784B RAM, 56 cores, 1480500KHz clock, 715000KHz mem clock, 4194304B L2$" \ No newline at end of file +device_description_str: "sm_6.0 with 17071734784B RAM, 56 cores, 1480500KHz clock, 715000KHz mem clock, 4194304B L2$" diff --git a/xla/service/xla_compile_main.cc b/xla/service/xla_compile_main.cc index f35c308221556..761412c554d23 100644 --- a/xla/service/xla_compile_main.cc +++ b/xla/service/xla_compile_main.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,49 +14,17 @@ limitations under the License. ==============================================================================*/ #include -#include -#include #include -#include -#include #include -#include "absl/cleanup/cleanup.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/DialectRegistry.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/Parser/Parser.h" // from @llvm-project -#include "stablehlo/dialect/Register.h" // from @stablehlo -#include "xla/autotune_results.pb.h" -#include "xla/debug_options_flags.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/pjrt/mlir_to_hlo.h" -#include "xla/service/compiler.h" -#include "xla/service/export_hlo.h" -#include "xla/service/hlo_module_config.h" -#include "xla/service/symbol_repository.h" -#include "xla/service/xla_compile_result.pb.h" -#include "xla/statusor.h" -#include "xla/tools/hlo_module_loader.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "xla/status.h" #include "xla/tools/xla_compile_lib.h" -#include "xla/util.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/init_main.h" -#include "tsl/platform/path.h" -#include "tsl/platform/protobuf.h" -#include "tsl/platform/status_to_from_proto.h" #include "tsl/platform/types.h" -#include "tsl/util/command_line_flags.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/service/gpu/autotuner_util.h" -#include "xla/service/gpu/gpu_symbol_repository.h" -#endif namespace xla { namespace xla_compile { @@ -72,7 +40,7 @@ const char kUsageHeader[] = "\n" "For GPU, either the attached GPU or a simulated one may be used. To use " "a simulated device, set --gpu_target_config to a textproto file " - "containing a GpuTargetConfigProto forthe device you wish to simulate. To " + "containing a GpuTargetConfigProto for the device you wish to simulate. To " "use the attached GPU, do not set this flag. When compiling with the " "attached device, --output_file will contain a text-format HLO module " "instead of an AotCompilationResult." @@ -83,129 +51,6 @@ const char kUsageHeader[] = "understood by that repository." "\n"; -xla::StatusOr> LoadModule( - const std::string& module_path) { - auto format = std::string(tsl::io::Extension(module_path)); - if (format == "hlo" || format == "txt") { - return LoadModuleFromFile( - module_path, hlo_module_loader_details::Config(), - /*format=*/"hlo", [&](HloModuleConfig* c) {}, nullptr); - } - std::string module_string; - TF_RETURN_IF_ERROR( - tsl::ReadFileToString(tsl::Env::Default(), module_path, &module_string)); - - mlir::DialectRegistry dialects; - // TODO(b/248362914): Register all required dialects. - dialects.insert(); - dialects.insert(); - dialects.insert(); - mlir::stablehlo::registerAllDialects(dialects); - - // Parse MHLO module. - auto threading = mlir::MLIRContext::Threading::DISABLED; - auto ctx = std::make_unique(dialects, threading); - mlir::OwningOpRef module = - mlir::parseSourceString(module_string, ctx.get()); - - // Convert Mhlo to Hlo Module. - XlaComputation xla_computation; - TF_RETURN_IF_ERROR( - MlirToXlaComputation(*module, xla_computation, false, false)); - HloModuleProto hlo_module_proto = xla_computation.proto(); - - TF_ASSIGN_OR_RETURN(ProgramShape shape, xla_computation.GetProgramShape()); - DebugOptions debug_options = GetDebugOptionsFromFlags(); - HloModuleConfig config(shape); - config.set_debug_options(debug_options); - return HloModule::CreateFromProto(hlo_module_proto, config); -} - -Status XlaCompileMain( - const std::string& module_path, const std::string& output_path, - const std::string& platform, const std::string& gpu_target_config_path, - const std::string& autotune_results_path, const std::string& symbol_repo, - const std::string& symbol_id, const bool use_attached_device, - const bool wait_for_uploads, const std::string& result_output_file) { - std::unique_ptr hlo_module; - std::unique_ptr target_config; - if (!symbol_id.empty()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr mod, - LookupSymbolInRepository(symbol_repo, symbol_id, BackendType::kGpu)); - if (mod == nullptr) { - return absl::NotFoundError( - absl::StrCat("Could not find ", symbol_id, " in ", symbol_repo)); - } -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - if (auto* data = static_cast( - mod->backend_specific_data.get()); - data != nullptr) { - target_config = std::move(mod->target_config); - } -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - hlo_module = std::move(mod->hlo_module); - } else { - TF_ASSIGN_OR_RETURN(hlo_module, LoadModule(module_path)); - } - - xla::TimerStats stats; - xla::ScopedLoggingTimer timer("compilation", true, "xla_compile_main.cc", 1, - &stats); - CompilationResult compilation_result; - absl::Cleanup cleanup([&] { - // Make sure we stop the timer if compilation failed. - timer.StopAndLog(); - if (!result_output_file.empty()) { - TF_QCHECK_OK( - WriteResultFile(result_output_file, stats, compilation_result)); - } - }); - // Run AOT compilation. - std::optional cfg = std::nullopt; - if (platform == "gpu") { - if (!gpu_target_config_path.empty()) { - // Parse GpuTargetConfig. - std::string gpu_target_config_string; - TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), - gpu_target_config_path, - &gpu_target_config_string)); - stream_executor::GpuTargetConfigProto gpu_target_config_proto; - - if (!tsl::protobuf::TextFormat::ParseFromString( - gpu_target_config_string, &gpu_target_config_proto)) { - return FailedPrecondition("Failed to parse GpuTargetConfigProto"); - } - - target_config = - std::make_unique(gpu_target_config_proto); - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - if (!autotune_results_path.empty()) { - TF_RETURN_IF_ERROR(gpu::AutotunerUtil::LoadAutotuneResultsFromFile( - autotune_results_path)); - } -#endif - } - - cfg = (use_attached_device) ? std::nullopt - : std::make_optional(*std::move(target_config)); - } - auto result = CompileExecutable(std::move(hlo_module), platform, cfg); - if (!result.ok()) { - *compilation_result.mutable_status() = tsl::StatusToProto(result.status()); - return result.status(); - } - - TF_RETURN_IF_ERROR( - tsl::WriteStringToFile(tsl::Env::Default(), output_path, *result)); - - if (wait_for_uploads) { - MaybeWaitForUploads(); - } - return OkStatus(); -} - } // end namespace xla_compile } // end namespace xla @@ -265,7 +110,7 @@ int main(int argc, char* argv[]) { tsl::port::InitMain(usage.c_str(), &argc, &argv); - xla::Status result = xla::xla_compile::XlaCompileMain( + absl::Status result = xla::XlaCompileMain( module_path, output_path, platform, gpu_target_config_path, autotune_results_path, symbol_repository, symbol_id, use_attached_device, wait_for_uploads, result_output_file); diff --git a/xla/service/xla_compile_result.proto b/xla/service/xla_compile_result.proto index b6b102b740da3..ed5982e270f9b 100644 --- a/xla/service/xla_compile_result.proto +++ b/xla/service/xla_compile_result.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/service/xla_debug_info_manager.cc b/xla/service/xla_debug_info_manager.cc index 92147af717386..b6d5e5ff90d13 100644 --- a/xla/service/xla_debug_info_manager.cc +++ b/xla/service/xla_debug_info_manager.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,6 +20,10 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/synchronization/mutex.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo.pb.h" #include "xla/service/hlo_proto_util.h" namespace xla { diff --git a/xla/service/xla_debug_info_manager.h b/xla/service/xla_debug_info_manager.h index 7d7cf094911e3..0d3ce1ca18a42 100644 --- a/xla/service/xla_debug_info_manager.h +++ b/xla/service/xla_debug_info_manager.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,7 +21,9 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo.pb.h" #include "tsl/platform/status.h" diff --git a/xla/service/xla_debug_info_manager_test.cc b/xla/service/xla_debug_info_manager_test.cc index a33e6ae5b6ed6..f3aaa7c47a41f 100644 --- a/xla/service/xla_debug_info_manager_test.cc +++ b/xla/service/xla_debug_info_manager_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,8 +18,13 @@ limitations under the License. #include #include +#include +#include #include "absl/container/flat_hash_set.h" +#include "absl/synchronization/mutex.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo.pb.h" +#include "xla/service/hlo_module_config.h" #include "xla/tests/hlo_test_base.h" namespace xla { diff --git a/xla/service/zero_sized_hlo_elimination.cc b/xla/service/zero_sized_hlo_elimination.cc index cf49f49a6cbdb..22c7cf0f00beb 100644 --- a/xla/service/zero_sized_hlo_elimination.cc +++ b/xla/service/zero_sized_hlo_elimination.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,19 +15,22 @@ limitations under the License. #include "xla/service/zero_sized_hlo_elimination.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout_util.h" #include "xla/literal.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" +#include "xla/statusor.h" #include "xla/util.h" #include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/status.h" namespace xla { -StatusOr ZeroSizedHloElimination::Run( +absl::StatusOr ZeroSizedHloElimination::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/xla/service/zero_sized_hlo_elimination.h b/xla/service/zero_sized_hlo_elimination.h index 4d2ac76299b66..f45b9134f6d34 100644 --- a/xla/service/zero_sized_hlo_elimination.h +++ b/xla/service/zero_sized_hlo_elimination.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,15 +16,18 @@ limitations under the License. #ifndef XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ #define XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" +#include "xla/statusor.h" // HLO pass that replaces zero sized Hlos with a zero sized constant literal. namespace xla { class ZeroSizedHloElimination : public HloModulePass { public: using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; absl::string_view name() const override { diff --git a/xla/service/zero_sized_hlo_elimination_test.cc b/xla/service/zero_sized_hlo_elimination_test.cc index eb06b07fd2392..9da305fb978cd 100644 --- a/xla/service/zero_sized_hlo_elimination_test.cc +++ b/xla/service/zero_sized_hlo_elimination_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,17 +20,15 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/literal.h" -#include "xla/service/shape_inference.h" +#include "xla/literal_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" #include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -43,7 +41,7 @@ class ZeroSizedHloEliminationTest : public HloTestBase { builder_.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {3, 0}), "zero sized param"))) {} - StatusOr RunZeroSizedElimination() { + absl::StatusOr RunZeroSizedElimination() { auto module = CreateNewVerifiedModule("zero_sized_elimination_test_module"); module->AddEntryComputation(builder_.Build()); return ZeroSizedHloElimination{}.Run(module.get()); diff --git a/xla/service_interface.h b/xla/service_interface.h index 6186bae11bbc7..c1bf58dc5c2a2 100644 --- a/xla/service_interface.h +++ b/xla/service_interface.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/shape.cc b/xla/shape.cc index 0ad4897f320e9..7831898b4a1c0 100644 --- a/xla/shape.cc +++ b/xla/shape.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,6 +38,7 @@ Shape::~Shape() = default; Shape::Shape(const Shape&) = default; Shape::Shape(Shape&&) = default; Shape& Shape::operator=(const Shape&) = default; +Shape& Shape::operator=(Shape&&) = default; Shape::Shape(const ShapeProto& shape_proto) { set_element_type(shape_proto.element_type()); @@ -116,37 +117,43 @@ std::string Shape::ToString(bool print_layout) const { } bool Shape::IsInteger() const { - if (primitive_util::IsIntegralType(element_type())) { - return true; - } if (IsTuple()) { - return absl::c_any_of(tuple_shapes_, + return absl::c_all_of(tuple_shapes_, [](const Shape& s) { return s.IsInteger(); }); } - return false; + return primitive_util::IsIntegralType(element_type()); } bool Shape::is_static() const { if (IsTuple()) { - for (const Shape& subshape : tuple_shapes_) { - if (!subshape.is_static()) { - return false; - } - } + return absl::c_all_of(tuple_shapes_, + [](const Shape& s) { return s.is_static(); }); } return !absl::c_any_of(dynamic_dimensions_, [](bool b) { return b; }); } bool Shape::is_unbounded_dynamic() const { - if (IsTuple() && absl::c_any_of(tuple_shapes_, [](const Shape& subshape) { - return subshape.is_unbounded_dynamic(); - })) { - return true; + if (IsTuple()) { + return absl::c_any_of(tuple_shapes_, [](const Shape& subshape) { + return subshape.is_unbounded_dynamic(); + }); } return absl::c_any_of(dimensions_, [](int64_t dim) { return dim == kUnboundedSize; }); } +bool Shape::is_bounded_dynamic() const { + if (IsTuple()) { + return absl::c_any_of(tuple_shapes_, [](const Shape& subshape) { + return subshape.is_bounded_dynamic(); + }); + } + for (auto i = 0; i < dimensions_.size(); ++i) { + if (is_bounded_dynamic_dimension(i)) return true; + } + return false; +} + void Shape::DeleteDimension(int64_t dim_to_delete) { CHECK(IsArray()); CHECK_GE(dim_to_delete, 0); @@ -159,7 +166,7 @@ void Shape::DeleteDimension(int64_t dim_to_delete) { } const Shape& Shape::tuple_shapes(int index) const { - return tuple_shapes_.at(index); + return tuple_shapes_[index]; } Shape* Shape::add_tuple_shapes() { @@ -193,10 +200,21 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) { } if (!ignore_dimensions_) { - if (!ShapeUtil::SameDimensions(lhs, rhs)) { - VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; + if (!ShapeUtil::SameRank(lhs, rhs)) { + VLOG(3) << "CompareShapes: lhs rank != rhs rank"; return false; } + for (int i = 0; i < lhs.rank(); ++i) { + if (ignore_dynamic_dimension_ && + (lhs.is_unbounded_dynamic_dimension(i) || + rhs.is_unbounded_dynamic_dimension(i))) { + continue; + } + if (lhs.dimensions(i) != rhs.dimensions(i)) { + VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; + return false; + } + } } else { if (!ShapeUtil::SameRank(lhs, rhs)) { VLOG(3) << "CompareShapes: lhs rank != rhs rank"; @@ -221,6 +239,9 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) { if (ignore_memory_space_in_layout_) { equal.IgnoreMemorySpace(); } + if (ignore_tail_padding_alignment_in_elements_in_layout_) { + equal.IgnoreTailPaddingAlignmentInElements(); + } if (!equal(lhs.layout(), rhs.layout())) { VLOG(3) << "CompareShapes: lhs layout != rhs layout"; return false; @@ -251,6 +272,7 @@ ProgramShape::~ProgramShape() = default; ProgramShape::ProgramShape(const ProgramShape&) = default; ProgramShape::ProgramShape(ProgramShape&&) = default; ProgramShape& ProgramShape::operator=(const ProgramShape&) = default; +ProgramShape& ProgramShape::operator=(ProgramShape&&) = default; ProgramShape::ProgramShape(const ProgramShapeProto& program_shape_proto) { for (const ShapeProto& shape_proto : program_shape_proto.parameters()) { diff --git a/xla/shape.h b/xla/shape.h index b2828b429e586..d9d6b63a11def 100644 --- a/xla/shape.h +++ b/xla/shape.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -44,6 +44,7 @@ class Shape { Shape(const Shape&); Shape(Shape&&); Shape& operator=(const Shape&); + Shape& operator=(Shape&&); // Construct a shape from a ShapeProto. explicit Shape(const ShapeProto& shape_proto); @@ -99,23 +100,40 @@ class Shape { static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); // Returns true if the shape has one or more dimensions with unbounded sizes. - // Tuple shapes are traversed recursively. + // Tuple shapes are traversed recursively, returns true if any element is + // unbounded dynamic. bool is_unbounded_dynamic() const; // Returns true if the given dimension is unbounded dynamic. bool is_unbounded_dynamic_dimension(int dimension) const { - return dimensions_.at(dimension) == kUnboundedSize; + return dimensions_[dimension] == kUnboundedSize; } // Sets a given dimension as unbounded dynamic. void set_unbounded_dynamic_dimension(int dimension) { dynamic_dimensions_[dimension] = true; - dimensions_.at(dimension) = kUnboundedSize; + dimensions_[dimension] = kUnboundedSize; + } + + // Returns true if the shape has one or more dimensions with bounded sizes. + // Tuple shapes are traversed recursively, returns true if any element is + // bounded dynamic. + bool is_bounded_dynamic() const; + + // Returns true if the given dimension is bounded dynamic. + bool is_bounded_dynamic_dimension(int dimension) const { + return is_dynamic_dimension(dimension) && + !is_unbounded_dynamic_dimension(dimension); } // Returns true if the given dimension is dynamically-sized. bool is_dynamic_dimension(int dimension) const { - return dynamic_dimensions_.at(dimension); + return dynamic_dimensions_[dimension]; + } + + // Returns true if the given dimension is statically-sized. + bool is_static_dimension(int dimension) const { + return !dynamic_dimensions_[dimension]; } // Sets whether or not the given dimension is dynamically-sized. @@ -149,18 +167,16 @@ class Shape { // Methods for accessing the dimensions array. int dimensions_size() const { return dimensions_.size(); } - int64_t dimensions(int index) const { return dimensions_.at(index); } + int64_t dimensions(int index) const { return dimensions_[index]; } int64_t dimensions_minor(int index) const { CHECK(has_layout()); - return dimensions_.at(layout_->minor_to_major(index)); - } - void set_dimensions(int index, int64_t value) { - dimensions_.at(index) = value; + return dimensions_[layout_->minor_to_major(index)]; } + void set_dimensions(int index, int64_t value) { dimensions_[index] = value; } void set_dimensions_minor(int index, int64_t value) { CHECK(has_layout()); - dimensions_.at(layout_->minor_to_major(index)) = value; + dimensions_[layout_->minor_to_major(index)] = value; } void add_dimensions(int64_t value) { dimensions_.push_back(value); @@ -179,7 +195,7 @@ class Shape { // tuple shapes. int tuple_shapes_size() const { return tuple_shapes_.size(); } const Shape& tuple_shapes(int index) const; - Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_.at(index); } + Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_[index]; } Shape* add_tuple_shapes(); void clear_tuple_shapes() { tuple_shapes_.clear(); } const std::vector& tuple_shapes() const { return tuple_shapes_; } @@ -200,7 +216,8 @@ class Shape { } void clear_layout() { layout_ = std::nullopt; } - // Recursively clear dynamic dimension of a shape. + // Recursively clear all dynamic dimension of a shape, including bounded and + // unbounded dynamic dimensions. void clear_dynamic_dimensions() { if (!IsTuple()) { if (is_dynamic()) { @@ -269,6 +286,7 @@ class Shape { ignore_tiles_in_layout_ = true; ignore_element_size_in_layout_ = true; ignore_memory_space_in_layout_ = true; + ignore_tail_padding_alignment_in_elements_in_layout_ = true; return *this; } Equal& IgnoreElementType() { @@ -287,6 +305,10 @@ class Shape { ignore_dimensions_ = true; return *this; } + Equal& IgnoreTailPaddingAlignmentInElements() { + ignore_tail_padding_alignment_in_elements_in_layout_ = true; + return *this; + } private: bool ignore_layout_ = false; @@ -297,6 +319,7 @@ class Shape { bool ignore_fp_precision_ = false; bool ignore_dynamic_dimension_ = false; bool ignore_dimensions_ = false; + bool ignore_tail_padding_alignment_in_elements_in_layout_ = false; }; // Test that all fields of the shape are the same, equivalent to Equal(). @@ -353,6 +376,7 @@ class ProgramShape { ProgramShape(const ProgramShape&); ProgramShape(ProgramShape&&); ProgramShape& operator=(const ProgramShape&); + ProgramShape& operator=(ProgramShape&&); // Creates a ProgramShape from a ProgramShapeProto protobuf. explicit ProgramShape(const ProgramShapeProto& program_shape_proto); @@ -372,8 +396,8 @@ class ProgramShape { // Methods for accessing and manipulating the Shape of the parameters. int parameters_size() const { return parameters_.size(); } - const Shape& parameters(int index) const { return parameters_.at(index); } - Shape* mutable_parameters(int index) { return ¶meters_.at(index); } + const Shape& parameters(int index) const { return parameters_[index]; } + Shape* mutable_parameters(int index) { return ¶meters_[index]; } Shape* add_parameters() { parameters_.emplace_back(); return ¶meters_.back(); @@ -389,13 +413,13 @@ class ProgramShape { // Methods for accessing and manipulating the names of the parameters. int parameter_names_size() const { return parameter_names_.size(); } const std::string& parameter_names(int index) const { - return parameter_names_.at(index); + return parameter_names_[index]; } void set_parameter_names(int index, const std::string& value) { - parameter_names_.at(index) = value; + parameter_names_[index] = value; } std::string* mutable_parameter_names(int index) { - return ¶meter_names_.at(index); + return ¶meter_names_[index]; } void add_parameter_names(const std::string& value) { parameter_names_.push_back(value); diff --git a/xla/shape_layout.cc b/xla/shape_layout.cc index 3a4d14e527c02..7969aca2d3563 100644 --- a/xla/shape_layout.cc +++ b/xla/shape_layout.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/shape_layout.h b/xla/shape_layout.h index 1d203d35deca2..6348e7230a778 100644 --- a/xla/shape_layout.h +++ b/xla/shape_layout.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/shape_test.cc b/xla/shape_test.cc index 322f02e4773f6..3d23b421718ed 100644 --- a/xla/shape_test.cc +++ b/xla/shape_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -108,6 +108,26 @@ TEST_F(ShapeTest, EqualityTest) { ShapeUtil::MakeShapeWithDenseLayout(F32, {23, 44}, {1, 0})); } +TEST_F(ShapeTest, IsInteger) { + EXPECT_FALSE(opaque_.IsInteger()); + EXPECT_FALSE(token_.IsInteger()); + EXPECT_TRUE(matrix_.IsInteger()); + EXPECT_FALSE(tuple_.IsInteger()); + EXPECT_FALSE(nested_tuple_.IsInteger()); + + Shape u32_shape = ShapeUtil::MakeShape(U32, {1}); + EXPECT_TRUE(u32_shape.IsInteger()); + + Shape f32_shape = ShapeUtil::MakeShape(F32, {1}); + EXPECT_FALSE(f32_shape.IsInteger()); + + Shape integer_tuple = ShapeUtil::MakeTupleShape({u32_shape, u32_shape}); + EXPECT_TRUE(integer_tuple.IsInteger()); + + Shape mixed_type_tuple = ShapeUtil::MakeTupleShape({u32_shape, f32_shape}); + EXPECT_FALSE(mixed_type_tuple.IsInteger()); +} + TEST_F(ShapeTest, IsStatic) { EXPECT_TRUE(opaque_.is_static()); EXPECT_TRUE(token_.is_static()); @@ -165,6 +185,15 @@ TEST_F(ShapeTest, IsDynamicDimension) { EXPECT_FALSE(unbounded_.is_dynamic_dimension(1)); } +TEST_F(ShapeTest, IsStaticDimension) { + Shape dynamic_matrix = matrix_; + dynamic_matrix.set_dynamic_dimension(1, true); + EXPECT_TRUE(dynamic_matrix.is_static_dimension(0)); + EXPECT_FALSE(dynamic_matrix.is_static_dimension(1)); + EXPECT_FALSE(unbounded_.is_static_dimension(0)); + EXPECT_TRUE(unbounded_.is_static_dimension(1)); +} + TEST_F(ShapeTest, ProgramShapeToFromProto) { ProgramShape program_shape; *program_shape.add_parameters() = ShapeUtil::MakeShape(F32, {1, 2, 3}); diff --git a/xla/shape_tree.cc b/xla/shape_tree.cc index 28b1ffc1c0d27..bc83698a02851 100644 --- a/xla/shape_tree.cc +++ b/xla/shape_tree.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/shape_tree.h b/xla/shape_tree.h index 5acf9de493e03..4e6a0bf649560 100644 --- a/xla/shape_tree.h +++ b/xla/shape_tree.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/shape_tree_test.cc b/xla/shape_tree_test.cc index b1aed006598c8..5e29d719eb27d 100644 --- a/xla/shape_tree_test.cc +++ b/xla/shape_tree_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/shape_util.cc b/xla/shape_util.cc index 628b8011dee63..4d7c88588388f 100644 --- a/xla/shape_util.cc +++ b/xla/shape_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -116,14 +117,16 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) { namespace { // Constructs and returns the new shape with the given minor_to_major order in // its Layout. -StatusOr MakeShapeWithLayoutInternal( +absl::StatusOr MakeShapeWithLayoutInternal( PrimitiveType element_type, absl::Span dimensions, absl::Span minor_to_major, absl::Span dim_level_types, absl::Span dim_unique, absl::Span dim_ordered, - absl::Span tiles, PrimitiveType index_primitive_type, - PrimitiveType pointer_primitive_type, int64_t element_size_in_bits, - int64_t memory_space, std::optional physical_shape) { + absl::Span tiles, int64_t tail_padding_alignment_in_elements, + PrimitiveType index_primitive_type, PrimitiveType pointer_primitive_type, + int64_t element_size_in_bits, int64_t memory_space, + absl::Span split_configs, + std::optional physical_shape) { if (dimensions.size() != minor_to_major.size()) { return InvalidArgument("Dimensions size is %ld, but layout size is %ld.", dimensions.size(), minor_to_major.size()); @@ -142,10 +145,11 @@ StatusOr MakeShapeWithLayoutInternal( } *shape.mutable_layout() = LayoutUtil::MakeLayout( minor_to_major, dim_level_types, dim_unique, dim_ordered, tiles, - index_primitive_type, pointer_primitive_type, element_size_in_bits, - memory_space, std::move(physical_shape)); + tail_padding_alignment_in_elements, index_primitive_type, + pointer_primitive_type, element_size_in_bits, memory_space, split_configs, + std::move(physical_shape)); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); - return shape; + return std::move(shape); } template @@ -235,10 +239,7 @@ Shape MakeTupleShapeImpl(absl::Span shapes) { Shape* shape) { int64_t dense_shape_size = primitive_util::IsArrayType(element_type) ? primitive_util::ByteWidth(element_type) - : 0; - if (dense_shape_size <= 0) { - return false; - } + : -1; // Verify that array-based lookup is consistent with public API. DCHECK_EQ(dense_shape_size, ByteSizeOfPrimitiveType(element_type)) @@ -248,23 +249,23 @@ Shape MakeTupleShapeImpl(absl::Span shapes) { const int ndims = dimensions.size(); auto layout = shape->mutable_layout(); auto* minor_to_major = layout->mutable_minor_to_major(); - auto is_unbounded_dynamic = absl::c_any_of( - dimensions, [](int64_t dim) { return dim == Shape::kUnboundedSize; }); + int64_t static_extent_product = dense_shape_size; + bool any_overflows = false; for (int i = 0; i < ndims; i++) { const int64_t d = dimensions[i]; - if (d < 0 && d != Shape::kUnboundedSize) { - return false; - } - if (!is_unbounded_dynamic) { - dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d); - if (dense_shape_size < 0) { - return false; - } + if (d != Shape::kUnboundedSize) { + bool overflow; + std::tie(static_extent_product, overflow) = + OverflowSafeMultiply(static_extent_product, d); + any_overflows |= overflow; } shape->add_dimensions(d); minor_to_major->push_back(ndims - 1 - i); } + if (any_overflows) { + return false; + } return true; } @@ -303,7 +304,7 @@ Shape MakeTupleShapeImpl(absl::Span shapes) { return output; } -/* static */ StatusOr ShapeUtil::MakeValidatedShape( +/* static */ absl::StatusOr ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions) { Shape shape; if (!FillNewShape(element_type, dimensions, &shape)) { @@ -311,10 +312,10 @@ Shape MakeTupleShapeImpl(absl::Span shapes) { static_cast(element_type), absl::StrJoin(dimensions, ",")); } - return shape; + return std::move(shape); } -/* static */ StatusOr ShapeUtil::MakeValidatedShape( +/* static */ absl::StatusOr ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions, const std::vector& dynamic_dimensions) { if (dynamic_dimensions.size() != dimensions.size()) { @@ -337,19 +338,21 @@ Shape MakeTupleShapeImpl(absl::Span shapes) { "Cannot mark a dynamic dimension at dim=%d as static", i); } } - return shape; + return std::move(shape); } /* static */ Shape ShapeUtil::MakeShapeWithDenseLayout( PrimitiveType element_type, absl::Span dimensions, absl::Span minor_to_major, absl::Span tiles, - int64_t element_size_in_bits, int64_t memory_space) { + int64_t tail_padding_alignment_in_elements, int64_t element_size_in_bits, + int64_t memory_space, absl::Span split_configs) { auto ret = MakeShapeWithLayoutInternal( element_type, dimensions, minor_to_major, /*dim_level_types=*/{}, /*dim_unique=*/{}, /*dim_ordered=*/{}, tiles, + tail_padding_alignment_in_elements, /*index_primitive_type=*/PRIMITIVE_TYPE_INVALID, /*pointer_primitive_type=*/PRIMITIVE_TYPE_INVALID, element_size_in_bits, - memory_space, + memory_space, split_configs, /*physical_shape=*/std::nullopt); TF_CHECK_OK(ret.status()); return *ret; @@ -361,12 +364,13 @@ Shape MakeTupleShapeImpl(absl::Span shapes) { absl::Span dim_level_types, absl::Span dim_unique, absl::Span dim_ordered, PrimitiveType index_primitive_type, PrimitiveType pointer_primitive_type, - int64_t element_size_in_bits, int64_t memory_space, - std::optional physical_shape) { + int64_t tail_padding_alignment_in_elements, int64_t element_size_in_bits, + int64_t memory_space, std::optional physical_shape) { auto ret = MakeShapeWithLayoutInternal( element_type, dimensions, minor_to_major, dim_level_types, dim_unique, - dim_ordered, /*tiles=*/{}, index_primitive_type, pointer_primitive_type, - element_size_in_bits, memory_space, std::move(physical_shape)); + dim_ordered, /*tiles=*/{}, tail_padding_alignment_in_elements, + index_primitive_type, pointer_primitive_type, element_size_in_bits, + memory_space, /*split_configs=*/{}, std::move(physical_shape)); TF_CHECK_OK(ret.status()); return *ret; } @@ -423,6 +427,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( shape.layout().tiles().begin(), shape.layout().tiles().end()); new_shape.mutable_layout()->set_element_size_in_bits( shape.layout().element_size_in_bits()); + new_shape.mutable_layout()->set_tail_padding_alignment_in_elements( + shape.layout().tail_padding_alignment_in_elements()); } for (int i = 0; i < shape.dimensions_size(); ++i) { new_shape.set_dynamic_dimension(i, shape.is_dynamic_dimension(i)); @@ -683,8 +689,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { } /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) { - return shape.IsArray() && - absl::c_any_of(shape.dimensions(), [](int64_t d) { return d == 0; }); + return shape.IsArray() && absl::c_linear_search(shape.dimensions(), 0); } /* static */ bool ShapeUtil::IsScalarWithElementType( @@ -790,7 +795,16 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { const Shape& rhs) { CHECK(lhs.IsArray()); CHECK(rhs.IsArray()); - return absl::c_equal(lhs.dimensions(), rhs.dimensions()); + if (!SameRank(lhs, rhs)) return false; + for (int i = 0; i < lhs.rank(); ++i) { + if (!lhs.is_unbounded_dynamic_dimension(i) && + !rhs.is_unbounded_dynamic_dimension(i) && + lhs.dimensions(i) != rhs.dimensions(i)) { + return false; + } + } + + return true; } /* static */ bool ShapeUtil::SameRank(const Shape& lhs, const Shape& rhs) { @@ -898,6 +912,47 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { ByteSizeOfPrimitiveType(shape.element_type()); } +/* static */ absl::StatusOr ShapeUtil::SerializedSize( + const Shape& shape) { + return SerializedSizeWithProto(shape, shape.ToProto()); +} + +/* static */ absl::StatusOr ShapeUtil::SerializedSizeWithProto( + const Shape& shape, const ShapeProto& proto) { + // The size computed here must be kept in sync with the serialized format as + // described in the comments for LiteralBase::SerializeWithShapeProto in + // literal.h. + TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayout(shape)); + int64_t size = sizeof(int64_t) + proto.ByteSizeLong(); + + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + shape, + [&](const Shape& subshape, const ShapeIndex& index) -> absl::Status { + if (subshape.IsTuple()) { + return OkStatus(); + } + if (!subshape.IsArray()) { + return InvalidArgument("Shape cannot be serialiized: %s", + shape.ToString()); + } + if (subshape.is_dynamic()) { + size += sizeof(DynamicSizeType) * subshape.rank(); + } + if (subshape.element_type() == PRED) { + // PRED is packed 8 elements per byte. + size += CeilOfRatio(ElementsIn(subshape), 8); + } else if (primitive_util::Is4BitType(subshape.element_type())) { + // 4-bit types are packed 2 elements per byte. + size += CeilOfRatio(ElementsIn(subshape), 2); + } else { + size += ByteSizeOfElements(subshape); + } + return OkStatus(); + })); + + return size; +} + /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { if (shape.element_type() == PRIMITIVE_TYPE_INVALID || @@ -921,7 +976,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { return InvalidArgument("non-tuple shape has tuple_shapes field"); } - // Tokens and opaques can should not have layout or dimensions. + // Tokens and opaques should not have layout or dimensions. if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE_TYPE) { if (shape.dimensions_size() != 0) { return InvalidArgument( @@ -938,13 +993,25 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { return OkStatus(); } + bool any_overflows = false; + int64_t product = 1; for (int64_t i = 0; i < shape.rank(); ++i) { int64_t dimension = shape.dimensions(i); - if (dimension < 0 && dimension != Shape::kUnboundedSize) { + if (dimension == Shape::kUnboundedSize) { + continue; + } + if (dimension < 0) { return InvalidArgument( "shape's dimensions must not be < 0; dimension at index %d was %d", i, dimension); } + bool overflow; + std::tie(product, overflow) = OverflowSafeMultiply(product, dimension); + any_overflows |= overflow; + } + if (any_overflows) { + return InvalidArgument("shape's dimensions overflow: %s", + shape.ShortDebugString()); } TF_RETURN_IF_ERROR(ValidateShapeSize(shape)); @@ -958,34 +1025,17 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { return OkStatus(); } - if (shape.is_unbounded_dynamic()) { - return OkStatus(); - } + auto [extent_product, extent_overflow] = + ExtentProduct(shape); + auto [dense_shape_size, byte_width_overflow] = OverflowSafeMultiply( + extent_product, ByteSizeOfPrimitiveType(shape.element_type())); - int64_t shape_size = [&]() { - int64_t dense_shape_size = 1; - if (shape.dimensions().empty()) { - return dense_shape_size; - } - - absl::Span shape_max_dimensions = shape.dimensions(); - for (int64_t dim : shape_max_dimensions) { - dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, dim); - if (dense_shape_size < 0) { - return dense_shape_size; - } - } - dense_shape_size = MultiplyWithoutOverflow( - dense_shape_size, ByteSizeOfPrimitiveType(shape.element_type())); - return dense_shape_size; - }(); - - if (shape_size < 0) { + if (extent_overflow || byte_width_overflow) { return InvalidArgument("Shape %s size may overflow int64_t.", ShapeUtil::HumanString(shape)); } - VLOG(3) << "Shape size is valid: " << shape_size; + VLOG(3) << "Shape size is valid: " << dense_shape_size; return OkStatus(); } @@ -1042,7 +1092,16 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { return *return_shape; } -/* static */ StatusOr ShapeUtil::TryGetSubshape( +/* static */ const Shape& ShapeUtil::GetSubshapeOneIndex(const Shape& shape, + int64_t index) { + const Shape* return_shape = &shape; + CHECK(return_shape->IsTuple()) + << "Invalid index " << index << " for shape " << shape; + return_shape = &return_shape->tuple_shapes(index); + return *return_shape; +} + +/* static */ absl::StatusOr ShapeUtil::TryGetSubshape( const Shape& shape, ShapeIndexView index) { const Shape* return_shape = &shape; for (auto i : index) { @@ -1072,17 +1131,26 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { return !GetSubshape(shape, index).IsTuple(); } -/* static */ int64_t ShapeUtil::GetLeafCount(const Shape& shape) { - if (!shape.IsTuple()) { - return 1; - } +/* static */ int64_t ShapeUtil::GetLeafCountTuple(const Shape& shape) { + DCHECK(shape.IsTuple()); int64_t count = 0; for (const Shape& subshape : shape.tuple_shapes()) { - count += GetLeafCount(subshape); + if (subshape.IsTuple()) { + count += GetLeafCount(subshape); + } else { + ++count; + } } return count; } +/* static */ int64_t ShapeUtil::GetLeafCount(const Shape& shape) { + if (!shape.IsTuple()) { + return 1; + } + return GetLeafCountTuple(shape); +} + /* static */ std::vector ShapeUtil::GetLeafShapes( const Shape& shape) { std::vector leaves; @@ -1470,9 +1538,8 @@ ShapeUtil::ReshapeLeavesDimensionsUnmodified( IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major, input_unit_index); // output_index has the same logical linear index as input_unit_index. - std::vector output_index = - IndexUtil::LinearIndexToMultidimensionalIndex(output_shape_dim0_major, - logical_linear_index); + auto output_index = IndexUtil::LinearIndexToMultidimensionalIndex( + output_shape_dim0_major, logical_linear_index); // Check input_unit_index and output_index have the same physical linear // index. if (IndexUtil::MultidimensionalIndexToLinearIndex(input_shape, @@ -1865,7 +1932,7 @@ struct ParallelState { auto indexes_copy = s.indexes; pstate.pool->Schedule([indexes_copy, &visitor_function, &pstate] { const int thread_id = pstate.pool->CurrentThreadId(); - StatusOr result = visitor_function(indexes_copy, thread_id); + absl::StatusOr result = visitor_function(indexes_copy, thread_id); if (!result.ok()) { absl::MutexLock lock(&pstate.mu); if (pstate.status.ok()) { @@ -2052,6 +2119,7 @@ Shape ShapeUtil::DeviceShapeToHostShape(Shape s) { subshape->mutable_layout()->set_memory_space(Layout::kDefaultMemorySpace); subshape->mutable_layout()->clear_physical_shape(); subshape->mutable_layout()->set_element_size_in_bits(0); + subshape->mutable_layout()->set_tail_padding_alignment_in_elements(1); subshape->mutable_layout()->set_dynamic_shape_metadata_prefix_bytes(0); } }); @@ -2077,6 +2145,16 @@ Status ShapeUtil::ByteStrides(const Shape& shape, absl::Span strides) { return OkStatus(); } +/*static*/ +std::optional> ShapeUtil::ByteStrides( + const Shape& shape) { + absl::InlinedVector strides(shape.dimensions_size()); + if (!ByteStrides(shape, absl::MakeSpan(strides)).ok()) { + return std::nullopt; + } + return strides; +} + /*static*/ int64_t ShapeUtil::ArraySize(const Shape& shape) { CHECK(LayoutUtil::IsDenseArray(shape)); if (shape.layout().tiles().empty()) { @@ -2084,9 +2162,8 @@ Status ShapeUtil::ByteStrides(const Shape& shape, absl::Span strides) { } auto tile_dimensions = shape.layout().tiles(0).dimensions(); - auto shape_dimensions = shape.dimensions(); auto minor_to_major = shape.layout().minor_to_major(); - int64_t shape_dim_size = shape_dimensions.size(); + int64_t shape_dim_size = shape.dimensions().size(); int64_t tile_dim_size = tile_dimensions.size(); // Use the top-level tile for shape size calculation. We assume the @@ -2094,13 +2171,14 @@ Status ShapeUtil::ByteStrides(const Shape& shape, absl::Span strides) { int64_t num_of_elements = 1; int64_t dim = 0; for (dim = 0; dim < tile_dim_size; dim++) { - int64_t dim_size = - dim < shape_dim_size ? shape_dimensions[minor_to_major[dim]] : 1; + int64_t dim_size = dim < shape_dim_size ? LayoutUtil::MaxSplitSize( + shape, minor_to_major[dim]) + : 1; num_of_elements *= RoundUpTo(dim_size, tile_dimensions[tile_dim_size - dim - 1]); } for (; dim < shape_dim_size; dim++) { - int64_t dim_size = shape_dimensions[minor_to_major[dim]]; + int64_t dim_size = LayoutUtil::MaxSplitSize(shape, minor_to_major[dim]); num_of_elements *= dim_size; } @@ -2109,6 +2187,11 @@ Status ShapeUtil::ByteStrides(const Shape& shape, absl::Span strides) { num_of_elements * shape.layout().element_size_in_bits(); return CeilOfRatio(num_bits, CHAR_BIT); } + + if (shape.layout().tail_padding_alignment_in_elements() != 1) { + num_of_elements = RoundUpTo( + num_of_elements, shape.layout().tail_padding_alignment_in_elements()); + } return num_of_elements * ByteSizeOfPrimitiveType(shape.element_type()); } diff --git a/xla/shape_util.h b/xla/shape_util.h index 897aa29d525fd..548d39961b9ce 100644 --- a/xla/shape_util.h +++ b/xla/shape_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -37,6 +38,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/layout.h" #include "xla/layout_util.h" +#include "xla/overflow_util.h" #include "xla/primitive_util.h" #include "xla/printer.h" #include "xla/shape.h" @@ -97,6 +99,8 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index); // properties, which do invariant checks before / after the operation. class ShapeUtil { public: + using DynamicSizeType = int32_t; + // Data structure which describes the coordinates and the shape, of a tuple // shaped sub-shape. struct IndexedShape { @@ -107,18 +111,44 @@ class ShapeUtil { Shape shape; }; + // Returns the product of the statically bound dimensions. + template + static inline std::pair ExtentProduct(const Shape& shape) { + DCHECK(shape.IsArray()) << ShapeUtil::HumanString(shape); + DCHECK_EQ(shape.dimensions_size(), shape.rank()); + int64_t product = 1; + bool any_overflows = false; + for (int dim = 0; dim < shape.dimensions_size(); ++dim) { + if constexpr (kBoundedDynamicOk) { + if (shape.is_unbounded_dynamic_dimension(dim)) { + continue; + } + } else { + DCHECK(!shape.is_unbounded_dynamic_dimension(dim)); + } + bool overflow; + std::tie(product, overflow) = + OverflowSafeMultiply(product, shape.dimensions(dim)); + any_overflows |= overflow; + } + return {product, any_overflows}; + } + + // Returns the product of the statically bound dimensions. + static inline int64_t StaticExtentProduct(const Shape& shape) { + auto [product, overflow] = ExtentProduct(shape); + DCHECK(!overflow); + return product; + } + // Returns the number of elements are contained within the provided shape; // e.g. for rank 0 (scalars) the result is always 1. // Precondition: shape.IsArray() static inline int64_t ElementsIn(const Shape& shape) { - DCHECK(shape.IsArray()) << ShapeUtil::HumanString(shape); - DCHECK_EQ(shape.dimensions_size(), shape.rank()); - if (shape.dimensions().empty()) { - return 1LL; - } - auto begin = shape.dimensions().begin(); - return std::accumulate(std::next(begin), shape.dimensions().end(), *begin, - std::multiplies()); + auto [product, overflow] = + ExtentProduct(shape); + DCHECK(!overflow); + return product; } // As ElementsIn(), but recurses through tuples. @@ -156,6 +186,18 @@ class ShapeUtil { // size also includes padding if present in the layout. static int64_t ByteSizeOfElements(const Shape& shape); + // Returns the size in bytes for the serialized form of this shape. + // This serialized size includes the header of the serialized format, and so + // should not be used for subshapes. Use SerializedSizeOfData for that + // purpose. + static absl::StatusOr SerializedSize(const Shape& shape); + + // As above, but assumes the given ShapeProto is the result of + // shape.ToProto(). This can be used to avoid converting the shape to a + // protobuf multiple times. + static absl::StatusOr SerializedSizeWithProto( + const Shape& shape, const ShapeProto& proto); + // Prints a human-readable string that represents the given shape, with or // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]". static void PrintHumanString(xla::Printer* printer, const Shape& shape); @@ -178,8 +220,8 @@ class ShapeUtil { // (param_name: f32[42x12], ...) -> f32[24x42] static std::string HumanString(const ProgramShape& program_shape); - // Returns whether the LHS and RHS shapes have the same dimensions; note: does - // not check element type. + // Returns whether the LHS and RHS shapes have the same dimensions, ignoring + // the unbounded dimension sizes; note: does not check element type. // Precondition: IsArray(lhs) && IsArray(rhs) static bool SameDimensions(const Shape& lhs, const Shape& rhs); @@ -375,9 +417,9 @@ class ShapeUtil { // dimensions. Method checks if the element type is valid, the shape's // size fits in std::numeric_limits::max(), and dynamic size is not // marked static. - static StatusOr MakeValidatedShape( + static absl::StatusOr MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions); - static StatusOr MakeValidatedShape( + static absl::StatusOr MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions, const std::vector& dynamic_dimensions); @@ -394,8 +436,10 @@ class ShapeUtil { static Shape MakeShapeWithDenseLayout( PrimitiveType element_type, absl::Span dimensions, absl::Span minor_to_major, - absl::Span tiles = {}, int64_t element_size_in_bits = 0, - int64_t memory_space = 0); + absl::Span tiles = {}, + int64_t tail_padding_alignment_in_elements = 1, + int64_t element_size_in_bits = 0, int64_t memory_space = 0, + absl::Span split_configs = {}); // Constructs a new sparse array shape with the given minor_to_major order and // dim_level_types in its Layout. Returns a value shape such that @@ -408,6 +452,7 @@ class ShapeUtil { absl::Span dim_ordered = {}, PrimitiveType index_primitive_type = PRIMITIVE_TYPE_INVALID, PrimitiveType pointer_primitive_type = PRIMITIVE_TYPE_INVALID, + int64_t tail_padding_alignment_in_elements = 1, int64_t element_size_in_bits = 0, int64_t memory_space = 0, std::optional physical_shape = std::nullopt); @@ -505,8 +550,12 @@ class ShapeUtil { // the given Shape argument. The non-Try variants check fail if index is // invalid. static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index); - static StatusOr TryGetSubshape(const Shape& shape, - ShapeIndexView index); + + // Faster version for one index. + static const Shape& GetSubshapeOneIndex(const Shape& shape, int64_t index); + + static absl::StatusOr TryGetSubshape(const Shape& shape, + ShapeIndexView index); static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index); // Returns whether the given index in the given shape is a leaf element of the @@ -515,6 +564,7 @@ class ShapeUtil { // Returns the number of leaves in the shape. static int64_t GetLeafCount(const Shape& shape); + static int64_t GetLeafCountTuple(const Shape& shape); // Retrieves all the leaf shapes and their indexes, in the order walked by // the ForEachSubshape() API. @@ -545,6 +595,23 @@ class ShapeUtil { }).IgnoreError(); } + // Calls the given visitor function for each leaf subshape of the given shape. + // Subshapes are visited in DFS pre-order starting with the entire shape + // (index {}). + // + // The visitor function must have the signature + // + // void fn(const Shape& subshape, const ShapeIndex& index) + template + static void ForEachLeafShape(const Shape& shape, Fn&& fn) { + ForEachSubshape(shape, + [&](const Shape& sub_shape, const ShapeIndex& index) { + if (IsLeafIndex(shape, index)) { + fn(sub_shape, index); + } + }); + } + // Variants of ForEach(Mutable)Subshape which propagate Status from the // visitor function. // @@ -804,7 +871,7 @@ class ShapeUtil { const xla::Shape& bounded_shape); using ForEachVisitorFunction = - absl::FunctionRef(absl::Span)>; + absl::FunctionRef(absl::Span)>; using ForEachVisitorFunctionNoStatus = absl::FunctionRef)>; @@ -869,12 +936,12 @@ class ShapeUtil { static void ForEachIndex(const Shape& shape, const ForEachVisitorFunction& visitor_function) { ForEachIndexWithStatus(shape, [&](absl::Span indices) { - return StatusOr(visitor_function(indices)); + return absl::StatusOr(visitor_function(indices)); }).IgnoreError(); } using ForEachParallelVisitorFunction = - absl::FunctionRef(absl::Span, int)>; + absl::FunctionRef(absl::Span, int)>; // A parallel version of ForEachIndex(WithStatus). This can only be used if // the visitor_function is thread-safe and the order of iteration does not @@ -946,6 +1013,9 @@ class ShapeUtil { // layout. Ignores tiling. `strides` must have size equal to the number of // dimensions of `shape`. static Status ByteStrides(const Shape& shape, absl::Span strides); + // Same as above but returns the stride array, or std::nullopt if error. + static std::optional> ByteStrides( + const Shape& shape); // Returns the array size in bytes (layout/tiling required), all paddings are // included. @@ -956,7 +1026,7 @@ class ShapeUtil { static int64_t ArrayDataSize(const Shape& shape); private: - // Fills *shape. Returns true on success. + // Fills *shape ignoring dynamic dimensions. Returns true on success. // REQUIRES: *shape is empty. static bool FillNewShape(PrimitiveType element_type, absl::Span dimensions, Shape* shape); @@ -1060,7 +1130,7 @@ inline ShapeUtil::ForEachState::ForEachState(const Shape& s, minor_to_major(shape.layout().minor_to_major().data()), rank(LayoutUtil::MinorToMajor(shape).size()), indexes(b.begin(), b.end()), - indexes_ptr((rank == 0) ? nullptr : &indexes[0]), + indexes_ptr((rank == 0) ? nullptr : indexes.data()), indexes_span(indexes) { CHECK_EQ(shape.rank(), b.size()); CHECK_EQ(i.size(), b.size()); diff --git a/xla/shape_util_test.cc b/xla/shape_util_test.cc index f18eb0bbcc094..cee69daab1ee4 100644 --- a/xla/shape_util_test.cc +++ b/xla/shape_util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -216,6 +216,9 @@ TEST(ShapeUtilTest, EqualDynamicShapes) { EXPECT_FALSE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 3}, {true, false}), ShapeUtil::MakeShape(F32, {4, 3}, {false, false}))); + EXPECT_FALSE(ShapeUtil::Equal( + ShapeUtil::MakeShape(F32, {Shape::kUnboundedSize}, {true}), + ShapeUtil::MakeShape(F32, {2}, {true}))); } TEST(ShapeUtilTest, CompatibleDynamicShapes) { @@ -335,9 +338,14 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64)); EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {}))); EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20}))); +} + +TEST(ShapeUtilTest, ByteStrides) { + Shape shape1 = ShapeUtil::MakeShape(F32, {3, 5, 7}); + Shape shape2 = ShapeUtil::MakeShape(F16, {5, 7, 9}); - EXPECT_EQ(0, ShapeUtil::ByteSizeOfPrimitiveType(TOKEN)); - EXPECT_EQ(0, ShapeUtil::ByteSizeOf(ShapeUtil::MakeTokenShape())); + EXPECT_THAT(*ShapeUtil::ByteStrides(shape1), ElementsAre(140, 28, 4)); + EXPECT_THAT(*ShapeUtil::ByteStrides(shape2), ElementsAre(126, 18, 2)); } TEST(ShapeUtilTest, NilShape) { @@ -628,7 +636,8 @@ TEST(ShapeUtilTest, ForEachIndexWithStatus) { // Increments at every invocation. int invocations = 0; auto increment_func = - [&invocations](absl::Span indexes) -> StatusOr { + [&invocations]( + absl::Span indexes) -> absl::StatusOr { if (++invocations == 5) { return Unimplemented("Cannot increment beyond 5."); } @@ -650,7 +659,7 @@ TEST(ShapeUtilTest, GetForEachIndexParallelThreadCount) { Shape shape = ShapeUtil::MakeShape(F32, {10, 100}); auto check_func = [kThreadCount](absl::Span /*indexes*/, - int thread_id) -> StatusOr { + int thread_id) -> absl::StatusOr { EXPECT_GE(thread_id, -1); EXPECT_LT(thread_id, kThreadCount); return true; @@ -667,7 +676,7 @@ TEST(ShapeUtilTest, ForEachIndexParallel) { int64_t output[10][10]; int init = 5; auto set_func = [&](absl::Span indexes, - int /*thread_id*/) -> StatusOr { + int /*thread_id*/) -> absl::StatusOr { output[indexes[0]][indexes[1]] = init + indexes[0] + indexes[1]; return true; }; @@ -685,7 +694,7 @@ TEST(ShapeUtilTest, ForEachIndexParallel_Rank0) { Shape shape = ShapeUtil::MakeShape(F32, {}); int64_t output = -1; auto set_func = [&](absl::Span indexes, - int /*thread_id*/) -> StatusOr { + int /*thread_id*/) -> absl::StatusOr { output = indexes.size(); return true; }; @@ -700,7 +709,7 @@ TEST(ShapeUtilTest, ForEachIndexParallel_Empty) { Shape shape = ShapeUtil::MakeShape(F32, {2, 0}); bool called = false; auto set_func = [&](absl::Span indexes, - int /*thread_id*/) -> StatusOr { + int /*thread_id*/) -> absl::StatusOr { called = true; return true; }; @@ -719,7 +728,7 @@ TEST(ShapeUtilTest, ForEachIndexParallel_DimensionPinnedWithZeros) { int64_t output[2][2] = {}; int init = 5; auto set_func = [&](absl::Span indexes, - int /*thread_id*/) -> StatusOr { + int /*thread_id*/) -> absl::StatusOr { output[indexes[0]][indexes[1]] = init + indexes[0] + indexes[1]; return true; }; @@ -743,7 +752,7 @@ TEST(ShapeUtilTest, ForEachIndexParallel_WithSkips) { int64_t output[10][10] = {}; int init = 5; auto set_func = [&](absl::Span indexes, - int /*thread_id*/) -> StatusOr { + int /*thread_id*/) -> absl::StatusOr { output[indexes[0]][indexes[1]] = init + indexes[0] + indexes[1]; return true; }; @@ -767,13 +776,13 @@ TEST(ShapeUtilTest, ForEachIndexParallel_CalledTwice) { int64_t output[10][10]; int init = 5; auto set_func = [&](absl::Span indexes, - int /*thread_id*/) -> StatusOr { + int /*thread_id*/) -> absl::StatusOr { output[indexes[0]][indexes[1]] = init + indexes[0] + indexes[1]; return true; }; int init2 = 15; auto set_func2 = [&](absl::Span indexes, - int /*thread_id*/) -> StatusOr { + int /*thread_id*/) -> absl::StatusOr { output[indexes[0]][indexes[1]] = init2 + indexes[0] + indexes[1]; return true; }; @@ -803,8 +812,9 @@ TEST(ShapeUtilTest, ForEachIndexParallel_CalledFromMultipleThreads) { kCallingThreads); for (int t = 0; t < kCallingThreads; ++t) { pool.Schedule([&output, &kShape, t] { - auto set_func = [&output, t](absl::Span indexes, - int /*thread_id*/) -> StatusOr { + auto set_func = [&output, t]( + absl::Span indexes, + int /*thread_id*/) -> absl::StatusOr { output[t][indexes[0]][indexes[1]] = kInit + indexes[0] + indexes[1]; return true; }; @@ -967,7 +977,7 @@ TEST(ShapeUtilTest, UpdateDynamicDimensions) { } TEST(ShapeUtilTest, InvalidDynamicDimension) { - StatusOr error_status = ShapeUtil::MakeValidatedShape( + absl::StatusOr error_status = ShapeUtil::MakeValidatedShape( F32, {Shape::kUnboundedSize, Shape::kUnboundedSize}, {true, false}); EXPECT_FALSE(error_status.ok()); @@ -1615,7 +1625,7 @@ void BM_ForEachIndex(::testing::benchmark::State& state) { for (auto s : state) { int count = 0; auto increment_func = - [&count](absl::Span indexes) -> StatusOr { + [&count](absl::Span indexes) -> absl::StatusOr { count++; return true; }; diff --git a/xla/sharding_op_util.cc b/xla/sharding_op_util.cc index 638402879796a..d3c6f0aa14167 100644 --- a/xla/sharding_op_util.cc +++ b/xla/sharding_op_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/sharding_op_util.h b/xla/sharding_op_util.h index cd90b16baf028..be71cab8464ff 100644 --- a/xla/sharding_op_util.h +++ b/xla/sharding_op_util.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/side_effect_util.cc b/xla/side_effect_util.cc index 5bbe7c30239ae..f7a7f198f840e 100644 --- a/xla/side_effect_util.cc +++ b/xla/side_effect_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/side_effect_util.h b/xla/side_effect_util.h index 1bb1397d58988..756ecf82f6b93 100644 --- a/xla/side_effect_util.h +++ b/xla/side_effect_util.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/status.h b/xla/status.h index b19cfe0ff2262..818bfdf4b1ba2 100644 --- a/xla/status.h +++ b/xla/status.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,12 +16,14 @@ limitations under the License. #ifndef XLA_STATUS_H_ #define XLA_STATUS_H_ -#include "tsl/platform/status.h" // IWYU pragma: export +#include "absl/log/check.h" // IWYU pragma: export +#include "absl/status/status.h" +#include "absl/status/statusor.h" namespace xla { // NOLINTBEGIN(misc-unused-using-decls) -using tsl::OkStatus; -using tsl::Status; // TENSORFLOW_STATUS_OK +using absl::OkStatus; +using absl::Status; // NOLINTEND(misc-unused-using-decls) } // namespace xla diff --git a/xla/status_macros.cc b/xla/status_macros.cc index e0b4341df664f..3587586a0efc9 100644 --- a/xla/status_macros.cc +++ b/xla/status_macros.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/status_macros.h b/xla/status_macros.h index bdc37673f39d2..392d829bb6b22 100644 --- a/xla/status_macros.h +++ b/xla/status_macros.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/status_macros_test.cc b/xla/status_macros_test.cc index ecbd7e89fe213..fe09a008143db 100644 --- a/xla/status_macros_test.cc +++ b/xla/status_macros_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -59,9 +59,9 @@ TEST(StatusMacros, RetCheckSucceeding) { EXPECT_IS_OK(status); } -StatusOr CreateIntSuccessfully() { return 42; } +absl::StatusOr CreateIntSuccessfully() { return 42; } -StatusOr CreateIntUnsuccessfully() { +absl::StatusOr CreateIntUnsuccessfully() { return tsl::errors::Internal("foobar"); } @@ -76,19 +76,20 @@ Status ReturnStatusError() { return (tsl::errors::Internal("foobar")); } using StatusReturningFunction = std::function; -StatusOr CallStatusReturningFunction(const StatusReturningFunction& func) { +absl::StatusOr CallStatusReturningFunction( + const StatusReturningFunction& func) { TF_RETURN_IF_ERROR(func()); return 42; } TEST(StatusMacros, ReturnIfErrorOnOK) { - StatusOr rc = CallStatusReturningFunction(ReturnStatusOK); + absl::StatusOr rc = CallStatusReturningFunction(ReturnStatusOK); EXPECT_IS_OK(rc); EXPECT_EQ(42, std::move(rc).value()); } TEST(StatusMacros, ReturnIfErrorOnError) { - StatusOr rc = CallStatusReturningFunction(ReturnStatusError); + absl::StatusOr rc = CallStatusReturningFunction(ReturnStatusError); EXPECT_FALSE(rc.ok()); EXPECT_EQ(rc.status().code(), tsl::error::INTERNAL); } diff --git a/xla/statusor.h b/xla/statusor.h index 2bb07c7f1e06a..a8704bfd45d06 100644 --- a/xla/statusor.h +++ b/xla/statusor.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/BUILD b/xla/stream_executor/BUILD index e35441efd34ce..a9264371d8633 100644 --- a/xla/stream_executor/BUILD +++ b/xla/stream_executor/BUILD @@ -1,14 +1,14 @@ -load("//xla:xla.bzl", "xla_cc_test") -load("//xla/stream_executor:build_defs.bzl", "stream_executor_friends", "stream_executor_internal") -load("@tsl//tsl:tsl.bzl", "set_external_visibility", "transitive_hdrs") +load("@tsl//tsl:tsl.bzl", "internal_visibility", "transitive_hdrs") load("@tsl//tsl:tsl.default.bzl", "filegroup") load("@tsl//tsl/platform:build_config.bzl", "tf_proto_library") load("@tsl//tsl/platform:build_config_root.bzl", "if_static") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") +load("//xla/stream_executor:build_defs.bzl", "stream_executor_friends", "stream_executor_internal") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([":friends"]), + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -51,7 +51,7 @@ package_group( # an implementation detail of StreamExecutor and has internal visibility. # # TODO(ezhulenev): Remove from public API headers that are exported via standalone public libraries, -# e.g. `platform` and `multi_platform_manager` should be added with an explicit dependency. +# e.g. `platform` and `platform_manager` should be added with an explicit dependency. filegroup( name = "stream_executor_api_headers", srcs = [ @@ -59,26 +59,22 @@ filegroup( "command_buffer.h", "data_type.h", "device_description.h", - "device_id_utils.h", "device_memory.h", "device_memory_allocator.h", - "device_options.h", "event.h", "executor_cache.h", + "host_memory_allocation.h", "kernel.h", "kernel_spec.h", "launch_dim.h", + "memory_allocation.h", "module_spec.h", - "multi_platform_manager.h", "numeric_options.h", "platform.h", + "platform_manager.h", "scratch_allocator.h", "stream.h", "stream_executor.h", - "stream_executor_internal.h", # TODO(ezhulenev): Remove private header - "temporary_device_memory.h", - "temporary_memory_manager.h", - "trace_listener.h", ], visibility = ["//visibility:private"], ) @@ -100,14 +96,12 @@ filegroup( STREAM_EXECUTOR_DEPENDENCIES = [ ":device_description_proto_cc", ":host_or_device_scalar", - ":multi_platform_manager", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -116,9 +110,10 @@ STREAM_EXECUTOR_DEPENDENCIES = [ "//xla/stream_executor/platform", "@tsl//tsl/framework:device_id", "@tsl//tsl/framework:device_type", + "@tsl//tsl/lib/gtl:int_type", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:float8", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", @@ -134,6 +129,9 @@ cc_library( ], deps = STREAM_EXECUTOR_DEPENDENCIES + [ ":stream_executor_pimpl", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@tsl//tsl/platform:thread_annotations", ] + if_static([ ":stream_executor_impl", "@com_google_protobuf//:protobuf", # indirectly-used by dnn.h @@ -164,22 +162,29 @@ cc_library( deps = [ ":device_description_proto_cc", ":launch_dim", - "//xla/stream_executor/platform", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@tsl//tsl/lib/math:math_util", "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", ], ) cc_library( name = "device_memory", hdrs = ["device_memory.h"], - deps = ["//xla/stream_executor/platform"], + deps = ["@tsl//tsl/platform:logging"], +) + +cc_library( + name = "data_type", + hdrs = ["data_type.h"], + visibility = [":internal"], + deps = [ + "@tsl//tsl/platform:ml_dtypes", + "@tsl//tsl/protobuf:dnn_proto_cc", + ], ) -# TODO(ezhulenev): Merge this target into `stream_executor`. cc_library( name = "device_memory_allocator", hdrs = ["device_memory_allocator.h"], @@ -187,39 +192,33 @@ cc_library( ":device_memory", ":platform", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", - "@tsl//tsl/platform:types", - ], -) - -cc_library( - name = "device_options", - hdrs = ["device_options.h"], - deps = [ - "//xla/stream_executor/platform", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", ], ) cc_library( - name = "device_id_utils", - hdrs = ["device_id_utils.h"], + name = "host_memory_allocation", + srcs = ["host_memory_allocation.cc"], + hdrs = ["host_memory_allocation.h"], deps = [ - ":platform", - ":stream_executor", - "@tsl//tsl/framework:device_id_impl", + ":memory_allocation", + ":stream_executor_internal", # TODO(b/323534971): Remove dependency on Interface. ], ) cc_library( name = "host_or_device_scalar", hdrs = ["host_or_device_scalar.h"], - deps = [":device_memory"], + deps = [ + ":device_memory", + "@com_google_absl//absl/log:check", + ], ) cc_library( @@ -229,19 +228,25 @@ cc_library( ) cc_library( - name = "multi_platform_manager", - srcs = ["multi_platform_manager.cc"], - hdrs = ["multi_platform_manager.h"], + name = "memory_allocation", + hdrs = ["memory_allocation.h"], +) + +cc_library( + name = "platform_manager", + srcs = ["platform_manager.cc"], + hdrs = ["platform_manager.h"], deps = [ ":platform", - "//xla/stream_executor/platform", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", ], ) @@ -257,12 +262,8 @@ cc_library( hdrs = ["platform.h"], deps = [ ":device_description", - ":device_options", - "//xla/stream_executor/platform", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", + "@com_google_absl//absl/status:statusor", ], ) @@ -279,11 +280,15 @@ cc_library( srcs = ["blas.cc"], hdrs = ["blas.h"], deps = [ - ":stream_executor_headers", + ":data_type", + ":device_memory", + ":numeric_options", "//xla/stream_executor/platform", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:errors", "@tsl//tsl/protobuf:dnn_proto_cc", ], ) @@ -293,19 +298,25 @@ cc_library( srcs = ["dnn.cc"], hdrs = ["dnn.h"], deps = [ + ":data_type", ":device_description_proto_cc", ":device_memory", ":numeric_options", - ":stream_executor_headers", "//xla/stream_executor/platform", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", + "@eigen_archive//:eigen3", # buildcleaner: keep "@tsl//tsl/lib/strings:proto_serialization", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", "@tsl//tsl/protobuf:dnn_proto_cc", @@ -324,9 +335,12 @@ cc_library( name = "lazy_op_runner", hdrs = ["lazy_op_runner.h"], deps = [ - ":dnn", ":stream_executor_headers", "@com_google_absl//absl/base", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/protobuf:dnn_proto_cc", ], ) @@ -347,16 +361,14 @@ exports_files(["lazy_op_runner.h"]) cc_library( name = "stream_executor_internal", hdrs = ["stream_executor_internal.h"], - visibility = [":internal"], + visibility = internal_visibility([":internal"]), deps = [ ":stream_executor_headers", "//xla/stream_executor/platform", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", ], ) @@ -377,7 +389,11 @@ cc_library( visibility = [":internal"], deps = STREAM_EXECUTOR_DEPENDENCIES + if_static([ "@com_google_protobuf//:protobuf", # indirectly-used by dnn.h - ]), + ]) + [ + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@tsl//tsl/platform:thread_annotations", + ], ) cc_library( @@ -391,11 +407,11 @@ cc_library( ":fft", ":platform", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", ], ) @@ -413,32 +429,23 @@ cc_library( srcs = ["allocator_stats.cc"], hdrs = ["allocator_stats.h"], visibility = ["//visibility:private"], - deps = [ - "//xla/stream_executor/platform", - "@com_google_absl//absl/strings:str_format", - ], + deps = ["@com_google_absl//absl/strings:str_format"], ) cc_library( name = "command_buffer", srcs = ["command_buffer.cc"], hdrs = ["command_buffer.h"], - local_defines = select({ - "//xla/stream_executor/cuda:graph_conditional_enabled": [ - "STREAM_EXECUTOR_CUDA_ENABLE_GRAPH_CONDITIONAL=1", - ], - "//conditions:default": [], - }), visibility = ["//visibility:private"], deps = [ ":stream_executor_headers", ":stream_executor_internal", - "//xla/stream_executor/platform", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@tsl//tsl/lib/gtl:int_type", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", ], ) @@ -451,8 +458,8 @@ cc_library( deps = [ ":stream_executor_headers", ":stream_executor_internal", - "//xla/stream_executor/platform", - "@tsl//tsl/platform:status", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", ], ) @@ -464,15 +471,13 @@ cc_library( deps = [ ":platform", ":stream_executor_headers", - "//xla/stream_executor/platform", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:logging", ], ) @@ -482,13 +487,11 @@ cc_library( hdrs = ["kernel_spec.h"], visibility = ["//visibility:private"], deps = [ - "//xla/stream_executor/platform", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", ], ) @@ -499,69 +502,33 @@ cc_library( visibility = ["//visibility:private"], deps = [ ":device_memory", + ":kernel_spec", ":platform", ":stream_executor_headers", ":stream_executor_internal", - "//xla/stream_executor/platform", "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:platform_port", - "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", ], ) cc_library( name = "scratch_allocator", - srcs = ["scratch_allocator.cc"], hdrs = ["scratch_allocator.h"], visibility = ["//visibility:private"], deps = [ + ":device_memory_allocator", ":stream_executor_headers", - ":temporary_device_memory", - "//xla/stream_executor/platform", "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:check", - "@tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "temporary_device_memory", - srcs = ["temporary_device_memory.cc"], - hdrs = ["temporary_device_memory.h"], - visibility = ["//visibility:private"], - deps = [ - ":stream_executor_headers", - "//xla/stream_executor/platform", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "temporary_memory_manager", - srcs = ["temporary_memory_manager.cc"], - hdrs = ["temporary_memory_manager.h"], - visibility = ["//visibility:private"], - deps = [ - ":stream_executor_headers", - ":temporary_device_memory", - "//xla/stream_executor/platform", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:status", + "@com_google_absl//absl/status:statusor", "@tsl//tsl/platform:statusor", ], ) @@ -570,6 +537,10 @@ cc_library( transitive_hdrs( name = "stream_executor_install_hdrs", + tags = [ + "alt_dep=:stream_executor_headers", + "avoid_dep", + ], deps = [":stream_executor_headers"], ) @@ -583,43 +554,35 @@ cc_library( "stream_executor_pimpl.cc", ], hdrs = ["stream_executor_pimpl.h"], + tags = ["avoid_dep"], visibility = ["//visibility:private"], deps = [ - ":blas", - ":command_buffer", - ":device_memory", - ":dnn", - ":event", - ":executor_cache", + ":blas", # build_cleaner: keep + ":command_buffer", # build_cleaner: keep + ":dnn", # build_cleaner: keep ":fft", - ":host_or_device_scalar", - ":kernel", + ":host_memory_allocation", # build_cleaner: keep ":kernel_spec", - ":launch_dim", ":platform", - ":plugin_registry", - ":scratch_allocator", ":stream_executor_headers", ":stream_executor_internal", - ":temporary_device_memory", - ":temporary_memory_manager", - "//xla/stream_executor/platform", + "//xla/tsl/util:env_var", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/memory", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@eigen_archive//:eigen3", - "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:numbers", "@tsl//tsl/platform:stacktrace", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", - "@tsl//tsl/protobuf:dnn_proto_cc", - "@tsl//tsl/util:env_var", ], ) @@ -640,13 +603,11 @@ cc_library( ":kernel", ":kernel_spec", ":launch_dim", - ":multi_platform_manager", ":platform", + ":platform_manager", ":scratch_allocator", ":stream_executor_headers", ":stream_executor_pimpl", - ":temporary_device_memory", - ":temporary_memory_manager", "@tsl//tsl/protobuf:dnn_proto_cc_impl", ], ) @@ -662,14 +623,24 @@ xla_cc_test( ":device_memory", ":stream_executor", "//xla/stream_executor/host:host_platform", - "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:status", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_benchmark", "@tsl//tsl/platform:test_main", ], ) +xla_cc_test( + name = "stream_executor_test", + srcs = ["stream_executor_test.cc"], + deps = [ + ":stream_executor", + "//xla/stream_executor/host:host_platform", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + xla_cc_test( name = "stream_test", size = "small", @@ -677,6 +648,8 @@ xla_cc_test( deps = [ ":stream_executor", "//xla/stream_executor/host:host_platform", + "@com_google_absl//absl/log:check", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -714,7 +687,7 @@ cc_library( deps = [ ":dnn", ":event", - ":multi_platform_manager", + ":platform_manager", ":scratch_allocator", ":stream_executor", "//xla/stream_executor/cuda:cuda_platform_id", diff --git a/xla/stream_executor/allocator_stats.cc b/xla/stream_executor/allocator_stats.cc index d37cb9f45e2dc..de6432b29d7bf 100644 --- a/xla/stream_executor/allocator_stats.cc +++ b/xla/stream_executor/allocator_stats.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ limitations under the License. #include "xla/stream_executor/allocator_stats.h" +#include + #include "absl/strings/str_format.h" namespace stream_executor { diff --git a/xla/stream_executor/allocator_stats.h b/xla/stream_executor/allocator_stats.h index 1e15d0c51e64b..c6d185cbcd777 100644 --- a/xla/stream_executor/allocator_stats.h +++ b/xla/stream_executor/allocator_stats.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,9 +16,10 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_ALLOCATOR_STATS_H_ #define XLA_STREAM_EXECUTOR_ALLOCATOR_STATS_H_ +#include +#include #include -#include "xla/stream_executor/platform/port.h" namespace stream_executor { diff --git a/xla/stream_executor/blas.cc b/xla/stream_executor/blas.cc index 594acf41adaca..ec6c40d59f74c 100644 --- a/xla/stream_executor/blas.cc +++ b/xla/stream_executor/blas.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,7 +16,10 @@ limitations under the License. #include "xla/stream_executor/blas.h" #include +#include +#include +#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "xla/stream_executor/device_memory.h" diff --git a/xla/stream_executor/blas.h b/xla/stream_executor/blas.h index 0cdb0ddb854f3..3f78c95648008 100644 --- a/xla/stream_executor/blas.h +++ b/xla/stream_executor/blas.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,40 +17,25 @@ limitations under the License. // use in conjunction with the StreamExecutor abstraction. // // Note that this interface is optionally supported by platforms. -// -// This abstraction makes it simple to entrain BLAS operations on GPU data into -// a Stream -- users typically will not use this API directly, but will use the -// Stream builder methods to entrain these operations "under the hood". For -// example: -// -// DeviceMemory x = stream_exec->AllocateArray(1024); -// DeviceMemory y = stream_exec->AllocateArray(1024); -// // ... populate x and y ... -// Stream stream{stream_exec}; -// stream -// .Init() -// .ThenBlasAxpy(1024, 5.5, x, 1, &y, 1); -// TF_CHECK_OK(stream.BlockHostUntilDone()); -// -// By using stream operations in this manner the user can easily intermix custom -// kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned BLAS -// routines. #ifndef XLA_STREAM_EXECUTOR_BLAS_H_ #define XLA_STREAM_EXECUTOR_BLAS_H_ #include +#include #include #include #include +#include #include +#include "absl/status/status.h" #include "absl/types/span.h" #include "xla/stream_executor/data_type.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/numeric_options.h" #include "xla/stream_executor/platform/port.h" -#include "tsl/platform/statusor.h" +#include "tsl/platform/errors.h" #include "tsl/protobuf/dnn.pb.h" namespace Eigen { @@ -61,7 +46,9 @@ namespace stream_executor { namespace gpu { struct BlasLt; -} +struct MatrixDescriptor; +struct OutputMatrixDescriptor; +} // namespace gpu class Stream; class ScratchAllocator; @@ -171,13 +158,15 @@ class ProfileResult { public: bool is_valid() const { return is_valid_; } void set_is_valid(bool val) { is_valid_ = val; } + bool warmup_run_executed() const { return warmup_run_executed_; } + void set_warmup_run_executed(bool val) { warmup_run_executed_ = val; } AlgorithmType algorithm() const { return algorithm_; } void set_algorithm(AlgorithmType val) { algorithm_ = val; } float elapsed_time_in_ms() const { return elapsed_time_in_ms_; } void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; } private: - bool is_valid_ = false; + bool is_valid_ = false, warmup_run_executed_ = false; AlgorithmType algorithm_ = kDefaultAlgorithm; float elapsed_time_in_ms_ = std::numeric_limits::max(); }; @@ -204,6 +193,21 @@ class AlgorithmConfig { typedef int64_t ComputePrecision; constexpr ComputePrecision kDefaultComputePrecision = 0; +namespace detail { + +// Helper to return if `T` is the same type as `First` or any or `Rest`. +template +constexpr bool is_any_of() { + return false; +} + +template +constexpr bool is_any_of() { + return std::is_same_v || is_any_of(); +} + +} // namespace detail + // BLAS support interface -- this can be derived from a GPU executor when the // underlying platform has an BLAS library implementation available. See // StreamExecutor::AsBlas(). @@ -222,31 +226,11 @@ class BlasSupport { virtual bool DoBlasAxpy(Stream *stream, uint64_t elem_count, float alpha, const DeviceMemory &x, int incx, DeviceMemory *y, int incy) = 0; - virtual bool DoBlasAxpy(Stream *stream, uint64_t elem_count, double alpha, - const DeviceMemory &x, int incx, - DeviceMemory *y, int incy) = 0; - virtual bool DoBlasAxpy(Stream *stream, uint64_t elem_count, - std::complex alpha, - const DeviceMemory> &x, int incx, - DeviceMemory> *y, int incy) = 0; - virtual bool DoBlasAxpy(Stream *stream, uint64_t elem_count, - std::complex alpha, - const DeviceMemory> &x, int incx, - DeviceMemory> *y, int incy) = 0; // Copies vector to another vector: y <- x. virtual bool DoBlasCopy(Stream *stream, uint64_t elem_count, const DeviceMemory &x, int incx, DeviceMemory *y, int incy) = 0; - virtual bool DoBlasCopy(Stream *stream, uint64_t elem_count, - const DeviceMemory &x, int incx, - DeviceMemory *y, int incy) = 0; - virtual bool DoBlasCopy(Stream *stream, uint64_t elem_count, - const DeviceMemory> &x, int incx, - DeviceMemory> *y, int incy) = 0; - virtual bool DoBlasCopy(Stream *stream, uint64_t elem_count, - const DeviceMemory> &x, int incx, - DeviceMemory> *y, int incy) = 0; // Computes the product of a vector by a scalar: x <- a*x. virtual bool DoBlasScal(Stream *stream, uint64_t elem_count, float alpha, @@ -307,11 +291,6 @@ class BlasSupport { uint64_t k, float alpha, const DeviceMemory &a, int lda, const DeviceMemory &x, int incx, float beta, DeviceMemory *y, int incy) = 0; - virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n, - uint64_t k, double alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &x, int incx, double beta, - DeviceMemory *y, int incy) = 0; // Computes a matrix-matrix product with general matrices: // @@ -327,18 +306,19 @@ class BlasSupport { // // Alpha/beta type matches `dtype`, unless `dtype` is `Eigen::half`, in that // case the expected alpha/beta type is `float`. - virtual tsl::Status DoBlasGemm(Stream *stream, blas::Transpose transa, - blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, DataType dtype, const void *alpha, - const DeviceMemoryBase &a, int lda, - const DeviceMemoryBase &b, int ldb, - const void *beta, DeviceMemoryBase *c, int ldc, - const NumericOptions &numeric_options, - blas::CallContext context) = 0; + virtual absl::Status DoBlasGemm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, + uint64_t m, uint64 n, uint64_t k, DataType dtype, const void *alpha, + const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb, + const void *beta, DeviceMemoryBase *c, int ldc, + const NumericOptions &numeric_options, blas::CallContext context) = 0; // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. virtual bool GetBlasGemmAlgorithms( - Stream *stream, std::vector *out_algorithms) = 0; + Stream *stream, const gpu::MatrixDescriptor &a, + const gpu::MatrixDescriptor &b, gpu::OutputMatrixDescriptor *c, + const void *alpha, const void *beta, + std::vector *out_algorithms) = 0; // Like DoBlasGemm, but accepts an algorithm and an compute type. // @@ -351,7 +331,7 @@ class BlasSupport { // output_profile_result->is_valid(). This lets you use this function for // choosing the best algorithm among many (some of which may fail) without // creating a new Stream for each attempt. - virtual tsl::Status DoBlasGemmWithAlgorithm( + virtual absl::Status DoBlasGemmWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a, DataType type_a, int lda, @@ -360,7 +340,7 @@ class BlasSupport { ComputationType computation_type, AlgorithmType algorithm, const NumericOptions &numeric_options, ProfileResult *output_profile_result, blas::CallContext context) = 0; - virtual tsl::Status DoBlasGemmStridedBatchedWithAlgorithm( + virtual absl::Status DoBlasGemmStridedBatchedWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a, DataType type_a, int lda, int64_t stride_a, @@ -423,7 +403,7 @@ class BlasSupport { int ldc, int batch_count, const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, blas::CallContext context) = 0; // Batched gemm with strides instead of pointer arrays. - virtual tsl::Status DoBlasGemmStridedBatched( + virtual absl::Status DoBlasGemmStridedBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64_t n, uint64 k, DataType dtype, const void *alpha, const DeviceMemoryBase &a, int lda, int64_t stride_a, @@ -431,6 +411,170 @@ class BlasSupport { DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count, const NumericOptions &numeric_options, blas::CallContext context) = 0; + template + absl::Status BlasGemmStridedBatchedWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, + uint64_t m, uint64 n, uint64_t k, ConstantType alpha, + const DeviceMemory &a, int lda, int64_t stride_a, + const DeviceMemory &b, int ldb, int64_t stride_b, + ConstantType beta, DeviceMemory *c, int ldc, int64_t stride_c, + int batch_count, blas::ComputationType computation_type, + blas::AlgorithmType algorithm, const NumericOptions &numeric_options, + blas::ProfileResult *output_profile_result, blas::CallContext context) { + TF_RETURN_IF_ERROR( + CheckTypesForExtendedBlas( + computation_type)); + + void *alpha_ptr = α + void *beta_ptr = β + float alpha_storage, beta_storage; + UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, + &beta_storage); + absl::Status status = DoBlasGemmStridedBatchedWithAlgorithm( + stream, transa, transb, m, n, k, alpha_ptr, a, + blas::ToDataType::value, lda, stride_a, b, + blas::ToDataType::value, ldb, stride_b, beta_ptr, c, + blas::ToDataType::value, ldc, stride_c, batch_count, + computation_type, algorithm, numeric_options, output_profile_result, + context); + if (output_profile_result) { + // The error is recorded in the profile. + return absl::OkStatus(); + } + return status; + } + + template + absl::Status BlasGemm(Stream *stream, blas::Transpose transa, + blas::Transpose transb, uint64_t m, uint64 n, uint64 k, + ConstantType alpha, const DeviceMemory &a, + int lda, const DeviceMemory &b, int ldb, + ConstantType beta, DeviceMemory *c, int ldc, + const NumericOptions &numeric_options, + blas::CallContext context) { + static_assert( + detail::is_any_of, + std::complex>(), + "Input can be int8_t, half, bf16, float, double, std::complex " + "or " + "std::complex"); + static_assert(!std::is_same_v || + detail::is_any_of(), + "If input is Eigen::half, constant has to be either " + "Eigen::half or float"); + static_assert(detail::is_any_of(), + "If input is not int8_t, Eigen::half, constant and input " + "types have to match"); + void *alpha_ptr = α + void *beta_ptr = β + float alpha_storage, beta_storage; + UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, + &beta_storage); + + return DoBlasGemm(stream, transa, transb, m, n, k, + blas::ToDataType::value, alpha_ptr, a, lda, b, + ldb, beta_ptr, c, ldc, numeric_options, context); + } + + template + absl::Status BlasGemm(Stream *stream, blas::Transpose transa, + blas::Transpose transb, uint64_t m, uint64 n, uint64 k, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, + DeviceMemory *c, int ldc, + const NumericOptions &numeric_options, + blas::CallContext context) { + InputType alpha{1.0}; + InputType beta{0.0}; + return BlasGemm(stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, numeric_options, context); + } + + template + absl::Status BlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, + uint64_t m, uint64 n, uint64_t k, ConstantType alpha, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, ConstantType beta, + DeviceMemory *c, int ldc, + blas::ComputationType computation_type, blas::AlgorithmType algorithm, + const NumericOptions &numeric_options, + blas::ProfileResult *output_profile_result, blas::CallContext context) { + TF_RETURN_IF_ERROR( + CheckTypesForExtendedBlas( + computation_type)); + + void *alpha_ptr = α + void *beta_ptr = β + float alpha_storage, beta_storage; + UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, + &beta_storage); + + absl::Status st = DoBlasGemmWithAlgorithm( + stream, transa, transb, m, n, k, alpha_ptr, a, + blas::ToDataType::value, lda, b, + blas::ToDataType::value, ldb, beta_ptr, c, + blas::ToDataType::value, ldc, computation_type, algorithm, + numeric_options, output_profile_result, context); + + if (output_profile_result) { + // The error is recorded in the profile. + return absl::OkStatus(); + } + return st; + } + + template + absl::Status BlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, + uint64_t m, uint64 n, uint64_t k, const DeviceMemory &a, + int lda, const DeviceMemory &b, int ldb, + DeviceMemory *c, int ldc, + blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result, blas::CallContext context) { + OutputType alpha{1}; + OutputType beta{0}; + + return BlasGemmWithAlgorithm(stream, transa, transb, m, n, k, alpha, a, lda, + b, ldb, beta, c, ldc, computation_type, + algorithm, NumericOptions{}, + output_profile_result, context); + } + + template + absl::Status BlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, + uint64_t m, uint64 n, uint64_t k, ConstantType alpha, + const DeviceMemory &a, int lda, int64_t stride_a, + const DeviceMemory &b, int ldb, int64_t stride_b, + ConstantType beta, DeviceMemory *c, int ldc, int64_t stride_c, + int batch_count, const NumericOptions &numeric_options, + blas::CallContext context) { + static_assert( + detail::is_any_of, + std::complex>(), + "Unsupported input type"); + static_assert(std::is_same_v || + (detail::is_any_of() && + std::is_same_v), + "Mismatched input and alpha/beta types"); + + void *alpha_ptr = α + void *beta_ptr = β + float alpha_storage, beta_storage; + UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, + &beta_storage); + + return DoBlasGemmStridedBatched( + stream, transa, transb, m, n, k, blas::ToDataType::value, + alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc, + stride_c, batch_count, numeric_options, context); + } + // Solves a triangular matrix equation. // // op(a) * x = alpha * b, @@ -511,7 +655,7 @@ class BlasSupport { BlasSupport *blas_; }; - virtual tsl::Status GetVersion(std::string *version) = 0; + virtual absl::Status GetVersion(std::string *version) = 0; protected: DeviceMemoryBase *GetWorkspace(); @@ -532,6 +676,71 @@ class BlasSupport { // own memory pool for allocating workspace. void ResetWorkspace(); + // Checks whether types match before a call to extended BLAS version. + template + absl::Status CheckTypesForExtendedBlas( + blas::ComputationType computation_type) { + static_assert( + detail::is_any_of, std::complex>(), + "The only buffer types supported are: Eigen::half, float, " + "double, int8, std::complex and std::complex"); + static_assert( + std::is_same_v || + (std::is_same_v && + detail::is_any_of()), + "Mismatched alpha/beta and output types"); + + bool valid_computation_type = [computation_type] { + switch (computation_type) { + case blas::ComputationType::kF16: + return std::is_same_v; + case blas::ComputationType::kF32: + return detail::is_any_of>(); + case blas::ComputationType::kF64: + return detail::is_any_of>(); + case blas::ComputationType::kI32: + return std::is_same_v; + case blas::ComputationType::kF16AsF32: // fall-through + case blas::ComputationType::kBF16AsF32: // fall-through + case blas::ComputationType::kTF32AsF32: + return detail::is_any_of>(); + } + }(); + + if (!valid_computation_type) { + return absl::InternalError(absl::StrCat( + "Invalid computation type ", + blas::ComputationTypeString(computation_type), " for output type: ", + blas::DataTypeString(blas::ToDataType::value))); + } + return absl::OkStatus(); + } + + // Non-extended BLAS interface requires alpha/beta to be floats when input + // type is Eigen::half. However, for consistency purposes it is convenient + // for the interface to accept Eigen::half. + template + void UpcastHalfToFloat(void **alpha_ptr, void **beta_ptr, + float *alpha_storage, float *beta_storage) { + if (std::is_same::value) { + *alpha_storage = + static_cast(*reinterpret_cast(*alpha_ptr)); + *beta_storage = + static_cast(*reinterpret_cast(*beta_ptr)); + *alpha_ptr = alpha_storage; + *beta_ptr = beta_storage; + } else if (std::is_same::value) { + *alpha_storage = + static_cast(*reinterpret_cast(*alpha_ptr)); + *beta_storage = + static_cast(*reinterpret_cast(*beta_ptr)); + *alpha_ptr = alpha_storage; + *beta_ptr = beta_storage; + } + } + BlasSupport(const BlasSupport &) = delete; void operator=(const BlasSupport &) = delete; }; @@ -542,29 +751,9 @@ class BlasSupport { bool DoBlasAxpy(Stream *stream, uint64_t elem_count, float alpha, \ const DeviceMemory &x, int incx, \ DeviceMemory *y, int incy) override; \ - bool DoBlasAxpy(Stream *stream, uint64_t elem_count, double alpha, \ - const DeviceMemory &x, int incx, \ - DeviceMemory *y, int incy) override; \ - bool DoBlasAxpy(Stream *stream, uint64_t elem_count, \ - std::complex alpha, \ - const DeviceMemory> &x, int incx, \ - DeviceMemory> *y, int incy) override; \ - bool DoBlasAxpy(Stream *stream, uint64_t elem_count, \ - std::complex alpha, \ - const DeviceMemory> &x, int incx, \ - DeviceMemory> *y, int incy) override; \ bool DoBlasCopy(Stream *stream, uint64_t elem_count, \ const DeviceMemory &x, int incx, \ DeviceMemory *y, int incy) override; \ - bool DoBlasCopy(Stream *stream, uint64_t elem_count, \ - const DeviceMemory &x, int incx, \ - DeviceMemory *y, int incy) override; \ - bool DoBlasCopy(Stream *stream, uint64_t elem_count, \ - const DeviceMemory> &x, int incx, \ - DeviceMemory> *y, int incy) override; \ - bool DoBlasCopy(Stream *stream, uint64_t elem_count, \ - const DeviceMemory> &x, int incx, \ - DeviceMemory> *y, int incy) override; \ bool DoBlasScal(Stream *stream, uint64_t elem_count, float alpha, \ DeviceMemory *x, int incx) override; \ bool DoBlasScal(Stream *stream, uint64_t elem_count, double alpha, \ @@ -603,21 +792,19 @@ class BlasSupport { float alpha, const DeviceMemory &a, int lda, \ const DeviceMemory &x, int incx, float beta, \ DeviceMemory *y, int incy) override; \ - bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n, uint64 k, \ - double alpha, const DeviceMemory &a, int lda, \ - const DeviceMemory &x, int incx, double beta, \ - DeviceMemory *y, int incy) override; \ - tsl::Status DoBlasGemm( \ + absl::Status DoBlasGemm( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64_t m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha, \ const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb, \ const void *beta, DeviceMemoryBase *c, int ldc, \ const NumericOptions &numeric_options, blas::CallContext context) \ override; \ - bool GetBlasGemmAlgorithms(Stream *stream, \ - std::vector *out_algorithms) \ - override; \ - tsl::Status DoBlasGemmWithAlgorithm( \ + bool GetBlasGemmAlgorithms( \ + Stream *stream, const gpu::MatrixDescriptor &a, \ + const gpu::MatrixDescriptor &b, gpu::OutputMatrixDescriptor *c, \ + const void *alpha, const void *beta, \ + std::vector *out_algorithms) override; \ + absl::Status DoBlasGemmWithAlgorithm( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64_t m, uint64 n, uint64 k, const void *alpha, \ const DeviceMemoryBase &a, blas::DataType type_a, int lda, \ @@ -679,7 +866,7 @@ class BlasSupport { int ldc, int batch_count, const NumericOptions &numeric_options, \ ScratchAllocator *scratch_allocator, blas::CallContext context) \ override; \ - tsl::Status DoBlasGemmStridedBatched( \ + absl::Status DoBlasGemmStridedBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64_t m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha, \ const DeviceMemoryBase &a, int lda, int64_t stride_a, \ @@ -687,7 +874,7 @@ class BlasSupport { DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count, \ const NumericOptions &numeric_options, blas::CallContext context) \ override; \ - tsl::Status DoBlasGemmStridedBatchedWithAlgorithm( \ + absl::Status DoBlasGemmStridedBatchedWithAlgorithm( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64_t m, uint64 n, uint64 k, const void *alpha, \ const DeviceMemoryBase &a, blas::DataType type_a, int lda, \ @@ -740,7 +927,7 @@ class BlasSupport { const DeviceMemory *> &as, \ int lda, DeviceMemory *> *bs, \ int ldb, int batch_count) override; \ - tsl::Status GetVersion(std::string *version) override; + absl::Status GetVersion(std::string *version) override; } // namespace blas } // namespace stream_executor diff --git a/xla/stream_executor/build_defs.bzl b/xla/stream_executor/build_defs.bzl index 960502a5c6ac4..a937a57edc31d 100644 --- a/xla/stream_executor/build_defs.bzl +++ b/xla/stream_executor/build_defs.bzl @@ -1,7 +1,11 @@ """Configurations for StreamExecutor builds""" load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") -load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") +load("@local_config_rocm//rocm:build_defs.bzl", _if_gpu_is_configured = "if_gpu_is_configured") +load( + "@tsl//tsl/platform:rules_cc.bzl", + "cc_library", +) def stream_executor_friends(): return ["//..."] @@ -12,16 +16,12 @@ def stream_executor_internal(): def tf_additional_cuda_platform_deps(): return [] -def tf_additional_cudnn_plugin_deps(): - return [] - def tf_additional_cudnn_plugin_copts(): - # TODO(timshen): remove TF_ENABLE_CUDNN_FRONTEND once cudnn-frontend is imported. - return ["-DNV_CUDNN_DISABLE_EXCEPTION", "-DTF_ENABLE_CUDNN_FRONTEND"] + return ["-DNV_CUDNN_DISABLE_EXCEPTION"] -# Returns whether any GPU backend is configuered. -def if_gpu_is_configured(x): - return if_cuda_is_configured(x) + if_rocm_is_configured(x) +# Returns whether any GPU backend is configured. +def if_gpu_is_configured(if_true, if_false = []): + return _if_gpu_is_configured(if_true, if_false) def if_cuda_or_rocm(x): return if_gpu_is_configured(x) @@ -30,3 +30,61 @@ def if_cuda_or_rocm(x): # unnecessary dependency def tf_additional_gpu_compilation_copts(): return ["-DTF_DISABLE_NVLINK_BY_DEFAULT"] + +def gpu_only_cc_library(name, tags = [], **kwargs): + """A library that only gets compiled when GPU is configured, otherwise it's an empty target. + + Args: + name: Name of the target + tags: Tags being applied to the implementation target + **kwargs: Accepts all arguments that a `cc_library` would also accept + """ + if not native.package_name().startswith("xla/stream_executor"): + fail("gpu_only_cc_library may only be used in `xla/stream_executor/...`.") + + cc_library( + name = "%s_non_gpu" % name, + tags = ["manual"], + ) + cc_library( + name = "%s_gpu_only" % name, + tags = tags + ["manual"], + **kwargs + ) + native.alias( + name = name, + actual = if_gpu_is_configured(":%s_gpu_only" % name, ":%s_non_gpu" % name), + visibility = kwargs.get("visibility"), + compatible_with = kwargs.get("compatible_with"), + restricted_to = kwargs.get("restricted_to"), + target_compatible_with = kwargs.get("target_compatible_with"), + ) + +def cuda_only_cc_library(name, tags = [], **kwargs): + """A library that only gets compiled when CUDA is configured, otherwise it's an empty target. + + Args: + name: Name of the target + tags: Tags being applied to the implementation target + **kwargs: Accepts all arguments that a `cc_library` would also accept + """ + if not native.package_name().startswith("xla/stream_executor"): + fail("cuda_only_cc_library may only be used in `xla/stream_executor/...`.") + + cc_library( + name = "%s_non_cuda" % name, + tags = ["manual"], + ) + cc_library( + name = "%s_cuda_only" % name, + tags = tags + ["manual"], + **kwargs + ) + native.alias( + name = name, + actual = if_cuda_is_configured(":%s_cuda_only" % name, ":%s_non_cuda" % name), + visibility = kwargs.get("visibility"), + compatible_with = kwargs.get("compatible_with"), + restricted_to = kwargs.get("restricted_to"), + target_compatible_with = kwargs.get("target_compatible_with"), + ) diff --git a/xla/stream_executor/command_buffer.cc b/xla/stream_executor/command_buffer.cc index 551a667843d81..e2172a0e7dc37 100644 --- a/xla/stream_executor/command_buffer.cc +++ b/xla/stream_executor/command_buffer.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,175 +15,50 @@ limitations under the License. #include "xla/stream_executor/command_buffer.h" -#include -#include #include #include -#include #include "absl/functional/any_invocable.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" #include "tsl/platform/errors.h" -#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" namespace stream_executor { -CommandBuffer::~CommandBuffer() = default; -CommandBuffer::CommandBuffer(CommandBuffer&&) = default; -CommandBuffer& CommandBuffer::operator=(CommandBuffer&&) = default; - -void CommandBuffer::Deleter::operator()( - internal::CommandBufferInterface* impl) { - if (owned) delete impl; -} - -/*static*/ tsl::StatusOr CommandBuffer::Create( +absl::StatusOr> CommandBuffer::Create( StreamExecutor* executor, Mode mode) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr command_buffer, - executor->implementation()->GetCommandBufferImplementation(mode)); - - CommandBuffer cmd(std::move(command_buffer)); - return cmd; + return executor->implementation()->CreateCommandBuffer(mode); } -/*static*/ tsl::StatusOr CommandBuffer::Trace( - StreamExecutor* executor, absl::AnyInvocable function, - Mode mode) { - Stream stream(executor); +absl::StatusOr> CommandBuffer::Trace( + StreamExecutor* executor, + absl::AnyInvocable function, Mode mode) { + TF_ASSIGN_OR_RETURN(auto stream, executor->CreateStream()); + return Trace(executor, stream.get(), std::move(function), mode); +} - // TODO(ezhulenev): Keep a dedicated stream for command buffer tracing in the - // StreamExecutor itself, and maybe add a StreamPool argument to the traced - // function arguments to be able to trace multiple stream simultaneously. - stream.Init(); - if (!stream.ok()) - return absl::InternalError( - "Failed to initialize stream for command buffer tracing"); +absl::StatusOr> CommandBuffer::Trace( + StreamExecutor* executor, Stream* stream, + absl::AnyInvocable function, Mode mode) { + if (stream == nullptr) + return absl::InvalidArgumentError( + "Can't trace command buffer on a null stream"); // Prepare an empty command buffer instance. - TF_ASSIGN_OR_RETURN(CommandBuffer command_buffer, + TF_ASSIGN_OR_RETURN(std::unique_ptr command_buffer, CommandBuffer::Create(executor, mode)); // Trace and finalize the command buffer. - TF_RETURN_IF_ERROR(command_buffer.implementation()->Trace( - &stream, [&]() { return function(&stream); })); - TF_RETURN_IF_ERROR(command_buffer.implementation()->Finalize()); + TF_RETURN_IF_ERROR( + command_buffer->Trace(stream, [&]() { return function(stream); })); + TF_RETURN_IF_ERROR(command_buffer->Finalize()); return command_buffer; } -/*static*/ bool CommandBuffer::SupportsConditionalCommands( - const Platform* platform) { - // TODO(ezhulenev): We should extend a Platform with a way to query - // implemented StreamExecutor features, for now we know that only CUDA - // platform supports conditional commands in command buffers. -#if defined(STREAM_EXECUTOR_CUDA_ENABLE_GRAPH_CONDITIONAL) - return platform->Name() == "CUDA"; -#endif - return false; -} - -const internal::CommandBufferInterface* CommandBuffer::implementation() const { - return implementation_.get(); -} - -internal::CommandBufferInterface* CommandBuffer::implementation() { - return implementation_.get(); -} - -/*static*/ CommandBuffer CommandBuffer::Create( - std::unique_ptr implementation) { - return CommandBuffer(std::move(implementation)); -} - -/*static*/ tsl::Status CommandBuffer::Build( - internal::CommandBufferInterface* implementation, - const CommandBuffer::Builder& builder) { - CommandBuffer command_buffer(implementation); - return builder(&command_buffer); -} - -CommandBuffer::CommandBuffer( - std::unique_ptr implementation) - : implementation_(implementation.release(), {/*owned=*/true}) {} - -CommandBuffer::CommandBuffer(internal::CommandBufferInterface* implementation) - : implementation_(implementation, {/*owned=*/false}) {} - -tsl::Status CommandBuffer::Launch(const ThreadDim& threads, - const BlockDim& blocks, const Kernel& kernel, - const KernelArgs& args) { - return implementation_->Launch(threads, blocks, kernel, args); -} - -tsl::Status CommandBuffer::AddNestedCommandBuffer(const CommandBuffer& nested) { - return implementation_->AddNestedCommandBuffer(nested); -} - -tsl::Status CommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst, - const DeviceMemoryBase& src, - uint64_t size) { - return implementation_->MemcpyDeviceToDevice(dst, src, size); -} - -tsl::Status CommandBuffer::Memset(DeviceMemoryBase* dst, BitPattern bit_pattern, - size_t num_elements) { - return implementation_->Memset(dst, bit_pattern, num_elements); -} - -tsl::StatusOr CommandBuffer::Allocate(size_t bytes) { - return implementation_->Allocate(bytes); -} - -tsl::Status CommandBuffer::If(StreamExecutor* executor, DeviceMemory pred, - Builder then_builder) { - return implementation_->If(executor, pred, std::move(then_builder)); -} - -tsl::Status CommandBuffer::IfElse(StreamExecutor* executor, - DeviceMemory pred, Builder then_builder, - Builder else_builder) { - return implementation_->IfElse(executor, pred, std::move(then_builder), - std::move(else_builder)); -} - -tsl::Status CommandBuffer::Case(StreamExecutor* executor, - DeviceMemory index, - std::vector branches) { - return implementation_->Case(executor, index, std::move(branches)); -} - -tsl::Status CommandBuffer::For(StreamExecutor* executor, int32_t num_iteration, - DeviceMemory loop_counter, - Builder body_builder) { - return implementation_->For(executor, num_iteration, loop_counter, - std::move(body_builder)); -} - -tsl::Status CommandBuffer::While(StreamExecutor* executor, - DeviceMemory pred, Builder cond_builder, - Builder body_builder) { - return implementation_->While(executor, pred, std::move(cond_builder), - std::move(body_builder)); -} - -CommandBuffer::Mode CommandBuffer::mode() const { - return implementation_->mode(); -} - -CommandBuffer::State CommandBuffer::state() const { - return implementation_->state(); -} - -tsl::Status CommandBuffer::Finalize() { return implementation_->Finalize(); } - -tsl::Status CommandBuffer::Update() { return implementation_->Update(); } - } // namespace stream_executor diff --git a/xla/stream_executor/command_buffer.h b/xla/stream_executor/command_buffer.h index 3ac378acbdf17..1a6c745ef8475 100644 --- a/xla/stream_executor/command_buffer.h +++ b/xla/stream_executor/command_buffer.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,23 +24,21 @@ limitations under the License. #include #include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" +#include "tsl/lib/gtl/int_type.h" #include "tsl/platform/errors.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace stream_executor { class Stream; class StreamExecutor; -namespace internal { -class CommandBufferInterface; -} - //===----------------------------------------------------------------------===// // CommandBuffer //===----------------------------------------------------------------------===// @@ -53,12 +51,93 @@ class CommandBufferInterface; // device. class CommandBuffer { public: + // Execution scope enables fine-grained synchronization scopes inside + // commands buffers. Implementation is very backend-specific and for CUDA/ROCM + // backends it's implemented as DAG edges. By default all commands launched in + // the `kDefaulExecutionScope` execution scope. + // + // Example #1: independent execution scopes and independent barriers + // + // ExecutionScope #0 ExecutionScope #1 + // + // A D + // B E + // ----- barrier ----- ----- barrier ----- + // C F + // + // (1) Commands A and B can run concurrently and must complete before C. + // (2) Commands D and E can run concurrently and must complete before F. + // (3) There is no syncrhonization between execution scopes, and commands + // from different execution scopes can execute concurrently with each + // other as long as they satisfy constraints of their respective + // execution scopes. + // + // + // + // Example #2: dependencies between scopes and inter-scope barriers + // + // ExecutionScope #0 ExecutionScope #1 + // + // A D + // B E + // ----------------- barrier ------------------ + // C F + // + // (1) Commands A and B can run concurrently and must complete before + // C and F. + // (2) Commands D and E can run concurrently and must complete before + // C and F. + // (3) Commands C and F can run concurrently. + // (4) All commands before a shared barrier (in both excecution scopes) + // should complete before any command after a berrier starts execution. + // + // + // + // Example #3: one-directional barriers between execution scopes + // + // ExecutionScope #0 ExecutionScope #1 + // + // A + // B + // ----- barrier ----- D + // C \ E + // ----- barrier ----- + // F + // + // (1) Commands A and B can run concurrently and must complete before + // C and F. + // (2) Commands D and E can run concurrently and must complete before + // F (does not synchronize with C). + // (3) Commands C and F can run concurrently. + // + // This is a more fine-grained barrier than in example #2: it enforces + // synchronization from execution scope #0 to execution scope #1 but no + // synchronization in other direction. For CUDA/ROCM backend it has the same + // semantics as stream wait operation. + // + TSL_LIB_GTL_DEFINE_INT_TYPE(ExecutionScopeId, int64_t); + static constexpr auto kDefaulExecutionScope = ExecutionScopeId(0); + // Builder constructs nested command buffers owned by a parent command buffer. - using Builder = std::function; + // + // Builder can use arbitrary number of nested execution scopes, the only + // requirement is that after builder constructed all commands, they all must + // be synchronized with a default execution scope. + using Builder = std::function; + + // An extension of a `Builder` defined above that builds a nested command + // buffer in a given execution scope. Builder can use arbitrary number of + // nested execution scopes, the only requirement is that after builder + // constructed all commands, they all must be synchronized with an execution + // scope passed as an argument. + using ExecutionScopeBuilder = + std::function; + + CommandBuffer() = default; + virtual ~CommandBuffer() = default; - ~CommandBuffer(); - CommandBuffer(CommandBuffer&&); - CommandBuffer& operator=(CommandBuffer&&); + CommandBuffer(const CommandBuffer&) = delete; + void operator=(const CommandBuffer&) = delete; // Command buffer state: // @@ -86,9 +165,13 @@ class CommandBuffer { // Command buffer constructors //===--------------------------------------------------------------------===// + // TODO(b/323534971): Command buffer constructors should be moved to + // StreamExecutor or a dedicated CommandBufferFactory accessible via + // StreamExecutor. + // Creates a new empty command buffer on the given executor. - static tsl::StatusOr Create(StreamExecutor* executor, - Mode mode = Mode::kPrimary); + static absl::StatusOr> Create( + StreamExecutor* executor, Mode mode = Mode::kPrimary); // Creates a new command buffer on the given executor by tracing `function` // invocation. All StreamExecutor operations on a Stream argument will be @@ -96,39 +179,133 @@ class CommandBuffer { // can't be updated. // // Command buffer tracing should be used only when it is impossible to use - // explicit construction APIs, e.g. when calling external libraries. - static tsl::StatusOr Trace( + // explicit construction APIs, e.g. when calling external libraries. By + // default we construct traced command buffers in nested mode because the + // primary use case for traced command buffers is to be inserted into primary + // command buffers constructed with explicit APIs. + static absl::StatusOr> Trace( StreamExecutor* executor, - absl::AnyInvocable function, - Mode mode = Mode::kPrimary); + absl::AnyInvocable function, + Mode mode = Mode::kNested); - //===--------------------------------------------------------------------===// - // Command buffer properties - //===--------------------------------------------------------------------===// - - // Returns true if command buffer on a given platform supports conditional - // commands (If, IfThen, While). - static bool SupportsConditionalCommands(const Platform* platform); + // Creates a new command buffer on the given executor by tracing `function` + // invocation using a user provided stream that will be passed to `function`. + static absl::StatusOr> Trace( + StreamExecutor* executor, Stream* stream, + absl::AnyInvocable function, + Mode mode = Mode::kNested); //===--------------------------------------------------------------------===// // Command buffer API //===--------------------------------------------------------------------===// - // Adds a kernel launch command to the command buffer. - tsl::Status Launch(const ThreadDim& threads, const BlockDim& blocks, - const Kernel& kernel, const KernelArgs& args); + // Adds an execution barrier to a given execution scope: all commands added + // before a barrier in a the execution scope will complete before any of the + // commands added after a barrier in the same execution scope. + virtual absl::Status Barrier(StreamExecutor* executor, + ExecutionScopeId execution_scope_id) = 0; - // Adds a nested command buffer to the command buffer. - tsl::Status AddNestedCommandBuffer(const CommandBuffer& nested); + // Adds an execution barrier that synchronizes commands across multiple + // execution scopes. See example #2 in execution scope id documentation. + virtual absl::Status Barrier( + StreamExecutor* executor, + absl::Span execution_scope_ids) = 0; + + // Adds an execution barrier from execution scope `from_execution_scope_id` to + // execution scope `to_execution_scope_id`. See example #3 for details. + virtual absl::Status Barrier(StreamExecutor* executor, + ExecutionScopeId from_execution_scope_id, + ExecutionScopeId to_execution_scope_id) = 0; + + // Adds an execution barrier to the default execution scope. + absl::Status Barrier(StreamExecutor* executor) { + return Barrier(executor, kDefaulExecutionScope); + } + + // Adds a kernel launch command. + virtual absl::Status Launch(ExecutionScopeId execution_scope_id, + const ThreadDim& threads, const BlockDim& blocks, + const Kernel& kernel, const KernelArgs& args) = 0; + + // Adds a kernel launch command to the default execution scope. + absl::Status Launch(const ThreadDim& threads, const BlockDim& blocks, + const Kernel& kernel, const KernelArgs& args) { + return Launch(kDefaulExecutionScope, threads, blocks, kernel, args); + } - // Adds a device-to-device memory copy to the command buffer. - tsl::Status MemcpyDeviceToDevice(DeviceMemoryBase* dst, - const DeviceMemoryBase& src, uint64_t size); + // Type-safe wrapper for launching typed kernels. Notice that the order of + // arguments is different do disambiguate from the regular launch API. + template + absl::Status Launch(const TypedKernel& kernel, + ExecutionScopeId execution_scope_id, + const ThreadDim& threads, const BlockDim& blocks, + Args... args); - // Adds a memset node to the command buffer. + // Type-safe wrapper for launching typed kernels in default execution scope. + template + absl::Status Launch(const TypedKernel& kernel, + const ThreadDim& threads, const BlockDim& blocks, + Args... args) { + return Launch(kernel, kDefaulExecutionScope, threads, blocks, args...); + } + + // Adds a nested command buffer. + virtual absl::Status AddNestedCommandBuffer( + ExecutionScopeId execution_scope_id, const CommandBuffer& nested) = 0; + + // Adds a nested command buffer to the default execution scope. + absl::Status AddNestedCommandBuffer(const CommandBuffer& nested) { + return AddNestedCommandBuffer(kDefaulExecutionScope, nested); + } + + // Adds a device-to-device memory copy. + virtual absl::Status MemcpyDeviceToDevice(ExecutionScopeId execution_scope_id, + DeviceMemoryBase* dst, + const DeviceMemoryBase& src, + uint64_t size) = 0; + + // Adds a device-to-device memory copy to the default execution scope. + absl::Status MemcpyDeviceToDevice(DeviceMemoryBase* dst, + const DeviceMemoryBase& src, + uint64_t size) { + return MemcpyDeviceToDevice(kDefaulExecutionScope, dst, src, size); + } + + // Supported bit patterns for memset commands. using BitPattern = std::variant; - tsl::Status Memset(DeviceMemoryBase* dst, BitPattern bit_pattern, - size_t num_elements); + + // Adds a memset command. + virtual absl::Status Memset(ExecutionScopeId execution_scope_id, + DeviceMemoryBase* dst, BitPattern bit_pattern, + size_t num_elements) = 0; + + // Adds a memset command to the default execution scope. + absl::Status Memset(DeviceMemoryBase* dst, BitPattern bit_pattern, + size_t num_elements) { + return Memset(kDefaulExecutionScope, dst, bit_pattern, num_elements); + } + + //--------------------------------------------------------------------------// + // Command buffer memory allocation API + //--------------------------------------------------------------------------// + + // Adds a device memory allocation command. + virtual absl::StatusOr Allocate( + ExecutionScopeId execution_scope_id, size_t bytes) = 0; + + // Adds a device memory allocation command to the default execution scope. + absl::StatusOr Allocate(size_t bytes) { + return Allocate(kDefaulExecutionScope, bytes); + } + + // Adds a device memory free command. + virtual absl::Status Free(ExecutionScopeId execution_scope_id, + DeviceMemoryBase dst) = 0; + + // Adds a device memory free command to the default execution scope. + absl::Status Free(DeviceMemoryBase dst) { + return Free(kDefaulExecutionScope, dst); + } //--------------------------------------------------------------------------// // Command buffer condtitional commands API @@ -136,34 +313,66 @@ class CommandBuffer { // Adds a conditional operation that will execute a command buffer constructed // by `then_builder` if `pred` value is `true`. - tsl::Status If(StreamExecutor* executor, DeviceMemory pred, - Builder then_builder); + virtual absl::Status If(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, DeviceMemory pred, + Builder then_builder) = 0; + + // Adds a conditional If operation to default execution scope. + absl::Status If(StreamExecutor* executor, DeviceMemory pred, + Builder then_builder) { + return If(kDefaulExecutionScope, executor, pred, then_builder); + } // Adds a conditional operation that will execute a command buffer constructed // by `then_builder` if `pred` value is `true`, or a command buffer // constructed by `else_builder` if `pred` is `false`. - tsl::Status IfElse(StreamExecutor* executor, DeviceMemory pred, - Builder then_builder, Builder else_builder); + virtual absl::Status IfElse(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, DeviceMemory pred, + Builder then_builder, Builder else_builder) = 0; + + // Adds a conditional IfElse operation to default execution scope. + absl::Status IfElse(StreamExecutor* executor, DeviceMemory pred, + Builder then_builder, Builder else_builder) { + return IfElse(kDefaulExecutionScope, executor, pred, then_builder, + else_builder); + } // Adds a conditional operation that will execute a command buffer constructed // by the `branches` builder at `index`. If `index` is out of range, then it // will run a conditional command buffer constructed by the last builder. // // See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case - tsl::Status Case(StreamExecutor* executor, DeviceMemory index, - std::vector branches); + virtual absl::Status Case(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, + DeviceMemory index, + std::vector branches) = 0; + + // Adds a conditional Case operation to default execution scope. + absl::Status Case(StreamExecutor* executor, DeviceMemory index, + std::vector branches) { + return Case(kDefaulExecutionScope, executor, index, branches); + } // Adds a conditional operation that will execute a command buffer constructed // by the `body_builder` exactly `num_iteration` times. This means the // condition is known at compile time (`num_iteration` < `loop_counter`), and // does not require a `cond_builder`. - tsl::Status For(StreamExecutor* executor, int32_t num_iteration, - DeviceMemory loop_counter, Builder body_builder); + virtual absl::Status For(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, int32_t num_iteration, + DeviceMemory loop_counter, + Builder body_builder) = 0; + + // Adds a conditional For operation to default execution scope. + absl::Status For(StreamExecutor* executor, int32_t num_iteration, + DeviceMemory loop_counter, Builder body_builder) { + return For(kDefaulExecutionScope, executor, num_iteration, loop_counter, + body_builder); + } // Adds a conditional operation that will execute a command buffer constructed // by the `cond_builder` that must update `pred` value, and then depending on // the value might execute command buffer constructed by `body_builder` and - // `cond_builder`. Will continue while `pred` value (which is continously + // `cond_builder`. Will continue while `pred` value (which is continuously // updated by `cond_builder`) is `true`. // // In pseudocode: @@ -173,71 +382,52 @@ class CommandBuffer { // body_builder() // cond_builder() // - tsl::Status While(StreamExecutor* executor, DeviceMemory pred, - Builder cond_builder, Builder body_builder); + // We use execution scope builder for the condition because we have to build + // condition twice: (1) before the conditional node in the scope defined by + // `execution_scope_id` (2) inside the loop body with default execution scope. + virtual absl::Status While(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, DeviceMemory pred, + ExecutionScopeBuilder cond_builder, + Builder body_builder) = 0; + + // Adds a conditional While operation to default execution scope. + absl::Status While(StreamExecutor* executor, DeviceMemory pred, + ExecutionScopeBuilder cond_builder, Builder body_builder) { + return While(kDefaulExecutionScope, executor, pred, cond_builder, + body_builder); + } //--------------------------------------------------------------------------// - - // Adds a device memory allocation command to the command buffer. - tsl::StatusOr Allocate(size_t bytes); + // Command buffer state management API + //--------------------------------------------------------------------------// // Finalizes command buffer and makes it executable. Once command buffer is // finalized no commands can be added to it. - tsl::Status Finalize(); + virtual absl::Status Finalize() = 0; // Begins command buffer update. Command buffer update should be finalized // before it can be executed. - tsl::Status Update(); - - // Type-safe wrapper for launching typed kernels. Notice that the order of - // arguments is different do disambiguate from the regular launch API. - template - tsl::Status Launch(const TypedKernel& kernel, - const ThreadDim& threads, const BlockDim& blocks, - Args... args); + virtual absl::Status Update() = 0; // Returns command buffer execution mode. - Mode mode() const; + virtual Mode mode() const = 0; // Returns command buffer state. - State state() const; - - //===--------------------------------------------------------------------===// - // Semi-internal APIs - //===--------------------------------------------------------------------===// - - // Following APIs are public, but considered to be implementation detail and - // discouraged from uses outside of StreamExecutor package. - const internal::CommandBufferInterface* implementation() const; - internal::CommandBufferInterface* implementation(); - - // Creates a command buffer from a platform-specific command buffer - // implementation. - static CommandBuffer Create( - std::unique_ptr implementation); - - // An adaptor for a command buffer builder that records commands into the - // platform-specific implementation - static tsl::Status Build(internal::CommandBufferInterface* implementation, - const CommandBuffer::Builder& builder); + virtual State state() const = 0; + //--------------------------------------------------------------------------// + // Command buffer tracing API + //--------------------------------------------------------------------------// private: - explicit CommandBuffer( - std::unique_ptr implementation); - - explicit CommandBuffer(internal::CommandBufferInterface* implementation); - - // A custom deleter to be able to construct command buffer that doesn't own - // underlying implementation (behaves like std::weak_ptr for implementation). - struct Deleter { - void operator()(internal::CommandBufferInterface*); - bool owned = true; - }; - - std::unique_ptr implementation_; - - CommandBuffer(const CommandBuffer&) = delete; - void operator=(const CommandBuffer&) = delete; + // Tracing APIs are private because they do not compose with command buffer + // updates. Instead of tracing directly into the command buffer users should + // create traced command buffers using factory methods and add them to primary + // command buffers as nested operations. + + // Traces `function` invocation by recording all operations on the `stream` + // into the command buffer. Command buffer must be empty. + virtual absl::Status Trace(Stream* stream, + absl::AnyInvocable function) = 0; }; //===----------------------------------------------------------------------===// @@ -245,12 +435,15 @@ class CommandBuffer { //===----------------------------------------------------------------------===// template -inline tsl::Status CommandBuffer::Launch(const TypedKernel& kernel, - const ThreadDim& threads, - const BlockDim& blocks, Args... args) { +inline absl::Status CommandBuffer::Launch(const TypedKernel& kernel, + ExecutionScopeId execution_scope_id, + const ThreadDim& threads, + const BlockDim& blocks, + Args... args) { auto kernel_args = PackKernelArgs(kernel, args...); - TF_RETURN_IF_ERROR(Launch(threads, blocks, kernel, *kernel_args)); - return tsl::OkStatus(); + TF_RETURN_IF_ERROR( + Launch(execution_scope_id, threads, blocks, *kernel, *kernel_args)); + return absl::OkStatus(); } } // namespace stream_executor diff --git a/xla/stream_executor/cuda/BUILD b/xla/stream_executor/cuda/BUILD index 4239022fc75e4..75dc1cba3402f 100644 --- a/xla/stream_executor/cuda/BUILD +++ b/xla/stream_executor/cuda/BUILD @@ -1,19 +1,5 @@ -load("//xla/tests:build_defs.bzl", "xla_test") load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") -load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") -load( - "//xla:xla.bzl", - "xla_cc_test", -) -load( - "//xla/stream_executor:build_defs.bzl", - "stream_executor_friends", - "tf_additional_cuda_platform_deps", - "tf_additional_cudnn_plugin_copts", - "tf_additional_cudnn_plugin_deps", - "tf_additional_gpu_compilation_copts", -) -load("@tsl//tsl:tsl.bzl", "if_google", "set_external_visibility", "tsl_copts") +load("@tsl//tsl:tsl.bzl", "if_google", "if_nccl", "internal_visibility", "tsl_copts") load( "@tsl//tsl/platform:build_config_root.bzl", "if_static", @@ -27,10 +13,22 @@ load( "@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) +load( + "//xla:xla.bzl", + "xla_cc_test", +) +load( + "//xla/stream_executor:build_defs.bzl", + "cuda_only_cc_library", + "stream_executor_friends", + "tf_additional_cuda_platform_deps", + "tf_additional_cudnn_plugin_copts", + "tf_additional_gpu_compilation_copts", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([":friends"]), + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -39,18 +37,19 @@ package_group( packages = stream_executor_friends(), ) -# Add `--//third_party/tensorflow/compiler/xla/stream_executor/cuda:enable_graph_conditional` to -# build command to enable CUDA graph conditional nodes support. Requires CUDA >=12.3. -# -# See: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#conditional-graph-nodes bool_flag( - name = "enable_graph_conditional", - build_setting_default = False, + name = "enable_libnvptxcompiler_support", + build_setting_default = if_google( + True, + oss_value = False, + ), ) config_setting( - name = "graph_conditional_enabled", - flag_values = {":enable_graph_conditional": "True"}, + name = "libnvptxcompiler_support_enabled", + flag_values = { + ":enable_libnvptxcompiler_support": "True", + }, ) cc_library( @@ -60,51 +59,56 @@ cc_library( deps = ["//xla/stream_executor:platform"], ) -cc_library( +cuda_only_cc_library( name = "cuda_platform", - srcs = if_cuda_is_configured(["cuda_platform.cc"]), - hdrs = if_cuda_is_configured(["cuda_platform.h"]), + srcs = ["cuda_platform.cc"], + hdrs = ["cuda_platform.h"], visibility = ["//visibility:public"], - deps = if_cuda_is_configured( + deps = [ ":cuda_activation", + ":cuda_collectives", ":cuda_driver", - ":cuda_runtime", ":cuda_executor", ":cuda_platform_id", + ":cuda_runtime", + "//xla/stream_executor", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_internal", + "//xla/stream_executor/gpu:gpu_driver_header", + "//xla/stream_executor/gpu:gpu_executor_header", + "//xla/stream_executor/platform", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "//xla/stream_executor", - "//xla/stream_executor:stream_executor_internal", - "//xla/stream_executor:multi_platform_manager", - "//xla/stream_executor/gpu:gpu_executor_header", - "//xla/stream_executor/platform", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", - ], - ) + tf_additional_cuda_platform_deps(), - alwayslink = True, # Registers itself with the MultiPlatformManager. + ] + tf_additional_cuda_platform_deps(), + alwayslink = True, # Registers itself with the PlatformManager. ) -cc_library( +cuda_only_cc_library( name = "cuda_diagnostics", - srcs = if_cuda_is_configured(["cuda_diagnostics.cc"]), - hdrs = if_cuda_is_configured(["cuda_diagnostics.h"]), - deps = if_cuda_is_configured([ + srcs = ["cuda_diagnostics.cc"], + hdrs = ["cuda_diagnostics.h"], + deps = [ + "//xla/stream_executor/gpu:gpu_diagnostics_header", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "//xla/stream_executor/gpu:gpu_diagnostics_header", - "//xla/stream_executor/platform", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:platform_port", - "@tsl//tsl/platform:status", - ]), + ], ) # Buildozer can not remove dependencies inside select guards, so we have to use @@ -117,190 +121,162 @@ cc_library(name = "nvlink_wrapper") # an intermediate target. cc_library(name = "fatbinary_wrapper") -cc_library( +cuda_only_cc_library( name = "cuda_driver", - srcs = if_cuda_is_configured(["cuda_driver.cc"]), - hdrs = if_cuda_is_configured(["cuda_driver.h"]), - deps = if_cuda_is_configured([ - ":cuda_diagnostics", + srcs = ["cuda_driver.cc"], + hdrs = ["cuda_driver.h"], + deps = [ + ":cuda_diagnostics", # buildcleaner: keep + "//xla/stream_executor", + "//xla/stream_executor/gpu:gpu_diagnostics_header", + "//xla/stream_executor/gpu:gpu_driver_header", + "//xla/stream_executor/gpu:gpu_types_header", + "//xla/tsl/cuda", + "//xla/tsl/cuda:cudart", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/debugging:leak_check", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_config_cuda//cuda:cuda_headers", - "//xla/stream_executor:device_options", - "//xla/stream_executor:stream_executor_headers", - "//xla/stream_executor/gpu:gpu_diagnostics_header", - "//xla/stream_executor/gpu:gpu_driver_header", - "//xla/stream_executor/gpu:gpu_types_header", - "//xla/stream_executor/platform", - "//xla/stream_executor/platform:dso_loader", - "@tsl//tsl/cuda", - "@tsl//tsl/cuda:cudart", "@tsl//tsl/platform:env", - "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:macros", "@tsl//tsl/platform:numbers", "@tsl//tsl/platform:stacktrace", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", - ]), + ], ) -cc_library( +cuda_only_cc_library( name = "cuda_runtime", - srcs = if_cuda_is_configured(["cuda_runtime.cc"]), + srcs = ["cuda_runtime.cc"], deps = [ "//xla/stream_executor/gpu:gpu_runtime_header", "//xla/stream_executor/gpu:gpu_types_header", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@tsl//tsl/platform:statusor", - ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", - ]), + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", + ], ) -xla_cc_test( - name = "stream_search_test", - size = "small", - srcs = ["stream_search_test.cc"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - tags = tf_cuda_tests_tags(), +cuda_only_cc_library( + name = "cuda_collectives", + srcs = ["cuda_collectives.cc"], + defines = if_nccl(["STREAM_EXECUTOR_GPU_ENABLE_XCCL"]), deps = [ - ":cuda_platform", - "//xla/stream_executor", - "//xla/stream_executor/host:host_platform", - "@tsl//tsl/platform:test", - "@tsl//tsl/platform:test_main", - ], + ":cuda_driver", + "//xla/stream_executor/gpu:gpu_collectives_header", + "//xla/stream_executor/gpu:gpu_driver_header", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:numbers", + "@tsl//tsl/platform:statusor", + ] + if_nccl(["@local_config_nccl//:nccl"]), ) xla_cc_test( name = "cuda_driver_test", srcs = ["cuda_driver_test.cc"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), tags = tf_cuda_tests_tags() + [ - "no_cuda_asan", # TODO(b/171512140): re-enable. "no_rocm", + "requires-gpu-nvidia", ], deps = [ ":cuda_driver", - "@com_google_absl//absl/memory", + "//xla/stream_executor/gpu:gpu_driver_header", + "@com_google_absl//absl/log", "@local_config_cuda//cuda:cuda_headers", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], ) -xla_cc_test( - name = "memcpy_test", - srcs = ["memcpy_test.cc"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - tags = tf_cuda_tests_tags() + [ - "no_cuda_asan", # TODO(b/171512140): re-enable. - ], - deps = [ - ":cuda_platform", - "//xla/stream_executor", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:multi_platform_manager", - "@com_google_absl//absl/memory", - "@tsl//tsl/platform:test", - "@tsl//tsl/platform:test_main", - ], -) - # The activation library is tightly coupled to the executor library. # TODO(leary) split up cuda_executor.cc so that this can stand alone. cc_library( name = "cuda_activation_header", hdrs = ["cuda_activation.h"], visibility = ["//visibility:public"], - deps = [ - "//xla/stream_executor/gpu:gpu_activation_header", - "//xla/stream_executor/platform", - ], + deps = ["//xla/stream_executor/gpu:gpu_activation_header"], ) -cc_library( +cuda_only_cc_library( name = "cuda_activation", srcs = [], - hdrs = if_cuda_is_configured(["cuda_activation.h"]), - deps = if_cuda_is_configured([ + hdrs = ["cuda_activation.h"], + deps = [ ":cuda_driver", - "@local_config_cuda//cuda:cuda_headers", "//xla/stream_executor", "//xla/stream_executor:stream_executor_internal", "//xla/stream_executor/gpu:gpu_activation", "//xla/stream_executor/platform", - ]), + "@local_config_cuda//cuda:cuda_headers", + ], ) -cc_library( +cuda_only_cc_library( name = "cublas_lt_header", - hdrs = if_cuda_is_configured([ + hdrs = [ "cuda_blas_lt.h", "cuda_blas_utils.h", - ]), + ], visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", + deps = [ "//xla:types", - "//xla/stream_executor:host_or_device_scalar", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_blas_lt", - "//xla/stream_executor/platform", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@local_config_cuda//cuda:cuda_headers", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", - ]), + ], ) -cc_library( +cuda_only_cc_library( name = "cublas_plugin", - srcs = if_cuda_is_configured([ + srcs = [ "cuda_blas.cc", "cuda_blas_lt.cc", - ]), - hdrs = if_cuda_is_configured([ + ], + hdrs = [ "cuda_blas.h", "cuda_blas_lt.h", - ]), + ], visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ + deps = [ ":cuda_activation", ":cuda_blas_utils", ":cuda_executor", ":cuda_helpers", ":cuda_platform_id", ":cuda_stream", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@eigen_archive//:eigen3", - "@local_config_cuda//cuda:cuda_headers", "//xla:shape_util", "//xla:status_macros", "//xla:types", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/stream_executor", - "//xla/stream_executor:plugin_registry", "//xla/stream_executor:device_memory", "//xla/stream_executor:host_or_device_scalar", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor:plugin_registry", "//xla/stream_executor/gpu:gpu_activation_header", "//xla/stream_executor/gpu:gpu_blas_lt", "//xla/stream_executor/gpu:gpu_executor_header", @@ -309,57 +285,70 @@ cc_library( "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/gpu:gpu_types_header", "//xla/stream_executor/platform", - "@tsl//tsl/cuda:cublas", - "@tsl//tsl/cuda:cublas_lt", + "//xla/tsl/cuda:cublas", + "//xla/tsl/cuda:cublas_lt", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@eigen_archive//:eigen3", + "@local_config_cuda//cuda:cuda_headers", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:tensor_float_32_hdr_lib", - ]) + if_static([ + "@tsl//tsl/protobuf:dnn_proto_cc", + ] + if_static([ "@tsl//tsl/platform:tensor_float_32_utils", ]), alwayslink = True, ) -cc_library( +cuda_only_cc_library( name = "cuda_blas_utils", - srcs = if_cuda_is_configured(["cuda_blas_utils.cc"]), - hdrs = if_cuda_is_configured(["cuda_blas_utils.h"]), - deps = if_cuda_is_configured([ + srcs = ["cuda_blas_utils.cc"], + hdrs = ["cuda_blas_utils.h"], + deps = [ + "//xla/stream_executor", + "//xla/tsl/cuda:cublas", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@local_config_cuda//cuda:cuda_headers", - "//xla/stream_executor:stream_executor_headers", - "@tsl//tsl/cuda:cublas", "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:status", - ]), + ], ) -cc_library( +cuda_only_cc_library( name = "cufft_plugin", - srcs = if_cuda_is_configured(["cuda_fft.cc"]), - hdrs = if_cuda_is_configured(["cuda_fft.h"]), + srcs = ["cuda_fft.cc"], + hdrs = ["cuda_fft.h"], visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ + deps = [ ":cuda_activation_header", - ":cuda_helpers", ":cuda_platform_id", - ":cuda_stream", - "@com_google_absl//absl/strings", - "@local_config_cuda//cuda:cuda_headers", "//xla/stream_executor", - "//xla/stream_executor:stream_executor_internal", "//xla/stream_executor:fft", "//xla/stream_executor:plugin_registry", + "//xla/stream_executor:stream_executor_internal", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_helpers_header", + "//xla/stream_executor/gpu:gpu_stream_header", "//xla/stream_executor/platform", - "//xla/stream_executor/platform:dso_loader", - "@tsl//tsl/cuda:cufft", - "@tsl//tsl/platform:errors", + "//xla/tsl/cuda:cufft", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@local_config_cuda//cuda:cuda_headers", "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:status", - ]), + "@tsl//tsl/platform:statusor", + ], alwayslink = True, ) @@ -370,120 +359,92 @@ cc_library( ":cuda_activation_header", "//xla/stream_executor:dnn", "//xla/stream_executor:plugin_registry", - ]), + ]) + [ + "//xla/stream_executor", # build_cleaner: keep + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:status", + "@tsl//tsl/protobuf:dnn_proto_cc", + ], ) -cc_library( +cuda_only_cc_library( name = "cudnn_plugin", - srcs = if_cuda_is_configured(["cuda_dnn.cc"]), - hdrs = if_cuda_is_configured(["cuda_dnn.h"]), + srcs = ["cuda_dnn.cc"], + hdrs = ["cuda_dnn.h"], copts = tf_additional_cudnn_plugin_copts(), visibility = ["//visibility:public"], - deps = if_cuda_is_configured([ + deps = [ ":cuda_activation", ":cuda_diagnostics", ":cuda_driver", ":cuda_executor", ":cuda_platform_id", ":cuda_stream", + ":cudnn_frontend_helpers", + "//xla/stream_executor", + "//xla/stream_executor:dnn", + "//xla/stream_executor:plugin_registry", + "//xla/stream_executor:stream_executor_internal", + "//xla/stream_executor/gpu:gpu_activation_header", + "//xla/stream_executor/gpu:gpu_diagnostics_header", + "//xla/stream_executor/gpu:gpu_driver_header", + "//xla/stream_executor/gpu:gpu_executor_header", + "//xla/stream_executor/gpu:gpu_stream", + "//xla/stream_executor/gpu:gpu_timer_header", + "//xla/stream_executor/platform", + "//xla/tsl/cuda:cudnn", + "//xla/tsl/util:env_var", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@cudnn_frontend_archive//:cudnn_frontend", "@eigen_archive//:eigen3", "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cudnn_header", - "//xla/stream_executor:dnn", - "//xla/stream_executor:plugin_registry", - "//xla/stream_executor:stream_executor_internal", - "//xla/stream_executor", - "//xla/stream_executor/gpu:gpu_executor_header", - "//xla/stream_executor/gpu:gpu_timer_header", - "//xla/stream_executor/platform", - "@tsl//tsl/cuda:cudnn", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:tensor_float_32_hdr_lib", "@tsl//tsl/platform:tensor_float_32_utils", - "@tsl//tsl/util:env_var", - ]) + tf_additional_cudnn_plugin_deps(), + "@tsl//tsl/protobuf:dnn_proto_cc", + ], alwayslink = True, ) -cc_library( +cuda_only_cc_library( name = "cuda_kernel", - srcs = if_cuda_is_configured(["cuda_kernel.cc"]), - hdrs = if_cuda_is_configured(["cuda_kernel.h"]), - deps = if_cuda_is_configured([ + srcs = ["cuda_kernel.cc"], + hdrs = ["cuda_kernel.h"], + deps = [ ":cuda_driver", - "@local_config_cuda//cuda:cuda_headers", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor", + "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/gpu:gpu_kernel_header", "//xla/stream_executor/platform", - ]), -) - -cuda_library( - name = "cuda_test_kernels", - testonly = 1, - srcs = if_cuda_is_configured(["cuda_test_kernels.cu.cc"]), - hdrs = if_cuda_is_configured(["cuda_test_kernels.h"]), - deps = ["@local_config_cuda//cuda:cuda_headers"], -) - -cuda_library( - name = "cuda_conditional_kernels", - srcs = if_cuda_is_configured(["cuda_conditional_kernels.cu.cc"]), - local_defines = select({ - ":graph_conditional_enabled": ["STREAM_EXECUTOR_CUDA_ENABLE_GRAPH_CONDITIONAL=1"], - "//conditions:default": [], - }), - deps = ["@local_config_cuda//cuda:cuda_headers"], -) - -xla_test( - name = "cuda_kernel_test", - srcs = if_cuda_is_configured(["cuda_kernel_test.cc"]), - backends = ["gpu"], - deps = [ - ":cuda_platform", - ":cuda_test_kernels", - "//xla/stream_executor", - "//xla/stream_executor:multi_platform_manager", - "//xla/stream_executor:platform", - "@com_google_googletest//:gtest", - "@tsl//tsl/platform:test", - "@tsl//tsl/platform:test_main", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@local_config_cuda//cuda:cuda_headers", + "@tsl//tsl/platform:statusor", ], ) -xla_test( - name = "cuda_command_buffer_test", - srcs = if_cuda_is_configured(["cuda_command_buffer_test.cc"]), - backends = ["gpu"], - deps = [ - ":cuda_platform", - ":cuda_runtime", - ":cuda_test_kernels", - "//xla/stream_executor", - "//xla/stream_executor:multi_platform_manager", - "//xla/stream_executor:platform", - "@com_google_absl//absl/log:check", - "@local_config_cuda//cuda:cuda_headers", - "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:test", - "@tsl//tsl/platform:test_benchmark", - "@tsl//tsl/platform:test_main", - ], +cc_library( + name = "cuda_conditional_kernels", + srcs = ["cuda_conditional_kernels.cc"], ) # TODO(leary) we likely need to canonicalize/eliminate this. @@ -495,91 +456,192 @@ cc_library( ]), ) -cc_library( +cuda_only_cc_library( name = "cuda_event", - srcs = if_cuda_is_configured(["cuda_event.cc"]), - hdrs = if_cuda_is_configured(["cuda_event.h"]), - deps = if_cuda_is_configured([ - ":cuda_driver", - ":cuda_stream", - "//xla/stream_executor:stream_executor_headers", + srcs = ["cuda_event.cc"], + hdrs = ["cuda_event.h"], + deps = [ + "//xla/stream_executor", + "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/gpu:gpu_event", "//xla/stream_executor/gpu:gpu_executor_header", - "//xla/stream_executor/gpu:gpu_stream_header", - "@tsl//tsl/platform:statusor", - ]), + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@local_config_cuda//cuda:cuda_headers", + ], ) -cc_library( +cuda_only_cc_library( name = "cuda_stream", srcs = [], - hdrs = if_cuda_is_configured(["cuda_stream.h"]), - deps = if_cuda_is_configured([ - ":cuda_driver", - "//xla/stream_executor:stream_executor_headers", + hdrs = ["cuda_stream.h"], + deps = [ + "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_stream", - "//xla/stream_executor/platform", - ]), + ], ) cc_library( - name = "cuda_asm_compiler", - srcs = if_cuda_is_configured(["cuda_asm_compiler.cc"]), - copts = tf_additional_gpu_compilation_copts(), - deps = if_cuda_is_configured([ - "@com_google_absl//absl/base", + name = "ptx_compiler_support", + srcs = ["ptx_compiler_support.cc"], + hdrs = ["ptx_compiler_support.h"], + local_defines = select({ + ":libnvptxcompiler_support_enabled": [ + "LIBNVPTXCOMPILER_SUPPORT=true", + ], + "//conditions:default": [ + "LIBNVPTXCOMPILER_SUPPORT=false", + ], + }), +) + +cc_library( + name = "ptx_compiler_stub", + srcs = [ + "ptx_compiler_stub.cc", + ], + deps = [ + "//xla/stream_executor/gpu:gpu_asm_opts", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "ptx_compiler_impl", + srcs = [ + "ptx_compiler_impl.cc", + ], + tags = ["manual"], + deps = [ + "//xla/stream_executor/gpu:gpu_asm_opts", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:nvptxcompiler", + ], +) + +cc_library( + name = "ptx_compiler", + hdrs = ["ptx_compiler.h"], + deps = select({ + ":libnvptxcompiler_support_enabled": [":ptx_compiler_impl"], + "//conditions:default": [":ptx_compiler_stub"], + }) + [ + "//xla/stream_executor/gpu:gpu_asm_opts", + "@com_google_absl//absl/status:statusor", + ], +) + +xla_cc_test( + name = "ptx_compiler_test", + srcs = ["ptx_compiler_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":ptx_compiler", + ":ptx_compiler_support", + "//xla/stream_executor:device_description", + "//xla/stream_executor/gpu:gpu_asm_opts", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + +cuda_only_cc_library( + name = "cuda_asm_compiler", + hdrs = ["cuda_asm_compiler.h"], + visibility = internal_visibility([ + "//third_party/py/jax:__subpackages__", + "//tensorflow/compiler/mlir/tools/kernel_gen:__subpackages__", + "//xla/service/gpu:__subpackages__", + "//xla/stream_executor:__subpackages__", + "//tensorflow/core/kernels:__subpackages__", + ]), + deps = [ + ":cuda_asm_compiler_legacy", + "//xla/stream_executor/gpu:asm_compiler", + ], +) + +cuda_only_cc_library( + name = "cuda_asm_compiler_legacy", + srcs = ["cuda_asm_compiler.cc"], + copts = tf_additional_gpu_compilation_copts(), + deps = [ "//xla:status_macros", + "//xla/stream_executor:stream_executor_headers", "//xla/stream_executor/gpu:asm_compiler_header", - "//xla/stream_executor/gpu:gpu_diagnostics_header", "//xla/stream_executor/gpu:gpu_driver_header", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@local_config_cuda//cuda:cuda_headers", + "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:subprocess", - ]), + ], ) -cc_library( +cuda_only_cc_library( name = "cuda_executor", - srcs = if_cuda_is_configured(["cuda_executor.cc"]), - deps = if_cuda_is_configured([ - ":cuda_activation", - ":cuda_asm_compiler", + srcs = ["cuda_executor.cc"], + deps = [ + ":cuda_collectives", # buildcleaner: keep ":cuda_diagnostics", ":cuda_driver", - ":cuda_event", - ":cuda_kernel", + ":cuda_event", # buildcleaner: keep + ":cuda_kernel", # buildcleaner: keep ":cuda_platform_id", - ":cuda_runtime", - ":cuda_stream", - "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@com_google_absl//absl/strings:str_format", + ":cuda_runtime", # buildcleaner: keep "//xla/stream_executor", "//xla/stream_executor:plugin_registry", "//xla/stream_executor:stream_executor_internal", - "//xla/stream_executor/gpu:asm_compiler", + "//xla/stream_executor/gpu:gpu_collectives_header", "//xla/stream_executor/gpu:gpu_command_buffer", + "//xla/stream_executor/gpu:gpu_diagnostics_header", "//xla/stream_executor/gpu:gpu_driver_header", - "//xla/stream_executor/gpu:gpu_runtime_header", "//xla/stream_executor/gpu:gpu_event_header", - "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_kernel_header", + "//xla/stream_executor/gpu:gpu_runtime_header", + "//xla/stream_executor/gpu:gpu_stream_header", "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/gpu:gpu_types_header", - "//xla/stream_executor/platform", - "//xla/stream_executor/platform:dso_loader", + "//xla/stream_executor/integrations:device_mem_allocator", + "@com_google_absl//absl/base", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@local_config_cuda//cuda:cuda_headers", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:fingerprint", "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:numbers", - "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", - ]), + ], alwayslink = True, ) @@ -593,9 +655,9 @@ cc_library( ":cuda_platform", ":cudnn_plugin", ":cufft_plugin", - "@tsl//tsl/cuda:cusolver", - "@tsl//tsl/cuda:cusparse", - "@tsl//tsl/cuda:tensorrt_rpath", + "//xla/tsl/cuda:cusolver", + "//xla/tsl/cuda:cusparse", + "//xla/tsl/cuda:tensorrt_rpath", ], alwayslink = 1, ) @@ -621,7 +683,7 @@ cc_library( ], }), [ - "@tsl//tsl/cuda:cudart", + "//xla/tsl/cuda:cudart", ] + select({ "@tsl//tsl:macos": ["IOKit"], "//conditions:default": [], @@ -629,22 +691,12 @@ cc_library( ), ) -xla_cc_test( - name = "redzone_allocator_test", - srcs = ["redzone_allocator_test.cc"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - tags = tf_cuda_tests_tags() + [ - "no_cuda_asan", # TODO(b/171512140): re-enable. - ], - deps = [ - ":cuda_activation", - ":cuda_executor", - "//xla/stream_executor", - "//xla/stream_executor:device_memory_allocator", - "//xla/stream_executor/gpu:gpu_asm_opts", - "//xla/stream_executor/gpu:redzone_allocator", - "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:test", - "@tsl//tsl/platform:test_main", - ], +cc_library( + name = "cudnn_frontend_helpers", + srcs = ["cudnn_frontend_helpers.cc"], + hdrs = ["cudnn_frontend_helpers.h"], +) + +cc_library( + name = "cuda_nvptxcompiler", ) diff --git a/xla/stream_executor/cuda/cuda_activation.h b/xla/stream_executor/cuda/cuda_activation.h index bcdc8f5e178ef..f43737c4f226c 100644 --- a/xla/stream_executor/cuda/cuda_activation.h +++ b/xla/stream_executor/cuda/cuda_activation.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/cuda/cuda_asm_compiler.cc b/xla/stream_executor/cuda/cuda_asm_compiler.cc index 50df8ba769964..36cae5b75685a 100644 --- a/xla/stream_executor/cuda/cuda_asm_compiler.cc +++ b/xla/stream_executor/cuda/cuda_asm_compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,19 +13,33 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include +#include #include +#include #include -#include "absl/base/attributes.h" -#include "absl/base/call_once.h" +#include "absl/base/optimization.h" #include "absl/cleanup/cleanup.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "third_party/gpus/cuda/include/cuda.h" #include "xla/status_macros.h" #include "xla/stream_executor/gpu/asm_compiler.h" -#include "xla/stream_executor/gpu/gpu_diagnostics.h" #include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/env.h" #include "tsl/platform/errors.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/subprocess.h" namespace stream_executor { @@ -39,20 +53,36 @@ namespace stream_executor { std::ostringstream oss; \ oss << error_string << "\nin " << __FILE__ << "(" << __LINE__ << "): '" \ << #expr << "'"; \ - return tsl::Status(absl::StatusCode::kUnknown, oss.str().c_str()); \ + return absl::UnknownError(oss.str().c_str()); \ } \ } while (false) -tsl::StatusOr> LinkUsingNvlink( +static absl::StatusOr FindNvlinkExecutable( + std::string_view preferred_cuda_dir) { + static constexpr ToolVersion kMinimumNvlinkVersion{11, 8, 0}; + static constexpr absl::Span kNoExcludedVersions{}; + static constexpr std::string_view kNvLinkBinaryName = "nvlink"; + + return FindCudaExecutable(kNvLinkBinaryName, preferred_cuda_dir, + kMinimumNvlinkVersion, kNoExcludedVersions); +} + +absl::StatusOr GetNvLinkVersion( + std::string_view preferred_cuda_dir) { + // Make sure nvlink exists and is executable. + TF_ASSIGN_OR_RETURN(std::string bin_path, + FindNvlinkExecutable(preferred_cuda_dir)); + + return GetToolVersion(bin_path); +} + +absl::StatusOr> LinkUsingNvlink( absl::string_view preferred_cuda_dir, gpu::GpuContext* context, std::vector images) { - { - static absl::once_flag log_once; - absl::call_once(log_once, - [] { LOG(INFO) << "Using nvlink for parallel linking"; }); - } - const std::string bin_path = - FindCudaExecutable("nvlink", std::string(preferred_cuda_dir)); + LOG_FIRST_N(INFO, 1) << "Using nvlink for parallel linking"; + + TF_ASSIGN_OR_RETURN(std::string bin_path, + FindNvlinkExecutable(preferred_cuda_dir)); if (images.empty()) { return std::vector(); @@ -101,14 +131,14 @@ tsl::StatusOr> LinkUsingNvlink( tsl::SubProcess process; process.SetProgram(bin_path, args); process.SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE); - + VLOG(5)<<"subprocess running:"<> LinkUsingNvlink( return cubin_vector; } -tsl::StatusOr> LinkGpuAsm( +absl::StatusOr> LinkGpuAsm( gpu::GpuContext* context, std::vector images) { gpu::ScopedActivateContext activation(context); diff --git a/xla/stream_executor/cuda/cuda_asm_compiler.h b/xla/stream_executor/cuda/cuda_asm_compiler.h new file mode 100644 index 0000000000000..d565fc22ba5b4 --- /dev/null +++ b/xla/stream_executor/cuda/cuda_asm_compiler.h @@ -0,0 +1,22 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_ASM_COMPILER_H_ +#define XLA_STREAM_EXECUTOR_CUDA_CUDA_ASM_COMPILER_H_ + +// TODO(hebecker): Move CUDA-specific functions from asm_compiler.h into here +#include "xla/stream_executor/gpu/asm_compiler.h" // IWYU pragma: export + +#endif // XLA_STREAM_EXECUTOR_CUDA_CUDA_ASM_COMPILER_H_ diff --git a/xla/stream_executor/cuda/cuda_blas.cc b/xla/stream_executor/cuda/cuda_blas.cc index a03f609398c1f..ca2bda8a6cecc 100644 --- a/xla/stream_executor/cuda/cuda_blas.cc +++ b/xla/stream_executor/cuda/cuda_blas.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,22 +16,32 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_blas.h" #include +#include #include +#include +#include +#include #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/time/time.h" #include "Eigen/Core" // from @eigen_archive #include "third_party/gpus/cuda/include/cublas_v2.h" #include "third_party/gpus/cuda/include/cuda.h" -#include "xla/stream_executor/cuda/cuda_activation.h" +#include "third_party/gpus/cuda/include/cuda_bf16.h" +#include "third_party/gpus/cuda/include/cuda_fp16.h" +#include "third_party/gpus/cuda/include/driver_types.h" +#include "third_party/gpus/cuda/include/library_types.h" +#include "third_party/gpus/cuda/include/vector_types.h" +#include "xla/stream_executor/blas.h" #include "xla/stream_executor/cuda/cuda_blas_utils.h" -#include "xla/stream_executor/cuda/cuda_helpers.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" -#include "xla/stream_executor/cuda/cuda_stream.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_activation.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_helpers.h" #include "xla/stream_executor/gpu/gpu_stream.h" @@ -43,9 +53,11 @@ limitations under the License. #include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" -#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/tensor_float_32_utils.h" +#include "tsl/protobuf/dnn.pb.h" namespace stream_executor { namespace cuda { @@ -354,14 +366,15 @@ struct CUDADataType> { } // namespace template -tsl::Status CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream, - bool pointer_mode_host, - cublasMath_t math_type, Args... args) { +absl::Status CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream, + bool pointer_mode_host, + cublasMath_t math_type, + Args... args) { absl::MutexLock lock(&mu_); CHECK(blas_ != nullptr); if (!SetStream(stream)) { - return tsl::errors::Internal("Failed setting stream"); + return absl::InternalError("Failed setting stream"); } // Set workspace to a user-owned buffer, otherwise cuBlas will use its own @@ -384,7 +397,7 @@ tsl::Status CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream, if (math_type == CUBLAS_TENSOR_OP_MATH) { #endif if (!math_mode.Init(math_type)) { - return tsl::errors::Internal("Failed initializing math mode"); + return absl::InternalError("Failed initializing math mode"); } } @@ -392,13 +405,13 @@ tsl::Status CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream, ScopedCublasPointerMode pointer_mode{blas_}; if (!pointer_mode.Init(pointer_mode_host ? CUBLAS_POINTER_MODE_HOST : CUBLAS_POINTER_MODE_DEVICE)) { - return tsl::errors::Internal("Failed setting error mode"); + return absl::InternalError("Failed setting error mode"); } cublasStatus_t ret = cublas_func(blas_, args...); if (ret == CUBLAS_STATUS_SUCCESS) { - return ::tsl::OkStatus(); + return absl::OkStatus(); } - return tsl::errors::Internal(ToString(ret)); + return absl::InternalError(ToString(ret)); } // cublas_func may be overloaded, so we need to figure out which one we really @@ -417,36 +430,6 @@ bool CUDABlas::DoBlasAxpy(Stream *stream, uint64_t elem_count, float alpha, GpuMemoryMutable(y), incy); } -bool CUDABlas::DoBlasAxpy(Stream *stream, uint64_t elem_count, double alpha, - const DeviceMemory &x, int incx, - DeviceMemory *y, int incy) { - return DoBlasInternal(cublasDaxpy, stream, true /* = pointer_mode_host */, - elem_count, &alpha, GpuMemory(x), incx, - GpuMemoryMutable(y), incy); -} - -bool CUDABlas::DoBlasAxpy(Stream *stream, uint64_t elem_count, - std::complex alpha, - const DeviceMemory> &x, int incx, - DeviceMemory> *y, int incy) { - auto cb_alpha = GpuComplexValue(alpha); - return DoBlasInternal(cublasCaxpy, stream, true /* = pointer_mode_host */, - elem_count, GpuComplex(&cb_alpha), - GpuComplex(GpuMemory(x)), incx, - GpuComplex(GpuMemoryMutable(y)), incy); -} - -bool CUDABlas::DoBlasAxpy(Stream *stream, uint64_t elem_count, - std::complex alpha, - const DeviceMemory> &x, int incx, - DeviceMemory> *y, int incy) { - auto cb_alpha = GpuComplexValue(alpha); - return DoBlasInternal(cublasZaxpy, stream, true /* = pointer_mode_host */, - elem_count, GpuComplex(&cb_alpha), - GpuComplex(GpuMemory(x)), incx, - GpuComplex(GpuMemoryMutable(y)), incy); -} - bool CUDABlas::DoBlasCopy(Stream *stream, uint64_t elem_count, const DeviceMemory &x, int incx, DeviceMemory *y, int incy) { @@ -455,30 +438,6 @@ bool CUDABlas::DoBlasCopy(Stream *stream, uint64_t elem_count, incy); } -bool CUDABlas::DoBlasCopy(Stream *stream, uint64_t elem_count, - const DeviceMemory &x, int incx, - DeviceMemory *y, int incy) { - return DoBlasInternal(cublasDcopy, stream, true /* = pointer_mode_host */, - elem_count, GpuMemory(x), incx, GpuMemoryMutable(y), - incy); -} - -bool CUDABlas::DoBlasCopy(Stream *stream, uint64_t elem_count, - const DeviceMemory> &x, int incx, - DeviceMemory> *y, int incy) { - return DoBlasInternal(cublasCcopy, stream, true /* = pointer_mode_host */, - elem_count, GpuComplex(GpuMemory(x)), incx, - GpuComplex(GpuMemoryMutable(y)), incy); -} - -bool CUDABlas::DoBlasCopy(Stream *stream, uint64_t elem_count, - const DeviceMemory> &x, int incx, - DeviceMemory> *y, int incy) { - return DoBlasInternal(cublasZcopy, stream, true /* = pointer_mode_host */, - elem_count, GpuComplex(GpuMemory(x)), incx, - GpuComplex(GpuMemoryMutable(y)), incy); -} - bool CUDABlas::DoBlasScal(Stream *stream, uint64_t elem_count, float alpha, DeviceMemory *x, int incx) { return DoBlasInternal(cublasSscal, stream, true /* = pointer_mode_host */, @@ -584,25 +543,12 @@ bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n, incy); } -bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n, - uint64_t k, double alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &x, int incx, double beta, - DeviceMemory *y, int incy) { - return DoBlasInternal(cublasDsbmv, stream, true /* = pointer_mode_host */, - CUDABlasUpperLower(uplo), n, k, &alpha, GpuMemory(a), - lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y), - incy); -} - -tsl::Status CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa, - blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, blas::DataType dtype, - const void *alpha, const DeviceMemoryBase &a, - int lda, const DeviceMemoryBase &b, int ldb, - const void *beta, DeviceMemoryBase *c, int ldc, - const NumericOptions &numeric_options, - blas::CallContext context) { +absl::Status CUDABlas::DoBlasGemm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, + uint64 n, uint64_t k, blas::DataType dtype, const void *alpha, + const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb, + const void *beta, DeviceMemoryBase *c, int ldc, + const NumericOptions &numeric_options, blas::CallContext context) { cublasMath_t math_type = CUBLAS_DEFAULT_MATH; #if CUDA_VERSION < 11000 @@ -714,8 +660,8 @@ tsl::Status CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa, static_cast(c->opaque()), ldc); } default: - return tsl::errors::Internal("Unsupported datatype for GEMM: ", - blas::DataTypeString(dtype)); + return absl::InternalError(absl::StrCat("Unsupported datatype for GEMM: ", + blas::DataTypeString(dtype))); } } @@ -724,39 +670,39 @@ static bool UsesTensorOps(blas::AlgorithmType algo) { return cublas_algo >= CUBLAS_GEMM_DEFAULT_TENSOR_OP; } -static tsl::StatusOr GetMathTypeForGemmEx( +static absl::StatusOr GetMathTypeForGemmEx( Stream *stream, blas::AlgorithmType algorithm, blas::DataType type_a, blas::DataType type_b, const NumericOptions &numeric_options) { if (type_a != type_b) { - return tsl::errors::Internal("Types of inputs mismatch"); + return absl::InternalError("Types of inputs mismatch"); } // GPUs < sm_50 don't support cublasGemmEx. CudaComputeCapability cc = stream->GetCudaComputeCapability(); if (cc.major < 5) { - return tsl::errors::Internal("sm_", cc.major, - " does not support explicit gemm algorithms."); + return absl::InternalError(absl::StrCat( + "sm_", cc.major, " does not support explicit gemm algorithms.")); } bool algo_uses_tensor_ops = UsesTensorOps(algorithm); cublasMath_t math_type = CUBLAS_DEFAULT_MATH; if (algo_uses_tensor_ops) { if (cc.major < 7) { - return tsl::errors::Internal( + return absl::InternalError(absl::StrCat( "Algorithm ", algorithm, " uses tensor ops, but tensor ops are not available in sm", cc.major, - "X devices."); + "X devices.")); } else if (type_a == blas::DataType::kFloat) { #if CUDA_VERSION < 11000 - return tsl::errors::Internal( + return absl::InternalError( "Algorithm ", algorithm, " uses tensor ops, but tensor ops are not available for fp32"); #else if (cc.major < 8) { - return tsl::errors::Internal( + return absl::InternalError(absl::StrCat( "Algorithm ", algorithm, " uses tensor ops, but tensor ops are not available in sm", - cc.major, "X devices for float input types."); + cc.major, "X devices for float input types.")); } math_type = CUBLAS_TF32_TENSOR_OP_MATH; #endif @@ -765,9 +711,9 @@ static tsl::StatusOr GetMathTypeForGemmEx( math_type = CUBLAS_TENSOR_OP_MATH; #endif } else { - return tsl::errors::Internal( - "Algorithm ", algorithm, - " uses tensor ops which are not supported for input"); + return absl::InternalError( + absl::StrCat("Algorithm ", algorithm, + " uses tensor ops which are not supported for input")); } } if (!numeric_options.allow_tf32) { @@ -777,7 +723,7 @@ static tsl::StatusOr GetMathTypeForGemmEx( return math_type; } -static tsl::Status PopulateProfileFromTimer( +static absl::Status PopulateProfileFromTimer( std::optional &timer, blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { if (output_profile_result) { @@ -787,10 +733,10 @@ static tsl::Status PopulateProfileFromTimer( output_profile_result->set_elapsed_time_in_ms( absl::ToDoubleMilliseconds(duration)); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status CUDABlas::DoBlasGemmWithAlgorithm( +absl::Status CUDABlas::DoBlasGemmWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a, blas::DataType type_a, int lda, const DeviceMemoryBase &b, @@ -804,8 +750,10 @@ tsl::Status CUDABlas::DoBlasGemmWithAlgorithm( TF_ASSIGN_OR_RETURN( std::optional timer, - GpuTimer::CreateIfNeeded(AsGpuStream(stream), - output_profile_result != nullptr)); + GpuTimer::CreateIfNeeded( + stream, + output_profile_result && output_profile_result->warmup_run_executed(), + output_profile_result != nullptr)); // Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast, // we do the following compile-time check on the default value: @@ -820,10 +768,10 @@ tsl::Status CUDABlas::DoBlasGemmWithAlgorithm( static_cast(algorithm))); TF_RETURN_IF_ERROR( PopulateProfileFromTimer(timer, algorithm, output_profile_result)); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm( +absl::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a, blas::DataType type_a, int lda, int64_t stride_a, const DeviceMemoryBase &b, @@ -837,8 +785,10 @@ tsl::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm( GetMathTypeForGemmEx(stream, algorithm, type_a, type_b, numeric_options)); TF_ASSIGN_OR_RETURN( std::optional timer, - GpuTimer::CreateIfNeeded(AsGpuStream(stream), - output_profile_result != nullptr)); + GpuTimer::CreateIfNeeded( + stream, + output_profile_result && output_profile_result->warmup_run_executed(), + output_profile_result != nullptr)); cudaDataType_t cuda_in_type = AsCudaDataType(type_a); #if CUDA_VERSION >= 11000 @@ -874,14 +824,14 @@ tsl::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm( AsCublasComputeType(computation_type), static_cast(algorithm))); } else { - return tsl::errors::Internal( + return absl::InternalError(absl::StrCat( "Unsupported type combination for GEMM: %s and %s", - blas::DataTypeString(type_a), blas::DataTypeString(type_c)); + blas::DataTypeString(type_a), blas::DataTypeString(type_c))); } } TF_RETURN_IF_ERROR( PopulateProfileFromTimer(timer, algorithm, output_profile_result)); - return tsl::OkStatus(); + return absl::OkStatus(); } #endif @@ -894,11 +844,13 @@ tsl::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm( static_cast(algorithm))); TF_RETURN_IF_ERROR( PopulateProfileFromTimer(timer, algorithm, output_profile_result)); - return ::tsl::OkStatus(); + return absl::OkStatus(); } bool CUDABlas::GetBlasGemmAlgorithms( - Stream *stream, std::vector *out_algorithms) { + Stream *stream, const gpu::MatrixDescriptor &, + const gpu::MatrixDescriptor &, gpu::OutputMatrixDescriptor *, const void *, + const void *, std::vector *out_algorithms) { // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx) // were first introduced in CUDA 8. // @@ -988,7 +940,7 @@ T inline GpuComplexValue(T v) { } // namespace template -tsl::Status CUDABlas::DoBlasGemmBatchedInternal( +absl::Status CUDABlas::DoBlasGemmBatchedInternal( FuncT cublas_func, Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64 k, Scalar alpha, const DeviceMemorySlice &a_ptrs_to_wrappers, int lda, @@ -1007,53 +959,22 @@ tsl::Status CUDABlas::DoBlasGemmBatchedInternal( const size_t size = batch_count * sizeof(CUDA_T *); - // Device-side copy of pointers to matrices. - DeviceMemory a; - DeviceMemory b; - DeviceMemory c; - - // If temporary space is allocated for device-side copies of pointers to - // matrices, that temporary space should not be freed until this function - // returns. Although the values for these unique_ptrs are not set here, they - // are declared at this scope so they will be destroyed when the function - // returns. - // - // If a scratch allocator is provided, these pointers will not be used at all. - std::unique_ptr> a_temporary; - std::unique_ptr> b_temporary; - std::unique_ptr> c_temporary; - - // Decide how to allocate device-side copy of pointers to matrices based on - // whether a scratch allocator was passed. - if (scratch_allocator != nullptr) { - TF_ASSIGN_OR_RETURN(DeviceMemory a_bytes, - scratch_allocator->AllocateBytes(size)); - TF_ASSIGN_OR_RETURN(DeviceMemory b_bytes, - scratch_allocator->AllocateBytes(size)); - TF_ASSIGN_OR_RETURN(DeviceMemory c_bytes, - scratch_allocator->AllocateBytes(size)); - a = DeviceMemory(a_bytes); - b = DeviceMemory(b_bytes); - c = DeviceMemory(c_bytes); - } else { - TF_ASSIGN_OR_RETURN(a_temporary, - stream->AllocateTemporaryArray(batch_count)); - TF_ASSIGN_OR_RETURN(b_temporary, - stream->AllocateTemporaryArray(batch_count)); - TF_ASSIGN_OR_RETURN(c_temporary, - stream->AllocateTemporaryArray(batch_count)); - a = DeviceMemory(*a_temporary->mutable_device_memory()); - b = DeviceMemory(*b_temporary->mutable_device_memory()); - c = DeviceMemory(*c_temporary->mutable_device_memory()); - } - - if (!stream->ThenMemcpy(&a, a_raw_ptrs.data(), size).ok() || - !stream->ThenMemcpy(&b, b_raw_ptrs.data(), size).ok() || - !stream->ThenMemcpy(&c, c_raw_ptrs.data(), size).ok()) { - return tsl::Status(absl::StatusCode::kInternal, - "failed to copy memory from host to device in " - "CUDABlas::DoBlasGemmBatched"); + if (scratch_allocator == nullptr) { + return absl::InternalError("scratch_allocator is null"); } + TF_ASSIGN_OR_RETURN(DeviceMemory a_bytes, + scratch_allocator->AllocateBytes(size)); + TF_ASSIGN_OR_RETURN(DeviceMemory b_bytes, + scratch_allocator->AllocateBytes(size)); + TF_ASSIGN_OR_RETURN(DeviceMemory c_bytes, + scratch_allocator->AllocateBytes(size)); + DeviceMemory a(a_bytes); + DeviceMemory b(b_bytes); + DeviceMemory c(c_bytes); + + TF_RETURN_IF_ERROR(stream->Memcpy(&a, a_raw_ptrs.data(), size)); + TF_RETURN_IF_ERROR(stream->Memcpy(&b, b_raw_ptrs.data(), size)); + TF_RETURN_IF_ERROR(stream->Memcpy(&c, c_raw_ptrs.data(), size)); cudaDataType_t data_type = CUDADataType::type; @@ -1113,10 +1034,9 @@ tsl::Status CUDABlas::DoBlasGemmBatchedInternal( const_cast(GpuMemory(b)), ldb, GpuComplex(&cb_beta), const_cast(GpuMemory(c)), ldc, batch_count); if (ok) { - return ::tsl::OkStatus(); + return absl::OkStatus(); } - return tsl::Status(absl::StatusCode::kInternal, - "failed BLAS call, see log for details"); + return absl::InternalError("failed BLAS call, see log for details"); } else { // Fall back to a loop for fp16 for (int b = 0; b < batch_count; ++b) { @@ -1128,7 +1048,7 @@ tsl::Status CUDABlas::DoBlasGemmBatchedInternal( a_matrix, lda, b_matrix, ldb, &beta, c_matrix, ldc, numeric_options, blas::CallContext::kNone)); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } } @@ -1141,7 +1061,7 @@ bool CUDABlas::DoBlasGemmBatched( blas::CallContext context) { // Note: The func passed here (cublasSgemmBatched) is not actually called, // due to special handling of fp16 inside DoBlasGemmBatchedInternal. - tsl::Status status = DoBlasGemmBatchedInternal( + absl::Status status = DoBlasGemmBatchedInternal( cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, numeric_options, scratch_allocator); @@ -1161,7 +1081,7 @@ bool CUDABlas::DoBlasGemmBatched( blas::CallContext context) { // Note: The func passed here (cublasSgemmBatched) is not actually called, // due to special handling of bf16 inside DoBlasGemmBatchedInternal. - tsl::Status status = DoBlasGemmBatchedInternal( + absl::Status status = DoBlasGemmBatchedInternal( cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, numeric_options, scratch_allocator); @@ -1178,7 +1098,7 @@ bool CUDABlas::DoBlasGemmBatched( DeviceMemorySlice c_array, int ldc, int batch_count, const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, blas::CallContext context) { - tsl::Status status = DoBlasGemmBatchedInternal( + absl::Status status = DoBlasGemmBatchedInternal( cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, numeric_options, scratch_allocator); @@ -1195,7 +1115,7 @@ bool CUDABlas::DoBlasGemmBatched( DeviceMemorySlice c_array, int ldc, int batch_count, const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, blas::CallContext context) { - tsl::Status status = DoBlasGemmBatchedInternal( + absl::Status status = DoBlasGemmBatchedInternal( cublasDgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, numeric_options, @@ -1214,7 +1134,7 @@ bool CUDABlas::DoBlasGemmBatched( std::complex beta, DeviceMemorySlice> c_array, int ldc, int batch_count, const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, blas::CallContext context) { - tsl::Status status = DoBlasGemmBatchedInternal( + absl::Status status = DoBlasGemmBatchedInternal( cublasCgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, numeric_options, @@ -1233,7 +1153,7 @@ bool CUDABlas::DoBlasGemmBatched( std::complex beta, DeviceMemorySlice> c_array, int ldc, int batch_count, const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, blas::CallContext context) { - tsl::Status status = DoBlasGemmBatchedInternal( + absl::Status status = DoBlasGemmBatchedInternal( cublasZgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, numeric_options, scratch_allocator); @@ -1243,7 +1163,7 @@ bool CUDABlas::DoBlasGemmBatched( return status.ok(); } -tsl::Status CUDABlas::DoBlasGemmStridedBatched( +absl::Status CUDABlas::DoBlasGemmStridedBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64_t n, uint64 k, blas::DataType dtype, const void *alpha, const DeviceMemoryBase &a, int lda, int64_t stride_a, @@ -1295,7 +1215,7 @@ tsl::Status CUDABlas::DoBlasGemmStridedBatched( b_matrix, CUDA_R_16BF, ldb, static_cast(beta), c_matrix, CUDA_R_16BF, ldc)); } - return tsl::OkStatus(); + return absl::OkStatus(); } #endif case dnn::kHalf: { @@ -1327,7 +1247,7 @@ tsl::Status CUDABlas::DoBlasGemmStridedBatched( b_matrix, CUDA_R_16F, ldb, static_cast(beta), c_matrix, CUDA_R_16F, ldc)); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } case dnn::kFloat: { return DoBlasInternalImpl( @@ -1378,8 +1298,8 @@ tsl::Status CUDABlas::DoBlasGemmStridedBatched( batch_count); } default: - return tsl::errors::Internal("Unsupported datatype for GEMM: ", - blas::DataTypeString(dtype)); + return absl::InternalError(absl::StrCat("Unsupported datatype for GEMM: ", + blas::DataTypeString(dtype))); } } @@ -1493,20 +1413,20 @@ bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side, reinterpret_cast(GpuMemoryMutable(bs)), ldb, batch_count); } -tsl::Status CUDABlas::GetVersion(std::string *version) { +absl::Status CUDABlas::GetVersion(std::string *version) { absl::MutexLock lock(&mu_); int v; auto status = cublasGetVersion(blas_, &v); if (status != CUBLAS_STATUS_SUCCESS) { - return tsl::errors::Internal(ToString(status)); + return absl::InternalError(ToString(status)); } *version = std::to_string(v); - return ::tsl::OkStatus(); + return absl::OkStatus(); } void initialize_cublas() { - tsl::Status status = + absl::Status status = PluginRegistry::Instance()->RegisterFactory( kCudaPlatformId, "cuBLAS", [](::stream_executor::internal::StreamExecutorInterface *parent) @@ -1537,5 +1457,6 @@ void initialize_cublas() { } // namespace cuda } // namespace stream_executor -REGISTER_MODULE_INITIALIZER(register_cublas, - { stream_executor::cuda::initialize_cublas(); }); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(register_cublas, { + stream_executor::cuda::initialize_cublas(); +}); diff --git a/xla/stream_executor/cuda/cuda_blas.h b/xla/stream_executor/cuda/cuda_blas.h index 3dc54f976edc0..5f69e8b04765a 100644 --- a/xla/stream_executor/cuda/cuda_blas.h +++ b/xla/stream_executor/cuda/cuda_blas.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,14 +20,17 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ #define XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ +#include + #include "absl/base/thread_annotations.h" +#include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "third_party/gpus/cuda/include/cublas_v2.h" +#include "third_party/gpus/cuda/include/driver_types.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/cuda/cuda_blas_lt.h" -#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/numeric_options.h" #include "xla/stream_executor/platform/port.h" -#include "xla/stream_executor/plugin_registry.h" namespace stream_executor { @@ -84,9 +87,9 @@ class CUDABlas : public blas::BlasSupport { // (true) or device (false). // args: Arguments of cuBLAS function. template - tsl::Status DoBlasInternalImpl(FuncT cublas_func, Stream *stream, - bool pointer_mode_host, cublasMath_t math_type, - Args... args); + absl::Status DoBlasInternalImpl(FuncT cublas_func, Stream *stream, + bool pointer_mode_host, + cublasMath_t math_type, Args... args); // Convenience functions that call DoBlasInternalImpl with err_on_failure=true // and math_type=CUBLAS_DEFAULT_MATH. @@ -101,7 +104,7 @@ class CUDABlas : public blas::BlasSupport { // A helper function to implement DoBlasGemmBatched interfaces for generic // types. template - tsl::Status DoBlasGemmBatchedInternal( + absl::Status DoBlasGemmBatchedInternal( FuncT cublas_func, Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64 k, Scalar alpha, const DeviceMemorySlice &a_array, int lda, diff --git a/xla/stream_executor/cuda/cuda_blas_lt.cc b/xla/stream_executor/cuda/cuda_blas_lt.cc index af0487f6931e1..859e298112fef 100644 --- a/xla/stream_executor/cuda/cuda_blas_lt.cc +++ b/xla/stream_executor/cuda/cuda_blas_lt.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,19 +15,27 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_blas_lt.h" +#include #include #include #include +#include #include #include #include -#include #include #include #include +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" #include "third_party/gpus/cuda/include/cublasLt.h" #include "third_party/gpus/cuda/include/cublas_v2.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/library_types.h" #include "xla/primitive_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/blas.h" @@ -40,13 +48,18 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" +#include "xla/types.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/ml_dtypes.h" +#include "tsl/platform/statusor.h" #define SET_ATTR(setter, handle, attr, value) \ ToStatus(setter(handle, attr, &value, sizeof(decltype(value))), #setter) #define GET_ATTR(getter, handle, attr, ValueT) \ - [&]() -> tsl::StatusOr { \ + [&]() -> absl::StatusOr { \ ValueT value; \ TF_RETURN_IF_ERROR(ToStatus( \ getter(handle, attr, &value, sizeof(ValueT), nullptr), #getter)); \ @@ -62,32 +75,32 @@ using ::xla::complex64; namespace { template -tsl::Status SetAttr(cublasLtMatrixLayout_t handle, - cublasLtMatrixLayoutAttribute_t attr, T value) { +absl::Status SetAttr(cublasLtMatrixLayout_t handle, + cublasLtMatrixLayoutAttribute_t attr, T value) { return SET_ATTR(cublasLtMatrixLayoutSetAttribute, handle, attr, value); } template -tsl::StatusOr GetAttr(cublasLtMatrixLayout_t handle, - cublasLtMatrixLayoutAttribute_t attr) { +absl::StatusOr GetAttr(cublasLtMatrixLayout_t handle, + cublasLtMatrixLayoutAttribute_t attr) { return GET_ATTR(cublasLtMatrixLayoutGetAttribute, handle, attr, T); } template -tsl::Status SetAttr(cublasLtMatmulDesc_t handle, - cublasLtMatmulDescAttributes_t attr, T value) { +absl::Status SetAttr(cublasLtMatmulDesc_t handle, + cublasLtMatmulDescAttributes_t attr, T value) { return SET_ATTR(cublasLtMatmulDescSetAttribute, handle, attr, value); } template -tsl::StatusOr GetAttr(cublasLtMatmulDesc_t handle, - cublasLtMatmulDescAttributes_t attr) { +absl::StatusOr GetAttr(cublasLtMatmulDesc_t handle, + cublasLtMatmulDescAttributes_t attr) { return GET_ATTR(cublasLtMatmulDescGetAttribute, handle, attr, T); } template -tsl::Status SetAttr(cublasLtMatmulPreference_t handle, - cublasLtMatmulPreferenceAttributes_t attr, T value) { +absl::Status SetAttr(cublasLtMatmulPreference_t handle, + cublasLtMatmulPreferenceAttributes_t attr, T value) { return SET_ATTR(cublasLtMatmulPreferenceSetAttribute, handle, attr, value); } @@ -101,7 +114,7 @@ cublasLtPointerMode_t AsCublasLtPointerMode( } } -tsl::StatusOr AsCublasLtEpilogue( +absl::StatusOr AsCublasLtEpilogue( gpu::BlasLt::Epilogue epilogue) { switch (epilogue) { case gpu::BlasLt::Epilogue::kDefault: @@ -126,36 +139,29 @@ tsl::StatusOr AsCublasLtEpilogue( case gpu::BlasLt::Epilogue::kGELUWithAux: case gpu::BlasLt::Epilogue::kBiasThenGELU: case gpu::BlasLt::Epilogue::kBiasThenGELUWithAux: - return tsl::errors::Internal("GELU epilogues require cublasLt >= 11.4"); + return absl::InternalError("GELU epilogues require cublasLt >= 11.4"); #endif } } } // namespace -tsl::Status BlasLt::Init() { +absl::Status BlasLt::Init() { cublasLtHandle_t blas_lt; SE_CUBLAS_RETURN_IF_ERROR(cublasLtCreate(&blas_lt)); absl::MutexLock lock(&mu_); blas_lt_.reset(blas_lt); - return tsl::OkStatus(); + return absl::OkStatus(); } -/*static*/ tsl::StatusOr BlasLt::MatrixLayout::Create( +/*static*/ absl::StatusOr BlasLt::MatrixLayout::Create( const gpu::MatrixLayout& m) { TF_ASSIGN_OR_RETURN(auto type, gpu::AsBlasDataType(m.dtype)); - auto leading_dim_stride = m.leading_dim_stride; - if (!leading_dim_stride) { - leading_dim_stride = (m.order == gpu::MatrixLayout::Order::kRowMajor) - ? m.num_cols - : m.num_rows; - } - cublasLtMatrixLayout_t cu_layout; SE_CUBLAS_RETURN_IF_ERROR( cublasLtMatrixLayoutCreate(&cu_layout, AsCudaDataType(type), m.num_rows, - m.num_cols, *leading_dim_stride)); + m.num_cols, m.leading_dim_stride)); // Wrap cublas handle immediately, so it is cleaned up if an error occurs. BlasLt::MatrixLayout layout(cu_layout); TF_RETURN_IF_ERROR( @@ -166,20 +172,14 @@ tsl::Status BlasLt::Init() { TF_RETURN_IF_ERROR(SetAttr(cu_layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, static_cast(m.batch_size))); - auto batch_stride = m.batch_stride; - if (!batch_stride) { - batch_stride = (m.batch_size > 1) ? m.num_rows * m.num_cols : 0; - } - VLOG(2) << "MatrixLayout::Create: num_rows: " << m.num_rows << " num_cols:" << (int)m.num_cols << ", order: " << (int)m.order - << "," - << " batchsz " << m.batch_size - << " leaddimstride: " << *leading_dim_stride - << " batch_stride: " << *batch_stride; + << "," << " batchsz " << m.batch_size + << " leaddimstride: " << m.leading_dim_stride + << " batch_stride: " << m.batch_stride; TF_RETURN_IF_ERROR(SetAttr( - cu_layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, *batch_stride)); + cu_layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, m.batch_stride)); return std::move(layout); } @@ -188,7 +188,7 @@ cudaDataType_t BlasLt::MatrixLayout::type() const { GetAttr(handle_.get(), CUBLASLT_MATRIX_LAYOUT_TYPE).value()); } -/*static*/ tsl::StatusOr BlasLt::MatmulDesc::Create( +/*static*/ absl::StatusOr BlasLt::MatmulDesc::Create( blas::ComputationType compute_type, blas::DataType scale_type, blas::Transpose trans_a, blas::Transpose trans_b, gpu::BlasLt::Epilogue epilogue, bool enable_fast_accum, @@ -238,7 +238,7 @@ cublasLtPointerMode_t BlasLt::MatmulDesc::pointer_mode() const { auto BlasLt::MatmulPlan::GetAlgorithms(size_t max_algorithm_count, size_t max_workspace_size) const - -> tsl::StatusOr> { + -> absl::StatusOr> { max_algorithm_count = std::min(max_algorithm_count, size_t{INT_MAX}); std::vector results(max_algorithm_count); { @@ -276,9 +276,27 @@ auto BlasLt::MatmulPlan::GetAlgorithms(size_t max_algorithm_count, return std::move(algorithms); } +namespace { + +bool IsFastAccumEnabled(const xla::PrecisionConfig::Algorithm algorithm, + xla::PrimitiveType lhs_type, + xla::PrimitiveType rhs_type, + int64_t compute_precision) { + if (algorithm == xla::PrecisionConfig::ALG_UNSET) { + return (xla::primitive_util::IsF8Type(lhs_type) || + xla::primitive_util::IsF8Type(rhs_type)) && + compute_precision == 0; + } + + return algorithm == + xla::PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM; +} + +} // namespace + auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg, gpu::BlasLt::Epilogue epilogue) const - -> tsl::StatusOr { + -> absl::StatusOr { auto lhs_layout = cfg.lhs_layout, rhs_layout = cfg.rhs_layout, output_layout = cfg.output_layout, c_layout = cfg.c_layout; // cublasLt matmul requires batch sizes to be equal. If only one operand has a @@ -297,18 +315,15 @@ auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg, // *not* be transposed, and if B is row-major, B must be transposed. We never // transpose A or B, and expect the caller to ensure A is row-major and B is // column when A and B are FP8. - auto trans_a = lhs_layout.transpose ? *lhs_layout.transpose - : blas::Transpose::kNoTranspose; - auto trans_b = rhs_layout.transpose ? *rhs_layout.transpose - : blas::Transpose::kNoTranspose; + auto trans_a = lhs_layout.transpose, trans_b = rhs_layout.transpose; if (xla::primitive_util::IsF8Type(lhs_layout.dtype) && lhs_layout.order == gpu::MatrixLayout::Order::kColumnMajor) { - return xla::InternalError("The F8 LHS must be column-major"); + return xla::Internal("The F8 LHS must be column-major"); } if (xla::primitive_util::IsF8Type(rhs_layout.dtype) && rhs_layout.order == gpu::MatrixLayout::Order::kRowMajor) { - return xla::InternalError("The F8 RHS must be row-major"); + return xla::Internal("The F8 RHS must be row-major"); } TF_ASSIGN_OR_RETURN(auto output_dtype, @@ -316,17 +331,18 @@ auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg, auto compute_type = cfg.compute_type; if (!compute_type) { // obtain compute_type unless provided by the user - TF_ASSIGN_OR_RETURN(compute_type, gpu::GetBlasComputationType( - lhs_layout.dtype, output_layout.dtype, - cfg.compute_precision)); + TF_ASSIGN_OR_RETURN(compute_type, + gpu::GetBlasComputationType( + cfg.precision_algorithm, lhs_layout.dtype, + output_layout.dtype, cfg.compute_precision)); } // FP8 matmuls have a fast accumulation mode that is less precise than the // default accumulation mode. Use the fast accumulation mode if the compute // precision is DEFAULT. - bool enable_fast_accum = (xla::primitive_util::IsF8Type(lhs_layout.dtype) || - xla::primitive_util::IsF8Type(rhs_layout.dtype)) && - cfg.compute_precision == 0; + bool enable_fast_accum = + IsFastAccumEnabled(cfg.precision_algorithm, lhs_layout.dtype, + rhs_layout.dtype, cfg.compute_precision); TF_ASSIGN_OR_RETURN( auto op_desc, MatmulDesc::Create(*compute_type, @@ -344,45 +360,45 @@ auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg, cfg.alpha, cfg.beta, must_swap_operands); } -tsl::Status BlasLt::MatmulPlan::ValidateInputs( +absl::Status BlasLt::MatmulPlan::ValidateInputs( blas::DataType scale_type, bool alpha_on_device, bool beta_on_device, blas::DataType A_type, blas::DataType B_type, blas::DataType C_type, blas::DataType D_type) const { if (AsCudaDataType(scale_type) != op_desc_.scale_type()) { - return tsl::errors::InvalidArgument("mismatched scale types"); + return absl::InvalidArgumentError("mismatched scale types"); } bool expect_scale_factor_on_device = (op_desc_.pointer_mode() == CUBLASLT_POINTER_MODE_DEVICE); if (alpha_on_device != expect_scale_factor_on_device) { - return tsl::errors::InvalidArgument("wrong location for alpha"); + return absl::InvalidArgumentError("wrong location for alpha"); } if (beta_on_device != expect_scale_factor_on_device) { - return tsl::errors::InvalidArgument("wrong location for beta"); + return absl::InvalidArgumentError("wrong location for beta"); } if (AsCudaDataType(A_type) != a_desc_.type()) { - return tsl::errors::InvalidArgument("mismatched A matrix types"); + return absl::InvalidArgumentError("mismatched A matrix types"); } if (AsCudaDataType(B_type) != b_desc_.type()) { - return tsl::errors::InvalidArgument("mismatched B matrix types"); + return absl::InvalidArgumentError("mismatched B matrix types"); } if (AsCudaDataType(C_type) != c_desc_.type()) { - return tsl::errors::InvalidArgument("mismatched C matrix types"); + return absl::InvalidArgumentError("mismatched C matrix types"); } if (AsCudaDataType(D_type) != d_desc_.type()) { - return tsl::errors::InvalidArgument("mismatched D matrix types"); + return absl::InvalidArgumentError("mismatched D matrix types"); } - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status BlasLt::MatmulPlan::DoMatmul( +absl::Status BlasLt::MatmulPlan::DoMatmul( Stream* stream, const void* alpha, DeviceMemoryBase a, DeviceMemoryBase b, const void* beta, DeviceMemoryBase c, DeviceMemoryBase d, const MatmulAlgorithm& algorithm, ScratchAllocator& scratch_allocator, @@ -392,7 +408,9 @@ tsl::Status BlasLt::MatmulPlan::DoMatmul( blas::ProfileResult* profile_result) const { TF_ASSIGN_OR_RETURN( std::optional timer, - gpu::GpuTimer::CreateIfNeeded(gpu::AsGpuStream(stream), profile_result)); + gpu::GpuTimer::CreateIfNeeded( + stream, profile_result && profile_result->warmup_run_executed(), + profile_result != nullptr)); void* workspace = nullptr; if (algorithm.workspace_size > 0) { @@ -441,7 +459,7 @@ tsl::Status BlasLt::MatmulPlan::DoMatmul( #else if (a_scale != nullptr || b_scale != nullptr || c_scale != nullptr || d_scale != nullptr || d_amax != nullptr) { - return tsl::errors::Internal( + return absl::InternalError( "A/B/C/D scales and amax require cublasLt >= 11.8"); } #endif @@ -471,7 +489,7 @@ tsl::Status BlasLt::MatmulPlan::DoMatmul( CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_BATCH_STRIDE, output_batch_stride)); #else - return tsl::errors::Internal( + return absl::InternalError( "Auxiliary inputs / outputs require cublasLt >= 11.4"); #endif } @@ -485,7 +503,7 @@ tsl::Status BlasLt::MatmulPlan::DoMatmul( c_desc_.get(), d.opaque(), d_desc_.get(), palgo, workspace, algorithm.workspace_size, gpu::AsGpuStreamValue(stream))); } else { - return tsl::errors::Internal("cublaslt: Invalid algorithm type"); + return absl::InternalError("cublaslt: Invalid algorithm type"); } } @@ -496,7 +514,7 @@ tsl::Status BlasLt::MatmulPlan::DoMatmul( profile_result->set_is_valid(true); profile_result->set_elapsed_time_in_ms(absl::ToDoubleMilliseconds(elapsed)); } - return tsl::OkStatus(); + return absl::OkStatus(); } namespace { @@ -542,7 +560,7 @@ struct CudaToNativeT { } // namespace -tsl::Status BlasLt::MatmulPlan::ExecuteOnStream( +absl::Status BlasLt::MatmulPlan::ExecuteOnStream( Stream* stream, DeviceMemoryBase a, DeviceMemoryBase b, DeviceMemoryBase c, DeviceMemoryBase d, DeviceMemoryBase bias, DeviceMemoryBase aux, DeviceMemoryBase a_scale, DeviceMemoryBase b_scale, @@ -613,7 +631,7 @@ tsl::Status BlasLt::MatmulPlan::ExecuteOnStream( #undef TYPED_MATMUL - return xla::InternalError("Unexpected dtype"); + return xla::Internal("Unexpected dtype"); } } // namespace cuda diff --git a/xla/stream_executor/cuda/cuda_blas_lt.h b/xla/stream_executor/cuda/cuda_blas_lt.h index 7a758f40e2205..2a0a5611b81ce 100644 --- a/xla/stream_executor/cuda/cuda_blas_lt.h +++ b/xla/stream_executor/cuda/cuda_blas_lt.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,22 +16,23 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_LT_H_ #define XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_LT_H_ -#include +#include #include -#include -#include +#include +#include #include +#include "absl/base/thread_annotations.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "third_party/gpus/cuda/include/cublasLt.h" #include "third_party/gpus/cuda/include/cublas_v2.h" -#include "third_party/gpus/cuda/include/cuda.h" -#include "xla/stream_executor/cuda/cuda_blas_utils.h" +#include "third_party/gpus/cuda/include/library_types.h" +#include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" -#include "xla/stream_executor/host_or_device_scalar.h" #include "xla/types.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace stream_executor { namespace gpu { @@ -52,7 +53,7 @@ class BlasLt : public gpu::BlasLt { // - `num_rows` if `order == kColumnMajor`. // If `batch_stride` is not specified, it defaults to `num_rows * num_cols` // if `batch_size > 1`, otherwise `0`. - static tsl::StatusOr Create(const gpu::MatrixLayout& m); + static absl::StatusOr Create(const gpu::MatrixLayout& m); cudaDataType_t type() const; cublasLtMatrixLayout_t get() const { return handle_.get(); } @@ -66,7 +67,7 @@ class BlasLt : public gpu::BlasLt { class MatmulDesc { public: - static tsl::StatusOr Create( + static absl::StatusOr Create( blas::ComputationType compute_type, blas::DataType scale_type, blas::Transpose trans_a = blas::Transpose::kNoTranspose, blas::Transpose trans_b = blas::Transpose::kNoTranspose, @@ -103,7 +104,7 @@ class BlasLt : public gpu::BlasLt { ~MatmulPlan() override = default; - tsl::Status ExecuteOnStream( + absl::Status ExecuteOnStream( Stream* stream, DeviceMemoryBase a_buffer, DeviceMemoryBase b_buffer, DeviceMemoryBase c_buffer, DeviceMemoryBase d_buffer, DeviceMemoryBase bias_buffer, // may be null @@ -114,25 +115,25 @@ class BlasLt : public gpu::BlasLt { ScratchAllocator& scratch_allocator, blas::ProfileResult* profile_result = nullptr) const override; - tsl::StatusOr> GetAlgorithms( + absl::StatusOr> GetAlgorithms( size_t max_algorithm_count, size_t max_workspace_size) const override; protected: - tsl::Status ValidateInputs(blas::DataType scale_type, bool alpha_on_device, - bool beta_on_device, blas::DataType A_type, - blas::DataType B_type, blas::DataType C_type, - blas::DataType D_type) const override; - - tsl::Status DoMatmul(Stream* stream, const void* alpha, DeviceMemoryBase a, - DeviceMemoryBase b, const void* beta, - DeviceMemoryBase c, DeviceMemoryBase d, - const MatmulAlgorithm& algorithm, - ScratchAllocator& scratch_allocator, - DeviceMemoryBase bias, DeviceMemoryBase aux, - DeviceMemoryBase a_scale, DeviceMemoryBase b_scale, - DeviceMemoryBase c_scale, DeviceMemoryBase d_scale, - DeviceMemoryBase d_amax, - blas::ProfileResult* profile_result) const override; + absl::Status ValidateInputs(blas::DataType scale_type, bool alpha_on_device, + bool beta_on_device, blas::DataType A_type, + blas::DataType B_type, blas::DataType C_type, + blas::DataType D_type) const override; + + absl::Status DoMatmul(Stream* stream, const void* alpha, DeviceMemoryBase a, + DeviceMemoryBase b, const void* beta, + DeviceMemoryBase c, DeviceMemoryBase d, + const MatmulAlgorithm& algorithm, + ScratchAllocator& scratch_allocator, + DeviceMemoryBase bias, DeviceMemoryBase aux, + DeviceMemoryBase a_scale, DeviceMemoryBase b_scale, + DeviceMemoryBase c_scale, DeviceMemoryBase d_scale, + DeviceMemoryBase d_amax, + blas::ProfileResult* profile_result) const override; private: const BlasLt& blas_lt_ref_; @@ -150,10 +151,10 @@ class BlasLt : public gpu::BlasLt { explicit BlasLt(gpu::GpuExecutor* parent) : parent_(parent), blas_lt_(nullptr, cublasLtDestroy) {} - tsl::Status Init() override; + absl::Status Init() override; - tsl::StatusOr GetMatmulPlan(const gpu::GemmConfig& cfg, - Epilogue epilogue) const override; + absl::StatusOr GetMatmulPlan(const gpu::GemmConfig& cfg, + Epilogue epilogue) const override; ~BlasLt() override = default; diff --git a/xla/stream_executor/cuda/cuda_blas_utils.cc b/xla/stream_executor/cuda/cuda_blas_utils.cc index e42192a03adb4..ec700d6f3168f 100644 --- a/xla/stream_executor/cuda/cuda_blas_utils.cc +++ b/xla/stream_executor/cuda/cuda_blas_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,9 +15,12 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_blas_utils.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "third_party/gpus/cuda/include/cublas_v2.h" #include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/library_types.h" #include "xla/stream_executor/blas.h" namespace stream_executor { @@ -31,12 +34,11 @@ const char* ToString(cublasStatus_t status) { #endif // CUDA_VERSION >= 11050 } -tsl::Status ToStatus(cublasStatus_t status, const char* prefix) { +absl::Status ToStatus(cublasStatus_t status, const char* prefix) { if (status != CUBLAS_STATUS_SUCCESS) { - return tsl::Status(absl::StatusCode::kInternal, - absl::StrCat(prefix, ": ", ToString(status))); + return absl::InternalError(absl::StrCat(prefix, ": ", ToString(status))); } - return tsl::OkStatus(); + return absl::OkStatus(); } cudaDataType_t AsCudaDataType(blas::DataType type) { diff --git a/xla/stream_executor/cuda/cuda_blas_utils.h b/xla/stream_executor/cuda/cuda_blas_utils.h index 2e8789b1f1bcd..aaaf4257f4f5b 100644 --- a/xla/stream_executor/cuda/cuda_blas_utils.h +++ b/xla/stream_executor/cuda/cuda_blas_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,12 +16,12 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_UTILS_H_ #define XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_UTILS_H_ -#include +#include "absl/status/status.h" #include "third_party/gpus/cuda/include/cublas_v2.h" +#include "third_party/gpus/cuda/include/library_types.h" #include "xla/stream_executor/blas.h" #include "tsl/platform/errors.h" -#include "tsl/platform/status.h" #define SE_CUBLAS_RETURN_IF_ERROR(expr) \ TF_RETURN_IF_ERROR(::stream_executor::cuda::ToStatus(expr, #expr)) @@ -30,7 +30,7 @@ namespace stream_executor { namespace cuda { const char* ToString(cublasStatus_t status); -tsl::Status ToStatus(cublasStatus_t status, const char* prefix = "cublasLt"); +absl::Status ToStatus(cublasStatus_t status, const char* prefix = "cublasLt"); cudaDataType_t AsCudaDataType(blas::DataType type); cublasComputeType_t AsCublasComputeType(blas::ComputationType type); cublasOperation_t AsCublasOperation(blas::Transpose trans); diff --git a/xla/stream_executor/cuda/cuda_collectives.cc b/xla/stream_executor/cuda/cuda_collectives.cc new file mode 100644 index 0000000000000..79118a54034af --- /dev/null +++ b/xla/stream_executor/cuda/cuda_collectives.cc @@ -0,0 +1,78 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "xla/stream_executor/cuda/cuda_driver.h" +#include "xla/stream_executor/gpu/gpu_collectives.h" +#include "xla/stream_executor/gpu/gpu_driver.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/numbers.h" + +#ifdef STREAM_EXECUTOR_GPU_ENABLE_XCCL +#include "third_party/nccl/nccl.h" +#endif // STREAM_EXECUTOR_GPU_ENABLE_XCCL + +namespace stream_executor::gpu { + +/* static */ absl::StatusOr GpuCollectives::CollectiveMemoryAllocate( + GpuContext* context, uint64_t bytes) { + if (bytes == 0) return nullptr; + + ScopedActivateContext activated(context); + +#ifdef STREAM_EXECUTOR_GPU_ENABLE_XCCL + void* ptr = nullptr; + ncclResult_t res = ncclMemAlloc(&ptr, bytes); + if (res != ncclSuccess) { + return absl::InternalError(absl::StrFormat( + "failed to allocate %s (%llu bytes) from device collective memory: %s, " + "Last NCCL warning(error) log entry (may be unrelated): %s", + tsl::strings::HumanReadableNumBytes(bytes), bytes, + ncclGetErrorString(res), ncclGetLastError(nullptr))); + } + VLOG(2) << "Allocated collective memory " << ptr << " for context " + << context->context() << " of " << bytes << " bytes"; + return ptr; +#else + return absl::FailedPreconditionError("XLA was compiled without NCCL support"); +#endif +} + +/* static */ absl::Status GpuCollectives::CollectiveMemoryDeallocate( + GpuContext* context, void* location) { + ScopedActivateContext activation(context); + +#ifdef STREAM_EXECUTOR_GPU_ENABLE_XCCL + ncclResult_t res = ncclMemFree(location); + if (res != ncclSuccess) { + return absl::InternalError(absl::StrFormat( + "failed to free device collective memory at %p; result: %s, Last NCCL " + "warning(error) log entry (may be unrelated): %s", + location, ncclGetErrorString(res), ncclGetLastError(nullptr))); + } + + VLOG(2) << "Deallocated collective memory " << location << " for context " + << context->context(); + return absl::OkStatus(); +#else + return absl::FailedPreconditionError("XLA was compiled without NCCL support"); +#endif +} + +} // namespace stream_executor::gpu diff --git a/xla/stream_executor/cuda/cuda_command_buffer_test.cc b/xla/stream_executor/cuda/cuda_command_buffer_test.cc deleted file mode 100644 index d3064b4bdde44..0000000000000 --- a/xla/stream_executor/cuda/cuda_command_buffer_test.cc +++ /dev/null @@ -1,821 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "absl/log/check.h" -#include "xla/stream_executor/command_buffer.h" -#include "xla/stream_executor/cuda/cuda_test_kernels.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/multi_platform_manager.h" -#include "xla/stream_executor/platform.h" -#include "xla/stream_executor/stream.h" -#include "xla/stream_executor/stream_executor.h" -#include "tsl/lib/core/status_test_util.h" -#include "tsl/platform/status.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" - -namespace stream_executor::cuda { - -using AddI32Kernel = TypedKernel, DeviceMemory, - DeviceMemory>; -using MulI32Kernel = TypedKernel, DeviceMemory, - DeviceMemory>; -using IncAndCmpKernel = - TypedKernel, DeviceMemory, int32_t>; - -using AddI32Ptrs3 = TypedKernel>; - -static constexpr auto nested = CommandBuffer::Mode::kNested; // NOLINT -static constexpr auto primary = CommandBuffer::Mode::kPrimary; // NOLINT - -TEST(CudaCommandBufferTest, LaunchSingleKernel) { - Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); - StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32CudaKernel(), "add"); - - AddI32Kernel add(executor); - TF_ASSERT_OK(executor->GetKernel(spec, &add)); - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: a=1, b=2, c=0 - DeviceMemory a = executor->AllocateArray(length, 0); - DeviceMemory b = executor->AllocateArray(length, 0); - DeviceMemory c = executor->AllocateArray(length, 0); - - stream.ThenMemset32(&a, 1, byte_length); - stream.ThenMemset32(&b, 2, byte_length); - stream.ThenMemZero(&c, byte_length); - - // Create a command buffer with a single kernel launch. - auto cmd_buffer = CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(cmd_buffer.Launch(add, ThreadDim(), BlockDim(4), a, b, c)); - TF_ASSERT_OK(cmd_buffer.Finalize()); - - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); - - // Copy `c` data back to host. - std::vector dst(4, 42); - stream.ThenMemcpy(dst.data(), c, byte_length); - - std::vector expected = {3, 3, 3, 3}; - ASSERT_EQ(dst, expected); - - // Prepare argument for graph update: d = 0 - DeviceMemory d = executor->AllocateArray(length, 0); - stream.ThenMemZero(&d, byte_length); - - // Update command buffer to write into `d` buffer. - TF_ASSERT_OK(cmd_buffer.Update()); - TF_ASSERT_OK(cmd_buffer.Launch(add, ThreadDim(), BlockDim(4), a, b, d)); - TF_ASSERT_OK(cmd_buffer.Finalize()); - - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); - - // Copy `d` data back to host. - std::fill(dst.begin(), dst.end(), 42); - stream.ThenMemcpy(dst.data(), d, byte_length); - ASSERT_EQ(dst, expected); -} - -TEST(CudaCommandBufferTest, TraceSingleKernel) { - Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); - StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - AddI32Ptrs3 add(executor); - - // Register a kernel with a custom arguments packing function that packs - // device memory arguments into a struct with pointers. - MultiKernelLoaderSpec spec(/*arity=*/1, [&](const KernelArgs& args) { - auto bufs = Cast(&args)->device_memory_args(); - auto cast = [](auto m) { return reinterpret_cast(m.opaque()); }; - return PackKernelArgs(add, internal::Ptrs3{ - cast(bufs[0]), - cast(bufs[1]), - cast(bufs[2]), - }); - }); - spec.AddInProcessSymbol(internal::GetAddI32Ptrs3CudaKernel(), "add"); - - TF_ASSERT_OK(executor->GetKernel(spec, &add)); - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: a=1, b=2, c=0 - DeviceMemory a = executor->AllocateArray(length, 0); - DeviceMemory b = executor->AllocateArray(length, 0); - DeviceMemory c = executor->AllocateArray(length, 0); - - stream.ThenMemset32(&a, 1, byte_length); - stream.ThenMemset32(&b, 2, byte_length); - stream.ThenMemZero(&c, byte_length); - - // Use an array of device memory base pointers as argument to test packing. - KernelArgsDeviceMemoryArray args({a, b, c}, 0); - - // Create a command buffer by tracing kernel launch operations. - auto cmd_buffer = CommandBuffer::Trace(executor, [&](Stream* stream) { - return executor->Launch(stream, ThreadDim(), BlockDim(4), add, args); - }); - - TF_ASSERT_OK(cmd_buffer.status()); - TF_ASSERT_OK(executor->Submit(&stream, *cmd_buffer)); - - // Copy data back to host. - std::vector dst(4, 42); - stream.ThenMemcpy(dst.data(), c, byte_length); - - std::vector expected = {3, 3, 3, 3}; - ASSERT_EQ(dst, expected); -} - -TEST(CudaCommandBufferTest, LaunchNestedCommandBuffer) { - Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); - StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddCudaPtxInMemory(internal::kAddI32Kernel, "add"); - - AddI32Kernel add(executor); - TF_ASSERT_OK(executor->GetKernel(spec, &add)); - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: a=1, b=2, c=0 - DeviceMemory a = executor->AllocateArray(length, 0); - DeviceMemory b = executor->AllocateArray(length, 0); - DeviceMemory c = executor->AllocateArray(length, 0); - - stream.ThenMemset32(&a, 1, byte_length); - stream.ThenMemset32(&b, 2, byte_length); - stream.ThenMemZero(&c, byte_length); - - // Create a command buffer with a single kernel launch. - auto primary_cmd = CommandBuffer::Create(executor).value(); - auto nested_cmd = CommandBuffer::Create(executor, nested).value(); - TF_ASSERT_OK(nested_cmd.Launch(add, ThreadDim(), BlockDim(4), a, b, c)); - TF_ASSERT_OK(primary_cmd.AddNestedCommandBuffer(nested_cmd)); - TF_ASSERT_OK(primary_cmd.Finalize()); - - TF_ASSERT_OK(executor->Submit(&stream, primary_cmd)); - - // Copy `c` data back to host. - std::vector dst(4, 42); - stream.ThenMemcpy(dst.data(), c, byte_length); - - std::vector expected = {3, 3, 3, 3}; - ASSERT_EQ(dst, expected); - - // Prepare argument for graph update: d = 0 - DeviceMemory d = executor->AllocateArray(length, 0); - stream.ThenMemZero(&d, byte_length); - - // Update command buffer to write into `d` buffer by creating a new nested - // command buffer. - nested_cmd = CommandBuffer::Create(executor, nested).value(); - TF_ASSERT_OK(nested_cmd.Launch(add, ThreadDim(), BlockDim(4), a, b, d)); - TF_ASSERT_OK(primary_cmd.Update()); - TF_ASSERT_OK(primary_cmd.AddNestedCommandBuffer(nested_cmd)); - TF_ASSERT_OK(primary_cmd.Finalize()); - - TF_ASSERT_OK(executor->Submit(&stream, primary_cmd)); - - // Copy `d` data back to host. - std::fill(dst.begin(), dst.end(), 42); - stream.ThenMemcpy(dst.data(), d, byte_length); - ASSERT_EQ(dst, expected); -} - -TEST(CudaCommandBufferTest, MemcpyDeviceToDevice) { - Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); - StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: a=42, b=uninitialized - DeviceMemory a = executor->AllocateArray(length, 0); - DeviceMemory b = executor->AllocateArray(length, 0); - - stream.ThenMemset32(&a, 42, byte_length); - - // Create a command buffer with a single a to b memcpy command. - auto cmd_buffer = CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(cmd_buffer.MemcpyDeviceToDevice(&b, a, byte_length)); - TF_ASSERT_OK(cmd_buffer.Finalize()); - - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); - - // Copy `b` data back to host. - std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), a, byte_length); - - std::vector expected = {42, 42, 42, 42}; - ASSERT_EQ(dst, expected); - - // Update command buffer to swap the memcpy direction. - TF_ASSERT_OK(cmd_buffer.Update()); - TF_ASSERT_OK(cmd_buffer.MemcpyDeviceToDevice(&a, b, byte_length)); - TF_ASSERT_OK(cmd_buffer.Finalize()); - - // Clear destination to test that command buffer actually copied memory. - stream.ThenMemset32(&a, 0, byte_length); - - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); - - // Copy `a` data back to host. - std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), a, byte_length); - ASSERT_EQ(dst, expected); -} - -TEST(CudaCommandBufferTest, Memset) { - Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); - StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - DeviceMemory a = executor->AllocateArray(length, 0); - - // Create a command buffer with a single memset command. - auto cmd_buffer = CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(cmd_buffer.Memset(&a, uint32_t{42}, length)); - TF_ASSERT_OK(cmd_buffer.Finalize()); - - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); - - // Copy `a` data back to host. - std::vector dst(4, 0); - stream.ThenMemcpy(dst.data(), a, byte_length); - - std::vector expected = {42, 42, 42, 42}; - ASSERT_EQ(dst, expected); - - // Update command buffer to use a new bit pattern. - TF_ASSERT_OK(cmd_buffer.Update()); - TF_ASSERT_OK(cmd_buffer.Memset(&a, uint32_t{43}, length)); - TF_ASSERT_OK(cmd_buffer.Finalize()); - - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); - - // Copy `d` data back to host. - std::fill(dst.begin(), dst.end(), 0); - stream.ThenMemcpy(dst.data(), a, byte_length); - - expected = {43, 43, 43, 43}; - ASSERT_EQ(dst, expected); -} - -TEST(CudaCommandBufferTest, ConditionalIf) { - Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); - if (!CommandBuffer::SupportsConditionalCommands(platform)) { - GTEST_SKIP() << "CUDA graph conditionals are not supported"; - } - - StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - AddI32Kernel add(executor); - - { // Load addition kernel. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32CudaKernel(), "add"); - TF_ASSERT_OK(executor->GetKernel(spec, &add)); - } - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: a=1, b=2, c=0, pred=true - DeviceMemory pred = executor->AllocateArray(1, 0); - DeviceMemory a = executor->AllocateArray(length, 0); - DeviceMemory b = executor->AllocateArray(length, 0); - DeviceMemory c = executor->AllocateArray(length, 0); - - constexpr bool kTrue = true; - stream.ThenMemcpy(&pred, &kTrue, 1); - stream.ThenMemset32(&a, 1, byte_length); - stream.ThenMemset32(&b, 2, byte_length); - stream.ThenMemZero(&c, byte_length); - - // if (pred == true) c = a + b - CommandBuffer::Builder then_builder = [&](CommandBuffer* then_cmd) { - return then_cmd->Launch(add, ThreadDim(), BlockDim(4), a, b, c); - }; - - // Create a command buffer with a single conditional operation. - auto cmd_buffer = CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(cmd_buffer.If(executor, pred, then_builder)); - TF_ASSERT_OK(cmd_buffer.Finalize()); - - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); - - // Copy `c` data back to host. - std::vector dst(4, 42); - stream.ThenMemcpy(dst.data(), c, byte_length); - - std::vector expected = {3, 3, 3, 3}; - ASSERT_EQ(dst, expected); - - // Reset predicate to false and clear output buffer. - constexpr bool kFalse = false; - stream.ThenMemcpy(&pred, &kFalse, 1); - stream.ThenMemZero(&c, byte_length); - - // Submit the same command buffer, but this time it should not execute - // conditional branch as conditional handle should be updated to false. - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); - - stream.ThenMemcpy(dst.data(), c, byte_length); - std::vector zeroes = {0, 0, 0, 0}; - ASSERT_EQ(dst, zeroes); - - // Prepare argument for graph update: d = 0 - DeviceMemory d = executor->AllocateArray(length, 0); - stream.ThenMemZero(&d, byte_length); - - // Set predicate buffer to true to run conditional command buffer. - stream.ThenMemcpy(&pred, &kTrue, 1); - - // if (pred == true) d = a + b (write to a new location). - then_builder = [&](CommandBuffer* then_cmd) { - return then_cmd->Launch(add, ThreadDim(), BlockDim(4), a, b, d); - }; - - // Update command buffer with a conditional to use new builder. - TF_ASSERT_OK(cmd_buffer.Update()); - TF_ASSERT_OK(cmd_buffer.If(executor, pred, then_builder)); - TF_ASSERT_OK(cmd_buffer.Finalize()); - - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); - - // Copy `d` data back to host. - std::fill(dst.begin(), dst.end(), 42); - stream.ThenMemcpy(dst.data(), d, byte_length); - ASSERT_EQ(dst, expected); -} - -TEST(CudaCommandBufferTest, ConditionalIfElse) { - Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); - if (!CommandBuffer::SupportsConditionalCommands(platform)) { - GTEST_SKIP() << "CUDA graph conditionals are not supported"; - } - - StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - AddI32Kernel add(executor); - MulI32Kernel mul(executor); - - { // Load addition kernel. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32CudaKernel(), "add"); - TF_ASSERT_OK(executor->GetKernel(spec, &add)); - } - - { // Load multiplication kernel. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetMulI32CudaKernel(), "mul"); - TF_ASSERT_OK(executor->GetKernel(spec, &mul)); - } - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: a=2, b=3, c=0, pred=true - DeviceMemory pred = executor->AllocateArray(1, 0); - DeviceMemory a = executor->AllocateArray(length, 0); - DeviceMemory b = executor->AllocateArray(length, 0); - DeviceMemory c = executor->AllocateArray(length, 0); - - constexpr bool kTrue = true; - stream.ThenMemcpy(&pred, &kTrue, 1); - stream.ThenMemset32(&a, 2, byte_length); - stream.ThenMemset32(&b, 3, byte_length); - stream.ThenMemZero(&c, byte_length); - - // if (pred == true) c = a + b - CommandBuffer::Builder then_builder = [&](CommandBuffer* then_cmd) { - return then_cmd->Launch(add, ThreadDim(), BlockDim(4), a, b, c); - }; - - // if (pred == false) c = a * b - CommandBuffer::Builder else_builder = [&](CommandBuffer* else_cmd) { - return else_cmd->Launch(mul, ThreadDim(), BlockDim(4), a, b, c); - }; - - // Create a command buffer with a single conditional operation. - auto cmd_buffer = CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(cmd_buffer.IfElse(executor, pred, then_builder, else_builder)); - TF_ASSERT_OK(cmd_buffer.Finalize()); - - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); - TF_ASSERT_OK(stream.BlockHostUntilDone()); - - // Copy `c` data back to host. - std::vector dst(4, 42); - stream.ThenMemcpy(dst.data(), c, byte_length); - - std::vector expected_add = {5, 5, 5, 5}; - ASSERT_EQ(dst, expected_add); - - // Reset predicate to false. - constexpr bool kFalse = false; - stream.ThenMemcpy(&pred, &kFalse, 1); - - // Submit the same command buffer, but this time it should execute `else` - // branch and multiply inputs. - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); - TF_ASSERT_OK(stream.BlockHostUntilDone()); - - stream.ThenMemcpy(dst.data(), c, byte_length); - std::vector expected_mul = {6, 6, 6, 6}; - ASSERT_EQ(dst, expected_mul); - - // Prepare argument for graph update: d = 0 - DeviceMemory d = executor->AllocateArray(length, 0); - stream.ThenMemZero(&d, byte_length); - - // if (pred == false) d = a * b (write to a new location). - else_builder = [&](CommandBuffer* else_cmd) { - return else_cmd->Launch(mul, ThreadDim(), BlockDim(4), a, b, d); - }; - - // Update command buffer with a conditional to use new `else` builder. - TF_ASSERT_OK(cmd_buffer.Update()); - TF_ASSERT_OK(cmd_buffer.IfElse(executor, pred, then_builder, else_builder)); - TF_ASSERT_OK(cmd_buffer.Finalize()); - - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); - TF_ASSERT_OK(stream.BlockHostUntilDone()); - - // Copy `d` data back to host. - std::fill(dst.begin(), dst.end(), 42); - stream.ThenMemcpy(dst.data(), d, byte_length); - ASSERT_EQ(dst, expected_mul); -} - -TEST(CudaCommandBufferTest, ConditionalCase) { - Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); - if (!CommandBuffer::SupportsConditionalCommands(platform)) { - GTEST_SKIP() << "CUDA graph conditionals are not supported"; - } - - StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - AddI32Kernel add(executor); - MulI32Kernel mul(executor); - - { // Load addition kernel. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32CudaKernel(), "add"); - TF_ASSERT_OK(executor->GetKernel(spec, &add)); - } - - { // Load multiplication kernel. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetMulI32CudaKernel(), "mul"); - TF_ASSERT_OK(executor->GetKernel(spec, &mul)); - } - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: a=2, b=3, c=0, index=0 - DeviceMemory index = executor->AllocateArray(1, 0); - DeviceMemory a = executor->AllocateArray(length, 0); - DeviceMemory b = executor->AllocateArray(length, 0); - DeviceMemory c = executor->AllocateArray(length, 0); - - stream.ThenMemset32(&index, 0, sizeof(int32_t)); - stream.ThenMemset32(&a, 2, byte_length); - stream.ThenMemset32(&b, 3, byte_length); - stream.ThenMemZero(&c, byte_length); - - // if (index == 0) c = a + b - CommandBuffer::Builder branch0 = [&](CommandBuffer* branch0_cmd) { - return branch0_cmd->Launch(add, ThreadDim(), BlockDim(4), a, b, c); - }; - - // if (index == 1) c = a * b - CommandBuffer::Builder branch1 = [&](CommandBuffer* branch1_cmd) { - return branch1_cmd->Launch(mul, ThreadDim(), BlockDim(4), a, b, c); - }; - - // Create a command buffer with a single conditional operation. - auto cmd_buffer = CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(cmd_buffer.Case(executor, index, {branch0, branch1})); - TF_ASSERT_OK(cmd_buffer.Finalize()); - - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); - TF_ASSERT_OK(stream.BlockHostUntilDone()); - - // Copy `c` data back to host. - std::vector dst(4, 42); - stream.ThenMemcpy(dst.data(), c, byte_length); - - std::vector expected_add = {5, 5, 5, 5}; - ASSERT_EQ(dst, expected_add); - - // Set index to `1` - stream.ThenMemset32(&index, 1, sizeof(int32_t)); - - // Submit the same command buffer, but this time it should multiply inputs. - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); - TF_ASSERT_OK(stream.BlockHostUntilDone()); - - stream.ThenMemcpy(dst.data(), c, byte_length); - std::vector expected_mul = {6, 6, 6, 6}; - ASSERT_EQ(dst, expected_mul); - - // Set index to `-1` (out of bound index value). - stream.ThenMemset32(&index, -1, sizeof(int32_t)); - - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); - TF_ASSERT_OK(stream.BlockHostUntilDone()); - - stream.ThenMemcpy(dst.data(), c, byte_length); - ASSERT_EQ(dst, expected_mul); - - // Set index to `2` (out of bound index value). - stream.ThenMemset32(&index, 2, sizeof(int32_t)); - - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); - TF_ASSERT_OK(stream.BlockHostUntilDone()); - - stream.ThenMemcpy(dst.data(), c, byte_length); - ASSERT_EQ(dst, expected_mul); -} - -TEST(CudaCommandBufferTest, ConditionalFor) { - Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); - if (!CommandBuffer::SupportsConditionalCommands(platform)) { - GTEST_SKIP() << "CUDA graph conditionals are not supported"; - } - - StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - AddI32Kernel add(executor); - - { // Load addition kernel. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32CudaKernel(), "add"); - TF_ASSERT_OK(executor->GetKernel(spec, &add)); - } - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: a=1, b=0, loop_counter=100 - DeviceMemory loop_counter = executor->AllocateArray(1, 0); - DeviceMemory a = executor->AllocateArray(length, 0); - DeviceMemory b = executor->AllocateArray(length, 0); - - // Set loop counter to 100 to check that command buffer resets it. - stream.ThenMemset32(&loop_counter, 100, sizeof(int32_t)); - stream.ThenMemset32(&a, 1, byte_length); - stream.ThenMemZero(&b, byte_length); - - // Loop body: b = a + b - CommandBuffer::Builder body_builder = [&](CommandBuffer* body_cmd) { - return body_cmd->Launch(add, ThreadDim(), BlockDim(4), a, b, b); - }; - - int32_t num_iters = 10; - - // Create a command buffer with a single conditional operation. - auto cmd_buffer = CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(cmd_buffer.For(executor, num_iters, loop_counter, body_builder)); - TF_ASSERT_OK(cmd_buffer.Finalize()); - - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); - - // Copy `b` data back to host. - std::vector dst(4, 42); - stream.ThenMemcpy(dst.data(), b, byte_length); - - std::vector expected = {10, 10, 10, 10}; - ASSERT_EQ(dst, expected); -} - -TEST(CudaCommandBufferTest, ConditionalWhile) { - Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); - if (!CommandBuffer::SupportsConditionalCommands(platform)) { - GTEST_SKIP() << "CUDA graph conditionals are not supported"; - } - - StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - AddI32Kernel add(executor); - IncAndCmpKernel inc_and_cmp(executor); - - { // Load addition kernel. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32CudaKernel(), "add"); - TF_ASSERT_OK(executor->GetKernel(spec, &add)); - } - - { // Load inc_and_cmp kernel. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetIncAndCmpCudaKernel(), "inc_and_cmp"); - TF_ASSERT_OK(executor->GetKernel(spec, &inc_and_cmp)); - } - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: a=1, b=0, loop_counter=0, pred=false - // Value of `pred` is not important, as it will be updated by `cond_builder` - // below. - DeviceMemory pred = executor->AllocateArray(1, 0); - DeviceMemory loop_counter = executor->AllocateArray(1, 0); - DeviceMemory a = executor->AllocateArray(length, 0); - DeviceMemory b = executor->AllocateArray(length, 0); - - static constexpr bool kFalse = false; - stream.ThenMemcpy(&pred, &kFalse, 1); - stream.ThenMemset32(&loop_counter, 0, sizeof(int32_t)); - stream.ThenMemset32(&a, 1, byte_length); - stream.ThenMemZero(&b, byte_length); - - int32_t num_iters = 10; - - // Loop cond: loop_counter++ < num_iters; - CommandBuffer::Builder cond_builder = [&](CommandBuffer* cond_cmd) { - return cond_cmd->Launch(inc_and_cmp, ThreadDim(), BlockDim(), loop_counter, - pred, num_iters); - }; - - // Loop body: b = a + b - CommandBuffer::Builder body_builder = [&](CommandBuffer* body_cmd) { - return body_cmd->Launch(add, ThreadDim(), BlockDim(length), a, b, b); - }; - - // Create a command buffer with a single conditional operation. - auto cmd_buffer = CommandBuffer::Create(executor).value(); - TF_ASSERT_OK(cmd_buffer.While(executor, pred, cond_builder, body_builder)); - TF_ASSERT_OK(cmd_buffer.Finalize()); - - TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); - - // Copy `b` data back to host. - std::vector dst(4, 42); - stream.ThenMemcpy(dst.data(), b, byte_length); - - std::vector expected = {10, 10, 10, 10}; - ASSERT_EQ(dst, expected); -} - -//===----------------------------------------------------------------------===// -// Performance benchmarks below -//===----------------------------------------------------------------------===// - -#define BENCHMARK_SIZES(NAME) \ - BENCHMARK(NAME)->Arg(8)->Arg(32)->Arg(128)->Arg(512)->Arg(1024); - -// In benchmarks we construct command buffers in nested mode when we -// do not want to measure graph executable instantiation overhead. -static void BM_CreateCommandBuffer(benchmark::State& state) { - Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); - StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddCudaPtxInMemory(internal::kAddI32Kernel, "add"); - - AddI32Kernel add(executor); - CHECK_OK(executor->GetKernel(spec, &add)); - - DeviceMemory b = executor->AllocateArray(1, 0); - - for (auto s : state) { - auto cmd_buffer = CommandBuffer::Create(executor, nested).value(); - for (int i = 1; i < state.range(0); ++i) { - CHECK_OK(cmd_buffer.Launch(add, ThreadDim(), BlockDim(4), b, b, b)); - } - CHECK_OK(cmd_buffer.Finalize()); - } -} - -BENCHMARK_SIZES(BM_CreateCommandBuffer); - -static void BM_TraceCommandBuffer(benchmark::State& state) { - Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); - StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - Stream stream(executor); - stream.Init(); - CHECK(stream.ok()); - - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddCudaPtxInMemory(internal::kAddI32Kernel, "add"); - - AddI32Kernel add(executor); - CHECK_OK(executor->GetKernel(spec, &add)); - - DeviceMemory b = executor->AllocateArray(1, 0); - - for (auto s : state) { - auto launch_kernels = [&](Stream* stream) { - for (int i = 1; i < state.range(0); ++i) { - CHECK_OK(stream->ThenLaunch(ThreadDim(), BlockDim(4), add, b, b, b)); - } - return tsl::OkStatus(); - }; - - CHECK_OK(CommandBuffer::Trace(executor, launch_kernels, nested)); - } -} - -BENCHMARK_SIZES(BM_TraceCommandBuffer); - -static void BM_UpdateCommandBuffer(benchmark::State& state) { - Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); - StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddCudaPtxInMemory(internal::kAddI32Kernel, "add"); - - AddI32Kernel add(executor); - CHECK_OK(executor->GetKernel(spec, &add)); - - DeviceMemory b = executor->AllocateArray(1, 0); - - auto cmd_buffer = CommandBuffer::Create(executor, primary).value(); - for (int i = 1; i < state.range(0); ++i) { - CHECK_OK(cmd_buffer.Launch(add, ThreadDim(), BlockDim(4), b, b, b)); - } - CHECK_OK(cmd_buffer.Finalize()); - - for (auto s : state) { - CHECK_OK(cmd_buffer.Update()); - for (int i = 1; i < state.range(0); ++i) { - CHECK_OK(cmd_buffer.Launch(add, ThreadDim(), BlockDim(4), b, b, b)); - } - CHECK_OK(cmd_buffer.Finalize()); - } -} - -BENCHMARK_SIZES(BM_UpdateCommandBuffer); - -} // namespace stream_executor::cuda diff --git a/xla/stream_executor/cuda/cuda_conditional_kernels.cc b/xla/stream_executor/cuda/cuda_conditional_kernels.cc new file mode 100644 index 0000000000000..005889c540587 --- /dev/null +++ b/xla/stream_executor/cuda/cuda_conditional_kernels.cc @@ -0,0 +1,744 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +namespace stream_executor::gpu { + +// Collection of helper kernels required by command buffers on CUDA backends. We +// use pre-compiled PTX instead of a CUDA C++ because conditional nodes require +// CUDA 12.3+ and trying to run with earlier CUDA versions leads to run time +// errors as all CUDA C++ kernels registered in a global static registry and a +// failure to load ONE kernel leads to failure to load ANY kernel at all. We +// should be able to switch to CUDA C++ once the minimum supported CUDA version +// will be larger than 12.3. + +// In all kernels defined below we set conditional handle value to `1` when we +// want to execute a CUDA graph tied to it, and to `0` otherwise. For loops, the +// graph will keep being executed until the conditional handle becomes `0`. + +// PTX kernel compiled from: +// +// __global__ void SetIfCondition(cudaGraphConditionalHandle then_handle, +// bool* predicate) { +// if (*predicate) { +// cudaGraphSetConditional(then_handle, 1); +// } else { +// cudaGraphSetConditional(then_handle, 0); +// } +// } +// +// Easiest way to get PTX from C++ is to use https://godbolt.org. +std::string_view GetSetIfConditionKernel() { + return R"( +.version 4.0 +.target sm_50 +.address_size 64 + +.extern .func cudaGraphSetConditional +( + .param .b64 cudaGraphSetConditional_param_0, + .param .b32 cudaGraphSetConditional_param_1 +) + +.visible .entry set_if_condition( + .param .u64 set_if_condition_param_0, + .param .u64 set_if_condition_param_1 +) +{ + .reg .pred %p<2>; + .reg .b16 %rs<2>; + .reg .b64 %rd<4>; + .loc 1 1 0 + + ld.param.u64 %rd1, [set_if_condition_param_0]; + ld.param.u64 %rd2, [set_if_condition_param_1]; + .loc 1 3 3 + cvta.to.global.u64 %rd3, %rd2; + ld.global.u8 %rs1, [%rd3]; + setp.eq.s16 %p1, %rs1, 0; + @%p1 bra $L__BB0_2; + + .loc 1 4 5 + { // callseq 0, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd1; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 0 + bra.uni $L__BB0_3; + +$L__BB0_2: + .loc 1 6 5 + { // callseq 1, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd1; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 1 + +$L__BB0_3: + .loc 1 8 1 + ret; + +})"; +} + +// PTX kernel compiled from: +// +// __global__ void SetIfElseCondition(cudaGraphConditionalHandle then_handle, +// cudaGraphConditionalHandle else_handle, +// bool* predicate) { +// if (*predicate) { +// cudaGraphSetConditional(then_handle, 1); +// cudaGraphSetConditional(else_handle, 0); +// } else { +// cudaGraphSetConditional(then_handle, 0); +// cudaGraphSetConditional(else_handle, 1); +// } +// } +// +// Easiest way to get PTX from C++ is to use https://godbolt.org. +std::string_view GetSetIfElseConditionKernel() { + return R"( +.version 4.0 +.target sm_50 +.address_size 64 + +.extern .func cudaGraphSetConditional +( + .param .b64 cudaGraphSetConditional_param_0, + .param .b32 cudaGraphSetConditional_param_1 +) + +.visible .entry set_if_else_condition( + .param .u64 set_if_else_condition_param_0, + .param .u64 set_if_else_condition_param_1, + .param .u64 set_if_else_condition_param_2 +) +{ + .reg .pred %p<2>; + .reg .b16 %rs<2>; + .reg .b64 %rd<5>; + .loc 1 1 0 + + ld.param.u64 %rd1, [set_if_else_condition_param_0]; + ld.param.u64 %rd2, [set_if_else_condition_param_1]; + ld.param.u64 %rd3, [set_if_else_condition_param_2]; + .loc 1 4 3 + cvta.to.global.u64 %rd4, %rd3; + ld.global.u8 %rs1, [%rd4]; + setp.eq.s16 %p1, %rs1, 0; + @%p1 bra $L__BB0_2; + + .loc 1 5 5 + { // callseq 0, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd1; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 0 + .loc 1 6 5 + { // callseq 1, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd2; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 1 + bra.uni $L__BB0_3; + +$L__BB0_2: + .loc 1 8 5 + { // callseq 2, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd1; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 2 + .loc 1 9 5 + { // callseq 3, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd2; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 3 + +$L__BB0_3: + .loc 1 11 1 + ret; + +})"; +} + +// PTX kernel compiled from: +// +// __global__ void SetCaseCondition( +// cudaGraphConditionalHandle h0, cudaGraphConditionalHandle h1, +// cudaGraphConditionalHandle h2, cudaGraphConditionalHandle h3, +// cudaGraphConditionalHandle h4, cudaGraphConditionalHandle h5, +// cudaGraphConditionalHandle h6, cudaGraphConditionalHandle h7, +// int32_t* index, int32_t num_handles) { +// // Only handles in [0, num_handles) range are valid. +// // +// // We can't define a device function with dynamic number of handle +// // arguments, so we always pass 8 handles, but only some of them are valid. +// // Size 8 picked as a reasonable (but random) upper bound for what we see +// // in XLA uses. +// std::array handles = {h0, h1, h2, h3, +// h4, h5, h6, h7}; + +// // If branch index is out of range activate the last valid handle. +// int32_t branch_index = *index; +// if (branch_index < 0 || branch_index >= num_handles) { +// branch_index = num_handles - 1; +// } + +// for (int32_t i = 0; i < num_handles; ++i) { +// if (branch_index == i) { +// cudaGraphSetConditional(handles[i], 1); +// } else { +// cudaGraphSetConditional(handles[i], 0); +// } +// } +// } +// +// Easiest way to get PTX from C++ is to use https://godbolt.org. +std::string_view GetSetCaseConditionKernel() { + return R"( +.version 4.0 +.target sm_50 +.address_size 64 + +.extern .func cudaGraphSetConditional +( + .param .b64 cudaGraphSetConditional_param_0, + .param .b32 cudaGraphSetConditional_param_1 +) + +.visible .entry set_case_condition( + .param .u64 set_case_condition_param_0, + .param .u64 set_case_condition_param_1, + .param .u64 set_case_condition_param_2, + .param .u64 set_case_condition_param_3, + .param .u64 set_case_condition_param_4, + .param .u64 set_case_condition_param_5, + .param .u64 set_case_condition_param_6, + .param .u64 set_case_condition_param_7, + .param .u64 set_case_condition_param_8, + .param .u32 set_case_condition_param_9 +) +{ + .local .align 16 .b8 __local_depot0[64]; + .reg .b64 %SP; + .reg .b64 %SPL; + .reg .pred %p<14>; + .reg .b32 %r<31>; + .reg .b64 %rd<27>; + .loc 1 4 0 + + mov.u64 %SPL, __local_depot0; + ld.param.u64 %rd13, [set_case_condition_param_8]; + ld.param.u32 %r18, [set_case_condition_param_9]; + cvta.to.global.u64 %rd14, %rd13; + .loc 1 15 3 + add.u64 %rd1, %SPL, 0; + ld.param.u64 %rd16, [set_case_condition_param_1]; + ld.param.u64 %rd17, [set_case_condition_param_0]; + st.local.v2.u64 [%rd1], {%rd17, %rd16}; + ld.param.u64 %rd18, [set_case_condition_param_3]; + ld.param.u64 %rd19, [set_case_condition_param_2]; + st.local.v2.u64 [%rd1+16], {%rd19, %rd18}; + ld.param.u64 %rd20, [set_case_condition_param_5]; + ld.param.u64 %rd21, [set_case_condition_param_4]; + .loc 1 16 60 + st.local.v2.u64 [%rd1+32], {%rd21, %rd20}; + ld.param.u64 %rd22, [set_case_condition_param_7]; + ld.param.u64 %rd23, [set_case_condition_param_6]; + .loc 1 16 68 + st.local.v2.u64 [%rd1+48], {%rd23, %rd22}; + .loc 1 19 3 + ld.global.u32 %r19, [%rd14]; + .loc 1 20 3 + setp.lt.s32 %p1, %r19, 0; + setp.ge.s32 %p2, %r19, %r18; + or.pred %p3, %p1, %p2; + .loc 1 21 5 + add.s32 %r1, %r18, -1; + .loc 1 20 3 + selp.b32 %r2, %r1, %r19, %p3; + .loc 1 24 3 + setp.lt.s32 %p4, %r18, 1; + @%p4 bra $L__BB0_22; + + .loc 1 25 5 + and.b32 %r30, %r18, 3; + setp.lt.u32 %p5, %r1, 3; + mov.u32 %r28, 0; + @%p5 bra $L__BB0_16; + + sub.s32 %r27, %r18, %r30; + neg.s32 %r25, %r2; + mov.u32 %r28, 0; + mov.u64 %rd25, %rd1; + +$L__BB0_3: + .loc 1 0 0 + ld.local.u64 %rd4, [%rd25]; + .loc 1 25 5 + setp.eq.s32 %p6, %r25, 0; + @%p6 bra $L__BB0_5; + + .loc 1 28 7 + { // callseq 0, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd4; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 0 + bra.uni $L__BB0_6; + +$L__BB0_5: + .loc 1 26 7 + { // callseq 1, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd4; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 1 + +$L__BB0_6: + .loc 1 24 40 + add.s32 %r22, %r28, 1; + .loc 1 25 5 + setp.eq.s32 %p7, %r2, %r22; + .loc 1 0 0 + ld.local.u64 %rd5, [%rd25+8]; + .loc 1 25 5 + @%p7 bra $L__BB0_8; + bra.uni $L__BB0_7; + +$L__BB0_8: + .loc 1 26 7 + { // callseq 3, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd5; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 3 + bra.uni $L__BB0_9; + +$L__BB0_7: + .loc 1 28 7 + { // callseq 2, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd5; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 2 + +$L__BB0_9: + .loc 1 24 40 + add.s32 %r23, %r28, 2; + .loc 1 25 5 + setp.eq.s32 %p8, %r2, %r23; + .loc 1 0 0 + ld.local.u64 %rd6, [%rd25+16]; + .loc 1 25 5 + @%p8 bra $L__BB0_11; + bra.uni $L__BB0_10; + +$L__BB0_11: + .loc 1 26 7 + { // callseq 5, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd6; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 5 + bra.uni $L__BB0_12; + +$L__BB0_10: + .loc 1 28 7 + { // callseq 4, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd6; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 4 + +$L__BB0_12: + .loc 1 24 40 + add.s32 %r24, %r28, 3; + .loc 1 25 5 + setp.eq.s32 %p9, %r2, %r24; + .loc 1 0 0 + ld.local.u64 %rd7, [%rd25+24]; + .loc 1 25 5 + @%p9 bra $L__BB0_14; + bra.uni $L__BB0_13; + +$L__BB0_14: + .loc 1 26 7 + { // callseq 7, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd7; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 7 + bra.uni $L__BB0_15; + +$L__BB0_13: + .loc 1 28 7 + { // callseq 6, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd7; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 6 + +$L__BB0_15: + .loc 1 24 40 + add.s64 %rd25, %rd25, 32; + add.s32 %r28, %r28, 4; + .loc 1 24 3 + add.s32 %r25, %r25, 4; + add.s32 %r27, %r27, -4; + setp.ne.s32 %p10, %r27, 0; + @%p10 bra $L__BB0_3; + +$L__BB0_16: + .loc 1 25 5 + setp.eq.s32 %p11, %r30, 0; + @%p11 bra $L__BB0_22; + + mul.wide.s32 %rd24, %r28, 8; + add.s64 %rd26, %rd1, %rd24; + sub.s32 %r29, %r28, %r2; + +$L__BB0_18: + .pragma "nounroll"; + .loc 1 0 0 + ld.local.u64 %rd11, [%rd26]; + .loc 1 25 5 + setp.eq.s32 %p12, %r29, 0; + @%p12 bra $L__BB0_20; + + .loc 1 28 7 + { // callseq 8, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd11; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 8 + bra.uni $L__BB0_21; + +$L__BB0_20: + .loc 1 26 7 + { // callseq 9, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd11; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 9 + +$L__BB0_21: + .loc 1 24 3 + add.s32 %r30, %r30, -1; + add.s64 %rd26, %rd26, 8; + add.s32 %r29, %r29, 1; + setp.ne.s32 %p13, %r30, 0; + @%p13 bra $L__BB0_18; + +$L__BB0_22: + .loc 1 31 1 + ret; + +})"; +} + +// PTX kernel compiled from: +// +// __global__ void SetForCondition(cudaGraphConditionalHandle handle, +// int32_t* loop_index, +// int32_t num_iterations) { +// if (*loop_index < num_iterations) { +// cudaGraphSetConditional(handle, 1); +// } else { +// cudaGraphSetConditional(handle, 0); +// } +// *loop_index += 1; +// } +// +// Easiest way to get PTX from C++ is to use https://godbolt.org. +std::string_view GetSetForConditionKernel() { + return R"( +.version 4.0 +.target sm_50 +.address_size 64 + +.extern .func cudaGraphSetConditional +( + .param .b64 cudaGraphSetConditional_param_0, + .param .b32 cudaGraphSetConditional_param_1 +) + +.visible .entry set_for_condition( + .param .u64 set_for_condition_param_0, + .param .u64 set_for_condition_param_1, + .param .u32 set_for_condition_param_2 +) +{ + .reg .pred %p<2>; + .reg .b32 %r<5>; + .reg .b64 %rd<4>; + .loc 1 1 0 + + ld.param.u64 %rd2, [set_for_condition_param_0]; + ld.param.u64 %rd3, [set_for_condition_param_1]; + ld.param.u32 %r1, [set_for_condition_param_2]; + .loc 1 3 3 + cvta.to.global.u64 %rd1, %rd3; + ld.global.u32 %r2, [%rd1]; + setp.lt.s32 %p1, %r2, %r1; + @%p1 bra $L__BB0_2; + bra.uni $L__BB0_1; + +$L__BB0_2: + .loc 1 4 5 + { // callseq 1, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd2; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 1 + bra.uni $L__BB0_3; + +$L__BB0_1: + .loc 1 6 5 + { // callseq 0, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd2; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 0 + +$L__BB0_3: + .loc 1 8 3 + ld.global.u32 %r3, [%rd1]; + add.s32 %r4, %r3, 1; + st.global.u32 [%rd1], %r4; + .loc 1 9 1 + ret; + +})"; +} + +std::string_view GetSetWhileConditionKernel() { + // While condition kernel is the same as an `If` with a single branch. + return R"( +.version 4.0 +.target sm_50 +.address_size 64 + +.extern .func cudaGraphSetConditional +( + .param .b64 cudaGraphSetConditional_param_0, + .param .b32 cudaGraphSetConditional_param_1 +) + +.visible .entry set_while_condition( + .param .u64 set_while_condition_param_0, + .param .u64 set_while_condition_param_1 +) +{ + .reg .pred %p<2>; + .reg .b16 %rs<2>; + .reg .b64 %rd<4>; + .loc 1 1 0 + + ld.param.u64 %rd1, [set_while_condition_param_0]; + ld.param.u64 %rd2, [set_while_condition_param_1]; + .loc 1 3 3 + cvta.to.global.u64 %rd3, %rd2; + ld.global.u8 %rs1, [%rd3]; + setp.eq.s16 %p1, %rs1, 0; + @%p1 bra $L__BB0_2; + + .loc 1 4 5 + { // callseq 0, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd1; + .param .b32 param1; + st.param.b32 [param1+0], 1; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 0 + bra.uni $L__BB0_3; + +$L__BB0_2: + .loc 1 6 5 + { // callseq 1, 0 + .reg .b32 temp_param_reg; + .param .b64 param0; + st.param.b64 [param0+0], %rd1; + .param .b32 param1; + st.param.b32 [param1+0], 0; + call.uni + cudaGraphSetConditional, + ( + param0, + param1 + ); + } // callseq 1 + +$L__BB0_3: + .loc 1 8 1 + ret; + +})"; +} + +} // namespace stream_executor::gpu diff --git a/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc b/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc deleted file mode 100644 index 0ce8dbd4a4fa3..0000000000000 --- a/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc +++ /dev/null @@ -1,129 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "third_party/gpus/cuda/include/cuda.h" - -namespace stream_executor { -namespace cuda { -namespace { - -// In all kernels defined below we set conditional handle value to `1` when we -// want to execute a CUDA graph tied to it, and to `0` otherwise. For loops, the -// graph will keep being executed until the conditional handle becomes `0`. - -#if defined(STREAM_EXECUTOR_CUDA_ENABLE_GRAPH_CONDITIONAL) && \ - CUDA_VERSION >= 12030 - -__global__ void SetIfCondition(cudaGraphConditionalHandle then_handle, - bool* predicate) { - if (*predicate) { - cudaGraphSetConditional(then_handle, 1); - } else { - cudaGraphSetConditional(then_handle, 0); - } -} - -__global__ void SetIfElseCondition(cudaGraphConditionalHandle then_handle, - cudaGraphConditionalHandle else_handle, - bool* predicate) { - if (*predicate) { - cudaGraphSetConditional(then_handle, 1); - cudaGraphSetConditional(else_handle, 0); - } else { - cudaGraphSetConditional(then_handle, 0); - cudaGraphSetConditional(else_handle, 1); - } -} - -__global__ void SetCaseCondition( - cudaGraphConditionalHandle h0, cudaGraphConditionalHandle h1, - cudaGraphConditionalHandle h2, cudaGraphConditionalHandle h3, - cudaGraphConditionalHandle h4, cudaGraphConditionalHandle h5, - cudaGraphConditionalHandle h6, cudaGraphConditionalHandle h7, - int32_t* index, int32_t num_handles) { - // Only handles in [0, num_handles) range are valid. - // - // We can't define a device function with dynamic number of handle arguments, - // so we always pass 8 handles, but only some of them are valid. Size 8 picked - // as a reasonable (but random) upper bound for what we see in XLA uses. - std::array handles = {h0, h1, h2, h3, - h4, h5, h6, h7}; - - // If branch index is out of range activate the last valid handle. - int32_t branch_index = *index; - if (branch_index < 0 || branch_index >= num_handles) { - branch_index = num_handles - 1; - } - - for (int32_t i = 0; i < num_handles; ++i) { - if (branch_index == i) { - cudaGraphSetConditional(handles[i], 1); - } else { - cudaGraphSetConditional(handles[i], 0); - } - } -} - -__global__ void SetForCondition(cudaGraphConditionalHandle handle, - int32_t* loop_index, int32_t num_iterations) { - if (*loop_index < num_iterations) { - cudaGraphSetConditional(handle, 1); - } else { - cudaGraphSetConditional(handle, 0); - } - *loop_index += 1; -} - -#else // CUDA graph conditionals are not available - -__global__ void SetIfCondition() {} -__global__ void SetIfElseCondition() {} -__global__ void SetCaseCondition() {} -__global__ void SetForCondition() {} - -#endif - -} // namespace -} // namespace cuda - -namespace gpu { - -void* GetSetIfConditionKernel() { - return reinterpret_cast(&cuda::SetIfCondition); -} - -void* GetSetIfElseConditionKernel() { - return reinterpret_cast(&cuda::SetIfElseCondition); -} - -void* GetSetCaseConditionKernel() { - return reinterpret_cast(&cuda::SetCaseCondition); -} - -void* GetSetForConditionKernel() { - return reinterpret_cast(&cuda::SetForCondition); -} - -void* GetSetWhileConditionKernel() { - // While condition kernel is the same as an `If` with a single branch. - return reinterpret_cast(&cuda::SetIfCondition); -} - -} // namespace gpu - -} // namespace stream_executor diff --git a/xla/stream_executor/cuda/cuda_diagnostics.cc b/xla/stream_executor/cuda/cuda_diagnostics.cc index e7bbd32f13487..561ac0d401e2f 100644 --- a/xla/stream_executor/cuda/cuda_diagnostics.cc +++ b/xla/stream_executor/cuda/cuda_diagnostics.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,25 +24,28 @@ limitations under the License. #include #include #include + #if !defined(PLATFORM_WINDOWS) #include #include #include #endif + #include -#include -#include + +#include #include #include "absl/container/inlined_vector.h" +#include "absl/status/status.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "absl/strings/strip.h" +#include "xla/stream_executor/gpu/gpu_diagnostics.h" #include "tsl/platform/host_info.h" #include "tsl/platform/logging.h" -#include "tsl/platform/status.h" namespace stream_executor { namespace cuda { @@ -52,7 +55,7 @@ std::string DriverVersionToString(DriverVersion version) { std::get<2>(version)); } -std::string DriverVersionStatusToString(tsl::StatusOr version) { +std::string DriverVersionStatusToString(absl::StatusOr version) { if (!version.ok()) { return version.status().ToString(); } @@ -60,37 +63,32 @@ std::string DriverVersionStatusToString(tsl::StatusOr version) { return DriverVersionToString(version.value()); } -tsl::StatusOr StringToDriverVersion(const std::string &value) { +absl::StatusOr StringToDriverVersion(const std::string &value) { std::vector pieces = absl::StrSplit(value, '.'); if (pieces.size() < 2 || pieces.size() > 4) { - return tsl::Status( - absl::StatusCode::kInvalidArgument, - absl::StrFormat( - "expected %%d.%%d, %%d.%%d.%%d, or %%d.%%d.%%d.%%d form " - "for driver version; got \"%s\"", - value.c_str())); + return absl::InvalidArgumentError(absl::StrFormat( + "expected %%d.%%d, %%d.%%d.%%d, or %%d.%%d.%%d.%%d form " + "for driver version; got \"%s\"", + value.c_str())); } int major; int minor; int patch = 0; if (!absl::SimpleAtoi(pieces[0], &major)) { - return tsl::Status( - absl::StatusCode::kInvalidArgument, + return absl::InvalidArgumentError( absl::StrFormat("could not parse major version number \"%s\" as an " "integer from string \"%s\"", pieces[0], value)); } if (!absl::SimpleAtoi(pieces[1], &minor)) { - return tsl::Status( - absl::StatusCode::kInvalidArgument, + return absl::InvalidArgumentError( absl::StrFormat("could not parse minor version number \"%s\" as an " "integer from string \"%s\"", pieces[1].c_str(), value.c_str())); } if (pieces.size() == 3 && !absl::SimpleAtoi(pieces[2], &patch)) { - return tsl::Status( - absl::StatusCode::kInvalidArgument, + return absl::InvalidArgumentError( absl::StrFormat("could not parse patch version number \"%s\" as an " "integer from string \"%s\"", pieces[2], value)); @@ -164,11 +162,11 @@ void Diagnostician::LogDiagnosticInformation() { closedir(dir); } } - tsl::StatusOr dso_version = FindDsoVersion(); + absl::StatusOr dso_version = FindDsoVersion(); LOG(INFO) << "libcuda reported version is: " << cuda::DriverVersionStatusToString(dso_version); - tsl::StatusOr kernel_version = FindKernelDriverVersion(); + absl::StatusOr kernel_version = FindKernelDriverVersion(); LOG(INFO) << "kernel reported version is: " << cuda::DriverVersionStatusToString(kernel_version); #endif @@ -182,9 +180,8 @@ void Diagnostician::LogDiagnosticInformation() { // Iterates through loaded DSOs with DlIteratePhdrCallback to find the // driver-interfacing DSO version number. Returns it as a string. -tsl::StatusOr Diagnostician::FindDsoVersion() { - tsl::StatusOr result(tsl::Status( - absl::StatusCode::kNotFound, +absl::StatusOr Diagnostician::FindDsoVersion() { + absl::StatusOr result(absl::NotFoundError( "was unable to find libcuda.so DSO loaded into this program")); #if !defined(PLATFORM_WINDOWS) && !defined(ANDROID_TEGRA) @@ -211,7 +208,7 @@ tsl::StatusOr Diagnostician::FindDsoVersion() { std::string dso_version = dot + strlen(so_suffix); // TODO(b/22689637): Eliminate the explicit namespace if possible. auto stripped_dso_version = absl::StripSuffix(dso_version, ".ld64"); - auto result = static_cast *>(data); + auto result = static_cast *>(data); *result = cuda::StringToDriverVersion(std::string(stripped_dso_version)); return 1; } @@ -224,13 +221,12 @@ tsl::StatusOr Diagnostician::FindDsoVersion() { return result; } -tsl::StatusOr Diagnostician::FindKernelModuleVersion( +absl::StatusOr Diagnostician::FindKernelModuleVersion( const std::string &driver_version_file_contents) { static const char *kDriverFilePrelude = "Kernel Module "; size_t offset = driver_version_file_contents.find(kDriverFilePrelude); if (offset == std::string::npos) { - return tsl::Status( - absl::StatusCode::kNotFound, + return absl::NotFoundError( absl::StrCat("could not find kernel module information in " "driver version file contents: \"", driver_version_file_contents, "\"")); @@ -246,8 +242,8 @@ tsl::StatusOr Diagnostician::FindKernelModuleVersion( } void Diagnostician::WarnOnDsoKernelMismatch( - tsl::StatusOr dso_version, - tsl::StatusOr kernel_version) { + absl::StatusOr dso_version, + absl::StatusOr kernel_version) { if (kernel_version.ok() && dso_version.ok() && dso_version.value() == kernel_version.value()) { LOG(INFO) << "kernel version seems to match DSO: " @@ -261,11 +257,10 @@ void Diagnostician::WarnOnDsoKernelMismatch( } } -tsl::StatusOr Diagnostician::FindKernelDriverVersion() { +absl::StatusOr Diagnostician::FindKernelDriverVersion() { FILE *driver_version_file = fopen(kDriverVersionPath, "r"); if (driver_version_file == nullptr) { - return tsl::Status( - absl::StatusCode::kPermissionDenied, + return absl::PermissionDeniedError( absl::StrCat("could not open driver version path for reading: ", kDriverVersionPath)); } @@ -286,11 +281,9 @@ tsl::StatusOr Diagnostician::FindKernelDriverVersion() { return FindKernelModuleVersion(contents.begin()); } - auto status = tsl::Status( - absl::StatusCode::kInternal, - absl::StrCat( - "failed to read driver version file contents: ", kDriverVersionPath, - "; ferror: ", ferror(driver_version_file))); + auto status = absl::InternalError(absl::StrCat( + "failed to read driver version file contents: ", kDriverVersionPath, + "; ferror: ", ferror(driver_version_file))); fclose(driver_version_file); return status; } diff --git a/xla/stream_executor/cuda/cuda_diagnostics.h b/xla/stream_executor/cuda/cuda_diagnostics.h index 46d1a001051c6..ea1fa0cfc51a7 100644 --- a/xla/stream_executor/cuda/cuda_diagnostics.h +++ b/xla/stream_executor/cuda/cuda_diagnostics.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_DIAGNOSTICS_H_ #define XLA_STREAM_EXECUTOR_CUDA_CUDA_DIAGNOSTICS_H_ +#include + +#include "absl/status/statusor.h" #include "xla/stream_executor/gpu/gpu_diagnostics.h" namespace stream_executor { @@ -28,10 +31,10 @@ using DriverVersion = gpu::DriverVersion; std::string DriverVersionToString(DriverVersion version); // Converts a parsed driver version or status value to natural string form. -std::string DriverVersionStatusToString(tsl::StatusOr version); +std::string DriverVersionStatusToString(absl::StatusOr version); // Converts a string of a form like "331.79" to a DriverVersion{331, 79}. -tsl::StatusOr StringToDriverVersion(const std::string& value); +absl::StatusOr StringToDriverVersion(const std::string& value); using Diagnostician = gpu::Diagnostician; diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc index 41a5440cc46f7..2e1e02aa1f2c1 100644 --- a/xla/stream_executor/cuda/cuda_dnn.cc +++ b/xla/stream_executor/cuda/cuda_dnn.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,34 +17,50 @@ limitations under the License. #include #include +#include #include #include #include #include #include #include +#include #include #include #include #include #include +#include "absl/algorithm/container.h" #include "absl/base/optimization.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/memory/memory.h" +#include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "absl/types/span.h" #include "Eigen/Core" // from @eigen_archive +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/gpus/cuda/include/driver_types.h" #include "xla/stream_executor/cuda/cuda_activation.h" #include "xla/stream_executor/cuda/cuda_diagnostics.h" -#include "xla/stream_executor/cuda/cuda_driver.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" -#include "xla/stream_executor/cuda/cuda_stream.h" +#include "xla/stream_executor/cuda/cudnn_frontend_helpers.h" +#include "xla/stream_executor/data_type.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/gpu/gpu_activation.h" +#include "xla/stream_executor/gpu/gpu_diagnostics.h" +#include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/numeric_options.h" #include "xla/stream_executor/platform/initialize.h" @@ -53,21 +69,49 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" +#include "xla/tsl/util/env_var.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/tensor_float_32_utils.h" -#include "tsl/util/env_var.h" +#include "tsl/protobuf/dnn.pb.h" // clang-format off #include "third_party/gpus/cuda/include/library_types.h" -#include "third_party/gpus/cudnn/cudnn.h" #include "third_party/gpus/cudnn/cudnn_version.h" -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND + +#if CUDNN_VERSION >= 90000 +#include "third_party/gpus/cudnn/cudnn_adv.h" +#include "third_party/gpus/cudnn/cudnn_cnn.h" +#include "third_party/gpus/cudnn/cudnn_ops.h" +#elif CUDNN_VERSION >= 8100 +#include "third_party/gpus/cudnn/cudnn_adv_infer.h" +#include "third_party/gpus/cudnn/cudnn_adv_train.h" +#include "third_party/gpus/cudnn/cudnn_cnn_infer.h" +#include "third_party/gpus/cudnn/cudnn_cnn_train.h" +#include "third_party/gpus/cudnn/cudnn_ops_infer.h" +#include "third_party/gpus/cudnn/cudnn_ops_train.h" +#endif + +#include "third_party/gpus/cudnn/cudnn_backend.h" + +#if CUDNN_VERSION >= 8100 #include "third_party/cudnn_frontend/include/cudnn_frontend.h" #include "third_party/cudnn_frontend/include/cudnn_frontend_utils.h" -#endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND -#include "absl/strings/string_view.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_EngineConfig.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_Errata.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_ExecutionPlan.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_Filters.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_Heuristics.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_MatMulDesc.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_Operation.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_OperationGraph.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_PointWiseDesc.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_Rng.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_Tensor.h" +#include "third_party/cudnn_frontend/include/cudnn_frontend_VariantPack.h" +#endif // CUDNN_VERSION >= 8100 // clang-format on #ifdef __clang__ @@ -89,7 +133,7 @@ static_assert(CUDNN_VERSION >= 7300, "cuDNN needs to be version 7.3 or higher"); #define CHECK_CUDNN_OK(expr) CHECK_EQ(expr, CUDNN_STATUS_SUCCESS) // If 'expr' doesn't return CUDNN_STATUS_SUCCESS, returns from the current -// function with a non-successful tsl::Status. +// function with a non-successful absl::Status. #define RETURN_IF_CUDNN_ERROR(expr) \ do { \ cudnnStatus_t _status = (expr); \ @@ -97,7 +141,7 @@ static_assert(CUDNN_VERSION >= 7300, "cuDNN needs to be version 7.3 or higher"); std::ostringstream oss; \ oss << CudnnStatusToString(_status) << "\nin " << __FILE__ << "(" \ << __LINE__ << "): '" << #expr << "'"; \ - return tsl::Status(absl::StatusCode::kUnknown, oss.str()); \ + return absl::UnknownError(oss.str()); \ } \ } while (false) @@ -108,7 +152,7 @@ static_assert(CUDNN_VERSION >= 7300, "cuDNN needs to be version 7.3 or higher"); std::ostringstream oss; \ oss << CudnnStatusToString(_status) << "\nin " << __FILE__ << "(" \ << __LINE__ << "): '" << #expr << "' " << (expr).get_error(); \ - return tsl::Status(absl::StatusCode::kUnknown, oss.str()); \ + return absl::UnknownError(oss.str()); \ } \ } while (false) @@ -185,6 +229,19 @@ class CudnnHandle { cudnnHandle_t handle_; // Not owned. }; +// RAII wrapper for temporary cuDNN handles that are used for multithreaded +// compilation. Unlike with CudnnAccess these are not associated +// with GPU devices and are not locked. +class LocalCuDnnHandle { + public: + explicit LocalCuDnnHandle(cudnnHandle_t handle) : handle_(handle) {} + ~LocalCuDnnHandle() { cudnnDestroy(handle_); } + cudnnHandle_t handle() { return handle_; } + + private: + cudnnHandle_t handle_; +}; + // Major version is neither forward or backward compatible and therefore major // versions needs to match between source and library. // @@ -244,6 +301,14 @@ class CudnnAccess { return CudnnHandle(executor, std::move(lock), handle_); } + absl::StatusOr> GetLocalHandle() { + cudnnHandle_t handle = nullptr; + if (cudnnCreate(&handle) != CUDNN_STATUS_SUCCESS) { + return absl::InternalError("Creation of local cudnn handle failed."); + } + return std::make_unique(handle); + } + void NotifyStreamDestroyed(Stream* stream) { CUstream cu_stream = AsGpuStreamValue(stream); absl::MutexLock lock(&mutex_); @@ -330,7 +395,7 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo( } } -tsl::StatusOr GetCudnnProperty(libraryPropertyType type) { +absl::StatusOr GetCudnnProperty(libraryPropertyType type) { int value; RETURN_IF_CUDNN_ERROR(cudnnGetProperty(type, &value)); return value; @@ -351,7 +416,7 @@ cudnnRNNAlgo_t ToCudnnRNNAlgo(std::optional algorithm) { } } -tsl::StatusOr GetLoadedCudnnVersion() { +absl::StatusOr GetLoadedCudnnVersion() { TF_ASSIGN_OR_RETURN(int major, GetCudnnProperty(MAJOR_VERSION)); TF_ASSIGN_OR_RETURN(int minor, GetCudnnProperty(MINOR_VERSION)); TF_ASSIGN_OR_RETURN(int patch_level, GetCudnnProperty(PATCH_LEVEL)); @@ -363,28 +428,38 @@ enum class PreloadCudnnType { ConvFwd, ConvBwdFilter, ConvBwdData, Rnn }; // Preload sub libs for cudnn 8.0.4+ to make sure that the loading time isn't // measured in the autotuning. void PreloadCudnnSubLibs(PreloadCudnnType type) { -#if CUDNN_VERSION >= 8004 switch (type) { case PreloadCudnnType::ConvBwdFilter: case PreloadCudnnType::ConvBwdData: { +#if CUDNN_VERSION >= 8004 && CUDNN_VERSION < 90000 cudnnOpsTrainVersionCheck(); cudnnCnnTrainVersionCheck(); +#endif // CUDNN_VERSION >= 8004 && CUDNN_VERSION < 90000 [[clang::fallthrough]]; } case PreloadCudnnType::ConvFwd: { +#if CUDNN_VERSION >= 90000 + cudnnGraphVersionCheck(); + cudnnOpsVersionCheck(); +#elif CUDNN_VERSION >= 8004 cudnnOpsInferVersionCheck(); cudnnCnnInferVersionCheck(); +#endif // CUDNN_VERSION >= 90000 break; } case PreloadCudnnType::Rnn: { +#if CUDNN_VERSION >= 90000 + cudnnOpsVersionCheck(); + cudnnAdvVersionCheck(); +#elif CUDNN_VERSION >= 8004 cudnnOpsInferVersionCheck(); cudnnAdvInferVersionCheck(); cudnnOpsTrainVersionCheck(); cudnnAdvTrainVersionCheck(); +#endif // CUDNN_VERSION >= 90000 break; } } -#endif // CUDNN_VERSION >= 8004 } void PreloadCudnnSubLibsHelper(dnn::ConvolutionKind kind) { @@ -414,7 +489,7 @@ void PreloadCudnnSubLibsHelper(dnn::ConvolutionKind kind) { CudnnSupport::CudnnSupport(GpuExecutor* parent) : parent_(parent) {} -tsl::Status CudnnSupport::Init() { +absl::Status CudnnSupport::Init() { ScopedActivateExecutorContext context(parent_); // Peek at the last error to give more information in cases of errors. @@ -427,7 +502,7 @@ tsl::Status CudnnSupport::Init() { cuda_error, "): ", cudaGetErrorName(cuda_error), " : ", cudaGetErrorString(cuda_error)); LOG(ERROR) << error; - return tsl::Status(absl::StatusCode::kInternal, error); + return absl::InternalError(error); } cudnnHandle_t cudnn_handle = nullptr; @@ -448,13 +523,13 @@ tsl::Status CudnnSupport::Init() { "configuration."); LOG(ERROR) << error; cudnnDestroy(cudnn_handle); - return tsl::Status(absl::StatusCode::kInternal, error); + return absl::InternalError(error); } cudnn_ = std::make_unique(cudnn_handle); LOG(INFO) << "Loaded cuDNN version " << cudnnGetVersion(); - return ::tsl::OkStatus(); + return absl::OkStatus(); } CHECK_EQ(cudnn_handle, nullptr); @@ -477,16 +552,16 @@ tsl::Status CudnnSupport::Init() { } } - return tsl::Status(absl::StatusCode::kInternal, - absl::StrCat("cudnn library could not create a handle: ", - CudnnStatusToString(status))); + return absl::InternalError( + absl::StrCat("cudnn library could not create a handle: ", + CudnnStatusToString(status))); } void CudnnSupport::NotifyStreamDestroyed(Stream* stream) /* override */ { cudnn_->NotifyStreamDestroyed(stream); } -tsl::StatusOr CudnnSupport::GetVersion() { +absl::StatusOr CudnnSupport::GetVersion() { return GetLoadedCudnnVersion(); } @@ -539,11 +614,13 @@ struct RnnDescriptorDeleter { CHECK_CUDNN_OK(cudnnDestroyRNNDescriptor(descriptor)); } }; +#if CUDNN_VERSION < 8100 struct PersistentRnnPlanDeleter { void operator()(cudnnPersistentRNNPlan_t plan) const { CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan)); } }; +#endif // CUDNN_VERSION < 8100 #if CUDNN_VERSION >= 7603 struct CtcLossDescriptorDeleter { void operator()(cudnnCTCLossDescriptor_t descriptor) const { @@ -569,8 +646,13 @@ using ActivationDescriptor = using DropoutDescriptor = std::unique_ptr; using RnnDescriptor = std::unique_ptr; +#if CUDNN_VERSION >= 8100 +struct DummyType {}; +using PersistentRnnPlan = std::unique_ptr; +#else using PersistentRnnPlan = std::unique_ptr; +#endif // CUDNN_VERSION >= 8100 #if CUDNN_VERSION >= 7603 using CtcLossDescriptor = std::unique_ptr; @@ -630,13 +712,15 @@ CtcLossDescriptor CreateCtcLossDescriptor() { } #endif -tsl::StatusOr CreatePersistentRnnPlan( +#if CUDNN_VERSION < 8100 +absl::StatusOr CreatePersistentRnnPlan( cudnnRNNDescriptor_t rnn_desc, int batch_size, cudnnDataType_t data_type) { cudnnPersistentRNNPlan_t result; RETURN_IF_CUDNN_ERROR( cudnnCreatePersistentRNNPlan(rnn_desc, batch_size, data_type, &result)); - return tsl::StatusOr(PersistentRnnPlan(result)); + return absl::StatusOr(PersistentRnnPlan(result)); } +#endif // CUDNN_VERSION < 8100 // Turns a BatchDescriptor structure into a cudnn tensor handle within a // scope. @@ -747,7 +831,7 @@ class CudnnFilterDescriptor { FilterDescriptor handle_; // Owned. }; -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 // The errata sheet (JSON format) for marking the cudnn engines that might be // buggy. For example, we don't want the engine 999 of forward convolution: // R"({ "version" : 1, @@ -843,7 +927,7 @@ const json* CudnnExecutionPlanEngineFilterRuntime() { return json_handle; } -#endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8100 // A helper function to decide whether to use // CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in @@ -1123,56 +1207,79 @@ class CudnnActivationDescriptor { ActivationDescriptor handle_; // Owned. }; -cudnnDataType_t ToCudnnDataType( +cudnn_frontend::DataType_t ToCudnnFrontendDataType( dnn::DataType data_type, dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) { switch (data_type) { case dnn::DataType::kFloat: - return CUDNN_DATA_FLOAT; + return cudnn_frontend::DataType_t::FLOAT; case dnn::DataType::kDouble: - return CUDNN_DATA_DOUBLE; + return cudnn_frontend::DataType_t::DOUBLE; case dnn::DataType::kHalf: - return CUDNN_DATA_HALF; + return cudnn_frontend::DataType_t::HALF; case dnn::DataType::kInt8: switch (data_layout) { case dnn::DataLayout::kBatchDepthYX4: - return CUDNN_DATA_INT8x4; + return cudnn_frontend::DataType_t::INT8x4; case dnn::DataLayout::kBatchDepthYX32: - return CUDNN_DATA_INT8x32; + return cudnn_frontend::DataType_t::INT8x32; default: - return CUDNN_DATA_INT8; + return cudnn_frontend::DataType_t::INT8; } case dnn::DataType::kInt32: - return CUDNN_DATA_INT32; + return cudnn_frontend::DataType_t::INT32; case dnn::DataType::kInt64: - return CUDNN_DATA_INT64; + return cudnn_frontend::DataType_t::INT64; #if CUDNN_VERSION >= 8200 case dnn::DataType::kBF16: - return CUDNN_DATA_BFLOAT16; + return cudnn_frontend::DataType_t::BFLOAT16; #endif #if CUDNN_VERSION >= 8900 case dnn::DataType::kF8E4M3FN: - return CUDNN_DATA_FP8_E4M3; + return cudnn_frontend::DataType_t::FP8_E4M3; case dnn::DataType::kF8E5M2: - return CUDNN_DATA_FP8_E5M2; + return cudnn_frontend::DataType_t::FP8_E5M2; #endif default: LOG(FATAL) << "Invalid DNN data type: " << static_cast(data_type); } } -cudnnDataType_t ToCudnnDataType(dnn::DataType data_type, - dnn::FilterLayout filter_layout) { +cudnnDataType_t ToCudnnDataType( + dnn::DataType data_type, + dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) { + cudnnDataType_t type; + CHECK_CUDNN_OK(cudnn_frontend::detail::convert_to_cudnn_type( + ToCudnnFrontendDataType(data_type, data_layout), type)); + return type; +} + +cudnn_frontend::DataType_t ToCudnnFrontendDataType( + dnn::DataType data_type, dnn::FilterLayout filter_layout) { if (data_type == dnn::DataType::kInt8 && filter_layout == dnn::FilterLayout::kOutputInputYX4) { - return CUDNN_DATA_INT8x4; + return cudnn_frontend::DataType_t::INT8x4; } if (data_type == dnn::DataType::kInt8 && (filter_layout == dnn::FilterLayout::kOutputInputYX32 || filter_layout == dnn::FilterLayout::kOutputInputYX32_CudnnReordered)) { - return CUDNN_DATA_INT8x32; + return cudnn_frontend::DataType_t::INT8x32; } - return ToCudnnDataType(data_type); + return ToCudnnFrontendDataType(data_type); +} + +cudnnDataType_t ToCudnnDataType(dnn::DataType data_type, + dnn::FilterLayout filter_layout) { + cudnnDataType_t type; + CHECK_CUDNN_OK(cudnn_frontend::detail::convert_to_cudnn_type( + ToCudnnFrontendDataType(data_type, filter_layout), type)); + return type; +} + +template +cudnn_frontend::DataType_t GetCudnnFrontendDataType( + dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) { + return ToCudnnFrontendDataType(dnn::ToDataType::value, data_layout); } template @@ -1181,6 +1288,12 @@ cudnnDataType_t GetCudnnDataType( return ToCudnnDataType(dnn::ToDataType::value, data_layout); } +template +cudnn_frontend::DataType_t GetCudnnFrontendDataType( + dnn::FilterLayout filter_layout) { + return ToCudnnFrontendDataType(dnn::ToDataType::value, filter_layout); +} + template cudnnDataType_t GetCudnnDataType(dnn::FilterLayout filter_layout) { return ToCudnnDataType(dnn::ToDataType::value, filter_layout); @@ -1242,7 +1355,7 @@ class CudnnDropoutDescriptor { public: CudnnDropoutDescriptor(CudnnDropoutDescriptor&&) = default; - static tsl::StatusOr Create( + static absl::StatusOr Create( const CudnnHandle& cudnn, float dropout, uint64_t seed, ScratchAllocator* state_allocator) { DropoutDescriptor handle = CreateDropoutDescriptor(); @@ -1289,7 +1402,7 @@ class CudnnRnnParamsDescriptor { public: CudnnRnnParamsDescriptor(CudnnRnnParamsDescriptor&&) = default; - static tsl::StatusOr Create( + static absl::StatusOr Create( const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type, cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode, cudnnDirectionMode_t direction_mode, int num_layers); @@ -1341,7 +1454,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { public: CudnnRnnDescriptor(CudnnRnnDescriptor&& other) = default; - static tsl::StatusOr Create( + static absl::StatusOr Create( const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size, int cell_size, int batch_size, cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode, @@ -1377,8 +1490,8 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { ? algorithm_config.algorithm()->tensor_ops_enabled() : allow_tensor_ops; if (use_tensor_ops && !allow_tensor_ops) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "Algo requests disallowed tensor op evaluation."); + return absl::InvalidArgumentError( + "Algo requests disallowed tensor op evaluation."); } #if CUDNN_VERSION >= 8000 @@ -1428,19 +1541,24 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { } #endif - tsl::StatusOr rnn_plan_wrapper; + absl::StatusOr rnn_plan_wrapper; PersistentRnnPlan rnn_plan; if (rnn_algo == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) { CHECK_GE(batch_size, 0); +#if CUDNN_VERSION >= 8100 + RETURN_IF_CUDNN_ERROR( + cudnnBuildRNNDynamic(cudnn.handle(), rnn_desc.get(), batch_size)); +#else rnn_plan_wrapper = CreatePersistentRnnPlan(rnn_desc.get(), batch_size, data_type); if (!rnn_plan_wrapper.ok()) { - return tsl::StatusOr(rnn_plan_wrapper.status()); + return absl::StatusOr(rnn_plan_wrapper.status()); } else { rnn_plan = std::move(rnn_plan_wrapper).value(); RETURN_IF_CUDNN_ERROR( cudnnSetPersistentRNNPlan(rnn_desc.get(), rnn_plan.get())); } +#endif // CUDNN_VERSION >= 8100 } // Create the params handle. @@ -1544,10 +1662,10 @@ namespace { // Check if the LSTM projection is used. If yes, an additional weight matrix // (projection matrix) will be fetched to the 'weights'. Otherwise, nothing will // be done. -tsl::Status CheckAndFetchProjectionWeights( +absl::Status CheckAndFetchProjectionWeights( const CudnnHandle& cudnn, cudnnRNNDescriptor_t rnn_desc, const int layer, const TensorDescriptor& input_desc, const FilterDescriptor& filter_desc, - const FilterDescriptor& region_desc_handle, + int64_t params_size_in_bytes, const FilterDescriptor& region_desc_handle, dnn::RnnDescriptor::ParamsRegions* weights) { int hidden_size_v; int num_layers_v; @@ -1557,17 +1675,24 @@ tsl::Status CheckAndFetchProjectionWeights( cudnnRNNMode_t mode; cudnnRNNAlgo_t algo; cudnnDataType_t data_type; -#if CUDNN_VERSION >= 8000 - RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor_v6( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, + int rec_proj_size_v; +#if CUDNN_VERSION >= 8100 + RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor_v8( + /*rnnDesc=*/rnn_desc, + /*algo=*/&algo, + /*cellMode=*/&mode, + /*biasMode=*/nullptr, + /*dirMode=*/&direction, + /*inputMode=*/&input_mode, + /*dataType=*/nullptr, + /*mathPrec=*/&data_type, + /*mathType=*/nullptr, + /*inputSize=*/nullptr, /*hiddenSize=*/&hidden_size_v, + /*projSize=*/&rec_proj_size_v, /*numLayers=*/&num_layers_v, /*dropoutDesc=*/&dropout_desc, - /*inputMode=*/&input_mode, - /*direction=*/&direction, - /*mode=*/&mode, - /*algo=*/&algo, - /*mathPrec=*/&data_type)); + /*auxFlags=*/nullptr)); #else RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, @@ -1579,17 +1704,48 @@ tsl::Status CheckAndFetchProjectionWeights( /*mode=*/&mode, /*algo=*/&algo, /*mathPrec=*/&data_type)); -#endif - int rec_proj_size_v; int out_proj_size_v; RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, /*recProjSize*/ &rec_proj_size_v, /*outProjSize*/ &out_proj_size_v)); +#endif // CUDNN_VERSION >= 8100 if (rec_proj_size_v != hidden_size_v) { - void* offset = nullptr; int region_id = 8; +#if CUDNN_VERSION >= 8100 + void* b_ptr = nullptr; + void* m_ptr = nullptr; + void* w_ptr = nullptr; + TensorDescriptor m_region_desc_handle = CreateTensorDescriptor(); + TensorDescriptor b_region_desc_handle = CreateTensorDescriptor(); + RETURN_IF_CUDNN_ERROR(cudnnGetRNNWeightParams( + /*handle=*/cudnn.handle(), + /*rnnDesc=*/rnn_desc, + /*pseudoLayer=*/layer, + /*weightSpaceSize=*/params_size_in_bytes, + /*weightSpace=*/w_ptr, + /*linLayerID=*/region_id, + /*mDesc=*/m_region_desc_handle.get(), + /*mAddr=*/&m_ptr, + /*bDesc=*/b_region_desc_handle.get(), + /*bAddr=*/&b_ptr)); + int dims[] = {1, 1, 1}; + int strides[] = {1, 1, 1}; + cudnnDataType_t data_type; + int n_dims; + RETURN_IF_CUDNN_ERROR(cudnnGetTensorNdDescriptor( + /*tensorDesc=*/m_region_desc_handle.get(), + /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]), + /*dataType=*/&data_type, + /*nbDims=*/&n_dims, + /*dimA=*/dims, + /*strideA*/ strides)); + int64_t size = + dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type); + int64_t offset = static_cast(m_ptr) - static_cast(w_ptr); +#else + void* offset = nullptr; RETURN_IF_CUDNN_ERROR(cudnnGetRNNLinLayerMatrixParams( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, /*layer=*/layer, /*xDesc=*/input_desc.get(), @@ -1608,14 +1764,15 @@ tsl::Status CheckAndFetchProjectionWeights( /*nbDims=*/&n_dims, /*filterDimA=*/dims)); int64_t size = dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type); +#endif // CUDNN_VERSION >= 8100 dnn::RnnDescriptor::ParamsRegion region = { reinterpret_cast(offset), size}; weights->push_back(region); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } -tsl::StatusOr CudnnRnnParamsDescriptor::Create( +absl::StatusOr CudnnRnnParamsDescriptor::Create( const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type, cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode, cudnnDirectionMode_t direction_mode, int num_layers) { @@ -1630,10 +1787,16 @@ tsl::StatusOr CudnnRnnParamsDescriptor::Create( /*strideA=*/strides)); size_t params_size = 0; +#if CUDNN_VERSION >= 8100 + RETURN_IF_CUDNN_ERROR(cudnnGetRNNWeightSpaceSize( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, + /*weightSpaceSize=*/¶ms_size)); +#else RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, /*xDesc=*/input_desc.get(), /*sizeInBytes=*/¶ms_size, /*dataType=*/data_type)); +#endif // CUDNN_VERSION >= 8100 int64_t params_size_in_bytes = static_cast(params_size); FilterDescriptor filter_desc = CreateFilterDescriptor(); @@ -1671,6 +1834,51 @@ tsl::StatusOr CudnnRnnParamsDescriptor::Create( for (int layer = 0; layer < layer_count; layer++) { for (int region = 0; region < region_count_per_layer; region++) { +#if CUDNN_VERSION >= 8100 + void* m_ptr = nullptr; + void* b_ptr = nullptr; + void* w_ptr = nullptr; + TensorDescriptor m_region_desc_handle = CreateTensorDescriptor(); + TensorDescriptor b_region_desc_handle = CreateTensorDescriptor(); + RETURN_IF_CUDNN_ERROR(cudnnGetRNNWeightParams( + /*handle=*/cudnn.handle(), + /*rnnDesc=*/rnn_desc, + /*pseudoLayer=*/layer, + /*weightsSize=*/params_size_in_bytes, + /*weights=*/&w_ptr, + /*linID=*/region, + /*mDesc=*/m_region_desc_handle.get(), + /*mAddr=*/&m_ptr, + /*bDesc=*/b_region_desc_handle.get(), + /*bAddr=*/&b_ptr)); + + int dims[] = {1, 1, 1}; + int strides[] = {1, 1, 1}; + cudnnDataType_t data_type; + int n_dims; + auto get_size = + [&](const TensorDescriptor& tensor_desc) -> absl::StatusOr { + RETURN_IF_CUDNN_ERROR(cudnnGetTensorNdDescriptor( + /*tensorDesc=*/m_region_desc_handle.get(), + /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]), + /*dataType=*/&data_type, + /*nbDims=*/&n_dims, + /*dimA=*/dims, + /*strideA*/ strides)); + int64_t size = + dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type); + return size; + }; + TF_ASSIGN_OR_RETURN(int64_t m_size, get_size(m_region_desc_handle)); + int64_t m_offset = static_cast(m_ptr) - static_cast(w_ptr); + dnn::RnnDescriptor::ParamsRegion m_region = {m_offset, m_size}; + weights.push_back(m_region); + + TF_ASSIGN_OR_RETURN(int64_t b_size, get_size(b_region_desc_handle)); + int64_t b_offset = static_cast(b_ptr) - static_cast(w_ptr); + dnn::RnnDescriptor::ParamsRegion b_region = {b_offset, b_size}; + biases.push_back(b_region); +#else for (int type = 0; type < 2; type++) { void* offset = nullptr; RETURN_IF_CUDNN_ERROR( @@ -1703,10 +1911,11 @@ tsl::StatusOr CudnnRnnParamsDescriptor::Create( reinterpret_cast(offset), size}; (type == 0 ? weights : biases).push_back(region); } +#endif // CUDNN_VERSION >= 8100 } TF_RETURN_IF_ERROR(CheckAndFetchProjectionWeights( - cudnn, rnn_desc, layer, input_desc, filter_desc, region_desc_handle, - &weights)); + cudnn, rnn_desc, layer, input_desc, filter_desc, params_size_in_bytes, + region_desc_handle, &weights)); } return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes, @@ -1732,12 +1941,11 @@ class CudnnRnnSequenceTensorDescriptor CudnnRnnSequenceTensorDescriptor(CudnnRnnSequenceTensorDescriptor&&) = default; - static tsl::StatusOr Create( + static absl::StatusOr Create( GpuExecutor* parent, int max_seq_length, int batch_size, int data_size, cudnnDataType_t data_type) { if (max_seq_length <= 0) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "max_seq_length <= 0"); + return absl::InvalidArgumentError("max_seq_length <= 0"); } int dims[] = {batch_size, data_size, 1}; int strides[] = {dims[1] * dims[2], dims[2], 1}; @@ -1751,13 +1959,12 @@ class CudnnRnnSequenceTensorDescriptor std::move(tensor_desc)); } - static tsl::StatusOr Create( + static absl::StatusOr Create( GpuExecutor* parent, int max_seq_length, int batch_size, int data_size, absl::Span seq_lengths, bool time_major, cudnnDataType_t data_type) { if (max_seq_length <= 0) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "max_seq_length <= 0"); + return absl::InvalidArgumentError("max_seq_length <= 0"); } int dims[] = {batch_size, data_size, 1}; int strides[] = {dims[1] * dims[2], dims[2], 1}; @@ -1854,7 +2061,7 @@ struct RnnModelDims { }; template -tsl::StatusOr ExtractAndCheckRnnForward( +absl::StatusOr ExtractAndCheckRnnForward( const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, @@ -1884,45 +2091,40 @@ tsl::StatusOr ExtractAndCheckRnnForward( model_dims.num_layers * model_dims.dir_count && input_h_desc.batch_size() == model_dims.batch_size && input_h_desc.data_size() == model_dims.hidden_size)) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "Invalid input_h shape"); + return absl::InvalidArgumentError("Invalid input_h shape"); } // The LSTM projection will be used if input_h_desc.data_size() < // input_c_desc.data_size() if (!(input_h_desc.num_layers() == input_c_desc.num_layers() && input_h_desc.batch_size() == input_c_desc.batch_size() && input_h_desc.data_size() <= input_c_desc.data_size())) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "Invalid input_c shape"); + return absl::InvalidArgumentError("Invalid input_c shape"); } if (!(output_desc.max_seq_length() == model_dims.max_seq_length && output_desc.batch_size() == model_dims.batch_size && output_desc.data_size() == model_dims.hidden_size * model_dims.dir_count)) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "Invalid output shape"); + return absl::InvalidArgumentError("Invalid output shape"); } if (!(input_h_desc.num_layers() == output_h_desc.num_layers() && input_h_desc.batch_size() == output_h_desc.batch_size() && input_h_desc.data_size() == output_h_desc.data_size())) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "Invalid output_h shape"); + return absl::InvalidArgumentError("Invalid output_h shape"); } if (!(input_h_desc.num_layers() == output_c_desc.num_layers() && input_h_desc.batch_size() == output_c_desc.batch_size() && input_h_desc.data_size() <= output_c_desc.data_size())) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "Invalid output_c shape"); + return absl::InvalidArgumentError("Invalid output_c shape"); } return model_dims; } -tsl::Status CheckRNNParameterSize( +absl::Status CheckRNNParameterSize( const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc) { size_t params_size_in_bytes = 0; -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 RETURN_IF_CUDNN_ERROR(cudnnGetRNNWeightSpaceSize( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*sizeInBytes=*/¶ms_size_in_bytes)); @@ -1934,32 +2136,71 @@ tsl::Status CheckRNNParameterSize( #endif if (static_cast(params_size_in_bytes) != rnn_desc.ParamsSizeInBytes()) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "Mismatching RNN parameter size"); + return absl::InvalidArgumentError("Mismatching RNN parameter size"); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } -tsl::StatusOr> CreateRnnWorkspace( +absl::Status CreateRnnTempSpace( Stream* stream, const CudnnHandle& cudnn, - const CudnnRnnDescriptor& rnn_desc, + const CudnnRnnDescriptor& rnn_desc, RnnModelDims model_dims, const CudnnRnnSequenceTensorDescriptor& input_desc, - ScratchAllocator* workspace_allocator) { - // Query the workspace size. + ScratchAllocator* workspace_allocator, + ScratchAllocator* reserve_space_allocator, bool is_fwd_training, + DeviceMemory* workspace, DeviceMemory* reserve_space) { + size_t reserve_space_size_in_bytes = 0; size_t workspace_size_in_bytes = 0; - RETURN_IF_CUDNN_ERROR(cudnnGetRNNWorkspaceSize( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), - /*seqLength=*/input_desc.max_seq_length(), /*xDesc=*/input_desc.handles(), - /*sizeInBytes=*/&workspace_size_in_bytes)); - // Allocate the workspace. - if (workspace_size_in_bytes == 0) { - return DeviceMemory(); + if (input_desc.is_var_seq_lengths()) { +#if CUDNN_VERSION >= 8100 + auto rnn_fwd_mode = + is_fwd_training ? CUDNN_FWD_MODE_TRAINING : CUDNN_FWD_MODE_INFERENCE; + RETURN_IF_CUDNN_ERROR(cudnnGetRNNTempSpaceSizes( + /*handle=*/cudnn.handle(), + /*rnnDesc=*/rnn_desc.handle(), + /*fMode=*/rnn_fwd_mode, + /*xDesc=*/input_desc.data_handle(), + /*workSpaceSize=*/&workspace_size_in_bytes, + /*reserveSpaceSize=*/&reserve_space_size_in_bytes)); +#else + return tsl::errors::Internal( + "Sequence lengths for RNN are supported from CUDNN 8.1+"); +#endif // CUDNN_VERSION >= 8100 + } else { +#if CUDNN_VERSION >= 90000 + return tsl::errors::Internal( + "Sequence lengths for RNN are required from CUDNN 9.0+"); +#else + RETURN_IF_CUDNN_ERROR(cudnnGetRNNWorkspaceSize( + /*handle=*/cudnn.handle(), + /*rnnDesc=*/rnn_desc.handle(), + /*seqLength=*/input_desc.max_seq_length(), + /*xDesc=*/input_desc.handles(), + /*sizeInBytes=*/&workspace_size_in_bytes)); + if (is_fwd_training) { + RETURN_IF_CUDNN_ERROR(cudnnGetRNNTrainingReserveSize( + /*handle=*/cudnn.handle(), + /*rnnDesc=*/rnn_desc.handle(), + /*seqLength=*/model_dims.max_seq_length, + /*xDesc=*/input_desc.handles(), + /*sizeInBytes=*/&reserve_space_size_in_bytes)); + } +#endif // CUDNN_VERSION >= 90000 } - return workspace_allocator->AllocateBytes(workspace_size_in_bytes); + + if (workspace_size_in_bytes > 0) { + TF_ASSIGN_OR_RETURN(*workspace, workspace_allocator->AllocateBytes( + workspace_size_in_bytes)); + } + if (reserve_space_allocator != nullptr && is_fwd_training && + reserve_space_size_in_bytes > 0) { + TF_ASSIGN_OR_RETURN(*reserve_space, reserve_space_allocator->AllocateBytes( + reserve_space_size_in_bytes)); + } + return absl::OkStatus(); } #if CUDNN_VERSION >= 7402 -tsl::StatusOr> CreateBatchNormForwardWorkspace( +absl::StatusOr> CreateBatchNormForwardWorkspace( Stream* stream, const CudnnHandle& cudnn, const cudnnBatchNormMode_t& mode, const cudnnBatchNormOps_t& bn_ops, const cudnnActivationDescriptor_t& activation_desc, @@ -1983,7 +2224,7 @@ tsl::StatusOr> CreateBatchNormForwardWorkspace( return workspace_allocator->AllocateBytes(workspace_size_in_bytes); } -tsl::StatusOr> CreateBatchNormBackwardWorkspace( +absl::StatusOr> CreateBatchNormBackwardWorkspace( Stream* stream, const CudnnHandle& cudnn, const cudnnBatchNormMode_t& mode, const cudnnBatchNormOps_t& bn_ops, const cudnnActivationDescriptor_t& activation_desc, @@ -2014,7 +2255,7 @@ tsl::StatusOr> CreateBatchNormBackwardWorkspace( } // namespace // Populates the profile result if not empty. -static tsl::Status PopulateProfileFromTimer( +static absl::Status PopulateProfileFromTimer( std::optional& timer, const dnn::AlgorithmDesc& algorithm, dnn::ProfileResult* profile_result, std::optional scratch_size = std::nullopt) { @@ -2027,11 +2268,11 @@ static tsl::Status PopulateProfileFromTimer( profile_result->set_scratch_size(*scratch_size); } } - return tsl::OkStatus(); + return absl::OkStatus(); } template -tsl::Status CudnnSupport::DoRnnForwardImpl( +absl::Status CudnnSupport::DoRnnForwardImpl( Stream* stream, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, @@ -2060,42 +2301,27 @@ tsl::Status CudnnSupport::DoRnnForwardImpl( TF_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc)); - // In CUDNN v8.0, the cudnnRNNForward*** and cudnnRNNForward***Ex have been - // deprecated. Instead, we use the cudnnRNNForward which requires the - // sequence_lengths parameter. For more info, - // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#release-802. -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND - if (input_desc.is_var_seq_lengths()) { - DeviceMemory workspace; - DeviceMemory reserve_space; - cudnnForwardMode_t rnn_fwd_mode; - if (is_training) { - rnn_fwd_mode = CUDNN_FWD_MODE_TRAINING; - } else { - rnn_fwd_mode = CUDNN_FWD_MODE_INFERENCE; - } - size_t reserve_space_size_in_bytes = 0; - size_t workspace_size_in_bytes = 0; - RETURN_IF_CUDNN_ERROR(cudnnGetRNNTempSpaceSizes( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), - /*fMode=*/rnn_fwd_mode, /*xDesc=*/input_desc.data_handle(), - /*workSpaceSize=*/&workspace_size_in_bytes, - /*reserveSpaceSize=*/&reserve_space_size_in_bytes)); - - if (workspace_size_in_bytes > 0) { - TF_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes( - workspace_size_in_bytes)); - } - if (reserve_space_size_in_bytes > 0) { - TF_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes( - reserve_space_size_in_bytes)); - } + DeviceMemory reserve_space; + DeviceMemory workspace; + TF_RETURN_IF_ERROR(CreateRnnTempSpace( + stream, cudnn, rnn_desc, model_dims, input_desc, workspace_allocator, + reserve_space_allocator, is_training, &workspace, &reserve_space)); - const bool is_profiling = output_profile_result != nullptr; - TF_ASSIGN_OR_RETURN( - std::optional timer, - GpuTimer::CreateIfNeeded(AsGpuStream(stream), is_profiling)); + const bool is_profiling = output_profile_result != nullptr; + TF_ASSIGN_OR_RETURN( + std::optional timer, + GpuTimer::CreateIfNeeded( + stream, + output_profile_result && output_profile_result->warmup_run_executed(), + is_profiling)); + if (input_desc.is_var_seq_lengths()) { + // In CUDNN v8, the cudnnRNNForward*** and cudnnRNNForward***Ex have been + // deprecated. Instead, we use the cudnnRNNForward which requires the + // sequence_lengths parameter. +#if CUDNN_VERSION >= 8100 + auto rnn_fwd_mode = + is_training ? CUDNN_FWD_MODE_TRAINING : CUDNN_FWD_MODE_INFERENCE; RETURN_IF_CUDNN_ERROR(cudnnRNNForward( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*fwdMode=*/rnn_fwd_mode, @@ -2112,42 +2338,8 @@ tsl::Status CudnnSupport::DoRnnForwardImpl( /*workSpaceSize=*/workspace.size(), /*workspace=*/workspace.opaque(), /*reserveSpaceSizeInBytes=*/reserve_space.size(), /*reserveSpace=*/reserve_space.opaque())); - - if (is_profiling) { - TF_RETURN_IF_ERROR(PopulateProfileFromTimer( - timer, *rnn_desc.algorithm_config().algorithm(), - output_profile_result)); - } - return tsl::OkStatus(); - } -#endif - TF_ASSIGN_OR_RETURN(DeviceMemory workspace, - CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc, - workspace_allocator)); - - // query the reserve space size - // allocate the reserve space - DeviceMemory reserve_space; - if (is_training) { - size_t reserve_space_size_in_bytes = 0; - RETURN_IF_CUDNN_ERROR(cudnnGetRNNTrainingReserveSize( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), - /*seqLength=*/model_dims.max_seq_length, /*xDesc=*/input_desc.handles(), - /*sizeInBytes=*/&reserve_space_size_in_bytes)); - - if (reserve_space_size_in_bytes > 0) { - TF_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes( - reserve_space_size_in_bytes)); - } - } - - const bool is_profiling = output_profile_result != nullptr; - TF_ASSIGN_OR_RETURN( - std::optional timer, - GpuTimer::CreateIfNeeded(AsGpuStream(stream), is_profiling)); - - if (!is_training) { - if (input_desc.is_var_seq_lengths()) { +#else + if (!is_training) { RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInferenceEx( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(), @@ -2163,21 +2355,6 @@ tsl::Status CudnnSupport::DoRnnForwardImpl( /*workspace=*/workspace.opaque(), /*workSpaceSizeInBytes=*/workspace.size())); } else { - RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInference( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), - /*seqLength=*/model_dims.max_seq_length, - /*xDesc=*/input_desc.handles(), - /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(), - /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(), - /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(), - /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(), - /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(), - /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(), - /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(), - /*workSpaceSizeInBytes=*/workspace.size())); - } - } else { - if (input_desc.is_var_seq_lengths()) { RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTrainingEx( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(), @@ -2194,6 +2371,26 @@ tsl::Status CudnnSupport::DoRnnForwardImpl( /*workSpaceSizeInBytes=*/workspace.size(), /*reserveSpace=*/reserve_space.opaque(), /*reserveSpaceSizeInBytes=*/reserve_space.size())); + } +#endif // CUDNN_VERSION >= 8100 + } else { +#if CUDNN_VERSION >= 90000 + return tsl::errors::Internal( + "Sequence lengths for RNN are required from CUDNN 9.0+"); +#else + if (!is_training) { + RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInference( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), + /*seqLength=*/model_dims.max_seq_length, + /*xDesc=*/input_desc.handles(), + /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(), + /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(), + /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(), + /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(), + /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(), + /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(), + /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(), + /*workSpaceSizeInBytes=*/workspace.size())); } else { RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTraining( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), @@ -2210,6 +2407,7 @@ tsl::Status CudnnSupport::DoRnnForwardImpl( /*reserveSpace=*/reserve_space.opaque(), /*reserveSpaceSizeInBytes=*/reserve_space.size())); } +#endif // CUDNN_VERSION >= 90000 } if (is_profiling) { @@ -2218,11 +2416,11 @@ tsl::Status CudnnSupport::DoRnnForwardImpl( output_profile_result)); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } template -tsl::Status CudnnSupport::DoRnnBackwardImpl( +absl::Status CudnnSupport::DoRnnBackwardImpl( Stream* stream, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, @@ -2258,29 +2456,24 @@ tsl::Status CudnnSupport::DoRnnBackwardImpl( TF_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc)); - // In CUDNN v8.0, the cudnnRNNForward*** and cudnnRNNForward***Ex have been - // deprecated. Instead, we use the cudnnRNNForward which requires the - // sequence_lengths parameter. For more info, - // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#release-802. -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND - if (input_desc.is_var_seq_lengths()) { - DeviceMemory workspace; - size_t workspace_size_in_bytes = 0; - RETURN_IF_CUDNN_ERROR(cudnnGetRNNTempSpaceSizes( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), - /*fMode=*/CUDNN_FWD_MODE_TRAINING, /*xDesc=*/input_desc.data_handle(), - /*workSpaceSize=*/&workspace_size_in_bytes, - /*reserveSpaceSize=*/NULL)); - if (workspace_size_in_bytes > 0) { - TF_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes( - workspace_size_in_bytes)); - } + DeviceMemory workspace; + TF_RETURN_IF_ERROR(CreateRnnTempSpace(stream, cudnn, rnn_desc, model_dims, + input_desc, workspace_allocator, + nullptr, true, &workspace, nullptr)); - const bool is_profiling = output_profile_result != nullptr; - TF_ASSIGN_OR_RETURN( - std::optional timer, - GpuTimer::CreateIfNeeded(AsGpuStream(stream), is_profiling)); + const bool is_profiling = output_profile_result != nullptr; + TF_ASSIGN_OR_RETURN( + std::optional timer, + GpuTimer::CreateIfNeeded( + stream, + output_profile_result && output_profile_result->warmup_run_executed(), + is_profiling)); + if (input_desc.is_var_seq_lengths()) { + // In CUDNN v8, the cudnnRNNBackward*** and cudnnRNNBackward***Ex have + // been deprecated. Instead, we use the cudnnRNNBackward***_v8 which + // requires the sequence_lengths parameter. +#if CUDNN_VERSION >= 8100 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData_v8( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*devSeqLengths=*/ @@ -2300,48 +2493,7 @@ tsl::Status CudnnSupport::DoRnnBackwardImpl( /*workSpaceSize=*/workspace.size(), /*workSpace=*/workspace.opaque(), /*reserveSpaceSize=*/reserve_space_data->size(), /*reserveSpace=*/reserve_space_data->opaque())); - - if (params_backprop_data != nullptr) { - // Clear the dw to zeros. - stream->ThenMemZero(params_backprop_data, params_backprop_data->size()); - RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights_v8( - /*handle=*/cudnn.handle(), - /*rnnDesc=*/rnn_desc.handle(), - /*addGrad=*/CUDNN_WGRAD_MODE_ADD, - /*devSeqLengths=*/ - reinterpret_cast(seq_lengths_data.opaque()), - /*xDesc=*/input_desc.data_handle(), - /*x=*/input_data.opaque(), - /*hDesc=*/input_h_desc.handle(), - /*hx=*/input_h_data.opaque(), - /*yDesc=*/output_desc.data_handle(), - /*y=*/output_data.opaque(), - /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(), - /*dweightSpace=*/params_backprop_data->opaque(), - /*workSpaceSize=*/workspace.size(), - /*workSpace=*/workspace.opaque(), - /*reserveSpaceSize=*/reserve_space_data->size(), - /*reserveSpace=*/reserve_space_data->opaque())); - } - - if (is_profiling) { - TF_RETURN_IF_ERROR(PopulateProfileFromTimer( - timer, *rnn_desc.algorithm_config().algorithm(), - output_profile_result)); - } - return tsl::OkStatus(); - } -#endif - TF_ASSIGN_OR_RETURN(DeviceMemory workspace, - CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc, - workspace_allocator)); - - const bool is_profiling = output_profile_result != nullptr; - TF_ASSIGN_OR_RETURN( - std::optional timer, - GpuTimer::CreateIfNeeded(AsGpuStream(stream), is_profiling)); - - if (input_desc.is_var_seq_lengths()) { +#else RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardDataEx( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*yDesc=*/output_desc.data_handle(), /*y=*/output_data.opaque(), @@ -2364,7 +2516,51 @@ tsl::Status CudnnSupport::DoRnnBackwardImpl( /*workSpaceSizeInBytes=*/workspace.size(), /*reserveSpace=*/reserve_space_data->opaque(), /*reserveSpaceSizeInBytes=*/reserve_space_data->size())); +#endif // CUDNN_VERSION >= 8100 + + if (params_backprop_data != nullptr) { + // Clear the dw to zeros. + TF_RETURN_IF_ERROR( + stream->MemZero(params_backprop_data, params_backprop_data->size())); +#if CUDNN_VERSION >= 8100 + RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights_v8( + /*handle=*/cudnn.handle(), + /*rnnDesc=*/rnn_desc.handle(), + /*addGrad=*/CUDNN_WGRAD_MODE_ADD, + /*devSeqLengths=*/ + reinterpret_cast(seq_lengths_data.opaque()), + /*xDesc=*/input_desc.data_handle(), + /*x=*/input_data.opaque(), + /*hDesc=*/input_h_desc.handle(), + /*hx=*/input_h_data.opaque(), + /*yDesc=*/output_desc.data_handle(), + /*y=*/output_data.opaque(), + /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(), + /*dweightSpace=*/params_backprop_data->opaque(), + /*workSpaceSize=*/workspace.size(), + /*workSpace=*/workspace.opaque(), + /*reserveSpaceSize=*/reserve_space_data->size(), + /*reserveSpace=*/reserve_space_data->opaque())); +#else + RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeightsEx( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), + /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(), + /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(), + /*yDesc=*/output_desc.data_handle(), + /*y=*/output_data.opaque(), + /*workspace=*/workspace.opaque(), + /*workSpaceSizeInBytes=*/workspace.size(), + /*dwDesc=*/rnn_desc.params_handle(), + /*dw=*/params_backprop_data->opaque(), + /*reserveSpace=*/reserve_space_data->opaque(), + /*reserveSpaceSizeInBytes=*/reserve_space_data->size())); +#endif // CUDNN_VERSION >= 8100 + } } else { +#if CUDNN_VERSION >= 90000 + return tsl::errors::Internal( + "Sequence lengths for RNN are required from CUDNN 9.0+"); +#else RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*seqLength=*/model_dims.max_seq_length, @@ -2387,25 +2583,11 @@ tsl::Status CudnnSupport::DoRnnBackwardImpl( /*workSpaceSizeInBytes=*/workspace.size(), /*reserveSpace=*/reserve_space_data->opaque(), /*reserveSpaceSizeInBytes=*/reserve_space_data->size())); - } - if (params_backprop_data != nullptr) { - // Clear the dw to zeros. - stream->ThenMemZero(params_backprop_data, params_backprop_data->size()); - if (input_desc.is_var_seq_lengths()) { - RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeightsEx( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), - /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(), - /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(), - /*yDesc=*/output_desc.data_handle(), - /*y=*/output_data.opaque(), - /*workspace=*/workspace.opaque(), - /*workSpaceSizeInBytes=*/workspace.size(), - /*dwDesc=*/rnn_desc.params_handle(), - /*dw=*/params_backprop_data->opaque(), - /*reserveSpace=*/reserve_space_data->opaque(), - /*reserveSpaceSizeInBytes=*/reserve_space_data->size())); - } else { + if (params_backprop_data != nullptr) { + // Clear the dw to zeros. + TF_RETURN_IF_ERROR( + stream->MemZero(params_backprop_data, params_backprop_data->size())); // make the backward weight call RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), @@ -2420,6 +2602,7 @@ tsl::Status CudnnSupport::DoRnnBackwardImpl( /*reserveSpace=*/reserve_space_data->opaque(), /*reserveSpaceSizeInBytes=*/reserve_space_data->size())); } +#endif // CUDNN_VERSION >= 90000 } if (is_profiling) { @@ -2428,10 +2611,10 @@ tsl::Status CudnnSupport::DoRnnBackwardImpl( output_profile_result)); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status CudnnSupport::DoCtcLossImpl( +absl::Status CudnnSupport::DoCtcLossImpl( Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc, const DeviceMemoryBase probs_data, absl::Span labels_data, absl::Span labels_lengths_data, @@ -2462,16 +2645,16 @@ tsl::Status CudnnSupport::DoCtcLossImpl( /*workspace=*/scratch_memory.opaque(), /*workSpaceSizeInBytes=*/scratch_memory.size())); #else - return tsl::Status(absl::StatusCode::kInvalidArgument, - "No supported cudnnCTCLoss when " - "CUDNN_VERSION < 7.6.3"); + return absl::InvalidArgumentError( + "No supported cudnnCTCLoss when " + "CUDNN_VERSION < 7.6.3"); #endif - return ::tsl::OkStatus(); + return absl::OkStatus(); } -tsl::StatusOr> -CudnnSupport::createRnnDescriptor( +absl::StatusOr> +CudnnSupport::CreateRnnDescriptor( int num_layers, int hidden_size, int input_size, int cell_size, int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, @@ -2494,8 +2677,8 @@ CudnnSupport::createRnnDescriptor( new CudnnRnnDescriptor(std::move(rnn_desc))); } -tsl::StatusOr> -CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length, +absl::StatusOr> +CudnnSupport::CreateRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, int data_size, dnn::DataType data_type) { TF_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor, @@ -2506,8 +2689,8 @@ CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length, new CudnnRnnSequenceTensorDescriptor(std::move(descriptor))); } -tsl::StatusOr> -CudnnSupport::createRnnSequenceTensorDescriptor( +absl::StatusOr> +CudnnSupport::CreateRnnSequenceTensorDescriptor( int max_seq_length, int batch_size, int data_size, const absl::Span& seq_lengths, bool time_major, dnn::DataType data_type) { @@ -2519,8 +2702,8 @@ CudnnSupport::createRnnSequenceTensorDescriptor( new CudnnRnnSequenceTensorDescriptor(std::move(descriptor))); } -tsl::StatusOr> -CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size, +absl::StatusOr> +CudnnSupport::CreateRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, dnn::DataType data_type) { return std::unique_ptr( @@ -2822,7 +3005,7 @@ namespace { // TODO(csigg): Merge a lot of duplicate code below for forward, backward data, // and backward filter. -tsl::StatusOr GetCudnnConvolutionForwardAlgo( +absl::StatusOr GetCudnnConvolutionForwardAlgo( const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv, const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit, @@ -2845,9 +3028,9 @@ tsl::StatusOr GetCudnnConvolutionForwardAlgo( return perf_results[r].algo; } } - return tsl::Status(absl::StatusCode::kInternal, - "cudnnGetConvolutionForwardAlgorithm_v7 returned " - "no suitable algorithms. This could be a cudnn bug."); + return absl::InternalError( + "cudnnGetConvolutionForwardAlgorithm_v7 returned " + "no suitable algorithms. This could be a cudnn bug."); #else cudnnConvolutionFwdPreference_t preference = specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT @@ -2860,7 +3043,7 @@ tsl::StatusOr GetCudnnConvolutionForwardAlgo( #endif } -tsl::StatusOr +absl::StatusOr GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, @@ -2887,9 +3070,9 @@ GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn, return perf_results[r].algo; } } - return tsl::Status(absl::StatusCode::kInternal, - "cudnnGetConvolutionBackwardDataAlgorithm_v7 returned " - "no suitable algorithms. This could be a cudnn bug."); + return absl::InternalError( + "cudnnGetConvolutionBackwardDataAlgorithm_v7 returned " + "no suitable algorithms. This could be a cudnn bug."); #else cudnnConvolutionBwdDataPreference_t preference = specify_workspace_limit @@ -2903,7 +3086,7 @@ GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn, #endif } -tsl::StatusOr +absl::StatusOr GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, @@ -2929,9 +3112,9 @@ GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn, return perf_results[r].algo; } } - return tsl::Status(absl::StatusCode::kInternal, - "cudnnGetConvolutionBackwardFilterAlgorithm_v7 returned " - "no suitable algorithms. This could be a cudnn bug."); + return absl::InternalError( + "cudnnGetConvolutionBackwardFilterAlgorithm_v7 returned " + "no suitable algorithms. This could be a cudnn bug."); #else cudnnConvolutionBwdFilterPreference_t preference = specify_workspace_limit @@ -2945,7 +3128,7 @@ GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn, #endif } -tsl::StatusOr> AllocateCudnnConvolutionForwardWorkspace( +absl::StatusOr> AllocateCudnnConvolutionForwardWorkspace( Stream* stream, const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv, @@ -2953,8 +3136,7 @@ tsl::StatusOr> AllocateCudnnConvolutionForwardWorkspace( const dnn::AlgorithmDesc& algorithm_desc, ScratchAllocator* scratch_allocator) { if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) { - return tsl::Status( - absl::StatusCode::kInternal, + return absl::InternalError( "Mismatch between cudnn conv and algorithm descriptors."); } @@ -2975,8 +3157,7 @@ tsl::StatusOr> AllocateCudnnConvolutionForwardWorkspace( int64_t size_in_bytes_int64_t = size_in_bytes; if (ABSL_PREDICT_FALSE(size_in_bytes_int64_t < 0)) { - return tsl::Status( - absl::StatusCode::kInternal, + return absl::InternalError( "cudnnGetConvolutionForwardWorkspaceSize() returned " "negative sizeInBytes value. This could be a cudnn bug."); } @@ -2986,14 +3167,13 @@ tsl::StatusOr> AllocateCudnnConvolutionForwardWorkspace( } if (ABSL_PREDICT_FALSE(!scratch_allocator)) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "No scratch allocator provided"); + return absl::InvalidArgumentError("No scratch allocator provided"); } return scratch_allocator->AllocateBytes(size_in_bytes); } -tsl::StatusOr> +absl::StatusOr> AllocateCudnnConvolutionBackwardDataWorkspace( Stream* stream, const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, @@ -3002,8 +3182,7 @@ AllocateCudnnConvolutionBackwardDataWorkspace( const dnn::AlgorithmDesc& algorithm_desc, ScratchAllocator* scratch_allocator) { if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) { - return tsl::Status( - absl::StatusCode::kInternal, + return absl::InternalError( "Mismatch between cudnn conv and algorithm descriptors."); } @@ -3025,8 +3204,7 @@ AllocateCudnnConvolutionBackwardDataWorkspace( int64_t size_in_bytes_int64_t = size_in_bytes; if (ABSL_PREDICT_FALSE(size_in_bytes_int64_t < 0)) { - return tsl::Status( - absl::StatusCode::kInternal, + return absl::InternalError( "cudnnGetConvolutionBackwardDataWorkspaceSize() returned " "negative sizeInBytes value. This could be a cudnn bug."); } @@ -3036,14 +3214,13 @@ AllocateCudnnConvolutionBackwardDataWorkspace( } if (ABSL_PREDICT_FALSE(!scratch_allocator)) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "No scratch allocator provided"); + return absl::InvalidArgumentError("No scratch allocator provided"); } return scratch_allocator->AllocateBytes(size_in_bytes); } -tsl::StatusOr> +absl::StatusOr> AllocateCudnnConvolutionBackwardFilterWorkspace( Stream* stream, const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, @@ -3052,8 +3229,7 @@ AllocateCudnnConvolutionBackwardFilterWorkspace( const dnn::AlgorithmDesc& algorithm_desc, ScratchAllocator* scratch_allocator) { if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) { - return tsl::Status( - absl::StatusCode::kInternal, + return absl::InternalError( "Mismatch between cudnn conv and algorithm descriptors."); } @@ -3075,8 +3251,7 @@ AllocateCudnnConvolutionBackwardFilterWorkspace( int64_t size_in_bytes_int64_t = size_in_bytes; if (ABSL_PREDICT_FALSE(size_in_bytes_int64_t < 0)) { - return tsl::Status( - absl::StatusCode::kInternal, + return absl::InternalError( "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned " "negative sizeInBytes value. This could be a cudnn bug."); } @@ -3086,8 +3261,7 @@ AllocateCudnnConvolutionBackwardFilterWorkspace( } if (ABSL_PREDICT_FALSE(!scratch_allocator)) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "No scratch allocator provided"); + return absl::InvalidArgumentError("No scratch allocator provided"); } return scratch_allocator->AllocateBytes(size_in_bytes); @@ -3108,7 +3282,7 @@ bool UseTensorOps(dnn::DataType input_type, cudnnDataType_t GetRnnComputeType(dnn::DataType data_type); dnn::DataType GetConvAccumulatorType(dnn::DataType data_type); -tsl::StatusOr GetCudnnConvolutionForwardAlgorithm( +absl::StatusOr GetCudnnConvolutionForwardAlgorithm( Stream* stream, const CudnnHandle& cudnn, const dnn::AlgorithmConfig& algorithm_config, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, @@ -3153,7 +3327,7 @@ tsl::StatusOr GetCudnnConvolutionForwardAlgorithm( // Failed to allocate workspace for the first algorithm, fall back to the // no_scratch algorithm. if (!algo_desc.has_value()) { - return tsl::Status( + return absl::Status( scratch_or.status().code(), absl::StrCat("The primary convolution algorithm failed, ", "while a secondary algorithm is not provided. ", @@ -3168,7 +3342,7 @@ tsl::StatusOr GetCudnnConvolutionForwardAlgorithm( return *algo_desc; } -tsl::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( +absl::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( Stream* stream, const CudnnHandle& cudnn, const dnn::AlgorithmConfig& algorithm_config, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, @@ -3212,8 +3386,7 @@ tsl::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( // Failed to allocate workspace for the first algorithm, fall back to the // no_scratch algorithm. if (!algo_desc.has_value()) { - return tsl::Status( - absl::StatusCode::kInvalidArgument, + return absl::InvalidArgumentError( "The primary convolution algorithm failed memory allocation, " "while a secondary algorithm is not provided."); } @@ -3226,7 +3399,7 @@ tsl::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( return *algo_desc; } -tsl::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( +absl::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( Stream* stream, const CudnnHandle& cudnn, const dnn::AlgorithmConfig& algorithm_config, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, @@ -3256,7 +3429,7 @@ tsl::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops); } - tsl::StatusOr> scratch_or = + absl::StatusOr> scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace( stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc, scratch_allocator); @@ -3271,12 +3444,10 @@ tsl::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( // Failed to allocate workspace for the first algorithm, fall back to the // no_scratch algorithm. if (!algo_desc.has_value()) { - return tsl::Status( - absl::StatusCode::kInvalidArgument, - absl::StrCat( - "The primary convolution algorithm failed memory allocation, " - "while a secondary algorithm is not provided. Actual error: ", - scratch_or.status().ToString())); + return absl::InvalidArgumentError(absl::StrCat( + "The primary convolution algorithm failed memory allocation, " + "while a secondary algorithm is not provided. Actual error: ", + scratch_or.status().ToString())); } use_tensor_ops = UseTensorOps(element_type, algo_desc); @@ -3367,7 +3538,7 @@ struct RnnDoFP32ComputationFP16Input { namespace { -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 bool GenericEngineFilter(cudnnBackendDescriptor_t engine_config, bool disable_winograd, bool disable_nondeterminism, @@ -3394,7 +3565,7 @@ bool GenericEngineFilter(cudnnBackendDescriptor_t engine_config, return ret; } -#endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8100 } // namespace @@ -3469,7 +3640,7 @@ dnn::DataType GetConvAccumulatorType(dnn::DataType data_type) { } } -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 namespace { static bool allowAllConfig(cudnnBackendDescriptor_t engine_config) { @@ -3538,8 +3709,8 @@ std::tuple GetTensorVectorSizeAndDim( return std::make_tuple(vector_size, vector_dim); } -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) -tsl::StatusOr CreateCudnnTensor( +#if CUDNN_VERSION >= 8800 +absl::StatusOr CreateCudnnTensor( absl::Span dims, absl::Span strides, int64_t uid, dnn::DataType dtype, int64_t vec_count, int64_t vec_dim, bool is_virtual = false, @@ -3561,7 +3732,7 @@ tsl::StatusOr CreateCudnnTensor( return tensor; } #else -tsl::StatusOr CreateCudnnTensor( +absl::StatusOr CreateCudnnTensor( absl::Span dims, absl::Span strides, int64_t uid, dnn::DataType dtype, int64_t vec_count, int64_t vec_dim, bool is_virtual = false, bool is_reordered_nchw_vect = false) { @@ -3591,10 +3762,10 @@ tsl::StatusOr CreateCudnnTensor( } #endif -tsl::StatusOr CreateCudnnTensor( +absl::StatusOr CreateCudnnTensor( const cudnn_frontend::Tensor& original, int64_t uid, dnn::DataType dtype, bool is_virtual = false) { -#if (CUDNN_VERSION >= 8900 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8900 auto tensor = cudnn_frontend::TensorBuilder() .cloneFrom(original, uid) .setAlignment(32) @@ -3605,10 +3776,10 @@ tsl::StatusOr CreateCudnnTensor( return tensor; #else return tsl::errors::Internal("Not implemented."); -#endif // CUDNN_VERSION >= 8900 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8900 } -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 enum CudnnfMHAUid { Q_ID = 400, K_ID, @@ -3639,7 +3810,7 @@ enum CudnnfMHAUid { VIRTUAL_ID = 34857 }; -tsl::StatusOr CreatePwDesc( +absl::StatusOr CreatePwDesc( dnn::DataType dtype, cudnnPointwiseMode_t mode) { auto pw_desc_created = cudnn_frontend::PointWiseDescBuilder() .setMode(mode) @@ -3649,7 +3820,7 @@ tsl::StatusOr CreatePwDesc( return pw_desc_created; } -tsl::StatusOr CreateUnaryPwOp( +absl::StatusOr CreateUnaryPwOp( cudnn_frontend::Tensor const& xDesc, cudnn_frontend::Tensor const& yDesc, cudnn_frontend::PointWiseDesc const& pwDesc) { auto pw_op_created = cudnn_frontend::OperationBuilder( @@ -3662,7 +3833,7 @@ tsl::StatusOr CreateUnaryPwOp( return pw_op_created; } -tsl::StatusOr CreateBinaryPwOp( +absl::StatusOr CreateBinaryPwOp( cudnn_frontend::Tensor const& xDesc, cudnn_frontend::Tensor const& bDesc, cudnn_frontend::Tensor const& yDesc, cudnn_frontend::PointWiseDesc const& pwDesc) { @@ -3677,7 +3848,7 @@ tsl::StatusOr CreateBinaryPwOp( return pw_op_created; } -tsl::StatusOr CreateTernaryPwOp( +absl::StatusOr CreateTernaryPwOp( cudnn_frontend::Tensor const& xDesc, cudnn_frontend::Tensor const& bDesc, cudnn_frontend::Tensor const& tDesc, cudnn_frontend::Tensor const& yDesc, cudnn_frontend::PointWiseDesc const& pwDesc) { @@ -3694,7 +3865,7 @@ tsl::StatusOr CreateTernaryPwOp( } // Returns a cudnn tensor that's the output of the mask op -tsl::StatusOr CreateCudnnMaskFwdTensor( +absl::StatusOr CreateCudnnMaskFwdTensor( std::vector& ops, absl::Span dims, absl::Span strides, dnn::DataType dtype, cudnn_frontend::Tensor& input_tensor) { @@ -3737,7 +3908,7 @@ tsl::StatusOr CreateCudnnMaskFwdTensor( } // Returns a cudnn tensor that's the output of the alpha scale -tsl::StatusOr CreateCudnnScaleTensor( +absl::StatusOr CreateCudnnScaleTensor( std::vector& ops, absl::Span dims, absl::Span strides, dnn::DataType dtype, cudnn_frontend::Tensor& input_tensor) { @@ -3768,7 +3939,7 @@ tsl::StatusOr CreateCudnnScaleTensor( } // Returns a cudnn tensor that's the output of the bias addition op -tsl::StatusOr CreateCudnnBiasTensor( +absl::StatusOr CreateCudnnBiasTensor( std::vector& ops, absl::Span dims, absl::Span strides, dnn::DataType dtype, cudnn_frontend::Tensor& input_tensor, bool use_mask) { @@ -3807,7 +3978,7 @@ tsl::StatusOr CreateCudnnBiasTensor( } // Returns a cudnn tensor that's the output of the softmax op -tsl::StatusOr CreateCudnnSoftmaxFwdTensor( +absl::StatusOr CreateCudnnSoftmaxFwdTensor( std::vector& ops, absl::Span dims, absl::Span strides, dnn::DataType dtype, cudnn_frontend::Tensor& input_tensor, bool is_virtual = false) { @@ -3937,7 +4108,7 @@ tsl::StatusOr CreateCudnnSoftmaxFwdTensor( } // Returns a cudnn tensor that's the output of the dropout op -tsl::StatusOr CreateCudnnDropoutFwdTensor( +absl::StatusOr CreateCudnnDropoutFwdTensor( std::vector& ops, absl::Span dims, absl::Span strides, dnn::DataType dtype, cudnn_frontend::Tensor& input_tensor, double dropout_rate, int64_t seed, @@ -4035,9 +4206,9 @@ tsl::StatusOr CreateCudnnDropoutFwdTensor( return dropout_scale_out_tensor; } -#endif // CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8800 -tsl::StatusOr> +absl::StatusOr> GetCudnnOperationGraph(dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor, @@ -4082,7 +4253,7 @@ GetCudnnOperationGraph(dnn::ConvolutionKind kind, dnn::DataType input_type, std::vector filter_strides = filter_descriptor.vectorized_strides( dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim); -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 cudnnBackendTensorReordering_t tensor_ordering_type = filter_descriptor.layout() == dnn::FilterLayout::kOutputInputYX32_CudnnReordered @@ -4094,7 +4265,7 @@ GetCudnnOperationGraph(dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::FilterLayout::kOutputInputYX32_CudnnReordered; #endif -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 TF_ASSIGN_OR_RETURN( auto tensor_w, CreateCudnnTensor(filter_dims, filter_strides, 'w', input_type, @@ -4163,7 +4334,7 @@ GetCudnnOperationGraph(dnn::ConvolutionKind kind, dnn::DataType input_type, return std::make_unique(std::move(opGraph)); } -tsl::StatusOr PrimitiveTypeStringToDnnType( +absl::StatusOr PrimitiveTypeStringToDnnType( std::string data_type_string) { if (data_type_string == "f8e4m3fn") { return dnn::DataType::kF8E4M3FN; @@ -4185,7 +4356,7 @@ using OpMode = std::variant> +absl::StatusOr> OpNameStringToOperandKindAndMode(std::string opstring) { #define KINDS_AND_MODE_FROM_OP_STRING(OPSTRING, BINARYOPERANDKIND, \ AUXOUTPUTKIND, PWMODE) \ @@ -4229,9 +4400,9 @@ class OpGraph { public: OpGraph() = default; - tsl::Status AddOp(int uid, std::optional operand_uid, OpMode mode, - TensorKind operand_kind, TensorKind result_kind, - dnn::DataType result_type) { + absl::Status AddOp(int uid, std::optional operand_uid, OpMode mode, + TensorKind operand_kind, TensorKind result_kind, + dnn::DataType result_type) { ops_.emplace_back(OpDescriptor({uid, operand_uid, mode, operand_kind, result_kind, result_type, false, -1})); // If it exists, the operand is virtual. @@ -4244,10 +4415,10 @@ class OpGraph { } it->is_virtual = true; } - return tsl::OkStatus(); + return absl::OkStatus(); } - tsl::StatusOr FindOpDescriptor(int uid) const { + absl::StatusOr FindOpDescriptor(int uid) const { auto it = std::find_if(ops_.begin(), ops_.end(), [uid](OpDescriptor op) { return op.uid == uid; }); if (it == ops_.end()) { @@ -4256,21 +4427,21 @@ class OpGraph { return *it; } - tsl::StatusOr OpDescriptorAt(int index) const { + absl::StatusOr OpDescriptorAt(int index) const { if (index >= Size()) { return tsl::errors::Internal("Index exceeds bounds."); } return ops_[index]; } - tsl::Status SetSequenceIndex(int uid, int index) { + absl::Status SetSequenceIndex(int uid, int index) { auto it = std::find_if(ops_.begin(), ops_.end(), [uid](OpDescriptor op) { return op.uid == uid; }); if (it == ops_.end()) { return tsl::errors::Internal("Unknown ID."); } it->sequence_index = index; - return tsl::OkStatus(); + return absl::OkStatus(); } bool Empty() const { return ops_.empty(); } @@ -4287,8 +4458,8 @@ class OpGraph { // Returns a generic cuDNN OperationGraph for ForwardGraph convolutions with the // fused ops listed in serialized_graph and the associated set of UIDs of // non-virtual cuDNN tensors. -tsl::StatusOr, - std::vector>> +absl::StatusOr, + std::vector>> GetGenericCudnnOperationGraph( dnn::ConvolutionKind kind, dnn::DataType input_type, const dnn::BatchDescriptor& input_descriptor, @@ -4304,7 +4475,7 @@ GetGenericCudnnOperationGraph( // UID);UID:[output_type]op_name(operand UID);..." with the convolution // assumed to be the first op in the graph. Operand UIDs identifying ops // outside the serialized graph are elided. - auto deserialize_cudnn_graph = [&]() -> tsl::StatusOr { + auto deserialize_cudnn_graph = [&]() -> absl::StatusOr { OpGraph op_graph; std::string::size_type pos = 0; while (pos < serialized_graph.size()) { @@ -4623,7 +4794,7 @@ bool SideInputNeeded(dnn::ActivationMode activation_mode, double conv_scale, return check_activation || check_scale; } -tsl::StatusOr> +absl::StatusOr> GetCudnnFusedOperationGraph( dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType bias_type, dnn::DataType output_type, double alpha, @@ -4680,7 +4851,7 @@ GetCudnnFusedOperationGraph( std::vector filter_strides = filter_descriptor.vectorized_strides( dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim); -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 cudnnBackendTensorReordering_t tensor_ordering_type = filter_descriptor.layout() == dnn::FilterLayout::kOutputInputYX32_CudnnReordered @@ -4692,7 +4863,7 @@ GetCudnnFusedOperationGraph( dnn::FilterLayout::kOutputInputYX32_CudnnReordered; #endif -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 TF_ASSIGN_OR_RETURN( auto tensor_w, CreateCudnnTensor(filter_dims, filter_strides, 'w', input_type, @@ -4743,7 +4914,7 @@ GetCudnnFusedOperationGraph( auto maybe_tensor_b = CreateCudnnTensor(bias_dims, bias_strides, 'b', bias_type, vector_size, vector_dim, /*is_virtual=*/false, -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 tensor_ordering_type #else is_reordered_nchw_vect @@ -4937,7 +5108,7 @@ GetCudnnFusedOperationGraph( return std::make_unique(std::move(op_graph)); } -tsl::StatusOr> +absl::StatusOr> GetCudnnFusedMatmulGraph(dnn::DataType input_type, dnn::DataType bias_type, dnn::DataType output_type, bool trans_a, bool trans_b, uint64_t m_u, uint64_t n_u, uint64_t k_u, int64_t lda, @@ -5080,8 +5251,8 @@ GetCudnnFusedMatmulGraph(dnn::DataType input_type, dnn::DataType bias_type, return std::make_unique(std::move(op_graph)); } -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) -tsl::StatusOr> +#if CUDNN_VERSION >= 8800 +absl::StatusOr> GetCudnnFusedMHAOperationGraph( const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor, const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor, @@ -5307,7 +5478,7 @@ GetCudnnFusedMHAOperationGraph( return std::make_unique(std::move(op_graph)); } -tsl::StatusOr CreateCudnnDropoutBwdTensor( +absl::StatusOr CreateCudnnDropoutBwdTensor( std::vector& ops, absl::Span dims, absl::Span strides, dnn::DataType dtype, cudnn_frontend::Tensor const& tensor_dropout_scale, @@ -5385,7 +5556,7 @@ tsl::StatusOr CreateCudnnDropoutBwdTensor( } // Returns a cudnn tensor that's the output of the softmax backward op -tsl::StatusOr CreateCudnnSoftmaxBwdTensor( +absl::StatusOr CreateCudnnSoftmaxBwdTensor( std::vector& ops, absl::Span dims, absl::Span strides, dnn::DataType dtype, cudnn_frontend::Tensor const& tensor_y, @@ -5496,7 +5667,7 @@ tsl::StatusOr CreateCudnnSoftmaxBwdTensor( return tensor_mul_reduction_sub_mul_alpha_scale; } -tsl::StatusOr CreateCudnnMaskBwdTensor( +absl::StatusOr CreateCudnnMaskBwdTensor( std::vector& ops, absl::Span dims, absl::Span strides, dnn::DataType dtype, cudnn_frontend::Tensor const& input_tensor, bool use_mask) { @@ -5609,8 +5780,8 @@ tsl::StatusOr CreateCudnnMaskBwdTensor( return dummy_mask_out_tensor; } } -#if (CUDNN_VERSION >= 8901 && TF_ENABLE_CUDNN_FRONTEND) -tsl::StatusOr CreateCudnnBiasBwdTensor( +#if CUDNN_VERSION >= 8901 +absl::StatusOr CreateCudnnBiasBwdTensor( std::vector& ops, absl::Span dims, absl::Span strides, dnn::DataType dtype, cudnn_frontend::Tensor const& input_tensor) { @@ -5685,8 +5856,8 @@ tsl::StatusOr CreateCudnnBiasBwdTensor( return dbias_tensor; } -#endif // (CUDNN_VERSION >= 8901 && TF_ENABLE_CUDNN_FRONTEND) -tsl::StatusOr> +#endif // CUDNN_VERSION >= 8901 +absl::StatusOr> GetCudnnFusedMHABackwardOperationGraph( const dnn::MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, const dnn::MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, @@ -5978,7 +6149,7 @@ GetCudnnFusedMHABackwardOperationGraph( // bias backward if (use_bias) { -#if (CUDNN_VERSION >= 8901 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8901 TF_ASSIGN_OR_RETURN( auto tensor_dbias, CreateCudnnBiasBwdTensor(intermediate_ops, p_dims, p_strides, dtype, @@ -6061,7 +6232,7 @@ GetCudnnFusedMHABackwardOperationGraph( } // Returns a cudnn tensor that's the output of the bias addition op -tsl::StatusOr CreateCudnnFlashAttentionBiasFwdTensor( +absl::StatusOr CreateCudnnFlashAttentionBiasFwdTensor( std::vector& ops, absl::Span dims, absl::Span strides, dnn::DataType dtype, cudnn_frontend::Tensor& input_tensor) { @@ -6097,7 +6268,8 @@ tsl::StatusOr CreateCudnnFlashAttentionBiasFwdTensor( return bias_out_tensor; } -tsl::StatusOr CreateCudnnFlashAttentionCausalMaskTensor( +absl::StatusOr +CreateCudnnFlashAttentionCausalMaskTensor( std::vector& ops, absl::Span dims, absl::Span strides, dnn::DataType dtype, cudnn_frontend::Tensor& input_tensor) { @@ -6192,492 +6364,150 @@ tsl::StatusOr CreateCudnnFlashAttentionCausalMaskTensor( mask_out_tensor, mask_desc)); // Add mask to op list - ops.push_back(std::move(gen_index_row_op)); - ops.push_back(std::move(gen_index_column_op)); - ops.push_back(std::move(row_greater_than_column_op)); - ops.push_back(std::move(mask_op)); - - return mask_out_tensor; -} - -tsl::StatusOr CreateCudnnFlashAttentionSoftmaxFwdTensor( - std::vector& ops, absl::Span dims, - absl::Span strides, dnn::DataType dtype, - cudnn_frontend::Tensor& input_tensor, bool is_virtual = false) { - // softmax's typical computation is: - // exp(input - reduce_max(input)) / reduce_sum(exp(input - reduce_max(input))) - // We need to create each op and add it to the op list sequentially. - - // Copy all dims except the last dim since it's reduced to 1. - std::vector reduction_output_dim(dims.begin(), dims.end() - 1); - reduction_output_dim.push_back(1); - - // Divide every stride by the last dim value. - std::vector reduction_output_stride; - int64_t reduced_dim_len = dims.back(); - for (auto stride : strides) { - reduction_output_stride.push_back(stride / reduced_dim_len); - } - - // Create output tensor of the first max reduction. - TF_ASSIGN_OR_RETURN( - auto max_reduction_output_tensor, - CreateCudnnTensor(reduction_output_dim, reduction_output_stride, - CudnnfMHAUid::VIRTUAL_ID + 500, dnn::DataType::kFloat, - 1, -1, /*is_virtual=*/true)); - - // Create the reduction descriptor - auto max_reduction_desc = - cudnn_frontend::ReductionDescBuilder() - .setComputeType(ToCudnnDataType(dnn::DataType::kFloat)) - .setReductionOp(CUDNN_REDUCE_TENSOR_MAX) - .build(); - RETURN_MSG_IF_CUDNN_ERROR(max_reduction_desc); - // Create a reduction max node. - auto max_reduction_op = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(input_tensor) - .setyDesc(max_reduction_output_tensor) - .setreductionDesc(max_reduction_desc) - .build(); - RETURN_MSG_IF_CUDNN_ERROR(max_reduction_op); - - // Create output tensor of the subtraction op. - TF_ASSIGN_OR_RETURN( - auto subtract_output_tensor, - CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 501, - dnn::DataType::kFloat, 1, -1, - /*is_virtual=*/true)); - // Create the subtraction descriptor - TF_ASSIGN_OR_RETURN(auto subtract_desc, - CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_SUB)); - - // Create a subtraction node. - TF_ASSIGN_OR_RETURN( - auto subtract_op, - CreateBinaryPwOp(input_tensor, max_reduction_output_tensor, - subtract_output_tensor, subtract_desc)); - // Create output tensor of the exp op. - TF_ASSIGN_OR_RETURN( - auto exp_output_tensor, - CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 502, - dnn::DataType::kFloat, 1, -1, - /*is_virtual=*/true)); - // Create the exponetial descriptor - TF_ASSIGN_OR_RETURN(auto exp_desc, - CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_EXP)); - - // Create a exponetial node. - TF_ASSIGN_OR_RETURN( - auto exp_op, - CreateUnaryPwOp(subtract_output_tensor, exp_output_tensor, exp_desc)); - - // Create output tensor of the sum reduction. - TF_ASSIGN_OR_RETURN( - auto sum_reduction_output_tensor, - CreateCudnnTensor(reduction_output_dim, reduction_output_stride, - CudnnfMHAUid::VIRTUAL_ID + 503, dnn::DataType::kFloat, - 1, -1, /*is_virtual=*/true)); - // Create the reduction descriptor - auto sum_reduction_desc = - cudnn_frontend::ReductionDescBuilder() - .setComputeType(ToCudnnDataType(dnn::DataType::kFloat)) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - RETURN_MSG_IF_CUDNN_ERROR(sum_reduction_desc); - // Create a reduction sum node. - auto sum_reduction_op = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(exp_output_tensor) - .setyDesc(sum_reduction_output_tensor) - .setreductionDesc(sum_reduction_desc) - .build(); - RETURN_MSG_IF_CUDNN_ERROR(sum_reduction_op); - - // Create output tensor of the log op. - TF_ASSIGN_OR_RETURN( - auto log_tensor, - CreateCudnnTensor(reduction_output_dim, reduction_output_stride, - CudnnfMHAUid::VIRTUAL_ID + 504, dnn::DataType::kFloat, - 1, -1, - /*is_virtual*/ true)); - - // Create the log descriptor - TF_ASSIGN_OR_RETURN(auto log_desc, - CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_LOG)); - - // Create a log node. - TF_ASSIGN_OR_RETURN(auto log_op, CreateUnaryPwOp(sum_reduction_output_tensor, - log_tensor, log_desc)); - - // Create output tensor of the add op. - auto ID = is_virtual ? CudnnfMHAUid::VIRTUAL_ID + 505 : CudnnfMHAUid::P_ID; - TF_ASSIGN_OR_RETURN( - auto softmax_stats_tensor, - CreateCudnnTensor(reduction_output_dim, reduction_output_stride, ID, - dnn::DataType::kFloat, 1, -1, - /*is_virtual*/ is_virtual)); - - // Create the add descriptor - TF_ASSIGN_OR_RETURN(auto add_desc, - CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_ADD)); - - // Create a add node. - TF_ASSIGN_OR_RETURN(auto add_op, - CreateBinaryPwOp(max_reduction_output_tensor, log_tensor, - softmax_stats_tensor, add_desc)); - - // Create output tensor of the divide op. - TF_ASSIGN_OR_RETURN( - auto divide_output_tensor, - CreateCudnnTensor( - dims, strides, CudnnfMHAUid::VIRTUAL_ID + 506, dnn::DataType::kFloat, - 1, -1, - /*is_virtual*/ true, - /*cudnn_tensor_order_type*/ CUDNN_TENSOR_REORDERING_F16x16)); - // Create the divide descriptor - TF_ASSIGN_OR_RETURN(auto divide_desc, - CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_DIV)); - - // Create a divide node. - TF_ASSIGN_OR_RETURN( - auto divide_op, - CreateBinaryPwOp(exp_output_tensor, sum_reduction_output_tensor, - divide_output_tensor, divide_desc)); - - // Add max reduction to op list - ops.push_back(std::move(max_reduction_op)); - // Add subtract to op list - ops.push_back(std::move(subtract_op)); - // Add exponetial to op list - ops.push_back(std::move(exp_op)); - // Add sum reduction to op list - ops.push_back(std::move(sum_reduction_op)); - // Add Log to op list - ops.push_back(std::move(log_op)); - // Add Add to op list - ops.push_back(std::move(add_op)); - // Add divide to op list - ops.push_back(std::move(divide_op)); - return divide_output_tensor; -} - -tsl::StatusOr CreateCudnnFlashAttentionDropoutFwdTensor( - std::vector& ops, absl::Span dims, - absl::Span strides, dnn::DataType dtype, - cudnn_frontend::Tensor& input_tensor, double dropout_rate) { - // Create scale tensor - std::vector scale_dims(dims.size(), 1); - std::vector scale_strides(strides.size(), 1); - - // Create tensor for dropout's mask. - TF_ASSIGN_OR_RETURN( - auto mask_tensor, - CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 600, - dnn::DataType::kFloat, 1, -1, - /*is_virtual*/ true)); - // Create output tensor of dropout node - // it is different from regular attention, the dropout output is always - // virtual we compute mask in the bwd instead of storing the mask - TF_ASSIGN_OR_RETURN( - auto dropout_out_tensor, - CreateCudnnTensor( - dims, strides, CudnnfMHAUid::VIRTUAL_ID + 601, dtype, 1, -1, - /*is_virtual*/ true, - /*cudnn_tensor_order_type*/ - cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16)); - - // Create offset tensor of dropout node - TF_ASSIGN_OR_RETURN( - auto dropout_offset_tensor, - CreateCudnnTensor( - scale_dims, scale_strides, CudnnfMHAUid::D_OFFSET_ID, - dnn::DataType::kInt64, 1, -1, /*is_virtual*/ false, - /*cudnn_tensor_order_type*/ - cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_NONE, - /*is_value*/ CUDNN_VERSION < 8903 ? false : true)); - - // Create seed tensor of dropout node - TF_ASSIGN_OR_RETURN( - auto dropout_seed_tensor, - CreateCudnnTensor( - scale_dims, scale_strides, CudnnfMHAUid::D_SEED_ID, - dnn::DataType::kInt64, 1, -1, /*is_virtual*/ false, - /*cudnn_tensor_order_type*/ - cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_NONE, - /*is_value*/ CUDNN_VERSION < 8903 ? false : true)); - - // Create description for rng node - auto rng_desc = cudnn_frontend::RngDescBuilder() - .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI) - .setBernoulliDistProbability(1.0 - dropout_rate) - .build(); - - // Create the rng Node. - auto rng_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR) - .setyDesc(mask_tensor) - .setSeedDesc(dropout_seed_tensor) - .setOffsetDesc(dropout_offset_tensor) - .setRngDesc(rng_desc) - .build(); - RETURN_MSG_IF_CUDNN_ERROR(rng_op); - - // Create the masking node desc after mask tensor - TF_ASSIGN_OR_RETURN(auto masking_desc, - CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_MUL)); - - // Create the scaling op - TF_ASSIGN_OR_RETURN(auto masking_op, - CreateBinaryPwOp(input_tensor, mask_tensor, - dropout_out_tensor, masking_desc)); - - TF_ASSIGN_OR_RETURN( - auto dropout_scale_tensor, - CreateCudnnTensor( - scale_dims, scale_strides, CudnnfMHAUid::DROPOUT_SCALE_ID, - dnn::DataType::kFloat, 1, -1, - /*is_virtual*/ false, - /*cudnn_tensor_order_type*/ CUDNN_TENSOR_REORDERING_NONE, - /*is_value*/ true)); - - // Create output of scale node - TF_ASSIGN_OR_RETURN( - auto dropout_scale_out_tensor, - CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 602, dtype, 1, - -1, /*is_virtual*/ true)); - // Create the scaling desc - TF_ASSIGN_OR_RETURN(auto scale_desc, - CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_MUL)); - - // Create the scaling op - TF_ASSIGN_OR_RETURN(auto scale_op, - CreateBinaryPwOp(dropout_out_tensor, dropout_scale_tensor, - dropout_scale_out_tensor, scale_desc)); - // Add rng op to op list - ops.push_back(std::move(rng_op)); - // Add masking op to op list - ops.push_back(std::move(masking_op)); - // Add scaling op to op list - ops.push_back(std::move(scale_op)); - - return dropout_scale_out_tensor; -} - -tsl::StatusOr> -GetCudnnFlashAttentionOperationGraph( - const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, - const dnn::TensorDescriptor& output_descriptor, - std::optional mask_descriptor, - std::optional bias_descriptor, - std::optional activation_descriptor, - dnn::FusedMHAKind kind, std::optional dropout_rate, - std::optional seed, CudnnHandle& cudnn, double scale, - std::vector& intermediate_shape, bool use_dropout = false, - bool use_mask = false, bool use_bias = false, - bool use_causal_mask = false) { - if (VLOG_IS_ON(4)) { - VLOG(4) << "\n bmm1_lhs(q): " << bmm1_lhs_descriptor.ToString() - << "\n bmm1_rhs(k): " << bmm1_rhs_descriptor.ToString() - << "\n bmm2_lhs(s): " << intermediate_bmm2_lhs_descriptor.ToString() - << "\n bmm2_rhs(v): " << bmm2_rhs_descriptor.ToString() - << "\n out(o): " << output_descriptor.ToString(); - if (activation_descriptor) { - VLOG(4) << "\n activation(s): " << (*activation_descriptor).ToString(); - } - } - - // cnn_infer needs to be preloaded for fMHA as well. Reusing the function - // created for convolution for fMHA. - PreloadCudnnSubLibsHelper(dnn::ConvolutionKind::FORWARD); - - std::vector ops; - std::vector intermediate_ops; - - // Batched Matmul: bmm1_lhs: tensor_q, bmm1_rhs:tensor_k; output: tensor_s - // (virtual) - // Batched Matmul: bmm2_lhs: tensor_s, bmm2_rhs:tensor_v; output: tensor_o - std::vector bmm1_lhs_dims = - bmm1_lhs_descriptor.GetCudnnCompatibleDimensions(true); - std::vector bmm1_lhs_strides = - bmm1_lhs_descriptor.GetCudnnCompatibleStrides(true); - - VLOG(2) << "\n cuDNN compatible bmm1_lhs_dims: " - << absl::StrJoin(bmm1_lhs_dims, ",") - << "\n cuDNN compatible bmm1_lhs_strides: " - << absl::StrJoin(bmm1_lhs_strides, ","); - - TF_ASSIGN_OR_RETURN( - auto tensor_q, - CreateCudnnTensor(bmm1_lhs_dims, bmm1_lhs_strides, CudnnfMHAUid::Q_ID, - bmm1_lhs_descriptor.type(), 1, -1)); - - std::vector bmm1_rhs_dims = - bmm1_rhs_descriptor.GetCudnnCompatibleDimensions(false); - std::vector bmm1_rhs_strides = - bmm1_rhs_descriptor.GetCudnnCompatibleStrides(false); - - VLOG(2) << "\n cuDNN compatible bmm1_rhs_dims: " - << absl::StrJoin(bmm1_rhs_dims, ",") - << "\n cuDNN compatible bmm1_rhs_strides: " - << absl::StrJoin(bmm1_rhs_strides, ","); - - TF_ASSIGN_OR_RETURN( - auto tensor_k, - CreateCudnnTensor(bmm1_rhs_dims, bmm1_rhs_strides, CudnnfMHAUid::K_ID, - bmm1_rhs_descriptor.type(), 1, -1)); - - std::vector intermediate_bmm2_lhs_dims = - intermediate_bmm2_lhs_descriptor.GetCudnnCompatibleDimensions(true); - std::vector intermediate_bmm2_lhs_strides = - intermediate_bmm2_lhs_descriptor.GetCudnnCompatibleStrides(true); - - VLOG(2) << "\n cuDNN compatible intermediate_bmm2_lhs_dims: " - << absl::StrJoin(intermediate_bmm2_lhs_dims, ",") - << "\n cuDNN compatible intermediate_bmm2_lhs_strides: " - << absl::StrJoin(intermediate_bmm2_lhs_strides, ","); - intermediate_shape = intermediate_bmm2_lhs_dims; - bool has_activation = activation_descriptor != std::nullopt; - - TF_ASSIGN_OR_RETURN(auto tensor_s, - CreateCudnnTensor(intermediate_bmm2_lhs_dims, - intermediate_bmm2_lhs_strides, - CudnnfMHAUid::VIRTUAL_ID + 100, - dnn::DataType::kFloat, 1, -1, - /*is_virtual=*/true)); - - auto bmm1_desc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .build(); - RETURN_MSG_IF_CUDNN_ERROR(bmm1_desc); - auto bmm1_op = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(tensor_q) - .setbMatDesc(tensor_k) - .setcMatDesc(tensor_s) - .setmatmulDesc(bmm1_desc) - .build(); - RETURN_MSG_IF_CUDNN_ERROR(bmm1_op); - intermediate_ops.push_back(std::move(bmm1_op)); - - // Create scale op and tensor - TF_ASSIGN_OR_RETURN( - auto alpha_scale_out, - CreateCudnnScaleTensor(intermediate_ops, intermediate_bmm2_lhs_dims, - intermediate_bmm2_lhs_strides, - dnn::DataType::kFloat, tensor_s)); - - auto bmm2_input_tensor = std::move(alpha_scale_out); - - if (use_bias) { - // Create bias op and tensor - TF_ASSIGN_OR_RETURN(auto bias_out, - CreateCudnnFlashAttentionBiasFwdTensor( - intermediate_ops, intermediate_bmm2_lhs_dims, - intermediate_bmm2_lhs_strides, - (*bias_descriptor).type(), bmm2_input_tensor)); - bmm2_input_tensor = std::move(bias_out); - } - - if (use_causal_mask) { - // Create mask op and tensor - TF_ASSIGN_OR_RETURN( - auto mask_out, - CreateCudnnFlashAttentionCausalMaskTensor( - intermediate_ops, intermediate_bmm2_lhs_dims, - intermediate_bmm2_lhs_strides, - intermediate_bmm2_lhs_descriptor.type(), bmm2_input_tensor)); - bmm2_input_tensor = std::move(mask_out); - } - - // Create Softmax tensor - // The output is always a virtual for inference mode. - // The output is always non virtual for training mode. cuz we recompute - // dropout in bwd.; - bool should_output_softmax = has_activation; - TF_ASSIGN_OR_RETURN(auto softmax_fwd_out, - CreateCudnnFlashAttentionSoftmaxFwdTensor( - intermediate_ops, intermediate_bmm2_lhs_dims, - intermediate_bmm2_lhs_strides, - intermediate_bmm2_lhs_descriptor.type(), - /*input_tensor*/ bmm2_input_tensor, - /*is_virtual*/ !should_output_softmax)); - bmm2_input_tensor = std::move(softmax_fwd_out); - - // Create dropout tensor - // dropout is always virtual in inference or training for flash attention - TF_ASSIGN_OR_RETURN(auto dropout_out, - CreateCudnnFlashAttentionDropoutFwdTensor( - intermediate_ops, intermediate_bmm2_lhs_dims, - intermediate_bmm2_lhs_strides, - intermediate_bmm2_lhs_descriptor.type(), - /*input_tensor*/ softmax_fwd_out, *dropout_rate)); - bmm2_input_tensor = std::move(dropout_out); - - std::vector bmm2_rhs_dims = - bmm2_rhs_descriptor.GetCudnnCompatibleDimensions(false); - std::vector bmm2_rhs_strides = - bmm2_rhs_descriptor.GetCudnnCompatibleStrides(false); - - VLOG(2) << "\n cuDNN compatible bmm2_rhs_dims: " - << absl::StrJoin(bmm2_rhs_dims, ",") - << "\n cuDNN compatible bmm2_rhs_strides: " - << absl::StrJoin(bmm2_rhs_strides, ","); + ops.push_back(std::move(gen_index_row_op)); + ops.push_back(std::move(gen_index_column_op)); + ops.push_back(std::move(row_greater_than_column_op)); + ops.push_back(std::move(mask_op)); - TF_ASSIGN_OR_RETURN( - auto tensor_v, - CreateCudnnTensor(bmm2_rhs_dims, bmm2_rhs_strides, CudnnfMHAUid::V_ID, - bmm2_rhs_descriptor.type(), 1, -1)); + return mask_out_tensor; +} - std::vector output_dims = output_descriptor.dimensions(); - std::vector output_strides = output_descriptor.GetLogicalStrides(); +absl::StatusOr GetCudnnFlashAttentionOperationGraph( + dnn::DnnSupport& dnn_support, + const dnn::MatmulTensorDescriptor& q_descriptor, + const dnn::MatmulTensorDescriptor& k_descriptor, + const dnn::MatmulTensorDescriptor& v_descriptor, + const dnn::TensorDescriptor& o_descriptor, + const std::optional bias_descriptor, + const std::optional mask_descriptor, + const std::optional stats_descriptor, + const float scale, const bool use_dropout, + const std::optional dropout_rate, const bool is_causal_mask) { + using cudnn_frontend::graph::Tensor_attributes; - VLOG(2) << "\n Out Dims: " << absl::StrJoin(output_dims, ",") - << "\n Out Strides: " << absl::StrJoin(output_strides, ","); + if (VLOG_IS_ON(4)) { + VLOG(4) << "\n bmm1_lhs(q): " << q_descriptor.ToString() + << "\n bmm1_rhs(k): " << k_descriptor.ToString() + << "\n bmm2_rhs(v): " << v_descriptor.ToString() + << "\n out(o): " << o_descriptor.ToString(); + if (bias_descriptor) { + VLOG(4) << "\n bias(b): " << bias_descriptor->ToString(); + } + if (mask_descriptor) { + VLOG(4) << "\n mask(m): " << mask_descriptor->ToString(); + } + if (stats_descriptor) { + VLOG(4) << "\n activation(s): " << stats_descriptor->ToString(); + } + } + + cudnn_frontend::graph::Graph graph; + dnn::DataType q_type = q_descriptor.type(); + dnn::DataType k_type = k_descriptor.type(); + dnn::DataType v_type = v_descriptor.type(); + dnn::DataType o_type = o_descriptor.type(); + if (!(q_type == k_type && k_type == v_type && v_type == o_type)) { + return absl::InternalError("Input datatypes do not match"); + } + cudnn_frontend::DataType_t ioDataType = ToCudnnFrontendDataType(q_type); + + graph.set_intermediate_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_io_data_type(ioDataType) + .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); + + std::shared_ptr q_tensor = + graph.tensor(Tensor_attributes() + .set_name("Q") + .set_dim(q_descriptor.GetCudnnCompatibleDimensions(true)) + .set_stride(q_descriptor.GetCudnnCompatibleStrides(true)) + .set_uid(CudnnfMHAUid::Q_ID)); + std::shared_ptr k_tensor = + graph.tensor(Tensor_attributes() + .set_name("K") + .set_dim(k_descriptor.GetCudnnCompatibleDimensions(true)) + .set_stride(k_descriptor.GetCudnnCompatibleStrides(true)) + .set_uid(CudnnfMHAUid::K_ID)); + std::shared_ptr v_tensor = graph.tensor( + Tensor_attributes() + .set_name("V") + .set_dim(v_descriptor.GetCudnnCompatibleDimensions(false)) + .set_stride(v_descriptor.GetCudnnCompatibleStrides(false)) + .set_uid(CudnnfMHAUid::V_ID)); + + // Setting sdpa, and is_inference + cudnn_frontend::graph::SDPA_attributes sdpa_options; + sdpa_options.set_name("flash_attention") + .set_is_inference(stats_descriptor == std::nullopt) + .set_causal_mask(is_causal_mask) + .set_attn_scale(scale); + + // Setting bias + std::shared_ptr bias = nullptr; + if (bias_descriptor.has_value()) { + auto bias_tensor = + graph.tensor(Tensor_attributes() + .set_name("bias") + .set_dim(bias_descriptor->dimensions()) + .set_stride(bias_descriptor->GetLogicalStrides()) + .set_uid(CudnnfMHAUid::BIAS_ID)); + sdpa_options.set_bias(bias_tensor); + } + // Setting seed and bias + if (use_dropout && dropout_rate.has_value() && *dropout_rate > 0.0) { + auto seed_tensor = + graph.tensor(Tensor_attributes() + .set_name("seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(cudnn_frontend::DataType_t::INT64) + .set_is_pass_by_value(true) + .set_uid(CudnnfMHAUid::D_SEED_ID)); + auto offset_tensor = + graph.tensor(Tensor_attributes() + .set_name("offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(cudnn_frontend::DataType_t::INT64) + .set_is_pass_by_value(true) + .set_uid(CudnnfMHAUid::D_OFFSET_ID)); + sdpa_options.set_dropout((float)dropout_rate.value(), seed_tensor, + offset_tensor); + } + + // Add SDPA to the graph. + auto [o_tensor, stats_tensor] = + graph.sdpa(q_tensor, k_tensor, v_tensor, sdpa_options); + + // Set output attributes. + o_tensor->set_name("O") + .set_output(true) + .set_dim(o_descriptor.dimensions()) + .set_stride(o_descriptor.GetLogicalStrides()) + .set_uid(CudnnfMHAUid::O_ID); + if (stats_descriptor.has_value()) { + cudnn_frontend::DataType_t statsType = + ToCudnnFrontendDataType(stats_descriptor->type()); + stats_tensor->set_name("stats") + .set_output(true) + .set_data_type(statsType) + .set_uid(CudnnfMHAUid::P_ID); + } + + CudnnGraph cudnnGraph(std::move(graph)); + TF_ASSIGN_OR_RETURN(bool supported, cudnnGraph.Prepare(dnn_support)); + if (!supported) { + return absl::InternalError("cuDNN graph is not supported."); + } + TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/0)); - TF_ASSIGN_OR_RETURN( - auto tensor_o, - CreateCudnnTensor(output_dims, output_strides, CudnnfMHAUid::O_ID, - output_descriptor.type(), 1, -1)); - auto bmm2_desc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .build(); - RETURN_MSG_IF_CUDNN_ERROR(bmm2_desc); - auto bmm2_op = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(bmm2_input_tensor) - .setbMatDesc(tensor_v) - .setcMatDesc(tensor_o) - .setmatmulDesc(bmm2_desc) - .build(); - RETURN_MSG_IF_CUDNN_ERROR(bmm2_op); - // Create an Operation Graph. In this case it is gemm-gemm - intermediate_ops.push_back(std::move(bmm2_op)); - ops.reserve(intermediate_ops.size()); - for (auto& intermediate_op : intermediate_ops) { - ops.emplace_back(&intermediate_op); + if (VLOG_IS_ON(4)) { + VLOG(4) << "\b flash attention operation graph: " << graph; } - - auto op_graph = cudnn_frontend::OperationGraphBuilder() - .setHandle(cudnn.handle()) - .setOperationGraph(ops.size(), ops.data()) - .build(); - RETURN_MSG_IF_CUDNN_ERROR(op_graph); - VLOG(4) << "\nTensor_q: " << tensor_q.describe() - << "\nTensor_k: " << tensor_k.describe() - << "\nTensor_s: " << tensor_s.describe() - << "\nTensor_v: " << tensor_v.describe() - << "\nTensor_o: " << tensor_o.describe() - << "\nBMM1: " << bmm1_desc.describe() - << "\nBMM2: " << bmm2_desc.describe() - << "\nOpGraph: " << op_graph.describe(); - return std::make_unique(std::move(op_graph)); + return cudnnGraph; } -tsl::StatusOr CreateCudnnFlashAttentionDropoutBwdTensor( +absl::StatusOr +CreateCudnnFlashAttentionDropoutBwdTensor( std::vector& ops, absl::Span dims, absl::Span strides, dnn::DataType dtype, cudnn_frontend::Tensor& input_tensor, cudnn_frontend::Tensor& mask_tensor, @@ -6776,7 +6606,7 @@ tsl::StatusOr CreateCudnnFlashAttentionDropoutBwdTensor( return dropout_scale_out_tensor; } -tsl::StatusOr> +absl::StatusOr> GetCudnnFlashAttentionBackwardOperationGraph( const dnn::MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, const dnn::MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, @@ -7053,7 +6883,15 @@ GetCudnnFlashAttentionBackwardOperationGraph( tensor_p_after_alpha_scale)); tensor_p_after_alpha_scale = std::move(tensor_p_after_bias); } - if (use_causal_mask) { + + if (use_mask) { + // masking -> p_after_mask + TF_ASSIGN_OR_RETURN( + auto tensor_p_after_mask, + CreateCudnnMaskFwdTensor(intermediate_ops, p_dims, p_strides, dtype, + tensor_p_after_alpha_scale)); + tensor_p_after_alpha_scale = std::move(tensor_p_after_mask); + } else if (use_causal_mask) { // Causal masking -> p_after_mask TF_ASSIGN_OR_RETURN(auto tensor_p_after_causal_mask, CreateCudnnFlashAttentionCausalMaskTensor( @@ -7068,9 +6906,21 @@ GetCudnnFlashAttentionBackwardOperationGraph( CreateCudnnTensor(p_dims, p_strides, CudnnfMHAUid::VIRTUAL_ID + 104, dnn::DataType::kFloat, 1, -1, /*is_virtual*/ true)); + + std::vector p_reduction_dims(p_dims.begin(), p_dims.end() - 1); + p_reduction_dims.push_back(1); + + // Divide every stride by the last dim value. + std::vector p_reduction_strides; + p_reduction_strides.reserve(p_strides.size()); + int64_t p_reduced_dim_len = p_dims.back(); + for (auto stride : p_strides) { + p_reduction_strides.push_back(stride / p_reduced_dim_len); + } + TF_ASSIGN_OR_RETURN( auto tensor_softmax_stats, - CreateCudnnTensor(do_reduction_dims, do_reduction_strides, + CreateCudnnTensor(p_reduction_dims, p_reduction_strides, CudnnfMHAUid::P_ID, dnn::DataType::kFloat, 1, -1)); TF_ASSIGN_OR_RETURN(auto sub_desc, @@ -7106,7 +6956,7 @@ GetCudnnFlashAttentionBackwardOperationGraph( auto tensor_p_after_scale_dropout, CreateCudnnFlashAttentionDropoutBwdTensor( intermediate_ops, p_dims, p_strides, dtype, tensor_p_after_softmax, - tensor_dropout_mask, *dropout_rate)); + tensor_dropout_mask, use_dropout ? *dropout_rate : 0)); // after_scale_dropout -> s_transpose auto p_transpose_dims = p_dims; @@ -7347,11 +7197,11 @@ GetCudnnFlashAttentionBackwardOperationGraph( return std::make_unique(std::move(op_graph)); } -#endif // CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8800 } // namespace -static tsl::StatusOr GetExecPlanFromHeuristics( +static absl::StatusOr GetExecPlanFromHeuristics( cudnn_frontend::OperationGraph&& opGraph, const CudnnHandle& cudnn, bool include_fallback_heuristics = false) { #if (CUDNN_VERSION >= 8800) @@ -7398,7 +7248,7 @@ static tsl::StatusOr GetExecPlanFromHeuristics( #endif } -static tsl::StatusOr RebuildExecutionPlan( +static absl::StatusOr RebuildExecutionPlan( const CudnnHandle& cudnn, const dnn::AlgorithmDesc& desc, const cudnn_frontend::OperationGraph& op_graph) { if (!desc.is_cudnn_frontend()) { @@ -7450,11 +7300,11 @@ static tsl::StatusOr RebuildExecutionPlan( return {std::move(plan)}; } -#endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8100 } // namespace -tsl::Status CudnnSupport::DoPrepareForConvolution( +absl::Status CudnnSupport::DoPrepareForConvolution( dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, const dnn::FilterDescriptor& filter_descriptor, @@ -7506,13 +7356,13 @@ tsl::Status CudnnSupport::DoPrepareForConvolution( static_cast(kind)); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } class CudnnLegacyConvRunner : public dnn::ConvRunner { public: // Queries the workspace size and constructs a 'CudnnLegacyConvRunner'. - static tsl::StatusOr Create( + static absl::StatusOr Create( GpuExecutor* parent, Stream* stream, CudnnAccess* cudnn, const dnn::AlgorithmDesc& algo, dnn::DataType input_type, dnn::DataType output_type, dnn::ConvolutionKind kind, @@ -7574,15 +7424,15 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner { size_t GetWorkspaceSize() const override { return workspace_size_; } - tsl::StatusOr ToAlgorithmDesc() const override { + absl::StatusOr ToAlgorithmDesc() const override { return MakeAlgorithmDesc(); } - tsl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result, - DeviceMemoryBase scratch_memory, - DeviceMemoryBase input_data, - DeviceMemoryBase filter_data, - DeviceMemoryBase output_data) const override { + absl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result, + DeviceMemoryBase scratch_memory, + DeviceMemoryBase input_data, + DeviceMemoryBase filter_data, + DeviceMemoryBase output_data) const override { auto algo = MakeAlgorithmDesc(); if (static_cast(parent_) != @@ -7608,27 +7458,28 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner { const bool is_profiling = profile_result != nullptr; TF_ASSIGN_OR_RETURN( std::optional timer, - GpuTimer::CreateIfNeeded(AsGpuStream(stream), is_profiling)); + GpuTimer::CreateIfNeeded( + stream, profile_result && profile_result->warmup_run_executed(), + is_profiling)); - const auto get_fwd_bugs = [&]() -> tsl::Status { + const auto get_fwd_bugs = [&]() -> absl::Status { #if CUDNN_VERSION < 8000 if (algo_id_ == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM && ToCudnnDataType(input_type_) == CUDNN_DATA_INT8 && ToCudnnDataType(output_type_) == CUDNN_DATA_FLOAT) { - return tsl::Status( - absl::StatusCode::kFailedPrecondition, + return absl::FailedPreconditionError( "This configuration potentially produces incorrect results."); } #else (void)output_type_; // To stop clang-tidy saying it's unused. #endif - return ::tsl::OkStatus(); + return absl::OkStatus(); }; - auto get_bwd_data_bugs = [&]() -> tsl::Status { return ::tsl::OkStatus(); }; + auto get_bwd_data_bugs = [&]() -> absl::Status { return absl::OkStatus(); }; - const auto get_bwd_filter_bugs = [&]() -> tsl::Status { - return ::tsl::OkStatus(); + const auto get_bwd_filter_bugs = [&]() -> absl::Status { + return absl::OkStatus(); }; switch (kind_) { @@ -7691,7 +7542,7 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner { scratch_memory.size())); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } private: @@ -7737,7 +7588,7 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner { CudnnConvolutionDescriptor conv_; }; -tsl::Status CudnnSupport::DoConvolve( +absl::Status CudnnSupport::DoConvolve( dnn::ConvolutionKind kind, dnn::DataType element_type, dnn::DataType output_type, Stream* stream, const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, @@ -7834,7 +7685,7 @@ class ScalingParam { dnn::DataType default_target_dtype_; }; -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 struct BackendDescriptorDeleter { void operator()(cudnnBackendDescriptor_t desc) { cudnnBackendDestroyDescriptor(desc); @@ -7843,7 +7694,7 @@ struct BackendDescriptorDeleter { using BackendDescriptor = std::unique_ptr; -tsl::StatusOr CreateBackendDesc( +absl::StatusOr CreateBackendDesc( cudnnBackendDescriptorType_t type) { void* result; RETURN_IF_CUDNN_ERROR(cudnnBackendCreateDescriptor(type, &result)); @@ -7856,7 +7707,7 @@ tsl::StatusOr CreateBackendDesc( // opposed to a sequence of multiple attributes. The distinction is a bit // meaningless, but this is the presentation the cuDNN docs use, so it may as // well be consistent. -tsl::StatusOr> GetDescriptorAttribute( +absl::StatusOr> GetDescriptorAttribute( cudnnBackendDescriptor_t desc, cudnnBackendAttributeName_t name, cudnnBackendDescriptorType_t type) { int64_t n; @@ -7885,7 +7736,7 @@ tsl::StatusOr> GetDescriptorAttribute( // Extract the engine ID and tuning knobs from the ExecutionPlan, and return // them in the form of an AlgorithmDesc for use with RebuildExecutionPlan. -tsl::StatusOr ExecutionPlanToAlgorithmDesc( +absl::StatusOr ExecutionPlanToAlgorithmDesc( const cudnn_frontend::ExecutionPlan& plan, size_t workspace_size) { TF_ASSIGN_OR_RETURN( auto engine_cfgs, @@ -7986,13 +7837,13 @@ class CudnnExecutionPlanRunner size_t GetWorkspaceSize() const override { return workspace_size_; } - tsl::StatusOr ToAlgorithmDesc() const override { + absl::StatusOr ToAlgorithmDesc() const override { return ExecutionPlanToAlgorithmDesc(plan_, workspace_size_); } - tsl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result, - DeviceMemoryBase scratch_memory, - Args... inputs) const override { + absl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result, + DeviceMemoryBase scratch_memory, + Args... inputs) const override { if (static_cast(parent_) != stream->parent()->implementation()) { return tsl::errors::Internal( @@ -8036,7 +7887,7 @@ class CudnnExecutionPlanRunner data_ptrs_vec.pop_back(); } - if (sizeof...(Args) == 7 || sizeof...(Args) == 15) { + if (sizeof...(Args) == 9 || sizeof...(Args) == 17) { // is attention fwd or bwd data_ptrs_vec.erase( std::remove(data_ptrs_vec.begin(), data_ptrs_vec.end(), nullptr), @@ -8054,25 +7905,18 @@ class CudnnExecutionPlanRunner } } if (offset_increment_ > 0) { -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 initial_offset_ += offset_increment_; data_uids_vec.push_back(CudnnfMHAUid::D_SEED_ID); data_uids_vec.push_back(CudnnfMHAUid::D_OFFSET_ID); - if (is_flash_attention_ && CUDNN_VERSION < 8903) { - // flash attention for cuDNN < 8.9.3 only supports dev pointer for seed - // and offset - data_ptrs_vec.push_back(scratch_memory.opaque()); - data_ptrs_vec.push_back(static_cast( - static_cast(scratch_memory.opaque()) + 1)); - } else { - data_ptrs_vec.push_back((void*)(&rng_seed_)); - data_ptrs_vec.push_back((void*)(&initial_offset_)); - } + data_ptrs_vec.push_back( + static_cast(const_cast(&rng_seed_))); + data_ptrs_vec.push_back(static_cast(&initial_offset_)); #else return absl::UnimplementedError( "Cudnn dropout offset and seed are only supported with Cudnn >= " "8.8."); -#endif // CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8800 } auto variantPack = cudnn_frontend::VariantPackBuilder() @@ -8088,17 +7932,9 @@ class CudnnExecutionPlanRunner const bool is_profiling = profile_result != nullptr; TF_ASSIGN_OR_RETURN( std::optional timer, - GpuTimer::CreateIfNeeded(AsGpuStream(stream), is_profiling)); - - if (sizeof...(Args) == 15) { - // is training - if (is_flash_attention_) { - // should memset dq_accum because it is being atomic added - std::vector dev_mem{inputs...}; - DeviceMemoryBase* dev_dq_accum = &(dev_mem[10]); - stream->ThenMemZero(dev_dq_accum, dev_dq_accum->size()); - } - } + GpuTimer::CreateIfNeeded( + stream, profile_result && profile_result->warmup_run_executed(), + is_profiling)); cudnnStatus_t status = cudnnBackendExecute( cudnn.handle(), plan_.get_raw_desc(), variantPack.get_raw_desc()); @@ -8118,7 +7954,7 @@ class CudnnExecutionPlanRunner return tsl::OkStatus(); } - static tsl::StatusOr Create( + static absl::StatusOr Create( GpuExecutor* parent, CudnnAccess* cudnn, cudnn_frontend::ExecutionPlan plan, absl::Span uids, bool need_side_input) { @@ -8134,23 +7970,21 @@ class CudnnExecutionPlanRunner {}, {}, 0, - 0, - false}}; + 0}}; } - static tsl::StatusOr Create( + static absl::StatusOr Create( GpuExecutor* parent, CudnnAccess* cudnn, cudnn_frontend::ExecutionPlan plan, absl::Span uids, bool need_side_input, bool has_activation_output, std::vector scalar_input_uids, std::vector scalar_input_values, int64_t dropout_rng_seed, - int64_t dropout_rng_offset, bool is_flash_attention) { + int64_t dropout_rng_offset) { auto workspace_size = static_cast(plan.getWorkspaceSize()); RETURN_MSG_IF_CUDNN_ERROR(plan); return {{parent, cudnn, std::move(plan), workspace_size, uids, need_side_input, has_activation_output, scalar_input_uids, - scalar_input_values, dropout_rng_seed, dropout_rng_offset, - is_flash_attention}}; + scalar_input_values, dropout_rng_seed, dropout_rng_offset}}; } private: @@ -8161,8 +7995,7 @@ class CudnnExecutionPlanRunner bool has_activation_output, std::vector scalar_input_uids, std::vector scalar_input_values, - int64_t dropout_rng_seed, int64_t dropout_rng_offset, - bool is_flash_attention) + int64_t dropout_rng_seed, int64_t dropout_rng_offset) : parent_(parent), cudnn_(cudnn), plan_(std::move(plan)), @@ -8173,8 +8006,7 @@ class CudnnExecutionPlanRunner scalar_input_uids_(scalar_input_uids), scalar_input_values_(scalar_input_values), offset_increment_(dropout_rng_offset), - rng_seed_(dropout_rng_seed), - is_flash_attention_(is_flash_attention) {} + rng_seed_(dropout_rng_seed) {} GpuExecutor* parent_; CudnnAccess* cudnn_; cudnn_frontend::ExecutionPlan plan_; @@ -8188,15 +8020,102 @@ class CudnnExecutionPlanRunner mutable int64_t initial_offset_ = 0; int64_t offset_increment_ = 0; int64_t rng_seed_; - bool is_flash_attention_; }; -#endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8100 + +template +class CudnnGraphRunner; +// An OpRunner implemented by a cuDNN frontend graph. +// +// This is the class holding the implementation of ToString, GetWorkspaceSize, +// and operator() for use by the cudnn frontend op runners. +template +class CudnnGraphRunner : public dnn::OpRunner { + private: + using Graph = cudnn_frontend::graph::Graph; + using Tensor_attributes = cudnn_frontend::graph::Tensor_attributes; + + public: + std::string ToString() const override { return graph_.Graph().print(); } + + size_t GetWorkspaceSize() const override { + return graph_.Graph().get_workspace_size(); + } + + absl::StatusOr ToAlgorithmDesc() const override { + return absl::InternalError( + "Unexpected call to CudnnGraphRunner::ToAlgorithmDesc"); + } + + absl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result, + DeviceMemoryBase scratch_memory, + Args... inputs) const override { + if (static_cast(parent_) != + stream->parent()->implementation()) { + return tsl::errors::Internal( + "CudnnExecutionPlanRunner cached across multiple StreamExecutors."); + } + CudnnHandle handle = cudnn_->GetHandle(parent_, stream); + std::unordered_map variant_pack; + std::vector vec = {inputs.opaque()...}; + for (int i = 0; i < uids_.size(); ++i) { + if (uids_[i].has_value()) { + variant_pack[*uids_[i]] = vec[i]; + } + } + + if (dropout_rng_offset_increment_ > 0) { +#if CUDNN_VERSION >= 8800 + variant_pack[D_SEED_ID] = (void*)&dropout_rng_seed_; + current_dropout_rng_offset_ += dropout_rng_offset_increment_; + variant_pack[D_OFFSET_ID] = (void*)¤t_dropout_rng_offset_; +#else + return absl::UnimplementedError( + "Cudnn dropout offset and seed are only supported with Cudnn >= " + "8.8.0"); +#endif // CUDNN_VERSION >= 8800 + } + + RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.Graph().execute( + handle.handle(), variant_pack, scratch_memory.opaque())); + + return tsl::OkStatus(); + } + + static absl::StatusOr Create( + GpuExecutor* parent, CudnnAccess* cudnn, CudnnGraph graph, + int64_t dropout_rng_seed, int64_t dropout_rng_offset, + std::vector> uids) { + return CudnnGraphRunner(parent, cudnn, std::move(graph), dropout_rng_seed, + dropout_rng_offset, uids); + } + + private: + CudnnGraphRunner(GpuExecutor* parent, CudnnAccess* cudnn, CudnnGraph graph, + int64_t dropout_rng_seed, int64_t dropout_rng_offset, + std::vector> uids) + : parent_(parent), + cudnn_(cudnn), + graph_(std::move(graph)), + dropout_rng_seed_(dropout_rng_seed), + current_dropout_rng_offset_(0), + dropout_rng_offset_increment_(dropout_rng_offset), + uids_(uids) {} + GpuExecutor* parent_; + CudnnAccess* cudnn_; + Stream* stream_; + CudnnGraph graph_; + int64_t dropout_rng_seed_; + mutable int64_t current_dropout_rng_offset_; + int64_t dropout_rng_offset_increment_; + std::vector> uids_; +}; -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 namespace { template -tsl::Status CreateOpRunners( +absl::Status CreateOpRunners( Stream* stream, CudnnHandle& cudnn, GpuExecutor* gpu_executor, CudnnAccess* cudnn_access, std::unique_ptr op_graph, @@ -8301,13 +8220,13 @@ tsl::Status CreateOpRunners( VLOG(4) << "\nReturned execution plans size: " << out_runners->size(); - return tsl::OkStatus(); + return absl::OkStatus(); } } // namespace -#endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8100 -tsl::Status CudnnSupport::GetConvolveRunners( +absl::Status CudnnSupport::GetConvolveRunners( bool use_cudnn_frontend, dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType output_type, Stream* stream, const dnn::BatchDescriptor& input_descriptor, @@ -8372,8 +8291,7 @@ tsl::Status CudnnSupport::GetConvolveRunners( break; } if (!got_algos) { - return tsl::Status( - absl::StatusCode::kUnknown, + return absl::UnknownError( absl::StrFormat("Listing algorithms failed for kind %d", kind)); } @@ -8396,10 +8314,10 @@ tsl::Status CudnnSupport::GetConvolveRunners( out_exec_plans->push_back(std::move(runner_or).value()); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 auto cudnn = cudnn_->GetHandle(parent_, stream); TF_ASSIGN_OR_RETURN( auto op_graph, @@ -8414,10 +8332,10 @@ tsl::Status CudnnSupport::GetConvolveRunners( #else return tsl::errors::Unimplemented( "Cudnn execution plans are only supported with Cudnn >= 8.1."); -#endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8100 } -tsl::Status CudnnSupport::GetGraphConvolveRunners( +absl::Status CudnnSupport::GetGraphConvolveRunners( dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType output_type, Stream* stream, const dnn::BatchDescriptor& input_descriptor, @@ -8439,7 +8357,7 @@ tsl::Status CudnnSupport::GetGraphConvolveRunners( /*need_side_input=*/false, numeric_options); } -tsl::StatusOr> +absl::StatusOr> CudnnSupport::ConvolveRunnerFromDesc( Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, dnn::ConvolutionKind kind, dnn::DataType input_type, @@ -8481,7 +8399,7 @@ CudnnSupport::ConvolveRunnerFromDesc( return {std::make_unique(std::move(runner))}; } -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 auto cudnn = cudnn_->GetHandle(parent_, stream); TF_ASSIGN_OR_RETURN( @@ -8506,7 +8424,7 @@ CudnnSupport::ConvolveRunnerFromDesc( #endif } -tsl::StatusOr> +absl::StatusOr> CudnnSupport::GraphConvolveRunnerFromDesc( Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, dnn::ConvolutionKind kind, dnn::DataType input_type, @@ -8520,7 +8438,7 @@ CudnnSupport::GraphConvolveRunnerFromDesc( "cuDNN graph execution requires the use of the cuDNN frontend."); } -#if CUDNN_VERSION >= 8900 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8900 auto cudnn = cudnn_->GetHandle(parent_, stream); TF_ASSIGN_OR_RETURN( @@ -8549,7 +8467,7 @@ CudnnSupport::GraphConvolveRunnerFromDesc( class CudnnLegacyFusedConvRunner : public dnn::FusedConvRunner { public: // Queries the workspace size and constructs a 'CudnnLegacyFusedConvRunner'. - static tsl::StatusOr Create( + static absl::StatusOr Create( GpuExecutor* parent, Stream* stream, CudnnAccess* cudnn, const dnn::AlgorithmDesc& algo, dnn::DataType input_type, double conv_scale, double side_input_scale, @@ -8584,17 +8502,17 @@ class CudnnLegacyFusedConvRunner : public dnn::FusedConvRunner { uint64_t GetWorkspaceSize() const override { return workspace_size_; } - tsl::StatusOr ToAlgorithmDesc() const override { + absl::StatusOr ToAlgorithmDesc() const override { return MakeAlgorithmDesc(); } - tsl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result, - DeviceMemoryBase scratch_memory, - DeviceMemoryBase input_data, - DeviceMemoryBase filter_data, - DeviceMemoryBase side_input_data, - DeviceMemoryBase bias_data, - DeviceMemoryBase output_data) const override { + absl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result, + DeviceMemoryBase scratch_memory, + DeviceMemoryBase input_data, + DeviceMemoryBase filter_data, + DeviceMemoryBase side_input_data, + DeviceMemoryBase bias_data, + DeviceMemoryBase output_data) const override { if (static_cast(parent_) != stream->parent()->implementation()) { return tsl::errors::Internal( @@ -8604,9 +8522,11 @@ class CudnnLegacyFusedConvRunner : public dnn::FusedConvRunner { auto algo = MakeAlgorithmDesc(); - TF_ASSIGN_OR_RETURN(std::optional timer, - GpuTimer::CreateIfNeeded(AsGpuStream(stream), - profile_result != nullptr)); + TF_ASSIGN_OR_RETURN( + std::optional timer, + GpuTimer::CreateIfNeeded( + stream, profile_result && profile_result->warmup_run_executed(), + profile_result != nullptr)); auto side_input_data_ptr = (side_input_scale_ == 0) ? output_data.opaque() : side_input_data.opaque(); @@ -8632,9 +8552,9 @@ class CudnnLegacyFusedConvRunner : public dnn::FusedConvRunner { << "\noutput_data.opaque() = " << output_data.opaque(); if (IsTensorMathOpSet(conv_) != tensor_ops_enabled_) { - return tsl::Status(absl::StatusCode::kFailedPrecondition, - "Tensor op math type in dnn::AlgorithmDesc does not " - "match that of the CudnnConvolutionDescriptor"); + return absl::FailedPreconditionError( + "Tensor op math type in dnn::AlgorithmDesc does not " + "match that of the CudnnConvolutionDescriptor"); } // N.B. the scaling parameters alpha1 and alpha2 are pointers to @@ -8670,7 +8590,7 @@ class CudnnLegacyFusedConvRunner : public dnn::FusedConvRunner { << profile_result->elapsed_time_in_ms() << "ms"; } - return ::tsl::OkStatus(); + return absl::OkStatus(); } private: @@ -8721,7 +8641,7 @@ class CudnnLegacyFusedConvRunner : public dnn::FusedConvRunner { CudnnActivationDescriptor activation_desc_; }; -tsl::StatusOr> +absl::StatusOr> CudnnSupport::FusedConvolveRunnerFromDesc( Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, dnn::ConvolutionKind kind, dnn::DataType input_type, @@ -8776,7 +8696,7 @@ CudnnSupport::FusedConvolveRunnerFromDesc( return {std::make_unique(std::move(runner))}; } -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 auto cudnn = cudnn_->GetHandle(parent_, stream); TF_ASSIGN_OR_RETURN(auto op_graph, @@ -8803,7 +8723,7 @@ CudnnSupport::FusedConvolveRunnerFromDesc( #endif } -tsl::Status CudnnSupport::GetFusedConvolveRunners( +absl::Status CudnnSupport::GetFusedConvolveRunners( bool use_cudnn_frontend, dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType bias_type, dnn::DataType output_type, double conv_scale, double side_input_scale, @@ -8881,9 +8801,9 @@ tsl::Status CudnnSupport::GetFusedConvolveRunners( activation_mode != dnn::ActivationMode::kElu && activation_mode != dnn::ActivationMode::kLeakyRelu && activation_mode != dnn::ActivationMode::kNone) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "CuDNN fusion only supports activations of " - "{Relu, Relu6, Elu, }."); + return absl::InvalidArgumentError( + "CuDNN fusion only supports activations of " + "{Relu, Relu6, Elu, }."); } if (!actually_use_cudnn_frontend) { @@ -8892,8 +8812,7 @@ tsl::Status CudnnSupport::GetFusedConvolveRunners( auto cuda_compute_capability = stream->GetCudaComputeCapability(); if (!GetConvolveAlgorithms(cuda_compute_capability, input_type, numeric_options, &algorithms)) { - return tsl::Status(absl::StatusCode::kUnknown, - "Listing fused convolve algorithms failed."); + return absl::UnknownError("Listing fused convolve algorithms failed."); } for (const auto& algo : algorithms) { @@ -8917,19 +8836,18 @@ tsl::Status CudnnSupport::GetFusedConvolveRunners( } out_exec_plans->push_back(std::move(runner_or).value()); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } -#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8100 auto cudnn = cudnn_->GetHandle(parent_, stream); auto op_graph_status = GetCudnnFusedOperationGraph( kind, input_type, bias_type, output_type, conv_scale, side_input_scale, leakyrelu_alpha, input_descriptor, filter_descriptor, bias_descriptor, output_descriptor, convolution_descriptor, activation_mode, cudnn); if (!op_graph_status.status().ok()) { - return tsl::Status(absl::StatusCode::kInternal, - absl::StrCat("Cudnn graph failed to build: ", - op_graph_status.status().ToString())); + return absl::InternalError(absl::StrCat( + "Cudnn graph failed to build: ", op_graph_status.status().ToString())); } auto op_graph = std::move(op_graph_status).value(); @@ -8942,10 +8860,10 @@ tsl::Status CudnnSupport::GetFusedConvolveRunners( #else return tsl::errors::Unimplemented( "Cudnn execution plans are only supported with Cudnn >= 8.1."); -#endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8100 } -tsl::Status CudnnSupport::GetFusedMatmulRunners( +absl::Status CudnnSupport::GetFusedMatmulRunners( bool use_cudnn_frontend, dnn::DataType input_type, dnn::DataType bias_type, dnn::DataType output_type, Stream* stream, bool trans_a, bool trans_b, uint64_t m, uint64_t n, uint64_t k, int64_t lda, int64_t ldb, int64_t ldc, @@ -8953,7 +8871,7 @@ tsl::Status CudnnSupport::GetFusedMatmulRunners( const NumericOptions& numeric_options, std::vector>* out_exec_plans) { -#if CUDNN_VERSION >= 8400 && TF_ENABLE_CUDNN_FRONTEND +#if CUDNN_VERSION >= 8400 if (!use_cudnn_frontend) { return tsl::errors::Unimplemented( "Cudnn execution plans for matmul are only supported with cudnn " @@ -8965,9 +8883,8 @@ tsl::Status CudnnSupport::GetFusedMatmulRunners( input_type, bias_type, output_type, trans_a, trans_b, m, n, k, lda, ldb, ldc, activation_mode, cudnn); if (!op_graph_status.status().ok()) { - return tsl::Status(absl::StatusCode::kInternal, - absl::StrCat("Cudnn graph failed to build: ", - op_graph_status.status().ToString())); + return absl::InternalError(absl::StrCat( + "Cudnn graph failed to build: ", op_graph_status.status().ToString())); } auto op_graph = std::move(op_graph_status).value(); @@ -8982,7 +8899,7 @@ tsl::Status CudnnSupport::GetFusedMatmulRunners( #else return tsl::errors::Unimplemented( "Cudnn execution plans for matmul are only supported with Cudnn >= 8.4."); -#endif // CUDNN_VERSION >= 8400 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8400 } bool CudnnSupport::GetConvolveAlgorithms( @@ -9025,16 +8942,20 @@ bool CudnnSupport::GetConvolveAlgorithms( return true; } -tsl::StatusOr> +absl::StatusOr> CudnnSupport::NormRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, double epsilon, - const dnn::TensorDescriptor& input_descriptor, + Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, + dnn::NormKind kind, double epsilon, + const dnn::TensorDescriptor& x_descriptor, const dnn::TensorDescriptor& scale_descriptor, - const dnn::TensorDescriptor& bias_descriptor, - const dnn::TensorDescriptor& output_descriptor, + const dnn::TensorDescriptor& y_or_dx_descriptor, + std::optional bias_descriptor, + std::optional dy_descriptor, std::optional expectation_descriptor, - std::optional norm_factor_descriptor) { -#if (CUDNN_VERSION >= 8905 && TF_ENABLE_CUDNN_FRONTEND) + std::optional norm_factor_descriptor, + std::optional dscale_descriptor, + std::optional dbias_descriptor) { +#if (CUDNN_VERSION >= 8905) auto cudnn = cudnn_->GetHandle(parent_, stream); std::vector uids; @@ -9045,45 +8966,48 @@ CudnnSupport::NormRunnerFromDesc( return uids.emplace_back(uids.back() + 1); }; - TF_ASSIGN_OR_RETURN( - auto xTensor, - CreateCudnnTensor(input_descriptor.dimensions(), - input_descriptor.GetPhysicalStridesMajorToMinor(), - next_uid(), input_descriptor.type(), 1, -1)); - TF_ASSIGN_OR_RETURN( - auto scaleTensor, - CreateCudnnTensor(scale_descriptor.dimensions(), - scale_descriptor.GetPhysicalStridesMajorToMinor(), - next_uid(), scale_descriptor.type(), 1, -1)); - TF_ASSIGN_OR_RETURN( - auto biasTensor, - CreateCudnnTensor(bias_descriptor.dimensions(), - bias_descriptor.GetPhysicalStridesMajorToMinor(), - next_uid(), bias_descriptor.type(), 1, -1)); - TF_ASSIGN_OR_RETURN( - auto yTensor, - CreateCudnnTensor(output_descriptor.dimensions(), - output_descriptor.GetPhysicalStridesMajorToMinor(), - next_uid(), output_descriptor.type(), 1, -1)); - std::optional expectation_tensor, norm_factor_tensor; - if (expectation_descriptor) { - TF_ASSIGN_OR_RETURN( - expectation_tensor, - CreateCudnnTensor( - expectation_descriptor->dimensions(), - expectation_descriptor->GetPhysicalStridesMajorToMinor(), - next_uid(), expectation_descriptor->type(), 1, -1)); - TF_ASSIGN_OR_RETURN( - norm_factor_tensor, - CreateCudnnTensor( - norm_factor_descriptor->dimensions(), - norm_factor_descriptor->GetPhysicalStridesMajorToMinor(), - next_uid(), norm_factor_descriptor->type(), 1, -1)); + auto create_cudnn_tensor = [next_uid](dnn::TensorDescriptor tensor_descriptor) + -> tsl::StatusOr { + return CreateCudnnTensor(tensor_descriptor.dimensions(), + tensor_descriptor.GetPhysicalStridesMajorToMinor(), + next_uid(), tensor_descriptor.type(), 1, -1); + }; + + TF_ASSIGN_OR_RETURN(auto x_tensor, create_cudnn_tensor(x_descriptor)); + TF_ASSIGN_OR_RETURN(auto scale_tensor, create_cudnn_tensor(scale_descriptor)); + TF_ASSIGN_OR_RETURN(auto y_or_dx_tensor, + create_cudnn_tensor(y_or_dx_descriptor)); + + std::optional bias_tensor, expectation_tensor, + norm_factor_tensor, dy_tensor, dscale_tensor, dbias_tensor; + if (kind == dnn::NormKind::LAYER_FWD_INFER || + kind == dnn::NormKind::LAYER_FWD_TRAIN) { + TF_ASSIGN_OR_RETURN(bias_tensor, + create_cudnn_tensor(bias_descriptor.value())); + } + + if (kind == dnn::LAYER_FWD_TRAIN) { + TF_ASSIGN_OR_RETURN(expectation_tensor, + create_cudnn_tensor(expectation_descriptor.value())); + TF_ASSIGN_OR_RETURN(norm_factor_tensor, + create_cudnn_tensor(norm_factor_descriptor.value())); + } + + if (kind == dnn::LAYER_BWD) { + TF_ASSIGN_OR_RETURN(dy_tensor, create_cudnn_tensor(dy_descriptor.value())); + TF_ASSIGN_OR_RETURN(expectation_tensor, + create_cudnn_tensor(expectation_descriptor.value())); + TF_ASSIGN_OR_RETURN(norm_factor_tensor, + create_cudnn_tensor(norm_factor_descriptor.value())); + TF_ASSIGN_OR_RETURN(dscale_tensor, + create_cudnn_tensor(dscale_descriptor.value())); + TF_ASSIGN_OR_RETURN(dbias_tensor, + create_cudnn_tensor(dbias_descriptor.value())); } std::vector scale_dim(4, 1), scalar_uids; TF_ASSIGN_OR_RETURN( - auto epsilonTensor, + auto epsilon_tensor, CreateCudnnTensor(scale_dim, scale_dim, scalar_uids.emplace_back(uids.back() + 1), dnn::DataType::kDouble, 1, -1, /*is_virtual=*/false, @@ -9093,30 +9017,47 @@ CudnnSupport::NormRunnerFromDesc( cudnnBackendNormMode_t normalizationMode = CUDNN_LAYER_NORM; std::optional norm_op; - if (!expectation_descriptor) { - cudnnBackendNormFwdPhase_t phase = CUDNN_NORM_FWD_INFERENCE; - norm_op = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR) - .setNormalizationMode(normalizationMode) - .setNormFwdPhase(phase) - .setxDesc(xTensor) - .setScaleAndBias(scaleTensor, biasTensor) - .setEpsilonTensor(epsilonTensor) - .setyDesc(yTensor) - .build(); - } else { - cudnnBackendNormFwdPhase_t phase = CUDNN_NORM_FWD_TRAINING; - norm_op = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR) - .setNormalizationMode(normalizationMode) - .setNormFwdPhase(phase) - .setxDesc(xTensor) - .setScaleAndBias(scaleTensor, biasTensor) - .setEpsilonTensor(epsilonTensor) - .setSavedMeanAndInvVar(expectation_tensor.value(), - norm_factor_tensor.value()) - .setyDesc(yTensor) - .build(); + switch (kind) { + case dnn::LAYER_FWD_INFER: + norm_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR) + .setNormalizationMode(normalizationMode) + .setNormFwdPhase(CUDNN_NORM_FWD_INFERENCE) + .setxDesc(x_tensor) + .setScaleAndBias(scale_tensor, bias_tensor.value()) + .setEpsilonTensor(epsilon_tensor) + .setyDesc(y_or_dx_tensor) + .build(); + break; + case dnn::LAYER_FWD_TRAIN: + norm_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR) + .setNormalizationMode(normalizationMode) + .setNormFwdPhase(CUDNN_NORM_FWD_TRAINING) + .setxDesc(x_tensor) + .setScaleAndBias(scale_tensor, bias_tensor.value()) + .setEpsilonTensor(epsilon_tensor) + .setSavedMeanAndInvVar(expectation_tensor.value(), + norm_factor_tensor.value()) + .setyDesc(y_or_dx_tensor) + .build(); + break; + case dnn::LAYER_BWD: + norm_op = + cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR) + .setNormalizationMode(normalizationMode) + .setxDesc(x_tensor) + .setScale(scale_tensor) + .setSavedMeanAndInvVar(expectation_tensor.value(), + norm_factor_tensor.value()) + .setDScaleAndDBias(dscale_tensor.value(), dbias_tensor.value()) + .setdyDesc(dy_tensor.value()) + .setdxDesc(y_or_dx_tensor) + .build(); + break; + default: + break; } std::array ops = {&norm_op.value()}; @@ -9138,14 +9079,14 @@ CudnnSupport::NormRunnerFromDesc( parent_, cudnn_.get(), std::move(execution_plan), uids, /*need_side_input=*/false, /*has_activation_output=*/false, scalar_uids, scalar_input_values, /*dropout_rng_seed=*/0, - /*dropout_rng_offset=*/0, /*is_flash_attention=*/false)); + /*dropout_rng_offset=*/0)); return {std::make_unique>( std::move(runner))}; #else return absl::UnimplementedError( "Layer norm kernels require cuDNN 8.9.5 or higher."); -#endif // CUDNN_VERSION >= 8905 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8905 } // Returns the offset to increment for the dropout rng. @@ -9161,7 +9102,7 @@ int64_t GetDropoutRngOffset(std::vector& intermediate_shape) { return max_seq_len * max_seq_len / cudnn_mha_num_threads; } -tsl::StatusOr> +absl::StatusOr> CudnnSupport::FusedMHARunnerFromDesc( Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, dnn::FusedMHAKind kind, @@ -9175,29 +9116,56 @@ CudnnSupport::FusedMHARunnerFromDesc( std::optional bias_descriptor, double scale, std::optional dropout_rate, std::optional seed, bool is_flash_attention, bool is_causal_mask) { -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 auto cudnn = cudnn_->GetHandle(parent_, stream); bool use_dropout = dropout_rate && *dropout_rate > 0.0; std::vector intermediate_shape; + + if (is_flash_attention) { + TF_ASSIGN_OR_RETURN( + auto graph, + GetCudnnFlashAttentionOperationGraph( + *this, /*q_descriptor=*/bmm1_lhs_descriptor, + /*k_descriptor=*/bmm1_rhs_descriptor, + /*v_descriptor=*/bmm2_rhs_descriptor, + /*o_descriptor=*/output_descriptor, bias_descriptor, + mask_descriptor, /*stats_descriptor=*/activation_descriptor, + /*scale=*/static_cast(scale), use_dropout, dropout_rate, + is_causal_mask)); + + std::vector intermediate_bmm2_lhs_dims = + intermediate_bmm2_lhs_descriptor.GetCudnnCompatibleDimensions(true); + intermediate_shape = intermediate_bmm2_lhs_dims; + int64_t dropout_rng_offset = GetDropoutRngOffset(intermediate_shape); + int64_t dropout_rng_seed = seed.has_value() ? *seed : 0; + std::vector> uids = { + CudnnfMHAUid::Q_ID, CudnnfMHAUid::K_ID, CudnnfMHAUid::V_ID, + CudnnfMHAUid::O_ID, + /*mask=*/std::nullopt}; + uids.emplace_back(bias_descriptor.has_value() + ? std::optional(CudnnfMHAUid::BIAS_ID) + : std::nullopt); + uids.emplace_back(activation_descriptor.has_value() + ? std::optional(CudnnfMHAUid::P_ID) + : std::nullopt); + TF_ASSIGN_OR_RETURN(auto runner, + CudnnGraphRunner::Create( + parent_, cudnn_.get(), std::move(graph), + dropout_rng_seed, dropout_rng_offset, uids)); + + return {std::make_unique>( + std::move(runner))}; + } + TF_ASSIGN_OR_RETURN( auto op_graph, - is_flash_attention - ? GetCudnnFlashAttentionOperationGraph( - bmm1_lhs_descriptor, bmm1_rhs_descriptor, bmm2_rhs_descriptor, - intermediate_bmm2_lhs_descriptor, output_descriptor, - mask_descriptor, bias_descriptor, activation_descriptor, kind, - dropout_rate, seed, cudnn, scale, intermediate_shape, - use_dropout, - /*use_mask*/ mask_descriptor != std::nullopt, - /*use_bias*/ bias_descriptor != std::nullopt, is_causal_mask) - : GetCudnnFusedMHAOperationGraph( - bmm1_lhs_descriptor, bmm1_rhs_descriptor, bmm2_rhs_descriptor, - intermediate_bmm2_lhs_descriptor, output_descriptor, - mask_descriptor, bias_descriptor, activation_descriptor, kind, - dropout_rate, seed, cudnn, scale, intermediate_shape, - use_dropout, - /*use_mask*/ mask_descriptor != std::nullopt, - /*use_bias*/ bias_descriptor != std::nullopt)); + GetCudnnFusedMHAOperationGraph( + bmm1_lhs_descriptor, bmm1_rhs_descriptor, bmm2_rhs_descriptor, + intermediate_bmm2_lhs_descriptor, output_descriptor, mask_descriptor, + bias_descriptor, activation_descriptor, kind, dropout_rate, seed, + cudnn, scale, intermediate_shape, use_dropout, + /*use_mask*/ mask_descriptor != std::nullopt, + /*use_bias*/ bias_descriptor != std::nullopt)); TF_ASSIGN_OR_RETURN(auto execution_plan, GetExecPlanFromHeuristics(std::move(*op_graph), cudnn)); @@ -9221,41 +9189,16 @@ CudnnSupport::FusedMHARunnerFromDesc( int64_t dropout_rng_seed = seed == std::nullopt ? 0 : *seed; int64_t dropout_rng_offset = 0; - if (is_flash_attention) { - ScalingParam alpha_scale(scale, dnn::DataType::kFloat); - scalar_input_values = {alpha_scale}; - scalar_input_uids = {CudnnfMHAUid::ALPHA_SCALE_ID}; + ScalingParam alpha_scale(scale, bmm1_lhs_descriptor.type()); + scalar_input_values = {alpha_scale}; + scalar_input_uids = {CudnnfMHAUid::ALPHA_SCALE_ID}; + if (use_dropout) { scalar_input_uids.push_back(CudnnfMHAUid::DROPOUT_SCALE_ID); - // before 8.9.3 it should be half/bf16, after 8.9.3, it could be any type, - // use fp32 here - double dropout_scale_value = - use_dropout ? (1.0f / (1.0f - *dropout_rate)) : 1.0f; - ScalingParam dropout_scale(dropout_scale_value, dnn::DataType::kFloat); + double dropout_scale_value = 1.0f / (1.0f - *dropout_rate); + ScalingParam dropout_scale(dropout_scale_value, bmm1_lhs_descriptor.type()); scalar_input_values.push_back(dropout_scale); dropout_rng_offset = GetDropoutRngOffset(intermediate_shape); - - if (bias_descriptor == std::nullopt) { - // push negative infinity here - scalar_input_uids.push_back(CudnnfMHAUid::NEG_INFINITY_ID); - double negative_infinity_value = -std::numeric_limits::infinity(); - ScalingParam negative_infinity(negative_infinity_value, - dnn::DataType::kFloat); - scalar_input_values.push_back(negative_infinity); - } - } else { - ScalingParam alpha_scale(scale, bmm1_lhs_descriptor.type()); - scalar_input_values = {alpha_scale}; - scalar_input_uids = {CudnnfMHAUid::ALPHA_SCALE_ID}; - if (use_dropout) { - scalar_input_uids.push_back(CudnnfMHAUid::DROPOUT_SCALE_ID); - double dropout_scale_value = 1.0f / (1.0f - *dropout_rate); - ScalingParam dropout_scale(dropout_scale_value, - bmm1_lhs_descriptor.type()); - scalar_input_values.push_back(dropout_scale); - dropout_rng_offset = GetDropoutRngOffset(intermediate_shape); - } } - TF_ASSIGN_OR_RETURN( auto runner, CudnnExecutionPlanRunner::Create( @@ -9263,16 +9206,16 @@ CudnnSupport::FusedMHARunnerFromDesc( /*need_side_input*/ true, /*has_activation_output*/ (activation_descriptor != std::nullopt), scalar_input_uids, scalar_input_values, dropout_rng_seed, - dropout_rng_offset, is_flash_attention)); + dropout_rng_offset)); return {std::make_unique>( std::move(runner))}; #else return absl::UnimplementedError( "Cudnn execution plans are only supported with Cudnn >= 8.8."); -#endif // CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8800 } -tsl::StatusOr> +absl::StatusOr> CudnnSupport::FusedMHABackwardRunnerFromDesc( Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, dnn::FusedMHAKind kind, @@ -9291,7 +9234,7 @@ CudnnSupport::FusedMHABackwardRunnerFromDesc( std::optional bias_descriptor, double scale, std::optional dropout_rate, std::optional seed, bool is_flash_attention, bool is_causal_mask) { -#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) +#if CUDNN_VERSION >= 8800 auto cudnn = cudnn_->GetHandle(parent_, stream); bool use_dropout = dropout_rate && *dropout_rate > 0.0; @@ -9344,19 +9287,24 @@ CudnnSupport::FusedMHABackwardRunnerFromDesc( use_dropout ? (1.0f / (1.0f - *dropout_rate)) : 1.0f; ScalingParam dropout_scale(dropout_scale_value, dnn::DataType::kFloat); // scale prob - double scale_prob_value = 1.0 - *dropout_rate; + double scale_prob_value = use_dropout ? 1.0 - *dropout_rate : 1.0f; ScalingParam scale_prob(scale_prob_value, dnn::DataType::kFloat); scalar_values = {alpha_scale, dropout_scale, scale_prob}; // push dropout seed and offset here dropout_rng_offset = GetDropoutRngOffset(intermediate_shape); - uids = { - CudnnfMHAUid::Q_ID, CudnnfMHAUid::K_ID, CudnnfMHAUid::P_ID, - CudnnfMHAUid::V_ID, CudnnfMHAUid::dO_ID, CudnnfMHAUid::dQ_ID, - CudnnfMHAUid::dK_ID, CudnnfMHAUid::dV_ID, CudnnfMHAUid::S_SUM_ID, - CudnnfMHAUid::d_Q_accum_ID, CudnnfMHAUid::O_ID}; + uids = {CudnnfMHAUid::Q_ID, CudnnfMHAUid::K_ID, + CudnnfMHAUid::P_ID, CudnnfMHAUid::V_ID, + CudnnfMHAUid::dO_ID, CudnnfMHAUid::dQ_ID, + CudnnfMHAUid::dK_ID, CudnnfMHAUid::dV_ID, + CudnnfMHAUid::S_SUM_ID, CudnnfMHAUid::d_Q_accum_ID}; + if (mask_descriptor != std::nullopt) { + uids.push_back(CudnnfMHAUid::MASK_ID); + } + uids.push_back(CudnnfMHAUid::O_ID); if (bias_descriptor != std::nullopt) { uids.push_back(CudnnfMHAUid::BIAS_ID); - } else { + } + if (is_causal_mask) { // is causal mask // negative infinity double negative_infinity_value = -std::numeric_limits::infinity(); @@ -9396,8 +9344,7 @@ CudnnSupport::FusedMHABackwardRunnerFromDesc( parent_, cudnn_.get(), std::move(execution_plan), uids, /*need_side_input*/ true, /*has_activation_output*/ false, scalar_uids, scalar_values, dropout_rng_seed, - /*dropout_rng_offset*/ dropout_rng_offset, - /*is_flash_attention*/ is_flash_attention)); + /*dropout_rng_offset*/ dropout_rng_offset)); return {std::make_unique< CudnnExecutionPlanRunner>( std::move(runner))}; @@ -9405,7 +9352,7 @@ CudnnSupport::FusedMHABackwardRunnerFromDesc( return absl::UnimplementedError( "Cudnn execution plans with dbias calculation in bwd are only " "supported with Cudnn >= 8.8."); -#endif // CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND +#endif // CUDNN_VERSION >= 8800 } bool CudnnSupport::GetRnnAlgorithms( @@ -9577,7 +9524,7 @@ bool CudnnSupport::DoBatchNormalizationForward( } template -tsl::Status CudnnSupport::DoBatchNormalizationForwardImpl( +absl::Status CudnnSupport::DoBatchNormalizationForwardImpl( Stream* stream, dnn::DataType input_data_type, dnn::DataType scale_data_type, const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& offset, @@ -9642,15 +9589,15 @@ tsl::Status CudnnSupport::DoBatchNormalizationForwardImpl( } #endif - auto check_no_side_input_or_activation = [&]() -> tsl::Status { + auto check_no_side_input_or_activation = [&]() -> absl::Status { if (activation_mode != dnn::ActivationMode::kNone || !side_input.is_null()) { - return tsl::Status(absl::StatusCode::kInternal, - absl::StrCat("Side input and activation are not " - "supported by cuDNN version: ", - CUDNN_VERSION)); + return absl::InternalError( + absl::StrCat("Side input and activation are not " + "supported by cuDNN version: ", + CUDNN_VERSION)); } else { - return ::tsl::OkStatus(); + return absl::OkStatus(); } }; @@ -9662,8 +9609,8 @@ tsl::Status CudnnSupport::DoBatchNormalizationForwardImpl( void* batch_var_opaque; if (!batch_mean->is_null() && !batch_var->is_null()) { if (exponential_average_factor == 1.0) { - stream->ThenMemZero(batch_mean, batch_mean->size()); - stream->ThenMemZero(batch_var, batch_var->size()); + TF_RETURN_IF_ERROR(stream->MemZero(batch_mean, batch_mean->size())); + TF_RETURN_IF_ERROR(stream->MemZero(batch_var, batch_var->size())); } batch_mean_opaque = batch_mean->opaque(); batch_var_opaque = batch_var->opaque(); @@ -9722,7 +9669,7 @@ tsl::Status CudnnSupport::DoBatchNormalizationForwardImpl( scale.opaque(), offset.opaque(), estimated_mean.opaque(), maybe_inv_var, epsilon)); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } bool CudnnSupport::DoBatchNormalizationBackward( @@ -9790,7 +9737,7 @@ bool CudnnSupport::DoBatchNormalizationBackward( } template -tsl::Status CudnnSupport::DoBatchNormalizationBackwardImpl( +absl::Status CudnnSupport::DoBatchNormalizationBackwardImpl( Stream* stream, int cudnn_input_type, int cudnn_scale_type, const DeviceMemory& y_backprop, const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& offset, @@ -9872,14 +9819,14 @@ tsl::Status CudnnSupport::DoBatchNormalizationBackwardImpl( /*reserveSpaceSizeInBytes=*/reserve_space_data->size())); } #endif - auto check_no_side_input_or_activation = [&]() -> tsl::Status { + auto check_no_side_input_or_activation = [&]() -> absl::Status { if (activation_mode != dnn::ActivationMode::kNone || !side_input_backprop->is_null()) { return tsl::errors::Internal( "Side input and activation are not supported by cuDNN version: ", CUDNN_VERSION); } else { - return ::tsl::OkStatus(); + return absl::OkStatus(); } }; @@ -9893,10 +9840,10 @@ tsl::Status CudnnSupport::DoBatchNormalizationBackwardImpl( mean.opaque(), inv_var.opaque())); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status CudnnSupport::DoFusedConvolve( +absl::Status CudnnSupport::DoFusedConvolve( Stream* stream, dnn::DataType input_type, dnn::DataType side_input_type, dnn::DataType bias_type, dnn::DataType output_type, const dnn::BatchDescriptor& conv_input_descriptor, @@ -9929,9 +9876,9 @@ tsl::Status CudnnSupport::DoFusedConvolve( if (activation_mode != dnn::ActivationMode::kRelu && activation_mode != dnn::ActivationMode::kNone) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "cudnnConvolutionBiasActivationForward() only supports " - "Relu or None activation."); + return absl::InvalidArgumentError( + "cudnnConvolutionBiasActivationForward() only supports " + "Relu or None activation."); } CudnnTensorDescriptor conv_input_nd( @@ -9980,7 +9927,7 @@ tsl::Status CudnnSupport::DoFusedConvolve( filter_data, side_input_data, biases, output_data); } -tsl::Status CudnnSupport::CudnnReorderConvolutionFilterAndBias( +absl::Status CudnnSupport::CudnnReorderConvolutionFilterAndBias( Stream* stream, const dnn::FilterDescriptor& filter_descriptor, const DeviceMemory& filter_input, DeviceMemory* filter_output, @@ -10002,10 +9949,10 @@ tsl::Status CudnnSupport::CudnnReorderConvolutionFilterAndBias( /*reorderedBiasData=*/has_bias ? bias_output->opaque() : nullptr); RETURN_IF_CUDNN_ERROR(status); - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status CudnnSupport::DoPrepareForCtcLoss( +absl::Status CudnnSupport::DoPrepareForCtcLoss( Stream* stream, dnn::DataType element_type, const dnn::RnnStateTensorDescriptor& probs_desc, const dnn::RnnStateTensorDescriptor& grads_desc, @@ -10057,26 +10004,26 @@ tsl::Status CudnnSupport::DoPrepareForCtcLoss( } *ctc_loss_algo_id = algo; #else - return tsl::Status(absl::StatusCode::kInvalidArgument, - "No supported cudnnGetCTCLossWorkspaceSize when " - "CUDNN_VERSION < 7.6.3"); + return absl::InvalidArgumentError( + "No supported cudnnGetCTCLossWorkspaceSize when " + "CUDNN_VERSION < 7.6.3"); #endif // Allocate the workspace. if (workspace_size_in_bytes == 0) { *scratch_memory = DeviceMemory(); - return ::tsl::OkStatus(); + return absl::OkStatus(); } const auto scratch_or = scratch_allocator->AllocateBytes(workspace_size_in_bytes); if (scratch_or.ok()) { *scratch_memory = scratch_or.value(); - return ::tsl::OkStatus(); + return absl::OkStatus(); } return tsl::errors::Internal( "Failed to allocate scratch memory for the CuDNN CTC Loss"); } -tsl::Status CudnnSupport::DoCtcLoss( +absl::Status CudnnSupport::DoCtcLoss( Stream* stream, dnn::DataType element_type, const dnn::RnnStateTensorDescriptor& probs_desc, const DeviceMemoryBase probs_data, absl::Span labels_data, @@ -10087,9 +10034,9 @@ tsl::Status CudnnSupport::DoCtcLoss( int ctc_loss_algo_id) { // Current cuDNN CTC Loss only supports the float datatype if (CUDNN_VERSION < 7603 || element_type != dnn::DataType::kFloat) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "CudnnCtcLossDescriptor is supported only when the " - "CUDNN_VERSION >= 7.6.3 and DataType is float"); + return absl::InvalidArgumentError( + "CudnnCtcLossDescriptor is supported only when the " + "CUDNN_VERSION >= 7.6.3 and DataType is float"); } CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type)); const CudnnRnnStateTensorDescriptor& cudnn_probs_desc = @@ -10119,214 +10066,7 @@ bool CudnnSupport::DoTransformTensor(Stream* stream, RETURN_IF_CUDNN_ERROR(cudnnTransformTensor( cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(), &beta, output_tensor_desc.handle(), output_data->opaque())); - return ::tsl::OkStatus(); - }(); - return IsStatusOk(status, /*report_error=*/true); -} - -bool CudnnSupport::DoMatMul(Stream* stream, - const DeviceMemory& input_data, - const DeviceMemory& weights, - const dnn::BatchDescriptor& input_dimensions, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) { - if (input_dimensions.count() != output_dimensions.count()) { - LOG(ERROR) << "MatMul input and output dimensions are not compatible."; - return false; - } - - // We do not permute the input or output, instead we just - // reinterpret the layout. We are working with row-major matrices - // and the rows of the input and output correspond to batch, so - // batch has to be outermost in both the input and output. - // - // By adding transposes to the BLAS gemm call we could perhaps make - // the kYXDepthBatch layout work as well, but there has been no need - // for that so far. - if (input_dimensions.layout() != dnn::DataLayout::kBatchYXDepth && - input_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) { - LOG(ERROR) << "Unsupported MatMul input layout."; - return false; - } - if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth && - output_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) { - LOG(ERROR) << "Unsupported MatMul output layout."; - return false; - } - - if (output_dimensions.width() == 1 && output_dimensions.height() == 1) { - // This is a fast path that also supports the kBatchYXDepth layout. - - // The matrices here are in row-major format while BLAS expects - // column-major, i.e. our matrices are transposed as far as BLAS - // is concerned. So we need to compute output^T = - // input^T*weights^T. There is no parameter for transposing the - // output in BLAS gemm, but instead we can transpose both sides of - // the equality to see that this is equivalent to - // output=weights*input. So we only need to swap the order of - // weights and input in the matrix product to correct for the - // row-major versus column-major difference. - const int64_t m = output_dimensions.NodesAcrossFeatureMaps(); - const int64_t n = input_dimensions.count(); - const int64_t k = input_dimensions.NodesAcrossFeatureMaps(); - if (!stream - ->ThenBlasGemm(blas::Transpose::kNoTranspose, - blas::Transpose::kNoTranspose, m, n, k, weights, m, - input_data, k, output_data, m, NumericOptions{}, - blas::CallContext::kNone) - - .ok()) { - return false; - } - } else { - // This is a slower and more complex path that supports output - // width() * height() > 1, though it only supports the - // kBatchYXDepth layout. Does support kBatchDepthYX if output - // feature_map_count() == 1, as then there is no difference - // between the two layouts. - // - // The operation here is the same as above, except that we have to - // do the matrix multiplication for each (y,x) output coordinate - // separately. We then interpret weights as containing K = width() - // * height() different matrices, which we all multiply onto the - // matrix from input_data, yielding K matrix products. We then - // combine these together into one matrix by concatenating all the - // first rows of these matrices, then all the seconds rows and so - // on. We can do this with a batched matrix multiplication, where - // the result is written to a different submatrix of the output - // for each matrix multiplication. - // - // The reason that we only support the kBatchYXDepth output layout - // is that we have to do something in the depth for each (y,x) - // coordinate. The kBatchYXDepth layout has the depth information - // for each point (y,x) in contiguous memory while the - // kBatchDepthYX layout does not. - // - // TODO(broune): Consider a special case for when output depth == - // 1, as then possibly this could all be done as one matrix - // multiplication instead of a batched one, which should be - // faster. Another possibility would be to add a weights layout - // parameter and then support kBatchDepthYX for a different - // weights layout. - if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth && - !(output_dimensions.layout() == dnn::DataLayout::kBatchDepthYX && - output_dimensions.feature_map_count() == 1)) { - LOG(ERROR) << "Unsupported MatMul output layout."; - return false; - } - - const float alpha = 1.0f; // Take the matrix product without scaling it. - const float beta = 0.0f; // Ignore the original values in output_data. - const uint64_t m = output_dimensions.feature_map_count(); - const uint64_t n = input_dimensions.count(); - const uint64_t k = input_dimensions.NodesAcrossFeatureMaps(); - const int lda = m; - const int ldb = k; - const int ldc = output_dimensions.NodesAcrossFeatureMaps(); - const int batch_count = output_dimensions.NodesPerFeatureMap(); - - std::vector> a(batch_count); - std::vector> b(batch_count); - std::vector> c(batch_count); - for (int i = 0; i < batch_count; ++i) { - const int weights_offset = i * input_dimensions.NodesAcrossFeatureMaps() * - output_dimensions.feature_map_count(); - a[i] = DeviceMemory::MakeFromByteSize( - const_cast(reinterpret_cast(weights.opaque())) + - weights_offset, - weights.ElementCount() - weights_offset); - - b[i] = input_data; - - const int output_offset = i * output_dimensions.feature_map_count(); - c[i] = DeviceMemory::MakeFromByteSize( - const_cast( - reinterpret_cast(output_data->opaque())) + - output_offset, - output_data->ElementCount() - output_offset); - } - const auto toPtrs = [](std::vector>& v) { - std::vector*> ptrs; - ptrs.reserve(v.size()); - for (auto& mem : v) { - ptrs.push_back(&mem); - } - return ptrs; - }; - - stream->ThenBlasGemmBatched( - blas::Transpose::kNoTranspose, blas::Transpose::kNoTranspose, m, n, k, - alpha, toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c), ldc, - batch_count, NumericOptions{}, blas::CallContext::kNone); - } - - return stream->ok(); -} - -bool CudnnSupport::DoBiasAdd(Stream* stream, - const DeviceMemory& input_data, - const DeviceMemory& biases, - const dnn::BatchDescriptor& dimensions, - DeviceMemory* output_data) { - CudnnTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT); - - dnn::BatchDescriptor bias_dimensions; - bias_dimensions.set_count(1) - .set_feature_map_count(dimensions.feature_map_count()) - .set_height(1) - .set_width(1) - .set_layout(dnn::DataLayout::kBatchYXDepth); - CudnnTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT); - - // cudnnAddTensor after R3 is in-place, so we need to copy input_data to - // output_data before doing the addition, unless the input and - // output are at the same address. - if (input_data.opaque() != output_data->opaque()) { - stream->ThenMemcpy(output_data, input_data, - dimensions.ElementCount() * sizeof(float)); - if (!stream->ok()) { - LOG(ERROR) - << "stream " << stream - << " could not enqueue a tensor copy as part of bias addition."; - return false; - } - } - - const float alpha = 1.0f; - const float beta = 1.0f; - - auto cudnn = cudnn_->GetHandle(parent_, stream); - - const auto status = [&] { - RETURN_IF_CUDNN_ERROR(cudnnAddTensor( - cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(), - &beta, input_descriptor.handle(), output_data->opaque())); - return ::tsl::OkStatus(); - }(); - return IsStatusOk(status, /*report_error=*/true); -} - -bool CudnnSupport::DoActivate(Stream* stream, - dnn::ActivationMode activation_mode, - const dnn::BatchDescriptor& dimensions, - const DeviceMemory& input_data, - DeviceMemory* output_data, - uint64_t options) { - CudnnActivationDescriptor activation_desc( - activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max()); - - CudnnTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT); - // Alpha is the input scaling factor. - float alpha = 1.0; - // Beta is the output scaling factor. - float beta = 0.0; - - auto cudnn = cudnn_->GetHandle(parent_, stream); - const auto status = [&] { - RETURN_IF_CUDNN_ERROR(cudnnActivationForward( - cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(), - input_data.opaque(), &beta, input_nd.handle(), output_data->opaque())); - return ::tsl::OkStatus(); + return absl::OkStatus(); }(); return IsStatusOk(status, /*report_error=*/true); } @@ -10343,7 +10083,7 @@ struct PoolingSplitsSpec { int64_t output_offset_in_bytes; }; -tsl::StatusOr> GetTensorSplits( +absl::StatusOr> GetTensorSplits( const dnn::BatchDescriptor& input_descriptor, const dnn::BatchDescriptor& output_descriptor, dnn::DataType element_type) { std::vector out; @@ -10369,12 +10109,9 @@ tsl::StatusOr> GetTensorSplits( std::numeric_limits::max() / elements_per_batch_input; if (max_batches_per_split == 0) { - return tsl::Status( - absl::StatusCode::kInternal, - absl::StrCat( - "Tensor has too many elements for int32 indexing: batches=", - num_batches, " elements_per_batch=", elements_per_batch_input, - ".")); + return absl::InternalError(absl::StrCat( + "Tensor has too many elements for int32 indexing: batches=", + num_batches, " elements_per_batch=", elements_per_batch_input, ".")); } int64_t processed_batches = 0; @@ -10392,7 +10129,7 @@ tsl::StatusOr> GetTensorSplits( } } // namespace -tsl::Status CudnnSupport::DoPoolForward( +absl::Status CudnnSupport::DoPoolForward( dnn::DataType element_type, Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, @@ -10403,7 +10140,7 @@ tsl::Status CudnnSupport::DoPoolForward( output_dimensions, output_data, workspace_allocator); } -tsl::Status CudnnSupport::DoPoolForward( +absl::Status CudnnSupport::DoPoolForward( dnn::DataType element_type, Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const NumericOptions& numeric_options, @@ -10436,14 +10173,13 @@ tsl::Status CudnnSupport::DoPoolForward( RETURN_IF_CUDNN_ERROR(cudnnPoolingForward( cudnn.handle(), pooling_desc.handle(), alpha, src_desc.handle(), input_ptr, beta, dest_desc.handle(), output_ptr)); - return ::tsl::OkStatus(); + return absl::OkStatus(); }; auto splits_or = GetTensorSplits(input_dimensions, output_dimensions, element_type); if (!splits_or.ok()) { - return tsl::Status(absl::StatusCode::kInternal, - "Cudnn pooling failed to split"); + return absl::InternalError("Cudnn pooling failed to split"); } auto splits = std::move(splits_or.value()); @@ -10468,10 +10204,10 @@ tsl::Status CudnnSupport::DoPoolForward( return status; } } - return ::tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status CudnnSupport::DoPoolBackward( +absl::Status CudnnSupport::DoPoolBackward( dnn::DataType element_type, Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, @@ -10484,7 +10220,7 @@ tsl::Status CudnnSupport::DoPoolBackward( output_diff_data, workspace_allocator); } -tsl::Status CudnnSupport::DoPoolBackward( +absl::Status CudnnSupport::DoPoolBackward( dnn::DataType element_type, Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const NumericOptions& numeric_options, @@ -10520,14 +10256,13 @@ tsl::Status CudnnSupport::DoPoolBackward( cudnn.handle(), pooling_desc.handle(), alpha, dest_desc.handle(), output_ptr, dest_desc.handle(), input_diff_ptr, src_desc.handle(), input_ptr, beta, src_desc.handle(), output_diff_ptr)); - return ::tsl::OkStatus(); + return absl::OkStatus(); }; auto splits_or = GetTensorSplits(input_dimensions, output_dimensions, element_type); if (!splits_or.ok()) { - return tsl::Status(absl::StatusCode::kInternal, - "Cudnn pooling failed to split"); + return absl::InternalError("Cudnn pooling failed to split"); } auto splits = std::move(splits_or.value()); @@ -10557,7 +10292,7 @@ tsl::Status CudnnSupport::DoPoolBackward( return status; } } - return ::tsl::OkStatus(); + return absl::OkStatus(); } bool CudnnSupport::DoNormalizeWithDimensions( @@ -10590,7 +10325,7 @@ bool CudnnSupport::DoNormalizeWithDimensions( cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, dims.handle(), input_data.opaque(), &beta, dims.handle(), output_data->opaque())); - return ::tsl::OkStatus(); + return absl::OkStatus(); }(); return IsStatusOk(status, /*report_error=*/true); } @@ -10625,113 +10360,11 @@ bool CudnnSupport::DoNormalizeBackwardWithDimensions( &alpha, dims.handle(), normalized_data.opaque(), dims.handle(), normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(), &beta, dims.handle(), raw_variable_gradient->opaque())); - return ::tsl::OkStatus(); + return absl::OkStatus(); }(); return IsStatusOk(status, /*report_error=*/true); } -bool CudnnSupport::DoDepthConcatenate(Stream* stream, - BatchDescriptorSlice input_dimensions, - DeviceMemorySlice input_data, - DeviceMemory* output_data) { - CHECK_EQ(input_dimensions.size(), input_data.size()); - - for (const auto& dimensions : input_dimensions) { - if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) { - LOG(ERROR) << "CudnnSupport::DoDepthConcatenate currently only " - "supports the kBatchDepthYX layout."; - return false; - } - } - - if (input_dimensions.empty()) { - return true; // Nothing to do. - } - - dnn::BatchDescriptor output_dimensions = - dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions); - - const int64_t area = output_dimensions.width() * output_dimensions.height(); - const auto index = [area](int64_t batch, int64_t depth, int64_t yx, - int64_t max_depth) { - return (batch * max_depth + depth) * area + yx; - }; - - std::vector output_host(output_dimensions.ElementCount()); - std::vector tmp; - int64_t depth_sum = 0; - for (size_t i = 0; i < input_data.size(); ++i) { - const auto& dimensions = input_dimensions[i]; - tmp.resize(dimensions.ElementCount()); - stream->ThenMemcpyD2H(*input_data[i], absl::MakeSpan(tmp)); - tsl::Status block_status = stream->BlockHostUntilDone(); - if (!block_status.ok()) { - LOG(ERROR) << "BlockHostUntilDone failed: " << block_status; - return false; - } - - for (int64_t batch = 0; batch < output_dimensions.count(); ++batch) { - for (int64_t yx = 0; yx < area; ++yx) { - for (int64_t depth = 0; depth < dimensions.feature_map_count(); - ++depth) { - LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' ' - << yx << ' ' << depth; - output_host[index(batch, depth + depth_sum, yx, - output_dimensions.feature_map_count())] = - tmp[index(batch, depth, yx, dimensions.feature_map_count())]; - } - } - } - depth_sum += dimensions.feature_map_count(); - } - stream->ThenMemcpyH2D(output_host, output_data); - return true; -} - -bool CudnnSupport::DoElementwiseOperate(Stream*, dnn::ElementwiseOperation, - BatchDescriptorSlice, - DeviceMemorySlice, - const dnn::BatchDescriptor&, - DeviceMemory*) { - LOG(FATAL) << "not yet implemented"; // TODO(leary) - return false; -} - -bool CudnnSupport::DoXYPad(Stream* stream, - const dnn::BatchDescriptor& dimensions, - const DeviceMemory& input_data, - int64_t left_pad, int64_t right_pad, int64_t top_pad, - int64_t bottom_pad, - DeviceMemory* output_data) { - LOG(FATAL) << "not yet implemented"; // TODO(leary) - return false; -} - -bool CudnnSupport::DoXYSlice(Stream* stream, - const dnn::BatchDescriptor& dimensions, - const DeviceMemory& input_data, - int64_t left_trim, int64_t right_trim, - int64_t top_trim, int64_t bottom_trim, - DeviceMemory* output_data) { - LOG(FATAL) << "not yet implemented"; // TODO(leary) - return false; -} - -bool CudnnSupport::DoMemcpyD2HQuantized( - Stream* stream, const DeviceMemory& gpu_unquantized_src, - dnn::QuantizedActivationMode mode, void* host_dst, int64_t size) { - LOG(ERROR) << "quantized memcpy not supported by cuDNN"; - return false; -} - -bool CudnnSupport::DoMemcpyH2DQuantized( - Stream* stream, const void* host_src, int64_t size, - dnn::QuantizedActivationMode mode, - DeviceMemory* gpu_unquantized_dst) { - LOG(ERROR) << "quantized memcpy not supported by cuDNN"; - return false; -} - bool CudnnSupport::DeriveOutputBatchDescriptor( const dnn::BatchDescriptor& batch_descriptor, const dnn::FilterDescriptor& filter_descriptor, @@ -10754,15 +10387,78 @@ bool CudnnSupport::DeriveOutputBatchDescriptor( output_batch_descriptor->set_spatial_dim(static_cast(i), dims.rbegin()[i]); } - return ::tsl::OkStatus(); + return absl::OkStatus(); }(); return IsStatusOk(status, /*report_error=*/true); } +#if CUDNN_VERSION >= 8100 + +absl::StatusOr> CudnnSupport::DeserializeGraph( + absl::string_view serialized_data) const { + TF_ASSIGN_OR_RETURN(auto cudnn, cudnn_->GetLocalHandle()); + cudnn_frontend::graph::Graph graph; + RETURN_IF_CUDNN_FRONTEND_ERROR(graph.deserialize( + cudnn->handle(), + std::vector(serialized_data.data(), + serialized_data.data() + serialized_data.size()))); + return std::make_unique(std::move(graph)); +} + +absl::StatusOr CudnnGraph::Prepare(dnn::DnnSupport& dnn_support) { + const CudnnSupport& cudnn_support = static_cast(dnn_support); + TF_ASSIGN_OR_RETURN(auto cudnn, cudnn_support.cudnn_->GetLocalHandle()); + RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.validate()); + RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.build_operation_graph(cudnn->handle())); + RETURN_IF_CUDNN_FRONTEND_ERROR( + graph_.create_execution_plans({cudnn_frontend::HeurMode_t::A})); + if (auto result = graph_.check_support(cudnn->handle()); result.is_bad()) { + VLOG(3) << result.get_message(); + return false; + } + return true; +} + +absl::Status CudnnGraph::Build(dnn::DnnSupport& dnn_support, + const int64_t plan_id) { + const CudnnSupport& cudnn_support = static_cast(dnn_support); + TF_ASSIGN_OR_RETURN(auto cudnn, cudnn_support.cudnn_->GetLocalHandle()); + RETURN_IF_CUDNN_FRONTEND_ERROR( + graph_.build_plan_at_index(cudnn->handle(), plan_id)); + return absl::OkStatus(); +} + +absl::Status CudnnGraph::Execute(Stream& stream, + absl::Span operands) const { + std::unordered_map, + void*> + tensor_to_ptr_map; + int operand_number = 0; + CHECK_EQ(graph_.get_workspace_size(), 0); + for (DeviceMemoryBase operand : operands) { + const cudnn_frontend::graph::Tensor_attributes attr = + cudnn_frontend::graph::Tensor_attributes().set_uid( + CuDnnTensorUID(operand_number)); + ++operand_number; + tensor_to_ptr_map + [std::make_shared(attr)] = + operand.opaque(); + } + const CudnnSupport& dnn_support = + static_cast(*stream.parent()->AsDnn()); + RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.execute( + dnn_support.cudnn_ + ->GetHandle(ExtractGpuExecutor(stream.parent()), &stream) + .handle(), + tensor_to_ptr_map, /*workspace=*/nullptr)); + return absl::OkStatus(); +} + +#endif // CUDNN_VERSION >= 8100 } // namespace gpu void initialize_cudnn() { - tsl::Status status = + absl::Status status = PluginRegistry::Instance()->RegisterFactory( cuda::kCudaPlatformId, "cuDNN", [](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* { @@ -10794,5 +10490,6 @@ void initialize_cudnn() { #pragma clang diagnostic pop #endif -REGISTER_MODULE_INITIALIZER(register_cudnn, - { stream_executor::initialize_cudnn(); }); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(register_cudnn, { + stream_executor::initialize_cudnn(); +}); diff --git a/xla/stream_executor/cuda/cuda_dnn.h b/xla/stream_executor/cuda/cuda_dnn.h index 6c26300c999fd..06cc5fec2dd9b 100644 --- a/xla/stream_executor/cuda/cuda_dnn.h +++ b/xla/stream_executor/cuda/cuda_dnn.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,18 +19,26 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_ #define XLA_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_ +#include #include +#include #include #include #include -#include "absl/base/thread_annotations.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/stream_executor/cuda/cuda_activation.h" +#include "third_party/gpus/cudnn/cudnn_version.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" -#include "xla/stream_executor/plugin_registry.h" -#include "xla/stream_executor/temporary_device_memory.h" -#include "tsl/platform/status.h" +#include "xla/stream_executor/numeric_options.h" +#include "tsl/protobuf/dnn.pb.h" + +#if CUDNN_VERSION >= 8100 +#include "third_party/cudnn_frontend/include/cudnn_frontend.h" +#endif // CUDNN_VERSION >= 8100 namespace stream_executor { namespace gpu { @@ -46,16 +54,35 @@ using BatchDescriptorSlice = absl::Span; template using DeviceMemorySlice = absl::Span* const>; +#if CUDNN_VERSION >= 8100 +class CudnnGraph : public dnn::DnnGraph { + public: + explicit CudnnGraph(cudnn_frontend::graph::Graph&& graph) + : graph_(std::move(graph)) {} + // Prepares a graph and checks whether it is generally supported. + absl::StatusOr Prepare(dnn::DnnSupport&) override; + // Builds single plan of the graph with given ID. + absl::Status Build(dnn::DnnSupport&, int64_t plan_id) override; + // Builds all the plans + absl::Status Execute(Stream& stream, + absl::Span operands) const override; + const cudnn_frontend::graph::Graph& Graph() const { return graph_; } + + private: + cudnn_frontend::graph::Graph graph_; +}; +#endif // CUDNN_VERSION >= 8100 + // cudnn-library based DNN support. For details on overridden interface // functions, see dnn.h. class CudnnSupport : public dnn::DnnSupport { public: explicit CudnnSupport(GpuExecutor* parent); - tsl::Status Init() override; - tsl::StatusOr GetVersion() override; + absl::Status Init() override; + absl::StatusOr GetVersion() override; - tsl::StatusOr> createRnnDescriptor( + absl::StatusOr> CreateRnnDescriptor( int num_layers, int hidden_size, int input_size, int cell_size, int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, @@ -63,20 +90,20 @@ class CudnnSupport : public dnn::DnnSupport { const NumericOptions& numeric_options, float dropout, uint64_t seed, ScratchAllocator* state_allocator, bool use_padded_io) override; - tsl::StatusOr> - createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, + absl::StatusOr> + CreateRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, int data_size, dnn::DataType data_type) override; - tsl::StatusOr> - createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, + absl::StatusOr> + CreateRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, int data_size, const absl::Span& seq_lengths, bool time_major, dnn::DataType data_type) override; - tsl::StatusOr> - createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, + absl::StatusOr> + CreateRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, dnn::DataType data_type) override; bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, @@ -214,7 +241,7 @@ class CudnnSupport : public dnn::DnnSupport { ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) override; - tsl::Status GetConvolveRunners( + absl::Status GetConvolveRunners( bool use_cudnn_frontend, dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType output_type, Stream* stream, const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, @@ -228,7 +255,7 @@ class CudnnSupport : public dnn::DnnSupport { std::vector>* out_exec_plans) override; - tsl::StatusOr> ConvolveRunnerFromDesc( + absl::StatusOr> ConvolveRunnerFromDesc( Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor, @@ -236,7 +263,7 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_descriptor, const dnn::ConvolutionDescriptor& convolution_descriptor) override; - tsl::Status GetGraphConvolveRunners( + absl::Status GetGraphConvolveRunners( dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType output_type, Stream* stream, const dnn::BatchDescriptor& input_descriptor, @@ -247,7 +274,7 @@ class CudnnSupport : public dnn::DnnSupport { std::vector>* out_exec_plans, std::string serialized_graph) override; - tsl::StatusOr> + absl::StatusOr> GraphConvolveRunnerFromDesc( Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, dnn::ConvolutionKind kind, dnn::DataType input_type, @@ -257,7 +284,7 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::ConvolutionDescriptor& convolution_descriptor, std::string serialized_graph) override; - tsl::Status GetFusedConvolveRunners( + absl::Status GetFusedConvolveRunners( bool use_cudnn_frontend, dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType bias_type, dnn::DataType output_type, double conv_scale, double side_input_scale, @@ -272,7 +299,7 @@ class CudnnSupport : public dnn::DnnSupport { std::vector>* out_exec_plans) override; - tsl::Status GetFusedMatmulRunners( + absl::Status GetFusedMatmulRunners( bool use_cudnn_frontend, dnn::DataType input_type, dnn::DataType bias_type, dnn::DataType output_type, Stream* stream, bool trans_a, bool trans_b, uint64_t m, uint64_t n, uint64_t k, @@ -282,7 +309,7 @@ class CudnnSupport : public dnn::DnnSupport { std::vector>* out_exec_plans) override; - tsl::StatusOr> + absl::StatusOr> FusedConvolveRunnerFromDesc( Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, dnn::ConvolutionKind kind, dnn::DataType input_type, @@ -295,16 +322,20 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::ConvolutionDescriptor& convolution_descriptor, dnn::ActivationMode activation_mode) override; - tsl::StatusOr> NormRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, double epsilon, - const dnn::TensorDescriptor& input_descriptor, + absl::StatusOr> NormRunnerFromDesc( + Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, + dnn::NormKind kind, double epsilon, + const dnn::TensorDescriptor& x_descriptor, const dnn::TensorDescriptor& scale_descriptor, - const dnn::TensorDescriptor& bias_descriptor, - const dnn::TensorDescriptor& output_descriptor, + const dnn::TensorDescriptor& y_or_dx_descriptor, + std::optional bias_descriptor, + std::optional dy_descriptor, std::optional expectation_descriptor, - std::optional norm_factor_descriptor) override; + std::optional norm_factor_descriptor, + std::optional dscale_descriptor, + std::optional dbias_descriptor) override; - tsl::StatusOr> + absl::StatusOr> FusedMHARunnerFromDesc( Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, dnn::FusedMHAKind kind, @@ -319,7 +350,7 @@ class CudnnSupport : public dnn::DnnSupport { std::optional dropout_rate, std::optional seed, bool is_flash_attention, bool is_causal_mask) override; - tsl::StatusOr> + absl::StatusOr> FusedMHABackwardRunnerFromDesc( Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, dnn::FusedMHAKind kind, @@ -427,7 +458,7 @@ class CudnnSupport : public dnn::DnnSupport { DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator) override; - tsl::Status DoConvolve( + absl::Status DoConvolve( dnn::ConvolutionKind kind, dnn::DataType element_type, dnn::DataType output_type, Stream* stream, const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, @@ -439,11 +470,11 @@ class CudnnSupport : public dnn::DnnSupport { dnn::AlgorithmDesc algorithm_desc, DeviceMemory scratch_memory, dnn::ProfileResult* output_profile_result) override; - tsl::Status DoFusedConvolve( + absl::Status DoFusedConvolve( Stream* stream, dnn::DataType input_type, dnn::DataType side_input_type, dnn::DataType bias_type, dnn::DataType output_type, const dnn::BatchDescriptor& conv_input_descriptor, - DeviceMemoryBase conv_input_data, double conv_input_scale, + DeviceMemoryBase conv_input_data, double conv_scale, const dnn::FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data, const dnn::ConvolutionDescriptor& convolution_descriptor, @@ -455,100 +486,51 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) override; - tsl::Status CudnnReorderConvolutionFilterAndBias( + absl::Status CudnnReorderConvolutionFilterAndBias( Stream* stream, const dnn::FilterDescriptor& filter_descriptor, const DeviceMemory& filter_input, DeviceMemory* filter_output, std::optional> bias_input, std::optional> bias_output) override; - bool DoSeparableConvolve( - Stream* stream, const dnn::BatchDescriptor& batch_descriptor, - const DeviceMemory& input_data, - const dnn::FilterDescriptor& filter_descriptor, int depth_multiplier, - const DeviceMemory& first_weights, - const DeviceMemory& second_weights, - const dnn::ConvolutionDescriptor& convolution_descriptor, - const dnn::BatchDescriptor& output_descriptor, - DeviceMemory* output_data) override { - LOG(ERROR) << "separable convolution not supported by CUDNN"; - return false; - } - - bool DoMatMul(Stream* stream, const DeviceMemory& input_data, - const DeviceMemory& weights, - const dnn::BatchDescriptor& input_dimensions, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) override; - - bool DoMatMulQuantized(Stream* stream, const DeviceMemory& input_data, - const DeviceMemory& quantized_weights, - const DeviceMemory& weight_scales, - const dnn::BatchDescriptor& input_dimensions, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) override { - LOG(ERROR) << "DNN MatMulQuantized not supported by CUDNN"; - return false; - } - - bool DoMatMulQuantized(Stream* stream, const DeviceMemory& input_data, - const DeviceMemory& quantized_weights, - const DeviceMemory& weight_scales, - const dnn::BatchDescriptor& input_dimensions, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) override { - LOG(ERROR) << "DNN MatMulQuantized not supported by CUDNN"; - return false; - } - - bool DoBiasAdd(Stream* stream, const DeviceMemory& input_data, - const DeviceMemory& biases, - const dnn::BatchDescriptor& dimensions, - DeviceMemory* output_data) override; - - bool DoActivate(Stream* stream, dnn::ActivationMode activation_mode, - const dnn::BatchDescriptor& dimensions, - const DeviceMemory& input_data, - DeviceMemory* output_data, uint64_t options) override; - - tsl::Status DoPoolForward(dnn::DataType element_type, Stream* stream, - const dnn::PoolingDescriptor& pooling_dimensions, - const dnn::BatchDescriptor& input_dimensions, - DeviceMemoryBase input_data, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemoryBase output_data, - ScratchAllocator* workspace_allocator) override; - - tsl::Status DoPoolForward(dnn::DataType element_type, Stream* stream, - const dnn::PoolingDescriptor& pooling_dimensions, - const NumericOptions& numeric_options, - const dnn::BatchDescriptor& input_dimensions, - DeviceMemoryBase input_data, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemoryBase output_data, - ScratchAllocator* workspace_allocator) override; - - tsl::Status DoPoolBackward(dnn::DataType element_type, Stream* stream, + absl::Status DoPoolForward(dnn::DataType element_type, Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, const dnn::BatchDescriptor& output_dimensions, DeviceMemoryBase output_data, - DeviceMemoryBase input_diff_data, - DeviceMemoryBase output_diff_data, ScratchAllocator* workspace_allocator) override; - tsl::Status DoPoolBackward(dnn::DataType element_type, Stream* stream, + absl::Status DoPoolForward(dnn::DataType element_type, Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const NumericOptions& numeric_options, const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, const dnn::BatchDescriptor& output_dimensions, DeviceMemoryBase output_data, - DeviceMemoryBase input_diff_data, - DeviceMemoryBase output_diff_data, ScratchAllocator* workspace_allocator) override; + absl::Status DoPoolBackward(dnn::DataType element_type, Stream* stream, + const dnn::PoolingDescriptor& pooling_dimensions, + const dnn::BatchDescriptor& input_dimensions, + DeviceMemoryBase input_data, + const dnn::BatchDescriptor& output_dimensions, + DeviceMemoryBase output_data, + DeviceMemoryBase input_diff_data, + DeviceMemoryBase output_diff_data, + ScratchAllocator* workspace_allocator) override; + + absl::Status DoPoolBackward(dnn::DataType element_type, Stream* stream, + const dnn::PoolingDescriptor& pooling_dimensions, + const NumericOptions& numeric_options, + const dnn::BatchDescriptor& input_dimensions, + DeviceMemoryBase input_data, + const dnn::BatchDescriptor& output_dimensions, + DeviceMemoryBase output_data, + DeviceMemoryBase input_diff_data, + DeviceMemoryBase output_diff_data, + ScratchAllocator* workspace_allocator) override; + bool DoNormalizeWithDimensions( Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor, const dnn::BatchDescriptor& dimensions, @@ -564,36 +546,6 @@ class CudnnSupport : public dnn::DnnSupport { DeviceMemory* raw_variable_gradient, ScratchAllocator* workspace_allocator) override; - bool DoDepthConcatenate(Stream* stream, BatchDescriptorSlice input_dimensions, - DeviceMemorySlice input_data, - DeviceMemory* output_data) override; - - bool DoElementwiseOperate(Stream* stream, dnn::ElementwiseOperation operation, - BatchDescriptorSlice input_dimensions, - DeviceMemorySlice input_data, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) override; - - bool DoXYPad(Stream* stream, const dnn::BatchDescriptor& dimensions, - const DeviceMemory& input_data, int64_t left_pad, - int64_t right_pad, int64_t top_pad, int64_t bottom_pad, - DeviceMemory* output_data) override; - - bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor& dimensions, - const DeviceMemory& input_data, int64_t left_trim, - int64_t right_trim, int64_t top_trim, int64_t bottom_trim, - DeviceMemory* output_data) override; - - bool DoMemcpyD2HQuantized(Stream* stream, - const DeviceMemory& device_unquantized_src, - dnn::QuantizedActivationMode mode, void* host_dst, - int64_t size) override; - - bool DoMemcpyH2DQuantized( - Stream* stream, const void* host_src, int64_t size, - dnn::QuantizedActivationMode mode, - DeviceMemory* device_unquantized_dst) override; - // Derives an output batch descriptor from an input batch and convolution // descriptors. bool DeriveOutputBatchDescriptor( @@ -602,17 +554,17 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::ConvolutionDescriptor& convolution_descriptor, dnn::BatchDescriptor* output_batch_descriptor); - tsl::Status DoCtcLoss(Stream* stream, dnn::DataType element_type, - const dnn::RnnStateTensorDescriptor& probs_desc, - const DeviceMemoryBase probs_data, - absl::Span labels_data, - absl::Span labels_lengths_data, - absl::Span input_lengths_data, - DeviceMemoryBase costs_data, - const dnn::RnnStateTensorDescriptor& grads_desc, - DeviceMemoryBase grads_data, - DeviceMemory scratch_memory, - int ctc_loss_algo_id) override; + absl::Status DoCtcLoss(Stream* stream, dnn::DataType element_type, + const dnn::RnnStateTensorDescriptor& probs_desc, + const DeviceMemoryBase probs_data, + absl::Span labels_data, + absl::Span labels_lengths_data, + absl::Span input_lengths_data, + DeviceMemoryBase costs_data, + const dnn::RnnStateTensorDescriptor& grads_desc, + DeviceMemoryBase grads_data, + DeviceMemory scratch_memory, + int ctc_loss_algo_id) override; bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc, dnn::DataType input_type, @@ -623,7 +575,16 @@ class CudnnSupport : public dnn::DnnSupport { void NotifyStreamDestroyed(Stream* stream) override; +#if CUDNN_VERSION >= 8100 + // Loads complete graph from its serialized representation. + absl::StatusOr> DeserializeGraph( + absl::string_view serialized_data) const override; +#endif // CUDNN_VERSION >= 8100 + private: + // Uses cuDNN handle for execution. + friend class CudnnGraph; + GpuExecutor* parent_; // Parent executor object. Not owned. // Provides access to the cuDNN handle. @@ -645,7 +606,7 @@ class CudnnSupport : public dnn::DnnSupport { std::vector* out_algorithms); template - tsl::Status DoBatchNormalizationForwardImpl( + absl::Status DoBatchNormalizationForwardImpl( Stream* stream, dnn::DataType input_data_type, dnn::DataType scale_data_type, const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& offset, @@ -661,7 +622,7 @@ class CudnnSupport : public dnn::DnnSupport { ScratchAllocator* workspace_allocator); template - tsl::Status DoBatchNormalizationBackwardImpl( + absl::Status DoBatchNormalizationBackwardImpl( Stream* stream, int cudnn_input_type, int cudnn_scale_type, const DeviceMemory& y_backprop, const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& offset, @@ -675,7 +636,7 @@ class CudnnSupport : public dnn::DnnSupport { ScratchAllocator* workspace_allocator); template - tsl::Status DoRnnForwardImpl( + absl::Status DoRnnForwardImpl( Stream* stream, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, @@ -695,7 +656,7 @@ class CudnnSupport : public dnn::DnnSupport { dnn::ProfileResult* output_profile_result); template - tsl::Status DoRnnBackwardImpl( + absl::Status DoRnnBackwardImpl( Stream* stream, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, @@ -721,7 +682,7 @@ class CudnnSupport : public dnn::DnnSupport { ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result); - tsl::Status DoCtcLossImpl( + absl::Status DoCtcLossImpl( Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc, const DeviceMemoryBase probs_data, absl::Span labels_data, absl::Span labels_lengths_data, @@ -731,7 +692,7 @@ class CudnnSupport : public dnn::DnnSupport { DeviceMemory scratch_memory, int ctc_loss_algo_id); private: - tsl::Status DoPrepareForConvolution( + absl::Status DoPrepareForConvolution( dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, const dnn::FilterDescriptor& filter_descriptor, @@ -743,7 +704,7 @@ class CudnnSupport : public dnn::DnnSupport { ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc, DeviceMemory* scratch_memory) override; - tsl::Status DoPrepareForCtcLoss( + absl::Status DoPrepareForCtcLoss( Stream* stream, dnn::DataType element_type, const dnn::RnnStateTensorDescriptor& probs_desc, const dnn::RnnStateTensorDescriptor& grads_desc, diff --git a/xla/stream_executor/cuda/cuda_driver.cc b/xla/stream_executor/cuda/cuda_driver.cc index e60e17d641ce6..9d9054f09a40c 100644 --- a/xla/stream_executor/cuda/cuda_driver.cc +++ b/xla/stream_executor/cuda/cuda_driver.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ limitations under the License. #include "absl/debugging/leak_check.h" #include "absl/log/check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -42,25 +43,17 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/gpus/cuda/include/driver_types.h" -#include "xla/stream_executor/device_options.h" #include "xla/stream_executor/gpu/gpu_diagnostics.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/platform.h" #include "tsl/platform/env.h" -#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" #include "tsl/platform/numbers.h" #include "tsl/platform/stacktrace.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" -static constexpr bool FLAGS_gpuexec_cuda_driver_inject_init_error = false; -static constexpr bool FLAGS_gpuexec_cuda_sync_around_driver_calls = false; -static constexpr bool FLAGS_gpuexec_cuda_device_0_only = false; - #define RETURN_IF_CUDA_RES_ERROR(expr, ...) \ do { \ CUresult _res = (expr); \ @@ -151,8 +144,6 @@ thread_local struct ThreadLocalData { } // namespace ScopedActivateContext::ScopedActivateContext(GpuContext* cuda_context) { - if (FLAGS_gpuexec_cuda_sync_around_driver_calls) SynchronizeOrDie(); - auto* tls = &tls_data; // If this is an outermost scope, we must not assume that the CUDA context has @@ -190,8 +181,6 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* cuda_context) { } ScopedActivateContext::~ScopedActivateContext() { - if (FLAGS_gpuexec_cuda_sync_around_driver_calls) SynchronizeOrDie(); - auto* tls = &tls_data; if (kVerifyGpuContext) { @@ -266,16 +255,11 @@ std::string CUDAPointersToCanAccessString(CUdeviceptr from, CUdeviceptr to) { // Actually performs the work of CUDA initialization. Wrapped up in one-time // execution guard. -static tsl::Status InternalInit() { - CUresult res = CUDA_ERROR_NO_DEVICE; - if (FLAGS_gpuexec_cuda_driver_inject_init_error) { - LOG(ERROR) << "injecting CUDA init error; initialization will fail"; - } else { - res = cuInit(0 /* = flags */); - } +static absl::Status InternalInit() { + CUresult res = cuInit(0 /* = flags */); if (res == CUDA_SUCCESS) { - return ::tsl::OkStatus(); + return absl::OkStatus(); } else if (res == CUDA_ERROR_SHARED_OBJECT_INIT_FAILED) { VLOG(1) << "failed call to cuInit: " << ToString(res); } else { @@ -283,30 +267,63 @@ static tsl::Status InternalInit() { } Diagnostician::LogDiagnosticInformation(); - return tsl::Status(absl::StatusCode::kAborted, - absl::StrCat("failed call to cuInit: ", ToString(res))); + return absl::AbortedError( + absl::StrCat("failed call to cuInit: ", ToString(res))); +} + +// Synchronize with spinlocks. +const char kScheduleSpinString[] = "spin"; +// Synchronize with spinlocks that also call CPU yield instructions. +const char kScheduleYieldString[] = "yield"; +// Synchronize with a "synchronization primitive" (e.g. mutex). +const char kScheduleBlockingSyncString[] = "blocking_sync"; + +int GetFlagsFromEnv() { + const char* gpu_schedule_string = + std::getenv("TF_CUDA_PLATFORM_GPU_DEVICE_SCHEDULE"); + + if (gpu_schedule_string == nullptr) { + return 0; + } + + unsigned device_flags = 0; + if (strcmp(kScheduleSpinString, gpu_schedule_string) == 0) { + device_flags = CU_CTX_SCHED_SPIN; + } else if (strcmp(kScheduleYieldString, gpu_schedule_string) == 0) { + device_flags = CU_CTX_SCHED_YIELD; + } else if (strcmp(kScheduleBlockingSyncString, gpu_schedule_string) == 0) { + device_flags = CU_CTX_SCHED_BLOCKING_SYNC; + } else { + LOG(QFATAL) << "Unknown option for environment variable " + "TF_CUDA_PLATFORM_GPU_DEVICE_SCHEDULE " + << gpu_schedule_string << " should be one of {" + << kScheduleBlockingSyncString << ", " << kScheduleSpinString + << ", " << kScheduleYieldString << "}"; + } + + return device_flags; } } // namespace -/* static */ tsl::Status GpuDriver::Init() { +/* static */ absl::Status GpuDriver::Init() { // Cached return value from calling InternalInit(), as cuInit need only be // called once, but GpuDriver::Init may be called many times. - static tsl::Status* init_retval = [] { - return new tsl::Status(InternalInit()); + static absl::Status* init_retval = [] { + return new absl::Status(InternalInit()); }(); return *init_retval; } -/* static */ tsl::Status GpuDriver::GetDevice(int device_ordinal, - CUdevice* device) { +/* static */ absl::Status GpuDriver::GetDevice(int device_ordinal, + CUdevice* device) { RETURN_IF_CUDA_RES_ERROR(cuDeviceGet(device, device_ordinal), "Failed call to cuDeviceGet"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::GetDeviceName(CUdevice device, - std::string* device_name) { +/* static */ absl::Status GpuDriver::GetDeviceName(CUdevice device, + std::string* device_name) { static const size_t kCharLimit = 64; absl::InlinedVector chars(kCharLimit); RETURN_IF_CUDA_RES_ERROR( @@ -314,42 +331,15 @@ static tsl::Status InternalInit() { "Failed to get device name"); chars[kCharLimit - 1] = '\0'; *device_name = chars.begin(); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -bool DeviceOptionsToContextFlags(const DeviceOptions& device_options, - int* flags) { - static_assert(DeviceOptions::kMask == 0xf, - "needs update for new device options"); - - if (device_options.flags() & DeviceOptions::kDoNotReclaimStackAllocation) { - *flags |= CU_CTX_LMEM_RESIZE_TO_MAX; - } - - // If no flags are set the default is CU_CTX_SCHED_AUTO, which - // in Google environments is very likely to mean SPIN. - if (device_options.flags() & DeviceOptions::kScheduleSpin) { - *flags |= CU_CTX_SCHED_SPIN; - } - if (device_options.flags() & DeviceOptions::kScheduleYield) { - *flags |= CU_CTX_SCHED_YIELD; - } - if (device_options.flags() & DeviceOptions::kScheduleBlockingSync) { - *flags |= CU_CTX_SCHED_BLOCKING_SYNC; - } - - return true; -} - -/* static */ tsl::Status GpuDriver::CreateContext( - int device_ordinal, CUdevice device, const DeviceOptions& device_options, - GpuContext** context) { +/* static */ absl::Status GpuDriver::CreateContext(int device_ordinal, + CUdevice device, + GpuContext** context) { *context = nullptr; - int flags = 0; - if (!DeviceOptionsToContextFlags(device_options, &flags)) { - LOG(WARNING) << "could not convert all device options into context flags"; - } + int flags = GetFlagsFromEnv(); CUresult res; CUcontext former_context; @@ -402,7 +392,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options, << "success in this call must entail non-null result"; VLOG(2) << "created or reused context " << new_context << " for this thread"; - return ::tsl::OkStatus(); + return absl::OkStatus(); } std::string message = @@ -416,7 +406,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options, } } - return tsl::Status(absl::StatusCode::kInternal, message); + return absl::InternalError(message); } /* static */ void GpuDriver::DestroyContext(GpuContext* context) { @@ -441,22 +431,22 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options, return context->context(); } -/* static */ tsl::Status GpuDriver::FuncGetAttribute( +/* static */ absl::Status GpuDriver::FuncGetAttribute( CUfunction_attribute attribute, CUfunction func, int* attribute_value) { RETURN_IF_CUDA_RES_ERROR(cuFuncGetAttribute(attribute_value, attribute, func), "Failed to query kernel attribute: ", attribute); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::FuncSetCacheConfig( +/* static */ absl::Status GpuDriver::FuncSetCacheConfig( CUfunction function, CUfunc_cache cache_config) { RETURN_IF_CUDA_RES_ERROR(cuFuncSetCacheConfig(function, cache_config), "Failed to set CUDA kernel cache config"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::StatusOr GpuDriver::ContextGetSharedMemConfig( - GpuContext* context) { +/* static */ absl::StatusOr +GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { CUsharedconfig shared_mem_config; ScopedActivateContext activation(context); RETURN_IF_CUDA_RES_ERROR(cuCtxGetSharedMemConfig(&shared_mem_config), @@ -464,27 +454,27 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options, return shared_mem_config; } -/* static */ tsl::Status GpuDriver::ContextSetSharedMemConfig( +/* static */ absl::Status GpuDriver::ContextSetSharedMemConfig( GpuContext* context, CUsharedconfig shared_mem_config) { ScopedActivateContext activation(context); RETURN_IF_CUDA_RES_ERROR(cuCtxSetSharedMemConfig(shared_mem_config), "Failed to set shared memory config"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::CreateGraph(CUgraph* graph) { +/* static */ absl::Status GpuDriver::CreateGraph(CUgraph* graph) { VLOG(2) << "Create new CUDA graph"; RETURN_IF_CUDA_RES_ERROR(cuGraphCreate(graph, /*flags=*/0), "Failed to create CUDA graph"); VLOG(2) << "Created CUDA graph " << *graph; - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::DestroyGraph(CUgraph graph) { +/* static */ absl::Status GpuDriver::DestroyGraph(CUgraph graph) { VLOG(2) << "Destroy CUDA graph " << graph; RETURN_IF_CUDA_RES_ERROR(cuGraphDestroy(graph), "Failed to destroy CUDA graph"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } static std::string_view StreamCaptureModeToString( @@ -499,8 +489,8 @@ static std::string_view StreamCaptureModeToString( } } -/* static */ tsl::Status GpuDriver::StreamBeginCapture(CUstream stream, - StreamCaptureMode mode) { +/* static */ absl::Status GpuDriver::StreamBeginCapture( + CUstream stream, StreamCaptureMode mode) { CUstreamCaptureMode cu_mode; switch (mode) { case StreamCaptureMode::kGlobal: @@ -514,26 +504,57 @@ static std::string_view StreamCaptureModeToString( break; } - VLOG(2) << "Beging stream " << stream << " capture in " + VLOG(2) << "Beginning stream " << stream << " capture in " << StreamCaptureModeToString(mode) << " mode"; RETURN_IF_CUDA_RES_ERROR(cuStreamBeginCapture(stream, cu_mode), "Failed to begin stream capture"); - return ::tsl::OkStatus(); + return absl::OkStatus(); +} + +/* static */ absl::Status GpuDriver::StreamBeginCaptureToGraph( + CUstream stream, CUgraph graph, StreamCaptureMode mode) { + CUstreamCaptureMode cu_mode; + switch (mode) { + case StreamCaptureMode::kGlobal: + cu_mode = CU_STREAM_CAPTURE_MODE_GLOBAL; + break; + case StreamCaptureMode::kThreadLocal: + cu_mode = CU_STREAM_CAPTURE_MODE_THREAD_LOCAL; + break; + case StreamCaptureMode::kRelaxed: + cu_mode = CU_STREAM_CAPTURE_MODE_RELAXED; + break; + } + +#if CUDA_VERSION >= 12030 + VLOG(2) << "Beginning stream " << stream << " capture in " + << StreamCaptureModeToString(mode) << " mode to graph " << graph; + RETURN_IF_CUDA_RES_ERROR( + cuStreamBeginCaptureToGraph(stream, graph, + /*dependencies=*/nullptr, + /*dependencyData=*/nullptr, + /*numDependencies=*/0, cu_mode), + "Failed to begin stream capture to graph"); + return absl::OkStatus(); +#else + return absl::UnimplementedError( + "StreamBeginCaptureToGraph is not implemented"); +#endif // CUDA_VERSION >= 12030 } -/* static */ tsl::Status GpuDriver::StreamEndCapture(CUstream stream, - CUgraph* graph) { +/* static */ absl::Status GpuDriver::StreamEndCapture(CUstream stream, + CUgraph* graph) { VLOG(2) << "End stream " << stream << " capture"; RETURN_IF_CUDA_RES_ERROR(cuStreamEndCapture(stream, graph), "Failed to end stream capture"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::GraphInstantiate( +/* static */ absl::Status GpuDriver::GraphInstantiate( CUgraphExec* exec, CUgraph graph, const GraphInstantiateFlags& flags) { - VLOG(2) << "Instante CUDA executable graph from graph " << graph << " (" + VLOG(2) << "Instantiate CUDA executable graph from graph " << graph << " (" << "auto_free_on_launch=" << flags.auto_free_on_launch << ", " << "device_launch=" << flags.device_launch << ", " << "use_node_priority=" << flags.use_node_prirotiy << ", " @@ -556,19 +577,31 @@ static std::string_view StreamCaptureModeToString( "Failed to instantiate CUDA graph"); #endif // CUDA_VERSION >= 12000 - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::GraphLaunch(CUgraphExec exec, - CUstream stream) { +/* static */ absl::Status GpuDriver::GraphLaunch(CUgraphExec exec, + CUstream stream) { VLOG(2) << "Launching CUDA executable graph " << exec << " on a stream " << stream; RETURN_IF_CUDA_RES_ERROR(cuGraphLaunch(exec, stream), "Failed to launch CUDA graph"); - return ::tsl::OkStatus(); + return absl::OkStatus(); +} + +/* static */ absl::Status GpuDriver::GraphNodeSetEnabled(CUgraphExec exec, + CUgraphNode node, + bool enabled) { + // Node is enabled if value != 0, otherwise the node is disabled. + unsigned value = enabled ? 1 : 0; + VLOG(2) << "Set CUDA executable graph " << exec << " node " << node + << " enabled flag to " << value; + RETURN_IF_CUDA_RES_ERROR(cuGraphNodeSetEnabled(exec, node, value), + "Failed to set CUDA graph node enabled flag"); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::GraphExecUpdate( +/* static */ absl::Status GpuDriver::GraphExecUpdate( CUgraphExec exec, CUgraph graph, GraphExecUpdateResultInfo* result) { VLOG(2) << "Update CUDA graph executable " << exec << " with graph " << graph; @@ -620,14 +653,14 @@ static std::string_view StreamCaptureModeToString( break; #endif // CUDA_VERSION >= 12000 default: - return tsl::errors::Internal("Unknown graph update result"); + return absl::InternalError("Unknown graph update result"); } RETURN_IF_CUDA_RES_ERROR(err_code, "Failed to update CUDA graph"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::StatusOr +/* static */ absl::StatusOr GpuDriver::GraphNodeGetType(CUgraphNode node) { CUgraphNodeType cu_node_type; memset(&cu_node_type, 0, sizeof(cu_node_type)); @@ -664,22 +697,40 @@ GpuDriver::GraphNodeGetType(CUgraphNode node) { return GraphNodeType::kBatchMemOp; #endif // CUDA_VERSION >= 12000 default: - return tsl::errors::Internal("Unknown graph node type"); + return absl::InternalError("Unknown graph node type"); } - return tsl::Status(absl::StatusCode::kInternal, - "Invalid CUDA graph node type"); + return absl::InternalError("Invalid CUDA graph node type"); +} + +absl::StatusOr> +GpuDriver::GraphNodeGetDependencies(GpuGraphNodeHandle node) { + VLOG(2) << "Get CUDA graph node " << node << " dependencies"; + + std::vector dependencies; + + size_t num_dependencies = 0; + RETURN_IF_CUDA_RES_ERROR( + cuGraphNodeGetDependencies(node, nullptr, &num_dependencies), + "Failed to get CUDA graph node depedencies size"); + + dependencies.resize(num_dependencies, nullptr); + RETURN_IF_CUDA_RES_ERROR( + cuGraphNodeGetDependencies(node, dependencies.data(), &num_dependencies), + "Failed to get CUDA graph node depedencies"); + + return dependencies; } -/* static */ tsl::Status GpuDriver::DestroyGraphExec(CUgraphExec exec) { +/* static */ absl::Status GpuDriver::DestroyGraphExec(CUgraphExec exec) { VLOG(2) << "Destroying CUDA executable graph " << exec; RETURN_IF_CUDA_RES_ERROR(cuGraphExecDestroy(exec), "Failed to destroy CUDA executable graph"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::GraphDebugDotPrint(CUgraph graph, - const char* path) { +/* static */ absl::StatusOr GpuDriver::GraphDebugDotPrint( + CUgraph graph, const char* path, bool return_printed_graph) { #if CUDA_VERSION >= 12000 VLOG(2) << "Print CUDA graph " << graph << " debug dot file to " << path; @@ -687,27 +738,28 @@ GpuDriver::GraphNodeGetType(CUgraphNode node) { RETURN_IF_CUDA_RES_ERROR(cuGraphDebugDotPrint(graph, path, flags), "Failed to print gpu graph debug file"); - if (VLOG_IS_ON(100)) { + if (return_printed_graph) { std::string data; if (tsl::ReadFileToString(tsl::Env::Default(), path, &data).ok()) { - VLOG(200) << "CUDA graph " << graph << " debug file:\n" << data; + return data; } else { LOG(WARNING) << "failed to read gpu graph debug file " << path; } } #endif // CUDA_VERSION >= 12000 - return ::tsl::OkStatus(); + return std::string(path); } -/* static */ tsl::Status GpuDriver::DeviceGraphMemTrim(CUdevice device) { +/* static */ absl::Status GpuDriver::DeviceGraphMemTrim(CUdevice device) { VLOG(2) << "Trim CUDA device graph memory " << device; RETURN_IF_CUDA_RES_ERROR(cuDeviceGraphMemTrim(device), "Failed to trim device graph memory"); - return tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::StatusOr GpuDriver::StreamIsCapturing(CUstream stream) { +/* static */ absl::StatusOr GpuDriver::StreamIsCapturing( + CUstream stream) { VLOG(2) << "Checking if stream " << stream << " is capturing"; CUstreamCaptureStatus status; @@ -717,7 +769,7 @@ GpuDriver::GraphNodeGetType(CUgraphNode node) { return status == CU_STREAM_CAPTURE_STATUS_ACTIVE; } -/* static */ tsl::Status GpuDriver::GraphConditionalHandleCreate( +/* static */ absl::Status GpuDriver::GraphConditionalHandleCreate( GpuGraphConditionalHandle* handle, CUgraph graph, GpuContext* context, unsigned int default_launch_value, unsigned int flags) { VLOG(2) << "Create conditional handle for a graph " << graph @@ -734,7 +786,7 @@ GpuDriver::GraphNodeGetType(CUgraphNode node) { return absl::UnimplementedError( "CUDA graph conditional nodes are not implemented"); #endif // CUDA_VERSION >= 12030 - return ::tsl::OkStatus(); + return absl::OkStatus(); } static std::string ConditionalTypeToString( @@ -747,9 +799,9 @@ static std::string ConditionalTypeToString( } } -/* static */ tsl::StatusOr +/* static */ absl::StatusOr GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, - absl::Span deps, + absl::Span deps, const GpuGraphNodeParams& params) { #if CUDA_VERSION >= 12030 // Add conditional node to a graph. @@ -790,19 +842,19 @@ GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, return absl::UnimplementedError("unsupported node type"); } -/* static */ tsl::Status GpuDriver::GraphAddEmptyNode( - CUgraphNode* node, CUgraph graph, absl::Span deps) { +/* static */ absl::Status GpuDriver::GraphAddEmptyNode( + CUgraphNode* node, CUgraph graph, absl::Span deps) { VLOG(2) << "Add empty node to a graph " << graph << "; deps: " << deps.size(); RETURN_IF_CUDA_RES_ERROR( cuGraphAddEmptyNode(node, graph, deps.data(), deps.size()), "Failed to add empty node to a CUDA graph"); - return tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::GraphAddKernelNode( - CUgraphNode* node, CUgraph graph, absl::Span deps, +/* static */ absl::Status GpuDriver::GraphAddKernelNode( + CUgraphNode* node, CUgraph graph, absl::Span deps, absl::string_view kernel_name, CUfunction function, unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x, unsigned int block_dim_y, unsigned int block_dim_z, @@ -828,6 +880,9 @@ GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, params.kernelParams = kernel_params; params.extra = extra; + // TODO(ezhulenev): Why do we do it on every call to launch kernel? This + // should be moved one level up to se::Kernel level, and done just once (or + // updated once we get a new larger shared memory request). if (shared_mem_bytes != 0) { RETURN_IF_CUDA_RES_ERROR( cuFuncSetAttribute(function, @@ -840,10 +895,10 @@ GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, cuGraphAddKernelNode(node, graph, deps.data(), deps.size(), ¶ms), "Failed to add kernel node to a CUDA graph"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/*static*/ tsl::Status GpuDriver::GraphExecKernelNodeSetParams( +/*static*/ absl::Status GpuDriver::GraphExecKernelNodeSetParams( CUgraphExec exec, CUgraphNode node, absl::string_view kernel_name, CUfunction function, unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x, unsigned int block_dim_y, @@ -869,6 +924,9 @@ GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, params.kernelParams = kernel_params; params.extra = extra; + // TODO(ezhulenev): Why do we do it on every call to launch kernel? This + // should be moved one level up to se::Kernel level, and done just once (or + // updated once we get a new larger shared memory request). if (shared_mem_bytes != 0) { RETURN_IF_CUDA_RES_ERROR( cuFuncSetAttribute(function, @@ -880,7 +938,7 @@ GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, RETURN_IF_CUDA_RES_ERROR(cuGraphExecKernelNodeSetParams(exec, node, ¶ms), "Failed to set CUDA graph kernel node params"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } static CUmemAccess_flags ToCudaMemAccessFlags( @@ -902,7 +960,7 @@ static CUmemLocationType ToCudaLocationType( return CU_MEM_LOCATION_TYPE_INVALID; case GpuDriver::MemLocationType::kDevice: return CU_MEM_LOCATION_TYPE_DEVICE; -#if CUDA_VERSION >= 12000 +#if CUDA_VERSION >= 12030 case GpuDriver::MemLocationType::kHost: return CU_MEM_LOCATION_TYPE_HOST; case GpuDriver::MemLocationType::kHostNuma: @@ -914,7 +972,7 @@ static CUmemLocationType ToCudaLocationType( case GpuDriver::MemLocationType::kHostNuma: case GpuDriver::MemLocationType::kHostNumaCurrent: return CU_MEM_LOCATION_TYPE_INVALID; -#endif // CUDA_VERSION >= 12000 +#endif // CUDA_VERSION >= 12030 } } @@ -928,8 +986,8 @@ static CUmemAllocationType ToCudaAllocationType( } } -/*static*/ tsl::Status GpuDriver::GraphAddMemAllocNode( - CUgraphNode* node, CUgraph graph, absl::Span deps, +/*static*/ absl::Status GpuDriver::GraphAddMemAllocNode( + CUgraphNode* node, CUgraph graph, absl::Span deps, GpuDriver::MemAccessFlags access_flags, GpuDriver::MemLocationType location_type, int device_id, GpuDriver::MemAllocationType allocation_type, uint64_t size, @@ -949,9 +1007,9 @@ static CUmemAllocationType ToCudaAllocationType( mem_pool_props.allocType = ToCudaAllocationType(allocation_type); mem_pool_props.handleTypes = CU_MEM_HANDLE_TYPE_NONE; mem_pool_props.location = mem_location; -#if CUDA_VERSION >= 12000 +#if CUDA_VERSION >= 12030 mem_pool_props.maxSize = max_pool_size; -#endif // CUDA_VERSION >= 12000 +#endif // CUDA_VERSION >= 12030 // cuda graph requires reserved space initialized to 0 memset(mem_pool_props.reserved, 0, sizeof(mem_pool_props.reserved)); @@ -968,10 +1026,10 @@ static CUmemAllocationType ToCudaAllocationType( << " address " << reinterpret_cast(params.dptr); *d_ptr = params.dptr; - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/*static*/ tsl::StatusOr> +/*static*/ absl::StatusOr> GpuDriver::GraphGetMemAllocNodeParams(CUgraphNode node) { CUDA_MEM_ALLOC_NODE_PARAMS params; RETURN_IF_CUDA_RES_ERROR(cuGraphMemAllocNodeGetParams(node, ¶ms), @@ -979,10 +1037,19 @@ GpuDriver::GraphGetMemAllocNodeParams(CUgraphNode node) { return std::pair{params.dptr, params.bytesize}; } -/* static */ tsl::Status GpuDriver::GraphAddMemcpyD2DNode( +/*static*/ absl::Status GpuDriver::GraphAddMemFreeNode( + CUgraphNode* node, CUgraph graph, absl::Span deps, + CUdeviceptr gpu_dst) { + RETURN_IF_CUDA_RES_ERROR( + cuGraphAddMemFreeNode(node, graph, deps.data(), deps.size(), gpu_dst), + "Failed to add memory free node to a CUDA graph"); + return absl::OkStatus(); +} + +/* static */ absl::Status GpuDriver::GraphAddMemcpyD2DNode( GpuContext* context, CUgraphNode* node, CUgraph graph, - absl::Span deps, CUdeviceptr gpu_dst, CUdeviceptr gpu_src, - uint64_t size) { + absl::Span deps, CUdeviceptr gpu_dst, + CUdeviceptr gpu_src, uint64_t size) { VLOG(2) << "Add memcpy d2d node to a graph " << graph << "; dst: " << reinterpret_cast(gpu_dst) << "; src: " << reinterpret_cast(gpu_src) << "; size: " << size @@ -1004,10 +1071,10 @@ GpuDriver::GraphGetMemAllocNodeParams(CUgraphNode node) { context->context()), "Failed to add memcpy d2d node to a CUDA graph"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::GraphExecMemcpyD2DNodeSetParams( +/* static */ absl::Status GpuDriver::GraphExecMemcpyD2DNodeSetParams( GpuContext* context, GpuGraphExecHandle exec, GpuGraphNodeHandle node, GpuDevicePtr gpu_dst, GpuDevicePtr gpu_src, uint64_t size) { VLOG(2) << "Set memcpy d2d node params " << node << " in graph executable " @@ -1030,7 +1097,7 @@ GpuDriver::GraphGetMemAllocNodeParams(CUgraphNode node) { cuGraphExecMemcpyNodeSetParams(exec, node, ¶ms, context->context()), "Failed to set memcpy d2d node params"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } namespace { @@ -1065,9 +1132,9 @@ struct BitPatternToValue { } // namespace -/* static */ tsl::Status GpuDriver::GraphAddMemsetNode( +/* static */ absl::Status GpuDriver::GraphAddMemsetNode( GpuContext* context, CUgraphNode* node, GpuGraphHandle graph, - absl::Span deps, CUdeviceptr dst, + absl::Span deps, CUdeviceptr dst, std::variant bit_pattern, uint64_t num_elements) { VLOG(2) << "Add memset node to a graph " << graph @@ -1093,10 +1160,10 @@ struct BitPatternToValue { context->context()), "Failed to add memset node to a CUDA graph"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::GraphExecMemsetNodeSetParams( +/* static */ absl::Status GpuDriver::GraphExecMemsetNodeSetParams( GpuContext* context, CUgraphExec exec, CUgraphNode node, CUdeviceptr dst, std::variant bit_pattern, uint64_t num_elements) { @@ -1122,11 +1189,11 @@ struct BitPatternToValue { cuGraphExecMemsetNodeSetParams(exec, node, ¶ms, context->context()), "Failed to set memset node params"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::GraphAddChildNode( - CUgraphNode* node, CUgraph graph, absl::Span deps, +/* static */ absl::Status GpuDriver::GraphAddChildNode( + CUgraphNode* node, CUgraph graph, absl::Span deps, CUgraph child) { VLOG(2) << "Create a new node by cloning the child graph " << child << " and add it to " << graph << "; deps: " << deps.size(); @@ -1135,12 +1202,12 @@ struct BitPatternToValue { cuGraphAddChildGraphNode(node, graph, deps.data(), deps.size(), child), "Failed to create a child graph node and add it to a CUDA graph"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/*static*/ tsl::Status GpuDriver::GraphExecChildNodeSetParams(CUgraphExec exec, - CUgraphNode node, - CUgraph child) { +/*static*/ absl::Status GpuDriver::GraphExecChildNodeSetParams(CUgraphExec exec, + CUgraphNode node, + CUgraph child) { VLOG(2) << "Set child node params " << node << " in graph executable " << exec << "to params contained in " << child; @@ -1148,10 +1215,10 @@ struct BitPatternToValue { cuGraphExecChildGraphNodeSetParams(exec, node, child), "Failed to set CUDA graph child node params"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::LaunchKernel( +/* static */ absl::Status GpuDriver::LaunchKernel( GpuContext* context, absl::string_view kernel_name, CUfunction function, unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x, unsigned int block_dim_y, @@ -1163,6 +1230,10 @@ struct BitPatternToValue { << " bdx: " << block_dim_x << " bdy: " << block_dim_y << " bdz: " << block_dim_z << "; shared_mem_bytes: " << shared_mem_bytes; + + // TODO(ezhulenev): Why do we do it on every call to launch kernel? This + // should be moved one level up to se::Kernel level, and done just once (or + // updated once we get a new larger shared memory request). if (shared_mem_bytes != 0) { RETURN_IF_CUDA_RES_ERROR( cuFuncSetAttribute(function, @@ -1170,32 +1241,93 @@ struct BitPatternToValue { shared_mem_bytes), "Failed to set shared memory size"); } + RETURN_IF_CUDA_RES_ERROR( cuLaunchKernel(function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y, block_dim_z, shared_mem_bytes, stream, kernel_params, extra), "Failed to launch CUDA kernel: ", kernel_name, - " with block dimensions: ", block_dim_x, "x", block_dim_y, "x", - block_dim_z, " and grid dimensions: ", grid_dim_x, "x", grid_dim_y, "x", - grid_dim_z, " and shared memory size: ", shared_mem_bytes); - return ::tsl::OkStatus(); + "; block dims: ", block_dim_x, "x", block_dim_y, "x", block_dim_z, + "; grid dims: ", grid_dim_x, "x", grid_dim_y, "x", grid_dim_z, + "; shared memory size: ", shared_mem_bytes); + + return absl::OkStatus(); +} + +/* static */ absl::Status GpuDriver::LaunchKernel( + GpuContext* context, absl::string_view kernel_name, + GpuFunctionHandle function, unsigned int cluster_dim_x, + unsigned int cluster_dim_y, unsigned int cluster_dim_z, + unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, + unsigned int block_dim_x, unsigned int block_dim_y, + unsigned int block_dim_z, unsigned int shared_mem_bytes, + GpuStreamHandle stream, void** kernel_params, void** extra) { + ScopedActivateContext activation(context); + VLOG(2) << "launching kernel: " << kernel_name << "; cdx: " << cluster_dim_x + << " cdy: " << cluster_dim_y << " cdz: " << cluster_dim_z + << " gdx: " << grid_dim_x << " gdy: " << grid_dim_y + << " gdz: " << grid_dim_z << " bdx: " << block_dim_x + << " bdy: " << block_dim_y << " bdz: " << block_dim_z + << "; shared_mem_bytes: " << shared_mem_bytes; + + // TODO(ezhulenev): Why do we do it on every call to launch kernel? This + // should be moved one level up to se::Kernel level, and done just once (or + // updated once we get a new larger shared memory request). + if (shared_mem_bytes != 0) { + RETURN_IF_CUDA_RES_ERROR( + cuFuncSetAttribute(function, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shared_mem_bytes), + "Failed to set shared memory size"); + } + + CUlaunchConfig launch_config; + memset(&launch_config, 0, sizeof(launch_config)); + launch_config.blockDimX = block_dim_x; + launch_config.blockDimY = block_dim_y; + launch_config.blockDimZ = block_dim_z; + launch_config.gridDimX = grid_dim_x; + launch_config.gridDimY = grid_dim_y; + launch_config.gridDimZ = grid_dim_z; + launch_config.hStream = stream; + launch_config.sharedMemBytes = shared_mem_bytes; + + CUlaunchAttribute cluster_dims; + memset(&cluster_dims, 0, sizeof(cluster_dims)); + cluster_dims.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + cluster_dims.value.clusterDim.x = cluster_dim_x; + cluster_dims.value.clusterDim.y = cluster_dim_y; + cluster_dims.value.clusterDim.z = cluster_dim_z; + + launch_config.attrs = &cluster_dims; + launch_config.numAttrs = 1; + + RETURN_IF_CUDA_RES_ERROR( + cuLaunchKernelEx(&launch_config, function, kernel_params, extra), + "Failed to launch CUDA kernel: ", kernel_name, + "; cluster dims: ", cluster_dim_x, "x", cluster_dim_y, "x", cluster_dim_z, + "; block dims: ", block_dim_x, "x", block_dim_y, "x", block_dim_z, + "; grid dims: ", grid_dim_x, "x", grid_dim_y, "x", grid_dim_z, + "; shared memory size: ", shared_mem_bytes); + + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::LoadCubin(GpuContext* context, - const char* cubin_bytes, - CUmodule* module) { +/* static */ absl::Status GpuDriver::LoadCubin(GpuContext* context, + const char* cubin_bytes, + CUmodule* module) { ScopedActivateContext activation(context); RETURN_IF_CUDA_RES_ERROR( cuModuleLoadFatBinary(module, cubin_bytes), "Failed to load in-memory CUBIN (compiled for a different GPU?)."); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::LoadPtx(GpuContext* context, - const char* ptx_contents, - CUmodule* module) { +/* static */ absl::Status GpuDriver::LoadPtx(GpuContext* context, + const char* ptx_contents, + CUmodule* module) { absl::Notification notification; - tsl::Status ret = ::tsl::OkStatus(); + absl::Status ret = absl::OkStatus(); GetDriverExecutor()->Schedule([context, ptx_contents, module, &ret, ¬ification]() { ScopedActivateContext activation(context); @@ -1256,6 +1388,7 @@ struct BitPatternToValue { "Failed to load PTX text as a module: %s", ToString(res))); } notification.Notify(); + return; } VLOG(3) << "PTX compilation info log (" << info_log_buffer_bytes @@ -1270,50 +1403,48 @@ struct BitPatternToValue { return ret; } -/* static */ tsl::Status GpuDriver::LoadHsaco(GpuContext* context, - const char* hsaco_contents, - CUmodule* module) { - return tsl::errors::Internal( +/* static */ absl::Status GpuDriver::LoadHsaco(GpuContext* context, + const char* hsaco_contents, + CUmodule* module) { + return absl::InternalError( "Feature not supported on CUDA platform (LoadHsaco)"); } -/* static */ tsl::Status GpuDriver::SynchronousMemsetUint8(GpuContext* context, - CUdeviceptr location, - uint8_t value, - size_t size) { +/* static */ absl::Status GpuDriver::SynchronousMemsetUint8( + GpuContext* context, CUdeviceptr location, uint8_t value, size_t size) { ScopedActivateContext activation(context); RETURN_IF_CUDA_RES_ERROR(cuMemsetD8(location, value, size), "Failed to memset memory"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::SynchronousMemsetUint32( +/* static */ absl::Status GpuDriver::SynchronousMemsetUint32( GpuContext* context, CUdeviceptr location, uint32_t value, size_t uint32_count) { ScopedActivateContext activation(context); RETURN_IF_CUDA_RES_ERROR(cuMemsetD32(location, value, uint32_count), "Failed to memset memory"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::AsynchronousMemsetUint8( +/* static */ absl::Status GpuDriver::AsynchronousMemsetUint8( GpuContext* context, CUdeviceptr location, uint8_t value, size_t uint32_count, CUstream stream) { ScopedActivateContext activation(context); RETURN_IF_CUDA_RES_ERROR( cuMemsetD8Async(location, value, uint32_count, stream), "Failed to enqueue async memset operation"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::AsynchronousMemsetUint32( +/* static */ absl::Status GpuDriver::AsynchronousMemsetUint32( GpuContext* context, CUdeviceptr location, uint32_t value, size_t uint32_count, CUstream stream) { ScopedActivateContext activation(context); RETURN_IF_CUDA_RES_ERROR( cuMemsetD32Async(location, value, uint32_count, stream), "Failed to enqueue async memset operation"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } /* static */ bool GpuDriver::AddStreamCallback(GpuContext* context, @@ -1329,23 +1460,22 @@ struct BitPatternToValue { return true; } -/* static */ tsl::Status GpuDriver::GetModuleFunction(GpuContext* context, - CUmodule module, - const char* kernel_name, - CUfunction* function) { +/* static */ absl::Status GpuDriver::GetModuleFunction(GpuContext* context, + CUmodule module, + const char* kernel_name, + CUfunction* function) { ScopedActivateContext activated{context}; CHECK(module != nullptr && kernel_name != nullptr); cudaError_t cuda_error = cudaPeekAtLastError(); if (cuda_error != cudaSuccess) { - return tsl::Status( - absl::StatusCode::kInternal, + return absl::InternalError( absl::StrCat("There was an error before calling cuModuleGetFunction (", cuda_error, "): ", cudaGetErrorName(cuda_error), " : ", cudaGetErrorString(cuda_error))); } RETURN_IF_CUDA_RES_ERROR(cuModuleGetFunction(function, module, kernel_name), "Failed to get module function"); - return tsl::OkStatus(); + return absl::OkStatus(); } /* static */ bool GpuDriver::GetModuleSymbol(GpuContext* context, @@ -1377,7 +1507,7 @@ struct BitPatternToValue { } } -/* static */ tsl::StatusOr GpuDriver::DeviceFromContext( +/* static */ absl::StatusOr GpuDriver::DeviceFromContext( GpuContext* context) { ScopedActivateContext activated{context}; CUdevice device = -1; @@ -1386,8 +1516,7 @@ struct BitPatternToValue { return device; } - return tsl::Status( - absl::StatusCode::kInternal, + return absl::InternalError( absl::StrCat("failed to get device for context: ", ToString(result))); } @@ -1566,15 +1695,14 @@ struct BitPatternToValue { : lowest; } -#if CUDA_VERSION >= 10020 -/* static */ tsl::StatusOr GpuDriver::ReserveVirtualMemory( - GpuContext* context, uint64_t bytes) { +/* static */ absl::StatusOr +GpuDriver::ReserveVirtualMemory(GpuContext* context, uint64_t bytes) { ScopedActivateContext activation(context); CUdeviceptr base; CUresult res = cuMemAddressReserve(&base, bytes, /*alignment=*/0, /*addr=*/0, /*flags=*/0); if (res != CUDA_SUCCESS) { - return tsl::errors::Internal( + return absl::InternalError( absl::StrFormat("error reserving %d bytes of virtual GPU memory: %s", bytes, ToString(res))); } @@ -1591,7 +1719,7 @@ struct BitPatternToValue { } } -/* static */ tsl::StatusOr GpuDriver::GetMinAllocationGranularity( +/* static */ absl::StatusOr GpuDriver::GetMinAllocationGranularity( GpuDeviceHandle device) { CUmemAllocationProp props = {}; props.type = CU_MEM_ALLOCATION_TYPE_PINNED; @@ -1602,13 +1730,13 @@ struct BitPatternToValue { CUresult res = cuMemGetAllocationGranularity( &granularity, &props, CU_MEM_ALLOC_GRANULARITY_MINIMUM); if (res != CUDA_SUCCESS) { - return tsl::errors::Internal("failed to get min allocation granularity: ", - ToString(res)); + return absl::InternalError(absl::StrCat( + "failed to get min allocation granularity: ", ToString(res))); } return granularity; } -/* static */ tsl::StatusOr +/* static */ absl::StatusOr GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { ScopedActivateContext activation(context); auto device = DeviceFromContext(context); @@ -1625,7 +1753,7 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { CUmemGenericAllocationHandle mem_handle; CUresult res = cuMemCreate(&mem_handle, bytes, &props, 0); if (res != CUDA_SUCCESS) { - return tsl::errors::Internal( + return absl::InternalError( absl::StrFormat("failed to create memory allocation of size %d: %s", bytes, ToString(res))); } @@ -1643,7 +1771,7 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { } } -/* static */ tsl::Status GpuDriver::MapMemory( +/* static */ absl::Status GpuDriver::MapMemory( GpuContext* context, CUdeviceptr va, const GpuDriver::GenericMemoryHandle& handle, const std::vector& device_handles) { @@ -1658,7 +1786,7 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { CUresult res = cuMemMap(va, handle.bytes, /*offset=*/0, handle.handle, /*flags=*/0); if (res != CUDA_SUCCESS) { - return tsl::errors::Internal(absl::StrFormat( + return absl::InternalError(absl::StrFormat( "Failed to map %d bytes at %d: %s", handle.bytes, va, ToString(res))); } @@ -1677,11 +1805,11 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { LOG(ERROR) << "Failed to unmap memory in GpuDriver::MapMemory error path."; } - return tsl::errors::Internal(absl::StrFormat( + return absl::InternalError(absl::StrFormat( "Failed to set read/write access on memory mapped at %d: %s", va, ToString(res))); } - return tsl::OkStatus(); + return absl::OkStatus(); } /* static */ void GpuDriver::UnmapMemory(GpuContext* context, CUdeviceptr va, @@ -1695,37 +1823,33 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { } } -#endif - -/* static */ tsl::Status GpuDriver::DestroyEvent(GpuContext* context, - CUevent* event) { +/* static */ absl::Status GpuDriver::DestroyEvent(GpuContext* context, + CUevent* event) { if (*event == nullptr) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "input event cannot be null"); + return absl::InvalidArgumentError("input event cannot be null"); } ScopedActivateContext activated{context}; RETURN_IF_CUDA_RES_ERROR(cuEventDestroy(*event), "Error destroying CUDA event"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::RecordEvent(GpuContext* context, - CUevent event, - CUstream stream) { +/* static */ absl::Status GpuDriver::RecordEvent(GpuContext* context, + CUevent event, + CUstream stream) { ScopedActivateContext activated{context}; RETURN_IF_CUDA_RES_ERROR(cuEventRecord(event, stream), "Error recording CUDA event"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::StatusOr GpuDriver::QueryEvent(GpuContext* context, - CUevent event) { +/* static */ absl::StatusOr GpuDriver::QueryEvent(GpuContext* context, + CUevent event) { ScopedActivateContext activated{context}; CUresult res = cuEventQuery(event); if (res != CUDA_SUCCESS && res != CUDA_ERROR_NOT_READY) { - return tsl::Status( - absl::StatusCode::kInternal, + return absl::InternalError( absl::StrFormat("failed to query event: %s", ToString(res))); } @@ -1777,13 +1901,13 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { return true; } -/* static */ tsl::Status GpuDriver::SynchronizeStream(GpuContext* context, - CUstream stream) { +/* static */ absl::Status GpuDriver::SynchronizeStream(GpuContext* context, + CUstream stream) { ScopedActivateContext activated{context}; CHECK(stream != nullptr); RETURN_IF_CUDA_RES_ERROR(cuStreamSynchronize(stream), "Could not synchronize CUDA stream"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } /* static */ bool GpuDriver::IsStreamIdle(GpuContext* context, @@ -1801,10 +1925,10 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { return false; } -/* static */ tsl::Status GpuDriver::SynchronousMemcpyD2H(GpuContext* context, - void* host_dst, - CUdeviceptr gpu_src, - uint64_t size) { +/* static */ absl::Status GpuDriver::SynchronousMemcpyD2H(GpuContext* context, + void* host_dst, + CUdeviceptr gpu_src, + uint64_t size) { ScopedActivateContext activation(context); RETURN_IF_CUDA_RES_ERROR( cuMemcpyDtoH(host_dst, gpu_src, size), @@ -1813,13 +1937,13 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { host_dst, absl::bit_cast(gpu_src), size, size)); VLOG(2) << "successfully sync memcpy'd d2h of " << size << " bytes to " << host_dst; - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::SynchronousMemcpyH2D(GpuContext* context, - CUdeviceptr gpu_dst, - const void* host_src, - uint64_t size) { +/* static */ absl::Status GpuDriver::SynchronousMemcpyH2D(GpuContext* context, + CUdeviceptr gpu_dst, + const void* host_src, + uint64_t size) { ScopedActivateContext activation(context); RETURN_IF_CUDA_RES_ERROR( cuMemcpyHtoD(gpu_dst, host_src, size), @@ -1828,13 +1952,13 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { " host src: %p; size: %u=0x%x", absl::bit_cast(gpu_dst), host_src, size, size)); VLOG(2) << "successfully enqueued sync memcpy h2d of " << size << " bytes"; - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::SynchronousMemcpyD2D(GpuContext* context, - CUdeviceptr gpu_dst, - CUdeviceptr gpu_src, - uint64_t size) { +/* static */ absl::Status GpuDriver::SynchronousMemcpyD2D(GpuContext* context, + CUdeviceptr gpu_dst, + CUdeviceptr gpu_src, + uint64_t size) { ScopedActivateContext activation(context); CUresult result; @@ -1850,14 +1974,14 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { CreatedContexts::GetAnyContext(absl::bit_cast(gpu_src)); if (static_cast(dst_context) == nullptr) { - tsl::StatusOr tmp_context = GetPointerContext(gpu_dst); + absl::StatusOr tmp_context = GetPointerContext(gpu_dst); if (tmp_context.ok()) { dst_context = tmp_context.value()->context(); } } if (static_cast(src_context) == nullptr) { - tsl::StatusOr tmp_context = GetPointerContext(gpu_src); + absl::StatusOr tmp_context = GetPointerContext(gpu_src); if (tmp_context.ok()) { src_context = tmp_context.value()->context(); } @@ -1874,7 +1998,7 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { absl::bit_cast(gpu_dst), absl::bit_cast(gpu_src), size, size)); VLOG(2) << "successfully sync memcpy'd d2d of " << size << " bytes"; - return ::tsl::OkStatus(); + return absl::OkStatus(); } /* static */ bool GpuDriver::AsynchronousMemcpyD2H(GpuContext* context, @@ -1927,7 +2051,7 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { // In graph capture mode we never have operations that access peer memory, so // we can always make a call to cuMemcpyDtoDAsync. - tsl::StatusOr is_capturing = StreamIsCapturing(stream); + absl::StatusOr is_capturing = StreamIsCapturing(stream); if (!is_capturing.ok()) { LOG(ERROR) << is_capturing.status().message(); return false; @@ -1945,14 +2069,14 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { CreatedContexts::GetAnyContext(absl::bit_cast(gpu_src)); if (static_cast(dst_context) == nullptr) { - tsl::StatusOr tmp_context = GetPointerContext(gpu_dst); + absl::StatusOr tmp_context = GetPointerContext(gpu_dst); if (tmp_context.ok()) { dst_context = tmp_context.value()->context(); } } if (static_cast(src_context) == nullptr) { - tsl::StatusOr tmp_context = GetPointerContext(gpu_src); + absl::StatusOr tmp_context = GetPointerContext(gpu_src); if (tmp_context.ok()) { src_context = tmp_context.value()->context(); } @@ -1988,9 +2112,9 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { return true; } -/* static */ tsl::Status GpuDriver::InitEvent(GpuContext* context, - CUevent* result, - EventFlags flags) { +/* static */ absl::Status GpuDriver::InitEvent(GpuContext* context, + CUevent* result, + EventFlags flags) { int cuflags; switch (flags) { case EventFlags::kDefault: @@ -2007,13 +2131,12 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { CUresult res = cuEventCreate(result, cuflags); if (res == CUDA_SUCCESS) { - return ::tsl::OkStatus(); + return absl::OkStatus(); } else if (res == CUDA_ERROR_OUT_OF_MEMORY) { - return tsl::Status(absl::StatusCode::kResourceExhausted, - "could not create CUDA event: out of device memory"); + return absl::ResourceExhaustedError( + "could not create CUDA event: out of device memory"); } else { - return tsl::Status( - absl::StatusCode::kFailedPrecondition, + return absl::FailedPreconditionError( absl::StrCat("could not create CUDA event: ", ToString(res))); } } @@ -2026,13 +2149,10 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { return 0; } - if (FLAGS_gpuexec_cuda_device_0_only && device_count > 1) { - device_count = 1; - } return device_count; } -/* static */ tsl::StatusOr GpuDriver::GetPointerContext( +/* static */ absl::StatusOr GpuDriver::GetPointerContext( CUdeviceptr pointer) { GpuContext* context = nullptr; CUresult result = @@ -2044,20 +2164,17 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { // handling. So all is working fine, but TF have a different // error then the original one. if (context == nullptr) { - return tsl::Status( - absl::StatusCode::kUnavailable, + return absl::UnavailableError( "Empty context returned while querying context for device pointer"); } return context; } - return tsl::Status( - absl::StatusCode::kInternal, - absl::StrCat("failed to query context for device pointer: ", - ToString(result))); + return absl::InternalError(absl::StrCat( + "failed to query context for device pointer: ", ToString(result))); } -/* static */ tsl::StatusOr GpuDriver::GetPointerMemorySpace( +/* static */ absl::StatusOr GpuDriver::GetPointerMemorySpace( CUdeviceptr pointer) { unsigned int value; CUresult result = @@ -2069,41 +2186,36 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { case CU_MEMORYTYPE_HOST: return MemorySpace::kHost; default: - return tsl::Status( - absl::StatusCode::kInternal, + return absl::InternalError( absl::StrCat("unknown memory space provided by CUDA API: ", value)); } } - return tsl::Status( - absl::StatusCode::kInternal, - absl::StrCat("failed to query device pointer for memory space: ", - ToString(result))); + return absl::InternalError(absl::StrCat( + "failed to query device pointer for memory space: ", ToString(result))); } -/* static */ tsl::Status GpuDriver::GetPointerAddressRange(CUdeviceptr dptr, - CUdeviceptr* base, - size_t* size) { +/* static */ absl::Status GpuDriver::GetPointerAddressRange(CUdeviceptr dptr, + CUdeviceptr* base, + size_t* size) { CUresult result = cuMemGetAddressRange(base, size, dptr); if (result == CUDA_SUCCESS) { - return ::tsl::OkStatus(); + return absl::OkStatus(); } else if (result == CUDA_ERROR_NOT_FOUND) { // We differentiate between "this pointer is unknown" (return here) and // "there was an internal error while performing this operation" (return // below). - return tsl::Status( - absl::StatusCode::kNotFound, - absl::StrFormat("not a device pointer %p; %s", - reinterpret_cast(dptr), ToString(result))); + return absl::NotFoundError(absl::StrFormat("not a device pointer %p; %s", + reinterpret_cast(dptr), + ToString(result))); } - return tsl::Status( - absl::StatusCode::kInternal, + return absl::InternalError( absl::StrFormat("failed to get pointer into for device pointer %p; %s", reinterpret_cast(dptr), ToString(result))); } -/* static */ tsl::StatusOr GpuDriver::GetPointerDevice( +/* static */ absl::StatusOr GpuDriver::GetPointerDevice( CUdeviceptr pointer) { auto result = GetPointerContext(pointer); if (!result.ok()) { @@ -2113,44 +2225,40 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { return DeviceFromContext(result.value()); } -/* static */ tsl::Status GpuDriver::GetComputeCapability(int* cc_major, - int* cc_minor, - CUdevice device) { +/* static */ absl::Status GpuDriver::GetComputeCapability(int* cc_major, + int* cc_minor, + CUdevice device) { *cc_major = 0; *cc_minor = 0; CUresult res = cuDeviceGetAttribute( cc_major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device); if (res != CUDA_SUCCESS) { - return tsl::Status( - absl::StatusCode::kInternal, - absl::StrFormat( - "failed to get compute capability major for device: %s; %d", - ToString(res), device)); + return absl::InternalError(absl::StrFormat( + "failed to get compute capability major for device: %s; %d", + ToString(res), device)); } res = cuDeviceGetAttribute( cc_minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device); if (res != CUDA_SUCCESS) { - return tsl::Status( - absl::StatusCode::kInternal, - absl::StrFormat( - "failed to get compute capability minor for device: %s; %d", - ToString(res), device)); + return absl::InternalError(absl::StrFormat( + "failed to get compute capability minor for device: %s; %d", + ToString(res), device)); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::GetGpuISAVersion(int* version, - CUdevice device) { - return tsl::Status{ +/* static */ absl::Status GpuDriver::GetGpuISAVersion(int* version, + CUdevice device) { + return absl::Status{ absl::StatusCode::kInternal, "Feature not supported on CUDA platform (GetGpuISAVersion)"}; } -/* static */ tsl::Status GpuDriver::GetGpuGCNArchName(CUdevice, std::string*) { - return tsl::Status{ +/* static */ absl::Status GpuDriver::GetGpuGCNArchName(CUdevice, std::string*) { + return absl::Status{ absl::StatusCode::kInternal, "Feature not supported on CUDA platform (GetGpuGCNArchName)"}; } @@ -2158,8 +2266,8 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { // Helper function that turns the integer output of cuDeviceGetAttribute to type // T and wraps it in a StatusOr. template -static tsl::StatusOr GetSimpleAttribute(CUdevice device, - CUdevice_attribute attribute) { +static absl::StatusOr GetSimpleAttribute(CUdevice device, + CUdevice_attribute attribute) { int value = -1; RETURN_IF_CUDA_RES_ERROR(cuDeviceGetAttribute(&value, attribute, device), "Could not retrieve CUDA device attribute (", @@ -2168,55 +2276,55 @@ static tsl::StatusOr GetSimpleAttribute(CUdevice device, return converted; } -/* static */ tsl::StatusOr GpuDriver::GetMultiprocessorCount( +/* static */ absl::StatusOr GpuDriver::GetMultiprocessorCount( CUdevice device) { return GetSimpleAttribute(device, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT); } -/* static */ tsl::StatusOr GpuDriver::GetMaxSharedMemoryPerCore( +/* static */ absl::StatusOr GpuDriver::GetMaxSharedMemoryPerCore( CUdevice device) { return GetSimpleAttribute( device, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR); } -/* static */ tsl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlock( +/* static */ absl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlock( CUdevice device) { return GetSimpleAttribute( device, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK); } -tsl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlockOptin( +absl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlockOptin( CUdevice device) { return GetSimpleAttribute( device, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN); } -/* static */ tsl::StatusOr GpuDriver::GetMaxThreadsPerMultiprocessor( +/* static */ absl::StatusOr GpuDriver::GetMaxThreadsPerMultiprocessor( CUdevice device) { return GetSimpleAttribute( device, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR); } -/* static */ tsl::StatusOr GpuDriver::GetMaxThreadsPerBlock( +/* static */ absl::StatusOr GpuDriver::GetMaxThreadsPerBlock( CUdevice device) { return GetSimpleAttribute(device, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK); } -/* static */ tsl::StatusOr GpuDriver::GetMaxRegistersPerBlock( +/* static */ absl::StatusOr GpuDriver::GetMaxRegistersPerBlock( CUdevice device) { return GetSimpleAttribute( device, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK); } -/* static */ tsl::StatusOr GpuDriver::GetThreadsPerWarp( +/* static */ absl::StatusOr GpuDriver::GetThreadsPerWarp( CUdevice device) { return GetSimpleAttribute(device, CU_DEVICE_ATTRIBUTE_WARP_SIZE); } -/* static */ tsl::Status GpuDriver::GetGridLimits(int* x, int* y, int* z, - CUdevice device) { +/* static */ absl::Status GpuDriver::GetGridLimits(int* x, int* y, int* z, + CUdevice device) { int value; RETURN_IF_CUDA_RES_ERROR( cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, device), @@ -2232,17 +2340,14 @@ tsl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlockOptin( cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z, device), "Could not get device attribute"); *z = value; - return tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ bool GpuDriver::GetDriverVersion(int* driver_version) { - CUresult res = cuDriverGetVersion(driver_version); - if (res != CUDA_SUCCESS) { - LOG(ERROR) << "failed to query driver version: " << ToString(res); - return false; - } - - return true; +/* static */ absl::StatusOr GpuDriver::GetDriverVersion() { + int32_t version; + RETURN_IF_CUDA_RES_ERROR(cuDriverGetVersion(&version), + "Could not get driver version"); + return version; } /* static */ bool GpuDriver::GetDeviceProperties(CUdevprop* device_properties, @@ -2256,13 +2361,12 @@ tsl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlockOptin( return true; } -/* static */ tsl::StatusOr GpuDriver::GetDeviceAttribute( +/* static */ absl::StatusOr GpuDriver::GetDeviceAttribute( CUdevice_attribute attribute, CUdevice device) { int val; CUresult res = cuDeviceGetAttribute(&val, attribute, device); if (res != CUDA_SUCCESS) { - return tsl::Status( - absl::StatusCode::kInternal, + return absl::InternalError( absl::StrFormat("failed to get device attribute %d for device %d: %s", attribute, device, ToString(res))); } @@ -2301,7 +2405,7 @@ tsl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlockOptin( /* static */ bool GpuDriver::GetDeviceTotalMemory(CUdevice device, uint64_t* result) { - size_t value = -1; + size_t value{}; CUresult res = cuDeviceTotalMem(&value, device); if (res != CUDA_SUCCESS) { LOG(ERROR) << "failed to query total available memory: " << ToString(res); @@ -2359,35 +2463,35 @@ tsl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlockOptin( return can_access_peer; } -/* static */ tsl::Status GpuDriver::EnablePeerAccess(GpuContext* from, - GpuContext* to) { +/* static */ absl::Status GpuDriver::EnablePeerAccess(GpuContext* from, + GpuContext* to) { if (from == to) { - return ::tsl::OkStatus(); // A context can always access its own - // memory. + return absl::OkStatus(); // A context can always access its own + // memory. } ScopedActivateContext activated{from}; CUresult result = cuCtxEnablePeerAccess(to->context(), 0 /* = flags */); if (result != CUDA_SUCCESS && result != CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED) { - return tsl::Status( - absl::StatusCode::kInternal, + return absl::InternalError( absl::StrFormat("failed to enable peer access from %p to %p: %s", from, to, ToString(result))); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::StatusOr GpuDriver::GetMaxOccupiedBlocksPerCore( +/* static */ absl::StatusOr GpuDriver::GetMaxOccupiedBlocksPerCore( GpuContext* context, CUfunction kernel, int threads_per_block, size_t dynamic_shared_memory_bytes) { ScopedActivateContext activation(context); int max_blocks; RETURN_IF_CUDA_RES_ERROR( - cuOccupancyMaxActiveBlocksPerMultiprocessor( - &max_blocks, kernel, threads_per_block, dynamic_shared_memory_bytes), + cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + &max_blocks, kernel, threads_per_block, dynamic_shared_memory_bytes, + CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE), absl::StrFormat("Failed to calculate occupancy of kernel %p", kernel)); return max_blocks; } diff --git a/xla/stream_executor/cuda/cuda_driver.h b/xla/stream_executor/cuda/cuda_driver.h index 7f3314fbff585..a72740ef4ead8 100644 --- a/xla/stream_executor/cuda/cuda_driver.h +++ b/xla/stream_executor/cuda/cuda_driver.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,10 +18,19 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_ #define XLA_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_ +#include +#include +#include +#include +#include +#include + #include "absl/container/node_hash_map.h" -#include "absl/memory/memory.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" +#include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/gpu/gpu_driver.h" namespace stream_executor { @@ -109,9 +118,8 @@ class CreatedContexts { } } - // Return the context associated to that ptr. - static CUcontext GetAnyContext(void* ptr) { - absl::ReaderMutexLock lock(&mu_); + // Find device id from cuda pointer value. + static int GetDeviceOrdinal(void* ptr) { int device_ordinal; CUresult result = cuPointerGetAttribute(static_cast(&device_ordinal), CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, @@ -120,6 +128,13 @@ class CreatedContexts { LOG(FATAL) << "Not able to get the device_ordinal for ptr: " << ptr << ". Error: " << ToString(result); } + return device_ordinal; + } + + // Return the context associated to that ptr. + static CUcontext GetAnyContext(void* ptr) { + absl::ReaderMutexLock lock(&mu_); + int device_ordinal = GetDeviceOrdinal(ptr); CHECK_EQ(LiveOrdinal()->count(device_ordinal), 1); CHECK(!LiveOrdinal()->at(device_ordinal).empty()) << "Need at least one context."; diff --git a/xla/stream_executor/cuda/cuda_driver_test.cc b/xla/stream_executor/cuda/cuda_driver_test.cc index 90c9f708e23cc..da4d78118f51d 100644 --- a/xla/stream_executor/cuda/cuda_driver_test.cc +++ b/xla/stream_executor/cuda/cuda_driver_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#include "absl/log/log.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/driver_types.h" +#include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/cuda/cuda_driver.h" -#include "absl/memory/memory.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "tsl/platform/test.h" @@ -72,5 +74,3 @@ TEST(CudaDriverTest, ScopedActivateContextTest) { } // namespace gpu } // namespace stream_executor - -#endif // GOOGLE_CUDA diff --git a/xla/stream_executor/cuda/cuda_event.cc b/xla/stream_executor/cuda/cuda_event.cc index 1a61f500af285..f42cf47e86e1a 100644 --- a/xla/stream_executor/cuda/cuda_event.cc +++ b/xla/stream_executor/cuda/cuda_event.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,17 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/stream_executor/cuda/cuda_event.h" - -#include "xla/stream_executor/cuda/cuda_stream.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" -#include "tsl/platform/statusor.h" namespace stream_executor { namespace gpu { Event::Status GpuEvent::PollForStatus() { - tsl::StatusOr status = + absl::StatusOr status = GpuDriver::QueryEvent(parent_->gpu_context(), gpu_event_); if (!status.ok()) { LOG(ERROR) << "Error polling for event status: " diff --git a/xla/stream_executor/cuda/cuda_event.h b/xla/stream_executor/cuda/cuda_event.h index b4aa70392a46c..d806b9f992e6f 100644 --- a/xla/stream_executor/cuda/cuda_event.h +++ b/xla/stream_executor/cuda/cuda_event.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index 500ac1789e495..d0b239cf06eb3 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,9 +13,28 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include +#include #include +#include +#include +#include #include +#include + +#include "absl/base/casts.h" +#include "absl/numeric/int128.h" +#include "absl/strings/str_join.h" +#include "xla/stream_executor/blas.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/fft.h" +#include "xla/stream_executor/gpu/gpu_diagnostics.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" #if defined(PLATFORM_WINDOWS) #include @@ -23,28 +42,36 @@ limitations under the License. #else #include #endif + #include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/cuda/cuda_diagnostics.h" #include "xla/stream_executor/cuda/cuda_driver.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/gpu/gpu_collectives.h" #include "xla/stream_executor/gpu/gpu_command_buffer.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_kernel.h" #include "xla/stream_executor/gpu/gpu_runtime.h" +#include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/integrations/device_mem_allocator.h" #include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/stream.h" @@ -52,8 +79,8 @@ limitations under the License. #include "xla/stream_executor/stream_executor_internal.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" +#include "tsl/platform/fingerprint.h" #include "tsl/platform/logging.h" -#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" // LOG(ERROR) uses a const named ERROR, so a macro with the same name is @@ -109,17 +136,16 @@ GpuExecutor::~GpuExecutor() { } } -tsl::Status GpuExecutor::Init(int device_ordinal, - DeviceOptions device_options) { +absl::Status GpuExecutor::Init(int device_ordinal) { device_ordinal_ = device_ordinal; TF_RETURN_IF_ERROR(GpuDriver::Init()); TF_RETURN_IF_ERROR(GpuDriver::GetDevice(device_ordinal_, &device_)); - TF_RETURN_IF_ERROR(GpuDriver::CreateContext(device_ordinal_, device_, - device_options, &context_)); + TF_RETURN_IF_ERROR( + GpuDriver::CreateContext(device_ordinal_, device_, &context_)); TF_RETURN_IF_ERROR( GpuDriver::GetComputeCapability(&cc_major_, &cc_minor_, device_)); - return tsl::OkStatus(); + return absl::OkStatus(); } // Returns the path to the running executable. @@ -138,8 +164,8 @@ static std::string GetBinaryDir(bool strip_exe) { return exe_path; } -tsl::Status GpuExecutor::LoadModuleFromCuBin(const char* cubin, - CUmodule* module) { +absl::Status GpuExecutor::LoadModuleFromCuBin(const char* cubin, + CUmodule* module) { uint64_t module_refcount; std::tie(*module, module_refcount) = gpu_binary_to_module_[cubin]; @@ -154,10 +180,10 @@ tsl::Status GpuExecutor::LoadModuleFromCuBin(const char* cubin, << " is already loaded as module " << *module; } gpu_binary_to_module_[cubin] = {*module, module_refcount}; - return ::tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status GpuExecutor::LoadModuleFromPtx(const char* ptx, CUmodule* module) { +absl::Status GpuExecutor::LoadModuleFromPtx(const char* ptx, CUmodule* module) { uint64_t module_refcount; std::tie(*module, module_refcount) = gpu_binary_to_module_[ptx]; @@ -172,17 +198,17 @@ tsl::Status GpuExecutor::LoadModuleFromPtx(const char* ptx, CUmodule* module) { << " is already loaded as module " << module; } gpu_binary_to_module_[ptx] = {*module, module_refcount}; - return ::tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status GpuExecutor::LoadModuleFromHsaco(const char* hsaco, - CUmodule* module) { - return tsl::errors::Internal( +absl::Status GpuExecutor::LoadModuleFromHsaco(const char* hsaco, + CUmodule* module) { + return absl::InternalError( "Feature not supported on CUDA platform (LoadModuleFromHsaco)"); } -tsl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) { +absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, + Kernel* kernel) { GpuKernel* cuda_kernel = AsGpuKernel(kernel); CUmodule module; const std::string* kernel_name; @@ -192,7 +218,8 @@ tsl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, if (spec.has_cuda_cubin_in_memory()) { absl::MutexLock lock{&in_memory_modules_mu_}; kernel_name = &spec.cuda_cubin_in_memory().kernel_name(); - const char* cubin = spec.cuda_cubin_in_memory().bytes(); + const char* cubin = reinterpret_cast( + spec.cuda_cubin_in_memory().cubin_bytes().data()); TF_RETURN_IF_ERROR(LoadModuleFromCuBin(cubin, &module)); kernel_to_gpu_binary_[kernel] = cubin; @@ -200,7 +227,7 @@ tsl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, kernel_name = &spec.cuda_ptx_in_memory().kernel_name(); if (cc_major_ == 0 && cc_minor_ == 0) { - return tsl::errors::Internal("Compute capability not set"); + return absl::InternalError("Compute capability not set"); } const char* ptx = spec.cuda_ptx_in_memory().text(cc_major_, cc_minor_); @@ -227,7 +254,7 @@ tsl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, *cuda_kernel->gpu_function_ptr() = function; } else { - return tsl::errors::Internal("No method of loading CUDA kernel provided"); + return absl::InternalError("No method of loading CUDA kernel provided"); } // If we resolved kernel from a symbol pointer, there is no need to load it @@ -239,6 +266,10 @@ tsl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, cuda_kernel->gpu_function_ptr())); } + // Update CUDA kernel properties after it was loaded in the CUDA context. + cuda_kernel->set_name(*kernel_name); + cuda_kernel->set_gpu_context(context_); + // We have to trust the kernel loader spec arity because there doesn't appear // to be a way to reflect on the number of expected arguments w/the CUDA API. cuda_kernel->set_arity(spec.arity()); @@ -247,8 +278,8 @@ tsl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, TF_RETURN_IF_ERROR(GetKernelMetadata(cuda_kernel, &kernel_metadata)); kernel->set_metadata(kernel_metadata); kernel->set_name(*kernel_name); - kernel->set_kernel_args_packing(spec.kernel_args_packing()); - return ::tsl::OkStatus(); + kernel->set_args_packing(spec.kernel_args_packing()); + return absl::OkStatus(); } bool GpuExecutor::UnloadGpuBinary(const void* gpu_binary) { @@ -286,8 +317,8 @@ void GpuExecutor::UnloadKernel(const Kernel* kernel) { kernel_to_gpu_binary_.erase(gpu_binary_it); } -tsl::Status GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec, - ModuleHandle* module_handle) { +absl::Status GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec, + ModuleHandle* module_handle) { // In GpuExecutor we store the pointer to the GPU binary (PTX or CUBIN) as // ModuleHandle::id(). CUmodule cu_module; @@ -298,14 +329,14 @@ tsl::Status GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec, &cu_module)); *module_handle = ModuleHandle(const_cast( static_cast(spec.cuda_cubin_in_memory().data()))); - return ::tsl::OkStatus(); + return absl::OkStatus(); } else if (spec.has_cuda_ptx_in_memory()) { if (cc_major_ == 0 && cc_minor_ == 0) { - return tsl::errors::Internal("Compute capability not set"); + return absl::InternalError("Compute capability not set"); } if (!spec.cuda_ptx_in_memory()) { - return tsl::errors::Internal("PTX not found in spec"); + return absl::InternalError("PTX not found in spec"); } absl::MutexLock lock{&in_memory_modules_mu_}; @@ -313,9 +344,9 @@ tsl::Status GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec, LoadModuleFromPtx(spec.cuda_ptx_in_memory(), &cu_module)); *module_handle = ModuleHandle( const_cast(static_cast(spec.cuda_ptx_in_memory()))); - return ::tsl::OkStatus(); + return absl::OkStatus(); } - return tsl::errors::Internal("No method of loading CUDA module provided"); + return absl::InternalError("No method of loading CUDA module provided"); } bool GpuExecutor::UnloadModule(ModuleHandle module_handle) { @@ -345,7 +376,7 @@ int fpus_per_core(int cc_major, int cc_minor) { } // namespace -tsl::StatusOr> +absl::StatusOr> GpuExecutor::CreateOrShareConstant(Stream* stream, absl::Span content) { absl::MutexLock lock{&shared_constants_mu_}; @@ -372,16 +403,16 @@ GpuExecutor::CreateOrShareConstant(Stream* stream, DeviceMemoryBase* new_constant = new DeviceMemoryBase(Allocate(content.size(), /*memory_space=*/0)); if (new_constant->opaque() == nullptr) { - return tsl::errors::Internal(absl::StrFormat( + return absl::InternalError(absl::StrFormat( "Failed to allocate %d bytes for new constant", content.size())); } - tsl::Status status = - stream->ThenMemcpy(new_constant, content.data(), content.size()) - .BlockHostUntilDone(); + TF_RETURN_IF_ERROR( + stream->Memcpy(new_constant, content.data(), content.size())); + absl::Status status = stream->BlockHostUntilDone(); if (!status.ok()) { Deallocate(new_constant); - status.Update(tsl::errors::Internal(absl::StrFormat( + status.Update(absl::InternalError(absl::StrFormat( "Memcpy to device address %p failed", new_constant->opaque()))); return status; } @@ -399,8 +430,8 @@ GpuExecutor::CreateOrShareConstant(Stream* stream, return shared_constant; } -tsl::Status GpuExecutor::GetKernelMetadata(GpuKernel* cuda_kernel, - KernelMetadata* kernel_metadata) { +absl::Status GpuExecutor::GetKernelMetadata(GpuKernel* cuda_kernel, + KernelMetadata* kernel_metadata) { int value; TF_RETURN_IF_ERROR(GpuDriver::FuncGetAttribute( CU_FUNC_ATTRIBUTE_NUM_REGS, *cuda_kernel->gpu_function_ptr(), &value)); @@ -410,12 +441,27 @@ tsl::Status GpuExecutor::GetKernelMetadata(GpuKernel* cuda_kernel, GpuDriver::FuncGetAttribute(CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, *cuda_kernel->gpu_function_ptr(), &value)); kernel_metadata->set_shared_memory_bytes(value); - return ::tsl::OkStatus(); + return absl::OkStatus(); +} + +absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, + const BlockDim& block_dims, + const Kernel& kernel, const KernelArgs& args) { + return Launch(stream, thread_dims, block_dims, std::nullopt, kernel, args); +} + +absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, + const BlockDim& block_dims, + const ClusterDim& cluster_dims, + const Kernel& kernel, const KernelArgs& args) { + return Launch(stream, thread_dims, block_dims, + std::make_optional(cluster_dims), kernel, args); } -tsl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const Kernel& kernel, const KernelArgs& args) { +absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, + const BlockDim& block_dims, + const std::optional& cluster_dims, + const Kernel& kernel, const KernelArgs& args) { CUstream custream = AsGpuStreamValue(stream); const GpuKernel* cuda_kernel = AsGpuKernel(&kernel); CUfunction cufunc = cuda_kernel->AsGpuFunctionHandle(); @@ -426,29 +472,46 @@ tsl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, if (VLOG_IS_ON(2)) { absl::MutexLock lock(&launched_kernels_mu_); if (!launched_kernels_.count(cufunc)) { - VlogOccupancyInfo(kernel, thread_dims, block_dims); + VlogOccupancyInfo(stream->parent()->GetDeviceDescription(), kernel, + thread_dims, block_dims); // TODO(rspringer): Remove elements from launched_kernels_...if we ever // expose a kernel/module deallocation method. launched_kernels_.insert(cufunc); } } - if (cuda_kernel->GetPreferredCacheConfig() != - KernelCacheConfig::kNoPreference) { + if (cuda_kernel->cache_config() != KernelCacheConfig::kNoPreference) { TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig( cufunc, cuda_kernel->GetGpuCacheConfig())); } // Launch CUDA kernels with packed arguments. auto launch = [&](const KernelArgsPackedArrayBase& packed) { - CHECK_EQ(kernel.Arity() + (packed.number_of_shared_bytes() > 0), - packed.number_of_arguments()); + int32_t expected_number_of_arguments = + kernel.Arity() + (packed.number_of_shared_bytes() > 0); + + CHECK_EQ(expected_number_of_arguments, packed.number_of_arguments()) + << "Kernel " << kernel.name() << " has " << packed.number_of_arguments() + << " arguments, but expected " << expected_number_of_arguments + << "; arity=" << kernel.Arity() + << "; number_of_shared_bytes=" << packed.number_of_shared_bytes(); + void** params = const_cast(packed.argument_addresses().data()); - return GpuDriver::LaunchKernel(context_, kernel.name(), cufunc, - block_dims.x, block_dims.y, block_dims.z, - thread_dims.x, thread_dims.y, thread_dims.z, - packed.number_of_shared_bytes(), custream, - params, nullptr /* = extra */); + + if (cluster_dims.has_value()) { + return GpuDriver::LaunchKernel( + context_, kernel.name(), cufunc, cluster_dims->x, cluster_dims->y, + cluster_dims->z, block_dims.x, block_dims.y, block_dims.z, + thread_dims.x, thread_dims.y, thread_dims.z, + packed.number_of_shared_bytes(), custream, params, + /*extra=*/nullptr); + } else { + return GpuDriver::LaunchKernel( + context_, kernel.name(), cufunc, block_dims.x, block_dims.y, + block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, + packed.number_of_shared_bytes(), custream, params, + /*extra=*/nullptr); + } }; // If arguments are already packed we can just launch the kernel. @@ -458,36 +521,38 @@ tsl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, // For device memory array we rely on a custom kernel arguments packing. if (auto* device_mem = DynCast(&args)) { - auto& pack = kernel.kernel_args_packing(); - if (!pack) + auto& pack = kernel.args_packing(); + if (!pack) { return absl::InternalError( "Kernel is missing a custom arguments packing function for device " "memory arguments array"); + } - TF_ASSIGN_OR_RETURN(auto packed, pack(*device_mem)); + TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem)); return launch(*packed); } return absl::InternalError("Unsupported kernel arguments type"); } -tsl::Status GpuExecutor::Submit(Stream* stream, - const CommandBuffer& command_buffer) { +absl::Status GpuExecutor::Submit(Stream* stream, + const CommandBuffer& command_buffer) { if (command_buffer.mode() != CommandBuffer::Mode::kPrimary) { return absl::InvalidArgumentError( "Can't submit non-primary command buffer for execution"); } auto exec = GpuCommandBuffer::Cast(&command_buffer)->executable(); - VLOG(3) << "Launch command buffer execuable graph " << exec - << " on a stream: " << stream->DebugStreamPointers(); + VLOG(3) << "Launch command buffer executable graph " << exec + << " on a stream: " << stream; return GpuDriver::GraphLaunch(exec, AsGpuStreamValue(stream)); } // This is a non-essential operation; if there's a failure, proceed without // logging an error. It's nearly certain that in case of failures, we'd never // get here in the first place; these are very low-impact routines. -void GpuExecutor::VlogOccupancyInfo(const Kernel& kernel, +void GpuExecutor::VlogOccupancyInfo(const DeviceDescription& device_description, + const Kernel& kernel, const ThreadDim& thread_dims, const BlockDim& block_dims) { VLOG(2) << "Computing kernel occupancy for kernel " @@ -502,9 +567,6 @@ void GpuExecutor::VlogOccupancyInfo(const Kernel& kernel, return; } - const DeviceDescription& device_description = - kernel.parent()->GetDeviceDescription(); - const GpuKernel* cuda_kernel = AsGpuKernel(&kernel); CUfunction cufunc = cuda_kernel->AsGpuFunctionHandle(); @@ -563,16 +625,20 @@ int GpuExecutor::CompareOccupancy(int* initial_blocks, } DeviceMemoryBase GpuExecutor::Allocate(uint64_t size, int64_t memory_space) { + if (memory_space == 1) { + auto result = GpuCollectives::CollectiveMemoryAllocate(context_, size); + if (!result.ok()) { + LOG(ERROR) << result.status(); + } + return DeviceMemoryBase(*result, size); + } else if (memory_space == + static_cast(stream_executor::MemoryType::kHost)) { + return DeviceMemoryBase(GpuDriver::HostAllocate(context_, size), size); + } CHECK_EQ(memory_space, 0); return DeviceMemoryBase(GpuDriver::DeviceAllocate(context_, size), size); } -void* GpuExecutor::GetSubBuffer(DeviceMemoryBase* mem, uint64_t offset_bytes, - uint64_t size_bytes) { - // offset and size are in bytes, so char* works as the pointer type. - return reinterpret_cast(mem->opaque()) + offset_bytes; -} - void GpuExecutor::Deallocate(DeviceMemoryBase* mem) { GpuDriver::DeviceDeallocate(context_, mem->opaque()); } @@ -595,8 +661,8 @@ bool GpuExecutor::SynchronizeAllActivity() { return GpuDriver::SynchronizeContext(context_); } -tsl::Status GpuExecutor::SynchronousMemZero(DeviceMemoryBase* location, - uint64_t size) { +absl::Status GpuExecutor::SynchronousMemZero(DeviceMemoryBase* location, + uint64_t size) { if (reinterpret_cast(location->opaque()) % 4 == 0 && size % 4 == 0) { return GpuDriver::SynchronousMemsetUint32( @@ -606,8 +672,8 @@ tsl::Status GpuExecutor::SynchronousMemZero(DeviceMemoryBase* location, 0x0, size); } -tsl::Status GpuExecutor::SynchronousMemSet(DeviceMemoryBase* location, - int value, uint64_t size) { +absl::Status GpuExecutor::SynchronousMemSet(DeviceMemoryBase* location, + int value, uint64_t size) { if (reinterpret_cast(location->opaque()) % 4 == 0 && size % 4 == 0) { // cudaMemset reinterprets "value" as a uint8_t. @@ -621,28 +687,28 @@ tsl::Status GpuExecutor::SynchronousMemSet(DeviceMemoryBase* location, value, size); } -tsl::Status GpuExecutor::SynchronousMemcpy(DeviceMemoryBase* gpu_dst, - const void* host_src, - uint64_t size) { +absl::Status GpuExecutor::SynchronousMemcpy(DeviceMemoryBase* gpu_dst, + const void* host_src, + uint64_t size) { return GpuDriver::SynchronousMemcpyH2D(context_, AsCudaDevicePtr(gpu_dst), host_src, size); } -tsl::Status GpuExecutor::SynchronousMemcpy(void* host_dst, - const DeviceMemoryBase& gpu_src, - uint64_t size) { +absl::Status GpuExecutor::SynchronousMemcpy(void* host_dst, + const DeviceMemoryBase& gpu_src, + uint64_t size) { return GpuDriver::SynchronousMemcpyD2H(context_, host_dst, AsCudaDevicePtr(gpu_src), size); } -tsl::Status GpuExecutor::SynchronousMemcpyDeviceToDevice( +absl::Status GpuExecutor::SynchronousMemcpyDeviceToDevice( DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, uint64_t size) { return GpuDriver::SynchronousMemcpyD2D(context_, AsCudaDevicePtr(gpu_dst), AsCudaDevicePtr(gpu_src), size); } -tsl::Status GpuExecutor::MemZero(Stream* stream, DeviceMemoryBase* location, - uint64_t size) { +absl::Status GpuExecutor::MemZero(Stream* stream, DeviceMemoryBase* location, + uint64_t size) { if (reinterpret_cast(location->opaque()) % 4 == 0 && size % 4 == 0) { return Memset32(stream, location, 0x0, size); @@ -651,8 +717,8 @@ tsl::Status GpuExecutor::MemZero(Stream* stream, DeviceMemoryBase* location, } } -tsl::Status GpuExecutor::Memset(Stream* stream, DeviceMemoryBase* location, - uint8_t pattern, uint64_t size) { +absl::Status GpuExecutor::Memset(Stream* stream, DeviceMemoryBase* location, + uint8_t pattern, uint64_t size) { VLOG(2) << "enqueueing memset8 operation onto stream " << stream << " at location " << location << " with size " << size << " and pattern " << std::hex << pattern; @@ -661,8 +727,8 @@ tsl::Status GpuExecutor::Memset(Stream* stream, DeviceMemoryBase* location, AsGpuStreamValue(stream)); } -tsl::Status GpuExecutor::Memset32(Stream* stream, DeviceMemoryBase* location, - uint32_t pattern, uint64_t size) { +absl::Status GpuExecutor::Memset32(Stream* stream, DeviceMemoryBase* location, + uint32_t pattern, uint64_t size) { VLOG(2) << "enqueueing memset32 operation onto stream " << stream << " at location " << location << " with size " << size << " and pattern " << std::hex << pattern; @@ -673,18 +739,29 @@ tsl::Status GpuExecutor::Memset32(Stream* stream, DeviceMemoryBase* location, AsGpuStreamValue(stream)); } -bool GpuExecutor::Memcpy(Stream* stream, void* host_dst, - const DeviceMemoryBase& gpu_src, uint64_t size) { - return GpuDriver::AsynchronousMemcpyD2H(context_, host_dst, - AsCudaDevicePtr(gpu_src), size, - AsGpuStreamValue(stream)); +absl::Status GpuExecutor::Memcpy(Stream* stream, void* host_dst, + const DeviceMemoryBase& gpu_src, + uint64_t size) { + bool ok = GpuDriver::AsynchronousMemcpyD2H(context_, host_dst, + AsCudaDevicePtr(gpu_src), size, + AsGpuStreamValue(stream)); + // TODO(b/326130105): Change AsynchronousMemcpyD2H calls to return Status. + if (!ok) { + return absl::InternalError("Failed to memcpy from device to host."); + } + return absl::OkStatus(); } -bool GpuExecutor::Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, - const void* host_src, uint64_t size) { - return GpuDriver::AsynchronousMemcpyH2D(context_, AsCudaDevicePtr(gpu_dst), - host_src, size, - AsGpuStreamValue(stream)); +absl::Status GpuExecutor::Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, + const void* host_src, uint64_t size) { + bool ok = GpuDriver::AsynchronousMemcpyH2D(context_, AsCudaDevicePtr(gpu_dst), + host_src, size, + AsGpuStreamValue(stream)); + // TODO(b/326130105): Change AsynchronousMemcpyD2H calls to return Status. + if (!ok) { + return absl::InternalError("Failed to memcpy from device to host."); + } + return absl::OkStatus(); } bool GpuExecutor::MemcpyDeviceToDevice(Stream* stream, @@ -697,10 +774,10 @@ bool GpuExecutor::MemcpyDeviceToDevice(Stream* stream, } bool GpuExecutor::HostCallback(Stream* stream, - absl::AnyInvocable callback) { + absl::AnyInvocable callback) { auto callback_ptr = new absl::AnyInvocable([cb = std::move(callback)]() mutable { - tsl::Status s = std::move(cb)(); + absl::Status s = std::move(cb)(); if (!s.ok()) { LOG(WARNING) << "Host callback failed: " << s; } @@ -715,39 +792,37 @@ bool GpuExecutor::HostCallback(Stream* stream, delete callback; } -tsl::Status GpuExecutor::AllocateEvent(Event* event) { +absl::Status GpuExecutor::AllocateEvent(Event* event) { return AsGpuEvent(event)->Init(); } -tsl::Status GpuExecutor::DeallocateEvent(Event* event) { +absl::Status GpuExecutor::DeallocateEvent(Event* event) { return AsGpuEvent(event)->Destroy(); } -tsl::Status GpuExecutor::RecordEvent(Stream* stream, Event* event) { +absl::Status GpuExecutor::RecordEvent(Stream* stream, Event* event) { return AsGpuEvent(event)->Record(AsGpuStream(stream)); } -tsl::Status GpuExecutor::WaitForEvent(Stream* stream, Event* event) { +absl::Status GpuExecutor::WaitForEvent(Stream* stream, Event* event) { if (GpuDriver::WaitStreamOnEvent(context_, AsGpuStream(stream)->gpu_stream(), AsGpuEvent(event)->gpu_event())) { - return ::tsl::OkStatus(); + return absl::OkStatus(); } else { - return tsl::Status( - absl::StatusCode::kInternal, - absl::StrFormat("error recording waiting for CUDA event on stream %p", - stream)); + return absl::InternalError(absl::StrFormat( + "error recording waiting for CUDA event on stream %p", stream)); } } -tsl::Status GpuExecutor::WaitForEventOnExternalStream(std::intptr_t stream, - Event* event) { +absl::Status GpuExecutor::WaitForEventOnExternalStream(std::intptr_t stream, + Event* event) { if (GpuDriver::WaitStreamOnEvent(context_, absl::bit_cast(stream), AsGpuEvent(event)->gpu_event())) { - return ::tsl::OkStatus(); + return absl::OkStatus(); } else { - return tsl::Status(absl::StatusCode::kInternal, - "error waiting for CUDA event on external stream"); + return absl::InternalError( + "error waiting for CUDA event on external stream"); } } @@ -787,13 +862,13 @@ bool GpuExecutor::CreateStreamDependency(Stream* dependent, Stream* other) { other_completed_event); } -tsl::Status GpuExecutor::BlockHostUntilDone(Stream* stream) { +absl::Status GpuExecutor::BlockHostUntilDone(Stream* stream) { return GpuDriver::SynchronizeStream(context_, AsGpuStreamValue(stream)); } blas::BlasSupport* GpuExecutor::CreateBlas() { PluginRegistry* registry = PluginRegistry::Instance(); - tsl::StatusOr status = + absl::StatusOr status = registry->GetFactory(cuda::kCudaPlatformId); if (!status.ok()) { LOG(ERROR) << "Unable to retrieve BLAS factory: " @@ -806,7 +881,7 @@ blas::BlasSupport* GpuExecutor::CreateBlas() { dnn::DnnSupport* GpuExecutor::CreateDnn() { PluginRegistry* registry = PluginRegistry::Instance(); - tsl::StatusOr status = + absl::StatusOr status = registry->GetFactory(cuda::kCudaPlatformId); if (!status.ok()) { LOG(ERROR) << "Unable to retrieve DNN factory: " @@ -819,7 +894,7 @@ dnn::DnnSupport* GpuExecutor::CreateDnn() { fft::FftSupport* GpuExecutor::CreateFft() { PluginRegistry* registry = PluginRegistry::Instance(); - tsl::StatusOr status = + absl::StatusOr status = registry->GetFactory(cuda::kCudaPlatformId); if (!status.ok()) { LOG(ERROR) << "Unable to retrieve FFT factory: " @@ -835,7 +910,7 @@ bool GpuExecutor::CanEnablePeerAccessTo(StreamExecutorInterface* other) { return GpuDriver::CanEnablePeerAccess(context_, cuda_other->context_); } -tsl::Status GpuExecutor::EnablePeerAccessTo(StreamExecutorInterface* other) { +absl::Status GpuExecutor::EnablePeerAccessTo(StreamExecutorInterface* other) { GpuExecutor* cuda_other = static_cast(other); return GpuDriver::EnablePeerAccess(context_, cuda_other->context_); } @@ -867,8 +942,8 @@ bool GpuExecutor::GetSymbol(const std::string& symbol_name, return false; } -tsl::Status FillBlockDimLimit(GpuDeviceHandle device, - BlockDim* block_dim_limit) { +absl::Status FillBlockDimLimit(GpuDeviceHandle device, + BlockDim* block_dim_limit) { // The BlockDim name is a mismatch against these GRID_DIM_* queries because // we use BlockDims to express the dimensions of blocks within a grid // (as opposed to ThreadDim which expresses the dimensions of threads @@ -878,7 +953,7 @@ tsl::Status FillBlockDimLimit(GpuDeviceHandle device, block_dim_limit->x = x; block_dim_limit->y = y; block_dim_limit->z = z; - return tsl::OkStatus(); + return absl::OkStatus(); } std::unique_ptr @@ -886,36 +961,31 @@ GpuExecutor::CreateEventImplementation() { return std::unique_ptr(new GpuEvent(this)); } -std::unique_ptr -GpuExecutor::CreateKernelImplementation() { - return std::unique_ptr(new GpuKernel()); -} - std::unique_ptr GpuExecutor::GetStreamImplementation() { return std::unique_ptr(new GpuStream(this)); } -tsl::StatusOr> -GpuExecutor::GetCommandBufferImplementation(CommandBuffer::Mode mode) { +absl::StatusOr> GpuExecutor::CreateKernel() { + return std::make_unique(this); +} + +absl::StatusOr> GpuExecutor::CreateCommandBuffer( + CommandBuffer::Mode mode) { VLOG(2) << "Create CUDA command buffer (CUDA graph)"; GpuGraphHandle graph = nullptr; TF_RETURN_IF_ERROR(GpuDriver::CreateGraph(&graph)); return std::make_unique(mode, /*parent=*/this, graph); } -std::unique_ptr -GpuExecutor::GetCommandBufferImplementation(CommandBuffer::Mode mode, - GpuGraphHandle graph, - bool is_owned_graph) { +std::unique_ptr GpuExecutor::CreateCommandBuffer( + CommandBuffer::Mode mode, GpuGraphHandle graph, bool is_owned_graph) { VLOG(2) << "Create CUDA command buffer (CUDA graph) from existing graph " << graph << "; is_owned_graph=" << is_owned_graph; return std::make_unique(mode, /*parent=*/this, graph, is_owned_graph); } -void* GpuExecutor::platform_specific_context() { return context_; } - GpuContext* GpuExecutor::gpu_context() { return context_; } // Attempts to read the NUMA node corresponding to the GPU device's PCI bus out @@ -982,7 +1052,7 @@ static int TryToReadNumaNode(const std::string& pci_bus_id, #endif } -tsl::StatusOr> +absl::StatusOr> GpuExecutor::CreateDeviceDescription(int device_ordinal) { GpuDeviceHandle device; TF_RETURN_IF_ERROR(GpuDriver::GetDevice(device_ordinal, &device)); @@ -995,8 +1065,7 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { internal::DeviceDescriptionBuilder builder; { - int driver_version = 0; - (void)GpuDriver::GetDriverVersion(&driver_version); + int driver_version = GpuDriver::GetDriverVersion().value_or(0); std::string augmented_driver_version = absl::StrFormat( "%d (%s)", driver_version, cuda::DriverVersionStatusToString(Diagnostician::FindDsoVersion())); @@ -1054,9 +1123,9 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { .value(); builder.set_l2_cache_size(l2_cache_bytes); - tsl::StatusOr mem_clock_khz = GpuDriver::GetDeviceAttribute( + absl::StatusOr mem_clock_khz = GpuDriver::GetDeviceAttribute( CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device_ordinal); - tsl::StatusOr mem_bus_width_bits = GpuDriver::GetDeviceAttribute( + absl::StatusOr mem_bus_width_bits = GpuDriver::GetDeviceAttribute( CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device_ordinal); if (mem_clock_khz.ok() && mem_bus_width_bits.ok()) { // Times 2 because HBM is DDR memory; it gets two data bits per each data @@ -1127,7 +1196,4 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { } } // namespace gpu - } // namespace stream_executor - -REGISTER_MODULE_INITIALIZER(cuda_executor, {}); diff --git a/xla/stream_executor/cuda/cuda_fft.cc b/xla/stream_executor/cuda/cuda_fft.cc index 1d84b9a1819e1..60408566b087e 100644 --- a/xla/stream_executor/cuda/cuda_fft.cc +++ b/xla/stream_executor/cuda/cuda_fft.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,23 +17,31 @@ limitations under the License. #include #include +#include #include +#include +#include +#include +#include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cufft.h" #include "xla/stream_executor/cuda/cuda_activation.h" -#include "xla/stream_executor/cuda/cuda_helpers.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" -#include "xla/stream_executor/cuda/cuda_stream.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/fft.h" #include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/gpu/gpu_helpers.h" +#include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/platform/initialize.h" #include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/plugin_registry.h" +#include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_internal.h" -#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" -#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace stream_executor { namespace gpu { @@ -75,13 +83,13 @@ bool SetStream(GpuExecutor *parent, cufftHandle plan, Stream *stream) { // Populates array of 32b integers from 64b integers, or an error if the // numbers don't fit in 32b (signed). -tsl::StatusOr> Downsize64bArray( +absl::StatusOr> Downsize64bArray( std::array source, int32_t rank) { // NOLINT std::array downsized = {0}; for (int32_t i = 0; i < rank; ++i) { if (source[i] > std::numeric_limits::max()) { - return tsl::errors::InvalidArgument( - source[i], " exceeds max 32b signed integer. Conversion failed."); + return absl::InvalidArgumentError(absl::StrCat( + source[i], " exceeds max 32b signed integer. Conversion failed.")); } downsized[i] = static_cast(source[i]); } @@ -90,13 +98,13 @@ tsl::StatusOr> Downsize64bArray( } // namespace -tsl::Status CUDAFftPlan::Initialize( +absl::Status CUDAFftPlan::Initialize( GpuExecutor *parent, Stream *stream, int rank, uint64_t *elem_count, uint64_t *input_embed, uint64 input_stride, uint64 input_distance, uint64_t *output_embed, uint64 output_stride, uint64 output_distance, fft::Type type, int batch_count, ScratchAllocator *scratch_allocator) { if (IsInitialized()) { - return tsl::errors::Internal("cuFFT is already initialized."); + return absl::InternalError("cuFFT is already initialized."); } is_initialized_ = true; scratch_allocator_ = scratch_allocator; @@ -127,44 +135,44 @@ tsl::Status CUDAFftPlan::Initialize( 1 /* = batch */); if (ret != CUFFT_SUCCESS) { LOG(ERROR) << "Failed to create cuFFT 1d plan: " << ret; - return tsl::errors::Internal("Failed to create cuFFT 1d plan."); + return absl::InternalError("Failed to create cuFFT 1d plan."); } - return ::tsl::OkStatus(); + return absl::OkStatus(); case 2: // cufftPlan2d ret = cufftPlan2d(&plan_, elem_count_[0], elem_count_[1], CUDAFftType(type)); if (ret != CUFFT_SUCCESS) { LOG(ERROR) << "Failed to create cuFFT 2d plan: " << ret; - return tsl::errors::Internal("Failed to create cuFFT 2d plan."); + return absl::InternalError("Failed to create cuFFT 2d plan."); } - return ::tsl::OkStatus(); + return absl::OkStatus(); case 3: // cufftPlan3d ret = cufftPlan3d(&plan_, elem_count_[0], elem_count_[1], elem_count_[2], CUDAFftType(type)); if (ret != CUFFT_SUCCESS) { LOG(ERROR) << "Failed to create cuFFT 3d plan: " << ret; - return tsl::errors::Internal("Failed to create cuFFT 3d plan."); + return absl::InternalError("Failed to create cuFFT 3d plan."); } - return ::tsl::OkStatus(); + return absl::OkStatus(); default: LOG(ERROR) << "Invalid rank value for cufftPlan. " "Requested 1, 2, or 3, given: " << rank; - return tsl::errors::InvalidArgument( + return absl::InvalidArgumentError( "cufftPlan only takes rank 1, 2, or 3."); } } else { ret = cufftCreate(&plan_); if (ret != CUFFT_SUCCESS) { LOG(ERROR) << "Failed to create cuFFT plan: " << ret; - return tsl::errors::Internal("Failed to create cuFFT plan."); + return absl::InternalError("Failed to create cuFFT plan."); } ret = cufftSetAutoAllocation(plan_, 0); if (ret != CUFFT_SUCCESS) { LOG(ERROR) << "Failed to set auto allocation for cuFFT plan: " << ret; - return tsl::errors::Internal( + return absl::InternalError( "Failed to set auto allocation for cuFFT plan."); } switch (rank) { @@ -173,7 +181,7 @@ tsl::Status CUDAFftPlan::Initialize( /*batch=*/1, &scratch_size_bytes_); if (ret != CUFFT_SUCCESS) { LOG(ERROR) << "Failed to make cuFFT 1d plan: " << ret; - return tsl::errors::Internal("Failed to make cuFFT 1d plan."); + return absl::InternalError("Failed to make cuFFT 1d plan."); } break; case 2: @@ -181,7 +189,7 @@ tsl::Status CUDAFftPlan::Initialize( CUDAFftType(type), &scratch_size_bytes_); if (ret != CUFFT_SUCCESS) { LOG(ERROR) << "Failed to make cuFFT 2d plan: " << ret; - return tsl::errors::Internal("Failed to make cuFFT 2d plan."); + return absl::InternalError("Failed to make cuFFT 2d plan."); } break; case 3: @@ -190,14 +198,14 @@ tsl::Status CUDAFftPlan::Initialize( &scratch_size_bytes_); if (ret != CUFFT_SUCCESS) { LOG(ERROR) << "Failed to make cuFFT 3d plan: " << ret; - return tsl::errors::Internal("Failed to make cuFFT 3d plan."); + return absl::InternalError("Failed to make cuFFT 3d plan."); } break; default: LOG(ERROR) << "Invalid rank value for cufftPlan. " "Requested 1, 2, or 3, given: " << rank; - return tsl::errors::InvalidArgument( + return absl::InvalidArgumentError( "cufftPlan only takes rank 1, 2, or 3."); } return UpdateScratchAllocator(stream, scratch_allocator); @@ -219,19 +227,19 @@ tsl::Status CUDAFftPlan::Initialize( output_stride, output_distance, CUDAFftType(type), batch_count); if (ret != CUFFT_SUCCESS) { LOG(ERROR) << "Failed to create cuFFT batched plan: " << ret; - return tsl::errors::Internal("Failed to create cuFFT batched plan."); + return absl::InternalError("Failed to create cuFFT batched plan."); } } else { auto ret = cufftCreate(&plan_); if (ret != CUFFT_SUCCESS) { LOG(ERROR) << "Failed to create cuFFT batched plan: " << ret; - return tsl::errors::Internal("Failed to create cuFFT batched plan."); + return absl::InternalError("Failed to create cuFFT batched plan."); } ret = cufftSetAutoAllocation(plan_, 0); if (ret != CUFFT_SUCCESS) { LOG(ERROR) << "Failed to set auto allocation for cuFFT batched plan: " << ret; - return tsl::errors::Internal( + return absl::InternalError( "Failed to set auto allocation for cuFFT batched plan."); } ret = cufftMakePlanMany64( @@ -242,18 +250,18 @@ tsl::Status CUDAFftPlan::Initialize( &scratch_size_bytes_); if (ret != CUFFT_SUCCESS) { LOG(ERROR) << "Failed to make cuFFT batched plan: " << ret; - return tsl::errors::Internal("Failed to make cuFFT batched plan."); + return absl::InternalError("Failed to make cuFFT batched plan."); } return UpdateScratchAllocator(stream, scratch_allocator); } } - return ::tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status CUDAFftPlan::Initialize(GpuExecutor *parent, Stream *stream, - int rank, uint64_t *elem_count, - fft::Type type, - ScratchAllocator *scratch_allocator) { +absl::Status CUDAFftPlan::Initialize(GpuExecutor *parent, Stream *stream, + int rank, uint64_t *elem_count, + fft::Type type, + ScratchAllocator *scratch_allocator) { return Initialize(parent_, stream, rank, elem_count, /*input_embed=*/nullptr, /*input_stride=*/0, /*input_distance=*/0, @@ -261,7 +269,7 @@ tsl::Status CUDAFftPlan::Initialize(GpuExecutor *parent, Stream *stream, /*output_distance=*/0, type, 1, scratch_allocator); } -tsl::Status CUDAFftPlan::UpdateScratchAllocator( +absl::Status CUDAFftPlan::UpdateScratchAllocator( Stream *stream, ScratchAllocator *scratch_allocator) { scratch_allocator_ = scratch_allocator; @@ -277,9 +285,9 @@ tsl::Status CUDAFftPlan::UpdateScratchAllocator( cufftResult_t ret = cufftSetWorkArea(plan_, scratch_.opaque()); if (ret != CUFFT_SUCCESS) { LOG(ERROR) << "Failed to set work area for cuFFT plan: " << ret; - return tsl::errors::Internal("Failed to set work area for cuFFT plan."); + return absl::InternalError("Failed to set work area for cuFFT plan."); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } CUDAFftPlan::~CUDAFftPlan() { @@ -308,143 +316,13 @@ int CUDAFftPlan::GetFftDirection() const { } } -std::unique_ptr CUDAFft::Create1dPlan(Stream *stream, uint64_t num_x, - fft::Type type, - bool in_place_fft) { - std::unique_ptr fft_plan_ptr{new CUDAFftPlan()}; - uint64_t elem_count[1] = {num_x}; - tsl::Status status = - fft_plan_ptr->Initialize(parent_, stream, 1, elem_count, type, - /*scratch_allocator=*/nullptr); - if (!status.ok()) { - LOG(ERROR) << "Plan Parameters: num_x: " << num_x; - LOG(ERROR) << "Failed to initialize cufft 1d plan: " << status.message(); - return nullptr; - } - return std::move(fft_plan_ptr); -} - -std::unique_ptr CUDAFft::Create1dPlanWithScratchAllocator( - Stream *stream, uint64_t num_x, fft::Type type, bool in_place_fft, - ScratchAllocator *scratch_allocator) { - std::unique_ptr fft_plan_ptr{new CUDAFftPlan()}; - uint64_t elem_count[1] = {num_x}; - tsl::Status status = fft_plan_ptr->Initialize(parent_, stream, 1, elem_count, - type, scratch_allocator); - if (!status.ok()) { - LOG(ERROR) << "Plan Parameters: num_x: " << num_x; - LOG(ERROR) - << "Failed to initialize cufft 1d plan with customized allocator: " - << status.message(); - return nullptr; - } - return std::move(fft_plan_ptr); -} - -std::unique_ptr CUDAFft::Create2dPlan(Stream *stream, uint64_t num_x, - uint64_t num_y, fft::Type type, - bool in_place_fft) { - std::unique_ptr fft_plan_ptr{new CUDAFftPlan()}; - uint64_t elem_count[2] = {num_x, num_y}; - tsl::Status status = - fft_plan_ptr->Initialize(parent_, stream, 1, elem_count, type, - /*scratch_allocator=*/nullptr); - if (!status.ok()) { - LOG(ERROR) << "Plan Parameters: num_x: " << num_x << " num_y: " << num_y; - LOG(ERROR) << "Failed to initialize cufft 2d plan: " << status.message(); - return nullptr; - } - return std::move(fft_plan_ptr); -} - -std::unique_ptr CUDAFft::Create2dPlanWithScratchAllocator( - Stream *stream, uint64_t num_x, uint64 num_y, fft::Type type, - bool in_place_fft, ScratchAllocator *scratch_allocator) { - std::unique_ptr fft_plan_ptr{new CUDAFftPlan()}; - uint64_t elem_count[2] = {num_x, num_y}; - tsl::Status status = fft_plan_ptr->Initialize(parent_, stream, 2, elem_count, - type, scratch_allocator); - if (!status.ok()) { - LOG(ERROR) << "Plan Parameters: num_x: " << num_x << " num_y: " << num_y; - LOG(ERROR) - << "Failed to initialize cufft 2d plan with customized allocator: " - << status.message(); - return nullptr; - } - return std::move(fft_plan_ptr); -} - -std::unique_ptr CUDAFft::Create3dPlan(Stream *stream, uint64_t num_x, - uint64_t num_y, uint64 num_z, - fft::Type type, - bool in_place_fft) { - std::unique_ptr fft_plan_ptr{new CUDAFftPlan()}; - uint64_t elem_count[3] = {num_x, num_y, num_z}; - tsl::Status status = - fft_plan_ptr->Initialize(parent_, stream, 3, elem_count, type, - /*scratch_allocator=*/nullptr); - if (!status.ok()) { - LOG(ERROR) << "Plan Parameters: num_x: " << num_x << " num_y: " << num_y - << " num_z: " << num_z; - LOG(ERROR) << "Failed to initialize cufft 3d plan: " << status.message(); - return nullptr; - } - return std::move(fft_plan_ptr); -} - -std::unique_ptr CUDAFft::Create3dPlanWithScratchAllocator( - Stream *stream, uint64_t num_x, uint64 num_y, uint64 num_z, fft::Type type, - bool in_place_fft, ScratchAllocator *scratch_allocator) { - std::unique_ptr fft_plan_ptr{new CUDAFftPlan()}; - uint64_t elem_count[3] = {num_x, num_y, num_z}; - tsl::Status status = fft_plan_ptr->Initialize(parent_, stream, 3, elem_count, - type, scratch_allocator); - if (!status.ok()) { - LOG(ERROR) << "Plan Parameters: num_x: " << num_x << " num_y: " << num_y - << " num_z: " << num_z; - LOG(ERROR) - << "Failed to initialize cufft 3d plan with customized allocator: " - << status.message(); - return nullptr; - } - return std::move(fft_plan_ptr); -} - -std::unique_ptr CUDAFft::CreateBatchedPlan( - Stream *stream, int rank, uint64_t *elem_count, uint64 *input_embed, - uint64_t input_stride, uint64 input_distance, uint64 *output_embed, - uint64_t output_stride, uint64 output_distance, fft::Type type, - bool in_place_fft, int batch_count) { - std::unique_ptr fft_plan_ptr{new CUDAFftPlan()}; - tsl::Status status = fft_plan_ptr->Initialize( - parent_, stream, rank, elem_count, input_embed, input_stride, - input_distance, output_embed, output_stride, output_distance, type, - batch_count, /*scratch_allocator=*/nullptr); - if (!status.ok()) { - LOG(ERROR) << "Initialize Params: rank: " << rank - << " elem_count: " << *elem_count - << " input_embed: " << *input_embed - << " input_stride: " << input_stride - << " input_distance: " << input_distance - << " output_embed: " << *output_embed - << " output_stride: " << output_stride - << " output_distance: " << output_distance - << " batch_count: " << batch_count; - LOG(ERROR) << "Failed to initialize batched cufft plan: " - << status.message(); - return nullptr; - } - - return std::move(fft_plan_ptr); -} - std::unique_ptr CUDAFft::CreateBatchedPlanWithScratchAllocator( Stream *stream, int rank, uint64_t *elem_count, uint64 *input_embed, uint64_t input_stride, uint64 input_distance, uint64 *output_embed, uint64_t output_stride, uint64 output_distance, fft::Type type, bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) { std::unique_ptr fft_plan_ptr{new CUDAFftPlan()}; - tsl::Status status = fft_plan_ptr->Initialize( + absl::Status status = fft_plan_ptr->Initialize( parent_, stream, rank, elem_count, input_embed, input_stride, input_distance, output_embed, output_stride, output_distance, type, batch_count, scratch_allocator); @@ -469,7 +347,7 @@ std::unique_ptr CUDAFft::CreateBatchedPlanWithScratchAllocator( void CUDAFft::UpdatePlanWithScratchAllocator( Stream *stream, fft::Plan *plan, ScratchAllocator *scratch_allocator) { CUDAFftPlan *cuda_fft_plan = dynamic_cast(plan); - tsl::Status status = + absl::Status status = cuda_fft_plan->UpdateScratchAllocator(stream, scratch_allocator); if (!status.ok()) { LOG(FATAL) << "Failed to update custom allocator for cufft plan: " @@ -508,7 +386,7 @@ bool CUDAFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT cufftExec, if (allocator) { auto allocated = allocator->AllocateBytes(input.size()); if (allocated.ok()) { - if (stream->ThenMemcpy(&allocated.value(), input, input.size()).ok()) { + if (stream->Memcpy(&allocated.value(), input, input.size()).ok()) { input_maybe_copy = DeviceMemory(allocated.value()); } } @@ -589,7 +467,7 @@ STREAM_EXECUTOR_CUDA_DEFINE_FFT(double, Z2Z, D2Z, Z2D) } // namespace gpu void initialize_cufft() { - tsl::Status status = + absl::Status status = PluginRegistry::Instance()->RegisterFactory( cuda::kCudaPlatformId, "cuFFT", [](internal::StreamExecutorInterface *parent) -> fft::FftSupport * { @@ -610,5 +488,6 @@ void initialize_cufft() { } // namespace stream_executor -REGISTER_MODULE_INITIALIZER(register_cufft, - { stream_executor::initialize_cufft(); }); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(register_cufft, { + stream_executor::initialize_cufft(); +}); diff --git a/xla/stream_executor/cuda/cuda_fft.h b/xla/stream_executor/cuda/cuda_fft.h index e5ce88163e4bb..111f47903b2fc 100644 --- a/xla/stream_executor/cuda/cuda_fft.h +++ b/xla/stream_executor/cuda/cuda_fft.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,12 +20,14 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_ #define XLA_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_ +#include #include +#include "absl/log/log.h" +#include "absl/status/status.h" #include "third_party/gpus/cuda/include/cufft.h" #include "xla/stream_executor/fft.h" #include "xla/stream_executor/platform/port.h" -#include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/scratch_allocator.h" namespace stream_executor { @@ -64,20 +66,20 @@ class CUDAFftPlan : public fft::Plan { } // Initialize function for batched plan - tsl::Status Initialize(GpuExecutor* parent, Stream* stream, int rank, - uint64_t* elem_count, uint64_t* input_embed, - uint64_t input_stride, uint64 input_distance, - uint64_t* output_embed, uint64_t output_stride, - uint64_t output_distance, fft::Type type, - int batch_count, ScratchAllocator* scratch_allocator); + absl::Status Initialize(GpuExecutor* parent, Stream* stream, int rank, + uint64_t* elem_count, uint64_t* input_embed, + uint64_t input_stride, uint64 input_distance, + uint64_t* output_embed, uint64_t output_stride, + uint64_t output_distance, fft::Type type, + int batch_count, ScratchAllocator* scratch_allocator); // Initialize function for 1d,2d, and 3d plan - tsl::Status Initialize(GpuExecutor* parent, Stream* stream, int rank, - uint64_t* elem_count, fft::Type type, - ScratchAllocator* scratch_allocator); + absl::Status Initialize(GpuExecutor* parent, Stream* stream, int rank, + uint64_t* elem_count, fft::Type type, + ScratchAllocator* scratch_allocator); - tsl::Status UpdateScratchAllocator(Stream* stream, - ScratchAllocator* scratch_allocator); + absl::Status UpdateScratchAllocator(Stream* stream, + ScratchAllocator* scratch_allocator); ScratchAllocator* GetScratchAllocator() const { return scratch_allocator_; } diff --git a/xla/stream_executor/cuda/cuda_helpers.h b/xla/stream_executor/cuda/cuda_helpers.h index 5ff0ed0d2af5d..82ac8abb04132 100644 --- a/xla/stream_executor/cuda/cuda_helpers.h +++ b/xla/stream_executor/cuda/cuda_helpers.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/cuda/cuda_kernel.cc b/xla/stream_executor/cuda/cuda_kernel.cc index 2840c0f8165e8..a28efca7097ef 100644 --- a/xla/stream_executor/cuda/cuda_kernel.cc +++ b/xla/stream_executor/cuda/cuda_kernel.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,13 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/stream_executor/cuda/cuda_kernel.h" +#include +#include + +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/gpu_kernel.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" namespace stream_executor { namespace gpu { CUfunc_cache GpuKernel::GetGpuCacheConfig() const { - switch (preferred_cache_config_) { + switch (cache_config()) { case KernelCacheConfig::kNoPreference: return CU_FUNC_CACHE_PREFER_NONE; case KernelCacheConfig::kPreferShared: @@ -30,9 +39,21 @@ CUfunc_cache GpuKernel::GetGpuCacheConfig() const { return CU_FUNC_CACHE_PREFER_EQUAL; default: LOG(FATAL) << "Unknown KernelCacheConfig" - << static_cast(preferred_cache_config_); + << static_cast(cache_config()); } } +absl::StatusOr GpuKernel::GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const { + int32_t threads_per_block = threads.x * threads.y * threads.z; + VLOG(3) << "Get kernel block occupancy: " << name_ + << "; threads_per_block: " << threads_per_block + << "; dynamic_shared_memory_bytes: " << dynamic_shared_memory_bytes; + + return GpuDriver::GetMaxOccupiedBlocksPerCore(gpu_context_, gpu_function_, + threads_per_block, + dynamic_shared_memory_bytes); +} + } // namespace gpu } // namespace stream_executor diff --git a/xla/stream_executor/cuda/cuda_kernel.h b/xla/stream_executor/cuda/cuda_kernel.h index 7077ee952ec9e..69cf73bcf5b77 100644 --- a/xla/stream_executor/cuda/cuda_kernel.h +++ b/xla/stream_executor/cuda/cuda_kernel.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/cuda/cuda_kernel_test.cc b/xla/stream_executor/cuda/cuda_kernel_test.cc deleted file mode 100644 index 4808389cde9d5..0000000000000 --- a/xla/stream_executor/cuda/cuda_kernel_test.cc +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include -#include "xla/stream_executor/cuda/cuda_test_kernels.h" -#include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/multi_platform_manager.h" -#include "xla/stream_executor/platform.h" -#include "xla/stream_executor/stream.h" -#include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/test.h" - -namespace stream_executor::cuda { - -TEST(CudaKernelTest, Add) { - using AddI32Kernel = TypedKernel, DeviceMemory, - DeviceMemory>; - - Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); - StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddCudaPtxInMemory(internal::kAddI32Kernel, "add"); - - AddI32Kernel add(executor); - ASSERT_TRUE(executor->GetKernel(spec, &add).ok()); - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: a=1, b=2, c=0 - DeviceMemory a = executor->AllocateArray(length, 0); - DeviceMemory b = executor->AllocateArray(length, 0); - DeviceMemory c = executor->AllocateArray(length, 0); - - stream.ThenMemset32(&a, 1, byte_length); - stream.ThenMemset32(&b, 2, byte_length); - stream.ThenMemZero(&c, byte_length); - - // Launch kernel. - ASSERT_TRUE(stream.ThenLaunch(ThreadDim(), BlockDim(4), add, a, b, c).ok()); - - // Copy data back to host. - std::vector dst(4, 42); - stream.ThenMemcpy(dst.data(), c, byte_length); - - std::vector expected = {3, 3, 3, 3}; - ASSERT_EQ(dst, expected); -} - -} // namespace stream_executor::cuda diff --git a/xla/stream_executor/cuda/cuda_platform.cc b/xla/stream_executor/cuda/cuda_platform.cc index c4153854b0b03..e73fbc621d444 100644 --- a/xla/stream_executor/cuda/cuda_platform.cc +++ b/xla/stream_executor/cuda/cuda_platform.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,56 +15,30 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_platform.h" +#include +#include +#include +#include +#include +#include + #include "absl/base/call_once.h" -#include "absl/base/const_init.h" -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" -#include "xla/stream_executor/cuda/cuda_driver.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/initialize.h" -#include "tsl/platform/errors.h" +#include "xla/stream_executor/platform_manager.h" #include "tsl/platform/status.h" namespace stream_executor { namespace gpu { -namespace { - -// Synchronize with spinlocks. -const char kScheduleSpinString[] = "spin"; -// Synchronize with spinlocks that also call CPU yield instructions. -const char kScheduleYieldString[] = "yield"; -// Synchronize with a "synchronization primitive" (e.g. mutex). -const char kScheduleBlockingSyncString[] = "blocking_sync"; - -const DeviceOptions GetDeviceOptionsFromEnv() { - const char* gpu_schedule_string = - std::getenv("TF_CUDA_PLATFORM_GPU_DEVICE_SCHEDULE"); - - if (gpu_schedule_string == nullptr) { - return DeviceOptions::Default(); - } - - unsigned device_flags = 0; - if (strcmp(kScheduleSpinString, gpu_schedule_string) == 0) { - device_flags = DeviceOptions::kScheduleSpin; - } else if (strcmp(kScheduleYieldString, gpu_schedule_string) == 0) { - device_flags = DeviceOptions::kScheduleYield; - } else if (strcmp(kScheduleBlockingSyncString, gpu_schedule_string) == 0) { - device_flags = DeviceOptions::kScheduleBlockingSync; - } else { - LOG(QFATAL) << "Unknown option for environment variable " - "TF_CUDA_PLATFORM_GPU_DEVICE_SCHEDULE " - << gpu_schedule_string << " should be one of {" - << kScheduleBlockingSyncString << ", " << kScheduleSpinString - << ", " << kScheduleYieldString << "}"; - } - - return DeviceOptions(device_flags); -} - -} // namespace CudaPlatform::CudaPlatform() : name_("CUDA"), min_numa_node_(0), limit_numa_node_(0) {} @@ -106,7 +80,7 @@ int CudaPlatform::DeviceToBus(int device_ordinal) { return exec->GetDeviceDescription().numa_node() - min_numa_node_; } -tsl::StatusOr CudaPlatform::FirstExecutorForBus( +absl::StatusOr CudaPlatform::FirstExecutorForBus( int bus_ordinal) { InspectNumaNodes(); CHECK_LT(bus_ordinal, BusCount()) << "bus ordinal out of available range"; @@ -116,8 +90,7 @@ tsl::StatusOr CudaPlatform::FirstExecutorForBus( } } - return tsl::Status( - absl::StatusCode::kNotFound, + return absl::NotFoundError( absl::StrFormat("Executor for bus %d not found.", bus_ordinal)); } @@ -134,19 +107,18 @@ int CudaPlatform::VisibleDeviceCount() const { const std::string& CudaPlatform::Name() const { return name_; } -tsl::StatusOr> +absl::StatusOr> CudaPlatform::DescriptionForDevice(int ordinal) const { return GpuExecutor::CreateDeviceDescription(ordinal); } -tsl::StatusOr CudaPlatform::ExecutorForDevice(int ordinal) { +absl::StatusOr CudaPlatform::ExecutorForDevice(int ordinal) { StreamExecutorConfig config; config.ordinal = ordinal; - config.device_options = GetDeviceOptionsFromEnv(); return GetExecutor(config); } -tsl::StatusOr CudaPlatform::GetExecutor( +absl::StatusOr CudaPlatform::GetExecutor( const StreamExecutorConfig& config) { if (config.gpu_stream) { // If the GPU stream was provided, it's not possible to get-or-create a @@ -158,17 +130,15 @@ tsl::StatusOr CudaPlatform::GetExecutor( config, [&]() { return GetUncachedExecutor(config); }); } -tsl::StatusOr> +absl::StatusOr> CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { auto executor = std::make_unique( this, std::make_unique(), config.ordinal); - auto init_status = executor->Init(config.device_options); + auto init_status = executor->Init(); if (!init_status.ok()) { - return tsl::Status( - absl::StatusCode::kInternal, - absl::StrFormat( - "failed initializing StreamExecutor for CUDA device ordinal %d: %s", - config.ordinal, init_status.ToString())); + return absl::InternalError(absl::StrFormat( + "failed initializing StreamExecutor for CUDA device ordinal %d: %s", + config.ordinal, init_status.ToString())); } return std::move(executor); @@ -177,20 +147,14 @@ CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { } // namespace gpu static void InitializeCudaPlatform() { - // Disabling leak checking, MultiPlatformManager does not destroy its + // Disabling leak checking, PlatformManager does not destroy its // registered platforms. std::unique_ptr platform(new gpu::CudaPlatform); - TF_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform))); + TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform))); } } // namespace stream_executor -REGISTER_MODULE_INITIALIZER(cuda_platform, - stream_executor::InitializeCudaPlatform()); - -// Note that module initialization sequencing is not supported in the -// open-source project, so this will be a no-op there. -REGISTER_MODULE_INITIALIZER_SEQUENCE(cuda_platform, multi_platform_manager); -REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener, - cuda_platform); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER( + cuda_platform, stream_executor::InitializeCudaPlatform()); diff --git a/xla/stream_executor/cuda/cuda_platform.h b/xla/stream_executor/cuda/cuda_platform.h index 8b26b8c76cefb..153282b26507e 100644 --- a/xla/stream_executor/cuda/cuda_platform.h +++ b/xla/stream_executor/cuda/cuda_platform.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,17 +17,13 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_H_ #include -#include +#include -#include "absl/base/thread_annotations.h" +#include "absl/status/statusor.h" #include "xla/stream_executor/executor_cache.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/platform/port.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/stream_executor_internal.h" -#include "xla/stream_executor/trace_listener.h" -#include "tsl/platform/statusor.h" namespace stream_executor { namespace cuda { @@ -53,7 +49,7 @@ class CudaPlatform : public Platform { int DeviceToBus(int device_ordinal); // Returns the lowest-ordinal-number StreamExecutor on the specified bus. - tsl::StatusOr FirstExecutorForBus(int bus_ordinal); + absl::StatusOr FirstExecutorForBus(int bus_ordinal); // Platform interface implementation: // Returns the same value as kCudaPlatform above. @@ -64,15 +60,15 @@ class CudaPlatform : public Platform { const std::string& Name() const override; - tsl::StatusOr> DescriptionForDevice( + absl::StatusOr> DescriptionForDevice( int ordinal) const override; - tsl::StatusOr ExecutorForDevice(int ordinal) override; + absl::StatusOr ExecutorForDevice(int ordinal) override; - tsl::StatusOr GetExecutor( + absl::StatusOr GetExecutor( const StreamExecutorConfig& config) override; - tsl::StatusOr> GetUncachedExecutor( + absl::StatusOr> GetUncachedExecutor( const StreamExecutorConfig& config) override; private: diff --git a/xla/stream_executor/cuda/cuda_platform_id.cc b/xla/stream_executor/cuda/cuda_platform_id.cc index bd3a77d007627..c8754155d6d51 100644 --- a/xla/stream_executor/cuda/cuda_platform_id.cc +++ b/xla/stream_executor/cuda/cuda_platform_id.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/platform.h" + namespace stream_executor { namespace cuda { diff --git a/xla/stream_executor/cuda/cuda_platform_id.h b/xla/stream_executor/cuda/cuda_platform_id.h index 83af581d6f209..b41404993c522 100644 --- a/xla/stream_executor/cuda/cuda_platform_id.h +++ b/xla/stream_executor/cuda/cuda_platform_id.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/cuda/cuda_runtime.cc b/xla/stream_executor/cuda/cuda_runtime.cc index 23a15491877d5..bf355cf9b7b1d 100644 --- a/xla/stream_executor/cuda/cuda_runtime.cc +++ b/xla/stream_executor/cuda/cuda_runtime.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,15 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "absl/base/optimization.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/gpus/cuda/include/driver_types.h" #include "xla/stream_executor/gpu/gpu_runtime.h" #include "xla/stream_executor/gpu/gpu_types.h" -#include "tsl/platform/statusor.h" +#include "tsl/platform/logging.h" namespace stream_executor::gpu { @@ -38,11 +42,20 @@ static const char* ToString(cudaError_t error) { } \ } while (0) -tsl::StatusOr GpuRuntime::GetFuncBySymbol(void* symbol) { +absl::StatusOr GpuRuntime::GetFuncBySymbol(void* symbol) { + VLOG(2) << "Get CUDA function from a symbol: " << symbol; cudaFunction_t func; RETURN_IF_CUDA_RES_ERROR(cudaGetFuncBySymbol(&func, symbol), "Failed call to cudaGetFuncBySymbol"); return reinterpret_cast(func); } +absl::StatusOr GpuRuntime::GetRuntimeVersion() { + VLOG(2) << "Get CUDA runtime version"; + int32_t version; + RETURN_IF_CUDA_RES_ERROR(cudaRuntimeGetVersion(&version), + "Failed call to cudaGetRuntimeVersion"); + return version; +} + } // namespace stream_executor::gpu diff --git a/xla/stream_executor/cuda/cuda_stream.h b/xla/stream_executor/cuda/cuda_stream.h index a99fe6d5eecb8..7e651b45d0e6f 100644 --- a/xla/stream_executor/cuda/cuda_stream.h +++ b/xla/stream_executor/cuda/cuda_stream.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_ #define XLA_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_ +#include "xla/stream_executor/blas.h" #include "xla/stream_executor/gpu/gpu_stream.h" namespace stream_executor { diff --git a/xla/stream_executor/cuda/cuda_test_kernels.cu.cc b/xla/stream_executor/cuda/cuda_test_kernels.cu.cc deleted file mode 100644 index 84b4e0a8d4d5c..0000000000000 --- a/xla/stream_executor/cuda/cuda_test_kernels.cu.cc +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/stream_executor/cuda/cuda_test_kernels.h" - -#include - -namespace stream_executor::cuda::internal { - -__global__ void AddI32(int32_t* a, int32_t* b, int32_t* c) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - c[index] = a[index] + b[index]; -} - -__global__ void MulI32(int32_t* a, int32_t* b, int32_t* c) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - c[index] = a[index] * b[index]; -} - -__global__ void IncAndCmp(int32_t* counter, bool* pred, int32_t value) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - pred[index] = counter[index] < value; - counter[index] += 1; -} - -__global__ void AddI32Ptrs3(Ptrs3 ptrs) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - ptrs.c[index] = ptrs.a[index] + ptrs.b[index]; -} - -void* GetAddI32CudaKernel() { return reinterpret_cast(&AddI32); } - -void* GetMulI32CudaKernel() { return reinterpret_cast(&MulI32); } - -void* GetIncAndCmpCudaKernel() { return reinterpret_cast(&IncAndCmp); } - -void* GetAddI32Ptrs3CudaKernel() { - return reinterpret_cast(&AddI32Ptrs3); -} - -} // namespace stream_executor::cuda::internal diff --git a/xla/stream_executor/cuda/cuda_test_kernels.h b/xla/stream_executor/cuda/cuda_test_kernels.h deleted file mode 100644 index 94014f5d76092..0000000000000 --- a/xla/stream_executor/cuda/cuda_test_kernels.h +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_TEST_KERNELS_H_ -#define XLA_STREAM_EXECUTOR_CUDA_CUDA_TEST_KERNELS_H_ - -#include - -namespace stream_executor::cuda::internal { - -// This is a collection of CUDA kernels for writing simple StreamExecutor tests. -// -// Some of the kernels available as pre-compiled PTX blobs (can be loaded with -// CUDA driver API), and some of the kernels are written directly in CUDA C++ -// and can be loaded from a symbol pointer (to test StreamExecutor CUDA runtime -// integration). - -// PTX kernel compiled from: -// -// __global__ void add(int* a, int* b, int* c) { -// int index = threadIdx.x + blockIdx.x * blockDim.x; -// c[index] = a[index] + b[index]; -// } -// -// Easiest way to get PTX from C++ is to use https://godbolt.org. -inline constexpr std::string_view kAddI32Kernel = R"( -.version 8.0 -.target sm_50 -.address_size 64 - -.visible .entry add( - .param .u64 add_param_0, - .param .u64 add_param_1, - .param .u64 add_param_2 -) -{ - .reg .b32 %r<8>; - .reg .b64 %rd<11>; - .loc 1 1 0 - - ld.param.u64 %rd1, [add_param_0]; - ld.param.u64 %rd2, [add_param_1]; - ld.param.u64 %rd3, [add_param_2]; - .loc 1 3 3 - cvta.to.global.u64 %rd4, %rd3; - cvta.to.global.u64 %rd5, %rd2; - cvta.to.global.u64 %rd6, %rd1; - mov.u32 %r1, %tid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %ntid.x; - mad.lo.s32 %r4, %r2, %r3, %r1; - .loc 1 4 3 - mul.wide.s32 %rd7, %r4, 4; - add.s64 %rd8, %rd6, %rd7; - ld.global.u32 %r5, [%rd8]; - add.s64 %rd9, %rd5, %rd7; - ld.global.u32 %r6, [%rd9]; - add.s32 %r7, %r6, %r5; - add.s64 %rd10, %rd4, %rd7; - st.global.u32 [%rd10], %r7; - .loc 1 5 1 - ret; - -})"; - -template -struct Ptrs3 { - T* a; - T* b; - T* c; -}; - -// Returns a pointer to device kernel compiled from the CUDA C++ code above. -void* GetAddI32CudaKernel(); - -// Returns a pointer to device kernel doing multiplication instead of addition. -void* GetMulI32CudaKernel(); - -// Returns a pointer to device kernel doing increment and compare, intended for -// testing on-device while loops. -void* GetIncAndCmpCudaKernel(); - -// Returns a pointer to device kernel compiled from the CUDA C++ but with all -// three pointers passed to argument as an instance of `Ptr3` template to test -// StreamExecutor arguments packing for custom C++ types. -void* GetAddI32Ptrs3CudaKernel(); - -} // namespace stream_executor::cuda::internal - -#endif // XLA_STREAM_EXECUTOR_CUDA_CUDA_TEST_KERNELS_H_ diff --git a/xla/stream_executor/cuda/cudnn_frontend_helpers.cc b/xla/stream_executor/cuda/cudnn_frontend_helpers.cc new file mode 100644 index 0000000000000..51087e8d600e0 --- /dev/null +++ b/xla/stream_executor/cuda/cudnn_frontend_helpers.cc @@ -0,0 +1,27 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/cudnn_frontend_helpers.h" + +namespace stream_executor { +namespace gpu { + +int CuDnnTensorUID(int offset) { + constexpr int kFirstUid = 1; + return kFirstUid + offset; +} + +} // namespace gpu +} // namespace stream_executor diff --git a/xla/stream_executor/cuda/cudnn_frontend_helpers.h b/xla/stream_executor/cuda/cudnn_frontend_helpers.h new file mode 100644 index 0000000000000..aa59af500ba7a --- /dev/null +++ b/xla/stream_executor/cuda/cudnn_frontend_helpers.h @@ -0,0 +1,37 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDNN_FRONTEND_HELPERS_H_ +#define XLA_STREAM_EXECUTOR_CUDA_CUDNN_FRONTEND_HELPERS_H_ + +namespace stream_executor { +namespace gpu { + +#define RETURN_IF_CUDNN_FRONTEND_ERROR(expr) \ + do { \ + if (ABSL_PREDICT_TRUE((expr).is_bad())) { \ + std::ostringstream oss; \ + oss << (expr).get_message() << "\nin " << __FILE__ << "(" << __LINE__ \ + << "): '" << #expr << "' "; \ + return absl::InternalError(oss.str()); \ + } \ + } while (false) + +int CuDnnTensorUID(int offset); + +} // namespace gpu +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_CUDA_CUDNN_FRONTEND_HELPERS_H_ diff --git a/xla/stream_executor/cuda/memcpy_test.cc b/xla/stream_executor/cuda/memcpy_test.cc deleted file mode 100644 index 1ebc5b048ffbd..0000000000000 --- a/xla/stream_executor/cuda/memcpy_test.cc +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#if GOOGLE_CUDA -#include "absl/memory/memory.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/multi_platform_manager.h" -#include "xla/stream_executor/stream.h" -#include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/test.h" - -namespace stream_executor { - -TEST(MemcpyTest, PinnedHostMemory) { - Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); - StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - Stream stream(executor); - stream.Init(); - ASSERT_TRUE(stream.ok()); - - void* d_ptr = executor->HostMemoryAllocate(sizeof(int)); - DeviceMemoryBase d_mem(d_ptr, sizeof(int)); - int h_ptr; - stream.ThenMemcpy(&h_ptr, d_mem, d_mem.size()); - EXPECT_TRUE(stream.BlockHostUntilDone().ok()); -} - -} // namespace stream_executor - -#endif // GOOGLE_CUDA diff --git a/xla/stream_executor/cuda/ptx_compiler.h b/xla/stream_executor/cuda/ptx_compiler.h new file mode 100644 index 0000000000000..867f857ff8f4a --- /dev/null +++ b/xla/stream_executor/cuda/ptx_compiler.h @@ -0,0 +1,34 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILER_H_ +#define XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILER_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "xla/stream_executor/gpu/gpu_asm_opts.h" + +namespace stream_executor { + +// Takes PTX as a null-terminated string and compiles it to SASS (CUBIN) +// targeting the sm_. NVIDIA GPU architecture. +absl::StatusOr> CompileGpuAsmUsingLibNvPtxCompiler( + int cc_major, int cc_minor, const char* ptx_contents, GpuAsmOpts options, + bool cancel_if_reg_spill); + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILER_H_ diff --git a/xla/stream_executor/cuda/ptx_compiler_impl.cc b/xla/stream_executor/cuda/ptx_compiler_impl.cc new file mode 100644 index 0000000000000..99cbe6ccd2c5c --- /dev/null +++ b/xla/stream_executor/cuda/ptx_compiler_impl.cc @@ -0,0 +1,176 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/optimization.h" +#include "absl/cleanup/cleanup.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/nvPTXCompiler.h" +#include "xla/stream_executor/gpu/gpu_asm_opts.h" + +namespace stream_executor { + +static std::string_view ToString(nvPTXCompileResult status) { + switch (status) { + case NVPTXCOMPILE_SUCCESS: + return "SUCCESS"; + case NVPTXCOMPILE_ERROR_INVALID_COMPILER_HANDLE: + return "INVALID_COMPILER_HANDLE"; + case NVPTXCOMPILE_ERROR_INVALID_INPUT: + return "INVALID_INPUT"; + case NVPTXCOMPILE_ERROR_COMPILATION_FAILURE: + return "COMPILATION_FAILURE"; + case NVPTXCOMPILE_ERROR_INTERNAL: + return "INTERNAL"; + case NVPTXCOMPILE_ERROR_OUT_OF_MEMORY: + return "OUT_OF_MEMORY"; + case NVPTXCOMPILE_ERROR_COMPILER_INVOCATION_INCOMPLETE: + return "COMPILER_INVOCATION_INCOMPLETE"; + case NVPTXCOMPILE_ERROR_UNSUPPORTED_PTX_VERSION: + return "UNSUPPORTED_PTX_VERSION"; +#if CUDA_VERSION > 12000 + case NVPTXCOMPILE_ERROR_UNSUPPORTED_DEVSIDE_SYNC: + return "UNSUPPORTED_DEVSIDE_SYNC"; +#endif + default: + return "UNKNOWN"; + } +} + +#define RETURN_IF_NVPTXCOMPILER_ERROR(expr) \ + do { \ + nvPTXCompileResult _status = expr; \ + if (!ABSL_PREDICT_TRUE(_status == NVPTXCOMPILE_SUCCESS)) { \ + std::ostringstream oss; \ + oss << ToString(_status) << "\nin " << __FILE__ << "(" << __LINE__ \ + << "): '" << #expr << "'"; \ + return absl::UnknownError(oss.str()); \ + } \ + } while (false) + +absl::StatusOr> CompileGpuAsmUsingLibNvPtxCompiler( + int cc_major, int cc_minor, const char* ptx_contents, GpuAsmOpts options, + bool cancel_if_reg_spill) { + nvPTXCompilerHandle compiler_handle{}; + RETURN_IF_NVPTXCOMPILER_ERROR(nvPTXCompilerCreate( + &compiler_handle, std::strlen(ptx_contents), ptx_contents)); + absl::Cleanup compiler_cleaner = [&compiler_handle] { + nvPTXCompilerDestroy(&compiler_handle); + }; + + // If the target is sm_90, hard code it to sm_90a so that all instructions + // can be used. We don't need the portability that sm_90 gives. + std::string_view extension = (cc_major == 9 && cc_minor == 0) ? "a" : ""; + std::string architecture = absl::StrCat("sm_", cc_major, cc_minor, extension); + + options.extra_flags.emplace_back(absl::StrCat("-arch=", architecture)); + options.extra_flags.emplace_back("--warn-on-spills"); + + if (VLOG_IS_ON(2)) { + options.extra_flags.emplace_back("-v"); + } + if (options.disable_gpuasm_optimizations) { + options.extra_flags.emplace_back("-O0"); + } + + if (VLOG_IS_ON(3)) { + VLOG(3) << absl::StrJoin(options.extra_flags, " "); + } + + std::vector cmdline_options_ptrs{}; + absl::c_transform(options.extra_flags, + std::back_inserter(cmdline_options_ptrs), + [](const std::string& s) { return s.c_str(); }); + + nvPTXCompileResult compile_result = + nvPTXCompilerCompile(compiler_handle, cmdline_options_ptrs.size(), + cmdline_options_ptrs.data()); + + if (compile_result != NVPTXCOMPILE_SUCCESS) { + size_t error_log_size{}; + RETURN_IF_NVPTXCOMPILER_ERROR( + nvPTXCompilerGetErrorLogSize(compiler_handle, &error_log_size)); + + std::string error_log(error_log_size, '\0'); + RETURN_IF_NVPTXCOMPILER_ERROR( + nvPTXCompilerGetErrorLog(compiler_handle, error_log.data())); + + // It happens when the linked version of ntvptxcompiler is too old for the + // current GPU. Example error message associated with this error code: + // ptxas fatal : Value 'sm_80' is not defined for option 'gpu-name' + if (absl::StrContains(error_log, "ptxas fatal : Value '") && + absl::StrContains(error_log, "is not defined for option 'gpu-name'")) { + return absl::UnimplementedError(absl::StrFormat( + "Linked libnvptxcompiler is too old for %s.", architecture)); + } + if (absl::StrContains(error_log, "ptxas fatal") && + absl::StrContains(error_log, "Register allocation failed")) { + return absl::ResourceExhaustedError("Register allocation failed"); + } + + return absl::InternalError( + absl::StrFormat("PTX compilation failed with error code %d, output: %s", + compile_result, error_log)); + } + + size_t info_log_size{}; + RETURN_IF_NVPTXCOMPILER_ERROR( + nvPTXCompilerGetInfoLogSize(compiler_handle, &info_log_size)); + + std::string info_log(info_log_size, '\0'); + RETURN_IF_NVPTXCOMPILER_ERROR( + nvPTXCompilerGetInfoLog(compiler_handle, info_log.data())); + + // Print the verbose output of ptxas. + if (!info_log.empty()) { + if (absl::StrContains(info_log, "warning")) { + if (cancel_if_reg_spill && + absl::StrContains(info_log, "Registers are spilled")) { + return absl::CancelledError( + "Compilation result discarded due to register spilling"); + } + } else { + VLOG(2) << info_log; + } + } + + size_t cubinSize{}; + RETURN_IF_NVPTXCOMPILER_ERROR( + nvPTXCompilerGetCompiledProgramSize(compiler_handle, &cubinSize)); + + std::vector cubin(cubinSize); + RETURN_IF_NVPTXCOMPILER_ERROR( + nvPTXCompilerGetCompiledProgram(compiler_handle, (char*)cubin.data())); + + return cubin; +} + +} // namespace stream_executor diff --git a/xla/stream_executor/cuda/ptx_compiler_stub.cc b/xla/stream_executor/cuda/ptx_compiler_stub.cc new file mode 100644 index 0000000000000..f0e69529352aa --- /dev/null +++ b/xla/stream_executor/cuda/ptx_compiler_stub.cc @@ -0,0 +1,30 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/stream_executor/gpu/gpu_asm_opts.h" + +namespace stream_executor { +absl::StatusOr> CompileGpuAsmUsingLibNvPtxCompiler( + int cc_major, int cc_minor, const char* ptx_contents, GpuAsmOpts options, + bool cancel_if_reg_spill) { + return absl::UnimplementedError( + "XLA was built without libnvptxcompiler support."); +} +} // namespace stream_executor diff --git a/xla/stream_executor/cuda/ptx_compiler_support.cc b/xla/stream_executor/cuda/ptx_compiler_support.cc new file mode 100644 index 0000000000000..ea45f1b102c27 --- /dev/null +++ b/xla/stream_executor/cuda/ptx_compiler_support.cc @@ -0,0 +1,18 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +namespace stream_executor { +bool IsLibNvPtxCompilerSupported() { return LIBNVPTXCOMPILER_SUPPORT; } +} // namespace stream_executor diff --git a/xla/stream_executor/cuda/ptx_compiler_support.h b/xla/stream_executor/cuda/ptx_compiler_support.h new file mode 100644 index 0000000000000..37f28f8a45c9a --- /dev/null +++ b/xla/stream_executor/cuda/ptx_compiler_support.h @@ -0,0 +1,25 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILER_SUPPORT_H_ +#define XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILER_SUPPORT_H_ + +namespace stream_executor { +// Returns true if XLA was built with libnvptxcompiler support. Otherwise false +// is returned. +bool IsLibNvPtxCompilerSupported(); +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILER_SUPPORT_H_ diff --git a/xla/stream_executor/cuda/ptx_compiler_test.cc b/xla/stream_executor/cuda/ptx_compiler_test.cc new file mode 100644 index 0000000000000..a42394f9988ee --- /dev/null +++ b/xla/stream_executor/cuda/ptx_compiler_test.cc @@ -0,0 +1,229 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/ptx_compiler.h" + +#include + +#include +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/stream_executor/cuda/ptx_compiler_support.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/gpu/gpu_asm_opts.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/test.h" + +namespace { + +// Generated by the following command: +// +// echo "__global__ void kernel(int* in) { for (int i=0; i < 16; i++) \ +// { in[i] += i; } for (int i=0; i < 16; i++) { in[15-i] += i; }}" \ +// | nvcc -o - -rdc true --ptx --x cu - -O0 +// +// The `.maxnreg` directive was added manually afterwards. +constexpr const char kSpillingPtx[] = R"( +// +// Generated by NVIDIA NVVM Compiler +// +// Compiler Build ID: CL-32267302 +// Cuda compilation tools, release 12.0, V12.0.140 +// Based on NVVM 7.0.1 +// + +.version 8.0 +.target sm_52 +.address_size 64 + + // .globl _Z6kernelPi + +.visible .entry _Z6kernelPi( + .param .u64 _Z6kernelPi_param_0 +) + .maxnreg 16 +{ + .reg .b32 %r<33>; + .reg .b64 %rd<3>; + + + ld.param.u64 %rd1, [_Z6kernelPi_param_0]; + cvta.to.global.u64 %rd2, %rd1; + ld.global.u32 %r1, [%rd2+4]; + ld.global.u32 %r2, [%rd2+8]; + ld.global.u32 %r3, [%rd2+12]; + ld.global.u32 %r4, [%rd2+16]; + ld.global.u32 %r5, [%rd2+20]; + ld.global.u32 %r6, [%rd2+24]; + ld.global.u32 %r7, [%rd2+28]; + ld.global.u32 %r8, [%rd2+32]; + ld.global.u32 %r9, [%rd2+36]; + ld.global.u32 %r10, [%rd2+40]; + ld.global.u32 %r11, [%rd2+44]; + ld.global.u32 %r12, [%rd2+48]; + ld.global.u32 %r13, [%rd2+52]; + ld.global.u32 %r14, [%rd2+56]; + ld.global.u32 %r15, [%rd2+60]; + add.s32 %r16, %r15, 15; + st.global.u32 [%rd2+60], %r16; + add.s32 %r17, %r14, 15; + st.global.u32 [%rd2+56], %r17; + add.s32 %r18, %r13, 15; + st.global.u32 [%rd2+52], %r18; + add.s32 %r19, %r12, 15; + st.global.u32 [%rd2+48], %r19; + add.s32 %r20, %r11, 15; + st.global.u32 [%rd2+44], %r20; + add.s32 %r21, %r10, 15; + st.global.u32 [%rd2+40], %r21; + add.s32 %r22, %r9, 15; + st.global.u32 [%rd2+36], %r22; + add.s32 %r23, %r8, 15; + st.global.u32 [%rd2+32], %r23; + add.s32 %r24, %r7, 15; + st.global.u32 [%rd2+28], %r24; + add.s32 %r25, %r6, 15; + st.global.u32 [%rd2+24], %r25; + add.s32 %r26, %r5, 15; + st.global.u32 [%rd2+20], %r26; + add.s32 %r27, %r4, 15; + st.global.u32 [%rd2+16], %r27; + add.s32 %r28, %r3, 15; + st.global.u32 [%rd2+12], %r28; + add.s32 %r29, %r2, 15; + st.global.u32 [%rd2+8], %r29; + add.s32 %r30, %r1, 15; + st.global.u32 [%rd2+4], %r30; + ld.global.u32 %r31, [%rd2]; + add.s32 %r32, %r31, 15; + st.global.u32 [%rd2], %r32; + ret; +} +)"; + +// Generated by the following command: +// +// echo "__global__ void kernel(int* output) { *output = 42; }" | +// nvcc -o - -rdc true --ptx --x cu - +// +constexpr const char kSimplePtx[] = R"( +.version 8.0 +.target sm_52 +.address_size 64 + + // .globl _Z6kernelPi + +.visible .entry _Z6kernelPi ( + .param .u64 _Z6kernelPi_param_0 +) +{ + .reg .b32 %r<16>; + .reg .b64 %rd<3>; + + + ld.param.u64 %rd1, [_Z6kernelPi_param_0]; + cvta.to.global.u64 %rd2, %rd1; + mov.u32 %r1, 42; + st.global.u32 [%rd2], %r15; + ret; + +})"; + +constexpr stream_executor::CudaComputeCapability kDefaultComputeCapability{5, + 2}; + +absl::StatusOr> CompileHelper( + stream_executor::CudaComputeCapability cc, const char* const ptx_input, + bool disable_gpuasm_optimizations = false, bool cancel_if_reg_spill = false, + std::vector extra_flags = {}) { + stream_executor::GpuAsmOpts options{}; + options.disable_gpuasm_optimizations = disable_gpuasm_optimizations; + options.extra_flags = std::move(extra_flags); + + return stream_executor::CompileGpuAsmUsingLibNvPtxCompiler( + cc.major, cc.minor, ptx_input, options, cancel_if_reg_spill); +} + +class PtxCompilerTest : public ::testing::Test { + void SetUp() override { + // This can't be in the constructor because `GTEST_SKIP` can't be called + // from constructors. + if (!stream_executor::IsLibNvPtxCompilerSupported()) { + // We skip these tests if this is a build without libnvptxcompiler + // support. + GTEST_SKIP(); + } + } +}; + +TEST_F(PtxCompilerTest, IdentifiesUnsupportedArchitecture) { + stream_executor::GpuAsmOpts options{}; + EXPECT_THAT( + CompileHelper(stream_executor::CudaComputeCapability{100, 0}, kSimplePtx), + tsl::testing::StatusIs(absl::StatusCode::kUnimplemented)); +} + +TEST_F(PtxCompilerTest, CanCompileSingleCompilationUnit) { + stream_executor::GpuAsmOpts options{}; + EXPECT_THAT(CompileHelper(kDefaultComputeCapability, kSimplePtx), + tsl::testing::IsOk()); +} + +TEST_F(PtxCompilerTest, CancelsOnRegSpill) { + // We have to disable optimization here, otherwise PTXAS will optimize our + // trivial register usages away and we don't spill as intended. + EXPECT_THAT(CompileHelper(kDefaultComputeCapability, kSpillingPtx, + /*disable_gpuasm_optimizations=*/true, + /*cancel_if_reg_spill=*/true), + tsl::testing::StatusIs(absl::StatusCode::kCancelled)); + + // We also test the converse to ensure our test case isn't broken. + EXPECT_THAT(CompileHelper(kDefaultComputeCapability, kSpillingPtx, + /*disable_gpuasm_optimizations=*/true, + /*cancel_if_reg_spill=*/false), + tsl::testing::IsOk()); +} + +TEST_F(PtxCompilerTest, AcceptsExtraArguments) { + // It's tricky to test whether `extra_arguments` works without depending on + // too much nvptx internals. So we pass the `--generate-line-info` flags and + // expect strictly larger outputs than without the flag. + auto reference_cubin = CompileHelper(kDefaultComputeCapability, kSimplePtx, + /*disable_gpuasm_optimizations=*/false, + /*cancel_if_reg_spill=*/false, {}); + auto cubin_with_line_info = + CompileHelper(kDefaultComputeCapability, kSimplePtx, + /*disable_gpuasm_optimizations=*/false, + /*cancel_if_reg_spill=*/false, {"--generate-line-info"}); + + EXPECT_THAT(reference_cubin, tsl::testing::IsOk()); + EXPECT_THAT(cubin_with_line_info, tsl::testing::IsOk()); + EXPECT_GT(cubin_with_line_info->size(), reference_cubin->size()); + + // We also test whether invalid flags lead to a compilation error. + EXPECT_THAT( + CompileHelper(kDefaultComputeCapability, kSimplePtx, + /*disable_gpuasm_optimizations=*/false, + /*cancel_if_reg_spill=*/false, {"--flag-does-not-exist"}), + tsl::testing::StatusIs(absl::StatusCode::kInternal)); +} + +} // namespace diff --git a/xla/stream_executor/cuda/stream_search_test.cc b/xla/stream_executor/cuda/stream_search_test.cc deleted file mode 100644 index 05fb96f97a725..0000000000000 --- a/xla/stream_executor/cuda/stream_search_test.cc +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/test.h" - -namespace stream_executor { -namespace { - -#if GOOGLE_CUDA - -class StreamSearchTest : public ::testing::Test { - public: - Platform* GetPlatform() { - return *MultiPlatformManager::PlatformWithName("CUDA"); - } -}; - -TEST_F(StreamSearchTest, NoMatchBadPtr) { - void* bad_ptr = reinterpret_cast(0xdeadbeef); - - StreamExecutorConfig config; - config.gpu_stream = bad_ptr; - - tsl::StatusOr found_executor = - GetPlatform()->GetExecutor(config); - - // No executor found. - EXPECT_FALSE(found_executor.ok()); -} - -TEST_F(StreamSearchTest, FoundPrevExecutor) { - tsl::StatusOr executor = GetPlatform()->ExecutorForDevice(0); - EXPECT_TRUE(executor.ok()); - - Stream s(*executor); - s.Init(); - - Stream s2(*executor); - s2.Init(); - - void* gpu_ptr = s.platform_specific_handle().stream; - void* gpu_ptr_2 = s2.platform_specific_handle().stream; - - StreamExecutorConfig c; - c.gpu_stream = gpu_ptr; - - tsl::StatusOr found_executor = GetPlatform()->GetExecutor(c); - EXPECT_TRUE(found_executor.ok()); - EXPECT_EQ(*found_executor, *executor); - - Stream* found1 = (*found_executor)->FindAllocatedStream(gpu_ptr); - EXPECT_EQ(found1, &s); - - Stream* found2 = (*found_executor)->FindAllocatedStream(gpu_ptr_2); - EXPECT_EQ(found2, &s2); -} - -#endif // GOOGLE_CUDA - -} // namespace -} // namespace stream_executor diff --git a/xla/stream_executor/data_type.h b/xla/stream_executor/data_type.h index c0f900d4794e4..ebac59ba7c4ea 100644 --- a/xla/stream_executor/data_type.h +++ b/xla/stream_executor/data_type.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tsl/platform/float8.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/protobuf/dnn.pb.h" namespace Eigen { diff --git a/xla/stream_executor/device_description.cc b/xla/stream_executor/device_description.cc index ed13e63ddbbe5..7361a799b56bb 100644 --- a/xla/stream_executor/device_description.cc +++ b/xla/stream_executor/device_description.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,9 +16,10 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include -#include #include +#include +#include "xla/stream_executor/launch_dim.h" #include "tsl/lib/math/math_util.h" #include "tsl/platform/logging.h" diff --git a/xla/stream_executor/device_description.h b/xla/stream_executor/device_description.h index 85cac5a9ee075..b2ada3ece94f5 100644 --- a/xla/stream_executor/device_description.h +++ b/xla/stream_executor/device_description.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,12 +28,12 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "xla/stream_executor/device_description.pb.h" #include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/platform/port.h" namespace stream_executor { namespace internal { @@ -53,8 +53,8 @@ struct CudaComputeCapability { HOPPER = 9 }; - CudaComputeCapability() = default; - CudaComputeCapability(int major, int minor) { + constexpr CudaComputeCapability() = default; + constexpr CudaComputeCapability(int major, int minor) { this->major = major; this->minor = minor; } @@ -171,6 +171,11 @@ class RocmComputeCapability { return absl::c_count(kList, gfx_version()) != 0; } + bool gfx9_mi300() const { + static constexpr absl::string_view kList[] = {"gfx940", "gfx941", "gfx942"}; + return absl::c_count(kList, gfx_version()) != 0; + } + bool navi21() const { return gfx_version() == "gfx1030"; } bool navi31() const { return gfx_version() == "gfx1100"; } @@ -196,6 +201,8 @@ class RocmComputeCapability { bool has_hipblaslt() const { return gfx9_mi200_or_later(); } + bool has_fp8_support() const { return gfx9_mi300(); } + RocmComputeCapabilityProto ToProto() const { RocmComputeCapabilityProto proto; proto.set_gcn_arch_name(gcn_arch_name_); @@ -210,13 +217,13 @@ class RocmComputeCapability { std::string gcn_arch_name_ = "gfx000"; // default to invalid arch. static constexpr absl::string_view kSupportedGfxVersions[]{ - "gfx900", // MI25 - "gfx906", // MI50 / MI60 - "gfx908", // MI100 - "gfx90a", // MI200 - "gfx940", "gfx941", "gfx942", - "gfx1030", // Navi21 - "gfx1100" // Navi31 + "gfx900", // MI25 + "gfx906", // MI50 / MI60 + "gfx908", // MI100 + "gfx90a", // MI200 + "gfx940", "gfx941", "gfx942", // MI300 + "gfx1030", // Navi21 + "gfx1100" // Navi31 }; }; diff --git a/xla/stream_executor/device_description.proto b/xla/stream_executor/device_description.proto index 21e66edd4c8be..c01d365b78f49 100644 --- a/xla/stream_executor/device_description.proto +++ b/xla/stream_executor/device_description.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/device_id_utils.h b/xla/stream_executor/device_id_utils.h deleted file mode 100644 index 76aada1dcfec0..0000000000000 --- a/xla/stream_executor/device_id_utils.h +++ /dev/null @@ -1,47 +0,0 @@ - -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_STREAM_EXECUTOR_DEVICE_ID_UTILS_H_ -#define XLA_STREAM_EXECUTOR_DEVICE_ID_UTILS_H_ - -#include "xla/stream_executor/platform.h" -#include "xla/stream_executor/stream_executor.h" -#include "tsl/framework/device_id.h" -#include "tsl/framework/device_id_manager.h" - -namespace stream_executor { - -// Utility methods for getting the associated executor given a TfDeviceId -// or PlatformDeviceId. -class DeviceIdUtil { - public: - static tsl::StatusOr ExecutorForPlatformDeviceId( - Platform* device_manager, tsl::PlatformDeviceId platform_device_id) { - return device_manager->ExecutorForDevice(platform_device_id.value()); - } - static tsl::StatusOr ExecutorForTfDeviceId( - const tsl::DeviceType& type, Platform* device_manager, - tsl::TfDeviceId tf_device_id) { - tsl::PlatformDeviceId platform_device_id; - TF_RETURN_IF_ERROR(tsl::DeviceIdManager::TfToPlatformDeviceId( - type, tf_device_id, &platform_device_id)); - return ExecutorForPlatformDeviceId(device_manager, platform_device_id); - } -}; - -} // namespace stream_executor - -#endif // XLA_STREAM_EXECUTOR_DEVICE_ID_UTILS_H_ diff --git a/xla/stream_executor/device_memory.h b/xla/stream_executor/device_memory.h index fe3df0067f337..82c3a782094c9 100644 --- a/xla/stream_executor/device_memory.h +++ b/xla/stream_executor/device_memory.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,8 +28,9 @@ limitations under the License. #include #include +#include -#include "xla/stream_executor/platform/port.h" +#include "tsl/platform/logging.h" namespace stream_executor { @@ -61,12 +62,16 @@ class DeviceMemoryBase { bool operator==(std::nullptr_t other) const { return is_null(); } bool operator!=(std::nullptr_t other) const { return !is_null(); } + bool operator==(const DeviceMemoryBase &other) const { + return opaque_ == other.opaque_ && size_ == other.size_; + } + // Provides a partial order between device memory values. // // This operator is provided so that this object can be used as a key in an // ordered map. bool operator<(const DeviceMemoryBase &other) const { - return opaque() < other.opaque(); + return std::tie(opaque_, size_) < std::tie(other.opaque_, other.size_); } // Returns the size, in bytes, for the backing memory. @@ -89,6 +94,19 @@ class DeviceMemoryBase { return opaque() == other.opaque() && size() == other.size(); } + // Creates a memory region (slice) inside another allocated memory region. + // Offset and size are in bytes. + DeviceMemoryBase GetByteSlice(uint64_t offset_bytes, + uint64_t size_bytes) const { + DCHECK(offset_bytes + size_bytes <= size_) + << "requested slice allocation (offset + size) is greater " + << "than parent allocation size: (" << offset_bytes << " + " + << size_bytes << ") vs. (" << size_ << ")"; + + return DeviceMemoryBase( + reinterpret_cast(opaque_) + offset_bytes, size_bytes); + } + protected: friend class StreamExecutor; @@ -139,13 +157,21 @@ class DeviceMemory final : public DeviceMemoryBase { // Returns whether this is a single-element allocation. bool IsScalar() const { return ElementCount() == 1; } - // Create a typed area of DeviceMemory with a given opaque pointer and the + // Creates a typed area of DeviceMemory with a given opaque pointer and the // quantity of bytes in the allocation. This function is broken out to // distinguish bytes from an element count. static DeviceMemory MakeFromByteSize(void *opaque, uint64_t bytes) { return DeviceMemory(opaque, bytes); } + // Creates a memory region (slice) inside another allocated memory region. + // Offset and size are specified in terms of ElemT elements. + DeviceMemory GetSlice(uint64_t element_offset, + uint64_t element_count) { + return DeviceMemory(GetByteSlice(sizeof(ElemT) * element_offset, + sizeof(ElemT) * element_count)); + } + // Resets the DeviceMemory data, in MakeFromByteSize fashion. // This simply clobbers the prior values. void ResetFromByteSize(void *opaque, uint64_t bytes) { @@ -167,26 +193,6 @@ class DeviceMemory final : public DeviceMemoryBase { DeviceMemory(void *opaque, uint64_t size) : DeviceMemoryBase(opaque, size) {} }; -// A class to encapsulate the type and size of a dynamic shared memory -// buffer. Because the buffer exists solely on the device and is not copyable -// to the host, memory objects of this type do not maintain buffer pointers -// on the host. -template -class SharedDeviceMemory final : public DeviceMemoryBase { - public: - explicit SharedDeviceMemory(uint64_t elem_count) - : DeviceMemoryBase(nullptr, elem_count * kElemSize) {} - - static constexpr size_t kElemSize = sizeof(ElemT); - - // Returns the number of elements of type ElemT that constitute this - // allocation. - uint64_t ElementCount() const { return size() / kElemSize; } - - // Returns whether this is a single-element allocation. - bool IsScalar() const { return ElementCount() == 1; } -}; - // Host-side representation of packed-and-aligned vector datatypes on the device // side. Since these can appear in device kernel signatures, we support // launching them with these datatypes in launch signatures. diff --git a/xla/stream_executor/device_memory_allocator.h b/xla/stream_executor/device_memory_allocator.h index 500428b0ed24b..d42066591481c 100644 --- a/xla/stream_executor/device_memory_allocator.h +++ b/xla/stream_executor/device_memory_allocator.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,18 +16,22 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_DEVICE_MEMORY_ALLOCATOR_H_ #define XLA_STREAM_EXECUTOR_DEVICE_MEMORY_ALLOCATOR_H_ +#include #include +#include #include #include #include "absl/base/thread_annotations.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/platform.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace stream_executor { @@ -82,7 +86,7 @@ class ScopedDeviceMemory { // object. // // Postcondition: other == nullptr. - ScopedDeviceMemory(ScopedDeviceMemory &&other) + ScopedDeviceMemory(ScopedDeviceMemory &&other) noexcept : wrapped_(other.Release()), device_ordinal_(other.device_ordinal_), allocator_(other.allocator_) {} @@ -94,7 +98,7 @@ class ScopedDeviceMemory { // Moves ownership of the memory from other to this object. // // Postcondition: other == nullptr. - ScopedDeviceMemory &operator=(ScopedDeviceMemory &&other) { + ScopedDeviceMemory &operator=(ScopedDeviceMemory &&other) noexcept { TF_CHECK_OK(Free()); wrapped_ = other.Release(); allocator_ = other.allocator_; @@ -141,7 +145,7 @@ class ScopedDeviceMemory { int device_ordinal() const { return device_ordinal_; } // Frees the existing memory, resets the wrapped memory to null. - tsl::Status Free(); + absl::Status Free(); private: DeviceMemory wrapped_; // Value we wrap with scoped-release. @@ -176,10 +180,10 @@ class DeviceMemoryAllocator { // fails, the allocation should return immediately without retrying. An // example use case is optional scratch spaces where a failure has only // performance impact. - virtual tsl::StatusOr Allocate(int device_ordinal, - uint64_t size, - bool retry_on_failure, - int64_t memory_space) = 0; + virtual absl::StatusOr Allocate(int device_ordinal, + uint64_t size, + bool retry_on_failure, + int64_t memory_space) = 0; // Two-arg version of Allocate(), which sets retry-on-failure to true and // memory_space to default (0). @@ -187,22 +191,22 @@ class DeviceMemoryAllocator { // (We don't simply use a default argument on the virtual Allocate function // because default args on virtual functions are disallowed by the Google // style guide.) - tsl::StatusOr Allocate(int device_ordinal, - uint64_t size) { + absl::StatusOr Allocate(int device_ordinal, + uint64_t size) { return Allocate(device_ordinal, size, /*retry_on_failure=*/true, /*memory_space=*/0); } // Three-arg version of Allocate(), which sets memory_space to default (0). - tsl::StatusOr Allocate(int device_ordinal, uint64_t size, - bool retry_on_failure) { + absl::StatusOr Allocate(int device_ordinal, uint64_t size, + bool retry_on_failure) { return Allocate(device_ordinal, size, retry_on_failure, /*memory_space=*/0); } // Typed version of the allocation, returning typed memory. template - tsl::StatusOr> Allocate( + absl::StatusOr> Allocate( int device_ordinal, uint64_t size, bool retry_on_failure = true, int64_t memory_space = 0) { return Allocate(device_ordinal, size, retry_on_failure, memory_space); @@ -211,7 +215,7 @@ class DeviceMemoryAllocator { // Must be a nop for null pointers. Should not be used. // // TODO(cheshire): Add deprecation notice. - virtual tsl::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) = 0; + virtual absl::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) = 0; // Return the platform that the allocator allocates memory on. const Platform *platform() const { return platform_; } @@ -224,7 +228,7 @@ class DeviceMemoryAllocator { // allocated by this allocator. It is not necessary to use the returned stream // though, as clients may have additional information letting them safely use // a different stream. - virtual tsl::StatusOr GetStream(int device_ordinal) = 0; + virtual absl::StatusOr GetStream(int device_ordinal) = 0; protected: const Platform *platform_; @@ -245,23 +249,23 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { const Platform *platform, absl::Span stream_executors); - tsl::StatusOr Allocate(int device_ordinal, uint64_t size, - bool retry_on_failure, - int64_t memory_space) override; + absl::StatusOr Allocate(int device_ordinal, uint64_t size, + bool retry_on_failure, + int64_t memory_space) override; // Pull in two-arg overload that sets retry_on_failure to true. using DeviceMemoryAllocator::Allocate; - tsl::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) override; + absl::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) override; bool AllowsAsynchronousDeallocation() const override; // Gets-or-creates a stream for a given `device_ordinal` from an appropriate // stream executor. - tsl::StatusOr GetStream(int device_ordinal) override; + absl::StatusOr GetStream(int device_ordinal) override; // Gets the stream executor for given device ordinal. - tsl::StatusOr GetStreamExecutor(int device_ordinal) const; + absl::StatusOr GetStreamExecutor(int device_ordinal) const; private: // Available stream executors. Each stream executor has a different device @@ -275,13 +279,13 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { }; template -tsl::Status ScopedDeviceMemory::Free() { +absl::Status ScopedDeviceMemory::Free() { if (!wrapped_.is_null()) { CHECK(allocator_ != nullptr) << "Owning pointer in inconsistent state"; TF_RETURN_IF_ERROR(allocator_->Deallocate(device_ordinal_, wrapped_)); } wrapped_ = DeviceMemory{}; - return ::tsl::OkStatus(); + return absl::OkStatus(); } } // namespace stream_executor diff --git a/xla/stream_executor/device_options.h b/xla/stream_executor/device_options.h deleted file mode 100644 index 776fa4220813c..0000000000000 --- a/xla/stream_executor/device_options.h +++ /dev/null @@ -1,104 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Contains device-level options that can be specified at a platform level. -// Example usage: -// auto device_options = DeviceOptions::Default(); - -#ifndef XLA_STREAM_EXECUTOR_DEVICE_OPTIONS_H_ -#define XLA_STREAM_EXECUTOR_DEVICE_OPTIONS_H_ - -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/strings/str_join.h" - -namespace stream_executor { - -// Indicates a set of options for a device's usage, which generally must be -// provided at StreamExecutor device-initialization time. -// -// These are intended to be useful-but-not-mandatorily-supported options for -// using devices on the underlying platform. Presently, if the option requested -// is not available on the target platform, a warning will be emitted. -struct DeviceOptions { - public: - // When it is observed that more memory has to be allocated for thread stacks, - // this flag prevents it from ever being deallocated. Potentially saves - // thrashing the thread stack memory allocation, but at the potential cost of - // some memory space. - static constexpr unsigned kDoNotReclaimStackAllocation = 0x1; - - // The following options refer to synchronization options when - // using SynchronizeStream or SynchronizeContext. - - // Synchronize with spinlocks. - static constexpr unsigned kScheduleSpin = 0x02; - // Synchronize with spinlocks that also call CPU yield instructions. - static constexpr unsigned kScheduleYield = 0x04; - // Synchronize with a "synchronization primitive" (e.g. mutex). - static constexpr unsigned kScheduleBlockingSync = 0x08; - - static constexpr unsigned kMask = 0xf; // Mask of all available flags. - - // Constructs an or-d together set of device options. - explicit DeviceOptions(unsigned flags) : flags_(flags) { - CHECK((flags & kMask) == flags); - } - - // Factory for the default set of device options. - static DeviceOptions Default() { return DeviceOptions(0); } - - unsigned flags() const { return flags_; } - - bool operator==(const DeviceOptions& other) const { - return flags_ == other.flags_ && - non_portable_tags == other.non_portable_tags; - } - - bool operator!=(const DeviceOptions& other) const { - return !(*this == other); - } - - std::string ToString() const { - std::vector flags_on; - if (flags_ & kDoNotReclaimStackAllocation) { - flags_on.push_back("kDoNotReclaimStackAllocation"); - } - if (flags_ & kScheduleSpin) { - flags_on.push_back("kScheduleSpin"); - } - if (flags_ & kScheduleYield) { - flags_on.push_back("kScheduleYield"); - } - if (flags_ & kScheduleBlockingSync) { - flags_on.push_back("kScheduleBlockingSync"); - } - return flags_on.empty() ? "none" : absl::StrJoin(flags_on, "|"); - } - - // Platform-specific device options. Expressed as key-value pairs to avoid - // DeviceOptions subclass proliferation. - std::map non_portable_tags; - - private: - unsigned flags_; -}; - -} // namespace stream_executor - -#endif // XLA_STREAM_EXECUTOR_DEVICE_OPTIONS_H_ diff --git a/xla/stream_executor/dnn.cc b/xla/stream_executor/dnn.cc index 8dcc5264a0091..f812dca34fc9c 100644 --- a/xla/stream_executor/dnn.cc +++ b/xla/stream_executor/dnn.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,18 +15,34 @@ limitations under the License. #include "xla/stream_executor/dnn.h" +#include #include +#include +#include #include #include +#include +#include +#include +#include +#include +#include #include "absl/algorithm/container.h" #include "absl/container/btree_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/stream_executor/data_type.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/numeric_options.h" #include "tsl/lib/strings/proto_serialization.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/protobuf/dnn.pb.h" namespace stream_executor { @@ -122,7 +138,7 @@ std::vector> AlgorithmDesc::TuningKnobs() const { return result; } -tsl::Status DnnSupport::GetConvolveRunners( +absl::Status DnnSupport::GetConvolveRunners( bool /* use_cudnn_frontend */, dnn::ConvolutionKind /*kind*/, dnn::DataType /*input_type*/, dnn::DataType /*output_type*/, Stream* /*stream*/, const dnn::BatchDescriptor& /*input_descriptor*/, @@ -135,10 +151,10 @@ tsl::Status DnnSupport::GetConvolveRunners( bool /*use_fallback*/, ScratchAllocator* /*scratch_allocator*/, const NumericOptions& /*numeric_options*/, std::vector>* /*exec_plans*/) { - return tsl::errors::Unimplemented("GetConvolveRunners not implemented."); + return absl::UnimplementedError("GetConvolveRunners not implemented."); } -tsl::StatusOr> +absl::StatusOr> DnnSupport::ConvolveRunnerFromDesc( Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, dnn::ConvolutionKind kind, dnn::DataType element_type, @@ -146,10 +162,10 @@ DnnSupport::ConvolveRunnerFromDesc( const dnn::FilterDescriptor& filter_descriptor, const dnn::BatchDescriptor& output_descriptor, const dnn::ConvolutionDescriptor& convolution_descriptor) { - return tsl::errors::Unimplemented("ConvolveRunnerFromDesc not implemented."); + return absl::UnimplementedError("ConvolveRunnerFromDesc not implemented."); } -tsl::Status DnnSupport::GetGraphConvolveRunners( +absl::Status DnnSupport::GetGraphConvolveRunners( dnn::ConvolutionKind /*kind*/, dnn::DataType /*input_type*/, dnn::DataType /*output_type*/, Stream* /*stream*/, const dnn::BatchDescriptor& /*input_descriptor*/, @@ -159,10 +175,10 @@ tsl::Status DnnSupport::GetGraphConvolveRunners( bool /*use_fallback*/, const NumericOptions& /*numeric_options*/, std::vector>* /*exec_plans*/, std::string /*serialized_graph*/) { - return tsl::errors::Unimplemented("GetGraphConvolveRunners not implemented."); + return absl::UnimplementedError("GetGraphConvolveRunners not implemented."); } -tsl::StatusOr> +absl::StatusOr> DnnSupport::GraphConvolveRunnerFromDesc( Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, dnn::ConvolutionKind kind, dnn::DataType element_type, @@ -171,11 +187,11 @@ DnnSupport::GraphConvolveRunnerFromDesc( const dnn::BatchDescriptor& output_descriptor, const dnn::ConvolutionDescriptor& convolution_descriptor, std::string serialized_graph) { - return tsl::errors::Unimplemented( + return absl::UnimplementedError( "GraphConvolveRunnerFromDesc not implemented."); } -tsl::Status DnnSupport::GetFusedConvolveRunners( +absl::Status DnnSupport::GetFusedConvolveRunners( bool use_cudnn_frontend, dnn::ConvolutionKind kind, dnn::DataType element_type, dnn::DataType bias_type, dnn::DataType output_type, double conv_input_scale, double side_input_scale, @@ -187,10 +203,10 @@ tsl::Status DnnSupport::GetFusedConvolveRunners( const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback, dnn::ActivationMode activation_mode, const NumericOptions& numeric_options, std::vector>* out_exec_plans) { - return tsl::errors::Unimplemented("GetFusedConvolveRunners not implemented."); + return absl::UnimplementedError("GetFusedConvolveRunners not implemented."); } -tsl::Status DnnSupport::GetFusedMatmulRunners( +absl::Status DnnSupport::GetFusedMatmulRunners( bool use_cudnn_frontend, dnn::DataType element_type, dnn::DataType bias_type, dnn::DataType output_type, Stream* stream, bool trans_a, bool trans_b, uint64_t m, uint64_t n, uint64_t k, int64_t lda, @@ -198,10 +214,10 @@ tsl::Status DnnSupport::GetFusedMatmulRunners( bool use_fallback, const NumericOptions& numeric_options, std::vector>* out_exec_plans) { - return tsl::errors::Unimplemented("GetFusedMatmulRunners not implemented."); + return absl::UnimplementedError("GetFusedMatmulRunners not implemented."); } -tsl::StatusOr> +absl::StatusOr> DnnSupport::FusedConvolveRunnerFromDesc( Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, dnn::ConvolutionKind kind, dnn::DataType element_type, @@ -213,23 +229,27 @@ DnnSupport::FusedConvolveRunnerFromDesc( const dnn::BatchDescriptor& output_descriptor, const dnn::ConvolutionDescriptor& convolution_descriptor, dnn::ActivationMode activation_mode) { - return tsl::errors::Unimplemented( + return absl::UnimplementedError( "FusedConvolveRunnerFromDesc not implemented."); } -tsl::StatusOr> +absl::StatusOr> DnnSupport::NormRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, double epsilon, - const dnn::TensorDescriptor& input_descriptor, + Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, + dnn::NormKind kind, double epsilon, + const dnn::TensorDescriptor& x_descriptor, const dnn::TensorDescriptor& scale_descriptor, - const dnn::TensorDescriptor& bias_descriptor, - const dnn::TensorDescriptor& output_descriptor, + const dnn::TensorDescriptor& y_or_dx_descriptor, + std::optional bias_descriptor, + std::optional dy_descriptor, std::optional expectation_descriptor, - std::optional norm_factor_descriptor) { + std::optional norm_factor_descriptor, + std::optional dscale_descriptor, + std::optional dbias_descriptor) { return absl::UnimplementedError("NormRunnerFromDesc not implemented."); } -tsl::StatusOr> +absl::StatusOr> DnnSupport::FusedMHARunnerFromDesc( Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, dnn::FusedMHAKind kind, @@ -246,7 +266,7 @@ DnnSupport::FusedMHARunnerFromDesc( return absl::UnimplementedError("FusedMHARunnerFromDesc not implemented."); } -tsl::StatusOr> +absl::StatusOr> DnnSupport::FusedMHABackwardRunnerFromDesc( Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, dnn::FusedMHAKind kind, @@ -287,7 +307,7 @@ bool DnnSupport::GetRnnAlgorithms(std::vector* out_algorithms) { return false; } -tsl::Status DnnSupport::DoPoolForward( +absl::Status DnnSupport::DoPoolForward( DataType element_type, Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const NumericOptions& numeric_options, @@ -300,7 +320,7 @@ tsl::Status DnnSupport::DoPoolForward( output_data, workspace_allocator); } -tsl::Status DnnSupport::DoPoolBackward( +absl::Status DnnSupport::DoPoolBackward( DataType element_type, Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const NumericOptions& numeric_options, @@ -585,7 +605,7 @@ int TensorDescriptor::ndims() const { return dimensions_.size(); } -tsl::StatusOr> +absl::StatusOr> TensorDescriptor::GetPhysicalDimensionsMajorToMinor() const { std::vector logical_to_physical(minor_to_major_.size()); for (int64_t physical = 0; physical < logical_to_physical.size(); @@ -594,8 +614,7 @@ TensorDescriptor::GetPhysicalDimensionsMajorToMinor() const { logical_to_physical[logical] = physical; } if (dimensions_.size() != minor_to_major_.size()) - return tsl::errors::Internal( - "Dimensions size should match the layout size."); + return absl::InternalError("Dimensions size should match the layout size."); std::vector physical_dims(dimensions_.size()); for (int64_t i = 0; i < physical_dims.size(); ++i) { @@ -645,14 +664,14 @@ std::string TensorDescriptor::ToString() const { // -- MatmulTensorDescriptor -tsl::StatusOr> +absl::StatusOr> MatmulTensorDescriptor::GetNonContractingDims() const { std::vector non_contracting_dims; for (int64_t dim = 0; dim < tensor_.dimensions().size(); ++dim) { bool is_batch = absl::c_count(batch_dimension_numbers_, dim) != 0; bool is_contracting = absl::c_count(contracting_dim_, dim) != 0; if (is_batch && is_contracting) - return tsl::errors::Internal( + return absl::InternalError( "A dimension cannot be both a batch dimension and a contracting " "dimension."); if (!(is_batch || is_contracting)) non_contracting_dims.push_back(dim); @@ -661,14 +680,15 @@ MatmulTensorDescriptor::GetNonContractingDims() const { if (batch_dimension_numbers_.size() + contracting_dim_.size() + non_contracting_dims.size() != tensor_.dimensions().size()) - return tsl::errors::Internal( + return absl::InternalError( "Batch_dimension_numbers, contracting_dim and non_contracting_dims " "should sum up to the total number of dimensions."); return non_contracting_dims; } -tsl::StatusOr> MatmulTensorDescriptor::MakeCudnnCompatible( - const std::vector& vec, bool is_lhs) const { +absl::StatusOr> +MatmulTensorDescriptor::MakeCudnnCompatible(const std::vector& vec, + bool is_lhs) const { std::vector cudnn_compatible(vec.size()); int batch_dim_size = batch_dimension_numbers_.size(); CHECK_LT(batch_dim_size, vec.size()); @@ -679,7 +699,7 @@ tsl::StatusOr> MatmulTensorDescriptor::MakeCudnnCompatible( if (batch_dimension_numbers_.size() + contracting_dim_.size() + non_contracting_dims.size() != vec.size()) - return tsl::errors::Internal( + return absl::InternalError( "Batch_dimension_numbers, contracting_dim and non_contracting_dims " "should sum up to the total number of dimensions."); if (is_lhs) /* lhs -> {b0, b1,....bk, m, k} */ { @@ -1148,7 +1168,7 @@ std::string NormalizeDescriptor::ToShortString() const { "_size:", segment_size_); } -bool DnnSupport::IsStatusOk(const tsl::Status& status, bool report_error) { +bool DnnSupport::IsStatusOk(const absl::Status& status, bool report_error) { if (status.ok()) { return true; } @@ -1158,7 +1178,7 @@ bool DnnSupport::IsStatusOk(const tsl::Status& status, bool report_error) { return false; } -tsl::Status DnnSupport::DoCtcLoss( +absl::Status DnnSupport::DoCtcLoss( Stream* stream, dnn::DataType element_type, const RnnStateTensorDescriptor& probs_desc, const DeviceMemoryBase probs_data, absl::Span labels_data, @@ -1166,7 +1186,7 @@ tsl::Status DnnSupport::DoCtcLoss( absl::Span input_lengths_data, DeviceMemoryBase costs_data, const RnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data, DeviceMemory scratch_memory, int ctc_loss_algo_id) { - return tsl::errors::Unimplemented("CtcLoss not implemented"); + return absl::UnimplementedError("CtcLoss not implemented"); } } // namespace dnn diff --git a/xla/stream_executor/dnn.h b/xla/stream_executor/dnn.h index a181516e9b26e..961def24f07cd 100644 --- a/xla/stream_executor/dnn.h +++ b/xla/stream_executor/dnn.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,8 +22,8 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_DNN_H_ #define XLA_STREAM_EXECUTOR_DNN_H_ +#include #include -#include #include #include #include @@ -35,16 +35,16 @@ limitations under the License. #include #include "google/protobuf/wrappers.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/data_type.h" -#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/numeric_options.h" -#include "xla/stream_executor/platform/port.h" #include "tsl/platform/logging.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" #include "tsl/protobuf/dnn.pb.h" namespace Eigen { @@ -149,7 +149,8 @@ enum class RnnDirectionMode { class TensorDescriptor { public: TensorDescriptor() = default; - tsl::StatusOr> GetPhysicalDimensionsMajorToMinor() const; + absl::StatusOr> GetPhysicalDimensionsMajorToMinor() + const; std::vector GetPhysicalStridesMajorToMinor() const; std::vector GetLogicalStrides() const; @@ -178,14 +179,14 @@ class TensorDescriptor { class MatmulTensorDescriptor { public: MatmulTensorDescriptor() = default; - tsl::StatusOr> GetNonContractingDims() const; + absl::StatusOr> GetNonContractingDims() const; std::vector GetCudnnCompatibleDimensions( bool is_lhs /*if not lhs, then rhs*/) const; std::vector GetCudnnCompatibleStrides( bool is_lhs /*if not lhs, then rhs*/) const; - tsl::StatusOr> MakeCudnnCompatible( + absl::StatusOr> MakeCudnnCompatible( const std::vector&, bool is_lhs) const; static MatmulTensorDescriptor For(DataType type, @@ -218,7 +219,7 @@ class MatmulTensorDescriptor { // Specifies the descriptor for a RNN model. // // An example use case: -// * The user first creates a model through createRnnDescriptor. +// * The user first creates a model through CreateRnnDescriptor. // * The user queries the size of the underlying opaque parameter buffer. // * The user creates and initializes a parameter buffer of the proper size. // * The user runs forward and backward operations using this RNN descriptor. @@ -417,7 +418,7 @@ class BatchDescriptor { // dimensions, except possibly for feature_map_count(), though this // function does not verify that. static BatchDescriptor DepthConcatenateOutputDescriptor( - absl::Span inputs); + absl::Span inputs); private: absl::Span spatial_size() const { @@ -568,7 +569,7 @@ enum class PadAlignment : int64_t { std::string PadAlignmentString(PadAlignment alignment); // Print alignment to str. Needed to use CHECK_EQ between two PadAlignments. -std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment); +std::ostream& operator<<(std::ostream& str, PadAlignment alignment); // Describes a convolution. // @@ -909,6 +910,8 @@ class ProfileResult { return algorithm_.has_value() && elapsed_time_in_ms() != std::numeric_limits::max(); } + bool warmup_run_executed() const { return warmup_run_executed_; } + void set_warmup_run_executed(bool val) { warmup_run_executed_ = val; } AlgorithmDesc algorithm() const { return *algorithm_; } void set_algorithm(AlgorithmDesc val) { algorithm_ = val; } @@ -925,6 +928,7 @@ class ProfileResult { // The scratch size algorithm_ requires. Currently it's only populated by // convolutions. size_t scratch_size_ = 0; + bool warmup_run_executed_ = false; }; // Backend-specific data shared between repeated launches of the same @@ -957,12 +961,12 @@ class OpRunner { virtual size_t GetWorkspaceSize() const = 0; // Convert to an AlgorithmDesc for AoT compilation or autotuning. - virtual tsl::StatusOr ToAlgorithmDesc() const = 0; + virtual absl::StatusOr ToAlgorithmDesc() const = 0; // Launch the operation, with the signature determined by `Sig`. - virtual tsl::Status operator()(Stream*, ProfileResult*, - DeviceMemoryBase scratch_memory, - Args... args) const = 0; + virtual absl::Status operator()(Stream*, ProfileResult*, + DeviceMemoryBase scratch_memory, + Args... args) const = 0; }; using ConvSignature = void(DeviceMemoryBase /* input_data */, @@ -995,7 +999,9 @@ using FusedMHASignature = void(DeviceMemoryBase /*BMM1_inputA_data*/, DeviceMemoryBase /* output_data */, DeviceMemoryBase /* mask_data */, DeviceMemoryBase /* bias_data */, - DeviceMemoryBase /* activation_data */); + DeviceMemoryBase /* activation_data */, + DeviceMemoryBase /* seqlen_q_data */, + DeviceMemoryBase /* seqlen_k_data */); using FusedMHARunner = OpRunner; using FusedMHABackwardSignature = void( @@ -1010,7 +1016,8 @@ using FusedMHABackwardSignature = void( DeviceMemoryBase /* softmax_sum_data */, DeviceMemoryBase /* d_Q_accum_data */, DeviceMemoryBase /* mask_data */, DeviceMemoryBase /* d_bias_data */, DeviceMemoryBase /* fwd_output_data */, - DeviceMemoryBase /* bias_data */); + DeviceMemoryBase /* bias_data */, DeviceMemoryBase /* seqlen_q_data */, + DeviceMemoryBase /* seqlen_k_data */); using FusedMHABackwardRunner = OpRunner; // Describes the configuration for the algorithms that will used. @@ -1246,6 +1253,25 @@ class VersionInfo { int patch_; }; +class DnnSupport; + +class DnnGraph { + public: + DnnGraph() = default; + virtual ~DnnGraph() = default; + + // Returns non-OK status on hard failures (incorrectly constructed graph, + // anything else unexpected), + // false on expected ones (graph is valid but not supported), + // true on success. + virtual absl::StatusOr Prepare(DnnSupport&) = 0; + virtual absl::Status Build(DnnSupport&, int64_t plan_id) = 0; + virtual absl::Status Execute(Stream& stream, + absl::Span operands) const = 0; +}; + +using LazyDnnGraph = std::unique_ptr; + // Suite of operations typically used for implementing Deep/Convolutional Neural // Nets. Note: A false return value of an operation indicates the // implementation is not available. @@ -1259,7 +1285,7 @@ class VersionInfo { // functions are actually implemented by both backends, the rest are // actually backend-specific. The massive interface creates extra mental // burden. -// * Poor error handling: the API should return tsl::Status objects. +// * Poor error handling: the API should return absl::Status objects. // // PrepareForConvolution is an example for how new APIs should be written. class DnnSupport { @@ -1267,11 +1293,11 @@ class DnnSupport { DnnSupport() = default; virtual ~DnnSupport() = default; - virtual tsl::Status Init() = 0; + virtual absl::Status Init() = 0; // Gets the version of the backing library, as a VersionInfo object. - virtual tsl::StatusOr GetVersion() { - return tsl::errors::Unimplemented( + virtual absl::StatusOr GetVersion() { + return absl::UnimplementedError( "DnnSupport::GetVersion not implemented on this platform."); } @@ -1310,12 +1336,11 @@ class DnnSupport { const DeviceMemory& scale, const DeviceMemory& offset, const DeviceMemory& estimated_mean, const DeviceMemory& estimated_variance, - const DeviceMemory& side_input, const dnn::BatchDescriptor& x_desc, - const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, - const double exponential_average_factor, - dnn::ActivationMode activation_mode, DeviceMemory* y, - DeviceMemory* batch_mean, DeviceMemory* batch_var, - DeviceMemory* reserve_space_1, + const DeviceMemory& side_input, const BatchDescriptor& x_desc, + const BatchDescriptor& scale_offset_desc, const double epsilon, + const double exponential_average_factor, ActivationMode activation_mode, + DeviceMemory* y, DeviceMemory* batch_mean, + DeviceMemory* batch_var, DeviceMemory* reserve_space_1, DeviceMemory* reserve_space_2, bool is_training, ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator) { @@ -1330,10 +1355,9 @@ class DnnSupport { const DeviceMemory& estimated_mean, const DeviceMemory& estimated_variance, const DeviceMemory& side_input, - const dnn::BatchDescriptor& x_desc, - const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, - const double exponential_average_factor, - dnn::ActivationMode activation_mode, DeviceMemory* y, + const BatchDescriptor& x_desc, const BatchDescriptor& scale_offset_desc, + const double epsilon, const double exponential_average_factor, + ActivationMode activation_mode, DeviceMemory* y, DeviceMemory* batch_mean, DeviceMemory* batch_var, DeviceMemory* reserve_space_1, DeviceMemory* reserve_space_2, bool is_training, @@ -1350,10 +1374,9 @@ class DnnSupport { const DeviceMemory& estimated_mean, const DeviceMemory& estimated_variance, const DeviceMemory& side_input, - const dnn::BatchDescriptor& x_desc, - const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, - const double exponential_average_factor, - dnn::ActivationMode activation_mode, DeviceMemory* y, + const BatchDescriptor& x_desc, const BatchDescriptor& scale_offset_desc, + const double epsilon, const double exponential_average_factor, + ActivationMode activation_mode, DeviceMemory* y, DeviceMemory* batch_mean, DeviceMemory* batch_var, DeviceMemory* reserve_space_1, DeviceMemory* reserve_space_2, bool is_training, @@ -1384,10 +1407,10 @@ class DnnSupport { const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& offset, const DeviceMemory& mean, const DeviceMemory& inv_var, const DeviceMemory& y, - const dnn::BatchDescriptor& x_desc, - const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, - dnn::ActivationMode activation_mode, DeviceMemory* x_backprop, - DeviceMemory* scale_backprop, DeviceMemory* offset_backprop, + const BatchDescriptor& x_desc, const BatchDescriptor& scale_offset_desc, + const double epsilon, ActivationMode activation_mode, + DeviceMemory* x_backprop, DeviceMemory* scale_backprop, + DeviceMemory* offset_backprop, DeviceMemory* side_input_backprop, DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator) { @@ -1402,9 +1425,8 @@ class DnnSupport { const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& offset, const DeviceMemory& mean, const DeviceMemory& inv_var, const DeviceMemory& y, - const dnn::BatchDescriptor& x_desc, - const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, - dnn::ActivationMode activation_mode, + const BatchDescriptor& x_desc, const BatchDescriptor& scale_offset_desc, + const double epsilon, ActivationMode activation_mode, DeviceMemory* x_backprop, DeviceMemory* scale_backprop, DeviceMemory* offset_backprop, DeviceMemory* side_input_backprop, @@ -1421,11 +1443,9 @@ class DnnSupport { const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& offset, const DeviceMemory& mean, const DeviceMemory& inv_var, - const DeviceMemory& y, - const dnn::BatchDescriptor& x_desc, - const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, - dnn::ActivationMode activation_mode, - DeviceMemory* x_backprop, + const DeviceMemory& y, const BatchDescriptor& x_desc, + const BatchDescriptor& scale_offset_desc, const double epsilon, + ActivationMode activation_mode, DeviceMemory* x_backprop, DeviceMemory* scale_backprop, DeviceMemory* offset_backprop, DeviceMemory* side_input_backprop, DeviceMemory* reserve_space_data, @@ -1483,27 +1503,49 @@ class DnnSupport { // that if the inverse of the filter is applied to the output in VALID mode // the result is the same size as the input - this requires even more // padding of the input. - virtual tsl::Status DoFusedConvolve( + virtual absl::Status DoFusedConvolve( Stream* stream, DataType input_type, DataType side_input_type, DataType bias_type, DataType output_type, - const dnn::BatchDescriptor& conv_input_descriptor, + const BatchDescriptor& conv_input_descriptor, DeviceMemoryBase conv_input_data, double conv_input_scale, - const dnn::FilterDescriptor& filter_descriptor, - DeviceMemoryBase filter_data, - const dnn::ConvolutionDescriptor& convolution_descriptor, + const FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data, + const ConvolutionDescriptor& convolution_descriptor, DeviceMemoryBase side_input_data, double side_input_scale, - const dnn::BatchDescriptor& bias_descriptor, DeviceMemoryBase biases, - dnn::ActivationMode activation_mode, - const dnn::BatchDescriptor& output_descriptor, + const BatchDescriptor& bias_descriptor, DeviceMemoryBase biases, + ActivationMode activation_mode, const BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, ScratchAllocator* scratch_allocator, - const dnn::AlgorithmConfig& algorithm_config, - dnn::ProfileResult* output_profile_result) { - return tsl::errors::Unimplemented( + const AlgorithmConfig& algorithm_config, + ProfileResult* output_profile_result) { + return absl::UnimplementedError( "DnnSupport::DoFusedConvolve not implemented on this platform."); } + template + absl::Status FusedConvolveWithAlgorithm( + Stream* stream, const BatchDescriptor& conv_input_descriptor, + const DeviceMemory& conv_input_data, ScaleT conv_input_scale, + const FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& side_input_data, ScaleT side_input_scale, + const BatchDescriptor& bias_descriptor, const DeviceMemory& biases, + ActivationMode activation_mode, const BatchDescriptor& output_descriptor, + DeviceMemory* output, ScratchAllocator* scratch_allocator, + const AlgorithmConfig& algorithm_config, + ProfileResult* output_profile_result) { + return DoFusedConvolve( + stream, ToDataType::value, ToDataType::value, + ToDataType::value, ToDataType::value, + conv_input_descriptor, conv_input_data, conv_input_scale, + filter_descriptor, filter_data, convolution_descriptor, side_input_data, + side_input_scale, bias_descriptor, biases, activation_mode, + output_descriptor, *output, scratch_allocator, algorithm_config, + output_profile_result); + } + template - tsl::Status PrepareForConvolution( + absl::Status PrepareForConvolution( ConvolutionKind kind, Stream* stream, const BatchDescriptor& batch_descriptor, DeviceMemory input_data, @@ -1524,13 +1566,13 @@ class DnnSupport { // cuDNN-specific input transformation that allows running int8x32 // convolutions faster using Tensor Core IMMA instruction. - virtual tsl::Status CudnnReorderConvolutionFilterAndBias( + virtual absl::Status CudnnReorderConvolutionFilterAndBias( Stream* stream, const FilterDescriptor& filter_descriptor, const DeviceMemory& filter_input, DeviceMemory* filter_output, std::optional> bias_input, std::optional> bias_output) { - return tsl::errors::Unimplemented( + return absl::UnimplementedError( "DnnSupport::CudnnReorderConvolutionFilterAndBias is specific to CUDA " "convolution implementation."); } @@ -1569,7 +1611,7 @@ class DnnSupport { // that if the inverse of the filter is applied to the output in VALID mode // the result is the same size as the input - this requires even more // padding of the input. - virtual tsl::Status DoConvolve( + virtual absl::Status DoConvolve( ConvolutionKind kind, DataType element_type, DataType output_type, Stream* stream, const BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, const FilterDescriptor& filter_descriptor, @@ -1579,115 +1621,140 @@ class DnnSupport { AlgorithmDesc algorithm_desc, DeviceMemory scratch_memory, ProfileResult* output_profile_result) = 0; - virtual tsl::Status GetConvolveRunners( - bool use_cudnn_frontend, dnn::ConvolutionKind kind, - dnn::DataType input_type, dnn::DataType output_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, - const dnn::FilterDescriptor& filter_descriptor, - DeviceMemoryBase filter_data, - const dnn::BatchDescriptor& output_descriptor, - DeviceMemoryBase output_data, - const dnn::ConvolutionDescriptor& convolution_descriptor, - bool use_fallback, ScratchAllocator* scratch_allocator, + template + absl::Status ConvolveWithAlgorithm( + Stream* stream, ConvolutionKind kind, + const BatchDescriptor& input_descriptor, + DeviceMemory input_data, + const FilterDescriptor& filter_descriptor, + DeviceMemory filter_data, + const BatchDescriptor& output_descriptor, + DeviceMemory output_data, + const ConvolutionDescriptor& convolution_descriptor, + ScratchAllocator* scratch_allocator, + const AlgorithmConfig& algorithm_config, + ProfileResult* output_profile_result) { + DeviceMemory scratch_memory; + AlgorithmDesc algorithm_desc; + TF_RETURN_IF_ERROR(PrepareForConvolution( + kind, stream, input_descriptor, input_data, filter_descriptor, + filter_data, output_descriptor, output_data, convolution_descriptor, + algorithm_config, scratch_allocator, &algorithm_desc, &scratch_memory)); + return DoConvolve(kind, ToDataType::value, + ToDataType::value, stream, input_descriptor, + input_data, filter_descriptor, filter_data, + output_descriptor, output_data, convolution_descriptor, + algorithm_desc, scratch_memory, output_profile_result); + } + + virtual absl::Status GetConvolveRunners( + bool use_cudnn_frontend, ConvolutionKind kind, DataType input_type, + DataType output_type, Stream* stream, + const BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, + const FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data, + const BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, + const ConvolutionDescriptor& convolution_descriptor, bool use_fallback, + ScratchAllocator* scratch_allocator, const NumericOptions& numeric_options, - std::vector>* out_exec_plans); - - virtual tsl::StatusOr> - ConvolveRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - dnn::ConvolutionKind kind, dnn::DataType element_type, - dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor, - const dnn::FilterDescriptor& filter_descriptor, - const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor); - - virtual tsl::Status GetGraphConvolveRunners( - dnn::ConvolutionKind kind, dnn::DataType input_type, - dnn::DataType output_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, - const dnn::FilterDescriptor& filter_descriptor, - const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor, - bool use_fallback, const NumericOptions& numeric_options, - std::vector>* out_exec_plans, + std::vector>* out_exec_plans); + + virtual absl::StatusOr> + ConvolveRunnerFromDesc(Stream* stream, const AlgorithmDesc& algorithm_desc, + ConvolutionKind kind, DataType element_type, + DataType output_type, + const BatchDescriptor& input_descriptor, + const FilterDescriptor& filter_descriptor, + const BatchDescriptor& output_descriptor, + const ConvolutionDescriptor& convolution_descriptor); + + virtual absl::Status GetGraphConvolveRunners( + ConvolutionKind kind, DataType input_type, DataType output_type, + Stream* stream, const BatchDescriptor& input_descriptor, + const FilterDescriptor& filter_descriptor, + const BatchDescriptor& output_descriptor, + const ConvolutionDescriptor& convolution_descriptor, bool use_fallback, + const NumericOptions& numeric_options, + std::vector>* out_exec_plans, std::string serialized_graph); - virtual tsl::StatusOr> + virtual absl::StatusOr> GraphConvolveRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - dnn::ConvolutionKind kind, dnn::DataType element_type, - dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor, - const dnn::FilterDescriptor& filter_descriptor, - const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor, + Stream* stream, const AlgorithmDesc& algorithm_desc, ConvolutionKind kind, + DataType element_type, DataType output_type, + const BatchDescriptor& input_descriptor, + const FilterDescriptor& filter_descriptor, + const BatchDescriptor& output_descriptor, + const ConvolutionDescriptor& convolution_descriptor, std::string serialized_graph); - virtual tsl::Status GetFusedConvolveRunners( - bool use_cudnn_frontend, dnn::ConvolutionKind kind, - dnn::DataType element_type, dnn::DataType bias_type, - dnn::DataType output_type, double conv_input_scale, + virtual absl::Status GetFusedConvolveRunners( + bool use_cudnn_frontend, ConvolutionKind kind, DataType element_type, + DataType bias_type, DataType output_type, double conv_input_scale, double side_input_scale, double leakyrelu_alpha, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, - const dnn::FilterDescriptor& filter_descriptor, - const dnn::BatchDescriptor& bias_descriptor, - const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor, - bool use_fallback, dnn::ActivationMode activation_mode, - const NumericOptions& numeric_options, - std::vector>* out_exec_plans); - - virtual tsl::Status GetFusedMatmulRunners( - bool use_cudnn_frontend, dnn::DataType element_type, - dnn::DataType bias_type, dnn::DataType output_type, Stream* stream, - bool trans_a, bool trans_b, uint64_t m, uint64_t n, uint64_t k, - int64_t lda, int64_t ldb, int64_t ldc, - dnn::ActivationMode activation_mode, bool use_fallback, + const BatchDescriptor& input_descriptor, + const FilterDescriptor& filter_descriptor, + const BatchDescriptor& bias_descriptor, + const BatchDescriptor& output_descriptor, + const ConvolutionDescriptor& convolution_descriptor, bool use_fallback, + ActivationMode activation_mode, const NumericOptions& numeric_options, + std::vector>* out_exec_plans); + + virtual absl::Status GetFusedMatmulRunners( + bool use_cudnn_frontend, DataType element_type, DataType bias_type, + DataType output_type, Stream* stream, bool trans_a, bool trans_b, + uint64_t m, uint64_t n, uint64_t k, int64_t lda, int64_t ldb, int64_t ldc, + ActivationMode activation_mode, bool use_fallback, const NumericOptions& numeric_options, - std::vector>* - out_exec_plans); + std::vector>* out_exec_plans); - virtual tsl::StatusOr> + virtual absl::StatusOr> FusedConvolveRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - dnn::ConvolutionKind kind, dnn::DataType element_type, - dnn::DataType bias_type, dnn::DataType output_type, double conv_scale, - double side_input_scale, double leakyrelu_alpha, - const dnn::BatchDescriptor& input_descriptor, - const dnn::FilterDescriptor& filter_descriptor, - const dnn::BatchDescriptor& bias_descriptor, - const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor, - dnn::ActivationMode activation_mode); - - virtual tsl::StatusOr> + Stream* stream, const AlgorithmDesc& algorithm_desc, ConvolutionKind kind, + DataType element_type, DataType bias_type, DataType output_type, + double conv_scale, double side_input_scale, double leakyrelu_alpha, + const BatchDescriptor& input_descriptor, + const FilterDescriptor& filter_descriptor, + const BatchDescriptor& bias_descriptor, + const BatchDescriptor& output_descriptor, + const ConvolutionDescriptor& convolution_descriptor, + ActivationMode activation_mode); + + virtual absl::StatusOr> NormRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, double epsilon, - const dnn::TensorDescriptor& input_descriptor, + Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, + dnn::NormKind kind, double epsilon, + const dnn::TensorDescriptor& x_descriptor, const dnn::TensorDescriptor& scale_descriptor, - const dnn::TensorDescriptor& bias_descriptor, - const dnn::TensorDescriptor& output_descriptor, + const dnn::TensorDescriptor& y_or_dx_descriptor, + std::optional bias_descriptor, + std::optional dy_descriptor, std::optional expectation_descriptor, - std::optional norm_factor_descriptor); + std::optional norm_factor_descriptor, + std::optional dscale_descriptor, + std::optional dbias_descriptor); + + virtual absl::StatusOr> DeserializeGraph( + absl::string_view) const { + return absl::UnimplementedError("Graph support requires cuDNN >= 8.1."); + }; - virtual tsl::StatusOr> + virtual absl::StatusOr> FusedMHARunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - dnn::FusedMHAKind kind, - const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, - const dnn::TensorDescriptor& output_descriptor, - std::optional activation_descriptor, - std::optional mask_descriptor, - std::optional bias_descriptor, double scale, + Stream* stream, const AlgorithmDesc& algorithm_desc, FusedMHAKind kind, + const MatmulTensorDescriptor& bmm1_lhs_descriptor, + const MatmulTensorDescriptor& bmm1_rhs_descriptor, + const MatmulTensorDescriptor& bmm2_rhs_descriptor, + const MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, + const TensorDescriptor& output_descriptor, + std::optional activation_descriptor, + std::optional mask_descriptor, + std::optional bias_descriptor, double scale, std::optional dropout_rate, std::optional seed, bool is_flash_attention, bool is_causal_mask); - virtual tsl::StatusOr> + virtual absl::StatusOr> FusedMHABackwardRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - dnn::FusedMHAKind kind, + Stream* stream, const AlgorithmDesc& algorithm_desc, FusedMHAKind kind, const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor, @@ -1696,152 +1763,57 @@ class DnnSupport { const TensorDescriptor& d_bmm1_lhs_descriptor, const TensorDescriptor& d_bmm1_rhs_descriptor, const TensorDescriptor& d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional mask_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, + std::optional d_s_descriptor, + std::optional mask_descriptor, + std::optional d_bias_descriptor, + std::optional fwd_output_descriptor, + std::optional bias_descriptor, double scale, std::optional dropout_rate, std::optional seed, bool is_flash_attention, bool is_causal_mask); virtual bool GetMIOpenConvolveAlgorithms( - dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, - const dnn::FilterDescriptor& filter_descriptor, - DeviceMemoryBase filter_data, - const dnn::BatchDescriptor& output_descriptor, - DeviceMemoryBase output_data, - const dnn::ConvolutionDescriptor& convolution_descriptor, + ConvolutionKind kind, DataType element_type, Stream* stream, + const BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, + const FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data, + const BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, + const ConvolutionDescriptor& convolution_descriptor, ScratchAllocator* scratch_allocator, std::vector* out_algorithms); // Returns a list of supported rnn algorithms. virtual bool GetRnnAlgorithms(std::vector* out_algorithms); - // Variation of the above with the weight matrix split into two matrices. - // first_weights: Coefficients of the first matrix. - // second_weights: Coefficients of the second matrix. - // depth_multiplier: specifies the columns of the first matrix and rows - // of the second one - first_weights columns = depth_multiplier, - // second_weights rows = depth_multiplier * - // filter_descriptor.input_feature_map_count(). - // see go/separable for documentation on separable convolutions. - virtual bool DoSeparableConvolve( - Stream* stream, const BatchDescriptor& input_descriptor, - const DeviceMemory& input_data, - const FilterDescriptor& filter_descriptor, int depth_multiplier, - const DeviceMemory& first_weights, - const DeviceMemory& second_weights, - const ConvolutionDescriptor& convolution_descriptor, - const BatchDescriptor& output_descriptor, - DeviceMemory* output_data) = 0; + template + absl::Status PoolForward(Stream* stream, + const PoolingDescriptor& pooling_dimensions, + const NumericOptions& numeric_options, + const BatchDescriptor& input_dimensions, + const DeviceMemory& input_data, + const BatchDescriptor& output_dimensions, + DeviceMemory* output_data, + ScratchAllocator* workspace_allocator = nullptr) { + return DoPoolForward(ToDataType::value, stream, + pooling_dimensions, numeric_options, input_dimensions, + input_data, output_dimensions, *output_data, + workspace_allocator); + } - // Fully connects the "nodes" (float values) in input_data with - // shape input_dimensions to output_data with output_dimensions - // using provided weights. This is equivalent to computing a matrix - // product, hence the name MatMul. - // - // A BatchDescriptor has four dimensions: batch, y, x, depth. Matrix products - // happen in two dimensions. To get down to two dimensions, we consider the - // input y, x and depth dimension as one combined dimension T. For now, - // assume that the output height and width are 1 and let OD be the output - // depth. - // - // There are three device memory buffers passed in to this - // function. We can now view all three as matrices: - // - // input_data: A batch x T matrix - // weights: A T x OD matrix - // output_data: A batch x OD matrix - // - // This function then computes the matrix product of input_data and - // weights and writes the result into output_data. - // - // Here the weights buffer is in row major order, i.e. the first OD - // entries in weights are the first row, the second OD entries in - // weights are the second row and so on. - // - // The case for output width*height > 1 is more complicated. Let K = - // OY * OX where OY is the output height and OX is the output - // width. Then weights is divided into K sub-arrays W_i, for - // i=0,...,k-1, that each represent a T x OD matrix. This function - // then computes the K matrix multiplications of input_data with - // each W_i. This creates K matrices with dimensions batch x - // OD. These K matrices are concatenated horizontally to form one - // larger matrix with dimensions batch x (K*OD); note that this is - // not the same as concatenating the bytes of the matrices. The - // combined matrix can then be interpreted as a tensor with - // dimensions (batch, OY, OX, OD). If the output tensor format is - // not kBatchYXDepth, this function would then need to arrange for - // the output to be in the requested layout, if that is - // supported. Note that the case K=1 is equivalent to the - // description above. It is recommended to prefer the case K=1. - // - // Arguments (all borrowed): - // stream: borrowed pointer to the stream that the 'fully connect' operation - // should be enqueued onto. - // output_data: un-owned device memory region in which to place the - // fully connected result. - virtual bool DoMatMul(Stream* stream, const DeviceMemory& input_data, - const DeviceMemory& weights, - const dnn::BatchDescriptor& input_dimensions, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) = 0; - - // Version of DoMatMul that uses pre-quantized 8 bit weights. - // weight_scales specifies the scaling of each column of weights: - // original float weight[row * num_columns + column] = - // quantized_weight[row * nnum_columns + column] * weight_scales[column]. - virtual bool DoMatMulQuantized(Stream* stream, - const DeviceMemory& input_data, - const DeviceMemory& quantized_weights, - const DeviceMemory& weight_scales, - const dnn::BatchDescriptor& input_dimensions, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) = 0; - - // Version of DoMatMul that uses pre-quantized 16 bit weights. - // weight_scales specifies the scaling of each column of weights: - // original float weight[row * num_columns + column] = - // quantized_weight[row * nnum_columns + column] * weight_scales[column]. - virtual bool DoMatMulQuantized(Stream* stream, - const DeviceMemory& input_data, - const DeviceMemory& quantized_weights, - const DeviceMemory& weight_scales, - const dnn::BatchDescriptor& input_dimensions, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) = 0; - - // Adds biases to the feature maps in input_data producing - // output_data. input_data can equal output_data, but must not - // partially overlap it. - // - // Let K = count() * height() * width() and N = feature_map_count() - // on dimensions. Then input_value contains K*N values and biases - // contains N values. We can thus logically consider input_value to - // contain K vectors of N elements each. This function adds biases - // to each of those N vectors. - // - // TODO(broune): This works differently when width() * height() > 1 - // and the call to ThenBiasAdd() follows a call to ThenMatMul(). In - // that case there should be width() * height() * - // feature_map_count() biases, but this is not implemented on all - // StreamExecutors. - // - // Arguments (all borrowed): - // stream: borrowed pointer to the stream that the 'bias add' operation - // should be enqueued onto. - // input_data: un-owned device memory region containing the input. - // biases: un-owned device memory region containing biases to add to the - // input. - // dimensions: dimensions of input_data and output_data. - // output_data: un-owned device memory region in which to place the result. - virtual bool DoBiasAdd(Stream* stream, const DeviceMemory& input_data, - const DeviceMemory& biases, - const dnn::BatchDescriptor& dimensions, - DeviceMemory* output_data) = 0; - - // Performs a forward pooling operation on input_data, writing to + template + absl::Status PoolBackward(Stream* stream, + const PoolingDescriptor& pooling_dimensions, + const NumericOptions& numeric_options, + const BatchDescriptor& input_dimensions, + const DeviceMemory& input_data, + const BatchDescriptor& output_dimensions, + const DeviceMemory& output_data, + const DeviceMemory& input_diff_data, + DeviceMemory* output_diff_data, + ScratchAllocator* workspace_allocator = nullptr) { + return DoPoolBackward( + ToDataType::value, stream, pooling_dimensions, + numeric_options, input_dimensions, input_data, output_dimensions, + output_data, input_diff_data, *output_diff_data, workspace_allocator); + } // Performs a forward pooling operation on input_data, writing to // output_data. See PoolingDescriptor for how to configure the // pooling operation. // @@ -1854,39 +1826,38 @@ class DnnSupport { // the input. The output width and height can be different. // // See PoolingDescriptor for how to configure the pooling operation. - virtual tsl::Status DoPoolForward( + virtual absl::Status DoPoolForward( DataType element_type, Stream* stream, - const dnn::PoolingDescriptor& pooling_dimensions, - const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemoryBase output_data, ScratchAllocator* workspace_allocator) = 0; + const PoolingDescriptor& pooling_dimensions, + const BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, + const BatchDescriptor& output_dimensions, DeviceMemoryBase output_data, + ScratchAllocator* workspace_allocator) = 0; - virtual tsl::Status DoPoolForward( + virtual absl::Status DoPoolForward( DataType element_type, Stream* stream, - const dnn::PoolingDescriptor& pooling_dimensions, + const PoolingDescriptor& pooling_dimensions, const NumericOptions& numeric_options, - const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemoryBase output_data, ScratchAllocator* workspace_allocator); + const BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, + const BatchDescriptor& output_dimensions, DeviceMemoryBase output_data, + ScratchAllocator* workspace_allocator); // Performs differentiation of the pooling operation. - virtual tsl::Status DoPoolBackward( + virtual absl::Status DoPoolBackward( DataType element_type, Stream* stream, - const dnn::PoolingDescriptor& pooling_dimensions, - const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemoryBase output_data, DeviceMemoryBase input_diff_data, - DeviceMemoryBase output_diff_data, + const PoolingDescriptor& pooling_dimensions, + const BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, + const BatchDescriptor& output_dimensions, DeviceMemoryBase output_data, + DeviceMemoryBase input_diff_data, DeviceMemoryBase output_diff_data, ScratchAllocator* workspace_allocator) = 0; - virtual tsl::Status DoPoolBackward( + virtual absl::Status DoPoolBackward( DataType element_type, Stream* stream, - const dnn::PoolingDescriptor& pooling_dimensions, + const PoolingDescriptor& pooling_dimensions, const NumericOptions& numeric_options, - const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemoryBase output_data, DeviceMemoryBase input_diff_data, - DeviceMemoryBase output_diff_data, ScratchAllocator* workspace_allocator); + const BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, + const BatchDescriptor& output_dimensions, DeviceMemoryBase output_data, + DeviceMemoryBase input_diff_data, DeviceMemoryBase output_diff_data, + ScratchAllocator* workspace_allocator); // Applies local response normalization to the values from input_data and // writes the result to output_data. @@ -1894,9 +1865,9 @@ class DnnSupport { // See comments on NormalizeDescriptor for a description of local response // normalization. virtual bool DoNormalizeWithDimensions( - Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor, - const dnn::BatchDescriptor& dimensions, - const DeviceMemory& input_data, DeviceMemory* output_data) { + Stream* stream, const NormalizeDescriptor& normalize_descriptor, + const BatchDescriptor& dimensions, const DeviceMemory& input_data, + DeviceMemory* output_data) { return false; } @@ -1913,9 +1884,8 @@ class DnnSupport { // See comments on NormalizeDescriptor for a description of local response // normalization. virtual bool DoNormalizeBackwardWithDimensions( - Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor, - const dnn::BatchDescriptor& dimensions, - const DeviceMemory& raw_data, + Stream* stream, const NormalizeDescriptor& normalize_descriptor, + const BatchDescriptor& dimensions, const DeviceMemory& raw_data, const DeviceMemory& normalized_data, const DeviceMemory& normalized_variable_gradient, DeviceMemory* raw_variable_gradient, @@ -1923,219 +1893,6 @@ class DnnSupport { return false; } - // Applies an activation function (see ActivationMode) to all of the values - // held on the device in 'input_data', whose dimensions are described by - // 'dimensions'. - // - // Arguments (all borrowed): - // stream: borrowed pointer to the stream that the 'activate' operation - // should be enqueued onto. - // activation_mode: Type of activation to perform. - // input_data: un-owned device memory region which contains the - // activate input. - // output_data: un-owned device memory region in which to place the - // activate result. - virtual bool DoActivate(Stream* stream, ActivationMode activation_mode, - const BatchDescriptor& dimensions, - const DeviceMemory& input_data, - DeviceMemory* output_data, uint64_t options) { - return false; - } - - // Concatenates several layers into one, by concatenating the depth of each - // layer at matching x and y coordinates. - // The inputs must all have the same width and height, the output will have - // the same width and height as the inputs and its depth will be the sum of - // the input depths. - // - // Arguments (all borrowed): - // stream: borrowed pointer to the stream that the 'depth concatenate' - // operation should be enqueued onto. - // input_dimensions: The dimensions of each input. - // input_data: un-owned device memory region which contains the - // input data for each input layer. - // output_data: un-owned device memory region in which to place the - // depth concatenate result. - virtual bool DoDepthConcatenate( - Stream* stream, absl::Span input_dimensions, - absl::Span* const> input_data, - DeviceMemory* output_data) = 0; - - // Computes the specified operation (e.g. addition or multiplication) - // between corresponding elements in the inputs and stores the result in the - // output element. - // The inputs and output must all have the same dimensions, but may have - // different quantization parameters (min_value and max_value). - // - // Arguments (all borrowed): - // stream: borrowed pointer to the stream that the 'elementwise operation' - // should be enqueued onto. - // operation: The operation to perform. - // input_dimensions: The dimensions of each input. - // input_data: un-owned device memory region which contains the - // input data for each input layer. - // output_dimensions: The dimensions of the output. - // output_data: un-owned device memory region in which to place the - // operation result. - virtual bool DoElementwiseOperate( - Stream* stream, ElementwiseOperation operation, - absl::Span input_dimensions, - absl::Span* const> input_data, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) = 0; - - // Computes the specified operation (e.g. addition or multiplication) - // between corresponding elements in the inputs and stores the result in the - // output element. Each input is multiplied by a scalar constant and the - // result is divided by a scalar constant. - // e.g. To perform Z = 0.9*X + 1.1*Y, set the input multiplicands to 9 and 11 - // and the output divisor to 10. - // The inputs and output must all have the same dimensions, but may have - // different quantization parameters (min_value and max_value). - // - // Arguments (all borrowed): - // stream: borrowed pointer to the stream that the 'elementwise operation' - // should be enqueued onto. - // operation: The operation to perform. - // input_multiplicands: Amount to scale each input. - // output_divisor: Amount to divide the output. - // input_dimensions: The dimensions of each input. - // input_data: un-owned device memory region which contains the - // input data for each input layer. - // output_dimensions: The dimensions of the output. - // output_data: un-owned device memory region in which to place the - // operation result. - virtual bool DoElementwiseOperateScaledQuantized( - Stream* stream, ElementwiseOperation operation, - absl::Span input_multiplicands, int output_divisor, - absl::Span input_dimensions, - absl::Span* const> input_data, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) { - return false; - } - - // Pads the input with zeros in the X and Y dimensions. The feature_map - // dimension is unchanged. - // - // Arguments (all borrowed): - // stream: borrowed pointer to the stream that the 'elementwise operation' - // should be enqueued onto. - // dimensions: The dimensions of the input. - // input_data: un-owned device memory region which contains the - // input data for the input layer. - // left_pad: Amount to pad the input on the left. - // right_pad: Amount to pad the input on the right. - // top_pad: Amount to pad the input at the top (low Y). - // bottom_pad: Amount to pad the input at the bottom (high Y). - // output_data: un-owned device memory region in which to place the - // padded result. - virtual bool DoXYPad(Stream* stream, const dnn::BatchDescriptor& dimensions, - const DeviceMemory& input_data, int64_t left_pad, - int64_t right_pad, int64_t top_pad, int64_t bottom_pad, - DeviceMemory* output_data) = 0; - - // Extracts a slice of the input in the X and Y dimensions. The feature_map - // dimension is unchanged. - // - // Arguments (all borrowed): - // stream: borrowed pointer to the stream that the 'elementwise operation' - // should be enqueued onto. - // dimensions: The dimensions of the input. - // input_data: un-owned device memory region which contains the - // input data for the input layer. - // left_trim: Amount to cut off the input on the left. - // right_trim: Amount to cut off the input on the right. - // top_trim: Amount to cut off the input at the top (low y). - // bottom_trim: Amount to cut off the input at the bottom (high Y). - // output_data: un-owned device memory region in which to place the - // padded result. - virtual bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor& dimensions, - const DeviceMemory& input_data, - int64_t left_trim, int64_t right_trim, - int64_t top_trim, int64_t bottom_trim, - DeviceMemory* output_data) = 0; - - // Grows the input tensor by replicating the X and Y dimensions. The batch and - // depth/feature_map dimensions are unchanged. Currently, the input tensor is - // limited to X=1 and Y=1. - // - // For example, the input has dimensions x=2, y=3, and replicate_x=3, - // replicate_y=2. The diagonal elements of the output would be: [x0y0, x1y1, - // x0y2, x1y0, x0y1, x1y2]. - // Here is the example as a picture. input: - // AB - // CD - // EF - // broadcast result: - // ABABAB - // CDCDCD - // EFEFEF - // ABABAB - // CDCDCD - // EFEFEF - // - // Arguments (all borrowed): - // stream: borrowed pointer to the stream that the 'elementwise operation' - // should be enqueued onto. - // dimensions: The dimensions of the input. - // input_data: un-owned device memory region which contains the - // input data for the input layer. - // replicate_x: Amount to replicate the input's X dimension. - // replicate_y: Amount to replicate the input's Y dimension. - // output_data: un-owned device memory region in which to place the - // padded result. - virtual bool DoXYBroadcast(Stream* stream, - const dnn::BatchDescriptor& dimensions, - const DeviceMemory& input_data, - int64_t replicate_x, int64_t replicate_y, - DeviceMemory* output_data) { - return false; - } - - // Enqueues an asynchronous memcpy of the *quantized* output of a layer (that - // is, bytes instead of scaled floats) into 'host_dst' if they are available - // for the underlying DNN implementation. If this quantized output is not - // available, false is returned, which will place 'stream' into an error - // state. - // - // Arguments (all borrowed): - // stream: borrowed pointer to the stream that the 'quantized memcpy' - // operation should be enqueued onto. - // gpu_unquantized_src: the device memory that contains the unquantized data - // -- this data should also have a corresponding quantized representation - // on the device for this operation to succeed. - // mode: Type of quantization of the data to write into host_dst. - // host_dst: un-owned host memory region that is mutated in place, - // it is clobbered by the values in 'gpu_unquantized_src' when the enqueued - // (asynchronous) memcpy operation is performed. - // size: size in bytes of the host_dst host memory region. - virtual bool DoMemcpyD2HQuantized( - Stream* stream, const DeviceMemory& gpu_unquantized_src, - QuantizedActivationMode mode, void* host_dst, int64_t size) = 0; - - // Enqueues an asynchronous memcpy of 'host_dst' into the *quantized* input - // of a layer (that is, bytes instead of scaled floats) if they are supported - // by the underlying DNN implementation. If this quantized input is not - // supported, false is returned, which will place 'stream' into an error - // state. - // - // Arguments (all borrowed): - // stream: borrowed pointer to the stream that the 'quantized memcpy' - // operation should be enqueued onto. - // host_src: un-owned host memory region that contains the quantized data. - // size: size in bytes of the host_src host memory region. - // mode: Type of quantization of the data to read from host_src. - // gpu_unquantized_dst: the device memory that is clobbered by the values in - // 'host_src' when the enqueued (asynchronous) memcpy operation is - // performed. -- this data should also have a corresponding quantized - // representation on the device for this operation to - // succeed. - virtual bool DoMemcpyH2DQuantized( - Stream* stream, const void* host_src, int64_t size, - QuantizedActivationMode mode, - DeviceMemory* gpu_unquantized_dst) = 0; - // Create an RNN descriptor based on model shapes and configurations. // The caller retains the ownership of the descriptor. // @@ -2158,18 +1915,14 @@ class DnnSupport { // for dropout layer. The user has to maintain the memory until the model // is no longer in use. // use_padded_io: a bool to specify whether the input is using padded IO. - virtual tsl::StatusOr> - createRnnDescriptor(int num_layers, int hidden_size, int input_size, - int cell_size, int batch_size, - dnn::RnnInputMode input_mode, - dnn::RnnDirectionMode direction_mode, - dnn::RnnMode rnn_mode, dnn::DataType data_type, - const dnn::AlgorithmConfig& algorithm_config, - const NumericOptions& numeric_options, float dropout, - uint64_t seed, ScratchAllocator* state_allocator, - bool use_padded_io) { - return tsl::Status(absl::StatusCode::kUnimplemented, - "createRnnDescriptor is unimplemented"); + virtual absl::StatusOr> CreateRnnDescriptor( + int num_layers, int hidden_size, int input_size, int cell_size, + int batch_size, RnnInputMode input_mode, RnnDirectionMode direction_mode, + RnnMode rnn_mode, DataType data_type, + const AlgorithmConfig& algorithm_config, + const NumericOptions& numeric_options, float dropout, uint64_t seed, + ScratchAllocator* state_allocator, bool use_padded_io) { + return absl::UnimplementedError("CreateRnnDescriptor is unimplemented"); } // Create a RNN sequence descriptor that specifies either the input or output @@ -2181,36 +1934,36 @@ class DnnSupport { // data_size: the size of the state. // seq_lengths: the lengths of sequences in a batch. // data_type: an enum to specify the type for the underlying data. - virtual tsl::StatusOr> - createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, - int data_size, dnn::DataType data_type) { - return tsl::Status(absl::StatusCode::kUnimplemented, - "createRnnSequenceTensorDescriptor is unimplemented"); + virtual absl::StatusOr> + CreateRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, + int data_size, DataType data_type) { + return absl::UnimplementedError( + "CreateRnnSequenceTensorDescriptor is unimplemented"); } - virtual tsl::StatusOr> - createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, + virtual absl::StatusOr> + CreateRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, int data_size, const absl::Span& seq_lengths, - bool time_major, dnn::DataType data_type) { - return tsl::Status(absl::StatusCode::kUnimplemented, - "createRnnSequenceTensorDescriptor is unimplemented"); + bool time_major, DataType data_type) { + return absl::UnimplementedError( + "CreateRnnSequenceTensorDescriptor is unimplemented"); } // Create an RNN state descriptor that specifies the input or hidden state. // The caller retains the ownership of the returned descriptor. - virtual tsl::StatusOr> - createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, - dnn::DataType data_type) { - return tsl::Status(absl::StatusCode::kUnimplemented, - "createRnnStateTensorDescriptor is unimplemented"); + virtual absl::StatusOr> + CreateRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, + DataType data_type) { + return absl::UnimplementedError( + "CreateRnnStateTensorDescriptor is unimplemented"); } // Enqueue a forward operation of the RNN model onto the stream. // // Arguments: // stream: pointer to the stream where this operation should be enqueued to. - // rnn_desc: a RNN descriptor created by createRnnDescriptor. + // rnn_desc: a RNN descriptor created by CreateRnnDescriptor. // input_desc: descriptor for the input sequence. // input_data: the device memory region that contains the input data. // input_h_desc: descriptor for the input "h" state. @@ -2235,76 +1988,76 @@ class DnnSupport { // workspace_allocator: an allocator to create temporary workspace used in // this kernel. The caller is responsible for retaining the memory long // enough for the lifespan of this operation, and recycles afterwards. - virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, - const dnn::RnnSequenceTensorDescriptor& input_desc, + virtual bool DoRnnForward(Stream* stream, const RnnDescriptor& rnn_desc, + const RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, const DeviceMemory& seq_lengths_data, - const dnn::RnnStateTensorDescriptor& input_h_desc, + const RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, - const dnn::RnnStateTensorDescriptor& input_c_desc, + const RnnStateTensorDescriptor& input_c_desc, const DeviceMemory& input_c_data, const DeviceMemory& params, - const dnn::RnnSequenceTensorDescriptor& output_desc, + const RnnSequenceTensorDescriptor& output_desc, DeviceMemory* output_data, - const dnn::RnnStateTensorDescriptor& output_h_desc, + const RnnStateTensorDescriptor& output_h_desc, DeviceMemory* output_h_data, - const dnn::RnnStateTensorDescriptor& output_c_desc, + const RnnStateTensorDescriptor& output_c_desc, DeviceMemory* output_c_data, bool is_training, ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator, - dnn::ProfileResult* output_profile_result) { + ProfileResult* output_profile_result) { return false; } - virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, - const dnn::RnnSequenceTensorDescriptor& input_desc, + virtual bool DoRnnForward(Stream* stream, const RnnDescriptor& rnn_desc, + const RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, const DeviceMemory& seq_lengths_data, - const dnn::RnnStateTensorDescriptor& input_h_desc, + const RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, - const dnn::RnnStateTensorDescriptor& input_c_desc, + const RnnStateTensorDescriptor& input_c_desc, const DeviceMemory& input_c_data, const DeviceMemory& params, - const dnn::RnnSequenceTensorDescriptor& output_desc, + const RnnSequenceTensorDescriptor& output_desc, DeviceMemory* output_data, - const dnn::RnnStateTensorDescriptor& output_h_desc, + const RnnStateTensorDescriptor& output_h_desc, DeviceMemory* output_h_data, - const dnn::RnnStateTensorDescriptor& output_c_desc, + const RnnStateTensorDescriptor& output_c_desc, DeviceMemory* output_c_data, bool is_training, ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator, - dnn::ProfileResult* output_profile_result) { + ProfileResult* output_profile_result) { return false; } - virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, - const dnn::RnnSequenceTensorDescriptor& input_desc, + virtual bool DoRnnForward(Stream* stream, const RnnDescriptor& rnn_desc, + const RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, const DeviceMemory& seq_lengths_data, - const dnn::RnnStateTensorDescriptor& input_h_desc, + const RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, - const dnn::RnnStateTensorDescriptor& input_c_desc, + const RnnStateTensorDescriptor& input_c_desc, const DeviceMemory& input_c_data, const DeviceMemory& params, - const dnn::RnnSequenceTensorDescriptor& output_desc, + const RnnSequenceTensorDescriptor& output_desc, DeviceMemory* output_data, - const dnn::RnnStateTensorDescriptor& output_h_desc, + const RnnStateTensorDescriptor& output_h_desc, DeviceMemory* output_h_data, - const dnn::RnnStateTensorDescriptor& output_c_desc, + const RnnStateTensorDescriptor& output_c_desc, DeviceMemory* output_c_data, bool is_training, ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator, - dnn::ProfileResult* output_profile_result) { + ProfileResult* output_profile_result) { return false; } // Enqueue a backward operation of the RNN model onto the stream. // // Arguments: // stream: pointer to the stream where this operation should be enqueued to. - // rnn_desc: a RNN descriptor created by createRnnDescriptor. + // rnn_desc: a RNN descriptor created by CreateRnnDescriptor. // input_desc: descriptor for the input sequence. // input_data: the device memory region that contains the input data. // input_h_desc: descriptor for the input "h" state. @@ -2342,20 +2095,20 @@ class DnnSupport { // keeping the memory alive long enough for this operation, and recylces // afterwards. virtual bool DoRnnBackward( - Stream* stream, const dnn::RnnDescriptor& rnn_desc, - const dnn::RnnSequenceTensorDescriptor& input_desc, + Stream* stream, const RnnDescriptor& rnn_desc, + const RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, const DeviceMemory& seq_lengths_data, - const dnn::RnnStateTensorDescriptor& input_h_desc, + const RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, - const dnn::RnnStateTensorDescriptor& input_c_desc, + const RnnStateTensorDescriptor& input_c_desc, const DeviceMemory& input_c_data, const DeviceMemory& params, - const dnn::RnnSequenceTensorDescriptor& output_desc, + const RnnSequenceTensorDescriptor& output_desc, const DeviceMemory& output_data, - const dnn::RnnStateTensorDescriptor& output_h_desc, + const RnnStateTensorDescriptor& output_h_desc, const DeviceMemory& output_h_data, - const dnn::RnnStateTensorDescriptor& output_c_desc, + const RnnStateTensorDescriptor& output_c_desc, const DeviceMemory& output_c_data, const DeviceMemory& output_backprop_data, const DeviceMemory& output_h_backprop_data, @@ -2366,80 +2119,78 @@ class DnnSupport { DeviceMemory* params_backprop_data, DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, - dnn::ProfileResult* output_profile_result) { + ProfileResult* output_profile_result) { return false; } - virtual bool DoRnnBackward( - Stream* stream, const dnn::RnnDescriptor& rnn_desc, - const dnn::RnnSequenceTensorDescriptor& input_desc, - const DeviceMemory& input_data, - const DeviceMemory& seq_lengths_data, - const dnn::RnnStateTensorDescriptor& input_h_desc, - const DeviceMemory& input_h_data, - const dnn::RnnStateTensorDescriptor& input_c_desc, - const DeviceMemory& input_c_data, - const DeviceMemory& params, - const dnn::RnnSequenceTensorDescriptor& output_desc, - const DeviceMemory& output_data, - const dnn::RnnStateTensorDescriptor& output_h_desc, - const DeviceMemory& output_h_data, - const dnn::RnnStateTensorDescriptor& output_c_desc, - const DeviceMemory& output_c_data, - const DeviceMemory& output_backprop_data, - const DeviceMemory& output_h_backprop_data, - const DeviceMemory& output_c_backprop_data, - DeviceMemory* input_backprop_data, - DeviceMemory* input_h_backprop_data, - DeviceMemory* input_c_backprop_data, - DeviceMemory* params_backprop_data, - DeviceMemory* reserve_space_data, - ScratchAllocator* workspace_allocator, - dnn::ProfileResult* output_profile_result) { + virtual bool DoRnnBackward(Stream* stream, const RnnDescriptor& rnn_desc, + const RnnSequenceTensorDescriptor& input_desc, + const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, + const RnnStateTensorDescriptor& input_h_desc, + const DeviceMemory& input_h_data, + const RnnStateTensorDescriptor& input_c_desc, + const DeviceMemory& input_c_data, + const DeviceMemory& params, + const RnnSequenceTensorDescriptor& output_desc, + const DeviceMemory& output_data, + const RnnStateTensorDescriptor& output_h_desc, + const DeviceMemory& output_h_data, + const RnnStateTensorDescriptor& output_c_desc, + const DeviceMemory& output_c_data, + const DeviceMemory& output_backprop_data, + const DeviceMemory& output_h_backprop_data, + const DeviceMemory& output_c_backprop_data, + DeviceMemory* input_backprop_data, + DeviceMemory* input_h_backprop_data, + DeviceMemory* input_c_backprop_data, + DeviceMemory* params_backprop_data, + DeviceMemory* reserve_space_data, + ScratchAllocator* workspace_allocator, + ProfileResult* output_profile_result) { return false; } - virtual bool DoRnnBackward( - Stream* stream, const dnn::RnnDescriptor& rnn_desc, - const dnn::RnnSequenceTensorDescriptor& input_desc, - const DeviceMemory& input_data, - const DeviceMemory& seq_lengths_data, - const dnn::RnnStateTensorDescriptor& input_h_desc, - const DeviceMemory& input_h_data, - const dnn::RnnStateTensorDescriptor& input_c_desc, - const DeviceMemory& input_c_data, - const DeviceMemory& params, - const dnn::RnnSequenceTensorDescriptor& output_desc, - const DeviceMemory& output_data, - const dnn::RnnStateTensorDescriptor& output_h_desc, - const DeviceMemory& output_h_data, - const dnn::RnnStateTensorDescriptor& output_c_desc, - const DeviceMemory& output_c_data, - const DeviceMemory& output_backprop_data, - const DeviceMemory& output_h_backprop_data, - const DeviceMemory& output_c_backprop_data, - DeviceMemory* input_backprop_data, - DeviceMemory* input_h_backprop_data, - DeviceMemory* input_c_backprop_data, - DeviceMemory* params_backprop_data, - DeviceMemory* reserve_space_data, - ScratchAllocator* workspace_allocator, - dnn::ProfileResult* output_profile_result) { + virtual bool DoRnnBackward(Stream* stream, const RnnDescriptor& rnn_desc, + const RnnSequenceTensorDescriptor& input_desc, + const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, + const RnnStateTensorDescriptor& input_h_desc, + const DeviceMemory& input_h_data, + const RnnStateTensorDescriptor& input_c_desc, + const DeviceMemory& input_c_data, + const DeviceMemory& params, + const RnnSequenceTensorDescriptor& output_desc, + const DeviceMemory& output_data, + const RnnStateTensorDescriptor& output_h_desc, + const DeviceMemory& output_h_data, + const RnnStateTensorDescriptor& output_c_desc, + const DeviceMemory& output_c_data, + const DeviceMemory& output_backprop_data, + const DeviceMemory& output_h_backprop_data, + const DeviceMemory& output_c_backprop_data, + DeviceMemory* input_backprop_data, + DeviceMemory* input_h_backprop_data, + DeviceMemory* input_c_backprop_data, + DeviceMemory* params_backprop_data, + DeviceMemory* reserve_space_data, + ScratchAllocator* workspace_allocator, + ProfileResult* output_profile_result) { return false; } template - tsl::Status PrepareForCtcLoss(Stream* stream, - const RnnStateTensorDescriptor& probs_desc, - DeviceMemory probs_data, - const RnnStateTensorDescriptor& grads_desc, - absl::Span labels_data, - absl::Span labels_lengths_data, - absl::Span input_lengths_data, - const NumericOptions& numeric_options, - ScratchAllocator* workspace_allocator, - DeviceMemory* scratch_memory, - int* ctc_loss_algo_id) { + absl::Status PrepareForCtcLoss(Stream* stream, + const RnnStateTensorDescriptor& probs_desc, + DeviceMemory probs_data, + const RnnStateTensorDescriptor& grads_desc, + absl::Span labels_data, + absl::Span labels_lengths_data, + absl::Span input_lengths_data, + const NumericOptions& numeric_options, + ScratchAllocator* workspace_allocator, + DeviceMemory* scratch_memory, + int* ctc_loss_algo_id) { return DoPrepareForCtcLoss( stream, ToDataType::value, probs_desc, grads_desc, labels_data, labels_lengths_data, input_lengths_data, numeric_options, @@ -2467,8 +2218,8 @@ class DnnSupport { // workspace memory used by this operation. The caller is responsible for // keeping the memory alive long enough for this operation, and recylces // afterwards. - virtual tsl::Status DoCtcLoss( - Stream* stream, dnn::DataType element_type, + virtual absl::Status DoCtcLoss( + Stream* stream, DataType element_type, const RnnStateTensorDescriptor& probs_desc, const DeviceMemoryBase probs_data, absl::Span labels_data, absl::Span labels_lengths_data, @@ -2477,14 +2228,13 @@ class DnnSupport { DeviceMemory scratch_memory, int ctc_loss_algo_id); template - bool DoCtcLoss(Stream* stream, - const dnn::RnnStateTensorDescriptor& probs_desc, + bool DoCtcLoss(Stream* stream, const RnnStateTensorDescriptor& probs_desc, const DeviceMemory& probs_data, absl::Span labels_data, absl::Span labels_lengths_data, absl::Span input_lengths_data, DeviceMemory* costs_data, - const dnn::RnnStateTensorDescriptor& grads_desc, + const RnnStateTensorDescriptor& grads_desc, DeviceMemory* grads_data, DeviceMemory* scratch_memory, int ctc_loss_algo_id) { return IsStatusOk( @@ -2507,13 +2257,10 @@ class DnnSupport { // output_type: the data type of the output tensor. // scale: an element-wise scaling factor to apply. // output_data: the device memory region that contains the output tensor. - virtual bool DoTransformTensor(Stream* stream, - const dnn::BatchDescriptor& input_desc, - dnn::DataType input_type, - const DeviceMemoryBase& input_data, - const dnn::BatchDescriptor& output_desc, - dnn::DataType output_type, float scale, - DeviceMemoryBase* output_data) { + virtual bool DoTransformTensor( + Stream* stream, const BatchDescriptor& input_desc, DataType input_type, + const DeviceMemoryBase& input_data, const BatchDescriptor& output_desc, + DataType output_type, float scale, DeviceMemoryBase* output_data) { return false; } @@ -2525,10 +2272,10 @@ class DnnSupport { protected: // Returns whether status is 'ok', and potentially logs the error. - static bool IsStatusOk(const tsl::Status& status, bool report_error); + static bool IsStatusOk(const absl::Status& status, bool report_error); private: - virtual tsl::Status DoPrepareForConvolution( + virtual absl::Status DoPrepareForConvolution( ConvolutionKind kind, DataType element_type, Stream* stream, const BatchDescriptor& batch_descriptor, DeviceMemoryBase input_data, const FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data, @@ -2539,10 +2286,10 @@ class DnnSupport { DeviceMemory* scratch_memory) { *algorithm_desc = {}; *scratch_memory = {}; - return ::tsl::OkStatus(); + return absl::OkStatus(); } - virtual tsl::Status DoPrepareForCtcLoss( + virtual absl::Status DoPrepareForCtcLoss( Stream* stream, DataType element_type, const RnnStateTensorDescriptor& probs_desc, const RnnStateTensorDescriptor& grads_desc, @@ -2553,7 +2300,7 @@ class DnnSupport { ScratchAllocator* scratch_allocator, DeviceMemory* scratch_memory, int* ctc_loss_algo_id) { *scratch_memory = {}; - return ::tsl::OkStatus(); + return absl::OkStatus(); } DnnSupport(const DnnSupport&) = delete; diff --git a/xla/stream_executor/dnn_test.cc b/xla/stream_executor/dnn_test.cc index 62f0dee764674..b159b333ce4e4 100644 --- a/xla/stream_executor/dnn_test.cc +++ b/xla/stream_executor/dnn_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/event.cc b/xla/stream_executor/event.cc index 602b86ccdaa9f..6d634640d5dc2 100644 --- a/xla/stream_executor/event.cc +++ b/xla/stream_executor/event.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,10 @@ limitations under the License. #include "xla/stream_executor/event.h" +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" @@ -52,7 +56,7 @@ Event::Status Event::PollForStatus() { return stream_exec_->PollForEventStatus(this); } -tsl::Status Event::WaitForEventOnExternalStream(std::intptr_t stream) { +absl::Status Event::WaitForEventOnExternalStream(std::intptr_t stream) { return stream_exec_->WaitForEventOnExternalStream(stream, this); } diff --git a/xla/stream_executor/event.h b/xla/stream_executor/event.h index 8721c9d53d559..3be6b53a406a1 100644 --- a/xla/stream_executor/event.h +++ b/xla/stream_executor/event.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,10 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_EVENT_H_ #define XLA_STREAM_EXECUTOR_EVENT_H_ +#include #include -#include "xla/stream_executor/platform/port.h" -#include "tsl/platform/status.h" +#include "absl/status/status.h" namespace stream_executor { @@ -32,7 +32,7 @@ class StreamExecutor; // The Event class, when supported by a platform, enables low-overhead status // reporting for a Stream. An Event is inserted at a location in a stream via -// the Stream::ThenRecordEvent() API. From then on, the Event's status can be +// the Stream::RecordEvent() API. From then on, the Event's status can be // monitored via the nonblocking Event::PollForStatus() call. class Event { public: @@ -61,7 +61,7 @@ class Event { // Blocks `stream` on this event. `stream` is a raw platform-specific // stream (e.g. GpuStreamHandle). - tsl::Status WaitForEventOnExternalStream(std::intptr_t stream); + absl::Status WaitForEventOnExternalStream(std::intptr_t stream); // Returns a pointer to the underlying platform-specific implementation. internal::EventInterface* implementation() { return implementation_.get(); } diff --git a/xla/stream_executor/executor_cache.cc b/xla/stream_executor/executor_cache.cc index 994a3e4417935..eae72060f0c04 100644 --- a/xla/stream_executor/executor_cache.cc +++ b/xla/stream_executor/executor_cache.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,20 +18,20 @@ limitations under the License. #include #include -#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/statusor.h" +#include "tsl/platform/logging.h" namespace stream_executor { ExecutorCache::ExecutorCache() = default; ExecutorCache::~ExecutorCache() { DestroyAllExecutors(); } -tsl::StatusOr ExecutorCache::GetOrCreate( +absl::StatusOr ExecutorCache::GetOrCreate( const StreamExecutorConfig& config, const ExecutorFactory& factory) { // In the fast path case, the cache already has an entry and we can just // return after Get() which only takes a shared lock and not a unique lock. @@ -53,14 +53,12 @@ tsl::StatusOr ExecutorCache::GetOrCreate( // initialization of different entries. absl::MutexLock lock{&entry->configurations_mutex}; for (const auto& iter : entry->configurations) { - if (iter.first.device_options == config.device_options) { - VLOG(2) << "hit in cache"; - return iter.second.get(); - } + VLOG(2) << "hit in cache"; + return iter.second.get(); } VLOG(2) << "building executor"; - tsl::StatusOr> result = factory(); + absl::StatusOr> result = factory(); if (!result.ok()) { VLOG(2) << "failed to get build executor: " << result.status(); // If construction failed, leave the cache Entry around, but with a null @@ -71,7 +69,7 @@ tsl::StatusOr ExecutorCache::GetOrCreate( return entry->configurations.back().second.get(); } -tsl::StatusOr ExecutorCache::Get( +absl::StatusOr ExecutorCache::Get( const StreamExecutorConfig& config) { Entry* entry = nullptr; { @@ -107,10 +105,7 @@ tsl::StatusOr ExecutorCache::Get( } for (auto& [entry_config, entry_executor] : entry->configurations) { - if (entry_config.device_options == config.device_options) { - VLOG(2) << "hit in cache for device ordinal " << config.ordinal; - return entry_executor.get(); - } + return entry_executor.get(); } return absl::NotFoundError("No executor found with a matching config."); diff --git a/xla/stream_executor/executor_cache.h b/xla/stream_executor/executor_cache.h index 6e9668afdbf39..6e7f32e487cd1 100644 --- a/xla/stream_executor/executor_cache.h +++ b/xla/stream_executor/executor_cache.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,10 +23,9 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/container/node_hash_map.h" +#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/platform/port.h" -#include "tsl/platform/statusor.h" namespace stream_executor { @@ -38,7 +37,7 @@ class StreamExecutor; class ExecutorCache { public: using ExecutorFactory = - std::function>()>; + std::function>()>; ExecutorCache(); ~ExecutorCache(); @@ -46,12 +45,12 @@ class ExecutorCache { // Looks up 'config' in the cache. Returns a pointer to the existing executor, // if already present, or creates it using 'factory', if it does not. // Factories may be executed concurrently for different device ordinals. - tsl::StatusOr GetOrCreate(const StreamExecutorConfig& config, - const ExecutorFactory& factory); + absl::StatusOr GetOrCreate( + const StreamExecutorConfig& config, const ExecutorFactory& factory); // Returns a pointer to the described executor (if one with a matching config // has been created), or a NOT_FOUND status. - tsl::StatusOr Get(const StreamExecutorConfig& config); + absl::StatusOr Get(const StreamExecutorConfig& config); // Destroys all Executors and clears the cache. // Performs no synchronization with the executors - undefined behavior may diff --git a/xla/stream_executor/fft.h b/xla/stream_executor/fft.h index 88a383b84cb1e..d88834c72e074 100644 --- a/xla/stream_executor/fft.h +++ b/xla/stream_executor/fft.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -44,6 +44,7 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_FFT_H_ #include +#include #include #include "xla/stream_executor/platform/port.h" @@ -90,58 +91,6 @@ class FftSupport { public: virtual ~FftSupport() {} - // Creates a 1d FFT plan. - virtual std::unique_ptr Create1dPlan(Stream *stream, uint64_t num_x, - Type type, bool in_place_fft) = 0; - - // Creates a 2d FFT plan. - virtual std::unique_ptr Create2dPlan(Stream *stream, uint64_t num_x, - uint64_t num_y, Type type, - bool in_place_fft) = 0; - - // Creates a 3d FFT plan. - virtual std::unique_ptr Create3dPlan(Stream *stream, uint64_t num_x, - uint64_t num_y, uint64 num_z, - Type type, bool in_place_fft) = 0; - - // Creates a 1d FFT plan with scratch allocator. - virtual std::unique_ptr Create1dPlanWithScratchAllocator( - Stream *stream, uint64_t num_x, Type type, bool in_place_fft, - ScratchAllocator *scratch_allocator) = 0; - - // Creates a 2d FFT plan with scratch allocator. - virtual std::unique_ptr Create2dPlanWithScratchAllocator( - Stream *stream, uint64_t num_x, uint64 num_y, Type type, - bool in_place_fft, ScratchAllocator *scratch_allocator) = 0; - - // Creates a 3d FFT plan with scratch allocator. - virtual std::unique_ptr Create3dPlanWithScratchAllocator( - Stream *stream, uint64_t num_x, uint64 num_y, uint64 num_z, Type type, - bool in_place_fft, ScratchAllocator *scratch_allocator) = 0; - - // Creates a batched FFT plan. - // - // stream: The GPU stream in which the FFT runs. - // rank: Dimensionality of the transform (1, 2, or 3). - // elem_count: Array of size rank, describing the size of each dimension. - // input_embed, output_embed: - // Pointer of size rank that indicates the storage dimensions - // of the input/output data in memory. If set to null_ptr all - // other advanced data layout parameters are ignored. - // input_stride: Indicates the distance (number of elements; same below) - // between two successive input elements. - // input_distance: Indicates the distance between the first element of two - // consecutive signals in a batch of the input data. - // output_stride: Indicates the distance between two successive output - // elements. - // output_distance: Indicates the distance between the first element of two - // consecutive signals in a batch of the output data. - virtual std::unique_ptr CreateBatchedPlan( - Stream *stream, int rank, uint64_t *elem_count, uint64 *input_embed, - uint64_t input_stride, uint64 input_distance, uint64 *output_embed, - uint64_t output_stride, uint64 output_distance, Type type, - bool in_place_fft, int batch_count) = 0; - // Creates a batched FFT plan with scratch allocator. // // stream: The GPU stream in which the FFT runs. @@ -212,30 +161,6 @@ class FftSupport { // fft::FftSupport base class. Assumes that it's emitted somewhere inside the // ::stream_executor namespace. #define TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES \ - std::unique_ptr Create1dPlan(Stream *stream, uint64_t num_x, \ - fft::Type type, bool in_place_fft) \ - override; \ - std::unique_ptr Create2dPlan(Stream *stream, uint64_t num_x, \ - uint64_t num_y, fft::Type type, \ - bool in_place_fft) override; \ - std::unique_ptr Create3dPlan( \ - Stream *stream, uint64_t num_x, uint64 num_y, uint64 num_z, \ - fft::Type type, bool in_place_fft) override; \ - std::unique_ptr Create1dPlanWithScratchAllocator( \ - Stream *stream, uint64_t num_x, fft::Type type, bool in_place_fft, \ - ScratchAllocator *scratch_allocator) override; \ - std::unique_ptr Create2dPlanWithScratchAllocator( \ - Stream *stream, uint64_t num_x, uint64 num_y, fft::Type type, \ - bool in_place_fft, ScratchAllocator *scratch_allocator) override; \ - std::unique_ptr Create3dPlanWithScratchAllocator( \ - Stream *stream, uint64_t num_x, uint64 num_y, uint64 num_z, \ - fft::Type type, bool in_place_fft, ScratchAllocator *scratch_allocator) \ - override; \ - std::unique_ptr CreateBatchedPlan( \ - Stream *stream, int rank, uint64_t *elem_count, uint64 *input_embed, \ - uint64_t input_stride, uint64 input_distance, uint64 *output_embed, \ - uint64_t output_stride, uint64 output_distance, fft::Type type, \ - bool in_place_fft, int batch_count) override; \ std::unique_ptr CreateBatchedPlanWithScratchAllocator( \ Stream *stream, int rank, uint64_t *elem_count, uint64 *input_embed, \ uint64_t input_stride, uint64 input_distance, uint64 *output_embed, \ diff --git a/xla/stream_executor/gpu/BUILD b/xla/stream_executor/gpu/BUILD index 7fc898b32e2a4..b637cd42b04cf 100644 --- a/xla/stream_executor/gpu/BUILD +++ b/xla/stream_executor/gpu/BUILD @@ -2,23 +2,25 @@ # GPU-platform specific StreamExecutor support code. load( - "//xla/stream_executor:build_defs.bzl", - "if_gpu_is_configured", + "@local_config_cuda//cuda:build_defs.bzl", + "if_cuda", ) load( "@local_config_rocm//rocm:build_defs.bzl", + "if_rocm", "if_rocm_is_configured", ) load( "@tsl//tsl:tsl.bzl", "if_libtpu", - "set_external_visibility", + "internal_visibility", "tsl_copts", "tsl_gpu_library", ) load( "@tsl//tsl/platform:build_config_root.bzl", "if_static", + "tf_gpu_tests_tags", ) load( "@tsl//tsl/platform:rules_cc.bzl", @@ -28,10 +30,27 @@ load( "@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) +load( + "//xla:xla.bzl", + "xla_cc_test", +) +load( + "//xla/service/gpu:build_defs.bzl", + "gpu_kernel_library", +) +load( + "//xla/stream_executor:build_defs.bzl", + "gpu_only_cc_library", + "if_gpu_is_configured", +) +load( + "//xla/tests:build_defs.bzl", + "xla_test", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([ + default_visibility = internal_visibility([ "//tensorflow/compiler/tf2xla:__subpackages__", "//xla:__subpackages__", "//tensorflow/core/kernels:__subpackages__", @@ -47,36 +66,39 @@ package( cc_library( name = "gpu_activation_header", hdrs = ["gpu_activation.h"], - deps = ["//xla/stream_executor/platform"], ) -cc_library( +gpu_only_cc_library( name = "gpu_activation", - srcs = if_gpu_is_configured(["gpu_activation.cc"]), - hdrs = if_gpu_is_configured(["gpu_activation.h"]), - deps = if_gpu_is_configured([ - ":gpu_executor_header", + srcs = ["gpu_activation.cc"], + hdrs = ["gpu_activation.h"], + deps = [ ":gpu_activation_header", ":gpu_driver_header", + ":gpu_executor_header", "//xla/stream_executor", - "//xla/stream_executor:stream_executor_internal", - "//xla/stream_executor/platform", - ]), + ], ) -cc_library( +gpu_only_cc_library( name = "gpu_diagnostics_header", - hdrs = if_gpu_is_configured(["gpu_diagnostics.h"]), + hdrs = ["gpu_diagnostics.h"], + deps = ["@com_google_absl//absl/status:statusor"], +) + +gpu_only_cc_library( + name = "gpu_collectives_header", + hdrs = ["gpu_collectives.h"], deps = [ - "//xla/stream_executor/platform", - "@tsl//tsl/platform:statusor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", ], ) -cc_library( +gpu_only_cc_library( name = "gpu_driver_header", - hdrs = if_gpu_is_configured(["gpu_driver.h"]), - visibility = set_external_visibility([ + hdrs = ["gpu_driver.h"], + visibility = internal_visibility([ "//xla/service/gpu:__subpackages__", "//xla/stream_executor:__subpackages__", "//tensorflow/core/common_runtime/gpu:__subpackages__", @@ -84,8 +106,10 @@ cc_library( ]), deps = [ ":gpu_types_header", - "//xla/stream_executor:device_options", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", @@ -95,39 +119,54 @@ cc_library( ), ) -cc_library( +gpu_only_cc_library( name = "gpu_runtime_header", - hdrs = if_gpu_is_configured(["gpu_runtime.h"]), - visibility = set_external_visibility([ + hdrs = ["gpu_runtime.h"], + visibility = internal_visibility([ + "//xla/service/gpu:__subpackages__", "//xla/stream_executor:__subpackages__", ]), deps = [ ":gpu_types_header", - "@tsl//tsl/platform:statusor", + "@com_google_absl//absl/status:statusor", ], ) -cc_library( +gpu_only_cc_library( + name = "gpu_kernels", + hdrs = ["gpu_kernels.h"], +) + +gpu_only_cc_library( name = "gpu_command_buffer", - srcs = if_gpu_is_configured(["gpu_command_buffer.cc"]), - hdrs = if_gpu_is_configured(["gpu_command_buffer.h"]), + srcs = ["gpu_command_buffer.cc"], + hdrs = ["gpu_command_buffer.h"], + local_defines = if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), deps = [ ":gpu_driver_header", ":gpu_executor_header", ":gpu_kernel_header", + ":gpu_kernels", ":gpu_stream", ":gpu_types_header", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor", "//xla/stream_executor:stream_executor_internal", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:path", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", ] + if_cuda_is_configured([ @@ -137,55 +176,56 @@ cc_library( ]), ) -cc_library( +gpu_only_cc_library( name = "gpu_event_header", - hdrs = if_gpu_is_configured(["gpu_event.h"]), - deps = if_gpu_is_configured([ - ":gpu_driver_header", + hdrs = ["gpu_event.h"], + deps = [ ":gpu_stream_header", - "//xla/stream_executor:stream_executor_headers", - "@tsl//tsl/platform:status", - ]), + ":gpu_types_header", + "//xla/stream_executor", + "@com_google_absl//absl/status", + ], ) -cc_library( +gpu_only_cc_library( name = "gpu_event", - srcs = if_gpu_is_configured(["gpu_event.cc"]), - hdrs = if_gpu_is_configured(["gpu_event.h"]), + srcs = ["gpu_event.cc"], + hdrs = ["gpu_event.h"], deps = [ ":gpu_driver_header", ":gpu_executor_header", ":gpu_stream", - "//xla/stream_executor:stream_executor_headers", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", + ":gpu_types_header", + "//xla/stream_executor", + "@com_google_absl//absl/status", ], ) -cc_library( +gpu_only_cc_library( name = "gpu_executor_header", - hdrs = if_gpu_is_configured(["gpu_executor.h"]), + hdrs = ["gpu_executor.h"], deps = [ - ":gpu_kernel_header", + ":gpu_collectives_header", + ":gpu_driver_header", ":gpu_types_header", + "//xla/stream_executor", "//xla/stream_executor:platform", - "//xla/stream_executor:stream_executor_headers", "//xla/stream_executor:stream_executor_internal", - "//xla/stream_executor/platform", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:fingerprint", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:thread_annotations", ], ) -cc_library( +gpu_only_cc_library( name = "gpu_helpers_header", - hdrs = if_gpu_is_configured(["gpu_helpers.h"]), + hdrs = ["gpu_helpers.h"], deps = [ ":gpu_types_header", "@tsl//tsl/platform:logging", @@ -197,10 +237,11 @@ tsl_gpu_library( hdrs = [ "gpu_init.h", ], - visibility = set_external_visibility([ + visibility = internal_visibility([ "@tsl//tsl:internal", ]), deps = [ + "@com_google_absl//absl/status", "@tsl//tsl/platform:status", ] + if_static( [":gpu_init_impl"], @@ -217,93 +258,130 @@ tsl_gpu_library( ], copts = tsl_copts(), linkstatic = True, - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/compiler/tf2xla:__subpackages__", "//xla:__subpackages__", "//tensorflow/core/common_runtime/gpu:__subpackages__", "//tensorflow/stream_executor:__subpackages__", ]), deps = [ - "//xla/stream_executor:multi_platform_manager", "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", ], alwayslink = True, ) -cc_library( +gpu_only_cc_library( name = "gpu_kernel_header", - hdrs = if_gpu_is_configured(["gpu_kernel.h"]), + hdrs = ["gpu_kernel.h"], deps = [ ":gpu_driver_header", - "//xla/stream_executor:stream_executor_headers", + ":gpu_executor_header", + ":gpu_types_header", + "//xla/stream_executor", "//xla/stream_executor:stream_executor_internal", - "//xla/stream_executor/platform", + "@com_google_absl//absl/status:statusor", "@tsl//tsl/platform:logging", ], ) -cc_library( +gpu_only_cc_library( name = "gpu_stream_header", - hdrs = if_gpu_is_configured(["gpu_stream.h"]), + hdrs = ["gpu_stream.h"], deps = [ ":gpu_types_header", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor", "//xla/stream_executor:stream_executor_internal", - "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:check", ], ) -cc_library( +gpu_only_cc_library( name = "gpu_stream", - srcs = if_gpu_is_configured(["gpu_stream.cc"]), - hdrs = if_gpu_is_configured(["gpu_stream.h"]), + srcs = ["gpu_stream.cc"], + hdrs = ["gpu_stream.h"], deps = [ + ":gpu_driver_header", ":gpu_executor_header", ":gpu_types_header", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor", "//xla/stream_executor:stream_executor_internal", - "@tsl//tsl/platform:status", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", ], ) -cc_library( +gpu_only_cc_library( + name = "gpu_timer_kernel_header", + hdrs = ["gpu_timer_kernel.h"], +) + +gpu_kernel_library( + name = "gpu_timer_kernel", + srcs = if_gpu_is_configured(["gpu_timer_kernel.cu.cc"]), + deps = [ + ":gpu_timer_kernel_header", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + ]), +) + +gpu_only_cc_library( name = "gpu_timer_header", - hdrs = if_gpu_is_configured(["gpu_timer.h"]), + hdrs = ["gpu_timer.h"], deps = [ - ":gpu_driver_header", ":gpu_executor_header", - "//xla/stream_executor:stream_executor_headers", - "//xla/stream_executor:stream_executor_internal", + ":gpu_timer_kernel_header", + ":gpu_types_header", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/time", ], ) -cc_library( +gpu_only_cc_library( name = "gpu_timer", - srcs = if_gpu_is_configured(["gpu_timer.cc"]), - hdrs = if_gpu_is_configured(["gpu_timer.h"]), + srcs = ["gpu_timer.cc"], + hdrs = ["gpu_timer.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), deps = [ ":gpu_driver_header", ":gpu_executor_header", ":gpu_stream", - "//xla/stream_executor:stream_executor_headers", + ":gpu_timer_kernel_header", + ":gpu_types_header", + "//xla/stream_executor", "//xla/stream_executor:stream_executor_internal", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", + "@com_google_absl//absl/utility", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status", - ] + if_cuda_is_configured([ + "@tsl//tsl/platform:statusor", + ] + if_gpu_is_configured([ + ":gpu_timer_kernel", + ]) + if_cuda_is_configured([ "//xla/stream_executor/cuda:cuda_driver", ]) + if_rocm_is_configured([ "//xla/stream_executor/rocm:rocm_driver", ]), ) -cc_library( +gpu_only_cc_library( name = "gpu_types_header", - hdrs = if_gpu_is_configured(["gpu_types.h"]), + hdrs = ["gpu_types.h"], deps = [ "//xla/stream_executor/platform", ] + if_cuda_is_configured([ @@ -316,7 +394,7 @@ cc_library( cc_library( name = "gpu_asm_opts", hdrs = ["gpu_asm_opts.h"], - visibility = set_external_visibility([ + visibility = internal_visibility([ "//xla/service/gpu:__subpackages__", "//xla/stream_executor:__subpackages__", "//tensorflow/core/kernels:__subpackages__", @@ -327,133 +405,199 @@ cc_library( ], ) -cc_library( +gpu_only_cc_library( name = "asm_compiler_header", - hdrs = if_gpu_is_configured(["asm_compiler.h"]), + hdrs = ["asm_compiler.h"], copts = tsl_copts(), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//tensorflow/compiler/mlir/tools/kernel_gen:__subpackages__", "//xla/service/gpu:__subpackages__", "//xla/stream_executor:__subpackages__", "//tensorflow/core/kernels:__subpackages__", ]), - deps = if_gpu_is_configured([ + deps = [ ":gpu_asm_opts", ":gpu_driver_header", ":gpu_helpers_header", - "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/base:core_headers", - "@tsl//tsl/platform:regexp", - "@tsl//tsl/platform:mutex", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:subprocess", - "@tsl//tsl/platform:path", - "@tsl//tsl/platform:cuda_libdevice_path", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor", "//xla/stream_executor/platform", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - ]) + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + "@tsl//tsl/platform:cuda_libdevice_path", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:mutex", + "@tsl//tsl/platform:path", + "@tsl//tsl/platform:regexp", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:subprocess", + ] + if_cuda_is_configured([ "//xla/stream_executor/cuda:cuda_driver", ]) + if_rocm_is_configured([ "//xla/stream_executor/rocm:rocm_driver", - ]) + ["@tsl//tsl/platform:statusor"], + ]), ) -cc_library( +gpu_only_cc_library( name = "asm_compiler", - srcs = if_gpu_is_configured(["asm_compiler.cc"]), - hdrs = if_gpu_is_configured(["asm_compiler.h"]), + srcs = ["asm_compiler.cc"], + hdrs = ["asm_compiler.h"], copts = tsl_copts(), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//third_party/py/jax:__subpackages__", "//tensorflow/compiler/mlir/tools/kernel_gen:__subpackages__", "//xla/service/gpu:__subpackages__", "//xla/stream_executor:__subpackages__", "//tensorflow/core/kernels:__subpackages__", ]), - deps = if_gpu_is_configured([ + deps = [ ":gpu_asm_opts", ":gpu_driver_header", ":gpu_helpers_header", - "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/base:core_headers", - "@tsl//tsl/platform:regexp", - "@tsl//tsl/platform:mutex", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:env", - "@tsl//tsl/platform:subprocess", - "@tsl//tsl/platform:path", - "@tsl//tsl/platform:cuda_libdevice_path", - "//xla/stream_executor:stream_executor_headers", - "//xla/stream_executor/platform", + ":gpu_types_header", "//xla:util", + "//xla/stream_executor", + "//xla/stream_executor/cuda:ptx_compiler", + "//xla/stream_executor/cuda:ptx_compiler_support", + "//xla/stream_executor/platform", + "//xla/tsl/util:env_var", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - ]) + if_cuda_is_configured([ - "//xla/stream_executor/cuda:cuda_asm_compiler", + "@local_config_cuda//cuda:cuda_headers", + "@tsl//tsl/platform:cuda_libdevice_path", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:mutex", + "@tsl//tsl/platform:path", + "@tsl//tsl/platform:regexp", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:subprocess", + ] + if_cuda_is_configured([ + "//xla/stream_executor/cuda:cuda_asm_compiler_legacy", "//xla/stream_executor/cuda:cuda_driver", "//xla/stream_executor/cuda:ptxas_wrapper", "//xla/stream_executor/cuda:nvlink_wrapper", "//xla/stream_executor/cuda:fatbinary_wrapper", ]) + if_rocm_is_configured([ "//xla/stream_executor/rocm:rocm_driver", - ]) + ["@tsl//tsl/platform:statusor"], + ]), ) -cc_library( - name = "redzone_allocator", - srcs = if_gpu_is_configured(["redzone_allocator.cc"]), +gpu_kernel_library( + name = "redzone_allocator_kernel", + srcs = if_gpu_is_configured(["redzone_allocator.cu.cc"]), hdrs = if_gpu_is_configured(["redzone_allocator.h"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + deps = if_gpu_is_configured([ + ":gpu_asm_opts", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ]) + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + ]), +) + +gpu_only_cc_library( + name = "redzone_allocator", + srcs = ["redzone_allocator.cc"], + hdrs = ["redzone_allocator.h"], copts = tsl_copts(), local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - visibility = set_external_visibility([ + visibility = internal_visibility([ "//xla/service/gpu:__subpackages__", "//xla/stream_executor:__subpackages__", "//tensorflow/core/kernels:__subpackages__", ]), - deps = if_gpu_is_configured([ + deps = [ ":asm_compiler", ":gpu_asm_opts", + "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", "@com_google_absl//absl/base", "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@tsl//tsl/framework:allocator", "@tsl//tsl/lib/math:math_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", - "@tsl//tsl/framework:allocator", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_allocator", - "//xla/stream_executor:stream_executor_headers", "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", + ] + if_rocm_is_configured([ + ":redzone_allocator_kernel", ]), ) +xla_cc_test( + name = "redzone_allocator_test", + srcs = if_gpu_is_configured(["redzone_allocator_test.cc"]), + tags = tf_gpu_tests_tags() + [ + "no_cuda_asan", # TODO(b/171512140): re-enable. + ], + deps = [ + ":gpu_asm_opts", + ":gpu_init", + ":redzone_allocator", + "//xla/stream_executor", + "//xla/stream_executor:device_memory_allocator", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + # TODO(tlongeri): Remove gpu_cudamallocasync_allocator header/impl split tsl_gpu_library( name = "gpu_cudamallocasync_allocator_header", hdrs = ["gpu_cudamallocasync_allocator.h"], deps = [ - "//xla/stream_executor", + "//xla/stream_executor:stream_executor_headers", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@tsl//tsl/framework:allocator", "@tsl//tsl/framework:device_id", - "@tsl//tsl/platform:macros", "@tsl//tsl/platform:mutex", ], ) @@ -470,48 +614,18 @@ tsl_gpu_library( ], deps = [ ":gpu_init_impl", - "//xla/stream_executor", + "//xla/stream_executor:stream_executor_headers", + "//xla/tsl/util:env_var", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@tsl//tsl/framework:allocator", "@tsl//tsl/framework:device_id", "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:macros", "@tsl//tsl/platform:mutex", - "@tsl//tsl/util:env_var", ], ) -cc_library( - name = "gpu_graph", - srcs = if_gpu_is_configured(["gpu_graph.cc"]), - hdrs = if_gpu_is_configured(["gpu_graph.h"]), - deps = if_gpu_is_configured([ - ":gpu_driver_header", - ":gpu_kernel_header", - ":gpu_types_header", - ":gpu_executor_header", - "@com_google_absl//absl/status", - "@com_google_absl//absl/types:span", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/functional:any_invocable", - "//xla/stream_executor/gpu:gpu_stream", - "//xla/stream_executor", - "@tsl//tsl/platform:env", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:path", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", - ]) + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - "//xla/stream_executor/cuda:cuda_driver", - ]) + if_rocm_is_configured([ - "//xla/stream_executor/rocm:rocm_driver", - ]) + ["@com_google_absl//absl/strings"], -) - cc_library( name = "gpu_blas_lt", srcs = ["gpu_blas_lt.cc"], @@ -527,14 +641,134 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/service:algorithm_util", + "//xla/stream_executor", "//xla/stream_executor:host_or_device_scalar", - "//xla/stream_executor:stream_executor_headers", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", + "@tsl//tsl/protobuf:dnn_proto_cc", ] + if_cuda_is_configured([ "@tsl//tsl/platform:tensor_float_32_hdr_lib", ]) + if_static([ "@tsl//tsl/platform:tensor_float_32_utils", ]), ) + +gpu_kernel_library( + name = "gpu_test_kernels", + testonly = 1, + srcs = if_gpu_is_configured(["gpu_test_kernels.cu.cc"]), + hdrs = if_gpu_is_configured(["gpu_test_kernels.h"]), + deps = if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + "//xla/stream_executor/rocm:add_i32_kernel", + ]), +) + +xla_test( + name = "gpu_kernel_test", + srcs = if_gpu_is_configured(["gpu_kernel_test.cc"]), + backends = ["gpu"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), + deps = [ + ":gpu_test_kernels", + "//xla/service:platform_util", + "//xla/stream_executor", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ] + if_cuda([ + "//xla/stream_executor/cuda:cuda_platform", + ]) + if_rocm([ + "//xla/stream_executor/rocm:rocm_platform", + ]), +) + +xla_test( + name = "gpu_command_buffer_test", + srcs = if_gpu_is_configured(["gpu_command_buffer_test.cc"]), + backends = ["gpu"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), + deps = [ + ":gpu_command_buffer", + ":gpu_test_kernels", + ":gpu_types_header", + "//xla/service:platform_util", + "//xla/stream_executor", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor/gpu:gpu_driver_header", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@local_config_cuda//cuda:cuda_headers", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_benchmark", + "@tsl//tsl/platform:test_main", + ] + if_cuda([ + "//xla/stream_executor/cuda:cuda_platform", + ]) + if_rocm([ + "//xla/stream_executor/rocm:rocm_platform", + ]), +) + +xla_cc_test( + name = "memcpy_test", + srcs = ["memcpy_test.cc"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + tags = tf_gpu_tests_tags(), + deps = [ + "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:platform_manager", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ] + if_cuda([ + "//xla/stream_executor/cuda:cuda_platform", + ]) + if_rocm([ + "//xla/stream_executor/rocm:rocm_platform", + ]), +) + +xla_cc_test( + name = "stream_search_test", + size = "small", + srcs = ["stream_search_test.cc"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + tags = tf_gpu_tests_tags() + [ + "requires-gpu-nvidia", + ], + deps = [ + "//xla/stream_executor", + "//xla/stream_executor/host:host_platform", + "@com_google_absl//absl/status:statusor", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ] + if_cuda([ + "//xla/stream_executor/cuda:cuda_platform", + ]) + if_rocm([ + "//xla/stream_executor/rocm:rocm_platform", + ]), +) diff --git a/xla/stream_executor/gpu/asm_compiler.cc b/xla/stream_executor/gpu/asm_compiler.cc index 4cdd13f44f75d..cd88402de05db 100644 --- a/xla/stream_executor/gpu/asm_compiler.cc +++ b/xla/stream_executor/gpu/asm_compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,46 +16,62 @@ limitations under the License. #include "xla/stream_executor/gpu/asm_compiler.h" #include +#include #include +#include +#include #include +#include #include #include #include +#include "absl/algorithm/container.h" #include "absl/base/const_init.h" #include "absl/base/optimization.h" #include "absl/base/thread_annotations.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/stream_executor/cuda/ptx_compiler.h" +#include "xla/stream_executor/cuda/ptx_compiler_support.h" +#include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/gpu_types.h" #include "xla/util.h" #include "tsl/platform/cuda_libdevice_path.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" #include "tsl/platform/regexp.h" +#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/subprocess.h" namespace stream_executor { -static tsl::StatusOr GetToolVersionString( - absl::string_view binary_path) { - static absl::Mutex mu(absl::kConstInit); - static auto* seen_binary_paths ABSL_GUARDED_BY(mu) = - new absl::flat_hash_map(); - - absl::MutexLock lock(&mu); - auto it = seen_binary_paths->find(binary_path); - if (it != seen_binary_paths->end()) { - // Already checked this binary, nothing to do. - return absl::string_view(it->second); +static absl::StatusOr GetToolVersionString( + std::string_view binary_path) { + // If binary_path doesn't exist, then tsl::SubProcess will log a bunch of + // error messages that have confused users in the past. Therefore we first + // check whether the binary_path exists and error out early if not. + tsl::Env* env = tsl::Env::Default(); + if (absl::Status file_exists = env->FileExists(std::string{binary_path}); + !file_exists.ok()) { + return file_exists; } tsl::SubProcess binary; @@ -63,7 +79,7 @@ static tsl::StatusOr GetToolVersionString( binary.SetProgram(binary_path_str, {binary_path_str, "--version"}); binary.SetChannelAction(tsl::CHAN_STDOUT, tsl::ACTION_PIPE); if (!binary.Start()) { - return tsl::errors::Internal( + return absl::InternalError( absl::StrFormat("Couldn't invoke %s --version", binary_path)); } @@ -71,62 +87,58 @@ static tsl::StatusOr GetToolVersionString( int exit_code = binary.Communicate(/*stdin_input=*/nullptr, &out, /*stderr_output=*/nullptr); if (exit_code != 0) { - return tsl::errors::Internal(absl::StrFormat( + return absl::InternalError(absl::StrFormat( "Running %s --version returned %d", binary_path, exit_code)); } - auto emplace_it = seen_binary_paths->emplace(binary_path, std::move(out)); - return absl::string_view(emplace_it.first->second); + + return out; } -tsl::StatusOr> GetToolVersion( - absl::string_view tool_path) { - tsl::StatusOr tool_version = - GetToolVersionString(tool_path); +static absl::StatusOr GetToolVersionImpl( + std::string_view tool_path) { + absl::StatusOr tool_version = GetToolVersionString(tool_path); if (!tool_version.ok()) { - return tsl::errors::FailedPrecondition( - "Couldn't get ptxas/nvlink version string: ", tool_version.status()); + return absl::FailedPreconditionError( + absl::StrCat("Couldn't get ptxas/nvlink version string: ", + tool_version.status().ToString())); } static constexpr LazyRE2 kVersionRegex = {R"(\bV(\d+)\.(\d+)\.(\d+)\b)"}; - std::array version; - absl::string_view vmaj_str, vmin_str, vdot_str; + ToolVersion version{}; + std::string_view vmaj_str, vmin_str, vdot_str; if (!RE2::PartialMatch(tool_version.value(), *kVersionRegex, &vmaj_str, &vmin_str, &vdot_str) || !absl::SimpleAtoi(vmaj_str, &version[0]) || !absl::SimpleAtoi(vmin_str, &version[1]) || !absl::SimpleAtoi(vdot_str, &version[2])) { - return tsl::errors::FailedPrecondition( - "Couldn't parse ptxas/nvlink version in output of ", tool_path, - " --version:\n", tool_version.value()); + return absl::FailedPreconditionError( + absl::StrCat("Couldn't parse ptxas/nvlink version in output of ", + tool_path, " --version:\n", tool_version.value())); } return version; } -// Prints a warning if the ptxas at ptxas_path has known bugs. -// -// Only prints a warning the first time it's called for a particular value of -// ptxas_path. -// -// Locks on entry.˝ -static void WarnIfBadPtxasVersion(absl::string_view ptxas_path) { - tsl::StatusOr> version = GetToolVersion(ptxas_path); - if (!version.ok()) { - LOG(WARNING) << "Couldn't get ptxas version : " << version.status(); - return; - } - - if (std::make_tuple((*version)[0], (*version)[1]) < std::make_tuple(11, 1)) { - LOG(ERROR) << "*** WARNING *** You are using ptxas " << (*version)[0] << "." - << (*version)[1] << "." << (*version)[2] - << ", which is older than 11.1. ptxas before 11.1 is known to " - "miscompile XLA code, leading to incorrect results or " - "invalid-address errors.\n"; +absl::StatusOr GetToolVersion(std::string_view tool_path) { + // This is only implementing a static cache. `GetToolVersionImpl` has the + // actual business logic. + static absl::Mutex mutex(absl::kConstInit); + static auto cache = + new absl::flat_hash_map> + ABSL_GUARDED_BY(mutex); + + absl::MutexLock lock(&mutex); + auto it = cache->find(tool_path); + if (it != cache->end()) { + return it->second; } + + return cache->try_emplace(tool_path, GetToolVersionImpl(tool_path)) + .first->second; } -tsl::StatusOr> CompileGpuAsmOrGetCached( +absl::StatusOr> CompileGpuAsmOrGetCached( int device_ordinal, const char* ptx, GpuAsmOpts compilation_options) { using PtxCacheKey = std::tuple; - using PtxCompilerResult = tsl::StatusOr>; + using PtxCompilerResult = absl::StatusOr>; static absl::Mutex ptx_cache_mutex(absl::kConstInit); static auto& ptx_cache ABSL_GUARDED_BY(ptx_cache_mutex) = *new absl::flat_hash_map(); @@ -155,9 +167,9 @@ tsl::StatusOr> CompileGpuAsmOrGetCached( return absl::MakeSpan(compiled); } -tsl::StatusOr> CompileGpuAsm(int device_ordinal, - const char* ptx_contents, - GpuAsmOpts options) { +absl::StatusOr> CompileGpuAsm(int device_ordinal, + const char* ptx_contents, + GpuAsmOpts options) { gpu::GpuDeviceHandle handle; TF_RETURN_IF_ERROR(gpu::GpuDriver::GetDevice(device_ordinal, &handle)); int cc_major; @@ -167,57 +179,75 @@ tsl::StatusOr> CompileGpuAsm(int device_ordinal, return CompileGpuAsm(cc_major, cc_minor, ptx_contents, options); } -std::string FindCudaExecutable(const std::string& binary_name, - const std::string& preferred_cuda_dir) { - static absl::Mutex mu(absl::kConstInit); - static auto* seen_binary_paths ABSL_GUARDED_BY(mu) = - new absl::flat_hash_map, - std::string>(); +absl::StatusOr FindCudaExecutable( + std::string_view binary_name, std::string_view preferred_cuda_dir, + ToolVersion minimum_version, + absl::Span excluded_versions) { + std::string binary_filename = std::string{binary_name}; + tsl::io::AppendDotExeIfWindows(binary_filename); + + std::vector candidates{}; + + // #1 - Check the preferred CUDA directory + candidates.emplace_back( + tsl::io::JoinPath(preferred_cuda_dir, "bin", binary_filename)); + + // #2 - Check the PATH environment variable + std::string_view path_env = std::getenv("PATH"); #if defined(PLATFORM_WINDOWS) - const std::string binary_filename = binary_name + ".exe"; + constexpr char kSearchPathSeparator = ';'; #else - const std::string& binary_filename = binary_name; + constexpr char kSearchPathSeparator = ':'; #endif - auto cache_key = std::make_pair(binary_name, preferred_cuda_dir); - - absl::MutexLock lock(&mu); - auto it = seen_binary_paths->find(cache_key); - if (it != seen_binary_paths->end()) { - return it->second; + for (std::string_view path : absl::StrSplit(path_env, kSearchPathSeparator)) { + candidates.emplace_back(tsl::io::JoinPath(path, binary_filename)); } - // Try searching in the default PATH first if applicable. - if (tsl::PreferPtxasFromPath() && - GetToolVersionString(binary_filename).ok()) { - VLOG(2) << "Using " << binary_filename; - seen_binary_paths->emplace(std::move(cache_key), binary_filename); - return binary_filename; + // #3 - Check generic CUDA locations + for (std::string_view path : tsl::CandidateCudaRoots()) { + candidates.emplace_back(tsl::io::JoinPath(path, "bin", binary_filename)); } - // Search in cuda root candidates. - auto env = tsl::Env::Default(); - std::string binary_path; - for (const std::string& cuda_root : - tsl::CandidateCudaRoots(preferred_cuda_dir)) { - binary_path = tsl::io::JoinPath(cuda_root, "bin", binary_filename); - VLOG(2) << "Looking for " << binary_filename << " at " << binary_path; - if (env->FileExists(binary_path).ok() && - GetToolVersionString(binary_path).ok()) { - break; + for (const auto& candidate : candidates) { + VLOG(2) << "Looking for " << candidate; + auto candidate_version = GetToolVersion(candidate); + if (!candidate_version.ok()) { + continue; + } + + if (candidate_version.value() < minimum_version) { + VLOG(2) << candidate << " with version " + << absl::StrJoin(minimum_version, ".") << " is too old."; + continue; + } + + if (absl::c_find(excluded_versions, candidate_version.value()) != + excluded_versions.end()) { + VLOG(2) << candidate << " has version " + << absl::StrJoin(candidate_version.value(), ".") + << " which was explicitly excluded."; + continue; } + + VLOG(2) << "Using " << candidate << " with version " + << absl::StrJoin(candidate_version.value(), "."); + return candidate; } - if (!env->FileExists(binary_path).ok()) { - // Give up and just rely on subprocess invocation to find the correct - // binary. This won't work, in all probability, given we already tried that - // above, but it's the best we can do. - VLOG(2) << "Unable to find " << binary_name; - binary_path = binary_filename; - } - VLOG(2) << "Using " << binary_filename << " at " << binary_path; - seen_binary_paths->emplace(std::move(cache_key), binary_path); - return binary_path; + + return absl::NotFoundError( + absl::StrCat("Couldn't find a suitable version of ", binary_name, + ". The following locations were considered: ", + absl::StrJoin(candidates, ", "))); +} + +absl::StatusOr FindCudaExecutable( + std::string_view binary_name, std::string_view preferred_cuda_dir) { + static constexpr ToolVersion kNoMinimumVersion{0, 0, 0}; + static constexpr absl::Span kNoExcludedVersions{}; + return FindCudaExecutable(binary_name, preferred_cuda_dir, kNoMinimumVersion, + kNoExcludedVersions); } static void LogPtxasTooOld(const std::string& ptxas_path, int cc_major, @@ -248,28 +278,38 @@ static void AppendArgsFromOptions(GpuAsmOpts options, options.extra_flags.end()); } -tsl::StatusOr> GetAsmCompilerVersion( - const std::string& preferred_cuda_dir) { - std::string ptxas_path = FindCudaExecutable("ptxas", preferred_cuda_dir); - return GetToolVersion(ptxas_path); +static absl::StatusOr FindPtxAsExecutable( + std::string_view preferred_cuda_dir) { + static constexpr ToolVersion kMinimumSupportedPtxAsVersion{11, 8, 0}; + static constexpr ToolVersion kBuggyPtxAsVersions[] = {{12, 3, 103}}; + static constexpr std::string_view kPtxAsBinaryName = "ptxas"; + + return FindCudaExecutable(kPtxAsBinaryName, preferred_cuda_dir, + kMinimumSupportedPtxAsVersion, kBuggyPtxAsVersions); } -tsl::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, - const char* ptx_contents, - GpuAsmOpts options, - bool cancel_if_reg_spill) { - std::string ptxas_path = - FindCudaExecutable("ptxas", options.preferred_cuda_dir); +absl::StatusOr GetAsmCompilerVersion( + std::string_view preferred_cuda_dir) { + TF_ASSIGN_OR_RETURN(std::string ptxas_path, + FindPtxAsExecutable(preferred_cuda_dir)); + return GetToolVersion(ptxas_path); +} - WarnIfBadPtxasVersion(ptxas_path); +absl::StatusOr> CompileGpuAsmUsingPtxAs( + int cc_major, int cc_minor, const char* ptx_contents, GpuAsmOpts options, + bool cancel_if_reg_spill) { + TF_ASSIGN_OR_RETURN(std::string ptxas_path, + FindPtxAsExecutable(options.preferred_cuda_dir)); // Write ptx into a temporary file. std::string ptx_path; auto env = tsl::Env::Default(); if (!env->LocalTempFilename(&ptx_path)) { - return tsl::errors::Internal("couldn't get temp PTX file name"); + return absl::InternalError("couldn't get temp PTX file name"); } - TF_RETURN_IF_ERROR(tsl::WriteStringToFile(env, ptx_path, ptx_contents)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + tsl::WriteStringToFile(env, ptx_path, ptx_contents), + "Unable to write PTX contents to: ", ptx_path); VLOG(2) << "ptx written to: " << ptx_path; absl::Cleanup ptx_cleaner = [&ptx_path] { @@ -279,7 +319,7 @@ tsl::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, // Invoke ptxas and collect its output. std::string cubin_path; if (!env->LocalTempFilename(&cubin_path)) { - return tsl::errors::Internal("couldn't get temp CUBIN file name"); + return absl::InternalError("couldn't get temp CUBIN file name"); } absl::Cleanup cubin_cleaner = [&cubin_path] { // CUBIN file may never be created, so the failure to delete it should not @@ -307,8 +347,9 @@ tsl::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, ptxas_info_dumper.SetProgram(ptxas_path, ptxas_args); ptxas_info_dumper.SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE); + VLOG(5)<<"subprocess running:"<> CompileGpuAsm(int cc_major, int cc_minor, return absl::ResourceExhaustedError("Register allocation failed"); } - return tsl::errors::Internal( + return absl::InternalError( absl::StrFormat("ptxas exited with non-zero error code %d, output: %s", exit_status, stderr_output)); } @@ -358,10 +399,11 @@ tsl::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, return cubin_vector; } -tsl::StatusOr> BundleGpuAsm( +absl::StatusOr> BundleGpuAsm( std::vector images, GpuAsmOpts options) { - std::string fatbinary_path = - FindCudaExecutable("fatbinary", options.preferred_cuda_dir); + TF_ASSIGN_OR_RETURN( + std::string fatbinary_path, + FindCudaExecutable("fatbinary", options.preferred_cuda_dir)); // Write images to temporary files. std::vector image_paths; @@ -369,7 +411,7 @@ tsl::StatusOr> BundleGpuAsm( for (const CubinOrPTXImage& img : images) { std::string img_path; if (!env->LocalTempFilename(&img_path)) { - return tsl::errors::Internal( + return absl::InternalError( "Could not get temporary filenames for images."); } TF_RETURN_IF_ERROR(tsl::WriteStringToFile( @@ -386,7 +428,7 @@ tsl::StatusOr> BundleGpuAsm( // Prepare temorary result file. std::string result_path; if (!env->LocalTempFilename(&result_path)) { - return tsl::errors::Internal( + return absl::InternalError( "Could not get temporary filename for fatbin result."); } absl::Cleanup result_file_cleaner = [&result_path] { @@ -418,14 +460,15 @@ tsl::StatusOr> BundleGpuAsm( } fatbinary.SetProgram(fatbinary_path, fatbinary_args); fatbinary.SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE); + VLOG(5)<<"subprocess running:"<> BundleGpuAsm( +absl::StatusOr> BundleGpuAsm( std::vector images, const std::string rocm_root_dir) { std::string clang_offload_bundler_path = findRocmExecutable("llvm/bin/clang-offload-bundler", rocm_root_dir); @@ -535,4 +578,20 @@ tsl::StatusOr> BundleGpuAsm( return std::vector(result_blob.begin(), result_blob.end()); } +absl::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, + const char* ptx_contents, + GpuAsmOpts options, + bool cancel_if_reg_spill) { + if (IsLibNvPtxCompilerSupported()) { + VLOG(3) << "Compiling GPU ASM with libnvptxcompiler"; + return CompileGpuAsmUsingLibNvPtxCompiler(cc_major, cc_minor, ptx_contents, + options, cancel_if_reg_spill); + } + + VLOG(3) << "Compiling GPU ASM with PTXAS. Libnvptxcompiler compilation " + "not supported."; + return CompileGpuAsmUsingPtxAs(cc_major, cc_minor, ptx_contents, options, + cancel_if_reg_spill); +} + } // namespace stream_executor diff --git a/xla/stream_executor/gpu/asm_compiler.h b/xla/stream_executor/gpu/asm_compiler.h index 0e15e326ca23e..5933a218baca1 100644 --- a/xla/stream_executor/gpu/asm_compiler.h +++ b/xla/stream_executor/gpu/asm_compiler.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,20 +18,24 @@ limitations under the License. #include #include -#include +#include +#include #include #include "absl/base/const_init.h" #include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/statusor.h" #if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/cuda/cuda_driver.h" #endif // GOOGLE_CUDA @@ -46,9 +50,9 @@ class GpuContext; // // 'options' is used to query for the CUDA location in case it is // customized in a passed flag, and for controlling ptxas optimizations. -tsl::StatusOr> CompileGpuAsm(int device_ordinal, - const char* ptx_contents, - GpuAsmOpts options); +absl::StatusOr> CompileGpuAsm(int device_ordinal, + const char* ptx_contents, + GpuAsmOpts options); // Compiles the given PTX string using ptxas and returns the resulting machine // code (i.e. a cubin) as a byte array. The generated cubin matches the compute @@ -56,7 +60,11 @@ tsl::StatusOr> CompileGpuAsm(int device_ordinal, // // 'options' is used to query for the CUDA location in case it is // customized in a passed flag, and for controlling ptxas optimizations. -tsl::StatusOr> CompileGpuAsm( +absl::StatusOr> CompileGpuAsm( + int cc_major, int cc_minor, const char* ptx_contents, GpuAsmOpts options, + bool cancel_if_reg_spill = false); + +absl::StatusOr> CompileGpuAsmUsingPtxAs( int cc_major, int cc_minor, const char* ptx_contents, GpuAsmOpts options, bool cancel_if_reg_spill = false); @@ -64,7 +72,7 @@ tsl::StatusOr> CompileGpuAsm( // the compiled binary. // // A copy of the string provided in ptx will be made. -tsl::StatusOr> CompileGpuAsmOrGetCached( +absl::StatusOr> CompileGpuAsmOrGetCached( int device_ordinal, const char* ptx, GpuAsmOpts compilation_options); struct CubinOrPTXImage { @@ -74,7 +82,7 @@ struct CubinOrPTXImage { // Bundles the GPU machine code (cubins) and PTX if requested and returns the // resulting binary (i.e. a fatbin) as a byte array. -tsl::StatusOr> BundleGpuAsm( +absl::StatusOr> BundleGpuAsm( std::vector images, GpuAsmOpts options); struct HsacoImage { @@ -84,33 +92,42 @@ struct HsacoImage { // Bundles the GPU machine code (HSA Code Object) and returns the resulting // binary (i.e. a fatbin) as a byte array. -tsl::StatusOr> BundleGpuAsm( +absl::StatusOr> BundleGpuAsm( std::vector images, const std::string rocm_root_dir); // Links multiple relocatable GPU images (e.g. results of ptxas -c) into a // single image. -tsl::StatusOr> LinkGpuAsm( +absl::StatusOr> LinkGpuAsm( gpu::GpuContext* context, std::vector images); -tsl::StatusOr> LinkUsingNvlink( +absl::StatusOr> LinkUsingNvlink( absl::string_view preferred_cuda_dir, gpu::GpuContext* context, std::vector images); -std::string FindCudaExecutable(const std::string& binary_name, - const std::string& preferred_cuda_dir); +using ToolVersion = std::array; +absl::StatusOr FindCudaExecutable( + std::string_view binary_name, std::string_view preferred_cuda_dir, + ToolVersion minimum_version, + absl::Span excluded_versions); + +absl::StatusOr FindCudaExecutable( + std::string_view binary_name, std::string_view preferred_cuda_dir); // Runs tool --version and parses its version string. -tsl::StatusOr> GetToolVersion( - absl::string_view tool_path); +absl::StatusOr GetToolVersion(std::string_view tool_path); + +// On NVIDIA GPUs, returns the version of the ptxas command line tool. +absl::StatusOr GetAsmCompilerVersion( + std::string_view preferred_cuda_dir); -// On NVIDIA GPUs, returns the CUDA toolkit version supported by the driver, -tsl::StatusOr> GetAsmCompilerVersion( - const std::string& preferred_cuda_dir); +// On NVIDIA GPUs, returns the version of the nvlink command line tool. +absl::StatusOr GetNvLinkVersion( + std::string_view preferred_cuda_dir); #if GOOGLE_CUDA // Maintains a cache of pointers to loaded kernels template -tsl::StatusOr>> LoadKernelOrGetPtr( +absl::StatusOr*> LoadKernelOrGetPtr( StreamExecutor* executor, absl::string_view kernel_name, absl::string_view ptx, absl::Span cubin_data) { using KernelPtrCacheKey = @@ -118,8 +135,7 @@ tsl::StatusOr>> LoadKernelOrGetPtr( static absl::Mutex kernel_ptr_cache_mutex(absl::kConstInit); static auto& kernel_ptr_cache ABSL_GUARDED_BY(kernel_ptr_cache_mutex) = - *new absl::flat_hash_map>>(); + *new absl::node_hash_map>(); CUcontext current_context = cuda::CurrentContextOrDie(); KernelPtrCacheKey kernel_ptr_cache_key{current_context, kernel_name, ptx}; absl::MutexLock lock(&kernel_ptr_cache_mutex); @@ -127,14 +143,14 @@ tsl::StatusOr>> LoadKernelOrGetPtr( auto it = kernel_ptr_cache.find(kernel_ptr_cache_key); if (it == kernel_ptr_cache.end()) { TF_ASSIGN_OR_RETURN( - std::shared_ptr> loaded, - executor->CreateTypedKernel(kernel_name, ptx, cubin_data)); + TypedKernel loaded, + (TypedKernel::Create(executor, kernel_name, ptx, cubin_data))); it = kernel_ptr_cache.emplace(kernel_ptr_cache_key, std::move(loaded)).first; } CHECK(it != kernel_ptr_cache.end()); - return it->second; + return &it->second; } #endif // GOOGLE_CUDA diff --git a/xla/stream_executor/gpu/gpu_activation.cc b/xla/stream_executor/gpu/gpu_activation.cc index 4a4704fb539b0..c40182cccf169 100644 --- a/xla/stream_executor/gpu/gpu_activation.cc +++ b/xla/stream_executor/gpu/gpu_activation.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,7 +18,6 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/stream_executor_internal.h" namespace stream_executor { namespace gpu { diff --git a/xla/stream_executor/gpu/gpu_activation.h b/xla/stream_executor/gpu/gpu_activation.h index 385644d7ac197..a28bef2e5da83 100644 --- a/xla/stream_executor/gpu/gpu_activation.h +++ b/xla/stream_executor/gpu/gpu_activation.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,7 +23,6 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_ACTIVATION_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_ACTIVATION_H_ -#include "xla/stream_executor/platform/port.h" namespace stream_executor { diff --git a/xla/stream_executor/gpu/gpu_asm_opts.h b/xla/stream_executor/gpu/gpu_asm_opts.h index 6a10c2653e906..7a34e6bbc1e1a 100644 --- a/xla/stream_executor/gpu/gpu_asm_opts.h +++ b/xla/stream_executor/gpu/gpu_asm_opts.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/gpu/gpu_blas_lt.cc b/xla/stream_executor/gpu/gpu_blas_lt.cc index 8818808029bf4..17b2562bd5d5a 100644 --- a/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,14 +15,20 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_blas_lt.h" +#include #include +#include #include +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "xla/primitive_util.h" +#include "xla/service/algorithm_util.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/stream.h" #include "xla/util.h" -#include "tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" +#include "tsl/protobuf/dnn.pb.h" #if GOOGLE_CUDA #include "tsl/platform/tensor_float_32_utils.h" #endif @@ -35,12 +41,16 @@ using blas::ComputationType; using blas::DataType; using xla::PrimitiveType; -tsl::StatusOr AsBlasDataType(PrimitiveType dtype) { +absl::StatusOr AsBlasDataType(PrimitiveType dtype) { switch (dtype) { case PrimitiveType::F8E5M2: return DataType::kF8E5M2; case PrimitiveType::F8E4M3FN: return DataType::kF8E4M3FN; + case PrimitiveType::F8E5M2FNUZ: + return DataType::kF8E5M2FNUZ; + case PrimitiveType::F8E4M3FNUZ: + return DataType::kF8E4M3FNUZ; case PrimitiveType::S8: return DataType::kInt8; case PrimitiveType::F16: @@ -58,18 +68,22 @@ tsl::StatusOr AsBlasDataType(PrimitiveType dtype) { case PrimitiveType::C128: return DataType::kComplexDouble; default: - return xla::InternalError( + return xla::Internal( "AsBlasDataType: unsupported type: %s", xla::primitive_util::LowercasePrimitiveTypeName(dtype)); } } -tsl::StatusOr AsXlaPrimitiveType(DataType dtype) { +absl::StatusOr AsXlaPrimitiveType(DataType dtype) { switch (dtype) { case DataType::kF8E5M2: return PrimitiveType::F8E5M2; case DataType::kF8E4M3FN: return PrimitiveType::F8E4M3FN; + case DataType::kF8E5M2FNUZ: + return PrimitiveType::F8E5M2FNUZ; + case DataType::kF8E4M3FNUZ: + return PrimitiveType::F8E4M3FNUZ; case DataType::kInt8: return PrimitiveType::S8; case DataType::kHalf: @@ -87,53 +101,90 @@ tsl::StatusOr AsXlaPrimitiveType(DataType dtype) { case DataType::kComplexDouble: return PrimitiveType::C128; default: - return xla::InternalError("AsXlaPrimitiveType: unsupported dtype"); + return xla::Internal("AsXlaPrimitiveType: unsupported dtype"); } } -tsl::StatusOr GetBlasComputationType( - PrimitiveType lhs_dtype, PrimitiveType output_dtype, - int64_t compute_precision) { - switch (output_dtype) { - case PrimitiveType::F8E5M2: // fall-through - case PrimitiveType::F8E4M3FN: // fall-through - case PrimitiveType::F16: // fall-through - case PrimitiveType::BF16: - // Accumulate in f32 precision. - return ComputationType::kF32; - case PrimitiveType::F32: // fall-through - case PrimitiveType::C64: +MatrixLayout::MatrixLayout(xla::PrimitiveType dtype_, int64_t num_rows_, + int64_t num_cols_, MatrixLayout::Order order_, + int64_t batch_size_, + std::optional leading_dim_stride_, + std::optional batch_stride_, + std::optional transpose_) + : dtype(dtype_), + num_rows(num_rows_), + num_cols(num_cols_), + order(order_), + batch_size(batch_size_) { + if (!leading_dim_stride_) { + leading_dim_stride = order == Order::kRowMajor ? num_cols : num_rows; + } else { + leading_dim_stride = *leading_dim_stride_; + } + if (!batch_stride_) { + batch_stride = (batch_size > 1) ? num_rows * num_cols : 0; + } else { + batch_stride = *batch_stride_; + } + transpose = transpose_ ? *transpose_ : blas::Transpose::kNoTranspose; +} + +void MatrixLayout::Transpose() { + std::swap(num_rows, num_cols); + order = (order == Order::kRowMajor) ? Order::kColumnMajor : Order::kRowMajor; +} + +absl::StatusOr GetBlasComputationType( + xla::PrecisionConfig::Algorithm algorithm, xla::PrimitiveType lhs_dtype, + xla::PrimitiveType output_dtype, int64_t compute_precision) { + if (algorithm == xla::PrecisionConfig::ALG_UNSET) { + switch (output_dtype) { + case PrimitiveType::F8E5M2: // fall-through + case PrimitiveType::F8E4M3FN: // fall-through + case PrimitiveType::F8E5M2FNUZ: // fall-through + case PrimitiveType::F8E4M3FNUZ: // fall-through + case PrimitiveType::F16: // fall-through + case PrimitiveType::BF16: + // Accumulate in f32 precision. + return ComputationType::kF32; + case PrimitiveType::F32: // fall-through + case PrimitiveType::C64: #if GOOGLE_CUDA - if (tsl::tensor_float_32_execution_enabled() && compute_precision <= 1 && - lhs_dtype == output_dtype) { - // CublasLt requires compute type to be F32 for F8 matmul. - // TF32 should only be chosen for FP32 or C64 gemm - return ComputationType::kTF32AsF32; - } + if (tsl::tensor_float_32_execution_enabled() && + compute_precision <= 1 && lhs_dtype == output_dtype) { + // CublasLt requires compute type to be F32 for F8 matmul. + // TF32 should only be chosen for FP32 or C64 gemm + return ComputationType::kTF32AsF32; + } #endif - return ComputationType::kF32; - case PrimitiveType::F64: // fall-through - case PrimitiveType::C128: - return ComputationType::kF64; - case PrimitiveType::S32: - return ComputationType::kI32; - default: - return xla::InternalError("GetBlasComputationType: unsupported type"); + return ComputationType::kF32; + case PrimitiveType::F64: // fall-through + case PrimitiveType::C128: + return ComputationType::kF64; + case PrimitiveType::S32: + return ComputationType::kI32; + default: + return xla::Internal("GetBlasComputationType: unsupported type"); + } } + + return xla::algorithm_util::GetBlasComputationType(algorithm); } // BLAS GeMM's output is column-major. If we require row-major, use identity: // C^T = (A @ B)^T = B^T @ A^T. bool MakeOutputColumnMajor(MatrixLayout& lhs, MatrixLayout& rhs, - MatrixLayout& output, MatrixLayout* pC) { + MatrixLayout& output, MatrixLayout* c) { bool swap_operands = output.order != MatrixLayout::Order::kColumnMajor; if (swap_operands) { std::swap(lhs, rhs); rhs.Transpose(); - lhs.Transpose(); - // prevent pC and output from being swapped two times if they are equal! - if (pC != nullptr && pC != &output) { - pC->Transpose(); + // prevent layouts from being swapped two times if they are equal + if (&lhs != &rhs) { + lhs.Transpose(); + } + if (c != nullptr && c != &output) { + c->Transpose(); } output.Transpose(); } @@ -142,10 +193,10 @@ bool MakeOutputColumnMajor(MatrixLayout& lhs, MatrixLayout& rhs, /*static*/ auto BlasLt::GetMatmulPlan(const Stream* stream, const GemmConfig& cfg, Epilogue epilogue) - -> tsl::StatusOr { + -> absl::StatusOr { auto blas = Get(stream); if (blas == nullptr) { - return xla::InternalError("BlasLt is unavailable"); + return xla::Internal("BlasLt is unavailable"); } return blas->GetMatmulPlan(cfg, epilogue); } diff --git a/xla/stream_executor/gpu/gpu_blas_lt.h b/xla/stream_executor/gpu/gpu_blas_lt.h index 40adeec908c9c..f6c1138878745 100644 --- a/xla/stream_executor/gpu/gpu_blas_lt.h +++ b/xla/stream_executor/gpu/gpu_blas_lt.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,28 +17,30 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_H_ #include +#include #include +#include #include #include #include -#include "xla/shape.h" #include "xla/status.h" -#include "xla/statusor.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/host_or_device_scalar.h" #include "xla/types.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" namespace stream_executor::gpu { -tsl::StatusOr AsBlasDataType(xla::PrimitiveType dtype); +absl::StatusOr AsBlasDataType(xla::PrimitiveType dtype); -tsl::StatusOr AsXlaPrimitiveType(blas::DataType dtype); +absl::StatusOr AsXlaPrimitiveType(blas::DataType dtype); -tsl::StatusOr GetBlasComputationType( - xla::PrimitiveType lhs_dtype, xla::PrimitiveType output_dtype, - int64_t compute_precision); +absl::StatusOr GetBlasComputationType( + xla::PrecisionConfig::Algorithm algorithm, xla::PrimitiveType lhs_dtype, + xla::PrimitiveType output_dtype, int64_t compute_precision); // Returns the type for the alpha and beta scalars. blas::DataType GetScaleType(blas::DataType c_type, @@ -51,11 +53,13 @@ struct MatrixLayout { // plain MatrixLayout which is extended with create kColumnMajor, // Elements in the same column are contiguous in memory. }; - void Transpose() { - std::swap(num_rows, num_cols); - order = - (order == Order::kRowMajor) ? Order::kColumnMajor : Order::kRowMajor; - } + MatrixLayout(xla::PrimitiveType dtype_, int64_t num_rows_, int64_t num_cols_, + Order order_, int64_t batch_size_ = 1, + std::optional leading_dim_stride_ = {}, + std::optional batch_stride_ = {}, + std::optional transpose_ = {}); + + void Transpose(); xla::PrimitiveType dtype; // `num_rows` / `num_cols` are for the "logical" matrix shape: @@ -65,16 +69,39 @@ struct MatrixLayout { // plain MatrixLayout which is extended with create int64_t num_cols; Order order; int64_t batch_size; - std::optional leading_dim_stride; + int64_t leading_dim_stride; // `batch_stride` is set to `0` when `batch_size == 1`. - std::optional batch_stride; - std::optional transpose; + int64_t batch_stride; + blas::Transpose transpose; +}; + +// compact version of the matrix layout to be used to pass matrices +// to underlying blas API +struct MatrixDescriptor { + DeviceMemoryBase data; + int64_t leading_dim_stride = 0; + int64_t batch_stride = 0; + blas::DataType type{}; + blas::Transpose transpose{}; + + template + DeviceMemory cast() const { + return DeviceMemory(data); + } +}; + +struct OutputMatrixDescriptor : public MatrixDescriptor { + OutputMatrixDescriptor(MatrixDescriptor&& parent) noexcept + : MatrixDescriptor(std::move(parent)) {} + int64_t batch_size = 0; + int64_t m = 0, n = 0, k = 0; + blas::ComputationType compute_type{}; }; // BLAS GeMM's output is column-major. If we require row-major, use identity: // C^T = (A @ B)^T = B^T @ A^T. bool MakeOutputColumnMajor(MatrixLayout& lhs, MatrixLayout& rhs, - MatrixLayout& output, MatrixLayout* pC = nullptr); + MatrixLayout& output, MatrixLayout* c = nullptr); struct GemmConfig { // plain GemmConfig which is extended with create functions // in matmul_utils.h @@ -85,23 +112,15 @@ struct GemmConfig { // plain GemmConfig which is extended with create functions xla::complex128 alpha; double beta; int64_t compute_precision; + // PrecisionConfig-level algorithm + xla::PrecisionConfig::Algorithm precision_algorithm; + // BLAS-library-level algorithm. std::optional algorithm; bool grad_x; bool grad_y; std::optional compute_type; }; -// template < cudaDataType_t What, cudaDataType_t SrcT, class Z, class... T> -// struct ChooseType { -// using type = std::conditional_t< What == SrcT, Z, -// typename ChooseType< What, T...>::type>; -// }; - -// template < cudaDataType_t What > -// using CudaToNativeT = typename ChooseType< What, CUDA_R_8F_E4M3, -// tsl::float8_e4m3fn, -// CUDA_R_8F_E5M2, tsl::float8_e5m2, ... >::type; - struct BlasLt { enum class Epilogue { kDefault = 1, // No special postprocessing @@ -127,20 +146,21 @@ struct BlasLt { struct MatmulPlan { template - tsl::Status DoMatmul(Stream* stream, const HostOrDeviceScalar& alpha, - const DeviceMemory
& a, const DeviceMemory& b, - const HostOrDeviceScalar& beta, - const DeviceMemory& c, DeviceMemory& d, - const MatmulAlgorithm& algorithm, - ScratchAllocator& scratch_allocator, - const DeviceMemory& bias = {}, - const DeviceMemoryBase& aux = DeviceMemory{}, - const DeviceMemory& a_scale = {}, - const DeviceMemory& b_scale = {}, - const DeviceMemory& c_scale = {}, - const DeviceMemory& d_scale = {}, - const DeviceMemory& d_amax = {}, - blas::ProfileResult* profile_result = nullptr) const { + absl::Status DoMatmul(Stream* stream, + const HostOrDeviceScalar& alpha, + const DeviceMemory& a, const DeviceMemory& b, + const HostOrDeviceScalar& beta, + const DeviceMemory& c, DeviceMemory& d, + const MatmulAlgorithm& algorithm, + ScratchAllocator& scratch_allocator, + const DeviceMemory& bias = {}, + const DeviceMemoryBase& aux = DeviceMemory{}, + const DeviceMemory& a_scale = {}, + const DeviceMemory& b_scale = {}, + const DeviceMemory& c_scale = {}, + const DeviceMemory& d_scale = {}, + const DeviceMemory& d_amax = {}, + blas::ProfileResult* profile_result = nullptr) const { TF_RETURN_IF_ERROR(ValidateInputs( blas::ToDataType::value, alpha.on_device(), beta.on_device(), blas::ToDataType::value, blas::ToDataType::value, @@ -152,21 +172,22 @@ struct BlasLt { } template - tsl::Status DoMatmul(Stream* stream, const HostOrDeviceScalar& alpha, - const DeviceMemory& a, const DeviceMemory& b, - const HostOrDeviceScalar& beta, - const DeviceMemory& c, DeviceMemory& d, - const MatmulAlgorithm& algorithm, - ScratchAllocator& scratch_allocator, - const DeviceMemory& bias = {}, - const DeviceMemoryBase& aux = DeviceMemory{}, - blas::ProfileResult* profile_result = nullptr) const { + absl::Status DoMatmul(Stream* stream, + const HostOrDeviceScalar& alpha, + const DeviceMemory& a, const DeviceMemory& b, + const HostOrDeviceScalar& beta, + const DeviceMemory& c, DeviceMemory& d, + const MatmulAlgorithm& algorithm, + ScratchAllocator& scratch_allocator, + const DeviceMemory& bias = {}, + const DeviceMemoryBase& aux = DeviceMemory{}, + blas::ProfileResult* profile_result = nullptr) const { return DoMatmul(stream, alpha, a, b, beta, c, d, algorithm, scratch_allocator, bias, aux, {}, {}, {}, {}, {}, profile_result); } - virtual tsl::Status ExecuteOnStream( + virtual absl::Status ExecuteOnStream( Stream* stream, DeviceMemoryBase a_buffer, DeviceMemoryBase b_buffer, DeviceMemoryBase c_buffer, DeviceMemoryBase d_buffer, DeviceMemoryBase bias_buffer, // may be null @@ -180,7 +201,7 @@ struct BlasLt { // Returns a list of supported algorithms for DoMatmul. The algorithms are // returned in the order of increasing estimated compute time according to // an internal heuristic. - virtual tsl::StatusOr> GetAlgorithms( + virtual absl::StatusOr> GetAlgorithms( size_t max_algorithm_count = 128, size_t max_workspace_size = 1ll << 32) const = 0; @@ -190,16 +211,16 @@ struct BlasLt { // might be used internally by ExecuteOnStream in derived classes template - tsl::Status DoMatmul(Stream* stream, xla::complex128 alpha, - DeviceMemoryBase a, DeviceMemoryBase b, double beta, - DeviceMemoryBase c, DeviceMemoryBase d, - DeviceMemoryBase bias, DeviceMemoryBase aux, - DeviceMemoryBase a_scale, DeviceMemoryBase b_scale, - DeviceMemoryBase c_scale, DeviceMemoryBase d_scale, - DeviceMemoryBase d_amax, - const MatmulAlgorithm& algorithm, - ScratchAllocator& scratch_allocator, - blas::ProfileResult* profile_result) const { + absl::Status DoMatmul(Stream* stream, xla::complex128 alpha, + DeviceMemoryBase a, DeviceMemoryBase b, double beta, + DeviceMemoryBase c, DeviceMemoryBase d, + DeviceMemoryBase bias, DeviceMemoryBase aux, + DeviceMemoryBase a_scale, DeviceMemoryBase b_scale, + DeviceMemoryBase c_scale, DeviceMemoryBase d_scale, + DeviceMemoryBase d_amax, + const MatmulAlgorithm& algorithm, + ScratchAllocator& scratch_allocator, + blas::ProfileResult* profile_result) const { Scale salpha; if constexpr (std::is_same_v || std::is_same_v) { @@ -221,12 +242,12 @@ struct BlasLt { } // used internally by template DoMatmul function to validate inputs - virtual tsl::Status ValidateInputs( + virtual absl::Status ValidateInputs( blas::DataType scale_type, bool alpha_on_device, bool beta_on_device, blas::DataType A_type, blas::DataType B_type, blas::DataType C_type, blas::DataType D_type) const = 0; - virtual tsl::Status DoMatmul( + virtual absl::Status DoMatmul( Stream* stream, const void* alpha, DeviceMemoryBase a, DeviceMemoryBase b, const void* beta, DeviceMemoryBase c, DeviceMemoryBase d, const MatmulAlgorithm& algorithm, @@ -239,17 +260,17 @@ struct BlasLt { using MatmulPlanPtr = std::unique_ptr; - virtual tsl::Status Init() = 0; + virtual absl::Status Init() = 0; - virtual tsl::StatusOr GetMatmulPlan( + virtual absl::StatusOr GetMatmulPlan( const GemmConfig& cfg, Epilogue epilogue) const = 0; static BlasLt* Get(const Stream* stream); // convenience function to create MatmulPlan directly using stream - static tsl::StatusOr GetMatmulPlan(const Stream* stream, - const GemmConfig& cfg, - Epilogue epilogue); + static absl::StatusOr GetMatmulPlan(const Stream* stream, + const GemmConfig& cfg, + Epilogue epilogue); virtual ~BlasLt() {} }; // class BlasLt diff --git a/xla/stream_executor/gpu/gpu_collectives.h b/xla/stream_executor/gpu/gpu_collectives.h new file mode 100644 index 0000000000000..e9e0d724c2cfe --- /dev/null +++ b/xla/stream_executor/gpu/gpu_collectives.h @@ -0,0 +1,47 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_COLLECTIVES_H_ +#define XLA_STREAM_EXECUTOR_GPU_GPU_COLLECTIVES_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +namespace stream_executor::gpu { + +// Forward declaration. +class GpuContext; + +struct GpuCollectives { + // Allocates a collective device memory space of size bytes associated with + // the given context. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclmemalloc + static absl::StatusOr CollectiveMemoryAllocate(GpuContext* context, + uint64_t bytes); + + // Deallocates a collective device memory space of size bytes associated with + // the given context. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclmemfree + static absl::Status CollectiveMemoryDeallocate(GpuContext* context, + void* location); +}; + +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_GPU_GPU_COLLECTIVES_H_ diff --git a/xla/stream_executor/gpu/gpu_command_buffer.cc b/xla/stream_executor/gpu/gpu_command_buffer.cc index 643ae4879df36..d9dd03fa667f4 100644 --- a/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,8 @@ limitations under the License. #include #include #include +#include +#include #include #include #include @@ -28,21 +30,26 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/types/span.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_kernel.h" +#include "xla/stream_executor/gpu/gpu_kernels.h" #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/stream_executor.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" -#include "tsl/platform/status.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/path.h" #include "tsl/platform/statusor.h" namespace stream_executor::gpu { @@ -61,7 +68,7 @@ std::string_view to_string(State state) { } } -tsl::Status UnsupportedStateError(State state) { +absl::Status UnsupportedStateError(State state) { return absl::InternalError( absl::StrCat("Unsupported command buffer state: ", to_string(state))); } @@ -95,24 +102,40 @@ static int64_t NotifyExecDestroyed() { // GpuCommandBuffer implementation //===----------------------------------------------------------------------===// +static std::string_view ModeToString(CommandBuffer::Mode mode) { + switch (mode) { + case CommandBuffer::Mode::kPrimary: + return "primary"; + case CommandBuffer::Mode::kNested: + return "nested"; + } +} + GpuCommandBuffer::GpuCommandBuffer(Mode mode, GpuExecutor* parent, GpuGraphHandle graph, bool is_owned_graph) : mode_(mode), parent_(parent), graph_(graph), - is_owned_graph_(is_owned_graph) {} + is_owned_graph_(is_owned_graph) { + VLOG(5) << "Created command buffer for graph " << graph_ + << "; mode=" << ModeToString(mode) + << "; is_owned_graph=" << is_owned_graph_; + execution_scopes_.try_emplace(kDefaulExecutionScope); +} GpuCommandBuffer::~GpuCommandBuffer() { if (exec_ != nullptr && is_owned_graph_exec_) { VLOG(5) << "Destroy GPU command buffer executable graph " << exec_ << " " << "(remaining alive executable graphs: " << NotifyExecDestroyed() << ")"; - auto st = GpuDriver::DestroyGraphExec(exec_); - CHECK(st.ok()) << "Failed to destroy GPU graph exec: " << st.message(); + if (auto status = GpuDriver::DestroyGraphExec(exec_); !status.ok()) { + LOG(ERROR) << "Failed to destroy GPU graph exec: " << status.message(); + } } if (graph_ != nullptr && is_owned_graph_) { - auto st = GpuDriver::DestroyGraph(graph_); - CHECK(st.ok()) << "Failed to destroy GPU graph: " << st.message(); + if (auto status = GpuDriver::DestroyGraph(graph_); !status.ok()) { + LOG(ERROR) << "Failed to destroy GPU graph: " << status.message(); + } } } @@ -134,26 +157,41 @@ static GpuDevicePtr AsDevicePtr(const DeviceMemoryBase& mem) { return reinterpret_cast(const_cast(mem.opaque())); } -tsl::Status GpuCommandBuffer::Trace( - Stream* stream, absl::AnyInvocable function) { - // TODO(ezhulenev): Check that graph is empty, because we should not be mixing - // graph tracing with explicit graph construction. +absl::Status GpuCommandBuffer::Trace( + Stream* stream, absl::AnyInvocable function) { TF_RETURN_IF_ERROR(CheckNotFinalized()); +#if defined(TENSORFLOW_USE_ROCM) + TF_ASSIGN_OR_RETURN(size_t count, GpuDriver::GraphGetNodeCount(graph_)); + if (count != 0 || !is_owned_graph_) + return absl::InternalError( + "Stream can't be traced on non empty command buffer"); +#endif // TENSORFLOW_USE_ROCM VLOG(5) << "Trace into GPU command buffer graph " << graph_ - << " on a stream: " << stream->DebugStreamPointers(); + << " on a stream: " << stream; auto gpu_stream = AsGpuStreamValue(stream); // Switch stream into the capture mode. uint64_t start_nanos = tsl::Env::Default()->NowNanos(); +#if !defined(TENSORFLOW_USE_ROCM) + TF_RETURN_IF_ERROR(GpuDriver::StreamBeginCaptureToGraph( + gpu_stream, graph_, GpuDriver::StreamCaptureMode::kThreadLocal)); +#else TF_RETURN_IF_ERROR(GpuDriver::StreamBeginCapture( gpu_stream, GpuDriver::StreamCaptureMode::kThreadLocal)); - +#endif // TENSORFLOW_USE_ROCM auto traced = function(); // Always stop capturing the stream before checking `traced` result. - TF_RETURN_IF_ERROR(GpuDriver::StreamEndCapture(gpu_stream, &graph_)); + GpuGraphHandle captured_graph; + TF_RETURN_IF_ERROR(GpuDriver::StreamEndCapture(gpu_stream, &captured_graph)); +#if !defined(TENSORFLOW_USE_ROCM) + DCHECK(captured_graph == graph_) << "Stream capture should update graph_"; +#else + TF_RETURN_IF_ERROR( + GpuDriver::DestroyGraph(std::exchange(graph_, captured_graph))); +#endif // TENSORFLOW_USE_ROCM uint64_t end_nanos = tsl::Env::Default()->NowNanos(); if (!traced.ok()) @@ -163,114 +201,432 @@ tsl::Status GpuCommandBuffer::Trace( VLOG(5) << "Traced into the GPU command buffer graph " << graph_ << " (took " << (end_nanos - start_nanos) / 1000 << " μs)"; - return tsl::OkStatus(); + return absl::OkStatus(); } -GpuCommandBuffer::Dependencies GpuCommandBuffer::GetDependencies() { - return nodes_.empty() ? Dependencies() : Dependencies{nodes_.back()}; +GpuCommandBuffer::Dependencies GpuCommandBuffer::GetBarrier( + ExecutionScopeId execution_scope_id) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + return execution_scope.barriers.empty() + ? Dependencies{} + : Dependencies{execution_scope.barriers.back().handle}; } -tsl::Status GpuCommandBuffer::CheckNotFinalized() { - if (state_ == State::kFinalized) - return absl::InternalError( - "Command can't be added to a command buffer after it was finalized"); - return tsl::OkStatus(); +absl::StatusOr +GpuCommandBuffer::GetSetIfConditionKernel(StreamExecutor* executor) { + if (!set_if_condition_kernel_) { + MultiKernelLoaderSpec spec(/*arity=*/2); + spec.AddCudaPtxInMemory(gpu::GetSetIfConditionKernel(), "set_if_condition"); + TF_ASSIGN_OR_RETURN(set_if_condition_kernel_, + SetIfConditionKernel::Create(executor, spec)); + } + return &set_if_condition_kernel_; +} + +absl::StatusOr +GpuCommandBuffer::GetSetIfElseConditionKernel(StreamExecutor* executor) { + if (!set_if_else_condition_kernel_) { + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddCudaPtxInMemory(gpu::GetSetIfElseConditionKernel(), + "set_if_else_condition"); + TF_ASSIGN_OR_RETURN(set_if_else_condition_kernel_, + SetIfElseConditionKernel::Create(executor, spec)); + } + return &set_if_else_condition_kernel_; +} + +absl::StatusOr +GpuCommandBuffer::GetSetCaseConditionKernel(StreamExecutor* executor) { + if (!set_case_condition_kernel_) { + MultiKernelLoaderSpec spec(/*arity=*/10); + spec.AddCudaPtxInMemory(gpu::GetSetCaseConditionKernel(), + "set_case_condition"); + TF_ASSIGN_OR_RETURN(set_case_condition_kernel_, + SetCaseConditionKernel::Create(executor, spec)); + } + return &set_case_condition_kernel_; +} + +absl::StatusOr +GpuCommandBuffer::GetSetForConditionKernel(StreamExecutor* executor) { + if (!set_for_condition_kernel_) { + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddCudaPtxInMemory(gpu::GetSetForConditionKernel(), + "set_for_condition"); + TF_ASSIGN_OR_RETURN(set_for_condition_kernel_, + SetForConditionKernel::Create(executor, spec)); + } + return &set_for_condition_kernel_; } -tsl::Status GpuCommandBuffer::CheckPrimary() { - if (mode_ != Mode::kPrimary) +absl::StatusOr +GpuCommandBuffer::GetSetWhileConditionKernel(StreamExecutor* executor) { + if (!set_while_condition_kernel_) { + MultiKernelLoaderSpec spec(/*arity=*/2); + spec.AddCudaPtxInMemory(gpu::GetSetWhileConditionKernel(), + "set_while_condition"); + TF_ASSIGN_OR_RETURN(set_while_condition_kernel_, + SetWhileConditionKernel::Create(executor, spec)); + } + return &set_while_condition_kernel_; +} + +absl::StatusOr GpuCommandBuffer::GetNoOpKernel( + StreamExecutor* executor) { +#if !defined(TENSORFLOW_USE_ROCM) + if (!noop_kernel_) { + MultiKernelLoaderSpec spec(/*arity=*/0); + spec.AddCudaPtxInMemory(gpu::kNoOpKernel, "noop"); + TF_ASSIGN_OR_RETURN(noop_kernel_, NoOpKernel::Create(executor, spec)); + } + return &noop_kernel_; +#else + return absl::UnimplementedError( + "GpuCommandBuffer::GetNoOpKernel is not implemented."); +#endif // TENSORFLOW_USE_ROCM +} + +absl::Status GpuCommandBuffer::DisableBarriersExecution( + GpuGraphExecHandle exec) { +#if !defined(TENSORFLOW_USE_ROCM) + ExecutionScope& execution_scope = execution_scopes_[kDefaulExecutionScope]; + + for (GpuGraphBarrierInfo& barrier : execution_scope.barriers) { + if (barrier.is_barrier_node) { + TF_RETURN_IF_ERROR( + GpuDriver::GraphNodeSetEnabled(exec, barrier.handle, false)); + } + } + for (ConditionalCommandBuffers& cmd_buffers : + execution_scope.conditional_command_buffers) { + for (auto& cmd_buffer : cmd_buffers.command_buffers) { + TF_RETURN_IF_ERROR(cmd_buffer->DisableBarriersExecution(exec)); + } + } +#endif // TENSORFLOW_USE_ROCM + return absl::OkStatus(); +} + +absl::Status GpuCommandBuffer::CheckNotFinalized() { + if (state_ == State::kFinalized) return absl::InternalError( - "Command can't be added to a non-primary command buffer"); - return tsl::OkStatus(); + "Command can't be added to a command buffer after it was finalized"); + return absl::OkStatus(); } -tsl::Status GpuCommandBuffer::CheckNumCommandBuffers( +absl::Status GpuCommandBuffer::CheckNumCommandBuffers( const ConditionalCommandBuffers& cmd_buffers, size_t num_cmd_buffers) { if (cmd_buffers.handles.size() != num_cmd_buffers) { return absl::InternalError(absl::StrCat( "Expected to have ", num_cmd_buffers, " conditional command buffers, got ", cmd_buffers.handles.size())); } - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status GpuCommandBuffer::Launch(const ThreadDim& threads, - const BlockDim& blocks, - const Kernel& kernel, - const KernelArgs& args) { - TF_RETURN_IF_ERROR(CheckNotFinalized()); +absl::StatusOr GpuCommandBuffer::CreateBarrierNode( + StreamExecutor* executor, const Dependencies& dependencies) { + GpuGraphNodeHandle barrier_handle = nullptr; +#if !defined(TENSORFLOW_USE_ROCM) + // TODO(b/316343054): Instead of empty nodes we create no-op kernel nodes as + // barriers because CUDA 12.3 does not support empty nodes inside + // conditional command buffers. This should be fixed in CUDA 12.4. + TF_ASSIGN_OR_RETURN(NoOpKernel * noop, GetNoOpKernel(executor)); + + TF_RETURN_IF_ERROR(GpuDriver::GraphAddKernelNode( + &barrier_handle, graph_, dependencies, "noop", + AsGpuKernel(&**noop)->AsGpuFunctionHandle(), 1, 1, 1, 1, 1, 1, 0, + /*kernel_params=*/nullptr, /*extra=*/nullptr)); +#else + TF_RETURN_IF_ERROR( + GpuDriver::GraphAddEmptyNode(&barrier_handle, graph_, dependencies)); +#endif // TENSORFLOW_USE_ROCM + + return barrier_handle; +} + +GpuCommandBuffer::Dependencies GpuCommandBuffer::GetBarrierDependencies( + ExecutionScopeId execution_scope_id) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + auto& barriers = execution_scope.barriers; + + // Collect nodes that will become a new barrier dependencies starting from + // the first command node added after the last barrier in the scope. + Dependencies dependencies; + for (size_t i = barriers.empty() ? 0 : barriers.back().nodes_offset; + i < execution_scope.nodes.size(); ++i) { + dependencies.push_back(execution_scope.nodes[i].handle); + } + return dependencies; +} + +absl::Status GpuCommandBuffer::Barrier(StreamExecutor* executor, + ExecutionScopeId execution_scope_id) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + + if (state_ == State::kCreate) { + // Nodes offset for a newly created barrier. + size_t nodes_offset = execution_scope.nodes.size(); + + // Collect nodes that will become a new barrier dependencies starting from + // the first command node added after the last barrier. + Dependencies dependencies = GetBarrierDependencies(execution_scope_id); + + // If there are no new dependencies and we have an existing barrier simply + // copy information from the last barrier to a new one. + if (dependencies.empty() && !execution_scope.barriers.empty()) { + execution_scope.barriers.push_back({execution_scope.barriers.back()}); + return absl::OkStatus(); + } + + // If we have only one node added after the last barrier simply reuse the + // last node corresponding to a command as a barrier. + if (dependencies.size() == 1) { + execution_scope.barriers.push_back( + {execution_scope.nodes.back().handle, false, nodes_offset}); + return absl::OkStatus(); + } + + // If we have multiple dependencies or no existing barriers we have to + // create a new empty node acting as an execution barrier. + TF_ASSIGN_OR_RETURN(auto barrier_handle, + CreateBarrierNode(executor, dependencies)); + execution_scope.barriers.push_back({barrier_handle, true, nodes_offset}); + return absl::OkStatus(); + } + + if (state_ == State::kUpdate) { + // Command buffer updates can't change the structure of the underlying gpu + // graph (add or delete barriers). We simply do a sanity check that at + // update time we didn't try to add more barriers than we had originally. + if (execution_scope.update_state.barrier_idx++ >= + execution_scope.barriers.size()) { + return absl::InternalError( + absl::StrFormat("Execution scope %d barrier index out of range", + execution_scope_id.value())); + } + return absl::OkStatus(); + } + + return UnsupportedStateError(state_); +} + +absl::Status GpuCommandBuffer::Barrier( + StreamExecutor* executor, + absl::Span execution_scope_ids) { + // Nothing to synchronize here. + if (execution_scope_ids.empty()) return absl::OkStatus(); + + // Do not create two-level barriers for single execution scope. + if (execution_scope_ids.size() == 1) { + return Barrier(executor, execution_scope_ids[0]); + } + + // Add a new barrier to every synchronized execution scope. + for (ExecutionScopeId execution_scope_id : execution_scope_ids) { + TF_RETURN_IF_ERROR(Barrier(executor, execution_scope_id)); + } + + if (state_ == State::kCreate) { + // Collect barriers from each scope as a dependencies. + Dependencies dependencies; + for (ExecutionScopeId execution_scope_id : execution_scope_ids) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + dependencies.push_back(execution_scope.barriers.back().handle); + } + + // Create a new barrier that joins all per-scope barriers together. + TF_ASSIGN_OR_RETURN(auto barrier_handle, + CreateBarrierNode(executor, dependencies)); + + // Broadcast new barrier to all participating execution scopes. + for (ExecutionScopeId execution_scope_id : execution_scope_ids) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + size_t nodes_offset = execution_scope.nodes.size(); + execution_scope.barriers.push_back({barrier_handle, true, nodes_offset}); + } + + return absl::OkStatus(); + } + + if (state_ == State::kUpdate) { + // Command buffer updates can't change the structure of the underlying gpu + // graph (add or delete barriers). We simply do a sanity check that at + // update time we didn't try to add more barriers than we had originally. + for (ExecutionScopeId execution_scope_id : execution_scope_ids) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + if (execution_scope.update_state.barrier_idx++ >= + execution_scope.barriers.size()) { + return absl::InternalError( + absl::StrFormat("Execution scope %d barrier index out of range", + execution_scope_id.value())); + } + } + return absl::OkStatus(); + } + + return UnsupportedStateError(state_); +} + +absl::Status GpuCommandBuffer::Barrier(StreamExecutor* executor, + ExecutionScopeId from_execution_scope_id, + ExecutionScopeId to_execution_scope_id) { + // If scopes are the same simply add a barrier to it. + if (from_execution_scope_id == to_execution_scope_id) { + return Barrier(executor, from_execution_scope_id); + } + + // Create new barriers in both execution scopes. + TF_RETURN_IF_ERROR(Barrier(executor, from_execution_scope_id)); + TF_RETURN_IF_ERROR(Barrier(executor, to_execution_scope_id)); + + if (state_ == State::kCreate) { + // Collect barriers from each scope as dependencies. + Dependencies dependencies = { + execution_scopes_[from_execution_scope_id].barriers.back().handle, + execution_scopes_[to_execution_scope_id].barriers.back().handle}; + + // Create a new barrier that joins `from` and `to` scopes. + TF_ASSIGN_OR_RETURN(auto barrier_handle, + CreateBarrierNode(executor, dependencies)); + + // Add a new barrier only to the `to_execution_scope_id`. + ExecutionScope& execution_scope = execution_scopes_[to_execution_scope_id]; + size_t nodes_offset = execution_scope.nodes.size(); + execution_scope.barriers.push_back({barrier_handle, true, nodes_offset}); + + return absl::OkStatus(); + } + + if (state_ == State::kUpdate) { + // Command buffer updates can't change the structure of the underlying gpu + // graph (add or delete barriers). We simply do a sanity check that at + // update time we didn't try to add more barriers than we had originally. + ExecutionScope& execution_scope = execution_scopes_[to_execution_scope_id]; + if (execution_scope.update_state.barrier_idx++ >= + execution_scope.barriers.size()) { + return absl::InternalError( + absl::StrFormat("Execution scope %d barrier index out of range", + to_execution_scope_id.value())); + } + return absl::OkStatus(); + } + + return UnsupportedStateError(state_); +} + +absl::Status GpuCommandBuffer::LaunchWithPackedArgs( + ExecutionScopeId execution_scope_id, const ThreadDim& threads, + const BlockDim& blocks, const Kernel& kernel, + const KernelArgsPackedArrayBase& packed_args) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + + CHECK_EQ(kernel.Arity() + (packed_args.number_of_shared_bytes() > 0), + packed_args.number_of_arguments()); const GpuKernel* gpu_kernel = AsGpuKernel(&kernel); GpuFunctionHandle gpu_func = gpu_kernel->AsGpuFunctionHandle(); - auto* packed_args = DynCast(&args); - if (!packed_args) - return absl::InternalError("Unsupported kernel arguments type"); - void** kernel_params = - const_cast(packed_args->argument_addresses().data()); + const_cast(packed_args.argument_addresses().data()); // Adds a new kernel node to the graph under construction. if (state_ == State::kCreate) { - Dependencies deps = GetDependencies(); - GpuGraphNodeHandle* node = &nodes_.emplace_back(); + Dependencies barrier = GetBarrier(execution_scope_id); + GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back(); return GpuDriver::GraphAddKernelNode( - node, graph_, absl::MakeSpan(deps), kernel.name(), gpu_func, blocks.x, + &node_info.handle, graph_, barrier, kernel.name(), gpu_func, blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, - args.number_of_shared_bytes(), kernel_params, /*extra=*/nullptr); + packed_args.number_of_shared_bytes(), kernel_params, /*extra=*/nullptr); } // Updates kernel node in the executable graph. if (state_ == State::kUpdate) { - GpuGraphNodeHandle node = nodes_[update_state_.node_idx++]; + GpuGraphNodeHandle node = + execution_scope.nodes[execution_scope.update_state.node_idx++].handle; return GpuDriver::GraphExecKernelNodeSetParams( exec_, node, kernel.name(), gpu_func, blocks.x, blocks.y, blocks.z, - threads.x, threads.y, threads.z, args.number_of_shared_bytes(), + threads.x, threads.y, threads.z, packed_args.number_of_shared_bytes(), kernel_params, /*extra=*/nullptr); } return UnsupportedStateError(state_); } -tsl::Status GpuCommandBuffer::AddNestedCommandBuffer( - const CommandBuffer& nested) { +absl::Status GpuCommandBuffer::Launch(ExecutionScopeId execution_scope_id, + const ThreadDim& threads, + const BlockDim& blocks, + const Kernel& kernel, + const KernelArgs& args) { + TF_RETURN_IF_ERROR(CheckNotFinalized()); + + // If arguments are already packed we can just launch the kernel. + if (auto* packed = DynCast(&args)) { + return LaunchWithPackedArgs(execution_scope_id, threads, blocks, kernel, + *packed); + } + + // For device memory array we rely on a custom kernel arguments packing. + if (auto* device_mem = DynCast(&args)) { + auto& pack = kernel.args_packing(); + if (!pack) { + return absl::InternalError( + "Kernel is missing a custom arguments packing function for device " + "memory arguments array"); + } + + TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem)); + return LaunchWithPackedArgs(execution_scope_id, threads, blocks, kernel, + *packed); + } + + return absl::InternalError("Unsupported kernel arguments type"); +} + +absl::Status GpuCommandBuffer::AddNestedCommandBuffer( + ExecutionScopeId execution_scope_id, const CommandBuffer& nested) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + TF_RETURN_IF_ERROR(CheckNotFinalized()); - TF_RETURN_IF_ERROR(CheckPrimary()); GpuGraphHandle child_graph = GpuCommandBuffer::Cast(&nested)->graph(); // Adds a child graph node to the graph under construction. if (state_ == State::kCreate) { - Dependencies deps = GetDependencies(); - GpuGraphNodeHandle* node = &nodes_.emplace_back(); - return GpuDriver::GraphAddChildNode(node, graph_, absl::MakeSpan(deps), + Dependencies barrier = GetBarrier(execution_scope_id); + GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back(); + return GpuDriver::GraphAddChildNode(&node_info.handle, graph_, barrier, child_graph); } // Updates child graph node in the executable graph. if (state_ == State::kUpdate) { - GpuGraphNodeHandle node = nodes_[update_state_.node_idx++]; + GpuGraphNodeHandle node = + execution_scope.nodes[execution_scope.update_state.node_idx++].handle; return GpuDriver::GraphExecChildNodeSetParams(exec_, node, child_graph); } return UnsupportedStateError(state_); } -tsl::Status GpuCommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst, - const DeviceMemoryBase& src, - uint64_t size) { +absl::Status GpuCommandBuffer::MemcpyDeviceToDevice( + ExecutionScopeId execution_scope_id, DeviceMemoryBase* dst, + const DeviceMemoryBase& src, uint64_t size) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + TF_RETURN_IF_ERROR(CheckNotFinalized()); if (state_ == State::kCreate) { - Dependencies deps = GetDependencies(); - GpuGraphNodeHandle* node = &nodes_.emplace_back(); + Dependencies barrier = GetBarrier(execution_scope_id); + GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back(); return GpuDriver::GraphAddMemcpyD2DNode( - parent_->gpu_context(), node, graph_, absl::MakeSpan(deps), + parent_->gpu_context(), &node_info.handle, graph_, barrier, AsDevicePtr(*dst), AsDevicePtr(src), size); } if (state_ == State::kUpdate) { - GpuGraphNodeHandle node = nodes_[update_state_.node_idx++]; + GpuGraphNodeHandle node = + execution_scope.nodes[execution_scope.update_state.node_idx++].handle; return GpuDriver::GraphExecMemcpyD2DNodeSetParams( parent_->gpu_context(), exec_, node, AsDevicePtr(*dst), AsDevicePtr(src), size); @@ -279,21 +635,25 @@ tsl::Status GpuCommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst, return UnsupportedStateError(state_); } -tsl::Status GpuCommandBuffer::Memset(DeviceMemoryBase* dst, - CommandBuffer::BitPattern bit_pattern, - size_t num_elements) { +absl::Status GpuCommandBuffer::Memset(ExecutionScopeId execution_scope_id, + DeviceMemoryBase* dst, + CommandBuffer::BitPattern bit_pattern, + size_t num_elements) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + TF_RETURN_IF_ERROR(CheckNotFinalized()); if (state_ == State::kCreate) { - Dependencies deps = GetDependencies(); - GpuGraphNodeHandle* node = &nodes_.emplace_back(); + Dependencies barrier = GetBarrier(execution_scope_id); + GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back(); return GpuDriver::GraphAddMemsetNode( - parent_->gpu_context(), node, graph_, absl::MakeSpan(deps), + parent_->gpu_context(), &node_info.handle, graph_, barrier, AsDevicePtr(*dst), bit_pattern, num_elements); } if (state_ == State::kUpdate) { - GpuGraphNodeHandle node = nodes_[update_state_.node_idx++]; + GpuGraphNodeHandle node = + execution_scope.nodes[execution_scope.update_state.node_idx++].handle; return GpuDriver::GraphExecMemsetNodeSetParams( parent_->gpu_context(), exec_, node, AsDevicePtr(*dst), bit_pattern, num_elements); @@ -302,16 +662,20 @@ tsl::Status GpuCommandBuffer::Memset(DeviceMemoryBase* dst, return UnsupportedStateError(state_); } -tsl::StatusOr GpuCommandBuffer::Allocate(size_t bytes) { +absl::StatusOr GpuCommandBuffer::Allocate( + ExecutionScopeId execution_scope_id, size_t bytes) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + TF_RETURN_IF_ERROR(CheckNotFinalized()); + // Adds a new memory allocation node to the graph under construction. if (state_ == State::kCreate) { - Dependencies deps = GetDependencies(); - GpuGraphNodeHandle* node = &nodes_.emplace_back(); + Dependencies barrier = GetBarrier(execution_scope_id); + GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back(); GpuDevicePtr ptr; TF_RETURN_IF_ERROR(GpuDriver::GraphAddMemAllocNode( - node, graph_, absl::MakeSpan(deps), + &node_info.handle, graph_, barrier, GpuDriver::MemAccessFlags::kReadWrite, GpuDriver::MemLocationType::kDevice, parent_->device_ordinal(), GpuDriver::MemAllocationType::kPinned, bytes, &ptr)); @@ -328,9 +692,10 @@ tsl::StatusOr GpuCommandBuffer::Allocate(size_t bytes) { // Memory allocation node implemented through CUDA graph does not allocate // new memory region on update, just return the memory region allocated // during the create step. + GpuGraphNodeHandle node = + execution_scope.nodes[execution_scope.update_state.node_idx++].handle; TF_ASSIGN_OR_RETURN(AllocationResult params, - GpuDriver::GraphGetMemAllocNodeParams( - nodes_[update_state_.node_idx++])); + GpuDriver::GraphGetMemAllocNodeParams(node)); return DeviceMemoryBase(reinterpret_cast(params.first), params.second); } @@ -338,19 +703,48 @@ tsl::StatusOr GpuCommandBuffer::Allocate(size_t bytes) { return UnsupportedStateError(state_); } +absl::Status GpuCommandBuffer::Free(ExecutionScopeId execution_scope_id, + DeviceMemoryBase dst) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + + TF_RETURN_IF_ERROR(CheckNotFinalized()); + + // Adds a new memfree node to the graph under construction. + if (state_ == State::kCreate) { + Dependencies barrier = GetBarrier(execution_scope_id); + GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back(); + GpuDevicePtr gpu_dptr = AsDevicePtr(dst); + TF_RETURN_IF_ERROR(GpuDriver::GraphAddMemFreeNode(&node_info.handle, graph_, + barrier, gpu_dptr)); + return absl::OkStatus(); + } + + if (state_ == State::kUpdate) { + // memfree node implemented through CUDA graph only free buffers that is + // allocated through memory alloc node, so buffer address will not change, + // no update is required. + execution_scope.update_state.node_idx++; + return absl::OkStatus(); + } + + return UnsupportedStateError(state_); +} + //--------------------------------------------------------------------------// // Command buffer condtitional commands API //--------------------------------------------------------------------------// +using ConditionalHandles = absl::Span; + /*static*/ GpuCommandBuffer::ConditionBuilder -GpuCommandBuffer::ToConditionBuilder(CommandBuffer::Builder builder) { +GpuCommandBuffer::ToConditionBuilder(Builder builder) { return [builder = std::move(builder)](CommandBuffer* cmd_buffer, GpuGraphConditionalHandle) { return builder(cmd_buffer); }; } -tsl::StatusOr> +absl::StatusOr> GpuCommandBuffer::CreateConditionalHandles(size_t num_handles) { std::vector handles; for (size_t i = 0; i < num_handles; ++i) { @@ -360,39 +754,12 @@ GpuCommandBuffer::CreateConditionalHandles(size_t num_handles) { return handles; } -tsl::StatusOr> -GpuCommandBuffer::CreateConditionalNodes( - ConditionType type, absl::Span handles) { - std::vector conditional_graphs; - - using ConditionalParams = GpuDriver::GpuGraphConditionalNodeParams; - using ConditionalResult = GpuDriver::GpuGraphConditionalNodeParams::Result; - - for (GpuGraphConditionalHandle handle : handles) { - Dependencies deps = GetDependencies(); - GpuGraphNodeHandle* node = &nodes_.emplace_back(); - - ConditionalParams params; - params.type = type; - params.handle = handle; - params.context = parent_->gpu_context(); - - TF_ASSIGN_OR_RETURN( - GpuDriver::GpuGraphNodeResult result, - GpuDriver::GraphAddNode(node, graph_, absl::MakeSpan(deps), params)); - - conditional_graphs.push_back(std::get(result).graph); - } - - return conditional_graphs; -} - -tsl::StatusOr> +absl::StatusOr>> GpuCommandBuffer::CreateConditionalCommandBuffers( absl::Span handles, absl::Span graphs, absl::Span builders) { - std::vector cmd_buffers; + std::vector> cmd_buffers; // Conditional command buffers always created in nested mode and with // underlying graphs owned by a conditional node. @@ -400,13 +767,10 @@ GpuCommandBuffer::CreateConditionalCommandBuffers( bool is_owned_graph = false; for (size_t i = 0; i < handles.size(); ++i) { - auto command_buffer_impl = parent_->GetCommandBufferImplementation( - nested, graphs[i], is_owned_graph); - - auto command_buffer = CommandBuffer::Create(std::move(command_buffer_impl)); - - TF_RETURN_IF_ERROR(builders[i](&command_buffer, handles[i])); - TF_RETURN_IF_ERROR(command_buffer.Finalize()); + auto command_buffer = + parent_->CreateCommandBuffer(nested, graphs[i], is_owned_graph); + TF_RETURN_IF_ERROR(builders[i](command_buffer.get(), handles[i])); + TF_RETURN_IF_ERROR(command_buffer->Finalize()); cmd_buffers.push_back(std::move(command_buffer)); } @@ -414,25 +778,58 @@ GpuCommandBuffer::CreateConditionalCommandBuffers( return cmd_buffers; } -tsl::Status GpuCommandBuffer::UpdateConditionalCommandBuffers( +absl::Status GpuCommandBuffer::UpdateConditionalCommandBuffers( absl::Span handles, - absl::Span command_buffers, + absl::Span> command_buffers, absl::Span builders) { for (size_t i = 0; i < command_buffers.size(); ++i) { // Use parent graph executable for conditional command buffer update. - ScopedGpuGraphExec scoped_exec(Cast(&command_buffers[i]), exec_); + ScopedGpuGraphExec scoped_exec(command_buffers[i].get(), exec_); // Update command buffer using user-provided builder callback. - TF_RETURN_IF_ERROR(command_buffers[i].Update()); - TF_RETURN_IF_ERROR(builders[i](&command_buffers[i], handles[i])); - TF_RETURN_IF_ERROR(command_buffers[i].Finalize()); + TF_RETURN_IF_ERROR(command_buffers[i]->Update()); + TF_RETURN_IF_ERROR(builders[i](command_buffers[i].get(), handles[i])); + TF_RETURN_IF_ERROR(command_buffers[i]->Finalize()); } - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status GpuCommandBuffer::CreateConditionalCommand( +absl::StatusOr> +GpuCommandBuffer::CreateConditionalNodes( + ExecutionScopeId execution_scope_id, ConditionType type, + absl::Span handles) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + + std::vector conditional_graphs; + + using ConditionalParams = GpuDriver::GpuGraphConditionalNodeParams; + using ConditionalResult = GpuDriver::GpuGraphConditionalNodeParams::Result; + + for (GpuGraphConditionalHandle handle : handles) { + Dependencies barrier = GetBarrier(execution_scope_id); + GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back(); + + ConditionalParams params; + params.type = type; + params.handle = handle; + params.context = parent_->gpu_context(); + + TF_ASSIGN_OR_RETURN( + GpuDriver::GpuGraphNodeResult result, + GpuDriver::GraphAddNode(&node_info.handle, graph_, barrier, params)); + + conditional_graphs.push_back(std::get(result).graph); + } + + return conditional_graphs; +} + +absl::Status GpuCommandBuffer::CreateConditionalCommand( + ExecutionScopeId execution_scope_id, StreamExecutor* executor, ConditionType type, SetConditionFn set_condition, absl::Span builders) { + ExecutionScope& execution_scope = execution_scopes_[execution_scope_id]; + TF_RETURN_IF_ERROR(CheckNotFinalized()); // Every conditional command buffer is controlled by its own handle. @@ -442,32 +839,41 @@ tsl::Status GpuCommandBuffer::CreateConditionalCommand( TF_ASSIGN_OR_RETURN(auto handles, CreateConditionalHandles(num_handles)); // Add a kernel to update conditional handles values. - TF_RETURN_IF_ERROR(set_condition(handles)); + TF_RETURN_IF_ERROR(set_condition(execution_scope_id, handles)); + + // Add a barrier between conditional handles and conditional nodes. + TF_RETURN_IF_ERROR(Barrier(executor, execution_scope_id)); // Create conditional command buffer for each builder. - TF_ASSIGN_OR_RETURN(auto graphs, CreateConditionalNodes(type, handles)); + TF_ASSIGN_OR_RETURN( + auto graphs, CreateConditionalNodes(execution_scope_id, type, handles)); TF_ASSIGN_OR_RETURN(auto cmd_buffers, CreateConditionalCommandBuffers( handles, graphs, builders)); // Keep track of created conditional handles and command buffers. - conditional_command_buffers_.emplace_back(std::move(handles), - std::move(cmd_buffers)); + execution_scope.conditional_command_buffers.push_back( + {std::move(handles), std::move(cmd_buffers)}); - return tsl::OkStatus(); + return absl::OkStatus(); } if (state_ == State::kUpdate) { ConditionalCommandBuffers& cond_cmd_buffers = - conditional_command_buffers_[update_state_.conditional_idx++]; + execution_scope.conditional_command_buffers[execution_scope.update_state + .conditional_idx++]; // Sanity check that we got the correct conditional command buffers. TF_RETURN_IF_ERROR(CheckNumCommandBuffers(cond_cmd_buffers, num_handles)); // Update a kernel that updates conditional handles values. - TF_RETURN_IF_ERROR(set_condition(cond_cmd_buffers.handles)); + TF_RETURN_IF_ERROR( + set_condition(execution_scope_id, cond_cmd_buffers.handles)); + + // Update a barrier between conditional handles and conditional nodes. + TF_RETURN_IF_ERROR(Barrier(executor, execution_scope_id)); // Skip updating conditional nodes. - update_state_.node_idx += num_handles; + execution_scope.update_state.node_idx += num_handles; return UpdateConditionalCommandBuffers( cond_cmd_buffers.handles, @@ -477,64 +883,54 @@ tsl::Status GpuCommandBuffer::CreateConditionalCommand( return UnsupportedStateError(state_); } -tsl::Status GpuCommandBuffer::If(StreamExecutor* executor, - DeviceMemory predicate, - CommandBuffer::Builder then_builder) { +absl::Status GpuCommandBuffer::If(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, + DeviceMemory predicate, + Builder then_builder) { DCHECK(executor->implementation() == parent_); - // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on - // every call to `If`. - SetIfConditionKernel set_if_condition(executor); + TF_ASSIGN_OR_RETURN(SetIfConditionKernel * set_if_condition, + GetSetIfConditionKernel(executor)); - { // Load kernels that updates condition handle value. - MultiKernelLoaderSpec spec(/*arity=*/2); - spec.AddInProcessSymbol(gpu::GetSetIfConditionKernel(), "set_if_condition"); - TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_if_condition)); - } - - auto set_cond_fn = [&](absl::Span handles) { - return Launch(set_if_condition, ThreadDim(), BlockDim(), handles[0], - predicate); + auto set_cond_fn = [&](ExecutionScopeId id, ConditionalHandles handles) { + return CommandBuffer::Launch(*set_if_condition, id, ThreadDim(), BlockDim(), + handles[0], predicate); }; std::array builders = { ToConditionBuilder(std::move(then_builder))}; - return CreateConditionalCommand(ConditionType::kIf, set_cond_fn, builders); + return CreateConditionalCommand(execution_scope_id, executor, + ConditionType::kIf, set_cond_fn, builders); } -tsl::Status GpuCommandBuffer::IfElse(StreamExecutor* executor, - DeviceMemory predicate, - CommandBuffer::Builder then_builder, - CommandBuffer::Builder else_builder) { +absl::Status GpuCommandBuffer::IfElse(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, + DeviceMemory predicate, + Builder then_builder, + Builder else_builder) { DCHECK(executor->implementation() == parent_); - // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on - // every call to `IfElse`. - SetIfElseConditionKernel set_if_else_condition(executor); - - { // Load kernels that updates condition handle value. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(gpu::GetSetIfElseConditionKernel(), - "set_if_else_condition"); - TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_if_else_condition)); - } + TF_ASSIGN_OR_RETURN(SetIfElseConditionKernel * set_if_else_condition, + GetSetIfElseConditionKernel(executor)); - auto set_cond_fn = [&](absl::Span handles) { - return Launch(set_if_else_condition, ThreadDim(), BlockDim(), handles[0], - handles[1], predicate); + auto set_cond_fn = [&](ExecutionScopeId id, ConditionalHandles handles) { + return CommandBuffer::Launch(*set_if_else_condition, id, ThreadDim(), + BlockDim(), handles[0], handles[1], predicate); }; std::array builders = { ToConditionBuilder(std::move(then_builder)), ToConditionBuilder(std::move(else_builder))}; - return CreateConditionalCommand(ConditionType::kIf, set_cond_fn, builders); + return CreateConditionalCommand(execution_scope_id, executor, + ConditionType::kIf, set_cond_fn, builders); } -tsl::Status GpuCommandBuffer::Case( - StreamExecutor* executor, DeviceMemory index, - std::vector branches) { +absl::Status GpuCommandBuffer::Case(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, + DeviceMemory index, + std::vector branches) { DCHECK(executor->implementation() == parent_); // TODO(ezhulenev): Relax this constraint, we can launch multiple back to back @@ -544,18 +940,10 @@ tsl::Status GpuCommandBuffer::Case( "Case command supports only up to 8 branches, got: ", branches.size())); } - // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on - // every call to `Case`. - SetCaseConditionKernel set_case_condition(executor); - - { // Load kernels that updates condition handle value. - MultiKernelLoaderSpec spec(/*arity=*/10); - spec.AddInProcessSymbol(gpu::GetSetCaseConditionKernel(), - "set_case_condition"); - TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_case_condition)); - } + TF_ASSIGN_OR_RETURN(SetCaseConditionKernel * set_case_condition, + GetSetCaseConditionKernel(executor)); - auto set_cond_fn = [&](absl::Span handles) { + auto set_cond_fn = [&](ExecutionScopeId id, ConditionalHandles handles) { int32_t num_handles = handles.size(); // Pad handles up to size 8 with a default initialized handle. @@ -563,10 +951,11 @@ tsl::Status GpuCommandBuffer::Case( handles.end()); padded_handles.resize(8); - return Launch(set_case_condition, ThreadDim(), BlockDim(), - padded_handles[0], padded_handles[1], padded_handles[2], - padded_handles[3], padded_handles[4], padded_handles[5], - padded_handles[6], padded_handles[7], index, num_handles); + return CommandBuffer::Launch( + *set_case_condition, id, ThreadDim(), BlockDim(), padded_handles[0], + padded_handles[1], padded_handles[2], padded_handles[3], + padded_handles[4], padded_handles[5], padded_handles[6], + padded_handles[7], index, num_handles); }; // Wrap all branches into conditional command buffer builders. @@ -576,100 +965,144 @@ tsl::Status GpuCommandBuffer::Case( builders.push_back(ToConditionBuilder(std::move(branch))); } - return CreateConditionalCommand(ConditionType::kIf, set_cond_fn, builders); + return CreateConditionalCommand(execution_scope_id, executor, + ConditionType::kIf, set_cond_fn, builders); } -tsl::Status GpuCommandBuffer::For(StreamExecutor* executor, - int32_t num_iteration, - DeviceMemory loop_counter, - CommandBuffer::Builder body_builder) { +absl::Status GpuCommandBuffer::For(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, + int32_t num_iteration, + DeviceMemory loop_counter, + Builder body_builder) { DCHECK(executor->implementation() == parent_); - // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on - // every call to `For`. - SetForConditionKernel set_for_condition(executor); - - { // Load kernels that updates condition handle value. - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(gpu::GetSetForConditionKernel(), - "set_for_condition"); - TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_for_condition)); - } + TF_ASSIGN_OR_RETURN(SetForConditionKernel * set_for_condition, + GetSetForConditionKernel(executor)); // Reset loop counter to zero. - TF_RETURN_IF_ERROR(Memset(&loop_counter, uint32_t{0}, 1)); + TF_RETURN_IF_ERROR(Memset(execution_scope_id, &loop_counter, uint32_t{0}, 1)); + TF_RETURN_IF_ERROR(Barrier(executor, execution_scope_id)); - auto set_cond_fn = [&](absl::Span handles) { - return Launch(set_for_condition, ThreadDim(), BlockDim(), handles[0], - loop_counter, num_iteration); + auto set_cond_fn = [&](ExecutionScopeId id, ConditionalHandles handles) { + return CommandBuffer::Launch(*set_for_condition, id, ThreadDim(), + BlockDim(), handles[0], loop_counter, + num_iteration); }; auto body = [&](CommandBuffer* body, GpuGraphConditionalHandle handle) { TF_RETURN_IF_ERROR(body_builder(body)); + TF_RETURN_IF_ERROR(body->Barrier(executor)); // Decide if we want to continue loop iteration. - return body->Launch(set_for_condition, ThreadDim(), BlockDim(), handle, + return body->Launch(*set_for_condition, ThreadDim(), BlockDim(), handle, loop_counter, num_iteration); }; std::array builders = {std::move(body)}; - return CreateConditionalCommand(ConditionType::kWhile, set_cond_fn, builders); + return CreateConditionalCommand(execution_scope_id, executor, + ConditionType::kWhile, set_cond_fn, builders); } -tsl::Status GpuCommandBuffer::While(StreamExecutor* executor, - DeviceMemory pred, - CommandBuffer::Builder cond_builder, - CommandBuffer::Builder body_builder) { +absl::Status GpuCommandBuffer::While(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, + DeviceMemory pred, + ExecutionScopeBuilder cond_builder, + Builder body_builder) { DCHECK(executor->implementation() == parent_); - // TODO(ezhulenev): Keep kernel in `GpuCommandBuffer` to avoid loading it on - // every call to `While`. - SetWhileConditionKernel set_while_condition(executor); - - { // Load kernels that updates condition handle value. - MultiKernelLoaderSpec spec(/*arity=*/2); - spec.AddInProcessSymbol(gpu::GetSetWhileConditionKernel(), - "set_while_condition"); - TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_while_condition)); - } + TF_ASSIGN_OR_RETURN(SetWhileConditionKernel * set_while_condition, + GetSetWhileConditionKernel(executor)); // Record condition commands into the parent command buffer. - TF_RETURN_IF_ERROR(CommandBuffer::Build(this, cond_builder)); + TF_RETURN_IF_ERROR(cond_builder(execution_scope_id, this)); + TF_RETURN_IF_ERROR(Barrier(executor, execution_scope_id)); - auto set_cond_fn = [&](absl::Span handles) { - return Launch(set_while_condition, ThreadDim(), BlockDim(), handles[0], - pred); + auto set_cond_fn = [&](ExecutionScopeId id, ConditionalHandles handles) { + return CommandBuffer::Launch(*set_while_condition, id, ThreadDim(), + BlockDim(), handles[0], pred); }; auto body = [&](CommandBuffer* body, GpuGraphConditionalHandle handle) { TF_RETURN_IF_ERROR(body_builder(body)); - TF_RETURN_IF_ERROR(cond_builder(body)); - return body->Launch(set_while_condition, ThreadDim(), BlockDim(), handle, + TF_RETURN_IF_ERROR(body->Barrier(executor)); + TF_RETURN_IF_ERROR(cond_builder(kDefaulExecutionScope, body)); + TF_RETURN_IF_ERROR(body->Barrier(executor)); + return body->Launch(*set_while_condition, ThreadDim(), BlockDim(), handle, pred); }; std::array builders = {std::move(body)}; - return CreateConditionalCommand(ConditionType::kWhile, set_cond_fn, builders); + return CreateConditionalCommand(execution_scope_id, executor, + ConditionType::kWhile, set_cond_fn, builders); } -tsl::Status GpuCommandBuffer::Finalize() { +absl::Status GpuCommandBuffer::Finalize() { TF_RETURN_IF_ERROR(CheckNotFinalized()); + // Maybe dump created CUDA graph to a dot file for debugging. + if (state_ == State::kCreate && VLOG_IS_ON(10)) { + std::string path = tsl::io::GetTempFilename(/*extension=*/"dot"); + auto printed = GpuDriver::GraphDebugDotPrint( + graph_, path.c_str(), /*return_printed_graph=*/VLOG_IS_ON(100)); + if (VLOG_IS_ON(100) && printed.ok()) { + VLOG(100) << "Printed Gpu graph " << graph_ << " to: " << path << "\n" + << *printed; + } + } + + // Collect number of nodes and conditionals for logging below. + size_t num_nodes = 0, num_cond_cmd_buffers = 0; + for (auto& [_, execution_scope] : execution_scopes_) { + num_nodes += execution_scope.nodes.size(); + num_cond_cmd_buffers += execution_scope.conditional_command_buffers.size(); + } + if (mode_ == Mode::kPrimary && state_ == State::kCreate) { // If this is the first time we finalize command buffer after construction, // we need to instantiate it to an executable graph. GpuDriver::GraphInstantiateFlags flags; uint64_t start_nanos = tsl::Env::Default()->NowNanos(); - TF_RETURN_IF_ERROR(GpuDriver::GraphInstantiate(&exec_, graph_, flags)); + + // If we get a "resource exhausted error" we retry instantiating Gpu graph + // one more time after releasing unused device memory allocated for graphs. + auto instantiated = GpuDriver::GraphInstantiate(&exec_, graph_, flags); + if (instantiated.code() == absl::StatusCode::kResourceExhausted) { + LOG(WARNING) << "Retry CUDA graph instantiation after OOM error" + << "; execution_scopes: " << execution_scopes_.size() + << "; nodes: " << num_nodes + << "; conditionals: " << num_cond_cmd_buffers + << "; alive executable graphs: " << AliveExecs(); + + TF_RETURN_IF_ERROR(GpuDriver::DeviceGraphMemTrim(parent_->device())); + + auto retry = GpuDriver::GraphInstantiate(&exec_, graph_, flags); + if (retry.code() == absl::StatusCode::kResourceExhausted) { + return absl::ResourceExhaustedError(absl::StrFormat( + "CUDA driver ran out of memory trying to instantiate CUDA graph " + "with %d nodes and %d conditionals (total of %d alive CUDA graphs " + "in the process). You can try to (a) Give more memory to CUDA " + "driver by reducing XLA_PYTHON_CLIENT_MEM_FRACTION (b) Disable " + "CUDA graph with 'XLA_FLAGS=--xla_gpu_enable_command_buffer=' " + "(empty set). Original error: %s", + num_nodes, num_cond_cmd_buffers, AliveExecs(), retry.message())); + } else { + TF_RETURN_IF_ERROR(retry); + } + } + uint64_t end_nanos = tsl::Env::Default()->NowNanos(); - VLOG(5) << "Instantiated executable graph " << exec_ << " in " - << (end_nanos - start_nanos) / 1000 << " μs (" - << "#" << NotifyExecCreated() << ", " - << "alive executable graphs: " << AliveExecs() << ")"; + VLOG(5) << "Instantiated executable graph #" << NotifyExecCreated() << " " + << exec_ << " in " << (end_nanos - start_nanos) / 1000 << " μs" + << "; execution_scopes: " << execution_scopes_.size() + << "; nodes: " << num_nodes + << "; conditionals: " << num_cond_cmd_buffers + << "; alive executable graphs: " << AliveExecs(); + + TF_RETURN_IF_ERROR(DisableBarriersExecution(exec_)); } else if (mode_ == Mode::kPrimary && state_ == State::kUpdate) { // If this is a finalization after update, we don't have to do anything as @@ -685,26 +1118,42 @@ tsl::Status GpuCommandBuffer::Finalize() { } state_ = State::kFinalized; - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status GpuCommandBuffer::Update() { - if (state_ != State::kFinalized) { +absl::Status GpuCommandBuffer::Update() { + if (exec_ == nullptr) { return absl::InternalError( - "Command buffer has to be finalized first before it can be updated"); + "Command buffer has to have a graph executable to be updated"); } - if (exec_ == nullptr) { + if (state_ != State::kFinalized) { return absl::InternalError( - "Command buffer has to have a graph executable to be updated"); + "Command buffer has to be finalized first before it can be updated"); } - VLOG(5) << "Begin primary command buffer update for executable graph " - << exec_; + VLOG(5) << "Begin " << (mode_ == Mode::kPrimary ? "primary" : "nested") + << " command buffer update for executable graph " << exec_; state_ = State::kUpdate; - update_state_ = UpdateState(); - return tsl::OkStatus(); + for (auto& [_, execution_scope] : execution_scopes_) { + execution_scope.update_state = ExecutionScope::UpdateState(); + } + return absl::OkStatus(); +} + +absl::Span GpuCommandBuffer::nodes( + ExecutionScopeId id) const { + if (auto it = execution_scopes_.find(id); it != execution_scopes_.end()) + return it->second.nodes; + return {}; +} + +absl::Span +GpuCommandBuffer::barriers(ExecutionScopeId id) const { + if (auto it = execution_scopes_.find(id); it != execution_scopes_.end()) + return it->second.barriers; + return {}; } } // namespace stream_executor::gpu diff --git a/xla/stream_executor/gpu/gpu_command_buffer.h b/xla/stream_executor/gpu/gpu_command_buffer.h index 5223e9d24e7cb..0172c62ba0c8d 100644 --- a/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/xla/stream_executor/gpu/gpu_command_buffer.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,11 +19,17 @@ limitations under the License. #include #include #include +#include +#include #include #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" @@ -33,96 +39,143 @@ limitations under the License. #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream_executor_internal.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace stream_executor::gpu { -// GpuCommandBuffer provides platform-specific CommandBufferInterface -// implementation (it's backed by CUDA or HIP graphs on NVIDIA and AMD devices). -class GpuCommandBuffer : public internal::CommandBufferInterface { +// GpuCommandBuffer provides platform-specific CommandBuffer implementation +// (it's backed by CUDA or HIP graphs on NVIDIA and AMD devices). +class GpuCommandBuffer : public CommandBuffer { public: - GpuCommandBuffer(CommandBuffer::Mode mode, GpuExecutor* parent, - GpuGraphHandle graph, bool is_owned_graph = true); + // A handle to a Gpu graph node and a metadata describing its properties. Each + // command (launch, memcpy, etc.) creates one or more graph nodes. + struct GpuGraphNodeInfo { + // A handle to the gpu graph node corresponding to a command. + GpuGraphNodeHandle handle = nullptr; + }; + + // A handle to Gpu graph barrier and metadata describing its properties. Each + // call to `Barrier` creates a new barrier record. + struct GpuGraphBarrierInfo { + // A handle to graph node acting as a barrier that defines execution order. + // It can be a handle to a `GpuGraphNodeInfo` node or a handle to an empty + // node created to be a barrier. We try to reuse existing nodes as barriers + // if possible to reduce the size of constructed gpu graphs. + GpuGraphNodeHandle handle = nullptr; + + // If `true` it means `handle` corresponds to an empty node specifically + // created to act as an execution barrier, otherwise `handle` points to one + // of the nodes created for recorded commands. + bool is_barrier_node = true; + + // Nodes with index smaller than `nodes_offset` are synchronized with this + // barrier. We use this offset to find nodes added after the last barrier + // that should be added as dependencies to the next barrier. + size_t nodes_offset = 0; + }; + + GpuCommandBuffer(Mode mode, GpuExecutor* parent, GpuGraphHandle graph, + bool is_owned_graph = true); ~GpuCommandBuffer() override; - tsl::Status Trace(Stream* stream, - absl::AnyInvocable function) override; + absl::Status Barrier(StreamExecutor* executor, + ExecutionScopeId execution_scope_id) override; + + absl::Status Barrier( + StreamExecutor* executor, + absl::Span execution_scope_ids) override; - tsl::Status Launch(const ThreadDim& threads, const BlockDim& blocks, - const Kernel& kernel, const KernelArgs& args) override; + absl::Status Barrier(StreamExecutor* executor, + ExecutionScopeId from_execution_scope_id, + ExecutionScopeId to_execution_scope_id) override; - tsl::Status AddNestedCommandBuffer(const CommandBuffer& nested) override; + absl::Status Launch(ExecutionScopeId execution_scope_id, + const ThreadDim& threads, const BlockDim& blocks, + const Kernel& kernel, const KernelArgs& args) override; - tsl::Status MemcpyDeviceToDevice(DeviceMemoryBase* dst, - const DeviceMemoryBase& src, - uint64_t size) override; + absl::Status AddNestedCommandBuffer(ExecutionScopeId execution_scope_id, + const CommandBuffer& nested) override; - tsl::Status Memset(DeviceMemoryBase* dst, - CommandBuffer::BitPattern bit_pattern, - size_t num_elements) override; + absl::Status MemcpyDeviceToDevice(ExecutionScopeId execution_scope_id, + DeviceMemoryBase* dst, + const DeviceMemoryBase& src, + uint64_t size) override; - tsl::StatusOr Allocate(size_t bytes) override; + absl::Status Memset(ExecutionScopeId execution_scope_id, + DeviceMemoryBase* dst, BitPattern bit_pattern, + size_t num_elements) override; - tsl::Status If(StreamExecutor* executor, DeviceMemory predicate, - CommandBuffer::Builder then_builder) override; + absl::StatusOr Allocate(ExecutionScopeId execution_scope_id, + size_t bytes) override; - tsl::Status IfElse(StreamExecutor* executor, DeviceMemory predicate, - CommandBuffer::Builder then_builder, - CommandBuffer::Builder else_builder) override; + absl::Status Free(ExecutionScopeId execution_scope_id, + DeviceMemoryBase dst) override; - tsl::Status Case(StreamExecutor* executor, DeviceMemory index, - std::vector branches) override; + absl::Status If(ExecutionScopeId execution_scope_id, StreamExecutor* executor, + DeviceMemory predicate, Builder then_builder) override; - tsl::Status For(StreamExecutor* executor, int32_t num_iteration, - DeviceMemory loop_counter, - CommandBuffer::Builder body_builder) override; + absl::Status IfElse(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, DeviceMemory predicate, + Builder then_builder, Builder else_builder) override; - tsl::Status While(StreamExecutor* executor, DeviceMemory pred, - CommandBuffer::Builder cond_builder, - CommandBuffer::Builder body_builder) override; + absl::Status Case(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, DeviceMemory index, + std::vector branches) override; - tsl::Status Finalize() override; - tsl::Status Update() override; + absl::Status For(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, int32_t num_iteration, + DeviceMemory loop_counter, + Builder body_builder) override; + + absl::Status While(ExecutionScopeId execution_scope_id, + StreamExecutor* executor, DeviceMemory pred, + ExecutionScopeBuilder cond_builder, + Builder body_builder) override; + + absl::Status Finalize() override; + absl::Status Update() override; GpuGraphExecHandle executable() const { return exec_; } GpuGraphHandle graph() const { return graph_; } - CommandBuffer::Mode mode() const override { return mode_; } - CommandBuffer::State state() const override { return state_; } + Mode mode() const override { return mode_; } + State state() const override { return state_; } + + static GpuCommandBuffer* Cast(CommandBuffer* command_buffer) { + return static_cast(command_buffer); + } + + static const GpuCommandBuffer* Cast(const CommandBuffer* command_buffer) { + return static_cast(command_buffer); + } + + absl::Span nodes(ExecutionScopeId id) const; + absl::Span barriers(ExecutionScopeId id) const; + + absl::Span nodes() const { + return nodes(kDefaulExecutionScope); + } - // A helper template for launching typed kernels. - template - tsl::Status Launch(const TypedKernel& kernel, - const ThreadDim& threads, const BlockDim& blocks, - Args... args); + absl::Span barriers() const { + return barriers(kDefaulExecutionScope); + } + + private: + absl::Status Trace(Stream* stream, + absl::AnyInvocable function) override; // We track the total number of allocated and alive executable graphs in the // process to track the command buffers resource usage. Executable graph // allocates resources on a GPU devices (rule of thumb is ~8kb per node), so // we have to be careful not to keep too many of them alive for too long, or // we have a higher risk of OOM errors. - // - // TODO(ezhulenev): We need to have a policy for how to evict unused - // executable graph instances from a device, currently lifetime of an - // executable graph is tied to a parent command buffer, and we can have - // thousands of command buffers alive at the same time. static int64_t AllocatedExecs(); static int64_t AliveExecs(); - static GpuCommandBuffer* Cast(CommandBuffer* command_buffer) { - return static_cast(command_buffer->implementation()); - } - - static const GpuCommandBuffer* Cast(const CommandBuffer* command_buffer) { - return static_cast( - command_buffer->implementation()); - } - private: using Dependencies = absl::InlinedVector; + using NoOpKernel = TypedKernel<>; + // A signature of a device kernels updating conditional handle(s). using SetIfConditionKernel = TypedKernel>; @@ -145,16 +198,16 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { TypedKernel>; // A callback to launch a kernel that updates conditional handles state. - using SetConditionFn = - std::function)>; + using SetConditionFn = std::function)>; - // An extension of `CommandBuffer::Builder` for building conditional command - // buffers tied to conditional handles. + // An extension of `Builder` for building conditional command buffers tied to + // conditional handles. using ConditionBuilder = - std::function; + std::function; // Wraps a regular command buffer builder into condition builder. - static ConditionBuilder ToConditionBuilder(CommandBuffer::Builder builder); + static ConditionBuilder ToConditionBuilder(Builder builder); using ConditionType = typename GpuDriver::GpuGraphConditionalNodeParams::Type; @@ -175,55 +228,79 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { // For each conditional node in the Gpu graph we keep a record of conditional // command buffers attached to a node, so we can apply updates to them. struct ConditionalCommandBuffers { - ConditionalCommandBuffers(std::vector handles, - std::vector command_buffers) - : handles(std::move(handles)), - command_buffers(std::move(command_buffers)) {} - std::vector handles; - std::vector command_buffers; + std::vector> command_buffers; }; using AllocationResult = std::pair; - tsl::StatusOr> + absl::StatusOr> CreateConditionalHandles(size_t num_handles); - tsl::StatusOr> CreateConditionalNodes( - ConditionType type, absl::Span handles); - - tsl::StatusOr> CreateConditionalCommandBuffers( + absl::StatusOr>> + CreateConditionalCommandBuffers( absl::Span handles, absl::Span graphs, absl::Span builders); - tsl::Status UpdateConditionalCommandBuffers( + absl::Status UpdateConditionalCommandBuffers( absl::Span handles, - absl::Span command_buffers, + absl::Span> command_buffers, absl::Span builders); - tsl::Status CreateConditionalCommand( + absl::StatusOr> CreateConditionalNodes( + ExecutionScopeId execution_scope_id, ConditionType type, + absl::Span handles); + + absl::Status CreateConditionalCommand( + ExecutionScopeId execution_scope_id, StreamExecutor* executor, ConditionType type, SetConditionFn set_condition, absl::Span builders); - // TODO(ezhulenev): Currently we serialize all Gpu nodes by adding a - // dependency between all nodes added to a command buffer. We need a - // concept of a barrier at a command buffer level. - Dependencies GetDependencies(); + Dependencies GetBarrier(ExecutionScopeId execution_scope_id); + + // Returns loaded auxiliary kernels, or loads them on a given stream executor. + // Loaded kernels owned by a current command buffer. + absl::StatusOr GetSetIfConditionKernel( + StreamExecutor* executor); + absl::StatusOr GetSetIfElseConditionKernel( + StreamExecutor* executor); + absl::StatusOr GetSetCaseConditionKernel( + StreamExecutor* executor); + absl::StatusOr GetSetForConditionKernel( + StreamExecutor* executor); + absl::StatusOr GetSetWhileConditionKernel( + StreamExecutor* executor); + absl::StatusOr GetNoOpKernel(StreamExecutor* executor); + + // Recursively disable all nodes corresponding to barriers (including nested + // conditional command buffers). This is work around the fact that we can't + // use empty nodes inside conditional CUDA graphs and instead we add no-op + // kernel nodes, however large number of no-op kernels impacts performance. + absl::Status DisableBarriersExecution(GpuGraphExecHandle exec); + + // Launches CUDA kernels with packed arguments. + absl::Status LaunchWithPackedArgs( + ExecutionScopeId execution_scope_id, const ThreadDim& threads, + const BlockDim& blocks, const Kernel& kernel, + const KernelArgsPackedArrayBase& packed_args); // Returns OK status if command buffer is not finalized and it is still // possible to add new commands to it, otherwise returns internal error. - tsl::Status CheckNotFinalized(); - - // Returns OK status if command buffer is primary, otherwise returns internal - // error. - tsl::Status CheckPrimary(); + absl::Status CheckNotFinalized(); // Returns OK status if the number of command buffers is equal to the expected // one, otherwise returns internal error. - tsl::Status CheckNumCommandBuffers( + absl::Status CheckNumCommandBuffers( const ConditionalCommandBuffers& cmd_buffers, size_t num_cmd_buffers); + // Creates a new no-op node acting as a barrier. + absl::StatusOr CreateBarrierNode( + StreamExecutor* executor, const Dependencies& dependencies); + + // Collects a set of dependencies for a new barrier. + Dependencies GetBarrierDependencies(ExecutionScopeId execution_scope_id); + static_assert(std::is_pointer_v, "GpuGraphHandle must be a pointer"); static_assert(std::is_pointer_v, @@ -231,8 +308,8 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { static_assert(std::is_pointer_v, "GpuGraphNodeHandle must be a pointer"); - CommandBuffer::Mode mode_; - CommandBuffer::State state_ = CommandBuffer::State::kCreate; + Mode mode_; + State state_ = State::kCreate; GpuExecutor* parent_; // not owned, must outlive *this @@ -242,53 +319,72 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { GpuGraphExecHandle exec_ = nullptr; // owned if `is_owned_graph_exec_` bool is_owned_graph_exec_ = true; // ownership of `is_owned_graph_exec_` - // Handles to graph nodes corresponding to command buffer commands. Owned by - // the `graph_` instance. - std::vector nodes_; + // ExecutionScope holds the state of an underlying CUDA graph (nodes an + // barriers added to a graph) for a single execution scope. + struct ExecutionScope { + // Tracks indices into data structures during command buffer updates. + struct UpdateState { + // Index points to the graph node inside `nodes` that will be updated + // next. + int64_t node_idx = 0; + + // Index points to the barrier node inside `barriers` that will be updated + // on a next call to `Barrier(...)`. + int64_t barrier_idx = 0; + + // Index points to the conditional command buffers that will be updated + // next when we'll be updating next conditional command (If, Case, While). + int64_t conditional_idx = 0; + }; + + // Gpu graph nodes corresponding to recorded commands (launch, memcpy, + // etc.). + std::vector nodes; + + // Gpu graph barriers that define recorded commands execution order. + std::vector barriers; + + // Command buffers for conditional nodes in the Gpu graph. Underlying Gpu + // graphs owned by the `graph_` instance. + std::vector conditional_command_buffers; + + // Tracks execution scope update state. + UpdateState update_state; + }; - // Command buffers for conditional nodes in the Gpu graph. Underlying Gpu - // graphs owned by the `graph_` instance. - std::vector conditional_command_buffers_; + // Execution scopes recorded into the command buffer. + absl::flat_hash_map execution_scopes_; // Track the number of command buffer updates for debugging. int64_t num_updates_ = 0; - // Tracks indices into internal data structures during command buffer updates. - struct UpdateState { - // Index points to the graph node inside `nodes_` that will be updated next. - int64_t node_idx = 0; - - // Index points to the conditional command buffers that will be updated next - // when we'll be updating next conditional command (If, Case, While). - int64_t conditional_idx = 0; - }; - - UpdateState update_state_; + // Lazy loaded auxiliary kernels required for building CUDA graphs (no-op + // barriers, updating conditional handles, etc.). + SetIfConditionKernel set_if_condition_kernel_; + SetIfElseConditionKernel set_if_else_condition_kernel_; + SetCaseConditionKernel set_case_condition_kernel_; + SetForConditionKernel set_for_condition_kernel_; + SetWhileConditionKernel set_while_condition_kernel_; + NoOpKernel noop_kernel_; }; -template -inline tsl::Status GpuCommandBuffer::Launch( - const TypedKernel& kernel, const ThreadDim& threads, - const BlockDim& blocks, Args... args) { - auto kernel_args = PackKernelArgs(kernel, args...); - TF_RETURN_IF_ERROR(Launch(threads, blocks, kernel, *kernel_args)); - return tsl::OkStatus(); -} - //===----------------------------------------------------------------------===// // Implementation details device kernels required by GpuCommandBuffer. //===----------------------------------------------------------------------===// -// See `cuda_conditional_kernels.cu.cc` for CUDA implementations. These are +// A no-op kernel required for creating barriers inside command buffers because +// empty nodes are not supported within conditional CUDA graphs (in CUDA 12.3). +void* GetNoOpKernel(); + +// See `cuda_conditional_kernels.cc` for CUDA implementation. These are // various kernels that update Gpu conditionals based on the device memory // values, and allow implementing on-device control flow via conditional command // buffers. - -void* GetSetIfConditionKernel(); -void* GetSetIfElseConditionKernel(); -void* GetSetCaseConditionKernel(); -void* GetSetForConditionKernel(); -void* GetSetWhileConditionKernel(); +std::string_view GetSetIfConditionKernel(); +std::string_view GetSetIfElseConditionKernel(); +std::string_view GetSetCaseConditionKernel(); +std::string_view GetSetForConditionKernel(); +std::string_view GetSetWhileConditionKernel(); } // namespace stream_executor::gpu diff --git a/xla/stream_executor/gpu/gpu_command_buffer_test.cc b/xla/stream_executor/gpu/gpu_command_buffer_test.cc new file mode 100644 index 0000000000000..7ff757aab9aa5 --- /dev/null +++ b/xla/stream_executor/gpu/gpu_command_buffer_test.cc @@ -0,0 +1,1358 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/gpu/gpu_command_buffer.h" + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/ascii.h" +#include "xla/service/platform_util.h" +#include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/gpu_test_kernels.h" +#include "xla/stream_executor/gpu/gpu_types.h" // IWYU pragma: keep +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" +#include "tsl/platform/test_benchmark.h" + +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" +#endif + +namespace stream_executor::gpu { + +using ExecutionScopeId = CommandBuffer::ExecutionScopeId; + +static Platform* GpuPlatform() { + auto name = absl::AsciiStrToUpper( + xla::PlatformUtil::CanonicalPlatformName("gpu").value()); + return PlatformManager::PlatformWithName(name).value(); +} + +static MultiKernelLoaderSpec GetAddI32KernelSpec() { + MultiKernelLoaderSpec spec(/*arity=*/3); +#if defined(GOOGLE_CUDA) + spec.AddCudaPtxInMemory(internal::kAddI32Kernel, "add"); +#elif defined(TENSORFLOW_USE_ROCM) + spec.AddCudaCubinInMemory(internal::kAddI32KernelModule, "add"); +#endif + return spec; +} + +using AddI32Kernel = TypedKernel, DeviceMemory, + DeviceMemory>; +using MulI32Kernel = TypedKernel, DeviceMemory, + DeviceMemory>; +using IncAndCmpKernel = + TypedKernel, DeviceMemory, int32_t>; + +using AddI32Ptrs3 = TypedKernel>; + +static constexpr auto nested = CommandBuffer::Mode::kNested; // NOLINT +static constexpr auto primary = CommandBuffer::Mode::kPrimary; // NOLINT + +template +static std::vector Deps(Info info) { + if (auto deps = GpuDriver::GraphNodeGetDependencies(info.handle); deps.ok()) { + return *deps; + } + return {GpuGraphNodeHandle(0xDEADBEEF)}; +} + +template +static std::vector ExpectedDeps(Infos... info) { + return {info.handle...}; +} + +// Some of the tests rely on CUDA 12.3+ features. +static bool IsAtLeastCuda12300() { +#if defined(TENSORFLOW_USE_ROCM) + return false; +#endif +#if CUDA_VERSION >= 12030 + return true; +#endif + return false; +} + +TEST(GpuCommandBufferTest, LaunchSingleKernel) { + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=1, b=2, c=0 + DeviceMemory a = executor->AllocateArray(length, 0); + DeviceMemory b = executor->AllocateArray(length, 0); + DeviceMemory c = executor->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length)); + TF_ASSERT_OK(stream->Memset32(&b, 2, byte_length)); + TF_ASSERT_OK(stream->MemZero(&c, byte_length)); + + // Create a command buffer with a single kernel launch. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(cmd_buffer->Launch(add, ThreadDim(), BlockDim(4), a, b, c)); + TF_ASSERT_OK(cmd_buffer->Finalize()); + + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + // Copy `c` data back to host. + std::vector dst(4, 42); + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + + std::vector expected = {3, 3, 3, 3}; + ASSERT_EQ(dst, expected); + + // Prepare argument for graph update: d = 0 + DeviceMemory d = executor->AllocateArray(length, 0); + TF_ASSERT_OK(stream->MemZero(&d, byte_length)); + + // Update command buffer to write into `d` buffer. + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(cmd_buffer->Launch(add, ThreadDim(), BlockDim(4), a, b, d)); + TF_ASSERT_OK(cmd_buffer->Finalize()); + + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + // Copy `d` data back to host. + std::fill(dst.begin(), dst.end(), 42); + TF_ASSERT_OK(stream->Memcpy(dst.data(), d, byte_length)); + ASSERT_EQ(dst, expected); +} + +TEST(CudaCommandBufferTest, TraceSingleKernel) { +#if defined(TENSORFLOW_USE_ROCM) + GTEST_SKIP() << "Not supported on ROCM"; +#endif +#if CUDA_VERSION < 12030 + GTEST_SKIP() << "Command buffer tracing is not supported"; +#endif + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + // Register a kernel with a custom arguments packing function that packs + // device memory arguments into a struct with pointers. + MultiKernelLoaderSpec spec(/*arity=*/1, [&](const Kernel& kernel, + const KernelArgs& args) { + auto bufs = Cast(&args)->device_memory_args(); + auto cast = [](auto m) { return reinterpret_cast(m.opaque()); }; + return PackKernelArgs(/*shmem_bytes=*/0, internal::Ptrs3{ + cast(bufs[0]), + cast(bufs[1]), + cast(bufs[2]), + }); + }); + spec.AddInProcessSymbol(internal::GetAddI32Ptrs3Kernel(), "add"); + + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Ptrs3::Create(executor, spec)); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=1, b=2, c=0 + DeviceMemory a = executor->AllocateArray(length, 0); + DeviceMemory b = executor->AllocateArray(length, 0); + DeviceMemory c = executor->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length)); + TF_ASSERT_OK(stream->Memset32(&b, 2, byte_length)); + TF_ASSERT_OK(stream->MemZero(&c, byte_length)); + + // Use an array of device memory base pointers as argument to test packing. + KernelArgsDeviceMemoryArray args({a, b, c}, 0); + + // Create a command buffer by tracing kernel launch operations. + auto cmd_buffer = CommandBuffer::Trace( + executor, + [&](Stream* stream) { + return executor->Launch(stream, ThreadDim(), BlockDim(4), *add, args); + }, + primary); + + TF_ASSERT_OK(cmd_buffer.status()); + TF_ASSERT_OK(executor->Submit(stream.get(), **cmd_buffer)); + + // Copy data back to host. + std::vector dst(4, 42); + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + + std::vector expected = {3, 3, 3, 3}; + ASSERT_EQ(dst, expected); +} + +TEST(GpuCommandBufferTest, LaunchNestedCommandBuffer) { + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + MultiKernelLoaderSpec spec = GetAddI32KernelSpec(); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=1, b=2, c=0 + DeviceMemory a = executor->AllocateArray(length, 0); + DeviceMemory b = executor->AllocateArray(length, 0); + DeviceMemory c = executor->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length)); + TF_ASSERT_OK(stream->Memset32(&b, 2, byte_length)); + TF_ASSERT_OK(stream->MemZero(&c, byte_length)); + + // Create a command buffer with a single kernel launch. + auto primary_cmd = CommandBuffer::Create(executor).value(); + auto nested_cmd = CommandBuffer::Create(executor, nested).value(); + TF_ASSERT_OK(nested_cmd->Launch(add, ThreadDim(), BlockDim(4), a, b, c)); + TF_ASSERT_OK(primary_cmd->AddNestedCommandBuffer(*nested_cmd)); + TF_ASSERT_OK(primary_cmd->Finalize()); + + TF_ASSERT_OK(executor->Submit(stream.get(), *primary_cmd)); + + // Copy `c` data back to host. + std::vector dst(4, 42); + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + + std::vector expected = {3, 3, 3, 3}; + ASSERT_EQ(dst, expected); + + // Prepare argument for graph update: d = 0 + DeviceMemory d = executor->AllocateArray(length, 0); + TF_ASSERT_OK(stream->MemZero(&d, byte_length)); + + // Update command buffer to write into `d` buffer by creating a new nested + // command buffer. + nested_cmd = CommandBuffer::Create(executor, nested).value(); + TF_ASSERT_OK(nested_cmd->Launch(add, ThreadDim(), BlockDim(4), a, b, d)); + TF_ASSERT_OK(primary_cmd->Update()); + TF_ASSERT_OK(primary_cmd->AddNestedCommandBuffer(*nested_cmd)); + TF_ASSERT_OK(primary_cmd->Finalize()); + + TF_ASSERT_OK(executor->Submit(stream.get(), *primary_cmd)); + + // Copy `d` data back to host. + std::fill(dst.begin(), dst.end(), 42); + TF_ASSERT_OK(stream->Memcpy(dst.data(), d, byte_length)); + ASSERT_EQ(dst, expected); +} + +TEST(GpuCommandBufferTest, MemcpyDeviceToDevice) { + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=42, b=uninitialized + DeviceMemory a = executor->AllocateArray(length, 0); + DeviceMemory b = executor->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); + + // Create a command buffer with a single a to b memcpy command. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(cmd_buffer->MemcpyDeviceToDevice(&b, a, byte_length)); + TF_ASSERT_OK(cmd_buffer->Finalize()); + + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + // Copy `b` data back to host. + std::vector dst(4, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), a, byte_length)); + + std::vector expected = {42, 42, 42, 42}; + ASSERT_EQ(dst, expected); + + // Update command buffer to swap the memcpy direction. + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(cmd_buffer->MemcpyDeviceToDevice(&a, b, byte_length)); + TF_ASSERT_OK(cmd_buffer->Finalize()); + + // Clear destination to test that command buffer actually copied memory. + TF_ASSERT_OK(stream->Memset32(&a, 0, byte_length)); + + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + // Copy `a` data back to host. + std::fill(dst.begin(), dst.end(), 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), a, byte_length)); + ASSERT_EQ(dst, expected); +} + +TEST(GpuCommandBufferTest, Memset) { + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + DeviceMemory a = executor->AllocateArray(length, 0); + + // Create a command buffer with a single memset command. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(cmd_buffer->Memset(&a, uint32_t{42}, length)); + TF_ASSERT_OK(cmd_buffer->Finalize()); + + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + // Copy `a` data back to host. + std::vector dst(4, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), a, byte_length)); + + std::vector expected = {42, 42, 42, 42}; + ASSERT_EQ(dst, expected); + + // Update command buffer to use a new bit pattern. + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(cmd_buffer->Memset(&a, uint32_t{43}, length)); + TF_ASSERT_OK(cmd_buffer->Finalize()); + + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + // Copy `d` data back to host. + std::fill(dst.begin(), dst.end(), 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), a, byte_length)); + + expected = {43, 43, 43, 43}; + ASSERT_EQ(dst, expected); +} + +TEST(GpuCommandBufferTest, Barriers) { + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + // Allocate device buffers for memset operations. + std::vector> buffers; + for (size_t i = 0; i < 6; ++i) { + buffers.push_back(executor->AllocateArray(1, 0)); + } + + // Transfer buffers data back to host. + auto transfer_buffers = [&]() -> std::vector { + std::vector dst(buffers.size(), 0); + for (size_t i = 0; i < buffers.size(); ++i) { + TF_CHECK_OK(stream->Memcpy(dst.data() + i, buffers[i], sizeof(int32_t))); + } + return dst; + }; + + auto record = [&](CommandBuffer* cmd_buffer, uint32_t bit_pattern) { + // Check that root barrier ignored. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(&buffers[0], bit_pattern + 0, 1)); + // Check barrier after a single command. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(&buffers[1], bit_pattern + 1, 1)); + // Check that repeated barriers are no-op. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor)); + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(&buffers[2], bit_pattern + 2, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(&buffers[3], bit_pattern + 3, 1)); + // Check that barrier can have multiple dependencies. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(&buffers[4], bit_pattern + 4, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(&buffers[5], bit_pattern + 5, 1)); + // Check that barrier can be that last command. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor)); + return cmd_buffer->Finalize(); + }; + + // Create a command buffer with a DAG of memset commands. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(record(cmd_buffer.get(), 42)); + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + std::vector expected = {42, 43, 44, 45, 46, 47}; + ASSERT_EQ(transfer_buffers(), expected); + + // Check the command buffer structure. + GpuCommandBuffer* gpu_cmd_buffer = GpuCommandBuffer::Cast(cmd_buffer.get()); + ASSERT_EQ(gpu_cmd_buffer->nodes().size(), 6); + ASSERT_EQ(gpu_cmd_buffer->barriers().size(), 6); + + auto nodes = gpu_cmd_buffer->nodes(); + auto barriers = gpu_cmd_buffer->barriers(); + + // First barrier does not have any dependencies. + EXPECT_TRUE(barriers[0].is_barrier_node); + EXPECT_TRUE(Deps(barriers[0]).empty()); + + // Second barrier reuses first memset node. + EXPECT_FALSE(barriers[1].is_barrier_node); + EXPECT_EQ(barriers[1].handle, nodes[0].handle); + + // Third and fourth barriers reuse second memset node. + EXPECT_FALSE(barriers[2].is_barrier_node); + EXPECT_FALSE(barriers[3].is_barrier_node); + EXPECT_EQ(barriers[2].handle, nodes[1].handle); + EXPECT_EQ(barriers[3].handle, nodes[1].handle); + + // Fifth and sixth barriers are barrier nodes. + EXPECT_TRUE(barriers[4].is_barrier_node); + EXPECT_TRUE(barriers[5].is_barrier_node); + + EXPECT_EQ(Deps(barriers[4]), ExpectedDeps(nodes[2], nodes[3])); + EXPECT_EQ(Deps(barriers[5]), ExpectedDeps(nodes[4], nodes[5])); + + // Update command buffer to use a new bit pattern. + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(record(cmd_buffer.get(), 43)); + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + expected = {43, 44, 45, 46, 47, 48}; + ASSERT_EQ(transfer_buffers(), expected); +} + +TEST(GpuCommandBufferTest, IndependentExecutionScopes) { + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + CommandBuffer::ExecutionScopeId s0 = CommandBuffer::ExecutionScopeId(0); + CommandBuffer::ExecutionScopeId s1 = CommandBuffer::ExecutionScopeId(1); + + // Allocate device buffers for memset operations. + std::vector> buffers; + for (size_t i = 0; i < 4; ++i) { + buffers.push_back(executor->AllocateArray(1, 0)); + } + + // Transfer buffers data back to host. + auto transfer_buffers = [&]() -> std::vector { + std::vector dst(buffers.size(), 0); + for (size_t i = 0; i < buffers.size(); ++i) { + TF_CHECK_OK(stream->Memcpy(dst.data() + i, buffers[i], sizeof(int32_t))); + } + return dst; + }; + + auto record = [&](CommandBuffer* cmd_buffer, uint32_t bit_pattern) { + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[0], bit_pattern + 0, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[1], bit_pattern + 1, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s1, &buffers[2], bit_pattern + 2, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s1, &buffers[3], bit_pattern + 3, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor, s0)); + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor, s1)); + return cmd_buffer->Finalize(); + }; + + // Create a command buffer with a DAG of memset commands. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(record(cmd_buffer.get(), 42)); + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + std::vector expected = {42, 43, 44, 45}; + ASSERT_EQ(transfer_buffers(), expected); + + // Check the command buffer structure. + GpuCommandBuffer* gpu_cmd_buffer = GpuCommandBuffer::Cast(cmd_buffer.get()); + + auto nodes0 = gpu_cmd_buffer->nodes(s0); + auto nodes1 = gpu_cmd_buffer->nodes(s1); + auto barriers0 = gpu_cmd_buffer->barriers(s0); + auto barriers1 = gpu_cmd_buffer->barriers(s1); + + ASSERT_EQ(nodes0.size(), 2); + ASSERT_EQ(nodes1.size(), 2); + ASSERT_EQ(barriers0.size(), 1); + ASSERT_EQ(barriers1.size(), 1); + + EXPECT_TRUE(barriers0[0].is_barrier_node); + EXPECT_TRUE(barriers1[0].is_barrier_node); + + EXPECT_EQ(Deps(barriers0[0]), ExpectedDeps(nodes0[0], nodes0[1])); + EXPECT_EQ(Deps(barriers1[0]), ExpectedDeps(nodes1[0], nodes1[1])); + + // Update command buffer to use a new bit pattern. + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(record(cmd_buffer.get(), 43)); + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + expected = {43, 44, 45, 46}; + ASSERT_EQ(transfer_buffers(), expected); +} + +TEST(GpuCommandBufferTest, ExecutionScopeBarriers) { + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + CommandBuffer::ExecutionScopeId s0 = CommandBuffer::ExecutionScopeId(0); + CommandBuffer::ExecutionScopeId s1 = CommandBuffer::ExecutionScopeId(1); + CommandBuffer::ExecutionScopeId s2 = CommandBuffer::ExecutionScopeId(2); + + // Allocate device buffers for memset operations. + std::vector> buffers; + for (size_t i = 0; i < 7; ++i) { + buffers.push_back(executor->AllocateArray(1, 0)); + } + + // Transfer buffers data back to host. + auto transfer_buffers = [&]() -> std::vector { + std::vector dst(buffers.size(), 0); + for (size_t i = 0; i < buffers.size(); ++i) { + TF_CHECK_OK(stream->Memcpy(dst.data() + i, buffers[i], sizeof(int32_t))); + } + return dst; + }; + + auto record = [&](CommandBuffer* cmd_buffer, uint32_t bit_pattern) { + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[0], bit_pattern + 0, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[1], bit_pattern + 1, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s1, &buffers[2], bit_pattern + 2, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s1, &buffers[3], bit_pattern + 3, 1)); + // This will synchronize scopes 0 and 1 and also create an empty scope 2. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor, {s0, s1, s2})); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[4], bit_pattern + 4, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s1, &buffers[5], bit_pattern + 5, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s2, &buffers[6], bit_pattern + 6, 1)); + return cmd_buffer->Finalize(); + }; + + // Create a command buffer with a DAG of memset commands. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(record(cmd_buffer.get(), 42)); + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + std::vector expected = {42, 43, 44, 45, 46, 47, 48}; + ASSERT_EQ(transfer_buffers(), expected); + + // Check the command buffer structure. + GpuCommandBuffer* gpu_cmd_buffer = GpuCommandBuffer::Cast(cmd_buffer.get()); + + auto nodes0 = gpu_cmd_buffer->nodes(s0); + auto nodes1 = gpu_cmd_buffer->nodes(s1); + auto nodes2 = gpu_cmd_buffer->nodes(s2); + auto barriers0 = gpu_cmd_buffer->barriers(s0); + auto barriers1 = gpu_cmd_buffer->barriers(s1); + auto barriers2 = gpu_cmd_buffer->barriers(s2); + + ASSERT_EQ(nodes0.size(), 3); + ASSERT_EQ(nodes1.size(), 3); + ASSERT_EQ(nodes2.size(), 1); + ASSERT_EQ(barriers0.size(), 2); + ASSERT_EQ(barriers1.size(), 2); + ASSERT_EQ(barriers2.size(), 2); + + // All barriers are real barrier nodes. + EXPECT_TRUE(barriers0[0].is_barrier_node && barriers0[1].is_barrier_node); + EXPECT_TRUE(barriers1[0].is_barrier_node && barriers1[1].is_barrier_node); + EXPECT_TRUE(barriers2[0].is_barrier_node && barriers2[1].is_barrier_node); + + // All scopes share a broadcasted barrier. + EXPECT_TRUE(barriers0[1].handle == barriers1[1].handle); + EXPECT_TRUE(barriers1[1].handle == barriers2[1].handle); + + EXPECT_EQ(Deps(barriers0[0]), ExpectedDeps(nodes0[0], nodes0[1])); + EXPECT_EQ(Deps(barriers1[0]), ExpectedDeps(nodes1[0], nodes1[1])); + + EXPECT_TRUE(Deps(barriers2[0]).empty()); + EXPECT_EQ(Deps(barriers2[1]), + ExpectedDeps(barriers0[0], barriers1[0], barriers2[0])); + + EXPECT_EQ(Deps(nodes0[2]), ExpectedDeps(barriers0[1])); + EXPECT_EQ(Deps(nodes1[2]), ExpectedDeps(barriers1[1])); + EXPECT_EQ(Deps(nodes2[0]), ExpectedDeps(barriers2[1])); + + // Update command buffer to use a new bit pattern. + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(record(cmd_buffer.get(), 43)); + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + expected = {43, 44, 45, 46, 47, 48, 49}; + ASSERT_EQ(transfer_buffers(), expected); +} + +TEST(GpuCommandBufferTest, ExecutionScopeOneDirectionalBarriers) { + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + CommandBuffer::ExecutionScopeId s0 = CommandBuffer::ExecutionScopeId(0); + CommandBuffer::ExecutionScopeId s1 = CommandBuffer::ExecutionScopeId(1); + + // Allocate device buffers for memset operations. + std::vector> buffers; + for (size_t i = 0; i < 6; ++i) { + buffers.push_back(executor->AllocateArray(1, 0)); + } + + // Transfer buffers data back to host. + auto transfer_buffers = [&]() -> std::vector { + std::vector dst(buffers.size(), 0); + for (size_t i = 0; i < buffers.size(); ++i) { + TF_CHECK_OK(stream->Memcpy(dst.data() + i, buffers[i], sizeof(int32_t))); + } + return dst; + }; + + auto record = [&](CommandBuffer* cmd_buffer, uint32_t bit_pattern) { + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[0], bit_pattern + 0, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[1], bit_pattern + 1, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s1, &buffers[2], bit_pattern + 2, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s1, &buffers[3], bit_pattern + 3, 1)); + // This will synchronize scopes 0 and 1. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor, s0, s1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[4], bit_pattern + 4, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s1, &buffers[5], bit_pattern + 5, 1)); + return cmd_buffer->Finalize(); + }; + + // Create a command buffer with a DAG of memset commands. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(record(cmd_buffer.get(), 42)); + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + std::vector expected = {42, 43, 44, 45, 46, 47}; + ASSERT_EQ(transfer_buffers(), expected); + + // Check the command buffer structure. + GpuCommandBuffer* gpu_cmd_buffer = GpuCommandBuffer::Cast(cmd_buffer.get()); + + auto nodes0 = gpu_cmd_buffer->nodes(s0); + auto nodes1 = gpu_cmd_buffer->nodes(s1); + auto barriers0 = gpu_cmd_buffer->barriers(s0); + auto barriers1 = gpu_cmd_buffer->barriers(s1); + + ASSERT_EQ(nodes0.size(), 3); + ASSERT_EQ(nodes1.size(), 3); + ASSERT_EQ(barriers0.size(), 1); + ASSERT_EQ(barriers1.size(), 2); + + // All barriers are real barrier nodes. + EXPECT_TRUE(barriers0[0].is_barrier_node); + EXPECT_TRUE(barriers1[0].is_barrier_node && barriers1[1].is_barrier_node); + + EXPECT_EQ(Deps(barriers0[0]), ExpectedDeps(nodes0[0], nodes0[1])); + EXPECT_EQ(Deps(barriers1[0]), ExpectedDeps(nodes1[0], nodes1[1])); + EXPECT_EQ(Deps(barriers1[1]), ExpectedDeps(barriers0[0], barriers1[0])); + EXPECT_EQ(Deps(nodes0[2]), ExpectedDeps(barriers0[0])); + EXPECT_EQ(Deps(nodes1[2]), ExpectedDeps(barriers1[1])); + + // Update command buffer to use a new bit pattern. + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(record(cmd_buffer.get(), 43)); + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + expected = {43, 44, 45, 46, 47, 48}; + ASSERT_EQ(transfer_buffers(), expected); +} + +TEST(GpuCommandBufferTest, ConditionalIf) { + if (!IsAtLeastCuda12300()) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=1, b=2, c=0, pred=true + DeviceMemory pred = executor->AllocateArray(1, 0); + DeviceMemory a = executor->AllocateArray(length, 0); + DeviceMemory b = executor->AllocateArray(length, 0); + DeviceMemory c = executor->AllocateArray(length, 0); + + constexpr bool kTrue = true; + TF_ASSERT_OK(stream->Memcpy(&pred, &kTrue, 1)); + TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length)); + TF_ASSERT_OK(stream->Memset32(&b, 2, byte_length)); + TF_ASSERT_OK(stream->MemZero(&c, byte_length)); + + // if (pred == true) c = a + b + CommandBuffer::Builder then_builder = [&](CommandBuffer* then_cmd) { + return then_cmd->Launch(add, ThreadDim(), BlockDim(4), a, b, c); + }; + + // Create a command buffer with a single conditional operation. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(cmd_buffer->If(executor, pred, then_builder)); + TF_ASSERT_OK(cmd_buffer->Finalize()); + + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + // Copy `c` data back to host. + std::vector dst(4, 42); + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + + std::vector expected = {3, 3, 3, 3}; + ASSERT_EQ(dst, expected); + + // Reset predicate to false and clear output buffer. + constexpr bool kFalse = false; + TF_ASSERT_OK(stream->Memcpy(&pred, &kFalse, 1)); + TF_ASSERT_OK(stream->MemZero(&c, byte_length)); + + // Submit the same command buffer, but this time it should not execute + // conditional branch as conditional handle should be updated to false. + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + std::vector zeroes = {0, 0, 0, 0}; + ASSERT_EQ(dst, zeroes); + + // Prepare argument for graph update: d = 0 + DeviceMemory d = executor->AllocateArray(length, 0); + TF_ASSERT_OK(stream->MemZero(&d, byte_length)); + + // Set predicate buffer to true to run conditional command buffer. + TF_ASSERT_OK(stream->Memcpy(&pred, &kTrue, 1)); + + // if (pred == true) d = a + b (write to a new location). + then_builder = [&](CommandBuffer* then_cmd) { + return then_cmd->Launch(add, ThreadDim(), BlockDim(4), a, b, d); + }; + + // Update command buffer with a conditional to use new builder. + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(cmd_buffer->If(executor, pred, then_builder)); + TF_ASSERT_OK(cmd_buffer->Finalize()); + + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + // Copy `d` data back to host. + std::fill(dst.begin(), dst.end(), 42); + TF_ASSERT_OK(stream->Memcpy(dst.data(), d, byte_length)); + ASSERT_EQ(dst, expected); +} + +TEST(GpuCommandBufferTest, ConditionalIfElse) { + if (!IsAtLeastCuda12300()) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + // Load addition kernel. + MultiKernelLoaderSpec add_spec(/*arity=*/3); + add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, add_spec)); + + // Load multiplication kernel. + MultiKernelLoaderSpec mul_spec(/*arity=*/3); + mul_spec.AddInProcessSymbol(internal::GetMulI32Kernel(), "mul"); + TF_ASSERT_OK_AND_ASSIGN(auto mul, MulI32Kernel::Create(executor, mul_spec)); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=2, b=3, c=0, pred=true + DeviceMemory pred = executor->AllocateArray(1, 0); + DeviceMemory a = executor->AllocateArray(length, 0); + DeviceMemory b = executor->AllocateArray(length, 0); + DeviceMemory c = executor->AllocateArray(length, 0); + + constexpr bool kTrue = true; + TF_ASSERT_OK(stream->Memcpy(&pred, &kTrue, 1)); + TF_ASSERT_OK(stream->Memset32(&a, 2, byte_length)); + TF_ASSERT_OK(stream->Memset32(&b, 3, byte_length)); + TF_ASSERT_OK(stream->MemZero(&c, byte_length)); + + // if (pred == true) c = a + b + CommandBuffer::Builder then_builder = [&](CommandBuffer* then_cmd) { + return then_cmd->Launch(add, ThreadDim(), BlockDim(4), a, b, c); + }; + + // if (pred == false) c = a * b + CommandBuffer::Builder else_builder = [&](CommandBuffer* else_cmd) { + return else_cmd->Launch(mul, ThreadDim(), BlockDim(4), a, b, c); + }; + + // Create a command buffer with a single conditional operation. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(cmd_buffer->IfElse(executor, pred, then_builder, else_builder)); + TF_ASSERT_OK(cmd_buffer->Finalize()); + + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `c` data back to host. + std::vector dst(4, 42); + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + + std::vector expected_add = {5, 5, 5, 5}; + ASSERT_EQ(dst, expected_add); + + // Reset predicate to false. + constexpr bool kFalse = false; + TF_ASSERT_OK(stream->Memcpy(&pred, &kFalse, 1)); + + // Submit the same command buffer, but this time it should execute `else` + // branch and multiply inputs. + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + std::vector expected_mul = {6, 6, 6, 6}; + ASSERT_EQ(dst, expected_mul); + + // Prepare argument for graph update: d = 0 + DeviceMemory d = executor->AllocateArray(length, 0); + TF_ASSERT_OK(stream->MemZero(&d, byte_length)); + + // if (pred == false) d = a * b (write to a new location). + else_builder = [&](CommandBuffer* else_cmd) { + return else_cmd->Launch(mul, ThreadDim(), BlockDim(4), a, b, d); + }; + + // Update command buffer with a conditional to use new `else` builder. + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(cmd_buffer->IfElse(executor, pred, then_builder, else_builder)); + TF_ASSERT_OK(cmd_buffer->Finalize()); + + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `d` data back to host. + std::fill(dst.begin(), dst.end(), 42); + TF_ASSERT_OK(stream->Memcpy(dst.data(), d, byte_length)); + ASSERT_EQ(dst, expected_mul); +} + +TEST(GpuCommandBufferTest, ConditionalCase) { + if (!IsAtLeastCuda12300()) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + // Load addition kernel. + MultiKernelLoaderSpec add_spec(/*arity=*/3); + add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, add_spec)); + + // Load multiplication kernel. + MultiKernelLoaderSpec mul_spec(/*arity=*/3); + mul_spec.AddInProcessSymbol(internal::GetMulI32Kernel(), "mul"); + TF_ASSERT_OK_AND_ASSIGN(auto mul, MulI32Kernel::Create(executor, mul_spec)); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=2, b=3, c=0, index=0 + DeviceMemory index = executor->AllocateArray(1, 0); + DeviceMemory a = executor->AllocateArray(length, 0); + DeviceMemory b = executor->AllocateArray(length, 0); + DeviceMemory c = executor->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&index, 0, sizeof(int32_t))); + TF_ASSERT_OK(stream->Memset32(&a, 2, byte_length)); + TF_ASSERT_OK(stream->Memset32(&b, 3, byte_length)); + TF_ASSERT_OK(stream->MemZero(&c, byte_length)); + + // if (index == 0) c = a + b + CommandBuffer::Builder branch0 = [&](CommandBuffer* branch0_cmd) { + return branch0_cmd->Launch(add, ThreadDim(), BlockDim(4), a, b, c); + }; + + // if (index == 1) c = a * b + CommandBuffer::Builder branch1 = [&](CommandBuffer* branch1_cmd) { + return branch1_cmd->Launch(mul, ThreadDim(), BlockDim(4), a, b, c); + }; + + // Create a command buffer with a single conditional operation. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(cmd_buffer->Case(executor, index, {branch0, branch1})); + TF_ASSERT_OK(cmd_buffer->Finalize()); + + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `c` data back to host. + std::vector dst(4, 42); + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + + std::vector expected_add = {5, 5, 5, 5}; + ASSERT_EQ(dst, expected_add); + + // Set index to `1` + TF_ASSERT_OK(stream->Memset32(&index, 1, sizeof(int32_t))); + + // Submit the same command buffer, but this time it should multiply inputs. + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + std::vector expected_mul = {6, 6, 6, 6}; + ASSERT_EQ(dst, expected_mul); + + // Set index to `-1` (out of bound index value). + TF_ASSERT_OK(stream->Memset32(&index, -1, sizeof(int32_t))); + + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + ASSERT_EQ(dst, expected_mul); + + // Set index to `2` (out of bound index value). + TF_ASSERT_OK(stream->Memset32(&index, 2, sizeof(int32_t))); + + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + ASSERT_EQ(dst, expected_mul); +} + +TEST(GpuCommandBufferTest, ConditionalFor) { + if (!IsAtLeastCuda12300()) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=1, b=0, loop_counter=100 + DeviceMemory loop_counter = executor->AllocateArray(1, 0); + DeviceMemory a = executor->AllocateArray(length, 0); + DeviceMemory b = executor->AllocateArray(length, 0); + + // Set loop counter to 100 to check that command buffer resets it. + TF_ASSERT_OK(stream->Memset32(&loop_counter, 100, sizeof(int32_t))); + TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length)); + TF_ASSERT_OK(stream->MemZero(&b, byte_length)); + + // Loop body: b = a + b + CommandBuffer::Builder body_builder = [&](CommandBuffer* body_cmd) { + return body_cmd->Launch(add, ThreadDim(), BlockDim(4), a, b, b); + }; + + int32_t num_iters = 10; + + // Create a command buffer with a single conditional operation. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK( + cmd_buffer->For(executor, num_iters, loop_counter, body_builder)); + TF_ASSERT_OK(cmd_buffer->Finalize()); + + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + // Copy `b` data back to host. + std::vector dst(4, 42); + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + + std::vector expected = {10, 10, 10, 10}; + ASSERT_EQ(dst, expected); +} + +TEST(GpuCommandBufferTest, ConditionalWhile) { + if (!IsAtLeastCuda12300()) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + // Load addition kernel. + MultiKernelLoaderSpec add_spec(/*arity=*/3); + add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, add_spec)); + + // Load inc_and_cmp kernel. + MultiKernelLoaderSpec icmp_spec(/*arity=*/3); + icmp_spec.AddInProcessSymbol(internal::GetIncAndCmpKernel(), "inc_and_cmp"); + TF_ASSERT_OK_AND_ASSIGN(auto inc_and_cmp, + IncAndCmpKernel::Create(executor, icmp_spec)); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=1, b=0, loop_counter=0, pred=false + // Value of `pred` is not important, as it will be updated by `cond_builder` + // below. + DeviceMemory pred = executor->AllocateArray(1, 0); + DeviceMemory loop_counter = executor->AllocateArray(1, 0); + DeviceMemory a = executor->AllocateArray(length, 0); + DeviceMemory b = executor->AllocateArray(length, 0); + + static constexpr bool kFalse = false; + TF_ASSERT_OK(stream->Memcpy(&pred, &kFalse, 1)); + TF_ASSERT_OK(stream->Memset32(&loop_counter, 0, sizeof(int32_t))); + TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length)); + TF_ASSERT_OK(stream->MemZero(&b, byte_length)); + + int32_t num_iters = 10; + + // Loop cond: loop_counter++ < num_iters; + CommandBuffer::ExecutionScopeBuilder cond_builder = + [&](ExecutionScopeId id, CommandBuffer* cond_cmd) { + return cond_cmd->Launch(inc_and_cmp, id, ThreadDim(), BlockDim(), + loop_counter, pred, num_iters); + }; + + // Loop body: b = a + b + CommandBuffer::Builder body_builder = [&](CommandBuffer* body_cmd) { + return body_cmd->Launch(add, ThreadDim(), BlockDim(length), a, b, b); + }; + + // Create a command buffer with a single conditional operation. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(cmd_buffer->While(executor, pred, cond_builder, body_builder)); + TF_ASSERT_OK(cmd_buffer->Finalize()); + + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + // Copy `b` data back to host. + std::vector dst(4, 42); + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + + std::vector expected = {10, 10, 10, 10}; + ASSERT_EQ(dst, expected); +} + +TEST(GpuCommandBufferTest, ConditionalIfInExecutionScope) { + if (!IsAtLeastCuda12300()) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + CommandBuffer::ExecutionScopeId s0 = CommandBuffer::ExecutionScopeId(0); + CommandBuffer::ExecutionScopeId s1 = CommandBuffer::ExecutionScopeId(1); + + DeviceMemory pred = executor->AllocateArray(1, 0); + + constexpr bool kTrue = true; + TF_ASSERT_OK(stream->Memcpy(&pred, &kTrue, 1)); + + // Allocate device buffers for memset operations. + std::vector> buffers; + for (size_t i = 0; i < 3; ++i) { + buffers.push_back(executor->AllocateArray(1, 0)); + } + + // Transfer buffers back to host. + auto transfer_buffers = [&]() -> std::vector { + std::vector dst(buffers.size(), 0); + for (size_t i = 0; i < buffers.size(); ++i) { + stream->Memcpy(dst.data() + i, buffers[i], sizeof(int32_t)).IgnoreError(); + } + return dst; + }; + + auto record = [&](CommandBuffer* cmd_buffer, uint32_t bit_pattern) { + // Record memsets in execution scope #0 + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[0], bit_pattern + 0, 1)); + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &buffers[1], bit_pattern + 1, 1)); + + // Record If in execution scope #1 + TF_RETURN_IF_ERROR( + cmd_buffer->If(s1, executor, pred, [&](CommandBuffer* then_cmd) { + return then_cmd->Memset(&buffers[2], bit_pattern + 2, 1); + })); + + // Create a barrier in execution scope #0. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor, s0)); + + // Create a barrier between two execution scopes. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor, {s0, s1})); + + return cmd_buffer->Finalize(); + }; + + // Create a command buffer with a DAG of memset commands. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(record(cmd_buffer.get(), 42)); + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + std::vector expected = {42, 43, 44}; + ASSERT_EQ(transfer_buffers(), expected); + + // Check the command buffer structure. + GpuCommandBuffer* gpu_cmd_buffer = GpuCommandBuffer::Cast(cmd_buffer.get()); + + auto nodes0 = gpu_cmd_buffer->nodes(s0); + auto nodes1 = gpu_cmd_buffer->nodes(s1); + auto barriers0 = gpu_cmd_buffer->barriers(s0); + auto barriers1 = gpu_cmd_buffer->barriers(s1); + + ASSERT_EQ(nodes0.size(), 2); + ASSERT_EQ(nodes1.size(), 2); + ASSERT_EQ(barriers0.size(), 3); + ASSERT_EQ(barriers1.size(), 3); + + EXPECT_EQ(Deps(barriers0[0]), ExpectedDeps(nodes0[0], nodes0[1])); + EXPECT_EQ(barriers0[0].handle, barriers0[1].handle); + + EXPECT_EQ(barriers1[0].handle, nodes1[0].handle); + EXPECT_EQ(barriers1[1].handle, nodes1[1].handle); + + // s0 and s1 share broadcasted barrier. + EXPECT_TRUE(barriers0[2].handle == barriers1[2].handle); + EXPECT_EQ(Deps(barriers0[2]), ExpectedDeps(barriers0[1], nodes1[1])); + + // TODO(b/326284532): Add a test for bit pattern update. + + // Disable conditional branch. + constexpr bool kFalse = false; + TF_ASSERT_OK(stream->Memcpy(&pred, &kFalse, 1)); + TF_ASSERT_OK(stream->MemZero(&buffers[2], sizeof(int32_t))); + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + expected = {42, 43, 0}; + ASSERT_EQ(transfer_buffers(), expected); +} + +TEST(GpuCommandBufferTest, ConditionalWhileInExecutionScope) { + if (!IsAtLeastCuda12300()) { + GTEST_SKIP() << "CUDA graph conditionals are not supported"; + } + + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + CommandBuffer::ExecutionScopeId s0 = CommandBuffer::ExecutionScopeId(0); + CommandBuffer::ExecutionScopeId s1 = CommandBuffer::ExecutionScopeId(1); + + // Load addition kernel. + MultiKernelLoaderSpec add_spec(/*arity=*/3); + add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, add_spec)); + + // Load inc_and_cmp kernel. + MultiKernelLoaderSpec icmp_spec(/*arity=*/3); + icmp_spec.AddInProcessSymbol(internal::GetIncAndCmpKernel(), "inc_and_cmp"); + TF_ASSERT_OK_AND_ASSIGN(auto inc_and_cmp, + IncAndCmpKernel::Create(executor, icmp_spec)); + + DeviceMemory pred = executor->AllocateArray(1, 0); + DeviceMemory loop_counter = executor->AllocateArray(1, 0); + DeviceMemory a = executor->AllocateArray(1, 0); + DeviceMemory b = executor->AllocateArray(1, 0); + DeviceMemory c = executor->AllocateArray(1, 0); + + TF_ASSERT_OK(stream->MemZero(&loop_counter, sizeof(int32_t))); + TF_ASSERT_OK(stream->Memset32(&a, 1, sizeof(int32_t))); + TF_ASSERT_OK(stream->MemZero(&b, sizeof(int32_t))); + + auto record = [&](CommandBuffer* cmd_buffer, uint32_t bit_pattern, + int32_t num_iters) { + // Record memset in execution scope #0 + TF_RETURN_IF_ERROR(cmd_buffer->Memset(s0, &c, bit_pattern, 1)); + + // Record While in execution scope #1 + TF_RETURN_IF_ERROR(cmd_buffer->While( + s1, executor, pred, + // Loop cond: loop_counter++ < num_iters; + [&](ExecutionScopeId id, CommandBuffer* cond_cmd) { + return cond_cmd->Launch(inc_and_cmp, id, ThreadDim(), BlockDim(), + loop_counter, pred, num_iters); + }, + // Loop body: b = a + b + [&](CommandBuffer* body_cmd) { + return body_cmd->Launch(add, ThreadDim(), BlockDim(), a, b, b); + })); + + // Create a barrier between two execution scopes. + TF_RETURN_IF_ERROR(cmd_buffer->Barrier(executor, {s0, s1})); + + return cmd_buffer->Finalize(); + }; + + // Create a command buffer with a single conditional operation. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(record(cmd_buffer.get(), 42, 10)); + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + // Copy `b` and `c` data back to host. + int32_t b_dst, c_dst; + TF_ASSERT_OK(stream->Memcpy(&b_dst, b, sizeof(int32_t))); + TF_ASSERT_OK(stream->Memcpy(&c_dst, c, sizeof(int32_t))); + + EXPECT_EQ(b_dst, 10); + EXPECT_EQ(c_dst, 42); + + // Check the command buffer structure. + GpuCommandBuffer* gpu_cmd_buffer = GpuCommandBuffer::Cast(cmd_buffer.get()); + + auto nodes0 = gpu_cmd_buffer->nodes(s0); + auto nodes1 = gpu_cmd_buffer->nodes(s1); + auto barriers0 = gpu_cmd_buffer->barriers(s0); + auto barriers1 = gpu_cmd_buffer->barriers(s1); + + // s0 should have only one real barrier joining while op and memset. + ASSERT_EQ(nodes0.size(), 1); + ASSERT_EQ(nodes1.size(), 3); + ASSERT_EQ(barriers0.size(), 2); + ASSERT_EQ(barriers1.size(), 4); + + // The final barrier that joins while and memset. + EXPECT_EQ(Deps(barriers0[1]), ExpectedDeps(nodes0[0], nodes1[2])); + + // Update bit pattern and number of iterations. + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(record(cmd_buffer.get(), 43, 20)); + + TF_ASSERT_OK(stream->MemZero(&loop_counter, sizeof(int32_t))); + TF_ASSERT_OK(stream->MemZero(&b, sizeof(int32_t))); + TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + + TF_ASSERT_OK(stream->Memcpy(&b_dst, b, sizeof(int32_t))); + TF_ASSERT_OK(stream->Memcpy(&c_dst, c, sizeof(int32_t))); + + EXPECT_EQ(b_dst, 20); + EXPECT_EQ(c_dst, 43); +} + +//===----------------------------------------------------------------------===// +// Performance benchmarks below +//===----------------------------------------------------------------------===// + +#define BENCHMARK_SIZES(NAME) \ + BENCHMARK(NAME)->Arg(8)->Arg(32)->Arg(128)->Arg(512)->Arg(1024); + +// In benchmarks we construct command buffers in nested mode when we +// do not want to measure graph executable instantiation overhead. +static void BM_CreateCommandBuffer(benchmark::State& state) { + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); + + DeviceMemory b = executor->AllocateArray(1, 0); + + for (auto s : state) { + auto cmd_buffer = CommandBuffer::Create(executor, nested).value(); + for (int i = 1; i < state.range(0); ++i) { + CHECK_OK(cmd_buffer->Launch(add, ThreadDim(), BlockDim(4), b, b, b)); + } + CHECK_OK(cmd_buffer->Finalize()); + } +} + +BENCHMARK_SIZES(BM_CreateCommandBuffer); + +static void BM_TraceCommandBuffer(benchmark::State& state) { + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); + + DeviceMemory b = executor->AllocateArray(1, 0); + + for (auto s : state) { + auto launch_kernels = [&](Stream* stream) { + for (int i = 1; i < state.range(0); ++i) { + CHECK_OK(stream->ThenLaunch(ThreadDim(), BlockDim(4), add, b, b, b)); + } + return absl::OkStatus(); + }; + + CHECK_OK(CommandBuffer::Trace(executor, launch_kernels, nested)); + } +} + +BENCHMARK_SIZES(BM_TraceCommandBuffer); + +static void BM_UpdateCommandBuffer(benchmark::State& state) { + Platform* platform = GpuPlatform(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); + + DeviceMemory b = executor->AllocateArray(1, 0); + + auto cmd_buffer = CommandBuffer::Create(executor, primary).value(); + for (int i = 1; i < state.range(0); ++i) { + CHECK_OK(cmd_buffer->Launch(add, ThreadDim(), BlockDim(4), b, b, b)); + } + CHECK_OK(cmd_buffer->Finalize()); + + for (auto s : state) { + CHECK_OK(cmd_buffer->Update()); + for (int i = 1; i < state.range(0); ++i) { + CHECK_OK(cmd_buffer->Launch(add, ThreadDim(), BlockDim(4), b, b, b)); + } + CHECK_OK(cmd_buffer->Finalize()); + } +} + +BENCHMARK_SIZES(BM_UpdateCommandBuffer); + +} // namespace stream_executor::gpu diff --git a/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc b/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc index 7735ae3be2c4d..c7a7ad403e3aa 100644 --- a/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc +++ b/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_cudamallocasync_allocator.h" +#include +#include #include #include #include @@ -28,14 +30,13 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" -#include "xla/stream_executor/device_id_utils.h" -#include "xla/stream_executor/gpu/gpu_init.h" -#include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/gpu/gpu_init.h" // IWYU pragma: keep +#include "xla/stream_executor/stream_executor.h" // IWYU pragma: keep +#include "xla/tsl/util/env_var.h" // IWYU pragma: keep #include "tsl/framework/allocator.h" #include "tsl/framework/device_id.h" #include "tsl/platform/logging.h" #include "tsl/platform/mutex.h" -#include "tsl/util/env_var.h" namespace stream_executor { @@ -102,11 +103,6 @@ void GpuCudaMallocAsyncAllocator::PrintAllocatorStatisticsNoLock() { #endif } -void GpuCudaMallocAsyncAllocator::PrintAllocatorStatistics() { - tsl::mutex_lock lock(lock_); - PrintAllocatorStatisticsNoLock(); -} - std::atomic GpuCudaMallocAsyncAllocator::number_instantiated_(0); GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator( @@ -121,8 +117,8 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator( (void)reserve_memory_; #if TF_CUDA_MALLOC_ASYNC_SUPPORTED - stream_exec_ = DeviceIdUtil::ExecutorForPlatformDeviceId(GPUMachineManager(), - platform_device_id) + stream_exec_ = GPUMachineManager() + ->ExecutorForDevice(platform_device_id.value()) .value(); // Initialized here as it only exist if compiled with a recent // enough CUDA. @@ -447,8 +443,7 @@ void GpuCudaMallocAsyncAllocator::SetStreamAndPreallocateMemory(void* stream) { void* ptr = AllocateRaw(0, prealloc_size); DeallocateRaw(ptr); VLOG(2) << Name() << " GpuCudaMallocAsyncAllocator reserved the pool for " - << prealloc_size << " bytes" - << ". First ptr: " << ptr; + << prealloc_size << " bytes" << ". First ptr: " << ptr; ClearStats(); } #endif diff --git a/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.h b/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.h index c428132d74aa9..7e9d274163228 100644 --- a/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.h +++ b/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,16 +16,17 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_CUDAMALLOCASYNC_ALLOCATOR_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_CUDAMALLOCASYNC_ALLOCATOR_H_ +#include +#include #include #include #include #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" -#include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_executor.h" // IWYU pragma: keep #include "tsl/framework/allocator.h" #include "tsl/framework/device_id.h" -#include "tsl/platform/macros.h" #include "tsl/platform/mutex.h" #if GOOGLE_CUDA @@ -34,7 +35,6 @@ limitations under the License. #define TF_CUDA_MALLOC_ASYNC_SUPPORTED CUDA_VERSION >= 11020 #endif // GOOGLE_CUDA - namespace stream_executor { // An allocator that wraps cudaMallocAsync. It has fewer fragmentation @@ -88,12 +88,6 @@ class GpuCudaMallocAsyncAllocator : public tsl::Allocator { void SetStreamAndPreallocateMemory(void* stream) override; - // With the right VLOG set, it prints: - // - the number of ptr currently allocated per size (histogram). - // - each ptr value and its size. - // - If CUDA_VERSION >= 11030, print cudaMallocAsync statistics. - void PrintAllocatorStatistics(); - static int GetInstantiatedCountTestOnly() { return number_instantiated_; } tsl::AllocatorMemoryType GetMemoryType() const override { diff --git a/xla/stream_executor/gpu/gpu_diagnostics.h b/xla/stream_executor/gpu/gpu_diagnostics.h index 7892a3bcf4f50..678a34e50a40e 100644 --- a/xla/stream_executor/gpu/gpu_diagnostics.h +++ b/xla/stream_executor/gpu/gpu_diagnostics.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,10 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_DIAGNOSTICS_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_DIAGNOSTICS_H_ +#include #include -#include "xla/stream_executor/platform/port.h" -#include "tsl/platform/statusor.h" +#include "absl/status/statusor.h" namespace stream_executor { namespace gpu { @@ -34,10 +34,10 @@ using DriverVersion = std::tuple; // string DriverVersionToString(DriverVersion version); // //// Converts a parsed driver version or status value to natural string form. -// string DriverVersionStatusToString(tsl::StatusOr version); +// string DriverVersionStatusToString(absl::StatusOr version); // //// Converts a string of a form like "331.79" to a DriverVersion{331, 79}. -// tsl::StatusOr StringToDriverVersion(const string& value); +// absl::StatusOr StringToDriverVersion(const string& value); class Diagnostician { public: @@ -58,15 +58,15 @@ class Diagnostician { // // This is solely used for more informative log messages when the user is // running on a machine that happens to have a libcuda/kernel driver mismatch. - static tsl::StatusOr FindKernelModuleVersion( + static absl::StatusOr FindKernelModuleVersion( const std::string& driver_version_file_contents); // Extracts the kernel driver version from the current host. - static tsl::StatusOr FindKernelDriverVersion(); + static absl::StatusOr FindKernelDriverVersion(); // Iterates through loaded DSOs with DlIteratePhdrCallback to find the // driver-interfacing DSO version number. Returns it as a string. - static tsl::StatusOr FindDsoVersion(); + static absl::StatusOr FindDsoVersion(); // Logs information about the kernel driver version and userspace driver // library version. @@ -80,12 +80,8 @@ class Diagnostician { // This is solely used for more informative log messages when the user is // running on a machine that happens to have a libcuda/kernel driver mismatch. static void WarnOnDsoKernelMismatch( - tsl::StatusOr dso_version, - tsl::StatusOr kernel_version); - - // Logs information about the dev nodes present on this machine: their - // existence, permissions, accessibility from this uid/gid. - static void LogDevNodeDiagnosticInformation(); + absl::StatusOr dso_version, + absl::StatusOr kernel_version); static std::string GetDevNodePath(int dev_node_ordinal); diff --git a/xla/stream_executor/gpu/gpu_driver.h b/xla/stream_executor/gpu/gpu_driver.h index 2dff75b93e1ce..91690f69649ce 100644 --- a/xla/stream_executor/gpu/gpu_driver.h +++ b/xla/stream_executor/gpu/gpu_driver.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,15 +21,17 @@ limitations under the License. #include #include +#include #include #include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/stream_executor/device_options.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/platform.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace stream_executor { namespace gpu { @@ -67,12 +69,12 @@ class GpuDriver { // all calls after the first. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__INITIALIZE.html#group__CUDA__INITIALIZE_1g0a2f1517e1bd8502c7194c3a8c134bc3 // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#initialization - static tsl::Status Init(); + static absl::Status Init(); // Returns the device associated with the given context. // device is an outparam owned by the caller, must not be null. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g4e84b109eba36cdaaade167f34ae881e - static tsl::StatusOr DeviceFromContext(GpuContext* context); + static absl::StatusOr DeviceFromContext(GpuContext* context); // Creates a new CUDA/HIP stream associated with the given context via // cuStreamCreate/hipStreamCreateWithFlags. @@ -99,14 +101,14 @@ class GpuDriver { // result is an outparam owned by the caller and must not be null. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g450687e75f3ff992fe01662a43d9d3db // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#cuda-driver-data-types - static tsl::Status InitEvent(GpuContext* context, GpuEventHandle* result, - EventFlags flags); + static absl::Status InitEvent(GpuContext* context, GpuEventHandle* result, + EventFlags flags); // Destroys *event and turns it into a nullptr. event may not be null, but // *event may be, via cuEventDestroy/hipEventDestroy // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g593ec73a8ec5a5fc031311d3e4dca1ef // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#event-management - static tsl::Status DestroyEvent(GpuContext* context, GpuEventHandle* event); + static absl::Status DestroyEvent(GpuContext* context, GpuEventHandle* event); // Allocates a GPU memory space of size bytes associated with the given // context via cuMemAlloc/hipMalloc. @@ -168,9 +170,6 @@ class GpuDriver { static int GetGpuStreamPriority( GpuContext* context, stream_executor::StreamPriority stream_priority); - // Virtual memory support was added to CUDA in 10.2 -#if CUDA_VERSION >= 10020 - // Reserves a range of virtual device memory addresses via // cuMemAddressReserve. bytes must be a multiple of the host page size. // Returns nullptr base address in VmemSpan if the reservation fails. @@ -180,8 +179,8 @@ class GpuDriver { // Size in bytes. uint64_t size_bytes; }; - static tsl::StatusOr ReserveVirtualMemory(GpuContext* context, - uint64_t bytes); + static absl::StatusOr ReserveVirtualMemory(GpuContext* context, + uint64_t bytes); // Frees a range of virtual addresses that were previously reserved through // ReserveVirtualMemory via cuMemAddressFree. @@ -191,7 +190,7 @@ class GpuDriver { // Calculates the minimum alignment for memory allocations done through // cuMemCreate via cuMemGetAllocationGranularity. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1g30ee906c2cf66a0347b3dfec3d7eb31a - static tsl::StatusOr GetMinAllocationGranularity( + static absl::StatusOr GetMinAllocationGranularity( GpuDeviceHandle device); // Allocates physical memory and returns a handle that can be mapped to @@ -202,7 +201,7 @@ class GpuDriver { uint64_t handle; uint64_t bytes; }; - static tsl::StatusOr CreateMemoryHandle( + static absl::StatusOr CreateMemoryHandle( GpuContext* context, uint64_t bytes); // Frees memory represented by the provided MemoryHandle via cuMemRelease. @@ -214,7 +213,7 @@ class GpuDriver { // cuMemMap and sets the appropriate access settings via cuMemSetAccess. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1gff1d395423af5c5c75375516959dae56 // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1g1b6b12b10e8324bf462ecab4e7ef30e1 - static tsl::Status MapMemory( + static absl::Status MapMemory( GpuContext* context, GpuDevicePtr va, const GenericMemoryHandle& handle, const std::vector& device_handles); @@ -224,19 +223,17 @@ class GpuDriver { // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1gfb50aac00c848fd7087e858f59bf7e2a static void UnmapMemory(GpuContext* context, GpuDevicePtr va, uint64_t bytes); -#endif // CUDA_VERSION >= 10200 - // Given a device ordinal, returns a device handle into the device outparam, // which must not be null. // // N.B. these device handles do not have a corresponding destroy function in // the CUDA/HIP driver API. - static tsl::Status GetDevice(int device_ordinal, GpuDeviceHandle* device); + static absl::Status GetDevice(int device_ordinal, GpuDeviceHandle* device); // Given a device handle, returns the name reported by the driver for the // device. - static tsl::Status GetDeviceName(GpuDeviceHandle device, - std::string* device_name); + static absl::Status GetDeviceName(GpuDeviceHandle device, + std::string* device_name); // Given a device to create a context for, returns a context handle into the // context outparam, which must not be null. @@ -245,9 +242,8 @@ class GpuDriver { // calling thread. Current documentation on contexts and their influence on // userspace processes is given here: // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g65dc0012348bc84810e2103a40d8e2cf - static tsl::Status CreateContext(int device_ordinal, GpuDeviceHandle device, - const DeviceOptions& device_options, - GpuContext** context); + static absl::Status CreateContext(int device_ordinal, GpuDeviceHandle device, + GpuContext** context); // Destroys the provided context via cuCtxDestroy. // Don't do this while clients could still be using the context, per the docs @@ -264,35 +260,33 @@ class GpuDriver { // in terms of integer-sized values, so there's no potential for overrun (as // of CUDA 5.5). // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1g5e92a1b0d8d1b82cb00dcfb2de15961b - static tsl::Status FuncGetAttribute(GpuFunctionAttribute attribute, - GpuFunctionHandle function, - int* attribute_value); + static absl::Status FuncGetAttribute(GpuFunctionAttribute attribute, + GpuFunctionHandle function, + int* attribute_value); // Sets the preferred cache configuration for the specified function. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1g40f8c11e81def95dc0072a375f965681 - static tsl::Status FuncSetCacheConfig(GpuFunctionHandle function, - GpuFuncCachePreference cache_config); + static absl::Status FuncSetCacheConfig(GpuFunctionHandle function, + GpuFuncCachePreference cache_config); // Gets the preferred shared memory bank configuration for the specified // CONTEXT (not function!), either default or four- or eight-byte bank size. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g17153a1b8b8c756f7ab8505686a4ad74 // https://rocm.docs.amd.com/projects/HIP/en/latest/.doxygen/docBin/html/group___execution.html - static tsl::StatusOr ContextGetSharedMemConfig( + static absl::StatusOr ContextGetSharedMemConfig( GpuContext* context); // Sets the preferred shared memory bank configuration for the specified // CONTEXT (not function!), either default or four- or eight-byte bank size. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g2574235fa643f8f251bf7bc28fac3692 // https://rocm.docs.amd.com/projects/HIP/en/latest/.doxygen/docBin/html/group___execution.html - static tsl::Status ContextSetSharedMemConfig( + static absl::Status ContextSetSharedMemConfig( GpuContext* context, GpuSharedMemConfig shared_mem_config); // Launches a CUDA/ROCm kernel via cuLaunchKernel/hipModuleLaunchKernel. - // TODO(leary) describe the structure of kernel_params and extra in a readable - // way. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb8f3dc3031b40da29d5f9a7139e52e15 // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#execution-control - static tsl::Status LaunchKernel( + static absl::Status LaunchKernel( GpuContext* context, absl::string_view kernel_name, GpuFunctionHandle function, unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, @@ -300,28 +294,46 @@ class GpuDriver { unsigned int block_dim_z, unsigned int shared_mem_bytes, GpuStreamHandle stream, void** kernel_params, void** extra); + // Launches a CUDA/ROCm kernel via cuLaunchKernelEx/hipModuleLaunchKernelEx. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb9c891eb6bb8f4089758e64c9c976db9 + static absl::Status LaunchKernel( + GpuContext* context, absl::string_view kernel_name, + GpuFunctionHandle function, unsigned int cluster_dim_x, + unsigned int cluster_dim_y, unsigned int cluster_dim_z, + unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, + unsigned int block_dim_x, unsigned int block_dim_y, + unsigned int block_dim_z, unsigned int shared_mem_bytes, + GpuStreamHandle stream, void** kernel_params, void** extra); + // Creates a new GPU graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1gd885f719186010727b75c3315f865fdf // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management - static tsl::Status CreateGraph(GpuGraphHandle* graph); + static absl::Status CreateGraph(GpuGraphHandle* graph); // Destroys GPU graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g718cfd9681f078693d4be2426fd689c8 // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management - static tsl::Status DestroyGraph(GpuGraphHandle graph); + static absl::Status DestroyGraph(GpuGraphHandle graph); // Begins graph capture on a stream. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g767167da0bbf07157dc20b6c258a2143 // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management enum class StreamCaptureMode { kGlobal, kThreadLocal, kRelaxed }; - static tsl::Status StreamBeginCapture(GpuStreamHandle stream, - StreamCaptureMode mode); + static absl::Status StreamBeginCapture(GpuStreamHandle stream, + StreamCaptureMode mode); + + // Begins graph capture on a stream to an existing graph. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1gac495e0527d1dd6437f95ee482f61865 + // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management + static absl::Status StreamBeginCaptureToGraph(GpuStreamHandle stream, + GpuGraphHandle graph, + StreamCaptureMode mode); // Ends capture on a stream, returning the captured graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g03dab8b2ba76b00718955177a929970c // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management - static tsl::Status StreamEndCapture(GpuStreamHandle stream, - GpuGraphHandle* graph); + static absl::Status StreamEndCapture(GpuStreamHandle stream, + GpuGraphHandle* graph); // Graph instantiation flags. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1g070bf5517d3a7915667c256eefce4956 @@ -341,15 +353,22 @@ class GpuDriver { // Creates an executable graph from a graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1gb53b435e178cccfa37ac87285d2c3fa1 // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management - static tsl::Status GraphInstantiate(GpuGraphExecHandle* exec, - GpuGraphHandle graph, - const GraphInstantiateFlags& flags); + static absl::Status GraphInstantiate(GpuGraphExecHandle* exec, + GpuGraphHandle graph, + const GraphInstantiateFlags& flags); // Launches an executable graph in a stream. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g6b2dceb3901e71a390d2bd8b0491e471 // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management - static tsl::Status GraphLaunch(GpuGraphExecHandle exec, - GpuStreamHandle stream); + static absl::Status GraphLaunch(GpuGraphExecHandle exec, + GpuStreamHandle stream); + + // Enables or disables the specified node in the given exec. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g371b20eb0c0658731e38db7e68f12c78 + // https://rocm.docs.amd.com/projects/HIP/en/latest/.doxygen/docBin/html/group___graph.html#ga8902200d9fed1df7644fc7a51c4d327b + static absl::Status GraphNodeSetEnabled(GpuGraphExecHandle exec, + GpuGraphNodeHandle node, + bool enabled); // Graph update result. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1g8edc8969ff6ae00b7cd5d7292f812c3c @@ -379,9 +398,9 @@ class GpuDriver { // the update if possible. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g96efefc56df46927da7297f122adfb9f // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management - static tsl::Status GraphExecUpdate(GpuGraphExecHandle exec, - GpuGraphHandle graph, - GraphExecUpdateResultInfo* result); + static absl::Status GraphExecUpdate(GpuGraphExecHandle exec, + GpuGraphHandle graph, + GraphExecUpdateResultInfo* result); // Graph node type. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1g0731a28f826922120d783d8444e154dc @@ -405,31 +424,40 @@ class GpuDriver { // Return the node type of the graph node. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1gdb1776d97aa1c9d5144774b29e4b8c3e // https://docs.amd.com/projects/HIP/en/docs-5.0.0/doxygen/html/group___graph.html#ga87c68ae9408a6438d4a1101560ceea11 - static tsl::StatusOr GraphNodeGetType(GpuGraphNodeHandle node); + static absl::StatusOr GraphNodeGetType( + GpuGraphNodeHandle node); + + // Returns a node's dependencies. + // + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g048f4c0babcbba64a933fc277cd45083 + static absl::StatusOr> + GraphNodeGetDependencies(GpuGraphNodeHandle node); // Destroys an executable graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1ga32ad4944cc5d408158207c978bc43a7 // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management - static tsl::Status DestroyGraphExec(GpuGraphExecHandle exec); + static absl::Status DestroyGraphExec(GpuGraphExecHandle exec); // Write a DOT file describing graph structure. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g0fb0c4d319477a0a98da005fcb0dacc4 // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management - static tsl::Status GraphDebugDotPrint(GpuGraphHandle graph, const char* path); + static absl::StatusOr GraphDebugDotPrint( + GpuGraphHandle graph, const char* path, + bool return_printed_graph = false); // Returns a stream's capture status. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g37823c49206e3704ae23c7ad78560bca // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#stream-management - static tsl::StatusOr StreamIsCapturing(GpuStreamHandle stream); + static absl::StatusOr StreamIsCapturing(GpuStreamHandle stream); // Free unused memory that was cached on the specified device for use with // graphs back to the OS. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g57c87f4ba6af41825627cdd4e5a8c52b - static tsl::Status DeviceGraphMemTrim(GpuDeviceHandle device); + static absl::Status DeviceGraphMemTrim(GpuDeviceHandle device); // Creates a conditional handle. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1gece6f3b9e85d0edb8484d625fe567376 - static tsl::Status GraphConditionalHandleCreate( + static absl::Status GraphConditionalHandleCreate( GpuGraphConditionalHandle* handle, GpuGraphHandle graph, GpuContext* context, unsigned int default_launch_value, unsigned int flags); @@ -459,32 +487,38 @@ class GpuDriver { // Adds a node of arbitrary type to a graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g4210c258cbba352040a26d1b4e658f9d - static tsl::StatusOr GraphAddNode( + static absl::StatusOr GraphAddNode( GpuGraphNodeHandle* node, GpuGraphHandle graph, - absl::Span deps, const GpuGraphNodeParams& params); + absl::Span deps, + const GpuGraphNodeParams& params); // Creates an empty node and adds it to a graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g14b625984430cb2d574c63f29c9b9223 - static tsl::Status GraphAddEmptyNode(GpuGraphNodeHandle* node, - GpuGraphHandle graph, - absl::Span deps); + static absl::Status GraphAddEmptyNode( + GpuGraphNodeHandle* node, GpuGraphHandle graph, + absl::Span deps); // Creates a kernel execution node and adds it to a graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g50d871e3bd06c1b835e52f2966ef366b // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management - static tsl::Status GraphAddKernelNode( + static absl::Status GraphAddKernelNode( GpuGraphNodeHandle* node, GpuGraphHandle graph, - absl::Span deps, absl::string_view kernel_name, + absl::Span deps, absl::string_view kernel_name, GpuFunctionHandle function, unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x, unsigned int block_dim_y, unsigned int block_dim_z, unsigned int shared_mem_bytes, void** kernel_params, void** extra); + // Counts number of nodes in the graph. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1gfa35a8e2d2fc32f48dbd67ba27cf27e5 + // https://docs.amd.com/projects/HIP/en/docs-5.0.0/doxygen/html/group___graph.html#gaf006701d98164ed3492755bbb19bab83 + static absl::StatusOr GraphGetNodeCount(GpuGraphHandle graph); + // Sets the parameters for a kernel node in the given graph exec. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1gd84243569e4c3d6356b9f2eea20ed48c // https://docs.amd.com/projects/HIP/en/docs-5.0.0/doxygen/html/group___graph.html#ga5b1918dae65224863b7370e6d4ad3f2a - static tsl::Status GraphExecKernelNodeSetParams( + static absl::Status GraphExecKernelNodeSetParams( GpuGraphExecHandle exec, GpuGraphNodeHandle node, absl::string_view kernel_name, GpuFunctionHandle function, unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, @@ -519,86 +553,89 @@ class GpuDriver { // Creates a memory allocation node and adds it to a graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g73a351cb71b2945a0bcb913a93f69ec9 - static tsl::Status GraphAddMemAllocNode( + static absl::Status GraphAddMemAllocNode( GpuGraphNodeHandle* node, GpuGraphHandle graph, - absl::Span deps, MemAccessFlags access_flags, + absl::Span deps, MemAccessFlags access_flags, MemLocationType location_type, int device_id, MemAllocationType allocation_type, uint64_t size, GpuDevicePtr* d_ptr, uint64_t max_pool_size = 0); // Fetch memory allocation node's allocated address; // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1gee2c7d66d3d96b1470c1d1a769f250a2 - static tsl::StatusOr> + static absl::StatusOr> GraphGetMemAllocNodeParams(GpuGraphNodeHandle node); + // Create a memfree node and adds it to a graph. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1geb7cdce5d9be2d28d9428e74eb00fa53 + static absl::Status GraphAddMemFreeNode( + GpuGraphNodeHandle* node, GpuGraphHandle graph, + absl::Span deps, GpuDevicePtr gpu_dst); + // Creates a memcpy node and adds it to a graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g674da6ab54a677f13e0e0e8206ff5073 - static tsl::Status GraphAddMemcpyD2DNode(GpuContext* context, - GpuGraphNodeHandle* node, - GpuGraphHandle graph, - absl::Span deps, - GpuDevicePtr gpu_dst, - GpuDevicePtr gpu_src, uint64_t size); + static absl::Status GraphAddMemcpyD2DNode( + GpuContext* context, GpuGraphNodeHandle* node, GpuGraphHandle graph, + absl::Span deps, GpuDevicePtr gpu_dst, + GpuDevicePtr gpu_src, uint64_t size); // Sets the parameters for a memcpy node in the given graphExec. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g26186d58858ab32ccc7425b53786cce5 - static tsl::Status GraphExecMemcpyD2DNodeSetParams( + static absl::Status GraphExecMemcpyD2DNodeSetParams( GpuContext* context, GpuGraphExecHandle exec, GpuGraphNodeHandle node, GpuDevicePtr gpu_dst, GpuDevicePtr gpu_src, uint64_t size); // Creates a memset node and adds it to a graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g89dc8fc3743392777c0daa2c4aca40d3 - static tsl::Status GraphAddMemsetNode( + static absl::Status GraphAddMemsetNode( GpuContext* context, GpuGraphNodeHandle* node, GpuGraphHandle graph, - absl::Span deps, GpuDevicePtr dst, + absl::Span deps, GpuDevicePtr dst, std::variant bit_pattern, uint64_t num_elements); // Sets the parameters for a memset node in the given graph exec. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g5df5be09a0b7b3513e740ebbbcd59739 - static tsl::Status GraphExecMemsetNodeSetParams( + static absl::Status GraphExecMemsetNodeSetParams( GpuContext* context, GpuGraphExecHandle exec, GpuGraphNodeHandle node, GpuDevicePtr dst, std::variant bit_pattern, uint64_t num_elements); // Creates a child graph node and adds it to a graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1gde52afbcf91a8c79d4d7efbe0e3b6844 - static tsl::Status GraphAddChildNode(GpuGraphNodeHandle* node, - GpuGraphHandle graph, - absl::Span deps, - GpuGraphHandle child); + static absl::Status GraphAddChildNode( + GpuGraphNodeHandle* node, GpuGraphHandle graph, + absl::Span deps, GpuGraphHandle child); // Sets the parameters for a child graph node in the given graph exec. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g8f2d9893f6b899f992db1a2942ec03ff - static tsl::Status GraphExecChildNodeSetParams(GpuGraphExecHandle exec, - GpuGraphNodeHandle node, - GpuGraphHandle child); + static absl::Status GraphExecChildNodeSetParams(GpuGraphExecHandle exec, + GpuGraphNodeHandle node, + GpuGraphHandle child); // Loads ptx_contents with the CUDA driver's PTX JIT and stores the resulting // handle in "module". Any error logs that are produced are logged internally. // (supported on CUDA only) - static tsl::Status LoadPtx(GpuContext* context, const char* ptx_contents, - GpuModuleHandle* module); + static absl::Status LoadPtx(GpuContext* context, const char* ptx_contents, + GpuModuleHandle* module); // Loads cubin_bytes with the CUDA driver's blob loading interface and stores // the resulting handle in "module". // (supported on CUDA only) - static tsl::Status LoadCubin(GpuContext* context, const char* cubin_bytes, - GpuModuleHandle* module); + static absl::Status LoadCubin(GpuContext* context, const char* cubin_bytes, + GpuModuleHandle* module); // Loads HSACO with the ROCM runtime and stores the resulting handle in // "module". Any error logs that are produced are logged internally. // (supported on ROCm only) - static tsl::Status LoadHsaco(GpuContext* context, const char* hsaco_contents, - GpuModuleHandle* module); + static absl::Status LoadHsaco(GpuContext* context, const char* hsaco_contents, + GpuModuleHandle* module); // Retrieves a named kernel from a loaded module, and places the resulting // handle into function (outparam) on success. Neither kernel_name nor // function may be null. No ownership is taken of kernel_name. - static tsl::Status GetModuleFunction(GpuContext* context, - GpuModuleHandle module, - const char* kernel_name, - GpuFunctionHandle* function); + static absl::Status GetModuleFunction(GpuContext* context, + GpuModuleHandle module, + const char* kernel_name, + GpuFunctionHandle* function); // Retrieves a named global/constant symbol from a loaded module, and returns // a device pointer and size of the symbol on success. symbol_name may not be @@ -616,45 +653,46 @@ class GpuDriver { // Performs a synchronous memset of the device memory segment via cuMemsetD8. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g6e582bf866e9e2fb014297bfaf354d7b - static tsl::Status SynchronousMemsetUint8(GpuContext* context, - GpuDevicePtr location, - uint8_t value, size_t size); + static absl::Status SynchronousMemsetUint8(GpuContext* context, + GpuDevicePtr location, + uint8_t value, size_t size); // Performs a synchronous memset of the device memory segment via cuMemsetD32. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g983e8d8759acd1b64326317481fbf132 - static tsl::Status SynchronousMemsetUint32(GpuContext* context, - GpuDevicePtr location, - uint32_t value, - size_t uint32_count); + static absl::Status SynchronousMemsetUint32(GpuContext* context, + GpuDevicePtr location, + uint32_t value, + size_t uint32_count); // Performs an asynchronous memset of the device memory segment via // cuMemsetD8Async. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gaef08a7ccd61112f94e82f2b30d43627 - static tsl::Status AsynchronousMemsetUint8(GpuContext* context, - GpuDevicePtr location, - uint8_t value, size_t uint32_count, - GpuStreamHandle stream); + static absl::Status AsynchronousMemsetUint8(GpuContext* context, + GpuDevicePtr location, + uint8_t value, + size_t uint32_count, + GpuStreamHandle stream); // Performs an asynchronous memset of the device memory segment via // cuMemsetD32Async. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g58229da5d30f1c0cdf667b320ec2c0f5 - static tsl::Status AsynchronousMemsetUint32(GpuContext* context, - GpuDevicePtr location, - uint32_t value, - size_t uint32_count, - GpuStreamHandle stream); + static absl::Status AsynchronousMemsetUint32(GpuContext* context, + GpuDevicePtr location, + uint32_t value, + size_t uint32_count, + GpuStreamHandle stream); // -- Synchronous memcopies. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g4d32266788c440b0220b1a9ba5795169 - static tsl::Status SynchronousMemcpyD2H(GpuContext* context, void* host_dst, - GpuDevicePtr gpu_src, uint64_t size); - static tsl::Status SynchronousMemcpyH2D(GpuContext* context, - GpuDevicePtr gpu_dst, - const void* host_src, uint64_t size); - static tsl::Status SynchronousMemcpyD2D(GpuContext* context, - GpuDevicePtr gpu_dst, - GpuDevicePtr gpu_src, uint64_t size); + static absl::Status SynchronousMemcpyD2H(GpuContext* context, void* host_dst, + GpuDevicePtr gpu_src, uint64_t size); + static absl::Status SynchronousMemcpyH2D(GpuContext* context, + GpuDevicePtr gpu_dst, + const void* host_src, uint64_t size); + static absl::Status SynchronousMemcpyD2D(GpuContext* context, + GpuDevicePtr gpu_dst, + GpuDevicePtr gpu_src, uint64_t size); // -- Asynchronous memcopies. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g56f30236c7c5247f8e061b59d3268362 @@ -700,8 +738,8 @@ class GpuDriver { // amount of time? // // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g15e49dd91ec15991eb7c0a741beb7dad - static tsl::Status SynchronizeStream(GpuContext* context, - GpuStreamHandle stream); + static absl::Status SynchronizeStream(GpuContext* context, + GpuStreamHandle stream); // Blocks the calling thread until the operations associated with the context // have been completed, via cuCtxSynchronize. @@ -728,7 +766,7 @@ class GpuDriver { // Enables peer access per CanEnablePeerAccess, via cuCtxEnablePeerAccess. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PEER__ACCESS.html#group__CUDA__PEER__ACCESS_1g0889ec6728e61c05ed359551d67b3f5a - static tsl::Status EnablePeerAccess(GpuContext* from, GpuContext* to); + static absl::Status EnablePeerAccess(GpuContext* from, GpuContext* to); // Returns the elapsed milliseconds between start and stop via // cuEventElapsedTime. @@ -740,29 +778,30 @@ class GpuDriver { // Records that an event occurred when execution reaches the current point in // thestream via cuEventRecord. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g95424d3be52c4eb95d83861b70fb89d1 - static tsl::Status RecordEvent(GpuContext* context, GpuEventHandle event, - GpuStreamHandle stream); + static absl::Status RecordEvent(GpuContext* context, GpuEventHandle event, + GpuStreamHandle stream); // Polls (without blocking) to determine the status of an event - pending or // complete (or an error status). // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g6f0704d755066b0ee705749ae911deef - static tsl::StatusOr QueryEvent(GpuContext* context, - GpuEventHandle event); + static absl::StatusOr QueryEvent(GpuContext* context, + GpuEventHandle event); // -- Pointer-specific calls. // Returns the context in which pointer was allocated or registered. - static tsl::StatusOr GetPointerContext(GpuDevicePtr pointer); + static absl::StatusOr GetPointerContext(GpuDevicePtr pointer); // Returns the device associated with the context from GetPointerContext(). - static tsl::StatusOr GetPointerDevice(GpuDevicePtr pointer); + static absl::StatusOr GetPointerDevice(GpuDevicePtr pointer); // Returns the memory space addressed by pointer. - static tsl::StatusOr GetPointerMemorySpace(GpuDevicePtr pointer); + static absl::StatusOr GetPointerMemorySpace( + GpuDevicePtr pointer); // Returns the base address and size of the device pointer dptr. - static tsl::Status GetPointerAddressRange(GpuDevicePtr dptr, - GpuDevicePtr* base, size_t* size); + static absl::Status GetPointerAddressRange(GpuDevicePtr dptr, + GpuDevicePtr* base, size_t* size); // -- Device-specific calls. @@ -770,62 +809,63 @@ class GpuDriver { // This is currently done via the deprecated device API. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html#group__CUDA__DEVICE__DEPRECATED_1ge2091bbac7e1fb18c2821612115607ea // (supported on CUDA only) - static tsl::Status GetComputeCapability(int* cc_major, int* cc_minor, - GpuDeviceHandle device); + static absl::Status GetComputeCapability(int* cc_major, int* cc_minor, + GpuDeviceHandle device); // Returns Gpu ISA version for the device; i.e 803, 900. // (supported on ROCm only) - static tsl::Status GetGpuISAVersion(int* version, GpuDeviceHandle device); + static absl::Status GetGpuISAVersion(int* version, GpuDeviceHandle device); // Return the full GCN Architecture Name for the device // for eg: amdgcn-amd-amdhsa--gfx908:sramecc+:xnack- // (supported on ROCm only) - static tsl::Status GetGpuGCNArchName(GpuDeviceHandle device, - std::string* gcnArchName); + static absl::Status GetGpuGCNArchName(GpuDeviceHandle device, + std::string* gcnArchName); #if TENSORFLOW_USE_ROCM // tests the current device for MFMA insn support (ROCm only) - static tsl::StatusOr GetMFMASupport(); + static absl::StatusOr GetMFMASupport(); #endif // Returns the number of multiprocessors on the device (note that the device // may be multi-GPU-per-board). - static tsl::StatusOr GetMultiprocessorCount(GpuDeviceHandle device); + static absl::StatusOr GetMultiprocessorCount(GpuDeviceHandle device); // Returns the limit on number of threads that can be resident in a single // multiprocessor. - static tsl::StatusOr GetMaxThreadsPerMultiprocessor( + static absl::StatusOr GetMaxThreadsPerMultiprocessor( GpuDeviceHandle device); // Returns the limit on number of threads which may be resident for a single // block (cooperative thread array). - static tsl::StatusOr GetMaxThreadsPerBlock(GpuDeviceHandle device); + static absl::StatusOr GetMaxThreadsPerBlock(GpuDeviceHandle device); // Returns the amount of shared memory available on a single GPU core (i.e. // SM on NVIDIA devices). - static tsl::StatusOr GetMaxSharedMemoryPerCore( + static absl::StatusOr GetMaxSharedMemoryPerCore( GpuDeviceHandle device); // Returns the amount of static shared memory available for a single block // (cooperative thread array). - static tsl::StatusOr GetMaxSharedMemoryPerBlock( + static absl::StatusOr GetMaxSharedMemoryPerBlock( GpuDeviceHandle device); // Returns the total amount of shared memory available for a single block // (cooperative thread array). - static tsl::StatusOr GetMaxSharedMemoryPerBlockOptin( + static absl::StatusOr GetMaxSharedMemoryPerBlockOptin( GpuDeviceHandle device); // Returns the maximum supported number of registers per block. - static tsl::StatusOr GetMaxRegistersPerBlock(GpuDeviceHandle device); + static absl::StatusOr GetMaxRegistersPerBlock( + GpuDeviceHandle device); // Returns the number of threads per warp. - static tsl::StatusOr GetThreadsPerWarp(GpuDeviceHandle device); + static absl::StatusOr GetThreadsPerWarp(GpuDeviceHandle device); // Queries the grid limits for device with cuDeviceGetAttribute calls. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g9c3e1414f0ad901d3278a4d6645fc266 - static tsl::Status GetGridLimits(int* x, int* y, int* z, - GpuDeviceHandle device); + static absl::Status GetGridLimits(int* x, int* y, int* z, + GpuDeviceHandle device); // Returns a grab-bag of device properties in a caller-owned device_properties // structure for device_ordinal via cuDeviceGetProperties. @@ -840,8 +880,8 @@ class GpuDriver { // Gets a specific integer-valued property about the given device. // // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g9c3e1414f0ad901d3278a4d6645fc266 - static tsl::StatusOr GetDeviceAttribute(GpuDeviceAttribute attribute, - GpuDeviceHandle device); + static absl::StatusOr GetDeviceAttribute(GpuDeviceAttribute attribute, + GpuDeviceHandle device); // Returns whether ECC is enabled for the given GpuDeviceHandle via // cuDeviceGetattribute with CU_DEVICE_ATTRIBUTE_ECC_ENABLED. @@ -877,7 +917,7 @@ class GpuDriver { // compatible driver). // // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VERSION.html#group__CUDA__VERSION_1g8b7a10395392e049006e61bcdc8ebe71 - static bool GetDriverVersion(int* driver_version); + static absl::StatusOr GetDriverVersion(); // -- Other calls @@ -885,7 +925,7 @@ class GpuDriver { // specified kernel/GpuFunctionHandle when launched with the specified // parameters. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__OCCUPANCY.html#group__CUDA__OCCUPANCY_1gcc6e1094d05cba2cee17fe33ddd04a98 - static tsl::StatusOr GetMaxOccupiedBlocksPerCore( + static absl::StatusOr GetMaxOccupiedBlocksPerCore( GpuContext* context, GpuFunctionHandle kernel, int threads_per_block, size_t dynamic_shared_memory_bytes); diff --git a/xla/stream_executor/gpu/gpu_event.cc b/xla/stream_executor/gpu/gpu_event.cc index 78b1663ae0ea1..bc714a519343c 100644 --- a/xla/stream_executor/gpu/gpu_event.cc +++ b/xla/stream_executor/gpu/gpu_event.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,9 +15,11 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_event.h" +#include "absl/status/status.h" +#include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "tsl/platform/statusor.h" +#include "xla/stream_executor/gpu/gpu_types.h" namespace stream_executor { namespace gpu { @@ -27,16 +29,16 @@ GpuEvent::GpuEvent(GpuExecutor* parent) GpuEvent::~GpuEvent() {} -tsl::Status GpuEvent::Init() { +absl::Status GpuEvent::Init() { return GpuDriver::InitEvent(parent_->gpu_context(), &gpu_event_, GpuDriver::EventFlags::kDisableTiming); } -tsl::Status GpuEvent::Destroy() { +absl::Status GpuEvent::Destroy() { return GpuDriver::DestroyEvent(parent_->gpu_context(), &gpu_event_); } -tsl::Status GpuEvent::Record(GpuStream* stream) { +absl::Status GpuEvent::Record(GpuStream* stream) { return GpuDriver::RecordEvent(parent_->gpu_context(), gpu_event_, stream->gpu_stream()); } diff --git a/xla/stream_executor/gpu/gpu_event.h b/xla/stream_executor/gpu/gpu_event.h index 5653e309c675f..2c8b588dab76c 100644 --- a/xla/stream_executor/gpu/gpu_event.h +++ b/xla/stream_executor/gpu/gpu_event.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,10 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_EVENT_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_EVENT_H_ +#include "absl/status/status.h" #include "xla/stream_executor/event.h" -#include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "tsl/platform/status.h" +#include "xla/stream_executor/gpu/gpu_types.h" namespace stream_executor { namespace gpu { @@ -33,14 +33,14 @@ class GpuEvent : public internal::EventInterface { ~GpuEvent() override; // Populates the CUDA-platform-specific elements of this object. - tsl::Status Init(); + absl::Status Init(); // Deallocates any platform-specific elements of this object. This is broken // out (not part of the destructor) to allow for error reporting. - tsl::Status Destroy(); + absl::Status Destroy(); // Inserts the event at the current position into the specified stream. - tsl::Status Record(GpuStream* stream); + absl::Status Record(GpuStream* stream); // Polls the CUDA platform for the event's current status. Event::Status PollForStatus(); diff --git a/xla/stream_executor/gpu/gpu_executor.h b/xla/stream_executor/gpu/gpu_executor.h index 1055c167c9d94..3c6787575454b 100644 --- a/xla/stream_executor/gpu/gpu_executor.h +++ b/xla/stream_executor/gpu/gpu_executor.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,30 +22,40 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_EXECUTOR_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_EXECUTOR_H_ +#include #include +#include #include #include #include #include -#include #include +#include #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/functional/any_invocable.h" -#include "absl/strings/string_view.h" +#include "absl/numeric/int128.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event.h" -#include "xla/stream_executor/gpu/gpu_kernel.h" +#include "xla/stream_executor/fft.h" +#include "xla/stream_executor/gpu/gpu_collectives.h" +#include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" -#include "tsl/platform/fingerprint.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" +#include "tsl/platform/thread_annotations.h" namespace stream_executor { @@ -53,6 +63,9 @@ class StreamExecutor; namespace gpu { +class GpuKernel; +class GpuCommandBuffer; + // CUDA-platform implementation of the platform-agnostic // StreamExecutorInterface. class GpuExecutor : public internal::StreamExecutorInterface { @@ -99,39 +112,42 @@ class GpuExecutor : public internal::StreamExecutorInterface { ~GpuExecutor() override; - tsl::Status Init(int device_ordinal, DeviceOptions device_options) override; + absl::Status Init(int device_ordinal) override; int device_ordinal() const override { return device_ordinal_; }; - tsl::Status GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) override; + absl::Status GetKernel(const MultiKernelLoaderSpec& spec, + Kernel* kernel) override; // (supported on CUDA only) void UnloadKernel(const Kernel* kernel) override; - tsl::Status LoadModule(const MultiModuleLoaderSpec& spec, - ModuleHandle* module_handle) override; + absl::Status LoadModule(const MultiModuleLoaderSpec& spec, + ModuleHandle* module_handle) override; bool UnloadModule(ModuleHandle module_handle) override; // Allocates and initializes a new constant on the device with the given // content. Or, if a device with identical content is already on-device, // returns a pointer to that buffer with shared ownership. - tsl::StatusOr> CreateOrShareConstant( + absl::StatusOr> CreateOrShareConstant( Stream* stream, absl::Span content) override; - tsl::Status Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, const Kernel& k, - const KernelArgs& args) override; + absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, + const BlockDim& block_dims, const Kernel& kernel, + const KernelArgs& args) override; - tsl::Status Submit(Stream* stream, - const CommandBuffer& command_buffer) override; + absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, + const BlockDim& block_dims, + const ClusterDim& cluster_dims, const Kernel& kernel, + const KernelArgs& args) override; + + absl::Status Submit(Stream* stream, + const CommandBuffer& command_buffer) override; - // (supported on CUDA only) int CalculateOccupancy(const DeviceDescription& device_description, uint64_t registers_per_thread, uint64_t shared_memory_per_block, const ThreadDim& thread_dims, GpuFunctionHandle func); - // (supported on CUDA only) int CompareOccupancy(int* initial_blocks, const DeviceDescription& device_description, uint64_t registers_per_thread, @@ -140,9 +156,6 @@ class GpuExecutor : public internal::StreamExecutorInterface { DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; - void* GetSubBuffer(DeviceMemoryBase* mem, uint64_t offset_bytes, - uint64_t size_bytes) override; - void Deallocate(DeviceMemoryBase* mem) override; void* UnifiedMemoryAllocate(uint64_t size) override { @@ -153,6 +166,14 @@ class GpuExecutor : public internal::StreamExecutorInterface { return GpuDriver::UnifiedMemoryDeallocate(context_, location); } + absl::StatusOr CollectiveMemoryAllocate(uint64_t size) override { + return GpuCollectives::CollectiveMemoryAllocate(context_, size); + } + + absl::Status CollectiveMemoryDeallocate(void* location) override { + return GpuCollectives::CollectiveMemoryDeallocate(context_, location); + } + // CUDA allocation/registration functions are necessary because the driver // internally sets up buffers for DMA operations (and page locks them). // There's no external interface for us to otherwise control these DMA @@ -171,41 +192,42 @@ class GpuExecutor : public internal::StreamExecutorInterface { bool SynchronizeAllActivity() override; - tsl::Status SynchronousMemZero(DeviceMemoryBase* location, - uint64_t size) override; + absl::Status SynchronousMemZero(DeviceMemoryBase* location, + uint64_t size) override; - tsl::Status SynchronousMemSet(DeviceMemoryBase* location, int value, - uint64_t size) override; + absl::Status SynchronousMemSet(DeviceMemoryBase* location, int value, + uint64_t size) override; - tsl::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst, const void* host_src, - uint64_t size) override; + absl::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst, + const void* host_src, uint64_t size) override; - tsl::Status SynchronousMemcpy(void* host_dst, const DeviceMemoryBase& gpu_src, - uint64_t size) override; + absl::Status SynchronousMemcpy(void* host_dst, + const DeviceMemoryBase& gpu_src, + uint64_t size) override; - tsl::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase* gpu_dst, - const DeviceMemoryBase& gpu_src, - uint64_t size) override; + absl::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, + uint64_t size) override; - tsl::Status MemZero(Stream* stream, DeviceMemoryBase* location, - uint64_t size) override; - tsl::Status Memset(Stream* stream, DeviceMemoryBase* location, - uint8_t pattern, uint64_t size) override; - tsl::Status Memset32(Stream* stream, DeviceMemoryBase* location, - uint32_t pattern, uint64_t size) override; + absl::Status MemZero(Stream* stream, DeviceMemoryBase* location, + uint64_t size) override; + absl::Status Memset(Stream* stream, DeviceMemoryBase* location, + uint8_t pattern, uint64_t size) override; + absl::Status Memset32(Stream* stream, DeviceMemoryBase* location, + uint32_t pattern, uint64_t size) override; - bool Memcpy(Stream* stream, void* host_dst, const DeviceMemoryBase& gpu_src, - uint64_t size) override; + absl::Status Memcpy(Stream* stream, void* host_dst, + const DeviceMemoryBase& gpu_src, uint64_t size) override; - bool Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, const void* host_src, - uint64_t size) override; + absl::Status Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, + const void* host_src, uint64_t size) override; bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, uint64_t size) override; bool HostCallback(Stream* stream, - absl::AnyInvocable callback) override; + absl::AnyInvocable callback) override; bool AllocateStream(Stream* stream) override; @@ -213,22 +235,22 @@ class GpuExecutor : public internal::StreamExecutorInterface { bool CreateStreamDependency(Stream* dependent, Stream* other) override; - tsl::Status AllocateEvent(Event* event) override; + absl::Status AllocateEvent(Event* event) override; - tsl::Status DeallocateEvent(Event* event) override; + absl::Status DeallocateEvent(Event* event) override; - tsl::Status RecordEvent(Stream* stream, Event* event) override; + absl::Status RecordEvent(Stream* stream, Event* event) override; - tsl::Status WaitForEvent(Stream* stream, Event* event) override; + absl::Status WaitForEvent(Stream* stream, Event* event) override; - tsl::Status WaitForEventOnExternalStream(std::intptr_t stream, - Event* event) override; + absl::Status WaitForEventOnExternalStream(std::intptr_t stream, + Event* event) override; Event::Status PollForEventStatus(Event* event) override; - tsl::Status BlockHostUntilDone(Stream* stream) override; + absl::Status BlockHostUntilDone(Stream* stream) override; - tsl::Status EnablePeerAccessTo(StreamExecutorInterface* other) override; + absl::Status EnablePeerAccessTo(StreamExecutorInterface* other) override; bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override; @@ -240,12 +262,12 @@ class GpuExecutor : public internal::StreamExecutorInterface { bool GetSymbol(const std::string& symbol_name, ModuleHandle module_handle, void** mem, size_t* bytes) override; - tsl::StatusOr> CreateDeviceDescription() + absl::StatusOr> CreateDeviceDescription() const override { return CreateDeviceDescription(device_ordinal_); } - static tsl::StatusOr> + static absl::StatusOr> CreateDeviceDescription(int device_ordinal); blas::BlasSupport* CreateBlas() override; @@ -257,22 +279,18 @@ class GpuExecutor : public internal::StreamExecutorInterface { std::unique_ptr CreateEventImplementation() override; - std::unique_ptr CreateKernelImplementation() - override; - std::unique_ptr GetStreamImplementation() override; - tsl::StatusOr> - GetCommandBufferImplementation(CommandBuffer::Mode mode) override; + absl::StatusOr> CreateKernel() override; + + absl::StatusOr> CreateCommandBuffer( + CommandBuffer::Mode mode) override; // Wraps existing Gpu graph handle into an instance of Gpu command buffer. // This is required for wrapping nested graphs constructed for conditional // nodes and owned by a parent graph executable. - std::unique_ptr - GetCommandBufferImplementation(CommandBuffer::Mode mode, GpuGraphHandle graph, - bool is_owned_graph); - - void* platform_specific_context() override; + std::unique_ptr CreateCommandBuffer( + CommandBuffer::Mode mode, GpuGraphHandle graph, bool is_owned_graph); GpuContext* gpu_context(); @@ -311,27 +329,33 @@ class GpuExecutor : public internal::StreamExecutorInterface { static void InternalHostCallback(void* data); // Collects metadata for the specified kernel. - tsl::Status GetKernelMetadata(GpuKernel* cuda_kernel, - KernelMetadata* kernel_metadata); + absl::Status GetKernelMetadata(GpuKernel* cuda_kernel, + KernelMetadata* kernel_metadata); // Prints to VLOG(2) information about the kernel's occupancy and how it might // be improved. - void VlogOccupancyInfo(const Kernel& kernel, const ThreadDim& thread_dims, + void VlogOccupancyInfo(const DeviceDescription& device_description, + const Kernel& kernel, const ThreadDim& thread_dims, const BlockDim& block_dims); // (supported on CUDA only) - tsl::Status LoadModuleFromCuBin(const char* cubin, GpuModuleHandle* module) + absl::Status LoadModuleFromCuBin(const char* cubin, GpuModuleHandle* module) TF_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); // Loads the PTX text `ptx` as a CUDA module. `ptx` must be null terminated. // (supported on CUDA only) - tsl::Status LoadModuleFromPtx(const char* ptx, GpuModuleHandle* module) + absl::Status LoadModuleFromPtx(const char* ptx, GpuModuleHandle* module) TF_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); // (supported on ROCm only) - tsl::Status LoadModuleFromHsaco(const char* hsaco, GpuModuleHandle* module) + absl::Status LoadModuleFromHsaco(const char* hsaco, GpuModuleHandle* module) TF_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); + absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, + const BlockDim& block_dims, + const std::optional& cluster_dims, + const Kernel& kernel, const KernelArgs& args); + bool UnloadGpuBinary(const void* gpu_binary) TF_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); diff --git a/xla/stream_executor/gpu/gpu_graph.cc b/xla/stream_executor/gpu/gpu_graph.cc deleted file mode 100644 index 5d9df50df7c85..0000000000000 --- a/xla/stream_executor/gpu/gpu_graph.cc +++ /dev/null @@ -1,315 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/stream_executor/gpu/gpu_graph.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_executor.h" -#include "xla/stream_executor/gpu/gpu_kernel.h" -#include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/kernel.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/path.h" -#include "tsl/platform/statusor.h" - -namespace stream_executor { -namespace gpu { - -//===----------------------------------------------------------------------===// -// RAII helpers for gpu graph types. -//===----------------------------------------------------------------------===// - -std::atomic GpuGraphSupport::allocated_gpu_graph_execs_; -std::atomic GpuGraphSupport::alive_gpu_graph_execs_; - -/*static*/ void GpuGraphSupport::TrimDeviceMemory(StreamExecutor* executor) { - auto* gpu_executor = ExtractGpuExecutor(executor); - auto st = GpuDriver::DeviceGraphMemTrim(gpu_executor->device()); - if (!st.ok()) { - LOG(ERROR) << "Failed to trim Gpu device graph memory: " << st.message(); - } -} - -/*static*/ size_t GpuGraphSupport::NotifyGraphExecCreated() { - alive_gpu_graph_execs_.fetch_add(1, std::memory_order_relaxed); - return allocated_gpu_graph_execs_.fetch_add(1, std::memory_order_relaxed); -} - -/*static*/ size_t GpuGraphSupport::NotifyGraphExecDestroyed() { - return alive_gpu_graph_execs_.fetch_sub(1, std::memory_order_relaxed) - 1; -} - -/*static*/ size_t GpuGraphSupport::allocated_gpu_graph_execs() { - return allocated_gpu_graph_execs_.load(std::memory_order_relaxed); -} - -/*static*/ size_t GpuGraphSupport::alive_gpu_graph_execs() { - return alive_gpu_graph_execs_.load(std::memory_order_relaxed); -} - -void GpuGraphSupport::DestroyGraph::operator()(GpuGraphHandle graph) { - auto st = GpuDriver::DestroyGraph(graph); - CHECK(st.ok()) << "Failed to destroy gpu graph: " << st.message(); -} - -void GpuGraphSupport::DestroyGraphExec::operator()(GpuGraphExecHandle exec) { - auto st = GpuDriver::DestroyGraphExec(exec); - CHECK(st.ok()) << "Failed to destroy executable gpu graph: " << st.message(); -} - -tsl::StatusOr GraphExecUpdateResultToString( - GpuDriver::GraphExecUpdateResult result) { - switch (result) { - case GpuDriver::GraphExecUpdateResult::kSuccess: - return "kSuccess"; - case GpuDriver::GraphExecUpdateResult::kError: - return "kFailure"; - case GpuDriver::GraphExecUpdateResult::kTopologyChanged: - return "kTopologyChanged"; - case GpuDriver::GraphExecUpdateResult::kAttributesChanged: - return "kAttributesChanged"; - case GpuDriver::GraphExecUpdateResult::kFunctionChanged: - return "kFunctionChanged"; - case GpuDriver::GraphExecUpdateResult::kParametersChanged: - return "kParametersChanged"; - case GpuDriver::GraphExecUpdateResult::kUnsupportedFunctionChange: - return "kUnsupportedFunctionChange"; - case GpuDriver::GraphExecUpdateResult::kNodeTypeChanged: - return "kNodeTypeChanged"; - case GpuDriver::GraphExecUpdateResult::kNotSupported: - return "kNotSupported"; - } - return tsl::errors::Internal("Unexpected value for GraphExecUpdateResult"); -} - -tsl::StatusOr GraphNodeTypeToString( - GpuDriver::GraphNodeType node_type) { - switch (node_type) { - case GpuDriver::GraphNodeType::kKernel: - return "kKernel"; - case GpuDriver::GraphNodeType::kMemcpy: - return "kMemcpy"; - case GpuDriver::GraphNodeType::kMemset: - return "kMemset"; - case GpuDriver::GraphNodeType::kHost: - return "kHost"; - case GpuDriver::GraphNodeType::kGraph: - return "kGraph"; - case GpuDriver::GraphNodeType::kEmpty: - return "kEmpty"; - case GpuDriver::GraphNodeType::kWaitEvent: - return "kWaitEvent"; - case GpuDriver::GraphNodeType::kEventRecord: - return "kEventRecord"; - case GpuDriver::GraphNodeType::kExtSemasSignal: - return "kExtSemasSignal"; - case GpuDriver::GraphNodeType::kExtSemasWait: - return "kExtSemasWait"; - case GpuDriver::GraphNodeType::kMemAlloc: - return "kMemAlloc"; - case GpuDriver::GraphNodeType::kMemFree: - return "kMemFree"; - case GpuDriver::GraphNodeType::kBatchMemOp: - return "kBatchMemOp"; - } - return tsl::errors::Internal("Unexpected value for GraphNodeType"); -} - -tsl::Status OwnedGpuGraphExec::Update(OwnedGpuGraph graph) { - VLOG(3) << "Update gpu graph exec with a new graph after " << num_launches_ - << " launches since last update" - << " #" << num_updates_++; - - num_launches_ = 0; - - uint64_t start_nanos = tsl::Env::Default()->NowNanos(); - GpuDriver::GraphExecUpdateResultInfo result; - memset(&result, 0, sizeof(result)); - auto st = GpuDriver::GraphExecUpdate(get(), graph.get(), &result); - uint64_t end_nanos = tsl::Env::Default()->NowNanos(); - - if (!st.ok()) { - TF_ASSIGN_OR_RETURN(std::string result_str, - GraphExecUpdateResultToString(result.result)); - std::string error_message = absl::StrCat( - "Failed to update gpu graph: Graph update result=", result_str); - - if (result.error_node) { - TF_ASSIGN_OR_RETURN(GpuDriver::GraphNodeType node_type, - GpuDriver::GraphNodeGetType(result.error_node)); - TF_ASSIGN_OR_RETURN(std::string node_type_str, - GraphNodeTypeToString(node_type)); - absl::StrAppend(&error_message, ", Error node name=", node_type_str); - } - - if (result.error_from_node) { - TF_ASSIGN_OR_RETURN(GpuDriver::GraphNodeType node_type, - GpuDriver::GraphNodeGetType(result.error_from_node)); - TF_ASSIGN_OR_RETURN(std::string node_type_str, - GraphNodeTypeToString(node_type)); - absl::StrAppend(&error_message, ", Error from node name=", node_type_str); - } - - absl::StrAppend(&error_message, ": ", st.message()); - return tsl::errors::Internal(error_message); - } - - VLOG(5) << "Updated gpu graph exec #" << id_ << " (took " - << (end_nanos - start_nanos) / 1000 << " us)"; - - return tsl::OkStatus(); -} - -tsl::Status OwnedGpuGraphExec::Launch(stream_executor::Stream* stream) { - VLOG(3) << "Launch gpu graph " << get() - << " on a stream: " << stream->DebugStreamPointers() << " #" - << ++num_launches_; - - return GpuDriver::GraphLaunch(get(), AsGpuStreamValue(stream)); -} - -OwnedGpuGraphExec::~OwnedGpuGraphExec() { - if (*this) // do not log for moved-from instances - VLOG(5) << "Destroy GPU graph exec #" << id_ - << " (remaining alive instances: " - << GpuGraphSupport::NotifyGraphExecDestroyed() << ")"; -} - -//===----------------------------------------------------------------------===// -// GPU Graph Helpers. -//===----------------------------------------------------------------------===// - -tsl::StatusOr CreateGpuGraph() { - GpuGraphHandle graph; - TF_RETURN_IF_ERROR(GpuDriver::CreateGraph(&graph)); - return OwnedGpuGraph(graph); -} - -tsl::StatusOr AddKernelNode( - GpuGraphHandle graph, absl::Span deps, - ThreadDim threads, BlockDim blocks, const Kernel& kernel, - const KernelArgs& args) { - const GpuKernel* gpu_kernel = AsGpuKernel(&kernel); - GpuFunctionHandle gpu_func = gpu_kernel->AsGpuFunctionHandle(); - - auto* packed_args = DynCast(&args); - if (!packed_args) - return absl::InternalError("Unsupported kernel arguments type"); - - void** kernel_params = - const_cast(packed_args->argument_addresses().data()); - - GpuGraphNodeHandle node; - TF_RETURN_IF_ERROR(GpuDriver::GraphAddKernelNode( - &node, graph, deps, kernel.name(), gpu_func, blocks.x, blocks.y, blocks.z, - threads.x, threads.y, threads.z, args.number_of_shared_bytes(), - kernel_params, /*extra=*/nullptr)); - - return node; -} - -static GpuDevicePtr AsDevicePtr(const DeviceMemoryBase& mem) { - return reinterpret_cast(const_cast(mem.opaque())); -} - -tsl::StatusOr AddMemcpyD2DNode( - GpuContext* context, GpuGraphHandle graph, - absl::Span deps, const DeviceMemoryBase& dst, - const DeviceMemoryBase& src) { - GpuGraphNodeHandle node; - TF_RETURN_IF_ERROR(GpuDriver::GraphAddMemcpyD2DNode( - context, &node, graph, deps, AsDevicePtr(dst), AsDevicePtr(src), - dst.size())); - return node; -} - -tsl::StatusOr CaptureGpuGraph( - stream_executor::Stream* stream, - absl::AnyInvocable capture) { - VLOG(3) << "Capture gpu graph on a stream: " << stream->DebugStreamPointers(); - uint64_t start_nanos = tsl::Env::Default()->NowNanos(); - - GpuGraphHandle graph; - - // Get the underlying stream for passing to GPU runtime APIs. - auto gpu_stream = AsGpuStreamValue(stream); - - // Capture graph constructed by the exported graph capture function. - TF_RETURN_IF_ERROR(GpuDriver::StreamBeginCapture( - gpu_stream, GpuDriver::StreamCaptureMode::kThreadLocal)); - - // Call into graph capture function. - auto captured = capture(); - - // Always stop capturing the stream before checking `captured` result. - TF_RETURN_IF_ERROR(GpuDriver::StreamEndCapture(gpu_stream, &graph)); - - if (!captured.ok()) - return tsl::errors::Internal("failed to capture gpu graph: ", - captured.message()); - - uint64_t end_nanos = tsl::Env::Default()->NowNanos(); - VLOG(5) << "Captured XLA:GPU operations into the graph " << graph << " (took " - << (end_nanos - start_nanos) / 1000 << " us)"; - - if (const char* path = getenv("XLA_GPU_GRAPH_DEBUG_DIRECTORY"); path) { - std::string file = tsl::io::JoinPath(std::string(path), "/gpu-graph-"); - - if (tsl::Env::Default()->CreateUniqueFileName(&file, ".dot")) { - VLOG(100) << "Print gpu graph " << graph - << " debug dot file to: " << file; - auto printed = GpuDriver::GraphDebugDotPrint(graph, file.c_str()); - printed.IgnoreError(); // warning will be printed by GpuDriver - } else { - LOG(WARNING) << "Cannot create unique filename, won't enable gpu " - "graph debugging"; - } - } - - return OwnedGpuGraph(graph); -} - -tsl::StatusOr InstantiateGpuGraph(OwnedGpuGraph graph) { - GpuGraphExecHandle exec; - - uint64_t start_nanos = tsl::Env::Default()->NowNanos(); - GpuDriver::GraphInstantiateFlags flags; - TF_RETURN_IF_ERROR(GpuDriver::GraphInstantiate(&exec, graph.get(), flags)); - uint64_t end_nanos = tsl::Env::Default()->NowNanos(); - - size_t id = GpuGraphSupport::NotifyGraphExecCreated(); - VLOG(5) << "Instantiated gpu graph exec instance #" << id << " in " - << (end_nanos - start_nanos) / 1000 << " us (alive instances: " - << GpuGraphSupport::alive_gpu_graph_execs() << ")"; - return OwnedGpuGraphExec(id, exec); -} - -tsl::StatusOr IsStreamCapturing(stream_executor::Stream* stream) { - return GpuDriver::StreamIsCapturing(AsGpuStreamValue(stream)); -} - -} // namespace gpu -} // namespace stream_executor diff --git a/xla/stream_executor/gpu/gpu_graph.h b/xla/stream_executor/gpu/gpu_graph.h deleted file mode 100644 index 28abf986049e8..0000000000000 --- a/xla/stream_executor/gpu/gpu_graph.h +++ /dev/null @@ -1,141 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_GRAPH_H_ -#define XLA_STREAM_EXECUTOR_GPU_GPU_GRAPH_H_ - -#include -#include -#include -#include -#include - -#include "absl/functional/any_invocable.h" -#include "absl/types/span.h" -#include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/stream.h" -#include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" - -namespace stream_executor { -namespace gpu { - -// Forward declare. -class GpuContext; - -class GpuGraphSupport { - public: - // Deleters for gpu graph and graph exec instance that check the returned - // status and terminate on error. - struct DestroyGraph { - void operator()(GpuGraphHandle); - }; - struct DestroyGraphExec { - void operator()(GpuGraphExecHandle); - }; - - static size_t NotifyGraphExecCreated(); - static size_t NotifyGraphExecDestroyed(); - - static size_t allocated_gpu_graph_execs(); - static size_t alive_gpu_graph_execs(); - - static void TrimDeviceMemory(StreamExecutor* executor); - - private: - // Global counters for the total number of allocated and alive gpu graph - // execs to track the resource usage at run time. - static std::atomic allocated_gpu_graph_execs_; - static std::atomic alive_gpu_graph_execs_; -}; - -//===----------------------------------------------------------------------===// -// RAII helpers for gpu graph types. -//===----------------------------------------------------------------------===// - -class OwnedGpuGraph - : public std::unique_ptr, - GpuGraphSupport::DestroyGraph> { - // Bring std::unique_ptr constructors in scope. - using std::unique_ptr, - GpuGraphSupport::DestroyGraph>::unique_ptr; -}; - -class OwnedGpuGraphExec - : public std::unique_ptr, - GpuGraphSupport::DestroyGraphExec> { - using Base = std::unique_ptr, - GpuGraphSupport::DestroyGraphExec>; - - public: - OwnedGpuGraphExec(uint64_t id, GpuGraphExecHandle exec) - : Base(exec), id_(id) {} - ~OwnedGpuGraphExec(); - - OwnedGpuGraphExec(OwnedGpuGraphExec&&) = default; - OwnedGpuGraphExec& operator=(OwnedGpuGraphExec&&) = default; - - // Updates executable graph instance with a newly captured graph. Returns an - // error if the new graph is not compatible (see `cudaGraphExecUpdate`). - tsl::Status Update(OwnedGpuGraph graph); - - // Launches captured graph on a given stream. - tsl::Status Launch(stream_executor::Stream* stream); - - uint64_t id() const { return id_; } - - private: - uint64_t id_; - uint64_t num_updates_ = 0; - uint64_t num_launches_ = 0; -}; - -//===----------------------------------------------------------------------===// -// Gpu Graph Helpers. -//===----------------------------------------------------------------------===// - -// Creates new empty Gpu graph. -tsl::StatusOr CreateGpuGraph(); - -// Adds a kernel node to the graph. -tsl::StatusOr AddKernelNode( - GpuGraphHandle graph, absl::Span deps, - ThreadDim threads, BlockDim blocks, const Kernel& kernel, - const KernelArgs& args); - -// Adds a memory copy node to the graph. -tsl::StatusOr AddMemcpyD2DNode( - GpuContext* context, GpuGraphHandle graph, - absl::Span deps, const DeviceMemoryBase& dst, - const DeviceMemoryBase& src); - -// Captures all operations added to a `stream` by the `capture` function into -// the gpu graph instance. -tsl::StatusOr CaptureGpuGraph( - stream_executor::Stream* stream, absl::AnyInvocable capture); - -// Instantiates a captured gpu graph instance into a gpu graph executable. -tsl::StatusOr InstantiateGpuGraph(OwnedGpuGraph graph); - -// Returns true if the stream is in graph capture mode -tsl::StatusOr IsStreamCapturing(stream_executor ::Stream* stream); - -} // namespace gpu -} // namespace stream_executor - -#endif // XLA_STREAM_EXECUTOR_GPU_GPU_GRAPH_H_ diff --git a/xla/stream_executor/gpu/gpu_helpers.h b/xla/stream_executor/gpu/gpu_helpers.h index cefd0a6af15df..c86f49140a521 100644 --- a/xla/stream_executor/gpu/gpu_helpers.h +++ b/xla/stream_executor/gpu/gpu_helpers.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include #include "xla/stream_executor/gpu/gpu_types.h" #include "tsl/platform/logging.h" diff --git a/xla/stream_executor/gpu/gpu_init.cc b/xla/stream_executor/gpu/gpu_init.cc index 8c0237ec06184..a0f8e5919a5ea 100644 --- a/xla/stream_executor/gpu/gpu_init.cc +++ b/xla/stream_executor/gpu/gpu_init.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,24 +17,24 @@ limitations under the License. #include -#include "xla/stream_executor/multi_platform_manager.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "tsl/platform/logging.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace stream_executor { -tsl::Status ValidateGPUMachineManager() { - return MultiPlatformManager::PlatformWithName(GpuPlatformName()).status(); +absl::Status ValidateGPUMachineManager() { + return PlatformManager::PlatformWithName(GpuPlatformName()).status(); } Platform* GPUMachineManager() { // Cache this result, it's on the critical path for light outside compilation // (and probably other things as well). static Platform* platform = [&] { - tsl::StatusOr p = - MultiPlatformManager::PlatformWithName(GpuPlatformName()); + absl::StatusOr p = + PlatformManager::PlatformWithName(GpuPlatformName()); if (!p.ok()) { LOG(FATAL) << "Could not find Platform with name " << GpuPlatformName(); } diff --git a/xla/stream_executor/gpu/gpu_init.h b/xla/stream_executor/gpu/gpu_init.h index bcdce71f9ffce..ea51473c20a7f 100644 --- a/xla/stream_executor/gpu/gpu_init.h +++ b/xla/stream_executor/gpu/gpu_init.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,14 +18,14 @@ limitations under the License. #include -#include "tsl/platform/status.h" +#include "absl/status/status.h" namespace stream_executor { class Platform; // Initializes the GPU platform and returns OK if the GPU // platform could be initialized. -tsl::Status ValidateGPUMachineManager(); +absl::Status ValidateGPUMachineManager(); // Returns the GPU machine manager singleton, creating it and // initializing the GPUs on the machine if needed the first time it is diff --git a/xla/stream_executor/gpu/gpu_kernel.h b/xla/stream_executor/gpu/gpu_kernel.h index 09443a23259b5..81c8e687d8120 100644 --- a/xla/stream_executor/gpu/gpu_kernel.h +++ b/xla/stream_executor/gpu/gpu_kernel.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,33 +22,37 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_KERNEL_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_KERNEL_H_ +#include +#include +#include +#include + +#include "absl/status/statusor.h" #include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/platform/port.h" -#include "xla/stream_executor/stream_executor_internal.h" +#include "xla/stream_executor/launch_dim.h" #include "tsl/platform/logging.h" -namespace stream_executor { -namespace gpu { +namespace stream_executor::gpu { -// Wraps a GpuFunctionHandle to implement the platform-independent -// KernelInterface. -class GpuKernel : public internal::KernelInterface { +class GpuKernel : public Kernel { public: - GpuKernel() - : gpu_function_(nullptr), - arity_(0), - preferred_cache_config_(KernelCacheConfig::kNoPreference) {} + explicit GpuKernel(GpuExecutor* gpu_executor) : gpu_executor_(gpu_executor) {} // Note that the function is unloaded when the module is unloaded, and the // module that the function is contained in is owned by the GpuExecutor. - ~GpuKernel() override {} + ~GpuKernel() override { gpu_executor_->UnloadKernel(this); } // As arity cannot be reflected upon using the CUDA API, the arity is // explicitly set during the GpuExecutor::GetKernel initialization process. void set_arity(unsigned arity) { arity_ = arity; } unsigned Arity() const override { return arity_; } + void set_name(std::string name) { name_ = std::move(name); } + void set_gpu_context(GpuContext* gpu_context) { gpu_context_ = gpu_context; } + // Returns the GpuFunctionHandle value for passing to the CUDA API. GpuFunctionHandle AsGpuFunctionHandle() const { DCHECK(gpu_function_ != nullptr); @@ -59,47 +63,30 @@ class GpuKernel : public internal::KernelInterface { // object, for the CUDA API which wants to load into a GpuFunctionHandle*. GpuFunctionHandle* gpu_function_ptr() { return &gpu_function_; } - // CUDA supports setting the preferred cache configuration of a - // GpuFunctionHandle (more-or-less equivalent to a GpuKernel). We support this - // via the below functions; users can set a preference, and that is applied - // when the kernel is [lazy-]loaded (in GpuExecutor::Launch). The alternative - // would be to load the kernel & set the preference when the user calls the - // setter below; either approach is valid. Sets the current kernel cache - // configuration preference. - void SetPreferredCacheConfig(KernelCacheConfig config) override { - preferred_cache_config_ = config; - } - - // Returns the current kernel cache configuration preference. - KernelCacheConfig GetPreferredCacheConfig() const override { - return preferred_cache_config_; - } - // Returns the current kernel cache configuration preference as a - // CUfunc_cache. + // GpuFuncCachePreference. GpuFuncCachePreference GetGpuCacheConfig() const; + absl::StatusOr GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const override; + private: - GpuFunctionHandle gpu_function_; // Wrapped CUDA kernel handle. - unsigned arity_; // Number of formal parameters the kernel takes. + GpuExecutor* gpu_executor_ = nullptr; + GpuContext* gpu_context_ = nullptr; // context where kernel is loaded + std::string name_; // kernel name - // Preferred (but not required) cache configuration for this kernel. - KernelCacheConfig preferred_cache_config_; + GpuFunctionHandle gpu_function_ = nullptr; // wrapped CUDA kernel handle + unsigned arity_ = 0; // number of formal parameters the kernel takes }; -// Given a platform-independent kernel datatype, returns the (const) internal -// CUDA platform implementation pointer. inline const GpuKernel* AsGpuKernel(const Kernel* kernel) { - return static_cast(kernel->implementation()); + return static_cast(kernel); } -// Given a platform-independent kernel datatype, returns the (non-const) -// internal CUDA platform implementation pointer. inline GpuKernel* AsGpuKernel(Kernel* kernel) { - return static_cast(kernel->implementation()); + return static_cast(kernel); } -} // namespace gpu -} // namespace stream_executor +} // namespace stream_executor::gpu #endif // XLA_STREAM_EXECUTOR_GPU_GPU_KERNEL_H_ diff --git a/xla/stream_executor/gpu/gpu_kernel_test.cc b/xla/stream_executor/gpu/gpu_kernel_test.cc new file mode 100644 index 0000000000000..591d417348776 --- /dev/null +++ b/xla/stream_executor/gpu/gpu_kernel_test.cc @@ -0,0 +1,77 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include +#include "absl/strings/ascii.h" +#include "xla/service/platform_util.h" +#include "xla/stream_executor/gpu/gpu_test_kernels.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { + +TEST(GpuKernelTest, Add) { + using AddI32Kernel = TypedKernel, DeviceMemory, + DeviceMemory>; + auto name = absl::AsciiStrToUpper( + xla::PlatformUtil::CanonicalPlatformName("gpu").value()); + Platform* platform = PlatformManager::PlatformWithName(name).value(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + MultiKernelLoaderSpec spec(/*arity=*/3); +#if defined(GOOGLE_CUDA) + spec.AddCudaPtxInMemory(internal::kAddI32Kernel, "add"); +#elif defined(TENSORFLOW_USE_ROCM) + spec.AddCudaCubinInMemory(internal::kAddI32KernelModule, "add"); +#endif + + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=1, b=2, c=0 + DeviceMemory a = executor->AllocateArray(length, 0); + DeviceMemory b = executor->AllocateArray(length, 0); + DeviceMemory c = executor->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length)); + TF_ASSERT_OK(stream->Memset32(&b, 2, byte_length)); + TF_ASSERT_OK(stream->MemZero(&c, byte_length)); + + // Launch kernel. + ASSERT_TRUE(stream->ThenLaunch(ThreadDim(), BlockDim(4), add, a, b, c).ok()); + + // Copy data back to host. + std::vector dst(4, 42); + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + + std::vector expected = {3, 3, 3, 3}; + ASSERT_EQ(dst, expected); +} + +} // namespace stream_executor::gpu diff --git a/xla/stream_executor/gpu/gpu_kernels.h b/xla/stream_executor/gpu/gpu_kernels.h new file mode 100644 index 0000000000000..a2f14ec15f451 --- /dev/null +++ b/xla/stream_executor/gpu/gpu_kernels.h @@ -0,0 +1,47 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_KERNELS_H_ +#define XLA_STREAM_EXECUTOR_GPU_GPU_KERNELS_H_ + +#include + +namespace stream_executor::gpu { + +// Collection of helper kernels required by StreamExecutor Gpu backend. + +// PTX kernel compiled from: +// +// __global__ void noop() {} +// +// Easiest way to get PTX from C++ is to use https://godbolt.org. +inline constexpr std::string_view kNoOpKernel = R"( +.version 4.0 +.target sm_50 +.address_size 64 + +.visible .entry noop() +{ + + .loc 1 1 0 + + .loc 1 4 1 + ret; + +})"; + +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_GPU_GPU_KERNELS_H_ diff --git a/xla/stream_executor/gpu/gpu_runtime.h b/xla/stream_executor/gpu/gpu_runtime.h index da9dc736be53b..6f36c7ceab1ea 100644 --- a/xla/stream_executor/gpu/gpu_runtime.h +++ b/xla/stream_executor/gpu/gpu_runtime.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,8 +18,10 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_RUNTIME_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_RUNTIME_H_ +#include + +#include "absl/status/statusor.h" #include "xla/stream_executor/gpu/gpu_types.h" -#include "tsl/platform/statusor.h" namespace stream_executor::gpu { @@ -50,7 +52,11 @@ class GpuRuntime { // current device (and create it if it doesn't exist yet). // // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DRIVER.html#group__CUDART__DRIVER_1gaba6f8d01e745f0c8d8776ceb18be617 - static tsl::StatusOr GetFuncBySymbol(void* symbol); + static absl::StatusOr GetFuncBySymbol(void* symbol); + + // Returns the Gpu Runtime version. + // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION_1g0e3952c7802fd730432180f1f4a6cdc6 + static absl::StatusOr GetRuntimeVersion(); }; } // namespace stream_executor::gpu diff --git a/xla/stream_executor/gpu/gpu_stream.cc b/xla/stream_executor/gpu/gpu_stream.cc index 5d35f76127821..a04f5410dbd61 100644 --- a/xla/stream_executor/gpu/gpu_stream.cc +++ b/xla/stream_executor/gpu/gpu_stream.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,9 +17,14 @@ limitations under the License. #include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" -#include "tsl/platform/status.h" namespace stream_executor { namespace gpu { @@ -43,7 +48,7 @@ bool GpuStream::Init() { void GpuStream::Destroy() { if (completed_event_ != nullptr) { - tsl::Status status = + absl::Status status = GpuDriver::DestroyEvent(parent_->gpu_context(), &completed_event_); if (!status.ok()) { LOG(ERROR) << status.message(); diff --git a/xla/stream_executor/gpu/gpu_stream.h b/xla/stream_executor/gpu/gpu_stream.h index 75667afa8e9f2..166be82299587 100644 --- a/xla/stream_executor/gpu/gpu_stream.h +++ b/xla/stream_executor/gpu/gpu_stream.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,7 +21,9 @@ limitations under the License. #include +#include "absl/log/check.h" #include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor_internal.h" namespace stream_executor { diff --git a/xla/stream_executor/gpu/gpu_test_kernels.cu.cc b/xla/stream_executor/gpu/gpu_test_kernels.cu.cc new file mode 100644 index 0000000000000..cab05701159ad --- /dev/null +++ b/xla/stream_executor/gpu/gpu_test_kernels.cu.cc @@ -0,0 +1,51 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/gpu/gpu_test_kernels.h" + +#include + +namespace stream_executor::gpu::internal { + +__global__ void AddI32(int32_t* a, int32_t* b, int32_t* c) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + c[index] = a[index] + b[index]; +} + +__global__ void MulI32(int32_t* a, int32_t* b, int32_t* c) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + c[index] = a[index] * b[index]; +} + +__global__ void IncAndCmp(int32_t* counter, bool* pred, int32_t value) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + pred[index] = counter[index] < value; + counter[index] += 1; +} + +__global__ void AddI32Ptrs3(Ptrs3 ptrs) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + ptrs.c[index] = ptrs.a[index] + ptrs.b[index]; +} + +void* GetAddI32Kernel() { return reinterpret_cast(&AddI32); } + +void* GetMulI32Kernel() { return reinterpret_cast(&MulI32); } + +void* GetIncAndCmpKernel() { return reinterpret_cast(&IncAndCmp); } + +void* GetAddI32Ptrs3Kernel() { return reinterpret_cast(&AddI32Ptrs3); } + +} // namespace stream_executor::gpu::internal diff --git a/xla/stream_executor/gpu/gpu_test_kernels.h b/xla/stream_executor/gpu/gpu_test_kernels.h new file mode 100644 index 0000000000000..74931452bb662 --- /dev/null +++ b/xla/stream_executor/gpu/gpu_test_kernels.h @@ -0,0 +1,106 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_TEST_KERNELS_H_ +#define XLA_STREAM_EXECUTOR_GPU_GPU_TEST_KERNELS_H_ + +#include + +namespace stream_executor::gpu::internal { + +// This is a collection of gpu kernels for writing simple StreamExecutor tests. +// +// Some of the kernels available as pre-compiled PTX blobs (can be loaded with +// CUDA driver API) / HSACO modules (can be loaded with ROCM driver api), and +// some of the kernels are written directly in CUDA C++ and can be loaded from a +// symbol pointer (to test StreamExecutor CUDA runtime integration). + +#if !defined(TENSORFLOW_USE_ROCM) +// PTX kernel compiled from: +// +// __global__ void add(int* a, int* b, int* c) { +// int index = threadIdx.x + blockIdx.x * blockDim.x; +// c[index] = a[index] + b[index]; +// } +// +// Easiest way to get PTX from C++ is to use https://godbolt.org. +inline constexpr std::string_view kAddI32Kernel = R"( +.version 4.0 +.target sm_50 +.address_size 64 + +.visible .entry add( + .param .u64 add_param_0, + .param .u64 add_param_1, + .param .u64 add_param_2 +) +{ + .reg .b32 %r<8>; + .reg .b64 %rd<11>; + .loc 1 1 0 + + ld.param.u64 %rd1, [add_param_0]; + ld.param.u64 %rd2, [add_param_1]; + ld.param.u64 %rd3, [add_param_2]; + .loc 1 3 3 + cvta.to.global.u64 %rd4, %rd3; + cvta.to.global.u64 %rd5, %rd2; + cvta.to.global.u64 %rd6, %rd1; + mov.u32 %r1, %tid.x; + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %ntid.x; + mad.lo.s32 %r4, %r2, %r3, %r1; + .loc 1 4 3 + mul.wide.s32 %rd7, %r4, 4; + add.s64 %rd8, %rd6, %rd7; + ld.global.u32 %r5, [%rd8]; + add.s64 %rd9, %rd5, %rd7; + ld.global.u32 %r6, [%rd9]; + add.s32 %r7, %r6, %r5; + add.s64 %rd10, %rd4, %rd7; + st.global.u32 [%rd10], %r7; + .loc 1 5 1 + ret; + +})"; +#else +#include "xla/stream_executor/rocm/add_i32_kernel.h" +#endif // !defined(TENSORFLOW_USE_ROCM) + +template +struct Ptrs3 { + T* a; + T* b; + T* c; +}; + +// Returns a pointer to device kernel compiled from the CUDA C++ code above. +void* GetAddI32Kernel(); + +// Returns a pointer to device kernel doing multiplication instead of addition. +void* GetMulI32Kernel(); + +// Returns a pointer to device kernel doing increment and compare, intended for +// testing on-device while loops. +void* GetIncAndCmpKernel(); + +// Returns a pointer to device kernel compiled from the CUDA C++ but with all +// three pointers passed to argument as an instance of `Ptr3` template to test +// StreamExecutor arguments packing for custom C++ types. +void* GetAddI32Ptrs3Kernel(); + +} // namespace stream_executor::gpu::internal + +#endif // XLA_STREAM_EXECUTOR_GPU_GPU_TEST_KERNELS_H_ diff --git a/xla/stream_executor/gpu/gpu_timer.cc b/xla/stream_executor/gpu/gpu_timer.cc index 6f557f9787c5f..e5cdab36a2457 100644 --- a/xla/stream_executor/gpu/gpu_timer.cc +++ b/xla/stream_executor/gpu/gpu_timer.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,18 +15,26 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_timer.h" +#include #include #include #include #include "absl/base/const_init.h" #include "absl/base/thread_annotations.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" +#include "absl/utility/utility.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "tsl/platform/status.h" +#include "xla/stream_executor/gpu/gpu_types.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace stream_executor { namespace gpu { @@ -43,9 +51,21 @@ absl::Duration RandomDuration() { return absl::Microseconds(distribution(rng)); } +bool ShouldLaunchDelayKernel() { + // Only launch the delay kernel if CUDA_LAUNCH_BLOCKING is not set to 1. + static bool value = [] { + const char* blocking = std::getenv("CUDA_LAUNCH_BLOCKING"); + return !blocking || std::string_view{blocking} != "1"; + }(); + return value; +} + } // namespace -/*static*/ tsl::StatusOr GpuTimer::Create(GpuStream* stream) { +/*deprecated*/ /*static*/ absl::StatusOr GpuTimer::Create( + GpuStream* stream) { + // This deprecated factory does not launch the delay kernel and may lead to + // reduced measurement accuracy. GpuExecutor* parent = stream->parent(); GpuContext* context = parent->gpu_context(); GpuEventHandle start_event; @@ -57,12 +77,14 @@ absl::Duration RandomDuration() { CHECK(start_event != nullptr && stop_event != nullptr); TF_RETURN_IF_ERROR(GpuDriver::RecordEvent(parent->gpu_context(), start_event, stream->gpu_stream())); - return tsl::StatusOr{absl::in_place, parent, start_event, - stop_event, stream}; + return absl::StatusOr{absl::in_place, parent, start_event, + stop_event, stream}; } -/*static*/ tsl::StatusOr> GpuTimer::CreateIfNeeded( - GpuStream* stream, bool is_needed) { +/*deprecated*/ /*static*/ absl::StatusOr> +GpuTimer::CreateIfNeeded(GpuStream* stream, bool is_needed) { + // This deprecated factory does not launch the delay kernel and may lead to + // reduced measurement accuracy. if (is_needed) { TF_ASSIGN_OR_RETURN(GpuTimer t, GpuTimer::Create(stream)); return {std::make_optional(std::move(t))}; @@ -70,32 +92,134 @@ absl::Duration RandomDuration() { return std::nullopt; } +/*static*/ absl::StatusOr +GpuTimer::GpuSemaphore::Create(StreamExecutor* executor) { + // Allocate the value in pinned host memory that can be read from both + // host and device. + TF_ASSIGN_OR_RETURN(auto alloc, + executor->HostMemoryAllocate(sizeof(GpuSemaphoreState))); + return GpuSemaphore{std::move(alloc)}; +} + +DeviceMemory GpuTimer::GpuSemaphore::device() { + // This assumes unified addressing, as we do not explicitly translate the + // host pointer into a device pointer. + return DeviceMemory::MakeFromByteSize( + ptr_->opaque(), sizeof(GpuSemaphoreState)); +} + +/*static*/ absl::StatusOr GpuTimer::Create(Stream* real_stream, + bool use_delay_kernel) { + StreamExecutor* executor = real_stream->parent(); + GpuStream* stream = AsGpuStream(real_stream); + GpuExecutor* parent = stream->parent(); + GpuContext* context = parent->gpu_context(); + GpuEventHandle start_event; + TF_RETURN_IF_ERROR(GpuDriver::InitEvent(context, &start_event, + GpuDriver::EventFlags::kDefault)); + GpuEventHandle stop_event; + TF_RETURN_IF_ERROR(GpuDriver::InitEvent(context, &stop_event, + GpuDriver::EventFlags::kDefault)); + CHECK(start_event != nullptr && stop_event != nullptr); + GpuSemaphore semaphore{}; + if (!use_delay_kernel) { + LOG(WARNING) + << "Skipping the delay kernel, measurement accuracy will be reduced"; + } + if (use_delay_kernel && ShouldLaunchDelayKernel()) { + // Check the assumption that this device supports unified addressing, + // otherwise skip the delay kernel + TF_ASSIGN_OR_RETURN(int status, GpuDriver::GetDeviceAttribute( + CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, + parent->device())); + if (!status) { + LOG(WARNING) << "Skipping the delay kernel because the device does not " + "support unified addressing"; + } else { + // Allocate a semaphore value that will be used to signal to the delay + // kernel that it may exit. + TF_ASSIGN_OR_RETURN(semaphore, GpuSemaphore::Create(executor)); + *semaphore = GpuSemaphoreState::Hold; + // In principle the kernel could be loaded lazily and shared across + // multiple GpuTimer objects. + TF_ASSIGN_OR_RETURN( + auto kernel, + (TypedKernel, + GpuSemaphoreState>::Create(executor, "DelayKernel", + delay_kernel::kernel()))); + // Launch a delay kernel into this stream, which will spin until + // GetElapsedDuration() is called, the timer is destroyed, or the timeout + // in the kernel is reached. + TF_RETURN_IF_ERROR(real_stream->ThenLaunch( + ThreadDim(1, 1, 1), BlockDim(1, 1, 1), kernel, semaphore.device(), + GpuSemaphoreState::Release)); + } + } + // The start event goes after the delay kernel in the stream + TF_RETURN_IF_ERROR(GpuDriver::RecordEvent(parent->gpu_context(), start_event, + stream->gpu_stream())); + return absl::StatusOr{absl::in_place, parent, start_event, + stop_event, stream, std::move(semaphore)}; +} + +/*static*/ absl::StatusOr> GpuTimer::CreateIfNeeded( + Stream* stream, bool use_delay_kernel, bool is_needed) { + if (is_needed) { + TF_ASSIGN_OR_RETURN(GpuTimer t, GpuTimer::Create(stream, use_delay_kernel)); + return {std::make_optional(std::move(t))}; + } + return std::nullopt; +} + /*static*/ void GpuTimer::ReturnRandomDurationsForTesting() { return_random_durations = true; } GpuTimer::~GpuTimer() { GpuContext* context = parent_->gpu_context(); + if (semaphore_ && !is_stopped_) { + // Signal the delay kernel that it can exit + *semaphore_ = GpuSemaphoreState::Release; + // Wait for the delay kernel to exit before destroying the value that it is + // watching. + absl::Status status = + GpuDriver::SynchronizeStream(context, stream_->gpu_stream()); + if (!status.ok()) { + LOG(ERROR) << status; + } + } if (start_event_ != nullptr) { - tsl::Status status = GpuDriver::DestroyEvent(context, &start_event_); + absl::Status status = GpuDriver::DestroyEvent(context, &start_event_); if (!status.ok()) { LOG(ERROR) << status; } } if (stop_event_ != nullptr) { - tsl::Status status = GpuDriver::DestroyEvent(context, &stop_event_); + absl::Status status = GpuDriver::DestroyEvent(context, &stop_event_); if (!status.ok()) { LOG(ERROR) << status; } } } -tsl::StatusOr GpuTimer::GetElapsedDuration() { +absl::StatusOr GpuTimer::GetElapsedDuration() { if (is_stopped_) { return absl::InternalError("Measuring inactive timer"); } TF_RETURN_IF_ERROR(GpuDriver::RecordEvent(parent_->gpu_context(), stop_event_, stream_->gpu_stream())); + // If we launched the delay kernel then check if it already timed out. + if (semaphore_) { + if (*semaphore_ == GpuSemaphoreState::TimedOut) { + // The delay kernel did not achieve the intended result. + LOG(ERROR) << "Delay kernel timed out: measured time has sub-optimal " + "accuracy. There may be a missing warmup execution, please " + "investigate in Nsight Systems."; + } else { + // Signal that the kernel can exit + *semaphore_ = GpuSemaphoreState::Release; + } + } float elapsed_milliseconds = NAN; if (!GpuDriver::GetEventElapsedTime(parent_->gpu_context(), &elapsed_milliseconds, start_event_, diff --git a/xla/stream_executor/gpu/gpu_timer.h b/xla/stream_executor/gpu/gpu_timer.h index 5867b3ee93a41..da78988179792 100644 --- a/xla/stream_executor/gpu/gpu_timer.h +++ b/xla/stream_executor/gpu/gpu_timer.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,9 +19,11 @@ limitations under the License. #include #include -#include "xla/stream_executor/gpu/gpu_driver.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" #include "xla/stream_executor/gpu/gpu_executor.h" -#include "xla/stream_executor/stream_executor_internal.h" +#include "xla/stream_executor/gpu/gpu_timer_kernel.h" +#include "xla/stream_executor/gpu/gpu_types.h" namespace xla { namespace gpu { @@ -35,41 +37,81 @@ namespace gpu { class GpuExecutor; class GpuStream; -// Timer is started once it's created, and is stopped once read. +// When a timer is created it launches a delay kernel into the given stream and +// queues a start event immediately afterwards. This delay kernel blocks +// execution on the stream until GetElapsedDuration() is called, at which point +// an end event is queued and the delay kernel exits. This allows the device +// execution time of the tasks queued to the stream while the timer is active +// to be measured more accurately. class GpuTimer { public: - static tsl::StatusOr Create(GpuStream* stream); + class GpuSemaphore { + public: + GpuSemaphore() = default; + static absl::StatusOr Create(StreamExecutor* executor); + explicit operator bool() const { return bool{ptr_}; } + GpuSemaphoreState& operator*() { + return *static_cast(ptr_->opaque()); + } + DeviceMemory device(); + + private: + explicit GpuSemaphore(std::unique_ptr alloc) + : ptr_{std::move(alloc)} {} + std::unique_ptr ptr_; + }; + static absl::StatusOr Create(Stream* stream, bool use_delay_kernel); + [[deprecated("Pass Stream* not GpuStream*")]] static absl::StatusOr + Create(GpuStream* stream); // An ugly but a very convenient helper: creates a timer only when we need // one, but always returns an object. If `is_needed` is false, returns an // empty optional, acts like `Create` otherwise. - static tsl::StatusOr> CreateIfNeeded( - GpuStream* stream, bool is_needed); + static absl::StatusOr> CreateIfNeeded( + Stream* stream, bool use_delay_kernel, bool is_needed); + [[deprecated("Pass Stream* not GpuStream*")]] static absl::StatusOr< + std::optional> + CreateIfNeeded(GpuStream* stream, bool is_needed); explicit GpuTimer(GpuExecutor* parent, GpuEventHandle start_event, - GpuEventHandle stop_event, GpuStream* stream) + GpuEventHandle stop_event, GpuStream* stream, + GpuSemaphore semaphore = {}) : parent_(parent), start_event_(start_event), stop_event_(stop_event), - stream_(stream) {} + stream_(stream), + semaphore_(std::move(semaphore)) {} GpuTimer(GpuTimer&& other) : parent_(other.parent_), start_event_(std::exchange(other.start_event_, nullptr)), stop_event_(std::exchange(other.stop_event_, nullptr)), - stream_(other.stream_) {} + stream_(other.stream_), + semaphore_(std::move(other.semaphore_)) {} + + GpuTimer& operator=(GpuTimer&& other) { + if (this != &other) { + parent_ = other.parent_; + start_event_ = std::exchange(other.start_event_, nullptr); + stop_event_ = std::exchange(other.stop_event_, nullptr); + stream_ = other.stream_; + semaphore_ = std::move(other.semaphore_); + } + return *this; + } ~GpuTimer(); // Stops the timer on the first call and returns the elapsed duration. // Subsequent calls error out. - tsl::StatusOr GetElapsedDuration(); + absl::StatusOr GetElapsedDuration(); private: GpuExecutor* parent_; GpuEventHandle start_event_ = nullptr; GpuEventHandle stop_event_ = nullptr; GpuStream* stream_; + GpuSemaphore semaphore_; bool is_stopped_ = false; GpuTimer(const GpuTimer&) = delete; diff --git a/xla/stream_executor/gpu/gpu_timer_kernel.cu.cc b/xla/stream_executor/gpu/gpu_timer_kernel.cu.cc new file mode 100644 index 0000000000000..0ce4b1d9fbb32 --- /dev/null +++ b/xla/stream_executor/gpu/gpu_timer_kernel.cu.cc @@ -0,0 +1,52 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/stream_executor/gpu/gpu_timer_kernel.h" + +#include + +namespace stream_executor::gpu { +namespace { +// Wait for the value pointed to by `semaphore` to have value `target`, timing +// out after approximately `APPROX_TIMEOUT_SECONDS` seconds if that value is +// not reached. This can happen if, for example, blocking launches are enabled +// via CUDA_LAUNCH_BLOCKING=1. It can also happen if launching a kernel after +// this delay kernel causes synchronisation, e.g. because of lazy loading. +__global__ void DelayKernel(volatile GpuSemaphoreState* semaphore, + GpuSemaphoreState target) { + constexpr int64_t WAIT_CYCLES{1024}; + constexpr int64_t TIMEOUT_CYCLES{200000000}; // 100ms at 2GHz + const int64_t tstart{clock64()}; + bool target_not_reached; + while ((target_not_reached = (*semaphore != target)) && + (clock64() - tstart) < TIMEOUT_CYCLES) { + int64_t elapsed{}; + const int64_t t0{clock64()}; + do { + elapsed = clock64() - t0; + } while (elapsed < WAIT_CYCLES); + } + if (target_not_reached) { + // We are exiting due to the timeout. Signal this back to the host so that + // we can emit a warning, as it probably indicates suboptimal usage. + *semaphore = GpuSemaphoreState::TimedOut; + } +} +} // namespace + +namespace delay_kernel { +void* kernel() { return reinterpret_cast(DelayKernel); } +} // namespace delay_kernel + +} // namespace stream_executor::gpu diff --git a/xla/stream_executor/gpu/gpu_timer_kernel.h b/xla/stream_executor/gpu/gpu_timer_kernel.h new file mode 100644 index 0000000000000..2ac358b4ee56c --- /dev/null +++ b/xla/stream_executor/gpu/gpu_timer_kernel.h @@ -0,0 +1,26 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_KERNEL_H_ +#define XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_KERNEL_H_ + +namespace stream_executor::gpu { +enum struct GpuSemaphoreState { Hold, Release, TimedOut }; +namespace delay_kernel { +void* kernel(); // returns a pointer to a CUDA C++ device function +} // namespace delay_kernel +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_KERNEL_H_ diff --git a/xla/stream_executor/gpu/gpu_types.h b/xla/stream_executor/gpu/gpu_types.h index 18562fad63728..c8d6266b35dfe 100644 --- a/xla/stream_executor/gpu/gpu_types.h +++ b/xla/stream_executor/gpu/gpu_types.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/gpu/memcpy_test.cc b/xla/stream_executor/gpu/memcpy_test.cc new file mode 100644 index 0000000000000..96b7700ce3353 --- /dev/null +++ b/xla/stream_executor/gpu/memcpy_test.cc @@ -0,0 +1,45 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor { + +TEST(MemcpyTest, PinnedHostMemory) { +#if GOOGLE_CUDA + Platform* platform = PlatformManager::PlatformWithName("CUDA").value(); +#elif TENSORFLOW_USE_ROCM + Platform* platform = PlatformManager::PlatformWithName("ROCM").value(); +#endif + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + auto stream = executor->CreateStream().value(); + + TF_ASSERT_OK_AND_ASSIGN(auto d_ptr, + executor->HostMemoryAllocate(sizeof(int))); + DeviceMemoryBase d_mem(d_ptr->opaque(), sizeof(int)); + + int h_ptr; + TF_ASSERT_OK(stream->Memcpy(&h_ptr, d_mem, d_mem.size())); + EXPECT_TRUE(stream->BlockHostUntilDone().ok()); +} + +} // namespace stream_executor diff --git a/xla/stream_executor/gpu/redzone_allocator.cc b/xla/stream_executor/gpu/redzone_allocator.cc index ceee5503e29b7..5f08f84c2d86b 100644 --- a/xla/stream_executor/gpu/redzone_allocator.cc +++ b/xla/stream_executor/gpu/redzone_allocator.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,25 +15,35 @@ limitations under the License. #include "xla/stream_executor/gpu/redzone_allocator.h" +#include #include #include +#include #include +#include +#include #include "absl/base/call_once.h" #include "absl/container/fixed_array.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" -#include "absl/types/optional.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/gpu/asm_compiler.h" #include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/framework/allocator.h" +#include "tsl/lib/math/math_util.h" #include "tsl/platform/errors.h" -#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace stream_executor { @@ -52,7 +62,7 @@ using RedzoneCheckStatus = RedzoneAllocator::RedzoneCheckStatus; RedzoneAllocator::RedzoneAllocator(Stream* stream, DeviceMemoryAllocator* memory_allocator, - GpuAsmOpts gpu_compilation_opts, + const GpuAsmOpts& gpu_compilation_opts, int64_t memory_limit, int64_t redzone_size, uint8_t redzone_pattern) : device_ordinal_(stream->parent()->device_ordinal()), @@ -65,15 +75,13 @@ RedzoneAllocator::RedzoneAllocator(Stream* stream, memory_allocator_(memory_allocator), gpu_compilation_opts_(gpu_compilation_opts) {} -tsl::StatusOr> RedzoneAllocator::AllocateBytes( +absl::StatusOr> RedzoneAllocator::AllocateBytes( int64_t byte_size) { CHECK_GE(byte_size, 0) << "byte_size must be positive."; if (byte_size > GetMemoryLimitInBytes()) { - return tsl::Status( - absl::StatusCode::kResourceExhausted, - absl::StrFormat( - "Allocating %d bytes exceeds the memory limit of %d bytes.", - byte_size, GetMemoryLimitInBytes())); + return absl::ResourceExhaustedError(absl::StrFormat( + "Allocating %d bytes exceeds the memory limit of %d bytes.", byte_size, + GetMemoryLimitInBytes())); } int64_t rhs_slop = RoundUpToNearest(byte_size, kRhsRedzoneAlign) - byte_size; @@ -87,33 +95,34 @@ tsl::StatusOr> RedzoneAllocator::AllocateBytes( static_assert(sizeof(uint8_t) == 1, "Unexpected size"); DeviceMemory allocated_buffer_memory(*allocated_buffer); - DeviceMemory lhs_redzone = stream_->parent()->GetSubBuffer( - &allocated_buffer_memory, 0, redzone_size_); + DeviceMemory lhs_redzone = + allocated_buffer_memory.GetSlice(0, redzone_size_); - DeviceMemory data_chunk = stream_->parent()->GetSubBuffer( - &allocated_buffer_memory, redzone_size_, byte_size); + DeviceMemory data_chunk = + allocated_buffer_memory.GetSlice(redzone_size_, byte_size); // Split up the RHS redzone into two pieces: // - 0 to kRhsRedzoneAlign bytes adjacent to the user buffer, followed by // - redzone_size_ bytes. - // We do this because Stream::ThenMemset32 requires the buffer address and + // We do this because Stream::Memset32 requires the buffer address and // size to be aligned to 4 bytes. - DeviceMemory rhs_redzone_slop = stream_->parent()->GetSubBuffer( - &allocated_buffer_memory, redzone_size_ + byte_size, rhs_slop); + DeviceMemory rhs_redzone_slop = + allocated_buffer_memory.GetSlice(redzone_size_ + byte_size, rhs_slop); - DeviceMemory rhs_redzone_nonslop = stream_->parent()->GetSubBuffer( - &allocated_buffer_memory, redzone_size_ + byte_size + rhs_slop, - redzone_size_); + DeviceMemory rhs_redzone_nonslop = allocated_buffer_memory.GetSlice( + redzone_size_ + byte_size + rhs_slop, redzone_size_); uint8_t pattern_arr[] = {redzone_pattern_, redzone_pattern_, redzone_pattern_, redzone_pattern_}; uint32_t pattern32; std::memcpy(&pattern32, pattern_arr, sizeof(pattern32)); - stream_->ThenMemset32(&lhs_redzone, pattern32, redzone_size_); + TF_RETURN_IF_ERROR(stream_->Memset32(&lhs_redzone, pattern32, redzone_size_)); if (rhs_slop != 0) { - stream_->ThenMemcpy(&rhs_redzone_slop, &pattern32, rhs_slop); + TF_RETURN_IF_ERROR( + stream_->Memcpy(&rhs_redzone_slop, &pattern32, rhs_slop)); } - stream_->ThenMemset32(&rhs_redzone_nonslop, pattern32, redzone_size_); + TF_RETURN_IF_ERROR( + stream_->Memset32(&rhs_redzone_nonslop, pattern32, redzone_size_)); allocated_buffers_.emplace_back(std::move(allocated_buffer), byte_size); return data_chunk; @@ -183,13 +192,13 @@ using ComparisonKernelT = TypedKernel, uint8_t, uint64_t, // Check that redzones weren't overwritten on a host. // // Slower, but gives a more useful error message. -static tsl::StatusOr CheckRedzoneHost( +static absl::StatusOr CheckRedzoneHost( DeviceMemoryBase redzone, DeviceMemoryBase user_allocation, absl::string_view name, Stream* stream, uint8_t redzone_pattern) { uint64_t size = redzone.size(); auto redzone_data = std::make_unique(size); - TF_RETURN_IF_ERROR(stream->ThenMemcpy(redzone_data.get(), redzone, size) - .BlockHostUntilDone()); + TF_RETURN_IF_ERROR(stream->Memcpy(redzone_data.get(), redzone, size)); + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); std::array pattern_arr; pattern_arr.fill(redzone_pattern); @@ -217,14 +226,14 @@ static tsl::StatusOr CheckRedzoneHost( // Run the redzone checker on the provided buffer redzone. // // Increment out_param if mismatch occurs. -static tsl::Status RunRedzoneChecker( +static absl::Status RunRedzoneChecker( Stream* stream, const DeviceMemory& redzone, uint8_t redzone_pattern, const DeviceMemory& out_param, const ComparisonKernelT& comparison_kernel) { StreamExecutor* executor = stream->parent(); if (redzone.size() == 0) { - return tsl::OkStatus(); + return absl::OkStatus(); } int64_t num_elements = redzone.size(); @@ -236,31 +245,32 @@ static tsl::Status RunRedzoneChecker( TF_RETURN_IF_ERROR(stream->ThenLaunch( ThreadDim(threads_per_block), BlockDim(block_count), comparison_kernel, redzone, redzone_pattern, redzone.size(), out_param)); - return ::tsl::OkStatus(); + return absl::OkStatus(); } // Since we reuse the same buffer for multiple checks, we re-initialize redzone // with a NaN pattern after a failed check. // // This function is blocking, since redzone failing is a rare event. -static tsl::Status ReinitializeRedzone(Stream* stream, DeviceMemoryBase redzone, - uint8_t redzone_pattern) { +static absl::Status ReinitializeRedzone(Stream* stream, + DeviceMemoryBase redzone, + uint8_t redzone_pattern) { absl::FixedArray redzone_array(redzone.size()); redzone_array.fill(redzone_pattern); - stream->ThenMemcpy(&redzone, redzone_array.data(), redzone.size()); + TF_RETURN_IF_ERROR( + stream->Memcpy(&redzone, redzone_array.data(), redzone.size())); TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); - return ::tsl::OkStatus(); + return absl::OkStatus(); } // Check redzones around the user allocation. // // Precondition: the memory pointed out by out_param is zeroed. -static tsl::StatusOr CheckRedzonesForBuffer( +static absl::StatusOr CheckRedzonesForBuffer( Stream* stream, DeviceMemoryBase memory, const DeviceMemory& out_param, const ComparisonKernelT& comparison_kernel, int64_t user_allocation_size, uint64_t redzone_size, uint8_t redzone_pattern) { - StreamExecutor* executor = stream->parent(); int64_t rhs_slop = RoundUpToNearest(user_allocation_size, kRhsRedzoneAlign) - user_allocation_size; @@ -268,14 +278,14 @@ static tsl::StatusOr CheckRedzonesForBuffer( DeviceMemory buffer_uint8(memory); DeviceMemory lhs_redzone = - executor->GetSubBuffer(&buffer_uint8, 0, - /*element_count=*/redzone_size); + buffer_uint8.GetSlice(0, + /*element_count=*/redzone_size); DeviceMemory user_allocation = - executor->GetSubBuffer(&buffer_uint8, redzone_size, - /*element_count=*/user_allocation_size); + buffer_uint8.GetSlice(redzone_size, + /*element_count=*/user_allocation_size); DeviceMemory rhs_redzone = - executor->GetSubBuffer(&buffer_uint8, redzone_size + user_allocation_size, - /*element_count=*/redzone_size + rhs_slop); + buffer_uint8.GetSlice(redzone_size + user_allocation_size, + /*element_count=*/redzone_size + rhs_slop); TF_RETURN_IF_ERROR(RunRedzoneChecker(stream, lhs_redzone, redzone_pattern, out_param, comparison_kernel)); @@ -283,7 +293,7 @@ static tsl::StatusOr CheckRedzonesForBuffer( out_param, comparison_kernel)); int64_t result; CHECK_EQ(out_param.size(), sizeof(result)); - stream->ThenMemcpy(&result, out_param, sizeof(result)); + TF_RETURN_IF_ERROR(stream->Memcpy(&result, out_param, sizeof(result))); TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); if (result != 0) { @@ -307,11 +317,12 @@ static tsl::StatusOr CheckRedzonesForBuffer( return RedzoneCheckStatus::OK(); } -tsl::StatusOr RedzoneAllocator::CheckRedzones() const { +absl::StatusOr RedzoneAllocator::CheckRedzones() const { StreamExecutor* executor = stream_->parent(); +#if GOOGLE_CUDA absl::Span compiled_ptx = {}; - tsl::StatusOr> compiled_ptx_or = + absl::StatusOr> compiled_ptx_or = CompileGpuAsmOrGetCached(executor->device_ordinal(), redzone_checker_ptx, gpu_compilation_opts_); if (compiled_ptx_or.ok()) { @@ -326,30 +337,30 @@ tsl::StatusOr RedzoneAllocator::CheckRedzones() const { }); } - ScopedDeviceMemory out_param = - executor->AllocateOwnedScalar(); - stream_->ThenMemZero(out_param.ptr(), sizeof(uint64_t)); - -#if GOOGLE_CUDA TF_ASSIGN_OR_RETURN( - std::shared_ptr loaded_kernel, + ComparisonKernelT * kernel_ptr, (LoadKernelOrGetPtr, uint8_t, uint64_t, DeviceMemory>( executor, "redzone_checker", redzone_checker_ptx, compiled_ptx))); #elif TENSORFLOW_USE_ROCM TF_ASSIGN_OR_RETURN( - std::unique_ptr loaded_kernel, - (executor->CreateTypedKernel, uint8, uint64_t, - DeviceMemory>( - "redzone_checker", redzone_checker_ptx, compiled_ptx))); + ComparisonKernelT loaded_kernel, + (TypedKernel, uint8, uint64_t, + DeviceMemory>::Create(executor, "redzone_checker", + kernel_symbol()))); + // CUDA side returns a pointer => hence get a pointer to the loaded kernel + auto* kernel_ptr = &loaded_kernel; #endif // GOOGLE_CUDA + auto out_param = executor->AllocateOwnedScalar(); + TF_RETURN_IF_ERROR(stream_->MemZero(out_param.ptr(), sizeof(uint64_t))); + for (const auto& buf_and_size : allocated_buffers_) { TF_ASSIGN_OR_RETURN( RedzoneCheckStatus redzone_status, CheckRedzonesForBuffer(stream_, *buf_and_size.first, out_param.cref(), - *loaded_kernel, buf_and_size.second, - redzone_size_, redzone_pattern_)); + *kernel_ptr, buf_and_size.second, redzone_size_, + redzone_pattern_)); if (!redzone_status.ok()) { return redzone_status; } diff --git a/xla/stream_executor/gpu/redzone_allocator.cu.cc b/xla/stream_executor/gpu/redzone_allocator.cu.cc new file mode 100644 index 0000000000000..d6a5108ef37ca --- /dev/null +++ b/xla/stream_executor/gpu/redzone_allocator.cu.cc @@ -0,0 +1,44 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/gpu/redzone_allocator.h" + + +namespace stream_executor { + +namespace { +#if TENSORFLOW_USE_ROCM + +__global__ void redzone_checker_kernel(uint8_t* input_buffer, + uint8_t redzone_pattern, + uint64_t buffer_length, + int* out_mismatched_ptr) { + uint64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= buffer_length) return; + if (input_buffer[idx] != redzone_pattern) atomicAdd(out_mismatched_ptr, 1); +} + +#endif +} // namespace + +void* RedzoneAllocator::kernel_symbol() const { +#if TENSORFLOW_USE_ROCM + return reinterpret_cast(&redzone_checker_kernel); +#else + return nullptr; +#endif +} + +} // namespace stream_executor diff --git a/xla/stream_executor/gpu/redzone_allocator.h b/xla/stream_executor/gpu/redzone_allocator.h index 43694d3295c38..8fcac1ea2b767 100644 --- a/xla/stream_executor/gpu/redzone_allocator.h +++ b/xla/stream_executor/gpu/redzone_allocator.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,14 +17,16 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_GPU_REDZONE_ALLOCATOR_H_ #include +#include +#include #include +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/stream_executor/device_memory_allocator.h" -#include "xla/stream_executor/gpu/asm_compiler.h" #include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream_executor.h" -#include "tsl/lib/math/math_util.h" namespace stream_executor { @@ -45,7 +47,7 @@ class RedzoneAllocator : public ScratchAllocator { 1LL << 23; // 8MiB per side, 16MiB total. static constexpr uint8_t kDefaultRedzonePattern = -1; // NOLINT RedzoneAllocator(Stream* stream, DeviceMemoryAllocator* memory_allocator, - GpuAsmOpts gpu_compilation_opts_, + const GpuAsmOpts& gpu_compilation_opts_, int64_t memory_limit = (1LL << 32), // 4GB int64_t redzone_size = kDefaultRedzoneSize, uint8_t redzone_pattern = kDefaultRedzonePattern); @@ -57,7 +59,7 @@ class RedzoneAllocator : public ScratchAllocator { return allocated_bytes_excluding_redzones_; } - tsl::StatusOr> AllocateBytes(int64_t byte_size) override; + absl::StatusOr> AllocateBytes(int64_t byte_size) override; // Non-empty redzone check status implies that there was a write into a // redzone, with a string communicating the location of the write. @@ -97,10 +99,13 @@ class RedzoneAllocator : public ScratchAllocator { // - RedzoneCheckStatus with a non-empty error message iff a write into a // redzone has been detected. // - A stream error, if loading or launching the kernel has failed. - tsl::StatusOr CheckRedzones() const; + absl::StatusOr CheckRedzones() const; Stream* stream() const { return stream_; } + // Return a pointer to in-process kernel symbol (used to check redzones). + void* kernel_symbol() const; + private: const int device_ordinal_; Stream* stream_; diff --git a/xla/stream_executor/cuda/redzone_allocator_test.cc b/xla/stream_executor/gpu/redzone_allocator_test.cc similarity index 79% rename from xla/stream_executor/cuda/redzone_allocator_test.cc rename to xla/stream_executor/gpu/redzone_allocator_test.cc index 8b496b35a1cff..6c5ee154d1502 100644 --- a/xla/stream_executor/cuda/redzone_allocator_test.cc +++ b/xla/stream_executor/gpu/redzone_allocator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,31 +13,36 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifdef GOOGLE_CUDA - #include "xla/stream_executor/gpu/redzone_allocator.h" #include +#include +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/gpu/gpu_asm_opts.h" -#include "xla/stream_executor/multi_platform_manager.h" +#include "xla/stream_executor/gpu/gpu_init.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace stream_executor { -namespace cuda { -namespace { +namespace gpu { using RedzoneCheckStatus = RedzoneAllocator::RedzoneCheckStatus; -static void EXPECT_REDZONE_OK(tsl::StatusOr status) { +static void EXPECT_REDZONE_OK(absl::StatusOr status) { EXPECT_TRUE(status.ok()); EXPECT_TRUE(status.value().ok()); } -static void EXPECT_REDZONE_VIOLATION(tsl::StatusOr status) { +static void EXPECT_REDZONE_VIOLATION( + absl::StatusOr status) { EXPECT_TRUE(status.ok()); EXPECT_FALSE(status.value().ok()); } @@ -51,14 +56,14 @@ TEST(RedzoneAllocatorTest, WriteToRedzone) { // Allocate 32MiB + 1 byte (to make things misaligned) constexpr int64_t kAllocSize = (1 << 25) + 1; - Platform* platform = MultiPlatformManager::PlatformWithName("cuda").value(); + Platform* platform = + PlatformManager::PlatformWithName(GpuPlatformName()).value(); StreamExecutor* stream_exec = platform->ExecutorForDevice(0).value(); GpuAsmOpts opts; StreamExecutorMemoryAllocator se_allocator(platform, {stream_exec}); - Stream stream(stream_exec); - stream.Init(); - RedzoneAllocator allocator(&stream, &se_allocator, opts, + TF_ASSERT_OK_AND_ASSIGN(auto stream, stream_exec->CreateStream()); + RedzoneAllocator allocator(stream.get(), &se_allocator, opts, /*memory_limit=*/(1LL << 32), /*redzone_size=*/kRedzoneSize, /*redzone_pattern=*/kRedzonePattern); @@ -73,8 +78,8 @@ TEST(RedzoneAllocatorTest, WriteToRedzone) { // Check that the redzones are in fact filled with kRedzonePattern. auto check_redzone = [&](DeviceMemoryBase redzone, absl::string_view name) { std::vector host_buf(kRedzoneSize); - TF_ASSERT_OK(stream.ThenMemcpy(host_buf.data(), redzone, kRedzoneSize) - .BlockHostUntilDone()); + TF_ASSERT_OK(stream->Memcpy(host_buf.data(), redzone, kRedzoneSize)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); const int64_t kMaxMismatches = 16; int64_t mismatches = 0; for (int64_t i = 0; i < host_buf.size(); ++i) { @@ -102,8 +107,8 @@ TEST(RedzoneAllocatorTest, WriteToRedzone) { reinterpret_cast(redzone.opaque()) + offset, 1); char old_redzone_value = 0; { EXPECT_REDZONE_OK(allocator.CheckRedzones()); } - stream.ThenMemcpy(&old_redzone_value, redzone_at_offset, 1) - .ThenMemZero(&redzone_at_offset, 1); + TF_ASSERT_OK(stream->Memcpy(&old_redzone_value, redzone_at_offset, 1)); + TF_ASSERT_OK(stream->MemZero(&redzone_at_offset, 1)); EXPECT_REDZONE_VIOLATION(allocator.CheckRedzones()); // Checking reinitializes the redzone. @@ -124,13 +129,13 @@ TEST(RedzoneAllocatorTest, WriteToRedzone) { TEST(RedzoneAllocatorTest, VeryLargeRedzone) { // Make sure the redzone size would require grid dimension > 65535. constexpr int64_t kRedzoneSize = 65535 * 1024 + 1; - Platform* platform = MultiPlatformManager::PlatformWithName("cuda").value(); + Platform* platform = + PlatformManager::PlatformWithName(GpuPlatformName()).value(); StreamExecutor* stream_exec = platform->ExecutorForDevice(0).value(); GpuAsmOpts opts; StreamExecutorMemoryAllocator se_allocator(platform, {stream_exec}); - Stream stream(stream_exec); - stream.Init(); - RedzoneAllocator allocator(&stream, &se_allocator, opts, + TF_ASSERT_OK_AND_ASSIGN(auto stream, stream_exec->CreateStream()); + RedzoneAllocator allocator(stream.get(), &se_allocator, opts, /*memory_limit=*/(1LL << 32), /*redzone_size=*/kRedzoneSize, /*redzone_pattern=*/-1); @@ -138,8 +143,5 @@ TEST(RedzoneAllocatorTest, VeryLargeRedzone) { EXPECT_REDZONE_OK(allocator.CheckRedzones()); } -} // namespace -} // namespace cuda +} // namespace gpu } // namespace stream_executor - -#endif // GOOGLE_CUDA diff --git a/xla/stream_executor/gpu/stream_search_test.cc b/xla/stream_executor/gpu/stream_search_test.cc new file mode 100644 index 0000000000000..9f91c63ae0d97 --- /dev/null +++ b/xla/stream_executor/gpu/stream_search_test.cc @@ -0,0 +1,74 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/status/statusor.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/test.h" + +namespace stream_executor { +namespace { + +class StreamSearchTest : public ::testing::Test { + public: + Platform* GetPlatform() { +#if GOOGLE_CUDA + return *PlatformManager::PlatformWithName("CUDA"); +#elif TENSORFLOW_USE_ROCM + return *PlatformManager::PlatformWithName("ROCM"); +#endif + } +}; + +TEST_F(StreamSearchTest, NoMatchBadPtr) { + void* bad_ptr = reinterpret_cast(0xdeadbeef); + + StreamExecutorConfig config; + config.gpu_stream = bad_ptr; + + absl::StatusOr found_executor = + GetPlatform()->GetExecutor(config); + + // No executor found. + EXPECT_FALSE(found_executor.ok()); +} + +TEST_F(StreamSearchTest, FoundPrevExecutor) { + TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor, + GetPlatform()->ExecutorForDevice(0)); + + TF_ASSERT_OK_AND_ASSIGN(auto s, executor->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto s2, executor->CreateStream()); + + void* gpu_ptr = s->platform_specific_handle().stream; + void* gpu_ptr_2 = s2->platform_specific_handle().stream; + + StreamExecutorConfig c; + c.gpu_stream = gpu_ptr; + + TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * found_executor, + GetPlatform()->GetExecutor(c)); + EXPECT_EQ(found_executor, executor); + + Stream* found1 = found_executor->FindAllocatedStream(gpu_ptr); + EXPECT_EQ(found1, s.get()); + + Stream* found2 = found_executor->FindAllocatedStream(gpu_ptr_2); + EXPECT_EQ(found2, s2.get()); +} + +} // namespace +} // namespace stream_executor diff --git a/xla/stream_executor/host/BUILD b/xla/stream_executor/host/BUILD index 16ef8bb031d55..1915abdc4a501 100644 --- a/xla/stream_executor/host/BUILD +++ b/xla/stream_executor/host/BUILD @@ -1,14 +1,14 @@ # Description: # Host-platform specific StreamExecutor support code. -load("//xla:xla.bzl", "xla_cc_test") -load("@tsl//tsl:tsl.bzl", "set_external_visibility") +load("@tsl//tsl:tsl.bzl", "internal_visibility") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") load("//xla/stream_executor:build_defs.bzl", "stream_executor_friends") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([":friends"]), + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -40,16 +40,18 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - ":host_gpu_executor", + ":host_executor", ":host_platform_id", "//xla/stream_executor", - "//xla/stream_executor:multi_platform_manager", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_headers", "//xla/stream_executor/platform", - "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:status", ], - alwayslink = True, # Registers itself with the MultiPlatformManager. + alwayslink = True, # Registers itself with the PlatformManager. ) cc_library( @@ -61,35 +63,71 @@ cc_library( "host_stream.h", ], deps = [ - "//xla/stream_executor", "//xla/stream_executor:stream_executor_internal", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", "@tsl//tsl/platform:denormal", "@tsl//tsl/platform:env", "@tsl//tsl/platform:setround", + "@tsl//tsl/platform:thread_annotations", ], ) -# TODO(22689637): Rename this target. cc_library( - name = "host_gpu_executor", + name = "host_kernel_c_api", + hdrs = ["host_kernel_c_api.h"], +) + +cc_library( + name = "host_kernel", + srcs = ["host_kernel.cc"], + hdrs = ["host_kernel.h"], + deps = [ + ":host_kernel_c_api", + "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream_executor_internal", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + ], +) + +xla_cc_test( + name = "host_kernel_test", + srcs = ["host_kernel_test.cc"], + deps = [ + ":host_kernel", + ":host_kernel_c_api", + "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "host_executor", srcs = [ - "host_gpu_executor.cc", + "host_executor.cc", ], hdrs = [ - "host_gpu_executor.h", + "host_executor.h", ], deps = [ - ":host_platform_id", ":host_stream", "//xla/stream_executor", - "//xla/stream_executor:plugin_registry", "//xla/stream_executor:stream_executor_internal", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:errors", "@tsl//tsl/platform:platform_port", "@tsl//tsl/platform/profile_utils:profile_utils_cpu_utils", ], @@ -102,8 +140,9 @@ xla_cc_test( deps = [ ":host_platform", "//xla/stream_executor", - "//xla/stream_executor:multi_platform_manager", "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", diff --git a/xla/stream_executor/host/host_executor.cc b/xla/stream_executor/host/host_executor.cc new file mode 100644 index 0000000000000..33d54a16f3fee --- /dev/null +++ b/xla/stream_executor/host/host_executor.cc @@ -0,0 +1,273 @@ +/* Copyright 2016 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implementation of HostExecutor class [of those methods not defined in the +// class declaration]. +#include "xla/stream_executor/host/host_executor.h" + +#include +#include + +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/notification.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/host/host_stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_executor_internal.h" +#include "tsl/platform/mem.h" +#include "tsl/platform/profile_utils/cpu_utils.h" + +namespace stream_executor { +namespace host { + +HostStream* AsHostStream(Stream* stream) { + DCHECK(stream != nullptr); + return dynamic_cast(stream->implementation()); +} + +absl::Status HostExecutor::Init(int device_ordinal) { return absl::OkStatus(); } + +bool HostExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const { + tsl::port::MemoryInfo mem_info = tsl::port::GetMemoryInfo(); + *free = (mem_info.free != INT64_MAX) ? mem_info.free : -1; + *total = (mem_info.total != INT64_MAX) ? mem_info.total : -1; + return true; +} + +DeviceMemoryBase HostExecutor::Allocate(uint64_t size, int64_t memory_space) { + CHECK_EQ(memory_space, 0); + // Use a minimum alignment of 64 bytes to be friendly to AVX512 code. + // This should probably be kept in sync with + // tsl::Allocator::kAllocatorAlignment. + return DeviceMemoryBase( + tsl::port::AlignedMalloc(size, /*minimum_alignment=*/64), size); +} + +void HostExecutor::Deallocate(DeviceMemoryBase* mem) { + tsl::port::AlignedFree(mem->opaque()); +} + +absl::Status HostExecutor::SynchronousMemZero(DeviceMemoryBase* location, + uint64_t size) { + memset(location->opaque(), 0, size); + return absl::OkStatus(); +} + +absl::Status HostExecutor::SynchronousMemSet(DeviceMemoryBase* location, + int value, uint64_t size) { + memset(location->opaque(), value, size); + return absl::OkStatus(); +} + +absl::Status HostExecutor::Memcpy(Stream* stream, void* host_dst, + const DeviceMemoryBase& gpu_src, + uint64_t size) { + // Enqueue the [asynchronous] memcpy on the stream (HostStream) associated + // with the HostExecutor. + void* src_mem = const_cast(gpu_src.opaque()); + AsHostStream(stream)->EnqueueTask( + [host_dst, src_mem, size]() { memcpy(host_dst, src_mem, size); }); + return absl::OkStatus(); +} + +absl::Status HostExecutor::Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, + const void* host_src, uint64_t size) { + void* dst_mem = gpu_dst->opaque(); + // Enqueue the [asynchronous] memcpy on the stream (HostStream) associated + // with the HostExecutor. + AsHostStream(stream)->EnqueueTask( + [dst_mem, host_src, size]() { memcpy(dst_mem, host_src, size); }); + return absl::OkStatus(); +} + +bool HostExecutor::MemcpyDeviceToDevice(Stream* stream, + DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, + uint64_t size) { + void* dst_mem = gpu_dst->opaque(); + void* src_mem = const_cast(gpu_src.opaque()); + // Enqueue this [asynchronous] "device-to-device" (i.e., host-to-host, given + // the nature of the HostExecutor) memcpy on the stream (HostStream) + // associated with the HostExecutor. + AsHostStream(stream)->EnqueueTask( + [src_mem, dst_mem, size]() { memcpy(dst_mem, src_mem, size); }); + return true; +} + +absl::Status HostExecutor::MemZero(Stream* stream, DeviceMemoryBase* location, + uint64_t size) { + void* gpu_mem = location->opaque(); + // Enqueue the [asynchronous] memzero on the stream (HostStream) associated + // with the HostExecutor. + AsHostStream(stream)->EnqueueTask( + [gpu_mem, size]() { memset(gpu_mem, 0, size); }); + return absl::OkStatus(); +} + +absl::Status HostExecutor::Memset(Stream* stream, DeviceMemoryBase* location, + uint8 pattern, uint64_t size) { + void* gpu_mem = location->opaque(); + // Enqueue the [asynchronous] memzero on the stream (HostStream) associated + // with the HostExecutor. + AsHostStream(stream)->EnqueueTask( + [gpu_mem, size, pattern]() { memset(gpu_mem, pattern, size); }); + return absl::OkStatus(); +} + +absl::Status HostExecutor::Memset32(Stream* stream, DeviceMemoryBase* location, + uint32_t pattern, uint64_t size) { + void* gpu_mem = location->opaque(); + // Enqueue the [asynchronous] memzero on the stream (HostStream) associated + // with the HostExecutor. + AsHostStream(stream)->EnqueueTask( + [gpu_mem, size, pattern]() { memset(gpu_mem, pattern, size); }); + return absl::OkStatus(); +} + +absl::Status HostExecutor::SynchronousMemcpy(DeviceMemoryBase* gpu_dst, + const void* host_src, + uint64_t size) { + memcpy(gpu_dst->opaque(), host_src, size); + return absl::OkStatus(); +} + +absl::Status HostExecutor::SynchronousMemcpy(void* host_dst, + const DeviceMemoryBase& gpu_src, + uint64_t size) { + memcpy(host_dst, gpu_src.opaque(), size); + return absl::OkStatus(); +} + +absl::Status HostExecutor::SynchronousMemcpyDeviceToDevice( + DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, uint64_t size) { + memcpy(gpu_dst->opaque(), gpu_src.opaque(), size); + return absl::OkStatus(); +} + +bool HostExecutor::HostCallback( + Stream* stream, absl::AnyInvocable callback) { + AsHostStream(stream)->EnqueueTaskWithStatus(std::move(callback)); + return true; +} + +bool HostExecutor::AllocateStream(Stream* stream) { return true; } + +void HostExecutor::DeallocateStream(Stream* stream) {} + +bool HostExecutor::CreateStreamDependency(Stream* dependent, Stream* other) { + auto event = std::make_shared(); + AsHostStream(other)->EnqueueTask([event]() { event->Notify(); }); + AsHostStream(dependent)->EnqueueTask( + [event]() { event->WaitForNotification(); }); + return true; +} + +class HostEvent : public internal::EventInterface { + public: + HostEvent() : notification_(std::make_shared()) {} + + std::shared_ptr& notification() { return notification_; } + + private: + // We use a std::shared_ptr here because the client may delete the HostEvent + // object while there are still RecordEvent and WaitForEvent callbacks pending + // on a stream. + std::shared_ptr notification_; +}; + +std::unique_ptr +HostExecutor::CreateEventImplementation() { + return std::unique_ptr(new HostEvent()); +} + +static HostEvent* AsHostEvent(Event* event) { + DCHECK(event != nullptr); + return static_cast(event->implementation()); +} + +absl::Status HostExecutor::AllocateEvent(Event* /*event*/) { + return absl::OkStatus(); +} + +absl::Status HostExecutor::DeallocateEvent(Event* /*event*/) { + return absl::OkStatus(); +} + +absl::Status HostExecutor::RecordEvent(Stream* stream, Event* event) { + std::shared_ptr notification = + AsHostEvent(event)->notification(); + AsHostStream(stream)->EnqueueTask([notification]() { + CHECK(!notification->HasBeenNotified()); + notification->Notify(); + }); + return absl::OkStatus(); +} + +absl::Status HostExecutor::WaitForEvent(Stream* stream, Event* event) { + std::shared_ptr notification = + AsHostEvent(event)->notification(); + AsHostStream(stream)->EnqueueTask( + [notification]() { notification->WaitForNotification(); }); + return absl::OkStatus(); +} + +Event::Status HostExecutor::PollForEventStatus(Event* event) { + absl::Notification& notification = *AsHostEvent(event)->notification(); + return notification.HasBeenNotified() ? Event::Status::kComplete + : Event::Status::kPending; +} + +absl::Status HostExecutor::BlockHostUntilDone(Stream* stream) { + return AsHostStream(stream)->BlockUntilDone(); +} + +absl::StatusOr> +HostExecutor::CreateDeviceDescription(int device_ordinal) { + internal::DeviceDescriptionBuilder builder; + + builder.set_device_address_bits(64); + + // TODO(rspringer): How to report a value that's based in reality but that + // doesn't result in thrashing or other badness? 4GiB chosen arbitrarily. + builder.set_device_memory_size(static_cast(4) * 1024 * 1024 * 1024); + + float cycle_counter_frequency = static_cast( + tsl::profile_utils::CpuUtils::GetCycleCounterFrequency()); + builder.set_clock_rate_ghz(cycle_counter_frequency / 1e9); + + builder.set_name("Host"); + builder.set_platform_version("Default Version"); + + return builder.Build(); +} + +std::unique_ptr +HostExecutor::GetStreamImplementation() { + return std::make_unique(); +} + +} // namespace host +} // namespace stream_executor diff --git a/xla/stream_executor/host/host_executor.h b/xla/stream_executor/host/host_executor.h new file mode 100644 index 0000000000000..6123e227591fa --- /dev/null +++ b/xla/stream_executor/host/host_executor.h @@ -0,0 +1,150 @@ +/* Copyright 2016 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Declares the HostExecutor class, which is a CPU-only implementation of +// the StreamExecutor interface. For now, this is used for testing and to +// examine the performance of host-based StreamExecutor code. +#ifndef XLA_STREAM_EXECUTOR_HOST_HOST_EXECUTOR_H_ +#define XLA_STREAM_EXECUTOR_HOST_HOST_EXECUTOR_H_ + +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_executor_internal.h" + +namespace stream_executor { +namespace host { + +// An implementation of StreamExecutor that does no communication or interaction +// with a device, but DOES perform memory operations backed by the host. +// Kernel invocations will fail, but host callbacks may be enqueued on this +// executor and its associated stream, and should follow standard ordering +// semantics. +// +// This is useful for evaluating the performance of host-based or fallback +// routines executed under the context of a GPU executor. +// See stream_executor.h for description of the below operations. +class HostExecutor : public internal::StreamExecutorInterface { + public: + HostExecutor() = default; + + absl::Status Init(int device_ordinal) override; + + absl::Status GetKernel(const MultiKernelLoaderSpec& spec, + Kernel* kernel) override { + return absl::UnimplementedError("Not Implemented"); + } + absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, + const BlockDim& block_dims, const Kernel& kernel, + const KernelArgs& args) override { + return absl::UnimplementedError("Not Implemented"); + } + + DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; + void Deallocate(DeviceMemoryBase* mem) override; + + void* HostMemoryAllocate(uint64_t size) override { return new char[size]; } + void HostMemoryDeallocate(void* mem) override { + delete[] static_cast(mem); + } + bool HostMemoryRegister(void* mem, uint64_t size) override { return true; } + bool HostMemoryUnregister(void* mem) override { return true; } + + absl::Status Memcpy(Stream* stream, void* host_dst, + const DeviceMemoryBase& gpu_src, uint64_t size) override; + absl::Status Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, + const void* host_src, uint64_t size) override; + bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, + uint64_t size) override; + + absl::Status MemZero(Stream* stream, DeviceMemoryBase* location, + uint64_t size) override; + absl::Status Memset(Stream* stream, DeviceMemoryBase* location, + uint8_t pattern, uint64_t size) override; + absl::Status Memset32(Stream* stream, DeviceMemoryBase* location, + uint32_t pattern, uint64_t size) override; + + // No "synchronize all activity" implemented for this platform at the moment. + bool SynchronizeAllActivity() override { return true; } + absl::Status SynchronousMemZero(DeviceMemoryBase* location, + uint64_t size) override; + + absl::Status SynchronousMemSet(DeviceMemoryBase* location, int value, + uint64_t size) override; + + absl::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst, + const void* host_src, uint64_t size) override; + absl::Status SynchronousMemcpy(void* host_dst, + const DeviceMemoryBase& gpu_src, + uint64_t size) override; + absl::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, + uint64_t size) override; + + bool HostCallback(Stream* stream, + absl::AnyInvocable callback) override; + + absl::Status AllocateEvent(Event* event) override; + absl::Status DeallocateEvent(Event* event) override; + absl::Status RecordEvent(Stream* stream, Event* event) override; + absl::Status WaitForEvent(Stream* stream, Event* event) override; + Event::Status PollForEventStatus(Event* event) override; + + bool AllocateStream(Stream* stream) override; + void DeallocateStream(Stream* stream) override; + bool CreateStreamDependency(Stream* dependent, Stream* other) override; + + absl::Status BlockHostUntilDone(Stream* stream) override; + + bool DeviceMemoryUsage(int64_t* free, int64_t* total) const override; + + absl::StatusOr> CreateDeviceDescription() + const override { + return CreateDeviceDescription(0); + } + + static absl::StatusOr> + CreateDeviceDescription(int device_ordinal); + + absl::Status EnablePeerAccessTo(StreamExecutorInterface* other) override { + return absl::OkStatus(); + } + + bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override { + return true; + } + + std::unique_ptr CreateEventImplementation() + override; + + std::unique_ptr GetStreamImplementation() override; +}; + +} // namespace host +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_HOST_HOST_EXECUTOR_H_ diff --git a/xla/stream_executor/host/host_gpu_executor.cc b/xla/stream_executor/host/host_gpu_executor.cc deleted file mode 100644 index 254d2ac02c4d3..0000000000000 --- a/xla/stream_executor/host/host_gpu_executor.cc +++ /dev/null @@ -1,311 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Implementation of HostExecutor class [of those methods not defined in the -// class declaration]. -#include "xla/stream_executor/host/host_gpu_executor.h" - -#include -#include - -#include -#include -#include - -#include "absl/functional/any_invocable.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_cat.h" -#include "absl/synchronization/notification.h" -#include "xla/stream_executor/host/host_platform_id.h" -#include "xla/stream_executor/host/host_stream.h" -#include "xla/stream_executor/plugin_registry.h" -#include "xla/stream_executor/stream_executor_internal.h" -#include "tsl/platform/mem.h" -#include "tsl/platform/profile_utils/cpu_utils.h" - -namespace stream_executor { -namespace host { - -HostStream* AsHostStream(Stream* stream) { - DCHECK(stream != nullptr); - return dynamic_cast(stream->implementation()); -} - -tsl::Status HostExecutor::Init(int device_ordinal, - DeviceOptions device_options) { - auto it = - device_options.non_portable_tags.find("host_thread_stack_size_in_bytes"); - if (it != device_options.non_portable_tags.end()) { - if (!absl::SimpleAtoi(it->second, &thread_stack_size_in_bytes_)) { - return tsl::errors::InvalidArgument( - "Unable to parse host_thread_stack_size_in_bytes as an integer: ", - it->second); - } - } - return ::tsl::OkStatus(); -} - -bool HostExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const { - tsl::port::MemoryInfo mem_info = tsl::port::GetMemoryInfo(); - *free = (mem_info.free != INT64_MAX) ? mem_info.free : -1; - *total = (mem_info.total != INT64_MAX) ? mem_info.total : -1; - return true; -} - -DeviceMemoryBase HostExecutor::Allocate(uint64_t size, int64_t memory_space) { - CHECK_EQ(memory_space, 0); - // Use a minimum alignment of 64 bytes to be friendly to AVX512 code. - // This should probably be kept in sync with - // tsl::Allocator::kAllocatorAlignment. - return DeviceMemoryBase( - tsl::port::AlignedMalloc(size, /*minimum_alignment=*/64), size); -} - -void* HostExecutor::GetSubBuffer(DeviceMemoryBase* parent, - uint64_t offset_bytes, uint64_t size_bytes) { - return reinterpret_cast(parent->opaque()) + offset_bytes; -} - -void HostExecutor::Deallocate(DeviceMemoryBase* mem) { - tsl::port::AlignedFree(mem->opaque()); -} - -tsl::Status HostExecutor::SynchronousMemZero(DeviceMemoryBase* location, - uint64_t size) { - memset(location->opaque(), 0, size); - return ::tsl::OkStatus(); -} - -tsl::Status HostExecutor::SynchronousMemSet(DeviceMemoryBase* location, - int value, uint64_t size) { - memset(location->opaque(), value, size); - return ::tsl::OkStatus(); -} - -bool HostExecutor::Memcpy(Stream* stream, void* host_dst, - const DeviceMemoryBase& gpu_src, uint64_t size) { - // Enqueue the [asynchronous] memcpy on the stream (HostStream) associated - // with the HostExecutor. - void* src_mem = const_cast(gpu_src.opaque()); - AsHostStream(stream)->EnqueueTask( - [host_dst, src_mem, size]() { memcpy(host_dst, src_mem, size); }); - return true; -} - -bool HostExecutor::Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, - const void* host_src, uint64_t size) { - void* dst_mem = gpu_dst->opaque(); - // Enqueue the [asynchronous] memcpy on the stream (HostStream) associated - // with the HostExecutor. - AsHostStream(stream)->EnqueueTask( - [dst_mem, host_src, size]() { memcpy(dst_mem, host_src, size); }); - return true; -} - -bool HostExecutor::MemcpyDeviceToDevice(Stream* stream, - DeviceMemoryBase* gpu_dst, - const DeviceMemoryBase& gpu_src, - uint64_t size) { - void* dst_mem = gpu_dst->opaque(); - void* src_mem = const_cast(gpu_src.opaque()); - // Enqueue this [asynchronous] "device-to-device" (i.e., host-to-host, given - // the nature of the HostExecutor) memcpy on the stream (HostStream) - // associated with the HostExecutor. - AsHostStream(stream)->EnqueueTask( - [src_mem, dst_mem, size]() { memcpy(dst_mem, src_mem, size); }); - return true; -} - -tsl::Status HostExecutor::MemZero(Stream* stream, DeviceMemoryBase* location, - uint64_t size) { - void* gpu_mem = location->opaque(); - // Enqueue the [asynchronous] memzero on the stream (HostStream) associated - // with the HostExecutor. - AsHostStream(stream)->EnqueueTask( - [gpu_mem, size]() { memset(gpu_mem, 0, size); }); - return ::tsl::OkStatus(); -} - -tsl::Status HostExecutor::Memset(Stream* stream, DeviceMemoryBase* location, - uint8 pattern, uint64_t size) { - void* gpu_mem = location->opaque(); - // Enqueue the [asynchronous] memzero on the stream (HostStream) associated - // with the HostExecutor. - AsHostStream(stream)->EnqueueTask( - [gpu_mem, size, pattern]() { memset(gpu_mem, pattern, size); }); - return ::tsl::OkStatus(); -} - -tsl::Status HostExecutor::Memset32(Stream* stream, DeviceMemoryBase* location, - uint32_t pattern, uint64_t size) { - void* gpu_mem = location->opaque(); - // Enqueue the [asynchronous] memzero on the stream (HostStream) associated - // with the HostExecutor. - AsHostStream(stream)->EnqueueTask( - [gpu_mem, size, pattern]() { memset(gpu_mem, pattern, size); }); - return ::tsl::OkStatus(); -} - -tsl::Status HostExecutor::SynchronousMemcpy(DeviceMemoryBase* gpu_dst, - const void* host_src, - uint64_t size) { - memcpy(gpu_dst->opaque(), host_src, size); - return ::tsl::OkStatus(); -} - -tsl::Status HostExecutor::SynchronousMemcpy(void* host_dst, - const DeviceMemoryBase& gpu_src, - uint64_t size) { - memcpy(host_dst, gpu_src.opaque(), size); - return ::tsl::OkStatus(); -} - -tsl::Status HostExecutor::SynchronousMemcpyDeviceToDevice( - DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, uint64_t size) { - memcpy(gpu_dst->opaque(), gpu_src.opaque(), size); - return ::tsl::OkStatus(); -} - -bool HostExecutor::HostCallback(Stream* stream, - absl::AnyInvocable callback) { - AsHostStream(stream)->EnqueueTaskWithStatus(std::move(callback)); - return true; -} - -bool HostExecutor::AllocateStream(Stream* stream) { return true; } - -void HostExecutor::DeallocateStream(Stream* stream) {} - -bool HostExecutor::CreateStreamDependency(Stream* dependent, Stream* other) { - auto event = std::make_shared(); - AsHostStream(other)->EnqueueTask([event]() { event->Notify(); }); - AsHostStream(dependent)->EnqueueTask( - [event]() { event->WaitForNotification(); }); - return true; -} - -class HostEvent : public internal::EventInterface { - public: - HostEvent() : notification_(std::make_shared()) {} - - std::shared_ptr& notification() { return notification_; } - - private: - // We use a std::shared_ptr here because the client may delete the HostEvent - // object while there are still RecordEvent and WaitForEvent callbacks pending - // on a stream. - std::shared_ptr notification_; -}; - -std::unique_ptr -HostExecutor::CreateEventImplementation() { - return std::unique_ptr(new HostEvent()); -} - -static HostEvent* AsHostEvent(Event* event) { - DCHECK(event != nullptr); - return static_cast(event->implementation()); -} - -tsl::Status HostExecutor::AllocateEvent(Event* /*event*/) { - return ::tsl::OkStatus(); -} - -tsl::Status HostExecutor::DeallocateEvent(Event* /*event*/) { - return ::tsl::OkStatus(); -} - -tsl::Status HostExecutor::RecordEvent(Stream* stream, Event* event) { - std::shared_ptr notification = - AsHostEvent(event)->notification(); - AsHostStream(stream)->EnqueueTask([notification]() { - CHECK(!notification->HasBeenNotified()); - notification->Notify(); - }); - return ::tsl::OkStatus(); -} - -tsl::Status HostExecutor::WaitForEvent(Stream* stream, Event* event) { - std::shared_ptr notification = - AsHostEvent(event)->notification(); - AsHostStream(stream)->EnqueueTask( - [notification]() { notification->WaitForNotification(); }); - return ::tsl::OkStatus(); -} - -Event::Status HostExecutor::PollForEventStatus(Event* event) { - absl::Notification& notification = *AsHostEvent(event)->notification(); - return notification.HasBeenNotified() ? Event::Status::kComplete - : Event::Status::kPending; -} - -tsl::Status HostExecutor::BlockHostUntilDone(Stream* stream) { - return AsHostStream(stream)->BlockUntilDone(); -} - -tsl::StatusOr> -HostExecutor::CreateDeviceDescription(int device_ordinal) { - internal::DeviceDescriptionBuilder builder; - - builder.set_device_address_bits(64); - - // TODO(rspringer): How to report a value that's based in reality but that - // doesn't result in thrashing or other badness? 4GiB chosen arbitrarily. - builder.set_device_memory_size(static_cast(4) * 1024 * 1024 * 1024); - - float cycle_counter_frequency = static_cast( - tsl::profile_utils::CpuUtils::GetCycleCounterFrequency()); - builder.set_clock_rate_ghz(cycle_counter_frequency / 1e9); - - builder.set_name("Host"); - builder.set_platform_version("Default Version"); - - return builder.Build(); -} - -blas::BlasSupport* HostExecutor::CreateBlas() { - PluginRegistry* registry = PluginRegistry::Instance(); - tsl::StatusOr status = - registry->GetFactory(kHostPlatformId); - if (!status.ok()) { - LOG(ERROR) << "Unable to retrieve BLAS factory: " - << status.status().message(); - return nullptr; - } - - return status.value()(this); -} - -fft::FftSupport* HostExecutor::CreateFft() { - PluginRegistry* registry = PluginRegistry::Instance(); - tsl::StatusOr status = - registry->GetFactory(kHostPlatformId); - if (!status.ok()) { - LOG(ERROR) << "Unable to retrieve FFT factory: " - << status.status().message(); - return nullptr; - } - - return status.value()(this); -} - -std::unique_ptr -HostExecutor::GetStreamImplementation() { - return std::unique_ptr( - new HostStream(thread_stack_size_in_bytes_)); -} - -} // namespace host -} // namespace stream_executor diff --git a/xla/stream_executor/host/host_gpu_executor.h b/xla/stream_executor/host/host_gpu_executor.h deleted file mode 100644 index 6ca6d0bb6594c..0000000000000 --- a/xla/stream_executor/host/host_gpu_executor.h +++ /dev/null @@ -1,162 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Declares the HostExecutor class, which is a CPU-only implementation of -// the StreamExecutor interface. For now, this is used for testing and to -// examine the performance of host-based StreamExecutor code. -#ifndef XLA_STREAM_EXECUTOR_HOST_HOST_GPU_EXECUTOR_H_ -#define XLA_STREAM_EXECUTOR_HOST_HOST_GPU_EXECUTOR_H_ - -#include - -#include "absl/functional/any_invocable.h" -#include "xla/stream_executor/blas.h" -#include "xla/stream_executor/host/host_stream.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/stream_executor_internal.h" -#include "tsl/platform/errors.h" - -namespace stream_executor { -namespace host { - -// An implementation of StreamExecutor that does no communication or interaction -// with a device, but DOES perform memory operations backed by the host. -// Plugin routines (BLAS) are also supported and functional. -// Kernel invocations will fail, but host callbacks may be enqueued on this -// executor and its associated stream, and should follow standard ordering -// semantics. -// -// This is useful for evaluating the performance of host-based or fallback -// routines executed under the context of a GPU executor. -// See stream_executor.h for description of the below operations. -class HostExecutor : public internal::StreamExecutorInterface { - public: - HostExecutor() = default; - - // The stack size used for host streams can be set via - // device_options.non_portable_tags["host_stack_size"]. - tsl::Status Init(int device_ordinal, DeviceOptions device_options) override; - - tsl::Status GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) override { - return tsl::errors::Unimplemented("Not Implemented"); - } - tsl::Status Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, const Kernel& kernel, - const KernelArgs& args) override { - return tsl::errors::Unimplemented("Not Implemented"); - } - - DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; - void* GetSubBuffer(DeviceMemoryBase* parent, uint64_t offset_bytes, - uint64_t size_bytes) override; - void Deallocate(DeviceMemoryBase* mem) override; - - void* HostMemoryAllocate(uint64_t size) override { return new char[size]; } - void HostMemoryDeallocate(void* mem) override { - delete[] static_cast(mem); - } - bool HostMemoryRegister(void* mem, uint64_t size) override { return true; } - bool HostMemoryUnregister(void* mem) override { return true; } - - bool Memcpy(Stream* stream, void* host_dst, const DeviceMemoryBase& gpu_src, - uint64_t size) override; - bool Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, const void* host_src, - uint64_t size) override; - bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, - const DeviceMemoryBase& gpu_src, - uint64_t size) override; - - tsl::Status MemZero(Stream* stream, DeviceMemoryBase* location, - uint64_t size) override; - tsl::Status Memset(Stream* stream, DeviceMemoryBase* location, - uint8_t pattern, uint64_t size) override; - tsl::Status Memset32(Stream* stream, DeviceMemoryBase* location, - uint32_t pattern, uint64_t size) override; - - // No "synchronize all activity" implemented for this platform at the moment. - bool SynchronizeAllActivity() override { return true; } - tsl::Status SynchronousMemZero(DeviceMemoryBase* location, - uint64_t size) override; - - tsl::Status SynchronousMemSet(DeviceMemoryBase* location, int value, - uint64_t size) override; - - tsl::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst, const void* host_src, - uint64_t size) override; - tsl::Status SynchronousMemcpy(void* host_dst, const DeviceMemoryBase& gpu_src, - uint64_t size) override; - tsl::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase* gpu_dst, - const DeviceMemoryBase& gpu_src, - uint64_t size) override; - - bool HostCallback(Stream* stream, - absl::AnyInvocable callback) override; - - tsl::Status AllocateEvent(Event* event) override; - tsl::Status DeallocateEvent(Event* event) override; - tsl::Status RecordEvent(Stream* stream, Event* event) override; - tsl::Status WaitForEvent(Stream* stream, Event* event) override; - Event::Status PollForEventStatus(Event* event) override; - - bool AllocateStream(Stream* stream) override; - void DeallocateStream(Stream* stream) override; - bool CreateStreamDependency(Stream* dependent, Stream* other) override; - - tsl::Status BlockHostUntilDone(Stream* stream) override; - - bool DeviceMemoryUsage(int64_t* free, int64_t* total) const override; - - tsl::StatusOr> CreateDeviceDescription() - const override { - return CreateDeviceDescription(0); - } - - static tsl::StatusOr> - CreateDeviceDescription(int device_ordinal); - - tsl::Status EnablePeerAccessTo(StreamExecutorInterface* other) override { - return ::tsl::OkStatus(); - } - - bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override { - return true; - } - - blas::BlasSupport* CreateBlas() override; - - dnn::DnnSupport* CreateDnn() override { return nullptr; } - - fft::FftSupport* CreateFft() override; - - std::unique_ptr CreateEventImplementation() - override; - - std::unique_ptr CreateKernelImplementation() - override { - return nullptr; - } - - std::unique_ptr GetStreamImplementation() override; - - private: - // Size of thread stacks for streams in bytes. '0' means "the default size". - size_t thread_stack_size_in_bytes_ = 0; -}; - -} // namespace host -} // namespace stream_executor - -#endif // XLA_STREAM_EXECUTOR_HOST_HOST_GPU_EXECUTOR_H_ diff --git a/xla/stream_executor/host/host_kernel.cc b/xla/stream_executor/host/host_kernel.cc new file mode 100644 index 0000000000000..e2d37085a1c75 --- /dev/null +++ b/xla/stream_executor/host/host_kernel.cc @@ -0,0 +1,68 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/host/host_kernel.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/host/host_kernel_c_api.h" +#include "xla/stream_executor/launch_dim.h" + +namespace stream_executor::host { + +HostKernel::HostKernel(unsigned arity, SE_HOST_Kernel* kernel) + : arity_(arity), kernel_(kernel) {} + +absl::Status HostKernel::Launch(const ThreadDim& thread_dims, + absl::Span buffers) { + SE_HOST_KernelThreadDim kernel_thread_dims = {thread_dims.x, thread_dims.y, + thread_dims.z}; + + // Convert buffers to kernel arguments. + std::vector args(buffers.size()); + for (int32_t i = 0; i < buffers.size(); ++i) { + args[i].data = const_cast(buffers[i].opaque()); + args[i].size = buffers[i].size(); + } + + // TODO(b/331430625): We should be using thread pool to call kernel function + // for different threads (blocks) concurrently. For now it's the most trivial + // implementation that runs tasks sequentially. + + for (uint64_t z = 0; z < thread_dims.z; ++z) { + for (uint64_t y = 0; y < thread_dims.y; ++y) { + for (uint64_t x = 0; x < thread_dims.x; ++x) { + SE_HOST_KernelThread kernel_thread = {x, y, z}; + + SE_HOST_KernelCallFrame call_frame = { + &kernel_thread_dims, &kernel_thread, args.size(), args.data()}; + + SE_HOST_KernelError* error = (*kernel_)(&call_frame); + + if (error != nullptr) { + return absl::InternalError("Failed to call host kernel"); + } + } + } + } + + return absl::OkStatus(); +} + +} // namespace stream_executor::host diff --git a/xla/stream_executor/host/host_kernel.h b/xla/stream_executor/host/host_kernel.h new file mode 100644 index 0000000000000..ee8f67738bf08 --- /dev/null +++ b/xla/stream_executor/host/host_kernel.h @@ -0,0 +1,58 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_HOST_HOST_KERNEL_H_ +#define XLA_STREAM_EXECUTOR_HOST_HOST_KERNEL_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/host/host_kernel_c_api.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" + +namespace stream_executor::host { + +class HostKernel : public Kernel { + public: + HostKernel(unsigned arity, SE_HOST_Kernel* kernel); + + // TODO(b/331430625): Connect this API to Launch API defined at StreamExecutor + // level, which requires refactoring how arguments passed to kernels, as + // current KernelArgs structure tied to the GPU kernel ABI. + absl::Status Launch(const ThreadDim& thread_dims, + absl::Span buffers); + + // For host platform, we assume that a core is a thread, and we can run at + // most one instance of a kernel on a given thread. + absl::StatusOr GetMaxOccupiedBlocksPerCore(ThreadDim, + size_t) const override { + return 1; + }; + + unsigned Arity() const override { return arity_; }; + + private: + unsigned arity_; + SE_HOST_Kernel* kernel_ = nullptr; +}; + +} // namespace stream_executor::host + +#endif // XLA_STREAM_EXECUTOR_HOST_HOST_KERNEL_H_ diff --git a/xla/stream_executor/host/host_kernel_c_api.h b/xla/stream_executor/host/host_kernel_c_api.h new file mode 100644 index 0000000000000..6768706abc280 --- /dev/null +++ b/xla/stream_executor/host/host_kernel_c_api.h @@ -0,0 +1,88 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_HOST_HOST_KERNEL_C_API_H_ +#define XLA_STREAM_EXECUTOR_HOST_HOST_KERNEL_C_API_H_ + +#include +#include + +//===----------------------------------------------------------------------===// +// StreamExecutor Host Kernel API +//===----------------------------------------------------------------------===// + +#ifdef __cplusplus +extern "C" { +#endif + +// StreamExecutor host kernel API is an integration point between a codegen +// backend and a runtime. XLA:CPU backend compiles fusion regions to native +// functions (via LLVM backend) that are compatible with a kernel API (and ABI), +// and the runtime is simply invoking them with user buffers and orchestrates +// multi-threaded execution. + +// WARNING: This API does not provide any backward compatibility guarantees as +// today XLA:CPU backend is statically linked and we do not plan to load +// kernels from dynamic libraries. It's defined as C API because we have to +// match it in the codegen backend (built on top of LLVM) and C structs have +// trivial layout that can be expressed as llvm stuct (*). +// +// (*) https://llvm.org/docs/LangRef.html#structure-types + +// Similar to a Gpu backend an XLA:CPU compiler generates a tiled function from +// an HLO fusion where each tile is responsible for computing a part of the +// output. It's up to compiler to chose the tiling strategy, from StreamExecutor +// perspective it's simply an iteration space where each task is independent and +// can be executed concurrently. +typedef struct SE_HOST_KernelDim3 { + uint64_t x; + uint64_t y; + uint64_t z; +} SE_HOST_KernelDim3; + +// Kernel grid size roughly corresponds to a CUDA block size. +typedef struct SE_HOST_KernelDim3 SE_HOST_KernelThreadDim; + +// Kernel grid coordinate roughly corresponds to a CUDA block, with an +// assumption that all kernel invocations can run concurrently. +typedef struct SE_HOST_KernelDim3 SE_HOST_KernelThread; + +// A CPU kernel argument that corresponds to se::DeviceMemoryBase. +typedef struct SE_HOST_KernelArg { + void* data; + size_t size; +} SE_HOST_KernelArg; + +// A CPU kernel call frame. +typedef struct SE_HOST_KernelCallFrame { + SE_HOST_KernelThreadDim* thread_dims; + SE_HOST_KernelThread* thread; + + size_t num_args; + SE_HOST_KernelArg* args; +} SE_HOST_KernelCallFrame; + +// Error reporting for host kernels. NULL means success. +typedef struct SE_HOST_KernelError SE_HOST_KernelError; + +// Host kernel API. +typedef SE_HOST_KernelError* SE_HOST_Kernel( + const SE_HOST_KernelCallFrame* call_frame); + +#ifdef __cplusplus +} +#endif + +#endif // XLA_STREAM_EXECUTOR_HOST_HOST_KERNEL_C_API_H_ diff --git a/xla/stream_executor/host/host_kernel_test.cc b/xla/stream_executor/host/host_kernel_test.cc new file mode 100644 index 0000000000000..6bf3439d2e95e --- /dev/null +++ b/xla/stream_executor/host/host_kernel_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/host/host_kernel.h" + +#include +#include + +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/host/host_kernel_c_api.h" +#include "xla/stream_executor/launch_dim.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/test.h" + +namespace stream_executor::host { + +static SE_HOST_KernelError* AddI32(const SE_HOST_KernelCallFrame* call_frame) { + SE_HOST_KernelArg& lhs = call_frame->args[0]; + SE_HOST_KernelArg& rhs = call_frame->args[1]; + SE_HOST_KernelArg& out = call_frame->args[2]; + + int32_t* lhs_ptr = reinterpret_cast(lhs.data); + int32_t* rhs_ptr = reinterpret_cast(rhs.data); + int32_t* out_ptr = reinterpret_cast(out.data); + + uint64_t x = call_frame->thread->x; + *(out_ptr + x) = *(lhs_ptr + x) + *(rhs_ptr + x); + + return nullptr; +} + +TEST(HostKernelTest, Addition) { + HostKernel kernel(/*arity=*/3, AddI32); + + std::vector lhs = {1, 2, 3, 4}; + std::vector rhs = {5, 6, 7, 8}; + std::vector out = {0, 0, 0, 0}; + + DeviceMemoryBase lhs_mem(lhs.data(), lhs.size() * sizeof(int32_t)); + DeviceMemoryBase rhs_mem(rhs.data(), rhs.size() * sizeof(int32_t)); + DeviceMemoryBase out_mem(out.data(), out.size() * sizeof(int32_t)); + std::vector args = {lhs_mem, rhs_mem, out_mem}; + + TF_ASSERT_OK(kernel.Launch(ThreadDim(4), args)); + + std::vector expected = {6, 8, 10, 12}; + EXPECT_EQ(out, expected); +} + +} // namespace stream_executor::host diff --git a/xla/stream_executor/host/host_platform.cc b/xla/stream_executor/host/host_platform.cc index 75dece9a8a42d..23112fbecd51a 100644 --- a/xla/stream_executor/host/host_platform.cc +++ b/xla/stream_executor/host/host_platform.cc @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,14 +15,22 @@ limitations under the License. #include "xla/stream_executor/host/host_platform.h" -#include +#include +#include +#include // NOLINT +#include -#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" -#include "xla/stream_executor/host/host_gpu_executor.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/host/host_executor.h" #include "xla/stream_executor/host/host_platform_id.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/initialize.h" -#include "tsl/platform/errors.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream_executor_pimpl.h" +#include "tsl/platform/status.h" namespace stream_executor { namespace host { @@ -39,35 +47,32 @@ int HostPlatform::VisibleDeviceCount() const { const std::string& HostPlatform::Name() const { return name_; } -tsl::StatusOr> +absl::StatusOr> HostPlatform::DescriptionForDevice(int ordinal) const { return HostExecutor::CreateDeviceDescription(ordinal); } -tsl::StatusOr HostPlatform::ExecutorForDevice(int ordinal) { +absl::StatusOr HostPlatform::ExecutorForDevice(int ordinal) { StreamExecutorConfig config; config.ordinal = ordinal; - config.device_options = DeviceOptions::Default(); return GetExecutor(config); } -tsl::StatusOr HostPlatform::GetExecutor( +absl::StatusOr HostPlatform::GetExecutor( const StreamExecutorConfig& config) { return executor_cache_.GetOrCreate( config, [&]() { return GetUncachedExecutor(config); }); } -tsl::StatusOr> +absl::StatusOr> HostPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { auto executor = std::make_unique( this, std::make_unique(), config.ordinal); - auto init_status = executor->Init(config.device_options); + auto init_status = executor->Init(); if (!init_status.ok()) { - return tsl::Status( - absl::StatusCode::kInternal, - absl::StrFormat( - "failed initializing StreamExecutor for device ordinal %d: %s", - config.ordinal, init_status.ToString().c_str())); + return absl::InternalError(absl::StrFormat( + "failed initializing StreamExecutor for device ordinal %d: %s", + config.ordinal, init_status.ToString().c_str())); } return std::move(executor); @@ -75,17 +80,11 @@ HostPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { static void InitializeHostPlatform() { std::unique_ptr platform(new host::HostPlatform); - TF_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform))); + TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform))); } } // namespace host } // namespace stream_executor -REGISTER_MODULE_INITIALIZER(host_platform, - stream_executor::host::InitializeHostPlatform()); - -// Note that module initialization sequencing is not supported in the -// open-source project, so this will be a no-op there. -REGISTER_MODULE_INITIALIZER_SEQUENCE(host_platform, multi_platform_manager); -REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener, - host_platform); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER( + host_platform, stream_executor::host::InitializeHostPlatform()); diff --git a/xla/stream_executor/host/host_platform.h b/xla/stream_executor/host/host_platform.h index 4461669b95bbd..25c1179dcd756 100644 --- a/xla/stream_executor/host/host_platform.h +++ b/xla/stream_executor/host/host_platform.h @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,14 +21,11 @@ limitations under the License. #include #include -#include +#include "absl/status/statusor.h" #include "xla/stream_executor/executor_cache.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/trace_listener.h" namespace stream_executor { namespace host { @@ -49,15 +46,15 @@ class HostPlatform : public Platform { const std::string& Name() const override; - tsl::StatusOr> DescriptionForDevice( + absl::StatusOr> DescriptionForDevice( int ordinal) const override; - tsl::StatusOr ExecutorForDevice(int ordinal) override; + absl::StatusOr ExecutorForDevice(int ordinal) override; - tsl::StatusOr GetExecutor( + absl::StatusOr GetExecutor( const StreamExecutorConfig& config) override; - tsl::StatusOr> GetUncachedExecutor( + absl::StatusOr> GetUncachedExecutor( const StreamExecutorConfig& config) override; private: diff --git a/xla/stream_executor/host/host_platform_id.cc b/xla/stream_executor/host/host_platform_id.cc index acf9f456be568..96a13097d289f 100644 --- a/xla/stream_executor/host/host_platform_id.cc +++ b/xla/stream_executor/host/host_platform_id.cc @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/host/host_platform_id.h b/xla/stream_executor/host/host_platform_id.h index e29fb5a5188b5..f9d85aeb32d19 100644 --- a/xla/stream_executor/host/host_platform_id.h +++ b/xla/stream_executor/host/host_platform_id.h @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/host/host_stream.cc b/xla/stream_executor/host/host_stream.cc index 892b2b7d03774..dfe091680db92 100644 --- a/xla/stream_executor/host/host_stream.cc +++ b/xla/stream_executor/host/host_stream.cc @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,10 +17,15 @@ limitations under the License. // the HostExecutor implementation. #include "xla/stream_executor/host/host_stream.h" +#include // NOLINT +#include #include #include #include "absl/functional/any_invocable.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" #include "tsl/platform/denormal.h" #include "tsl/platform/env.h" @@ -29,20 +34,9 @@ limitations under the License. namespace stream_executor { namespace host { -namespace { - -tsl::ThreadOptions GetThreadOptions(size_t stack_size_in_bytes) { - tsl::ThreadOptions options; - options.stack_size = stack_size_in_bytes; - return options; -} - -} // namespace - -HostStream::HostStream(size_t stack_size_in_bytes) - : thread_(tsl::Env::Default()->StartThread( - GetThreadOptions(stack_size_in_bytes), "host_executor", - [this]() { WorkLoop(); })) {} +HostStream::HostStream() + : thread_(tsl::Env::Default()->StartThread({}, "host_executor", + [this]() { WorkLoop(); })) {} HostStream::~HostStream() { { @@ -56,12 +50,12 @@ HostStream::~HostStream() { bool HostStream::EnqueueTask(absl::AnyInvocable task) { return EnqueueTaskWithStatus([task = std::move(task)]() mutable { std::move(task)(); - return ::tsl::OkStatus(); + return absl::OkStatus(); }); } bool HostStream::EnqueueTaskWithStatus( - absl::AnyInvocable task) { + absl::AnyInvocable task) { CHECK(task != nullptr); absl::MutexLock lock(&mu_); work_queue_.push(std::move(task)); @@ -77,14 +71,14 @@ void HostStream::WorkLoop() { tsl::port::ScopedFlushDenormal flush; tsl::port::ScopedSetRound round(FE_TONEAREST); while (true) { - std::queue> queue; + std::queue> queue; { absl::MutexLock lock(&mu_); mu_.Await(absl::Condition(this, &HostStream::WorkAvailable)); std::swap(queue, work_queue_); } while (!queue.empty()) { - absl::AnyInvocable& fn = queue.front(); + absl::AnyInvocable& fn = queue.front(); if (!fn) { return; } @@ -94,15 +88,15 @@ void HostStream::WorkLoop() { } } -tsl::Status HostStream::BlockUntilDone() { +absl::Status HostStream::BlockUntilDone() { absl::Notification done; - tsl::Status status; + absl::Status status; EnqueueTask([&done, &status, this]() { // This task is always executed synchronously before 'status_' is updated // with the result of the task (always OK() in this case), so we don't need // to worry about locking access to 'status_'. status = status_; - status_ = ::tsl::OkStatus(); + status_ = absl::OkStatus(); done.Notify(); }); done.WaitForNotification(); diff --git a/xla/stream_executor/host/host_stream.h b/xla/stream_executor/host/host_stream.h index ed5bc0b056fc1..47afb45f6c647 100644 --- a/xla/stream_executor/host/host_stream.h +++ b/xla/stream_executor/host/host_stream.h @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,46 +18,47 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_HOST_HOST_STREAM_H_ #define XLA_STREAM_EXECUTOR_HOST_HOST_STREAM_H_ -#include +#include #include #include +#include "absl/base/thread_annotations.h" #include "absl/functional/any_invocable.h" +#include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "xla/stream_executor/stream_executor_internal.h" #include "tsl/platform/env.h" +#include "tsl/platform/thread_annotations.h" namespace stream_executor { namespace host { class HostStream : public internal::StreamInterface { public: - // stack_size_in_bytes may be '0', meaning "use the default thread stack - // size". - explicit HostStream(size_t stack_size_in_bytes); + HostStream(); ~HostStream() override; // Enqueue a task that reports a status when finished. Tasks that fail do not // stop the stream or block any other tasks from executing; rather, the stream // will remember the first error encountered and return it from // 'BlockUntilDone'. - bool EnqueueTaskWithStatus(absl::AnyInvocable task); + bool EnqueueTaskWithStatus(absl::AnyInvocable task); // Enqueue a task that doesn't report any status. bool EnqueueTask(absl::AnyInvocable task); // Blocks until all tasks are done, returns the first error reported by a task // (if any) and clears the error status. - tsl::Status BlockUntilDone(); + absl::Status BlockUntilDone(); private: bool WorkAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); void WorkLoop(); absl::Mutex mu_; - std::queue> work_queue_ + std::queue> work_queue_ ABSL_GUARDED_BY(mu_); std::unique_ptr thread_; - tsl::Status status_; + absl::Status status_; }; } // namespace host diff --git a/xla/stream_executor/host/host_stream_test.cc b/xla/stream_executor/host/host_stream_test.cc index 3453bd0b3b95d..dfce1582fa7e7 100644 --- a/xla/stream_executor/host/host_stream_test.cc +++ b/xla/stream_executor/host/host_stream_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/status/status.h" #include "absl/synchronization/mutex.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/lib/core/status_test_util.h" @@ -26,55 +27,49 @@ namespace se = stream_executor; TEST(HostStream, EnforcesFIFOOrder) { se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("Host").value(); + se::PlatformManager::PlatformWithName("Host").value(); se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - se::Stream stream(executor); - stream.Init(); - + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); absl::Mutex mu; int expected = 0; bool ok = true; for (int i = 0; i < 2000; ++i) { - stream.ThenDoHostCallback([i, &mu, &expected, &ok]() { + TF_ASSERT_OK(stream->DoHostCallback([i, &mu, &expected, &ok]() { absl::MutexLock lock(&mu); if (expected != i) { ok = false; } ++expected; - }); + })); } - TF_ASSERT_OK(stream.BlockHostUntilDone()); + TF_ASSERT_OK(stream->BlockHostUntilDone()); absl::MutexLock lock(&mu); EXPECT_TRUE(ok); } TEST(HostStream, ReportsHostCallbackError) { se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("Host").value(); + se::PlatformManager::PlatformWithName("Host").value(); se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - se::Stream stream(executor); - stream.Init(); - - stream.ThenDoHostCallbackWithStatus( - []() { return tsl::errors::Internal("error!"); }); + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + TF_ASSERT_OK(stream->DoHostCallbackWithStatus( + []() { return absl::InternalError("error!"); })); - auto status = stream.BlockHostUntilDone(); + auto status = stream->BlockHostUntilDone(); ASSERT_EQ(status.code(), tsl::error::INTERNAL); ASSERT_EQ(status.message(), "error!"); } TEST(HostStream, ReportsFirstHostCallbackError) { se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("Host").value(); + se::PlatformManager::PlatformWithName("Host").value(); se::StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - se::Stream stream(executor); - stream.Init(); - - stream.ThenDoHostCallbackWithStatus( - []() { return tsl::errors::Internal("error 1"); }); - stream.ThenDoHostCallbackWithStatus( - []() { return tsl::errors::Internal("error 2"); }); + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + TF_ASSERT_OK(stream->DoHostCallbackWithStatus( + []() { return absl::InternalError("error 1"); })); + TF_ASSERT_OK(stream->DoHostCallbackWithStatus( + []() { return absl::InternalError("error 2"); })); // "error 2" is just lost. - ASSERT_EQ(stream.BlockHostUntilDone().message(), "error 1"); + ASSERT_EQ(stream->BlockHostUntilDone().message(), "error 1"); } diff --git a/xla/stream_executor/host_memory_allocation.cc b/xla/stream_executor/host_memory_allocation.cc new file mode 100644 index 0000000000000..12affb0b3c68b --- /dev/null +++ b/xla/stream_executor/host_memory_allocation.cc @@ -0,0 +1,34 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/host_memory_allocation.h" + +#include + +#include "xla/stream_executor/stream_executor_internal.h" + +namespace stream_executor { + +HostMemoryAllocation::HostMemoryAllocation( + void* ptr, uint64_t size, internal::StreamExecutorInterface* executor) + : ptr_(ptr), size_(size), executor_(executor) {} + +HostMemoryAllocation::~HostMemoryAllocation() { + if (ptr_ != nullptr && executor_ != nullptr) { + executor_->HostMemoryDeallocate(ptr_); + } +} + +} // namespace stream_executor diff --git a/xla/stream_executor/host_memory_allocation.h b/xla/stream_executor/host_memory_allocation.h new file mode 100644 index 0000000000000..974eb63fb8daa --- /dev/null +++ b/xla/stream_executor/host_memory_allocation.h @@ -0,0 +1,48 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_HOST_MEMORY_ALLOCATION_H_ +#define XLA_STREAM_EXECUTOR_HOST_MEMORY_ALLOCATION_H_ + +#include + +#include "xla/stream_executor/memory_allocation.h" + +namespace stream_executor { + +namespace internal { +class StreamExecutorInterface; +} + +// RAII container for pinned host memory allocation allocated on an underlying +// device owned by `*this`. +class HostMemoryAllocation final : public MemoryAllocation { + public: + HostMemoryAllocation(void* ptr, uint64_t size, + internal::StreamExecutorInterface* executor); + ~HostMemoryAllocation() final; + + void* opaque() const final { return ptr_; } + uint64_t size() const final { return size_; } + + private: + void* ptr_ = nullptr; + uint64_t size_ = 0; + internal::StreamExecutorInterface* executor_ = nullptr; +}; + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_HOST_MEMORY_ALLOCATION_H_ diff --git a/xla/stream_executor/host_or_device_scalar.h b/xla/stream_executor/host_or_device_scalar.h index a7d428d05e565..81e07e2d194ce 100644 --- a/xla/stream_executor/host_or_device_scalar.h +++ b/xla/stream_executor/host_or_device_scalar.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "xla/stream_executor/device_memory.h" namespace stream_executor { diff --git a/xla/stream_executor/integrations/BUILD b/xla/stream_executor/integrations/BUILD index 58bb3f447ae34..5a3130a64aa00 100644 --- a/xla/stream_executor/integrations/BUILD +++ b/xla/stream_executor/integrations/BUILD @@ -1,11 +1,12 @@ -load("//xla/stream_executor:build_defs.bzl", "stream_executor_friends") -load("@tsl//tsl:tsl.bzl", "set_external_visibility") +load("@tsl//tsl:tsl.bzl", "if_google", "internal_visibility") load("@tsl//tsl:tsl.default.bzl", "filegroup") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") +load("//xla/stream_executor:build_defs.bzl", "stream_executor_friends") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([":friends"]), + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -28,7 +29,7 @@ filegroup( "device_host_allocator.h", "device_mem_allocator.h", ], - visibility = ["//tensorflow/core:__pkg__"], + visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) #===--------------------------------------------------------------------------------------------===# @@ -47,8 +48,11 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:platform", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@tsl//tsl/framework:allocator", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", @@ -63,8 +67,32 @@ cc_library( ], deps = [ "//xla/stream_executor", + "//xla/stream_executor:memory_allocation", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/synchronization", "@tsl//tsl/framework:allocator", "@tsl//tsl/framework:device_id", + "@tsl//tsl/platform:logging", "@tsl//tsl/profiler/lib:traceme", ], ) + +xla_cc_test( + name = "tf_allocator_adapter_test", + srcs = ["tf_allocator_adapter_test.cc"], + deps = [ + ":tf_allocator_adapter", + "//xla/service:cpu_plugin", + "//xla/service:platform_util", + "//xla/stream_executor", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/log:check", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ] + if_google([ + "@tsl//tsl/framework:allocator", + ]), +) diff --git a/xla/stream_executor/integrations/device_host_allocator.h b/xla/stream_executor/integrations/device_host_allocator.h index b39ffe52a502f..90292674a5456 100644 --- a/xla/stream_executor/integrations/device_host_allocator.h +++ b/xla/stream_executor/integrations/device_host_allocator.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,13 +16,22 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_INTEGRATIONS_DEVICE_HOST_ALLOCATOR_H_ #define XLA_STREAM_EXECUTOR_INTEGRATIONS_DEVICE_HOST_ALLOCATOR_H_ +#include +#include +#include #include +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/framework/allocator.h" +#include "tsl/platform/logging.h" #include "tsl/profiler/lib/traceme.h" namespace stream_executor { + // Allocator for pinned CPU RAM that is made known to a StreamExecutor-based // device for the purpose of efficient DMA with the device. class DeviceHostAllocator : public tsl::SubAllocator { @@ -45,15 +54,22 @@ class DeviceHostAllocator : public tsl::SubAllocator { void* ptr = nullptr; *bytes_received = num_bytes; + if (num_bytes > 0) { - ptr = stream_exec_->HostMemoryAllocate(num_bytes); - if (ptr == nullptr) { + auto allocation = stream_exec_->HostMemoryAllocate(num_bytes); + if (!allocation.ok()) { LOG(WARNING) << "could not allocate pinned host memory of size: " << num_bytes; - return ptr; + return nullptr; } + + ptr = (*allocation)->opaque(); VisitAlloc(ptr, numa_node_, num_bytes); + + absl::MutexLock lock(&mutex_); + allocs_[ptr] = std::move(*allocation); } + return ptr; } @@ -62,7 +78,8 @@ class DeviceHostAllocator : public tsl::SubAllocator { if (ptr != nullptr) { VisitFree(ptr, numa_node_, num_bytes); - stream_exec_->HostMemoryDeallocate(ptr); + absl::MutexLock lock(&mutex_); + allocs_.erase(ptr); } } @@ -78,6 +95,10 @@ class DeviceHostAllocator : public tsl::SubAllocator { DeviceHostAllocator(const DeviceHostAllocator&) = delete; void operator=(const DeviceHostAllocator&) = delete; + + absl::Mutex mutex_; + absl::flat_hash_map> allocs_ + ABSL_GUARDED_BY(mutex_); }; } // namespace stream_executor diff --git a/xla/stream_executor/integrations/device_mem_allocator.h b/xla/stream_executor/integrations/device_mem_allocator.h index d3327fc8fa0d7..ee7767ce5aadf 100644 --- a/xla/stream_executor/integrations/device_mem_allocator.h +++ b/xla/stream_executor/integrations/device_mem_allocator.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,6 +25,9 @@ limitations under the License. namespace stream_executor { +// The type of memory that the allocator will use. +enum class MemoryType { kDevice = 0, kUnified, kCollective, kHost = 5 }; + // Suballocator for StreamExecutor-based device memory. class DeviceMemAllocator : public tsl::SubAllocator { public: @@ -33,13 +36,13 @@ class DeviceMemAllocator : public tsl::SubAllocator { // Note: stream_exec cannot be null. explicit DeviceMemAllocator(StreamExecutor* stream_exec, tsl::PlatformDeviceId device_id, - bool use_unified_memory, + MemoryType memory_type, const std::vector& alloc_visitors, const std::vector& free_visitors) : SubAllocator(alloc_visitors, free_visitors), stream_exec_(stream_exec), device_id_(device_id), - use_unified_memory_(use_unified_memory) { + memory_type_(memory_type) { CHECK(stream_exec_ != nullptr); } @@ -52,8 +55,17 @@ class DeviceMemAllocator : public tsl::SubAllocator { void* ptr = nullptr; *bytes_received = num_bytes; if (num_bytes > 0) { - if (use_unified_memory_) { + if (memory_type_ == MemoryType::kUnified) { ptr = stream_exec_->UnifiedMemoryAllocate(num_bytes); + } else if (memory_type_ == MemoryType::kCollective) { + auto status_or = stream_exec_->CollectiveMemoryAllocate(num_bytes); + CHECK(status_or.ok()) << status_or.status().message(); + ptr = status_or.value(); + } else if (memory_type_ == MemoryType::kHost) { + // Convert size_t to long unsigned int + long unsigned int value = static_cast(num_bytes); + auto status_or = stream_exec_->HostMemoryAllocate(value); + CHECK(status_or.ok()) << status_or.status().message(); } else { ptr = stream_exec_->AllocateArray(num_bytes).opaque(); } @@ -67,8 +79,13 @@ class DeviceMemAllocator : public tsl::SubAllocator { if (ptr != nullptr) { VisitFree(ptr, device_id_.value(), num_bytes); - if (use_unified_memory_) { + if (memory_type_ == MemoryType::kUnified) { stream_exec_->UnifiedMemoryDeallocate(ptr); + } else if (memory_type_ == MemoryType::kCollective) { + auto status = stream_exec_->CollectiveMemoryDeallocate(ptr); + CHECK(status.ok()) << status.message(); + } else if (memory_type_ == MemoryType::kHost) { + stream_exec_->HostMemoryDeallocate(ptr, num_bytes); } else { DeviceMemoryBase device_ptr(ptr); stream_exec_->Deallocate(&device_ptr); @@ -85,7 +102,7 @@ class DeviceMemAllocator : public tsl::SubAllocator { private: StreamExecutor* stream_exec_; // not owned, non-null const tsl::PlatformDeviceId device_id_; - const bool use_unified_memory_ = false; + const MemoryType memory_type_ = MemoryType::kDevice; DeviceMemAllocator(const DeviceMemAllocator&) = delete; void operator=(const DeviceMemAllocator&) = delete; diff --git a/xla/stream_executor/integrations/tf_allocator_adapter.cc b/xla/stream_executor/integrations/tf_allocator_adapter.cc index ebbaa220ecc7d..97a6c7c7b09e3 100644 --- a/xla/stream_executor/integrations/tf_allocator_adapter.cc +++ b/xla/stream_executor/integrations/tf_allocator_adapter.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,7 +16,9 @@ limitations under the License. #include "xla/stream_executor/integrations/tf_allocator_adapter.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/errors.h" @@ -34,10 +36,9 @@ TfAllocatorAdapter::TfAllocatorAdapter(tsl::Allocator *wrapped, TfAllocatorAdapter::~TfAllocatorAdapter() {} -tsl::StatusOr TfAllocatorAdapter::Allocate( +absl::StatusOr TfAllocatorAdapter::Allocate( int device_ordinal, uint64_t size, bool retry_on_failure, int64_t memory_space) { - CHECK_EQ(memory_space, 0); tsl::AllocationAttributes attrs; attrs.retry_on_failure = retry_on_failure; void *data = nullptr; @@ -45,25 +46,25 @@ tsl::StatusOr TfAllocatorAdapter::Allocate( data = wrapped_->AllocateRaw(tsl::Allocator::kAllocatorAlignment, size, attrs); if (data == nullptr) { - return tsl::errors::ResourceExhausted( - "Out of memory while trying to allocate ", size, " bytes."); + return absl::ResourceExhaustedError(absl::StrCat( + "Out of memory while trying to allocate ", size, " bytes.")); } } return OwningDeviceMemory(DeviceMemoryBase(data, size), device_ordinal, this); } -tsl::Status TfAllocatorAdapter::Deallocate(int device_ordinal, - DeviceMemoryBase mem) { +absl::Status TfAllocatorAdapter::Deallocate(int device_ordinal, + DeviceMemoryBase mem) { wrapped_->DeallocateRaw(mem.opaque()); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -tsl::StatusOr TfAllocatorAdapter::GetStream(int device_ordinal) { +absl::StatusOr TfAllocatorAdapter::GetStream(int device_ordinal) { CHECK_EQ(stream_->parent()->device_ordinal(), device_ordinal); return stream_; } -tsl::StatusOr TfAllocatorAdapter::GetAllocator( +absl::StatusOr TfAllocatorAdapter::GetAllocator( int device_ordinal) { if (stream_ == nullptr) { return absl::UnavailableError("stream_ is null for TfAllocatorAdapter."); diff --git a/xla/stream_executor/integrations/tf_allocator_adapter.h b/xla/stream_executor/integrations/tf_allocator_adapter.h index 7e63f4b4e70a3..0a1b2bbf37d4e 100644 --- a/xla/stream_executor/integrations/tf_allocator_adapter.h +++ b/xla/stream_executor/integrations/tf_allocator_adapter.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,10 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/platform.h" @@ -46,11 +50,11 @@ class TfAllocatorAdapter : public DeviceMemoryAllocator { ~TfAllocatorAdapter() override; - tsl::StatusOr Allocate(int device_ordinal, uint64_t size, - bool retry_on_failure, - int64_t memory_space) override; + absl::StatusOr Allocate(int device_ordinal, uint64_t size, + bool retry_on_failure, + int64_t memory_space) override; - tsl::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) override; + absl::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) override; // The Tensorflow BFC allocator used on GPU allows host-side deallocation // before GPU execution takes place. Tensorflow uses the ordering of the main @@ -61,9 +65,9 @@ class TfAllocatorAdapter : public DeviceMemoryAllocator { // (This attribute has no effect on CPU.) bool AllowsAsynchronousDeallocation() const override { return true; } - tsl::StatusOr GetStream(int device_ordinal) override; + absl::StatusOr GetStream(int device_ordinal) override; - tsl::StatusOr GetAllocator(int device_ordinal); + absl::StatusOr GetAllocator(int device_ordinal); private: tsl::Allocator *wrapped_; @@ -75,56 +79,80 @@ class TfAllocatorAdapter : public DeviceMemoryAllocator { // asynchronous deallocation; see comment on `AllowsAsynchronousDeallocation()`. class MultiDeviceAdapter : public DeviceMemoryAllocator { public: - using AllocatorWithStream = - std::pair, Stream *>; - using AllocatorWithLogicalIdAndStream = - std::tuple, int, Stream *>; + struct AllocatorInfo { + std::unique_ptr allocator; + Stream *stream; + int64_t memory_space; + std::optional device_ordinal = std::nullopt; + + AllocatorInfo(std::unique_ptr allocator, Stream *stream, + int64_t memory_space, + std::optional device_ordinal = std::nullopt) + : allocator(std::move(allocator)), + stream(stream), + memory_space(memory_space), + device_ordinal(device_ordinal) {} + }; MultiDeviceAdapter(const Platform *platform, - std::vector tf_allocators) + std::vector tf_allocators) : DeviceMemoryAllocator(platform) { tf_allocators_.reserve(tf_allocators.size()); - for (AllocatorWithStream &p : tf_allocators) { - int device_ordinal = p.second->parent()->device_ordinal(); - if (per_device_allocators_.size() <= device_ordinal) { - per_device_allocators_.resize(device_ordinal + 1); + for (AllocatorInfo &info : tf_allocators) { + auto &per_device_allocators = + memory_space_to_per_device_allocators_[info.memory_space]; + int device_ordinal = info.device_ordinal.has_value() + ? *info.device_ordinal + : info.stream->parent()->device_ordinal(); + if (per_device_allocators.size() <= device_ordinal) { + per_device_allocators.resize(device_ordinal + 1); } - CHECK(!per_device_allocators_[device_ordinal]); - per_device_allocators_[device_ordinal] = - std::make_unique(p.first.get(), p.second); - tf_allocators_.push_back(std::move(p.first)); + CHECK(!per_device_allocators[device_ordinal]); + per_device_allocators[device_ordinal] = + std::make_unique(info.allocator.get(), + info.stream); + tf_allocators_.push_back(std::move(info.allocator)); } } - MultiDeviceAdapter(const Platform *platform, - std::vector tf_allocators) - : DeviceMemoryAllocator(platform) { - tf_allocators_.reserve(tf_allocators.size()); - for (AllocatorWithLogicalIdAndStream &t : tf_allocators) { - const int device_ordinal = std::get<1>(t); - Stream *stream = std::get<2>(t); - if (per_device_allocators_.size() <= device_ordinal) { - per_device_allocators_.resize(device_ordinal + 1); - } - CHECK(!per_device_allocators_[device_ordinal]); - per_device_allocators_[device_ordinal] = - std::make_unique(std::get<0>(t).get(), stream); - tf_allocators_.push_back(std::move(std::get<0>(t))); - } + absl::StatusOr Allocate(int device_ordinal, uint64_t size, + bool retry_on_failure, + int64_t memory_space) override { + // memory_space is used here to select allocator. This isn't a need to pass + // it any lower to TfAllocatorAdapter. + auto it = memory_space_to_per_device_allocators_.find(memory_space); + CHECK(it != memory_space_to_per_device_allocators_.end()); + CHECK_LT(device_ordinal, it->second.size()); + TF_ASSIGN_OR_RETURN( + auto result, it->second[device_ordinal]->Allocate( + device_ordinal, size, retry_on_failure, memory_space)); + + absl::MutexLock lock(&mu_); + buffer_memory_spaces_[{device_ordinal, result->opaque()}] = memory_space; + return result; } - tsl::StatusOr Allocate(int device_ordinal, uint64_t size, - bool retry_on_failure, - int64_t memory_space) override { - CHECK_LT(device_ordinal, per_device_allocators_.size()); - return per_device_allocators_[device_ordinal]->Allocate( - device_ordinal, size, retry_on_failure, memory_space); - } + absl::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) override { + if (mem.opaque() == nullptr) return absl::OkStatus(); + // Memory space is not passed to deallocate, look up in + // buffer_memory_spaces_. + int64_t memory_space; + { + absl::MutexLock lock(&mu_); + auto it = buffer_memory_spaces_.find({device_ordinal, mem.opaque()}); + if (it == buffer_memory_spaces_.end()) { + return absl::InternalError( + absl::StrFormat("Memory %p was not allocated on device %d.", + mem.opaque(), device_ordinal)); + } + memory_space = it->second; + buffer_memory_spaces_.erase(it); + } - tsl::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) override { - CHECK_LT(device_ordinal, per_device_allocators_.size()); - return per_device_allocators_[device_ordinal]->Deallocate(device_ordinal, - mem); + auto it = memory_space_to_per_device_allocators_.find(memory_space); + CHECK(it != memory_space_to_per_device_allocators_.end()); + CHECK_LT(device_ordinal, it->second.size()); + return it->second[device_ordinal]->Deallocate(device_ordinal, mem); } // The Tensorflow BFC allocator used on GPU allows host-side deallocation @@ -136,16 +164,26 @@ class MultiDeviceAdapter : public DeviceMemoryAllocator { // (This attribute has no effect on CPU.) bool AllowsAsynchronousDeallocation() const override { return true; } - tsl::StatusOr GetStream(int device_ordinal) override { - return per_device_allocators_[device_ordinal]->GetStream(device_ordinal); + absl::StatusOr GetStream(int device_ordinal) override { + // Both allocators should use the same stream, so just use 0. + return memory_space_to_per_device_allocators_[0][device_ordinal]->GetStream( + device_ordinal); } - tsl::StatusOr GetAllocator(int device_ordinal) { - return per_device_allocators_[device_ordinal]->GetAllocator(device_ordinal); + absl::StatusOr GetAllocator(int device_ordinal) { + // GetAllocator is used for memory stats. Currently we will only see stats + // for main device memory allocator. + return memory_space_to_per_device_allocators_[0][device_ordinal] + ->GetAllocator(device_ordinal); } private: - std::vector> per_device_allocators_; + absl::flat_hash_map>> + memory_space_to_per_device_allocators_; + // Map of device ordinal, buffer to which memory space it resides in. + absl::Mutex mu_; + absl::flat_hash_map, int64_t> buffer_memory_spaces_ + ABSL_GUARDED_BY(mu_); // The wrapped TF allocators backing per_device_allocators_ // (TfAllocatorAdapter does not take ownership of its underlying Allocator). std::vector> tf_allocators_; diff --git a/xla/stream_executor/integrations/tf_allocator_adapter_test.cc b/xla/stream_executor/integrations/tf_allocator_adapter_test.cc new file mode 100644 index 0000000000000..e98569cdbf405 --- /dev/null +++ b/xla/stream_executor/integrations/tf_allocator_adapter_test.cc @@ -0,0 +1,105 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/integrations/tf_allocator_adapter.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/container/node_hash_set.h" +#include "absl/log/check.h" +#include "xla/service/platform_util.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/framework/allocator.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace se = stream_executor; + +// Each allocatotion will have an incrementing address. +class TestAllocator : public tsl::Allocator { + public: + explicit TestAllocator(size_t start_address) + : start_address_(start_address) {} + + std::string Name() override { return "test"; } + + void* AllocateRaw(size_t alignment, size_t num_bytes) override { + void* ptr = reinterpret_cast(++start_address_); + allocations_.insert(ptr); + return ptr; + } + + void DeallocateRaw(void* ptr) override { + auto it = allocations_.find(ptr); + if (it == allocations_.end()) { + ADD_FAILURE() << "Allocation not found (double free?)"; + } else { + allocations_.erase(it); + } + } + + private: + absl::flat_hash_set allocations_; + size_t start_address_; +}; + +TEST(MultiDeviceAdapter, UsesCorrectAllocator) { + TF_ASSERT_OK_AND_ASSIGN(auto* platform, + xla::PlatformUtil::GetDefaultPlatform()); + TF_ASSERT_OK_AND_ASSIGN(std::vector executors, + xla::PlatformUtil::GetStreamExecutors(platform)) + TF_ASSERT_OK_AND_ASSIGN(auto stream, executors[0]->CreateStream()); + + std::vector infos; + infos.emplace_back(std::make_unique(0x1000), stream.get(), + /*memory_space=*/0, /*device_ordinal=*/0); + infos.emplace_back(std::make_unique(0x2000), stream.get(), + /*memory_space=*/0, /*device_ordinal=*/1); + infos.emplace_back(std::make_unique(0x3000), stream.get(), + /*memory_space=*/1, /*device_ordinal=*/0); + infos.emplace_back(std::make_unique(0x4000), stream.get(), + /*memory_space=*/1, /*device_ordinal=*/1); + std::unique_ptr allocator = + std::make_unique(platform, std::move(infos)); + + TF_ASSERT_OK_AND_ASSIGN( + se::OwningDeviceMemory buff0, + allocator->Allocate(/*device_ordinal=*/0, 4, false, /*memory_space=*/0)); + CHECK_EQ(reinterpret_cast(buff0->opaque()), 0x1001); + TF_ASSERT_OK_AND_ASSIGN( + se::OwningDeviceMemory buff1, + allocator->Allocate(/*device_ordinal=*/0, 4, false, /*memory_space=*/0)); + CHECK_EQ(reinterpret_cast(buff1->opaque()), 0x1002); + TF_ASSERT_OK_AND_ASSIGN( + se::OwningDeviceMemory buff2, + allocator->Allocate(/*device_ordinal=*/0, 4, false, /*memory_space=*/1)); + CHECK_EQ(reinterpret_cast(buff2->opaque()), 0x3001); + TF_ASSERT_OK_AND_ASSIGN( + se::OwningDeviceMemory buff3, + allocator->Allocate(/*device_ordinal=*/1, 4, false, /*memory_space=*/0)); + CHECK_EQ(reinterpret_cast(buff3->opaque()), 0x2001); + TF_ASSERT_OK_AND_ASSIGN( + se::OwningDeviceMemory buff4, + allocator->Allocate(/*device_ordinal=*/1, 4, false, /*memory_space=*/1)); + CHECK_EQ(reinterpret_cast(buff4->opaque()), 0x4001); +} diff --git a/xla/stream_executor/kernel.cc b/xla/stream_executor/kernel.cc index 3b06a32fc31c5..f51257aa049c3 100644 --- a/xla/stream_executor/kernel.cc +++ b/xla/stream_executor/kernel.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,16 +16,20 @@ limitations under the License. #include "xla/stream_executor/kernel.h" #include +#include #include #include -#include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" +#include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" #include "tsl/platform/demangle.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace stream_executor { @@ -45,33 +49,15 @@ void KernelMetadata::set_shared_memory_bytes(int shared_memory_bytes) { shared_memory_bytes_ = shared_memory_bytes; } -Kernel::Kernel(Kernel &&from) - : parent_(from.parent_), - implementation_(std::move(from.implementation_)), - name_(std::move(from.name_)), - demangled_name_(std::move(from.demangled_name_)), - metadata_(from.metadata_) { - from.parent_ = nullptr; -} - -Kernel::Kernel(StreamExecutor *parent) - : parent_(parent), - implementation_(parent->implementation()->CreateKernelImplementation()) {} - -Kernel::~Kernel() { - if (parent_) { - parent_->UnloadKernel(this); - } -} - -unsigned Kernel::Arity() const { return implementation_->Arity(); } - -void Kernel::SetPreferredCacheConfig(KernelCacheConfig config) { - return implementation_->SetPreferredCacheConfig(config); -} +//===----------------------------------------------------------------------===// +// Kernel +//===----------------------------------------------------------------------===// -KernelCacheConfig Kernel::GetPreferredCacheConfig() const { - return implementation_->GetPreferredCacheConfig(); +absl::StatusOr> Kernel::Create( + StreamExecutor *executor, const MultiKernelLoaderSpec &spec) { + TF_ASSIGN_OR_RETURN(auto kernel, executor->implementation()->CreateKernel()); + TF_RETURN_IF_ERROR(executor->GetKernel(spec, kernel.get())); + return kernel; } void Kernel::set_name(absl::string_view name) { diff --git a/xla/stream_executor/kernel.h b/xla/stream_executor/kernel.h index 9077f80955ed8..edf0e24b31a11 100644 --- a/xla/stream_executor/kernel.h +++ b/xla/stream_executor/kernel.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -81,23 +81,23 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" -#include "absl/log/check.h" #include "absl/meta/type_traits.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" +#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" namespace stream_executor { +class Kernel; class StreamExecutor; -namespace internal { -class KernelInterface; -} // namespace internal - //===----------------------------------------------------------------------===// // Kernel cache config //===----------------------------------------------------------------------===// @@ -225,91 +225,116 @@ class Kernel { // registering custom CUDA C++ kernels with non-trivial C++ API with a // StreamExecutor as a generic `Kernel`. using KernelArgsPacking = - std::function>( - const KernelArgs &args)>; + std::function>( + const Kernel &kernel, const KernelArgs &args)>; - Kernel(Kernel &&from); + // TODO(b/323534971): Kernel constructor should be moved to StreamExecutor or + // a dedicated KernelFactory accessible via StreamExecutor. - // Constructs an "empty" (not-yet-loaded) kernel instance. - // - // parent is the StreamExecutor that will be responsible for loading the - // implementation of this kernel. It must not be null. - explicit Kernel(StreamExecutor *parent); + // Creates kernel on a given executor from a given kernel specification. + static absl::StatusOr> Create( + StreamExecutor *executor, const MultiKernelLoaderSpec &spec); - // Releases resources associated with the kernel instance (i.e. - // platform-specific implementation). - ~Kernel(); + Kernel() = default; + virtual ~Kernel() = default; + + Kernel(const Kernel &) = delete; + void operator=(const Kernel &) = delete; // Returns the number of parameters that this kernel accepts. (Arity refers to // nullary, unary, ...). - unsigned Arity() const; + virtual unsigned Arity() const = 0; - // Returns the StreamExecutor that represents the platform this kernel - // executes upon. - StreamExecutor *parent() const { return parent_; } + // Returns the maximum number of blocks (per multiprocessor) occupied by the + // kernel given the number of threads per block and shared memory size. + virtual absl::StatusOr GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const = 0; - // Returns a const pointer to the (opaque) platform-dependent implementation. - const internal::KernelInterface *implementation() const { - return implementation_.get(); + KernelCacheConfig cache_config() const { return cache_config_; } + void set_cache_config(KernelCacheConfig cache_config) { + cache_config_ = std::move(cache_config); } - // Returns a non-const pointer to the (opaque) platform-dependent - // implementation. - internal::KernelInterface *implementation() { return implementation_.get(); } - - void set_metadata(const KernelMetadata &metadata) { metadata_ = metadata; } - const KernelMetadata &metadata() const { return metadata_; } - - // Sets the preferred cache configuration for a kernel. This is just a - // suggestion to the runtime, and may not be honored during execution. - void SetPreferredCacheConfig(KernelCacheConfig config); - - // Gets the preferred cache configuration for a kernel. - KernelCacheConfig GetPreferredCacheConfig() const; - - // Sets custom kernels arguments packing function for a kernel. - void set_kernel_args_packing(KernelArgsPacking kernel_args_packing) { - kernel_args_packing_ = std::move(kernel_args_packing); + void set_metadata(KernelMetadata metadata) { + metadata_ = std::move(metadata); } - const KernelArgsPacking &kernel_args_packing() const { - return kernel_args_packing_; + const KernelArgsPacking &args_packing() const { return args_packing_; } + void set_args_packing(KernelArgsPacking args_packing) { + args_packing_ = std::move(args_packing); } + std::string_view name() const { return name_; } void set_name(absl::string_view name); - const std::string &name() const { return name_; } - const std::string &demangled_name() const { return demangled_name_; } - - private: - // The StreamExecutor that loads this kernel object. - StreamExecutor *parent_; - // Implementation delegated to for platform-specific functionality. - std::unique_ptr implementation_; + std::string_view demangled_name() const { return demangled_name_; } + private: std::string name_; std::string demangled_name_; + KernelCacheConfig cache_config_ = KernelCacheConfig::kNoPreference; KernelMetadata metadata_; - - KernelArgsPacking kernel_args_packing_; - - Kernel(const Kernel &) = delete; - void operator=(const Kernel &) = delete; + KernelArgsPacking args_packing_; }; //===----------------------------------------------------------------------===// // Typed kernel //===----------------------------------------------------------------------===// -// Typed variant of Kernel, like a typed device function pointer. +// Typed kernel is a typed smart-pointer-like wrapper around untyped Kernel. template -class TypedKernel : public Kernel { +class TypedKernel { public: static constexpr size_t kNumberOfParameters = sizeof...(Params); - explicit TypedKernel(StreamExecutor *parent) : Kernel(parent) {} + // Creates a typed kernel on a given executor from a kernel specification. + static absl::StatusOr Create(StreamExecutor *executor, + const MultiKernelLoaderSpec &spec) { + TF_ASSIGN_OR_RETURN(std::unique_ptr kernel, + Kernel::Create(executor, spec)); + return TypedKernel(std::move(kernel)); + } + + // Creates a kernel which can be launched with `stream.ThenLaunch(...)` from a + // PTX (and optional CUBIN), such that the types of the arguments provided for + // launch would have to match types of the arguments provided at creation + // time. The canonical storage for both ptx and cubin_data should outlive the + // lifetime of the kernel. + static absl::StatusOr Create( + StreamExecutor *executor, absl::string_view kernel_name, + absl::string_view ptx, absl::Span cubin_data); + + // Creates a kernel which can be launched with `stream.ThenLaunch(...)` from + // an in-process symbol pointer. + static absl::StatusOr Create(StreamExecutor *executor, + absl::string_view kernel_name, + void *symbol); + + // Creates a kernel which can be launched with `stream.ThenLaunch(...)` from + // an LLVM IR. + static absl::StatusOr Create(StreamExecutor *executor, + absl::string_view ir, + absl::string_view entrypoint, + absl::string_view kernel_name, + absl::Span options); + + TypedKernel() = default; + + Kernel &operator*() { return *kernel_; } + const Kernel &operator*() const { return *kernel_; } + + Kernel *operator->() { return kernel_.get(); } + const Kernel *operator->() const { return kernel_.get(); } + + operator bool() const { return static_cast(kernel_); } // NOLINT + + private: + explicit TypedKernel(std::unique_ptr kernel) + : kernel_(std::move(kernel)) {} + + std::unique_ptr kernel_; }; //===----------------------------------------------------------------------===// @@ -520,8 +545,9 @@ std::unique_ptr PackKernelArgs( } } // namespace internal -inline tsl::StatusOr> PackKernelArgs( - absl::Span args, uint32_t shared_mem_bytes) { +inline absl::StatusOr> +PackKernelArgs(absl::Span args, + uint32_t shared_mem_bytes) { static constexpr int kKernelArgsLimit = 1024; if (args.size() > kKernelArgsLimit) @@ -551,8 +577,9 @@ inline tsl::StatusOr> PackKernelArgs( return internal::PackKernelArgs(args, shared_mem_bytes); } -inline tsl::StatusOr> PackKernelArgs( - absl::Span args, const KernelMetadata &metadata) { +inline absl::StatusOr> +PackKernelArgs(absl::Span args, + const KernelMetadata &metadata) { return PackKernelArgs(args, metadata.shared_memory_bytes().value_or(0)); } @@ -711,10 +738,44 @@ std::unique_ptr PackKernelArgs( PackedParams::template CheckCompatibleStaticAssert(); - int64_t shmem_bytes = kernel.metadata().shared_memory_bytes().value_or(0); + int64_t shmem_bytes = kernel->metadata().shared_memory_bytes().value_or(0); return std::make_unique(std::forward(args)..., shmem_bytes); } +template +inline absl::StatusOr> TypedKernel::Create( + StreamExecutor *executor, absl::string_view kernel_name, + absl::string_view ptx, absl::Span cubin_data) { + MultiKernelLoaderSpec loader_spec(TypedKernel::kNumberOfParameters); + loader_spec.AddCudaPtxInMemory(ptx, kernel_name); + + if (!cubin_data.empty()) { + loader_spec.AddCudaCubinInMemory(cubin_data, kernel_name); + } + + return TypedKernel::Create(executor, loader_spec); +} + +template +inline absl::StatusOr> TypedKernel::Create( + StreamExecutor *executor, absl::string_view kernel_name, void *symbol) { + MultiKernelLoaderSpec loader_spec(TypedKernel::kNumberOfParameters); + loader_spec.AddInProcessSymbol(symbol, kernel_name); + + return TypedKernel::Create(executor, loader_spec); +} + +template +inline absl::StatusOr> TypedKernel::Create( + StreamExecutor *executor, absl::string_view ir, + absl::string_view entrypoint, absl::string_view kernel_name, + absl::Span options) { + MultiKernelLoaderSpec loader_spec(TypedKernel::kNumberOfParameters); + loader_spec.AddLlvmHostKernel(ir, entrypoint, kernel_name, options); + + return TypedKernel::Create(executor, loader_spec); +} + } // namespace stream_executor #endif // XLA_STREAM_EXECUTOR_KERNEL_H_ diff --git a/xla/stream_executor/kernel_spec.cc b/xla/stream_executor/kernel_spec.cc index 53e5b6687e71d..5f7077e991bbb 100644 --- a/xla/stream_executor/kernel_spec.cc +++ b/xla/stream_executor/kernel_spec.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,10 +23,9 @@ limitations under the License. #include #include -#include "absl/log/check.h" -#include "absl/log/log.h" #include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "tsl/platform/logging.h" namespace stream_executor { @@ -36,100 +35,44 @@ KernelLoaderSpec::KernelLoaderSpec(absl::string_view kernel_name) InProcessSymbol::InProcessSymbol(void *symbol, std::string kernel_name) : KernelLoaderSpec(std::move(kernel_name)), symbol_(symbol) {} -OnDiskKernelLoaderSpec::OnDiskKernelLoaderSpec(absl::string_view filename, - absl::string_view kernel_name) - : KernelLoaderSpec(kernel_name), filename_(std::string(filename)) {} - -CudaPtxOnDisk::CudaPtxOnDisk(absl::string_view filename, - absl::string_view kernel_name) - : OnDiskKernelLoaderSpec(filename, kernel_name) {} - -CudaCubinOnDisk::CudaCubinOnDisk(absl::string_view filename, - absl::string_view kernel_name) - : OnDiskKernelLoaderSpec(filename, kernel_name) {} - -CudaCubinInMemory::CudaCubinInMemory(const char *bytes, +CudaCubinInMemory::CudaCubinInMemory(absl::Span cubin_bytes, absl::string_view kernel_name) - : KernelLoaderSpec(kernel_name), bytes_(bytes) {} - -bool CompareComputeCapability(const std::tuple &lhs, - const std::tuple &rhs) { - return std::get<0>(lhs) < std::get<0>(rhs) || - (std::get<0>(lhs) == std::get<0>(rhs) && - std::get<1>(lhs) < std::get<1>(rhs)); -} + : KernelLoaderSpec(kernel_name), cubin_bytes_(cubin_bytes) {} const std::tuple CudaPtxInMemory::kMinimumCapability{1, 0}; CudaPtxInMemory::CudaPtxInMemory(absl::string_view ptx, - absl::string_view kernel_name, - bool ptx_compressed) - : KernelLoaderSpec(kernel_name), - ptx_by_compute_capability_(CompareComputeCapability) { - if (ptx_compressed) { - // Lazy decompression. Put an empty string in decompressed_ptx_ showing that - // the original ptx is compressed. - decompressed_ptx_[ptx.data()] = ""; - } + absl::string_view kernel_name) + : KernelLoaderSpec(kernel_name) { ptx_by_compute_capability_[kMinimumCapability] = ptx.data(); } CudaPtxInMemory::CudaPtxInMemory( const std::initializer_list &spec_list, - absl::string_view kernel_name, bool ptx_compressed) - : KernelLoaderSpec(kernel_name), - ptx_by_compute_capability_(CompareComputeCapability) { + absl::string_view kernel_name) + : KernelLoaderSpec(kernel_name) { for (const auto &spec : spec_list) { int major, minor; absl::string_view ptx; std::tie(major, minor, ptx) = spec; - if (ptx_compressed) { - // Lazy decompression. Put an empty string in decompressed_ptx_ showing - // that the original ptx is compressed. - decompressed_ptx_[ptx.data()] = ""; - } ptx_by_compute_capability_[std::tuple{major, minor}] = ptx.data(); } } -std::string CudaPtxInMemory::DecompressPtx(const char *ptx) { - // Get the length of the PTX string from the beginning of the buffer. - uint64_t ptx_length = *reinterpret_cast(ptx); - // Get the PTX string from the buffer with offset and length. - std::string compressed_ptx(ptx + sizeof(uint64_t), - ptx + sizeof(uint64_t) + ptx_length); - std::string decompressed_ptx; - // Decompress the PTX string with bzip2. - LOG(FATAL) << "bzip2 decompression is not supported yet."; - return decompressed_ptx; -} +LlvmHostKernel::LlvmHostKernel(absl::string_view ir, + absl::string_view entrypoint, + absl::string_view kernel_name, + absl::Span options) + : KernelLoaderSpec(std::move(kernel_name)), + ir_(ir), + entrypoint_(entrypoint), + options_(options.cbegin(), options.cend()) {} const char *CudaPtxInMemory::default_text() const { if (ptx_by_compute_capability_.empty()) { return nullptr; } - absl::MutexLock lock(&mu_); - - auto ptx = ptx_by_compute_capability_.begin()->second; - // Check if there is an entry in decompressed ptx table. - auto decompressed_ptx_iter = decompressed_ptx_.find(ptx); - if (decompressed_ptx_iter != decompressed_ptx_.end()) { - // If the decompressed string is empty, which means the ptx hasn't been - // decompressed, decompress it here. - if (decompressed_ptx_iter->second.empty()) { - decompressed_ptx_iter->second = DecompressPtx(ptx); - } - return decompressed_ptx_iter->second.c_str(); - } - return ptx; -} - -const char *CudaPtxInMemory::original_default_text() const { - if (ptx_by_compute_capability_.empty()) { - return nullptr; - } - return ptx_by_compute_capability_.begin()->second; } @@ -143,31 +86,6 @@ const char *CudaPtxInMemory::text(int compute_capability_major, return nullptr; } - absl::MutexLock lock(&mu_); - - // Check if there is an entry in decompressed ptx table. - auto decompressed_ptx_iter = decompressed_ptx_.find(ptx_iter->second); - if (decompressed_ptx_iter != decompressed_ptx_.end()) { - // If the decompressed string is empty, which means the ptx hasn't been - // decompressed, decompress it here. - if (decompressed_ptx_iter->second.empty()) { - decompressed_ptx_iter->second = DecompressPtx(ptx_iter->second); - } - return decompressed_ptx_iter->second.c_str(); - } - return ptx_iter->second; -} - -const char *CudaPtxInMemory::original_text(int compute_capability_major, - int compute_capability_minor) const { - std::tuple capability{compute_capability_major, - compute_capability_minor}; - - auto ptx_iter = ptx_by_compute_capability_.find(capability); - if (ptx_iter == ptx_by_compute_capability_.end()) { - return nullptr; - } - return ptx_iter->second; } @@ -175,62 +93,30 @@ MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddInProcessSymbol( void *symbol, absl::string_view kernel_name) { CHECK(in_process_symbol_ == nullptr); in_process_symbol_ = - std::make_unique(symbol, std::string(kernel_name)); - return this; -} - -MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaPtxOnDisk( - absl::string_view filename, absl::string_view kernel_name) { - CHECK(cuda_ptx_on_disk_ == nullptr); - cuda_ptx_on_disk_.reset(new CudaPtxOnDisk{filename, kernel_name}); + std::make_shared(symbol, std::string(kernel_name)); return this; } MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaCubinInMemory( - const char *bytes, absl::string_view kernel_name) { + absl::Span cubin_bytes, absl::string_view kernel_name) { CHECK(cuda_cubin_in_memory_ == nullptr); - cuda_cubin_in_memory_.reset(new CudaCubinInMemory{bytes, kernel_name}); - return this; -} - -MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaCubinOnDisk( - absl::string_view filename, absl::string_view kernel_name) { - CHECK(cuda_cubin_on_disk_ == nullptr); - cuda_cubin_on_disk_.reset(new CudaCubinOnDisk{filename, kernel_name}); + cuda_cubin_in_memory_.reset(new CudaCubinInMemory{cubin_bytes, kernel_name}); return this; } MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaPtxInMemory( absl::string_view ptx, absl::string_view kernel_name) { CHECK(cuda_ptx_in_memory_ == nullptr); - cuda_ptx_in_memory_.reset( - new CudaPtxInMemory{ptx, kernel_name, false /* ptx_compressed */}); + cuda_ptx_in_memory_.reset(new CudaPtxInMemory{ptx, kernel_name}); return this; } -MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaCompressedPtxInMemory( - absl::string_view ptx, absl::string_view kernel_name) { - CHECK(cuda_ptx_in_memory_ == nullptr); - cuda_ptx_in_memory_.reset( - new CudaPtxInMemory{ptx, kernel_name, true /* ptx_compressed */}); - return this; -} - -MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaPtxInMemory( - std::initializer_list spec_list, - absl::string_view kernel_name) { - CHECK(cuda_ptx_in_memory_ == nullptr); - cuda_ptx_in_memory_.reset( - new CudaPtxInMemory{spec_list, kernel_name, false /* ptx_compressed */}); - return this; -} - -MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaCompressedPtxInMemory( - std::initializer_list spec_list, - absl::string_view kernel_name) { - CHECK(cuda_ptx_in_memory_ == nullptr); - cuda_ptx_in_memory_.reset( - new CudaPtxInMemory{spec_list, kernel_name, true /* ptx_compressed */}); +MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddLlvmHostKernel( + absl::string_view ir, absl::string_view entrypoint, + absl::string_view kernel_name, absl::Span options) { + CHECK(llvm_host_kernel_ == nullptr); + llvm_host_kernel_ = + std::make_shared(ir, entrypoint, kernel_name, options); return this; } diff --git a/xla/stream_executor/kernel_spec.h b/xla/stream_executor/kernel_spec.h index 6144944306bef..d50ac23713dc5 100644 --- a/xla/stream_executor/kernel_spec.h +++ b/xla/stream_executor/kernel_spec.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ limitations under the License. // static const MultiKernelLoaderSpec &SaxpySpec() { // static auto *mkls = // (new MultiKernelLoaderSpec{4 /* = arity */}) -// ->AddCudaPtxOnDisk(ptx_file_path, ptx_kernel_name); +// ->AddCudaPtxInMemory(ptx_bytes, ptx_kernel_name); // }; // // return *mkls; @@ -34,7 +34,7 @@ limitations under the License. // // This lazily instantiates an object that describes how to load CUDA PTX // present on disk that implements saxpy for the CUDA platform. The -// CudaPtxOnDisk object is a subtype of KernelLoaderSpec -- KernelLoaderSpec +// CudaPtxInMemory object is a subtype of KernelLoaderSpec -- KernelLoaderSpec // describes how to load a kernel for subsequent launching on a single platform. // // For the loader functionality that accepts these KernelLoaderSpecs in order @@ -45,6 +45,7 @@ limitations under the License. #include +#include #include #include #include @@ -52,15 +53,14 @@ limitations under the License. #include #include -#include "absl/log/check.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "xla/stream_executor/platform/port.h" +#include "absl/types/span.h" #include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" namespace stream_executor { +class Kernel; // defined in kernel.h class KernelArgs; // defined in kernel.h class KernelArgsPackedArrayBase; // defined in kernel.h @@ -76,7 +76,7 @@ class KernelArgsPackedArrayBase; // defined in kernel.h // files at build time, but can also be specified manually. class KernelLoaderSpec { public: - virtual ~KernelLoaderSpec() {} + virtual ~KernelLoaderSpec() = default; // Returns the kernel name to load out of the program. const std::string &kernel_name() const { return kernel_name_; } @@ -105,61 +105,6 @@ class InProcessSymbol : public KernelLoaderSpec { void *symbol_; }; -// An abstract kernel loader spec that has an associated file path, where -// there's a canonical suffix for the filename; e.g. see CudaPtxOnDisk whose -// canonical filename suffix is ".ptx". -class OnDiskKernelLoaderSpec : public KernelLoaderSpec { - public: - ~OnDiskKernelLoaderSpec() override {} - - // Returns the path to the on-disk loadable kernel file. - const std::string &filename() const { return filename_; } - - // Returns the canonical suffix for this on-disk kernel loader spec format; - // e.g. PTX files on disk have a canonical suffix of ".ptx". - virtual const char *CanonicalSuffix() const = 0; - - protected: - OnDiskKernelLoaderSpec(absl::string_view filename, - absl::string_view kernel_name); - - std::string filename_; - - private: - OnDiskKernelLoaderSpec(const OnDiskKernelLoaderSpec &) = delete; - void operator=(const OnDiskKernelLoaderSpec &) = delete; -}; - -// Kernel loader specification for PTX text that resides on disk. -class CudaPtxOnDisk : public OnDiskKernelLoaderSpec { - public: - CudaPtxOnDisk(absl::string_view filename, absl::string_view kernel_name); - ~CudaPtxOnDisk() override {} - - const char *CanonicalSuffix() const override { return ".ptx"; } - - private: - CudaPtxOnDisk(const CudaPtxOnDisk &) = delete; - void operator=(const CudaPtxOnDisk &) = delete; -}; - -// Kernel loader specification for CUBIN binary that resides on disk. -class CudaCubinOnDisk : public OnDiskKernelLoaderSpec { - public: - CudaCubinOnDisk(absl::string_view filename, absl::string_view kernel_name); - ~CudaCubinOnDisk() override {} - - const std::string &filename() const { return filename_; } - - const char *CanonicalSuffix() const override { return ".cubin"; } - - private: - std::string filename_; - - CudaCubinOnDisk(const CudaCubinOnDisk &) = delete; - void operator=(const CudaCubinOnDisk &) = delete; -}; - // Kernel loader specification for PTX text that resides in memory. class CudaPtxInMemory : public KernelLoaderSpec { public: @@ -175,15 +120,13 @@ class CudaPtxInMemory : public KernelLoaderSpec { // // Warning: the string backing the provided absl::string_view ptx must outlive // this instance. - CudaPtxInMemory(absl::string_view ptx, absl::string_view kernel_name, - bool ptx_compressed = false); + CudaPtxInMemory(absl::string_view ptx, absl::string_view kernel_name); // Multiple-PTX-version constructor. Adds each item in spec_list to this // object. Note that the PTX can be compressed, which is indicated by the // argument ptx_compressed. CudaPtxInMemory(const std::initializer_list &spec_list, - absl::string_view kernel_name, bool ptx_compressed = false); - ~CudaPtxInMemory() override {} + absl::string_view kernel_name); // Add the PTX implementation described by ptx_spec to this object. On // collision (i.e., if a version with the same compute_capability already @@ -194,41 +137,20 @@ class CudaPtxInMemory : public KernelLoaderSpec { // lowest-valued compute capability. For example, if PTX written to CC2.0, // 3.0, and 3.5 are all available, the version for CC2.0 will be set. Returns // nullptr on failed lookup (if any version is not available). - // When the ptx is compressed, returns the decompressed ptx. const char *default_text() const; - // Similar to default_text(). - // When the ptx is compressed, returns the decompressed ptx. - const char *original_default_text() const; - // Returns pointer to the ptx for the requested compute capability. // Returns nullptr on failed lookup (if the requested version is not // available). - // When the ptx is compressed, returns the decompressed ptx. const char *text(int compute_capability_major, int compute_capability_minor) const; - // Similar to text(). - // When the ptx is compressed, returns the original compressed ptx. - const char *original_text(int compute_capability_major, - int compute_capability_minor) const; - - // Decompresses the PTX string using bzip2. - static std::string DecompressPtx(const char *ptx); - private: // PTX translation unit text contents in memory. The key is of as a tuple // ",", i.e., "2,0", "3,0", "3,5". Because CC's // represented in this way have a clear sorting order, map::begin() will give // the lowest-numbered version available, i.e. the default. - std::map, const char *, - bool (*)(const std::tuple &, const std::tuple &)> - ptx_by_compute_capability_; - - // Stores all decompressed ptx strings, with original ptx string as keys. - // It is marked as mutable for lazy decompression. - mutable std::map decompressed_ptx_; - mutable absl::Mutex mu_; + std::map, const char *> ptx_by_compute_capability_; // Defines the minimum compute capability possible. Used when PTX has no // compute capability specified (in the single-PTX constructor). @@ -241,18 +163,37 @@ class CudaPtxInMemory : public KernelLoaderSpec { // Kernel loader specification for a CUBIN blob that resides in memory. class CudaCubinInMemory : public KernelLoaderSpec { public: - CudaCubinInMemory(const char *bytes, absl::string_view kernel_name); - ~CudaCubinInMemory() override {} + CudaCubinInMemory(absl::Span cubin_bytes, + absl::string_view kernel_name); - const char *bytes() const { return bytes_; } + absl::Span cubin_bytes() const { return cubin_bytes_; } private: - const char *bytes_; + absl::Span cubin_bytes_; CudaCubinInMemory(const CudaCubinInMemory &) = delete; void operator=(const CudaCubinInMemory &) = delete; }; +class LlvmHostKernel : public KernelLoaderSpec { + public: + LlvmHostKernel(absl::string_view ir, absl::string_view entrypoint, + absl::string_view kernel_name, + absl::Span options); + + absl::string_view ir() const { return ir_; } + absl::string_view entrypoint() const { return entrypoint_; } + absl::Span options() const { return options_; } + + private: + std::string ir_; + std::string entrypoint_; + std::vector options_; + + LlvmHostKernel(const LlvmHostKernel &) = delete; + void operator=(const LlvmHostKernel &) = delete; +}; + // Describes how to load a kernel on any subset of a number of target platforms. class MultiKernelLoaderSpec { public: @@ -261,8 +202,8 @@ class MultiKernelLoaderSpec { // registering custom CUDA C++ kernels with non-trivial C++ API with a // StreamExecutor as a generic `Kernel`. using KernelArgsPacking = - std::function>( - const KernelArgs &args)>; + std::function>( + const Kernel &kernel, const KernelArgs &args)>; explicit MultiKernelLoaderSpec( size_t arity, KernelArgsPacking kernel_args_packing = nullptr); @@ -273,12 +214,11 @@ class MultiKernelLoaderSpec { // Convenience getters for testing whether these platform variants have // kernel loader specifications available. bool has_in_process_symbol() const { return in_process_symbol_ != nullptr; } - bool has_cuda_ptx_on_disk() const { return cuda_ptx_on_disk_ != nullptr; } - bool has_cuda_cubin_on_disk() const { return cuda_cubin_on_disk_ != nullptr; } bool has_cuda_cubin_in_memory() const { return cuda_cubin_in_memory_ != nullptr; } bool has_cuda_ptx_in_memory() const { return cuda_ptx_in_memory_ != nullptr; } + bool has_llvm_host_kernel() const { return llvm_host_kernel_ != nullptr; } // Accessors for platform variant kernel load specifications. // Precondition: corresponding has_* is true. @@ -286,14 +226,6 @@ class MultiKernelLoaderSpec { CHECK(has_in_process_symbol()); return *in_process_symbol_; } - const CudaPtxOnDisk &cuda_ptx_on_disk() const { - CHECK(has_cuda_ptx_on_disk()); - return *cuda_ptx_on_disk_; - } - const CudaCubinOnDisk &cuda_cubin_on_disk() const { - CHECK(has_cuda_cubin_on_disk()); - return *cuda_cubin_on_disk_; - } const CudaCubinInMemory &cuda_cubin_in_memory() const { CHECK(has_cuda_cubin_in_memory()); return *cuda_cubin_in_memory_; @@ -302,6 +234,10 @@ class MultiKernelLoaderSpec { CHECK(has_cuda_ptx_in_memory()); return *cuda_ptx_in_memory_; } + const LlvmHostKernel &llvm_host_kernel() const { + CHECK(has_llvm_host_kernel()); + return *llvm_host_kernel_; + } // Builder-pattern-like methods for use in initializing a // MultiKernelLoaderSpec. Each of these should be used at most once for a // single MultiKernelLoaderSpec object. See file comment for example usage. @@ -311,22 +247,14 @@ class MultiKernelLoaderSpec { // mangled by the compiler if it is not declared in an extern "C" scope. MultiKernelLoaderSpec *AddInProcessSymbol(void *symbol, absl::string_view kernel_name); - MultiKernelLoaderSpec *AddCudaPtxOnDisk(absl::string_view filename, - absl::string_view kernel_name); - MultiKernelLoaderSpec *AddCudaCubinOnDisk(absl::string_view filename, - absl::string_view kernel_name); - MultiKernelLoaderSpec *AddCudaCubinInMemory(const char *cubin_bytes, - absl::string_view kernel_name); + MultiKernelLoaderSpec *AddCudaCubinInMemory( + absl::Span cubin_bytes, absl::string_view kernel_name); MultiKernelLoaderSpec *AddCudaPtxInMemory(absl::string_view ptx, absl::string_view kernel_name); - MultiKernelLoaderSpec *AddCudaCompressedPtxInMemory( - absl::string_view ptx, absl::string_view kernel_name); - MultiKernelLoaderSpec *AddCudaPtxInMemory( - std::initializer_list spec_list, - absl::string_view kernel_name); - MultiKernelLoaderSpec *AddCudaCompressedPtxInMemory( - std::initializer_list spec_list, - absl::string_view kernel_name); + MultiKernelLoaderSpec *AddLlvmHostKernel(absl::string_view ir, + absl::string_view entrypoint, + absl::string_view kernel_name, + absl::Span options); const KernelArgsPacking &kernel_args_packing() const { return kernel_args_packing_; @@ -335,14 +263,12 @@ class MultiKernelLoaderSpec { private: std::shared_ptr in_process_symbol_; // In process symbol pointer. - std::shared_ptr - cuda_ptx_on_disk_; // PTX text that resides in a file. - std::shared_ptr - cuda_cubin_on_disk_; // Binary CUDA program in a file. std::shared_ptr cuda_cubin_in_memory_; // Binary CUDA program in memory. std::shared_ptr cuda_ptx_in_memory_; // PTX text that resides in memory. + std::shared_ptr + llvm_host_kernel_; // LLVM kernel for host execution. // Number of parameters that the kernel takes. (This is nicer to have in a // constexpr than having to determine it from the types via template diff --git a/xla/stream_executor/kernel_test.cc b/xla/stream_executor/kernel_test.cc index 0397a4db08b51..205c559170e72 100644 --- a/xla/stream_executor/kernel_test.cc +++ b/xla/stream_executor/kernel_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,6 +22,9 @@ limitations under the License. #include #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" @@ -63,7 +66,7 @@ static_assert( std::tuple>); static std::unique_ptr NewStreamExecutor() { - Platform* platform = MultiPlatformManager::PlatformWithName("Host").value(); + Platform* platform = PlatformManager::PlatformWithName("Host").value(); StreamExecutorConfig config(/*ordinal=*/0); return platform->GetUncachedExecutor(config).value(); } @@ -103,11 +106,8 @@ TEST(KernelTest, PackPodArguments) { ASSERT_EQ(f64, 3.0); } -TEST(KernelTest, PackTypedKernelArguments) { - auto executor = NewStreamExecutor(); - TypedKernel kernel(executor.get()); - - auto args = PackKernelArgs(kernel, 1, 2.0f, 3.0); +TEST(KernelTest, PackTupleArguments) { + auto args = PackKernelArgs(/*shmem_bytes=*/0, 1, 2.0f, 3.0); ASSERT_EQ(args->number_of_arguments(), 3); auto packed = args->argument_addresses(); @@ -120,6 +120,14 @@ TEST(KernelTest, PackTypedKernelArguments) { ASSERT_EQ(f64, 3.0); } +TEST(KernelTest, FailToCreateTypedKernelFromEmptySpec) { + MultiKernelLoaderSpec empty_spec(/*arity=*/0); + + auto executor = NewStreamExecutor(); + auto kernel = TypedKernel<>::Create(executor.get(), empty_spec); + EXPECT_FALSE(kernel.ok()); +} + //===----------------------------------------------------------------------===// // Performance benchmarks below //===----------------------------------------------------------------------===// diff --git a/xla/stream_executor/launch_dim.h b/xla/stream_executor/launch_dim.h index ef7d03311bb62..59b935c1ac757 100644 --- a/xla/stream_executor/launch_dim.h +++ b/xla/stream_executor/launch_dim.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,25 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Types to express dimensionality of a kernel launch. Blocks and threads -// are (up to) 3-dimensional. -// -// A thread is conceptually like a SIMD lane. Some number, typically 32 -// (though that fact should not be relied on) SIMD lanes are tied together with -// a single PC in a unit called a warp. There is a maximum number of threads -// that can execute in a shared-context entity called a block. Presently, that -// number is 1024 -- again, something that should not be relied on from this -// comment, but checked via stream_executor::DeviceDescription. -// -// For additional information, see -// http://docs.nvidia.com/cuda/kepler-tuning-guide/#device-utilization-and-occupancy -// -// Because of that modest thread-per-block limit, a kernel can be launched with -// multiple blocks. Each block is indivisibly scheduled onto a single core. -// Blocks can also be used in a multi-dimensional configuration, and the block -// count has much less modest limits -- typically they're similar to the maximum -// amount of addressable memory. - #ifndef XLA_STREAM_EXECUTOR_LAUNCH_DIM_H_ #define XLA_STREAM_EXECUTOR_LAUNCH_DIM_H_ @@ -41,33 +22,56 @@ limitations under the License. #include "absl/strings/str_cat.h" namespace stream_executor { +namespace internal { + +struct Dim3D { + uint64_t x, y, z; + + bool operator==(const Dim3D& other) const { + return x == other.x && y == other.y && z == other.z; + } + + bool operator!=(const Dim3D& other) const { return !(*this == other); } +}; -// Thread dimensionality for use in a kernel launch. See file comment for +} // namespace internal + +// Types to express dimensionality of a kernel launch. Blocks, threads and +// clusters are (up to) 3-dimensional. +// +// See NVIDIA documentation for a thread hierarchy: +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#thread-hierarchy + +// Thread dimensionality for use in a kernel launch. // details. -struct ThreadDim { +struct ThreadDim : internal::Dim3D { explicit ThreadDim(uint64_t x = 1, uint64_t y = 1, uint64_t z = 1) - : x(x), y(y), z(z) {} + : internal::Dim3D({x, y, z}) {} - // Returns a string representation of the thread dimensionality. std::string ToString() const { return absl::StrCat("ThreadDim{", x, ", ", y, ", ", z, "}"); } - - uint64_t x, y, z; }; -// Block dimensionality for use in a kernel launch. See file comment for +// Block dimensionality for use in a kernel launch. // details. -struct BlockDim { +struct BlockDim : internal::Dim3D { explicit BlockDim(uint64_t x = 1, uint64_t y = 1, uint64_t z = 1) - : x(x), y(y), z(z) {} + : internal::Dim3D({x, y, z}) {} - // Returns a string representation of the block dimensionality. std::string ToString() const { return absl::StrCat("BlockDim{", x, ", ", y, ", ", z, "}"); } +}; - uint64_t x, y, z; +// Cluster dimensionality for use in a kernel launch. +struct ClusterDim : internal::Dim3D { + explicit ClusterDim(uint64_t x = 1, uint64_t y = 1, uint64_t z = 1) + : internal::Dim3D({x, y, z}) {} + + std::string ToString() const { + return absl::StrCat("ClusterDim{", x, ", ", y, ", ", z, "}"); + } }; } // namespace stream_executor diff --git a/xla/stream_executor/lazy_op_runner.h b/xla/stream_executor/lazy_op_runner.h index db7ba0b487b30..c54a9bfdac005 100644 --- a/xla/stream_executor/lazy_op_runner.h +++ b/xla/stream_executor/lazy_op_runner.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,18 +17,35 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_LAZY_OP_RUNNER_H_ #include +#include #include #include #include #include #include "absl/base/call_once.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/stream_executor/blas.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/stream.h" +#include "tsl/platform/statusor.h" +#include "tsl/protobuf/dnn.pb.h" namespace stream_executor { namespace dnn { +namespace internal { +// Returns the DnnSupport object for the given stream. +inline absl::StatusOr GetDnnFromStream(Stream* stream) { + auto dnn = stream->parent()->AsDnn(); + if (dnn == nullptr) { + return absl::InternalError("No DNN support for stream"); + } + return dnn; +} +} // namespace internal + // A lazily-initialized OpRunner from an AlgorithmDesc. // // This exists to hold a choice of conv algorithm for a particular config, @@ -57,10 +74,10 @@ class LazyOpRunner { public: // Construct from a pre-initialized OpRunner; all calls to GetOrCreateRunner // will return a pointer to exactly this runner. - static tsl::StatusOr> FromOpRunner( + static absl::StatusOr> FromOpRunner( std::unique_ptr> runner) { if (!runner) { - return tsl::errors::Internal("Null runner argument to FromOpRunner"); + return absl::InternalError("Null runner argument to FromOpRunner"); } TF_ASSIGN_OR_RETURN(auto desc, runner->ToAlgorithmDesc()); // Private constructor cannot be called by make_unique :( @@ -82,7 +99,7 @@ class LazyOpRunner { // executor will be errors. // // The result is owned by LazyOpRunner. - tsl::StatusOr*> GetOrCreateRunner( + absl::StatusOr*> GetOrCreateRunner( typename Op::Config config, Stream* stream) { absl::call_once(once_flag_, [&] { if (runner_) return; // runner was passed via constructor argument @@ -100,11 +117,11 @@ class LazyOpRunner { } // Get the contained runner with the invariant that it's already initialized. - tsl::StatusOr*> GetRunner() { + absl::StatusOr*> GetRunner() { if (auto* runner = runner_ptr_.load(std::memory_order_acquire)) { return runner; } - return tsl::errors::Internal("LazyOpRunner::GetRunner: not initialized"); + return absl::InternalError("LazyOpRunner::GetRunner: not initialized"); } bool operator==(const LazyOpRunner& other) const { @@ -119,7 +136,7 @@ class LazyOpRunner { LazyOpRunner(AlgorithmDesc desc, std::unique_ptr> runner) : desc_(std::move(desc)), - error_(tsl::OkStatus()), + error_(absl::OkStatus()), runner_(std::move(runner)), runner_ptr_(runner_.get()) {} @@ -127,7 +144,7 @@ class LazyOpRunner { // We use absl::call_once to lazily initialize `runner_` (or `error_`). absl::once_flag once_flag_; - tsl::Status error_; // holds error if runner can't be initialized + absl::Status error_; // holds error if runner can't be initialized std::unique_ptr> runner_; // Once we initialize `runner_` we publish a pointer through atomic so that @@ -148,11 +165,12 @@ struct ConvOp { const ConvolutionDescriptor& convolution_descriptor; }; - static tsl::StatusOr>> + static absl::StatusOr>> RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, Stream* stream) { - return stream->ConvolveRunnerFromDesc( - desc, config.kind, config.input_type, config.output_type, + TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); + return dnn->ConvolveRunnerFromDesc( + stream, desc, config.kind, config.input_type, config.output_type, config.input_descriptor, config.filter_descriptor, config.output_descriptor, config.convolution_descriptor); } @@ -173,11 +191,12 @@ struct GraphConvOp { std::string serialized_graph; }; - static tsl::StatusOr>> + static absl::StatusOr>> RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, Stream* stream) { - return stream->GraphConvolveRunnerFromDesc( - desc, config.kind, config.input_type, config.output_type, + TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); + return dnn->GraphConvolveRunnerFromDesc( + stream, desc, config.kind, config.input_type, config.output_type, config.input_descriptor, config.filter_descriptor, config.output_descriptor, config.convolution_descriptor, config.serialized_graph); @@ -200,11 +219,12 @@ struct FusedConvOp { ActivationMode activation_mode; }; - static tsl::StatusOr>> + static absl::StatusOr>> RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, Stream* stream) { - return stream->FusedConvolveRunnerFromDesc( - desc, config.kind, config.input_type, config.bias_type, + TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); + return dnn->FusedConvolveRunnerFromDesc( + stream, desc, config.kind, config.input_type, config.bias_type, config.output_type, config.conv_scale, config.side_input_scale, config.leakyrelu_alpha, config.input_descriptor, config.filter_descriptor, config.bias_descriptor, @@ -218,22 +238,29 @@ struct NormOp { using Signature = NormSignature; struct Config { + NormKind kind; double epsilon; - const TensorDescriptor& input_descriptor; + const TensorDescriptor& x_descriptor; const TensorDescriptor& scale_descriptor; - const TensorDescriptor& bias_descriptor; - const TensorDescriptor& output_descriptor; - std::optional expectation_descriptor; - std::optional norm_factor_descriptor; + const TensorDescriptor& y_or_dx_descriptor; + std::optional bias_descriptor; + std::optional dy_descriptor; + std::optional expectation_descriptor; + std::optional norm_factor_descriptor; + std::optional dscale_descriptor; + std::optional dbias_descriptor; }; - static tsl::StatusOr>> + static absl::StatusOr>> RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, Stream* stream) { - return stream->NormRunnerFromDesc( - desc, config.epsilon, config.input_descriptor, config.scale_descriptor, - config.bias_descriptor, config.output_descriptor, - config.expectation_descriptor, config.norm_factor_descriptor); + TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); + return dnn->NormRunnerFromDesc( + stream, desc, config.kind, config.epsilon, config.x_descriptor, + config.scale_descriptor, config.y_or_dx_descriptor, + config.bias_descriptor, config.dy_descriptor, + config.expectation_descriptor, config.norm_factor_descriptor, + config.dscale_descriptor, config.dbias_descriptor); } }; @@ -246,10 +273,10 @@ struct FusedMatmulOp { // this feature. struct Config {}; - static tsl::StatusOr>> + static absl::StatusOr>> RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, Stream* stream) { - return tsl::errors::Unimplemented("Unimplemented"); + return absl::UnimplementedError("Unimplemented"); } }; @@ -272,11 +299,12 @@ struct FusedMHAOp { bool is_causal_mask; }; - static tsl::StatusOr>> + static absl::StatusOr>> RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, Stream* stream) { - return stream->FusedMHARunnerFromDesc( - desc, config.kind, config.bmm1_lhs_descriptor, + TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); + return dnn->FusedMHARunnerFromDesc( + stream, desc, config.kind, config.bmm1_lhs_descriptor, config.bmm1_rhs_descriptor, config.bmm2_rhs_descriptor, config.intermediate_bmm2_lhs_descriptor, config.output_descriptor, config.activation_descriptor, config.mask_descriptor, @@ -310,12 +338,13 @@ struct FusedMHABackwardOp { bool is_causal_mask; }; - static tsl::StatusOr< + static absl::StatusOr< std::unique_ptr>> RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, Stream* stream) { - return stream->FusedMHABackwardRunnerFromDesc( - desc, config.kind, config.bmm1_grad_gemm1_rhs_descriptor, + TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); + return dnn->FusedMHABackwardRunnerFromDesc( + stream, desc, config.kind, config.bmm1_grad_gemm1_rhs_descriptor, config.bmm1_grad_gemm2_rhs_descriptor, config.bmm2_grad_gemm1_lhs_descriptor, config.bmm2_grad_gemm2_rhs_descriptor, config.d_output_descriptor, diff --git a/xla/stream_executor/memory_allocation.h b/xla/stream_executor/memory_allocation.h new file mode 100644 index 0000000000000..0e0df2442001e --- /dev/null +++ b/xla/stream_executor/memory_allocation.h @@ -0,0 +1,40 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_MEMORY_ALLOCATION_H_ +#define XLA_STREAM_EXECUTOR_MEMORY_ALLOCATION_H_ + +#include + +namespace stream_executor { + +// An RAII handle for a memory allocated for a device. It can be pinned host +// memory, unified memory, device memory, etc. depending on what kinds of +// memories are supported by underlying device. +class MemoryAllocation { + public: + MemoryAllocation() = default; + virtual ~MemoryAllocation() = default; + + MemoryAllocation(MemoryAllocation&&) = delete; + MemoryAllocation& operator=(MemoryAllocation&&) = delete; + + virtual void* opaque() const = 0; + virtual uint64_t size() const = 0; +}; + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_MEMORY_ALLOCATION_H_ diff --git a/xla/stream_executor/module_spec.h b/xla/stream_executor/module_spec.h index cfee31e514c14..eb5c54a939bef 100644 --- a/xla/stream_executor/module_spec.h +++ b/xla/stream_executor/module_spec.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ limitations under the License. #include #include "absl/types/span.h" -#include "xla/stream_executor/platform/port.h" #include "tsl/platform/logging.h" namespace stream_executor { diff --git a/xla/stream_executor/multi_platform_manager.cc b/xla/stream_executor/multi_platform_manager.cc deleted file mode 100644 index 8ae90a9cd4f31..0000000000000 --- a/xla/stream_executor/multi_platform_manager.cc +++ /dev/null @@ -1,303 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/stream_executor/multi_platform_manager.h" - -#include - -#include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" -#include "absl/strings/ascii.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "xla/stream_executor/platform/initialize.h" -#include "tsl/platform/errors.h" - -namespace stream_executor { -namespace { - -class MultiPlatformManagerImpl { - public: - tsl::Status RegisterPlatform(std::unique_ptr platform) - ABSL_LOCKS_EXCLUDED(mu_); - - tsl::StatusOr PlatformWithName(absl::string_view target) - ABSL_LOCKS_EXCLUDED(mu_); - - tsl::StatusOr PlatformWithId(const Platform::Id& id) - ABSL_LOCKS_EXCLUDED(mu_); - - tsl::StatusOr PlatformWithName(absl::string_view target, - bool initialize_platform) - ABSL_LOCKS_EXCLUDED(mu_); - - tsl::StatusOr PlatformWithId(const Platform::Id& id, - bool initialize_platform) - ABSL_LOCKS_EXCLUDED(mu_); - - tsl::StatusOr InitializePlatformWithName( - absl::string_view target, - const std::map& options) - ABSL_LOCKS_EXCLUDED(mu_); - tsl::StatusOr InitializePlatformWithId( - const Platform::Id& id, const std::map& options) - ABSL_LOCKS_EXCLUDED(mu_); - - tsl::StatusOr> PlatformsWithFilter( - const std::function& filter, - bool initialize_platform) ABSL_LOCKS_EXCLUDED(mu_); - - private: - // Looks up the platform object with the given name. Assumes the Platforms - // mutex is held. - tsl::StatusOr LookupByNameLocked(absl::string_view target) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Looks up the platform object with the given id. Assumes the Platforms - // mutex is held. - tsl::StatusOr LookupByIdLocked(const Platform::Id& id) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Returns the names of the initialized platforms satisfying the given filter. - // By default, it will return all initialized platform names. - std::vector InitializedPlatformNamesWithFilter( - const std::function& filter = [](const Platform*) { - return true; - }) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - absl::Mutex mu_; - absl::flat_hash_map id_map_ ABSL_GUARDED_BY(mu_); - absl::flat_hash_map name_map_ ABSL_GUARDED_BY(mu_); -}; - -tsl::Status MultiPlatformManagerImpl::RegisterPlatform( - std::unique_ptr platform) { - CHECK(platform != nullptr); - std::string key = absl::AsciiStrToLower(platform->Name()); - absl::MutexLock lock(&mu_); - if (name_map_.find(key) != name_map_.end()) { - return tsl::Status(absl::StatusCode::kInternal, - "platform is already registered with name: \"" + - platform->Name() + "\""); - } - Platform* platform_ptr = platform.get(); - CHECK(id_map_.emplace(platform->id(), platform_ptr).second); - // Release ownership/uniqueness to prevent destruction on program exit. - // This avoids Platforms "cleaning up" on program exit, because otherwise, - // there are _very_ tricky races between StreamExecutor and underlying - // platforms (CUDA, OpenCL) during exit. Since these are fixed-size and 1x per - // program, these are deemed acceptable. - name_map_[key] = platform.release(); - return ::tsl::OkStatus(); -} - -tsl::StatusOr MultiPlatformManagerImpl::PlatformWithName( - absl::string_view target) { - return PlatformWithName(target, /*initialize_platform=*/true); -} - -tsl::StatusOr MultiPlatformManagerImpl::PlatformWithId( - const Platform::Id& id) { - return PlatformWithId(id, /*initialize_platform=*/true); -} - -tsl::StatusOr MultiPlatformManagerImpl::PlatformWithName( - absl::string_view target, bool initialize_platform) { - absl::MutexLock lock(&mu_); - - TF_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target)); - if (initialize_platform && !platform->Initialized()) { - TF_RETURN_IF_ERROR(platform->Initialize({})); - } - - return platform; -} - -tsl::StatusOr MultiPlatformManagerImpl::PlatformWithId( - const Platform::Id& id, bool initialize_platform) { - absl::MutexLock lock(&mu_); - - TF_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id)); - if (initialize_platform && !platform->Initialized()) { - TF_RETURN_IF_ERROR(platform->Initialize({})); - } - - return platform; -} - -tsl::StatusOr MultiPlatformManagerImpl::InitializePlatformWithName( - absl::string_view target, - const std::map& options) { - absl::MutexLock lock(&mu_); - - TF_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target)); - if (platform->Initialized()) { - return tsl::Status( - absl::StatusCode::kFailedPrecondition, - absl::StrCat("platform \"", target, "\" is already initialized")); - } - - TF_RETURN_IF_ERROR(platform->Initialize(options)); - - return platform; -} - -tsl::StatusOr MultiPlatformManagerImpl::InitializePlatformWithId( - const Platform::Id& id, const std::map& options) { - absl::MutexLock lock(&mu_); - - TF_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id)); - if (platform->Initialized()) { - return tsl::Status( - absl::StatusCode::kFailedPrecondition, - absl::StrFormat("platform with id %p is already initialized", id)); - } - - TF_RETURN_IF_ERROR(platform->Initialize(options)); - - return platform; -} - -tsl::StatusOr> -MultiPlatformManagerImpl::PlatformsWithFilter( - const std::function& filter, - bool initialize_platform) { - absl::MutexLock lock(&mu_); - CHECK_EQ(id_map_.size(), name_map_.size()); - std::vector platforms; - platforms.reserve(id_map_.size()); - for (const auto& entry : id_map_) { - Platform* platform = entry.second; - if (filter(platform)) { - if (initialize_platform && !platform->Initialized()) { - TF_RETURN_IF_ERROR(platform->Initialize({})); - } - platforms.push_back(platform); - } - } - return platforms; -} - -std::vector -MultiPlatformManagerImpl::InitializedPlatformNamesWithFilter( - const std::function& filter) { - CHECK_EQ(id_map_.size(), name_map_.size()); - std::vector initialized_platforms_names; - initialized_platforms_names.reserve(id_map_.size()); - for (const auto& entry : id_map_) { - Platform* platform = entry.second; - if (filter(platform)) { - if (platform->Initialized()) { - initialized_platforms_names.push_back(platform->Name()); - } - } - } - return initialized_platforms_names; -} - -tsl::StatusOr MultiPlatformManagerImpl::LookupByNameLocked( - absl::string_view target) { - auto it = name_map_.find(absl::AsciiStrToLower(target)); - if (it == name_map_.end()) { - return tsl::Status( - absl::StatusCode::kNotFound, - absl::StrCat("Could not find registered platform with name: \"", target, - "\". Available platform names are: ", - absl::StrJoin(InitializedPlatformNamesWithFilter(), " "))); - } - return it->second; -} - -tsl::StatusOr MultiPlatformManagerImpl::LookupByIdLocked( - const Platform::Id& id) { - auto it = id_map_.find(id); - if (it == id_map_.end()) { - return tsl::Status( - absl::StatusCode::kNotFound, - absl::StrFormat("could not find registered platform with id: %p", id)); - } - return it->second; -} - -MultiPlatformManagerImpl& Impl() { - static MultiPlatformManagerImpl* impl = new MultiPlatformManagerImpl; - return *impl; -} - -} // namespace - -/*static*/ tsl::Status MultiPlatformManager::RegisterPlatform( - std::unique_ptr platform) { - return Impl().RegisterPlatform(std::move(platform)); -} - -/*static*/ tsl::StatusOr MultiPlatformManager::PlatformWithName( - absl::string_view target) { - return Impl().PlatformWithName(target); -} - -/*static*/ tsl::StatusOr MultiPlatformManager::PlatformWithId( - const Platform::Id& id) { - return Impl().PlatformWithId(id); -} - -/*static*/ tsl::StatusOr MultiPlatformManager::PlatformWithName( - absl::string_view target, bool initialize_platform) { - return Impl().PlatformWithName(target, initialize_platform); -} - -/*static*/ tsl::StatusOr -MultiPlatformManager::InitializePlatformWithId( - const Platform::Id& id, const std::map& options) { - return Impl().InitializePlatformWithId(id, options); -} - -/*static*/ tsl::StatusOr> -MultiPlatformManager::PlatformsWithFilter( - const std::function& filter) { - return PlatformsWithFilter(filter, /*initialize_platform=*/true); -} - -/*static*/ tsl::StatusOr> -MultiPlatformManager::PlatformsWithFilter( - const std::function& filter, - bool initialize_platform) { - return Impl().PlatformsWithFilter(filter, initialize_platform); -} - -} // namespace stream_executor - -REGISTER_MODULE_INITIALIZER( - multi_platform_manager, - { - // Nothing -- this is just a module initializer - // definition to reference for sequencing - // purposes from Platform subclasses that register - // themselves with the MultiPlatformManager. - }); - -REGISTER_MODULE_INITIALIZER( - multi_platform_manager_listener, - { - // Nothing -- this is just a module initializer definition to reference - // for sequencing registration of listeners with the - // MultiPlatformManager. - }); - -// Listener registration should happen before platform registration. -REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener, - multi_platform_manager); diff --git a/xla/stream_executor/multi_platform_manager.h b/xla/stream_executor/multi_platform_manager.h deleted file mode 100644 index b326983404d9b..0000000000000 --- a/xla/stream_executor/multi_platform_manager.h +++ /dev/null @@ -1,155 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This is a registration-oriented interface for multiple platforms. -// -// Usage: -// -// In your BUILD rule, add a dependency on a platform plugin that you'd like -// to use, such as: -// -// //third_party/tensorflow/compiler/xla/stream_executor/cuda:cuda_platform -// //third_party/tensorflow/compiler/xla/stream_executor/opencl:opencl_platform -// -// This will register platform plugins that can be discovered via this -// interface. Sample API usage: -// -// tsl::StatusOr platform_status = -// se::MultiPlatformManager::PlatformWithName("OpenCL"); -// if (!platform_status.ok()) { ... } -// Platform* platform = platform_status.value(); -// LOG(INFO) << platform->VisibleDeviceCount() << " devices visible"; -// if (platform->VisibleDeviceCount() <= 0) { return; } -// -// for (int i = 0; i < platform->VisibleDeviceCount(); ++i) { -// tsl::StatusOr executor_status = -// platform->ExecutorForDevice(i); -// if (!executor_status.ok()) { -// LOG(INFO) << "could not retrieve executor for device ordinal " << i -// << ": " << executor_status.status(); -// continue; -// } -// LOG(INFO) << "found usable executor: " << executor_status.value(); -// } -// -// A few things to note: -// - There is no standard formatting/practice for identifying the name of a -// platform. Ideally, a platform will list its registered name in its header -// or in other associated documentation. -// - Platform name lookup is case-insensitive. "OpenCL" or "opencl" (or even -// ("OpEnCl") would work correctly in the above example. -// -// And similarly, for standard interfaces (BLAS, etc.) you can add -// dependencies on support libraries, e.g.: -// -// //third_party/tensorflow/compiler/xla/stream_executor/cuda:pluton_blas_plugin -// //third_party/tensorflow/compiler/xla/stream_executor/cuda:cudnn_plugin -// //third_party/tensorflow/compiler/xla/stream_executor/cuda:cublas_plugin - -#ifndef XLA_STREAM_EXECUTOR_MULTI_PLATFORM_MANAGER_H_ -#define XLA_STREAM_EXECUTOR_MULTI_PLATFORM_MANAGER_H_ - -#include -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "xla/stream_executor/platform.h" -#include "xla/stream_executor/platform/initialize.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" - -namespace stream_executor { - -// Manages multiple platforms that may be present on the current machine. -class MultiPlatformManager { - public: - // Registers a platform object, returns an error status if the platform is - // already registered. The associated listener, if not null, will be used to - // trace events for ALL executors for that platform. - // Takes ownership of platform. - static tsl::Status RegisterPlatform(std::unique_ptr platform); - - // Retrieves the platform registered with the given platform name (e.g. - // "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the - // Platform's Id() method). - // - // If the platform has not already been initialized, it will be initialized - // with a default set of parameters. - // - // If the requested platform is not registered, an error status is returned. - // Ownership of the platform is NOT transferred to the caller -- - // the MultiPlatformManager owns the platforms in a singleton-like fashion. - static tsl::StatusOr PlatformWithName(absl::string_view target); - static tsl::StatusOr PlatformWithId(const Platform::Id& id); - - // Same functions as above, but allows platforms to be returned without - // initialization if initialize_platform == false. - static tsl::StatusOr PlatformWithName(absl::string_view target, - bool initialize_platform); - - // Retrieves the platform registered with the given platform id (an opaque, - // comparable value provided by the Platform's Id() method). - // - // The platform will be initialized with the given options. If the platform - // was already initialized, an error will be returned. - // - // If the requested platform is not registered, an error status is returned. - // Ownership of the platform is NOT transferred to the caller -- - // the MultiPlatformManager owns the platforms in a singleton-like fashion. - static tsl::StatusOr InitializePlatformWithId( - const Platform::Id& id, - const std::map& options); - - // Retrieves the platforms satisfying the given filter, i.e. returns true. - // Returned Platforms are always initialized. - static tsl::StatusOr> PlatformsWithFilter( - const std::function& filter); - - static tsl::StatusOr> PlatformsWithFilter( - const std::function& filter, - bool initialize_platform); - - // Although the MultiPlatformManager "owns" its platforms, it holds them as - // undecorated pointers to prevent races during program exit (between this - // object's data and the underlying platforms (e.g., CUDA, OpenCL). - // Because certain platforms have unpredictable deinitialization - // times/sequences, it is not possible to strucure a safe deinitialization - // sequence. Thus, we intentionally "leak" allocated platforms to defer - // cleanup to the OS. This should be acceptable, as these are one-time - // allocations per program invocation. - // The MultiPlatformManager should be considered the owner - // of any platforms registered with it, and leak checking should be disabled - // during allocation of such Platforms, to avoid spurious reporting at program - // exit. -}; - -} // namespace stream_executor - -// multi_platform_manager.cc will define these instances. -// -// Registering a platform: -// REGISTER_MODULE_INITIALIZER_SEQUENCE(my_platform, multi_platform_manager); -// REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener, -// my_platform); -// -// Registering a listener: -// REGISTER_MODULE_INITIALIZER_SEQUENCE(my_listener, -// multi_platform_manager_listener); -DECLARE_MODULE_INITIALIZER(multi_platform_manager); -DECLARE_MODULE_INITIALIZER(multi_platform_manager_listener); - -#endif // XLA_STREAM_EXECUTOR_MULTI_PLATFORM_MANAGER_H_ diff --git a/xla/stream_executor/numeric_options.h b/xla/stream_executor/numeric_options.h index f75dfb54ff65b..5620d3ad45def 100644 --- a/xla/stream_executor/numeric_options.h +++ b/xla/stream_executor/numeric_options.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/platform.cc b/xla/stream_executor/platform.cc index 801e381ffa817..b1c2680868cc0 100644 --- a/xla/stream_executor/platform.cc +++ b/xla/stream_executor/platform.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,8 +19,6 @@ limitations under the License. #include #include "absl/status/status.h" -#include "xla/stream_executor/device_options.h" -#include "tsl/platform/status.h" namespace stream_executor { @@ -35,23 +33,22 @@ std::string StreamPriorityToString(StreamPriority priority) { } } -StreamExecutorConfig::StreamExecutorConfig() - : ordinal(-1), device_options(DeviceOptions::Default()) {} +StreamExecutorConfig::StreamExecutorConfig() : ordinal(-1) {} StreamExecutorConfig::StreamExecutorConfig(int ordinal_in) - : ordinal(ordinal_in), device_options(DeviceOptions::Default()) {} + : ordinal(ordinal_in) {} Platform::~Platform() {} bool Platform::Initialized() const { return true; } -tsl::Status Platform::Initialize( +absl::Status Platform::Initialize( const std::map &platform_options) { if (!platform_options.empty()) { return absl::UnimplementedError( "this platform does not support custom initialization"); } - return tsl::OkStatus(); + return absl::OkStatus(); } } // namespace stream_executor diff --git a/xla/stream_executor/platform.h b/xla/stream_executor/platform.h index b03d046220bbb..e9e27bba23598 100644 --- a/xla/stream_executor/platform.h +++ b/xla/stream_executor/platform.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,11 +23,9 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/device_options.h" -#include "xla/stream_executor/platform/port.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace stream_executor { @@ -44,8 +42,7 @@ std::string StreamPriorityToString(StreamPriority priority); // StreamExecutorConfig encapsulates the set of options for constructing a // StreamExecutor for a given platform. struct StreamExecutorConfig { - // Sets members to defaults: -1 for ordinal (must be changed), and default - // PluginConfig and DeviceOptions. + // Sets members to defaults: -1 for ordinal (must be changed). StreamExecutorConfig(); // Simple ordinal-setting constructor. @@ -57,12 +54,9 @@ struct StreamExecutorConfig { // The ordinal of the device to be managed by the returned StreamExecutor. int ordinal; - - // The DeviceOptions for the returned StreamExecutor. - DeviceOptions device_options; }; -// Abstract base class for a platform registered with the MultiPlatformManager. +// Abstract base class for a platform registered with the PlatformManager. class Platform { public: virtual ~Platform(); @@ -103,9 +97,9 @@ class Platform { // initialized before obtaining StreamExecutor objects. The interpretation of // the platform_options argument is implementation specific. This method may // return an error if unrecognized options are provided. If using - // MultiPlatformManager, this method will be called automatically by + // PlatformManager, this method will be called automatically by // InitializePlatformWithId/InitializePlatformWithName. - virtual tsl::Status Initialize( + virtual absl::Status Initialize( const std::map& platform_options); // Returns a populated DeviceDescription for the device at the given ordinal. @@ -114,7 +108,7 @@ class Platform { // // Alternatively callers may call GetDeviceDescription() on the StreamExecutor // which returns a cached instance specific to the initialized StreamExecutor. - virtual tsl::StatusOr> + virtual absl::StatusOr> DescriptionForDevice(int ordinal) const = 0; // Returns a device with the given ordinal on this platform with a default @@ -124,17 +118,17 @@ class Platform { // // Ownership of the executor is NOT transferred to the caller -- // the Platform owns the executors in a singleton-like fashion. - virtual tsl::StatusOr ExecutorForDevice(int ordinal) = 0; + virtual absl::StatusOr ExecutorForDevice(int ordinal) = 0; // Returns a device constructed with the options specified in "config". // Ownership of the executor is NOT transferred to the caller. - virtual tsl::StatusOr GetExecutor( + virtual absl::StatusOr GetExecutor( const StreamExecutorConfig& config) = 0; // Returns a device constructed with the options specified in "config" without // looking in or storing to the Platform's executor cache. // Ownership IS transferred to the caller. - virtual tsl::StatusOr> GetUncachedExecutor( + virtual absl::StatusOr> GetUncachedExecutor( const StreamExecutorConfig& config) = 0; protected: diff --git a/xla/stream_executor/platform/BUILD b/xla/stream_executor/platform/BUILD index 6ac277f9a52ac..0833c9dd84112 100644 --- a/xla/stream_executor/platform/BUILD +++ b/xla/stream_executor/platform/BUILD @@ -1,11 +1,11 @@ -load("@tsl//tsl:tsl.bzl", "set_external_visibility") -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("@tsl//tsl:tsl.bzl", "internal_visibility") load("@tsl//tsl/platform:build_config.bzl", "tf_stream_executor_deps") +load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//xla/stream_executor:build_defs.bzl", "stream_executor_friends") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([":friends"]), + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) diff --git a/xla/stream_executor/platform/default/BUILD b/xla/stream_executor/platform/default/BUILD index 010fa41fc3314..908e58d5dd241 100644 --- a/xla/stream_executor/platform/default/BUILD +++ b/xla/stream_executor/platform/default/BUILD @@ -1,5 +1,5 @@ -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load("@tsl//tsl:tsl.bzl", "tsl_copts") +load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") licenses(["notice"]) @@ -27,5 +27,7 @@ cc_library( deps = [ "@com_google_absl//absl/strings", "@tsl//tsl/platform:dso_loader", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/stream_executor/platform/default/dso_loader.h b/xla/stream_executor/platform/default/dso_loader.h index baf9de18876c6..93f95747615a1 100644 --- a/xla/stream_executor/platform/default/dso_loader.h +++ b/xla/stream_executor/platform/default/dso_loader.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,10 +21,10 @@ limitations under the License. #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace stream_executor { namespace internal { diff --git a/xla/stream_executor/platform/default/initialize.h b/xla/stream_executor/platform/default/initialize.h index 6928d6adbb0c2..cb951ed8b0611 100644 --- a/xla/stream_executor/platform/default/initialize.h +++ b/xla/stream_executor/platform/default/initialize.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,13 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// IWYU pragma: private, include "third_party/tensorflow/compiler/xla/stream_executor/platform/initialize.h" + #ifndef XLA_STREAM_EXECUTOR_PLATFORM_DEFAULT_INITIALIZE_H_ #define XLA_STREAM_EXECUTOR_PLATFORM_DEFAULT_INITIALIZE_H_ -#undef REGISTER_MODULE_INITIALIZER -#undef DECLARE_MODULE_INITIALIZER -#undef REGISTER_MODULE_INITIALIZER_SEQUENCE - namespace stream_executor { namespace port { @@ -44,19 +42,20 @@ class Initializer { } // namespace port } // namespace stream_executor -#define REGISTER_INITIALIZER(type, name, body) \ +#define STREAM_EXECUTOR_REGISTER_INITIALIZER(type, name, body) \ static void google_init_##type##_##name() { body; } \ ::stream_executor::port::Initializer google_initializer_##type##_##name( \ google_init_##type##_##name) -#define REGISTER_MODULE_INITIALIZER(name, body) \ - REGISTER_INITIALIZER(module, name, body) +#define STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(name, body) \ + STREAM_EXECUTOR_REGISTER_INITIALIZER(module, name, body) -#define DECLARE_INITIALIZER(type, name) \ +#define STREAM_EXECUTOR_DECLARE_INITIALIZER(type, name) \ extern ::stream_executor::port::Initializer google_initializer_##type##_##name -#define DECLARE_MODULE_INITIALIZER(name) DECLARE_INITIALIZER(module, name) +#define STREAM_EXECUTOR_DECLARE_MODULE_INITIALIZER(name) \ + STREAM_EXECUTOR_DECLARE_INITIALIZER(module, name) -#define REGISTER_MODULE_INITIALIZER_SEQUENCE(name1, name2) +#define STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER_SEQUENCE(name1, name2) #endif // XLA_STREAM_EXECUTOR_PLATFORM_DEFAULT_INITIALIZE_H_ diff --git a/xla/stream_executor/platform/dso_loader.h b/xla/stream_executor/platform/dso_loader.h index 41538c0ef067a..bfd1e061f9d82 100644 --- a/xla/stream_executor/platform/dso_loader.h +++ b/xla/stream_executor/platform/dso_loader.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/platform/initialize.h b/xla/stream_executor/platform/initialize.h index d56bda8383ae2..910b011634318 100644 --- a/xla/stream_executor/platform/initialize.h +++ b/xla/stream_executor/platform/initialize.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/platform/platform.h b/xla/stream_executor/platform/platform.h index 3ea6fa66d362f..3b00ab8cc6476 100644 --- a/xla/stream_executor/platform/platform.h +++ b/xla/stream_executor/platform/platform.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/platform/port.h b/xla/stream_executor/platform/port.h index dfb9c1997baf3..6cd6654061501 100644 --- a/xla/stream_executor/platform/port.h +++ b/xla/stream_executor/platform/port.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/platform_manager.cc b/xla/stream_executor/platform_manager.cc new file mode 100644 index 0000000000000..4773b85bf2c1c --- /dev/null +++ b/xla/stream_executor/platform_manager.cc @@ -0,0 +1,285 @@ +/* Copyright 2015 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/platform_manager.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "xla/stream_executor/platform.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace stream_executor { +namespace { + +class PlatformManagerImpl { + public: + absl::Status RegisterPlatform(std::unique_ptr platform) + ABSL_LOCKS_EXCLUDED(mu_); + + absl::StatusOr PlatformWithName(absl::string_view target) + ABSL_LOCKS_EXCLUDED(mu_); + + absl::StatusOr PlatformWithId(const Platform::Id& id) + ABSL_LOCKS_EXCLUDED(mu_); + + absl::StatusOr PlatformWithName(absl::string_view target, + bool initialize_platform) + ABSL_LOCKS_EXCLUDED(mu_); + + absl::StatusOr PlatformWithId(const Platform::Id& id, + bool initialize_platform) + ABSL_LOCKS_EXCLUDED(mu_); + + absl::StatusOr InitializePlatformWithName( + absl::string_view target, + const std::map& options) + ABSL_LOCKS_EXCLUDED(mu_); + absl::StatusOr InitializePlatformWithId( + const Platform::Id& id, const std::map& options) + ABSL_LOCKS_EXCLUDED(mu_); + + absl::StatusOr> PlatformsWithFilter( + const std::function& filter, + bool initialize_platform) ABSL_LOCKS_EXCLUDED(mu_); + + private: + // Looks up the platform object with the given name. Assumes the Platforms + // mutex is held. + absl::StatusOr LookupByNameLocked(absl::string_view target) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Looks up the platform object with the given id. Assumes the Platforms + // mutex is held. + absl::StatusOr LookupByIdLocked(const Platform::Id& id) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Returns the names of the initialized platforms satisfying the given filter. + // By default, it will return all initialized platform names. + std::vector InitializedPlatformNamesWithFilter( + const std::function& filter = [](const Platform*) { + return true; + }) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + absl::Mutex mu_; + absl::flat_hash_map id_map_ ABSL_GUARDED_BY(mu_); + absl::flat_hash_map name_map_ ABSL_GUARDED_BY(mu_); +}; + +absl::Status PlatformManagerImpl::RegisterPlatform( + std::unique_ptr platform) { + CHECK(platform != nullptr); + std::string key = absl::AsciiStrToLower(platform->Name()); + absl::MutexLock lock(&mu_); + if (name_map_.find(key) != name_map_.end()) { + return absl::InternalError("platform is already registered with name: \"" + + platform->Name() + "\""); + } + Platform* platform_ptr = platform.get(); + CHECK(id_map_.emplace(platform->id(), platform_ptr).second); + // Release ownership/uniqueness to prevent destruction on program exit. + // This avoids Platforms "cleaning up" on program exit, because otherwise, + // there are _very_ tricky races between StreamExecutor and underlying + // platforms (CUDA, OpenCL) during exit. Since these are fixed-size and 1x per + // program, these are deemed acceptable. + name_map_[key] = platform.release(); + return absl::OkStatus(); +} + +absl::StatusOr PlatformManagerImpl::PlatformWithName( + absl::string_view target) { + return PlatformWithName(target, /*initialize_platform=*/true); +} + +absl::StatusOr PlatformManagerImpl::PlatformWithId( + const Platform::Id& id) { + return PlatformWithId(id, /*initialize_platform=*/true); +} + +absl::StatusOr PlatformManagerImpl::PlatformWithName( + absl::string_view target, bool initialize_platform) { + absl::MutexLock lock(&mu_); + + TF_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target)); + if (initialize_platform && !platform->Initialized()) { + TF_RETURN_IF_ERROR(platform->Initialize({})); + } + + return platform; +} + +absl::StatusOr PlatformManagerImpl::PlatformWithId( + const Platform::Id& id, bool initialize_platform) { + absl::MutexLock lock(&mu_); + + TF_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id)); + if (initialize_platform && !platform->Initialized()) { + TF_RETURN_IF_ERROR(platform->Initialize({})); + } + + return platform; +} + +absl::StatusOr PlatformManagerImpl::InitializePlatformWithName( + absl::string_view target, + const std::map& options) { + absl::MutexLock lock(&mu_); + + TF_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target)); + if (platform->Initialized()) { + return absl::FailedPreconditionError( + absl::StrCat("platform \"", target, "\" is already initialized")); + } + + TF_RETURN_IF_ERROR(platform->Initialize(options)); + + return platform; +} + +absl::StatusOr PlatformManagerImpl::InitializePlatformWithId( + const Platform::Id& id, const std::map& options) { + absl::MutexLock lock(&mu_); + + TF_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id)); + if (platform->Initialized()) { + return absl::FailedPreconditionError( + absl::StrFormat("platform with id %p is already initialized", id)); + } + + TF_RETURN_IF_ERROR(platform->Initialize(options)); + + return platform; +} + +absl::StatusOr> PlatformManagerImpl::PlatformsWithFilter( + const std::function& filter, + bool initialize_platform) { + absl::MutexLock lock(&mu_); + CHECK_EQ(id_map_.size(), name_map_.size()); + std::vector platforms; + platforms.reserve(id_map_.size()); + for (const auto& entry : id_map_) { + Platform* platform = entry.second; + if (filter(platform)) { + if (initialize_platform && !platform->Initialized()) { + TF_RETURN_IF_ERROR(platform->Initialize({})); + } + platforms.push_back(platform); + } + } + return platforms; +} + +std::vector +PlatformManagerImpl::InitializedPlatformNamesWithFilter( + const std::function& filter) { + CHECK_EQ(id_map_.size(), name_map_.size()); + std::vector initialized_platforms_names; + initialized_platforms_names.reserve(id_map_.size()); + for (const auto& entry : id_map_) { + Platform* platform = entry.second; + if (filter(platform)) { + if (platform->Initialized()) { + initialized_platforms_names.push_back(platform->Name()); + } + } + } + return initialized_platforms_names; +} + +absl::StatusOr PlatformManagerImpl::LookupByNameLocked( + absl::string_view target) { + auto it = name_map_.find(absl::AsciiStrToLower(target)); + if (it == name_map_.end()) { + return absl::NotFoundError( + absl::StrCat("Could not find registered platform with name: \"", target, + "\". Available platform names are: ", + absl::StrJoin(InitializedPlatformNamesWithFilter(), " "))); + } + return it->second; +} + +absl::StatusOr PlatformManagerImpl::LookupByIdLocked( + const Platform::Id& id) { + auto it = id_map_.find(id); + if (it == id_map_.end()) { + return absl::NotFoundError( + absl::StrFormat("could not find registered platform with id: %p", id)); + } + return it->second; +} + +PlatformManagerImpl& Impl() { + static PlatformManagerImpl* impl = new PlatformManagerImpl; + return *impl; +} + +} // namespace + +/*static*/ absl::Status PlatformManager::RegisterPlatform( + std::unique_ptr platform) { + return Impl().RegisterPlatform(std::move(platform)); +} + +/*static*/ absl::StatusOr PlatformManager::PlatformWithName( + absl::string_view target) { + return Impl().PlatformWithName(target); +} + +/*static*/ absl::StatusOr PlatformManager::PlatformWithId( + const Platform::Id& id) { + return Impl().PlatformWithId(id); +} + +/*static*/ absl::StatusOr PlatformManager::PlatformWithName( + absl::string_view target, bool initialize_platform) { + return Impl().PlatformWithName(target, initialize_platform); +} + +/*static*/ absl::StatusOr PlatformManager::InitializePlatformWithId( + const Platform::Id& id, const std::map& options) { + return Impl().InitializePlatformWithId(id, options); +} + +/*static*/ absl::StatusOr> +PlatformManager::PlatformsWithFilter( + const std::function& filter) { + return PlatformsWithFilter(filter, /*initialize_platform=*/true); +} + +/*static*/ absl::StatusOr> +PlatformManager::PlatformsWithFilter( + const std::function& filter, + bool initialize_platform) { + return Impl().PlatformsWithFilter(filter, initialize_platform); +} + +} // namespace stream_executor diff --git a/xla/stream_executor/platform_manager.h b/xla/stream_executor/platform_manager.h new file mode 100644 index 0000000000000..4f88d2a79f2d2 --- /dev/null +++ b/xla/stream_executor/platform_manager.h @@ -0,0 +1,142 @@ +/* Copyright 2015 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is a registration-oriented interface for multiple platforms. +// +// Usage: +// +// In your BUILD rule, add a dependency on a platform plugin that you'd like +// to use, such as: +// +// //third_party/tensorflow/compiler/xla/stream_executor/cuda:cuda_platform +// //third_party/tensorflow/compiler/xla/stream_executor/opencl:opencl_platform +// +// This will register platform plugins that can be discovered via this +// interface. Sample API usage: +// +// absl::StatusOr platform_status = +// se::PlatformManager::PlatformWithName("OpenCL"); +// if (!platform_status.ok()) { ... } +// Platform* platform = platform_status.value(); +// LOG(INFO) << platform->VisibleDeviceCount() << " devices visible"; +// if (platform->VisibleDeviceCount() <= 0) { return; } +// +// for (int i = 0; i < platform->VisibleDeviceCount(); ++i) { +// absl::StatusOr executor_status = +// platform->ExecutorForDevice(i); +// if (!executor_status.ok()) { +// LOG(INFO) << "could not retrieve executor for device ordinal " << i +// << ": " << executor_status.status(); +// continue; +// } +// LOG(INFO) << "found usable executor: " << executor_status.value(); +// } +// +// A few things to note: +// - There is no standard formatting/practice for identifying the name of a +// platform. Ideally, a platform will list its registered name in its header +// or in other associated documentation. +// - Platform name lookup is case-insensitive. "OpenCL" or "opencl" (or even +// ("OpEnCl") would work correctly in the above example. +// +// And similarly, for standard interfaces (BLAS, etc.) you can add +// dependencies on support libraries, e.g.: +// +// //third_party/tensorflow/compiler/xla/stream_executor/cuda:pluton_blas_plugin +// //third_party/tensorflow/compiler/xla/stream_executor/cuda:cudnn_plugin +// //third_party/tensorflow/compiler/xla/stream_executor/cuda:cublas_plugin + +#ifndef XLA_STREAM_EXECUTOR_PLATFORM_MANAGER_H_ +#define XLA_STREAM_EXECUTOR_PLATFORM_MANAGER_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/stream_executor/platform.h" + +namespace stream_executor { + +// Manages multiple platforms that may be present on the current machine. +class PlatformManager { + public: + // Registers a platform object, returns an error status if the platform is + // already registered. The associated listener, if not null, will be used to + // trace events for ALL executors for that platform. + // Takes ownership of platform. + static absl::Status RegisterPlatform(std::unique_ptr platform); + + // Retrieves the platform registered with the given platform name (e.g. + // "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the + // Platform's Id() method). + // + // If the platform has not already been initialized, it will be initialized + // with a default set of parameters. + // + // If the requested platform is not registered, an error status is returned. + // Ownership of the platform is NOT transferred to the caller -- + // the PlatformManager owns the platforms in a singleton-like fashion. + static absl::StatusOr PlatformWithName(absl::string_view target); + static absl::StatusOr PlatformWithId(const Platform::Id& id); + + // Same functions as above, but allows platforms to be returned without + // initialization if initialize_platform == false. + static absl::StatusOr PlatformWithName(absl::string_view target, + bool initialize_platform); + + // Retrieves the platform registered with the given platform id (an opaque, + // comparable value provided by the Platform's Id() method). + // + // The platform will be initialized with the given options. If the platform + // was already initialized, an error will be returned. + // + // If the requested platform is not registered, an error status is returned. + // Ownership of the platform is NOT transferred to the caller -- + // the PlatformManager owns the platforms in a singleton-like fashion. + static absl::StatusOr InitializePlatformWithId( + const Platform::Id& id, + const std::map& options); + + // Retrieves the platforms satisfying the given filter, i.e. returns true. + // Returned Platforms are always initialized. + static absl::StatusOr> PlatformsWithFilter( + const std::function& filter); + + static absl::StatusOr> PlatformsWithFilter( + const std::function& filter, + bool initialize_platform); + + // Although the PlatformManager "owns" its platforms, it holds them as + // undecorated pointers to prevent races during program exit (between this + // object's data and the underlying platforms (e.g., CUDA, OpenCL). + // Because certain platforms have unpredictable deinitialization + // times/sequences, it is not possible to strucure a safe deinitialization + // sequence. Thus, we intentionally "leak" allocated platforms to defer + // cleanup to the OS. This should be acceptable, as these are one-time + // allocations per program invocation. + // The PlatformManager should be considered the owner + // of any platforms registered with it, and leak checking should be disabled + // during allocation of such Platforms, to avoid spurious reporting at program + // exit. +}; + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_PLATFORM_MANAGER_H_ diff --git a/xla/stream_executor/plugin_registry.cc b/xla/stream_executor/plugin_registry.cc index 1007722e1af06..8e5f772dbf603 100644 --- a/xla/stream_executor/plugin_registry.cc +++ b/xla/stream_executor/plugin_registry.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,9 +15,16 @@ limitations under the License. #include "xla/stream_executor/plugin_registry.h" +#include +#include + #include "absl/base/const_init.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" +#include "xla/stream_executor/platform.h" namespace stream_executor { @@ -54,37 +61,38 @@ PluginRegistry::PluginRegistry() {} } template -tsl::Status PluginRegistry::RegisterFactoryInternal( +absl::Status PluginRegistry::RegisterFactoryInternal( const std::string& plugin_name, FACTORY_TYPE factory, std::optional* factories) { absl::MutexLock lock{&GetPluginRegistryMutex()}; if (factories->has_value()) { - return tsl::Status( - absl::StatusCode::kAlreadyExists, + return absl::AlreadyExistsError( absl::StrFormat("Attempting to register factory for plugin %s when " "one has already been registered", plugin_name)); } (*factories) = factory; - return ::tsl::OkStatus(); + return absl::OkStatus(); } bool PluginRegistry::HasFactory(Platform::Id platform_id, PluginKind plugin_kind) const { auto iter = factories_.find(platform_id); - if (iter != factories_.end()) { - switch (plugin_kind) { - case PluginKind::kBlas: - return iter->second.blas.has_value(); - case PluginKind::kDnn: - return iter->second.dnn.has_value(); - case PluginKind::kFft: - return iter->second.fft.has_value(); - default: - break; - } + if (iter == factories_.end()) { + return false; + } + + switch (plugin_kind) { + case PluginKind::kBlas: + return iter->second.blas.has_value(); + case PluginKind::kDnn: + return iter->second.dnn.has_value(); + case PluginKind::kFft: + return iter->second.fft.has_value(); + default: + break; } LOG(ERROR) << "Invalid plugin kind specified: " @@ -95,13 +103,13 @@ bool PluginRegistry::HasFactory(Platform::Id platform_id, // Explicit instantiations to support types exposed in user/public API. #define EMIT_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE, FACTORY_VAR, PLUGIN_STRING) \ \ - template tsl::Status \ + template absl::Status \ PluginRegistry::RegisterFactoryInternal( \ const std::string& plugin_name, PluginRegistry::FACTORY_TYPE factory, \ std::optional* factories); \ \ template <> \ - tsl::Status PluginRegistry::RegisterFactory( \ + absl::Status PluginRegistry::RegisterFactory( \ Platform::Id platform_id, const std::string& name, \ PluginRegistry::FACTORY_TYPE factory) { \ return RegisterFactoryInternal(name, factory, \ @@ -109,13 +117,12 @@ bool PluginRegistry::HasFactory(Platform::Id platform_id, } \ \ template <> \ - tsl::StatusOr PluginRegistry::GetFactory( \ + absl::StatusOr PluginRegistry::GetFactory( \ Platform::Id platform_id) { \ auto plugin_id = factories_[platform_id].FACTORY_VAR; \ \ if (!plugin_id.has_value()) { \ - return tsl::Status( \ - absl::StatusCode::kFailedPrecondition, \ + return absl::FailedPreconditionError( \ "No suitable " PLUGIN_STRING \ " plugin registered. Have you linked in a " PLUGIN_STRING \ "-providing plugin?"); \ diff --git a/xla/stream_executor/plugin_registry.h b/xla/stream_executor/plugin_registry.h index ab106599e2e61..aa3372fc04966 100644 --- a/xla/stream_executor/plugin_registry.h +++ b/xla/stream_executor/plugin_registry.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,13 +17,15 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_PLUGIN_REGISTRY_H_ #include +#include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/fft.h" #include "xla/stream_executor/platform.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace stream_executor { @@ -65,17 +67,17 @@ class PluginRegistry { // Returns a non-successful status if the factory has already been registered // with that platform (but execution should be otherwise unaffected). template - tsl::Status RegisterFactory(Platform::Id platform_id, const std::string& name, - FactoryT factory); + absl::Status RegisterFactory(Platform::Id platform_id, + const std::string& name, FactoryT factory); // Return true if the factory/kind has been registered for the // specified platform and plugin kind and false otherwise. bool HasFactory(Platform::Id platform_id, PluginKind plugin_kind) const; // Retrieves the factory registered for the specified kind, - // or a tsl::Status on error. + // or a absl::Status on error. template - tsl::StatusOr GetFactory(Platform::Id platform_id); + absl::StatusOr GetFactory(Platform::Id platform_id); private: // Containers for the sets of registered factories, by plugin kind. @@ -89,9 +91,9 @@ class PluginRegistry { // Actually performs the work of registration. template - tsl::Status RegisterFactoryInternal(const std::string& plugin_name, - FactoryT factory, - std::optional* factories); + absl::Status RegisterFactoryInternal(const std::string& plugin_name, + FactoryT factory, + std::optional* factories); // Returns true if the specified plugin has been registered with the specified // platform factories. Unlike the other overload of this method, this does @@ -109,13 +111,13 @@ class PluginRegistry { }; // Explicit specializations are defined in plugin_registry.cc. -#define DECLARE_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE) \ - template <> \ - tsl::Status PluginRegistry::RegisterFactory( \ - Platform::Id platform_id, const std::string& name, \ - PluginRegistry::FACTORY_TYPE factory); \ - template <> \ - tsl::StatusOr PluginRegistry::GetFactory( \ +#define DECLARE_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE) \ + template <> \ + absl::Status PluginRegistry::RegisterFactory( \ + Platform::Id platform_id, const std::string& name, \ + PluginRegistry::FACTORY_TYPE factory); \ + template <> \ + absl::StatusOr PluginRegistry::GetFactory( \ Platform::Id platform_id) DECLARE_PLUGIN_SPECIALIZATIONS(BlasFactory); diff --git a/xla/stream_executor/rocm/BUILD b/xla/stream_executor/rocm/BUILD index 8d62e4a4a1939..05121d3748a5a 100644 --- a/xla/stream_executor/rocm/BUILD +++ b/xla/stream_executor/rocm/BUILD @@ -1,23 +1,31 @@ # Description: # ROCm-platform specific StreamExecutor support code. +# buildifier: disable=out-of-order-load + +# buildifier: disable=out-of-order-load load( "//xla/stream_executor:build_defs.bzl", "stream_executor_friends", ) + +# copybara:comment_begin(oss-only) +load("//xla/stream_executor/rocm:build_defs.bzl", "rocm_embedded_test_modules") + +# copybara:comment_end load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_hipblaslt", "if_rocm_is_configured", "rocm_library", ) -load("@tsl//tsl:tsl.bzl", "set_external_visibility", "tsl_copts") +load("@tsl//tsl:tsl.bzl", "internal_visibility", "tsl_copts") load("@tsl//tsl/platform:build_config_root.bzl", "if_static") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([":friends"]), + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -53,8 +61,7 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings", - "//xla/stream_executor:stream_executor_headers", - "//xla/stream_executor:device_options", + "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/platform", "//xla/stream_executor/platform:dso_loader", @@ -65,6 +72,40 @@ cc_library( ]), ) +cc_library( + name = "rocm_runtime", + srcs = if_rocm_is_configured(["rocm_runtime.cc"]), + hdrs = if_rocm_is_configured([ + "rocm_driver_wrapper.h", + "rocm_driver.h", + ]), + deps = if_rocm_is_configured([ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "//xla/stream_executor", + "//xla/stream_executor/gpu:gpu_runtime_header", + "//xla/stream_executor/gpu:gpu_types_header", + "//xla/stream_executor/gpu:gpu_driver_header", + "//xla/stream_executor/platform", + "//xla/stream_executor/platform:dso_loader", + "@local_config_rocm//rocm:rocm_headers", + "@tsl//tsl/platform:statusor", + ]), +) + +cc_library( + name = "rocm_collectives", + srcs = if_rocm_is_configured(["rocm_collectives.cc"]), + deps = if_rocm_is_configured([ + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "//xla/stream_executor/gpu:gpu_collectives_header", + "//xla/stream_executor/gpu:gpu_driver_header", + ]), +) + cc_library( name = "rocm_activation", srcs = [], @@ -84,7 +125,7 @@ cc_library( srcs = if_rocm_is_configured(["rocm_event.cc"]), deps = if_rocm_is_configured([ ":rocm_driver", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_event_header", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_stream_header", @@ -92,14 +133,16 @@ cc_library( ) cc_library( - name = "rocm_gpu_executor", - srcs = if_rocm_is_configured(["rocm_gpu_executor.cc"]), + name = "rocm_executor", + srcs = if_rocm_is_configured(["rocm_executor.cc"]), deps = if_rocm_is_configured([ ":rocm_diagnostics", ":rocm_driver", ":rocm_event", ":rocm_kernel", ":rocm_platform_id", + ":rocm_runtime", + ":rocm_collectives", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/strings", "//xla/stream_executor", @@ -113,7 +156,9 @@ cc_library( "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/platform", "//xla/stream_executor/platform:dso_loader", + "//xla/stream_executor/integrations:device_mem_allocator", "@tsl//tsl/platform:env", + "@tsl//tsl/platform:fingerprint", ]), alwayslink = True, ) @@ -141,14 +186,16 @@ cc_library( visibility = ["//visibility:public"], deps = if_rocm_is_configured([ ":rocm_driver", - ":rocm_gpu_executor", + ":rocm_executor", ":rocm_platform_id", + ":rocm_runtime", + ":rocm_collectives", "@com_google_absl//absl/base", "@com_google_absl//absl/memory", "//xla/stream_executor", # buildcleaner: keep "//xla/stream_executor/platform", ]), - alwayslink = True, # Registers itself with the MultiPlatformManager. + alwayslink = True, # Registers itself with the PlatformManager. ) cc_library( @@ -178,13 +225,13 @@ cc_library( hdrs = if_rocm_is_configured(["rocblas_wrapper.h"]), deps = if_rocm_is_configured([ ":rocblas_if_static", - ":rocm_gpu_executor", + ":rocm_executor", ":rocm_platform_id", "@local_config_rocm//rocm:rocm_headers", "//xla/stream_executor/platform", "//xla/stream_executor/platform:dso_loader", "@tsl//tsl/platform:env", - "@tsl//tsl/util:determinism_for_kernels", + "//xla/tsl/util:determinism_for_kernels", ]), alwayslink = True, ) @@ -198,7 +245,7 @@ cc_library( ":rocblas_if_static", ":rocblas_wrapper", ":hipblas_lt_header", - ":rocm_gpu_executor", + ":rocm_executor", ":rocm_platform_id", "@eigen_archive//:eigen3", "//xla/stream_executor", @@ -285,7 +332,7 @@ cc_library( ":miopen_if_static", ":rocm_diagnostics", ":rocm_driver", - ":rocm_gpu_executor", + ":rocm_executor", ":rocm_platform_id", "@eigen_archive//:eigen3", "//xla/stream_executor", @@ -295,6 +342,7 @@ cc_library( "//xla/stream_executor/gpu:gpu_stream_header", "//xla/stream_executor/gpu:gpu_timer_header", "//xla/stream_executor/gpu:gpu_types_header", + "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor/platform", "//xla/stream_executor/platform:dso_loader", "@com_google_absl//absl/algorithm:container", @@ -302,8 +350,8 @@ cc_library( "@com_google_absl//absl/types:span", "@local_config_rocm//rocm:rocm_headers", "@tsl//tsl/platform:env", - "@tsl//tsl/util:env_var", - "@tsl//tsl/util:determinism_for_kernels", + "//xla/tsl/util:env_var", + "//xla/tsl/util:determinism_for_kernels", ]), alwayslink = True, ) @@ -342,7 +390,7 @@ cc_library( hdrs = if_rocm_is_configured(["hipsparse_wrapper.h"]), deps = if_rocm_is_configured([ ":hipsparse_if_static", - ":rocm_gpu_executor", + ":rocm_executor", ":rocm_platform_id", "@local_config_rocm//rocm:rocm_headers", "//xla/stream_executor/platform", @@ -371,7 +419,7 @@ cc_library( srcs = if_rocm_is_configured(["rocsolver_wrapper.h"]), hdrs = if_rocm_is_configured(["rocsolver_wrapper.h"]), deps = if_rocm_is_configured([ - ":rocm_gpu_executor", + ":rocm_executor", ":rocm_platform_id", ":rocsolver_if_static", "@local_config_rocm//rocm:rocm_headers", @@ -401,7 +449,7 @@ cc_library( srcs = if_rocm_is_configured(["hipsolver_wrapper.h"]), hdrs = if_rocm_is_configured(["hipsolver_wrapper.h"]), deps = if_rocm_is_configured([ - ":rocm_gpu_executor", + ":rocm_executor", ":rocm_platform_id", ":hipsolver_if_static", "@local_config_rocm//rocm:rocm_headers", @@ -420,7 +468,7 @@ cc_library( ) cc_library( - name = "hipblaslt_plugin", + name = "amdhipblaslt_plugin", srcs = if_rocm_is_configured(["hip_blas_lt.cc"]), hdrs = if_rocm_is_configured([ "hip_blas_lt.h", @@ -428,7 +476,7 @@ cc_library( "hip_blas_utils.h", ]), deps = if_rocm_is_configured([ - ":rocm_gpu_executor", + ":rocm_executor", ":rocm_platform_id", ":rocblas_plugin", ":hip_blas_utils", @@ -463,7 +511,7 @@ cc_library( "@tsl//tsl/platform:errors", "//xla:status", "//xla/stream_executor:host_or_device_scalar", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor", "//xla/stream_executor/platform", ]), ) @@ -479,7 +527,7 @@ cc_library( "@local_config_rocm//rocm:rocm_headers", "@tsl//tsl/platform:status", "@tsl//tsl/platform:errors", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor", ]), ) @@ -502,7 +550,7 @@ cc_library( srcs = if_rocm_is_configured(["roctracer_wrapper.h"]), hdrs = if_rocm_is_configured(["roctracer_wrapper.h"]), deps = if_rocm_is_configured([ - ":rocm_gpu_executor", + ":rocm_executor", ":rocm_platform_id", ":roctracer_if_static", "@local_config_rocm//rocm:rocm_headers", @@ -533,7 +581,7 @@ cc_library( ":rocm_driver", ":rocm_platform", ":rocm_helpers", - ":hipblaslt_plugin", + ":amdhipblaslt_plugin", ]), alwayslink = 1, ) @@ -558,3 +606,10 @@ cc_library( [":all_runtime"], ), ) + +# copybara:comment_begin(oss-only) +rocm_embedded_test_modules( + name = "add_i32_kernel", + srcs = if_rocm_is_configured(["add_i32_kernel.cu.cc"]), +) +# copybara:comment_end diff --git a/xla/stream_executor/rocm/add_i32_kernel.cu.cc b/xla/stream_executor/rocm/add_i32_kernel.cu.cc new file mode 100644 index 0000000000000..8a6406fe05e5f --- /dev/null +++ b/xla/stream_executor/rocm/add_i32_kernel.cu.cc @@ -0,0 +1,21 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +extern "C" __global__ void add(int32_t* a, int32_t* b, int32_t* c) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + c[index] = a[index] + b[index]; +} diff --git a/xla/stream_executor/rocm/build_defs.bzl b/xla/stream_executor/rocm/build_defs.bzl new file mode 100644 index 0000000000000..0be87739c8469 --- /dev/null +++ b/xla/stream_executor/rocm/build_defs.bzl @@ -0,0 +1,68 @@ +""" ROCM-specific build macros. +""" + +load("@local_config_rocm//rocm:build_defs.bzl", "rocm_gpu_architectures") + +def rocm_embedded_test_modules(name, srcs, testonly = True, **kwargs): + """Compile srcs into hsaco files and create a header only cc_library. + + Binary files are embedded as constant data. + + Args: + name: name for the generated cc_library target, and the base name for + generated header file + srcs: source files for input modules + testonly: If True, the target can only be used with tests. + **kwargs: keyword arguments passed onto the generated cc_library() rule. + """ + + # Lets piggyback this on top crosstool wrapper for now + hipcc_tool = "@local_config_rocm//crosstool:crosstool_wrapper_driver_is_not_gcc" + target_opts = " ".join(["--amdgpu-target=" + + arch for arch in rocm_gpu_architectures()]) + + header_file = "%s.h" % name + + native.genrule( + name = name + "_header_file", + srcs = srcs, + outs = [header_file], + cmd = """ + tmp_name_for_xxd() { + local filename=$$(basename $$1) + local name="k" + for word in $$(echo $${filename%%%%.*} | tr '_' ' '); do + name="$$name$${word^}" + done + echo "$${name}Module" + } + + echo '#pragma once' > $@ + echo '#include ' >> $@ + for src in $(SRCS); do + tmp=$$(tmp_name_for_xxd $$src); + $(location %s) -x rocm %s --genco -c $$src -o $$tmp && xxd -i $$tmp | sed \ + -e 's/unsigned char/inline constexpr uint8_t/g' \ + -e '$$d' >> $@; + rm -f $$tmp + done + """ % (hipcc_tool, target_opts), + tools = [hipcc_tool], + testonly = testonly, + target_compatible_with = select({ + "@local_config_rocm//rocm:using_hipcc": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + ) + + native.cc_library( + name = name, + srcs = [], + hdrs = [header_file], + testonly = testonly, + target_compatible_with = select({ + "@local_config_rocm//rocm:using_hipcc": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + **kwargs + ) diff --git a/xla/stream_executor/rocm/hip_blas_lt.cc b/xla/stream_executor/rocm/hip_blas_lt.cc index 262a3a5c3122f..b09cfefef56a1 100644 --- a/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/xla/stream_executor/rocm/hip_blas_lt.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -39,7 +39,7 @@ limitations under the License. // hipblasLtMatmulDescGetAttribute does not allow nullptr for the last // argument (size_t* sizeWritten) #define GET_ATTR(getter, handle, attr, ValueT) \ - [&]() -> tsl::StatusOr { \ + [&]() -> absl::StatusOr { \ ValueT value; \ size_t size; \ TF_RETURN_IF_ERROR(ToStatus( \ @@ -57,32 +57,32 @@ using ::xla::complex64; namespace { template -tsl::Status SetAttr(hipblasLtMatrixLayout_t handle, - hipblasLtMatrixLayoutAttribute_t attr, T value) { +absl::Status SetAttr(hipblasLtMatrixLayout_t handle, + hipblasLtMatrixLayoutAttribute_t attr, T value) { return SET_ATTR(wrap::hipblasLtMatrixLayoutSetAttribute, handle, attr, value); } template -tsl::StatusOr GetAttr(hipblasLtMatrixLayout_t handle, - hipblasLtMatrixLayoutAttribute_t attr) { +absl::StatusOr GetAttr(hipblasLtMatrixLayout_t handle, + hipblasLtMatrixLayoutAttribute_t attr) { return GET_ATTR(wrap::hipblasLtMatrixLayoutGetAttribute, handle, attr, T); } template -tsl::Status SetAttr(hipblasLtMatmulDesc_t handle, - hipblasLtMatmulDescAttributes_t attr, T value) { +absl::Status SetAttr(hipblasLtMatmulDesc_t handle, + hipblasLtMatmulDescAttributes_t attr, T value) { return SET_ATTR(wrap::hipblasLtMatmulDescSetAttribute, handle, attr, value); } template -tsl::StatusOr GetAttr(hipblasLtMatmulDesc_t handle, - hipblasLtMatmulDescAttributes_t attr) { +absl::StatusOr GetAttr(hipblasLtMatmulDesc_t handle, + hipblasLtMatmulDescAttributes_t attr) { return GET_ATTR(wrap::hipblasLtMatmulDescGetAttribute, handle, attr, T); } template -tsl::Status SetAttr(hipblasLtMatmulPreference_t handle, - hipblasLtMatmulPreferenceAttributes_t attr, T value) { +absl::Status SetAttr(hipblasLtMatmulPreference_t handle, + hipblasLtMatmulPreferenceAttributes_t attr, T value) { return SET_ATTR(wrap::hipblasLtMatmulPreferenceSetAttribute, handle, attr, value); } @@ -97,7 +97,7 @@ static hipblasPointerMode_t AsHipblasLtPointerMode( } } -static tsl::StatusOr AsHipblasLtEpilogue( +static absl::StatusOr AsHipblasLtEpilogue( gpu::BlasLt::Epilogue epilogue) { switch (epilogue) { case gpu::BlasLt::Epilogue::kDefault: @@ -110,61 +110,68 @@ static tsl::StatusOr AsHipblasLtEpilogue( return HIPBLASLT_EPILOGUE_RELU_BIAS; case gpu::BlasLt::Epilogue::kGELU: return HIPBLASLT_EPILOGUE_GELU; +#if TF_ROCM_VERSION >= 60000 + case gpu::BlasLt::Epilogue::kGELUWithAux: + return HIPBLASLT_EPILOGUE_GELU_AUX; + case gpu::BlasLt::Epilogue::kBiasThenGELU: + return HIPBLASLT_EPILOGUE_GELU_BIAS; + case gpu::BlasLt::Epilogue::kBiasThenGELUWithAux: + return HIPBLASLT_EPILOGUE_GELU_AUX_BIAS; +#endif default: - return tsl::errors::Internal("Unsupported epilogue"); + return absl::InternalError("Unsupported epilogue: " + + std::to_string((int)epilogue)); } } } // namespace -tsl::Status BlasLt::Init() { +absl::Status BlasLt::Init() { hipblasLtHandle_t blas_lt; SE_HIPBLAS_RETURN_IF_ERROR(wrap::hipblasLtCreate(&blas_lt)); absl::MutexLock lock(&mu_); blas_lt_.reset(blas_lt); - return tsl::OkStatus(); + return absl::OkStatus(); } -/*static*/ tsl::StatusOr BlasLt::MatrixLayout::Create( +/*static*/ absl::StatusOr BlasLt::MatrixLayout::Create( const gpu::MatrixLayout& m) { TF_ASSIGN_OR_RETURN(auto type, gpu::AsBlasDataType(m.dtype)); - auto leading_dim_stride = m.leading_dim_stride; - if (!leading_dim_stride) { - leading_dim_stride = (m.order == gpu::MatrixLayout::Order::kRowMajor) - ? m.num_cols - : m.num_rows; - } auto hipblas_data_type_ = AsHipblasDataType(type); hipblasLtMatrixLayout_t hip_layout; SE_HIPBLAS_RETURN_IF_ERROR(wrap::hipblasLtMatrixLayoutCreate( &hip_layout, hipblas_data_type_, m.num_rows, m.num_cols, - *leading_dim_stride)); + m.leading_dim_stride)); // Wrap hipblas handle immediately, so it is cleaned up if an error occurs. BlasLt::MatrixLayout layout(hip_layout, hipblas_data_type_); if (m.order != gpu::MatrixLayout::Order::kColumnMajor) - return tsl::errors::Internal( - "HipblasLT does not support row-major matrices"); + return absl::InternalError("HipblasLT does not support row-major matrices"); TF_RETURN_IF_ERROR(SetAttr(hip_layout, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, static_cast(m.batch_size))); - auto batch_stride = m.batch_stride; - if (!batch_stride) { - batch_stride = (m.batch_size > 1) ? m.num_rows * m.num_cols : 0; - } - TF_RETURN_IF_ERROR(SetAttr( - hip_layout, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, *batch_stride)); + VLOG(2) << "BlasLt::MatrixLayout::Create type: " << (int)type + << " rows: " << m.num_rows << " cols: " << m.num_cols + << " batch_size: " << m.batch_size + << " leading_dim_stride: " << m.leading_dim_stride + << " batch_stride: " << m.batch_stride; + + TF_RETURN_IF_ERROR(SetAttr(hip_layout, + HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + m.batch_stride)); return std::move(layout); } -/*static*/ tsl::StatusOr BlasLt::MatmulDesc::Create( +/*static*/ absl::StatusOr BlasLt::MatmulDesc::Create( blas::ComputationType compute_type, blas::DataType scale_type, blas::Transpose trans_a, blas::Transpose trans_b, Epilogue epilogue, PointerMode pointer_mode) { hipblasLtMatmulDesc_t hip_desc; - VLOG(2) << "BlasLt::MatmulDesc::Create compute_type" << int(compute_type) - << " scale_type " << int(scale_type) << " epilogue " << int(epilogue) - << " pointer_mode " << int(pointer_mode); + VLOG(2) << "BlasLt::MatmulDesc::Create compute_type: " << int(compute_type) + << " scale_type: " << int(scale_type) + << " epilogue: " << int(epilogue) << " trans_a: " << int(trans_a) + << " trans_b: " << int(trans_b) << " pointer_mode " + << int(pointer_mode); auto hip_scale_type = AsHipblasDataType(scale_type); auto hip_compute_type = AsHipblasComputeType(compute_type); SE_HIPBLAS_RETURN_IF_ERROR(wrap::hipblasLtMatmulDescCreate( @@ -172,7 +179,7 @@ tsl::Status BlasLt::Init() { // Wrap hipblas handle immediately, so it is cleaned up if an error occurs. BlasLt::MatmulDesc desc(hip_desc, hip_compute_type, hip_scale_type); if (pointer_mode != PointerMode::kHost) { - return tsl::errors::Internal("hipblaslt does not support device pointers"); + return absl::InternalError("hipblaslt does not support device pointers"); } TF_RETURN_IF_ERROR(SetAttr(hip_desc, HIPBLASLT_MATMUL_DESC_TRANSA, @@ -186,7 +193,7 @@ tsl::Status BlasLt::Init() { auto BlasLt::MatmulPlan::GetAlgorithms(size_t max_algorithm_count, size_t max_workspace_size) const - -> tsl::StatusOr> { + -> absl::StatusOr> { max_algorithm_count = std::min(max_algorithm_count, size_t{INT_MAX}); std::vector results(max_algorithm_count); @@ -244,7 +251,7 @@ auto BlasLt::MatmulPlan::GetAlgorithms(size_t max_algorithm_count, } auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg, Epilogue epilogue) const - -> tsl::StatusOr { + -> absl::StatusOr { auto lhs_layout = cfg.lhs_layout, rhs_layout = cfg.rhs_layout, output_layout = cfg.output_layout, c_layout = cfg.c_layout; @@ -264,18 +271,15 @@ auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg, Epilogue epilogue) const // *not* be transposed, and if B is row-major, B must be transposed. We never // transpose A or B, and expect the caller to ensure A is row-major and B is // column when A and B are FP8. - auto trans_a = lhs_layout.transpose ? *lhs_layout.transpose - : blas::Transpose::kNoTranspose; - auto trans_b = rhs_layout.transpose ? *rhs_layout.transpose - : blas::Transpose::kNoTranspose; + auto trans_a = lhs_layout.transpose, trans_b = rhs_layout.transpose; if (xla::primitive_util::IsF8Type(lhs_layout.dtype) && lhs_layout.order == gpu::MatrixLayout::Order::kColumnMajor) { - return xla::InternalError("The F8 LHS must be column-major"); + return xla::Internal("The F8 LHS must be column-major"); } if (xla::primitive_util::IsF8Type(rhs_layout.dtype) && rhs_layout.order == gpu::MatrixLayout::Order::kRowMajor) { - return xla::InternalError("The F8 RHS must be row-major"); + return xla::Internal("The F8 RHS must be row-major"); } TF_ASSIGN_OR_RETURN(auto output_dtype, @@ -283,9 +287,10 @@ auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg, Epilogue epilogue) const auto compute_type = cfg.compute_type; if (!compute_type) { // obtain compute_type unless provided by the user - TF_ASSIGN_OR_RETURN(compute_type, gpu::GetBlasComputationType( - lhs_layout.dtype, output_layout.dtype, - cfg.compute_precision)); + TF_ASSIGN_OR_RETURN(compute_type, + gpu::GetBlasComputationType( + cfg.precision_algorithm, lhs_layout.dtype, + output_layout.dtype, cfg.compute_precision)); } if (lhs_layout.order == gpu::MatrixLayout::Order::kRowMajor) { @@ -308,6 +313,30 @@ auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg, Epilogue epilogue) const TF_ASSIGN_OR_RETURN(auto c_desc, MatrixLayout::Create(c_layout)); TF_ASSIGN_OR_RETURN(auto d_desc, MatrixLayout::Create(output_layout)); +#if TF_ROCM_VERSION >= 60000 + // Currently, the default bias data type in hipblasLt is the same with output + // data type for fp8 matmul, which is different from cublasLt. This is a + // workaround to match cublasLt behavior. + if (epilogue == gpu::BlasLt::Epilogue::kBias) { + auto a_dtype = a_desc.type(); + auto b_dtype = b_desc.type(); + + auto bias_dtype = d_desc.type(); + if ((a_dtype == HIP_R_8F_E4M3_FNUZ || a_dtype == HIP_R_8F_E5M2_FNUZ) && + (b_dtype == HIP_R_8F_E4M3_FNUZ || b_dtype == HIP_R_8F_E5M2_FNUZ)) { + auto d_dtype = d_desc.type(); + if (d_dtype == HIP_R_32F) { + bias_dtype = HIP_R_16BF; + } + + if (bias_dtype != d_dtype) { + TF_RETURN_IF_ERROR(SetAttr( + op_desc.get(), HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_dtype)); + } + } + } +#endif // TF_ROCM_VERSION >= 60000 + // std::make_unique won't work with brace initialization in C++17 ;( return std::make_unique(*this, std::move(op_desc), std::move(a_desc), std::move(b_desc), @@ -315,45 +344,45 @@ auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg, Epilogue epilogue) const cfg.alpha, cfg.beta, must_swap_operands); } -tsl::Status BlasLt::MatmulPlan::ValidateInputs( +absl::Status BlasLt::MatmulPlan::ValidateInputs( blas::DataType scale_type, bool alpha_on_device, bool beta_on_device, blas::DataType A_type, blas::DataType B_type, blas::DataType C_type, blas::DataType D_type) const { if (AsHipblasDataType(scale_type) != op_desc_.scale_type()) { - return tsl::errors::InvalidArgument("mismatched scale types"); + return absl::InvalidArgumentError("mismatched scale types"); } bool expect_scale_factor_on_device = (op_desc_.pointer_mode() == HIPBLAS_POINTER_MODE_DEVICE); if (alpha_on_device != expect_scale_factor_on_device) { - return tsl::errors::InvalidArgument("wrong location for alpha"); + return absl::InvalidArgumentError("wrong location for alpha"); } if (beta_on_device != expect_scale_factor_on_device) { - return tsl::errors::InvalidArgument("wrong location for beta"); + return absl::InvalidArgumentError("wrong location for beta"); } if (AsHipblasDataType(A_type) != a_desc_.type()) { - return tsl::errors::InvalidArgument("mismatched A matrix types"); + return absl::InvalidArgumentError("mismatched A matrix types"); } if (AsHipblasDataType(B_type) != b_desc_.type()) { - return tsl::errors::InvalidArgument("mismatched B matrix types"); + return absl::InvalidArgumentError("mismatched B matrix types"); } if (AsHipblasDataType(C_type) != c_desc_.type()) { - return tsl::errors::InvalidArgument("mismatched C matrix types"); + return absl::InvalidArgumentError("mismatched C matrix types"); } if (AsHipblasDataType(D_type) != d_desc_.type()) { - return tsl::errors::InvalidArgument("mismatched D matrix types"); + return absl::InvalidArgumentError("mismatched D matrix types"); } - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status BlasLt::MatmulPlan::DoMatmul( +absl::Status BlasLt::MatmulPlan::DoMatmul( Stream* stream, const void* alpha, DeviceMemoryBase a, DeviceMemoryBase b, const void* beta, DeviceMemoryBase c, DeviceMemoryBase d, const MatmulAlgorithm& algorithm, ScratchAllocator& scratch_allocator, @@ -363,7 +392,9 @@ tsl::Status BlasLt::MatmulPlan::DoMatmul( blas::ProfileResult* profile_result) const { TF_ASSIGN_OR_RETURN( std::optional timer, - gpu::GpuTimer::CreateIfNeeded(gpu::AsGpuStream(stream), profile_result)); + gpu::GpuTimer::CreateIfNeeded( + stream, profile_result && profile_result->warmup_run_executed(), + profile_result)); void* workspace = nullptr; if (algorithm.workspace_size > 0) { @@ -384,17 +415,34 @@ tsl::Status BlasLt::MatmulPlan::DoMatmul( op_desc_.get(), HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias.opaque())); } +#if TF_ROCM_VERSION >= 60000 + if (a_scale != nullptr) { + TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), + HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, + a_scale.opaque())); + } + if (b_scale != nullptr) { + TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), + HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, + b_scale.opaque())); + } + if (c_scale != nullptr || d_scale != nullptr) { + return absl::InternalError( + "hipblaslt does not support c_scale or d_scale."); + } +#else if ((a_scale != nullptr) || (b_scale != nullptr) || (c_scale != nullptr) || (d_scale != nullptr)) { - return tsl::errors::Internal("hipblaslt does not support scale"); + return absl::InternalError("hipblaslt does not support scale"); } +#endif if (d_amax != nullptr) { - return tsl::errors::Internal("hipblaslt does not support amax"); + return absl::InternalError("hipblaslt does not support amax"); } if (aux != nullptr) { - return tsl::errors::Internal( + return absl::InternalError( "hipblaslt does not support auxiliary inputs / outputs"); } @@ -407,7 +455,7 @@ tsl::Status BlasLt::MatmulPlan::DoMatmul( c_desc_.get(), d.opaque(), d_desc_.get(), palgo, workspace, algorithm.workspace_size, gpu::AsGpuStreamValue(stream))); } else { - return tsl::errors::Internal("hipblaslt: Invalid algorithm type"); + return absl::InternalError("hipblaslt: Invalid algorithm type"); } } @@ -418,7 +466,7 @@ tsl::Status BlasLt::MatmulPlan::DoMatmul( profile_result->set_is_valid(true); profile_result->set_elapsed_time_in_ms(absl::ToDoubleMilliseconds(elapsed)); } - return tsl::OkStatus(); + return absl::OkStatus(); } namespace { @@ -426,6 +474,17 @@ namespace { template struct HipToNativeT; +#if TF_ROCM_VERSION >= 60000 +template <> +struct HipToNativeT { + using type = tsl::float8_e4m3fnuz; +}; +template <> +struct HipToNativeT { + using type = tsl::float8_e5m2fnuz; +}; +#endif // TF_ROCM_VERSION >= 60000 + template <> struct HipToNativeT { using type = Eigen::bfloat16; @@ -453,7 +512,7 @@ struct HipToNativeT { } // namespace -tsl::Status BlasLt::MatmulPlan::ExecuteOnStream( +absl::Status BlasLt::MatmulPlan::ExecuteOnStream( Stream* stream, DeviceMemoryBase a, DeviceMemoryBase b, DeviceMemoryBase c, DeviceMemoryBase d, DeviceMemoryBase bias, DeviceMemoryBase aux, DeviceMemoryBase a_scale, DeviceMemoryBase b_scale, @@ -477,6 +536,23 @@ tsl::Status BlasLt::MatmulPlan::ExecuteOnStream( profile_result); \ } +#if TF_ROCM_VERSION >= 60000 + TYPED_MATMUL(float, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_16F, + HIP_R_16F) + TYPED_MATMUL(float, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_32F, + HIP_R_32F) + + TYPED_MATMUL(float, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E5M2_FNUZ, HIP_R_16F, + HIP_R_16F) + TYPED_MATMUL(float, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E5M2_FNUZ, HIP_R_32F, + HIP_R_32F) + + TYPED_MATMUL(float, HIP_R_8F_E5M2_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_16F, + HIP_R_16F) + TYPED_MATMUL(float, HIP_R_8F_E5M2_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_32F, + HIP_R_32F) +#endif + // Other data types: TYPED_MATMUL(float, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF) TYPED_MATMUL(float, HIP_R_16F, HIP_R_16F, HIP_R_16F, HIP_R_16F) @@ -489,7 +565,7 @@ tsl::Status BlasLt::MatmulPlan::ExecuteOnStream( #undef TYPED_MATMUL - return xla::InternalError("Unexpected dtype"); + return xla::Internal("Unexpected dtype"); } } // namespace rocm diff --git a/xla/stream_executor/rocm/hip_blas_lt.h b/xla/stream_executor/rocm/hip_blas_lt.h index 0ab58918a66f4..54c5b024c60f3 100644 --- a/xla/stream_executor/rocm/hip_blas_lt.h +++ b/xla/stream_executor/rocm/hip_blas_lt.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -13,13 +13,13 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_ROCM_HIP_BLAS_LT_H_ #define XLA_STREAM_EXECUTOR_ROCM_HIP_BLAS_LT_H_ +#include "absl/status/status.h" #include "rocm/rocm_config.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "xla/stream_executor/host_or_device_scalar.h" #include "xla/types.h" -#include "tsl/platform/status.h" #if TF_HIPBLASLT @@ -40,7 +40,7 @@ class BlasLt : public gpu::BlasLt { public: struct MatrixLayout { - static tsl::StatusOr Create(const gpu::MatrixLayout& m); + static absl::StatusOr Create(const gpu::MatrixLayout& m); hipDataType type() const { return datatype_; } hipblasLtMatrixLayout_t get() const { return handle_.get(); } @@ -56,7 +56,7 @@ class BlasLt : public gpu::BlasLt { class MatmulDesc { public: - static tsl::StatusOr Create( + static absl::StatusOr Create( blas::ComputationType compute_type, blas::DataType scale_type, blas::Transpose trans_a = blas::Transpose::kNoTranspose, blas::Transpose trans_b = blas::Transpose::kNoTranspose, @@ -99,7 +99,7 @@ class BlasLt : public gpu::BlasLt { ~MatmulPlan() override = default; - tsl::Status ExecuteOnStream( + absl::Status ExecuteOnStream( Stream* stream, DeviceMemoryBase a_buffer, DeviceMemoryBase b_buffer, DeviceMemoryBase c_buffer, DeviceMemoryBase d_buffer, DeviceMemoryBase bias_buffer, // may be null @@ -110,25 +110,25 @@ class BlasLt : public gpu::BlasLt { ScratchAllocator& scratch_allocator, blas::ProfileResult* profile_result = nullptr) const override; - tsl::StatusOr> GetAlgorithms( + absl::StatusOr> GetAlgorithms( size_t max_algorithm_count, size_t max_workspace_size) const override; protected: - tsl::Status ValidateInputs(blas::DataType scale_type, bool alpha_on_device, - bool beta_on_device, blas::DataType A_type, - blas::DataType B_type, blas::DataType C_type, - blas::DataType D_type) const override; - - tsl::Status DoMatmul(Stream* stream, const void* alpha, DeviceMemoryBase a, - DeviceMemoryBase b, const void* beta, - DeviceMemoryBase c, DeviceMemoryBase d, - const MatmulAlgorithm& algorithm, - ScratchAllocator& scratch_allocator, - DeviceMemoryBase bias, DeviceMemoryBase aux, - DeviceMemoryBase a_scale, DeviceMemoryBase b_scale, - DeviceMemoryBase c_scale, DeviceMemoryBase d_scale, - DeviceMemoryBase d_amax, - blas::ProfileResult* profile_result) const override; + absl::Status ValidateInputs(blas::DataType scale_type, bool alpha_on_device, + bool beta_on_device, blas::DataType A_type, + blas::DataType B_type, blas::DataType C_type, + blas::DataType D_type) const override; + + absl::Status DoMatmul(Stream* stream, const void* alpha, DeviceMemoryBase a, + DeviceMemoryBase b, const void* beta, + DeviceMemoryBase c, DeviceMemoryBase d, + const MatmulAlgorithm& algorithm, + ScratchAllocator& scratch_allocator, + DeviceMemoryBase bias, DeviceMemoryBase aux, + DeviceMemoryBase a_scale, DeviceMemoryBase b_scale, + DeviceMemoryBase c_scale, DeviceMemoryBase d_scale, + DeviceMemoryBase d_amax, + blas::ProfileResult* profile_result) const override; private: const BlasLt& blas_lt_ref_; @@ -146,10 +146,10 @@ class BlasLt : public gpu::BlasLt { explicit BlasLt(gpu::GpuExecutor* parent) : parent_(parent), blas_lt_(nullptr, wrap::hipblasLtDestroy) {} - tsl::Status Init() override; + absl::Status Init() override; - tsl::StatusOr GetMatmulPlan(const gpu::GemmConfig& cfg, - Epilogue epilogue) const override; + absl::StatusOr GetMatmulPlan(const gpu::GemmConfig& cfg, + Epilogue epilogue) const override; ~BlasLt() override = default; diff --git a/xla/stream_executor/rocm/hip_blas_utils.cc b/xla/stream_executor/rocm/hip_blas_utils.cc index 8bd0be07c5346..a59c935614cd8 100644 --- a/xla/stream_executor/rocm/hip_blas_utils.cc +++ b/xla/stream_executor/rocm/hip_blas_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,20 +23,30 @@ limitations under the License. namespace stream_executor { namespace rocm { -tsl::Status ToStatus(hipblasStatus_t status, const char* prefix) { +absl::Status ToStatus(hipblasStatus_t status, const char* prefix) { if (status != HIPBLAS_STATUS_SUCCESS) { - return tsl::errors::Internal(absl::StrCat( + return absl::InternalError(absl::StrCat( prefix, ": ", "HipblasLt error " + std::to_string(static_cast(status)))); } - return tsl::OkStatus(); + return absl::OkStatus(); } hipDataType AsHipblasDataType(blas::DataType type) { switch (type) { case blas::DataType::kF8E5M2: case blas::DataType::kF8E4M3FN: - LOG(FATAL) << "hipblaslt does not support F8 yet"; + LOG(FATAL) << "hipblaslt does not support F8E5M2 and F8E4M3FN"; +#if TF_ROCM_VERSION >= 60000 + case blas::DataType::kF8E5M2FNUZ: + return HIP_R_8F_E5M2_FNUZ; + case blas::DataType::kF8E4M3FNUZ: + return HIP_R_8F_E4M3_FNUZ; +#else + case blas::DataType::kF8E5M2FNUZ: + case blas::DataType::kF8E4M3FNUZ: + LOG(FATAL) << "hipblaslt only supports F8 in ROCm 6.0 and above"; +#endif case blas::DataType::kHalf: return HIP_R_16F; case blas::DataType::kBF16: diff --git a/xla/stream_executor/rocm/hip_blas_utils.h b/xla/stream_executor/rocm/hip_blas_utils.h index 726386a1bb6f2..267e2a31050ce 100644 --- a/xla/stream_executor/rocm/hip_blas_utils.h +++ b/xla/stream_executor/rocm/hip_blas_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,10 +18,10 @@ limitations under the License. #include +#include "absl/status/status.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/rocm/hipblaslt_wrapper.h" #include "tsl/platform/errors.h" -#include "tsl/platform/status.h" #if TF_HIPBLASLT @@ -48,7 +48,7 @@ namespace rocm { #define SE_HIPBLAS_RETURN_IF_ERROR(expr) \ TF_RETURN_IF_ERROR(::stream_executor::rocm::ToStatus(expr, #expr)) -tsl::Status ToStatus(hipblasStatus_t status, const char* prefix); +absl::Status ToStatus(hipblasStatus_t status, const char* prefix); hipDataType AsHipblasDataType(blas::DataType type); hipblasComputeType_t AsHipblasComputeType(blas::ComputationType type); hipblasOperation_t AsHipblasOperation(blas::Transpose trans); diff --git a/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc b/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc index 654b9f02c4879..d88ae44f240d3 100644 --- a/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc +++ b/xla/stream_executor/rocm/hip_conditional_kernels.cu.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,33 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include -namespace stream_executor { -namespace rocm { -namespace { +namespace stream_executor::gpu { -__global__ void SetCondition() {} +std::string_view GetSetIfConditionKernel() { return ""; } +std::string_view GetSetIfElseConditionKernel() { return ""; } +std::string_view GetSetCaseConditionKernel() { return ""; } +std::string_view GetSetForConditionKernel() { return ""; } +std::string_view GetSetWhileConditionKernel() { return ""; } -} // namespace -} // namespace rocm - -namespace gpu { -void* GetSetIfConditionKernel() { - return reinterpret_cast(&rocm::SetCondition); -} -void* GetSetIfElseConditionKernel() { - return reinterpret_cast(&rocm::SetCondition); -} -void* GetSetCaseConditionKernel() { - return reinterpret_cast(&rocm::SetCondition); -} -void* GetSetForConditionKernel() { - return reinterpret_cast(&rocm::SetCondition); -} -void* GetSetWhileConditionKernel() { - return reinterpret_cast(&rocm::SetCondition); -} -} // namespace gpu - -} // namespace stream_executor +} // namespace stream_executor::gpu diff --git a/xla/stream_executor/rocm/hipblaslt_wrapper.h b/xla/stream_executor/rocm/hipblaslt_wrapper.h index 326280b9ada9f..c53cff6a93391 100644 --- a/xla/stream_executor/rocm/hipblaslt_wrapper.h +++ b/xla/stream_executor/rocm/hipblaslt_wrapper.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/xla/stream_executor/rocm/hipsolver_wrapper.h b/xla/stream_executor/rocm/hipsolver_wrapper.h index 67a4d48dd03c4..8434ae03b9668 100644 --- a/xla/stream_executor/rocm/hipsolver_wrapper.h +++ b/xla/stream_executor/rocm/hipsolver_wrapper.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/rocm/hipsparse_wrapper.h b/xla/stream_executor/rocm/hipsparse_wrapper.h index d38bda28476ae..b4bcc7d8f3944 100644 --- a/xla/stream_executor/rocm/hipsparse_wrapper.h +++ b/xla/stream_executor/rocm/hipsparse_wrapper.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -46,8 +46,8 @@ namespace wrap { #else #define HIPSPARSE_API_WRAPPER(__name) \ - struct DynLoadShim__##__name { \ - static const char* kName; \ + static struct DynLoadShim__##__name { \ + constexpr static const char* kName = #__name; \ using FuncPtrT = std::add_pointer::type; \ static void* GetDsoHandle() { \ auto s = \ @@ -56,8 +56,8 @@ namespace wrap { } \ static FuncPtrT LoadOrDie() { \ void* f; \ - auto s = tsl::Env::Default() \ - -> GetSymbolFromLibrary(GetDsoHandle(), kName, &f); \ + auto s = tsl::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ + kName, &f); \ CHECK(s.ok()) << "could not find " << kName \ << " in miopen DSO; dlerror: " << s.message(); \ return reinterpret_cast(f); \ @@ -70,8 +70,7 @@ namespace wrap { hipsparseStatus_t operator()(Args... args) { \ return DynLoad()(args...); \ } \ - } __name; \ - const char* DynLoadShim__##__name::kName = #__name; + } __name; #endif @@ -128,7 +127,7 @@ namespace wrap { __macro(hipsparseDcsru2csr_bufferSizeExt) \ __macro(hipsparseDcsru2csr) \ __macro(hipsparseScsru2csr_bufferSizeExt) \ - __macro(hipsparseScsru2csr) \ + __macro(hipsparseScsru2csr) \ __macro(hipsparseSpMM_bufferSize) \ __macro(hipsparseSpMM) \ __macro(hipsparseZcsru2csr_bufferSizeExt) \ diff --git a/xla/stream_executor/rocm/rocblas_wrapper.h b/xla/stream_executor/rocm/rocblas_wrapper.h index 497a61d035fba..3d444ab83a0ee 100644 --- a/xla/stream_executor/rocm/rocblas_wrapper.h +++ b/xla/stream_executor/rocm/rocblas_wrapper.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,6 +20,8 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_ROCM_ROCBLAS_WRAPPER_H_ #define XLA_STREAM_EXECUTOR_ROCM_ROCBLAS_WRAPPER_H_ +// needed for rocblas_gemm_ex_get_solutions* functionality +#define ROCBLAS_BETA_FEATURES_API #include "rocm/include/rocblas/rocblas.h" #include "xla/stream_executor/gpu/gpu_activation.h" #include "xla/stream_executor/platform/dso_loader.h" @@ -32,44 +34,42 @@ namespace wrap { using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle; #ifdef PLATFORM_GOOGLE -#define ROCBLAS_API_WRAPPER(__name) \ - struct WrapperShim__##__name { \ - static const char* kName; \ - template \ - rocblas_status operator()(Args... args) { \ - return ::__name(args...); \ - } \ - } __name; \ - const char* WrapperShim__##__name::kName = #__name; +#define ROCBLAS_API_WRAPPER(__name) \ + struct WrapperShim__##__name { \ + constexpr static const char* kName = #__name; \ + template \ + rocblas_status operator()(Args... args) { \ + return ::__name(args...); \ + } \ + } __name; #else -#define ROCBLAS_API_WRAPPER(__name) \ - struct DynLoadShim__##__name { \ - static const char* kName; \ - using FuncPtrT = std::add_pointer::type; \ - static void* GetDsoHandle() { \ - auto s = GetRocblasDsoHandle(); \ - return s.value(); \ - } \ - static FuncPtrT LoadOrDie() { \ - void* f; \ - auto s = tsl::Env::Default() \ - -> GetSymbolFromLibrary(GetDsoHandle(), kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in rocblas DSO; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - } \ - static FuncPtrT DynLoad() { \ - static FuncPtrT f = LoadOrDie(); \ - return f; \ - } \ - template \ - rocblas_status operator()(Args... args) { \ - return DynLoad()(args...); \ - } \ - } __name; \ - const char* DynLoadShim__##__name::kName = #__name; +#define ROCBLAS_API_WRAPPER(__name) \ + static struct DynLoadShim__##__name { \ + constexpr static const char* kName = #__name; \ + using FuncPtrT = std::add_pointer::type; \ + static void* GetDsoHandle() { \ + auto s = GetRocblasDsoHandle(); \ + return s.value(); \ + } \ + static FuncPtrT LoadOrDie() { \ + void* f; \ + auto s = tsl::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in rocblas DSO; dlerror: " << s.message(); \ + return reinterpret_cast(f); \ + } \ + static FuncPtrT DynLoad() { \ + static FuncPtrT f = LoadOrDie(); \ + return f; \ + } \ + template \ + rocblas_status operator()(Args... args) { \ + return DynLoad()(args...); \ + } \ + } __name; #endif @@ -257,6 +257,11 @@ using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle; __macro(rocblas_zgemm_strided_batched) \ __macro(rocblas_gemm_ex) \ __macro(rocblas_gemm_strided_batched_ex) \ + __macro(rocblas_gemm_ex_get_solutions) \ + __macro(rocblas_gemm_ex_get_solutions_by_type) \ + __macro(rocblas_gemm_batched_ex_get_solutions) \ + __macro(rocblas_gemm_batched_ex_get_solutions_by_type) \ + __macro(rocblas_gemm_strided_batched_ex_get_solutions) \ __macro(rocblas_strsm_batched) \ __macro(rocblas_dtrsm_batched) \ __macro(rocblas_ctrsm_batched) \ diff --git a/xla/stream_executor/rocm/rocm_activation.h b/xla/stream_executor/rocm/rocm_activation.h index 9b1b58064b3a5..31a19ae00e116 100644 --- a/xla/stream_executor/rocm/rocm_activation.h +++ b/xla/stream_executor/rocm/rocm_activation.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ limitations under the License. // It reaches into the ROCM implementation to activate an underlying ROCM // context. // -// Having this file separate from rocm/rocm_gpu_executor.h means that dependent +// Having this file separate from rocm/rocm_executor.h means that dependent // code does not also have to depend on rocm.h. #ifndef XLA_STREAM_EXECUTOR_ROCM_ROCM_ACTIVATION_H_ diff --git a/xla/stream_executor/rocm/rocm_blas.cc b/xla/stream_executor/rocm/rocm_blas.cc index c9c6bebcda6c3..c1915c7d7b835 100644 --- a/xla/stream_executor/rocm/rocm_blas.cc +++ b/xla/stream_executor/rocm/rocm_blas.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -40,8 +40,8 @@ limitations under the License. #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/util/determinism.h" #include "tsl/platform/logging.h" -#include "tsl/util/determinism.h" using tsl::OpDeterminismRequired; namespace stream_executor { @@ -52,48 +52,56 @@ extern void rocm_Broadcast_fp32(void *stream, float *dst, int dst_stride, int size); template -const typename RocBlasTypeConversionHelper::mapped_type *complex_cast( - const DeviceMemory &a) { - return reinterpret_cast< - const typename RocBlasTypeConversionHelper::mapped_type *>( - GpuMemory(a)); +const RocBlasType_t *const *complex_cast(const DeviceMemory &a) { + return reinterpret_cast *const *>(GpuMemory(a)); } template -const typename RocBlasTypeConversionHelper::mapped_type *complex_cast( - const T &a) { - return reinterpret_cast< - const typename RocBlasTypeConversionHelper::mapped_type *>(&a); +RocBlasType_t *const *complex_cast(DeviceMemory &a) { + return reinterpret_cast *const *>(GpuMemory(a)); } + template -typename RocBlasTypeConversionHelper::mapped_type *complex_cast( - DeviceMemory *a) { - return reinterpret_cast< - typename RocBlasTypeConversionHelper::mapped_type *>( - GpuMemoryMutable(a)); +const RocBlasType_t *complex_cast(const DeviceMemory &a) { + return reinterpret_cast *>(GpuMemory(a)); } -static void blas_log(const char *c) {} +template +const RocBlasType_t *complex_cast(const T &a) { + return reinterpret_cast *>(&a); +} +template +RocBlasType_t *complex_cast(DeviceMemory *a) { + return reinterpret_cast *>(GpuMemoryMutable(a)); +} static string ToString(rocblas_status status) { +#define XVAL(x) \ + case x: \ + return #x switch (status) { - case rocblas_status_success: - return "rocblas_status_success"; - case rocblas_status_invalid_handle: - return "rocblas_status_invalid_handle"; - case rocblas_status_not_implemented: - return "rocblas_status_not_implemented"; - case rocblas_status_invalid_pointer: - return "rocblas_status_invalid_pointer"; - case rocblas_status_invalid_size: - return "rocblas_status_invalid_size"; - case rocblas_status_memory_error: - return "rocblas_status_memory_error"; - case rocblas_status_internal_error: - return "rocblas_status_internal_error"; + XVAL(rocblas_status_success); + XVAL(rocblas_status_invalid_handle); + XVAL(rocblas_status_not_implemented); + XVAL(rocblas_status_invalid_pointer); + XVAL(rocblas_status_invalid_size); + XVAL(rocblas_status_memory_error); + XVAL(rocblas_status_internal_error); +#if TF_ROCM_VERSION >= 60000 + XVAL(rocblas_status_perf_degraded); + XVAL(rocblas_status_size_query_mismatch); + XVAL(rocblas_status_size_increased); + XVAL(rocblas_status_size_unchanged); + XVAL(rocblas_status_invalid_value); + XVAL(rocblas_status_continue); + XVAL(rocblas_status_check_numerics_fail); + XVAL(rocblas_status_excluded_from_build); + XVAL(rocblas_status_arch_mismatch); +#endif default: return absl::StrCat(""); } +#undef XVAL } bool ROCMBlas::Init() { @@ -110,6 +118,17 @@ bool ROCMBlas::Init() { return false; } #endif + + int dev = 0; + hipError_t result = hipGetDevice(&dev); + hipDeviceProp_t props; + result = hipGetDeviceProperties(&props, dev); + if (result == hipSuccess) { + auto cap = RocmComputeCapability(props.gcnArchName); + has_mfma_ = cap.has_mfma_instr_support(); + use_hgemm_alt_impl_ = (cap.gfx_version() == "gfx90a"); + } + return true; } @@ -203,17 +222,113 @@ rocblas_side ROCMBlasSide(blas::Side side) { } } +absl::StatusOr AsRocBlasType(blas::DataType type) { + switch (type) { + case blas::DataType::kHalf: + return rocblas_datatype_f16_r; + case blas::DataType::kBF16: + return rocblas_datatype_bf16_r; + case blas::DataType::kFloat: + return rocblas_datatype_f32_r; + case blas::DataType::kDouble: + return rocblas_datatype_f64_r; + case blas::DataType::kInt8: + return rocblas_datatype_i8_r; + case blas::DataType::kInt32: + return rocblas_datatype_i32_r; + case blas::DataType::kComplexFloat: + return rocblas_datatype_f32_c; + case blas::DataType::kComplexDouble: + return rocblas_datatype_f64_c; + default: + return absl::InternalError( + absl::StrFormat("Unsupported blas data type: %d", (int)type)); + } +} + +absl::StatusOr AsRocBlasComputeType( + blas::ComputationType type) { + switch (type) { + case blas::ComputationType::kF16: + return rocblas_datatype_f16_r; + case blas::ComputationType::kF32: + return rocblas_datatype_f32_r; + case blas::ComputationType::kF64: + return rocblas_datatype_f64_r; + case blas::ComputationType::kI32: + return rocblas_datatype_i32_r; + case blas::ComputationType::kF16AsF32: + case blas::ComputationType::kBF16AsF32: + case blas::ComputationType::kTF32AsF32: + default: + return absl::InternalError( + absl::StrFormat("Unsupported compute type: %d", (int)type)); + } +} + +void CheckPreconditions(blas::Transpose transa, blas::Transpose transb, + uint64_t m, uint64_t n, uint64_t k, + blas::DataType dtype, int lda, int ldb) { + if (dtype == blas::DataType::kHalf || dtype == blas::DataType::kFloat) { + if (transa == blas::Transpose::kNoTranspose) { + if (lda < static_cast(m)) { + LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); " + "precondition violation"; + } + } else { + if (lda < static_cast(k)) { + LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k + << ") (transpose case); precondition violation"; + } + } + if (transb == blas::Transpose::kNoTranspose) { + if (ldb < static_cast(k)) { + LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k + << ") (no transpose case); precondition violation"; + } + } else { + if (ldb < static_cast(n)) { + LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); " + "precondition violation"; + } + } + } +} + +uint32_t GemmFloat16Flags(blas::DataType dtype, blas::CallContext context, + bool use_alt_impl) { + bool is_backprop = (context == blas::CallContext::kBackpropInput1 || + context == blas::CallContext::kBackpropInput2); + + return ((dtype == blas::DataType::kHalf) && is_backprop && use_alt_impl) + ? rocblas_gemm_flags_fp16_alt_impl + : rocblas_gemm_flags_none; +} + +absl::Status PopulateProfileFromTimer( + std::optional &timer, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result) { + if (output_profile_result) { + TF_ASSIGN_OR_RETURN(absl::Duration duration, timer->GetElapsedDuration()); + output_profile_result->set_is_valid(true); + output_profile_result->set_algorithm(algorithm); + output_profile_result->set_elapsed_time_in_ms( + absl::ToDoubleMilliseconds(duration)); + } + return absl::OkStatus(); +} + } // namespace template -bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, - bool pointer_mode_host, bool err_on_failure, - Args... args) { +absl::Status ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, + bool pointer_mode_host, + bool err_on_failure, Args &&...args) { absl::MutexLock lock{&mu_}; CHECK(blas_ != nullptr); if (!SetStream(stream)) { - return false; + return absl::InternalError("Setting stream failed"); } gpu::ScopedActivateExecutorContext sac{parent_}; @@ -224,55 +339,31 @@ bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, if (!allow_atomics) { ret = wrap::rocblas_set_atomics_mode(blas_, rocblas_atomics_not_allowed); if (err_on_failure && ret != rocblas_status_success) { - LOG(ERROR) << "failed to to set atomics mode before " - << rocblas_func.kName << ": " << ToString(ret); + LOG(ERROR) << "failed to to set atomics mode before " << FuncT::kName + << ": " << ToString(ret); } } - ret = rocblas_func(blas_, args...); - if (err_on_failure && ret != rocblas_status_success) { - LOG(ERROR) << "failed to run ROCBLAS routine " << rocblas_func.kName << ": " - << ToString(ret); + ret = rocblas_func(blas_, std::forward(args)...); + if (ret != rocblas_status_success) { + auto err_str = + absl::StrFormat("%s failed with: %s", FuncT::kName, ToString(ret)); + if (err_on_failure) { + LOG(ERROR) << err_str; + } + return absl::InternalError(err_str); } - return ret == rocblas_status_success; + return absl::OkStatus(); } bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64_t elem_count, float alpha, const DeviceMemory &x, int incx, DeviceMemory *y, int incy) { - blas_log("DoBlasAxpy"); return DoBlasInternal(wrap::rocblas_saxpy, stream, /* pointer_mode_host = */ true, elem_count, &alpha, GpuMemory(x), incx, GpuMemoryMutable(y), incy); } -bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64_t elem_count, double alpha, - const DeviceMemory &x, int incx, - DeviceMemory *y, int incy) { - blas_log("DoBlasAxpy"); - return DoBlasInternal(wrap::rocblas_daxpy, stream, - /* pointer_mode_host = */ true, elem_count, &alpha, - GpuMemory(x), incx, GpuMemoryMutable(y), incy); -} - -bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64_t elem_count, - std::complex alpha, - const DeviceMemory> &x, int incx, - DeviceMemory> *y, int incy) { - return DoBlasInternal( - wrap::rocblas_caxpy, stream, /* pointer_mode_host = */ true, elem_count, - complex_cast(alpha), complex_cast(x), incx, complex_cast(y), incy); -} - -bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64_t elem_count, - std::complex alpha, - const DeviceMemory> &x, int incx, - DeviceMemory> *y, int incy) { - return DoBlasInternal( - wrap::rocblas_zaxpy, stream, /* pointer_mode_host = */ true, elem_count, - complex_cast(alpha), complex_cast(x), incx, complex_cast(y), incy); -} - bool ROCMBlas::DoBlasCopy(Stream *stream, uint64_t elem_count, const DeviceMemory &x, int incx, DeviceMemory *y, int incy) { @@ -281,310 +372,355 @@ bool ROCMBlas::DoBlasCopy(Stream *stream, uint64_t elem_count, GpuMemory(x), incx, GpuMemoryMutable(y), incy); } -bool ROCMBlas::DoBlasCopy(Stream *stream, uint64_t elem_count, - const DeviceMemory &x, int incx, - DeviceMemory *y, int incy) { - return DoBlasInternal(wrap::rocblas_dcopy, stream, - /* pointer_mode_host = */ true, elem_count, - GpuMemory(x), incx, GpuMemoryMutable(y), incy); -} - -bool ROCMBlas::DoBlasCopy(Stream *stream, uint64_t elem_count, - const DeviceMemory> &x, int incx, - DeviceMemory> *y, int incy) { - return DoBlasInternal(wrap::rocblas_ccopy, stream, - /* pointer_mode_host = */ true, elem_count, - complex_cast(x), incx, complex_cast(y), incy); -} - -bool ROCMBlas::DoBlasCopy(Stream *stream, uint64_t elem_count, - const DeviceMemory> &x, int incx, - DeviceMemory> *y, int incy) { - return DoBlasInternal(wrap::rocblas_zcopy, stream, - /* pointer_mode_host = */ true, elem_count, - complex_cast(x), incx, complex_cast(y), incy); -} - -bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, float alpha, - DeviceMemory *x, int incx) { - blas_log("DoBlasScal"); - return DoBlasInternal(wrap::rocblas_sscal, stream, - /* pointer_mode_host = */ true, elem_count, &alpha, - GpuMemoryMutable(x), incx); -} - -bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, double alpha, - DeviceMemory *x, int incx) { - return DoBlasInternal(wrap::rocblas_dscal, stream, - /* pointer_mode_host = */ true, elem_count, &alpha, - GpuMemoryMutable(x), incx); -} - -bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, float alpha, - DeviceMemory> *x, int incx) { - return DoBlasInternal(wrap::rocblas_csscal, stream, - /* pointer_mode_host = */ true, elem_count, &alpha, - complex_cast(x), incx); -} - -bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, double alpha, - DeviceMemory> *x, int incx) { - return DoBlasInternal(wrap::rocblas_zdscal, stream, - /* pointer_mode_host = */ true, elem_count, &alpha, - complex_cast(x), incx); -} - -bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, - std::complex alpha, - DeviceMemory> *x, int incx) { - return DoBlasInternal(wrap::rocblas_cscal, stream, - /* pointer_mode_host = */ true, elem_count, - complex_cast(alpha), complex_cast(x), incx); -} - -bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, - std::complex alpha, - DeviceMemory> *x, int incx) { - return DoBlasInternal(wrap::rocblas_zscal, stream, - /* pointer_mode_host = */ true, elem_count, - complex_cast(alpha), complex_cast(x), incx); -} - -bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, - uint64_t n, float alpha, const DeviceMemory &a, - int lda, const DeviceMemory &x, int incx, - float beta, DeviceMemory *y, int incy) { - blas_log("DoBlasGemv"); - return DoBlasInternal( - wrap::rocblas_sgemv, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x), - incx, &beta, GpuMemoryMutable(y), incy); -} - -bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, - uint64_t n, double alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &x, int incx, double beta, - DeviceMemory *y, int incy) { - blas_log("DoBlasGemv"); - return DoBlasInternal( - wrap::rocblas_dgemv, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x), - incx, &beta, GpuMemoryMutable(y), incy); -} - -bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, - uint64_t n, std::complex alpha, - const DeviceMemory> &a, int lda, - const DeviceMemory> &x, int incx, - std::complex beta, - DeviceMemory> *y, int incy) { - blas_log("DoBlasGemv"); - return DoBlasInternal( - wrap::rocblas_cgemv, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda, - complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy); -} +#define Impl_DoBlasScal(Fun, T, Ta) \ + bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, Ta alpha, \ + DeviceMemory *x, int incx) { \ + return DoBlasInternal(Fun, stream, /* pointer_mode_host = */ true, \ + elem_count, complex_cast(alpha), complex_cast(x), \ + incx); \ + } -bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, - uint64_t n, std::complex alpha, - const DeviceMemory> &a, int lda, - const DeviceMemory> &x, int incx, - std::complex beta, - DeviceMemory> *y, int incy) { - blas_log("DoBlasGemv\n"); - return DoBlasInternal( - wrap::rocblas_zgemv, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda, - complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy); -} +Impl_DoBlasScal(wrap::rocblas_sscal, float, + float) Impl_DoBlasScal(wrap::rocblas_dscal, double, double) + Impl_DoBlasScal(wrap::rocblas_csscal, std::complex, float) + Impl_DoBlasScal(wrap::rocblas_zdscal, std::complex, double) + Impl_DoBlasScal(wrap::rocblas_cscal, std::complex, + std::complex) + Impl_DoBlasScal(wrap::rocblas_zscal, std::complex, + std::complex) +#define Impl_DoBlasGemv(fun, T) \ + bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, \ + uint64_t n, T alpha, const DeviceMemory &a, \ + int lda, const DeviceMemory &x, int incx, \ + T beta, DeviceMemory *y, int incy) { \ + return DoBlasInternal(fun, stream, /* pointer_mode_host = */ true, \ + ROCMBlasTranspose(trans), m, n, complex_cast(alpha), \ + complex_cast(a), lda, complex_cast(x), incx, \ + complex_cast(beta), complex_cast(y), incy); \ + } -bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n, - uint64_t k, float alpha, const DeviceMemory &a, - int lda, const DeviceMemory &x, int incx, - float beta, DeviceMemory *y, int incy) { + Impl_DoBlasGemv(wrap::rocblas_sgemv, float) + Impl_DoBlasGemv(wrap::rocblas_dgemv, double) + Impl_DoBlasGemv(wrap::rocblas_cgemv, + std::complex) + Impl_DoBlasGemv(wrap::rocblas_zgemv, + std::complex) + + bool ROCMBlas::DoBlasSbmv( + Stream *stream, blas::UpperLower uplo, + uint64_t n, uint64_t k, float alpha, + const DeviceMemory &a, int lda, + const DeviceMemory &x, int incx, + float beta, DeviceMemory *y, + int incy) { return DoBlasInternal( wrap::rocblas_ssbmv, stream, /* pointer_mode_host = */ true, ROCMBlasUpperLower(uplo), n, k, &alpha, GpuMemory(a), lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy); } -bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n, - uint64_t k, double alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &x, int incx, double beta, - DeviceMemory *y, int incy) { - return DoBlasInternal( - wrap::rocblas_dsbmv, stream, /* pointer_mode_host = */ true, - ROCMBlasUpperLower(uplo), n, k, &alpha, GpuMemory(a), lda, GpuMemory(x), - incx, &beta, GpuMemoryMutable(y), incy); -} - -tsl::Status ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, - blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, blas::DataType dtype, - const void *alpha, const DeviceMemoryBase &a, - int lda, const DeviceMemoryBase &b, int ldb, - const void *beta, DeviceMemoryBase *c, int ldc, - const NumericOptions &numeric_options, - blas::CallContext context) { - blas_log("DoBlasGemm"); +/** + * + * ALPHA/BETA TYPES + * + * For half and bf16, alpha and beta point to floats. + * For all other types, alpha and beta point to values of the same type as + *a/b/c. + * + * On the rocblas side, non-ex functions expect the same type as a/b/c + * (this seems to be a deviation from the blas standard); + * and ex functions expect the same type as the compute type (i.e. floats.) + * + **/ + +absl::Status ROCMBlas::DoBlasGemm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, + uint64_t n, uint64_t k, blas::DataType dtype, const void *alpha, + const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb, + const void *beta, DeviceMemoryBase *c, int ldc, + const NumericOptions &numeric_options, blas::CallContext context) { VLOG(1) << absl::StreamFormat( "doing rocBLAS GEMM: at=%d bt=%d m=%u n=%u " "k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p " "c=%p ldc=%d", static_cast(transa), static_cast(transb), m, n, k, alpha, a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc); - if (dtype == blas::DataType::kHalf || dtype == blas::DataType::kFloat) { - if (transa == blas::Transpose::kNoTranspose) { - if (lda < static_cast(m)) { - LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); " - "precondition violation"; - } - } else { - if (lda < static_cast(k)) { - LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k - << ") (transpose case); precondition violation"; - } - } - if (transb == blas::Transpose::kNoTranspose) { - if (ldb < static_cast(k)) { - LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k - << ") (no transpose case); precondition violation"; - } - } else { - if (ldb < static_cast(n)) { - LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); " - "precondition violation"; - } - } + + CheckPreconditions(transa, transb, m, n, k, dtype, lda, ldb); + + absl::Status status; + uint32_t gemm_ex_flags = rocblas_gemm_flags_none; + bool is_backprop = (context == blas::CallContext::kBackpropInput1) || + (context == blas::CallContext::kBackpropInput2); + if (is_backprop && use_hgemm_alt_impl_) + gemm_ex_flags = rocblas_gemm_flags_fp16_alt_impl; + + Eigen::half alpha_half, beta_half; + + const void *alpha_downcast = alpha, *beta_downcast = beta; + if (dtype == blas::DataType::kHalf) { + alpha_half = Eigen::half(*static_cast(alpha)); + beta_half = Eigen::half(*static_cast(beta)); + alpha_downcast = &alpha_half; + beta_downcast = &beta_half; } + /* I would like to specify the type with a template parameter: + * + * auto call_gemm = [&](auto func) { ... } + * ... + * status = call_gemm(wrap::rocblas_sgemm); + * + * but that's a C++20 extension and can't be enabled (the compiler does + * support it, but enabling it causes compilation errors inside Eigen.) */ + auto call_gemm = [&](auto func, auto type) { + return DoBlasInternalStatus( + func, stream, /* pointer_mode_host = */ true, ROCMBlasTranspose(transa), + ROCMBlasTranspose(transb), m, n, k, + reinterpret_cast(alpha_downcast), + reinterpret_cast(a.opaque()), lda, + reinterpret_cast(b.opaque()), ldb, + reinterpret_cast(beta_downcast), + reinterpret_cast(c->opaque()), ldc); + }; + + auto call_gemm_ex = [&](rocblas_datatype dt) { + return DoBlasInternalStatus( + wrap::rocblas_gemm_ex, stream, /* pointer_mode_host = */ true, + ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), (rocblas_int)m, + (rocblas_int)n, (rocblas_int)k, alpha, a.opaque(), dt, lda, b.opaque(), + dt, ldb, beta, c->opaque(), dt, ldc, c->opaque(), dt, ldc, + rocblas_datatype_f32_r, rocblas_gemm_algo_standard, 0, gemm_ex_flags); + }; + switch (dtype) { - case blas::DataType::kHalf: { - tsl::StatusOr maybe_hasXDLOPS = GpuDriver::GetMFMASupport(); - if (maybe_hasXDLOPS.ok() && maybe_hasXDLOPS.value()) { - VLOG(1) << "Using rocblas_gemm_ex"; - bool is_backprop = (context == blas::CallContext::kBackpropInput1) || - (context == blas::CallContext::kBackpropInput2); - - uint32_t flags = rocblas_gemm_flags_none; -#if TF_ROCM_VERSION >= 50000 - if (is_backprop) { - flags = rocblas_gemm_flags_fp16_alt_impl; - } -#endif - return DoBlasInternalStatus( - wrap::rocblas_gemm_ex, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), - (rocblas_int)m, (rocblas_int)n, (rocblas_int)k, alpha, a.opaque(), - rocblas_datatype_f16_r, lda, b.opaque(), rocblas_datatype_f16_r, - ldb, beta, c->opaque(), rocblas_datatype_f16_r, ldc, c->opaque(), - rocblas_datatype_f16_r, ldc, rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, flags); - } else { - VLOG(1) << "Using rocblas_hgemm"; - const Eigen::half alpha_half(*static_cast(alpha)); - const Eigen::half beta_half(*static_cast(beta)); - return DoBlasInternalStatus( - wrap::rocblas_hgemm, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - reinterpret_cast(&alpha_half), - reinterpret_cast(a.opaque()), lda, - reinterpret_cast(b.opaque()), ldb, - reinterpret_cast(&beta_half), - reinterpret_cast(c->opaque()), ldc); - } - } + case blas::DataType::kHalf: + if (has_mfma_) + return call_gemm_ex(rocblas_datatype_f16_r); + else + return call_gemm(wrap::rocblas_hgemm, rocblas_half()); case blas::DataType::kBF16: - return DoBlasInternalStatus( - wrap::rocblas_gemm_ex, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), (rocblas_int)m, - (rocblas_int)n, (rocblas_int)k, alpha, a.opaque(), - rocblas_datatype_bf16_r, lda, b.opaque(), rocblas_datatype_bf16_r, - ldb, beta, c->opaque(), rocblas_datatype_bf16_r, ldc, c->opaque(), - rocblas_datatype_bf16_r, ldc, rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, 0, 0); + return call_gemm_ex(rocblas_datatype_bf16_r); case blas::DataType::kFloat: - return DoBlasInternalStatus( - wrap::rocblas_sgemm, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - static_cast(alpha), - static_cast(a.opaque()), lda, - static_cast(b.opaque()), ldb, - static_cast(beta), static_cast(c->opaque()), - ldc); + return call_gemm(wrap::rocblas_sgemm, 1.0f); case blas::DataType::kDouble: - return DoBlasInternalStatus( - wrap::rocblas_dgemm, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - static_cast(alpha), - static_cast(a.opaque()), lda, - static_cast(b.opaque()), ldb, - static_cast(beta), static_cast(c->opaque()), - ldc); - case blas::DataType::kComplexFloat: { - auto cb_alpha = - complex_cast(*static_cast *>(alpha)); - auto cb_beta = - complex_cast(*static_cast *>(beta)); - return DoBlasInternalStatus( - wrap::rocblas_cgemm, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - cb_alpha, static_cast(a.opaque()), lda, - static_cast(b.opaque()), ldb, cb_beta, - static_cast(c->opaque()), ldc); - } - case blas::DataType::kComplexDouble: { - auto cb_alpha = - complex_cast(*static_cast *>(alpha)); - auto cb_beta = - complex_cast(*static_cast *>(beta)); - return DoBlasInternalStatus( - wrap::rocblas_zgemm, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - cb_alpha, static_cast(a.opaque()), - lda, static_cast(b.opaque()), ldb, - cb_beta, static_cast(c->opaque()), ldc); - } + return call_gemm(wrap::rocblas_dgemm, 1.0); + case blas::DataType::kComplexFloat: + return call_gemm(wrap::rocblas_cgemm, rocblas_float_complex()); + case blas::DataType::kComplexDouble: + return call_gemm(wrap::rocblas_zgemm, rocblas_double_complex()); default: - return tsl::errors::Internal("Unsupported datatype for GEMM: ", - blas::DataTypeString(dtype)); + return absl::InternalError(absl::StrCat("Unsupported datatype for GEMM: ", + blas::DataTypeString(dtype))); } } -tsl::Status ROCMBlas::DoBlasGemmWithAlgorithm( +absl::Status ROCMBlas::DoBlasGemmWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a, + uint64_t n, uint64_t k, const void *alpha, const DeviceMemoryBase &a, blas::DataType type_a, int lda, const DeviceMemoryBase &b, blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, const NumericOptions &numeric_options, - blas::ProfileResult *output_profile_result, blas::CallContext context) { - // ROCM TODO: properly implement the interface - return tsl::errors::Internal("DoBlasGemmWithAlgorithm ", - "is not implemented on ROCm yet"); + blas::ProfileResult *profile_result, blas::CallContext context) { + if (type_a != type_b) { + return absl::InternalError(absl::StrFormat( + "DoBlasGemmWithAlgorithm: different " + "datatypes for the inputs a (%d) and b (%d) are unsupported", + static_cast(type_a), static_cast(type_b))); + } + TF_ASSIGN_OR_RETURN( + auto timer, + GpuTimer::CreateIfNeeded( + stream, profile_result && profile_result->warmup_run_executed(), + profile_result != nullptr)); + + // fall back to the default implementation + if (algorithm == blas::kDefaultAlgorithm && type_a == type_c) { + TF_RETURN_IF_ERROR(DoBlasGemm(stream, transa, transb, m, n, k, type_a, + alpha, a, lda, b, ldb, beta, c, ldc, + numeric_options, context)); + + } else { + CheckPreconditions(transa, transb, m, n, k, type_a, lda, ldb); + TF_ASSIGN_OR_RETURN(auto roc_type_a, AsRocBlasType(type_a)); + TF_ASSIGN_OR_RETURN(auto roc_type_c, AsRocBlasType(type_c)); + TF_ASSIGN_OR_RETURN(auto roc_comp_type, + AsRocBlasComputeType(computation_type)); + + VLOG(1) << absl::StreamFormat( + "doing rocBLAS GEMM with Algorithm: at=%d bt=%d m=%u n=%u " + "k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p " + "c=%p ldc=%d algorithm=%d type_a/b=%d type_c=%d comp_type=%d", + static_cast(transa), static_cast(transb), m, n, k, alpha, + a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc, algorithm, + static_cast(roc_type_a), static_cast(roc_type_c), + static_cast(roc_comp_type)); + + TF_RETURN_IF_ERROR(DoBlasInternalImpl( + wrap::rocblas_gemm_ex, stream, + /* pointer_mode_host = */ true, + /* error_on_failure = */ false, ROCMBlasTranspose(transa), + ROCMBlasTranspose(transb), (rocblas_int)m, (rocblas_int)n, + (rocblas_int)k, alpha, a.opaque(), roc_type_a, lda, b.opaque(), + roc_type_a, ldb, beta, c->opaque(), roc_type_c, ldc, c->opaque(), + roc_type_c, ldc, roc_comp_type, rocblas_gemm_algo_solution_index, + algorithm, GemmFloat16Flags(type_a, context, use_hgemm_alt_impl_))); + } + TF_RETURN_IF_ERROR( + PopulateProfileFromTimer(timer, algorithm, profile_result)); + + return absl::OkStatus(); } -tsl::Status ROCMBlas::DoBlasGemmStridedBatchedWithAlgorithm( +absl::Status ROCMBlas::DoBlasGemmStridedBatchedWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a, + uint64_t n, uint64_t k, const void *alpha, const DeviceMemoryBase &a, blas::DataType type_a, int lda, int64_t stride_a, const DeviceMemoryBase &b, blas::DataType type_b, int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64_t stride_c, int batch_count, blas::ComputationType computation_type, blas::AlgorithmType algorithm, const NumericOptions &numeric_options, - blas::ProfileResult *output_profile_result, blas::CallContext context) { - // ROCM TODO: properly implement the interface - return tsl::errors::Internal("DoBlasGemmStridedBatchedWithAlgorithm ", - "is not implemented on ROCm yet"); + blas::ProfileResult *profile_result, blas::CallContext context) { + if (type_a != type_b) { + return absl::InternalError(absl::StrFormat( + "DoBlasGemmStridedBatchedWithAlgorithm: different " + "datatypes for the inputs a (%d) and b (%d) are unsupported", + static_cast(type_a), static_cast(type_b))); + } + TF_ASSIGN_OR_RETURN( + auto timer, + GpuTimer::CreateIfNeeded( + stream, profile_result && profile_result->warmup_run_executed(), + profile_result != nullptr)); + + // fall back to the default implementation + if (algorithm == blas::kDefaultAlgorithm && type_a == type_c) { + TF_RETURN_IF_ERROR(DoBlasGemmStridedBatched( + stream, transa, transb, m, n, k, type_a, alpha, a, lda, stride_a, b, + ldb, stride_b, beta, c, ldc, stride_c, batch_count, numeric_options, + context)); + } else { + VLOG(1) << absl::StreamFormat( + "doing rocBLAS GEMM strided batched with Algorithm: at=%d bt=%d m=%u " + "n=%u " + "k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p " + "c=%p ldc=%d algorithm=%d type_a/b=%d type_c=%d stride_a/b/c=%d/%d/%d " + "batch_count=%d", + static_cast(transa), static_cast(transb), m, n, k, alpha, + a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc, algorithm, + static_cast(type_a), static_cast(type_c), stride_a, stride_b, + stride_c, batch_count); + + TF_ASSIGN_OR_RETURN(auto roc_type_a, AsRocBlasType(type_a)); + TF_ASSIGN_OR_RETURN(auto roc_type_c, AsRocBlasType(type_c)); + TF_ASSIGN_OR_RETURN(auto roc_comp_type, + AsRocBlasComputeType(computation_type)); + + TF_RETURN_IF_ERROR(DoBlasInternalImpl( + wrap::rocblas_gemm_strided_batched_ex, stream, + /* pointer_mode_host = */ true, + /* error_on_failure = */ false, ROCMBlasTranspose(transa), + ROCMBlasTranspose(transb), (rocblas_int)m, (rocblas_int)n, + (rocblas_int)k, alpha, a.opaque(), roc_type_a, lda, stride_a, + b.opaque(), roc_type_a, ldb, stride_b, beta, c->opaque(), roc_type_c, + ldc, stride_c, c->opaque(), roc_type_c, ldc, stride_c, batch_count, + roc_comp_type, rocblas_gemm_algo_solution_index, algorithm, + GemmFloat16Flags(type_a, context, use_hgemm_alt_impl_))); + } + TF_RETURN_IF_ERROR( + PopulateProfileFromTimer(timer, algorithm, profile_result)); + + return absl::OkStatus(); } +template +struct NameWrap : Lambda { + using Lambda::operator(); + constexpr static const char *kName = "rocblas_gemm_ex_get_solutions"; +}; +template +NameWrap(Func) -> NameWrap; + +#define ASSIGN_OR_FALSE(lhs, rexpr) \ + result = (rexpr); \ + if (TF_PREDICT_FALSE(!result.ok())) return false; \ + lhs = std::move(result).value() + bool ROCMBlas::GetBlasGemmAlgorithms( - Stream *stream, std::vector *out_algorithms) { - // ROCM TODO: properly implement the interface - return true; + Stream *stream, const gpu::MatrixDescriptor &a, + const gpu::MatrixDescriptor &b, gpu::OutputMatrixDescriptor *c, + const void *alpha, const void *beta, + std::vector *out_algorithms) { + out_algorithms->clear(); + auto blas_lambda = [this, out_algorithms](auto handle, auto &&blas_func, + auto &&...rest) { + rocblas_int num_sols = 0; + + if (auto ret = blas_func(handle, std::forward(rest)..., + nullptr, &num_sols); + ret != rocblas_status_success) { + return ret; + } + solutions_.resize(num_sols); + if (auto ret = blas_func(handle, std::forward(rest)..., + solutions_.data(), &num_sols); + ret != rocblas_status_success) { + return ret; + } + out_algorithms->resize(num_sols); + for (rocblas_int i = 0; i < num_sols; i++) { + (*out_algorithms)[i] = solutions_[i]; + } + return rocblas_status_success; + }; + + VLOG(1) << absl::StreamFormat( + "GetBlasAlgorithms: at=%d bt=%d m=%u n=%u " + "k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p " + "c=%p ldc=%d type_a/b=%d type_c=%d stride_a/b/c=%d/%d/%d " + "batch_count=%d", + static_cast(a.transpose), static_cast(b.transpose), c->m, c->n, + c->k, alpha, a.data.opaque(), a.leading_dim_stride, b.data.opaque(), + b.leading_dim_stride, beta, c->data.opaque(), c->leading_dim_stride, + static_cast(a.type), static_cast(c->type), a.batch_stride, + b.batch_stride, c->batch_stride, c->batch_size); + + if (a.type != b.type) { + LOG(ERROR) << "Gemm arguments types differ: no feasible solutions!"; + return false; + } + absl::StatusOr result; + ASSIGN_OR_FALSE(auto roc_type_a, AsRocBlasType(a.type)); + ASSIGN_OR_FALSE(auto roc_type_c, AsRocBlasType(c->type)); + ASSIGN_OR_FALSE(auto roc_comp_type, AsRocBlasComputeType(c->compute_type)); + + if (c->batch_size == 1) { + // TODO: we should possibly use GemmFloat16Flags(type_a, context) here.. + return DoBlasInternalFailureOK( + NameWrap{blas_lambda}, stream, true, + wrap::rocblas_gemm_ex_get_solutions, ROCMBlasTranspose(a.transpose), + ROCMBlasTranspose(b.transpose), c->m, c->n, c->k, alpha, + a.data.opaque(), roc_type_a, a.leading_dim_stride, b.data.opaque(), + roc_type_a, b.leading_dim_stride, beta, c->data.opaque(), roc_type_c, + c->leading_dim_stride, c->data.opaque(), roc_type_c, + c->leading_dim_stride, roc_comp_type, rocblas_gemm_algo_solution_index, + 0); + } + return DoBlasInternalFailureOK( + NameWrap{blas_lambda}, stream, true, + wrap::rocblas_gemm_strided_batched_ex_get_solutions, + ROCMBlasTranspose(a.transpose), ROCMBlasTranspose(b.transpose), c->m, + c->n, c->k, alpha, a.data.opaque(), roc_type_a, a.leading_dim_stride, + a.batch_stride, b.data.opaque(), roc_type_a, b.leading_dim_stride, + b.batch_stride, beta, c->data.opaque(), roc_type_c, c->leading_dim_stride, + c->batch_stride, c->data.opaque(), roc_type_c, c->leading_dim_stride, + c->batch_stride, c->batch_size, roc_comp_type, + rocblas_gemm_algo_solution_index, 0); } +#undef ASSIGN_OR_FALSE + +namespace { struct MemoryCopyOp { char *src_ptr; @@ -597,7 +733,7 @@ struct MemoryCopyOp { // Check whether two Memory Copy Ops can be fold together. // If it's true, fold it. Otherwise, return false. -static bool MemCopyOpsFold(MemoryCopyOp &y, const MemoryCopyOp &x) { +bool MemCopyOpsFold(MemoryCopyOp &y, const MemoryCopyOp &x) { bool misaligned = (x.size & 3) || (reinterpret_cast(x.dst_ptr) & 3) || (reinterpret_cast(x.src_ptr) & 3) || @@ -635,14 +771,13 @@ static bool MemCopyOpsFold(MemoryCopyOp &y, const MemoryCopyOp &x) { // The below algorithm tries to minimize the number of memcpy by consolidating // neighboring memcpy into a single request. template -tsl::Status ReorganizeMemory(Stream *stream, - DeviceMemory *device_memory, - const std::vector &raw_ptrs, - int batch_count, uint64_t batch_stride, - bool gather) { +absl::Status ReorganizeMemory(Stream *stream, + DeviceMemory *device_memory, + const std::vector &raw_ptrs, + int batch_count, uint64_t batch_stride, + bool gather) { if (gather == false) { - return tsl::Status(absl::StatusCode::kUnimplemented, - "gather=false is unsupported"); + return absl::UnimplementedError("gather=false is unsupported"); } assert(batch_count > 0); @@ -687,32 +822,29 @@ tsl::Status ReorganizeMemory(Stream *stream, } else { DeviceMemoryBase src_mem = DeviceMemoryBase(x.src_ptr, x.size); DeviceMemoryBase target_mem = DeviceMemoryBase(x.dst_ptr, x.size); - bool a_status = stream->ThenMemcpy(&target_mem, src_mem, x.size).ok(); - if (!a_status) { - return tsl::Status( - absl::StatusCode::kInternal, - "failed to copy device memory in ROCMBlas::DoBlasGemmBatched"); - } + TF_RETURN_IF_ERROR(stream->Memcpy(&target_mem, src_mem, x.size)); } i++; } - return tsl::OkStatus(); + return absl::OkStatus(); } template -tsl::Status ROCMBlas::AllocateStridedBuffer( - const std::vector::mapped_type *> - &raw_ptrs, - int batch_count, uint64_t batch_stride, ScratchAllocator *scratch_allocator, - Stream *stream, - std::unique_ptr::mapped_type>> *temp_memory, - DeviceMemory::mapped_type> - *device_memory, - bool copy_data, bool &reallocated) { - assert(device_memory != nullptr); - - using MAPPED_T = typename RocBlasTypeConversionHelper::mapped_type; +struct AllocateStridedResult { + using Type = RocBlasType_t; + DeviceMemory device_mem; + bool reallocated; +}; + +// A helper allocation function to convert raw pointers memory layout to +// strided flavor +template +absl::StatusOr> AllocateStridedBuffer( + const std::vector *> &raw_ptrs, int batch_count, + uint64_t batch_stride, ScratchAllocator *scratch_allocator, Stream *stream, + bool copy_data) { + using MAPPED_T = RocBlasType_t; + AllocateStridedResult res; bool needs_allocate_strided = false; for (int i = 1; i < batch_count; ++i) { @@ -728,42 +860,37 @@ tsl::Status ROCMBlas::AllocateStridedBuffer( // No need to do re-allocation, take the short cut and return if (!needs_allocate_strided) { - *device_memory = DeviceMemory( + res.device_mem = DeviceMemory( DeviceMemoryBase(raw_ptrs[0], matrix_batch_byte_size)); - reallocated = false; - return tsl::OkStatus(); + res.reallocated = false; + return res; } - if (scratch_allocator != nullptr) { - TF_ASSIGN_OR_RETURN( - DeviceMemory batch_matrix_bytes, - scratch_allocator->AllocateBytes(matrix_batch_byte_size)); - *device_memory = DeviceMemory(batch_matrix_bytes); - } else { - assert(temp_memory != nullptr); - TF_ASSIGN_OR_RETURN(*temp_memory, stream->AllocateTemporaryArray( - matrix_batch_byte_size)); - *device_memory = - DeviceMemory(*(*temp_memory)->mutable_device_memory()); + if (scratch_allocator == nullptr) { + return absl::InternalError("scratch_allocator is null"); } - - reallocated = true; - - if (copy_data) - return ReorganizeMemory(stream, device_memory, raw_ptrs, batch_count, - batch_stride, true); - return tsl::OkStatus(); + TF_ASSIGN_OR_RETURN(DeviceMemory batch_matrix_bytes, + scratch_allocator->AllocateBytes(matrix_batch_byte_size)); + res.device_mem = DeviceMemory(batch_matrix_bytes); + res.reallocated = true; + if (copy_data) { + TF_RETURN_IF_ERROR(ReorganizeMemory(stream, &res.device_mem, raw_ptrs, + batch_count, batch_stride, true)); + } + return res; } +} // namespace + template -tsl::Status ROCMBlas::DoBlasGemmBatchedInternal( +absl::Status ROCMBlas::DoBlasGemmBatchedInternal( FuncT rocblas_func, Stream *stream, blas::Transpose transa, - blas::Transpose transb, uint64_t m, uint64 n, uint64 k, T alpha, + blas::Transpose transb, uint64_t m, uint64_t n, uint64_t k, T alpha, DeviceMemorySlice a_ptrs_to_wrappers, int lda, DeviceMemorySlice b_ptrs_to_wrappers, int ldb, T beta, DeviceMemorySlice c_ptrs_to_wrappers, int ldc, int batch_count, ScratchAllocator *scratch_allocator) { - using MAPPED_T = typename RocBlasTypeConversionHelper::mapped_type; + using MAPPED_T = RocBlasType_t; // Sanity checks before making any further progress uint64_t batch_stride_a = 0; @@ -790,96 +917,137 @@ tsl::Status ROCMBlas::DoBlasGemmBatchedInternal( } // Allocate local vectors to hold device pointers to matrices - std::vector a_raw_ptrs, b_raw_ptrs, c_raw_ptrs; + std::vector a_raw_ptrs(batch_count), b_raw_ptrs(batch_count), + c_raw_ptrs(batch_count); for (int i = 0; i < batch_count; ++i) { // static_cast does work when converting Eigen::half* to rocblas_half*, // hence the use of reinterpret_cast - a_raw_ptrs.push_back( - reinterpret_cast(a_ptrs_to_wrappers[i]->opaque())); - b_raw_ptrs.push_back( - reinterpret_cast(b_ptrs_to_wrappers[i]->opaque())); - c_raw_ptrs.push_back( - reinterpret_cast(c_ptrs_to_wrappers[i]->opaque())); + a_raw_ptrs[i] = + reinterpret_cast(a_ptrs_to_wrappers[i]->opaque()); + b_raw_ptrs[i] = + reinterpret_cast(b_ptrs_to_wrappers[i]->opaque()); + c_raw_ptrs[i] = + reinterpret_cast(c_ptrs_to_wrappers[i]->opaque()); } - DeviceMemory a; // Make sure the temporary memory are in-scope before the function returns - std::unique_ptr> a_temp; - bool reallocated_a, reallocated_b, reallocated_c; - tsl::Status a_allocation_status = AllocateStridedBuffer( - a_raw_ptrs, batch_count, batch_stride_a, scratch_allocator, stream, - &a_temp, &a, true, reallocated_a); - if (a_allocation_status != tsl::OkStatus()) { - return a_allocation_status; + TF_ASSIGN_OR_RETURN( + auto a, AllocateStridedBuffer(a_raw_ptrs, batch_count, batch_stride_a, + scratch_allocator, stream, true)); + + TF_ASSIGN_OR_RETURN( + auto b, AllocateStridedBuffer(b_raw_ptrs, batch_count, batch_stride_b, + scratch_allocator, stream, true)); + + TF_ASSIGN_OR_RETURN( + auto c, AllocateStridedBuffer(c_raw_ptrs, batch_count, batch_stride_c, + scratch_allocator, stream, + true)); // can disable copy if beta=0 + + MAPPED_T *alpha_ptr = reinterpret_cast(&alpha); + MAPPED_T *beta_ptr = reinterpret_cast(&beta); + bool ok = DoBlasInternal( + rocblas_func, stream, /* pointer_mode_host = */ true, + ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, + GpuComplex(alpha_ptr), GpuMemory(a.device_mem), lda, batch_stride_a, + GpuMemory(b.device_mem), ldb, batch_stride_b, GpuComplex(beta_ptr), + GpuMemoryMutable(&c.device_mem), ldc, batch_stride_c, batch_count); + + if (!ok) { + return absl::Status(absl::StatusCode::kInternal, + "failed BLAS call, see log for details"); } - - DeviceMemory b; - std::unique_ptr> b_temp; - tsl::Status b_allocation_status = AllocateStridedBuffer( - b_raw_ptrs, batch_count, batch_stride_b, scratch_allocator, stream, - &b_temp, &b, true, reallocated_b); - if (b_allocation_status != tsl::OkStatus()) { - return b_allocation_status; + if (c.reallocated) { + return ReorganizeMemory(stream, &c.device_mem, c_raw_ptrs, batch_count, + batch_stride_c, false); } + return absl::OkStatus(); +} - DeviceMemory c; - std::unique_ptr> c_temp; - tsl::Status c_allocation_status = AllocateStridedBuffer( - c_raw_ptrs, batch_count, batch_stride_c, scratch_allocator, stream, - &c_temp, &c, true, reallocated_c); // can disable copy if beta=0 - if (c_allocation_status != tsl::OkStatus()) { - return c_allocation_status; +class rocblas_hgemm_strided_batched_mfma { + int ALT_; + + public: + rocblas_hgemm_strided_batched_mfma(int ALT) : ALT_(ALT) {} + static const char *kName; + rocblas_status operator()(rocblas_handle handle, rocblas_operation transA, + rocblas_operation transB, rocblas_int m, + rocblas_int n, rocblas_int k, + const rocblas_half *alpha, const rocblas_half *A, + rocblas_int lda, rocblas_stride stride_a, + const rocblas_half *B, rocblas_int ldb, + rocblas_stride stride_b, const rocblas_half *beta, + rocblas_half *C, rocblas_int ldc, + rocblas_stride stride_c, rocblas_int batch_count) { + float alpha32 = static_cast(*(const __half *)alpha); + float beta32 = static_cast(*(const __half *)beta); + uint32_t flags = rocblas_gemm_flags_none; + if (ALT_) flags = rocblas_gemm_flags_fp16_alt_impl; + return wrap::rocblas_gemm_strided_batched_ex( + handle, transA, transB, m, n, k, &alpha32, A, rocblas_datatype_f16_r, + lda, stride_a, B, rocblas_datatype_f16_r, ldb, stride_b, &beta32, C, + rocblas_datatype_f16_r, ldc, stride_c, C, rocblas_datatype_f16_r, ldc, + stride_c, batch_count, rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, 0, flags); } +}; - bool ok; - if constexpr (std::is_same_v) { - float alpha_ = static_cast(alpha); - float beta_ = static_cast(beta); - const void *alpha_ptr = reinterpret_cast(&alpha_); - const void *beta_ptr = reinterpret_cast(&beta_); - - ok = DoBlasInternal( - rocblas_func, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - alpha_ptr, a.opaque(), rocblas_datatype_bf16_r, lda, batch_stride_a, - b.opaque(), rocblas_datatype_bf16_r, ldb, batch_stride_b, beta_ptr, - c.opaque(), rocblas_datatype_bf16_r, ldc, batch_stride_c, c.opaque(), - rocblas_datatype_bf16_r, ldc, batch_stride_c, batch_count, - rocblas_datatype_f32_r, rocblas_gemm_algo_standard, 0, 0); - } else { - MAPPED_T *alpha_ptr = reinterpret_cast(&alpha); - MAPPED_T *beta_ptr = reinterpret_cast(&beta); - ok = DoBlasInternal(rocblas_func, stream, /* pointer_mode_host = */ true, - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, - n, k, GpuComplex(alpha_ptr), GpuMemory(a), lda, - batch_stride_a, GpuMemory(b), ldb, batch_stride_b, - GpuComplex(beta_ptr), GpuMemoryMutable(&c), ldc, - batch_stride_c, batch_count); +const char *rocblas_hgemm_strided_batched_mfma::kName = + "rocblas_hgemm_strided_batched_mfma"; + +class rocblas_gemm_strided_batched_bf16 { + public: + static const char *kName; + rocblas_status operator()(rocblas_handle handle, rocblas_operation transA, + rocblas_operation transB, rocblas_int m, + rocblas_int n, rocblas_int k, + const rocblas_bfloat16 *alpha, + const rocblas_bfloat16 *A, rocblas_int lda, + rocblas_stride stride_a, const rocblas_bfloat16 *B, + rocblas_int ldb, rocblas_stride stride_b, + const rocblas_bfloat16 *beta, rocblas_bfloat16 *C, + rocblas_int ldc, rocblas_stride stride_c, + rocblas_int batch_count) { + float alpha32 = static_cast(*(const Eigen::bfloat16 *)alpha); + float beta32 = static_cast(*(const Eigen::bfloat16 *)beta); + uint32_t flags = rocblas_gemm_flags_none; + return wrap::rocblas_gemm_strided_batched_ex( + handle, transA, transB, m, n, k, &alpha32, A, rocblas_datatype_bf16_r, + lda, stride_a, B, rocblas_datatype_bf16_r, ldb, stride_b, &beta32, C, + rocblas_datatype_bf16_r, ldc, stride_c, C, rocblas_datatype_bf16_r, ldc, + stride_c, batch_count, rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, 0, flags); } - if (!ok) - return tsl::Status(absl::StatusCode::kInternal, - "failed BLAS call, see log for details"); - if (reallocated_c) - return ReorganizeMemory(stream, &c, c_raw_ptrs, batch_count, batch_stride_c, - false); - return tsl::OkStatus(); -} +}; +const char *rocblas_gemm_strided_batched_bf16::kName = + "rocblas_gemm_strided_batched_bf16"; bool ROCMBlas::DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, float alpha, DeviceMemorySlice a, + uint64_t n, uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, blas::CallContext context) { - blas_log("DoBlasGemmBatched"); const Eigen::half alpha_half(alpha); const Eigen::half beta_half(beta); + absl::Status status; + + auto call_gemm = [&](auto x) { + return DoBlasGemmBatchedInternal(x, stream, transa, transb, m, n, k, + alpha_half, a, lda, b, ldb, beta_half, c, + ldc, batch_count, scratch_allocator); + }; + + if (has_mfma_) { + bool is_backprop = (context == blas::CallContext::kBackpropInput1) || + (context == blas::CallContext::kBackpropInput2); + status = call_gemm( + rocblas_hgemm_strided_batched_mfma(is_backprop && use_hgemm_alt_impl_)); + } else { + status = call_gemm(wrap::rocblas_hgemm_strided_batched); + } - tsl::Status status = DoBlasGemmBatchedInternal( - wrap::rocblas_hgemm_strided_batched, stream, transa, transb, m, n, k, - alpha_half, a, lda, b, ldb, beta_half, c, ldc, batch_count, - scratch_allocator); if (!status.ok()) { LOG(ERROR) << status; } @@ -889,18 +1057,17 @@ bool ROCMBlas::DoBlasGemmBatched( bool ROCMBlas::DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, float alpha, + uint64_t n, uint64_t k, float alpha, DeviceMemorySlice a_array, int lda, DeviceMemorySlice b_array, int ldb, float beta, DeviceMemorySlice c_array, int ldc, int batch_count, const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, blas::CallContext context) { - blas_log("DoBlasGemmBatched"); const Eigen::bfloat16 alpha_bf16(alpha); const Eigen::bfloat16 beta_bf16(beta); - tsl::Status status = DoBlasGemmBatchedInternal( - wrap::rocblas_gemm_strided_batched_ex, stream, transa, transb, m, n, k, + absl::Status status = DoBlasGemmBatchedInternal( + rocblas_gemm_strided_batched_bf16(), stream, transa, transb, m, n, k, alpha_bf16, a_array, lda, b_array, ldb, beta_bf16, c_array, ldc, batch_count, scratch_allocator); if (!status.ok()) { @@ -909,288 +1076,154 @@ bool ROCMBlas::DoBlasGemmBatched( return status.ok(); } -bool ROCMBlas::DoBlasGemmBatched( - Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, float alpha, DeviceMemorySlice a_array, - int lda, DeviceMemorySlice b_array, int ldb, float beta, - DeviceMemorySlice c_array, int ldc, int batch_count, - const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, - blas::CallContext context) { - blas_log("DoBlasGemmBatched"); - tsl::Status status = DoBlasGemmBatchedInternal( - wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k, - alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, - scratch_allocator); - if (!status.ok()) { - LOG(ERROR) << status; +#define IMPL_DoBlasGemmBatched(T, Fun) \ + bool ROCMBlas::DoBlasGemmBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64_t m, uint64_t n, uint64 k, T alpha, DeviceMemorySlice a_array, \ + int lda, DeviceMemorySlice b_array, int ldb, T beta, \ + DeviceMemorySlice c_array, int ldc, int batch_count, \ + const NumericOptions &numeric_options, \ + ScratchAllocator *scratch_allocator, blas::CallContext context) { \ + absl::Status status = DoBlasGemmBatchedInternal( \ + Fun, stream, transa, transb, m, n, k, alpha, a_array, lda, b_array, \ + ldb, beta, c_array, ldc, batch_count, scratch_allocator); \ + if (!status.ok()) { \ + LOG(ERROR) << status; \ + } \ + return status.ok(); \ } - return status.ok(); -} -bool ROCMBlas::DoBlasGemmBatched( - Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, double alpha, DeviceMemorySlice a_array, - int lda, DeviceMemorySlice b_array, int ldb, double beta, - DeviceMemorySlice c_array, int ldc, int batch_count, - const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, - blas::CallContext context) { - blas_log("DoBlasGemmBatched"); - tsl::Status status = DoBlasGemmBatchedInternal( - wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k, - alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, - scratch_allocator); - if (!status.ok()) { - LOG(ERROR) << status; +IMPL_DoBlasGemmBatched(float, wrap::rocblas_sgemm_strided_batched) + IMPL_DoBlasGemmBatched(double, wrap::rocblas_dgemm_strided_batched) + IMPL_DoBlasGemmBatched(std::complex, + wrap::rocblas_cgemm_strided_batched) + IMPL_DoBlasGemmBatched(std::complex, + wrap::rocblas_zgemm_strided_batched) +#define IMPL_DoBlasTrsm(T, Fun, Fun2) \ + bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, \ + blas::UpperLower uplo, blas::Transpose transa, \ + blas::Diagonal diag, uint64_t m, uint64 n, \ + T alpha, const DeviceMemory &a, int lda, \ + DeviceMemory *b, int ldb) { \ + return DoBlasInternal(Fun, stream, /* pointer_mode_host = */ true, \ + ROCMBlasSide(side), ROCMBlasUpperLower(uplo), \ + ROCMBlasTranspose(transa), ROCMBlasDiagonal(diag), \ + m, n, complex_cast(alpha), complex_cast(a), lda, \ + complex_cast(b), ldb); \ + } \ + \ + bool ROCMBlas::DoBlasTrsmBatched( \ + Stream *stream, blas::Side side, blas::UpperLower uplo, \ + blas::Transpose transa, blas::Diagonal diag, uint64_t m, uint64 n, \ + T alpha, const DeviceMemory &as, int lda, DeviceMemory *bs, \ + int ldb, int batch_count) { \ + return DoBlasInternal(Fun2, stream, true /* = pointer_mode_host */, \ + ROCMBlasSide(side), ROCMBlasUpperLower(uplo), \ + ROCMBlasTranspose(transa), ROCMBlasDiagonal(diag), \ + m, n, complex_cast(alpha), complex_cast(as), lda, \ + complex_cast(*bs), ldb, batch_count); \ } - return status.ok(); -} - -bool ROCMBlas::DoBlasGemmBatched( - Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, std::complex alpha, - DeviceMemorySlice> a_array, int lda, - DeviceMemorySlice> b_array, int ldb, - std::complex beta, DeviceMemorySlice> c_array, - int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context) { - blas_log("DoBlasGemmBatched"); - tsl::Status status = DoBlasGemmBatchedInternal( - wrap::rocblas_cgemm_strided_batched, stream, transa, transb, m, n, k, - alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, - scratch_allocator); - if (!status.ok()) { - LOG(ERROR) << status; - } - return status.ok(); -} - -bool ROCMBlas::DoBlasGemmBatched( - Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, std::complex alpha, - DeviceMemorySlice> a_array, int lda, - DeviceMemorySlice> b_array, int ldb, - std::complex beta, DeviceMemorySlice> c_array, - int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context) { - blas_log("DoBlasGemmBatched"); - tsl::Status status = DoBlasGemmBatchedInternal( - wrap::rocblas_zgemm_strided_batched, stream, transa, transb, m, n, k, - alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, - scratch_allocator); - if (!status.ok()) { - LOG(ERROR) << status; - } - return status.ok(); -} - -bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, - blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, - float alpha, const DeviceMemory &a, int lda, - DeviceMemory *b, int ldb) { - blas_log("DoBlasTrsm"); - return DoBlasInternal(wrap::rocblas_strsm, stream, - /* pointer_mode_host = */ true, ROCMBlasSide(side), - ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), - ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda, - GpuMemoryMutable(b), ldb); -} -bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, - blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, - double alpha, const DeviceMemory &a, int lda, - DeviceMemory *b, int ldb) { - blas_log("DoBlasTrsm"); - return DoBlasInternal(wrap::rocblas_dtrsm, stream, - /* pointer_mode_host = */ true, ROCMBlasSide(side), - ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), - ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda, - GpuMemoryMutable(b), ldb); -} - -bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, - blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, - std::complex alpha, - const DeviceMemory> &a, int lda, - DeviceMemory> *b, int ldb) { - return DoBlasInternal(wrap::rocblas_ctrsm, stream, - /* pointer_mode_host = */ true, ROCMBlasSide(side), - ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), - ROCMBlasDiagonal(diag), m, n, complex_cast(alpha), - complex_cast(a), lda, complex_cast(b), ldb); -} - -bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, - blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, - std::complex alpha, - const DeviceMemory> &a, int lda, - DeviceMemory> *b, int ldb) { - return DoBlasInternal(wrap::rocblas_ztrsm, stream, - /* pointer_mode_host = */ true, ROCMBlasSide(side), - ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), - ROCMBlasDiagonal(diag), m, n, complex_cast(alpha), - complex_cast(a), lda, complex_cast(b), ldb); -} - -bool ROCMBlas::DoBlasTrsmBatched(Stream *stream, blas::Side side, - blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, - float alpha, const DeviceMemory &as, - int lda, DeviceMemory *bs, int ldb, - int batch_count) { - return DoBlasInternal(wrap::rocblas_strsm_batched, stream, - true /* = pointer_mode_host */, ROCMBlasSide(side), - ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), - ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(as), - lda, GpuMemoryMutable(bs), ldb, batch_count); -} - -bool ROCMBlas::DoBlasTrsmBatched(Stream *stream, blas::Side side, - blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, - double alpha, const DeviceMemory &as, - int lda, DeviceMemory *bs, int ldb, - int batch_count) { - return DoBlasInternal(wrap::rocblas_dtrsm_batched, stream, - true /* = pointer_mode_host */, ROCMBlasSide(side), - ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), - ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(as), - lda, GpuMemoryMutable(bs), ldb, batch_count); -} - -bool ROCMBlas::DoBlasTrsmBatched(Stream *stream, blas::Side side, - blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, - std::complex alpha, - const DeviceMemory *> &as, - int lda, - DeviceMemory *> *bs, - int ldb, int batch_count) { - return DoBlasInternal( - wrap::rocblas_ctrsm_batched, stream, true /* = pointer_mode_host */, - ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), - ROCMBlasDiagonal(diag), m, n, complex_cast(alpha), - static_cast(as.opaque()), lda, - static_cast(bs->opaque()), ldb, - batch_count); -} - -bool ROCMBlas::DoBlasTrsmBatched(Stream *stream, blas::Side side, - blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, - std::complex alpha, - const DeviceMemory *> &as, - int lda, - DeviceMemory *> *bs, - int ldb, int batch_count) { - return DoBlasInternal( - wrap::rocblas_ztrsm_batched, stream, true /* = pointer_mode_host */, - ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), - ROCMBlasDiagonal(diag), m, n, complex_cast(alpha), - static_cast(as.opaque()), lda, - static_cast(bs->opaque()), ldb, - batch_count); -} - -tsl::Status ROCMBlas::DoBlasGemmStridedBatched( - Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, blas::DataType dtype, const void *alpha, - const DeviceMemoryBase &a, int lda, int64_t stride_a, - const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta, - DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count, - const NumericOptions &numeric_options, blas::CallContext context) { + IMPL_DoBlasTrsm(float, wrap::rocblas_strsm, + wrap::rocblas_strsm_batched) + IMPL_DoBlasTrsm(double, wrap::rocblas_dtrsm, + wrap::rocblas_dtrsm_batched) + IMPL_DoBlasTrsm(std::complex, + wrap::rocblas_ctrsm, + wrap::rocblas_ctrsm_batched) + IMPL_DoBlasTrsm(std::complex, + wrap::rocblas_ztrsm, + wrap::rocblas_ztrsm_batched) + + absl::Status + ROCMBlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, + uint64_t m, uint64_t n, uint64_t k, blas::DataType dtype, + const void *alpha, const DeviceMemoryBase &a, int lda, int64_t stride_a, + const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta, + DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count, + const NumericOptions &numeric_options, blas::CallContext context) { VLOG(1) << absl::StreamFormat( - "doing rocBLAS SGEMM Strided Batched: at=%d bt=%d m=%u n=%u " + "doing rocBLAS GEMM Strided Batched: at=%d bt=%d m=%u n=%u " "k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p " - "c=%p ldc=%d", + "c=%p ldc=%d stride_a/b/c=%d/%d/%d batch_count=%d", static_cast(transa), static_cast(transb), m, n, k, alpha, - a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc); + a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc, stride_a, + stride_b, stride_c, batch_count); + + absl::Status status; + auto call_gemm = [&](auto func, auto type) { + return DoBlasInternalStatus( + func, stream, false, /* pointer_mode_host */ + ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(a.opaque()), lda, stride_a, + reinterpret_cast(b.opaque()), ldb, stride_b, + reinterpret_cast(beta), + reinterpret_cast(c->opaque()), ldc, stride_c, + batch_count); + }; switch (dtype) { case blas::DataType::kHalf: { - const Eigen::half alpha_half(*static_cast(alpha)); - const Eigen::half beta_half(*static_cast(beta)); - return DoBlasInternalStatus( - wrap::rocblas_hgemm_strided_batched, stream, - false, /* pointer_mode_host */ - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - reinterpret_cast(&alpha_half), - reinterpret_cast(a.opaque()), lda, stride_a, - reinterpret_cast(b.opaque()), ldb, stride_b, - reinterpret_cast(&beta_half), - reinterpret_cast(c->opaque()), ldc, stride_c, - batch_count); + bool is_backprop = (context == blas::CallContext::kBackpropInput1) || + (context == blas::CallContext::kBackpropInput2); + Eigen::half alpha_half = Eigen::half(*static_cast(alpha)); + Eigen::half beta_half = Eigen::half(*static_cast(beta)); + alpha = &alpha_half; + beta = &beta_half; + if (has_mfma_) { + return call_gemm(rocblas_hgemm_strided_batched_mfma( + is_backprop && use_hgemm_alt_impl_), + rocblas_half()); + } else { + return call_gemm(wrap::rocblas_hgemm_strided_batched, rocblas_half()); + } + } + case blas::DataType::kBF16: { + Eigen::bfloat16 alpha_bf16, beta_bf16; + alpha_bf16 = Eigen::bfloat16(*static_cast(alpha)); + beta_bf16 = Eigen::bfloat16(*static_cast(beta)); + alpha = &alpha_bf16; + beta = &beta_bf16; + return call_gemm(rocblas_gemm_strided_batched_bf16(), rocblas_bfloat16()); } - case blas::DataType::kBF16: - return DoBlasInternalStatus( - wrap::rocblas_gemm_strided_batched_ex, stream, - false, /* pointer_mode_host */ - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, alpha, - a.opaque(), rocblas_datatype_bf16_r, lda, stride_a, b.opaque(), - rocblas_datatype_bf16_r, ldb, stride_b, beta, c->opaque(), - rocblas_datatype_bf16_r, ldc, stride_c, c->opaque(), - rocblas_datatype_bf16_r, ldc, stride_c, batch_count, - rocblas_datatype_f32_r, rocblas_gemm_algo_standard, 0, 0); case blas::DataType::kFloat: - return DoBlasInternalStatus( - wrap::rocblas_sgemm_strided_batched, stream, - false, /* pointer_mode_host */ - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(a.opaque()), lda, stride_a, - reinterpret_cast(b.opaque()), ldb, stride_b, - reinterpret_cast(beta), - reinterpret_cast(c->opaque()), ldc, stride_c, batch_count); + return call_gemm(wrap::rocblas_sgemm_strided_batched, 1.0f); case blas::DataType::kDouble: - return DoBlasInternalStatus( - wrap::rocblas_dgemm_strided_batched, stream, - false, /* pointer_mode_host */ - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(a.opaque()), lda, stride_a, - reinterpret_cast(b.opaque()), ldb, stride_b, - reinterpret_cast(beta), - reinterpret_cast(c->opaque()), ldc, stride_c, batch_count); - case blas::DataType::kComplexFloat: { - auto cb_alpha = - complex_cast(*static_cast *>(alpha)); - auto cb_beta = - complex_cast(*static_cast *>(beta)); - return DoBlasInternalStatus( - wrap::rocblas_cgemm_strided_batched, stream, - /* pointer_mode_host = */ true, ROCMBlasTranspose(transa), - ROCMBlasTranspose(transb), m, n, k, cb_alpha, - static_cast(a.opaque()), lda, stride_a, - static_cast(b.opaque()), ldb, stride_b, - cb_beta, static_cast(c->opaque()), ldc, - stride_c, batch_count); - } - case blas::DataType::kComplexDouble: { - auto cb_alpha = - complex_cast(*static_cast *>(alpha)); - auto cb_beta = - complex_cast(*static_cast *>(beta)); - return DoBlasInternalStatus( - wrap::rocblas_zgemm_strided_batched, stream, - /* pointer_mode_host = */ true, ROCMBlasTranspose(transa), - ROCMBlasTranspose(transb), m, n, k, cb_alpha, - static_cast(a.opaque()), lda, - stride_a, static_cast(b.opaque()), - ldb, stride_b, cb_beta, - static_cast(c->opaque()), ldc, stride_c, - batch_count); - } + return call_gemm(wrap::rocblas_dgemm_strided_batched, 1.0); + case blas::DataType::kComplexFloat: + return call_gemm(wrap::rocblas_cgemm_strided_batched, + rocblas_float_complex()); + case blas::DataType::kComplexDouble: + return call_gemm(wrap::rocblas_zgemm_strided_batched, + rocblas_double_complex()); default: - return tsl::errors::Internal(absl::StrCat( - "Unsupported datatype for GEMM: ", blas::DataTypeString(dtype))); + return absl::InternalError(absl::StrCat("Unsupported datatype for GEMM: ", + blas::DataTypeString(dtype))); } } -tsl::Status ROCMBlas::GetVersion(string *version) { - return tsl::errors::Unimplemented(""); +absl::Status ROCMBlas::GetVersion(string *version) { +#if TF_ROCM_VERSION >= 60300 // Not yet available in ROCM-6.1 + absl::MutexLock lock{&mu_}; + size_t len = 0; + if (auto res = rocblas_get_version_string_size(&len); + res != rocblas_status_success) { + return absl::InternalError( + absl::StrCat("GetVersion failed with: ", ToString(res))); + } + std::vector buf(len + 1); + if (auto res = rocblas_get_version_string(buf.data(), len); + res != rocblas_status_success) { + return absl::InternalError( + absl::StrCat("GetVersion failed with: ", ToString(res))); + } + *version = string(buf.begin(), buf.end()); + return absl::OkStatus(); +#else + return absl::UnimplementedError(""); +#endif } } // namespace gpu @@ -1200,7 +1233,7 @@ void initialize_rocblas() { rocm::kROCmPlatformId, PluginKind::kBlas); if (!rocBlasAlreadyRegistered) { - tsl::Status status = + absl::Status status = PluginRegistry::Instance() ->RegisterFactory( rocm::kROCmPlatformId, "rocBLAS", @@ -1233,5 +1266,6 @@ void initialize_rocblas() { } // namespace stream_executor -REGISTER_MODULE_INITIALIZER(register_rocblas, - { stream_executor::initialize_rocblas(); }); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(register_rocblas, { + stream_executor::initialize_rocblas(); +}); diff --git a/xla/stream_executor/rocm/rocm_blas.h b/xla/stream_executor/rocm/rocm_blas.h index 7c8621f685181..537a3a7a46f07 100644 --- a/xla/stream_executor/rocm/rocm_blas.h +++ b/xla/stream_executor/rocm/rocm_blas.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,15 +24,17 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "rocm/rocm_config.h" + +#define ROCBLAS_BETA_FEATURES_API #if TF_ROCM_VERSION >= 50600 #include "rocm/include/rocblas/rocblas.h" #else #include "rocm/include/rocblas.h" #endif #include "xla/stream_executor/blas.h" +#include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/plugin_registry.h" -#include "xla/stream_executor/temporary_device_memory.h" #if TF_HIPBLASLT #include "xla/stream_executor/rocm/hip_blas_lt.h" #endif @@ -43,32 +45,34 @@ class Stream; namespace gpu { -// Type conversion helper that helps to map non-rocblas types to rocblas types -// Right now, it only converts the Eigen::half type to rocblas_half type -template -struct RocBlasTypeConversionHelper { - using mapped_type = T; -}; - -template <> -struct RocBlasTypeConversionHelper { - using mapped_type = rocblas_half; +template +struct ChooseType { + using type = std::conditional_t< + std::is_same_v, B, + typename ChooseType::type>; }; -template <> -struct RocBlasTypeConversionHelper { - using mapped_type = rocblas_bfloat16; +template +struct ChooseType { + // default case: return the same type Target if there is no recursive match + using type = std::conditional_t, B, Target>; }; -template <> -struct RocBlasTypeConversionHelper> { - using mapped_type = rocblas_float_complex; +template +struct ChooseType { + // default case: return compile error if type is not found + static_assert(std::is_same_v, + "ChooseType: the target type is not found!"); + using type = B; }; -template <> -struct RocBlasTypeConversionHelper> { - using mapped_type = rocblas_double_complex; -}; +// Type conversion helper that helps to map non-rocblas types to rocblas types +template +using RocBlasType_t = + typename ChooseType, + rocblas_float_complex, std::complex, + rocblas_double_complex>::type; class GpuExecutor; @@ -124,49 +128,39 @@ class ROCMBlas : public blas::BlasSupport { // err_on_failure: Whether to print an error if the rocBLAS function // fails. args: Arguments of rocBLAS function. template - bool DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, - bool pointer_mode_host, bool err_on_failure, - Args... args); + absl::Status DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, + bool pointer_mode_host, bool err_on_failure, + Args &&...args); // Convenience functions that call DoBlasInternalImpl with different values // for err_on_failure. template bool DoBlasInternal(FuncT rocblas_func, Stream *stream, - bool pointer_mode_host, Args... args) { - return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, - /*err_on_failure=*/true, args...); + bool pointer_mode_host, Args &&...args) { + auto ret = DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, + /*err_on_failure=*/true, + std::forward(args)...); + return ret.ok(); } - // Same as above, but returns tsl::Status. - template - tsl::Status DoBlasInternalStatus(Args... args) { - if (!DoBlasInternal(args...)) { - return tsl::errors::Internal("Failed calling rocBLAS"); - } - return tsl::OkStatus(); + // Same as above, but returns absl::Status. + template + absl::Status DoBlasInternalStatus(FuncT rocblas_func, Stream *stream, + bool pointer_mode_host, Args &&...args) { + return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, + /*err_on_failure=*/true, + std::forward(args)...); } template bool DoBlasInternalFailureOK(FuncT rocblas_func, Stream *stream, - bool pointer_mode_host, Args... args) { - return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, - /*err_on_failure=*/false, args...); + bool pointer_mode_host, Args &&...args) { + auto ret = DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, + /*err_on_failure=*/false, + std::forward(args)...); + return ret.ok(); } - // A helper allocation function to convert raw pointers memory layout to - // strided flavor - template - tsl::Status AllocateStridedBuffer( - const std::vector::mapped_type *> - &raw_ptrs, - int batch_count, uint64_t batch_stride, - ScratchAllocator *scratch_allocator, Stream *stream, - std::unique_ptr::mapped_type>> *temp_memory, - DeviceMemory::mapped_type> - *device_memory, - bool copy_data, bool &reallocated); - // A helper function to implement DoBlasGemmBatched interfaces for generic // types. // @@ -184,9 +178,9 @@ class ROCMBlas : public blas::BlasSupport { // It will take advantage of the AllocateStridedBuffer subroutine to // reallocate the memory layout to be strided batched. template - tsl::Status DoBlasGemmBatchedInternal( + absl::Status DoBlasGemmBatchedInternal( FuncT rocblas_func, Stream *stream, blas::Transpose transa, - blas::Transpose transb, uint64_t m, uint64 n, uint64 k, T alpha, + blas::Transpose transb, uint64_t m, uint64_t n, uint64_t k, T alpha, DeviceMemorySlice a_ptrs_to_wrappers, int lda, DeviceMemorySlice b_ptrs_to_wrappers, int ldb, T beta, DeviceMemorySlice c_ptrs_to_wrappers, int ldc, int batch_count, @@ -202,12 +196,18 @@ class ROCMBlas : public blas::BlasSupport { // rocBLAS library handle on the device. rocblas_handle blas_ ABSL_GUARDED_BY(mu_); + // container holding solutions vector (to avoid reallocating it each time) + std::vector solutions_; + #if TF_HIPBLASLT rocm::BlasLt blas_lt_; #endif ROCMBlas(const ROCMBlas &) = delete; void operator=(const ROCMBlas &) = delete; + + bool has_mfma_ = false; + bool use_hgemm_alt_impl_ = false; }; } // namespace gpu diff --git a/xla/stream_executor/rocm/rocm_collectives.cc b/xla/stream_executor/rocm/rocm_collectives.cc new file mode 100644 index 0000000000000..44a6b39759ed5 --- /dev/null +++ b/xla/stream_executor/rocm/rocm_collectives.cc @@ -0,0 +1,37 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/stream_executor/gpu/gpu_collectives.h" +#include "xla/stream_executor/gpu/gpu_driver.h" + +namespace stream_executor::gpu { + +absl::StatusOr GpuCollectives::CollectiveMemoryAllocate( + GpuContext* context, uint64_t bytes) { + return absl::UnimplementedError( + "Feature not supported on ROCm platform (CollectiveMemoryAllocate)"); +} + +absl::Status GpuCollectives::CollectiveMemoryDeallocate(GpuContext* context, + void* location) { + return absl::UnimplementedError( + "Feature not supported on ROCm platform (CollectiveMemoryDeallocate)"); +} + +} // namespace stream_executor::gpu diff --git a/xla/stream_executor/rocm/rocm_diagnostics.cc b/xla/stream_executor/rocm/rocm_diagnostics.cc index 91ef8cf92c011..73aad489e6fba 100644 --- a/xla/stream_executor/rocm/rocm_diagnostics.cc +++ b/xla/stream_executor/rocm/rocm_diagnostics.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,19 +35,18 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "absl/strings/strip.h" -#include "tsl/platform/errors.h" #include "tsl/platform/host_info.h" #include "tsl/platform/logging.h" namespace stream_executor { namespace rocm { -string DriverVersionToString(DriverVersion version) { +std::string DriverVersionToString(DriverVersion version) { return absl::StrFormat("%d.%d.%d", std::get<0>(version), std::get<1>(version), std::get<2>(version)); } -string DriverVersionStatusToString(tsl::StatusOr version) { +std::string DriverVersionStatusToString(absl::StatusOr version) { if (!version.ok()) { return version.status().ToString(); } @@ -55,34 +54,34 @@ string DriverVersionStatusToString(tsl::StatusOr version) { return DriverVersionToString(version.value()); } -tsl::StatusOr StringToDriverVersion(const string& value) { - std::vector pieces = absl::StrSplit(value, '.'); +absl::StatusOr StringToDriverVersion(const std::string& value) { + std::vector pieces = absl::StrSplit(value, '.'); if (pieces.size() != 2 && pieces.size() != 3) { - return tsl::Status{absl::StatusCode::kInvalidArgument, - absl::StrFormat("expected %%d.%%d or %%d.%%d.%%d form " - "for driver version; got \"%s\"", - value.c_str())}; + return absl::Status{absl::StatusCode::kInvalidArgument, + absl::StrFormat("expected %%d.%%d or %%d.%%d.%%d form " + "for driver version; got \"%s\"", + value.c_str())}; } int major; int minor; int patch = 0; if (!absl::SimpleAtoi(pieces[0], &major)) { - return tsl::Status{ + return absl::Status{ absl::StatusCode::kInvalidArgument, absl::StrFormat("could not parse major version number \"%s\" as an " "integer from string \"%s\"", pieces[0].c_str(), value.c_str())}; } if (!absl::SimpleAtoi(pieces[1], &minor)) { - return tsl::Status{ + return absl::Status{ absl::StatusCode::kInvalidArgument, absl::StrFormat("could not parse minor version number \"%s\" as an " "integer from string \"%s\"", pieces[1].c_str(), value.c_str())}; } if (pieces.size() == 3 && !absl::SimpleAtoi(pieces[2], &patch)) { - return tsl::Status{ + return absl::Status{ absl::StatusCode::kInvalidArgument, absl::StrFormat("could not parse patch version number \"%s\" as an " "integer from string \"%s\"", @@ -103,7 +102,7 @@ namespace gpu { // -- class Diagnostician -string Diagnostician::GetDevNodePath(int dev_node_ordinal) { +std::string Diagnostician::GetDevNodePath(int dev_node_ordinal) { return absl::StrCat("/dev/kfd", dev_node_ordinal); } @@ -118,10 +117,10 @@ void Diagnostician::LogDiagnosticInformation() { LOG(INFO) << "hostname: " << tsl::port::Hostname(); if (VLOG_IS_ON(1)) { const char* value = getenv("LD_LIBRARY_PATH"); - string library_path = value == nullptr ? "" : value; + std::string library_path = value == nullptr ? "" : value; VLOG(1) << "LD_LIBRARY_PATH is: \"" << library_path << "\""; - std::vector pieces = absl::StrSplit(library_path, ':'); + std::vector pieces = absl::StrSplit(library_path, ':'); for (const auto& piece : pieces) { if (piece.empty()) { continue; @@ -137,11 +136,11 @@ void Diagnostician::LogDiagnosticInformation() { closedir(dir); } } - tsl::StatusOr dso_version = FindDsoVersion(); + absl::StatusOr dso_version = FindDsoVersion(); LOG(INFO) << "librocm reported version is: " << rocm::DriverVersionStatusToString(dso_version); - tsl::StatusOr kernel_version = FindKernelDriverVersion(); + absl::StatusOr kernel_version = FindKernelDriverVersion(); LOG(INFO) << "kernel reported version is: " << rocm::DriverVersionStatusToString(kernel_version); @@ -152,8 +151,8 @@ void Diagnostician::LogDiagnosticInformation() { // Iterates through loaded DSOs with DlIteratePhdrCallback to find the // driver-interfacing DSO version number. Returns it as a string. -tsl::StatusOr Diagnostician::FindDsoVersion() { - tsl::StatusOr result{tsl::Status{ +absl::StatusOr Diagnostician::FindDsoVersion() { + absl::StatusOr result{absl::Status{ absl::StatusCode::kNotFound, "was unable to find librocm.so DSO loaded into this program"}}; @@ -177,11 +176,11 @@ tsl::StatusOr Diagnostician::FindDsoVersion() { if (dot == nullptr) { return 0; } - string dso_version = dot + strlen(so_suffix); + std::string dso_version = dot + strlen(so_suffix); // TODO(b/22689637): Eliminate the explicit namespace if possible. auto stripped_dso_version = absl::StripSuffix(dso_version, ".ld64"); - auto result = static_cast*>(data); - *result = rocm::StringToDriverVersion(string(stripped_dso_version)); + auto result = static_cast*>(data); + *result = rocm::StringToDriverVersion(std::string(stripped_dso_version)); return 1; } return 0; @@ -192,30 +191,30 @@ tsl::StatusOr Diagnostician::FindDsoVersion() { return result; } -tsl::StatusOr Diagnostician::FindKernelModuleVersion( - const string& driver_version_file_contents) { +absl::StatusOr Diagnostician::FindKernelModuleVersion( + const std::string& driver_version_file_contents) { static const char* kDriverFilePrelude = "Kernel Module "; size_t offset = driver_version_file_contents.find(kDriverFilePrelude); - if (offset == string::npos) { - return tsl::Status{ + if (offset == std::string::npos) { + return absl::Status{ absl::StatusCode::kNotFound, absl::StrCat("could not find kernel module information in " "driver version file contents: \"", driver_version_file_contents, "\"")}; } - string version_and_rest = driver_version_file_contents.substr( - offset + strlen(kDriverFilePrelude), string::npos); + std::string version_and_rest = driver_version_file_contents.substr( + offset + strlen(kDriverFilePrelude), std::string::npos); size_t space_index = version_and_rest.find(" "); auto kernel_version = version_and_rest.substr(0, space_index); // TODO(b/22689637): Eliminate the explicit namespace if possible. auto stripped_kernel_version = absl::StripSuffix(kernel_version, ".ld64"); - return rocm::StringToDriverVersion(string(stripped_kernel_version)); + return rocm::StringToDriverVersion(std::string(stripped_kernel_version)); } void Diagnostician::WarnOnDsoKernelMismatch( - tsl::StatusOr dso_version, - tsl::StatusOr kernel_version) { + absl::StatusOr dso_version, + absl::StatusOr kernel_version) { if (kernel_version.ok() && dso_version.ok() && dso_version.value() == kernel_version.value()) { LOG(INFO) << "kernel version seems to match DSO: " @@ -229,9 +228,9 @@ void Diagnostician::WarnOnDsoKernelMismatch( } } -tsl::StatusOr Diagnostician::FindKernelDriverVersion() { - auto status = tsl::Status{absl::StatusCode::kUnimplemented, - "kernel reported driver version not implemented"}; +absl::StatusOr Diagnostician::FindKernelDriverVersion() { + auto status = absl::Status{absl::StatusCode::kUnimplemented, + "kernel reported driver version not implemented"}; return status; } diff --git a/xla/stream_executor/rocm/rocm_diagnostics.h b/xla/stream_executor/rocm/rocm_diagnostics.h index 16d07121a4d99..f9bc2c2c484b5 100644 --- a/xla/stream_executor/rocm/rocm_diagnostics.h +++ b/xla/stream_executor/rocm/rocm_diagnostics.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,13 +25,13 @@ namespace rocm { using DriverVersion = gpu::DriverVersion; // Converts a parsed driver version to string form. -string DriverVersionToString(DriverVersion version); +std::string DriverVersionToString(DriverVersion version); // Converts a parsed driver version or status value to natural string form. -string DriverVersionStatusToString(tsl::StatusOr version); +std::string DriverVersionStatusToString(absl::StatusOr version); // Converts a string of a form like "331.79" to a DriverVersion{331, 79}. -tsl::StatusOr StringToDriverVersion(const string& value); +absl::StatusOr StringToDriverVersion(const std::string& value); using Diagnostician = gpu::Diagnostician; diff --git a/xla/stream_executor/rocm/rocm_dnn.cc b/xla/stream_executor/rocm/rocm_dnn.cc index 9a200ccc8882c..c4c60b9c01f05 100644 --- a/xla/stream_executor/rocm/rocm_dnn.cc +++ b/xla/stream_executor/rocm/rocm_dnn.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -43,12 +43,12 @@ limitations under the License. #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/util/determinism.h" +#include "xla/tsl/util/env_var.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/hash.h" #include "tsl/platform/logging.h" -#include "tsl/util/determinism.h" -#include "tsl/util/env_var.h" namespace { @@ -77,6 +77,23 @@ using dnn::PoolingDescriptor; namespace gpu { +// Populates the profile result if not empty. +static absl::Status PopulateProfileFromTimer( + std::optional& timer, const dnn::AlgorithmDesc& algorithm, + dnn::ProfileResult* profile_result, + std::optional scratch_size = std::nullopt) { + if (profile_result) { + TF_ASSIGN_OR_RETURN(absl::Duration duration, timer->GetElapsedDuration()); + profile_result->set_algorithm(algorithm); + profile_result->set_elapsed_time_in_ms( + absl::ToDoubleMilliseconds(duration)); + if (scratch_size.has_value()) { + profile_result->set_scratch_size(*scratch_size); + } + } + return absl::OkStatus(); +} + string ToString(miopenStatus_t status) { switch (status) { case miopenStatusSuccess: @@ -224,31 +241,31 @@ namespace wrap { #else -#define STREAM_EXECUTOR_MIOPEN_WRAP(__name) \ - struct DynLoadShim__##__name { \ - static const char* kName; \ - using FuncPtrT = std::add_pointer::type; \ - static void* GetDsoHandle() { \ - auto s = internal::CachedDsoLoader::GetMiopenDsoHandle(); \ - return s.value(); \ - } \ - static FuncPtrT LoadOrDie() { \ - void* f; \ - auto s = tsl::Env::Default() \ - -> GetSymbolFromLibrary(GetDsoHandle(), kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in miopen DSO; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - } \ - static FuncPtrT DynLoad() { \ - static FuncPtrT f = LoadOrDie(); \ - return f; \ - } \ - template \ - miopenStatus_t operator()(Args... args) { \ - return DynLoad()(args...); \ - } \ - } __name; \ +#define STREAM_EXECUTOR_MIOPEN_WRAP(__name) \ + struct DynLoadShim__##__name { \ + static const char* kName; \ + using FuncPtrT = std::add_pointer::type; \ + static void* GetDsoHandle() { \ + auto s = internal::CachedDsoLoader::GetMiopenDsoHandle(); \ + return s.value(); \ + } \ + static FuncPtrT LoadOrDie() { \ + void* f; \ + auto s = tsl::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in miopen DSO; dlerror: " << s.message(); \ + return reinterpret_cast(f); \ + } \ + static FuncPtrT DynLoad() { \ + static FuncPtrT f = LoadOrDie(); \ + return f; \ + } \ + template \ + miopenStatus_t operator()(Args... args) { \ + return DynLoad()(args...); \ + } \ + } __name; \ const char* DynLoadShim__##__name::kName = #__name; #endif @@ -727,14 +744,14 @@ MIOpenSupport::MIOpenSupport(GpuExecutor* parent) : parent_(parent) { if (enable_pooling_cache) m_pooling_cache_allowed = true; } -tsl::Status MIOpenSupport::Init() { +absl::Status MIOpenSupport::Init() { ScopedActivateExecutorContext context(parent_); miopenHandle_t miopen_handle = nullptr; auto status = wrap::miopenCreateWithStream( reinterpret_cast(&miopen_handle), (hipStream_t)(0)); if (status == miopenStatusSuccess) { miopen_.reset(new MIOpenAccess(miopen_handle)); - return tsl::OkStatus(); + return absl::OkStatus(); } CHECK_EQ(miopen_handle, nullptr); @@ -751,12 +768,12 @@ tsl::Status MIOpenSupport::Init() { } } - return tsl::Status{absl::StatusCode::kInternal, - absl::StrCat("miopen library could not create a handle: ", - ToString(status))}; + return absl::Status{absl::StatusCode::kInternal, + absl::StrCat("miopen library could not create a handle: ", + ToString(status))}; } -tsl::StatusOr MIOpenSupport::GetVersion() { +absl::StatusOr MIOpenSupport::GetVersion() { // ROCM TODO: retrieve MIOpen version with its API return stream_executor::dnn::VersionInfo(1, 3, 0); } @@ -1900,7 +1917,7 @@ class MixinBase {}; #define RETURN_IF_MIOPEN_ERROR(STATUS, ...) \ if (!SE_PREDICT_TRUE((STATUS) == miopenStatusSuccess)) { \ string error_msg = absl::StrCat(ToString(STATUS), " ", __VA_ARGS__); \ - SetFailure(::tsl::Status(absl::StatusCode::kUnknown, error_msg)); \ + SetFailure(::absl::UnknownError(error_msg)); \ LOG(ERROR) << error_msg; \ return; \ } @@ -1909,11 +1926,11 @@ template class MIOpenDescriptorCommon : public MixinBase { public: bool ok() const { return status_.ok(); } - tsl::Status Status() const { return status_; } + absl::Status Status() const { return status_; } protected: - void SetFailure(const tsl::Status& status) { status_.Update(status); } - tsl::Status status_; + void SetFailure(const absl::Status& status) { status_.Update(status); } + absl::Status status_; }; class MIOpenRnnParamsDescriptor : public MIOpenDescriptorCommon { @@ -1947,7 +1964,7 @@ class MIOpenRnnParamsDescriptor : public MIOpenDescriptorCommon { int64_t params_size_in_bytes_; ParamsRegions weights_; ParamsRegions biases_; - tsl::Status status_; + absl::Status status_; MIOpenRnnParamsDescriptor(const MIOpenRnnParamsDescriptor&) = delete; void operator=(const MIOpenRnnParamsDescriptor&) = delete; }; @@ -2035,7 +2052,7 @@ class MIOpenRnnDescriptor : public MIOpenDescriptorCommon { miopenRNNMode_t rnn_mode_; miopenDataType_t data_type_; dnn::AlgorithmConfig algorithm_config_; - tsl::Status status_; + absl::Status status_; // no dropout in MIOpen. // std::unique_ptr miopen_dropout_desc_; std::unique_ptr miopen_params_desc_; @@ -2074,7 +2091,7 @@ class MIOpenRnnSequenceTensorDescriptor string error_msg = absl::StrCat("sequence length must be positive: ", seq_length); LOG(ERROR) << error_msg; - SetFailure(tsl::Status(absl::StatusCode::kUnknown, error_msg)); + SetFailure(absl::UnknownError(error_msg)); return; } auto status = wrap::miopenCreateTensorDescriptor(&handle); @@ -2111,7 +2128,7 @@ class MIOpenRnnSequenceTensorDescriptor int data_size_; miopenDataType_t data_type_; std::vector handles_; - tsl::Status status_; + absl::Status status_; MIOpenRnnSequenceTensorDescriptor(const MIOpenRnnSequenceTensorDescriptor&) = delete; void operator=(const MIOpenRnnSequenceTensorDescriptor&) = delete; @@ -2156,7 +2173,7 @@ class MIOpenRnnStateTensorDescriptor int num_layers_; int batch_size_; int data_size_; - tsl::Status status_; + absl::Status status_; miopenDataType_t data_type_; MIOpenRnnStateTensorDescriptor(const MIOpenRnnStateTensorDescriptor&) = delete; @@ -2275,7 +2292,9 @@ bool CreateRnnWorkspace(Stream* stream, miopenHandle_t miopen_handle, return false; } - stream->ThenMemZero(workspace, workspace_size_in_bytes); + if (!stream->MemZero(workspace, workspace_size_in_bytes).ok()) { + return false; + } } else { *workspace = DeviceMemory(); } @@ -2285,7 +2304,7 @@ bool CreateRnnWorkspace(Stream* stream, miopenHandle_t miopen_handle, } // namespace template -bool MIOpenSupport::DoRnnForwardImpl( +absl::Status MIOpenSupport::DoRnnForwardImpl( Stream* stream, const MIOpenRnnDescriptor& rnn_desc, const MIOpenRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, @@ -2311,7 +2330,7 @@ bool MIOpenSupport::DoRnnForwardImpl( &model_dims); if (!res) { LOG(ERROR) << "Invalid parameters for RNN Model"; - return false; + return absl::InternalError("ExtractAndCheckRnnForward returned false"); } auto miopen = miopen_->GetHandle(parent_, stream); @@ -2320,7 +2339,7 @@ bool MIOpenSupport::DoRnnForwardImpl( if (!CheckRNNParameterSize(miopen.handle(), rnn_desc, input_desc)) { LOG(ERROR) << "Invalid parameters"; - return false; + return absl::InternalError("CheckRNNParameterSize returned false"); } // create the workspace @@ -2328,8 +2347,7 @@ bool MIOpenSupport::DoRnnForwardImpl( if (!CreateRnnWorkspace(stream, miopen.handle(), rnn_desc, input_desc, workspace_allocator, &workspace)) { LOG(ERROR) << "Unable to create rnn workspace"; - - return false; + return absl::InternalError("CreateRnnWorkspace returned false"); } // query the reserve space size @@ -2343,7 +2361,8 @@ bool MIOpenSupport::DoRnnForwardImpl( &reserve_space_size_in_bytes /*sizeInBytes*/); if (status != miopenStatusSuccess) { LOG(ERROR) << "Unable to query reserve space size: " << ToString(status); - return false; + return absl::InternalError( + "miopenGetRNNTrainingReserveSize returned failure"); } if (reserve_space_size_in_bytes > 0) { @@ -2351,23 +2370,21 @@ bool MIOpenSupport::DoRnnForwardImpl( reserve_space_allocator->AllocateBytes(reserve_space_size_in_bytes); if (!allocated.ok() || (reserve_space = allocated.value()) == nullptr) { LOG(ERROR) << "Fail to allocate RNN reserve space"; - return false; + return absl::InternalError("AllocateBytes for RNN failed"); } - stream->ThenMemZero(&reserve_space, reserve_space_size_in_bytes); + TF_RETURN_IF_ERROR( + stream->MemZero(&reserve_space, reserve_space_size_in_bytes)); } } - std::optional timer; const bool is_profiling = output_profile_result != nullptr; - if (is_profiling) { - auto timer_or_status = GpuTimer::Create(AsGpuStream(stream)); - if (!timer_or_status.ok()) { - LOG(ERROR) << "Failed to create timer"; - return false; - } - timer.emplace(std::move(*timer_or_status)); - } + TF_ASSIGN_OR_RETURN( + std::optional timer, + GpuTimer::CreateIfNeeded( + stream, + output_profile_result && output_profile_result->warmup_run_executed(), + is_profiling)); // make the forward call if (!is_training) { @@ -2386,7 +2403,7 @@ bool MIOpenSupport::DoRnnForwardImpl( if (status != miopenStatusSuccess) { LOG(ERROR) << "Failed to call miopenRNNForwardInference: " << ToString(status); - return false; + return absl::InternalError("miopenRNNForwardInference failed"); } } else { auto status = wrap::miopenRNNForwardTraining( @@ -2405,27 +2422,21 @@ bool MIOpenSupport::DoRnnForwardImpl( if (status != miopenStatusSuccess) { LOG(ERROR) << "Failed to call miopenRNNForwardTraining" << ToString(status); - return false; + return absl::InternalError("miopenRNNForwardTraining failed"); } } if (is_profiling) { - tsl::StatusOr elapsed = timer->GetElapsedDuration(); - if (!elapsed.ok()) { - LOG(ERROR) << "Failed to get elapsed duration"; - return false; - } - auto algo_desc = *rnn_desc.algorithm_config().algorithm(); - output_profile_result->set_algorithm(algo_desc); - output_profile_result->set_elapsed_time_in_ms( - absl::ToDoubleMilliseconds(*elapsed)); + TF_RETURN_IF_ERROR(PopulateProfileFromTimer( + timer, *rnn_desc.algorithm_config().algorithm(), + output_profile_result)); } - return true; + return absl::OkStatus(); } template -bool MIOpenSupport::DoRnnBackwardImpl( +absl::Status MIOpenSupport::DoRnnBackwardImpl( Stream* stream, const MIOpenRnnDescriptor& rnn_desc, const MIOpenRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, @@ -2457,7 +2468,7 @@ bool MIOpenSupport::DoRnnBackwardImpl( output_h_desc, output_h_data, output_c_desc, output_c_data, &model_dims); if (!res) { LOG(ERROR) << "Invalid parameters for RNN Model"; - return false; + return absl::InternalError("ExtractAndCheckRnnForward failed"); } auto miopen = miopen_->GetHandle(parent_, stream); @@ -2466,7 +2477,7 @@ bool MIOpenSupport::DoRnnBackwardImpl( if (!CheckRNNParameterSize(miopen.handle(), rnn_desc, input_desc)) { LOG(ERROR) << "Invalid parameters"; - return false; + return absl::InternalError("CheckRNNParameterSize failed"); } // create the workspace @@ -2474,7 +2485,7 @@ bool MIOpenSupport::DoRnnBackwardImpl( if (!CreateRnnWorkspace(stream, miopen.handle(), rnn_desc, input_desc, workspace_allocator, &workspace)) { LOG(ERROR) << "Unable to create rnn workspace"; - return false; + return absl::InternalError("CreateRnnWorkspace failed"); } // workaround for missing initialization support in MIOpen. @@ -2483,29 +2494,29 @@ bool MIOpenSupport::DoRnnBackwardImpl( auto size_data = input_desc.seq_length() * input_desc.batch_size() * input_desc.data_size(); if ((size_data > 0) && (input_backprop_data->opaque() != nullptr)) - stream->ThenMemZero(input_backprop_data, size_data * type_size); + TF_RETURN_IF_ERROR( + stream->MemZero(input_backprop_data, size_data * type_size)); size_data = input_h_desc.num_layers() * input_h_desc.batch_size() * input_h_desc.data_size(); if ((size_data > 0) && (input_h_backprop_data->opaque() != nullptr)) - stream->ThenMemZero(input_h_backprop_data, size_data * type_size); + TF_RETURN_IF_ERROR( + stream->MemZero(input_h_backprop_data, size_data * type_size)); size_data = input_c_desc.num_layers() * input_c_desc.batch_size() * input_c_desc.data_size(); if ((size_data > 0) && (input_c_backprop_data->opaque() != nullptr)) - stream->ThenMemZero(input_c_backprop_data, size_data * type_size); + TF_RETURN_IF_ERROR( + stream->MemZero(input_c_backprop_data, size_data * type_size)); - std::optional timer; const bool is_profiling = output_profile_result != nullptr; - if (is_profiling) { - auto timer_or_status = GpuTimer::Create(AsGpuStream(stream)); - if (!timer_or_status.ok()) { - LOG(ERROR) << "Failed to create timer"; - return false; - } - timer.emplace(std::move(*timer_or_status)); - } + TF_ASSIGN_OR_RETURN( + std::optional timer, + GpuTimer::CreateIfNeeded( + stream, + output_profile_result && output_profile_result->warmup_run_executed(), + is_profiling)); // make the backward data call auto status = wrap::miopenRNNBackwardData( @@ -2529,12 +2540,13 @@ bool MIOpenSupport::DoRnnBackwardImpl( reserve_space_data->size() /*reserveSpaceSizeInBytes*/); if (status != miopenStatusSuccess) { LOG(ERROR) << "Failed to call miopenRNNBackwardData: " << ToString(status); - return false; + return absl::InternalError("miopenRNNBackwardData failed"); } if (params_backprop_data != nullptr) { // Clear the dw to zeros. - stream->ThenMemZero(params_backprop_data, params_backprop_data->size()); + TF_RETURN_IF_ERROR( + stream->MemZero(params_backprop_data, params_backprop_data->size())); // make the backward weight call status = wrap::miopenRNNBackwardWeights( miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/, @@ -2549,23 +2561,17 @@ bool MIOpenSupport::DoRnnBackwardImpl( if (status != miopenStatusSuccess) { LOG(ERROR) << "Failed to call miopenRNNBackwardWeights: " << ToString(status); - return false; + return absl::InternalError("miopenRNNBackwardWeights failed"); } } if (is_profiling) { - tsl::StatusOr elapsed = timer->GetElapsedDuration(); - if (!elapsed.ok()) { - LOG(ERROR) << "Failed to get elapsed duration"; - return false; - } - auto algo_desc = *rnn_desc.algorithm_config().algorithm(); - output_profile_result->set_algorithm(algo_desc); - output_profile_result->set_elapsed_time_in_ms( - absl::ToDoubleMilliseconds(*elapsed)); + TF_RETURN_IF_ERROR(PopulateProfileFromTimer( + timer, *rnn_desc.algorithm_config().algorithm(), + output_profile_result)); } - return true; + return absl::OkStatus(); } MIOpenRnnParamsDescriptor::MIOpenRnnParamsDescriptor( @@ -2644,7 +2650,7 @@ class MIOpenCTCLossDescriptor { void operator=(const MIOpenCTCLossDescriptor&) = delete; }; -tsl::Status MIOpenSupport::DoPrepareForCtcLoss( +absl::Status MIOpenSupport::DoPrepareForCtcLoss( Stream* stream, dnn::DataType element_type, const dnn::RnnStateTensorDescriptor& probs_desc, const dnn::RnnStateTensorDescriptor& grads_desc, @@ -2675,7 +2681,7 @@ tsl::Status MIOpenSupport::DoPrepareForCtcLoss( if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenDestroyCTCLossDescriptor failed: " << ToString(status); - return tsl::errors::Internal( + return absl::InternalError( "Failed to determine scratch memory size for MIOpen CTC Loss"); } @@ -2684,7 +2690,7 @@ tsl::Status MIOpenSupport::DoPrepareForCtcLoss( // Allocate the workspace. if (workspace_size_in_bytes != 0) { if (scratch_allocator == nullptr) { - return tsl::errors::Internal( + return absl::InternalError( "An allocator must be specified when scratch memory is needed"); } auto scratch_or = scratch_allocator->AllocateBytes(workspace_size_in_bytes); @@ -2698,16 +2704,16 @@ tsl::Status MIOpenSupport::DoPrepareForCtcLoss( "larger number (e.g. 8192) to increase the max memory limit.\n" << "\tIncreasing the max memory limit might help resolve this " "error"; - return tsl::errors::Internal( + return absl::InternalError(absl::StrCat( "Failed to allocate scratch memory for MIOpen CTC Loss, of size: ", - workspace_size_in_bytes); + workspace_size_in_bytes)); } } - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status MIOpenSupport::DoCtcLossImpl( +absl::Status MIOpenSupport::DoCtcLossImpl( Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc, const DeviceMemoryBase probs_data, absl::Span labels_data, absl::Span labels_lengths_data, @@ -2731,13 +2737,13 @@ tsl::Status MIOpenSupport::DoCtcLossImpl( scratch_memory.opaque(), scratch_memory.size()); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenCTCLoss failed: " << ToString(status); - return tsl::errors::Internal("Failure during MIOpen CTC Loss"); + return absl::InternalError("Failure during MIOpen CTC Loss"); } - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status MIOpenSupport::DoCtcLoss( +absl::Status MIOpenSupport::DoCtcLoss( Stream* stream, dnn::DataType element_type, const dnn::RnnStateTensorDescriptor& probs_desc, const DeviceMemoryBase probs_data, absl::Span labels_data, @@ -2748,9 +2754,9 @@ tsl::Status MIOpenSupport::DoCtcLoss( int ctc_loss_algo_id) { // Current MIOPen CTC Loss only supports the float datatype if (element_type != dnn::DataType::kFloat) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "MIOpenCTCLossDescriptor is supported only when the " - "DataType is float"); + return absl::InvalidArgumentError( + "MIOpenCTCLossDescriptor is supported only when the " + "DataType is float"); } MIOpenCTCLossDescriptor miopen_ctc_loss_desc(ToMIOpenDataType(element_type)); @@ -2767,8 +2773,8 @@ tsl::Status MIOpenSupport::DoCtcLoss( scratch_memory, ctc_loss_algo_id); } -tsl::StatusOr> -MIOpenSupport::createRnnDescriptor( +absl::StatusOr> +MIOpenSupport::CreateRnnDescriptor( int num_layers, int hidden_size, int input_size, int cell_size, int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, @@ -2778,14 +2784,13 @@ MIOpenSupport::createRnnDescriptor( // ROCM TODO: batch_size is used in dynamic persistent RNN algorithm and is // not supported by MIOpen now. if (use_padded_io) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "ROCm MIOpen only supports packed input output."); + return absl::InvalidArgumentError( + "ROCm MIOpen only supports packed input output."); } bool use_projection = cell_size != 0 && hidden_size < cell_size; if (use_projection) { - return tsl::Status( - absl::StatusCode::kInvalidArgument, + return absl::InvalidArgumentError( "ROCm MIOpen does not support RNN ProjectionLayers yet."); } @@ -2799,12 +2804,12 @@ MIOpenSupport::createRnnDescriptor( if (!rnn_desc->ok()) { return rnn_desc->Status(); } - return tsl::StatusOr>( + return absl::StatusOr>( std::move(rnn_desc)); } -tsl::StatusOr> -MIOpenSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size, +absl::StatusOr> +MIOpenSupport::CreateRnnSequenceTensorDescriptor(int seq_length, int batch_size, int data_size, dnn::DataType data_type) { std::unique_ptr seq_desc( @@ -2813,12 +2818,12 @@ MIOpenSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size, if (!seq_desc->ok()) { return seq_desc->Status(); } - return tsl::StatusOr>( + return absl::StatusOr>( std::move(seq_desc)); } -tsl::StatusOr> -MIOpenSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size, +absl::StatusOr> +MIOpenSupport::CreateRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, dnn::DataType data_type) { std::unique_ptr state_desc( @@ -2827,7 +2832,7 @@ MIOpenSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size, if (!state_desc->ok()) { return state_desc->Status(); } - return tsl::StatusOr>( + return absl::StatusOr>( std::move(state_desc)); } @@ -2865,12 +2870,14 @@ bool MIOpenSupport::DoRnnForward( const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc = static_cast(output_c_desc); - return DoRnnForwardImpl( - stream, miopen_rnn_desc, miopen_input_desc, input_data, - miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data, - params, miopen_output_desc, output_data, miopen_output_h_desc, - output_h_data, miopen_output_c_desc, output_c_data, is_training, - reserve_space_allocator, workspace_allocator, output_profile_result); + return IsStatusOk( + DoRnnForwardImpl( + stream, miopen_rnn_desc, miopen_input_desc, input_data, + miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data, + params, miopen_output_desc, output_data, miopen_output_h_desc, + output_h_data, miopen_output_c_desc, output_c_data, is_training, + reserve_space_allocator, workspace_allocator, output_profile_result), + /*report_error=*/!output_profile_result); } bool MIOpenSupport::DoRnnForward( @@ -2906,12 +2913,14 @@ bool MIOpenSupport::DoRnnForward( const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc = static_cast(output_c_desc); - return DoRnnForwardImpl( - stream, miopen_rnn_desc, miopen_input_desc, input_data, - miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data, - params, miopen_output_desc, output_data, miopen_output_h_desc, - output_h_data, miopen_output_c_desc, output_c_data, is_training, - reserve_space_allocator, workspace_allocator, output_profile_result); + return IsStatusOk( + DoRnnForwardImpl( + stream, miopen_rnn_desc, miopen_input_desc, input_data, + miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data, + params, miopen_output_desc, output_data, miopen_output_h_desc, + output_h_data, miopen_output_c_desc, output_c_data, is_training, + reserve_space_allocator, workspace_allocator, output_profile_result), + /*report_error=*/!output_profile_result); } bool MIOpenSupport::DoRnnForward( @@ -2978,14 +2987,17 @@ bool MIOpenSupport::DoRnnBackward( const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc = static_cast(output_c_desc); - return DoRnnBackwardImpl( - stream, miopen_rnn_desc, miopen_input_desc, input_data, - miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data, - params, miopen_output_desc, output_data, miopen_output_h_desc, - output_h_data, miopen_output_c_desc, output_c_data, output_backprop_data, - output_h_backprop_data, output_c_backprop_data, input_backprop_data, - input_h_backprop_data, input_c_backprop_data, params_backprop_data, - reserve_space_data, workspace_allocator, output_profile_result); + return IsStatusOk( + DoRnnBackwardImpl( + stream, miopen_rnn_desc, miopen_input_desc, input_data, + miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data, + params, miopen_output_desc, output_data, miopen_output_h_desc, + output_h_data, miopen_output_c_desc, output_c_data, + output_backprop_data, output_h_backprop_data, output_c_backprop_data, + input_backprop_data, input_h_backprop_data, input_c_backprop_data, + params_backprop_data, reserve_space_data, workspace_allocator, + output_profile_result), + /*report_error=*/true); } bool MIOpenSupport::DoRnnBackward( @@ -3028,14 +3040,17 @@ bool MIOpenSupport::DoRnnBackward( const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc = static_cast(output_c_desc); - return DoRnnBackwardImpl( - stream, miopen_rnn_desc, miopen_input_desc, input_data, - miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data, - params, miopen_output_desc, output_data, miopen_output_h_desc, - output_h_data, miopen_output_c_desc, output_c_data, output_backprop_data, - output_h_backprop_data, output_c_backprop_data, input_backprop_data, - input_h_backprop_data, input_c_backprop_data, params_backprop_data, - reserve_space_data, workspace_allocator, output_profile_result); + return IsStatusOk( + DoRnnBackwardImpl( + stream, miopen_rnn_desc, miopen_input_desc, input_data, + miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data, + params, miopen_output_desc, output_data, miopen_output_h_desc, + output_h_data, miopen_output_c_desc, output_c_data, + output_backprop_data, output_h_backprop_data, output_c_backprop_data, + input_backprop_data, input_h_backprop_data, input_c_backprop_data, + params_backprop_data, reserve_space_data, workspace_allocator, + output_profile_result), + /*report_error=*/true); } bool MIOpenSupport::DoRnnBackward( @@ -3095,7 +3110,7 @@ void MIOpenDeallocatorCallback(void* ctx, void* mem) { // reclaim the memory } -tsl::Status MIOpenSupport::DoPrepareForConvolution( +absl::Status MIOpenSupport::DoPrepareForConvolution( dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, const dnn::FilterDescriptor& filter_descriptor, @@ -3120,7 +3135,7 @@ tsl::Status MIOpenSupport::DoPrepareForConvolution( // allocate scratch memory if (scratch_memory_size != 0) { if (scratch_allocator == nullptr) { - return tsl::errors::Internal( + return absl::InternalError( "An allocator must be specified when scratch memory is needed"); } auto allocated = scratch_allocator->AllocateBytes(scratch_memory_size); @@ -3134,12 +3149,12 @@ tsl::Status MIOpenSupport::DoPrepareForConvolution( "larger number (e.g. 8192) to increase the max memory limit.\n" << "\tIncreasing the max memory limit might help resolve this " "error"; - return tsl::errors::Internal( - "Failed to allocate scratch memory of size: ", scratch_memory_size); + return absl::InternalError(absl::StrCat( + "Failed to allocate scratch memory of size: ", scratch_memory_size)); } } - return tsl::OkStatus(); + return absl::OkStatus(); } class RocmConvRunner : public dnn::ConvRunner { @@ -3177,15 +3192,15 @@ class RocmConvRunner : public dnn::ConvRunner { size_t GetWorkspaceSize() const override { return workspace_size_; } - tsl::StatusOr ToAlgorithmDesc() const override { + absl::StatusOr ToAlgorithmDesc() const override { return {{algo_id_, false, workspace_size_}}; } - tsl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result, - DeviceMemoryBase scratch_memory, - DeviceMemoryBase input_data, - DeviceMemoryBase filter_data, - DeviceMemoryBase output_data) const override { + absl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result, + DeviceMemoryBase scratch_memory, + DeviceMemoryBase input_data, + DeviceMemoryBase filter_data, + DeviceMemoryBase output_data) const override { auto miopen = miopen_->GetHandle(parent_, stream); // Alpha is the scaling factor for input. float alpha = 1.0; @@ -3193,9 +3208,12 @@ class RocmConvRunner : public dnn::ConvRunner { float beta = 0.0; const bool is_profiling = profile_result != nullptr; - TF_ASSIGN_OR_RETURN( - std::optional timer, - GpuTimer::CreateIfNeeded(AsGpuStream(stream), is_profiling)); + TF_ASSIGN_OR_RETURN(std::optional timer, + GpuTimer::CreateIfNeeded( + stream, + output_profile_result && + output_profile_result->warmup_run_executed(), + is_profiling)); miopenStatus_t status = miopenStatusSuccess; switch (kind_) { @@ -3258,8 +3276,8 @@ class RocmConvRunner : public dnn::ConvRunner { break; } default: - return tsl::errors::Internal("Unexpected convolution kind ", - static_cast(kind_)); + return absl::InternalError(absl::StrCat("Unexpected convolution kind ", + static_cast(kind_))); } if (is_profiling) { @@ -3275,11 +3293,12 @@ class RocmConvRunner : public dnn::ConvRunner { } if (status != miopenStatusSuccess) { - return tsl::errors::Internal("Failed to enqueue convolution on stream: ", - ::stream_executor::gpu::ToString(status)); + return absl::InternalError( + absl::StrCat("Failed to enqueue convolution on stream: ", + ::stream_executor::gpu::ToString(status))); } - return tsl::OkStatus(); + return absl::OkStatus(); } private: @@ -3296,7 +3315,7 @@ class RocmConvRunner : public dnn::ConvRunner { ScopedConvolutionDescriptor conv_desc_; }; -tsl::Status MIOpenSupport::DoConvolve( +absl::Status MIOpenSupport::DoConvolve( dnn::ConvolutionKind kind, dnn::DataType element_type, dnn::DataType output_type, Stream* stream, const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, @@ -3316,7 +3335,7 @@ tsl::Status MIOpenSupport::DoConvolve( filter_data, output_data); } -tsl::Status MIOpenSupport::GetConvolveRunners( +absl::Status MIOpenSupport::GetConvolveRunners( bool use_cudnn_frontend, dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType output_type, Stream* stream, const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, @@ -3327,7 +3346,7 @@ tsl::Status MIOpenSupport::GetConvolveRunners( ScratchAllocator* scratch_allocator, const NumericOptions& numeric_options, std::vector>* out_runners) { if (input_type != output_type) { - return tsl::errors::Unimplemented( + return absl::UnimplementedError( absl::StrFormat("MIOpen backend does not support different input and " "output types: %d != %d", input_type, output_type)); @@ -3338,8 +3357,7 @@ tsl::Status MIOpenSupport::GetConvolveRunners( kind, input_type, stream, input_descriptor, input_data, filter_descriptor, filter_data, output_descriptor, output_data, convolution_descriptor, scratch_allocator, &profile_results)) { - return tsl::Status( - absl::StatusCode::kUnknown, + return absl::UnknownError( "GetConvolveRunners: GetMIOpenConvolveAlgorithms failed"); } @@ -3352,10 +3370,10 @@ tsl::Status MIOpenSupport::GetConvolveRunners( out_runners->push_back(std::move(runner)); } - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::StatusOr> +absl::StatusOr> MIOpenSupport::ConvolveRunnerFromDesc( Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, dnn::ConvolutionKind kind, dnn::DataType input_type, @@ -3364,7 +3382,7 @@ MIOpenSupport::ConvolveRunnerFromDesc( const dnn::BatchDescriptor& output_descriptor, const dnn::ConvolutionDescriptor& convolution_descriptor) { if (input_type != output_type) { - return tsl::errors::Unimplemented( + return absl::UnimplementedError( absl::StrFormat("MIOpen backend does not support different input and " "output types: %d != %d", input_type, output_type)); @@ -3372,7 +3390,7 @@ MIOpenSupport::ConvolveRunnerFromDesc( auto workspace_size = algorithm_desc.workspace_size(); if (!workspace_size) { - return tsl::errors::InvalidArgument( + return absl::InvalidArgumentError( "MIOpenSupport::ConvolveRunnerFromDesc requires " "AlgorithmProto.workspace_size, but it was missing."); } @@ -4017,9 +4035,9 @@ void launchInplaceBiasActivation(hipStream_t stream, void* c_data, class ROCmFusedMatmulRunner : public dnn::FusedMatmulRunner { template - tsl::Status gemm(Stream*, DeviceMemoryBase /* a_data */, - DeviceMemoryBase /* b_data */, - DeviceMemoryBase /* c_data */) const; + absl::Status gemm(Stream*, DeviceMemoryBase /* a_data */, + DeviceMemoryBase /* b_data */, + DeviceMemoryBase /* c_data */) const; Stream* _stream; dnn::DataType _input_type, _bias_type, _output_type; @@ -4032,14 +4050,14 @@ class ROCmFusedMatmulRunner : public dnn::FusedMatmulRunner { std::string ToString() const override; size_t GetWorkspaceSize() const override { return 0; } // Convert to an AlgorithmDesc for AoT compilation or autotuning - tsl::StatusOr ToAlgorithmDesc() const override; + absl::StatusOr ToAlgorithmDesc() const override; // Launch the operation, with the signature determined by `Sig`. - tsl::Status operator()(Stream*, dnn::ProfileResult*, - DeviceMemoryBase scratch_memory, - DeviceMemoryBase /* a_data */, - DeviceMemoryBase /* b_data */, - DeviceMemoryBase /* bias_data */, - DeviceMemoryBase /* c_data */) const override; + absl::Status operator()(Stream*, dnn::ProfileResult*, + DeviceMemoryBase scratch_memory, + DeviceMemoryBase /* a_data */, + DeviceMemoryBase /* b_data */, + DeviceMemoryBase /* bias_data */, + DeviceMemoryBase /* c_data */) const override; ROCmFusedMatmulRunner(Stream* stream, dnn::DataType input_type, dnn::DataType bias_type, dnn::DataType output_type, @@ -4067,7 +4085,7 @@ ROCmFusedMatmulRunner::ROCmFusedMatmulRunner( _ldc(ldc), _activation_mode(activation_mode) {} -tsl::StatusOr ROCmFusedMatmulRunner::ToAlgorithmDesc() const { +absl::StatusOr ROCmFusedMatmulRunner::ToAlgorithmDesc() const { std::vector> knobs; knobs.emplace_back(0, static_cast(_input_type)); knobs.emplace_back(1, static_cast(_bias_type)); @@ -4088,27 +4106,32 @@ std::string ROCmFusedMatmulRunner::ToString() const { } template -tsl::Status ROCmFusedMatmulRunner::gemm(Stream* stream, DeviceMemoryBase a_data, - DeviceMemoryBase b_data, - DeviceMemoryBase c_data) const { +absl::Status ROCmFusedMatmulRunner::gemm(Stream* stream, + DeviceMemoryBase a_data, + DeviceMemoryBase b_data, + DeviceMemoryBase c_data) const { blas::Transpose ta = _trans_a ? blas::Transpose::kTranspose : blas::Transpose::kNoTranspose; blas::Transpose tb = _trans_b ? blas::Transpose::kTranspose : blas::Transpose::kNoTranspose; - return stream->ThenBlasGemm( - tb, ta, _n, _m, _k, static_cast>(b_data), _ldb, - static_cast>(a_data), _lda, - static_cast*>(&c_data), _ldc, NumericOptions{}, - blas::CallContext::kNone); + auto* blas = stream->parent()->AsBlas(); + if (blas == nullptr) { + return absl::InternalError("No Blas support for stream"); + } + return blas->BlasGemm(stream, tb, ta, _n, _m, _k, + static_cast>(b_data), _ldb, + static_cast>(a_data), _lda, + static_cast*>(&c_data), _ldc, + NumericOptions{}, blas::CallContext::kNone); } template -tsl::Status InplaceBiasActivation(Stream* stream, DeviceMemoryBase c_data, - DeviceMemoryBase bias_data, - dnn::ActivationMode activation_mode, - uint64_t m, uint64_t n, int64_t ldc, - float param) { +absl::Status InplaceBiasActivation(Stream* stream, DeviceMemoryBase c_data, + DeviceMemoryBase bias_data, + dnn::ActivationMode activation_mode, + uint64_t m, uint64_t n, int64_t ldc, + float param) { typedef typename std::conditional< std::is_same_v, __half, typename std::conditional, @@ -4118,22 +4141,22 @@ tsl::Status InplaceBiasActivation(Stream* stream, DeviceMemoryBase c_data, activation_mode == dnn::ActivationMode::kBandPass || activation_mode == dnn::ActivationMode::kLeakyRelu) - return tsl::Status(absl::StatusCode::kInvalidArgument, - "ROCm InplaceBiasActivation can't be used with " - "parametric activations yet"); + return absl::InvalidArgumentError( + "ROCm InplaceBiasActivation can't be used with " + "parametric activations yet"); launchInplaceBiasActivation( AsGpuStreamValue(stream), c_data.opaque(), bias_data.opaque(), static_cast(activation_mode), m, n, ldc, param); - return tsl::OkStatus(); + return absl::OkStatus(); } // Launch the operation, with the signature determined by `Sig`. -tsl::Status ROCmFusedMatmulRunner::operator()( +absl::Status ROCmFusedMatmulRunner::operator()( Stream* stream, dnn::ProfileResult* prof, DeviceMemoryBase scratch_memory, DeviceMemoryBase a_data, DeviceMemoryBase b_data, DeviceMemoryBase bias_data, DeviceMemoryBase c_data) const { - tsl::Status status; + absl::Status status; if (_input_type == dnn::DataType::kFloat) status = gemm(stream, a_data, b_data, c_data); else if (_input_type == dnn::DataType::kHalf) @@ -4143,8 +4166,7 @@ tsl::Status ROCmFusedMatmulRunner::operator()( else if (_input_type == dnn::DataType::kDouble) status = gemm(stream, a_data, b_data, c_data); else - return tsl::Status(absl::StatusCode::kInvalidArgument, - "Unsupported input type"); + return absl::InvalidArgumentError("Unsupported input type"); if (!status.ok()) return status; if (_input_type == dnn::DataType::kFloat) @@ -4160,11 +4182,10 @@ tsl::Status ROCmFusedMatmulRunner::operator()( return InplaceBiasActivation(stream, c_data, bias_data, _activation_mode, _m, _n, _ldc, 0.0f); else - return tsl::Status(absl::StatusCode::kInvalidArgument, - "Unsupported input type"); + return absl::InvalidArgumentError("Unsupported input type"); } -tsl::Status MIOpenSupport::GetFusedMatmulRunners( +absl::Status MIOpenSupport::GetFusedMatmulRunners( bool use_cudnn_frontend, dnn::DataType input_type, dnn::DataType bias_type, dnn::DataType output_type, Stream* stream, bool trans_a, bool trans_b, uint64_t m, uint64_t n, uint64_t k, int64_t lda, int64_t ldb, int64_t ldc, @@ -4174,22 +4195,20 @@ tsl::Status MIOpenSupport::GetFusedMatmulRunners( out_exec_plans) { out_exec_plans->clear(); if (input_type != output_type) - return tsl::Status( - absl::StatusCode::kInvalidArgument, + return absl::InvalidArgumentError( "ROCm fused matmul does not support input/output type mismatch"); if (input_type != bias_type) - return tsl::Status( - absl::StatusCode::kInvalidArgument, + return absl::InvalidArgumentError( "ROCm fused matmul does not support input/bias type mismatch"); auto runner_ptr = new ROCmFusedMatmulRunner( stream, input_type, bias_type, output_type, trans_a, trans_b, m, n, k, lda, ldb, ldc, activation_mode); out_exec_plans->push_back( std::unique_ptr(runner_ptr)); - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status MIOpenSupport::DoFusedConvolve( +absl::Status MIOpenSupport::DoFusedConvolve( Stream* stream, dnn::DataType input_type, dnn::DataType side_input_type, dnn::DataType bias_type, dnn::DataType output_type, const dnn::BatchDescriptor& conv_input_descriptor, @@ -4204,7 +4223,7 @@ tsl::Status MIOpenSupport::DoFusedConvolve( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return tsl::errors::Unimplemented("fused convolve not implemented yet"); + return absl::UnimplementedError("fused convolve not implemented yet"); } bool MIOpenSupport::DoTransformTensor(Stream* stream, @@ -4219,209 +4238,15 @@ bool MIOpenSupport::DoTransformTensor(Stream* stream, return false; } -bool MIOpenSupport::DoMatMul(Stream* stream, - const DeviceMemory& input_data, - const DeviceMemory& weights, - const dnn::BatchDescriptor& input_dimensions, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) { - if (input_dimensions.count() != output_dimensions.count()) { - LOG(ERROR) << "MatMul input and output dimensions are not compatible."; - return false; - } - - // We do not permute the input or output, instead we just - // reinterpret the layout. We are working with row-major matrices - // and the rows of the input and output correspond to batch, so - // batch has to be outermost in both the input and output. - // - // By adding transposes to the BLAS gemm call we could perhaps make - // the kYXDepthBatch layout work as well, but there has been no need - // for that so far. - if (input_dimensions.layout() != dnn::DataLayout::kBatchYXDepth && - input_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) { - LOG(ERROR) << "Unsupported MatMul input layout."; - return false; - } - if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth && - output_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) { - LOG(ERROR) << "Unsupported MatMul output layout."; - return false; - } - - if (output_dimensions.width() == 1 && output_dimensions.height() == 1) { - // This is a fast path that also supports the kBatchYXDepth layout. - - // The matrices here are in row-major format while BLAS expects - // column-major, i.e. our matrices are transposed as far as BLAS - // is concerned. So we need to compute output^T = - // input^T*weights^T. There is no parameter for transposing the - // output in BLAS gemm, but instead we can transpose both sides of - // the equality to see that this is equivalent to - // output=weights*input. So we only need to swap the order of - // weights and input in the matrix product to correct for the - // row-major versus column-major difference. - const int64_t m = output_dimensions.NodesAcrossFeatureMaps(); - const int64_t n = input_dimensions.count(); - const int64_t k = input_dimensions.NodesAcrossFeatureMaps(); - if (!stream - ->ThenBlasGemm(blas::Transpose::kNoTranspose, - blas::Transpose::kNoTranspose, m, n, k, weights, m, - input_data, k, output_data, m, NumericOptions{}, - blas::CallContext::kNone) - - .ok()) { - return false; - } - } else { - // This is a slower and more complex path that supports output - // width() * height() > 1, though it only supports the - // kBatchYXDepth layout. Does support kBatchDepthYX if output - // feature_map_count() == 1, as then there is no difference - // between the two layouts. - // - // The operation here is the same as above, except that we have to - // do the matrix multiplication for each (y,x) output coordinate - // separately. We then interpret weights as containing K = width() - // * height() different matrices, which we all multiply onto the - // matrix from input_data, yielding K matrix products. We then - // combine these together into one matrix by concatenating all the - // first rows of these matrices, then all the seconds rows and so - // on. We can do this with a batched matrix multiplication, where - // the result is written to a different submatrix of the output - // for each matrix multiplication. - // - // The reason that we only support the kBatchYXDepth output layout - // is that we have to do something in the depth for each (y,x) - // coordinate. The kBatchYXDepth layout has the depth information - // for each point (y,x) in contiguous memory while the - // kBatchDepthYX layout does not. - // - // TODO(broune): Consider a special case for when output depth == - // 1, as then possibly this could all be done as one matrix - // multiplication instead of a batched one, which should be - // faster. Another possibility would be to add a weights layout - // parameter and then support kBatchDepthYX for a different - // weights layout. - if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth && - !(output_dimensions.layout() == dnn::DataLayout::kBatchDepthYX && - output_dimensions.feature_map_count() == 1)) { - LOG(ERROR) << "Unsupported MatMul output layout."; - return false; - } - - const float alpha = 1.0f; // Take the matrix product without scaling it. - const float beta = 0.0f; // Ignore the original values in output_data. - const uint64_t m = output_dimensions.feature_map_count(); - const uint64_t n = input_dimensions.count(); - const uint64_t k = input_dimensions.NodesAcrossFeatureMaps(); - const int lda = m; - const int ldb = k; - const int ldc = output_dimensions.NodesAcrossFeatureMaps(); - const int batch_count = output_dimensions.NodesPerFeatureMap(); - - std::vector> a(batch_count); - std::vector> b(batch_count); - std::vector> c(batch_count); - for (int i = 0; i < batch_count; ++i) { - const int weights_offset = i * input_dimensions.NodesAcrossFeatureMaps() * - output_dimensions.feature_map_count(); - a[i] = DeviceMemory::MakeFromByteSize( - const_cast(reinterpret_cast(weights.opaque())) + - weights_offset, - weights.ElementCount() - weights_offset); - - b[i] = input_data; - - const int output_offset = i * output_dimensions.feature_map_count(); - c[i] = DeviceMemory::MakeFromByteSize( - const_cast( - reinterpret_cast(output_data->opaque())) + - output_offset, - output_data->ElementCount() - output_offset); - } - const auto toPtrs = [](std::vector>& v) { - std::vector*> ptrs; - ptrs.reserve(v.size()); - for (auto& mem : v) { - ptrs.push_back(&mem); - } - return ptrs; - }; - - stream->ThenBlasGemmBatched( - blas::Transpose::kNoTranspose, blas::Transpose::kNoTranspose, m, n, k, - alpha, toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c), ldc, - batch_count, NumericOptions{}, blas::CallContext::kNone); - } - - return stream->ok(); -} - -bool MIOpenSupport::DoBiasAdd(Stream* stream, - const DeviceMemory& input_data, - const DeviceMemory& biases, - const dnn::BatchDescriptor& dimensions, - DeviceMemory* output_data) { - ScopedTensorDescriptor input_descriptor{dimensions, miopenFloat}; - - BatchDescriptor bias_dimensions; - bias_dimensions.set_count(1) - .set_feature_map_count(dimensions.feature_map_count()) - .set_height(1) - .set_width(1) - .set_layout(dnn::DataLayout::kBatchYXDepth); - ScopedTensorDescriptor bias_descriptor{bias_dimensions, miopenFloat}; - - if (input_data.opaque() != output_data->opaque()) { - stream->ThenMemcpy(output_data, input_data, - dimensions.ElementCount() * sizeof(float)); - if (!stream->ok()) { - LOG(ERROR) - << "stream " << stream - << " could not enqueue a tensor copy as part of bias addition."; - return false; - } - } - - auto miopen = miopen_->GetHandle(parent_, stream); - - const float alpha1 = 1.0f; - const float alpha2 = 0.0f; - const float beta = 1.0f; - - auto status = wrap::miopenOpTensor( - miopen.handle(), miopenTensorOpAdd, &alpha1, bias_descriptor.handle(), - biases.opaque(), &alpha2, bias_descriptor.handle(), biases.opaque(), - &beta, input_descriptor.handle(), output_data->opaque()); - - if (status != miopenStatusSuccess) { - LOG(ERROR) << "stream " << stream << " could not enqueue bias addition."; - return false; - } - - return true; -} - -bool MIOpenSupport::DoActivate(Stream* stream, - dnn::ActivationMode activation_mode, - const dnn::BatchDescriptor& dimensions, - const DeviceMemory& input_data, - DeviceMemory* output_data, - uint64_t options) { - LOG(ERROR) << "miopen does not support activation yet"; - return false; -} - -tsl::Status MIOpenSupport::DoPoolForward( +absl::Status MIOpenSupport::DoPoolForward( dnn::DataType element_type, Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, const dnn::BatchDescriptor& output_dimensions, DeviceMemoryBase output_data, ScratchAllocator* workspace_allocator) { if (element_type == dnn::DataType::kDouble) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "MIOpen does not support pooling for double type yet"); + return absl::InvalidArgumentError( + "MIOpen does not support pooling for double type yet"); } auto miopen = miopen_->GetHandle(parent_, stream); @@ -4440,15 +4265,15 @@ tsl::Status MIOpenSupport::DoPoolForward( bool do_backward = false; uint8* workspace = nullptr; size_t workspace_size = 0; - std::unique_ptr> wsp_mem; + ScopedDeviceMemory wsp_mem; if (m_pooling_cache_enabled && element_type == dnn::DataType::kFloat) { do_backward = true; auto status = wrap::miopenPoolingGetWorkSpaceSizeV2( pooling_desc.handle(), dest_desc.handle(), &workspace_size); if (status != miopenStatusSuccess) { - return tsl::errors::Internal( + return absl::InternalError(absl::StrCat( "Failed to obtain workspace size for backward pooling on stream: ", - ToString(status)); + ToString(status))); } if (workspace_size != 0) { PoolingWorkspaceDescriptor* pdesc = 0; @@ -4459,12 +4284,10 @@ tsl::Status MIOpenSupport::DoPoolForward( miopenFloat, pdesc); if (cache_hit) { // reusing the same buffer - workspace = reinterpret_cast( - pdesc->workspace->mutable_device_memory()->opaque()); + workspace = reinterpret_cast(pdesc->workspace.ptr()->opaque()); } else { - wsp_mem = stream->AllocateTemporaryArray(workspace_size).value(); - workspace = reinterpret_cast( - wsp_mem->mutable_device_memory()->opaque()); + wsp_mem = stream->parent()->AllocateOwnedArray(workspace_size); + workspace = reinterpret_cast(wsp_mem.ptr()->opaque()); m_pooling_cache.insert(input_data.opaque(), input_dimensions, output_dimensions, pooling_dimensions, miopenFloat, wsp_mem, workspace_size, @@ -4478,10 +4301,10 @@ tsl::Status MIOpenSupport::DoPoolForward( input_data.opaque(), &beta, dest_desc.handle(), output_data.opaque(), do_backward, workspace, workspace_size); if (status != miopenStatusSuccess) { - return tsl::errors::Internal( - "Failed to enqueue forward pooling on stream: ", ToString(status)); + return absl::InternalError(absl::StrCat( + "Failed to enqueue forward pooling on stream: ", ToString(status))); } - return tsl::OkStatus(); + return absl::OkStatus(); } bool PoolingWorkspaceDescriptor::IsSame( @@ -4521,7 +4344,7 @@ void PoolingWorkspaceCache::insert( const void* p, const dnn::BatchDescriptor& input_dimensions, const dnn::BatchDescriptor& output_dimensions, const dnn::PoolingDescriptor& pooling_dimensions, int _type, - std::unique_ptr>& workspace, size_t wsp_size, + ScopedDeviceMemory& workspace, size_t wsp_size, hipStream_t hip_stream) { PoolingWorkspaceDescriptor* desc = 0; auto it = cache.find(p); @@ -4571,7 +4394,7 @@ void PoolingWorkspaceCache::trim(hipStream_t hip_stream) { } } -tsl::Status MIOpenSupport::DoPoolBackward( +absl::Status MIOpenSupport::DoPoolBackward( dnn::DataType element_type, Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, @@ -4579,8 +4402,8 @@ tsl::Status MIOpenSupport::DoPoolBackward( DeviceMemoryBase input_diff_data, DeviceMemoryBase output_diff_data, ScratchAllocator* workspace_allocator) { if (element_type == dnn::DataType::kDouble) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "MIOpen does not support pooling for double type yet"); + return absl::InvalidArgumentError( + "MIOpen does not support pooling for double type yet"); } auto miopen = miopen_->GetHandle(parent_, stream); @@ -4605,9 +4428,9 @@ tsl::Status MIOpenSupport::DoPoolBackward( auto status = wrap::miopenPoolingGetWorkSpaceSizeV2( pooling_desc.handle(), dest_desc.handle(), &workspace_size_in_bytes); if (status != miopenStatusSuccess) { - return tsl::errors::Internal( + return absl::InternalError(absl::StrCat( "Failed to obtain workspace size for backward pooling on stream: ", - ToString(status)); + ToString(status))); } // Allocate the workspace. @@ -4618,8 +4441,8 @@ tsl::Status MIOpenSupport::DoPoolBackward( miopen_dtype, pdesc); if (cache_hit) { assert(pdesc != 0); - workspace_ptr = reinterpret_cast( - pdesc->workspace->mutable_device_memory()->opaque()); + workspace_ptr = + reinterpret_cast(pdesc->workspace.ptr()->opaque()); VLOG(1) << "Pooling cache hit"; } else { VLOG(1) << "Pooling cache miss"; @@ -4627,7 +4450,7 @@ tsl::Status MIOpenSupport::DoPoolBackward( auto allocated = workspace_allocator->AllocateBytes(workspace_size_in_bytes); if (!allocated.ok() || (workspace = allocated.value()) == nullptr) { - return tsl::errors::Internal( + return absl::InternalError( "Failed to allocate backward pooling workspace"); } DeviceMemory dest2; // duplicated dest from forward: @@ -4648,7 +4471,7 @@ tsl::Status MIOpenSupport::DoPoolBackward( assert(workspace_allocator); auto allocated = workspace_allocator->AllocateBytes(dest2_size); if (!allocated.ok() || (dest2 = allocated.value()) == nullptr) { - return tsl::errors::Internal( + return absl::InternalError( "Failed to allocate backward pooling workspace"); } } else { @@ -4662,9 +4485,9 @@ tsl::Status MIOpenSupport::DoPoolBackward( workspace.opaque(), workspace_size_in_bytes); if (status != miopenStatusSuccess) { - return tsl::errors::Internal( + return absl::InternalError(absl::StrCat( "Failed to enqueue forward pooling (before backward) on stream: ", - ToString(status)); + ToString(status))); } workspace_ptr = reinterpret_cast(workspace.opaque()); } @@ -4677,10 +4500,10 @@ tsl::Status MIOpenSupport::DoPoolBackward( output_diff_data.opaque(), workspace_ptr); if (status != miopenStatusSuccess) { - return tsl::errors::Internal( - "Failed to enqueue backward pooling on stream: ", ToString(status)); + return absl::InternalError(absl::StrCat( + "Failed to enqueue backward pooling on stream: ", ToString(status))); } - return tsl::OkStatus(); + return absl::OkStatus(); } bool MIOpenSupport::DoNormalizeWithDimensions( @@ -4818,109 +4641,6 @@ bool MIOpenSupport::DoNormalizeBackwardWithDimensions( return true; } -bool MIOpenSupport::DoDepthConcatenate( - Stream* stream, absl::Span input_dimensions, - absl::Span* const> input_data, - DeviceMemory* output_data) { - CHECK_EQ(input_dimensions.size(), input_data.size()); - - for (const auto& dimensions : input_dimensions) { - if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) { - LOG(ERROR) << "MIOpenSupport::DoDepthConcatenate currently only " - "supports the kBatchDepthYX layout."; - return false; - } - } - - if (input_dimensions.empty()) { - return true; // Nothing to do. - } - - dnn::BatchDescriptor output_dimensions = - dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions); - - const int64_t area = output_dimensions.width() * output_dimensions.height(); - const auto index = [area](int64_t batch, int64_t depth, int64_t yx, - int64_t max_depth) { - return (batch * max_depth + depth) * area + yx; - }; - - std::vector output_host(output_dimensions.ElementCount()); - std::vector tmp; - int64_t depth_sum = 0; - for (size_t i = 0; i < input_data.size(); ++i) { - const auto& dimensions = input_dimensions[i]; - tmp.resize(dimensions.ElementCount()); - stream->ThenMemcpyD2H(*input_data[i], absl::MakeSpan(tmp)); - tsl::Status block_status = stream->BlockHostUntilDone(); - if (!block_status.ok()) { - LOG(ERROR) << "BlockHostUntilDone failed: " << block_status; - return false; - } - - for (int64_t batch = 0; batch < output_dimensions.count(); ++batch) { - for (int64_t yx = 0; yx < area; ++yx) { - for (int64_t depth = 0; depth < dimensions.feature_map_count(); - ++depth) { - LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' ' - << yx << ' ' << depth; - output_host[index(batch, depth + depth_sum, yx, - output_dimensions.feature_map_count())] = - tmp[index(batch, depth, yx, dimensions.feature_map_count())]; - } - } - } - depth_sum += dimensions.feature_map_count(); - } - stream->ThenMemcpyH2D(output_host, output_data); - return true; -} - -bool MIOpenSupport::DoElementwiseOperate( - Stream* stream, dnn::ElementwiseOperation operation, - absl::Span input_dimensions, - absl::Span* const> input_data, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) { - LOG(FATAL) << "not yet implemented"; // TODO(leary) - return false; -} - -bool MIOpenSupport::DoXYPad(Stream* stream, - const dnn::BatchDescriptor& dimensions, - const DeviceMemory& input_data, - int64_t left_pad, int64_t right_pad, - int64_t top_pad, int64_t bottom_pad, - DeviceMemory* output_data) { - LOG(FATAL) << "not yet implemented"; // TODO(leary) - return false; -} - -bool MIOpenSupport::DoXYSlice(Stream* stream, - const dnn::BatchDescriptor& dimensions, - const DeviceMemory& input_data, - int64_t left_trim, int64_t right_trim, - int64_t top_trim, int64_t bottom_trim, - DeviceMemory* output_data) { - LOG(FATAL) << "not yet implemented"; // TODO(leary) - return false; -} - -bool MIOpenSupport::DoMemcpyD2HQuantized( - Stream* stream, const DeviceMemory& gpu_unquantized_src, - dnn::QuantizedActivationMode mode, void* host_dst, int64_t size) { - LOG(ERROR) << "quantized memcpy not supported by MIOpen"; - return false; -} - -bool MIOpenSupport::DoMemcpyH2DQuantized( - Stream* stream, const void* host_src, int64_t size, - dnn::QuantizedActivationMode mode, - DeviceMemory* gpu_unquantized_dst) { - LOG(ERROR) << "quantized memcpy not supported by MIOpen"; - return false; -} - bool MIOpenSupport::DeriveOutputBatchDescriptor( const BatchDescriptor& batch_descriptor, const FilterDescriptor& filter_descriptor, @@ -4973,7 +4693,7 @@ void initialize_miopen() { rocm::kROCmPlatformId, PluginKind::kDnn); if (!miopenAlreadyRegistered) { - tsl::Status status = + absl::Status status = PluginRegistry::Instance()->RegisterFactory( rocm::kROCmPlatformId, "MIOpen", [](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* { @@ -5003,5 +4723,6 @@ void initialize_miopen() { } // namespace stream_executor -REGISTER_MODULE_INITIALIZER(register_miopen, - { stream_executor::initialize_miopen(); }); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(register_miopen, { + stream_executor::initialize_miopen(); +}); diff --git a/xla/stream_executor/rocm/rocm_dnn.h b/xla/stream_executor/rocm/rocm_dnn.h index 8cd6f43ee292d..ecaffd3cad392 100644 --- a/xla/stream_executor/rocm/rocm_dnn.h +++ b/xla/stream_executor/rocm/rocm_dnn.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,9 +22,9 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "rocm/include/miopen/miopen.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/plugin_registry.h" -#include "xla/stream_executor/temporary_device_memory.h" namespace stream_executor { namespace gpu { @@ -41,7 +41,7 @@ struct PoolingWorkspaceDescriptor { dnn::PoolingDescriptor op; int dtype; uint64_t timestamp; - std::unique_ptr> workspace; + ScopedDeviceMemory workspace; size_t workspace_size; bool IsSame(const dnn::BatchDescriptor& input_dimensions, const dnn::BatchDescriptor& output_dimensions, @@ -61,8 +61,8 @@ struct PoolingWorkspaceCache { void insert(const void* p, const dnn::BatchDescriptor& input_dimensions, const dnn::BatchDescriptor& output_dimensions, const dnn::PoolingDescriptor& pooling_dimensions, int _type, - std::unique_ptr>& workspace, - size_t wsp_size, hipStream_t hip_stream); + ScopedDeviceMemory& workspace, size_t wsp_size, + hipStream_t hip_stream); private: void trim(hipStream_t hip_stream); @@ -74,10 +74,10 @@ class MIOpenSupport : public dnn::DnnSupport { public: explicit MIOpenSupport(GpuExecutor* parent); - tsl::Status Init() override; - tsl::StatusOr GetVersion() override; + absl::Status Init() override; + absl::StatusOr GetVersion() override; - tsl::StatusOr> createRnnDescriptor( + absl::StatusOr> CreateRnnDescriptor( int num_layers, int hidden_size, int input_size, int cell_size, int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, @@ -85,13 +85,13 @@ class MIOpenSupport : public dnn::DnnSupport { const NumericOptions& numeric_options, float dropout, uint64_t seed, ScratchAllocator* state_allocator, bool use_padded_io) override; - tsl::StatusOr> - createRnnSequenceTensorDescriptor(int seq_length, int batch_size, + absl::StatusOr> + CreateRnnSequenceTensorDescriptor(int seq_length, int batch_size, int data_size, dnn::DataType data_type) override; - tsl::StatusOr> - createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, + absl::StatusOr> + CreateRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, dnn::DataType data_type) override; bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, @@ -229,7 +229,7 @@ class MIOpenSupport : public dnn::DnnSupport { ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) override; - tsl::Status GetConvolveRunners( + absl::Status GetConvolveRunners( bool use_cudnn_frontend, dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType output_type, Stream* stream, const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, @@ -243,7 +243,7 @@ class MIOpenSupport : public dnn::DnnSupport { std::vector>* out_runners) override; - tsl::StatusOr> ConvolveRunnerFromDesc( + absl::StatusOr> ConvolveRunnerFromDesc( Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor, @@ -351,7 +351,7 @@ class MIOpenSupport : public dnn::DnnSupport { DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator) override; - tsl::Status DoConvolve( + absl::Status DoConvolve( dnn::ConvolutionKind kind, dnn::DataType element_type, dnn::DataType output_type, Stream* stream, const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, @@ -363,7 +363,7 @@ class MIOpenSupport : public dnn::DnnSupport { dnn::AlgorithmDesc algorithm_desc, DeviceMemory scratch_memory, dnn::ProfileResult* output_profile_result) override; - tsl::Status DoFusedConvolve( + absl::Status DoFusedConvolve( Stream* stream, dnn::DataType input_type, dnn::DataType side_input_type, dnn::DataType bias_type, dnn::DataType output_type, const dnn::BatchDescriptor& conv_input_descriptor, @@ -379,46 +379,7 @@ class MIOpenSupport : public dnn::DnnSupport { const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) override; - bool DoSeparableConvolve( - Stream* stream, const dnn::BatchDescriptor& batch_descriptor, - const DeviceMemory& input_data, - const dnn::FilterDescriptor& filter_descriptor, int depth_multiplier, - const DeviceMemory& first_weights, - const DeviceMemory& second_weights, - const dnn::ConvolutionDescriptor& convolution_descriptor, - const dnn::BatchDescriptor& output_descriptor, - DeviceMemory* output_data) override { - LOG(ERROR) << "separable convolution not supported by MIOpen"; - return false; - } - - bool DoMatMul(Stream* stream, const DeviceMemory& input_data, - const DeviceMemory& weights, - const dnn::BatchDescriptor& input_dimensions, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) override; - - bool DoMatMulQuantized(Stream* stream, const DeviceMemory& input_data, - const DeviceMemory& quantized_weights, - const DeviceMemory& weight_scales, - const dnn::BatchDescriptor& input_dimensions, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) override { - LOG(ERROR) << "DNN MatMulQuantized not supported by MIOpen"; - return false; - } - - bool DoMatMulQuantized(Stream* stream, const DeviceMemory& input_data, - const DeviceMemory& quantized_weights, - const DeviceMemory& weight_scales, - const dnn::BatchDescriptor& input_dimensions, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) override { - LOG(ERROR) << "DNN MatMulQuantized not supported by MIOpen"; - return false; - } - - tsl::Status GetFusedMatmulRunners( + absl::Status GetFusedMatmulRunners( bool use_cudnn_frontend, dnn::DataType input_type, dnn::DataType bias_type, dnn::DataType output_type, Stream* stream, bool trans_a, bool trans_b, uint64_t m, uint64_t n, uint64_t k, @@ -428,34 +389,24 @@ class MIOpenSupport : public dnn::DnnSupport { std::vector>* out_exec_plans) override; - bool DoBiasAdd(Stream* stream, const DeviceMemory& input_data, - const DeviceMemory& biases, - const dnn::BatchDescriptor& dimensions, - DeviceMemory* output_data) override; - - bool DoActivate(Stream* stream, dnn::ActivationMode activation_mode, - const dnn::BatchDescriptor& dimensions, - const DeviceMemory& input_data, - DeviceMemory* output_data, uint64_t options) override; - - tsl::Status DoPoolForward(dnn::DataType element_type, Stream* stream, - const dnn::PoolingDescriptor& pooling_dimensions, - const dnn::BatchDescriptor& input_dimensions, - DeviceMemoryBase input_data, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemoryBase output_data, - ScratchAllocator* workspace_allocator) override; - - tsl::Status DoPoolBackward(dnn::DataType element_type, Stream* stream, + absl::Status DoPoolForward(dnn::DataType element_type, Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data, const dnn::BatchDescriptor& output_dimensions, DeviceMemoryBase output_data, - DeviceMemoryBase input_diff_data, - DeviceMemoryBase output_diff_data, ScratchAllocator* workspace_allocator) override; + absl::Status DoPoolBackward(dnn::DataType element_type, Stream* stream, + const dnn::PoolingDescriptor& pooling_dimensions, + const dnn::BatchDescriptor& input_dimensions, + DeviceMemoryBase input_data, + const dnn::BatchDescriptor& output_dimensions, + DeviceMemoryBase output_data, + DeviceMemoryBase input_diff_data, + DeviceMemoryBase output_diff_data, + ScratchAllocator* workspace_allocator) override; + bool DoNormalizeWithDimensions( Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor, const dnn::BatchDescriptor& dimensions, @@ -471,38 +422,6 @@ class MIOpenSupport : public dnn::DnnSupport { DeviceMemory* raw_variable_gradient, ScratchAllocator* workspace_allocator = nullptr) override; - bool DoDepthConcatenate( - Stream* stream, absl::Span input_dimensions, - absl::Span* const> input_data, - DeviceMemory* output_data) override; - - bool DoElementwiseOperate( - Stream* stream, dnn::ElementwiseOperation operation, - absl::Span input_dimensions, - absl::Span* const> input_data, - const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) override; - - bool DoXYPad(Stream* stream, const dnn::BatchDescriptor& dimensions, - const DeviceMemory& input_data, int64_t left_pad, - int64_t right_pad, int64_t top_pad, int64_t bottom_pad, - DeviceMemory* output_data) override; - - bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor& dimensions, - const DeviceMemory& input_data, int64_t left_trim, - int64_t right_trim, int64_t top_trim, int64_t bottom_trim, - DeviceMemory* output_data) override; - - bool DoMemcpyD2HQuantized(Stream* stream, - const DeviceMemory& device_unquantized_src, - dnn::QuantizedActivationMode mode, void* host_dst, - int64_t size) override; - - bool DoMemcpyH2DQuantized( - Stream* stream, const void* host_src, int64_t size, - dnn::QuantizedActivationMode mode, - DeviceMemory* device_unquantized_dst) override; - // Derives an output batch descriptor from an input batch and convolution // descriptors. bool DeriveOutputBatchDescriptor( @@ -520,17 +439,17 @@ class MIOpenSupport : public dnn::DnnSupport { GpuExecutor* GetParentExecutor() { return parent_; } - tsl::Status DoCtcLoss(Stream* stream, dnn::DataType element_type, - const dnn::RnnStateTensorDescriptor& probs_desc, - const DeviceMemoryBase probs_data, - absl::Span labels_data, - absl::Span labels_lengths_data, - absl::Span input_lengths_data, - DeviceMemoryBase costs_data, - const dnn::RnnStateTensorDescriptor& grads_desc, - DeviceMemoryBase grads_data, - DeviceMemory scratch_memory, - int ctc_loss_algo_id) override; + absl::Status DoCtcLoss(Stream* stream, dnn::DataType element_type, + const dnn::RnnStateTensorDescriptor& probs_desc, + const DeviceMemoryBase probs_data, + absl::Span labels_data, + absl::Span labels_lengths_data, + absl::Span input_lengths_data, + DeviceMemoryBase costs_data, + const dnn::RnnStateTensorDescriptor& grads_desc, + DeviceMemoryBase grads_data, + DeviceMemory scratch_memory, + int ctc_loss_algo_id) override; private: GpuExecutor* parent_; // Parent executor object. Not owned. @@ -575,50 +494,50 @@ class MIOpenSupport : public dnn::DnnSupport { DeviceMemory* offset_backprop); template - bool DoRnnForwardImpl(Stream* stream, const MIOpenRnnDescriptor& rnn_desc, - const MIOpenRnnSequenceTensorDescriptor& input_desc, - const DeviceMemory& input_data, - const MIOpenRnnStateTensorDescriptor& input_h_desc, - const DeviceMemory& input_h_data, - const MIOpenRnnStateTensorDescriptor& input_c_desc, - const DeviceMemory& input_c_data, - const DeviceMemory& params, - const MIOpenRnnSequenceTensorDescriptor& output_desc, - DeviceMemory* output_data, - const MIOpenRnnStateTensorDescriptor& output_h_desc, - DeviceMemory* output_h_data, - const MIOpenRnnStateTensorDescriptor& output_c_desc, - DeviceMemory* output_c_data, bool is_training, - ScratchAllocator* reserve_space_allocator, - ScratchAllocator* workspace_allocator, - dnn::ProfileResult* output_profile_result); + absl::Status DoRnnForwardImpl( + Stream* stream, const MIOpenRnnDescriptor& rnn_desc, + const MIOpenRnnSequenceTensorDescriptor& input_desc, + const DeviceMemory& input_data, + const MIOpenRnnStateTensorDescriptor& input_h_desc, + const DeviceMemory& input_h_data, + const MIOpenRnnStateTensorDescriptor& input_c_desc, + const DeviceMemory& input_c_data, const DeviceMemory& params, + const MIOpenRnnSequenceTensorDescriptor& output_desc, + DeviceMemory* output_data, + const MIOpenRnnStateTensorDescriptor& output_h_desc, + DeviceMemory* output_h_data, + const MIOpenRnnStateTensorDescriptor& output_c_desc, + DeviceMemory* output_c_data, bool is_training, + ScratchAllocator* reserve_space_allocator, + ScratchAllocator* workspace_allocator, + dnn::ProfileResult* output_profile_result); template - bool DoRnnBackwardImpl(Stream* stream, const MIOpenRnnDescriptor& rnn_desc, - const MIOpenRnnSequenceTensorDescriptor& input_desc, - const DeviceMemory& input_data, - const MIOpenRnnStateTensorDescriptor& input_h_desc, - const DeviceMemory& input_h_data, - const MIOpenRnnStateTensorDescriptor& input_c_desc, - const DeviceMemory& input_c_data, - const DeviceMemory& params, - const MIOpenRnnSequenceTensorDescriptor& output_desc, - const DeviceMemory& output_data, - const MIOpenRnnStateTensorDescriptor& output_h_desc, - const DeviceMemory& output_h_data, - const MIOpenRnnStateTensorDescriptor& output_c_desc, - const DeviceMemory& output_c_data, - const DeviceMemory& output_backprop_data, - const DeviceMemory& output_h_backprop_data, - const DeviceMemory& output_c_backprop_data, - DeviceMemory* input_backprop_data, - DeviceMemory* input_h_backprop_data, - DeviceMemory* input_c_backprop_data, - DeviceMemory* params_backprop_data, - DeviceMemory* reserve_space_data, - ScratchAllocator* workspace_allocator, - dnn::ProfileResult* output_profile_result); - - tsl::Status DoPrepareForConvolution( + absl::Status DoRnnBackwardImpl( + Stream* stream, const MIOpenRnnDescriptor& rnn_desc, + const MIOpenRnnSequenceTensorDescriptor& input_desc, + const DeviceMemory& input_data, + const MIOpenRnnStateTensorDescriptor& input_h_desc, + const DeviceMemory& input_h_data, + const MIOpenRnnStateTensorDescriptor& input_c_desc, + const DeviceMemory& input_c_data, const DeviceMemory& params, + const MIOpenRnnSequenceTensorDescriptor& output_desc, + const DeviceMemory& output_data, + const MIOpenRnnStateTensorDescriptor& output_h_desc, + const DeviceMemory& output_h_data, + const MIOpenRnnStateTensorDescriptor& output_c_desc, + const DeviceMemory& output_c_data, + const DeviceMemory& output_backprop_data, + const DeviceMemory& output_h_backprop_data, + const DeviceMemory& output_c_backprop_data, + DeviceMemory* input_backprop_data, + DeviceMemory* input_h_backprop_data, + DeviceMemory* input_c_backprop_data, + DeviceMemory* params_backprop_data, + DeviceMemory* reserve_space_data, + ScratchAllocator* workspace_allocator, + dnn::ProfileResult* output_profile_result); + + absl::Status DoPrepareForConvolution( dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, const dnn::FilterDescriptor& filter_descriptor, @@ -630,7 +549,7 @@ class MIOpenSupport : public dnn::DnnSupport { ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc, DeviceMemory* scratch_memory) override; - tsl::Status DoCtcLossImpl( + absl::Status DoCtcLossImpl( Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc, const DeviceMemoryBase probs_data, absl::Span labels_data, absl::Span labels_lengths_data, @@ -639,7 +558,7 @@ class MIOpenSupport : public dnn::DnnSupport { DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc, DeviceMemory scratch_memory, int ctc_loss_algo_id); - tsl::Status DoPrepareForCtcLoss( + absl::Status DoPrepareForCtcLoss( Stream* stream, dnn::DataType element_type, const dnn::RnnStateTensorDescriptor& probs_desc, const dnn::RnnStateTensorDescriptor& grads_desc, diff --git a/xla/stream_executor/rocm/rocm_driver.cc b/xla/stream_executor/rocm/rocm_driver.cc index bf20011441504..409546aafd7ab 100644 --- a/xla/stream_executor/rocm/rocm_driver.cc +++ b/xla/stream_executor/rocm/rocm_driver.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -75,6 +75,46 @@ namespace gpu { /* static */ absl::Mutex CreatedContexts::mu_{absl::kConstInit}; /* static */ int64_t CreatedContexts::next_id_ = 1; // 0 means "no context" + +// Formats hipError_t to output prettified values into a log stream. +// Error summaries taken from: +string ToString(hipError_t result) { +#define OSTREAM_ROCM_ERROR(__name) \ + case hipError##__name: \ + return "HIP_ERROR_" #__name; + + switch (result) { + OSTREAM_ROCM_ERROR(InvalidValue) + OSTREAM_ROCM_ERROR(OutOfMemory) + OSTREAM_ROCM_ERROR(NotInitialized) + OSTREAM_ROCM_ERROR(Deinitialized) + OSTREAM_ROCM_ERROR(NoDevice) + OSTREAM_ROCM_ERROR(InvalidDevice) + OSTREAM_ROCM_ERROR(InvalidImage) + OSTREAM_ROCM_ERROR(InvalidContext) + OSTREAM_ROCM_ERROR(InvalidHandle) + OSTREAM_ROCM_ERROR(NotFound) + OSTREAM_ROCM_ERROR(NotReady) + OSTREAM_ROCM_ERROR(NoBinaryForGpu) + + // Encountered an uncorrectable ECC error during execution. + OSTREAM_ROCM_ERROR(ECCNotCorrectable) + + // Load/store on an invalid address. Must reboot all context. + case 700: + return "ROCM_ERROR_ILLEGAL_ADDRESS"; + // Passed too many / wrong arguments, too many threads for register count. + case 701: + return "ROCM_ERROR_LAUNCH_OUT_OF_RESOURCES"; + + OSTREAM_ROCM_ERROR(ContextAlreadyInUse) + OSTREAM_ROCM_ERROR(PeerAccessUnsupported) + OSTREAM_ROCM_ERROR(Unknown) // Unknown internal error to ROCM. + default: + return absl::StrCat("hipError_t(", static_cast(result), ")"); + } +} + namespace { // Returns the current context and checks that it is in the set of HIP contexts @@ -142,7 +182,7 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* hip_context) { if (tls->depth == 0) { VLOG(3) << "ScopedActivateContext switching to " << hip_context->device_ordinal(); - FAIL_IF_ROCM_ERROR(hipCtxSetCurrent(hip_context->context()), + FAIL_IF_ROCM_ERROR(wrap::hipCtxSetCurrent(hip_context->context()), "Failed setting context"); tls->depth = 1; tls->current_device_ordinal = hip_context->device_ordinal(); @@ -165,7 +205,7 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* hip_context) { to_restore_ = tls->context; // Set the device and update thread local. - FAIL_IF_ROCM_ERROR(hipCtxSetCurrent(hip_context->context()), + FAIL_IF_ROCM_ERROR(wrap::hipCtxSetCurrent(hip_context->context()), "Failed setting context"); tls->current_device_ordinal = hip_context->device_ordinal(); tls->context = hip_context; @@ -189,7 +229,7 @@ ScopedActivateContext::~ScopedActivateContext() { } // Set context and update thread local. - FAIL_IF_ROCM_ERROR(hipCtxSetCurrent(to_restore_->context()), + FAIL_IF_ROCM_ERROR(wrap::hipCtxSetCurrent(to_restore_->context()), "Failed setting context"); tls->current_device_ordinal = to_restore_->device_ordinal(); tls->context = to_restore_; @@ -246,7 +286,7 @@ string ROCMPointersToCanAccessString(hipDeviceptr_t from, hipDeviceptr_t to) { // Actually performs the work of ROCM initialization. Wrapped up in one-time // execution guard. -static tsl::Status InternalInit() { +static absl::Status InternalInit() { hipError_t res = hipErrorNoDevice; if (FLAGS_gpuexec_rocm_driver_inject_init_error) { LOG(ERROR) << "injecting ROCM init error; initialization will fail"; @@ -255,40 +295,40 @@ static tsl::Status InternalInit() { } if (res == hipSuccess) { - return tsl::OkStatus(); + return absl::OkStatus(); } LOG(ERROR) << "failed call to hipInit: " << ToString(res); Diagnostician::LogDiagnosticInformation(); - return tsl::Status{absl::StatusCode::kAborted, - absl::StrCat("failed call to hipInit: ", ToString(res))}; + return absl::Status{absl::StatusCode::kAborted, + absl::StrCat("failed call to hipInit: ", ToString(res))}; } } // namespace -/* static */ tsl::Status GpuDriver::Init() { +/* static */ absl::Status GpuDriver::Init() { // Cached return value from calling InternalInit(), as hipInit need only be // called once, but GpuDriver::Init may be called many times. - static tsl::Status* init_retval = [] { - return new tsl::Status(InternalInit()); + static absl::Status* init_retval = [] { + return new absl::Status(InternalInit()); }(); return *init_retval; } -/* static */ tsl::Status GpuDriver::GetDevice(int device_ordinal, - hipDevice_t* device) { +/* static */ absl::Status GpuDriver::GetDevice(int device_ordinal, + hipDevice_t* device) { hipError_t res = wrap::hipDeviceGet(device, device_ordinal); if (res == hipSuccess) { - return tsl::OkStatus(); + return absl::OkStatus(); } - return tsl::Status{ + return absl::Status{ absl::StatusCode::kInternal, absl::StrCat("failed call to hipDeviceGet: ", ToString(res))}; } -/* static */ tsl::Status GpuDriver::GetDeviceName(hipDevice_t device, - string* device_name) { +/* static */ absl::Status GpuDriver::GetDeviceName(hipDevice_t device, + string* device_name) { static const size_t kCharLimit = 64; absl::InlinedVector chars(kCharLimit); RETURN_IF_ROCM_ERROR( @@ -296,40 +336,15 @@ static tsl::Status InternalInit() { "Failed to get device name"); chars[kCharLimit - 1] = '\0'; *device_name = chars.begin(); - return tsl::OkStatus(); -} - -bool DeviceOptionsToContextFlags(const DeviceOptions& device_options, - int* flags) { - static_assert(DeviceOptions::kMask == 0xf, - "needs update for new device options"); - - if (device_options.flags() & DeviceOptions::kDoNotReclaimStackAllocation) { - *flags |= hipDeviceLmemResizeToMax; - } - - if (device_options.flags() & DeviceOptions::kScheduleSpin) { - *flags |= hipDeviceScheduleSpin; - } - if (device_options.flags() & DeviceOptions::kScheduleYield) { - *flags |= hipDeviceScheduleYield; - } - if (device_options.flags() & DeviceOptions::kScheduleBlockingSync) { - *flags |= hipDeviceScheduleBlockingSync; - } - - return true; + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::CreateContext( - int device_ordinal, hipDevice_t device, const DeviceOptions& device_options, - GpuContext** context) { +/* static */ absl::Status GpuDriver::CreateContext(int device_ordinal, + hipDevice_t device, + GpuContext** context) { *context = nullptr; int flags = 0; - if (!DeviceOptionsToContextFlags(device_options, &flags)) { - LOG(WARNING) << "could not convert all device options into context flags"; - } hipError_t res; hipCtx_t former_context; @@ -382,7 +397,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options, << "success in this call must entail non-null result"; VLOG(2) << "created or reused context " << new_context << " for this thread"; - return ::tsl::OkStatus(); + return absl::OkStatus(); } std::string message = @@ -396,7 +411,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options, } } - return tsl::Status(absl::StatusCode::kInternal, message); + return absl::InternalError(message); } /* static */ void GpuDriver::DestroyContext(GpuContext* context) { @@ -422,15 +437,15 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options, return context->context(); } -/* static */ tsl::Status GpuDriver::FuncGetAttribute( +/* static */ absl::Status GpuDriver::FuncGetAttribute( hipFunction_attribute attribute, hipFunction_t func, int* attribute_value) { RETURN_IF_ROCM_ERROR( wrap::hipFuncGetAttribute(attribute_value, attribute, func), "Failed to query kernel attribute: ", attribute); - return tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::FuncSetCacheConfig( +/* static */ absl::Status GpuDriver::FuncSetCacheConfig( hipFunction_t function, hipFuncCache_t cache_config) { // NOTE: this function is only available for in-process GPU kernels: // https://rocm.docs.amd.com/projects/HIP/en/latest/.doxygen/docBin/html/group___execution.html#gafdb33ef569eb89808fc5178d04b508ba @@ -438,10 +453,10 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options, RETURN_IF_ROCM_ERROR( wrap::hipFuncSetCacheConfig((const void*)function, cache_config), "Failed to set ROCM kernel cache config."); - return tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::StatusOr +/* static */ absl::StatusOr GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { hipSharedMemConfig shared_mem_config; ScopedActivateContext activation{context}; @@ -450,27 +465,27 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { return shared_mem_config; } -/* static */ tsl::Status GpuDriver::ContextSetSharedMemConfig( +/* static */ absl::Status GpuDriver::ContextSetSharedMemConfig( GpuContext* context, hipSharedMemConfig shared_mem_config) { ScopedActivateContext activation{context}; RETURN_IF_ROCM_ERROR(wrap::hipDeviceSetSharedMemConfig(shared_mem_config), "Failed to set ROCM device shared memory config"); - return tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::CreateGraph(hipGraph_t* graph) { +/* static */ absl::Status GpuDriver::CreateGraph(hipGraph_t* graph) { VLOG(2) << "Create new HIP graph"; RETURN_IF_ROCM_ERROR(wrap::hipGraphCreate(graph, /*flags=*/0), "Failed to create HIP graph"); VLOG(2) << "Created HIP graph " << *graph; - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::DestroyGraph(hipGraph_t graph) { +/* static */ absl::Status GpuDriver::DestroyGraph(hipGraph_t graph) { VLOG(2) << "Destroy HIP graph " << graph; RETURN_IF_ROCM_ERROR(wrap::hipGraphDestroy(graph), "Failed to destroy HIP graph"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } static std::string_view StreamCaptureModeToString( @@ -485,8 +500,8 @@ static std::string_view StreamCaptureModeToString( } } -/* static */ tsl::Status GpuDriver::StreamBeginCapture(GpuStreamHandle stream, - StreamCaptureMode mode) { +/* static */ absl::Status GpuDriver::StreamBeginCapture( + GpuStreamHandle stream, StreamCaptureMode mode) { hipStreamCaptureMode hip_mode; switch (mode) { case StreamCaptureMode::kGlobal: @@ -504,23 +519,29 @@ static std::string_view StreamCaptureModeToString( << StreamCaptureModeToString(mode) << " mode"; RETURN_IF_ROCM_ERROR(wrap::hipStreamBeginCapture(stream, hip_mode), "Failed to begin stream capture"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::StreamEndCapture(GpuStreamHandle stream, - hipGraph_t* graph) { +/* static */ absl::Status GpuDriver::StreamBeginCaptureToGraph( + GpuStreamHandle stream, GpuGraphHandle graph, StreamCaptureMode mode) { + return absl::UnimplementedError( + "StreamBeginCaptureToGraph is not implemented"); +} + +/* static */ absl::Status GpuDriver::StreamEndCapture(GpuStreamHandle stream, + hipGraph_t* graph) { VLOG(2) << "End stream " << stream << " capture"; RETURN_IF_ROCM_ERROR(wrap::hipStreamEndCapture(stream, graph), "Failed to end stream capture"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::GraphInstantiate( +/* static */ absl::Status GpuDriver::GraphInstantiate( hipGraphExec_t* exec, hipGraph_t graph, const GraphInstantiateFlags& flags) { - VLOG(2) << "Instante HIP executable graph from graph " << graph << " (" + VLOG(2) << "Instantiate HIP executable graph from graph " << graph << " (" << "auto_free_on_launch=" << flags.auto_free_on_launch << ", " << "device_launch=" << flags.device_launch << ", " << "use_node_priority=" << flags.use_node_prirotiy << ", " @@ -528,19 +549,31 @@ static std::string_view StreamCaptureModeToString( RETURN_IF_ROCM_ERROR( wrap::hipGraphInstantiate(exec, graph, nullptr, nullptr, 0), "Failed to instantiate HIP graph"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::GraphLaunch(hipGraphExec_t exec, - GpuStreamHandle stream) { +/* static */ absl::Status GpuDriver::GraphLaunch(hipGraphExec_t exec, + GpuStreamHandle stream) { VLOG(2) << "Launching HIP executable graph " << exec << " on a stream " << stream; RETURN_IF_ROCM_ERROR(wrap::hipGraphLaunch(exec, stream), "Failed to launch HIP graph"); - return ::tsl::OkStatus(); + return absl::OkStatus(); +} + +/* static */ absl::Status GpuDriver::GraphNodeSetEnabled(hipGraphExec_t exec, + hipGraphNode_t node, + bool enabled) { + // Node is enabled if value != 0, otherwise the node is disabled. + unsigned value = enabled ? 1 : 0; + VLOG(2) << "Set HIP executable graph " << exec << " node " << node + << " enabled flag to " << value; + RETURN_IF_ROCM_ERROR(wrap::hipGraphNodeSetEnabled(exec, node, value), + "Failed to set HIP graph node enabled flag"); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::GraphExecUpdate( +/* static */ absl::Status GpuDriver::GraphExecUpdate( hipGraphExec_t exec, hipGraph_t graph, GraphExecUpdateResultInfo* result) { VLOG(2) << "Update HIP graph executable " << exec << " with graph " << graph; @@ -582,17 +615,36 @@ static std::string_view StreamCaptureModeToString( } RETURN_IF_ROCM_ERROR(hip_error, "Failed to update HIP graph"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::DestroyGraphExec(hipGraphExec_t exec) { +absl::StatusOr> +GpuDriver::GraphNodeGetDependencies(GpuGraphNodeHandle node) { + VLOG(2) << "Get HIP graph node " << node << " dependencies"; + + std::vector dependencies; + + size_t num_dependencies = 0; + RETURN_IF_ROCM_ERROR( + hipGraphNodeGetDependencies(node, nullptr, &num_dependencies), + "Failed to get HIP graph node depedencies size"); + + dependencies.resize(num_dependencies, nullptr); + RETURN_IF_ROCM_ERROR( + hipGraphNodeGetDependencies(node, dependencies.data(), &num_dependencies), + "Failed to get HIP graph node depedencies"); + + return dependencies; +} + +/* static */ absl::Status GpuDriver::DestroyGraphExec(hipGraphExec_t exec) { VLOG(2) << "Destroying HIP executable graph" << exec; RETURN_IF_ROCM_ERROR(wrap::hipGraphExecDestroy(exec), "Failed to destroy HIP graph"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::StatusOr +/* static */ absl::StatusOr GpuDriver::GraphNodeGetType(hipGraphNode_t node) { hipGraphNodeType node_type = hipGraphNodeTypeCount; RETURN_IF_ROCM_ERROR(hipGraphNodeGetType(node, &node_type), @@ -629,38 +681,38 @@ GpuDriver::GraphNodeGetType(hipGraphNode_t node) { return GraphNodeType::kMemFree; } - return tsl::Status(absl::StatusCode::kInternal, - "Invalid HIP graph node type"); + return absl::InternalError("Invalid HIP graph node type"); } -/* static */ tsl::Status GpuDriver::GraphDebugDotPrint(hipGraph_t graph, - const char* path) { +/* static */ absl::StatusOr GpuDriver::GraphDebugDotPrint( + hipGraph_t graph, const char* path, bool return_printed_graph) { VLOG(2) << "Print HIP graph " << graph << " debug dot file to " << path; int flags = hipGraphDebugDotFlagsVerbose; RETURN_IF_ROCM_ERROR(wrap::hipGraphDebugDotPrint(graph, path, flags), "Failed to print gpu graph debug file"); - if (VLOG_IS_ON(100)) { + if (return_printed_graph) { std::string data; if (tsl::ReadFileToString(tsl::Env::Default(), path, &data).ok()) { - VLOG(200) << "HIP graph " << graph << " debug file:\n" << data; + return data; } else { LOG(WARNING) << "failed to read gpu graph debug file " << path; } } - return ::tsl::OkStatus(); + return std::string(path); } -/* static */ tsl::Status GpuDriver::DeviceGraphMemTrim(GpuDeviceHandle device) { +/* static */ absl::Status GpuDriver::DeviceGraphMemTrim( + GpuDeviceHandle device) { VLOG(2) << "Trim ROCM device graph memory " << device; RETURN_IF_ROCM_ERROR(wrap::hipDeviceGraphMemTrim(device), "Failed to trim device graph memory"); - return tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::StatusOr GpuDriver::StreamIsCapturing( +/* static */ absl::StatusOr GpuDriver::StreamIsCapturing( GpuStreamHandle stream) { VLOG(2) << "Checking if stream " << stream << " is capturing"; @@ -671,7 +723,7 @@ GpuDriver::GraphNodeGetType(hipGraphNode_t node) { return status == hipStreamCaptureStatusActive; } -/* static */ tsl::Status GpuDriver::GraphConditionalHandleCreate( +/* static */ absl::Status GpuDriver::GraphConditionalHandleCreate( GpuGraphConditionalHandle* handle, hipGraph_t graph, GpuContext* context, unsigned int default_launch_value, unsigned int flags) { VLOG(2) << "Create conditional handle for a graph " << graph @@ -683,18 +735,30 @@ GpuDriver::GraphNodeGetType(hipGraphNode_t node) { "HIP graph conditional nodes are not implemented yet"); } -/* static */ tsl::StatusOr +/* static */ absl::StatusOr GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, - absl::Span deps, + absl::Span deps, const GpuGraphNodeParams& params) { return absl::UnimplementedError("unsupported node type"); } -/* static */ tsl::Status GpuDriver::GraphAddKernelNode( - hipGraphNode_t* node, hipGraph_t graph, absl::Span deps, - absl::string_view kernel_name, hipFunction_t function, - unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, - unsigned int block_dim_x, unsigned int block_dim_y, +/* static */ absl::Status GpuDriver::GraphAddEmptyNode( + hipGraphNode_t* node, hipGraph_t graph, + absl::Span deps) { + VLOG(2) << "Add empty node to a graph " << graph << "; deps: " << deps.size(); + + RETURN_IF_ROCM_ERROR( + wrap::hipGraphAddEmptyNode(node, graph, deps.data(), deps.size()), + "Failed to add empty node to a HIP graph"); + + return absl::OkStatus(); +} + +/* static */ absl::Status GpuDriver::GraphAddKernelNode( + hipGraphNode_t* node, hipGraph_t graph, + absl::Span deps, absl::string_view kernel_name, + hipFunction_t function, unsigned int grid_dim_x, unsigned int grid_dim_y, + unsigned int grid_dim_z, unsigned int block_dim_x, unsigned int block_dim_y, unsigned int block_dim_z, unsigned int shared_mem_bytes, void** kernel_params, void** extra) { VLOG(2) << "Add kernel node to a graph " << graph @@ -729,10 +793,20 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, deps.size(), ¶ms), "Failed to add kernel node to a HIP graph"); - return ::tsl::OkStatus(); + return absl::OkStatus(); +} + +/* static */ absl::StatusOr GpuDriver::GraphGetNodeCount( + hipGraph_t graph) { + VLOG(2) << "Get node count in graph " << graph; + size_t numNodes; + RETURN_IF_ROCM_ERROR(wrap::hipGraphGetNodes(graph, nullptr, &numNodes), + "Failed to get HIP graph node count"); + + return numNodes; } -/*static*/ tsl::Status GpuDriver::GraphExecKernelNodeSetParams( +/*static*/ absl::Status GpuDriver::GraphExecKernelNodeSetParams( GpuGraphExecHandle exec, GpuGraphNodeHandle node, absl::string_view kernel_name, GpuFunctionHandle function, unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, @@ -767,33 +841,16 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, "Failed to set shared memory size"); } - RETURN_IF_ROCM_ERROR(hipGraphExecKernelNodeSetParams(exec, node, ¶ms), - "Failed to set HIP graph kernel node params"); - - return ::tsl::OkStatus(); -} - -/* static */ tsl::Status GpuDriver::GraphAddMemcpyD2DNode( - GpuContext* context, hipGraphNode_t* node, hipGraph_t graph, - absl::Span deps, hipDeviceptr_t gpu_dst, - hipDeviceptr_t gpu_src, uint64_t size) { - VLOG(2) << "Add memcpy d2d node to a graph " << graph - << "; dst: " << reinterpret_cast(gpu_dst) - << "; src: " << reinterpret_cast(gpu_src) << "; size: " << size - << "; context: " << context->context() << "; deps: " << deps.size(); - RETURN_IF_ROCM_ERROR( - wrap::hipGraphAddMemcpyNode1D( - node, graph, deps.data(), deps.size(), - reinterpret_cast(gpu_dst), reinterpret_cast(gpu_src), - static_cast(size), hipMemcpyDeviceToDevice), - "Failed to add memcpy d2d node to a HIP graph"); - return tsl::OkStatus(); + wrap::hipGraphExecKernelNodeSetParams(exec, node, ¶ms), + "Failed to set HIP graph kernel node params"); + + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::GraphAddChildNode( - hipGraphNode_t* node, hipGraph_t graph, absl::Span deps, - hipGraph_t child) { +/* static */ absl::Status GpuDriver::GraphAddChildNode( + hipGraphNode_t* node, hipGraph_t graph, + absl::Span deps, hipGraph_t child) { VLOG(2) << "Create a new node by cloning the child graph " << child << " and add it to " << graph << "; deps: " << deps.size(); @@ -801,22 +858,237 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, wrap::hipGraphAddChildGraphNode(node, graph, deps.data(), deps.size(), child), "Failed to create a child graph node and add it to a HIP graph"); - return tsl::OkStatus(); + return absl::OkStatus(); } -/*static*/ tsl::Status GpuDriver::GraphExecChildNodeSetParams( +/*static*/ absl::Status GpuDriver::GraphExecChildNodeSetParams( GpuGraphExecHandle exec, GpuGraphNodeHandle node, GpuGraphHandle child) { VLOG(2) << "Set child node params " << node << " in graph executable " << exec << "to params contained in " << child; RETURN_IF_ROCM_ERROR( wrap::hipGraphExecChildGraphNodeSetParams(exec, node, child), - "Failed to set ROCm graph child node params"); + "Failed to set HIP graph child node params"); + + return absl::OkStatus(); +} + +static hipMemAccessFlags ToHipMemAccessFlags( + GpuDriver::MemAccessFlags access_flags) { + switch (access_flags) { + case GpuDriver::MemAccessFlags::kNone: + return hipMemAccessFlagsProtNone; + case GpuDriver::MemAccessFlags::kRead: + return hipMemAccessFlagsProtRead; + case GpuDriver::MemAccessFlags::kReadWrite: + return hipMemAccessFlagsProtReadWrite; + } +} + +static hipMemLocationType ToHipLocationType( + GpuDriver::MemLocationType location_type) { + switch (location_type) { + case GpuDriver::MemLocationType::kInvalid: + return hipMemLocationTypeInvalid; + case GpuDriver::MemLocationType::kDevice: + return hipMemLocationTypeDevice; + case GpuDriver::MemLocationType::kHost: + case GpuDriver::MemLocationType::kHostNuma: + case GpuDriver::MemLocationType::kHostNumaCurrent: + return hipMemLocationTypeInvalid; + } +} + +static hipMemAllocationType ToHipAllocationType( + GpuDriver::MemAllocationType allocation_type) { + switch (allocation_type) { + case GpuDriver::MemAllocationType::kInvalid: + return hipMemAllocationTypeInvalid; + case GpuDriver::MemAllocationType::kPinned: + return hipMemAllocationTypePinned; + } +} + +/*static*/ absl::Status GpuDriver::GraphAddMemFreeNode( + GpuGraphNodeHandle* node, GpuGraphHandle graph, + absl::Span deps, GpuDevicePtr gpu_dst) { + RETURN_IF_ROCM_ERROR(wrap::hipGraphAddMemFreeNode(node, graph, deps.data(), + deps.size(), gpu_dst), + "Failed to add memory free node to a HIP graph"); + return absl::OkStatus(); +} + +/*static*/ absl::Status GpuDriver::GraphAddMemAllocNode( + GpuGraphNodeHandle* node, GpuGraphHandle graph, + absl::Span deps, MemAccessFlags access_flags, + MemLocationType location_type, int device_id, + MemAllocationType allocation_type, uint64_t size, GpuDevicePtr* d_ptr, + uint64_t max_pool_size) { + hipMemLocation mem_loc = { + .type = ToHipLocationType(location_type), + .id = device_id, + }; + + hipMemPoolProps props{}; + props.allocType = ToHipAllocationType(allocation_type); + props.handleTypes = hipMemHandleTypeNone; + props.location = mem_loc; + + hipMemAccessDesc mem_desc = { + .location = mem_loc, + .flags = ToHipMemAccessFlags(access_flags), + }; + + hipMemAllocNodeParams params{ + .poolProps = props, + .accessDescs = &mem_desc, + .accessDescCount = 1, + .bytesize = size, + .dptr = nullptr, + }; + + RETURN_IF_ROCM_ERROR(wrap::hipGraphAddMemAllocNode(node, graph, deps.data(), + deps.size(), ¶ms), + "Failed to add memory allocation node to a HIP graph"); + + VLOG(2) << "Add MemAllocNode to a graph " << graph << " size " << size + << " address " << reinterpret_cast(params.dptr); + + *d_ptr = params.dptr; + return absl::OkStatus(); +} + +/*static*/ absl::StatusOr> +GpuDriver::GraphGetMemAllocNodeParams(GpuGraphNodeHandle node) { + hipMemAllocNodeParams params; + RETURN_IF_ROCM_ERROR(wrap::hipGraphMemAllocNodeGetParams(node, ¶ms), + "Failed to get memory allocation node parameter"); + return std::pair{params.dptr, params.bytesize}; +} + +/* static */ absl::Status GpuDriver::GraphAddMemcpyD2DNode( + GpuContext* context, GpuGraphNodeHandle* node, GpuGraphHandle graph, + absl::Span deps, GpuDevicePtr gpu_dst, + GpuDevicePtr gpu_src, uint64_t size) { + VLOG(2) << "Add memcpy d2d node to a graph " << graph + << "; dst: " << reinterpret_cast(gpu_dst) + << "; src: " << reinterpret_cast(gpu_src) << "; size: " << size + << "; context: " << context->context() << "; deps: " << deps.size(); - return tsl::OkStatus(); + RETURN_IF_ROCM_ERROR(wrap::hipGraphAddMemcpyNode1D( + node, graph, deps.data(), deps.size(), gpu_dst, + gpu_src, size, hipMemcpyDeviceToDevice), + "Failed to add memcpy d2d node to a HIP graph"); + + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::LaunchKernel( +/* static */ absl::Status GpuDriver::GraphExecMemcpyD2DNodeSetParams( + GpuContext* context, GpuGraphExecHandle exec, GpuGraphNodeHandle node, + GpuDevicePtr gpu_dst, GpuDevicePtr gpu_src, uint64_t size) { + VLOG(2) << "Set memcpy d2d node params " << node << " in graph executable " + << exec << "; dst: " << reinterpret_cast(gpu_dst) + << "; src: " << reinterpret_cast(gpu_src) << "; size: " << size + << "; context: " << context->context(); + + RETURN_IF_ROCM_ERROR( + wrap::hipGraphExecMemcpyNodeSetParams1D(exec, node, gpu_dst, gpu_src, + size, hipMemcpyDeviceToDevice), + "Failed to set memcpy d2d node params"); + + return absl::OkStatus(); +} + +namespace { + +struct BitPatternToString { + std::string operator()(uint8_t pattern) { + return absl::StrCat("u8:", pattern); + } + std::string operator()(uint16_t pattern) { + return absl::StrCat("u16:", pattern); + } + std::string operator()(uint32_t pattern) { + return absl::StrCat("u32:", pattern); + } +}; + +// Broadcasts a pattern value of 1/2/4 bytes to a 4 byte value. +struct BitPatternToValue { + std::pair operator()(uint8_t pattern) { + unsigned value = pattern; + return {(value << 24) | (value << 16) | (value << 8) | value, + /*element_size=*/1}; + } + std::pair operator()(uint16_t pattern) { + unsigned value = pattern; + return {(value << 16) | value, /*element_size=*/2}; + } + std::pair operator()(uint32_t pattern) { + return {pattern, /*element_size=*/4}; + } +}; + +} // namespace + +/* static */ absl::Status GpuDriver::GraphAddMemsetNode( + GpuContext* context, GpuGraphNodeHandle* node, GpuGraphHandle graph, + absl::Span deps, GpuDevicePtr dst, + std::variant bit_pattern, + uint64_t num_elements) { + VLOG(2) << "Add memset node to a graph " << graph + << "; dst: " << reinterpret_cast(dst) + << "; bit_pattern: " << std::visit(BitPatternToString(), bit_pattern) + << "; num_elements: " << num_elements + << "; context: " << context->context() << "; deps: " << deps.size(); + + auto [value, element_size] = std::visit(BitPatternToValue(), bit_pattern); + + hipMemsetParams params{ + .dst = dst, + .elementSize = element_size, + .height = 1, + .pitch = 0, // unused if height is 1 + .value = value, + .width = num_elements, + }; + + RETURN_IF_ROCM_ERROR(wrap::hipGraphAddMemsetNode(node, graph, deps.data(), + deps.size(), ¶ms), + "Failed to add memset node to a HIP graph"); + + return absl::OkStatus(); +} + +/* static */ absl::Status GpuDriver::GraphExecMemsetNodeSetParams( + GpuContext* context, GpuGraphExecHandle exec, GpuGraphNodeHandle node, + GpuDevicePtr dst, std::variant bit_pattern, + uint64_t num_elements) { + VLOG(2) << "Set memset node params " << node << " in graph executable " + << exec << "; dst: " << reinterpret_cast(dst) + << "; bit_pattern: " << std::visit(BitPatternToString(), bit_pattern) + << "; num_elements: " << num_elements + << "; context: " << context->context(); + + auto [value, element_size] = std::visit(BitPatternToValue(), bit_pattern); + + hipMemsetParams params{ + .dst = dst, + .elementSize = element_size, + .height = 1, + .pitch = 0, // unused if height is 1 + .value = value, + .width = num_elements, + }; + + RETURN_IF_ROCM_ERROR( + wrap::hipGraphExecMemsetNodeSetParams(exec, node, ¶ms), + "Failed to set memset node params"); + + return absl::OkStatus(); +} + +/* static */ absl::Status GpuDriver::LaunchKernel( GpuContext* context, absl::string_view kernel_name, hipFunction_t function, unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x, unsigned int block_dim_y, @@ -848,28 +1120,28 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, block_dim_y, "x", block_dim_z); VLOG(2) << "successfully launched kernel"; - return tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::LoadPtx(GpuContext* context, - const char* ptx_contents, - hipModule_t* module) { +/* static */ absl::Status GpuDriver::LoadPtx(GpuContext* context, + const char* ptx_contents, + hipModule_t* module) { LOG(ERROR) << "Feature not supported on ROCm platform (LoadPtx)"; - return tsl::errors::Internal("Not Implemented"); + return absl::InternalError("Not Implemented"); } -/* static */ tsl::Status GpuDriver::LoadCubin(GpuContext* context, - const char* cubin_bytes, - hipModule_t* module) { - return tsl::Status{absl::StatusCode::kInternal, - "Feature not supported on ROCm platform (LoadCubin)"}; +/* static */ absl::Status GpuDriver::LoadCubin(GpuContext* context, + const char* cubin_bytes, + hipModule_t* module) { + return absl::Status{absl::StatusCode::kInternal, + "Feature not supported on ROCm platform (LoadCubin)"}; } -/* static */ tsl::Status GpuDriver::LoadHsaco(GpuContext* context, - const char* hsaco_contents, - hipModule_t* module) { +/* static */ absl::Status GpuDriver::LoadHsaco(GpuContext* context, + const char* hsaco_contents, + hipModule_t* module) { absl::Notification notification; - tsl::Status ret = tsl::OkStatus(); + absl::Status ret = absl::OkStatus(); GetDriverExecutor()->Schedule( [context, hsaco_contents, module, &ret, ¬ification]() { ScopedActivateContext activation{context}; @@ -878,7 +1150,8 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, hipError_t res = wrap::hipModuleLoadData(module, hsaco_data); if (res != hipSuccess) { - ret = tsl::errors::Internal("Failed to load HSACO: ", ToString(res)); + ret = absl::InternalError( + absl::StrCat("Failed to load HSACO: ", ToString(res))); notification.Notify(); } @@ -890,35 +1163,35 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, return ret; } -/* static */ tsl::Status GpuDriver::SynchronousMemsetUint8( +/* static */ absl::Status GpuDriver::SynchronousMemsetUint8( GpuContext* context, hipDeviceptr_t location, uint8 value, size_t size) { ScopedActivateContext activation{context}; RETURN_IF_ROCM_ERROR(wrap::hipMemsetD8(location, value, size), "Failed to memset memory"); - return tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::SynchronousMemsetUint32( +/* static */ absl::Status GpuDriver::SynchronousMemsetUint32( GpuContext* context, hipDeviceptr_t location, uint32 value, size_t uint32_count) { ScopedActivateContext activation{context}; void* pointer = absl::bit_cast(location); RETURN_IF_ROCM_ERROR(wrap::hipMemsetD32(pointer, value, uint32_count), "Failed to memset memory"); - return tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::AsynchronousMemsetUint8( +/* static */ absl::Status GpuDriver::AsynchronousMemsetUint8( GpuContext* context, hipDeviceptr_t location, uint8 value, size_t uint32_count, GpuStreamHandle stream) { ScopedActivateContext activation{context}; RETURN_IF_ROCM_ERROR( wrap::hipMemsetAsync(location, value, uint32_count, stream), "Failed to enqueue async memset operation"); - return tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::AsynchronousMemsetUint32( +/* static */ absl::Status GpuDriver::AsynchronousMemsetUint32( GpuContext* context, hipDeviceptr_t location, uint32 value, size_t uint32_count, GpuStreamHandle stream) { ScopedActivateContext activation{context}; @@ -927,7 +1200,7 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, wrap::hipMemsetD32Async(pointer, value, uint32_count, stream), "Failed to enqueue async memset operation"); VLOG(2) << "successfully enqueued async memset operation"; - return tsl::OkStatus(); + return absl::OkStatus(); } /* static */ bool GpuDriver::AddStreamCallback(GpuContext* context, @@ -942,16 +1215,15 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, return true; } -/* static */ tsl::Status GpuDriver::GetModuleFunction(GpuContext* context, - hipModule_t module, - const char* kernel_name, - hipFunction_t* function) { +/* static */ absl::Status GpuDriver::GetModuleFunction( + GpuContext* context, hipModule_t module, const char* kernel_name, + hipFunction_t* function) { ScopedActivateContext activated{context}; CHECK(module != nullptr && kernel_name != nullptr); RETURN_IF_ROCM_ERROR( wrap::hipModuleGetFunction(function, module, kernel_name), "Failed to get kernel"); - return tsl::OkStatus(); + return absl::OkStatus(); } /* static */ bool GpuDriver::GetModuleSymbol(GpuContext* context, @@ -984,15 +1256,14 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, } } -/* static */ tsl::StatusOr GpuDriver::DeviceFromContext( +/* static */ absl::StatusOr GpuDriver::DeviceFromContext( GpuContext* context) { ScopedActivateContext activated{context}; hipDevice_t device = -1; hipError_t result = wrap::hipCtxGetDevice(&device); if (result == hipSuccess) return device; - return tsl::Status( - absl::StatusCode::kInternal, + return absl::InternalError( absl::StrCat("failed to get device for context: ", ToString(result))); } @@ -1149,11 +1420,11 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, : lowest; } -/* static */ tsl::Status GpuDriver::DestroyEvent(GpuContext* context, - GpuEventHandle* event) { +/* static */ absl::Status GpuDriver::DestroyEvent(GpuContext* context, + GpuEventHandle* event) { if (*event == nullptr) { - return tsl::Status{absl::StatusCode::kInvalidArgument, - "input event cannot be null"}; + return absl::Status{absl::StatusCode::kInvalidArgument, + "input event cannot be null"}; } ScopedActivateContext activated{context}; @@ -1162,49 +1433,49 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, switch (res) { case hipSuccess: - return tsl::OkStatus(); + return absl::OkStatus(); case hipErrorDeinitialized: case hipErrorNotInitialized: - return tsl::Status{ + return absl::Status{ absl::StatusCode::kFailedPrecondition, absl::StrFormat("error destroying ROCM event in device %d: %s", context->device_ordinal(), ToString(res).c_str())}; default: - return tsl::Status{ + return absl::Status{ absl::StatusCode::kInternal, absl::StrFormat("error destroying ROCM event in device %d: %s", context->device_ordinal(), ToString(res).c_str())}; } } -/* static */ tsl::Status GpuDriver::RecordEvent(GpuContext* context, - GpuEventHandle event, - GpuStreamHandle stream) { +/* static */ absl::Status GpuDriver::RecordEvent(GpuContext* context, + GpuEventHandle event, + GpuStreamHandle stream) { ScopedActivateContext activated{context}; hipError_t res = wrap::hipEventRecord(event, stream); switch (res) { case hipSuccess: - return tsl::OkStatus(); + return absl::OkStatus(); case hipErrorDeinitialized: case hipErrorNotInitialized: - return tsl::Status{ + return absl::Status{ absl::StatusCode::kFailedPrecondition, absl::StrFormat("error recording ROCM event on stream %p: %s", stream, ToString(res).c_str())}; default: - return tsl::Status{ + return absl::Status{ absl::StatusCode::kInvalidArgument, absl::StrFormat("error recording ROCM event on stream %p: %s", stream, ToString(res).c_str())}; } } -/* static */ tsl::StatusOr GpuDriver::QueryEvent( +/* static */ absl::StatusOr GpuDriver::QueryEvent( GpuContext* context, GpuEventHandle event) { ScopedActivateContext activated{context}; hipError_t res = wrap::hipEventQuery(event); if (res != hipSuccess && res != hipErrorNotReady) { - return tsl::Status{ + return absl::Status{ absl::StatusCode::kInternal, absl::StrFormat("failed to query event: %s", ToString(res).c_str())}; } @@ -1259,15 +1530,15 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, return true; } -/* static */ tsl::Status GpuDriver::SynchronizeStream(GpuContext* context, - GpuStreamHandle stream) { +/* static */ absl::Status GpuDriver::SynchronizeStream(GpuContext* context, + GpuStreamHandle stream) { ScopedActivateContext activated{context}; CHECK(stream != nullptr); RETURN_IF_ROCM_ERROR(wrap::hipStreamSynchronize(stream), "Could not synchronize on ROCM stream"); VLOG(2) << "successfully synchronized stream " << stream << " on device " << context->device_ordinal(); - return tsl::OkStatus(); + return absl::OkStatus(); } /* static */ bool GpuDriver::IsStreamIdle(GpuContext* context, @@ -1285,10 +1556,9 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, return false; } -/* static */ tsl::Status GpuDriver::SynchronousMemcpyD2H(GpuContext* context, - void* host_dst, - hipDeviceptr_t gpu_src, - uint64_t size) { +/* static */ absl::Status GpuDriver::SynchronousMemcpyD2H( + GpuContext* context, void* host_dst, hipDeviceptr_t gpu_src, + uint64_t size) { ScopedActivateContext activation{context}; RETURN_IF_ROCM_ERROR( wrap::hipMemcpyDtoH(host_dst, gpu_src, size), @@ -1297,13 +1567,12 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, host_dst, absl::bit_cast(gpu_src), size, size)); VLOG(2) << "successfully sync memcpy'd d2h of " << size << " bytes to " << host_dst; - return tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::SynchronousMemcpyH2D(GpuContext* context, - hipDeviceptr_t gpu_dst, - const void* host_src, - uint64_t size) { +/* static */ absl::Status GpuDriver::SynchronousMemcpyH2D( + GpuContext* context, hipDeviceptr_t gpu_dst, const void* host_src, + uint64_t size) { ScopedActivateContext activation{context}; RETURN_IF_ROCM_ERROR( wrap::hipMemcpyHtoD(gpu_dst, const_cast(host_src), size), @@ -1311,14 +1580,13 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, "failed to synchronous memcpy from host to device: Gpu dst: %p;" " host src: %p; size: %llu=0x%llx", absl::bit_cast(gpu_dst), host_src, size, size)); - VLOG(2) << "successfully enqueued sync memcpy h2d of " << size << " bytes"; - return tsl::OkStatus(); + VLOG(2) << "successfully sync memcpy'd h2d of " << size << " bytes"; + return absl::OkStatus(); } -/* static */ tsl::Status GpuDriver::SynchronousMemcpyD2D(GpuContext* context, - hipDeviceptr_t gpu_dst, - hipDeviceptr_t gpu_src, - uint64_t size) { +/* static */ absl::Status GpuDriver::SynchronousMemcpyD2D( + GpuContext* context, hipDeviceptr_t gpu_dst, hipDeviceptr_t gpu_src, + uint64_t size) { ScopedActivateContext activation{context}; RETURN_IF_ROCM_ERROR( wrap::hipMemcpyDtoD(gpu_dst, gpu_src, size), @@ -1328,7 +1596,7 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, absl::bit_cast(gpu_dst), absl::bit_cast(gpu_src), size, size)); VLOG(2) << "successfully sync memcpy'd d2d of " << size << " bytes"; - return tsl::OkStatus(); + return absl::OkStatus(); } /* static */ bool GpuDriver::AsynchronousMemcpyD2H(GpuContext* context, @@ -1348,7 +1616,8 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, } VLOG(2) << "successfully enqueued async memcpy d2h of " << size << " bytes from " << absl::bit_cast(gpu_src) << " to " - << host_dst << " on stream " << stream; + << host_dst << " on stream " << stream + << " device: " << context->device_ordinal(); return true; } @@ -1368,8 +1637,10 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, size); return false; } - VLOG(2) << "successfully enqueued async memcpy h2d of " << size << " bytes" - << " on stream " << stream; + VLOG(2) << "successfully enqueued async memcpy h2d of " << size + << " bytes from " << host_src << " to " + << absl::bit_cast(gpu_dst) << " on stream " << stream + << " device: " << context->device_ordinal(); return true; } @@ -1396,13 +1667,17 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, return false; } - VLOG(2) << "successfully enqueued async memcpy d2d of " << size << " bytes"; + + VLOG(2) << "successfully enqueued async memcpy d2d of " << size + << " bytes from " << absl::bit_cast(gpu_src) << " to " + << absl::bit_cast(gpu_dst) << " on stream " << stream + << " device: " << context->device_ordinal(); return true; } -/* static */ tsl::Status GpuDriver::InitEvent(GpuContext* context, - GpuEventHandle* event, - EventFlags flags) { +/* static */ absl::Status GpuDriver::InitEvent(GpuContext* context, + GpuEventHandle* event, + EventFlags flags) { int hipflags; switch (flags) { case EventFlags::kDefault: @@ -1419,12 +1694,12 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, hipError_t res = wrap::hipEventCreateWithFlags(event, hipflags); if (res == hipSuccess) { - return tsl::OkStatus(); + return absl::OkStatus(); } else if (res == hipErrorMemoryAllocation) { - return tsl::Status{absl::StatusCode::kResourceExhausted, - "could not create ROCM event: out of device memory"}; + return absl::Status{absl::StatusCode::kResourceExhausted, + "could not create ROCM event: out of device memory"}; } else { - return tsl::Status{ + return absl::Status{ absl::StatusCode::kFailedPrecondition, absl::StrCat("could not create ROCM event: ", ToString(res))}; } @@ -1444,59 +1719,54 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, return device_count; } -/* static */ tsl::Status GpuDriver::GetComputeCapability(int* cc_major, - int* cc_minor, - hipDevice_t device) { - return tsl::Status( - absl::StatusCode::kInternal, +/* static */ absl::Status GpuDriver::GetComputeCapability(int* cc_major, + int* cc_minor, + hipDevice_t device) { + return absl::InternalError( absl::StrFormat("failed to get compute capability for device: %d " "(unsupported API on AMD Gpus)", device)); } -/* static */ tsl::Status GpuDriver::GetPointerAddressRange(hipDeviceptr_t dptr, - hipDeviceptr_t* base, - size_t* size) { +/* static */ absl::Status GpuDriver::GetPointerAddressRange( + hipDeviceptr_t dptr, hipDeviceptr_t* base, size_t* size) { hipError_t result = wrap::hipMemGetAddressRange(base, size, dptr); if (result == hipSuccess) { - return tsl::OkStatus(); + return absl::OkStatus(); } else if (result == hipErrorNotFound) { // We differentiate between "this pointer is unknown" (return here) and // "there was an internal error while performing this operation" (return // below). - return tsl::Status{absl::StatusCode::kNotFound, - absl::StrFormat("not a device pointer %p; %s", - reinterpret_cast(dptr), - ToString(result).c_str())}; + return absl::Status{absl::StatusCode::kNotFound, + absl::StrFormat("not a device pointer %p; %s", + reinterpret_cast(dptr), + ToString(result).c_str())}; } - return tsl::Status{ + return absl::Status{ absl::StatusCode::kInternal, absl::StrFormat("failed to get pointer into for device pointer %p; %s", reinterpret_cast(dptr), ToString(result).c_str())}; } -/* static */ tsl::StatusOr GpuDriver::GetPointerContext( +/* static */ absl::StatusOr GpuDriver::GetPointerContext( hipDeviceptr_t pointer) { GpuContext* context = nullptr; hipError_t result = wrap::hipPointerGetAttribute( &context, HIP_POINTER_ATTRIBUTE_CONTEXT, pointer); if (result == hipSuccess) { if (context == nullptr) { - return tsl::Status( - absl::StatusCode::kUnavailable, + return absl::UnavailableError( "Empty context returned while querying context for device pointer"); } return context; } - return tsl::Status( - absl::StatusCode::kInternal, - absl::StrCat("failed to query context for device pointer: ", - ToString(result))); + return absl::InternalError(absl::StrCat( + "failed to query context for device pointer: ", ToString(result))); } -/* static */ tsl::StatusOr GpuDriver::GetPointerMemorySpace( +/* static */ absl::StatusOr GpuDriver::GetPointerMemorySpace( hipDeviceptr_t pointer) { unsigned int value; hipError_t result = hipSuccess; @@ -1507,25 +1777,25 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, case hipMemoryTypeHost: return MemorySpace::kHost; default: - return tsl::Status{ + return absl::Status{ absl::StatusCode::kInternal, absl::StrCat("unknown memory space provided by ROCM API: ", value)}; } } - return tsl::Status{ + return absl::Status{ absl::StatusCode::kInternal, absl::StrCat("failed to query device pointer for memory space: ", ToString(result))}; } -/* static */ tsl::StatusOr GpuDriver::GetPointerDevice( +/* static */ absl::StatusOr GpuDriver::GetPointerDevice( hipDeviceptr_t pointer) { hipPointerAttribute_t pointerAttributes; hipError_t result = wrap::hipPointerGetAttributes(&pointerAttributes, pointer); if (result != hipSuccess) { - return tsl::Status{ + return absl::Status{ absl::StatusCode::kInternal, absl::StrCat("failed to get device for pointer: ", ToString(result))}; } @@ -1533,7 +1803,7 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, hipDevice_t device; result = wrap::hipDeviceGet(&device, pointerAttributes.device); if (result != hipSuccess) { - return tsl::Status{ + return absl::Status{ absl::StatusCode::kInternal, absl::StrCat("failed to get device for pointer: ", ToString(result))}; } @@ -1541,37 +1811,43 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, return device; } -/* static */ tsl::Status GpuDriver::GetGpuISAVersion(int* version, - hipDevice_t device) { +/* static */ absl::Status GpuDriver::GetGpuISAVersion(int* version, + hipDevice_t device) { hipDeviceProp_t props; hipError_t result = wrap::hipGetDeviceProperties(&props, device); if (result == hipSuccess) { - *version = props.gcnArch; - return tsl::OkStatus(); + std::string gcnName = props.gcnArchName; + std::vector tokens = absl::StrSplit(gcnName, ':'); + std::string amdgpu_version = gcnName; + if (!tokens.empty() && tokens[0].size() >= 3) { + amdgpu_version = tokens[0].substr(3); + } + *version = stoi(amdgpu_version); + return absl::OkStatus(); } *version = 0; - return tsl::Status{ + return absl::Status{ absl::StatusCode::kInternal, absl::StrFormat("failed to determine AMDGpu ISA version for device %d", device)}; } -/* static */ tsl::Status GpuDriver::GetGpuGCNArchName( +/* static */ absl::Status GpuDriver::GetGpuGCNArchName( hipDevice_t device, std::string* gcnArchName) { hipDeviceProp_t props; hipError_t result = wrap::hipGetDeviceProperties(&props, device); if (result == hipSuccess) { *gcnArchName = props.gcnArchName; - return tsl::OkStatus(); + return absl::OkStatus(); } *gcnArchName = ""; - return tsl::Status{ + return absl::Status{ absl::StatusCode::kInternal, absl::StrFormat("failed to determine AMDGpu GCN Arch Name for device %d", device)}; } -/* static */ tsl::StatusOr GpuDriver::GetMFMASupport() { +/* static */ absl::StatusOr GpuDriver::GetMFMASupport() { hipDeviceProp_t props; int dev = 0; hipError_t result = wrap::hipGetDevice(&dev); @@ -1579,15 +1855,11 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, if (result == hipSuccess) { std::string gcnArchName = props.gcnArchName; VLOG(3) << "GCN arch name " << gcnArchName; - auto pos = gcnArchName.find(":"); - if (pos != string::npos) gcnArchName = gcnArchName.substr(0, pos); - pos = gcnArchName.find("gfx"); - if (pos != string::npos) gcnArchName = gcnArchName.substr(pos + 3); - VLOG(3) << "GCN arch name (stripped) " << gcnArchName; - return ((gcnArchName == "908") || (gcnArchName == "909") || - (gcnArchName == "90a") || (gcnArchName == "940")); - } - return tsl::Status{ + auto compute_capability = RocmComputeCapability(gcnArchName); + VLOG(3) << "GCN arch name (stripped) " << compute_capability.gfx_version(); + return compute_capability.gfx9_mi100_or_later(); + } + return absl::Status{ absl::StatusCode::kInternal, absl::StrFormat("failed to determine AMDGpu GCN Arch Name for device %d", dev)}; @@ -1596,12 +1868,12 @@ GpuDriver::GraphAddNode(hipGraphNode_t* node, hipGraph_t graph, // Helper function that turns the integer output of hipDeviceGetAttribute to // type T and wraps it in a StatusOr. template -static tsl::StatusOr GetSimpleAttribute(hipDevice_t device, - hipDeviceAttribute_t attribute) { +static absl::StatusOr GetSimpleAttribute(hipDevice_t device, + hipDeviceAttribute_t attribute) { int value = -1; hipError_t result = wrap::hipDeviceGetAttribute(&value, attribute, device); if (result != hipSuccess) { - return tsl::Status{ + return absl::Status{ absl::StatusCode::kNotFound, absl::StrCat("could not retrieve ROCM device attribute (", attribute, "): ", ToString(result))}; @@ -1610,48 +1882,48 @@ static tsl::StatusOr GetSimpleAttribute(hipDevice_t device, return converted; } -/* static */ tsl::StatusOr GpuDriver::GetMultiprocessorCount( +/* static */ absl::StatusOr GpuDriver::GetMultiprocessorCount( hipDevice_t device) { return GetSimpleAttribute(device, hipDeviceAttributeMultiprocessorCount); } -/* static */ tsl::StatusOr GpuDriver::GetMaxSharedMemoryPerCore( +/* static */ absl::StatusOr GpuDriver::GetMaxSharedMemoryPerCore( hipDevice_t device) { return GetSimpleAttribute( device, hipDeviceAttributeMaxSharedMemoryPerMultiprocessor); } -/* static */ tsl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlock( +/* static */ absl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlock( hipDevice_t device) { return GetSimpleAttribute(device, hipDeviceAttributeMaxSharedMemoryPerBlock); } -/* static */ tsl::StatusOr GpuDriver::GetMaxThreadsPerMultiprocessor( +/* static */ absl::StatusOr GpuDriver::GetMaxThreadsPerMultiprocessor( hipDevice_t device) { return GetSimpleAttribute( device, hipDeviceAttributeMaxThreadsPerMultiProcessor); } -/* static */ tsl::StatusOr GpuDriver::GetMaxThreadsPerBlock( +/* static */ absl::StatusOr GpuDriver::GetMaxThreadsPerBlock( hipDevice_t device) { return GetSimpleAttribute(device, hipDeviceAttributeMaxThreadsPerBlock); } -/* static */ tsl::StatusOr GpuDriver::GetMaxRegistersPerBlock( +/* static */ absl::StatusOr GpuDriver::GetMaxRegistersPerBlock( hipDevice_t device) { return GetSimpleAttribute(device, hipDeviceAttributeMaxRegistersPerBlock); } -/* static */ tsl::StatusOr GpuDriver::GetThreadsPerWarp( +/* static */ absl::StatusOr GpuDriver::GetThreadsPerWarp( hipDevice_t device) { return GetSimpleAttribute(device, hipDeviceAttributeWarpSize); } -/* static */ tsl::Status GpuDriver::GetGridLimits(int* x, int* y, int* z, - hipDevice_t device) { +/* static */ absl::Status GpuDriver::GetGridLimits(int* x, int* y, int* z, + hipDevice_t device) { int value; RETURN_IF_ROCM_ERROR(wrap::hipDeviceGetAttribute( &value, hipDeviceAttributeMaxGridDimX, device), @@ -1667,17 +1939,14 @@ static tsl::StatusOr GetSimpleAttribute(hipDevice_t device, &value, hipDeviceAttributeMaxGridDimZ, device), "failed to query max grid dim z"); *z = value; - return tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ bool GpuDriver::GetDriverVersion(int* driver_version) { - hipError_t res = wrap::hipDriverGetVersion(driver_version); - if (res != hipSuccess) { - LOG(ERROR) << "failed to query driver version: " << ToString(res); - return false; - } - - return true; +/* static */ absl::StatusOr GpuDriver::GetDriverVersion() { + int32_t version; + RETURN_IF_ROCM_ERROR(wrap::hipDriverGetVersion(&version), + "Could not get driver version"); + return version; } /* static */ bool GpuDriver::GetDeviceProperties( @@ -1692,7 +1961,7 @@ static tsl::StatusOr GetSimpleAttribute(hipDevice_t device, return true; } -/* static */ tsl::StatusOr GpuDriver::GetDeviceAttribute( +/* static */ absl::StatusOr GpuDriver::GetDeviceAttribute( hipDeviceAttribute_t attribute, hipDevice_t device) { return GetSimpleAttribute(device, attribute); } @@ -1726,17 +1995,21 @@ static tsl::StatusOr GetSimpleAttribute(hipDevice_t device, } std::string gcnArchName = props.gcnArchName; + auto compute_capability = RocmComputeCapability(gcnArchName); // On gfx90a, we hide 1 GB of GPU memory (512MB for gfx908) from TF, // to allow for late allocations by internal ROCm libraries // (e.g. rocBLAS alone needs~200 MB to put its kernels as of ROCm 4.1) const uint64_t RESERVED_GFX908 = 1048576 * 512; const uint64_t RESERVED_GFX9_X = 1048576 * 1024; - if (gcnArchName.substr(0, 6) == "gfx908") { + const uint64_t RESERVED_GFX10_X = 1048576 * 512; + if (compute_capability.gfx_version() == "gfx908") { *reserve = RESERVED_GFX908; - } else if (gcnArchName.substr(0, 6) == "gfx90a" || - gcnArchName.substr(0, 6) == "gfx940") { + } else if (compute_capability.gfx9_mi200_or_later()) { *reserve = RESERVED_GFX9_X; + } else if (compute_capability.navi21() || compute_capability.navi31()) { + *reserve = RESERVED_GFX10_X; } + return true; } @@ -1838,41 +2111,36 @@ static tsl::StatusOr GetSimpleAttribute(hipDevice_t device, return can_access_peer; } -/* static */ tsl::Status GpuDriver::EnablePeerAccess(GpuContext* from, - GpuContext* to) { +/* static */ absl::Status GpuDriver::EnablePeerAccess(GpuContext* from, + GpuContext* to) { if (from == to) { - return tsl::OkStatus(); // A device can always access its own memory. + return absl::OkStatus(); // A device can always access its own memory. } ScopedActivateContext activated{from}; hipError_t result = wrap::hipCtxEnablePeerAccess(to->context(), 0 /* = flags */); if (result != hipSuccess && result != hipErrorPeerAccessAlreadyEnabled) { - return tsl::Status{ + return absl::Status{ absl::StatusCode::kInternal, absl::StrFormat("failed to enable peer access from %d to %d: %s", from->device_ordinal(), to->device_ordinal(), ToString(result).c_str())}; } - return tsl::OkStatus(); + return absl::OkStatus(); } -/* static */ tsl::StatusOr GpuDriver::GetMaxOccupiedBlocksPerCore( +/* static */ absl::StatusOr GpuDriver::GetMaxOccupiedBlocksPerCore( GpuContext* context, hipFunction_t kernel, int threads_per_block, size_t dynamic_shared_memory_bytes) { ScopedActivateContext activation{context}; int max_blocks = 0; - hipError_t result = hipSuccess; - // TODO(ROCm) implement this feature in HIP - if (result != hipSuccess) { - return tsl::Status{ - absl::StatusCode::kInternal, - absl::StrFormat("failed to calculate occupancy of kernel %p: %s", - kernel, ToString(result).c_str())}; - } - + RETURN_IF_ROCM_ERROR( + wrap::hipModuleOccupancyMaxActiveBlocksPerMultiprocessor( + &max_blocks, kernel, threads_per_block, dynamic_shared_memory_bytes), + "Failed to calculate maximal active blocks per SM"); return max_blocks; } @@ -1880,6 +2148,17 @@ static tsl::StatusOr GetSimpleAttribute(hipDevice_t device, namespace rocm { +absl::Status OccupancyGetMaxPotentialBlockSize(int* gridSize, int* blockSize, + hipFunction_t kernel, + size_t dynSharedMemPerBlk, + int blockSizeLimit) { + RETURN_IF_ROCM_ERROR( + wrap::hipModuleOccupancyMaxPotentialBlockSize( + gridSize, blockSize, kernel, dynSharedMemPerBlk, blockSizeLimit), + "Failed to calculate maximal potential block size"); + return absl::OkStatus(); +} + hipCtx_t CurrentContextOrDie() { hipCtx_t current = nullptr; FAIL_IF_ROCM_ERROR(hipCtxGetCurrent(¤t), diff --git a/xla/stream_executor/rocm/rocm_driver.h b/xla/stream_executor/rocm/rocm_driver.h index 65eb87fc33707..80822446a58d8 100644 --- a/xla/stream_executor/rocm/rocm_driver.h +++ b/xla/stream_executor/rocm/rocm_driver.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,47 +24,13 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "xla/stream_executor/gpu/gpu_driver.h" +#include "tsl/platform/logging.h" namespace stream_executor { namespace gpu { // Formats hipError_t to output prettified values into a log stream. // Error summaries taken from: -string ToString(hipError_t result) { -#define OSTREAM_ROCM_ERROR(__name) \ - case hipError##__name: \ - return "HIP_ERROR_" #__name; - - switch (result) { - OSTREAM_ROCM_ERROR(InvalidValue) - OSTREAM_ROCM_ERROR(OutOfMemory) - OSTREAM_ROCM_ERROR(NotInitialized) - OSTREAM_ROCM_ERROR(Deinitialized) - OSTREAM_ROCM_ERROR(NoDevice) - OSTREAM_ROCM_ERROR(InvalidDevice) - OSTREAM_ROCM_ERROR(InvalidImage) - OSTREAM_ROCM_ERROR(InvalidContext) - OSTREAM_ROCM_ERROR(InvalidHandle) - OSTREAM_ROCM_ERROR(NotFound) - OSTREAM_ROCM_ERROR(NotReady) - OSTREAM_ROCM_ERROR(NoBinaryForGpu) - - // Encountered an uncorrectable ECC error during execution. - OSTREAM_ROCM_ERROR(ECCNotCorrectable) - - // Load/store on an invalid address. Must reboot all context. - case 700: - return "ROCM_ERROR_ILLEGAL_ADDRESS"; - // Passed too many / wrong arguments, too many threads for register count. - case 701: - return "ROCM_ERROR_LAUNCH_OUT_OF_RESOURCES"; - - OSTREAM_ROCM_ERROR(ContextAlreadyInUse) - OSTREAM_ROCM_ERROR(PeerAccessUnsupported) - OSTREAM_ROCM_ERROR(Unknown) // Unknown internal error to ROCM. - default: - return absl::StrCat("hipError_t(", static_cast(result), ")"); - } -} +std::string ToString(hipError_t result); // GpuContext wraps the device_ordinal and hipCtx_t handle. class GpuContext { @@ -177,6 +143,12 @@ namespace rocm { using MemorySpace = gpu::MemorySpace; using ScopedActivateContext = gpu::ScopedActivateContext; +// TODO: this function shall be added to the GpuDriver API as well +absl::Status OccupancyGetMaxPotentialBlockSize(int* gridSize, int* blockSize, + hipFunction_t func, + size_t dynSharedMemPerBlk, + int blockSizeLimit); + // Returns the current context set in ROCm. This is done by calling ROCm // driver (e.g., this value is not our cached view of the current context). hipCtx_t CurrentContextOrDie(); diff --git a/xla/stream_executor/rocm/rocm_driver_wrapper.h b/xla/stream_executor/rocm/rocm_driver_wrapper.h index 36bcfdd33e873..2a8274b527eca 100644 --- a/xla/stream_executor/rocm/rocm_driver_wrapper.h +++ b/xla/stream_executor/rocm/rocm_driver_wrapper.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -53,7 +53,7 @@ namespace wrap { static FuncPtrT loaded = []() -> FuncPtrT { \ static const char *kName = TO_STR(hipSymbolName); \ void *f; \ - auto s = tsl::Env::Default() -> GetSymbolFromLibrary( \ + auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ stream_executor::internal::CachedDsoLoader::GetHipDsoHandle() \ .value(), \ kName, &f); \ @@ -79,6 +79,7 @@ namespace wrap { __macro(hipDeviceGetName) \ __macro(hipDeviceGetPCIBusId) \ __macro(hipDeviceGetSharedMemConfig) \ + __macro(hipDeviceGetStreamPriorityRange) \ __macro(hipDeviceGraphMemTrim) \ __macro(hipDevicePrimaryCtxGetState) \ __macro(hipDevicePrimaryCtxSetFlags) \ @@ -104,17 +105,26 @@ namespace wrap { __macro(hipGetErrorString) \ __macro(hipGraphAddKernelNode) \ __macro(hipGraphAddChildGraphNode) \ - __macro(hipGraphAddMemcpyNode) \ + __macro(hipGraphAddEmptyNode) \ + __macro(hipGraphAddMemAllocNode) \ __macro(hipGraphAddMemcpyNode1D) \ - __macro(hipGraphExecChildGraphNodeSetParams) \ + __macro(hipGraphAddMemsetNode) \ + __macro(hipGraphAddMemFreeNode) \ __macro(hipGraphCreate) \ __macro(hipGraphDebugDotPrint) \ __macro(hipGraphDestroy) \ + __macro(hipGraphGetNodes) \ + __macro(hipGraphExecChildGraphNodeSetParams) \ __macro(hipGraphExecDestroy) \ + __macro(hipGraphExecKernelNodeSetParams) \ + __macro(hipGraphExecMemcpyNodeSetParams1D) \ + __macro(hipGraphExecMemsetNodeSetParams) \ __macro(hipGraphExecUpdate) \ __macro(hipGraphInstantiate) \ + __macro(hipGraphMemAllocNodeGetParams) \ __macro(hipGraphLaunch) \ __macro(hipGraphNodeGetType) \ + __macro(hipGraphNodeSetEnabled) \ __macro(hipHostFree) \ __macro(hipHostMalloc) \ __macro(hipHostRegister) \ @@ -145,10 +155,12 @@ namespace wrap { __macro(hipModuleLaunchKernel) \ __macro(hipModuleLoadData) \ __macro(hipModuleUnload) \ + __macro(hipModuleOccupancyMaxActiveBlocksPerMultiprocessor) \ + __macro(hipModuleOccupancyMaxPotentialBlockSize) \ __macro(hipPointerGetAttribute) \ __macro(hipPointerGetAttributes) \ + __macro(hipRuntimeGetVersion) \ __macro(hipSetDevice) \ - __macro(hipDeviceGetStreamPriorityRange) \ __macro(hipStreamAddCallback) \ __macro(hipStreamBeginCapture) \ __macro(hipStreamCreateWithFlags) \ diff --git a/xla/stream_executor/rocm/rocm_event.cc b/xla/stream_executor/rocm/rocm_event.cc index af8a1fc987240..497dee272cc1c 100644 --- a/xla/stream_executor/rocm/rocm_event.cc +++ b/xla/stream_executor/rocm/rocm_event.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ namespace stream_executor { namespace gpu { Event::Status GpuEvent::PollForStatus() { - tsl::StatusOr status = + absl::StatusOr status = GpuDriver::QueryEvent(parent_->gpu_context(), gpu_event_); if (!status.ok()) { LOG(ERROR) << "Error polling for event status: " diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc new file mode 100644 index 0000000000000..722f2c800acef --- /dev/null +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -0,0 +1,1090 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include + +#include "absl/base/casts.h" +#include "absl/functional/any_invocable.h" +#include "absl/strings/ascii.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "xla/stream_executor/gpu/gpu_collectives.h" +#include "xla/stream_executor/gpu/gpu_command_buffer.h" +#include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/gpu_event.h" +#include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/gpu/gpu_kernel.h" +#include "xla/stream_executor/gpu/gpu_runtime.h" +#include "xla/stream_executor/gpu/gpu_stream.h" +#include "xla/stream_executor/gpu/gpu_timer.h" +#include "xla/stream_executor/integrations/device_mem_allocator.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform/dso_loader.h" +#include "xla/stream_executor/platform/initialize.h" +#include "xla/stream_executor/platform/port.h" +#include "xla/stream_executor/plugin_registry.h" +#include "xla/stream_executor/rocm/rocm_diagnostics.h" +#include "xla/stream_executor/rocm/rocm_driver.h" +#include "xla/stream_executor/rocm/rocm_platform_id.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_executor_internal.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/fingerprint.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" + +#ifdef PLATFORMS_GPUS_ROCM_DYNAMIC_LIBROCM_DYNAMIC_LIBROCM_H_ +#error \ + "No driver calls in this file, wrap driver functionality in rocm_driver.cc." +#endif + +#ifdef __ROCM_RUNTIME_H__ +#error \ + "ROCM runtime being included into ROCM GPU executor; should be driver only." +#endif + +namespace stream_executor { +namespace gpu { + +static GpuEvent* AsGpuEvent(Event* event) { + DCHECK(event != nullptr); + return static_cast(event->implementation()); +} + +// Given const GPU memory, returns a librocm device pointer datatype, suitable +// for passing directly to librocm APIs. +// +// N.B. we must lose constness in order to pass a suitable type to the existing +// librocm APIs, so the caller should take care to only pass the result of const +// GPU memory conversions to librocm functions which will honor constness. +static hipDeviceptr_t AsROCmDevicePtr(const DeviceMemoryBase& gpu_mem) { + return const_cast(gpu_mem.opaque()); +} + +// See description on const version above. +static hipDeviceptr_t AsROCmDevicePtr(DeviceMemoryBase* gpu_mem) { + return AsROCmDevicePtr(*gpu_mem); +} + +static GpuContext* GetGpuContext(Stream* stream) { + return static_cast(stream->parent()->implementation()) + ->gpu_context(); +} + +GpuContext* ExtractGpuContext(GpuExecutor* rocm_exec) { + CHECK(rocm_exec != nullptr); + return rocm_exec->gpu_context(); +} + +GpuExecutor::~GpuExecutor() { + for (auto& it : disk_modules_) { + GpuDriver::UnloadModule(context_, it.second); + } + for (auto& it : in_memory_modules_) { + GpuDriver::UnloadModule(context_, it.second); + } + if (context_ != nullptr) { + GpuDriver::DestroyContext(context_); + } + CHECK(kernel_to_gpu_binary_.empty()) << "GpuExecutor has live kernels."; + CHECK(gpu_binary_to_module_.empty()) << "GpuExecutor has loaded modules."; +} +bool GpuExecutor::UnloadModule(ModuleHandle module_handle) { + const char* gpu_binary = reinterpret_cast(module_handle.id()); + absl::MutexLock lock{&in_memory_modules_mu_}; + return UnloadGpuBinary(gpu_binary); +} + +namespace { +absl::uint128 Fingerprint128(const absl::string_view s) { + auto fp = tsl::Fingerprint128(s); + return absl::MakeUint128(fp.high64, fp.low64); +} + +int fpus_per_core(std::string gcn_arch_name) { + // Source: + // https://www.amd.com/content/dam/amd/en/documents/instinct-business-docs/white-papers/amd-cdna2-white-paper.pdf + int n = 128; // gfx90a and gfx908 -> 128 + if (gcn_arch_name.substr(0, 6) == "gfx906") { + n = 64; + } + return n; +} +} // namespace + +absl::StatusOr> +GpuExecutor::CreateOrShareConstant(Stream* stream, + absl::Span content) { + absl::MutexLock lock{&shared_constants_mu_}; + // We assume all constants are uniquely identified by this hash. In the + // (highly unlikely) event of a hash collision, the program will likely crash + // (because the cached constant that will be returned by mistake is unlikely + // to have the correct size). + absl::uint128 fingerprint = Fingerprint128(absl::string_view( + reinterpret_cast(content.data()), content.size())); + // Must insert nullptr first to get an iterator to the insertion point. + auto insert_result = shared_constants_.insert( + {fingerprint, std::weak_ptr()}); + auto it = insert_result.first; + bool was_already_in_cache = !insert_result.second; + std::shared_ptr shared_constant; + + if (was_already_in_cache) { + shared_constant = it->second.lock(); + } + + if (shared_constant == nullptr) { + // Either the constant wasn't found in the cache, or it was but its + // weak_ptr had expired. + DeviceMemoryBase* new_constant = + new DeviceMemoryBase(Allocate(content.size(), /*memory_space=*/0)); + if (new_constant->opaque() == nullptr) { + return absl::InternalError(absl::StrFormat( + "Failed to allocate %d bytes for new constant", content.size())); + } + + TF_RETURN_IF_ERROR( + stream->Memcpy(new_constant, content.data(), content.size())); + absl::Status status = stream->BlockHostUntilDone(); + if (!status.ok()) { + Deallocate(new_constant); + status.Update(absl::InternalError(absl::StrFormat( + "Memcpy to device address %p failed", new_constant->opaque()))); + return status; + } + + // Capturing 'this' in the custom deleter means this executor must + // outlive all shared uses of this constant. + shared_constant = std::shared_ptr( + new_constant, [this](DeviceMemoryBase* p) { + Deallocate(p); + delete p; + }); + it->second = std::weak_ptr(shared_constant); + } + + return shared_constant; +} + +bool GpuExecutor::UnloadGpuBinary(const void* gpu_binary) { + auto module_it = gpu_binary_to_module_.find(gpu_binary); + if (gpu_binary_to_module_.end() == module_it) { + VLOG(3) << "No loaded HSACO module for " << gpu_binary; + return false; + } + auto& module = module_it->second.first; + auto& refcount = module_it->second.second; + VLOG(3) << "Found HSACO module " << module << " with refcount " << refcount; + if (--refcount == 0) { + VLOG(3) << "Unloading HSACO module " << module; + GpuDriver::UnloadModule(context_, module); + gpu_binary_to_module_.erase(module_it); + const char* mem_it = nullptr; + for (auto x : in_memory_modules_) { + if (x.second == module) mem_it = x.first; + } + if (mem_it != nullptr) in_memory_modules_.erase(mem_it); + } + return true; +} + +void GpuExecutor::UnloadKernel(const Kernel* kernel) { + VLOG(3) << "Unloading kernel " << kernel << " : " << kernel->name(); + + absl::MutexLock lock{&in_memory_modules_mu_}; + auto gpu_binary_it = kernel_to_gpu_binary_.find(kernel); + if (kernel_to_gpu_binary_.end() == gpu_binary_it) { + VLOG(3) << "Kernel " << kernel << " : " << kernel->name() + << " has never been loaded."; + return; // We've never seen this kernel. + } + VLOG(3) << "Kernel " << kernel << " : " << kernel->name() + << " has loaded GPU code " << gpu_binary_it->second; + UnloadGpuBinary(gpu_binary_it->second); + kernel_to_gpu_binary_.erase(gpu_binary_it); +} + +absl::Status GpuExecutor::Init(int device_ordinal) { + device_ordinal_ = device_ordinal; + + auto status = GpuDriver::Init(); + if (!status.ok()) { + return status; + } + + status = GpuDriver::GetDevice(device_ordinal_, &device_); + if (!status.ok()) { + return status; + } + + status = GpuDriver::CreateContext(device_ordinal_, device_, &context_); + if (!status.ok()) { + return status; + } + + return GpuDriver::GetGpuISAVersion(&version_, device_); +} + +// Returns the path to the running executable. +// N.B. Derived from //knowledge/smalltalk/background_kb.cc +// Arg: strip_exe: if true, remove the name of the executable itself from the +// returned string. Example: calling this from /usr/bin/foo +// would return /usr/bin. +static string GetBinaryDir(bool strip_exe) { + char exe_path[PATH_MAX] = {0}; + CHECK_NE(readlink("/proc/self/exe", exe_path, sizeof(exe_path) - 1), -1); + // Make sure it's null-terminated: + exe_path[sizeof(exe_path) - 1] = 0; + + if (strip_exe) { + // The exe is the last component of the path, so remove one component. + string ret = exe_path; + std::vector components = absl::StrSplit(exe_path, '/'); + components.pop_back(); + return absl::StrJoin(components, "/"); + } + return exe_path; +} + +absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, + Kernel* kernel) { + GpuKernel* rocm_kernel = AsGpuKernel(kernel); + hipModule_t module = nullptr; + const string* kernel_name; + + if (spec.has_cuda_cubin_in_memory()) { + kernel_name = &spec.cuda_cubin_in_memory().kernel_name(); + + const char* hsaco = reinterpret_cast( + spec.cuda_cubin_in_memory().cubin_bytes().data()); + absl::MutexLock lock{&in_memory_modules_mu_}; + module = in_memory_modules_[hsaco]; + + if (module == nullptr) { + TF_RETURN_IF_ERROR(GpuDriver::LoadHsaco(context_, hsaco, &module)); + } + kernel_to_gpu_binary_[kernel] = hsaco; + } else if (spec.has_in_process_symbol()) { + kernel_name = &spec.in_process_symbol().kernel_name(); + void* symbol = spec.in_process_symbol().symbol(); + + VLOG(1) << "Resolve ROCM kernel " << *kernel_name + << " from symbol pointer: " << symbol; + + *rocm_kernel->gpu_function_ptr() = + static_cast(spec.in_process_symbol().symbol()); + } else { + return absl::InternalError("No method of loading ROCM kernel provided"); + } + + // If we resolved kernel from a symbol pointer, there is no need to load it + // from a module, as ROCm runtime did that automatically for us. + if (!spec.has_in_process_symbol()) { + VLOG(2) << "getting function " << *kernel_name << " from module " << module; + TF_RETURN_IF_ERROR( + GpuDriver::GetModuleFunction(context_, module, kernel_name->c_str(), + rocm_kernel->gpu_function_ptr())); + } + + // We have to trust the kernel loader spec arity because there doesn't appear + // to be a way to reflect on the number of expected arguments w/the ROCM API. + rocm_kernel->set_arity(spec.arity()); + + // unable to get kernel metadata for in-process kernel + if (!spec.has_in_process_symbol()) { + KernelMetadata kernel_metadata; + TF_RETURN_IF_ERROR(GetKernelMetadata(rocm_kernel, &kernel_metadata)); + kernel->set_metadata(kernel_metadata); + } + kernel->set_name(*kernel_name); + kernel->set_args_packing(spec.kernel_args_packing()); + return absl::OkStatus(); +} + +absl::Status GpuExecutor::GetKernelMetadata(GpuKernel* rocm_kernel, + KernelMetadata* kernel_metadata) { + int value = 0; + TF_RETURN_IF_ERROR(GpuDriver::FuncGetAttribute( + HIP_FUNC_ATTRIBUTE_NUM_REGS, *rocm_kernel->gpu_function_ptr(), &value)); + kernel_metadata->set_registers_per_thread(value); + + TF_RETURN_IF_ERROR( + GpuDriver::FuncGetAttribute(HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, + *rocm_kernel->gpu_function_ptr(), &value)); + kernel_metadata->set_shared_memory_bytes(value); + return absl::OkStatus(); +} + +absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, + const BlockDim& block_dims, + const Kernel& kernel, const KernelArgs& args) { + GpuStreamHandle hipstream = AsGpuStreamValue(stream); + const GpuKernel* rocm_kernel = AsGpuKernel(&kernel); + hipFunction_t hipfunc = rocm_kernel->AsGpuFunctionHandle(); + + // Only perform/print the occupancy check once. Even just checking to see + // whether we've done an occupancy check on this kernel before isn't free + // (because we have to synchronize), so we only do this at -v 2+. + if (VLOG_IS_ON(2)) { + absl::MutexLock lock(&launched_kernels_mu_); + if (!launched_kernels_.count(hipfunc)) { + VlogOccupancyInfo(stream->parent()->GetDeviceDescription(), kernel, + thread_dims, block_dims); + // TODO(rspringer): Remove elements from launched_kernels_...if we ever + // expose a kernel/module deallocation method. + launched_kernels_.insert(hipfunc); + } + } + + if (rocm_kernel->cache_config() != KernelCacheConfig::kNoPreference) { + TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig( + hipfunc, rocm_kernel->GetGpuCacheConfig())); + } + + auto launch = [&](const KernelArgsPackedArrayBase& packed) { + CHECK_EQ(kernel.Arity() + (args.number_of_shared_bytes() > 0), + packed.number_of_arguments()); + + void** kernel_params = + const_cast(packed.argument_addresses().data()); + + return GpuDriver::LaunchKernel( + GetGpuContext(stream), kernel.name(), hipfunc, block_dims.x, + block_dims.y, block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, + args.number_of_shared_bytes(), hipstream, kernel_params, nullptr); + }; + + auto* packed_args = DynCast(&args); + if (packed_args) return launch(*packed_args); + + if (auto* device_mem = DynCast(&args)) { + auto& pack = kernel.args_packing(); + if (!pack) { + return absl::InternalError( + "Kernel is missing a custom arguments packing function for device " + "memory arguments array"); + } + + TF_ASSIGN_OR_RETURN(auto packed_args, pack(kernel, *device_mem)); + return launch(*packed_args); + } + + return absl::InternalError("Unsupported kernel arguments type"); +} + +absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, + const BlockDim& block_dims, + const ClusterDim& cluster_dims, + const Kernel& kernel, const KernelArgs& args) { + if (cluster_dims.x != 1 || cluster_dims.y != 1 || cluster_dims.z != 1) + return absl::UnimplementedError("Not implemented for ROCm"); + return Launch(stream, thread_dims, block_dims, kernel, args); +} + +absl::Status GpuExecutor::Submit(Stream* stream, + const CommandBuffer& command_buffer) { + if (command_buffer.mode() != CommandBuffer::Mode::kPrimary) { + return absl::InvalidArgumentError( + "Can't submit non-primary command buffer for execution"); + } + + auto exec = GpuCommandBuffer::Cast(&command_buffer)->executable(); + VLOG(3) << "Launch command buffer execuable graph " << exec + << " on a stream: " << stream; + return GpuDriver::GraphLaunch(exec, AsGpuStreamValue(stream)); +} + +absl::Status GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec, + ModuleHandle* module_handle) { + // In GpuExecutor we store the pointer to the HSACO binary as + // ModuleHandle::id(). + hipModule_t hip_module = nullptr; + // TODO(ROCm): Need generic term instead of cubin/cuda/ptx + if (spec.has_cuda_cubin_in_memory()) { + absl::MutexLock lock{&in_memory_modules_mu_}; + TF_RETURN_IF_ERROR(LoadModuleFromHsaco( + reinterpret_cast(spec.cuda_cubin_in_memory().data()), + &hip_module)); + *module_handle = ModuleHandle(const_cast( + static_cast(spec.cuda_cubin_in_memory().data()))); + return absl::OkStatus(); + } else { + return absl::InternalError("No HASCO binary found"); + } +} + +absl::Status GpuExecutor::LoadModuleFromCuBin(const char* cubin, + hipModule_t* module) { + LOG(FATAL) << "Feature not supported on ROCM platform (LoadModuleFromCuBin)"; +} + +absl::Status GpuExecutor::LoadModuleFromPtx(const char* ptx, + hipModule_t* module) { + LOG(FATAL) << "Feature not supported on ROCM platform (LoadModuleFromPtx)"; +} + +absl::Status GpuExecutor::LoadModuleFromHsaco(const char* hsaco, + hipModule_t* module) { + uint64_t module_refcount; + std::tie(*module, module_refcount) = gpu_binary_to_module_[hsaco]; + + if (*module == nullptr) { + TF_RETURN_IF_ERROR(GpuDriver::LoadHsaco(context_, hsaco, module)); + module_refcount = 1; + in_memory_modules_[hsaco] = *module; + VLOG(3) << "Loaded HSACO " << static_cast(hsaco) + << " as module " << *module; + } else { + ++module_refcount; + VLOG(3) << "HSACO " << static_cast(hsaco) + << " is already loaded as module " << *module; + } + gpu_binary_to_module_[hsaco] = {*module, module_refcount}; + return absl::OkStatus(); +} + +// This is a non-essential operation; if there's a failure, proceed without +// logging an error. It's nearly certain that in case of failures, we'd never +// get here in the first place; these are very low-impact routines. +void GpuExecutor::VlogOccupancyInfo(const DeviceDescription& device_description, + const Kernel& kernel, + const ThreadDim& thread_dims, + const BlockDim& block_dims) { + VLOG(2) << "Computing kernel occupancy for kernel " + << kernel.demangled_name(); + VLOG(2) << "Thread dimensions (" << thread_dims.x << ", " << thread_dims.y + << ", " << thread_dims.z << ")"; + + auto regs_per_thread = kernel.metadata().registers_per_thread(); + auto smem_per_block = kernel.metadata().shared_memory_bytes(); + + if (!regs_per_thread && !smem_per_block) { + return; + } + + const GpuKernel* rocm_kernel = AsGpuKernel(&kernel); + auto hipfunc = rocm_kernel->AsGpuFunctionHandle(); + + int blocks_per_sm = CalculateOccupancy(device_description, *regs_per_thread, + *smem_per_block, thread_dims, hipfunc); + VLOG(2) << "Resident blocks per SM is " << blocks_per_sm; + + int suggested_threads = + CompareOccupancy(&blocks_per_sm, device_description, *regs_per_thread, + *smem_per_block, thread_dims, hipfunc); + if (suggested_threads != 0) { + VLOG(2) << "The rocm occupancy calculator recommends using " + << suggested_threads + << " threads per block to achieve an occupancy of " << blocks_per_sm + << " blocks per SM."; + } +} + +// Compute and return maximum blocks per core (occupancy) based on the +// device description, some kernel characteristics and the number of threads per +// block. If unable to compute occupancy, zero is returned. +int GpuExecutor::CalculateOccupancy(const DeviceDescription& device_description, + uint64_t registers_per_thread, + uint64_t shared_memory_per_block, + const ThreadDim& thread_dims, + GpuFunctionHandle func) { + int suggested_blocks = 0; + int suggested_threads = 0; + (void)rocm::OccupancyGetMaxPotentialBlockSize( + &suggested_blocks, &suggested_threads, func, shared_memory_per_block, 0); + return suggested_blocks; +} + +// Compute and return the suggested thread count to achieve ideal occupancy. +// If the provided thread dimensions match this number, zero is returned. +int GpuExecutor::CompareOccupancy(int* initial_blocks, + const DeviceDescription& device_description, + uint64_t registers_per_thread, + uint64_t shared_memory_per_block, + const ThreadDim& thread_dims, + GpuFunctionHandle func) { + int suggested_blocks = 0; + int suggested_threads = 0; + (void)rocm::OccupancyGetMaxPotentialBlockSize( + &suggested_blocks, &suggested_threads, func, shared_memory_per_block, 0); + if (suggested_blocks > *initial_blocks) { + *initial_blocks = suggested_blocks; + return suggested_threads; + } else { + return 0; + } +} + +DeviceMemoryBase GpuExecutor::Allocate(uint64_t size, int64_t memory_space) { + if (memory_space == + static_cast(stream_executor::MemoryType::kHost)) { + return DeviceMemoryBase(GpuDriver::HostAllocate(context_, size), size); + } + CHECK_EQ(memory_space, 0); + return DeviceMemoryBase(GpuDriver::DeviceAllocate(context_, size), size); +} + +void GpuExecutor::Deallocate(DeviceMemoryBase* mem) { + GpuDriver::DeviceDeallocate(context_, mem->opaque()); +} + +bool GpuExecutor::HostMemoryRegister(void* location, uint64_t size) { + if (location == nullptr || size == 0) { + LOG(WARNING) << "attempting to register null or zero-sized memory: " + << location << "; size " << size; + } + VLOG(2) << "registering " << location << " size " << size; + return GpuDriver::HostRegister(context_, location, size); +} + +bool GpuExecutor::HostMemoryUnregister(void* location) { + VLOG(2) << "unregistering " << location; + return GpuDriver::HostUnregister(context_, location); +} + +bool GpuExecutor::SynchronizeAllActivity() { + return GpuDriver::SynchronizeContext(context_); +} + +absl::Status GpuExecutor::SynchronousMemZero(DeviceMemoryBase* location, + uint64_t size) { + if (reinterpret_cast(location->opaque()) % 4 == 0 && + size % 4 == 0) { + return GpuDriver::SynchronousMemsetUint32( + context_, AsROCmDevicePtr(location), 0x0, size / 4); + } + return GpuDriver::SynchronousMemsetUint8(context_, AsROCmDevicePtr(location), + 0x0, size); +} + +absl::Status GpuExecutor::SynchronousMemSet(DeviceMemoryBase* location, + int value, uint64_t size) { + if (reinterpret_cast(location->opaque()) % 4 == 0 && + size % 4 == 0) { + // hipMemset reinterprets "value" as a uint8. + uint8 byte_value = static_cast(value); + uint32 pattern = (byte_value << 24) | (byte_value << 16) | + (byte_value << 8) | byte_value; + return GpuDriver::SynchronousMemsetUint32( + context_, AsROCmDevicePtr(location), pattern, size / 4); + } + return GpuDriver::SynchronousMemsetUint8(context_, AsROCmDevicePtr(location), + value, size); +} + +absl::Status GpuExecutor::SynchronousMemcpy(DeviceMemoryBase* gpu_dst, + const void* host_src, + uint64_t size) { + return GpuDriver::SynchronousMemcpyH2D(context_, AsROCmDevicePtr(gpu_dst), + host_src, size); +} + +absl::Status GpuExecutor::SynchronousMemcpy(void* host_dst, + const DeviceMemoryBase& gpu_src, + uint64_t size) { + return GpuDriver::SynchronousMemcpyD2H(context_, host_dst, + AsROCmDevicePtr(gpu_src), size); +} + +absl::Status GpuExecutor::SynchronousMemcpyDeviceToDevice( + DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, uint64_t size) { + return GpuDriver::SynchronousMemcpyD2D(context_, AsROCmDevicePtr(gpu_dst), + AsROCmDevicePtr(gpu_src), size); +} + +absl::Status GpuExecutor::MemZero(Stream* stream, DeviceMemoryBase* location, + uint64_t size) { + if (reinterpret_cast(location->opaque()) % 4 == 0 && + size % 4 == 0) { + return Memset32(stream, location, 0x0, size); + } else { + return Memset(stream, location, 0x0, size); + } +} + +absl::Status GpuExecutor::Memset(Stream* stream, DeviceMemoryBase* location, + uint8 pattern, uint64_t size) { + VLOG(2) << "enqueueing memset8 operation onto stream " << stream + << " at location " << location << " with size " << size + << " and pattern " << std::hex << pattern; + return GpuDriver::AsynchronousMemsetUint8(context_, AsROCmDevicePtr(location), + pattern, size, + AsGpuStreamValue(stream)); +} + +absl::Status GpuExecutor::Memset32(Stream* stream, DeviceMemoryBase* location, + uint32 pattern, uint64_t size) { + VLOG(2) << "enqueueing memset32 operation onto stream " << stream + << " at location " << location << " with size " << size + << " and pattern " << std::hex << pattern; + CHECK(reinterpret_cast(location->opaque()) % 4 == 0 && + size % 4 == 0); + return GpuDriver::AsynchronousMemsetUint32( + context_, AsROCmDevicePtr(location), pattern, size / 4, + AsGpuStreamValue(stream)); +} + +absl::Status GpuExecutor::Memcpy(Stream* stream, void* host_dst, + const DeviceMemoryBase& gpu_src, + uint64_t size) { + bool ok = GpuDriver::AsynchronousMemcpyD2H(context_, host_dst, + AsROCmDevicePtr(gpu_src), size, + AsGpuStreamValue(stream)); + + // TODO(b/326130105): Change AsynchronousMemcpyD2H calls to return Status. + if (!ok) { + return absl::InternalError("Failed to memcpy from device to host."); + } + return absl::OkStatus(); +} + +absl::Status GpuExecutor::Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, + const void* host_src, uint64_t size) { + bool ok = GpuDriver::AsynchronousMemcpyH2D(context_, AsROCmDevicePtr(gpu_dst), + host_src, size, + AsGpuStreamValue(stream)); + // TODO(b/326130105): Change AsynchronousMemcpyD2H calls to return Status. + if (!ok) { + return absl::InternalError("Failed to memcpy from device to host."); + } + return absl::OkStatus(); +} + +bool GpuExecutor::MemcpyDeviceToDevice(Stream* stream, + DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, + uint64_t size) { + return GpuDriver::AsynchronousMemcpyD2D(context_, AsROCmDevicePtr(gpu_dst), + AsROCmDevicePtr(gpu_src), size, + AsGpuStreamValue(stream)); +} + +bool GpuExecutor::HostCallback(Stream* stream, + absl::AnyInvocable callback) { + auto callback_ptr = + new absl::AnyInvocable([cb = std::move(callback)]() mutable { + absl::Status s = std::move(cb)(); + if (!s.ok()) { + LOG(WARNING) << "Host callback failed: " << s; + } + }); + return GpuDriver::AddStreamCallback(context_, AsGpuStreamValue(stream), + InternalHostCallback, callback_ptr); +} + +/* static */ void GpuExecutor::InternalHostCallback(void* data) { + auto* callback = reinterpret_cast*>(data); + std::move (*callback)(); + delete callback; +} + +absl::Status GpuExecutor::AllocateEvent(Event* event) { + return AsGpuEvent(event)->Init(); +} + +absl::Status GpuExecutor::DeallocateEvent(Event* event) { + return AsGpuEvent(event)->Destroy(); +} + +absl::Status GpuExecutor::RecordEvent(Stream* stream, Event* event) { + return AsGpuEvent(event)->Record(AsGpuStream(stream)); +} + +absl::Status GpuExecutor::WaitForEvent(Stream* stream, Event* event) { + if (GpuDriver::WaitStreamOnEvent(context_, AsGpuStream(stream)->gpu_stream(), + AsGpuEvent(event)->gpu_event())) { + return absl::OkStatus(); + } else { + return absl::Status{ + absl::StatusCode::kInternal, + absl::StrFormat("error recording waiting for ROCM event on stream %p", + stream)}; + } +} + +absl::Status GpuExecutor::WaitForEventOnExternalStream(std::intptr_t stream, + Event* event) { + if (GpuDriver::WaitStreamOnEvent(context_, + absl::bit_cast(stream), + AsGpuEvent(event)->gpu_event())) { + return absl::OkStatus(); + } else { + return absl::InternalError( + "error waiting for ROCM event on external stream"); + } +} + +Event::Status GpuExecutor::PollForEventStatus(Event* event) { + return AsGpuEvent(event)->PollForStatus(); +} + +bool GpuExecutor::AllocateStream(Stream* stream) { + absl::MutexLock l(&alive_gpu_streams_mu_); + bool out = AsGpuStream(stream)->Init(); + alive_gpu_streams_[stream->platform_specific_handle().stream] = stream; + return out; +} + +void GpuExecutor::DeallocateStream(Stream* stream) { + GpuStream* rocm_stream = AsGpuStream(stream); + absl::MutexLock l(&alive_gpu_streams_mu_); + alive_gpu_streams_.erase(rocm_stream->platform_specific_stream()); + if (!rocm_stream->IsIdle()) { + LOG(ERROR) << "Deallocating stream with pending work"; + } + rocm_stream->Destroy(); +} + +bool GpuExecutor::CreateStreamDependency(Stream* dependent, Stream* other) { + GpuEventHandle other_completed_event = *AsGpuStream(other)->completed_event(); + bool ok = GpuDriver::RecordEvent(context_, other_completed_event, + AsGpuStreamValue(other)) + .ok(); + if (!ok) { + LOG(ERROR) << "failed to record completion event; " + "therefore, failed to create inter-stream dependency"; + return false; + } + + return GpuDriver::WaitStreamOnEvent(context_, AsGpuStreamValue(dependent), + other_completed_event); +} + +absl::Status GpuExecutor::BlockHostUntilDone(Stream* stream) { + return GpuDriver::SynchronizeStream(context_, AsGpuStreamValue(stream)); +} + +blas::BlasSupport* GpuExecutor::CreateBlas() { + PluginRegistry* registry = PluginRegistry::Instance(); + absl::StatusOr status = + registry->GetFactory(rocm::kROCmPlatformId); + if (!status.ok()) { + LOG(ERROR) << "Unable to retrieve BLAS factory: " + << status.status().message(); + return nullptr; + } + + return status.value()(this); +} + +dnn::DnnSupport* GpuExecutor::CreateDnn() { + PluginRegistry* registry = PluginRegistry::Instance(); + absl::StatusOr status = + registry->GetFactory(rocm::kROCmPlatformId); + if (!status.ok()) { + LOG(ERROR) << "Unable to retrieve DNN factory: " + << status.status().message(); + return nullptr; + } + + return status.value()(this); +} + +fft::FftSupport* GpuExecutor::CreateFft() { + PluginRegistry* registry = PluginRegistry::Instance(); + absl::StatusOr status = + registry->GetFactory(rocm::kROCmPlatformId); + if (!status.ok()) { + LOG(ERROR) << "Unable to retrieve FFT factory: " + << status.status().message(); + return nullptr; + } + + return status.value()(this); +} + +bool GpuExecutor::CanEnablePeerAccessTo(StreamExecutorInterface* other) { + GpuExecutor* rocm_other = static_cast(other); + return GpuDriver::CanEnablePeerAccess(context_, rocm_other->context_); +} + +absl::Status GpuExecutor::EnablePeerAccessTo(StreamExecutorInterface* other) { + GpuExecutor* rocm_other = static_cast(other); + return GpuDriver::EnablePeerAccess(context_, rocm_other->context_); +} + +bool GpuExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const { + return GpuDriver::GetDeviceMemoryInfo(context_, free, total); +} + +bool GpuExecutor::GetSymbol(const string& symbol_name, + ModuleHandle module_handle, void** mem, + size_t* bytes) { + absl::MutexLock lock{&in_memory_modules_mu_}; + if (static_cast(module_handle)) { + auto it = gpu_binary_to_module_.find(module_handle.id()); + CHECK(it != gpu_binary_to_module_.end()); + if (GpuDriver::GetModuleSymbol( + context_, it->second.first, symbol_name.c_str(), + reinterpret_cast(mem), bytes)) { + return true; + } + } + + for (auto& it : gpu_binary_to_module_) { + if (GpuDriver::GetModuleSymbol( + context_, it.second.first, symbol_name.c_str(), + reinterpret_cast(mem), bytes)) { + return true; + } + } + + LOG(INFO) << "Falied to find symbol in any modules: " << symbol_name; + return false; +} + +absl::Status FillBlockDimLimit(GpuDeviceHandle device, + BlockDim* block_dim_limit) { + // The BlockDim name is a mismatch against these GRID_DIM_* queries because + // we use BlockDims to express the dimensions of blocks within a grid + // (as opposed to ThreadDim which expresses the dimensions of threads + // within a block). + int x, y, z; + TF_RETURN_IF_ERROR(GpuDriver::GetGridLimits(&x, &y, &z, device)); + + block_dim_limit->x = x; + block_dim_limit->y = y; + block_dim_limit->z = z; + return absl::OkStatus(); +} + +std::unique_ptr +GpuExecutor::CreateEventImplementation() { + return std::unique_ptr(new GpuEvent(this)); +} + +std::unique_ptr +GpuExecutor::GetStreamImplementation() { + return std::unique_ptr(new GpuStream(this)); +} + +absl::StatusOr> GpuExecutor::CreateKernel() { + return std::make_unique(this); +} + +absl::StatusOr> GpuExecutor::CreateCommandBuffer( + CommandBuffer::Mode mode) { + VLOG(2) << "Create ROCm command buffer (ROCm graph)"; + GpuGraphHandle graph = nullptr; + TF_RETURN_IF_ERROR(GpuDriver::CreateGraph(&graph)); + return std::make_unique(mode, /*parent=*/this, graph); +} + +std::unique_ptr GpuExecutor::CreateCommandBuffer( + CommandBuffer::Mode mode, GpuGraphHandle graph, bool is_owned_graph) { + VLOG(2) << "Create HIP command buffer (HIP graph) from existing graph " + << graph << "; is_owned_graph=" << is_owned_graph; + return std::make_unique(mode, /*parent=*/this, graph, + is_owned_graph); +} + +GpuContext* GpuExecutor::gpu_context() { return context_; } + +// Attempts to read the NUMA node corresponding to the GPU device's PCI bus out +// of SysFS. Returns -1 if it cannot. +// +// For anything more complicated/prod-focused than this, you'll likely want to +// turn to gsys' topology modeling. +static int TryToReadNumaNode(const string& pci_bus_id, int device_ordinal) { + VLOG(2) << "trying to read NUMA node for device ordinal: " << device_ordinal; + static const int kUnknownNumaNode = -1; + + if (pci_bus_id.empty()) { + LOG(INFO) << "no PCI bus ID for device ordinal: " << device_ordinal; + return kUnknownNumaNode; + } + + std::string filename = + absl::StrFormat("/sys/bus/pci/devices/%s/numa_node", pci_bus_id); + + // We have to use fopen/fread here so that the device properties can be + // populated before InitGoogle procedure has been completed (at which point we + // could use the file::* utilities). + FILE* file = fopen(filename.c_str(), "r"); + if (file == nullptr) { + LOG(INFO) << "could not open file to read NUMA node: " << filename + << "\nYour kernel may have been built without NUMA support."; + return kUnknownNumaNode; + } + + std::string content; + char buf[32]; + size_t did_read = fread(buf, sizeof(buf[0]), sizeof(buf) - 1, file); + buf[did_read] = '\0'; + content = buf; + + int32_t value; + if (absl::SimpleAtoi(content, &value)) { + if (value < 0) { // See http://b/18228951 for details on this path. + LOG(INFO) << "successful NUMA node read from SysFS had negative value (" + << value + << "), but there must be at least one NUMA node" + ", so returning NUMA node zero"; + fclose(file); + return 0; + } + fclose(file); + return value; + } + + LOG(WARNING) + << "could not convert SysFS file contents to integral NUMA node value: " + << content; + + fclose(file); + return kUnknownNumaNode; +} + +absl::StatusOr> +GpuExecutor::CreateDeviceDescription(int device_ordinal) { + GpuDeviceHandle device; + auto status = GpuDriver::GetDevice(device_ordinal, &device); + if (!status.ok()) { + return status; + } + + int version; + status = GpuDriver::GetGpuISAVersion(&version, device); + if (!status.ok()) { + return status; + } + + std::string gcn_arch_name; + status = GpuDriver::GetGpuGCNArchName(device, &gcn_arch_name); + if (!status.ok()) { + return status; + } + + internal::DeviceDescriptionBuilder builder; + + { + int version = GpuDriver::GetDriverVersion().value_or(-1); + string augmented_driver_version = absl::StrFormat( + "%d (%s)", version, + rocm::DriverVersionStatusToString(Diagnostician::FindDsoVersion()) + .c_str()); + builder.set_driver_version(augmented_driver_version); + } + + { + string pci_bus_id = GpuDriver::GetPCIBusID(device); + + // Lower the hex characters to match sysfs. + pci_bus_id = absl::AsciiStrToLower(pci_bus_id); + builder.set_pci_bus_id(pci_bus_id); + + // Read the NUMA node corresponding to the PCI bus ID out of sysfs. + int numa_node = TryToReadNumaNode(pci_bus_id, device_ordinal); + builder.set_numa_node(numa_node); + } + + hipDeviceProp_t prop; + if (GpuDriver::GetDeviceProperties(&prop, device_ordinal)) { + builder.set_threads_per_block_limit(prop.maxThreadsPerBlock); + + ThreadDim thread_dim_limit; + thread_dim_limit.x = prop.maxThreadsDim[0]; + thread_dim_limit.y = prop.maxThreadsDim[1]; + thread_dim_limit.z = prop.maxThreadsDim[2]; + builder.set_thread_dim_limit(thread_dim_limit); + + float clock_rate_ghz = static_cast(prop.clockRate) / 1e6; + builder.set_clock_rate_ghz(clock_rate_ghz); + + // mem_bandwidth = 2 * mem_bus_width_in_bytes * mem_clock_rate_in_hz + int64_t memory_bandwidth = 2 * (int64_t(prop.memoryBusWidth) / 8) * + (int64_t(prop.memoryClockRate) * 1000); + builder.set_memory_bandwidth(memory_bandwidth); + + builder.set_l2_cache_size(prop.l2CacheSize); + } + + { + bool ecc_enabled = false; + (void)GpuDriver::IsEccEnabled(device, &ecc_enabled); + builder.set_ecc_enabled(ecc_enabled); + } + + uint64_t device_memory_size = -1; + (void)GpuDriver::GetDeviceTotalMemory(device, &device_memory_size); + builder.set_device_memory_size(device_memory_size); + + { + BlockDim block_dim_limit; + TF_RETURN_IF_ERROR(FillBlockDimLimit(device, &block_dim_limit)); + builder.set_block_dim_limit(block_dim_limit); + } + + { + string device_name; + TF_RETURN_IF_ERROR(GpuDriver::GetDeviceName(device, &device_name)); + builder.set_name(device_name); + } + + builder.set_platform_version( + absl::StrCat("AMDGPU ISA version: ", gcn_arch_name)); + + // TODO(leary) should be a way to query this from the driver, but this is + // unlikely to change for us any time soon. + builder.set_device_address_bits(64); + + builder.set_device_vendor("Advanced Micro Devices, Inc"); + builder.set_rocm_compute_capability(gcn_arch_name); + + builder.set_shared_memory_per_core( + GpuDriver::GetMaxSharedMemoryPerCore(device).value()); + builder.set_shared_memory_per_block( + GpuDriver::GetMaxSharedMemoryPerBlock(device).value()); + int core_count = GpuDriver::GetMultiprocessorCount(device).value(); + builder.set_core_count(core_count); + builder.set_fpus_per_core(fpus_per_core(gcn_arch_name)); + builder.set_threads_per_core_limit( + GpuDriver::GetMaxThreadsPerMultiprocessor(device).value()); + builder.set_registers_per_block_limit( + GpuDriver::GetMaxRegistersPerBlock(device).value()); + builder.set_threads_per_warp(GpuDriver::GetThreadsPerWarp(device).value()); + builder.set_registers_per_core_limit(64 * 1024); + + int cc_major = 0; + int cc_minor = 0; + GpuDriver::GetComputeCapability(&cc_major, &cc_minor, device).IgnoreError(); + + // It would be better to use the PCI device ID or some other truly unique + // identifier for the GPU model. But getting this requires using NVML or + // other hacks, which we don't have access to in OSS TensorFlow. + // + // Alternatively you might be tempted to use GpuDriver::GetDeviceName as a + // unique identifier, but this is not stable across GPU VBIOS versions. + // + // TODO(jlebar): This really should be more unique. In CUDA land, we mix in + // the clock speed and L2 cache size. + builder.set_model_str(absl::StrFormat("cc_%d.%d with %dB RAM, %d cores", + cc_major, cc_minor, device_memory_size, + core_count)); + + return builder.Build(); +} + +} // namespace gpu + +} // namespace stream_executor + +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(rocm_executor, {}); diff --git a/xla/stream_executor/rocm/rocm_fft.cc b/xla/stream_executor/rocm/rocm_fft.cc index 5dc336e996e30..9d2f9b30899ce 100644 --- a/xla/stream_executor/rocm/rocm_fft.cc +++ b/xla/stream_executor/rocm/rocm_fft.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -151,7 +151,7 @@ bool SetStream(GpuExecutor *parent, hipfftHandle plan, Stream *stream) { } // namespace -tsl::Status ROCMFftPlan::Initialize( +absl::Status ROCMFftPlan::Initialize( GpuExecutor *parent, Stream *stream, int rank, uint64_t *elem_count, uint64_t *input_embed, uint64 input_stride, uint64 input_distance, uint64_t *output_embed, uint64 output_stride, uint64 output_distance, @@ -183,20 +183,20 @@ tsl::Status ROCMFftPlan::Initialize( ROCMFftType(type), 1 /* = batch */); if (ret != HIPFFT_SUCCESS) { LOG(ERROR) << "failed to create rocFFT 1d plan:" << ret; - return tsl::Status{absl::StatusCode::kInternal, - "Failed to create rocFFT 1d plan."}; + return absl::Status{absl::StatusCode::kInternal, + "Failed to create rocFFT 1d plan."}; } - return tsl::OkStatus(); + return absl::OkStatus(); case 2: // hipfftPlan2d ret = wrap::hipfftPlan2d(parent, &plan_, elem_count_[0], elem_count_[1], ROCMFftType(type)); if (ret != HIPFFT_SUCCESS) { LOG(ERROR) << "failed to create rocFFT 2d plan:" << ret; - return tsl::Status{absl::StatusCode::kInternal, - "Failed to create rocFFT 2d plan."}; + return absl::Status{absl::StatusCode::kInternal, + "Failed to create rocFFT 2d plan."}; } - return tsl::OkStatus(); + return absl::OkStatus(); case 3: // hipfftPlan3d ret = @@ -204,29 +204,29 @@ tsl::Status ROCMFftPlan::Initialize( elem_count_[2], ROCMFftType(type)); if (ret != HIPFFT_SUCCESS) { LOG(ERROR) << "failed to create rocFFT 3d plan:" << ret; - return tsl::Status{absl::StatusCode::kInternal, - "Failed to create rocFFT 3d plan."}; + return absl::Status{absl::StatusCode::kInternal, + "Failed to create rocFFT 3d plan."}; } - return tsl::OkStatus(); + return absl::OkStatus(); default: LOG(ERROR) << "Invalid rank value for hipfftPlan. " "Requested 1, 2, or 3, given: " << rank; - return tsl::Status{absl::StatusCode::kInvalidArgument, - "hipfftPlan only takes rank 1, 2, or 3."}; + return absl::Status{absl::StatusCode::kInvalidArgument, + "hipfftPlan only takes rank 1, 2, or 3."}; } } else { ret = wrap::hipfftCreate(parent, &plan_); if (ret != HIPFFT_SUCCESS) { LOG(ERROR) << "failed to create rocFFT plan:" << ret; - return tsl::Status{absl::StatusCode::kInternal, - "Failed to create rocFFT plan."}; + return absl::Status{absl::StatusCode::kInternal, + "Failed to create rocFFT plan."}; } ret = wrap::hipfftSetAutoAllocation(parent, plan_, 0); if (ret != HIPFFT_SUCCESS) { LOG(ERROR) << "failed to set auto allocation for rocFFT plan:" << ret; - return tsl::Status{absl::StatusCode::kInternal, - "Failed to set auto allocation for rocFFT plan."}; + return absl::Status{absl::StatusCode::kInternal, + "Failed to set auto allocation for rocFFT plan."}; } switch (rank) { case 1: @@ -235,8 +235,8 @@ tsl::Status ROCMFftPlan::Initialize( &scratch_size_bytes_); if (ret != HIPFFT_SUCCESS) { LOG(ERROR) << "failed to make rocFFT 1d plan:" << ret; - return tsl::Status{absl::StatusCode::kInternal, - "Failed to make rocFFT 1d plan."}; + return absl::Status{absl::StatusCode::kInternal, + "Failed to make rocFFT 1d plan."}; } break; case 2: @@ -245,8 +245,8 @@ tsl::Status ROCMFftPlan::Initialize( &scratch_size_bytes_); if (ret != HIPFFT_SUCCESS) { LOG(ERROR) << "failed to make rocFFT 2d plan:" << ret; - return tsl::Status{absl::StatusCode::kInternal, - "Failed to make rocFFT 2d plan."}; + return absl::Status{absl::StatusCode::kInternal, + "Failed to make rocFFT 2d plan."}; } break; case 3: @@ -255,16 +255,16 @@ tsl::Status ROCMFftPlan::Initialize( ROCMFftType(type), &scratch_size_bytes_); if (ret != HIPFFT_SUCCESS) { LOG(ERROR) << "failed to make rocFFT 3d plan:" << ret; - return tsl::Status{absl::StatusCode::kInternal, - "Failed to make rocFFT 3d plan."}; + return absl::Status{absl::StatusCode::kInternal, + "Failed to make rocFFT 3d plan."}; } break; default: LOG(ERROR) << "Invalid rank value for hipfftPlan. " "Requested 1, 2, or 3, given: " << rank; - return tsl::Status{absl::StatusCode::kInvalidArgument, - "hipfftPlan only takes rank 1, 2, or 3."}; + return absl::Status{absl::StatusCode::kInvalidArgument, + "hipfftPlan only takes rank 1, 2, or 3."}; } return UpdateScratchAllocator(stream, scratch_allocator); } @@ -278,21 +278,21 @@ tsl::Status ROCMFftPlan::Initialize( output_distance, ROCMFftType(type), batch_count); if (ret != HIPFFT_SUCCESS) { LOG(ERROR) << "failed to create rocFFT batched plan:" << ret; - return tsl::Status{absl::StatusCode::kInternal, - "Failed to create rocFFT batched plan."}; + return absl::Status{absl::StatusCode::kInternal, + "Failed to create rocFFT batched plan."}; } } else { auto ret = wrap::hipfftCreate(parent, &plan_); if (ret != HIPFFT_SUCCESS) { LOG(ERROR) << "failed to create rocFFT batched plan:" << ret; - return tsl::Status{absl::StatusCode::kInternal, - "Failed to create rocFFT batched plan."}; + return absl::Status{absl::StatusCode::kInternal, + "Failed to create rocFFT batched plan."}; } ret = wrap::hipfftSetAutoAllocation(parent, plan_, 0); if (ret != HIPFFT_SUCCESS) { LOG(ERROR) << "failed to set auto allocation for rocFFT batched plan:" << ret; - return tsl::Status{ + return absl::Status{ absl::StatusCode::kInternal, "Failed to set auto allocation for rocFFT batched plan."}; } @@ -304,19 +304,19 @@ tsl::Status ROCMFftPlan::Initialize( &scratch_size_bytes_); if (ret != HIPFFT_SUCCESS) { LOG(ERROR) << "failed to make rocFFT batched plan:" << ret; - return tsl::Status{absl::StatusCode::kInternal, - "Failed to make rocFFT batched plan."}; + return absl::Status{absl::StatusCode::kInternal, + "Failed to make rocFFT batched plan."}; } return UpdateScratchAllocator(stream, scratch_allocator); } } - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status ROCMFftPlan::Initialize(GpuExecutor *parent, Stream *stream, - int rank, uint64_t *elem_count, - fft::Type type, - ScratchAllocator *scratch_allocator) { +absl::Status ROCMFftPlan::Initialize(GpuExecutor *parent, Stream *stream, + int rank, uint64_t *elem_count, + fft::Type type, + ScratchAllocator *scratch_allocator) { return Initialize(parent_, stream, rank, elem_count, /*input_embed=*/nullptr, /*input_stride=*/0, /*input_distance=*/0, @@ -324,7 +324,7 @@ tsl::Status ROCMFftPlan::Initialize(GpuExecutor *parent, Stream *stream, /*output_distance=*/0, type, 1, scratch_allocator); } -tsl::Status ROCMFftPlan::UpdateScratchAllocator( +absl::Status ROCMFftPlan::UpdateScratchAllocator( Stream *stream, ScratchAllocator *scratch_allocator) { scratch_allocator_ = scratch_allocator; if (scratch_size_bytes_ != 0) { @@ -338,10 +338,9 @@ tsl::Status ROCMFftPlan::UpdateScratchAllocator( auto ret = wrap::hipfftSetWorkArea(parent_, plan_, scratch_.opaque()); if (ret != HIPFFT_SUCCESS) { LOG(ERROR) << "failed to set work area for rocFFT plan:" << ret; - return tsl::Status(absl::StatusCode::kInternal, - "Failed to set work area for rocFFT plan."); + return absl::InternalError("Failed to set work area for rocFFT plan."); } - return tsl::OkStatus(); + return absl::OkStatus(); } ROCMFftPlan::~ROCMFftPlan() { wrap::hipfftDestroy(parent_, plan_); } @@ -367,121 +366,13 @@ int ROCMFftPlan::GetFftDirection() const { } } -std::unique_ptr ROCMFft::Create1dPlan(Stream *stream, uint64_t num_x, - fft::Type type, - bool in_place_fft) { - std::unique_ptr fft_plan_ptr{new ROCMFftPlan()}; - uint64_t elem_count[1] = {num_x}; - tsl::Status status = - fft_plan_ptr->Initialize(parent_, stream, 1, elem_count, type, - /*scratch_allocator=*/nullptr); - // TODO(yangzihao): In the future, send error msg back to TensorFlow - // so it can fail gracefully, - if (!status.ok()) { - LOG(FATAL) << "failed to initialize hipfft 1d plan: " << status.message(); - } - return std::move(fft_plan_ptr); -} - -std::unique_ptr ROCMFft::Create1dPlanWithScratchAllocator( - Stream *stream, uint64_t num_x, fft::Type type, bool in_place_fft, - ScratchAllocator *scratch_allocator) { - std::unique_ptr fft_plan_ptr{new ROCMFftPlan()}; - uint64_t elem_count[1] = {num_x}; - tsl::Status status = fft_plan_ptr->Initialize(parent_, stream, 1, elem_count, - type, scratch_allocator); - if (!status.ok()) { - LOG(FATAL) - << "failed to initialize hipfft 1d plan with customized allocator: " - << status.message(); - } - return std::move(fft_plan_ptr); -} - -std::unique_ptr ROCMFft::Create2dPlan(Stream *stream, uint64_t num_x, - uint64_t num_y, fft::Type type, - bool in_place_fft) { - std::unique_ptr fft_plan_ptr{new ROCMFftPlan()}; - uint64_t elem_count[2] = {num_x, num_y}; - tsl::Status status = - fft_plan_ptr->Initialize(parent_, stream, 1, elem_count, type, - /*scratch_allocator=*/nullptr); - if (!status.ok()) { - LOG(FATAL) << "failed to initialize hipfft 2d plan: " << status.message(); - } - return std::move(fft_plan_ptr); -} - -std::unique_ptr ROCMFft::Create2dPlanWithScratchAllocator( - Stream *stream, uint64_t num_x, uint64 num_y, fft::Type type, - bool in_place_fft, ScratchAllocator *scratch_allocator) { - std::unique_ptr fft_plan_ptr{new ROCMFftPlan()}; - uint64_t elem_count[2] = {num_x, num_y}; - tsl::Status status = fft_plan_ptr->Initialize(parent_, stream, 2, elem_count, - type, scratch_allocator); - if (!status.ok()) { - LOG(FATAL) - << "failed to initialize hipfft 2d plan with customized allocator: " - << status.message(); - } - return std::move(fft_plan_ptr); -} - -std::unique_ptr ROCMFft::Create3dPlan(Stream *stream, uint64_t num_x, - uint64_t num_y, uint64 num_z, - fft::Type type, - bool in_place_fft) { - std::unique_ptr fft_plan_ptr{new ROCMFftPlan()}; - uint64_t elem_count[3] = {num_x, num_y, num_z}; - tsl::Status status = - fft_plan_ptr->Initialize(parent_, stream, 3, elem_count, type, - /*scratch_allocator=*/nullptr); - if (!status.ok()) { - LOG(FATAL) << "failed to initialize hipfft 3d plan: " << status.message(); - } - return std::move(fft_plan_ptr); -} - -std::unique_ptr ROCMFft::Create3dPlanWithScratchAllocator( - Stream *stream, uint64_t num_x, uint64 num_y, uint64 num_z, fft::Type type, - bool in_place_fft, ScratchAllocator *scratch_allocator) { - std::unique_ptr fft_plan_ptr{new ROCMFftPlan()}; - uint64_t elem_count[3] = {num_x, num_y, num_z}; - tsl::Status status = fft_plan_ptr->Initialize(parent_, stream, 3, elem_count, - type, scratch_allocator); - if (!status.ok()) { - LOG(FATAL) - << "failed to initialize hipfft 3d plan with customized allocator: " - << status.message(); - } - return std::move(fft_plan_ptr); -} - -std::unique_ptr ROCMFft::CreateBatchedPlan( - Stream *stream, int rank, uint64_t *elem_count, uint64 *input_embed, - uint64_t input_stride, uint64 input_distance, uint64 *output_embed, - uint64_t output_stride, uint64 output_distance, fft::Type type, - bool in_place_fft, int batch_count) { - std::unique_ptr fft_plan_ptr{new ROCMFftPlan()}; - tsl::Status status = fft_plan_ptr->Initialize( - parent_, stream, rank, elem_count, input_embed, input_stride, - input_distance, output_embed, output_stride, output_distance, type, - batch_count, /*scratch_allocator=*/nullptr); - if (!status.ok()) { - LOG(FATAL) << "failed to initialize batched hipfft plan: " - << status.message(); - } - - return std::move(fft_plan_ptr); -} - std::unique_ptr ROCMFft::CreateBatchedPlanWithScratchAllocator( Stream *stream, int rank, uint64_t *elem_count, uint64 *input_embed, uint64_t input_stride, uint64 input_distance, uint64 *output_embed, uint64_t output_stride, uint64 output_distance, fft::Type type, bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) { std::unique_ptr fft_plan_ptr{new ROCMFftPlan()}; - tsl::Status status = fft_plan_ptr->Initialize( + absl::Status status = fft_plan_ptr->Initialize( parent_, stream, rank, elem_count, input_embed, input_stride, input_distance, output_embed, output_stride, output_distance, type, batch_count, scratch_allocator); @@ -496,7 +387,7 @@ std::unique_ptr ROCMFft::CreateBatchedPlanWithScratchAllocator( void ROCMFft::UpdatePlanWithScratchAllocator( Stream *stream, fft::Plan *plan, ScratchAllocator *scratch_allocator) { ROCMFftPlan *rocm_fft_plan = dynamic_cast(plan); - tsl::Status status = + absl::Status status = rocm_fft_plan->UpdateScratchAllocator(stream, scratch_allocator); if (!status.ok()) { LOG(FATAL) << "failed to update custom allocator for hipfft plan: " @@ -533,7 +424,7 @@ bool ROCMFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfftExec, if (allocator) { auto allocated = allocator->AllocateBytes(input.size()); if (allocated.ok()) { - if (stream->ThenMemcpy(&allocated.value(), input, input.size()).ok()) { + if (stream->Memcpy(&allocated.value(), input, input.size()).ok()) { input_maybe_copy = DeviceMemory(allocated.value()); } else { LOG(ERROR) << "failed to copy input buffer for rocFFT."; @@ -615,7 +506,7 @@ void initialize_rocfft() { rocm::kROCmPlatformId, PluginKind::kFft); if (!rocFftAlreadyRegistered) { - tsl::Status status = + absl::Status status = PluginRegistry::Instance()->RegisterFactory( rocm::kROCmPlatformId, "rocFFT", [](internal::StreamExecutorInterface *parent) -> fft::FftSupport * { @@ -638,5 +529,6 @@ void initialize_rocfft() { } // namespace stream_executor -REGISTER_MODULE_INITIALIZER(register_rocfft, - { stream_executor::initialize_rocfft(); }); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(register_rocfft, { + stream_executor::initialize_rocfft(); +}); diff --git a/xla/stream_executor/rocm/rocm_fft.h b/xla/stream_executor/rocm/rocm_fft.h index 330237656f2f7..8b55b4b4446e2 100644 --- a/xla/stream_executor/rocm/rocm_fft.h +++ b/xla/stream_executor/rocm/rocm_fft.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -72,20 +72,20 @@ class ROCMFftPlan : public fft::Plan { } // Initialize function for batched plan - tsl::Status Initialize(GpuExecutor *parent, Stream *stream, int rank, - uint64_t *elem_count, uint64 *input_embed, - uint64_t input_stride, uint64 input_distance, - uint64_t *output_embed, uint64 output_stride, - uint64_t output_distance, fft::Type type, - int batch_count, ScratchAllocator *scratch_allocator); + absl::Status Initialize(GpuExecutor *parent, Stream *stream, int rank, + uint64_t *elem_count, uint64 *input_embed, + uint64_t input_stride, uint64 input_distance, + uint64_t *output_embed, uint64 output_stride, + uint64_t output_distance, fft::Type type, + int batch_count, ScratchAllocator *scratch_allocator); // Initialize function for 1d,2d, and 3d plan - tsl::Status Initialize(GpuExecutor *parent, Stream *stream, int rank, - uint64_t *elem_count, fft::Type type, - ScratchAllocator *scratch_allocator); + absl::Status Initialize(GpuExecutor *parent, Stream *stream, int rank, + uint64_t *elem_count, fft::Type type, + ScratchAllocator *scratch_allocator); - tsl::Status UpdateScratchAllocator(Stream *stream, - ScratchAllocator *scratch_allocator); + absl::Status UpdateScratchAllocator(Stream *stream, + ScratchAllocator *scratch_allocator); ScratchAllocator *GetScratchAllocator() const { return scratch_allocator_; } diff --git a/xla/stream_executor/rocm/rocm_gpu_executor.cc b/xla/stream_executor/rocm/rocm_gpu_executor.cc deleted file mode 100644 index 82e1c71186bd7..0000000000000 --- a/xla/stream_executor/rocm/rocm_gpu_executor.cc +++ /dev/null @@ -1,963 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include - -#include - -#include "absl/base/casts.h" -#include "absl/functional/any_invocable.h" -#include "absl/strings/ascii.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "xla/stream_executor/gpu/gpu_command_buffer.h" -#include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_event.h" -#include "xla/stream_executor/gpu/gpu_executor.h" -#include "xla/stream_executor/gpu/gpu_kernel.h" -#include "xla/stream_executor/gpu/gpu_runtime.h" -#include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/gpu_timer.h" -#include "xla/stream_executor/platform.h" -#include "xla/stream_executor/platform/dso_loader.h" -#include "xla/stream_executor/platform/initialize.h" -#include "xla/stream_executor/platform/port.h" -#include "xla/stream_executor/plugin_registry.h" -#include "xla/stream_executor/rocm/rocm_diagnostics.h" -#include "xla/stream_executor/rocm/rocm_platform_id.h" -#include "xla/stream_executor/stream.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/stream_executor_internal.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" - -#ifdef PLATFORMS_GPUS_ROCM_DYNAMIC_LIBROCM_DYNAMIC_LIBROCM_H_ -#error \ - "No driver calls in this file, wrap driver functionality in rocm_driver.cc." -#endif - -#ifdef __ROCM_RUNTIME_H__ -#error \ - "ROCM runtime being included into ROCM GPU executor; should be driver only." -#endif - -namespace stream_executor { -namespace gpu { - -static GpuEvent* AsGpuEvent(Event* event) { - DCHECK(event != nullptr); - return static_cast(event->implementation()); -} - -// Given const GPU memory, returns a librocm device pointer datatype, suitable -// for passing directly to librocm APIs. -// -// N.B. we must lose constness in order to pass a suitable type to the existing -// librocm APIs, so the caller should take care to only pass the result of const -// GPU memory conversions to librocm functions which will honor constness. -static hipDeviceptr_t AsROCmDevicePtr(const DeviceMemoryBase& gpu_mem) { - return const_cast(gpu_mem.opaque()); -} - -// See description on const version above. -static hipDeviceptr_t AsROCmDevicePtr(DeviceMemoryBase* gpu_mem) { - return AsROCmDevicePtr(*gpu_mem); -} - -static GpuContext* GetGpuContext(Stream* stream) { - return static_cast(stream->parent()->implementation()) - ->gpu_context(); -} - -GpuContext* ExtractGpuContext(GpuExecutor* rocm_exec) { - CHECK(rocm_exec != nullptr); - return rocm_exec->gpu_context(); -} - -GpuExecutor::~GpuExecutor() { - for (auto& it : disk_modules_) { - GpuDriver::UnloadModule(context_, it.second); - } - for (auto& it : in_memory_modules_) { - GpuDriver::UnloadModule(context_, it.second); - } - if (context_ != nullptr) { - GpuDriver::DestroyContext(context_); - } - CHECK(kernel_to_gpu_binary_.empty()) << "GpuExecutor has live kernels."; - CHECK(gpu_binary_to_module_.empty()) << "GpuExecutor has loaded modules."; -} -bool GpuExecutor::UnloadModule(ModuleHandle module_handle) { - const char* gpu_binary = reinterpret_cast(module_handle.id()); - absl::MutexLock lock{&in_memory_modules_mu_}; - return UnloadGpuBinary(gpu_binary); -} - -namespace { -int fpus_per_core(std::string gcn_arch_name) { - // Source: - // https://www.amd.com/content/dam/amd/en/documents/instinct-business-docs/white-papers/amd-cdna2-white-paper.pdf - int n = 128; // gfx90a and gfx908 -> 128 - if (gcn_arch_name.substr(0, 6) == "gfx906") { - n = 64; - } - return n; -} -} // namespace - -tsl::StatusOr> -GpuExecutor::CreateOrShareConstant(Stream* stream, - absl::Span content) { - return tsl::errors::Unimplemented("Not implemented for ROCm"); -} - -bool GpuExecutor::UnloadGpuBinary(const void* gpu_binary) { - auto module_it = gpu_binary_to_module_.find(gpu_binary); - if (gpu_binary_to_module_.end() == module_it) { - VLOG(3) << "No loaded HSACO module for " << gpu_binary; - return false; - } - auto& module = module_it->second.first; - auto& refcount = module_it->second.second; - VLOG(3) << "Found HSACO module " << module << " with refcount " << refcount; - if (--refcount == 0) { - VLOG(3) << "Unloading HSACO module " << module; - GpuDriver::UnloadModule(context_, module); - gpu_binary_to_module_.erase(module_it); - const char* mem_it = nullptr; - for (auto x : in_memory_modules_) { - if (x.second == module) mem_it = x.first; - } - if (mem_it != nullptr) in_memory_modules_.erase(mem_it); - } - return true; -} - -void GpuExecutor::UnloadKernel(const Kernel* kernel) { - VLOG(3) << "Unloading kernel " << kernel << " : " << kernel->name(); - - absl::MutexLock lock{&in_memory_modules_mu_}; - auto gpu_binary_it = kernel_to_gpu_binary_.find(kernel); - if (kernel_to_gpu_binary_.end() == gpu_binary_it) { - VLOG(3) << "Kernel " << kernel << " : " << kernel->name() - << " has never been loaded."; - return; // We've never seen this kernel. - } - VLOG(3) << "Kernel " << kernel << " : " << kernel->name() - << " has loaded GPU code " << gpu_binary_it->second; - UnloadGpuBinary(gpu_binary_it->second); - kernel_to_gpu_binary_.erase(gpu_binary_it); -} - -tsl::Status GpuExecutor::Init(int device_ordinal, - DeviceOptions device_options) { - device_ordinal_ = device_ordinal; - - auto status = GpuDriver::Init(); - if (!status.ok()) { - return status; - } - - status = GpuDriver::GetDevice(device_ordinal_, &device_); - if (!status.ok()) { - return status; - } - - status = GpuDriver::CreateContext(device_ordinal_, device_, device_options, - &context_); - if (!status.ok()) { - return status; - } - - return GpuDriver::GetGpuISAVersion(&version_, device_); -} - -// Returns the path to the running executable. -// N.B. Derived from //knowledge/smalltalk/background_kb.cc -// Arg: strip_exe: if true, remove the name of the executable itself from the -// returned string. Example: calling this from /usr/bin/foo -// would return /usr/bin. -static string GetBinaryDir(bool strip_exe) { - char exe_path[PATH_MAX] = {0}; - CHECK_NE(readlink("/proc/self/exe", exe_path, sizeof(exe_path) - 1), -1); - // Make sure it's null-terminated: - exe_path[sizeof(exe_path) - 1] = 0; - - if (strip_exe) { - // The exe is the last component of the path, so remove one component. - string ret = exe_path; - std::vector components = absl::StrSplit(exe_path, '/'); - components.pop_back(); - return absl::StrJoin(components, "/"); - } - return exe_path; -} - -tsl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) { - GpuKernel* rocm_kernel = AsGpuKernel(kernel); - hipModule_t module = nullptr; - const string* kernel_name; - - if (spec.has_cuda_cubin_on_disk()) { - return tsl::errors::Internal( - "Loading ROCM kernel from disk is not supported"); - } else if (spec.has_cuda_cubin_in_memory()) { - kernel_name = &spec.cuda_cubin_in_memory().kernel_name(); - - const char* hsaco = spec.cuda_cubin_in_memory().bytes(); - absl::MutexLock lock{&in_memory_modules_mu_}; - module = in_memory_modules_[hsaco]; - - if (module == nullptr) { - TF_RETURN_IF_ERROR(GpuDriver::LoadHsaco(context_, hsaco, &module)); - } - kernel_to_gpu_binary_[kernel] = hsaco; - } else if (spec.has_in_process_symbol()) { - kernel_name = &spec.in_process_symbol().kernel_name(); - void* symbol = spec.in_process_symbol().symbol(); - - VLOG(1) << "Resolve ROCM kernel " << *kernel_name - << " from symbol pointer: " << symbol; - - *rocm_kernel->gpu_function_ptr() = - static_cast(spec.in_process_symbol().symbol()); - } else { - return tsl::errors::Internal("No method of loading ROCM kernel provided"); - } - - // If we resolved kernel from a symbol pointer, there is no need to load it - // from a module, as ROCm runtime did that automatically for us. - if (!spec.has_in_process_symbol()) { - VLOG(2) << "getting function " << *kernel_name << " from module " << module; - TF_RETURN_IF_ERROR( - GpuDriver::GetModuleFunction(context_, module, kernel_name->c_str(), - rocm_kernel->gpu_function_ptr())); - } - - // We have to trust the kernel loader spec arity because there doesn't appear - // to be a way to reflect on the number of expected arguments w/the ROCM API. - rocm_kernel->set_arity(spec.arity()); - - // unable to get kernel metadata for in-process kernel - if (!spec.has_in_process_symbol()) { - KernelMetadata kernel_metadata; - TF_RETURN_IF_ERROR(GetKernelMetadata(rocm_kernel, &kernel_metadata)); - kernel->set_metadata(kernel_metadata); - } - kernel->set_name(*kernel_name); - kernel->set_kernel_args_packing(spec.kernel_args_packing()); - return tsl::OkStatus(); -} - -tsl::Status GpuExecutor::GetKernelMetadata(GpuKernel* rocm_kernel, - KernelMetadata* kernel_metadata) { - int value = 0; - TF_RETURN_IF_ERROR(GpuDriver::FuncGetAttribute( - HIP_FUNC_ATTRIBUTE_NUM_REGS, *rocm_kernel->gpu_function_ptr(), &value)); - kernel_metadata->set_registers_per_thread(value); - - TF_RETURN_IF_ERROR( - GpuDriver::FuncGetAttribute(HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, - *rocm_kernel->gpu_function_ptr(), &value)); - kernel_metadata->set_shared_memory_bytes(value); - return ::tsl::OkStatus(); -} - -tsl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const Kernel& kernel, const KernelArgs& args) { - CHECK_EQ(kernel.Arity() + (args.number_of_shared_bytes() > 0), - args.number_of_arguments()); - GpuStreamHandle hipstream = AsGpuStreamValue(stream); - const GpuKernel* rocm_kernel = AsGpuKernel(&kernel); - hipFunction_t hipfunc = rocm_kernel->AsGpuFunctionHandle(); - - // Only perform/print the occupancy check once. Even just checking to see - // whether we've done an occupancy check on this kernel before isn't free - // (because we have to synchronize), so we only do this at -v 2+. - if (VLOG_IS_ON(2)) { - absl::MutexLock lock(&launched_kernels_mu_); - if (!launched_kernels_.count(hipfunc)) { - VlogOccupancyInfo(kernel, thread_dims, block_dims); - // TODO(rspringer): Remove elements from launched_kernels_...if we ever - // expose a kernel/module deallocation method. - launched_kernels_.insert(hipfunc); - } - } - - if (rocm_kernel->GetPreferredCacheConfig() != - KernelCacheConfig::kNoPreference) { - TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig( - hipfunc, rocm_kernel->GetGpuCacheConfig())); - } - - auto* packed_args = DynCast(&args); - if (!packed_args) - return absl::InternalError("Unsupported kernel arguments type"); - - void** kernel_params = - const_cast(packed_args->argument_addresses().data()); - - return GpuDriver::LaunchKernel( - GetGpuContext(stream), kernel.name(), hipfunc, block_dims.x, block_dims.y, - block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, - args.number_of_shared_bytes(), hipstream, kernel_params, nullptr); -} - -tsl::Status GpuExecutor::Submit(Stream* stream, - const CommandBuffer& command_buffer) { - if (command_buffer.mode() != CommandBuffer::Mode::kPrimary) { - return absl::InvalidArgumentError( - "Can't submit non-primary command buffer for execution"); - } - - auto exec = GpuCommandBuffer::Cast(&command_buffer)->executable(); - VLOG(3) << "Launch command buffer execuable graph " << exec - << " on a stream: " << stream->DebugStreamPointers(); - return GpuDriver::GraphLaunch(exec, AsGpuStreamValue(stream)); -} - -int GpuExecutor::CalculateOccupancy(const DeviceDescription& device_description, - uint64_t registers_per_thread, - uint64_t shared_memory_per_block, - const ThreadDim& thread_dims, - GpuFunctionHandle func) { - LOG(FATAL) << "Feature not supported on ROCM platform (CalculateOccupancy)"; - return 0; -} - -int GpuExecutor::CompareOccupancy(int* initial_blocks, - const DeviceDescription& device_description, - uint64_t registers_per_thread, - uint64_t shared_memory_per_block, - const ThreadDim& thread_dims, - GpuFunctionHandle func) { - LOG(FATAL) << "Feature not supported on ROCM platform (CompareOccupancy)"; - return 0; -} - -tsl::Status GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec, - ModuleHandle* module_handle) { - // In GpuExecutor we store the pointer to the HSACO binary as - // ModuleHandle::id(). - hipModule_t hip_module = nullptr; - // TODO(ROCm): Need generic term instead of cubin/cuda/ptx - if (spec.has_cuda_cubin_in_memory()) { - absl::MutexLock lock{&in_memory_modules_mu_}; - TF_RETURN_IF_ERROR(LoadModuleFromHsaco( - reinterpret_cast(spec.cuda_cubin_in_memory().data()), - &hip_module)); - *module_handle = ModuleHandle(const_cast( - static_cast(spec.cuda_cubin_in_memory().data()))); - return tsl::OkStatus(); - } else { - return tsl::errors::Internal("No HASCO binary found"); - } -} - -tsl::Status GpuExecutor::LoadModuleFromCuBin(const char* cubin, - hipModule_t* module) { - LOG(FATAL) << "Feature not supported on ROCM platform (LoadModuleFromCuBin)"; -} - -tsl::Status GpuExecutor::LoadModuleFromPtx(const char* ptx, - hipModule_t* module) { - LOG(FATAL) << "Feature not supported on ROCM platform (LoadModuleFromPtx)"; -} - -tsl::Status GpuExecutor::LoadModuleFromHsaco(const char* hsaco, - hipModule_t* module) { - uint64_t module_refcount; - std::tie(*module, module_refcount) = gpu_binary_to_module_[hsaco]; - - if (*module == nullptr) { - TF_RETURN_IF_ERROR(GpuDriver::LoadHsaco(context_, hsaco, module)); - module_refcount = 1; - in_memory_modules_[hsaco] = *module; - VLOG(3) << "Loaded HSACO " << static_cast(hsaco) - << " as module " << *module; - } else { - ++module_refcount; - VLOG(3) << "HSACO " << static_cast(hsaco) - << " is already loaded as module " << *module; - } - gpu_binary_to_module_[hsaco] = {*module, module_refcount}; - return tsl::OkStatus(); -} - -// This is a non-essential operation; if there's a failure, proceed without -// logging an error. It's nearly certain that in case of failures, we'd never -// get here in the first place; these are very low-impact routines. -void GpuExecutor::VlogOccupancyInfo(const Kernel& kernel, - const ThreadDim& thread_dims, - const BlockDim& block_dims) { - // TODO(ROCm) implement this feature in HIP -} - -DeviceMemoryBase GpuExecutor::Allocate(uint64_t size, int64_t memory_space) { - CHECK_EQ(memory_space, 0); - return DeviceMemoryBase(GpuDriver::DeviceAllocate(context_, size), size); -} - -void* GpuExecutor::GetSubBuffer(DeviceMemoryBase* mem, uint64_t offset_bytes, - uint64_t size_bytes) { - // offset and size are in bytes, so char* works as the pointer type. - return reinterpret_cast(mem->opaque()) + offset_bytes; -} - -void GpuExecutor::Deallocate(DeviceMemoryBase* mem) { - GpuDriver::DeviceDeallocate(context_, mem->opaque()); -} - -bool GpuExecutor::HostMemoryRegister(void* location, uint64_t size) { - if (location == nullptr || size == 0) { - LOG(WARNING) << "attempting to register null or zero-sized memory: " - << location << "; size " << size; - } - VLOG(2) << "registering " << location << " size " << size; - return GpuDriver::HostRegister(context_, location, size); -} - -bool GpuExecutor::HostMemoryUnregister(void* location) { - VLOG(2) << "unregistering " << location; - return GpuDriver::HostUnregister(context_, location); -} - -bool GpuExecutor::SynchronizeAllActivity() { - return GpuDriver::SynchronizeContext(context_); -} - -tsl::Status GpuExecutor::SynchronousMemZero(DeviceMemoryBase* location, - uint64_t size) { - if (reinterpret_cast(location->opaque()) % 4 == 0 && - size % 4 == 0) { - return GpuDriver::SynchronousMemsetUint32( - context_, AsROCmDevicePtr(location), 0x0, size / 4); - } - return GpuDriver::SynchronousMemsetUint8(context_, AsROCmDevicePtr(location), - 0x0, size); -} - -tsl::Status GpuExecutor::SynchronousMemSet(DeviceMemoryBase* location, - int value, uint64_t size) { - if (reinterpret_cast(location->opaque()) % 4 == 0 && - size % 4 == 0) { - // hipMemset reinterprets "value" as a uint8. - uint8 byte_value = static_cast(value); - uint32 pattern = (byte_value << 24) | (byte_value << 16) | - (byte_value << 8) | byte_value; - return GpuDriver::SynchronousMemsetUint32( - context_, AsROCmDevicePtr(location), pattern, size / 4); - } - return GpuDriver::SynchronousMemsetUint8(context_, AsROCmDevicePtr(location), - value, size); -} - -tsl::Status GpuExecutor::SynchronousMemcpy(DeviceMemoryBase* gpu_dst, - const void* host_src, - uint64_t size) { - return GpuDriver::SynchronousMemcpyH2D(context_, AsROCmDevicePtr(gpu_dst), - host_src, size); -} - -tsl::Status GpuExecutor::SynchronousMemcpy(void* host_dst, - const DeviceMemoryBase& gpu_src, - uint64_t size) { - return GpuDriver::SynchronousMemcpyD2H(context_, host_dst, - AsROCmDevicePtr(gpu_src), size); -} - -tsl::Status GpuExecutor::SynchronousMemcpyDeviceToDevice( - DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, uint64_t size) { - return GpuDriver::SynchronousMemcpyD2D(context_, AsROCmDevicePtr(gpu_dst), - AsROCmDevicePtr(gpu_src), size); -} - -tsl::Status GpuExecutor::MemZero(Stream* stream, DeviceMemoryBase* location, - uint64_t size) { - if (reinterpret_cast(location->opaque()) % 4 == 0 && - size % 4 == 0) { - return Memset32(stream, location, 0x0, size); - } else { - return Memset(stream, location, 0x0, size); - } -} - -tsl::Status GpuExecutor::Memset(Stream* stream, DeviceMemoryBase* location, - uint8 pattern, uint64_t size) { - VLOG(2) << "enqueueing memset8 operation onto stream " << stream - << " at location " << location << " with size " << size - << " and pattern " << std::hex << pattern; - return GpuDriver::AsynchronousMemsetUint8(context_, AsROCmDevicePtr(location), - pattern, size, - AsGpuStreamValue(stream)); -} - -tsl::Status GpuExecutor::Memset32(Stream* stream, DeviceMemoryBase* location, - uint32 pattern, uint64_t size) { - VLOG(2) << "enqueueing memset32 operation onto stream " << stream - << " at location " << location << " with size " << size - << " and pattern " << std::hex << pattern; - CHECK(reinterpret_cast(location->opaque()) % 4 == 0 && - size % 4 == 0); - return GpuDriver::AsynchronousMemsetUint32( - context_, AsROCmDevicePtr(location), pattern, size / 4, - AsGpuStreamValue(stream)); -} - -bool GpuExecutor::Memcpy(Stream* stream, void* host_dst, - const DeviceMemoryBase& gpu_src, uint64_t size) { - return GpuDriver::AsynchronousMemcpyD2H(context_, host_dst, - AsROCmDevicePtr(gpu_src), size, - AsGpuStreamValue(stream)); -} - -bool GpuExecutor::Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, - const void* host_src, uint64_t size) { - return GpuDriver::AsynchronousMemcpyH2D(context_, AsROCmDevicePtr(gpu_dst), - host_src, size, - AsGpuStreamValue(stream)); -} - -bool GpuExecutor::MemcpyDeviceToDevice(Stream* stream, - DeviceMemoryBase* gpu_dst, - const DeviceMemoryBase& gpu_src, - uint64_t size) { - return GpuDriver::AsynchronousMemcpyD2D(context_, AsROCmDevicePtr(gpu_dst), - AsROCmDevicePtr(gpu_src), size, - AsGpuStreamValue(stream)); -} - -bool GpuExecutor::HostCallback(Stream* stream, - absl::AnyInvocable callback) { - auto callback_ptr = - new absl::AnyInvocable([cb = std::move(callback)]() mutable { - tsl::Status s = std::move(cb)(); - if (!s.ok()) { - LOG(WARNING) << "Host callback failed: " << s; - } - }); - return GpuDriver::AddStreamCallback(context_, AsGpuStreamValue(stream), - InternalHostCallback, callback_ptr); -} - -/* static */ void GpuExecutor::InternalHostCallback(void* data) { - auto* callback = reinterpret_cast*>(data); - std::move (*callback)(); - delete callback; -} - -tsl::Status GpuExecutor::AllocateEvent(Event* event) { - return AsGpuEvent(event)->Init(); -} - -tsl::Status GpuExecutor::DeallocateEvent(Event* event) { - return AsGpuEvent(event)->Destroy(); -} - -tsl::Status GpuExecutor::RecordEvent(Stream* stream, Event* event) { - return AsGpuEvent(event)->Record(AsGpuStream(stream)); -} - -tsl::Status GpuExecutor::WaitForEvent(Stream* stream, Event* event) { - if (GpuDriver::WaitStreamOnEvent(context_, AsGpuStream(stream)->gpu_stream(), - AsGpuEvent(event)->gpu_event())) { - return tsl::OkStatus(); - } else { - return tsl::Status{ - absl::StatusCode::kInternal, - absl::StrFormat("error recording waiting for ROCM event on stream %p", - stream)}; - } -} - -tsl::Status GpuExecutor::WaitForEventOnExternalStream(std::intptr_t stream, - Event* event) { - if (GpuDriver::WaitStreamOnEvent(context_, - absl::bit_cast(stream), - AsGpuEvent(event)->gpu_event())) { - return ::tsl::OkStatus(); - } else { - return tsl::Status(absl::StatusCode::kInternal, - "error waiting for ROCM event on external stream"); - } -} - -Event::Status GpuExecutor::PollForEventStatus(Event* event) { - return AsGpuEvent(event)->PollForStatus(); -} - -bool GpuExecutor::AllocateStream(Stream* stream) { - absl::MutexLock l(&alive_gpu_streams_mu_); - bool out = AsGpuStream(stream)->Init(); - alive_gpu_streams_[stream->platform_specific_handle().stream] = stream; - return out; -} - -void GpuExecutor::DeallocateStream(Stream* stream) { - GpuStream* rocm_stream = AsGpuStream(stream); - absl::MutexLock l(&alive_gpu_streams_mu_); - alive_gpu_streams_.erase(rocm_stream->platform_specific_stream()); - if (!rocm_stream->IsIdle()) { - LOG(ERROR) << "Deallocating stream with pending work"; - } - rocm_stream->Destroy(); -} - -bool GpuExecutor::CreateStreamDependency(Stream* dependent, Stream* other) { - GpuEventHandle other_completed_event = *AsGpuStream(other)->completed_event(); - bool ok = GpuDriver::RecordEvent(context_, other_completed_event, - AsGpuStreamValue(other)) - .ok(); - if (!ok) { - LOG(ERROR) << "failed to record completion event; " - "therefore, failed to create inter-stream dependency"; - return false; - } - - return GpuDriver::WaitStreamOnEvent(context_, AsGpuStreamValue(dependent), - other_completed_event); -} - -tsl::Status GpuExecutor::BlockHostUntilDone(Stream* stream) { - return GpuDriver::SynchronizeStream(context_, AsGpuStreamValue(stream)); -} - -blas::BlasSupport* GpuExecutor::CreateBlas() { - PluginRegistry* registry = PluginRegistry::Instance(); - tsl::StatusOr status = - registry->GetFactory(rocm::kROCmPlatformId); - if (!status.ok()) { - LOG(ERROR) << "Unable to retrieve BLAS factory: " - << status.status().message(); - return nullptr; - } - - return status.value()(this); -} - -dnn::DnnSupport* GpuExecutor::CreateDnn() { - PluginRegistry* registry = PluginRegistry::Instance(); - tsl::StatusOr status = - registry->GetFactory(rocm::kROCmPlatformId); - if (!status.ok()) { - LOG(ERROR) << "Unable to retrieve DNN factory: " - << status.status().message(); - return nullptr; - } - - return status.value()(this); -} - -fft::FftSupport* GpuExecutor::CreateFft() { - PluginRegistry* registry = PluginRegistry::Instance(); - tsl::StatusOr status = - registry->GetFactory(rocm::kROCmPlatformId); - if (!status.ok()) { - LOG(ERROR) << "Unable to retrieve FFT factory: " - << status.status().message(); - return nullptr; - } - - return status.value()(this); -} - -bool GpuExecutor::CanEnablePeerAccessTo(StreamExecutorInterface* other) { - GpuExecutor* rocm_other = static_cast(other); - return GpuDriver::CanEnablePeerAccess(context_, rocm_other->context_); -} - -tsl::Status GpuExecutor::EnablePeerAccessTo(StreamExecutorInterface* other) { - GpuExecutor* rocm_other = static_cast(other); - return GpuDriver::EnablePeerAccess(context_, rocm_other->context_); -} - -bool GpuExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const { - return GpuDriver::GetDeviceMemoryInfo(context_, free, total); -} - -bool GpuExecutor::GetSymbol(const string& symbol_name, - ModuleHandle module_handle, void** mem, - size_t* bytes) { - absl::MutexLock lock{&in_memory_modules_mu_}; - if (static_cast(module_handle)) { - auto it = gpu_binary_to_module_.find(module_handle.id()); - CHECK(it != gpu_binary_to_module_.end()); - if (GpuDriver::GetModuleSymbol( - context_, it->second.first, symbol_name.c_str(), - reinterpret_cast(mem), bytes)) { - return true; - } - } - - for (auto& it : gpu_binary_to_module_) { - if (GpuDriver::GetModuleSymbol( - context_, it.second.first, symbol_name.c_str(), - reinterpret_cast(mem), bytes)) { - return true; - } - } - - LOG(INFO) << "Falied to find symbol in any modules: " << symbol_name; - return false; -} - -tsl::Status FillBlockDimLimit(GpuDeviceHandle device, - BlockDim* block_dim_limit) { - // The BlockDim name is a mismatch against these GRID_DIM_* queries because - // we use BlockDims to express the dimensions of blocks within a grid - // (as opposed to ThreadDim which expresses the dimensions of threads - // within a block). - int x, y, z; - TF_RETURN_IF_ERROR(GpuDriver::GetGridLimits(&x, &y, &z, device)); - - block_dim_limit->x = x; - block_dim_limit->y = y; - block_dim_limit->z = z; - return tsl::OkStatus(); -} - -std::unique_ptr -GpuExecutor::CreateEventImplementation() { - return std::unique_ptr(new GpuEvent(this)); -} - -std::unique_ptr -GpuExecutor::CreateKernelImplementation() { - return std::unique_ptr(new GpuKernel()); -} - -std::unique_ptr -GpuExecutor::GetStreamImplementation() { - return std::unique_ptr(new GpuStream(this)); -} - -tsl::StatusOr> -GpuExecutor::GetCommandBufferImplementation(CommandBuffer::Mode mode) { - VLOG(2) << "Create ROCm command buffer (ROCm graph)"; - GpuGraphHandle graph = nullptr; - TF_RETURN_IF_ERROR(GpuDriver::CreateGraph(&graph)); - return std::make_unique(mode, /*parent=*/this, graph); -} - -std::unique_ptr -GpuExecutor::GetCommandBufferImplementation(CommandBuffer::Mode mode, - GpuGraphHandle graph, - bool is_owned_graph) { - VLOG(2) << "Create HIP command buffer (HIP graph) from existing graph " - << graph << "; is_owned_graph=" << is_owned_graph; - return std::make_unique(mode, /*parent=*/this, graph, - is_owned_graph); -} - -void* GpuExecutor::platform_specific_context() { return context_; } - -GpuContext* GpuExecutor::gpu_context() { return context_; } - -// Attempts to read the NUMA node corresponding to the GPU device's PCI bus out -// of SysFS. Returns -1 if it cannot. -// -// For anything more complicated/prod-focused than this, you'll likely want to -// turn to gsys' topology modeling. -static int TryToReadNumaNode(const string& pci_bus_id, int device_ordinal) { - VLOG(2) << "trying to read NUMA node for device ordinal: " << device_ordinal; - static const int kUnknownNumaNode = -1; - - if (pci_bus_id.empty()) { - LOG(INFO) << "no PCI bus ID for device ordinal: " << device_ordinal; - return kUnknownNumaNode; - } - - std::string filename = - absl::StrFormat("/sys/bus/pci/devices/%s/numa_node", pci_bus_id); - - // We have to use fopen/fread here so that the device properties can be - // populated before InitGoogle procedure has been completed (at which point we - // could use the file::* utilities). - FILE* file = fopen(filename.c_str(), "r"); - if (file == nullptr) { - LOG(INFO) << "could not open file to read NUMA node: " << filename - << "\nYour kernel may have been built without NUMA support."; - return kUnknownNumaNode; - } - - std::string content; - char buf[32]; - size_t did_read = fread(buf, sizeof(buf[0]), sizeof(buf) - 1, file); - buf[did_read] = '\0'; - content = buf; - - int32_t value; - if (absl::SimpleAtoi(content, &value)) { - if (value < 0) { // See http://b/18228951 for details on this path. - LOG(INFO) << "successful NUMA node read from SysFS had negative value (" - << value - << "), but there must be at least one NUMA node" - ", so returning NUMA node zero"; - fclose(file); - return 0; - } - fclose(file); - return value; - } - - LOG(WARNING) - << "could not convert SysFS file contents to integral NUMA node value: " - << content; - - fclose(file); - return kUnknownNumaNode; -} - -tsl::StatusOr> -GpuExecutor::CreateDeviceDescription(int device_ordinal) { - GpuDeviceHandle device; - auto status = GpuDriver::GetDevice(device_ordinal, &device); - if (!status.ok()) { - return status; - } - - int version; - status = GpuDriver::GetGpuISAVersion(&version, device); - if (!status.ok()) { - return status; - } - - std::string gcn_arch_name; - status = GpuDriver::GetGpuGCNArchName(device, &gcn_arch_name); - if (!status.ok()) { - return status; - } - - internal::DeviceDescriptionBuilder builder; - - { - int driver_version = 0; - (void)GpuDriver::GetDriverVersion(&driver_version); - string augmented_driver_version = absl::StrFormat( - "%d (%s)", driver_version, - rocm::DriverVersionStatusToString(Diagnostician::FindDsoVersion()) - .c_str()); - builder.set_driver_version(augmented_driver_version); - } - - { - string pci_bus_id = GpuDriver::GetPCIBusID(device); - - // Lower the hex characters to match sysfs. - pci_bus_id = absl::AsciiStrToLower(pci_bus_id); - builder.set_pci_bus_id(pci_bus_id); - - // Read the NUMA node corresponding to the PCI bus ID out of sysfs. - int numa_node = TryToReadNumaNode(pci_bus_id, device_ordinal); - builder.set_numa_node(numa_node); - } - - hipDeviceProp_t prop; - if (GpuDriver::GetDeviceProperties(&prop, device_ordinal)) { - builder.set_threads_per_block_limit(prop.maxThreadsPerBlock); - - ThreadDim thread_dim_limit; - thread_dim_limit.x = prop.maxThreadsDim[0]; - thread_dim_limit.y = prop.maxThreadsDim[1]; - thread_dim_limit.z = prop.maxThreadsDim[2]; - builder.set_thread_dim_limit(thread_dim_limit); - - float clock_rate_ghz = static_cast(prop.clockRate) / 1e6; - builder.set_clock_rate_ghz(clock_rate_ghz); - - // mem_bandwidth = 2 * mem_bus_width_in_bytes * mem_clock_rate_in_hz - int64_t memory_bandwidth = 2 * (int64_t(prop.memoryBusWidth) / 8) * - (int64_t(prop.memoryClockRate) * 1000); - builder.set_memory_bandwidth(memory_bandwidth); - - builder.set_l2_cache_size(prop.l2CacheSize); - } - - { - bool ecc_enabled = false; - (void)GpuDriver::IsEccEnabled(device, &ecc_enabled); - builder.set_ecc_enabled(ecc_enabled); - } - - uint64_t device_memory_size = -1; - (void)GpuDriver::GetDeviceTotalMemory(device, &device_memory_size); - builder.set_device_memory_size(device_memory_size); - - { - BlockDim block_dim_limit; - TF_RETURN_IF_ERROR(FillBlockDimLimit(device, &block_dim_limit)); - builder.set_block_dim_limit(block_dim_limit); - } - - { - string device_name; - TF_RETURN_IF_ERROR(GpuDriver::GetDeviceName(device, &device_name)); - builder.set_name(device_name); - } - - builder.set_platform_version( - absl::StrCat("AMDGPU ISA version: ", gcn_arch_name)); - - // TODO(leary) should be a way to query this from the driver, but this is - // unlikely to change for us any time soon. - builder.set_device_address_bits(64); - - builder.set_device_vendor("Advanced Micro Devices, Inc"); - builder.set_rocm_compute_capability(gcn_arch_name); - - builder.set_shared_memory_per_core( - GpuDriver::GetMaxSharedMemoryPerCore(device).value()); - builder.set_shared_memory_per_block( - GpuDriver::GetMaxSharedMemoryPerBlock(device).value()); - int core_count = GpuDriver::GetMultiprocessorCount(device).value(); - builder.set_core_count(core_count); - builder.set_fpus_per_core(fpus_per_core(gcn_arch_name)); - builder.set_threads_per_core_limit( - GpuDriver::GetMaxThreadsPerMultiprocessor(device).value()); - builder.set_registers_per_block_limit( - GpuDriver::GetMaxRegistersPerBlock(device).value()); - builder.set_threads_per_warp(GpuDriver::GetThreadsPerWarp(device).value()); - builder.set_registers_per_core_limit(64 * 1024); - - int cc_major = 0; - int cc_minor = 0; - GpuDriver::GetComputeCapability(&cc_major, &cc_minor, device).IgnoreError(); - - // It would be better to use the PCI device ID or some other truly unique - // identifier for the GPU model. But getting this requires using NVML or - // other hacks, which we don't have access to in OSS TensorFlow. - // - // Alternatively you might be tempted to use GpuDriver::GetDeviceName as a - // unique identifier, but this is not stable across GPU VBIOS versions. - // - // TODO(jlebar): This really should be more unique. In CUDA land, we mix in - // the clock speed and L2 cache size. - builder.set_model_str(absl::StrFormat("cc_%d.%d with %dB RAM, %d cores", - cc_major, cc_minor, device_memory_size, - core_count)); - - return builder.Build(); -} - -} // namespace gpu - -} // namespace stream_executor - -REGISTER_MODULE_INITIALIZER(rocm_gpu_executor, {}); diff --git a/xla/stream_executor/rocm/rocm_helpers.cu.cc b/xla/stream_executor/rocm/rocm_helpers.cu.cc index 9c27fcb3a1861..cf12ffdaeb47c 100644 --- a/xla/stream_executor/rocm/rocm_helpers.cu.cc +++ b/xla/stream_executor/rocm/rocm_helpers.cu.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/rocm/rocm_kernel.cc b/xla/stream_executor/rocm/rocm_kernel.cc index 5ebad6db18bb5..74f2def14e22b 100644 --- a/xla/stream_executor/rocm/rocm_kernel.cc +++ b/xla/stream_executor/rocm/rocm_kernel.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ namespace stream_executor { namespace gpu { hipFuncCache_t GpuKernel::GetGpuCacheConfig() const { - switch (preferred_cache_config_) { + switch (cache_config()) { case KernelCacheConfig::kNoPreference: return hipFuncCachePreferNone; case KernelCacheConfig::kPreferShared: @@ -30,9 +30,21 @@ hipFuncCache_t GpuKernel::GetGpuCacheConfig() const { return hipFuncCachePreferEqual; default: LOG(FATAL) << "Unknown KernelCacheConfig" - << static_cast(preferred_cache_config_); + << static_cast(cache_config()); } } +absl::StatusOr GpuKernel::GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const { + int32_t threads_per_block = threads.x * threads.y * threads.z; + VLOG(0) << "Get kernel block occupancy: " << name_ + << "; threads_per_block: " << threads_per_block + << "; dynamic_shared_memory_bytes: " << dynamic_shared_memory_bytes; + + return GpuDriver::GetMaxOccupiedBlocksPerCore(gpu_context_, gpu_function_, + threads_per_block, + dynamic_shared_memory_bytes); +} + } // namespace gpu } // namespace stream_executor diff --git a/xla/stream_executor/rocm/rocm_platform.cc b/xla/stream_executor/rocm/rocm_platform.cc index 44f19ed53cd70..b7f4f9ff2fc64 100644 --- a/xla/stream_executor/rocm/rocm_platform.cc +++ b/xla/stream_executor/rocm/rocm_platform.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -72,7 +72,7 @@ int ROCmPlatform::DeviceToBus(int device_ordinal) { return exec->GetDeviceDescription().numa_node() - min_numa_node_; } -tsl::StatusOr ROCmPlatform::FirstExecutorForBus( +absl::StatusOr ROCmPlatform::FirstExecutorForBus( int bus_ordinal) { InspectNumaNodes(); CHECK_LT(bus_ordinal, BusCount()) << "bus ordinal out of available range"; @@ -84,7 +84,7 @@ tsl::StatusOr ROCmPlatform::FirstExecutorForBus( } } - return tsl::Status{ + return absl::Status{ absl::StatusCode::kNotFound, absl::StrFormat("Executor for bus %d not found.", bus_ordinal)}; } @@ -104,19 +104,18 @@ int ROCmPlatform::VisibleDeviceCount() const { const string& ROCmPlatform::Name() const { return name_; } -tsl::StatusOr> +absl::StatusOr> ROCmPlatform::DescriptionForDevice(int ordinal) const { return GpuExecutor::CreateDeviceDescription(ordinal); } -tsl::StatusOr ROCmPlatform::ExecutorForDevice(int ordinal) { +absl::StatusOr ROCmPlatform::ExecutorForDevice(int ordinal) { StreamExecutorConfig config; config.ordinal = ordinal; - config.device_options = DeviceOptions::Default(); return GetExecutor(config); } -tsl::StatusOr ROCmPlatform::GetExecutor( +absl::StatusOr ROCmPlatform::GetExecutor( const StreamExecutorConfig& config) { if (config.gpu_stream) { // If the GPU stream was provided, it's not possible to get-or-create a @@ -128,13 +127,13 @@ tsl::StatusOr ROCmPlatform::GetExecutor( config, [&]() { return GetUncachedExecutor(config); }); } -tsl::StatusOr> +absl::StatusOr> ROCmPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { auto executor = std::make_unique( this, std::make_unique(), config.ordinal); - auto init_status = executor->Init(config.device_options); + auto init_status = executor->Init(); if (!init_status.ok()) { - return tsl::Status{ + return absl::Status{ absl::StatusCode::kInternal, absl::StrFormat( "failed initializing StreamExecutor for ROCM device ordinal %d: %s", @@ -147,21 +146,16 @@ ROCmPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { } // namespace gpu static void InitializeROCmPlatform() { - // Disabling leak checking, MultiPlatformManager does not destroy its + // Disabling leak checking, PlatformManager does not destroy its // registered platforms. - auto status = MultiPlatformManager::PlatformWithName("ROCM"); + auto status = PlatformManager::PlatformWithName("ROCM"); if (!status.ok()) { std::unique_ptr platform(new gpu::ROCmPlatform); - TF_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform))); + TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform))); } } } // namespace stream_executor -REGISTER_MODULE_INITIALIZER(rocm_platform, - stream_executor::InitializeROCmPlatform()); - -DECLARE_MODULE_INITIALIZER(multi_platform_manager); -// Note that module initialization sequencing is not supported in the -// open-source project, so this will be a no-op there. -REGISTER_MODULE_INITIALIZER_SEQUENCE(rocm_platform, multi_platform_manager); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER( + rocm_platform, stream_executor::InitializeROCmPlatform()); diff --git a/xla/stream_executor/rocm/rocm_platform.h b/xla/stream_executor/rocm/rocm_platform.h index bd861cceb993c..7c9f503743549 100644 --- a/xla/stream_executor/rocm/rocm_platform.h +++ b/xla/stream_executor/rocm/rocm_platform.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,12 +21,11 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/stream_executor/executor_cache.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/port.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" -#include "xla/stream_executor/trace_listener.h" namespace stream_executor { namespace gpu { @@ -51,7 +50,7 @@ class ROCmPlatform : public Platform { int DeviceToBus(int device_ordinal); // Returns the lowest-ordinal-number StreamExecutor on the specified bus. - tsl::StatusOr FirstExecutorForBus(int bus_ordinal); + absl::StatusOr FirstExecutorForBus(int bus_ordinal); // Platform interface implementation: // Returns the same value as kROCmPlatform above. @@ -62,15 +61,15 @@ class ROCmPlatform : public Platform { const string& Name() const override; - tsl::StatusOr> DescriptionForDevice( + absl::StatusOr> DescriptionForDevice( int ordinal) const override; - tsl::StatusOr ExecutorForDevice(int ordinal) override; + absl::StatusOr ExecutorForDevice(int ordinal) override; - tsl::StatusOr GetExecutor( + absl::StatusOr GetExecutor( const StreamExecutorConfig& config) override; - tsl::StatusOr> GetUncachedExecutor( + absl::StatusOr> GetUncachedExecutor( const StreamExecutorConfig& config) override; private: diff --git a/xla/stream_executor/rocm/rocm_platform_id.cc b/xla/stream_executor/rocm/rocm_platform_id.cc index 6f0df65430a63..82e7d26729fb7 100644 --- a/xla/stream_executor/rocm/rocm_platform_id.cc +++ b/xla/stream_executor/rocm/rocm_platform_id.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/rocm/rocm_platform_id.h b/xla/stream_executor/rocm/rocm_platform_id.h index 2ee37660ddea0..b11b377c01028 100644 --- a/xla/stream_executor/rocm/rocm_platform_id.h +++ b/xla/stream_executor/rocm/rocm_platform_id.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/rocm/rocm_runtime.cc b/xla/stream_executor/rocm/rocm_runtime.cc new file mode 100644 index 0000000000000..ff1ef83bba39c --- /dev/null +++ b/xla/stream_executor/rocm/rocm_runtime.cc @@ -0,0 +1,52 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/base/optimization.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "xla/stream_executor/gpu/gpu_runtime.h" +#include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/rocm/rocm_driver.h" +#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" + +#define RETURN_IF_ROCM_ERROR(expr, ...) \ + if (auto res = (expr); TF_PREDICT_FALSE(res != hipSuccess)) { \ + return absl::InternalError(absl::StrCat( \ + __VA_ARGS__, ": ", ::stream_executor::gpu::ToString(res))); \ + } + +namespace stream_executor { +namespace gpu { + +absl::StatusOr GpuRuntime::GetFuncBySymbol(void* symbol) { + VLOG(2) << "Get ROCM function from a symbol: " << symbol; + return absl::UnimplementedError("GetFuncBySymbol is not implemented"); +} + +absl::StatusOr GpuRuntime::GetRuntimeVersion() { + VLOG(2) << "Get ROCM runtime version"; + int32_t version; + RETURN_IF_ROCM_ERROR(wrap::hipRuntimeGetVersion(&version), + "Failed call to hipRuntimeGetVersion"); + return version; +} + +} // namespace gpu + +} // namespace stream_executor diff --git a/xla/stream_executor/rocm/rocsolver_wrapper.h b/xla/stream_executor/rocm/rocsolver_wrapper.h index 2d6018615a1c0..23b6c4b99dabb 100644 --- a/xla/stream_executor/rocm/rocsolver_wrapper.h +++ b/xla/stream_executor/rocm/rocsolver_wrapper.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/rocm/roctracer_wrapper.h b/xla/stream_executor/rocm/roctracer_wrapper.h index 61bfd00b65a88..cbfe9e2e2cfe6 100644 --- a/xla/stream_executor/rocm/roctracer_wrapper.h +++ b/xla/stream_executor/rocm/roctracer_wrapper.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/scratch_allocator.cc b/xla/stream_executor/scratch_allocator.cc deleted file mode 100644 index cd02336b9a4d1..0000000000000 --- a/xla/stream_executor/scratch_allocator.cc +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/stream_executor/scratch_allocator.h" - -#include - -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/stream.h" -#include "tsl/platform/statusor.h" - -namespace stream_executor { - -tsl::StatusOr> OneTimeScratchAllocator::AllocateBytes( - int64_t byte_size) { - CHECK(temporary_ == nullptr); - TF_ASSIGN_OR_RETURN(temporary_, - stream_->AllocateTemporaryArray(byte_size)); - return temporary_->device_memory(); -} - -} // namespace stream_executor diff --git a/xla/stream_executor/scratch_allocator.h b/xla/stream_executor/scratch_allocator.h index c3a321e9e5e8a..546f382b643b4 100644 --- a/xla/stream_executor/scratch_allocator.h +++ b/xla/stream_executor/scratch_allocator.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,15 +16,14 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_SCRATCH_ALLOCATOR_H_ #define XLA_STREAM_EXECUTOR_SCRATCH_ALLOCATOR_H_ +#include #include -#include #include #include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" -#include "xla/stream_executor/platform/port.h" -#include "xla/stream_executor/temporary_device_memory.h" #include "tsl/platform/statusor.h" namespace stream_executor { @@ -51,34 +50,10 @@ class ScratchAllocator { // // This is a temporary allocation, and the caller is responsible for // deallocating at some known-safe point. See the class comment above. - virtual tsl::StatusOr> AllocateBytes( + virtual absl::StatusOr> AllocateBytes( int64_t byte_size) = 0; }; -// Allocates a single temporary memory allocation -- this memory is deallocated -// at the next stream synchronization point after this object has gone out of -// scope. This satisfies the lifetime and deallocation properties given in the -// class comment above. -// -// Thread-compatible, but not thread-safe (use in scenarios where only one -// thread will request the scratch allocation). -class OneTimeScratchAllocator : public ScratchAllocator { - public: - explicit OneTimeScratchAllocator(Stream* stream) : stream_(stream) {} - - int64_t GetMemoryLimitInBytes() override { return -1; } - - tsl::StatusOr> AllocateBytes( - int64_t byte_size) override; - - private: - std::unique_ptr> temporary_; - Stream* stream_; - - OneTimeScratchAllocator(const OneTimeScratchAllocator&) = delete; - void operator=(const OneTimeScratchAllocator&) = delete; -}; - // Can allocate several times -- this memory is deallocated when the scratch // allocator is destroyed. // @@ -90,9 +65,12 @@ class OwningScratchAllocator : public ScratchAllocator { OwningScratchAllocator(int device_ordinal, DeviceMemoryAllocator* allocator) : device_ordinal_(device_ordinal), allocator_(allocator) {} + OwningScratchAllocator(OwningScratchAllocator&&) = default; + OwningScratchAllocator& operator=(OwningScratchAllocator&&) = default; + int64_t GetMemoryLimitInBytes() override { return -1; } - tsl::StatusOr> AllocateBytes( + absl::StatusOr> AllocateBytes( int64_t byte_size) override { TF_ASSIGN_OR_RETURN(OwningDeviceMemory buffer, allocator_->Allocate(device_ordinal_, byte_size, @@ -105,9 +83,6 @@ class OwningScratchAllocator : public ScratchAllocator { int device_ordinal_; DeviceMemoryAllocator* allocator_; absl::InlinedVector buffers_; - - OwningScratchAllocator(const OwningScratchAllocator&) = delete; - void operator=(const OwningScratchAllocator&) = delete; }; } // namespace stream_executor diff --git a/xla/stream_executor/stream.cc b/xla/stream_executor/stream.cc index d607aaa45eebc..fd1e2a6c993d9 100644 --- a/xla/stream_executor/stream.cc +++ b/xla/stream_executor/stream.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,275 +15,72 @@ limitations under the License. #include "xla/stream_executor/stream.h" +#include #include #include -#include #include +#include +#include +#include #include #include #include #include "absl/functional/any_invocable.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" -#include "Eigen/Core" // from @eigen_archive +#include "absl/synchronization/mutex.h" #include "xla/stream_executor/blas.h" -#include "xla/stream_executor/numeric_options.h" +#include "xla/stream_executor/event.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/stacktrace.h" namespace stream_executor { -namespace { -// Code to turn parameters to functions on stream into strings that -// will be VLOG'ed. We need overloads, instead of -// e.g. BatchDescriptorToVlogString(), as the code that calls these -// functions does not know what the type of the parameter is. -std::string ToVlogString(const dnn::BatchDescriptor &descriptor) { - return descriptor.ToShortString(); -} - -std::string ToVlogString(const dnn::FilterDescriptor &descriptor) { - return descriptor.ToShortString(); -} - -std::string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) { - return descriptor.ToShortString(); -} - -std::string ToVlogString(const dnn::PoolingDescriptor &descriptor) { - return descriptor.ToShortString(); -} - -std::string ToVlogString(const dnn::NormalizeDescriptor &descriptor) { - return descriptor.ToShortString(); -} - -std::string ToVlogString(dnn::ActivationMode mode) { - return dnn::ActivationModeString(mode); -} - -std::string ToVlogString(const dnn::AlgorithmConfig &algo_config) { - return algo_config.ToString(); -} - -std::string ToVlogString(dnn::ElementwiseOperation op) { - return dnn::ElementwiseOperationString(op); -} - -std::string ToVlogString(dnn::QuantizedActivationMode mode) { - return dnn::QuantizedActivationModeString(mode); -} - -std::string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); } - -std::string ToVlogString(blas::UpperLower ul) { - return blas::UpperLowerString(ul); -} - -std::string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); } - -std::string ToVlogString(blas::Side s) { return blas::SideString(s); } - -std::string ToVlogString(blas::ComputationType ty) { - return blas::ComputationTypeString(ty); -} - -std::string ToVlogString(const void *ptr) { - if (ptr == nullptr) { - return "null"; - } - - // StrCat does not convert pointers to text. - std::ostringstream out; - out << ptr; - return out.str(); -} - -template -std::string ToVlogString(const std::complex &c) { - // StrCat does not convert std::complex to text. - std::ostringstream out; - out << c; - return out.str(); -} - -template -std::string ToVlogString(const std::function &f) { - return f == nullptr ? "null" : ""; -} - -template -std::string ToVlogString(const absl::AnyInvocable &f) { - return f == nullptr ? "null" : ""; -} - -std::string ToVlogString(const DeviceMemoryBase &memory) { - return ToVlogString(memory.opaque()); -} - -std::string ToVlogString(const DeviceMemoryBase *memory) { - return memory == nullptr ? "null" : ToVlogString(*memory); -} - -std::string ToVlogString(const Eigen::half &h) { - return absl::StrCat(static_cast(h)); -} - -std::string ToVlogString(const Eigen::bfloat16 &bf) { // NOLINT - return absl::StrCat(static_cast(bf)); -} - -std::string ToVlogString(int i) { return absl::StrCat(i); } - -std::string ToVlogString(uint32_t i) { return absl::StrCat(i); } - -std::string ToVlogString(uint64_t i) { return absl::StrCat(i); } - -std::string ToVlogString(int64_t i) { return absl::StrCat(i); } - -std::string ToVlogString(float f) { return absl::StrCat(f); } - -std::string ToVlogString(double d) { return absl::StrCat(d); } +Stream::Stream(StreamExecutor *parent) + : parent_(parent), implementation_(nullptr), status_(absl::OkStatus()) {} -template -std::string ToVlogString(absl::Span elements) { - std::string str = absl::StrCat( - ToVlogString(reinterpret_cast(elements.data())), "[", - elements.size(), "]{"); - const char *separator = ""; - size_t max_to_show = std::numeric_limits::max(); - if (!VLOG_IS_ON(2)) { - max_to_show = 5; - } else if (!VLOG_IS_ON(3)) { - max_to_show = 20; - } else if (!VLOG_IS_ON(11)) { - max_to_show = 1000; - } - for (size_t i = 0; i < elements.size(); ++i) { - if (i == max_to_show) { - str += ", ..."; - break; +absl::Status Stream::Initialize( + std::optional> priority) { + absl::MutexLock lock(&mu_); + if (implementation_ != nullptr) { + return absl::InternalError( + "stream appears to already have been initialized"); + } + implementation_ = parent_->implementation()->GetStreamImplementation(); + if (priority.has_value()) { + if (std::holds_alternative(*priority)) { + implementation_->SetPriority(std::get(*priority)); + } else { + implementation_->SetPriority(std::get(*priority)); } - absl::StrAppend(&str, separator, ToVlogString(elements[i])); - separator = ", "; } - str += "}"; - return str; -} - -template -std::string ToVlogString(absl::Span elements) { - return ToVlogString(absl::Span(elements)); -} -std::string ToVlogString(dnn::DataType data_type) { - switch (data_type) { - case dnn::DataType::kFloat: - return "dnn::DataType::kFloat"; - case dnn::DataType::kDouble: - return "dnn::DataType::kDouble"; - case dnn::DataType::kHalf: - return "dnn::DataType::kHalf"; - case dnn::DataType::kInt8: - return "dnn::DataType::kInt8"; - case dnn::DataType::kInt32: - return "dnn::DataType::kInt32"; - case dnn::DataType::kBF16: - return "dnn::DataType::kBF16"; - default: - return "unknown DataType"; + if (parent_->AllocateStream(this)) { + // Successful initialization! + return absl::OkStatus(); } -} -// Used together with PARAM to VLOG calls made to the stream. Intended -// to be used like this: -// -// VLOG(1) << CallStr("MyFunction", this, {PARAM(a), PARAM(b)}); -// -// where a and b are the parameters to MyFunction. -// -// See VLOG_CALL for a short-hand for this. This way of doing it saves -// a tremendous amount of boilerplate code given how many functions -// there are on Stream and how many parameters they each have. -std::string CallStr(const char *function_name, Stream *stream, - std::vector> params) { - // Do not call this function unless VLOG is on since just - // constructing all the strings in params is expensive. - CHECK(VLOG_IS_ON(1)); - - std::string str = absl::StrCat(stream->DebugStreamPointers(), - " Called Stream::", function_name, "("); - const char *separator = ""; - for (const auto ¶m : params) { - absl::StrAppend(&str, separator, param.first, "=", param.second); - separator = ", "; - } - absl::StrAppend(&str, ")"); - if (VLOG_IS_ON(10)) { - absl::StrAppend(&str, " ", tsl::CurrentStackTrace(), "\n"); - } - return str; -} - -// Use this macro to avoid having to type every parameter twice to log -// it with VLOG and CallStr. -#define PARAM(parameter) \ - { #parameter, ToVlogString(parameter) } - -// Use this macro to avoid having to type out the name of each -// function and to save some boilerplate. Intended to be used like this: -// -// VLOG_CALL(PARAM(a), PARAM(b)) -// -// This saves a tremendous amount of boilerplate compared to the alternative: -// -// VLOG(1) << "Calling MyFunction(a=" << ToVlogString(a) -// << ", b=" << ToVlogString(b); -// -// Note here that most of the parameter names are not short and that -// most of the functions take many more than 2 parameters. -#define VLOG_CALL(...) VLOG(1) << CallStr(__func__, this, {__VA_ARGS__}) - -} // namespace - -Stream::Stream(StreamExecutor *parent) - : parent_(parent), - implementation_(parent->implementation()->GetStreamImplementation()), - allocated_(false), - status_(tsl::errors::Internal("Uninitialized stream")), - temporary_memory_manager_(this) { - VLOG_CALL(PARAM(parent)); + return absl::InternalError("failed to allocate stream during initialization"); } Stream::~Stream() { - VLOG_CALL(); - // Ensure the stream is completed. auto status = BlockHostUntilDone(); if (!status.ok()) { LOG(WARNING) << "Error blocking host until done in stream destructor: " << status; } - temporary_memory_manager_.ForceDeallocateAll(); - RunAfterBlockHostUntilDoneCallbacks(); - if (allocated_) { + if (implementation_ != nullptr) { parent_->DeallocateStream(this); } } -void Stream::SetPriority(StreamPriority priority) { - implementation_->SetPriority(priority); -} - -void Stream::SetPriority(int priority) { - implementation_->SetPriority(priority); -} - std::variant Stream::priority() const { return implementation_->priority(); } @@ -294,526 +91,22 @@ Stream::PlatformSpecificHandle Stream::platform_specific_handle() const { return handle; } -tsl::Status Stream::RefreshStatus() { - tsl::Status status = parent_->GetStatus(this); +absl::Status Stream::RefreshStatus() { + absl::Status status = parent_->GetStatus(this); // We should not put the stream in an error state, just because the GetStatus // method is unimplemented. - if (status != tsl::Status(absl::StatusCode::kUnimplemented, - "GetStatus is not supported on this executor.")) { + if (status != absl::UnimplementedError( + "GetStatus is not supported on this executor.")) { CheckStatus(status); } return status; } -Stream &Stream::Init() { - VLOG_CALL(); - - absl::MutexLock lock(&mu_); - CHECK_EQ(false, allocated_) - << "stream appears to already have been initialized"; - CHECK(!status_.ok()) << "stream should be in !ok() state pre-initialization"; - - if (parent_->AllocateStream(this)) { - // Successful initialization! - allocated_ = true; - status_ = ::tsl::OkStatus(); - } else { - LOG(ERROR) << "failed to allocate stream during initialization"; - } - - return *this; -} - -Stream &Stream::ThenRecordEvent(Event *event) { - VLOG_CALL(PARAM(event)); - - tsl::Status status = parent_->RecordEvent(this, event); - if (!status.ok()) { - LOG(ERROR) << "Error recording event in stream: " << status.message() - << "; not marking stream as bad, as the Event object may be " - << "at fault. Monitor for further errors."; - } - - return *this; -} - -Stream &Stream::ThenBatchNormalizationForward( - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, - const DeviceMemory &estimated_mean, - const DeviceMemory &estimated_variance, - const DeviceMemory &side_input, const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - const double exponential_average_factor, - dnn::ActivationMode activation_mode, DeviceMemory *y, - DeviceMemory *batch_mean, DeviceMemory *batch_var, - DeviceMemory *saved_mean, DeviceMemory *saved_inv_var, - bool is_training, ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator) { - VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc), - PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y)); - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoBatchNormalizationForward( - this, x, scale, offset, estimated_mean, estimated_variance, side_input, - x_desc, scale_offset_desc, epsilon, exponential_average_factor, - activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var, - is_training, reserve_space_allocator, workspace_allocator)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenBatchNormalizationBackward( - const DeviceMemory &y_backprop, const DeviceMemory &x, - const DeviceMemory &scale, const DeviceMemory &offset, - const DeviceMemory &mean, const DeviceMemory &inv_var, - const DeviceMemory &y, const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - dnn::ActivationMode activation_mode, DeviceMemory *x_backprop, - DeviceMemory *scale_backprop, DeviceMemory *offset_backprop, - DeviceMemory *side_input_backprop, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator) { - VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc), - PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop), - PARAM(scale_backprop), PARAM(offset_backprop)); - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoBatchNormalizationBackward( - this, y_backprop, x, scale, offset, mean, inv_var, y, x_desc, - scale_offset_desc, epsilon, activation_mode, x_backprop, scale_backprop, - offset_backprop, side_input_backprop, reserve_space_data, - workspace_allocator)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenBatchNormalizationForward( - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, - const DeviceMemory &estimated_mean, - const DeviceMemory &estimated_variance, - const DeviceMemory &side_input, - const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - const double exponential_average_factor, - dnn::ActivationMode activation_mode, DeviceMemory *y, - DeviceMemory *batch_mean, DeviceMemory *batch_var, - DeviceMemory *saved_mean, DeviceMemory *saved_inv_var, - bool is_training, ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator) { - VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc), - PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y)); - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoBatchNormalizationForward( - this, x, scale, offset, estimated_mean, estimated_variance, side_input, - x_desc, scale_offset_desc, epsilon, exponential_average_factor, - activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var, - is_training, reserve_space_allocator, workspace_allocator)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenBatchNormalizationBackward( - const DeviceMemory &y_backprop, - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, const DeviceMemory &mean, - const DeviceMemory &inv_var, const DeviceMemory &y, - const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - dnn::ActivationMode activation_mode, DeviceMemory *x_backprop, - DeviceMemory *scale_backprop, DeviceMemory *offset_backprop, - DeviceMemory *side_input_backprop, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator) { - VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc), - PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop), - PARAM(scale_backprop), PARAM(offset_backprop)); - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoBatchNormalizationBackward( - this, y_backprop, x, scale, offset, mean, inv_var, y, x_desc, - scale_offset_desc, epsilon, activation_mode, x_backprop, scale_backprop, - offset_backprop, side_input_backprop, reserve_space_data, - workspace_allocator)); - - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenBatchNormalizationForward( - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, - const DeviceMemory &estimated_mean, - const DeviceMemory &estimated_variance, - const DeviceMemory &side_input, - const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - const double exponential_average_factor, - dnn::ActivationMode activation_mode, DeviceMemory *y, - DeviceMemory *batch_mean, DeviceMemory *batch_var, - DeviceMemory *saved_mean, DeviceMemory *saved_inv_var, - bool is_training, ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator) { - VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc), - PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y)); - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoBatchNormalizationForward( - this, x, scale, offset, estimated_mean, estimated_variance, side_input, - x_desc, scale_offset_desc, epsilon, exponential_average_factor, - activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var, - is_training, reserve_space_allocator, workspace_allocator)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenBatchNormalizationBackward( - const DeviceMemory &y_backprop, - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, const DeviceMemory &mean, - const DeviceMemory &inv_var, const DeviceMemory &y, - const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - dnn::ActivationMode activation_mode, - DeviceMemory *x_backprop, - DeviceMemory *scale_backprop, DeviceMemory *offset_backprop, - DeviceMemory *side_input_backprop, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator) { - VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc), - PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop), - PARAM(scale_backprop), PARAM(offset_backprop)); - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoBatchNormalizationBackward( - this, y_backprop, x, scale, offset, mean, inv_var, y, x_desc, - scale_offset_desc, epsilon, activation_mode, x_backprop, scale_backprop, - offset_backprop, side_input_backprop, reserve_space_data, - workspace_allocator)); - - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenConvolve( - const dnn::BatchDescriptor &input_descriptor, - const DeviceMemory &input_data, - const dnn::FilterDescriptor &filter_descriptor, - const DeviceMemory &filter_data, - const dnn::ConvolutionDescriptor &convolution_descriptor, - const dnn::BatchDescriptor &output_descriptor, - DeviceMemory *output) { - if (ok()) { - CheckError(ConvolveWithAlgorithm( - dnn::ConvolutionKind::FORWARD, input_descriptor, input_data, - filter_descriptor, filter_data, output_descriptor, *output, - convolution_descriptor, - /*scratch_allocator=*/nullptr, dnn::AlgorithmConfig(), - /*output_profile_result=*/nullptr) - .ok()); - } - return *this; -} - -Stream &Stream::ThenSeparableConvolve( - const dnn::BatchDescriptor &batch_descriptor, - const DeviceMemory &input_data, - const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier, - const DeviceMemory &first_weights, - const DeviceMemory &second_weights, - const dnn::ConvolutionDescriptor &convolution_descriptor, - const dnn::BatchDescriptor &output_descriptor, - DeviceMemory *output) { - VLOG_CALL( - PARAM(batch_descriptor), PARAM(input_data), PARAM(filter_descriptor), - PARAM(depth_multiplier), PARAM(first_weights), PARAM(second_weights), - PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output)); - - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoSeparableConvolve( - this, batch_descriptor, input_data, filter_descriptor, depth_multiplier, - first_weights, second_weights, convolution_descriptor, - output_descriptor, output)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenMatMul(const DeviceMemory &input_data, - const DeviceMemory &weights, - const dnn::BatchDescriptor &input_dimensions, - const dnn::BatchDescriptor &output_dimensions, - DeviceMemory *output_data) { - VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(input_dimensions), - PARAM(output_dimensions), PARAM(output_data)); - - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoMatMul(this, input_data, weights, input_dimensions, - output_dimensions, output_data)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenMatMulQuantized( - const DeviceMemory &input_data, const DeviceMemory &weights, - const DeviceMemory &weight_scales, - const dnn::BatchDescriptor &input_dimensions, - const dnn::BatchDescriptor &output_dimensions, - DeviceMemory *output_data) { - VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales), - PARAM(input_dimensions), PARAM(output_dimensions), - PARAM(output_data)); - - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoMatMulQuantized(this, input_data, weights, weight_scales, - input_dimensions, output_dimensions, - output_data)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; +absl::Status Stream::RecordEvent(Event *event) { + return parent_->RecordEvent(this, event); } -Stream &Stream::ThenMatMulQuantized( - const DeviceMemory &input_data, const DeviceMemory &weights, - const DeviceMemory &weight_scales, - const dnn::BatchDescriptor &input_dimensions, - const dnn::BatchDescriptor &output_dimensions, - DeviceMemory *output_data) { - VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales), - PARAM(input_dimensions), PARAM(output_dimensions), - PARAM(output_data)); - - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoMatMulQuantized(this, input_data, weights, weight_scales, - input_dimensions, output_dimensions, - output_data)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenBiasAdd(const DeviceMemory &input_data, - const DeviceMemory &biases, - const dnn::BatchDescriptor &dimensions, - DeviceMemory *output_data) { - VLOG_CALL(PARAM(input_data), PARAM(biases), PARAM(dimensions), - PARAM(output_data)); - - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError( - dnn->DoBiasAdd(this, input_data, biases, dimensions, output_data)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenNormalizeWithDimensions( - const dnn::NormalizeDescriptor &normalize_descriptor, - const dnn::BatchDescriptor &dimensions, - const DeviceMemory &input_data, DeviceMemory *output_data) { - VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(input_data), - PARAM(output_data)); - - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoNormalizeWithDimensions( - this, normalize_descriptor, dimensions, input_data, output_data)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenNormalizeBackwardWithDimensions( - const dnn::NormalizeDescriptor &normalize_descriptor, - const dnn::BatchDescriptor &dimensions, const DeviceMemory &raw_data, - const DeviceMemory &normalized_data, - const DeviceMemory &normalized_variable_gradient, - DeviceMemory *raw_variable_gradient, - ScratchAllocator *workspace_allocator) { - VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data), - PARAM(normalized_data), PARAM(normalized_variable_gradient), - PARAM(raw_variable_gradient), PARAM(workspace_allocator)); - - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoNormalizeBackwardWithDimensions( - this, normalize_descriptor, dimensions, raw_data, normalized_data, - normalized_variable_gradient, raw_variable_gradient, - workspace_allocator)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenActivate(dnn::ActivationMode activation_mode, - const dnn::BatchDescriptor &dimensions, - const DeviceMemory &input_data, - DeviceMemory *output_data) { - return ThenActivateWithOptions(activation_mode, dimensions, input_data, - output_data, /*options=*/0); -} - -Stream &Stream::ThenActivateWithOptions(dnn::ActivationMode activation_mode, - const dnn::BatchDescriptor &dimensions, - const DeviceMemory &input_data, - DeviceMemory *output_data, - uint64_t options) { - VLOG_CALL(PARAM(activation_mode), PARAM(dimensions), PARAM(input_data), - PARAM(output_data), PARAM(options)); - - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoActivate(this, activation_mode, dimensions, input_data, - output_data, options)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenDepthConcatenate( - absl::Span input_dimensions, - absl::Span *const> input_data, - DeviceMemory *output_data) { - VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data)); - - for (size_t i = 1; i < input_dimensions.size(); ++i) { - if (input_dimensions[i].count() != input_dimensions[0].count() || - input_dimensions[i].height() != input_dimensions[0].height() || - input_dimensions[i].width() != input_dimensions[0].width()) { - SetError(); - LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n" - << "input_dimensions[0]: " << input_dimensions[0].ToString() - << "input_dimensions[" << i - << "]: " << input_dimensions[i].ToString(); - return *this; - } - } - - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data, - output_data)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenElementwiseOperate( - dnn::ElementwiseOperation operation, - absl::Span input_dimensions, - absl::Span *const> input_data, - const dnn::BatchDescriptor &output_dimensions, - DeviceMemory *output_data) { - VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data), - PARAM(output_dimensions), PARAM(output_data)); - - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoElementwiseOperate(this, operation, input_dimensions, - input_data, output_dimensions, - output_data)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions, - const DeviceMemory &input_data, - int64_t left_pad, int64_t right_pad, int64_t top_pad, - int64_t bottom_pad, - DeviceMemory *output_data) { - VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad), - PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad), - PARAM(output_data)); - - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad, - top_pad, bottom_pad, output_data)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions, - const DeviceMemory &input_data, - int64_t left_trim, int64_t right_trim, - int64_t top_trim, int64_t bottom_trim, - DeviceMemory *output_data) { - VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim), - PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim), - PARAM(output_data)); - - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim, - right_trim, top_trim, bottom_trim, output_data)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenXYBroadcast(const dnn::BatchDescriptor &dimensions, - const DeviceMemory &input_data, - int64_t replicate_x, int64_t replicate_y, - DeviceMemory *output_data) { - VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(replicate_x), - PARAM(replicate_y), PARAM(output_data)); - - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoXYBroadcast(this, dimensions, input_data, replicate_x, - replicate_y, output_data)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenMemcpyD2HQuantized( - const DeviceMemory &gpu_unquantized_src, - dnn::QuantizedActivationMode mode, void *host_dst, uint64_t size) { - VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst), - PARAM(size)); - - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode, - host_dst, size)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenMemcpyH2DQuantized( - const void *host_src, uint64_t size, dnn::QuantizedActivationMode mode, - DeviceMemory *gpu_unquantized_dst) { - VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode), - PARAM(gpu_unquantized_dst)); - - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode, - gpu_unquantized_dst)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream *Stream::GetOrCreateSubStream() { +absl::StatusOr Stream::GetOrCreateSubStream() { // Do not destroy bad streams when holding mu_ because ~Stream() may // BlockHostUntilDone and it's host callbacks might attempt to acquire mu_. std::vector> bad_streams; @@ -828,8 +121,7 @@ Stream *Stream::GetOrCreateSubStream() { // The sub_stream is reusable. Stream *sub_stream = pair.first.get(); if (sub_stream->ok()) { - VLOG(1) << DebugStreamPointers() << " reusing sub_stream " - << sub_stream->DebugStreamPointers(); + VLOG(1) << "stream=" << this << " reusing sub_stream=" << sub_stream; pair.second = false; return sub_stream; } @@ -843,8 +135,7 @@ Stream *Stream::GetOrCreateSubStream() { } bad_streams.push_back(std::move(sub_streams_.back().first)); sub_streams_.pop_back(); - VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream " - << sub_stream->DebugStreamPointers(); + VLOG(1) << "stream=" << this << " dropped !ok sub_stream=" << sub_stream; } else { // The sub_stream is not reusable, move on to the next one. ++index; @@ -852,15 +143,10 @@ Stream *Stream::GetOrCreateSubStream() { } // No streams are reusable; create a new stream. - sub_streams_.emplace_back(std::unique_ptr{new Stream{parent_}}, - false); + sub_streams_.emplace_back(std::make_unique(parent_), false); Stream *sub_stream = sub_streams_.back().first.get(); - sub_stream->Init(); - if (!sub_stream->ok()) { - LOG(ERROR) << "sub-stream failed to be initialized"; - } - VLOG(1) << DebugStreamPointers() << " created new sub_stream " - << sub_stream->DebugStreamPointers(); + TF_RETURN_IF_ERROR(sub_stream->Initialize()); + VLOG(1) << "stream=" << this << " created new sub_stream=" << sub_stream; return sub_stream; } @@ -881,15 +167,13 @@ void Stream::ReturnSubStream(Stream *sub_stream) { // Found the sub_stream. if (sub_stream->ok()) { - VLOG(1) << DebugStreamPointers() << " returned ok sub_stream " - << sub_stream->DebugStreamPointers(); + VLOG(1) << "stream=" << this << " returned ok sub_stream=" << sub_stream; pair.second = true; } else { // The returned stream is not ok. Streams have a monotonic state // machine; the stream will remain in !ok forever. Swap it with the last // stream and pop it off. - VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream " - << sub_stream->DebugStreamPointers(); + VLOG(1) << "stream=" << this << " returned !ok sub_stream=" << sub_stream; const int64_t last = sub_streams_.size() - 1; if (index != last) { std::swap(pair, sub_streams_[last]); @@ -900,914 +184,70 @@ void Stream::ReturnSubStream(Stream *sub_stream) { return; } - LOG(FATAL) << DebugStreamPointers() - << " did not create the returned sub-stream " - << sub_stream->DebugStreamPointers(); + LOG(FATAL) << "stream=" << this << " did not create the returned sub-stream " + << sub_stream; } -Stream &Stream::ThenWaitFor(Stream *other) { - VLOG_CALL(PARAM(other)); - - CHECK(this != other) << "stream cannot wait for itself"; - if (ok() && other->ok()) { - CheckError(parent_->CreateStreamDependency(this, other)); - } else { - SetError(); - LOG(INFO) << DebugStreamPointers() << " did not wait for " - << other->DebugStreamPointers(); - } - return *this; -} - -Stream &Stream::ThenWaitFor(Event *event) { - VLOG_CALL(PARAM(event)); - - if (ok()) { - tsl::Status status = parent_->WaitForEvent(this, event); - if (!status.ok()) { - LOG(ERROR) << "Error waiting for event in stream: " << status.message() - << "; not marking stream as bad, as the Event object may be " - << "at fault. Monitor for further errors."; - } - } else { - LOG(INFO) << DebugStreamPointers() << " did not wait for an event."; +absl::Status Stream::WaitFor(Stream *other) { + if (this == other) { + return absl::InternalError("stream cannot wait for itself"); } - return *this; -} - -// A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX -// functions and logs for errors. -template -struct ThenBlasImpl { - // blas_func is the DoBlasXXX member function pointer, and args are its - // arguments except the first one of Stream* type. - Stream &operator()(Stream *stream, - bool (blas::BlasSupport::*blas_func)(Stream *, Args...), - Args... args) { - return Run(stream, blas_func, /*record_error=*/true, args...); - } - - // Like operator(), but only calls stream->CheckError() if record_error is - // true. - Stream &Run(Stream *stream, - bool (blas::BlasSupport::*blas_func)(Stream *, Args...), - bool record_error, Args... args); -}; - -template -Stream &ThenBlasImpl::Run( - Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...), - bool record_error, Args... args) { - if (stream->ok()) { - bool ok; - if (blas::BlasSupport *blas = stream->parent_->AsBlas()) { - ok = (blas->*blas_func)(stream, args...); - } else { - LOG(WARNING) - << "attempting to perform BLAS operation using StreamExecutor " - "without BLAS support"; - ok = false; - } - if (record_error) { - stream->CheckError(ok); - } + if (parent_->CreateStreamDependency(this, other)) { + return absl::OkStatus(); } - return *stream; -} - -Stream &Stream::ThenBlasAxpy(uint64_t elem_count, float alpha, - const DeviceMemory &x, int incx, - DeviceMemory *y, int incy) { - VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), - PARAM(incy)); - - ThenBlasImpl &, int, - DeviceMemory *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx, - y, incy); -} - -Stream &Stream::ThenBlasCopy(uint64_t elem_count, const DeviceMemory &x, - int incx, DeviceMemory *y, int incy) { - VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy)); - - ThenBlasImpl &, int, - DeviceMemory *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y, - incy); -} - -Stream &Stream::ThenBlasScal(uint64_t elem_count, float alpha, - DeviceMemory *x, int incx) { - VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); - - ThenBlasImpl *, int> impl; - return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); -} - -Stream &Stream::ThenBlasScal(uint64_t elem_count, double alpha, - DeviceMemory *x, int incx) { - VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); - - ThenBlasImpl *, int> impl; - return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); -} - -Stream &Stream::ThenBlasScal(uint64_t elem_count, float alpha, - DeviceMemory> *x, int incx) { - VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); - - ThenBlasImpl> *, int> impl; - return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); -} - -Stream &Stream::ThenBlasScal(uint64_t elem_count, double alpha, - DeviceMemory> *x, int incx) { - VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); - - ThenBlasImpl> *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); -} - -Stream &Stream::ThenBlasScal(uint64_t elem_count, std::complex alpha, - DeviceMemory> *x, int incx) { - VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); - - ThenBlasImpl, - DeviceMemory> *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); -} - -Stream &Stream::ThenBlasScal(uint64_t elem_count, std::complex alpha, - DeviceMemory> *x, int incx) { - VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); - - ThenBlasImpl, - DeviceMemory> *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); + return absl::InternalError("stream cannot wait for other"); } -Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, - float alpha, const DeviceMemory &a, int lda, - const DeviceMemory &x, int incx, float beta, - DeviceMemory *y, int incy) { - VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), - PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), - PARAM(incy)); - - ThenBlasImpl &, int, const DeviceMemory &, - int, float, DeviceMemory *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda, - x, incx, beta, y, incy); +absl::Status Stream::WaitFor(Event *event) { + return parent_->WaitForEvent(this, event); } -Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, - double alpha, const DeviceMemory &a, - int lda, const DeviceMemory &x, int incx, - double beta, DeviceMemory *y, int incy) { - VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), - PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), - PARAM(incy)); - - ThenBlasImpl &, int, const DeviceMemory &, - int, double, DeviceMemory *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda, - x, incx, beta, y, incy); -} - -Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, - std::complex alpha, - const DeviceMemory> &a, - int lda, - const DeviceMemory> &x, - int incx, std::complex beta, - DeviceMemory> *y, int incy) { - VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), - PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), - PARAM(incy)); - - ThenBlasImpl, - const DeviceMemory> &, int, - const DeviceMemory> &, int, - std::complex, DeviceMemory> *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda, - x, incx, beta, y, incy); -} - -Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, - std::complex alpha, - const DeviceMemory> &a, - int lda, - const DeviceMemory> &x, - int incx, std::complex beta, - DeviceMemory> *y, int incy) { - VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), - PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), - PARAM(incy)); - - ThenBlasImpl, - const DeviceMemory> &, int, - const DeviceMemory> &, int, - std::complex, DeviceMemory> *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda, - x, incx, beta, y, incy); -} - -Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k, - float alpha, const DeviceMemory &a, int lda, - const DeviceMemory &x, int incx, float beta, - DeviceMemory *y, int incy) { - VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), - PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); - - ThenBlasImpl &, int, const DeviceMemory &, - int, float, DeviceMemory *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda, - x, incx, beta, y, incy); -} - -namespace { -// Like ThenBlasImpl, except this expects the last argument of blas_func to be a -// blas::ProfileResult*. This functor doesn't put the stream into an error -// state if the op fails and the profile result is non-null. Instead, the -// error-ness is returned in the profile result itself. -template -struct ThenBlasWithProfileImpl { - Stream &operator()(Stream *stream, - bool (blas::BlasSupport::*blas_func)( - Stream *, Args..., blas::ProfileResult *), - Args... args, blas::ProfileResult *profile_result) { - ThenBlasImpl Runner; - bool record_error = profile_result == nullptr; - return Runner.Run(stream, blas_func, record_error, args..., profile_result); +absl::Status Stream::Memcpy(void *host_dst, const DeviceMemoryBase &gpu_src, + uint64_t size) { + if (parent_->Memcpy(this, host_dst, gpu_src, size)) { + return absl::OkStatus(); } -}; -} // anonymous namespace - -Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, float alpha, - const DeviceMemory &a, int lda, - DeviceMemory *b, int ldb) { - VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), - PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); - - ThenBlasImpl &, int, - DeviceMemory *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m, - n, alpha, a, lda, b, ldb); -} - -Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, double alpha, - const DeviceMemory &a, int lda, - DeviceMemory *b, int ldb) { - VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), - PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); - - ThenBlasImpl &, int, - DeviceMemory *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m, - n, alpha, a, lda, b, ldb); -} - -Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, std::complex alpha, - const DeviceMemory> &a, - int lda, DeviceMemory> *b, - int ldb) { - VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), - PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); - - ThenBlasImpl, - const DeviceMemory> &, int, - DeviceMemory> *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m, - n, alpha, a, lda, b, ldb); + return absl::InternalError("failed to memcpy"); } -Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, std::complex alpha, - const DeviceMemory> &a, - int lda, DeviceMemory> *b, - int ldb) { - VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), - PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); - - ThenBlasImpl, - const DeviceMemory> &, int, - DeviceMemory> *, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m, - n, alpha, a, lda, b, ldb); -} - -Stream &Stream::ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, float alpha, - const DeviceMemory &as, int lda, - DeviceMemory *bs, int ldb, - int batch_count) { - VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), - PARAM(n), PARAM(alpha), PARAM(as), PARAM(lda), PARAM(bs), - PARAM(ldb), PARAM(batch_count)); - - ThenBlasImpl &, int, - DeviceMemory *, int, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasTrsmBatched, side, uplo, transa, - diag, m, n, alpha, as, lda, bs, ldb, batch_count); -} - -Stream &Stream::ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, double alpha, - const DeviceMemory &as, int lda, - DeviceMemory *bs, int ldb, - int batch_count) { - VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), - PARAM(n), PARAM(alpha), PARAM(as), PARAM(lda), PARAM(bs), - PARAM(ldb), PARAM(batch_count)); - - ThenBlasImpl &, int, - DeviceMemory *, int, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasTrsmBatched, side, uplo, transa, - diag, m, n, alpha, as, lda, bs, ldb, batch_count); -} - -Stream &Stream::ThenBlasTrsmBatched( - blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, std::complex alpha, - const DeviceMemory *> &as, int lda, - DeviceMemory *> *bs, int ldb, int batch_count) { - VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), - PARAM(n), PARAM(alpha), PARAM(as), PARAM(lda), PARAM(bs), - PARAM(ldb), PARAM(batch_count)); - - ThenBlasImpl, - const DeviceMemory *> &, int, - DeviceMemory *> *, int, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasTrsmBatched, side, uplo, transa, - diag, m, n, alpha, as, lda, bs, ldb, batch_count); -} - -Stream &Stream::ThenBlasTrsmBatched( - blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, std::complex alpha, - const DeviceMemory *> &as, int lda, - DeviceMemory *> *bs, int ldb, int batch_count) { - VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), - PARAM(n), PARAM(alpha), PARAM(as), PARAM(lda), PARAM(bs), - PARAM(ldb), PARAM(batch_count)); - - ThenBlasImpl, - const DeviceMemory *> &, int, - DeviceMemory *> *, int, int> - impl; - return impl(this, &blas::BlasSupport::DoBlasTrsmBatched, side, uplo, transa, - diag, m, n, alpha, as, lda, bs, ldb, batch_count); -} - -Stream &Stream::ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, float alpha, DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, float beta, - DeviceMemorySlice c, int ldc, int batch_count, - const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, - blas::CallContext context) { - VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), - PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), - PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - - ThenBlasImpl, int, - DeviceMemorySlice, int, float, - DeviceMemorySlice, int, int, const NumericOptions &, - ScratchAllocator *, blas::CallContext> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, - k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - numeric_options, scratch_allocator, context); -} - -Stream &Stream::ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, float alpha, DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, float beta, - DeviceMemorySlice c, int ldc, int batch_count, - const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, - blas::CallContext context) { - VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), - PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), - PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - - ThenBlasImpl, int, - DeviceMemorySlice, int, float, - DeviceMemorySlice, int, int, - const NumericOptions &, ScratchAllocator *, blas::CallContext> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, - k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - numeric_options, scratch_allocator, context); -} - -Stream &Stream::ThenBlasGemmBatched( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, float alpha, DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, - int ldc, int batch_count, const NumericOptions &numeric_options, - blas::CallContext context) { - return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, - b, ldb, beta, c, ldc, batch_count, - numeric_options, - /*scratch_allocator=*/nullptr, context); -} - -Stream &Stream::ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, float alpha, DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, - int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context) { - VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), - PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), - PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - - ThenBlasImpl, int, DeviceMemorySlice, - int, float, DeviceMemorySlice, int, int, - const NumericOptions &, ScratchAllocator *, blas::CallContext> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, - k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - numeric_options, scratch_allocator, context); -} - -Stream &Stream::ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, double alpha, DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, double beta, - DeviceMemorySlice c, int ldc, int batch_count, - const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, - blas::CallContext context) { - VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), - PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), - PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - - ThenBlasImpl, int, - DeviceMemorySlice, int, double, - DeviceMemorySlice, int, int, const NumericOptions &, - ScratchAllocator *, blas::CallContext> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, - k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - numeric_options, scratch_allocator, context); -} - -Stream &Stream::ThenBlasGemmBatched( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, std::complex alpha, - DeviceMemorySlice> a, int lda, - DeviceMemorySlice> b, int ldb, std::complex beta, - DeviceMemorySlice> c, int ldc, int batch_count, - const NumericOptions &numeric_options, blas::CallContext context) { - return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, - b, ldb, beta, c, ldc, batch_count, - numeric_options, - /*scratch_allocator=*/nullptr, context); -} - -Stream &Stream::ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, std::complex alpha, - DeviceMemorySlice> a, int lda, - DeviceMemorySlice> b, int ldb, std::complex beta, - DeviceMemorySlice> c, int ldc, int batch_count, - const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, - blas::CallContext context) { - VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), - PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), - PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - - ThenBlasImpl, DeviceMemorySlice>, int, - DeviceMemorySlice>, int, std::complex, - DeviceMemorySlice>, int, int, - const NumericOptions &, ScratchAllocator *, blas::CallContext> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, - k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - numeric_options, scratch_allocator, context); -} - -Stream &Stream::ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, std::complex alpha, - DeviceMemorySlice> a, int lda, - DeviceMemorySlice> b, int ldb, - std::complex beta, DeviceMemorySlice> c, - int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context) { - VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), - PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), - PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - - ThenBlasImpl, DeviceMemorySlice>, - int, DeviceMemorySlice>, int, - std::complex, DeviceMemorySlice>, - int, int, const NumericOptions &, ScratchAllocator *, - blas::CallContext> - impl; - return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, - k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - numeric_options, scratch_allocator, context); -} - -Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src, - uint64_t size) { - VLOG_CALL(PARAM(host_dst), PARAM(gpu_src), PARAM(size)); - - CheckError(parent_->Memcpy(this, host_dst, gpu_src, size)); - return *this; -} - -Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src, - uint64_t size) { - VLOG_CALL(PARAM(gpu_dst), PARAM(host_src), PARAM(size)); - - CheckError(parent_->Memcpy(this, gpu_dst, host_src, size)); - return *this; -} - -Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, - const DeviceMemoryBase &gpu_src, uint64_t size) { - VLOG_CALL(PARAM(gpu_dst), PARAM(gpu_src), PARAM(size)); - - CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size)); - return *this; -} - -Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64_t size) { - VLOG_CALL(PARAM(location), PARAM(size)); - - CheckStatus(parent_->MemZero(this, location, size)); - return *this; -} - -Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32_t pattern, - uint64_t size) { - VLOG_CALL(PARAM(location), PARAM(pattern), PARAM(size)); - - CheckStatus(parent_->Memset32(this, location, pattern, size)); - return *this; -} - -Stream &Stream::ThenRnnForward( - const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - DeviceMemory *output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - DeviceMemory *output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - DeviceMemory *output_c_data, bool is_training, - ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result) { - // TODO(zhengxq): add VLOG PARAM calls. - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - auto status = dnn->DoRnnForward( - this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, - input_h_data, input_c_desc, input_c_data, params, output_desc, - output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, - is_training, reserve_space_allocator, workspace_allocator, - output_profile_result); - if (!status && !output_profile_result) { - SetError(); - } - } else { - SetErrorAndLogNoDnnSupport(); +absl::Status Stream::Memcpy(DeviceMemoryBase *gpu_dst, const void *host_src, + uint64_t size) { + if (parent_->Memcpy(this, gpu_dst, host_src, size)) { + return absl::OkStatus(); } - return *this; + return absl::InternalError("failed to memcpy"); } -Stream &Stream::ThenRnnForward( - const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - DeviceMemory *output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - DeviceMemory *output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - DeviceMemory *output_c_data, bool is_training, - ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result) { - // TODO(zhengxq): add VLOG PARAM calls. - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - auto status = dnn->DoRnnForward( - this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, - input_h_data, input_c_desc, input_c_data, params, output_desc, - output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, - is_training, reserve_space_allocator, workspace_allocator, - output_profile_result); - if (!status && !output_profile_result) { - SetError(); - } - } else { - SetErrorAndLogNoDnnSupport(); +absl::Status Stream::Memcpy(DeviceMemoryBase *gpu_dst, + const DeviceMemoryBase &gpu_src, uint64_t size) { + if (parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size)) { + return absl::OkStatus(); } - return *this; + return absl::InternalError("failed to memcpy"); } -Stream &Stream::ThenRnnForward( - const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - DeviceMemory *output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - DeviceMemory *output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - DeviceMemory *output_c_data, bool is_training, - ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result) { - // TODO(zhengxq): add VLOG PARAM calls. - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - auto status = dnn->DoRnnForward( - this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, - input_h_data, input_c_desc, input_c_data, params, output_desc, - output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, - is_training, reserve_space_allocator, workspace_allocator, - output_profile_result); - if (!status && !output_profile_result) { - SetError(); - } - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; -} - -Stream &Stream::ThenRnnBackward( - const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - const DeviceMemory &output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - const DeviceMemory &output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - const DeviceMemory &output_c_data, - const DeviceMemory &output_backprop_data, - const DeviceMemory &output_h_backprop_data, - const DeviceMemory &output_c_backprop_data, - DeviceMemory *input_backprop_data, - DeviceMemory *input_h_backprop_data, - DeviceMemory *input_c_backprop_data, - DeviceMemory *params_backprop_data, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result) { - // TODO(zhengxq): add VLOG PARAM calls. - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - auto status = dnn->DoRnnBackward( - this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, - input_h_data, input_c_desc, input_c_data, params, output_desc, - output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, - output_backprop_data, output_h_backprop_data, output_c_backprop_data, - input_backprop_data, input_h_backprop_data, input_c_backprop_data, - params_backprop_data, reserve_space_data, workspace_allocator, - output_profile_result); - if (!status && !output_profile_result) { - SetError(); - } - } else { - SetError(); - LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support"; - } - return *this; -} - -Stream &Stream::ThenRnnBackward( - const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - const DeviceMemory &output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - const DeviceMemory &output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - const DeviceMemory &output_c_data, - const DeviceMemory &output_backprop_data, - const DeviceMemory &output_h_backprop_data, - const DeviceMemory &output_c_backprop_data, - DeviceMemory *input_backprop_data, - DeviceMemory *input_h_backprop_data, - DeviceMemory *input_c_backprop_data, - DeviceMemory *params_backprop_data, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result) { - // TODO(zhengxq): add VLOG PARAM calls. - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - auto status = dnn->DoRnnBackward( - this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, - input_h_data, input_c_desc, input_c_data, params, output_desc, - output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, - output_backprop_data, output_h_backprop_data, output_c_backprop_data, - input_backprop_data, input_h_backprop_data, input_c_backprop_data, - params_backprop_data, reserve_space_data, workspace_allocator, - output_profile_result); - if (!status && !output_profile_result) { - SetError(); - } - } else { - SetError(); - LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support"; - } - return *this; -} - -Stream &Stream::ThenRnnBackward( - const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - const DeviceMemory &output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - const DeviceMemory &output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - const DeviceMemory &output_c_data, - const DeviceMemory &output_backprop_data, - const DeviceMemory &output_h_backprop_data, - const DeviceMemory &output_c_backprop_data, - DeviceMemory *input_backprop_data, - DeviceMemory *input_h_backprop_data, - DeviceMemory *input_c_backprop_data, - DeviceMemory *params_backprop_data, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result) { - // TODO(zhengxq): add VLOG PARAM calls. - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - auto status = dnn->DoRnnBackward( - this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, - input_h_data, input_c_desc, input_c_data, params, output_desc, - output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, - output_backprop_data, output_h_backprop_data, output_c_backprop_data, - input_backprop_data, input_h_backprop_data, input_c_backprop_data, - params_backprop_data, reserve_space_data, workspace_allocator, - output_profile_result); - if (!status && !output_profile_result) { - SetError(); - } - } else { - SetError(); - LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support"; - } - return *this; -} - -Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc, - const DeviceMemory &probs_data, - absl::Span labels_data, - absl::Span labels_lengths_data, - absl::Span input_lengths_data, - const NumericOptions &numeric_options, - DeviceMemory *costs_data, - const dnn::RnnStateTensorDescriptor &grads_desc, - DeviceMemory *grads_data, - ScratchAllocator *workspace_allocator) { - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - DeviceMemory scratch_memory; - int ctc_loss_algo_id; - auto status = - dnn->PrepareForCtcLoss( - this, probs_desc, probs_data, grads_desc, labels_data, - labels_lengths_data, input_lengths_data, numeric_options, - workspace_allocator, &scratch_memory, &ctc_loss_algo_id) - .ok(); - if (status) { - status = dnn->DoCtcLoss(this, probs_desc, probs_data, labels_data, - labels_lengths_data, input_lengths_data, - costs_data, grads_desc, grads_data, - &scratch_memory, ctc_loss_algo_id); - } - if (!status) { - SetError(); - } - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; +absl::Status Stream::MemZero(DeviceMemoryBase *location, uint64_t size) { + return parent_->MemZero(this, location, size); } -Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc, - dnn::DataType input_type, - const DeviceMemoryBase &input_data, - const dnn::BatchDescriptor &output_desc, - dnn::DataType output_type, float scale, - DeviceMemoryBase *output_data) { - VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data), - PARAM(output_desc), PARAM(output_type), PARAM(scale), - PARAM(output_data)); - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoTransformTensor(this, input_desc, input_type, input_data, - output_desc, output_type, scale, - output_data)); - } else { - SetErrorAndLogNoDnnSupport(); - } - return *this; +absl::Status Stream::Memset32(DeviceMemoryBase *location, uint32_t pattern, + uint64_t size) { + return parent_->Memset32(this, location, pattern, size); } -Stream &Stream::ThenDoHostCallback(absl::AnyInvocable callback) { - return ThenDoHostCallbackWithStatus([cb = std::move(callback)]() mutable { +absl::Status Stream::DoHostCallback(absl::AnyInvocable callback) { + return DoHostCallbackWithStatus([cb = std::move(callback)]() mutable { std::move(cb)(); - return ::tsl::OkStatus(); + return absl::OkStatus(); }); } -Stream &Stream::ThenDoHostCallbackWithStatus( - absl::AnyInvocable callback) { - VLOG_CALL(PARAM(callback)); - - if (!ok()) { - LOG(INFO) << DebugStreamPointers() - << " was in error state before adding host callback"; +absl::Status Stream::DoHostCallbackWithStatus( + absl::AnyInvocable callback) { + if (parent_->HostCallback(this, std::move(callback))) { + return absl::OkStatus(); } - CheckError(parent_->HostCallback(this, std::move(callback))); - return *this; + return absl::InternalError("failed to host callback"); } void Stream::CheckError(bool operation_retcode) { @@ -1815,157 +255,26 @@ void Stream::CheckError(bool operation_retcode) { return; } absl::MutexLock lock(&mu_); - status_ = tsl::errors::Internal("Unknown error"); -} - -Stream &Stream::ThenFft(fft::Plan *plan, - const DeviceMemory> &input, - DeviceMemory> *output) { - VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); - - if (fft::FftSupport *fft = parent_->AsFft()) { - CheckError(fft->DoFft(this, plan, input, output)); - } else { - SetError(); - LOG(INFO) << DebugStreamPointers() - << " attempting to perform FFT operation using StreamExecutor" - " without FFT support"; - } - return *this; -} - -Stream &Stream::ThenFft(fft::Plan *plan, - const DeviceMemory> &input, - DeviceMemory> *output) { - VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); - - if (fft::FftSupport *fft = parent_->AsFft()) { - CheckError(fft->DoFft(this, plan, input, output)); - } else { - SetError(); - LOG(INFO) << DebugStreamPointers() - << " attempting to perform FFT operation using StreamExecutor" - " without FFT support"; - } - return *this; -} - -Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory &input, - DeviceMemory> *output) { - VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); - - if (fft::FftSupport *fft = parent_->AsFft()) { - CheckError(fft->DoFft(this, plan, input, output)); - } else { - SetError(); - LOG(INFO) << DebugStreamPointers() - << " attempting to perform FFT operation using StreamExecutor" - " without FFT support"; - } - return *this; + status_ = absl::InternalError("Unknown error"); } -Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory &input, - DeviceMemory> *output) { - VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); - - if (fft::FftSupport *fft = parent_->AsFft()) { - CheckError(fft->DoFft(this, plan, input, output)); - } else { - SetError(); - LOG(INFO) << DebugStreamPointers() - << " attempting to perform FFT operation using StreamExecutor" - " without FFT support"; - } - return *this; -} - -Stream &Stream::ThenFft(fft::Plan *plan, - const DeviceMemory> &input, - DeviceMemory *output) { - VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); - - if (fft::FftSupport *fft = parent_->AsFft()) { - CheckError(fft->DoFft(this, plan, input, output)); - } else { - SetError(); - LOG(INFO) << DebugStreamPointers() - << " attempting to perform FFT operation using StreamExecutor" - " without FFT support"; - } - return *this; -} - -Stream &Stream::ThenFft(fft::Plan *plan, - const DeviceMemory> &input, - DeviceMemory *output) { - VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); - - if (fft::FftSupport *fft = parent_->AsFft()) { - CheckError(fft->DoFft(this, plan, input, output)); - } else { - SetError(); - LOG(INFO) << DebugStreamPointers() - << " attempting to perform FFT operation using StreamExecutor" - " without FFT support"; - } - return *this; -} - -// It looks confusing, but all this is doing is inserting a callback at the -// present point in the stream to then enqueue a task on the host executor. -Stream &Stream::ThenEnqueueOnBackgroundThread( - std::function task) { - VLOG_CALL(PARAM(task)); - - StreamExecutor *stream_executor = this->parent_; - std::function bound_task = std::bind(task, stream_executor); - - return ThenDoHostCallback([stream_executor, bound_task]() { - stream_executor->EnqueueOnBackgroundThread(bound_task); - }); -} - -tsl::Status Stream::BlockHostUntilDone() { - VLOG_CALL(); - +absl::Status Stream::BlockHostUntilDone() { if (!ok()) { absl::MutexLock lock(&mu_); LOG(INFO) << status_.ToString(); - tsl::Status status = tsl::Status( - absl::StatusCode::kInternal, + absl::Status status = absl::InternalError( "stream did not block host until done; was already in an error state"); - LOG(INFO) << DebugStreamPointers() << " " << status; + LOG(INFO) << "stream = " << this << " " << status; return status; } - temporary_memory_manager_.DeallocateFinalizedTemporaries(); - - tsl::Status error = parent_->BlockHostUntilDone(this); + absl::Status error = parent_->BlockHostUntilDone(this); CheckError(error.ok()); - RunAfterBlockHostUntilDoneCallbacks(); return error; } -void Stream::RunAfterBlockHostUntilDoneCallbacks() { - std::vector> callbacks; - { - absl::MutexLock lock(&mu_); - std::swap(callbacks, after_block_host_until_done_callbacks_); - } - for (auto &fn : callbacks) { - std::move(fn)(); - } -} - -std::string Stream::DebugStreamPointers() const { - // Relies on the ToVlogString(const void*) overload above. - return absl::StrCat("[stream=", ToVlogString(this), - ",impl=", ToVlogString(implementation_.get()), "]"); -} - -void Stream::CheckStatus(tsl::Status status) { +void Stream::CheckStatus(absl::Status status) { if (status.ok()) { return; } diff --git a/xla/stream_executor/stream.h b/xla/stream_executor/stream.h index 3021720d85980..274009b06d5f8 100644 --- a/xla/stream_executor/stream.h +++ b/xla/stream_executor/stream.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,32 +21,32 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_STREAM_H_ #define XLA_STREAM_EXECUTOR_STREAM_H_ -#include #include -#include #include #include #include -#include #include #include #include +#include "absl/base/attributes.h" #include "absl/base/thread_annotations.h" #include "absl/functional/any_invocable.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" -#include "xla/stream_executor/blas.h" +#include "absl/types/span.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/fft.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/numeric_options.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/stream_executor_pimpl.h" -#include "xla/stream_executor/temporary_memory_manager.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/thread_annotations.h" namespace stream_executor { @@ -58,35 +58,7 @@ class DeviceMemoryBase; template class DeviceMemory; -namespace dnn { -class BatchDescriptor; -class FilterDescriptor; -class ConvolutionDescriptor; -class ProfileResult; -class AlgorithmDesc; -} // namespace dnn - class StreamExecutor; -class ScratchAllocator; - -namespace detail { - -// Helper to return if `T` is the same type as `First` or any or `Rest`. -template -constexpr bool is_any_of() { - return false; -} - -template -constexpr bool is_any_of() { - return std::is_same_v || is_any_of(); -} - -} // namespace detail - -// Convert a type to the corresponding QuantizedActivationMode. -template -struct Quantization; // Represents a stream of dependent computations on a GPU device. // @@ -135,18 +107,19 @@ class Stream { // devices should also override AllowsSyncOnCompletion to return false.) For // these devices, this method can be used after work is finished to retrieve // execution status. - tsl::Status RefreshStatus() TF_LOCKS_EXCLUDED(mu_); + absl::Status RefreshStatus() TF_LOCKS_EXCLUDED(mu_); // Initialize the stream. This must be performed before entraining any other // operations. - Stream &Init() TF_LOCKS_EXCLUDED(mu_); + absl::Status Initialize( + std::optional> priority = std::nullopt); // Get or create a sub-stream from this stream. If there is any sub-stream in // the pool that can be reused then just return this sub-stream. Otherwise // create a new sub-stream. // // TODO(b/112196569): The semantics of failed sub-streams is error-prone. - Stream *GetOrCreateSubStream() TF_LOCKS_EXCLUDED(mu_); + absl::StatusOr GetOrCreateSubStream() TF_LOCKS_EXCLUDED(mu_); // Return the sub-stream back to the host stream so that it can be reused // later. Sub-streams that are !ok() will not be reused. @@ -154,12 +127,6 @@ class Stream { // TODO(b/112196569): The semantics of failed sub-streams is error-prone. void ReturnSubStream(Stream *sub_stream) TF_LOCKS_EXCLUDED(mu_); - // Allocate temporary memories. The stream will deallocate them when blocked - // or destroyed. - template - tsl::StatusOr>> - AllocateTemporaryArray(uint64_t element_count); - // Entrains onto the stream of operations: a kernel launch with the given // (variadic) parameters for the invocation. These arguments can be things // like DeviceMemory or primitive types such as int. What arguments you may @@ -179,14 +146,24 @@ class Stream { // spit out helpful static_assert error traces with information as to the // argument number and types that were mismatched. template - tsl::Status ThenLaunch(ThreadDim thread_dims, BlockDim block_dims, - const TypedKernel &kernel, Args... args); + absl::Status ThenLaunch(ThreadDim thread_dims, BlockDim block_dims, + const TypedKernel &kernel, Args... args); + + template + absl::Status ThenLaunch(ThreadDim thread_dims, BlockDim block_dims, + ClusterDim cluster_dims, + const TypedKernel &kernel, Args... args); // Same as above, with an explicit argument for shared memory size in bytes. template - tsl::Status ThenLaunch(ThreadDim thread_dims, BlockDim block_dims, - int32_t shmem_bytes, - const TypedKernel &kernel, Args... args); + absl::Status ThenLaunch(ThreadDim thread_dims, BlockDim block_dims, + int32_t shmem_bytes, + const TypedKernel &kernel, Args... args); + + template + absl::Status ThenLaunch(ThreadDim thread_dims, BlockDim block_dims, + ClusterDim cluster_dims, int32_t shmem_bytes, + const TypedKernel &kernel, Args... args); // Create a dependency for this stream's next work on the other stream // completing. Does not take ownership of other, and other must not be @@ -195,1139 +172,80 @@ class Stream { // Checks that a stream does not wait for itself, and it is up to the // user to guarantee that a stream does not come to wait on itself in a // cyclic manner; in that case, behavior is undefined. - // - // N.B. Base recursion case for the variadic ThenWaitFor. - Stream &ThenWaitFor(Stream *other); - - // Waits for all streams values in others. - // Checks that there is no shallow circular wait (i.e. that "this" is not in - // others) - template - Stream &ThenWaitFor(P others) { - for (auto &stream : *others) { - CHECK_NE(stream.get(), this); - ThenWaitFor(stream.get()); - } - return *this; - } + absl::Status WaitFor(Stream *other); // Waits for an event object to be set. - // Note that ThenRecordEvent must have been called on the event before + // Note that RecordEvent must have been called on the event before // you call this function; otherwise the event will be considered complete // and this wait will do nothing. - Stream &ThenWaitFor(Event *event); + absl::Status WaitFor(Event *event); // Inserts the specified event into the end of this stream. Once the stream // has processed all events prior to the insertion point, the event will be // marked as completed. // The stream does not take ownership of event - meaning that event's lifetime // must extend past the point at which it is marked complete! - Stream &ThenRecordEvent(Event *event); - - //////////////// - // DNN support - // - // See DnnSupport::* for comments on the following methods. - - Stream &ThenBatchNormalizationForward( - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, - const DeviceMemory &estimated_mean, - const DeviceMemory &estimated_variance, - const DeviceMemory &side_input, const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - const double exponential_average_factor, - dnn::ActivationMode activation_mode, DeviceMemory *y, - DeviceMemory *batch_mean, DeviceMemory *batch_var, - DeviceMemory *saved_mean, DeviceMemory *saved_inv_var, - bool is_training, ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator); - - Stream &ThenBatchNormalizationBackward( - const DeviceMemory &y_backprop, const DeviceMemory &x, - const DeviceMemory &scale, const DeviceMemory &offset, - const DeviceMemory &mean, const DeviceMemory &inv_var, - const DeviceMemory &y, const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - dnn::ActivationMode activation_mode, DeviceMemory *x_backprop, - DeviceMemory *scale_backprop, DeviceMemory *offset_backprop, - DeviceMemory *side_input_backprop, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator); - - Stream &ThenBatchNormalizationForward( - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, - const DeviceMemory &estimated_mean, - const DeviceMemory &estimated_variance, - const DeviceMemory &side_input, - const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - const double exponential_average_factor, - dnn::ActivationMode activation_mode, DeviceMemory *y, - DeviceMemory *batch_mean, DeviceMemory *batch_var, - DeviceMemory *saved_mean, DeviceMemory *saved_inv_var, - bool is_training, ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator); - - Stream &ThenBatchNormalizationBackward( - const DeviceMemory &y_backprop, - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, const DeviceMemory &mean, - const DeviceMemory &inv_var, const DeviceMemory &y, - const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - dnn::ActivationMode activation_mode, - DeviceMemory *x_backprop, - DeviceMemory *scale_backprop, DeviceMemory *offset_backprop, - DeviceMemory *side_input_backprop, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator); - - Stream &ThenBatchNormalizationForward( - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, - const DeviceMemory &estimated_mean, - const DeviceMemory &estimated_variance, - const DeviceMemory &side_input, - const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - const double exponential_average_factor, - dnn::ActivationMode activation_mode, DeviceMemory *y, - DeviceMemory *batch_mean, DeviceMemory *batch_var, - DeviceMemory *saved_mean, DeviceMemory *saved_inv_var, - bool is_training, ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator); - - Stream &ThenBatchNormalizationBackward( - const DeviceMemory &y_backprop, - const DeviceMemory &x, const DeviceMemory &scale, - const DeviceMemory &offset, const DeviceMemory &mean, - const DeviceMemory &inv_var, - const DeviceMemory &y, - const dnn::BatchDescriptor &x_desc, - const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, - dnn::ActivationMode activation_mode, - DeviceMemory *x_backprop, - DeviceMemory *scale_backprop, DeviceMemory *offset_backprop, - DeviceMemory *side_input_backprop, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator); - - Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor, - const DeviceMemory &input_data, - const dnn::FilterDescriptor &filter_descriptor, - const DeviceMemory &filter_data, - const dnn::ConvolutionDescriptor &convolution_descriptor, - const dnn::BatchDescriptor &output_descriptor, - DeviceMemory *output); - - template - tsl::Status ConvolveWithAlgorithm( - dnn::ConvolutionKind kind, const dnn::BatchDescriptor &input_descriptor, - DeviceMemory input_data, - const dnn::FilterDescriptor &filter_descriptor, - DeviceMemory filter_data, - const dnn::BatchDescriptor &output_descriptor, - DeviceMemory output_data, - const dnn::ConvolutionDescriptor &convolution_descriptor, - ScratchAllocator *scratch_allocator, - const dnn::AlgorithmConfig &algorithm_config, - dnn::ProfileResult *output_profile_result) { - DeviceMemory scratch_memory; - dnn::AlgorithmDesc algorithm_desc; - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - TF_RETURN_IF_ERROR(dnn->PrepareForConvolution( - kind, this, input_descriptor, input_data, filter_descriptor, - filter_data, output_descriptor, output_data, convolution_descriptor, - algorithm_config, scratch_allocator, &algorithm_desc, - &scratch_memory)); - return dnn->DoConvolve(kind, dnn::ToDataType::value, - dnn::ToDataType::value, this, - input_descriptor, input_data, filter_descriptor, - filter_data, output_descriptor, output_data, - convolution_descriptor, algorithm_desc, - scratch_memory, output_profile_result); - } - return tsl::errors::Unimplemented("DNN library is not found."); - } - - template - tsl::Status FusedConvolveWithAlgorithm( - const dnn::BatchDescriptor &conv_input_descriptor, - const DeviceMemory &conv_input_data, ScaleT conv_input_scale, - const dnn::FilterDescriptor &filter_descriptor, - const DeviceMemory &filter_data, - const dnn::ConvolutionDescriptor &convolution_descriptor, - const DeviceMemory &side_input_data, ScaleT side_input_scale, - const dnn::BatchDescriptor &bias_descriptor, - const DeviceMemory &biases, dnn::ActivationMode activation_mode, - const dnn::BatchDescriptor &output_descriptor, - DeviceMemory *output, ScratchAllocator *scratch_allocator, - const dnn::AlgorithmConfig &algorithm_config, - dnn::ProfileResult *output_profile_result) { - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - return dnn->DoFusedConvolve( - this, dnn::ToDataType::value, - dnn::ToDataType::value, dnn::ToDataType::value, - dnn::ToDataType::value, conv_input_descriptor, - conv_input_data, conv_input_scale, filter_descriptor, filter_data, - convolution_descriptor, side_input_data, side_input_scale, - bias_descriptor, biases, activation_mode, output_descriptor, *output, - scratch_allocator, algorithm_config, output_profile_result); - } - return tsl::errors::Unimplemented("DNN library is not found."); - } - - tsl::Status CudnnReorderConvolutionFilterAndBias( - const dnn::FilterDescriptor &filter_descriptor, - const DeviceMemory &filter_input, - DeviceMemory *filter_output, - std::optional> bias_input, - std::optional> bias_output) { - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - return dnn->CudnnReorderConvolutionFilterAndBias( - this, filter_descriptor, filter_input, filter_output, - std::move(bias_input), std::move(bias_output)); - } - return tsl::errors::Unimplemented("DNN library is not found."); - } - - tsl::StatusOr> ConvolveRunnerFromDesc( - const dnn::AlgorithmDesc &algorithm_desc, dnn::ConvolutionKind kind, - dnn::DataType element_type, dnn::DataType output_type, - const dnn::BatchDescriptor &input_descriptor, - const dnn::FilterDescriptor &filter_descriptor, - const dnn::BatchDescriptor &output_descriptor, - const dnn::ConvolutionDescriptor &convolution_descriptor) { - dnn::DnnSupport *dnn_support = parent_->AsDnn(); - if (!dnn_support) { - return tsl::errors::Unimplemented("DNN library is not found."); - } - return dnn_support->ConvolveRunnerFromDesc( - this, algorithm_desc, kind, element_type, output_type, input_descriptor, - filter_descriptor, output_descriptor, convolution_descriptor); - } - - tsl::StatusOr> - GraphConvolveRunnerFromDesc( - const dnn::AlgorithmDesc &algorithm_desc, dnn::ConvolutionKind kind, - dnn::DataType element_type, dnn::DataType output_type, - const dnn::BatchDescriptor &input_descriptor, - const dnn::FilterDescriptor &filter_descriptor, - const dnn::BatchDescriptor &output_descriptor, - const dnn::ConvolutionDescriptor &convolution_descriptor, - std::string serialized_graph) { - dnn::DnnSupport *dnn_support = parent_->AsDnn(); - if (!dnn_support) { - return tsl::errors::Unimplemented("DNN library is not found."); - } - return dnn_support->GraphConvolveRunnerFromDesc( - this, algorithm_desc, kind, element_type, output_type, input_descriptor, - filter_descriptor, output_descriptor, convolution_descriptor, - serialized_graph); - } - - tsl::StatusOr> - FusedConvolveRunnerFromDesc( - const dnn::AlgorithmDesc &algorithm_desc, dnn::ConvolutionKind kind, - dnn::DataType element_type, dnn::DataType bias_type, - dnn::DataType output_type, double conv_input_scale, - double side_input_scale, double leakyrelu_alpha, - const dnn::BatchDescriptor &input_descriptor, - const dnn::FilterDescriptor &filter_descriptor, - const dnn::BatchDescriptor &bias_descriptor, - const dnn::BatchDescriptor &output_descriptor, - const dnn::ConvolutionDescriptor &convolution_descriptor, - dnn::ActivationMode activation_mode) { - dnn::DnnSupport *dnn_support = parent_->AsDnn(); - if (!dnn_support) { - return tsl::errors::Unimplemented("DNN library is not found."); - } - return dnn_support->FusedConvolveRunnerFromDesc( - this, algorithm_desc, kind, element_type, bias_type, output_type, - conv_input_scale, side_input_scale, leakyrelu_alpha, input_descriptor, - filter_descriptor, bias_descriptor, output_descriptor, - convolution_descriptor, activation_mode); - } - - tsl::StatusOr> NormRunnerFromDesc( - const dnn::AlgorithmDesc &algorithm_desc, double epsilon, - const dnn::TensorDescriptor &input_descriptor, - const dnn::TensorDescriptor &scale_descriptor, - const dnn::TensorDescriptor &bias_descriptor, - const dnn::TensorDescriptor &output_descriptor, - std::optional expectation_descriptor, - std::optional norm_factor_descriptor) { - dnn::DnnSupport *dnn_support = parent_->AsDnn(); - if (!dnn_support) { - return absl::UnimplementedError("DNN library is not found."); - } - return dnn_support->NormRunnerFromDesc( - this, algorithm_desc, epsilon, input_descriptor, scale_descriptor, - bias_descriptor, output_descriptor, expectation_descriptor, - norm_factor_descriptor); - } - - tsl::StatusOr> - FusedMHARunnerFromDesc( - const dnn::AlgorithmDesc &algorithm_desc, dnn::FusedMHAKind kind, - const dnn::MatmulTensorDescriptor &bmm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor &bmm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor &bmm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor &intermediate_bmm2_lhs_descriptor, - const dnn::TensorDescriptor &output_descriptor, - std::optional activation_descriptor, - std::optional mask_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - bool is_flash_attention, bool is_causal_mask) { - dnn::DnnSupport *dnn_support = parent_->AsDnn(); - if (!dnn_support) { - return absl::UnimplementedError("DNN library is not found."); - } - return dnn_support->FusedMHARunnerFromDesc( - this, algorithm_desc, kind, bmm1_lhs_descriptor, bmm1_rhs_descriptor, - bmm2_rhs_descriptor, intermediate_bmm2_lhs_descriptor, - output_descriptor, activation_descriptor, mask_descriptor, - bias_descriptor, scale, dropout_rate, seed, is_flash_attention, - is_causal_mask); - } - - tsl::StatusOr> - FusedMHABackwardRunnerFromDesc( - const dnn::AlgorithmDesc &algorithm_desc, dnn::FusedMHAKind kind, - const dnn::MatmulTensorDescriptor &bmm1_grad_gemm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor &bmm1_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor &bmm2_grad_gemm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor &bmm2_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor &d_output_descriptor, - const dnn::TensorDescriptor &d_bmm1_lhs_descriptor, - const dnn::TensorDescriptor &d_bmm1_rhs_descriptor, - const dnn::TensorDescriptor &d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional mask_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - bool is_flash_attention, bool is_causal_mask) { - dnn::DnnSupport *dnn_support = parent_->AsDnn(); - if (!dnn_support) { - return absl::UnimplementedError("DNN library is not found."); - } - return dnn_support->FusedMHABackwardRunnerFromDesc( - this, algorithm_desc, kind, bmm1_grad_gemm1_rhs_descriptor, - bmm1_grad_gemm2_rhs_descriptor, bmm2_grad_gemm1_lhs_descriptor, - bmm2_grad_gemm2_rhs_descriptor, d_output_descriptor, - d_bmm1_lhs_descriptor, d_bmm1_rhs_descriptor, d_bmm2_rhs_descriptor, - d_s_descriptor, mask_descriptor, d_bias_descriptor, - fwd_output_descriptor, bias_descriptor, scale, dropout_rate, seed, - is_flash_attention, is_causal_mask); - } - - Stream &ThenSeparableConvolve( - const dnn::BatchDescriptor &input_descriptor, - const DeviceMemory &input_data, - const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier, - const DeviceMemory &first_weights, - const DeviceMemory &second_weights, - const dnn::ConvolutionDescriptor &convolution_descriptor, - const dnn::BatchDescriptor &output_descriptor, - DeviceMemory *output); - - Stream &ThenMatMul(const DeviceMemory &input_data, - const DeviceMemory &weights, - const dnn::BatchDescriptor &input_dimensions, - const dnn::BatchDescriptor &output_dimensions, - DeviceMemory *output_data); - - Stream &ThenMatMulQuantized(const DeviceMemory &input_data, - const DeviceMemory &weights, - const DeviceMemory &weight_scales, - const dnn::BatchDescriptor &input_dimensions, - const dnn::BatchDescriptor &output_dimensions, - DeviceMemory *output_data); - - Stream &ThenMatMulQuantized(const DeviceMemory &input_data, - const DeviceMemory &weights, - const DeviceMemory &weight_scales, - const dnn::BatchDescriptor &input_dimensions, - const dnn::BatchDescriptor &output_dimensions, - DeviceMemory *output_data); - - Stream &ThenBiasAdd(const DeviceMemory &input_data, - const DeviceMemory &biases, - const dnn::BatchDescriptor &dimensions, - DeviceMemory *output_data); - - template - tsl::Status ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions, - const dnn::BatchDescriptor &input_dimensions, - const DeviceMemory &input_data, - const dnn::BatchDescriptor &output_dimensions, - DeviceMemory *output_data, - ScratchAllocator *workspace_allocator = nullptr) { - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - return dnn->DoPoolForward(dnn::ToDataType::value, this, - pooling_dimensions, input_dimensions, - input_data, output_dimensions, *output_data, - workspace_allocator); - } - return tsl::errors::Unimplemented("DNN library is not found."); - } - - template - tsl::Status ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions, - const NumericOptions &numeric_options, - const dnn::BatchDescriptor &input_dimensions, - const DeviceMemory &input_data, - const dnn::BatchDescriptor &output_dimensions, - DeviceMemory *output_data, - ScratchAllocator *workspace_allocator = nullptr) { - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - return dnn->DoPoolForward(dnn::ToDataType::value, this, - pooling_dimensions, numeric_options, - input_dimensions, input_data, output_dimensions, - *output_data, workspace_allocator); - } - return tsl::errors::Unimplemented("DNN library is not found."); - } - - template - tsl::Status ThenPoolBackward( - const dnn::PoolingDescriptor &pooling_dimensions, - const NumericOptions &numeric_options, - const dnn::BatchDescriptor &input_dimensions, - const DeviceMemory &input_data, - const dnn::BatchDescriptor &output_dimensions, - const DeviceMemory &output_data, - const DeviceMemory &input_diff_data, - DeviceMemory *output_diff_data, - ScratchAllocator *workspace_allocator = nullptr) { - if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - return dnn->DoPoolBackward( - dnn::ToDataType::value, this, pooling_dimensions, - numeric_options, input_dimensions, input_data, output_dimensions, - output_data, input_diff_data, *output_diff_data, workspace_allocator); - } - return tsl::errors::Unimplemented("DNN library is not found."); - } - - Stream &ThenNormalizeWithDimensions( - const dnn::NormalizeDescriptor &normalize_descriptor, - const dnn::BatchDescriptor &dimensions, - const DeviceMemory &input_data, DeviceMemory *output_data); - - Stream &ThenNormalizeBackwardWithDimensions( - const dnn::NormalizeDescriptor &normalize_descriptor, - const dnn::BatchDescriptor &dimensions, - const DeviceMemory &raw_data, - const DeviceMemory &normalized_data, - const DeviceMemory &normalized_variable_gradient, - DeviceMemory *raw_variable_gradient, - ScratchAllocator *workspace_allocator = nullptr); - - Stream &ThenActivate(dnn::ActivationMode activation_mode, - const dnn::BatchDescriptor &dimensions, - const DeviceMemory &input_data, - DeviceMemory *output_data); - - // Same as ThenActivate, but also takes an options argument that can be used - // for platform-specific option flags. - Stream &ThenActivateWithOptions(dnn::ActivationMode activation_mode, - const dnn::BatchDescriptor &dimensions, - const DeviceMemory &input_data, - DeviceMemory *output_data, - uint64_t options); - - Stream &ThenDepthConcatenate( - absl::Span input_dimensions, - absl::Span *const> input_data, - DeviceMemory *output_data); - - Stream &ThenElementwiseOperate( - dnn::ElementwiseOperation operation, - absl::Span input_dimensions, - absl::Span *const> input_data, - const dnn::BatchDescriptor &output_dimensions, - DeviceMemory *output_data); - - Stream &ThenXYPad(const dnn::BatchDescriptor &dimensions, - const DeviceMemory &input_data, int64_t left_pad, - int64_t right_pad, int64_t top_pad, int64_t bottom_pad, - DeviceMemory *output_data); - - Stream &ThenXYSlice(const dnn::BatchDescriptor &dimensions, - const DeviceMemory &input_data, int64_t left_trim, - int64_t right_trim, int64_t top_trim, int64_t bottom_trim, - DeviceMemory *output_data); - - // Grows the input tensor by replicating the X and Y dimensions. The batch and - // depth/feature_map dimensions are unchanged. Currently, the input tensor is - // limited to X=1 and Y=1. - Stream &ThenXYBroadcast(const dnn::BatchDescriptor &dimensions, - const DeviceMemory &input_data, - int64_t replicate_x, int64_t replicate_y, - DeviceMemory *output_data); - - // See DnnSupport::DoMemcpyD2HQuantized. - Stream &ThenMemcpyD2HQuantized(const DeviceMemory &gpu_unquantized_src, - dnn::QuantizedActivationMode mode, - void *host_dst, uint64_t size); - - // Template version of ThenMemcpyD2HQuantized that takes a mutable span and - // uses the Quantization trait to call the generic version of - // ThenMemcpyD2HQuantized with the correct QuantizedActivationMode. - template - Stream &ThenMemcpyD2HQuantized(const DeviceMemory &gpu_unquantized_src, - absl::Span host_dst) { - return ThenMemcpyD2HQuantized( - gpu_unquantized_src, Quantization::kModeId, - host_dst.data(), host_dst.size() * sizeof(ElementType)); - } - - // See DnnSupport::DoMemcpyH2DQuantized. - Stream &ThenMemcpyH2DQuantized(const void *host_src, uint64_t size, - dnn::QuantizedActivationMode mode, - DeviceMemory *gpu_unquantized_dst); - - // Template version of ThenMemcpyH2DQuantized that takes an array slice - // and uses the Quantization trait to call the generic version of - // ThenMemcpyH2DQuantized with the correct QuantizedActivationMode. - template - Stream &ThenMemcpyH2DQuantized(absl::Span host_src, - DeviceMemory *gpu_unquantized_dst) { - return ThenMemcpyH2DQuantized( - host_src.data(), host_src.size() * sizeof(ElementType), - Quantization::kModeId, gpu_unquantized_dst); - } - - ///////////////// - // BLAS support - - // See BlasSupport::DoBlasAxpy. Note that, even for the case where alpha is - // present in DeviceMemory, it must be an execution-time constant (i.e. a - // value - // that the stream does not change or populate during the course of - // execution). The value is effectively captured at stream-enqueue time. - Stream &ThenBlasAxpy(uint64_t elem_count, float alpha, - const DeviceMemory &x, int incx, - DeviceMemory *y, int incy); - - // See BlasSupport::DoBlasCopy. - Stream &ThenBlasCopy(uint64_t elem_count, const DeviceMemory &x, - int incx, DeviceMemory *y, int incy); - - // See BlasSupport::DoBlasScal. - Stream &ThenBlasScal(uint64_t elem_count, float alpha, DeviceMemory *x, - int incx); - Stream &ThenBlasScal(uint64_t elem_count, double alpha, - DeviceMemory *x, int incx); - Stream &ThenBlasScal(uint64_t elem_count, float alpha, - DeviceMemory> *x, int incx); - Stream &ThenBlasScal(uint64_t elem_count, double alpha, - DeviceMemory> *x, int incx); - Stream &ThenBlasScal(uint64_t elem_count, std::complex alpha, - DeviceMemory> *x, int incx); - Stream &ThenBlasScal(uint64_t elem_count, std::complex alpha, - DeviceMemory> *x, int incx); - - // See BlasSupport::DoBlasGemv. - Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, float alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &x, int incx, float beta, - DeviceMemory *y, int incy); - Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, - double alpha, const DeviceMemory &a, int lda, - const DeviceMemory &x, int incx, double beta, - DeviceMemory *y, int incy); - Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, - std::complex alpha, - const DeviceMemory> &a, int lda, - const DeviceMemory> &x, int incx, - std::complex beta, - DeviceMemory> *y, int incy); - Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, - std::complex alpha, - const DeviceMemory> &a, int lda, - const DeviceMemory> &x, int incx, - std::complex beta, - DeviceMemory> *y, int incy); - - // See BlasSupport::DoBlasSbmv. - Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k, float alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &x, int incx, float beta, - DeviceMemory *y, int incy); - - template - tsl::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64 n, uint64 k, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, - DeviceMemory *c, int ldc, - const NumericOptions &numeric_options, - blas::CallContext context) { - InputType alpha{1.0}; - InputType beta{0.0}; - return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, - ldc, numeric_options, context); - } - - template - tsl::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64 n, uint64 k, ConstantType alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, - ConstantType beta, DeviceMemory *c, - int ldc, const NumericOptions &numeric_options, - blas::CallContext context) { - static_assert( - detail::is_any_of, - std::complex>(), - "Input can be int8_t, half, bf16, float, double, std::complex " - "or " - "std::complex"); - static_assert(!std::is_same_v || - detail::is_any_of(), - "If input is Eigen::half, constant has to be either " - "Eigen::half or float"); - static_assert(detail::is_any_of(), - "If input is not int8_t, Eigen::half, constant and input " - "types have to match"); - blas::BlasSupport *blas = parent()->AsBlas(); - if (!blas) { - return tsl::errors::Internal( - "Attempting to perform BLAS operation using " - "StreamExecutor without BLAS support"); - } - - void *alpha_ptr = α - void *beta_ptr = β - float alpha_storage, beta_storage; - UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, - &beta_storage); - - return blas->DoBlasGemm( - this, transa, transb, m, n, k, blas::ToDataType::value, - alpha_ptr, a, lda, b, ldb, beta_ptr, c, ldc, numeric_options, context); - } - - // TODO(reedwm): Update all callers to pass correct NumericOptions. - template - tsl::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64 n, uint64 k, ConstantType alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, - ConstantType beta, DeviceMemory *c, - int ldc, blas::CallContext context) { - return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, - ldc, NumericOptions{}, context); - } - - template - tsl::Status ThenBlasGemmWithAlgorithm( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, DeviceMemory *c, - int ldc, blas::ComputationType computation_type, - blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result, - blas::CallContext context) { - OutputType alpha{1}; - OutputType beta{0}; - return ThenBlasGemmWithAlgorithm(transa, transb, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc, computation_type, - algorithm, NumericOptions{}, - output_profile_result, context); - } - - template - tsl::Status ThenBlasGemmWithAlgorithm( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, ConstantType alpha, const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, ConstantType beta, - DeviceMemory *c, int ldc, - blas::ComputationType computation_type, blas::AlgorithmType algorithm, - const NumericOptions &numeric_options, - blas::ProfileResult *output_profile_result, blas::CallContext context) { - TF_RETURN_IF_ERROR( - CheckTypesForExtendedBlas( - computation_type)); - - blas::BlasSupport *blas = parent()->AsBlas(); - if (!blas) { - return tsl::errors::Internal( - "Attempting to perform BLAS operation using " - "StreamExecutor without BLAS support"); - } - - void *alpha_ptr = α - void *beta_ptr = β - float alpha_storage, beta_storage; - UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, - &beta_storage); - - tsl::Status st = blas->DoBlasGemmWithAlgorithm( - this, transa, transb, m, n, k, alpha_ptr, a, - blas::ToDataType::value, lda, b, - blas::ToDataType::value, ldb, beta_ptr, c, - blas::ToDataType::value, ldc, computation_type, algorithm, - numeric_options, output_profile_result, context); - - if (output_profile_result) { - // The error is recorded in the profile. - return ::tsl::OkStatus(); - } - return st; - } - - template - tsl::Status ThenBlasGemmStridedBatchedWithAlgorithm( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, ConstantType alpha, const DeviceMemory &a, int lda, - int64_t stride_a, const DeviceMemory &b, int ldb, - int64_t stride_b, ConstantType beta, DeviceMemory *c, int ldc, - int64_t stride_c, int batch_count, blas::ComputationType computation_type, - blas::AlgorithmType algorithm, const NumericOptions &numeric_options, - blas::ProfileResult *output_profile_result, blas::CallContext context) { - TF_RETURN_IF_ERROR( - CheckTypesForExtendedBlas( - computation_type)); - - blas::BlasSupport *blas = parent()->AsBlas(); - if (!blas) { - return tsl::errors::Internal( - "Attempting to perform BLAS operation using " - "StreamExecutor without BLAS support"); - } - void *alpha_ptr = α - void *beta_ptr = β - float alpha_storage, beta_storage; - UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, - &beta_storage); - tsl::Status st = blas->DoBlasGemmStridedBatchedWithAlgorithm( - this, transa, transb, m, n, k, alpha_ptr, a, - blas::ToDataType::value, lda, stride_a, b, - blas::ToDataType::value, ldb, stride_b, beta_ptr, c, - blas::ToDataType::value, ldc, stride_c, batch_count, - computation_type, algorithm, numeric_options, output_profile_result, - context); - if (output_profile_result) { - // The error is recorded in the profile. - return ::tsl::OkStatus(); - } - return st; - } - - template - using DeviceMemorySlice = absl::Span *const>; - - // See BlasSupport::DoBlasGemmBatched. - Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64 n, uint64 k, float alpha, - DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, float beta, - DeviceMemorySlice c, int ldc, - int batch_count, - const NumericOptions &numeric_options, - blas::CallContext context); - - Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64 n, uint64_t k, - std::complex alpha, - DeviceMemorySlice> a, int lda, - DeviceMemorySlice> b, int ldb, - std::complex beta, - DeviceMemorySlice> c, int ldc, - int batch_count, - const NumericOptions &numeric_options, - blas::CallContext context); - Stream &ThenBlasGemmBatched( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, std::complex alpha, - DeviceMemorySlice> a, int lda, - DeviceMemorySlice> b, int ldb, - std::complex beta, DeviceMemorySlice> c, - int ldc, int batch_count, const NumericOptions &numeric_options, - blas::CallContext context); - - Stream &ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, float alpha, DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, float beta, - DeviceMemorySlice c, int ldc, int batch_count, - const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context); - - Stream &ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, float alpha, DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, float beta, - DeviceMemorySlice c, int ldc, int batch_count, - const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context); - - Stream &ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, float alpha, DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, float beta, - DeviceMemorySlice c, int ldc, int batch_count, - const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context); - - Stream &ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, double alpha, DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, double beta, - DeviceMemorySlice c, int ldc, int batch_count, - const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context); - - Stream &ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, std::complex alpha, - DeviceMemorySlice> a, int lda, - DeviceMemorySlice> b, int ldb, - std::complex beta, DeviceMemorySlice> c, - int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context); - - Stream &ThenBlasGemmBatchedWithScratch( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, std::complex alpha, - DeviceMemorySlice> a, int lda, - DeviceMemorySlice> b, int ldb, - std::complex beta, DeviceMemorySlice> c, - int ldc, int batch_count, const NumericOptions &numeric_options, - ScratchAllocator *scratch_allocator, blas::CallContext context); - - template - tsl::Status ThenBlasGemmStridedBatched( - blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, - uint64_t k, ConstantType alpha, const DeviceMemory &a, int lda, - int64_t stride_a, const DeviceMemory &b, int ldb, - int64_t stride_b, ConstantType beta, DeviceMemory *c, int ldc, - int64_t stride_c, int batch_count, const NumericOptions &numeric_options, - blas::CallContext context) { - static_assert( - detail::is_any_of, - std::complex>(), - "Unsupported input type"); - static_assert(std::is_same_v || - (detail::is_any_of() && - std::is_same_v), - "Mismatched input and alpha/beta types"); - blas::BlasSupport *blas = parent()->AsBlas(); - if (!blas) { - return tsl::errors::Internal( - "Attempting to perform BLAS operation using " - "StreamExecutor without BLAS support"); - } - - void *alpha_ptr = α - void *beta_ptr = β - float alpha_storage, beta_storage; - UpcastHalfToFloat(&alpha_ptr, &beta_ptr, &alpha_storage, - &beta_storage); - - return blas->DoBlasGemmStridedBatched( - this, transa, transb, m, n, k, blas::ToDataType::value, - alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc, - stride_c, batch_count, numeric_options, context); - } - - // See BlasSupport::DoBlasTrsm. - Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, uint64_t m, - uint64_t n, float alpha, const DeviceMemory &a, - int lda, DeviceMemory *b, int ldb); - Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, uint64_t m, - uint64_t n, double alpha, const DeviceMemory &a, - int lda, DeviceMemory *b, int ldb); - Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, uint64_t m, - uint64_t n, std::complex alpha, - const DeviceMemory> &a, int lda, - DeviceMemory> *b, int ldb); - Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, uint64_t m, - uint64_t n, std::complex alpha, - const DeviceMemory> &a, int lda, - DeviceMemory> *b, int ldb); - - // See BlasSupport::DoBlasTrsmBatched. - Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, float alpha, - const DeviceMemory &as, int lda, - DeviceMemory *bs, int ldb, - int batch_count); - Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, double alpha, - const DeviceMemory &as, int lda, - DeviceMemory *bs, int ldb, - int batch_count); - Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, std::complex alpha, - const DeviceMemory *> &as, - int lda, DeviceMemory *> *bs, - int ldb, int batch_count); - Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, - blas::Transpose transa, blas::Diagonal diag, - uint64_t m, uint64 n, std::complex alpha, - const DeviceMemory *> &as, - int lda, DeviceMemory *> *bs, - int ldb, int batch_count); - - // See FftSupport::DoFft. - Stream &ThenFft(fft::Plan *plan, - const DeviceMemory> &input, - DeviceMemory> *output); - Stream &ThenFft(fft::Plan *plan, - const DeviceMemory> &input, - DeviceMemory> *output); - Stream &ThenFft(fft::Plan *plan, const DeviceMemory &input, - DeviceMemory> *output); - Stream &ThenFft(fft::Plan *plan, const DeviceMemory &input, - DeviceMemory> *output); - Stream &ThenFft(fft::Plan *plan, - const DeviceMemory> &input, - DeviceMemory *output); - Stream &ThenFft(fft::Plan *plan, - const DeviceMemory> &input, - DeviceMemory *output); + absl::Status RecordEvent(Event *event); // Entrain onto the stream: a memcpy to a host destination from a GPU source // of the given target size. host_dst must be a pointer to host memory // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and // then registered with StreamExecutor::HostMemoryRegister. - Stream &ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src, - uint64_t size); + absl::Status Memcpy(void *host_dst, const DeviceMemoryBase &gpu_src, + uint64_t size); // Entrain onto the stream: a memcpy to a GPU destination from a host source // of the given target size. host_src must be a pointer to host memory // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and // then registered with StreamExecutor::HostMemoryRegister. - Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src, - uint64_t size); + absl::Status Memcpy(DeviceMemoryBase *gpu_dst, const void *host_src, + uint64_t size); // Alternative interface for memcpying from device to host that takes an // array slice. Checks that the destination size can accommodate the host // slice size. template - Stream &ThenMemcpyD2H(const DeviceMemory &gpu_src, - absl::Span host_dst) { + absl::Status MemcpyD2H(const DeviceMemory &gpu_src, + absl::Span host_dst) { auto host_size = host_dst.size() * sizeof(T); - CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size()); - return ThenMemcpy(host_dst.begin(), gpu_src, host_size); + if (gpu_src.size() == 0 || host_size >= gpu_src.size()) { + return Memcpy(host_dst.begin(), gpu_src, host_size); + } + return absl::InternalError("Bad source size."); } // Alternative interface for memcpying from host to device that takes an // array slice. Checks that the destination size can accommodate the host // slice size. template - Stream &ThenMemcpyH2D(absl::Span host_src, - DeviceMemory *gpu_dst) { + absl::Status MemcpyH2D(absl::Span host_src, + DeviceMemory *gpu_dst) { auto host_size = host_src.size() * sizeof(T); - CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size); - return ThenMemcpy(gpu_dst, host_src.begin(), host_size); + if (gpu_dst->size() == 0 || gpu_dst->size() >= host_size) { + return Memcpy(gpu_dst, host_src.begin(), host_size); + } + return absl::InternalError("Bad destination size."); } // Entrain onto the stream: a memcpy to a GPU destination from a GPU source // of the given target size. gpu_src/dst must be pointers to GPU memory and // peer access must be enabled between their owning StreamExecutors. - Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, - uint64_t size); - - // Calls to the device-to-device copy overload of ThenMemcpy -- useful for - // ensuring that the host pointer isn't getting confused accidentally with a - // device pointer if you're not doing metaprogramming against the API. - Stream &ThenMemcpyD2D(DeviceMemoryBase *gpu_dst, - const DeviceMemoryBase &gpu_src, uint64_t size) { - return ThenMemcpy(gpu_dst, gpu_src, size); + absl::Status Memcpy(DeviceMemoryBase *gpu_dst, + const DeviceMemoryBase &gpu_src, uint64_t size); + absl::Status MemcpyD2D(DeviceMemoryBase *gpu_dst, + const DeviceMemoryBase &gpu_src, uint64_t size) { + return Memcpy(gpu_dst, gpu_src, size); } // Entrain onto the stream: a memset of zero at a GPU location of size bytes. // The location must not be null. - Stream &ThenMemZero(DeviceMemoryBase *location, uint64_t size); + absl::Status MemZero(DeviceMemoryBase *location, uint64_t size); // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location of // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly divisible // by 4). The location must not be null. - Stream &ThenMemset32(DeviceMemoryBase *location, uint32_t pattern, - uint64_t size); - - // Enqueue a forward operation of the RNN model onto the stream. - // See DnnSupport::DoRnnForward for more details. - Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - DeviceMemory *output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - DeviceMemory *output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - DeviceMemory *output_c_data, - bool is_training, - ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result); - - Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - DeviceMemory *output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - DeviceMemory *output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - DeviceMemory *output_c_data, bool is_training, - ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result); - - Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - DeviceMemory *output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - DeviceMemory *output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - DeviceMemory *output_c_data, bool is_training, - ScratchAllocator *reserve_space_allocator, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result); - - // Enqueue a backward operation of the RNN model onto the stream. - // See DnnSupport::DoRnnBackward for more details. - Stream &ThenRnnBackward( - const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - const DeviceMemory &output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - const DeviceMemory &output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - const DeviceMemory &output_c_data, - const DeviceMemory &output_backprop_data, - const DeviceMemory &output_h_backprop_data, - const DeviceMemory &output_c_backprop_data, - DeviceMemory *input_backprop_data, - DeviceMemory *input_h_backprop_data, - DeviceMemory *input_c_backprop_data, - DeviceMemory *params_backprop_data, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result); - - Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - const DeviceMemory &output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - const DeviceMemory &output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - const DeviceMemory &output_c_data, - const DeviceMemory &output_backprop_data, - const DeviceMemory &output_h_backprop_data, - const DeviceMemory &output_c_backprop_data, - DeviceMemory *input_backprop_data, - DeviceMemory *input_h_backprop_data, - DeviceMemory *input_c_backprop_data, - DeviceMemory *params_backprop_data, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result); - - Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc, - const dnn::RnnSequenceTensorDescriptor &input_desc, - const DeviceMemory &input_data, - const DeviceMemory &seq_lengths_data, - const dnn::RnnStateTensorDescriptor &input_h_desc, - const DeviceMemory &input_h_data, - const dnn::RnnStateTensorDescriptor &input_c_desc, - const DeviceMemory &input_c_data, - const DeviceMemory ¶ms, - const dnn::RnnSequenceTensorDescriptor &output_desc, - const DeviceMemory &output_data, - const dnn::RnnStateTensorDescriptor &output_h_desc, - const DeviceMemory &output_h_data, - const dnn::RnnStateTensorDescriptor &output_c_desc, - const DeviceMemory &output_c_data, - const DeviceMemory &output_backprop_data, - const DeviceMemory &output_h_backprop_data, - const DeviceMemory &output_c_backprop_data, - DeviceMemory *input_backprop_data, - DeviceMemory *input_h_backprop_data, - DeviceMemory *input_c_backprop_data, - DeviceMemory *params_backprop_data, - DeviceMemory *reserve_space_data, - ScratchAllocator *workspace_allocator, - dnn::ProfileResult *output_profile_result); - - // Enqueue a CTCLoss operation onto the stream. - // See DnnSupport::DoCtcLoss for more details. - Stream &ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc, - const DeviceMemory &probs_data, - absl::Span labels_data, - absl::Span labels_lengths_data, - absl::Span input_lengths_data, - const NumericOptions &numeric_options, - DeviceMemory *costs_data, - const dnn::RnnStateTensorDescriptor &grads_desc, - DeviceMemory *grads_data, - ScratchAllocator *workspace_allocator); - - // Enqueue onto the stream a operation that transforms a tensor. - // See DnnSupport::DoTransformTensor for more details. - Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc, - dnn::DataType input_type, - const DeviceMemoryBase &input_data, - const dnn::BatchDescriptor &output_desc, - dnn::DataType output_type, float scale, - DeviceMemoryBase *output_data); + absl::Status Memset32(DeviceMemoryBase *location, uint32_t pattern, + uint64_t size); // (Synchronously) block the host code waiting for the operations // entrained on the stream (enqueued to this point in program @@ -1335,39 +253,20 @@ class Stream { // // Returns an OK status if the blocking was successful and the stream is ok(). // Otherwise returns an error describing why the blocking failed. - tsl::Status BlockHostUntilDone() TF_LOCKS_EXCLUDED(mu_); - - // Warning! This method interacts with internal threads in - // sometimes-unpredictable ways and is intended for GPU-Executor-internal - // use - // only. Please check with a member of the FASTR team before making use of - // this method. - // - // Entrains onto the stream a function to be executed on the host at some - // point in the future. - // Async host callbacks DO NOT block the stream as device functions (or as - // synchronous host callbacks). No synchronization is possible with - // asynchronous callbacks; they are strictly fire-and-forget. - // This method is private due to the potential for undefined behavior with - // synchronization using OpenCL user events. - // The ONLY lifetime guarantee in these calls is that the StreamExecutor - // parameter will still be valid - this Stream may not be! - // Any callbacks requiring device API calls must use this method. - Stream &ThenEnqueueOnBackgroundThread( - std::function task); + absl::Status BlockHostUntilDone() TF_LOCKS_EXCLUDED(mu_); // Returns the (opaque) platform-specific backing object. Ownership is not // transferred to the caller. internal::StreamInterface *implementation() { return implementation_.get(); } // Entrains onto the stream a callback to the host (from the device). - // Behaves as ThenDoHostCallbackWithStatus below, but the callback should + // Behaves as DoHostCallbackWithStatus below, but the callback should // never fail or its failure is inconsequential. // // This is kept for backward compatibility. Future code should use - // ThenDoHostCallbackWithStatus and explicitly return a success status. + // DoHostCallbackWithStatus and explicitly return a success status. // TODO(b/112125301): Eventually remove this method. - Stream &ThenDoHostCallback(absl::AnyInvocable callback); + absl::Status DoHostCallback(absl::AnyInvocable callback); // Entrains onto the stream a callback to the host (from the device). // Host callbacks block/occupy the stream just as device functions @@ -1375,14 +274,10 @@ class Stream { // Whether the callback return status affects the result of BlockHostUntilDone // is platform-dependent. // - // Behavior is undefined when synchronizing using OpenCL user events. - // Behavior is undefined if host callbacks call device routines or insert - // them into any stream. - // - // On certain platforms, ThenDoHostCallback is expected to have significant + // On certain platforms, DoHostCallback is expected to have significant // negative effects on performance. - Stream &ThenDoHostCallbackWithStatus( - absl::AnyInvocable callback); + absl::Status DoHostCallbackWithStatus( + absl::AnyInvocable callback); // Returns the StreamExecutor (parent object) associated with this stream. StreamExecutor *parent() const { @@ -1398,64 +293,10 @@ class Stream { RocmComputeCapability GetRocmComputeCapability() const { return parent()->GetDeviceDescription().rocm_compute_capability(); } - // Returns the (internal usage) temporary-memory-allocation manager associated - // with this stream. - internal::TemporaryMemoryManager *temporary_memory_manager(); - - // Returns a debugging string "[stream=0x...,impl=0x...]". - std::string DebugStreamPointers() const; - - void SetPriority(StreamPriority priority); - void SetPriority(int priority); std::variant priority() const; private: - template - friend struct ThenBlasImpl; // for implementing ThenBlasXXX. - - // Checks whether types match before a call to extended BLAS version. - template - tsl::Status CheckTypesForExtendedBlas( - blas::ComputationType computation_type) { - static_assert( - detail::is_any_of, std::complex>(), - "The only buffer types supported are: Eigen::half, float, " - "double, int8, std::complex and std::complex"); - static_assert( - std::is_same_v || - (std::is_same_v && - detail::is_any_of()), - "Mismatched alpha/beta and output types"); - - bool valid_computation_type = [computation_type] { - switch (computation_type) { - case blas::ComputationType::kF16: - return std::is_same_v; - case blas::ComputationType::kF32: - return detail::is_any_of>(); - case blas::ComputationType::kF64: - return detail::is_any_of>(); - case blas::ComputationType::kI32: - return std::is_same_v; - case blas::ComputationType::kF16AsF32: // fall-through - case blas::ComputationType::kBF16AsF32: // fall-through - case blas::ComputationType::kTF32AsF32: - return detail::is_any_of>(); - } - }(); - - if (!valid_computation_type) { - return tsl::errors::Internal( - "Invalid computation type ", - blas::ComputationTypeString(computation_type), " for output type: ", - blas::DataTypeString(blas::ToDataType::value)); - } - return ::tsl::OkStatus(); - } - bool InErrorState() const TF_LOCKS_EXCLUDED(mu_) { absl::ReaderMutexLock lock(&mu_); return !status_.ok(); @@ -1466,20 +307,10 @@ class Stream { void CheckError(bool operation_retcode) TF_LOCKS_EXCLUDED(mu_); // Checks the status and logs the error message, if any. - void CheckStatus(tsl::Status status) TF_LOCKS_EXCLUDED(mu_); + void CheckStatus(absl::Status status) TF_LOCKS_EXCLUDED(mu_); void SetError() { CheckError(false /* = operation_retcode */); } - void SetErrorAndLogNoDnnSupport() { - SetError(); - LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor " - "without DNN support"; - } - - // Runs the set of callbacks that are intended to run after - // BlockHostUntilDone. - void RunAfterBlockHostUntilDoneCallbacks(); - // The StreamExecutor that supports the operation of this stream. StreamExecutor *parent_; @@ -1491,13 +322,8 @@ class Stream { // Mutable so that it can be obtained via const reader lock. mutable absl::Mutex mu_; - // Whether Init() was successfully called to allocate this stream on the - // underlying platform. It simply flips from 0 to 1 with a sanity check. - // See StreamExecutor::AllocateStream. - bool allocated_ ABSL_GUARDED_BY(mu_); - // The last error (if any) of all method calls. - tsl::Status status_ ABSL_GUARDED_BY(mu_); + absl::Status status_ ABSL_GUARDED_BY(mu_); // Sub-streams that are generated from this stream. Each element has a pointer // to sub-stream and a boolean value indicating if this substream is ready to @@ -1505,39 +331,6 @@ class Stream { std::vector, bool>> sub_streams_ ABSL_GUARDED_BY(mu_); - // Streams can allocate temporary memories to help with work they enqueue - // (e.g. for scratch memory spaces). This member tracks those allocations and - // notes when they can be reclaimed -- reclamation is attempted when - // BlockHostUntilDone() is called. - internal::TemporaryMemoryManager temporary_memory_manager_; - - // Callbacks enqueued to be run after the next call to BlockHostUntilDone(). - std::vector> - after_block_host_until_done_callbacks_ ABSL_GUARDED_BY(mu_); - - // Non-extended BLAS interface requires alpha/beta to be floats when input - // type is Eigen::half. However, for consistency purposes it is convenient - // for the interface to accept Eigen::half. - template - void UpcastHalfToFloat(void **alpha_ptr, void **beta_ptr, - float *alpha_storage, float *beta_storage) { - if (std::is_same::value) { - *alpha_storage = - static_cast(*reinterpret_cast(*alpha_ptr)); - *beta_storage = - static_cast(*reinterpret_cast(*beta_ptr)); - *alpha_ptr = alpha_storage; - *beta_ptr = beta_storage; - } else if (std::is_same::value) { - *alpha_storage = - static_cast(*reinterpret_cast(*alpha_ptr)); - *beta_storage = - static_cast(*reinterpret_cast(*beta_ptr)); - *alpha_ptr = alpha_storage; - *beta_ptr = beta_storage; - } - } - Stream(const Stream &) = delete; void operator=(const Stream &) = delete; }; @@ -1546,55 +339,49 @@ class Stream { // Inlines template -inline tsl::Status Stream::ThenLaunch(ThreadDim thread_dims, - BlockDim block_dims, - const TypedKernel &kernel, - Args... args) { +inline absl::Status Stream::ThenLaunch(ThreadDim thread_dims, + BlockDim block_dims, + const TypedKernel &kernel, + Args... args) { auto kernel_args = PackKernelArgs(kernel, args...); TF_RETURN_IF_ERROR( - parent_->Launch(this, thread_dims, block_dims, kernel, *kernel_args)); - return ::tsl::OkStatus(); + parent_->Launch(this, thread_dims, block_dims, *kernel, *kernel_args)); + return absl::OkStatus(); } template -inline tsl::Status Stream::ThenLaunch(ThreadDim thread_dims, - BlockDim block_dims, int32_t shmem_bytes, - const TypedKernel &kernel, - Args... args) { +inline absl::Status Stream::ThenLaunch(ThreadDim thread_dims, + BlockDim block_dims, int32_t shmem_bytes, + const TypedKernel &kernel, + Args... args) { auto kernel_args = PackKernelArgs(shmem_bytes, args...); TF_RETURN_IF_ERROR( - parent_->Launch(this, thread_dims, block_dims, kernel, *kernel_args)); - return ::tsl::OkStatus(); + parent_->Launch(this, thread_dims, block_dims, *kernel, *kernel_args)); + return absl::OkStatus(); } -template -inline tsl::StatusOr>> -Stream::AllocateTemporaryArray(uint64_t element_count) { - return temporary_memory_manager_.AllocateArray(element_count); +template +inline absl::Status Stream::ThenLaunch(ThreadDim thread_dims, + BlockDim block_dims, + ClusterDim cluster_dims, + const TypedKernel &kernel, + Args... args) { + auto kernel_args = PackKernelArgs(kernel, args...); + TF_RETURN_IF_ERROR(parent_->Launch(this, thread_dims, block_dims, + cluster_dims, *kernel, *kernel_args)); + return absl::OkStatus(); } -inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() { - return &temporary_memory_manager_; +template +inline absl::Status Stream::ThenLaunch( + ThreadDim thread_dims, BlockDim block_dims, ClusterDim cluster_dims, + int32_t shmem_bytes, const TypedKernel &kernel, Args... args) { + auto kernel_args = PackKernelArgs(shmem_bytes, args...); + TF_RETURN_IF_ERROR(parent_->Launch(this, thread_dims, block_dims, + cluster_dims, *kernel, *kernel_args)); + return absl::OkStatus(); } -template <> -struct Quantization { - static constexpr dnn::QuantizedActivationMode kModeId = - dnn::QuantizedActivationMode::k8Bit; -}; - -template <> -struct Quantization { - static constexpr dnn::QuantizedActivationMode kModeId = - dnn::QuantizedActivationMode::k16Bit; -}; - -template <> -struct Quantization { - static constexpr dnn::QuantizedActivationMode kModeId = - dnn::QuantizedActivationMode::k32Bit; -}; - } // namespace stream_executor #endif // XLA_STREAM_EXECUTOR_STREAM_H_ diff --git a/xla/stream_executor/stream_executor.h b/xla/stream_executor/stream_executor.h index 76b9921b24052..4cfe8cf327d29 100644 --- a/xla/stream_executor/stream_executor.h +++ b/xla/stream_executor/stream_executor.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,16 +22,7 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ #define XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ -#include "xla/stream_executor/device_description.h" // IWYU pragma: export -#include "xla/stream_executor/device_memory.h" // IWYU pragma: export -#include "xla/stream_executor/device_options.h" // IWYU pragma: export -#include "xla/stream_executor/event.h" // IWYU pragma: export -#include "xla/stream_executor/kernel.h" // IWYU pragma: export -#include "xla/stream_executor/kernel_spec.h" // IWYU pragma: export -#include "xla/stream_executor/launch_dim.h" // IWYU pragma: export -#include "xla/stream_executor/multi_platform_manager.h" // IWYU pragma: export -#include "xla/stream_executor/platform.h" // IWYU pragma: export -#include "xla/stream_executor/stream.h" // IWYU pragma: export +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_pimpl.h" // IWYU pragma: export #endif // XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ diff --git a/xla/stream_executor/stream_executor_internal.h b/xla/stream_executor/stream_executor_internal.h index 8c17dbdc15560..c837a5ce6e7a1 100644 --- a/xla/stream_executor/stream_executor_internal.h +++ b/xla/stream_executor/stream_executor_internal.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,17 +27,16 @@ limitations under the License. #include #include #include -#include #include "absl/functional/any_invocable.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/stream_executor/allocator_stats.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_options.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/fft.h" @@ -47,10 +46,6 @@ limitations under the License. #include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/port.h" -#include "xla/stream_executor/trace_listener.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace stream_executor { @@ -74,148 +69,6 @@ class EventInterface { void operator=(const EventInterface&) = delete; }; -//===----------------------------------------------------------------------===// -// KernelInterface -//===----------------------------------------------------------------------===// - -// Pointer-to-implementation object type (i.e. the Kernel class delegates to -// this interface) with virtual destruction. This class exists for the -// platform-dependent code to hang any kernel data/resource info/functionality -// off of. -class KernelInterface { - public: - // Default constructor for the abstract interface. - KernelInterface() = default; - - // Default destructor for the abstract interface. - virtual ~KernelInterface() = default; - - // Returns the number of formal parameters that this kernel accepts. - virtual unsigned Arity() const = 0; - - // Sets the preferred cache configuration. - virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0; - - // Gets the preferred cache configuration. - virtual KernelCacheConfig GetPreferredCacheConfig() const = 0; - - private: - KernelInterface(const KernelInterface&) = delete; - void operator=(const KernelInterface&) = delete; -}; - -//===----------------------------------------------------------------------===// -// CommandBufferInterface -//===----------------------------------------------------------------------===// - -// Platform-dependent interface class for implementing generic CommandBuffer. -// -// TODO(ezhulenev): Currently we assume that all operations between barriers -// can execute concurrently, and it's up to the caller to insert barriers to -// guarantee correctness. Consider adding finer grained synchronization -// mechanism between different commands. -// -// TODO(ezhulenev): Currently command buffers do no support updates, and once -// finalized can be executed as recorded. We need to support cheap command -// buffer updates that in GPU backend will be mapped to CUDA/HIP graph node -// updates. -class CommandBufferInterface { - public: - CommandBufferInterface() = default; - virtual ~CommandBufferInterface() = default; - - // Traces `function` invocation by recording all operations on the `stream` - // into the command buffer. Command buffer must be empty. - virtual tsl::Status Trace(Stream* stream, - absl::AnyInvocable function) = 0; - - // Adds a kernel launch command to the command buffer. - virtual tsl::Status Launch(const ThreadDim& threads, const BlockDim& blocks, - const Kernel& kernel, const KernelArgs& args) = 0; - - // Adds a nested command buffer to the command buffer. - virtual tsl::Status AddNestedCommandBuffer(const CommandBuffer& nested) = 0; - - // Adds a device-to-device memory copy to the command buffer. - virtual tsl::Status MemcpyDeviceToDevice(DeviceMemoryBase* dst, - const DeviceMemoryBase& src, - uint64_t size) = 0; - - // Adds a memset node to the command buffer. - virtual tsl::Status Memset(DeviceMemoryBase* dst, - CommandBuffer::BitPattern bit_pattern, - size_t num_elements) = 0; - - // Adds a device memory allocation node to the command buffer. - virtual tsl::StatusOr Allocate(size_t bytes) = 0; - - // For all conditional command APIs defined below, nested command buffers - // constructed for conditional branches owned by *this and should never be - // finalized or updated inside builders. - - // Adds a conditional operation that will run a command buffer constructed by - // `then_builder` if `predicate` value is `true`. - virtual tsl::Status If(StreamExecutor* executor, DeviceMemory predicate, - CommandBuffer::Builder then_builder) = 0; - - // Adds a conditional operation that will run a command buffer constructed by - // `then_builder` if `predicate` value is `true`, or a command buffer - // constructed by `else_builder` if `predicate` is `false`. - virtual tsl::Status IfElse(StreamExecutor* executor, - DeviceMemory predicate, - CommandBuffer::Builder then_builder, - CommandBuffer::Builder else_builder) = 0; - - // Adds a conditional operation that will run a command buffer constructed by - // the `branches` builder at `index`. If `index` is out of range, then it will - // run a conditional command buffer constructed by the last builder. - // - // See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case - virtual tsl::Status Case(StreamExecutor* executor, - DeviceMemory index, - std::vector branches) = 0; - - // Adds a conditional operation that will run a command buffer constructed by - // the `body_builder` exactly `num_iteration` times. - virtual tsl::Status For(StreamExecutor* executor, int32_t num_iteration, - DeviceMemory loop_index, - CommandBuffer::Builder body_builder) = 0; - - // Adds a conditional operation that will execute a command buffer constructed - // by the `cond_builder` that must update `pred` value, and then depending on - // the value might execute command buffer constructed by `body_builder` and - // `cond_builder`. Will continue while `pred` value is `true`. - // - // In pseudocode: - // - // cond_builder() - // while(pred): - // body_builder() - // cond_builder() - // - virtual tsl::Status While(StreamExecutor* executor, DeviceMemory pred, - CommandBuffer::Builder cond_builder, - CommandBuffer::Builder body_builder) = 0; - - // Finalizes command buffer and makes it executable. Once command buffer is - // finalized no commands can be added to it. - virtual tsl::Status Finalize() = 0; - - // Begins command buffer update. Command buffer update should be finalized - // before it can be executed. - virtual tsl::Status Update() = 0; - - // Returns command buffer execution mode. - virtual CommandBuffer::Mode mode() const = 0; - - // Returns command buffer state. - virtual CommandBuffer::State state() const = 0; - - private: - CommandBufferInterface(const CommandBufferInterface&) = delete; - void operator=(const CommandBufferInterface&) = delete; -}; - //===----------------------------------------------------------------------===// // StreamInterface //===----------------------------------------------------------------------===// @@ -233,13 +86,9 @@ class StreamInterface { virtual ~StreamInterface() = default; // Sets priority for a stream. - virtual void SetPriority(StreamPriority priority) { - LOG(ERROR) << "SetPriority unimplemented for this stream."; - } + virtual void SetPriority(StreamPriority priority) {} - virtual void SetPriority(int priority) { - LOG(ERROR) << "SetPriority unimplemented for this stream."; - } + virtual void SetPriority(int priority) {} // Gets priority for a stream. virtual std::variant priority() const { @@ -277,8 +126,7 @@ class StreamExecutorInterface { virtual StreamExecutorInterface* GetUnderlyingExecutor() { return this; } // See the StreamExecutor interface for comments on the same-named methods. - virtual tsl::Status Init(int device_ordinal, - DeviceOptions device_options) = 0; + virtual absl::Status Init(int device_ordinal) = 0; // This value is cached by the wrapping StreamExecutor instance, so it's OK if // this function is slow. @@ -290,27 +138,34 @@ class StreamExecutorInterface { virtual int device_ordinal() const { return -1; } - virtual tsl::Status GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) { + virtual absl::Status GetKernel(const MultiKernelLoaderSpec& spec, + Kernel* kernel) { return absl::UnimplementedError("Not Implemented"); } virtual bool UnloadModule(ModuleHandle module_handle) { return false; } - virtual tsl::Status LoadModule(const MultiModuleLoaderSpec& spec, - ModuleHandle* module_handle) { + virtual absl::Status LoadModule(const MultiModuleLoaderSpec& spec, + ModuleHandle* module_handle) { return absl::UnimplementedError("Not Implemented"); } - virtual tsl::StatusOr> + virtual absl::StatusOr> CreateOrShareConstant(Stream* stream, absl::Span content) { return absl::UnimplementedError("Not Implemented"); } - virtual tsl::Status Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, const Kernel& k, - const KernelArgs& args) { + virtual absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, + const BlockDim& block_dims, const Kernel& k, + const KernelArgs& args) { + return absl::UnimplementedError("Not Implemented"); + } + + virtual absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, + const BlockDim& block_dims, + const ClusterDim& cluster_dims, const Kernel& k, + const KernelArgs& args) { return absl::UnimplementedError("Not Implemented"); } - virtual tsl::Status Submit(Stream* stream, - const CommandBuffer& command_buffer) { + virtual absl::Status Submit(Stream* stream, + const CommandBuffer& command_buffer) { return absl::UnimplementedError("Not Implemented"); } @@ -320,8 +175,6 @@ class StreamExecutorInterface { DeviceMemoryBase Allocate(uint64_t size) { return Allocate(size, /*memory_space=*/0); } - virtual void* GetSubBuffer(DeviceMemoryBase* parent, uint64_t offset, - uint64_t size) = 0; virtual void Deallocate(DeviceMemoryBase* mem) = 0; // Allocates unified memory space of the given size, if supported. // See @@ -332,47 +185,54 @@ class StreamExecutorInterface { // Deallocates unified memory space previously allocated with // UnifiedMemoryAllocate. virtual void UnifiedMemoryDeallocate(void* mem) {} + virtual absl::StatusOr CollectiveMemoryAllocate(uint64_t size) { + return absl::UnimplementedError("Not implemented"); + } + virtual absl::Status CollectiveMemoryDeallocate(void* mem) { + return absl::UnimplementedError("Not implemented"); + } virtual void* HostMemoryAllocate(uint64_t size) = 0; virtual void HostMemoryDeallocate(void* mem) = 0; virtual bool HostMemoryRegister(void* mem, uint64_t size) = 0; virtual bool HostMemoryUnregister(void* mem) = 0; virtual bool SynchronizeAllActivity() = 0; - virtual tsl::Status SynchronousMemZero(DeviceMemoryBase* location, + virtual absl::Status SynchronousMemZero(DeviceMemoryBase* location, + uint64_t size) = 0; + virtual absl::Status SynchronousMemSet(DeviceMemoryBase* location, int value, + uint64_t size) = 0; + virtual absl::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst, + const void* host_src, uint64_t size) = 0; - virtual tsl::Status SynchronousMemSet(DeviceMemoryBase* location, int value, - uint64_t size) = 0; - virtual tsl::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst, - const void* host_src, - uint64_t size) = 0; - virtual tsl::Status SynchronousMemcpy(void* host_dst, - const DeviceMemoryBase& gpu_src, - uint64_t size) = 0; - virtual tsl::Status SynchronousMemcpyDeviceToDevice( + virtual absl::Status SynchronousMemcpy(void* host_dst, + const DeviceMemoryBase& gpu_src, + uint64_t size) = 0; + virtual absl::Status SynchronousMemcpyDeviceToDevice( DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, uint64_t size) = 0; - virtual tsl::Status MemZero(Stream* stream, DeviceMemoryBase* location, - uint64_t size) = 0; - virtual tsl::Status Memset(Stream* stream, DeviceMemoryBase* location, - uint8 pattern, uint64_t size) { - return tsl::errors::Internal("Not implemented"); + virtual absl::Status MemZero(Stream* stream, DeviceMemoryBase* location, + uint64_t size) = 0; + virtual absl::Status Memset(Stream* stream, DeviceMemoryBase* location, + uint8 pattern, uint64_t size) { + return absl::InternalError("Not implemented"); } - virtual tsl::Status Memset32(Stream* stream, DeviceMemoryBase* location, - uint32_t pattern, uint64_t size) = 0; - virtual bool Memcpy(Stream* stream, void* host_dst, - const DeviceMemoryBase& gpu_src, uint64_t size) = 0; - virtual bool Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, - const void* host_src, uint64_t size) = 0; + virtual absl::Status Memset32(Stream* stream, DeviceMemoryBase* location, + uint32_t pattern, uint64_t size) = 0; + virtual absl::Status Memcpy(Stream* stream, void* host_dst, + const DeviceMemoryBase& gpu_src, + uint64_t size) = 0; + virtual absl::Status Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, + const void* host_src, uint64_t size) = 0; virtual bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, uint64_t size) = 0; virtual bool HostCallback(Stream* stream, - absl::AnyInvocable callback) = 0; - virtual tsl::Status AllocateEvent(Event* event) = 0; - virtual tsl::Status DeallocateEvent(Event* event) = 0; - virtual tsl::Status RecordEvent(Stream* stream, Event* event) = 0; - virtual tsl::Status WaitForEvent(Stream* stream, Event* event) = 0; - virtual tsl::Status WaitForEventOnExternalStream(std::intptr_t stream, - Event* event) { + absl::AnyInvocable callback) = 0; + virtual absl::Status AllocateEvent(Event* event) = 0; + virtual absl::Status DeallocateEvent(Event* event) = 0; + virtual absl::Status RecordEvent(Stream* stream, Event* event) = 0; + virtual absl::Status WaitForEvent(Stream* stream, Event* event) = 0; + virtual absl::Status WaitForEventOnExternalStream(std::intptr_t stream, + Event* event) { return absl::UnimplementedError( "WaitForEventOnExternalStream not supported on this executor."); } @@ -380,12 +240,12 @@ class StreamExecutorInterface { virtual bool AllocateStream(Stream* stream) = 0; virtual void DeallocateStream(Stream* stream) = 0; virtual bool CreateStreamDependency(Stream* dependent, Stream* other) = 0; - virtual tsl::Status BlockHostUntilDone(Stream* stream) = 0; - virtual tsl::Status GetStatus(Stream* stream) { + virtual absl::Status BlockHostUntilDone(Stream* stream) = 0; + virtual absl::Status GetStatus(Stream* stream) { return absl::UnimplementedError( "GetStatus is not supported on this executor."); } - virtual tsl::Status EnablePeerAccessTo(StreamExecutorInterface* other) = 0; + virtual absl::Status EnablePeerAccessTo(StreamExecutorInterface* other) = 0; virtual bool CanEnablePeerAccessTo(StreamExecutorInterface* other) = 0; virtual int64_t GetDeviceLoad() { return -1; } @@ -411,25 +271,9 @@ class StreamExecutorInterface { // Creates a new DeviceDescription object. Ownership is transferred to the // caller. - virtual tsl::StatusOr> + virtual absl::StatusOr> CreateDeviceDescription() const = 0; - // Attempts to register the provided TraceListener with the device-specific - // Executor implementation. When this is called, the PIMPL interface has - // already taken ownership of the object and is managing the generic tracing - // events. The device-specific implementation must determine if the passed - // listener is of a type appropriate for it to trace during registration (and - // before dispatching events to it). - // Returns true if the listener was successfully registered, false otherwise. - // Does not take ownership of listener. - virtual bool RegisterTraceListener(TraceListener* listener) { return false; } - - // Unregisters the specified listener from the device-specific Executor. - // Returns true if the listener was successfully registered, false otherwise. - virtual bool UnregisterTraceListener(TraceListener* listener) { - return false; - } - // Creates a new BlasSupport object, ownership is transferred to the caller. // // This may return null if the BLAS initialization fails or this object does @@ -450,19 +294,16 @@ class StreamExecutorInterface { // Each call creates a new instance of the platform-specific implementation of // the corresponding interface type. virtual std::unique_ptr CreateEventImplementation() = 0; - virtual std::unique_ptr CreateKernelImplementation() = 0; virtual std::unique_ptr GetStreamImplementation() = 0; - virtual tsl::StatusOr> - GetCommandBufferImplementation(CommandBuffer::Mode mode) { - return absl::UnimplementedError("Command buffers are not implemented"); + virtual absl::StatusOr> CreateKernel() { + return absl::UnimplementedError("Kernels are not implemented"); } - // Returns a pointer to a platform specific context associated with this - // object if it exists, or nullptr otherwise. This is available via - // StreamExecutor public API as StreamExecuto::PlatformSpecificHandle, and - // should not be accessed directly outside of a StreamExecutor package. - virtual void* platform_specific_context() { return nullptr; } + virtual absl::StatusOr> CreateCommandBuffer( + CommandBuffer::Mode mode) { + return absl::UnimplementedError("Command buffers are not implemented"); + } // Return allocator statistics. virtual std::optional GetAllocatorStats() { @@ -479,7 +320,7 @@ class StreamExecutorInterface { // Clears the compilation cache from volatile memory. Returns OK if no // compilation cache exists or if clearing the compilation cache is // unsupported. Caches in non-volatile storage are unaffected. - virtual tsl::Status FlushCompilationCache() { return ::tsl::OkStatus(); } + virtual absl::Status FlushCompilationCache() { return absl::OkStatus(); } // Returns a stream allocated by this executor, or nullptr if not found. // Performs linear search over alive GPU streams. diff --git a/xla/stream_executor/stream_executor_pimpl.cc b/xla/stream_executor/stream_executor_pimpl.cc index a58beb89c6bab..8105dfcd80608 100644 --- a/xla/stream_executor/stream_executor_pimpl.cc +++ b/xla/stream_executor/stream_executor_pimpl.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,28 +20,46 @@ limitations under the License. #include "xla/stream_executor/stream_executor_pimpl.h" #include +#include #include #include +#include +#include +#include #include +#include -#include "absl/base/const_init.h" #include "absl/functional/any_invocable.h" -#include "absl/strings/ascii.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/synchronization/notification.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/stream_executor/allocator_stats.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/event.h" #include "xla/stream_executor/fft.h" -#include "xla/stream_executor/platform/port.h" +#include "xla/stream_executor/host_memory_allocation.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/module_spec.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_internal.h" +#include "xla/tsl/util/env_var.h" #include "tsl/platform/errors.h" +#include "tsl/platform/numbers.h" #include "tsl/platform/stacktrace.h" +#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/platform/threadpool.h" -#include "tsl/util/env_var.h" namespace stream_executor { namespace { @@ -54,77 +72,8 @@ std::string StackTraceIfVLOG10() { } } -// Make sure the executor is done with its work; we know (because this isn't -// publicly visible) that all enqueued work is quick. -void BlockOnThreadExecutor(tsl::thread::ThreadPool* executor) { - absl::Notification n; - executor->Schedule([&n]() { n.Notify(); }); - n.WaitForNotification(); -} - -std::atomic_int_fast64_t correlation_id_generator(0); - } // namespace -template -class ScopedTracer { - public: - ScopedTracer(StreamExecutor* stream_exec, BeginCallT begin_call, - CompleteCallT complete_call, const ReturnT* result, - BeginArgsT... begin_args) - : stream_exec_(stream_exec), - complete_call_(complete_call), - result_(result) { - if (stream_exec_->tracing_enabled_) { - correlation_id_ = - correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1; - Trace(begin_call, begin_args...); - } - } - - ~ScopedTracer() { - if (stream_exec_->tracing_enabled_) { - Trace(complete_call_, result_); - } - } - - private: - template - void Trace(CallbackT callback, TraceArgsT... args) { - { - // Instance tracers held in a block to limit the lock lifetime. - absl::ReaderMutexLock lock{&stream_exec_->mu_}; - for (TraceListener* listener : stream_exec_->listeners_) { - (listener->*callback)(correlation_id_, - std::forward(args)...); - } - } - } - - StreamExecutor* stream_exec_; - CompleteCallT complete_call_; - const ReturnT* result_; - int64_t correlation_id_; -}; - -template -ScopedTracer -MakeScopedTracer(StreamExecutor* stream_exec, BeginCallT begin_call, - CompleteCallT complete_call, ReturnT* result, - BeginArgsT... begin_args) { - return ScopedTracer( - stream_exec, begin_call, complete_call, result, - std::forward(begin_args)...); -} - -#define SCOPED_TRACE(LOC, ...) \ - auto tracer = \ - MakeScopedTracer(this, &LOC##Begin, &LOC##Complete, ##__VA_ARGS__); - -/* static */ absl::Mutex StreamExecutor::static_mu_{absl::kConstInit}; - // Get per-device memory limit in bytes. Returns 0 if // TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set. static int64_t GetMemoryLimitBytes() { @@ -141,16 +90,11 @@ StreamExecutor::StreamExecutor( : platform_(platform), implementation_(std::move(implementation)), device_ordinal_(device_ordinal), - background_threads_(new tsl::thread::ThreadPool( - tsl::Env::Default(), "stream_executor", kNumBackgroundThreads)), live_stream_count_(0), - tracing_enabled_(false), memory_limit_bytes_(GetMemoryLimitBytes()), allocator_(this) {} StreamExecutor::~StreamExecutor() { - BlockOnThreadExecutor(background_threads_.get()); - if (live_stream_count_.load() != 0) { LOG(WARNING) << "Not all streams were deallocated at executor destruction " << "time. This may lead to unexpected/bad behavior - " @@ -158,23 +102,13 @@ StreamExecutor::~StreamExecutor() { } } -StreamExecutor::PlatformSpecificHandle -StreamExecutor::platform_specific_handle() const { - PlatformSpecificHandle handle; - handle.context = implementation_->platform_specific_context(); - return handle; -} - -tsl::Status StreamExecutor::Init(DeviceOptions device_options) { - TF_RETURN_IF_ERROR( - implementation_->Init(device_ordinal_, std::move(device_options))); - return ::tsl::OkStatus(); +absl::Status StreamExecutor::Init() { + TF_RETURN_IF_ERROR(implementation_->Init(device_ordinal_)); + return absl::OkStatus(); } -tsl::Status StreamExecutor::Init() { return Init(DeviceOptions::Default()); } - -tsl::Status StreamExecutor::GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) { +absl::Status StreamExecutor::GetKernel(const MultiKernelLoaderSpec& spec, + Kernel* kernel) { return implementation_->GetKernel(spec, kernel); } @@ -182,8 +116,8 @@ void StreamExecutor::UnloadKernel(const Kernel* kernel) { implementation_->UnloadKernel(kernel); } -tsl::Status StreamExecutor::LoadModule(const MultiModuleLoaderSpec& spec, - ModuleHandle* module_handle) { +absl::Status StreamExecutor::LoadModule(const MultiModuleLoaderSpec& spec, + ModuleHandle* module_handle) { return implementation_->LoadModule(spec, module_handle); } @@ -191,7 +125,7 @@ bool StreamExecutor::UnloadModule(ModuleHandle module_handle) { return implementation_->UnloadModule(module_handle); } -tsl::StatusOr> +absl::StatusOr> StreamExecutor::CreateOrShareConstant(Stream* stream, absl::Span content) { return implementation_->CreateOrShareConstant(stream, content); @@ -209,7 +143,7 @@ bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor* other) { return implementation_->CanEnablePeerAccessTo(other->implementation_.get()); } -tsl::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor* other) { +absl::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor* other) { return implementation_->EnablePeerAccessTo(other->implementation_.get()); } @@ -227,187 +161,6 @@ int64_t StreamExecutor::GetDeviceLoad() const { return implementation_->GetDeviceLoad(); } -tsl::Status StreamExecutor::GetConvolveRunners( - bool use_cudnn_frontend, dnn::ConvolutionKind kind, - dnn::DataType input_type, dnn::DataType output_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, - const dnn::FilterDescriptor& filter_descriptor, - DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor, - DeviceMemoryBase output_data, - const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback, - ScratchAllocator* scratch_allocator, const NumericOptions& numeric_options, - std::vector>* out_exec_plans) { - dnn::DnnSupport* dnn_support = AsDnn(); - if (!dnn_support) { - return tsl::errors::Unimplemented("DNN library is not found."); - } - return dnn_support->GetConvolveRunners( - use_cudnn_frontend, kind, input_type, output_type, stream, - input_descriptor, input_data, filter_descriptor, filter_data, - output_descriptor, output_data, convolution_descriptor, use_fallback, - scratch_allocator, numeric_options, out_exec_plans); -} - -tsl::Status StreamExecutor::GetGraphConvolveRunners( - dnn::ConvolutionKind kind, dnn::DataType input_type, - dnn::DataType output_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, - const dnn::FilterDescriptor& filter_descriptor, - const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback, - const NumericOptions& numeric_options, - std::vector>* out_exec_plans, - std::string serialized_graph) { - dnn::DnnSupport* dnn_support = AsDnn(); - if (!dnn_support) { - return tsl::errors::Unimplemented("DNN library is not found."); - } - return dnn_support->GetGraphConvolveRunners( - kind, input_type, output_type, stream, input_descriptor, - filter_descriptor, output_descriptor, convolution_descriptor, - use_fallback, numeric_options, out_exec_plans, serialized_graph); -} - -tsl::Status StreamExecutor::GetFusedConvolveRunners( - bool use_cudnn_frontend, dnn::ConvolutionKind kind, - dnn::DataType input_type, dnn::DataType bias_type, - dnn::DataType output_type, double conv_input_scale, double side_input_scale, - double leakyrelu_alpha, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, - const dnn::FilterDescriptor& filter_descriptor, - const dnn::BatchDescriptor& bias_descriptor, - const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback, - dnn::ActivationMode activation_mode, const NumericOptions& numeric_options, - std::vector>* out_exec_plans) { - dnn::DnnSupport* dnn_support = AsDnn(); - if (!dnn_support) { - return tsl::errors::Unimplemented("DNN library is not found."); - } - return dnn_support->GetFusedConvolveRunners( - use_cudnn_frontend, kind, input_type, bias_type, output_type, - conv_input_scale, side_input_scale, leakyrelu_alpha, stream, - input_descriptor, filter_descriptor, bias_descriptor, output_descriptor, - convolution_descriptor, use_fallback, activation_mode, numeric_options, - out_exec_plans); -} - -tsl::Status StreamExecutor::GetFusedMatmulRunners( - bool use_cudnn_frontend, dnn::DataType input_type, dnn::DataType bias_type, - dnn::DataType output_type, Stream* stream, bool trans_a, bool trans_b, - uint64_t m, uint64_t n, uint64_t k, int64_t lda, int64_t ldb, int64_t ldc, - dnn::ActivationMode activation_mode, bool use_fallback, - const NumericOptions& numeric_options, - std::vector>* - out_exec_plans) { - dnn::DnnSupport* dnn_support = AsDnn(); - if (!dnn_support) { - return tsl::errors::Unimplemented("DNN library is not found."); - } - - return dnn_support->GetFusedMatmulRunners( - use_cudnn_frontend, input_type, bias_type, output_type, stream, trans_a, - trans_b, m, n, k, lda, ldb, ldc, activation_mode, use_fallback, - numeric_options, out_exec_plans); -} - -bool StreamExecutor::GetMIOpenConvolveAlgorithms( - dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, - const dnn::FilterDescriptor& filter_descriptor, - DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor, - DeviceMemoryBase output_data, - const dnn::ConvolutionDescriptor& convolution_descriptor, - ScratchAllocator* scratch_allocator, - std::vector* out_algorithms) { - dnn::DnnSupport* dnn_support = AsDnn(); - if (!dnn_support) { - return false; - } - return dnn_support->GetMIOpenConvolveAlgorithms( - kind, element_type, stream, input_descriptor, input_data, - filter_descriptor, filter_data, output_descriptor, output_data, - convolution_descriptor, scratch_allocator, out_algorithms); -} - -bool StreamExecutor::GetRnnAlgorithms( - std::vector* out_algorithms) { - dnn::DnnSupport* dnn_support = AsDnn(); - if (!dnn_support) { - return false; - } - return dnn_support->GetRnnAlgorithms(out_algorithms); -} - -bool StreamExecutor::GetBlasGemmAlgorithms( - Stream* stream, std::vector* out_algorithms) { - blas::BlasSupport* blas_support = AsBlas(); - if (!blas_support) { - return false; - } - return blas_support->GetBlasGemmAlgorithms(stream, out_algorithms); -} - -tsl::StatusOr> -StreamExecutor::createRnnDescriptor( - int num_layers, int hidden_size, int input_size, int cell_size, - int batch_size, dnn::RnnInputMode input_mode, - dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, - dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, - const NumericOptions& numeric_options, float dropout, uint64_t seed, - ScratchAllocator* state_allocator, bool use_padded_io) { - dnn::DnnSupport* dnn_support = AsDnn(); - if (!dnn_support) { - return tsl::Status(absl::StatusCode::kUnknown, - "Fail to find the dnn implementation."); - } - return dnn_support->createRnnDescriptor( - num_layers, hidden_size, input_size, cell_size, batch_size, input_mode, - direction_mode, rnn_mode, data_type, algorithm_config, numeric_options, - dropout, seed, state_allocator, use_padded_io); -} - -tsl::StatusOr> -StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length, - int batch_size, int data_size, - dnn::DataType data_type) { - dnn::DnnSupport* dnn_support = AsDnn(); - if (!dnn_support) { - return tsl::Status(absl::StatusCode::kUnknown, - "Fail to find the dnn implementation."); - } - return dnn_support->createRnnSequenceTensorDescriptor( - max_seq_length, batch_size, data_size, data_type); -} - -tsl::StatusOr> -StreamExecutor::createRnnSequenceTensorDescriptor( - int max_seq_length, int batch_size, int data_size, - const absl::Span& seq_lengths, bool time_major, - dnn::DataType data_type) { - dnn::DnnSupport* dnn_support = AsDnn(); - if (!dnn_support) { - return tsl::Status(absl::StatusCode::kUnknown, - "Fail to find the dnn implementation."); - } - return dnn_support->createRnnSequenceTensorDescriptor( - max_seq_length, batch_size, data_size, seq_lengths, time_major, - data_type); -} - -tsl::StatusOr> -StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size, - int data_size, - dnn::DataType data_type) { - dnn::DnnSupport* dnn_support = AsDnn(); - if (!dnn_support) { - return tsl::Status(absl::StatusCode::kUnknown, - "Fail to find the dnn implementation."); - } - return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size, - data_size, data_type); -} - dnn::DnnSupport* StreamExecutor::AsDnn() { absl::MutexLock lock(&mu_); if (dnn_ != nullptr) { @@ -438,30 +191,36 @@ fft::FftSupport* StreamExecutor::AsFft() { return fft_.get(); } -tsl::Status StreamExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const Kernel& kernel, - const KernelArgs& args) { - SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims, - kernel, args); - +absl::Status StreamExecutor::Launch(Stream* stream, + const ThreadDim& thread_dims, + const BlockDim& block_dims, + const Kernel& kernel, + const KernelArgs& args) { return implementation_->Launch(stream, thread_dims, block_dims, kernel, args); } -tsl::Status StreamExecutor::Submit(Stream* stream, - const CommandBuffer& command_buffer) { - return implementation_->Submit(stream, command_buffer); +absl::Status StreamExecutor::Launch(Stream* stream, + const ThreadDim& thread_dims, + const BlockDim& block_dims, + const ClusterDim& cluster_dims, + const Kernel& kernel, + const KernelArgs& args) { + return implementation_->Launch(stream, thread_dims, block_dims, cluster_dims, + kernel, args); } -tsl::Status StreamExecutor::BlockHostUntilDone(Stream* stream) { - tsl::Status result; - SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream); +absl::Status StreamExecutor::Submit(Stream* stream, + const CommandBuffer& command_buffer) { + return implementation_->Submit(stream, command_buffer); +} +absl::Status StreamExecutor::BlockHostUntilDone(Stream* stream) { + absl::Status result; result = implementation_->BlockHostUntilDone(stream); return result; } -tsl::Status StreamExecutor::GetStatus(Stream* stream) { +absl::Status StreamExecutor::GetStatus(Stream* stream) { return implementation_->GetStatus(stream); } @@ -482,12 +241,7 @@ DeviceMemoryBase StreamExecutor::Allocate(uint64_t size, int64_t memory_space) { return buf; } -void* StreamExecutor::GetUntypedSubBuffer(DeviceMemoryBase* parent, - uint64_t offset, uint64_t size) { - return implementation_->GetSubBuffer(parent, offset, size); -} - -tsl::StatusOr StreamExecutor::GetUntypedSymbol( +absl::StatusOr StreamExecutor::GetUntypedSymbol( const std::string& symbol_name, ModuleHandle module_handle) { // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to // be nullptr/0 for consistency with DeviceMemory semantics. @@ -497,8 +251,7 @@ tsl::StatusOr StreamExecutor::GetUntypedSymbol( return DeviceMemoryBase(opaque, bytes); } - return tsl::Status( - absl::StatusCode::kNotFound, + return absl::NotFoundError( absl::StrCat("Check if module containing symbol ", symbol_name, " is loaded (module_handle = ", reinterpret_cast(module_handle.id()), ")")); @@ -524,18 +277,38 @@ void StreamExecutor::UnifiedMemoryDeallocate(void* location) { return implementation_->UnifiedMemoryDeallocate(location); } -void* StreamExecutor::HostMemoryAllocate(uint64_t size) { +absl::StatusOr StreamExecutor::CollectiveMemoryAllocate(uint64_t bytes) { + TF_ASSIGN_OR_RETURN(void* buffer, + implementation_->CollectiveMemoryAllocate(bytes)); + VLOG(1) << "Called StreamExecutor::CollectiveMemoryAllocate(size=" << bytes + << ") returns " << buffer << StackTraceIfVLOG10(); + return buffer; +} + +absl::Status StreamExecutor::CollectiveMemoryDeallocate(void* location) { + VLOG(1) << "Called StreamExecutor::CollectiveMemoryDeallocate(location=" + << location << ")" << StackTraceIfVLOG10(); + + return implementation_->CollectiveMemoryDeallocate(location); +} + +absl::StatusOr> +StreamExecutor::HostMemoryAllocate(uint64_t size) { void* buffer = implementation_->HostMemoryAllocate(size); VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size << ") returns " << buffer << StackTraceIfVLOG10(); - return buffer; + if (buffer == nullptr && size > 0) { + return absl::InternalError( + absl::StrFormat("Failed to allocate HostMemory of size %d", size)); + } + return std::make_unique(buffer, size, implementation()); } -void StreamExecutor::HostMemoryDeallocate(void* location) { - VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location - << ")" << StackTraceIfVLOG10(); +void StreamExecutor::HostMemoryDeallocate(void* data, uint64_t size) { + VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(data=" << data << ")" + << StackTraceIfVLOG10(); - return implementation_->HostMemoryDeallocate(location); + return implementation_->HostMemoryDeallocate(data); } bool StreamExecutor::SynchronizeAllActivity() { @@ -543,15 +316,11 @@ bool StreamExecutor::SynchronizeAllActivity() { << StackTraceIfVLOG10(); bool ok = implementation_->SynchronizeAllActivity(); - // This should all be quick and infallible work, so we can perform the - // synchronization even in the case of failure. - BlockOnThreadExecutor(background_threads_.get()); - return ok; } -tsl::Status StreamExecutor::SynchronousMemZero(DeviceMemoryBase* location, - uint64_t size) { +absl::Status StreamExecutor::SynchronousMemZero(DeviceMemoryBase* location, + uint64_t size) { VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location << ", size=" << size << ")" << StackTraceIfVLOG10(); @@ -565,7 +334,7 @@ bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase* device_dst, << device_dst->opaque() << ", device_src=" << device_src.opaque() << ", size=" << size << ") D2D" << StackTraceIfVLOG10(); - tsl::Status status = implementation_->SynchronousMemcpyDeviceToDevice( + absl::Status status = implementation_->SynchronousMemcpyDeviceToDevice( device_dst, device_src, size); if (!status.ok()) { LOG(ERROR) << "synchronous memcpy: " << status; @@ -573,48 +342,37 @@ bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase* device_dst, return status.ok(); } -tsl::Status StreamExecutor::SynchronousMemcpyD2H( +absl::Status StreamExecutor::SynchronousMemcpyD2H( const DeviceMemoryBase& device_src, int64_t size, void* host_dst) { VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(device_src=" << device_src.opaque() << ", size=" << size << ", host_dst=" << host_dst << ")" << StackTraceIfVLOG10(); - tsl::Status result; - SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, &result, device_src, size, - host_dst); - - result = implementation_->SynchronousMemcpy(host_dst, device_src, size); + absl::Status result = + implementation_->SynchronousMemcpy(host_dst, device_src, size); if (!result.ok()) { - result = tsl::Status( - absl::StatusCode::kInternal, - absl::StrFormat("failed to synchronously memcpy device-to-host: device " - "%p to host %p size %d: %s", - device_src.opaque(), host_dst, size, - result.ToString())); + result = absl::InternalError(absl::StrFormat( + "failed to synchronously memcpy device-to-host: device " + "%p to host %p size %d: %s", + device_src.opaque(), host_dst, size, result.ToString())); } return result; } -tsl::Status StreamExecutor::SynchronousMemcpyH2D(const void* host_src, - int64_t size, - DeviceMemoryBase* device_dst) { +absl::Status StreamExecutor::SynchronousMemcpyH2D( + const void* host_src, int64_t size, DeviceMemoryBase* device_dst) { VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src << ", size=" << size << ", device_dst=" << device_dst->opaque() << ")" << StackTraceIfVLOG10(); - tsl::Status result; - SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, &result, host_src, size, - device_dst); - - result = implementation_->SynchronousMemcpy(device_dst, host_src, size); + absl::Status result = + implementation_->SynchronousMemcpy(device_dst, host_src, size); if (!result.ok()) { - result = tsl::Status( - absl::StatusCode::kInternal, - absl::StrFormat("failed to synchronously memcpy host-to-device: host " - "%p to device %p size %d: %s", - host_src, device_dst->opaque(), size, - result.ToString())); + result = absl::InternalError(absl::StrFormat( + "failed to synchronously memcpy host-to-device: host " + "%p to device %p size %d: %s", + host_src, device_dst->opaque(), size, result.ToString())); } return result; @@ -622,12 +380,12 @@ tsl::Status StreamExecutor::SynchronousMemcpyH2D(const void* host_src, bool StreamExecutor::Memcpy(Stream* stream, void* host_dst, const DeviceMemoryBase& device_src, uint64_t size) { - return implementation_->Memcpy(stream, host_dst, device_src, size); + return implementation_->Memcpy(stream, host_dst, device_src, size).ok(); } bool StreamExecutor::Memcpy(Stream* stream, DeviceMemoryBase* device_dst, const void* host_src, uint64_t size) { - return implementation_->Memcpy(stream, device_dst, host_src, size); + return implementation_->Memcpy(stream, device_dst, host_src, size).ok(); } bool StreamExecutor::MemcpyDeviceToDevice(Stream* stream, @@ -638,41 +396,42 @@ bool StreamExecutor::MemcpyDeviceToDevice(Stream* stream, size); } -tsl::Status StreamExecutor::MemZero(Stream* stream, DeviceMemoryBase* location, - uint64_t size) { +absl::Status StreamExecutor::MemZero(Stream* stream, DeviceMemoryBase* location, + uint64_t size) { return implementation_->MemZero(stream, location, size); } -tsl::Status StreamExecutor::Memset32(Stream* stream, DeviceMemoryBase* location, - uint32_t pattern, uint64_t size) { +absl::Status StreamExecutor::Memset32(Stream* stream, + DeviceMemoryBase* location, + uint32_t pattern, uint64_t size) { CHECK_EQ(0, size % 4) << "need 32-bit multiple size to fill with 32-bit pattern"; return implementation_->Memset32(stream, location, pattern, size); } bool StreamExecutor::HostCallback( - Stream* stream, absl::AnyInvocable callback) { + Stream* stream, absl::AnyInvocable callback) { return implementation_->HostCallback(stream, std::move(callback)); } -tsl::Status StreamExecutor::AllocateEvent(Event* event) { +absl::Status StreamExecutor::AllocateEvent(Event* event) { return implementation_->AllocateEvent(event); } -tsl::Status StreamExecutor::DeallocateEvent(Event* event) { +absl::Status StreamExecutor::DeallocateEvent(Event* event) { return implementation_->DeallocateEvent(event); } -tsl::Status StreamExecutor::RecordEvent(Stream* stream, Event* event) { +absl::Status StreamExecutor::RecordEvent(Stream* stream, Event* event) { return implementation_->RecordEvent(stream, event); } -tsl::Status StreamExecutor::WaitForEvent(Stream* stream, Event* event) { +absl::Status StreamExecutor::WaitForEvent(Stream* stream, Event* event) { return implementation_->WaitForEvent(stream, event); } -tsl::Status StreamExecutor::WaitForEventOnExternalStream(std::intptr_t stream, - Event* event) { +absl::Status StreamExecutor::WaitForEventOnExternalStream(std::intptr_t stream, + Event* event) { return implementation_->WaitForEventOnExternalStream(stream, event); } @@ -680,6 +439,13 @@ Event::Status StreamExecutor::PollForEventStatus(Event* event) { return implementation_->PollForEventStatus(event); } +absl::StatusOr> StreamExecutor::CreateStream( + std::optional> priority) { + auto stream = std::make_unique(this); + TF_RETURN_IF_ERROR(stream->Initialize(priority)); + return std::move(stream); +} + bool StreamExecutor::AllocateStream(Stream* stream) { live_stream_count_.fetch_add(1, std::memory_order_relaxed); if (!implementation_->AllocateStream(stream)) { @@ -719,38 +485,6 @@ bool StreamExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const { return implementation_->DeviceMemoryUsage(free, total); } -void StreamExecutor::EnqueueOnBackgroundThread(std::function task) { - background_threads_->Schedule(std::move(task)); -} - -void StreamExecutor::RegisterTraceListener(TraceListener* listener) { - { - absl::MutexLock lock(&mu_); - if (listeners_.find(listener) != listeners_.end()) { - LOG(INFO) << "Attempt to register already-registered listener, " - << listener; - } else { - listeners_.insert(listener); - } - } - - implementation_->RegisterTraceListener(listener); -} - -bool StreamExecutor::UnregisterTraceListener(TraceListener* listener) { - { - absl::MutexLock lock(&mu_); - if (listeners_.find(listener) == listeners_.end()) { - LOG(INFO) << "Attempt to unregister unknown listener, " << listener; - return false; - } - listeners_.erase(listener); - } - - implementation_->UnregisterTraceListener(listener); - return true; -} - std::optional StreamExecutor::GetAllocatorStats() { return implementation_->GetAllocatorStats(); } @@ -763,19 +497,6 @@ Stream* StreamExecutor::FindAllocatedStream(void* gpu_stream) { return implementation_->FindAllocatedStream(gpu_stream); } -template -void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT&&... args) { - if (tracing_enabled_) { - { - // instance tracers held in a block to limit the lock lifetime. - absl::ReaderMutexLock lock(&mu_); - for (TraceListener* listener : listeners_) { - (listener->*trace_call)(std::forward(args)...); - } - } - } -} - internal::StreamExecutorInterface* StreamExecutor::implementation() { return implementation_->GetUnderlyingExecutor(); } @@ -792,7 +513,7 @@ StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator( : DeviceMemoryAllocator(platform), stream_executors_(stream_executors.begin(), stream_executors.end()) {} -tsl::StatusOr StreamExecutorMemoryAllocator::Allocate( +absl::StatusOr StreamExecutorMemoryAllocator::Allocate( int device_ordinal, uint64_t size, bool retry_on_failure, int64_t memory_space) { TF_ASSIGN_OR_RETURN(StreamExecutor * executor, @@ -800,7 +521,7 @@ tsl::StatusOr StreamExecutorMemoryAllocator::Allocate( DeviceMemoryBase result = executor->AllocateArray(size, memory_space); if (size > 0 && result == nullptr) { - return tsl::errors::ResourceExhausted(absl::StrFormat( + return absl::ResourceExhaustedError(absl::StrFormat( "Failed to allocate request for %s (%uB) on device ordinal %d", tsl::strings::HumanReadableNumBytes(size), size, device_ordinal)); } @@ -810,8 +531,8 @@ tsl::StatusOr StreamExecutorMemoryAllocator::Allocate( return OwningDeviceMemory(result, device_ordinal, this); } -tsl::Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal, - DeviceMemoryBase mem) { +absl::Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal, + DeviceMemoryBase mem) { if (!mem.is_null()) { TF_ASSIGN_OR_RETURN(StreamExecutor * executor, GetStreamExecutor(device_ordinal)); @@ -819,13 +540,13 @@ tsl::Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal, mem.opaque(), device_ordinal); executor->Deallocate(&mem); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } -tsl::StatusOr StreamExecutorMemoryAllocator::GetStreamExecutor( - int device_ordinal) const { +absl::StatusOr +StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) const { if (device_ordinal < 0) { - return tsl::errors::InvalidArgument(absl::StrFormat( + return absl::InvalidArgumentError(absl::StrFormat( "device ordinal value (%d) must be non-negative", device_ordinal)); } for (StreamExecutor* se : stream_executors_) { @@ -833,7 +554,7 @@ tsl::StatusOr StreamExecutorMemoryAllocator::GetStreamExecutor( return se; } } - return tsl::errors::NotFound( + return absl::NotFoundError( absl::StrFormat("Device %s:%d present but not supported", platform()->Name(), device_ordinal)); } @@ -842,24 +563,21 @@ bool StreamExecutorMemoryAllocator::AllowsAsynchronousDeallocation() const { return false; } -tsl::StatusOr StreamExecutorMemoryAllocator::GetStream( +absl::StatusOr StreamExecutorMemoryAllocator::GetStream( int device_ordinal) { CHECK(!AllowsAsynchronousDeallocation()) << "The logic below only works for synchronous allocators"; TF_ASSIGN_OR_RETURN(StreamExecutor * executor, GetStreamExecutor(device_ordinal)); - Stream* out = [&] { - absl::MutexLock lock(&mutex_); - if (!streams_.count(device_ordinal)) { - auto p = streams_.emplace(std::piecewise_construct, - std::forward_as_tuple(device_ordinal), - std::forward_as_tuple(executor)); - p.first->second.Init(); - return &p.first->second; - } - return &streams_.at(device_ordinal); - }(); - return out; + absl::MutexLock lock(&mutex_); + if (!streams_.count(device_ordinal)) { + auto p = streams_.emplace(std::piecewise_construct, + std::forward_as_tuple(device_ordinal), + std::forward_as_tuple(executor)); + TF_RETURN_IF_ERROR(p.first->second.Initialize()); + return &p.first->second; + } + return &streams_.at(device_ordinal); } } // namespace stream_executor diff --git a/xla/stream_executor/stream_executor_pimpl.h b/xla/stream_executor/stream_executor_pimpl.h index d7235077d44f4..2717d71eb3ebf 100644 --- a/xla/stream_executor/stream_executor_pimpl.h +++ b/xla/stream_executor/stream_executor_pimpl.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,35 +17,37 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_ #include +#include #include +#include #include #include -#include #include +#include #include #include "absl/base/attributes.h" #include "absl/base/thread_annotations.h" #include "absl/functional/any_invocable.h" -#include "absl/strings/string_view.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/stream_executor/allocator_stats.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/fft.h" +#include "xla/stream_executor/host_memory_allocation.h" #include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/module_spec.h" -#include "xla/stream_executor/numeric_options.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/platform/port.h" -#include "xla/stream_executor/trace_listener.h" +#include "tsl/platform/logging.h" #include "tsl/platform/status.h" -#include "tsl/platform/threadpool.h" -#include "tsl/protobuf/dnn.pb.h" namespace stream_executor { @@ -55,11 +57,6 @@ namespace internal { class StreamExecutorInterface; } // namespace internal -// Forward declaration of private friend class. -template -class ScopedTracer; - // A StreamExecutor manages a single device, in terms of executing work (kernel // launches) and memory management (allocation/deallocation, memory copies to // and from the device). It is conceptually the "handle" for a device -- Stream @@ -74,12 +71,6 @@ class ScopedTracer; // StreamExecutor interface should not be invoked from a signal handler. class StreamExecutor { public: - // Platform specific handle to the underlying resources behind an executor - // implementation (e.g. it gives access to CUcontext for CUDA platform). - struct PlatformSpecificHandle { - void* context = nullptr; // will be nullptr if not supported - }; - StreamExecutor( const Platform* platform, std::unique_ptr implementation, @@ -87,13 +78,7 @@ class StreamExecutor { ~StreamExecutor(); - // TODO(ezhulenev): Consider removing this platform-specific accessor and - // forward all users to platform-specific headers, however it requires careful - // build rules set up to avoid leaking even more implementation details. - PlatformSpecificHandle platform_specific_handle() const; - - tsl::Status Init(); - tsl::Status Init(DeviceOptions device_options); + absl::Status Init(); // Returns a reference to the platform that created this executor. const Platform* platform() const { return platform_; } @@ -111,7 +96,7 @@ class StreamExecutor { // // If an error occurs, or there is no kernel available for the StreamExecutor // platform, error status is returned. - tsl::Status GetKernel(const MultiKernelLoaderSpec& spec, Kernel* kernel); + absl::Status GetKernel(const MultiKernelLoaderSpec& spec, Kernel* kernel); // Releases any state associated with the previously loaded kernel. void UnloadKernel(const Kernel* kernel); @@ -121,13 +106,13 @@ class StreamExecutor { // `spec` describes the module to be loaded. On success writes the handle for // the loaded module to `module_handle` and returns OkStatus(). Otherwise, // returns the error which has occurred. - tsl::Status LoadModule(const MultiModuleLoaderSpec& spec, - ModuleHandle* module_handle); + absl::Status LoadModule(const MultiModuleLoaderSpec& spec, + ModuleHandle* module_handle); // Unloads the module with handle `module_handle`. bool UnloadModule(ModuleHandle module_handle); - tsl::StatusOr> CreateOrShareConstant( + absl::StatusOr> CreateOrShareConstant( Stream* stream, absl::Span content); // Synchronously allocates an array on the device of type T with element_count @@ -155,20 +140,8 @@ class StreamExecutor { return AllocateOwnedArray(1); } - // Allocate a memory region inside another allocated memory region. - // Offset and size are specified in terms of T elements. - // Warning: Do not free a parent buffer before its sub-buffers; this may cause - // use-after-free issues (the specific behavior is not consistent across - // platforms). - // - Note: OpenCL uses refcounting to manage buffer lifetimes, so use of a - // sub-buffer after parent deallocation is expected to be safe. This will - // render your code non-platform-portable, however. - template - DeviceMemory GetSubBuffer(DeviceMemory* parent, uint64_t element_offset, - uint64_t element_count); - // An untyped version of GetSymbol. - tsl::StatusOr GetUntypedSymbol( + absl::StatusOr GetUntypedSymbol( const std::string& symbol_name, ModuleHandle module_handle); // Deallocate the DeviceMemory previously allocated via this interface. @@ -188,14 +161,22 @@ class StreamExecutor { // UnifiedMemoryAllocate. void UnifiedMemoryDeallocate(void* location); + // Allocate collective device memory using ncclMemAlloc. + // See + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html + // for more details on User Buffer Registration. + absl::StatusOr CollectiveMemoryAllocate(uint64_t bytes); + + // Deallocate collective device memory previously allocated with + // CollectiveMemoryAllocate. + absl::Status CollectiveMemoryDeallocate(void* location); + // Allocates a region of host memory and registers it with the platform API. // Memory allocated in this manner (or allocated and registered with // HostMemoryRegister() is required for use in asynchronous memcpy operations, // such as Stream::ThenMemcpy. - void* HostMemoryAllocate(uint64_t size); - - // Deallocates a region of host memory allocated by HostMemoryAllocate(). - void HostMemoryDeallocate(void* location); + absl::StatusOr> HostMemoryAllocate( + uint64_t size); // Synchronizes all activity occurring in the StreamExecutor's context (most // likely a whole device). @@ -203,27 +184,27 @@ class StreamExecutor { // Blocks the caller while "size" bytes are zeroed out (in POD fashion) at the // given location in device memory. - tsl::Status SynchronousMemZero(DeviceMemoryBase* location, - uint64_t size) ABSL_MUST_USE_RESULT; + absl::Status SynchronousMemZero(DeviceMemoryBase* location, + uint64_t size) ABSL_MUST_USE_RESULT; // Same as SynchronousMemcpy(DeviceMemoryBase*, ...) above. - tsl::Status SynchronousMemcpyH2D(const void* host_src, int64_t size, - DeviceMemoryBase* device_dst); + absl::Status SynchronousMemcpyH2D(const void* host_src, int64_t size, + DeviceMemoryBase* device_dst); // Alternative interface for memcpying from host to device that takes an // array slice. Checks that the destination size can accommodate the host // slice size. template - tsl::Status SynchronousMemcpyH2D(absl::Span host_src, - DeviceMemoryBase* device_dst) { + absl::Status SynchronousMemcpyH2D(absl::Span host_src, + DeviceMemoryBase* device_dst) { auto host_size = host_src.size() * sizeof(T); CHECK(device_dst->size() == 0 || device_dst->size() >= host_size); return SynchronousMemcpyH2D(host_src.begin(), host_size, device_dst); } // Same as SynchronousMemcpy(void*, ...) above. - tsl::Status SynchronousMemcpyD2H(const DeviceMemoryBase& device_src, - int64_t size, void* host_dst); + absl::Status SynchronousMemcpyD2H(const DeviceMemoryBase& device_src, + int64_t size, void* host_dst); // Blocks the caller while a data segment of the given size is copied from the // device source to the device destination. @@ -234,15 +215,15 @@ class StreamExecutor { // Enqueues an operation onto stream to zero out size bytes at the given // device memory location. Neither stream nor location may be null. Returns // whether the operation was successfully enqueued onto the stream. - tsl::Status MemZero(Stream* stream, DeviceMemoryBase* location, - uint64_t size) ABSL_MUST_USE_RESULT; + absl::Status MemZero(Stream* stream, DeviceMemoryBase* location, + uint64_t size) ABSL_MUST_USE_RESULT; // Enqueues an operation onto stream to set 32-bit patterns starting at // location, for byte count given by size. size must be 32-bit quantified // (i.e. evently divisible by 4). Returns whether the operation was // successfully enqueued onto the stream. - tsl::Status Memset32(Stream* stream, DeviceMemoryBase* location, - uint32_t pattern, uint64_t size); + absl::Status Memset32(Stream* stream, DeviceMemoryBase* location, + uint32_t pattern, uint64_t size); // Enables peer access from this StreamExecutor to memory // allocated by other, such that launched device code, memcpies, etc may @@ -251,7 +232,7 @@ class StreamExecutor { // Both this StreamExecutor and other must be backed by the same platform (as // in // CUDA vs OpenCL) implementation. - tsl::Status EnablePeerAccessTo(StreamExecutor* other); + absl::Status EnablePeerAccessTo(StreamExecutor* other); // Returns whether it's possible to enable peer access from this // StreamExecutor @@ -279,103 +260,6 @@ class StreamExecutor { // will be reflected in "free". bool DeviceMemoryUsage(int64_t* free, int64_t* total) const; - // Returns the supported algorithms / execution plans for a convolution. - tsl::Status GetConvolveRunners( - bool use_cudnn_frontend, dnn::ConvolutionKind kind, - dnn::DataType input_type, dnn::DataType output_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, - const dnn::FilterDescriptor& filter_descriptor, - DeviceMemoryBase filter_data, - const dnn::BatchDescriptor& output_descriptor, - DeviceMemoryBase output_data, - const dnn::ConvolutionDescriptor& convolution_descriptor, - bool use_fallback, ScratchAllocator* scratch_allocator, - const NumericOptions& numeric_options, - std::vector>* out_exec_plans); - - tsl::Status GetGraphConvolveRunners( - dnn::ConvolutionKind kind, dnn::DataType input_type, - dnn::DataType output_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, - const dnn::FilterDescriptor& filter_descriptor, - const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor, - bool use_fallback, const NumericOptions& numeric_options, - std::vector>* out_exec_plans, - std::string serialized_graph); - - tsl::Status GetFusedConvolveRunners( - bool use_cudnn_frontend, dnn::ConvolutionKind kind, - dnn::DataType input_type, dnn::DataType bias_type, - dnn::DataType output_type, double conv_input_scale, - double side_input_scale, double leakyrelu_alpha, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, - const dnn::FilterDescriptor& filter_descriptor, - const dnn::BatchDescriptor& bias_descriptor, - const dnn::BatchDescriptor& output_descriptor, - const dnn::ConvolutionDescriptor& convolution_descriptor, - bool use_fallback, dnn::ActivationMode activation_mode, - const NumericOptions& numeric_options, - std::vector>* out_exec_plans); - - tsl::Status GetFusedMatmulRunners( - bool use_cudnn_frontend, dnn::DataType input_type, - dnn::DataType bias_type, dnn::DataType output_type, Stream* stream, - bool trans_a, bool trans_b, uint64_t m, uint64_t n, uint64_t k, - int64_t lda, int64_t ldb, int64_t ldc, - dnn::ActivationMode activation_mode, bool use_fallback, - const NumericOptions& numeric_options, - std::vector>* - out_exec_plans); - - // Returns the list of supported algorithms for the forward convolution - // operation. - bool GetMIOpenConvolveAlgorithms( - dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, - const dnn::FilterDescriptor& filter_descriptor, - DeviceMemoryBase filter_data, - const dnn::BatchDescriptor& output_descriptor, - DeviceMemoryBase output_data, - const dnn::ConvolutionDescriptor& convolution_descriptor, - ScratchAllocator* scratch_allocator, - std::vector* out_algorithms); - - // Returns the list of supported algorithms for rnn operation. - bool GetRnnAlgorithms(std::vector* out_algorithms); - - // Get the list of supported algorithms for BLAS gemm. - bool GetBlasGemmAlgorithms(Stream* stream, - std::vector* out_algorithms); - - // Create an RNN descriptor based on model shapes and configurations. - // The caller retains the ownership of the descriptor. - tsl::StatusOr> createRnnDescriptor( - int num_layers, int hidden_size, int input_size, int cell_size, - int batch_size, dnn::RnnInputMode input_mode, - dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, - dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, - const NumericOptions& numeric_options, float dropout, uint64_t seed, - ScratchAllocator* state_allocator, bool use_padded_io); - - // Create a RNN sequence descriptor that specifies either the input or output - // sequence. The caller retains the ownership of the returned descriptor. - tsl::StatusOr> - createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, - int data_size, dnn::DataType data_type); - - tsl::StatusOr> - createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, - int data_size, - const absl::Span& seq_lengths, - bool time_major, dnn::DataType data_type); - - // Create an RNN state descriptor that specifies the input or hidden state. - // The caller retains the ownership of the returned descriptor. - tsl::StatusOr> - createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, - dnn::DataType data_type); - // Returns the device ordinal that this StreamExecutor was initialized with. // Meaningless before initialization. int device_ordinal() const { return device_ordinal_; } @@ -383,22 +267,6 @@ class StreamExecutor { // Returns a borrowed pointer to the underlying StreamExecutor implementation. internal::StreamExecutorInterface* implementation(); - // Creates a kernel which can be launched with `stream.ThenLaunch(...)` from a - // PTX (and optional CUBIN), such that the types of the arguments provided for - // launch would have to match types of the arguments provided at creation - // time. The canonical storage for both ptx and cubin_data should outlive the - // lifetime of the kernel. - template - tsl::StatusOr>> CreateTypedKernel( - absl::string_view kernel_name, absl::string_view ptx, - absl::Span cubin_data); - - // Creates a kernel which can be launched with `stream.ThenLaunch(...)` from - // an in-process symbol pointer. - template - tsl::StatusOr>> CreateTypedKernel( - absl::string_view kernel_name, void* symbol); - // Warning: use Stream::ThenLaunch instead, this method is not for general // consumption. However, this is the only way to launch a kernel for which // the type signature is only known at runtime; say, if an application @@ -412,12 +280,17 @@ class StreamExecutor { // // This is called by Stream::Launch() to delegate to the platform's launch // implementation in StreamExecutorInterface::Launch(). - tsl::Status Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, const Kernel& kernel, - const KernelArgs& args); + absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, + const BlockDim& block_dims, const Kernel& kernel, + const KernelArgs& args); + + absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, + const BlockDim& block_dims, + const ClusterDim& cluster_dims, const Kernel& kernel, + const KernelArgs& args); // Submits command buffer for execution to the underlying platform driver. - tsl::Status Submit(Stream* stream, const CommandBuffer& command_buffer); + absl::Status Submit(Stream* stream, const CommandBuffer& command_buffer); // Gets-or-creates (creates with memoization) a FftSupport datatype that can // be used to execute FFT routines on the current platform. @@ -450,18 +323,6 @@ class StreamExecutor { // underlying platform. blas::BlasSupport* AsBlas(); - // Registers a trace listener to receive callbacks for only a single - // StreamExecutor instance. - // To register a listener for all executors for a given platform, see - // Platform::RegisterTraceListener(). - // Does not take ownership of listener. - void RegisterTraceListener(TraceListener* listener); - - // Removes a TraceListener from this StreamExecutor instance. - // Returns false (and logs) in cases where the argument listener was not - // previously registered. - bool UnregisterTraceListener(TraceListener* listener); - // Return allocator statistics. std::optional GetAllocatorStats(); @@ -477,32 +338,29 @@ class StreamExecutor { // Performs linear search over alive GPU streams. Stream* FindAllocatedStream(void* gpu_stream); + // Creates and initializes a Stream. + absl::StatusOr> CreateStream( + std::optional> priority = std::nullopt); + // Deallocates a region of host memory allocated by HostMemoryAllocate(). + void HostMemoryDeallocate(void* data, uint64_t size); + private: - template - friend class ScopedTracer; friend class Event; friend class Stream; - template - friend class TypedKernel; - template - friend struct ThenBlasImpl; + friend class HostMemoryAllocation; // Synchronously allocates size bytes on the underlying platform and returns // a DeviceMemoryBase representing that allocation. In the case of failure, // nullptr is returned. DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space); - void* GetUntypedSubBuffer(DeviceMemoryBase* parent, uint64_t offset, - uint64_t size); - // Causes the host code to synchronously wait for operations entrained // onto stream to complete. Effectively a join on the asynchronous device // operations enqueued on the stream before this program point. - tsl::Status BlockHostUntilDone(Stream* stream); + absl::Status BlockHostUntilDone(Stream* stream); // Without blocking the device, retrieve the current stream status. - tsl::Status GetStatus(Stream* stream); + absl::Status GetStatus(Stream* stream); // Finds and retrieves device memory for the symbol on the underlying // platform. @@ -530,24 +388,24 @@ class StreamExecutor { // See Stream::ThenDoHostCallback for full details. // This is the preferred form for a callback that may return an error. bool HostCallback(Stream* stream, - absl::AnyInvocable callback); + absl::AnyInvocable callback); // Performs platform-specific allocation and initialization of an event. - tsl::Status AllocateEvent(Event* event); + absl::Status AllocateEvent(Event* event); // Performs platform-specific deallocation and cleanup of an event. - tsl::Status DeallocateEvent(Event* event); + absl::Status DeallocateEvent(Event* event); // Inserts the specified event at the end of the specified stream. - tsl::Status RecordEvent(Stream* stream, Event* event); + absl::Status RecordEvent(Stream* stream, Event* event); // Wait for the specified event at the end of the specified stream. - tsl::Status WaitForEvent(Stream* stream, Event* event); + absl::Status WaitForEvent(Stream* stream, Event* event); // Wait for the specified event at the end of the raw platform-specific // stream. Currently only implemented for GPU, where stream is a // GpuStreamHandle (e.g. cudaStream_t). - tsl::Status WaitForEventOnExternalStream(std::intptr_t stream, Event* event); + absl::Status WaitForEventOnExternalStream(std::intptr_t stream, Event* event); // Requests the current status of the event from the underlying platform. Event::Status PollForEventStatus(Event* event); @@ -567,21 +425,6 @@ class StreamExecutor { // ownership transfer to caller. std::unique_ptr CreateDeviceDescription() const; - // Adds a task to the tsl::thread::ThreadPool work queue. These tasks must be - // fire-and-forget and have no external data or timing dependencies; their - // execution order and completion time have no guarantees. - // For an example of an appropriate task, see HostBlas::DoBlasGemmInternal; - // there, temporary internal buffers are freed using this method. - void EnqueueOnBackgroundThread(std::function task); - - // Calls the relevant TraceListener routine to begin tracing for the specified - // asynchronous method. - template - void SubmitTrace(TraceCallT trace_call, ArgsT&&... args); - - // Reader/writer lock for class-static StreamExecutor members. - static absl::Mutex static_mu_; - // Reader/writer lock for mutable data structures on this StreamExecutor. // // Mutable so that caching functions (like DeviceDescription, AsBlas, etc.) @@ -618,16 +461,6 @@ class StreamExecutor { // Immutable post-initialization. int device_ordinal_; - // Executor for handling host callback work that cannot be performed - // by a host callback thread - for example, cleanup after a host BLAS routine - // (which may make device API calls). This work cannot block the host - // callback thread, will be completed asynchronously, and should be treated - // as fire-and-forget. Assume no ordering guarantees WRT the tasks enqueued - // here. - // - // Immutable post-initialization. Object is thread-safe. - std::unique_ptr background_threads_; - // Counter for the current number of live streams. This is used to check // for accidentally-outstanding streams at StreamExecutor teardown time, as // well @@ -639,12 +472,6 @@ class StreamExecutor { // executor. static constexpr int kNumBackgroundThreads = 1; - // Indicates if StreamExecutor operation tracing should be performed. - bool tracing_enabled_; - - // The set of TraceListeners registered for this StreamExecutor. - std::set listeners_ ABSL_GUARDED_BY(mu_); - // Memory limit in bytes. Value less or equal to 0 indicates there is no // limit. int64_t memory_limit_bytes_; @@ -694,35 +521,6 @@ class ScopedModuleHandle { //////////// // Inlines -template -inline tsl::StatusOr>> -StreamExecutor::CreateTypedKernel(absl::string_view kernel_name, - absl::string_view ptx, - absl::Span cubin_data) { - auto kernel_base = std::make_unique>(this); - MultiKernelLoaderSpec loader_spec(kernel_base->kNumberOfParameters); - loader_spec.AddCudaPtxInMemory(ptx, kernel_name); - - if (!cubin_data.empty()) { - loader_spec.AddCudaCubinInMemory( - reinterpret_cast(cubin_data.data()), kernel_name); - } - - TF_RETURN_IF_ERROR(GetKernel(loader_spec, kernel_base.get())); - return std::move(kernel_base); -} - -template -inline tsl::StatusOr>> -StreamExecutor::CreateTypedKernel(absl::string_view kernel_name, void* symbol) { - auto kernel_base = std::make_unique>(this); - MultiKernelLoaderSpec loader_spec(kernel_base->kNumberOfParameters); - loader_spec.AddInProcessSymbol(symbol, kernel_name); - - TF_RETURN_IF_ERROR(GetKernel(loader_spec, kernel_base.get())); - return std::move(kernel_base); -} - template inline DeviceMemory StreamExecutor::AllocateArray(uint64_t element_count, int64_t memory_space) { @@ -743,32 +541,12 @@ ScopedDeviceMemory::ScopedDeviceMemory( : ScopedDeviceMemory(parent, parent->AllocateArray(values.size())) { if (ptr() != nullptr) { std::vector local(values); - if (!parent->SynchronousMemcpy(ptr(), const_cast(&local[0]), - ptr()->size())) { + if (!parent->SynchronousMemcpy(ptr(), local.data(), ptr()->size())) { TF_CHECK_OK(Free()); } } } -template -DeviceMemory StreamExecutor::GetSubBuffer(DeviceMemory* parent, - uint64_t element_offset, - uint64_t element_count) { - if (element_offset + element_count > parent->ElementCount()) { - LOG(ERROR) << "requested sub-buffer allocation (offset + size) is greater " - << "than parent allocation size: (" << element_offset << " + " - << element_count << ") vs. (" << parent->ElementCount() << ")"; - return DeviceMemory{}; - } - - void* opaque = GetUntypedSubBuffer(parent, sizeof(T) * element_offset, - sizeof(T) * element_count); - if (opaque == nullptr) { - return DeviceMemory{}; - } - return DeviceMemory(DeviceMemoryBase(opaque, sizeof(T) * element_count)); -} - } // namespace stream_executor #endif // XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_ diff --git a/xla/stream_executor/stream_executor_test.cc b/xla/stream_executor/stream_executor_test.cc new file mode 100644 index 0000000000000..8dd6c8b36e418 --- /dev/null +++ b/xla/stream_executor/stream_executor_test.cc @@ -0,0 +1,41 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/stream_executor.h" + +#include + +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor { + +static std::unique_ptr NewStreamExecutor() { + Platform* platform = PlatformManager::PlatformWithName("Host").value(); + StreamExecutorConfig config(/*ordinal=*/0); + return platform->GetUncachedExecutor(config).value(); +} + +TEST(StreamExecutorTest, HostMemoryAllocate) { + auto executor = NewStreamExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto allocation, executor->HostMemoryAllocate(1024)); + EXPECT_NE(allocation->opaque(), nullptr); + EXPECT_EQ(allocation->size(), 1024); +} + +} // namespace stream_executor diff --git a/xla/stream_executor/stream_test.cc b/xla/stream_executor/stream_test.cc index 3517c1dd3684e..7801abe629431 100644 --- a/xla/stream_executor/stream_test.cc +++ b/xla/stream_executor/stream_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,7 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + +#include "absl/log/check.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace stream_executor { @@ -22,40 +28,41 @@ namespace { class StreamTest : public ::testing::Test { protected: std::unique_ptr NewStreamExecutor() { - Platform* platform = MultiPlatformManager::PlatformWithName("Host").value(); + Platform* platform = PlatformManager::PlatformWithName("Host").value(); StreamExecutorConfig config(/*ordinal=*/0); return platform->GetUncachedExecutor(config).value(); } }; -TEST_F(StreamTest, NoInitNotOk) { +TEST_F(StreamTest, InitOk) { std::unique_ptr executor = NewStreamExecutor(); - Stream stream(executor.get()); - EXPECT_FALSE(stream.ok()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); } -TEST_F(StreamTest, InitOk) { +TEST_F(StreamTest, InitWithIntPriorityOk) { + std::unique_ptr executor = NewStreamExecutor(); + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream(1)); +} + +TEST_F(StreamTest, InitWithStreamPriorityOk) { std::unique_ptr executor = NewStreamExecutor(); - Stream stream(executor.get()); - stream.Init(); - EXPECT_TRUE(stream.ok()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, + executor->CreateStream(StreamPriority::Highest)); } TEST_F(StreamTest, OneSubStream) { std::unique_ptr executor = NewStreamExecutor(); - Stream stream(executor.get()); - stream.Init(); - EXPECT_TRUE(stream.ok()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); // Get and return a sub-stream. Sub-streams are always initialized. - Stream* sub_stream1 = stream.GetOrCreateSubStream(); + TF_ASSERT_OK_AND_ASSIGN(Stream * sub_stream1, stream->GetOrCreateSubStream()); EXPECT_TRUE(sub_stream1->ok()); - stream.ReturnSubStream(sub_stream1); + stream->ReturnSubStream(sub_stream1); // Get and return another sub-stream. - Stream* sub_stream2 = stream.GetOrCreateSubStream(); + TF_ASSERT_OK_AND_ASSIGN(Stream * sub_stream2, stream->GetOrCreateSubStream()); EXPECT_TRUE(sub_stream2->ok()); - stream.ReturnSubStream(sub_stream1); + stream->ReturnSubStream(sub_stream1); // The underlying sub-streams should be the same, since sub_stream1 // was returned before we tried to get sub_stream2. @@ -64,14 +71,12 @@ TEST_F(StreamTest, OneSubStream) { TEST_F(StreamTest, TwoSubStreams) { std::unique_ptr executor = NewStreamExecutor(); - Stream stream(executor.get()); - stream.Init(); - EXPECT_TRUE(stream.ok()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); // Get two sub-streams. - Stream* sub_stream1 = stream.GetOrCreateSubStream(); + TF_ASSERT_OK_AND_ASSIGN(Stream * sub_stream1, stream->GetOrCreateSubStream()); EXPECT_TRUE(sub_stream1->ok()); - Stream* sub_stream2 = stream.GetOrCreateSubStream(); + TF_ASSERT_OK_AND_ASSIGN(Stream * sub_stream2, stream->GetOrCreateSubStream()); EXPECT_TRUE(sub_stream2->ok()); // The underlying sub-streams should be different, since neither @@ -79,123 +84,19 @@ TEST_F(StreamTest, TwoSubStreams) { EXPECT_NE(sub_stream1, sub_stream2); // Return sub_stream1 and get sub_stream3, which should be the same. - stream.ReturnSubStream(sub_stream1); - Stream* sub_stream3 = stream.GetOrCreateSubStream(); + stream->ReturnSubStream(sub_stream1); + TF_ASSERT_OK_AND_ASSIGN(Stream * sub_stream3, stream->GetOrCreateSubStream()); EXPECT_TRUE(sub_stream3->ok()); EXPECT_EQ(sub_stream1, sub_stream3); EXPECT_NE(sub_stream2, sub_stream3); // Return sub_stream2 and get sub_stream4, which should be the same. - stream.ReturnSubStream(sub_stream2); - Stream* sub_stream4 = stream.GetOrCreateSubStream(); + stream->ReturnSubStream(sub_stream2); + TF_ASSERT_OK_AND_ASSIGN(Stream * sub_stream4, stream->GetOrCreateSubStream()); EXPECT_TRUE(sub_stream4->ok()); EXPECT_EQ(sub_stream2, sub_stream4); EXPECT_NE(sub_stream3, sub_stream4); } -TEST_F(StreamTest, FailedSubStreamBeforeReturnNotReused) { - std::unique_ptr executor = NewStreamExecutor(); - Stream stream(executor.get()); - stream.Init(); - EXPECT_TRUE(stream.ok()); - - // Get sub_stream1. - Stream* sub_stream1 = stream.GetOrCreateSubStream(); - EXPECT_TRUE(sub_stream1->ok()); - - // Force an error on sub_stream1; here we call a method that requires DNN - // support, which we know the Host platform doesn't support. - sub_stream1->ThenDepthConcatenate({}, {}, nullptr); - EXPECT_FALSE(sub_stream1->ok()); - - // Return sub_stream1 and get sub_stream2. - stream.ReturnSubStream(sub_stream1); - Stream* sub_stream2 = stream.GetOrCreateSubStream(); - EXPECT_TRUE(sub_stream2->ok()); - - // The underlying sub_streams should be different. They would have been the - // same, but since we forced an error on sub_stream1, it will not be - // re-used. Sadly we can't just check: - // EXPECT_NE(sub_stream1, sub_stream2); - // - // The above should hold logically, but it may fail if the new Stream instance - // allocated for sub_stream2 happens to reside in the same memory address as - // sub_stream1. - // - // The check that sub_stream2->ok() serves as a good-enough check. - - // Return sub_stream2 and get sub_stream3. The previous error on sub_stream1 - // has no effect on these streams, and they are the same. - stream.ReturnSubStream(sub_stream2); - Stream* sub_stream3 = stream.GetOrCreateSubStream(); - EXPECT_TRUE(sub_stream3->ok()); - EXPECT_EQ(sub_stream2, sub_stream3); -} - -TEST_F(StreamTest, FailedSubStreamAfterReturnNotReused) { - std::unique_ptr executor = NewStreamExecutor(); - Stream stream(executor.get()); - stream.Init(); - EXPECT_TRUE(stream.ok()); - - // Get and return sub_stream1. - Stream* sub_stream1 = stream.GetOrCreateSubStream(); - EXPECT_TRUE(sub_stream1->ok()); - stream.ReturnSubStream(sub_stream1); - - // Force an error on sub_stream1; here we call a method that requires DNN - // support, which we know the Host platform doesn't support. - // - // It is a bit weird to use sub_stream1 after it has already been returned. By - // doing this, we're simulating an asynchronous error that occurs during - // execution of the sub_stream, that occurs after the sub_stream is returned. - // - // E.g. the following is a common pattern of usage, where the execution of the - // operations enqueued onto the sub streams may occur after the streams have - // already been returned. - // - // void EnqueueOnSubStreams(Stream* stream) { - // Stream* sub_stream1 = stream.GetOrCreateSubStream(); - // Stream* sub_stream2 = stream.GetOrCreateSubStream(); - // // ... enqueue some operations on the sub streams ... - // stream.ThenWaitFor(sub_stream1).ThenWaitFor(sub_stream2); - // stream.ReturnSubStream(sub_stream1); - // stream.ReturnSubStream(sub_stream2); - // } - // - // Stream* main_stream = ...; - // EnqueueOnSubStreams(main_stream); - // main_stream.BlockHostUntilDone(); - // - // TODO(b/112196569): The semantics of failed sub-streams is error-prone; - // GetOrCreateSubStream can still return a sub-stream that has not encountered - // an error yet, but will encounter one in the future, based on previously - // enqueued operations. - sub_stream1->ThenDepthConcatenate({}, {}, nullptr); - EXPECT_FALSE(sub_stream1->ok()); - - // Get and return sub_stream2. - Stream* sub_stream2 = stream.GetOrCreateSubStream(); - EXPECT_TRUE(sub_stream2->ok()); - - // The underlying streams should be different. They would have been the same, - // but since we forced an error on sub_stream1, it will not be re-used. Sadly - // we can't just check: - // EXPECT_NE(sub_stream1, sub_stream2); - // - // The above should hold logically, but it may fail if the new stream instance - // allocated for sub_stream2 happens to reside in the same memory address as - // sub_stream1. - // - // The check that sub_stream2->ok() serves as a good-enough check. - - // Return sub_stream2 and get sub_stream3. The previous error on sub_stream1 - // has no effect on these streams, and they are the same. - stream.ReturnSubStream(sub_stream2); - Stream* sub_stream3 = stream.GetOrCreateSubStream(); - EXPECT_TRUE(sub_stream3->ok()); - EXPECT_EQ(sub_stream2, sub_stream3); -} - } // namespace } // namespace stream_executor diff --git a/xla/stream_executor/temporary_device_memory.cc b/xla/stream_executor/temporary_device_memory.cc deleted file mode 100644 index 54f45cca06a92..0000000000000 --- a/xla/stream_executor/temporary_device_memory.cc +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/stream_executor/temporary_device_memory.h" - -#include "xla/stream_executor/stream.h" - -namespace stream_executor { - -TemporaryDeviceMemoryBase::~TemporaryDeviceMemoryBase() { - parent_->temporary_memory_manager()->MarkFinalized(device_memory_, - allocation_generation_, - /*must_exist=*/false); -} - -DeviceMemoryBase* TemporaryDeviceMemoryBase::mutable_device_memory() { - DCHECK(!IsFinalized()) - << "should not access device memory after finalization"; - return &device_memory_; -} - -const DeviceMemoryBase& TemporaryDeviceMemoryBase::device_memory() const { - DCHECK(!IsFinalized()) - << "should not access device memory after finalization"; - return device_memory_; -} - -void TemporaryDeviceMemoryBase::Finalize() { - DCHECK(!IsFinalized()) << "should not finalize more than once"; - parent_->temporary_memory_manager()->MarkFinalized(device_memory_, - allocation_generation_, - /*must_exist=*/true); -} - -bool TemporaryDeviceMemoryBase::IsFinalized() const { - return parent_->temporary_memory_manager()->IsFinalized( - device_memory_, allocation_generation_); -} - -bool TemporaryDeviceMemoryBase::IsAllocated() const { - return parent_->temporary_memory_manager()->HasAllocated( - device_memory_, allocation_generation_); -} - -TemporaryDeviceMemoryBase::TemporaryDeviceMemoryBase( - Stream* parent, DeviceMemoryBase device_memory, - uint64_t allocation_generation) - : device_memory_(device_memory), - allocation_generation_(allocation_generation), - parent_(parent) { - DCHECK(IsAllocated()); -} - -} // namespace stream_executor diff --git a/xla/stream_executor/temporary_device_memory.h b/xla/stream_executor/temporary_device_memory.h deleted file mode 100644 index b12fd2efceb2c..0000000000000 --- a/xla/stream_executor/temporary_device_memory.h +++ /dev/null @@ -1,136 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Temporary memories are used to allocate scratch space required by an -// operation about to be enqueued onto a stream. -// -// std::unique_ptr> temporary_memory = -// stream.AllocateTemporaryArray(1024).value(); -// // ... enqueue stuff onto the stream using the temporary memory ... -// // Note that the memory is accessible via -// // temporary_memory->device_memory() and similar. -// -// // Finalize the temporary memory. The underlying device memory may -// // be released any time after this program point, as another thread may -// // call Stream::BlockHostUntilDone, causing synchronization. This -// // finalization also happens automatically for the user if the unique_ptr -// // goes out of scope. -// temporary_memory.Finalize(); -// -// WARNING: do NOT hold onto the device memory associated with temporary_memory -// after finalization. If temporary_memory->device_memory() is used after the -// temporary memory is finalized, it will cause a DCHECK failure. -// -// Note that standard usage takes advantage of the type-safe wrapper, -// TemporaryDeviceMemory, defined below. -// -// Also see tests for executable sample usage. - -#ifndef XLA_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_ -#define XLA_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_ - -#include "xla/stream_executor/device_memory.h" - -namespace stream_executor { - -class Stream; -namespace internal { -class TemporaryMemoryManager; -} - -// Untyped base class (analogous to a void*) for temporary device memory -// allocations associated with a stream. -class TemporaryDeviceMemoryBase { - public: - // Marks the temporary memory as finalized if it is not already marked as - // such. - ~TemporaryDeviceMemoryBase(); - - // Precondition: !IsFinalized() - DeviceMemoryBase* mutable_device_memory(); - - // Precondition: !IsFinalized() - const DeviceMemoryBase& device_memory() const; - - // "Finalizes" this temporary memory, making it acceptable to release at the - // next stream synchronization point -- the device memory can be reclaimed at - // any time after the temporary memory is marked as finalized (e.g. if a - // separate thread is calls Stream::BlockHostUntilDone). This may only be - // called once -- see the precondition below. - // - // Precondition: !IsFinalized() - void Finalize(); - - // Returns true iff the temporary memory is finalized (that is, the user is - // done referring to the temporary device memory, and thus it can be released - // at the next stream synchronization point). - bool IsFinalized() const; - - // Returns true iff the temporary memory is still allocated. - // - // Note: this is a polling call, no guarantee is made that the temporary - // memory is still allocated after the call has completed. - bool IsAllocated() const; - - private: - friend class internal::TemporaryMemoryManager; - friend class TemporaryDeviceMemoryTest; - - // Note: construction DCHECKs that the memory is known-allocated in the - // stream's temporary-allocation-manager. - TemporaryDeviceMemoryBase(Stream* parent, DeviceMemoryBase device_memory, - uint64_t allocation_generation); - - // The device memory region that has allocated. - DeviceMemoryBase device_memory_; - - // The generation counter value for the temporary memory record in the - // temporary memory manager. - uint64_t allocation_generation_; - - // The stream that this temporary memory was allocated for. - Stream* parent_; -}; - -// Type-safe wrapper around the base type (which is analogous to a void*). -template -class TemporaryDeviceMemory : public TemporaryDeviceMemoryBase { - public: - // Type-safe wrapper around TemporaryDeviceMemoryBase::mutable_device_memory. - DeviceMemory* mutable_device_memory() { - StaticSlicingAssertionDummy(); - return reinterpret_cast*>( - TemporaryDeviceMemoryBase::mutable_device_memory()); - } - - // Type-safe wrapper around TemporaryDeviceMemoryBase::device_memory. - const DeviceMemory& device_memory() const { - StaticSlicingAssertionDummy(); - return reinterpret_cast&>( - TemporaryDeviceMemoryBase::device_memory()); - } - - private: - static void StaticSlicingAssertionDummy() { - static_assert( - sizeof(TemporaryDeviceMemory) == sizeof(TemporaryDeviceMemoryBase), - "derived class is simply a wrapper, no members may be added due to " - "slicing"); - } -}; - -} // namespace stream_executor - -#endif // XLA_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_ diff --git a/xla/stream_executor/temporary_memory_manager.cc b/xla/stream_executor/temporary_memory_manager.cc deleted file mode 100644 index f5b2955a3d265..0000000000000 --- a/xla/stream_executor/temporary_memory_manager.cc +++ /dev/null @@ -1,129 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/stream_executor/temporary_memory_manager.h" - -#include - -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "xla/stream_executor/stream.h" -#include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/logging.h" - -namespace stream_executor { -namespace internal { - -void TemporaryMemoryManager::ForceDeallocateAll() { - absl::MutexLock lock(&mutex_); - VLOG(1) << "force-deallocating " << records_.size() << " remaining records"; - for (auto it = records_.begin(); it != records_.end(); ++it) { - DeviceMemoryBase device_memory = it->first; - stream_->parent()->Deallocate(&device_memory); - } -} - -void TemporaryMemoryManager::MarkFinalized( - const DeviceMemoryBase& device_memory, uint64_t generation, - bool must_exist) { - absl::MutexLock lock(&mutex_); - auto it = records_.find(device_memory); - if (it == records_.end()) { - if (must_exist) { - LOG(FATAL) << "attempted to mark finalization for temporary " - "memory that does not exist"; - } - return; - } - it->second.finalized = true; -} - -void TemporaryMemoryManager::DeallocateFinalizedTemporaries() { - absl::MutexLock lock(&mutex_); - int deallocated_count = 0; - for (auto it = records_.begin(); it != records_.end();) { - if (it->second.finalized) { - DeviceMemoryBase device_memory = it->first; - stream_->parent()->Deallocate(&device_memory); - ++deallocated_count; - it = records_.erase(it); - } else { - ++it; - } - } - VLOG(1) << "deallocated " << deallocated_count << " finalized temporaries"; -} - -bool TemporaryMemoryManager::IsFinalized(const DeviceMemoryBase& device_memory, - uint64_t allocation_generation) const { - absl::MutexLock lock(&mutex_); - auto it = records_.find(device_memory); - if (it == records_.end()) { - return true; // If there's no record present it's vacuously finalized. - } - - if (it->second.allocation_generation == allocation_generation) { - return it->second.finalized; - } - - // If the allocation generation did not match, it's vacuously true. - return true; -} - -bool TemporaryMemoryManager::HasAllocated(const DeviceMemoryBase& device_memory, - uint64_t generation) const { - absl::MutexLock lock(&mutex_); - auto it = records_.find(device_memory); - if (it == records_.end()) { - return false; - } - return it->second.allocation_generation == generation; -} - -tsl::StatusOr> -TemporaryMemoryManager::AllocateArrayBase(uint64_t element_count, - uint64_t element_size) { - uint64_t byte_size = element_count * element_size; - DeviceMemoryBase device_memory = - stream_->parent()->AllocateArray(byte_size); - if (device_memory == nullptr) { - return tsl::Status(absl::StatusCode::kResourceExhausted, - absl::StrCat("could not allocate temporary memory of ", - byte_size, " bytes")); - } - - uint64_t generation; - - // Add the record before instantiating the device memory instance so we can - // check the allocation invariant at TemporaryDeviceMemory construction time. - { - absl::MutexLock lock(&mutex_); - generation = ++generation_; - DCHECK(records_.find(device_memory) == records_.end()); - records_[device_memory] = {generation, - /*finalized=*/false}; - } - - VLOG(1) << absl::StreamFormat( - "stream %p allocated temporary device memory at %p (size %u) in " - "generation %u", - stream_, device_memory.opaque(), byte_size, generation); - std::unique_ptr result( - new TemporaryDeviceMemoryBase(stream_, device_memory, generation)); - return std::move(result); -} - -} // namespace internal -} // namespace stream_executor diff --git a/xla/stream_executor/temporary_memory_manager.h b/xla/stream_executor/temporary_memory_manager.h deleted file mode 100644 index 6a82fe0667235..0000000000000 --- a/xla/stream_executor/temporary_memory_manager.h +++ /dev/null @@ -1,153 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// The temporary-memory-manager is a helper class for a Stream to keep track of -// temporary allocations. These allocations defer their deallocation to the next -// Stream::BlockHostUntilDone call for efficiency purposes (as deallocation -// itself generally forces synchronization to occur). - -#ifndef XLA_STREAM_EXECUTOR_TEMPORARY_MEMORY_MANAGER_H_ -#define XLA_STREAM_EXECUTOR_TEMPORARY_MEMORY_MANAGER_H_ - -#include -#include -#include - -#include "absl/base/thread_annotations.h" -#include "absl/synchronization/mutex.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/temporary_device_memory.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" - -namespace stream_executor { -namespace internal { - -// Record used inside the TemporaryMemoryManager as metadata for a given device -// memory region. -struct TemporaryMemoryRecord { - // What "generation" this record was allocated in. - // - // Currently the generation counter is bumped for every allocation, but this - // could be made coarser if necessary. - uint64_t allocation_generation; - - // Notes whether the temporary memory has been marked as finalized, such that - // we can release the DeviceMemory associated with this record at - // synchronization time. - bool finalized; -}; - -// Manages temporary memories associated with a stream -- keeps records of -// outstanding temporaries and their state, and can deallocate them -// appropriately at points in the Stream lifecycle (e.g. BlockHostUntilDone, -// destruction). -class TemporaryMemoryManager { - public: - explicit TemporaryMemoryManager(Stream* stream) : stream_(stream) {} - - // Allocates a temporary array that is then managed by this object. - template - tsl::StatusOr>> AllocateArray( - uint64_t element_count); - - // Forces deallocation of all managed temporary memory regions. - // - // Called, for example, when the Stream owning this temporary memory manager - // is destroyed. - // - // Note: These calls to Deallocate will likely force synchronization. - void ForceDeallocateAll(); - - // Marks the given memory region as finalized. - // - // If must_exist is set, this will check-fail if the temporary memory record - // is not found. - void MarkFinalized(const DeviceMemoryBase& device_memory, uint64_t generation, - bool must_exist); - - // Deallocates temporary memories that have been finalized. - // - // Note: These calls to Deallocate will likely force synchronization, so it is - // meant to be called before a "BlockHostUntilDone" is about to be performed. - void DeallocateFinalizedTemporaries(); - - // Returns whether the provided device_memory is finalized. - // - // In the vacuous case where the device memory doesn't appear in the temporary - // memory records, it is either not a temporary at all, or has already been - // deallocated, and thus returns true. - bool IsFinalized(const DeviceMemoryBase& device_memory, - uint64_t allocation_generation) const; - - // Returns whether the manager has a live allocation record for the given - // device memory pointer with the given generation counter. - // - // Note: this is a polling call -- there is no guarantee that the region is - // still allocated once the call has completed. - bool HasAllocated(const DeviceMemoryBase& device_memory, - uint64_t generation) const; - - private: - // Allocates an array without type parameterization, so that the - // implementation can live in the source file. Without this base allocation - // method, we incur a circular dependency between the StreamExecutor - // definition and this class' definition. - tsl::StatusOr> AllocateArrayBase( - uint64_t element_count, uint64 element_size); - - // Mutex to guard temporary record state. - mutable absl::Mutex mutex_; - - // Mapping from device memory to the current (live) temporary memory record. - // - // If a device memory is not in this mapping, it is not a temporary currently - // allocated and owned by this temporary memory manager. - std::map records_ - ABSL_GUARDED_BY(mutex_); - - // Allocation generation -- we bump this counter to distinguish temporary - // memory handles that have been deallocated and later reallocated at the same - // device memory address. - uint64_t generation_ ABSL_GUARDED_BY(mutex_); - - // The stream (parent object) for this temporary memory manager -- allocations - // are performed through this stream handle. - Stream* stream_; - - TemporaryMemoryManager(const TemporaryMemoryManager&) = delete; - void operator=(const TemporaryMemoryManager&) = delete; -}; - -//////////// -// Inlines - -template -tsl::StatusOr>> -TemporaryMemoryManager::AllocateArray(uint64_t element_count) { - tsl::StatusOr> temporary_memory = - AllocateArrayBase(element_count, sizeof(T)); - if (!temporary_memory.ok()) { - return temporary_memory.status(); - } - - return std::unique_ptr>( - reinterpret_cast*>(temporary_memory->release())); -} - -} // namespace internal -} // namespace stream_executor - -#endif // XLA_STREAM_EXECUTOR_TEMPORARY_MEMORY_MANAGER_H_ diff --git a/xla/stream_executor/tpu/BUILD b/xla/stream_executor/tpu/BUILD index 5dcf813a05a4c..0e7f62c4a0d87 100644 --- a/xla/stream_executor/tpu/BUILD +++ b/xla/stream_executor/tpu/BUILD @@ -1,17 +1,16 @@ # Description: StreamExecutor Interface for TPUs -load("//xla:xla.bzl", "xla_cc_test") -load("@tsl//tsl:tsl.bzl", "set_external_visibility") +load("@tsl//tsl:tsl.bzl", "internal_visibility") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = set_external_visibility([ + default_visibility = internal_visibility([ "//learning/brain/experimental/dtensor:__subpackages__", "//learning/brain/google/xla/kernels:__subpackages__", "//learning/brain/research/pjrt:__subpackages__", "//learning/brain/tfrc/executor:__subpackages__", - "//learning/brain/tfrc/runtime/tpu_driver:__subpackages__", "//learning/brain/tfrt/tpu_plugin:__subpackages__", "//tensorflow/compiler/jit:__subpackages__", "//tensorflow/compiler/mlir:__subpackages__", @@ -33,7 +32,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":libtftpu_header", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor", ], ) @@ -97,12 +96,12 @@ xla_cc_test( "//xla/service:hlo_module_config", "//xla/service:hlo_parser", "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@tsl//tsl/platform:protobuf", - "@tsl//tsl/platform:statusor", ], ) @@ -142,9 +141,9 @@ cc_library( hdrs = ["tsl_status_helper.h"], deps = [ ":c_api_decl", + "//xla/tsl/c:tsl_status", + "//xla/tsl/c:tsl_status_helper", "@com_google_absl//absl/status", - "@tsl//tsl/c:tsl_status", - "@tsl//tsl/c:tsl_status_helper", "@tsl//tsl/platform:status", ], ) @@ -202,15 +201,17 @@ cc_library( ":tpu_stream_interface", ":tpu_topology_external", "//xla/stream_executor", + "//xla/stream_executor:stream_executor_internal", "//xla/stream_executor/platform", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:types", ], alwayslink = True, @@ -230,12 +231,14 @@ cc_library( ":tpu_executor_c_api_hdrs", ":tpu_platform_interface", ":tpu_topology_external", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor", + "//xla/stream_executor:stream_executor_internal", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", ], ) @@ -256,16 +259,17 @@ cc_library( ":tpu_executor_api", ":tpu_executor_c_api_hdrs", ":tpu_topology_external", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor", "//xla/stream_executor:stream_executor_internal", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:types", ], ) @@ -304,17 +308,18 @@ cc_library( "//xla:status", "//xla/stream_executor", "//xla/stream_executor:stream_executor_internal", + "//xla/tsl/c:tsl_status_internal", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@tsl//tsl/c:tsl_status_internal", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:types", ], ) @@ -352,9 +357,9 @@ cc_library( "//xla/service:backend", "//xla/service:stream_pool", "//xla/stream_executor:device_memory_allocator", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@tsl//tsl/platform:macros", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", ], ) @@ -381,7 +386,7 @@ cc_library( ":tpu_platform_interface", "//xla:status", "//xla/service:transfer_manager", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor", ], ) @@ -420,11 +425,11 @@ cc_library( "//xla/service:transfer_manager", "//xla/stream_executor", "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", ], ) @@ -464,10 +469,11 @@ cc_library( ":c_api_decl", ":tpu_topology_external", "//xla/stream_executor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", "@tsl//tsl/platform:env", "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:status", "@tsl//tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -479,10 +485,10 @@ cc_library( deps = [ ":tpu_platform_interface", ":tpu_topology_external", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor", "//xla/stream_executor:stream_executor_internal", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", ], ) @@ -525,7 +531,7 @@ cc_library( "//xla/service:executable", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_proto_cc", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor", "@com_google_absl//absl/cleanup", ], alwayslink = True, @@ -536,9 +542,9 @@ cc_library( hdrs = ["tpu_stream_interface.h"], visibility = ["//visibility:public"], deps = [ - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor", "//xla/stream_executor:stream_executor_internal", - "@tsl//tsl/platform:status", + "@com_google_absl//absl/status", ], ) @@ -560,8 +566,9 @@ cc_library( "//xla/service:maybe_owning_device_memory", "//xla/service:shaped_buffer", "//xla/service:transfer_manager", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", @@ -590,8 +597,8 @@ cc_library( "//xla/service:executable", "//xla/service:hlo_execution_profile", "//xla/service:shaped_buffer", + "//xla/stream_executor", "//xla/stream_executor:device_memory", - "//xla/stream_executor:stream_executor_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", diff --git a/xla/stream_executor/tpu/c_api_conversions.cc b/xla/stream_executor/tpu/c_api_conversions.cc index 3348871f7f946..24709bb78ef79 100644 --- a/xla/stream_executor/tpu/c_api_conversions.cc +++ b/xla/stream_executor/tpu/c_api_conversions.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -323,9 +323,31 @@ void Destroy(XLA_Shape* c_shape) { void ToC(const xla::Layout& layout, XLA_Layout* c_layout) { CreateVector(layout.minor_to_major(), &c_layout->minor_to_major); - CreateVector(layout.dim_level_types(), &c_layout->dim_level_types); - CreateVector(layout.dim_unique(), &c_layout->dim_unique); - CreateVector(layout.dim_ordered(), &c_layout->dim_ordered); + { + const int n = layout.dim_level_types_size(); + absl::InlinedVector dim_level_types( + n); + for (int i = 0; i < n; i++) { + dim_level_types[i] = layout.dim_level_type(i); + } + CreateVector(dim_level_types, &c_layout->dim_level_types); + } + { + const int n = layout.dim_unique_size(); + absl::InlinedVector dim_unique(n); + for (int i = 0; i < n; i++) { + dim_unique[i] = layout.dim_unique(i); + } + CreateVector(dim_unique, &c_layout->dim_unique); + } + { + const int n = layout.dim_ordered_size(); + absl::InlinedVector dim_ordered(n); + for (int i = 0; i < n; i++) { + dim_ordered[i] = layout.dim_ordered(i); + } + CreateVector(dim_ordered, &c_layout->dim_ordered); + } c_layout->index_primitive_type = layout.index_primitive_type(); c_layout->pointer_primitive_type = layout.pointer_primitive_type(); c_layout->element_size_in_bits = layout.element_size_in_bits(); @@ -333,6 +355,8 @@ void ToC(const xla::Layout& layout, XLA_Layout* c_layout) { c_layout->dynamic_shape_metadata_prefix_bytes = layout.dynamic_shape_metadata_prefix_bytes(); CreateVector(layout.tiles(), &c_layout->tiles); + c_layout->tail_padding_alignment_in_elements = + layout.tail_padding_alignment_in_elements(); } xla::Layout FromC(const XLA_Layout* c_layout) { @@ -360,9 +384,11 @@ xla::Layout FromC(const XLA_Layout* c_layout) { } return xla::Layout( minor_to_major, dim_level_types, dim_unique, dim_ordered, tiles, + c_layout->tail_padding_alignment_in_elements, static_cast(c_layout->index_primitive_type), static_cast(c_layout->pointer_primitive_type), c_layout->element_size_in_bits, c_layout->memory_space, + /*split_configs=*/{}, /*physical_shape=*/nullptr, c_layout->dynamic_shape_metadata_prefix_bytes); } @@ -405,8 +431,7 @@ XLA_ShapeIndex ToC(const xla::ShapeIndex& xla_shape) { } xla::ShapeIndex FromC(XLA_ShapeIndex* c_shape) { - return xla::ShapeIndex(&c_shape->indices[0], - &c_shape->indices[c_shape->count]); + return xla::ShapeIndex(c_shape->indices, c_shape->indices + c_shape->count); } void ToC(const xla::LiteralSlice& literal, XLA_Literal* c_literal) { @@ -474,7 +499,7 @@ XLA_HloModule ToC(const xla::HloModule& module) { return c_module; } -xla::StatusOr> FromC( +absl::StatusOr> FromC( const XLA_HloModule& c_module) { xla::HloModuleProto module_proto = stream_executor::tpu::DeserializeProto( @@ -509,6 +534,8 @@ XLA_HloModuleConfig ToC(const xla::HloModuleConfig& config) { hlo_config.num_partitions = config.num_partitions(); hlo_config.use_spmd_partitioning = config.use_spmd_partitioning(); hlo_config.use_auto_spmd_partitioning = config.use_auto_spmd_partitioning(); + CreateVector(config.allow_spmd_sharding_propagation_to_parameters(), + &hlo_config.allow_spmd_sharding_propagation_to_parameters); CreateVector(config.allow_spmd_sharding_propagation_to_output(), &hlo_config.allow_spmd_sharding_propagation_to_output); CreateVector(config.auto_spmd_partitioning_mesh_shape(), @@ -557,17 +584,18 @@ xla::HloModuleConfig FromC(const XLA_HloModuleConfig& c_config) { config.set_num_partitions(c_config.num_partitions); config.set_use_spmd_partitioning(c_config.use_spmd_partitioning); config.set_use_auto_spmd_partitioning(c_config.use_auto_spmd_partitioning); + config.set_allow_spmd_sharding_propagation_to_parameters( + MakeSpan(c_config.allow_spmd_sharding_propagation_to_parameters)); config.set_allow_spmd_sharding_propagation_to_output( MakeSpan(c_config.allow_spmd_sharding_propagation_to_output)); absl::Span mesh_shape_span = MakeSpan(c_config.auto_spmd_partitioning_mesh_shape); - std::vector mesh_shape(mesh_shape_span.begin(), - mesh_shape_span.end()); - config.set_auto_spmd_partitioning_mesh_shape(mesh_shape); + config.set_auto_spmd_partitioning_mesh_shape( + std::vector(mesh_shape_span.begin(), mesh_shape_span.end())); absl::Span mesh_ids_span = MakeSpan(c_config.auto_spmd_partitioning_mesh_ids); - std::vector mesh_ids(mesh_ids_span.begin(), mesh_ids_span.end()); - config.set_auto_spmd_partitioning_mesh_ids(mesh_ids); + config.set_auto_spmd_partitioning_mesh_ids( + std::vector(mesh_ids_span.begin(), mesh_ids_span.end())); if (c_config.has_static_device_assignment) { auto device_assignment = xla::DeviceAssignment::Deserialize( stream_executor::tpu::DeserializeProto( diff --git a/xla/stream_executor/tpu/c_api_conversions.h b/xla/stream_executor/tpu/c_api_conversions.h index 47b9f99551dcd..dfe5f2c7e5946 100644 --- a/xla/stream_executor/tpu/c_api_conversions.h +++ b/xla/stream_executor/tpu/c_api_conversions.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -128,7 +128,7 @@ SE_MaybeOwningDeviceMemory ToC(xla::MaybeOwningDeviceMemory& mem, bool aliased); // HloModule XLA_HloModule ToC(const xla::HloModule& module); -xla::StatusOr> FromC( +absl::StatusOr> FromC( const XLA_HloModule& c_module); void Destroy(XLA_HloModule* c_module); diff --git a/xla/stream_executor/tpu/c_api_conversions_test.cc b/xla/stream_executor/tpu/c_api_conversions_test.cc index 4a2bae3f0a8ba..4092cedce69af 100644 --- a/xla/stream_executor/tpu/c_api_conversions_test.cc +++ b/xla/stream_executor/tpu/c_api_conversions_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -38,7 +39,6 @@ limitations under the License. #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/protobuf.h" -#include "tsl/platform/statusor.h" namespace ApiConverter { @@ -130,26 +130,23 @@ void XlaLayout_ToC(const xla::Layout& cpp_layout) { MakeSpan(c_layout.minor_to_major); EXPECT_EQ(cpp_minor_to_major, c_minor_to_major); - absl::Span cpp_dim_level_types = - cpp_layout.dim_level_types(); absl::Span c_dim_level_types = MakeSpan(c_layout.dim_level_types); - EXPECT_EQ(cpp_dim_level_types.size(), c_dim_level_types.size()); + EXPECT_EQ(cpp_layout.dim_level_types_size(), c_dim_level_types.size()); for (int i = 0; i < c_dim_level_types.size(); ++i) { - EXPECT_EQ(static_cast(cpp_dim_level_types[i]), c_dim_level_types[i]); + EXPECT_EQ(static_cast(cpp_layout.dim_level_type(i)), + c_dim_level_types[i]); } - absl::Span cpp_dim_unique = cpp_layout.dim_unique(); absl::Span c_dim_unique = MakeSpan(c_layout.dim_unique); - EXPECT_EQ(cpp_dim_unique.size(), c_dim_unique.size()); + EXPECT_EQ(cpp_layout.dim_unique_size(), c_dim_unique.size()); for (int i = 0; i < c_dim_unique.size(); ++i) { - EXPECT_EQ(cpp_dim_unique[i], static_cast(c_dim_unique[i])); + EXPECT_EQ(cpp_layout.dim_unique(i), static_cast(c_dim_unique[i])); } - absl::Span cpp_dim_ordered = cpp_layout.dim_ordered(); absl::Span c_dim_ordered = MakeSpan(c_layout.dim_ordered); - EXPECT_EQ(cpp_dim_ordered.size(), c_dim_ordered.size()); + EXPECT_EQ(cpp_layout.dim_ordered_size(), c_dim_ordered.size()); for (int i = 0; i < c_dim_ordered.size(); ++i) { - EXPECT_EQ(cpp_dim_ordered[i], static_cast(c_dim_ordered[i])); + EXPECT_EQ(cpp_layout.dim_ordered(i), static_cast(c_dim_ordered[i])); } absl::Span cpp_tiles = cpp_layout.tiles(); @@ -298,7 +295,7 @@ TEST(XlaShape, FromCNested) { // TODO(b/290654348): xla::ShapeIndex, xla::Literal, xla::ShapedBuffer TEST(XlaHloModuleConfig, ToAndFromC) { - xla::StatusOr> hlo_module = + absl::StatusOr> hlo_module = xla::ParseAndReturnUnverifiedModule(kHloString); ASSERT_TRUE(hlo_module.ok()); xla::HloModule& cpp_module = *hlo_module.value(); @@ -321,13 +318,13 @@ TEST(XlaHloModuleConfig, ToAndFromC) { } TEST(XlaHloModule, ToAndFromC) { - xla::StatusOr> hlo_module = + absl::StatusOr> hlo_module = xla::ParseAndReturnUnverifiedModule(kHloString); ASSERT_TRUE(hlo_module.ok()); xla::HloModule& in_module = *hlo_module.value(); XLA_HloModule c_module = ToC(in_module); - xla::StatusOr> out_module_ptr = + absl::StatusOr> out_module_ptr = FromC(c_module); ASSERT_TRUE(out_module_ptr.ok()); xla::HloModule& out_module = *out_module_ptr.value(); diff --git a/xla/stream_executor/tpu/c_api_decl.h b/xla/stream_executor/tpu/c_api_decl.h index bbeb12472ea8b..c6aec20911ac0 100644 --- a/xla/stream_executor/tpu/c_api_decl.h +++ b/xla/stream_executor/tpu/c_api_decl.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -67,7 +67,6 @@ typedef struct SE_PlatformId { void* id; // aka stream_executor::Platform::Id } SE_PlatformId; typedef struct SE_StreamExecutorConfig SE_StreamExecutorConfig; -typedef struct SE_DeviceOptions SE_DeviceOptions; typedef TF_Status* (*SE_StatusCallback)(void*); typedef struct SE_DeviceMemoryBase { @@ -246,6 +245,7 @@ typedef struct XLA_Layout { int64_t element_size_in_bits; int64_t memory_space; int64_t dynamic_shape_metadata_prefix_bytes; + int64_t tail_padding_alignment_in_elements; } XLA_Layout; // Represents an XLA shape tree. @@ -321,6 +321,7 @@ typedef struct XLA_HloModuleConfig { TpuSerializedProto static_device_assignment; bool has_entry_computation_layout; XLA_ComputationLayout entry_computation_layout; + BoolList allow_spmd_sharding_propagation_to_parameters; BoolList allow_spmd_sharding_propagation_to_output; } XLA_HloModuleConfig; diff --git a/xla/stream_executor/tpu/c_api_defn.h b/xla/stream_executor/tpu/c_api_defn.h index 8e4707fcbbda0..407d6c326707a 100644 --- a/xla/stream_executor/tpu/c_api_defn.h +++ b/xla/stream_executor/tpu/c_api_defn.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_TPU_C_API_DEFN_H_ #define XLA_STREAM_EXECUTOR_TPU_C_API_DEFN_H_ -#include "xla/stream_executor/device_options.h" +#include + #include "xla/stream_executor/event.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" @@ -39,9 +40,8 @@ struct SE_StreamExecutor { }; struct SE_Stream { - explicit SE_Stream(stream_executor::StreamExecutor* parent) - : stream(parent) {} - stream_executor::Stream stream; + explicit SE_Stream(stream_executor::StreamExecutor* parent) {} + std::unique_ptr stream; }; struct SE_Event { @@ -53,10 +53,6 @@ struct SE_StreamExecutorConfig { stream_executor::StreamExecutorConfig config; }; -struct SE_DeviceOptions { - stream_executor::DeviceOptions options; -}; - // Ignored -- these are just used to enforce the interface types struct XLA_TransferManager {}; struct XLA_ComputationPlacer {}; diff --git a/xla/stream_executor/tpu/libtftpu.h b/xla/stream_executor/tpu/libtftpu.h index c5e3fb59a4a5d..3f20d6d3417af 100644 --- a/xla/stream_executor/tpu/libtftpu.h +++ b/xla/stream_executor/tpu/libtftpu.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/tpu/noncopyable_buffer.h b/xla/stream_executor/tpu/noncopyable_buffer.h index 997018149331f..8e0abf45c88af 100644 --- a/xla/stream_executor/tpu/noncopyable_buffer.h +++ b/xla/stream_executor/tpu/noncopyable_buffer.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/tpu/proto_helper.cc b/xla/stream_executor/tpu/proto_helper.cc index f5f66e1d67cb1..0e852880f542a 100644 --- a/xla/stream_executor/tpu/proto_helper.cc +++ b/xla/stream_executor/tpu/proto_helper.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/tpu/proto_helper.h b/xla/stream_executor/tpu/proto_helper.h index b4f505ed135d8..cda8b48910266 100644 --- a/xla/stream_executor/tpu/proto_helper.h +++ b/xla/stream_executor/tpu/proto_helper.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/tpu/status_helper.h b/xla/stream_executor/tpu/status_helper.h index b49d801a64f53..64a87d9afa44c 100644 --- a/xla/stream_executor/tpu/status_helper.h +++ b/xla/stream_executor/tpu/status_helper.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ limitations under the License. #include "xla/stream_executor/tpu/c_api_decl.h" #include "xla/stream_executor/tpu/tpu_executor_api.h" #include "xla/stream_executor/tpu/tpu_executor_c_api.h" -#include "tsl/platform/status.h" class StatusHelper { public: @@ -31,12 +30,12 @@ class StatusHelper { stream_executor::tpu::ExecutorApiFn()->TpuStatus_FreeFn(c_status); } - static tsl::Status FromC( // TENSORFLOW_STATUS_OK + static absl::Status FromC( // TENSORFLOW_STATUS_OK TF_Status* const c_status) { if (stream_executor::tpu::ExecutorApiFn()->TpuStatus_OkFn(c_status)) { - return ::tsl::OkStatus(); + return absl::OkStatus(); } else { - return tsl::Status( // TENSORFLOW_STATUS_OK + return absl::Status( // TENSORFLOW_STATUS_OK absl::StatusCode( stream_executor::tpu::ExecutorApiFn()->TpuStatus_CodeFn( c_status)), @@ -48,7 +47,7 @@ class StatusHelper { return stream_executor::tpu::ExecutorApiFn()->TpuStatus_OkFn(c_status); } - tsl::Status status() const { // TENSORFLOW_STATUS_OK + absl::Status status() const { // TENSORFLOW_STATUS_OK return FromC(c_status); } diff --git a/xla/stream_executor/tpu/tpu_api.cc b/xla/stream_executor/tpu/tpu_api.cc index f69ab2ed1d327..b569d91f0a627 100644 --- a/xla/stream_executor/tpu/tpu_api.cc +++ b/xla/stream_executor/tpu/tpu_api.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/tpu/tpu_api.h b/xla/stream_executor/tpu/tpu_api.h index 9a3cf2b330abc..8338f0f6c4cb9 100644 --- a/xla/stream_executor/tpu/tpu_api.h +++ b/xla/stream_executor/tpu/tpu_api.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/tpu/tpu_api_dlsym_set_fn.h b/xla/stream_executor/tpu/tpu_api_dlsym_set_fn.h index 3298ce9b10455..5aaa164b65616 100644 --- a/xla/stream_executor/tpu/tpu_api_dlsym_set_fn.h +++ b/xla/stream_executor/tpu/tpu_api_dlsym_set_fn.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,8 +21,8 @@ limitations under the License. reinterpret_cast(dlsym(library_handle, #FnName)); \ if (!(Struct->FnName##Fn)) { \ LOG(FATAL) << #FnName " not available in this library."; \ - return tsl::errors::Unimplemented(#FnName \ - " not available in this library."); \ + return absl::UnimplementedError(#FnName \ + " not available in this library."); \ } #endif // XLA_STREAM_EXECUTOR_TPU_TPU_API_DLSYM_SET_FN_H_ diff --git a/xla/stream_executor/tpu/tpu_event.h b/xla/stream_executor/tpu/tpu_event.h index 004ac5fc5c387..a9daeea816077 100644 --- a/xla/stream_executor/tpu/tpu_event.h +++ b/xla/stream_executor/tpu/tpu_event.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/tpu/tpu_executable.cc b/xla/stream_executor/tpu/tpu_executable.cc index 2afe5798b8f79..abdd5f20c9172 100644 --- a/xla/stream_executor/tpu/tpu_executable.cc +++ b/xla/stream_executor/tpu/tpu_executable.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -93,7 +93,7 @@ TpuExecutable::~TpuExecutable() { ExecutorApiFn()->TpuExecutable_FreeFn(se_executable_); } -StatusOr TpuExecutable::ExecuteAsyncOnStream( +absl::StatusOr TpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, std::vector arguments, HloExecutionProfile* hlo_execution_profile) { @@ -182,7 +182,7 @@ absl::string_view TpuExecutable::fingerprint() const { return absl::string_view(data, size); } -StatusOr TpuExecutable::Serialize() const { +absl::StatusOr TpuExecutable::Serialize() const { SE_ExecutableSerializationHandle* handle = nullptr; absl::Cleanup cleanup = [&handle]() { ExecutorApiFn()->TpuExecutableSerialize_FreeHandleFn(handle); @@ -210,7 +210,7 @@ StatusOr TpuExecutable::Serialize() const { return serialized; } -StatusOr> TpuExecutable::Deserialize( +absl::StatusOr> TpuExecutable::Deserialize( absl::string_view serialized) { SE_Executable* se_executable; StatusHelper status; diff --git a/xla/stream_executor/tpu/tpu_executable.h b/xla/stream_executor/tpu/tpu_executable.h index c391e26ab20d5..cd3d8132c82a2 100644 --- a/xla/stream_executor/tpu/tpu_executable.h +++ b/xla/stream_executor/tpu/tpu_executable.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -46,7 +46,7 @@ class TpuExecutable : public xla::TpuExecutableInterface { ~TpuExecutable() override; - StatusOr ExecuteAsyncOnStream( + absl::StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, std::vector arguments, HloExecutionProfile* hlo_execution_profile) override; @@ -56,8 +56,8 @@ class TpuExecutable : public xla::TpuExecutableInterface { // The serialization is not guaranteed to be stable over time and has no // compatibility guarantees (i.e. this is not a suitable long-term storage // format). - StatusOr Serialize() const; - static StatusOr> Deserialize( + absl::StatusOr Serialize() const; + static absl::StatusOr> Deserialize( absl::string_view serialized); private: diff --git a/xla/stream_executor/tpu/tpu_executable_interface.cc b/xla/stream_executor/tpu/tpu_executable_interface.cc index 78de11aa795d3..69b9c4a0fc806 100644 --- a/xla/stream_executor/tpu/tpu_executable_interface.cc +++ b/xla/stream_executor/tpu/tpu_executable_interface.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/status/status.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/layout_util.h" @@ -59,9 +60,9 @@ static Status PopulateResultTupleBuffers(const ShapedBuffer& result, TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync( transfer_stream ? transfer_stream : stream, result)); if (transfer_stream && transfer_stream != stream) { - stream->ThenWaitFor(transfer_stream); + TF_RETURN_IF_ERROR(stream->WaitFor(transfer_stream)); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } else { return transfer_manager->WriteTupleIndexTablesAsync(stream, result); } @@ -69,7 +70,7 @@ static Status PopulateResultTupleBuffers(const ShapedBuffer& result, } // namespace -StatusOr +absl::StatusOr TpuExecutableInterface::AllocateOutputMemoryWithInputReuse( const Shape& shape, const HloInputOutputAliasConfig& alias_config, se::DeviceMemoryAllocator* allocator, @@ -100,7 +101,8 @@ TpuExecutableInterface::AllocateOutputMemoryWithInputReuse( TF_RETURN_IF_ERROR(alias_config.ForEachAliasWithStatus( [&](const ShapeIndex& output_index, - std::optional alias) { + std::optional alias) + -> absl::Status { if (alias && alias->must_alias()) { VLOG(1) << alias->ToString(); const MaybeOwningDeviceMemory& original_input = @@ -113,7 +115,7 @@ TpuExecutableInterface::AllocateOutputMemoryWithInputReuse( alias->ToString()); } } - return ::tsl::OkStatus(); + return absl::OkStatus(); })); if (VLOG_IS_ON(3)) { @@ -141,8 +143,7 @@ TpuExecutableInterface::AllocateOutputMemoryWithInputReuse( // Return an InternalError if result_index is invalid. This avoids failing // the CHECK when calling GetAliasedParameter if (!ShapeUtil::IndexIsValid(alias_config.shape(), result_index)) { - return InternalError("result_index is invalid: %s", - result_index.ToString()); + return Internal("result_index is invalid: %s", result_index.ToString()); } std::optional alias = @@ -204,7 +205,7 @@ TpuExecutableInterface::AllocateOutputMemoryWithInputReuse( return std::move(result); } -StatusOr TpuExecutableInterface::ExecuteAsyncOnStream( +absl::StatusOr TpuExecutableInterface::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, std::vector arguments, HloExecutionProfile* /*hlo_execution_profile*/) { diff --git a/xla/stream_executor/tpu/tpu_executable_interface.h b/xla/stream_executor/tpu/tpu_executable_interface.h index e6a8052a7e383..ff037c07a532f 100644 --- a/xla/stream_executor/tpu/tpu_executable_interface.h +++ b/xla/stream_executor/tpu/tpu_executable_interface.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -44,7 +44,7 @@ class TpuExecutableInterface : public Executable { : Executable(std::move(hlo_module)) {} ~TpuExecutableInterface() override = default; - StatusOr ExecuteAsyncOnStream( + absl::StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, std::vector arguments, HloExecutionProfile* hlo_execution_profile) override; @@ -63,7 +63,7 @@ class TpuExecutableInterface : public Executable { // // The optional 'transfer_stream' parameter enables transfers (for tuple // tables) to be performed on a separate stream to 'stream'. - StatusOr AllocateOutputMemoryWithInputReuse( + absl::StatusOr AllocateOutputMemoryWithInputReuse( const Shape& shape, const HloInputOutputAliasConfig& alias_config, se::DeviceMemoryAllocator* allocator, std::vector* arguments, se::Stream* stream, diff --git a/xla/stream_executor/tpu/tpu_executor.cc b/xla/stream_executor/tpu/tpu_executor.cc index 20f40f6c7bd1d..ea110a514e32b 100644 --- a/xla/stream_executor/tpu/tpu_executor.cc +++ b/xla/stream_executor/tpu/tpu_executor.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,12 +22,12 @@ limitations under the License. #include "absl/cleanup/cleanup.h" #include "absl/functional/any_invocable.h" +#include "absl/status/status.h" #include "absl/types/span.h" #include "xla/status.h" #include "xla/stream_executor/allocator_stats.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_options.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/stream_executor_internal.h" #include "xla/stream_executor/tpu/c_api_conversions.h" @@ -37,7 +37,7 @@ limitations under the License. #include "xla/stream_executor/tpu/tpu_executor_api.h" #include "xla/stream_executor/tpu/tpu_stream.h" #include "xla/stream_executor/tpu/tpu_topology.h" -#include "tsl/c/tsl_status.h" +#include "xla/tsl/c/tsl_status.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep @@ -45,19 +45,15 @@ namespace stream_executor { namespace tpu { namespace { -using xla::Status; +using absl::Status; } // namespace TpuExecutor::~TpuExecutor() { ExecutorApiFn()->TpuExecutor_FreeFn(executor_); } -Status TpuExecutor::Init(int device_ordinal, - ::stream_executor::DeviceOptions device_options) { +Status TpuExecutor::Init(int device_ordinal) { StatusHelper status; - SE_DeviceOptions* options = - ExecutorApiFn()->TpuExecutor_NewDeviceOptionsFn(device_options.flags()); - ExecutorApiFn()->TpuExecutor_InitFn(executor_, device_ordinal, options, + ExecutorApiFn()->TpuExecutor_InitFn(executor_, device_ordinal, status.c_status); - ExecutorApiFn()->TpuExecutor_FreeDeviceOptionsFn(options); return status.status(); } @@ -104,11 +100,11 @@ bool TpuExecutor::CreateStreamDependency(Stream* dependent, Stream* other) { get_stream(other->implementation())); } -Status TpuExecutor::AllocateEvent(Event* event) { return tsl::OkStatus(); } +Status TpuExecutor::AllocateEvent(Event* event) { return absl::OkStatus(); } Status TpuExecutor::DeallocateEvent(Event* event) { tpu_platform().EraseEvent(event->implementation()); - return tsl::OkStatus(); + return absl::OkStatus(); } stream_executor::Event::Status TpuExecutor::PollForEventStatus( @@ -234,22 +230,26 @@ Status TpuExecutor::EnqueueInfeed(int32_t infeed_queue_index, return status.status(); } -bool TpuExecutor::Memcpy(Stream* stream, void* host_dst, - const ::stream_executor::DeviceMemoryBase& device_src, - uint64_t size) { +absl::Status TpuExecutor::Memcpy( + Stream* stream, void* host_dst, + const ::stream_executor::DeviceMemoryBase& device_src, uint64_t size) { + StatusHelper status; SE_DeviceMemoryBase se_base = ApiConverter::ToC(device_src); - return ExecutorApiFn()->TpuExecutor_MemcpyToHostFn( - executor_, get_stream(stream->implementation()), host_dst, &se_base, - size); + ExecutorApiFn()->TpuExecutor_MemcpyToHostFn( + executor_, get_stream(stream->implementation()), host_dst, &se_base, size, + status.c_status); + return status.status(); } -bool TpuExecutor::Memcpy(Stream* stream, - ::stream_executor::DeviceMemoryBase* device_dst, - const void* host_src, uint64_t size) { +absl::Status TpuExecutor::Memcpy( + Stream* stream, ::stream_executor::DeviceMemoryBase* device_dst, + const void* host_src, uint64_t size) { + StatusHelper status; SE_DeviceMemoryBase se_base = ApiConverter::ToC(*device_dst); - return ExecutorApiFn()->TpuExecutor_MemcpyFromHostFn( - executor_, get_stream(stream->implementation()), &se_base, host_src, - size); + ExecutorApiFn()->TpuExecutor_MemcpyFromHostFn( + executor_, get_stream(stream->implementation()), &se_base, host_src, size, + status.c_status); + return status.status(); } Status TpuExecutor::SynchronousMemcpy( @@ -275,7 +275,7 @@ Status TpuExecutor::SynchronousMemcpy( Status TpuExecutor::SynchronousMemcpyDeviceToDevice( ::stream_executor::DeviceMemoryBase* device_dst, const ::stream_executor::DeviceMemoryBase& device_src, uint64_t size) { - return tsl::errors::Unimplemented("This operation not supported on TPU"); + return absl::UnimplementedError("This operation not supported on TPU"); } bool TpuExecutor::MemcpyDeviceToDevice( diff --git a/xla/stream_executor/tpu/tpu_executor.h b/xla/stream_executor/tpu/tpu_executor.h index 46ebe431a49fa..53226961c290f 100644 --- a/xla/stream_executor/tpu/tpu_executor.h +++ b/xla/stream_executor/tpu/tpu_executor.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,11 +23,12 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/stream_executor/allocator_stats.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_options.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" @@ -40,8 +41,6 @@ limitations under the License. #include "xla/stream_executor/tpu/tpu_topology.h" #include "tsl/platform/casts.h" #include "tsl/platform/logging.h" // IWYU pragma: keep -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" #include "tsl/platform/types.h" namespace stream_executor { @@ -50,8 +49,8 @@ namespace tpu { class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { public: template - using StatusOr = ::tsl::StatusOr; - using StatusCallback = std::function; + using StatusOr = ::absl::StatusOr; + using StatusCallback = std::function; using Stream = ::stream_executor::Stream; using Event = ::stream_executor::Event; using DeviceMemoryBase = ::stream_executor::DeviceMemoryBase; @@ -65,16 +64,15 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { ~TpuExecutor() override; - tsl::Status Init(int device_ordinal, - ::stream_executor::DeviceOptions device_options) override; + absl::Status Init(int device_ordinal) override; DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; - tsl::Status AllocateEvent(Event* event) override; + absl::Status AllocateEvent(Event* event) override; bool AllocateStream(Stream* stream) override; - tsl::Status BlockHostUntilDone(::stream_executor::Stream* stream) override; + absl::Status BlockHostUntilDone(::stream_executor::Stream* stream) override; StatusOr> CreateDeviceDescription() const override; @@ -87,22 +85,22 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { void Deallocate(DeviceMemoryBase* memory) override; - tsl::Status DeallocateEvent(Event* event) override; + absl::Status DeallocateEvent(Event* event) override; bool DeviceMemoryUsage(int64_t* free, int64_t* total) const override; void DequeueOutfeed(int32_t outfeed_queue_index, absl::Span bytes, StatusCallback done); - tsl::Status EnqueueInfeed(int32_t infeed_queue_index, - absl::Span bytes); + absl::Status EnqueueInfeed(int32_t infeed_queue_index, + absl::Span bytes); std::optional GetAllocatorStats() override; tensorflow::tpu::TpuCoreLocationExternal GetCoreLocationExternal() const override; - tsl::Status GetStatus(Stream* stream) override; + absl::Status GetStatus(Stream* stream) override; std::unique_ptr<::stream_executor::internal::StreamInterface> GetStreamImplementation() override; @@ -111,14 +109,15 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { CreateEventImplementation() override; bool HostCallback(Stream* stream, - absl::AnyInvocable callback) override; + absl::AnyInvocable callback) override; - bool Memcpy(Stream* stream, void* host_dst, - const ::stream_executor::DeviceMemoryBase& device_src, - uint64_t size) override; + absl::Status Memcpy(Stream* stream, void* host_dst, + const ::stream_executor::DeviceMemoryBase& device_src, + uint64_t size) override; - bool Memcpy(Stream* stream, ::stream_executor::DeviceMemoryBase* device_dst, - const void* host_src, uint64_t size) override; + absl::Status Memcpy(Stream* stream, + ::stream_executor::DeviceMemoryBase* device_dst, + const void* host_src, uint64_t size) override; bool MemcpyDeviceToDevice(Stream* stream, ::stream_executor::DeviceMemoryBase* gpu_dst, @@ -127,25 +126,26 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { bool SynchronizeAllActivity() override; - tsl::Status SynchronousMemcpy(::stream_executor::DeviceMemoryBase* device_dst, - const void* host_src, uint64_t size) override; - tsl::Status SynchronousMemcpy( + absl::Status SynchronousMemcpy( + ::stream_executor::DeviceMemoryBase* device_dst, const void* host_src, + uint64_t size) override; + absl::Status SynchronousMemcpy( void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src, uint64_t size) override; - tsl::Status SynchronousMemcpyDeviceToDevice( + absl::Status SynchronousMemcpyDeviceToDevice( ::stream_executor::DeviceMemoryBase* device_dst, const ::stream_executor::DeviceMemoryBase& device_src, uint64_t size) override; Event::Status PollForEventStatus(Event* event) override; - tsl::Status RecordEvent(Stream* stream, - ::stream_executor::Event* event) override; - tsl::Status WaitForEvent(Stream* stream, + absl::Status RecordEvent(Stream* stream, ::stream_executor::Event* event) override; + absl::Status WaitForEvent(Stream* stream, + ::stream_executor::Event* event) override; - tsl::Status UnloadAllPrograms() override; + absl::Status UnloadAllPrograms() override; - tsl::Status EnqueueCompactionOnStreamForHbm( + absl::Status EnqueueCompactionOnStreamForHbm( Stream* compaction_stream) override; const ::tensorflow::tpu::TpuPlatformInterface& platform() const override { @@ -166,24 +166,16 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { } // -- Unimplemented (stubbed out) methods. - std::unique_ptr - CreateKernelImplementation() override { - LOG(FATAL) << "Not yet implemented"; - } - void* GetSubBuffer(DeviceMemoryBase* parent, uint64_t offset, - uint64_t size) override { - LOG(FATAL) << "not yet implemented"; - } - tsl::Status MemZero(Stream* stream, DeviceMemoryBase* location, - uint64_t size) override { + absl::Status MemZero(Stream* stream, DeviceMemoryBase* location, + uint64_t size) override { LOG(FATAL) << "not yet implemented"; } - tsl::Status Memset32(Stream* stream, DeviceMemoryBase* location, - uint32_t pattern, uint64_t size) override { + absl::Status Memset32(Stream* stream, DeviceMemoryBase* location, + uint32_t pattern, uint64_t size) override { LOG(FATAL) << "not yet implemented"; } - tsl::Status EnablePeerAccessTo(StreamExecutorInterface* other) override { + absl::Status EnablePeerAccessTo(StreamExecutorInterface* other) override { LOG(FATAL) << "not yet implemented"; } bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override { @@ -202,12 +194,12 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { bool HostMemoryUnregister(void* mem) override { LOG(FATAL) << "not yet implemented"; } - tsl::Status SynchronousMemZero(DeviceMemoryBase* location, - uint64_t size) override { + absl::Status SynchronousMemZero(DeviceMemoryBase* location, + uint64_t size) override { LOG(FATAL) << "not yet implemented"; } - tsl::Status SynchronousMemSet(DeviceMemoryBase* location, int value, - uint64_t size) override { + absl::Status SynchronousMemSet(DeviceMemoryBase* location, int value, + uint64_t size) override { LOG(FATAL) << "not yet implemented"; } diff --git a/xla/stream_executor/tpu/tpu_executor_api.cc b/xla/stream_executor/tpu/tpu_executor_api.cc index a18ad3952d622..014c5cbd63d29 100644 --- a/xla/stream_executor/tpu/tpu_executor_api.cc +++ b/xla/stream_executor/tpu/tpu_executor_api.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/tpu/tpu_executor_api.h b/xla/stream_executor/tpu/tpu_executor_api.h index 34989709c31a6..c28d4d5cb45e9 100644 --- a/xla/stream_executor/tpu/tpu_executor_api.h +++ b/xla/stream_executor/tpu/tpu_executor_api.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/tpu/tpu_executor_c_api.h b/xla/stream_executor/tpu/tpu_executor_c_api.h index f2dc3e3918948..34f869e718849 100644 --- a/xla/stream_executor/tpu/tpu_executor_c_api.h +++ b/xla/stream_executor/tpu/tpu_executor_c_api.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ SE_TpuTopology_Host* TpuPlatform_GetHostLocation(SE_Platform* platform); TpuRuntimeVersion TpuPlatform_GetRuntimeVersion(SE_Platform* platform); void TpuExecutor_Init(SE_StreamExecutor* executor, int device_ordinal, - SE_DeviceOptions* device_options, TF_Status* status); + TF_Status* status); void TpuExecutor_Free(SE_StreamExecutor* executor); SE_DeviceMemoryBase TpuExecutor_Allocate(SE_StreamExecutor* executor, @@ -82,14 +82,15 @@ void TpuExecutor_SynchronousMemcpyFromHost(SE_StreamExecutor* executor, SE_DeviceMemoryBase* device_dst, const void* host_src, uint64_t size, TF_Status* status); -bool TpuExecutor_MemcpyToHost(SE_StreamExecutor* executor, SE_Stream* stream, +void TpuExecutor_MemcpyToHost(SE_StreamExecutor* executor, SE_Stream* stream, void* host_dst, const SE_DeviceMemoryBase* device_src, - uint64_t size); + uint64_t size, TF_Status* status); -bool TpuExecutor_MemcpyFromHost(SE_StreamExecutor* executor, SE_Stream* stream, +void TpuExecutor_MemcpyFromHost(SE_StreamExecutor* executor, SE_Stream* stream, SE_DeviceMemoryBase* device_dst, - const void* host_src, uint64_t size); + const void* host_src, uint64_t size, + TF_Status* status); void TpuExecutor_EnqueueInfeed(SE_StreamExecutor* executor, int32_t infeed_queue_index, const uint8_t* data, @@ -157,9 +158,6 @@ void TpuExecutor_CreateDeviceDescription(SE_StreamExecutor* executor, SE_DeviceDescription* description, TF_Status* status); -SE_DeviceOptions* TpuExecutor_NewDeviceOptions(unsigned flags); -void TpuExecutor_FreeDeviceOptions(SE_DeviceOptions* options); - bool TpuExecutor_HostCallback(SE_StreamExecutor* executor, SE_Stream* stream, SE_StatusCallback callback_fn, void* ctx); @@ -445,8 +443,6 @@ struct TfTpu_ExecutorApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuDeviceDescription_Free); TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_CreateDeviceDescription); - TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_NewDeviceOptions); - TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_FreeDeviceOptions); TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_HostCallback); TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_New); diff --git a/xla/stream_executor/tpu/tpu_executor_init_fns.inc b/xla/stream_executor/tpu/tpu_executor_init_fns.inc index 9f51f04b050a6..3221fe01c0076 100644 --- a/xla/stream_executor/tpu/tpu_executor_init_fns.inc +++ b/xla/stream_executor/tpu/tpu_executor_init_fns.inc @@ -1,6 +1,7 @@ namespace { -tsl::Status SetExecutorStructFn(void* library_handle) { // TENSORFLOW_STATUS_OK +absl::Status SetExecutorStructFn( + void* library_handle) { // TENSORFLOW_STATUS_OK auto* executor_fn = stream_executor::tpu::ExecutorApiFn(); TFTPU_SET_FN(executor_fn, TpuPlatform_New); @@ -70,8 +71,6 @@ tsl::Status SetExecutorStructFn(void* library_handle) { // TENSORFLOW_STATUS_OK TFTPU_SET_FN(executor_fn, TpuDeviceDescription_Free); TFTPU_SET_FN(executor_fn, TpuExecutor_CreateDeviceDescription); - TFTPU_SET_FN(executor_fn, TpuExecutor_NewDeviceOptions); - TFTPU_SET_FN(executor_fn, TpuExecutor_FreeDeviceOptions); TFTPU_SET_FN(executor_fn, TpuExecutor_HostCallback); TFTPU_SET_FN(executor_fn, TpuTransferManager_New); @@ -149,7 +148,7 @@ tsl::Status SetExecutorStructFn(void* library_handle) { // TENSORFLOW_STATUS_OK TFTPU_SET_FN(executor_fn, TpuAsyncCollectiveOffloadHelper_Init); TFTPU_SET_FN(executor_fn, TpuAsyncCollectiveOffloadHelper_Shutdown); - return tsl::OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/xla/stream_executor/tpu/tpu_executor_interface.h b/xla/stream_executor/tpu/tpu_executor_interface.h index 9d07bb37e1d96..f716f1d53d063 100644 --- a/xla/stream_executor/tpu/tpu_executor_interface.h +++ b/xla/stream_executor/tpu/tpu_executor_interface.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,12 +19,12 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream_executor_internal.h" #include "xla/stream_executor/tpu/tpu_platform_interface.h" #include "xla/stream_executor/tpu/tpu_topology.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace tpu { class TpuCore; @@ -37,7 +37,7 @@ class TpuExecutorInterface : public stream_executor::internal::StreamExecutorInterface { public: template - using StatusOr = tsl::StatusOr; + using StatusOr = absl::StatusOr; class TemporaryDeviceMemory { public: @@ -61,9 +61,9 @@ class TpuExecutorInterface LOG(FATAL) << "Unimplemented."; } - virtual tsl::Status UnloadAllPrograms() { LOG(FATAL) << "Unimplemented."; } + virtual absl::Status UnloadAllPrograms() { LOG(FATAL) << "Unimplemented."; } - virtual tsl::Status EnqueueCompactionOnStreamForHbm( + virtual absl::Status EnqueueCompactionOnStreamForHbm( stream_executor::Stream* compaction_stream) { LOG(FATAL) << "Unimplemented."; } diff --git a/xla/stream_executor/tpu/tpu_initialize_util.cc b/xla/stream_executor/tpu/tpu_initialize_util.cc index ce5b4b342eb2d..0868ba5ead48a 100644 --- a/xla/stream_executor/tpu/tpu_initialize_util.cc +++ b/xla/stream_executor/tpu/tpu_initialize_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -98,7 +98,7 @@ bool IsTpuUsed(int64_t pid) { int64_t fd; if (!absl::SimpleAtoi(ent->d_name, &fd)) continue; path = absl::StrCat("/proc/", pid, "/fd/", fd); - if (!readlink(path.c_str(), &line[0], line.size())) continue; + if (!readlink(path.c_str(), line.data(), line.size())) continue; if (line != tpu_dev_path) continue; return true; } diff --git a/xla/stream_executor/tpu/tpu_initialize_util.h b/xla/stream_executor/tpu/tpu_initialize_util.h index 8656d0ec1fc75..9567f581e2e2d 100644 --- a/xla/stream_executor/tpu/tpu_initialize_util.h +++ b/xla/stream_executor/tpu/tpu_initialize_util.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/tpu/tpu_library_init_fns.inc b/xla/stream_executor/tpu/tpu_library_init_fns.inc index c037327281f21..baf503c6cf262 100644 --- a/xla/stream_executor/tpu/tpu_library_init_fns.inc +++ b/xla/stream_executor/tpu/tpu_library_init_fns.inc @@ -3,7 +3,7 @@ namespace { -tsl::Status SetTpuOpsStructFns(void* library_handle) { // TENSORFLOW_STATUS_OK +absl::Status SetTpuOpsStructFns(void* library_handle) { // TENSORFLOW_STATUS_OK // Constant cast so that we can initialize the functions. The functions are // mutable here because this is the only place where they are initialized. auto* ops_api_fn = @@ -112,15 +112,15 @@ tsl::Status SetTpuOpsStructFns(void* library_handle) { // TENSORFLOW_STATUS_OK TFTPU_SET_FN(ops_api_fn, TpuEmbeddingEngine_DedupDataTupleMaskComputation); TFTPU_SET_FN(ops_api_fn, SparseCore_GetMaxIdsAndUniques); - return tsl::OkStatus(); + return absl::OkStatus(); } -tsl::Status InitializeTpuStructFns( // TENSORFLOW_STATUS_OK +absl::Status InitializeTpuStructFns( // TENSORFLOW_STATUS_OK void* library_handle) { TF_RETURN_IF_ERROR(SetTpuOpsStructFns(library_handle)); TF_RETURN_IF_ERROR(SetExecutorStructFn(library_handle)); - return tsl::OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/xla/stream_executor/tpu/tpu_node_context.cc b/xla/stream_executor/tpu/tpu_node_context.cc index a7261a99c9b92..0299169fdc642 100644 --- a/xla/stream_executor/tpu/tpu_node_context.cc +++ b/xla/stream_executor/tpu/tpu_node_context.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ limitations under the License. namespace tensorflow { namespace tpu { -using tsl::StatusOr; +using absl::StatusOr; /*static*/ StatusOr> TpuNodeContext::Create( @@ -46,7 +46,7 @@ TpuNodeContext::~TpuNodeContext() { } /* static */ -tsl::Status TpuNodeContext::CloseTpuHost() { +absl::Status TpuNodeContext::CloseTpuHost() { StatusHelper status; stream_executor::tpu::OpsApiFn()->TpuNodeContext_CloseTpuHostFn( status.c_status); @@ -54,7 +54,7 @@ tsl::Status TpuNodeContext::CloseTpuHost() { } /* static */ -tsl::Status TpuNodeContext::Initialize(int device_ordinal) { +absl::Status TpuNodeContext::Initialize(int device_ordinal) { StatusHelper status; stream_executor::tpu::OpsApiFn()->TpuNodeContext_InitializeFn( device_ordinal, status.c_status); diff --git a/xla/stream_executor/tpu/tpu_node_context.h b/xla/stream_executor/tpu/tpu_node_context.h index 1af41c5b9cc20..036ece761323d 100644 --- a/xla/stream_executor/tpu/tpu_node_context.h +++ b/xla/stream_executor/tpu/tpu_node_context.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,14 +18,14 @@ limitations under the License. #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/service/backend.h" #include "xla/service/stream_pool.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/tpu/tpu_ops_c_api.h" #include "xla/stream_executor/tpu/tpu_platform_interface.h" #include "tsl/platform/macros.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace tensorflow { namespace tpu { @@ -38,7 +38,7 @@ namespace tpu { class TpuNodeContext final { public: template - using StatusOr = tsl::StatusOr; + using StatusOr = absl::StatusOr; static StatusOr> Create(int device_ordinal); @@ -48,9 +48,9 @@ class TpuNodeContext final { } ~TpuNodeContext(); - static tsl::Status CloseTpuHost(); + static absl::Status CloseTpuHost(); - static tsl::Status Initialize(int device_ordinal); + static absl::Status Initialize(int device_ordinal); static TpuPlatformInterface* platform(); diff --git a/xla/stream_executor/tpu/tpu_on_demand_compiler.cc b/xla/stream_executor/tpu/tpu_on_demand_compiler.cc index 2dafd0ec3cbbd..ed9f7a7f413cb 100644 --- a/xla/stream_executor/tpu/tpu_on_demand_compiler.cc +++ b/xla/stream_executor/tpu/tpu_on_demand_compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -54,7 +54,7 @@ class TpuCompiler : public Compiler { return tensorflow::tpu::GetTpuPlatformId(); } - StatusOr> RunHloPasses( + absl::StatusOr> RunHloPasses( std::unique_ptr module, stream_executor::StreamExecutor* executor, const CompileOptions& options) override { @@ -83,7 +83,7 @@ class TpuCompiler : public Compiler { return HloModule::CreateFromProto(result_proto, module->config()); } - StatusOr> RunBackend( + absl::StatusOr> RunBackend( std::unique_ptr module, stream_executor::StreamExecutor* executor, const CompileOptions& options) override { @@ -113,7 +113,7 @@ class TpuCompiler : public Compiler { return exec; } - StatusOr>> Compile( + absl::StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_exec, const CompileOptions& options) override { @@ -188,7 +188,7 @@ class TpuCompiler : public Compiler { // Compiles the HLO module group for ahead-of-time execution. This is // intended for use in static compilation. - StatusOr>> + absl::StatusOr>> CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& options) override { return Unimplemented("This compiler does not support CompileAheadOfTime."); diff --git a/xla/stream_executor/tpu/tpu_op_executable.cc b/xla/stream_executor/tpu/tpu_op_executable.cc index 5295ad2eea506..c4256230ac3bd 100644 --- a/xla/stream_executor/tpu/tpu_op_executable.cc +++ b/xla/stream_executor/tpu/tpu_op_executable.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -49,7 +49,7 @@ TpuOpExecutable::TpuOpExecutable( core_program_(core_program), outside_compilation_params_(outside_compilation_params) {} -xla::Status TpuOpExecutable::LoadProgramAndEnqueueToStream( +absl::Status TpuOpExecutable::LoadProgramAndEnqueueToStream( const xla::ServiceExecutableRunOptions& run_options, absl::Span arguments, se::DeviceMemoryBase result, diff --git a/xla/stream_executor/tpu/tpu_op_executable.h b/xla/stream_executor/tpu/tpu_op_executable.h index 346e15453f089..6ff5cced6c78e 100644 --- a/xla/stream_executor/tpu/tpu_op_executable.h +++ b/xla/stream_executor/tpu/tpu_op_executable.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -49,7 +49,7 @@ class TpuOpExecutable : public xla::TpuExecutableInterface { absl::string_view fingerprint() const override; private: - xla::Status LoadProgramAndEnqueueToStream( + absl::Status LoadProgramAndEnqueueToStream( const xla::ServiceExecutableRunOptions& run_options, absl::Span arguments, stream_executor::DeviceMemoryBase result, diff --git a/xla/stream_executor/tpu/tpu_ops_c_api.h b/xla/stream_executor/tpu/tpu_ops_c_api.h index 14494201e82d1..80365ebb046a2 100644 --- a/xla/stream_executor/tpu/tpu_ops_c_api.h +++ b/xla/stream_executor/tpu/tpu_ops_c_api.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -635,6 +635,9 @@ typedef struct TpuEmbeddingEngine_RecvActivationsComputation_Params { void* priv; TpuSerializedProto tpu_embedding_config; + TpuSerializedProto embedding_partitions; + TpuSerializedProto hbm_buffers_config; + TpuSerializedProto tpu_topology; XLA_Shape* deduplication_data_shape; TpuSerializedProto* op_sharding; @@ -652,6 +655,9 @@ typedef struct void* priv; TpuSerializedProto tpu_embedding_config; + TpuSerializedProto embedding_partitions; + TpuSerializedProto hbm_buffers_config; + TpuSerializedProto tpu_topology; TpuSerializedProto* op_sharding; // out TpuSerializedProto* xla_computation; @@ -669,6 +675,9 @@ typedef struct TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation_Params { int32_t num_inputs; TpuSerializedProto tpu_embedding_config; + TpuSerializedProto embedding_partitions; + TpuSerializedProto hbm_buffers_config; + TpuSerializedProto tpu_topology; XLA_Shape* learning_rate_tuple_shape; XLA_Shape* deduplication_data_shape; XLA_Shape* gradient_tuple_shape; @@ -686,6 +695,9 @@ typedef struct TpuEmbeddingEngine_DedupDataSizeComputation_Params { void* priv; TpuSerializedProto tpu_embedding_config; + TpuSerializedProto embedding_partitions; + TpuSerializedProto hbm_buffers_config; + TpuSerializedProto tpu_topology; // out int32_t* num_elements; TF_Status* status; @@ -699,6 +711,9 @@ typedef struct TpuEmbeddingEngine_DedupDataTupleMaskComputation_Params { void* priv; TpuSerializedProto tpu_embedding_config; + TpuSerializedProto embedding_partitions; + TpuSerializedProto hbm_buffers_config; + TpuSerializedProto tpu_topology; // out TpuSerializedProto* xla_computation; TF_Status* status; diff --git a/xla/stream_executor/tpu/tpu_platform.cc b/xla/stream_executor/tpu/tpu_platform.cc index 91274499259ec..d346c510d9006 100644 --- a/xla/stream_executor/tpu/tpu_platform.cc +++ b/xla/stream_executor/tpu/tpu_platform.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,9 +23,11 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" #include "xla/stream_executor/tpu/c_api_decl.h" @@ -37,8 +39,6 @@ limitations under the License. #include "xla/stream_executor/tpu/tpu_platform_interface.h" #include "xla/stream_executor/tpu/tpu_topology.h" #include "tsl/platform/logging.h" // IWYU pragma: keep -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace tensorflow { namespace tpu { @@ -47,7 +47,7 @@ const ::stream_executor::Platform::Id TpuPlatform::kId = GetTpuPlatformId(); TpuPlatform* tpu_registered_platform = nullptr; template -using StatusOr = ::tsl::StatusOr; +using StatusOr = ::absl::StatusOr; TpuPlatform::TpuPlatform() : name_("TPU") { platform_ = stream_executor::tpu::ExecutorApiFn()->TpuPlatform_NewFn(); @@ -58,7 +58,7 @@ TpuPlatform* TpuPlatform::GetRegisteredPlatform() { return tpu_registered_platform; } -tsl::Status TpuPlatform::Initialize( +absl::Status TpuPlatform::Initialize( const std::map& platform_options) { StatusHelper status; @@ -172,11 +172,11 @@ void TpuPlatform::EraseEvent(stream_executor::internal::EventInterface* key) { event_map_.erase(key); } -tsl::Status TpuPlatform::TpusPerHost(int* tpus) { +absl::Status TpuPlatform::TpusPerHost(int* tpus) { if (stream_executor::tpu::OpsApiFn()->TpuConfigurationApi_TpusPerHostFn == nullptr) { *tpus = 0; - return tsl::OkStatus(); + return absl::OkStatus(); } StatusHelper status; @@ -185,11 +185,11 @@ tsl::Status TpuPlatform::TpusPerHost(int* tpus) { return status.status(); } -tsl::Status TpuPlatform::TpuMemoryLimit(int64_t* memory_limit) { +absl::Status TpuPlatform::TpuMemoryLimit(int64_t* memory_limit) { if (stream_executor::tpu::OpsApiFn()->TpuConfigurationApi_TpuMemoryLimitFn == nullptr) { *memory_limit = 0; - return tsl::OkStatus(); + return absl::OkStatus(); } StatusHelper status; @@ -211,7 +211,7 @@ bool RegisterTpuPlatform() { tpu_registered_platform = new TpuPlatform(); std::unique_ptr platform( tpu_registered_platform); - TF_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform( + TF_CHECK_OK(stream_executor::PlatformManager::RegisterPlatform( std::move(platform))); tpu_platform_registered = true; } diff --git a/xla/stream_executor/tpu/tpu_platform.h b/xla/stream_executor/tpu/tpu_platform.h index abfe0b6fb4538..348acf135c5cf 100644 --- a/xla/stream_executor/tpu/tpu_platform.h +++ b/xla/stream_executor/tpu/tpu_platform.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,6 +22,9 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "xla/stream_executor/executor_cache.h" #include "xla/stream_executor/platform.h" @@ -30,10 +33,7 @@ limitations under the License. #include "xla/stream_executor/tpu/tpu_executor_c_api.h" // IWYU pragma: keep #include "xla/stream_executor/tpu/tpu_platform_interface.h" #include "xla/stream_executor/tpu/tpu_topology.h" -#include "xla/stream_executor/trace_listener.h" #include "tsl/platform/logging.h" // IWYU pragma: keep -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace tensorflow { namespace tpu { @@ -50,7 +50,7 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface { static const ::stream_executor::Platform::Id kId; template - using StatusOr = ::tsl::StatusOr; + using StatusOr = ::absl::StatusOr; TpuPlatform(); @@ -75,10 +75,10 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface { bool Initialized() const override; - tsl::Status Initialize( + absl::Status Initialize( const std::map& platform_options) override; - tsl::Status Reset(bool only_tear_down, absl::string_view reason) override { + absl::Status Reset(bool only_tear_down, absl::string_view reason) override { LOG(FATAL) << "Not yet implemented"; } @@ -117,10 +117,10 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface { SE_Platform* se_platform() const { return platform_; } // Returns the number of TPUs per host. - static tsl::Status TpusPerHost(int* tpus); + static absl::Status TpusPerHost(int* tpus); // Returns the memory capacity of the TPUs on this host. - static tsl::Status TpuMemoryLimit(int64_t* memory_limit); + static absl::Status TpuMemoryLimit(int64_t* memory_limit); absl::Mutex& mutex() { return event_map_mu_; } diff --git a/xla/stream_executor/tpu/tpu_platform_id.cc b/xla/stream_executor/tpu/tpu_platform_id.cc index 0feb400ed25ab..3b3711f805dac 100644 --- a/xla/stream_executor/tpu/tpu_platform_id.cc +++ b/xla/stream_executor/tpu/tpu_platform_id.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/tpu/tpu_platform_id.h b/xla/stream_executor/tpu/tpu_platform_id.h index fa0b52342835a..53424463294eb 100644 --- a/xla/stream_executor/tpu/tpu_platform_id.h +++ b/xla/stream_executor/tpu/tpu_platform_id.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/tpu/tpu_platform_interface.cc b/xla/stream_executor/tpu/tpu_platform_interface.cc index ef9f5d6785259..c7df661934283 100644 --- a/xla/stream_executor/tpu/tpu_platform_interface.cc +++ b/xla/stream_executor/tpu/tpu_platform_interface.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,8 +16,8 @@ limitations under the License. #include "xla/stream_executor/tpu/tpu_platform_interface.h" #include "absl/synchronization/mutex.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/protobuf/error_codes.pb.h" @@ -32,8 +32,8 @@ TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform, DCHECK_GT(tries_left, 0); // Prefer TpuPlatform if it's registered. auto status_or_tpu_platform = - stream_executor::MultiPlatformManager::PlatformWithName( - "TPU", initialize_platform); + stream_executor::PlatformManager::PlatformWithName("TPU", + initialize_platform); if (status_or_tpu_platform.ok()) { return static_cast(status_or_tpu_platform.value()); } @@ -45,7 +45,7 @@ TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform, // Use any other registered TPU platform. auto status_or_other_tpu_platforms = - stream_executor::MultiPlatformManager::PlatformsWithFilter( + stream_executor::PlatformManager::PlatformsWithFilter( [](const stream_executor::Platform* platform) { return dynamic_cast(platform) != nullptr; diff --git a/xla/stream_executor/tpu/tpu_platform_interface.h b/xla/stream_executor/tpu/tpu_platform_interface.h index 99d532a7c3ca2..567f146228897 100644 --- a/xla/stream_executor/tpu/tpu_platform_interface.h +++ b/xla/stream_executor/tpu/tpu_platform_interface.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,10 +18,11 @@ limitations under the License. #include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/tpu/c_api_decl.h" #include "xla/stream_executor/tpu/tpu_topology.h" -#include "tsl/platform/status.h" namespace tensorflow { namespace tpu { @@ -43,11 +44,11 @@ class TpuPlatformInterface : public stream_executor::Platform { static TpuPlatformInterface* GetRegisteredPlatform( bool initialize_platform = true, int num_tries = 5); - virtual tsl::Status Reset(bool only_tear_down, absl::string_view reason) = 0; + virtual absl::Status Reset(bool only_tear_down, absl::string_view reason) = 0; - tsl::Status Reset(absl::string_view reason) { return Reset(false, reason); } + absl::Status Reset(absl::string_view reason) { return Reset(false, reason); } - tsl::Status Reset() { return Reset(false, {}); } + absl::Status Reset() { return Reset(false, {}); } virtual bool ShouldRegisterTpuDeviceToDeviceCopy() = 0; diff --git a/xla/stream_executor/tpu/tpu_platform_registration.cc b/xla/stream_executor/tpu/tpu_platform_registration.cc index 4de704f2542c2..c0c15a98f6e6a 100644 --- a/xla/stream_executor/tpu/tpu_platform_registration.cc +++ b/xla/stream_executor/tpu/tpu_platform_registration.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,15 +17,7 @@ limitations under the License. #include "xla/stream_executor/tpu/tpu_platform.h" #if defined(PLATFORM_GOOGLE) -REGISTER_MODULE_INITIALIZER(tpu_platform, - tensorflow::tpu::RegisterTpuPlatform()); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER( + tpu_platform, tensorflow::tpu::RegisterTpuPlatform()); -DECLARE_MODULE_INITIALIZER(multi_platform_manager); -DECLARE_MODULE_INITIALIZER(multi_platform_manager_listener); - -// Note that module initialization sequencing is not supported in the -// open-source project, so this will be a no-op there. -REGISTER_MODULE_INITIALIZER_SEQUENCE(tpu_platform, multi_platform_manager); -REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener, - tpu_platform); #endif diff --git a/xla/stream_executor/tpu/tpu_profiler_c_api.h b/xla/stream_executor/tpu/tpu_profiler_c_api.h index efb81d37acc14..3f0572a66132b 100644 --- a/xla/stream_executor/tpu/tpu_profiler_c_api.h +++ b/xla/stream_executor/tpu/tpu_profiler_c_api.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/tpu/tpu_profiler_init_fns.inc b/xla/stream_executor/tpu/tpu_profiler_init_fns.inc index 341c51f2d9677..99518ad0d1768 100644 --- a/xla/stream_executor/tpu/tpu_profiler_init_fns.inc +++ b/xla/stream_executor/tpu/tpu_profiler_init_fns.inc @@ -3,7 +3,8 @@ namespace { -tsl::Status SetTpuProfilerApiFns(void* library_handle) { // TENSORFLOW_STATUS_OK +absl::Status SetTpuProfilerApiFns( + void* library_handle) { // TENSORFLOW_STATUS_OK // Constant cast so that we can initialize the functions. The functions are // mutable here because this is the only place where they are initialized. auto* profiler_api_fn = @@ -18,7 +19,7 @@ tsl::Status SetTpuProfilerApiFns(void* library_handle) { // TENSORFLOW_STATUS_O TFTPU_SET_FN(profiler_api_fn, TpuStatus_Free); TFTPU_SET_FN(profiler_api_fn, TpuStatus_Message); TFTPU_SET_FN(profiler_api_fn, TpuStatus_Code); - return tsl::OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/xla/stream_executor/tpu/tpu_stream.h b/xla/stream_executor/tpu/tpu_stream.h index 286a850c0a3f1..456ef878d6765 100644 --- a/xla/stream_executor/tpu/tpu_stream.h +++ b/xla/stream_executor/tpu/tpu_stream.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/status/status.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/tpu/c_api_conversions.h" #include "xla/stream_executor/tpu/c_api_decl.h" @@ -25,7 +26,6 @@ limitations under the License. #include "xla/stream_executor/tpu/tpu_executor_api.h" #include "xla/stream_executor/tpu/tpu_executor_c_api.h" #include "xla/stream_executor/tpu/tpu_stream_interface.h" -#include "tsl/platform/status.h" namespace tensorflow { namespace tpu { @@ -44,7 +44,7 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface { stream_, static_cast(other)->stream_); } - tsl::Status EnqueueTransferHostToDevice( + absl::Status EnqueueTransferHostToDevice( stream_executor::DeviceMemoryBase device_dst, const void* host_src, uint64_t size) { StatusHelper status; @@ -55,7 +55,7 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface { return status.status(); } - tsl::Status EnqueueTransferDeviceToHost( + absl::Status EnqueueTransferDeviceToHost( stream_executor::DeviceMemoryBase device_src, void* host_dst, uint64_t size) { StatusHelper status; @@ -66,7 +66,7 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface { return status.status(); } - tsl::Status EnqueueOnTpuDeviceSendRecvLocal( + absl::Status EnqueueOnTpuDeviceSendRecvLocal( stream_executor::DeviceMemoryBase send_buffer, stream_executor::DeviceMemoryBase recv_buffer) override { StatusHelper status; diff --git a/xla/stream_executor/tpu/tpu_stream_interface.h b/xla/stream_executor/tpu/tpu_stream_interface.h index a91c93ddcb26d..f5d21eb4f475c 100644 --- a/xla/stream_executor/tpu/tpu_stream_interface.h +++ b/xla/stream_executor/tpu/tpu_stream_interface.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,10 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_ #define XLA_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_ +#include "absl/status/status.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" -#include "tsl/platform/status.h" namespace tensorflow { namespace tpu { @@ -27,7 +27,7 @@ namespace tpu { class TpuStreamInterface : public stream_executor::internal::StreamInterface { public: virtual bool IsSameSharedMemoryLocation(TpuStreamInterface* other) = 0; - virtual tsl::Status EnqueueOnTpuDeviceSendRecvLocal( + virtual absl::Status EnqueueOnTpuDeviceSendRecvLocal( stream_executor::DeviceMemoryBase send_buffer, stream_executor::DeviceMemoryBase recv_buffer) = 0; }; diff --git a/xla/stream_executor/tpu/tpu_topology.cc b/xla/stream_executor/tpu/tpu_topology.cc index 2af388816e1b4..dab965a525e6b 100644 --- a/xla/stream_executor/tpu/tpu_topology.cc +++ b/xla/stream_executor/tpu/tpu_topology.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/tpu/tpu_topology.h b/xla/stream_executor/tpu/tpu_topology.h index 74e8d9e4d2eda..8c5f6ae19285a 100644 --- a/xla/stream_executor/tpu/tpu_topology.h +++ b/xla/stream_executor/tpu/tpu_topology.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/tpu/tpu_transfer_manager.cc b/xla/stream_executor/tpu/tpu_transfer_manager.cc index e44815bc4428b..c804af09c4d9f 100644 --- a/xla/stream_executor/tpu/tpu_transfer_manager.cc +++ b/xla/stream_executor/tpu/tpu_transfer_manager.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,6 +24,8 @@ limitations under the License. #include #include "absl/cleanup/cleanup.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/literal.h" #include "xla/service/shaped_buffer.h" @@ -44,14 +46,12 @@ limitations under the License. #include "xla/stream_executor/tpu/tpu_platform_id.h" #include "tsl/platform/casts.h" #include "tsl/platform/logging.h" // IWYU pragma: keep -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace tensorflow { namespace tpu { template -using StatusOr = tsl::StatusOr; +using StatusOr = absl::StatusOr; TpuTransferManager::TpuTransferManager() { manager_ = stream_executor::tpu::ExecutorApiFn()->TpuTransferManager_NewFn(); @@ -81,7 +81,7 @@ xla::Shape TpuTransferManager::HostShapeToDeviceShape( return device_shape; } -tsl::Status TpuTransferManager::TransferLiteralToDeviceAsync( +absl::Status TpuTransferManager::TransferLiteralToDeviceAsync( stream_executor::Stream* stream, const xla::LiteralSlice& literal, const xla::ShapedBuffer& device_buffer, const TransferMetadata* transfer_metadata) { @@ -104,7 +104,7 @@ tsl::Status TpuTransferManager::TransferLiteralToDeviceAsync( return status.status(); } -tsl::Status TpuTransferManager::TransferLiteralToInfeed( +absl::Status TpuTransferManager::TransferLiteralToInfeed( stream_executor::StreamExecutor* executor, const xla::LiteralSlice& literal) { StatusHelper status; @@ -122,7 +122,7 @@ tsl::Status TpuTransferManager::TransferLiteralToInfeed( return status.status(); } -tsl::Status TpuTransferManager::TransferBuffersToInfeed( +absl::Status TpuTransferManager::TransferBuffersToInfeed( se::StreamExecutor* executor, const std::deque& buffers) { StatusHelper status; @@ -148,7 +148,7 @@ tsl::Status TpuTransferManager::TransferBuffersToInfeed( return status.status(); } -tsl::Status TpuTransferManager::TransferLiteralFromOutfeed( +absl::Status TpuTransferManager::TransferLiteralFromOutfeed( stream_executor::StreamExecutor* executor, xla::MutableBorrowingLiteral literal) { StatusHelper status; @@ -171,7 +171,7 @@ tsl::Status TpuTransferManager::TransferLiteralFromOutfeed( return status.status(); } -tsl::Status TpuTransferManager::ResetDevices( +absl::Status TpuTransferManager::ResetDevices( absl::Span executor) { StatusHelper status; std::vector se; @@ -191,7 +191,7 @@ struct TransferFromDeviceState { std::atomic remaining_transfers; TF_Status* overall_status = stream_executor::tpu::ExecutorApiFn() ->TpuStatus_NewFn(); // OK or the first error - std::function done; + std::function done; void TransferFinished(TF_Status* status) { if (!stream_executor::tpu::ExecutorApiFn()->TpuStatus_OkFn(status) && @@ -214,7 +214,8 @@ void TransferLiteralFromDeviceTrampoline(void* ctx, TF_Status* status) { void TpuTransferManager::TransferLiteralFromDevice( stream_executor::Stream* stream, const xla::ShapedBuffer& device_buffer, - xla::MutableBorrowingLiteral literal, std::function done, + xla::MutableBorrowingLiteral literal, + std::function done, const TransferMetadata* transfer_metadata) { TransferFromDeviceState* state = new TransferFromDeviceState; state->remaining_transfers = 1; @@ -296,7 +297,7 @@ bool TpuTransferManager::CanBufferBeAccessedNow( manager_, tpu_executor->se_executor(), &c_device_buffer); } -tsl::Status TpuTransferManager::WriteSingleTupleIndexTable( +absl::Status TpuTransferManager::WriteSingleTupleIndexTable( stream_executor::Stream* stream, absl::Span elements, const xla::Shape& shape, stream_executor::DeviceMemoryBase* region) { @@ -327,7 +328,7 @@ tsl::Status TpuTransferManager::WriteSingleTupleIndexTable( return status.status(); } -tsl::Status TpuTransferManager::LinearizeToBuffers( +absl::Status TpuTransferManager::LinearizeToBuffers( const xla::LiteralSlice& literal, const xla::Shape& device_shape, std::deque* buffers) { XLA_Literal c_literal; @@ -360,7 +361,7 @@ tsl::Status TpuTransferManager::LinearizeToBuffers( return status.status(); } -tsl::Status TpuTransferManager::ReadDynamicShapes( +absl::Status TpuTransferManager::ReadDynamicShapes( se::Stream* stream, const xla::ShapedBuffer* device_buffer, xla::Shape* device_shape) { XLA_ShapedBuffer c_device_buffer; @@ -380,7 +381,7 @@ tsl::Status TpuTransferManager::ReadDynamicShapes( } *device_shape = ApiConverter::FromC(&c_updated_shape); ApiConverter::Destroy(&c_updated_shape); - return tsl::OkStatus(); + return absl::OkStatus(); } } // namespace tpu diff --git a/xla/stream_executor/tpu/tpu_transfer_manager.h b/xla/stream_executor/tpu/tpu_transfer_manager.h index 5693964eb7500..83e024435f9df 100644 --- a/xla/stream_executor/tpu/tpu_transfer_manager.h +++ b/xla/stream_executor/tpu/tpu_transfer_manager.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/literal.h" #include "xla/service/shaped_buffer.h" @@ -30,8 +32,6 @@ limitations under the License. #include "xla/stream_executor/tpu/noncopyable_buffer.h" #include "xla/stream_executor/tpu/tpu_executor_c_api.h" #include "xla/stream_executor/tpu/tpu_transfer_manager_interface.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace tensorflow { namespace tpu { @@ -42,14 +42,14 @@ class TpuTransferManager : public xla::TpuTransferManagerInterface { ~TpuTransferManager() override; template - using StatusOr = tsl::StatusOr; + using StatusOr = absl::StatusOr; stream_executor::Platform::Id PlatformId() const override; xla::Shape HostShapeToDeviceShape( const xla::Shape& host_shape) const override; - tsl::Status TransferLiteralToDeviceAsync( + absl::Status TransferLiteralToDeviceAsync( stream_executor::Stream* stream, const xla::LiteralSlice& literal, const xla::ShapedBuffer& device_buffer, const TransferMetadata* transfer_metadata) override; @@ -57,22 +57,22 @@ class TpuTransferManager : public xla::TpuTransferManagerInterface { void TransferLiteralFromDevice( stream_executor::Stream* stream, const xla::ShapedBuffer& device_buffer, xla::MutableBorrowingLiteral literal, - std::function done, + std::function done, const TransferMetadata* transfer_metadata) override; - tsl::Status TransferLiteralToInfeed( + absl::Status TransferLiteralToInfeed( stream_executor::StreamExecutor* executor, const xla::LiteralSlice& literal) override; - tsl::Status TransferLiteralFromOutfeed( + absl::Status TransferLiteralFromOutfeed( stream_executor::StreamExecutor* executor, xla::MutableBorrowingLiteral literal) override; - tsl::Status TransferBuffersToInfeed( + absl::Status TransferBuffersToInfeed( se::StreamExecutor* executor, const std::deque& buffers) override; - tsl::Status ResetDevices( + absl::Status ResetDevices( absl::Span executor) override; int64_t GetByteSizeRequirement(const xla::Shape& shape) const override; @@ -88,19 +88,19 @@ class TpuTransferManager : public xla::TpuTransferManagerInterface { se::StreamExecutor* executor, const se::DeviceMemoryBase& device_buffer) const override; - tsl::Status WriteSingleTupleIndexTable( + absl::Status WriteSingleTupleIndexTable( stream_executor::Stream* stream, absl::Span elements, const xla::Shape& shape, stream_executor::DeviceMemoryBase* region) override; - tsl::Status LinearizeToBuffers( + absl::Status LinearizeToBuffers( const xla::LiteralSlice& literal, const xla::Shape& device_shape, std::deque* buffers) override; - tsl::Status ReadDynamicShapes(se::Stream* stream, - const xla::ShapedBuffer* device_buffer, - xla::Shape* device_shape) override; + absl::Status ReadDynamicShapes(se::Stream* stream, + const xla::ShapedBuffer* device_buffer, + xla::Shape* device_shape) override; private: XLA_TransferManager* manager_; diff --git a/xla/stream_executor/tpu/tpu_transfer_manager_interface.cc b/xla/stream_executor/tpu/tpu_transfer_manager_interface.cc index d9cd043d644be..f01165bb7b074 100644 --- a/xla/stream_executor/tpu/tpu_transfer_manager_interface.cc +++ b/xla/stream_executor/tpu/tpu_transfer_manager_interface.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/tpu/tpu_transfer_manager_interface.h b/xla/stream_executor/tpu/tpu_transfer_manager_interface.h index e6eb07170163b..e3478e3a5ff78 100644 --- a/xla/stream_executor/tpu/tpu_transfer_manager_interface.h +++ b/xla/stream_executor/tpu/tpu_transfer_manager_interface.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,6 +36,8 @@ class TpuTransferManagerInterface : public xla::TransferManager { std::deque* buffers) = 0; static TpuTransferManagerInterface* GetRegisteredTpuTransferManager(); + + bool PackSubbyteTypes() const override { return true; } }; } // namespace xla diff --git a/xla/stream_executor/tpu/tpu_transfer_manager_registration.cc b/xla/stream_executor/tpu/tpu_transfer_manager_registration.cc index 2878cf68abff7..2995bd40c0c25 100644 --- a/xla/stream_executor/tpu/tpu_transfer_manager_registration.cc +++ b/xla/stream_executor/tpu/tpu_transfer_manager_registration.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/stream_executor/tpu/tsl_status_helper.h b/xla/stream_executor/tpu/tsl_status_helper.h index 3abc184440860..074b7703d09e2 100644 --- a/xla/stream_executor/tpu/tsl_status_helper.h +++ b/xla/stream_executor/tpu/tsl_status_helper.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,8 +18,8 @@ limitations under the License. #include "absl/status/status.h" #include "xla/stream_executor/tpu/c_api_decl.h" -#include "tsl/c/tsl_status.h" -#include "tsl/c/tsl_status_helper.h" +#include "xla/tsl/c/tsl_status.h" +#include "xla/tsl/c/tsl_status_helper.h" #include "tsl/platform/status.h" class TslStatusHelper { @@ -28,7 +28,8 @@ class TslStatusHelper { ~TslStatusHelper() { TSL_DeleteStatus(c_status); } - static tsl::Status FromC(TF_Status* const c_status) { // TENSORFLOW_STATUS_OK + static absl::Status FromC( + TF_Status* const c_status) { // TENSORFLOW_STATUS_OK absl::StatusCode code = tsl::StatusCodeFromTSLCode(TSL_GetCode(c_status)); if (code == absl::StatusCode::kOk) { return tsl::OkStatus(); @@ -41,7 +42,7 @@ class TslStatusHelper { absl::StatusCode::kOk; } - tsl::Status status() const { // TENSORFLOW_STATUS_OK + absl::Status status() const { // TENSORFLOW_STATUS_OK return FromC(c_status); } diff --git a/xla/stream_executor/trace_listener.h b/xla/stream_executor/trace_listener.h deleted file mode 100644 index 79909261aca07..0000000000000 --- a/xla/stream_executor/trace_listener.h +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file defines the StreamExecutor trace listener, used for inserting -// non-device-specific instrumentation into the StreamExecutor. -#ifndef XLA_STREAM_EXECUTOR_TRACE_LISTENER_H_ -#define XLA_STREAM_EXECUTOR_TRACE_LISTENER_H_ - -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" -#include "tsl/platform/status.h" - -namespace stream_executor { - -class Stream; - -// Traces StreamExecutor PIMPL-level events. -// The few StreamExecutor interfaces that are synchronous have both Begin and -// Complete versions of their trace calls. Asynchronous operations only have -// Submit calls, as execution of the underlying operations is device-specific. -// As all tracing calls mirror StreamExecutor routines, documentation here is -// minimal. -// -// All calls have default implementations that perform no work; subclasses -// should override functionality of interest. Keep in mind that these routines -// are not called on a dedicated thread, so callbacks should execute quickly. -// -// Note: This API is constructed on an as-needed basis. Users should add -// support for further StreamExecutor operations as required. By enforced -// convention (see SCOPED_TRACE in stream_executor_pimpl.cc), synchronous -// tracepoints should be named NameBegin and NameComplete. -class TraceListener { - public: - virtual ~TraceListener() {} - - virtual void LaunchSubmit(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, const Kernel& kernel, - const KernelArgs& args) {} - - virtual void SynchronousMemcpyH2DBegin(int64_t correlation_id, - const void* host_src, int64_t size, - DeviceMemoryBase* gpu_dst) {} - virtual void SynchronousMemcpyH2DComplete(int64_t correlation_id, - const tsl::Status* result) {} - - virtual void SynchronousMemcpyD2HBegin(int64_t correlation_id, - const DeviceMemoryBase& gpu_src, - int64_t size, void* host_dst) {} - virtual void SynchronousMemcpyD2HComplete(int64_t correlation_id, - const tsl::Status* result) {} - - virtual void BlockHostUntilDoneBegin(int64_t correlation_id, Stream* stream) { - } - virtual void BlockHostUntilDoneComplete(int64_t correlation_id, - const tsl::Status* result) {} -}; - -} // namespace stream_executor - -#endif // XLA_STREAM_EXECUTOR_TRACE_LISTENER_H_ diff --git a/xla/test.h b/xla/test.h index 39aced84821dc..5117b8fd41a1c 100644 --- a/xla/test.h +++ b/xla/test.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/test_helpers.h b/xla/test_helpers.h index 162f057cce2d2..11425fd51fa35 100644 --- a/xla/test_helpers.h +++ b/xla/test_helpers.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,14 +16,8 @@ limitations under the License. #ifndef XLA_TEST_HELPERS_H_ #define XLA_TEST_HELPERS_H_ -#include -#include - -#include "absl/strings/string_view.h" +#include "xla/status.h" #include "xla/statusor.h" -#include "xla/types.h" -#include "tsl/platform/protobuf.h" -#include "tsl/platform/regexp.h" #include "tsl/platform/test.h" // This module contains a minimal subset of gmock functionality just @@ -54,18 +48,18 @@ inline const Status& GetStatus(const StatusOr& status) { // Macros for testing the results of functions that return Status or // StatusOr (for any type T). #define EXPECT_IS_OK(expression) \ - EXPECT_EQ(::tsl::OkStatus(), \ + EXPECT_EQ(::absl::OkStatus(), \ xla::testing::internal_status::GetStatus(expression)) #define EXPECT_IS_NOT_OK(expression) \ - EXPECT_NE(::tsl::OkStatus(), \ + EXPECT_NE(::absl::OkStatus(), \ xla::testing::internal_status::GetStatus(expression)) #undef ASSERT_IS_OK #define ASSERT_IS_OK(expression) \ - ASSERT_EQ(::tsl::OkStatus(), \ + ASSERT_EQ(::absl::OkStatus(), \ xla::testing::internal_status::GetStatus(expression)) #undef ASSERT_IS_NOT_OK #define ASSERT_IS_NOT_OK(expression) \ - ASSERT_NE(::tsl::OkStatus(), \ + ASSERT_NE(::absl::OkStatus(), \ xla::testing::internal_status::GetStatus(expression)) #endif // XLA_TEST_HELPERS_H_ diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 3e9a1f5b51b9e..d3d3275be7dda 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -1,17 +1,11 @@ # Description: # Base testing infrastructure for XLA. -load("//xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library") -load("//xla:xla.bzl", "xla_cc_binary", "xla_cc_test") -load( - "//xla/stream_executor:build_defs.bzl", - "if_gpu_is_configured", -) load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured", ) -load("@tsl//tsl:tsl.bzl", "tsl_copts") +load("@tsl//tsl:tsl.bzl", "internal_visibility", "tsl_copts") load("@tsl//tsl:tsl.default.bzl", "filegroup") load( "@tsl//tsl/platform:build_config_root.bzl", @@ -22,10 +16,16 @@ load( "@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) +load("//xla:xla.bzl", "xla_cc_binary", "xla_cc_test") +load( + "//xla/stream_executor:build_defs.bzl", + "if_gpu_is_configured", +) +load("//xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [":friends"], + default_visibility = internal_visibility([":friends"]), licenses = ["notice"], ) @@ -57,6 +57,7 @@ cc_library( srcs = ["xla_internal_test_main.cc"], deps = [ "//xla:debug_options_flags", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", @@ -184,6 +185,7 @@ cc_library( ], deps = [ ":pjrt_client_registry", + "//xla/pjrt/gpu:gpu_helpers", "//xla/pjrt/gpu:se_gpu_pjrt_client", ], ) @@ -220,6 +222,7 @@ cc_library( "//xla/service:platform_util", "//xla/stream_executor", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:logging", @@ -273,6 +276,7 @@ cc_library( "//xla:status_macros", "//xla:statusor", "//xla:test_helpers", + "//xla:types", "//xla:xla_data_proto_cc", "//xla/client:client_library", "//xla/client:global_data", @@ -285,8 +289,8 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@tsl//tsl/lib/core:bitmap", - "@tsl//tsl/platform:float8", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test", ], ) @@ -327,6 +331,7 @@ cc_library( data = [ "@llvm-project//llvm:FileCheck", ], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), deps = [ "//xla:statusor", "//xla:types", @@ -336,6 +341,7 @@ cc_library( "@tsl//tsl/platform:path", "@tsl//tsl/platform:resource_loader", "@tsl//tsl/platform:subprocess", + "@tsl//tsl/platform:test", ], ) @@ -694,6 +700,29 @@ xla_test( ], ) +xla_test( + name = "complex_unary_op_test", + srcs = [ + "complex_unary_op_samples.h", + "complex_unary_op_test.cc", + ], + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":client_library_test_base", + ":literal_test_util", + ":test_macros_header", + ":xla_internal_test_main", + "//xla:xla_data_proto_cc", + "//xla/client:global_data", + "//xla/client:local_client", + "//xla/client:xla_builder", + "@tsl//tsl/platform:test", + ], +) + xla_test( name = "scalar_computations_test", srcs = ["scalar_computations_test.cc"], @@ -774,6 +803,7 @@ xla_test( "//xla:array2d", "//xla:array3d", "//xla:array4d", + "//xla:comparison_util", "//xla:literal", "//xla:shape_util", "//xla:statusor", @@ -784,6 +814,8 @@ xla_test( "//xla/client:xla_builder", "@com_google_absl//absl/base", "@com_google_absl//absl/types:span", + "@ml_dtypes//:float8", + "@tsl//tsl/platform:ml_dtypes", ], ) @@ -992,10 +1024,13 @@ xla_test( ":hlo_test_base", ":test_macros_header", ":xla_internal_test_main", + "//xla:array2d", "//xla:error_spec", "//xla:literal", + "//xla:shape_util", "//xla:status_macros", "//xla:test", + "//xla:types", "@com_google_absl//absl/strings", ], ) @@ -1070,7 +1105,7 @@ xla_test( "//xla/client:xla_builder", "//xla/client/lib:constants", "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:float8", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test", ], ) @@ -1293,6 +1328,7 @@ xla_test( "requires-gpu-sm80", ]}, backends = ["gpu"], + tags = ["no_rocm"], # No int8 deps = [ ":hlo_test_base", ":xla_internal_test_main", @@ -1689,10 +1725,15 @@ xla_test( srcs = ["token_hlo_test.cc"], deps = [ ":hlo_test_base", + ":literal_test_util", ":test_macros_header", ":test_utils", ":xla_internal_test_main", - "//xla/service:hlo_verifier", + "//xla:literal", + "//xla:literal_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_runner", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], ) @@ -1727,17 +1768,24 @@ xla_test( ":test_macros_header", ":test_utils", ":xla_internal_test_main", # fixdeps: keep + "//xla:literal", "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", "//xla/client/lib:constants", + "//xla/ffi", + "//xla/ffi:ffi_api", "//xla/hlo/ir:hlo", - "//xla/runtime/ffi:ffi_api", "//xla/service:custom_call_status", "//xla/service:custom_call_target_registry", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:dynamic_annotations", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], ) @@ -1775,6 +1823,7 @@ xla_test( "//xla:test", "//xla/client:local_client", "//xla/client:xla_builder", + "@com_google_absl//absl/strings:string_view", ], ) @@ -1874,6 +1923,7 @@ xla_test( "//xla/client:local_client", "//xla/client:xla_builder", "@com_google_absl//absl/types:span", + "@eigen_archive//:eigen3", "@tsl//tsl/platform:protobuf", "@tsl//tsl/platform:test", ], @@ -1972,10 +2022,6 @@ xla_test( xla_test( name = "convert_test", srcs = ["convert_test.cc"], - backend_tags = { - # TODO(b/305009066) - "ghostfish_grm": ["broken"], - }, deps = [ ":client_library_test_base", ":literal_test_util", @@ -1989,7 +2035,7 @@ xla_test( "//xla/stream_executor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", - "@tsl//tsl/platform:float8", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test", ], ) @@ -2027,7 +2073,6 @@ xla_test( "multi_gpu", "no_oss", "notap", - "jitrt_executable", ], "cpu": [ "notsan", @@ -2066,7 +2111,6 @@ xla_test( "multi_gpu", "no_oss", "notap", - "jitrt_executable", ], }, backends = [ @@ -2147,7 +2191,7 @@ xla_test( "//xla/client:local_client", "//xla/client:xla_builder", "//xla/stream_executor", - "@tsl//tsl/platform:float8", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test", ], ) @@ -2169,29 +2213,6 @@ xla_test( ], ) -xla_test( - name = "compilation_cache_test", - srcs = ["compilation_cache_test.cc"], - deps = [ - ":client_library_test_base", - ":literal_test_util", - ":test_macros_header", - ":test_utils", - ":xla_internal_test_main", - "//xla:literal", - "//xla:shape_util", - "//xla:statusor", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/client:global_data", - "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:test", - ], -) - xla_test( name = "floor_ceil_test", srcs = ["floor_ceil_test.cc"], @@ -2349,7 +2370,7 @@ xla_test( xla_cc_test( name = "llvm_compiler_test", - srcs = if_gpu_is_configured(["llvm_compiler_test.cc"]), + srcs = ["llvm_compiler_test.cc"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM", ]), @@ -2510,6 +2531,7 @@ xla_test( "//xla/service:transfer_manager", "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:platform_manager", "//xla/stream_executor/host:host_platform", "//xla/stream_executor/host:host_platform_id", "@tsl//tsl/platform:env", @@ -2727,6 +2749,7 @@ xla_cc_test( "//xla/client:client_library", "//xla/client:xla_builder", "//xla/service:cpu_plugin", + "//xla/stream_executor:platform_manager", "@com_google_absl//absl/synchronization", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", @@ -2794,7 +2817,7 @@ xla_test( xla_test( name = "cholesky_test", srcs = ["cholesky_test.cc"], - real_hardware_only = True, + shard_count = 10, tags = [ "optonly", ], @@ -2847,6 +2870,7 @@ xla_test( "cpu", ], copts = tsl_copts(), + tags = ["no_oss"], deps = [ ":hlo_test_base", ":test_macros_header", @@ -2855,6 +2879,63 @@ xla_test( "//xla:shape_util", "//xla:test", "//xla:test_helpers", + "//xla/hlo/utils:hlo_matchers", + "//xla/service/cpu:onednn_util", "@tsl//tsl/platform:platform_port", ], ) + +xla_test( + name = "onednn_layer_norm_test", + srcs = ["onednn_layer_norm_test.cc"], + backends = [ + "cpu", + ], + copts = tsl_copts(), + deps = [ + ":hlo_test_base", + ":test_macros_header", + ":xla_internal_test_main", + "//xla:literal", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + ], +) + +xla_test( + name = "onednn_softmax_test", + srcs = ["onednn_softmax_test.cc"], + backends = [ + "cpu", + ], + copts = tsl_copts(), + shard_count = 4, + deps = [ + ":hlo_test_base", + ":test_macros_header", + ":xla_internal_test_main", + "//xla:literal", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + "//xla/service/cpu:onednn_util", + "@tsl//tsl/platform:platform_port", + ], +) + +xla_test( + name = "numerics_test", + srcs = ["numerics_test.cc"], + deps = [ + ":hlo_test_base", + ":test_macros_header", + ":xla_internal_test_main", + "//xla:literal_util", + "//xla:statusor", + "//xla:test", + "//xla:types", + "//xla/hlo/ir:hlo", + "@tsl//tsl/platform:test", + ], +) diff --git a/xla/tests/all_reduce_test.cc b/xla/tests/all_reduce_test.cc index 5f9b4601b456a..c07ed4713baad 100644 --- a/xla/tests/all_reduce_test.cc +++ b/xla/tests/all_reduce_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/array_elementwise_ops_test.cc b/xla/tests/array_elementwise_ops_test.cc index 5a2e0add1eebb..d4fc025dc175a 100644 --- a/xla/tests/array_elementwise_ops_test.cc +++ b/xla/tests/array_elementwise_ops_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -27,20 +28,24 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/types/span.h" +#include "ml_dtypes/include/float8.h" // from @ml_dtypes #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/array4d.h" #include "xla/client/global_data.h" #include "xla/client/local_client.h" #include "xla/client/xla_builder.h" +#include "xla/comparison_util.h" #include "xla/layout_util.h" #include "xla/literal.h" +#include "xla/primitive_util.h" #include "xla/statusor.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/types.h" +#include "tsl/platform/ml_dtypes.h" #if TENSORFLOW_USE_ROCM #include "rocm/rocm_config.h" @@ -1293,14 +1298,140 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32sTO) { - SetFastMathDisabled(true); - XlaBuilder builder(TestName()); - auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); - auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 2.25f, NAN, NAN}); - EqTotalOrder(lhs, rhs); +template +class TotalOrderTest : public ClientLibraryTestBase { + public: + void DoIt(ComparisonDirection direction) { + this->SetFastMathDisabled(true); + XlaBuilder builder(this->TestName()); + std::vector values = { + static_cast(0.0f), + std::numeric_limits::min(), + static_cast(1.0f), + std::numeric_limits::max(), + }; + if constexpr (std::numeric_limits::has_denorm) { + auto denorm = static_cast(std::numeric_limits::denorm_min()); + if (denorm >= std::numeric_limits::min()) { + values.push_back(std::numeric_limits::denorm_min()); + } + } + if constexpr (std::is_same_v || std::is_same_v) { + values.push_back(std::fabs(std::numeric_limits::quiet_NaN())); + } + if constexpr (std::numeric_limits::has_infinity) { + values.push_back(std::numeric_limits::infinity()); + } +#if defined(XLA_TEST_BACKEND_CPU) || defined(XLA_TEST_BACKEND_GPU) || \ + defined(XLA_TEST_BACKEND_INTERPRETER) + if constexpr (std::numeric_limits::has_quiet_NaN) { + values.push_back(Eigen::numext::abs(std::numeric_limits::quiet_NaN())); + } +#endif + values.reserve(values.size() * 2); + for (size_t i = 0, n = values.size(); i < n; ++i) { + auto value = values[i]; + auto neg = -value; + if (Eigen::numext::signbit(neg) != Eigen::numext::signbit(value)) { + values.push_back(neg); + } + } + std::vector lhs_data; + std::vector rhs_data; + lhs_data.reserve(values.size() * values.size()); + rhs_data.reserve(values.size() * values.size()); + for (T lhs_value : values) { + for (T rhs_value : values) { + lhs_data.push_back(lhs_value); + rhs_data.push_back(rhs_value); + } + } + absl::InlinedVector results; + results.reserve(lhs_data.size()); + Comparison comparison(direction, primitive_util::NativeToPrimitiveType(), + Comparison::Order::kTotal); + for (size_t i = 0; i < lhs_data.size(); ++i) { + results.push_back(comparison.Compare(lhs_data[i], rhs_data[i])); + } + auto lhs = ConstantR1(&builder, lhs_data); + auto rhs = ConstantR1(&builder, rhs_data); + switch (direction) { + case ComparisonDirection::kEq: + EqTotalOrder(lhs, rhs); + break; + case ComparisonDirection::kNe: + NeTotalOrder(lhs, rhs); + break; + case ComparisonDirection::kGt: + GtTotalOrder(lhs, rhs); + break; + case ComparisonDirection::kGe: + GeTotalOrder(lhs, rhs); + break; + case ComparisonDirection::kLt: + LtTotalOrder(lhs, rhs); + break; + case ComparisonDirection::kLe: + LeTotalOrder(lhs, rhs); + break; + } + + this->ComputeAndCompareR1(&builder, results, {}); + } +}; - ComputeAndCompareR1(&builder, {false, false, true, true, false}, {}); +using Types = ::testing::Types; + +TYPED_TEST_SUITE(TotalOrderTest, Types); + +TYPED_TEST(TotalOrderTest, Eq) { this->DoIt(ComparisonDirection::kEq); } +TYPED_TEST(TotalOrderTest, Ne) { this->DoIt(ComparisonDirection::kNe); } +TYPED_TEST(TotalOrderTest, Le) { this->DoIt(ComparisonDirection::kLe); } +TYPED_TEST(TotalOrderTest, Lt) { this->DoIt(ComparisonDirection::kLt); } +TYPED_TEST(TotalOrderTest, Ge) { this->DoIt(ComparisonDirection::kGe); } +TYPED_TEST(TotalOrderTest, Gt) { this->DoIt(ComparisonDirection::kGt); } +TYPED_TEST(TotalOrderTest, LargeMagnitudeVsNaN) { + using T = TypeParam; + if constexpr (!std::numeric_limits::has_quiet_NaN) { + GTEST_SKIP(); + } + this->SetFastMathDisabled(true); + + XlaBuilder builder(this->TestName()); + std::vector values = { + static_cast(0.0f), + std::numeric_limits::min(), + static_cast(1.0f), + std::numeric_limits::max(), + }; + if constexpr (std::numeric_limits::has_infinity) { + values.push_back(std::numeric_limits::infinity()); + } + for (size_t i = 0, n = values.size(); i < n; ++i) { + auto value = values[i]; + auto neg = -value; + if (Eigen::numext::signbit(neg) != Eigen::numext::signbit(value)) { + values.push_back(neg); + } + } + auto lhs = ConstantR1(&builder, values); + auto rhs = ConstantR1( + &builder, + std::vector(values.size(), std::numeric_limits::quiet_NaN())); + LtTotalOrder(lhs, rhs); + TF_ASSERT_OK_AND_ASSIGN(auto result, this->ComputeAndTransfer(&builder, {})); + EXPECT_TRUE(result.IsAll(0) || result.IsAll(1)) << result.ToString(); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) { @@ -1322,23 +1453,6 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32sTO) { - SetFastMathDisabled(true); - XlaBuilder builder(TestName()); - // For portability, need to represent NAN using the following call. - // The C++ standard does not specify if quiet_NaN() sets the sign bit of - // its result. The call to std::fabs will ensure that it is not set. - auto kNaN = std::fabs(std::numeric_limits::quiet_NaN()); - auto lhs = - ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, kNaN, 6.0f, 6.0f}); - auto rhs = - ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, kNaN, -kNaN}); - GeTotalOrder(lhs, rhs); - - ComputeAndCompareR1(&builder, {false, true, true, true, false, true}, - {}); -} - XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); @@ -2009,6 +2123,121 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) { error_spec_); } +using ScalarF32TestCase = std::tuple; + +class ScalarF32MinMaxTest + : public ArrayElementwiseOpTest, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(ScalarF32MinMaxTest, Version_1) { + auto test_params = GetParam(); + XlaBuilder builder(TestName()); + SetFastMathDisabled(true); + float x = std::get<0>(test_params); + float y = std::get<1>(test_params); + auto lhs = ConstantR0(&builder, x); + auto rhs = ConstantR0(&builder, y); + Tuple(&builder, {Max(Max(lhs, rhs), rhs), Min(Min(lhs, rhs), rhs)}); + + float expected_min = std::min(x, y); + float expected_max = std::max(x, y); + if (std::isnan(x)) { + expected_min = x; + expected_max = x; + } else if (std::isnan(y)) { + expected_min = y; + expected_max = y; + } + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(expected_max), + LiteralUtil::CreateR0(expected_min)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); +} + +XLA_TEST_P(ScalarF32MinMaxTest, Version_2) { + auto test_params = GetParam(); + XlaBuilder builder(TestName()); + SetFastMathDisabled(true); + float x = std::get<0>(test_params); + float y = std::get<1>(test_params); + auto lhs = ConstantR0(&builder, x); + auto rhs = ConstantR0(&builder, y); + Tuple(&builder, {Max(Max(lhs, rhs), lhs), Min(Min(lhs, rhs), lhs)}); + + float expected_min = std::min(x, y); + float expected_max = std::max(x, y); + if (std::isnan(x)) { + expected_min = x; + expected_max = x; + } else if (std::isnan(y)) { + expected_min = y; + expected_max = y; + } + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(expected_max), + LiteralUtil::CreateR0(expected_min)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); +} + +XLA_TEST_P(ScalarF32MinMaxTest, Version_3) { + auto test_params = GetParam(); + XlaBuilder builder(TestName()); + SetFastMathDisabled(true); + float x = std::get<0>(test_params); + float y = std::get<1>(test_params); + auto lhs = ConstantR0(&builder, x); + auto rhs = ConstantR0(&builder, y); + Tuple(&builder, {Max(lhs, Max(lhs, rhs)), Min(lhs, Min(lhs, rhs))}); + + float expected_min = std::min(x, y); + float expected_max = std::max(x, y); + if (std::isnan(x)) { + expected_min = x; + expected_max = x; + } else if (std::isnan(y)) { + expected_min = y; + expected_max = y; + } + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(expected_max), + LiteralUtil::CreateR0(expected_min)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); +} + +XLA_TEST_P(ScalarF32MinMaxTest, Version_4) { + auto test_params = GetParam(); + XlaBuilder builder(TestName()); + SetFastMathDisabled(true); + float x = std::get<0>(test_params); + float y = std::get<1>(test_params); + auto lhs = ConstantR0(&builder, x); + auto rhs = ConstantR0(&builder, y); + Tuple(&builder, {Max(rhs, Max(lhs, rhs)), Min(rhs, Min(lhs, rhs))}); + + float expected_min = std::min(x, y); + float expected_max = std::max(x, y); + if (std::isnan(x)) { + expected_min = x; + expected_max = x; + } else if (std::isnan(y)) { + expected_min = y; + expected_max = y; + } + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(expected_max), + LiteralUtil::CreateR0(expected_min)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); +} + +INSTANTIATE_TEST_SUITE_P( + ScalarF32MinMaxTestInstance, ScalarF32MinMaxTest, + ::testing::Combine( + ::testing::Values(1.0, std::numeric_limits::infinity(), NAN, + -1.0, -std::numeric_limits::infinity(), -NAN), + ::testing::Values(1.0, std::numeric_limits::infinity(), NAN, + -1.0, -std::numeric_limits::infinity(), + -NAN))); + XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) { XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {}); @@ -2601,6 +2830,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, Atan2C64s) { ComputeAndCompare(&builder, {}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, ErfF32s) { + XlaBuilder builder(TestName()); + auto kInf = std::numeric_limits::infinity(); + auto kNaN = std::numeric_limits::quiet_NaN(); + auto a = ConstantR1( + &builder, {-kInf, -2.5f, 3.14f, -0.0f, 0.0f, 2.25f, kInf, kNaN}); + + Erf(a); + + ErrorSpec error_spec{1e-5f, 1e-5f}; + ComputeAndCompare(&builder, {}, error_spec); +} + XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) { XlaBuilder builder(TestName()); auto kInf = std::numeric_limits::infinity(); diff --git a/xla/tests/axpy_simple_test.cc b/xla/tests/axpy_simple_test.cc index 5cd2907349ed5..b2274dc2ab144 100644 --- a/xla/tests/axpy_simple_test.cc +++ b/xla/tests/axpy_simple_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/bad_rng_shape_validation_test.cc b/xla/tests/bad_rng_shape_validation_test.cc index 13c62390adbb0..90487f20c584f 100644 --- a/xla/tests/bad_rng_shape_validation_test.cc +++ b/xla/tests/bad_rng_shape_validation_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -40,7 +40,7 @@ TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) { Shape default_constructed; RngUniform(zero, one, default_constructed); - StatusOr computation = builder.Build(); + absl::StatusOr computation = builder.Build(); EXPECT_FALSE(computation.ok()); LOG(INFO) << "status received: " << computation.status(); EXPECT_THAT(computation.status().message(), @@ -57,7 +57,7 @@ TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) { RngUniform(zero, one, sans_layout); - StatusOr computation = builder.Build(); + absl::StatusOr computation = builder.Build(); ASSERT_TRUE(computation.ok()); LOG(INFO) << computation.status(); } diff --git a/xla/tests/batch_normalization_test.cc b/xla/tests/batch_normalization_test.cc index c66ff8beb6da4..b784ae527f84b 100644 --- a/xla/tests/batch_normalization_test.cc +++ b/xla/tests/batch_normalization_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/bfloat16_test.cc b/xla/tests/bfloat16_test.cc index d34fabbd1ef27..ffd84282fc544 100644 --- a/xla/tests/bfloat16_test.cc +++ b/xla/tests/bfloat16_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/binop_scaling_test.cc b/xla/tests/binop_scaling_test.cc index a967ae02b0ada..8205318b40f15 100644 --- a/xla/tests/binop_scaling_test.cc +++ b/xla/tests/binop_scaling_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/bitcast_convert_test.cc b/xla/tests/bitcast_convert_test.cc index 49c2266f7bc93..78c6b435fe5c9 100644 --- a/xla/tests/bitcast_convert_test.cc +++ b/xla/tests/bitcast_convert_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,7 +27,7 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/float8.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/tests/broadcast_simple_test.cc b/xla/tests/broadcast_simple_test.cc index 55600173f4bb7..7e4154ac01b4d 100644 --- a/xla/tests/broadcast_simple_test.cc +++ b/xla/tests/broadcast_simple_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "xla/array2d.h" #include "xla/array4d.h" #include "xla/client/local_client.h" @@ -34,6 +35,9 @@ namespace { class BroadcastSimpleTest : public ClientLibraryTestBase { public: + static constexpr absl::string_view kIncompatibleBinaryOpShapeErrorMessage = + "Binary op with incompatible shapes"; + XlaOp BuildBinOp(HloOpcode op, const XlaOp lhs, const XlaOp rhs, XlaBuilder* builder) { switch (op) { @@ -753,7 +757,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().message(), - HasSubstr("op add with incompatible shapes")); + HasSubstr(kIncompatibleBinaryOpShapeErrorMessage)); } XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { @@ -766,7 +770,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().message(), - HasSubstr("op add with incompatible shapes")); + HasSubstr(kIncompatibleBinaryOpShapeErrorMessage)); } } // namespace diff --git a/xla/tests/broadcast_test.cc b/xla/tests/broadcast_test.cc index bf40968f7a23f..c46104f844319 100644 --- a/xla/tests/broadcast_test.cc +++ b/xla/tests/broadcast_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/buffer_donation_test.cc b/xla/tests/buffer_donation_test.cc index 5d02acedd0dad..55ac7f0ce2306 100644 --- a/xla/tests/buffer_donation_test.cc +++ b/xla/tests/buffer_donation_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -78,13 +78,12 @@ class BufferDonationTest : public HloTestBase { backend_->compiler()->RunBackend(std::move(hlo_module), executor_, /*device_allocator=*/nullptr)); - se::Stream stream(executor_); - ASSERT_TRUE(stream.Init().ok()); + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor_->CreateStream()); se::StreamExecutorMemoryAllocator memory_allocator( platform_, backend_->stream_executors()); ExecutableRunOptions run_options; - run_options.set_stream(&stream); + run_options.set_stream(stream.get()); run_options.set_allocator(&memory_allocator); ServiceExecutableRunOptions service_run_options( run_options, backend_->StreamBorrowerWithPriority()); @@ -106,7 +105,7 @@ class BufferDonationTest : public HloTestBase { executor_->device_ordinal())); ShapedBuffer shaped_buffer = scoped_shaped_buffer.release(); TF_CHECK_OK(backend_->transfer_manager()->TransferLiteralToDevice( - &stream, argument_literal, shaped_buffer)); + stream.get(), argument_literal, shaped_buffer)); ShapeTree input_buffers = shaped_buffer.buffers(); inputs_buffers.push_back(input_buffers); ShapeTree owned_buffers( @@ -125,7 +124,7 @@ class BufferDonationTest : public HloTestBase { args.emplace_back(ExecutionInput(std::move(owned_buffers))); } - StatusOr output_status = + absl::StatusOr output_status = executable->ExecuteAsyncOnStream(&service_run_options, std::move(args), /*hlo_execution_profile=*/nullptr); if (!expected_failure.empty()) { @@ -162,7 +161,7 @@ class BufferDonationTest : public HloTestBase { TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, backend_->transfer_manager()->TransferLiteralFromDevice( - &stream, output.Result())); + stream.get(), output.Result())); EXPECT_TRUE(LiteralTestUtil::Equal(expected, result_literal)); // Memories are automatically deallocated. @@ -289,7 +288,7 @@ TEST_F(BufferDonationTest, TestNoCopyProtectionOnPassthroughParam) { HloModuleConfig config; config.set_alias_passthrough_params(true); - StatusOr> module = + absl::StatusOr> module = ParseAndReturnVerifiedModule(R"( HloModule module @@ -317,7 +316,7 @@ ENTRY entry { TEST_F(BufferDonationTest, TestMustAliasNotDonated) { HloModuleConfig config; - StatusOr> module = + absl::StatusOr> module = ParseAndReturnVerifiedModule(R"( HloModule module diff --git a/xla/tests/build_defs.bzl b/xla/tests/build_defs.bzl index 86a90ac26db6e..9cd17f5d41ca3 100644 --- a/xla/tests/build_defs.bzl +++ b/xla/tests/build_defs.bzl @@ -1,15 +1,15 @@ """Build rules for XLA testing.""" +load( + "@tsl//tsl/platform:build_config_root.bzl", + "tf_gpu_tests_tags", +) load("//xla:xla.bzl", "xla_cc_test") load( "//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured", ) load("//xla/tests:plugin.bzl", "plugins") -load( - "@tsl//tsl/platform:build_config_root.bzl", - "tf_gpu_tests_tags", -) all_backends = ["cpu", "gpu"] + plugins.keys() @@ -144,7 +144,16 @@ def xla_test( test_names.append(test_name) - native.test_suite(name = name, tags = tags, tests = test_names) + # Notably, a test_suite with `tests = []` is not empty: + # https://bazel.build/reference/be/general#test_suite_args and the default + # `tests = []` behavior doesn't respect `--build_tag_filters` due to + # b/317293391. For this reason, if we would create an empty `test_suite`, + # instead create a `cc_test` with no srcs that links against `main` to have + # more predictable behavior that avoids bugs. + if test_names: + native.test_suite(name = name, tags = tags, tests = test_names) + else: + native.cc_test(name = name, deps = ["@tsl//tsl/platform:test_main"]) def xla_test_library( name, diff --git a/xla/tests/call_test.cc b/xla/tests/call_test.cc index d6ddf1edc809a..4bc38c213b384 100644 --- a/xla/tests/call_test.cc +++ b/xla/tests/call_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/check_execution_arity_test.cc b/xla/tests/check_execution_arity_test.cc index f314694eaf355..1c9d4f40d910d 100644 --- a/xla/tests/check_execution_arity_test.cc +++ b/xla/tests/check_execution_arity_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/cholesky_test.cc b/xla/tests/cholesky_test.cc index 5d69994db6ada..5def722e43497 100644 --- a/xla/tests/cholesky_test.cc +++ b/xla/tests/cholesky_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/client_library_test_base.cc b/xla/tests/client_library_test_base.cc index 47950628ad639..86dc33b33d195 100644 --- a/xla/tests/client_library_test_base.cc +++ b/xla/tests/client_library_test_base.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ constexpr char kInterpreter[] = "interpreter"; // value()) if the platform we intend to test is not available. LocalClient* GetOrCreateLocalClientOrDie( const LocalClientOptions& client_options) { - StatusOr result = + absl::StatusOr result = ClientLibrary::GetOrCreateLocalClient(client_options); TF_CHECK_OK(result.status()) << " could not create local client for testing"; return result.value(); @@ -101,14 +101,14 @@ std::string ClientLibraryTestBase::TestName() const { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } -StatusOr> ClientLibraryTestBase::Execute( +absl::StatusOr> ClientLibraryTestBase::Execute( XlaBuilder* builder, absl::Span arguments) { // Build the computation, as a convenience. TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); return client_->Execute(computation, arguments, &execution_options_); } -StatusOr ClientLibraryTestBase::ExecuteAndTransfer( +absl::StatusOr ClientLibraryTestBase::ExecuteAndTransfer( const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout) { ExecutionOptions execution_options = execution_options_; @@ -120,7 +120,7 @@ StatusOr ClientLibraryTestBase::ExecuteAndTransfer( &execution_options); } -StatusOr ClientLibraryTestBase::ExecuteAndTransfer( +absl::StatusOr ClientLibraryTestBase::ExecuteAndTransfer( XlaBuilder* builder, absl::Span arguments, const Shape* shape_with_output_layout) { // Build the computation, as a convenience. @@ -128,7 +128,7 @@ StatusOr ClientLibraryTestBase::ExecuteAndTransfer( return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); } -StatusOr ClientLibraryTestBase::ExecuteAndTransferReference( +absl::StatusOr ClientLibraryTestBase::ExecuteAndTransferReference( const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout) { ExecutionOptions execution_options = execution_options_; @@ -270,7 +270,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( return choose(0); } -StatusOr ClientLibraryTestBase::ComputeAndTransfer( +absl::StatusOr ClientLibraryTestBase::ComputeAndTransfer( XlaBuilder* builder, absl::Span arguments_passed_in, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), @@ -481,7 +481,7 @@ void ClientLibraryTestBase::ComputeAndCompare( EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error)); } -StatusOr> +absl::StatusOr> ClientLibraryTestBase::ComputeValueAndReference( XlaBuilder* builder, absl::Span arguments) { // Transfer the arguments to the executor service. We put the unique_ptr's @@ -607,7 +607,7 @@ XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, : LiteralSlice(literal)); } -StatusOr> +absl::StatusOr> ClientLibraryTestBase::CreateParameterAndTransferLiteral( int64_t parameter_number, const Literal& literal, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { @@ -637,7 +637,7 @@ Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16( return literal.Clone(); } -StatusOr> +absl::StatusOr> ClientLibraryTestBase::CreateParameterAndTransferLiteral( int64_t parameter_number, const Literal& literal, const std::string& name, const DeviceHandle* device_handle, XlaBuilder* builder, diff --git a/xla/tests/client_library_test_base.h b/xla/tests/client_library_test_base.h index 765fd10340ade..b7be83b5c9c40 100644 --- a/xla/tests/client_library_test_base.h +++ b/xla/tests/client_library_test_base.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -37,9 +37,10 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/manifest_checking_test.h" #include "xla/tests/test_utils.h" +#include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/bitmap.h" -#include "tsl/platform/float8.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/test.h" namespace xla { @@ -93,14 +94,14 @@ class ClientLibraryTestBase : public ManifestCheckingTest { // Convenience methods for building and running a computation with the member // execution options. Modify execution_options_ in your test if you want to // customize the options. - StatusOr> Execute( + absl::StatusOr> Execute( XlaBuilder* builder, absl::Span arguments); - StatusOr ExecuteAndTransfer( + absl::StatusOr ExecuteAndTransfer( XlaBuilder* builder, absl::Span arguments, const Shape* shape_with_output_layout = nullptr); - StatusOr ExecuteAndTransfer( + absl::StatusOr ExecuteAndTransfer( const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout = nullptr); @@ -108,7 +109,7 @@ class ClientLibraryTestBase : public ManifestCheckingTest { // This executes the computation via the reference client (which connects a // interpreter backend). The result is used as the expected value of the // computation. - StatusOr ExecuteAndTransferReference( + absl::StatusOr ExecuteAndTransferReference( const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout = nullptr); @@ -192,7 +193,7 @@ class ClientLibraryTestBase : public ManifestCheckingTest { // Build and run the computation and return the result as a literal. // shape_with_layout indicates the result layout to request when calling // Execute. - StatusOr ComputeAndTransfer( + absl::StatusOr ComputeAndTransfer( XlaBuilder* builder, absl::Span arguments, const Shape* shape_with_layout = nullptr); @@ -275,14 +276,14 @@ class ClientLibraryTestBase : public ManifestCheckingTest { // server, then stores into "data_handle" the global handle for that // parameter. When the use_bfloat16 flag is set but the literal has F32 // elements, the literal will be converted to BF16 before being transferred. - StatusOr> CreateParameterAndTransferLiteral( + absl::StatusOr> CreateParameterAndTransferLiteral( int64_t parameter_number, const Literal& literal, const std::string& name, XlaBuilder* builder, XlaOp* data_handle); // As above, but the caller can specify the device that the literal is // transferred to. If device_handle is nullptr, the literal will be // transferred to the default device. - StatusOr> CreateParameterAndTransferLiteral( + absl::StatusOr> CreateParameterAndTransferLiteral( int64_t parameter_number, const Literal& literal, const std::string& name, const DeviceHandle* device_handle, XlaBuilder* builder, XlaOp* data_handle); @@ -410,7 +411,7 @@ class ClientLibraryTestBase : public ManifestCheckingTest { // Executes the computation and calculates the expected reference value using // the reference client. Returns two literals in the order of (expected, // actual). - StatusOr> ComputeValueAndReference( + absl::StatusOr> ComputeValueAndReference( XlaBuilder* builder, absl::Span arguments); // Converts an f32 literal to bf16 if use_bfloat16_ is true. @@ -444,6 +445,10 @@ class ClientLibraryTestBase : public ManifestCheckingTest { // Arguments to be passed to the computation when it runs. std::vector arguments_; + + template + static constexpr inline bool is_floating_or_complex_v = + std::disjunction_v, is_complex>; }; template @@ -459,17 +464,7 @@ template void ClientLibraryTestBase::ComputeAndCompareR0( XlaBuilder* builder, NativeT expected, absl::Span arguments, ErrorSpec error) { - static_assert(std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value, + static_assert(is_floating_or_complex_v, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, @@ -489,17 +484,7 @@ template void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, absl::Span expected, absl::Span arguments, ErrorSpec error) { - static_assert(std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value, + static_assert(is_floating_or_complex_v, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, @@ -520,17 +505,7 @@ template void ClientLibraryTestBase::ComputeAndCompareR2( XlaBuilder* builder, const Array2D& expected, absl::Span arguments, ErrorSpec error) { - static_assert(std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value, + static_assert(is_floating_or_complex_v, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR2FromArray2D(expected); @@ -552,17 +527,7 @@ template void ClientLibraryTestBase::ComputeAndCompareR3( XlaBuilder* builder, const Array3D& expected, absl::Span arguments, ErrorSpec error) { - static_assert(std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value, + static_assert(is_floating_or_complex_v, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR3FromArray3D(expected); @@ -584,17 +549,7 @@ template void ClientLibraryTestBase::ComputeAndCompareR4( XlaBuilder* builder, const Array4D& expected, absl::Span arguments, ErrorSpec error) { - static_assert(std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value, + static_assert(is_floating_or_complex_v, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR4FromArray4D(expected); @@ -615,17 +570,7 @@ template void ClientLibraryTestBase::ComputeAndCompare( XlaBuilder* builder, const Array& expected, absl::Span arguments, ErrorSpec error) { - static_assert(std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value, + static_assert(is_floating_or_complex_v, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateFromArray(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, diff --git a/xla/tests/client_test.cc b/xla/tests/client_test.cc index 8dfbb7e959c5d..5cf7fff6404f8 100644 --- a/xla/tests/client_test.cc +++ b/xla/tests/client_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/codegen_test_base.cc b/xla/tests/codegen_test_base.cc index 785233794ff1e..aadb9480e43c0 100644 --- a/xla/tests/codegen_test_base.cc +++ b/xla/tests/codegen_test_base.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,8 +19,9 @@ limitations under the License. namespace xla { -StatusOr> CodegenTestBase::CompileToExecutable( - std::unique_ptr hlo_module, bool run_optimization_passes) { +absl::StatusOr> +CodegenTestBase::CompileToExecutable(std::unique_ptr hlo_module, + bool run_optimization_passes) { if (run_optimization_passes) { TF_ASSIGN_OR_RETURN(hlo_module, backend().compiler()->RunHloPasses( std::move(hlo_module), @@ -32,7 +33,7 @@ StatusOr> CodegenTestBase::CompileToExecutable( /*device_allocator=*/nullptr); } -StatusOr> +absl::StatusOr> CodegenTestBase::CompileToAotCompilationResult( std::unique_ptr hlo_module, const AotCompilationOptions& options) { diff --git a/xla/tests/codegen_test_base.h b/xla/tests/codegen_test_base.h index 00dda51c8ca9d..7d71613aa874e 100644 --- a/xla/tests/codegen_test_base.h +++ b/xla/tests/codegen_test_base.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,14 +29,14 @@ namespace xla { class CodegenTestBase : public HloTestBase { protected: // Compiles hlo_module with the JIT compiler. - StatusOr> CompileToExecutable( + absl::StatusOr> CompileToExecutable( std::unique_ptr hlo_module, bool run_optimization_passes = true); // Compiles hlo_module with the AOT compiler. - StatusOr> CompileToAotCompilationResult( - std::unique_ptr hlo_module, - const AotCompilationOptions& options); + absl::StatusOr> + CompileToAotCompilationResult(std::unique_ptr hlo_module, + const AotCompilationOptions& options); }; } // namespace xla diff --git a/xla/tests/collective_ops_test.cc b/xla/tests/collective_ops_test.cc index 29d7e5781bfa4..e366264a1d4c4 100644 --- a/xla/tests/collective_ops_test.cc +++ b/xla/tests/collective_ops_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -480,7 +480,7 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_ThreeReplicaGroups) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, ExecuteReplicated(std::move(module), {&input_literal}, /*num_replicas=*/4, - /*use_threads=*/true)); + /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), 4); @@ -522,7 +522,7 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_Degenerate) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, ExecuteReplicated(std::move(module), {}, /*num_replicas=*/kNumReplicas, - /*use_threads=*/true)); + /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); for (int i = 0; i < kNumReplicas; ++i) { @@ -618,9 +618,10 @@ XLA_TEST_F(CollectiveOpsTest, ReplicaId) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModuleStr)); - TF_ASSERT_OK_AND_ASSIGN(std::vector results, - ExecuteReplicated(std::move(module), {}, num_devices_, - /*use_threads=*/true)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), {}, num_devices_, + /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), num_devices_); for (uint32_t i = 0; i < num_devices_; ++i) { @@ -628,16 +629,22 @@ XLA_TEST_F(CollectiveOpsTest, ReplicaId) { } } -XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) { +XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectiveBroadcast_Simple)) { const char* const kModuleStr = R"( HloModule test + + collective_broadcast { + p0 = u32[2] parameter(0) + ROOT result = u32[2] collective-broadcast(p0), replica_groups={{1, 0, 2, 3}} + } + ENTRY test_computation { replica = u32[] replica-id() ten = u32[] constant(10) sum = u32[] add(replica, ten) p = u32[2] broadcast(sum), dimensions={} - permute = u32[2] collective-permute(p), source_target_pairs={{1,0}, {0,1}, {2,2}} - ROOT copy = u32[2] copy(permute) + cb = ((u32[2]), u32[2]) async-start(u32[2] %p), calls=collective_broadcast + ROOT res = u32[2] async-done(cb), calls=collective_broadcast } )"; const int64_t kNumReplicas = 4; @@ -652,6 +659,41 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) { ExecuteReplicated(std::move(module), {}, kNumReplicas, /*use_threads=*/true)); ASSERT_EQ(results.size(), kNumReplicas); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({11, 11}), + results[0])); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({11, 11}), + results[1])); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({11, 11}), + results[2])); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({11, 11}), + results[3])); +} + +XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + replica = u32[] replica-id() + ten = u32[] constant(10) + sum = u32[] add(replica, ten) + p = u32[2] broadcast(sum), dimensions={} + permute = u32[2] collective-permute(p), source_target_pairs={{1,0}, {0,1}, {2,2}} + ROOT copy = u32[2] copy(permute) + } + )"; + const int64_t kNumReplicas = 4; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), {}, kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({11, 11}), results[0])); EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({10, 10}), @@ -683,9 +725,10 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Degenerate) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN(std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, - /*use_threads=*/true)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), {}, kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({10, 10}), results[0])); @@ -717,9 +760,10 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_NotDegenerate) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN(std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, - /*use_threads=*/true)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), {}, kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({10, 10}), results[0])); @@ -752,9 +796,10 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Rotate) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN(std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, - /*use_threads=*/true)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), {}, kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({13, 13}), results[0])); @@ -829,9 +874,10 @@ XLA_TEST_F(CollectiveOpsTest, AllToAll_EmptyReplicaGroups) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN(std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, - /*use_threads=*/true)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), {}, kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); LiteralTestUtil::ExpectR1Equal({10, 15, 11, 16, 12, 17, 13, 18}, results[0]); @@ -873,9 +919,10 @@ XLA_TEST_F(CollectiveOpsTest, AllToAll_OrderedReplicaGroups) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN(std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, - /*use_threads=*/true)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), {}, kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); LiteralTestUtil::ExpectR1Equal({43, 48, 42, 47, 41, 46, 40, 45}, results[0]); @@ -911,9 +958,10 @@ XLA_TEST_F(CollectiveOpsTest, AllToAll_TwoReplicaGroups) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN(std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, - /*use_threads=*/true)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), {}, kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); LiteralTestUtil::ExpectR1Equal({23, 28, 20, 25}, results[0]); LiteralTestUtil::ExpectR1Equal({22, 27, 21, 26}, results[1]); @@ -941,9 +989,10 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_SplitDimension)) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN(std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, - /*use_threads=*/true)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), {}, kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); LiteralTestUtil::ExpectR1Equal({10, 15, 11, 16, 12, 17, 13, 18}, results[0]); @@ -983,6 +1032,34 @@ XLA_TEST_F(CollectiveOpsTest, AllGather_Dim0) { } } +XLA_TEST_F(CollectiveOpsTest, AllGather_Dim0_UseGlobalDevices) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + id = u32[] replica-id() + id2 = u32[1, 2] broadcast(id), dimensions={} + a0 = u32[1, 2] constant({{10, 15}}) + a1 = u32[1, 2] add(id2, a0) + allgather = u32[2, 2] all-gather(a1), dimensions={0}, use_global_device_ids=true, channel_id=7, replica_groups={{0, 1}} + ROOT out = u32[4] reshape(allgather) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), {}, kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + for (const Literal& result : results) { + LiteralTestUtil::ExpectR1Equal({10, 15, 11, 16}, result); + } +} + XLA_TEST_F(CollectiveOpsTest, AllGather_Dim1) { const char* const kModuleStr = R"( HloModule test @@ -1844,16 +1921,16 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_Simple)) { %p = u32[2] broadcast(%sum), dimensions={} %after-all = token[] after-all() - %recv = (u32[2], u32[], token[]) recv(%after-all), channel_id=1, frontend_attributes={ + %recv = (u32[2], u32[], token[]) recv(%after-all), channel_id=0, frontend_attributes={ _xla_send_recv_source_target_pairs="{{1,0}}" } - %send = (u32[2], u32[], token[]) send(%p, %after-all), channel_id=1, control-predecessors={%recv}, frontend_attributes={ + %send = (u32[2], u32[], token[]) send(%p, %after-all), channel_id=0, control-predecessors={%recv}, frontend_attributes={ _xla_send_recv_source_target_pairs="{{1,0}}" } - %recv-done = (u32[2], token[]) recv-done(%recv), channel_id=1 + %recv-done = (u32[2], token[]) recv-done(%recv), channel_id=0 %recv-data = u32[2] get-tuple-element(%recv-done), index=0 - %send-done = token[] send-done(%send), channel_id=1, control-predecessors={%recv} + %send-done = token[] send-done(%send), channel_id=0, control-predecessors={%recv} ROOT copy = u32[2] copy(%recv-data) } )"; @@ -1876,5 +1953,226 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_Simple)) { results[1])); } +XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_TwoConcurrentChains)) { + const char* const kModuleStr = R"( + HloModule test, is_scheduled=true + + ENTRY test_computation { + c0 = u32[] constant(0) + c1 = u32[] constant(1) + replica = u32[] replica-id() + a = u32[] add(c1, replica) + send-data = u32[2] broadcast(a), dimensions={} + + after-all.0 = token[] after-all() + recv.0 = (u32[2], u32[], token[]) recv(after-all.0), channel_id=0, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}", + _xla_send_recv_pipeline="1" + } + send.0 = (u32[2], u32[], token[]) send(send-data, after-all.0), + channel_id=0, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}", + _xla_send_recv_pipeline="1" + } + + after-all.1 = token[] after-all() + recv.1 = (u32[2], u32[], token[]) recv(after-all.1), channel_id=0, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}}" + } + send.1 = (u32[2], u32[], token[]) send(send-data, after-all.1), + channel_id=0, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}}" + } + recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=0 + recv-data.0 = u32[2] get-tuple-element(recv-done.0), index=0 + recv-done.1 = (u32[2], token[]) recv-done(recv.1), channel_id=0 + recv-data.1 = u32[2] get-tuple-element(recv-done.1), index=0 + + compare0 = pred[] compare(replica, c0), direction=EQ + compare = pred[2] broadcast(compare0), dimensions={} + recv-data = u32[2] select(compare, recv-data.0, recv-data.1) + + send-done.0 = token[] send-done(send.0), channel_id=0 + send-done.1 = token[] send-done(send.1), channel_id=0 + + c1b = u32[2] broadcast(c1), dimensions={} + ROOT result = u32[2] add(c1b, recv-data) + })"; + + const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), {}, kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/false)); + ASSERT_EQ(results.size(), kNumReplicas); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({3, 3}), + results[0])); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({2, 2}), + results[1])); +} + +XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_ValidationAttr1)) { + const char* const kModuleStr = R"( + HloModule test, is_scheduled=true + + ENTRY test_computation { + c0 = u32[] constant(0) + c1 = u32[] constant(1) + replica = u32[] replica-id() + a = u32[] add(c1, replica) + send-data = u32[2] broadcast(a), dimensions={} + + after-all.0 = token[] after-all() + recv.0 = (u32[2], u32[], token[]) recv(after-all.0), channel_id=0, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}", + _xla_send_recv_validation="invalid" + } + send.0 = (u32[2], u32[], token[]) send(send-data, after-all.0), + channel_id=0, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}", + _xla_send_recv_validation="invalid" + } + recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=0 + recv-data.0 = u32[2] get-tuple-element(recv-done.0), index=0 + send-done.0 = token[] send-done(send.0), channel_id=0 + + after-all.1 = token[] after-all() + recv.1 = (u32[2], u32[], token[]) recv(after-all.1), channel_id=0, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}}" + } + send.1 = (u32[2], u32[], token[]) send(send-data, after-all.1), + channel_id=0, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}}" + } + recv-done.1 = (u32[2], token[]) recv-done(recv.1), channel_id=0 + recv-data.1 = u32[2] get-tuple-element(recv-done.1), index=0 + + compare0 = pred[] compare(replica, c0), direction=EQ + compare = pred[2] broadcast(compare0), dimensions={} + recv-data = u32[2] select(compare, recv-data.0, recv-data.1) + send-done.1 = token[] send-done(send.1), channel_id=0 + + c1b = u32[2] broadcast(c1), dimensions={} + ROOT result = u32[2] add(c1b, recv-data) + })"; + + const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), {}, kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/false)); + ASSERT_EQ(results.size(), kNumReplicas); + // Skip checking the result for device 0 as it has garabage value as the + // Recv operation is marked for skipping at runtime. + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({2, 2}), + results[1])); +} + +XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_ValidationAttr2)) { + const char* const kModuleStr = R"( + HloModule test, is_scheduled=true +cond { + param = (u32[], u32[2]) parameter(0) + count = get-tuple-element(%param), index=0 + ub = u32[] constant(2) + ROOT result = pred[] compare(count, ub), direction=LT + } + +body { + param = (u32[], u32[2]) parameter(0) + count = get-tuple-element(%param), index=0 + send-data = get-tuple-element(%param), index=1 + + after-all.0 = token[] after-all() + recv.0 = (u32[2], u32[], token[]) recv(after-all.0), channel_id=0, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}", + _xla_send_recv_validation="{{0,1}}" + } + send.0 = (u32[2], u32[], token[]) send(send-data, after-all.0), + channel_id=0, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{1,0}}", + _xla_send_recv_validation="{{0,1}}" + } + recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=0 + recv-data.0 = u32[2] get-tuple-element(recv-done.0), index=0 + send-done.0 = token[] send-done(send.0), channel_id=0 + + after-all.1 = token[] after-all() + recv.1 = (u32[2], u32[], token[]) recv(after-all.1), channel_id=0, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}}" + } + send.1 = (u32[2], u32[], token[]) send(send-data, after-all.1), + channel_id=0, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}}" + } + recv-done.1 = (u32[2], token[]) recv-done(recv.1), channel_id=0 + recv-data.1 = u32[2] get-tuple-element(recv-done.1), index=0 + + replica = u32[] replica-id() + constant0 = u32[] constant(0) + compare0 = pred[] compare(replica, constant0), direction=EQ + compare = pred[2] broadcast(compare0), dimensions={} + recv-data = u32[2] select(compare, recv-data.0, recv-data.1) + + c1 = u32[] constant(1) + new_count = u32[] add(count, c1) + + r = u32[2] broadcast(c1), dimensions={} + s = u32[2] add(r, recv-data) + + send-done.1 = token[] send-done(send.1), channel_id=0 + ROOT result = (u32[], u32[2]) tuple(new_count, s) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + r = u32[] replica-id() + init = u32[2] broadcast(r), dimensions={} + while_init = (u32[], u32[2]) tuple(c0, init) + while_result = (u32[], u32[2]) while(while_init), body=body, condition=cond + ROOT result = u32[2] get-tuple-element(while_result), index=1 + })"; + + const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), {}, kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/false)); + ASSERT_EQ(results.size(), kNumReplicas); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({2, 2}), + results[0])); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({3, 3}), + results[1])); +} + } // namespace } // namespace xla diff --git a/xla/tests/collective_ops_test_e2e.cc b/xla/tests/collective_ops_test_e2e.cc index 770e928481e90..55bf673c21783 100644 --- a/xla/tests/collective_ops_test_e2e.cc +++ b/xla/tests/collective_ops_test_e2e.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/tests/test_utils.h" namespace xla { namespace { @@ -46,8 +47,8 @@ DeviceAssignment MakeDeviceAssn(int64_t num_replicas) { class CollectiveOpsTestE2E : public HloTestBase { public: - StatusOr> ExecuteReplicated(Executable* executable, - int64_t num_replicas) { + absl::StatusOr> ExecuteReplicated(Executable* executable, + int64_t num_replicas) { DeviceAssignment device_assignment = MakeDeviceAssn(num_replicas); return HloTestBase::ExecuteReplicated( /*executable_provider*/ [&](int64_t) { return executable; }, @@ -78,7 +79,9 @@ class AsyncCollectiveOps : public CollectiveOpsTestE2E, // Enable or disable all async collectives based on test parameter. const bool enable_async = GetParam(); + debug_options.set_xla_gpu_enable_async_collectives(false); debug_options.set_xla_gpu_enable_async_all_reduce(enable_async); + debug_options.set_xla_gpu_enable_async_collective_broadcast(enable_async); debug_options.set_xla_gpu_enable_async_collective_permute(enable_async); debug_options.set_xla_gpu_enable_async_all_gather(enable_async); debug_options.set_xla_gpu_enable_async_reduce_scatter(enable_async); @@ -88,7 +91,7 @@ class AsyncCollectiveOps : public CollectiveOpsTestE2E, return debug_options; } - StatusOr> CreateExecutable( + absl::StatusOr> CreateExecutable( absl::string_view hlo_string, int64_t num_replicas) { HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/num_replicas); @@ -100,7 +103,10 @@ class AsyncCollectiveOps : public CollectiveOpsTestE2E, } bool IsAsync(const HloInstruction* inst) { - return !inst->backend_config()->is_sync(); + return !inst->backend_config() + .value() + .collective_backend_config() + .is_sync(); } const int64_t num_devices_; @@ -224,6 +230,38 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllGatherMixedTypes) { } } +XLA_TEST_P(AsyncCollectiveOps, AsyncCollectiveBroadcast) { + const absl::string_view kModuleStr = R"( + HloModule test + ENTRY test_computation { + replica = u32[] replica-id() + ten = u32[] constant(10) + sum = u32[] add(replica, ten) + p = u32[2] broadcast(sum), dimensions={} + bcast = u32[2] collective-broadcast(p), replica_groups={{1, 0}} + ROOT res = copy(bcast) + } + )"; + const int64_t kNumReplicas = 2; + const bool enable_async_collective_broadcast = GetParam(); + TF_ASSERT_OK_AND_ASSIGN(auto executable, + CreateExecutable(kModuleStr, kNumReplicas)); + EXPECT_TRUE(executable->has_module()); + HloInstruction* cb_start = + FindInstruction(&executable->module(), HloOpcode::kAsyncStart); + HloInstruction* cb_done = + FindInstruction(&executable->module(), HloOpcode::kAsyncDone); + EXPECT_THAT(cb_start, NotNull()); + EXPECT_THAT(cb_done, NotNull()); + EXPECT_EQ(IsAsync(cb_start), enable_async_collective_broadcast); + + TF_ASSERT_OK_AND_ASSIGN(std::vector results, + ExecuteReplicated(executable.get(), kNumReplicas)); + ASSERT_EQ(results.size(), kNumReplicas); + LiteralTestUtil::ExpectR1Equal({11, 11}, results[0]); + LiteralTestUtil::ExpectR1Equal({11, 11}, results[1]); +} + XLA_TEST_P(AsyncCollectiveOps, AsyncCollectivePermute) { const absl::string_view kModuleStr = R"( HloModule test @@ -378,6 +416,82 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllToAllWithoutSplitDim) { LiteralTestUtil::ExpectR1Equal({40, 60, 44, 64}, results[1]); } +TEST_P(AsyncCollectiveOps, MatmulReplicated) { + // collective_permute = f32[16,32]{1,0} collective-permute(x_unscaled), + // source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + absl::string_view kModuleReplicatedStr = R"( + HloModule test + + ENTRY test { + x_f32 = f32[16,32] parameter(0) + y_f32 = f32[16,32] parameter(1) + replica_id = u32[] replica-id() + addend = f32[] convert(replica_id) + addend_bcast = f32[16,32] broadcast(addend), dimensions={} + x_add = f32[16,32] add(addend_bcast, x_f32) + ROOT dot_a = f32[16,16] dot(x_add, y_f32), lhs_contracting_dims={1}, rhs_contracting_dims={1} + } + )"; + + absl::string_view kModuleSingleStr = R"( + HloModule test + + ENTRY test { + x_f32 = f32[16,32] parameter(0) + y_f32 = f32[16,32] parameter(1) + replica_id = u32[] parameter(2) + addend = f32[] convert(replica_id) + addend_bcast = f32[16,32] broadcast(addend), dimensions={} + x_add = f32[16,32] add(addend_bcast, x_f32) + ROOT dot_a = f32[16,16] dot(x_add, y_f32), lhs_contracting_dims={1}, rhs_contracting_dims={1} + } + )"; + const int64_t kNumReplicas = 4; + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + auto opts = GetDebugOptionsForTest(); + opts.set_xla_gpu_enable_cublaslt(GetParam()); + VLOG(0) << "Running with CUBLAS enabled: " << opts.xla_gpu_enable_cublaslt(); + config.set_debug_options(opts); + + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config)); + DeviceAssignment assn(/*replica_count=*/kNumReplicas, + /*computation_count=*/1); + for (int64_t i = 0; i < kNumReplicas; ++i) { + assn(i, 0) = i; + } + + auto fake_arguments = xla::MakeFakeArguments(module.get()).value(); + std::vector fake_ptrs(fake_arguments.size()); + for (int i = 0; i < fake_arguments.size(); i++) { + fake_ptrs[i] = &fake_arguments[i]; + } + TF_ASSERT_OK_AND_ASSIGN(std::vector results, + HloTestBase::ExecuteReplicated( + std::move(module), fake_ptrs, kNumReplicas, &assn, + true /*run_hlo_passes*/, true /*use-threads*/)); + ASSERT_EQ(results.size(), kNumReplicas); + + auto& ref_runner = HloTestBase::reference_runner_; + TF_ASSERT_OK_AND_ASSIGN( + auto ref_module, ParseAndReturnVerifiedModule(kModuleSingleStr, config)); + TF_ASSERT_OK_AND_ASSIGN( + auto ref_exec, ref_runner.CreateExecutable(std::move(ref_module), true)); + + ErrorSpec error_spec{1e-5, 1e-5}; + fake_ptrs.push_back(nullptr); + for (int i = 0; i < kNumReplicas; i++) { + auto replica_id = + LiteralUtil::CreateFullWithDescendingLayout({}, i); + fake_ptrs.back() = &replica_id; + TF_ASSERT_OK_AND_ASSIGN( + auto res, ref_runner.ExecuteWithExecutable(ref_exec.get(), fake_ptrs)); + EXPECT_TRUE(LiteralTestUtil::Near(res, results[i], error_spec)); + } +} + INSTANTIATE_TEST_SUITE_P(AsyncCollectiveOps, AsyncCollectiveOps, ::testing::Bool()); diff --git a/xla/tests/collective_pipeliner_execution_test.cc b/xla/tests/collective_pipeliner_execution_test.cc index 536e83799204f..403f9b32854e7 100644 --- a/xla/tests/collective_pipeliner_execution_test.cc +++ b/xla/tests/collective_pipeliner_execution_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ namespace { using CollectivePipelinerExecutionTest = HloTestBase; -StatusOr RunOptimizer( +absl::StatusOr RunOptimizer( HloModule* module, bool last_run, int64_t level_to_operate_on = 0, HloPredicate should_process = HloPredicateIsOp, CollectivePipeliner::PipeliningDirection pipelining_direction = diff --git a/xla/tests/compilation_cache_test.cc b/xla/tests/compilation_cache_test.cc deleted file mode 100644 index d947c209a2197..0000000000000 --- a/xla/tests/compilation_cache_test.cc +++ /dev/null @@ -1,171 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "absl/types/span.h" -#include "xla/client/global_data.h" -#include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" -#include "xla/literal.h" -#include "xla/shape_util.h" -#include "xla/statusor.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" -#include "xla/tests/test_macros.h" -#include "xla/tests/test_utils.h" -#include "xla/xla.pb.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" - -namespace xla { -namespace { - -class CompilationCacheTest : public ClientLibraryTestBase { - public: - void ExecuteComputationR0F32(const XlaComputation& computation, - absl::Span arguments, - float expected_result, bool expect_cache_hit) { - ExecutionProfile execution_profile; - Literal result = - client_ - ->ExecuteAndTransfer(computation, arguments, - /*execution_options=*/&execution_options_, - &execution_profile) - .value(); - EXPECT_TRUE(LiteralTestUtil::Near( - LiteralUtil::CreateR0(expected_result), result, error_spec_)); - EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); - } - - void ExecuteComputationR2F32( - const XlaComputation& computation, - absl::Span arguments, - std::initializer_list> expected_result, - bool expect_cache_hit) { - ExecutionProfile execution_profile; - auto data_handle = client_ - ->Execute(computation, arguments, - &execution_options_, &execution_profile) - .value(); - Literal result = client_->Transfer(*data_handle).value(); - EXPECT_TRUE(LiteralTestUtil::Near( - LiteralUtil::CreateR2(expected_result), result, error_spec_)); - EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); - } - - ErrorSpec error_spec_{0.0001}; -}; - -// TODO(b/74197823): Disabled because there is no cache in the new design. -XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) { - XlaBuilder builder(TestName()); - Neg(ConstantR0(&builder, 42.0)); - XlaComputation computation = builder.Build().value(); - - ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false); - ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true); - ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true); -} - -// TODO(b/74197823): Disabled because there is no cache in the new design. -XLA_TEST_F(CompilationCacheTest, - DISABLED_ComputationCalledWithDifferentParameters) { - std::unique_ptr data_42 = - client_->TransferToServer(LiteralUtil::CreateR0(42.0f)).value(); - std::unique_ptr data_123 = - client_->TransferToServer(LiteralUtil::CreateR0(123.0f)).value(); - std::unique_ptr data_456 = - client_->TransferToServer(LiteralUtil::CreateR0(456.0f)).value(); - - XlaBuilder builder(TestName()); - Neg(Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param")); - XlaComputation computation = builder.Build().value(); - - ExecuteComputationR0F32(computation, {data_42.get()}, -42.0, - /*expect_cache_hit=*/false); - ExecuteComputationR0F32(computation, {data_123.get()}, -123.0, - /*expect_cache_hit=*/true); - ExecuteComputationR0F32(computation, {data_456.get()}, -456.0, - /*expect_cache_hit=*/true); - ExecuteComputationR0F32(computation, {data_42.get()}, -42.0, - /*expect_cache_hit=*/true); -} - -// TODO(b/74197823): Disabled because there is no cache in the new design. -XLA_TEST_F(CompilationCacheTest, DISABLED_MultipleComputations) { - XlaBuilder builder_neg(TestName() + "_neg"); - Neg(ConstantR0(&builder_neg, 42.0)); - XlaComputation computation_neg = builder_neg.Build().value(); - - XlaBuilder builder_exp(TestName() + "_exp"); - Exp(ConstantR0(&builder_exp, 1.0)); - XlaComputation computation_exp = builder_exp.Build().value(); - - XlaBuilder builder_add(TestName() + "_add"); - Add(ConstantR0(&builder_add, 2.0), - ConstantR0(&builder_add, 3.0)); - XlaComputation computation_add = builder_add.Build().value(); - - ExecuteComputationR0F32(computation_neg, {}, -42.0, - /*expect_cache_hit=*/false); - ExecuteComputationR0F32(computation_exp, {}, 2.7182817, - /*expect_cache_hit=*/false); - ExecuteComputationR0F32(computation_add, {}, 5.0, - /*expect_cache_hit=*/false); - ExecuteComputationR0F32(computation_neg, {}, -42.0, - /*expect_cache_hit=*/true); -} - -// TODO(b/74197823): Disabled because there is no cache in the new design. -XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) { - // Create two GlobalData arrays with the same shape but different - // layouts. Use these arrays as parameters to a simple computation. If the - // layout of the array changes then computation should be recompiled (cache - // miss). - auto rowmaj_array = LiteralUtil::CreateR2WithLayout( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0})); - auto rowmaj_handle = client_->TransferToServer(rowmaj_array).value(); - - auto colmaj_array = LiteralUtil::CreateR2WithLayout( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})); - auto colmaj_handle = client_->TransferToServer(colmaj_array).value(); - - XlaBuilder builder(TestName()); - Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); - XlaComputation computation = builder.Build().value(); - - ExecuteComputationR2F32(computation, {colmaj_handle.get()}, - {{1.0f, 2.0f}, {3.0f, 4.0f}}, - /*expect_cache_hit=*/false); - ExecuteComputationR2F32(computation, {colmaj_handle.get()}, - {{1.0f, 2.0f}, {3.0f, 4.0f}}, - /*expect_cache_hit=*/true); - ExecuteComputationR2F32(computation, {rowmaj_handle.get()}, - {{1.0f, 2.0f}, {3.0f, 4.0f}}, - /*expect_cache_hit=*/false); - ExecuteComputationR2F32(computation, {rowmaj_handle.get()}, - {{1.0f, 2.0f}, {3.0f, 4.0f}}, - /*expect_cache_hit=*/true); - ExecuteComputationR2F32(computation, {colmaj_handle.get()}, - {{1.0f, 2.0f}, {3.0f, 4.0f}}, - /*expect_cache_hit=*/true); -} - -} // namespace -} // namespace xla diff --git a/xla/tests/complex_unary_op_samples.h b/xla/tests/complex_unary_op_samples.h new file mode 100644 index 0000000000000..edceb7829c44c --- /dev/null +++ b/xla/tests/complex_unary_op_samples.h @@ -0,0 +1,1448 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* + This file is generated using xla/tests/generate_complex_unary_op_samples.py. + Do not edit! + */ + +#include +#include +#include +#include +#include + +#ifndef XLA_TESTS_COMPLEX_UNARY_OP_SAMPLES_H_ +#define XLA_TESTS_COMPLEX_UNARY_OP_SAMPLES_H_ + +namespace complex_unary_op_samples { + +template +struct Log1p { + typedef std::complex InputType; + typedef std::complex OutputType; + typedef T FloatType; + using TableType = std::vector>; + static constexpr int dps_deficiency = default_dps_deficiency; + const TableType get() { + const T inf = std::numeric_limits::infinity(); + const T min = std::numeric_limits::min(); + const T max = std::numeric_limits::max(); + if constexpr (std::is_same_v) { + const T pi = 3.1415927f; + const T pi_4 = 0.7853982f; + const T pi_2 = 1.5707964f; + const T pi3_4 = 2.3561945f; + const T zero = 0.0f; + const TableType table{ + /* 0 */ {{-inf, -inf}, {inf, -pi3_4}, 1.e+00f}, + /* 1 */ {{-max, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 2 */ {{-6.14096e+25f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 3 */ {{-1.108238e+13f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 4 */ {{-2.e+00f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 5 */ {{-3.609332e-13f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 6 */ {{-6.513639e-26f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 7 */ {{-min, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 8 */ {{zero, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 9 */ {{min, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 10 */ {{6.513639e-26f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 11 */ {{3.609332e-13f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 12 */ {{2.e+00f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 13 */ {{1.108238e+13f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 14 */ {{6.14096e+25f, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 15 */ {{max, -inf}, {inf, -pi_2}, 1.e+00f}, + /* 16 */ {{inf, -inf}, {inf, -pi_4}, 1.e+00f}, + /* 17 */ {{-inf, -max}, {inf, -pi}, 1.e+00f}, + /* 18 */ {{-max, -max}, {8.906941e+01f, -pi3_4}, 7.8125e-03f}, + /* 19 */ {{-6.14096e+25f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 20 */ + {{-1.108238e+13f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 21 */ {{-2.e+00f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 22 */ + {{-3.609332e-13f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 23 */ + {{-6.513639e-26f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 24 */ {{-min, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 25 */ {{zero, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 26 */ {{min, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 27 */ {{6.513639e-26f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 28 */ {{3.609332e-13f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 29 */ {{2.e+00f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 30 */ {{1.108238e+13f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 31 */ {{6.14096e+25f, -max}, {8.872284e+01f, -pi_2}, 7.8125e-03f}, + /* 32 */ {{max, -max}, {8.906941e+01f, -pi_4}, 7.8125e-03f}, + /* 33 */ {{inf, -max}, {inf, zero}, 1.e+00f}, + /* 34 */ {{-inf, -6.14096e+25f}, {inf, -pi}, 1.e+00f}, + /* 35 */ {{-max, -6.14096e+25f}, {8.872284e+01f, -pi}, 7.8125e-03f}, + /* 36 */ + {{-6.14096e+25f, -6.14096e+25f}, + {5.972618e+01f, -pi3_4}, + 1.5625e-02f}, + /* 37 */ + {{-1.108238e+13f, -6.14096e+25f}, + {5.937961e+01f, -pi_2}, + 1.5625e-02f}, + /* 38 */ + {{-2.e+00f, -6.14096e+25f}, {5.937961e+01f, -pi_2}, 1.5625e-02f}, + /* 39 */ + {{-3.609332e-13f, -6.14096e+25f}, + {5.937961e+01f, -pi_2}, + 1.5625e-02f}, + /* 40 */ + {{-6.513639e-26f, -6.14096e+25f}, + {5.937961e+01f, -pi_2}, + 1.5625e-02f}, + /* 41 */ {{-min, -6.14096e+25f}, {5.937961e+01f, -pi_2}, 1.5625e-02f}, + /* 42 */ {{zero, -6.14096e+25f}, {5.937961e+01f, -pi_2}, 1.5625e-02f}, + /* 43 */ {{min, -6.14096e+25f}, {5.937961e+01f, -pi_2}, 1.5625e-02f}, + /* 44 */ + {{6.513639e-26f, -6.14096e+25f}, {5.937961e+01f, -pi_2}, 1.5625e-02f}, + /* 45 */ + {{3.609332e-13f, -6.14096e+25f}, {5.937961e+01f, -pi_2}, 1.5625e-02f}, + /* 46 */ + {{2.e+00f, -6.14096e+25f}, {5.937961e+01f, -pi_2}, 1.5625e-02f}, + /* 47 */ + {{1.108238e+13f, -6.14096e+25f}, {5.937961e+01f, -pi_2}, 1.5625e-02f}, + /* 48 */ + {{6.14096e+25f, -6.14096e+25f}, {5.972618e+01f, -pi_4}, 1.5625e-02f}, + /* 49 */ + {{max, -6.14096e+25f}, {8.872284e+01f, -1.804666e-13f}, 7.8125e-03f}, + /* 50 */ {{inf, -6.14096e+25f}, {inf, zero}, 1.e+00f}, + /* 51 */ {{-inf, -1.108238e+13f}, {inf, -pi}, 1.e+00f}, + /* 52 */ {{-max, -1.108238e+13f}, {8.872284e+01f, -pi}, 7.8125e-03f}, + /* 53 */ + {{-6.14096e+25f, -1.108238e+13f}, {5.937961e+01f, -pi}, 1.5625e-02f}, + /* 54 */ + {{-1.108238e+13f, -1.108238e+13f}, + {3.038295e+01f, -pi3_4}, + 3.125e-02f}, + /* 55 */ + {{-2.e+00f, -1.108238e+13f}, {3.003638e+01f, -pi_2}, 3.125e-02f}, + /* 56 */ + {{-3.609332e-13f, -1.108238e+13f}, + {3.003638e+01f, -pi_2}, + 3.125e-02f}, + /* 57 */ + {{-6.513639e-26f, -1.108238e+13f}, + {3.003638e+01f, -pi_2}, + 3.125e-02f}, + /* 58 */ {{-min, -1.108238e+13f}, {3.003638e+01f, -pi_2}, 3.125e-02f}, + /* 59 */ {{zero, -1.108238e+13f}, {3.003638e+01f, -pi_2}, 3.125e-02f}, + /* 60 */ {{min, -1.108238e+13f}, {3.003638e+01f, -pi_2}, 3.125e-02f}, + /* 61 */ + {{6.513639e-26f, -1.108238e+13f}, {3.003638e+01f, -pi_2}, 3.125e-02f}, + /* 62 */ + {{3.609332e-13f, -1.108238e+13f}, {3.003638e+01f, -pi_2}, 3.125e-02f}, + /* 63 */ + {{2.e+00f, -1.108238e+13f}, {3.003638e+01f, -pi_2}, 3.125e-02f}, + /* 64 */ + {{1.108238e+13f, -1.108238e+13f}, {3.038295e+01f, -pi_4}, 3.125e-02f}, + /* 65 */ + {{6.14096e+25f, -1.108238e+13f}, + {5.937961e+01f, -1.804666e-13f}, + 1.5625e-02f}, + /* 66 */ + {{max, -1.108238e+13f}, {8.872284e+01f, -3.25682e-26f}, 7.8125e-03f}, + /* 67 */ {{inf, -1.108238e+13f}, {inf, zero}, 1.e+00f}, + /* 68 */ {{-inf, -2.e+00f}, {inf, -pi}, 1.e+00f}, + /* 69 */ {{-max, -2.e+00f}, {8.872284e+01f, -pi}, 7.8125e-03f}, + /* 70 */ + {{-6.14096e+25f, -2.e+00f}, {5.937961e+01f, -pi}, 1.5625e-02f}, + /* 71 */ + {{-1.108238e+13f, -2.e+00f}, {3.003638e+01f, -pi}, 3.125e-02f}, + /* 72 */ + {{-2.e+00f, -2.e+00f}, {8.04719e-01f, -2.034444e+00f}, 2.5e-01f}, + /* 73 */ + {{-3.609332e-13f, -2.e+00f}, {8.04719e-01f, -1.107149e+00f}, 5.e-01f}, + /* 74 */ + {{-6.513639e-26f, -2.e+00f}, {8.04719e-01f, -1.107149e+00f}, 5.e-01f}, + /* 75 */ {{-min, -2.e+00f}, {8.04719e-01f, -1.107149e+00f}, 5.e-01f}, + /* 76 */ {{zero, -2.e+00f}, {8.04719e-01f, -1.107149e+00f}, 5.e-01f}, + /* 77 */ {{min, -2.e+00f}, {8.04719e-01f, -1.107149e+00f}, 5.e-01f}, + /* 78 */ + {{6.513639e-26f, -2.e+00f}, {8.04719e-01f, -1.107149e+00f}, 5.e-01f}, + /* 79 */ + {{3.609332e-13f, -2.e+00f}, {8.04719e-01f, -1.107149e+00f}, 5.e-01f}, + /* 80 */ + {{2.e+00f, -2.e+00f}, {1.282475e+00f, -5.880026e-01f}, 5.e-01f}, + /* 81 */ + {{1.108238e+13f, -2.e+00f}, + {3.003638e+01f, -1.804666e-13f}, + 3.125e-02f}, + /* 82 */ + {{6.14096e+25f, -2.e+00f}, + {5.937961e+01f, -3.25682e-26f}, + 1.5625e-02f}, + /* 83 */ + {{max, -2.e+00f}, {8.872284e+01f, -5.877472e-39f}, 7.8125e-03f}, + /* 84 */ {{inf, -2.e+00f}, {inf, zero}, 1.e+00f}, + /* 85 */ {{-inf, -3.609332e-13f}, {inf, -pi}, 1.e+00f}, + /* 86 */ {{-max, -3.609332e-13f}, {8.872284e+01f, -pi}, 7.8125e-03f}, + /* 87 */ + {{-6.14096e+25f, -3.609332e-13f}, {5.937961e+01f, -pi}, 1.5625e-02f}, + /* 88 */ + {{-1.108238e+13f, -3.609332e-13f}, {3.003638e+01f, -pi}, 3.125e-02f}, + /* 89 */ {{-2.e+00f, -3.609332e-13f}, {6.513639e-26f, -pi}, 2.5e-01f}, + /* 90 */ + {{-3.609332e-13f, -3.609332e-13f}, + {-3.609332e-13f, -3.609332e-13f}, + 1.099512e+12f}, + /* 91 */ + {{-6.513639e-26f, -3.609332e-13f}, + {-2.843711e-33f, -3.609332e-13f}, + 2.199023e+12f}, + /* 92 */ + {{-min, -3.609332e-13f}, + {6.513639e-26f, -3.609332e-13f}, + 2.199023e+12f}, + /* 93 */ + {{zero, -3.609332e-13f}, + {6.513639e-26f, -3.609332e-13f}, + 2.199023e+12f}, + /* 94 */ + {{min, -3.609332e-13f}, + {6.513639e-26f, -3.609332e-13f}, + 2.199023e+12f}, + /* 95 */ + {{6.513639e-26f, -3.609332e-13f}, + {1.302728e-25f, -3.609332e-13f}, + 2.199023e+12f}, + /* 96 */ + {{3.609332e-13f, -3.609332e-13f}, + {3.609332e-13f, -3.609332e-13f}, + 1.099512e+12f}, + /* 97 */ + {{2.e+00f, -3.609332e-13f}, {1.098612e+00f, -1.203111e-13f}, 5.e-01f}, + /* 98 */ + {{1.108238e+13f, -3.609332e-13f}, + {3.003638e+01f, -3.256819e-26f}, + 3.125e-02f}, + /* 99 */ + {{6.14096e+25f, -3.609332e-13f}, + {5.937961e+01f, -5.877472e-39f}, + 1.5625e-02f}, + /* 100 */ {{max, -3.609332e-13f}, {8.872284e+01f, zero}, 7.8125e-03f}, + /* 101 */ {{inf, -3.609332e-13f}, {inf, zero}, 1.e+00f}, + /* 102 */ {{-inf, -6.513639e-26f}, {inf, -pi}, 1.e+00f}, + /* 103 */ {{-max, -6.513639e-26f}, {8.872284e+01f, -pi}, 7.8125e-03f}, + /* 104 */ + {{-6.14096e+25f, -6.513639e-26f}, {5.937961e+01f, -pi}, 1.5625e-02f}, + /* 105 */ + {{-1.108238e+13f, -6.513639e-26f}, {3.003638e+01f, -pi}, 3.125e-02f}, + /* 106 */ {{-2.e+00f, -6.513639e-26f}, {zero, -pi}, 2.5e-01f}, + /* 107 */ + {{-3.609332e-13f, -6.513639e-26f}, + {-3.609332e-13f, -6.513639e-26f}, + 2.199023e+12f}, + /* 108 */ + {{-6.513639e-26f, -6.513639e-26f}, + {-6.513639e-26f, -6.513639e-26f}, + 9.671407e+24f}, + /* 109 */ + {{-min, -6.513639e-26f}, {-min, -6.513639e-26f}, 9.671407e+24f}, + /* 110 */ + {{zero, -6.513639e-26f}, {zero, -6.513639e-26f}, 9.671407e+24f}, + /* 111 */ + {{min, -6.513639e-26f}, {min, -6.513639e-26f}, 9.671407e+24f}, + /* 112 */ + {{6.513639e-26f, -6.513639e-26f}, + {6.513639e-26f, -6.513639e-26f}, + 9.671407e+24f}, + /* 113 */ + {{3.609332e-13f, -6.513639e-26f}, + {3.609332e-13f, -6.513639e-26f}, + 2.199023e+12f}, + /* 114 */ + {{2.e+00f, -6.513639e-26f}, {1.098612e+00f, -2.171213e-26f}, 5.e-01f}, + /* 115 */ + {{1.108238e+13f, -6.513639e-26f}, + {3.003638e+01f, -5.877472e-39f}, + 3.125e-02f}, + /* 116 */ + {{6.14096e+25f, -6.513639e-26f}, {5.937961e+01f, zero}, 1.5625e-02f}, + /* 117 */ {{max, -6.513639e-26f}, {8.872284e+01f, zero}, 7.8125e-03f}, + /* 118 */ {{inf, -6.513639e-26f}, {inf, zero}, 1.e+00f}, + /* 119 */ {{-inf, -min}, {inf, -pi}, 1.e+00f}, + /* 120 */ {{-max, -min}, {8.872284e+01f, -pi}, 7.8125e-03f}, + /* 121 */ {{-6.14096e+25f, -min}, {5.937961e+01f, -pi}, 1.5625e-02f}, + /* 122 */ {{-1.108238e+13f, -min}, {3.003638e+01f, -pi}, 3.125e-02f}, + /* 123 */ {{-2.e+00f, -min}, {zero, -pi}, 2.5e-01f}, + /* 124 */ + {{-3.609332e-13f, -min}, {-3.609332e-13f, -min}, 2.199023e+12f}, + /* 125 */ + {{-6.513639e-26f, -min}, {-6.513639e-26f, -min}, 9.671407e+24f}, + /* 126 */ {{-min, -min}, {-min, -min}, 4.25353e+37f}, + /* 127 */ {{zero, -min}, {zero, -min}, 4.25353e+37f}, + /* 128 */ {{min, -min}, {min, -min}, 4.25353e+37f}, + /* 129 */ + {{6.513639e-26f, -min}, {6.513639e-26f, -min}, 9.671407e+24f}, + /* 130 */ + {{3.609332e-13f, -min}, {3.609332e-13f, -min}, 2.199023e+12f}, + /* 131 */ {{2.e+00f, -min}, {1.098612e+00f, zero}, 5.e-01f}, + /* 132 */ {{1.108238e+13f, -min}, {3.003638e+01f, zero}, 3.125e-02f}, + /* 133 */ {{6.14096e+25f, -min}, {5.937961e+01f, zero}, 1.5625e-02f}, + /* 134 */ {{max, -min}, {8.872284e+01f, zero}, 7.8125e-03f}, + /* 135 */ {{inf, -min}, {inf, zero}, 1.e+00f}, + /* 136 */ {{-inf, zero}, {inf, pi}, 1.e+00f}, + /* 137 */ {{-max, zero}, {8.872284e+01f, pi}, 7.8125e-03f}, + /* 138 */ {{-6.14096e+25f, zero}, {5.937961e+01f, pi}, 1.5625e-02f}, + /* 139 */ {{-1.108238e+13f, zero}, {3.003638e+01f, pi}, 3.125e-02f}, + /* 140 */ {{-2.e+00f, zero}, {zero, pi}, 2.5e-01f}, + /* 141 */ + {{-3.609332e-13f, zero}, {-3.609332e-13f, zero}, 2.199023e+12f}, + /* 142 */ + {{-6.513639e-26f, zero}, {-6.513639e-26f, zero}, 9.671407e+24f}, + /* 143 */ {{-min, zero}, {-min, zero}, 4.25353e+37f}, + /* 144 */ {{zero, zero}, {zero, zero}, 1.e+00f}, + /* 145 */ {{min, zero}, {min, zero}, 4.25353e+37f}, + /* 146 */ + {{6.513639e-26f, zero}, {6.513639e-26f, zero}, 9.671407e+24f}, + /* 147 */ + {{3.609332e-13f, zero}, {3.609332e-13f, zero}, 2.199023e+12f}, + /* 148 */ {{2.e+00f, zero}, {1.098612e+00f, zero}, 5.e-01f}, + /* 149 */ {{1.108238e+13f, zero}, {3.003638e+01f, zero}, 3.125e-02f}, + /* 150 */ {{6.14096e+25f, zero}, {5.937961e+01f, zero}, 1.5625e-02f}, + /* 151 */ {{max, zero}, {8.872284e+01f, zero}, 7.8125e-03f}, + /* 152 */ {{inf, zero}, {inf, zero}, 1.e+00f}, + /* 153 */ {{-inf, min}, {inf, pi}, 1.e+00f}, + /* 154 */ {{-max, min}, {8.872284e+01f, pi}, 7.8125e-03f}, + /* 155 */ {{-6.14096e+25f, min}, {5.937961e+01f, pi}, 1.5625e-02f}, + /* 156 */ {{-1.108238e+13f, min}, {3.003638e+01f, pi}, 3.125e-02f}, + /* 157 */ {{-2.e+00f, min}, {zero, pi}, 2.5e-01f}, + /* 158 */ + {{-3.609332e-13f, min}, {-3.609332e-13f, min}, 2.199023e+12f}, + /* 159 */ + {{-6.513639e-26f, min}, {-6.513639e-26f, min}, 9.671407e+24f}, + /* 160 */ {{-min, min}, {-min, min}, 4.25353e+37f}, + /* 161 */ {{zero, min}, {zero, min}, 4.25353e+37f}, + /* 162 */ {{min, min}, {min, min}, 4.25353e+37f}, + /* 163 */ {{6.513639e-26f, min}, {6.513639e-26f, min}, 9.671407e+24f}, + /* 164 */ {{3.609332e-13f, min}, {3.609332e-13f, min}, 2.199023e+12f}, + /* 165 */ {{2.e+00f, min}, {1.098612e+00f, zero}, 5.e-01f}, + /* 166 */ {{1.108238e+13f, min}, {3.003638e+01f, zero}, 3.125e-02f}, + /* 167 */ {{6.14096e+25f, min}, {5.937961e+01f, zero}, 1.5625e-02f}, + /* 168 */ {{max, min}, {8.872284e+01f, zero}, 7.8125e-03f}, + /* 169 */ {{inf, min}, {inf, zero}, 1.e+00f}, + /* 170 */ {{-inf, 6.513639e-26f}, {inf, pi}, 1.e+00f}, + /* 171 */ {{-max, 6.513639e-26f}, {8.872284e+01f, pi}, 7.8125e-03f}, + /* 172 */ + {{-6.14096e+25f, 6.513639e-26f}, {5.937961e+01f, pi}, 1.5625e-02f}, + /* 173 */ + {{-1.108238e+13f, 6.513639e-26f}, {3.003638e+01f, pi}, 3.125e-02f}, + /* 174 */ {{-2.e+00f, 6.513639e-26f}, {zero, pi}, 2.5e-01f}, + /* 175 */ + {{-3.609332e-13f, 6.513639e-26f}, + {-3.609332e-13f, 6.513639e-26f}, + 2.199023e+12f}, + /* 176 */ + {{-6.513639e-26f, 6.513639e-26f}, + {-6.513639e-26f, 6.513639e-26f}, + 9.671407e+24f}, + /* 177 */ + {{-min, 6.513639e-26f}, {-min, 6.513639e-26f}, 9.671407e+24f}, + /* 178 */ + {{zero, 6.513639e-26f}, {zero, 6.513639e-26f}, 9.671407e+24f}, + /* 179 */ {{min, 6.513639e-26f}, {min, 6.513639e-26f}, 9.671407e+24f}, + /* 180 */ + {{6.513639e-26f, 6.513639e-26f}, + {6.513639e-26f, 6.513639e-26f}, + 9.671407e+24f}, + /* 181 */ + {{3.609332e-13f, 6.513639e-26f}, + {3.609332e-13f, 6.513639e-26f}, + 2.199023e+12f}, + /* 182 */ + {{2.e+00f, 6.513639e-26f}, {1.098612e+00f, 2.171213e-26f}, 5.e-01f}, + /* 183 */ + {{1.108238e+13f, 6.513639e-26f}, + {3.003638e+01f, 5.877472e-39f}, + 3.125e-02f}, + /* 184 */ + {{6.14096e+25f, 6.513639e-26f}, {5.937961e+01f, zero}, 1.5625e-02f}, + /* 185 */ {{max, 6.513639e-26f}, {8.872284e+01f, zero}, 7.8125e-03f}, + /* 186 */ {{inf, 6.513639e-26f}, {inf, zero}, 1.e+00f}, + /* 187 */ {{-inf, 3.609332e-13f}, {inf, pi}, 1.e+00f}, + /* 188 */ {{-max, 3.609332e-13f}, {8.872284e+01f, pi}, 7.8125e-03f}, + /* 189 */ + {{-6.14096e+25f, 3.609332e-13f}, {5.937961e+01f, pi}, 1.5625e-02f}, + /* 190 */ + {{-1.108238e+13f, 3.609332e-13f}, {3.003638e+01f, pi}, 3.125e-02f}, + /* 191 */ {{-2.e+00f, 3.609332e-13f}, {6.513639e-26f, pi}, 2.5e-01f}, + /* 192 */ + {{-3.609332e-13f, 3.609332e-13f}, + {-3.609332e-13f, 3.609332e-13f}, + 1.099512e+12f}, + /* 193 */ + {{-6.513639e-26f, 3.609332e-13f}, + {-2.843711e-33f, 3.609332e-13f}, + 2.199023e+12f}, + /* 194 */ + {{-min, 3.609332e-13f}, + {6.513639e-26f, 3.609332e-13f}, + 2.199023e+12f}, + /* 195 */ + {{zero, 3.609332e-13f}, + {6.513639e-26f, 3.609332e-13f}, + 2.199023e+12f}, + /* 196 */ + {{min, 3.609332e-13f}, {6.513639e-26f, 3.609332e-13f}, 2.199023e+12f}, + /* 197 */ + {{6.513639e-26f, 3.609332e-13f}, + {1.302728e-25f, 3.609332e-13f}, + 2.199023e+12f}, + /* 198 */ + {{3.609332e-13f, 3.609332e-13f}, + {3.609332e-13f, 3.609332e-13f}, + 1.099512e+12f}, + /* 199 */ + {{2.e+00f, 3.609332e-13f}, {1.098612e+00f, 1.203111e-13f}, 5.e-01f}, + /* 200 */ + {{1.108238e+13f, 3.609332e-13f}, + {3.003638e+01f, 3.256819e-26f}, + 3.125e-02f}, + /* 201 */ + {{6.14096e+25f, 3.609332e-13f}, + {5.937961e+01f, 5.877472e-39f}, + 1.5625e-02f}, + /* 202 */ {{max, 3.609332e-13f}, {8.872284e+01f, zero}, 7.8125e-03f}, + /* 203 */ {{inf, 3.609332e-13f}, {inf, zero}, 1.e+00f}, + /* 204 */ {{-inf, 2.e+00f}, {inf, pi}, 1.e+00f}, + /* 205 */ {{-max, 2.e+00f}, {8.872284e+01f, pi}, 7.8125e-03f}, + /* 206 */ + {{-6.14096e+25f, 2.e+00f}, {5.937961e+01f, pi}, 1.5625e-02f}, + /* 207 */ + {{-1.108238e+13f, 2.e+00f}, {3.003638e+01f, pi}, 3.125e-02f}, + /* 208 */ + {{-2.e+00f, 2.e+00f}, {8.04719e-01f, 2.034444e+00f}, 2.5e-01f}, + /* 209 */ + {{-3.609332e-13f, 2.e+00f}, {8.04719e-01f, 1.107149e+00f}, 5.e-01f}, + /* 210 */ + {{-6.513639e-26f, 2.e+00f}, {8.04719e-01f, 1.107149e+00f}, 5.e-01f}, + /* 211 */ {{-min, 2.e+00f}, {8.04719e-01f, 1.107149e+00f}, 5.e-01f}, + /* 212 */ {{zero, 2.e+00f}, {8.04719e-01f, 1.107149e+00f}, 5.e-01f}, + /* 213 */ {{min, 2.e+00f}, {8.04719e-01f, 1.107149e+00f}, 5.e-01f}, + /* 214 */ + {{6.513639e-26f, 2.e+00f}, {8.04719e-01f, 1.107149e+00f}, 5.e-01f}, + /* 215 */ + {{3.609332e-13f, 2.e+00f}, {8.04719e-01f, 1.107149e+00f}, 5.e-01f}, + /* 216 */ + {{2.e+00f, 2.e+00f}, {1.282475e+00f, 5.880026e-01f}, 5.e-01f}, + /* 217 */ + {{1.108238e+13f, 2.e+00f}, + {3.003638e+01f, 1.804666e-13f}, + 3.125e-02f}, + /* 218 */ + {{6.14096e+25f, 2.e+00f}, {5.937961e+01f, 3.25682e-26f}, 1.5625e-02f}, + /* 219 */ + {{max, 2.e+00f}, {8.872284e+01f, 5.877472e-39f}, 7.8125e-03f}, + /* 220 */ {{inf, 2.e+00f}, {inf, zero}, 1.e+00f}, + /* 221 */ {{-inf, 1.108238e+13f}, {inf, pi}, 1.e+00f}, + /* 222 */ {{-max, 1.108238e+13f}, {8.872284e+01f, pi}, 7.8125e-03f}, + /* 223 */ + {{-6.14096e+25f, 1.108238e+13f}, {5.937961e+01f, pi}, 1.5625e-02f}, + /* 224 */ + {{-1.108238e+13f, 1.108238e+13f}, {3.038295e+01f, pi3_4}, 3.125e-02f}, + /* 225 */ + {{-2.e+00f, 1.108238e+13f}, {3.003638e+01f, pi_2}, 3.125e-02f}, + /* 226 */ + {{-3.609332e-13f, 1.108238e+13f}, {3.003638e+01f, pi_2}, 3.125e-02f}, + /* 227 */ + {{-6.513639e-26f, 1.108238e+13f}, {3.003638e+01f, pi_2}, 3.125e-02f}, + /* 228 */ {{-min, 1.108238e+13f}, {3.003638e+01f, pi_2}, 3.125e-02f}, + /* 229 */ {{zero, 1.108238e+13f}, {3.003638e+01f, pi_2}, 3.125e-02f}, + /* 230 */ {{min, 1.108238e+13f}, {3.003638e+01f, pi_2}, 3.125e-02f}, + /* 231 */ + {{6.513639e-26f, 1.108238e+13f}, {3.003638e+01f, pi_2}, 3.125e-02f}, + /* 232 */ + {{3.609332e-13f, 1.108238e+13f}, {3.003638e+01f, pi_2}, 3.125e-02f}, + /* 233 */ + {{2.e+00f, 1.108238e+13f}, {3.003638e+01f, pi_2}, 3.125e-02f}, + /* 234 */ + {{1.108238e+13f, 1.108238e+13f}, {3.038295e+01f, pi_4}, 3.125e-02f}, + /* 235 */ + {{6.14096e+25f, 1.108238e+13f}, + {5.937961e+01f, 1.804666e-13f}, + 1.5625e-02f}, + /* 236 */ + {{max, 1.108238e+13f}, {8.872284e+01f, 3.25682e-26f}, 7.8125e-03f}, + /* 237 */ {{inf, 1.108238e+13f}, {inf, zero}, 1.e+00f}, + /* 238 */ {{-inf, 6.14096e+25f}, {inf, pi}, 1.e+00f}, + /* 239 */ {{-max, 6.14096e+25f}, {8.872284e+01f, pi}, 7.8125e-03f}, + /* 240 */ + {{-6.14096e+25f, 6.14096e+25f}, {5.972618e+01f, pi3_4}, 1.5625e-02f}, + /* 241 */ + {{-1.108238e+13f, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 242 */ + {{-2.e+00f, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 243 */ + {{-3.609332e-13f, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 244 */ + {{-6.513639e-26f, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 245 */ {{-min, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 246 */ {{zero, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 247 */ {{min, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 248 */ + {{6.513639e-26f, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 249 */ + {{3.609332e-13f, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 250 */ + {{2.e+00f, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 251 */ + {{1.108238e+13f, 6.14096e+25f}, {5.937961e+01f, pi_2}, 1.5625e-02f}, + /* 252 */ + {{6.14096e+25f, 6.14096e+25f}, {5.972618e+01f, pi_4}, 1.5625e-02f}, + /* 253 */ + {{max, 6.14096e+25f}, {8.872284e+01f, 1.804666e-13f}, 7.8125e-03f}, + /* 254 */ {{inf, 6.14096e+25f}, {inf, zero}, 1.e+00f}, + /* 255 */ {{-inf, max}, {inf, pi}, 1.e+00f}, + /* 256 */ {{-max, max}, {8.906941e+01f, pi3_4}, 7.8125e-03f}, + /* 257 */ {{-6.14096e+25f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 258 */ {{-1.108238e+13f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 259 */ {{-2.e+00f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 260 */ {{-3.609332e-13f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 261 */ {{-6.513639e-26f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 262 */ {{-min, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 263 */ {{zero, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 264 */ {{min, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 265 */ {{6.513639e-26f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 266 */ {{3.609332e-13f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 267 */ {{2.e+00f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 268 */ {{1.108238e+13f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 269 */ {{6.14096e+25f, max}, {8.872284e+01f, pi_2}, 7.8125e-03f}, + /* 270 */ {{max, max}, {8.906941e+01f, pi_4}, 7.8125e-03f}, + /* 271 */ {{inf, max}, {inf, zero}, 1.e+00f}, + /* 272 */ {{-inf, inf}, {inf, pi3_4}, 1.e+00f}, + /* 273 */ {{-max, inf}, {inf, pi_2}, 1.e+00f}, + /* 274 */ {{-6.14096e+25f, inf}, {inf, pi_2}, 1.e+00f}, + /* 275 */ {{-1.108238e+13f, inf}, {inf, pi_2}, 1.e+00f}, + /* 276 */ {{-2.e+00f, inf}, {inf, pi_2}, 1.e+00f}, + /* 277 */ {{-3.609332e-13f, inf}, {inf, pi_2}, 1.e+00f}, + /* 278 */ {{-6.513639e-26f, inf}, {inf, pi_2}, 1.e+00f}, + /* 279 */ {{-min, inf}, {inf, pi_2}, 1.e+00f}, + /* 280 */ {{zero, inf}, {inf, pi_2}, 1.e+00f}, + /* 281 */ {{min, inf}, {inf, pi_2}, 1.e+00f}, + /* 282 */ {{6.513639e-26f, inf}, {inf, pi_2}, 1.e+00f}, + /* 283 */ {{3.609332e-13f, inf}, {inf, pi_2}, 1.e+00f}, + /* 284 */ {{2.e+00f, inf}, {inf, pi_2}, 1.e+00f}, + /* 285 */ {{1.108238e+13f, inf}, {inf, pi_2}, 1.e+00f}, + /* 286 */ {{6.14096e+25f, inf}, {inf, pi_2}, 1.e+00f}, + /* 287 */ {{max, inf}, {inf, pi_2}, 1.e+00f}, + /* 288 */ {{inf, inf}, {inf, pi_4}, 1.e+00f}}; + return table; + } else if constexpr (std::is_same_v) { + const T pi = 3.141592653589793; + const T pi_4 = 0.7853981633974483; + const T pi_2 = 1.5707963267948966; + const T pi3_4 = 2.356194490192345; + const T zero = 0.0; + const TableType table{ + /* 0 */ {{-inf, -inf}, {inf, -pi3_4}, 1.e+00}, + /* 1 */ {{-max, -inf}, {inf, -pi_2}, 1.e+00}, + /* 2 */ {{-4.013165208090075e+205, -inf}, {inf, -pi_2}, 1.e+00}, + /* 3 */ {{-8.958978968710456e+102, -inf}, {inf, -pi_2}, 1.e+00}, + /* 4 */ {{-1.999999999999869e+00, -inf}, {inf, -pi_2}, 1.e+00}, + /* 5 */ {{-4.464794497196183e-103, -inf}, {inf, -pi_2}, 1.e+00}, + /* 6 */ {{-9.967194951097309e-206, -inf}, {inf, -pi_2}, 1.e+00}, + /* 7 */ {{-min, -inf}, {inf, -pi_2}, 1.e+00}, + /* 8 */ {{zero, -inf}, {inf, -pi_2}, 1.e+00}, + /* 9 */ {{min, -inf}, {inf, -pi_2}, 1.e+00}, + /* 10 */ {{9.967194951097309e-206, -inf}, {inf, -pi_2}, 1.e+00}, + /* 11 */ {{4.464794497196183e-103, -inf}, {inf, -pi_2}, 1.e+00}, + /* 12 */ {{1.999999999999869e+00, -inf}, {inf, -pi_2}, 1.e+00}, + /* 13 */ {{8.958978968710456e+102, -inf}, {inf, -pi_2}, 1.e+00}, + /* 14 */ {{4.013165208090075e+205, -inf}, {inf, -pi_2}, 1.e+00}, + /* 15 */ {{max, -inf}, {inf, -pi_2}, 1.e+00}, + /* 16 */ {{inf, -inf}, {inf, -pi_4}, 1.e+00}, + /* 17 */ {{-inf, -max}, {inf, -pi}, 1.e+00}, + /* 18 */ + {{-max, -max}, {7.101292864836639e+02, -pi3_4}, 9.765625e-04}, + /* 19 */ + {{-4.013165208090075e+205, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 20 */ + {{-8.958978968710456e+102, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 21 */ + {{-1.999999999999869e+00, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 22 */ + {{-4.464794497196183e-103, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 23 */ + {{-9.967194951097309e-206, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 24 */ {{-min, -max}, {7.09782712893384e+02, -pi_2}, 9.765625e-04}, + /* 25 */ {{zero, -max}, {7.09782712893384e+02, -pi_2}, 9.765625e-04}, + /* 26 */ {{min, -max}, {7.09782712893384e+02, -pi_2}, 9.765625e-04}, + /* 27 */ + {{9.967194951097309e-206, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 28 */ + {{4.464794497196183e-103, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 29 */ + {{1.999999999999869e+00, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 30 */ + {{8.958978968710456e+102, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 31 */ + {{4.013165208090075e+205, -max}, + {7.09782712893384e+02, -pi_2}, + 9.765625e-04}, + /* 32 */ {{max, -max}, {7.101292864836639e+02, -pi_4}, 9.765625e-04}, + /* 33 */ {{inf, -max}, {inf, zero}, 1.e+00}, + /* 34 */ {{-inf, -4.013165208090075e+205}, {inf, -pi}, 1.e+00}, + /* 35 */ + {{-max, -4.013165208090075e+205}, + {7.09782712893384e+02, -pi}, + 9.765625e-04}, + /* 36 */ + {{-4.013165208090075e+205, -4.013165208090075e+205}, + {4.737660979127225e+02, -pi3_4}, + 1.953125e-03}, + /* 37 */ + {{-8.958978968710456e+102, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 38 */ + {{-1.999999999999869e+00, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 39 */ + {{-4.464794497196183e-103, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 40 */ + {{-9.967194951097309e-206, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 41 */ + {{-min, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 42 */ + {{zero, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 43 */ + {{min, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 44 */ + {{9.967194951097309e-206, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 45 */ + {{4.464794497196183e-103, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 46 */ + {{1.999999999999869e+00, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 47 */ + {{8.958978968710456e+102, -4.013165208090075e+205}, + {4.734195243224426e+02, -pi_2}, + 1.953125e-03}, + /* 48 */ + {{4.013165208090075e+205, -4.013165208090075e+205}, + {4.737660979127225e+02, -pi_4}, + 1.953125e-03}, + /* 49 */ + {{max, -4.013165208090075e+205}, + {7.09782712893384e+02, -2.23239724859796e-103}, + 9.765625e-04}, + /* 50 */ {{inf, -4.013165208090075e+205}, {inf, zero}, 1.e+00}, + /* 51 */ {{-inf, -8.958978968710456e+102}, {inf, -pi}, 1.e+00}, + /* 52 */ + {{-max, -8.958978968710456e+102}, + {7.09782712893384e+02, -pi}, + 9.765625e-04}, + /* 53 */ + {{-4.013165208090075e+205, -8.958978968710456e+102}, + {4.734195243224426e+02, -pi}, + 1.953125e-03}, + /* 54 */ + {{-8.958978968710456e+102, -8.958978968710456e+102}, + {2.374029093417812e+02, -pi3_4}, + 3.90625e-03}, + /* 55 */ + {{-1.999999999999869e+00, -8.958978968710456e+102}, + {2.370563357515012e+02, -pi_2}, + 3.90625e-03}, + /* 56 */ + {{-4.464794497196183e-103, -8.958978968710456e+102}, + {2.370563357515012e+02, -pi_2}, + 3.90625e-03}, + /* 57 */ + {{-9.967194951097309e-206, -8.958978968710456e+102}, + {2.370563357515012e+02, -pi_2}, + 3.90625e-03}, + /* 58 */ + {{-min, -8.958978968710456e+102}, + {2.370563357515012e+02, -pi_2}, + 3.90625e-03}, + /* 59 */ + {{zero, -8.958978968710456e+102}, + {2.370563357515012e+02, -pi_2}, + 3.90625e-03}, + /* 60 */ + {{min, -8.958978968710456e+102}, + {2.370563357515012e+02, -pi_2}, + 3.90625e-03}, + /* 61 */ + {{9.967194951097309e-206, -8.958978968710456e+102}, + {2.370563357515012e+02, -pi_2}, + 3.90625e-03}, + /* 62 */ + {{4.464794497196183e-103, -8.958978968710456e+102}, + {2.370563357515012e+02, -pi_2}, + 3.90625e-03}, + /* 63 */ + {{1.999999999999869e+00, -8.958978968710456e+102}, + {2.370563357515012e+02, -pi_2}, + 3.90625e-03}, + /* 64 */ + {{8.958978968710456e+102, -8.958978968710456e+102}, + {2.374029093417812e+02, -pi_4}, + 3.90625e-03}, + /* 65 */ + {{4.013165208090075e+205, -8.958978968710456e+102}, + {4.734195243224426e+02, -2.232397248598237e-103}, + 1.953125e-03}, + /* 66 */ + {{max, -8.958978968710456e+102}, + {7.09782712893384e+02, -4.983597475548361e-206}, + 9.765625e-04}, + /* 67 */ {{inf, -8.958978968710456e+102}, {inf, zero}, 1.e+00}, + /* 68 */ {{-inf, -1.999999999999869e+00}, {inf, -pi}, 1.e+00}, + /* 69 */ + {{-max, -1.999999999999869e+00}, + {7.09782712893384e+02, -pi}, + 9.765625e-04}, + /* 70 */ + {{-4.013165208090075e+205, -1.999999999999869e+00}, + {4.734195243224426e+02, -pi}, + 1.953125e-03}, + /* 71 */ + {{-8.958978968710456e+102, -1.999999999999869e+00}, + {2.370563357515012e+02, -pi}, + 3.90625e-03}, + /* 72 */ + {{-1.999999999999869e+00, -1.999999999999869e+00}, + {8.047189562169719e-01, -2.034443935795677e+00}, + 2.5e-01}, + /* 73 */ + {{-4.464794497196183e-103, -1.999999999999869e+00}, + {8.04718956216998e-01, -1.107148717794064e+00}, + 5.e-01}, + /* 74 */ + {{-9.967194951097309e-206, -1.999999999999869e+00}, + {8.04718956216998e-01, -1.107148717794064e+00}, + 5.e-01}, + /* 75 */ + {{-min, -1.999999999999869e+00}, + {8.04718956216998e-01, -1.107148717794064e+00}, + 5.e-01}, + /* 76 */ + {{zero, -1.999999999999869e+00}, + {8.04718956216998e-01, -1.107148717794064e+00}, + 5.e-01}, + /* 77 */ + {{min, -1.999999999999869e+00}, + {8.04718956216998e-01, -1.107148717794064e+00}, + 5.e-01}, + /* 78 */ + {{9.967194951097309e-206, -1.999999999999869e+00}, + {8.04718956216998e-01, -1.107148717794064e+00}, + 5.e-01}, + /* 79 */ + {{4.464794497196183e-103, -1.999999999999869e+00}, + {8.04718956216998e-01, -1.107148717794064e+00}, + 5.e-01}, + /* 80 */ + {{1.999999999999869e+00, -1.999999999999869e+00}, + {1.282474678730718e+00, -5.880026035475575e-01}, + 5.e-01}, + /* 81 */ + {{8.958978968710456e+102, -1.999999999999869e+00}, + {2.370563357515012e+02, -2.232397248598237e-103}, + 3.90625e-03}, + /* 82 */ + {{4.013165208090075e+205, -1.999999999999869e+00}, + {4.734195243224426e+02, -4.98359747554898e-206}, + 1.953125e-03}, + /* 83 */ + {{max, -1.999999999999869e+00}, + {7.09782712893384e+02, zero}, + 9.765625e-04}, + /* 84 */ {{inf, -1.999999999999869e+00}, {inf, zero}, 1.e+00}, + /* 85 */ {{-inf, -4.464794497196183e-103}, {inf, -pi}, 1.e+00}, + /* 86 */ + {{-max, -4.464794497196183e-103}, + {7.09782712893384e+02, -pi}, + 9.765625e-04}, + /* 87 */ + {{-4.013165208090075e+205, -4.464794497196183e-103}, + {4.734195243224426e+02, -pi}, + 1.953125e-03}, + /* 88 */ + {{-8.958978968710456e+102, -4.464794497196183e-103}, + {2.370563357515012e+02, -pi}, + 3.90625e-03}, + /* 89 */ + {{-1.999999999999869e+00, -4.464794497196183e-103}, + {-1.305622276959269e-13, -pi}, + 2.5e-01}, + /* 90 */ + {{-4.464794497196183e-103, -4.464794497196183e-103}, + {-4.464794497196183e-103, -4.464794497196183e-103}, + 1.119872371088902e+102}, + /* 91 */ + {{-9.967194951097309e-206, -4.464794497196183e-103}, + {-6.506695883473837e-219, -4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 92 */ + {{-min, -4.464794497196183e-103}, + {9.967194951096658e-206, -4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 93 */ + {{zero, -4.464794497196183e-103}, + {9.967194951096658e-206, -4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 94 */ + {{min, -4.464794497196183e-103}, + {9.967194951096658e-206, -4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 95 */ + {{9.967194951097309e-206, -4.464794497196183e-103}, + {1.993438990219397e-205, -4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 96 */ + {{4.464794497196183e-103, -4.464794497196183e-103}, + {4.464794497196183e-103, -4.464794497196183e-103}, + 1.119872371088902e+102}, + /* 97 */ + {{1.999999999999869e+00, -4.464794497196183e-103}, + {1.098612288668066e+00, -1.488264832398792e-103}, + 5.e-01}, + /* 98 */ + {{8.958978968710456e+102, -4.464794497196183e-103}, + {2.370563357515012e+02, -4.98359747554898e-206}, + 3.90625e-03}, + /* 99 */ + {{4.013165208090075e+205, -4.464794497196183e-103}, + {4.734195243224426e+02, -1.112536929253666e-308}, + 1.953125e-03}, + /* 100 */ + {{max, -4.464794497196183e-103}, + {7.09782712893384e+02, zero}, + 9.765625e-04}, + /* 101 */ {{inf, -4.464794497196183e-103}, {inf, zero}, 1.e+00}, + /* 102 */ {{-inf, -9.967194951097309e-206}, {inf, -pi}, 1.e+00}, + /* 103 */ + {{-max, -9.967194951097309e-206}, + {7.09782712893384e+02, -pi}, + 9.765625e-04}, + /* 104 */ + {{-4.013165208090075e+205, -9.967194951097309e-206}, + {4.734195243224426e+02, -pi}, + 1.953125e-03}, + /* 105 */ + {{-8.958978968710456e+102, -9.967194951097309e-206}, + {2.370563357515012e+02, -pi}, + 3.90625e-03}, + /* 106 */ + {{-1.999999999999869e+00, -9.967194951097309e-206}, + {-1.305622276959269e-13, -pi}, + 2.5e-01}, + /* 107 */ + {{-4.464794497196183e-103, -9.967194951097309e-206}, + {-4.464794497196183e-103, -9.967194951097309e-206}, + 2.239744742177804e+102}, + /* 108 */ + {{-9.967194951097309e-206, -9.967194951097309e-206}, + {-9.967194951097309e-206, -9.967194951097309e-206}, + 5.016456510113119e+204}, + /* 109 */ + {{-min, -9.967194951097309e-206}, + {-min, -9.967194951097309e-206}, + 1.003291302022624e+205}, + /* 110 */ + {{zero, -9.967194951097309e-206}, + {zero, -9.967194951097309e-206}, + 1.003291302022624e+205}, + /* 111 */ + {{min, -9.967194951097309e-206}, + {min, -9.967194951097309e-206}, + 1.003291302022624e+205}, + /* 112 */ + {{9.967194951097309e-206, -9.967194951097309e-206}, + {9.967194951097309e-206, -9.967194951097309e-206}, + 5.016456510113119e+204}, + /* 113 */ + {{4.464794497196183e-103, -9.967194951097309e-206}, + {4.464794497196183e-103, -9.967194951097309e-206}, + 2.239744742177804e+102}, + /* 114 */ + {{1.999999999999869e+00, -9.967194951097309e-206}, + {1.098612288668066e+00, -3.322398317032581e-206}, + 5.e-01}, + /* 115 */ + {{8.958978968710456e+102, -9.967194951097309e-206}, + {2.370563357515012e+02, -1.112536929253666e-308}, + 3.90625e-03}, + /* 116 */ + {{4.013165208090075e+205, -9.967194951097309e-206}, + {4.734195243224426e+02, zero}, + 1.953125e-03}, + /* 117 */ + {{max, -9.967194951097309e-206}, + {7.09782712893384e+02, zero}, + 9.765625e-04}, + /* 118 */ {{inf, -9.967194951097309e-206}, {inf, zero}, 1.e+00}, + /* 119 */ {{-inf, -min}, {inf, -pi}, 1.e+00}, + /* 120 */ {{-max, -min}, {7.09782712893384e+02, -pi}, 9.765625e-04}, + /* 121 */ + {{-4.013165208090075e+205, -min}, + {4.734195243224426e+02, -pi}, + 1.953125e-03}, + /* 122 */ + {{-8.958978968710456e+102, -min}, + {2.370563357515012e+02, -pi}, + 3.90625e-03}, + /* 123 */ + {{-1.999999999999869e+00, -min}, + {-1.305622276959269e-13, -pi}, + 2.5e-01}, + /* 124 */ + {{-4.464794497196183e-103, -min}, + {-4.464794497196183e-103, -min}, + 2.239744742177804e+102}, + /* 125 */ + {{-9.967194951097309e-206, -min}, + {-9.967194951097309e-206, -min}, + 1.003291302022624e+205}, + /* 126 */ {{-min, -min}, {-min, -min}, 2.247116418577895e+307}, + /* 127 */ {{zero, -min}, {zero, -min}, 2.247116418577895e+307}, + /* 128 */ {{min, -min}, {min, -min}, 2.247116418577895e+307}, + /* 129 */ + {{9.967194951097309e-206, -min}, + {9.967194951097309e-206, -min}, + 1.003291302022624e+205}, + /* 130 */ + {{4.464794497196183e-103, -min}, + {4.464794497196183e-103, -min}, + 2.239744742177804e+102}, + /* 131 */ + {{1.999999999999869e+00, -min}, + {1.098612288668066e+00, zero}, + 5.e-01}, + /* 132 */ + {{8.958978968710456e+102, -min}, + {2.370563357515012e+02, zero}, + 3.90625e-03}, + /* 133 */ + {{4.013165208090075e+205, -min}, + {4.734195243224426e+02, zero}, + 1.953125e-03}, + /* 134 */ {{max, -min}, {7.09782712893384e+02, zero}, 9.765625e-04}, + /* 135 */ {{inf, -min}, {inf, zero}, 1.e+00}, + /* 136 */ {{-inf, zero}, {inf, pi}, 1.e+00}, + /* 137 */ {{-max, zero}, {7.09782712893384e+02, pi}, 9.765625e-04}, + /* 138 */ + {{-4.013165208090075e+205, zero}, + {4.734195243224426e+02, pi}, + 1.953125e-03}, + /* 139 */ + {{-8.958978968710456e+102, zero}, + {2.370563357515012e+02, pi}, + 3.90625e-03}, + /* 140 */ + {{-1.999999999999869e+00, zero}, + {-1.305622276959269e-13, pi}, + 2.5e-01}, + /* 141 */ + {{-4.464794497196183e-103, zero}, + {-4.464794497196183e-103, zero}, + 2.239744742177804e+102}, + /* 142 */ + {{-9.967194951097309e-206, zero}, + {-9.967194951097309e-206, zero}, + 1.003291302022624e+205}, + /* 143 */ {{-min, zero}, {-min, zero}, 2.247116418577895e+307}, + /* 144 */ {{zero, zero}, {zero, zero}, 1.e+00}, + /* 145 */ {{min, zero}, {min, zero}, 2.247116418577895e+307}, + /* 146 */ + {{9.967194951097309e-206, zero}, + {9.967194951097309e-206, zero}, + 1.003291302022624e+205}, + /* 147 */ + {{4.464794497196183e-103, zero}, + {4.464794497196183e-103, zero}, + 2.239744742177804e+102}, + /* 148 */ + {{1.999999999999869e+00, zero}, + {1.098612288668066e+00, zero}, + 5.e-01}, + /* 149 */ + {{8.958978968710456e+102, zero}, + {2.370563357515012e+02, zero}, + 3.90625e-03}, + /* 150 */ + {{4.013165208090075e+205, zero}, + {4.734195243224426e+02, zero}, + 1.953125e-03}, + /* 151 */ {{max, zero}, {7.09782712893384e+02, zero}, 9.765625e-04}, + /* 152 */ {{inf, zero}, {inf, zero}, 1.e+00}, + /* 153 */ {{-inf, min}, {inf, pi}, 1.e+00}, + /* 154 */ {{-max, min}, {7.09782712893384e+02, pi}, 9.765625e-04}, + /* 155 */ + {{-4.013165208090075e+205, min}, + {4.734195243224426e+02, pi}, + 1.953125e-03}, + /* 156 */ + {{-8.958978968710456e+102, min}, + {2.370563357515012e+02, pi}, + 3.90625e-03}, + /* 157 */ + {{-1.999999999999869e+00, min}, + {-1.305622276959269e-13, pi}, + 2.5e-01}, + /* 158 */ + {{-4.464794497196183e-103, min}, + {-4.464794497196183e-103, min}, + 2.239744742177804e+102}, + /* 159 */ + {{-9.967194951097309e-206, min}, + {-9.967194951097309e-206, min}, + 1.003291302022624e+205}, + /* 160 */ {{-min, min}, {-min, min}, 2.247116418577895e+307}, + /* 161 */ {{zero, min}, {zero, min}, 2.247116418577895e+307}, + /* 162 */ {{min, min}, {min, min}, 2.247116418577895e+307}, + /* 163 */ + {{9.967194951097309e-206, min}, + {9.967194951097309e-206, min}, + 1.003291302022624e+205}, + /* 164 */ + {{4.464794497196183e-103, min}, + {4.464794497196183e-103, min}, + 2.239744742177804e+102}, + /* 165 */ + {{1.999999999999869e+00, min}, {1.098612288668066e+00, zero}, 5.e-01}, + /* 166 */ + {{8.958978968710456e+102, min}, + {2.370563357515012e+02, zero}, + 3.90625e-03}, + /* 167 */ + {{4.013165208090075e+205, min}, + {4.734195243224426e+02, zero}, + 1.953125e-03}, + /* 168 */ {{max, min}, {7.09782712893384e+02, zero}, 9.765625e-04}, + /* 169 */ {{inf, min}, {inf, zero}, 1.e+00}, + /* 170 */ {{-inf, 9.967194951097309e-206}, {inf, pi}, 1.e+00}, + /* 171 */ + {{-max, 9.967194951097309e-206}, + {7.09782712893384e+02, pi}, + 9.765625e-04}, + /* 172 */ + {{-4.013165208090075e+205, 9.967194951097309e-206}, + {4.734195243224426e+02, pi}, + 1.953125e-03}, + /* 173 */ + {{-8.958978968710456e+102, 9.967194951097309e-206}, + {2.370563357515012e+02, pi}, + 3.90625e-03}, + /* 174 */ + {{-1.999999999999869e+00, 9.967194951097309e-206}, + {-1.305622276959269e-13, pi}, + 2.5e-01}, + /* 175 */ + {{-4.464794497196183e-103, 9.967194951097309e-206}, + {-4.464794497196183e-103, 9.967194951097309e-206}, + 2.239744742177804e+102}, + /* 176 */ + {{-9.967194951097309e-206, 9.967194951097309e-206}, + {-9.967194951097309e-206, 9.967194951097309e-206}, + 5.016456510113119e+204}, + /* 177 */ + {{-min, 9.967194951097309e-206}, + {-min, 9.967194951097309e-206}, + 1.003291302022624e+205}, + /* 178 */ + {{zero, 9.967194951097309e-206}, + {zero, 9.967194951097309e-206}, + 1.003291302022624e+205}, + /* 179 */ + {{min, 9.967194951097309e-206}, + {min, 9.967194951097309e-206}, + 1.003291302022624e+205}, + /* 180 */ + {{9.967194951097309e-206, 9.967194951097309e-206}, + {9.967194951097309e-206, 9.967194951097309e-206}, + 5.016456510113119e+204}, + /* 181 */ + {{4.464794497196183e-103, 9.967194951097309e-206}, + {4.464794497196183e-103, 9.967194951097309e-206}, + 2.239744742177804e+102}, + /* 182 */ + {{1.999999999999869e+00, 9.967194951097309e-206}, + {1.098612288668066e+00, 3.322398317032581e-206}, + 5.e-01}, + /* 183 */ + {{8.958978968710456e+102, 9.967194951097309e-206}, + {2.370563357515012e+02, 1.112536929253666e-308}, + 3.90625e-03}, + /* 184 */ + {{4.013165208090075e+205, 9.967194951097309e-206}, + {4.734195243224426e+02, zero}, + 1.953125e-03}, + /* 185 */ + {{max, 9.967194951097309e-206}, + {7.09782712893384e+02, zero}, + 9.765625e-04}, + /* 186 */ {{inf, 9.967194951097309e-206}, {inf, zero}, 1.e+00}, + /* 187 */ {{-inf, 4.464794497196183e-103}, {inf, pi}, 1.e+00}, + /* 188 */ + {{-max, 4.464794497196183e-103}, + {7.09782712893384e+02, pi}, + 9.765625e-04}, + /* 189 */ + {{-4.013165208090075e+205, 4.464794497196183e-103}, + {4.734195243224426e+02, pi}, + 1.953125e-03}, + /* 190 */ + {{-8.958978968710456e+102, 4.464794497196183e-103}, + {2.370563357515012e+02, pi}, + 3.90625e-03}, + /* 191 */ + {{-1.999999999999869e+00, 4.464794497196183e-103}, + {-1.305622276959269e-13, pi}, + 2.5e-01}, + /* 192 */ + {{-4.464794497196183e-103, 4.464794497196183e-103}, + {-4.464794497196183e-103, 4.464794497196183e-103}, + 1.119872371088902e+102}, + /* 193 */ + {{-9.967194951097309e-206, 4.464794497196183e-103}, + {-6.506695883473837e-219, 4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 194 */ + {{-min, 4.464794497196183e-103}, + {9.967194951096658e-206, 4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 195 */ + {{zero, 4.464794497196183e-103}, + {9.967194951096658e-206, 4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 196 */ + {{min, 4.464794497196183e-103}, + {9.967194951096658e-206, 4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 197 */ + {{9.967194951097309e-206, 4.464794497196183e-103}, + {1.993438990219397e-205, 4.464794497196183e-103}, + 2.239744742177804e+102}, + /* 198 */ + {{4.464794497196183e-103, 4.464794497196183e-103}, + {4.464794497196183e-103, 4.464794497196183e-103}, + 1.119872371088902e+102}, + /* 199 */ + {{1.999999999999869e+00, 4.464794497196183e-103}, + {1.098612288668066e+00, 1.488264832398792e-103}, + 5.e-01}, + /* 200 */ + {{8.958978968710456e+102, 4.464794497196183e-103}, + {2.370563357515012e+02, 4.98359747554898e-206}, + 3.90625e-03}, + /* 201 */ + {{4.013165208090075e+205, 4.464794497196183e-103}, + {4.734195243224426e+02, 1.112536929253666e-308}, + 1.953125e-03}, + /* 202 */ + {{max, 4.464794497196183e-103}, + {7.09782712893384e+02, zero}, + 9.765625e-04}, + /* 203 */ {{inf, 4.464794497196183e-103}, {inf, zero}, 1.e+00}, + /* 204 */ {{-inf, 1.999999999999869e+00}, {inf, pi}, 1.e+00}, + /* 205 */ + {{-max, 1.999999999999869e+00}, + {7.09782712893384e+02, pi}, + 9.765625e-04}, + /* 206 */ + {{-4.013165208090075e+205, 1.999999999999869e+00}, + {4.734195243224426e+02, pi}, + 1.953125e-03}, + /* 207 */ + {{-8.958978968710456e+102, 1.999999999999869e+00}, + {2.370563357515012e+02, pi}, + 3.90625e-03}, + /* 208 */ + {{-1.999999999999869e+00, 1.999999999999869e+00}, + {8.047189562169719e-01, 2.034443935795677e+00}, + 2.5e-01}, + /* 209 */ + {{-4.464794497196183e-103, 1.999999999999869e+00}, + {8.04718956216998e-01, 1.107148717794064e+00}, + 5.e-01}, + /* 210 */ + {{-9.967194951097309e-206, 1.999999999999869e+00}, + {8.04718956216998e-01, 1.107148717794064e+00}, + 5.e-01}, + /* 211 */ + {{-min, 1.999999999999869e+00}, + {8.04718956216998e-01, 1.107148717794064e+00}, + 5.e-01}, + /* 212 */ + {{zero, 1.999999999999869e+00}, + {8.04718956216998e-01, 1.107148717794064e+00}, + 5.e-01}, + /* 213 */ + {{min, 1.999999999999869e+00}, + {8.04718956216998e-01, 1.107148717794064e+00}, + 5.e-01}, + /* 214 */ + {{9.967194951097309e-206, 1.999999999999869e+00}, + {8.04718956216998e-01, 1.107148717794064e+00}, + 5.e-01}, + /* 215 */ + {{4.464794497196183e-103, 1.999999999999869e+00}, + {8.04718956216998e-01, 1.107148717794064e+00}, + 5.e-01}, + /* 216 */ + {{1.999999999999869e+00, 1.999999999999869e+00}, + {1.282474678730718e+00, 5.880026035475575e-01}, + 5.e-01}, + /* 217 */ + {{8.958978968710456e+102, 1.999999999999869e+00}, + {2.370563357515012e+02, 2.232397248598237e-103}, + 3.90625e-03}, + /* 218 */ + {{4.013165208090075e+205, 1.999999999999869e+00}, + {4.734195243224426e+02, 4.98359747554898e-206}, + 1.953125e-03}, + /* 219 */ + {{max, 1.999999999999869e+00}, + {7.09782712893384e+02, zero}, + 9.765625e-04}, + /* 220 */ {{inf, 1.999999999999869e+00}, {inf, zero}, 1.e+00}, + /* 221 */ {{-inf, 8.958978968710456e+102}, {inf, pi}, 1.e+00}, + /* 222 */ + {{-max, 8.958978968710456e+102}, + {7.09782712893384e+02, pi}, + 9.765625e-04}, + /* 223 */ + {{-4.013165208090075e+205, 8.958978968710456e+102}, + {4.734195243224426e+02, pi}, + 1.953125e-03}, + /* 224 */ + {{-8.958978968710456e+102, 8.958978968710456e+102}, + {2.374029093417812e+02, pi3_4}, + 3.90625e-03}, + /* 225 */ + {{-1.999999999999869e+00, 8.958978968710456e+102}, + {2.370563357515012e+02, pi_2}, + 3.90625e-03}, + /* 226 */ + {{-4.464794497196183e-103, 8.958978968710456e+102}, + {2.370563357515012e+02, pi_2}, + 3.90625e-03}, + /* 227 */ + {{-9.967194951097309e-206, 8.958978968710456e+102}, + {2.370563357515012e+02, pi_2}, + 3.90625e-03}, + /* 228 */ + {{-min, 8.958978968710456e+102}, + {2.370563357515012e+02, pi_2}, + 3.90625e-03}, + /* 229 */ + {{zero, 8.958978968710456e+102}, + {2.370563357515012e+02, pi_2}, + 3.90625e-03}, + /* 230 */ + {{min, 8.958978968710456e+102}, + {2.370563357515012e+02, pi_2}, + 3.90625e-03}, + /* 231 */ + {{9.967194951097309e-206, 8.958978968710456e+102}, + {2.370563357515012e+02, pi_2}, + 3.90625e-03}, + /* 232 */ + {{4.464794497196183e-103, 8.958978968710456e+102}, + {2.370563357515012e+02, pi_2}, + 3.90625e-03}, + /* 233 */ + {{1.999999999999869e+00, 8.958978968710456e+102}, + {2.370563357515012e+02, pi_2}, + 3.90625e-03}, + /* 234 */ + {{8.958978968710456e+102, 8.958978968710456e+102}, + {2.374029093417812e+02, pi_4}, + 3.90625e-03}, + /* 235 */ + {{4.013165208090075e+205, 8.958978968710456e+102}, + {4.734195243224426e+02, 2.232397248598237e-103}, + 1.953125e-03}, + /* 236 */ + {{max, 8.958978968710456e+102}, + {7.09782712893384e+02, 4.983597475548361e-206}, + 9.765625e-04}, + /* 237 */ {{inf, 8.958978968710456e+102}, {inf, zero}, 1.e+00}, + /* 238 */ {{-inf, 4.013165208090075e+205}, {inf, pi}, 1.e+00}, + /* 239 */ + {{-max, 4.013165208090075e+205}, + {7.09782712893384e+02, pi}, + 9.765625e-04}, + /* 240 */ + {{-4.013165208090075e+205, 4.013165208090075e+205}, + {4.737660979127225e+02, pi3_4}, + 1.953125e-03}, + /* 241 */ + {{-8.958978968710456e+102, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 242 */ + {{-1.999999999999869e+00, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 243 */ + {{-4.464794497196183e-103, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 244 */ + {{-9.967194951097309e-206, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 245 */ + {{-min, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 246 */ + {{zero, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 247 */ + {{min, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 248 */ + {{9.967194951097309e-206, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 249 */ + {{4.464794497196183e-103, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 250 */ + {{1.999999999999869e+00, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 251 */ + {{8.958978968710456e+102, 4.013165208090075e+205}, + {4.734195243224426e+02, pi_2}, + 1.953125e-03}, + /* 252 */ + {{4.013165208090075e+205, 4.013165208090075e+205}, + {4.737660979127225e+02, pi_4}, + 1.953125e-03}, + /* 253 */ + {{max, 4.013165208090075e+205}, + {7.09782712893384e+02, 2.23239724859796e-103}, + 9.765625e-04}, + /* 254 */ {{inf, 4.013165208090075e+205}, {inf, zero}, 1.e+00}, + /* 255 */ {{-inf, max}, {inf, pi}, 1.e+00}, + /* 256 */ {{-max, max}, {7.101292864836639e+02, pi3_4}, 9.765625e-04}, + /* 257 */ + {{-4.013165208090075e+205, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 258 */ + {{-8.958978968710456e+102, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 259 */ + {{-1.999999999999869e+00, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 260 */ + {{-4.464794497196183e-103, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 261 */ + {{-9.967194951097309e-206, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 262 */ {{-min, max}, {7.09782712893384e+02, pi_2}, 9.765625e-04}, + /* 263 */ {{zero, max}, {7.09782712893384e+02, pi_2}, 9.765625e-04}, + /* 264 */ {{min, max}, {7.09782712893384e+02, pi_2}, 9.765625e-04}, + /* 265 */ + {{9.967194951097309e-206, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 266 */ + {{4.464794497196183e-103, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 267 */ + {{1.999999999999869e+00, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 268 */ + {{8.958978968710456e+102, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 269 */ + {{4.013165208090075e+205, max}, + {7.09782712893384e+02, pi_2}, + 9.765625e-04}, + /* 270 */ {{max, max}, {7.101292864836639e+02, pi_4}, 9.765625e-04}, + /* 271 */ {{inf, max}, {inf, zero}, 1.e+00}, + /* 272 */ {{-inf, inf}, {inf, pi3_4}, 1.e+00}, + /* 273 */ {{-max, inf}, {inf, pi_2}, 1.e+00}, + /* 274 */ {{-4.013165208090075e+205, inf}, {inf, pi_2}, 1.e+00}, + /* 275 */ {{-8.958978968710456e+102, inf}, {inf, pi_2}, 1.e+00}, + /* 276 */ {{-1.999999999999869e+00, inf}, {inf, pi_2}, 1.e+00}, + /* 277 */ {{-4.464794497196183e-103, inf}, {inf, pi_2}, 1.e+00}, + /* 278 */ {{-9.967194951097309e-206, inf}, {inf, pi_2}, 1.e+00}, + /* 279 */ {{-min, inf}, {inf, pi_2}, 1.e+00}, + /* 280 */ {{zero, inf}, {inf, pi_2}, 1.e+00}, + /* 281 */ {{min, inf}, {inf, pi_2}, 1.e+00}, + /* 282 */ {{9.967194951097309e-206, inf}, {inf, pi_2}, 1.e+00}, + /* 283 */ {{4.464794497196183e-103, inf}, {inf, pi_2}, 1.e+00}, + /* 284 */ {{1.999999999999869e+00, inf}, {inf, pi_2}, 1.e+00}, + /* 285 */ {{8.958978968710456e+102, inf}, {inf, pi_2}, 1.e+00}, + /* 286 */ {{4.013165208090075e+205, inf}, {inf, pi_2}, 1.e+00}, + /* 287 */ {{max, inf}, {inf, pi_2}, 1.e+00}, + /* 288 */ {{inf, inf}, {inf, pi_4}, 1.e+00}}; + return table; + } else { + static_assert(false); /* unreachable */ + } + } +}; + +} // namespace complex_unary_op_samples + +#endif // XLA_TESTS_COMPLEX_UNARY_OP_SAMPLES_H_ diff --git a/xla/tests/complex_unary_op_test.cc b/xla/tests/complex_unary_op_test.cc new file mode 100644 index 0000000000000..fefe96cb59c69 --- /dev/null +++ b/xla/tests/complex_unary_op_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "xla/client/global_data.h" +#include "xla/client/local_client.h" +#include "xla/client/xla_builder.h" +#include "xla/tests/client_library_test_base.h" +#include "xla/tests/complex_unary_op_samples.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tests/test_macros.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace { + +class ComplexUnaryOpTest : public ClientLibraryTestBase { + protected: + template + std::vector get_column(const std::vector>& table) { + std::vector column; + std::transform( + table.cbegin(), table.cend(), std::back_inserter(column), + [](const auto& item) { return static_cast(std::get(item)); }); + return column; + } + + template + void scale_column(std::vector& column, const std::vector& scales) { + std::transform(column.begin(), column.end(), scales.begin(), column.begin(), + [](const T& lhs, const S& rhs) { return lhs * rhs; }); + } + + template + void UnaryTestHelper(XlaOp (*Op)(const XlaOp operand)) { + using InputType = typename C::InputType; + using OutputType = typename C::OutputType; + using FloatType = typename C::FloatType; + + float atol; + // log(10)/log(2) = 3.3219... + constexpr int precision_deficiency = + static_cast(C::dps_deficiency * 3.3219280948873626); + // precision_deficiency defines a slack allowed when comparing a + // result value against expected value that is known to be + // inaccurate to some extent. + if constexpr (std::is_same_v) { + atol = std::ldexp(1e-6f, precision_deficiency); + } else if constexpr (std::is_same_v) { + atol = std::ldexp(1e-15f, precision_deficiency); + } else { + static_assert(false); // unreachable + } + + XlaBuilder builder(TestName()); + auto table = C().get(); + auto inputs_vec = get_column(table); + auto expected_vec = get_column(table); + auto scales_vec = get_column(table); + scale_column(expected_vec, scales_vec); + + auto inputs = ConstantR1(&builder, inputs_vec); + auto scales = ConstantR1(&builder, scales_vec); + Literal expected = LiteralUtil::CreateR1(expected_vec); + + if constexpr (std::is_same_v) { + auto results = Op(inputs); + Mul(results, scales); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(atol)); + } else { + auto results = Op(inputs); + auto re = Mul(Real(results), scales); + auto im = Mul(Imag(results), scales); + Complex(re, im); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(atol)); + } + } +}; + +XLA_TEST_F(ComplexUnaryOpTest, Log1pTest) { + UnaryTestHelper>(Log1p); + UnaryTestHelper>(Log1p); +} + +} // namespace +} // namespace xla diff --git a/xla/tests/compute_constant_test.cc b/xla/tests/compute_constant_test.cc index c05e2b2d22deb..b37cf2d9b8395 100644 --- a/xla/tests/compute_constant_test.cc +++ b/xla/tests/compute_constant_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -53,13 +53,13 @@ class ComputeConstantTest : public ::testing::Test { Client* ClientOrDie(se::Platform* platform, ClientType client_type) { if (client_type == ClientType::kLocal) { - StatusOr result = + absl::StatusOr result = ClientLibrary::GetOrCreateLocalClient(platform); TF_CHECK_OK(result.status()) << "could not create LocalClient for testing"; return result.value(); } else if (client_type == ClientType::kCompileOnly) { - StatusOr result = + absl::StatusOr result = ClientLibrary::GetOrCreateCompileOnlyClient(platform); TF_CHECK_OK(result.status()) << "could not create CompileOnlyClient for testing"; @@ -68,9 +68,9 @@ class ComputeConstantTest : public ::testing::Test { LOG(FATAL) << "invalid client_type value"; } - StatusOr ComputeConstantLiteral(Client* client, const XlaOp operand, - XlaBuilder* builder, - Layout* output_layout = nullptr) { + absl::StatusOr ComputeConstantLiteral( + Client* client, const XlaOp operand, XlaBuilder* builder, + Layout* output_layout = nullptr) { TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand)); TF_ASSIGN_OR_RETURN(auto computed, client->ComputeConstant(subgraph, output_layout)); @@ -86,7 +86,7 @@ class ComputeConstantTest : public ::testing::Test { } bool IsConstant(const XlaOp operand, XlaBuilder* builder) { - StatusOr result = builder->IsConstant(operand); + absl::StatusOr result = builder->IsConstant(operand); EXPECT_TRUE(result.ok()) << result.status(); return result.ok() ? result.value() : false; } diff --git a/xla/tests/concat_test.cc b/xla/tests/concat_test.cc index 998ff8b2bb0d9..01bd3dd6f76a0 100644 --- a/xla/tests/concat_test.cc +++ b/xla/tests/concat_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ using ::testing::HasSubstr; XLA_TEST_F(ConcatTest, Concat_Nothing) { XlaBuilder builder(TestName()); ConcatInDim(&builder, {}, 0); - StatusOr computation_status = builder.Build(); + absl::StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), HasSubstr("Concatenate expects at least one argument")); @@ -75,7 +75,7 @@ XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { auto a = ConstantR0(&builder, 42.0); auto b = ConstantR0(&builder, 64.0); ConcatInDim(&builder, {a, b}, 0); - StatusOr computation_status = builder.Build(); + absl::StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), HasSubstr("out of bounds: 0")); @@ -416,7 +416,7 @@ XLA_TEST_F(ConcatTest, CannotConcatOpaques) { auto x = Parameter(&builder, 0, r1f32, "x"); auto y = Parameter(&builder, 1, opaque_shape, "y"); ConcatInDim(&builder, {x, y}, 0); - StatusOr computation_status = builder.Build(); + absl::StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT( computation_status.status().ToString(), @@ -431,7 +431,7 @@ XLA_TEST_F(ConcatTest, CannotConcatTokens) { auto x = Parameter(&builder, 0, r1f32, "x"); auto y = Parameter(&builder, 1, token_shape, "y"); ConcatInDim(&builder, {x, y}, 0); - StatusOr computation_status = builder.Build(); + absl::StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT( computation_status.status().ToString(), diff --git a/xla/tests/conditional_test.cc b/xla/tests/conditional_test.cc index 33e6fb441550f..7d13d06851077 100644 --- a/xla/tests/conditional_test.cc +++ b/xla/tests/conditional_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/constant_reduction_function_test.cc b/xla/tests/constant_reduction_function_test.cc index 9d212d1d3b7d9..57c603023610c 100644 --- a/xla/tests/constant_reduction_function_test.cc +++ b/xla/tests/constant_reduction_function_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/constants_test.cc b/xla/tests/constants_test.cc index 31201d8cab145..b487877633b0b 100644 --- a/xla/tests/constants_test.cc +++ b/xla/tests/constants_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,7 +32,7 @@ limitations under the License. #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" #include "tsl/lib/core/status_test_util.h" -#include "tsl/platform/float8.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/tests/conv_depthwise_backprop_filter_test.cc b/xla/tests/conv_depthwise_backprop_filter_test.cc index 5a0556baf2ce8..55a540c7e9df1 100644 --- a/xla/tests/conv_depthwise_backprop_filter_test.cc +++ b/xla/tests/conv_depthwise_backprop_filter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/conv_depthwise_common.cc b/xla/tests/conv_depthwise_common.cc index 76d1b116735a8..07a7bbfa53b28 100644 --- a/xla/tests/conv_depthwise_common.cc +++ b/xla/tests/conv_depthwise_common.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/conv_depthwise_common.h b/xla/tests/conv_depthwise_common.h index 3163d3ff4568a..0deb41064ee0d 100644 --- a/xla/tests/conv_depthwise_common.h +++ b/xla/tests/conv_depthwise_common.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/conv_depthwise_test.cc b/xla/tests/conv_depthwise_test.cc index 3f60a27c5db8f..b44e089b596ce 100644 --- a/xla/tests/conv_depthwise_test.cc +++ b/xla/tests/conv_depthwise_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/convert_test.cc b/xla/tests/convert_test.cc index 3e1ae9c9a30f6..91d545023c637 100644 --- a/xla/tests/convert_test.cc +++ b/xla/tests/convert_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ limitations under the License. #include "xla/tests/test_macros.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/float8.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/test.h" namespace xla { @@ -43,7 +43,9 @@ class ConvertTest : public ClientLibraryTestBase { : ClientLibraryTestBase(platform) { mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); mutable_debug_options()->add_xla_disable_hlo_passes("inline"); - mutable_debug_options()->set_xla_gpu_simplify_all_fp_conversions(false); + mutable_debug_options()->add_xla_disable_hlo_passes( + "simplify-fp-conversions"); + mutable_debug_options()->set_xla_allow_excess_precision(false); } }; diff --git a/xla/tests/convolution_cudnn_test.cc b/xla/tests/convolution_cudnn_test.cc index 1bacf8cfa18f9..ad467b49b19d9 100644 --- a/xla/tests/convolution_cudnn_test.cc +++ b/xla/tests/convolution_cudnn_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -98,9 +98,9 @@ ENTRY TestComputation { convert.1 = f32[64]{0} convert(bias) cudnn-conv-bias-activation.3 = (s8[4,48,48,64]{3,2,1,0}, u8[0]{0}) custom-call(input, filter, convert.1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convBiasActivationForward", - backend_config="{\"activation_mode\":\"2\",\"conv_result_scale\":1,\"side_input_scale\":0,\"algorithm\":{ + backend_config="{\"cudnn_conv_backend_config\":{\"activation_mode\":\"2\",\"conv_result_scale\":1,\"side_input_scale\":0,\"algorithm\":{ \"algo_id\":\"38\",\"math_type\":\"DEFAULT_MATH\",\"tuning_knobs\":{\"14\":\"5\",\"13\":\"1\",\"23\":\"0\",\"2\":\"1\"}, - \"is_cudnn_frontend\":true,\"workspace_size\":\"0\"}}" + \"is_cudnn_frontend\":true,\"workspace_size\":\"0\"}}}" ROOT get-tuple-element.1 = s8[4,48,48,64]{3,2,1,0} get-tuple-element(cudnn-conv-bias-activation.3), index=0 })"; constexpr char kHloVectorized[] = R"( @@ -121,9 +121,9 @@ ENTRY TestComputation { bitcast.37 = f32[64]{0} bitcast(transpose.6) cudnn-conv-bias-activation.4 = (s8[4,2,48,48,32]{4,3,2,1,0}, u8[51328]{0}) custom-call(transpose, bitcast.28, bitcast.37), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convBiasActivationForward", - backend_config="{\"activation_mode\":\"2\",\"conv_result_scale\":1,\"side_input_scale\":0,\"algorithm\":{ + backend_config="{\"cudnn_conv_backend_config\":{\"activation_mode\":\"2\",\"conv_result_scale\":1,\"side_input_scale\":0,\"algorithm\":{ \"algo_id\":\"7\",\"math_type\":\"DEFAULT_MATH\",\"tuning_knobs\":{\"7\":\"3\",\"2\":\"0\",\"5\":\"4\",\"6\":\"4\",\"4\":\"2\",\"21\":\"0\"}, - \"is_cudnn_frontend\":true,\"workspace_size\":\"51328\"},\"reordered_int8_nchw_vect\":true}" + \"is_cudnn_frontend\":true,\"workspace_size\":\"51328\"},\"reordered_int8_nchw_vect\":true}}" get-tuple-element.6 = s8[4,2,48,48,32]{4,3,2,1,0} get-tuple-element(cudnn-conv-bias-activation.4), index=0 transpose.1 = s8[4,48,48,2,32]{4,3,2,1,0} transpose(get-tuple-element.6), dimensions={0,2,3,1,4} ROOT bitcast.1 = s8[4,48,48,64]{3,2,1,0} bitcast(transpose.1) @@ -161,7 +161,7 @@ ENTRY TestComputation { %conv = (s8[4,48,48,64], u8[0]) custom-call(%input, %filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForward", - backend_config="{\"activation_mode\":\"0\",\"conv_result_scale\":1,\"side_input_scale\":0}" + backend_config="{\"cudnn_conv_backend_config\":{\"activation_mode\":\"0\",\"conv_result_scale\":1,\"side_input_scale\":0}}" ROOT %gte = s8[4,48,48,64] get-tuple-element(%conv), index=0 })"; constexpr char kHloVectorized[] = R"( @@ -175,7 +175,7 @@ ENTRY TestComputation { %conv = (s8[4,48,48,16,4], u8[0]) custom-call(%input.1, %filter.1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForward", - backend_config="{\"activation_mode\":\"0\",\"conv_result_scale\":1,\"side_input_scale\":0}" + backend_config="{\"cudnn_conv_backend_config\":{\"activation_mode\":\"0\",\"conv_result_scale\":1,\"side_input_scale\":0}}" %gte = s8[4,48,48,16,4] get-tuple-element(%conv), index=0 ROOT reshape.3 = s8[4,48,48,64] reshape(%gte) })"; diff --git a/xla/tests/convolution_dimension_numbers_test.cc b/xla/tests/convolution_dimension_numbers_test.cc index 07b6e280b55a0..64ebe394db3cb 100644 --- a/xla/tests/convolution_dimension_numbers_test.cc +++ b/xla/tests/convolution_dimension_numbers_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ limitations under the License. namespace xla { namespace { -StatusOr CreateConvDimensionNumbers( +absl::StatusOr CreateConvDimensionNumbers( int64_t input_batch, int64_t input_feature, int64_t input_first_spatial, int64_t input_second_spatial, int64_t output_batch, int64_t output_feature, int64_t output_first_spatial, int64_t output_second_spatial, diff --git a/xla/tests/convolution_test.cc b/xla/tests/convolution_test.cc index 5f319913dc258..4691bf1d6bd99 100644 --- a/xla/tests/convolution_test.cc +++ b/xla/tests/convolution_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/convolution_test_1d.cc b/xla/tests/convolution_test_1d.cc index fbdf2fbbcbd60..35780649a4c15 100644 --- a/xla/tests/convolution_test_1d.cc +++ b/xla/tests/convolution_test_1d.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/convolution_variants_test.cc b/xla/tests/convolution_variants_test.cc index a6281d78142aa..6fb47d1e15640 100644 --- a/xla/tests/convolution_variants_test.cc +++ b/xla/tests/convolution_variants_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/copy_test.cc b/xla/tests/copy_test.cc index 72f5c23749454..45d94ab033383 100644 --- a/xla/tests/copy_test.cc +++ b/xla/tests/copy_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/cpu_gpu_fusion_test.cc b/xla/tests/cpu_gpu_fusion_test.cc index 452d1dc0a6c13..a504247923df4 100644 --- a/xla/tests/cpu_gpu_fusion_test.cc +++ b/xla/tests/cpu_gpu_fusion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -872,6 +872,21 @@ XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) { ComputeAndCompare(&b, {}); } +XLA_TEST_F(CpuGpuFusionTest, TransposeDiamondWithNonTrivialBranch) { + const char* hlo = R"( +HloModule module + +ENTRY entry { + p = f64[16,16]{1,0} parameter(0) + trans = f64[16,16]{1,0} transpose(p), dimensions={1,0} + rev = f64[16,16]{1,0} reverse(trans), dimensions={0,1} + sub = f64[16,16]{1,0} subtract(trans, trans) + ROOT add = f64[16,16]{1,0} add(rev, sub) +} +)"; + EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{1e-5, 1e-5})); +} + void BM_ParallelFusion(::testing::benchmark::State& state) { // Simple element-wise computation to benchmark parallel task partitioning. @@ -934,9 +949,7 @@ void BM_ParallelFusion(::testing::benchmark::State& state) { .value(); auto executable = std::move(executables[0]); - se::Stream stream(executors[device_ordinal]); - stream.Init(); - + auto stream = executors[device_ordinal]->CreateStream().value(); // Initialize thread pool. tsl::thread::ThreadPool pool(tsl::Env::Default(), "XLAEigen", intra_op_parallelism_threads); @@ -944,7 +957,7 @@ void BM_ParallelFusion(::testing::benchmark::State& state) { // Initialize ExecutableRunOptions. ExecutableRunOptions options; - options.set_allocator(&allocator).set_stream(&stream); + options.set_allocator(&allocator).set_stream(stream.get()); options.set_intra_op_thread_pool(&device); // Run some warm-up executions. diff --git a/xla/tests/custom_call_test.cc b/xla/tests/custom_call_test.cc index 25d7a31e8e44c..e376f90a6aaff 100644 --- a/xla/tests/custom_call_test.cc +++ b/xla/tests/custom_call_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,18 +13,30 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include +#include #include #include +#include "absl/algorithm/container.h" #include "absl/base/dynamic_annotations.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/client/lib/constants.h" #include "xla/client/xla_builder.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout_util.h" +#include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" @@ -35,6 +47,7 @@ limitations under the License. #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace { @@ -101,67 +114,69 @@ namespace { using ::testing::HasSubstr; class CustomCallTest : public HloTestBase { + public: + CustomCallTest() + : HloTestBase(), + module_(CreateNewVerifiedModule()), + builder_(TestName()) {} + protected: + // Call this function when builder_ is complete (i.e. when all instructions + // have been added). Note that module_ is empty after calling this function. + auto BuildAndExecute(absl::Span arguments) { + module_->AddEntryComputation(builder_.Build()); + return Execute(std::move(module_), arguments); + } + Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); Shape r2f32_ = ShapeUtil::MakeShape(F32, {2, 2}); + + std::unique_ptr module_; + HloComputation::Builder builder_; }; XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2) { - auto module = CreateNewVerifiedModule(); - auto builder = HloComputation::Builder(TestName()); - - auto constant = builder.AddInstruction( + auto constant = builder_.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - builder.AddInstruction( + builder_.AddInstruction( HloInstruction::CreateCustomCall(r0f32_, {constant}, "R0F32Add2")); - module->AddEntryComputation(builder.Build()); - - Literal result = ExecuteAndTransfer(std::move(module), {}); + TF_ASSERT_OK_AND_ASSIGN(auto result, BuildAndExecute({})); LiteralTestUtil::ExpectR0Near(44.0f, result, error_spec_); } XLA_TEST_F(CustomCallTest, CustomCallR2F32Reduce) { - auto module = CreateNewVerifiedModule(); - auto builder = HloComputation::Builder(TestName()); - Array2D array(2, 2); array(0, 0) = 1.0f; array(0, 1) = 2.0f; array(1, 0) = 3.0f; array(1, 1) = 4.0f; - auto constant = builder.AddInstruction( + auto constant = builder_.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(array))); - builder.AddInstruction( + builder_.AddInstruction( HloInstruction::CreateCustomCall(r0f32_, {constant}, "R2F32ReduceSum")); - module->AddEntryComputation(builder.Build()); - - Literal result = ExecuteAndTransfer(std::move(module), {}); + TF_ASSERT_OK_AND_ASSIGN(auto result, BuildAndExecute({})); LiteralTestUtil::ExpectR0Near(10.0f, result, error_spec_); } XLA_TEST_F(CustomCallTest, UsedInOtherComputations) { - auto module = CreateNewVerifiedModule(); - auto b = HloComputation::Builder(TestName()); - - auto input = b.AddInstruction( + auto input = builder_.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D( Array2D{{1.0f, 2.0f}, {3.0f, 4.0f}}))); - auto incremented = b.AddInstruction(HloInstruction::CreateCustomCall( + auto incremented = builder_.AddInstruction(HloInstruction::CreateCustomCall( ShapeUtil::MakeShape(F32, {1, 2, 2}), {input}, "Add1ToValues")); - auto incremented_again = b.AddInstruction(HloInstruction::CreateCustomCall( - ShapeUtil::MakeShape(F32, {1, 2, 2}), {incremented}, "Add1ToValues")); + auto incremented_again = + builder_.AddInstruction(HloInstruction::CreateCustomCall( + ShapeUtil::MakeShape(F32, {1, 2, 2}), {incremented}, "Add1ToValues")); // Concatenate the values along first dim. - b.AddInstruction( + builder_.AddInstruction( HloInstruction::CreateConcatenate(ShapeUtil::MakeShape(F32, {2, 2, 2}), {incremented, incremented_again}, 0)); - module->AddEntryComputation(b.Build()); - - Literal result = ExecuteAndTransfer(std::move(module), {}); + TF_ASSERT_OK_AND_ASSIGN(auto result, BuildAndExecute({})); LiteralTestUtil::ExpectR3EqualArray3D( Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result); } @@ -173,24 +188,22 @@ XLA_TEST_F(CustomCallTest, InputAndOutputLayoutDiffer) { GTEST_SKIP() << "Appears to test an XLA current implementation detail"; } - auto module = CreateNewVerifiedModule(); - auto b = HloComputation::Builder(TestName()); - auto input = - b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p")); - b.AddInstruction( + builder_.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p")); + builder_.AddInstruction( HloInstruction::CreateCustomCall(r2f32_, {input}, "Add1ToValues")); - module->AddEntryComputation(b.Build()); - ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0})); - ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1})); + module_->AddEntryComputation(builder_.Build()); + ForceParameterLayout(module_.get(), 0, LayoutUtil::MakeLayout({1, 0})); + ForceResultLayout(module_.get(), LayoutUtil::MakeLayout({0, 1})); Literal argument = LiteralUtil::CreateR2({{1.f, 2.f}, {3.f, 4.f}}); // Note, the expected result is transposed! This is because the input and // output layouts of the custom call differ and the called function just // blindly adds one to each element. - Literal result = ExecuteAndTransfer(std::move(module), {&argument}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module_), {&argument})); LiteralTestUtil::ExpectR2Equal({{2.f, 4.f}, {3.f, 5.f}}, result); } @@ -198,26 +211,24 @@ XLA_TEST_F(CustomCallTest, LayoutConstrained) { // The argument and result of the computation are set to different layouts, // but the custom call is layout constrained to a fixed operand and result // layout, so the correct result should be produced. - auto module = CreateNewVerifiedModule(); - auto b = HloComputation::Builder(TestName()); - auto input = - b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p")); + builder_.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p")); const Shape& r2f32_dim0_major = ShapeUtil::MakeShapeWithDenseLayout(F32, {2, 2}, {1, 0}); - auto custom_call = b.AddInstruction(HloInstruction::CreateCustomCall( + auto custom_call = builder_.AddInstruction(HloInstruction::CreateCustomCall( r2f32_dim0_major, {input}, "Add1ToValues", {r2f32_dim0_major})); - b.AddInstruction( + builder_.AddInstruction( custom_call->CloneWithNewOperands(r2f32_dim0_major, {custom_call})); - module->AddEntryComputation(b.Build()); - ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0})); - ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1})); + module_->AddEntryComputation(builder_.Build()); + ForceParameterLayout(module_.get(), 0, LayoutUtil::MakeLayout({1, 0})); + ForceResultLayout(module_.get(), LayoutUtil::MakeLayout({0, 1})); Literal argument = LiteralUtil::CreateR2({{1.f, 2.f}, {3.f, 4.f}}); - Literal result = ExecuteAndTransfer(std::move(module), {&argument}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module_), {&argument})); LiteralTestUtil::ExpectR2Equal({{3.f, 4.f}, {5.f, 6.f}}, result); } @@ -237,63 +248,49 @@ XLA_TEST_F(CustomCallTest, TupleOutput) { Literal arg1 = LiteralUtil::CreateR0(42.f); Literal expected = LiteralUtil::MakeTuple({&arg1, &arg0}); - Literal result = ExecuteAndTransfer(std::move(module), {&arg0, &arg1}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {&arg0, &arg1})); EXPECT_EQ(result, expected); } XLA_TEST_F(CustomCallTest, ReportsSuccess) { - auto module = CreateNewVerifiedModule(); - auto builder = HloComputation::Builder(TestName()); - - auto constant = builder.AddInstruction( + auto constant = builder_.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - builder.AddInstruction(HloInstruction::CreateCustomCall( + builder_.AddInstruction(HloInstruction::CreateCustomCall( r0f32_, {constant}, "R0F32Add2Succeed", /*opaque=*/"", CustomCallApiVersion::API_VERSION_STATUS_RETURNING)); - module->AddEntryComputation(builder.Build()); - - Literal result = ExecuteAndTransfer(std::move(module), {}); + TF_ASSERT_OK_AND_ASSIGN(auto result, BuildAndExecute({})); LiteralTestUtil::ExpectR0Near(44.0f, result, error_spec_); } XLA_TEST_F(CustomCallTest, ReportsFailure) { - auto module = CreateNewVerifiedModule(); - auto builder = HloComputation::Builder(TestName()); - - auto constant = builder.AddInstruction( + auto constant = builder_.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - builder.AddInstruction(HloInstruction::CreateCustomCall( + builder_.AddInstruction(HloInstruction::CreateCustomCall( ShapeUtil::MakeShape(F32, {}), {constant}, "CustomCallFail", /*opaque=*/"", CustomCallApiVersion::API_VERSION_STATUS_RETURNING)); - module->AddEntryComputation(builder.Build()); - - auto status = Execute(std::move(module), {}).status(); + auto status = BuildAndExecute({}).status(); EXPECT_EQ(status.code(), absl::StatusCode::kInternal); EXPECT_THAT(status.message(), ::testing::HasSubstr("Failed: 42.0")); } XLA_TEST_F(CustomCallTest, ReportsFirstFailure) { - auto module = CreateNewVerifiedModule(); - auto builder = HloComputation::Builder(TestName()); - - auto constant_1 = builder.AddInstruction( + auto constant_1 = builder_.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); - auto constant_2 = builder.AddInstruction( + auto constant_2 = builder_.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); - auto res_1 = builder.AddInstruction(HloInstruction::CreateCustomCall( + auto res_1 = builder_.AddInstruction(HloInstruction::CreateCustomCall( ShapeUtil::MakeShape(F32, {}), {constant_1}, "CustomCallFail", /*opaque=*/"", CustomCallApiVersion::API_VERSION_STATUS_RETURNING)); - auto res_2 = builder.AddInstruction(HloInstruction::CreateCustomCall( + auto res_2 = builder_.AddInstruction(HloInstruction::CreateCustomCall( ShapeUtil::MakeShape(F32, {}), {constant_2}, "CustomCallFail", /*opaque=*/"", CustomCallApiVersion::API_VERSION_STATUS_RETURNING)); - builder.AddInstruction(HloInstruction::CreateBinary( + builder_.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, res_1, res_2)); - module->AddEntryComputation(builder.Build()); - - auto status = Execute(std::move(module), {}).status(); + auto status = BuildAndExecute({}).status(); EXPECT_EQ(status.code(), absl::StatusCode::kInternal); EXPECT_THAT(status.message(), ::testing::HasSubstr("Failed: 1.0")); } @@ -354,10 +351,505 @@ XLA_TEST_F(CustomCallClientAPITest, IllegalCustomCallTarget) { CustomCall(&builder, "$illegal", /*operands=*/{}, ShapeUtil::MakeShape(F32, {1})); - StatusOr> result = + absl::StatusOr> result = Execute(&builder, /*arguments=*/{}); EXPECT_FALSE(result.ok()); } +//===----------------------------------------------------------------------===// +// XLA runtime custom call provides type-safe custom call API +//===----------------------------------------------------------------------===// + +namespace { +// Helper function to get data pointer from buffer +template +static NativeType* DataPointer(BufferType& buffer) { + return reinterpret_cast(buffer.data.opaque()); +} + +using R0F32Buffer = typename ffi::BufferR0; +using F32Buffer = typename ffi::Buffer; + +static absl::Status AlwaysSucceed(ffi::BufferBase) { return absl::OkStatus(); } + +XLA_FFI_DEFINE_HANDLER( + kAlwaysSucceed, AlwaysSucceed, + ffi::Ffi::Bind().Arg() // unused out buffer +); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$always_succeed", + "Host", kAlwaysSucceed); + +static absl::Status AlwaysFail(ffi::BufferBase, int32_t value) { + return absl::InternalError(absl::StrCat("Failed: ", value)); +} + +// TODO(abanas): When Result is supported, change output buffers in all +// bindings to use it (e.g. .Arg -> .Result) +XLA_FFI_DEFINE_HANDLER(kAlwaysFail, AlwaysFail, + ffi::Ffi::Bind() + .Arg() // unused out buffer + .Attr("value") // value +); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$always_fail", "Host", + kAlwaysFail); + +static absl::Status FfiR0F32Add2(R0F32Buffer in, R0F32Buffer out) { + auto in_data = DataPointer(in); + auto out_data = DataPointer(out); + *out_data = *in_data + 2.0f; + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kFfiR0F32Add2, FfiR0F32Add2, + ffi::Ffi::Bind() + .Arg() // in + .Arg() // out +); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$FfiR0F32Add2", + "Host", kFfiR0F32Add2); + +// This represents a kernel that is valid only for F32 and F64 types +static absl::Status FfiR0FAdd2BufferBase(ffi::BufferBase in, + ffi::BufferBase out) { + if (in.dtype != out.dtype) { + return absl::InternalError("Input and output dtypes mismatch"); + } + + switch (in.dtype) { + case PrimitiveType::F32: { + auto in_data = DataPointer(in); + auto out_data = DataPointer(out); + *out_data = *in_data + 2.0f; + break; + } + case PrimitiveType::F64: { + auto in_data = DataPointer(in); + auto out_data = DataPointer(out); + *out_data = *in_data + 2.0f; + break; + } + default: + return absl::InternalError("Incorrect type"); + } + + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kFfiR0FAdd2BufferBase, FfiR0FAdd2BufferBase, + ffi::Ffi::Bind() + .Arg() // in + .Arg() // out +); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), + "__xla_test$$FfiR0FAdd2BufferBase", "Host", + kFfiR0FAdd2BufferBase); + +static absl::Status FfiR0F32AddN(R0F32Buffer in, R0F32Buffer out, float n) { + auto in_data = DataPointer(in); + auto out_data = DataPointer(out); + *out_data = *in_data + n; + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kFfiR0F32AddN, FfiR0F32AddN, + ffi::Ffi::Bind() + .Arg() // in + .Arg() // out + .Attr("n")); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$FfiR0F32AddN", + "Host", kFfiR0F32AddN); + +static absl::Status FfiR0F32AddNPointer(R0F32Buffer in, R0F32Buffer out, + float* n) { + auto in_data = DataPointer(in); + auto out_data = DataPointer(out); + *out_data = *in_data + *n; + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kFfiR0F32AddNPointer, FfiR0F32AddNPointer, + ffi::Ffi::Bind() + .Arg() // in + .Arg() // out + .Attr>("n")); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$FfiR0F32AddNPointer", + "Host", kFfiR0F32AddNPointer); + +static absl::Status FfiF32ReduceSum(F32Buffer in, R0F32Buffer out) { + auto in_data = DataPointer(in); + auto out_data = DataPointer(out); + + // Calculate the total size of the vector + const auto size = + absl::c_accumulate(in.dimensions, 1, std::multiplies()); + + // Calculate the sum of the vector + *out_data = absl::c_accumulate(absl::MakeSpan(in_data, size), 0.0f); + + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kFfiF32ReduceSum, FfiF32ReduceSum, + ffi::Ffi::Bind() + .Arg() // in + .Arg() // out +); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$FfiF32ReduceSum", + "Host", kFfiF32ReduceSum); + +static absl::Status FfiF32Add1ToValues(F32Buffer in, F32Buffer out) { + auto in_data = DataPointer(in); + auto out_data = DataPointer(out); + + // Calculate and verify the total size of the vector + const auto in_size = + absl::c_accumulate(in.dimensions, 1, std::multiplies()); + const auto out_size = + absl::c_accumulate(out.dimensions, 1, std::multiplies()); + if (in_size != out_size) { + return absl::InternalError("Input and output sizes mismatch"); + } + + // Actual computations + std::transform(in_data, in_data + in_size, out_data, + [](float x) { return x + 1; }); + + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kFfiF32Add1ToValues, FfiF32Add1ToValues, + ffi::Ffi::Bind() + .Arg() // in + .Arg() // out +); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$FfiF32Add1ToValues", + "Host", kFfiF32Add1ToValues); + +static absl::Status FfiF32TupleSwap(R0F32Buffer in0, R0F32Buffer in1, + R0F32Buffer out0, R0F32Buffer out1) { + auto in_data0 = DataPointer(in0); + auto in_data1 = DataPointer(in1); + auto out_data0 = DataPointer(out0); + auto out_data1 = DataPointer(out1); + *out_data0 = *in_data1; + *out_data1 = *in_data0; + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kFfiF32TupleSwap, FfiF32TupleSwap, + ffi::Ffi::Bind() + .Arg() // in0 + .Arg() // in1 + .Arg() // out0 + .Arg() // out1 +); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$FfiF32TupleSwap", + "Host", kFfiF32TupleSwap); + +} // namespace + +// TODO(abanas): When #10056 (typed FFI support) is ready, this class can be +// replaced by a simple 'using FfiCustomCallTest = CustomCallTest;' +class FfiCustomCallTest : public CustomCallTest { + protected: + void SetUp() override { + GTEST_SKIP() << "Typed FFI is not supported yet on CPU"; + } +}; + +XLA_TEST_F(FfiCustomCallTest, FfiReportsSuccess) { + builder_.AddInstruction(HloInstruction::CreateCustomCall( + r0f32_, {}, "__xla_test$$always_succeed", "", + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + + auto status = BuildAndExecute({}).status(); + EXPECT_EQ(status.code(), absl::StatusCode::kOk); +} + +XLA_TEST_F(FfiCustomCallTest, FfiUnknownTarget) { + builder_.AddInstruction(HloInstruction::CreateCustomCall( + r0f32_, {}, "__xla_test$$unknown_target", "", + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + + auto status = BuildAndExecute({}).status(); + EXPECT_EQ(status.code(), absl::StatusCode::kUnimplemented); +} + +XLA_TEST_F(FfiCustomCallTest, FfiReportsFailure) { + builder_.AddInstruction(HloInstruction::CreateCustomCall( + r0f32_, {}, "__xla_test$$always_fail", + /*opaque=*/"{value = 42 : i32}", + CustomCallApiVersion::API_VERSION_TYPED_FFI)); + + auto status = BuildAndExecute({}).status(); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_THAT(status.message(), ::testing::HasSubstr("Failed: 42")); +} + +XLA_TEST_F(FfiCustomCallTest, FfiReportsFirstFailure) { + auto res_1 = builder_.AddInstruction(HloInstruction::CreateCustomCall( + r0f32_, {}, "__xla_test$$always_fail", + /*opaque=*/"{value = 1 : i32}", + CustomCallApiVersion::API_VERSION_TYPED_FFI)); + auto res_2 = builder_.AddInstruction(HloInstruction::CreateCustomCall( + r0f32_, {}, "__xla_test$$always_fail", + /*opaque=*/"{value = 2 : i32}", + CustomCallApiVersion::API_VERSION_TYPED_FFI)); + builder_.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, res_1, res_2)); + + auto status = BuildAndExecute({}).status(); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_THAT(status.message(), ::testing::HasSubstr("Failed: 1")); +} + +XLA_TEST_F(FfiCustomCallTest, FfiTransitiveCustomCallReportsFirstFailure) { + const char* const kModuleStr = R"( + HloModule m + sub_2 { + ROOT custom-call = f32[] custom-call(), custom_call_target="__xla_test$$always_fail", api_version=API_VERSION_TYPED_FFI, backend_config="{value = 2 : i32}" + } + sub_3 { + ROOT custom-call = f32[] custom-call(), custom_call_target="__xla_test$$always_fail", api_version=API_VERSION_TYPED_FFI, backend_config="{value = 3 : i32}" + } + ENTRY test { + call0 = f32[] call(), to_apply=sub_2 + call1 = f32[] call(), to_apply=sub_3 + ROOT sum = f32[] add(%call0, %call1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + auto status = Execute(std::move(module), {}).status(); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_THAT(status.message(), HasSubstr("Failed: 2")); +} + +XLA_TEST_F(FfiCustomCallTest, FfiWrongNumberOfArguments) { + builder_.AddInstruction(HloInstruction::CreateCustomCall( + r0f32_, {}, "__xla_test$$FfiR0F32Add2", "", + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + + auto status = BuildAndExecute({}).status(); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +XLA_TEST_F(FfiCustomCallTest, FfiWrongTypeOfArguments) { + Array2D array(2, 2); + array(0, 0) = 1.0f; + array(0, 1) = 2.0f; + array(1, 0) = 3.0f; + array(1, 1) = 4.0f; + + auto constant = builder_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(array))); + builder_.AddInstruction(HloInstruction::CreateCustomCall( + r2f32_, {constant}, "__xla_test$$FfiR0F32Add2", "", + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + + auto status = BuildAndExecute({}).status(); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + +XLA_TEST_F(FfiCustomCallTest, FfiHandleTypedBuffers) { + auto constant = builder_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + builder_.AddInstruction(HloInstruction::CreateCustomCall( + r0f32_, {constant}, "__xla_test$$FfiR0F32Add2", "", + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + + TF_ASSERT_OK_AND_ASSIGN(auto result, BuildAndExecute({})); + LiteralTestUtil::ExpectR0Near(44.0f, result, error_spec_); +} + +XLA_TEST_F(FfiCustomCallTest, FfiHandleInputAsParameters) { + auto constant = + builder_.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); + builder_.AddInstruction(HloInstruction::CreateCustomCall( + r0f32_, {constant}, "__xla_test$$FfiR0F32Add2", "", + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + + Literal argument = LiteralUtil::CreateR0(42.0f); + + TF_ASSERT_OK_AND_ASSIGN(auto result, BuildAndExecute({&argument})); + LiteralTestUtil::ExpectR0Near(44.0f, result, error_spec_); +} + +XLA_TEST_F(FfiCustomCallTest, FfiHandleBufferBaseFloat) { + auto constant = builder_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + builder_.AddInstruction(HloInstruction::CreateCustomCall( + r0f32_, {constant}, "__xla_test$$FfiR0FAdd2BufferBase", "", + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + + TF_ASSERT_OK_AND_ASSIGN(auto result, BuildAndExecute({})); + LiteralTestUtil::ExpectR0Near(44.0f, result, error_spec_); +} + +XLA_TEST_F(FfiCustomCallTest, FfiHandleBufferBaseDouble) { + auto constant = builder_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + builder_.AddInstruction(HloInstruction::CreateCustomCall( + r0f32_, {constant}, "__xla_test$$FfiR0FAdd2BufferBase", "", + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + + TF_ASSERT_OK_AND_ASSIGN(auto result, BuildAndExecute({})); + LiteralTestUtil::ExpectR0Near(44.0f, result, error_spec_); +} + +XLA_TEST_F(FfiCustomCallTest, FfiHandleAttr) { + auto constant = builder_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + builder_.AddInstruction(HloInstruction::CreateCustomCall( + r0f32_, {constant}, "__xla_test$$FfiR0F32AddN", + /*opaque=*/"{n = 3.0 : f32}", + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + + TF_ASSERT_OK_AND_ASSIGN(auto result, BuildAndExecute({})); + LiteralTestUtil::ExpectR0Near(45.0f, result, error_spec_); +} + +XLA_TEST_F(FfiCustomCallTest, FfiHandleAttrPointer) { + auto constant = builder_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto n = 4.0f; + auto ptr = reinterpret_cast(&n); + builder_.AddInstruction(HloInstruction::CreateCustomCall( + r0f32_, {constant}, "__xla_test$$FfiR0F32AddN", + /*opaque=*/absl::StrFormat("{n = %d : i64}", ptr), + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + + TF_ASSERT_OK_AND_ASSIGN(auto result, BuildAndExecute({})); + LiteralTestUtil::ExpectR0Near(46.0f, result, error_spec_); +} + +XLA_TEST_F(FfiCustomCallTest, FfiHandleR2Vector) { + Array2D array(2, 2); + array(0, 0) = 1.0f; + array(0, 1) = 2.0f; + array(1, 0) = 3.0f; + array(1, 1) = 4.0f; + + auto constant = builder_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(array))); + builder_.AddInstruction(HloInstruction::CreateCustomCall( + r0f32_, {constant}, "__xla_test$$FfiF32ReduceSum", + /*opaque=*/"", + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + + TF_ASSERT_OK_AND_ASSIGN(auto result, BuildAndExecute({})); + LiteralTestUtil::ExpectR0Near(10.0f, result, error_spec_); +} + +XLA_TEST_F(FfiCustomCallTest, FfiUsedInOtherComputations) { + auto input = builder_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D( + Array2D{{1.0f, 2.0f}, {3.0f, 4.0f}}))); + auto incremented = builder_.AddInstruction(HloInstruction::CreateCustomCall( + ShapeUtil::MakeShape(F32, {1, 2, 2}), {input}, + "__xla_test$$FfiF32Add1ToValues", + /*opaque=*/"", + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + auto incremented_again = + builder_.AddInstruction(HloInstruction::CreateCustomCall( + ShapeUtil::MakeShape(F32, {1, 2, 2}), {incremented}, + "__xla_test$$FfiF32Add1ToValues", + /*opaque=*/"", + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + + // Concatenate the values along first dim. + builder_.AddInstruction( + HloInstruction::CreateConcatenate(ShapeUtil::MakeShape(F32, {2, 2, 2}), + {incremented, incremented_again}, 0)); + + TF_ASSERT_OK_AND_ASSIGN(auto result, BuildAndExecute({})); + LiteralTestUtil::ExpectR3EqualArray3D( + Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result); +} + +XLA_TEST_F(FfiCustomCallTest, FfiInputAndOutputLayoutDiffer) { + if (IsMlirLoweringEnabled()) { + // The MLIR pipeline does /not/ transpose the output here, and there's no + // obvious reason why it should. + GTEST_SKIP() << "Appears to test an XLA current implementation detail"; + } + + auto input = + builder_.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p")); + + builder_.AddInstruction(HloInstruction::CreateCustomCall( + r2f32_, {input}, "__xla_test$$FfiF32Add1ToValues", /*opaque=*/"", + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + + module_->AddEntryComputation(builder_.Build()); + ForceParameterLayout(module_.get(), 0, LayoutUtil::MakeLayout({1, 0})); + ForceResultLayout(module_.get(), LayoutUtil::MakeLayout({0, 1})); + + Literal argument = LiteralUtil::CreateR2({{1.f, 2.f}, {3.f, 4.f}}); + + // Note, the expected result is transposed! This is because the input and + // output layouts of the custom call differ and the called function just + // blindly adds one to each element. + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module_), {&argument})); + LiteralTestUtil::ExpectR2Equal({{2.f, 4.f}, {3.f, 5.f}}, result); +} + +XLA_TEST_F(FfiCustomCallTest, FfiLayoutConstrained) { + // The argument and result of the computation are set to different layouts, + // but the custom call is layout constrained to a fixed operand and result + // layout, so the correct result should be produced. + auto input = + builder_.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p")); + + const Shape& r2f32_dim0_major = + ShapeUtil::MakeShapeWithDenseLayout(F32, {2, 2}, {1, 0}); + auto custom_call = builder_.AddInstruction(HloInstruction::CreateCustomCall( + r2f32_dim0_major, {input}, "__xla_test$$FfiF32Add1ToValues", + /*operand_shapes_with_layout=*/{r2f32_dim0_major}, + /*opaque=*/"", + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + builder_.AddInstruction( + custom_call->CloneWithNewOperands(r2f32_dim0_major, {custom_call})); + + module_->AddEntryComputation(builder_.Build()); + ForceParameterLayout(module_.get(), 0, LayoutUtil::MakeLayout({1, 0})); + ForceResultLayout(module_.get(), LayoutUtil::MakeLayout({0, 1})); + + Literal argument = LiteralUtil::CreateR2({{1.f, 2.f}, {3.f, 4.f}}); + + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module_), {&argument})); + LiteralTestUtil::ExpectR2Equal({{3.f, 4.f}, {5.f, 6.f}}, result); +} + +XLA_TEST_F(FfiCustomCallTest, FfiTupleOutput) { + auto input0 = + builder_.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p0")); + auto input1 = + builder_.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "p1")); + builder_.AddInstruction(HloInstruction::CreateCustomCall( + ShapeUtil::MakeTupleShape({r0f32_, r0f32_}), {input0, input1}, + "__xla_test$$FfiF32TupleSwap", /*opaque=*/"", + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + + Literal arg0 = LiteralUtil::CreateR0(7.f); + Literal arg1 = LiteralUtil::CreateR0(42.f); + + Literal expected = LiteralUtil::MakeTuple({&arg1, &arg0}); + TF_ASSERT_OK_AND_ASSIGN(auto result, BuildAndExecute({&arg0, &arg1})); + EXPECT_EQ(result, expected); +} + } // namespace } // namespace xla diff --git a/xla/tests/deallocation_test.cc b/xla/tests/deallocation_test.cc index 33cb373c5197b..38b81d9650080 100644 --- a/xla/tests/deallocation_test.cc +++ b/xla/tests/deallocation_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/deconstruct_tuple_test.cc b/xla/tests/deconstruct_tuple_test.cc index bf5b25d719d96..48cd0418e0dff 100644 --- a/xla/tests/deconstruct_tuple_test.cc +++ b/xla/tests/deconstruct_tuple_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/deep_graph_test.cc b/xla/tests/deep_graph_test.cc index 0cd1f35dbf132..eed8303b6e0a4 100644 --- a/xla/tests/deep_graph_test.cc +++ b/xla/tests/deep_graph_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/dot_operation_test.cc b/xla/tests/dot_operation_test.cc index 89af28d11b093..6ed8d128ee277 100644 --- a/xla/tests/dot_operation_test.cc +++ b/xla/tests/dot_operation_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -70,7 +70,12 @@ using TypesF16F32F64CF64 = ::testing::Types< #endif float>; +#if GOOGLE_CUDA using TypesF8 = ::testing::Types; +#endif +#if TF_HIPBLASLT && TF_ROCM_VERSION >= 60000 +using TypesF8 = ::testing::Types; +#endif // Check that we can safely pass an input tuple's elements to a dot operation. XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { @@ -289,6 +294,34 @@ std::string PrintDotTestParam( class ParametricDotTest : public DotOperationTest, public ::testing::WithParamInterface { protected: + // This method runs before each test runs. + void SetUp() override { + // Several F16 tests are subject to denormal issues on MI210 architecture. + // For that matter, we set propagate_grad_xy_ flag for these tests, which + // activates adapted GEMM algorithm on ROCM. Besides, the adapted algorithm + // does not work well with ROCBLAS autotuning, hence we also disable it. + // This also serves as a test that grad_x/y attributes are correctly + // propagated down to a GEMM routine. + const auto& gpu_comp = client_->backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + if (std::holds_alternative(gpu_comp)) { + std::string_view name( + ::testing::UnitTest::GetInstance()->current_test_info()->name()); + if (name.find("TestF16/270x270x520_MajorToMinor") != std::string::npos) { + execution_options_.mutable_debug_options()->set_xla_gpu_autotune_level( + 0); + DotTestParam param = GetParam(); + // In order to test both grad_x and grad_y attributes, we set + // propagate_grad_xy_ to 1 or 2 based on some alternating parameter + // to set it deterministically. + propagate_grad_xy_ = param.dot_lhs_row_major ? 1 : 2; + } + } + ManifestCheckingTest::SetUp(); + } + template void TestImpl(); @@ -296,6 +329,8 @@ class ParametricDotTest : public DotOperationTest, void ComputeAndCompareR2WithError(XlaBuilder* builder, const Array2D& expected, absl::Span arguments); + + int32_t propagate_grad_xy_ = 0; }; template @@ -356,6 +391,15 @@ void ParametricDotTest::TestImpl() { XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); + + if (propagate_grad_xy_ != 0) { + FrontendAttributes attributes; + if (propagate_grad_xy_ == 1) + (*attributes.mutable_map())["grad_x"] = "true"; + else + (*attributes.mutable_map())["grad_y"] = "true"; + builder.SetFrontendAttributes(attributes); + } auto result = Dot(Parameter(&builder, 0, ShapeUtil::MakeShapeWithDenseLayout( @@ -367,6 +411,9 @@ void ParametricDotTest::TestImpl() { prim_type, {param.k, param.n}, MinorToMajorForIsRowMajor(param.dot_rhs_row_major)), "dot_rhs")); + if (propagate_grad_xy_ != 0) { + builder.ClearFrontendAttributes(); + } if (param.has_addend) { result = @@ -689,7 +736,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) { {x_data.get(), y_data.get()}, this->error_spec_); } -#if GOOGLE_CUDA || TF_HIPBLASLT +#if GOOGLE_CUDA || (TF_HIPBLASLT && TF_ROCM_VERSION >= 60000) template class DotOperationTestWithCublasLt_F16F32F64CF64 : public DotOperationTest { public: @@ -745,7 +792,7 @@ XLA_TYPED_TEST(DotOperationTestWithCublasLt_F16F32F64CF64, } #endif // GOOGLE_CUDA || TF_HIPBLASLT -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TF_HIPBLASLT template class DotOperationTestWithCublasLt_F8 : public DotOperationTest { public: @@ -1067,7 +1114,7 @@ XLA_TYPED_TEST(DotOperationTestWithCublasLt_F8, ScaledABScaledDWithDAmaxF8) { b_scale_data.get(), d_scale_data.get()}, this->error_spec_); } -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TF_HIPBLASLT XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR3LhsR2Rhs) { using T = TypeParam; @@ -2295,8 +2342,8 @@ void DOT_ReorderContracting(::testing::benchmark::State& state) { ExecutableBuildOptions())); auto executable = std::move(executables[0]); - se::Stream stream(executors[device_ordinal]); - stream.Init(); + TF_ASSERT_OK_AND_ASSIGN(auto stream, + executors[device_ordinal]->CreateStream()); ExecutableRunOptions options; options.set_allocator(&allocator); diff --git a/xla/tests/dynamic_ops_test.cc b/xla/tests/dynamic_ops_test.cc index bc89ea418e6cf..ae3c81f0c6db4 100644 --- a/xla/tests/dynamic_ops_test.cc +++ b/xla/tests/dynamic_ops_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/exhaustive/BUILD b/xla/tests/exhaustive/BUILD index d5ab18cf5177e..7044fc90106ed 100644 --- a/xla/tests/exhaustive/BUILD +++ b/xla/tests/exhaustive/BUILD @@ -1,8 +1,8 @@ # Description: # Computationally expensive, exhaustive tests for XLA -load("//xla/tests:build_defs.bzl", "xla_test") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla/tests:build_defs.bzl", "xla_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc b/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc index 7ab80b39d3f56..dcb25baca2e6e 100644 --- a/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc +++ b/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/exhaustive/exhaustive_binary_test_f32_f64.cc b/xla/tests/exhaustive/exhaustive_binary_test_f32_f64.cc index 68b7c1d3d3daf..0a1c6100a314e 100644 --- a/xla/tests/exhaustive/exhaustive_binary_test_f32_f64.cc +++ b/xla/tests/exhaustive/exhaustive_binary_test_f32_f64.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/exhaustive/exhaustive_op_test_utils.cc b/xla/tests/exhaustive/exhaustive_op_test_utils.cc index 7a6e1bb963ba7..881c394fe4610 100644 --- a/xla/tests/exhaustive/exhaustive_op_test_utils.cc +++ b/xla/tests/exhaustive/exhaustive_op_test_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -321,10 +321,10 @@ void PrintMismatch(int64_t* mismatches, const ErrorGenerator& err_generator) { } // namespace template -void ExhaustiveOpTestBase::ExpectNear(const InputLiterals& input_literals, - const Literal& result_literal, - EvaluateOp evaluate_op, - ErrorSpecGen error_spec_gen) { +void ExhaustiveOpTestBase::ExpectNear( + const InputLiterals& input_literals, const Literal& result_literal, + EvaluateOp evaluate_op, ErrorSpecGen error_spec_gen, + OutputRangeCheck check_valid_range) { // Cache for when all components are subnormal testing values. std::vector pure_subnormal_cache; // Since we take the cross product of all possible test values, and each @@ -364,6 +364,16 @@ void ExhaustiveOpTestBase::ExpectNear(const InputLiterals& input_literals, static_cast(CallOperation(evaluate_op, inputs_ref_ty)); ErrorSpec error_spec = CallErrorSpec(error_spec_gen, inputs); + if (check_valid_range != nullptr && !check_valid_range(actual)) { + PrintMismatch(&mismatches, [&] { + return absl::StrFormat( + "mismatch on input: %s. output: %s, output is not in valid range", + StringifyNum(inputs), + StringifyNum(actual)); + }); + continue; + } + if (IsClose(static_cast(expected), static_cast(actual), error_spec)) { continue; diff --git a/xla/tests/exhaustive/exhaustive_op_test_utils.h b/xla/tests/exhaustive/exhaustive_op_test_utils.h index ec52928c607f2..363ddb357e145 100644 --- a/xla/tests/exhaustive/exhaustive_op_test_utils.h +++ b/xla/tests/exhaustive/exhaustive_op_test_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ limitations under the License. #include #include #include +#include #include "xla/bit_cast.h" #include "xla/client/lib/constants.h" @@ -158,6 +159,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { using ErrorSpecGen = typename ErrorSpecGenWrapper::type; using EvaluateOp = typename EvaluateOpWrapper::type; using EnqueueOp = typename EnqueueOpWrapper::type; + using OutputRangeCheck = std::function; explicit ExhaustiveOpTestBase() : ty_(T), platform_(client_->platform()->Name()) { @@ -168,8 +170,10 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { mutable_debug_options()->clear_xla_disable_hlo_passes(); } - void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op) { - Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator()); + void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op, + OutputRangeCheck check_valid_range = nullptr) { + Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator(), + check_valid_range); } // A helper for implementing the Run method for exhaustive op tests. It @@ -180,7 +184,8 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // called each time an output element is compared inside a loop in routine // ExpectNear. void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op, - ErrorSpecGen error_spec_gen) { + ErrorSpecGen error_spec_gen, + OutputRangeCheck check_valid_range = nullptr) { InputLiterals input_literals = CreateInputLiterals(); FillInput(&input_literals); @@ -195,15 +200,16 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build()); TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, RunComputationHelper(comp, input_literals)); - ExpectNear(input_literals, result_literal, evaluate_op, error_spec_gen); + ExpectNear(input_literals, result_literal, evaluate_op, error_spec_gen, + check_valid_range); } - StatusOr RunComputationHelper(const XlaComputation& comp, - const Literal& literal) { + absl::StatusOr RunComputationHelper(const XlaComputation& comp, + const Literal& literal) { return RunComputation(comp, {&literal}); } - StatusOr RunComputationHelper( + absl::StatusOr RunComputationHelper( const XlaComputation& comp, const std::array& literals) { std::array lit_ptrs; for (int i = 0; i < N; ++i) { @@ -220,15 +226,18 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // c) we need special handling of certain inputs. For example, we say that // a denormal input has multiple correct outputs (namely, f(x) and f(0)) // and just needs to be close to one of them. + // check_valid_range can be used to provide a function that is called with + // the result to check whether it is in the expected range. void ExpectNear(const InputLiterals& input_literals, const Literal& result_literal, EvaluateOp evaluate_op, - ErrorSpecGen error_spec_gen); + ErrorSpecGen error_spec_gen, + OutputRangeCheck check_valid_range = nullptr); // Builds and runs the computation using the LocalClient API, rather than the // plain Client API, which is used by ClientLibraryTestBase. This is because // the plain Client API results does more memcpys to/from Literals, and that's // slow given that we're touching a lot of data here. - StatusOr RunComputation( + absl::StatusOr RunComputation( const XlaComputation& computation, absl::Span input_literals) { // Copy debug options from ClientLibraryTestBase. In particular, we're diff --git a/xla/tests/exhaustive/exhaustive_unary_test_complex.cc b/xla/tests/exhaustive/exhaustive_unary_test_complex.cc index 604559d0a7862..5e6ab943bcb01 100644 --- a/xla/tests/exhaustive/exhaustive_unary_test_complex.cc +++ b/xla/tests/exhaustive/exhaustive_unary_test_complex.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/exhaustive/exhaustive_unary_test_f32_or_smaller.cc b/xla/tests/exhaustive/exhaustive_unary_test_f32_or_smaller.cc index 47a9906aa63e3..d908c4e167aee 100644 --- a/xla/tests/exhaustive/exhaustive_unary_test_f32_or_smaller.cc +++ b/xla/tests/exhaustive/exhaustive_unary_test_f32_or_smaller.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -176,17 +177,36 @@ template class Exhaustive32BitOrLessUnaryTest : public ExhaustiveUnaryTest, public ::testing::WithParamInterface> { + public: + static constexpr size_t kRandomInputSize = 2048; + + public: + Exhaustive32BitOrLessUnaryTest() + : input_lower_bounder_(0), + input_upper_bounder_(0), + special_input_bounder_(false) {} + public: // Sets error parameters appropriately for testing tan. void SetParamsForTan(); + void SetBounder(const float lower_bounder, const float upper_bounder) { + input_lower_bounder_ = lower_bounder; + input_upper_bounder_ = upper_bounder; + special_input_bounder_ = true; + } + protected: using typename ExhaustiveUnaryTest::NativeT; private: int64_t GetInputSize() override { int64_t begin, end; - std::tie(begin, end) = GetParam(); + if (special_input_bounder_) { + return kRandomInputSize; + } else { + std::tie(begin, end) = GetParam(); + } VLOG(2) << "Checking range [" << begin << ", " << end << ")"; return end - begin; } @@ -198,11 +218,21 @@ class Exhaustive32BitOrLessUnaryTest // the same bit as the type being tested, if needed, and then bitcasted to the // type being tested. void FillInput(std::array* input_literal) override { + int64_t begin, end; + if (special_input_bounder_) { + begin = input_lower_bounder_; + end = input_upper_bounder_; + FillRandomInput(input_literal, begin, end); + } else { + std::tie(begin, end) = GetParam(); + FillNormalInput(input_literal, begin, end); + } + } + void FillNormalInput(std::array* input_literal, + const int64_t begin, const int64_t end) { using IntegralT = typename ExhaustiveOpTestBase::ComponentIntegralNativeT; int64_t input_size = (*input_literal)[0].element_count(); - int64_t begin, end; - std::tie(begin, end) = GetParam(); VLOG(2) << "Checking range [" << begin << ", " << end << ")"; CHECK_EQ(input_size, end - begin); @@ -213,6 +243,25 @@ class Exhaustive32BitOrLessUnaryTest this->ConvertAndReplaceKnownIncorrectValueWith(input_val, 0); } } + + void FillRandomInput(std::array* input_literal, + const int64_t begin, const int64_t end) { + absl::Span input_arr = (*input_literal)[0].data(); + + uint32_t size = kRandomInputSize; + NativeT inputs[kRandomInputSize]; + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dist(static_cast(begin), + static_cast(end)); + for (uint32_t i = 0; i < size; ++i) { + inputs[i] = NativeT(dist(gen)); + input_arr[i] = inputs[i]; + } + } + float input_lower_bounder_; + float input_upper_bounder_; + bool special_input_bounder_; }; using ExhaustiveF32UnaryTest = Exhaustive32BitOrLessUnaryTest; @@ -353,6 +402,7 @@ UNARY_TEST_FLOAT_32_BITS_OR_LESS(Logistic, { // pow(x, 0.5), but this is not true for x == -inf. UNARY_TEST_FLOAT_32_BITS_OR_LESS(PowOneHalf, { EvaluateOp fn = +[](float x) { return std::pow(x, 0.5f); }; + Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); }, fn); }) @@ -468,7 +518,29 @@ UNARY_TEST_FLOAT_32_BITS_OR_LESS(Sinh, { Run(Sinh, host_sinh); }) -UNARY_TEST_FLOAT_32_BITS_OR_LESS(Tanh, { +UNARY_TEST_FLOAT_32_BITS_OR_LESS(TanhBounderTestUpperBound, { + SetBounder(8, 9); + ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); + if (platform_ == "CUDA" || platform_ == "CPU") { + error_spec_gen = +[](NativeT x) { return ErrorSpec{0, 0}; }; + } + Run( + Tanh, +[](float) { return 1.0f; }, error_spec_gen, + [](NativeT actual) { return actual >= -1 && actual <= 1; }); +}) + +UNARY_TEST_FLOAT_32_BITS_OR_LESS(TanhBounderTestLowerBound, { + SetBounder(-9, -8); + ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); + if (platform_ == "CUDA" || platform_ == "CPU") { + error_spec_gen = +[](NativeT x) { return ErrorSpec{0, 0}; }; + } + Run( + Tanh, +[](float) { return -1.0f; }, error_spec_gen, + [](NativeT actual) { return actual >= -1 && actual <= 1; }); +}) + +UNARY_TEST_FLOAT_32_BITS_OR_LESS(TanhNormalTest, { ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); if (platform_ == "CUDA") { error_spec_gen = +[](NativeT x) { diff --git a/xla/tests/exhaustive/exhaustive_unary_test_f64.cc b/xla/tests/exhaustive/exhaustive_unary_test_f64.cc index 96ebfb7db8378..f95a03c209e87 100644 --- a/xla/tests/exhaustive/exhaustive_unary_test_f64.cc +++ b/xla/tests/exhaustive/exhaustive_unary_test_f64.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/fft_test.cc b/xla/tests/fft_test.cc index 40adafb8e9963..17ee3bea43f0d 100644 --- a/xla/tests/fft_test.cc +++ b/xla/tests/fft_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/filecheck.cc b/xla/tests/filecheck.cc index 5f261c2e7575d..5ef61bf653008 100644 --- a/xla/tests/filecheck.cc +++ b/xla/tests/filecheck.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include "xla/tests/filecheck.h" #include +#include #include "xla/types.h" #include "xla/util.h" @@ -24,11 +25,12 @@ limitations under the License. #include "tsl/platform/path.h" #include "tsl/platform/resource_loader.h" #include "tsl/platform/subprocess.h" +#include "tsl/platform/test.h" namespace xla { -StatusOr RunFileCheck(const std::string& input, - absl::string_view pattern) { +absl::StatusOr RunFileCheck(const std::string& input, + absl::string_view pattern) { // Generate an input file for the FileCheck pattern. std::string pattern_path; auto env = tsl::Env::Default(); @@ -40,16 +42,34 @@ StatusOr RunFileCheck(const std::string& input, return RunFileCheckWithPatternFile(input, pattern_path); } -StatusOr RunFileCheckWithPatternFile(const std::string& input, - const std::string& pattern_file) { +absl::StatusOr RunFileCheckWithPatternFile( + const std::string& input, const std::string& pattern_file) { // Invoke FileCheck to check whether input matches `pattern`. + std::string binary_name = "FileCheck"; + tsl::io::AppendDotExeIfWindows(binary_name); std::string file_check_path = tsl::GetDataDependencyFilepath( - tsl::io::JoinPath("external", "llvm-project", "llvm", "FileCheck")); + tsl::testing::kIsOpenSource + ? tsl::io::JoinPath("external", "llvm-project", "llvm", binary_name) + : tsl::io::JoinPath("llvm", "llvm-project", "llvm", binary_name)); tsl::SubProcess file_check_process; +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + std::string file_check_prefixes; +#if GOOGLE_CUDA + file_check_prefixes = "--check-prefixes=CHECK,CHECK-PTX"; +#endif // GOOGLE_CUDA +#if TENSORFLOW_USE_ROCM + file_check_prefixes = "--check-prefixes=CHECK,CHECK-GCN"; +#endif // TENSORFLOW_USE_ROCM + file_check_process.SetProgram( + file_check_path, + {file_check_path, "-v", "-dump-input=fail", "--dump-input-filter=all", + file_check_prefixes, "--allow-unused-prefixes", pattern_file}); +#else // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM) file_check_process.SetProgram(file_check_path, {file_check_path, "-v", "-dump-input=fail", "--dump-input-filter=all", pattern_file}); +#endif file_check_process.SetChannelAction(tsl::CHAN_STDIN, tsl::ACTION_PIPE); file_check_process.SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE); if (!file_check_process.Start()) { diff --git a/xla/tests/filecheck.h b/xla/tests/filecheck.h index ee3f6e185e285..f03609f8bea4c 100644 --- a/xla/tests/filecheck.h +++ b/xla/tests/filecheck.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,14 +26,14 @@ namespace xla { // Runs FileCheck with the given pattern over given input string. Provided that // FileCheck can execute, returns true if and only if FileCheck succeeded in // matching the input. -StatusOr RunFileCheck(const std::string& input, - absl::string_view pattern); +absl::StatusOr RunFileCheck(const std::string& input, + absl::string_view pattern); // Runs FileCheck with the given pattern file over given input string. Provided // that FileCheck can execute, returns true if and only if FileCheck succeeded // in matching the input. -StatusOr RunFileCheckWithPatternFile(const std::string& input, - const std::string& pattern_file); +absl::StatusOr RunFileCheckWithPatternFile( + const std::string& input, const std::string& pattern_file); } // namespace xla diff --git a/xla/tests/float8_test.cc b/xla/tests/float8_test.cc index 7896eda569501..f770b495bb841 100644 --- a/xla/tests/float8_test.cc +++ b/xla/tests/float8_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/floor_ceil_test.cc b/xla/tests/floor_ceil_test.cc index f42c59f835a65..05543af7ef1e3 100644 --- a/xla/tests/floor_ceil_test.cc +++ b/xla/tests/floor_ceil_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/fmax_fmin_test.cc b/xla/tests/fmax_fmin_test.cc index 1839e54727eff..dbf34c8d699d5 100644 --- a/xla/tests/fmax_fmin_test.cc +++ b/xla/tests/fmax_fmin_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/fuzz/BUILD b/xla/tests/fuzz/BUILD index 15ea1ef5035a1..9f988e3c77143 100644 --- a/xla/tests/fuzz/BUILD +++ b/xla/tests/fuzz/BUILD @@ -16,6 +16,7 @@ cc_library( [hlo_test( name = hlo + "_test", hlo = hlo, + tags = (["no_rocm"] if hlo == "rand_000079.hlo" else []), # No int8 ) for hlo in glob( include = ["rand_*.hlo"], exclude = [ diff --git a/xla/tests/fuzz/hlo_test_lib.cc b/xla/tests/fuzz/hlo_test_lib.cc index da181e821712c..7b74e32306de2 100644 --- a/xla/tests/fuzz/hlo_test_lib.cc +++ b/xla/tests/fuzz/hlo_test_lib.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/gather_operation_test.cc b/xla/tests/gather_operation_test.cc index 44d43141b8fe9..b80d1570a9b3b 100644 --- a/xla/tests/gather_operation_test.cc +++ b/xla/tests/gather_operation_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/generate_complex_unary_op_samples.py b/xla/tests/generate_complex_unary_op_samples.py new file mode 100644 index 0000000000000..5dcb848f0eae3 --- /dev/null +++ b/xla/tests/generate_complex_unary_op_samples.py @@ -0,0 +1,231 @@ +"""A script to generate the complex_unary_op_samples.h file. + +The generated file contains samples and reference values of complex unary +functions used by the complex_unary_op_test program. + +Prerequisites: + jax version 0.4.26 or newer + mpmath 1.3 + numpy + +Usage: + Running + python /path/to/generate_complex_unary_op_samples.py + will create + /path/to/generate_complex_unary_op_samples.h +""" + +import os +import re +import sys +import jax._src.test_util as jtu +import mpmath +import numpy as np + + +def disable(op, real, imag): + del op, real, imag + # Return True to disable samples (real, imag) that are know to be + # problematic for the op. + return False + + +def main(): + default_size = 7 + nmp = jtu.numpy_with_mpmath(mpmath, extra_prec_multiplier=1) + blocks = [] + for opname in ['Log1p']: + mpmath_op = opname.lower() + size_re, size_im = dict(Log1p=(7, 7)).get( + opname, (default_size, default_size) + ) + ifblocks = [] + input_ttype = 'std::complex' + output_ttype = 'TBD' + for dtype in [np.complex64, np.complex128]: + float_dtype = {np.complex64: np.float32, np.complex128: np.float64}[dtype] + ctype = {np.float32: 'float', np.float64: 'double'}[float_dtype] + cnan = {np.float32: 'std::nanf("")', np.float64: 'std::nan("")'}[ + float_dtype + ] + pi = float_dtype(np.pi) + h_pi = float_dtype(np.pi / 2) + q_pi = float_dtype(np.pi / 4) + tq_pi = float_dtype(3 * np.pi / 4) + cfloat_suffix = 'f' if float_dtype == np.float32 else '' + cpi = str(pi) + cfloat_suffix + cpi_2 = str(h_pi) + cfloat_suffix + cpi_4 = str(q_pi) + cfloat_suffix + cpi3_4 = str(tq_pi) + cfloat_suffix + czero = str(float_dtype(0)) + cfloat_suffix + + sample = jtu.complex_plane_sample(dtype, size_re=size_re, size_im=size_im) + values = getattr(nmp, mpmath_op)(sample) + finfo = np.finfo(float_dtype) + + # pylint: disable=cell-var-from-loop + def _tostr(v): + if v == pi: + return 'pi' + if v == -pi: + return '-pi' + if v == h_pi: + return 'pi_2' + if v == -h_pi: + return '-pi_2' + if v == q_pi: + return 'pi_4' + if v == -q_pi: + return '-pi_4' + if v == tq_pi: + return 'pi3_4' + if v == -tq_pi: + return '-pi3_4' + if v == finfo.max: + return 'max' + if v == -finfo.max: + return '-max' + if v == finfo.tiny: + return 'min' + if v == -finfo.tiny: + return '-min' + if np.isnan(v): + return 'nan' + if np.isneginf(v): + return '-inf' + if np.isposinf(v): + return 'inf' + if v == 0.0: + return 'zero' + if float_dtype == np.float32: + s = f'{v:.6e}f' + elif float_dtype == np.float64: + s = f'{v:.15e}' + else: + assert 0 # unreachable + return re.sub(r'0+e', 'e', s) + + used_constants = set() + + def tostr(v): + r = _tostr(v) + used_constants.add(r.removeprefix('-')) + return r + + rows = [] + counter = 0 + for x, y in zip(sample.flatten(), values.flatten()): + re_x, im_x = tostr(x.real), tostr(x.imag) + if disable(opname, re_x, im_x): + prefix = '// ' + else: + # to ease tracking mismatching cases: + prefix = f'/* {counter} */ ' + counter += 1 + if values.dtype.kind == 'c': + output_ttype = 'std::complex' + re_y, im_y = tostr(y.real), tostr(y.imag) + scale = tostr(np.ldexp(1.0, -np.frexp(abs(y))[1])) + rows.append( + f'{prefix}{{ {{ {re_x}, {im_x} }}, {{ {re_y}, {im_y} }},' + f' {scale} }}' + ) + else: + assert values.dtype.kind == 'f' + output_ttype = 'T' + # Scale is power of 2 so that multiplication with + # it has minimal effect to the binary mantissa + # part of other operand. + scale = tostr(np.ldexp(1.0, -np.frexp(abs(y))[1])) + rows.append( + f'{prefix}{{ {{ {re_x}, {im_x} }}, {tostr(y)}, {scale} }}' + ) + rows = ',\n '.join(rows) + + constants = [] + for name, value in dict( + nan=cnan, + pi=cpi, + pi_4=cpi_4, + pi_2=cpi_2, + pi3_4=cpi3_4, + zero=czero, + ).items(): + if name in used_constants: + constants.append(f'const T {name} = {value};') + constants = '\n '.join(constants) + + ifblocks.append(f"""\ +if constexpr (std::is_same_v) {{ + {constants} + const TableType table{{ + {rows} + }}; + return table; + }}""") + ifblocks.append('{ static_assert(false); /* unreachable */ }') + ifblocks = ' else '.join(ifblocks) + blocks.append(f""" + template + struct {opname} {{ + typedef {input_ttype} InputType; + typedef {output_ttype} OutputType; + typedef T FloatType; + using TableType = std::vector>; + static constexpr int dps_deficiency = default_dps_deficiency; + const TableType get() {{ + const T inf = std::numeric_limits::infinity(); + const T min = std::numeric_limits::min(); + const T max = std::numeric_limits::max(); + {ifblocks} + }} + }}; +""") + blocks = '\n'.join(blocks) + + output_filename = os.path.join( + os.path.dirname(__file__), 'complex_unary_op_samples.h' + ) + output = open(output_filename, 'w') + + output.write(f"""\ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* + This file is generated using xla/tests/{os.path.basename(__file__)}. Do not edit! + */ + +#include +#include +#include +#include +#include + +#ifndef TENSORFLOW_COMPILER_XLA_TESTS_COMPLEX_UNARY_OP_SAMPLES_H_ +#define TENSORFLOW_COMPILER_XLA_TESTS_COMPLEX_UNARY_OP_SAMPLES_H_ + +namespace complex_unary_op_samples {{ +{blocks} +}} + +#endif // TENSORFLOW_COMPILER_XLA_TESTS_COMPLEX_UNARY_OP_SAMPLES_H_ +""") + output.close() + sys.stdout.write(f'Created {output_filename}\n') + + +if __name__ == '__main__': + main() diff --git a/xla/tests/get_dimension_size_test.cc b/xla/tests/get_dimension_size_test.cc index 9d43d0d24469d..9538eb4d27263 100644 --- a/xla/tests/get_dimension_size_test.cc +++ b/xla/tests/get_dimension_size_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/grouped_convolution_test.cc b/xla/tests/grouped_convolution_test.cc index 3e5f9094bd4a4..45cc3fe4112eb 100644 --- a/xla/tests/grouped_convolution_test.cc +++ b/xla/tests/grouped_convolution_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/half_test.cc b/xla/tests/half_test.cc index c083da3a9eef4..0933a85a261c7 100644 --- a/xla/tests/half_test.cc +++ b/xla/tests/half_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/hlo_metadata_test.cc b/xla/tests/hlo_metadata_test.cc index d8d00d0442a5c..ed5260426044e 100644 --- a/xla/tests/hlo_metadata_test.cc +++ b/xla/tests/hlo_metadata_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -60,7 +60,6 @@ TEST_F(HloMetadataTest, MetadataPropagation) { ->root_instruction(); EXPECT_THAT(instruction->metadata().op_type(), StrEq("add")); EXPECT_THAT(instruction->metadata().op_name(), StrEq("my_sum_op")); - EXPECT_NE(instruction->metadata().logical_creation_pass_id(), 0); } TEST_F(HloMetadataTest, MetadataClearing) { diff --git a/xla/tests/hlo_test_base.cc b/xla/tests/hlo_test_base.cc index e9358e88f53be..06c75fec1fc0e 100644 --- a/xla/tests/hlo_test_base.cc +++ b/xla/tests/hlo_test_base.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/log/check.h" #include "absl/types/span.h" #include "xla/debug_options_flags.h" #include "xla/layout_util.h" @@ -97,6 +98,7 @@ HloTestBase::HloTestBase(se::Platform* test_platform, verifier_layout_sensitive_(verifier_layout_sensitive), allow_mixed_precision_in_hlo_verifier_( allow_mixed_precision_in_hlo_verifier), + instruction_can_change_layout_func_(instruction_can_change_layout_func), test_platform_(test_platform) { hlo_verifier_ = std::make_unique( /*layout_sensitive=*/verifier_layout_sensitive, @@ -127,28 +129,26 @@ std::unique_ptr HloTestBase::CreateNewVerifiedModule( return std::make_unique( name, GetModuleConfigForTest(replica_count), verifier_layout_sensitive_, allow_mixed_precision_in_hlo_verifier_, - backend().compiler()->ShapeSizeBytesFunction()); + backend().compiler()->ShapeSizeBytesFunction(), + instruction_can_change_layout_func_); } -StatusOr> +absl::StatusOr> HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, int64_t replica_count, int64_t num_partitions) { - TF_ASSIGN_OR_RETURN( - auto module, - ParseAndReturnVerifiedModule( - hlo_text, GetModuleConfigForTest(replica_count, num_partitions))); - UpdateEntryComputationLayout(module.get()); - return module; + return ParseAndReturnVerifiedModule( + hlo_text, GetModuleConfigForTest(replica_count, num_partitions)); } -StatusOr> +absl::StatusOr> HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, const HloModuleConfig& config) { auto module = std::make_unique( TestName(), config, verifier_layout_sensitive_, allow_mixed_precision_in_hlo_verifier_, - backend().compiler()->ShapeSizeBytesFunction()); + backend().compiler()->ShapeSizeBytesFunction(), + instruction_can_change_layout_func_); TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text)); UpdateEntryComputationLayout(module.get()); return std::move(module); @@ -167,8 +167,8 @@ void HloTestBase::UpdateEntryComputationLayout(HloModule* module) { } /* static */ -StatusOr HloTestBase::RunHloPass(HloPassInterface* hlo_pass, - HloModule* module) { +absl::StatusOr HloTestBase::RunHloPass(HloPassInterface* hlo_pass, + HloModule* module) { const std::string module_str_before_run = module->ToProto().ShortDebugString(); const auto status_or = hlo_pass->Run(module); @@ -188,8 +188,8 @@ StatusOr HloTestBase::RunHloPass(HloPassInterface* hlo_pass, } /* static */ -StatusOr HloTestBase::RunHloPass(HloPassInterface&& hlo_pass, - HloModuleGroup* module_group) { +absl::StatusOr HloTestBase::RunHloPass(HloPassInterface&& hlo_pass, + HloModuleGroup* module_group) { const std::string module_group_str_before_run = module_group->ToProto().ShortDebugString(); const auto status_or = hlo_pass.RunOnModuleGroup(module_group); @@ -291,8 +291,8 @@ void HloTestBase::RunAndFilecheckHloModuleGroupRewrite( } } -StatusOr HloTestBase::Execute(std::unique_ptr module, - absl::Span arguments) { +absl::StatusOr HloTestBase::Execute( + std::unique_ptr module, absl::Span arguments) { return runner_->Execute(std::move(module), arguments); } @@ -304,11 +304,12 @@ Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr module, .value(); } -StatusOr> HloTestBase::GetHloRunner() { +absl::StatusOr> +HloTestBase::GetHloRunner() { if (runner_ != nullptr) { return std::move(runner_); } - StatusOr> status_or_runner = + absl::StatusOr> status_or_runner = GetHloRunnerForTest(test_platform_); // Test for successful creation of PjRt based Hlo Runner. @@ -322,7 +323,7 @@ Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr module, return runner_->Execute(std::move(module), arguments, true, nullptr).value(); } -StatusOr> HloTestBase::ExecuteReplicated( +absl::StatusOr> HloTestBase::ExecuteReplicated( std::unique_ptr module, absl::Span arguments, int64_t num_replicas, bool use_threads, bool run_hlo_passes) { HloRunner::ReplicatedExecuteOptions options; @@ -336,7 +337,7 @@ StatusOr> HloTestBase::ExecuteReplicated( return runner_->ExecuteReplicated(std::move(module), options); } -StatusOr> HloTestBase::ExecuteReplicated( +absl::StatusOr> HloTestBase::ExecuteReplicated( std::unique_ptr module, absl::Span arguments, int64_t num_replicas, DeviceAssignment* device_assignment, bool run_hlo_passes, bool use_threads) { @@ -351,7 +352,7 @@ StatusOr> HloTestBase::ExecuteReplicated( device_assignment); } -StatusOr> HloTestBase::ExecuteReplicated( +absl::StatusOr> HloTestBase::ExecuteReplicated( std::function executable_provider, std::function argument_count_provider, std::function argument_provider, @@ -366,7 +367,7 @@ StatusOr> HloTestBase::ExecuteReplicated( options, device_assignment); } -StatusOr> HloTestBase::MakeReferenceModule( +absl::StatusOr> HloTestBase::MakeReferenceModule( const HloModule& test_module, const std::function& reference_preprocessor) { std::unique_ptr reference_module = test_module.Clone(); @@ -384,7 +385,7 @@ StatusOr> HloTestBase::MakeReferenceModule( return std::move(reference_module); } -StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( +absl::StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( std::unique_ptr module, const absl::Span arguments, const optional& error, bool run_hlo_passes, @@ -490,7 +491,7 @@ ::testing::AssertionResult HloTestBase::RunAndCompare( reference_preprocessor); } -StatusOr<::testing::AssertionResult> +absl::StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareTwoModulesInternal( std::unique_ptr module_0, std::unique_ptr module_1, const absl::Span arguments, @@ -523,7 +524,8 @@ ::testing::AssertionResult HloTestBase::RunAndCompareTwoModules( ::testing::AssertionResult HloTestBase::RunAndCompareTwoModules( std::unique_ptr module_0, std::unique_ptr module_1, - const optional& error, bool run_hlo_passes) { + const optional& error, bool run_hlo_passes, + std::optional args_max_bits_of_precision) { const auto params_0 = module_0->entry_computation()->parameter_instructions(); const auto params_1 = module_1->entry_computation()->parameter_instructions(); for (int i = 0; i < params_0.size(); ++i) { @@ -559,11 +561,14 @@ ::testing::AssertionResult HloTestBase::RunAndCompareTwoModules( } } - auto fake_arguments = MakeFakeArguments(module_0.get()).value(); + absl::StatusOr> fake_arguments = MakeFakeArguments( + module_0.get(), /*pseudo_random=*/true, /*use_large_range=*/false, + /*treat_gte_as_data_formatting=*/false, args_max_bits_of_precision); + CHECK_OK(fake_arguments); std::vector fake_argument_ptrs; absl::c_transform( - fake_arguments, std::back_inserter(fake_argument_ptrs), + *fake_arguments, std::back_inserter(fake_argument_ptrs), [](const Literal& literal) { return const_cast(&literal); }); return RunAndCompareTwoModules(std::move(module_0), std::move(module_1), @@ -572,7 +577,8 @@ ::testing::AssertionResult HloTestBase::RunAndCompareTwoModules( ::testing::AssertionResult HloTestBase::RunAndCompareTwoModules( string_view hlo_string_module_0, string_view hlo_string_module_1, - const std::optional& error, bool run_hlo_passes) { + const std::optional& error, bool run_hlo_passes, + std::optional args_max_bits_of_precision) { auto module_0_or_status = ParseAndReturnVerifiedModule(hlo_string_module_0); if (!module_0_or_status.ok()) { return ::testing::AssertionFailure() @@ -588,21 +594,69 @@ ::testing::AssertionResult HloTestBase::RunAndCompareTwoModules( } return RunAndCompareTwoModules(std::move(module_0_or_status).value(), std::move(module_1_or_status).value(), error, - run_hlo_passes); + run_hlo_passes, args_max_bits_of_precision); +} + +::testing::AssertionResult HloTestBase::RunAndCompareTwoModules( + string_view hlo_string_module_0, string_view hlo_string_module_1, + const HloModuleConfig& config_0, const HloModuleConfig& config_1, + const std::optional& error, bool run_hlo_passes, + std::optional args_max_bits_of_precision) { + auto module_0_or_status = + ParseAndReturnVerifiedModule(hlo_string_module_0, config_0); + if (!module_0_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_0_or_status.status().ToString(); + } + + auto module_1_or_status = + ParseAndReturnVerifiedModule(hlo_string_module_1, config_1); + if (!module_1_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_1_or_status.status().ToString(); + } + return RunAndCompareTwoModules(std::move(module_0_or_status).value(), + std::move(module_1_or_status).value(), error, + run_hlo_passes, args_max_bits_of_precision); +} + +::testing::AssertionResult HloTestBase::RunAndCompareTwoModules( + absl::string_view hlo_string_module_0, + absl::string_view hlo_string_module_1, + const absl::Span arguments, + const std::optional& error, bool run_hlo_passes) { + auto module_0_or_status = ParseAndReturnVerifiedModule(hlo_string_module_0); + if (!module_0_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_0_or_status.status().ToString(); + } + + auto module_1_or_status = ParseAndReturnVerifiedModule(hlo_string_module_1); + if (!module_1_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_1_or_status.status().ToString(); + } + return RunAndCompareTwoModules(std::move(module_0_or_status).value(), + std::move(module_1_or_status).value(), + arguments, error, run_hlo_passes); } ::testing::AssertionResult HloTestBase::Run( string_view hlo_string, bool run_hlo_passes, ExecutionProfile* profile, - const tsl::protobuf::Message* backend_config) { + const tsl::protobuf::Message* backend_config, bool use_random_data) { auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); if (!module_or_status.ok()) { return ::testing::AssertionFailure() << "Error while parsing HLO text format: " << module_or_status.status().ToString(); } - std::unique_ptr module = std::move(module_or_status.value()); - const auto fake_arguments = MakeFakeArguments(module.get()).value(); + const auto fake_arguments = + MakeFakeArguments(module.get(), use_random_data).value(); std::vector fake_argument_ptrs; absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), @@ -732,7 +786,7 @@ ::testing::AssertionResult HloTestBase::RunMultipleTimes( std::optional canonical_output; for (int i = 0; i < n; ++i) { - StatusOr output = + absl::StatusOr output = runner_->ExecuteWithExecutable(executables[i].get(), fake_arguments[i], /*profile=*/&((*profiles)[i])); if (!output.ok()) { @@ -845,13 +899,13 @@ void HloTestBase::MatchOptimizedHlo(absl::string_view hlo, GetOptimizedModule(hlo)); HloPrintOptions print_opts; print_opts.set_print_operand_shape(print_operand_shape); - StatusOr filecheck_result = + absl::StatusOr filecheck_result = RunFileCheck(optimized_module->ToString(print_opts), pattern); TF_ASSERT_OK(filecheck_result.status()); EXPECT_TRUE(filecheck_result.value()); } -StatusOr> HloTestBase::GetOptimizedModule( +absl::StatusOr> HloTestBase::GetOptimizedModule( absl::string_view hlo) { TF_ASSIGN_OR_RETURN( std::unique_ptr module, @@ -861,15 +915,15 @@ StatusOr> HloTestBase::GetOptimizedModule( backend().default_stream_executor()->GetAllocator()); } -StatusOr> HloTestBase::GetOptimizedModule( +absl::StatusOr> HloTestBase::GetOptimizedModule( std::unique_ptr hlo_module) { return backend().compiler()->RunHloPasses( std::move(hlo_module), backend().default_stream_executor(), backend().default_stream_executor()->GetAllocator()); } -StatusOr> HloTestBase::GetHloRunnerForTest( - se::Platform* test_platform) { +absl::StatusOr> +HloTestBase::GetHloRunnerForTest(se::Platform* test_platform) { if (ShouldUsePjRt()) { PjRtClientTestFactoryRegistry& pjrt_registry = GetGlobalPjRtClientTestFactory(); diff --git a/xla/tests/hlo_test_base.h b/xla/tests/hlo_test_base.h index b895be5a36fb2..f8f91a8df51b9 100644 --- a/xla/tests/hlo_test_base.h +++ b/xla/tests/hlo_test_base.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -90,11 +90,13 @@ class HloTestBase : public ManifestCheckingTest { const std::string& name = TestName(), int64_t replica_count = 1); // Parses the given string and returns module as a VerifiedHloModule. - StatusOr> ParseAndReturnVerifiedModule( - absl::string_view hlo_text, int64_t replica_count = 1, - int64_t num_partitions = 1); - StatusOr> ParseAndReturnVerifiedModule( - absl::string_view hlo_text, const HloModuleConfig& config); + absl::StatusOr> + ParseAndReturnVerifiedModule(absl::string_view hlo_text, + int64_t replica_count = 1, + int64_t num_partitions = 1); + absl::StatusOr> + ParseAndReturnVerifiedModule(absl::string_view hlo_text, + const HloModuleConfig& config); // Runs the hlo_pass with the provided module and returns the result. This // function also verifies that the module remains unchanged when hlo_pass @@ -104,14 +106,14 @@ class HloTestBase : public ManifestCheckingTest { // `RunHloPass(MyPass(), module)` all in one line. The reason for the // overload that takes a pointer is that, at one point in the past, non-const // lvalue references were banned in Google code. - static StatusOr RunHloPass(HloPassInterface* hlo_pass, - HloModule* module); - static StatusOr RunHloPass(HloPassInterface& hlo_pass, - HloModule* module) { + static absl::StatusOr RunHloPass(HloPassInterface* hlo_pass, + HloModule* module); + static absl::StatusOr RunHloPass(HloPassInterface& hlo_pass, + HloModule* module) { return RunHloPass(&hlo_pass, module); } - static StatusOr RunHloPass(HloPassInterface&& hlo_pass, - HloModule* module) { + static absl::StatusOr RunHloPass(HloPassInterface&& hlo_pass, + HloModule* module) { return RunHloPass(&hlo_pass, module); } @@ -119,8 +121,8 @@ class HloTestBase : public ManifestCheckingTest { // This method runs the input HLO module group pass for a `HloModuleGroup` and // it also verifies the module group remains unchanged when hlo_pass returns // false as the StatusOr value. - static StatusOr RunHloPass(HloPassInterface&& hlo_pass, - HloModuleGroup* module_group); + static absl::StatusOr RunHloPass(HloPassInterface&& hlo_pass, + HloModuleGroup* module_group); static PrecisionConfig DefaultPrecisionConfig(int operands); @@ -140,10 +142,10 @@ class HloTestBase : public ManifestCheckingTest { } // Compiles and returns module with optimizations from a given HLO. - StatusOr> GetOptimizedModule( + absl::StatusOr> GetOptimizedModule( absl::string_view hlo); - StatusOr> GetOptimizedModule( + absl::StatusOr> GetOptimizedModule( std::unique_ptr hlo_module); protected: @@ -202,8 +204,8 @@ class HloTestBase : public ManifestCheckingTest { } // Executes the given module and return the result as a Literal. - StatusOr Execute(std::unique_ptr module, - absl::Span arguments); + absl::StatusOr Execute(std::unique_ptr module, + absl::Span arguments); // Same as above, except the module will be executed without running any HLO // passes on it. @@ -214,7 +216,7 @@ class HloTestBase : public ManifestCheckingTest { absl::Span arguments); // Compile the given module to an executable. - StatusOr> CreateExecutable( + absl::StatusOr> CreateExecutable( std::unique_ptr module, bool run_hlo_passes) { return runner_->CreateExecutable(std::move(module), run_hlo_passes); } @@ -224,18 +226,18 @@ class HloTestBase : public ManifestCheckingTest { // use_threads indicates whether this replicated computation will be executed // with a thread-per-replica, vs using an implicitly async call such as // Executable::ExecuteOnStreams. - StatusOr> ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( std::unique_ptr module, absl::Span arguments, int64_t num_replicas, bool use_threads, bool run_hlo_passes = false); // Same as above, but uses specified device assignment. - StatusOr> ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( std::unique_ptr module, absl::Span arguments, int64_t num_replicas, DeviceAssignment* device_assignment, bool run_hlo_passes, bool use_threads); // Same as above, but allows passing different programs for replicas. - StatusOr> ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( std::function executable_provider, std::function argument_count_provider, std::function argument_provider, @@ -292,7 +294,8 @@ class HloTestBase : public ManifestCheckingTest { [[nodiscard]] ::testing::AssertionResult Run( const absl::string_view hlo_string, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr, - const tsl::protobuf::Message* backend_config = nullptr); + const tsl::protobuf::Message* backend_config = nullptr, + bool use_random_data = true); // Same as below, except requires passing fake arguments. ::testing::AssertionResult RunAndCompareTwoModules( @@ -303,7 +306,8 @@ class HloTestBase : public ManifestCheckingTest { // Same as below, except requires passing the modules. ::testing::AssertionResult RunAndCompareTwoModules( std::unique_ptr module_0, std::unique_ptr module_1, - const std::optional& error, bool run_hlo_passes = true); + const std::optional& error, bool run_hlo_passes = true, + std::optional args_max_bits_of_precision = std::nullopt); // Convenient wrapper for executing and comparing results of two hlo modules // with fake input. By default compares unoptimized modules. If the modules @@ -311,6 +315,22 @@ class HloTestBase : public ManifestCheckingTest { ::testing::AssertionResult RunAndCompareTwoModules( absl::string_view hlo_string_module_0, absl::string_view hlo_string_module_1, + const std::optional& error, bool run_hlo_passes = true, + std::optional args_max_bits_of_precision = std::nullopt); + + // Same as above but allows running with different configs. + ::testing::AssertionResult RunAndCompareTwoModules( + absl::string_view hlo_string_module_0, + absl::string_view hlo_string_module_1, const HloModuleConfig& config_0, + const HloModuleConfig& config_1, const std::optional& error, + bool run_hlo_passes = true, + std::optional args_max_bits_of_precision = std::nullopt); + + // Same as above but requires explicit arguments. + ::testing::AssertionResult RunAndCompareTwoModules( + absl::string_view hlo_string_module_0, + absl::string_view hlo_string_module_1, + absl::Span arguments, const std::optional& error, bool run_hlo_passes = true); // Executes an hlo module with fake inputs on multiple replicas. @@ -393,6 +413,7 @@ class HloTestBase : public ManifestCheckingTest { bool verifier_layout_sensitive_; bool allow_mixed_precision_in_hlo_verifier_; + HloPredicate instruction_can_change_layout_func_; std::unique_ptr hlo_verifier_; ErrorSpec error_spec_{0.0001}; @@ -401,7 +422,7 @@ class HloTestBase : public ManifestCheckingTest { HloModule*, std::unique_ptr computation); void UpdateEntryComputationLayout(HloModule* module); - StatusOr> GetHloRunner(); + absl::StatusOr> GetHloRunner(); protected: // Helper functions to get test and reference platforms. @@ -416,14 +437,14 @@ class HloTestBase : public ManifestCheckingTest { // Given the test module, makes a reference module that is ready to run on the // reference platform. This assumes that the given module is ready to run on // the test platform. - StatusOr> MakeReferenceModule( + absl::StatusOr> MakeReferenceModule( const HloModule& test_module, const std::function& reference_preprocessor); // Runs the module on two platforms with or without running hlo passes and // compares the results. Returns whether the results are near or equal. If any // error happens before the results are computed, returns the error status. - StatusOr<::testing::AssertionResult> RunAndCompareInternal( + absl::StatusOr<::testing::AssertionResult> RunAndCompareInternal( std::unique_ptr module, const absl::Span arguments, const std::optional& error, bool run_hlo_passes, @@ -432,14 +453,14 @@ class HloTestBase : public ManifestCheckingTest { // Runs the two module on with or without running hlo passes and // compares the results. Returns whether the results are near or equal. If any // error happens before the results are computed, returns the error status. - StatusOr<::testing::AssertionResult> RunAndCompareTwoModulesInternal( + absl::StatusOr<::testing::AssertionResult> RunAndCompareTwoModulesInternal( std::unique_ptr module_0, std::unique_ptr module_1, const absl::Span arguments, const std::optional& error, bool run_hlo_passes); // Returns either an HloRunner or HloRunnerPjRt implementation depending if // there exists a registered PjRtClientFactory. - StatusOr> GetHloRunnerForTest( + absl::StatusOr> GetHloRunnerForTest( se::Platform* test_platform); }; diff --git a/xla/tests/int4_test.cc b/xla/tests/int4_test.cc index 84e2e0b98a91c..6d83d9489aee4 100644 --- a/xla/tests/int4_test.cc +++ b/xla/tests/int4_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -61,6 +61,20 @@ XLA_TEST_F(HloTestBase, Slice) { EXPECT_TRUE(RunAndCompare(hlo_text, std::nullopt)); } +XLA_TEST_F(HloTestBase, Add) { + const std::string hlo_text = R"( + HloModule HorizontalLoopFusion + + ENTRY main { + x = s4[5,5] parameter(0) + x8 = s8[5,5] convert(x) + y8 = add(x8, x8) + ROOT y = s4[5,5] convert(y8) + } +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, std::nullopt)); +} + XLA_TEST_F(HloTestBase, NonMajorToMinorLayout) { // Tests transposing a matrix with a non-major-to-minor layout. const std::string hlo_text = R"( @@ -121,5 +135,27 @@ XLA_TEST_F(HloTestBase, Scalar) { EXPECT_TRUE(RunAndCompare(hlo_text, std::nullopt)); } +XLA_TEST_F(HloTestBase, HorizontalLoopFusion) { + // Tests an HLO module where horizontal loop fusion can be done on GPUs + const std::string hlo_text = R"( + HloModule HorizontalLoopFusion + + ENTRY main { + x4 = s4[10] parameter(0) + x8 = s8[10] convert(x4) + y8 = s8[10] add(x8, x8) + y4 = s4[10] convert(y8) + + x4_b = s4[13] parameter(1) + x8_b = s8[13] convert(x4_b) + y8_b = s8[13] add(x8_b, x8_b) + y4_b = s4[13] convert(y8_b) + + ROOT t = (s4[10], s4[13]) tuple(y4, y4_b) + } +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, std::nullopt)); +} + } // namespace } // namespace xla diff --git a/xla/tests/iota_test.cc b/xla/tests/iota_test.cc index 78fb328cd2d00..a0f3fe37fe685 100644 --- a/xla/tests/iota_test.cc +++ b/xla/tests/iota_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/literal_test_util.cc b/xla/tests/literal_test_util.cc index f42659858422a..df898b1cedea8 100644 --- a/xla/tests/literal_test_util.cc +++ b/xla/tests/literal_test_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/literal_test_util.h b/xla/tests/literal_test_util.h index 0e4f355986d2e..01b2aa6433c30 100644 --- a/xla/tests/literal_test_util.h +++ b/xla/tests/literal_test_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/literal_test_util_test.cc b/xla/tests/literal_test_util_test.cc index 738e88eddcaa8..4912a37255d9d 100644 --- a/xla/tests/literal_test_util_test.cc +++ b/xla/tests/literal_test_util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/llvm_compiler_test.cc b/xla/tests/llvm_compiler_test.cc index 39acb92460845..5f826254f2524 100644 --- a/xla/tests/llvm_compiler_test.cc +++ b/xla/tests/llvm_compiler_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -69,7 +69,7 @@ class GpuDummyCompiler : public GpuCompiler { return OkStatus(); } - StatusOr CompileTargetBinary( + absl::StatusOr CompileTargetBinary( const HloModuleConfig& module_config, llvm::Module* llvm_module, se::GpuComputeCapability gpu_version, bool relocatable, const HloModule* debug_module, const CompileOptions& options) override { @@ -88,7 +88,7 @@ class LLVMCompilerTest : public ::testing::Test { BackendOptions backend_options; backend_options.set_platform(platform); - StatusOr> backend_or_status = + absl::StatusOr> backend_or_status = Backend::CreateBackend(backend_options); ASSERT_IS_OK(backend_or_status.status()); backend_ = std::move(backend_or_status).value(); diff --git a/xla/tests/llvm_irgen_test_base.cc b/xla/tests/llvm_irgen_test_base.cc index 5df038b5b3a9a..59b16529b4df1 100644 --- a/xla/tests/llvm_irgen_test_base.cc +++ b/xla/tests/llvm_irgen_test_base.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -56,7 +56,7 @@ void LlvmIrGenTestBase::CompileAndVerifyIr( ResetIrHook(); TF_ASSERT_OK(status); - StatusOr filecheck_result = RunFileCheck(ir_, pattern); + absl::StatusOr filecheck_result = RunFileCheck(ir_, pattern); TF_ASSERT_OK(filecheck_result.status()); EXPECT_TRUE(filecheck_result.value()) << "Full IR: " << ir_; } @@ -82,7 +82,7 @@ void LlvmIrGenTestBase::CompileAheadOfTimeAndVerifyIr( ResetIrHook(); TF_ASSERT_OK(status); - StatusOr filecheck_result = RunFileCheck(ir_, pattern); + absl::StatusOr filecheck_result = RunFileCheck(ir_, pattern); ASSERT_TRUE(filecheck_result.ok()); EXPECT_TRUE(filecheck_result.value()) << "Full IR: " << ir_; } diff --git a/xla/tests/llvm_irgen_test_base.h b/xla/tests/llvm_irgen_test_base.h index d6cd0cdd0febb..890e5bacb65ed 100644 --- a/xla/tests/llvm_irgen_test_base.h +++ b/xla/tests/llvm_irgen_test_base.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/local_client_allocation_test.cc b/xla/tests/local_client_allocation_test.cc index a6022b12419b6..5584c378f041a 100644 --- a/xla/tests/local_client_allocation_test.cc +++ b/xla/tests/local_client_allocation_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/local_client_aot_test.cc b/xla/tests/local_client_aot_test.cc index fe2034d649efc..28b2b92427140 100644 --- a/xla/tests/local_client_aot_test.cc +++ b/xla/tests/local_client_aot_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/local_client_aot_test_helper.cc b/xla/tests/local_client_aot_test_helper.cc index 2fcd11a9cd55f..bb6fdc0cb0bff 100644 --- a/xla/tests/local_client_aot_test_helper.cc +++ b/xla/tests/local_client_aot_test_helper.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -65,6 +65,8 @@ int main(int argc, char** argv) { std::string target_cpu = argv[1]; if (target_cpu == "k8") { triple_string = "x86_64-none-linux-gnu"; + } else if (target_cpu == "darwin_arm64") { + triple_string = "arm64-apple-darwin"; } else if (target_cpu == "darwin") { triple_string = "x86_64-apple-macosx"; } else if ((target_cpu == "arm") || (target_cpu == "aarch64")) { diff --git a/xla/tests/local_client_execute_test.cc b/xla/tests/local_client_execute_test.cc index d6736f6152e7f..1c94d82ef387e 100644 --- a/xla/tests/local_client_execute_test.cc +++ b/xla/tests/local_client_execute_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ limitations under the License. #include "xla/statusor.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/host/host_platform_id.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/test_helpers.h" #include "xla/tests/literal_test_util.h" @@ -653,12 +654,11 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnStream) { } se::StreamExecutor* executor = local_client_->platform()->ExecutorForDevice(d).value(); - se::Stream stream(executor); - stream.Init(); + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); - auto result = - ExecuteLocallyOrDie(computation, {}, DefaultExecutableBuildOptions(), - DefaultExecutableRunOptions().set_stream(&stream)); + auto result = ExecuteLocallyOrDie( + computation, {}, DefaultExecutableBuildOptions(), + DefaultExecutableRunOptions().set_stream(stream.get())); // As a check to verify that the computation ran of the device associated // with the stream. This is a weak check, but stronger verification is hard. EXPECT_EQ(d, result.device_ordinal()); @@ -673,16 +673,16 @@ XLA_TEST_F(LocalClientExecuteTest, // Try to run a computation on a stream for a platform (CPU) which does not // match the platform of the service (!= CPU). se::Platform* wrong_platform = - se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId) - .value(); - se::Stream wrong_stream(wrong_platform->ExecutorForDevice(0).value()); - wrong_stream.Init(); + se::PlatformManager::PlatformWithId(se::host::kHostPlatformId).value(); + TF_ASSERT_OK_AND_ASSIGN( + auto wrong_stream, + wrong_platform->ExecutorForDevice(0).value()->CreateStream()); XlaBuilder builder(TestName()); ConstantR0(&builder, 42.0f); auto execute_status = ExecuteLocally( builder.Build().value(), {}, DefaultExecutableBuildOptions(), - DefaultExecutableRunOptions().set_stream(&wrong_stream)); + DefaultExecutableRunOptions().set_stream(wrong_stream.get())); EXPECT_FALSE(execute_status.ok()); EXPECT_THAT(execute_status.status().message(), ContainsRegex("stream is for platform .*, but service targets")); @@ -691,8 +691,7 @@ XLA_TEST_F(LocalClientExecuteTest, XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_CPU(AllocatorDoesNotMatchPlatform)) { se::Platform* wrong_platform = - se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId) - .value(); + se::PlatformManager::PlatformWithId(se::host::kHostPlatformId).value(); TestAllocator allocator(wrong_platform); XlaBuilder builder(TestName()); @@ -706,27 +705,6 @@ XLA_TEST_F(LocalClientExecuteTest, ContainsRegex("allocator platform .* does not match service")); } -XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) { - // Try to run a computation on a stream that has not been initialized. - XlaBuilder builder(TestName()); - ConstantR0(&builder, 42.0f); - - LOG(INFO) << "default device = " << local_client_->default_device_ordinal(); - se::StreamExecutor* executor = - local_client_->platform() - ->ExecutorForDevice(local_client_->default_device_ordinal()) - .value(); - se::Stream stream(executor); - // Don't call stream.Init(). - - auto execute_status = ExecuteLocally( - builder.Build().value(), {}, DefaultExecutableBuildOptions(), - DefaultExecutableRunOptions().set_stream(&stream)); - EXPECT_FALSE(execute_status.ok()); - EXPECT_THAT(execute_status.status().message(), - ContainsRegex("stream is uninitialized or in an error state")); -} - XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { XlaBuilder builder(TestName()); auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); diff --git a/xla/tests/local_client_test_base.cc b/xla/tests/local_client_test_base.cc index 09b41b21b2a1e..8df3e5eb11aca 100644 --- a/xla/tests/local_client_test_base.cc +++ b/xla/tests/local_client_test_base.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -39,10 +39,9 @@ namespace xla { /* static */ TestAllocator* LocalClientTestBase::allocator_; -StatusOr TestAllocator::Allocate(int device_ordinal, - uint64_t size, - bool retry_on_failure, - int64_t memory_space) { +absl::StatusOr TestAllocator::Allocate( + int device_ordinal, uint64_t size, bool retry_on_failure, + int64_t memory_space) { VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")"; { absl::MutexLock lock(&count_mutex_); @@ -173,14 +172,14 @@ ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( .value(); } -StatusOr LocalClientTestBase::ExecuteLocally( +absl::StatusOr LocalClientTestBase::ExecuteLocally( const XlaComputation& computation, absl::Span arguments) { return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions()); } -StatusOr LocalClientTestBase::ExecuteLocally( +absl::StatusOr LocalClientTestBase::ExecuteLocally( const XlaComputation& computation, absl::Span arguments, const ExecutableBuildOptions& build_options, @@ -208,12 +207,12 @@ StatusOr LocalClientTestBase::ExecuteLocally( return std::move(ret); } -StatusOr> +absl::StatusOr> LocalClientTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text) { return ParseAndReturnVerifiedModule(hlo_text, HloModuleConfig()); } -StatusOr> +absl::StatusOr> LocalClientTestBase::ParseAndReturnVerifiedModule( absl::string_view hlo_text, const HloModuleConfig& config) { auto module = std::make_unique( diff --git a/xla/tests/local_client_test_base.h b/xla/tests/local_client_test_base.h index 7c7128326a966..7f99d8471442f 100644 --- a/xla/tests/local_client_test_base.h +++ b/xla/tests/local_client_test_base.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -47,9 +47,9 @@ class TestAllocator : public se::StreamExecutorMemoryAllocator { : se::StreamExecutorMemoryAllocator( platform, PlatformUtil::GetStreamExecutors(platform).value()) {} - StatusOr Allocate(int device_ordinal, uint64_t size, - bool retry_on_failure, - int64_t memory_space) override; + absl::StatusOr Allocate( + int device_ordinal, uint64_t size, bool retry_on_failure, + int64_t memory_space) override; Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; // Return the number of allocations that have been performed. @@ -93,10 +93,10 @@ class LocalClientTestBase : public ManifestCheckingTest { // Execute the given computation on the local client. With and without // options. - StatusOr ExecuteLocally( + absl::StatusOr ExecuteLocally( const XlaComputation& computation, absl::Span arguments); - StatusOr ExecuteLocally( + absl::StatusOr ExecuteLocally( const XlaComputation& computation, absl::Span arguments, const ExecutableBuildOptions& build_options, @@ -112,10 +112,11 @@ class LocalClientTestBase : public ManifestCheckingTest { const ExecutableRunOptions& run_options); // Parses the given string and returns module as a VerifiedHloModule. - StatusOr> ParseAndReturnVerifiedModule( - absl::string_view hlo_text); - StatusOr> ParseAndReturnVerifiedModule( - absl::string_view hlo_text, const HloModuleConfig& config); + absl::StatusOr> + ParseAndReturnVerifiedModule(absl::string_view hlo_text); + absl::StatusOr> + ParseAndReturnVerifiedModule(absl::string_view hlo_text, + const HloModuleConfig& config); // Returns a default set of execute options. ExecutableBuildOptions DefaultExecutableBuildOptions() const; diff --git a/xla/tests/log_test.cc b/xla/tests/log_test.cc index 861a30de0590c..f5bc530f0c1c8 100644 --- a/xla/tests/log_test.cc +++ b/xla/tests/log_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/manifest_checking_test.cc b/xla/tests/manifest_checking_test.cc index 686e2d5455ff4..ac4230a99ffbe 100644 --- a/xla/tests/manifest_checking_test.cc +++ b/xla/tests/manifest_checking_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/manifest_checking_test.h b/xla/tests/manifest_checking_test.h index d7405b2d5b748..66b4f74b4fc4a 100644 --- a/xla/tests/manifest_checking_test.h +++ b/xla/tests/manifest_checking_test.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/map_test.cc b/xla/tests/map_test.cc index be72acd88e567..e72605bc4249a 100644 --- a/xla/tests/map_test.cc +++ b/xla/tests/map_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -489,7 +489,7 @@ TEST_F(MapTest, MapOperationWithBuildError) { auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, error_add, {0}); - StatusOr computation_status = builder.Build(); + absl::StatusOr computation_status = builder.Build(); ASSERT_TRUE(!computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), ::testing::HasSubstr("error from: ErrorAdd: Binary op add with " diff --git a/xla/tests/matmul_test.cc b/xla/tests/matmul_test.cc index 3f070d8cc3480..61e97d15a3882 100644 --- a/xla/tests/matmul_test.cc +++ b/xla/tests/matmul_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,6 +33,19 @@ class MatmulTestWithCublas : public HloTestBase, debug_options.set_xla_gpu_enable_cublaslt(use_cublas_lt_); return debug_options; } + void SetUp() override { + auto dbg = GetDebugOptionsForTest(); + if (dbg.xla_gpu_enable_cublaslt()) { + const auto& gpu_cc = backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + if (auto* rocm = std::get_if(&gpu_cc); + rocm != nullptr && !rocm->has_hipblaslt()) { + GTEST_SKIP() << "No hipblas-lt support on this architecture!"; + } + } + } private: const bool use_cublas_lt_{GetParam()}; diff --git a/xla/tests/matrix_ops_simple_test.cc b/xla/tests/matrix_ops_simple_test.cc index 236a722423cde..13a34e6effaa9 100644 --- a/xla/tests/matrix_ops_simple_test.cc +++ b/xla/tests/matrix_ops_simple_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/multidimensional_slice_test.cc b/xla/tests/multidimensional_slice_test.cc index 3313c63df4e71..1c15cd2ac94fc 100644 --- a/xla/tests/multidimensional_slice_test.cc +++ b/xla/tests/multidimensional_slice_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/multioutput_fusion_test.cc b/xla/tests/multioutput_fusion_test.cc index 0b14df8dcbfb7..c2220217b12de 100644 --- a/xla/tests/multioutput_fusion_test.cc +++ b/xla/tests/multioutput_fusion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -190,36 +190,6 @@ XLA_TEST_F(MultiOutputFusionTest, DifferentTypesNoFusion) { } XLA_TEST_F(MultiOutputFusionTest, DifferentTypesFusion) { RunTest1D(true, 8); } -XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { - const char* testcase = R"( - HloModule m, is_scheduled=true - - fused_computation { - x.param_0 = (((s32[]), f32[]), (f32[], s32[])) parameter(0) - gte.3 = ((s32[]), f32[]) get-tuple-element(x.param_0), index=0 - gte.2 = (s32[]) get-tuple-element(gte.3), index=0 - gte.4 = s32[] get-tuple-element(gte.2), index=0 - copy = s32[] copy(gte.4) - ROOT tuple = (s32[]) tuple(copy) - } - - ENTRY thing.v3 { - x = (((s32[]), f32[]), (f32[], s32[])) parameter(0) - ROOT fusion = (s32[]) fusion(x), kind=kLoop, calls=fused_computation - } - )"; - auto module = ParseAndReturnVerifiedModule(testcase).value(); - auto param = LiteralUtil::MakeTupleOwned( - LiteralUtil::MakeTupleOwned( - LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), - LiteralUtil::CreateR0(1.0)), - LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(3.0), - LiteralUtil::CreateR0(4))); - Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); - EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), result)); -} - XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { const char* testcase = R"( HloModule m, is_scheduled=true @@ -319,8 +289,7 @@ const char* const kScalarOps = R"( } )"; -XLA_TEST_F(MultiOutputFusionTest, - DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) { +XLA_TEST_F(MultiOutputFusionTest, MultiOutputReduceFusionMinor) { const std::string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[32,32,32]{2,1,0} parameter(0) @@ -341,8 +310,7 @@ XLA_TEST_F(MultiOutputFusionTest, EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); } -XLA_TEST_F(MultiOutputFusionTest, - DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) { +XLA_TEST_F(MultiOutputFusionTest, MultiOutputReduceFusionMajor) { const std::string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[32,32,32]{2,1,0} parameter(0) @@ -363,8 +331,7 @@ XLA_TEST_F(MultiOutputFusionTest, EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); } -XLA_TEST_F(MultiOutputFusionTest, - DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) { +XLA_TEST_F(MultiOutputFusionTest, MultiOutputReduceFusionScalar) { const std::string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,32,32]{2,1,0} parameter(0) @@ -386,8 +353,7 @@ XLA_TEST_F(MultiOutputFusionTest, EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); } -XLA_TEST_F(MultiOutputFusionTest, - DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) { +XLA_TEST_F(MultiOutputFusionTest, MultiOutputReduceFusionMinorWithExtraOutput) { const std::string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,32,32]{2,1,0} parameter(0) @@ -409,8 +375,7 @@ XLA_TEST_F(MultiOutputFusionTest, EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); } -XLA_TEST_F(MultiOutputFusionTest, - DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) { +XLA_TEST_F(MultiOutputFusionTest, MultiOutputReduceFusionMajorWithExtraOutput) { const std::string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[32,32,2]{2,1,0} parameter(0) @@ -433,7 +398,7 @@ XLA_TEST_F(MultiOutputFusionTest, } XLA_TEST_F(MultiOutputFusionTest, - DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) { + MultiOutputReduceFusionScalarWithExtraOutput) { const std::string testcase = R"( HloModule m, is_scheduled=true @@ -463,8 +428,7 @@ XLA_TEST_F(MultiOutputFusionTest, EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); } -XLA_TEST_F(MultiOutputFusionTest, - DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) { +XLA_TEST_F(MultiOutputFusionTest, MultiOutputReduceFusionNonConstInit) { const std::string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,32,32]{2,1,0} parameter(0) @@ -487,7 +451,7 @@ XLA_TEST_F(MultiOutputFusionTest, } XLA_TEST_F(MultiOutputFusionTest, - DISABLED_ON_CPU(MultiOutputReduceFusionDifferentElementTypes)) { + MultiOutputReduceFusionDifferentElementTypes) { const std::string testcase = absl::StrCat(kScalarOps, R"( fused_reduce (p0: f16[2,32,32]) -> (f32[2,32], f32[2,32], f16[2,32,32]) { p0 = f16[2,32,32]{2,1,0} parameter(0) @@ -510,5 +474,112 @@ XLA_TEST_F(MultiOutputFusionTest, EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); } +XLA_TEST_F(MultiOutputFusionTest, MultiOutputReduceCanonicalizationIsSame) { + const std::string testcase = absl::StrCat(kScalarOps, R"( +fused_computation { + param_0 = f32[64,128]{1,0} parameter(0) + neg = f32[64,128]{1,0} negate(param_0) + bitcast = f32[8,8,128]{2,1,0} bitcast(neg) + constant_0 = f32[] constant(0) + reduce.1 = f32[128]{0} reduce(param_0, constant_0), dimensions={0}, to_apply=Add + reduce.2 = f32[128]{0} reduce(bitcast, constant_0), dimensions={0,1}, to_apply=Add + ROOT tuple.12 = (f32[128]{0}, f32[128]{0}) tuple(reduce.1, reduce.2) +} + +ENTRY main { + Arg_2.1 = f32[64,128]{1,0} parameter(0) + ROOT fusion = (f32[128]{0}, f32[128]{0}) fusion(Arg_2.1), kind=kInput, calls=fused_computation +})"); + auto module = ParseAndReturnVerifiedModule(testcase).value(); + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); +} + +XLA_TEST_F(MultiOutputFusionTest, MultiOutputReduceGeneralBitcastCompatible) { + const std::string testcase = absl::StrCat(kScalarOps, R"( +fused_computation { + param_0 = f32[64,128]{1,0} parameter(0) + neg = f32[64,128]{1,0} negate(param_0) + bitcast = f32[8,8,128]{2,1,0} bitcast(neg) + bitcast2 = f32[128,64]{0,1} bitcast(neg) + constant_0 = f32[] constant(0) + reduce.1 = f32[128]{0} reduce(bitcast, constant_0), dimensions={0,1}, to_apply=Add + ROOT tuple.12 = (f32[128]{0}, f32[64,128]{1,0}, f32[128,64]{0,1}) tuple(reduce.1, neg, bitcast2) +} + +ENTRY main { + Arg_2.1 = f32[64,128]{1,0} parameter(0) + ROOT fusion = (f32[128]{0}, f32[64,128]{1,0}, f32[128,64]{0,1}) fusion(Arg_2.1), kind=kInput, calls=fused_computation +})"); + auto module = ParseAndReturnVerifiedModule(testcase).value(); + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); +} + +XLA_TEST_F(MultiOutputFusionTest, MultiOutputReduceWithEpilogue) { + const std::string testcase = absl::StrCat(kScalarOps, R"( +fused_computation { + param_0 = f32[4,2]{1,0} parameter(0) + neg = f32[4,2]{1,0} negate(param_0) + constant_0 = f32[] constant(0) + reduce.1 = f32[4]{0} reduce(param_0, constant_0), dimensions={1}, to_apply=Add + bitcast.1 = f32[1,1,4]{2,1,0} bitcast(reduce.1) + sign.1 = f32[1,1,4]{2,1,0} sign(bitcast.1) + ROOT tuple.12 = (f32[4,2]{1,0}, f32[1,1,4]{2,1,0}, f32[1,1,4]{2,1,0}) tuple(neg, bitcast.1, sign.1) +} + +ENTRY main.7749 { + Arg_2.1 = f32[4,2]{1,0} parameter(0) + ROOT fusion = (f32[4,2]{1,0}, f32[1,1,4]{2,1,0}, f32[1,1,4]{2,1,0}) fusion(Arg_2.1), kind=kInput, calls=fused_computation +})"); + auto module = ParseAndReturnVerifiedModule(testcase).value(); + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); +} + +XLA_TEST_F(MultiOutputFusionTest, + MultiOutputReduceWithEpilogueHeroAlsoUsedAsNonHero) { + // reduce.8 is used by bitcast as reduce hero and by broadcast as non-hero. + const std::string testcase = absl::StrCat(kScalarOps, R"( +fused_computation { + %param_0.6 = f32[4]{0} parameter(0) + %zero = f32[] constant(0.0) + %reduce.8 = f32[] reduce(f32[4]{0} %param_0.6, f32[] %zero), dimensions={0}, to_apply=Add + %broadcast = f32[4]{0} broadcast(f32[] %reduce.8), dimensions={} + %compare = pred[4]{0} compare(f32[4]{0} %param_0.6, f32[4]{0} %broadcast), direction=EQ + %convert = f32[4]{0} convert(pred[4]{0} %compare) + %reduce.19.1 = f32[] reduce(f32[4]{0} %convert, f32[] %zero), dimensions={0}, to_apply=Add + %bitcast = f32[1]{0} bitcast(f32[] %reduce.8) + ROOT %tuple.1 = (f32[], f32[4]{0}, f32[1]{0}) tuple(f32[] %reduce.19.1, f32[4]{0} %convert, f32[1]{0} %bitcast) +} + +ENTRY main { + Arg0 = f32[4]{0} parameter(0) + ROOT fusion = (f32[], f32[4]{0}, f32[1]{0}) fusion(Arg0), kind=kInput, calls=fused_computation +})"); + auto module = ParseAndReturnVerifiedModule(testcase).value(); + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); +} + +XLA_TEST_F(MultiOutputFusionTest, + MultiOutputTransposeFusionHeroWithMultipleRootUsers) { + const std::string testcase = R"( + HloModule test + fused_transpose { + p0 = f32[128,64]{1,0} parameter(0) + p1 = f32[64,128]{1,0} parameter(1) + tr = f32[64,128]{1,0} transpose(p0), dimensions={1,0} + add = f32[64,128]{1,0} add(tr, p1) + neg = f32[64,128]{1,0} negate(add) + ROOT tuple = (f32[64,128]{1,0}, f32[64,128]{1,0}, f32[64,128]{1,0}) tuple(tr, add, neg) + } + + ENTRY main { + p = f32[128,64]{1,0} parameter(0) + p2 = f32[64,128]{1,0} parameter(1) + ROOT fusion = (f32[64,128]{1,0}, f32[64,128]{1,0}, f32[64,128]{1,0}) fusion(p, p2), + kind=kInput, calls=fused_transpose + })"; + auto module = ParseAndReturnVerifiedModule(testcase).value(); + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); +} + } // namespace } // namespace xla diff --git a/xla/tests/multiple_devices_on_host_test.cc b/xla/tests/multiple_devices_on_host_test.cc index 9e192c7404e98..ef69ea068d7fe 100644 --- a/xla/tests/multiple_devices_on_host_test.cc +++ b/xla/tests/multiple_devices_on_host_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,13 +17,14 @@ limitations under the License. #include "xla/client/client_library.h" #include "xla/client/xla_builder.h" #include "xla/shape_util.h" +#include "xla/stream_executor/platform_manager.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" namespace xla { namespace { -StatusOr BuildComputation() { +absl::StatusOr BuildComputation() { XlaBuilder b("computation"); Shape scalar_s32 = ShapeUtil::MakeShape(S32, {}); XlaOp infeed = InfeedWithToken(CreateToken(&b), scalar_s32); @@ -45,7 +46,7 @@ void CompileAndExecute( xla::ClientLibrary::GetXlaService(client->platform()) ->backend() .memory_allocator()); - StatusOr result = + absl::StatusOr result = executable->Run(absl::Span(), execute_options); { absl::MutexLock lock(results_mutex); @@ -57,7 +58,7 @@ void TestWithDeviceCount(const int device_count) { // Run `device_count` copies of the XLA program built by BuildComputation. TF_ASSERT_OK_AND_ASSIGN( se::Platform* const platform, - stream_executor::MultiPlatformManager::PlatformWithName("Host")); + stream_executor::PlatformManager::PlatformWithName("Host")); xla::LocalClientOptions client_options; client_options.set_platform(platform); TF_ASSERT_OK_AND_ASSIGN( diff --git a/xla/tests/multithreaded_compilation_test.cc b/xla/tests/multithreaded_compilation_test.cc index 3ad17c8ad8894..530384d16e894 100644 --- a/xla/tests/multithreaded_compilation_test.cc +++ b/xla/tests/multithreaded_compilation_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -76,7 +76,7 @@ XLA_TEST_F(MultithreadedCompilation, EightModuleCompilation) { absl::MutexLock lock(&mu); executables.push_back(std::move(executable)); VLOG(2) << "Adding executable obtained from thread: " << iteration; - return tsl::OkStatus(); + return absl::OkStatus(); }; { diff --git a/xla/tests/numerics_test.cc b/xla/tests/numerics_test.cc new file mode 100644 index 0000000000000..4ec6eeffa52a6 --- /dev/null +++ b/xla/tests/numerics_test.cc @@ -0,0 +1,90 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "xla/hlo/ir/hlo_module.h" +#include "xla/literal_util.h" +#include "xla/statusor.h" +#include "xla/test.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/test_macros.h" +#include "xla/types.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace { + +using NumericsTest = HloTestBase; + +XLA_TEST_F(NumericsTest, AbsOfLargeComplexNumber) { + const char* hlo = R"( +HloModule module + +ENTRY entry { + x = c64[] parameter(0) + ROOT power = f32[] abs(x) +} +)"; + + auto abs_of_complex_x = [&hlo, this](float x) { + std::unique_ptr module = + ParseAndReturnVerifiedModule(hlo).value(); + auto x_lit = LiteralUtil::CreateR0(x); + return RunAndCompare(std::move(module), {&x_lit}, ErrorSpec{1e-5, 1e-5}); + }; + + EXPECT_TRUE(abs_of_complex_x(1e19)); + EXPECT_TRUE(abs_of_complex_x(1e25)); + EXPECT_TRUE(abs_of_complex_x(1e30)); +} + +XLA_TEST_F(NumericsTest, PowerOfLargeComplexNumber) { + const char* hlo = R"( +HloModule module + +ENTRY entry { + large = c64[] parameter(0) + x = c64[] parameter(1) + ROOT power = c64[] power(large, x) +} +)"; + + auto complex_a_raised_to_complex_b = [&hlo, this](float num, float exp) { + std::unique_ptr module = + ParseAndReturnVerifiedModule(hlo).value(); + auto num_lit = LiteralUtil::CreateR0(num); + auto exp_lit = LiteralUtil::CreateR0(exp); + return RunAndCompare(std::move(module), {&num_lit, &exp_lit}, + ErrorSpec{1e-5, 1e-5}); + }; + + EXPECT_TRUE(complex_a_raised_to_complex_b(1e19, 0)); + EXPECT_TRUE(complex_a_raised_to_complex_b(1e19, 1)); + EXPECT_TRUE(complex_a_raised_to_complex_b(1e19, 1.2)); + EXPECT_TRUE(complex_a_raised_to_complex_b(1e19, 2)); + EXPECT_TRUE(complex_a_raised_to_complex_b(1e30, 0)); + EXPECT_TRUE(complex_a_raised_to_complex_b(1e30, 1)); + EXPECT_TRUE(complex_a_raised_to_complex_b(1e30, 1.2)); + EXPECT_TRUE( + complex_a_raised_to_complex_b(std::numeric_limits::infinity(), 0)); + EXPECT_TRUE(complex_a_raised_to_complex_b( + std::numeric_limits::quiet_NaN(), 0)); +} + +} // namespace +} // namespace xla diff --git a/xla/tests/onednn_layer_norm_test.cc b/xla/tests/onednn_layer_norm_test.cc new file mode 100644 index 0000000000000..beead4f62b5b4 --- /dev/null +++ b/xla/tests/onednn_layer_norm_test.cc @@ -0,0 +1,216 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + +#include "xla/test.h" +#include "xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +class LayerNormTest : public HloTestBase { + protected: + const char* onednn_layer_norm_ = + R"( + ; CHECK: custom_call_target="__onednn$layernorm", + ; CHECK: backend_config={ + ; CHECK-DAG: "onednn_layer_norm_config":{ + ; CHECK-DAG: "fused_ops":"SCALE_AND_SHIFT" + ; CHECK-DAG: } + ; CHECK: } + )"; + std::string common_hlo_region_ = + R"( + + region_add { + Arg_0.7555 = f32[] parameter(0) + Arg_1.7556 = f32[] parameter(1) + ROOT add.7557 = f32[] add(Arg_0.7555, Arg_1.7556) + } +)"; + + std::string common_hlo_entry_computation_block_ = + R"( + Arg_0.2 = f32[768]{0} parameter(1), sharding={replicated} + Arg_0.3 = f32[768]{0} parameter(2), sharding={replicated} + + convert.290 = f32[84,197,768]{2,1,0} convert(Arg_0.1) + constant.291 = f32[] constant(0) + convert.292 = f32[] convert(constant.291) + reduce.297 = f32[84,197]{1,0} reduce(convert.290, convert.292), dimensions={2}, to_apply=region_add + constant.298 = s32[] constant(768) + convert.299 = f32[] convert(constant.298) + broadcast.300 = f32[84,197]{1,0} broadcast(convert.299), dimensions={} + divide.301 = f32[84,197]{1,0} divide(reduce.297, broadcast.300) + convert.302 = f32[84,197]{1,0} convert(divide.301) + reshape.303 = f32[84,197,1]{2,1,0} reshape(convert.302) + reshape.304 = f32[84,197]{1,0} reshape(reshape.303) + broadcast.305 = f32[84,197,768]{2,1,0} broadcast(reshape.304), dimensions={0,1} + subtract.306 = f32[84,197,768]{2,1,0} subtract(Arg_0.1, broadcast.305) + multiply.307 = f32[84,197,768]{2,1,0} multiply(subtract.306, subtract.306) + convert.308 = f32[84,197,768]{2,1,0} convert(multiply.307) + constant.309 = f32[] constant(0) + convert.310 = f32[] convert(constant.309) + reduce.315 = f32[84,197]{1,0} reduce(convert.308, convert.310), dimensions={2}, to_apply=region_add + constant.316 = s32[] constant(768) + convert.317 = f32[] convert(constant.316) + broadcast.318 = f32[84,197]{1,0} broadcast(convert.317), dimensions={} + divide.319 = f32[84,197]{1,0} divide(reduce.315, broadcast.318) + convert.320 = f32[84,197]{1,0} convert(divide.319) + reshape.321 = f32[84,197,1]{2,1,0} reshape(convert.320) + constant.322 = f32[] constant(1e-12) + broadcast.323 = f32[84,197,1]{2,1,0} broadcast(constant.322), dimensions={} + add.324 = f32[84,197,1]{2,1,0} add(reshape.321, broadcast.323) + rsqrt.325 = f32[84,197,1]{2,1,0} rsqrt(add.324) + reshape.328 = f32[84,197]{1,0} reshape(rsqrt.325) + broadcast.329 = f32[84,197,768]{2,1,0} broadcast(reshape.328), dimensions={0,1} + broadcast.327 = f32[84,197,768]{2,1,0} broadcast(Arg_0.2), dimensions={2} + multiply.330 = f32[84,197,768]{2,1,0} multiply(broadcast.329, broadcast.327) + multiply.331 = f32[84,197,768]{2,1,0} multiply(Arg_0.1, multiply.330) + broadcast.336 = f32[84,197,768]{2,1,0} broadcast(Arg_0.3), dimensions={2} + reshape.332 = f32[84,197]{1,0} reshape(reshape.303) + broadcast.333 = f32[84,197,768]{2,1,0} broadcast(reshape.332), dimensions={0,1} + multiply.334 = f32[84,197,768]{2,1,0} multiply(multiply.330, broadcast.333) + subtract.337 = f32[84,197,768]{2,1,0} subtract(broadcast.336, multiply.334) +)"; +}; + +TEST_F(LayerNormTest, LayerNormTest0_FP32) { + std::string layer_norm_module_str = + R"(HloModule layer_norm.test, entry_computation_layout={(f32[84,197,768]{2,1,0}, f32[768]{0}, f32[768]{0})->f32[84,197,768]{2,1,0}})" + + common_hlo_region_ + R"( + ENTRY main { + Arg_0.1 = f32[84,197,768]{2,1,0} parameter(0), sharding={replicated} + + )" + common_hlo_entry_computation_block_ + + R"( + ROOT add.338 = f32[84,197,768]{2,1,0} add(multiply.331, subtract.337) + } + )"; + + EXPECT_TRUE(RunAndCompare(layer_norm_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(layer_norm_module_str, onednn_layer_norm_); +} + +TEST_F(LayerNormTest, LayerNormTest0_BF16) { + std::string layer_norm_module_str = + R"(HloModule layer_norm.test, entry_computation_layout={(bf16[84,197,768]{2,1,0}, f32[768]{0}, f32[768]{0})->bf16[84,197,768]{2,1,0}})" + + common_hlo_region_ + R"( + ENTRY main { + Arg_0.1.0 = bf16[84,197,768]{2,1,0} parameter(0), sharding={replicated} + Arg_0.1 = f32[84,197,768]{2,1,0} convert(Arg_0.1.0) + )" + common_hlo_entry_computation_block_ + + R"( + add.338 = f32[84,197,768]{2,1,0} add(multiply.331, subtract.337) + ROOT convert.339 = bf16[84,197,768]{2,1,0} convert(add.338) + } + )"; + + EXPECT_TRUE(RunAndCompare(layer_norm_module_str, ErrorSpec{1e-2, 1e-2})); + MatchOptimizedHlo(layer_norm_module_str, onednn_layer_norm_); +} + +TEST_F(LayerNormTest, LayerNormTest0_F16) { + std::string layer_norm_module_str = + R"(HloModule layer_norm.test, entry_computation_layout={(f16[84,197,768]{2,1,0}, f32[768]{0}, f32[768]{0})->f16[84,197,768]{2,1,0}})" + + common_hlo_region_ + R"( + ENTRY main { + Arg_0.1.0 = f16[84,197,768]{2,1,0} parameter(0), sharding={replicated} + Arg_0.1 = f32[84,197,768]{2,1,0} convert(Arg_0.1.0) + )" + common_hlo_entry_computation_block_ + + R"( + add.338 = f32[84,197,768]{2,1,0} add(multiply.331, subtract.337) + ROOT convert.339 = f16[84,197,768]{2,1,0} convert(add.338) + } + )"; + + EXPECT_TRUE(RunAndCompare(layer_norm_module_str, ErrorSpec{1e-2, 1e-2})); + MatchOptimizedHlo(layer_norm_module_str, onednn_layer_norm_); +} + +// Test case encountered in models like TFViTForImageClassification in +// HuggingFace +// (https://huggingface.co/docs/transformers/model_doc/vit#transformers.TFViTForImageClassification) +TEST_F(LayerNormTest, LayerNormTest1_BF16) { + const char* layer_norm_module_str = R"( + HloModule layer_norm.test + region_add { + Arg_0.7555 = f32[] parameter(0) + Arg_1.7556 = f32[] parameter(1) + ROOT add.7557 = f32[] add(Arg_0.7555, Arg_1.7556) + } + ENTRY main { + Arg_0.1 = bf16[160,197,768] parameter(0), sharding={replicated} + Arg_0.2 = bf16[768] parameter(1), sharding={replicated} + Arg_0.3 = bf16[768] parameter(2), sharding={replicated} + convert.80 = f32[160,197,768] convert(Arg_0.1) + constant.81 = f32[] constant(0) + convert.82 = f32[] convert(constant.81) + reduce.87 = f32[160,197] reduce(convert.80, convert.82), dimensions={2}, to_apply=region_add + constant.88 = s32[] constant(768) + convert.89 = f32[] convert(constant.88) + broadcast.90 = f32[160,197] broadcast(convert.89), dimensions={} + divide.91 = f32[160,197] divide(reduce.87, broadcast.90) + convert.92 = bf16[160,197] convert(divide.91) + reshape.93 = bf16[160,197,1] reshape(convert.92) + reshape.94 = bf16[160,197] reshape(reshape.93) + broadcast.95 = bf16[160,197,768] broadcast(reshape.94), dimensions={0,1} + subtract.96 = bf16[160,197,768] subtract(Arg_0.1, broadcast.95) + multiply.97 = bf16[160,197,768] multiply(subtract.96, subtract.96) + convert.98 = f32[160,197,768] convert(multiply.97) + constant.99 = f32[] constant(0) + convert.100 = f32[] convert(constant.99) + reduce.105 = f32[160,197] reduce(convert.98, convert.100), dimensions={2}, to_apply=region_add + constant.106 = s32[] constant(768) + convert.107 = f32[] convert(constant.106) + broadcast.108 = f32[160,197] broadcast(convert.107), dimensions={} + divide.109 = f32[160,197] divide(reduce.105, broadcast.108) + convert.110 = bf16[160,197] convert(divide.109) + reshape.111 = bf16[160,197,1] reshape(convert.110) + constant.112 = bf16[] constant(1.002e-12) + broadcast.113 = bf16[160,197,1] broadcast(constant.112), dimensions={} + add.114 = bf16[160,197,1] add(reshape.111, broadcast.113) + rsqrt.115 = bf16[160,197,1] rsqrt(add.114) + reshape.118 = bf16[160,197] reshape(rsqrt.115) + broadcast.119 = bf16[160,197,768] broadcast(reshape.118), dimensions={0,1} + broadcast.117 = bf16[160,197,768] broadcast(Arg_0.2), dimensions={2} + multiply.120 = bf16[160,197,768] multiply(broadcast.119, broadcast.117) + multiply.121 = bf16[160,197,768] multiply(Arg_0.1, multiply.120) + broadcast.126 = bf16[160,197,768] broadcast(Arg_0.3), dimensions={2} + reshape.122 = bf16[160,197] reshape(reshape.93) + broadcast.123 = bf16[160,197,768] broadcast(reshape.122), dimensions={0,1} + multiply.124 = bf16[160,197,768] multiply(multiply.120, broadcast.123) + subtract.127 = bf16[160,197,768] subtract(broadcast.126, multiply.124) + ROOT add.128 = bf16[160,197,768] add(multiply.121, subtract.127) + } +)"; + + EXPECT_TRUE(RunAndCompare(layer_norm_module_str, ErrorSpec{1e-2, 1e-2})); + MatchOptimizedHlo(layer_norm_module_str, + R"( + ; CHECK: custom_call_target="__onednn$layernorm", + ; CHECK: backend_config={ + ; CHECK-DAG: "onednn_layer_norm_config":{ + ; CHECK-DAG: "fused_ops":"SCALE_AND_SHIFT" + ; CHECK-DAG: } + ; CHECK: } + )"); +} + +} // namespace +} // namespace xla + +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 diff --git a/xla/tests/onednn_matmul_test.cc b/xla/tests/onednn_matmul_test.cc index ce1f58fc67bee..64befeab768c2 100644 --- a/xla/tests/onednn_matmul_test.cc +++ b/xla/tests/onednn_matmul_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,64 +17,679 @@ limitations under the License. #include +#include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" +#include "xla/service/cpu/onednn_matmul_rewriter.h" +#include "xla/service/cpu/onednn_util.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" +#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" #include "tsl/platform/cpu_info.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace cpu { -class MatmulTest : public HloTestBase {}; +class MatmulTest : public HloTestBase { + protected: + const char* fused_matmul_bias_ = R"( + ; CHECK: custom_call_target="__onednn$matmul", + ; CHECK: backend_config={ + ; CHECK-DAG: "outer_dimension_partitions":[], + ; CHECK-DAG: "onednn_matmul_config":{ + ; CHECK-DAG: "fused_ops":["BIAS"] + ; CHECK-DAG: } + ; CHECK: } + )"; + const char* fused_matmul_binary_add_ = R"( + ; CHECK: custom_call_target="__onednn$matmul", + ; CHECK: backend_config={ + ; CHECK-DAG: "outer_dimension_partitions":[], + ; CHECK-DAG: "onednn_matmul_config":{ + ; CHECK-DAG: "fused_ops":["BINARY_ADD"] + ; CHECK-DAG: } + ; CHECK: } + )"; + const char* matmul_rewrite_str_ = R"( + ; CHECK: custom_call_target="__onednn$matmul", + ; CHECK: backend_config={ + ; CHECK-DAG: "outer_dimension_partitions":[], + ; CHECK-DAG: "onednn_matmul_config":{ + ; CHECK-DAG: "fused_ops":[] + ; CHECK-DAG: } + ; CHECK: } + )"; +}; TEST_F(MatmulTest, SimpleTestF32) { const char* matmul_module_str = R"( - HloModule matmul.test.f32, entry_computation_layout={(f32[2,8,4,16]{3,2,1,0},f32[2,8,16,32]{3,2,1,0})->f32[2,8,4,32]{3,2,1,0}} + HloModule matmul.test.f32 ENTRY matmul.test.f32 { - arg.0 = f32[2,8,4,16]{3,2,1,0} parameter(0), parameter_replication={false} - arg.1 = f32[2,8,16,32]{3,2,1,0} parameter(1), parameter_replication={false} - ROOT onednn.matmul.0 = f32[2,8,4,32]{3,2,1,0} dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + arg.0 = f32[32,8,128,64] parameter(0), parameter_replication={false} + arg.1 = f32[32,8,64,128] parameter(1), parameter_replication={false} + ROOT onednn.matmul.0 = f32[32,8,128,128] dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, matmul_rewrite_str_); } TEST_F(MatmulTest, SimpleTestBF16) { // TODO(penporn): Refactor IsBF16SupportedByOneDNNOnThisCPU() from // tensorflow/core/graph/mkl_graph_util.h and call the function instead. - using tsl::port::TestCPUFeature; - if (!TestCPUFeature(tsl::port::CPUFeature::AVX512_BF16) && - !TestCPUFeature(tsl::port::CPUFeature::AMX_BF16)) { + if (!IsSupportedType(PrimitiveType::BF16)) { GTEST_SKIP() << "CPU does not support BF16."; } const char* matmul_module_str = R"( - HloModule matmul.test.bf16, entry_computation_layout={(bf16[2,8,4,16]{3,2,1,0},bf16[2,8,16,32]{3,2,1,0})->bf16[2,8,4,32]{3,2,1,0}} + HloModule matmul.test.bf16 ENTRY matmul.test.bf16 { - arg.0 = bf16[2,8,4,16]{3,2,1,0} parameter(0), parameter_replication={false} - arg.1 = bf16[2,8,16,32]{3,2,1,0} parameter(1), parameter_replication={false} - ROOT onednn.matmul.0 = bf16[2,8,4,32]{3,2,1,0} dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + arg.0 = bf16[32,8,128,64] parameter(0), parameter_replication={false} + arg.1 = bf16[32,8,64,128] parameter(1), parameter_replication={false} + ROOT onednn.matmul.0 = bf16[32,8,128,128] dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} })"; - EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-2, 1e-4})); + MatchOptimizedHlo(matmul_module_str, matmul_rewrite_str_); +} + +TEST_F(MatmulTest, SimpleTestF16) { + if (!IsSupportedType(PrimitiveType::F16)) { + GTEST_SKIP() << "CPU does not support F16."; + } + + const char* matmul_module_str = R"( + HloModule matmul.test.f16 + + ENTRY matmul.test.f16 { + arg.0 = f16[32,8,128,64] parameter(0), parameter_replication={false} + arg.1 = f16[32,8,64,128] parameter(1), parameter_replication={false} + ROOT onednn.matmul.0 = f16[32,8,128,128] dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-2, 1e-4})); + MatchOptimizedHlo(matmul_module_str, matmul_rewrite_str_); } TEST_F(MatmulTest, SimpleTestF32TransposeB) { const char* matmul_module_str = R"( - HloModule matmul.test.1, entry_computation_layout={(f32[2,8,4,16]{3,1,2,0},f32[2,8,4,16]{3,1,2,0})->f32[2,8,4,4]{3,2,1,0}} + HloModule matmul.test.1 ENTRY matmul.test.1 { - arg.0 = f32[2,8,4,16]{3,1,2,0} parameter(0), parameter_replication={false} - arg.1 = f32[2,8,4,16]{3,1,2,0} parameter(1), parameter_replication={false} - ROOT onednn.matmul.0 = f32[2,8,4,4]{3,2,1,0} dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + arg.0 = f32[32,8,128,64]{3,1,2,0} parameter(0), parameter_replication={false} + arg.1 = f32[32,8,128,64]{3,1,2,0} parameter(1), parameter_replication={false} + ROOT onednn.matmul.0 = f32[32,8,128,128] dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, matmul_rewrite_str_); +} + +TEST_F(MatmulTest, SimpleTestF32WithBiasAddFusion1) { + const char* matmul_module_str = R"( + HloModule matmul.biasadd.test.f32 + + ENTRY matmul.biasadd.test.f32 { + arg0.1 = f32[32,32,40,30] parameter(0), parameter_replication={false} + reshape.2 = f32[32,32,40,30] reshape(arg0.1) + constant.3 = f32[] constant(1) + broadcast.4 = f32[32,32,30,40] broadcast(constant.3), dimensions={} + dot.7 = f32[32,32,40,40] dot(reshape.2, broadcast.4), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + constant.5 = f32[] constant(15) + broadcast.6 = f32[40] broadcast(constant.5), dimensions={} + broadcast.9 = f32[32,32,40,40] broadcast(broadcast.6), dimensions={3} + add.10 = f32[32,32,40,40] add(dot.7, broadcast.9) + reshape.11 = f32[32,32,40,40] reshape(add.10) + tuple.12 = (f32[32,32,40,40]) tuple(reshape.11) + ROOT get-tuple-element.13 = f32[32,32,40,40] get-tuple-element(tuple.12), index=0 + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, fused_matmul_binary_add_); +} + +TEST_F(MatmulTest, SimpleTestF32WithBiasAddFusion2) { + const char* matmul_module_str = R"( + HloModule matmul.biasadd.test.f32 + + ENTRY matmul.biasadd.test.f32 { + arg0.1 = f32[400,300] parameter(0), parameter_replication={false} + reshape.2 = f32[400,300] reshape(arg0.1) + constant.3 = f32[] constant(1) + broadcast.4 = f32[300,400] broadcast(constant.3), dimensions={} + dot.7 = f32[400,400] dot(reshape.2, broadcast.4), lhs_batch_dims={}, lhs_contracting_dims={1}, rhs_batch_dims={}, rhs_contracting_dims={0} + reshape.1 = f32[400,1,400] reshape(dot.7) + constant.5 = f32[] constant(15) + broadcast.6 = f32[400] broadcast(constant.5), dimensions={} + broadcast.9 = f32[400,1,400] broadcast(broadcast.6), dimensions={2} + add.10 = f32[400,1,400] add(reshape.1, broadcast.9) + tuple.12 = (f32[400,1,400]) tuple(add.10) + ROOT get-tuple-element.13 = f32[400,1,400] get-tuple-element(tuple.12), index=0 + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, fused_matmul_binary_add_); +} + +TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter1) { + const char* matmul_module_str = R"( + HloModule matmul.biasadd.test.f32 + + ENTRY matmul.biasadd.test.f32 { + arg0.1 = f32[32,32,40,30] parameter(0), parameter_replication={false} + arg0.2 = f32[32,32,30,40] parameter(1), parameter_replication={false} + arg0.3 = f32[32,32,40,40] parameter(2), parameter_replication={false} + dot.7 = f32[32,32,40,40] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + add.10 = f32[32,32,40,40] add(dot.7, arg0.3) + reshape.11 = f32[32,32,40,40] reshape(add.10) + tuple.12 = (f32[32,32,40,40]) tuple(reshape.11) + ROOT get-tuple-element.13 = f32[32,32,40,40] get-tuple-element(tuple.12), index=0 + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, fused_matmul_binary_add_); +} + +TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter2) { + const char* matmul_module_str = R"( + HloModule matmul.biasadd.test.f32 + + ENTRY matmul.biasadd.test.f32 { + arg0.1 = f32[32,32,40,30] parameter(0), parameter_replication={false} + arg0.2 = f32[32,32,30,40] parameter(1), parameter_replication={false} + arg0.3 = f32[40]{0} parameter(2), parameter_replication={false} + dot.7 = f32[32,32,40,40] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + broad.1 = f32[32,32,40,40] broadcast(arg0.3), dimensions={3} + add.10 = f32[32,32,40,40] add(dot.7, broad.1) + reshape.11 = f32[32,32,40,40] reshape(add.10) + tuple.12 = (f32[32,32,40,40]) tuple(reshape.11) + ROOT get-tuple-element.13 = f32[32,32,40,40] get-tuple-element(tuple.12), index=0 + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, fused_matmul_bias_); +} + +TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter2D) { + const char* matmul_module_str = R"( + HloModule matmul.biasadd.test.f32 + + ENTRY matmul.biasadd.test.f32 { + arg0.1 = f32[2,2,400,30] parameter(0), parameter_replication={false} + arg0.2 = f32[2,2,30,400] parameter(1), parameter_replication={false} + arg0.3 = f32[2,400] parameter(2), parameter_replication={false} + dot.7 = f32[2,2,400,400] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + broad.1 = f32[2,2,400,400] broadcast(arg0.3), dimensions={0,3} + add.10 = f32[2,2,400,400] add(dot.7, broad.1) + reshape.11 = f32[2,2,400,400] reshape(add.10) + tuple.12 = (f32[2,2,400,400]) tuple(reshape.11) + ROOT get-tuple-element.13 = f32[2,2,400,400] get-tuple-element(tuple.12), index=0 + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, fused_matmul_binary_add_); +} + +TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter2D1B) { + const char* matmul_module_str = R"( + HloModule matmul.biasadd.test.f32 + + ENTRY matmul.biasadd.test.f32 { + arg0.1 = f32[1,2,400,30] parameter(0), parameter_replication={false} + arg0.2 = f32[1,2,30,400] parameter(1), parameter_replication={false} + arg0.3 = f32[1,400] parameter(2), parameter_replication={false} + dot.7 = f32[1,2,400,400] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + broad.1 = f32[1,2,400,400] broadcast(arg0.3), dimensions={0,3} + add.10 = f32[1,2,400,400] add(dot.7, broad.1) + reshape.11 = f32[1,2,400,400] reshape(add.10) + tuple.12 = (f32[1,2,400,400]) tuple(reshape.11) + ROOT get-tuple-element.13 = f32[1,2,400,400] get-tuple-element(tuple.12), index=0 + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, fused_matmul_bias_); +} + +TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter3) { + const char* matmul_module_str = R"( + HloModule matmul.biasadd.test.f32 + + ENTRY matmul.biasadd.test.f32 { + arg0.1 = f32[16,128,768] parameter(0), sharding={replicated} + arg0.2 = f32[768,768] parameter(1), sharding={replicated} + dot.84 = f32[16,128,768] dot(arg0.1, arg0.2), lhs_contracting_dims={2}, rhs_contracting_dims={0} + arg0.3 = f32[768]{0} parameter(2), sharding={replicated} + reshape.85 = f32[1,1,768] reshape(arg0.3) + broadcast.86 = f32[1,1,768] broadcast(reshape.85), dimensions={0,1,2} + reshape.87 = f32[768]{0} reshape(broadcast.86) + broadcast.88 = f32[16,128,768] broadcast(reshape.87), dimensions={2} + ROOT add.89 = f32[16,128,768] add(dot.84, broadcast.88) + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, fused_matmul_bias_); +} + +TEST_F(MatmulTest, SimpleTestF32TransposeBWithBiasAddFusion) { + const char* matmul_module_str = R"( + HloModule matmul.test.1 + + ENTRY matmul.test.1 { + arg.0 = f32[32,8,4,16]{3,1,2,0} parameter(0), parameter_replication={false} + arg.1 = f32[32,8,16,16]{3,1,2,0} parameter(1), parameter_replication={false} + dot.7 = f32[32,8,4,16]{3,2,1,0} dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + constant.5 = f32[] constant(15) + broadcast.6 = f32[16]{0} broadcast(constant.5), dimensions={} + broadcast.9 = f32[32,8,4,16]{3,2,1,0} broadcast(broadcast.6), dimensions={3} + add.10 = f32[32,8,4,16]{3,2,1,0} add(dot.7, broadcast.9) + reshape.11 = f32[32,8,4,16]{3,2,1,0} reshape(add.10) + tuple.12 = (f32[32,8,4,16]{3,2,1,0}) tuple(reshape.11) + ROOT get-tuple-element.13 = f32[32,8,4,16]{3,2,1,0} get-tuple-element(tuple.12), index=0 + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, fused_matmul_binary_add_); +} + +TEST_F(MatmulTest, F32BiasAddFusionNonCompatibleBias) { + const char* matmul_module_str = R"( + HloModule matmul.test.f32 + + ENTRY matmul.test.1 { + arg.0 = f32[12288,2] parameter(0), parameter_replication={false} + arg.1 = f32[2,1024] parameter(1), parameter_replication={false} + dot.0 = f32[12288,1024] dot(arg.0, arg.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + reshape.0 = f32[32,384,1024] reshape(dot.0) + constant.0 = f32[1,384,1024] constant(15) + reshape.1 = f32[384,1024] reshape(constant.0) + broadcast.0 = f32[32,384,1024] broadcast(reshape.1), dimensions={1,2} + add.0 = f32[32,384,1024] add(reshape.0, broadcast.0) + tuple.0 = (f32[32,384,1024]) tuple(add.0) + ROOT get-tuple-element.0 = f32[32,384,1024] get-tuple-element(tuple.0), index=0 + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, matmul_rewrite_str_); +} + +TEST_F(MatmulTest, ApproxGELUTestF32) { + const char* matmul_module_str = R"( + HloModule matmul.test.f32 + + ENTRY matmul.test.f32 { + arg.0 = f32[32,32,4,16] parameter(0), parameter_replication={false} + arg.1 = f32[32,32,16,32] parameter(1), parameter_replication={false} + onednn.matmul.0 = f32[32,32,4,32] dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + mul.0 = f32[32,32,4,32] multiply(onednn.matmul.0, onednn.matmul.0) + mul.1 = f32[32,32,4,32] multiply(onednn.matmul.0, mul.0) + const.0 = f32[] constant(0.044715) + bcast.0 = f32[32,32,4,32] broadcast(const.0), dimensions={} + mul.2 = f32[32,32,4,32] multiply(mul.1, bcast.0) + add.0 = f32[32,32,4,32] add(onednn.matmul.0, mul.2) + const.1 = f32[] constant(0.797884583) + bcast.1 = f32[32,32,4,32] broadcast(const.1), dimensions={} + mul.3 = f32[32,32,4,32] multiply(add.0, bcast.1) + tanh = f32[32,32,4,32] tanh(mul.3) + const.2 = f32[] constant(1) + bcast.2 = f32[32,32,4,32] broadcast(const.2), dimensions={} + add.2 = f32[32,32,4,32] add(tanh, bcast.2) + const.3 = f32[] constant(0.5) + bcast.3 = f32[32,32,4,32] broadcast(const.3), dimensions={} + mul.4 = f32[32,32,4,32] multiply(add.2, bcast.3) + ROOT out = f32[32,32,4,32] multiply(onednn.matmul.0, mul.4) + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, + R"( + ; CHECK: custom_call_target="__onednn$matmul", + ; CHECK: backend_config={ + ; CHECK-DAG: "outer_dimension_partitions":[], + ; CHECK-DAG: "onednn_matmul_config":{ + ; CHECK-DAG: "fused_ops":["GELU_TANH"] + ; CHECK-DAG: } + ; CHECK: } + )"); +} + +// GPT-J Bias+GELU pattern with reduced sizes for test time: +// batch=32; seq_len=32; hidden_size=64; intermediate_size=256 +TEST_F(MatmulTest, BiasAndApproxGELUTestF32) { + const char* matmul_module_str = R"( + HloModule matmul.test.f32 + + ENTRY matmul.test.f32 { + Arg_5.6 = f32[32,32,64] parameter(0), sharding={replicated} + Arg_7.8 = f32[64,256] parameter(1), sharding={replicated} + dot.232 = f32[32,32,256] dot(Arg_5.6, Arg_7.8), lhs_contracting_dims={2}, rhs_contracting_dims={0} + Arg_6.7 = f32[256] parameter(2), sharding={replicated} + reshape.233 = f32[1,1,256] reshape(Arg_6.7) + broadcast.234 = f32[1,1,256] broadcast(reshape.233), dimensions={0,1,2} + reshape.235 = f32[256] reshape(broadcast.234) + broadcast.236 = f32[32,32,256] broadcast(reshape.235), dimensions={2} + add.237 = f32[32,32,256] add(dot.232, broadcast.236) + multiply.238 = f32[32,32,256] multiply(add.237, add.237) + multiply.239 = f32[32,32,256] multiply(add.237, multiply.238) + constant.20 = f32[] constant(0.044715) + broadcast.21 = f32[32,32,256] broadcast(constant.20), dimensions={} + multiply.240 = f32[32,32,256] multiply(multiply.239, broadcast.21) + add.241 = f32[32,32,256] add(add.237, multiply.240) + constant.18 = f32[] constant(0.797884583) + broadcast.19 = f32[32,32,256] broadcast(constant.18), dimensions={} + multiply.242 = f32[32,32,256] multiply(add.241, broadcast.19) + tanh.243 = f32[32,32,256] tanh(multiply.242) + constant.16 = f32[] constant(1) + broadcast.17 = f32[32,32,256] broadcast(constant.16), dimensions={} + add.244 = f32[32,32,256] add(tanh.243, broadcast.17) + constant.14 = f32[] constant(0.5) + broadcast.15 = f32[32,32,256] broadcast(constant.14), dimensions={} + multiply.245 = f32[32,32,256] multiply(add.244, broadcast.15) + ROOT out = f32[32,32,256] multiply(add.237, multiply.245) + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, + R"( + ; CHECK: custom_call_target="__onednn$matmul", + ; CHECK: backend_config={ + ; CHECK-DAG: "outer_dimension_partitions":[], + ; CHECK-DAG: "onednn_matmul_config":{ + ; CHECK-DAG: "fused_ops":["BIAS","GELU_TANH"] + ; CHECK-DAG: } + ; CHECK: } + )"); +} + +TEST_F(MatmulTest, ReLUTestF32) { + const char* matmul_module_str = R"( + HloModule matmul.test.f32 + + relu.1 { + Arg_0.3 = f32[32,32,4,32] parameter(0) + constant.4 = f32[] constant(0) + broadcast.5 = f32[32,32,4,32] broadcast(constant.4), dimensions={} + ROOT maximum.6 = f32[32,32,4,32] maximum(Arg_0.3, broadcast.5) + } + + ENTRY matmul.test.f32 { + arg.0 = f32[32,32,4,16] parameter(0), parameter_replication={false} + arg.1 = f32[32,32,16,32] parameter(1), parameter_replication={false} + onednn.matmul.0 = f32[32,32,4,32] dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT call.7 = f32[32,32,4,32] call(onednn.matmul.0), to_apply=relu.1 + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, + R"( + ; CHECK: custom_call_target="__onednn$matmul", + ; CHECK: backend_config={ + ; CHECK-DAG: "outer_dimension_partitions":[], + ; CHECK-DAG: "onednn_matmul_config":{ + ; CHECK-DAG: "fused_ops":["RELU"] + ; CHECK-DAG: } + ; CHECK: } + )"); +} + +TEST_F(MatmulTest, SimpleBiasTestBF16_PARAM_F32) { + if (!IsSupportedType(PrimitiveType::BF16)) { + GTEST_SKIP() << "CPU does not support BF16."; + } + + const char* matmul_module_str = R"( + HloModule jit_apply + + ENTRY matmul.test.bf16 { + Arg_2.3 = f32[16,128,768] parameter(2), sharding={replicated} + convert.4 = bf16[16,128,768] convert(Arg_2.3) + Arg_1.2 = f32[768,3072] parameter(1), sharding={replicated} + convert.5 = bf16[768,3072] convert(Arg_1.2) + dot.7 = bf16[16,128,3072] dot(convert.4, convert.5), lhs_contracting_dims={2}, rhs_contracting_dims={0} + Arg_0.1 = f32[3072] parameter(0), sharding={replicated} + convert.6 = bf16[3072] convert(Arg_0.1) + reshape.8 = bf16[1,1,3072] reshape(convert.6) + broadcast.9 = bf16[1,1,3072] broadcast(reshape.8), dimensions={0,1,2} + reshape.10 = bf16[3072] reshape(broadcast.9) + broadcast.11 = bf16[16,128,3072] broadcast(reshape.10), dimensions={2} + ROOT add.12 = bf16[16,128,3072] add(dot.7, broadcast.11) + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-2, 1e-2})); + MatchOptimizedHlo(matmul_module_str, fused_matmul_bias_); +} + +TEST_F(MatmulTest, SimpleBiasTestBF16_PARAM_BF16) { + if (!IsSupportedType(PrimitiveType::BF16)) { + GTEST_SKIP() << "CPU does not support BF16."; + } + + const char* matmul_module_str = R"( + HloModule jit_apply + + ENTRY matmul.test.bf16 { + Arg_2.3 = f32[16,128,768] parameter(2), sharding={replicated} + convert.4 = bf16[16,128,768] convert(Arg_2.3) + Arg_1.2 = bf16[768,3072] parameter(1), sharding={replicated} + dot.5 = bf16[16,128,3072] dot(convert.4, Arg_1.2), lhs_contracting_dims={2}, rhs_contracting_dims={0} + Arg_0.1 = bf16[3072] parameter(0), sharding={replicated} + reshape.6 = bf16[1,1,3072] reshape(Arg_0.1) + broadcast.7 = bf16[1,1,3072] broadcast(reshape.6), dimensions={0,1,2} + reshape.8 = bf16[3072] reshape(broadcast.7) + broadcast.9 = bf16[16,128,3072] broadcast(reshape.8), dimensions={2} + ROOT add.10 = bf16[16,128,3072] add(dot.5, broadcast.9) + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-2, 1e-2})); + MatchOptimizedHlo(matmul_module_str, fused_matmul_bias_); +} + +TEST_F(MatmulTest, DivisionByConstantWithEltwiseLinearF32) { + const char* matmul_module_str = R"( + HloModule matmul.divide.test.1 + + ENTRY matmul.divide.test.f32 { + Arg_4.5 = f32[16,128,768] parameter(0), sharding={replicated} + Arg_2.3 = f32[768,12,64] parameter(1), sharding={replicated} + onednn.matmul.0 = f32[16,128,12,64] dot(Arg_4.5, Arg_2.3), lhs_contracting_dims={2}, rhs_contracting_dims={0} + constant.8 = f32[] constant(8) + broadcast.9 = f32[16,128,12,64] broadcast(constant.8), dimensions={} + ROOT divide.16 = f32[16,128,12,64] divide(onednn.matmul.0, broadcast.9) + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec(1e-4, 1e-4))); + MatchOptimizedHlo(matmul_module_str, + R"( + ; CHECK: custom_call_target="__onednn$matmul", + ; CHECK: backend_config={ + ; CHECK-DAG: "outer_dimension_partitions":[], + ; CHECK-DAG: "onednn_matmul_config":{ + ; CHECK-DAG: "fused_ops":["LINEAR"] + ; CHECK-DAG: } + ; CHECK: } + )"); +} + +TEST_F(MatmulTest, SimpleBiasTestF16_PARAM_F32) { + if (!IsSupportedType(PrimitiveType::F16)) { + GTEST_SKIP() << "CPU does not support F16."; + } + + const char* matmul_module_str = R"( + HloModule jit_apply + + ENTRY matmul.test.f16 { + Arg_2.3 = f32[16,128,768] parameter(2), sharding={replicated} + convert.4 = f16[16,128,768] convert(Arg_2.3) + Arg_1.2 = f32[768,3072] parameter(1), sharding={replicated} + convert.5 = f16[768,3072] convert(Arg_1.2) + dot.7 = f16[16,128,3072] dot(convert.4, convert.5), lhs_contracting_dims={2}, rhs_contracting_dims={0} + Arg_0.1 = f32[3072] parameter(0), sharding={replicated} + convert.6 = f16[3072] convert(Arg_0.1) + reshape.8 = f16[1,1,3072] reshape(convert.6) + broadcast.9 = f16[1,1,3072] broadcast(reshape.8), dimensions={0,1,2} + reshape.10 = f16[3072] reshape(broadcast.9) + broadcast.11 = f16[16,128,3072] broadcast(reshape.10), dimensions={2} + ROOT add.12 = f16[16,128,3072] add(dot.7, broadcast.11) + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-2, 1e-2})); + MatchOptimizedHlo(matmul_module_str, fused_matmul_bias_); +} + +TEST_F(MatmulTest, SimpleBiasTestF16_PARAM_F16) { + if (!IsSupportedType(PrimitiveType::F16)) { + GTEST_SKIP() << "CPU does not support F16."; + } + const char* matmul_module_str = R"( + HloModule jit_apply + + ENTRY matmul.test.f16 { + Arg_2.3 = f32[16,128,768] parameter(2), sharding={replicated} + convert.4 = f16[16,128,768] convert(Arg_2.3) + Arg_1.2 = f16[768,3072] parameter(1), sharding={replicated} + dot.5 = f16[16,128,3072] dot(convert.4, Arg_1.2), lhs_contracting_dims={2}, rhs_contracting_dims={0} + Arg_0.1 = f16[3072] parameter(0), sharding={replicated} + reshape.6 = f16[1,1,3072] reshape(Arg_0.1) + broadcast.7 = f16[1,1,3072] broadcast(reshape.6), dimensions={0,1,2} + reshape.8 = f16[3072] reshape(broadcast.7) + broadcast.9 = f16[16,128,3072] broadcast(reshape.8), dimensions={2} + ROOT add.10 = f16[16,128,3072] add(dot.5, broadcast.9) + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-2, 1e-2})); + MatchOptimizedHlo(matmul_module_str, fused_matmul_bias_); +} + +TEST_F(MatmulTest, TestF32NonConstantWeights) { + const char* matmul_module_str = R"( + HloModule matmul.test.f32 + + ENTRY matmul.test.f32 { + arg.0 = f32[64,256,16] parameter(0), parameter_replication={false} + arg.1 = f32[16,32] parameter(1), parameter_replication={false} + ROOT onednn.matmul.0 = f32[64,256,32] dot(arg.0, arg.1), lhs_contracting_dims={2}, rhs_contracting_dims={0} + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, + R"( + ; CHECK: %matmul.test.f32 + ; CHECK-NOT: custom_call_target="__onednn$matmul_reorder", + ; CHECK: custom-call(%{{[a-z,A-Z,0-9,\.]*}}, %arg.1), custom_call_target="__onednn$matmul", + )"); +} + +TEST_F(MatmulTest, TestF32ConstantWeights) { + const char* matmul_module_str = R"( + HloModule matmul.test.f32 + + ENTRY matmul.test.f32 { + arg.0 = f32[64,256,16] parameter(0), parameter_replication={false} + constant = f32[] constant(1) + arg.1 = f32[16,32] broadcast(constant), dimensions={} + ROOT onednn.matmul.0 = f32[64,256,32] dot(arg.0, arg.1), lhs_contracting_dims={2}, rhs_contracting_dims={0} + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, + R"( + ; CHECK: %matmul.test.f32 + ; CHECK-NOT: custom_call_target="__onednn$matmul_reorder", + ; CHECK: custom-call(%{{[a-z,A-Z,0-9,\.]*}}, %constant{{[a-z,A-Z,0-9,\.]*}}), custom_call_target="__onednn$matmul", + )"); +} + +TEST_F(MatmulTest, SimpleTestBF16Gemv1) { + if (!IsSupportedType(PrimitiveType::BF16)) { + GTEST_SKIP() << "CPU does not support BF16."; + } + + const char* matmul_module_str = R"( + HloModule matmul.test.bf16 + + ENTRY matmul.test.bf16 { + arg.0 = bf16[1000,10000] parameter(0) + arg.1 = bf16[10000] parameter(1) + ROOT onednn.matmul.0 = bf16[1000] dot(arg.0, arg.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{2e-2, 1e-4})); + MatchOptimizedHlo(matmul_module_str, matmul_rewrite_str_); +} + +TEST_F(MatmulTest, SimpleTestBF16Gemv2) { + if (!IsSupportedType(PrimitiveType::BF16)) { + GTEST_SKIP() << "CPU does not support BF16."; + } + + const char* matmul_module_str = R"( + HloModule matmul.test.bf16 + + ENTRY matmul.test.bf16 { + arg.0 = bf16[100,300,300] parameter(0) + arg.1 = bf16[300] parameter(1) + ROOT onednn.matmul.0 = bf16[100,300] dot(arg.0, arg.1), lhs_contracting_dims={2}, rhs_contracting_dims={0} + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{2e-2, 1e-4})); + MatchOptimizedHlo(matmul_module_str, matmul_rewrite_str_); +} + +TEST_F(MatmulTest, TestTransposeBNoRewriteF32) { + const char* matmul_module_str = R"( + HloModule matmul.test.f32 + + ENTRY matmul.test.f32 { + arg.0 = f32[384,1024]{1,0} parameter(0), parameter_replication={false} + arg.1 = f32[2,1024]{1,0} parameter(1), parameter_replication={false} + ROOT dot.2 = f32[384,2]{1,0} dot(arg.0, arg.1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, + R"( + ; CHECK: %matmul.test.f32 + ; CHECK-NOT: custom_call_target="__onednn$matmul", + ; CHECK: f32[384,2]{1,0} dot(%arg.0, %arg.1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + )"); +} + +TEST_F(MatmulTest, SimpleTestF32WithMulAndAddFusion) { + const char* matmul_module_str = R"( + ENTRY matmul.mul.add.test.f32 { + arg0.1 = f32[32,32,40,30] parameter(0), parameter_replication={false} + arg0.2 = f32[32,32,30,40] parameter(1), parameter_replication={false} + dot.7 = f32[32,32,40,40] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + const.0 = f32[] constant(0.044715) + bcast.0 = f32[32,32,40,40] broadcast(const.0), dimensions={} + mul.0 = f32[32,32,40,40] multiply(dot.7,bcast.0) + const.1 = f32[] constant(0.65) + bcast.1 = f32[32,32,40,40] broadcast(const.1), dimensions={} + add.0 = f32[32,32,40,40] add(mul.0, bcast.1) + const.2 = f32[] constant(0.65) + bcast.2 = f32[32,32,40,40] broadcast(const.2), dimensions={} + add.1 = f32[32,32,40,40] add(bcast.2, bcast.1) + tuple.12 = (f32[32,32,40,40]) tuple(add.0) + ROOT get-tuple-element.13 = f32[32,32,40,40] get-tuple-element(tuple.12), index=0 })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, + R"( + ; CHECK: custom_call_target="__onednn$matmul", + ; CHECK: backend_config={ + ; CHECK-DAG: "outer_dimension_partitions":[], + ; CHECK-DAG: "onednn_matmul_config":{ + ; CHECK-DAG: "fused_ops":["LINEAR","BINARY_ADD"] + ; CHECK-DAG: } + ; CHECK: } + )"); } } // namespace cpu diff --git a/xla/tests/onednn_softmax_test.cc b/xla/tests/onednn_softmax_test.cc new file mode 100644 index 0000000000000..4af19eafa732d --- /dev/null +++ b/xla/tests/onednn_softmax_test.cc @@ -0,0 +1,188 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + +#include + +#include "xla/literal.h" +#include "xla/service/cpu/onednn_util.h" +#include "xla/shape_util.h" +#include "xla/test.h" +#include "xla/test_helpers.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/test_macros.h" + +namespace xla { +namespace cpu { + +class OneDnnSoftmaxTest : public HloTestBase {}; + +TEST_F(OneDnnSoftmaxTest, Softmaxtest) { + const std::string hlo_string = R"( + HloModule jit_softmax, entry_computation_layout={(f32[16,128,30522]{2,1,0})->f32[16,128,30522]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true} + region_0.4 { + Arg_0.5 = f32[] parameter(0) + Arg_1.6 = f32[] parameter(1) + ROOT maximum.7 = f32[] maximum(Arg_0.5, Arg_1.6) + } + region_1.15 { + Arg_0.16 = f32[] parameter(0) + Arg_1.17 = f32[] parameter(1) + ROOT add.18 = f32[] add(Arg_0.16, Arg_1.17) + } + ENTRY main.25 { + Arg_0.1 = f32[16,128,30522]{2,1,0} parameter(0), sharding={replicated} + constant.3 = f32[] constant(-inf) + reduce.8 = f32[16,128]{1,0} reduce(Arg_0.1, constant.3), dimensions={2}, to_apply=region_0.4 + reshape.9 = f32[16,128,1]{2,1,0} reshape(reduce.8) + broadcast.10 = f32[16,128,1]{2,1,0} broadcast(reshape.9), dimensions={0,1,2} + reshape.11 = f32[16,128]{1,0} reshape(broadcast.10) + broadcast.12 = f32[16,128,30522]{2,1,0} broadcast(reshape.11), dimensions={0,1} + subtract.13 = f32[16,128,30522]{2,1,0} subtract(Arg_0.1, broadcast.12) + exponential.14 = f32[16,128,30522]{2,1,0} exponential(subtract.13) + constant.2 = f32[] constant(0) + reduce.19 = f32[16,128]{1,0} reduce(exponential.14, constant.2), dimensions={2}, to_apply=region_1.15 + reshape.20 = f32[16,128,1]{2,1,0} reshape(reduce.19) + broadcast.21 = f32[16,128,1]{2,1,0} broadcast(reshape.20), dimensions={0,1,2} + reshape.22 = f32[16,128]{1,0} reshape(broadcast.21) + broadcast.23 = f32[16,128,30522]{2,1,0} broadcast(reshape.22), dimensions={0,1} + ROOT divide.24 = f32[16,128,30522]{2,1,0} divide(exponential.14, broadcast.23) + } + )"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4})); +} + +TEST_F(OneDnnSoftmaxTest, SoftmaxFP32) { + const std::string hlo_string = R"( + HloModule jit_softmax, entry_computation_layout={(f32[1,128,30522]{2,1,0})->f32[1,128,30522]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true} + region_0.4 { + Arg_0.5 = f32[] parameter(0) + Arg_1.6 = f32[] parameter(1) + ROOT maximum.7 = f32[] maximum(Arg_0.5, Arg_1.6) + } + region_1.15 { + Arg_0.16 = f32[] parameter(0) + Arg_1.17 = f32[] parameter(1) + ROOT add.18 = f32[] add(Arg_0.16, Arg_1.17) + } + ENTRY main.25 { + Arg_0.1 = f32[1,128,30522]{2,1,0} parameter(0), sharding={replicated} + constant.3 = f32[] constant(-inf) + reduce.8 = f32[1,128]{1,0} reduce(Arg_0.1, constant.3), dimensions={2}, to_apply=region_0.4 + reshape.9 = f32[1,128,1]{2,1,0} reshape(reduce.8) + broadcast.10 = f32[1,128,1]{2,1,0} broadcast(reshape.9), dimensions={0,1,2} + reshape.11 = f32[1,128]{1,0} reshape(broadcast.10) + broadcast.12 = f32[1,128,30522]{2,1,0} broadcast(reshape.11), dimensions={0,1} + subtract.13 = f32[1,128,30522]{2,1,0} subtract(Arg_0.1, broadcast.12) + exponential.14 = f32[1,128,30522]{2,1,0} exponential(subtract.13) + constant.2 = f32[] constant(0) + reduce.19 = f32[1,128]{1,0} reduce(exponential.14, constant.2), dimensions={2}, to_apply=region_1.15 + reshape.20 = f32[1,128,1]{2,1,0} reshape(reduce.19) + broadcast.21 = f32[1,128,1]{2,1,0} broadcast(reshape.20), dimensions={0,1,2} + reshape.22 = f32[1,128]{1,0} reshape(broadcast.21) + broadcast.23 = f32[1,128,30522]{2,1,0} broadcast(reshape.22), dimensions={0,1} + ROOT divide.24 = f32[1,128,30522]{2,1,0} divide(exponential.14, broadcast.23) + } + )"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4})); +} + +TEST_F(OneDnnSoftmaxTest, SoftmaxBF16) { + if (!IsSupportedType(PrimitiveType::BF16)) { + GTEST_SKIP() << "CPU does not support BF16."; + } + + const std::string hlo_string = R"( + HloModule jit_softmax, entry_computation_layout={(bf16[1,128,30522]{2,1,0})->bf16[1,128,30522]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true} + region_0.4 { + Arg_0.5 = bf16[] parameter(0) + Arg_1.6 = bf16[] parameter(1) + ROOT maximum.7 = bf16[] maximum(Arg_0.5, Arg_1.6) + } + region_1.15 { + Arg_0.16 = bf16[] parameter(0) + Arg_1.17 = bf16[] parameter(1) + ROOT add.18 = bf16[] add(Arg_0.16, Arg_1.17) + } + ENTRY main.25 { + Arg_0.1 = bf16[1,128,30522]{2,1,0} parameter(0), sharding={replicated} + constant.3 = bf16[] constant(-inf) + reduce.8 = bf16[1,128]{1,0} reduce(Arg_0.1, constant.3), dimensions={2}, to_apply=region_0.4 + reshape.9 = bf16[1,128,1]{2,1,0} reshape(reduce.8) + broadcast.10 = bf16[1,128,1]{2,1,0} broadcast(reshape.9), dimensions={0,1,2} + reshape.11 = bf16[1,128]{1,0} reshape(broadcast.10) + broadcast.12 = bf16[1,128,30522]{2,1,0} broadcast(reshape.11), dimensions={0,1} + subtract.13 = bf16[1,128,30522]{2,1,0} subtract(Arg_0.1, broadcast.12) + exponential.14 = bf16[1,128,30522]{2,1,0} exponential(subtract.13) + constant.2 = bf16[] constant(0) + reduce.19 = bf16[1,128]{1,0} reduce(exponential.14, constant.2), dimensions={2}, to_apply=region_1.15 + reshape.20 = bf16[1,128,1]{2,1,0} reshape(reduce.19) + broadcast.21 = bf16[1,128,1]{2,1,0} broadcast(reshape.20), dimensions={0,1,2} + reshape.22 = bf16[1,128]{1,0} reshape(broadcast.21) + broadcast.23 = bf16[1,128,30522]{2,1,0} broadcast(reshape.22), dimensions={0,1} + ROOT divide.24 = bf16[1,128,30522]{2,1,0} divide(exponential.14, broadcast.23) + } + )"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4})); +} + +TEST_F(OneDnnSoftmaxTest, SoftmaxF32toBF16) { + if (!IsSupportedType(PrimitiveType::BF16)) { + GTEST_SKIP() << "CPU does not support BF16."; + } + + const std::string hlo_string = R"( + HloModule jit_softmax, entry_computation_layout={(f32[16,128,30522]{2,1,0})->bf16[16,128,30522]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true} + region_0.4 { + Arg_0.5 = f32[] parameter(0) + Arg_1.6 = f32[] parameter(1) + ROOT maximum.7 = f32[] maximum(Arg_0.5, Arg_1.6) + } + region_1.15 { + Arg_0.16 = f32[] parameter(0) + Arg_1.17 = f32[] parameter(1) + ROOT add.18 = f32[] add(Arg_0.16, Arg_1.17) + } + ENTRY main.25 { + Arg_0.1 = f32[16,128,30522]{2,1,0} parameter(0), sharding={replicated} + constant.3 = f32[] constant(-inf) + reduce.8 = f32[16,128]{1,0} reduce(Arg_0.1, constant.3), dimensions={2}, to_apply=region_0.4 + reshape.9 = f32[16,128,1]{2,1,0} reshape(reduce.8) + broadcast.10 = f32[16,128,1]{2,1,0} broadcast(reshape.9), dimensions={0,1,2} + reshape.11 = f32[16,128]{1,0} reshape(broadcast.10) + broadcast.12 = f32[16,128,30522]{2,1,0} broadcast(reshape.11), dimensions={0,1} + subtract.13 = f32[16,128,30522]{2,1,0} subtract(Arg_0.1, broadcast.12) + exponential.14 = f32[16,128,30522]{2,1,0} exponential(subtract.13) + constant.2 = f32[] constant(0) + reduce.19 = f32[16,128]{1,0} reduce(exponential.14, constant.2), dimensions={2}, to_apply=region_1.15 + reshape.20 = f32[16,128,1]{2,1,0} reshape(reduce.19) + broadcast.21 = f32[16,128,1]{2,1,0} broadcast(reshape.20), dimensions={0,1,2} + reshape.22 = f32[16,128]{1,0} reshape(broadcast.21) + broadcast.23 = f32[16,128,30522]{2,1,0} broadcast(reshape.22), dimensions={0,1} + divide.24 = f32[16,128,30522]{2,1,0} divide(exponential.14, broadcast.23) + ROOT convert.1 = bf16[16,128,30522]{2,1,0} convert(divide.24) + } + )"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4})); +} + +} // namespace cpu +} // namespace xla + +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 diff --git a/xla/tests/outfeed_in_nested_computation_test.cc b/xla/tests/outfeed_in_nested_computation_test.cc index daa65e5da58b1..b111c620e96f2 100644 --- a/xla/tests/outfeed_in_nested_computation_test.cc +++ b/xla/tests/outfeed_in_nested_computation_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/pad_test.cc b/xla/tests/pad_test.cc index 919c54533fbf0..cd039d3daaf93 100644 --- a/xla/tests/pad_test.cc +++ b/xla/tests/pad_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/params_test.cc b/xla/tests/params_test.cc index 85baa6f22bfe1..8b55bb3203af8 100644 --- a/xla/tests/params_test.cc +++ b/xla/tests/params_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/pjrt_client_registry.cc b/xla/tests/pjrt_client_registry.cc index cd02928fbcbb5..c4f66923c229c 100644 --- a/xla/tests/pjrt_client_registry.cc +++ b/xla/tests/pjrt_client_registry.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/pjrt_client_registry.h b/xla/tests/pjrt_client_registry.h index fd492df24ee95..7d82fddc058c3 100644 --- a/xla/tests/pjrt_client_registry.h +++ b/xla/tests/pjrt_client_registry.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,11 +32,11 @@ class PjRtClientTestFactoryRegistry { typedef std::function DeviceShapeRepresentationFn; typedef std::function DeviceShapeRepresentationFnFactory; - typedef std::function>()> + typedef std::function>()> PjRtClientFactory; static DeviceShapeRepresentationFn DefaultShapeRepresentationRegisteredFn( - StatusOr client) { + absl::StatusOr client) { return [](const Shape& host_shape) { return host_shape; }; } @@ -66,14 +66,14 @@ class PjRtClientTestFactoryRegistry { return factory_ != nullptr; } - std::function>()> Get() const { + std::function>()> Get() const { absl::MutexLock lock(&mu_); return factory_; } private: mutable absl::Mutex mu_; - std::function>()> factory_ + std::function>()> factory_ ABSL_GUARDED_BY(mu_); DeviceShapeRepresentationFnFactory registered_device_shape_representation_fn_; }; diff --git a/xla/tests/pjrt_cpu_client_registry.cc b/xla/tests/pjrt_cpu_client_registry.cc index d65884c15bbdb..540cf3d59ff6b 100644 --- a/xla/tests/pjrt_cpu_client_registry.cc +++ b/xla/tests/pjrt_cpu_client_registry.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/pjrt_gpu_client_registry.cc b/xla/tests/pjrt_gpu_client_registry.cc index 11fe68be7fbba..3723b22935280 100644 --- a/xla/tests/pjrt_gpu_client_registry.cc +++ b/xla/tests/pjrt_gpu_client_registry.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/pjrt/gpu/gpu_helpers.h" #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/tests/pjrt_client_registry.h" @@ -20,17 +21,18 @@ namespace xla { namespace { // Register a GPU PjRt client for tests. -const bool kUnused = - (RegisterPjRtClientTestFactory([]() { - xla::GpuAllocatorConfig gpu_config; - gpu_config.kind = xla::GpuAllocatorConfig::Kind::kDefault; - gpu_config.preallocate = true; - gpu_config.memory_fraction = 0.08; - GpuClientOptions options; - options.allocator_config = gpu_config; - return GetStreamExecutorGpuClient(options); - }), - true); +const bool kUnused = (RegisterPjRtClientTestFactory([]() { + xla::GpuAllocatorConfig gpu_config; + gpu_config.kind = + xla::GpuAllocatorConfig::Kind::kDefault; + gpu_config.preallocate = true; + gpu_config.memory_fraction = 0.08; + gpu_config.collective_memory_size = 0; + GpuClientOptions options; + options.allocator_config = gpu_config; + return GetStreamExecutorGpuClient(options); + }), + true); } // namespace } // namespace xla diff --git a/xla/tests/plugin.bzl b/xla/tests/plugin.bzl index 107869fe59d43..4cd2568fa79ab 100644 --- a/xla/tests/plugin.bzl +++ b/xla/tests/plugin.bzl @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/xla/tests/pred_test.cc b/xla/tests/pred_test.cc index 0fecb57a67c69..89ba59cd70b35 100644 --- a/xla/tests/pred_test.cc +++ b/xla/tests/pred_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/prng_test.cc b/xla/tests/prng_test.cc index cf9e1dcbae618..5ba28ae38b73c 100644 --- a/xla/tests/prng_test.cc +++ b/xla/tests/prng_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,10 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include #include #include +#include +#include +#include #include "absl/types/span.h" +#include "unsupported/Eigen/SpecialFunctions" // from @eigen_archive #include "xla/client/local_client.h" #include "xla/client/xla_builder.h" #include "xla/literal.h" @@ -43,8 +51,8 @@ class PrngTest : public ClientLibraryTestBase { // of the given range size. `expected_count` is the number of times each // possible value is expected to be generated. Thus, the sample size is // `range_size * expected_count`. - double UniformChiSquared(int32_t range_size, int32_t expected_count, - int64_t seed = 42); + void UniformChiSquared(int32_t range_size, int32_t expected_count, + int64_t seed = 42); }; template @@ -141,10 +149,30 @@ template T Square(T x) { return x * x; } + +// Calculates the p-value (probability) of a given chi-square value and degrees +// of freedom. +double ChiSquarePValue(double chi_square, int dof) { + // We are doing a right-tailed test so the p-value is calculated as 1 - CDF. + // + // The CDF can be computed using the regularized lower incomplete gamma + // function like so: + // gammainc(dof/2, chi_square/2). + // + // Seeing as we are interested in 1-CDF, we can compute this using the + // regularized upper incomplete gamma function like so: + // gammaincc(dof/2, chi_square/2). + // + // NIST/SEMATECH e-Handbook of Statistical Methods, 1.3.6.6.6. Chi-Square + // Distribution: Cumulative Distribution Function + // https://www.itl.nist.gov/div898/handbook/eda/section3/eda3666.htm#cdf + return Eigen::numext::igammac(0.5 * dof, 0.5 * chi_square); +} + } // namespace -double PrngTest::UniformChiSquared(int32_t range_size, int32_t expected_count, - int64_t seed) { +void PrngTest::UniformChiSquared(int32_t range_size, int32_t expected_count, + int64_t seed) { int32_t sample_size = range_size * expected_count; XlaBuilder builder(TestName()); @@ -157,34 +185,48 @@ double PrngTest::UniformChiSquared(int32_t range_size, int32_t expected_count, std::vector counts(range_size, 0); actual.EachCell( [&counts](absl::Span, int32_t value) { ++counts[value]; }); + LOG(INFO) << "sample_size = " << sample_size; + LOG(INFO) << "range_size = " << range_size; + LOG(INFO) << "expected_count = " << expected_count; + for (int32_t i = 0; i < range_size; ++i) { + LOG(INFO) << "counts[" << i << "] = " << counts[i]; + } int64_t sum = 0; for (int32_t i = 0; i < range_size; ++i) { sum += Square(static_cast(counts[i] - expected_count)); } - return static_cast(sum) / expected_count; + double chi_square = static_cast(sum) / expected_count; + int64_t dof = range_size - 1; + double p_value = ChiSquarePValue(chi_square, dof); + const double kLevelOfSignificance = 1e-5; + // We have two hypotheses: + // - null hypothesis: the distribution we sampled from cannot be distinguished + // from a uniform random distribution. + // - alternate hypothesis: the distribution we sampled from can be + // distinguished from a uniform random distribution. + // + // The lower our calculated p-value, the less likely we would get this result + // if the null hypothesis were true. If our p-value is greater than or equal + // to `kLevelOfSignificance`, we cannot reject the null hypothesis. + // + // Another way of saying this is that if our p-value is greater than or equal + // to `kLevelOfSignificance` then we can consider our data randomly + // distributed with a confidence of 1-kLevelOfSignificance; otherwise, if our + // p-value is less than `kLevelOfSignificance` then our data is non-random + // with a confidence of 1-kLevelOfSignificance. + EXPECT_GE(p_value, kLevelOfSignificance); } // We only test distribution of uniform discrete PRNG as other types are based // on it. // These range sizes are arbitrary but include prime numbers, powers of 2, and // other composite numbers. -// The level of significance in all these cases is 1/20. // TODO(b/35723038): Use parametrized tests where possible. -XLA_TEST_F(PrngTest, Uniformity7) { - EXPECT_LT(UniformChiSquared(7, 256), 12.5916); -} -XLA_TEST_F(PrngTest, Uniformity61) { - EXPECT_LT(UniformChiSquared(61, 256), 79.0819); -} -XLA_TEST_F(PrngTest, Uniformity64) { - EXPECT_LT(UniformChiSquared(64, 256), 82.5287); -} -XLA_TEST_F(PrngTest, Uniformity108) { - EXPECT_LT(UniformChiSquared(108, 256), 132.144); -} -XLA_TEST_F(PrngTest, Uniformity256) { - EXPECT_LT(UniformChiSquared(256, 512), 293.248); -} +XLA_TEST_F(PrngTest, Uniformity7) { UniformChiSquared(7, 256); } +XLA_TEST_F(PrngTest, Uniformity61) { UniformChiSquared(61, 256); } +XLA_TEST_F(PrngTest, Uniformity64) { UniformChiSquared(64, 256); } +XLA_TEST_F(PrngTest, Uniformity108) { UniformChiSquared(108, 256); } +XLA_TEST_F(PrngTest, Uniformity256) { UniformChiSquared(256, 256); } // TODO(b/134770669): May remove this test if we decide not to support map // computations with kRng instructions. diff --git a/xla/tests/ptxas_bug_120501638.cc b/xla/tests/ptxas_bug_120501638.cc index 9f56aac448beb..2c1217cf8be91 100644 --- a/xla/tests/ptxas_bug_120501638.cc +++ b/xla/tests/ptxas_bug_120501638.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/query_inferred_shape_test.cc b/xla/tests/query_inferred_shape_test.cc index d648358d211f2..673b736ca87aa 100644 --- a/xla/tests/query_inferred_shape_test.cc +++ b/xla/tests/query_inferred_shape_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ TEST_F(QueryInferredShapeTest, OnePlusOneShape) { XlaBuilder builder("one_plus_one"); auto one = ConstantR0(&builder, 1.0); auto result = Add(one, one); - StatusOr shape_status = builder.GetShape(result); + absl::StatusOr shape_status = builder.GetShape(result); ASSERT_IS_OK(shape_status.status()); auto shape = shape_status.value(); ASSERT_TRUE(ShapeUtil::Equal(shape, ShapeUtil::MakeShape(F32, {}))); diff --git a/xla/tests/reduce_hlo_test.cc b/xla/tests/reduce_hlo_test.cc index dced103ed4c7c..60378a80daafa 100644 --- a/xla/tests/reduce_hlo_test.cc +++ b/xla/tests/reduce_hlo_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -51,7 +51,7 @@ class ReduceWithLayoutTest : public HloTestBase, public ::testing::WithParamInterface { public: - StatusOr> GetParsedModule() { + absl::StatusOr> GetParsedModule() { const char* const hlo_string = R"( HloModule BadReduce @@ -125,9 +125,7 @@ INSTANTIATE_TEST_CASE_P(ReduceWithLayoutTest_Instantiation, ReduceLayout{{3, 2, 1, 0}, {1, 0, 2}}, // ReduceLayout{{3, 2, 1, 0}, {2, 0, 1}}, // ReduceLayout{{3, 2, 1, 0}, {2, 1, 0}}, // - ReduceLayout{{3, 1, 2, 0}, {1, 2, 0}}, // - ReduceLayout{{1, 2, 3, 0}, {1, 0, 2}}, // - ReduceLayout{{0, 2, 1, 3}, {2, 0, 1}}), // + ReduceLayout{{3, 1, 2, 0}, {1, 2, 0}}), // PrintReduceLayout); } // namespace diff --git a/xla/tests/reduce_precision_test.cc b/xla/tests/reduce_precision_test.cc index 9b8c8f6595fa1..20002e53cebe0 100644 --- a/xla/tests/reduce_precision_test.cc +++ b/xla/tests/reduce_precision_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/reduce_test.cc b/xla/tests/reduce_test.cc index 843f276d82868..2e8bc921fa664 100644 --- a/xla/tests/reduce_test.cc +++ b/xla/tests/reduce_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -1044,6 +1044,50 @@ XLA_TEST_F(ReduceHloTest, HandleReductionToVectorAndOtherReduction) { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5})); } +XLA_TEST_F(ReduceHloTest, ReduceAtomicF16) { + absl::string_view hlo_string = R"( +HloModule jit_reduce_axes12 + +region_0.3 { + Arg_0.4 = f16[] parameter(0) + Arg_1.5 = f16[] parameter(1) + ROOT minimum.6 = f16[] minimum(Arg_0.4, Arg_1.5) +} + +ENTRY main.8 { + constant.1 = f16[] constant(1) + Arg_0.1 = f16[2,16385,1]{2,1,0} broadcast(constant.1), dimensions={} + constant.2 = f16[] constant(inf) + ROOT reduce.7 = f16[2]{0} reduce(Arg_0.1, constant.2), dimensions={1,2}, to_apply=region_0.3 +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5})); +} + +XLA_TEST_F(ReduceHloTest, ReduceWithEpilogueMultiOutputFusion) { + absl::string_view hlo_string = R"( + HloModule test_module + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + + ENTRY main { + %p0 = f32[1024] parameter(0) + %p1 = f32[] parameter(1) + %reduce = f32[] reduce(%p0, %p1), dimensions={0}, to_apply=add + %p2 = f32[1024] parameter(2) + %reduce2 = f32[] reduce(%p2, %p1), dimensions={0}, to_apply=add + %negate = f32[] negate(%reduce) + %log = f32[] log(%reduce) + ROOT %tuple = (f32[], f32[], f32[]) tuple(%negate, %reduce2, %log) + })"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5})); +} + class VariadicReduceTest : public HloTestBase {}; XLA_TEST_F(VariadicReduceTest, Reduce_R3x2_to_R2x2_simple) { diff --git a/xla/tests/reduce_window_test.cc b/xla/tests/reduce_window_test.cc index a6d7026d765b9..b76065f55649f 100644 --- a/xla/tests/reduce_window_test.cc +++ b/xla/tests/reduce_window_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/replay_test.cc b/xla/tests/replay_test.cc index 986cfd62515c3..65509462d7641 100644 --- a/xla/tests/replay_test.cc +++ b/xla/tests/replay_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/replicated_io_feed_test.cc b/xla/tests/replicated_io_feed_test.cc index fb32db39dc449..d3600e4602f13 100644 --- a/xla/tests/replicated_io_feed_test.cc +++ b/xla/tests/replicated_io_feed_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/reshape_motion_test.cc b/xla/tests/reshape_motion_test.cc index be9eae0ecb916..f826ed4aac7b6 100644 --- a/xla/tests/reshape_motion_test.cc +++ b/xla/tests/reshape_motion_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/reshape_test.cc b/xla/tests/reshape_test.cc index 263d45c266d68..6ea3b214a91f5 100644 --- a/xla/tests/reshape_test.cc +++ b/xla/tests/reshape_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/reverse_test.cc b/xla/tests/reverse_test.cc index c7b6084693792..299ea416e3c9e 100644 --- a/xla/tests/reverse_test.cc +++ b/xla/tests/reverse_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/round_trip_packed_literal_test.cc b/xla/tests/round_trip_packed_literal_test.cc index 8fd57f12a7ba0..8a2fbf12d4e6c 100644 --- a/xla/tests/round_trip_packed_literal_test.cc +++ b/xla/tests/round_trip_packed_literal_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/round_trip_transfer_test.cc b/xla/tests/round_trip_transfer_test.cc index e7588016d3e77..368dd6c7dd283 100644 --- a/xla/tests/round_trip_transfer_test.cc +++ b/xla/tests/round_trip_transfer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/sample_file_test.cc b/xla/tests/sample_file_test.cc index 030aaac350377..367ef95880fb9 100644 --- a/xla/tests/sample_file_test.cc +++ b/xla/tests/sample_file_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/sample_text_test.cc b/xla/tests/sample_text_test.cc index d135fd34a75fe..576cdeffea8a8 100644 --- a/xla/tests/sample_text_test.cc +++ b/xla/tests/sample_text_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/scalar_computations_test.cc b/xla/tests/scalar_computations_test.cc index 7e6c96a596b67..8f45987e95aaa 100644 --- a/xla/tests/scalar_computations_test.cc +++ b/xla/tests/scalar_computations_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include "absl/strings/str_cat.h" #include "absl/types/span.h" @@ -61,7 +62,13 @@ class ScalarComputationsTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0(&builder, lhs); XlaOp rhs_op = ConstantR0(&builder, rhs); - op(lhs_op, rhs_op, {}); + XlaOp minmax_op = op(lhs_op, rhs_op, {}); + // Canonicalize NaNs so we can do a bitwise compare without caring about + // payloads. + if constexpr (std::is_floating_point_v) { + XlaOp isnan_op = Ne(minmax_op, minmax_op); + Select(isnan_op, ConstantR0(&builder, NAN), minmax_op); + } ComputeAndCompareR0(&builder, expected, {}); } }; diff --git a/xla/tests/scatter_test.cc b/xla/tests/scatter_test.cc index 513130f228004..0151b863e1a08 100644 --- a/xla/tests/scatter_test.cc +++ b/xla/tests/scatter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,13 +17,16 @@ limitations under the License. #include #include "absl/strings/substitute.h" +#include "xla/array2d.h" #include "xla/error_spec.h" #include "xla/literal.h" +#include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/types.h" namespace xla { namespace { @@ -322,6 +325,39 @@ ENTRY main { RunTest(hlo_text, &operand, &scatter_indices, &updates); } +XLA_TEST_F(ScatterTest, TensorFlowScatter_F16) { + const std::string hlo_text = R"( +HloModule TensorFlowScatter_F16 + +add_f16 (lhs: f16[], rhs: f16[]) -> f16[] { + lhs = f16[] parameter(0) + rhs = f16[] parameter(1) + ROOT add = f16[] add(f16[] lhs, f16[] rhs) +} + +ENTRY main { + operand = f16[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = f16[2,3] parameter(2) + ROOT scatter = f16[3,3] scatter(operand, indices, updates), + to_apply=add_f16, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + Array2D operand_array( + {{1.1f, 2.2f, 3.3f}, {4.4f, 5.5f, 6.6f}, {7.7f, 8.8f, 9.9f}}); + Literal operand(ShapeUtil::MakeShape(F16, {3, 3})); + operand.PopulateR2FromArray2D(operand_array); + Literal scatter_indices = LiteralUtil::CreateR1({2, 1}); + Array2D updates_array({{0.4f, 1.1f, 0.7f}, {2.3f, 3.1f, 1.6f}}); + Literal updates(ShapeUtil::MakeShape(F16, {2, 3})); + updates.PopulateR2FromArray2D(updates_array); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) { const char* hlo_text = R"( HloModule TensorFlowScatter diff --git a/xla/tests/select_and_scatter_test.cc b/xla/tests/select_and_scatter_test.cc index 9822f9220bb8f..c4c9bb7b0a3fe 100644 --- a/xla/tests/select_and_scatter_test.cc +++ b/xla/tests/select_and_scatter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -86,7 +86,7 @@ XLA_TEST_P(SelectAndScatterTest, ParamTest) { GetParam().window_strides, GetParam().padding_type, source, ConstantR0(&builder_, 0.0f), add_f32_); - ComputeAndCompare(&builder_, {}, ErrorSpec(1e-5, 1e-5)); + ComputeAndCompare(&builder_, {}, ErrorSpec(3e-5, 3e-5)); } INSTANTIATE_TEST_CASE_P( diff --git a/xla/tests/select_test.cc b/xla/tests/select_test.cc index 39644bf7692a3..30eba6bdd783a 100644 --- a/xla/tests/select_test.cc +++ b/xla/tests/select_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/slice_test.cc b/xla/tests/slice_test.cc index d174a8a70572b..3800431fed8dc 100644 --- a/xla/tests/slice_test.cc +++ b/xla/tests/slice_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/test_macros.cc b/xla/tests/test_macros.cc index b4a4ba48a2214..13fddf2c15b8d 100644 --- a/xla/tests/test_macros.cc +++ b/xla/tests/test_macros.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/test_macros.h b/xla/tests/test_macros.h index c4ec6340d706b..91874e0636279 100644 --- a/xla/tests/test_macros.h +++ b/xla/tests/test_macros.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/test_utils.cc b/xla/tests/test_utils.cc index 80503956551a8..d2934f0395c6e 100644 --- a/xla/tests/test_utils.cc +++ b/xla/tests/test_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include "xla/hlo/ir/hlo_casting_utils.h" @@ -125,14 +126,29 @@ void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) { } template -void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine, - bool no_duplicates, bool use_large_range) { +void PopulateWithFloatingPointData( + Literal* literal, std::minstd_rand0* engine, bool no_duplicates, + bool use_large_range, std::optional max_bits_of_precision) { using ComputeT = std::conditional_t; CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); - if (no_duplicates) { + if (max_bits_of_precision.has_value()) { + CHECK(!use_large_range) << "Cannot set both use_large_range and " + "max_bits_of_precision for floating points."; + CHECK(!no_duplicates) << "Cannot set both no_duplicates and " + "max_bits_of_precision for floating points."; + std::uniform_int_distribution generator( + -(1 << *max_bits_of_precision), 1 << *max_bits_of_precision); + for (FloatT& value : literal->data()) { + int64_t temp = generator(*engine); + // We want to generate floating point numbers to a fixed precision, while + // keeping them between -1 and 1. This preserves their bits of precision + // while keeping the numbers small. + value = static_cast(temp * pow(2, -ceil(log2(abs(temp))))); + } + } else if (no_duplicates) { PopulateWithNoDuplicateData(literal, engine); } else if (use_large_range) { PopulateWithRandomFullRangeFloatingPointData(literal, engine); @@ -153,10 +169,12 @@ void PopulateWithComplexData(Literal* result, std::minstd_rand0* engine, Literal real_lit(floating_point_shape); Literal imaginary_lit(floating_point_shape); - PopulateWithFloatingPointData(&real_lit, engine, no_duplicates, - use_large_range); - PopulateWithFloatingPointData(&imaginary_lit, engine, - no_duplicates, use_large_range); + PopulateWithFloatingPointData( + &real_lit, engine, no_duplicates, use_large_range, + /*max_bits_of_precision=*/std::nullopt); + PopulateWithFloatingPointData( + &imaginary_lit, engine, no_duplicates, use_large_range, + /*max_bits_of_precision=*/std::nullopt); absl::Span real_data = real_lit.data(); absl::Span imaginary_data = @@ -210,10 +228,13 @@ void PopulateWithRandomIntegralDataWithBounds(Literal* literal, // the type. (floating point format only) // 'use_large_range' indicates the sampled data is from the full range of the // floating point format. (floating point format only) -StatusOr MakeFakeLiteralInternal( +// 'max_bits_of_precision' sets the data to have the given number of bits or +// less (integer or floating point formats only). +absl::StatusOr MakeFakeLiteralInternal( const Shape& shape, std::minstd_rand0* engine, std::optional> limit, bool is_sorted, - bool no_duplicates, bool use_large_range) { + bool no_duplicates, bool use_large_range, + std::optional max_bits_of_precision) { if (shape.IsTuple()) { std::vector elements; const auto& shape_tuple_shapes = shape.tuple_shapes(); @@ -222,7 +243,8 @@ StatusOr MakeFakeLiteralInternal( TF_ASSIGN_OR_RETURN( Literal element, MakeFakeLiteralInternal(element_shape, engine, limit, is_sorted, - no_duplicates, use_large_range)); + no_duplicates, use_large_range, + max_bits_of_precision)); elements.push_back(std::move(element)); } return LiteralUtil::MakeTupleOwned(std::move(elements)); @@ -234,6 +256,7 @@ StatusOr MakeFakeLiteralInternal( // literal. Shape new_shape = shape; new_shape.mutable_layout()->clear_tiles(); + new_shape.mutable_layout()->set_tail_padding_alignment_in_elements(1); new_shape.mutable_layout()->set_element_size_in_bits(0); Literal literal(new_shape); @@ -244,7 +267,8 @@ StatusOr MakeFakeLiteralInternal( if constexpr (primitive_util::IsFloatingPointType( primitive_type_constant)) { PopulateWithFloatingPointData( - &literal, engine, no_duplicates, use_large_range); + &literal, engine, no_duplicates, use_large_range, + max_bits_of_precision); return OkStatus(); } if constexpr (primitive_type_constant == PRED) { @@ -263,6 +287,15 @@ StatusOr MakeFakeLiteralInternal( max = static_cast(limit->second); min = static_cast(limit->first); } + if (max_bits_of_precision.has_value()) { + max = std::min(max, + static_cast(1 << *max_bits_of_precision)); + if (primitive_util::IsSignedIntegralType( + primitive_type_constant)) { + min = std::max( + min, static_cast(-(1 << *max_bits_of_precision))); + } + } PopulateWithRandomIntegralDataWithBounds( &literal, engine, /*no_duplicate*/ no_duplicates, min, max); if (is_sorted) { @@ -436,10 +469,11 @@ std::vector FindConstrainedUses( // no constrained uses in the dataflow graph. If such constraints exist, // generate a constrained literal (either bounded in the case of indices, or // zero in the case of init_values for reductions). -StatusOr CreateLiteralForConstrainedUses( +absl::StatusOr CreateLiteralForConstrainedUses( const absl::Span constrained_uses, const HloInstruction& param, const Shape& param_shape, - std::minstd_rand0* engine, bool use_large_range) { + std::minstd_rand0* engine, bool use_large_range, + std::optional max_bits_of_precision) { int64_t index_bound = INT64_MAX; bool no_duplicates = false; bool needs_constant = false; @@ -516,9 +550,10 @@ StatusOr CreateLiteralForConstrainedUses( return Unimplemented("Conflicting operand generation constraints."); } if (index_bound != INT64_MAX) { - return MakeFakeLiteralInternal( - param_shape, engine, std::pair(0, index_bound), - needs_sorted_indices, no_duplicates, use_large_range); + return MakeFakeLiteralInternal(param_shape, engine, + std::pair(0, index_bound), + needs_sorted_indices, no_duplicates, + use_large_range, max_bits_of_precision); } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: @@ -532,50 +567,54 @@ StatusOr CreateLiteralForConstrainedUses( return MakeFakeLiteralInternal( param_shape, engine, /*limit=*/std::nullopt, /*is_sorted=*/needs_sorted_indices, - /*no_duplicates=*/false, use_large_range); + /*no_duplicates=*/false, use_large_range, max_bits_of_precision); } } else { return MakeFakeLiteralInternal(param_shape, engine, /*limit=*/std::nullopt, /*is_sorted=*/needs_sorted_indices, - no_duplicates, use_large_range); + no_duplicates, use_large_range, + max_bits_of_precision); } } // Given a module entry parameter, use the dataflow analysis to see if a // special case literal must be created, or if we can generate fake data. -StatusOr MakeConstrainedArgument(const HloDataflowAnalysis& dataflow, - const HloInstruction& param, - const Shape& param_shape, - std::minstd_rand0* engine, - bool use_large_range, - bool treat_gte_as_data_formatting) { +absl::StatusOr MakeConstrainedArgument( + const HloDataflowAnalysis& dataflow, const HloInstruction& param, + const Shape& param_shape, std::minstd_rand0* engine, bool use_large_range, + bool treat_gte_as_data_formatting, + std::optional max_bits_of_precision) { const auto constrained_uses = FindConstrainedUses(dataflow, param, treat_gte_as_data_formatting); return CreateLiteralForConstrainedUses(constrained_uses, param, param_shape, - engine, use_large_range); + engine, use_large_range, + max_bits_of_precision); } } // namespace -StatusOr MakeFakeLiteral(const Shape& shape, bool pseudo_random, - bool use_large_range) { +absl::StatusOr MakeFakeLiteral(const Shape& shape, bool pseudo_random, + bool use_large_range) { auto engine = pseudo_random ? std::make_unique() : nullptr; return MakeFakeLiteralInternal(shape, engine.get(), /*limit=*/std::nullopt, /*is_sorted=*/false, - /*no_duplicates=*/false, use_large_range); + /*no_duplicates=*/false, use_large_range, + /*max_bits_of_precision=*/std::nullopt); } -StatusOr> MakeFakeArguments( +absl::StatusOr> MakeFakeArguments( const HloModule* module, bool pseudo_random, bool use_large_range, - bool treat_gte_as_data_formatting) { + bool treat_gte_as_data_formatting, + std::optional max_bits_of_precision) { auto engine = pseudo_random ? std::make_unique() : nullptr; return MakeFakeArguments(module, engine.get(), use_large_range, - treat_gte_as_data_formatting); + treat_gte_as_data_formatting, max_bits_of_precision); } -StatusOr> MakeFakeArguments( +absl::StatusOr> MakeFakeArguments( const HloModule* module, std::minstd_rand0* engine, bool use_large_range, - bool treat_gte_as_data_formatting) { + bool treat_gte_as_data_formatting, + std::optional max_bits_of_precision) { TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); std::vector arguments(params.size()); @@ -594,7 +633,8 @@ StatusOr> MakeFakeArguments( TF_ASSIGN_OR_RETURN( arguments[i], MakeConstrainedArgument(*dataflow, *params[i], param_shape, engine, - use_large_range, treat_gte_as_data_formatting)); + use_large_range, treat_gte_as_data_formatting, + max_bits_of_precision)); } return std::move(arguments); } @@ -623,13 +663,6 @@ std::unique_ptr CreateCanonicalDot(const Shape& shape, shape, lhs, rhs, dot_dimension_numbers, precision_config); } -bool IsMlirLoweringEnabled() { - char* xla_flags = getenv("XLA_FLAGS"); - if (!xla_flags) { - return false; - } - return !absl::StrContains(xla_flags, "--xla_cpu_use_xla_runtime=false") && - (absl::StrContains(xla_flags, "--xla_cpu_use_xla_runtime")); -} +bool IsMlirLoweringEnabled() { return false; } } // namespace xla diff --git a/xla/tests/test_utils.h b/xla/tests/test_utils.h index ba6a03802f5df..fb101102d6381 100644 --- a/xla/tests/test_utils.h +++ b/xla/tests/test_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -56,8 +56,9 @@ class PseudorandomGenerator { // Generates fake data in a literal of the given shape, or returns an error // status if the element type is currently unhandled for fake data // generation. See below for documentation of pseudo_random and use_large_range. -StatusOr MakeFakeLiteral(const Shape& shape, bool pseudo_random = true, - bool use_large_range = false); +absl::StatusOr MakeFakeLiteral(const Shape& shape, + bool pseudo_random = true, + bool use_large_range = false); // Generates a vector of arguments containing fake data. The number, shape and // layout of the arguments is appropriate for given HLO module. @@ -74,6 +75,10 @@ StatusOr MakeFakeLiteral(const Shape& shape, bool pseudo_random = true, // // These constraints are best-effort only. // +// If max_bits_of_precision is set to a number, then floating point & integer +// types will be constrained to be represented in that number of bits. Setting +// it to 5 for integers would mean it only creates integers between -32 and 32. +// // If pseudo_random is true, the generated numbers will be generated // deterministically in a pseudo random way unless the values are constrated to // be e.g. init values as above. If pseudo_random is false, the returned values @@ -89,16 +94,18 @@ StatusOr MakeFakeLiteral(const Shape& shape, bool pseudo_random = true, // TODO(b/79942829): Make interesting argument generation fast enough that using // pseudo_random does not save any noticeable amount of time so that the // parameter can be removed. -StatusOr> MakeFakeArguments( +absl::StatusOr> MakeFakeArguments( const HloModule* module, bool pseudo_random = true, - bool use_large_range = false, bool treat_gte_as_data_formatting = false); + bool use_large_range = false, bool treat_gte_as_data_formatting = false, + std::optional max_bits_of_precision = std::nullopt); // Overload which accepts a random number generator. This enables generation of // different random values with sequential calls to MakeFakeArguments by reusing // the same generator. -StatusOr> MakeFakeArguments( +absl::StatusOr> MakeFakeArguments( const HloModule* module, std::minstd_rand0* engine, - bool use_large_range = false, bool treat_gte_as_data_formatting = false); + bool use_large_range = false, bool treat_gte_as_data_formatting = false, + std::optional max_bits_of_precision = std::nullopt); // Check that a given module satisfies various constraints before trying to // execute it. diff --git a/xla/tests/test_utils_test.cc b/xla/tests/test_utils_test.cc index 202062f62dfc8..82a95589e6b90 100644 --- a/xla/tests/test_utils_test.cc +++ b/xla/tests/test_utils_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/tile_assignment_test.cc b/xla/tests/tile_assignment_test.cc index 89b42e68c4700..d28fb6e754964 100644 --- a/xla/tests/tile_assignment_test.cc +++ b/xla/tests/tile_assignment_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/token_hlo_test.cc b/xla/tests/token_hlo_test.cc index 7bdf8a46236bd..8991e97916a75 100644 --- a/xla/tests/token_hlo_test.cc +++ b/xla/tests/token_hlo_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,12 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - -#include "xla/service/hlo_verifier.h" +#include +#include +#include +#include + +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/service/hlo_runner.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -78,47 +88,6 @@ XLA_TEST_F(TokenHloTest, TokenTree) { EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken())); } -XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { - std::unique_ptr module = CreateNewUnverifiedModule(); - auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction( - HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); - builder.AddInstruction( - HloInstruction::CreateParameter(1, ShapeUtil::MakeTokenShape(), "p1")); - builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); - module->AddEntryComputation(builder.Build()); - - Status status = - HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) - .Run(module.get()) - .status(); - ASSERT_IS_NOT_OK(status); - EXPECT_THAT( - status.message(), - ::testing::HasSubstr("Entry parameter 1 is or contains a token shape")); -} - -XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { - std::unique_ptr module = CreateNewUnverifiedModule(); - auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction(HloInstruction::CreateParameter( - 0, - ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeTokenShape()}), - "param")); - module->AddEntryComputation(builder.Build()); - - Status status = - HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) - .Run(module.get()) - .status(); - ASSERT_IS_NOT_OK(status); - EXPECT_THAT( - status.message(), - ::testing::HasSubstr("Entry parameter 0 is or contains a token shape")); -} - XLA_TEST_F(TokenHloTest, TokenInWhileLoop) { // Thread a token around a while loop. Token is created and consumed by a // AfterAll instruction in the while body. diff --git a/xla/tests/topk_test.cc b/xla/tests/topk_test.cc index 9ce4450770b45..c8cc503f2c56f 100644 --- a/xla/tests/topk_test.cc +++ b/xla/tests/topk_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/transfer_manager_test.cc b/xla/tests/transfer_manager_test.cc index 6663022daca3a..eb7ac0fb4e032 100644 --- a/xla/tests/transfer_manager_test.cc +++ b/xla/tests/transfer_manager_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -328,7 +328,7 @@ XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) { auto device_buffer2 = AllocateDeviceBuffer(literal2.shape()); auto stream1 = stream_; - auto stream2 = stream_->GetOrCreateSubStream(); + auto stream2 = stream_->GetOrCreateSubStream().value(); Literal result1, result2; diff --git a/xla/tests/transpose_test.cc b/xla/tests/transpose_test.cc index 80be54fa694e4..b936693ea52ab 100644 --- a/xla/tests/transpose_test.cc +++ b/xla/tests/transpose_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/triangular_solve_test.cc b/xla/tests/triangular_solve_test.cc index 49951a0d0719e..9e855c9aab6b9 100644 --- a/xla/tests/triangular_solve_test.cc +++ b/xla/tests/triangular_solve_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/tuple_test.cc b/xla/tests/tuple_test.cc index 741105b807701..d644ceb28959e 100644 --- a/xla/tests/tuple_test.cc +++ b/xla/tests/tuple_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/unary_op_test.cc b/xla/tests/unary_op_test.cc index 174d4b16ba210..e8ea9ff1ae84c 100644 --- a/xla/tests/unary_op_test.cc +++ b/xla/tests/unary_op_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/value_inference_test.cc b/xla/tests/value_inference_test.cc index 532d83942ef81..819e7f87a90c9 100644 --- a/xla/tests/value_inference_test.cc +++ b/xla/tests/value_inference_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -56,8 +56,8 @@ class DynamismInferenceTest : public ValueInferenceTest { explicit DynamismInferenceTest(se::Platform* platform = nullptr) : platform_(platform) {} - StatusOr ComputeDynamismLiteral(XlaOp operand, XlaBuilder* builder, - Layout* output_layout = nullptr) { + absl::StatusOr ComputeDynamismLiteral( + XlaOp operand, XlaBuilder* builder, Layout* output_layout = nullptr) { TF_RETURN_IF_ERROR(builder->first_error()); ValueInference value_inference(builder); TF_ASSIGN_OR_RETURN(auto literal_slice, @@ -65,8 +65,8 @@ class DynamismInferenceTest : public ValueInferenceTest { return literal_slice.Clone(); } - StatusOr ComputeDynamismScalar(XlaOp operand, XlaBuilder* builder, - ShapeIndex index = {}) { + absl::StatusOr ComputeDynamismScalar(XlaOp operand, XlaBuilder* builder, + ShapeIndex index = {}) { TF_ASSIGN_OR_RETURN(auto literal, ComputeDynamismLiteral(operand, builder, nullptr)); return literal.Get({}, index); @@ -558,7 +558,7 @@ class UpperBoundInferenceTest : public ValueInferenceTest { explicit UpperBoundInferenceTest(se::Platform* platform = nullptr) : platform_(platform) {} - StatusOr ComputeUpperBoundLiteral( + absl::StatusOr ComputeUpperBoundLiteral( XlaOp operand, XlaBuilder* builder, Layout* output_layout = nullptr) { ValueInference value_inference(builder); TF_ASSIGN_OR_RETURN(auto literal, @@ -715,7 +715,7 @@ class ConstValueInferenceTest : public ValueInferenceTest { explicit ConstValueInferenceTest(se::Platform* platform = nullptr) : platform_(platform) {} - StatusOr ComputeConstantValueLiteral( + absl::StatusOr ComputeConstantValueLiteral( XlaOp operand, XlaBuilder* builder, Layout* output_layout = nullptr) { ValueInference value_inference(builder); TF_ASSIGN_OR_RETURN(auto literal, value_inference.AnalyzeConstant( diff --git a/xla/tests/vector_ops_reduce_test.cc b/xla/tests/vector_ops_reduce_test.cc index 17249e4765842..f0524ec0a6787 100644 --- a/xla/tests/vector_ops_reduce_test.cc +++ b/xla/tests/vector_ops_reduce_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/vector_ops_simple_test.cc b/xla/tests/vector_ops_simple_test.cc index 9f092ef4b07cd..67e90d03f7162 100644 --- a/xla/tests/vector_ops_simple_test.cc +++ b/xla/tests/vector_ops_simple_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -321,10 +321,9 @@ XLA_TEST_F(VecOpsSimpleTest, ClampFloatEdgeCases) { auto low = ConstantR1(&builder, {NAN, 1, 1}); auto high = ConstantR1(&builder, {3, NAN, 3}); auto x = ConstantR1(&builder, {2, 2, NAN}); - Clamp(low, x, high); - - std::vector expected = {NAN, NAN, NAN}; - ComputeAndCompareR1(&builder, expected, {}); + auto clamp = Clamp(low, x, high); + Eq(clamp, clamp); // Check for NaN. + ComputeAndCompareR1(&builder, {false, false, false}, {}); } XLA_TEST_F(VecOpsSimpleTest, ClampValuesConstantS64) { diff --git a/xla/tests/verified_hlo_module.cc b/xla/tests/verified_hlo_module.cc index 668aa6b42d88b..3eaacbc0ac7a1 100644 --- a/xla/tests/verified_hlo_module.cc +++ b/xla/tests/verified_hlo_module.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/verified_hlo_module.h b/xla/tests/verified_hlo_module.h index 3e8c64595610f..52fce003b9ccf 100644 --- a/xla/tests/verified_hlo_module.h +++ b/xla/tests/verified_hlo_module.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,11 +35,12 @@ class VerifiedHloModule : public HloModule { VerifiedHloModule(const std::string& name, const HloModuleConfig& config, bool verifier_layout_sensitive, bool allow_mixed_precision_in_hlo_verifier, - std::function shape_size_function) + std::function shape_size_function, + HloPredicate instruction_can_change_layout_func = {}) : HloModule(name, config), - verifier_( - verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier, - /*instruction_can_change_layout_func=*/{}, shape_size_function) {} + verifier_(verifier_layout_sensitive, + allow_mixed_precision_in_hlo_verifier, + instruction_can_change_layout_func, shape_size_function) {} ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } diff --git a/xla/tests/while_test.cc b/xla/tests/while_test.cc index d94b3f7a6affc..9624a7f9079aa 100644 --- a/xla/tests/while_test.cc +++ b/xla/tests/while_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/xla_hlo_profile_test.cc b/xla/tests/xla_hlo_profile_test.cc index c6f41fb2a8f95..2806e63f94c07 100644 --- a/xla/tests/xla_hlo_profile_test.cc +++ b/xla/tests/xla_hlo_profile_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tests/xla_internal_test_main.cc b/xla/tests/xla_internal_test_main.cc index a95abc026d4dd..4786355349f53 100644 --- a/xla/tests/xla_internal_test_main.cc +++ b/xla/tests/xla_internal_test_main.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/text_literal_reader.cc b/xla/text_literal_reader.cc index 0838ad597da88..06773cc229f03 100644 --- a/xla/text_literal_reader.cc +++ b/xla/text_literal_reader.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -39,7 +39,7 @@ limitations under the License. namespace xla { -StatusOr TextLiteralReader::ReadPath(absl::string_view path) { +absl::StatusOr TextLiteralReader::ReadPath(absl::string_view path) { CHECK(!absl::EndsWith(path, ".gz")) << "TextLiteralReader no longer supports reading .gz files"; std::unique_ptr file; @@ -55,7 +55,7 @@ StatusOr TextLiteralReader::ReadPath(absl::string_view path) { TextLiteralReader::TextLiteralReader(tsl::RandomAccessFile* file) : file_(file) {} -StatusOr TextLiteralReader::ReadAllLines() { +absl::StatusOr TextLiteralReader::ReadAllLines() { tsl::io::RandomAccessInputStream stream(file_.get()); tsl::io::BufferedInputStream buf(&stream, 65536); std::string shape_string; diff --git a/xla/text_literal_reader.h b/xla/text_literal_reader.h index 0719bf3a03db3..a0d56611bac03 100644 --- a/xla/text_literal_reader.h +++ b/xla/text_literal_reader.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -40,7 +40,7 @@ class TextLiteralReader { public: // See class comment -- reads a file in its entirety (there must be only one // literal in the text file path provided). - static StatusOr ReadPath(absl::string_view path); + static absl::StatusOr ReadPath(absl::string_view path); private: // Ownership of file is transferred. @@ -48,7 +48,7 @@ class TextLiteralReader { // Parses a shape string on the first line, followed by lines of values to the // end of the file. - StatusOr ReadAllLines(); + absl::StatusOr ReadAllLines(); // Owns the file being read std::unique_ptr file_; diff --git a/xla/text_literal_reader_test.cc b/xla/text_literal_reader_test.cc index 6ee3357d58855..afeed461c61be 100644 --- a/xla/text_literal_reader_test.cc +++ b/xla/text_literal_reader_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/text_literal_writer.cc b/xla/text_literal_writer.cc index 07e4ec9b7fdee..c07f66e6b5b73 100644 --- a/xla/text_literal_writer.cc +++ b/xla/text_literal_writer.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/text_literal_writer.h b/xla/text_literal_writer.h index e0eca78642c3c..6ca94296cf73d 100644 --- a/xla/text_literal_writer.h +++ b/xla/text_literal_writer.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/text_literal_writer_test.cc b/xla/text_literal_writer_test.cc index 5c52ed3ad7663..3c9c2d6161eef 100644 --- a/xla/text_literal_writer_test.cc +++ b/xla/text_literal_writer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tools/BUILD b/xla/tools/BUILD index 4fe8619b13af1..4f876c0d6ab4f 100644 --- a/xla/tools/BUILD +++ b/xla/tools/BUILD @@ -1,19 +1,12 @@ # Tools and utilities that aid in XLA development and usage. -load("//xla/tests:build_defs.bzl", "xla_test") load("@bazel_skylib//rules:build_test.bzl", "build_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load( - "//xla:xla.bzl", - "xla_cc_binary", - "xla_cc_test", - "xla_py_proto_library", -) -load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load( "@tsl//tsl:tsl.bzl", "if_cuda_or_rocm", + "if_google", "tsl_gpu_library", ) load("@tsl//tsl:tsl.default.bzl", "filegroup") @@ -26,6 +19,14 @@ load( "@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) +load( + "//xla:xla.bzl", + "xla_cc_binary", + "xla_cc_test", + "xla_py_proto_library", +) +load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") +load("//xla/tests:build_defs.bzl", "xla_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -55,6 +56,7 @@ xla_cc_binary( srcs = ["hex_floats_to_packed_literal.cc"], deps = [ "//xla:types", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", "@tsl//tsl/lib/io:buffered_inputstream", @@ -63,7 +65,6 @@ xla_cc_binary( "@tsl//tsl/platform:logging", "@tsl//tsl/platform:platform_port", "@tsl//tsl/platform:status", - "@tsl//tsl/util:command_line_flags", ], ) @@ -81,7 +82,6 @@ xla_cc_binary( "//xla:shape_util", "//xla:statusor", "//xla:types", - "//xla:xla_data_proto_cc", "//xla/client", "//xla/client:client_library", "//xla/client:local_client", @@ -171,7 +171,6 @@ xla_cc_binary( deps = [ "//xla:statusor", "//xla:types", - "//xla:xla_data_proto_cc", "//xla/client", "//xla/client:client_library", "//xla/client:local_client", @@ -230,11 +229,11 @@ xla_cc_binary( "//xla:statusor", "//xla:util", "//xla/service:hlo_proto_cc", + "//xla/tsl/util:command_line_flags", "@tsl//tsl/platform:env", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:platform_port", "@tsl//tsl/platform:status", - "@tsl//tsl/util:command_line_flags", ], ) @@ -299,13 +298,13 @@ cc_library( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass_pipeline", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@tsl//tsl/platform:env", "@tsl//tsl/platform:path", "@tsl//tsl/platform:platform_port", - "@tsl//tsl/util:command_line_flags", ], ) @@ -322,8 +321,10 @@ cc_library( "//xla/service:hlo_verifier", "//xla/service:rng_bit_generator_expander", "//xla/service:rng_expander", + "//xla/service:sharding_propagation", "//xla/service:triangular_solve_expander", - "@tsl//tsl/util:command_line_flags", + "//xla/service/spmd:stateful_rng_spmd_partitioner", + "//xla/tsl/util:command_line_flags", ], ) @@ -333,6 +334,7 @@ xla_cc_test( data = [ "tests/cholesky.hlo", "tests/invalid_concat.hlo", + "tests/spmd.hlo", ":hlo-expand", ], tags = [ @@ -396,6 +398,7 @@ xla_cc_binary( "//xla/service:hlo_runner", "//xla/service:local_service", "//xla/service:platform_util", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", @@ -404,7 +407,6 @@ xla_cc_binary( "@tsl//tsl/platform:platform_port", "@tsl//tsl/platform:subprocess", "@tsl//tsl/protobuf:error_codes_proto_impl_cc", - "@tsl//tsl/util:command_line_flags", ] + if_cuda_or_rocm([ "//xla/service:gpu_plugin", ]) + if_cuda([ @@ -438,13 +440,13 @@ cc_library( "//xla:debug_options_flags", "//xla:statusor", "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", "//xla/service:hlo_parser", "@com_google_absl//absl/strings", "@tsl//tsl/platform:env", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:path", "@tsl//tsl/platform:protobuf", - "@tsl//tsl/platform:regexp", ], ) @@ -495,28 +497,55 @@ xla_py_proto_library( deps = [":run_hlo_module_proto"], ) +cc_library( + name = "hlo_decomposer_lib", + srcs = ["hlo_decomposer.cc"], + hdrs = ["hlo_decomposer.h"], + deps = [ + "//xla:status", + "//xla/hlo/ir:hlo", + "//xla/service:call_graph", + "//xla/service:compilation_environments", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@tsl//tsl/platform:statusor", + ], +) + cc_library( name = "run_hlo_module_lib", srcs = ["run_hlo_module.cc"], hdrs = ["run_hlo_module.h"], deps = [ ":hlo_control_flow_flattening", + ":hlo_decomposer_lib", ":hlo_module_loader", ":prepare_reference_module", ":run_hlo_module_proto_cc", "//xla:error_spec", "//xla:literal", "//xla:literal_comparison", + "//xla:status", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", "//xla/service:hlo_runner", "//xla/service:hlo_verifier", "//xla/tests:test_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:path", "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", ], ) @@ -534,13 +563,13 @@ xla_cc_binary( "//xla/service:hlo_runner", "//xla/service:interpreter_plugin", "//xla/service:platform_util", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/strings", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:path", "@tsl//tsl/platform:platform_port", "@tsl//tsl/platform:status", "@tsl//tsl/platform:test", - "@tsl//tsl/util:command_line_flags", ] + if_cuda_or_rocm([ "//xla/service:gpu_plugin", ]) + if_cuda([ @@ -570,7 +599,12 @@ cc_library( srcs = ["hlo_control_flow_flattening.cc"], hdrs = ["hlo_control_flow_flattening.h"], deps = [ + "//xla:comparison_util", + "//xla:literal", "//xla:literal_util", + "//xla:shape_util", + "//xla:status", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:call_graph", "//xla/service:collective_ops_utils", @@ -578,7 +612,12 @@ cc_library( "//xla/service:hlo_pass", "//xla/service:tuple_util", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -592,7 +631,6 @@ xla_cc_test( "//xla/service:despecializer", "//xla/service:hlo_verifier", "//xla/service/spmd:spmd_partitioner", - "//xla/tests:client_library_test_base", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/strings", @@ -626,30 +664,75 @@ xla_cc_binary( ], ) +xla_cc_binary( + name = "extract_collective_operations", + srcs = ["extract_collective_operations.cc"], + deps = [ + ":hlo_decomposer_lib", + ":hlo_module_loader", + "//xla:debug_options_flags", + "//xla:status", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_proto_cc", + "//xla/tsl/util:command_line_flags", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:path", + "@tsl//tsl/platform:platform_port", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", + ], +) + tsl_gpu_library( name = "xla_compile_lib", srcs = ["xla_compile_lib.cc"], hdrs = ["xla_compile_lib.h"], - cuda_deps = [ - ], defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), + visibility = ["//visibility:public"], deps = [ + ":hlo_module_loader", + "//xla:debug_options_flags", + "//xla:shape_util", + "//xla:status", + "//xla:statusor", "//xla:util", + "//xla/client:xla_computation", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", + "//xla/mlir_hlo", + "//xla/pjrt:mlir_to_hlo", "//xla/service:compiler", "//xla/service:executable", + "//xla/service:export_hlo", + "//xla/service:hlo_module_config", + "//xla/service:hlo_proto_cc", + "//xla/service:symbol_repository", "//xla/service:xla_compile_result_proto_cc_impl", "//xla/service/cpu:cpu_compiler", "//xla/service/cpu:cpu_executable", + "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@stablehlo//:register", "@tsl//tsl/platform:env", "@tsl//tsl/platform:env_time", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:path", + "@tsl//tsl/platform:protobuf", "@tsl//tsl/platform:status", + "@tsl//tsl/platform:status_to_from_proto", "@tsl//tsl/platform:statusor", ] + if_cuda_is_configured([ "//xla/service/gpu:nvptx_compiler", @@ -658,9 +741,11 @@ tsl_gpu_library( "//xla/service/gpu:amdgpu_compiler", "//xla/service/gpu:amdgpu_compiler_impl", ]) + if_gpu_is_configured([ + "//xla/service/gpu:autotuner_util", "//xla/service/gpu:executable_proto_cc", "//xla/service/gpu:gpu_compiler", "//xla/stream_executor/gpu:gpu_init", + "//xla/service/gpu:gpu_symbol_repository", ]), ) @@ -668,10 +753,7 @@ xla_test( name = "xla_compile_lib_test", srcs = ["xla_compile_lib_test.cc"], backend_tags = { - "gpu": [ - "requires-gpu-nvidia", - "config-cuda-only", - ], + "gpu": ["requires-gpu-nvidia"] + if_google(["config-cuda-only"]), }, backends = [ "cpu", @@ -707,3 +789,18 @@ xla_test( "@tsl//tsl/protobuf:error_codes_proto_impl_cc", ], ) + +xla_test( + name = "hlo_decomposer_test", + srcs = ["hlo_decomposer_test.cc"], + deps = [ + ":hlo_decomposer_lib", + "//xla/hlo/ir:hlo", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", + ], +) diff --git a/xla/tools/compute_cost.cc b/xla/tools/compute_cost.cc index f0b1ad29bb2a1..c5e0ddcda4e22 100644 --- a/xla/tools/compute_cost.cc +++ b/xla/tools/compute_cost.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -56,7 +56,7 @@ int main(int argc, char** argv) { return xla::ShapeUtil::ByteSizeOf(shape, 8); }); - TF_CHECK_OK(xla::LoadModuleFromFile(input, {}, format) + TF_CHECK_OK(xla::LoadModuleFromFile(input, format, {}) .value() ->entry_computation() ->root_instruction() diff --git a/xla/tools/convert_computation.cc b/xla/tools/convert_computation.cc index 652d8e5a7c4b4..fba823d4cedf6 100644 --- a/xla/tools/convert_computation.cc +++ b/xla/tools/convert_computation.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,9 +16,10 @@ limitations under the License. // Usage: convert_computation serialized_computation_proto // // bin2txt spits out the result to stdout. txt2bin modifies the file in place. - -#include +#ifndef _WIN32 #include +#endif +#include #include diff --git a/xla/tools/driver.cc b/xla/tools/driver.cc index 27e03945949ac..351d01660aa2f 100644 --- a/xla/tools/driver.cc +++ b/xla/tools/driver.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tools/dumped_computation_to_operation_list.cc b/xla/tools/dumped_computation_to_operation_list.cc index a9a6ecdc5d11f..a916d796b0559 100644 --- a/xla/tools/dumped_computation_to_operation_list.cc +++ b/xla/tools/dumped_computation_to_operation_list.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tools/dumped_computation_to_text.cc b/xla/tools/dumped_computation_to_text.cc index eacc7831f0a7d..4a939100698e5 100644 --- a/xla/tools/dumped_computation_to_text.cc +++ b/xla/tools/dumped_computation_to_text.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tools/extract_collective_operations.cc b/xla/tools/extract_collective_operations.cc new file mode 100644 index 0000000000000..1a15fa8fbc1fe --- /dev/null +++ b/xla/tools/extract_collective_operations.cc @@ -0,0 +1,96 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "xla/debug_options_flags.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/hlo.pb.h" +#include "xla/status.h" +#include "xla/tools/hlo_decomposer.h" +#include "xla/tools/hlo_module_loader.h" +#include "xla/tsl/util/command_line_flags.h" +#include "tsl/platform/env.h" +#include "tsl/platform/init_main.h" +#include "tsl/platform/path.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" + +namespace { +const char* const kUsage = R"( +This tool extracts collective operations from HLO module and saves them together +to the separate module. + +Usage: +bazel run extract_collective_operations -- --input=path/to/hlo_module + --output=path/to/hlo_module +)"; +} // namespace + +namespace xla { +Status ExtractCollectiveOperations(const std::string& input, + const std::string& output) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr test_module, + LoadModuleFromFile(input, std::string(tsl::io::Extension(input)), + hlo_module_loader_details::Config(), nullptr)); + + std::vector collective_instructions; + for (const auto& op : test_module->computations()) { + for (const auto& instr : op->instructions()) { + if (absl::StartsWith(instr->name(), "all-")) { + collective_instructions.push_back(instr); + } + } + } + + if (collective_instructions.empty()) { + return absl::InternalError("No collective instructions found."); + } + auto collectives_module = + ExtractInstructionIntoNewModule(collective_instructions); + + QCHECK_OK(tsl::WriteStringToFile(tsl::Env::Default(), output, + collectives_module->ToString())) + << "Can't open or write output module at " << output; + return absl::OkStatus(); +} +} // namespace xla + +int main(int argc, char** argv) { + std::string input; + std::string output; + std::vector flag_list = { + tsl::Flag("input", &input, "input file"), + tsl::Flag("output", &output, "output file")}; + xla::AppendDebugOptionsFlags(&flag_list); + const std::string kUsageString = + absl::StrCat(kUsage, "\n\n", tsl::Flags::Usage(argv[0], flag_list)); + bool parse_ok = tsl::Flags::Parse(&argc, argv, flag_list); + tsl::port::InitMain(kUsageString.c_str(), &argc, &argv); + if (!parse_ok) { + LOG(QFATAL) << kUsageString; + } + TF_CHECK_OK(xla::ExtractCollectiveOperations(input, output)); + return 0; +} diff --git a/xla/tools/hex_floats_to_packed_literal.cc b/xla/tools/hex_floats_to_packed_literal.cc index 609890ad231da..c4d591ba34928 100644 --- a/xla/tools/hex_floats_to_packed_literal.cc +++ b/xla/tools/hex_floats_to_packed_literal.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/strings/string_view.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/types.h" #include "tsl/lib/io/buffered_inputstream.h" #include "tsl/lib/io/random_inputstream.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" -#include "tsl/util/command_line_flags.h" using std::string; diff --git a/xla/tools/hlo_bisect/BUILD b/xla/tools/hlo_bisect/BUILD index 53208bb90b6ce..730e18da45425 100644 --- a/xla/tools/hlo_bisect/BUILD +++ b/xla/tools/hlo_bisect/BUILD @@ -1,3 +1,6 @@ +load("@bazel_skylib//rules:build_test.bzl", "build_test") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") + # Description: # A tool for reducing a HLO module that produces incorrect results. load( @@ -5,8 +8,6 @@ load( "xla_cc_binary", "xla_cc_test", ) -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("@bazel_skylib//rules:build_test.bzl", "build_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -31,8 +32,8 @@ xla_cc_binary( "//xla/service:cpu_plugin", "//xla/service:gpu_plugin", "//xla/service:interpreter_plugin", + "//xla/tsl/util:command_line_flags", "@tsl//tsl/platform:platform_port", - "@tsl//tsl/util:command_line_flags", ] + if_cuda(["//xla/stream_executor/cuda:cublas_plugin"]), ) diff --git a/xla/tools/hlo_bisect/hlo_bisect.cc b/xla/tools/hlo_bisect/hlo_bisect.cc index fe8c4a5b04cf9..73b018323f34c 100644 --- a/xla/tools/hlo_bisect/hlo_bisect.cc +++ b/xla/tools/hlo_bisect/hlo_bisect.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,8 +20,8 @@ limitations under the License. #include #include "xla/tools/hlo_bisect/hlo_bisect_utils.h" +#include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/init_main.h" -#include "tsl/util/command_line_flags.h" const char* const kUsage = R"( Given an HloModule that manifests an XLA bug, either crashes the compiler or diff --git a/xla/tools/hlo_bisect/hlo_bisect_state.cc b/xla/tools/hlo_bisect/hlo_bisect_state.cc index fa5de90f7fc87..19f1b29fd1b77 100644 --- a/xla/tools/hlo_bisect/hlo_bisect_state.cc +++ b/xla/tools/hlo_bisect/hlo_bisect_state.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -68,7 +68,7 @@ Status MorphModuleWithOutputs(HloModule* module, module->compute_computation_layout(); HloDCE dce; - StatusOr dce_result = dce.Run(module); + absl::StatusOr dce_result = dce.Run(module); return dce_result.status(); } @@ -127,7 +127,7 @@ Status MorphModuleWithLiterals( } xla::HloDCE dce; - StatusOr dce_status = dce.Run(module); + absl::StatusOr dce_status = dce.Run(module); return dce_status.status(); } @@ -144,12 +144,12 @@ bool InstructionNotReplaceableWithConstant(HloInstruction* instruction) { } // namespace -StatusOr HloBisectState::ShouldProcess() { +absl::StatusOr HloBisectState::ShouldProcess() { // Running the unmodified module should trigger the bug checker. return RunModule(*module_); } -StatusOr HloBisectState::TrimEntryComputation() { +absl::StatusOr HloBisectState::TrimEntryComputation() { bool changed_in_loop = false; bool changed = false; for (int iter = 0; changed || iter < 2; iter++) { @@ -172,11 +172,11 @@ std::unique_ptr&& HloBisectState::GetResult() { return std::move(module_); } -StatusOr HloBisectState::RunModule(const HloModule& module) { +absl::StatusOr HloBisectState::RunModule(const HloModule& module) { VLOG(3) << "Modified module: " << module.ToString(); // Run the modified module with the bug checker. - StatusOr bug_result = bug_checker_->Run(module); + absl::StatusOr bug_result = bug_checker_->Run(module); TF_RETURN_IF_ERROR(bug_result.status()); VLOG(3) << "Bug checker result: " << bug_result.value(); @@ -192,7 +192,7 @@ StatusOr HloBisectState::RunModule(const HloModule& module) { return bug_result; } -StatusOr HloBisectState::TrimByOutputs() { +absl::StatusOr HloBisectState::TrimByOutputs() { // Only available if the root instruction is a tuple. HloInstruction* root_instruction = module_->entry_computation()->root_instruction(); @@ -202,7 +202,7 @@ StatusOr HloBisectState::TrimByOutputs() { } // Run the modified module and return the error state. - auto run_modified = [&](int64_t start, int64_t end) -> StatusOr { + auto run_modified = [&](int64_t start, int64_t end) -> absl::StatusOr { std::unique_ptr new_module = module_->Clone(/*suffix=*/""); HloInstruction* const* new_operands = new_module->entry_computation()->root_instruction()->operands().begin(); @@ -245,7 +245,7 @@ StatusOr HloBisectState::TrimByOutputs() { return changed; } -StatusOr HloBisectState::TrimByInstructions() { +absl::StatusOr HloBisectState::TrimByInstructions() { HloComputation* computation = module_->entry_computation(); // If the root instruction is a tuple, exclude it from the bisect range. @@ -271,7 +271,7 @@ StatusOr HloBisectState::TrimByInstructions() { // Sanity check for the bug checker. if (bisect_high == computation->num_parameters()) { - return InternalError( + return Internal( "The checker fails on an empty computation! Something is not right. " "Can't bisect."); } @@ -285,7 +285,7 @@ StatusOr HloBisectState::TrimByInstructions() { return changed; } -StatusOr HloBisectState::TrimByUsingConstants() { +absl::StatusOr HloBisectState::TrimByUsingConstants() { // Use random literals for the instructions which do not trigger the bug // checker and also didn't get a definitive value from it. absl::flat_hash_map literal_map; @@ -298,7 +298,7 @@ StatusOr HloBisectState::TrimByUsingConstants() { auto it = foldable_instructions_values_.extract(instr->name()); literal_map.insert(std::move(it)); } else if (foldable_instructions_.contains(instr->name())) { - StatusOr literal_status = MakeFakeLiteral(instr->shape()); + absl::StatusOr literal_status = MakeFakeLiteral(instr->shape()); TF_RETURN_IF_ERROR(literal_status.status()); literal_map[instr->name()] = std::move(literal_status).value(); ++random_literals_count; @@ -337,11 +337,10 @@ Status HloBisectState::ExpectModuleIsBuggy() { } } if (bug_count != 0) { - return InternalErrorStrCat("The checker is non deterministic! (only ", - bug_count, " failures seen in ", - (retry_count + 1), " runs)"); + return InternalStrCat("The checker is non deterministic! (only ", bug_count, + " failures seen in ", (retry_count + 1), " runs)"); } - return InternalError("We \"lost\" the bug while bisecting!"); + return Internal("We \"lost\" the bug while bisecting!"); } } // namespace bisect diff --git a/xla/tools/hlo_bisect/hlo_bisect_state.h b/xla/tools/hlo_bisect/hlo_bisect_state.h index fdc152f5a7556..6cc19dd7bcbee 100644 --- a/xla/tools/hlo_bisect/hlo_bisect_state.h +++ b/xla/tools/hlo_bisect/hlo_bisect_state.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ class BugCheckerInterface { virtual ~BugCheckerInterface() {} // Returns true if `module` has a bug we're interested in. - virtual StatusOr Run(const HloModule& module) = 0; + virtual absl::StatusOr Run(const HloModule& module) = 0; // Returns mapping of instruction names to their results after the run // (empty if this information is unavailable). @@ -54,11 +54,11 @@ class HloBisectState { : module_(std::move(module)), bug_checker_(bug_checker) {} // Returns true if the current module has a bug and should be processed. - StatusOr ShouldProcess(); + absl::StatusOr ShouldProcess(); // Trims the entry computation until no more reductions are possible. Returns // a boolean to indicate whether the computation has been reduced. - StatusOr TrimEntryComputation(); + absl::StatusOr TrimEntryComputation(); // Returns the resulting module. std::unique_ptr&& GetResult(); @@ -66,19 +66,19 @@ class HloBisectState { private: // Runs a modified module and updates the foldable instructions data, if // available. Returns true if `module` has a bug. - StatusOr RunModule(const HloModule& module); + absl::StatusOr RunModule(const HloModule& module); // Trims the entry computation by reducing the total number of outputs. // Returns a boolean to indicate whether the computation has been reduced. - StatusOr TrimByOutputs(); + absl::StatusOr TrimByOutputs(); // Trims the entry computation by reducing the total number of instructions. // Returns a boolean to indicate whether the computation has been reduced. - StatusOr TrimByInstructions(); + absl::StatusOr TrimByInstructions(); // Trims the given computation by replacing instructions with constant values. // Returns a boolean to indicate whether the computation has been reduced. - StatusOr TrimByUsingConstants(); + absl::StatusOr TrimByUsingConstants(); // Asserts that the module still has the bug. If negative, runs the bug // checker repeatedly to verify that it's deterministic. diff --git a/xla/tools/hlo_bisect/hlo_bisect_state_test.cc b/xla/tools/hlo_bisect/hlo_bisect_state_test.cc index 1e846071238c3..3e608c82ddb16 100644 --- a/xla/tools/hlo_bisect/hlo_bisect_state_test.cc +++ b/xla/tools/hlo_bisect/hlo_bisect_state_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ class TestBugSearch : public BugCheckerInterface { public: TestBugSearch(std::initializer_list opcodes) : opcodes_(opcodes) {} - StatusOr Run(const HloModule& module) override { + absl::StatusOr Run(const HloModule& module) override { auto has_opcode = [&](HloOpcode opcode) { return absl::c_any_of(module.entry_computation()->instructions(), [opcode](const HloInstruction* instr) { @@ -173,7 +173,7 @@ TEST_F(HloBisectStateTest, TrimByOutputsLostBug) { class CustomBugSearch : public TestBugSearch { public: CustomBugSearch() : TestBugSearch({HloOpcode::kConstant}) {} - StatusOr Run(const HloModule& module) override { + absl::StatusOr Run(const HloModule& module) override { TF_ASSIGN_OR_RETURN(bool has_constants, TestBugSearch::Run(module)); int program_size = module.entry_computation()->instruction_count(); return program_size == 5 && !has_constants; diff --git a/xla/tools/hlo_bisect/hlo_bisect_utils.cc b/xla/tools/hlo_bisect/hlo_bisect_utils.cc index 94b7ca5e7e8a6..338d4755ce47c 100644 --- a/xla/tools/hlo_bisect/hlo_bisect_utils.cc +++ b/xla/tools/hlo_bisect/hlo_bisect_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -56,7 +56,7 @@ Literal ExecuteWithRunnerAndRetrieveResult(std::unique_ptr module, } // Loads the given HloProto as HloModule. -StatusOr> LoadModuleFromHloProto( +absl::StatusOr> LoadModuleFromHloProto( const HloProto& proto) { const HloModuleProto& module_proto = proto.hlo_module(); TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, @@ -65,8 +65,9 @@ StatusOr> LoadModuleFromHloProto( return CreateModuleFromProto(module_proto, module_config); } -StatusOr> LoadModuleAndInputDataFromHloSnapshot( - const HloSnapshot& snapshot, std::vector* input_data) { +absl::StatusOr> +LoadModuleAndInputDataFromHloSnapshot(const HloSnapshot& snapshot, + std::vector* input_data) { for (int64_t i = 0; i < snapshot.arguments_size(); ++i) { TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(snapshot.arguments(i))); @@ -79,7 +80,7 @@ StatusOr> LoadModuleAndInputDataFromHloSnapshot( return HloModule::CreateFromProto(snapshot.hlo().hlo_module(), config); } -StatusOr GetModuleAndInputData( +absl::StatusOr GetModuleAndInputData( absl::string_view input_filename) { const std::string input_file(input_filename); tsl::Env* env = tsl::Env::Default(); @@ -95,7 +96,7 @@ StatusOr GetModuleAndInputData( } LOG(INFO) << input_file << " is not HloSnapshot. Trying HLO binary proto.\n"; HloProto hlo_proto; - StatusOr> module_or_status; + absl::StatusOr> module_or_status; if (tsl::ReadBinaryProto(env, input_file, &hlo_proto).ok()) { module_or_status = LoadModuleFromHloProto(hlo_proto); if (!module_or_status.ok()) { @@ -168,7 +169,7 @@ MiscompareChecker::MiscompareChecker(HloModule* module, // Generate input data and store the data for all the execution. std::minstd_rand0 rng_engine; if (input_data.empty()) { - StatusOr> input_status = + absl::StatusOr> input_status = MakeFakeArguments(module, &rng_engine); CHECK(input_status.ok()); input_data_ = std::move(input_status).value(); @@ -178,14 +179,14 @@ MiscompareChecker::MiscompareChecker(HloModule* module, } // Set up the reference platform. - StatusOr reference_platform_status = + absl::StatusOr reference_platform_status = PlatformUtil::GetPlatform(std::string(reference_platform)); CHECK(reference_platform_status.ok()); reference_runner_ = std::make_unique(reference_platform_status.value()); // Set up the test platform. - StatusOr test_platform_status = + absl::StatusOr test_platform_status = PlatformUtil::GetPlatform(std::string(test_platform)); CHECK(test_platform_status.ok()); test_runner_ = @@ -195,7 +196,7 @@ MiscompareChecker::MiscompareChecker(HloModule* module, // Executes the module with the test_runner and the reference_runner and // compares the results from the two runs. Returns true if the two results are // not near to indicate a bug exists. -StatusOr MiscompareChecker::Run(const HloModule& module) { +absl::StatusOr MiscompareChecker::Run(const HloModule& module) { std::unique_ptr test_module = module.Clone(/*suffix=*/""); // Make sure that the module config has a non-zero seed, which the CPU and GPU @@ -210,9 +211,8 @@ StatusOr MiscompareChecker::Run(const HloModule& module) { } // Prepare the reference module. - TF_ASSIGN_OR_RETURN( - std::unique_ptr reference_module, - PrepareReferenceModule(*test_module, test_runner_.get())); + TF_ASSIGN_OR_RETURN(std::unique_ptr reference_module, + PrepareReferenceModule(*test_module, test_runner_.get())); // Run the module on the reference platform. Literal reference_result = ExecuteWithRunnerAndRetrieveResult( @@ -225,7 +225,7 @@ StatusOr MiscompareChecker::Run(const HloModule& module) { /*run_hlo_passes=*/true); // Compare the results. - StatusOr<::testing::AssertionResult> status_or_result = + absl::StatusOr<::testing::AssertionResult> status_or_result = LiteralTestUtil::Near(/*expected=*/reference_result, /*actual=*/test_result, /*error_spec=*/error_spec_, @@ -241,18 +241,19 @@ absl::flat_hash_map MiscompareChecker::GetResults() { return {}; } -StatusOr> MiscompareChecker::PrepareReferenceModule( - const HloModule& hlo_module, HloRunnerInterface* hlo_runner) const { +absl::StatusOr> +MiscompareChecker::PrepareReferenceModule( + const HloModule& hlo_module, HloRunnerInterface* hlo_runner) const { // By default clone the test module (could be overridden). return xla::PrepareReferenceModule(hlo_module, hlo_runner); } -StatusOr ScriptChecker::Run(const HloModule& module) { +absl::StatusOr ScriptChecker::Run(const HloModule& module) { tsl::Env* env = tsl::Env::Default(); // Write hlo into a temporary file. std::string hlo_path; if (!env->LocalTempFilename(&hlo_path)) { - return InternalError("couldn't get temp HLO file name"); + return Internal("couldn't get temp HLO file name"); } absl::Cleanup hlo_cleaner = [&] { @@ -273,7 +274,7 @@ StatusOr ScriptChecker::Run(const HloModule& module) { script_subprocess.SetChannelAction(tsl::CHAN_STDOUT, tsl::ACTION_PIPE); script_subprocess.SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE); if (!script_subprocess.Start()) { - return InternalError("Failed to launch script"); + return Internal("Failed to launch script"); } std::string stderr_output; @@ -293,7 +294,7 @@ absl::flat_hash_map ScriptChecker::GetResults() { return {}; } -StatusOr> BisectRunner::RunEntry() { +absl::StatusOr> BisectRunner::RunEntry() { HloBisectState hlo_bisect(std::move(module_), bug_checker_.get()); TF_ASSIGN_OR_RETURN(bool has_bug, hlo_bisect.ShouldProcess()); if (!has_bug) { @@ -306,13 +307,13 @@ StatusOr> BisectRunner::RunEntry() { return hlo_bisect.GetResult(); } -StatusOr> BisectRunner::RunAll() { +absl::StatusOr> BisectRunner::RunAll() { std::unique_ptr original_module = std::move(module_); std::unique_ptr result; for (HloComputation* c : original_module->computations()) { LOG(INFO) << "Bisecting computation: " << c->name(); module_ = original_module->Clone(/*suffix=*/""); - StatusOr> new_result; + absl::StatusOr> new_result; if (c->IsEntryComputation()) { // Run on the entry computation with input data. new_result = RunEntry(); @@ -341,7 +342,7 @@ StatusOr> BisectRunner::RunAll() { void RunBisect(std::unique_ptr runner, bool all_computations, absl::string_view dump_path, absl::string_view output_format) { - StatusOr> bisect_status = + absl::StatusOr> bisect_status = all_computations ? runner->RunAll() : runner->RunEntry(); CHECK(bisect_status.ok()) << bisect_status.status().message(); @@ -352,7 +353,7 @@ void RunBisect(std::unique_ptr runner, bool all_computations, CHECK(dump_status.ok()) << dump_status.message(); } -StatusOr GetVerifiedModuleAndInputData( +absl::StatusOr GetVerifiedModuleAndInputData( absl::string_view input_filename) { std::unique_ptr module; std::vector input_data; diff --git a/xla/tools/hlo_bisect/hlo_bisect_utils.h b/xla/tools/hlo_bisect/hlo_bisect_utils.h index 9a0d34f5436fc..b0649a78579ae 100644 --- a/xla/tools/hlo_bisect/hlo_bisect_utils.h +++ b/xla/tools/hlo_bisect/hlo_bisect_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -41,10 +41,10 @@ class MiscompareChecker : public BugCheckerInterface { absl::string_view test_platform, absl::string_view reference_platform, ErrorSpec error_spec); - StatusOr Run(const HloModule& module) override; + absl::StatusOr Run(const HloModule& module) override; absl::flat_hash_map GetResults() override; - virtual StatusOr> PrepareReferenceModule( + virtual absl::StatusOr> PrepareReferenceModule( const HloModule& hlo_module, HloRunnerInterface* hlo_runner) const; private: @@ -61,7 +61,7 @@ class ScriptChecker : public BugCheckerInterface { public: explicit ScriptChecker(std::string path_to_script) : path_to_script_(std::move(path_to_script)) {} - StatusOr Run(const HloModule& module) override; + absl::StatusOr Run(const HloModule& module) override; absl::flat_hash_map GetResults() override; private: @@ -75,8 +75,8 @@ class BisectRunner { std::unique_ptr bug_checker) : module_(std::move(module)), bug_checker_(std::move(bug_checker)) {} - StatusOr> RunEntry(); - StatusOr> RunAll(); + absl::StatusOr> RunEntry(); + absl::StatusOr> RunAll(); protected: std::unique_ptr module_; @@ -90,7 +90,7 @@ void RunBisect(std::unique_ptr runner, bool all_computations, // Utility function for getting the verified module and optional inputs. using ModuleWithInputs = std::pair, std::vector>; -xla::StatusOr GetVerifiedModuleAndInputData( +absl::StatusOr GetVerifiedModuleAndInputData( absl::string_view input_filename); } // namespace bisect diff --git a/xla/tools/hlo_control_flow_flattening.cc b/xla/tools/hlo_control_flow_flattening.cc index 933c33b45f080..98e3d8584712e 100644 --- a/xla/tools/hlo_control_flow_flattening.cc +++ b/xla/tools/hlo_control_flow_flattening.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,19 +16,35 @@ limitations under the License. #include "xla/tools/hlo_control_flow_flattening.h" #include -#include +#include #include +#include +#include #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/service/call_graph.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/hlo_dce.h" #include "xla/service/tuple_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -155,7 +171,7 @@ Status HloControlFlowFlattening::FlattenWhileLoop( // non-get-tuple-element users with a new tuple instruction which has the // first N - 1 elements. auto replace_non_gte_users = - [](HloInstruction* new_tuple) -> StatusOr { + [](HloInstruction* new_tuple) -> absl::StatusOr { CHECK(new_tuple->shape().IsTuple()); HloInstruction* prefix = nullptr; std::vector users(new_tuple->users()); @@ -371,7 +387,8 @@ Status HloControlFlowFlattening::RemoveSendDone( return OkStatus(); } -Status HloControlFlowFlattening::RemoveCollective(HloInstruction* hlo) const { +absl::StatusOr HloControlFlowFlattening::RemoveCollective( + HloInstruction* hlo) const { HloComputation* computation = hlo->parent(); HloInstruction* custom_call = computation->AddInstruction(HloInstruction::CreateCustomCall( @@ -387,7 +404,7 @@ Status HloControlFlowFlattening::RemoveCollective(HloInstruction* hlo) const { std::string original_op_name(hlo->name()); TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, custom_call)); custom_call->SetAndSanitizeName(original_op_name); - return OkStatus(); + return custom_call; } Status HloControlFlowFlattening::RemoveId(HloInstruction* hlo) const { @@ -399,13 +416,18 @@ Status HloControlFlowFlattening::RemoveId(HloInstruction* hlo) const { return OkStatus(); } -StatusOr HloControlFlowFlattening::Run( +absl::StatusOr HloControlFlowFlattening::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { auto call_graph = CallGraph::Build(module); bool changed = false; absl::flat_hash_set removed; for (HloComputation* computation : module->computations(execution_threads)) { + // Do not change computations that are wrapped by async calls. Instead we + // remove the async callers if needed. + if (computation->IsAsyncComputation()) { + continue; + } for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { if (removed.contains(instruction)) { @@ -447,9 +469,28 @@ StatusOr HloControlFlowFlattening::Run( changed = true; } } else if (remove_comm_ && IsCollective(instruction) && - !instruction->parent()->IsFusionComputation()) { - VLOG(1) << "Remove " << instruction->name(); - TF_RETURN_IF_ERROR(RemoveCollective(instruction)); + !instruction->parent()->IsFusionComputation() && + (instruction->opcode() != HloOpcode::kAsyncStart && + instruction->opcode() != HloOpcode::kAsyncUpdate)) { + // We do not remove kAsyncStart or kAsyncUpdate here since we expect + // them to be removed as a part of the async chain above. + // We should remove the async chain all together because the async + // wrapped computation is only associated with the AsyncStart. So we + // need to refer to the AsyncStart in order to determine whether + // the Done or the Update wraps a collective. + if (instruction->opcode() == HloOpcode::kAsyncDone) { + while (instruction->opcode() == HloOpcode::kAsyncDone || + instruction->opcode() == HloOpcode::kAsyncUpdate || + instruction->opcode() == HloOpcode::kAsyncStart) { + HloInstruction* operand = instruction->mutable_operand(0); + VLOG(1) << "Remove " << instruction->name(); + TF_RETURN_IF_ERROR(RemoveCollective(instruction).status()); + instruction = operand; + } + } else { + VLOG(1) << "Remove " << instruction->name(); + TF_RETURN_IF_ERROR(RemoveCollective(instruction).status()); + } changed = true; } else if (remove_comm_ && (instruction->opcode() == HloOpcode::kPartitionId || diff --git a/xla/tools/hlo_control_flow_flattening.h b/xla/tools/hlo_control_flow_flattening.h index 47a2d92a0bf3b..c70bc80a50407 100644 --- a/xla/tools/hlo_control_flow_flattening.h +++ b/xla/tools/hlo_control_flow_flattening.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,10 +19,14 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/call_graph.h" #include "xla/service/hlo_pass_interface.h" +#include "xla/status.h" namespace xla { @@ -56,7 +60,7 @@ class HloControlFlowFlattening : public HloModulePass { ~HloControlFlowFlattening() override = default; absl::string_view name() const override { return "control-flow-flattening"; } using HloPassInterface::Run; - StatusOr Run( + absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; @@ -89,8 +93,9 @@ class HloControlFlowFlattening : public HloModulePass { bool remove_host_transfer_; protected: - // Replaces a collective op with a custom call. - Status RemoveCollective(HloInstruction* hlo) const; + // Replaces a collective op with a custom call and returns the custom call. + virtual absl::StatusOr RemoveCollective( + HloInstruction* hlo) const; bool remove_comm_; }; diff --git a/xla/tools/hlo_control_flow_flattening_test.cc b/xla/tools/hlo_control_flow_flattening_test.cc index a4d018da0ff88..a391a59ebdad3 100644 --- a/xla/tools/hlo_control_flow_flattening_test.cc +++ b/xla/tools/hlo_control_flow_flattening_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers; class HloControlFlowFlatteningTest : public HloTestBase { public: - StatusOr> PartitionComputation( + absl::StatusOr> PartitionComputation( std::unique_ptr hlo_module, int64_t num_devices = 2) { spmd::SpmdPartitionerOptions options; auto collective_ops_creator = @@ -52,7 +52,7 @@ class HloControlFlowFlatteningTest : public HloTestBase { pass.AddPass(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); TF_RETURN_IF_ERROR(pass.Run(hlo_module.get()).status()); - return StatusOr>(std::move(hlo_module)); + return absl::StatusOr>(std::move(hlo_module)); } }; @@ -792,6 +792,27 @@ ENTRY main { EXPECT_EQ(module->entry_computation()->root_instruction()->name(), "fusion"); } +TEST_F(HloControlFlowFlatteningTest, AsyncAllToAll) { + absl::string_view hlo = R"( + + ENTRY main { + param = f32[4,8,128]{2,1,0} parameter(0) + all-to-all-start = ((f32[4,8,128]{2,1,0}), f32[4,8,128]{2,1,0}, u32[], u32[]) all-to-all-start(param), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={1} + ROOT all-to-all-done = f32[4,8,128]{2,1,0} all-to-all-done(all-to-all-start) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + EXPECT_TRUE(IsCollective(module->entry_computation()->root_instruction())); + HloControlFlowFlattening flattening({}); + EXPECT_TRUE(flattening.Run(module.get()).value()); + TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/true) + .Run(module.get()) + .status()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::CustomCall(op::CustomCall(op::Parameter(0)))); +} + void CheckWhileBound(HloInstruction* while_op, int expected_bound) { auto* cond = while_op->while_condition(); ASSERT_NE(cond, nullptr); diff --git a/xla/tools/hlo_decomposer.cc b/xla/tools/hlo_decomposer.cc new file mode 100644 index 0000000000000..005355d671fe3 --- /dev/null +++ b/xla/tools/hlo_decomposer.cc @@ -0,0 +1,237 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tools/hlo_decomposer.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "xla/hlo/ir/hlo_clone_context.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/call_graph.h" +#include "xla/service/compilation_environments.h" +#include "xla/status.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +// Returns whether it makes sense to run the given instruction in isolation +// (e.g. whether it can run without dependent instructions). +bool ShouldIsolateOpcode(HloOpcode opcode) { + switch (opcode) { + case HloOpcode::kConstant: + case HloOpcode::kGetTupleElement: + case HloOpcode::kParameter: + case HloOpcode::kTuple: + return false; + default: + return true; + } +} + +absl::StatusOr>> Decompose( + const HloModule& module) { + std::vector> modules; + + absl::flat_hash_set computations_to_visit{ + module.entry_computation()}; + absl::flat_hash_set visited_computations; + + // Traverse the computation tree, starting from the entry computation, and + // recursing into the called computations. + while (!computations_to_visit.empty()) { + const HloComputation* computation = *computations_to_visit.begin(); + computations_to_visit.erase(computations_to_visit.begin()); + visited_computations.insert(computation); + + for (const HloInstruction* instruction : computation->instructions()) { + // Skip called computations in the embedded context (fusion, reduce, map, + // etc), as within these computations instructions are not lowered + // individually and it doesn't make sense to test them in isolation. + if (GetInstructionCallContext(instruction->opcode()) != + CallContext::kEmbedded) { + for (const HloComputation* called_computation : + instruction->called_computations()) { + if (!visited_computations.contains(called_computation)) { + computations_to_visit.insert(called_computation); + } + } + } + if (ShouldIsolateOpcode(instruction->opcode())) { + modules.push_back(ExtractInstructionIntoNewModule(*instruction)); + } + } + } + + return modules; +} + +} // namespace + +absl::StatusOr>> DecomposeHloModule( + const HloModule& module, bool deduplicate_modules) { + std::vector> modules; + absl::flat_hash_set module_fingerprints; + + auto should_add_module = [&](const HloModule* module) { + if (!deduplicate_modules) { + return true; + } + const std::string fingerprint = module->GetFingerprint128(); + if (module_fingerprints.contains(fingerprint)) { + return false; + } + module_fingerprints.insert(fingerprint); + return true; + }; + + TF_ASSIGN_OR_RETURN(std::vector> isolated_modules, + Decompose(module)); + for (auto& module : isolated_modules) { + if (should_add_module(module.get())) { + modules.push_back(std::move(module)); + } + } + return modules; +} + +std::unique_ptr ExtractInstructionIntoNewModule( + const std::vector& instructions) { + CHECK(!instructions.empty()); + HloInstruction& first_instruction = *instructions[0]; + auto new_hlo_module = std::make_unique( + first_instruction.GetModule()->name() + "_collective_ops", + HloModuleConfig{}, + std::make_unique( + first_instruction.GetModule()->comp_envs())); + int parameter_number = 0; + HloComputation::Builder builder("entry_computation"); + HloCloneContext clone_context(new_hlo_module.get()); + std::vector new_instructions; + for (auto* hlo : instructions) { + std::vector new_operands; + for (const HloInstruction* operand : hlo->operands()) { + std::unique_ptr new_parameter = + HloInstruction::CreateParameter(parameter_number, operand->shape(), + operand->name()); + ++parameter_number; + new_operands.push_back(builder.AddInstruction(std::move(new_parameter))); + } + std::unique_ptr new_instruction = + hlo->CloneWithNewOperands(hlo->shape(), new_operands, &clone_context); + new_instructions.push_back( + builder.AddInstruction(std::move(new_instruction))); + } + + std::unique_ptr tuple_instruction = + HloInstruction::CreateTuple(new_instructions); + builder.AddInstruction(std::move(tuple_instruction)); + new_hlo_module->AddEntryComputationWithLayouts(builder.Build()); + return new_hlo_module; +} + +std::unique_ptr ExtractInstructionIntoNewModule( + const HloInstruction& hlo) { + auto new_hlo_module = std::make_unique( + std::string(hlo.name()), HloModuleConfig{}, + std::make_unique(hlo.GetModule()->comp_envs())); + int parameter_number = 0; + HloComputation::Builder builder("entry_computation"); + HloCloneContext clone_context(new_hlo_module.get()); + std::vector new_operands; + for (const HloInstruction* operand : hlo.operands()) { + std::unique_ptr new_parameter = + HloInstruction::CreateParameter(parameter_number, operand->shape(), + operand->name()); + ++parameter_number; + new_operands.push_back(builder.AddInstruction(std::move(new_parameter))); + } + std::unique_ptr new_instruction = + hlo.CloneWithNewOperands(hlo.shape(), new_operands, &clone_context); + builder.AddInstruction(std::move(new_instruction)); + new_hlo_module->AddEntryComputationWithLayouts(builder.Build()); + return new_hlo_module; +} + +std::unique_ptr ExtractProducerConsumerIntoNewModule( + const HloInstruction& producer, const HloInstruction& consumer) { + auto new_hlo_module = + std::make_unique("extracted", HloModuleConfig{}, + std::make_unique( + consumer.GetModule()->comp_envs())); + int parameter_number = 0; + HloComputation::Builder builder("entry_computation"); + HloCloneContext clone_context(new_hlo_module.get()); + absl::InlinedVector producer_operands; + for (const HloInstruction* operand : producer.operands()) { + HloInstruction* new_parameter = + builder.AddInstruction(HloInstruction::CreateParameter( + parameter_number, operand->shape(), operand->name())); + ++parameter_number; + + producer_operands.push_back(new_parameter); + } + + HloInstruction* new_producer = + builder.AddInstruction(producer.CloneWithNewOperands( + producer.shape(), producer_operands, &clone_context)); + + absl::flat_hash_map operand_map; + operand_map.emplace(&producer, new_producer); + + absl::InlinedVector consumer_operands; + for (const HloInstruction* operand : consumer.operands()) { + auto it = operand_map.find(operand); + if (it != operand_map.end()) { + consumer_operands.push_back(it->second); + } else { + HloInstruction* new_parameter = + builder.AddInstruction(HloInstruction::CreateParameter( + parameter_number, operand->shape(), operand->name())); + ++parameter_number; + + consumer_operands.push_back(new_parameter); + } + } + builder.AddInstruction(consumer.CloneWithNewOperands( + consumer.shape(), consumer_operands, &clone_context)); + + new_hlo_module->AddEntryComputationWithLayouts(builder.Build()); + return new_hlo_module; +} + +std::unique_ptr ExtractComputationIntoNewModule( + const HloComputation& computation) { + auto new_hlo_module = + std::make_unique("extracted", HloModuleConfig{}, + std::make_unique( + computation.parent()->comp_envs())); + HloCloneContext clone_context(new_hlo_module.get()); + new_hlo_module->AddEntryComputationWithLayouts( + computation.CloneInContext(clone_context)); + return new_hlo_module; +} + +} // namespace xla diff --git a/xla/tools/hlo_decomposer.h b/xla/tools/hlo_decomposer.h new file mode 100644 index 0000000000000..d12b4d82216d1 --- /dev/null +++ b/xla/tools/hlo_decomposer.h @@ -0,0 +1,60 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TOOLS_HLO_DECOMPOSER_H_ +#define XLA_TOOLS_HLO_DECOMPOSER_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" + +namespace xla { + +// Decomposes the `module` into individual ops and de-duplicates the decomposed +// op if `deduplicate_modules` is true. The modules are considered duplicate if +// if their computation graphs are isomorphic (i.e. computations and +// instructions are sorted, names are ignored etc). +absl::StatusOr>> DecomposeHloModule( + const HloModule& module, bool deduplicate_modules); + +// Extracts an HLO instruction into a new HLO module replacing its operands +// with parameter instructions. +std::unique_ptr ExtractInstructionIntoNewModule( + const HloInstruction& hlo); + +// Extracts HLO instructions into a new HLO module replacing all operands +// with parameter instructions even if the result of one instruction is used +// as a parameter to another. Combines results of all operations into the +// tuple and adds this tuple as a root instruction of the new module. +std::unique_ptr ExtractInstructionIntoNewModule( + const std::vector& instructions); + +// Extracts producer and consumer HLO instruction into a new HLO module +// replacing its operands with parameter instructions. +std::unique_ptr ExtractProducerConsumerIntoNewModule( + const HloInstruction& producer, const HloInstruction& consumer); + +// Extracts an HLO computation into a new HLO module, using its clone as the +// root computation. +std::unique_ptr ExtractComputationIntoNewModule( + const HloComputation& computation); + +} // namespace xla + +#endif // XLA_TOOLS_HLO_DECOMPOSER_H_ diff --git a/xla/tools/hlo_decomposer_test.cc b/xla/tools/hlo_decomposer_test.cc new file mode 100644 index 0000000000000..d60f94fdd26aa --- /dev/null +++ b/xla/tools/hlo_decomposer_test.cc @@ -0,0 +1,161 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tools/hlo_decomposer.h" + +#include +#include + +#include +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/tests/filecheck.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +class HloDecomposerTest : public HloTestBase { + protected: + std::unique_ptr GetModule() { + absl::string_view kHlo = R"( +HloModule test_module, entry_computation_layout={(bf16[1024,8192]{1,0}, f32[8192]{0}, f32[16384]{0})->(bf16[1024]{0}, bf16[1024]{0}, f32[16384]{0}, f32[16384]{0})} + +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add.1 = f32[] add(p0, p1) +} + +fused_computation.1 { + param_1.3 = f32[8192]{0} parameter(1) + broadcast.2 = f32[1024,8192]{1,0} broadcast(param_1.3), dimensions={1} + param_0.3 = bf16[1024,8192]{1,0} parameter(0) + convert.5 = f32[1024,8192]{1,0} convert(param_0.3) + multiply.2 = f32[1024,8192]{1,0} multiply(broadcast.2, convert.5) + c0_1 = f32[] constant(0) + reduce.1 = f32[1024]{0} reduce(multiply.2, c0_1), dimensions={1}, to_apply=add + ROOT convert.4 = bf16[1024]{0} convert(reduce.1) +} + +fused_computation.2 { + p0.0 = bf16[1024,8192]{1,0} parameter(0) + c.0 = f32[1024,8192]{1,0} convert(p0.0) + co0_1.1 = f32[] constant(0) + p.0 = f32[8192]{0} parameter(1) + b.0 = f32[1024,8192]{1,0} broadcast(p.0), dimensions={1} + m.0 = f32[1024,8192]{1,0} multiply(b.0, c.0) + r.0 = f32[1024]{0} reduce(m.0, co0_1.1), dimensions={1}, to_apply=add + ROOT c.1 = bf16[1024]{0} convert(r.0) +} + +exp { + param_0.5 = f32[16384]{0} parameter(0) + m.4 = f32[16384]{0} multiply(param_0.5, param_0.5) + e = f32[16384]{0} exponential(m.4) + l.clone.1 = f32[16384]{0} log(m.4) + ROOT tuple = (f32[16384]{0}, f32[16384]{0}) tuple(e, l.clone.1) +} + +ENTRY main { + p0.1 = bf16[1024,8192]{1,0} parameter(0) + p1.1 = f32[8192]{0} parameter(1) + fusion.1 = bf16[1024]{0} fusion(p0.1, p1.1), kind=kInput, calls=fused_computation.1 + fusion.2 = bf16[1024]{0} fusion(p0.1, p1.1), kind=kInput, calls=fused_computation.2 + p2 = f32[16384]{0} parameter(2) + e.1 = (f32[16384]{0}, f32[16384]{0}) fusion(p2), kind=kInput, calls=exp + get-tuple-element.1 = f32[16384]{0} get-tuple-element(e.1), index=1 + get-tuple-element = f32[16384]{0} get-tuple-element(e.1), index=0 + ROOT result = (bf16[1024]{0}, bf16[1024]{0}, f32[16384]{0}, f32[16384]{0}) tuple(fusion.1, fusion.2, get-tuple-element.1, get-tuple-element) +})"; + return ParseAndReturnVerifiedModule(kHlo).value(); + } + + void FindAndCompare(const std::vector>& modules, + absl::string_view module_name, + absl::string_view pattern) { + auto iter = + absl::c_find_if(modules, [&](const std::unique_ptr& module) { + return module->name() == module_name; + }); + EXPECT_NE(iter, modules.end()) << "No module named " << module_name; + if (iter == modules.end()) { + return; + } + EXPECT_TRUE(*RunFileCheck((*iter)->ToString(), pattern)); + } +}; + +TEST_F(HloDecomposerTest, DecomposeNoDedup) { + auto module = GetModule(); + TF_ASSERT_OK_AND_ASSIGN( + auto decomposed, + DecomposeHloModule(*module, /*deduplicate_modules=*/false)); + EXPECT_EQ(decomposed.size(), 3); + + FindAndCompare(decomposed, "fusion.1", R"( +CHECK: %add{{.*}} { +CHECK: %fused_computation.1 +CHECK: ENTRY +CHECK-THEN: %parameter.0 = bf16[1024,8192]{1,0} parameter(0) +CHECK-THEN: %parameter.1 = f32[8192]{0} parameter(1) +CHECK-THEN: ROOT %fusion.1 +)"); + + FindAndCompare(decomposed, "fusion.2", R"( +CHECK: %add{{.*}} { +CHECK: %fused_computation.2 +CHECK: ENTRY +CHECK-THEN: %parameter.0 = bf16[1024,8192]{1,0} parameter(0) +CHECK-THEN: %parameter.1 = f32[8192]{0} parameter(1) +CHECK-THEN: ROOT %fusion.2 +)"); + + FindAndCompare(decomposed, "e.1", R"( +CHECK: %exp{{.*}} { +CHECK: ENTRY +CHECK-THEN: %parameter.0 = f32[16384]{0} parameter(0) +CHECK-THEN: ROOT %e.1 +)"); +} + +TEST_F(HloDecomposerTest, DecomposeDedup) { + auto module = GetModule(); + TF_ASSERT_OK_AND_ASSIGN( + auto decomposed, + DecomposeHloModule(*module, /*deduplicate_modules=*/true)); + EXPECT_EQ(decomposed.size(), 2); + + FindAndCompare(decomposed, "fusion.1", R"( +CHECK: %add{{.*}} { +CHECK: %fused_computation.1 +CHECK: ENTRY +CHECK-THEN: %parameter.0 = bf16[1024,8192]{1,0} parameter(0) +CHECK-THEN: %parameter.1 = f32[8192]{0} parameter(1) +CHECK-THEN: ROOT %fusion.1 +)"); + + FindAndCompare(decomposed, "e.1", R"( +CHECK: %exp{{.*}} { +CHECK: ENTRY +CHECK-THEN: %parameter.0 = f32[16384]{0} parameter(0) +CHECK-THEN: ROOT %e.1 +)"); +} + +} // namespace +} // namespace xla diff --git a/xla/tools/hlo_expand.cc b/xla/tools/hlo_expand.cc index 95bc7e6b9eb1e..cd568564339d2 100644 --- a/xla/tools/hlo_expand.cc +++ b/xla/tools/hlo_expand.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,13 +24,16 @@ limitations under the License. #include "xla/service/hlo_verifier.h" #include "xla/service/rng_bit_generator_expander.h" #include "xla/service/rng_expander.h" +#include "xla/service/sharding_propagation.h" +#include "xla/service/spmd/stateful_rng_spmd_partitioner.h" #include "xla/service/triangular_solve_expander.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/xla_data.pb.h" -#include "tsl/util/command_line_flags.h" namespace xla { -void AddPassesToPipeline(HloExpandConfig& config, HloPassPipeline& pipeline) { +void AddPassesToPipeline(HloExpandConfig& config, HloPassPipeline& pipeline, + const HloModuleConfig& hlo_module_config) { if (config.batch_norm_grad_expander || config.batch_norm_inference_expander || config.batch_norm_training_expander) { pipeline.AddPass( @@ -55,6 +58,16 @@ void AddPassesToPipeline(HloExpandConfig& config, HloPassPipeline& pipeline) { if (config.triangular_solve_expander) { pipeline.AddPass(); } + if (config.spmd_expander) { + pipeline.AddPass( + /*is_spmd=*/true, /*propagate_metadata=*/false, + hlo_module_config.allow_spmd_sharding_propagation_to_output(), + hlo_module_config.allow_spmd_sharding_propagation_to_parameters()); + pipeline.AddPass( + hlo_module_config.num_partitions(), hlo_module_config.replica_count(), + hlo_module_config.debug_options() + .xla_gpu_threshold_for_windowed_einsum_mib()); + } if (config.verify_hlo) { pipeline.AddPass(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); @@ -69,17 +82,17 @@ std::vector GetFlags(HloExpandConfig& config) { "input_format", &config.input_format, "The format of the input file. If this flag is not specified, it's" "inferred from the file extension instead. Valid values:\n " - "* hlo : HLO textual format\n" - "* pb : xla::HloProto in binary proto format\n" - "* pbtxt : xla::HloProto in text proto format"), + "* hlo|txt : HLO textual format\n" + "* pb : xla::HloProto in binary proto format\n" + "* pbtxt : xla::HloProto in text proto format"), tsl::Flag("o", &config.output_file, "Alias of --output_file="), tsl::Flag("output_file", &config.output_file, "Full output file path"), tsl::Flag("output_format", &config.output_format, "The format of the output file. Defaults to input_format. " "Valid values:\n" - "* hlo : HLO textual format\n" - "* pb : xla::HloProto in binary proto format\n" - "* pbtxt : xla::HloProto in text proto format"), + "* hlo|txt : HLO textual format\n" + "* pb : xla::HloProto in binary proto format\n" + "* pbtxt : xla::HloProto in text proto format"), tsl::Flag("batch_norm_expander", &config.batch_norm_expander, "Overrides and expands batch_norm_grad, batch_norm_inference, " "and batch_norm_training ops"), @@ -93,6 +106,8 @@ std::vector GetFlags(HloExpandConfig& config) { "Expands batch_norm_training_grad op"), tsl::Flag("cholesky_expander", &config.cholesky_expander, "Expands cholesky op"), + tsl::Flag("spmd_expander", &config.spmd_expander, + "Expands SPMD sharding"), tsl::Flag("expand_all", &config.expand_all, "Overrides and expands all supported passes below"), tsl::Flag("rng_expander", &config.rng_expander, "Expands rng op"), diff --git a/xla/tools/hlo_expand.h b/xla/tools/hlo_expand.h index 7019dc45e887f..5c1818e91a81d 100644 --- a/xla/tools/hlo_expand.h +++ b/xla/tools/hlo_expand.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ limitations under the License. #include #include "xla/service/hlo_pass_pipeline.h" -#include "tsl/util/command_line_flags.h" +#include "xla/tsl/util/command_line_flags.h" namespace xla { @@ -45,12 +45,14 @@ struct HloExpandConfig { bool rng_bit_generator_philox_expander{false}; bool rng_bit_generator_three_fry_expander{false}; bool triangular_solve_expander{false}; + bool spmd_expander{false}; bool verify_hlo{false}; }; // Adds passes to the `pipeline` for flags set in `config`. void AddPassesToPipeline(xla::HloExpandConfig& config, - xla::HloPassPipeline& pipeline); + xla::HloPassPipeline& pipeline, + const xla::HloModuleConfig& hlo_module_config); // Wraps `config` with flag descriptions and returns a vector of `tsl::Flag`s. std::vector GetFlags(xla::HloExpandConfig& config); diff --git a/xla/tools/hlo_expand_main.cc b/xla/tools/hlo_expand_main.cc index e0aea0aa92865..60a83c3b55837 100644 --- a/xla/tools/hlo_expand_main.cc +++ b/xla/tools/hlo_expand_main.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,11 +25,11 @@ limitations under the License. #include "xla/service/hlo_pass_pipeline.h" #include "xla/tools/hlo_expand.h" #include "xla/tools/hlo_module_loader.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/xla.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/init_main.h" #include "tsl/platform/path.h" -#include "tsl/util/command_line_flags.h" namespace { @@ -97,9 +97,9 @@ int main(int argc, char** argv) { config.input_format = std::string(tsl::io::Extension(hlo_filename)); } if (config.input_format != "hlo" && config.input_format != "pb" && - config.input_format != "pbtxt") { + config.input_format != "pbtxt" && config.input_format != "txt") { std::cerr << absl::StrCat( - "input_format must be specified as [hlo|pb|pbtxt].\n", kHelpString); + "input_format must be specified as [hlo|pb|pbtxt|txt].\n", kHelpString); return 1; } @@ -117,7 +117,7 @@ int main(int argc, char** argv) { : std::string(tsl::io::Extension(output_filename)); } if (config.output_format != "hlo" && config.output_format != "pb" && - config.output_format != "pbtxt") { + config.output_format != "pbtxt" && config.output_format != "txt") { std::cerr << absl::StrCat( "output_format must be specified as [hlo|pb|pbtxt].\n", kHelpString); return 1; @@ -126,13 +126,12 @@ int main(int argc, char** argv) { // 1. Load HloModule from stdin or file. absl::StatusOr> status_or_module; if (hlo_filename == "-") { - std::string stdin; - std::getline(std::cin, stdin, static_cast(EOF)); - status_or_module = xla::LoadModuleFromData(stdin, config.input_format); + std::string input; + std::getline(std::cin, input, static_cast(EOF)); + status_or_module = xla::LoadModuleFromData(input, config.input_format); } else { - status_or_module = xla::LoadModuleFromFile( - hlo_filename, xla::hlo_module_loader_details::Config(), - config.input_format); + status_or_module = + xla::LoadModuleFromFile(hlo_filename, config.input_format); } if (!status_or_module.ok()) { std::cerr << status_or_module.status() << "\nTry: hlo-expand --help\n"; @@ -141,10 +140,10 @@ int main(int argc, char** argv) { // 2. Add a set of passes to the HloPassPipeline. xla::HloPassPipeline pipeline("expand_pass_pipeline"); - AddPassesToPipeline(config, pipeline); + auto& hlo_module = status_or_module.value(); + AddPassesToPipeline(config, pipeline, hlo_module->config()); // 3. Run a set of expander passes on the module. - auto& hlo_module = status_or_module.value(); auto pipeline_status = pipeline.Run(hlo_module.get()).status(); if (!pipeline_status.ok()) { std::cerr << pipeline_status; @@ -153,14 +152,14 @@ int main(int argc, char** argv) { // 4. Optionally print the output to stdout. if (output_filename == "-") { - if (config.output_format == "hlo") { + if (config.output_format == "hlo" || config.output_format == "txt") { std::cout << hlo_module->ToString(); } else if (config.output_format == "pbtxt") { std::cout << hlo_module->ToProto().DebugString(); } else { std::cerr << absl::StrCat( "Printing to stdout must specify supported " - "output_format=[hlo|pbtxt].\n", + "output_format=[hlo|pbtxt|txt].\n", kHelpString); return 1; } diff --git a/xla/tools/hlo_extractor.cc b/xla/tools/hlo_extractor.cc index 34f3e53e49636..f9f07153cd9c9 100644 --- a/xla/tools/hlo_extractor.cc +++ b/xla/tools/hlo_extractor.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,7 +15,9 @@ limitations under the License. #include "xla/tools/hlo_extractor.h" +#ifndef _WIN32 #include +#endif #include #include @@ -187,7 +189,7 @@ class ExtractionVisitor : public ConstDfsHloVisitorWithDefault { private: // Replace the `hlo` with Constant of the same shape. Status ReplaceWithConstant(const HloInstruction* hlo) { - StatusOr literal_status = MakeFakeLiteral(hlo->shape()); + absl::StatusOr literal_status = MakeFakeLiteral(hlo->shape()); TF_CHECK_OK(literal_status.status()); auto new_const = HloInstruction::CreateConstant(std::move(literal_status.value())); @@ -246,7 +248,8 @@ class ExtractionVisitor : public ConstDfsHloVisitorWithDefault { builder->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(constant_shape.element_type()))); } else { - StatusOr literal_status = MakeFakeLiteral(constant_shape); + absl::StatusOr literal_status = + MakeFakeLiteral(constant_shape); TF_CHECK_OK(literal_status.status()); constant_instruction = builder->AddInstruction( HloInstruction::CreateConstant(std::move(literal_status.value()))); diff --git a/xla/tools/hlo_extractor.h b/xla/tools/hlo_extractor.h index 6dadb333a994d..cb1ea5914f7a0 100644 --- a/xla/tools/hlo_extractor.h +++ b/xla/tools/hlo_extractor.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tools/hlo_extractor_test.cc b/xla/tools/hlo_extractor_test.cc index 1246b390ab4fc..77e4d8b3a2760 100644 --- a/xla/tools/hlo_extractor_test.cc +++ b/xla/tools/hlo_extractor_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tools/hlo_module_loader.cc b/xla/tools/hlo_module_loader.cc index 3c72824634a32..7e3547e5ef53e 100644 --- a/xla/tools/hlo_module_loader.cc +++ b/xla/tools/hlo_module_loader.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/strings/str_cat.h" @@ -28,18 +29,18 @@ limitations under the License. #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_parser.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" #include "tsl/platform/path.h" #include "tsl/platform/protobuf.h" -#include "tsl/platform/regexp.h" namespace xla { namespace { -Status OverrideConfig(const hlo_module_loader_details::Config& ovr_config, - HloModuleConfig* config) { +absl::Status OverrideConfig(const hlo_module_loader_details::Config& ovr_config, + HloModuleConfig* config) { config->set_replica_count(ovr_config.num_replicas); config->set_num_partitions(ovr_config.num_partitions); return OkStatus(); @@ -47,12 +48,12 @@ Status OverrideConfig(const hlo_module_loader_details::Config& ovr_config, } // namespace -std::string StripLogHeaders(const std::string& hlo_string) { +std::string StripLogHeaders(std::string_view hlo_string) { // I0521 12:04:45.883483 1509 service.cc:186] ... static RE2* matcher = new RE2( "[IWEF]\\d{4} " "\\d{2}:\\d{2}:\\d{2}\\.\\d+\\s+\\d+\\s+[^:]+:\\d+\\]\\s?(.*)"); - absl::string_view matches[4]; + std::string_view matches[4]; std::vector lines = absl::StrSplit(hlo_string, '\n'); for (auto& line : lines) { if (matcher->Match(line, 0, line.size(), RE2::ANCHOR_START, matches, 4)) { @@ -65,9 +66,9 @@ std::string StripLogHeaders(const std::string& hlo_string) { }); } -StatusOr> LoadModuleFromData( - const std::string& data, const std::string& format, - hlo_module_loader_details::Config ovr_config, +absl::StatusOr> LoadModuleFromData( + const std::string& data, std::string_view format, + const hlo_module_loader_details::Config& ovr_config, const std::function& config_modifier_hook, BufferAssignmentProto* buffer_assignment_proto) { DebugOptions debug_options = GetDebugOptionsFromFlags(); @@ -125,9 +126,9 @@ StatusOr> LoadModuleFromData( return std::move(module); } -StatusOr> LoadModuleFromFile( - const std::string& path, hlo_module_loader_details::Config ovr_config, - std::string format, +absl::StatusOr> LoadModuleFromFile( + const std::string& path, std::string format, + const hlo_module_loader_details::Config& ovr_config, const std::function& config_modifier_hook, BufferAssignmentProto* buffer_assignment_proto) { std::string data; @@ -139,8 +140,8 @@ StatusOr> LoadModuleFromFile( buffer_assignment_proto); } -StatusOr> LoadInputFromData( - const std::string& data, absl::string_view format) { +absl::StatusOr> +LoadInputFromData(const std::string& data, std::string_view format) { HloSnapshot proto; if (format == "pb") { if (!proto.ParseFromString(data) && @@ -171,8 +172,8 @@ StatusOr> LoadInputFromData( return std::move(iteration_literals_proto); } -StatusOr> LoadInputFromFile( - const std::string& path, std::string format) { +absl::StatusOr> +LoadInputFromFile(const std::string& path, std::string format) { std::string data; if (format.empty()) { format = std::string(tsl::io::Extension(path)); diff --git a/xla/tools/hlo_module_loader.h b/xla/tools/hlo_module_loader.h index 201ee9fb5dd2b..29a4c7e8e5162 100644 --- a/xla/tools/hlo_module_loader.h +++ b/xla/tools/hlo_module_loader.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" @@ -38,7 +39,7 @@ struct Config { // Given a string composed by multiple lines, strip the log headers, if present // at the beginning of each line. -std::string StripLogHeaders(const std::string& hlo_string); +std::string StripLogHeaders(std::string_view hlo_string); // Loads an HLO module from a string. // The data can have the followings formats: @@ -54,9 +55,9 @@ std::string StripLogHeaders(const std::string& hlo_string); // modifications before use. If the buffer assignment proto pointer is not null // and the hlo module format is proto, it loads buffer assignment from the // proto. -StatusOr> LoadModuleFromData( - const std::string& data, const std::string& format, - hlo_module_loader_details::Config ovr_config = +absl::StatusOr> LoadModuleFromData( + const std::string& data, std::string_view format, + const hlo_module_loader_details::Config& ovr_config = hlo_module_loader_details::Config(), const std::function& config_modifier_hook = {}, BufferAssignmentProto* buffer_assignment_proto = nullptr); @@ -76,11 +77,10 @@ StatusOr> LoadModuleFromData( // modifications before use. If the buffer assignment proto pointer is not null // and the hlo module format is proto, it loads buffer assignment from the // proto. -StatusOr> LoadModuleFromFile( - const std::string& path, - hlo_module_loader_details::Config ovr_config = +absl::StatusOr> LoadModuleFromFile( + const std::string& path, std::string format = "", + const hlo_module_loader_details::Config& ovr_config = hlo_module_loader_details::Config(), - std::string format = "", const std::function& config_modifier_hook = {}, BufferAssignmentProto* buffer_assignment_proto = nullptr); @@ -88,8 +88,8 @@ StatusOr> LoadModuleFromFile( // The data format must be one of the following: // 1) A binary proto (format "pb") // 2) A text proto (format "pbtxt") -StatusOr> LoadInputFromData( - const std::string& data, absl::string_view format); +absl::StatusOr> +LoadInputFromData(const std::string& data, std::string_view format); // Loads an HLO snapshot from file, only for its inputs // The file must be one of the following: @@ -97,8 +97,8 @@ StatusOr> LoadInputFromData( // 2) A text proto (with a .pbtxt extension) // If the format is specified (not empty), it overrides the one guessed from the // file extension. -StatusOr> LoadInputFromFile( - const std::string& path, std::string format = ""); +absl::StatusOr> +LoadInputFromFile(const std::string& path, std::string format = ""); } // namespace xla diff --git a/xla/tools/hlo_module_loader_test.cc b/xla/tools/hlo_module_loader_test.cc index 788472ce60b7d..e3916a0ec98ac 100644 --- a/xla/tools/hlo_module_loader_test.cc +++ b/xla/tools/hlo_module_loader_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tools/hlo_opt/BUILD b/xla/tools/hlo_opt/BUILD index a75f49637ef17..ffaa4ecb38e9f 100644 --- a/xla/tools/hlo_opt/BUILD +++ b/xla/tools/hlo_opt/BUILD @@ -1,20 +1,23 @@ -load("//xla:glob_lit_test.bzl", "glob_lit_tests") load( - "//xla/stream_executor:build_defs.bzl", - "if_gpu_is_configured", + "@local_config_rocm//rocm:build_defs.bzl", + "if_rocm_is_configured", ) load("@tsl//tsl:tsl.default.bzl", "filegroup") load( "@tsl//tsl/platform:build_config_root.bzl", "tf_cuda_tests_tags", ) -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") # hlo-opt tool. load( "@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) +load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") +load( + "//xla/stream_executor:build_defs.bzl", + "if_gpu_is_configured", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -33,42 +36,78 @@ cc_library( "//xla:types", "//xla/hlo/ir:hlo", "//xla/service:compiler", + "//xla/service:executable", + "//xla/service:hlo_graph_dumper", + "//xla/service:platform_util", + "//xla/stream_executor", "//xla/stream_executor:platform", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", + "@tsl//tsl/platform:statusor", ], ) cc_library( name = "gpu_opt", testonly = True, - srcs = if_cuda_is_configured(["gpu_opt.cc"]), + srcs = if_gpu_is_configured(["gpu_opt.cc"]), deps = [ ":opt_lib", "//xla:debug_options_flags", "//xla:statusor", "//xla:types", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_value", "//xla/service:compiler", "//xla/service:dump", + "//xla/service:executable", + "//xla/service:hlo_graph_dumper", "//xla/service:platform_util", + "//xla/service/gpu:buffer_sharing", + "//xla/service/gpu:compile_module_to_llvm_ir", "//xla/service/gpu:executable_proto_cc", - "//xla/stream_executor/cuda:cuda_platform_id", + "//xla/service/gpu:gpu_compiler", + "//xla/service/gpu:gpu_hlo_schedule", + "//xla/service/llvm_ir:llvm_util", + "//xla/stream_executor", "//xla/stream_executor/platform", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:ir_headers", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ] + if_gpu_is_configured([ "//xla/service:gpu_plugin", "//xla/service/gpu:gpu_executable", ]) + if_cuda_is_configured([ "//xla/stream_executor:cuda_platform", + ]) + if_rocm_is_configured([ + "//xla/stream_executor:rocm_platform", ]), alwayslink = True, # Initializer needs to run. ) +cc_library( + name = "cpu_opt", + testonly = True, + srcs = ["cpu_opt.cc"], + deps = [ + ":opt_lib", + "//xla/service:cpu_plugin", + "//xla/service:hlo_graph_dumper", + "//xla/stream_executor/host:host_platform", + "//xla/stream_executor/platform", + ], + alwayslink = True, # Initializer needs to run. +) + cc_library( name = "opt_main", testonly = True, srcs = ["opt_main.cc"], deps = [ + "cpu_opt", ":opt_lib", "//xla:debug_options_flags", "//xla:status", @@ -78,7 +117,9 @@ cc_library( "//xla/service:platform_util", "//xla/tools:hlo_module_loader", "//xla/tools:run_hlo_module_lib", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", @@ -86,22 +127,50 @@ cc_library( "@tsl//tsl/platform:path", "@tsl//tsl/platform:platform_port", "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", - "@tsl//tsl/util:command_line_flags", ] + if_gpu_is_configured([ ":gpu_opt", ]) + if_cuda_is_configured([ "//xla/stream_executor:cuda_platform", + ]) + if_rocm_is_configured([ + "//xla/stream_executor:rocm_platform", ]), ) -glob_lit_tests( - name = "gpu_opt_tests", +lit_test_suite( + name = "hlo_opt_tests", + srcs = enforce_glob( + [ + "cpu_hlo.hlo", + "gpu_hlo.hlo", + "gpu_hlo_backend.hlo", + "gpu_hlo_buffers.hlo", + "gpu_hlo_llvm.hlo", + "gpu_hlo_ptx.hlo", + "gpu_hlo_unoptimized_llvm.hlo", + ], + include = [ + "*.hlo", + ], + ), + args = if_cuda_is_configured([ + "--param=PTX=PTX", + "--param=GPU=a100_80", + ]) + if_rocm_is_configured([ + "--param=PTX=GCN", + "--param=GPU=mi200", + ]), + cfg = "//xla:lit.cfg.py", data = [":test_utilities"], - default_tags = tf_cuda_tests_tags() + [ + default_tags = tf_cuda_tests_tags(), + tags_override = { + "gpu_hlo_ptx.hlo": ["no_rocm"], + }, + tools = [ + "//xla/tools:hlo-opt", + "@llvm-project//llvm:FileCheck", ], - driver = "//xla:run_lit.sh", - test_file_exts = ["hlo"], ) # Bundle together all of the test utilities that are used by tests. @@ -109,9 +178,18 @@ filegroup( name = "test_utilities", testonly = True, data = [ - "gpu_specs/a100.txtpb", + "gpu_specs/a100_80.txtpb", + "gpu_specs/mi200.txtpb", "//xla/tools:hlo-opt", "@llvm-project//llvm:FileCheck", - "@llvm-project//mlir:run_lit.sh", ], ) + +filegroup( + name = "all_gpu_specs", + data = glob(["gpu_specs/*.txtpb"]), +) + +exports_files(glob([ + "gpu_specs/*.txtpb", +])) diff --git a/xla/tools/hlo_opt/cpu_hlo.hlo b/xla/tools/hlo_opt/cpu_hlo.hlo new file mode 100644 index 0000000000000..1c9b2a81a9894 --- /dev/null +++ b/xla/tools/hlo_opt/cpu_hlo.hlo @@ -0,0 +1,12 @@ +// RUN: hlo-opt %s --platform=cpu --stage=hlo | FileCheck %s + +HloModule module + +ENTRY computation { +// CHECK: outer_dimension_partitions + p = f32[5000,6000]{1,0} parameter(0) + e = f32[5000,6000]{1,0} sqrt(p) + c = f32[6000,5000] transpose(p), dimensions={1,0} + r = f32[300,20,5000] reshape(c) + ROOT out = (f32[5000,6000], f32[300,20,5000]) tuple(e,r) +} diff --git a/xla/tools/hlo_opt/cpu_opt.cc b/xla/tools/hlo_opt/cpu_opt.cc new file mode 100644 index 0000000000000..9e7d5c2b72ace --- /dev/null +++ b/xla/tools/hlo_opt/cpu_opt.cc @@ -0,0 +1,37 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "xla/stream_executor/platform/initialize.h" +#include "xla/tools/hlo_opt/opt_lib.h" + +namespace xla { + +namespace { + +class CpuOptProvider : public OptProvider { + public: + std::string GetPlatformName() override { return "cpu"; } +}; + +} // namespace +} // namespace xla + +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(cpu_opt_provider, { + xla::OptProvider::RegisterForPlatform( + "cpu", std::make_unique()); +}); diff --git a/xla/tools/hlo_opt/gpu_hlo.hlo b/xla/tools/hlo_opt/gpu_hlo.hlo old mode 100644 new mode 100755 index c170f21bb3194..0633b7e5ef7ce --- a/xla/tools/hlo_opt/gpu_hlo.hlo +++ b/xla/tools/hlo_opt/gpu_hlo.hlo @@ -1,7 +1,10 @@ -// RUN: hlo-opt %s --platform=CUDA --stage=hlo --xla_gpu_target_config_filename=%S/gpu_specs/a100.txtpb | FileCheck %s +// RUN: hlo-opt %s --platform=gpu --stage=hlo --xla_gpu_target_config_filename=%S/gpu_specs/%{GPU}.txtpb | FileCheck %s HloModule module +// --stage=hlo doesn't run scheduling. +// CHECK-NOT: is_scheduled + ENTRY computation { // CHECK: bitcast p = f32[5000,6000]{1,0} parameter(0) diff --git a/xla/tools/hlo_opt/gpu_hlo_backend.hlo b/xla/tools/hlo_opt/gpu_hlo_backend.hlo new file mode 100644 index 0000000000000..61b6b9aa778b9 --- /dev/null +++ b/xla/tools/hlo_opt/gpu_hlo_backend.hlo @@ -0,0 +1,13 @@ +// RUN: hlo-opt %s --platform=gpu --stage=hlo-backend --xla_gpu_target_config_filename=%S/gpu_specs/%{GPU}.txtpb | FileCheck %s + +HloModule module + +// CHECK: is_scheduled=true + +ENTRY computation { + p = f32[5000,6000]{1,0} parameter(0) + e = f32[5000,6000]{1,0} sqrt(p) + c = f32[6000,5000] transpose(p), dimensions={1,0} + r = f32[300,20,5000] reshape(c) + ROOT out = (f32[5000,6000], f32[300,20,5000]) tuple(e,r) +} diff --git a/xla/tools/hlo_opt/gpu_hlo_buffers.hlo b/xla/tools/hlo_opt/gpu_hlo_buffers.hlo index 434d6876f62c8..e7b8321cc6480 100644 --- a/xla/tools/hlo_opt/gpu_hlo_buffers.hlo +++ b/xla/tools/hlo_opt/gpu_hlo_buffers.hlo @@ -1,4 +1,4 @@ -// RUN: hlo-opt %s --platform=CUDA --stage=buffer-assignment --xla_gpu_target_config_filename=%S/gpu_specs/a100.txtpb | FileCheck %s +// RUN: hlo-opt %s --platform=gpu --stage=buffer-assignment --xla_gpu_target_config_filename=%S/gpu_specs/%{GPU}.txtpb | FileCheck %s HloModule m diff --git a/xla/tools/hlo_opt/gpu_hlo_llvm.hlo b/xla/tools/hlo_opt/gpu_hlo_llvm.hlo index b99bba81b2a39..59800a9d17056 100644 --- a/xla/tools/hlo_opt/gpu_hlo_llvm.hlo +++ b/xla/tools/hlo_opt/gpu_hlo_llvm.hlo @@ -1,4 +1,4 @@ -// RUN: hlo-opt %s --platform=CUDA --stage=llvm --xla_gpu_target_config_filename=%S/gpu_specs/a100.txtpb | FileCheck %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm --xla_gpu_target_config_filename=%S/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s HloModule m @@ -9,9 +9,29 @@ add { } -// CHECK: load half +// CHECK-LABEL: fusion +// CHECK: 2 x half ENTRY e { p1 = f16[1048576] parameter(0) i = f16[] constant(0) ROOT out = f16[] reduce(p1, i), dimensions={0}, to_apply=add } + +// ----- + +HloModule Test, is_scheduled=true + + +// CHECK-LABEL: fusion +// CHECK-PTX: call void @llvm.nvvm.barrier0 +// CHECK-GCN: call void @llvm.amdgcn.s.barrier +fused_computation { + param_0 = f32[100,200]{1,0} parameter(0) + ROOT b.1 = f32[100,200]{0,1} copy(f32[100,200]{1,0} param_0) +} + +ENTRY main { + a = f32[100, 200]{1,0} parameter(0) + ROOT wrapped_b = f32[100,200]{0,1} fusion(f32[100,200]{1,0} a), kind=kInput, calls=fused_computation +} + diff --git a/xla/tools/hlo_opt/gpu_hlo_ptx.hlo b/xla/tools/hlo_opt/gpu_hlo_ptx.hlo index 428bb07d95329..5c6485a57813a 100644 --- a/xla/tools/hlo_opt/gpu_hlo_ptx.hlo +++ b/xla/tools/hlo_opt/gpu_hlo_ptx.hlo @@ -1,4 +1,4 @@ -// RUN: hlo-opt %s --platform=CUDA --stage=ptx --xla_gpu_target_config_filename=%S/gpu_specs/a100.txtpb | FileCheck %s +// RUN: hlo-opt %s --platform=CUDA --stage=ptx --xla_gpu_target_config_filename=%S/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck %s HloModule m @@ -9,9 +9,29 @@ add { } +// CHECK-LABEL: .target sm_80 // CHECK: shfl.sync.down ENTRY e { p1 = f16[1048576] parameter(0) i = f16[] constant(0) ROOT out = f16[] reduce(p1, i), dimensions={0}, to_apply=add } + +// ----- + +// CHECK-LABEL: .target sm_80 + +HloModule m + +add { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT out = f32[] add(a, b) +} + +// CHECK-NOT: cvt.u16.u32 +ENTRY e { + p1 = f32[3] parameter(0) + i = f32[] constant(0) + ROOT out = f32[] reduce(p1, i), dimensions={0}, to_apply=add +} diff --git a/xla/tools/hlo_opt/gpu_hlo_unoptimized_llvm.hlo b/xla/tools/hlo_opt/gpu_hlo_unoptimized_llvm.hlo new file mode 100644 index 0000000000000..6c8bc8bd54fe6 --- /dev/null +++ b/xla/tools/hlo_opt/gpu_hlo_unoptimized_llvm.hlo @@ -0,0 +1,51 @@ +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/gpu_specs/%{GPU}.txtpb | FileCheck %s + +// CHECK: fusion.in_bounds-true: +// CHECK: br label +// CHECK: concat_index_from_operand_id0: + + +HloModule module, is_scheduled=true + +%fused_computation (param_0.1: f32[1000], param_1.2: f32[1000], param_2.3: f32[1000], param_3.4: f32[1000], param_4.5: f32[1000], param_5.6: f32[1000], param_6.7: f32[1000], param_7.8: f32[1000], param_8.9: f32[1000], param_9.10: f32[1000], param_10.11: f32[1000]) -> f16[11000] { + %param_10.11 = f32[1000]{0} parameter(10) + %converted0.1 = f16[1000]{0} convert(f32[1000]{0} %param_10.11) + %param_9.10 = f32[1000]{0} parameter(9) + %converted1.1 = f16[1000]{0} convert(f32[1000]{0} %param_9.10) + %param_8.9 = f32[1000]{0} parameter(8) + %converted2.1 = f16[1000]{0} convert(f32[1000]{0} %param_8.9) + %param_7.8 = f32[1000]{0} parameter(7) + %converted3.1 = f16[1000]{0} convert(f32[1000]{0} %param_7.8) + %param_6.7 = f32[1000]{0} parameter(6) + %converted4.1 = f16[1000]{0} convert(f32[1000]{0} %param_6.7) + %param_5.6 = f32[1000]{0} parameter(5) + %converted5.1 = f16[1000]{0} convert(f32[1000]{0} %param_5.6) + %param_4.5 = f32[1000]{0} parameter(4) + %converted6.1 = f16[1000]{0} convert(f32[1000]{0} %param_4.5) + %param_3.4 = f32[1000]{0} parameter(3) + %converted7.1 = f16[1000]{0} convert(f32[1000]{0} %param_3.4) + %param_2.3 = f32[1000]{0} parameter(2) + %converted8.1 = f16[1000]{0} convert(f32[1000]{0} %param_2.3) + %param_1.2 = f32[1000]{0} parameter(1) + %converted9.1 = f16[1000]{0} convert(f32[1000]{0} %param_1.2) + %param_0.1 = f32[1000]{0} parameter(0) + %converted10.1 = f16[1000]{0} convert(f32[1000]{0} %param_0.1) + ROOT %out.1 = f16[11000]{0} concatenate(f16[1000]{0} %converted0.1, f16[1000]{0} %converted1.1, f16[1000]{0} %converted2.1, f16[1000]{0} %converted3.1, f16[1000]{0} %converted4.1, /*index=5*/f16[1000]{0} %converted5.1, f16[1000]{0} %converted6.1, f16[1000]{0} %converted7.1, f16[1000]{0} %converted8.1, f16[1000]{0} %converted9.1, /*index=10*/f16[1000]{0} %converted10.1), dimensions={0} +} + +ENTRY %computation (p0: f32[1000], p1: f32[1000], p2: f32[1000], p3: f32[1000], p4: f32[1000], p5: f32[1000], p6: f32[1000], p7: f32[1000], p8: f32[1000], p9: f32[1000], p10: f32[1000]) -> f16[11000] { + %p10 = f32[1000]{0} parameter(10) + %p9 = f32[1000]{0} parameter(9) + %p8 = f32[1000]{0} parameter(8) + %p7 = f32[1000]{0} parameter(7) + %p6 = f32[1000]{0} parameter(6) + %p5 = f32[1000]{0} parameter(5) + %p4 = f32[1000]{0} parameter(4) + %p3 = f32[1000]{0} parameter(3) + %p2 = f32[1000]{0} parameter(2) + %p1 = f32[1000]{0} parameter(1) + %p0 = f32[1000]{0} parameter(0) + ROOT %fusion = f16[11000]{0} fusion(f32[1000]{0} %p10, f32[1000]{0} %p9, f32[1000]{0} %p8, f32[1000]{0} %p7, f32[1000]{0} %p6, /*index=5*/f32[1000]{0} %p5, f32[1000]{0} %p4, f32[1000]{0} %p3, f32[1000]{0} %p2, f32[1000]{0} %p1, /*index=10*/f32[1000]{0} %p0), kind=kLoop, calls=%fused_computation +} + + diff --git a/xla/tools/hlo_opt/gpu_opt.cc b/xla/tools/hlo_opt/gpu_opt.cc index 9f6b72b8edd03..ee8f93aa7cb0e 100644 --- a/xla/tools/hlo_opt/gpu_opt.cc +++ b/xla/tools/hlo_opt/gpu_opt.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,111 +13,126 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include #include +#include #include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "llvm/IR/LLVMContext.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/buffer_value.h" #include "xla/service/compiler.h" #include "xla/service/dump.h" +#include "xla/service/executable.h" +#include "xla/service/gpu/buffer_sharing.h" +#include "xla/service/gpu/compile_module_to_llvm_ir.h" #include "xla/service/gpu/executable.pb.h" +#include "xla/service/gpu/gpu_compiler.h" #include "xla/service/gpu/gpu_executable.h" +#include "xla/service/gpu/gpu_hlo_schedule.h" +#include "xla/service/llvm_ir/llvm_util.h" #include "xla/service/platform_util.h" #include "xla/statusor.h" -#include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/platform/initialize.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/tools/hlo_opt/opt_lib.h" #include "xla/types.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { -// TODO(cheshire): Switch CUDA/ROCM -static auto kGpuPlatformId = se::cuda::kCudaPlatformId; - -static StatusOr> ToGpuExecutable( - std::unique_ptr module, Compiler* compiler, - se::StreamExecutor* executor, const Compiler::CompileOptions& opts) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr optimized_module, - compiler->RunHloPasses(std::move(module), executor, opts)); - DebugOptions d = optimized_module->config().debug_options(); - d.set_xla_embed_ir_in_executable(true); - optimized_module->mutable_config().set_debug_options(d); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - compiler->RunBackend(std::move(optimized_module), executor, opts)); - return executable; -} - -struct GpuOptProvider : public OptProvider { - StatusOr> GenerateStage( +class GpuOptProvider : public OptProvider { + public: + absl::StatusOr> GenerateStage( std::unique_ptr module, absl::string_view s) override { - TF_ASSIGN_OR_RETURN( - se::Platform * platform, - se::MultiPlatformManager::PlatformWithId(kGpuPlatformId)); - - TF_ASSIGN_OR_RETURN(Compiler * compiler, - Compiler::GetForPlatform(platform)); - DebugOptions debug_opts = GetDebugOptionsFromFlags(); - - Compiler::CompileOptions opts; - - se::StreamExecutor* executor = nullptr; - if (debug_opts.xla_gpu_target_config_filename().empty()) { - TF_ASSIGN_OR_RETURN(std::vector stream_executors, - PlatformUtil::GetStreamExecutors( - platform, /*allowed_devices=*/std::nullopt)); - executor = stream_executors[0]; - } - - if (s == "hlo") { - TF_ASSIGN_OR_RETURN( - std::unique_ptr optimized_module, - compiler->RunHloPasses(std::move(module), executor, opts)); - return optimized_module->ToString(); - } else if (s == "llvm") { - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - ToGpuExecutable(std::move(module), compiler, executor, opts)); + if (s == "llvm-before-optimizations") { + TF_ASSIGN_OR_RETURN(std::unique_ptr optimized_module, + GetOptimizedHlo(std::move(module))); + TF_ASSIGN_OR_RETURN(std::string llvm_ir, + LlvmIrBeforeOptimizations(optimized_module.get())); + return llvm_ir; + + } else if (s == "llvm" || s == "llvm-after-optimizations") { + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + GetExecutable(std::move(module))); return static_cast(executable.get()) ->ir_module_string(); } else if (s == "ptx") { - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - ToGpuExecutable(std::move(module), compiler, executor, opts)); + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + GetExecutable(std::move(module))); return static_cast(executable.get())->text(); } else if (s == "buffer-assignment") { - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - ToGpuExecutable(std::move(module), compiler, executor, opts)); + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + GetExecutable(std::move(module))); return static_cast(executable.get()) ->buffer_assignment() ->ToVerboseString(9999); - } else if (s == "html") { - TF_ASSIGN_OR_RETURN( - std::unique_ptr optimized_module, - compiler->RunHloPasses(std::move(module), executor, opts)); - return RenderGraph(optimized_module->name(), *optimized_module, - RenderedGraphFormat::kHtml, - /*show_fusion_subcomputations=*/false); + } else { + // Delegate to base class. + TF_ASSIGN_OR_RETURN(std::optional out, + OptProvider::GenerateStage(std::move(module), s)); + return out; } + } - // Unimplemented stage. - return std::nullopt; + std::string GetPlatformName() override { return "gpu"; } + + std::set SupportedStages() override { + std::set supported = OptProvider::SupportedStages(); + supported.insert({"ptx", "llvm", "buffer-assignment", + "llvm-before-optimizations", "llvm-after-optimizations"}); + return supported; } - std::vector SupportedStages() override { - return {"hlo", "llvm", "ptx", "buffer-assignment", "html"}; + private: + absl::StatusOr LlvmIrBeforeOptimizations( + HloModule* optimized_module) { + Compiler::CompileOptions opts; + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, GetExecutor()); + TF_ASSIGN_OR_RETURN( + Compiler::TargetConfig target_config, + gpu::GpuCompiler::GetTargetConfig( + opts, optimized_module->config().debug_options(), executor)); + + TF_ASSIGN_OR_RETURN(se::Platform * platform, + PlatformUtil::GetPlatform(GetPlatformName())); + TF_ASSIGN_OR_RETURN(Compiler * compiler, + Compiler::GetForPlatform(platform)); + + auto* gpu_compiler = static_cast(compiler); + if (!optimized_module->has_schedule()) { + TF_ASSIGN_OR_RETURN(gpu::ScheduleMetadata schedule_metadata, + gpu::ScheduleGpuModule( + optimized_module, gpu_compiler->GetPointerSize(), + target_config.device_description)); + TF_RETURN_IF_ERROR(gpu_compiler->RunPostSchedulingPipelines( + optimized_module, schedule_metadata.scheduler_mem_limit, + target_config.device_description)); + } + + llvm::LLVMContext llvm_context; + TF_ASSIGN_OR_RETURN( + xla::gpu::CompileModuleResults results, + xla::gpu::CompileModuleToLlvmIr( + optimized_module, &llvm_context, gpu_compiler->GetTargetTriple(), + gpu_compiler->GetDataLayout(), platform->Name(), platform->id(), + target_config.device_description, gpu_compiler->GetCanShareBuffer(), + gpu_compiler->BufferSizeBytesFunction())); + return llvm_ir::DumpToString(results.llvm_module.get()); } }; } // namespace } // namespace xla -REGISTER_MODULE_INITIALIZER(gpu_opt_provider, { +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(gpu_opt_provider, { xla::OptProvider::RegisterForPlatform( - xla::kGpuPlatformId, std::make_unique()); + "gpu", std::make_unique()); }); diff --git a/xla/tools/hlo_opt/gpu_specs/a100.txtpb b/xla/tools/hlo_opt/gpu_specs/a100.txtpb deleted file mode 100644 index 864125066c3ae..0000000000000 --- a/xla/tools/hlo_opt/gpu_specs/a100.txtpb +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -gpu_device_info { - cuda_compute_capability { - major: 8 - minor: 0 - } - threads_per_block_limit: 1024 - threads_per_warp: 32 - shared_memory_per_block: 65536 - shared_memory_per_block_optin: 65536 - shared_memory_per_core: 65536 - threads_per_core_limit: 2048 - core_count: 6192 - fpus_per_core: 64 - block_dim_limit_x: 2147483647 - block_dim_limit_y: 65535 - block_dim_limit_z: 65535 - memory_bandwidth: 2039000000000 - l2_cache_size: 4194304 - clock_rate_ghz: 1.1105 - device_memory_size: 79050250240 -} -platform_name: "CUDA" -dnn_version_info { - major: 8 - minor: 3 - patch: 2 -} -device_description_str: "A100 80GB" diff --git a/xla/tools/hlo_opt/gpu_specs/a100_40.txtpb b/xla/tools/hlo_opt/gpu_specs/a100_40.txtpb new file mode 100644 index 0000000000000..755bca3967114 --- /dev/null +++ b/xla/tools/hlo_opt/gpu_specs/a100_40.txtpb @@ -0,0 +1,41 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +gpu_device_info { + threads_per_block_limit: 1024 + threads_per_warp: 32 + shared_memory_per_block: 49152 + shared_memory_per_core: 167936 + threads_per_core_limit: 2048 + core_count: 108 + fpus_per_core: 64 + block_dim_limit_x: 2147483647 + block_dim_limit_y: 65535 + block_dim_limit_z: 65535 + memory_bandwidth: 1555200000000 + l2_cache_size: 41943040 + clock_rate_ghz: 1.41 + device_memory_size: 42331013120 + shared_memory_per_block_optin: 166912 + cuda_compute_capability { + major: 8 + } +} +platform_name: "CUDA" +dnn_version_info { + major: 8 + minor: 9 + patch: 4 +} +device_description_str: "NVIDIA A100-SXM4-40GB" diff --git a/xla/tools/hlo_opt/gpu_specs/a100_80.txtpb b/xla/tools/hlo_opt/gpu_specs/a100_80.txtpb new file mode 100644 index 0000000000000..cf29fa306fdb9 --- /dev/null +++ b/xla/tools/hlo_opt/gpu_specs/a100_80.txtpb @@ -0,0 +1,41 @@ +# Copyright 2023 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +gpu_device_info { + threads_per_block_limit: 1024 + threads_per_warp: 32 + shared_memory_per_block: 49152 + shared_memory_per_core: 167936 + threads_per_core_limit: 2048 + core_count: 108 + fpus_per_core: 64 + block_dim_limit_x: 2147483647 + block_dim_limit_y: 65535 + block_dim_limit_z: 65535 + memory_bandwidth: 2039000000000 + l2_cache_size: 41943040 + clock_rate_ghz: 1.1105 + device_memory_size: 79050250240 + shared_memory_per_block_optin: 166912 + cuda_compute_capability { + major: 8 + } +} +platform_name: "CUDA" +dnn_version_info { + major: 8 + minor: 3 + patch: 2 +} +device_description_str: "A100 80GB" diff --git a/xla/tools/hlo_opt/gpu_specs/a6000.txtpb b/xla/tools/hlo_opt/gpu_specs/a6000.txtpb new file mode 100644 index 0000000000000..c4dafc26b15b7 --- /dev/null +++ b/xla/tools/hlo_opt/gpu_specs/a6000.txtpb @@ -0,0 +1,43 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +gpu_device_info { + threads_per_block_limit: 1024 + threads_per_warp: 32 + shared_memory_per_block: 49152 + shared_memory_per_core: 102400 + threads_per_core_limit: 1536 + core_count: 84 + fpus_per_core: 128 + block_dim_limit_x: 2147483647 + block_dim_limit_y: 65535 + block_dim_limit_z: 65535 + memory_bandwidth: 768096000000 + l2_cache_size: 6291456 + clock_rate_ghz: 1.8 + device_memory_size: 51041271808 + shared_memory_per_block_optin: 101376 + cuda_compute_capability { + major: 8 + minor: 6 + } +} +platform_name: "CUDA" +dnn_version_info { + major: 8 + minor: 9 + patch: 4 +} +device_description_str: "NVIDIA RTX A6000" + diff --git a/xla/tools/hlo_opt/gpu_specs/h100.txtpb b/xla/tools/hlo_opt/gpu_specs/h100.txtpb new file mode 100644 index 0000000000000..cee3fd499eb06 --- /dev/null +++ b/xla/tools/hlo_opt/gpu_specs/h100.txtpb @@ -0,0 +1,41 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +gpu_device_info { + threads_per_block_limit: 1024 + threads_per_warp: 32 + shared_memory_per_block: 49152 + shared_memory_per_core: 233472 + threads_per_core_limit: 2048 + core_count: 132 + fpus_per_core: 128 + block_dim_limit_x: 2147483647 + block_dim_limit_y: 65535 + block_dim_limit_z: 65535 + memory_bandwidth: 3352320000000 + l2_cache_size: 52428800 + clock_rate_ghz: 1.98 + device_memory_size: 84978434048 + shared_memory_per_block_optin: 232448 + cuda_compute_capability { + major: 9 + } +} +platform_name: "CUDA" +dnn_version_info { + major: 8 + minor: 9 + patch: 4 +} +device_description_str: "NVIDIA H100 80GB HBM3" diff --git a/xla/tools/hlo_opt/gpu_specs/h100_pcie.txtpb b/xla/tools/hlo_opt/gpu_specs/h100_pcie.txtpb new file mode 100644 index 0000000000000..b2213143b6dfb --- /dev/null +++ b/xla/tools/hlo_opt/gpu_specs/h100_pcie.txtpb @@ -0,0 +1,42 @@ +# Copyright 2023 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +gpu_device_info { + cuda_compute_capability { + major: 9 + minor: 0 + } + block_dim_limit_x: 2147483647 + block_dim_limit_y: 65535 + block_dim_limit_z: 65535 + clock_rate_ghz: 1.755 + core_count: 114 + device_memory_size: 84942979072 + fpus_per_core: 128 + l2_cache_size: 52428800 + memory_bandwidth: 2039040000000 + shared_memory_per_block_optin: 232448 + shared_memory_per_block: 49152 + shared_memory_per_core: 233472 + threads_per_block_limit: 1024 + threads_per_core_limit: 2048 + threads_per_warp: 32 +} +platform_name: "CUDA" +dnn_version_info { + major: 8 + minor: 9 + patch: 2 +} +device_description_str: "NVIDIA H100 PCIe" diff --git a/xla/tools/hlo_opt/gpu_specs/mi200.txtpb b/xla/tools/hlo_opt/gpu_specs/mi200.txtpb new file mode 100644 index 0000000000000..3c882ac2dcbe5 --- /dev/null +++ b/xla/tools/hlo_opt/gpu_specs/mi200.txtpb @@ -0,0 +1,37 @@ +# Copyright 2023 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +gpu_device_info { + threads_per_block_limit: 1024 + threads_per_warp: 64 + shared_memory_per_block: 65536 + shared_memory_per_core: 65536 + threads_per_core_limit: 2048 + core_count: 110 + fpus_per_core: 128 + block_dim_limit_x: 2147483647 + block_dim_limit_y: 2147483647 + block_dim_limit_z: 2147483647 + memory_bandwidth: 1638400000000 + l2_cache_size: 8388608 + clock_rate_ghz: 1.7 + device_memory_size: 67628957696 + rocm_compute_capability { + gcn_arch_name: "gfx90a:sramecc+:xnack-" + } +} +platform_name: "ROCM" +dnn_version_info { +} +device_description_str: "MI250" diff --git a/xla/tools/hlo_opt/gpu_specs/p100.txtpb b/xla/tools/hlo_opt/gpu_specs/p100.txtpb new file mode 100644 index 0000000000000..c9fcfb0b5981a --- /dev/null +++ b/xla/tools/hlo_opt/gpu_specs/p100.txtpb @@ -0,0 +1,41 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +gpu_device_info { + threads_per_block_limit: 1024 + threads_per_warp: 32 + shared_memory_per_block: 49152 + shared_memory_per_core: 65536 + threads_per_core_limit: 2048 + core_count: 56 + fpus_per_core: 64 + block_dim_limit_x: 2147483647 + block_dim_limit_y: 65535 + block_dim_limit_z: 65535 + memory_bandwidth: 732160000000 + l2_cache_size: 4194304 + clock_rate_ghz: 1.4805 + device_memory_size: 17066622976 + shared_memory_per_block_optin: 49152 + cuda_compute_capability { + major: 6 + } +} +platform_name: "CUDA" +dnn_version_info { + major: 8 + minor: 9 + patch: 4 +} +device_description_str: "Tesla P100-SXM2-16GB" diff --git a/xla/tools/hlo_opt/gpu_specs/v100.txtpb b/xla/tools/hlo_opt/gpu_specs/v100.txtpb new file mode 100644 index 0000000000000..5a6526e88eda8 --- /dev/null +++ b/xla/tools/hlo_opt/gpu_specs/v100.txtpb @@ -0,0 +1,38 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +gpu_device_info { + threads_per_block_limit: 1024 + threads_per_warp: 32 + shared_memory_per_block: 49152 + shared_memory_per_core: 98304 + threads_per_core_limit: 2048 + core_count: 80 + fpus_per_core: 64 + block_dim_limit_x: 2147483647 + block_dim_limit_y: 65535 + block_dim_limit_z: 65535 + memory_bandwidth: 898048000000 + l2_cache_size: 6291456 + clock_rate_ghz: 1.53 + device_memory_size: 16935419904 + shared_memory_per_block_optin: 98304 + cuda_compute_capability { + major: 7 + } +} +platform_name: "CUDA" +dnn_version_info { +} +device_description_str: "Tesla V100-SXM2-16GB" diff --git a/xla/tools/hlo_opt/opt_lib.cc b/xla/tools/hlo_opt/opt_lib.cc index 3bf19f43004f2..5435c0a27259c 100644 --- a/xla/tools/hlo_opt/opt_lib.cc +++ b/xla/tools/hlo_opt/opt_lib.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,13 +15,33 @@ limitations under the License. #include "xla/tools/hlo_opt/opt_lib.h" +#include +#include +#include +#include +#include +#include + #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "xla/debug_options_flags.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/compiler.h" +#include "xla/service/executable.h" +#include "xla/service/hlo_graph_dumper.h" +#include "xla/service/platform_util.h" +#include "xla/statusor.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/types.h" +#include "tsl/platform/statusor.h" namespace xla { using ProviderMap = - absl::flat_hash_map>; + absl::flat_hash_map>; static absl::Mutex provider_mu(absl::kConstInit); static ProviderMap& GetProviderMap() { @@ -30,22 +50,114 @@ static ProviderMap& GetProviderMap() { } /*static*/ void OptProvider::RegisterForPlatform( - se::Platform::Id platform, - std::unique_ptr translate_provider) { + std::string platform, std::unique_ptr translate_provider) { absl::MutexLock l(&provider_mu); CHECK(!GetProviderMap().contains(platform)); - GetProviderMap()[platform] = std::move(translate_provider); + absl::StatusOr canonical_name = + xla::PlatformUtil::CanonicalPlatformName(platform); + CHECK_OK(canonical_name); + GetProviderMap()[*canonical_name] = std::move(translate_provider); } -/*static*/ OptProvider* OptProvider::ProviderForPlatform( - se::Platform::Id platform) { +/*static*/ absl::StatusOr OptProvider::ProviderForPlatform( + std::string platform) { absl::MutexLock l(&provider_mu); - auto it = GetProviderMap().find(platform); + + TF_ASSIGN_OR_RETURN(std::string canonical_name, + xla::PlatformUtil::CanonicalPlatformName(platform)); + auto it = GetProviderMap().find(canonical_name); if (it == GetProviderMap().end()) { - return nullptr; + return absl::UnimplementedError(absl::StrCat( + "Provider not found for platform ", platform, "; canonical expansion: ", + canonical_name, "; supported platforms are: ", + absl::StrJoin(GetProviderMap(), ", ", + [&](std::string* s, const auto& p) { + absl::StrAppend(s, p.first); + }))); } return it->second.get(); } +absl::StatusOr OptProvider::GetExecutor() { + DebugOptions debug_opts = GetDebugOptionsFromFlags(); + TF_ASSIGN_OR_RETURN(se::Platform * platform, + PlatformUtil::GetPlatform(GetPlatformName())); + if (debug_opts.xla_gpu_target_config_filename().empty()) { + TF_ASSIGN_OR_RETURN(std::vector stream_executors, + PlatformUtil::GetStreamExecutors( + platform, /*allowed_devices=*/std::nullopt)); + return stream_executors[0]; + } + return nullptr; +} + +absl::StatusOr> OptProvider::GenerateStage( + std::unique_ptr module, absl::string_view stage) { + if (stage == "hlo") { + TF_ASSIGN_OR_RETURN(std::unique_ptr optimized_module, + GetOptimizedHlo(std::move(module))); + return optimized_module->ToString(); + } else if (stage == "html") { + TF_ASSIGN_OR_RETURN(std::unique_ptr optimized_module, + GetOptimizedHlo(std::move(module))); + TF_ASSIGN_OR_RETURN(std::string cmps, + RenderAllComputationsToHtml(*optimized_module)); + return cmps; + } else if (stage == "hlo-backend") { + TF_ASSIGN_OR_RETURN(auto executable, GetExecutable(std::move(module))); + return executable->module().ToString(); + } + + return std::nullopt; +} + +absl::StatusOr OptProvider::GetCompiler() { + TF_ASSIGN_OR_RETURN(se::Platform * platform, + PlatformUtil::GetPlatform(GetPlatformName())); + + TF_ASSIGN_OR_RETURN(Compiler * compiler, Compiler::GetForPlatform(platform)); + return compiler; +} + +absl::StatusOr> OptProvider::GetOptimizedHlo( + std::unique_ptr input_module) { + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, GetExecutor()); + + DebugOptions debug_opts = GetDebugOptionsFromFlags(); + Compiler::CompileOptions opts; + TF_ASSIGN_OR_RETURN(Compiler * compiler, GetCompiler()); + DebugOptions d = input_module->config().debug_options(); + d.set_xla_embed_ir_in_executable(true); + input_module->mutable_config().set_debug_options(d); + + if (input_module->has_schedule()) { + return input_module; + } + + // But run-hlo-passes does not actually run the scheduling. + TF_ASSIGN_OR_RETURN( + std::unique_ptr optimized_module, + compiler->RunHloPasses(std::move(input_module), executor, opts)); + + return optimized_module; +} + +absl::StatusOr> OptProvider::GetExecutable( + std::unique_ptr input_module) { + Compiler::CompileOptions opts; + TF_ASSIGN_OR_RETURN(std::unique_ptr optimized_module, + GetOptimizedHlo(std::move(input_module))); + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, GetExecutor()); + TF_ASSIGN_OR_RETURN(Compiler * compiler, GetCompiler()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + compiler->RunBackend(std::move(optimized_module), executor, opts)); + return executable; +} + +std::set OptProvider::SupportedStages() { + return {"hlo", "html", "hlo-backend"}; +} + } // namespace xla diff --git a/xla/tools/hlo_opt/opt_lib.h b/xla/tools/hlo_opt/opt_lib.h index 29ae5ba73c96d..bed286a5815b1 100644 --- a/xla/tools/hlo_opt/opt_lib.h +++ b/xla/tools/hlo_opt/opt_lib.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,34 +18,57 @@ limitations under the License. #include #include +#include #include -#include #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/compiler.h" +#include "xla/service/executable.h" #include "xla/statusor.h" #include "xla/stream_executor/platform.h" #include "xla/types.h" namespace xla { -// Platform-specific provider of `hlo_translate` functionality. -struct OptProvider { +// Platform-specific provider of `hlo-opt` functionality. +class OptProvider { + public: // Generates textual output for a given stage on a given platform, returns // empty optional if the stage is not supported. - virtual StatusOr> GenerateStage( - std::unique_ptr module, absl::string_view stage) = 0; + virtual absl::StatusOr> GenerateStage( + std::unique_ptr module, absl::string_view stage); virtual ~OptProvider() = default; - virtual std::vector SupportedStages() = 0; + // Returns a set of stages supported by the opt provider. + virtual std::set SupportedStages(); + // Registers a given provider for a given platform. static void RegisterForPlatform( - se::Platform::Id platform, - std::unique_ptr translate_provider); + std::string platform, std::unique_ptr translate_provider); - static OptProvider* ProviderForPlatform(se::Platform::Id platform); + // Gets a provider for a given platform. + static absl::StatusOr ProviderForPlatform( + std::string platform); + + protected: + // Returns platform name associated with the provider. + virtual std::string GetPlatformName() = 0; + + // Returns a stream executor for the provider (could be nullptr). + virtual absl::StatusOr GetExecutor(); + + // Generates executable from a given input module. + absl::StatusOr> GetExecutable( + std::unique_ptr input_module); + + // Generates optimized HLO. + absl::StatusOr> GetOptimizedHlo( + std::unique_ptr input_module); + + // Gets a compiler associated with the provider. + virtual absl::StatusOr GetCompiler(); }; } // namespace xla diff --git a/xla/tools/hlo_opt/opt_main.cc b/xla/tools/hlo_opt/opt_main.cc index f402fe866bc30..ef24a0a4da240 100644 --- a/xla/tools/hlo_opt/opt_main.cc +++ b/xla/tools/hlo_opt/opt_main.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,45 +20,56 @@ limitations under the License. #include #include #include +#include #include #include +#include #include #include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_runner.h" #include "xla/service/platform_util.h" +#include "xla/status.h" #include "xla/statusor.h" #include "xla/tools/hlo_module_loader.h" #include "xla/tools/hlo_opt/opt_lib.h" #include "xla/tools/run_hlo_module.h" +#include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" #include "tsl/platform/path.h" #include "tsl/platform/status.h" -#include "tsl/util/command_line_flags.h" +#include "tsl/platform/statusor.h" namespace { const char* const kUsage = R"( This tool lets you run a given HloModule from a file (or stdin) and convert it to expanded HLO, fully optimized HLO, or a binary depending on options. +HLO passes are always run, unless the HLO module is already scheduled (has +is_scheduled=True). + You can also pass in debug option flags for the HloModule. Usage: - bazel run opt -- --platform=[CUDA|CPU|Interpreter|...] path/to/hlo_module + bazel run opt -- --platform=[gpu|cpu|...] path/to/hlo_module )"; struct HloOptConfig { // Optional flags. bool help{false}; bool split_input_file{false}; - std::string platform{"cuda"}; + std::string platform{"gpu"}; std::string input_file{""}; std::string input_format{""}; std::string output_file{"-"}; @@ -72,6 +83,9 @@ namespace xla { namespace { +// Convention separator as set by mlir-opt tool. +const char* kOptSeparator = "// -----"; + std::string GetHloPath(const HloOptConfig& opts, int argc, char** argv) { if (!opts.input_file.empty()) { return opts.input_file; @@ -80,13 +94,13 @@ std::string GetHloPath(const HloOptConfig& opts, int argc, char** argv) { return argv[1]; } -StatusOr GetHloContents(const HloOptConfig& opts, int argc, - char** argv) { +absl::StatusOr GetHloContents(const HloOptConfig& opts, int argc, + char** argv) { std::string hlo_path = GetHloPath(opts, argc, argv); if (hlo_path == "-") { - std::string stdin; - std::getline(std::cin, stdin, static_cast(EOF)); - return stdin; + std::string input; + std::getline(std::cin, input, static_cast(EOF)); + return input; } std::string data; @@ -95,43 +109,68 @@ StatusOr GetHloContents(const HloOptConfig& opts, int argc, return data; } -StatusOr> GetModule(const HloOptConfig& opts, - int argc, char** argv) { +absl::StatusOr>> GetModules( + const HloOptConfig& opts, int argc, char** argv) { TF_ASSIGN_OR_RETURN(std::string module_data, GetHloContents(opts, argc, argv)); + std::vector hlos; + if (opts.split_input_file) { + hlos = absl::StrSplit(module_data, kOptSeparator); + } else { + hlos.push_back(module_data); + } + std::string format = opts.input_format; if (format.empty()) { format = std::string(tsl::io::Extension(GetHloPath(opts, argc, argv))); } - return LoadModuleFromData(module_data, format); -} -StatusOr TranslateToStage(int argc, char** argv, - const HloOptConfig& opts) { - se::Platform* platform = - xla::PlatformUtil::GetPlatform(opts.platform).value(); - - OptProvider* provider = OptProvider::ProviderForPlatform(platform->id()); - if (provider == nullptr) { - return absl::UnimplementedError( - absl::StrCat("Provider not found for platform: ", platform->Name())); + std::vector> out; + out.reserve(hlos.size()); + + for (const std::string& hlo : hlos) { + if (absl::StrContains(hlo, "// ---")) { + if (opts.split_input_file) { + return absl::InternalError( + "Unexpected separator found, expected exactly '// -----', found " + "'// ---'"); + } else { + return absl::InternalError( + "'// ---' separator found in input, but -split-input-file not " + "specified"); + } + } + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + LoadModuleFromData(hlo, format)); + out.push_back(std::move(module)); } + return out; +} + +absl::StatusOr TranslateToStage(int argc, char** argv, + const HloOptConfig& opts) { + TF_ASSIGN_OR_RETURN(OptProvider * provider, + OptProvider::ProviderForPlatform(opts.platform)); if (opts.list_stages) { return absl::StrJoin(provider->SupportedStages(), "\n"); } - TF_ASSIGN_OR_RETURN(std::unique_ptr module, - GetModule(opts, argc, argv)); - - TF_ASSIGN_OR_RETURN(std::optional out, - provider->GenerateStage(std::move(module), opts.stage)); - - if (!out.has_value()) { - return absl::UnimplementedError("Stage not supported"); + TF_ASSIGN_OR_RETURN(std::vector> modules, + GetModules(opts, argc, argv)); + + std::string out_combined; + + for (std::unique_ptr& m : modules) { + TF_ASSIGN_OR_RETURN(std::optional out, + provider->GenerateStage(std::move(m), opts.stage)); + if (!out.has_value()) { + return absl::UnimplementedError("Stage not supported"); + } + absl::StrAppend(&out_combined, *out, "\n"); } - return *out; + return out_combined; } Status RunOpt(int argc, char** argv, const HloOptConfig& opts) { @@ -169,9 +208,14 @@ int main(int argc, char** argv) { "\t\t\t * hlo : HLO after all optimizations\n" "\t\t\t * llvm : LLVM IR\n" "\t\t\t * ptx : PTX dump\n" - "\t\t\t * buffer-assignment: Buffer Assignment\n"), + "\t\t\t * buffer-assignment: Buffer Assignment\n" + "\t\t\t * hlo-backend: HLO after backend passes\n" + "\t\t\t * html: HTML dump\n"), tsl::Flag("list-stages", &opts.list_stages, - "Print all supported stages for a given platform and exit")}; + "Print all supported stages for a given platform and exit"), + tsl::Flag("split-input-file", &opts.split_input_file, + "Splits the input file in pieces based on '// -----' " + "substring, and processes each chunk independently")}; // Modifies global DebugOptions, populates flags with every flag available // from xla.proto. xla::AppendDebugOptionsFlags(&flag_list); diff --git a/xla/tools/hlo_proto_to_json.cc b/xla/tools/hlo_proto_to_json.cc index 7d41882e90442..cdadd1e3ce1d7 100644 --- a/xla/tools/hlo_proto_to_json.cc +++ b/xla/tools/hlo_proto_to_json.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,19 +30,19 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/statusor.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/util.h" #include "tsl/platform/env.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" -#include "tsl/util/command_line_flags.h" using std::string; namespace xla { namespace tools { -StatusOr ToJson(const tsl::protobuf::Message& message) { +absl::StatusOr ToJson(const tsl::protobuf::Message& message) { std::string json_output; tsl::protobuf::util::JsonPrintOptions json_options; json_options.add_whitespace = true; @@ -50,8 +50,8 @@ StatusOr ToJson(const tsl::protobuf::Message& message) { auto status = tsl::protobuf::util::MessageToJsonString(message, &json_output, json_options); if (!status.ok()) { - return InternalError("MessageToJsonString failed: %s", - std::string{status.message()}); + return Internal("MessageToJsonString failed: %s", + std::string{status.message()}); } return json_output; } diff --git a/xla/tools/hlo_slicer.cc b/xla/tools/hlo_slicer.cc index a407222f42d5b..2b1f1faa355c3 100644 --- a/xla/tools/hlo_slicer.cc +++ b/xla/tools/hlo_slicer.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tools/hlo_slicer.h b/xla/tools/hlo_slicer.h index e621bf9ed21b7..d9d5db1ee6f09 100644 --- a/xla/tools/hlo_slicer.h +++ b/xla/tools/hlo_slicer.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tools/hlo_slicer_test.cc b/xla/tools/hlo_slicer_test.cc index 75bde371d4aad..b117ba6814969 100644 --- a/xla/tools/hlo_slicer_test.cc +++ b/xla/tools/hlo_slicer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tools/interactive_graphviz.cc b/xla/tools/interactive_graphviz.cc index 0eb18c59f7b52..10d68162d5462 100644 --- a/xla/tools/interactive_graphviz.cc +++ b/xla/tools/interactive_graphviz.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,8 +21,11 @@ limitations under the License. // Generated visualization is opened in a new default browser window using // /usr/bin/sensible-browser. -#include +#ifndef _WIN32 #include +#endif + +#include #include #include @@ -43,12 +46,12 @@ limitations under the License. #include "xla/service/local_service.h" #include "xla/service/platform_util.h" #include "xla/tools/hlo_extractor.h" +#include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" #include "tsl/platform/path.h" #include "tsl/platform/subprocess.h" #include "tsl/protobuf/error_codes.pb.h" -#include "tsl/util/command_line_flags.h" #if defined(PLATFORM_GOOGLE) #include "util/readline/readline.h" #endif @@ -456,8 +459,9 @@ void OpenUrl(const Options& opts, absl::string_view url) { // URL format doesn't work out of the box; it requires you to register a plugin. void RenderAndDisplayGraph( const Options& opts, - const std::function(RenderedGraphFormat)>& renderer) { - StatusOr url_result = renderer(RenderedGraphFormat::kUrl); + const std::function(RenderedGraphFormat)>& + renderer) { + absl::StatusOr url_result = renderer(RenderedGraphFormat::kUrl); if (url_result.ok()) { std::string url = url_result.value(); OpenUrl(opts, url); @@ -473,7 +477,8 @@ void RenderAndDisplayGraph( } auto* env = tsl::Env::Default(); - StatusOr html_result = renderer(RenderedGraphFormat::kHtml); + absl::StatusOr html_result = + renderer(RenderedGraphFormat::kHtml); if (!html_result.ok()) { std::cerr << "Failed to render graph as HTML: " << html_result.status() << std::endl; diff --git a/xla/tools/interactive_graphviz_bin_test.cc b/xla/tools/interactive_graphviz_bin_test.cc index 262879bd6c0b8..8da5f014da335 100644 --- a/xla/tools/interactive_graphviz_bin_test.cc +++ b/xla/tools/interactive_graphviz_bin_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tools/multihost_hlo_runner/BUILD b/xla/tools/multihost_hlo_runner/BUILD index 0f4cfc85469d3..5b919442c8c20 100644 --- a/xla/tools/multihost_hlo_runner/BUILD +++ b/xla/tools/multihost_hlo_runner/BUILD @@ -1,10 +1,10 @@ -load("//xla/tests:build_defs.bzl", "xla_test") load("@bazel_skylib//rules:build_test.bzl", "build_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("//xla:xla.bzl", "xla_cc_binary") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load("@tsl//tsl:tsl.bzl", "if_cuda_or_rocm") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_binary") +load("//xla/tests:build_defs.bzl", "xla_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -15,6 +15,7 @@ package( build_test( name = "hlo_runner_main_build_test", tags = [ + "cpu", "gpu", ], targets = [ @@ -38,13 +39,14 @@ xla_cc_binary( "//xla:status", "//xla:statusor", "//xla/pjrt:pjrt_client", + "//xla/service:cpu_plugin", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:platform_port", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", - "@tsl//tsl/util:command_line_flags", ] + if_cuda_or_rocm([ "//xla/service:gpu_plugin", ]) + if_cuda([ @@ -70,6 +72,7 @@ cc_library( "//xla/pjrt:pjrt_compiler", "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", + "//xla/pjrt/cpu:cpu_client", "//xla/pjrt/distributed:client", "//xla/pjrt/gpu:se_gpu_pjrt_client", "//xla/service:hlo_parser", diff --git a/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc b/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc index 40efc7f613d7a..b424df91ec935 100644 --- a/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc +++ b/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" +#include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/distributed/client.h" #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_client.h" @@ -53,13 +54,13 @@ namespace xla { namespace { // Creates an HloModule from the given proto. -StatusOr> HloTextToModule( +absl::StatusOr> HloTextToModule( absl::string_view hlo_text) { return ParseAndReturnUnverifiedModule(hlo_text); } // Creates an HloModule from the given proto. -StatusOr> HloProtoToModule( +absl::StatusOr> HloProtoToModule( const HloModuleProto& proto) { TF_ASSIGN_OR_RETURN( HloModuleConfig config, @@ -76,7 +77,8 @@ void PopulateWithSameValue(Literal* literal, ElementType val) { } } -StatusOr MakeFakeLiteralWithSameValue(const Shape& shape, int value) { +absl::StatusOr MakeFakeLiteralWithSameValue(const Shape& shape, + int value) { if (!shape.IsArray()) { return InvalidArgument( "MakeFakeLiteralWithSameValue does not support non-array type"); @@ -84,7 +86,7 @@ StatusOr MakeFakeLiteralWithSameValue(const Shape& shape, int value) { Shape new_shape = shape; new_shape.mutable_layout()->clear_tiles(); return primitive_util::PrimitiveTypeSwitch>( - [&](auto type) -> StatusOr { + [&](auto type) -> absl::StatusOr { if constexpr (primitive_util::IsArrayType(type)) { using NativeT = primitive_util::NativeTypeOf; @@ -243,19 +245,26 @@ void AddShardingAnnotationsToSpmdPartitionedModule(HloModule* hlo_module) { set_manual_sharding(entry_root); } -StatusOr> FunctionalHloRunner::CreateGpuClient() { +absl::StatusOr> +FunctionalHloRunner::CreateHostClient() { + return GetTfrtCpuClient(CpuClientOptions()); +} + +absl::StatusOr> +FunctionalHloRunner::CreateGpuClient() { return GetStreamExecutorGpuClient(GpuClientOptions()); } -StatusOr> FunctionalHloRunner::CreateMockGpuClient( - int num_nodes) { +absl::StatusOr> +FunctionalHloRunner::CreateMockGpuClient(int num_nodes) { GpuClientOptions options; options.num_nodes = num_nodes; options.enable_mock_nccl = true; return GetStreamExecutorGpuClient(options); } -StatusOr> FunctionalHloRunner::CreateGpuClient( +absl::StatusOr> +FunctionalHloRunner::CreateGpuClient( std::shared_ptr distributed_client, int node_id, int num_nodes) { if (node_id < 0 || node_id >= num_nodes) { @@ -265,32 +274,15 @@ StatusOr> FunctionalHloRunner::CreateGpuClient( TF_RET_CHECK(distributed_client != nullptr); - // Use the plugin name as key prefix. - static constexpr absl::string_view kKeyPrefix = "gpu:"; - - xla::PjRtClient::KeyValueGetCallback kv_get = - [distributed_client]( - std::string_view k, - absl::Duration timeout) -> xla::StatusOr { - return distributed_client->BlockingKeyValueGet(absl::StrCat(kKeyPrefix, k), - timeout); - }; - - xla::PjRtClient::KeyValuePutCallback kv_put = - [distributed_client](std::string_view k, - std::string_view v) -> xla::Status { - return distributed_client->KeyValueSet(absl::StrCat(kKeyPrefix, k), v); - }; - GpuClientOptions options; options.node_id = node_id; options.num_nodes = num_nodes; - options.kv_get = kv_get; - options.kv_put = kv_put; + options.kv_store = + GetDistributedKeyValueStore(distributed_client, /*key_prefix=*/"gpu:"); return GetStreamExecutorGpuClient(options); } -StatusOr FunctionalHloRunner::LoadExecutionOptions( +absl::StatusOr FunctionalHloRunner::LoadExecutionOptions( absl::string_view path) { ExecutionOptions execution_options; TF_RETURN_IF_ERROR(tsl::ReadTextOrBinaryProto( @@ -298,7 +290,7 @@ StatusOr FunctionalHloRunner::LoadExecutionOptions( return execution_options; } -StatusOr FunctionalHloRunner::CreateCompileOptions( +absl::StatusOr FunctionalHloRunner::CreateCompileOptions( const PjRtClient& client, const FunctionalHloRunner::RawCompileOptions& raw_options, int task_id) { CompileOptions compile_options; @@ -414,10 +406,12 @@ FunctionalHloRunner::CreateExecutableBuildOptionsFromExecutionOptions( build_options.set_use_auto_spmd_partitioning( execution_options.use_auto_spmd_partitioning()); build_options.set_deduplicate_hlo(execution_options.deduplicate_hlo()); + build_options.set_allow_spmd_sharding_propagation_to_parameters( + execution_options.allow_spmd_sharding_propagation_to_parameters()); build_options.set_allow_spmd_sharding_propagation_to_output( execution_options.allow_spmd_sharding_propagation_to_output()); if (execution_options.has_device_assignment()) { - StatusOr> device_assignment = + absl::StatusOr> device_assignment = DeviceAssignment::Deserialize(execution_options.device_assignment()); TF_CHECK_OK(device_assignment.status()); build_options.set_device_assignment(**device_assignment); @@ -462,7 +456,7 @@ absl::Span FunctionalHloRunner::GetLocalDevices( return client.addressable_devices(); } -StatusOr +absl::StatusOr FunctionalHloRunner::LoadHloModuleAndArguments(absl::string_view hlo_file, InputFormat input_format) { HloModuleAndArguments hlo_module_and_arguments; @@ -511,7 +505,7 @@ Status FunctionalHloRunner::LoadAndRunAndDump( : FunctionalHloRunner::DumpOutput(output, dump_output_to, task_id); } -StatusOr +absl::StatusOr FunctionalHloRunner::LoadAndRun(PjRtClient& client, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, @@ -546,29 +540,6 @@ FunctionalHloRunner::LoadAndRun(PjRtClient& client, hlo_module_and_arguments.hlo_module.get(), loaded_arguments); } -StatusOr -FunctionalHloRunner::LoadAndRun( - PjRtClient& client, const DebugOptions& debug_options, - const PreprocessingOptions& preproc_options, - const CompileOptions& compile_options, - const RunningOptions& running_options, - absl::Span hlo_files, InputFormat input_format, - const LiteralVec& argument_literals, - const PerDeviceIndexVecType& per_device_index_vec) { - CHECK(!hlo_files.empty()); - // We only support SPMD as of now, i.e., all devices are supposed - // to execute the same HLO module. - // TODO(tdanyluk): Consider revising this API which takes multiple HLOs, but - // uses only one. - HloModuleAndArguments hlo_module_and_arguments; - TF_ASSIGN_OR_RETURN(hlo_module_and_arguments, - LoadHloModuleAndArguments(hlo_files[0], input_format)); - return CompileAndRun(client, debug_options, preproc_options, compile_options, - running_options, - hlo_module_and_arguments.hlo_module.get(), - argument_literals, per_device_index_vec); -} - Status FunctionalHloRunner::LoadAndCompile( PjRtClient& client, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, @@ -602,7 +573,7 @@ Status FunctionalHloRunner::LoadAndCompile( return OkStatus(); } -StatusOr> +absl::StatusOr> FunctionalHloRunner::ReadModuleFromHloTextFile(absl::string_view hlo_file) { std::string hlo_string; TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), @@ -610,7 +581,7 @@ FunctionalHloRunner::ReadModuleFromHloTextFile(absl::string_view hlo_file) { return ParseAndReturnUnverifiedModule(hlo_string); } -StatusOr> +absl::StatusOr> FunctionalHloRunner::ReadModuleFromBinaryProtoFile(absl::string_view hlo_file) { HloProto proto; TF_RETURN_IF_ERROR( @@ -618,7 +589,7 @@ FunctionalHloRunner::ReadModuleFromBinaryProtoFile(absl::string_view hlo_file) { return HloProtoToModule(proto.hlo_module()); } -StatusOr> +absl::StatusOr> FunctionalHloRunner::ReadModuleFromTextProtoFile(absl::string_view hlo_file) { HloProto proto; TF_RETURN_IF_ERROR( @@ -626,7 +597,7 @@ FunctionalHloRunner::ReadModuleFromTextProtoFile(absl::string_view hlo_file) { return HloProtoToModule(proto.hlo_module()); } -StatusOr +absl::StatusOr FunctionalHloRunner::ReadModuleFromSnapshotBinaryProtoFile( absl::string_view hlo_file) { HloSnapshot proto; @@ -643,17 +614,17 @@ FunctionalHloRunner::ReadModuleFromSnapshotBinaryProtoFile( return hlo_module_and_arguments; } -StatusOr> FunctionalHloRunner::ReadModuleFromString( - absl::string_view hlo_text) { +absl::StatusOr> +FunctionalHloRunner::ReadModuleFromString(absl::string_view hlo_text) { return HloTextToModule(hlo_text); } -StatusOr> FunctionalHloRunner::ReadModuleFromProto( - const HloModuleProto& proto) { +absl::StatusOr> +FunctionalHloRunner::ReadModuleFromProto(const HloModuleProto& proto) { return HloProtoToModule(proto); } -StatusOr +absl::StatusOr FunctionalHloRunner::CompileAndRun(PjRtClient& client, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, @@ -668,21 +639,6 @@ FunctionalHloRunner::CompileAndRun(PjRtClient& client, return Run(client, executable.get(), arguments, running_options); } -StatusOr -FunctionalHloRunner::CompileAndRun( - PjRtClient& client, const DebugOptions& debug_options, - const PreprocessingOptions& preproc_options, - const CompileOptions& compile_options, - const RunningOptions& running_options, HloModule* hlo_module, - const LiteralVec& argument_literals, - const PerDeviceIndexVecType& argument_indices) { - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - Compile(client, hlo_module, debug_options, - preproc_options, compile_options)); - return Run(client, executable.get(), argument_literals, argument_indices, - running_options); -} - namespace { // Argument buffers are created on device at the first time an HLO module @@ -796,11 +752,11 @@ CompileOptions FunctionalHloRunner::CompleteCompileOptions( return compile_options; } -StatusOr> FunctionalHloRunner::Compile( - PjRtClient& client, HloModule* hlo_module, - const DebugOptions& debug_options, - const PreprocessingOptions& preproc_options, - const CompileOptions& compile_options) { +absl::StatusOr> +FunctionalHloRunner::Compile(PjRtClient& client, HloModule* hlo_module, + const DebugOptions& debug_options, + const PreprocessingOptions& preproc_options, + const CompileOptions& compile_options) { TF_RETURN_IF_ERROR(PrepareHloModuleForCompilation(hlo_module, debug_options, preproc_options)); CompileOptions modified_compile_options = @@ -813,7 +769,7 @@ StatusOr> FunctionalHloRunner::Compile( return executable; } -StatusOr> FunctionalHloRunner::Compile( +absl::StatusOr> FunctionalHloRunner::Compile( PjRtClient& client, HloModule* hlo_module, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, @@ -835,11 +791,11 @@ StatusOr> FunctionalHloRunner::Compile( // Runs the executable and may repeat for multiple times. // Since the input buffers may be donated by the PjrtClient, we re-create the // input PjrtBuffers for each repetition. -StatusOr FunctionalHloRunner::Run( - PjRtClient& client, PjRtLoadedExecutable* executable, +absl::StatusOr +FunctionalHloRunner::Run(PjRtClient& client, PjRtLoadedExecutable* executable, - const PerDeviceLiteralVecType& arguments, - const RunningOptions& running_options) { + const PerDeviceLiteralVecType& arguments, + const RunningOptions& running_options) { auto create_argument_buffers_on_device = [&client, &executable, &arguments, &running_options]( bool flatten_tupled_arguments) { @@ -872,66 +828,6 @@ StatusOr FunctionalHloRunner::Run( running_options); } -// Runs the executable and may repeat for multiple times. -// Since the input buffers may be donated by the PjrtClient, we re-create the -// input PjrtBuffers for each repetition. -StatusOr FunctionalHloRunner::Run( - - PjRtClient& client, PjRtLoadedExecutable* executable, - const LiteralVec& argument_literals, - const PerDeviceIndexVecType& argument_indices, - const RunningOptions& running_options) { - auto create_argument_buffers_on_device = [&client, &executable, - &argument_literals, - &argument_indices, - &running_options]( - bool flatten_arguments) { - CHECK_GE(argument_literals.size(), 1); - bool arguments_can_be_flattened = absl::c_all_of( - argument_literals, - [](const Literal& literal) { return literal.shape().IsTuple(); }); - arguments_can_be_flattened &= absl::c_all_of( - argument_indices, [](PerDeviceIndexVecType::const_reference - device_id_and_argument_indices) { - return device_id_and_argument_indices.second.size() == 1; - }); - if (flatten_arguments && arguments_can_be_flattened) { - int tuple_shape_size = - argument_literals.front().shape().tuple_shapes_size(); - LiteralVec flattened_argument_literals; - for (const Literal& tupled_argument : argument_literals) { - LiteralVec flattened_arguments = - tupled_argument.Clone().DecomposeTuple(); - for (Literal& flattened_argument : flattened_arguments) { - flattened_argument_literals.push_back(std::move(flattened_argument)); - } - } - PerDeviceIndexVecType flattened_per_device_index_vec; - for (const auto& device_id_and_argument_indices : argument_indices) { - std::vector flattened_argument_indices(tuple_shape_size); - int tupled_argument_index = - device_id_and_argument_indices.second.front(); - for (int i = 0; i < tuple_shape_size; i++) { - flattened_argument_indices[i] = - tupled_argument_index * tuple_shape_size + i; - } - int device_id = device_id_and_argument_indices.first; - flattened_per_device_index_vec.insert( - {device_id, std::move(flattened_argument_indices)}); - } - return CopyArgumentsToDevice(client, executable->addressable_devices(), - flattened_argument_literals, - flattened_per_device_index_vec, - running_options.log_input_output()); - } - return CopyArgumentsToDevice(client, executable->addressable_devices(), - argument_literals, argument_indices, - running_options.log_input_output()); - }; - return RunInternal(client, executable, create_argument_buffers_on_device, - running_options); -} - namespace { std::vector> CreateArgumentPointersFromDeviceBuffers( @@ -1015,11 +911,11 @@ Status EnsureSingleTupleForFlattening(const HloModule& module) { } // namespace -StatusOr +absl::StatusOr FunctionalHloRunner::RunInternal( PjRtClient& client, PjRtLoadedExecutable* executable, - std::function< - StatusOr>>>(bool)> + std::function>>>(bool)> create_argument_buffers_on_device, const RunningOptions& running_options) { ExecuteOptions execute_options; @@ -1139,7 +1035,7 @@ FunctionalHloRunner::RunInternal( return results; } -StatusOr>>> +absl::StatusOr>>> FunctionalHloRunner::CreateArgumentsOnDevice( PjRtClient& client, const PjRtLoadedExecutable* executable, const RunningOptions& running_options, bool flatten_arguments) { @@ -1245,7 +1141,7 @@ FunctionalHloRunner::CreateArgumentsOnDevice( running_options.log_input_output()); } -StatusOr>>> +absl::StatusOr>>> FunctionalHloRunner::CreateUninitializedArgumentsOnDevice( PjRtClient& client, const PjRtLoadedExecutable* executable, const RunningOptions& running_options, bool flatten_arguments) { @@ -1329,7 +1225,7 @@ FunctionalHloRunner::CreateUninitializedArgumentsOnDevice( return argument_buffers_per_device; } -StatusOr>>> +absl::StatusOr>>> FunctionalHloRunner::CopyArgumentsToDevice( PjRtClient& client, absl::Span addressable_devices, const PerDeviceLiteralVecType& arguments, bool log_input) { @@ -1374,7 +1270,7 @@ FunctionalHloRunner::CopyArgumentsToDevice( return argument_buffers; } -StatusOr>>> +absl::StatusOr>>> FunctionalHloRunner::CopyArgumentsToDevice( PjRtClient& client, absl::Span addressable_devices, const LiteralVec& argument_literals, @@ -1422,7 +1318,7 @@ FunctionalHloRunner::CopyArgumentsToDevice( return argument_buffers; } -StatusOr +absl::StatusOr FunctionalHloRunner::FetchAndLogOutput( PjRtClient& client, const std::vector>>& output_buffers, @@ -1464,7 +1360,7 @@ FunctionalHloRunner::FetchAndLogOutput( "same device"; output_slice.emplace_back( ShapeUtil::DeviceShapeToHostShape(buffer->on_device_shape())); - buffer->ToLiteral(&output_slice.back(), [&](Status s) { + buffer->ToLiteral(&output_slice.back()).OnReady([&](Status s) { absl::MutexLock lock(&mu); --num_pending_transfers; status.Update(s); diff --git a/xla/tools/multihost_hlo_runner/functional_hlo_runner.h b/xla/tools/multihost_hlo_runner/functional_hlo_runner.h index 93d22bf94233b..a49e7d8050932 100644 --- a/xla/tools/multihost_hlo_runner/functional_hlo_runner.h +++ b/xla/tools/multihost_hlo_runner/functional_hlo_runner.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -207,30 +207,33 @@ class FunctionalHloRunner { int partitions = 1; }; + // Create a PjRtClient which can run HLOs on Host CPU. + static absl::StatusOr> CreateHostClient(); + // Create a PjRtClient which can run HLOs on GPU. - static StatusOr> CreateGpuClient(); + static absl::StatusOr> CreateGpuClient(); // Create a PjRtClient which mocks multi-hosts GPU run - static StatusOr> CreateMockGpuClient( + static absl::StatusOr> CreateMockGpuClient( int num_nodes = 1); // Create a PjRtClient which can run HLOs on GPUs distributed across several // nodes. // The distributed client pointer passed as a parameter is expected to be // non-null, and 0 <= node_id < num_nodes must hold. - static StatusOr> CreateGpuClient( + static absl::StatusOr> CreateGpuClient( std::shared_ptr distributed_client, int node_id, int num_nodes); // Loads an ExecutionOptions proto (which can be used in RawCompileOptions). - static StatusOr LoadExecutionOptions( + static absl::StatusOr LoadExecutionOptions( absl::string_view path); // Creates the compilation options. // // If RawCompileOptions::num_slices is set, the // CompileOptions::device_assignment has to be set manually. - static StatusOr CreateCompileOptions( + static absl::StatusOr CreateCompileOptions( const PjRtClient& client, const FunctionalHloRunner::RawCompileOptions& raw_options, int task_id = 0); @@ -251,7 +254,7 @@ class FunctionalHloRunner { // not empty. Otherwise, use arguments from the HLO file or fake arguments. // The hlo file might be a HLO snapshot and thus contain arguments, otherwise // it is run with fake arguments. - static StatusOr LoadAndRun( + static absl::StatusOr LoadAndRun( PjRtClient& client, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, const CompileOptions& compile_options, @@ -259,21 +262,6 @@ class FunctionalHloRunner { absl::Span hlo_files, InputFormat input_format, const PerDeviceLiteralVecType& arguments = {}); - // Loads an HLO module from hlo_file according to input_format and run it. - // The module arguments are provided by `argument_literals`. The arguments per - // device is defined by the `per_device_index_vec`, which should contain a - // vector of indices for each local device. This means different device may - // use the same argument literals. This is essential to run HLO modules with - // large arguments (e.g., models with large weights). - static StatusOr LoadAndRun( - PjRtClient& client, const DebugOptions& debug_options, - const PreprocessingOptions& preproc_options, - const CompileOptions& compile_options, - const RunningOptions& running_options, - absl::Span hlo_files, InputFormat input_format, - const LiteralVec& argument_literals, - const PerDeviceIndexVecType& per_device_index_vec); - // Loads and compiles an HLO for debugging purposes. // // This function allows compiling multi-device HLOs on machines with fewer @@ -288,29 +276,15 @@ class FunctionalHloRunner { // Compiles and runs the given HLO module with the given arguments for each // device. The given arguments is a map from device ID to a list of arguments. // If the arguments map is empty, the HLO module is run with fake arguments. - static StatusOr CompileAndRun( + static absl::StatusOr CompileAndRun( PjRtClient& client, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, const CompileOptions& compile_options, const RunningOptions& running_options, HloModule* hlo_module, const PerDeviceLiteralVecType& arguments = {}); - // Compiles and runs the given HLO module with the given arguments for each - // device. The module arguments are provided by `argument_literals`. The - // arguments per device is defined by the `per_device_index_vec`, which should - // contain a vector of indices for each local device. This means different - // devices may use the same argument literals. This is essential to run HLO - // modules with large arguments (e.g., models with large weights). - static StatusOr CompileAndRun( - PjRtClient& client, const DebugOptions& debug_options, - const PreprocessingOptions& preproc_options, - const CompileOptions& compile_options, - const RunningOptions& running_options, HloModule* hlo_module, - const LiteralVec& argument_literals, - const PerDeviceIndexVecType& argument_indices); - // Compiles the HLO module. - static StatusOr> Compile( + static absl::StatusOr> Compile( PjRtClient& client, HloModule* hlo_module, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, @@ -319,7 +293,7 @@ class FunctionalHloRunner { // Ahead-of-time compilation using the PjRtTopologyDescription that's passed // instead of using the registered topology. This enables reproduction of // compilation based on captured information. - static StatusOr> Compile( + static absl::StatusOr> Compile( PjRtClient& client, HloModule* hlo_module, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, @@ -327,35 +301,27 @@ class FunctionalHloRunner { const PjRtTopologyDescription& topology); // Runs the executable. - static StatusOr Run( + static absl::StatusOr Run( PjRtClient& client, PjRtLoadedExecutable* executable, const PerDeviceLiteralVecType& arguments, const RunningOptions& running_options); - // Runs the executable, where the module arguments are provided through - // a shared literal vector and per-device indices. - static StatusOr Run( - PjRtClient& client, PjRtLoadedExecutable* executable, - const LiteralVec& argument_literals, - const PerDeviceIndexVecType& argument_indices, - const RunningOptions& running_options); - - static StatusOr> ReadModuleFromHloTextFile( + static absl::StatusOr> ReadModuleFromHloTextFile( absl::string_view hlo_file); - static StatusOr> ReadModuleFromBinaryProtoFile( - absl::string_view hlo_file); - static StatusOr> ReadModuleFromTextProtoFile( + static absl::StatusOr> + ReadModuleFromBinaryProtoFile(absl::string_view hlo_file); + static absl::StatusOr> ReadModuleFromTextProtoFile( absl::string_view hlo_file); - static StatusOr ReadModuleFromSnapshotBinaryProtoFile( - absl::string_view hlo_file); - static StatusOr LoadHloModuleAndArguments( + static absl::StatusOr + ReadModuleFromSnapshotBinaryProtoFile(absl::string_view hlo_file); + static absl::StatusOr LoadHloModuleAndArguments( absl::string_view hlo_file, InputFormat input_format); - static StatusOr> ReadModuleFromString( + static absl::StatusOr> ReadModuleFromString( absl::string_view hlo_text); - static StatusOr> ReadModuleFromProto( + static absl::StatusOr> ReadModuleFromProto( const HloModuleProto& proto); // This would ideally be private, but we need it for the implementation of @@ -394,14 +360,14 @@ class FunctionalHloRunner { const PjRtClient& client); // Creates fake arguments to run the given executable. - static StatusOr>>> + static absl::StatusOr>>> CreateArgumentsOnDevice(PjRtClient& client, const PjRtLoadedExecutable* executable, const RunningOptions& running_options, bool flatten_arguments = false); // Creates uninitialized arguments to run the given executable. - static StatusOr>>> + static absl::StatusOr>>> CreateUninitializedArgumentsOnDevice(PjRtClient& client, const PjRtLoadedExecutable* executable, const RunningOptions& running_options, @@ -409,27 +375,27 @@ class FunctionalHloRunner { // Creates argument buffers based on the given arguments map. Note that the // arguments might be invalid when arguments are destructed. - static StatusOr>>> + static absl::StatusOr>>> CopyArgumentsToDevice(PjRtClient& client, absl::Span addressable_devices, const PerDeviceLiteralVecType& arguments, bool log_input); - static StatusOr>>> + static absl::StatusOr>>> CopyArgumentsToDevice(PjRtClient& client, absl::Span addressable_devices, const LiteralVec& argument_literals, const PerDeviceIndexVecType& argument_indices, bool log_input); - static StatusOr RunInternal( + static absl::StatusOr RunInternal( PjRtClient& client, PjRtLoadedExecutable* executable, - std::function< - StatusOr>>>(bool)> + std::function>>>(bool)> create_argument_buffers_on_device, const RunningOptions& running_options); - static StatusOr FetchAndLogOutput( + static absl::StatusOr FetchAndLogOutput( PjRtClient& client, const std::vector>>& output_buffers, diff --git a/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc b/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc index 85db7711d0431..662bbdb7e46dc 100644 --- a/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc +++ b/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -211,7 +211,7 @@ TEST_F(FunctionalHloRunnerTest, CanCompileWithoutHavingEnoughGpus) { std::string after_opt_hlo; TF_ASSERT_OK( tsl::ReadFileToString(env, after_opt_hlo_paths[0], &after_opt_hlo)); - StatusOr file_check_result = RunFileCheck(after_opt_hlo, R"( + absl::StatusOr file_check_result = RunFileCheck(after_opt_hlo, R"( // CHECK: param = f32[16,1]{1,0} // CHECK: add = f32[16,1]{1,0} )"); diff --git a/xla/tools/multihost_hlo_runner/hlo_runner_flags.cc b/xla/tools/multihost_hlo_runner/hlo_runner_flags.cc index 986d4914d7f68..600fffaf023b8 100644 --- a/xla/tools/multihost_hlo_runner/hlo_runner_flags.cc +++ b/xla/tools/multihost_hlo_runner/hlo_runner_flags.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -129,7 +129,7 @@ bool MultiHostHloRunnerFlags::CreateOptionsFromFlags( ? FunctionalHloRunner::SpmdMode::kUseSpmdPartitioning : FunctionalHloRunner::SpmdMode::kNotUseSpmdPartitioning; if (!flag_values_.execution_options_path.empty()) { - StatusOr execution_options = + absl::StatusOr execution_options = FunctionalHloRunner::LoadExecutionOptions( flag_values_.execution_options_path); if (!execution_options.ok()) { diff --git a/xla/tools/multihost_hlo_runner/hlo_runner_flags.h b/xla/tools/multihost_hlo_runner/hlo_runner_flags.h index 8f183e28ab1bd..6228555c9da68 100644 --- a/xla/tools/multihost_hlo_runner/hlo_runner_flags.h +++ b/xla/tools/multihost_hlo_runner/hlo_runner_flags.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tools/multihost_hlo_runner/hlo_runner_main.cc b/xla/tools/multihost_hlo_runner/hlo_runner_main.cc index 18515ad05918f..b0666ff491fcb 100644 --- a/xla/tools/multihost_hlo_runner/hlo_runner_main.cc +++ b/xla/tools/multihost_hlo_runner/hlo_runner_main.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,11 +28,11 @@ limitations under the License. #include "xla/statusor.h" #include "xla/tools/multihost_hlo_runner/functional_hlo_runner.h" #include "xla/tools/multihost_hlo_runner/hlo_runner_flags.h" +#include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/util/command_line_flags.h" namespace { const char* const kUsage = R"( @@ -96,7 +96,7 @@ int main(int argc, char** argv) { "A path to which the HLO output will be dumped. " "Example: /a/b/literal.txt."), tsl::Flag("task_id", &task_id, "Borg task id."), - tsl::Flag("device_type", &device_type_str, "Device type: gpu"), + tsl::Flag("device_type", &device_type_str, "Device type: gpu, host"), tsl::Flag("num_nodes", &num_nodes, "Number of nodes (hosts)"), tsl::Flag( "enable_mock_nccl", &enable_mock_nccl, @@ -116,7 +116,8 @@ int main(int argc, char** argv) { bool parse_ok = tsl::Flags::Parse(&argc, argv, flag_list); parse_ok = parse_ok && xla::AbslParseFlag(input_format_str, &input_format, &parse_error); - parse_ok = parse_ok && device_type_str == "gpu"; + parse_ok = + parse_ok && (device_type_str == "gpu" || device_type_str == "host"); parse_ok = parse_ok && hlo_runner_flags.CreateOptionsFromFlags( &preproc_options, &raw_compile_options, &running_options, &parse_error); @@ -131,14 +132,23 @@ int main(int argc, char** argv) { } // The main logic: - xla::StatusOr> client; - if (enable_mock_nccl) { - CHECK_GT(num_nodes, 1); - client = xla::FunctionalHloRunner::CreateMockGpuClient(num_nodes); - } else { - CHECK_EQ(num_nodes, 1); - client = xla::FunctionalHloRunner::CreateGpuClient(); - } + absl::StatusOr> client = [&] { + if (device_type_str == "host") { + CHECK_EQ(num_nodes, 1); + return xla::FunctionalHloRunner::CreateHostClient(); + } + + CHECK_EQ(device_type_str, "gpu"); + + if (enable_mock_nccl) { + CHECK_GT(num_nodes, 1); + return xla::FunctionalHloRunner::CreateMockGpuClient(num_nodes); + } else { + CHECK_EQ(num_nodes, 1); + return xla::FunctionalHloRunner::CreateGpuClient(); + } + }(); + TF_QCHECK_OK(client.status()); if (should_run) { diff --git a/xla/tools/prepare_reference_module.cc b/xla/tools/prepare_reference_module.cc index 572aa2dcf9235..fc904d9fea1c6 100644 --- a/xla/tools/prepare_reference_module.cc +++ b/xla/tools/prepare_reference_module.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ limitations under the License. namespace xla { -StatusOr> PrepareReferenceModule( +absl::StatusOr> PrepareReferenceModule( const HloModule& test_module, HloRunnerInterface* test_runner, const std::function& config_modifier_hook, const std::function> PrepareReferenceModule( +absl::StatusOr> PrepareReferenceModule( const HloModule& test_module, HloRunnerInterface* test_runner, const std::function& config_modifier_hook = {}, const std::function +#include #include +#include #include +#include #include +#include #include #include #include +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" #include "xla/literal_comparison.h" #include "xla/service/hlo.pb.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_verifier.h" +#include "xla/status.h" #include "xla/tests/test_utils.h" #include "xla/tools/hlo_control_flow_flattening.h" +#include "xla/tools/hlo_decomposer.h" #include "xla/tools/hlo_module_loader.h" #include "xla/tools/prepare_reference_module.h" #include "xla/tools/run_hlo_module.pb.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" #include "tsl/platform/path.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { +enum class ModuleResult { + kMatched, + kRan, + kSkipped, + kDidntRun, + kOtherError, + kCompilationError, + kRuntimeError, + kMismatch, +}; + +constexpr absl::string_view ModuleResultToString(ModuleResult result) { + switch (result) { + case ModuleResult::kMatched: + return "MATCHED"; + case ModuleResult::kRan: + return "RAN"; + case ModuleResult::kSkipped: + return "SKIPPED"; + case ModuleResult::kDidntRun: + return "DIDN'T RUN"; + case ModuleResult::kOtherError: + return "OTHER ERROR"; + case ModuleResult::kCompilationError: + return "COMPILATION ERROR"; + case ModuleResult::kRuntimeError: + return "RUNTIME ERROR"; + case ModuleResult::kMismatch: + return "MISMATCH"; + } +} // Writes the given literal to a file in the test temporary directory. void WriteLiteralToTempFile(const LiteralSlice& literal, @@ -85,7 +133,7 @@ void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, WriteLiteralToTempFile(mismatches, "mismatches"); } -StatusOr ExecuteWithRunner( +absl::StatusOr ExecuteWithRunner( std::unique_ptr module, const BufferAssignmentProto* buffer_assignment_proto, absl::Span args, HloRunnerInterface* runner, @@ -118,9 +166,8 @@ StatusOr ExecuteWithRunner( return std::move(result_status).value(); } -} // namespace -Status RunAndCompare( +Status RunAndCompareInternal( std::unique_ptr test_module, const BufferAssignmentProto* buffer_assignment_proto, HloRunnerInterface* test_runner, HloRunnerInterface* reference_runner, @@ -128,7 +175,16 @@ Status RunAndCompare( xla::RunHloModuleIterationLiterals* iteration_literals_proto, std::function reference_module_modifier_hook, - std::function config_modifier_hook) { + std::function config_modifier_hook, + ModuleResult* test_run_result, ModuleResult* reference_run_result) { + auto copy_result_on_failure = [](auto status, ModuleResult result, + ModuleResult* out_result) { + if (!status.ok() && out_result != nullptr) { + *out_result = result; + } + return status; + }; + if (!config_modifier_hook) { config_modifier_hook = [](HloModuleConfig* config) { config->set_seed(42); @@ -138,19 +194,27 @@ Status RunAndCompare( if (options.flatten_control_flow) { HloControlFlowFlattening control_flow_flattening( HloControlFlowFlattening::Options{/*while_execution_count=*/1}); - TF_RETURN_IF_ERROR(control_flow_flattening.Run(test_module.get()).status()); + TF_RETURN_IF_ERROR( + copy_result_on_failure(control_flow_flattening.Run(test_module.get()), + ModuleResult::kCompilationError, test_run_result) + .status()); } const HloModuleProto test_module_proto = test_module->ToProto(); - TF_ASSIGN_OR_RETURN(auto args, - MakeFakeArguments(test_module.get(), engine, - options.use_large_float_range, - options.treat_gte_as_data_formatting)); + TF_ASSIGN_OR_RETURN( + auto args, copy_result_on_failure( + MakeFakeArguments(test_module.get(), engine, + options.use_large_float_range, + options.treat_gte_as_data_formatting), + ModuleResult::kOtherError, test_run_result)); // Use provided input literals as arguments, if any. if (iteration_literals_proto != nullptr && iteration_literals_proto->arguments_size() != 0) { if (iteration_literals_proto->arguments_size() != args.size()) { + if (test_run_result != nullptr) { + *test_run_result = ModuleResult::kOtherError; + } return xla::InvalidArgument( "Failed to use input literals as arguments; mismatched " "number of expected arguments."); @@ -160,14 +224,19 @@ Status RunAndCompare( xla::Shape(args[i].shape()), xla::Shape(iteration_literals_proto->arguments(i).shape())) .ok()) { + if (test_run_result != nullptr) { + *test_run_result = ModuleResult::kOtherError; + } return xla::InvalidArgument( "Failed to use input literals for argument %d " "because of a shape mismatch.", i); } - TF_ASSIGN_OR_RETURN(args[i], - xla::Literal::CreateFromProto( - iteration_literals_proto->arguments(i))); + TF_ASSIGN_OR_RETURN( + args[i], + copy_result_on_failure(xla::Literal::CreateFromProto( + iteration_literals_proto->arguments(i)), + ModuleResult::kOtherError, test_run_result)); } } } @@ -190,14 +259,22 @@ Status RunAndCompare( // properly match the test runner's numerics. TF_ASSIGN_OR_RETURN( reference_module, - PrepareReferenceModule(*test_module, test_runner, config_modifier_hook, - reference_module_modifier_hook)); + copy_result_on_failure( + PrepareReferenceModule(*test_module, test_runner, + config_modifier_hook, + reference_module_modifier_hook), + ModuleResult::kCompilationError, reference_run_result)); } TF_ASSIGN_OR_RETURN( auto test_result, - ExecuteWithRunner(std::move(test_module), buffer_assignment_proto, args, - test_runner, options.run_test_hlo_passes)); + copy_result_on_failure( + ExecuteWithRunner(std::move(test_module), buffer_assignment_proto, + args, test_runner, options.run_test_hlo_passes), + ModuleResult::kRuntimeError, test_run_result)); + if (test_run_result != nullptr) { + *test_run_result = ModuleResult::kRan; + } if (options.print_literals) { std::cout << "\n** Result with test runner " << test_runner->Name() << " **\n" @@ -209,15 +286,31 @@ Status RunAndCompare( } if (reference_module == nullptr) { - std::cerr << "Skipping reference runner\n"; + std::cerr << "Skipping reference runner"; + return OkStatus(); + } + if (const HloInstruction* root_instruction = + reference_module->entry_computation()->root_instruction(); + root_instruction->opcode() == HloOpcode::kCustomCall) { + // TODO(b/323849999) Use original computation for the reference platform. + std::cerr << "Skipping reference runner for a custom call " + << root_instruction->custom_call_target() << "\n"; + if (reference_run_result != nullptr) { + *reference_run_result = ModuleResult::kSkipped; + } return OkStatus(); } TF_ASSIGN_OR_RETURN( auto reference_result, - ExecuteWithRunner(std::move(reference_module), - /*buffer_assignment_proto=*/nullptr, args, - reference_runner, options.run_reference_hlo_passes)); + copy_result_on_failure( + ExecuteWithRunner(std::move(reference_module), + /*buffer_assignment_proto=*/nullptr, args, + reference_runner, options.run_reference_hlo_passes), + ModuleResult::kRuntimeError, reference_run_result)); + if (reference_run_result != nullptr) { + *reference_run_result = ModuleResult::kRan; + } if (options.print_literals) { std::cout << "\n** Result with reference runner " @@ -231,10 +324,164 @@ Status RunAndCompare( } ErrorSpec error_spec(static_cast(options.abs_error_bound), static_cast(options.rel_error_bound)); - return literal_comparison::Near(/*expected=*/reference_result, - /*actual=*/test_result, - /*error=*/error_spec, - /*detailed_message=*/true, &OnMiscompare); + + Status comparison_status = + literal_comparison::Near(/*expected=*/reference_result, + /*actual=*/test_result, + /*error=*/error_spec, + /*detailed_message=*/true, &OnMiscompare); + const ModuleResult comparison_result = + comparison_status.ok() ? ModuleResult::kMatched : ModuleResult::kMismatch; + if (test_run_result != nullptr) { + *test_run_result = comparison_result; + } + if (reference_run_result != nullptr) { + *reference_run_result = comparison_result; + } + return comparison_status; +} + +struct ChunkResult { + std::string module_name; + ModuleResult test_result = ModuleResult::kDidntRun; + ModuleResult reference_result = ModuleResult::kDidntRun; + Status status; + + bool operator<(const ChunkResult& other) const { + if (test_result != other.test_result) { + return test_result < other.test_result; + } + return reference_result < other.reference_result; + } +}; + +std::string BuildResultsTable(absl::Span chunk_results, + size_t num_modules) { + constexpr int kStatusWidth = 21; + constexpr int kNameWidth = 30; + constexpr int kThreeColumnsWidth = 5 + 2 * kStatusWidth + kNameWidth; + constexpr int kTableWidth = kThreeColumnsWidth + 30; + + std::ostringstream strstr; + auto print_row = [&](absl::string_view reference, absl::string_view test, + absl::string_view module_name, absl::string_view error) { + std::string formatted_error = absl::StrReplaceAll( + error, {{"\n", absl::StrCat("\n", std::string(kThreeColumnsWidth, ' '), + "|")}}); + strstr << " " << std::left << std::setw(kStatusWidth) << reference << "| " + << std::setw(kStatusWidth) << test << "| " << std::setw(kNameWidth) + << module_name << "| " << formatted_error << "\n"; + }; + auto print_line = [&](int line_width) { + strstr << std::string(line_width, '-') << "\n"; + }; + + print_row("Reference", "Test", "Module", "Status"); + print_line(kTableWidth); + + std::map, int> result_counts; + + for (const ChunkResult& chunk_result : chunk_results) { + const std::pair result_pair( + chunk_result.reference_result, chunk_result.test_result); + + ++result_counts[result_pair]; + print_row(ModuleResultToString(chunk_result.reference_result), + ModuleResultToString(chunk_result.test_result), + chunk_result.module_name, chunk_result.status.ToString()); + } + print_line(kTableWidth); + print_row("Reference", "Test", "Module", "Status"); + print_line(kTableWidth); + + strstr << "\n\n"; + + // Summary table. + print_line(kThreeColumnsWidth); + print_row("Reference", "Test", "Total count", ""); + print_line(kThreeColumnsWidth); + for (const auto& [result, count] : result_counts) { + print_row(ModuleResultToString(result.first), + ModuleResultToString(result.second), absl::StrCat(count), ""); + } + print_line(kThreeColumnsWidth); + if (chunk_results.size() < num_modules) { + strstr << "\n(did not " << (num_modules - chunk_results.size()) + << " modules due to earlier failures)\n\n"; + } + return strstr.str(); +} + +Status RunIsolatedAndCompare( + std::unique_ptr test_module, + const BufferAssignmentProto* buffer_assignment_proto, + HloRunnerInterface* test_runner, HloRunnerInterface* reference_runner, + std::minstd_rand0* engine, const RunHloModuleOptions& options, + xla::RunHloModuleIterationLiterals* iteration_literals_proto, + std::function + reference_module_modifier_hook, + std::function config_modifier_hook) { + CHECK(test_module); + CHECK(iteration_literals_proto == nullptr) + << "Cannot run decomposed module if input literals are provided."; + if (options.run_test_hlo_passes || (options.run_reference_hlo_passes && + !options.reference_platform.empty())) { + LOG(WARNING) + << "!!! Warning !!! When running decomposed module, running HLO " + "passes is likely not what you want. If you have unoptimized " + "HLO, first convert it to the optimized e.g. using the " + "hlo-opt tool, and then isolate without HLO passes."; + } + + std::vector chunk_results; + + TF_ASSIGN_OR_RETURN( + std::vector> modules, + DecomposeHloModule(*test_module, /*deduplicate_modules=*/true)); + + Status status = OkStatus(); + for (std::unique_ptr& module : modules) { + const std::string module_name = module->name(); + ModuleResult test_module_result = ModuleResult::kDidntRun; + ModuleResult reference_module_result = ModuleResult::kDidntRun; + Status chunk_status = RunAndCompareInternal( + std::move(module), buffer_assignment_proto, test_runner, + reference_runner, engine, options, iteration_literals_proto, + reference_module_modifier_hook, config_modifier_hook, + &test_module_result, &reference_module_result); + chunk_results.push_back({std::move(module_name), test_module_result, + reference_module_result, chunk_status}); + status.Update(chunk_status); + if (!chunk_status.ok() && test_module_result != ModuleResult::kMismatch) { + break; + } + } + absl::c_sort(chunk_results); + std::cout << BuildResultsTable(chunk_results, modules.size()); + return status; +} + +} // namespace + +Status RunAndCompare( + std::unique_ptr test_module, + const BufferAssignmentProto* buffer_assignment_proto, + HloRunnerInterface* test_runner, HloRunnerInterface* reference_runner, + std::minstd_rand0* engine, const RunHloModuleOptions& options, + xla::RunHloModuleIterationLiterals* iteration_literals_proto, + std::function + reference_module_modifier_hook, + std::function config_modifier_hook) { + if (options.isolate_instructions) { + return RunIsolatedAndCompare( + std::move(test_module), buffer_assignment_proto, test_runner, + reference_runner, engine, options, iteration_literals_proto, + reference_module_modifier_hook, config_modifier_hook); + } + return RunAndCompareInternal( + std::move(test_module), buffer_assignment_proto, test_runner, + reference_runner, engine, options, iteration_literals_proto, + reference_module_modifier_hook, config_modifier_hook, nullptr, nullptr); } Status RunAndCompare( @@ -250,11 +497,11 @@ Status RunAndCompare( BufferAssignmentProto buffer_assignment_proto; TF_ASSIGN_OR_RETURN( auto test_module, - LoadModuleFromFile(hlo_filename, hlo_module_loader_details::Config(), - options.input_format, config_modifier_hook, - options.use_buffer_assignment_from_proto - ? &buffer_assignment_proto - : nullptr)); + LoadModuleFromFile( + hlo_filename, options.input_format, + hlo_module_loader_details::Config(), config_modifier_hook, + options.use_buffer_assignment_from_proto ? &buffer_assignment_proto + : nullptr)); HloVerifier verifier( HloVerifierOpts{}.WithLayoutSensitive(false).WithAllowMixedPrecision( true)); @@ -272,7 +519,7 @@ Status RunAndCompare( std::unique_ptr iteration_literals_proto_local; if (iteration_literals_proto == nullptr) { // User did not explicitly give input - if (!options.force_fake_data && + if (!options.force_fake_data && !options.isolate_instructions && (options.input_format == "pb" || options.input_format == "pbtxt")) { // User is giving a snapshot (which contains inputs) LOG(INFO) << "Using input data from the user-provided snapshot."; diff --git a/xla/tools/run_hlo_module.h b/xla/tools/run_hlo_module.h index 61ce7626a7bbe..80b16b277c7f0 100644 --- a/xla/tools/run_hlo_module.h +++ b/xla/tools/run_hlo_module.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_runner.h" +#include "xla/status.h" #include "xla/tools/run_hlo_module.pb.h" #include "tsl/platform/status.h" @@ -55,6 +56,7 @@ struct RunHloModuleOptions { std::string input_literals_file; bool random_init_input_literals{true}; bool force_fake_data{false}; + bool isolate_instructions{false}; }; // Runs test_module on the platform with the name diff --git a/xla/tools/run_hlo_module_bin_test.cc b/xla/tools/run_hlo_module_bin_test.cc index 633b07672fe66..079babc9424b4 100644 --- a/xla/tools/run_hlo_module_bin_test.cc +++ b/xla/tools/run_hlo_module_bin_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tools/run_hlo_module_main.cc b/xla/tools/run_hlo_module_main.cc index 89dc82cdd8db4..c19130df2c656 100644 --- a/xla/tools/run_hlo_module_main.cc +++ b/xla/tools/run_hlo_module_main.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,11 +26,11 @@ limitations under the License. #include "xla/service/hlo_runner.h" #include "xla/service/platform_util.h" #include "xla/tools/run_hlo_module.h" +#include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" #include "tsl/platform/test.h" -#include "tsl/util/command_line_flags.h" namespace { const char* const kUsage = R"( @@ -127,6 +127,12 @@ int main(int argc, char** argv) { "iterations", &opts.iterations, "The number of times to run the module. Each iteration will be run " "with different input data."), + tsl::Flag( + "isolate_instructions", &opts.isolate_instructions, + "Rather than executing the entire module at once, run every " + "instruction individually, including the top-level and control-flow " + "dependent computations (e.g. inside conditions, calls). Skip " + "instructions inside fused computations etc."), tsl::Flag("different_random_seeds", &different_random_seeds, "Whether each iteration should use a different random seed for " "the HloModuleConfig."), @@ -194,8 +200,7 @@ int main(int argc, char** argv) { } if (!reference_platform_name.empty()) { - std::cerr << failure_count << "/" << iteration_count - << " runs miscompared.\n"; + std::cerr << failure_count << "/" << iteration_count << " runs failed.\n"; } return failure_count == 0 ? 0 : -1; diff --git a/xla/tools/show_literal.cc b/xla/tools/show_literal.cc index 3a4867521ee35..e21a94a2d443d 100644 --- a/xla/tools/show_literal.cc +++ b/xla/tools/show_literal.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tools/show_signature.cc b/xla/tools/show_signature.cc index 22ce2e26717e9..4cd64dfd86a19 100644 --- a/xla/tools/show_signature.cc +++ b/xla/tools/show_signature.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tools/show_text_literal.cc b/xla/tools/show_text_literal.cc index 0c075227c1682..c903565e8968b 100644 --- a/xla/tools/show_text_literal.cc +++ b/xla/tools/show_text_literal.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/tools/tests/hlo_expand_test.cc b/xla/tools/tests/hlo_expand_test.cc index 54a62fd8f18d0..1289f48a6ea67 100644 --- a/xla/tools/tests/hlo_expand_test.cc +++ b/xla/tools/tests/hlo_expand_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -69,6 +69,34 @@ ENTRY %main.3 () -> f64[3,3] { EXPECT_THAT(stdout_output_, testing::HasSubstr(expected_hlo_string)); } +TEST_F(HloExpandTest, SpmdHlo) { + std::string hlo_path = tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "tools", + "tests", "spmd.hlo"); + std::vector additional_flags = {"--spmd_expander", hlo_path}; + HloOpt(additional_flags); + + const std::string& expected_hlo_string = + R"(HloModule module, entry_computation_layout={(f32[24,64]{1,0}, f32[39296,64]{1,0})->f32[24,19648]{1,0}}, num_partitions=2 + +ENTRY %entry_spmd (param: f32[24,64], param.1: f32[39296,64]) -> f32[24,19648] { + %param = f32[24,64]{1,0} parameter(0), sharding={replicated} + %lhs.copy.1 = f32[24,64]{1,0} copy(f32[24,64]{1,0} %param) + %param.1 = f32[39296,64]{1,0} parameter(1), sharding={replicated} + %constant = s32[2]{0} constant({0, 19648}) + %partition-id = u32[] partition-id() + %dynamic-slice = s32[1]{0} dynamic-slice(s32[2]{0} %constant, u32[] %partition-id), dynamic_slice_sizes={1} + %reshape = s32[] reshape(s32[1]{0} %dynamic-slice) + %constant.1 = s32[] constant(0) + %dynamic-slice.1 = f32[19648,64]{1,0} dynamic-slice(f32[39296,64]{1,0} %param.1, s32[] %reshape, s32[] %constant.1), dynamic_slice_sizes={19648,64} + %rhs.copy.1 = f32[19648,64]{1,0} copy(f32[19648,64]{1,0} %dynamic-slice.1) + ROOT %dot.1 = f32[24,19648]{1,0} dot(f32[24,64]{1,0} %lhs.copy.1, f32[19648,64]{1,0} %rhs.copy.1), lhs_contracting_dims={1}, rhs_contracting_dims={1} +})"; + + EXPECT_TRUE(exited_normally_); + EXPECT_EQ(exit_status_, 0); + EXPECT_THAT(stdout_output_, testing::HasSubstr(expected_hlo_string)); +} + TEST_F(HloExpandTest, CholeskyExpanderHlo) { std::string hlo_path = tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "tools", "tests", "cholesky.hlo"); @@ -103,7 +131,7 @@ TEST_F(HloExpandTest, InvalidInputFileExtension) { HloOpt(additional_flags); const std::string& expected_string = - "input_format must be specified as [hlo|pb|pbtxt]."; + "input_format must be specified as [hlo|pb|pbtxt|txt]."; EXPECT_TRUE(exited_normally_); EXPECT_EQ(exit_status_, 1); @@ -115,7 +143,7 @@ TEST_F(HloExpandTest, InvalidInputFormat) { HloOpt(additional_flags); const std::string& expected_string = - "input_format must be specified as [hlo|pb|pbtxt]."; + "input_format must be specified as [hlo|pb|pbtxt|txt]."; EXPECT_TRUE(exited_normally_); EXPECT_EQ(exit_status_, 1); @@ -175,7 +203,8 @@ TEST_F(HloExpandTest, UnsupportedOutputFormat) { HloOpt(additional_flags); const std::string& expected_string = - "Printing to stdout must specify supported output_format=[hlo|pbtxt]."; + "Printing to stdout must specify supported " + "output_format=[hlo|pbtxt|txt]."; EXPECT_TRUE(exited_normally_); EXPECT_EQ(exit_status_, 1); diff --git a/xla/tools/tests/spmd.hlo b/xla/tools/tests/spmd.hlo new file mode 100644 index 0000000000000..f68c23734d535 --- /dev/null +++ b/xla/tools/tests/spmd.hlo @@ -0,0 +1,12 @@ +HloModule module, num_partitions=2 + +ENTRY entry { + %lhs = f32[24,64] parameter(0) + %lhs.copy = f32[24,64] copy(%lhs), sharding={replicated} + %rhs = f32[39296,64] parameter(1) + %rhs.copy = f32[39296,64] copy(%rhs), sharding={devices=[2,1]0,1} + ROOT %dot = f32[24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={1}, rhs_contracting_dims={1}, + sharding={devices=[1,2]0,1} +} diff --git a/xla/tools/xla_compile_lib.cc b/xla/tools/xla_compile_lib.cc index 4c26660a6b22f..852c83615854e 100644 --- a/xla/tools/xla_compile_lib.cc +++ b/xla/tools/xla_compile_lib.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,27 +23,54 @@ limitations under the License. #include #include "google/protobuf/duration.pb.h" +#include "absl/cleanup/cleanup.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "stablehlo/dialect/Register.h" // from @stablehlo +#include "xla/client/xla_computation.h" +#include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/pjrt/mlir_to_hlo.h" #include "xla/service/compiler.h" #include "xla/service/cpu/cpu_compiler.h" #include "xla/service/cpu/cpu_executable.h" #include "xla/service/executable.h" +#include "xla/service/export_hlo.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/symbol_repository.h" #include "xla/service/xla_compile_result.pb.h" +#include "xla/shape.h" +#include "xla/status.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/tools/hlo_module_loader.h" #include "xla/util.h" #include "tsl/platform/env.h" #include "tsl/platform/env_time.h" #include "tsl/platform/errors.h" +#include "tsl/platform/path.h" +#include "tsl/platform/protobuf.h" #include "tsl/platform/status.h" +#include "tsl/platform/status_to_from_proto.h" #include "tsl/platform/statusor.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "xla/service/gpu/autotuner_util.h" #include "xla/service/gpu/executable.pb.h" +#include "xla/service/gpu/gpu_symbol_repository.h" #include "xla/stream_executor/gpu/gpu_init.h" #endif #if GOOGLE_CUDA @@ -54,20 +81,22 @@ limitations under the License. namespace xla { -static StatusOr AotCompileCpuExecutable( +static absl::StatusOr AotCompileCpuExecutable( std::unique_ptr hlo_module) { cpu::CpuCompiler cpu_compiler; + auto module_group = std::make_unique(std::move(hlo_module)); TF_ASSIGN_OR_RETURN( - std::unique_ptr cpu_executable, - cpu_compiler.CompileXlaRuntimeCpuExecutable(std::move(hlo_module))); + std::vector> executables, + cpu_compiler.Compile(std::move(module_group), {{nullptr}}, {nullptr})); TF_ASSIGN_OR_RETURN(std::unique_ptr aot_result, - cpu_compiler.Export(cpu_executable.get())); + cpu_compiler.Export(executables[0].get())); return aot_result->SerializeAsString(); } -static StatusOr CompileGpuExecutable( +static absl::StatusOr CompileGpuExecutable( std::unique_ptr hlo_module, - std::optional target_config) { + std::optional target_config, + CompilationResult& result) { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM const bool aot = target_config.has_value(); @@ -77,69 +106,63 @@ static StatusOr CompileGpuExecutable( auto gpu_compiler = gpu::AMDGPUCompiler(); #endif - Compiler::CompileOptions compile_options; - - stream_executor::StreamExecutor* stream_executor = nullptr; - std::unique_ptr allocator; - if (aot) { - compile_options.target_config = *target_config; - } else { - TF_RETURN_IF_ERROR(stream_executor::ValidateGPUMachineManager()); - TF_ASSIGN_OR_RETURN( - stream_executor, - stream_executor::GPUMachineManager()->ExecutorForDevice(0)); - allocator = - std::make_unique( - stream_executor); - compile_options.device_allocator = allocator.get(); - } - - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_after_opt, - gpu_compiler.RunHloPasses(std::move(hlo_module), stream_executor, - compile_options)); + auto module_group = std::make_unique(std::move(hlo_module)); if (aot) { - auto module_group = - std::make_unique(std::move(module_after_opt)); - AotCompilationOptions aot_options(gpu_compiler.PlatformId()); aot_options.set_target_config(*target_config); + // We need the optimized module, so we call RunHloPasses ourselves above. + aot_options.set_run_backend_only(true); TF_ASSIGN_OR_RETURN( std::vector> aot_results, gpu_compiler.CompileAheadOfTime(std::move(module_group), aot_options)); - TF_ASSIGN_OR_RETURN(std::string result, + TF_ASSIGN_OR_RETURN(std::string compile_result, aot_results[0]->SerializeAsString()); - return result; + *result.mutable_hlo_module() = + aot_results[0]->optimized_module()->ToProto(); + return compile_result; } + Compiler::CompileOptions compile_options; + TF_RETURN_IF_ERROR(stream_executor::ValidateGPUMachineManager()); TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - gpu_compiler.RunBackend(std::move(module_after_opt), stream_executor, - compile_options)); - return executable->module().ToString(); + stream_executor::StreamExecutor * stream_executor, + stream_executor::GPUMachineManager()->ExecutorForDevice(0)); + auto allocator = + std::make_unique( + stream_executor); + compile_options.device_allocator = allocator.get(); + + TF_ASSIGN_OR_RETURN( + std::vector> executables, + gpu_compiler.Compile(std::move(module_group), {{stream_executor}}, + compile_options)); + *result.mutable_hlo_module() = executables[0]->module().ToProto(); + return executables[0]->module().ToString(); #else LOG(ERROR) << "Neither ROCm nor CUDA present; returning empty."; return ""; #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } -StatusOr CompileExecutable( +absl::StatusOr CompileExecutable( std::unique_ptr hlo_module, absl::string_view platform, - std::optional target_config) { + std::optional target_config, + CompilationResult& result) { if (platform == "cpu") { return AotCompileCpuExecutable(std::move(hlo_module)); } else if (platform == "gpu") { - return CompileGpuExecutable(std::move(hlo_module), target_config); + return CompileGpuExecutable(std::move(hlo_module), target_config, result); } return absl::UnimplementedError( absl::StrCat("platform", platform, " is not supported")); } -Status WriteResultFile(const std::string& result_output_file, TimerStats& stats, - CompilationResult& compilation_result) { +absl::Status WriteResultFile(const absl::string_view result_output_file, + TimerStats& stats, + CompilationResult& compilation_result) { if (result_output_file.empty()) { return absl::OkStatus(); } @@ -155,8 +178,132 @@ Status WriteResultFile(const std::string& result_output_file, TimerStats& stats, duration; *compilation_result.mutable_perf_stats()->mutable_total_duration() = duration; - return tsl::WriteBinaryProto(tsl::Env::Default(), result_output_file, - compilation_result); + return tsl::WriteBinaryProto( + tsl::Env::Default(), std::string(result_output_file), compilation_result); +} + +absl::StatusOr> LoadModule( + const absl::string_view module_path) { + auto format = std::string(tsl::io::Extension(module_path)); + if (format == "hlo" || format == "txt" || format == "pb") { + return LoadModuleFromFile( + std::string(module_path), format, hlo_module_loader_details::Config(), + [&](HloModuleConfig* c) {}, nullptr); + } + std::string module_string; + TF_RETURN_IF_ERROR(tsl::ReadFileToString( + tsl::Env::Default(), std::string(module_path), &module_string)); + + mlir::DialectRegistry dialects; + // TODO(b/248362914): Register all required dialects. + dialects.insert(); + dialects.insert(); + dialects.insert(); + mlir::stablehlo::registerAllDialects(dialects); + + // Parse MHLO module. + auto threading = mlir::MLIRContext::Threading::DISABLED; + auto ctx = std::make_unique(dialects, threading); + mlir::OwningOpRef module = + mlir::parseSourceString(module_string, ctx.get()); + + // Convert Mhlo to Hlo Module. + XlaComputation xla_computation; + TF_RETURN_IF_ERROR( + MlirToXlaComputation(*module, xla_computation, false, false)); + HloModuleProto hlo_module_proto = xla_computation.proto(); + + TF_ASSIGN_OR_RETURN(ProgramShape shape, xla_computation.GetProgramShape()); + DebugOptions debug_options = GetDebugOptionsFromFlags(); + HloModuleConfig config(shape); + config.set_debug_options(debug_options); + return HloModule::CreateFromProto(hlo_module_proto, config); +} + +absl::Status XlaCompileMain( + absl::string_view module_path, absl::string_view output_path, + absl::string_view platform, absl::string_view gpu_target_config_path, + absl::string_view autotune_results_path, absl::string_view symbol_repo, + absl::string_view symbol_id, const bool use_attached_device, + const bool wait_for_uploads, absl::string_view result_output_file) { + std::unique_ptr hlo_module; + std::unique_ptr target_config; + if (!symbol_id.empty()) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr mod, + LookupSymbolInRepository(symbol_repo, symbol_id, BackendType::kGpu)); + if (mod == nullptr) { + return absl::NotFoundError( + absl::StrCat("Could not find ", symbol_id, " in ", symbol_repo)); + } +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + if (auto* data = static_cast( + mod->backend_specific_data.get()); + data != nullptr) { + target_config = std::move(mod->target_config); + } +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + hlo_module = std::move(mod->hlo_module); + } else { + TF_ASSIGN_OR_RETURN(hlo_module, LoadModule(module_path)); + } + + xla::TimerStats stats; + xla::ScopedLoggingTimer timer("compilation", true, "xla_compile_main.cc", 1, + &stats); + CompilationResult compilation_result; + absl::Cleanup cleanup([&] { + // Make sure we stop the timer if compilation failed. + timer.StopAndLog(); + if (!result_output_file.empty()) { + TF_QCHECK_OK( + WriteResultFile(result_output_file, stats, compilation_result)); + } + }); + // Run AOT compilation. + std::optional cfg = std::nullopt; + if (platform == "gpu") { + if (!gpu_target_config_path.empty()) { + // Parse GpuTargetConfig. + std::string gpu_target_config_string; + TF_RETURN_IF_ERROR(tsl::ReadFileToString( + tsl::Env::Default(), std::string(gpu_target_config_path), + &gpu_target_config_string)); + stream_executor::GpuTargetConfigProto gpu_target_config_proto; + + if (!tsl::protobuf::TextFormat::ParseFromString( + gpu_target_config_string, &gpu_target_config_proto)) { + return FailedPrecondition("Failed to parse GpuTargetConfigProto"); + } + + target_config = + std::make_unique(gpu_target_config_proto); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + if (!autotune_results_path.empty()) { + TF_RETURN_IF_ERROR(gpu::AutotunerUtil::LoadAutotuneResultsFromFile( + autotune_results_path)); + } +#endif + } + + cfg = (use_attached_device) ? std::nullopt + : std::make_optional(*std::move(target_config)); + } + auto result = CompileExecutable(std::move(hlo_module), platform, cfg, + compilation_result); + if (!result.ok()) { + *compilation_result.mutable_status() = tsl::StatusToProto(result.status()); + return result.status(); + } + + TF_RETURN_IF_ERROR(tsl::WriteStringToFile(tsl::Env::Default(), + std::string(output_path), *result)); + + if (wait_for_uploads) { + MaybeWaitForUploads(); + } + return OkStatus(); } } // namespace xla diff --git a/xla/tools/xla_compile_lib.h b/xla/tools/xla_compile_lib.h index 63176f10faf63..78b868d2a23ae 100644 --- a/xla/tools/xla_compile_lib.h +++ b/xla/tools/xla_compile_lib.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,30 +20,46 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/compiler.h" #include "xla/service/xla_compile_result.pb.h" #include "xla/util.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace xla { // Compiles the provided module for the given platform, either "cpu" or "gpu". // When compiling for GPU, if the target config is provided, the compilation // will be AOT. If it is not provided, an attached GPU will be used. When -// compiling for CPU, the compilation will always be AOT. +// compiling for CPU, the compilation will always be AOT. If a result is +// provided, the post-optimization module will be stored in it. // // This is the expected entry point to the compilation functionality. -StatusOr CompileExecutable( +absl::StatusOr CompileExecutable( std::unique_ptr hlo_module, absl::string_view platform, - std::optional target_config); + std::optional target_config, + CompilationResult& result); // Merges the measured duration into compilation_result and writes // compilation_result to result_output_file in the wire format. -Status WriteResultFile(const std::string& result_output_file, TimerStats& stats, - CompilationResult& compilation_result); +absl::Status WriteResultFile(absl::string_view result_output_file, + TimerStats& stats, + CompilationResult& compilation_result); + +// Loads the HLO, MHLO, or StableHLO module at the given file path. +absl::StatusOr> LoadModule( + absl::string_view module_path); + +// Full entry point if you want to wrap a binary around this functionality. +// See flag definitions in ../service/xla_compile_main.cc for semantics. +absl::Status XlaCompileMain( + absl::string_view module_path, absl::string_view output_path, + absl::string_view platform, absl::string_view gpu_target_config_path, + absl::string_view autotune_results_path, absl::string_view symbol_repo, + absl::string_view symbol_id, bool use_attached_device, + bool wait_for_uploads, absl::string_view result_output_file); } // namespace xla diff --git a/xla/tools/xla_compile_lib_test.cc b/xla/tools/xla_compile_lib_test.cc index dacd2ecec2682..415a5ccdb77ef 100644 --- a/xla/tools/xla_compile_lib_test.cc +++ b/xla/tools/xla_compile_lib_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -45,6 +45,7 @@ namespace xla { namespace { using ::testing::IsEmpty; +using ::testing::IsNull; using ::testing::Not; using ::tsl::testing::IsOk; using ::tsl::testing::IsOkAndHolds; @@ -78,13 +79,18 @@ class XlaCompileLibTest : public HloTestBase { }; TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(CompilesForCpu)) { - EXPECT_THAT(CompileExecutable(std::move(module_), "cpu", std::nullopt), - IsOkAndHolds(Not(IsEmpty()))); + CompilationResult result; + EXPECT_THAT( + CompileExecutable(std::move(module_), "cpu", std::nullopt, result), + IsOkAndHolds(Not(IsEmpty()))); } TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(CompilesForGpuWithDevice)) { - EXPECT_THAT(CompileExecutable(std::move(module_), "gpu", std::nullopt), - IsOkAndHolds(Not(IsEmpty()))); + CompilationResult result; + EXPECT_THAT( + CompileExecutable(std::move(module_), "gpu", std::nullopt, result), + IsOkAndHolds(Not(IsEmpty()))); + EXPECT_TRUE(result.has_hlo_module()) << result.DebugString(); } TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(CompilesForGpuWithoutDevice)) { @@ -94,12 +100,16 @@ TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(CompilesForGpuWithoutDevice)) { stream_executor::GpuTargetConfigProto target_config; TF_ASSERT_OK(tsl::ReadTextProto(tsl::Env::Default(), target_config_path, &target_config)); - EXPECT_THAT(CompileExecutable(std::move(module_), "gpu", std::nullopt), - IsOkAndHolds(Not(IsEmpty()))); + CompilationResult result; + EXPECT_THAT( + CompileExecutable(std::move(module_), "gpu", std::nullopt, result), + IsOkAndHolds(Not(IsEmpty()))); + EXPECT_TRUE(result.has_hlo_module()) << result.DebugString(); } TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(ErrorsOnUnexpectedPlatform)) { - EXPECT_THAT(CompileExecutable(nullptr, "tpu", std::nullopt), + CompilationResult result; + EXPECT_THAT(CompileExecutable(nullptr, "tpu", std::nullopt, result), StatusIs(tsl::error::UNIMPLEMENTED)); } @@ -142,5 +152,58 @@ TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(WriteResultFileWritesTheFile)) { got_result.perf_stats().total_duration().nanos()); } +TEST_F(XlaCompileLibTest, LoadModuleErrors) { + EXPECT_THAT(LoadModule("/does/not/exist"), Not(IsOk())); +} + +TEST_F(XlaCompileLibTest, LoadModuleLoadsTextFormat) { + const std::string module_file = + tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt"); + TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file, + module_->ToString())); + + EXPECT_THAT(LoadModule(module_file), IsOkAndHolds(Not(IsNull()))); +} + +TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(MainForCpu)) { + const std::string module_file = + tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt"); + TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file, + module_->ToString())); + + const std::string output_path = + tsl::io::JoinPath(tsl::testing::TmpDir(), "output"); + const std::string result_file = + tsl::io::JoinPath(tsl::testing::TmpDir(), "result.pb"); + + TF_EXPECT_OK(XlaCompileMain(module_file, output_path, "cpu", + /* gpu_target_config_path= */ "", + /* autotune_results_path= */ "", + /* symbol_repo= */ "", /* symbol_id= */ "", + /* use_attached_device=*/false, + /* wait_for_uploads */ false, + /* result_output_file=*/result_file)); +} + +TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(MainForGpu)) { + const std::string module_file = + tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt"); + TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file, + module_->ToString())); + + const std::string output_path = + tsl::io::JoinPath(tsl::testing::TmpDir(), "output"); + const std::string result_file = + tsl::io::JoinPath(tsl::testing::TmpDir(), "result.pb"); + + TF_EXPECT_OK(XlaCompileMain(module_file, output_path, "gpu", + /* gpu_target_config_path= */ "", + /* autotune_results_path= */ "", + /* symbol_repo= */ "", /* symbol_id= */ "", + /* use_attached_device=*/true, + /* wait_for_uploads */ false, + /* result_output_file=*/result_file)); +} + } // namespace } // namespace xla diff --git a/xla/translate/BUILD b/xla/translate/BUILD index f2d537757a157..79f2ab89e9db4 100644 --- a/xla/translate/BUILD +++ b/xla/translate/BUILD @@ -1,9 +1,13 @@ -load("//xla:xla.bzl", "xla_cc_binary") load("@bazel_skylib//rules:build_test.bzl", "build_test") +load("@tsl//tsl:tsl.bzl", "internal_visibility") +load("//xla:xla.bzl", "xla_cc_binary") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], + default_visibility = internal_visibility([ + "//learning/brain/mlir:tensorflow_friends", + "//learning/brain/mlir:xla_friends", + ]), licenses = ["notice"], ) @@ -22,10 +26,8 @@ xla_cc_binary( "//xla/service/cpu:cpu_compiler", "//xla/service/cpu:cpu_transfer_manager", "//xla/stream_executor/host:host_platform", - "//xla/stream_executor/platform", "//xla/translate/hlo_to_mhlo:translate_registration", "//xla/translate/mhlo_to_hlo:translate_registration", - "//xla/translate/mhlo_to_lhlo_with_xla:translate_registration", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -33,3 +35,29 @@ xla_cc_binary( "@tsl//tsl/platform:platform_port", ], ) + +build_test( + name = "xla-translate-opt_build_test", + targets = [ + ":xla-translate-opt", + ], +) + +xla_cc_binary( + name = "xla-translate-opt", + testonly = True, + srcs = ["xla_translate_opt_main.cc"], + deps = [ + "//xla/mlir/framework/ir:xla_framework", + "//xla/mlir/framework/transforms:passes", + "//xla/mlir_hlo:hlo_dialect_registration", + "//xla/service:cpu_plugin", + "//xla/service/cpu:hlo_xla_runtime_pipeline", # buildcleaner: keep + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:Support", + "@stablehlo//:register", + "@tsl//tsl/platform:platform_port", + ], +) diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD index 0f0c5e8428ea8..299dbba082a4c 100644 --- a/xla/translate/hlo_to_mhlo/BUILD +++ b/xla/translate/hlo_to_mhlo/BUILD @@ -1,9 +1,13 @@ +load("@tsl//tsl:tsl.bzl", "internal_visibility") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//xla:xla.bzl", "xla_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], + default_visibility = internal_visibility([ + "//learning/brain/mlir:tensorflow_friends", + "//learning/brain/mlir:xla_friends", + ]), licenses = ["notice"], ) @@ -13,12 +17,31 @@ cc_library( hdrs = ["attribute_importer.h"], deps = [ "//xla:shape_util", - "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", "//xla/mlir_hlo", "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/status:statusor", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "custom_call_importer", + srcs = ["custom_call_importer.cc"], + hdrs = ["custom_call_importer.h"], + deps = [ + "//xla:status", + "//xla:statusor", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", ], ) @@ -45,15 +68,15 @@ cc_library( ], deps = [ ":attribute_importer", + ":custom_call_importer", ":hlo_utils", ":location_importer", "//xla:comparison_util", - "//xla:permutation_util", - "//xla:printer", + "//xla:literal", "//xla:protobuf_util", + "//xla:shape_layout", "//xla:shape_util", "//xla:status", - "//xla:status_macros", "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", @@ -62,13 +85,19 @@ cc_library( "//xla/mlir_hlo", "//xla/service:hlo_proto_cc", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:Support", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", ], ) @@ -97,10 +126,9 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/mlir/utils:type_util", "//xla/mlir_hlo", "//xla/mlir_hlo:convert_op_folder", - "//xla/mlir_hlo:lhlo", - "//xla/service/llvm_ir:llvm_util", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:SparseTensorDialect", @@ -121,7 +149,6 @@ xla_cc_test( "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:protobuf", "@tsl//tsl/platform:test_main", ], ) @@ -148,7 +175,6 @@ cc_library( "//xla/service:hlo_parser", "//xla/service:hlo_proto_cc", "//xla/service/llvm_ir:llvm_util", - "//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "@llvm-project//mlir:IR", "@tsl//tsl/platform:protobuf", ], diff --git a/xla/translate/hlo_to_mhlo/attribute_importer.cc b/xla/translate/hlo_to_mhlo/attribute_importer.cc index 8633585ca4739..18fd20290ef10 100644 --- a/xla/translate/hlo_to_mhlo/attribute_importer.cc +++ b/xla/translate/hlo_to_mhlo/attribute_importer.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,7 +21,9 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/layout_util.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -127,7 +129,19 @@ mlir::ArrayAttr ConvertOutputOperandAliasing( return builder->getArrayAttr(attrs); } -StatusOr ConvertFftType(FftType type) { +absl::StatusOr ConvertSparsityDescriptor( + xla::SparsityDescriptor sparsity_descriptor, mlir::Builder* builder) { + switch (sparsity_descriptor.type()) { + case SPARSITY_STRUCTURED_N_M: + return mlir::mhlo::SparsityDescriptorAttr::get( + builder->getContext(), sparsity_descriptor.dimension(), + sparsity_descriptor.n(), sparsity_descriptor.m()); + default: + return InvalidArgument("Unknown sparsity descriptor type"); + } +} + +absl::StatusOr ConvertFftType(FftType type) { switch (type) { case FftType::FFT: return mlir::mhlo::FftType::FFT; @@ -142,7 +156,7 @@ StatusOr ConvertFftType(FftType type) { } } -StatusOr ConvertTranspose( +absl::StatusOr ConvertTranspose( xla::TriangularSolveOptions_Transpose transpose) { switch (transpose) { case TriangularSolveOptions::NO_TRANSPOSE: @@ -158,7 +172,7 @@ StatusOr ConvertTranspose( } } -StatusOr ConvertCustomCallApiVersion( +absl::StatusOr ConvertCustomCallApiVersion( xla::CustomCallApiVersion api_version) { switch (api_version) { case xla::CustomCallApiVersion::API_VERSION_UNSPECIFIED: @@ -179,7 +193,7 @@ StatusOr ConvertCustomCallApiVersion( } } -StatusOr ExtractLayoutsFromShapes( +absl::StatusOr ExtractLayoutsFromShapes( const absl::Span shapes_with_layouts, mlir::Builder* builder) { std::vector layouts; for (auto& shape_and_layout : shapes_with_layouts) { @@ -216,8 +230,8 @@ StatusOr ExtractLayoutsFromShapes( return builder->getArrayAttr(layouts); } -StatusOr ExtractLayoutsFromTuple(const Shape shape, - mlir::Builder* builder) { +absl::StatusOr ExtractLayoutsFromTuple( + const Shape shape, mlir::Builder* builder) { if (!shape.IsTuple()) return InvalidArgument("Expected shape to be Tuple"); return ExtractLayoutsFromShapes(shape.tuple_shapes(), builder); } diff --git a/xla/translate/hlo_to_mhlo/attribute_importer.h b/xla/translate/hlo_to_mhlo/attribute_importer.h index cd88cd95010e4..be291d59447e2 100644 --- a/xla/translate/hlo_to_mhlo/attribute_importer.h +++ b/xla/translate/hlo_to_mhlo/attribute_importer.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,13 +19,12 @@ limitations under the License. #include #include -#include "mlir/IR/Attributes.h" // from @llvm-project +#include "absl/status/statusor.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/statusor.h" #include "xla/xla_data.pb.h" namespace xla { @@ -56,23 +55,27 @@ mlir::ArrayAttr ConvertOutputOperandAliasing( std::pair>>& aliaInfo, mlir::Builder* builder); -StatusOr ConvertFftType(FftType type); -StatusOr ConvertTranspose( +// Converts the sparsity descriptor to attributes. +absl::StatusOr ConvertSparsityDescriptor( + xla::SparsityDescriptor sparsity_descriptor, mlir::Builder* builder); + +absl::StatusOr ConvertFftType(FftType type); +absl::StatusOr ConvertTranspose( TriangularSolveOptions_Transpose transpose); -StatusOr ConvertCustomCallApiVersion( +absl::StatusOr ConvertCustomCallApiVersion( xla::CustomCallApiVersion api_version); // Extracts layouts from shapes and converts it into layout attributes (array of // rank-1 index tensors). Returns an error if any of the shapes is a tuple. -StatusOr ExtractLayoutsFromShapes( +absl::StatusOr ExtractLayoutsFromShapes( const absl::Span shapes_with_layouts, mlir::Builder* builder); // Extracts the layouts of each element from a tuple shape and returns them as // an array of rank-1 index tensors. Returns an error in presence of nested // tuple shapes. -StatusOr ExtractLayoutsFromTuple(const xla::Shape shape, - mlir::Builder* builder); +absl::StatusOr ExtractLayoutsFromTuple(const xla::Shape shape, + mlir::Builder* builder); } // namespace xla diff --git a/xla/translate/hlo_to_mhlo/custom_call_importer.cc b/xla/translate/hlo_to_mhlo/custom_call_importer.cc new file mode 100644 index 0000000000000..4b2217397eb98 --- /dev/null +++ b/xla/translate/hlo_to_mhlo/custom_call_importer.cc @@ -0,0 +1,210 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/translate/hlo_to_mhlo/custom_call_importer.h" + +#include +#include + +#include "absl/strings/match.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/AsmParser/AsmParser.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/status.h" +#include "xla/util.h" + +namespace xla { +namespace { + +absl::StatusOr ImportDynamicBroadcastInDimOp( + mlir::StringRef backend_config, mlir::Location loc, mlir::Type result_type, + mlir::ValueRange operands, mlir::OpBuilder* builder) { + if (backend_config.empty()) { + return Internal("backend_config attribute cannot be empty."); + } + + auto attr = mlir::parseAttribute(backend_config, builder->getContext()) + .dyn_cast(); + if (!attr) { + return Internal( + "Couldn't parse backend config into a dictionary attribute"); + } + + auto broadcast_dimensions_attr = + attr.get("broadcast_dimensions").dyn_cast_or_null(); + if (!broadcast_dimensions_attr) { + return Internal("broadcast_dimensions attribute is required."); + } + + std::vector broadcast_dimensions(broadcast_dimensions_attr.size()); + for (auto [i, broadcast_dimension] : + llvm::enumerate(broadcast_dimensions_attr)) { + broadcast_dimensions[i] = + broadcast_dimension.cast().getInt(); + } + + return builder + ->create( + loc, result_type, operands[0], operands[1], + builder->getI64TensorAttr(broadcast_dimensions)) + .getOperation(); +} + +absl::StatusOr ImportDynamicReshapeOp( + mlir::StringRef backend_config, mlir::Location loc, mlir::Type result_type, + mlir::ValueRange operands, mlir::OpBuilder* builder) { + if (!backend_config.empty()) { + return Internal("backend_config attribute must be empty."); + } + return builder + ->create(loc, result_type, operands) + .getOperation(); +} + +absl::StatusOr ImportRealDynamicSliceOp( + mlir::StringRef backend_config, mlir::Location loc, mlir::Type result_type, + mlir::ValueRange operands, mlir::OpBuilder* builder) { + if (!backend_config.empty()) { + return Internal("backend_config attribute must be empty."); + } + return builder + ->create(loc, result_type, operands) + .getOperation(); +} + +} // namespace + +mlir::Type getQuantizedType(mlir::DictionaryAttr& backend_config) { + std::vector scales; + std::vector zero_points; + int64_t quantization_dimension = -1, storage_max = 0, storage_min = 0; + mlir::Type storage_type, expressed_type; + + if (const mlir::Attribute scales_attr = backend_config.get("scale"); + scales_attr) { + for (auto scale_attr : scales_attr.cast()) { + scales.push_back(scale_attr.cast().getValueAsDouble()); + } + } + + auto zero_points_attr = backend_config.get("zero_point"); + if (zero_points_attr) { + for (auto zero_point_attr : zero_points_attr.cast()) { + zero_points.push_back(zero_point_attr.cast().getInt()); + } + } + + auto quantization_dimension_attr = + backend_config.get("quantization_dimension"); + if (quantization_dimension_attr) { + quantization_dimension = + quantization_dimension_attr.cast().getInt(); + } + + auto storage_max_attr = backend_config.get("storage_max"); + if (storage_max_attr) { + storage_max = storage_max_attr.cast().getInt(); + } + + auto storage_min_attr = backend_config.get("storage_min"); + if (storage_min_attr) { + storage_min = storage_min_attr.cast().getInt(); + } + + auto storage_type_attr = backend_config.get("storage_type"); + if (storage_type_attr) { + storage_type = storage_type_attr.cast().getValue(); + } + + auto expressed_type_attr = backend_config.get("expressed_type"); + if (expressed_type_attr) { + expressed_type = expressed_type_attr.cast().getValue(); + } + + auto is_signed = storage_type.cast().isSignless(); + + if (quantization_dimension != -1) { + return mlir::quant::UniformQuantizedPerAxisType::get( + is_signed, storage_type, expressed_type, scales, zero_points, + quantization_dimension, storage_min, storage_max); + } else { + return mlir::quant::UniformQuantizedType::get( + is_signed, storage_type, expressed_type, scales[0], zero_points[0], + storage_min, storage_max); + } +} + +absl::StatusOr ImportCustomCallAsOp( + const HloCustomCallInstruction* instruction, mlir::Location loc, + mlir::Type result_type, mlir::ValueRange operands, + mlir::OpBuilder* builder) { + const std::string& custom_call_target = instruction->custom_call_target(); + const std::string& backend_config_str = + instruction->raw_backend_config_string(); + if (custom_call_target == "mhlo.dynamic_broadcast_in_dim") { + return ImportDynamicBroadcastInDimOp(backend_config_str, loc, result_type, + operands, builder); + } + if (custom_call_target == "mhlo.dynamic_reshape") { + return ImportDynamicReshapeOp(backend_config_str, loc, result_type, + operands, builder); + } + if (custom_call_target == "mhlo.real_dynamic_slice") { + return ImportRealDynamicSliceOp(backend_config_str, loc, result_type, + operands, builder); + } + + auto backend_config = + mlir::parseAttribute(backend_config_str, builder->getContext()) + .dyn_cast(); + if (!backend_config) { + return Internal( + "Couldn't parse backend config into a dictionary attribute"); + } + + if (custom_call_target == "mhlo.uniform_quantize") { + return builder + ->create( + loc, + mlir::RankedTensorType::get( + result_type.cast().getShape(), + getQuantizedType(backend_config)), + operands) + .getOperation(); + } + + if (custom_call_target == "mhlo.uniform_dequantize") { + return builder + ->create(loc, result_type, operands) + .getOperation(); + } + return InvalidArgument("Unsupported MHLO op custom_call %s", + custom_call_target); +} + +bool IsOpEncodedCustomCall(const HloCustomCallInstruction* instruction) { + return absl::StartsWith(instruction->custom_call_target(), "mhlo."); +} + +} // namespace xla diff --git a/xla/translate/hlo_to_mhlo/custom_call_importer.h b/xla/translate/hlo_to_mhlo/custom_call_importer.h new file mode 100644 index 0000000000000..4fbe4516b9320 --- /dev/null +++ b/xla/translate/hlo_to_mhlo/custom_call_importer.h @@ -0,0 +1,45 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ +#define XLA_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ + +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/statusor.h" + +namespace xla { + +// Imports custom_calls prefixed with `mhlo.` from HLO to MHLO. +// This is used for ops in MHLO / StableHLO that don't exist in HLO. Many of +// these ops are needed for XlaBuilder clients that need to raise HLO to +// StableHLO. +absl::StatusOr ImportCustomCallAsOp( + const HloCustomCallInstruction* instruction, mlir::Location loc, + mlir::Type result_type, mlir::ValueRange operands, + mlir::OpBuilder* builder); + +// Indicates whether a custom call is an encoded MHLO op. +// Currently returns true for `mhlo.` prefixed custom calls. +bool IsOpEncodedCustomCall(const HloCustomCallInstruction* instruction); + +} // namespace xla + +#endif // XLA_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc index cc7aa9e9ed6e4..0e234bcdb7c60 100644 --- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,10 @@ limitations under the License. #include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" #include +#include +#include +#include +#include #include #include #include @@ -23,19 +27,32 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/status/status.h" #include "absl/types/optional.h" +#include "absl/types/span.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" #include "mlir/AsmParser/AsmParser.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -43,17 +60,24 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/hlo_sharding_metadata.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/literal.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/printer.h" #include "xla/protobuf_util.h" #include "xla/service/hlo.pb.h" -#include "xla/status_macros.h" +#include "xla/shape_layout.h" +#include "xla/shape_util.h" +#include "xla/status.h" #include "xla/translate/hlo_to_mhlo/attribute_importer.h" +#include "xla/translate/hlo_to_mhlo/custom_call_importer.h" #include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/translate/hlo_to_mhlo/location_importer.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" using llvm::APInt; @@ -100,7 +124,7 @@ bool DotIsDefault(const HloInstruction* instruction) { default_dimension_numbers.add_lhs_contracting_dimensions( instruction->operand(0)->shape().dimensions_size() == 1 ? 0 : 1); default_dimension_numbers.add_rhs_contracting_dimensions(0); - return xla::protobuf_util::ProtobufEquals(dnums, default_dimension_numbers); + return protobuf_util::ProtobufEquals(dnums, default_dimension_numbers); } // Clean up the GetTupleElementOp, created during the flattening of @@ -139,7 +163,7 @@ mlir::TypeRange Untuple(const mlir::Type& type) { } template -StatusOr HloFunctionImporter::ImportOldStyleAsyncStart( +absl::StatusOr HloFunctionImporter::ImportOldStyleAsyncStart( llvm::SmallVectorImpl& attributes, const llvm::SmallVectorImpl& operands, mlir::Location loc, mlir::Type result_type, mlir::OpBuilder* func_builder, @@ -206,7 +230,7 @@ StatusOr HloFunctionImporter::ImportOldStyleAsyncStart( .getOperation(); } -StatusOr HloFunctionImporter::ImportOldStyleAsyncDone( +absl::StatusOr HloFunctionImporter::ImportOldStyleAsyncDone( llvm::SmallVectorImpl& attributes, const llvm::SmallVectorImpl& operands, mlir::Location loc, mlir::Type result_type, mlir::OpBuilder* func_builder) { @@ -290,16 +314,16 @@ static bool IsNestedTupleInData(Type type) { return false; } -static bool HasCustomLayout(const xla::Shape& shape) { +static bool HasCustomLayout(const Shape& shape) { if (shape.IsTuple()) { return llvm::any_of(shape.tuple_shapes(), HasCustomLayout); } return shape.has_layout() && !shape.layout().minor_to_major().empty() && - shape.layout() != xla::LayoutUtil::GetDefaultLayoutForShape(shape); + shape.layout() != LayoutUtil::GetDefaultLayoutForShape(shape); } static mlir::Attribute GetLayoutAttribute(mlir::Builder& b, - const xla::Shape& shape) { + const Shape& shape) { if (shape.IsTuple()) { llvm::SmallVector element_attrs; for (const auto& tuple_shape : shape.tuple_shapes()) { @@ -320,8 +344,8 @@ static mlir::Attribute GetLayoutAttribute(mlir::Builder& b, return b.getIndexTensorAttr(layout); } -mlir::Attribute GetFrontendAttributes( - mlir::Builder& b, const xla::FrontendAttributes& attributes) { +mlir::Attribute GetFrontendAttributes(mlir::Builder& b, + const FrontendAttributes& attributes) { llvm::SmallVector attrs; attrs.reserve(attributes.map_size()); for (const auto& [k, v] : attributes.map()) { @@ -380,7 +404,7 @@ Value HloFunctionImporter::CreateTupleValue( .getResult(); } -StatusOr HloFunctionImporter::ImportAsFunc( +absl::StatusOr HloFunctionImporter::ImportAsFunc( const HloComputation& computation, mlir::SymbolTable& symbol_table, std::unordered_map* function_map, mlir::Builder* builder, bool is_main) { @@ -388,15 +412,16 @@ StatusOr HloFunctionImporter::ImportAsFunc( return importer.ImportAsFunc(computation, is_main); } -Status HloFunctionImporter::ImportAsRegion( - const xla::HloComputation& computation, mlir::SymbolTable& symbol_table, - mlir::Region* region, mlir::Builder* builder, - bool flatten_region_arg_tuple) { +Status HloFunctionImporter::ImportAsRegion(const HloComputation& computation, + mlir::SymbolTable& symbol_table, + mlir::Region* region, + mlir::Builder* builder, + bool flatten_region_arg_tuple) { HloFunctionImporter importer(symbol_table, {}, builder); return importer.ImportAsRegion(computation, region, flatten_region_arg_tuple); } -StatusOr HloFunctionImporter::ImportAsFunc( +absl::StatusOr HloFunctionImporter::ImportAsFunc( const HloComputation& computation, bool is_main) { std::string computation_name = is_main ? "main" : SanitizeFunctionName(computation.name()); @@ -477,7 +502,7 @@ StatusOr HloFunctionImporter::ImportAsFunc( computation_layout.result_layout().shape())); } if (llvm::any_of(computation_layout.parameter_layouts(), - [](const xla::ShapeLayout& shape) { + [](const ShapeLayout& shape) { return HasCustomLayout(shape.shape()); })) { llvm::SmallVector parameter_layouts; @@ -533,8 +558,8 @@ Status HloFunctionImporter::ImportAsRegion(const HloComputation& computation, return ImportInstructions(computation, block, flatten_region_arg_tuple); } -StatusOr HloFunctionImporter::ImportInstructionsImpl( - const xla::HloComputation& computation, +absl::StatusOr HloFunctionImporter::ImportInstructionsImpl( + const HloComputation& computation, const llvm::SmallVectorImpl& arguments, mlir::OpBuilder* builder) { // Setup the input parameters. const int num_parameters = computation.num_parameters(); @@ -550,7 +575,12 @@ StatusOr HloFunctionImporter::ImportInstructionsImpl( auto new_operation, ImportInstructionWithLayout(instruction, operands, builder)); if (new_operation) { - instruction_value_map_[instruction] = new_operation->getResult(0); + unsigned int idx = + (instruction->opcode() == HloOpcode::kRngBitGenerator && + instruction->shape().IsArray()) + ? 1 + : 0; + instruction_value_map_[instruction] = new_operation->getResult(idx); } } @@ -634,11 +664,11 @@ Status HloFunctionImporter::ImportInstructions( CleanUpTupleOps(block, &builder); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -StatusOr HloFunctionImporter::ImportInstructions( - const xla::HloComputation& computation, +absl::StatusOr HloFunctionImporter::ImportInstructions( + const HloComputation& computation, const llvm::SmallVectorImpl& arguments, mlir::SymbolTable& symbol_table, mlir::OpBuilder* builder) { mlir::Block* block = builder->getBlock(); @@ -650,8 +680,8 @@ StatusOr HloFunctionImporter::ImportInstructions( return importer.ImportInstructionsImpl(computation, arguments, builder); } -StatusOr HloFunctionImporter::ImportInstruction( - const xla::HloInstruction* instr, +absl::StatusOr HloFunctionImporter::ImportInstruction( + const HloInstruction* instr, const llvm::SmallVectorImpl& operands, mlir::SymbolTable& symbol_table, mlir::OpBuilder* builder, DynamicShapeHandlingMode mode) { @@ -664,13 +694,13 @@ StatusOr HloFunctionImporter::ImportInstruction( return importer.ImportInstructionWithLayout(instr, operands, builder, mode); } -StatusOr HloFunctionImporter::ImportInstructionImpl( +absl::StatusOr HloFunctionImporter::ImportInstructionImpl( const HloInstruction* instruction, const llvm::SmallVectorImpl& operands, mlir::OpBuilder* func_builder, DynamicShapeHandlingMode mode) { const Shape& instruction_shape = instruction->shape(); const Shape& shape = mode == DynamicShapeHandlingMode::kConvertToStatic - ? xla::ShapeUtil::MakeStaticShape(instruction_shape) + ? ShapeUtil::MakeStaticShape(instruction_shape) : instruction_shape; TF_ASSIGN_OR_RETURN(auto result_type, ConvertShapeToType(shape, *builder_)); @@ -733,11 +763,6 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( "execution_thread", builder_->getStringAttr(execution_thread))); function->setAttr("execution_thread", builder_->getStringAttr(execution_thread)); - auto group_id = async_op->async_group_id(); - if (group_id) { - attributes.push_back(builder_->getNamedAttr( - "group_id", builder_->getI64IntegerAttr(*group_id))); - } if (instruction->opcode() == HloOpcode::kAsyncStart) { auto bundle_result_type = mlir::mhlo::AsyncBundleType::get( @@ -813,12 +838,13 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( } case HloOpcode::kDot: { + auto dot = Cast(instruction); attributes.push_back(builder_->getNamedAttr( "precision_config", ConvertPrecisionConfig(&instruction->precision_config(), builder_))); // Consider consolidating DotOps together. - if (DotIsDefault(instruction)) { + if (DotIsDefault(instruction) && !dot->sparse_operands()) { return func_builder ->create(loc, result_type, operands, attributes) .getOperation(); @@ -828,9 +854,23 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( "dot_dimension_numbers", ConvertDotDimensionNumbers(instruction->dot_dimension_numbers(), builder_))); + if (!dot->sparse_operands()) { + return func_builder + ->create(loc, result_type, operands, + attributes) + .getOperation(); + } + + for (const SparsityDescriptor& descriptor : dot->sparsity()) { + TF_ASSIGN_OR_RETURN(auto sparsity, + ConvertSparsityDescriptor(descriptor, builder_)); + attributes.push_back(builder_->getNamedAttr( + descriptor.index() == 0 ? "lhs_sparsity" : "rhs_sparsity", + sparsity)); + } return func_builder - ->create(loc, result_type, operands, - attributes) + ->create(loc, result_type, operands, + attributes) .getOperation(); } case HloOpcode::kCall: { @@ -844,6 +884,19 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( } return new_operation; } + case HloOpcode::kCollectiveBroadcast: { + auto collective_broadcast = Cast(instruction); + attributes.push_back(ConvertReplicaGroups( + collective_broadcast->replica_groups(), builder_)); + if (collective_broadcast->channel_id().has_value()) + attributes.push_back( + ConvertChannelHandle(collective_broadcast->channel_id().value())); + return func_builder + ->create(loc, result_type, + operands, attributes) + .getOperation(); + } + case HloOpcode::kCollectivePermute: { auto collective_permute = Cast(instruction); attributes.push_back(ConvertSourceTargetPairs( @@ -861,7 +914,7 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( instruction->source_target_pairs(), builder_)); return ImportOldStyleAsyncStart( attributes, operands, loc, result_type, func_builder, - "collective_permute_", [&](auto) { return ::tsl::OkStatus(); }); + "collective_permute_", [&](auto) { return absl::OkStatus(); }); } case HloOpcode::kCollectivePermuteDone: { return ImportOldStyleAsyncDone(attributes, operands, loc, result_type, @@ -869,6 +922,10 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( } case HloOpcode::kCustomCall: { auto custom_call = Cast(instruction); + if (IsOpEncodedCustomCall(custom_call)) { + return ImportCustomCallAsOp(custom_call, loc, result_type, operands, + func_builder); + } const auto& called_computations = custom_call->called_computations(); if (!called_computations.empty()) { llvm::SmallVector callees; @@ -1208,7 +1265,7 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( } return ImportOldStyleAsyncStart( attributes, operands, loc, result_type, func_builder, "copy_", - [](auto) { return ::tsl::OkStatus(); }); + [](auto) { return absl::OkStatus(); }); } case HloOpcode::kCopyDone: { return ImportOldStyleAsyncDone(attributes, operands, loc, result_type, @@ -1233,11 +1290,11 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( "is_host_transfer", builder_->getBoolAttr(send_op->is_host_transfer()))); if (send_op->channel_id().has_value()) { - xla::ChannelHandle channel_handle; + ChannelHandle channel_handle; channel_handle.set_handle(send_op->channel_id().value()); channel_handle.set_type(send_op->is_host_transfer() - ? xla::ChannelHandle::DEVICE_TO_HOST - : xla::ChannelHandle::DEVICE_TO_DEVICE); + ? ChannelHandle::DEVICE_TO_HOST + : ChannelHandle::DEVICE_TO_DEVICE); attributes.push_back(ConvertChannelHandle(channel_handle)); } return ImportOldStyleAsyncStart( @@ -1267,11 +1324,11 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( "is_host_transfer", builder_->getBoolAttr(recv_op->is_host_transfer()))); if (recv_op->channel_id().has_value()) { - xla::ChannelHandle channel_handle; + ChannelHandle channel_handle; channel_handle.set_handle(recv_op->channel_id().value()); channel_handle.set_type(recv_op->is_host_transfer() - ? xla::ChannelHandle::HOST_TO_DEVICE - : xla::ChannelHandle::DEVICE_TO_DEVICE); + ? ChannelHandle::HOST_TO_DEVICE + : ChannelHandle::DEVICE_TO_DEVICE); attributes.push_back(ConvertChannelHandle(channel_handle)); } return ImportOldStyleAsyncStart( @@ -1368,6 +1425,12 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( } case HloOpcode::kAllGather: { auto all_gather = Cast(instruction); + auto result_tuple_ty = result_type.dyn_cast(); + + llvm::SmallVector result_types = {result_type}; + if (result_tuple_ty) { + result_types = llvm::to_vector(result_tuple_ty.getTypes()); + } attributes.push_back(builder_->getNamedAttr( "all_gather_dim", builder_->getI64IntegerAttr(all_gather->all_gather_dimension()))); @@ -1378,10 +1441,15 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( ConvertChannelHandle(all_gather->channel_id().value())); if (all_gather->use_global_device_ids()) attributes.push_back(ConvertUseGlobalDeviceIds()); - return func_builder - ->create(loc, result_type, operands, - attributes) - .getOperation(); + auto all_gather_op = func_builder->create( + loc, result_types, operands, attributes); + if (result_tuple_ty) { + return func_builder + ->create(loc, result_type, + all_gather_op.getResults()) + .getOperation(); + } + return all_gather_op.getOperation(); } case HloOpcode::kAllGatherStart: { auto all_gather_start = Cast(instruction); @@ -1395,10 +1463,13 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( ConvertChannelHandle(all_gather_start->channel_id().value())); if (all_gather_start->use_global_device_ids()) attributes.push_back(ConvertUseGlobalDeviceIds()); + if (all_gather_start->operands().size() > 1) + return InvalidArgument( + "Async tuple all-gather is not supported in MHLO"); return ImportOldStyleAsyncStart( attributes, operands, loc, result_type, func_builder, "all_gather_", - [](auto) { return ::tsl::OkStatus(); }); + [](auto) { return absl::OkStatus(); }); } case HloOpcode::kAllGatherDone: { return ImportOldStyleAsyncDone(attributes, operands, loc, result_type, @@ -1452,7 +1523,7 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( TF_RETURN_IF_ERROR(ImportAsRegion( *instruction->to_apply(), &all_reduce_sync.getComputation(), /*flatten_region_arg_tuple=*/true)); - return ::tsl::OkStatus(); + return absl::OkStatus(); }); } case HloOpcode::kAllReduceDone: { @@ -1556,14 +1627,14 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( auto shape = func_builder->create( loc, Convert(result_type.cast().getShape())); switch (instruction->random_distribution()) { - case xla::RNG_UNIFORM: + case RNG_UNIFORM: return func_builder ->create( loc, result_type, operands[0], operands[1], shape, ::mlir::mhlo::RngDistribution::UNIFORM) .getOperation(); - case xla::RNG_NORMAL: + case RNG_NORMAL: return func_builder ->create(loc, result_type, operands[0], operands[1], shape, @@ -1577,18 +1648,44 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( } } case HloOpcode::kRngBitGenerator: { + // HloRngBitGeneratorInstruction can have two kinds of shapes, (1) + // tuple(output_state, output_data), and (2) output_data. + // mhlo::RngBitGeneratorOp has only one shape, (output_state, + // output_data). auto rng_op = Cast(instruction); + auto algorithm_attr = mlir::mhlo::RngAlgorithmAttr::get( + builder_->getContext(), + *mlir::mhlo::symbolizeRngAlgorithm(rng_op->algorithm())); + attributes.push_back( + builder_->getNamedAttr("rng_algorithm", algorithm_attr)); + // Flatten the return type if they are tuple-typed. llvm::SmallVector flattened_ret_types; FlattenTupleType(result_type, flattened_ret_types); + if (rng_op->shape().IsArray()) { + TF_ASSIGN_OR_RETURN(auto state_type, + ConvertShapeToType( + rng_op->operand(0)->shape(), *builder_)); + flattened_ret_types.insert(flattened_ret_types.begin(), state_type); + + if (instruction->has_sharding()) { + Shape tuple_shape = ShapeUtil::MakeTupleShape( + {rng_op->operand(0)->shape(), instruction->shape()}); + HloSharding tuple_sharding = HloSharding::Tuple( + tuple_shape, {HloSharding::Replicate(), instruction->sharding()}); + CHECK_EQ(attributes.front().getName().str(), kShardingAttr); + attributes.front() = builder_->getNamedAttr( + kShardingAttr, ConvertSharding(tuple_sharding, builder_)); + } + } + CHECK_EQ(flattened_ret_types.size(), 2); - auto algorithm_attr = mlir::mhlo::RngAlgorithmAttr::get( - builder_->getContext(), - *mlir::mhlo::symbolizeRngAlgorithm(rng_op->algorithm())); auto op = func_builder->create( - loc, flattened_ret_types, algorithm_attr, operands[0]); - + loc, flattened_ret_types, operands[0], attributes); + if (rng_op->shape().IsArray()) { + return op.getOperation(); + } return CreateTupleFromOpResults(func_builder, loc, op.getOperation(), result_type); } @@ -1844,7 +1941,7 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( } // Return type is boolean, let's use `operand != 0` instead of Convert. - xla::Shape input_shape = instruction->operand(0)->shape(); + Shape input_shape = instruction->operand(0)->shape(); TF_ASSIGN_OR_RETURN(mlir::Type type, ConvertTensorShapeToType( input_shape, *func_builder)); @@ -1944,6 +2041,7 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( NO_ATTRIBUTE_CASE(kReplicaId, ReplicaIdOp); NO_ATTRIBUTE_CASE(kStochasticConvert, StochasticConvertOp); NO_ATTRIBUTE_CASE(kLogistic, LogisticOp); + NO_ATTRIBUTE_CASE(kErf, ErfOp); // The dimensions attribute is not present on the HLO Reshape // instruction. If dimensions are non-default, the XLA builder // implements it as a separate transpose. @@ -1981,8 +2079,8 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( llvm::SmallVector flattened_ret_types; FlattenTupleType(result_type, flattened_ret_types); - auto fusion_kind = mlir::mhlo::symbolizeFusionKind( - xla::ToString(instruction->fusion_kind())); + auto fusion_kind = + mlir::mhlo::symbolizeFusionKind(ToString(instruction->fusion_kind())); attributes.push_back(builder_->getNamedAttr( "fusion_kind", mlir::mhlo::FusionKindAttr::get( func_builder->getContext(), fusion_kind.value()))); @@ -2038,7 +2136,8 @@ void SetXlaShape(mlir::Operation* op, const Shape& shape) { .getStringAttr(shape.ToString(/*print_layout=*/true))); } -StatusOr HloFunctionImporter::ImportInstructionWithLayout( +absl::StatusOr +HloFunctionImporter::ImportInstructionWithLayout( const HloInstruction* instruction, const llvm::SmallVectorImpl& operands, mlir::OpBuilder* func_builder, DynamicShapeHandlingMode mode) { @@ -2062,8 +2161,8 @@ StatusOr HloFunctionImporter::ImportInstructionWithLayout( return op; } -StatusOr> HloFunctionImporter::GetOperands( - const HloInstruction* instruction) { +absl::StatusOr> +HloFunctionImporter::GetOperands(const HloInstruction* instruction) { llvm::SmallVector operands; for (const auto& operand : instruction->operands()) { auto input_it = instruction_value_map_.find(operand); @@ -2084,10 +2183,10 @@ Status HloFunctionImporter::GetMlirTypes( instruction->shape(), *builder_)); types->push_back(ret_type); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } -StatusOr HloFunctionImporter::GetMlirValue( +absl::StatusOr HloFunctionImporter::GetMlirValue( const HloInstruction* instruction) { auto lookup = instruction_value_map_.find(instruction); if (lookup != instruction_value_map_.end()) { @@ -2211,13 +2310,13 @@ mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups( mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle( std::optional channel_id) { - xla::ChannelHandle channel_handle; + ChannelHandle channel_handle; if (channel_id) channel_handle.set_handle(*channel_id); return ConvertChannelHandle(channel_handle); } mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle( - const xla::ChannelHandle& channel) { + const ChannelHandle& channel) { return builder_->getNamedAttr( "channel_handle", mlir::mhlo::ChannelHandleAttr::get( context_, channel.handle(), channel.type())); @@ -2236,10 +2335,10 @@ void HloFunctionImporter::SetLayoutForMlir(mlir::Operation* op, } Status HloFunctionImporter::ConvertShapeToMlirLayout( - const xla::Shape& shape, + const Shape& shape, llvm::SmallVectorImpl& flattened_attr) { if (shape.IsToken()) { - return ::tsl::OkStatus(); + return absl::OkStatus(); } if (shape.IsTuple()) { std::vector tuple_layouts; @@ -2247,29 +2346,29 @@ Status HloFunctionImporter::ConvertShapeToMlirLayout( TF_RETURN_IF_ERROR( ConvertShapeToMlirLayout(shape.tuple_shapes(i), flattened_attr)); } - return ::tsl::OkStatus(); + return absl::OkStatus(); } if (shape.IsArray()) { - const xla::Layout l = shape.layout(); + const Layout l = shape.layout(); std::vector minor_to_major; for (int64_t i : l.minor_to_major()) { minor_to_major.push_back(builder_->getI64IntegerAttr(i)); } llvm::ArrayRef array_ref(minor_to_major); flattened_attr.push_back(builder_->getArrayAttr(array_ref)); - return ::tsl::OkStatus(); + return absl::OkStatus(); } return Internal("Couldn't convert layout."); } -mlir::Attribute ConvertSharding(const xla::HloSharding& sharding, +mlir::Attribute ConvertSharding(const HloSharding& sharding, mlir::Builder* builder) { return builder->getStringAttr(sharding.ToString(/*include_metadata=*/true)); } -mlir::Attribute ConvertSharding(const xla::OpSharding& sharding, +mlir::Attribute ConvertSharding(const OpSharding& sharding, mlir::Builder* builder) { - auto hlo_sharding = xla::HloSharding::FromProto(sharding); + auto hlo_sharding = HloSharding::FromProto(sharding); if (!hlo_sharding.ok()) return {}; return ConvertSharding(hlo_sharding.value(), builder); } diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.h b/xla/translate/hlo_to_mhlo/hlo_function_importer.h index 66794a4bf8048..760730271b5b0 100644 --- a/xla/translate/hlo_to_mhlo/hlo_function_importer.h +++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,6 +28,8 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -54,29 +56,29 @@ class HloFunctionImporter { // Imports the given computation as a function in the given symbol table and // returns the FuncOp. This also imports any computations referred by // instructions in this computation. - static StatusOr ImportAsFunc( - const xla::HloComputation& computation, mlir::SymbolTable& symbol_table, - std::unordered_map* + static absl::StatusOr ImportAsFunc( + const HloComputation& computation, mlir::SymbolTable& symbol_table, + std::unordered_map* function_map, mlir::Builder* builder, bool is_main); // Imports the given hlo computation to the specified region. If // 'flatten_region_arg_tuple' is true, then flatten the tuple-typed region // argument(s) and return value(s). - static Status ImportAsRegion(const xla::HloComputation& computation, + static Status ImportAsRegion(const HloComputation& computation, mlir::SymbolTable& symbol_table, mlir::Region* region, mlir::Builder* builder, bool flatten_region_arg_tuple = false); // Imports the given computation to the given place specified by `builder`. // `arguments` contains values for all parameters. - static StatusOr ImportInstructions( - const xla::HloComputation& computation, + static absl::StatusOr ImportInstructions( + const HloComputation& computation, const llvm::SmallVectorImpl& arguments, mlir::SymbolTable& symbol_table, mlir::OpBuilder* builder); - static StatusOr ImportInstruction( - const xla::HloInstruction* instr, + static absl::StatusOr ImportInstruction( + const HloInstruction* instr, const llvm::SmallVectorImpl& operands, mlir::SymbolTable& symbol_table, mlir::OpBuilder* builder, DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); @@ -137,7 +139,7 @@ class HloFunctionImporter { private: HloFunctionImporter(mlir::SymbolTable& symbol_table, - std::unordered_map* function_map, mlir::Builder* builder) : context_(symbol_table.getOp()->getContext()), @@ -152,8 +154,8 @@ class HloFunctionImporter { // Imports the given computation as a new function, if it hasn't been already // imported. - StatusOr ImportAsFunc( - const xla::HloComputation& computation, bool is_main); + absl::StatusOr ImportAsFunc( + const HloComputation& computation, bool is_main); // Imports the given computation in the specified region. Status ImportAsRegion(const HloComputation& computation, mlir::Region* region, @@ -163,39 +165,39 @@ class HloFunctionImporter { // Assumes that the block already has correct arguments populated. Status ImportInstructions(const HloComputation& computation, mlir::Block* block, bool flatten_region_arg_tuple); - StatusOr ImportInstructionsImpl( - const xla::HloComputation& computation, + absl::StatusOr ImportInstructionsImpl( + const HloComputation& computation, const llvm::SmallVectorImpl& arguments, mlir::OpBuilder* builder); // Imports an instruction. - StatusOr ImportInstructionWithLayout( - const xla::HloInstruction* instruction, + absl::StatusOr ImportInstructionWithLayout( + const HloInstruction* instruction, const llvm::SmallVectorImpl& operands, mlir::OpBuilder* func_builder, DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); - StatusOr ImportInstructionImpl( + absl::StatusOr ImportInstructionImpl( const HloInstruction* instruction, const llvm::SmallVectorImpl& operands, mlir::OpBuilder* func_builder, DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); // Gets the MLIR operand values from an HLO Instruction. - StatusOr> GetOperands( - const xla::HloInstruction* instruction); + absl::StatusOr> GetOperands( + const HloInstruction* instruction); // Converts xla Tensor type to the corresponding MLIR type. - StatusOr ConvertTensorType(const xla::Shape& shape); + absl::StatusOr ConvertTensorType(const Shape& shape); // Converts an XLA shape/layout to the corresponding MLIR layout, in // flattened_attr, while flattening the tuple layout. Status ConvertShapeToMlirLayout( - const xla::Shape& shape, + const Shape& shape, llvm::SmallVectorImpl& flattened_attr); // Returns the output type of an HloInstruction. - StatusOr GetReturnType(const xla::HloInstruction* instruction); + absl::StatusOr GetReturnType(const HloInstruction* instruction); // Takes a list of HloInstructions and generates the list of types used for // input, bypassing tuples to subsets. @@ -203,7 +205,7 @@ class HloFunctionImporter { llvm::SmallVectorImpl* types); // Returns the Mlir Value for the corresponding HloInstruction. - StatusOr GetMlirValue(const xla::HloInstruction* instruction); + absl::StatusOr GetMlirValue(const HloInstruction* instruction); // Converts an XLA ComparisonDirection to the corresponding MLIR attribute. mlir::NamedAttribute ConvertComparisonDirection( @@ -213,8 +215,7 @@ class HloFunctionImporter { mlir::NamedAttribute ConvertComparisonType(Comparison::Type type); // Converts an XLA CustomCallSchedule to the corresponding MLIR attribute. - mlir::NamedAttribute ConvertCustomCallSchedule( - xla::CustomCallSchedule schedule); + mlir::NamedAttribute ConvertCustomCallSchedule(CustomCallSchedule schedule); // Converts the dimensions of an HLO instruction into an MLIR attribute. mlir::DenseIntElementsAttr ConvertDimensions( @@ -237,7 +238,7 @@ class HloFunctionImporter { mlir::NamedAttribute ConvertUseGlobalDeviceIds(); // Converts channel handle to attribute - mlir::NamedAttribute ConvertChannelHandle(const xla::ChannelHandle& channel); + mlir::NamedAttribute ConvertChannelHandle(const ChannelHandle& channel); // ============ // Imports an old-style async start op. E.g. an HLO all-gather-start @@ -255,14 +256,14 @@ class HloFunctionImporter { // new-style async API. // ============ template - StatusOr ImportOldStyleAsyncStart( + absl::StatusOr ImportOldStyleAsyncStart( llvm::SmallVectorImpl& attributes, const llvm::SmallVectorImpl& operands, mlir::Location loc, mlir::Type result_type, mlir::OpBuilder* func_builder, std::string func_name, std::function mutate_op); // Imports an old-style async done op - StatusOr ImportOldStyleAsyncDone( + absl::StatusOr ImportOldStyleAsyncDone( llvm::SmallVectorImpl& attributes, const llvm::SmallVectorImpl& operands, mlir::Location loc, mlir::Type result_type, mlir::OpBuilder* func_builder); @@ -275,25 +276,23 @@ class HloFunctionImporter { mlir::Builder* builder_; // Mapping from HloComputation to the created MLIR function. - std::unordered_map* - function_map_; + std::unordered_map* function_map_; // Mapping from HloInstructions to the associative MLIR values. - std::unordered_map - instruction_value_map_; + std::unordered_map instruction_value_map_; }; // Returns a StringAttr that carries a prettyprinted representation of the // given HLO C++ sharding. // Always succeeds and returns a non-empty attribute. -mlir::Attribute ConvertSharding(const xla::HloSharding& sharding, +mlir::Attribute ConvertSharding(const HloSharding& sharding, mlir::Builder* builder); // Returns a StringAttr that carries a prettyprinted representation of the // given HLO proto sharding. // Will fail and return an empty attribute if the proto sharding cannot be // converted to the C++ sharding. -mlir::Attribute ConvertSharding(const xla::OpSharding& sharding, +mlir::Attribute ConvertSharding(const OpSharding& sharding, mlir::Builder* builder); } // namespace xla diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc index 1494efd9ebc53..c1645030e9408 100644 --- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,21 +15,29 @@ limitations under the License. #include "xla/translate/hlo_to_mhlo/hlo_module_importer.h" -#include +#include #include -#include - +#include + +#include "absl/types/span.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/layout_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/permutation_util.h" -#include "xla/translate/hlo_to_mhlo/attribute_importer.h" +#include "xla/status.h" #include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/xla.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -41,6 +49,8 @@ HloModuleImporter::HloModuleImporter(mlir::ModuleOp module, module.getContext()->loadDialect(); module.getContext()->loadDialect(); module.getContext()->loadDialect(); + module.getContext()->loadDialect(); + module.getContext()->loadDialect(); } namespace { @@ -48,7 +58,7 @@ namespace { constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes"; mlir::ArrayAttr ConvertCrossProgramPrefetches( - const absl::Span prefetches, + const absl::Span prefetches, mlir::Builder* builder) { llvm::SmallVector shapes; for (auto [parameter, index, alt_memory_offset] : prefetches) { @@ -65,7 +75,7 @@ mlir::ArrayAttr ConvertCrossProgramPrefetches( } } // namespace -Status HloModuleImporter::Import(const xla::HloModule& hlo_module) { +Status HloModuleImporter::Import(const HloModule& hlo_module) { auto module = llvm::cast(symbol_table_.getOp()); module.setName(hlo_module.name()); module->setAttr("mhlo.cross_program_prefetches", @@ -113,19 +123,19 @@ Status HloModuleImporter::Import(const xla::HloModule& hlo_module) { return OkStatus(); } -Status HloModuleImporter::Import(const xla::HloModuleProto& module_proto) { - xla::DebugOptions debug_options; +Status HloModuleImporter::Import(const HloModuleProto& module_proto) { + DebugOptions debug_options; TF_ASSIGN_OR_RETURN( auto module_config, - xla::HloModule::CreateModuleConfigFromProto(module_proto, debug_options)); - TF_ASSIGN_OR_RETURN(auto module, xla::HloModule::CreateFromProto( - module_proto, module_config)); + HloModule::CreateModuleConfigFromProto(module_proto, debug_options)); + TF_ASSIGN_OR_RETURN(auto module, + HloModule::CreateFromProto(module_proto, module_config)); return Import(*module); } -void HloModuleImporter::ImportFrontendAttributes( - const xla::HloModule& hlo_module, mlir::ModuleOp module) { +void HloModuleImporter::ImportFrontendAttributes(const HloModule& hlo_module, + mlir::ModuleOp module) { if (!hlo_module.frontend_attributes().map().empty()) { llvm::SmallVector frontend_attributes; for (const auto& [k, v] : hlo_module.frontend_attributes().map()) { diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.h b/xla/translate/hlo_to_mhlo/hlo_module_importer.h index ac42e28f3f212..10dcbb9e2646a 100644 --- a/xla/translate/hlo_to_mhlo/hlo_module_importer.h +++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc b/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc index 3354495a52d76..37064723890ff 100644 --- a/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc +++ b/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h b/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h index 21b3c1589c0bd..de5cda6d1dde2 100644 --- a/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h +++ b/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/translate/hlo_to_mhlo/hlo_utils.cc b/xla/translate/hlo_to_mhlo/hlo_utils.cc index 4c2f38a15a7a1..30c680b4ffc9f 100644 --- a/xla/translate/hlo_to_mhlo/hlo_utils.cc +++ b/xla/translate/hlo_to_mhlo/hlo_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,13 +23,11 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "mlir/IR/AffineMap.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "xla/literal.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" +#include "xla/mlir/utils/type_util.h" #include "xla/primitive_util.h" -#include "xla/service/llvm_ir/llvm_util.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -65,8 +63,8 @@ ::mlir::DenseElementsAttr CreateDenseAttrFromLiteral( } } -StatusOr GetPermutationIfAvailable(const Shape& shape, - mlir::Builder builder) { +absl::StatusOr GetPermutationIfAvailable(const Shape& shape, + mlir::Builder builder) { // N.B. IsMonotonicWithDim0Major ignores tiling, and I can't change it because // some XLA code relies on it treating tiled layouts as equivalent to untiled // layouts, so the check to rule out tiling has to come /before/ the @@ -95,10 +93,10 @@ StatusOr GetPermutationIfAvailable(const Shape& shape, } } // namespace -StatusOr ConvertTensorShapeToMemRefType( +absl::StatusOr ConvertTensorShapeToMemRefType( const Shape& shape, mlir::Builder builder) { auto element_type_or = - ConvertPrimitiveTypeToMLIRType(shape.element_type(), builder); + ConvertPrimitiveTypeToMlirType(shape.element_type(), builder); if (!element_type_or.ok()) return element_type_or.status(); using mlir::MemRefType; @@ -110,7 +108,7 @@ StatusOr ConvertTensorShapeToMemRefType( permutation_or.value()); } -StatusOr CreateDenseElementsAttrFromLiteral( +absl::StatusOr CreateDenseElementsAttrFromLiteral( const LiteralBase& literal, Builder builder) { TF_ASSIGN_OR_RETURN(auto type, ConvertTensorShapeToType( @@ -119,7 +117,8 @@ StatusOr CreateDenseElementsAttrFromLiteral( // TODO(hinsu): Support remaining XLA primitive types. auto element_type = literal.shape().element_type(); return primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> StatusOr { + [&](auto primitive_type_constant) + -> absl::StatusOr { if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { return CreateDenseAttrFromLiteral< primitive_util::NativeTypeOf>(type, @@ -131,20 +130,6 @@ StatusOr CreateDenseElementsAttrFromLiteral( element_type); } -StatusOr GetElementTypeBytes(mlir::Type type) { - if (type.isInteger(1)) { - return 1; - } - if (auto complex_type = type.dyn_cast()) { - TF_ASSIGN_OR_RETURN(int bytes, - GetElementTypeBytes(complex_type.getElementType())); - return bytes * 2; - } - int width = type.getIntOrFloatBitWidth(); - TF_RET_CHECK(width % 8 == 0); - return width / 8; -} - mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector( const llvm::ArrayRef vector, mlir::Builder builder, llvm::ArrayRef shape) { @@ -154,275 +139,4 @@ mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector( vector); } -StatusOr ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type, - mlir::Builder builder) { - switch (element_type) { - case PrimitiveType::PRED: - return builder.getI1Type(); - case PrimitiveType::F8E5M2: - return builder.getFloat8E5M2Type(); - case PrimitiveType::F8E4M3FN: - return builder.getFloat8E4M3FNType(); - case PrimitiveType::F8E4M3B11FNUZ: - return builder.getFloat8E4M3B11FNUZType(); - case PrimitiveType::F8E5M2FNUZ: - return builder.getFloat8E5M2FNUZType(); - case PrimitiveType::F8E4M3FNUZ: - return builder.getFloat8E4M3FNUZType(); - case PrimitiveType::F16: - return builder.getF16Type(); - case PrimitiveType::BF16: - return builder.getBF16Type(); - case PrimitiveType::F32: - return builder.getF32Type(); - case PrimitiveType::F64: - return builder.getF64Type(); - // TODO(b/130356985): Support unsigned primitive types. - default: - if (primitive_util::IsIntegralType(element_type)) { - return mlir::IntegerType::get( - builder.getContext(), - /*width=*/primitive_util::BitWidth(element_type), - /*signed=*/ - primitive_util::IsUnsignedIntegralType(element_type) - ? mlir::IntegerType::Unsigned - : mlir::IntegerType::Signless); - } - if (primitive_util::IsComplexType(element_type)) { - TF_ASSIGN_OR_RETURN( - mlir::Type component_type, - ConvertPrimitiveTypeToMLIRType( - primitive_util::ComplexComponentType(element_type), builder)); - return mlir::ComplexType::get(component_type); - } - return Internal("Unsupported type: %s", PrimitiveType_Name(element_type)); - } -} - -mlir::mhlo::GatherDimensionNumbersAttr CreateGatherDimensionNumbers( - const GatherDimensionNumbers& input, mlir::Builder builder) { - auto get_i64_array = [](absl::Span container) { - return llvm::ArrayRef{container.data(), container.size()}; - }; - return mlir::mhlo::GatherDimensionNumbersAttr::get( - builder.getContext(), get_i64_array(input.offset_dims()), - get_i64_array(input.collapsed_slice_dims()), - get_i64_array(input.start_index_map()), input.index_vector_dim()); -} - -StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op) { - using mlir::isa; - - if (isa(op)) { - return xla::HloOpcode::kConstant; - } else if (isa(op)) { - return xla::HloOpcode::kIota; - } else if (isa(op)) { - return xla::HloOpcode::kConvert; - } else if (isa(op)) { - return xla::HloOpcode::kAdd; - } else if (isa(op)) { - return xla::HloOpcode::kAtan2; - } else if (isa(op)) { - return xla::HloOpcode::kDivide; - } else if (isa(op)) { - return xla::HloOpcode::kMaximum; - } else if (isa(op)) { - return xla::HloOpcode::kMinimum; - } else if (isa(op)) { - return xla::HloOpcode::kMultiply; - } else if (isa(op)) { - return xla::HloOpcode::kPower; - } else if (isa(op)) { - return xla::HloOpcode::kRemainder; - } else if (isa(op)) { - return xla::HloOpcode::kShiftLeft; - } else if (isa(op)) { - return xla::HloOpcode::kShiftRightArithmetic; - } else if (isa(op)) { - return xla::HloOpcode::kShiftRightLogical; - } else if (isa(op)) { - return xla::HloOpcode::kSubtract; - } else if (isa(op)) { - return xla::HloOpcode::kXor; - } else if (isa(op)) { - return xla::HloOpcode::kInfeed; - } else if (isa(op)) { - return xla::HloOpcode::kOutfeed; - } else if (isa(op)) { - return xla::HloOpcode::kSend; - } else if (isa(op)) { - return xla::HloOpcode::kRecv; - } else if (isa(op)) { - return xla::HloOpcode::kReplicaId; - } else if (isa(op)) { - return xla::HloOpcode::kAfterAll; - } else if (isa(op)) { - return xla::HloOpcode::kAllReduce; - } else if (isa(op)) { - return xla::HloOpcode::kAllToAll; - } else if (isa(op)) { - return xla::HloOpcode::kTuple; - } else if (isa( - op)) { - return xla::HloOpcode::kBatchNormGrad; - } else if (isa(op)) { - return xla::HloOpcode::kBatchNormInference; - } else if (isa(op)) { - return xla::HloOpcode::kBatchNormTraining; - } else if (isa( - op)) { - return xla::HloOpcode::kBitcastConvert; - } else if (isa(op)) { - return xla::HloOpcode::kBroadcast; - } else if (isa(op)) { - return xla::HloOpcode::kCholesky; - } else if (isa(op)) { - return xla::HloOpcode::kClamp; - } else if (isa(op)) { - return xla::HloOpcode::kConcatenate; - } else if (isa(op)) { - return xla::HloOpcode::kConvolution; - } else if (isa(op)) { - return xla::HloOpcode::kSort; - } else if (isa(op)) { - return xla::HloOpcode::kTopK; - } else if (isa(op)) { - return xla::HloOpcode::kRngBitGenerator; - } else if (isa(op)) { - return xla::HloOpcode::kRngGetAndUpdateState; - } else if (isa(op)) { - return xla::HloOpcode::kFusion; - } else if (isa(op)) { - return xla::HloOpcode::kBitcast; - } else if (isa(op)) { - return xla::HloOpcode::kAbs; - } else if (isa(op)) { - return xla::HloOpcode::kCbrt; - } else if (isa(op)) { - return xla::HloOpcode::kCeil; - } else if (isa(op)) { - return xla::HloOpcode::kClz; - } else if (isa(op)) { - return xla::HloOpcode::kCos; - } else if (isa(op)) { - return xla::HloOpcode::kExp; - } else if (isa(op)) { - return xla::HloOpcode::kExpm1; - } else if (isa(op)) { - return xla::HloOpcode::kFloor; - } else if (isa(op)) { - return xla::HloOpcode::kImag; - } else if (isa(op)) { - return xla::HloOpcode::kIsFinite; - } else if (isa(op)) { - return xla::HloOpcode::kLog; - } else if (isa(op)) { - return xla::HloOpcode::kLog1p; - } else if (isa(op)) { - return xla::HloOpcode::kLogistic; - } else if (isa(op)) { - return xla::HloOpcode::kNot; - } else if (isa(op)) { - return xla::HloOpcode::kNegate; - } else if (isa( - op)) { - return xla::HloOpcode::kPopulationCount; - } else if (isa(op)) { - return xla::HloOpcode::kReal; - } else if (isa(op)) { - return xla::HloOpcode::kRoundNearestAfz; - } else if (isa(op)) { - return xla::HloOpcode::kRoundNearestEven; - } else if (isa(op)) { - return xla::HloOpcode::kRsqrt; - } else if (isa(op)) { - return xla::HloOpcode::kSign; - } else if (isa(op)) { - return xla::HloOpcode::kSin; - } else if (isa(op)) { - return xla::HloOpcode::kSqrt; - } else if (isa(op)) { - return xla::HloOpcode::kTan; - } else if (isa(op)) { - return xla::HloOpcode::kTanh; - } else if (isa(op)) { - return xla::HloOpcode::kComplex; - } else if (isa(op)) { - return xla::HloOpcode::kAnd; - } else if (isa(op)) { - return xla::HloOpcode::kOr; - } else if (isa(op)) { - return xla::HloOpcode::kWhile; - } else if (isa(op)) { - return xla::HloOpcode::kReduce; - } else if (isa(op)) { - return xla::HloOpcode::kGetTupleElement; - } else if (isa(op)) { - return xla::HloOpcode::kCompare; - } else if (isa(op)) { - return xla::HloOpcode::kSlice; - } else if (isa(op)) { - return xla::HloOpcode::kDynamicSlice; - } else if (isa(op)) { - return xla::HloOpcode::kDynamicUpdateSlice; - } else if (isa(op)) { - return xla::HloOpcode::kCollectivePermute; - } else if (isa(op)) { - return xla::HloOpcode::kCopy; - } else if (isa(op)) { - return xla::HloOpcode::kCustomCall; - } else if (isa(op)) { - return xla::HloOpcode::kDot; - } else if (isa(op)) { - return xla::HloOpcode::kFft; - } else if (isa(op)) { - return xla::HloOpcode::kGather; - } else if (isa(op)) { - return xla::HloOpcode::kGetDimensionSize; - } else if (isa(op)) { - return xla::HloOpcode::kMap; - } else if (isa(op)) { - return xla::HloOpcode::kReshape; - } else if (isa(op)) { - return xla::HloOpcode::kDynamicReshape; - } else if (isa(op)) { - return xla::HloOpcode::kScatter; - } else if (isa(op)) { - return xla::HloOpcode::kSelect; - } else if (isa(op)) { - return xla::HloOpcode::kSelectAndScatter; - } else if (isa(op)) { - return xla::HloOpcode::kSetDimensionSize; - } else if (isa(op)) { - return xla::HloOpcode::kReverse; - } else if (isa(op)) { - return xla::HloOpcode::kPad; - } else if (isa(op)) { - return xla::HloOpcode::kTranspose; - } else if (isa( - op)) { - return xla::HloOpcode::kTriangularSolve; - } else if (isa(op)) { - return xla::HloOpcode::kReduceWindow; - } else if (isa( - op)) { - return xla::HloOpcode::kReducePrecision; - } else if (isa(op)) { - return xla::HloOpcode::kDot; - } else if (isa( - op)) { - return xla::HloOpcode::kBroadcast; - } else { - return Unimplemented("Unimplemented MHLO -> HloOpcode: %s", - llvm_ir::DumpToString(op)); - } -} } // namespace xla diff --git a/xla/translate/hlo_to_mhlo/hlo_utils.h b/xla/translate/hlo_to_mhlo/hlo_utils.h index b5f44584f1abc..b570125c654da 100644 --- a/xla/translate/hlo_to_mhlo/hlo_utils.h +++ b/xla/translate/hlo_to_mhlo/hlo_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,35 +25,28 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/mlir/utils/type_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/utils/convert_op_folder.h" #include "xla/util.h" namespace xla { -StatusOr CreateDenseElementsAttrFromLiteral( +absl::StatusOr CreateDenseElementsAttrFromLiteral( const LiteralBase& literal, mlir::Builder builder); -StatusOr GetElementTypeBytes(mlir::Type type); - // Creates an DenseIntElementsAttr using the elements of the vector and the // optional shape. mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector( const llvm::ArrayRef vector, mlir::Builder builder, llvm::ArrayRef shape = {}); -StatusOr ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type, - mlir::Builder builder); - -mlir::mhlo::GatherDimensionNumbersAttr CreateGatherDimensionNumbers( - const GatherDimensionNumbers& input, mlir::Builder builder); - // Converts the given XLA shape for tensors to the template MLIR type. template static StatusOr ConvertTensorShapeToType(const Shape& xla_ty, mlir::Builder builder) { auto element_type_or = - ConvertPrimitiveTypeToMLIRType(xla_ty.element_type(), builder); + ConvertPrimitiveTypeToMlirType(xla_ty.element_type(), builder); if (!element_type_or.ok()) return element_type_or.status(); bool is_bounded_dynamic = false; @@ -92,12 +85,12 @@ static StatusOr ConvertTensorShapeToType(const Shape& xla_ty, return Unimplemented( "MHLO doesn't support bounded dynamic shapes for sparse tensors"); llvm::SmallVector lts; - for (size_t i = 0, e = layout.dim_level_types().size(); i < e; ++i) { - auto dlt = layout.dim_level_types()[i]; + for (size_t i = 0, e = layout.dim_level_types_size(); i < e; ++i) { + auto dlt = layout.dim_level_type(i); bool ordered = - i < layout.dim_ordered().size() ? layout.dim_ordered()[i] : true; + i < layout.dim_ordered_size() ? layout.dim_ordered(i) : true; bool unique = - i < layout.dim_unique().size() ? layout.dim_unique()[i] : true; + i < layout.dim_unique_size() ? layout.dim_unique(i) : true; switch (dlt) { case DimLevelType::DIM_DENSE: lts.push_back(*mlir::sparse_tensor::buildLevelType( @@ -133,11 +126,11 @@ static StatusOr ConvertTensorShapeToType(const Shape& xla_ty, return TypeT::get(shape, element_type_or.value(), encoding); } -StatusOr ConvertTensorShapeToMemRefType( +absl::StatusOr ConvertTensorShapeToMemRefType( const Shape& shape, mlir::Builder builder); template <> -inline StatusOr ConvertTensorShapeToType( +inline absl::StatusOr ConvertTensorShapeToType( const Shape& shape, mlir::Builder builder) { if (shape.is_dynamic()) { return FailedPrecondition( // NOLINT @@ -148,8 +141,8 @@ inline StatusOr ConvertTensorShapeToType( // Converts the given XLA shape to the template MLIR type. template -static StatusOr ConvertShapeToType(const Shape& shape, - mlir::Builder builder) { +static absl::StatusOr ConvertShapeToType(const Shape& shape, + mlir::Builder builder) { if (shape.IsTuple()) { llvm::SmallVector contents; contents.reserve(shape.tuple_shapes_size()); @@ -166,8 +159,6 @@ static StatusOr ConvertShapeToType(const Shape& shape, return ConvertTensorShapeToType(shape, builder); } -::xla::StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op); - } // namespace xla #endif // XLA_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ diff --git a/xla/translate/hlo_to_mhlo/hlo_utils_test.cc b/xla/translate/hlo_to_mhlo/hlo_utils_test.cc index 295d656275d88..b0d8b48ff7649 100644 --- a/xla/translate/hlo_to_mhlo/hlo_utils_test.cc +++ b/xla/translate/hlo_to_mhlo/hlo_utils_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/translate/hlo_to_mhlo/location_importer.cc b/xla/translate/hlo_to_mhlo/location_importer.cc index b2ee13ece9405..431fae74cc6b9 100644 --- a/xla/translate/hlo_to_mhlo/location_importer.cc +++ b/xla/translate/hlo_to_mhlo/location_importer.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/translate/hlo_to_mhlo/location_importer.h b/xla/translate/hlo_to_mhlo/location_importer.h index 1cfc4ab39bfde..164754e8281e0 100644 --- a/xla/translate/hlo_to_mhlo/location_importer.h +++ b/xla/translate/hlo_to_mhlo/location_importer.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/translate/hlo_to_mhlo/stack_location_utils.cc b/xla/translate/hlo_to_mhlo/stack_location_utils.cc index afbf3aa2c64ad..d09b8ff1d56f1 100644 --- a/xla/translate/hlo_to_mhlo/stack_location_utils.cc +++ b/xla/translate/hlo_to_mhlo/stack_location_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/translate/hlo_to_mhlo/stack_location_utils.h b/xla/translate/hlo_to_mhlo/stack_location_utils.h index 298f651ad8653..96f6ac86cdc46 100644 --- a/xla/translate/hlo_to_mhlo/stack_location_utils.h +++ b/xla/translate/hlo_to_mhlo/stack_location_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/translate/hlo_to_mhlo/tests/BUILD b/xla/translate/hlo_to_mhlo/tests/BUILD index fc792d50600ba..e2961242bb96b 100644 --- a/xla/translate/hlo_to_mhlo/tests/BUILD +++ b/xla/translate/hlo_to_mhlo/tests/BUILD @@ -1,26 +1,42 @@ -load("@tsl//tsl:tsl.default.bzl", "filegroup") -load("//xla:glob_lit_test.bzl", "glob_lit_tests") +load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) -glob_lit_tests( +lit_test_suite( name = "all_tests", - data = [":test_utilities"], - driver = "@llvm-project//mlir:run_lit.sh", - test_file_exts = [ - "hlo", - "hlotxt", - ], -) - -# Bundle together all of the test utilities that are used by tests. -filegroup( - name = "test_utilities", - testonly = True, - data = [ + srcs = enforce_glob( + [ + "bool_compare.hlotxt", + "case_conditional.hlotxt", + "custom_call.hlotxt", + "dynamic_param.hlo", + "entry_computation_layout.hlotxt", + "frontend_attributes.hlotxt", + "fully_connected_reference_model.hlotxt", + "fusion.hlotxt", + "if_conditional.hlotxt", + "import.hlotxt", + "import_async.hlotxt", + "layouts_and_names.hlotxt", + "location.hlotxt", + "module_attributes.hlo", + "send_recv.hlotxt", + "simple.hlo", + "spmd_module_sharding.hlo", + "stacktrace_to_location.hlo", + "types.hlotxt", + "while.hlotxt", + ], + include = [ + "*.hlotxt", + "*.hlo", + ], + ), + cfg = "//xla:lit.cfg.py", + tools = [ "//xla/translate:xla-translate", "@llvm-project//llvm:FileCheck", "@llvm-project//llvm:not", diff --git a/xla/translate/hlo_to_mhlo/tests/custom_call.hlotxt b/xla/translate/hlo_to_mhlo/tests/custom_call.hlotxt new file mode 100644 index 0000000000000..ad40f5dac6669 --- /dev/null +++ b/xla/translate/hlo_to_mhlo/tests/custom_call.hlotxt @@ -0,0 +1,51 @@ +// RUN: xla-translate --print-sugar=false -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s + +// CHECK: module @foobar +HloModule foobar + +// CHECK-LABEL: func @main(%arg0: tensor) -> tensor { +ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { + ROOT %Arg_0.1 = f32[] parameter(0) +} + +// CHECK-LABEL: func private @test_custom_call_dynamic_broadcast_in_dim +// CHECK-SAME: [[ARG_0:%.*]]: tensor<1x?xf32>, [[ARG_1:%.*]]: tensor<3xi64>) -> tensor<2x?x2xf32> +%test_custom_call_dynamic_broadcast_in_dim (arg1: f32[1,?], arg2: s64[3]) -> f32[2,?,2] { + %arg1 = f32[1,?] parameter(0) + %arg2 = s64[3] parameter(1) + // CHECK: "mhlo.dynamic_broadcast_in_dim"([[ARG_0]], [[ARG_1]]) <{ + // CHECK-SAME: broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> + // CHECK-SAME: (tensor<1x?xf32>, tensor<3xi64>) -> tensor<2x?x2xf32> + ROOT %custom-call = f32[2,?,2] custom-call(f32[1,?] %arg1, s64[3] %arg2), custom_call_target="mhlo.dynamic_broadcast_in_dim", backend_config={broadcast_dimensions=[0,1]} +} + +// CHECK-LABEL: func private @test_custom_call_dynamic_reshape +// CHECK-SAME: [[ARG_0:%.*]]: tensor, [[ARG_1:%.*]]: tensor<2xi64>) -> tensor +%test_custom_call_dynamic_reshape (arg1: f32[?], arg2: s64[2]) -> f32[?,?] { + %arg1 = f32[?] parameter(0) + %arg2 = s64[2] parameter(1) + // CHECK: mhlo.dynamic_reshape [[ARG_0]], [[ARG_1]] : (tensor, tensor<2xi64>) -> tensor + ROOT %custom-call = f32[?,?] custom-call(f32[?] %arg1, s64[2] %arg2), custom_call_target="mhlo.dynamic_reshape" +} + +// CHECK-LABEL: func private @test_custom_call_real_dynamic_slice +// CHECK-SAME: ([[ARG_0:%.*]]: tensor, [[ARG_1:%.*]]: tensor<4xi32>, [[ARG_2:%.*]]: tensor<4xi32>, [[ARG_3:%.*]]: tensor<4xi32>) -> tensor +%test_custom_call_real_dynamic_slice(arg1: f32[?,3,224,224], arg2: s32[4], arg3: s32[4], arg4: s32[4]) -> f32[?,3,224,224] { + %Arg_0.1 = f32[?,3,224,224] parameter(0) + %Arg_1.2 = s32[4] parameter(1) + %Arg_2.3 = s32[4] parameter(2) + %Arg_3.4 = s32[4] parameter(3) + + // CHECK: mhlo.real_dynamic_slice [[ARG_0]], [[ARG_1]], [[ARG_2]], [[ARG_3]] : (tensor, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor + ROOT %custom-call.12 = f32[?,3,224,224] custom-call(f32[?,3,224,224] %Arg_0.1, s32[4] %Arg_1.2, s32[4] %Arg_2.3, s32[4] %Arg_3.4), custom_call_target="mhlo.real_dynamic_slice" +} + +// Test HLO->MHLO converter for quantize/dequantize +%test_custom_call_for_quant_dequant (p0:f32[1,3]) -> f32[1,3] { + %p0 = f32[1,3] parameter(0) + %custom-call.1 = s8[1,3] custom-call(f32[1,3] %p0), custom_call_target="mhlo.uniform_quantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[1.00],zero_point=[0],storage_type=i8,expressed_type=f32,storage_min=-128,storage_max=127} + ROOT %custom-call.2 = f32[1,3] custom-call(s8[1,3] %custom-call.1), custom_call_target="mhlo.uniform_dequantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[1.00],zero_point=[0],storage_type=i8,expressed_type=f32,storage_min=-128,storage_max=127} +} +// CHECK-LABEL: func private @test_custom_call_for_quant_dequant +// CHECK: mhlo.uniform_quantize {{.*}} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> +// CHECK: mhlo.uniform_dequantize {{.*}} : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> diff --git a/xla/translate/hlo_to_mhlo/tests/fully_connected_reference_model.hlotxt b/xla/translate/hlo_to_mhlo/tests/fully_connected_reference_model.hlotxt index ad705919886e8..d99796976505e 100644 --- a/xla/translate/hlo_to_mhlo/tests/fully_connected_reference_model.hlotxt +++ b/xla/translate/hlo_to_mhlo/tests/fully_connected_reference_model.hlotxt @@ -14,7 +14,7 @@ ENTRY %tfcompile.48 { // CHECK-NEXT: %[[VAL_2:.*]] = mhlo.reshape %[[VAL_0]] : (tensor<1x300xf32>) -> tensor<1x300xf32> %reshape.3 = f32[1,300] reshape(%arg0.1) - // CHECK-NEXT: %[[VAL_3:.*]] = "mhlo.transpose"(%[[VAL_2]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32> + // CHECK-NEXT: %[[VAL_3:.*]] = "mhlo.transpose"(%[[VAL_2]]) <{permutation = dense<[1, 0]> : tensor<2xi64>}> : (tensor<1x300xf32>) -> tensor<300x1xf32> %transpose.27 = f32[300,1] transpose(%reshape.3), dimensions={1,0} // CHECK-NEXT: %[[VAL_4:.*]] = mhlo.reshape %[[VAL_3]] : (tensor<300x1xf32>) -> tensor<300x1x1xf32> @@ -23,13 +23,13 @@ ENTRY %tfcompile.48 { // CHECK-NEXT: %[[VAL_5:.*]] = mhlo.reshape %[[VAL_4]] : (tensor<300x1x1xf32>) -> tensor<300x1xf32> %reshape.29 = f32[300,1] reshape(%reshape.28) - // CHECK-NEXT: %[[VAL_6:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_5]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<300x1xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %[[VAL_6:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_5]]) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<300x1xf32>) -> tensor<300x1x5xf32> %broadcast.30 = f32[300,1,5] broadcast(%reshape.29), dimensions={0,1} // CHECK-NEXT: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor %constant.8 = f32[] constant(1) - // CHECK-NEXT: %[[VAL_8:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_7]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %[[VAL_8:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_7]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<300x1x5xf32> %broadcast.9 = f32[300,1,5] broadcast(%constant.8), dimensions={} // CHECK-NEXT: %[[VAL_9:.*]] = mhlo.multiply %[[VAL_6]], %[[VAL_8]] : tensor<300x1x5xf32> @@ -38,7 +38,7 @@ ENTRY %tfcompile.48 { // CHECK-NEXT: %[[VAL_10:.*]] = mhlo.constant dense<0.000000e+00> : tensor %constant.32 = f32[] constant(0) - // CHECK-NEXT: %[[VAL_11:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_10]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %[[VAL_11:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_10]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<300x1x5xf32> %broadcast.33 = f32[300,1,5] broadcast(%constant.32), dimensions={} // CHECK-NEXT: %[[VAL_12:.*]] = mhlo.compare GT, %[[VAL_9]], %[[VAL_11]] : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1> @@ -47,13 +47,13 @@ ENTRY %tfcompile.48 { // CHECK-NEXT: %[[VAL_13:.*]] = mhlo.constant dense<0.000000e+00> : tensor %constant.10 = f32[] constant(0) - // CHECK-NEXT: %[[VAL_14:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_13]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %[[VAL_14:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_13]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<300x1x5xf32> %broadcast.11 = f32[300,1,5] broadcast(%constant.10), dimensions={} // CHECK-NEXT: %[[VAL_15:.*]] = mhlo.constant dense<0.000000e+00> : tensor %constant.40 = f32[] constant(0) - // CHECK-NEXT: %[[VAL_16:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_15]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<300x5xf32> + // CHECK-NEXT: %[[VAL_16:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_15]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<300x5xf32> %broadcast.41 = f32[300,5] broadcast(%constant.40), dimensions={} // CHECK-NEXT: %[[VAL_17:.*]] = mhlo.copy %[[VAL_1]] : tensor<1x300x3x1xf32> @@ -65,7 +65,7 @@ ENTRY %tfcompile.48 { // CHECK-NEXT: %[[VAL_19:.*]] = mhlo.reshape %[[VAL_18]] : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32> %reshape.24 = f32[1,300,3] reshape(%reshape.4) - // CHECK-NEXT: %[[VAL_20:.*]] = "mhlo.transpose"(%[[VAL_19]]) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32> + // CHECK-NEXT: %[[VAL_20:.*]] = "mhlo.transpose"(%[[VAL_19]]) <{permutation = dense<[1, 0, 2]> : tensor<3xi64>}> : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32> %transpose.25 = f32[300,1,3] transpose(%reshape.24), dimensions={1,0,2} // CHECK-NEXT: %[[VAL_21:.*]] = mhlo.reshape %[[VAL_20]] : (tensor<300x1x3xf32>) -> tensor<300x3xf32> @@ -75,13 +75,13 @@ ENTRY %tfcompile.48 { %constant.35 = f32[3,5] constant({ { -0.106023, 0.121505, 0.800239, -0.768885, 0.0966113 }, { 0.689014, -0.407056, -0.797853, 0.00378925, -0.208881 }, { -0.608529, 0.0276617, 0.268557, 0.577401, -0.428437 } }) // TODO(b/129709049) consider making this default precision config implied. - // CHECK-NEXT: %[[VAL_23:.*]] = "mhlo.dot"(%[[VAL_21]], %[[VAL_22]]) {precision_config = [#mhlo, #mhlo]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32> + // CHECK-NEXT: %[[VAL_23:.*]] = "mhlo.dot"(%[[VAL_21]], %[[VAL_22]]) <{precision_config = [#mhlo, #mhlo]}> : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32> %dot.36 = f32[300,5] dot(%reshape.26, %constant.35), lhs_contracting_dims={1}, rhs_contracting_dims={0} // CHECK-NEXT: %[[VAL_24:.*]] = mhlo.constant dense<0.000000e+00> : tensor<5xf32> %constant.37 = f32[5]{0} constant({0, 0, 0, 0, 0}) - // CHECK-NEXT: %[[VAL_25:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_24]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<300x5xf32> + // CHECK-NEXT: %[[VAL_25:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_24]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<5xf32>) -> tensor<300x5xf32> %broadcast.38 = f32[300,5] broadcast(%constant.37), dimensions={1} // CHECK-NEXT: %[[VAL_26:.*]] = mhlo.add %[[VAL_23]], %[[VAL_25]] : tensor<300x5xf32> diff --git a/xla/translate/hlo_to_mhlo/tests/fusion.hlotxt b/xla/translate/hlo_to_mhlo/tests/fusion.hlotxt index 6ba31aafffe5c..310e95513523c 100644 --- a/xla/translate/hlo_to_mhlo/tests/fusion.hlotxt +++ b/xla/translate/hlo_to_mhlo/tests/fusion.hlotxt @@ -3,18 +3,22 @@ HloModule main.17 // CHECK: func @main(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -> tensor { -// CHECK: %[[F0:.+]] = "mhlo.fusion"(%[[ARG0:.*]], %[[ARG1:.*]]) ({ +// CHECK: %[[F0:.+]] = "mhlo.fusion"(%[[ARG0:.*]], %[[ARG1:.*]]) +// CHECK: <{fusion_kind = #mhlo}> ({ // CHECK: ^bb0(%[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor): -// CHECK: }) {fusion_kind = #mhlo, output_operand_aliasing = []} : (tensor, tensor) -> tensor -// CHECK: %[[F1:.+]]:2 = "mhlo.fusion"(%[[ARG0:.*]]) ({ +// CHECK: }) {output_operand_aliasing = []} : (tensor, tensor) -> tensor +// CHECK: %[[F1:.+]]:2 = "mhlo.fusion"(%[[ARG0:.*]]) +// CHECK: <{fusion_kind = #mhlo}> ({ // CHECK: ^bb0(%[[ARG2:.*]]: tensor): -// CHECK: }) {fusion_kind = #mhlo, output_operand_aliasing = []} : (tensor) -> (tensor, tensor) -// CHECK: %[[F2:.+]]:2 = "mhlo.fusion"(%[[ARG0:.*]], %[[ARG1:.*]]) ({ +// CHECK: }) {output_operand_aliasing = []} : (tensor) -> (tensor, tensor) +// CHECK: %[[F2:.+]]:2 = "mhlo.fusion"(%[[ARG0:.*]], %[[ARG1:.*]]) +// CHECK: <{fusion_kind = #mhlo}> ({ // CHECK: ^bb0(%[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor): -// CHECK: }) {fusion_kind = #mhlo, output_operand_aliasing = []} : (tensor, tensor) -> (tensor, tensor) -// CHECK: %[[F3:.+]]:2 = "mhlo.fusion"(%[[ARG0:.*]], %[[ARG1:.*]]) ({ +// CHECK: }) {output_operand_aliasing = []} : (tensor, tensor) -> (tensor, tensor) +// CHECK: %[[F3:.+]]:2 = "mhlo.fusion"(%[[ARG0:.*]], %[[ARG1:.*]]) +// CHECK: <{fusion_kind = #mhlo}> ({ // CHECK: ^bb0(%[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor): -// CHECK: }) {fusion_kind = #mhlo, output_operand_aliasing = []} : (tensor, tensor) -> (tensor, tensor) +// CHECK: }) {output_operand_aliasing = []} : (tensor, tensor) -> (tensor, tensor) // CHECK: } %region_0 (Arg_0.4: f32[], Arg_1.5: f32[]) -> f32[] { diff --git a/xla/translate/hlo_to_mhlo/tests/import.hlotxt b/xla/translate/hlo_to_mhlo/tests/import.hlotxt index e18fc9926588f..454c93111a198 100644 --- a/xla/translate/hlo_to_mhlo/tests/import.hlotxt +++ b/xla/translate/hlo_to_mhlo/tests/import.hlotxt @@ -24,7 +24,7 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { %add.42 = f32[4]{0} add(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2) // TODO(b/129709049) consider making this default precision config inferred. - // CHECK-NEXT: "mhlo.dot"(%0, %arg1) {precision_config = [#mhlo, #mhlo]} : (tensor<4xf32>, tensor<4xf32>) -> tensor + // CHECK-NEXT: "mhlo.dot"(%0, %arg1) <{precision_config = [#mhlo, #mhlo]}> : (tensor<4xf32>, tensor<4xf32>) -> tensor ROOT %dot.4 = f32[] dot(f32[4]{0} %add.42, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} } @@ -72,19 +72,31 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { ROOT ag = f32[128,128] all-gather(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, dimensions={1}, use_global_device_ids=true } +// CHECK-LABEL: func private @test_all_gather_variadic +%test_all_gather_variadic { + input.0 = f32[128,8] parameter(0) + input.1 = f32[128,16] parameter(1) + // CHECK-NEXT: "mhlo.all_gather"(%arg0, %arg1) + // CHECK-SAME: all_gather_dim = 1 : i64 + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + // CHECK-SAME: use_global_device_ids + ROOT ag = (f32[128,32], f32[128,64]) all-gather(f32[128,8] input.0, f32[128,16] input.1), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, dimensions={1}, use_global_device_ids=true +} + // Test all-to-all // CHECK-LABEL: func private @test_all_to_all // CHECK-SAME: ([[ARG:%.*]]: tensor<2x2xi32>) %test_all_to_all { %parameter = s32[2,2]{1,0} parameter(0) - // CHECK-NEXT: "mhlo.all_to_all"([[ARG]]) { + // CHECK-NEXT: "mhlo.all_to_all"([[ARG]]) <{ // CHECK-SAME: channel_handle = #mhlo.channel_handle // CHECK-SAME: concat_dimension = 1 : i64, // CHECK-SAME{LITERAL}: replica_groups = dense<[[1, 2], [3, 0]]> : tensor<2x2xi64>, // CHECK-SAME: split_count = 2 : i64, // CHECK-SAME: split_dimension = 1 : i64 - // CHECK-SAME: } : (tensor<2x2xi32>) -> tensor<2x2xi32> + // CHECK-SAME: }> : (tensor<2x2xi32>) -> tensor<2x2xi32> ROOT %all-to-all = s32[2,2]{1,0} all-to-all(s32[2,2]{1,0} %parameter), channel_id=1, replica_groups={{1,2}, {3,0}}, dimensions={1} } @@ -94,9 +106,9 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { %test_tuple_all_to_all { %p0 = f32[128,4]{0,1} parameter(0) %p1 = f32[128,4]{1,0} parameter(1) - // CHECK-NEXT: "mhlo.all_to_all"([[ARG0]], [[ARG1]]) { + // CHECK-NEXT: "mhlo.all_to_all"([[ARG0]], [[ARG1]]) <{ // CHECK-SAME: channel_handle = #mhlo.channel_handle - // CHECK-SAME: } : (tensor<128x4xf32>, tensor<128x4xf32>) -> (tensor<128x4xf32>, tensor<128x4xf32>) + // CHECK-SAME: }> : (tensor<128x4xf32>, tensor<128x4xf32>) -> (tensor<128x4xf32>, tensor<128x4xf32>) ROOT %all-to-all = (f32[128,4]{0,1}, f32[128,4]{1,0}) all-to-all(%p0, %p1), channel_id=1, replica_groups={{0,1}} } @@ -111,15 +123,15 @@ add { // CHECK-SAME: ([[INPUT:%.*]]: tensor<8xf32>) %test_all_reduce { input = f32[8] parameter(0) - // CHECK-NEXT: "mhlo.all_reduce"([[INPUT]]) - // CHECK: ^bb0([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor): - // CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]] - // CHECK: mhlo.return [[ADD]] : tensor - // CHECK: }) { + // CHECK-NEXT: "mhlo.all_reduce"([[INPUT]]) <{ // CHECK-SAME: channel_handle = #mhlo.channel_handle // CHECK-NOT: use_global_device_ids // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64> // CHECK-NOT: use_global_device_ids + // CHECK: ^bb0([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor): + // CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]] + // CHECK: mhlo.return [[ADD]] : tensor + // CHECK: }) // CHECK-SAME: : ROOT result = f32[8] all-reduce(input), channel_id=1, replica_groups={{0,1,2,3}, {4,5,6,7}}, to_apply=add } @@ -128,14 +140,14 @@ add { // CHECK-SAME: ([[INPUT:%.*]]: tensor<8xf32>) %test_all_reduce_global { input = f32[8] parameter(0) - // CHECK-NEXT: "mhlo.all_reduce"([[INPUT]]) - // CHECK: ^bb0([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor): - // CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]] - // CHECK: mhlo.return [[ADD]] : tensor - // CHECK: }) { + // CHECK-NEXT: "mhlo.all_reduce"([[INPUT]]) <{ // CHECK-SAME: channel_handle = #mhlo.channel_handle // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64> // CHECK-SAME: use_global_device_ids + // CHECK: ^bb0([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor): + // CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]] + // CHECK: mhlo.return [[ADD]] : tensor + // CHECK: }) ROOT result = f32[8] all-reduce(input), channel_id=1, replica_groups={{0,1,2,3}, {4,5,6,7}}, use_global_device_ids=true, to_apply=add } @@ -180,10 +192,10 @@ add { %test_broadcast_in_dim { %Arg_0.1 = f32[1, 2] parameter(0) - // CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<1x2x3xf32> + // CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x2xf32>) -> tensor<1x2x3xf32> %broadcast.2 = f32[1,2,3] broadcast(%Arg_0.1), dimensions={0,1} - // CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<3x1x2xf32> + // CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}> : (tensor<1x2xf32>) -> tensor<3x1x2xf32> ROOT broadcast.4 = f32[3,1,2] broadcast(%Arg_0.1), dimensions={1, 2} } @@ -234,7 +246,7 @@ add { // CHECK-SAME: ([[ARG:%.*]]: tensor<1x291x291xf32>) -> tensor<1x291x291xf32> %test_cholesky (a: f32[1,291,291]) -> f32[1,291,291] { %a = f32[1,291,291] parameter(0) - // CHECK-NEXT: "mhlo.cholesky"([[ARG]]) {lower = true} : (tensor<1x291x291xf32>) -> tensor<1x291x291xf32> + // CHECK-NEXT: "mhlo.cholesky"([[ARG]]) <{lower = true}> : (tensor<1x291x291xf32>) -> tensor<1x291x291xf32> ROOT %out = f32[1,291,291] cholesky(f32[1,291,291] %a), lower=true } @@ -249,14 +261,25 @@ add { ROOT %clamp.3 = f32[4] clamp(f32[] %Arg_0.1, f32[4] %Arg_1.2, f32[] %Arg_2.3) } +// CHECK-LABEL: func private @test_collective_broadcast +// CHECK-SAME: ([[ARG:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> +%test_collective_broadcast (input: f32[128,32]) -> f32[128,32] { + %input = f32[128,32]{1,0} parameter(0) + // CHECK-NEXT: "mhlo.collective_broadcast"([[ARG]]) <{ + // CHECK-SAME: channel_handle = #mhlo.channel_handle, + // CHECK-SAME{LITERAL} replica_groups = {{0,1}{2,3}} + // CHECK-SAME: }> : (tensor<128x32xf32>) -> tensor<128x32xf32> + ROOT root = f32[128,32]{1,0} collective-broadcast(%input), replica_groups={{0,1},{2,3}}, channel_id=1 +} + // CHECK-LABEL: func private @test_collective_permute // CHECK-SAME: ([[ARG:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> %test_collective_permute (input: f32[128,32]) -> f32[128,32] { %input = f32[128,32]{1,0} parameter(0) - // CHECK-NEXT: "mhlo.collective_permute"([[ARG]]) { + // CHECK-NEXT: "mhlo.collective_permute"([[ARG]]) <{ // CHECK-SAME: channel_handle = #mhlo.channel_handle, // CHECK-SAME: source_target_pairs = dense<{{\[\[}}0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> - // CHECK-SAME: } : (tensor<128x32xf32>) -> tensor<128x32xf32> + // CHECK-SAME: }> : (tensor<128x32xf32>) -> tensor<128x32xf32> ROOT root = f32[128,32]{1,0} collective-permute(%input), source_target_pairs={{0,1},{1,2},{2,3}}, channel_id=1 } @@ -291,7 +314,7 @@ add { %Arg_0.1 = f32[4, 1] parameter(0) %Arg_1.2 = f32[4, 2] parameter(1) - // CHECK-NEXT: "mhlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<4x1xf32>, tensor<4x2xf32>) -> tensor<4x3xf32> + // CHECK-NEXT: "mhlo.concatenate"(%arg0, %arg1) <{dimension = 1 : i64}> : (tensor<4x1xf32>, tensor<4x2xf32>) -> tensor<4x3xf32> ROOT %concatenate.3 = f32[4, 3] concatenate(f32[4, 1] %Arg_0.1, f32[4, 2] %Arg_1.2), dimensions={1} } @@ -610,25 +633,25 @@ add { %Arg_0.1 = f32[1, 4] parameter(0) %Arg_1.2 = f32[4, 1] parameter(1) - // CHECK-NEXT: %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = [#mhlo, #mhlo]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<1x1xf32> + // CHECK-NEXT: %0 = "mhlo.dot"(%arg0, %arg1) <{precision_config = [#mhlo, #mhlo]}> : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<1x1xf32> dot.3 = f32[1, 1] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={high,highest} - // CHECK-NEXT: %1 = "mhlo.dot"(%arg0, %arg1) {precision_config = [#mhlo, #mhlo]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<1x1xf32> + // CHECK-NEXT: %1 = "mhlo.dot"(%arg0, %arg1) <{precision_config = [#mhlo, #mhlo]}> : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<1x1xf32> dot.4 = f32[1, 1] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,default} - // CHECK-NEXT: %2 = "mhlo.dot"(%arg0, %arg1) {precision_config = [#mhlo, #mhlo]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<1x1xf32> + // CHECK-NEXT: %2 = "mhlo.dot"(%arg0, %arg1) <{precision_config = [#mhlo, #mhlo]}> : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<1x1xf32> %dot.5 = f32[1, 1] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={default,default} // CHECK-NEXT: %3 = mhlo.reshape %arg1 : (tensor<4x1xf32>) -> tensor<4xf32> - // CHECK-NEXT: %4 = "mhlo.dot"(%3, %arg1) {precision_config = [#mhlo, #mhlo]} : (tensor<4xf32>, tensor<4x1xf32>) -> tensor<1xf32> + // CHECK-NEXT: %4 = "mhlo.dot"(%3, %arg1) <{precision_config = [#mhlo, #mhlo]}> : (tensor<4xf32>, tensor<4x1xf32>) -> tensor<1xf32> reshape.0 = f32[4]{0} reshape(f32[4, 1] Arg_1.2) %dot.6 = f32[1] dot(reshape.0, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}, operand_precision={default,default} - // CHECK-NEXT: %5 = "mhlo.dot"(%arg0, %3) {precision_config = [#mhlo, #mhlo]} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1xf32> + // CHECK-NEXT: %5 = "mhlo.dot"(%arg0, %3) <{precision_config = [#mhlo, #mhlo]}> : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1xf32> %dot.7 = f32[1] dot(%Arg_0.1, reshape.0), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={default,default} // TODO(b/129709049) consider making this default precision config inferred. - // CHECK-NEXT: "mhlo.dot"(%arg0, %arg1) {precision_config = [#mhlo, #mhlo]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<1x1xf32> + // CHECK-NEXT: "mhlo.dot"(%arg0, %arg1) <{precision_config = [#mhlo, #mhlo]}> : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<1x1xf32> ROOT %dot.8 = f32[1, 1] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} } @@ -637,7 +660,7 @@ add { %Arg_0.1 = s4[1, 4] parameter(0) %Arg_1.2 = s4[4, 1] parameter(1) - // CHECK-NEXT: %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = [#mhlo, #mhlo]} : (tensor<1x4xi4>, tensor<4x1xi4>) -> tensor<1x1xi8> + // CHECK-NEXT: %0 = "mhlo.dot"(%arg0, %arg1) <{precision_config = [#mhlo, #mhlo]}> : (tensor<1x4xi4>, tensor<4x1xi4>) -> tensor<1x1xi8> ROOT dot.3 = s8[1, 1] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={high,highest} } @@ -646,7 +669,7 @@ add { %Arg_0.1 = u4[1, 4] parameter(0) %Arg_1.2 = u4[4, 1] parameter(1) - // CHECK-NEXT: %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = [#mhlo, #mhlo]} : (tensor<1x4xui4>, tensor<4x1xui4>) -> tensor<1x1xui8> + // CHECK-NEXT: %0 = "mhlo.dot"(%arg0, %arg1) <{precision_config = [#mhlo, #mhlo]}> : (tensor<1x4xui4>, tensor<4x1xui4>) -> tensor<1x1xui8> ROOT dot.3 = u8[1, 1] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={high,highest} } @@ -700,6 +723,18 @@ add { ROOT %dot.1 = s16[14,1,12,16,3] dot(%Arg_0.1, %Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} } +// CHECK-LABEL: @test_dot_sparse +// CHECK-SAME: [[LHS:%.*]]: tensor<10x16xf32>, [[RHS:%.*]]: tensor<32x20xf32>, [[META:%.*]]: tensor<10x2xui16> +%test_dot_sparse (Arg_0: f32[10,16], Arg_1: f32[32,20], Arg_2: u16[10,2]) -> f32[10,20] { + %lhs = f32[10,16] parameter(0) + %rhs = f32[32,20] parameter(1) + %meta = u16[10,2] parameter(2) + // CHECK: "mhlo.sparse_dot"([[LHS]], [[RHS]], [[META]]) + // CHECK-SAME: lhs_sparsity = #mhlo.sparsity + ROOT %dot = f32[10,20] dot(%lhs, %rhs, %meta), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4 +} + + // CHECK-LABEL: func private @test_dynamic_slice // CHECK-SAME: [[OPERAND:%.*]]: tensor<2x2x258xi32>, [[START_IDX_1:%.*]]: tensor, [[START_IDX_2:%.*]]: tensor, [[START_IDX_3:%.*]]: tensor %test_dynamic_slice (operand: s32[2,2,258], start_indices: s32[3]) -> s32[1,1,32] { @@ -752,7 +787,7 @@ add { // CHECK-LABEL: func private @test_fft(%arg0: tensor<3x9xf32>) -> tensor<3x5xcomplex> %test_fft { %arg0.1 = f32[3,9]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"} - // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo + // CHECK: "mhlo.fft"(%arg0) <{fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo ROOT %fft.2 = c64[3,5]{1,0} fft(%arg0.1), fft_type=RFFT, fft_length={9}, metadata={op_type="RFFT" op_name="rfft"} } @@ -791,7 +826,7 @@ add { // CHECK-SAME: ([[ARG:%.*]]: tensor<4x2xf32>) %test_get_dimension_size (Arg_0.1: f32[4,2]) -> s32[] { %Arg_0.1 = f32[4,2] parameter(0) - // CHECK-NEXT: "mhlo.get_dimension_size"([[ARG]]) {dimension = 1 : i64} : (tensor<4x2xf32>) -> tensor + // CHECK-NEXT: "mhlo.get_dimension_size"([[ARG]]) <{dimension = 1 : i64}> : (tensor<4x2xf32>) -> tensor ROOT %get-dimension-size.2 = s32[] get-dimension-size(f32[4,2] %Arg_0.1), dimensions={1} } @@ -841,13 +876,13 @@ add { // CHECK-LABEL: func private @test_iota_1() -> tensor<4xf32> %test_iota_1 () -> f32[4] { - // CHECK-NEXT: "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32> + // CHECK-NEXT: "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<4xf32> ROOT %iota.0 = f32[4] iota(), iota_dimension=0 } // CHECK-LABEL: func private @test_iota_2() -> tensor<4x5xf32> %test_iota_2 () -> f32[4, 5] { - // CHECK-NEXT: "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4x5xf32> + // CHECK-NEXT: "mhlo.iota"() <{iota_dimension = 1 : i64}> : () -> tensor<4x5xf32> ROOT %iota.0 = f32[4, 5] iota(), iota_dimension=1 } @@ -879,11 +914,11 @@ add { %test_map { param0 = f32[4]{0} parameter(0) param1 = f32[4]{0} parameter(1) -// CHECK: "mhlo.map"([[ARG_0]], [[ARG_1]]) ({ +// CHECK: "mhlo.map"([[ARG_0]], [[ARG_1]]) <{dimensions = dense<0> : tensor<1xi64>}> ({ // CHECK: ^bb0([[ARG_2:%.*]]: tensor, [[ARG_3:%.*]]: tensor): // CHECK: [[ADD:%.*]] = mhlo.add [[ARG_2]], [[ARG_3]] // CHECK: mhlo.return [[ADD]] : tensor -// CHECK: }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> +// CHECK: }) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=%map_computation } @@ -900,11 +935,11 @@ add { %test_map_with_reducer_returning_tuple { param0 = f32[4]{0} parameter(0) param1 = f32[4]{0} parameter(1) -// CHECK: "mhlo.map"([[ARG_0]], [[ARG_1]]) ({ +// CHECK: "mhlo.map"([[ARG_0]], [[ARG_1]]) <{dimensions = dense<0> : tensor<1xi64>}> ({ // CHECK: ^bb0([[ARG_2:%.*]]: tensor, [[ARG_3:%.*]]: tensor): // CHECK: [[ADD:%.*]] = mhlo.add [[ARG_2]], [[ARG_3]] // CHECK: mhlo.return [[ADD]] : tensor -// CHECK: }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> +// CHECK: }) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=%map_computation_returning_tuple } @@ -922,10 +957,10 @@ add { ROOT %map.6 = f32[4] map(f32[4] %Arg_0.1, s32[4] %Arg_1.2), dimensions={0}, to_apply=%map_computation_take_left } -// CHECK: "mhlo.map"([[ARG_0]], [[ARG_1]]) ({ +// CHECK: "mhlo.map"([[ARG_0]], [[ARG_1]]) <{dimensions = dense<0> : tensor<1xi64>}> ({ // CHECK: ^bb0([[ARG_2:%.*]]: tensor, [[ARG_3:%.*]]: tensor): // CHECK: mhlo.return [[ARG_2]] : tensor -// CHECK: }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xi32>) -> tensor<4xf32> +// CHECK: }) : (tensor<4xf32>, tensor<4xi32>) -> tensor<4xf32> // CHECK-LABEL: func private @test_maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> %test_maximum (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { @@ -1003,8 +1038,8 @@ add { %Arg_0.1 = s32[3] parameter(0) %Arg_1.2 = token[] parameter(1) // CHECK-NEXT: "mhlo.outfeed"([[DATA]], [[TOKEN]]) - // CHECK-SAME: mhlo.sharding = "{maximal device=0}" // CHECK-SAME: outfeed_config = "foobar" + // CHECK-SAME: mhlo.sharding = "{maximal device=0}" ROOT %outfeed.3 = token[] outfeed(s32[3] %Arg_0.1, token[] %Arg_1.2), outfeed_config="foobar", sharding={maximal device=0} } @@ -1025,7 +1060,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[] parameter(1) - // CHECK-NEXT: "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK-NEXT: "mhlo.pad"(%arg0, %arg1) <{edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<4xf32>, tensor) -> tensor<4xf32> ROOT %pad.3 = f32[4] pad(%Arg_0.1, %Arg_1.2), padding=0_0_0 } @@ -1034,7 +1069,7 @@ add { %Arg_0.1 = f32[4, 4, 4] parameter(0) %Arg_1.2 = f32[] parameter(1) - // CHECK-NEXT: "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<[2, 4, 6]> : tensor<3xi64>, edge_padding_low = dense<[1, 3, 5]> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<4x4x4xf32>, tensor) -> tensor<7x11x15xf32> + // CHECK-NEXT: "mhlo.pad"(%arg0, %arg1) <{edge_padding_high = dense<[2, 4, 6]> : tensor<3xi64>, edge_padding_low = dense<[1, 3, 5]> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>}> : (tensor<4x4x4xf32>, tensor) -> tensor<7x11x15xf32> ROOT %pad.3 = f32[7, 11, 15] pad(%Arg_0.1, %Arg_1.2), padding=1_2x3_4x5_6 } @@ -1043,7 +1078,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[] parameter(1) - // CHECK-NEXT: "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<2> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<10xf32> + // CHECK-NEXT: "mhlo.pad"(%arg0, %arg1) <{edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<2> : tensor<1xi64>}> : (tensor<4xf32>, tensor) -> tensor<10xf32> ROOT %pad.3 = f32[10] pad(%Arg_0.1, %Arg_1.2), padding=0_0_2 } @@ -1070,7 +1105,7 @@ add { %Arg_0.1 = f32[] parameter(0) %Arg_1.2 = f32[] parameter(1) // CHECK: [[CST:%.*]] = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // CHECK: "mhlo.rng"([[ARG0]], [[ARG1]], [[CST]]) {rng_distribution = #mhlo.rng_distribution} + // CHECK: "mhlo.rng"([[ARG0]], [[ARG1]], [[CST]]) <{rng_distribution = #mhlo.rng_distribution}> ROOT %rng.4 = f32[2,3,5] rng(f32[] %Arg_0.1, f32[] %Arg_1.2), distribution=rng_normal } @@ -1080,7 +1115,7 @@ add { %Arg_0.1 = f32[] parameter(0) %Arg_1.2 = f32[] parameter(1) // CHECK: [[CST:%.*]] = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // CHECK: "mhlo.rng"([[ARG0]], [[ARG1]], [[CST]]) {rng_distribution = #mhlo.rng_distribution} + // CHECK: "mhlo.rng"([[ARG0]], [[ARG1]], [[CST]]) <{rng_distribution = #mhlo.rng_distribution}> ROOT %rng.4 = f32[2,3,5] rng(f32[] %Arg_0.1, f32[] %Arg_1.2), distribution=rng_uniform } @@ -1130,16 +1165,12 @@ add { // CHECK: mhlo.tuple %0#0, %0#1 {xla_shape = {{.*}}} : tuple, tensor> %reduce.1 = (f32[], f32[]) reduce(%Arg_0.1, %Arg_0.1, %Arg_2.3, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.1 - // CHECK: [[VAL2:%.*]] = mhlo.reduce([[ARG0]] init: [[ARG2]]) - // CHECK: mhlo.add{{.*}} : tensor + // CHECK: [[VAL2:%.*]] = mhlo.reduce([[ARG0]] init: [[ARG2]]) applies mhlo.add across dimensions = [0, 1] : (tensor<4x4xf32>, tensor) -> tensor %reduce.3 = f32[] reduce(%Arg_0.1, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.3 // CHECK: [[VAL3:%.*]] = mhlo.reduce([[ARG0]] init: [[ARG1]]) - // CHECK- // CHECK: mhlo.add{{.*}} : tensor<4xf32> %reduce.2 = f32[4] reduce(%Arg_0.1, %Arg_1.2), dimensions={0}, to_apply=%reduce_helper.2 - // CHECK: [[VAL4:%.*]] = mhlo.reduce([[VAL3]] init: [[ARG2]]) - // CHECK-SAME: dimensions = [0] - // CHECK: mhlo.add{{.*}} : tensor + // CHECK: [[VAL4:%.*]] = mhlo.reduce([[VAL3]] init: [[ARG2]]) applies mhlo.add across dimensions = [0] : (tensor<4xf32>, tensor) -> tensor %reduce.4 = f32[] reduce(%reduce.2, %Arg_2.3), dimensions={0}, to_apply=%reduce_helper.3 // CHECK: %5 = mhlo.subtract [[VAL2]], [[VAL4]] : tensor @@ -1167,14 +1198,14 @@ add { // CHECK-SAME: ([[ARG0:%.*]]: tensor<4x8xf32>) %test_reduce_scatter { input = f32[4,8] parameter(0) - // CHECK-NEXT: "mhlo.reduce_scatter"([[ARG0]]) ({ + // CHECK-NEXT: "mhlo.reduce_scatter"([[ARG0]]) <{ + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + // CHECK-SAME: scatter_dimension = 1 : i64 + // CHECK-SAME: }> ({ // CHECK-NEXT: ^bb0([[BARG0:%.*]]: tensor, [[BARG1:%.*]]: tensor): // CHECK-NEXT: [[ADD:%.*]] = mhlo.add [[BARG0]], [[BARG1]] : tensor // CHECK-NEXT: mhlo.return [[ADD]] - // CHECK-NEXT: }) { - // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> - // CHECK-SAME: scatter_dimension = 1 : i64 - // CHECK-SAME: } : (tensor<4x8xf32>) -> tensor<4x4xf32> + // CHECK-NEXT: }) : (tensor<4x8xf32>) -> tensor<4x4xf32> ROOT ars = f32[4,4] reduce-scatter(input), replica_groups={{0,1}}, dimensions={1}, to_apply=reduce_helper_add } @@ -1192,18 +1223,18 @@ add { // CHECK-SAME: ([[ARG0:%.*]]: tensor<4x8xf32>) %test_reduce_scatter_with_channel { input = f32[4,8] parameter(0) - // CHECK-NEXT: "mhlo.reduce_scatter"([[ARG0]]) ({ - // CHECK-NEXT: ^bb0([[BARG0:%.*]]: tensor, [[BARG1:%.*]]: tensor): - // CHECK-NEXT: [[ADD:%.*]] = mhlo.add [[BARG0]], [[BARG1]] : tensor - // CHECK-NEXT: mhlo.return [[ADD]] - // CHECK-NEXT: }) { + // CHECK-NEXT: "mhlo.reduce_scatter"([[ARG0]]) <{ // CHECK-SAME: channel_handle = #mhlo.channel_handle // CHECK-NOT: use_global_device_ids // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> // CHECK-NOT: use_global_device_ids // CHECK-SAME: scatter_dimension = 1 : i64 // CHECK-NOT: use_global_device_ids - // CHECK-SAME: } : (tensor<4x8xf32>) -> tensor<4x4xf32> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^bb0([[BARG0:%.*]]: tensor, [[BARG1:%.*]]: tensor): + // CHECK-NEXT: [[ADD:%.*]] = mhlo.add [[BARG0]], [[BARG1]] : tensor + // CHECK-NEXT: mhlo.return [[ADD]] + // CHECK-NEXT: }) : (tensor<4x8xf32>) -> tensor<4x4xf32> ROOT ars = f32[4,4] reduce-scatter(input), channel_id=1, replica_groups={{0,1}}, dimensions={1}, to_apply=reduce_helper_add } @@ -1211,16 +1242,16 @@ add { // CHECK-SAME: ([[ARG0:%.*]]: tensor<4x8xf32>) %test_reduce_scatter_global { input = f32[4,8] parameter(0) - // CHECK-NEXT: "mhlo.reduce_scatter"([[ARG0]]) ({ - // CHECK-NEXT: ^bb0([[BARG0:%.*]]: tensor, [[BARG1:%.*]]: tensor): - // CHECK-NEXT: [[ADD:%.*]] = mhlo.add [[BARG0]], [[BARG1]] : tensor - // CHECK-NEXT: mhlo.return [[ADD]] - // CHECK-NEXT: }) { + // CHECK-NEXT: "mhlo.reduce_scatter"([[ARG0]]) <{ // CHECK-SAME: channel_handle = #mhlo.channel_handle // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> // CHECK-SAME: scatter_dimension = 1 : i64 // CHECK-SAME: use_global_device_ids - // CHECK-SAME: } : (tensor<4x8xf32>) -> tensor<4x4xf32> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^bb0([[BARG0:%.*]]: tensor, [[BARG1:%.*]]: tensor): + // CHECK-NEXT: [[ADD:%.*]] = mhlo.add [[BARG0]], [[BARG1]] : tensor + // CHECK-NEXT: mhlo.return [[ADD]] + // CHECK-NEXT: }) : (tensor<4x8xf32>) -> tensor<4x4xf32> ROOT ars = f32[4,4] reduce-scatter(input), channel_id=1, replica_groups={{0,1}}, dimensions={1}, use_global_device_ids=true, to_apply=reduce_helper_add } @@ -1231,15 +1262,15 @@ add { %Arg_0.1 = f32[2,17,31,7] parameter(0) %Arg_1.2 = f32[] parameter(1) - // CHECK: "mhlo.reduce_window"([[ARG0]], [[ARG1]]) ({ - // CHECK: mhlo.add {{.*}} : tensor - // CHECK: }) { + // CHECK: "mhlo.reduce_window"([[ARG0]], [[ARG1]]) <{ // CHECK-SAME: base_dilations = dense<1> : tensor<4xi64> // CHECK-SAME: padding = dense<{{\[\[}}0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64> // CHECK-SAME: window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64> // CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64> // CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64> - // CHECK-SAME: } + // CHECK-SAME: }> ({ + // CHECK: mhlo.add {{.*}} : tensor + // CHECK: }) ROOT %reduce-window.1 = f32[2,5,8,7] reduce-window(f32[2,17,31,7] %Arg_0.1, f32[] %Arg_1.2), window={size=1x2x2x1 stride=1x4x4x1 pad=0_0x2_0x0_2x0_0 rhs_dilate=1x2x2x1}, to_apply=%reduce_helper.3 } @@ -1267,19 +1298,18 @@ add { %constant.3 = pred[] constant(false) %constant.4 = f32[] constant(-inf) %constant.5 = f32[] constant(0) - // CHECK: [[REDUCE_WINDOW:%.*]]:2 = "mhlo.reduce_window"([[ARG1]], [[ARG0]], [[CONST_INF]], [[CONST_ZERO]]) ({ - // CHECK: ^bb0([[BARG0:%.*]]: tensor, [[BARG1:%.*]]: tensor, [[BARG2:%.*]]: tensor, [[BARG3:%.*]]: tensor) - // CHECK: [[COMPARE:%.*]] = mhlo.compare GE, [[BARG0]], [[BARG2]] - // CHECK: [[SELECT_0:%.*]] = mhlo.select [[COMPARE]], [[BARG0]], [[BARG2]] - // CHECK: [[SELECT_1:%.*]] = mhlo.select [[COMPARE]], [[BARG1]], [[BARG3]] - // CHECK: }) { + // CHECK: [[REDUCE_WINDOW:%.*]]:2 = "mhlo.reduce_window"([[ARG1]], [[ARG0]], [[CONST_INF]], [[CONST_ZERO]]) <{ // CHECK-SAME: base_dilations = dense<1> : tensor<2xi64> // CHECK-SAME: padding = dense<0> : tensor<2x2xi64> // CHECK-SAME: window_dilations = dense<1> : tensor<2xi64> // CHECK-SAME: window_dimensions = dense<[1, 2]> : tensor<2xi64> // CHECK-SAME: window_strides = dense<[1, 2]> : tensor<2xi64> - // CHECK-SAME: } - // CHECK-SAME: : (tensor<4x6xf32>, tensor<4x6xf32>, tensor, tensor) -> (tensor<4x3xf32>, tensor<4x3xf32>) + // CHECK-SAME: }> ({ + // CHECK: ^bb0([[BARG0:%.*]]: tensor, [[BARG1:%.*]]: tensor, [[BARG2:%.*]]: tensor, [[BARG3:%.*]]: tensor) + // CHECK: [[COMPARE:%.*]] = mhlo.compare GE, [[BARG0]], [[BARG2]] + // CHECK: [[SELECT_0:%.*]] = mhlo.select [[COMPARE]], [[BARG0]], [[BARG2]] + // CHECK: [[SELECT_1:%.*]] = mhlo.select [[COMPARE]], [[BARG1]], [[BARG3]] + // CHECK: }) : (tensor<4x6xf32>, tensor<4x6xf32>, tensor, tensor) -> (tensor<4x3xf32>, tensor<4x3xf32>) %reduce-window.15 = (f32[4,3], f32[4,3]) reduce-window(f32[4,6] %Arg1, f32[4,6] %Arg0, f32[] %constant.4, f32[] %constant.5), window={size=1x2 stride=1x2}, to_apply=%reducer_window_helper // CHECK: return [[REDUCE_WINDOW]]#1 : tensor<4x3xf32> ROOT %get-tuple-element.16 = f32[4,3] get-tuple-element((f32[4,3], f32[4,3]) %reduce-window.15), index=1 @@ -1298,7 +1328,7 @@ add { %test_reverse_1d (Arg_0.1: f32[4]) -> f32[4] { %Arg_0.1 = f32[4] parameter(0) - // CHECK-NEXT: "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: "mhlo.reverse"(%arg0) <{dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4xf32> ROOT reverse.2 = f32[4] reverse(%Arg_0.1), dimensions={0} } @@ -1306,7 +1336,7 @@ add { %test_reverse_2d (Arg_0.1: f32[4, 4]) -> f32[4, 4] { %Arg_0.1 = f32[4, 4] parameter(0) - // CHECK-NEXT: "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4xf32> + // CHECK-NEXT: "mhlo.reverse"(%arg0) <{dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<4x4xf32>) -> tensor<4x4xf32> ROOT reverse.2 = f32[4, 4] reverse(%Arg_0.1), dimensions={0, 1} } @@ -1341,11 +1371,7 @@ add { // CHECK-LABEL: func private @test_scatter // CHECK-SAME: [[ARG_0:%.*]]: tensor<200x100x300xf32>, [[ARG_1:%.*]]: tensor<10x2xi64>, [[ARG_2:%.*]]: tensor<10x300xf32>) -> tensor<200x100x300xf32> -// CHECK: "mhlo.scatter"([[ARG_0]], [[ARG_1]], [[ARG_2]]) ({ -// CHECK: ^bb0([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): -// CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]] -// CHECK: mhlo.return [[ADD]] : tensor -// CHECK: }) +// CHECK: "mhlo.scatter"([[ARG_0]], [[ARG_1]], [[ARG_2]]) // CHECK-SAME: indices_are_sorted = false // CHECK-SAME: scatter_dimension_numbers = // CHECK-SAME: update_window_dims = [1] @@ -1353,6 +1379,10 @@ add { // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] // CHECK-SAME: index_vector_dim = 1 // CHECK-SAME: unique_indices = false +// CHECK: ^bb0([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): +// CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]] +// CHECK: mhlo.return [[ADD]] : tensor +// CHECK: }) %wide_update_computation { %lhs = f32[] parameter(0) @@ -1371,10 +1401,7 @@ add { // CHECK-LABEL: func.func private @test_variadic_scatter // CHECK-SAME: [[ARG_0:%.*]]: tensor<200x100x300xf32>, [[ARG_1:%.*]]: tensor<10x2xi64>, [[ARG_2:%.*]]: tensor<10x300xf32>) -> tuple, tensor<200x100x300xf32>> -// CHECK: "mhlo.scatter"([[ARG_0]], [[ARG_0]], [[ARG_1]], [[ARG_2]], [[ARG_2]]) ({ -// CHECK: ^bb0([[LHS:%.*]]: tensor, [[UPD:%.*]]: tensor, [[RHS:%.*]]: tensor, [[UPP:%.*]]: tensor): -// CHECK: mhlo.return [[LHS]], [[RHS]] : tensor, tensor -// CHECK: }) +// CHECK: "mhlo.scatter"([[ARG_0]], [[ARG_0]], [[ARG_1]], [[ARG_2]], [[ARG_2]]) // CHECK-SAME: indices_are_sorted = false // CHECK-SAME: scatter_dimension_numbers = // CHECK-SAME: update_window_dims = [1] @@ -1382,6 +1409,9 @@ add { // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] // CHECK-SAME: index_vector_dim = 1 // CHECK-SAME: unique_indices = false +// CHECK: ^bb0([[LHS:%.*]]: tensor, [[UPD:%.*]]: tensor, [[RHS:%.*]]: tensor, [[UPP:%.*]]: tensor): +// CHECK: mhlo.return [[LHS]], [[RHS]] : tensor, tensor +// CHECK: }) %update_computation_returning_tuple { %lhs = f32[] parameter(0) @@ -1399,7 +1429,7 @@ add { // CHECK-LABEL: func private @test_scatter_with_reducer_returning_tuple // CHECK-SAME: [[ARG_0:%.*]]: tensor<200x100x300xf32>, [[ARG_1:%.*]]: tensor<10x2xi64>, [[ARG_2:%.*]]: tensor<10x300xf32>) -> tensor<200x100x300xf32> -// CHECK: "mhlo.scatter"([[ARG_0]], [[ARG_1]], [[ARG_2]]) ({ +// CHECK: "mhlo.scatter"([[ARG_0]], [[ARG_1]], [[ARG_2]]) // CHECK: ^bb0([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): // CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]] // CHECK: mhlo.return [[ADD]] : tensor @@ -1419,7 +1449,7 @@ add { // CHECK-LABEL: func private @test_all_reduce_with_reducer_returning_tuple // CHECK-SAME: [[ARG_0:%.*]]: tensor<4xf32>) -> tensor<4xf32> -// CHECK: "mhlo.all_reduce"([[ARG_0]]) ({ +// CHECK: "mhlo.all_reduce"([[ARG_0]]) // CHECK: ^bb0([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): // CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]] // CHECK: mhlo.return [[ADD]] : tensor @@ -1457,7 +1487,11 @@ add { ROOT %select-and-scatter = f32[4,5] select-and-scatter(f32[4,5] %input, f32[2,2] %source, f32[] %init_value), window={size=2x3 stride=2x3 pad=0_0x0_1}, select=%ge_select, scatter=%add_gather } -// CHECK: [[RESULT:%.*]] = "mhlo.select_and_scatter"([[INPUT]], [[SOURCE]], [[INIT_VAL]]) ({ +// CHECK: [[RESULT:%.*]] = "mhlo.select_and_scatter"([[INPUT]], [[SOURCE]], [[INIT_VAL]]) <{ +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 1]]> : tensor<2x2xi64> +// CHECK-SAME: window_dimensions = dense<[2, 3]> : tensor<2xi64> +// CHECK-SAME: window_strides = dense<[2, 3]> : tensor<2xi64> +// CHECK-SAME: }> ({ // CHECK: ^bb0([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): // CHECK: [[CMP:%.*]] = mhlo.compare GE, [[LHS]], [[RHS]] // CHECK: mhlo.return [[CMP]] : tensor @@ -1465,11 +1499,7 @@ add { // CHECK: ^bb0([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): // CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]] // CHECK: mhlo.return [[ADD]] : tensor -// CHECK: }) { -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 1]]> : tensor<2x2xi64> -// CHECK-SAME: window_dimensions = dense<[2, 3]> : tensor<2xi64> -// CHECK-SAME: window_strides = dense<[2, 3]> : tensor<2xi64> -// CHECK-SAME: } +// CHECK: }) // CHECK: return [[RESULT:%.*]] : tensor<4x5xf32> // Test SelectAndScatter with tuple returns from computations. @@ -1508,7 +1538,7 @@ add { %test_set_dimension_size (Arg_0.1: f32[4,4], Arg_1.2: s32[]) -> f32[4,<=4] { %Arg_0.1 = f32[4,4] parameter(0) %Arg_1.2 = s32[] parameter(1) - // CHECK-NEXT: "mhlo.set_dimension_size"([[ARG]], [[SIZE]]) {dimension = 1 : i64} : (tensor<4x4xf32>, tensor) + // CHECK-NEXT: "mhlo.set_dimension_size"([[ARG]], [[SIZE]]) <{dimension = 1 : i64}> : (tensor<4x4xf32>, tensor) // CHECK-SAME: tensor<4x?xf32, #mhlo.type_extensions> ROOT %set-dimension-size.2 = f32[4,<=4] set-dimension-size(f32[4,4] %Arg_0.1, s32[] %Arg_1.2), dimensions={1} } @@ -1534,11 +1564,11 @@ add { } // CHECK-LABEL: func private @test_sort // CHECK-SAME: [[ARG:%.*]]: tensor<1024xf32>) -> tensor<1024xf32> -// CHECK: "mhlo.sort"([[ARG]]) ({ +// CHECK: "mhlo.sort"([[ARG]]) <{dimension = 0 : i64, is_stable = true}> ({ // CHECK: ^bb0([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor): // CHECK: [[CMP:%.*]] = mhlo.compare LT, [[ARG0]], [[ARG1]] : (tensor, tensor) -> tensor // CHECK: mhlo.return [[CMP]] : tensor -// CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<1024xf32>) -> tensor<1024xf32> +// CHECK: }) : (tensor<1024xf32>) -> tensor<1024xf32> // CHECK-LABEL: func private @test_subtract %test_subtract (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { @@ -1579,7 +1609,7 @@ add { %test_transpose { %Arg_0.1 = s32[1,2,3,4] parameter(0) - // CHECK: "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + // CHECK: "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}> : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2} } @@ -1692,14 +1722,29 @@ add { ROOT %not.2 = u16[4] not(u16[4] %Arg_0.1) } -// CHECK-LABEL: func private @rngbitgen -// CHECK-SAME: (%[[ARG0:.*]]: tensor<3xui64>) -%rngbitgen (Arg_0.1: u64[3]) -> (u64[3], u32[2,2]) { +// CHECK-LABEL: func private @rngbitgen_tuple_shape +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3xui64>) -> (tuple, tensor<2x2xui32>> {mhlo.sharding = "{{\{}}{maximal device=0}, {maximal device=1}}"}) +%rngbitgen_tuple_shape (Arg_0.1: u64[3]) -> (u64[3], u32[2,2]) { %Arg_0.1 = u64[3] parameter(0) - // CHECK: %[[RNG0:.+]], %[[RNG1:.+]] = "mhlo.rng_bit_generator"(%[[ARG0]]) {rng_algorithm = #mhlo.rng_algorithm} : (tensor<3xui64>) -> (tensor<3xui64>, tensor<2x2xui32>) + // CHECK: %[[RNG0:.+]], %[[RNG1:.+]] = "mhlo.rng_bit_generator"(%[[ARG0]]) + // CHECK-SAME: rng_algorithm = #mhlo.rng_algorithm + // CHECK-SAME: mhlo.sharding = "{{\{}}{maximal device=0}, {maximal device=1}}" + // CHECK-SAME: (tensor<3xui64>) -> (tensor<3xui64>, tensor<2x2xui32>) // CHECK: %[[TUPLE:.+]] = mhlo.tuple %[[RNG0]], %[[RNG1]] {xla_shape = "(u64[3]{0}, u32[2,2]{1,0})"} : tuple, tensor<2x2xui32>> // CHECK: return %[[TUPLE]] - ROOT %rng-bit-generator.2 = (u64[3], u32[2,2]) rng-bit-generator(u64[3] %Arg_0.1), algorithm=rng_philox + ROOT %rng-bit-generator.2 = (u64[3], u32[2,2]) rng-bit-generator(u64[3] %Arg_0.1), algorithm=rng_philox, sharding={{maximal device=0}, {maximal device=1}} +} + +// CHECK-LABEL: func private @rngbitgen_array_shape +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3xui64>) -> (tensor<2x2xui32> {mhlo.sharding = "{maximal device=0}"}) +%rngbitgen_array_shape (Arg_0.1: u64[3]) -> u32[2,2] { + %Arg_0.1 = u64[3] parameter(0) + // CHECK: %[[RNG0:.+]], %[[RNG1:.+]] = "mhlo.rng_bit_generator"(%[[ARG0]]) + // CHECK-SAME: rng_algorithm = #mhlo.rng_algorithm + // CHECK-SAME: mhlo.sharding = "{{\{}}{replicated}, {maximal device=0}}" + // CHECK-SAME: (tensor<3xui64>) -> (tensor<3xui64>, tensor<2x2xui32>) + // CHECK: return %[[RNG1]] + ROOT %rng-bit-generator.2 = u32[2,2] rng-bit-generator(u64[3] %Arg_0.1), algorithm=rng_default, sharding={maximal device=0} } // CHECK-LABEL: func private @cbrt @@ -1797,23 +1842,18 @@ add { // CHECK: %[[START:.*]] = "mhlo.async_start"(%[[ARG0]]) // CHECK-SAME: called_computation = @AsyncOp // CHECK-SAME: execution_thread = "main" - // CHECK-SAME: group_id = 1 - %async-start = (f32[10]{0}, f32[20]{0}, s32[]) async-start(%p0), - calls=%AsyncOp, - async_group_id=1 + %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) async-start(%p0), + calls=%AsyncOp // CHECK: %[[UPDATE:.*]] = "mhlo.async_update"(%[[START]]) // CHECK-SAME: called_computation = @AsyncOp // CHECK-SAME: execution_thread = "main" - // CHECK-SAME: group_id = 1 - %async-update = (f32[10]{0}, f32[20]{0}, s32[]) async-update( + %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) async-update( %async-start), - calls=%AsyncOp, - async_group_id=1 + calls=%AsyncOp // CHECK: "mhlo.async_done"(%[[UPDATE]]) // CHECK-SAME: called_computation = @AsyncOp // CHECK-SAME: execution_thread = "main" - // CHECK-SAME: group_id = 1 - ROOT %async-done = f32[20]{0} async-done(%async-update), calls=%AsyncOp, async_group_id=1 + ROOT %async-done = f32[20]{0} async-done(%async-update), calls=%AsyncOp } // CHECK-LABEL: func private @test_args_and_result_with_sharding diff --git a/xla/translate/hlo_to_mhlo/tests/import_async.hlotxt b/xla/translate/hlo_to_mhlo/tests/import_async.hlotxt index 3ac32e0b300aa..8040ac89faabb 100644 --- a/xla/translate/hlo_to_mhlo/tests/import_async.hlotxt +++ b/xla/translate/hlo_to_mhlo/tests/import_async.hlotxt @@ -26,22 +26,22 @@ HloModule foobar // CHECK: func private [[CP_GENSYM:@.*collective_permute_.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} { // CHECK-NEXT: "mhlo.collective_permute"([[INPUT]]) - // CHECK-SAME{LITERAL}:{source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>} : (tensor<128x32xf32>) -> tensor<128x32xf32> + // CHECK-SAME{LITERAL}: <{source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>}> : (tensor<128x32xf32>) -> tensor<128x32xf32> // CHECK: func private [[AR_GENSYM:@.*all_reduce.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} { // CHECK-NEXT: "mhlo.all_reduce"([[INPUT]]) + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + // CHECK-SAME: use_global_device_ids // CHECK: [[BLOCK:^.*]]([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): // CHECK: mhlo.add [[LHS]], [[RHS]] - // CHECK: channel_handle = #mhlo.channel_handle - // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - // CHECK: use_global_device_ids // CHECK: func private [[AG_GENSYM:@.*all_gather.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x128xf32> attributes {execution_thread = "main"} { // CHECK-NEXT: "mhlo.all_gather"([[INPUT]]) // CHECK-SAME: all_gather_dim = 1 : i64 // CHECK-SAME: channel_handle = #mhlo.channel_handle // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - // CHECK: use_global_device_ids + // CHECK-SAME: use_global_device_ids // CHECK: func @main(%arg0: tensor) -> tensor { ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { diff --git a/xla/translate/hlo_to_mhlo/tests/simple.hlo b/xla/translate/hlo_to_mhlo/tests/simple.hlo index 40502568a7d98..bc542ead303c4 100644 --- a/xla/translate/hlo_to_mhlo/tests/simple.hlo +++ b/xla/translate/hlo_to_mhlo/tests/simple.hlo @@ -141,6 +141,6 @@ dynamic_parameter_binding { # CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor { # CHECK-NEXT: %0 = mhlo.add %arg0, %arg1 : tensor<4xf32> # TODO(b/129709049) consider making this default precision config inferred. -# CHECK-NEXT: %1 = "mhlo.dot"(%0, %arg1) {precision_config = [#mhlo, #mhlo]} : (tensor<4xf32>, tensor<4xf32>) -> tensor +# CHECK-NEXT: %1 = "mhlo.dot"(%0, %arg1) <{precision_config = [#mhlo, #mhlo]}> : (tensor<4xf32>, tensor<4xf32>) -> tensor # CHECK-NEXT: return %1 : tensor # CHECK-NEXT: } diff --git a/xla/translate/hlo_to_mhlo/tests/while.hlotxt b/xla/translate/hlo_to_mhlo/tests/while.hlotxt index 2ed4016368f99..0b6c252a8338d 100644 --- a/xla/translate/hlo_to_mhlo/tests/while.hlotxt +++ b/xla/translate/hlo_to_mhlo/tests/while.hlotxt @@ -143,7 +143,7 @@ ENTRY %foo (arg0.1: s64[]) -> s64[] { // CHECK: %[[CMP_7:.*]] = mhlo.compare LT, %[[RED_5]], %[[RED_6]] : (tensor, tensor) -> tensor // CHECK: mhlo.return %[[CMP_7]] : tensor -// CHECK: %[[BDCAST_4:.*]] = "mhlo.broadcast_in_dim"(%[[ARG_3]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32> +// CHECK: %[[BDCAST_4:.*]] = "mhlo.broadcast_in_dim"(%[[ARG_3]]) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xf32>) -> tensor<3xf32> // CHECK-NEXT: %[[ADD_5:.*]] = mhlo.add %[[ARG_4]], %[[BDCAST_4]] : tensor<3xf32> // CHECK-NEXT: mhlo.return %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ADD_5]] : tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32> @@ -214,7 +214,7 @@ ENTRY %foo (arg0.1: s64[]) -> s64[] { // CHECK: %[[CMP_7:.*]] = mhlo.compare LT, %[[RED_5]], %[[RED_6]] : (tensor, tensor) -> tensor // CHECK: mhlo.return %[[CMP_7]] : tensor -// CHECK: %[[BDCAST_4:.*]] = "mhlo.broadcast_in_dim"(%[[ARG_3]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32> +// CHECK: %[[BDCAST_4:.*]] = "mhlo.broadcast_in_dim"(%[[ARG_3]]) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xf32>) -> tensor<3xf32> // CHECK: %[[ADD_5:.*]] = mhlo.add %[[ARG_4]], %[[BDCAST_4]] : tensor<3xf32> // CHECK: mhlo.return %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ADD_5]] : tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32> // CHECK: return %[[WHILE]]#3 : tensor<3xf32> @@ -334,7 +334,7 @@ region_cond5 { // CHECK-NEXT: } do { // CHECK-NEXT: %[[CST_2:.*]] = mhlo.constant dense : tensor // CHECK-NEXT: %[[CST_3:.*]] = mhlo.constant dense<2.000000e+00> : tensor -// CHECK-NEXT: %[[BDCAST:.*]] = "mhlo.broadcast_in_dim"(%[[CST_3]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<3x3xf32> +// CHECK-NEXT: %[[BDCAST:.*]] = "mhlo.broadcast_in_dim"(%[[CST_3]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<3x3xf32> // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[ITER_ARG]], %[[BDCAST]] : tensor<3x3xf32> // CHECK-NEXT: mhlo.return %[[ADD]] : tensor<3x3xf32> // CHECK: return %[[WHILE]] : tensor<3x3xf32> diff --git a/xla/translate/hlo_to_mhlo/translate.cc b/xla/translate/hlo_to_mhlo/translate.cc index 1cc5ab0b75ee6..addba009d0f95 100644 --- a/xla/translate/hlo_to_mhlo/translate.cc +++ b/xla/translate/hlo_to_mhlo/translate.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/translate/hlo_to_mhlo/translate.h b/xla/translate/hlo_to_mhlo/translate.h index cc74b838e6e7b..96d8bcca16491 100644 --- a/xla/translate/hlo_to_mhlo/translate.h +++ b/xla/translate/hlo_to_mhlo/translate.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/translate/hlo_to_mhlo/translate_registration.cc b/xla/translate/hlo_to_mhlo/translate_registration.cc index 1a5f8c2df6fb6..505961aac0c7d 100644 --- a/xla/translate/hlo_to_mhlo/translate_registration.cc +++ b/xla/translate/hlo_to_mhlo/translate_registration.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/translate/mhlo_to_hlo/BUILD b/xla/translate/mhlo_to_hlo/BUILD index 88ec58fd3964d..34ffe1ee8b1fc 100644 --- a/xla/translate/mhlo_to_hlo/BUILD +++ b/xla/translate/mhlo_to_hlo/BUILD @@ -1,12 +1,16 @@ +load("@bazel_skylib//rules:build_test.bzl", "build_test") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("@tsl//tsl:tsl.bzl", "internal_visibility") +load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") load("@tsl//tsl/platform:rules_cc.bzl", "cc_binary", "cc_library") load("//xla:xla.bzl", "xla_cc_test") -load("@tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") -load("@bazel_skylib//rules:build_test.bzl", "build_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], + default_visibility = internal_visibility([ + "//learning/brain/mlir:tensorflow_friends", + "//learning/brain/mlir:xla_friends", + ]), licenses = ["notice"], ) @@ -21,7 +25,6 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/mlir_hlo", - "//xla/mlir_hlo:lhlo_gpu", "//xla/service:hlo_parser", "//xla/service:hlo_proto_cc", "//xla/stream_executor:dnn", @@ -49,8 +52,10 @@ cc_library( deps = [ ":stack_frame_index_builder", "//xla:xla_data_proto_cc", + "@com_google_absl//absl/log", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -60,7 +65,6 @@ cc_library( hdrs = ["stack_frame_index_builder.h"], deps = [ "//xla/service:hlo_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", "@llvm-project//mlir:IR", ], ) @@ -96,6 +100,7 @@ cc_library( "//xla/client/lib:slicing", "//xla/hlo/ir:hlo", "//xla/mlir/utils:error_util", + "//xla/mlir/utils:type_util", "//xla/mlir_hlo", "//xla/mlir_hlo:mhlo_passes", "//xla/service:hlo_parser", @@ -113,7 +118,7 @@ cc_library( "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@stablehlo//:stablehlo_ops", - "@tsl//tsl/platform:float8", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:statusor", ], ) @@ -187,8 +192,8 @@ cc_library( "//xla:shape_util", "//xla:statusor", "//xla:xla_data_proto_cc", + "//xla/mlir/utils:type_util", "//xla/mlir_hlo", - "//xla/translate/hlo_to_mhlo:hlo_utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:SparseTensorDialect", diff --git a/xla/translate/mhlo_to_hlo/attribute_exporter.cc b/xla/translate/mhlo_to_hlo/attribute_exporter.cc index 93d07fc3453d1..21dfb491ec6e9 100644 --- a/xla/translate/mhlo_to_hlo/attribute_exporter.cc +++ b/xla/translate/mhlo_to_hlo/attribute_exporter.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ limitations under the License. #include -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo_parser.h" #include "xla/shape_util.h" @@ -57,41 +56,15 @@ ConvolutionDimensionNumbers ConvertConvDimensionNumbers( return output; } -StatusOr ConvertConvActivationMode( - mlir::lmhlo_gpu::Activation activation) { - switch (activation) { - case mlir::lmhlo_gpu::Activation::None: - return stream_executor::dnn::kNone; - case mlir::lmhlo_gpu::Activation::Sigmoid: - return stream_executor::dnn::kSigmoid; - case mlir::lmhlo_gpu::Activation::Tanh: - return stream_executor::dnn::kTanh; - case mlir::lmhlo_gpu::Activation::Relu: - return stream_executor::dnn::kRelu; - case mlir::lmhlo_gpu::Activation::Relu6: - return stream_executor::dnn::kRelu6; - case mlir::lmhlo_gpu::Activation::ReluX: - return stream_executor::dnn::kReluX; - case mlir::lmhlo_gpu::Activation::BandPass: - return stream_executor::dnn::kBandPass; - case mlir::lmhlo_gpu::Activation::Elu: - return stream_executor::dnn::kElu; - case mlir::lmhlo_gpu::Activation::LeakyRelu: - return stream_executor::dnn::kLeakyRelu; - default: - return InternalError("Unexpected activation"); - } -} - // Convert replica group from MLIR encoding to HLO. // See HloFunctionImporter::ConvertReplicaGroups for the MLIR encoding. -StatusOr> ConvertReplicaGroups( +absl::StatusOr> ConvertReplicaGroups( mlir::DenseIntElementsAttr input) { mlir::RankedTensorType type = input.getType().dyn_cast(); if (!type || type.getRank() != 2 || !type.getElementType().isInteger(/*width=*/64)) { - return InternalError("Execpted replica group to be a rank 2 tensor of i64"); + return Internal("Execpted replica group to be a rank 2 tensor of i64"); } // rank 0 is num_groups, rank 1 is group size. auto replica_group_values_it = input.getValues().begin(); @@ -112,14 +85,14 @@ StatusOr> ConvertReplicaGroups( // Convert a (N, 2) dense attribute to a list of tuples. This is the way padding // and source-target pairs are defined in HLO. -StatusOr>> ConvertNx2Attribute( +absl::StatusOr>> ConvertNx2Attribute( std::optional optional_attr) { if (!optional_attr.has_value()) return std::vector>{}; mlir::DenseIntElementsAttr attr = *optional_attr; auto type = attr.getType().dyn_cast(); if (!type || type.getRank() != 2 || type.getShape()[1] != 2) - return InternalError("expected Nx2 attribute to be a tensor of shape Nx2"); + return Internal("expected Nx2 attribute to be a tensor of shape Nx2"); auto it = attr.getValues().begin(); std::vector> out(attr.getNumElements() / 2); for (auto& item : out) { @@ -132,7 +105,7 @@ StatusOr>> ConvertNx2Attribute( return out; } -StatusOr ConvertFftType(llvm::StringRef type_string) { +absl::StatusOr ConvertFftType(llvm::StringRef type_string) { std::optional type = mlir::mhlo::symbolizeEnum(type_string); if (!type) return InvalidArgument("Unknown FFT type %s", type_string.str()); @@ -151,7 +124,7 @@ StatusOr ConvertFftType(llvm::StringRef type_string) { } } -StatusOr ConvertTranspose( +absl::StatusOr ConvertTranspose( llvm::StringRef transpose_string) { std::optional transpose = mlir::mhlo::symbolizeTranspose(transpose_string); @@ -172,7 +145,7 @@ StatusOr ConvertTranspose( } } -StatusOr ConvertCustomCallSchedule( +absl::StatusOr ConvertCustomCallSchedule( mlir::mhlo::CustomCallSchedule schedule) { switch (schedule) { case mlir::mhlo::CustomCallSchedule::NONE: @@ -187,7 +160,7 @@ StatusOr ConvertCustomCallSchedule( } } -StatusOr ConvertCustomCallApiVersion( +absl::StatusOr ConvertCustomCallApiVersion( mlir::mhlo::CustomCallApiVersion api_version) { switch (api_version) { case mlir::mhlo::CustomCallApiVersion::API_VERSION_UNSPECIFIED: @@ -206,7 +179,8 @@ StatusOr ConvertCustomCallApiVersion( } } -StatusOr>>> +absl::StatusOr< + std::vector>>> ConvertOutputOperandAliasing(mlir::ArrayAttr aliasArrayAttr) { std::vector>> aliasInfo; for (auto attr : aliasArrayAttr.getValue()) { @@ -223,7 +197,8 @@ ConvertOutputOperandAliasing(mlir::ArrayAttr aliasArrayAttr) { std::optional ConvertSharding(llvm::StringRef sharding) { xla::OpSharding sharding_proto; if (sharding_proto.ParseFromString(sharding.str())) return sharding_proto; - StatusOr sharding_cpp = xla::ParseSharding(sharding.str()); + absl::StatusOr sharding_cpp = + xla::ParseSharding(sharding.str()); if (sharding_cpp.ok()) return sharding_cpp->ToProto(); return std::nullopt; } @@ -275,14 +250,14 @@ DotDimensionNumbers ConvertDotDimensionNumbers( return output; } -StatusOr> ConvertMlirArrayAttrToInt64Array( +absl::StatusOr> ConvertMlirArrayAttrToInt64Array( const mlir::ArrayAttr& array) { int rank = array.size(); std::vector converted_array(rank); for (int i = 0; i < rank; i++) { mlir::IntegerAttr attr = array[i].dyn_cast(); if (!attr) { - return InternalError("Type Error: Expected layout integer attribute"); + return Internal("Type Error: Expected layout integer attribute"); } converted_array[i] = attr.getInt(); } diff --git a/xla/translate/mhlo_to_hlo/attribute_exporter.h b/xla/translate/mhlo_to_hlo/attribute_exporter.h index 3d58247b6d631..96aa716b58862 100644 --- a/xla/translate/mhlo_to_hlo/attribute_exporter.h +++ b/xla/translate/mhlo_to_hlo/attribute_exporter.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ limitations under the License. #include #include "mlir/IR/Attributes.h" // from @llvm-project -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" #include "xla/shape_util.h" @@ -34,28 +33,26 @@ namespace xla { ConvolutionDimensionNumbers ConvertConvDimensionNumbers( mlir::mhlo::ConvDimensionNumbersAttr input); -StatusOr ConvertConvActivationMode( - mlir::lmhlo_gpu::Activation input); - -StatusOr> ConvertReplicaGroups( +absl::StatusOr> ConvertReplicaGroups( mlir::DenseIntElementsAttr input); // Convert a (N, 2) dense attribute to a list of tuples. This is the way padding // and source-target pairs are defined in HLO. -StatusOr>> ConvertNx2Attribute( +absl::StatusOr>> ConvertNx2Attribute( std::optional optional_attr); -StatusOr ConvertFftType(llvm::StringRef type_string); -StatusOr ConvertTranspose( +absl::StatusOr ConvertFftType(llvm::StringRef type_string); +absl::StatusOr ConvertTranspose( llvm::StringRef transpose_string); -StatusOr ConvertCustomCallSchedule( +absl::StatusOr ConvertCustomCallSchedule( mlir::mhlo::CustomCallSchedule schedule); -StatusOr ConvertCustomCallApiVersion( +absl::StatusOr ConvertCustomCallApiVersion( mlir::mhlo::CustomCallApiVersion api_version); -StatusOr>>> +absl::StatusOr< + std::vector>>> ConvertOutputOperandAliasing(mlir::ArrayAttr aliasArrayAttr); // Returns an OpSharding that represents the result of parsing the given string: @@ -71,7 +68,7 @@ DotDimensionNumbers ConvertDotDimensionNumbers( absl::Span rhs_batch, absl::Span rhs_contract); -StatusOr> ConvertMlirArrayAttrToInt64Array( +absl::StatusOr> ConvertMlirArrayAttrToInt64Array( const mlir::ArrayAttr& array); } // namespace xla #endif // XLA_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ diff --git a/xla/translate/mhlo_to_hlo/layout_util.cc b/xla/translate/mhlo_to_hlo/layout_util.cc index 4c8ed5d211e1b..88261a847874f 100644 --- a/xla/translate/mhlo_to_hlo/layout_util.cc +++ b/xla/translate/mhlo_to_hlo/layout_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ limitations under the License. namespace mlir { // Rewrites the layout of xla_shape if there is tiled sharding. -xla::Status RewriteLayoutWithShardedShape( +absl::Status RewriteLayoutWithShardedShape( const std::optional& sharding, bool use_fast_memory, const LayoutPreferenceFn& layout_preference_fn, const ShapeRepresentationFn& shape_representation_fn, @@ -63,7 +63,7 @@ xla::Status RewriteLayoutWithShardedShape( // There is a shape_representation_fn or sharding for an output, this function // uses a reshape to fix the layout. -xla::StatusOr ReshapeWithCorrectRepresentationAndSharding( +absl::StatusOr ReshapeWithCorrectRepresentationAndSharding( xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, const LayoutPreferenceFn& layout_preference_fn, const ShapeRepresentationFn& shape_representation_fn, @@ -105,6 +105,7 @@ xla::StatusOr ReshapeWithCorrectRepresentationAndSharding( to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i)); } } + xla::XlaScopedShardingAssignment scoped_sharding(builder, sharding); return xla::Reshape(to_shape, original); } diff --git a/xla/translate/mhlo_to_hlo/layout_util.h b/xla/translate/mhlo_to_hlo/layout_util.h index f8f175ba282c4..f68e3ed94fe85 100644 --- a/xla/translate/mhlo_to_hlo/layout_util.h +++ b/xla/translate/mhlo_to_hlo/layout_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -52,11 +52,11 @@ enum class XlaLayoutPreference { // The following defines the layout preference of an xla tensor. // The return value of LayoutPreferenceFn can be used in // ShapeRepresentationFn. -typedef std::function( +typedef std::function( const xla::Shape& shape)> LayoutPreferenceFn; -typedef std::function( +typedef std::function( const xla::Shape& shape, bool fast_mem, XlaLayoutPreference layout_preference)> ShapeRepresentationFn; @@ -65,7 +65,7 @@ typedef std::function( LayoutPreferenceFn UseNoPreferenceLayoutFn(); // Rewrites the layout of xla_shape if there is tiled sharding. -xla::Status RewriteLayoutWithShardedShape( +absl::Status RewriteLayoutWithShardedShape( const std::optional& sharding, bool use_fast_memory, const LayoutPreferenceFn& layout_preference_fn, const ShapeRepresentationFn& shape_representation_fn, @@ -73,7 +73,7 @@ xla::Status RewriteLayoutWithShardedShape( // Adds reshapes to fix the layout of an output, if a shape_representation_fn or // sharding is present. -xla::StatusOr ReshapeWithCorrectRepresentationAndSharding( +absl::StatusOr ReshapeWithCorrectRepresentationAndSharding( xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, const LayoutPreferenceFn& layout_preference_fn, const ShapeRepresentationFn& shape_representation_fn, diff --git a/xla/translate/mhlo_to_hlo/location_exporter.cc b/xla/translate/mhlo_to_hlo/location_exporter.cc index 06e316eac5cce..0d8969116e70e 100644 --- a/xla/translate/mhlo_to_hlo/location_exporter.cc +++ b/xla/translate/mhlo_to_hlo/location_exporter.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,8 +17,17 @@ limitations under the License. #include +#include "absl/log/log.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/translate/mhlo_to_hlo/stack_frame_index_builder.h" namespace mlir { namespace mhlo { @@ -36,24 +45,18 @@ static std::string GetNameFromLocImpl(Location loc) { // in functions where the op's name is first. auto name = name_loc.getName().strref().split('@').first; // Skip if the name is for op type. - if (!name.endswith(":")) { + if (!name.ends_with(":")) { loc_names.push_back(name); } - continue; } else if (auto call_loc = curr_loc.dyn_cast()) { // Use location of the Callee to generate the name. locs.push_back(call_loc.getCallee()); - continue; } else if (auto fused_loc = curr_loc.dyn_cast()) { // Push all locations in FusedLoc in reverse order, so locations are // visited based on order in FusedLoc. auto reversed_fused_locs = llvm::reverse(fused_loc.getLocations()); locs.append(reversed_fused_locs.begin(), reversed_fused_locs.end()); - continue; } - - // Location is not a supported, so an empty StringRef is added. - loc_names.push_back(llvm::StringRef()); } return llvm::join(loc_names.begin(), loc_names.end(), ";"); @@ -71,29 +74,34 @@ static std::string GetOpTypeFromLoc(Location loc) { // Add name in NameLoc. For NameLoc we also account for names due to ops // in functions where the op's name is first. auto op_type = name_loc.getName().strref().split('@').first; - if (op_type.endswith(":")) { + if (op_type.ends_with(":")) { op_type = op_type.substr(0, op_type.size() - 1); loc_op_types.push_back(op_type); } - continue; } else if (auto call_loc = curr_loc.dyn_cast()) { // Use location of the Callee to generate the name. locs.push_back(call_loc.getCallee()); - continue; } else if (auto fused_loc = curr_loc.dyn_cast()) { // The first location is reserved for op_type. if (!fused_loc.getLocations().empty()) locs.push_back(fused_loc.getLocations()[0]); - continue; } - - // Location is not a supported, so an empty StringRef is added. - loc_op_types.push_back(llvm::StringRef()); } return llvm::join(loc_op_types.begin(), loc_op_types.end(), ";"); } +static void SetSourceFileAndLine(Location loc, xla::OpMetadata& metadata) { + if (auto file_line_col_loc = loc.dyn_cast()) { + metadata.set_source_file(file_line_col_loc.getFilename().str()); + metadata.set_source_line(file_line_col_loc.getLine()); + } else if (auto fused_loc = loc.dyn_cast()) { + for (Location it : fused_loc.getLocations()) { + SetSourceFileAndLine(it, metadata); + } + } +} + xla::OpMetadata CreateOpMetadataFromLocation( mlir::Operation* op, mlir::StackFrameIndexBuilder* frame_index_builder) { xla::OpMetadata metadata; @@ -105,21 +113,20 @@ xla::OpMetadata CreateOpMetadataFromLocation( std::string op_type = GetOpTypeFromLoc(loc); metadata.set_op_type(op_type); - if (auto name_loc = dyn_cast(op->getLoc())) { + if (auto name_loc = loc.dyn_cast()) { loc = name_loc.getChildLoc(); if (isa(loc)) return metadata; if (frame_index_builder != nullptr) { - int frameId = frame_index_builder->AddCallStackAndGetFirstFrameId(loc); - metadata.set_stack_frame_id(frameId); + auto result = frame_index_builder->AddCallStackAndGetFirstFrameId(loc); + metadata.set_stack_frame_id(result.last_frame_id); + // TODO(b/311155137): Remove when profiler will support stack traces. + metadata.set_source_file(result.last_frame_file); + metadata.set_source_line(result.last_frame_line); } } - if (auto file_line_col_loc = loc.dyn_cast()) { - metadata.set_source_file(file_line_col_loc.getFilename().str()); - metadata.set_source_line(file_line_col_loc.getLine()); - } - + SetSourceFileAndLine(loc, metadata); return metadata; } diff --git a/xla/translate/mhlo_to_hlo/location_exporter.h b/xla/translate/mhlo_to_hlo/location_exporter.h index 69efa6dbc4f47..8b7d30379247f 100644 --- a/xla/translate/mhlo_to_hlo/location_exporter.h +++ b/xla/translate/mhlo_to_hlo/location_exporter.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 4b8b0d2569900..39d8313472d2a 100644 --- a/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include +#include #include #include #include @@ -72,6 +73,7 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/mlir/utils/error_util.h" +#include "xla/mlir/utils/type_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/primitive_util.h" @@ -88,7 +90,7 @@ limitations under the License. #include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/float8.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/statusor.h" using ::int64_t; @@ -183,8 +185,8 @@ xla::Array ArrayFromDenseElementsAttr(mlir::DenseElementsAttr dense_attr) { return array; } -StatusOr CreateArrayLiteralFromAttr(mlir::ElementsAttr attr, - xla::Layout layout) { +absl::StatusOr CreateArrayLiteralFromAttr(mlir::ElementsAttr attr, + xla::Layout layout) { auto dense_attr = attr.dyn_cast(); if (!dense_attr) return tsl::errors::Unimplemented("Only dense elements attr are supported"); @@ -192,7 +194,7 @@ StatusOr CreateArrayLiteralFromAttr(mlir::ElementsAttr attr, xla::Shape shape = xla::TypeToShape(dense_attr.getType()); return xla::primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> StatusOr { + [&](auto primitive_type_constant) -> absl::StatusOr { if constexpr (xla::primitive_util::IsArrayType( primitive_type_constant)) { using cpp_type = @@ -482,6 +484,17 @@ static xla::ConvolutionDimensionNumbers Convert_dimension_numbers( return xla::ConvertConvDimensionNumbers(input); } +static xla::SparsityDescriptor Convert_sparsity_descriptor( + mlir::mhlo::SparsityDescriptorAttr sparsity_attr, bool is_lhs) { + xla::SparsityDescriptor sparsity_descriptor; + sparsity_descriptor.set_type(xla::SPARSITY_STRUCTURED_N_M); + sparsity_descriptor.set_index(is_lhs ? 0 : 1); + sparsity_descriptor.set_dimension(sparsity_attr.getDimension()); + sparsity_descriptor.set_n(sparsity_attr.getN()); + sparsity_descriptor.set_m(sparsity_attr.getM()); + return sparsity_descriptor; +} + xla::ChannelHandle Convert_channel_handle(mlir::mhlo::ChannelHandleAttr attr) { xla::ChannelHandle channel_handle; channel_handle.set_handle(attr.getHandle()); @@ -598,15 +611,6 @@ static void ExtractFrontendAttributesFromFunction( } } -// Checks if all shardings are set. -static bool AllOptionalShardingsAreSet( - llvm::ArrayRef> shardings) { - return llvm::all_of(shardings, - [](const std::optional& sharding) { - return sharding.has_value(); - }); -} - static bool SomeOptionalShardingsAreSet( llvm::ArrayRef> shardings) { return llvm::any_of(shardings, @@ -829,11 +833,46 @@ bool SimplyReturnedOp(mlir::Operation* op) { return false; } +void BuildGetTupleElementsForTupleResults(mlir::Operation* op, xla::XlaOp tuple, + OpLoweringContext ctx) { + const std::optional& tuple_sharding = + ctx.builder->sharding(); + if (tuple_sharding.has_value()) { + assert(op->getNumResults() == tuple_sharding->tuple_shardings_size()); + for (auto [index, result] : llvm::enumerate(op->getResults())) { + xla::XlaScopedShardingAssignment scoped_sharding( + ctx.builder, tuple_sharding->tuple_shardings(index)); + (*ctx.values)[result] = xla::GetTupleElement(tuple, index); + } + } else { + xla::XlaScopedShardingAssignment scoped_sharding(ctx.builder, std::nullopt); + for (auto [index, result] : llvm::enumerate(op->getResults())) { + (*ctx.values)[result] = xla::GetTupleElement(tuple, index); + } + } +} + } // namespace namespace mlir { namespace mhlo { namespace { +LogicalResult ExportXlaOp(CollectiveBroadcastOp op, OpLoweringContext ctx) { + auto& value_map = *ctx.values; + xla::XlaOp operand; + if (failed(GetXlaOp(op.getOperand(), value_map, &operand, op))) + return failure(); + value_map[op->getResult(0)] = xla::CollectiveBroadcast( + operand, Convert_replica_groups(op.getReplicaGroups()), + Convert_channel_handle(op.getChannelHandle())); + + return success(); +} + +LogicalResult ExportXlaOp(CompositeOp, OpLoweringContext) { + // TODO: b/328526226 - Implement MHLO export for CompositeOp. + return failure(); +} LogicalResult ExportXlaOp(ComputeReshapeShapeOp, OpLoweringContext) { // This op should've been removed during PrepareForExport. @@ -944,21 +983,47 @@ LogicalResult ExportXlaOp(AddDependencyOp op, OpLoweringContext ctx) { LogicalResult ExportXlaOp(AllGatherOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; - xla::XlaOp operand; - if (failed(GetXlaOp(op.getOperand(), value_map, &operand, op))) - return failure(); - TensorType operand_type = op.getOperand().getType().cast(); - TensorType result_type = op.getType(); - if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) + + SmallVector operands; + if (failed(GetTuple(op.getOperation(), op.getOperands(), ctx, operands))) { return failure(); + } + + mlir::FailureOr shape_or = ExtractXlaShape(op.getOperation()); + if (failed(shape_or)) return failure(); + auto all_gather_dim = op.getAllGatherDim(); - int64_t shard_count = result_type.getDimSize(all_gather_dim) / - operand_type.getDimSize(all_gather_dim); - value_map[op] = xla::AllGather( - operand, all_gather_dim, shard_count, - Convert_replica_groups(op.getReplicaGroups()), - Convert_channel_handle(op.getChannelHandle()), std::nullopt, - Convert_use_global_device_ids(op.getUseGlobalDeviceIds())); + int64_t shard_count = 0; + for (size_t i = 0; i < operands.size(); ++i) { + TensorType operand_type = op.getOperand(i).getType().cast(); + TensorType result_type = op.getType(i).cast(); + if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) + return failure(); + if (i == 0) { + shard_count = result_type.getDimSize(all_gather_dim) / + operand_type.getDimSize(all_gather_dim); + } + } + + if (shape_or->IsTuple()) { + std::optional layout = std::nullopt; + if (shape_or->has_layout()) { + layout = shape_or->layout(); + } + auto tuple = xla::AllGatherTuple( + operands, all_gather_dim, shard_count, + Convert_replica_groups(op.getReplicaGroups()), + Convert_channel_handle(op.getChannelHandle()), layout, + Convert_use_global_device_ids(op.getUseGlobalDeviceIds())); + BuildGetTupleElementsForTupleResults(op, tuple, ctx); + } else { + value_map[op->getResults()[0]] = xla::AllGather( + operands[0], all_gather_dim, shard_count, + Convert_replica_groups(op.getReplicaGroups()), + Convert_channel_handle(op.getChannelHandle()), std::nullopt, + Convert_use_global_device_ids(op.getUseGlobalDeviceIds())); + } + return success(); } @@ -983,9 +1048,7 @@ LogicalResult ExportXlaOp(AllReduceOp op, OpLoweringContext ctx) { operands, computation, Convert_replica_groups(op.getReplicaGroups()), Convert_channel_handle(op.getChannelHandle()), shape_with_layout, Convert_use_global_device_ids(op.getUseGlobalDeviceIds())); - for (auto [index, result] : llvm::enumerate(op.getResults())) { - value_map[result] = xla::GetTupleElement(tuple, index); - } + BuildGetTupleElementsForTupleResults(op, tuple, ctx); } else { value_map[op->getResults()[0]] = xla::AllReduce( operands[0], computation, Convert_replica_groups(op.getReplicaGroups()), @@ -1014,9 +1077,7 @@ LogicalResult ExportXlaOp(AllToAllOp op, OpLoweringContext ctx) { auto tuple = xla::AllToAllTuple( operands, Convert_replica_groups(op.getReplicaGroups()), layout, Convert_channel_handle(op.getChannelHandle())); - for (auto [index, result] : llvm::enumerate(op.getResults())) { - value_map[result] = xla::GetTupleElement(tuple, index); - } + BuildGetTupleElementsForTupleResults(op, tuple, ctx); } else { // ArrayAllToAll always has exactly one operand (checked in the verifier). value_map[op->getResults()[0]] = xla::AllToAll( @@ -1058,18 +1119,16 @@ LogicalResult ExportXlaOp(ReduceScatterOp op, OpLoweringContext ctx) { LogicalResult ExportXlaOp(AsyncStartOp op, OpLoweringContext ctx) { for (auto* user : op.getResult().getUsers()) { if (auto asyncOp = dyn_cast_or_null(user)) { - if (asyncOp.getGroupId() != op.getGroupId() || - asyncOp.getCalledComputation() != op.getCalledComputation()) { + if (asyncOp.getCalledComputation() != op.getCalledComputation()) { return op.emitOpError() << "Users of AsyncStart's return value must have " - "the same group_id and called_computation"; + "the same called_computation"; } } else if (auto asyncOp = dyn_cast_or_null(user)) { - if (asyncOp.getGroupId() != op.getGroupId() || - asyncOp.getCalledComputation() != op.getCalledComputation()) { + if (asyncOp.getCalledComputation() != op.getCalledComputation()) { return op.emitOpError() << "Users of AsyncStart's return value must have " - "the same group_id and called_computation"; + "the same called_computation"; } } else { return op.emitOpError() << "Users of AsyncStart's return value must be " @@ -1090,8 +1149,8 @@ LogicalResult ExportXlaOp(AsyncStartOp op, OpLoweringContext ctx) { dyn_cast_or_null(callee.getBody().front().front()); if (all_gather_op && SimplyReturnedOp(all_gather_op)) { TensorType operand_type = - all_gather_op.getOperand().getType().cast(); - TensorType result_type = all_gather_op.getType(); + all_gather_op.getOperand(0).getType().cast(); + TensorType result_type = all_gather_op.getType(0).cast(); if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) return failure(); if (operands.size() != 1) return failure(); @@ -1182,23 +1241,12 @@ LogicalResult ExportXlaOp(AsyncStartOp op, OpLoweringContext ctx) { ctx.converter->GetLoweredComputation(callee); computation.mutable_proto()->mutable_computations(0)->set_execution_thread( op.getExecutionThread().str()); - if (op.getGroupId()) { - auto [xla_op, computation_id] = - xla::internal::XlaBuilderFriend::BuildAsyncStart( - ctx.builder, operands, op.getExecutionThread().str(), - *op.getGroupId(), computation, xla::TypeToShape(result.getType())); - value_map[result] = xla_op; - computation.mutable_proto()->mutable_computations(0)->set_id( - computation_id); - } else { - auto [xla_op, computation_id] = - xla::internal::XlaBuilderFriend::BuildAsyncStart( - ctx.builder, operands, op.getExecutionThread().str(), computation, - xla::TypeToShape(result.getType())); - value_map[result] = xla_op; - computation.mutable_proto()->mutable_computations(0)->set_id( - computation_id); - } + auto [xla_op, computation_id] = + xla::internal::XlaBuilderFriend::BuildAsyncStart( + ctx.builder, operands, op.getExecutionThread().str(), computation, + xla::TypeToShape(result.getType())); + value_map[result] = xla_op; + computation.mutable_proto()->mutable_computations(0)->set_id(computation_id); return success(); } @@ -1217,15 +1265,13 @@ LogicalResult ExportXlaOp(AsyncUpdateOp op, OpLoweringContext ctx) { for (auto* user : op.getResult().getUsers()) { if (auto asyncOp = dyn_cast_or_null(user)) { - if (asyncOp.getGroupId() != op.getGroupId() || - asyncOp.getCalledComputation() != op.getCalledComputation()) { + if (asyncOp.getCalledComputation() != op.getCalledComputation()) { return op.emitOpError() << "Users of AsyncUpdate's return value must have " "the same group_id and called_computation"; } } else if (auto asyncOp = dyn_cast_or_null(user)) { - if (asyncOp.getGroupId() != op.getGroupId() || - asyncOp.getCalledComputation() != op.getCalledComputation()) { + if (asyncOp.getCalledComputation() != op.getCalledComputation()) { return op.emitOpError() << "Users of AsyncUpdate's return value must have " "the same group_id and called_computation"; @@ -1246,17 +1292,10 @@ LogicalResult ExportXlaOp(AsyncUpdateOp op, OpLoweringContext ctx) { FlatSymbolRefAttr::get(op->getContext(), op.getCalledComputation())); xla::XlaComputation& computation = ctx.converter->GetLoweredComputation(callee); - if (op.getGroupId()) { - value_map[result] = xla::internal::XlaBuilderFriend::BuildAsyncUpdate( - ctx.builder, operand, op.getExecutionThread().str(), *op.getGroupId(), - computation.proto().computations(0).id(), - xla::TypeToShape(result.getType())); - } else { - value_map[result] = xla::internal::XlaBuilderFriend::BuildAsyncUpdate( - ctx.builder, operand, op.getExecutionThread().str(), - computation.proto().computations(0).id(), - xla::TypeToShape(result.getType())); - } + value_map[result] = xla::internal::XlaBuilderFriend::BuildAsyncUpdate( + ctx.builder, operand, op.getExecutionThread().str(), + computation.proto().computations(0).id(), + xla::TypeToShape(result.getType())); return success(); } @@ -1285,7 +1324,7 @@ LogicalResult ExportXlaOp(AsyncDoneOp op, OpLoweringContext ctx) { if (all_gather_op && SimplyReturnedOp(all_gather_op)) { value_map[op.getResult(0)] = xla::internal::XlaBuilderFriend::BuildAllGatherDone( - ctx.builder, operand, xla::TypeToShape(all_gather_op.getType())); + ctx.builder, operand, xla::TypeToShape(all_gather_op.getType(0))); return success(); } auto all_reduce_op = @@ -1336,11 +1375,7 @@ LogicalResult ExportXlaOp(AsyncDoneOp op, OpLoweringContext ctx) { if (op.getNumResults() == 1) { value_map[op.getResult(0)] = xla_recv; } else { - xla::XlaScopedShardingAssignment scoped_sharding(ctx.builder, - std::nullopt); - for (const auto& item : llvm::enumerate(op.getResults())) { - value_map[item.value()] = xla::GetTupleElement(xla_recv, item.index()); - } + BuildGetTupleElementsForTupleResults(op, xla_recv, ctx); } return success(); } @@ -1353,22 +1388,13 @@ LogicalResult ExportXlaOp(AsyncDoneOp op, OpLoweringContext ctx) { } xla::Shape data_shape = xla::ShapeUtil::MakeTupleShape(subshapes); - xla::XlaOp exportedOp; - if (op.getGroupId()) { - exportedOp = xla::internal::XlaBuilderFriend::BuildAsyncDone( - ctx.builder, operand, op.getExecutionThread().str(), *op.getGroupId(), - computation.proto().computations(0).id(), data_shape); - } else { - exportedOp = xla::internal::XlaBuilderFriend::BuildAsyncDone( - ctx.builder, operand, op.getExecutionThread().str(), - computation.proto().computations(0).id(), data_shape); - } + xla::XlaOp exportedOp = xla::internal::XlaBuilderFriend::BuildAsyncDone( + ctx.builder, operand, op.getExecutionThread().str(), + computation.proto().computations(0).id(), data_shape); if (op.getNumResults() == 1) { value_map[op.getResult(0)] = exportedOp; } else { - for (const auto& item : llvm::enumerate(op.getResults())) { - value_map[item.value()] = xla::GetTupleElement(exportedOp, item.index()); - } + BuildGetTupleElementsForTupleResults(op, exportedOp, ctx); } return success(); } @@ -1380,7 +1406,8 @@ LogicalResult ExportXlaOp(BitcastConvertOp op, OpLoweringContext ctx) { return failure(); value_map[op] = xla::BitcastConvertType( - operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()))); + operand, + xla::ConvertMlirTypeToPrimitiveType(getElementTypeOrSelf(op.getType()))); return success(); } @@ -1408,7 +1435,7 @@ LogicalResult ExportXlaOp(StochasticConvertOp op, OpLoweringContext ctx) { value_map[op] = xla::StochasticConvertType( operand, random, - xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()))); + xla::ConvertMlirTypeToPrimitiveType(getElementTypeOrSelf(op.getType()))); return success(); } @@ -1442,7 +1469,7 @@ LogicalResult ExportXlaOp(DotOp op, OpLoweringContext ctx) { if (failed(GetXlaOp(op.getRhs(), value_map, &rhs, op))) return mlir::failure(); xla::PrimitiveType preferred_element_type = - xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())); + xla::ConvertMlirTypeToPrimitiveType(getElementTypeOrSelf(op.getType())); value_map[op] = xla::Dot( lhs, rhs, Unwrap(Convert_precision_config(op.getPrecisionConfig())), preferred_element_type); @@ -1457,7 +1484,7 @@ LogicalResult ExportXlaOp(DotGeneralOp op, OpLoweringContext ctx) { if (failed(GetXlaOp(op.getRhs(), value_map, &rhs, op))) return mlir::failure(); xla::PrimitiveType preferred_element_type = - xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())); + xla::ConvertMlirTypeToPrimitiveType(getElementTypeOrSelf(op.getType())); value_map[op] = xla::DotGeneral( lhs, rhs, Convert_dot_dimension_numbers(op.getDotDimensionNumbers()), Unwrap(Convert_precision_config(op.getPrecisionConfig())), @@ -1465,6 +1492,36 @@ LogicalResult ExportXlaOp(DotGeneralOp op, OpLoweringContext ctx) { return mlir::success(); } +LogicalResult ExportXlaOp(SparseDotOp op, OpLoweringContext ctx) { + auto& value_map = *ctx.values; + xla::XlaOp lhs, rhs; + if (failed(GetXlaOp(op.getLhs(), value_map, &lhs, op))) + return mlir::failure(); + if (failed(GetXlaOp(op.getRhs(), value_map, &rhs, op))) + return mlir::failure(); + xla::PrimitiveType preferred_element_type = + xla::ConvertMlirTypeToPrimitiveType(getElementTypeOrSelf(op.getType())); + + llvm::SmallVector sparse_meta; + if (failed(GetTuple(op, op.getMeta(), ctx, sparse_meta))) return failure(); + std::vector sparsity; + if (op.getLhsSparsity().has_value()) { + sparsity.push_back( + Convert_sparsity_descriptor(*op.getLhsSparsity(), /*is_lhs=*/true)); + } + if (op.getRhsSparsity().has_value()) { + sparsity.push_back( + Convert_sparsity_descriptor(*op.getRhsSparsity(), /*is_lhs=*/false)); + } + + value_map[op] = + xla::SparseDot(lhs, rhs, absl::MakeSpan(sparse_meta), sparsity, + Convert_dot_dimension_numbers(op.getDotDimensionNumbers()), + Unwrap(Convert_precision_config(op.getPrecisionConfig())), + preferred_element_type); + return mlir::success(); +} + LogicalResult ExportXlaOp(DomainOp op, OpLoweringContext ctx) { auto& valueMap = *ctx.values; @@ -1551,9 +1608,7 @@ LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) { if (op.getNumResults() == 1) { value_map[op.getResult(0)] = ifop; } else { - for (const auto& item : llvm::enumerate(op.getResults())) { - value_map[item.value()] = xla::GetTupleElement(ifop, item.index()); - } + BuildGetTupleElementsForTupleResults(op, ifop, ctx); } return success(); @@ -1608,9 +1663,7 @@ LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) { if (op.getNumResults() == 1) { value_map[op.getResult(0)] = caseop; } else { - for (const auto& item : llvm::enumerate(op.getResults())) { - value_map[item.value()] = xla::GetTupleElement(caseop, item.index()); - } + BuildGetTupleElementsForTupleResults(op, caseop, ctx); } return success(); } @@ -1653,7 +1706,7 @@ LogicalResult ExportXlaOp(mlir::mhlo::ConvolutionOp op, OpLoweringContext ctx) { if (failed(GetXlaOp(op.getRhs(), value_map, &rhs, op))) return mlir::failure(); xla::PrimitiveType preferred_element_type = - xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())); + xla::ConvertMlirTypeToPrimitiveType(getElementTypeOrSelf(op.getType())); xla::XlaOp xla_result = xla::ConvGeneralDilated( lhs, rhs, Convert_window_strides(op.getWindowStrides()), Convert_padding(op.getPadding()), @@ -1675,7 +1728,8 @@ LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) { return failure(); value_map[op] = xla::ConvertElementType( - operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()))); + operand, + xla::ConvertMlirTypeToPrimitiveType(getElementTypeOrSelf(op.getType()))); return success(); } @@ -1963,9 +2017,7 @@ LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) { reduction_dim, comparator, recall_target, aggregate_to_topk, reduction_input_size_override); } - for (const auto& item : llvm::enumerate(op.getResults())) { - value_map[item.value()] = xla::GetTupleElement(cc_op, item.index()); - } + BuildGetTupleElementsForTupleResults(op, cc_op, ctx); return success(); } @@ -2006,7 +2058,7 @@ LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) { } } - StatusOr literal; + absl::StatusOr literal; const xla::Literal* literal_ptr = nullptr; auto literal_attr = op->getAttrOfType(kLiteralAttr); if (literal_attr) { @@ -2065,9 +2117,7 @@ LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) { if (op->getNumResults() == 1) { value_map[op.getResult(0)] = custom_call; } else { - for (auto [index, result] : llvm::enumerate(op.getResults())) { - value_map[result] = xla::GetTupleElement(custom_call, index); - } + BuildGetTupleElementsForTupleResults(op, custom_call, ctx); } return success(); @@ -2262,9 +2312,7 @@ LogicalResult ExportXlaOp(ReduceOp op, OpLoweringContext ctx) { if (op.getNumResults() == 1) { value_map[op.getResult(0)] = result; } else { - for (const auto& item : llvm::enumerate(op.getResults())) { - value_map[item.value()] = xla::GetTupleElement(result, item.index()); - } + BuildGetTupleElementsForTupleResults(op, result, ctx); } return success(); } @@ -2292,9 +2340,7 @@ LogicalResult ExportXlaOp(ReduceWindowOp op, OpLoweringContext ctx) { if (op.getNumResults() == 1) { value_map[op.getResult(0)] = result; } else { - for (const auto& item : llvm::enumerate(op.getResults())) { - value_map[item.value()] = xla::GetTupleElement(result, item.index()); - } + BuildGetTupleElementsForTupleResults(op, result, ctx); } return success(); } @@ -2324,9 +2370,7 @@ LogicalResult ExportXlaOp(RngBitGeneratorOp op, OpLoweringContext ctx) { static_cast(op.getRngAlgorithm()), Unwrap(xla_arg_1), xla::TypeToShape(results[1].getType())); - for (const auto& item : llvm::enumerate(results)) - value_map[item.value()] = xla::GetTupleElement(xla_result, item.index()); - + BuildGetTupleElementsForTupleResults(op, xla_result, ctx); return mlir::success(); } @@ -2341,7 +2385,6 @@ LogicalResult ExportXlaOp(XlaRngGetAndUpdateStateOp op, OpLoweringContext ctx) { LogicalResult ExportXlaOp(BatchNormGradOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; - auto results = op.getResults(); xla::XlaOp operand, scale, mean, variance, grad_output; if (failed(GetXlaOp(op.getOperand(), value_map, &operand, op))) @@ -2357,15 +2400,13 @@ LogicalResult ExportXlaOp(BatchNormGradOp op, OpLoweringContext ctx) { xla::BatchNormGrad(operand, scale, mean, variance, grad_output, ConvertAPFloat(op.getEpsilon()), op.getFeatureIndex()); - for (const auto& item : llvm::enumerate(results)) - value_map[item.value()] = xla::GetTupleElement(xla_result, item.index()); + BuildGetTupleElementsForTupleResults(op, xla_result, ctx); return mlir::success(); } LogicalResult ExportXlaOp(BatchNormTrainingOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; - auto results = op.getResults(); xla::XlaOp operand, scale, offset; if (failed(GetXlaOp(op.getOperand(), value_map, &operand, op))) @@ -2378,8 +2419,7 @@ LogicalResult ExportXlaOp(BatchNormTrainingOp op, OpLoweringContext ctx) { ConvertAPFloat(op.getEpsilon()), op.getFeatureIndex()); - for (const auto& item : llvm::enumerate(results)) - value_map[item.value()] = xla::GetTupleElement(xla_result, item.index()); + BuildGetTupleElementsForTupleResults(op, xla_result, ctx); return mlir::success(); } @@ -2428,9 +2468,7 @@ LogicalResult ExportXlaOp(ScatterOp op, OpLoweringContext ctx) { } // mhlo.ScatterOp supports multiple returns, untuple all the results of XLA's. - for (const auto& it : llvm::enumerate(op.getResults())) { - value_map[it.value()] = xla::GetTupleElement(scatter_op, it.index()); - } + BuildGetTupleElementsForTupleResults(op, scatter_op, ctx); return success(); } @@ -2557,9 +2595,7 @@ LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) { } // MLIR's sort supports multiple returns, untuple all the results of XLA's. - for (const auto& it : llvm::enumerate(op.getResults())) { - value_map[it.value()] = xla::GetTupleElement(sorted, it.index()); - } + BuildGetTupleElementsForTupleResults(op, sorted, ctx); return success(); } @@ -2622,9 +2658,7 @@ LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) { } // mhlo.WhileOp supports multiple returns, untuple all the results of XLA's. - for (const auto& it : llvm::enumerate(op.getResults())) { - value_map[it.value()] = xla::GetTupleElement(whileop, it.index()); - } + BuildGetTupleElementsForTupleResults(op, whileop, ctx); return success(); } @@ -2643,10 +2677,7 @@ LogicalResult ExportXlaOp(OptimizationBarrierOp op, OpLoweringContext ctx) { xla::OptimizationBarrier(operands[0]); } else { auto result = xla::OptimizationBarrier(Tuple(ctx.builder, operands)); - - for (const auto& it : llvm::enumerate(op.getResults())) { - value_map[it.value()] = xla::GetTupleElement(result, it.index()); - } + BuildGetTupleElementsForTupleResults(op, result, ctx); } return success(); @@ -2679,9 +2710,7 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) { if (op.getNumResults() == 1) { values[op.getResult(0)] = fusion; } else { - for (const auto& item : llvm::enumerate(op.getResults())) { - values[item.value()] = xla::GetTupleElement(fusion, item.index()); - } + BuildGetTupleElementsForTupleResults(op, fusion, ctx); } return success(); } @@ -2740,9 +2769,7 @@ LogicalResult ExportXlaOp(TopKOp op, OpLoweringContext ctx) { auto topk = xla::TopK(operand, op.getK(), op.getLargest()); // Untuple the two results of XLA's topk. - for (const auto& [index, value] : llvm::enumerate(op.getResults())) { - value_map[value] = xla::GetTupleElement(topk, index); - } + BuildGetTupleElementsForTupleResults(op, topk, ctx); return success(); } @@ -3114,7 +3141,7 @@ LogicalResult ConvertToHloModule::Lower( if (!is_entry_function || !has_ret_shardings) continue; xla::Shape return_shape = xla::TypeToShape(ret.get().getType()); - StatusOr reshape = + absl::StatusOr reshape = ReshapeWithCorrectRepresentationAndSharding( builder, returns[index], return_shape, options_.layout_preference_fn, options_.shape_representation_fn, @@ -3228,6 +3255,15 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::func::FuncOp f) { any_arg_replicated |= entry_args_same_across_replicas.back(); // Pass the alias info to the builder so that it will build the alias info // into the resulting HloModule. + auto buffer_donor = + f.getArgAttrOfType(i, "jax.buffer_donor"); + if (buffer_donor) { + if (use_tuple_args_) { + builder.AddBufferDonor(/*param_number=*/0, /*param_index=*/{i}); + } else { + builder.AddBufferDonor(/*param_number=*/i, /*param_index=*/{}); + } + } auto aliasing_output = f.getArgAttrOfType(i, "tf.aliasing_output"); if (!aliasing_output) continue; @@ -3325,23 +3361,29 @@ LogicalResult ConvertToHloModule::SetEntryTupleShardings( Block* block, xla::XlaBuilder* builder, llvm::ArrayRef> arg_shardings, llvm::SmallVectorImpl* arg_shapes) { - if (!arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings)) { + if (!arg_shardings.empty() && SomeOptionalShardingsAreSet(arg_shardings)) { xla::OpSharding sharding; sharding.set_type(xla::OpSharding::TUPLE); for (const auto& arg_sharding : llvm::enumerate(arg_shardings)) { - auto hlo_sharding = xla::HloSharding::FromProto(*arg_sharding.value()); - if (!hlo_sharding.ok()) - return block->getParentOp()->emitError() - << hlo_sharding.status().message(); - - auto status = RewriteLayoutWithShardedShape( - hlo_sharding.value(), /*use_fast_memory=*/false, - options_.layout_preference_fn, options_.shape_representation_fn, - &(*arg_shapes)[arg_sharding.index()]); - if (!status.ok()) - return block->getParentOp()->emitError() << status.message(); - - *sharding.add_tuple_shardings() = *arg_sharding.value(); + if (arg_sharding.value().has_value()) { + auto hlo_sharding = xla::HloSharding::FromProto(*arg_sharding.value()); + if (!hlo_sharding.ok()) + return block->getParentOp()->emitError() + << hlo_sharding.status().message(); + + auto status = RewriteLayoutWithShardedShape( + hlo_sharding.value(), /*use_fast_memory=*/false, + options_.layout_preference_fn, options_.shape_representation_fn, + &(*arg_shapes)[arg_sharding.index()]); + if (!status.ok()) + return block->getParentOp()->emitError() << status.message(); + + *sharding.add_tuple_shardings() = *arg_sharding.value(); + } else { + xla::OpSharding fallback_sharding; + fallback_sharding.set_type(xla::OpSharding::REPLICATED); + *sharding.add_tuple_shardings() = fallback_sharding; + } } builder->SetSharding(sharding); @@ -3382,13 +3424,15 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( builder->ClearSharding(); bool set_tuple_element_sharding = - !arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings); + !arg_shardings.empty() && SomeOptionalShardingsAreSet(arg_shardings); for (BlockArgument& arg : block->getArguments()) { - if (set_tuple_element_sharding) + if (set_tuple_element_sharding && + arg_shardings[arg.getArgNumber()].has_value()) { builder->SetSharding(*arg_shardings[arg.getArgNumber()]); + } lowering[arg] = xla::GetTupleElement(tuple, arg.getArgNumber()); + builder->ClearSharding(); } - builder->ClearSharding(); } else { if (ensure_single_arg) { // Applicable for mhlo.IfOp or mhlo.CaseOp or mhlo.WhileOp. @@ -3494,7 +3538,7 @@ LogicalResult ConvertToHloModule::LowerRegionAsComputation( } // Runs the PrepareForExport pass on the ModuleOp. -xla::Status PrepareForExport(mlir::ModuleOp module) { +absl::Status PrepareForExport(mlir::ModuleOp module) { bool hasShapeOps = false; module.walk([&](Operation* op) { hasShapeOps |= isa(op->getDialect()); @@ -3513,14 +3557,15 @@ xla::Status PrepareForExport(mlir::ModuleOp module) { } if (failed(pm.run(module))) return tsl::errors::Internal("Unable to prepare for XLA export"); - return ::tsl::OkStatus(); + return absl::OkStatus(); } } // namespace -xla::Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto, - bool use_tuple_args, bool return_tuple, - MlirToHloConversionOptions options) { +absl::Status ConvertMlirHloToHlo(mlir::ModuleOp module, + xla::HloProto* hlo_proto, bool use_tuple_args, + bool return_tuple, + MlirToHloConversionOptions options) { // To support the ongoing migration of XLA's compiler interface from MHLO // to StableHLO, we've inserted this fallback to provide support for backends // which are converting incoming ModuleOps directly to HLO. @@ -3590,13 +3635,13 @@ xla::Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto, converter.BuildStackFramesIndexProto(); hlo_module.mutable_stack_frame_index()->Swap(&stack_frame_index); hlo_proto->mutable_hlo_module()->Swap(&hlo_module); - return ::tsl::OkStatus(); + return absl::OkStatus(); } -xla::Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder, - llvm::ArrayRef xla_params, - std::vector& returns, - MlirToHloConversionOptions options) { +absl::Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder, + llvm::ArrayRef xla_params, + std::vector& returns, + MlirToHloConversionOptions options) { auto module = block.getParentOp()->getParentOfType(); TF_RETURN_IF_ERROR(PrepareForExport(module)); ConvertToHloModule converter(module, builder, @@ -3635,7 +3680,7 @@ xla::Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder, } } - return ::tsl::OkStatus(); + return absl::OkStatus(); } } // namespace mlir diff --git a/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h b/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h index 5cbdfdd398dab..de1360e2d3164 100644 --- a/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h +++ b/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -54,18 +54,18 @@ struct MlirToHloConversionOptions { // are converted to a tuple even when there is only a single return value. // Multiple return values are always converted to a tuple and returned as a // single value. -xla::Status ConvertMlirHloToHlo(mlir::ModuleOp module, - ::xla::HloProto* hlo_proto, bool use_tuple_args, - bool return_tuple, - MlirToHloConversionOptions options = {}); +absl::Status ConvertMlirHloToHlo(mlir::ModuleOp module, + ::xla::HloProto* hlo_proto, + bool use_tuple_args, bool return_tuple, + MlirToHloConversionOptions options = {}); // Transforms a Block into HLO, where the HLO is represented as calls into an // XlaBuilder. Callee functions are allowed in the Block's ancestor ModuleOp. // xla_params are inputs to block. returns are the returned XlaOps. -xla::Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder, - llvm::ArrayRef xla_params, - std::vector& returns, - MlirToHloConversionOptions options = {}); +absl::Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder, + llvm::ArrayRef xla_params, + std::vector& returns, + MlirToHloConversionOptions options = {}); } // namespace mlir diff --git a/xla/translate/mhlo_to_hlo/operator_writer_gen.cc b/xla/translate/mhlo_to_hlo/operator_writer_gen.cc index 511a5b4a72500..5b6ac2104bbb3 100644 --- a/xla/translate/mhlo_to_hlo/operator_writer_gen.cc +++ b/xla/translate/mhlo_to_hlo/operator_writer_gen.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -43,10 +43,10 @@ static std::string GetDefaultAttrExport( Attribute attr = named_attr.attr; StringRef storage_type = attr.getStorageType(); // For some attribute types we have a general conversion, so use that. - if (!attr.isEnumAttr() && (storage_type.endswith("BoolAttr") || - storage_type.endswith("FloatAttr") || - storage_type.endswith("IntegerAttr") || - storage_type.endswith("StringAttr"))) { + if (!attr.isEnumAttr() && (storage_type.ends_with("BoolAttr") || + storage_type.ends_with("FloatAttr") || + storage_type.ends_with("IntegerAttr") || + storage_type.ends_with("StringAttr"))) { // The return type may contains qualified namespaces. Split to remove them. std::pair splits = attr.getReturnType().rsplit("::"); StringRef symbol = splits.second; diff --git a/xla/translate/mhlo_to_hlo/stack_frame_index_builder.cc b/xla/translate/mhlo_to_hlo/stack_frame_index_builder.cc index 27de242099776..49fa8fd245b4b 100644 --- a/xla/translate/mhlo_to_hlo/stack_frame_index_builder.cc +++ b/xla/translate/mhlo_to_hlo/stack_frame_index_builder.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -101,7 +101,8 @@ bool IsFrameNameLocation(mlir::Location location) { isa(cast(location).getChildLoc()); } -int StackFrameIndexBuilder::AddCallStackAndGetFirstFrameId( +StackFrameIndexBuilder::AddStackFrameResult +StackFrameIndexBuilder::AddCallStackAndGetFirstFrameId( const mlir::Location &root_loc) { std::stack locations; mlir::CallSiteLoc call_site; @@ -130,7 +131,16 @@ int StackFrameIndexBuilder::AddCallStackAndGetFirstFrameId( parent_frame_id = AddStackFrameLocation(name_location, parent_frame_id); } - return parent_frame_id; + if (parent_frame_id == StackFrameIndexBuilder::kInvalidIndex) { + return {StackFrameIndexBuilder::kInvalidIndex, "", 0}; + } + + auto stack_frame = indexes_.stack_frames(parent_frame_id - 1); + auto file_location = + indexes_.file_locations(stack_frame.file_location_id() - 1); + return {parent_frame_id, + indexes_.file_names(file_location.file_name_id() - 1), + file_location.line()}; } xla::StackFrameIndexProto StackFrameIndexBuilder::Build() const { diff --git a/xla/translate/mhlo_to_hlo/stack_frame_index_builder.h b/xla/translate/mhlo_to_hlo/stack_frame_index_builder.h index 804795eb83f62..c6c9eb2900ea4 100644 --- a/xla/translate/mhlo_to_hlo/stack_frame_index_builder.h +++ b/xla/translate/mhlo_to_hlo/stack_frame_index_builder.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ limitations under the License. #define XLA_TRANSLATE_MHLO_TO_HLO_STACK_FRAME_INDEX_BUILDER_H_ #include +#include #include #include @@ -30,7 +31,14 @@ class StackFrameIndexBuilder { xla::StackFrameIndexProto Build() const; - int AddCallStackAndGetFirstFrameId(const mlir::Location &root_loc); + struct AddStackFrameResult { + int last_frame_id; + std::string last_frame_file; + int last_frame_line; + }; + + AddStackFrameResult AddCallStackAndGetFirstFrameId( + const mlir::Location &root_loc); private: int AddStackFrameLocation(const mlir::NameLoc &name_location, diff --git a/xla/translate/mhlo_to_hlo/tests/BUILD b/xla/translate/mhlo_to_hlo/tests/BUILD index 405acd1a2da11..d0bcc19787163 100644 --- a/xla/translate/mhlo_to_hlo/tests/BUILD +++ b/xla/translate/mhlo_to_hlo/tests/BUILD @@ -1,17 +1,51 @@ load("@tsl//tsl:tsl.default.bzl", "filegroup") -load("//xla:glob_lit_test.bzl", "glob_lit_tests") +load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) -glob_lit_tests( +lit_test_suite( name = "all_tests", + srcs = enforce_glob( + [ + "add.mlir", + "case.mlir", + "dynamic.mlir", + "export-with-layouts.mlir", + "export.mlir", + "export_and_check_layouts.mlir", + "export_large_constants.mlir", + "export_replicas.mlir", + "frontend_attributes.mlir", + "fusion.mlir", + "if.mlir", + "input_output_aliasing.mlir", + "int4.mlir", + "layouts_and_names.mlir", + "location_to_op_metadata.mlir", + "location_to_stacktrace.mlir", + "missing_main.mlir", + "module_attributes.mlir", + "multiple_return_tuple.mlir", + "opaque_elements_attr.mlir", + "rng_get_and_update_state.mlir", + "sharding.mlir", + "simple.mlir", + "unsupported_type.mlir", + "while.mlir", + ], + include = [ + "*.mlir", + ], + ), + cfg = "//xla:lit.cfg.py", data = [":test_utilities"], - driver = "@llvm-project//mlir:run_lit.sh", - test_file_exts = [ - "mlir", + tools = [ + "//xla/translate:xla-translate", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", ], ) @@ -19,9 +53,4 @@ glob_lit_tests( filegroup( name = "test_utilities", testonly = True, - data = [ - "//xla/translate:xla-translate", - "@llvm-project//llvm:FileCheck", - "@llvm-project//llvm:not", - ], ) diff --git a/xla/translate/mhlo_to_hlo/tests/dynamic.mlir b/xla/translate/mhlo_to_hlo/tests/dynamic.mlir index bc369324bb586..bb83b10cc3f5f 100644 --- a/xla/translate/mhlo_to_hlo/tests/dynamic.mlir +++ b/xla/translate/mhlo_to_hlo/tests/dynamic.mlir @@ -3,9 +3,9 @@ // CHECK: HloModule main, entry_computation_layout={(s64[<=4,1]{1,0})->s64[1,<=4]{1,0}} func.func @main(%arg0: tensor>) -> tensor<1x?xi64, #mhlo.type_extensions> { %0 = mhlo.constant dense<1> : tensor<1xi32> - %1 = "mhlo.get_dimension_size"(%arg0) {dimension = 0 : i64} : (tensor>) -> tensor + %1 = "mhlo.get_dimension_size"(%arg0) <{dimension = 0 : i64}> : (tensor>) -> tensor %2 = mhlo.reshape %1 : (tensor) -> tensor<1xi32> - %3 = "mhlo.concatenate"(%0, %2) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %3 = "mhlo.concatenate"(%0, %2) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> %4 = mhlo.dynamic_reshape %arg0, %3 : (tensor>, tensor<2xi32>) -> tensor<1x?xi64, #mhlo.type_extensions> func.return %4 : tensor<1x?xi64, #mhlo.type_extensions> // CHECK: %[[ARG0:.*]] = s64[<=4,1] parameter(0) diff --git a/xla/translate/mhlo_to_hlo/tests/export.mlir b/xla/translate/mhlo_to_hlo/tests/export.mlir index 40efe506d8c1c..bbf29e9daa742 100644 --- a/xla/translate/mhlo_to_hlo/tests/export.mlir +++ b/xla/translate/mhlo_to_hlo/tests/export.mlir @@ -187,6 +187,44 @@ func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>) -> (tensor<10xf32>, // ----- +// expected-error@-3 {{'mhlo.async_start' op can't be translated to XLA HLO}} +func.func @all_gather_0(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) attributes {execution_thread = "main"} { + %0:2 = "mhlo.all_gather"(%arg0, %arg1) { + all_gather_dim = 1 : i64, + channel_handle = #mhlo.channel_handle, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, + use_global_device_ids + } : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) + func.return %0#0, %0#1 : tensor<8x2xf32>, tensor<8x4xf32> +} + +func.func @main(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) { + %0 = "mhlo.async_start"(%arg0, %arg1) {called_computation = @all_gather_0, execution_thread = "main"} : (tensor<8x2xf32>, tensor<8x4xf32>) -> !mhlo.async_bundle,tensor<8x4xf32>>, tuple,tensor<8x4xf32>>> + %1:2 = "mhlo.async_done"(%0) {called_computation = @all_gather_0, execution_thread = "main"} : (!mhlo.async_bundle,tensor<8x4xf32>>, tuple,tensor<8x4xf32>>>) -> (tensor<8x2xf32>, tensor<8x4xf32>) + return %1#0, %1#1 : tensor<8x2xf32>, tensor<8x4xf32> +} + +// ----- + +func.func private @main(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> tuple, tensor<8x16xf32>> { + // CHECK: %[[ARG0:.*]] = f32[8,2] parameter(0) + // CHECK-NEXT: %[[ARG1:.*]] = f32[8,4] parameter(1) + // CHECK-NEXT: %[[TUPLE:.*]] = (f32[8,2], f32[8,4]) tuple + // CHECK-NEXT: %[[TUPLE_ARG0:.*]] = f32[8,2] get-tuple-element((f32[8,2], f32[8,4]) %[[TUPLE]]), index=0 + // CHECK-NEXT: %[[TUPLE_ARG1:.*]] = f32[8,4] get-tuple-element((f32[8,2], f32[8,4]) %[[TUPLE]]), index=1 + // CHECK-NEXT: (f32[8,8], f32[8,16]) all-gather(f32[8,2] %[[TUPLE_ARG0]], f32[8,4] %[[TUPLE_ARG1]]), channel_id=1, replica_groups={{.*}}, dimensions={1} + %0:2 = "mhlo.all_gather"(%arg0, %arg1) { + all_gather_dim = 1 : i64, + channel_handle = #mhlo.channel_handle, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, + use_global_device_ids + } : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x8xf32>, tensor<8x16xf32>) + %1 = mhlo.tuple %0#0, %0#1 {xla_shape = "(f32[8,8]{0,1}, f32[8,16]{0,1})"} : tuple, tensor<8x16xf32>> + return %1 : tuple, tensor<8x16xf32>> +} + +// ----- + // CHECK: HloModule func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { %0 = "mhlo.all_reduce"(%arg0) ({ @@ -438,7 +476,7 @@ func.func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> { func.func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> { // CHECK: [[ARG:%.*]] = s32[4] parameter(0) // CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] [[ARG]]), dimensions={3} - %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32> + %0 = "mhlo.broadcast"(%arg0) <{broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>}> : (tensor<4xi32>) -> tensor<1x2x3x4xi32> func.return %0 : tensor<1x2x3x4xi32> } @@ -547,6 +585,20 @@ func.func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, // ----- +// CHECK: HloModule +func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { + %0 = "mhlo.collective_broadcast"(%arg0) { + replica_groups = dense<[[0, 1], [2, 3]]> : tensor<2x2xi64>, + channel_handle = #mhlo.channel_handle + } : (tensor<128x32xf32>) -> tensor<128x32xf32> + func.return %0 : tensor<128x32xf32> +} +// CHECK: ENTRY +// CHECK: [[ARG:%.*]] = f32[128,32] parameter(0) +// CHECK: ROOT [[RESULT:%.*]] = f32[128,32] collective-broadcast(f32[128,32] [[ARG]]), channel_id=1 +// CHECK-SAME{LITERAL}: replica_groups={{0,1},{2,3}} +// ----- + // CHECK: HloModule func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { %0 = "mhlo.collective_permute"(%arg0) { @@ -1698,11 +1750,26 @@ func.func @main(%arg0: tensor<2x2x2xi8>, %arg1: tensor<2x2x3xi8>) -> tensor<2x2x // ----- +// CHECK: HloModule +func.func @main(%arg0: tensor<10x16xbf16>, %arg1: tensor<32x20xbf16>, %meta: tensor<10x2xui16>) -> tensor<10x20xf32> { + // CHECK: dot(bf16[10,16] %{{.*}}, bf16[32,20] %{{.*}}, u16[10,2] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4 + %0 = "mhlo.sparse_dot"(%arg0, %arg1, %meta) { + lhs_sparsity = #mhlo.sparsity, + dot_dimension_numbers = #mhlo.dot< + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [0] + >, + precision_config = []} : (tensor<10x16xbf16>, tensor<32x20xbf16>, tensor<10x2xui16>) -> tensor<10x20xf32> + func.return %0 : tensor<10x20xf32> +} + +// ----- + // CHECK: HloModule func.func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> { // Simple einsum is lowered to HLO dot op. // CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0} - %0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32> + %0 = "mhlo.einsum"(%arg0, %arg1) <{einsum_config = "ab,bc->ac"}> : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32> func.return %0 : tensor<3x5xi32> } @@ -1710,7 +1777,7 @@ func.func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi3 // CHECK: HloModule func.func @main(%arg0: tensor<3x9xf32>) -> tensor<3x5xcomplex> { - %0 = "mhlo.fft"(%arg0) {fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo} : (tensor<3x9xf32>) -> tensor<3x5xcomplex> + %0 = "mhlo.fft"(%arg0) <{fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo}> : (tensor<3x9xf32>) -> tensor<3x5xcomplex> func.return %0 : tensor<3x5xcomplex> } @@ -1731,7 +1798,7 @@ func.func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tens // CHECK-SAME: index_vector_dim=1 // CHECK-SAME: slice_sizes={1,1,300} // CHECK-SAME: indices_are_sorted=true - %0 = "mhlo.gather"(%arg0, %arg1) { + %0 = "mhlo.gather"(%arg0, %arg1) <{ dimension_numbers = #mhlo.gather< collapsed_slice_dims = [0, 1], index_vector_dim = 1, @@ -1740,7 +1807,7 @@ func.func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tens >, indices_are_sorted = true, slice_sizes = dense<[1, 1, 300]> : tensor<3xi64> - } : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32> + }> : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32> func.return %0 : tensor<10x300xf32> } @@ -1748,8 +1815,8 @@ func.func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tens // CHECK: HloModule func.func @main(%arg: tensor<4x2xf32>, %size: tensor) -> tensor { - %0 = "mhlo.set_dimension_size"(%arg, %size) {dimension = 1 : i64} : (tensor<4x2xf32>, tensor) -> tensor<4x2xf32> - %1 = "mhlo.get_dimension_size"(%0) {dimension = 1 : i64} : (tensor<4x2xf32>) -> tensor + %0 = "mhlo.set_dimension_size"(%arg, %size) <{dimension = 1 : i64}> : (tensor<4x2xf32>, tensor) -> tensor<4x2xf32> + %1 = "mhlo.get_dimension_size"(%0) <{dimension = 1 : i64}> : (tensor<4x2xf32>) -> tensor func.return %1 : tensor } @@ -1765,7 +1832,7 @@ func.func @main(%arg: tensor<4x2xf32>, %size: tensor) -> tensor { // CHECK: HloModule func.func @main(%arg: tensor>) -> tensor<8x4xf32> { %size = mhlo.constant dense<8> : tensor - %1 = "mhlo.set_dimension_size"(%arg, %size) {dimension = 0 : i64} : (tensor>, tensor) -> tensor<8x4xf32> + %1 = "mhlo.set_dimension_size"(%arg, %size) <{dimension = 0 : i64}> : (tensor>, tensor) -> tensor<8x4xf32> func.return %1 : tensor<8x4xf32> } @@ -1778,7 +1845,7 @@ func.func @main(%arg: tensor>) - // CHECK: HloModule func.func @main(%arg0: tuple, tensor>) -> tensor { - %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor + %0 = "mhlo.get_tuple_element"(%arg0) <{index = 0 : i32}> : (tuple, tensor>) -> tensor func.return %0 : tensor } @@ -1790,7 +1857,7 @@ func.func @main(%arg0: tuple, tensor>) -> tensor { // CHECK: HloModule func.func @main(%arg0: !mhlo.token) -> tuple, tensor>, !mhlo.token> { - %0:3 = "mhlo.infeed"(%arg0) {infeed_config = "foobar", layout=[[0, 1], [0]]} : (!mhlo.token) -> (tensor<3x3xi32>, tensor, !mhlo.token) + %0:3 = "mhlo.infeed"(%arg0) <{infeed_config = "foobar", layout=[[0, 1], [0]]}> : (!mhlo.token) -> (tensor<3x3xi32>, tensor, !mhlo.token) %1 = "mhlo.tuple"(%0#0, %0#1) : (tensor<3x3xi32>, tensor) -> tuple, tensor> %2 = "mhlo.tuple"(%1, %0#2) : (tuple, tensor>, !mhlo.token) -> tuple, tensor>, !mhlo.token> @@ -1809,7 +1876,7 @@ func.func @main(%arg0: !mhlo.token) -> tuple, tensor>, // CHECK: HloModule func.func @main(%arg0: !mhlo.token) -> tensor<3x3xi32> { - %0:2 = "mhlo.infeed"(%arg0) {infeed_config = "foobar", layout=[[0,1]]} : (!mhlo.token) -> (tensor<3x3xi32>, !mhlo.token) + %0:2 = "mhlo.infeed"(%arg0) <{infeed_config = "foobar", layout=[[0,1]]}> : (!mhlo.token) -> (tensor<3x3xi32>, !mhlo.token) func.return %0#0 : tensor<3x3xi32> } @@ -1825,7 +1892,7 @@ func.func @main(%arg0: !mhlo.token) -> tensor<3x3xi32> { // CHECK: HloModule func.func @main(%arg0: !mhlo.token) -> !mhlo.token { - %0 = "mhlo.infeed"(%arg0) {infeed_config = "foobar", layout = [], xla_shape = "((), token[])"} : (!mhlo.token) -> !mhlo.token + %0 = "mhlo.infeed"(%arg0) <{infeed_config = "foobar", layout = [], xla_shape = "((), token[])"}> : (!mhlo.token) -> !mhlo.token func.return %0 : !mhlo.token } @@ -1978,7 +2045,7 @@ func.func @main(%token: !mhlo.token) -> !mhlo.token { // CHECK: HloModule func.func @main(%arg: tensor<4x6xf32>, %pad: tensor) -> tensor<13x19xf32> { - %0 = "mhlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor) -> tensor<13x19xf32> + %0 = "mhlo.pad"(%arg, %pad) <{edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>, tensor) -> tensor<13x19xf32> func.return %0 : tensor<13x19xf32> } @@ -2106,11 +2173,9 @@ func.func @main(%token: !mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) { // CHECK-SAME: sharding={ // CHECK-SAME: {maximal device=0}, {maximal device=0} // CHECK-SAME: } -// CHECK: [[TUPLE0:%.*]] = s32[3,4] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=0 -// CHECK-NOT: sharding= -// CHECK: [[TUPLE1:%.*]] = token[] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=1 -// CHECK-NOT: sharding= -// CHECK: ROOT {{%.*}} = (s32[3,4], token[]) tuple(s32[3,4] [[TUPLE0]], token[] [[TUPLE1]]) +// CHECK: [[TUPLE0:%.*]] = s32[3,4] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=0, sharding={maximal device=0} +// CHECK: [[TUPLE1:%.*]] = token[] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=1, sharding={maximal device=0} +// CHECK: ROOT {{%.*}} = (s32[3,4], token[]) tuple(s32[3,4] [[TUPLE0]], token[] [[TUPLE1]]) // ----- @@ -2199,7 +2264,7 @@ func.func @main(%arg0 : tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> { // CHECK: HloModule func.func @main(%mu: tensor, %sigma: tensor) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> - %0 = "mhlo.rng"(%mu, %sigma, %shape) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng"(%mu, %sigma, %shape) <{rng_distribution = #mhlo.rng_distribution}> : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2215,7 +2280,7 @@ func.func @main() -> tensor<2x3x5xf32> { %0 = mhlo.constant dense<0.000000e+00> : tensor %1 = mhlo.constant dense<1.000000e+00> : tensor %2 = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> - %3 = "mhlo.rng"(%0, %1, %2) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + %3 = "mhlo.rng"(%0, %1, %2) <{rng_distribution = #mhlo.rng_distribution}> : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %3 : tensor<2x3x5xf32> } @@ -2457,7 +2522,7 @@ func.func @main(%arg: tensor<4x4xf32>, %size: tensor) -> tensor<4x4xf32> { // CHECK: HloModule func.func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> { - %0 = "mhlo.slice"(%arg) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32> + %0 = "mhlo.slice"(%arg) <{start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>}> : (tensor<3x4xi32>) -> tensor<1x2xi32> func.return %0 : tensor<1x2xi32> } @@ -2470,7 +2535,7 @@ func.func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> { // CHECK: HloModule func.func @main(%arg: tensor<3x4xi32>, %start1: tensor, %start2: tensor) -> tensor<1x4xi32> { - %0 = "mhlo.dynamic_slice"(%arg, %start1, %start2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + %0 = "mhlo.dynamic_slice"(%arg, %start1, %start2) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> func.return %0 : tensor<1x4xi32> } @@ -2488,7 +2553,7 @@ func.func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { // CHECK: [[ARG:%.*]] = s32[1,2,3,4] parameter(0) // CHECK-NEXT: ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] [[ARG]]), dimensions={1,0,3,2} - %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + %0 = "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}> : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> func.return %0 : tensor<2x1x4x3xi32> } @@ -2747,7 +2812,7 @@ func.func @main(%arg: tensor<3xui64>) -> tuple, tensor<2x2xui32>> // CHECK: [[GTE1:%.*]] = u32[2,2] get-tuple-element((u64[3], u32[2,2]) [[RNG]]), index=1 // CHECK: ROOT // CHECK-SAME: [[RES:%.*]] = (u64[3], u32[2,2]) tuple(u64[3] [[GTE0]], u32[2,2] [[GTE1]]) - %0:2 = "mhlo.rng_bit_generator"(%arg) {rng_algorithm = #mhlo.rng_algorithm} : (tensor<3xui64>) -> (tensor<3xui64>, tensor<2x2xui32>) + %0:2 = "mhlo.rng_bit_generator"(%arg) <{rng_algorithm = #mhlo.rng_algorithm}> : (tensor<3xui64>) -> (tensor<3xui64>, tensor<2x2xui32>) %1 = "mhlo.tuple"(%0#0, %0#1) : (tensor<3xui64>, tensor<2x2xui32>) -> tuple, tensor<2x2xui32>> func.return %1 : tuple, tensor<2x2xui32>> } @@ -2788,9 +2853,12 @@ func.func @main(%arg: tensor<3x4xf32>) -> tensor<3x4x1xf32> { func.func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<3x4xf32>) -> (tensor<4x4xf32>, tensor<3x4xf32>) { // CHECK: %[[ARG0:.*]] = f32[4,4] parameter(0) // CHECK: %[[ARG1:.*]] = f32[3,4] parameter(1) -// CHECK: %[[ARGS:.*]] = (f32[4,4], f32[3,4]) tuple(f32[4,4] %[[ARG0]], f32[3,4] %[[ARG1]]) -// CHECK: %[[RESULT:.*]] = (f32[4,4], f32[3,4]) opt-barrier((f32[4,4], f32[3,4]) %[[ARGS]]) - %0, %1 = "mhlo.optimization_barrier"(%arg0, %arg1) : (tensor<4x4xf32>, tensor<3x4xf32>) -> (tensor<4x4xf32>, tensor<3x4xf32>) +// CHECK: %[[ARGS:.*]] = (f32[4,4], f32[3,4]) tuple(f32[4,4] %[[ARG0]], f32[3,4] %[[ARG1]]), sharding={{\{}}{replicated}, {devices=[1,2]<=[2]}} +// CHECK: %[[OPT:.*]] = (f32[4,4], f32[3,4]) opt-barrier((f32[4,4], f32[3,4]) %[[ARGS]]), sharding={{\{}}{replicated}, {devices=[1,2]<=[2]}} +// CHECK: %[[GTE0:.*]] = f32[4,4] get-tuple-element((f32[4,4], f32[3,4]) %[[OPT]]), index=0, sharding={replicated} +// CHECK: %[[GTE1:.*]] = f32[3,4] get-tuple-element((f32[4,4], f32[3,4]) %[[OPT]]), index=1, sharding={devices=[1,2]<=[2]} +// CHECK: ROOT %[[ROOT:.*]] = (f32[4,4], f32[3,4]) tuple(f32[4,4] %[[GTE0]], f32[3,4] %[[GTE1]]) + %0, %1 = "mhlo.optimization_barrier"(%arg0, %arg1) {mhlo.sharding = "{{replicated}, {devices=[1,2]<=[2]}}"} : (tensor<4x4xf32>, tensor<3x4xf32>) -> (tensor<4x4xf32>, tensor<3x4xf32>) func.return %0, %1 : tensor<4x4xf32>, tensor<3x4xf32> } @@ -2869,15 +2937,13 @@ func.func @AsyncOp(%arg0: tensor<10xf32>) -> tensor<20xf32> // CHECK: ENTRY func.func @main(%arg0: tensor<10xf32>) -> tensor<20xf32> { // CHECK: %[[ARG0:.*]] = f32[10] parameter(0) - // CHECK: %[[START:.*]] = (f32[10], f32[20], s32[]) async-start(f32[10] %[[ARG0]]) - // CHECK-SAME: calls=[[CALLED_COMPUTATION]] - %0 = "mhlo.async_start"(%arg0) {called_computation = @AsyncOp, execution_thread = "thread"} : (tensor<10xf32>) -> !mhlo.async_bundle, tensor<20xf32>, tensor> - // CHECK: %[[UPDATE:.*]] = (f32[10], f32[20], s32[]) async-update((f32[10], f32[20], s32[]) %[[START]]) + // CHECK: %[[START:.*]] = ((f32[10]), f32[20], s32[]) async-start(f32[10] %[[ARG0]]) // CHECK-SAME: calls=[[CALLED_COMPUTATION]] - %1 = "mhlo.async_update"(%0) {called_computation = @AsyncOp, execution_thread = "thread"} : (!mhlo.async_bundle, tensor<20xf32>, tensor>) -> !mhlo.async_bundle, tensor<20xf32>, tensor> - // CHECK: ROOT %{{.*}} = (f32[20]) async-done((f32[10], f32[20], s32[]) %[[UPDATE]]) - // CHECK-SAME: calls=[[CALLED_COMPUTATION]] - %2 = "mhlo.async_done"(%1) {called_computation = @AsyncOp, execution_thread = "thread"} : (!mhlo.async_bundle, tensor<20xf32>, tensor>) -> tensor<20xf32> + %0 = "mhlo.async_start"(%arg0) {called_computation = @AsyncOp, execution_thread = "thread"} : (tensor<10xf32>) -> !mhlo.async_bundle>, tensor<20xf32>, tensor> + // CHECK: %[[UPDATE:.*]] = ((f32[10]), f32[20], s32[]) async-update(((f32[10]), f32[20], s32[]) %[[START]]) + %1 = "mhlo.async_update"(%0) {called_computation = @AsyncOp, execution_thread = "thread"} : (!mhlo.async_bundle>, tensor<20xf32>, tensor>) -> !mhlo.async_bundle>, tensor<20xf32>, tensor> + // CHECK: ROOT %{{.*}} = (f32[20]) async-done(((f32[10]), f32[20], s32[]) %[[UPDATE]]) + %2 = "mhlo.async_done"(%1) {called_computation = @AsyncOp, execution_thread = "thread"} : (!mhlo.async_bundle>, tensor<20xf32>, tensor>) -> tensor<20xf32> return %2 : tensor<20xf32> } @@ -2896,14 +2962,14 @@ func.func @AsyncOp(%arg0: tensor<10xf32>) -> tensor<20xf32> // CHECK: ENTRY func.func @main(%arg0: tensor<10xf32>) -> tensor<20xf32> { // CHECK: %[[ARG0:.*]] = f32[10] parameter(0) - // CHECK: %[[START:.*]] = (f32[10], f32[20], s32[]) async-start(f32[10] %[[ARG0]]), async_group_id=1, async_execution_thread="thread", calls=[[CALLED_COMPUTATION]], - // CHECK: %[[UPDATE:.*]] = (f32[10], f32[20], s32[]) async-update((f32[10], f32[20], s32[]) %[[START]]), async_group_id=1, async_execution_thread="thread", calls=[[CALLED_COMPUTATION]] + // CHECK: %[[START:.*]] = ((f32[10]), f32[20], s32[]) async-start(f32[10] %[[ARG0]]), async_execution_thread="thread", calls=[[CALLED_COMPUTATION]], + // CHECK: %[[UPDATE:.*]] = ((f32[10]), f32[20], s32[]) async-update(((f32[10]), f32[20], s32[]) %[[START]]) // CHECK: ROOT - // CHECK-SAME: (f32[20]) async-done((f32[10], f32[20], s32[]) %[[UPDATE]]), async_group_id=1, async_execution_thread="thread", calls=[[CALLED_COMPUTATION]] + // CHECK-SAME: (f32[20]) async-done(((f32[10]), f32[20], s32[]) %[[UPDATE]]) - %0 = "mhlo.async_start"(%arg0) {called_computation = @AsyncOp, execution_thread="thread", group_id = 1} : (tensor<10xf32>) -> !mhlo.async_bundle, tensor<20xf32>, tensor> - %1 = "mhlo.async_update"(%0) {called_computation = @AsyncOp, execution_thread="thread", group_id=1} : (!mhlo.async_bundle, tensor<20xf32>, tensor>) -> !mhlo.async_bundle, tensor<20xf32>, tensor> - %2 = "mhlo.async_done"(%1) {called_computation = @AsyncOp, execution_thread="thread", group_id=1} : (!mhlo.async_bundle, tensor<20xf32>, tensor>) -> tensor<20xf32> + %0 = "mhlo.async_start"(%arg0) {called_computation = @AsyncOp, execution_thread="thread"} : (tensor<10xf32>) -> !mhlo.async_bundle>, tensor<20xf32>, tensor> + %1 = "mhlo.async_update"(%0) {called_computation = @AsyncOp, execution_thread="thread"} : (!mhlo.async_bundle>, tensor<20xf32>, tensor>) -> !mhlo.async_bundle>, tensor<20xf32>, tensor> + %2 = "mhlo.async_done"(%1) {called_computation = @AsyncOp, execution_thread="thread"} : (!mhlo.async_bundle>, tensor<20xf32>, tensor>) -> tensor<20xf32> return %2 : tensor<20xf32> } @@ -3072,3 +3138,166 @@ func.func @main(%operand: tensor) -> tensor { // CHECK-NEXT: [[ARG0]] = f32[?,784] parameter(0) // CHECK-NEXT: ROOT {{.*}} = f32[?,784] abs(f32[?,784] %Arg_0.1), {{.*}} // CHECK-NEXT: } + +// ----- + +// reduce multiple implicit captures test +// CHECK: HloModule +// CHECK: [[REG0:%region.*]] ({{.*}} { +// CHECK: f32[] constant(0) +// CHECK: f32[] constant(1) +// CHECK: ROOT +// CHECK: ENTRY +// CHECK: {{.*}} reduce{{.*}} to_apply=[[REG0]] +// CHECK: ROOT +func.func @main(%arg0: tensor<2x2xf32>) -> tuple> { + %0 = mhlo.constant dense<1.000000e+00> : tensor + %1 = mhlo.constant dense<0.000000e+00> : tensor + %2 = mhlo.reduce(%arg0 init: %1) across dimensions = [0, 1] : (tensor<2x2xf32>, tensor) -> tensor + reducer(%arg1: tensor, %arg2: tensor) { + %5 = mhlo.compare NE, %arg1, %1 : (tensor, tensor) -> tensor + %6 = mhlo.compare NE, %arg2, %1 : (tensor, tensor) -> tensor + %7 = mhlo.or %5, %6 : tensor + %8 = mhlo.select %7, %0, %1 : tensor, tensor + mhlo.return %8 : tensor + } + %3 = mhlo.compare NE, %2, %1 : (tensor, tensor) -> tensor + %4 = mhlo.tuple %3 {xla_shape = "(pred[])"} : tuple> + return %4 : tuple> +} + +// ----- + +// all_reduce implicit capture test +// CHECK: HloModule +// CHECK: [[REG0:%region.*]] ({{.*}} { +// CHECK: f32[] constant(0) +// CHECK: ROOT +// CHECK: ENTRY +// CHECK: ROOT {{.*}} all-reduce{{.*}} to_apply=[[REG0]] +func.func @main(%arg0: tensor) -> tensor { + %c = mhlo.constant dense<0.0> : tensor + %0 = "mhlo.all_reduce"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = mhlo.add %arg1, %c : tensor + mhlo.return %1 : tensor + }) {replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>} : (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// reduce_scatter implicit capture test +// CHECK: HloModule +// CHECK: [[REG0:%region.*]] ({{.*}} { +// CHECK: f32[] constant(0) +// CHECK: ROOT +// CHECK: ENTRY +// CHECK: ROOT {{.*}} reduce-scatter{{.*}} to_apply=[[REG0]] +func.func @main(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { + %c = mhlo.constant dense<0.0> : tensor + %0 = "mhlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %c : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #mhlo.channel_handle, + use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf32> + func.return %0 : tensor<4x4xf32> +} + +// ----- + +// reduce_window implicit capture test +// CHECK: HloModule +// CHECK: [[REG0:%region.*]] ({{.*}} { +// CHECK: f32[] constant(0) +// CHECK: ROOT +// CHECK: ENTRY +// CHECK: ROOT {{.*}} reduce-window{{.*}} to_apply=[[REG0]] +func.func @main(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x16x30x7xf32> { + %c = mhlo.constant dense<0.0> : tensor + %0 = "mhlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.maximum %arg2, %c : tensor + mhlo.return %1 : tensor + }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x16x30x7xf32> + return %0 : tensor<2x16x30x7xf32> + } + +// ----- + +// Scatter implicit capture test +// CHECK: HloModule +// CHECK: [[REG0:%region.*]] ({{.*}} { +// CHECK: s32[] constant(0) +// CHECK: ROOT +// CHECK: ENTRY +// CHECK: ROOT {{.*}} scatter{{.*}} to_apply=[[REG0]] +func.func @main(%arg0: tensor<3xi32>, %arg1: tensor<1x1xi32>, + %arg2: tensor<1xi32>) -> tensor<3xi32> { + %c = mhlo.constant dense<0> : tensor + %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %x = mhlo.add %arg4, %c : tensor + "mhlo.return"(%x) : (tensor) -> () + }) { + indices_are_sorted = false, + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [], + inserted_window_dims = [0], + scatter_dims_to_operand_dims = [0], + index_vector_dim = 1, + >, + unique_indices = false + } : (tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>) -> tensor<3xi32> + func.return %0 : tensor<3xi32> +} + +// ----- + +// select_and_scatter implicit capture test +// CHECK: HloModule +// CHECK: [[SEL_REG:%region.*]] ({{.*}} { +// CHECK: f32[] constant(0) +// CHECK: ROOT +// CHECK: [[SCAT_REG:%region.*]] ({{.*}} { +// CHECK: f32[] constant(0) +// CHECK: ROOT +// CHECK: ENTRY +// CHECK: ROOT {{.*}} select-and-scatter{{.*}} select=[[SEL_REG]], scatter=[[SCAT_REG]] +func.func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x23x23x64xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { + %c1 = mhlo.constant dense<0.0> : tensor + %c2 = mhlo.constant dense<0.0> : tensor + %0 = "mhlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = mhlo.compare GE, %arg3, %c1, TOTALORDER : (tensor, tensor) -> tensor + mhlo.return %1 : tensor + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = mhlo.add %arg4, %c2 : tensor + mhlo.return %1 : tensor + }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<10x24x24x64xf32>, tensor<10x23x23x64xf32>, tensor) -> tensor<10x24x24x64xf32> + return %0 : tensor<10x24x24x64xf32> + } + +// ----- + +// sort implicit capture test +// CHECK: HloModule +// CHECK: [[REG0:%region.*]] ({{.*}} { +// CHECK: f32[] constant(0) +// CHECK: ROOT +// CHECK: ENTRY +// CHECK: {{.*}} sort{{.*}} to_apply=[[REG0]] +// CHECK: ROOT +func.func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { + %c = mhlo.constant dense<0.0> : tensor + %0:2 = "mhlo.sort"(%input0, %input1) ({ + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "mhlo.compare"(%arg0, %c) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) + func.return +} diff --git a/xla/translate/mhlo_to_hlo/tests/location_to_op_metadata.mlir b/xla/translate/mhlo_to_hlo/tests/location_to_op_metadata.mlir index ec40749815ac4..c1f9de94661eb 100644 --- a/xla/translate/mhlo_to_hlo/tests/location_to_op_metadata.mlir +++ b/xla/translate/mhlo_to_hlo/tests/location_to_op_metadata.mlir @@ -63,3 +63,36 @@ func.func @main(%arg0: !mhlo.token) -> !mhlo.token { // CHECK: after-all // CHECK-SAME: metadata={op_name="name(anothername)" source_file="file_name" source_line=2} + +// ----- + +// CHECK-LABEL: %main +func.func @main(%arg0: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc(fused["fused/location/file", "source.txt":42:5]) + func.return %0 : !mhlo.token +} + +// CHECK: after-all +// CHECK-SAME: metadata={op_name="fused/location/file" source_file="source.txt" source_line=42} + +// ----- + +// CHECK-LABEL: %main +func.func @main(%arg0: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc(fused["name1", fused["nested_fusion":5:42, "name2"]]) + func.return %0 : !mhlo.token +} + +// CHECK: after-all +// CHECK-SAME: metadata={op_name="name1;name2" source_file="nested_fusion" source_line=5} + +// ----- + +// CHECK-LABEL: %main +func.func @main(%arg0: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc(fused["multiple_sources", "source1":1:2, "source2":3:4]) + func.return %0 : !mhlo.token +} + +// CHECK: after-all +// CHECK-SAME: metadata={op_name="multiple_sources" source_file="source2" source_line=3} diff --git a/xla/translate/mhlo_to_hlo/tests/location_to_stacktrace.mlir b/xla/translate/mhlo_to_hlo/tests/location_to_stacktrace.mlir index 5eb7cb1c18de0..53e6e3b8fdcb6 100644 --- a/xla/translate/mhlo_to_hlo/tests/location_to_stacktrace.mlir +++ b/xla/translate/mhlo_to_hlo/tests/location_to_stacktrace.mlir @@ -33,6 +33,8 @@ module @main attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = f // CHECK-NEXT: } // CHECK-NEXT: metadata { // CHECK-NEXT: op_name: "name(anothername) +// CHECK-NEXT: source_file: "file_name" +// CHECK-NEXT: source_line: 2 // CHECK-NEXT: stack_frame_id: 1 // CHECK-NEXT: } @@ -71,6 +73,8 @@ module @main attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = f // CHECK-NEXT: } // CHECK-NEXT: metadata { // CHECK-NEXT: op_name: "name(anothername) +// CHECK-NEXT: source_file: "file_name_2" +// CHECK-NEXT: source_line: 3 // CHECK-NEXT: stack_frame_id: 2 // CHECK-NEXT: } diff --git a/xla/translate/mhlo_to_hlo/tests/multiple_return_tuple.mlir b/xla/translate/mhlo_to_hlo/tests/multiple_return_tuple.mlir index b8e5612c81e0f..878285fd21d4d 100644 --- a/xla/translate/mhlo_to_hlo/tests/multiple_return_tuple.mlir +++ b/xla/translate/mhlo_to_hlo/tests/multiple_return_tuple.mlir @@ -9,6 +9,6 @@ func.func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<1x2x3x4xi32>) { // CHECK-NEXT: %Arg_0.1 = s32[4] parameter(0) // CHECK-NEXT: %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] %Arg_0.1), dimensions={3} - %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32> + %0 = "mhlo.broadcast"(%arg0) <{broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>}> : (tensor<4xi32>) -> tensor<1x2x3x4xi32> func.return %arg0, %0 : tensor<4xi32>, tensor<1x2x3x4xi32> } diff --git a/xla/translate/mhlo_to_hlo/tests/sharding.mlir b/xla/translate/mhlo_to_hlo/tests/sharding.mlir index ed82732606e17..e6109118d5a30 100644 --- a/xla/translate/mhlo_to_hlo/tests/sharding.mlir +++ b/xla/translate/mhlo_to_hlo/tests/sharding.mlir @@ -4,9 +4,9 @@ func.func public @main(%arg0: tensor {mhlo.sharding = ""}, %arg1: tensor<4xf32> {mhlo.sharding = "\08\03\1A\01\02\22\02\00\01"}) -> (tensor<4x4xf32> {mhlo.sharding = "\08\03\1A\02\02\01\22\02\00\01"}) { // CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2]0,1} // CHECK-NEXT: %Arg_0.1 = f32[] parameter(0), sharding={replicated} - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<4xf32> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4xf32> %1 = mhlo.multiply %arg1, %0 : tensor<4xf32> - %2 = "mhlo.broadcast_in_dim"(%1) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4x4xf32> + %2 = "mhlo.broadcast_in_dim"(%1) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4x4xf32> // CHECK: ROOT {{.*}}, sharding={devices=[2,1]0,1} func.return %2 : tensor<4x4xf32> } @@ -31,9 +31,39 @@ func.func @main(%arg0: tensor<5x8x128xf32> {mhlo.sharding = "\08\03\1A\03\01\02\ // CHECK-LABEL: ENTRY %main.{{.*}} ({{[^,]*}}: f32[4,4]) -> (f32[4,4], f32[4,4]) func.func @main(%arg0: tensor<4x4xf32>) -> (tensor<4x4xf32> {mhlo.sharding = "\08\03\1A\03\02\01\02\22\04\00\01\02\03B\01\00"}, tensor<4x4xf32>) { // CHECK-NEXT: %Arg_0.1 = f32[4,4] parameter(0) - // CHECK-NEXT: [[RESHAPE_0:%.*]] = f32[4,4] reshape(f32[4,4] %Arg_0.1) + // CHECK-NEXT: [[RESHAPE_0:%.*]] = f32[4,4] reshape(f32[4,4] %Arg_0.1), sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} // CHECK-NEXT: [[RESHAPE_1:%.*]] = f32[4,4] reshape(f32[4,4] %Arg_0.1) + // CHECK-NOT: sharding // CHECK-NEXT: ROOT {{%.*}} = (f32[4,4], f32[4,4]) tuple(f32[4,4] [[RESHAPE_0]], f32[4,4] [[RESHAPE_1]]) // CHECK-SAME: sharding={{\{}}{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}, {replicated}} return %arg0, %arg0 : tensor<4x4xf32>, tensor<4x4xf32> } + +// ----- + +// CHECK-LABEL: ENTRY %main.{{.*}} () -> f32[12,24,36] +func.func @main() -> (tensor<12x24x36xf32>) { + // CHECK-NEXT: %constant.1 = f32[] constant(3.1415925) + // CHECK-NEXT: %broadcast.2 = f32[12,24,36] broadcast(f32[] %constant.1), dimensions={}, sharding={devices=[1,2,1]0,1} + // CHECK-NEXT: ROOT %add.3 = f32[12,24,36] add(f32[12,24,36] %broadcast.2, f32[12,24,36] %broadcast.2) + %0 = mhlo.constant {mhlo.sharding = "{devices=[1,2,1]0,1}"} dense<3.1415926> : tensor<12x24x36xf32> + %1 = mhlo.add %0, %0 : tensor<12x24x36xf32> + return %1 : tensor<12x24x36xf32> +} + +// ----- + +// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: u64[2]) -> (u64[2], u32[512,4]) +func.func @main(%arg0: tensor<2xui64>) -> (tensor<2xui64> {mhlo.sharding = "{devices=[2,16]<=[32] last_tile_dim_replicate}"}, tensor<512x4xui32> {mhlo.sharding = "{devices=[4,8]<=[32]}"}) { + // CHECK-NEXT: %Arg_0.1 = u64[2] parameter(0) + // CHECK-NEXT: %rng-bit-generator.2 = (u64[2], u32[512,4]) rng-bit-generator(u64[2] %Arg_0.1), algorithm=rng_default, sharding={{\{}}{replicated}, {devices=[8,4]<=[32]}} + // CHECK-NEXT: %get-tuple-element.3 = u64[2] get-tuple-element((u64[2], u32[512,4]) %rng-bit-generator.2), index=0, sharding={replicated} + // CHECK-NEXT: %add.5 = u64[2] add(u64[2] %get-tuple-element.3, u64[2] %get-tuple-element.3) + // CHECK-NEXT: %reshape.6 = u64[2] reshape(u64[2] %add.5) + // CHECK-NEXT: %get-tuple-element.4 = u32[512,4] get-tuple-element((u64[2], u32[512,4]) %rng-bit-generator.2), index=1, sharding={devices=[8,4]<=[32]} + // CHECK-NEXT: %reshape.7 = u32[512,4] reshape(u32[512,4] %get-tuple-element.4) + // CHECK-NEXT: ROOT %tuple.8 = (u64[2], u32[512,4]) tuple(u64[2] %reshape.6, u32[512,4] %reshape.7), sharding={{\{}}{devices=[2,16]<=[32] last_tile_dim_replicate}, {devices=[4,8]<=[32]}} + %output_state, %output = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> {mhlo.sharding = "{{replicated}, {devices=[8,4]<=[32]}}"} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<512x4xui32>) + %0 = mhlo.add %output_state, %output_state : tensor<2xui64> + return %0, %output : tensor<2xui64>, tensor<512x4xui32> +} diff --git a/xla/translate/mhlo_to_hlo/tests/while.mlir b/xla/translate/mhlo_to_hlo/tests/while.mlir index b5d5c98217f03..fddcb2bc61d7e 100644 --- a/xla/translate/mhlo_to_hlo/tests/while.mlir +++ b/xla/translate/mhlo_to_hlo/tests/while.mlir @@ -143,7 +143,7 @@ func.func @main(%arg0: tensor<3xf32>) -> tensor<3xf32> { "mhlo.return"(%7) : (tensor) -> () }, { ^bb0(%arg1: tensor<1xi32>, %arg2: tensor<2xi32>, %arg3: tensor<1xf32>, %arg4: tensor<3xf32>): - %4 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<3xf32> + %4 = "mhlo.broadcast_in_dim"(%arg3) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xf32>) -> tensor<3xf32> %5 = mhlo.add %arg4, %4 : tensor<3xf32> "mhlo.return"(%arg1, %arg2, %arg3, %5) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> () }) : (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) -> (tensor<1xi32>, tensor<2xi32>, tensor<1xf32>, tensor<3xf32>) @@ -259,7 +259,7 @@ func.func @main(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { ^bb0(%arg1: tensor<3x3xf32>): %2 = mhlo.constant dense : tensor %3 = mhlo.constant dense<2.000000e+00> : tensor - %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<3x3xf32> + %4 = "mhlo.broadcast_in_dim"(%3) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<3x3xf32> %5 = mhlo.add %arg1, %4 : tensor<3x3xf32> "mhlo.return"(%5) : (tensor<3x3xf32>) -> () }) : (tensor<3x3xf32>) -> tensor<3x3xf32> diff --git a/xla/translate/mhlo_to_hlo/translate.cc b/xla/translate/mhlo_to_hlo/translate.cc index 18152f0fc3813..409bab6c722ce 100644 --- a/xla/translate/mhlo_to_hlo/translate.cc +++ b/xla/translate/mhlo_to_hlo/translate.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ mlir::LogicalResult MlirHloToHloTranslateFunction(mlir::ModuleOp module, return mlir::success(); } -StatusOr> HloModuleFromProto( +absl::StatusOr> HloModuleFromProto( const HloProto& hlo_proto) { const HloModuleProto& module_proto = hlo_proto.hlo_module(); TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, diff --git a/xla/translate/mhlo_to_hlo/translate.h b/xla/translate/mhlo_to_hlo/translate.h index 358691b2dda68..a199b92b5ca78 100644 --- a/xla/translate/mhlo_to_hlo/translate.h +++ b/xla/translate/mhlo_to_hlo/translate.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/translate/mhlo_to_hlo/translate_registration.cc b/xla/translate/mhlo_to_hlo/translate_registration.cc index 27bd4a741b4f1..8500e1f6f2a58 100644 --- a/xla/translate/mhlo_to_hlo/translate_registration.cc +++ b/xla/translate/mhlo_to_hlo/translate_registration.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/translate/mhlo_to_hlo/type_to_shape.cc b/xla/translate/mhlo_to_hlo/type_to_shape.cc index e10d8b7cff00f..235115596a076 100644 --- a/xla/translate/mhlo_to_hlo/type_to_shape.cc +++ b/xla/translate/mhlo_to_hlo/type_to_shape.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "xla/mlir/utils/type_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/primitive_util.h" #include "xla/shape_util.h" @@ -46,40 +47,6 @@ using xla::ShapeUtil; namespace xla { -PrimitiveType TypeToPrimitiveType(mlir::Type type) { - if (type.isFloat8E5M2()) { - return PrimitiveType::F8E5M2; - } else if (type.isFloat8E4M3FN()) { - return PrimitiveType::F8E4M3FN; - } else if (type.isFloat8E4M3B11FNUZ()) { - return PrimitiveType::F8E4M3B11FNUZ; - } else if (type.isFloat8E4M3FNUZ()) { - return PrimitiveType::F8E4M3FNUZ; - } else if (type.isFloat8E5M2FNUZ()) { - return PrimitiveType::F8E5M2FNUZ; - } else if (type.isBF16()) { - return PrimitiveType::BF16; - } else if (type.isF16()) { - return PrimitiveType::F16; - } else if (type.isF32()) { - return PrimitiveType::F32; - } else if (type.isF64()) { - return PrimitiveType::F64; - } else if (auto complex_type = type.dyn_cast()) { - mlir::Type element_ty = complex_type.getElementType(); - return primitive_util::ComplexType(TypeToPrimitiveType(element_ty)); - } else if (auto integer_type = type.dyn_cast()) { - bool is_unsigned = integer_type.isUnsigned(); - if (integer_type.getWidth() == 1) { - return PrimitiveType::PRED; - } - return is_unsigned ? primitive_util::UnsignedIntegralTypeForBitWidth( - integer_type.getWidth()) - : primitive_util::SignedIntegralTypeForBitWidth( - integer_type.getWidth()); - } - return PrimitiveType::PRIMITIVE_TYPE_INVALID; -} std::optional> ConvertDimLevelType( mlir::sparse_tensor::LevelType lt) { @@ -104,7 +71,7 @@ std::optional> ConvertDimLevelType( } Shape TypeToShape(mlir::Type type) { - PrimitiveType ptype = TypeToPrimitiveType(type); + PrimitiveType ptype = ConvertMlirTypeToPrimitiveType(type); if (ptype != PrimitiveType::PRIMITIVE_TYPE_INVALID) return ShapeUtil::MakeShape(ptype, {}); @@ -117,7 +84,7 @@ Shape TypeToShape(mlir::Type type) { llvm::SmallVector span(v.getShape().begin(), v.getShape().end()); mlir::Type element_type = v.getElementType(); - PrimitiveType primitive_type = TypeToPrimitiveType(element_type); + PrimitiveType primitive_type = ConvertMlirTypeToPrimitiveType(element_type); if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID) return ShapeUtil::MakeShape(primitive_type, span); } else if (auto m = type.dyn_cast()) { @@ -130,7 +97,7 @@ Shape TypeToShape(mlir::Type type) { element_type = v.getElementType(); span.insert(span.end(), v.getShape().begin(), v.getShape().end()); } - PrimitiveType primitive_type = TypeToPrimitiveType(element_type); + PrimitiveType primitive_type = ConvertMlirTypeToPrimitiveType(element_type); if (primitive_type == PrimitiveType::PRIMITIVE_TYPE_INVALID) return {}; // For the primitive type case, the shape of the memref is similar to the // vector type case (i.e., it is, modulo the layout, the same dimensions @@ -190,7 +157,8 @@ Shape TypeToShape(mlir::Type type) { } } - PrimitiveType primitive_type = TypeToPrimitiveType(t.getElementType()); + PrimitiveType primitive_type = + ConvertMlirTypeToPrimitiveType(t.getElementType()); if (primitive_type == PrimitiveType::PRIMITIVE_TYPE_INVALID) return {}; if (auto sparse = mlir::sparse_tensor::getSparseTensorEncoding(type)) { diff --git a/xla/translate/mhlo_to_hlo/type_to_shape.h b/xla/translate/mhlo_to_hlo/type_to_shape.h index 39bfb6c5b4091..a56d8ff2347d1 100644 --- a/xla/translate/mhlo_to_hlo/type_to_shape.h +++ b/xla/translate/mhlo_to_hlo/type_to_shape.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,10 +26,6 @@ namespace xla { // Returns a XLA Shape equivalent of a MLIR Type, else returns empty shape. Shape TypeToShape(mlir::Type type); -// Returns a XLA PrimitiveType equivalent of a MLIR Type that represents a -// primitive type (e.g., i8, f32), else returns PRIMITIVE_TYPE_INVALID. -PrimitiveType TypeToPrimitiveType(mlir::Type type); - } // namespace xla #endif // XLA_TRANSLATE_MHLO_TO_HLO_TYPE_TO_SHAPE_H_ diff --git a/xla/translate/mhlo_to_hlo/type_to_shape_test.cc b/xla/translate/mhlo_to_hlo/type_to_shape_test.cc index e38dbc355d042..36ef41a55ac78 100644 --- a/xla/translate/mhlo_to_hlo/type_to_shape_test.cc +++ b/xla/translate/mhlo_to_hlo/type_to_shape_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -63,16 +63,6 @@ inline ::testing::PolymorphicMatcher EqualsProto( return ::testing::MakePolymorphicMatcher(ProtoStringMatcher(x)); } -TEST(TypeToShapeTest, ConvertPrimitiveTypes) { - MLIRContext context; - Builder b(&context); - - EXPECT_EQ(TypeToPrimitiveType(b.getF32Type()), PrimitiveType::F32); - EXPECT_EQ(TypeToPrimitiveType(b.getIntegerType(1)), PrimitiveType::PRED); - EXPECT_EQ(TypeToPrimitiveType(b.getIntegerType(17)), - PrimitiveType::PRIMITIVE_TYPE_INVALID); -} - TEST(TypeToShapeTest, ConvertBasicTypesToTypes) { MLIRContext context; Builder b(&context); @@ -164,7 +154,7 @@ TEST(TypeToShapeTest, ConvertMemRefToShape) { MLIRContext context; mlir::Builder builder(&context); - StatusOr mlir_type = + absl::StatusOr mlir_type = ConvertShapeToType(shape, builder); ASSERT_TRUE(mlir_type.ok()); mlir::Type type = std::move(mlir_type).value(); @@ -181,7 +171,7 @@ TEST(TypeToShapeTest, ConvertMemRefToShape2) { MLIRContext context; mlir::Builder builder(&context); - StatusOr mlir_type = + absl::StatusOr mlir_type = ConvertShapeToType(shape, builder); ASSERT_TRUE(mlir_type.ok()); mlir::Type type = std::move(mlir_type).value(); diff --git a/xla/translate/mhlo_to_lhlo_with_xla/BUILD b/xla/translate/mhlo_to_lhlo_with_xla/BUILD deleted file mode 100644 index e5f9fff1d9162..0000000000000 --- a/xla/translate/mhlo_to_lhlo_with_xla/BUILD +++ /dev/null @@ -1,128 +0,0 @@ -load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("//xla:xla.bzl", "xla_cc_binary") -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") -load("@bazel_skylib//rules:build_test.bzl", "build_test") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) - -cc_library( - name = "mhlo_to_lhlo_with_xla", - srcs = ["mhlo_to_lhlo_with_xla.cc"], - hdrs = ["mhlo_to_lhlo_with_xla.h"], - deps = [ - "//xla:debug_options_flags", - "//xla:shape_util", - "//xla:statusor", - "//xla:util", - "//xla:window_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/mlir/utils:error_util", - "//xla/mlir_hlo:lhlo", - "//xla/mlir_hlo:lhlo_gpu", - "//xla/service:backend", - "//xla/service:buffer_assignment", - "//xla/service:hlo_parser", - "//xla/service/gpu:backend_configs_cc", - "//xla/service/gpu:cublas_cudnn", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:matmul_utils", - "//xla/service/llvm_ir:buffer_assignment_util", - "//xla/service/llvm_ir:llvm_util", - "//xla/translate/hlo_to_mhlo:attribute_importer", - "//xla/translate/hlo_to_mhlo:hlo_module_importer", - "//xla/translate/hlo_to_mhlo:hlo_utils", - "//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", - "//xla/translate/mhlo_to_hlo:type_to_shape", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:AsmParser", - "@llvm-project//mlir:BufferizationDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TranslateLib", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "translate_registration", - testonly = True, - srcs = ["translate_registration.cc"], - deps = [ - ":mhlo_to_lhlo_with_xla", - "@llvm-project//mlir:TranslateLib", - ], - alwayslink = 1, -) - -build_test( - name = "xla-translate-opt_build_test", - targets = [ - ":xla-translate-opt", - ], -) - -xla_cc_binary( - name = "xla-translate-opt", - testonly = True, - srcs = ["xla_translate_opt_main.cc"], - deps = [ - ":mhlo_to_lhlo_with_xla", # buildcleaner: keep - "//xla/mlir/framework/ir:xla_framework", - "//xla/mlir/framework/transforms:passes", - "//xla/mlir_hlo:hlo_dialect_registration", - "//xla/service:cpu_plugin", - "//xla/service/cpu:hlo_xla_runtime_pipeline", # buildcleaner: keep - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", - "@llvm-project//mlir:MlirOptLib", - "@stablehlo//:register", - "@tsl//tsl/platform:platform_port", - ], -) - -build_test( - name = "xla-translate-gpu-opt_build_test", - targets = [ - ":xla-translate-gpu-opt", - ], -) - -xla_cc_binary( - name = "xla-translate-gpu-opt", - testonly = True, - srcs = ["xla_translate_opt_main.cc"], - deps = [ - ":mhlo_to_lhlo_with_xla", # buildcleaner: keep - "//xla/mlir/framework/ir:xla_framework", - "//xla/mlir/framework/transforms:passes", - "//xla/mlir_hlo:all_passes", - "//xla/mlir_hlo:hlo_dialect_registration", - "//xla/service:gpu_plugin", - "//xla/stream_executor", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", - "@llvm-project//mlir:MlirOptLib", - "@stablehlo//:register", - "@tsl//tsl/platform:platform_port", - ] + if_cuda(["//xla/stream_executor/cuda:cublas_plugin"]) + if_rocm([ - "//xla/stream_executor/rocm:rocblas_plugin", - ]), -) diff --git a/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc b/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc deleted file mode 100644 index 53289bcf35742..0000000000000 --- a/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc +++ /dev/null @@ -1,2562 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/cleanup/cleanup.h" -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/strings/match.h" -#include "absl/types/optional.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/AsmParser/AsmParser.h" // from @llvm-project -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OpDefinition.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/IR/ValueRange.h" // from @llvm-project -#include "mlir/IR/Verifier.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project -#include "xla/debug_options_flags.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/mlir/utils/error_util.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" -#include "xla/service/backend.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/service/hlo_parser.h" -#include "xla/service/llvm_ir/buffer_assignment_util.h" -#include "xla/service/llvm_ir/llvm_util.h" -#include "xla/shape_util.h" -#include "xla/statusor.h" -#include "xla/translate/hlo_to_mhlo/attribute_importer.h" -#include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" -#include "xla/util.h" -#include "xla/window_util.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" - -using xla::BufferAllocation; -using xla::BufferAssignment; -using xla::HloComputation; -using xla::HloCustomCallInstruction; -using xla::HloInfeedInstruction; -using xla::HloInstruction; -using xla::HloModule; -using xla::HloModuleProto; -using xla::HloOutfeedInstruction; -using xla::HloProto; -using xla::Shape; - -namespace mlir { -namespace { - -absl::string_view StringRefToView(llvm::StringRef ref) { - return {ref.data(), ref.size()}; -} - -tsl::StatusOr> HloModuleFromProto( - const HloProto& hlo_proto) { - const HloModuleProto& module_proto = hlo_proto.hlo_module(); - TF_ASSIGN_OR_RETURN(const xla::HloModuleConfig module_config, - HloModule::CreateModuleConfigFromProto( - module_proto, xla::GetDebugOptionsFromFlags())); - return HloModule::CreateFromProto(module_proto, module_config); -} - -bool IsSyncCollective(const HloInstruction* instr) { - auto backend_config = - instr->backend_config().value(); - return backend_config.is_sync(); -} - -bool NoParallelCustomCallCollective(const HloInstruction* instr) { - auto backend_config = - instr->backend_config().value(); - return backend_config.no_parallel_custom_call(); -} - -// Convert the MLIR `module` from HLO dialect to LHLO dialect using XLA for the -// given platform. -tsl::Status ConvertHloToLmhlo(std::unique_ptr hlo_module, - ModuleOp module, StringRef platform_name) { - auto platform = xla::se::MultiPlatformManager::PlatformWithName( - StringRefToView(platform_name)); - if (!platform.ok()) { - std::string error_msg; - llvm::raw_string_ostream os(error_msg); - os << "failed to get platform: " << platform.status().ToString() - << " (available Platform: "; - std::vector available_platforms; - (void)xla::se::MultiPlatformManager::PlatformsWithFilter( - [&](const stream_executor::Platform* p) { - available_platforms.push_back(p->Name()); - return false; - }); - llvm::interleaveComma(available_platforms, os); - os << ")"; - return tsl::errors::InvalidArgument("%s", os.str().c_str()); - } - - xla::BackendOptions backend_options; - backend_options.set_platform(platform.value()); - auto backend_or_err = xla::Backend::CreateBackend(backend_options); - TF_RETURN_WITH_CONTEXT_IF_ERROR(backend_or_err.status(), - "failed to create XLA Backend "); - auto backend = std::move(backend_or_err.value()); - - tsl::StatusOr> assignment = - backend->compiler()->AssignBuffers(hlo_module.get(), - backend->default_stream_executor()); - TF_RETURN_WITH_CONTEXT_IF_ERROR(assignment.status(), - "running XLA buffer assigment"); - - // Clear the module before populating it back with the result of the - // conversion. - module.getBody()->clear(); - OpBuilder builder(module); - - std::vector ordered_allocations; - TF_RETURN_WITH_CONTEXT_IF_ERROR( - HloToLhloModule(**assignment, *hlo_module, module, &ordered_allocations), - "converting HLO to LHLO"); - - return ::tsl::OkStatus(); -} - -} // namespace - -// Creates MLIR operands corresponding to operands and results of the XLA HLO -// instruction. If `num_operands` is valid, then only the first `num_operands` -// operands of the HLO instruction will be considered. -tsl::Status LhloDialectEmitter::CreateOperands( - const HloInstruction* instr, std::optional num_operands, - TokenLoweringMode token_mode, llvm::SmallVectorImpl& operands, - size_t& num_arguments, size_t& num_results) { - if (num_operands.value_or(0) > instr->operand_count()) - return tsl::errors::InvalidArgument( - "num_operands must be <= operand count"); - for (int64_t i = 0; i < num_operands.value_or(instr->operand_count()); ++i) { - TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands, - /*result_subset=*/{}, token_mode)); - } - num_arguments = operands.size(); - TF_RETURN_IF_ERROR( - GetOrCreateView(instr, &operands, /*result_subset=*/{}, token_mode)); - num_results = operands.size() - num_arguments; - return ::tsl::OkStatus(); -} - -template -OpType LhloDialectEmitter::CreateOpWithoutAttrs(const HloInstruction* instr, - ValueRange operands) { - Location loc = getLocation(instr); - return builder_.create(loc, std::nullopt, operands, - llvm::ArrayRef{}); -} - -template -tsl::StatusOr LhloDialectEmitter::CreateOpWithoutAttrs( - const HloInstruction* instr, size_t& num_arguments, size_t& num_results, - std::optional num_operands) { - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(CreateOperands(instr, num_operands, - TokenLoweringMode::kFailToLower, operands, - num_arguments, num_results)); - return CreateOpWithoutAttrs(instr, operands); -} - -tsl::StatusOr LhloDialectEmitter::EmitOp( - const HloInstruction* instr) { - using xla::HloOpcode; - switch (instr->opcode()) { - case HloOpcode::kAddDependency: - return nullptr; - case HloOpcode::kAfterAll: - // LMHLO is already ordered. This assumption may be broken after - // introducing async regions and partial orders. - return nullptr; - case HloOpcode::kAllGatherStart: - return EmitAllGatherStartOp(instr); - case HloOpcode::kAllGatherDone: - return EmitAllGatherDoneOp(instr); - case HloOpcode::kAllReduceStart: - return EmitAllReduceStartOp(instr); - case HloOpcode::kAllReduceDone: - return EmitAllReduceDoneOp(instr); - case HloOpcode::kAsyncStart: - return EmitAsyncStartOp(instr); - case HloOpcode::kAsyncDone: - return EmitAsyncDoneOp(instr); - case HloOpcode::kBitcast: - return EmitBitcast(instr); - case HloOpcode::kCollectivePermuteStart: - return EmitCollectivePermuteStartOp(instr); - case HloOpcode::kCollectivePermuteDone: - return EmitCollectivePermuteDoneOp(instr); - case HloOpcode::kConditional: - return EmitCaseOp(instr); - case HloOpcode::kFft: - return EmitFftOp(instr); - case HloOpcode::kGetTupleElement: - return nullptr; - case HloOpcode::kInfeed: - return EmitInfeedOp(instr); - case HloOpcode::kOutfeed: - return EmitOutfeedOp(instr); - case HloOpcode::kPartitionId: - return CreateOpWithoutAttrs(instr); - case HloOpcode::kReplicaId: - return CreateOpWithoutAttrs(instr); - case HloOpcode::kTriangularSolve: - return EmitTriangularSolveOp(instr); - case HloOpcode::kTuple: - return nullptr; - case HloOpcode::kSort: - return EmitSortOp(instr); - case HloOpcode::kFusion: - return EmitFusionOp(instr); - case HloOpcode::kScatter: - return EmitScatterOp(instr); - case HloOpcode::kSelectAndScatter: - return EmitSelectAndScatterOp(instr); - case HloOpcode::kCustomCall: - return EmitCustomCallOp(instr); - case HloOpcode::kConstant: - return EmitConstant(instr); - case HloOpcode::kRngGetAndUpdateState: - return EmitRngGetAndUpdateStateOp(instr); - case HloOpcode::kWhile: - return EmitWhileOp(instr); - case HloOpcode::kSend: - return EmitSendOp(instr); - case HloOpcode::kSendDone: - return EmitSendDoneOp(instr); - case HloOpcode::kRecv: - return EmitRecvOp(instr); - case HloOpcode::kRecvDone: - return EmitRecvDoneOp(instr); - // TODO(b/302038092): Currently the command buffer call is represented by - // a kCall. We need to be able to differentiate it from a regular kCall. - case HloOpcode::kCall: - return EmitCommandBufferOp(instr); - default: - llvm::errs() << instr->ToString(); - llvm::errs() << "\n\nModule:\n" - << instr->GetModule()->ToString() << "\n\n"; - return tsl::errors::Internal( - absl::StrCat("LHLO opcode ", xla::HloOpcodeString(instr->opcode()), - " is not supported.")); - } -} - -tsl::Status LhloDialectEmitter::DefaultAction(const HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(auto* op, EmitOp(instr)); - if (op) { - lhlo_to_hlo_[op] = instr; - } - return tsl::OkStatus(); -} - -tsl::StatusOr LhloDialectEmitter::EmitSortOp( - const HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs(instr)); - auto* sort_instr = xla::Cast(instr); - sort.setDimensionAttr( - builder_.getI64IntegerAttr(sort_instr->sort_dimension())); - sort.setIsStableAttr(builder_.getBoolAttr(sort_instr->is_stable())); - TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( - *sort_instr->called_computations()[0], symbol_table_, - &sort.getComparator(), &builder_)); - return sort; -} - -// Walks MHLO::TupleOp recursively. -tsl::Status WalkTuplePostOrder( - Value v, const std::function& visitor) { - if (auto* op = v.getDefiningOp()) { - if (auto tuple = dyn_cast(op)) { - for (Value sub_v : tuple.getVal()) { - TF_RETURN_IF_ERROR(WalkTuplePostOrder(sub_v, visitor)); - } - return ::tsl::OkStatus(); - } - } - return visitor(v); -} - -tsl::StatusOr LhloDialectEmitter::RewriteFusionOperand( - const HloInstruction* root, const Shape& shape, - xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) { - if (shape.IsTuple()) { - llvm::SmallVector values; - for (int i = 0; i < shape.tuple_shapes_size(); ++i) { - shape_index->push_back(i); - TF_ASSIGN_OR_RETURN( - auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index, - b, loc)); - values.push_back(v); - shape_index->pop_back(); - } - return Value(b->create(loc, values)); - } - TF_ASSIGN_OR_RETURN(Value memref, - GetOrCreateArrayView(root, shape, *shape_index)); - auto load = b->create(loc, memref); - if (shape.layout() != - xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) { - llvm::SmallVector minor_to_major( - shape.layout().minor_to_major().begin(), - shape.layout().minor_to_major().end()); - load->setAttr("xla_shape", - b->getStringAttr(shape.ToString(/*print_layout=*/true))); - } - return load.getResult(); -} - -// Emit a lmhlo.fusion based on XLA HLO fusion. Structurally they are not neatly -// equivalent. Specifically, XLA HLO fusion: -// fused_computation { -// %p0 = parameter(0) -// %p1 = parameter(1) -// ... -// ROOT %ret = ... -// } -// will be converted to -// lmhlo.fusion() { // no explicit operands -// // capturing outside buffers -// %p0 = bufferization.to_tensor(%arg0) : memref<...> -> tensor<...> -// %p1 = bufferization.to_tensor(%arg1) : memref<...> -> tensor<...> -// ... -// tensor_store ..., %ret // store a tensor to a memref -// } -tsl::StatusOr LhloDialectEmitter::EmitFusionOp( - const HloInstruction* instr) { - Location loc = getLocation(instr); - - auto* fusion_instr = xla::Cast(instr); - - auto fusion = builder_.create(getLocation(instr)); - auto after_fusion = builder_.saveInsertionPoint(); - auto reverter = absl::MakeCleanup( - [this, after_fusion] { builder_.restoreInsertionPoint(after_fusion); }); - builder_ = mlir::OpBuilder(fusion); - - auto region_builder = OpBuilder::atBlockBegin(&fusion.getRegion().front()); - - llvm::SmallVector arguments; - for (int i = 0; i < instr->operands().size(); ++i) { - const HloInstruction* operand = instr->operand(i); - xla::ShapeIndex shape_index; - TF_ASSIGN_OR_RETURN( - auto arg, RewriteFusionOperand(operand, operand->shape(), &shape_index, - ®ion_builder, loc)); - arguments.push_back(arg); - } - - TF_ASSIGN_OR_RETURN(Value result, - xla::HloFunctionImporter::ImportInstructions( - *fusion_instr->fused_instructions_computation(), - arguments, symbol_table_, ®ion_builder)); - { - int i = 0; - llvm::SmallVector output; - TF_RETURN_IF_ERROR(GetOrCreateView(instr, &output)); - TF_RETURN_IF_ERROR(WalkTuplePostOrder(result, [&](Value v) mutable { - auto materialize_op = - region_builder.create( - loc, v, output[i++]); - materialize_op.setWritable(true); - return ::tsl::OkStatus(); - })); - if (i != output.size()) { - return xla::InternalError("output sizes don't match"); - } - } - - // The fusion op might not have a backend-config. But we at least want to set - // the fusion kind, because LMHLO doesn't have this concept. - TF_ASSIGN_OR_RETURN(auto backend_config, - instr->backend_config()); - if (backend_config.kind().empty() && - instr->opcode() == xla::HloOpcode::kFusion) { - backend_config.set_kind(std::string(ToString(instr->fusion_kind()))); - } - - TF_ASSIGN_OR_RETURN(std::string backend_config_str, - HloInstruction::BackendConfigToRawString(backend_config)); - fusion.setBackendConfigAttr(builder_.getStringAttr(backend_config_str)); - - // Fold GTE/Tuple pairs. - // - // Since the fused region refers to values in its parent region, we can't - // call applyPatternAndFoldGreedily. We optimize it manually. - // - // Only walk once, because post-ordering is exactly what we need for GTE - // optimizations. - fusion.getRegion().walk([](mhlo::GetTupleElementOp gte) { - SmallVector folded_values; - if (succeeded(OpBuilder(gte).tryFold(gte, folded_values))) { - gte.replaceAllUsesWith(folded_values[0]); - } - }); - - // Effectively a DCE on the region. - { - llvm::SmallVector ops; - fusion.getRegion().walk([&](mlir::Operation* op) { ops.push_back(op); }); - // Visit the user first. - std::reverse(ops.begin(), ops.end()); - for (auto op : ops) { - if (isOpTriviallyDead(op)) op->erase(); - } - } - - return fusion; -} - -tsl::StatusOr -LhloDialectEmitter::GetScatterDimensionNumbers(const HloInstruction* instr, - mlir::MLIRContext* context) { - auto* scatter_instr = xla::Cast(instr); - - const xla::ScatterDimensionNumbers& xla_scatter_dim = - scatter_instr->scatter_dimension_numbers(); - - auto get_i64_array = [](absl::Span container) { - return ArrayRef{container.data(), - static_cast(container.size())}; - }; - auto scatter_dimension_numbers = mhlo::ScatterDimensionNumbersAttr::get( - context, get_i64_array(xla_scatter_dim.update_window_dims()), - get_i64_array(xla_scatter_dim.inserted_window_dims()), - get_i64_array(xla_scatter_dim.scatter_dims_to_operand_dims()), - xla_scatter_dim.index_vector_dim()); - return scatter_dimension_numbers; -} - -tsl::StatusOr LhloDialectEmitter::EmitScatterOp( - const HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(auto scatter, - CreateOpWithoutAttrs(instr)); - - // copy attributes - auto* scatter_instr = xla::Cast(instr); - - TF_ASSIGN_OR_RETURN(auto scatter_dimension_numbers, - GetScatterDimensionNumbers(instr, builder_.getContext())); - scatter.setScatterDimensionNumbersAttr(scatter_dimension_numbers); - scatter.setIndicesAreSortedAttr( - builder_.getBoolAttr(scatter_instr->indices_are_sorted())); - scatter.setUniqueIndicesAttr( - builder_.getBoolAttr(scatter_instr->unique_indices())); - - // import update computation as region - TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( - *scatter_instr->called_computations()[0], symbol_table_, - &scatter.getUpdateComputation(), &builder_)); - - return scatter; -} - -tsl::StatusOr -LhloDialectEmitter::EmitSelectAndScatterOp(const HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(auto select_and_scatter, - CreateOpWithoutAttrs(instr)); - - // copy attributes - auto* select_and_scatter_instr = - xla::Cast(instr); - const xla::Window& window = select_and_scatter_instr->window(); - - if (xla::window_util::HasDilation(window)) { - return tsl::errors::Unimplemented( - "Dilation for SelectAndScatter is not supported"); - } - - select_and_scatter.setWindowDimensionsAttr( - GetWindowElements(window, [](const xla::WindowDimension& dim) { - return static_cast(dim.size()); - })); - select_and_scatter.setWindowStridesAttr( - GetWindowElements(window, [](const xla::WindowDimension& dim) { - return static_cast(dim.stride()); - })); - select_and_scatter.setPaddingAttr( - GetWindowElements(window, [](const xla::WindowDimension& dim) { - return static_cast(dim.padding_low()); - })); - - // import select and scatter computation as region - TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( - *select_and_scatter_instr->select(), symbol_table_, - &select_and_scatter.getSelect(), &builder_)); - TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( - *select_and_scatter_instr->scatter(), symbol_table_, - &select_and_scatter.getScatter(), &builder_)); - return select_and_scatter; -} - -tsl::StatusOr LhloDialectEmitter::EmitCustomCallOp( - const HloInstruction* instr) { - auto* custom_call_instr = xla::Cast(instr); - - if (xla::gpu::IsCustomCallToCusolver(*instr)) { - return EmitCholesky(custom_call_instr); - } - - if (xla::gpu::IsLegacyCublasMatmul(*instr)) { - return EmitGemm(custom_call_instr); - } - - if (xla::gpu::IsCublasLtMatmul(*instr)) { - return EmitCublasLtMatmul(custom_call_instr); - } - - if (xla::gpu::IsCublasLtMatmulF8(*instr)) { - return EmitCublasLtMatmulF8(custom_call_instr); - } - - if (xla::gpu::IsCustomCallToDnnConvolution(*instr)) { - return EmitDnnConvolution(custom_call_instr); - } - - if (xla::gpu::IsCudnnConvolutionReorder(*instr)) { - return EmitDnnConvolutionReorderVectorized(custom_call_instr); - } - - if (xla::gpu::IsCustomCallToDnnNorm(*instr)) { - return EmitDnnNorm(custom_call_instr); - } - - if (xla::gpu::IsFwdCustomCallTofMHA(*instr)) { - return EmitDnnfMHA(custom_call_instr); - } - if (xla::gpu::IsBwdCustomCallTofMHA(*instr)) { - return EmitDnnfMHABackward(custom_call_instr); - } - if (xla::gpu::IsCubDeviceRadixSort(*instr)) { - return EmitCubDeviceRadixSort(custom_call_instr); - } - - // For custom call, if there are any token operands or results, they will not - // be represented in LHLO so we need to remember the mapping. First create - // operands where each token is replaced with a null Value. - llvm::SmallVector operands; - size_t num_arguments, num_results; - TF_RETURN_IF_ERROR(CreateOperands(instr, /*num_operands=*/std::nullopt, - TokenLoweringMode::kUseNull, operands, - num_arguments, num_results)); - - // Now check if any of the operands is Null, which would indicate the presence - // of a token in the input or output. - bool has_token = llvm::any_of(operands, [](Value v) { return !v; }); - - lmhlo::CustomCallTargetArgMappingAttr target_mapping; - if (has_token) { - // If there was a token, squeeze all the non-token arguments and results - // (in-place) and remember the mapping. - int next_index = 0; - llvm::SmallVector arg_to_target_arg_mapping; - for (int i = 0; i < num_arguments; ++i) { - if (operands[i]) { - arg_to_target_arg_mapping.push_back(i); - operands[next_index++] = operands[i]; - } - } - // Size of arg_to_target_arg_mapping is the number of arguments in LHLO. - llvm::SmallVector result_to_target_result_mapping; - for (int i = num_arguments; i < operands.size(); ++i) { - if (operands[i]) { - result_to_target_result_mapping.push_back(i - num_arguments); - operands[next_index++] = operands[i]; - } - } - - // Build the mapping attribute. - target_mapping = lmhlo::CustomCallTargetArgMappingAttr::get( - builder_.getContext(), num_arguments, num_results, - arg_to_target_arg_mapping, result_to_target_result_mapping); - - // Drop the remaining operands and adjust num_arguments and num_results - // for LMHLO creation. - operands.resize(next_index); - num_arguments = arg_to_target_arg_mapping.size(); - num_results = result_to_target_result_mapping.size(); - } - - auto custom_call = CreateOpWithoutAttrs(instr, operands); - TF_ASSIGN_OR_RETURN( - auto mlir_api_version, - ConvertCustomCallApiVersion(custom_call_instr->api_version())); - custom_call.setCallTargetNameAttr( - builder_.getStringAttr(custom_call_instr->custom_call_target())); - custom_call.setApiVersionAttr(mhlo::CustomCallApiVersionAttr::get( - builder_.getContext(), mlir_api_version)); - - // For typed custom calls we need to parse user-defined attributes back to the - // dictionary attribute, and then add them back to the custom call op. - if (mlir_api_version == mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI) { - if (custom_call_instr->opaque().empty()) { - auto empty = mlir::DictionaryAttr::get(builder_.getContext()); - custom_call.setBackendConfigAttr(empty); - } else { - mlir::Attribute attr = mlir::parseAttribute(custom_call_instr->opaque(), - builder_.getContext()); - TF_RET_CHECK(attr.isa()) - << "Couldn't parse backend config into a dictionary attribute"; - custom_call.setBackendConfigAttr(attr); - } - } else { - custom_call.setBackendConfigAttr( - builder_.getStringAttr(custom_call_instr->opaque())); - } - - const int32_t segments[2] = {static_cast(num_arguments), - static_cast(num_results)}; - custom_call->setAttr(lmhlo::CustomCallOp::getOperandSegmentSizeAttr(), - builder_.getDenseI32ArrayAttr(segments)); - if (target_mapping) custom_call.setTargetArgMappingAttr(target_mapping); - - for (int i = 0; i < custom_call_instr->called_computations().size(); ++i) { - auto& region = custom_call->getRegion(i); - TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( - *custom_call_instr->called_computation(), symbol_table_, ®ion, - &builder_)); - } - - return custom_call.getOperation(); -} - -tsl::StatusOr LhloDialectEmitter::EmitCholesky( - const HloCustomCallInstruction* custom_call) { - TF_ASSIGN_OR_RETURN(auto cholesky_op, - CreateOpWithoutAttrs(custom_call)); - TF_ASSIGN_OR_RETURN(xla::CholeskyOptions options, - custom_call->backend_config()); - cholesky_op.setIsLowerAttr(builder_.getBoolAttr(options.lower())); - return cholesky_op; -} - -namespace { - -mhlo::DotDimensionNumbersAttr GetDotDimensionNumbersAttr( - const OpBuilder& builder, const xla::DotDimensionNumbers& hlo_dims) { - auto arrayref = [](absl::Span array) { - return llvm::ArrayRef{array.data(), array.size()}; - }; - return mhlo::DotDimensionNumbersAttr::get( - builder.getContext(), arrayref(hlo_dims.lhs_batch_dimensions()), - arrayref(hlo_dims.rhs_batch_dimensions()), - arrayref(hlo_dims.lhs_contracting_dimensions()), - arrayref(hlo_dims.rhs_contracting_dimensions())); -} - -template -void SetMatmulAttributes(OpT op, const xla::gpu::GemmBackendConfig& config, - OpBuilder& builder) { - op.setDotDimensionNumbersAttr( - GetDotDimensionNumbersAttr(builder, config.dot_dimension_numbers())); - op.setAlphaRealAttr(builder.getF64FloatAttr(config.alpha_real())); - op.setAlphaImagAttr(builder.getF64FloatAttr(config.alpha_imag())); - op.setBetaAttr(builder.getF64FloatAttr(config.beta())); - if (config.algorithm_case() == - xla::gpu::GemmBackendConfig::kSelectedAlgorithm) { - op.setAlgorithmAttr(builder.getI64IntegerAttr(config.selected_algorithm())); - } - op.setPrecisionConfigAttr( - xla::ConvertPrecisionConfig(&config.precision_config(), &builder)); - op.setGradXAttr(builder.getBoolAttr(config.grad_x())); - op.setGradYAttr(builder.getBoolAttr(config.grad_y())); -} - -tsl::StatusOr AsLhloEpilogue( - xla::gpu::GemmBackendConfig_Epilogue epilogue) { - switch (epilogue) { - case xla::gpu::GemmBackendConfig::DEFAULT: - return lmhlo_gpu::CublasLtMatmulEpilogue::Default; - case xla::gpu::GemmBackendConfig::RELU: - return lmhlo_gpu::CublasLtMatmulEpilogue::Relu; - case xla::gpu::GemmBackendConfig::GELU: - return lmhlo_gpu::CublasLtMatmulEpilogue::Gelu; - case xla::gpu::GemmBackendConfig::GELU_AUX: - return lmhlo_gpu::CublasLtMatmulEpilogue::GeluAux; - case xla::gpu::GemmBackendConfig::BIAS: - return lmhlo_gpu::CublasLtMatmulEpilogue::Bias; - case xla::gpu::GemmBackendConfig::BIAS_RELU: - return lmhlo_gpu::CublasLtMatmulEpilogue::BiasRelu; - case xla::gpu::GemmBackendConfig::BIAS_GELU: - return lmhlo_gpu::CublasLtMatmulEpilogue::BiasGelu; - case xla::gpu::GemmBackendConfig::BIAS_GELU_AUX: - return lmhlo_gpu::CublasLtMatmulEpilogue::BiasGeluAux; - default: - return xla::InternalError("unknown epilogue"); - } -} - -tsl::StatusOr AsLhloFusedMhaDagSignature( - xla::gpu::CudnnfMHAKind kind) { - switch (kind) { - case xla::gpu::CudnnfMHAKind::kBmmBmm: - return lmhlo_gpu::FusedMhaDagSignature::Default; - case xla::gpu::CudnnfMHAKind::kScaleBiasMaskSoftmax: - return lmhlo_gpu::FusedMhaDagSignature::ScaleBiasMaskSoftmax; - case xla::gpu::CudnnfMHAKind::kScaleBiasMaskSoftmaxDropout: - return lmhlo_gpu::FusedMhaDagSignature::ScaleBiasMaskSoftmaxDropout; - case xla::gpu::CudnnfMHAKind::kScaleMaskSoftmax: - return lmhlo_gpu::FusedMhaDagSignature::ScaleMaskSoftmax; - case xla::gpu::CudnnfMHAKind::kScaleMaskSoftmaxDropout: - return lmhlo_gpu::FusedMhaDagSignature::ScaleMaskSoftmaxDropout; - case xla::gpu::CudnnfMHAKind::kSoftmaxDropout: - return lmhlo_gpu::FusedMhaDagSignature::SoftmaxDropout; - case xla::gpu::CudnnfMHAKind::kSoftmax: - return lmhlo_gpu::FusedMhaDagSignature::Softmax; - case xla::gpu::CudnnfMHAKind::kScaleBiasSoftmax: - return lmhlo_gpu::FusedMhaDagSignature::ScaleBiasSoftmax; - case xla::gpu::CudnnfMHAKind::kScaleBiasSoftmaxDropout: - return lmhlo_gpu::FusedMhaDagSignature::ScaleBiasSoftmaxDropout; - default: - return xla::InternalError("unknown cudnn fmha fwd kind"); - } -} -tsl::StatusOr -AsLhloFusedMhaBackwardDagSignature(xla::gpu::CudnnfMHAKind kind) { - switch (kind) { - case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmax: - return lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardScaleBiasSoftmax; - break; - case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout: - return lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasSoftmaxDropout; - break; - case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmax: - return lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasMaskSoftmax; - break; - case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmaxDropout: - return lmhlo_gpu::FusedMhaBackwardDagSignature:: - BackwardScaleBiasMaskSoftmaxDropout; - break; - case xla::gpu::CudnnfMHAKind::kBackwardSoftmax: - return lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmax; - break; - case xla::gpu::CudnnfMHAKind::kBackwardSoftmaxDropout: - return lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmaxDropout; - break; - default: - return xla::InternalError("unknown cudnn fmha bwd kind"); - } -} -} // namespace - -tsl::StatusOr LhloDialectEmitter::EmitGemm( - const HloCustomCallInstruction* custom_call) { - TF_ASSIGN_OR_RETURN( - auto const config, - custom_call->backend_config()); - - if (custom_call->operand_count() == 2) { - TF_RET_CHECK(config.beta() == 0.); - } else if (custom_call->operand_count() != 3) { - return xla::InvalidArgument("GEMM custom call should have 2 or 3 operands"); - } - - // GEMM may have two or three operands. However, in the three operand case, - // the third operand is updated in-place, so we treat that as an output here. - TF_ASSIGN_OR_RETURN( - lmhlo_gpu::GEMMOp op, - CreateOpWithoutAttrs(custom_call, - /*num_operands=*/2)); - - SetMatmulAttributes(op, config, builder_); - return op.getOperation(); -} - -tsl::StatusOr LhloDialectEmitter::EmitCublasLtMatmul( - const HloCustomCallInstruction* custom_call) { - TF_ASSIGN_OR_RETURN( - auto const config, - custom_call->backend_config()); - - bool has_matrix_bias = config.beta() != 0.; - - TF_ASSIGN_OR_RETURN( - bool has_vector_bias, - xla::gpu::gpublas_lt::EpilogueAddsVectorBias(config.epilogue())); - - TF_ASSIGN_OR_RETURN( - bool has_aux_output, - xla::gpu::gpublas_lt::EpilogueHasAuxiliaryOutput(config.epilogue())); - - TF_RET_CHECK(custom_call->operand_count() == - 2 + int{has_matrix_bias} + int{has_vector_bias}); - - xla::ShapeIndex output_index = - has_aux_output ? xla::ShapeIndex{0} : xla::ShapeIndex{}; - - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(0), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(1), &operands)); - if (has_matrix_bias) { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(2), &operands)); - } else { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands, output_index)); - } - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands, output_index)); - - if (has_vector_bias) { - TF_RETURN_IF_ERROR(GetOrCreateView( - custom_call->operand(has_matrix_bias ? 3 : 2), &operands)); - } - - if (has_aux_output) { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands, {1})); - } - - auto op = - CreateOpWithoutAttrs(custom_call, operands); - SetMatmulAttributes(op, config, builder_); - - int32_t operand_sizes[] = { - 1, 1, 1, 1, has_vector_bias ? 1 : 0, has_aux_output ? 1 : 0}; - op->setAttr(op.getOperandSegmentSizeAttr(), - builder_.getDenseI32ArrayAttr(operand_sizes)); - - TF_ASSIGN_OR_RETURN(lmhlo_gpu::CublasLtMatmulEpilogue epilogue, - AsLhloEpilogue(config.epilogue())); - op.setEpilogueAttr(lmhlo_gpu::CublasLtMatmulEpilogueAttr::get( - builder_.getContext(), epilogue)); - - // Use the first algorithm by default (i.e. fastest according to heuristics). - if (config.algorithm_case() != - xla::gpu::GemmBackendConfig::kSelectedAlgorithm) { - op.setAlgorithmAttr(builder_.getI64IntegerAttr(0)); - } - - return op.getOperation(); -} - -tsl::StatusOr LhloDialectEmitter::EmitCublasLtMatmulF8( - const HloCustomCallInstruction* custom_call) { - TF_ASSIGN_OR_RETURN( - auto const config, - custom_call->backend_config()); - - int ops_num = custom_call->operand_count(); - TF_RET_CHECK(ops_num == 6 || ops_num == 7 || ops_num == 8); - TF_ASSIGN_OR_RETURN( - bool has_vector_bias, - xla::gpu::gpublas_lt::EpilogueAddsVectorBias(config.epilogue())); - - bool has_damax = custom_call->shape().IsTuple(); - bool has_matrix_bias = config.beta() != 0.; - xla::ShapeIndex output_index = - has_damax ? xla::ShapeIndex{0} : xla::ShapeIndex{}; - - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(0), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(1), &operands)); - - int a_scale_index = has_matrix_bias ? 3 : 2; - if (has_matrix_bias) { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(2), &operands)); - } else { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands, output_index)); - } - - TF_RETURN_IF_ERROR( - GetOrCreateView(custom_call->operand(a_scale_index), &operands)); - TF_RETURN_IF_ERROR( - GetOrCreateView(custom_call->operand(a_scale_index + 1), &operands)); - TF_RETURN_IF_ERROR( - GetOrCreateView(custom_call->operand(a_scale_index + 2), &operands)); - TF_RETURN_IF_ERROR( - GetOrCreateView(custom_call->operand(a_scale_index + 3), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands, output_index)); - - if (has_vector_bias) { - TF_RETURN_IF_ERROR( - GetOrCreateView(custom_call->operand(a_scale_index + 4), &operands)); - } - if (has_damax) { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands, {1})); - } - auto op = CreateOpWithoutAttrs(custom_call, - operands); - - SetMatmulAttributes(op, config, builder_); - int32_t operand_sizes[] = { - 1, 1, 1, 1, 1, 1, 1, 1, has_vector_bias ? 1 : 0, has_damax ? 1 : 0}; - op->setAttr(op.getOperandSegmentSizeAttr(), - builder_.getDenseI32ArrayAttr(operand_sizes)); - TF_ASSIGN_OR_RETURN(lmhlo_gpu::CublasLtMatmulEpilogue epilogue, - AsLhloEpilogue(config.epilogue())); - op.setEpilogueAttr(lmhlo_gpu::CublasLtMatmulEpilogueAttr::get( - builder_.getContext(), epilogue)); - - // Use the first algorithm by default (i.e. fastest according to heuristics). - if (config.algorithm_case() != - xla::gpu::GemmBackendConfig::kSelectedAlgorithm) { - op.setAlgorithmAttr(builder_.getI64IntegerAttr(0)); - } - - return op.getOperation(); -} - -static tsl::StatusOr GetLHLOActivation( - stream_executor::dnn::ActivationMode activation) { - switch (activation) { - case stream_executor::dnn::kNone: - return mlir::lmhlo_gpu::Activation::None; - case stream_executor::dnn::kSigmoid: - return mlir::lmhlo_gpu::Activation::Sigmoid; - case stream_executor::dnn::kRelu: - return mlir::lmhlo_gpu::Activation::Relu; - case stream_executor::dnn::kRelu6: - return mlir::lmhlo_gpu::Activation::Relu6; - case stream_executor::dnn::kReluX: - return mlir::lmhlo_gpu::Activation::ReluX; - case stream_executor::dnn::kTanh: - return mlir::lmhlo_gpu::Activation::Tanh; - case stream_executor::dnn::kBandPass: - return mlir::lmhlo_gpu::Activation::BandPass; - case stream_executor::dnn::kElu: - return mlir::lmhlo_gpu::Activation::Elu; - case stream_executor::dnn::kLeakyRelu: - return mlir::lmhlo_gpu::Activation::LeakyRelu; - default: - return xla::InternalError("Unknown activation"); - } -} - -tsl::StatusOr LhloDialectEmitter::EmitDnnConvolution( - const HloCustomCallInstruction* custom_call) { - TF_ASSIGN_OR_RETURN( - auto const backend_config, - custom_call->backend_config()); - - TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnConvKind kind, - xla::gpu::GetCudnnConvKind(custom_call)); - - auto get_layout_attribute = [&](const xla::Layout& layout) { - std::vector minor_to_major(layout.minor_to_major_size()); - absl::c_transform(layout.minor_to_major(), minor_to_major.begin(), - [](int64_t x) { return static_cast(x); }); - return minor_to_major; - }; - - auto set_common_conv_attributes = [&, this](auto op) -> Operation* { - const xla::Window& window = custom_call->window(); - // Window size for Cudnn Conv is same as the kernel size. - NamedAttrList attrs(op->getAttrDictionary()); - DenseIntElementsAttr window_strides; - attrs.set(op.getWindowStridesAttrName(), - window_strides = GetWindowElements( - window, [](const xla::WindowDimension& dim) { - return static_cast(dim.stride()); - })); - // Cudnn Conv requires low and high padding to be equal. - attrs.set(op.getPaddingAttrName(), - GetWindowElements(window, [](const xla::WindowDimension& dim) { - return static_cast(dim.padding_low()); - })); - // LHS dilation is encoded in base_dilation of the backend config. - // RHS dilation is encoded in window_dilation of the backend config. - attrs.set(op.getLhsDilationAttrName(), - GetWindowElements(window, [](const xla::WindowDimension& dim) { - return static_cast(dim.base_dilation()); - })); - attrs.set(op.getRhsDilationAttrName(), - GetWindowElements(window, [](const xla::WindowDimension& dim) { - return static_cast(dim.window_dilation()); - })); - // Setup window reversal. - auto window_reversal = llvm::to_vector<4>(llvm::map_range( - window.dimensions(), - [](const xla::WindowDimension& dim) { return dim.window_reversal(); })); - auto type = RankedTensorType::get(window_strides.getType().getShape(), - builder_.getIntegerType(/*width=*/1)); - attrs.set(op.getWindowReversalAttrName(), - DenseElementsAttr::get(type, window_reversal)); - - attrs.set(op.getDimensionNumbersAttrName(), - xla::ConvertConvDimensionNumbers( - custom_call->convolution_dimension_numbers(), &builder_)); - attrs.set(op.getFeatureGroupCountAttrName(), - builder_.getI64IntegerAttr(custom_call->feature_group_count())); - attrs.set(op.getBatchGroupCountAttrName(), - builder_.getI64IntegerAttr(custom_call->batch_group_count())); - attrs.set(op.getPrecisionConfigAttrName(), - xla::ConvertPrecisionConfig(&custom_call->precision_config(), - &builder_)); - attrs.set(op.getResultScaleAttrName(), - builder_.getF64FloatAttr(backend_config.conv_result_scale())); - - const auto& algorithm = backend_config.algorithm(); - std::vector knob_ids; - std::vector knob_values; - for (const auto& entry : algorithm.tuning_knobs()) { - knob_ids.push_back(entry.first); - knob_values.push_back(entry.second); - } - - auto config = mlir::lmhlo_gpu::ConvolutionBackendConfigAttr::get( - builder_.getContext(), algorithm.algo_id(), - - algorithm.math_type() == - stream_executor::dnn::AlgorithmProto::TENSOR_OP_MATH, - knob_ids, knob_values, algorithm.is_cudnn_frontend(), - backend_config.reordered_int8_nchw_vect(), - algorithm.has_workspace_size() ? algorithm.workspace_size().value() - : -1, - get_layout_attribute(custom_call->operand(0)->shape().layout()), - get_layout_attribute(custom_call->operand(1)->shape().layout()), - get_layout_attribute(custom_call->shape().tuple_shapes(0).layout())); - attrs.set(op.getBackendConfigAttrName(), config); - op->setAttrs(attrs.getDictionary(op->getContext())); - - return op.getOperation(); - }; - - auto set_activation = [&, this](auto op) -> tsl::Status { - auto se_activation = static_cast( - backend_config.activation_mode()); - TF_ASSIGN_OR_RETURN(mlir::lmhlo_gpu::Activation activation, - GetLHLOActivation(se_activation)); - auto activation_attr = ::mlir::lmhlo_gpu::ActivationAttr::get( - getLocation(custom_call).getContext(), activation); - op.setActivationModeAttr(activation_attr); - return ::tsl::OkStatus(); - }; - - switch (kind) { - case xla::gpu::CudnnConvKind::kForward: { - TF_ASSIGN_OR_RETURN( - auto cnn_forward, - CreateOpWithoutAttrs(custom_call)); - return set_common_conv_attributes(cnn_forward); - } - case xla::gpu::CudnnConvKind::kBackwardInput: { - TF_ASSIGN_OR_RETURN( - auto cnn_backward, - CreateOpWithoutAttrs(custom_call)); - return set_common_conv_attributes(cnn_backward); - } - case xla::gpu::CudnnConvKind::kBackwardFilter: { - TF_ASSIGN_OR_RETURN( - auto cnn_backward, - CreateOpWithoutAttrs(custom_call)); - return set_common_conv_attributes(cnn_backward); - } - case xla::gpu::CudnnConvKind::kForwardActivation: { - // Fused conv can be either with side input or without. - if (custom_call->operand_count() == 3) { - TF_ASSIGN_OR_RETURN( - auto cnn_fused, - CreateOpWithoutAttrs(custom_call)); - TF_RETURN_IF_ERROR(set_activation(cnn_fused)); - cnn_fused.setLeakyreluAlphaAttr( - builder_.getF64FloatAttr(backend_config.leakyrelu_alpha())); - return set_common_conv_attributes(cnn_fused); - } - - TF_RET_CHECK(custom_call->operand_count() == 4); - TF_ASSIGN_OR_RETURN( - auto cnn_fused_side_input, - CreateOpWithoutAttrs( - custom_call)); - cnn_fused_side_input.setSideInputScaleAttr( - builder_.getF64FloatAttr(backend_config.side_input_scale())); - TF_RETURN_IF_ERROR(set_activation(cnn_fused_side_input)); - return set_common_conv_attributes(cnn_fused_side_input); - } - case xla::gpu::CudnnConvKind::kForwardGraph: { - const int32_t n_binary_operands = custom_call->operand_count() - 2; - const int32_t n_aux_outputs = - custom_call->shape().tuple_shapes_size() - 2; - TF_ASSIGN_OR_RETURN( - auto cnn_graph, - CreateOpWithoutAttrs(custom_call)); - cnn_graph.setSerializedGraph(backend_config.serialized_graph()); - cnn_graph.setNAuxOutputs(n_aux_outputs); - int32_t operand_sizes[] = {1, 1, n_binary_operands, 1, n_aux_outputs, 1}; - cnn_graph->setAttr(cnn_graph.getOperandSegmentSizeAttr(), - builder_.getDenseI32ArrayAttr(operand_sizes)); - return set_common_conv_attributes(cnn_graph); - } - } -} - -tsl::StatusOr -LhloDialectEmitter::EmitDnnConvolutionReorderVectorized( - const HloCustomCallInstruction* custom_call) { - auto set_common_attributes = [&, this](auto op) -> Operation* { - // Output shape defines the filter, it must have NCHW_VECT_C layout. - Shape shape = custom_call->shape(); - if (shape.IsTuple()) { - shape = shape.tuple_shapes(0); - } - - CHECK_EQ(shape.rank(), 5); - CHECK_EQ(shape.dimensions(4), 32); - llvm::SmallVector nchw = { - shape.dimensions(0), shape.dimensions(1) * 32, shape.dimensions(2), - shape.dimensions(3)}; - op->setAttr("filter_dims", GetI64DenseElementsAttr(nchw)); - - return op.getOperation(); - }; - - if (custom_call->operand_count() > 1) { - TF_ASSIGN_OR_RETURN( - auto reorder_filter_and_bias, - CreateOpWithoutAttrs( - custom_call)); - return set_common_attributes(reorder_filter_and_bias); - } else { - TF_ASSIGN_OR_RETURN( - auto reorder_filter, - CreateOpWithoutAttrs(custom_call)); - return set_common_attributes(reorder_filter); - } -} - -tsl::StatusOr LhloDialectEmitter::EmitDnnNorm( - const HloCustomCallInstruction* custom_call) { - TF_ASSIGN_OR_RETURN( - auto const backend_config, - custom_call->backend_config()); - - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(0), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(1), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(2), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - - auto norm = - CreateOpWithoutAttrs(custom_call, operands); - norm.setEpsilonAttr(builder_.getF64FloatAttr(backend_config.epsilon())); - - const auto& algorithm = backend_config.algorithm(); - auto norm_algo_config = mlir::lmhlo_gpu::NormAlgorithmConfigAttr::get( - builder_.getContext(), algorithm.algo_id(), - algorithm.has_workspace_size() ? algorithm.workspace_size().value() : -1); - norm.setAlgorithmConfigAttr(norm_algo_config); - - std::vector operand_minor_to_major; - - auto get_minor_to_major = - [&operand_minor_to_major](const xla::Layout& layout) -> void { - std::vector minor_to_major(layout.minor_to_major_size()); - absl::c_transform(layout.minor_to_major(), minor_to_major.begin(), - [](int64_t x) { return static_cast(x); }); - operand_minor_to_major.insert(operand_minor_to_major.end(), - minor_to_major.begin(), minor_to_major.end()); - }; - - // Store the layout information of all operands and outputs. - for (HloInstruction* operand : custom_call->operands()) { - get_minor_to_major(operand->shape().layout()); - } - for (int i = 0; i < custom_call->shape().tuple_shapes_size() - 1; ++i) { - get_minor_to_major(custom_call->shape().tuple_shapes(i).layout()); - } - - norm.setOperandLayoutsAttr(builder_.getI64ArrayAttr(llvm::ArrayRef{ - operand_minor_to_major.data(), operand_minor_to_major.size()})); - - bool has_aux_outputs = custom_call->shape().tuple_shapes_size() == 4; - int32_t operand_sizes[] = {1, 1, 1, 1, has_aux_outputs, has_aux_outputs, 1}; - norm->setAttr(norm.getOperandSegmentSizeAttr(), - builder_.getDenseI32ArrayAttr(operand_sizes)); - - return norm.getOperation(); -} - -tsl::StatusOr LhloDialectEmitter::EmitDnnfMHA( - const HloCustomCallInstruction* custom_call) { - TF_ASSIGN_OR_RETURN( - auto const config, - custom_call->backend_config()); - - TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, - xla::gpu::GetCudnnfMHAKind(custom_call)); - - bool has_activation = - xla::ShapeUtil::TupleElementCount(custom_call->shape()) == 3; - bool has_mask = false; - bool has_bias = false; - - auto set_common_fmha_attributes = - [&, this](auto op) -> tsl::StatusOr { - TF_ASSIGN_OR_RETURN(lmhlo_gpu::FusedMhaDagSignature fused_mha_dag_signature, - AsLhloFusedMhaDagSignature(kind)); - op.setFusedMhaDagAttr(lmhlo_gpu::FusedMhaDagSignatureAttr::get( - builder_.getContext(), fused_mha_dag_signature)); - op.setBmm1DotDimensionNumbersAttr(GetDotDimensionNumbersAttr( - builder_, config.bmm1_dot_dimension_numbers())); - op.setBmm2DotDimensionNumbersAttr(GetDotDimensionNumbersAttr( - builder_, config.bmm2_dot_dimension_numbers())); - - const auto& algorithm = config.algorithm(); - std::vector knob_ids; - std::vector knob_values; - for (const auto& entry : algorithm.tuning_knobs()) { - knob_ids.push_back(entry.first); - knob_values.push_back(entry.second); - } - auto fmha_algo_config = mlir::lmhlo_gpu::FusedMHAAlgorithmConfigAttr::get( - builder_.getContext(), algorithm.algo_id(), knob_ids, knob_values, - algorithm.has_workspace_size() ? algorithm.workspace_size().value() - : -1); - op.setAlgorithmConfigAttr(fmha_algo_config); - - auto intermediate_tensor_shape = Shape(config.intermediate_tensor_shape()); - auto arrayref = [](absl::Span array) { - return llvm::ArrayRef{array.data(), array.size()}; - }; - auto intermediate_tensor_dims = builder_.getI64ArrayAttr( - arrayref(intermediate_tensor_shape.dimensions())); - op.setIntermediateTensorDimensionsAttr(intermediate_tensor_dims); - - auto intermediate_tensor_layout = builder_.getI64ArrayAttr( - arrayref(intermediate_tensor_shape.layout().minor_to_major())); - op.setIntermediateTensorLayoutAttr(intermediate_tensor_layout); - op.setFmhaScaleAttr(builder_.getF64FloatAttr(config.fmha_scale())); - int32_t operand_sizes[] = {1, - 1, - 1, - has_mask ? 1 : 0, - has_bias ? 1 : 0, - 1, - 1, - has_activation ? 1 : 0}; - op->setAttr(op.getOperandSegmentSizeAttr(), - builder_.getDenseI32ArrayAttr(operand_sizes)); - // set is flash attention here - op.setIsFlashAttentionAttr( - builder_.getBoolAttr(config.is_flash_attention())); - // set is causal mask here - op.setIsCausalMaskAttr(builder_.getBoolAttr(config.is_causal_mask())); - return op.getOperation(); - }; - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(0), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(1), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(2), &operands)); - switch (kind) { - case xla::gpu::CudnnfMHAKind::kBmmBmm: - case xla::gpu::CudnnfMHAKind::kSoftmax: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha = - CreateOpWithoutAttrs(custom_call, operands); - return set_common_fmha_attributes(fmha); - } - case xla::gpu::CudnnfMHAKind::kSoftmaxDropout: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha = - CreateOpWithoutAttrs(custom_call, operands); - fmha.setDropoutRateAttr(builder_.getF64FloatAttr(config.dropout_rate())); - fmha.setSeedAttr(builder_.getI64IntegerAttr(config.seed())); - return set_common_fmha_attributes(fmha); - } - case xla::gpu::CudnnfMHAKind::kScaleMaskSoftmax: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(3), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha = - CreateOpWithoutAttrs(custom_call, operands); - has_mask = true; - return set_common_fmha_attributes(fmha); - } - case xla::gpu::CudnnfMHAKind::kScaleMaskSoftmaxDropout: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(3), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha = - CreateOpWithoutAttrs(custom_call, operands); - fmha.setDropoutRateAttr(builder_.getF64FloatAttr(config.dropout_rate())); - fmha.setSeedAttr(builder_.getI64IntegerAttr(config.seed())); - has_mask = true; - return set_common_fmha_attributes(fmha); - } - case xla::gpu::CudnnfMHAKind::kScaleBiasMaskSoftmax: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(3), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(4), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha = - CreateOpWithoutAttrs(custom_call, operands); - has_mask = true; - has_bias = true; - return set_common_fmha_attributes(fmha); - } - case xla::gpu::CudnnfMHAKind::kScaleBiasMaskSoftmaxDropout: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(3), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(4), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha = - CreateOpWithoutAttrs(custom_call, operands); - fmha.setDropoutRateAttr(builder_.getF64FloatAttr(config.dropout_rate())); - fmha.setSeedAttr(builder_.getI64IntegerAttr(config.seed())); - has_mask = true; - has_bias = true; - return set_common_fmha_attributes(fmha); - } - case xla::gpu::CudnnfMHAKind::kScaleBiasSoftmax: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(3), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha = - CreateOpWithoutAttrs(custom_call, operands); - has_bias = true; - return set_common_fmha_attributes(fmha); - } - case xla::gpu::CudnnfMHAKind::kScaleBiasSoftmaxDropout: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(3), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha = - CreateOpWithoutAttrs(custom_call, operands); - fmha.setDropoutRateAttr(builder_.getF64FloatAttr(config.dropout_rate())); - fmha.setSeedAttr(builder_.getI64IntegerAttr(config.seed())); - has_bias = true; - return set_common_fmha_attributes(fmha); - } - default: - return xla::InternalError("Unknown forward fused MHA call."); - } -} - -tsl::StatusOr LhloDialectEmitter::EmitDnnfMHABackward( - const HloCustomCallInstruction* custom_call) { - TF_ASSIGN_OR_RETURN( - auto const config, - custom_call->backend_config()); - - TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, - xla::gpu::GetCudnnfMHAKind(custom_call)); - - bool is_flash_attention = config.is_flash_attention(); - bool has_dbias = - custom_call->shape().tuple_shapes().size() == 6 && !is_flash_attention; - bool has_mask = false; - bool has_bias = false; - - auto set_common_fmha_backward_attributes = - [&, this](auto op) -> tsl::StatusOr { - TF_ASSIGN_OR_RETURN(lmhlo_gpu::FusedMhaBackwardDagSignature - fused_mha_backward_dag_signature, - AsLhloFusedMhaBackwardDagSignature(kind)); - op.setFusedMhaDagAttr(lmhlo_gpu::FusedMhaBackwardDagSignatureAttr::get( - builder_.getContext(), fused_mha_backward_dag_signature)); - op.setBmm1GradGemm1DotDimensionNumbersAttr(GetDotDimensionNumbersAttr( - builder_, config.bmm1_grad_gemm1_dot_dimension_numbers())); - op.setBmm1GradGemm2DotDimensionNumbersAttr(GetDotDimensionNumbersAttr( - builder_, config.bmm1_grad_gemm2_dot_dimension_numbers())); - op.setBmm2GradGemm1DotDimensionNumbersAttr(GetDotDimensionNumbersAttr( - builder_, config.bmm2_grad_gemm1_dot_dimension_numbers())); - op.setBmm2GradGemm2DotDimensionNumbersAttr(GetDotDimensionNumbersAttr( - builder_, config.bmm2_grad_gemm2_dot_dimension_numbers())); - - auto intermediate_tensor_shape = Shape(config.intermediate_tensor_shape()); - auto arrayref = [](absl::Span array) { - return llvm::ArrayRef{array.data(), array.size()}; - }; - auto intermediate_tensor_dims = builder_.getI64ArrayAttr( - arrayref(intermediate_tensor_shape.dimensions())); - op.setIntermediateTensorDimensionsAttr(intermediate_tensor_dims); - - auto intermediate_tensor_layout = builder_.getI64ArrayAttr( - arrayref(intermediate_tensor_shape.layout().minor_to_major())); - op.setIntermediateTensorLayoutAttr(intermediate_tensor_layout); - - op.setFmhaScaleAttr(builder_.getF64FloatAttr(config.fmha_scale())); - - int32_t operand_sizes[] = {1, - 1, - 1, - 1, - 1, - has_mask ? 1 : 0, - has_bias ? 1 : 0, - is_flash_attention ? 1 : 0, // fwd_output - 1, - 1, - 1, - is_flash_attention ? 0 : 1, // d_S - is_flash_attention ? 1 : 0, // softmax_sum - is_flash_attention ? 1 : 0, // d_Q_accum - 1, - has_dbias ? 1 : 0}; - op->setAttr(op.getOperandSegmentSizeAttr(), - builder_.getDenseI32ArrayAttr(operand_sizes)); - - // set is flash attention here - op.setIsFlashAttentionAttr( - builder_.getBoolAttr(config.is_flash_attention())); - // set is causal mask here - op.setIsCausalMaskAttr(builder_.getBoolAttr(config.is_causal_mask())); - const auto& algorithm = config.algorithm(); - std::vector knob_ids; - std::vector knob_values; - for (const auto& entry : algorithm.tuning_knobs()) { - knob_ids.push_back(entry.first); - knob_values.push_back(entry.second); - } - auto fmha_algo_config = mlir::lmhlo_gpu::FusedMHAAlgorithmConfigAttr::get( - builder_.getContext(), algorithm.algo_id(), knob_ids, knob_values, - algorithm.has_workspace_size() ? algorithm.workspace_size().value() - : -1); - op.setAlgorithmConfigAttr(fmha_algo_config); - return op.getOperation(); - }; - - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(0), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(1), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(2), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(3), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(4), &operands)); - - switch (kind) { - case xla::gpu::CudnnfMHAKind::kBackwardBmmBmm: - case xla::gpu::CudnnfMHAKind::kBackwardSoftmax: { - // push fwd output for bwd here if it is flash attention - if (config.is_flash_attention()) { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); - } - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha_backward = CreateOpWithoutAttrs( - custom_call, operands); - return set_common_fmha_backward_attributes(fmha_backward); - } - case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmax: { - // push fwd output for bwd here if it is flash attention - if (config.is_flash_attention()) { - has_bias = true; - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); - } - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha_backward = CreateOpWithoutAttrs( - custom_call, operands); - return set_common_fmha_backward_attributes(fmha_backward); - } - case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout: { - // push fwd output for bwd here if it is flash attention - if (config.is_flash_attention()) { - has_bias = true; - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); - } - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - auto fmha_backward = CreateOpWithoutAttrs( - custom_call, operands); - fmha_backward.setDropoutRateAttr( - builder_.getF64FloatAttr(config.dropout_rate())); - fmha_backward.setSeedAttr(builder_.getI64IntegerAttr(config.seed())); - return set_common_fmha_backward_attributes(fmha_backward); - } - - case xla::gpu::CudnnfMHAKind::kBackwardScaleMaskSoftmax: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); - // push fwd output for bwd here if it is flash attention - if (config.is_flash_attention()) { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); - } - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - has_mask = true; - auto fmha_backward = CreateOpWithoutAttrs( - custom_call, operands); - return set_common_fmha_backward_attributes(fmha_backward); - } - case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmax: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); - // push fwd output for bwd here if it is flash attention - if (config.is_flash_attention()) { - has_bias = true; - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(7), &operands)); - } - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - has_mask = true; - auto fmha_backward = CreateOpWithoutAttrs( - custom_call, operands); - return set_common_fmha_backward_attributes(fmha_backward); - } - - case xla::gpu::CudnnfMHAKind::kBackwardScaleMaskSoftmaxDropout: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); - // push fwd output for bwd here if it is flash attention - if (config.is_flash_attention()) { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); - } - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - has_mask = true; - auto fmha_backward = CreateOpWithoutAttrs( - custom_call, operands); - fmha_backward.setDropoutRateAttr( - builder_.getF64FloatAttr(config.dropout_rate())); - fmha_backward.setSeedAttr(builder_.getI64IntegerAttr(config.seed())); - return set_common_fmha_backward_attributes(fmha_backward); - } - case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmaxDropout: { - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); - // push fwd output for bwd here if it is flash attention - if (config.is_flash_attention()) { - has_bias = true; - TF_RETURN_IF_ERROR( - GetOrCreateView(custom_call->operand(6), &operands)); // bias - TF_RETURN_IF_ERROR( - GetOrCreateView(custom_call->operand(7), &operands)); // fwd_output - } - TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); - has_mask = true; - auto fmha_backward = CreateOpWithoutAttrs( - custom_call, operands); - fmha_backward.setDropoutRateAttr( - builder_.getF64FloatAttr(config.dropout_rate())); - fmha_backward.setSeedAttr(builder_.getI64IntegerAttr(config.seed())); - return set_common_fmha_backward_attributes(fmha_backward); - } - - default: - return xla::InternalError("Unknown backward fused MHA call."); - } -} - -xla::StatusOr LhloDialectEmitter::EmitCubDeviceRadixSort( - const xla::HloCustomCallInstruction* custom_call) { - TF_ASSIGN_OR_RETURN( - auto radix_sort_op, - CreateOpWithoutAttrs(custom_call)); - TF_ASSIGN_OR_RETURN(xla::SortOptions options, - custom_call->backend_config()); - radix_sort_op.setDescendingAttr(builder_.getBoolAttr(options.descending())); - return radix_sort_op.getOperation(); -} - -// Convert an XLA HLO constant to a global_memref + get_global_memref pair. -tsl::StatusOr LhloDialectEmitter::EmitConstant( - const HloInstruction* instr) { - auto& instr_slice = instr_slices_[std::make_pair(instr, xla::ShapeIndex())]; - if (instr_slice) { - return dyn_cast(instr_slice.getDefiningOp()); - } - - // Insert a global_memref in the module. - Location loc = getLocation(instr); - - auto const_instr = xla::Cast(instr); - TF_RET_CHECK(const_instr->shape().IsArray() && - const_instr->shape().is_static()); - TF_ASSIGN_OR_RETURN(Type type, xla::ConvertShapeToType( - const_instr->shape(), builder_)); - auto memref_type = type.dyn_cast(); - TF_RET_CHECK(memref_type != nullptr); - - TF_ASSIGN_OR_RETURN( - DenseElementsAttr initial_value, - CreateDenseElementsAttrFromLiteral(const_instr->literal(), builder_)); - - std::string constant_name = xla::llvm_ir::ConstantNameToGlobalName( - xla::llvm_ir::SanitizeConstantName(instr->name())); - - // Insert the global memref at the top level. - { - OpBuilder::InsertionGuard guard(builder_); - builder_.clearInsertionPoint(); - auto global_var = builder_.create( - loc, constant_name, builder_.getStringAttr("private"), memref_type, - initial_value, true, /*alignment=*/IntegerAttr()); - symbol_table_.insert(global_var); - global_var.getOperation()->moveBefore(&module_.front()); - - // For operations that do not fold this constant value in their codegen, we - // still need to materialize it into a buffer. Since buffer allocation is - // already done, annotate the global_memref with the information to get to - // the allocated buffer slice for this constant if need be. - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, - assignment_.GetUniqueTopLevelSlice(instr)); - global_var->setAttr( - "lmhlo.alloc", - builder_.getIndexAttr(allocations_.find(slice.allocation()) - ->second.cast() - .getArgNumber())); - TF_RET_CHECK(slice.offset() == 0) - << "Each constant should have its own allocation from BufferAssignment"; - TF_RET_CHECK(slice.allocation()->size() == slice.size()) - << "Each constant should have its own allocation from BufferAssignment"; - } - - auto get_global_memref = - builder_.create(loc, memref_type, constant_name); - - // Update the cache to remember this value. - instr_slice = get_global_memref; - return get_global_memref; -} - -namespace { -template -void SetupChannelIdAttribute(OpT op, const xla::HloChannelInstruction* instr, - mlir::Builder builder) { - if (instr->channel_id().has_value()) { - op.setChannelIdAttr(mlir::mhlo::ChannelHandleAttr::get( - builder.getContext(), *instr->channel_id(), 0)); - } -} - -template -tsl::Status SetupCommonCollectiveOpAttributes(OpT op, - const HloInstruction* instr, - mlir::OpBuilder& builder) { - auto* collective = xla::Cast(instr); - auto replica_groups_attr = xla::HloFunctionImporter::ConvertReplicaGroups( - collective->replica_groups(), &builder); - op->setAttr(replica_groups_attr.getName(), replica_groups_attr.getValue()); - op.setConstrainLayoutAttr( - builder.getBoolAttr(collective->constrain_layout())); - SetupChannelIdAttribute(op, collective, builder); - return ::tsl::OkStatus(); -} -} // namespace - -template -tsl::StatusOr LhloDialectEmitter::EmitDoneOp( - const xla::HloInstruction* instr) { - auto token = ret_tokens_.extract(instr->operand(0)); - TF_RET_CHECK(token) << "didn't find " << OpT::getOperationName().str() - << " token"; - return builder_.create(getLocation(instr), /*resultTypes=*/std::nullopt, - token.mapped()); -} - -tsl::StatusOr -LhloDialectEmitter::EmitAllToAllStartOp(const xla::HloInstruction* instr) { - // All the input of async-done (which wraps the all-to-all) are also - // listed as outputs, so we just create operands for the outputs. - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{})); - - mlir::Location loc = getLocation(instr); - mlir::Type token_type = mlir::mhlo::TokenType::get(builder_.getContext()); - std::array result_types = {token_type}; - auto all_to_all_start_op = - builder_.create(loc, result_types, operands); - - auto* all_to_all = xla::Cast( - instr->async_wrapped_instruction()); - TF_RETURN_IF_ERROR(SetupCommonCollectiveOpAttributes(all_to_all_start_op, - all_to_all, builder_)); - if (all_to_all->split_dimension().has_value()) { - all_to_all_start_op.setSplitDimensionAttr( - builder_.getI64IntegerAttr(*all_to_all->split_dimension())); - } - all_to_all_start_op.setIsSync(IsSyncCollective(instr)); - all_to_all_start_op.setNoParallelCustomCall( - NoParallelCustomCallCollective(instr)); - - auto [_, was_inserted] = - ret_tokens_.insert({instr, all_to_all_start_op.getToken()}); - TF_RET_CHECK(was_inserted) << "all-to-all-start already lowered"; - return all_to_all_start_op; -} - -tsl::StatusOr LhloDialectEmitter::EmitAllToAllDoneOp( - const HloInstruction* instr) { - return EmitDoneOp(instr); -} - -tsl::StatusOr -LhloDialectEmitter::EmitAllGatherStartOp(const HloInstruction* instr) { - llvm::SmallVector operands; - // In all-gather-start HLO, all inputs are also outputs of the HLO. In LMHLO - // though, we list the inputs and outputs just once. In the HLO result, - // the inputs are listed first, followed by outputs, which matches the order - // of operands we need for LMHLO AllGatherOp. - TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{})); - - mlir::Location loc = getLocation(instr); - mlir::Type token_type = mlir::mhlo::TokenType::get(builder_.getContext()); - std::array result_types = {token_type}; - auto all_gather_start_op = - builder_.create(loc, result_types, operands); - - auto* all_gather = xla::Cast(instr); - TF_RETURN_IF_ERROR( - SetupCommonCollectiveOpAttributes(all_gather_start_op, instr, builder_)); - all_gather_start_op.setUseGlobalDeviceIdsAttr( - builder_.getBoolAttr(all_gather->use_global_device_ids())); - all_gather_start_op.setAllGatherDimensionAttr( - builder_.getI64IntegerAttr(all_gather->all_gather_dimension())); - all_gather_start_op.setIsSync(IsSyncCollective(instr)); - all_gather_start_op.setNoParallelCustomCall( - NoParallelCustomCallCollective(instr)); - auto [_, was_inserted] = - ret_tokens_.insert({instr, all_gather_start_op.getToken()}); - TF_RET_CHECK(was_inserted) << "all-gather-start already lowered"; - return all_gather_start_op; -} - -tsl::StatusOr -LhloDialectEmitter::EmitAllGatherDoneOp(const HloInstruction* instr) { - return EmitDoneOp(instr); -} - -tsl::StatusOr -LhloDialectEmitter::EmitAllReduceStartOp(const HloInstruction* instr) { - llvm::SmallVector operands; - for (const HloInstruction* operand : instr->operands()) { - TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands)); - } - TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{})); - - mlir::Location loc = getLocation(instr); - mlir::Type token_type = mlir::mhlo::TokenType::get(builder_.getContext()); - std::array result_types = {token_type}; - auto all_reduce_start_op = - builder_.create(loc, result_types, operands); - - auto* all_reduce = xla::Cast(instr); - TF_RETURN_IF_ERROR( - SetupCommonCollectiveOpAttributes(all_reduce_start_op, instr, builder_)); - all_reduce_start_op.setUseGlobalDeviceIdsAttr( - builder_.getBoolAttr(all_reduce->use_global_device_ids())); - all_reduce_start_op.setIsSync(IsSyncCollective(instr)); - all_reduce_start_op.setNoParallelCustomCall( - NoParallelCustomCallCollective(instr)); - - TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( - *instr->called_computations()[0], symbol_table_, - &all_reduce_start_op.getComputation(), &builder_)); - - auto [_, was_inserted] = - ret_tokens_.insert({instr, all_reduce_start_op.getToken()}); - TF_RET_CHECK(was_inserted) << "all-reduce-start already lowered"; - return all_reduce_start_op; -} - -tsl::StatusOr -LhloDialectEmitter::EmitAllReduceDoneOp(const HloInstruction* instr) { - return EmitDoneOp(instr); -} - -tsl::StatusOr LhloDialectEmitter::EmitAsyncStartOp( - const xla::HloInstruction* instr) { - const xla::HloAsyncInstruction* async = - xla::Cast(instr); - - switch (async->async_wrapped_opcode()) { - case xla::HloOpcode::kReduceScatter: - return EmitReduceScatterStartOp(instr); - case xla::HloOpcode::kAllToAll: - return EmitAllToAllStartOp(instr); - default: - return tsl::errors::InvalidArgument( - "Unexpected instruction %s wrapped in %s", - xla::HloOpcodeString(async->async_wrapped_opcode()), - HloOpcodeString(instr->opcode())); - } -} - -tsl::StatusOr LhloDialectEmitter::EmitAsyncDoneOp( - const xla::HloInstruction* instr) { - const xla::HloAsyncInstruction* async = - xla::Cast(instr); - switch (async->async_wrapped_opcode()) { - case xla::HloOpcode::kReduceScatter: - return EmitReduceScatterDoneOp(instr); - case xla::HloOpcode::kAllToAll: - return EmitAllToAllDoneOp(instr); - default: - return tsl::errors::InvalidArgument( - "Unexpected instruction %s wrapped in %s", - xla::HloOpcodeString(async->async_wrapped_opcode()), - HloOpcodeString(instr->opcode())); - } -} - -tsl::StatusOr -LhloDialectEmitter::EmitReduceScatterStartOp(const xla::HloInstruction* instr) { - // All the input of async-done (which wraps the reduce-scatter) are also - // listed as outputs, so we just create operands for the outputs. - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{})); - - mlir::Location loc = getLocation(instr); - mlir::Type token_type = mlir::mhlo::TokenType::get(builder_.getContext()); - std::array result_types = {token_type}; - auto reduce_scatter_start_op = - builder_.create(loc, result_types, - operands); - - auto* reduce_scatter = xla::Cast( - instr->async_wrapped_instruction()); - TF_RETURN_IF_ERROR(SetupCommonCollectiveOpAttributes( - reduce_scatter_start_op, reduce_scatter, builder_)); - reduce_scatter_start_op.setUseGlobalDeviceIdsAttr( - builder_.getBoolAttr(reduce_scatter->use_global_device_ids())); - reduce_scatter_start_op.setScatterDimensionAttr( - builder_.getI64IntegerAttr(reduce_scatter->scatter_dimension())); - reduce_scatter_start_op.setIsSync(IsSyncCollective(instr)); - reduce_scatter_start_op.setNoParallelCustomCall( - NoParallelCustomCallCollective(instr)); - TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( - *reduce_scatter->to_apply(), symbol_table_, - &reduce_scatter_start_op.getComputation(), &builder_)); - - auto [_, was_inserted] = - ret_tokens_.insert({instr, reduce_scatter_start_op.getToken()}); - TF_RET_CHECK(was_inserted) << "reduce-scatter-start already lowered"; - return reduce_scatter_start_op; -} - -tsl::StatusOr -LhloDialectEmitter::EmitReduceScatterDoneOp(const xla::HloInstruction* instr) { - return EmitDoneOp(instr); -} - -tsl::StatusOr -LhloDialectEmitter::EmitCollectivePermuteStartOp(const HloInstruction* instr) { - llvm::SmallVector operands; - for (const HloInstruction* operand : instr->operands()) { - TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands)); - } - // Ignore the aliased first output and TPU-specific outputs. - TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{1})); - - mlir::Location loc = getLocation(instr); - mlir::Type token_type = mlir::mhlo::TokenType::get(builder_.getContext()); - std::array result_types = {token_type}; - auto permute_start_op = builder_.create( - loc, result_types, operands); - - auto* permute = xla::Cast(instr); - SetupChannelIdAttribute(permute_start_op, permute, builder_); - mlir::NamedAttribute source_target_pairs_attr = - xla::HloFunctionImporter::ConvertSourceTargetPairs( - permute->source_target_pairs(), &builder_); - permute_start_op->setAttr(source_target_pairs_attr.getName(), - source_target_pairs_attr.getValue()); - permute_start_op.setIsSync(IsSyncCollective(instr)); - permute_start_op.setNoParallelCustomCall( - NoParallelCustomCallCollective(instr)); - - auto [_, was_inserted] = - ret_tokens_.insert({instr, permute_start_op.getToken()}); - TF_RET_CHECK(was_inserted) << "collective-permute-start already lowered"; - return permute_start_op; -} - -tsl::StatusOr -LhloDialectEmitter::EmitCollectivePermuteDoneOp(const HloInstruction* instr) { - return EmitDoneOp(instr); -} - -tsl::StatusOr LhloDialectEmitter::EmitInfeedOp( - const HloInstruction* instr) { - const HloInfeedInstruction* infeed = xla::Cast(instr); - // HLO Infeed instruction has a single operand of token type and a tuple - // with buffers and a token as its output. LMHLO Infeed operation does not - // need the token operand or result, so drop it. - SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{0})); - auto infeed_op = CreateOpWithoutAttrs(instr, operands); - infeed_op.setConfigAttr(builder_.getStringAttr(infeed->infeed_config())); - return infeed_op; -} - -tsl::StatusOr LhloDialectEmitter::EmitOutfeedOp( - const HloInstruction* instr) { - const HloOutfeedInstruction* outfeed = - xla::Cast(instr); - // HLO outfeed instruction has 2 operands, the source and a token, and a - // single token output. LMHLO Outfeed does not need the token operand and - // result, do drop it. - SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(0), &operands)); - auto outfeed_op = CreateOpWithoutAttrs(instr, operands); - outfeed_op.setConfigAttr(builder_.getStringAttr(outfeed->outfeed_config())); - return outfeed_op; -} - -tsl::StatusOr -LhloDialectEmitter::EmitRngGetAndUpdateStateOp( - const xla::HloInstruction* instr) { - TF_ASSIGN_OR_RETURN( - auto rng, CreateOpWithoutAttrs(instr)); - auto hlo_rng = xla::Cast(instr); - rng.setDeltaAttr(builder_.getI64IntegerAttr(hlo_rng->delta())); - return rng; -} - -tsl::StatusOr LhloDialectEmitter::EmitFftOp( - const HloInstruction* instr) { - auto hlo_fft = xla::Cast(instr); - TF_ASSIGN_OR_RETURN(auto fft, CreateOpWithoutAttrs(instr)); - TF_ASSIGN_OR_RETURN(mlir::mhlo::FftType fft_type, - xla::ConvertFftType(hlo_fft->fft_type())); - fft.setFftTypeAttr( - mlir::mhlo::FftTypeAttr::get(builder_.getContext(), fft_type)); - fft.setFftLengthAttr(GetI64DenseElementsAttr(instr->fft_length())); - return fft; -} - -tsl::StatusOr -LhloDialectEmitter::EmitTriangularSolveOp(const xla::HloInstruction* instr) { - auto hlo_triangular_solve = - xla::Cast(instr); - TF_ASSIGN_OR_RETURN(auto triangular_solve, - CreateOpWithoutAttrs(instr)); - const xla::TriangularSolveOptions& options = - hlo_triangular_solve->triangular_solve_options(); - triangular_solve.setLeftSideAttr(builder_.getBoolAttr(options.left_side())); - triangular_solve.setLowerAttr(builder_.getBoolAttr(options.lower())); - triangular_solve.setUnitDiagonalAttr( - builder_.getBoolAttr(options.unit_diagonal())); - TF_ASSIGN_OR_RETURN(mlir::mhlo::Transpose transpose, - xla::ConvertTranspose(options.transpose_a())); - triangular_solve.setTransposeAAttr( - mlir::mhlo::TransposeAttr::get(builder_.getContext(), transpose)); - triangular_solve.setLayoutAAttr( - GetLayoutAttribute(instr->operand(0)->shape().layout(), &builder_)); - triangular_solve.setLayoutBAttr( - GetLayoutAttribute(instr->operand(1)->shape().layout(), &builder_)); - triangular_solve.setLayoutOutputAttr( - GetLayoutAttribute(instr->shape().layout(), &builder_)); - return triangular_solve; -} - -tsl::StatusOr LhloDialectEmitter::EmitBitcast( - const xla::HloInstruction* instr) { - // XLA buffer assignment should assign the same slice to a bitcast input and - // output. - const xla::ShapeIndex top_index; - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice, - assignment_.GetUniqueSlice(instr, top_index)); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice, - assignment_.GetUniqueSlice(instr->operand(0), top_index)); - - if (input_slice != result_slice) { - return tsl::errors::InvalidArgument( - "Bitcast input and result slice should be same"); - } - return nullptr; -} - -mlir::DenseIntElementsAttr LhloDialectEmitter::GetLayoutAttribute( - const xla::Layout& layout, Builder* builder) { - llvm::SmallVector minor_to_major(layout.minor_to_major().begin(), - layout.minor_to_major().end()); - return builder->getIndexTensorAttr(minor_to_major); -} - -tsl::Status LhloDialectEmitter::ImportAsLmhloRegion( - xla::HloComputation* computation, mlir::Region* region) { - auto after = builder_.saveInsertionPoint(); - auto reverter = absl::MakeCleanup( - [this, after] { builder_.restoreInsertionPoint(after); }); - - builder_ = OpBuilder(region); - xla::HloModule* hlo_module = computation->parent(); - if (!hlo_module->has_schedule()) { - return tsl::errors::Unimplemented( - "Missing sequential order for the computation"); - } - const xla::HloInstructionSequence* schedule = - &hlo_module->schedule().sequence(computation); - TF_RETURN_IF_ERROR( - computation->AcceptOrdered(this, schedule->instructions())); - builder_.create(builder_.getUnknownLoc()); - return ::tsl::OkStatus(); -} - -tsl::StatusOr LhloDialectEmitter::EmitCaseOp( - const HloInstruction* instr) { - Location loc = getLocation(instr); - llvm::SmallVector operands; - size_t num_arguments, num_results; - TF_RETURN_IF_ERROR(CreateOperands(instr, 1, TokenLoweringMode::kUseNull, - operands, num_arguments, num_results)); - - auto case_op = - builder_.create(loc, operands[0], instr->branch_count()); - - for (int i = 0; i < instr->branch_count(); i++) { - case_op.getBranches()[i].push_back(new mlir::Block()); - TF_RETURN_IF_ERROR(ImportAsLmhloRegion(instr->called_computations()[i], - &case_op.getBranches()[i])); - } - - return case_op; -} - -tsl::StatusOr LhloDialectEmitter::EmitWhileOp( - const xla::HloInstruction* instr) { - Location loc = getLocation(instr); - SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView( - instr->called_computations()[1]->root_instruction(), &operands)); - TF_RET_CHECK(operands.size() == 1); - - TF_ASSIGN_OR_RETURN(auto config, - instr->backend_config()); - mlir::IntegerAttr trip_count; - if (config.has_known_trip_count()) { - trip_count = builder_.getI64IntegerAttr(config.known_trip_count().n()); - } - lmhlo::WhileOp while_op = - builder_.create(loc, operands[0], trip_count); - - while_op.getCond().push_back(new mlir::Block()); - while_op.getBody().push_back(new mlir::Block()); - TF_RETURN_IF_ERROR(ImportAsLmhloRegion(instr->called_computations()[1], - &while_op.getCond())); - - TF_RETURN_IF_ERROR(ImportAsLmhloRegion(instr->called_computations()[0], - &while_op.getBody())); - - return while_op; -} - -// TODO(b/264291989): Use enum to define the host transfer type (channel type). -template -static void CopyChannelAttrs(OpBuilder& b, Instr* instr, OpTy op, - int host_transfer_type) { - op.setIsHostTransferAttr(b.getBoolAttr(instr->is_host_transfer())); - op.setChannelHandleAttr(mlir::mhlo::ChannelHandleAttr::get( - b.getContext(), *instr->channel_id(), - instr->is_host_transfer() ? host_transfer_type : /*DEVICE_TO_DEVICE*/ 1)); -} - -template -static void CopyFrontendAttrs(OpBuilder& b, Instr* instr, OpTy op) { - llvm::SmallVector frontend_attrs; - for (auto& [name, value] : instr->frontend_attributes().map()) { - frontend_attrs.push_back(b.getNamedAttr(name, b.getStringAttr(value))); - } - op->setAttr(b.getStringAttr("frontend_attributes"), - b.getDictionaryAttr(frontend_attrs)); -} - -tsl::StatusOr LhloDialectEmitter::EmitSendOp( - const xla::HloInstruction* instr) { - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(0), &operands)); - - auto token = mhlo::TokenType::get(builder_.getContext()); - auto send_op = builder_.create(getLocation(instr), - TypeRange(token), operands); - - // Set point-to-point op communication attributes. - auto* send = xla::Cast(instr); - CopyChannelAttrs(builder_, send, send_op, /*host_transfer_type=*/2); - CopyFrontendAttrs(builder_, send, send_op); - - auto [_, emplaced] = ret_tokens_.try_emplace(instr, send_op.getToken()); - TF_RET_CHECK(emplaced) << "send already lowered"; - return send_op; -} - -tsl::StatusOr LhloDialectEmitter::EmitSendDoneOp( - const xla::HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(auto send_done_op, EmitDoneOp(instr)); - // Copy send-done attributes. - auto* send_done = xla::Cast(instr); - CopyChannelAttrs(builder_, send_done, send_done_op, - /*host_transfer_type=*/2); - - return send_done_op; -} - -tsl::StatusOr LhloDialectEmitter::EmitRecvOp( - const xla::HloInstruction* instr) { - llvm::SmallVector operands; - TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, {0})); - - auto token = mhlo::TokenType::get(builder_.getContext()); - auto recv_op = builder_.create(getLocation(instr), - TypeRange(token), operands); - - // Set point-to-point op communication attributes. - auto* recv = xla::Cast(instr); - CopyChannelAttrs(builder_, recv, recv_op, /*host_transfer_type=*/3); - CopyFrontendAttrs(builder_, recv, recv_op); - - auto [_, emplaced] = ret_tokens_.try_emplace(instr, recv_op.getToken()); - TF_RET_CHECK(emplaced) << "recv already lowered"; - return recv_op; -} - -tsl::StatusOr LhloDialectEmitter::EmitRecvDoneOp( - const xla::HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(auto recv_done_op, EmitDoneOp(instr)); - // Copy recv-done attributes. - auto* recv_done = xla::Cast(instr); - CopyChannelAttrs(builder_, recv_done, recv_done_op, - /*host_transfer_type=*/3); - - return recv_done_op; -} - -tsl::StatusOr LhloDialectEmitter::EmitCommandBufferOp( - const xla::HloInstruction* instr) { - const std::vector called_computations = - instr->called_computations(); - if (called_computations.size() != 1) { - return absl::InternalError( - "Command buffer calls must have one called computation"); - } - - if (!absl::StartsWith(called_computations[0]->name(), "command_buffer")) { - return absl::InternalError("Called computation must be a command buffer"); - } - return builder_.create(getLocation(instr)); -} - -// Sets builder insertion point for a new `memref.view` operation in the parent -// function. We create just one `memref.view` operation for every unique -// subspan of allocation, and because first use of the slice can be inside a -// block nested in a control flow operation, we have to find an insertion point -// in the parent function. Returns insertion guard for the original insertion -// point. -static tsl::StatusOr SetArrayViewInsertionPoint( - OpBuilder& builder) { - OpBuilder::InsertionGuard guard(builder); - - Operation* parent = builder.getInsertionBlock()->getParentOp(); - while (!isa(parent)) { - builder.setInsertionPoint(parent); - if ((parent = parent->getParentOp()) == nullptr) - return absl::InternalError( - "Can't find an insertion point for memref.view operation"); - } - - return guard; -} - -tsl::StatusOr LhloDialectEmitter::GetOrCreateArrayView( - const xla::HloInstruction* instr, const xla::Shape& current_shape, - const xla::ShapeIndex& shape_index) { - // For constants, the cache is managed inside EmitConstant since it can - // be called either from here or when we see a top-level HloConstant instr. - if (instr->IsConstant() && shape_index.empty()) { - TF_ASSIGN_OR_RETURN(Value constant_memref, EmitConstant(instr)); - return constant_memref; - } - - // Cache generated ViewOp and StaticMemRefCastOp by (instruction, - // shape_index). - auto& instr_slice = instr_slices_[std::make_pair(instr, shape_index)]; - if (instr_slice) { - return instr_slice; - } - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, - assignment_.GetUniqueSlice(instr, shape_index)); - - // If the shape happens to have dynamic dimensions, create the memref using - // the underlying static shape. - // TODO(jurahul): Revisit this when we can model memrefs with dynamic shape - // but static bounds in MLIR. - xla::Shape static_shape = xla::ShapeUtil::MakeStaticShape(current_shape); - - // Try to find allocation slice with the same physical shape so that we always - // have only one memref.view operation covering the same buffer subspan. All - // reinterpret casts into different layouts will use the same source memref. - xla::Shape physical_shape = - xla::ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - static_shape); - - // Initialize values in `allocation_slices_` before taking references, - // otherwise we can invalidate them and trigger asan errors below. - auto static_shape_key = std::make_pair(slice, static_shape); - auto physical_shape_key = std::make_pair(slice, physical_shape); - allocation_slices_[static_shape_key]; - allocation_slices_[physical_shape_key]; - - // Check if we already have a memref.view for a given slice and shape. - auto& allocation_slice = allocation_slices_[static_shape_key]; - if (allocation_slice) { - return instr_slice = allocation_slice; - } - - TF_ASSIGN_OR_RETURN(Type out_type, xla::ConvertShapeToType( - static_shape, builder_)); - TF_ASSIGN_OR_RETURN( - Type physical_out_type, - xla::ConvertShapeToType(physical_shape, builder_)); - - // Try to find an insertion point for a new memref.view operation. - TF_ASSIGN_OR_RETURN(auto guard, SetArrayViewInsertionPoint(builder_)); - - // TODO(timshen): revisit location handling. - Location loc = builder_.getUnknownLoc(); - - // Creates new memref.view operation with a `physical_shape`. - auto create_physical_slice = [&]() -> Value { - Value alloc = allocations_[slice.allocation()]; - Value byte_shift = - builder_.create(alloc.getLoc(), slice.offset()); - - // ViewOp only takes memrefs without affine maps (layouts). Let ViewOp - // produce the physical shape (where dimensions are ordered in major to - // minor) first, then follow up with a MemRefReinterpretCast to cast the - // resulting memref to the original layout. - return builder_.create(loc, physical_out_type, alloc, - byte_shift, - /*sizes=*/ValueRange()); - }; - - // Reuse existing physical slice if it exists, otherwise build a new - // memref.view operation and cache it. - auto& physical_slice = allocation_slices_[physical_shape_key]; - if (!physical_slice) { - physical_slice = create_physical_slice(); - } - - // Start from a physical slice as a result, and maybe reinterpret cast it into - // logical shape. - Value result = physical_slice; - - if (result.getType() != out_type) { - int64_t out_offset; - SmallVector out_strides; - auto out_memref_type = out_type.dyn_cast(); - if (!out_memref_type) - return tsl::errors::Internal( - "Expected memref type when creating a view for leaf type of a " - "tuple."); - if (failed(getStridesAndOffset(out_memref_type, out_strides, out_offset))) - return tsl::errors::Internal( - "Failed to get strides and offset from the output type."); - result = builder_.create( - loc, out_memref_type, result, out_offset, out_memref_type.getShape(), - out_strides); - } - - return instr_slice = allocation_slice = result; -} - -tsl::Status LhloDialectEmitter::GetOrCreateViewImpl( - const HloInstruction* instr, const Shape& current_shape, - xla::ShapeIndex* current_shape_index, SmallVectorImpl* values, - TokenLoweringMode token_mode) { - if (current_shape.IsTuple()) { - for (int i = 0; i < current_shape.tuple_shapes().size(); ++i) { - current_shape_index->push_back(i); - TF_RETURN_IF_ERROR( - GetOrCreateViewImpl(instr, current_shape.tuple_shapes(i), - current_shape_index, values, token_mode)); - current_shape_index->pop_back(); - } - return ::tsl::OkStatus(); - } - if (current_shape.IsArray()) { - TF_ASSIGN_OR_RETURN(auto v, GetOrCreateArrayView(instr, current_shape, - *current_shape_index)); - values->push_back(v); - return ::tsl::OkStatus(); - } - if (current_shape.IsToken()) { - switch (token_mode) { - case TokenLoweringMode::kFailToLower: - return tsl::errors::Internal( - "Unexpected token kind for %s and shape index %s", - instr->ToString(), current_shape_index->ToString()); - - case TokenLoweringMode::kUseNull: - values->push_back(Value{}); - return ::tsl::OkStatus(); - } - } - return tsl::errors::Internal( - "Unexpected shape kind for %s and shape index %s", instr->ToString(), - current_shape_index->ToString()); -} - -// Returns a view for the result of an instruction. -// We first get a view for the slice in the allocation, and then may need to -// create another view to adjust the slice for the shape of the instruction. -tsl::Status LhloDialectEmitter::GetOrCreateView( - const HloInstruction* instr, SmallVectorImpl* values, - const xla::ShapeIndex& result_subset, TokenLoweringMode token_mode) { - xla::ShapeIndex shape_index = result_subset; - const Shape& sub_shape = - xla::ShapeUtil::GetSubshape(instr->shape(), shape_index); - return GetOrCreateViewImpl(instr, sub_shape, &shape_index, values, - token_mode); -} - -tsl::Status LhloDialectEmitter::Initialize( - std::vector* ordered_allocations) { - TF_RET_CHECK(computation_.IsEntryComputation()); - - mlir::IntegerAttr unique_id = - builder_.getI32IntegerAttr(computation_.parent()->unique_id()); - module_->setAttr("hlo.unique_id", unique_id); - llvm::StringRef function_name = - computation_.name().empty() ? "__compute" - : llvm::StringRef(computation_.name().data(), - computation_.name().size()); - - // Create the function as () -> (), we'll compute the arguments from the - // buffer allocation and update the type then. - auto func_op = func::FuncOp::create(builder_.getUnknownLoc(), function_name, - builder_.getFunctionType({}, {})); - - { - // This is an optional attribute used by the XLA backend. If the resulting - // LMHLO doesn't go through XLA, this is not needed. - const Shape& shape = computation_.root_instruction()->shape(); - func_op->setAttr( - "result_xla_shape", - builder_.getStringAttr(shape.ToString(/*print_layout=*/true))); - } - Block* block = func_op.addEntryBlock(); - - for (const BufferAllocation& alloc : assignment_.Allocations()) { - if (!alloc.is_thread_local()) { - ordered_allocations->push_back(&alloc); - } - } - - if (computation_.IsEntryComputation()) { - // Sort the rather arbitrarily ordered allocations to match the input/output - // parameters. Specifically we want to sort buffer allocations in the - // following order: - // * Parameters always order before non-parameters. - // * Different parameters order by parameter number. - // * Different allocations for the same parameter order by the shape index. - // - // TODO(timshen): there should be only one non-parameter buffer, the temp - // buffer. Check on that. - const auto allocation_comparator = [](const BufferAllocation* lhs, - const BufferAllocation* rhs) { - if (lhs->is_entry_computation_parameter() != - rhs->is_entry_computation_parameter()) { - return lhs->is_entry_computation_parameter() > - rhs->is_entry_computation_parameter(); - } - if (lhs->is_entry_computation_parameter()) { - return std::tuple( - lhs->parameter_number(), lhs->param_shape_index()) < - std::tuple( - rhs->parameter_number(), rhs->param_shape_index()); - } - return false; - }; - - std::stable_sort(ordered_allocations->begin(), ordered_allocations->end(), - allocation_comparator); - } - - absl::flat_hash_map> - allocation_to_output_info; - TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus( - computation_.root_instruction()->shape(), - [&](const Shape& sub_shape, xla::ShapeIndex index) -> tsl::Status { - TF_ASSIGN_OR_RETURN( - auto slice, - assignment_.GetUniqueSlice(computation_.root_instruction(), index)); - const BufferAllocation* alloc = slice.allocation(); - TF_RET_CHECK(slice.offset() == 0); - TF_RET_CHECK(slice.size() == alloc->size()); - allocation_to_output_info[alloc] = std::make_pair(&sub_shape, index); - return ::tsl::OkStatus(); - })); - - // The function signature will be composed of: - // - one memref for each of the parameters. - // - one memref for each other buffer allocation. - llvm::SmallVector args_attrs; - auto it = ordered_allocations->begin(); - while (it != ordered_allocations->end()) { - const BufferAllocation* alloc = *it; - // There are optional attributes to help the program run through XLA. XLA - // defines ExecutionInput and ExecutionOutput structures to carry - // input-output type and buffer information, therefore any information they - // need (mainly the type structure, potentially containing tuples) to be - // preserved. They are not needed if the generated LMHLO is not sent to XLA. - NamedAttrList arg_attr_list; - mlir::Type arg_type = MemRefType::get({alloc->size()}, i8_type_); - - // Propagate source location information for every HLOInstruction that - // uses this allocation. - std::vector buf_locs; - buf_locs.reserve(alloc->assigned_buffers().size()); - for (const auto& entry : alloc->assigned_buffers()) { - const xla::HloValue* hlo_value = entry.first; - buf_locs.push_back(getLocation(hlo_value->instruction())); - } - mlir::Location loc = builder_.getFusedLoc(buf_locs); - - if (alloc->is_entry_computation_parameter()) { - arg_attr_list.set("lmhlo.params", - builder_.getIndexAttr(alloc->parameter_number())); - if (!alloc->param_shape_index().empty()) { - arg_attr_list.set("lmhlo.param_shape_index", - builder_.getI64TensorAttr(llvm::ArrayRef( - alloc->param_shape_index().begin(), - alloc->param_shape_index().end()))); - } - } - // Optional: an attribute for optimization. If a kernel uses this - // allocation, but the allocation has lmhlo.constant_name, then the kernel - // will instead use the global value indicated by the name for potentially - // more optimizations (e.g. constant propagation). - if (alloc->is_constant()) { - arg_attr_list.set( - "lmhlo.constant_name", - builder_.getStringAttr( - xla::llvm_ir::ConstantBufferAllocationToGlobalName(*alloc))); - } - auto iter = allocation_to_output_info.find(alloc); - if (iter != allocation_to_output_info.end()) { - const Shape* sub_shape = iter->second.first; - const xla::ShapeIndex& shape_index = iter->second.second; - if (!sub_shape->IsArray()) { - it = ordered_allocations->erase(it); - continue; - } - arg_attr_list.set("lmhlo.output_index", - builder_.getI64TensorAttr(llvm::ArrayRef( - shape_index.begin(), shape_index.end()))); - if (auto alias = computation_.parent() - ->input_output_alias_config() - .GetAliasedParameter(shape_index)) { - if (alias->must_alias()) { - arg_attr_list.set("lmhlo.must_alias", builder_.getUnitAttr()); - } - } - } - block->addArgument(arg_type, loc); - allocations_[alloc] = block->getArguments().back(); - args_attrs.push_back(arg_attr_list.getDictionary(builder_.getContext())); - it++; - } - - FunctionType function_type = - builder_.getFunctionType(block->getArgumentTypes(), {}); - func_op.setType(function_type); - func_op.setAllArgAttrs(args_attrs); - - symbol_table_.insert(func_op); - builder_.setInsertionPointToEnd(block); - - auto return_op = - builder_.create(builder_.getUnknownLoc()); - builder_ = OpBuilder(return_op); - - return ::tsl::OkStatus(); -} - -tsl::Status HloToLhloModule( - const BufferAssignment& assignment, const HloModule& hlo_module, - ModuleOp module, std::vector* ordered_allocations, - absl::flat_hash_map* - lhlo_to_hlo_map) { - module.getContext() - ->loadDialect(); - - module->setLoc(mlir::NameLoc::get( - mlir::StringAttr::get(module.getContext(), hlo_module.name()))); - - // Store the HloModule's unique_id in the MLIR module. - Builder builder(module.getContext()); - module->setAttr("mhlo.unique_id", - builder.getI64IntegerAttr(hlo_module.unique_id())); - - const HloComputation* computation = hlo_module.entry_computation(); - - LhloDialectEmitter emitter(assignment, *computation, module); - TF_RETURN_IF_ERROR(emitter.Initialize(ordered_allocations)); - - const xla::HloInstructionSequence* schedule = - &hlo_module.schedule().sequence(computation); - - if (!schedule) { - return tsl::errors::Unimplemented( - "Missing sequential order for the computation"); - } - BaseScopedDiagnosticHandler status_handler(module.getContext()); - - const std::vector& ordering = schedule->instructions(); - TF_RETURN_IF_ERROR(computation->AcceptOrdered(&emitter, ordering)); - TF_RETURN_IF_ERROR(status_handler.ConsumeStatus()); - - (void)mlir::verify(module); - - if (lhlo_to_hlo_map) { - auto map = emitter.ConsumeLhloToHloMap(); - std::swap(*lhlo_to_hlo_map, map); - } - return status_handler.ConsumeStatus(); -} - -OwningOpRef HloTextToLhloTranslateFunction( - llvm::StringRef input, MLIRContext* context) { - tsl::StatusOr> maybe_module = - xla::ParseAndReturnUnverifiedModule( - absl::string_view(input.data(), input.size())); - TF_CHECK_OK(maybe_module.status()); - - OwningOpRef module = - xla::llvm_ir::CreateMlirModuleOp(UnknownLoc::get(context)); - - TF_CHECK_OK( - ConvertHloToLmhlo(std::move(maybe_module).value(), module.get(), "Host")); - - return module; -} - -} // namespace mlir diff --git a/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h b/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h deleted file mode 100644 index 6a2f5f6998946..0000000000000 --- a/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h +++ /dev/null @@ -1,348 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_TRANSLATE_MHLO_TO_LHLO_WITH_XLA_MHLO_TO_LHLO_WITH_XLA_H_ -#define XLA_TRANSLATE_MHLO_TO_LHLO_WITH_XLA_MHLO_TO_LHLO_WITH_XLA_H_ - -#include -#include -#include -#include - -#include "absl/types/optional.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" -#include "xla/service/buffer_assignment.h" -#include "xla/shape_util.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" - -namespace mlir { - -// This class will process an HloModule with the supplied BufferAssignment and -// populate the MLIR ModuleOp with the computation converted in the LHLO -// dialect. -class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault { - public: - // Initializes internal data structures. It must be called before calling any - // of the visitors. - tsl::Status Initialize( - std::vector* ordered_allocations); - - LhloDialectEmitter(const xla::BufferAssignment& assignment, - const xla::HloComputation& computation, ModuleOp module) - : assignment_(assignment), - computation_(computation), - module_(module), - symbol_table_(module), - builder_(module.getContext()), - i8_type_(builder_.getIntegerType(8)) {} - - tsl::StatusOr EmitOp(const xla::HloInstruction* instr); - - static tsl::StatusOr - GetScatterDimensionNumbers(const xla::HloInstruction* instr, - mlir::MLIRContext* context); - - absl::flat_hash_map - ConsumeLhloToHloMap() { - return std::move(lhlo_to_hlo_); - } - - private: - tsl::StatusOr EmitSortOp(const xla::HloInstruction* instr); - tsl::StatusOr EmitFusionOp(const xla::HloInstruction* instr); - tsl::StatusOr EmitScatterOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitSelectAndScatterOp( - const xla::HloInstruction* instr); - - tsl::StatusOr EmitCustomCallOp(const xla::HloInstruction* instr); - tsl::StatusOr EmitCholesky( - const xla::HloCustomCallInstruction* custom_call); - tsl::StatusOr EmitGemm( - const xla::HloCustomCallInstruction* custom_call); - tsl::StatusOr EmitCublasLtMatmul( - const xla::HloCustomCallInstruction* custom_call); - tsl::StatusOr EmitCublasLtMatmulF8( - const xla::HloCustomCallInstruction* custom_call); - tsl::StatusOr EmitDnnConvolution( - const xla::HloCustomCallInstruction* custom_call); - tsl::StatusOr EmitDnnConvolutionReorderVectorized( - const xla::HloCustomCallInstruction* custom_call); - tsl::StatusOr EmitDnnBatchNorm( - const xla::HloCustomCallInstruction* custom_call); - xla::StatusOr EmitDnnfMHA( - const xla::HloCustomCallInstruction* custom_call); - xla::StatusOr EmitDnnfMHABackward( - const xla::HloCustomCallInstruction* custom_call); - tsl::StatusOr EmitDnnNorm( - const xla::HloCustomCallInstruction* custom_call); - xla::StatusOr EmitCubDeviceRadixSort( - const xla::HloCustomCallInstruction* custom_call); - tsl::StatusOr EmitConstant( - const xla::HloInstruction* instr); - - tsl::StatusOr EmitInfeedOp(const xla::HloInstruction* instr); - tsl::StatusOr EmitOutfeedOp( - const xla::HloInstruction* instr); - - template - tsl::StatusOr EmitDoneOp(const xla::HloInstruction* instr); - - tsl::StatusOr EmitAllToAllStartOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitAllToAllDoneOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitAllGatherStartOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitAllGatherDoneOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitAllReduceStartOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitAllReduceDoneOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitAsyncStartOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitAsyncDoneOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitReduceScatterStartOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitReduceScatterDoneOp( - const xla::HloInstruction* instr); - tsl::StatusOr - EmitCollectivePermuteStartOp(const xla::HloInstruction* instr); - tsl::StatusOr EmitCollectivePermuteDoneOp( - const xla::HloInstruction* instr); - - tsl::StatusOr EmitRngGetAndUpdateStateOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitFftOp(const xla::HloInstruction* instr); - tsl::StatusOr EmitTriangularSolveOp( - const xla::HloInstruction* instr); - tsl::StatusOr EmitBitcast(const xla::HloInstruction* instr); - - tsl::StatusOr EmitCaseOp(const xla::HloInstruction* instr); - - tsl::StatusOr EmitWhileOp(const xla::HloInstruction* instr); - - tsl::StatusOr EmitSendOp(const xla::HloInstruction* instr); - tsl::StatusOr EmitSendDoneOp( - const xla::HloInstruction* instr); - - tsl::StatusOr EmitRecvOp(const xla::HloInstruction* instr); - tsl::StatusOr EmitRecvDoneOp( - const xla::HloInstruction* instr); - - tsl::StatusOr EmitCommandBufferOp( - const xla::HloInstruction* instr); - - tsl::Status ImportAsLmhloRegion(xla::HloComputation* computation, - mlir::Region* region); - - // Since LMHLO dialect does not define token types, this enum controls how - // token operand/results from XLA:HLO are lowered to MLIR. - enum class TokenLoweringMode { - kFailToLower, // Fail lowering if token inputs are encountered. - kUseNull, // Use a null Value in the operand list for each token. - // kSkip, // Skip any token inputs or outputs (not yet needed) - }; - - // Create LHLO operation operands given an XLA HLO instruction. By default, - // all XLA HLO operands and results are converted to MLIR and appended to - // `operands`. If `num_operands` is specified, only the first `num_operand` - // operands of the instruction are converted to MLIR. The function returns the - // actual number of operands and results generated for MLIR in `num_arguments` - // and `num_results`. - tsl::Status CreateOperands(const xla::HloInstruction* instr, - std::optional num_operands, - TokenLoweringMode token_mode, - SmallVectorImpl& operands, - size_t& num_arguments, size_t& num_results); - - template - tsl::StatusOr CreateOpWithoutAttrs( - const xla::HloInstruction* instr, - std::optional num_operands = std::nullopt) { - size_t unused; - return CreateOpWithoutAttrs(instr, unused, unused, num_operands); - } - - template - tsl::StatusOr CreateOpWithoutAttrs( - const xla::HloInstruction* instr, size_t& num_arguments, - size_t& num_results, std::optional num_operands = std::nullopt); - - template - OpType CreateOpWithoutAttrs(const xla::HloInstruction* instr, - ValueRange operands); - - template - DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) { - return builder_.getI64TensorAttr( - {container.data(), static_cast(container.size())}); - } - - DenseIntElementsAttr GetWindowElements( - const xla::Window& window, - std::function getter) { - llvm::SmallVector elements; - elements.reserve(window.dimensions_size()); - for (const xla::WindowDimension& dim : window.dimensions()) { - elements.push_back(getter(dim)); - } - return GetI64DenseElementsAttr(elements); - } - - static mlir::DenseIntElementsAttr GetLayoutAttribute( - const xla::Layout& layout, Builder* builder); - - tsl::Status DefaultAction(const xla::HloInstruction* instr) final; - - // Computation parameters don't need any specific handling when they are - // visited, they are already processed when we enter a new computation. - tsl::Status HandleParameter(const xla::HloInstruction* instr) final { - return ::tsl::OkStatus(); - } - - // Helper function that recursively visits the tuple structure in - // `current_shape`, and reconstruct a matching lmhlo::TupleOp. - // Each leaf node is converted to an std.view op with corresponding offsets. - // If no tuple presents, it simply returns a view of the buffer. - tsl::Status GetOrCreateViewImpl(const xla::HloInstruction* instr, - const xla::Shape& current_shape, - xla::ShapeIndex* current_shape_index, - SmallVectorImpl* values, - TokenLoweringMode token_mode); - - // Helper function to create view/tuple of views to a buffer for a given - // instruction result. `result_subset` can be used to for instructions that - // have a tuple result and MLIR conversion needs to convert only one of the - // tuple elements. Note that if needed, this can be extended to take a list of - // ShapeIndex values in case we need finer control on what elements of the - // output tuple to be converted to MLIR. - tsl::Status GetOrCreateView( - const xla::HloInstruction* instr, SmallVectorImpl* values, - const xla::ShapeIndex& result_subset = {}, - TokenLoweringMode token_mode = TokenLoweringMode::kFailToLower); - - tsl::StatusOr GetOrCreateArrayView( - const xla::HloInstruction* instr, const xla::Shape& current_shape, - const xla::ShapeIndex& current_shape_index); - - tsl::StatusOr RewriteFusionOperand(const xla::HloInstruction* root, - const xla::Shape& shape, - xla::ShapeIndex* shape_index, - OpBuilder* b, Location loc); - - // Return an MLIR location for an HLO instruction. - Location getLocation(const xla::HloInstruction* inst) { - return NameLoc::get(builder_.getStringAttr(inst->name())); - } - - // This map provides access to MLIR buffers for each HLO buffer allocation. - // The MLIR buffers are all `memref<{size}xi8>` and correspond to function - // parameters. It is populated at the beginning of the processing with all - // the buffer allocations and is unchanged afterward. Every HLOInstruction - // is using a "slice" of the buffer allocation and providing shape, layout, - // and Dtype. An MLIR view is used separately to model slices into the - // allocations (see below). - llvm::DenseMap allocations_; - - // This map provides access to MLIR buffers constructed from memref arguments - // (allocations) using memref.view operation at the given offset (defined by - // slice) and result type (defined by shape). By using this cache we guarantee - // that we have a unique memref.view operation corresponding to each - // allocation slice. - absl::flat_hash_map, - Value> - allocation_slices_; - - // This map provides access to MLIR buffers for each HLO instruction, keyed - // instruction identity. A slice is contained in a BufferAllocation, and has - // an offset and a size. - // - // As for why we don't use HloInstruction*, see GetOrCreateView(), but - // mostly we want to leverage better of the aliased buffers. - // - // If the HloInstruction is a tuple, all leaf nodes are stored flattened. - // Otherwise, there will be a single buffer. - // - // An MLIR buffer is either an input parameter, or a ViewOp in the case - // where the slice is only part of its allocation. - // - // `instr_slices_` is populated lazily in the `GetOrCreateView()` helper as we - // process every instruction. - absl::flat_hash_map, - Value> - instr_slices_; - - // The BufferAssignment computed by XLA ahead of time. - const xla::BufferAssignment& assignment_; - - // The HLO module that will be converted. - const xla::HloComputation& computation_; - - // This is the MLIR module in which a function will be created for every HLO - // computation. - ModuleOp module_; - - // SymbolTable associated with the module. New functions should be added using - // this to avoid name conflicts. - mlir::SymbolTable symbol_table_; - - // The builder keeps track of the current insertion point in the MLIR - // module. - OpBuilder builder_; - // Convenient "cached" access to this widely used MLIR type (i8). - Type i8_type_; - - // Map ops returning tokens to their output (async collectives start ops, and - // point-to-point communication ops), to connect the correct done op. - absl::flat_hash_map ret_tokens_; - - // Maps each LHLO op created directly by this emitter to the corresponding HLO - // instruction. - // Note: this does not contain ops that are inside the bodies of fusions. - absl::flat_hash_map - lhlo_to_hlo_; -}; - -// Populate the MLIR `module` with the computation from the `hlo_module` using -// the provided buffer `assignment`. The returned `Status` indicates success -// or failure in the conversion. -// `lhlo_to_hlo_map`, if non-null, is populated with a mapping from generated -// top-level MLIR operations to the original HLO instructions. "top-level" means -// that ops inside the bodies of fusions are not included (but all fusions are). -// Store buffer allocations from buffer assignment in the order of inputs to the -// LMHLO entry function. -tsl::Status HloToLhloModule( - const xla::BufferAssignment& assignment, const xla::HloModule& hlo_module, - ModuleOp module, - std::vector* ordered_allocation, - absl::flat_hash_map* - lhlo_to_hlo_map = nullptr); - -OwningOpRef HloTextToLhloTranslateFunction( - llvm::StringRef input, MLIRContext* context); - -} // namespace mlir - -#endif // XLA_TRANSLATE_MHLO_TO_LHLO_WITH_XLA_MHLO_TO_LHLO_WITH_XLA_H_ diff --git a/xla/translate/mhlo_to_lhlo_with_xla/tests/BUILD b/xla/translate/mhlo_to_lhlo_with_xla/tests/BUILD deleted file mode 100644 index 1d3593a2942b6..0000000000000 --- a/xla/translate/mhlo_to_lhlo_with_xla/tests/BUILD +++ /dev/null @@ -1,31 +0,0 @@ -load("@tsl//tsl:tsl.default.bzl", "filegroup") -load("//xla:glob_lit_test.bzl", "glob_lit_tests") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - licenses = ["notice"], -) - -glob_lit_tests( - name = "all_tests", - data = [":test_utilities"], - driver = "@llvm-project//mlir:run_lit.sh", - test_file_exts = [ - "mlir", - "hlo", - "hlotxt", - ], -) - -# Bundle together all of the test utilities that are used by tests. -filegroup( - name = "test_utilities", - testonly = True, - data = [ - "//xla/translate:xla-translate", - "//xla/translate/mhlo_to_lhlo_with_xla:xla-translate-gpu-opt", - "//xla/translate/mhlo_to_lhlo_with_xla:xla-translate-opt", - "@llvm-project//llvm:FileCheck", - "@llvm-project//llvm:not", - ], -) diff --git a/xla/translate/mhlo_to_lhlo_with_xla/tests/hlo_text_to_lhlo_no_opt.hlotxt b/xla/translate/mhlo_to_lhlo_with_xla/tests/hlo_text_to_lhlo_no_opt.hlotxt deleted file mode 100644 index 6af4c9df2f4da..0000000000000 --- a/xla/translate/mhlo_to_lhlo_with_xla/tests/hlo_text_to_lhlo_no_opt.hlotxt +++ /dev/null @@ -1,736 +0,0 @@ -// RUN: xla-translate -split-input-file -hlo-text-to-lhlo %s | FileCheck %s - -HloModule TestModule - -// CHECK-LABEL: func @TestComputation - -FusedComputation { - // CHECK: to_tensor {{.*}} {xla_shape = "f32[3,2]{0,1}"} - x = f32[3, 2]{0,1} parameter(0) - ROOT y = f32[3, 2]{0,1} add(x, x) -} - -ENTRY TestComputation { - x = f32[3, 2]{0,1} parameter(0) - ROOT y = f32[3, 2]{0,1} fusion(x), kind=kLoop, calls=FusedComputation -} - -// ----- - -HloModule ScatterModule - -update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { - lhs = s32[] parameter(0) - ROOT rhs = s32[] parameter(1) -} - -// CHECK-LABEL: func @main -// CHECK: "lmhlo.scatter" -// CHECK: indices_are_sorted = false -// CHECK: update_window_dims = [1] -// CHECK: inserted_window_dims = [0] -// CHECK: scatter_dims_to_operand_dims = [0] -// CHECK: index_vector_dim = 1 -// CHECK: unique_indices = false -// CHECK: ^bb0(%[[ARG5:.*]]: tensor, %[[ARG6:.*]]: tensor): -// CHECK: mhlo.return %[[ARG6]] -// CHECK: (memref<3x3xi32>, memref<2xi32>, memref<2x3xi32>, memref<3x3xi32>) -> () -ENTRY main { - operand = s32[3,3] parameter(0) - indices = s32[2] parameter(1) - updates = s32[2,3] parameter(2) - ROOT scatter_op = s32[3,3] scatter(operand, indices, updates), - to_apply=update_s32, - update_window_dims={1}, - inserted_window_dims={0}, - scatter_dims_to_operand_dims={0}, - index_vector_dim=1 -} - -// ----- - -HloModule SelectAndScatter - -%ge_F32 (lhs.5: f32[], rhs.6: f32[]) -> pred[] { - %lhs.5 = f32[] parameter(0) - %rhs.6 = f32[] parameter(1) - ROOT %compare.7 = pred[] compare(f32[] %lhs.5, f32[] %rhs.6), direction=GE -} - -%add_F32 (lhs.9: f32[], rhs.10: f32[]) -> f32[] { - %lhs.9 = f32[] parameter(0) - %rhs.10 = f32[] parameter(1) - ROOT %add.11 = f32[] add(f32[] %lhs.9, f32[] %rhs.10) -} - -// CHECK-LABEL: module -// CHECK: memref.global "private" constant @[[$GLOBAL:.*]] : memref = dense<0.000000e+00> -// CHECK-LABEL: func @main -// CHECK: %[[GLOBAL_MEMREF:.*]] = memref.get_global @[[$GLOBAL]] : memref -// CHECK: "lmhlo.select_and_scatter"(%{{.*}}, %{{.*}}, %[[GLOBAL_MEMREF]], %{{.*}}) -// CHECK: padding = dense<0> : tensor<1xi64> -// CHECK: window_dimensions = dense<3> : tensor<1xi64> -// CHECK: window_strides = dense<3> : tensor<1xi64> -// CHECK: ^bb0(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor): -// CHECK: %[[COMPARE:.*]] = mhlo.compare GE, %[[ARG0]], %[[ARG1]] -// CHECK: mhlo.return %[[COMPARE]] : tensor -// CHECK: ^bb0(%[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor): -// CHECK: %[[ADD:.*]] = mhlo.add %[[ARG2]], %[[ARG3]] -// CHECK: mhlo.return %[[ADD]] : tensor -// CHECK: (memref<6xf32>, memref<2xf32>, memref, memref<6xf32>) -> () -ENTRY main () -> f32[6] { - %operand = f32[6]{0} parameter(0) - %source = f32[2]{0} parameter(1) - %init = f32[] constant(0) - ROOT %select-and-scatter.12 = f32[6]{0} select-and-scatter(f32[6]{0} %operand, f32[2]{0} %source, f32[] %init), window={size=3 stride=3}, select=%ge_F32, scatter=%add_F32 -} - -// ----- - -HloModule SliceToDynamic - -// CHECK-LABEL: func @main -// CHECK: "lmhlo.custom_call" -// CHECK: backend_config = "", call_target_name = "SliceToDynamic" -// CHECK-SAME: operandSegmentSizes = array -// CHECK-NOT: target_arg_mapping -// CHECK: (memref<2x2x2xi32>, memref, memref, memref, memref<2x2x2xi32>) -> () -ENTRY main { - %param = s32[2,2,2] parameter(0) - %static = s32[] parameter(1) - %dynamic = s32[] parameter(2) - ROOT %custom-call = s32[2,<=2, 2] custom-call(s32[2,2,2] %param, - s32[] %static, - s32[] %dynamic, - s32[] %static), - custom_call_target="SliceToDynamic", - backend_config="" -} - -// ----- - -HloModule Cholesky - -// CHECK-LABEL: func @main -// CHECK: "lmhlo_gpu.cholesky" -// CHECK-SAME: is_lower = true -ENTRY main { - %param = f32[3,3] parameter(0) - ROOT %custom-call = (f32[3,3], f32[3], s32[]) custom-call(f32[3,3] %param), - custom_call_target="__cusolver$cholesky", - operand_layout_constraints={f32[3,3]}, - backend_config="{\"lower\":true}" -} - -// ----- - -HloModule Gemm - -// CHECK-LABEL: func @main -// CHECK: "lmhlo_gpu.gemm" -// CHECK-SAME: algorithm = 7 : i64 -// CHECK-SAME: alpha_imag = 0.000000e+00 : f64 -// CHECK-SAME: alpha_real = 1.000000e+00 : f64 -// CHECK-SAME: beta = 0.000000e+00 : f64 -// CHECK-NOT: lhs_batching_dimensions -// CHECK-NOT: rhs_batching_dimensions -// CHECK-SAME: lhs_contracting_dimensions = [1] -// CHECK-SAME: rhs_contracting_dimensions = [0] -// CHECK-SAME: precision_config = [#mhlo, #mhlo] -// CHECK: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () -ENTRY main { - %A = f32[2,2]{1,0} parameter(0) - %B = f32[2,2]{1,0} parameter(1) - ROOT %sgemm = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %A, f32[2,2]{1,0} %B), - custom_call_target="__cublas$gemm", - backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"precision_config\":{\"operand_precision\":[\"HIGH\",\"HIGHEST\"]},\"selected_algorithm\":\"7\"}" -} - -// ----- - -HloModule CublasLtMatmul - -// CHECK-LABEL: func @main -// CHECK: "lmhlo_gpu.cublas.lt.matmul" -// CHECK-SAME: alpha_imag = 0.000000e+00 : f64 -// CHECK-SAME: alpha_real = 1.000000e+00 : f64 -// CHECK-SAME: beta = 0.000000e+00 : f64 -// CHECK-NOT: lhs_batching_dimensions -// CHECK-NOT: rhs_batching_dimensions -// CHECK-SAME: lhs_contracting_dimensions = [1] -// CHECK-SAME: rhs_contracting_dimensions = [0] -// CHECK-SAME: precision_config = [#mhlo, #mhlo] -// CHECK: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () - -ENTRY main { - %A = f32[2,2]{1,0} parameter(0) - %B = f32[2,2]{1,0} parameter(1) - ROOT %custom-call = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %A, f32[2,2]{1,0} %B), custom_call_target="__cublas$lt$matmul", - backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}}" -} - -// ----- - -HloModule CublasLtMatmulF8 - -// CHECK-LABEL: func @main -// CHECK: "lmhlo_gpu.cublas.lt.matmul.f8" -// CHECK-SAME: alpha_imag = 0.000000e+00 : f64 -// CHECK-SAME: alpha_real = 1.000000e+00 : f64 -// CHECK-SAME: beta = 1.000000e+00 : f64 -// CHECK-NOT: lhs_batching_dimensions -// CHECK-NOT: rhs_batching_dimensions -// CHECK-SAME: lhs_contracting_dimensions = [1] -// CHECK-SAME: rhs_contracting_dimensions = [0] -// CHECK-SAME: precision_config = [#mhlo, #mhlo] -// CHECK: (memref<16x16xf8E4M3FN>, memref<16x16xf8E4M3FN>, memref<16x16xf16>, memref, memref, memref, memref, memref<16x16xf8E4M3FN>, memref) -> () - -ENTRY main { - %A = f8e4m3fn[16,16]{1,0} parameter(0) - %B = f8e4m3fn[16,16]{1,0} parameter(1) - %C = f16[16,16]{1,0} parameter(2) - %A_SCALE = f32[] parameter(3) - %B_SCALE = f32[] parameter(4) - %C_SCALE = f32[] parameter(5) - %D_SCALE = f32[] parameter(6) - ROOT %custom-call = (f8e4m3fn[16,16]{1,0}, f32[]) custom-call(f8e4m3fn[16,16]{1,0} %A, f8e4m3fn[16,16]{1,0} %B, f16[16,16]{1,0} %C, f32[] %A_SCALE, f32[] %B_SCALE, f32[] %C_SCALE, f32[] %D_SCALE), custom_call_target="__cublas$lt$matmul$f8", - backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1.0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}}" -} - - -// ----- - -HloModule AsyncAllReduce - -// Test all-reduce-async -add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) -} - -// CHECK-LABEL: func @test_async_all_reduce -// CHECK-SAME: [[BUFFER:%.*]]: memref<32xi8> -%test_async_all_reduce { - param0 = f32[8] parameter(0) - // CHECK: [[VIEW:%.*]] = memref.view [[BUFFER]]{{.*}} : memref<32xi8> to memref<8xf32> - // CHECK: [[TOKEN:%.*]] = "lmhlo_gpu.all_reduce_start"([[VIEW]], [[VIEW]]) - // CHECK-SAME: channel_id = #mhlo.channel_handle - // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64> - // CHECK: ^bb0([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor): - // CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]] - // CHECK: mhlo.return [[ADD]] : tensor - // CHECK: }) - // CHECK: "lmhlo_gpu.all_reduce_done"([[TOKEN]]) - start = f32[8] all-reduce-start(param0), - channel_id=1, replica_groups={{0,1,2,3}, {4,5,6,7}}, to_apply=add - ROOT done = f32[8] all-reduce-done(start) -} - -// ----- - -HloModule AsyncAllReduceTwoOperands - -// Test all-reduce-async -add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) -} - -// CHECK-LABEL: func @test_async_all_reduce_two_operands -// CHECK-SAME: [[BUFFER0:%.*]]: memref<32xi8> -// CHECK-SAME: [[BUFFER1:%.*]]: memref<36xi8> -%test_async_all_reduce_two_operands { - param0 = f32[8] parameter(0) - param1 = f32[9] parameter(1) - // CHECK: [[VIEW0:%.*]] = memref.view [[BUFFER0]]{{.*}} : memref<32xi8> to memref<8xf32> - // CHECK: [[VIEW1:%.*]] = memref.view [[BUFFER1]]{{.*}} : memref<36xi8> to memref<9xf32> - // CHECK: [[TOKEN:%.*]] = "lmhlo_gpu.all_reduce_start"([[VIEW0]], [[VIEW1]], [[VIEW0]], [[VIEW1]]) - // CHECK-SAME: channel_id = #mhlo.channel_handle - // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64> - // CHECK: ^bb0([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor): - // CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]] - // CHECK: mhlo.return [[ADD]] : tensor - // CHECK: }) - // CHECK: "lmhlo_gpu.all_reduce_done"([[TOKEN]]) - start = (f32[8], f32[9]) all-reduce-start(param0, param1), - channel_id=1, replica_groups={{0,1,2,3}, {4,5,6,7}}, to_apply=add - ROOT done = (f32[8], f32[9]) all-reduce-done(start) -} - -// ----- - -HloModule ConvForward - -// CHECK-LABEL: func @main -// CHECK: lmhlo_gpu.conv_forward -// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1] -// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [0, 0], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [1, 1]} -// CHECK-SAME: algorithm = 2 -// CHECK-SAME: tensor_ops_enabled = false -// CHECK-SAME: operand_0_layout = [3, 2, 1, 0] -// CKECK-SAME: operand_1_layout = [3, 2, 1, 0] -// CHECK-SAME: result_layout = [3, 2, 1, 0] -// CHECK-SAME: batch_group_count = 1 : i64 -// CHECK-SAME: feature_group_count = 1 : i64 -// CHECK-SAME: result_scale = 1.000000e+00 : f64 -// CHECK: (memref<4x256x3x3xf32>, memref<256x256x2x2xf32>, memref<4x256x2x2xf32>, memref<65536xui8>) -ENTRY main { - %input = f32[4,256,3,3]{3,2,1,0} parameter(0) - %filter = f32[256,256,2,2]{3,2,1,0} parameter(1) - ROOT %custom-call.1 = (f32[4,256,2,2]{3,2, 1,0}, u8[65536]{0}) custom-call(f32[4,256,3,3]{3,2,1,0} %input, f32[256,256,2,2]{3,2,1,0} %filter), - window={size=2x2 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, - custom_call_target="__cudnn$convForward", - backend_config="{\"algorithm\": {\"algo_id\":\"2\",\"math_type\":\"DEFAULT_MATH\"},\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" -} - -// ----- - -// CHECK: func @main -// CHECK: lmhlo_gpu.conv_forward_fused -// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1] -// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [1, 1], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [0, 0]} -// CHECK-SAME: activation_mode = #lmhlo_gpu -// CHECK-SAME: algorithm = 0 -// CHECK-SAME: tensor_ops_enabled = false -// CHECK-SAME: operand_0_layout = [1, 3, 2, 0] -// CHECK-SAME: operand_1_layout = [2, 1, 0, 3] -// CHECK-SAME: result_layout = [1, 3, 2, 0] -// CHECK-SAME: batch_group_count = 1 : i64 -// CHECK-SAME: feature_group_count = 1 : i64 -// CHECK-SAME: precision_config = [#mhlo, #mhlo, #mhlo] -// CHECK-SAME: result_scale = 1.000000e+00 : f64 -// CHECK-SAME: (memref<1x17x9x9xf16, #map{{.*}}>, memref<3x3x17x32xf16, #map{{.*}}>, memref<32xf16>, memref<1x32x9x9xf16, #{{.*}}>, memref<0xui8>) -> () - -HloModule FusedConvForward - -ENTRY main { - %input = f16[1,17,9,9]{1,3,2,0} parameter(0) - %filter = f16[3,3,17,32]{2,1,0,3} parameter(1) - %bias = f16[32]{0} parameter(2) - ROOT %custom-call.2 = (f16[1,32,9,9]{1,3,2,0}, u8[0]{0}) custom-call(f16[1,17,9,9]{1,3,2,0} %input, f16[3,3,17,32]{2,1,0,3} %filter, f16[32]{0} %bias), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"algorithm\": {\"algo_id\":\"0\",\"math_type\":\"DEFAULT_MATH\"},\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":0}" -} - -// ----- - -// CHECK: func @main -// CHECK: lmhlo_gpu.conv_forward_fused_with_side_input -// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1] -// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [1, 1], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [0, 0]} -// CHECK-SAME: activation_mode = #lmhlo_gpu -// CHECK-SAME: algorithm = 0 -// CHECK-SAME: tensor_ops_enabled = false -// CHECK-SAME: operand_0_layout = [1, 3, 2, 0] -// CHECK-SAME: operand_1_layout = [2, 1, 0, 3] -// CHECK-SAME: result_layout = [1, 3, 2, 0] -// CHECK-SAME: batch_group_count = 1 : i64 -// CHECK-SAME: feature_group_count = 1 : i64 -// CHECK-SAME: precision_config = [#mhlo, #mhlo, #mhlo, #mhlo] -// CHECK-SAME: result_scale = 1.000000e+00 : f64 -// CHECK-SAME: side_input_scale = 1.000000e+00 -// CHECK-SAME: (memref<1x17x9x9xf16, #map{{.*}}>, memref<3x3x17x32xf16, #map{{.*}}>, memref<32xf16>, memref<1x32x9x9xf16, #{{.*}}>, memref<0xui8>) -> () - -HloModule FusedConvForwardSideInput - -ENTRY main { - %input = f16[1,17,9,9]{1,3,2,0} parameter(0) - %filter = f16[3,3,17,32]{2,1,0,3} parameter(1) - %bias = f16[32]{0} parameter(2) - %side = f16[32]{0} parameter(3) - ROOT %custom-call.2 = (f16[1,32,9,9]{1,3,2,0}, u8[0]{0}) custom-call(f16[1,17,9,9]{1,3,2,0} %input, f16[3,3,17,32]{2,1,0,3} %filter, f16[32]{0} %bias, f16[32]{0} %side), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"algorithm\":{\"algo_id\":\"0\",\"math_type\":\"DEFAULT_MATH\"},\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":1}" -} - -// ----- - -HloModule Infeed - -// CHECK: func @main -// CHECK: "lmhlo.infeed" -// CHECK-SAME: (memref<3xf32>) -> () -ENTRY main { - %tok = token[] parameter(0) - ROOT %infeed = (f32[3]{0}, token[]) infeed(token[] %tok) -} - -// ----- - -HloModule Outfeed - -// CHECK: func @main -// CHECK: "lmhlo.outfeed" -// CHECK-SAME: config = "" -// CHECK-SAME: (memref<3xf32>) -> () -ENTRY main { - %source = f32[3] parameter(0) - %tok = token[] parameter(1) - ROOT %o = token[] outfeed(f32[3] %source, token[] %tok) -} - -// ----- - -HloModule Outfeed - -// CHECK: func @main -// CHECK: "lmhlo.custom_call" -// CHECK: call_target_name = "foo" -// CHECK: "lmhlo.outfeed" -// CHECK-SAME: config = "" -// CHECK-SAME: (memref<3xf32>, memref<5xf16>) -> () -ENTRY main { - %tok = token[] parameter(0) - %tuple = (f32[3], f16[5]) custom-call(),custom_call_target="foo" - ROOT %o = token[] outfeed((f32[3], f16[5]) %tuple, token[] %tok) -} - -// ----- - -HloModule TestModule - -// CHECK: func @main -// CHECK: "lmhlo.rng_get_and_update_state"(%{{.*}}) <{delta = 131072 : i64}> : (memref<2xui64>) -> () -ENTRY main { - ROOT %rng-get-and-update-state = u64[2]{0} rng-get-and-update-state(), delta=131072 -} - -// ----- - -HloModule TestReplicaId - -// CHECK: func @main -// CHECK: "lmhlo.replica_id" -// CHECK-SAME: (memref) -> () -ENTRY main { - ROOT %replica_id = u32[] replica-id() -} - -// ----- - -HloModule fft - -// CHECK: func @main -// CHECK: "lmhlo.fft" -// CHECK-SAME: fft_length = dense<[8, 32]> : tensor<2xi64> -// CHECK-SAME: fft_type = #mhlo -ENTRY main { - %input = c64[5,8,32] parameter(0) - ROOT %fft = c64[5,8,32] fft(c64[5,8,32] %input), fft_type=IFFT, fft_length={8,32} -} - -// ----- - -HloModule TriangularSolve_module - -// CHECK: func @main -// CHECK: "lmhlo.triangular_solve" -// CHECK-SAME: layout_a = dense<[1, 0]> : tensor<2xindex> -// CHECK-SAME: layout_b = dense<[1, 0]> : tensor<2xindex> -// CHECK-SAME: layout_output = dense<[1, 0]> : tensor<2xindex> -// CHECK-SAME: left_side = false -// CHECK-SAME: lower = true -// CHECK-SAME: transpose_a = #mhlo -// CHECK-SAME: unit_diagonal = false -ENTRY main { - %a = f32[4,4]{1,0} parameter(0) - %b = f32[3,4]{1,0} parameter(1) - ROOT %triangular-solve = f32[3,4]{1,0} triangular-solve(f32[4,4]{1,0} %a, f32[3,4]{1,0} %b), lower=true, transpose_a=NO_TRANSPOSE -} - -// ----- - -HloModule CustomCallWithTypedFFIBackendConfig - -// CHECK: func @main -// CHECK: "lmhlo.custom_call" -// CHECK: api_version = 4 : i32 -// CHECK-SAME: backend_config = { -// CHECK-SAME: user_attr0 = 123 : i32 -// CHECK-SAME: user_attr1 = dense<42> : tensor -// CHECK-SAME: } -// CHECK-SAME: num_args = 1 -// CHECK-SAME: num_results = 2 -// CHECK-SAME: args_to_target_args = [] -// CHECK-SAME: results_to_target_results = [0] -ENTRY main { - %tok = token[] parameter(0) - ROOT %call = (f32[3], token[]) custom-call (%tok), custom_call_target="foo", - api_version=API_VERSION_TYPED_FFI, - backend_config="{user_attr0 = 123 : i32, user_attr1 = dense<42> : tensor}" -} - -// ----- - -HloModule CustomCallWithTokens - -// CHECK: func @main -// CHECK: "lmhlo.custom_call" -// CHECK: num_args = 1 -// CHECK-SAME: num_results = 2 -// CHECK-SAME: args_to_target_args = [] -// CHECK-SAME: results_to_target_results = [0] -ENTRY main { - %tok = token[] parameter(0) - ROOT %call = (f32[3], token[]) custom-call (%tok), custom_call_target="foo", - backend_config="" -} - -// ----- - -HloModule CustomCallWithTokens - -// CHECK: func @main -// CHECK: "lmhlo.custom_call" -// CHECK: num_args = 3 -// CHECK-SAME: num_results = 3 -// CHECK-SAME: args_to_target_args = [1] -// CHECK-SAME: results_to_target_results = [0, 2] -ENTRY main { - %tok = token[] parameter(0) - %input = f32[5,8,32] parameter(1) - ROOT %call = (f32[3]{0}, token[], f32[3]) custom-call (%tok, %input, %tok), - custom_call_target="foo", - backend_config="" -} - -// ----- - -HloModule CustomCallWithTokens - -// CHECK: func @main -// CHECK: "lmhlo.custom_call" -// CHECK: num_args = 3 -// CHECK-SAME: num_results = 1 -// CHECK-SAME: args_to_target_args = [1] -// CHECK-SAME: results_to_target_results = [0] -ENTRY main { - %tok = token[] parameter(0) - %input = f32[5,8,32] parameter(1) - ROOT %call = f32[3] custom-call (%tok, %input, %tok), - custom_call_target="foo", - backend_config="" -} - -// ----- - -HloModule CustomCallWithTokens - -// CHECK: func @main -// CHECK: "lmhlo.custom_call" -// CHECK: num_args = 1 -// CHECK-SAME: num_results = 4 -// CHECK-SAME: args_to_target_args = [0] -// CHECK-SAME: results_to_target_results = [1] -ENTRY main { - %input = f32[5,8,32] parameter(0) - ROOT %call = (token[], f32[3]{0}, token[], token[]) custom-call (%input), - custom_call_target="foo", - backend_config="" -} - -// ----- -// CHECK: func @main -// CHECK: "lmhlo.while"(%{{.*}}) ({ -HloModule WhileConstantCondition - -%body { - ROOT %parameter.5 = (f32[5]{0}) parameter(0) -} - -%cond { - %parameter.12 = (f32[5]{0}) parameter(0) - ROOT %constant_1 = pred[] constant(false) -} - -ENTRY %main (parameter.1: f32[5]) -> (f32[5]) { - %parameter.1 = f32[5]{0} parameter(0) - %tuple = (f32[5]{0}) tuple(f32[5]{0} %parameter.1) - ROOT %while.19 = (f32[5]{0}) while((f32[5]{0}) %tuple), condition=%cond, body=%body -} - -// ----- - -HloModule CustomCallNoComputation - -// CHECK: "lmhlo.custom_call" -// CHECK: call_target_name = "__custom" - -ENTRY main { - param = f32[] parameter(0) - ROOT cr = f32[] custom-call(param), custom_call_target="__custom" -} - -// ----- - -HloModule CustomCallWithComputation - -// CHECK: "lmhlo.custom_call" -// CHECK: call_target_name = "__custom" -// CHECK: %0 = mhlo.add -// CHECK: mhlo.return %0 - -computation1 { - param_0 = f32[] parameter(0) - ROOT r = f32[] add(param_0, param_0) -} - -ENTRY main { - param = f32[] parameter(0) - ROOT cr = f32[] custom-call(param), custom_call_target="__custom", - to_apply=computation1 -} - -// ----- - -HloModule Send - -// CHECK: func @main -// CHECK: %[[ARG1:arg[0-9]+]]: memref<16xi8> {lmhlo.params = 1 : index} -// CHECK: %[[VIEW:.*]] = memref.view %[[ARG1]][%c0][] -// CHECK: %[[TOKEN:.*]] = "lmhlo.send"(%[[VIEW]]) -// CHECK: channel_handle = #mhlo.channel_handle, -// CHECK: frontend_attributes = {_xla_dcn_recv_channel = "2", -// CHECK: _xla_host_transfer_handler_name = "undef", -// CHECK: _xla_host_transfer_rendezvous = "undef"} -// CHECK: is_host_transfer = true -// CHECK: : (memref<4xf32>) -> !mhlo.token -// CHECK: "lmhlo.send_done"(%0) -// CHECK: channel_handle = #mhlo.channel_handle, -// CHECK is_host_transfer = true -// CHECK: : (!mhlo.token) -> () -ENTRY main { - %tok = token[] parameter(0) - %buf = f32[4]{0} parameter(1) - %send = (f32[4]{0}, u32[], token[]) send(f32[4]{0} %buf, token[] %tok), channel_id=1, is_host_transfer=true, frontend_attributes={_xla_dcn_recv_channel="2",_xla_host_transfer_handler_name="undef",_xla_host_transfer_rendezvous="undef"} - ROOT %send-done = token[] send-done((f32[4]{0}, u32[], token[]) %send), channel_id=1, is_host_transfer=true -} - -// ----- - -HloModule Recv - -// CHECK: func @main -// CHECK: %[[ARG1:arg[0-9]+]]: memref<16xi8> {lmhlo.output_index = dense<0> : tensor<1xi64>} -// CHECK: %[[VIEW:.*]] = memref.view %[[ARG1]][%c0][] -// CHECK: %[[TOKEN:.*]] = "lmhlo.recv"(%[[VIEW]]) -// CHECK: channel_handle = #mhlo.channel_handle, -// CHECK: frontend_attributes = {_xla_host_transfer_handler_name = "undef", -// CHECK: _xla_host_transfer_rendezvous = "undef"} -// CHECK: is_host_transfer = true -// CHECK: : (memref<4xf32>) -> !mhlo.token -// CHECK: "lmhlo.recv_done"(%0) -// CHECK: channel_handle = #mhlo.channel_handle, -// CHECK is_host_transfer = true -// CHECK: : (!mhlo.token) -> () -ENTRY main { - %tok = token[] parameter(0) - %recv = (f32[4]{0}, u32[], token[]) recv(token[] %tok), channel_id=1, is_host_transfer=true, frontend_attributes={_xla_host_transfer_handler_name="undef",_xla_host_transfer_rendezvous="undef"} - ROOT %recv-done = (f32[4]{0}, token[]) recv-done((f32[4]{0}, u32[], token[]) %recv), channel_id=1, is_host_transfer=true -} - -// ----- - -HloModule TestAllGatherAsync - -// CHECK: func @main -// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.all_gather_start"(%{{.*}}, %{{.*}}) < -// CHECK-SAME: all_gather_dimension = 1 : i64 -// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> -// CHECK-SAME: use_global_device_ids = false -// CHECK ""lmhlo_gpu.all_gather_done"(%[[TOKEN]]) -ENTRY main { - param0 = f32[10,20] parameter(0) - ags = (f32[10,20], f32[10,80]) all-gather-start(param0), replica_groups={{0,1,2,3}}, - dimensions={1} - ROOT ag = f32[10,80] all-gather-done(ags) -} - -// ----- - -HloModule AsyncReduceScatter - -// CHECK: func @main -// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.reduce_scatter_start"(%{{.*}}, %{{.*}}) < -// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> -// CHECK-SAME: scatter_dimension = 0 -// CHECK-SAME: use_global_device_ids = false -// CHECK: ^bb0([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor): -// CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]] -// CHECK: mhlo.return [[ADD]] : tensor -// CHECK: }) : -// CHECK ""lmhlo_gpu.reduce_scatter_done"(%[[TOKEN]]) - -add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) -} - -reduce_scatter { - p0 = f32[8] parameter(0) - ROOT result = f32[4] reduce-scatter(p0), replica_groups={{0,1}}, - dimensions={0}, to_apply=add -} - -ENTRY main { - input = f32[8] parameter(0) - rs-start = ((f32[8]), f32[4]) async-start(input), calls=reduce_scatter - ROOT rs-done = f32[4] async-done(rs-start), calls=reduce_scatter -} - -// ----- - -HloModule AsyncAllToAll - -// CHECK: func @main -// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.all_to_all_start"(%{{.*}}, %{{.*}}) < -// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> -// CHECK ""lmhlo_gpu.all_to_all_done"(%[[TOKEN]]) - -all_to_all { - p0 = f32[128,4] parameter(0) - ROOT a2a = f32[128,4] all-to-all(p0), replica_groups={{0,1}} -} - -ENTRY main { - p0 = f32[128,4] parameter(0) - a2a-start = ((f32[128,4]), f32[128,4]) async-start(p0), calls=all_to_all - ROOT a2a-done = f32[128,4] async-done(a2a-start), calls=all_to_all -} - -// ----- - -HloModule TestAllGatherAsyncWithSyncFlagFalse - -// CHECK: func @main -// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.all_gather_start"(%{{.*}}, %{{.*}}) < -// CHECK-SAME: all_gather_dimension = 1 : i64 -// CHECK-SAME: is_sync = false -// CHECK-SAME: no_parallel_custom_call = false -// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> -// CHECK-SAME: use_global_device_ids = false -// CHECK ""lmhlo_gpu.all_gather_done"(%[[TOKEN]]) -ENTRY main { - param0 = f32[10,20] parameter(0) - ags = (f32[10,20], f32[10,80]) all-gather-start(param0), replica_groups={{0,1,2,3}}, - dimensions={1} - ROOT ag = f32[10,80] all-gather-done(ags) -} - -// ----- - -HloModule TestAllGatherAsyncWithSyncFlagTrue - -// CHECK: func @main -// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.all_gather_start"(%{{.*}}, %{{.*}}) < -// CHECK-SAME: all_gather_dimension = 1 : i64 -// CHECK-SAME: is_sync = true -// CHECK-SAME: no_parallel_custom_call = true -// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> -// CHECK-SAME: use_global_device_ids = false -// CHECK ""lmhlo_gpu.all_gather_done"(%[[TOKEN]]) -ENTRY main { - param0 = f32[10,20] parameter(0) - ags = (f32[10,20], f32[10,80]) all-gather-start(param0), replica_groups={{0,1,2,3}}, - dimensions={1}, backend_config="{\"is_sync\":true, \"no_parallel_custom_call\":true}" - ROOT ag = f32[10,80] all-gather-done(ags) -} diff --git a/xla/translate/mhlo_to_lhlo_with_xla/tests/no_opt_ops.hlotxt b/xla/translate/mhlo_to_lhlo_with_xla/tests/no_opt_ops.hlotxt deleted file mode 100644 index 0c7fc220f73af..0000000000000 --- a/xla/translate/mhlo_to_lhlo_with_xla/tests/no_opt_ops.hlotxt +++ /dev/null @@ -1,95 +0,0 @@ -// RUN: xla-translate -split-input-file -hlo-text-to-lhlo %s | FileCheck %s - -HloModule indexed_conditional - -%Negate (x: f32[]) -> f32[] { - %x = f32[] parameter(0) - ROOT %negate = f32[] negate(f32[] %x) -} - -%NegateCond (x: f32[]) -> f32[] { - %x = f32[] parameter(0) - ROOT %negate = f32[] fusion(f32[] %x), kind=kLoop, calls=%Negate -} - -%Identity (y: f32[]) -> f32[] { - %y = f32[] parameter(0) - ROOT %copy = f32[] copy(f32[] %y) -} - -%IdentityCond (x: f32[]) -> f32[] { - %y = f32[] parameter(0) - ROOT %copy = f32[] fusion(f32[] %y), kind=kLoop, calls=%Identity -} - -%Floor (z: f32[]) -> f32[] { - %z = f32[] parameter(0) - ROOT %floor = f32[] floor(f32[] %z) -} - -%FloorCond (x: f32[]) -> f32[] { - %z = f32[] parameter(0) - ROOT %floor = f32[] fusion(f32[] %z), kind=kLoop, calls=%Floor -} - -// CHECK: %{{.*}} = memref.view -// CHECK: "lmhlo.case"(%{{.*}}) ({ -// CHECK: mhlo.negate -// CHECK: "lmhlo.terminator"() : () -> () -// CHECK: }, { -// CHECK: mhlo.copy -// CHECK: "lmhlo.terminator"() : () -> () -// CHECK: }, { -// CHECK: mhlo.floor -// CHECK: "lmhlo.terminator"() : () -> () -// CHECK: }) : (memref) -> () - -ENTRY %Parameters1.v4 () -> (f32[]) { - %constant = s32[] parameter(0) - %constant.1 = f32[] parameter(1) - %constant.2 = f32[] parameter(2) - %constant.3 = f32[] parameter(3) - %conditional = f32[] conditional(s32[] %constant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%NegateCond, %IdentityCond, %FloorCond} - ROOT %t = (f32[]) tuple(%conditional) -} - -// ----- - -HloModule WhileWithScalarS32Result_module - -%Add (a: s32[], b: s32[]) -> s32[] { - %a = s32[] parameter(0) - %b = s32[] parameter(1) - ROOT %add = s32[] add(s32[] %a, s32[] %b) -} - -%body.v3 (prev.1: s32[]) -> s32[] { - %constant = s32[] constant(1) - %prev.1 = s32[] parameter(0) - ROOT %add = s32[] fusion(s32[] %constant, s32[] %prev.1), kind=kLoop, calls=%Add -} - -%Compare (a: s32[], b: s32[]) -> pred[] { - %a = s32[] parameter(0) - %b = s32[] parameter(1) - ROOT %greater-than = pred[] compare(s32[] %a, s32[] %b), direction=GT -} - -%condition.v3 (prev.2: s32[]) -> pred[] { - %constant.1 = s32[] constant(5) - %prev.2 = s32[] parameter(0) - ROOT %greater-than = pred[] fusion(s32[] %constant.1, s32[] %prev.2), kind=kLoop, calls=%Compare -} - -// CHECK: %{{.*}} = memref.view -// CHECK: "lmhlo.while"(%{{.*}}) ({ -// CHECK: mhlo.compare -// CHECK: "lmhlo.terminator"() : () -> () -// CHECK: }, { -// CHECK: mhlo.add -// CHECK: "lmhlo.terminator"() : () -> () -// CHECK: }) : (memref) -> () -ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { - %constant.2 = s32[] constant(0) - ROOT %while = s32[] while(s32[] %constant.2), condition=%condition.v3, body=%body.v3 -} diff --git a/xla/translate/mhlo_to_lhlo_with_xla/tests/non_identity_layouts.hlotxt b/xla/translate/mhlo_to_lhlo_with_xla/tests/non_identity_layouts.hlotxt deleted file mode 100644 index d74ec4e3434c9..0000000000000 --- a/xla/translate/mhlo_to_lhlo_with_xla/tests/non_identity_layouts.hlotxt +++ /dev/null @@ -1,28 +0,0 @@ -// RUN: xla-translate -hlo-text-to-lhlo %s | FileCheck %s - -HloModule TestModule - -// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 3)> - -Fusion { - x = f32[3, 2]{1,0} parameter(0) - ROOT x.copy = f32[3, 2]{0,1} copy(x) -} - -// CHECK: func @TestComputation -ENTRY TestComputation { - x = f32[3, 2]{1,0} parameter(0) - - // CHECK: %[[VIEW:.*]] = memref.view {{.*}} : memref<24xi8> to memref<3x2xf32> - // CHECK: "lmhlo.fusion"() <{backend_config = "{{.*}}"}> ({ - // CHECK: %[[VAL2:.*]] = bufferization.to_tensor %[[VIEW]] : memref<3x2xf32> - // CHECK: %[[VAL3:.*]] = mhlo.copy %[[VAL2]] { - // CHECK-SAME: result_layout = dense<[0, 1]> - // CHECK-SAME: xla_shape = "f32[3,2]{0,1}" - // CHECK-SAME: } : tensor<3x2xf32> - // CHECK: bufferization.materialize_in_destination %[[VAL3:.*]] in - // CHECK-SAME: writable %{{.*}} : (tensor<3x2xf32>, memref<3x2xf32, #[[MAP]]>) - // CHECK: "lmhlo.terminator"() : () -> () - // CHECK: }) : () -> () - ROOT fusion = f32[3, 2]{0,1} fusion(f32[3, 2]{1,0} x), kind=kLoop, calls=Fusion -} diff --git a/xla/translate/mhlo_to_lhlo_with_xla/translate_registration.cc b/xla/translate/mhlo_to_lhlo_with_xla/translate_registration.cc deleted file mode 100644 index 3ed2141a4ad3a..0000000000000 --- a/xla/translate/mhlo_to_lhlo_with_xla/translate_registration.cc +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project -#include "xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h" - -//----------------------------------------------------------------------------// -// Hooks for tf-mlir-translate -//----------------------------------------------------------------------------/ - -// MHLO doesn't support explicit layouts, while XLA service does. -// TODO(timshen): remove it once MHLO supports explicit layouts. -static mlir::TranslateToMLIRRegistration HloTextToLhloMlirTranslate( - "hlo-text-to-lhlo", "hlo-text-to-lhlo", - [](llvm::StringRef input, mlir::MLIRContext* context) { - return mlir::HloTextToLhloTranslateFunction(input, context); - }); diff --git a/xla/translate/xla_translate_main.cc b/xla/translate/xla_translate_main.cc index d444667d3f593..241c8aa0993fd 100644 --- a/xla/translate/xla_translate_main.cc +++ b/xla/translate/xla_translate_main.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,11 +16,13 @@ limitations under the License. #include #include +#include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/IR/AsmState.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/xla/translate/mhlo_to_lhlo_with_xla/xla_translate_opt_main.cc b/xla/translate/xla_translate_opt_main.cc similarity index 84% rename from xla/translate/mhlo_to_lhlo_with_xla/xla_translate_opt_main.cc rename to xla/translate/xla_translate_opt_main.cc index 7f48958d4b92d..6b286f34cabd8 100644 --- a/xla/translate/mhlo_to_lhlo_with_xla/xla_translate_opt_main.cc +++ b/xla/translate/xla_translate_opt_main.cc @@ -13,21 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "llvm/Support/InitLLVM.h" #include "mlir/InitAllDialects.h" // from @llvm-project #include "mlir/InitAllPasses.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project #include "stablehlo/dialect/Register.h" // from @stablehlo #include "xla/mlir/framework/ir/xla_framework.h" #include "xla/mlir/framework/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/register.h" -#include "xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h" #include "tsl/platform/init_main.h" int main(int argc, char **argv) { - // TODO(jreiffers): Move this to a more appropriate place. It is used by both - // translate/mhlo_to_lhlo_with_xla and mlir/framework for testing. - llvm::InitLLVM y(argc, argv); int dummyArgc = 1; tsl::port::InitMain(argv[0], &dummyArgc, &argv); diff --git a/xla/tsl/c/BUILD b/xla/tsl/c/BUILD new file mode 100644 index 0000000000000..5a93daf0433fe --- /dev/null +++ b/xla/tsl/c/BUILD @@ -0,0 +1,117 @@ +# Description: +# C API for TensorFlow, for use by client language bindings. + +load("@tsl//tsl:tsl.bzl", "internal_visibility", "tsl_copts", "tsl_gpu_library") + +# buildifier: disable=same-origin-load +load("@tsl//tsl:tsl.default.bzl", "filegroup") +load("@tsl//tsl/platform:build_config.bzl", "tsl_cc_test") +load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +# ----------------------------------------------------------------------------- +# Public targets + +filegroup( + name = "headers", + srcs = [ + "tsl_status.h", + ], + visibility = internal_visibility(["//tensorflow:__subpackages__"]), +) + +filegroup( + name = "srcs", + srcs = glob( + [ + "*.cc", + "*.h", + ], + exclude = [ + "*test*", + ], + ), + visibility = internal_visibility([ + "//tensorflow/c:__subpackages__", + ]), +) + +tsl_gpu_library( + name = "c_api", + hdrs = [ + "tsl_status.h", + ], + copts = tsl_copts(), + visibility = ["//visibility:public"], + deps = [ + ":tsl_status_internal", + ], +) + +tsl_gpu_library( + name = "tsl_status_internal", + hdrs = [ + "tsl_status.h", + "tsl_status_internal.h", + ], + visibility = ["//visibility:public"], + deps = [ + "@tsl//tsl/platform:status", + ], +) + +cc_library( + name = "tsl_status", + srcs = ["tsl_status.cc"], + hdrs = ["tsl_status.h"], + visibility = ["//visibility:public"], + deps = [ + ":tsl_status_internal", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:status", + ], +) + +tsl_cc_test( + name = "tsl_status_test", + srcs = ["tsl_status_test.cc"], + deps = [ + ":tsl_status", + ":tsl_status_internal", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "tsl_status_headers", + hdrs = ["tsl_status.h"], + visibility = ["//visibility:public"], +) + +tsl_gpu_library( + name = "tsl_status_helper", + srcs = ["tsl_status_helper.cc"], + hdrs = ["tsl_status_helper.h"], + visibility = ["//visibility:public"], + deps = [ + ":tsl_status", + ":tsl_status_internal", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:status", + ], +) + +filegroup( + name = "tsl_status_internal_headers", + srcs = ["tsl_status_internal.h"], + visibility = internal_visibility([ + "//tensorflow/c:__subpackages__", + ]), +) diff --git a/third_party/tsl/tsl/c/tsl_status.cc b/xla/tsl/c/tsl_status.cc similarity index 95% rename from third_party/tsl/tsl/c/tsl_status.cc rename to xla/tsl/c/tsl_status.cc index 5be154e52c99e..d6f86a71aa7ae 100644 --- a/third_party/tsl/tsl/c/tsl_status.cc +++ b/xla/tsl/c/tsl_status.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/c/tsl_status.h" +#include "xla/tsl/c/tsl_status.h" #include -#include "tsl/c/tsl_status_internal.h" +#include "xla/tsl/c/tsl_status_internal.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" @@ -31,7 +31,7 @@ void TSL_DeleteStatus(TSL_Status* s) { delete s; } void TSL_SetStatus(TSL_Status* s, TSL_Code code, const char* msg) { if (code == TSL_OK) { - s->status = ::tsl::OkStatus(); + s->status = absl::OkStatus(); return; } s->status = diff --git a/third_party/tsl/tsl/c/tsl_status.h b/xla/tsl/c/tsl_status.h similarity index 96% rename from third_party/tsl/tsl/c/tsl_status.h rename to xla/tsl/c/tsl_status.h index 6e332747c3b3f..017456015ada1 100644 --- a/third_party/tsl/tsl/c/tsl_status.h +++ b/xla/tsl/c/tsl_status.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_C_TSL_STATUS_H_ -#define TENSORFLOW_TSL_C_TSL_STATUS_H_ +#ifndef XLA_TSL_C_TSL_STATUS_H_ +#define XLA_TSL_C_TSL_STATUS_H_ #ifdef __cplusplus extern "C" { @@ -89,4 +89,4 @@ extern const char* TSL_Message(const TSL_Status* s); } /* end extern "C" */ #endif -#endif // TENSORFLOW_TSL_C_TSL_STATUS_H_ +#endif // XLA_TSL_C_TSL_STATUS_H_ diff --git a/third_party/tsl/tsl/c/tsl_status_helper.cc b/xla/tsl/c/tsl_status_helper.cc similarity index 97% rename from third_party/tsl/tsl/c/tsl_status_helper.cc rename to xla/tsl/c/tsl_status_helper.cc index bd2d8bce86492..ca1c8b2dbe322 100644 --- a/third_party/tsl/tsl/c/tsl_status_helper.cc +++ b/xla/tsl/c/tsl_status_helper.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/c/tsl_status_helper.h" +#include "xla/tsl/c/tsl_status_helper.h" -#include "tsl/c/tsl_status_internal.h" +#include "xla/tsl/c/tsl_status_internal.h" #include "tsl/platform/errors.h" namespace tsl { diff --git a/third_party/tsl/tsl/c/tsl_status_helper.h b/xla/tsl/c/tsl_status_helper.h similarity index 83% rename from third_party/tsl/tsl/c/tsl_status_helper.h rename to xla/tsl/c/tsl_status_helper.h index 80cc80ffab321..905785dc67838 100644 --- a/third_party/tsl/tsl/c/tsl_status_helper.h +++ b/xla/tsl/c/tsl_status_helper.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_C_TSL_STATUS_HELPER_H_ -#define TENSORFLOW_TSL_C_TSL_STATUS_HELPER_H_ +#ifndef XLA_TSL_C_TSL_STATUS_HELPER_H_ +#define XLA_TSL_C_TSL_STATUS_HELPER_H_ #include -#include "tsl/c/tsl_status.h" +#include "xla/tsl/c/tsl_status.h" #include "tsl/platform/status.h" namespace tsl { @@ -29,4 +29,4 @@ absl::StatusCode StatusCodeFromTSLCode(TSL_Code code); } // namespace tsl -#endif // TENSORFLOW_TSL_C_TSL_STATUS_HELPER_H_ +#endif // XLA_TSL_C_TSL_STATUS_HELPER_H_ diff --git a/third_party/tsl/tsl/c/tsl_status_internal.h b/xla/tsl/c/tsl_status_internal.h similarity index 83% rename from third_party/tsl/tsl/c/tsl_status_internal.h rename to xla/tsl/c/tsl_status_internal.h index 8da822511cbaf..132adc62dac66 100644 --- a/third_party/tsl/tsl/c/tsl_status_internal.h +++ b/xla/tsl/c/tsl_status_internal.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_C_TSL_STATUS_INTERNAL_H_ -#define TENSORFLOW_TSL_C_TSL_STATUS_INTERNAL_H_ +#ifndef XLA_TSL_C_TSL_STATUS_INTERNAL_H_ +#define XLA_TSL_C_TSL_STATUS_INTERNAL_H_ #include "tsl/platform/status.h" @@ -22,7 +22,7 @@ limitations under the License. // and should not be depended on. struct TSL_Status { - tsl::Status status; + absl::Status status; }; -#endif // TENSORFLOW_TSL_C_TSL_STATUS_INTERNAL_H_ +#endif // XLA_TSL_C_TSL_STATUS_INTERNAL_H_ diff --git a/third_party/tsl/tsl/c/tsl_status_test.cc b/xla/tsl/c/tsl_status_test.cc similarity index 95% rename from third_party/tsl/tsl/c/tsl_status_test.cc rename to xla/tsl/c/tsl_status_test.cc index 22ccc697dc233..b4518644f837f 100644 --- a/third_party/tsl/tsl/c/tsl_status_test.cc +++ b/xla/tsl/c/tsl_status_test.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/c/tsl_status.h" +#include "xla/tsl/c/tsl_status.h" #include #include #include -#include "tsl/c/tsl_status_internal.h" +#include "xla/tsl/c/tsl_status_internal.h" #include "tsl/platform/errors.h" #include "tsl/platform/test.h" diff --git a/xla/tsl/cuda/BUILD.bazel b/xla/tsl/cuda/BUILD.bazel new file mode 100644 index 0000000000000..def74d56e8ca7 --- /dev/null +++ b/xla/tsl/cuda/BUILD.bazel @@ -0,0 +1,297 @@ +# Description: +# Stubs for dynamically loading CUDA. + +load( + "@tsl//tsl/platform:rules_cc.bzl", + "cc_library", +) +load( + "@tsl//tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", + "if_cuda_is_configured", +) +load("//xla/tsl/cuda:stub.bzl", "cuda_stub") + +package( + licenses = ["notice"], +) + +cuda_stub( + name = "cublas", + srcs = ["cublas.symbols"], +) + +cc_library( + name = "cublas", # buildifier: disable=duplicated-name + srcs = if_cuda_is_configured([ + "cublas_stub.cc", + "cublas.tramp.S", + ]), + linkopts = if_cuda_is_configured(cuda_rpath_flags( + "nvidia/cublas/lib", + )), + local_defines = [ + "IMPLIB_EXPORT_SHIMS=1", + ], + textual_hdrs = ["cublas.inc"], + visibility = ["//visibility:public"], + deps = if_cuda_is_configured([ + "@com_google_absl//absl/container:flat_hash_set", + "@local_config_cuda//cuda:cuda_headers", + "@tsl//tsl/platform:dso_loader", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:load_library", + ]), +) + +cuda_stub( + name = "cublasLt", + srcs = ["cublasLt.symbols"], +) + +cc_library( + name = "cublas_lt", + srcs = if_cuda_is_configured([ + "cublasLt_stub.cc", + "cublasLt.tramp.S", + ]), + local_defines = [ + "IMPLIB_EXPORT_SHIMS=1", + ], + textual_hdrs = ["cublasLt.inc"], + visibility = ["//visibility:public"], + deps = if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + "@tsl//tsl/platform:dso_loader", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:load_library", + ]), +) + +cuda_stub( + name = "cuda", + srcs = ["cuda.symbols"], +) + +cc_library( + name = "cuda", # buildifier: disable=duplicated-name + srcs = if_cuda_is_configured([ + "cuda_stub.cc", + "cuda.tramp.S", + ]), + local_defines = [ + "IMPLIB_EXPORT_SHIMS=1", + ], + textual_hdrs = ["cuda.inc"], + visibility = ["//visibility:public"], + deps = if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + "@tsl//tsl/platform:dso_loader", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:load_library", + ]), +) + +cuda_stub( + name = "cudart", + srcs = ["cudart.symbols"], +) + +cc_library( + name = "cudart", # buildifier: disable=duplicated-name + srcs = select({ + # include dynamic loading implementation only when if_cuda_is_configured and build dynamically + "@tsl//tsl:is_cuda_enabled_and_oss": [ + "cudart.tramp.S", + "cudart_stub.cc", + ], + "//conditions:default": [], + }), + linkopts = select({ + "@tsl//tsl:is_cuda_enabled_and_oss": cuda_rpath_flags("nvidia/cuda_runtime/lib"), + "//conditions:default": [], + }), + local_defines = [ + "IMPLIB_EXPORT_SHIMS=1", + ], + textual_hdrs = ["cudart.inc"], + visibility = ["//visibility:public"], + deps = select({ + "@tsl//tsl:is_cuda_enabled_and_oss": [ + ":cuda", + "@com_google_absl//absl/container:flat_hash_set", + "@local_config_cuda//cuda:cuda_headers", + "@tsl//tsl/platform:dso_loader", + "@tsl//tsl/platform:load_library", + "@tsl//tsl/platform:logging", + ], + "//conditions:default": [], + }), +) + +cuda_stub( + name = "cudnn", + srcs = ["cudnn.symbols"], +) + +cc_library( + name = "cudnn", # buildifier: disable=duplicated-name + srcs = if_cuda_is_configured([ + "cudnn_stub.cc", + "cudnn.tramp.S", + ]), + linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/cudnn/lib")), + local_defines = [ + "IMPLIB_EXPORT_SHIMS=1", + ], + textual_hdrs = ["cudnn.inc"], + visibility = ["//visibility:public"], + deps = if_cuda_is_configured([ + "@com_google_absl//absl/container:flat_hash_map", + "@local_config_cuda//cuda:cudnn_header", + "@tsl//tsl/platform:dso_loader", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:load_library", + ]), +) + +cc_library( + name = "nccl_rpath", + linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/nccl/lib")), + visibility = ["//visibility:public"], +) + +cc_library( + name = "tensorrt_rpath", + linkopts = if_cuda_is_configured(cuda_rpath_flags("tensorrt")), + visibility = ["//visibility:public"], +) + +cuda_stub( + name = "cufft", + srcs = ["cufft.symbols"], +) + +cc_library( + name = "cufft", # buildifier: disable=duplicated-name + srcs = if_cuda_is_configured([ + "cufft_stub.cc", + "cufft.tramp.S", + ]), + linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/cufft/lib")), + local_defines = [ + "IMPLIB_EXPORT_SHIMS=1", + ], + textual_hdrs = ["cufft.inc"], + visibility = ["//visibility:public"], + deps = if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + "@tsl//tsl/platform:dso_loader", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:load_library", + ]), +) + +cuda_stub( + name = "cupti", + srcs = ["cupti.symbols"], +) + +cc_library( + name = "cupti", # buildifier: disable=duplicated-name + srcs = if_cuda_is_configured([ + "cupti_stub.cc", + "cupti.tramp.S", + ]), + data = if_cuda_is_configured(["@local_config_cuda//cuda:cupti_dsos"]), + linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/cuda_cupti/lib")), + local_defines = [ + "IMPLIB_EXPORT_SHIMS=1", + ], + textual_hdrs = ["cupti.inc"], + visibility = ["//visibility:public"], + deps = if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cupti_headers", + "@tsl//tsl/platform:dso_loader", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:load_library", + ]), +) + +cuda_stub( + name = "cusolver", + srcs = ["cusolver.symbols"], +) + +cc_library( + name = "cusolver", # buildifier: disable=duplicated-name + srcs = if_cuda_is_configured([ + "cusolver_stub.cc", + "cusolver.tramp.S", + ]), + linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/cusolver/lib")), + local_defines = [ + "IMPLIB_EXPORT_SHIMS=1", + ], + textual_hdrs = ["cusolver.inc"], + visibility = ["//visibility:public"], + deps = if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + "@tsl//tsl/platform:dso_loader", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:load_library", + ]), +) + +cuda_stub( + name = "cusparse", + srcs = ["cusparse.symbols"], +) + +cc_library( + name = "cusparse", # buildifier: disable=duplicated-name + srcs = if_cuda_is_configured([ + "cusparse_stub.cc", + "cusparse.tramp.S", + ]), + linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/cusparse/lib")), + local_defines = [ + "IMPLIB_EXPORT_SHIMS=1", + ], + textual_hdrs = ["cusparse.inc"], + visibility = ["//visibility:public"], + deps = if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + "@tsl//tsl/platform:dso_loader", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:load_library", + ]), +) + +cuda_stub( + name = "nccl", + srcs = ["nccl.symbols"], +) + +cc_library( + name = "nccl_stub", + srcs = if_cuda_is_configured([ + "nccl_stub.cc", + "nccl.tramp.S", + ]), + linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/nccl/lib")), + local_defines = [ + "IMPLIB_EXPORT_SHIMS=1", + ], + textual_hdrs = ["nccl.inc"], + visibility = ["//visibility:public"], + deps = if_cuda_is_configured([ + "@com_google_absl//absl/container:flat_hash_set", + "@local_config_cuda//cuda:cuda_headers", + "@local_config_nccl//:nccl_headers", + "@tsl//tsl/platform:dso_loader", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:load_library", + ]), +) diff --git a/third_party/tsl/tsl/cuda/cublas.symbols b/xla/tsl/cuda/cublas.symbols similarity index 100% rename from third_party/tsl/tsl/cuda/cublas.symbols rename to xla/tsl/cuda/cublas.symbols diff --git a/third_party/tsl/tsl/cuda/cublasLt.symbols b/xla/tsl/cuda/cublasLt.symbols similarity index 100% rename from third_party/tsl/tsl/cuda/cublasLt.symbols rename to xla/tsl/cuda/cublasLt.symbols diff --git a/third_party/tsl/tsl/cuda/cublasLt_stub.cc b/xla/tsl/cuda/cublasLt_stub.cc similarity index 91% rename from third_party/tsl/tsl/cuda/cublasLt_stub.cc rename to xla/tsl/cuda/cublasLt_stub.cc index df4e73bebc126..db60995d59fa5 100644 --- a/third_party/tsl/tsl/cuda/cublasLt_stub.cc +++ b/xla/tsl/cuda/cublasLt_stub.cc @@ -15,7 +15,8 @@ limitations under the License. #include "third_party/gpus/cuda/include/cublasLt.h" #include "third_party/gpus/cuda/include/cuda.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the cuBLASLt API by forwarding to cuBLASLt loaded from the DSO. @@ -33,15 +34,14 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; } const char* kSymbols[] = { -#include "tsl/cuda/cublasLt.inc" +#include "xla/tsl/cuda/cublasLt.inc" }; constexpr size_t kNumSymbols = sizeof(kSymbols) / sizeof(const char*); diff --git a/third_party/tsl/tsl/cuda/cublas_stub.cc b/xla/tsl/cuda/cublas_stub.cc similarity index 97% rename from third_party/tsl/tsl/cuda/cublas_stub.cc rename to xla/tsl/cuda/cublas_stub.cc index 814d64d75d8d6..a4b7fcbb828b6 100644 --- a/third_party/tsl/tsl/cuda/cublas_stub.cc +++ b/xla/tsl/cuda/cublas_stub.cc @@ -24,7 +24,8 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "third_party/gpus/cuda/include/cuda.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the cuBLAS API by forwarding to cuBLAS loaded from the DSO. // Note that it does not implement the v1 interface. @@ -43,15 +44,14 @@ void *GetDsoHandle() { void *LoadSymbol(const char *symbol_name) { void *symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; } const char *kSymbols[] = { -#include "tsl/cuda/cublas.inc" +#include "xla/tsl/cuda/cublas.inc" }; constexpr size_t kNumSymbols = sizeof(kSymbols) / sizeof(const char *); diff --git a/third_party/tsl/tsl/cuda/cuda.symbols b/xla/tsl/cuda/cuda.symbols similarity index 100% rename from third_party/tsl/tsl/cuda/cuda.symbols rename to xla/tsl/cuda/cuda.symbols diff --git a/third_party/tsl/tsl/cuda/cuda_stub.cc b/xla/tsl/cuda/cuda_stub.cc similarity index 91% rename from third_party/tsl/tsl/cuda/cuda_stub.cc rename to xla/tsl/cuda/cuda_stub.cc index a199d4cc70044..e33535c16e33c 100644 --- a/third_party/tsl/tsl/cuda/cuda_stub.cc +++ b/xla/tsl/cuda/cuda_stub.cc @@ -14,7 +14,8 @@ limitations under the License. ==============================================================================*/ #include "third_party/gpus/cuda/include/cuda.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the CUDA driver API by forwarding to CUDA loaded from the DSO. @@ -36,15 +37,14 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; } const char* kSymbols[] = { -#include "tsl/cuda/cuda.inc" +#include "xla/tsl/cuda/cuda.inc" }; constexpr size_t kNumSymbols = sizeof(kSymbols) / sizeof(const char*); diff --git a/third_party/tsl/tsl/cuda/cudart.symbols b/xla/tsl/cuda/cudart.symbols similarity index 100% rename from third_party/tsl/tsl/cuda/cudart.symbols rename to xla/tsl/cuda/cudart.symbols diff --git a/third_party/tsl/tsl/cuda/cudart_stub.cc b/xla/tsl/cuda/cudart_stub.cc similarity index 92% rename from third_party/tsl/tsl/cuda/cudart_stub.cc rename to xla/tsl/cuda/cudart_stub.cc index a3797b5c751cd..7064a72541eef 100644 --- a/third_party/tsl/tsl/cuda/cudart_stub.cc +++ b/xla/tsl/cuda/cudart_stub.cc @@ -21,7 +21,8 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" namespace { void *GetDsoHandle() { @@ -39,13 +40,13 @@ void *GetDsoHandle() { void *LoadSymbol(const char *symbol_name) { void *symbol = nullptr; - auto env = tsl::Env::Default(); - env->GetSymbolFromLibrary(GetDsoHandle(), symbol_name, &symbol).IgnoreError(); + tsl::internal::GetSymbolFromLibrary(GetDsoHandle(), symbol_name, &symbol) + .IgnoreError(); return symbol; } const char *kSymbols[] = { -#include "tsl/cuda/cudart.inc" +#include "xla/tsl/cuda/cudart.inc" }; constexpr size_t kNumSymbols = sizeof(kSymbols) / sizeof(const char *); diff --git a/third_party/tsl/tsl/cuda/cudnn.symbols b/xla/tsl/cuda/cudnn.symbols similarity index 98% rename from third_party/tsl/tsl/cuda/cudnn.symbols rename to xla/tsl/cuda/cudnn.symbols index 2c4dbd71030b3..95c46295e1dcb 100644 --- a/third_party/tsl/tsl/cuda/cudnn.symbols +++ b/xla/tsl/cuda/cudnn.symbols @@ -3,6 +3,7 @@ cudnnActivationForward cudnnAddTensor cudnnAdvInferVersionCheck cudnnAdvTrainVersionCheck +cudnnAdvVersionCheck cudnnBackendCreateDescriptor cudnnBackendDestroyDescriptor cudnnBackendExecute @@ -20,6 +21,7 @@ cudnnCTCLoss cudnnCTCLoss_v8 cudnnCnnInferVersionCheck cudnnCnnTrainVersionCheck +cudnnCnnVersionCheck cudnnConvolutionBackwardBias cudnnConvolutionBackwardData cudnnConvolutionBackwardFilter @@ -175,6 +177,7 @@ cudnnGetTensorNdDescriptor cudnnGetTensorSizeInBytes cudnnGetTensorTransformDescriptor cudnnGetVersion +cudnnGraphVersionCheck cudnnIm2Col cudnnInitTransformDest cudnnLRNCrossChannelBackward @@ -189,6 +192,7 @@ cudnnNormalizationForwardTraining cudnnOpTensor cudnnOpsInferVersionCheck cudnnOpsTrainVersionCheck +cudnnOpsVersionCheck cudnnPoolingBackward cudnnPoolingForward cudnnQueryRuntimeError diff --git a/third_party/tsl/tsl/cuda/cudnn_stub.cc b/xla/tsl/cuda/cudnn_stub.cc similarity index 93% rename from third_party/tsl/tsl/cuda/cudnn_stub.cc rename to xla/tsl/cuda/cudnn_stub.cc index f3cab179eb0b7..192009c9e8728 100644 --- a/third_party/tsl/tsl/cuda/cudnn_stub.cc +++ b/xla/tsl/cuda/cudnn_stub.cc @@ -16,7 +16,8 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "third_party/gpus/cudnn/cudnn.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the cuDNN API by forwarding to cuDNN loaded from the DSO. @@ -38,15 +39,14 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; } const char* kSymbols[] = { -#include "tsl/cuda/cudnn.inc" +#include "xla/tsl/cuda/cudnn.inc" }; constexpr size_t kNumSymbols = sizeof(kSymbols) / sizeof(const char*); diff --git a/third_party/tsl/tsl/cuda/cufft.symbols b/xla/tsl/cuda/cufft.symbols similarity index 100% rename from third_party/tsl/tsl/cuda/cufft.symbols rename to xla/tsl/cuda/cufft.symbols diff --git a/third_party/tsl/tsl/cuda/cufft_stub.cc b/xla/tsl/cuda/cufft_stub.cc similarity index 91% rename from third_party/tsl/tsl/cuda/cufft_stub.cc rename to xla/tsl/cuda/cufft_stub.cc index 8f5c1b0d68733..ea7b08f882189 100644 --- a/third_party/tsl/tsl/cuda/cufft_stub.cc +++ b/xla/tsl/cuda/cufft_stub.cc @@ -15,7 +15,8 @@ limitations under the License. #include "third_party/gpus/cuda/include/cufft.h" #include "third_party/gpus/cuda/include/cufftXt.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the cuFFT API by forwarding to cuFFT loaded from the DSO. @@ -37,15 +38,14 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; } const char* kSymbols[] = { -#include "tsl/cuda/cufft.inc" +#include "xla/tsl/cuda/cufft.inc" }; constexpr size_t kNumSymbols = sizeof(kSymbols) / sizeof(const char*); diff --git a/third_party/tsl/tsl/cuda/cupti.symbols b/xla/tsl/cuda/cupti.symbols similarity index 100% rename from third_party/tsl/tsl/cuda/cupti.symbols rename to xla/tsl/cuda/cupti.symbols diff --git a/third_party/tsl/tsl/cuda/cupti_stub.cc b/xla/tsl/cuda/cupti_stub.cc similarity index 91% rename from third_party/tsl/tsl/cuda/cupti_stub.cc rename to xla/tsl/cuda/cupti_stub.cc index 9e632010d83a7..01d13a8ea7d4f 100644 --- a/third_party/tsl/tsl/cuda/cupti_stub.cc +++ b/xla/tsl/cuda/cupti_stub.cc @@ -16,7 +16,8 @@ limitations under the License. #include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" #include "third_party/gpus/cuda/include/cuda.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the CUPTI API by forwarding to CUPTI loaded from the DSO. @@ -38,15 +39,14 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; } const char* kSymbols[] = { -#include "tsl/cuda/cupti.inc" +#include "xla/tsl/cuda/cupti.inc" }; constexpr size_t kNumSymbols = sizeof(kSymbols) / sizeof(const char*); diff --git a/third_party/tsl/tsl/cuda/cusolver.symbols b/xla/tsl/cuda/cusolver.symbols similarity index 100% rename from third_party/tsl/tsl/cuda/cusolver.symbols rename to xla/tsl/cuda/cusolver.symbols diff --git a/third_party/tsl/tsl/cuda/cusolver_stub.cc b/xla/tsl/cuda/cusolver_stub.cc similarity index 91% rename from third_party/tsl/tsl/cuda/cusolver_stub.cc rename to xla/tsl/cuda/cusolver_stub.cc index d11601b3bd421..d76526042582e 100644 --- a/third_party/tsl/tsl/cuda/cusolver_stub.cc +++ b/xla/tsl/cuda/cusolver_stub.cc @@ -16,7 +16,8 @@ limitations under the License. #include "third_party/gpus/cuda/include/cusolverDn.h" #include "third_party/gpus/cuda/include/cusolverSp.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the cusolver API by forwarding to cusolver loaded from the DSO. @@ -38,15 +39,14 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; } const char* kSymbols[] = { -#include "tsl/cuda/cusolver.inc" +#include "xla/tsl/cuda/cusolver.inc" }; constexpr size_t kNumSymbols = sizeof(kSymbols) / sizeof(const char*); diff --git a/third_party/tsl/tsl/cuda/cusparse.symbols b/xla/tsl/cuda/cusparse.symbols similarity index 100% rename from third_party/tsl/tsl/cuda/cusparse.symbols rename to xla/tsl/cuda/cusparse.symbols diff --git a/third_party/tsl/tsl/cuda/cusparse_stub.cc b/xla/tsl/cuda/cusparse_stub.cc similarity index 91% rename from third_party/tsl/tsl/cuda/cusparse_stub.cc rename to xla/tsl/cuda/cusparse_stub.cc index 16141e51e2613..b8327024e6720 100644 --- a/third_party/tsl/tsl/cuda/cusparse_stub.cc +++ b/xla/tsl/cuda/cusparse_stub.cc @@ -15,7 +15,8 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cusparse.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the cusparse API by forwarding to cusparse loaded from the DSO. @@ -37,15 +38,14 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; } const char* kSymbols[] = { -#include "tsl/cuda/cusparse.inc" +#include "xla/tsl/cuda/cusparse.inc" }; constexpr size_t kNumSymbols = sizeof(kSymbols) / sizeof(const char*); diff --git a/third_party/tsl/tsl/cuda/nccl.symbols b/xla/tsl/cuda/nccl.symbols similarity index 87% rename from third_party/tsl/tsl/cuda/nccl.symbols rename to xla/tsl/cuda/nccl.symbols index 0d6552dafe023..e5164825373af 100644 --- a/third_party/tsl/tsl/cuda/nccl.symbols +++ b/xla/tsl/cuda/nccl.symbols @@ -11,6 +11,8 @@ ncclCommGetAsyncError ncclCommInitAll ncclCommInitRank ncclCommInitRankConfig +ncclCommDeregister +ncclCommRegister ncclCommSplit ncclCommUserRank ncclGetErrorString @@ -19,6 +21,8 @@ ncclGetUniqueId ncclGetVersion ncclGroupEnd ncclGroupStart +ncclMemAlloc +ncclMemFree ncclRecv ncclRedOpCreatePreMulSum ncclRedOpDestroy @@ -38,6 +42,8 @@ pncclCommGetAsyncError pncclCommInitAll pncclCommInitRank pncclCommInitRankConfig +pncclCommDeregister +pncclCommRegister pncclCommSplit pncclCommUserRank pncclGetErrorString @@ -46,6 +52,8 @@ pncclGetUniqueId pncclGetVersion pncclGroupEnd pncclGroupStart +pncclMemAlloc +pncclMemFree pncclRecv pncclRedOpCreatePreMulSum pncclRedOpDestroy diff --git a/third_party/tsl/tsl/cuda/nccl_stub.cc b/xla/tsl/cuda/nccl_stub.cc similarity index 93% rename from third_party/tsl/tsl/cuda/nccl_stub.cc rename to xla/tsl/cuda/nccl_stub.cc index 0ebae2f3c2b2e..f3895da245176 100644 --- a/third_party/tsl/tsl/cuda/nccl_stub.cc +++ b/xla/tsl/cuda/nccl_stub.cc @@ -18,7 +18,8 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/nccl/nccl.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the nccl API by forwarding to nccl loaded from a DSO. @@ -40,15 +41,14 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; } const char* kSymbols[] = { -#include "tsl/cuda/nccl.inc" +#include "xla/tsl/cuda/nccl.inc" }; constexpr size_t kNumSymbols = sizeof(kSymbols) / sizeof(const char*); diff --git a/xla/tsl/cuda/stub.bzl b/xla/tsl/cuda/stub.bzl new file mode 100644 index 0000000000000..23ec8a86e5063 --- /dev/null +++ b/xla/tsl/cuda/stub.bzl @@ -0,0 +1,27 @@ +"""Macros to generate CUDA library stubs from a list of symbols.""" + +def cuda_stub(name, srcs): + """Generates a CUDA stub from a list of symbols. + + Generates two files: + * library.inc, which contains a list of symbols suitable for inclusion by + C++, and + * library.tramp.S, which contains assembly-language trampolines for each + symbol. + """ + native.genrule( + name = "{}_stub_gen".format(name), + srcs = srcs, + tools = ["//third_party/implib_so:make_stub"], + outs = [ + "{}.inc".format(name), + "{}.tramp.S".format(name), + ], + tags = ["gpu"], + cmd = select({ + "@tsl//tsl:linux_aarch64": "$(location //third_party/implib_so:make_stub) $< --outdir $(RULEDIR) --target aarch64", + "@tsl//tsl:linux_x86_64": "$(location //third_party/implib_so:make_stub) $< --outdir $(RULEDIR) --target x86_64", + "@tsl//tsl:linux_ppc64le": "$(location //third_party/implib_so:make_stub) $< --outdir $(RULEDIR) --target powerpc64le", + "//conditions:default": "NOT_IMPLEMENTED_FOR_THIS_PLATFORM_OR_ARCHITECTURE", + }), + ) diff --git a/xla/tsl/mkl/BUILD b/xla/tsl/mkl/BUILD new file mode 100644 index 0000000000000..9e42f75a1ffdc --- /dev/null +++ b/xla/tsl/mkl/BUILD @@ -0,0 +1,146 @@ +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load( + "@tsl//tsl:tsl.bzl", + "clean_dep", +) + +licenses(["notice"]) # 3-Clause BSD + +config_setting( + name = "build_with_mkl", + define_values = { + "build_with_mkl": "true", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "build_with_mkl_lnx_x64", + define_values = { + "build_with_mkl": "true", + }, + values = { + "cpu": "k8", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "build_with_mkl_lnx_openmp", + constraint_values = [ + "@platforms//os:linux", + ], + define_values = { + "build_with_mkl": "true", + "build_with_openmp": "true", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "build_with_mkl_windows_openmp", + constraint_values = [ + "@platforms//os:windows", + ], + define_values = { + "build_with_mkl": "true", + "build_with_openmp": "true", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "build_with_mkl_aarch64", + define_values = { + "build_with_mkl_aarch64": "true", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "enable_mkl", + define_values = { + "enable_mkl": "true", + "build_with_mkl": "true", + }, + visibility = ["//visibility:public"], +) + +filegroup( + name = "LICENSE", + srcs = [ + "MKL_LICENSE", + "@llvm_openmp//:LICENSE.txt", + ], + visibility = ["//visibility:public"], +) + +# TODO(Intel-tf) Remove the following 3 calls to cc_library and replace all uses +# of mkl_libs_* with @llvm_openmp//:libiomp5.* directly. + +cc_library( + name = "mkl_libs_linux", + srcs = [ + "@llvm_openmp//:libiomp5.so", + ], + hdrs = ["@llvm_openmp//:config_omp"], + target_compatible_with = select({ + "@xla//xla/tsl/mkl:build_with_mkl": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + visibility = ["//visibility:public"], +) + +# MacOS build configuration is provided for completness, it has not been tested +cc_library( + name = "mkl_libs_darwin", + srcs = [ + "@llvm_openmp//:libiomp5.dylib", + ], + hdrs = ["@llvm_openmp//:config_omp"], + target_compatible_with = select({ + "@xla//xla/tsl/mkl:build_with_mkl": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "mkl_libs_windows", + srcs = [ + "@llvm_openmp//:libiomp5md.dll", + ], + hdrs = ["@llvm_openmp//:config_omp"], + target_compatible_with = select({ + "@xla//xla/tsl/mkl:build_with_mkl": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "intel_binary_blob", + target_compatible_with = select({ + "@xla//xla/tsl/mkl:build_with_mkl": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + visibility = ["//visibility:public"], + deps = select({ + clean_dep("@tsl//tsl:linux_x86_64"): [ + ":mkl_libs_linux", + ], + clean_dep("@tsl//tsl:macos"): [ + ":mkl_libs_darwin", + ], + clean_dep("@tsl//tsl:windows"): [ + ":mkl_libs_windows", + ], + "//conditions:default": [], + }), +) + +bzl_library( + name = "build_defs_bzl", + srcs = ["build_defs.bzl"], + visibility = ["//visibility:public"], +) diff --git a/third_party/tsl/tsl/mkl/LICENSE b/xla/tsl/mkl/LICENSE similarity index 100% rename from third_party/tsl/tsl/mkl/LICENSE rename to xla/tsl/mkl/LICENSE diff --git a/third_party/tsl/tsl/mkl/MKL_LICENSE b/xla/tsl/mkl/MKL_LICENSE similarity index 100% rename from third_party/tsl/tsl/mkl/MKL_LICENSE rename to xla/tsl/mkl/MKL_LICENSE diff --git a/xla/tsl/mkl/build_defs.bzl b/xla/tsl/mkl/build_defs.bzl new file mode 100644 index 0000000000000..9ebeeb14ff047 --- /dev/null +++ b/xla/tsl/mkl/build_defs.bzl @@ -0,0 +1,161 @@ +"""Starlark macros for MKL. + +if_mkl is a conditional to check if we are building with MKL. +if_mkl_ml is a conditional to check if we are building with MKL-ML. +if_mkl_ml_only is a conditional to check for MKL-ML-only (no MKL-DNN) mode. +if_mkl_lnx_x64 is a conditional to check for MKL +if_enable_mkl is a conditional to check if building with MKL and MKL is enabled. + +mkl_repository is a repository rule for creating MKL repository rule that can +be pointed to either a local folder, or download it from the internet. +mkl_repository depends on the following environment variables: + * `TF_MKL_ROOT`: The root folder where a copy of libmkl is located. +""" + +_TF_MKL_ROOT = "TF_MKL_ROOT" + +def if_mkl(if_true, if_false = []): + """Shorthand for select()'ing on whether we're building with oneDNN. + + OneDNN gets built if we are building on platforms that support oneDNN + (x86 linux/windows) or if specifcially configured to use oneDNN. + + Args: + if_true: expression to evaluate if building with oneDNN. + if_false: expression to evaluate if building without oneDNN. + + Returns: + a select evaluating to either if_true or if_false as appropriate. + + TODO(intel-tf): + the first "if_true" line is kept because non-x86 platforms (e.g., ARM) + may need it. It may be deleted in future with refactoring. + """ + return select({ + "@xla//xla/tsl/mkl:build_with_mkl_aarch64": if_true, + "@tsl//tsl:linux_x86_64": if_true, + "@tsl//tsl:windows": if_true, + "//conditions:default": if_false, + }) + +def if_mkl_ml(if_true, if_false = []): + """Shorthand for select()'ing on whether we're building with MKL-ML. + + Args: + if_true: expression to evaluate if building with MKL-ML. + if_false: expression to evaluate if building without MKL-ML + (i.e. without MKL at all, or with MKL-DNN only). + + Returns: + a select evaluating to either if_true or if_false as appropriate. + """ + return select({ + "@tsl//third_party/mkl_dnn:build_with_mkl_opensource": if_false, + "@xla//xla/tsl/mkl:build_with_mkl": if_true, + "//conditions:default": if_false, + }) + +def if_mkl_lnx_x64(if_true, if_false = []): + """Shorthand to select() if building with MKL and the target is Linux x86-64. + + Args: + if_true: expression to evaluate if building with MKL is enabled and the + target platform is Linux x86-64. + if_false: expression to evaluate if building without MKL or for a + different platform. + + Returns: + a select evaluating to either if_true or if_false as appropriate. + """ + return select({ + "@xla//xla/tsl/mkl:build_with_mkl_lnx_x64": if_true, + "//conditions:default": if_false, + }) + +def if_enable_mkl(if_true, if_false = []): + """Shorthand to select() if we are building with MKL and MKL is enabled. + + This is only effective when built with MKL. + + Args: + if_true: expression to evaluate if building with MKL and MKL is enabled + if_false: expression to evaluate if building without MKL or MKL is not enabled. + + Returns: + A select evaluating to either if_true or if_false as appropriate. + """ + return select({ + "@xla//xla/tsl/mkl:enable_mkl": if_true, + "//conditions:default": if_false, + }) + +def mkl_deps(): + """Returns the correct set of oneDNN library dependencies. + + Shorthand for select() to pull in the correct set of oneDNN library deps + depending on the platform. x86 Linux/Windows with or without --config=mkl + will always build with oneDNN library. + + Returns: + a select evaluating to a list of library dependencies, suitable for + inclusion in the deps attribute of rules. + """ + return select({ + "@xla//xla/tsl/mkl:build_with_mkl_aarch64": ["@mkl_dnn_acl_compatible//:mkl_dnn_acl"], + "@tsl//tsl:linux_x86_64": ["@onednn//:mkl_dnn"], + "@tsl//tsl:windows": ["@onednn//:mkl_dnn"], + "//conditions:default": [], + }) + +def onednn_v3_define(): + """Returns a define to build with oneDNN v3.x if it is enabled. + + Returns: + A define to build with oneDNN v3.x for Linux x86 and Windows x86 builds. + An empty list of all other cases (include ARM builds). + """ + return select({ + "@xla//xla/tsl/mkl:build_with_mkl_aarch64": ["-DENABLE_ONEDNN_V3"], + "@tsl//tsl:linux_x86_64": ["-DENABLE_ONEDNN_V3"], + "@tsl//tsl:windows": ["-DENABLE_ONEDNN_V3"], + "//conditions:default": [], + }) + +def _enable_local_mkl(repository_ctx): + return _TF_MKL_ROOT in repository_ctx.os.environ + +def _mkl_autoconf_impl(repository_ctx): + """Implementation of the local_mkl_autoconf repository rule.""" + + if _enable_local_mkl(repository_ctx): + # Symlink lib and include local folders. + mkl_root = repository_ctx.os.environ[_TF_MKL_ROOT] + mkl_lib_path = "%s/lib" % mkl_root + repository_ctx.symlink(mkl_lib_path, "lib") + mkl_include_path = "%s/include" % mkl_root + repository_ctx.symlink(mkl_include_path, "include") + mkl_license_path = "%s/license.txt" % mkl_root + repository_ctx.symlink(mkl_license_path, "license.txt") + else: + # setup remote mkl repository. + repository_ctx.download_and_extract( + repository_ctx.attr.urls, + sha256 = repository_ctx.attr.sha256, + stripPrefix = repository_ctx.attr.strip_prefix, + ) + + # Also setup BUILD file. + repository_ctx.symlink(repository_ctx.attr.build_file, "BUILD") + +mkl_repository = repository_rule( + implementation = _mkl_autoconf_impl, + environ = [ + _TF_MKL_ROOT, + ], + attrs = { + "build_file": attr.label(), + "urls": attr.string_list(default = []), + "sha256": attr.string(default = ""), + "strip_prefix": attr.string(default = ""), + }, +) diff --git a/third_party/tsl/tsl/python/lib/core/BUILD b/xla/tsl/python/lib/core/BUILD similarity index 100% rename from third_party/tsl/tsl/python/lib/core/BUILD rename to xla/tsl/python/lib/core/BUILD diff --git a/third_party/tsl/tsl/python/lib/core/ml_dtypes.cc b/xla/tsl/python/lib/core/ml_dtypes.cc similarity index 93% rename from third_party/tsl/tsl/python/lib/core/ml_dtypes.cc rename to xla/tsl/python/lib/core/ml_dtypes.cc index 795c772ca39ba..d138c00bf9e6d 100644 --- a/third_party/tsl/tsl/python/lib/core/ml_dtypes.cc +++ b/xla/tsl/python/lib/core/ml_dtypes.cc @@ -12,18 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/python/lib/core/ml_dtypes.h" +#include "xla/tsl/python/lib/core/ml_dtypes.h" #include #include +// Must be included first to ensure `NPY_NO_DEPRECATED_API` is defined. +// clang-format off +#include "xla/tsl/python/lib/core/numpy.h" // IWYU pragma: keep +// clang-format on #include "numpy/ndarraytypes.h" #include "absl/base/attributes.h" #include "absl/base/call_once.h" #include "pybind11/gil.h" // from @pybind11 #include "pybind11/numpy.h" // from @pybind11 #include "pybind11/pybind11.h" // from @pybind11 -#include "tsl/python/lib/core/numpy.h" // IWYU pragma: keep namespace tsl { namespace ml_dtypes { @@ -71,6 +74,7 @@ struct MlDtypesInitInfo { numpy_dtypes.int4 = py::dtype::from_args(ml_dtypes.attr("int4")).num(); numpy_dtypes.uint4 = py::dtype::from_args(ml_dtypes.attr("uint4")).num(); } catch (const std::exception& e) { + py::gil_scoped_acquire acquire; py::print(e.what()); init_valid = false; } diff --git a/third_party/tsl/tsl/python/lib/core/ml_dtypes.h b/xla/tsl/python/lib/core/ml_dtypes.h similarity index 90% rename from third_party/tsl/tsl/python/lib/core/ml_dtypes.h rename to xla/tsl/python/lib/core/ml_dtypes.h index f2b93ebee41a2..bf9eab2200a76 100644 --- a/third_party/tsl/tsl/python/lib/core/ml_dtypes.h +++ b/xla/tsl/python/lib/core/ml_dtypes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PYTHON_LIB_CORE_ML_DTYPES_H_ -#define TENSORFLOW_TSL_PYTHON_LIB_CORE_ML_DTYPES_H_ +#ifndef XLA_TSL_PYTHON_LIB_CORE_ML_DTYPES_H_ +#define XLA_TSL_PYTHON_LIB_CORE_ML_DTYPES_H_ // Registers all custom types from the python ml_dtypes package. // https://github.com/jax-ml/ml_dtypes @@ -47,4 +47,4 @@ inline int GetBfloat16TypeNum() { return GetNumpyDtypes().bfloat16; } } // namespace ml_dtypes } // namespace tsl -#endif // TENSORFLOW_TSL_PYTHON_LIB_CORE_ML_DTYPES_H_ +#endif // XLA_TSL_PYTHON_LIB_CORE_ML_DTYPES_H_ diff --git a/third_party/tsl/tsl/python/lib/core/numpy.cc b/xla/tsl/python/lib/core/numpy.cc similarity index 95% rename from third_party/tsl/tsl/python/lib/core/numpy.cc rename to xla/tsl/python/lib/core/numpy.cc index 3013a1a7c68d4..3f54df1281c2d 100644 --- a/third_party/tsl/tsl/python/lib/core/numpy.cc +++ b/xla/tsl/python/lib/core/numpy.cc @@ -17,7 +17,7 @@ limitations under the License. // ImportNumpy function to populate it. #define XLA_IMPORT_NUMPY -#include "tsl/python/lib/core/numpy.h" +#include "xla/tsl/python/lib/core/numpy.h" namespace tsl { diff --git a/third_party/tsl/tsl/python/lib/core/numpy.h b/xla/tsl/python/lib/core/numpy.h similarity index 80% rename from third_party/tsl/tsl/python/lib/core/numpy.h rename to xla/tsl/python/lib/core/numpy.h index 8cbe3ab74e20d..6a5a6a6486ccf 100644 --- a/third_party/tsl/tsl/python/lib/core/numpy.h +++ b/xla/tsl/python/lib/core/numpy.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PYTHON_LIB_CORE_NUMPY_H_ -#define TENSORFLOW_TSL_PYTHON_LIB_CORE_NUMPY_H_ +#ifndef XLA_TSL_PYTHON_LIB_CORE_NUMPY_H_ +#define XLA_TSL_PYTHON_LIB_CORE_NUMPY_H_ #ifdef PyArray_Type #error "Numpy cannot be included before numpy.h." @@ -29,13 +29,16 @@ limitations under the License. #define NO_IMPORT_ARRAY #endif +// clang-format off // Place `` before to avoid build failure in macOS. #include #include +// clang-format on -#include "numpy/arrayobject.h" -#include "numpy/ufuncobject.h" +#include "numpy/arrayobject.h" // IWYU pragma: export +#include "numpy/npy_common.h" // IWYU pragma: export +#include "numpy/ufuncobject.h" // IWYU pragma: export namespace tsl { @@ -47,4 +50,4 @@ void ImportNumpy(); } // namespace tsl -#endif // TENSORFLOW_TSL_PYTHON_LIB_CORE_NUMPY_H_ +#endif // XLA_TSL_PYTHON_LIB_CORE_NUMPY_H_ diff --git a/xla/tsl/util/BUILD b/xla/tsl/util/BUILD new file mode 100644 index 0000000000000..0e17cdca1c7d1 --- /dev/null +++ b/xla/tsl/util/BUILD @@ -0,0 +1,344 @@ +# Description: +# Tensor Standard Libraries. +# +# The libraries in this package are not allowed to have ANY dependencies +# to other TF components outside of TSL. + +load( + "@tsl//tsl:tsl.bzl", + "check_deps", + "internal_visibility", + "tsl_copts", +) +load("@tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") +load( + "@tsl//tsl/platform:build_config.bzl", + "tsl_cc_test", +) +load( + "@tsl//tsl/platform:build_config_root.bzl", + "if_static", +) +load( + "@tsl//tsl/platform:rules_cc.bzl", + "cc_library", +) + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], +) + +filegroup( + name = "mobile_srcs_no_runtime", + srcs = [ + "byte_swap_array.cc", + "byte_swap_array.h", + ], +) + +filegroup( + name = "mobile_srcs_only_runtime", + srcs = [ + "command_line_flags.cc", + "command_line_flags.h", + "determinism.cc", + "determinism.h", + "device_name_utils.cc", + "device_name_utils.h", + "env_var.cc", + "env_var.h", + "use_cudnn.cc", + "use_cudnn.h", + ], +) + +filegroup( + name = "determnism_hdr", + srcs = [ + "determinism.h", + ], + compatible_with = get_compatible_with_portable(), + visibility = internal_visibility([ + "//tensorflow:__subpackages__", + "//tensorflow/core/util:__pkg__", + ]), +) + +filegroup( + name = "framework_internal_private_hdrs", + srcs = [ + "byte_swap_array.h", + "command_line_flags.h", + "device_name_utils.h", + "env_var.h", + "stat_summarizer_options.h", + "stats_calculator.h", + "use_cudnn.h", + ], +) + +filegroup( + name = "framework_internal_impl_srcs", + srcs = [ + "use_cudnn.cc", + ], +) + +filegroup( + name = "lib_internal_public_hdrs", + srcs = [ + "command_line_flags.h", + "env_var.h", + "use_cudnn.h", + ], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/util:__pkg__", + ]), +) + +filegroup( + name = "determinism_hdr", + srcs = [ + "determinism.h", + ], + compatible_with = get_compatible_with_portable(), + visibility = internal_visibility([ + "//tensorflow:__subpackages__", + "//tensorflow/core/util:__pkg__", + ]), +) + +filegroup( + name = "framework_srcs", + srcs = [ + "device_name_utils.h", + "stat_summarizer_options.h", + "use_cudnn.h", + ], +) + +cc_library( + name = "byte_swap_array", + srcs = ["byte_swap_array.cc"], + hdrs = ["byte_swap_array.h"], + deps = [ + "@tsl//tsl/platform:byte_order", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:status", + ], +) + +cc_library( + name = "determinism_hdr_lib", + hdrs = [":determinism_hdr"], + compatible_with = get_compatible_with_portable(), + # TODO(b/298501506): narrow this in a way that won't break TAP + visibility = ["//visibility:public"], +) + +# Note: This rule should not be used as a dependency for kernels. Use the +# "determinism_for_kernels" rule below instead. +cc_library( + name = "determinism", + srcs = ["determinism.cc"], + hdrs = ["determinism.h"], + copts = tsl_copts(), + visibility = internal_visibility(["//tensorflow:__subpackages__"]), + deps = [ + ":env_var", + "@com_google_absl//absl/strings", + "@tsl//tsl/platform:mutex", + ], + alwayslink = 1, +) + +# This alias should be used as a dependency for kernels which use determinism, +# as well any other rules which are in the same shared library as the kernels. +# This rule does not include the determinism.cc file for nonstatic builds. The +# reason is that for nonstatic builds, the shared object which contains the +# kernels (e.g. _pywrap_tensorflow_internal.so) must not contain the global +# variable in determinism.cc, since the global variable is already in +# libtensorflow_framework.so. +# +# To test that determinism.cc is not improperly included in the shared object +# which contains the kernels, you can run the "determinism_check_deps" rule +# below. +alias( + name = "determinism_for_kernels", + actual = if_static(":determinism", ":determinism_hdr_lib"), + visibility = internal_visibility(["//tensorflow:__subpackages__"]), +) + +check_deps( + name = "determinism_check_deps", + disallowed_deps = if_static( + [], + otherwise = [":determinism"], + ), + deps = [ + ], +) + +cc_library( + name = "determinism_test_util", + hdrs = [":determinism_test_util.h"], + data = [ + # Adding this data dependency ensures determinism_check_deps is run + # whenever determinism tests are run. + ":determinism_check_deps", + ], + deps = [":determinism"], +) + +cc_library( + name = "env_var", + srcs = ["env_var.cc"], + hdrs = ["env_var.h"], + deps = [ + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:numbers", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:str_util", + "@tsl//tsl/platform:strcat", + "@tsl//tsl/platform:stringpiece", + "@tsl//tsl/platform:types", + ], +) + +cc_library( + name = "reporter", + srcs = ["reporter.cc"], + hdrs = ["reporter.h"], + visibility = internal_visibility([ + "//tensorflow/core:__subpackages__", + "@tsl//tsl:__subpackages__", + ]), + deps = [ + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:env_impl", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:macros", + "@tsl//tsl/platform:mutex", + "@tsl//tsl/platform:str_util", + "@tsl//tsl/platform:types", + "@tsl//tsl/protobuf:test_log_proto_cc", + ], +) + +cc_library( + name = "stats_calculator_portable", + srcs = [ + "stats_calculator.cc", + ], + hdrs = [ + "stat_summarizer_options.h", + "stats_calculator.h", + ], + copts = tsl_copts(), + visibility = internal_visibility([ + "@tsl//tsl:internal", + ]), +) + +tsl_cc_test( + name = "stats_calculator_test", + srcs = ["stats_calculator_test.cc"], + deps = [ + ":stats_calculator_portable", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "device_name_utils", + srcs = ["device_name_utils.cc"], + hdrs = ["device_name_utils.h"], + deps = [ + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:stringpiece", + ], +) + +tsl_cc_test( + name = "device_name_utils_test", + size = "small", + srcs = ["device_name_utils_test.cc"], + deps = [ + ":device_name_utils", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:strcat", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_benchmark", + "@tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "command_line_flags", + srcs = ["command_line_flags.cc"], + hdrs = ["command_line_flags.h"], + deps = [ + "@com_google_absl//absl/strings", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:str_util", + "@tsl//tsl/platform:stringpiece", + "@tsl//tsl/platform:stringprintf", + "@tsl//tsl/platform:types", + ], +) + +filegroup( + name = "test_hdrs", + testonly = 1, + srcs = [ + "reporter.h", + ], + visibility = internal_visibility(["//tensorflow/core/util:__pkg__"]), +) + +filegroup( + name = "onednn_util_hdrs", + srcs = [ + "onednn_threadpool.h", + ], + visibility = internal_visibility([ + "//xla:__subpackages__", + "//tensorflow/core:__pkg__", + "//tensorflow/core/framework:__pkg__", + "//tensorflow/core/util:__pkg__", + ]), +) + +filegroup( + name = "android_test_hdrs", + testonly = 1, + srcs = [ + "reporter.h", + ], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/util:__pkg__", + ]), +) + +filegroup( + name = "android_test_srcs", + testonly = 1, + srcs = [ + "reporter.cc", + ":android_test_hdrs", + ], + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "//tensorflow/core/util:__pkg__", + ]), +) diff --git a/third_party/tsl/tsl/util/byte_swap_array.cc b/xla/tsl/util/byte_swap_array.cc similarity index 86% rename from third_party/tsl/tsl/util/byte_swap_array.cc rename to xla/tsl/util/byte_swap_array.cc index e77e4bab8defc..2c80e8cb928d0 100644 --- a/third_party/tsl/tsl/util/byte_swap_array.cc +++ b/xla/tsl/util/byte_swap_array.cc @@ -13,34 +13,34 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/util/byte_swap_array.h" +#include "xla/tsl/util/byte_swap_array.h" #include "tsl/platform/errors.h" namespace tsl { -Status ByteSwapArray(char* array, size_t bytes_per_elem, int array_len) { +absl::Status ByteSwapArray(char* array, size_t bytes_per_elem, int array_len) { if (bytes_per_elem == 1) { // No-op - return OkStatus(); + return absl::OkStatus(); } else if (bytes_per_elem == 2) { auto array_16 = reinterpret_cast(array); for (int i = 0; i < array_len; i++) { array_16[i] = BYTE_SWAP_16(array_16[i]); } - return OkStatus(); + return absl::OkStatus(); } else if (bytes_per_elem == 4) { auto array_32 = reinterpret_cast(array); for (int i = 0; i < array_len; i++) { array_32[i] = BYTE_SWAP_32(array_32[i]); } - return OkStatus(); + return absl::OkStatus(); } else if (bytes_per_elem == 8) { auto array_64 = reinterpret_cast(array); for (int i = 0; i < array_len; i++) { array_64[i] = BYTE_SWAP_64(array_64[i]); } - return OkStatus(); + return absl::OkStatus(); } else { return errors::Unimplemented("Byte-swapping of ", bytes_per_elem, "-byte values not supported."); diff --git a/third_party/tsl/tsl/util/byte_swap_array.h b/xla/tsl/util/byte_swap_array.h similarity index 94% rename from third_party/tsl/tsl/util/byte_swap_array.h rename to xla/tsl/util/byte_swap_array.h index ad7e34efcd51f..43ee9f77a38ba 100644 --- a/third_party/tsl/tsl/util/byte_swap_array.h +++ b/xla/tsl/util/byte_swap_array.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_BYTE_SWAP_ARRAY_H_ -#define TENSORFLOW_TSL_UTIL_BYTE_SWAP_ARRAY_H_ +#ifndef XLA_TSL_UTIL_BYTE_SWAP_ARRAY_H_ +#define XLA_TSL_UTIL_BYTE_SWAP_ARRAY_H_ #include "tsl/platform/byte_order.h" #include "tsl/platform/errors.h" @@ -97,8 +97,8 @@ namespace tsl { // // Returns: OkStatus() on success, -1 otherwise // -Status ByteSwapArray(char *array, size_t bytes_per_elem, int array_len); +absl::Status ByteSwapArray(char *array, size_t bytes_per_elem, int array_len); } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_BYTE_SWAP_ARRAY_H_ +#endif // XLA_TSL_UTIL_BYTE_SWAP_ARRAY_H_ diff --git a/third_party/tsl/tsl/util/command_line_flags.cc b/xla/tsl/util/command_line_flags.cc similarity index 99% rename from third_party/tsl/tsl/util/command_line_flags.cc rename to xla/tsl/util/command_line_flags.cc index 5e316e9ae9fc6..f5a97a50eb198 100644 --- a/third_party/tsl/tsl/util/command_line_flags.cc +++ b/xla/tsl/util/command_line_flags.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/util/command_line_flags.h" +#include "xla/tsl/util/command_line_flags.h" #include #include diff --git a/third_party/tsl/tsl/util/command_line_flags.h b/xla/tsl/util/command_line_flags.h similarity index 97% rename from third_party/tsl/tsl/util/command_line_flags.h rename to xla/tsl/util/command_line_flags.h index 2710de5753cd0..d4b3efd662a94 100644 --- a/third_party/tsl/tsl/util/command_line_flags.h +++ b/xla/tsl/util/command_line_flags.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_COMMAND_LINE_FLAGS_H_ -#define TENSORFLOW_TSL_UTIL_COMMAND_LINE_FLAGS_H_ +#ifndef XLA_TSL_UTIL_COMMAND_LINE_FLAGS_H_ +#define XLA_TSL_UTIL_COMMAND_LINE_FLAGS_H_ #include #include @@ -145,4 +145,4 @@ class Flags { } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_COMMAND_LINE_FLAGS_H_ +#endif // XLA_TSL_UTIL_COMMAND_LINE_FLAGS_H_ diff --git a/third_party/tsl/tsl/util/determinism.cc b/xla/tsl/util/determinism.cc similarity index 96% rename from third_party/tsl/tsl/util/determinism.cc rename to xla/tsl/util/determinism.cc index b9a5abd9af40d..6089cc96458dc 100644 --- a/third_party/tsl/tsl/util/determinism.cc +++ b/xla/tsl/util/determinism.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/util/determinism.h" +#include "xla/tsl/util/determinism.h" #include "absl/strings/string_view.h" +#include "xla/tsl/util/env_var.h" #include "tsl/platform/mutex.h" -#include "tsl/util/env_var.h" namespace tsl { diff --git a/third_party/tsl/tsl/util/determinism.h b/xla/tsl/util/determinism.h similarity index 86% rename from third_party/tsl/tsl/util/determinism.h rename to xla/tsl/util/determinism.h index fff5b195845a3..2f1861ed60a23 100644 --- a/third_party/tsl/tsl/util/determinism.h +++ b/xla/tsl/util/determinism.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_DETERMINISM_H_ -#define TENSORFLOW_TSL_UTIL_DETERMINISM_H_ +#ifndef XLA_TSL_UTIL_DETERMINISM_H_ +#define XLA_TSL_UTIL_DETERMINISM_H_ namespace tsl { @@ -24,4 +24,4 @@ void EnableOpDeterminism(bool enabled); } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_DETERMINISM_H_ +#endif // XLA_TSL_UTIL_DETERMINISM_H_ diff --git a/third_party/tsl/tsl/util/determinism_test_util.h b/xla/tsl/util/determinism_test_util.h similarity index 84% rename from third_party/tsl/tsl/util/determinism_test_util.h rename to xla/tsl/util/determinism_test_util.h index e458dc9cdacc5..34b4552bb62d6 100644 --- a/third_party/tsl/tsl/util/determinism_test_util.h +++ b/xla/tsl/util/determinism_test_util.h @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_DETERMINISM_TEST_UTIL_H_ -#define TENSORFLOW_TSL_UTIL_DETERMINISM_TEST_UTIL_H_ +#ifndef XLA_TSL_UTIL_DETERMINISM_TEST_UTIL_H_ +#define XLA_TSL_UTIL_DETERMINISM_TEST_UTIL_H_ -#include "tsl/util/determinism.h" +#include "xla/tsl/util/determinism.h" namespace tsl { namespace test { @@ -35,4 +35,4 @@ class DeterministicOpsScope { } // namespace test } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_DETERMINISM_TEST_UTIL_H_ +#endif // XLA_TSL_UTIL_DETERMINISM_TEST_UTIL_H_ diff --git a/third_party/tsl/tsl/util/device_name_utils.cc b/xla/tsl/util/device_name_utils.cc similarity index 94% rename from third_party/tsl/tsl/util/device_name_utils.cc rename to xla/tsl/util/device_name_utils.cc index 0920532c62edd..6812071505ae0 100644 --- a/third_party/tsl/tsl/util/device_name_utils.cc +++ b/xla/tsl/util/device_name_utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/util/device_name_utils.h" +#include "xla/tsl/util/device_name_utils.h" #include @@ -205,9 +205,9 @@ void CompleteName(const DeviceNameUtils::ParsedName& parsed_basename, } // namespace /* static */ -Status DeviceNameUtils::CanonicalizeDeviceName(StringPiece fullname, - StringPiece basename, - string* canonical_name) { +absl::Status DeviceNameUtils::CanonicalizeDeviceName(StringPiece fullname, + StringPiece basename, + string* canonical_name) { *canonical_name = ""; ParsedName parsed_basename; if (!ParseFullName(basename, &parsed_basename)) { @@ -225,12 +225,12 @@ Status DeviceNameUtils::CanonicalizeDeviceName(StringPiece fullname, if (ParseLocalName(fullname, &parsed_name)) { CompleteName(parsed_basename, &parsed_name); *canonical_name = ParsedNameToString(parsed_name); - return OkStatus(); + return absl::OkStatus(); } if (ParseFullName(fullname, &parsed_name)) { CompleteName(parsed_basename, &parsed_name); *canonical_name = ParsedNameToString(parsed_name); - return OkStatus(); + return absl::OkStatus(); } return errors::InvalidArgument("Could not parse ", fullname, " into a device " @@ -341,9 +341,10 @@ bool DeviceNameUtils::IsCompleteSpecification(const ParsedName& pattern, } namespace { -Status MergeDevNamesImpl(DeviceNameUtils::ParsedName* target, - const DeviceNameUtils::ParsedName& other, - bool allow_soft_placement, bool override_conflicts) { +absl::Status MergeDevNamesImpl(DeviceNameUtils::ParsedName* target, + const DeviceNameUtils::ParsedName& other, + bool allow_soft_placement, + bool override_conflicts) { const auto& ParsedNameToString = DeviceNameUtils::ParsedNameToString; if (other.has_job) { if (target->has_job && target->job != other.job) { @@ -393,7 +394,7 @@ Status MergeDevNamesImpl(DeviceNameUtils::ParsedName* target, } else { target->has_id = false; target->has_type = false; - return OkStatus(); + return absl::OkStatus(); } } else { target->has_type = other.has_type; @@ -412,7 +413,7 @@ Status MergeDevNamesImpl(DeviceNameUtils::ParsedName* target, target->id = other.id; } else { target->has_id = false; - return OkStatus(); + return absl::OkStatus(); } } else { target->has_id = other.has_id; @@ -420,22 +421,22 @@ Status MergeDevNamesImpl(DeviceNameUtils::ParsedName* target, } } - return OkStatus(); + return absl::OkStatus(); } } // namespace /* static */ -Status DeviceNameUtils::MergeDevNames(ParsedName* target, - const ParsedName& other, - bool allow_soft_placement) { +absl::Status DeviceNameUtils::MergeDevNames(ParsedName* target, + const ParsedName& other, + bool allow_soft_placement) { return MergeDevNamesImpl(target, other, allow_soft_placement, /*override_conflicts=*/false); } /* static */ -Status DeviceNameUtils::MergeOverrideDevNames(ParsedName* target, - const ParsedName& other) { +absl::Status DeviceNameUtils::MergeOverrideDevNames(ParsedName* target, + const ParsedName& other) { return MergeDevNamesImpl(target, other, /*allow_soft_placement=*/true, /*override_conflicts=*/true); } @@ -604,7 +605,7 @@ std::vector DeviceNameUtils::GetLocalNamesForDeviceMappings( } } -/*static*/ Status DeviceNameUtils::DeviceNameToCpuDeviceName( +/*static*/ absl::Status DeviceNameUtils::DeviceNameToCpuDeviceName( const string& device_name, string* host_device_name) { DeviceNameUtils::ParsedName device; if (!DeviceNameUtils::ParseFullName(device_name, &device)) { @@ -615,7 +616,7 @@ std::vector DeviceNameUtils::GetLocalNamesForDeviceMappings( device.id = 0; device.has_id = true; *host_device_name = DeviceNameUtils::ParsedNameToString(device); - return OkStatus(); + return absl::OkStatus(); } std::ostream& operator<<(std::ostream& os, diff --git a/third_party/tsl/tsl/util/device_name_utils.h b/xla/tsl/util/device_name_utils.h similarity index 90% rename from third_party/tsl/tsl/util/device_name_utils.h rename to xla/tsl/util/device_name_utils.h index 51dcca37ac054..90152831444b7 100644 --- a/third_party/tsl/tsl/util/device_name_utils.h +++ b/xla/tsl/util/device_name_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_DEVICE_NAME_UTILS_H_ -#define TENSORFLOW_TSL_UTIL_DEVICE_NAME_UTILS_H_ +#ifndef XLA_TSL_UTIL_DEVICE_NAME_UTILS_H_ +#define XLA_TSL_UTIL_DEVICE_NAME_UTILS_H_ #include @@ -131,9 +131,9 @@ class DeviceNameUtils { template friend H AbslHashValue(H h, const ParsedName& n) { - return H::combine(std::move(h), n.has_job, n.job, n.has_replica, - n.replica, n.has_task, n.task, n.has_type, n.type, - n.has_id, n.id); + return H::combine(std::move(h), n.has_job ? n.job : "", + n.has_replica ? n.replica : 0, n.has_task ? n.task : 0, + n.has_type ? n.type : "", n.has_id ? n.id : 0); } bool has_job = false; @@ -168,9 +168,9 @@ class DeviceNameUtils { // and local versions of the device spec. Returns the newer version of the // device spec. If we were unable to interpret / parse "fullname" returns // an error and *canonical_name is set to "". - static Status CanonicalizeDeviceName(StringPiece fullname, - StringPiece basename, - std::string* canonical_name); + static absl::Status CanonicalizeDeviceName(StringPiece fullname, + StringPiece basename, + std::string* canonical_name); // Returns true if "name" specifies any non-trivial constraint on the device. static bool HasSomeDetails(const ParsedName& name) { @@ -202,15 +202,16 @@ class DeviceNameUtils { // Merges the device specifications in "*target" and "other", and // stores the result in "*target". Returns OK if "*target" and // "other" are compatible, otherwise returns an error. - static Status MergeDevNames(ParsedName* target, const ParsedName& other) { + static absl::Status MergeDevNames(ParsedName* target, + const ParsedName& other) { return MergeDevNames(target, other, false); } - static Status MergeDevNames(ParsedName* target, const ParsedName& other, - bool allow_soft_placement); + static absl::Status MergeDevNames(ParsedName* target, const ParsedName& other, + bool allow_soft_placement); // Same as MergeDevNames with allow_soft_placement=true, but instead of // clearing conflicting fields, overrides them with `other`'s values. - static Status MergeOverrideDevNames(ParsedName* target, - const ParsedName& other); + static absl::Status MergeOverrideDevNames(ParsedName* target, + const ParsedName& other); // Merges the device specifications in "*target" and "other", and // stores the result in "*target" by setting all unset values in target with @@ -271,8 +272,8 @@ class DeviceNameUtils { // Returns name of the CPU:0 device on the same host as the device // `device_name`. - static Status DeviceNameToCpuDeviceName(const std::string& device_name, - std::string* host_device_name); + static absl::Status DeviceNameToCpuDeviceName(const std::string& device_name, + std::string* host_device_name); static bool CompareFullNames(const StringPiece& a, const StringPiece& b) { ParsedName parsed_a; @@ -291,4 +292,4 @@ std::ostream& operator<<(std::ostream& os, } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_DEVICE_NAME_UTILS_H_ +#endif // XLA_TSL_UTIL_DEVICE_NAME_UTILS_H_ diff --git a/third_party/tsl/tsl/util/device_name_utils_test.cc b/xla/tsl/util/device_name_utils_test.cc similarity index 99% rename from third_party/tsl/tsl/util/device_name_utils_test.cc rename to xla/tsl/util/device_name_utils_test.cc index dce1fc5807604..1f5f5114550d4 100644 --- a/third_party/tsl/tsl/util/device_name_utils_test.cc +++ b/xla/tsl/util/device_name_utils_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/util/device_name_utils.h" +#include "xla/tsl/util/device_name_utils.h" #include @@ -426,7 +426,7 @@ static void MergeDevNamesHelperAllowSoftPlacement( static void MergeDevNamesError(const string& name_a, const string& name_b, const string& expected_error_substr) { DeviceNameUtils::ParsedName target_a = Name(name_a); - Status s = DeviceNameUtils::MergeDevNames(&target_a, Name(name_b)); + absl::Status s = DeviceNameUtils::MergeDevNames(&target_a, Name(name_b)); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); EXPECT_TRUE(absl::StrContains(s.message(), expected_error_substr)) << s; } @@ -607,7 +607,7 @@ TEST(DeviceNameUtilsTest, CanonicalizeDeviceName) { TF_EXPECT_OK(DeviceNameUtils::CanonicalizeDeviceName("CPU:0", basename, &canonical_name)); EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:0", canonical_name); - Status s = DeviceNameUtils::CanonicalizeDeviceName( + absl::Status s = DeviceNameUtils::CanonicalizeDeviceName( "/job:foo/task:0/replica/cpu:1", basename, &canonical_name); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); EXPECT_EQ("", canonical_name); @@ -617,7 +617,7 @@ TEST(DeviceNameUtilsTest, CanonicalizeDeviceName) { // Try out malformed basenames. string fullname = "/device:CPU:0"; - Status s = DeviceNameUtils::CanonicalizeDeviceName( + absl::Status s = DeviceNameUtils::CanonicalizeDeviceName( fullname, "/device:CPU:0", &canonical_name); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); EXPECT_EQ("", canonical_name); diff --git a/xla/tsl/util/env_var.cc b/xla/tsl/util/env_var.cc new file mode 100644 index 0000000000000..217e851ff175b --- /dev/null +++ b/xla/tsl/util/env_var.cc @@ -0,0 +1,98 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tsl/util/env_var.h" + +#include + +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/numbers.h" +#include "tsl/platform/str_util.h" +#include "tsl/platform/strcat.h" + +namespace tsl { + +absl::Status ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val, + bool* value) { + *value = default_val; + const char* tf_env_var_val = getenv(string(env_var_name).c_str()); + if (tf_env_var_val == nullptr) { + return absl::OkStatus(); + } + string str_value = absl::AsciiStrToLower(tf_env_var_val); + if (str_value == "0" || str_value == "false") { + *value = false; + return absl::OkStatus(); + } else if (str_value == "1" || str_value == "true") { + *value = true; + return absl::OkStatus(); + } + return errors::InvalidArgument(strings::StrCat( + "Failed to parse the env-var ${", env_var_name, "} into bool: ", + tf_env_var_val, ". Use the default value: ", default_val)); +} + +absl::Status ReadInt64FromEnvVar(StringPiece env_var_name, int64_t default_val, + int64_t* value) { + *value = default_val; + const char* tf_env_var_val = getenv(string(env_var_name).c_str()); + if (tf_env_var_val == nullptr) { + return absl::OkStatus(); + } + if (strings::safe_strto64(tf_env_var_val, value)) { + return absl::OkStatus(); + } + return errors::InvalidArgument(strings::StrCat( + "Failed to parse the env-var ${", env_var_name, "} into int64: ", + tf_env_var_val, ". Use the default value: ", default_val)); +} + +absl::Status ReadFloatFromEnvVar(StringPiece env_var_name, float default_val, + float* value) { + *value = default_val; + const char* tf_env_var_val = getenv(string(env_var_name).c_str()); + if (tf_env_var_val == nullptr) { + return absl::OkStatus(); + } + if (strings::safe_strtof(tf_env_var_val, value)) { + return absl::OkStatus(); + } + return errors::InvalidArgument(strings::StrCat( + "Failed to parse the env-var ${", env_var_name, "} into float: ", + tf_env_var_val, ". Use the default value: ", default_val)); +} + +absl::Status ReadStringFromEnvVar(StringPiece env_var_name, + StringPiece default_val, string* value) { + const char* tf_env_var_val = getenv(string(env_var_name).c_str()); + if (tf_env_var_val != nullptr) { + *value = tf_env_var_val; + } else { + *value = string(default_val); + } + return absl::OkStatus(); +} + +absl::Status ReadStringsFromEnvVar(StringPiece env_var_name, + StringPiece default_val, + std::vector* value) { + string str_val; + TF_RETURN_IF_ERROR(ReadStringFromEnvVar(env_var_name, default_val, &str_val)); + *value = str_util::Split(str_val, ','); + return absl::OkStatus(); +} + +} // namespace tsl diff --git a/xla/tsl/util/env_var.h b/xla/tsl/util/env_var.h new file mode 100644 index 0000000000000..d3b422bd2bb6d --- /dev/null +++ b/xla/tsl/util/env_var.h @@ -0,0 +1,57 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TSL_UTIL_ENV_VAR_H_ +#define XLA_TSL_UTIL_ENV_VAR_H_ + +#include "tsl/platform/status.h" +#include "tsl/platform/stringpiece.h" +#include "tsl/platform/types.h" + +namespace tsl { + +// Returns a boolean into "value" from the environmental variable +// "env_var_name". If it is unset, the default value is used. A string "0" or a +// case insensitive "false" is interpreted as false. A string "1" or a case +// insensitive "true" is interpreted as true. Otherwise, an error status is +// returned. +absl::Status ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val, + bool* value); + +// Returns an int64 into "value" from the environmental variable "env_var_name". +// If it is unset, the default value is used. +// If the string cannot be parsed into int64, an error status is returned. +absl::Status ReadInt64FromEnvVar(StringPiece env_var_name, int64_t default_val, + int64_t* value); +// Returns a float into "value" from the environmental variable "env_var_name". +// If it is unset, the default value is used. +// If the string cannot be parsed into float, an error status is returned. +absl::Status ReadFloatFromEnvVar(StringPiece env_var_name, float default_val, + float* value); + +// Returns a string into "value" from the environmental variable "env_var_name". +// If it is unset, the default value is used. +absl::Status ReadStringFromEnvVar(StringPiece env_var_name, + StringPiece default_val, std::string* value); + +// Returns a comma separated string into "value" from the environmental variable +// "env_var_name". If it is unset, the default value is comma split and used. +absl::Status ReadStringsFromEnvVar(StringPiece env_var_name, + StringPiece default_val, + std::vector* value); + +} // namespace tsl + +#endif // XLA_TSL_UTIL_ENV_VAR_H_ diff --git a/third_party/tsl/tsl/util/onednn_threadpool.h b/xla/tsl/util/onednn_threadpool.h similarity index 94% rename from third_party/tsl/tsl/util/onednn_threadpool.h rename to xla/tsl/util/onednn_threadpool.h index 82fbec738f00e..0c81806352f86 100644 --- a/third_party/tsl/tsl/util/onednn_threadpool.h +++ b/xla/tsl/util/onednn_threadpool.h @@ -14,8 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_ONEDNN_THREADPOOL_H_ -#define TENSORFLOW_TSL_UTIL_ONEDNN_THREADPOOL_H_ +#ifndef XLA_TSL_UTIL_ONEDNN_THREADPOOL_H_ +#define XLA_TSL_UTIL_ONEDNN_THREADPOOL_H_ #ifdef INTEL_MKL #include @@ -151,6 +151,16 @@ class OneDnnThreadPool : public threadpool_iface { ~OneDnnThreadPool() {} + static void set_onednn_max_threads(int num_threads) { +#if DNNL_VERSION_MAJOR >= 3 || \ + (DNNL_VERSION_MAJOR == 2 && DNNL_VERSION_MINOR >= 7) +#ifndef DNNL_AARCH64_USE_ACL + dnnl_threadpool_interop_set_max_concurrency(num_threads); +#endif // DNNL_AARCH64_USE_ACL +#endif // DNNL_VERSION_MAJOR >= 3 || + // (DNNL_VERSION_MAJOR == 2 && DNNL_VERSION_MINOR >= 7) + } + private: Eigen::ThreadPoolInterface* eigen_interface_ = nullptr; int num_threads_ = 1; // Execute in caller thread. @@ -159,13 +169,7 @@ class OneDnnThreadPool : public threadpool_iface { inline void set_num_and_max_threads(int num_threads) { num_threads_ = num_threads == -1 ? eigen_interface_->NumThreads() : num_threads; -#if DNNL_VERSION_MAJOR >= 3 || \ - (DNNL_VERSION_MAJOR == 2 && DNNL_VERSION_MINOR >= 7) -#ifndef DNNL_AARCH64_USE_ACL - dnnl_threadpool_interop_set_max_concurrency(num_threads_); -#endif // DNNL_AARCH64_USE_ACL -#endif // DNNL_VERSION_MAJOR >= 3 || - // (DNNL_VERSION_MAJOR == 2 && DNNL_VERSION_MINOR >= 7) + set_onednn_max_threads(num_threads_); } }; @@ -178,6 +182,7 @@ class OneDnnThreadPool { OneDnnThreadPool(Eigen::ThreadPoolInterface* eigen_interface) {} OneDnnThreadPool(Eigen::ThreadPoolInterface* eigen_interface, bool can_use_caller_thread, int num_threads = -1) {} + static void set_onednn_max_threads(int num_threads) {} }; #endif // !ENABLE_ONEDNN_OPENMP @@ -185,4 +190,4 @@ class OneDnnThreadPool { } // namespace tsl #endif // INTEL_MKL -#endif // TENSORFLOW_TSL_UTIL_ONEDNN_THREADPOOL_H_ +#endif // XLA_TSL_UTIL_ONEDNN_THREADPOOL_H_ diff --git a/xla/tsl/util/proto/BUILD b/xla/tsl/util/proto/BUILD new file mode 100644 index 0000000000000..c5926d91002dd --- /dev/null +++ b/xla/tsl/util/proto/BUILD @@ -0,0 +1,21 @@ +load( + "@tsl//tsl/platform:rules_cc.bzl", + "cc_library", +) + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], +) + +cc_library( + name = "proto_utils", + hdrs = ["proto_utils.h"], + deps = [ + "@com_google_absl//absl/time", + "@com_google_protobuf//:protobuf_headers", + ], +) diff --git a/third_party/tsl/tsl/util/proto/proto_utils.h b/xla/tsl/util/proto/proto_utils.h similarity index 90% rename from third_party/tsl/tsl/util/proto/proto_utils.h rename to xla/tsl/util/proto/proto_utils.h index 9a1dee8eed522..2762f4df0e8af 100644 --- a/third_party/tsl/tsl/util/proto/proto_utils.h +++ b/xla/tsl/util/proto/proto_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_PROTO_PROTO_UTILS_H_ -#define TENSORFLOW_TSL_UTIL_PROTO_PROTO_UTILS_H_ +#ifndef XLA_TSL_UTIL_PROTO_PROTO_UTILS_H_ +#define XLA_TSL_UTIL_PROTO_PROTO_UTILS_H_ #include "google/protobuf/duration.pb.h" #include "absl/time/time.h" @@ -39,4 +39,4 @@ inline absl::Duration FromDurationProto(google::protobuf::Duration proto) { } // namespace proto_utils } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_PROTO_PROTO_UTILS_H_ +#endif // XLA_TSL_UTIL_PROTO_PROTO_UTILS_H_ diff --git a/xla/tsl/util/reporter.cc b/xla/tsl/util/reporter.cc new file mode 100644 index 0000000000000..1d08abf7b2e6c --- /dev/null +++ b/xla/tsl/util/reporter.cc @@ -0,0 +1,105 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tsl/util/reporter.h" + +#include "tsl/platform/errors.h" +#include "tsl/platform/mutex.h" +#include "tsl/platform/str_util.h" + +namespace tsl { + +TestReportFile::TestReportFile(const string& fname, const string& test_name) + : closed_(true), fname_(fname), test_name_(test_name) {} + +absl::Status TestReportFile::Append(const string& content) { + if (closed_) return absl::OkStatus(); + return log_file_->Append(content); +} + +absl::Status TestReportFile::Close() { + if (closed_) return absl::OkStatus(); + closed_ = true; + return log_file_->Close(); +} + +absl::Status TestReportFile::Initialize() { + if (fname_.empty()) { + return absl::OkStatus(); + } + string mangled_fname = strings::StrCat( + fname_, absl::StrJoin(str_util::Split(test_name_, '/'), "__")); + Env* env = Env::Default(); + if (env->FileExists(mangled_fname).ok()) { + return errors::InvalidArgument( + "Cannot create TestReportFile, file exists: ", mangled_fname); + } + TF_RETURN_IF_ERROR(env->NewWritableFile(mangled_fname, &log_file_)); + TF_RETURN_IF_ERROR(log_file_->Flush()); + + closed_ = false; + return absl::OkStatus(); +} + +TestReporter::TestReporter(const string& fname, const string& test_name) + : report_file_(fname, test_name) { + benchmark_entry_.set_name(test_name); +} + +absl::Status TestReporter::Close() { + if (report_file_.IsClosed()) return absl::OkStatus(); + + tensorflow::BenchmarkEntries entries; + *entries.add_entry() = benchmark_entry_; + TF_RETURN_IF_ERROR(report_file_.Append(entries.SerializeAsString())); + benchmark_entry_.Clear(); + + return report_file_.Close(); +} + +absl::Status TestReporter::Benchmark(int64_t iters, double cpu_time, + double wall_time, double throughput) { + if (report_file_.IsClosed()) return absl::OkStatus(); + benchmark_entry_.set_iters(iters); + benchmark_entry_.set_cpu_time(cpu_time / iters); + benchmark_entry_.set_wall_time(wall_time / iters); + benchmark_entry_.set_throughput(throughput); + return absl::OkStatus(); +} + +absl::Status TestReporter::SetProperty(const string& name, + const string& value) { + if (report_file_.IsClosed()) return absl::OkStatus(); + (*benchmark_entry_.mutable_extras())[name].set_string_value(value); + return absl::OkStatus(); +} + +absl::Status TestReporter::SetProperty(const string& name, double value) { + if (report_file_.IsClosed()) return absl::OkStatus(); + (*benchmark_entry_.mutable_extras())[name].set_double_value(value); + return absl::OkStatus(); +} + +absl::Status TestReporter::AddMetric(const string& name, double value) { + if (report_file_.IsClosed()) return absl::OkStatus(); + auto* metric = benchmark_entry_.add_metrics(); + metric->set_name(name); + metric->set_value(value); + return absl::OkStatus(); +} + +absl::Status TestReporter::Initialize() { return report_file_.Initialize(); } + +} // namespace tsl diff --git a/third_party/tsl/tsl/util/reporter.h b/xla/tsl/util/reporter.h similarity index 88% rename from third_party/tsl/tsl/util/reporter.h rename to xla/tsl/util/reporter.h index d020e94fae127..6d2969d404e93 100644 --- a/third_party/tsl/tsl/util/reporter.h +++ b/xla/tsl/util/reporter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_REPORTER_H_ -#define TENSORFLOW_TSL_UTIL_REPORTER_H_ +#ifndef XLA_TSL_UTIL_REPORTER_H_ +#define XLA_TSL_UTIL_REPORTER_H_ #include #include @@ -37,13 +37,13 @@ class TestReportFile { // Initialize the TestReportFile. If the reporting env flag is set, // try to create the reporting file. Fails if the file already exists. - Status Initialize(); + absl::Status Initialize(); // Append the report file w/ 'content'. - Status Append(const string& content); + absl::Status Append(const string& content); // Close the report file. - Status Close(); + absl::Status Close(); bool IsClosed() const { return closed_; } @@ -91,29 +91,29 @@ class TestReporter { // Initialize the TestReporter. If the reporting env flag is set, // try to create the reporting file. Fails if the file already exists. - Status Initialize(); + absl::Status Initialize(); // Finalize the report. If the reporting env flag is set, // flush the reporting file and close it. // Once Close is called, no other methods should be called other // than Close and the destructor. - Status Close(); + absl::Status Close(); // Set the report to be a Benchmark and log the given parameters. // Only does something if the reporting env flag is set. // Does not guarantee the report is written. Use Close() to // enforce I/O operations. - Status Benchmark(int64_t iters, double cpu_time, double wall_time, - double throughput); + absl::Status Benchmark(int64_t iters, double cpu_time, double wall_time, + double throughput); // Set property on Benchmark to the given value. - Status SetProperty(const string& name, double value); + absl::Status SetProperty(const string& name, double value); // Set property on Benchmark to the given value. - Status SetProperty(const string& name, const string& value); + absl::Status SetProperty(const string& name, const string& value); // Add the given value to the metrics on the Benchmark. - Status AddMetric(const string& name, double value); + absl::Status AddMetric(const string& name, double value); // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object! ~TestReporter() { Close().IgnoreError(); } // Autoclose in destructor. @@ -131,4 +131,4 @@ class TestReporter { } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_REPORTER_H_ +#endif // XLA_TSL_UTIL_REPORTER_H_ diff --git a/third_party/tsl/tsl/util/stat_summarizer_options.h b/xla/tsl/util/stat_summarizer_options.h similarity index 88% rename from third_party/tsl/tsl/util/stat_summarizer_options.h rename to xla/tsl/util/stat_summarizer_options.h index e07de6e8d5d9d..c3ed6ffd7e48b 100644 --- a/third_party/tsl/tsl/util/stat_summarizer_options.h +++ b/xla/tsl/util/stat_summarizer_options.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_STAT_SUMMARIZER_OPTIONS_H_ -#define TENSORFLOW_TSL_UTIL_STAT_SUMMARIZER_OPTIONS_H_ +#ifndef XLA_TSL_UTIL_STAT_SUMMARIZER_OPTIONS_H_ +#define XLA_TSL_UTIL_STAT_SUMMARIZER_OPTIONS_H_ namespace tsl { // Used to control the output of the statistics summarizer; struct StatSummarizerOptions { @@ -41,4 +41,4 @@ struct StatSummarizerOptions { }; } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_STAT_SUMMARIZER_OPTIONS_H_ +#endif // XLA_TSL_UTIL_STAT_SUMMARIZER_OPTIONS_H_ diff --git a/third_party/tsl/tsl/util/stats_calculator.cc b/xla/tsl/util/stats_calculator.cc similarity index 99% rename from third_party/tsl/tsl/util/stats_calculator.cc rename to xla/tsl/util/stats_calculator.cc index 99ab1e3e7c6bc..cdfa46c94417c 100644 --- a/third_party/tsl/tsl/util/stats_calculator.cc +++ b/xla/tsl/util/stats_calculator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/util/stats_calculator.h" +#include "xla/tsl/util/stats_calculator.h" #include #include diff --git a/third_party/tsl/tsl/util/stats_calculator.h b/xla/tsl/util/stats_calculator.h similarity index 96% rename from third_party/tsl/tsl/util/stats_calculator.h rename to xla/tsl/util/stats_calculator.h index 5c23f432971c2..84045fb6ceece 100644 --- a/third_party/tsl/tsl/util/stats_calculator.h +++ b/xla/tsl/util/stats_calculator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_STATS_CALCULATOR_H_ -#define TENSORFLOW_TSL_UTIL_STATS_CALCULATOR_H_ +#ifndef XLA_TSL_UTIL_STATS_CALCULATOR_H_ +#define XLA_TSL_UTIL_STATS_CALCULATOR_H_ #include @@ -26,7 +26,7 @@ limitations under the License. #include #include -#include "tsl/util/stat_summarizer_options.h" +#include "xla/tsl/util/stat_summarizer_options.h" namespace tsl { @@ -198,4 +198,4 @@ class StatsCalculator { } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_STATS_CALCULATOR_H_ +#endif // XLA_TSL_UTIL_STATS_CALCULATOR_H_ diff --git a/third_party/tsl/tsl/util/stats_calculator_test.cc b/xla/tsl/util/stats_calculator_test.cc similarity index 98% rename from third_party/tsl/tsl/util/stats_calculator_test.cc rename to xla/tsl/util/stats_calculator_test.cc index 9093701e4478c..d58186630598f 100644 --- a/third_party/tsl/tsl/util/stats_calculator_test.cc +++ b/xla/tsl/util/stats_calculator_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/util/stats_calculator.h" +#include "xla/tsl/util/stats_calculator.h" #include diff --git a/third_party/tsl/tsl/util/use_cudnn.cc b/xla/tsl/util/use_cudnn.cc similarity index 98% rename from third_party/tsl/tsl/util/use_cudnn.cc rename to xla/tsl/util/use_cudnn.cc index 3156a319b73b3..a3e1b4d25d266 100644 --- a/third_party/tsl/tsl/util/use_cudnn.cc +++ b/xla/tsl/util/use_cudnn.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/util/use_cudnn.h" +#include "xla/tsl/util/use_cudnn.h" #include +#include "xla/tsl/util/env_var.h" #include "tsl/platform/str_util.h" #include "tsl/platform/stringpiece.h" -#include "tsl/util/env_var.h" #if GOOGLE_CUDA #include "third_party/gpus/cudnn/cudnn.h" diff --git a/third_party/tsl/tsl/util/use_cudnn.h b/xla/tsl/util/use_cudnn.h similarity index 92% rename from third_party/tsl/tsl/util/use_cudnn.h rename to xla/tsl/util/use_cudnn.h index 738e727e4c780..41c29b256f7be 100644 --- a/third_party/tsl/tsl/util/use_cudnn.h +++ b/xla/tsl/util/use_cudnn.h @@ -15,8 +15,8 @@ limitations under the License. // The utility to check Cudnn dependency and set Cudnn-related flags. -#ifndef TENSORFLOW_TSL_UTIL_USE_CUDNN_H_ -#define TENSORFLOW_TSL_UTIL_USE_CUDNN_H_ +#ifndef XLA_TSL_UTIL_USE_CUDNN_H_ +#define XLA_TSL_UTIL_USE_CUDNN_H_ #include @@ -40,4 +40,4 @@ bool ShouldCudnnGroupedConvolutionBeUsed(const int32_t filter_rows, const int32_t out_depth); } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_USE_CUDNN_H_ +#endif // XLA_TSL_UTIL_USE_CUDNN_H_ diff --git a/xla/types.h b/xla/types.h index 1910ea89e2ee4..f1891b9a8c076 100644 --- a/xla/types.h +++ b/xla/types.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,11 +19,11 @@ limitations under the License. #include #include #include +#include #include -#include "absl/strings/str_format.h" -#include "Eigen/Core" // from @eigen_archive -#include "ml_dtypes/include/int4.h" // from @ml_dtypes +#include "Eigen/Core" // from @eigen_archive // IWYU pragma: export +#include "tsl/platform/ml_dtypes.h" // IWYU pragma: export namespace xla { @@ -41,18 +41,26 @@ struct is_complex> : std::true_type {}; template inline constexpr bool is_complex_v = is_complex::value; +template +struct is_specialized_floating_point + : std::bool_constant::is_specialized && + !std::numeric_limits::is_integer> {}; + template inline constexpr bool is_specialized_floating_point_v = - std::numeric_limits::is_specialized && - !std::numeric_limits::is_integer; + is_specialized_floating_point::value; + +template +struct is_specialized_integral + : std::bool_constant::is_specialized && + std::numeric_limits::is_integer> {}; template inline constexpr bool is_specialized_integral_v = - std::numeric_limits::is_specialized && - std::numeric_limits::is_integer; + is_specialized_integral::value; -using u4 = ml_dtypes::uint4; -using s4 = ml_dtypes::int4; +using u4 = tsl::uint4; +using s4 = tsl::int4; } // namespace xla @@ -60,12 +68,12 @@ using s4 = ml_dtypes::int4; namespace ml_dtypes { template void AbslStringify(Sink& sink, const xla::s4& i) { - absl::Format(&sink, "%d", static_cast(i)); + sink.Append(std::to_string(static_cast(i))); } template void AbslStringify(Sink& sink, const xla::u4& i) { - absl::Format(&sink, "%d", static_cast(i)); + sink.Append(std::to_string(static_cast(i))); } } // namespace ml_dtypes @@ -112,6 +120,16 @@ struct make_specialized_signed { template using make_specialized_signed_t = typename make_specialized_signed::type; +template +struct has_negative_zero + : std::bool_constant::is_iec559> {}; + +template <> +struct has_negative_zero : std::bool_constant {}; + +template +inline constexpr bool has_negative_zero_v = has_negative_zero::value; + } // namespace xla #endif // XLA_TYPES_H_ diff --git a/xla/types_test.cc b/xla/types_test.cc index d9ae35205800a..a51b86ea51b2a 100644 --- a/xla/types_test.cc +++ b/xla/types_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/union_find.h b/xla/union_find.h index 134bc3da874c1..38ae5a98eb092 100644 --- a/xla/union_find.h +++ b/xla/union_find.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/util.cc b/xla/util.cc index 429ab9d4a060b..c4abd63d8c7f9 100644 --- a/xla/util.cc +++ b/xla/util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ limitations under the License. #include -#include +#include #include #include #include @@ -456,51 +456,6 @@ bool DistinctNumbersAreConsecutiveIfSorted(absl::Span seq) { seq.size() - 1; } -// Utility function to split a double-precision float (F64) into a pair of F32s. -// For a p-bit number, and a splitting point (p/2) <= s <= (p - 1), the -// algorithm produces a (p - s)-bit value 'hi' and a non-overlapping (s - 1)-bit -// value 'lo'. See Theorem 4 in [1] (attributed to Dekker) or [2] for the -// original theorem by Dekker. -// -// For double-precision F64s, which contain a 53 bit mantissa (52 of them -// explicit), we can represent the most significant 49 digits as the unevaluated -// sum of two single-precision floats 'hi' and 'lo'. The 'hi' float stores the -// most significant 24 bits and the sign bit of 'lo' together with its mantissa -// store the remaining 25 bits. The exponent of the resulting representation is -// still restricted to 8 bits of F32. -// -// References: -// [1] A. Thall, Extended-Precision Floating-Point Numbers for GPU Computation, -// SIGGRAPH Research Posters, 2006. -// (http://andrewthall.org/papers/df64_qf128.pdf) -// [2] T. J. Dekker, A floating point technique for extending the available -// precision, Numerische Mathematik, vol. 18, pp. 224–242, 1971. -std::pair SplitF64ToF32(double x) { - const float x_f32 = static_cast(x); - - // Early return if x is an infinity or NaN. - if (!std::isfinite(x_f32)) { - // Only values within the range of F32 are supported, unless it is infinity. - // Small values with large negative exponents would be rounded to zero. - if (std::isfinite(x)) { - LOG(WARNING) << "Out of range F64 constant detected: " << x; - } - return std::make_pair(x_f32, 0.0f); - } - - // The high float is simply the double rounded to the nearest float. Because - // we are rounding to nearest with ties to even, the error introduced in - // rounding is less than half an ULP in the high ULP. - const float hi = x_f32; - // We can compute the low term using Sterbenz' lemma: If a and b are two - // positive floating point numbers and a/2 ≤ b ≤ 2a, then their difference can - // be computed exactly. - // Note: the difference is computed exactly but is rounded to the nearest - // float which will introduce additional error. - const float lo = static_cast(x - static_cast(hi)); - return std::make_pair(hi, lo); -} - void PackInt4(absl::Span input, absl::Span output) { CHECK_EQ(output.size(), CeilOfRatio(input.size(), size_t{2})); for (size_t i = 0; i < input.size(); ++i) { @@ -525,64 +480,4 @@ void UnpackInt4(absl::Span input, absl::Span output) { } } -/*static*/ MaybeOwningThreadPool MaybeOwningThreadPool::GetOrCreate( - int parallelism, tsl::thread::ThreadPool* default_thread_pool, - int default_parallelism) { - CHECK_GE(parallelism, 0); - CHECK_GE(default_parallelism, 1); - - auto create_thread_pool = [&](int num_threads) { - CHECK_GE(num_threads, 1); - return std::make_unique(tsl::Env::Default(), "", - num_threads); - }; - - switch (parallelism) { - case 0: - if (default_thread_pool == nullptr && default_parallelism > 1) { - return MaybeOwningThreadPool(create_thread_pool(default_parallelism)); - } - return MaybeOwningThreadPool(default_thread_pool); - case 1: - return MaybeOwningThreadPool(nullptr); - default: - return MaybeOwningThreadPool(create_thread_pool(parallelism)); - } -} - -MaybeOwningThreadPool::MaybeOwningThreadPool() : thread_pool_(nullptr) {} - -MaybeOwningThreadPool::MaybeOwningThreadPool( - tsl::thread::ThreadPool* thread_pool) - : thread_pool_(thread_pool) {} - -MaybeOwningThreadPool::MaybeOwningThreadPool( - std::unique_ptr thread_pool) - : thread_pool_(std::move(thread_pool)) {} - -tsl::thread::ThreadPool* MaybeOwningThreadPool::get() { - if (std::holds_alternative(thread_pool_)) { - return std::get(thread_pool_); - } - return std::get>(thread_pool_).get(); -} - -const tsl::thread::ThreadPool* MaybeOwningThreadPool::get() const { - return const_cast(this)->get(); -} - -tsl::thread::ThreadPool* MaybeOwningThreadPool::operator->() { - tsl::thread::ThreadPool* thread_pool = get(); - CHECK_NE(thread_pool, nullptr); - return thread_pool; -} - -const tsl::thread::ThreadPool* MaybeOwningThreadPool::operator->() const { - return const_cast(this)->operator->(); -} - -MaybeOwningThreadPool::operator bool() const { return get() != nullptr; } - -bool MaybeOwningThreadPool::operator!() const { return get() == nullptr; } - } // namespace xla diff --git a/xla/util.h b/xla/util.h index 9894a7b46e4d4..a5910d0ce4197 100644 --- a/xla/util.h +++ b/xla/util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -47,14 +48,14 @@ limitations under the License. #include "Eigen/Core" // from @eigen_archive #include "xla/status.h" #include "xla/status_macros.h" +#include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/lib/math/math_util.h" #include "tsl/platform/bfloat16.h" #include "tsl/platform/casts.h" #include "tsl/platform/errors.h" // IWYU pragma: keep -#include "tsl/platform/float8.h" #include "tsl/platform/logging.h" -#include "tsl/platform/threadpool.h" +#include "tsl/platform/ml_dtypes.h" namespace xla { @@ -211,87 +212,156 @@ void StridedCopy(D* dest, int64_t dest_stride, const S* src, int64_t src_stride, Status AddStatus(Status prior, absl::string_view context); Status AppendStatus(Status prior, absl::string_view context); -// Status error shorthands -- StrFormat's the arguments to be used as an error -// message and returns a status in the canonical error space. -template -Status InvalidArgument(const absl::FormatSpec& format, - const Args&... args) { - return WithLogBacktrace( - tsl::errors::InvalidArgument(absl::StrFormat(format, args...))); -} -template -Status Unimplemented(const absl::FormatSpec& format, - const Args&... args) { - return WithLogBacktrace( - tsl::errors::Unimplemented(absl::StrFormat(format, args...))); -} -template -Status InternalError(const absl::FormatSpec& format, - const Args&... args) { - return WithLogBacktrace( - tsl::errors::Internal(absl::StrFormat(format, args...))); -} -template -Status FailedPrecondition(const absl::FormatSpec& format, - const Args&... args) { - return WithLogBacktrace( - tsl::errors::FailedPrecondition(absl::StrFormat(format, args...))); -} -template -Status Cancelled(const absl::FormatSpec& format, const Args&... args) { - return WithLogBacktrace( - tsl::errors::Cancelled(absl::StrFormat(format, args...))); -} -template -Status ResourceExhausted(const absl::FormatSpec& format, - const Args&... args) { - return WithLogBacktrace( - tsl::errors::ResourceExhausted(absl::StrFormat(format, args...))); -} -template -Status NotFound(const absl::FormatSpec& format, const Args&... args) { - return WithLogBacktrace( - tsl::errors::NotFound(absl::StrFormat(format, args...))); -} -template -Status Unavailable(const absl::FormatSpec& format, - const Args&... args) { - return WithLogBacktrace( - tsl::errors::Unavailable(absl::StrFormat(format, args...))); -} -template -Status Unknown(const absl::FormatSpec& format, const Args&... args) { - return WithLogBacktrace( - tsl::errors::Unknown(absl::StrFormat(format, args...))); -} -template -Status Internal(const absl::FormatSpec& format, const Args&... args) { - return WithLogBacktrace( - tsl::errors::Internal(absl::StrFormat(format, args...))); -} - -template -Status InvalidArgumentStrCat(Args&&... concat) { - return WithLogBacktrace( - tsl::errors::InvalidArgument(std::forward(concat)...)); -} +// The following three macros define a common set of code for creating +// absl::Status errors with the given error_type, with the addition of adding +// absl::SourceLocation if it's available (PLATFORM_GOOGLE). They're a +// complicated by the need to use #ifdefs within the code. This would be the +// equivalent code for ResourceExhausted if a #define macro could have embedded +// #ifdef directives: +// +// template +// struct ResourceExhausted { +// Status status; +// #if defined(PLATFORM_GOOGLE) +// // NOLINTNEXTLINE(google-explicit-constructor) +// ResourceExhausted(const absl::FormatSpec& format, Args&&... args, +// absl::SourceLocation loc = +// absl::SourceLocation::current()) +// : status(WithLogBacktrace( +// absl::ResourceExhaustedError(absl::StrFormat(format, args...)) +// .WithSourceLocation(loc))) {} +// #else +// ResourceExhaustedStrCat(Args&&... concat) +// : status(WithLogBacktrace( +// absl::ResourceExhaustedError(absl::StrFormat(format, args...))) +// {} +// #endif +// +// // NOLINTNEXTLINE(google-explicit-constructor) +// operator Status() const { return status; } +// }; +// +#define XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE_PREFIX(error_type) \ + template \ + struct error_type { \ + Status status; +#define XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE_SUFFIX(error_type) \ + /* NOLINTNEXTLINE(google-explicit-constructor) */ \ + operator Status() const { return status; } \ + } \ + ; \ + /*Deduction guide to make variadic arguments play nice with default */ \ + /* absl::SourceLocation argument. */ \ + template \ + error_type(const absl::FormatSpec& format, \ + Args&&...) -> error_type; + +#if defined(PLATFORM_GOOGLE) +#define XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE(error_type) \ + XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE_PREFIX(error_type) \ + /* NOLINTNEXTLINE(google-explicit-constructor) */ \ + error_type(const absl::FormatSpec& format, Args&&... args, \ + absl::SourceLocation loc = absl::SourceLocation::current()) \ + : status(WithLogBacktrace( \ + absl::error_type##Error(absl::StrFormat(format, args...)) \ + .WithSourceLocation(loc))) {} \ + XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE_SUFFIX(error_type) +#else +#define XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE(error_type) \ + template \ + Status error_type(const absl::FormatSpec& format, \ + const Args&... args) { \ + return WithLogBacktrace( \ + absl::error_type##Error(absl::StrFormat(format, args...))); \ + } +#endif -template -Status UnimplementedStrCat(Args&&... concat) { - return WithLogBacktrace( - tsl::errors::Unimplemented(std::forward(concat)...)); -} +XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE(Cancelled); +XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE(FailedPrecondition); +XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE(Internal); +XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE(InvalidArgument); +XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE(NotFound); +XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE(ResourceExhausted); +XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE(Unavailable); +XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE(Unimplemented); +XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE(Unknown); + +#undef XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE +#undef XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE_PREFIX +#undef XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE_SUFFIX + +// The following three macros define a common set of code for creating +// absl::Status errors with the given error_type, with the addition of adding +// absl::SourceLocation if it's available (PLATFORM_GOOGLE). They're a +// complicated by the need to use #ifdefs within the code. This would be the +// equivalent code for ResourceExhausted if a #define macro could have embedded +// #ifdef directives: +// +// template +// struct ResourceExhaustedStrCat { +// Status status; +// #if defined(PLATFORM_GOOGLE) +// // NOLINTNEXTLINE(google-explicit-constructor) +// ResourceExhaustedStrCat(Args&&... concat, absl::SourceLocation loc = +// absl::SourceLocation::current()) +// : status(WithLogBacktrace( +// absl::ResourceExhaustedError(absl::StrCat( +// std::forward(concat)...)) +// .WithSourceLocation(loc))) {} +// #else +// ResourceExhaustedStrCat(Args&&... concat) +// : status(WithLogBacktrace( +// absl::ResourceExhaustedError(absl::StrCat( +// std::forward(concat)...)))) +// {} +// #endif +// +// // NOLINTNEXTLINE(google-explicit-constructor) +// operator Status() const { return status; } +// }; +// +#define XLA_ERROR_WITH_STRCAT_AND_BACKTRACE_PREFIX(error_type) \ + template \ + struct error_type##StrCat { \ + Status status; \ + /* NOLINTNEXTLINE(google-explicit-constructor) */ +#define XLA_ERROR_WITH_STRCAT_AND_BACKTRACE_SUFFIX(error_type) \ + /* NOLINTNEXTLINE(google-explicit-constructor) */ \ + operator Status() const { return status; } \ + } \ + ; \ + /*Deduction guide to make variadic arguments play nice with default */ \ + /* absl::SourceLocation argument. */ \ + template \ + error_type##StrCat(Args&&...)->error_type##StrCat; + +#if defined(PLATFORM_GOOGLE) +#define XLA_ERROR_WITH_STRCAT_AND_BACKTRACE(error_type) \ + XLA_ERROR_WITH_STRCAT_AND_BACKTRACE_PREFIX(error_type) \ + error_type##StrCat(Args&&... concat, absl::SourceLocation loc = \ + absl::SourceLocation::current()) \ + : status( \ + WithLogBacktrace(absl::error_type##Error( \ + absl::StrCat(std::forward(concat)...)) \ + .WithSourceLocation(loc))) {} \ + XLA_ERROR_WITH_STRCAT_AND_BACKTRACE_SUFFIX(error_type) +#else +#define XLA_ERROR_WITH_STRCAT_AND_BACKTRACE(error_type) \ + XLA_ERROR_WITH_STRCAT_AND_BACKTRACE_PREFIX(error_type) \ + error_type##StrCat(Args&&... concat) \ + : status(WithLogBacktrace(absl::error_type##Error( \ + absl::StrCat(std::forward(concat)...)))) {} \ + XLA_ERROR_WITH_STRCAT_AND_BACKTRACE_SUFFIX(error_type) +#endif -template -Status InternalErrorStrCat(Args&&... concat) { - return WithLogBacktrace(tsl::errors::Internal(std::forward(concat)...)); -} +XLA_ERROR_WITH_STRCAT_AND_BACKTRACE(ResourceExhausted); +XLA_ERROR_WITH_STRCAT_AND_BACKTRACE(InvalidArgument); +XLA_ERROR_WITH_STRCAT_AND_BACKTRACE(Unimplemented); +XLA_ERROR_WITH_STRCAT_AND_BACKTRACE(Internal); -template -Status ResourceExhaustedStrCat(Args&&... concat) { - return WithLogBacktrace( - tsl::errors::ResourceExhausted(std::forward(concat)...)); -} +#undef XLA_ERROR_WITH_STRCAT_AND_BACKTRACE +#undef XLA_ERROR_WITH_STRCAT_AND_BACKTRACE_PREFIX +#undef XLA_ERROR_WITH_STRCAT_AND_BACKTRACE_SUFFIX // Splits the lines of the original, replaces leading whitespace with the prefix // given by "indentation", and returns the string joined by newlines again. As a @@ -562,9 +632,7 @@ auto SignAndMagnitude(T x) { BitType x_abs_bits = Eigen::numext::bit_cast(Eigen::numext::abs(x)); const BitType x_bits = Eigen::numext::bit_cast(x); const BitType x_sign = x_bits ^ x_abs_bits; - if constexpr (std::is_same_v || - std::is_same_v || - std::is_same_v) { + if constexpr (!has_negative_zero_v) { // f8e4m3b11, f8e4m3fnuz, and f8e5m2fnuz don't support -0, adjust negative // numbers to fill in the gap. if (x_sign) { @@ -729,16 +797,6 @@ Status EraseElementFromVector(std::vector* container, const T& value) { return OkStatus(); } -// Utility function which splits a double-precision float (F64) into a pair of -// single-precision floating point numbers. The most significant 49 bits (out of -// the total 53 available) in the mantissa of the F64 is represented as the -// unevaluated sum of two non-overlapping single-precision F32s; the 'high' part -// contains 24 bits in its mantissa, and the 'low' part contains 25 bits in its -// sign bit and its mantissa. -// Note: The resulting representation can still only represent 8-bit exponent -// range that is available in F32s (out of a total of 11 exponent bits in F64s). -std::pair SplitF64ToF32(double x); - // Takes a sequence of unpacked int4 values, such that every byte stores one // int4 value in the low-order four bits, and packs them so every byte stores // two int4 values. 'input' should have num_elements bytes; 'output' should have @@ -766,37 +824,6 @@ inline bool HloPredicateFalse(const HloInstruction*) { return false; } using Vector2 = std::array; using Vector3 = std::array; -// A class for storing either an owned thread pool or a non-owning pointer to an -// external thread pool. -class MaybeOwningThreadPool { - public: - // Gets or creates a thread pool. - // - // See the code for the logic. - static MaybeOwningThreadPool GetOrCreate( - int parallelism, tsl::thread::ThreadPool* default_thread_pool, - int default_parallelism); - - // Not owning (nullptr). - MaybeOwningThreadPool(); - // Not owning. - explicit MaybeOwningThreadPool(tsl::thread::ThreadPool* thread_pool); - // Owning. - explicit MaybeOwningThreadPool( - std::unique_ptr thread_pool); - tsl::thread::ThreadPool* get(); - const tsl::thread::ThreadPool* get() const; - tsl::thread::ThreadPool* operator->(); - const tsl::thread::ThreadPool* operator->() const; - explicit operator bool() const; - bool operator!() const; - - private: - std::variant> - thread_pool_; -}; - } // namespace xla #define XLA_LOG_LINES(SEV, STRING) \ diff --git a/xla/util_test.cc b/xla/util_test.cc index 3dfb2e8682af7..555521332c26e 100644 --- a/xla/util_test.cc +++ b/xla/util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,8 +25,8 @@ limitations under the License. #include "xla/test.h" #include "xla/types.h" -#include "tsl/platform/float8.h" #include "tsl/platform/logging.h" +#include "tsl/platform/ml_dtypes.h" namespace xla { namespace { @@ -183,14 +183,6 @@ TEST(UtilTest, RoundTripFpToString) { "-nan(0x1)"); } -TEST(UtilTest, SplitF64ToF32) { - // Overflowing the F32 exponent in SplitF64ToF32 should result in a pair of - // [∞,0]. - EXPECT_EQ(SplitF64ToF32(std::numeric_limits::max()).first, - std::numeric_limits::infinity()); - EXPECT_EQ(SplitF64ToF32(std::numeric_limits::max()).second, 0.0f); -} - namespace { template void TotalOrderHelper(T x, T y) { diff --git a/xla/window_util.cc b/xla/window_util.cc index d4ab29b6de40d..471aec16511f3 100644 --- a/xla/window_util.cc +++ b/xla/window_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/window_util.h b/xla/window_util.h index c691056958154..b94074ae2b993 100644 --- a/xla/window_util.h +++ b/xla/window_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/window_util_test.cc b/xla/window_util_test.cc index b5901973c2a54..a5d7ac7a265d3 100644 --- a/xla/window_util_test.cc +++ b/xla/window_util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/xla/xla.bzl b/xla/xla.bzl index 5793b69910bac..9aa2a58c4e3f1 100644 --- a/xla/xla.bzl +++ b/xla/xla.bzl @@ -6,12 +6,11 @@ load( ) load( "@tsl//tsl:tsl.bzl", - "if_oss", "tsl_copts", - _tsl_clean_dep = "clean_dep", ) load( "@tsl//tsl/platform:build_config_root.bzl", + "if_static", "tf_exec_properties", ) load( @@ -19,17 +18,6 @@ load( "if_cuda_is_configured", ) -def clean_dep(target): - """Returns string to 'target' in @{org_tensorflow,xla} repository. - - This is distinct from the clean_dep which appears in @{org_tensorflow,tsl}. - TODO(ddunleavy,jakeharmon): figure out what to do with this after vendoring. - """ - - # A repo-relative label is resolved relative to the file in which the - # Label() call appears, i.e. @org_tensorflow. - return str(Label(target)) - def xla_py_proto_library(**_kwargs): # Note: we don't currently define a proto library target for Python in OSS. pass @@ -45,39 +33,39 @@ def xla_py_test_deps(): # `framework_shared_object` in the bazelrc all of this should be able to go # away. The problem is making sure that all these impl deps are `if_static`'d # appropriately throughout XLA. -_XLA_SHARED_OBJECT_SENSITIVE_DEPS = if_oss([_tsl_clean_dep("@com_google_protobuf//:protobuf")]) + [ - clean_dep("//xla:xla_proto_cc_impl"), - clean_dep("//xla:xla_data_proto_cc_impl"), - clean_dep("//xla/service:hlo_proto_cc_impl"), - clean_dep("//xla/service:buffer_assignment_proto_cc_impl"), - clean_dep("//xla/service/memory_space_assignment:memory_space_assignment_proto_cc_impl"), - clean_dep("//xla/service/gpu:backend_configs_cc_impl"), - clean_dep("//xla/service/gpu/model:hlo_op_profile_proto_cc_impl"), - clean_dep("//xla/stream_executor:device_description_proto_cc_impl"), - clean_dep("//xla/stream_executor:device_id_utils"), - clean_dep("//xla/stream_executor:stream_executor_impl"), - clean_dep("//xla/stream_executor/gpu:gpu_cudamallocasync_allocator"), - clean_dep("//xla/stream_executor/gpu:gpu_init_impl"), - clean_dep("@tsl//tsl/profiler/utils:time_utils_impl"), - clean_dep("@tsl//tsl/profiler/backends/cpu:annotation_stack_impl"), - clean_dep("@tsl//tsl/profiler/backends/cpu:traceme_recorder_impl"), - clean_dep("@tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl"), - clean_dep("@tsl//tsl/profiler/protobuf:xplane_proto_cc_impl"), - clean_dep("//xla:autotune_results_proto_cc_impl"), - clean_dep("//xla:autotuning_proto_cc_impl"), - clean_dep("@tsl//tsl/protobuf:protos_all_cc_impl"), - clean_dep("@tsl//tsl/platform:env_impl"), - clean_dep("@tsl//tsl/framework:allocator"), - clean_dep("@tsl//tsl/framework:allocator_registry_impl"), - clean_dep("@tsl//tsl/util:determinism"), -] + if_cuda_is_configured([ - clean_dep("//xla/stream_executor/cuda:cuda_stream"), - clean_dep("//xla/stream_executor/cuda:all_runtime"), - clean_dep("//xla/stream_executor/cuda:stream_executor_cuda"), +_XLA_SHARED_OBJECT_SENSITIVE_DEPS = if_static(extra_deps = [], otherwise = [ + Label("//xla:autotune_results_proto_cc_impl"), + Label("//xla:autotuning_proto_cc_impl"), + Label("//xla:xla_data_proto_cc_impl"), + Label("//xla:xla_proto_cc_impl"), + Label("//xla/service:buffer_assignment_proto_cc_impl"), + Label("//xla/service:hlo_proto_cc_impl"), + Label("//xla/service/gpu:backend_configs_cc_impl"), + Label("//xla/service/gpu/model:hlo_op_profile_proto_cc_impl"), + Label("//xla/service/memory_space_assignment:memory_space_assignment_proto_cc_impl"), + Label("//xla/stream_executor:device_description_proto_cc_impl"), + Label("//xla/stream_executor:stream_executor_impl"), + Label("//xla/stream_executor/gpu:gpu_init_impl"), + "@com_google_protobuf//:protobuf", + "@tsl//tsl/framework:allocator_registry_impl", + "@tsl//tsl/framework:allocator", + "@tsl//tsl/platform:env_impl", + "@tsl//tsl/profiler/backends/cpu:annotation_stack_impl", + "@tsl//tsl/profiler/backends/cpu:traceme_recorder_impl", + "@tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl", + "@tsl//tsl/profiler/protobuf:xplane_proto_cc_impl", + "@tsl//tsl/profiler/utils:time_utils_impl", + "@tsl//tsl/protobuf:protos_all_cc_impl", +]) + if_cuda_is_configured([ + Label("//xla/stream_executor/cuda:all_runtime"), + Label("//xla/stream_executor/cuda:cuda_stream"), + Label("//xla/stream_executor/cuda:stream_executor_cuda"), + Label("//xla/stream_executor/gpu:gpu_cudamallocasync_allocator"), ]) + if_rocm_is_configured([ - clean_dep("//xla/stream_executor/gpu:gpu_stream"), - clean_dep("//xla/stream_executor/rocm:all_runtime"), - clean_dep("//xla/stream_executor/rocm:stream_executor_rocm"), + Label("//xla/stream_executor/gpu:gpu_stream"), + Label("//xla/stream_executor/rocm:all_runtime"), + Label("//xla/stream_executor/rocm:stream_executor_rocm"), + "//xla/tsl/util:determinism", ]) def xla_cc_binary(deps = [], copts = tsl_copts(), **kwargs): @@ -92,10 +80,10 @@ def xla_cc_test(name, deps = [], **kwargs): ) def auto_sharding_deps(): - return ["//xla/hlo/experimental/auto_sharding:auto_sharding_impl"] + return [Label("//xla/hlo/experimental/auto_sharding:auto_sharding_impl")] def auto_sharding_solver_deps(): - return ["//xla/hlo/experimental/auto_sharding:auto_sharding_solver_impl"] + return [Label("//xla/hlo/experimental/auto_sharding:auto_sharding_solver_impl")] def xla_export_hlo_deps(): return [] diff --git a/xla/xla.proto b/xla/xla.proto index 0a8608599838f..3ccef6436ce51 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -172,7 +172,12 @@ message DebugOptions { // useful when accelerating structured sparsity. int32 xla_cpu_sparse_cuda_threads = 207; - // Allows xla to increase the output precision of floating point operations. + // Allows xla to increase the output precision of floating point operations + // and all floating-point conversions to be simplified, including those + // that affect the numerics. The `FloatNormalization` pass inserts many + // `f32 -> bf16 -> f32` conversion pairs. These are not removed by the + // `AlgebraicSimplifier`, as that will only simplify conversions that are + // no-ops, e.g. `bf16 -> f32 -> bf16`. Removing these improves accuracy. bool xla_allow_excess_precision = 122; // Crashes the program when any kind of verification fails, instead of just @@ -362,6 +367,7 @@ message DebugOptions { // Overrides normal multi-threaded compilation setting to use this many // threads. Setting to 0 (the default value) means no enforcement. int32 xla_gpu_force_compilation_parallelism = 147; + bool xla_gpu_enable_llvm_module_compilation_parallelism = 268; // Guarantees run-to-run determinism. At present, the HLO ops Scatter and // SelectAndScatter do not have deterministic XLA:GPU implementations. @@ -374,6 +380,7 @@ message DebugOptions { // Convert synchronous collective ops into asynchronous. bool xla_gpu_enable_async_collectives = 238; bool xla_gpu_enable_async_all_reduce = 152; + bool xla_gpu_enable_async_collective_broadcast = 278; bool xla_gpu_enable_async_collective_permute = 183; bool xla_gpu_enable_async_all_gather = 199; bool xla_gpu_enable_async_reduce_scatter = 200; @@ -445,19 +452,11 @@ message DebugOptions { // if `xla_gpu_enable_custom_fusion` set to true. string xla_gpu_enable_custom_fusions_re = 264; - // If true, use OpenXLA runtime for XLA:GPU backend. That is, use IREE VM - // as a host executable, optional CUDA HAL for dispatching device kernels and - // custom modules for integration with libraries required for running - // XLA:GPU programs. - // - // Note: this mode disables thunks and the "classic" gpu runtime, which - // is defined above. - bool xla_gpu_enable_gpu2_runtime = 233; + // If true, use XLA runtime for XLA:GPU backend. + bool xla_gpu_enable_address_computation_fusion = 105; - // If true, use OpenXLA hardware abstraction layer (aka CUDA HAL) to dispatch - // device kernels, otherwise use StreamExecutor kernel launch APIs. Has any - // effect only if `xla_gpu_enable_gpu2_runtime` is set to true. - bool xla_gpu_enable_gpu2_hal = 234; + reserved 233; // was xla_gpu_enable_gpu2_runtime + reserved 234; // was xla_gpu_enable_gpu2_hal // Timeout in seconds before terminating jobs that are stuck in a NCCL // Rendezvous. Negative value disables the timeout and will not terminate. @@ -470,14 +469,18 @@ message DebugOptions { // Whether to use cuBLASLt for GEMMs on GPUs. bool xla_gpu_enable_cublaslt = 166; - // Commands are categorized into four types: FUSION represents regular fusion - // kernels. CUBLAS, CUDNN, and NCCL represent library calls. + // Commands are categorized into 5 types: + // FUSION represents regular fusion kernels. + // CUBLAS, CUDNN, and COLLECTIVES represent library calls. + // CONDITIONALS represents control flow. enum CommandBufferCmdType { INVALID = 0; FUSION = 1; CUBLAS = 2; CUDNN = 3; - NCCL = 4; + COLLECTIVES = 4; + CONDITIONALS = 5; + CUSTOM_CALL = 6; } // Determine the types of commands that are recorded into command buffers. @@ -501,11 +504,6 @@ message DebugOptions { // executed to free space on device. int32 xla_gpu_graph_eviction_timeout_seconds = 230; - // Allocate temp buffers once during the first execution of an executable. - // Reuse the allocated buffers in subsequent executions. Executables cannot - // run concurrently if this is enabled. - bool xla_gpu_enable_persistent_temp_buffers = 206; - // Size threshold (in megabytes) for the GPU redzone scratch allocator. int64 xla_gpu_redzone_scratch_max_megabytes = 167; @@ -519,12 +517,8 @@ message DebugOptions { // scratch), so this can be multiplied by quite a lot. int64 xla_gpu_redzone_padding_bytes = 228; - // Allows all floating-point conversions to be simplified, including those - // that affect the numerics. The `FloatNormalization` pass inserts many - // `f32 -> bf16 -> f32` conversion pairs. These are not removed by the - // `AlgebraicSimplifier`, as that will only simplify conversions that are - // no-ops, e.g. `bf16 -> f32 -> bf16`. Removing these improves accuracy. - bool xla_gpu_simplify_all_fp_conversions = 168; + // Deprecated. Use xla_allow_excess_precision instead. + bool xla_gpu_simplify_all_fp_conversions = 168 [deprecated = true]; // An experimental option to force all layouts present in the // after-optimizations HLO to be descending, e.g. @@ -564,6 +558,7 @@ message DebugOptions { bool xla_gpu_enable_latency_hiding_scheduler = 186; bool xla_gpu_enable_highest_priority_async_stream = 216; bool xla_gpu_enable_analytical_latency_estimator = 255; + bool xla_gpu_enable_linear_program_scheduler = 286; bool xla_gpu_lhs_enable_gpu_async_tracker = 204; string xla_gpu_pgle_profile_file_or_directory_path = 210; @@ -592,6 +587,9 @@ message DebugOptions { bool xla_gpu_enable_cudnn_int8x32_convolution_reordering = 189; + // Creates triton fusion for all supported gemms. + // To make sure only triton gemm is chosen by the autotuner run with + // `xla_gpu_cublas_fallback` set to false. bool xla_gpu_triton_gemm_any = 190; reserved 211; // Was xla_gpu_enable_dot_strength_reduction @@ -674,7 +672,65 @@ message DebugOptions { // Threshold to enable windowed einsum (collective matmul) in MB. int64 xla_gpu_threshold_for_windowed_einsum_mib = 265; - // Next id: 266 + // Enables currently disabled features within Triton for Hopper. + bool xla_gpu_enable_triton_hopper = 266; + + // Enable NCCL user buffers. + bool xla_gpu_enable_nccl_user_buffers = 267; + + // Enable NCCL communicator splitting. + bool xla_gpu_enable_nccl_comm_splitting = 272; + + // Enable NCCL per stream communicators. + bool xla_gpu_enable_nccl_per_stream_comms = 276; + + // If enabled, uses the libnvptxcompiler library to compile PTX to cuBIN. + bool xla_gpu_enable_libnvptxcompiler = 269; + + bool xla_gpu_enable_dot_strength_reduction = 270; + // Whether to use multiple compute streams to run windowed einsum. + bool xla_gpu_multi_streamed_windowed_einsum = 280; + + // If enabled, uses bf16_6way gemm to compute F32 gemm. + bool xla_gpu_enable_bf16_6way_gemm = 271; + + // If enabled, uses bf16_3way gemm to compute F32 gemm. + bool xla_gpu_enable_bf16_3way_gemm = 279; + + // Specify the maximum number of channels(SMs) NCCL + // will use for collective operations. + int64 xla_gpu_nccl_collective_max_nchannels = 273; + + // Specify the maximum number of channels(SMs) NCCL + // will use for p2p operations. + int64 xla_gpu_nccl_p2p_max_nchannels = 274; + + bool xla_gpu_enable_mlir_emitters = 275; + // The maximum number of kernels to emit with MLIR. Unlimited if 0. + int64 xla_gpu_max_mlir_kernels = 281; + // The number of initial kernels to not emit with MLIR. Only supported kernels + // are counted. + int64 xla_gpu_skip_mlir_kernels = 282; + + // Threshold to rewrite matmul to cuBLAS or Triton (minumum combined number of + // elements of both matrices in non-batch dimensions to be considered for a + // rewrite). + int64 xla_gpu_gemm_rewrite_size_threshold = 283; + + // If true, will require complete AOT autotuning results; in the case of + // missing AOT result, the model will not be compiled or executed, a + // `NotFound` error will be returned. + bool xla_gpu_require_complete_aot_autotune_results = 284; + + // Let GEMM fusion autotuning probe cuDNN as a backend. + // Current levels: + // 0: Disabled. + // 1: Fusions of GEMM, elementwise, transpose/reshape operations. + // 2: + Broadcasts. + // 3: + Nontrivial noncontracting dimension reshapes/transposes. + int32 xla_gpu_cudnn_gemm_fusion_level = 285; + + // Next id: 287 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. @@ -689,7 +745,8 @@ message DebugOptions { // xla_gpu_enable_experimental_block_size // xla_gpu_graph_level // xla_gpu_single_wave_autotuning - reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 242; + // xla_gpu_enable_persistent_temp_buffers + reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 242, 206; } // Contains flags which affects the GPU compilation result. @@ -711,7 +768,7 @@ message ShardableValueUpdatePairProto { // will have an effect on every platform. // // When adding new fields, keep in mind that boolean fields default to false. -// Next id: 23. +// Next id: 24. message ExecutionOptions { // This optional field's layout is used as a hint when storing the output of // this computation. Subsequent transfers of this output array to the client @@ -773,6 +830,18 @@ message ExecutionOptions { reserved 13; // Was broadcast_replicated_parameters_via_collectives + // Allows sharding propagation to propagate to the parameters. This changes + // the input shape of the computation (which is undesirable), but it can be + // used to allow to run partial compilation to determine what would be the + // input sharding of a computation if XLA would be allowed to propagate the + // sharding which can be used by higher level framework as a way to query + // intermediate sharding of operations when multiple computation would be + // chained and merged together. + // This is a vector of bool, because the user can control which parameters can + // have the sharding substituted. If only one boolean value is passed in the + // vector that is interpreted as the value to be applied for every parameter. + repeated bool allow_spmd_sharding_propagation_to_parameters = 23; + // Allows sharding propagation to propagate to the outputs. This changes the // output shape of the computation (which is undesirable), but it can be used // to allow to run partial compilation to determine what would be the output @@ -811,7 +880,7 @@ message ExecutionOptions { // Serialization of HloModuleConfig. See the C++ class definition for // descriptions of each field. // There are no guarantees of backwards or forwards compatibility. -// Next id: 33. +// Next id: 34. message HloModuleConfigProto { enum FusionConfigCollection { OFF = 0; // Do not collect configuration. @@ -859,6 +928,7 @@ message HloModuleConfigProto { repeated BoolList phase_ordering_config = 24; int32 phase_index = 25; reserved 26; // Was flag_config + repeated bool allow_spmd_sharding_propagation_to_parameters = 33; repeated bool allow_spmd_sharding_propagation_to_output = 27; map analysis_allowance_map = 28; xla.PrecisionConfig.Precision matrix_unit_operand_precision = 29; diff --git a/xla/xla_data.proto b/xla/xla_data.proto index d2691293070ad..5bcdb0aa08320 100644 --- a/xla/xla_data.proto +++ b/xla/xla_data.proto @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -181,6 +181,16 @@ message TileProto { repeated int64 dimensions = 1; } +// Describes how data should be split between different memories. +message SplitConfigProto { + // The dimension that is split. + int64 dimension = 1; + // The indices where each split point occurs. For example, if the dimension + // size is 1024, a split_indices value of {512} indicates a two-way split of + // data through the middle. + repeated int64 split_indices = 2; +} + // A layout describes how the array is placed in (1D) memory space. This // includes the minor-to-major ordering of dimensions within a shape. // @@ -216,6 +226,15 @@ message LayoutProto { // error. repeated TileProto tiles = 6; + // The shape is padded at the end to multiple of, in terms of number of + // elements. This is useful when tiling does not bring the shape to certain + // desired granules. Tiling effectively pads/reshapes/transposes the shape + // to another shape. This field pads the total number of elements of that + // new shape to a multiple of certain number of elements. This is useful such + // as we want a layout which does not tile the data but still requires it to + // be padded to certain number of elements. + int64 tail_padding_alignment_in_elements = 16; + // (Optional) Bit size of each element. When unspecified or being 0, default // to ShapeUtil::ByteSizeOfPrimitiveType. int64 element_size_in_bits = 7; @@ -244,6 +263,10 @@ message LayoutProto { // dynamic shape, e.g. a result of SliceToDynamic. int64 dynamic_shape_metadata_prefix_bytes = 15; + // The split configurations which describe if/how the data is split between + // different memories. + repeated SplitConfigProto split_configs = 17; + // Important: if any field is added, be sure to modify ShapeUtil::Equal() and // LayoutUtil::Hash appropriately to account for the new field. @@ -371,16 +394,11 @@ message OpMetadata { // Deprecated, use [ProfileInfo][profile_type] instead. repeated ProfileType profile_type = 5 [deprecated = true]; - // HloPassMetadata.pass_id of the pass that created this HLO instruction - // object. Should never be copied between HLO instructions. Zero if unset and - // -1 if the instruction was created before HLO passes began. - int64 creation_pass_id = 6; + reserved 6; + reserved "creation_pass_id"; - // HloPassMetadata.pass_id of the pass that created the logical functionality - // that this HLO instruction represents. Should be copied between HLO - // instructions that correspond across compilation passes. Zero if unset and - // -1 if the instruction was created before HLO passes began. - int64 logical_creation_pass_id = 7; + reserved 7; + reserved "logical_creation_pass_id"; // The footprint of the generated code for the instruction. int64 size_of_generated_code_in_bytes = 8; @@ -452,6 +470,10 @@ message ExecutionProfile { // Whether this profile was drawn from a cache of profiles instead of from // execution on the hardware. bool profile_cache_hit = 7; + + // Whether a warm-up run of the computation was executed before the + // measured execution. + bool warmup_run_executed = 8; } // Handle given to a user that represents an execution that the user launched @@ -710,6 +732,36 @@ message DotDimensionNumbers { repeated int64 rhs_batch_dimensions = 4; } +enum SparsityType { + SPARSITY_INVALID = 0; + + // Structured N:M sparsity. + SPARSITY_STRUCTURED_N_M = 1; + + // Next: 2 +} + +// Contains sparsity metadata for a sparse dot operation. +// The only supported type atm is structured 2:4 sparsity, which is natively +// supported on NVidia GPUs. +// Restrictions: +// - only one operand of the dot operation may be sparse; +// - only the contracting dimension may be sparse. +message SparsityDescriptor { + SparsityType type = 1; + + // Sparse operand index (0 or 1). + int32 index = 2; + // Sparse dimension number. + int32 dimension = 3; + + // Structured N:M sparsity (N < M). + int32 n = 4; + int32 m = 5; + + // Next: 6 +} + enum RandomDistribution { RNG_INVALID = 0; @@ -899,9 +951,65 @@ message PrecisionConfig { // Next: 4 } + + // The algorithm used to evaluate the instruction. + // + // The naming convention for the dot instruction is + // ALG_DOT_{A_TYPE}_{B_TYPE}_{ACCUM_TYPE}[_X{NUM_OPS}] where A_TYPE, B_TYPE + // and ACCUM_TYPE correspond to the types in the "primitive dot operations" + // (such as TensorCore operations) and NUM_OPS is the number of such + // operations used per "primitive tile". When the NUM_OPS + // field is skipped, it is assumed to be 1. The types mentioned in the name + // are independent of the storage types. + // + // In general ATYPE and BTYPE are the precisions that the LHS and RHS of the + // operation are rounded to and ACCUMTYPE is the accumulation type. If a + // backend does not support the given algorithm, an error is raised. The + // Algorithm enum is intended to eventually replace the Precision enum. + // + enum Algorithm { + // If the algorithm is `ALG_UNSET`, we will decide the algorithm based on + // the operand_precision values (for now). + ALG_UNSET = 0; + // The storage type can be any 8-bit floating point type. + ALG_DOT_ANY_F8_ANY_F8_F32 = 1; + // The storage type can be any 8-bit floating point type. Intermediate + // results will not periodically be promoted to a higher precision. This + // corresponds to CUBLASLT_MATMUL_DESC_FAST_ACCUM. Triton's + // maxNumImpreciseAcc=32 setting may be similar. + ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM = 2; + ALG_DOT_F16_F16_F16 = 3; + ALG_DOT_F16_F16_F32 = 4; + ALG_DOT_BF16_BF16_BF16 = 5; + ALG_DOT_BF16_BF16_F32 = 6; + // An algorithm which uses 3 BF16_BF16_F32 matmuls to achieve better + // precision. + ALG_DOT_BF16_BF16_F32_X3 = 7; + // An algorithm which uses 6 BF16_BF16_F32 matmuls to achieve better + // precision (similar to F32). + ALG_DOT_BF16_BF16_F32_X6 = 8; + ALG_DOT_TF32_TF32_F32 = 9; + // An algorithm which uses 3 TF32_TF32_F32 matmuls to achieve better + // precision (similar to F32). + ALG_DOT_TF32_TF32_F32_X3 = 10; + ALG_DOT_F32_F32_F32 = 11; + ALG_DOT_F64_F64_F64 = 12; + + // Next: 13 + } + repeated Precision operand_precision = 1; - // Next: 2 + // Currently doesn't do anything, but we plan to support it for dot and + // possibly more instructions. + // + // TODO(b/316147294): Support this on GPU and add this to StableHLO as well. + // + // If this is set, then `operand_precision` should be set to DEFAULT and it + // will be ignored. + Algorithm algorithm = 2; + + // Next: 3 } // Describes whether all data-parallelism replicas will receive the same